Commit e600e173 authored by zhifan huang's avatar zhifan huang

fix: fix problem when merge_requests

remove useless test, correct typo and eek problem, refactor verbose
command

see nexedi/re6stnet!38
parent 3140768d
...@@ -25,17 +25,17 @@ class Node(nemu.Node): ...@@ -25,17 +25,17 @@ class Node(nemu.Node):
@property @property
def ip(self): def ip(self):
if hasattr(self, "_ip"): try:
return str(self._ip) return str(self._ip)
except AttributeError:
# return 1 ipv4 address of the one interface, reverse mode # return 1 ipv4 address of the one interface, reverse mode
for iface in self.get_interfaces()[::-1]: for iface in self.get_interfaces()[::-1]:
for addr in iface.get_addresses(): for addr in iface.get_addresses():
addr = addr['address'] addr = addr['address']
if '.' in addr: if '.' in addr:
#TODO different type problem? #TODO different type problem?
self._ip = addr self._ip = addr
return addr return addr
def connect_switch(self, switch, ip, prefix_len=24): def connect_switch(self, switch, ip, prefix_len=24):
...@@ -49,7 +49,7 @@ class NetManager(object): ...@@ -49,7 +49,7 @@ class NetManager(object):
"""contain all the nemu object created, so they can live more time""" """contain all the nemu object created, so they can live more time"""
def __init__(self): def __init__(self):
self.object = [] self.object = []
self.registrys = {} self.registries = {}
def connectible_test(nm): def connectible_test(nm):
...@@ -61,11 +61,13 @@ def connectible_test(nm): ...@@ -61,11 +61,13 @@ def connectible_test(nm):
Raise: Raise:
AssertionError AssertionError
""" """
for reg in nm.registrys: for reg in nm.registries:
for node in nm.registrys[reg]: for node in nm.registries[reg]:
app0 = node.Popen(["ping", "-c", "1", reg.ip], stdout=PIPE) app0 = node.Popen(["ping", "-c", "1", reg.ip], stdout=PIPE)
ret = app0.wait() ret = app0.wait()
assert ret == 0, "network construct failed {} to {}".format(node.ip, reg.ip) if ret:
raise ConnectionError(
"network construct failed {} to {}".format(node.ip, reg.ip))
logging.debug("each node can ping to their registry") logging.debug("each node can ping to their registry")
...@@ -89,7 +91,7 @@ def net_route(): ...@@ -89,7 +91,7 @@ def net_route():
m2_if_0 = machine2.connect_switch(switch1, "192.168.1.3") m2_if_0 = machine2.connect_switch(switch1, "192.168.1.3")
nm.object.append(switch1) nm.object.append(switch1)
nm.registrys[registry] = [machine1, machine2] nm.registries[registry] = [machine1, machine2]
connectible_test(nm) connectible_test(nm)
return nm return nm
...@@ -116,7 +118,7 @@ def net_demo(): ...@@ -116,7 +118,7 @@ def net_demo():
nm = NetManager() nm = NetManager()
nm.object = [internet, switch3, switch1, switch2, gateway1, gateway2] nm.object = [internet, switch3, switch1, switch2, gateway1, gateway2]
nm.registrys = {registry: [m1, m2, m3, m4, m5, m6, m7, m8]} nm.registries = {registry: [m1, m2, m3, m4, m5, m6, m7, m8]}
# for node in [g1, m3, m4, m5]: # for node in [g1, m3, m4, m5]:
# print "pid: {}".format(node.pid) # print "pid: {}".format(node.pid)
...@@ -183,7 +185,7 @@ def network_direct(): ...@@ -183,7 +185,7 @@ def network_direct():
registry = Node() registry = Node()
m0 = Node() m0 = Node()
nm = NetManager() nm = NetManager()
nm.registrys = {registry: [m0]} nm.registries = {registry: [m0]}
re_if_0, m_if_0 = nemu.P2PInterface.create_pair(registry, m0) re_if_0, m_if_0 = nemu.P2PInterface.create_pair(registry, m0)
...@@ -200,5 +202,5 @@ def network_direct(): ...@@ -200,5 +202,5 @@ def network_direct():
return nm return nm
if __name__ == "__main__": if __name__ == "__main__":
nm = network_demo() nm = net_demo()
time.sleep(1000000) time.sleep(1000000)
...@@ -29,9 +29,7 @@ class MultiPing(MultiPing): ...@@ -29,9 +29,7 @@ class MultiPing(MultiPing):
except socket.timeout: except socket.timeout:
pass pass
except socket.error as e: except socket.error as e:
if e.errno == errno.EWOULDBLOCK: if e.errno != errno.EWOULDBLOCK:
pass
else:
raise raise
return pkts return pkts
...@@ -39,21 +37,18 @@ def main(): ...@@ -39,21 +37,18 @@ def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-a', nargs = '+', help = 'the list of addresses to ping') parser.add_argument('-a', nargs = '+', help = 'the list of addresses to ping')
parser.add_argument('--retry', action='store_true', help='retry ping unitl success') parser.add_argument('--retry', action='store_true', help='retry ping unitl success')
args = parser.parse_args() args = parser.parse_args()
addrs = args.a addrs = args.a
retry = args.retry retry = args.retry
no_responses = ""
while True: while retry and no_responses:
mp = MultiPing(addrs) mp = MultiPing(addrs)
mp.send() mp.send()
_, no_responses = mp.receive(PING_TIMEOUT) _, no_responses = mp.receive(PING_TIMEOUT)
if retry and no_responses: sys.stdout.write(" ".join(no_responses))
continue
else:
sys.stdout.write(" ".join(no_responses))
return
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -10,6 +10,7 @@ import time ...@@ -10,6 +10,7 @@ import time
import re import re
import tempfile import tempfile
import logging import logging
import errno
from subprocess import PIPE, call from subprocess import PIPE, call
from pathlib2 import Path from pathlib2 import Path
...@@ -117,21 +118,27 @@ class Re6stRegistry(object): ...@@ -117,21 +118,27 @@ class Re6stRegistry(object):
serial=ip_to_serial(self.ip6)) serial=ip_to_serial(self.ip6))
def run(self): def run(self):
cmd = ("{script} --ca {ca} --key {key} --dh {dh} --ipv4 10.42.0.0/16 8 " cmd =['--ca', self.ca_crt, '--key', self.ca_key, '--dh', DH_FILE,
" --logfile {log} --db {db} --run {run} --hello 4 --mailhost s " '--ipv4', '10.42.0.0/16', '8', '--logfile', self.log, '--db', self.db,
"-v4 --client-count {nb}") '--run', self.run_path, '--hello', '4', '--mailhost', 's', '-v4',
cmd = cmd.format(script=RE6ST_REGISTRY, ca=self.ca_crt, '--client-count', (self.client_number+1)//2]
key=self.ca_key, dh=DH_FILE, log=self.log, db=self.db,
run=self.run_path, nb=(self.client_number+1)//2).split() #convert PosixPath to str, can be remove in python3
cmd = map(str, cmd)
cmd = RE6ST_REGISTRY.split() + cmd
logging.debug("run registry %s at ns: %s with cmd: %s", logging.debug("run registry %s at ns: %s with cmd: %s",
self.name, self.node.pid, " ".join(cmd)) self.name, self.node.pid, " ".join(cmd))
self.proc = self.node.Popen(cmd, stdout=PIPE, stderr=PIPE) self.proc = self.node.Popen(cmd, stdout=PIPE, stderr=PIPE)
def clean(self): def clean(self):
"""remove the file created last time""" """remove the file created last time"""
for f in [self.log]: try:
if f.exists(): self.log.unlink()
f.unlink() except OSError as e:
if e.errno != errno.ENOENT:
raise
def __del__(self): def __del__(self):
try: try:
...@@ -206,9 +213,8 @@ class Re6stNode(object): ...@@ -206,9 +213,8 @@ class Re6stNode(object):
def create_node(self): def create_node(self):
"""create necessary file for node""" """create necessary file for node"""
logging.info("create dir of node %s", self.name) logging.info("create dir of node %s", self.name)
cmd = "{script} --registry {registry_url} --email {email}" cmd = ["--registry", self.registry.url, '--email', self.email]
cmd = cmd.format(script=RE6ST_CONF, registry_url=self.registry.url, cmd = RE6ST_CONF.split() + cmd
email=self.email).split()
p = self.node.Popen(cmd, stdin=PIPE, stdout=PIPE, stderr=PIPE, p = self.node.Popen(cmd, stdin=PIPE, stdout=PIPE, stderr=PIPE,
cwd=str(self.path)) cwd=str(self.path))
# read token # read token
...@@ -235,14 +241,13 @@ class Re6stNode(object): ...@@ -235,14 +241,13 @@ class Re6stNode(object):
def run(self, *args): def run(self, *args):
"""execute re6stnet""" """execute re6stnet"""
cmd = ("{script} --log {log} --run {run} --state {state}" cmd = ['--log', self.path, '--run', self.run_path, '--state', self.path,
" --dh {dh} --ca {ca} --cert {cert} --key {key} -v4" '--dh', DH_FILE, '--ca', self.registry.ca_crt, '--cert', self.crt,
" --registry {registry} --console {console}" '--key', self.key, '-v4', '--registry', self.registry.url,
) '--console', self.console]
cmd = cmd.format(script=RE6STNET, log=self.path, run=self.run_path, cmd = map(str, cmd)
state=self.path, dh=DH_FILE, ca=self.registry.ca_crt, cmd = RE6STNET.split() + cmd
cert=self.crt, key=self.key, registry=self.registry.url,
console=self.console).split()
cmd += args cmd += args
logging.debug("run node %s at ns: %s with cmd: %s", logging.debug("run node %s at ns: %s with cmd: %s",
self.name, self.node.pid, " ".join(cmd)) self.name, self.node.pid, " ".join(cmd))
...@@ -252,8 +257,11 @@ class Re6stNode(object): ...@@ -252,8 +257,11 @@ class Re6stNode(object):
"""remove the file created last time""" """remove the file created last time"""
for name in ["re6stnet.log", "babeld.state", "cache.db", "babeld.log"]: for name in ["re6stnet.log", "babeld.state", "cache.db", "babeld.log"]:
f = self.path / name f = self.path / name
if f.exists(): try:
f.unlink() f.unlink()
except OSError as e:
if e.errno != errno.ENOENT:
raise
def stop(self): def stop(self):
"""stop running re6stnet process""" """stop running re6stnet process"""
......
...@@ -15,23 +15,23 @@ PING_PATH = str(Path(__file__).parent.resolve() / "ping.py") ...@@ -15,23 +15,23 @@ PING_PATH = str(Path(__file__).parent.resolve() / "ping.py")
BABEL_HMAC = 'babel_hmac0', 'babel_hmac1', 'babel_hmac2' BABEL_HMAC = 'babel_hmac0', 'babel_hmac1', 'babel_hmac2'
def deploy_re6st(nm, recreate=False): def deploy_re6st(nm, recreate=False):
net = nm.registrys net = nm.registries
nodes = [] nodes = []
registrys = [] registries = []
re6st_wrap.Re6stRegistry.registry_seq = 0 re6st_wrap.Re6stRegistry.registry_seq = 0
re6st_wrap.Re6stNode.node_seq = 0 re6st_wrap.Re6stNode.node_seq = 0
for registry in net: for registry in net:
reg = re6st_wrap.Re6stRegistry(registry, "2001:db8:42::", len(net[registry]), reg = re6st_wrap.Re6stRegistry(registry, "2001:db8:42::", len(net[registry]),
recreate=recreate) recreate=recreate)
reg_node = re6st_wrap.Re6stNode(registry, reg, name=reg.name) reg_node = re6st_wrap.Re6stNode(registry, reg, name=reg.name)
registrys.append(reg) registries.append(reg)
reg_node.run("--gateway", "--disable-proto", "none", "--ip", registry.ip) reg_node.run("--gateway", "--disable-proto", "none", "--ip", registry.ip)
nodes.append(reg_node) nodes.append(reg_node)
for m in net[registry]: for m in net[registry]:
node = re6st_wrap.Re6stNode(m, reg) node = re6st_wrap.Re6stNode(m, reg)
node.run("-i" + m.iface.name) node.run("-i" + m.iface.name)
nodes.append(node) nodes.append(node)
return nodes, registrys return nodes, registries
def wait_stable(nodes, timeout=240): def wait_stable(nodes, timeout=240):
"""try use ping6 from each node to the other until ping success to all the """try use ping6 from each node to the other until ping success to all the
......
"""Re6st unittest module """Re6st unittest module
""" """
# contatin the test case \ No newline at end of file
__all__ = ["test_registry",
"test_registry_client",
"test_conf",
"test_tunnel"]
...@@ -90,17 +90,20 @@ class TestRegistryServer(unittest.TestCase): ...@@ -90,17 +90,20 @@ class TestRegistryServer(unittest.TestCase):
"0000000000000\0" # ERROR, IndexError: msg is null "0000000000000\0" # ERROR, IndexError: msg is null
] ]
res1 = self.server.recv(4) try:
res2 = self.server.recv(4) res1 = self.server.recv(4)
res3 = self.server.recv(4) res2 = self.server.recv(4)
res4 = self.server.recv(4) res3 = self.server.recv(4)
res4 = self.server.recv(4)
self.assertEqual(res1, (None, None)) # not contain \0
self.assertEqual(res2, (None, None)) # binary to digital failed self.assertEqual(res1, (None, None)) # not contain \0
self.assertEqual(res3, (None, None)) # code don't match self.assertEqual(res2, (None, None)) # binary to digital failed
self.assertEqual(res4, ("0001001001001", "a_msg")) self.assertEqual(res3, (None, None)) # code don't match
self.assertEqual(res4, ("0001001001001", "a_msg"))
del self.server.sock.recv except:
pass
finally:
del self.server.sock.recv
def test_onTimeout(self): def test_onTimeout(self):
# old token, cert, not old token, cert # old token, cert, not old token, cert
......
__all__ = ["test_multi_gateway_manager", "test_base_tunnel_manager"]
...@@ -32,82 +32,14 @@ class testBaseTunnelManager(unittest.TestCase): ...@@ -32,82 +32,14 @@ class testBaseTunnelManager(unittest.TestCase):
self.cache.same_country = False self.cache.same_country = False
address = [(2, [('10.0.0.2', '1194', 'udp'), ('10.0.0.2', '1194', 'tcp')])] address = [(2, [('10.0.0.2', '1194', 'udp'), ('10.0.0.2', '1194', 'tcp')])]
self.tunnel = tunnel.BaseTunnelManager(self.control_socket, self.tunnel = tunnel.BaseTunnelManager(self.control_socket,
self.cache, self.cert, None, address) self.cache, self.cert, None, address)
def tearDown(self): def tearDown(self):
self.tunnel.close() self.tunnel.close()
del self.tunnel del self.tunnel
#TODO selectTimeout in contain callback, removing, update
@patch("re6st.tunnel.BaseTunnelManager._babel_dump_one", create=True)
@patch("re6st.tunnel.BaseTunnelManager._babel_dump_two", create=True)
def test_babel_dump(self, two, one):
""" case two func in requesting_dump"""
self.tunnel._BaseTunnelManager__requesting_dump = set(['one', 'two'])
self.tunnel.babel_dump()
# assert is empty
self.assertFalse(self.tunnel._BaseTunnelManager__requesting_dump)
one.assert_called_once()
two.assert_called_once()
@patch("re6st.ctl.Babel.request_dump")
def test_request_dump_empty(self, request_dump):
"""case when self.__requesting_dump is None or empty"""
reason = "rina"
self.tunnel._BaseTunnelManager__request_dump(reason)
self.assertEqual(self.tunnel._BaseTunnelManager__requesting_dump, set([reason]))
request_dump.assert_called_once()
@patch("re6st.ctl.Babel.request_dump")
def test___request_dump_not_empty(self, request_dump):
"""case when self.__requesting_dump is not empty"""
self.tunnel._BaseTunnelManager__requesting_dump = set(["rina"])
reason = "reason"
self.tunnel._BaseTunnelManager__request_dump(reason)
self.assertEqual(self.tunnel._BaseTunnelManager__requesting_dump, set([reason, "rina"]))
request_dump.assert_not_called()
def test_selectTimeout_add_callback(self):
"""case add new callback"""
self.tunnel._timeouts = [(1, self.tunnel.close)]
callback = self.tunnel.babel_dump
self.tunnel.selectTimeout(10, callback)
self.assertIn((10, callback), self.tunnel._timeouts)
def test_selectTimeout_removing(self):
"""case remove a callback"""
removed = self.tunnel.babel_dump
self.tunnel._timeouts = [(1, self.tunnel.close), (10, removed)]
self.tunnel.selectTimeout(None, removed)
self.assertEqual(self.tunnel._timeouts, [(1, self.tunnel.close)])
def test_selectTimeout_update(self):
"""case update a callback"""
updated = self.tunnel.babel_dump
self.tunnel._timeouts = [(1, self.tunnel.close), (10, updated)]
self.tunnel.selectTimeout(100, updated)
self.assertEqual(self.tunnel._timeouts, [(1, self.tunnel.close), (100, updated)])
@patch("re6st.tunnel.BaseTunnelManager.selectTimeout") @patch("re6st.tunnel.BaseTunnelManager.selectTimeout")
def test_invalidatePeers(self, selectTimeout): def test_invalidatePeers(self, selectTimeout):
...@@ -125,19 +57,19 @@ class testBaseTunnelManager(unittest.TestCase): ...@@ -125,19 +57,19 @@ class testBaseTunnelManager(unittest.TestCase):
self.tunnel._peers = [p1, p2, p3] self.tunnel._peers = [p1, p2, p3]
self.tunnel.invalidatePeers() self.tunnel.invalidatePeers()
self.assertEqual(self.tunnel._peers, [p1, p3]) self.assertEqual(self.tunnel._peers, [p1, p3])
selectTimeout.assert_called_once_with(p1.stop_date, self.tunnel.invalidatePeers) selectTimeout.assert_called_once_with(p1.stop_date, self.tunnel.invalidatePeers)
# Because _makeTunnel is defined in sub class of BaseTunnelManager, so i comment # Because _makeTunnel is defined in sub class of BaseTunnelManager, so i comment
# the follow test # the follow test
# @patch("re6st.tunnel.BaseTunnelManager._makeTunnel", create=True) # @patch("re6st.tunnel.BaseTunnelManager._makeTunnel", create=True)
# def test_processPacket_address_with_msg_peer(self, makeTunnel): # def test_processPacket_address_with_msg_peer(self, makeTunnel):
# """code is 1, peer and msg not none """ # """code is 1, peer and msg not none """
# c = chr(1) # c = chr(1)
# msg = "address" # msg = "address"
# peer = x509.Peer("000001") # peer = x509.Peer("000001")
# self.tunnel._connecting = {peer} # self.tunnel._connecting = {peer}
# self.tunnel._processPacket(c + msg, peer) # self.tunnel._processPacket(c + msg, peer)
...@@ -162,7 +94,7 @@ class testBaseTunnelManager(unittest.TestCase): ...@@ -162,7 +94,7 @@ class testBaseTunnelManager(unittest.TestCase):
in my opion, this function return address in form address,port,portocl in my opion, this function return address in form address,port,portocl
and each address join by ; and each address join by ;
it will truncate address which has more than 3 element it will truncate address which has more than 3 element
""" """
c = chr(1) c = chr(1)
peer = x509.Peer("000001") peer = x509.Peer("000001")
peer.protocol = 1 peer.protocol = 1
...@@ -172,17 +104,6 @@ class testBaseTunnelManager(unittest.TestCase): ...@@ -172,17 +104,6 @@ class testBaseTunnelManager(unittest.TestCase):
res = self.tunnel._processPacket(c, peer) res = self.tunnel._processPacket(c, peer)
self.assertEqual(res, "1,1,1;0,0,0;2,2,2") self.assertEqual(res, "1,1,1;0,0,0;2,2,2")
@patch("re6st.tunnel.BaseTunnelManager.selectTimeout")
def test_processPacket_version(self, selectTimeout):
"""code is 0, for network version, peer is none"""
c = chr(0)
self.tunnel._processPacket(c)
self.assertEqual(selectTimeout.call_args[0][1], self.tunnel.newVersion)
@patch("re6st.x509.Cert.verifyVersion", Mock(return_value=True)) @patch("re6st.x509.Cert.verifyVersion", Mock(return_value=True))
@patch("re6st.tunnel.BaseTunnelManager.selectTimeout") @patch("re6st.tunnel.BaseTunnelManager.selectTimeout")
...@@ -194,18 +115,18 @@ class testBaseTunnelManager(unittest.TestCase): ...@@ -194,18 +115,18 @@ class testBaseTunnelManager(unittest.TestCase):
peer = x509.Peer("000001") peer = x509.Peer("000001")
version1 = "00003" version1 = "00003"
version2 = "00007" version2 = "00007"
self.tunnel._version = "00005" self.tunnel._version = version3 = "00005"
self.tunnel._peers.append(peer) self.tunnel._peers.append(peer)
res = self.tunnel._processPacket(c + version1, peer) res = self.tunnel._processPacket(c + version1, peer)
self.tunnel._processPacket(c + version2, peer) self.tunnel._processPacket(c + version2, peer)
self.assertEqual(res, "00005") self.assertEqual(res, version3)
self.assertEqual(self.tunnel._version, version2) self.assertEqual(self.tunnel._version, version2)
self.assertEqual(peer.version, version2) self.assertEqual(peer.version, version2)
self.assertEqual(selectTimeout.call_args[0][1], self.tunnel.newVersion) self.assertEqual(selectTimeout.call_args[0][1], self.tunnel.newVersion)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -10,7 +10,7 @@ from re6st import registry ...@@ -10,7 +10,7 @@ from re6st import registry
def generate_csr(): def generate_csr():
"""generate a certificate request """generate a certificate request
return: return:
crypto.Pekey and crypto.X509Req both in pem format crypto.Pekey and crypto.X509Req both in pem format
""" """
key = crypto.PKey() key = crypto.PKey()
...@@ -27,7 +27,7 @@ def generate_csr(): ...@@ -27,7 +27,7 @@ def generate_csr():
def generate_cert(ca, ca_key, csr, prefix, serial, not_after=None): def generate_cert(ca, ca_key, csr, prefix, serial, not_after=None):
"""generate a certificate """generate a certificate
return return
crypto.X509Cert in pem format crypto.X509Cert in pem format
""" """
if type(ca) is str: if type(ca) is str:
...@@ -51,10 +51,7 @@ def generate_cert(ca, ca_key, csr, prefix, serial, not_after=None): ...@@ -51,10 +51,7 @@ def generate_cert(ca, ca_key, csr, prefix, serial, not_after=None):
cert.set_pubkey(req.get_pubkey()) cert.set_pubkey(req.get_pubkey())
cert.set_serial_number(serial) cert.set_serial_number(serial)
cert.sign(ca_key, 'sha512') cert.sign(ca_key, 'sha512')
cert = crypto.dump_certificate(crypto.FILETYPE_PEM, cert) return crypto.dump_certificate(crypto.FILETYPE_PEM, cert)
return cert
def create_cert_file(pkey_file, cert_file, ca, ca_key, prefix, serial): def create_cert_file(pkey_file, cert_file, ca, ca_key, prefix, serial):
pkey, csr = generate_csr() pkey, csr = generate_csr()
...@@ -91,7 +88,7 @@ def create_ca_file(pkey_file, cert_file, serial=0x120010db80042): ...@@ -91,7 +88,7 @@ def create_ca_file(pkey_file, cert_file, serial=0x120010db80042):
pkey_file.write(crypto.dump_privatekey(crypto.FILETYPE_PEM, key)) pkey_file.write(crypto.dump_privatekey(crypto.FILETYPE_PEM, key))
with open(cert_file, 'w') as cert_file: with open(cert_file, 'w') as cert_file:
cert_file.write(crypto.dump_certificate(crypto.FILETYPE_PEM, cert)) cert_file.write(crypto.dump_certificate(crypto.FILETYPE_PEM, cert))
return key, cert return key, cert
...@@ -99,7 +96,7 @@ def prefix2cn(prefix): ...@@ -99,7 +96,7 @@ def prefix2cn(prefix):
return "%u/%u" % (int(prefix, 2), len(prefix)) return "%u/%u" % (int(prefix, 2), len(prefix))
def serial2prefix(serial): def serial2prefix(serial):
return bin(serial)[2:].rjust(16, '0') return bin(serial)[2:].rjust(16, '0')
# pkey: private key # pkey: private key
def decrypt(pkey, incontent): def decrypt(pkey, incontent):
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment