Commit 74f2a5c6 authored by Romain Courteaud's avatar Romain Courteaud

First tests

parent 16f75e94
This diff is collapsed.
import unittest
from urlchecker_db import LogDB
from urlchecker_status import logStatus
class UrlCheckerStatusTestCase(unittest.TestCase):
def setUp(self):
self.db = LogDB(":memory:")
self.db.createTables()
def test_logStatus_insert(self):
result = logStatus(self.db, "foo")
assert self.db.Status.select().count() == 1
assert self.db.Status.get(self.db.Status.text == "foo").id == result
def test_logStatus_insertTwice(self):
result1 = logStatus(self.db, "foo")
result2 = logStatus(self.db, "foo")
assert self.db.Status.select().count() == 2
assert result1 < result2
def suite():
suite = unittest.TestSuite()
suite.addTest(unittest.makeSuite(UrlCheckerStatusTestCase))
return suite
if __name__ == "__main__":
unittest.main(defaultTest="suite")
...@@ -6,7 +6,7 @@ from playhouse.sqlite_ext import SqliteExtDatabase ...@@ -6,7 +6,7 @@ from playhouse.sqlite_ext import SqliteExtDatabase
class LogDB: class LogDB:
def __init__(self, sqlite_path): def __init__(self, sqlite_path):
self._db = SqliteExtDatabase( self._db = SqliteExtDatabase(
sqlite_path, pragmas=(("journal_mode", "WAL"),) sqlite_path, pragmas=(("journal_mode", "WAL"), ("foreign_keys", 1))
) )
self._db.connect() self._db.connect()
...@@ -62,6 +62,8 @@ class LogDB: ...@@ -62,6 +62,8 @@ class LogDB:
ip = peewee.TextField(index=True) ip = peewee.TextField(index=True)
url = peewee.TextField(index=True) url = peewee.TextField(index=True)
status_code = peewee.IntegerField() status_code = peewee.IntegerField()
class Meta:
primary_key = peewee.CompositeKey("status", "ip", "url")
self.Status = Status self.Status = Status
self.ConfigurationChange = ConfigurationChange self.ConfigurationChange = ConfigurationChange
......
...@@ -113,6 +113,6 @@ def getServerIpDict(db, status_id, resolver_dict, domain_list, rdtype): ...@@ -113,6 +113,6 @@ def getServerIpDict(db, status_id, resolver_dict, domain_list, rdtype):
if address not in server_ip_dict: if address not in server_ip_dict:
server_ip_dict[address] = [] server_ip_dict[address] = []
if domain_text not in server_ip_dict[address]: if domain_text not in server_ip_dict[address]:
# Do not duplicate the domain # Do not duplicate the domain
server_ip_dict[address].append(domain_text) server_ip_dict[address].append(domain_text)
return server_ip_dict return server_ip_dict
...@@ -13,7 +13,7 @@ def getUrlHostname(url): ...@@ -13,7 +13,7 @@ def getUrlHostname(url):
return urlparse(url).hostname return urlparse(url).hostname
def getUserAgent(self, version="0"): def getUserAgent(version):
return "%s/%s (+%s)" % ( return "%s/%s (+%s)" % (
"URLCHECKER", "URLCHECKER",
version, version,
...@@ -22,16 +22,10 @@ def getUserAgent(self, version="0"): ...@@ -22,16 +22,10 @@ def getUserAgent(self, version="0"):
def request( def request(
method,
url, url,
headers=None, headers=None,
stream=False,
timeout=TIMEOUT,
allow_redirects=False,
verify=True,
session=requests, session=requests,
version=None, version=0
**kwargs,
): ):
if headers is None: if headers is None:
...@@ -42,11 +36,12 @@ def request( ...@@ -42,11 +36,12 @@ def request(
# XXX user agent # XXX user agent
headers["User-Agent"] = getUserAgent(version) headers["User-Agent"] = getUserAgent(version)
kwargs["stream"] = stream kwargs = {}
kwargs["timeout"] = timeout kwargs["stream"] = False
kwargs["allow_redirects"] = allow_redirects kwargs["timeout"] = TIMEOUT
kwargs["verify"] = verify kwargs["allow_redirects"] = False
args = [method, url] kwargs["verify"] = True
args = ["GET", url]
kwargs["headers"] = headers kwargs["headers"] = headers
...@@ -89,33 +84,34 @@ def logHttpStatus(db, ip, url, code, status_id): ...@@ -89,33 +84,34 @@ def logHttpStatus(db, ip, url, code, status_id):
previous_entry = db.HttpCodeChange.create( previous_entry = db.HttpCodeChange.create(
status=status_id, ip=ip, url=url, status_code=code status=status_id, ip=ip, url=url, status_code=code
) )
return previous_entry.id return previous_entry.status
def checkHttpStatus(db, status_id, url, ip, bot_version): def checkHttpStatus(db, status_id, url, ip, bot_version):
parsed_url = urlparse(url) parsed_url = urlparse(url)
hostname = parsed_url.hostname hostname = parsed_url.hostname
request_kw = {}
session = requests.Session()
# SNI Support # SNI Support
if parsed_url.scheme == "https": if parsed_url.scheme == "https":
# Provide SNI support # Provide SNI support
base_url = urlunsplit( base_url = urlunsplit(
(parsed_url.scheme, parsed_url.netloc, "", "", "") (parsed_url.scheme, parsed_url.netloc, "", "", "")
) )
session = requests.Session()
session.mount(base_url, ForcedIPHTTPSAdapter(dest_ip=ip)) session.mount(base_url, ForcedIPHTTPSAdapter(dest_ip=ip))
request_kw['session'] = session
ip_url = url ip_url = url
elif parsed_url.scheme == "http": elif parsed_url.scheme == "http":
# Force IP location # Force IP location
parsed_url = parsed_url._replace(netloc=ip) parsed_url = parsed_url._replace(netloc=ip)
ip_url = parsed_url.geturl() ip_url = parsed_url.geturl()
else:
raise NotImplementedError('Unhandled url: %s' % url)
response = request( response = request(
"GET",
ip_url, ip_url,
headers={"Host": hostname}, headers={"Host": hostname},
session=session,
version=bot_version, version=bot_version,
**request_kw
) )
logHttpStatus(db, ip, url, response.status_code, status_id) logHttpStatus(db, ip, url, response.status_code, status_id)
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment