Commit 43fc51da authored by Pedro Oliveira's avatar Pedro Oliveira

Fix messages, interfaces added in kernel, graft dst ip address, return ip address type str

parent 9290b913
...@@ -14,7 +14,7 @@ class InterfaceIGMP(object): ...@@ -14,7 +14,7 @@ class InterfaceIGMP(object):
PACKET_MR_ALLMULTI = 2 PACKET_MR_ALLMULTI = 2
def __init__(self, interface_name: str): def __init__(self, interface_name: str, vif_index:int):
# RECEIVE SOCKET # RECEIVE SOCKET
rcv_s = socket.socket(socket.PF_PACKET, socket.SOCK_RAW, socket.htons(InterfaceIGMP.ETH_P_IP)) rcv_s = socket.socket(socket.PF_PACKET, socket.SOCK_RAW, socket.htons(InterfaceIGMP.ETH_P_IP))
...@@ -39,6 +39,9 @@ class InterfaceIGMP(object): ...@@ -39,6 +39,9 @@ class InterfaceIGMP(object):
from igmp.RouterState import RouterState from igmp.RouterState import RouterState
self.interface_state = RouterState(self) self.interface_state = RouterState(self)
# virtual interface index for the multicast routing table
self.vif_index = vif_index
# run receive method in background # run receive method in background
receive_thread = threading.Thread(target=self.receive) receive_thread = threading.Thread(target=self.receive)
receive_thread.daemon = True receive_thread.daemon = True
...@@ -47,6 +50,11 @@ class InterfaceIGMP(object): ...@@ -47,6 +50,11 @@ class InterfaceIGMP(object):
def get_ip(self): def get_ip(self):
return netifaces.ifaddresses(self.interface_name)[netifaces.AF_INET][0]['addr'] return netifaces.ifaddresses(self.interface_name)[netifaces.AF_INET][0]['addr']
@property
def ip_interface(self):
return self.get_ip()
def send(self, data: bytes, address: str="224.0.0.1"): def send(self, data: bytes, address: str="224.0.0.1"):
if self.interface_enabled: if self.interface_enabled:
self.send_socket.sendto(data, (address, 0)) self.send_socket.sendto(data, (address, 0))
......
...@@ -20,7 +20,7 @@ class InterfacePim(Interface): ...@@ -20,7 +20,7 @@ class InterfacePim(Interface):
MAX_TRIGGERED_HELLO_PERIOD = 5 MAX_TRIGGERED_HELLO_PERIOD = 5
def __init__(self, interface_name: str): def __init__(self, interface_name: str, vif_index:int):
super().__init__(interface_name) super().__init__(interface_name)
# generation id # generation id
...@@ -50,12 +50,18 @@ class InterfacePim(Interface): ...@@ -50,12 +50,18 @@ class InterfacePim(Interface):
self.neighbors = {} self.neighbors = {}
self.neighbors_lock = RWLockWrite() self.neighbors_lock = RWLockWrite()
# virtual interface index for the multicast routing table
self.vif_index = vif_index
# run receive method in background # run receive method in background
receive_thread = threading.Thread(target=self.receive) receive_thread = threading.Thread(target=self.receive)
receive_thread.daemon = True receive_thread.daemon = True
receive_thread.start() receive_thread.start()
def create_virtual_interface(self):
self.vif_index = Main.kernel.create_virtual_interface(ip_interface=self.ip_interface, interface_name=self.interface_name)
def receive(self): def receive(self):
while self.is_enabled(): while self.is_enabled():
try: try:
...@@ -66,16 +72,6 @@ class InterfacePim(Interface): ...@@ -66,16 +72,6 @@ class InterfacePim(Interface):
traceback.print_exc() traceback.print_exc()
continue continue
"""
while self.interface_enabled:
(raw_packet, (ip, _)) = self.socket.recvfrom(256 * 1024)
if raw_packet:
packet = ReceivedPacket(raw_packet, self)
Main.protocols[packet.payload.get_pim_type()].receive_handle(packet) # TODO: perceber se existe melhor maneira de fazer isto
except Exception:
traceback.print_exc()
continue
"""
def send(self, data: bytes, group_ip: str=MCAST_GRP): def send(self, data: bytes, group_ip: str=MCAST_GRP):
super().send(data=data, group_ip=group_ip) super().send(data=data, group_ip=group_ip)
...@@ -107,6 +103,7 @@ class InterfacePim(Interface): ...@@ -107,6 +103,7 @@ class InterfacePim(Interface):
self.send(packet.bytes()) self.send(packet.bytes())
super().remove() super().remove()
Main.kernel.remove_virtual_interface(self.ip_interface)
def add_neighbor(self, ip, random_number, hello_hold_time): def add_neighbor(self, ip, random_number, hello_hold_time):
......
...@@ -13,82 +13,6 @@ from tree.tree_if_upstream import * ...@@ -13,82 +13,6 @@ from tree.tree_if_upstream import *
from tree.tree_if_downstream import * from tree.tree_if_downstream import *
from tree.KernelEntry import KernelEntry from tree.KernelEntry import KernelEntry
"""
class KernelEntry:
def __init__(self, source_ip: str, group_ip: str, inbound_interface_index: int):
self.source_ip = source_ip
self.group_ip = group_ip
# decide inbound interface based on rpf check
self.inbound_interface_index = Main.kernel.vif_dic[self.check_rpf()]
# all other interfaces = outbound
#self.outbound_interfaces = [1] * Kernel.MAXVIFS
#self.outbound_interfaces[self.inbound_interface_index] = 0
self._lock = threading.Lock()
# todo
self.state = {} # type: Dict[int, SFRMTreeInterface]
for i in range(Kernel.MAXVIFS):
if i == self.inbound_interface_index:
self.state[i] = SFRMRootInterface(self, i, False)
else:
self.state[i] = SFRMNonRootInterface(self, i)
def lock(self):
self._lock.acquire()
def unlock(self):
self._lock.release()
def get_inbound_interface_index(self):
return self.inbound_interface_index
def get_outbound_interfaces_indexes(self):
# todo check state of outbound interfaces
outbound_indexes = [0]*Kernel.MAXVIFS
for (index, state) in self.state.items():
outbound_indexes[index] = state.is_forwarding()
return outbound_indexes
def check_rpf(self):
from pyroute2 import IPRoute
# from utils import if_indextoname
ipr = IPRoute()
# obter index da interface
# rpf_interface_index = ipr.get_routes(family=socket.AF_INET, dst=ip)[0]['attrs'][2][1]
# interface_name = if_indextoname(rpf_interface_index)
# return interface_name
# obter ip da interface de saida
rpf_interface_source = ipr.get_routes(family=socket.AF_INET, dst=socket.inet_ntoa(self.source_ip))[0]['attrs'][3][1]
return rpf_interface_source
def recv_data_msg(self, index):
self.state[index].recv_data_msg()
def recv_assert_msg(self, index, packet):
self.state[index].recv_assert_msg(packet, None)
def recv_prune_msg(self, index, packet):
self.state[index].recv_prune_msg(None, None)
def recv_join_msg(self, index, packet):
self.state[index].recv_join_msg(None, None)
def change(self):
# todo: changes on unicast routing or multicast routing...
Main.kernel.set_multicast_route(self)
def delete(self):
Main.kernel.remove_multicast_route(self)
"""
class Kernel: class Kernel:
# MRT # MRT
...@@ -140,9 +64,11 @@ class Kernel: ...@@ -140,9 +64,11 @@ class Kernel:
self.socket = s self.socket = s
self.rwlock = RWLockWrite() self.rwlock = RWLockWrite()
self.interface_lock = Lock()
# Create virtual interfaces # Create virtual interfaces
'''
interfaces = netifaces.interfaces() interfaces = netifaces.interfaces()
for interface in interfaces: for interface in interfaces:
try: try:
...@@ -154,6 +80,10 @@ class Kernel: ...@@ -154,6 +80,10 @@ class Kernel:
self.create_virtual_interface(ip_interface=addr, interface_name=interface) self.create_virtual_interface(ip_interface=addr, interface_name=interface)
except Exception: except Exception:
continue continue
'''
self.pim_interface = {} # name: interface_pim
self.igmp_interface = {} # name: interface_igmp
# receive signals from kernel with a background thread # receive signals from kernel with a background thread
handler_thread = threading.Thread(target=self.handler) handler_thread = threading.Thread(target=self.handler)
...@@ -174,19 +104,112 @@ class Kernel: ...@@ -174,19 +104,112 @@ class Kernel:
struct in_addr vifc_rmt_addr; /* IPIP tunnel addr */ struct in_addr vifc_rmt_addr; /* IPIP tunnel addr */
}; };
''' '''
def create_virtual_interface(self, ip_interface: str or bytes, interface_name: str, index: int = None, flags=0x0): '''
def create_virtual_interface(self, ip_interface: str or bytes, interface_name: str, flags=0x0):
with self.interface_lock:
index = list(range(0, self.MAXVIFS) - self.vif_index_to_name_dic.keys())[0]
if type(ip_interface) is str: if type(ip_interface) is str:
ip_interface = socket.inet_aton(ip_interface) ip_interface = socket.inet_aton(ip_interface)
if index is None:
index = len(self.vif_dic)
struct_mrt_add_vif = struct.pack("HBBI 4s 4s", index, flags, 1, 0, ip_interface, socket.inet_aton("0.0.0.0")) struct_mrt_add_vif = struct.pack("HBBI 4s 4s", index, flags, 1, 0, ip_interface, socket.inet_aton("0.0.0.0"))
self.socket.setsockopt(socket.IPPROTO_IP, Kernel.MRT_ADD_VIF, struct_mrt_add_vif) self.socket.setsockopt(socket.IPPROTO_IP, Kernel.MRT_ADD_VIF, struct_mrt_add_vif)
self.vif_dic[socket.inet_ntoa(ip_interface)] = index self.vif_dic[socket.inet_ntoa(ip_interface)] = index
self.vif_index_to_name_dic[index] = interface_name self.vif_index_to_name_dic[index] = interface_name
self.vif_name_to_index_dic[interface_name] = index self.vif_name_to_index_dic[interface_name] = index
with self.rwlock.genWlock():
for kernel_entry in list(self.routing.values()):
kernel_entry.new_interface(index)
return index
'''
######################### new create virtual if
def create_virtual_interface(self, ip_interface: str or bytes, interface_name: str, index, flags=0x0):
#with self.interface_lock:
if type(ip_interface) is str:
ip_interface = socket.inet_aton(ip_interface)
struct_mrt_add_vif = struct.pack("HBBI 4s 4s", index, flags, 1, 0, ip_interface,
socket.inet_aton("0.0.0.0"))
self.socket.setsockopt(socket.IPPROTO_IP, Kernel.MRT_ADD_VIF, struct_mrt_add_vif)
self.vif_dic[socket.inet_ntoa(ip_interface)] = index
self.vif_index_to_name_dic[index] = interface_name
self.vif_name_to_index_dic[interface_name] = index
with self.rwlock.genWlock():
for kernel_entry in list(self.routing.values()):
kernel_entry.new_interface(index)
return index
def create_interface(self, interface_name: str, igmp:bool = False, pim:bool = False):
from InterfaceIGMP import InterfaceIGMP
from InterfacePIM import InterfacePim
if (not igmp and not pim):
return
with self.interface_lock:
pim_interface = self.pim_interface.get(interface_name)
igmp_interface = self.igmp_interface.get(interface_name)
vif_already_exists = pim_interface or igmp_interface
if pim_interface:
index = pim_interface.vif_index
elif igmp_interface:
index = igmp_interface.vif_index
else:
index = list(range(0, self.MAXVIFS) - self.vif_index_to_name_dic.keys())[0]
ip_interface = None
if pim and interface_name not in self.pim_interface:
pim_interface = InterfacePim(interface_name, index)
self.pim_interface[interface_name] = pim_interface
ip_interface = pim_interface.ip_interface
if igmp and interface_name not in self.igmp_interface:
igmp_interface = InterfaceIGMP(interface_name, index)
self.igmp_interface[interface_name] = igmp_interface
ip_interface = igmp_interface.ip_interface
if not vif_already_exists:
self.create_virtual_interface(ip_interface=ip_interface, interface_name=interface_name, index=index)
def remove_interface(self, interface_name, igmp:bool=False, pim:bool=False):
with self.interface_lock:
ip_interface = None
pim_interface = self.pim_interface.get(interface_name)
igmp_interface = self.igmp_interface.get(interface_name)
if (igmp and not igmp_interface) or (pim and not pim_interface) or (not igmp and not pim):
return
if pim:
pim_interface = self.pim_interface.pop(interface_name)
ip_interface = pim_interface.ip_interface
pim_interface.remove()
elif igmp:
igmp_interface = self.igmp_interface.pop(interface_name)
ip_interface = igmp_interface.ip_interface
igmp_interface.remove()
if (not self.igmp_interface.get(interface_name) and not self.pim_interface.get(interface_name)):
self.remove_virtual_interface(ip_interface)
def remove_virtual_interface(self, ip_interface): def remove_virtual_interface(self, ip_interface):
#with self.interface_lock:
index = self.vif_dic[ip_interface] index = self.vif_dic[ip_interface]
struct_vifctl = struct.pack("HBBI 4s 4s", index, 0, 0, 0, socket.inet_aton("0.0.0.0"), socket.inet_aton("0.0.0.0")) struct_vifctl = struct.pack("HBBI 4s 4s", index, 0, 0, 0, socket.inet_aton("0.0.0.0"), socket.inet_aton("0.0.0.0"))
...@@ -197,6 +220,12 @@ class Kernel: ...@@ -197,6 +220,12 @@ class Kernel:
del self.vif_index_to_name_dic[index] del self.vif_index_to_name_dic[index]
# TODO alterar MFC's para colocar a 0 esta interface # TODO alterar MFC's para colocar a 0 esta interface
with self.rwlock.genWlock():
for kernel_entry in list(self.routing.values()):
kernel_entry.remove_interface(index)
''' '''
/* Cache manipulation structures for mrouted and PIMd */ /* Cache manipulation structures for mrouted and PIMd */
...@@ -213,6 +242,10 @@ class Kernel: ...@@ -213,6 +242,10 @@ class Kernel:
''' '''
def set_multicast_route(self, kernel_entry: KernelEntry): def set_multicast_route(self, kernel_entry: KernelEntry):
source_ip = socket.inet_aton(kernel_entry.source_ip) source_ip = socket.inet_aton(kernel_entry.source_ip)
print("============")
print(type(kernel_entry.group_ip))
print(kernel_entry.group_ip)
print("============")
group_ip = socket.inet_aton(kernel_entry.group_ip) group_ip = socket.inet_aton(kernel_entry.group_ip)
outbound_interfaces = kernel_entry.get_outbound_interfaces_indexes() outbound_interfaces = kernel_entry.get_outbound_interfaces_indexes()
...@@ -230,6 +263,7 @@ class Kernel: ...@@ -230,6 +263,7 @@ class Kernel:
# TODO: ver melhor tabela routing # TODO: ver melhor tabela routing
#self.routing[(socket.inet_ntoa(source_ip), socket.inet_ntoa(group_ip))] = {"inbound_interface_index": inbound_interface_index, "outbound_interfaces": outbound_interfaces} #self.routing[(socket.inet_ntoa(source_ip), socket.inet_ntoa(group_ip))] = {"inbound_interface_index": inbound_interface_index, "outbound_interfaces": outbound_interfaces}
'''
def flood(self, ip_src, ip_dst, iif): def flood(self, ip_src, ip_dst, iif):
source_ip = socket.inet_aton(ip_src) source_ip = socket.inet_aton(ip_src)
group_ip = socket.inet_aton(ip_dst) group_ip = socket.inet_aton(ip_dst)
...@@ -242,6 +276,7 @@ class Kernel: ...@@ -242,6 +276,7 @@ class Kernel:
#struct_mfcctl = struct.pack("4s 4s H " + "B"*Kernel.MAXVIFS + " IIIi", source_ip, group_ip, inbound_interface_index, *outbound_interfaces, 0, 0, 0, 0) #struct_mfcctl = struct.pack("4s 4s H " + "B"*Kernel.MAXVIFS + " IIIi", source_ip, group_ip, inbound_interface_index, *outbound_interfaces, 0, 0, 0, 0)
struct_mfcctl = struct.pack("4s 4s H " + "B"*Kernel.MAXVIFS + " IIIi", source_ip, group_ip, iif, *outbound_interfaces_and_other_parameters) struct_mfcctl = struct.pack("4s 4s H " + "B"*Kernel.MAXVIFS + " IIIi", source_ip, group_ip, iif, *outbound_interfaces_and_other_parameters)
self.socket.setsockopt(socket.IPPROTO_IP, Kernel.MRT_ADD_MFC, struct_mfcctl) self.socket.setsockopt(socket.IPPROTO_IP, Kernel.MRT_ADD_MFC, struct_mfcctl)
'''
def remove_multicast_route(self, kernel_entry: KernelEntry): def remove_multicast_route(self, kernel_entry: KernelEntry):
source_ip = socket.inet_aton(kernel_entry.source_ip) source_ip = socket.inet_aton(kernel_entry.source_ip)
...@@ -251,6 +286,7 @@ class Kernel: ...@@ -251,6 +286,7 @@ class Kernel:
struct_mfcctl = struct.pack("4s 4s H " + "B"*Kernel.MAXVIFS + " IIIi", source_ip, group_ip, *outbound_interfaces_and_other_parameters) struct_mfcctl = struct.pack("4s 4s H " + "B"*Kernel.MAXVIFS + " IIIi", source_ip, group_ip, *outbound_interfaces_and_other_parameters)
self.socket.setsockopt(socket.IPPROTO_IP, Kernel.MRT_DEL_MFC, struct_mfcctl) self.socket.setsockopt(socket.IPPROTO_IP, Kernel.MRT_DEL_MFC, struct_mfcctl)
self.routing.pop((kernel_entry.source_ip, kernel_entry.group_ip))
def exit(self): def exit(self):
self.running = False self.running = False
...@@ -348,7 +384,7 @@ class Kernel: ...@@ -348,7 +384,7 @@ class Kernel:
with self.rwlock.genRlock(): with self.rwlock.genRlock():
return self.routing[source_group] return self.routing[source_group]
""" """
def get_routing_entry(self, source_group: tuple, create_if_not_existent=False): def get_routing_entry(self, source_group: tuple, create_if_not_existent=True):
ip_src = source_group[0] ip_src = source_group[0]
ip_dst = source_group[1] ip_dst = source_group[1]
with self.rwlock.genRlock(): with self.rwlock.genRlock():
...@@ -361,6 +397,7 @@ class Kernel: ...@@ -361,6 +397,7 @@ class Kernel:
elif create_if_not_existent: elif create_if_not_existent:
kernel_entry = KernelEntry(ip_src, ip_dst, 0) kernel_entry = KernelEntry(ip_src, ip_dst, 0)
self.routing[source_group] = kernel_entry self.routing[source_group] = kernel_entry
kernel_entry.change()
#self.set_multicast_route(kernel_entry) #self.set_multicast_route(kernel_entry)
return kernel_entry return kernel_entry
else: else:
......
...@@ -16,39 +16,45 @@ igmp = None ...@@ -16,39 +16,45 @@ igmp = None
def add_interface(interface_name, pim=False, igmp=False): def add_interface(interface_name, pim=False, igmp=False):
if pim is True and interface_name not in interfaces: #if pim is True and interface_name not in interfaces:
interface = InterfacePim(interface_name) # interface = InterfacePim(interface_name)
interfaces[interface_name] = interface # interfaces[interface_name] = interface
if igmp is True and interface_name not in igmp_interfaces: # interface.create_virtual_interface()
interface = InterfaceIGMP(interface_name) #if igmp is True and interface_name not in igmp_interfaces:
igmp_interfaces[interface_name] = interface # interface = InterfaceIGMP(interface_name)
# igmp_interfaces[interface_name] = interface
kernel.create_interface(interface_name=interface_name, pim=pim, igmp=igmp)
#if pim:
# interfaces[interface_name] = kernel.pim_interface[interface_name]
#if igmp:
# igmp_interfaces[interface_name] = kernel.igmp_interface[interface_name]
def remove_interface(interface_name, pim=False, igmp=False): def remove_interface(interface_name, pim=False, igmp=False):
if pim is True and ((interface_name in interfaces) or interface_name == "*"): #if pim is True and ((interface_name in interfaces) or interface_name == "*"):
if interface_name == "*": # if interface_name == "*":
interface_name_list = list(interfaces.keys()) # interface_name_list = list(interfaces.keys())
else: # else:
interface_name_list = [interface_name] # interface_name_list = [interface_name]
for if_name in interface_name_list: # for if_name in interface_name_list:
interface_obj = interfaces.pop(if_name) # interface_obj = interfaces.pop(if_name)
interface_obj.remove() # interface_obj.remove()
#interfaces[if_name].remove() # #interfaces[if_name].remove()
#del interfaces[if_name] # #del interfaces[if_name]
print("removido interface") # print("removido interface")
print(interfaces) # print(interfaces)
if igmp is True and ((interface_name in igmp_interfaces) or interface_name == "*"): #if igmp is True and ((interface_name in igmp_interfaces) or interface_name == "*"):
if interface_name == "*": # if interface_name == "*":
interface_name_list = list(igmp_interfaces.keys()) # interface_name_list = list(igmp_interfaces.keys())
else: # else:
interface_name_list = [interface_name] # interface_name_list = [interface_name]
for if_name in interface_name_list: # for if_name in interface_name_list:
igmp_interfaces[if_name].remove() # igmp_interfaces[if_name].remove()
del igmp_interfaces[if_name] # del igmp_interfaces[if_name]
print("removido interface") # print("removido interface")
print(igmp_interfaces) # print(igmp_interfaces)
kernel.remove_interface(interface_name, pim=pim, igmp=igmp)
def add_protocol(protocol_number, protocol_obj): def add_protocol(protocol_number, protocol_obj):
global protocols global protocols
...@@ -199,3 +205,8 @@ def main(): ...@@ -199,3 +205,8 @@ def main():
global u global u
u = UnicastRouting.UnicastRouting() u = UnicastRouting.UnicastRouting()
global interfaces
global igmp_interfaces
interfaces = kernel.pim_interface
igmp_interfaces = kernel.igmp_interface
...@@ -39,7 +39,7 @@ class KernelEntry: ...@@ -39,7 +39,7 @@ class KernelEntry:
# (S,G) starts IG state # (S,G) starts IG state
self._was_olist_null = None self._was_olist_null = False
# todo # todo
#self._rpf_is_origin = False #self._rpf_is_origin = False
...@@ -49,7 +49,7 @@ class KernelEntry: ...@@ -49,7 +49,7 @@ class KernelEntry:
self.inbound_interface_index = Main.kernel.vif_dic[self.check_rpf()] self.inbound_interface_index = Main.kernel.vif_dic[self.check_rpf()]
Main.kernel.flood(source_ip, group_ip, self.inbound_interface_index) #Main.kernel.flood(source_ip, group_ip, self.inbound_interface_index)
self.interface_state = {} # type: Dict[int, TreeInterface] self.interface_state = {} # type: Dict[int, TreeInterface]
...@@ -68,7 +68,8 @@ class KernelEntry: ...@@ -68,7 +68,8 @@ class KernelEntry:
self._lock_test2 = RLock() self._lock_test2 = RLock()
self.CHANGE_STATE_LOCK = RLock() self.CHANGE_STATE_LOCK = RLock()
#self._was_olist_null = self.is_olist_null() #self._was_olist_null = self.is_olist_null()
self.change()
self.evaluate_olist_change()
print('Tree created') print('Tree created')
#self._liveliness_timer = None #self._liveliness_timer = None
#if self.is_originater(): #if self.is_originater():
...@@ -122,7 +123,8 @@ class KernelEntry: ...@@ -122,7 +123,8 @@ class KernelEntry:
def recv_graft_msg(self, index, packet): def recv_graft_msg(self, index, packet):
print("recv graft msg") print("recv graft msg")
upstream_neighbor_address = packet.payload.payload.upstream_neighbor_address upstream_neighbor_address = packet.payload.payload.upstream_neighbor_address
self.interface_state[index].recv_graft_msg(upstream_neighbor_address) source_ip = packet.ip_header.ip_src
self.interface_state[index].recv_graft_msg(upstream_neighbor_address, source_ip)
def recv_graft_ack_msg(self, index, packet): def recv_graft_ack_msg(self, index, packet):
print("recv graft ack msg") print("recv graft ack msg")
...@@ -218,3 +220,24 @@ class KernelEntry: ...@@ -218,3 +220,24 @@ class KernelEntry:
state.delete() state.delete()
Main.kernel.remove_multicast_route(self) Main.kernel.remove_multicast_route(self)
######################################
# Interface change
#######################################
def new_interface(self, index):
with self.CHANGE_STATE_LOCK:
self.interface_state[index] = TreeInterfaceDownstream(self, index)
self.change()
self.evaluate_olist_change()
def remove_interface(self, index):
with self.CHANGE_STATE_LOCK:
#check if removed interface is root interface
if self.inbound_interface_index == index:
self.delete()
else:
self.interface_state[index].delete()
del self.interface_state[index]
self.change()
self.evaluate_olist_change()
...@@ -105,10 +105,9 @@ class AssertStateABC(metaclass=ABCMeta): ...@@ -105,10 +105,9 @@ class AssertStateABC(metaclass=ABCMeta):
def _sendAssert_setAT(interface: "TreeInterfaceDownstream"): def _sendAssert_setAT(interface: "TreeInterfaceDownstream"):
interface.send_assert()
#interface.assert_timer.set_timer(pim_globals.ASSERT_TIME) #interface.assert_timer.set_timer(pim_globals.ASSERT_TIME)
interface.set_assert_timer(pim_globals.ASSERT_TIME) interface.set_assert_timer(pim_globals.ASSERT_TIME)
interface.send_assert()
#interface.assert_timer.reset() #interface.assert_timer.reset()
@staticmethod @staticmethod
...@@ -136,12 +135,12 @@ class NoInfoState(AssertStateABC): ...@@ -136,12 +135,12 @@ class NoInfoState(AssertStateABC):
""" """
@type interface: TreeInterface @type interface: TreeInterface
""" """
NoInfoState._sendAssert_setAT(interface)
interface.set_assert_state(AssertState.Winner) interface.set_assert_state(AssertState.Winner)
#interface.assert_winner_metric = interface.assert_metric
interface.set_assert_winner_metric(interface.my_assert_metric()) interface.set_assert_winner_metric(interface.my_assert_metric())
NoInfoState._sendAssert_setAT(interface)
#interface.assert_winner_metric = interface.assert_metric
print('receivedDataFromDownstreamIf, NI -> W') print('receivedDataFromDownstreamIf, NI -> W')
@staticmethod @staticmethod
...@@ -150,12 +149,12 @@ class NoInfoState(AssertStateABC): ...@@ -150,12 +149,12 @@ class NoInfoState(AssertStateABC):
@staticmethod @staticmethod
def receivedInferiorMetricFromNonWinner_couldAssertIsTrue(interface: "TreeInterfaceDownstream"): def receivedInferiorMetricFromNonWinner_couldAssertIsTrue(interface: "TreeInterfaceDownstream"):
NoInfoState._sendAssert_setAT(interface) interface.set_assert_state(AssertState.Winner)
interface.set_assert_winner_metric(interface.my_assert_metric())
NoInfoState._sendAssert_setAT(interface)
#interface.assert_state = AssertState.Winner #interface.assert_state = AssertState.Winner
interface.set_assert_state(AssertState.Winner)
#interface.assert_winner_metric = interface.assert_metric #interface.assert_winner_metric = interface.assert_metric
interface.set_assert_winner_metric(interface.my_assert_metric())
print( print(
'receivedInferiorMetricFromNonWinner_couldAssertIsTrue, NI -> W') 'receivedInferiorMetricFromNonWinner_couldAssertIsTrue, NI -> W')
...@@ -174,12 +173,12 @@ class NoInfoState(AssertStateABC): ...@@ -174,12 +173,12 @@ class NoInfoState(AssertStateABC):
assert_timer_value = state_refresh_interval*3 assert_timer_value = state_refresh_interval*3
interface.set_assert_timer(assert_timer_value) interface.set_assert_timer(assert_timer_value)
interface.set_assert_winner_metric(better_metric)
interface.set_assert_state(AssertState.Loser)
#interface.assert_timer.reset() #interface.assert_timer.reset()
#interface.assert_state = AssertState.Loser #interface.assert_state = AssertState.Loser
interface.set_assert_state(AssertState.Loser)
#interface.assert_winner_metric = better_metric #interface.assert_winner_metric = better_metric
interface.set_assert_winner_metric(better_metric)
# todo MUST also multicast a Prune(S,G) to the Assert winner <- TO THE colocar endereco do winner # todo MUST also multicast a Prune(S,G) to the Assert winner <- TO THE colocar endereco do winner
if interface.could_assert(): if interface.could_assert():
......
...@@ -25,7 +25,7 @@ class DownstreamStateABS(metaclass=ABCMeta): ...@@ -25,7 +25,7 @@ class DownstreamStateABS(metaclass=ABCMeta):
raise NotImplementedError() raise NotImplementedError()
@abstractstaticmethod @abstractstaticmethod
def receivedGraft(interface: "TreeInterfaceDownstream"): def receivedGraft(interface: "TreeInterfaceDownstream", source_ip):
""" """
Receive Graft(S,G) Receive Graft(S,G)
...@@ -113,7 +113,7 @@ class NoInfo(DownstreamStateABS): ...@@ -113,7 +113,7 @@ class NoInfo(DownstreamStateABS):
print("receivedJoin, NI -> NI") print("receivedJoin, NI -> NI")
@staticmethod @staticmethod
def receivedGraft(interface: "TreeInterfaceDownstream"): def receivedGraft(interface: "TreeInterfaceDownstream", source_ip):
""" """
Receive Graft(S,G) Receive Graft(S,G)
...@@ -122,7 +122,7 @@ class NoInfo(DownstreamStateABS): ...@@ -122,7 +122,7 @@ class NoInfo(DownstreamStateABS):
# todo why pt stop???!!! # todo why pt stop???!!!
#interface.get_pt().stop() #interface.get_pt().stop()
interface.send_graft_ack() interface.send_graft_ack(source_ip)
print('receivedGraft, NI -> NI') print('receivedGraft, NI -> NI')
...@@ -205,7 +205,7 @@ class PrunePending(DownstreamStateABS): ...@@ -205,7 +205,7 @@ class PrunePending(DownstreamStateABS):
print('receivedJoin, PP -> NI') print('receivedJoin, PP -> NI')
@staticmethod @staticmethod
def receivedGraft(interface: "TreeInterfaceDownstream"): def receivedGraft(interface: "TreeInterfaceDownstream", source_ip):
""" """
Receive Graft(S,G) Receive Graft(S,G)
...@@ -216,7 +216,7 @@ class PrunePending(DownstreamStateABS): ...@@ -216,7 +216,7 @@ class PrunePending(DownstreamStateABS):
interface.clear_prune_pending_timer() interface.clear_prune_pending_timer()
interface.set_prune_state(DownstreamState.NoInfo) interface.set_prune_state(DownstreamState.NoInfo)
interface.send_graft_ack() interface.send_graft_ack(source_ip)
print('receivedGraft, PP -> NI') print('receivedGraft, PP -> NI')
...@@ -321,7 +321,7 @@ class Pruned(DownstreamStateABS): ...@@ -321,7 +321,7 @@ class Pruned(DownstreamStateABS):
print('receivedPrune, P -> NI') print('receivedPrune, P -> NI')
@staticmethod @staticmethod
def receivedGraft(interface: "TreeInterfaceDownstream"): def receivedGraft(interface: "TreeInterfaceDownstream", source_ip):
""" """
Receive Graft(S,G) Receive Graft(S,G)
...@@ -330,7 +330,7 @@ class Pruned(DownstreamStateABS): ...@@ -330,7 +330,7 @@ class Pruned(DownstreamStateABS):
#interface.get_pt().stop() #interface.get_pt().stop()
interface.clear_prune_timer() interface.clear_prune_timer()
interface.set_prune_state(DownstreamState.NoInfo) interface.set_prune_state(DownstreamState.NoInfo)
interface.send_graft_ack() interface.send_graft_ack(source_ip)
print('receivedGraft, P -> NI') print('receivedGraft, P -> NI')
......
...@@ -74,3 +74,7 @@ class AssertMetric(object): ...@@ -74,3 +74,7 @@ class AssertMetric(object):
value = ipaddress.ip_address(value) value = ipaddress.ip_address(value)
self._ip_address = value self._ip_address = value
def get_ip(self):
return str(self._ip_address)
...@@ -103,26 +103,31 @@ class TreeInterfaceDownstream(TreeInterface): ...@@ -103,26 +103,31 @@ class TreeInterfaceDownstream(TreeInterface):
def recv_prune_msg(self, upstream_neighbor_address, holdtime): def recv_prune_msg(self, upstream_neighbor_address, holdtime):
super().recv_prune_msg(upstream_neighbor_address, holdtime) super().recv_prune_msg(upstream_neighbor_address, holdtime)
# set here??? #TODO if upstream_neighbor_address == self.get_ip():
if upstream_neighbor_address == self.get_ip():
self.set_receceived_prune_holdtime(holdtime) self.set_receceived_prune_holdtime(holdtime)
self._prune_state.receivedPrune(self, holdtime) self._prune_state.receivedPrune(self, holdtime)
# Override # Override
def recv_join_msg(self, upstream_neighbor_address): def recv_join_msg(self, upstream_neighbor_address):
super().recv_join_msg(upstream_neighbor_address) super().recv_join_msg(upstream_neighbor_address)
if upstream_neighbor_address == self.get_ip():
self._prune_state.receivedJoin(self) self._prune_state.receivedJoin(self)
# Override # Override
def recv_graft_msg(self, upstream_neighbor_address): def recv_graft_msg(self, upstream_neighbor_address, source_ip):
super().recv_graft_msg(upstream_neighbor_address) print("GRAFT!!!")
self._prune_state.receivedGraft(self) super().recv_graft_msg(upstream_neighbor_address, source_ip)
if upstream_neighbor_address == self.get_ip():
self._prune_state.receivedGraft(self, source_ip)
# Override # Override
def is_forwarding(self): def is_forwarding(self):
return ((len(self.get_interface().neighbors) >= 1 and not self.is_pruned()) or self.igmp_has_members()) and not self.lost_assert() return ((len(self.get_interface().neighbors) >= 1 and not self.is_pruned()) or self.igmp_has_members()) and not self.lost_assert()
# todo wtf is boundary??!!
#return self._assert_state == AssertState.Winner and self.is_in_group() #return self._assert_state == AssertState.Winner and self.is_in_group()
def is_pruned(self): def is_pruned(self):
...@@ -138,7 +143,9 @@ class TreeInterfaceDownstream(TreeInterface): ...@@ -138,7 +143,9 @@ class TreeInterfaceDownstream(TreeInterface):
# Override # Override
def delete(self): def delete(self):
TreeInterface.delete(self) TreeInterface.delete(self)
#self._get_dipt().cancel() self.clear_assert_timer()
self.clear_prune_timer()
self.clear_prune_pending_timer()
def get_metric(self): def get_metric(self):
return AssertMetric.spt_assert_metric(self) return AssertMetric.spt_assert_metric(self)
......
...@@ -118,6 +118,7 @@ class TreeInterfaceUpstream(TreeInterface): ...@@ -118,6 +118,7 @@ class TreeInterfaceUpstream(TreeInterface):
self._graft_prune_state.seePrune(self) self._graft_prune_state.seePrune(self)
def recv_graft_ack_msg(self): def recv_graft_ack_msg(self):
print("GRAFT ACK!!!")
# todo check rpf nbr # todo check rpf nbr
self._graft_prune_state.recvGraftAckFromRPFnbr(self) self._graft_prune_state.recvGraftAckFromRPFnbr(self)
...@@ -146,6 +147,10 @@ class TreeInterfaceUpstream(TreeInterface): ...@@ -146,6 +147,10 @@ class TreeInterfaceUpstream(TreeInterface):
#Override #Override
def delete(self): def delete(self):
super().delete() super().delete()
self.clear_graft_retry_timer()
self.clear_assert_timer()
self.clear_prune_limit_timer()
self.clear_override_timer()
def is_downstream(self): def is_downstream(self):
return False return False
......
...@@ -160,7 +160,7 @@ class TreeInterface(metaclass=ABCMeta): ...@@ -160,7 +160,7 @@ class TreeInterface(metaclass=ABCMeta):
if upstream_neighbor_address == self.get_ip(): if upstream_neighbor_address == self.get_ip():
self._assert_state.receivedPruneOrJoinOrGraft(self) self._assert_state.receivedPruneOrJoinOrGraft(self)
def recv_graft_msg(self, upstream_neighbor_address): def recv_graft_msg(self, upstream_neighbor_address, source_ip):
if upstream_neighbor_address == self.get_ip(): if upstream_neighbor_address == self.get_ip():
self._assert_state.receivedPruneOrJoinOrGraft(self) self._assert_state.receivedPruneOrJoinOrGraft(self)
...@@ -186,30 +186,33 @@ class TreeInterface(metaclass=ABCMeta): ...@@ -186,30 +186,33 @@ class TreeInterface(metaclass=ABCMeta):
(source, group) = self.get_tree_id() (source, group) = self.get_tree_id()
# todo self.get_rpf_() # todo self.get_rpf_()
ph = PacketPimGraft("10.0.0.13") ip_dst = self.get_neighbor_RPF()
ph = PacketPimGraft(ip_dst)
ph.add_multicast_group(PacketPimJoinPruneMulticastGroup(group, joined_src_addresses=[source])) ph.add_multicast_group(PacketPimJoinPruneMulticastGroup(group, joined_src_addresses=[source]))
pckt = Packet(payload=PacketPimHeader(ph)) pckt = Packet(payload=PacketPimHeader(ph))
self.get_interface().send(pckt.bytes()) self.get_interface().send(pckt.bytes(), ip_dst)
#msg = GraftMsg(self.get_tree().tree_id, self.get_rpf_()) #msg = GraftMsg(self.get_tree().tree_id, self.get_rpf_())
#self.pim_if.send_mcast(msg) #self.pim_if.send_mcast(msg)
except: except:
traceback.print_exc()
return return
def send_graft_ack(self): def send_graft_ack(self, ip_sender):
print("send graft ack") print("send graft ack")
try: try:
(source, group) = self.get_tree_id() (source, group) = self.get_tree_id()
# todo endereco?!! # todo endereco?!!
ph = PacketPimGraftAck("10.0.0.13") ph = PacketPimGraftAck(ip_sender)
ph.add_multicast_group(PacketPimJoinPruneMulticastGroup(group, joined_src_addresses=[source])) ph.add_multicast_group(PacketPimJoinPruneMulticastGroup(group, joined_src_addresses=[source]))
pckt = Packet(payload=PacketPimHeader(ph)) pckt = Packet(payload=PacketPimHeader(ph))
self.get_interface().send(pckt.bytes()) self.get_interface().send(pckt.bytes(), ip_sender)
#msg = GraftAckMsg(self.get_tree().tree_id, self.get_node()) #msg = GraftAckMsg(self.get_tree().tree_id, self.get_node())
#self.pim_if.send_mcast(msg) #self.pim_if.send_mcast(msg)
except: except:
traceback.print_exc()
return return
...@@ -224,13 +227,14 @@ class TreeInterface(metaclass=ABCMeta): ...@@ -224,13 +227,14 @@ class TreeInterface(metaclass=ABCMeta):
(source, group) = self.get_tree_id() (source, group) = self.get_tree_id()
# todo help ip of ph # todo help ip of ph
#ph = PacketPimJoinPrune("123.123.123.123", 210) #ph = PacketPimJoinPrune("123.123.123.123", 210)
ph = PacketPimJoinPrune("123.123.123.123", holdtime) ph = PacketPimJoinPrune(self.get_neighbor_RPF(), holdtime)
ph.add_multicast_group(PacketPimJoinPruneMulticastGroup(group, pruned_src_addresses=[source])) ph.add_multicast_group(PacketPimJoinPruneMulticastGroup(group, pruned_src_addresses=[source]))
pckt = Packet(payload=PacketPimHeader(ph)) pckt = Packet(payload=PacketPimHeader(ph))
self.get_interface().send(pckt.bytes()) self.get_interface().send(pckt.bytes())
print('sent prune msg') print('sent prune msg')
except: except:
traceback.print_exc()
return return
...@@ -246,6 +250,7 @@ class TreeInterface(metaclass=ABCMeta): ...@@ -246,6 +250,7 @@ class TreeInterface(metaclass=ABCMeta):
self.get_interface().send(pckt.bytes()) self.get_interface().send(pckt.bytes())
print("send prune echo") print("send prune echo")
except: except:
traceback.print_exc()
return return
# todo # todo
#msg = PruneMsg(self.get_tree().tree_id, #msg = PruneMsg(self.get_tree().tree_id,
...@@ -258,7 +263,7 @@ class TreeInterface(metaclass=ABCMeta): ...@@ -258,7 +263,7 @@ class TreeInterface(metaclass=ABCMeta):
try: try:
(source, group) = self.get_tree_id() (source, group) = self.get_tree_id()
# todo help ip of ph # todo help ip of ph
ph = PacketPimJoinPrune("123.123.123.123", 210) ph = PacketPimJoinPrune(self.get_neighbor_RPF(), 210)
ph.add_multicast_group(PacketPimJoinPruneMulticastGroup(group, joined_src_addresses=[source])) ph.add_multicast_group(PacketPimJoinPruneMulticastGroup(group, joined_src_addresses=[source]))
pckt = Packet(payload=PacketPimHeader(ph)) pckt = Packet(payload=PacketPimHeader(ph))
...@@ -266,6 +271,7 @@ class TreeInterface(metaclass=ABCMeta): ...@@ -266,6 +271,7 @@ class TreeInterface(metaclass=ABCMeta):
#msg = JoinMsg(self.get_tree().tree_id, self.get_rpf_()) #msg = JoinMsg(self.get_tree().tree_id, self.get_rpf_())
#self.pim_if.send_mcast(msg) #self.pim_if.send_mcast(msg)
except: except:
traceback.print_exc()
return return
...@@ -280,6 +286,7 @@ class TreeInterface(metaclass=ABCMeta): ...@@ -280,6 +286,7 @@ class TreeInterface(metaclass=ABCMeta):
self.get_interface().send(pckt.bytes()) self.get_interface().send(pckt.bytes())
except: except:
traceback.print_exc()
return return
...@@ -295,6 +302,7 @@ class TreeInterface(metaclass=ABCMeta): ...@@ -295,6 +302,7 @@ class TreeInterface(metaclass=ABCMeta):
self.get_interface().send(pckt.bytes()) self.get_interface().send(pckt.bytes())
except: except:
traceback.print_exc()
return return
#msg = AssertMsg.new_assert_cancel(self.tree_id) #msg = AssertMsg.new_assert_cancel(self.tree_id)
#self.pim_if.send_mcast(msg) #self.pim_if.send_mcast(msg)
...@@ -388,8 +396,8 @@ class TreeInterface(metaclass=ABCMeta): ...@@ -388,8 +396,8 @@ class TreeInterface(metaclass=ABCMeta):
raise NotImplementedError() raise NotImplementedError()
def get_rpf_(self): #def get_rpf_(self):
return self.get_neighbor_RPF() # return self.get_neighbor_RPF()
# obtain ip of RPF'(S) # obtain ip of RPF'(S)
...@@ -397,11 +405,14 @@ class TreeInterface(metaclass=ABCMeta): ...@@ -397,11 +405,14 @@ class TreeInterface(metaclass=ABCMeta):
''' '''
RPF'(S) RPF'(S)
''' '''
if not self.is_assert_winner(): if self.i_am_assert_loser():
return self._assert_winner_ip return self._assert_winner_metric.get_ip()
else: else:
return self._kernel_entry.rpf_node return self._kernel_entry.rpf_node
def i_am_assert_loser(self):
return self._assert_state == AssertState.Loser
def is_assert_winner(self): def is_assert_winner(self):
return not self.is_downstream() and not self._assert_state == AssertState.Loser return not self.is_downstream() and not self._assert_state == AssertState.Loser
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment