Commit b22b441b authored by Romain Courteaud's avatar Romain Courteaud

Add timeout parameter

parent 7b9db1df
...@@ -210,10 +210,10 @@ class UrlCheckerDNSTestCase(unittest.TestCase): ...@@ -210,10 +210,10 @@ class UrlCheckerDNSTestCase(unittest.TestCase):
# buildResolver # buildResolver
################################################ ################################################
def test_buildResolver_default(self): def test_buildResolver_default(self):
resolver = buildResolver("127.0.0.1") resolver = buildResolver("127.0.0.1", 4)
assert resolver.nameservers == ["127.0.0.1"] assert resolver.nameservers == ["127.0.0.1"]
assert resolver.timeout == 2 assert resolver.timeout == 4
assert resolver.lifetime == 2 assert resolver.lifetime == 4
assert resolver.edns == -1 assert resolver.edns == -1
################################################ ################################################
......
...@@ -311,13 +311,14 @@ class UrlCheckerHttpTestCase(unittest.TestCase): ...@@ -311,13 +311,14 @@ class UrlCheckerHttpTestCase(unittest.TestCase):
"https://example.org/foo?bar=1", "https://example.org/foo?bar=1",
) )
assert ( assert (
len(mock_request.call_args.kwargs) == 3 len(mock_request.call_args.kwargs) == 4
), mock_request.call_args.kwargs ), mock_request.call_args.kwargs
assert mock_request.call_args.kwargs["headers"] == { assert mock_request.call_args.kwargs["headers"] == {
"Host": "example.org" "Host": "example.org"
} }
assert mock_request.call_args.kwargs["session"] is not None assert mock_request.call_args.kwargs["session"] is not None
assert mock_request.call_args.kwargs["version"] == 2 assert mock_request.call_args.kwargs["version"] == 2
assert mock_request.call_args.kwargs["timeout"] == 2
assert self.db.HttpCodeChange.select().count() == 1 assert self.db.HttpCodeChange.select().count() == 1
assert self.db.HttpCodeChange.get().ip == ip assert self.db.HttpCodeChange.get().ip == ip
......
...@@ -61,11 +61,12 @@ class WebBot: ...@@ -61,11 +61,12 @@ class WebBot:
def iterateLoop(self): def iterateLoop(self):
status_id = logStatus(self._db, "loop") status_id = logStatus(self._db, "loop")
timeout = int(self.config["TIMEOUT"])
# logPlatform(self._db, __version__, status_id) # logPlatform(self._db, __version__, status_id)
# Calculate the resolver list # Calculate the resolver list
resolver_ip_list = getReachableResolverList( resolver_ip_list = getReachableResolverList(
self._db, status_id, self.config["NAMESERVER"].split() self._db, status_id, self.config["NAMESERVER"].split(), timeout
) )
if not resolver_ip_list: if not resolver_ip_list:
return return
...@@ -76,7 +77,7 @@ class WebBot: ...@@ -76,7 +77,7 @@ class WebBot:
# Get the list of server to check # Get the list of server to check
# XXX Check DNS expiration # XXX Check DNS expiration
server_ip_dict = getDomainIpDict( server_ip_dict = getDomainIpDict(
self._db, status_id, resolver_ip_list, domain_list, "A" self._db, status_id, resolver_ip_list, domain_list, "A", timeout
) )
# Check TCP port for the list of IP found # Check TCP port for the list of IP found
...@@ -86,7 +87,9 @@ class WebBot: ...@@ -86,7 +87,9 @@ class WebBot:
for server_ip in server_ip_list: for server_ip in server_ip_list:
# XXX Check SSL certificate expiration # XXX Check SSL certificate expiration
for port, protocol in [(80, "http"), (443, "https")]: for port, protocol in [(80, "http"), (443, "https")]:
if isTcpPortOpen(self._db, server_ip, port, status_id): if isTcpPortOpen(
self._db, server_ip, port, status_id, timeout
):
for hostname in server_ip_dict[server_ip]: for hostname in server_ip_dict[server_ip]:
url = "%s://%s" % (protocol, hostname) url = "%s://%s" % (protocol, hostname)
if url not in url_dict: if url not in url_dict:
...@@ -103,7 +106,9 @@ class WebBot: ...@@ -103,7 +106,9 @@ class WebBot:
# Check HTTP Status # Check HTTP Status
for url in url_dict: for url in url_dict:
for ip in url_dict[url]: for ip in url_dict[url]:
checkHttpStatus(self._db, status_id, url, ip, __version__) checkHttpStatus(
self._db, status_id, url, ip, __version__, timeout
)
# XXX Check location header and check new url recursively # XXX Check location header and check new url recursively
# XXX Parse HTML, fetch found link, css, js, image # XXX Parse HTML, fetch found link, css, js, image
# XXX Check HTTP Cache # XXX Check HTTP Cache
......
...@@ -18,6 +18,7 @@ from urlchecker_bot import create_bot ...@@ -18,6 +18,7 @@ from urlchecker_bot import create_bot
@click.option("--nameserver", "-n", help="The IP of the DNS server.") @click.option("--nameserver", "-n", help="The IP of the DNS server.")
@click.option("--url", "-u", help="The url to check.") @click.option("--url", "-u", help="The url to check.")
@click.option("--domain", "-d", help="The domain to check.") @click.option("--domain", "-d", help="The domain to check.")
@click.option("--timeout", "-t", help="The timeout value.")
@click.option( @click.option(
"--configuration", "-f", help="The path of the configuration file." "--configuration", "-f", help="The path of the configuration file."
) )
...@@ -29,7 +30,9 @@ from urlchecker_bot import create_bot ...@@ -29,7 +30,9 @@ from urlchecker_bot import create_bot
default="plain", default="plain",
show_default=True, show_default=True,
) )
def runUrlChecker(run, sqlite, nameserver, url, domain, configuration, output): def runUrlChecker(
run, sqlite, nameserver, url, domain, timeout, configuration, output
):
# click.echo("Running url checker bot") # click.echo("Running url checker bot")
mapping = {} mapping = {}
......
...@@ -30,6 +30,8 @@ def createConfiguration( ...@@ -30,6 +30,8 @@ def createConfiguration(
) )
if "FORMAT" not in config[CONFIG_SECTION]: if "FORMAT" not in config[CONFIG_SECTION]:
config[CONFIG_SECTION]["FORMAT"] = "json" config[CONFIG_SECTION]["FORMAT"] = "json"
if "TIMEOUT" not in config[CONFIG_SECTION]:
config[CONFIG_SECTION]["TIMEOUT"] = "1"
if config[CONFIG_SECTION]["SQLITE"] == ":memory:": if config[CONFIG_SECTION]["SQLITE"] == ":memory:":
# Do not loop when using temporary DB # Do not loop when using temporary DB
......
...@@ -90,20 +90,20 @@ def logDnsQuery(db, status_id, resolver_ip, domain_text, rdtype, answer_list): ...@@ -90,20 +90,20 @@ def logDnsQuery(db, status_id, resolver_ip, domain_text, rdtype, answer_list):
return previous_entry.status_id return previous_entry.status_id
def buildResolver(resolver_ip): def buildResolver(resolver_ip, timeout):
resolver = dns.resolver.Resolver(configure=False) resolver = dns.resolver.Resolver(configure=False)
resolver.nameservers.append(resolver_ip) resolver.nameservers.append(resolver_ip)
resolver.timeout = TIMEOUT resolver.timeout = timeout
resolver.lifetime = TIMEOUT resolver.lifetime = timeout
resolver.edns = -1 resolver.edns = -1
return resolver return resolver
def queryDNS(db, status_id, resolver_ip, domain_text, rdtype): def queryDNS(db, status_id, resolver_ip, domain_text, rdtype, timeout=TIMEOUT):
# only A (and AAAA) has address property # only A (and AAAA) has address property
assert rdtype == "A" assert rdtype == "A"
resolver = buildResolver(resolver_ip) resolver = buildResolver(resolver_ip, timeout)
try: try:
answer_list = [ answer_list = [
x.address x.address
...@@ -123,14 +123,16 @@ def queryDNS(db, status_id, resolver_ip, domain_text, rdtype): ...@@ -123,14 +123,16 @@ def queryDNS(db, status_id, resolver_ip, domain_text, rdtype):
return answer_list return answer_list
def getReachableResolverList(db, status_id, resolver_ip_list): def getReachableResolverList(db, status_id, resolver_ip_list, timeout=TIMEOUT):
# Create a list of resolver object # Create a list of resolver object
result_ip_list = [] result_ip_list = []
# Check the DNS server availability once # Check the DNS server availability once
# to prevent using it later if it is down # to prevent using it later if it is down
for resolver_ip in resolver_ip_list: for resolver_ip in resolver_ip_list:
resolver_state = "open" resolver_state = "open"
answer_list = queryDNS(db, status_id, resolver_ip, URL_TO_CHECK, "A") answer_list = queryDNS(
db, status_id, resolver_ip, URL_TO_CHECK, "A", timeout
)
if len(answer_list) == 0: if len(answer_list) == 0:
# We expect a valid response # We expect a valid response
...@@ -156,12 +158,14 @@ def expandDomainList(domain_list): ...@@ -156,12 +158,14 @@ def expandDomainList(domain_list):
return domain_list return domain_list
def getDomainIpDict(db, status_id, resolver_ip_list, domain_list, rdtype): def getDomainIpDict(
db, status_id, resolver_ip_list, domain_list, rdtype, timeout=TIMEOUT
):
server_ip_dict = {} server_ip_dict = {}
for domain_text in domain_list: for domain_text in domain_list:
for resolver_ip in resolver_ip_list: for resolver_ip in resolver_ip_list:
answer_list = queryDNS( answer_list = queryDNS(
db, status_id, resolver_ip, domain_text, rdtype db, status_id, resolver_ip, domain_text, rdtype, timeout
) )
for address in answer_list: for address in answer_list:
if address not in server_ip_dict: if address not in server_ip_dict:
......
...@@ -25,7 +25,7 @@ def getUserAgent(version): ...@@ -25,7 +25,7 @@ def getUserAgent(version):
) )
def request(url, headers=None, session=requests, version=0): def request(url, timeout=TIMEOUT, headers=None, session=requests, version=0):
if headers is None: if headers is None:
headers = {} headers = {}
...@@ -37,7 +37,7 @@ def request(url, headers=None, session=requests, version=0): ...@@ -37,7 +37,7 @@ def request(url, headers=None, session=requests, version=0):
kwargs = {} kwargs = {}
kwargs["stream"] = False kwargs["stream"] = False
kwargs["timeout"] = TIMEOUT kwargs["timeout"] = timeout
kwargs["allow_redirects"] = False kwargs["allow_redirects"] = False
kwargs["verify"] = True kwargs["verify"] = True
args = ["GET", url] args = ["GET", url]
...@@ -125,10 +125,10 @@ def logHttpStatus(db, ip, url, code, status_id): ...@@ -125,10 +125,10 @@ def logHttpStatus(db, ip, url, code, status_id):
return previous_entry.status_id return previous_entry.status_id
def checkHttpStatus(db, status_id, url, ip, bot_version): def checkHttpStatus(db, status_id, url, ip, bot_version, timeout=TIMEOUT):
parsed_url = urlparse(url) parsed_url = urlparse(url)
hostname = parsed_url.hostname hostname = parsed_url.hostname
request_kw = {} request_kw = {"timeout": timeout}
# SNI Support # SNI Support
if parsed_url.scheme == "https": if parsed_url.scheme == "https":
# Provide SNI support # Provide SNI support
......
...@@ -2,6 +2,7 @@ import socket ...@@ -2,6 +2,7 @@ import socket
import errno import errno
from peewee import fn from peewee import fn
TIMEOUT = 2 TIMEOUT = 2
...@@ -86,10 +87,10 @@ def logNetwork(db, ip, transport, port, state, status_id): ...@@ -86,10 +87,10 @@ def logNetwork(db, ip, transport, port, state, status_id):
return previous_entry.status_id return previous_entry.status_id
def isTcpPortOpen(db, ip, port, status_id): def isTcpPortOpen(db, ip, port, status_id, timeout=TIMEOUT):
is_open = False is_open = False
sock = socket.socket() sock = socket.socket()
sock.settimeout(TIMEOUT) sock.settimeout(timeout)
try: try:
sock.connect((ip, port)) sock.connect((ip, port))
state = "open" state = "open"
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment