Commit d48b47a4 authored by Łukasz Nowak's avatar Łukasz Nowak Committed by Łukasz Nowak

updater: Make stateful decision

If at least once certificate has been downloaded from KeDiFa it shall never
use again the fall-back, as otherwise it would result with a problem, that
next unsuccessful download from KeDiFa would result replacement with
fall-back.

In order to do so state file is introduced keeping list of overridden
certificates. As now there is critical path regarding fetching certificates,
the lock is created to avoid concurrent updates.
parent 53e99f68
...@@ -142,6 +142,13 @@ def updater(*args): ...@@ -142,6 +142,13 @@ def updater(*args):
'certificate, and DESTINATION is the output file.' 'certificate, and DESTINATION is the output file.'
) )
parser.add_argument(
'state',
type=str,
help='Path to JSON state file for fallback recognition, on which locks '
'will happen.'
)
parser.add_argument( parser.add_argument(
'--identity', '--identity',
type=argparse.FileType('r'), type=argparse.FileType('r'),
...@@ -185,7 +192,7 @@ def updater(*args): ...@@ -185,7 +192,7 @@ def updater(*args):
parsed = parser.parse_args(args) parsed = parser.parse_args(args)
u = Updater( u = Updater(
parsed.sleep, parsed.mapping.name, parsed.master_certificate, parsed.sleep, parsed.mapping.name, parsed.state, parsed.master_certificate,
parsed.on_update, parsed.identity.name, parsed.server_ca_certificate.name, parsed.on_update, parsed.identity.name, parsed.server_ca_certificate.name,
parsed.once parsed.once
) )
......
...@@ -22,6 +22,7 @@ import contextlib ...@@ -22,6 +22,7 @@ import contextlib
import datetime import datetime
import httplib import httplib
import ipaddress import ipaddress
import json
import mock import mock
import multiprocessing import multiprocessing
import os import os
...@@ -308,11 +309,14 @@ class KedifaIntegrationTest(KedifaMixinCaucase, unittest.TestCase): ...@@ -308,11 +309,14 @@ class KedifaIntegrationTest(KedifaMixinCaucase, unittest.TestCase):
mapping = tempfile.NamedTemporaryFile(dir=self.testdir, delete=False) mapping = tempfile.NamedTemporaryFile(dir=self.testdir, delete=False)
mapping.write("%s %s" % (url, destination)) mapping.write("%s %s" % (url, destination))
mapping.close() mapping.close()
state = tempfile.NamedTemporaryFile(dir=self.testdir, delete=False)
state.close()
updater( updater(
'--once', '--once',
'--server-ca-certificate', self.ca_crt_pem, '--server-ca-certificate', self.ca_crt_pem,
'--identity', certificate, '--identity', certificate,
mapping.name, mapping.name,
state.name
) )
def updater_get(self, url, certificate): def updater_get(self, url, certificate):
...@@ -1089,123 +1093,242 @@ class KedifaIntegrationTest(KedifaMixinCaucase, unittest.TestCase): ...@@ -1089,123 +1093,242 @@ class KedifaIntegrationTest(KedifaMixinCaucase, unittest.TestCase):
) )
class KedifaUpdaterMappingTest(KedifaMixin, unittest.TestCase): class KedifaUpdaterMixin(KedifaMixin):
def setUp(self):
super(KedifaUpdaterMixin, self).setUp()
state = tempfile.NamedTemporaryFile(dir=self.testdir, delete=False)
state.close()
self.state = state.name
def setupMapping(self, mapping_content=''): def setupMapping(self, mapping_content=''):
mapping = tempfile.NamedTemporaryFile(dir=self.testdir, delete=False) mapping = tempfile.NamedTemporaryFile(dir=self.testdir, delete=False)
mapping.write(mapping_content) mapping.write(mapping_content)
mapping.close() mapping.close()
self.mapping = mapping.name self.mapping = mapping.name
class KedifaUpdaterMappingTest(KedifaUpdaterMixin, unittest.TestCase):
def test_updateMapping_empty(self): def test_updateMapping_empty(self):
self.setupMapping() self.setupMapping()
u = Updater(1, self.mapping, None, None, None, None, True) u = Updater(1, self.mapping, self.state, None, None, None, None, True)
u.updateMapping() u.updateMapping()
self.assertEqual(u.mapping, {}) self.assertEqual(u.mapping, {})
def test_updateMapping_normal(self): def test_updateMapping_normal(self):
self.setupMapping('url file') self.setupMapping('url file')
u = Updater(1, self.mapping, None, None, None, None, True) u = Updater(1, self.mapping, self.state, None, None, None, None, True)
u.updateMapping() u.updateMapping()
self.assertEqual(u.mapping, {'file': 'url'}) self.assertEqual(u.mapping, {'file': ('url', None)})
def test_updateMapping_morewhite(self): def test_updateMapping_morewhite(self):
self.setupMapping('url \t file') self.setupMapping('url \t file')
u = Updater(1, self.mapping, None, None, None, None, True) u = Updater(1, self.mapping, self.state, None, None, None, None, True)
u.updateMapping() u.updateMapping()
self.assertEqual(u.mapping, {'file': 'url'}) self.assertEqual(u.mapping, {'file': ('url', None)})
def test_updateMapping_one_empty(self): def test_updateMapping_one_empty(self):
self.setupMapping('url file\n \n') self.setupMapping('url file\n \n')
u = Updater(1, self.mapping, None, None, None, None, True) u = Updater(1, self.mapping, self.state, None, None, None, None, True)
u.updateMapping() u.updateMapping()
self.assertEqual(u.mapping, {'file': 'url'}) self.assertEqual(u.mapping, {'file': ('url', None)})
def test_updateMapping_one_not_enough(self): def test_updateMapping_one_not_enough(self):
self.setupMapping('url file\nbuzz\n') self.setupMapping('url file\nbuzz\n')
u = Updater(1, self.mapping, None, None, None, None, True) u = Updater(1, self.mapping, self.state, None, None, None, None, True)
u.updateMapping() u.updateMapping()
self.assertEqual(u.mapping, {'file': 'url'}) self.assertEqual(u.mapping, {'file': ('url', None)})
def test_updateMapping_one_too_much(self): def test_updateMapping_with_fallback(self):
self.setupMapping('url file\nbuzz oink aff\n') self.setupMapping('url file\nbuzz oink fallback\n')
u = Updater(1, self.mapping, None, None, None, None, True) u = Updater(1, self.mapping, self.state, None, None, None, None, True)
u.updateMapping() u.updateMapping()
self.assertEqual(u.mapping, {'file': 'url'}) self.assertEqual(
u.mapping, {'file': ('url', None), 'oink': ('buzz', 'fallback')})
def test_updateMapping_one_comment(self): def test_updateMapping_one_comment(self):
self.setupMapping('url file\n#buzz uff\n') self.setupMapping('url file\n#buzz uff\n')
u = Updater(1, self.mapping, None, None, None, None, True) u = Updater(1, self.mapping, self.state, None, None, None, None, True)
u.updateMapping() u.updateMapping()
self.assertEqual(u.mapping, {'file': 'url'}) self.assertEqual(u.mapping, {'file': ('url', None)})
class KedifaUpdaterUpdateCertificateTest(KedifaMixin, unittest.TestCase): class KedifaUpdaterUpdateCertificateTest(
def setupMapping(self, mapping_content=''): KedifaUpdaterMixin, unittest.TestCase):
mapping = tempfile.NamedTemporaryFile(dir=self.testdir, delete=False) def setUp(self):
mapping.write(mapping_content) super(KedifaUpdaterUpdateCertificateTest, self).setUp()
mapping.close()
self.mapping = mapping.name
def _update(self, certificate, fetch, master_content):
certificate_file = tempfile.NamedTemporaryFile( certificate_file = tempfile.NamedTemporaryFile(
dir=self.testdir, delete=False) dir=self.testdir, delete=False)
certificate_file.write(certificate)
certificate_file.close() certificate_file.close()
self.setupMapping('http://example.com %s' % (certificate_file.name,)) self.certificate_file_name = certificate_file.name
def _update(self, certificate, fetch, master_content, fallback=None):
with open(self.certificate_file_name, 'w') as fh:
fh.write(certificate)
fallback_file = None
if fallback:
fallback_file = tempfile.NamedTemporaryFile(
dir=self.testdir, delete=False)
fallback_file.write(fallback)
fallback_file.close()
mapping = 'http://example.com %s' % (self.certificate_file_name,)
if fallback_file:
mapping = '%s %s' % (mapping, fallback_file.name)
self.setupMapping(mapping)
u = Updater( u = Updater(
1, self.mapping, '/master/certificate/file', None, None, None, True) 1, self.mapping, self.state, '/master/certificate/file', None, None,
None, True)
u.updateMapping() u.updateMapping()
u.readState()
with mock.patch.object( with mock.patch.object(
Updater, 'fetchCertificate', return_value=fetch): Updater, 'fetchCertificate', return_value=fetch):
result = u.updateCertificate(certificate_file.name, master_content) result = u.updateCertificate(self.certificate_file_name, master_content)
return open(certificate_file.name, 'r').read(), result u.writeState()
return open(self.certificate_file_name, 'r').read(), result
def test_nocert_nofetch_nomaster(self): def assertState(self, state):
with open(self.state, 'r') as fh:
json_state = json.load(fh)
self.assertEqual(
json_state,
state
)
def test_nocert_nofetch_nomaster_nofallback(self):
certificate, update = self._update( certificate, update = self._update(
certificate='', fetch='', master_content=None) certificate='', fetch='', master_content=None)
self.assertEqual('', certificate) self.assertEqual('', certificate)
self.assertFalse(update) self.assertFalse(update)
self.assertState({})
def test_cert_nofetch_nomaster(self): def test_cert_nofetch_nomaster_nofallback(self):
certificate, update = self._update( certificate, update = self._update(
certificate='old content', fetch='', master_content=None) certificate='old content', fetch='', master_content=None)
self.assertEqual('old content', certificate) self.assertEqual('old content', certificate)
self.assertFalse(update) self.assertFalse(update)
self.assertState({})
def test_nocert_fetch_nomaster(self): def test_nocert_fetch_nomaster_nofallback(self):
certificate, update = self._update( certificate, update = self._update(
certificate='', fetch='content', master_content=None) certificate='', fetch='content', master_content=None)
self.assertEqual('content', certificate) self.assertEqual('content', certificate)
self.assertTrue(update) self.assertTrue(update)
self.assertState({self.certificate_file_name: True})
def test_cert_fetch_nomaster(self): def test_cert_fetch_nomaster_nofallback(self):
certificate, update = self._update( certificate, update = self._update(
certificate='old content', fetch='content', master_content=None) certificate='old content', fetch='content', master_content=None)
self.assertEqual('content', certificate) self.assertEqual('content', certificate)
self.assertTrue(update) self.assertTrue(update)
self.assertState({self.certificate_file_name: True})
def test_nocert_nofetch_master(self): def test_nocert_nofetch_master_nofallback(self):
certificate, update = self._update( certificate, update = self._update(
certificate='', fetch='', master_content='master') certificate='', fetch='', master_content='master')
self.assertEqual('master', certificate) self.assertEqual('master', certificate)
self.assertTrue(update) self.assertTrue(update)
self.assertState({})
def test_cert_nofetch_master(self): def test_cert_nofetch_master_nofallback(self):
# This is important feature. Master certifcate does not override existing
# certificate, so it can use provided outside of KeDiFa
certificate, update = self._update( certificate, update = self._update(
certificate='old content', fetch='', master_content='master') certificate='old content', fetch='', master_content='master')
self.assertEqual('old content', certificate) self.assertEqual('old content', certificate)
self.assertFalse(update) self.assertFalse(update)
self.assertState({})
def test_nocert_fetch_master(self): def test_nocert_fetch_master_nofallback(self):
certificate, update = self._update( certificate, update = self._update(
certificate='', fetch='content', master_content='master') certificate='', fetch='content', master_content='master')
self.assertEqual('content', certificate) self.assertEqual('content', certificate)
self.assertTrue(update) self.assertTrue(update)
self.assertState({self.certificate_file_name: True})
def test_cert_fetch_master(self): def test_cert_fetch_master_nofallback(self):
certificate, update = self._update( certificate, update = self._update(
certificate='old content', fetch='content', master_content='master') certificate='old content', fetch='content', master_content='master')
self.assertEqual('content', certificate) self.assertEqual('content', certificate)
self.assertTrue(update) self.assertTrue(update)
self.assertState({self.certificate_file_name: True})
def test_nocert_nofetch_nomaster_fallback(self):
certificate, update = self._update(
certificate='', fetch='', master_content=None, fallback='fallback')
self.assertEqual('fallback', certificate)
self.assertTrue(update)
self.assertState({})
def test_cert_nofetch_nomaster_fallback(self):
certificate, update = self._update(
certificate='old content', fetch='', master_content=None,
fallback='fallback')
self.assertEqual('fallback', certificate)
self.assertTrue(update)
self.assertState({})
def test_cert_nofetch_nomaster_fallback_overridden(self):
with open(self.state, 'w') as fh:
json.dump({self.certificate_file_name: True}, fh)
certificate, update = self._update(
certificate='old content', fetch='', master_content=None,
fallback='fallback')
self.assertEqual('old content', certificate)
self.assertFalse(update)
self.assertState({self.certificate_file_name: True})
def test_nocert_fetch_nomaster_fallback(self):
certificate, update = self._update(
certificate='', fetch='content', master_content=None,
fallback='fallback')
self.assertEqual('content', certificate)
self.assertTrue(update)
self.assertState({self.certificate_file_name: True})
def test_cert_fetch_nomaster_fallback(self):
certificate, update = self._update(
certificate='old content', fetch='content', master_content=None,
fallback='fallback')
self.assertEqual('content', certificate)
self.assertTrue(update)
self.assertState({self.certificate_file_name: True})
def test_nocert_nofetch_master_fallback(self):
certificate, update = self._update(
certificate='', fetch='', master_content='master',
fallback='fallback')
self.assertEqual('fallback', certificate)
self.assertTrue(update)
self.assertState({})
def test_cert_nofetch_master_fallback(self):
certificate, update = self._update(
certificate='old content', fetch='', master_content='master',
fallback='fallback')
self.assertEqual('fallback', certificate)
self.assertTrue(update)
self.assertState({})
def test_cert_nofetch_master_fallback_overridden(self):
with open(self.state, 'w') as fh:
json.dump({self.certificate_file_name: True}, fh)
certificate, update = self._update(
certificate='old content', fetch='', master_content='master',
fallback='fallback')
self.assertEqual('old content', certificate)
self.assertFalse(update)
self.assertState({self.certificate_file_name: True})
def test_nocert_fetch_master_fallback(self):
certificate, update = self._update(
certificate='', fetch='content', master_content='master',
fallback='fallback')
self.assertEqual('content', certificate)
self.assertTrue(update)
self.assertState({self.certificate_file_name: True})
def test_cert_fetch_master_fallback(self):
certificate, update = self._update(
certificate='old content', fetch='content', master_content='master',
fallback='fallback')
self.assertEqual('content', certificate)
self.assertTrue(update)
self.assertState({self.certificate_file_name: True})
import httplib import httplib
import json
import os import os
import requests import requests
import sys
import time import time
import zc.lockfile
class Updater(object): class Updater(object):
def __init__(self, sleep, mapping_file, master_certificate_file, on_update, def __init__(self, sleep, mapping_file, state_file, master_certificate_file,
identity_file, server_ca_certificate_file, once): on_update, identity_file, server_ca_certificate_file, once):
self.sleep = sleep self.sleep = sleep
self.mapping_file = mapping_file self.mapping_file = mapping_file
self.state_file = state_file
self.state_lock_file = '%s.lock' % (state_file, )
self.master_certificate_file = master_certificate_file self.master_certificate_file = master_certificate_file
self.on_update = on_update self.on_update = on_update
self.identity_file = identity_file self.identity_file = identity_file
...@@ -25,11 +30,15 @@ class Updater(object): ...@@ -25,11 +30,15 @@ class Updater(object):
if not line: if not line:
continue continue
line_content = line.split() line_content = line.split()
if len(line_content) != 2: if len(line_content) == 2:
url, certificate = line_content
fallback = None
elif len(line_content) == 3:
url, certificate, fallback = line_content
else:
print 'Line %r is incorrect' % (line,) print 'Line %r is incorrect' % (line,)
continue continue
url, certificate = line_content self.mapping[certificate] = (url, fallback)
self.mapping[certificate] = url
def fetchCertificate(self, url, certificate_file): def fetchCertificate(self, url, certificate_file):
certificate = '' certificate = ''
...@@ -50,9 +59,17 @@ class Updater(object): ...@@ -50,9 +59,17 @@ class Updater(object):
return certificate return certificate
def updateCertificate(self, certificate_file, master_content=None): def updateCertificate(self, certificate_file, master_content=None):
url = self.mapping[certificate_file] url, fallback_file = self.mapping[certificate_file]
certificate = self.fetchCertificate(url, certificate_file) certificate = self.fetchCertificate(url, certificate_file)
fallback_overridden = self.state_dict.get(certificate_file, False)
fallback = ''
if fallback_file:
try:
with open(fallback_file, 'r') as fh:
fallback = fh.read() or None
except IOError:
pass
current = '' current = ''
try: try:
with open(certificate_file, 'r') as fh: with open(certificate_file, 'r') as fh:
...@@ -61,11 +78,16 @@ class Updater(object): ...@@ -61,11 +78,16 @@ class Updater(object):
current = '' current = ''
if not(certificate): if not(certificate):
if not current and master_content is not None: if fallback and not fallback_overridden:
certificate = fallback
elif not current and master_content is not None:
url = self.master_certificate_file url = self.master_certificate_file
certificate = master_content certificate = master_content
else: else:
return False return False
else:
self.state_dict[certificate_file] = True
if current != certificate: if current != certificate:
with open(certificate_file, 'w') as fh: with open(certificate_file, 'w') as fh:
fh.write(certificate) fh.write(certificate)
...@@ -79,8 +101,23 @@ class Updater(object): ...@@ -79,8 +101,23 @@ class Updater(object):
status = os.system(self.on_update) status = os.system(self.on_update)
print 'Called %r with status %i' % (self.on_update, status) print 'Called %r with status %i' % (self.on_update, status)
def loop(self): def readState(self):
while True: self.state_dict = {}
try:
with open(self.state_file, 'r') as fh:
try:
self.state_dict = json.load(fh)
except ValueError:
pass
except IOError:
pass
def writeState(self):
with open(self.state_file, 'w') as fh:
json.dump(self.state_dict, fh, indent=2)
def action(self):
self.readState()
self.updateMapping() self.updateMapping()
updated = False updated = False
...@@ -105,7 +142,22 @@ class Updater(object): ...@@ -105,7 +142,22 @@ class Updater(object):
if updated: if updated:
self.callOnUpdate() self.callOnUpdate()
self.writeState()
def loop(self):
while True:
try:
lock = zc.lockfile.LockFile(self.state_lock_file)
except zc.lockfile.LockError as e:
print e,
if self.once:
print '...exiting.'
sys.exit(1)
else:
print "...will try again later."
else:
self.action()
lock.close()
if self.once: if self.once:
break break
print 'Sleeping for %is' % (self.sleep,) print 'Sleeping for %is' % (self.sleep,)
......
...@@ -49,6 +49,7 @@ setup( ...@@ -49,6 +49,7 @@ setup(
install_requires=[ install_requires=[
'cryptography', # for working with certificates 'cryptography', # for working with certificates
'requests', # for getter and updater 'requests', # for getter and updater
'zc.lockfile', # for stateful updater
'urllib3 >= 1.18', # https://github.com/urllib3/urllib3/issues/258 'urllib3 >= 1.18', # https://github.com/urllib3/urllib3/issues/258
'caucase', # provides utils for certificate management; 'caucase', # provides utils for certificate management;
# version requirement caucase >= 0.9.3 is dropped, as it # version requirement caucase >= 0.9.3 is dropped, as it
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment