Commit 2290b757 authored by Pedro Oliveira's avatar Pedro Oliveira

Prune <- almost every event done and tested; IGMP now notifies multicast...

Prune <- almost every event done and tested; IGMP now notifies multicast routing (if members are or not interested in receiving traffic from group G)
parent c6d3cf5a
...@@ -7,6 +7,7 @@ from Packet.PacketPimHeader import PacketPimHeader ...@@ -7,6 +7,7 @@ from Packet.PacketPimHeader import PacketPimHeader
from Interface import Interface from Interface import Interface
import Main import Main
from utils import HELLO_HOLD_TIME_TIMEOUT from utils import HELLO_HOLD_TIME_TIMEOUT
from Neighbor import Neighbor
class Hello: class Hello:
...@@ -20,7 +21,7 @@ class Hello: ...@@ -20,7 +21,7 @@ class Hello:
self.thread.start() self.thread.start()
def send_handle(self): def send_handle(self):
for (_, interface) in list(Main.interfaces.items()): for interface in list(Main.interfaces.values()):
self.packet_send_handle(interface) self.packet_send_handle(interface)
# reschedule timer # reschedule timer
...@@ -51,14 +52,46 @@ class Hello: ...@@ -51,14 +52,46 @@ class Hello:
# receive handler # receive handler
def receive_handle(self, packet: ReceivedPacket): def receive_handle(self, packet: ReceivedPacket):
interface = packet.interface
ip = packet.ip_header.ip_src ip = packet.ip_header.ip_src
print("ip = ", ip) print("ip = ", ip)
options = packet.payload.payload.get_options() options = packet.payload.payload.get_options()
if Main.get_neighbor(ip) is None:
if (1 in options) and (20 in options):
hello_hold_time = options[1]
generation_id = options[20]
else:
raise Exception
with interface.neighbors_lock.genWlock():
if ip in interface.neighbors:
neighbor = interface.neighbors[ip]
else:
interface.neighbors[ip] = Neighbor(interface, ip, generation_id, hello_hold_time)
return
with neighbor.neighbor_lock:
# Already know Neighbor
print("neighbor conhecido")
neighbor.heartbeat()
if neighbor.hello_hold_time != hello_hold_time:
print("keep alive period diferente")
neighbor.set_hello_hold_time(hello_hold_time)
if neighbor.generation_id != generation_id:
print("neighbor reiniciado")
neighbor.set_generation_id(generation_id)
'''
with interface.neighbors_lock.genWlock():
#if interface.get_neighbor(ip) is None:
if ip in interface.neighbors:
# Unknown Neighbor # Unknown Neighbor
if (1 in options) and (20 in options): if (1 in options) and (20 in options):
try: try:
Main.add_neighbor(packet.interface, ip, options[20], options[1]) #Main.add_neighbor(packet.interface, ip, options[20], options[1])
print("non neighbor and options inside") print("non neighbor and options inside")
except Exception: except Exception:
# Received Neighbor with Timeout # Received Neighbor with Timeout
...@@ -78,3 +111,4 @@ class Hello: ...@@ -78,3 +111,4 @@ class Hello:
print("neighbor reiniciado") print("neighbor reiniciado")
neighbor.remove() neighbor.remove()
Main.add_neighbor(packet.interface, ip, options[20], options[1]) Main.add_neighbor(packet.interface, ip, options[20], options[1])
'''
...@@ -10,7 +10,7 @@ class IGMP: ...@@ -10,7 +10,7 @@ class IGMP:
interface = packet.interface interface = packet.interface
ip_src = packet.ip_header.ip_src ip_src = packet.ip_header.ip_src
ip_dst = packet.ip_header.ip_dst ip_dst = packet.ip_header.ip_dst
print("ip = ", ip_src) #print("ip = ", ip_src)
igmp_hdr = packet.payload igmp_hdr = packet.payload
igmp_type = igmp_hdr.type igmp_type = igmp_hdr.type
......
...@@ -5,6 +5,8 @@ import netifaces ...@@ -5,6 +5,8 @@ import netifaces
from Packet.ReceivedPacket import ReceivedPacket from Packet.ReceivedPacket import ReceivedPacket
import Main import Main
import traceback import traceback
from RWLock.RWLock import RWLockWrite
class Interface(object): class Interface(object):
MCAST_GRP = '224.0.0.13' MCAST_GRP = '224.0.0.13'
...@@ -21,7 +23,10 @@ class Interface(object): ...@@ -21,7 +23,10 @@ class Interface(object):
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
# explicitly join the multicast group on the interface specified # explicitly join the multicast group on the interface specified
s.setsockopt(socket.SOL_IP, socket.IP_ADD_MEMBERSHIP, socket.inet_aton(Interface.MCAST_GRP) + socket.inet_aton(ip_interface)) #s.setsockopt(socket.SOL_IP, socket.IP_ADD_MEMBERSHIP, socket.inet_aton(Interface.MCAST_GRP) + socket.inet_aton(ip_interface))
s.setsockopt(socket.IPPROTO_IP, socket.IP_ADD_MEMBERSHIP,
socket.inet_aton(Interface.MCAST_GRP) + socket.inet_aton(ip_interface))
s.setsockopt(socket.SOL_SOCKET, 25, str(interface_name + '\0').encode('utf-8'))
# set socket output interface # set socket output interface
s.setsockopt(socket.SOL_IP, socket.IP_MULTICAST_IF, socket.inet_aton(ip_interface)) s.setsockopt(socket.SOL_IP, socket.IP_MULTICAST_IF, socket.inet_aton(ip_interface))
...@@ -36,17 +41,29 @@ class Interface(object): ...@@ -36,17 +41,29 @@ class Interface(object):
self.interface_enabled = True self.interface_enabled = True
# generation id # generation id
self.generation_id = random.getrandbits(32) #self.generation_id = random.getrandbits(32)
# todo neighbors # todo neighbors
self.neighbors = {} #self.neighbors = {}
#self.neighbors_lock = RWLockWrite()
# run receive method in background # run receive method in background
receive_thread = threading.Thread(target=self.receive) #receive_thread = threading.Thread(target=self.receive)
receive_thread.daemon = True #receive_thread.daemon = True
receive_thread.start() #receive_thread.start()
def receive(self): def receive(self):
try:
(raw_packet, (ip, _)) = self.socket.recvfrom(256 * 1024)
if raw_packet:
packet = ReceivedPacket(raw_packet, self)
else:
packet = None
return packet
except Exception:
return None
"""
while self.interface_enabled: while self.interface_enabled:
try: try:
(raw_packet, (ip, _)) = self.socket.recvfrom(256 * 1024) (raw_packet, (ip, _)) = self.socket.recvfrom(256 * 1024)
...@@ -56,10 +73,11 @@ class Interface(object): ...@@ -56,10 +73,11 @@ class Interface(object):
except Exception: except Exception:
traceback.print_exc() traceback.print_exc()
continue continue
"""
def send(self, data: bytes): def send(self, data: bytes, group_ip: str):
if self.interface_enabled and data: if self.interface_enabled and data:
self.socket.sendto(data, (Interface.MCAST_GRP, 0)) self.socket.sendto(data, (group_ip, 0))
def remove(self): def remove(self):
self.interface_enabled = False self.interface_enabled = False
...@@ -68,3 +86,29 @@ class Interface(object): ...@@ -68,3 +86,29 @@ class Interface(object):
except Exception: except Exception:
pass pass
self.socket.close() self.socket.close()
def is_enabled(self):
return self.interface_enabled
def get_ip(self):
return self.ip_interface
"""
def add_neighbor(self, ip, random_number, hello_hold_time):
with self.neighbors_lock.genWlock():
if ip not in self.neighbors:
print("ADD NEIGHBOR")
from Neighbor import Neighbor
n = Neighbor(self, ip, random_number, hello_hold_time)
self.neighbors[ip] = n
Main.protocols[0].force_send(self)
def get_neighbors(self):
with self.neighbors_lock.genRlock():
return self.neighbors.values()
def get_neighbor(self, ip):
with self.neighbors_lock.genRlock():
return self.neighbors[ip]
"""
\ No newline at end of file
...@@ -60,11 +60,11 @@ class InterfaceIGMP(object): ...@@ -60,11 +60,11 @@ class InterfaceIGMP(object):
from Packet.PacketIpHeader import PacketIpHeader from Packet.PacketIpHeader import PacketIpHeader
(verhlen, tos, iplen, ipid, frag, ttl, proto, cksum, src, dst) = \ (verhlen, tos, iplen, ipid, frag, ttl, proto, cksum, src, dst) = \
struct.unpack(PacketIpHeader.IP_HDR, raw_packet[:PacketIpHeader.IP_HDR_LEN]) struct.unpack(PacketIpHeader.IP_HDR, raw_packet[:PacketIpHeader.IP_HDR_LEN])
print(proto) #print(proto)
if proto != socket.IPPROTO_IGMP: if proto != socket.IPPROTO_IGMP:
continue continue
print((raw_packet, x)) #print((raw_packet, x))
packet = ReceivedPacket(raw_packet, self) packet = ReceivedPacket(raw_packet, self)
Main.igmp.receive_handle(packet) Main.igmp.receive_handle(packet)
except Exception: except Exception:
......
import threading
import random
from Interface import Interface
from Packet.ReceivedPacket import ReceivedPacket
import Main
import traceback
from RWLock.RWLock import RWLockWrite
class InterfacePim(Interface):
MCAST_GRP = '224.0.0.13'
def __init__(self, interface_name: str):
super().__init__(interface_name)
# generation id
self.generation_id = random.getrandbits(32)
# pim neighbors
self.neighbors = {}
self.neighbors_lock = RWLockWrite()
# run receive method in background
receive_thread = threading.Thread(target=self.receive)
receive_thread.daemon = True
receive_thread.start()
def receive(self):
while self.is_enabled():
try:
packet = super().receive()
if packet:
Main.protocols[packet.payload.get_pim_type()].receive_handle(packet)
except:
traceback.print_exc()
continue
"""
while self.interface_enabled:
(raw_packet, (ip, _)) = self.socket.recvfrom(256 * 1024)
if raw_packet:
packet = ReceivedPacket(raw_packet, self)
Main.protocols[packet.payload.get_pim_type()].receive_handle(packet) # TODO: perceber se existe melhor maneira de fazer isto
except Exception:
traceback.print_exc()
continue
"""
def send(self, data: bytes, group_ip: str=MCAST_GRP):
super().send(data=data, group_ip=group_ip)
def remove(self):
super().remove()
def add_neighbor(self, ip, random_number, hello_hold_time):
with self.neighbors_lock.genWlock():
if ip not in self.neighbors:
print("ADD NEIGHBOR")
from Neighbor import Neighbor
n = Neighbor(self, ip, random_number, hello_hold_time)
self.neighbors[ip] = n
Main.protocols[0].force_send(self)
def get_neighbors(self):
with self.neighbors_lock.genRlock():
return self.neighbors.values()
def get_neighbor(self, ip):
with self.neighbors_lock.genRlock():
return self.neighbors[ip]
...@@ -43,6 +43,12 @@ class JoinPrune: ...@@ -43,6 +43,12 @@ class JoinPrune:
#Main.kernel.routing[source_group].recv_join_msg(interface_index, packet) #Main.kernel.routing[source_group].recv_join_msg(interface_index, packet)
Main.kernel.get_routing_entry(source_group).recv_join_msg(interface_index, packet) Main.kernel.get_routing_entry(source_group).recv_join_msg(interface_index, packet)
except: except:
try:
#import time
#time.sleep(2)
Main.kernel.get_routing_entry(source_group).recv_join_msg(interface_index, packet)
except:
pass
# todo o que fazer quando n existe arvore para (s,g) ??? # todo o que fazer quando n existe arvore para (s,g) ???
traceback.print_exc() traceback.print_exc()
print("ATENCAO!!!!") print("ATENCAO!!!!")
...@@ -55,6 +61,12 @@ class JoinPrune: ...@@ -55,6 +61,12 @@ class JoinPrune:
#Main.kernel.routing[source_group].recv_prune_msg(interface_index, packet) #Main.kernel.routing[source_group].recv_prune_msg(interface_index, packet)
Main.kernel.get_routing_entry(source_group).recv_prune_msg(interface_index, packet) Main.kernel.get_routing_entry(source_group).recv_prune_msg(interface_index, packet)
except: except:
try:
#import time
#time.sleep(2)
Main.kernel.get_routing_entry(source_group).recv_prune_msg(interface_index, packet)
except:
pass
# todo o que fazer quando n existe arvore para (s,g) ??? # todo o que fazer quando n existe arvore para (s,g) ???
traceback.print_exc() traceback.print_exc()
print("ATENCAO!!!!") print("ATENCAO!!!!")
......
...@@ -273,16 +273,18 @@ class Kernel: ...@@ -273,16 +273,18 @@ class Kernel:
while self.running: while self.running:
try: try:
msg = self.socket.recv(5000) msg = self.socket.recv(5000)
print(len(msg)) #print(len(msg))
(_, _, im_msgtype, im_mbz, im_vif, _, im_src, im_dst) = struct.unpack("II B B B B 4s 4s", msg[:20]) (_, _, im_msgtype, im_mbz, im_vif, _, im_src, im_dst) = struct.unpack("II B B B B 4s 4s", msg[:20])
if im_mbz != 0:
continue
print(im_msgtype) print(im_msgtype)
print(im_mbz) print(im_mbz)
print(im_vif) print(im_vif)
print(socket.inet_ntoa(im_src)) print(socket.inet_ntoa(im_src))
print(socket.inet_ntoa(im_dst)) print(socket.inet_ntoa(im_dst))
print(struct.unpack("II B B B B 4s 4s", msg[:20])) print(struct.unpack("II B B B B 4s 4s", msg[:20]))
if im_mbz != 0:
continue
ip_src = socket.inet_ntoa(im_src) ip_src = socket.inet_ntoa(im_src)
ip_dst = socket.inet_ntoa(im_dst) ip_dst = socket.inet_ntoa(im_dst)
...@@ -301,20 +303,58 @@ class Kernel: ...@@ -301,20 +303,58 @@ class Kernel:
# receive multicast (S,G) packet and multicast routing table has no (S,G) entry # receive multicast (S,G) packet and multicast routing table has no (S,G) entry
def igmpmsg_nocache_handler(self, ip_src, ip_dst, iif): def igmpmsg_nocache_handler(self, ip_src, ip_dst, iif):
source_group_pair = (ip_src, ip_dst) source_group_pair = (ip_src, ip_dst)
"""
with self.rwlock.genWlock(): with self.rwlock.genWlock():
if source_group_pair in self.routing: if source_group_pair in self.routing:
return kernel_entry = self.routing[(ip_src, ip_dst)]
else:
kernel_entry = KernelEntry(ip_src, ip_dst, iif)
self.routing[(ip_src, ip_dst)] = kernel_entry
self.set_multicast_route(kernel_entry)
kernel_entry.recv_data_msg(iif)
"""
"""
with self.rwlock.genRlock():
if source_group_pair in self.routing:
kernel_entry = self.routing[(ip_src, ip_dst)]
with self.rwlock.genWlock():
if source_group_pair in self.routing:
kernel_entry = self.routing[(ip_src, ip_dst)]
else:
kernel_entry = KernelEntry(ip_src, ip_dst, iif) kernel_entry = KernelEntry(ip_src, ip_dst, iif)
self.routing[(ip_src, ip_dst)] = kernel_entry self.routing[(ip_src, ip_dst)] = kernel_entry
self.set_multicast_route(kernel_entry) self.set_multicast_route(kernel_entry)
kernel_entry.recv_data_msg(iif)
"""
self.get_routing_entry(source_group_pair, create_if_not_existent=True).recv_data_msg(iif)
# receive multicast (S,G) packet in a outbound_interface # receive multicast (S,G) packet in a outbound_interface
def igmpmsg_wrongvif_handler(self, ip_src, ip_dst, iif): def igmpmsg_wrongvif_handler(self, ip_src, ip_dst, iif):
#kernel_entry = self.routing[(ip_src, ip_dst)] #kernel_entry = self.routing[(ip_src, ip_dst)]
kernel_entry = self.get_routing_entry((ip_src, ip_dst)) self.get_routing_entry((ip_src, ip_dst), create_if_not_existent=True).recv_data_msg(iif)
kernel_entry.recv_data_msg(iif) #kernel_entry.recv_data_msg(iif)
"""
def get_routing_entry(self, source_group: tuple): def get_routing_entry(self, source_group: tuple):
with self.rwlock.genRlock(): with self.rwlock.genRlock():
return self.routing[source_group] return self.routing[source_group]
"""
def get_routing_entry(self, source_group: tuple, create_if_not_existent=False):
ip_src = source_group[0]
ip_dst = source_group[1]
with self.rwlock.genRlock():
if source_group in self.routing:
return self.routing[(ip_src, ip_dst)]
with self.rwlock.genWlock():
if source_group in self.routing:
return self.routing[(ip_src, ip_dst)]
elif create_if_not_existent:
kernel_entry = KernelEntry(ip_src, ip_dst, 0)
self.routing[source_group] = kernel_entry
#self.set_multicast_route(kernel_entry)
return kernel_entry
else:
return None
\ No newline at end of file
...@@ -2,24 +2,20 @@ import netifaces ...@@ -2,24 +2,20 @@ import netifaces
import time import time
from prettytable import PrettyTable from prettytable import PrettyTable
from Interface import Interface from InterfacePIM import InterfacePim
from InterfaceIGMP import InterfaceIGMP from InterfaceIGMP import InterfaceIGMP
from Kernel import Kernel from Kernel import Kernel
from Neighbor import Neighbor
from threading import Lock from threading import Lock
interfaces = {} # interfaces with multicast routing enabled interfaces = {} # interfaces with multicast routing enabled
igmp_interfaces = {} # igmp interfaces igmp_interfaces = {} # igmp interfaces
neighbors = {} # multicast router neighbors
neighbors_lock = Lock()
protocols = {} protocols = {}
kernel = None kernel = None
igmp = None igmp = None
def add_interface(interface_name, pim=False, igmp=False): def add_interface(interface_name, pim=False, igmp=False):
global interfaces
if pim is True and interface_name not in interfaces: if pim is True and interface_name not in interfaces:
interface = Interface(interface_name) interface = InterfacePim(interface_name)
interfaces[interface_name] = interface interfaces[interface_name] = interface
protocols[0].force_send(interface) protocols[0].force_send(interface)
if igmp is True and interface_name not in igmp_interfaces: if igmp is True and interface_name not in igmp_interfaces:
...@@ -27,35 +23,29 @@ def add_interface(interface_name, pim=False, igmp=False): ...@@ -27,35 +23,29 @@ def add_interface(interface_name, pim=False, igmp=False):
igmp_interfaces[interface_name] = interface igmp_interfaces[interface_name] = interface
def remove_interface(interface_name, pim=False, igmp=False): def remove_interface(interface_name, pim=False, igmp=False):
global interfaces
global neighbors
if pim is True and ((interface_name in interfaces) or interface_name == "*"): if pim is True and ((interface_name in interfaces) or interface_name == "*"):
if interface_name == "*": if interface_name == "*":
interface_name = list(interfaces.keys()) interface_name_list = list(interfaces.keys())
else: else:
interface_name = [interface_name] interface_name_list = [interface_name]
for if_name in interface_name: for if_name in interface_name_list:
protocols[0].force_send_remove(interfaces[if_name]) protocols[0].force_send_remove(interfaces[if_name])
interfaces[if_name].remove() interfaces[if_name].remove()
del interfaces[if_name] del interfaces[if_name]
print("removido interface") print("removido interface")
for (ip_neighbor, neighbor) in list(neighbors.items()):
if neighbor.contact_interface not in interfaces:
neighbor.remove()
if igmp is True and ((interface_name in igmp_interfaces) or interface_name == "*"): if igmp is True and ((interface_name in igmp_interfaces) or interface_name == "*"):
if interface_name == "*": if interface_name == "*":
interface_name = list(igmp_interfaces.keys()) interface_name_list = list(igmp_interfaces.keys())
else: else:
interface_name = [interface_name] interface_name_list = [interface_name]
for if_name in interface_name: for if_name in interface_name_list:
igmp_interfaces[if_name].remove() igmp_interfaces[if_name].remove()
del igmp_interfaces[if_name] del igmp_interfaces[if_name]
print("removido interface") print("removido interface")
"""
def add_neighbor(contact_interface, ip, random_number, hello_hold_time): def add_neighbor(contact_interface, ip, random_number, hello_hold_time):
global neighbors global neighbors
with neighbors_lock: with neighbors_lock:
...@@ -81,20 +71,23 @@ def remove_neighbor(ip): ...@@ -81,20 +71,23 @@ def remove_neighbor(ip):
if ip in neighbors: if ip in neighbors:
del neighbors[ip] del neighbors[ip]
print("removido neighbor") print("removido neighbor")
"""
def add_protocol(protocol_number, protocol_obj): def add_protocol(protocol_number, protocol_obj):
global protocols global protocols
protocols[protocol_number] = protocol_obj protocols[protocol_number] = protocol_obj
def list_neighbors(): def list_neighbors():
global neighbors interfaces_list = interfaces.values()
t = PrettyTable(['Interface', 'Neighbor IP', 'Hello Hold Time', "Generation ID", "Uptime"])
check_time = time.time() check_time = time.time()
t = PrettyTable(['Neighbor IP', 'Hello Hold Time', "Generation ID", "Uptime"]) for interface in interfaces_list:
for ip, neighbor in list(neighbors.items()): for neighbor in interface.get_neighbors():
uptime = check_time - neighbor.time_of_last_update uptime = check_time - neighbor.time_of_last_update
uptime = 0 if (uptime < 0) else uptime uptime = 0 if (uptime < 0) else uptime
t.add_row([ip, neighbor.hello_hold_time, neighbor.generation_id, time.strftime("%H:%M:%S", time.gmtime(uptime))]) t.add_row(
[interface.interface_name, neighbor.ip, neighbor.hello_hold_time, neighbor.generation_id, time.strftime("%H:%M:%S", time.gmtime(uptime))])
print(t) print(t)
return str(t) return str(t)
...@@ -126,7 +119,7 @@ def list_enabled_interfaces(): ...@@ -126,7 +119,7 @@ def list_enabled_interfaces():
from Packet.PacketPimGraft import PacketPimGraft from Packet.PacketPimGraft import PacketPimGraft
ph = PacketPimGraft("10.0.0.13", 210) ph = PacketPimGraft("10.0.0.13")
ph.add_multicast_group(PacketPimJoinPruneMulticastGroup("239.123.123.124", ["1.1.1.2", "10.1.1.2"], [])) ph.add_multicast_group(PacketPimJoinPruneMulticastGroup("239.123.123.124", ["1.1.1.2", "10.1.1.2"], []))
pckt = Packet(payload=PacketPimHeader(ph)) pckt = Packet(payload=PacketPimHeader(ph))
interfaces[interface].send(pckt.bytes()) interfaces[interface].send(pckt.bytes())
...@@ -165,6 +158,29 @@ def list_igmp_state(): ...@@ -165,6 +158,29 @@ def list_igmp_state():
t.add_row([interface_name, state_txt, group_addr, group_state_txt]) t.add_row([interface_name, state_txt, group_addr, group_state_txt])
return str(t) return str(t)
def list_routing_state():
routing_entries = kernel.routing.values()
vif_indexes = kernel.vif_index_to_name_dic.keys()
t = PrettyTable(['SourceIP', 'GroupIP', 'Interface', 'PruneState', 'AssertState', "Is Forwarding?"])
for entry in routing_entries:
ip = entry.source_ip
group = entry.group_ip
for index in vif_indexes:
interface_state = entry.interface_state[index]
interface_name = kernel.vif_index_to_name_dic[index]
is_forwarding = interface_state.is_forwarding()
try:
prune_state = type(interface_state._prune_state).__name__
assert_state = type(interface_state._assert_state).__name__
except:
prune_state = "-"
assert_state = "-"
t.add_row([ip, group, interface_name, prune_state, assert_state, is_forwarding])
return str(t)
def main(interfaces_to_add=[]): def main(interfaces_to_add=[]):
from Hello import Hello from Hello import Hello
......
...@@ -3,6 +3,7 @@ import time ...@@ -3,6 +3,7 @@ import time
from utils import HELLO_HOLD_TIME_NO_TIMEOUT, HELLO_HOLD_TIME_TIMEOUT from utils import HELLO_HOLD_TIME_NO_TIMEOUT, HELLO_HOLD_TIME_TIMEOUT
from Interface import Interface from Interface import Interface
import Main import Main
from threading import Lock
class Neighbor: class Neighbor:
...@@ -16,6 +17,10 @@ class Neighbor: ...@@ -16,6 +17,10 @@ class Neighbor:
self.hello_hold_time = None self.hello_hold_time = None
self.set_hello_hold_time(hello_hold_time) self.set_hello_hold_time(hello_hold_time)
self.time_of_last_update = time.time() self.time_of_last_update = time.time()
self.neighbor_lock = Lock()
# todo
Main.protocols[0].force_send(contact_interface)
def set_hello_hold_time(self, hello_hold_time: int): def set_hello_hold_time(self, hello_hold_time: int):
self.hello_hold_time = hello_hold_time self.hello_hold_time = hello_hold_time
...@@ -30,6 +35,14 @@ class Neighbor: ...@@ -30,6 +35,14 @@ class Neighbor:
else: else:
self.neighbor_liveness_timer = None self.neighbor_liveness_timer = None
def set_generation_id(self, generation_id):
if self.generation_id is None:
self.generation_id = generation_id
elif self.generation_id != generation_id:
self.generation_id = generation_id
self.set_hello_hold_time(self.hello_hold_time)
self.time_of_last_update = time.time()
def heartbeat(self): def heartbeat(self):
if (self.hello_hold_time != HELLO_HOLD_TIME_TIMEOUT) and \ if (self.hello_hold_time != HELLO_HOLD_TIME_TIMEOUT) and \
(self.hello_hold_time != HELLO_HOLD_TIME_NO_TIMEOUT): (self.hello_hold_time != HELLO_HOLD_TIME_NO_TIMEOUT):
...@@ -44,4 +57,5 @@ class Neighbor: ...@@ -44,4 +57,5 @@ class Neighbor:
print('HELLO TIMER EXPIRED... remove neighbor') print('HELLO TIMER EXPIRED... remove neighbor')
if self.neighbor_liveness_timer is not None: if self.neighbor_liveness_timer is not None:
self.neighbor_liveness_timer.cancel() self.neighbor_liveness_timer.cancel()
Main.remove_neighbor(self.ip) #Main.remove_neighbor(self.ip)
del self.contact_interface.neighbors[self.ip]
...@@ -60,18 +60,18 @@ class PacketIGMPHeader(PacketPayload): ...@@ -60,18 +60,18 @@ class PacketIGMPHeader(PacketPayload):
@staticmethod @staticmethod
def parse_bytes(data: bytes): def parse_bytes(data: bytes):
print("parseIGMPHdr: ", data) #print("parseIGMPHdr: ", data)
igmp_hdr = data[0:PacketIGMPHeader.IGMP_HDR_LEN] igmp_hdr = data[0:PacketIGMPHeader.IGMP_HDR_LEN]
(type, max_resp_time, rcv_checksum, group_address) = struct.unpack(PacketIGMPHeader.IGMP_HDR, igmp_hdr) (type, max_resp_time, rcv_checksum, group_address) = struct.unpack(PacketIGMPHeader.IGMP_HDR, igmp_hdr)
print(type, max_resp_time, rcv_checksum, group_address) #print(type, max_resp_time, rcv_checksum, group_address)
msg_to_checksum = data[0:2] + b'\x00\x00' + data[4:] msg_to_checksum = data[0:2] + b'\x00\x00' + data[4:]
print("checksum calculated: " + str(checksum(msg_to_checksum))) #print("checksum calculated: " + str(checksum(msg_to_checksum)))
if checksum(msg_to_checksum) != rcv_checksum: if checksum(msg_to_checksum) != rcv_checksum:
print("wrong checksum") #print("wrong checksum")
raise Exception("wrong checksum") raise Exception("wrong checksum")
igmp_hdr = igmp_hdr[PacketIGMPHeader.IGMP_HDR_LEN:] igmp_hdr = igmp_hdr[PacketIGMPHeader.IGMP_HDR_LEN:]
......
...@@ -43,5 +43,5 @@ from Packet.PacketPimJoinPrune import PacketPimJoinPrune ...@@ -43,5 +43,5 @@ from Packet.PacketPimJoinPrune import PacketPimJoinPrune
class PacketPimGraft(PacketPimJoinPrune): class PacketPimGraft(PacketPimJoinPrune):
PIM_TYPE = 6 PIM_TYPE = 6
def __init__(self, upstream_neighbor_address, hold_time): def __init__(self, upstream_neighbor_address):
super().__init__(upstream_neighbor_address, hold_time) super().__init__(upstream_neighbor_address=upstream_neighbor_address, hold_time=0)
...@@ -62,7 +62,7 @@ class MyDaemon(Daemon): ...@@ -62,7 +62,7 @@ class MyDaemon(Daemon):
elif args.list_neighbors: elif args.list_neighbors:
connection.sendall(pickle.dumps(Main.list_neighbors())) connection.sendall(pickle.dumps(Main.list_neighbors()))
elif args.list_state: elif args.list_state:
connection.sendall(pickle.dumps(Main.list_igmp_state())) connection.sendall(pickle.dumps(Main.list_igmp_state() + "\n\n\n\n\n\n" + Main.list_routing_state()))
elif args.add_interface: elif args.add_interface:
Main.add_interface(args.add_interface[0], pim=True) Main.add_interface(args.add_interface[0], pim=True)
connection.shutdown(socket.SHUT_RDWR) connection.shutdown(socket.SHUT_RDWR)
......
...@@ -18,6 +18,10 @@ class GroupState(object): ...@@ -18,6 +18,10 @@ class GroupState(object):
# lock # lock
self.lock = Lock() self.lock = Lock()
# KernelEntry's instances to notify change of igmp state
self.multicast_interface_state = []
self.multicast_interface_state_lock = Lock()
def print_state(self): def print_state(self):
return self.state.print_state() return self.state.print_state()
...@@ -99,3 +103,30 @@ class GroupState(object): ...@@ -99,3 +103,30 @@ class GroupState(object):
def receive_group_specific_query(self, max_response_time: int): def receive_group_specific_query(self, max_response_time: int):
with self.lock: with self.lock:
self.get_interface_group_state().receive_group_specific_query(self, max_response_time) self.get_interface_group_state().receive_group_specific_query(self, max_response_time)
###########################################
# Notify Routing
###########################################
def notify_routing_add(self):
with self.multicast_interface_state_lock:
print("notify+", self.multicast_interface_state)
for interface_state in self.multicast_interface_state:
interface_state.notify_igmp(has_members=True)
def notify_routing_remove(self):
with self.multicast_interface_state_lock:
print("notify-", self.multicast_interface_state)
for interface_state in self.multicast_interface_state:
interface_state.notify_igmp(has_members=False)
def add_multicast_routing_entry(self, kernel_entry):
with self.multicast_interface_state_lock:
self.multicast_interface_state.append(kernel_entry)
return self.has_members()
def remove_multicast_routing_entry(self, kernel_entry):
with self.multicast_interface_state_lock:
self.multicast_interface_state.remove(kernel_entry)
def has_members(self):
return self.state is not NoMembersPresent
...@@ -5,7 +5,7 @@ from utils import Membership_Query, QueryResponseInterval, QueryInterval, OtherQ ...@@ -5,7 +5,7 @@ from utils import Membership_Query, QueryResponseInterval, QueryInterval, OtherQ
from .querier.Querier import Querier from .querier.Querier import Querier
from .nonquerier.NonQuerier import NonQuerier from .nonquerier.NonQuerier import NonQuerier
from .GroupState import GroupState from .GroupState import GroupState
from RWLock.RWLock import RWLockWrite
if TYPE_CHECKING: if TYPE_CHECKING:
from InterfaceIGMP import InterfaceIGMP from InterfaceIGMP import InterfaceIGMP
...@@ -22,6 +22,7 @@ class RouterState(object): ...@@ -22,6 +22,7 @@ class RouterState(object):
# state of each group # state of each group
# Key: GroupIPAddress, Value: GroupState object # Key: GroupIPAddress, Value: GroupState object
self.group_state = {} self.group_state = {}
self.group_state_lock = RWLockWrite()
# send general query # send general query
packet = PacketIGMPHeader(type=Membership_Query, max_resp_time=QueryResponseInterval*10) packet = PacketIGMPHeader(type=Membership_Query, max_resp_time=QueryResponseInterval*10)
...@@ -80,24 +81,40 @@ class RouterState(object): ...@@ -80,24 +81,40 @@ class RouterState(object):
############################################ ############################################
# group state methods # group state methods
############################################ ############################################
def get_group_state(self, group_ip):
with self.group_state_lock.genRlock():
if group_ip in self.group_state:
return self.group_state[group_ip]
with self.group_state_lock.genWlock():
if group_ip in self.group_state:
group_state = self.group_state[group_ip]
else:
group_state = GroupState(self, group_ip)
self.group_state[group_ip] = group_state
return group_state
def receive_v1_membership_report(self, packet: ReceivedPacket): def receive_v1_membership_report(self, packet: ReceivedPacket):
igmp_group = packet.payload.group_address igmp_group = packet.payload.group_address
if igmp_group not in self.group_state: #if igmp_group not in self.group_state:
self.group_state[igmp_group] = GroupState(self, igmp_group) # self.group_state[igmp_group] = GroupState(self, igmp_group)
self.group_state[igmp_group].receive_v1_membership_report() #self.group_state[igmp_group].receive_v1_membership_report()
self.get_group_state(igmp_group).receive_v1_membership_report()
def receive_v2_membership_report(self, packet: ReceivedPacket): def receive_v2_membership_report(self, packet: ReceivedPacket):
igmp_group = packet.payload.group_address igmp_group = packet.payload.group_address
if igmp_group not in self.group_state: #if igmp_group not in self.group_state:
self.group_state[igmp_group] = GroupState(self, igmp_group) # self.group_state[igmp_group] = GroupState(self, igmp_group)
self.group_state[igmp_group].receive_v2_membership_report() #self.group_state[igmp_group].receive_v2_membership_report()
self.get_group_state(igmp_group).receive_v2_membership_report()
def receive_leave_group(self, packet: ReceivedPacket): def receive_leave_group(self, packet: ReceivedPacket):
igmp_group = packet.payload.group_address igmp_group = packet.payload.group_address
if igmp_group in self.group_state: #if igmp_group in self.group_state:
self.group_state[igmp_group].receive_leave_group() # self.group_state[igmp_group].receive_leave_group()
self.get_group_state(igmp_group).receive_leave_group()
def receive_query(self, packet: ReceivedPacket): def receive_query(self, packet: ReceivedPacket):
self.interface_state.receive_query(self, packet) self.interface_state.receive_query(self, packet)
...@@ -105,5 +122,7 @@ class RouterState(object): ...@@ -105,5 +122,7 @@ class RouterState(object):
# process group specific query # process group specific query
if igmp_group != "0.0.0.0" and igmp_group in self.group_state: if igmp_group != "0.0.0.0" and igmp_group in self.group_state:
#if igmp_group != "0.0.0.0":
max_response_time = packet.payload.max_resp_time max_response_time = packet.payload.max_resp_time
self.group_state[igmp_group].receive_group_specific_query(max_response_time) #self.group_state[igmp_group].receive_group_specific_query(max_response_time)
self.get_group_state(igmp_group).receive_group_specific_query(max_response_time)
\ No newline at end of file
...@@ -7,10 +7,10 @@ if TYPE_CHECKING: ...@@ -7,10 +7,10 @@ if TYPE_CHECKING:
def group_membership_timeout(group_state: 'GroupState'): def group_membership_timeout(group_state: 'GroupState'):
group_ip = group_state.group_ip
# TODO NOTIFY ROUTING - !!!!
group_state.state = NoMembersPresent group_state.state = NoMembersPresent
# NOTIFY ROUTING - !!!!
group_state.notify_routing_remove()
def group_membership_v1_timeout(group_state: 'GroupState'): def group_membership_v1_timeout(group_state: 'GroupState'):
# do nothing # do nothing
......
...@@ -7,10 +7,11 @@ if TYPE_CHECKING: ...@@ -7,10 +7,11 @@ if TYPE_CHECKING:
def group_membership_timeout(group_state: 'GroupState'): def group_membership_timeout(group_state: 'GroupState'):
group_ip = group_state.group_ip
# TODO NOTIFY ROUTING - !!!!
group_state.state = NoMembersPresent group_state.state = NoMembersPresent
# NOTIFY ROUTING - !!!!
group_state.notify_routing_remove()
def group_membership_v1_timeout(group_state: 'GroupState'): def group_membership_v1_timeout(group_state: 'GroupState'):
# do nothing # do nothing
......
...@@ -25,12 +25,12 @@ def receive_v1_membership_report(group_state: 'GroupState'): ...@@ -25,12 +25,12 @@ def receive_v1_membership_report(group_state: 'GroupState'):
def receive_v2_membership_report(group_state: 'GroupState'): def receive_v2_membership_report(group_state: 'GroupState'):
group_ip = group_state.group_ip
# TODO NOTIFY ROUTING + !!!!
group_state.set_timer() group_state.set_timer()
group_state.state = MembersPresent group_state.state = MembersPresent
# NOTIFY ROUTING + !!!!
group_state.notify_routing_add()
def receive_leave_group(group_state: 'GroupState'): def receive_leave_group(group_state: 'GroupState'):
# do nothing # do nothing
......
...@@ -6,11 +6,12 @@ if TYPE_CHECKING: ...@@ -6,11 +6,12 @@ if TYPE_CHECKING:
def group_membership_timeout(group_state: 'GroupState'): def group_membership_timeout(group_state: 'GroupState'):
group_ip = group_state.group_ip
# TODO NOTIFY ROUTING - !!!!
group_state.clear_retransmit_timer() group_state.clear_retransmit_timer()
group_state.state = NoMembersPresent group_state.state = NoMembersPresent
# NOTIFY ROUTING - !!!!
group_state.notify_routing_remove()
def group_membership_v1_timeout(group_state: 'GroupState'): def group_membership_v1_timeout(group_state: 'GroupState'):
# do nothing # do nothing
......
...@@ -6,9 +6,11 @@ if TYPE_CHECKING: ...@@ -6,9 +6,11 @@ if TYPE_CHECKING:
def group_membership_timeout(group_state: 'GroupState'): def group_membership_timeout(group_state: 'GroupState'):
# TODO NOTIFY ROUTING - !!!!
group_state.state = NoMembersPresent group_state.state = NoMembersPresent
# NOTIFY ROUTING - !!!!
group_state.notify_routing_remove()
def group_membership_v1_timeout(group_state: 'GroupState'): def group_membership_v1_timeout(group_state: 'GroupState'):
# do nothing # do nothing
......
...@@ -21,21 +21,21 @@ def retransmit_timeout(group_state: 'GroupState'): ...@@ -21,21 +21,21 @@ def retransmit_timeout(group_state: 'GroupState'):
def receive_v1_membership_report(group_state: 'GroupState'): def receive_v1_membership_report(group_state: 'GroupState'):
group_ip = group_state.group_ip
# TODO NOTIFY ROUTING + !!!!
group_state.set_timer() group_state.set_timer()
group_state.set_v1_host_timer() group_state.set_v1_host_timer()
group_state.state = Version1MembersPresent group_state.state = Version1MembersPresent
# NOTIFY ROUTING + !!!!
group_state.notify_routing_add()
def receive_v2_membership_report(group_state: 'GroupState'):
group_ip = group_state.group_ip
# TODO NOTIFY ROUTING + !!!!
def receive_v2_membership_report(group_state: 'GroupState'):
group_state.set_timer() group_state.set_timer()
group_state.state = MembersPresent group_state.state = MembersPresent
# NOTIFY ROUTING + !!!!
group_state.notify_routing_add()
def receive_leave_group(group_state: 'GroupState'): def receive_leave_group(group_state: 'GroupState'):
# do nothing # do nothing
......
...@@ -6,9 +6,11 @@ if TYPE_CHECKING: ...@@ -6,9 +6,11 @@ if TYPE_CHECKING:
def group_membership_timeout(group_state: 'GroupState'): def group_membership_timeout(group_state: 'GroupState'):
# TODO NOTIFY ROUTING - !!!!
group_state.state = NoMembersPresent group_state.state = NoMembersPresent
# NOTIFY ROUTING - !!!!
group_state.notify_routing_remove()
def group_membership_v1_timeout(group_state: 'GroupState'): def group_membership_v1_timeout(group_state: 'GroupState'):
group_state.state = MembersPresent group_state.state = MembersPresent
......
import Main import Main
import socket import socket
import struct
import netifaces import netifaces
import threading
from tree.root_interface import SFRMRootInterface from tree.root_interface import SFRMRootInterface
from tree.non_root_interface import SFRMNonRootInterface from tree.non_root_interface import SFRMNonRootInterface
from threading import Timer from threading import Timer, Lock
class KernelEntry: class KernelEntry:
TREE_TIMEOUT = 180 TREE_TIMEOUT = 180
...@@ -18,35 +16,42 @@ class KernelEntry: ...@@ -18,35 +16,42 @@ class KernelEntry:
# ip of neighbor of the rpf # ip of neighbor of the rpf
self._rpf_node = None self._rpf_node = None
self._has_members = True # todo check via igmp # (S,G) starts IG state
self._was_in_group = True self._was_in_group = True
self._rpf_is_origin = False
self._liveliness_timer = None
# todo
self._rpf_is_origin = False
# decide inbound interface based on rpf check # decide inbound interface based on rpf check
self.inbound_interface_index = Main.kernel.vif_dic[self.check_rpf()] self.inbound_interface_index = Main.kernel.vif_dic[self.check_rpf()]
#Main.kernel.flood(ip_src=source_ip, ip_dst=group_ip, iif=self.inbound_interface_index)
#import time Main.kernel.flood(source_ip, group_ip, self.inbound_interface_index)
#time.sleep(5)
self.interface_state = {} # type: Dict[int, SFRMTreeInterface] self.interface_state = {} # type: Dict[int, SFRMTreeInterface]
for i in range(Main.kernel.MAXVIFS): #for i in range(Main.kernel.MAXVIFS):
for i in Main.kernel.vif_index_to_name_dic.keys():
try:
if i == self.inbound_interface_index: if i == self.inbound_interface_index:
self.interface_state[i] = SFRMRootInterface(self, i, False) self.interface_state[i] = SFRMRootInterface(self, i, False)
else: else:
self.interface_state[i] = SFRMNonRootInterface(self, i) self.interface_state[i] = SFRMNonRootInterface(self, i)
except:
continue
print('Tree created') self._multicast_change = Lock()
self.evaluate_ingroup() self._lock_test2 = Lock()
self.CHANGE_STATE_LOCK = Lock()
print('Tree created')
self._liveliness_timer = None
if self.is_originater(): if self.is_originater():
self.set_liveliness_timer() self.set_liveliness_timer()
print('set SAT') print('set SAT')
self._lock = threading.RLock() #self._lock = threading.RLock()
def get_inbound_interface_index(self): def get_inbound_interface_index(self):
...@@ -76,7 +81,7 @@ class KernelEntry: ...@@ -76,7 +81,7 @@ class KernelEntry:
if self.is_originater(): if self.is_originater():
self.clear_liveliness_timer() self.clear_liveliness_timer()
self.interface_state[index].recv_data_msg() self.interface_state[index].recv_data_msg(None, None)
def recv_assert_msg(self, index, packet): def recv_assert_msg(self, index, packet):
print("recv assert msg") print("recv assert msg")
...@@ -128,16 +133,24 @@ class KernelEntry: ...@@ -128,16 +133,24 @@ class KernelEntry:
def is_in_group(self): def is_in_group(self):
# todo # todo
#if self.get_has_members(): #if self.get_has_members():
if True: #if True:
# return True
"""
for index in Main.kernel.vif_index_to_name_dic.keys():
if self.interface_state[index].is_forwarding():
return True return True
return False
"""
for interface in self.interface_state.values(): for interface in self.interface_state.values():
if interface.is_forwarding(): if interface.is_forwarding():
return True return True
return False return False
def evaluate_ingroup(self): def evaluate_ingroup(self):
with self._lock_test2:
is_ig = self.is_in_group() is_ig = self.is_in_group()
if self._was_in_group != is_ig: if self._was_in_group != is_ig:
...@@ -175,7 +188,7 @@ class KernelEntry: ...@@ -175,7 +188,7 @@ class KernelEntry:
def change(self): def change(self):
# todo: changes on unicast routing or multicast routing... # todo: changes on unicast routing or multicast routing...
with self._multicast_change:
Main.kernel.set_multicast_route(self) Main.kernel.set_multicast_route(self)
def delete(self): def delete(self):
......
...@@ -74,7 +74,7 @@ class SFMRAssertWinner(SFMRAssertABC): ...@@ -74,7 +74,7 @@ class SFMRAssertWinner(SFMRAssertABC):
''' '''
@type interface: SFRMNonRootInterface @type interface: SFRMNonRootInterface
''' '''
interface.rprint('data_arrival, W -> W') print('data_arrival, W -> W')
interface.send_assert() interface.send_assert()
@staticmethod @staticmethod
...@@ -83,7 +83,7 @@ class SFMRAssertWinner(SFMRAssertABC): ...@@ -83,7 +83,7 @@ class SFMRAssertWinner(SFMRAssertABC):
@type interface: SFRMNonRootInterface @type interface: SFRMNonRootInterface
@type metric: SFMRAssertMetric @type metric: SFMRAssertMetric
''' '''
interface.rprint('recv_better_metric, W -> L') print('recv_better_metric, W -> L')
interface._set_assert_state(AssertState.Looser) interface._set_assert_state(AssertState.Looser)
interface._set_winner_metric(metric) interface._set_winner_metric(metric)
...@@ -94,7 +94,7 @@ class SFMRAssertWinner(SFMRAssertABC): ...@@ -94,7 +94,7 @@ class SFMRAssertWinner(SFMRAssertABC):
@type interface: SFRMNonRootInterface @type interface: SFRMNonRootInterface
@type metric: SFMRAssertMetric @type metric: SFMRAssertMetric
''' '''
interface.rprint('recv_worse_metric, W -> W') print('recv_worse_metric, W -> W')
interface.send_assert() interface.send_assert()
...@@ -118,14 +118,14 @@ class SFMRAssertWinner(SFMRAssertABC): ...@@ -118,14 +118,14 @@ class SFMRAssertWinner(SFMRAssertABC):
@type interface: SFRMNonRootInterface @type interface: SFRMNonRootInterface
''' '''
interface.send_reset() interface.send_reset()
interface.rprint('aw_rpc_worsens, W -> W') print('aw_rpc_worsens, W -> W')
@staticmethod @staticmethod
def is_now_root(interface): def is_now_root(interface):
''' '''
@type interface: SFRMNonRootInterface @type interface: SFRMNonRootInterface
''' '''
interface.rprint('is_now_root, W -> W') print('is_now_root, W -> W')
interface.send_reset() interface.send_reset()
...@@ -134,14 +134,14 @@ class SFMRAssertWinner(SFMRAssertABC): ...@@ -134,14 +134,14 @@ class SFMRAssertWinner(SFMRAssertABC):
''' '''
@type interface: SFRMNonRootInterface @type interface: SFRMNonRootInterface
''' '''
interface.rprint('recv_reset, W -> W') print('recv_reset, W -> W')
@staticmethod @staticmethod
def is_now_pruned(interface): def is_now_pruned(interface):
''' '''
@type interface: SFRMNonRootInterface @type interface: SFRMNonRootInterface
''' '''
interface.rprint('is_now_pruned, W -> W') print('is_now_pruned, W -> W')
class SFMRAssertLooser(SFMRAssertABC): class SFMRAssertLooser(SFMRAssertABC):
...@@ -150,7 +150,7 @@ class SFMRAssertLooser(SFMRAssertABC): ...@@ -150,7 +150,7 @@ class SFMRAssertLooser(SFMRAssertABC):
''' '''
@type interface: SFRMNonRootInterface @type interface: SFRMNonRootInterface
''' '''
interface.rprint('data_arrival, L -> L') print('data_arrival, L -> L')
@staticmethod @staticmethod
def recv_better_metric(interface, metric): def recv_better_metric(interface, metric):
...@@ -158,7 +158,7 @@ class SFMRAssertLooser(SFMRAssertABC): ...@@ -158,7 +158,7 @@ class SFMRAssertLooser(SFMRAssertABC):
@type interface: SFRMNonRootInterface @type interface: SFRMNonRootInterface
@type metric: SFMRAssertMetric @type metric: SFMRAssertMetric
''' '''
interface.rprint('recv_better_metric, L -> L') print('recv_better_metric, L -> L')
interface._set_winner_metric(metric) interface._set_winner_metric(metric)
...@@ -168,7 +168,7 @@ class SFMRAssertLooser(SFMRAssertABC): ...@@ -168,7 +168,7 @@ class SFMRAssertLooser(SFMRAssertABC):
@type interface: SFRMNonRootInterface @type interface: SFRMNonRootInterface
@type metric: SFMRAssertMetric @type metric: SFMRAssertMetric
''' '''
interface.rprint('recv_worse_metric, L -> W') print('recv_worse_metric, L -> W')
interface.send_assert() interface.send_assert()
interface._set_assert_state(AssertState.Winner) interface._set_assert_state(AssertState.Winner)
...@@ -179,7 +179,7 @@ class SFMRAssertLooser(SFMRAssertABC): ...@@ -179,7 +179,7 @@ class SFMRAssertLooser(SFMRAssertABC):
''' '''
@type interface: SFRMNonRootInterface @type interface: SFRMNonRootInterface
''' '''
interface.rprint('aw_failure, L -> W') print('aw_failure, L -> W')
interface._set_assert_state(AssertState.Winner) interface._set_assert_state(AssertState.Winner)
interface._set_winner_metric(None) interface._set_winner_metric(None)
...@@ -190,7 +190,7 @@ class SFMRAssertLooser(SFMRAssertABC): ...@@ -190,7 +190,7 @@ class SFMRAssertLooser(SFMRAssertABC):
''' '''
@type interface: SFRMNonRootInterface @type interface: SFRMNonRootInterface
''' '''
interface.rprint('al_rpc_improves, L -> W') print('al_rpc_improves, L -> W')
interface._set_assert_state(AssertState.Winner) interface._set_assert_state(AssertState.Winner)
interface._set_winner_metric(None) interface._set_winner_metric(None)
...@@ -209,14 +209,14 @@ class SFMRAssertLooser(SFMRAssertABC): ...@@ -209,14 +209,14 @@ class SFMRAssertLooser(SFMRAssertABC):
''' '''
@type interface: SFRMNonRootInterface @type interface: SFRMNonRootInterface
''' '''
interface.rprint('is_now_root, L -> L') print('is_now_root, L -> L')
@staticmethod @staticmethod
def recv_reset(interface): def recv_reset(interface):
''' '''
@type interface: SFRMNonRootInterface @type interface: SFRMNonRootInterface
''' '''
interface.rprint('recv_reset, L -> W') print('recv_reset, L -> W')
interface._set_assert_state(AssertState.Winner) interface._set_assert_state(AssertState.Winner)
interface._set_winner_metric(None) interface._set_winner_metric(None)
...@@ -227,7 +227,7 @@ class SFMRAssertLooser(SFMRAssertABC): ...@@ -227,7 +227,7 @@ class SFMRAssertLooser(SFMRAssertABC):
''' '''
@type interface: SFRMNonRootInterface @type interface: SFRMNonRootInterface
''' '''
interface.rprint('is_now_pruned, L -> W') print('is_now_pruned, L -> W')
interface._set_assert_state(AssertState.Winner) interface._set_assert_state(AssertState.Winner)
interface._set_winner_metric(None) interface._set_winner_metric(None)
......
...@@ -15,23 +15,23 @@ from .prune import SFMRPruneState, SFMRPruneStateABC ...@@ -15,23 +15,23 @@ from .prune import SFMRPruneState, SFMRPruneStateABC
from .tree_interface import SFRMTreeInterface from .tree_interface import SFRMTreeInterface
from Packet.ReceivedPacket import ReceivedPacket from Packet.ReceivedPacket import ReceivedPacket
from Packet.PacketPimAssert import PacketPimAssert from Packet.PacketPimAssert import PacketPimAssert
from threading import Lock
class SFRMNonRootInterface(SFRMTreeInterface): class SFRMNonRootInterface(SFRMTreeInterface):
DIPT_TIME = 3.0 DIPT_TIME = 3.0
def __init__(self, kernel_entry, interface_id): def __init__(self, kernel_entry, interface_id):
SFRMTreeInterface.__init__(self, kernel_entry, interface_id, None) SFRMTreeInterface.__init__(self, kernel_entry, interface_id)
self._assert_state = AssertState.Winner self._assert_state = AssertState.Winner
self._assert_metric = None self._assert_metric = None
self._prune_state = SFMRPruneState.DIP self._prune_state = SFMRPruneState.DIP
#self._dipt = Timer(SFRMNonRootInterface.DIPT_TIME, self.__dipt_expires)
#self._dipt.start()
self._dipt = None self._dipt = None
self.set_dipt_timer() self.set_dipt_timer()
self.send_prune() self.send_prune()
# Override # Override
def recv_data_msg(self, msg=None, sender=None): def recv_data_msg(self, msg=None, sender=None):
if self._prune_state != SFMRPruneState.NDI: if self._prune_state != SFMRPruneState.NDI:
...@@ -51,9 +51,6 @@ class SFRMNonRootInterface(SFRMTreeInterface): ...@@ -51,9 +51,6 @@ class SFRMNonRootInterface(SFRMTreeInterface):
else: else:
winner_metric = self.get_metric() winner_metric = self.get_metric()
ip_sender = msg.ip_header.ip_src ip_sender = msg.ip_header.ip_src
pkt_assert = msg.payload.payload # type: PacketPimAssert pkt_assert = msg.payload.payload # type: PacketPimAssert
msg_metric = SFMRAssertMetric(metric_preference=pkt_assert.metric_preference, route_metric=pkt_assert.metric, ip_address=ip_sender) msg_metric = SFMRAssertMetric(metric_preference=pkt_assert.metric_preference, route_metric=pkt_assert.metric, ip_address=ip_sender)
...@@ -74,21 +71,15 @@ class SFRMNonRootInterface(SFRMTreeInterface): ...@@ -74,21 +71,15 @@ class SFRMNonRootInterface(SFRMTreeInterface):
# Override # Override
def recv_prune_msg(self, msg, sender, in_group): def recv_prune_msg(self, msg, sender, in_group):
super().recv_prune_msg(msg, sender, in_group) super().recv_prune_msg(msg, sender, in_group)
#with self.prune_lock:
self._prune_state.recv_prune(self) self._prune_state.recv_prune(self)
# Override # Override
def recv_join_msg(self, msg, sender, in_group): def recv_join_msg(self, msg, sender, in_group):
super().recv_join_msg(msg, sender, in_group) super().recv_join_msg(msg, sender, in_group)
#with self.prune_lock:
self._prune_state.recv_join(self) self._prune_state.recv_join(self)
def forward_data_msg(self, msg):
pass
#def forward_data_msg(self, msg):
# if self.is_forwarding():
# self._interface.send_mcast(msg)
def send_assert(self): def send_assert(self):
(source, group) = self.get_tree_id() (source, group) = self.get_tree_id()
from Packet.Packet import Packet from Packet.Packet import Packet
...@@ -111,17 +102,17 @@ class SFRMNonRootInterface(SFRMTreeInterface): ...@@ -111,17 +102,17 @@ class SFRMNonRootInterface(SFRMTreeInterface):
def send_prune(self): def send_prune(self):
SFRMTreeInterface.send_prune(self) SFRMTreeInterface.send_prune(self)
#if self._dipt.is_ticking():
if self._dipt.is_alive():
self._dipt.cancel()
# Override # Override
def is_forwarding(self): def is_forwarding(self):
return self._assert_state == AssertState.Winner \ return self._assert_state == AssertState.Winner \
and self._prune_state != SFMRPruneState.NDI and (self.igmp_has_members() or not self.is_pruned())
def is_pruned(self):
return self._prune_state == SFMRPruneState.NDI
# Override # Override
def nbr_died(self, node): def nbr_died(self, node):
# todo
if self._get_winner_metric() is not None \ if self._get_winner_metric() is not None \
and self._get_winner_metric().get_ip_address() == node\ and self._get_winner_metric().get_ip_address() == node\
and self._prune_state != SFMRPruneState.NDI: and self._prune_state != SFMRPruneState.NDI:
...@@ -146,21 +137,19 @@ class SFRMNonRootInterface(SFRMTreeInterface): ...@@ -146,21 +137,19 @@ class SFRMNonRootInterface(SFRMTreeInterface):
def __dipt_expires(self): def __dipt_expires(self):
print('DIPT expired') print('DIPT expired')
self._prune_state.dipt_expires(self) self._prune_state.dipt_expires(self)
def get_metric(self): def get_metric(self):
return SFMRAssertMetric.spt_assert_metric(self) return SFMRAssertMetric.spt_assert_metric(self)
def _set_assert_state(self, value): def _set_assert_state(self, value: SFMRAssertABC):
assert isinstance(value, SFMRAssertABC) with self.get_state_lock():
if value != self._assert_state: if value != self._assert_state:
self._assert_state = value self._assert_state = value
self.change_tree()
self.evaluate_ingroup() self.evaluate_ingroup()
#Convergence.mark_change() #Convergence.mark_change()
self.change_tree()
def _get_winner_metric(self): def _get_winner_metric(self):
''' '''
...@@ -170,11 +159,13 @@ class SFRMNonRootInterface(SFRMTreeInterface): ...@@ -170,11 +159,13 @@ class SFRMNonRootInterface(SFRMTreeInterface):
def _set_winner_metric(self, value): def _set_winner_metric(self, value):
assert isinstance(value, SFMRAssertMetric) or value is None assert isinstance(value, SFMRAssertMetric) or value is None
# todo
self._assert_metric = value self._assert_metric = value
# Override # Override
def set_cost(self, value): def set_cost(self, value):
# todo
"""
if value != self._cost and self._prune_state != SFMRPruneState.NDI: if value != self._cost and self._prune_state != SFMRPruneState.NDI:
if self.is_forwarding() and value > self._cost: if self.is_forwarding() and value > self._cost:
SFRMTreeInterface.set_cost(self, value) SFRMTreeInterface.set_cost(self, value)
...@@ -189,16 +180,16 @@ class SFRMNonRootInterface(SFRMTreeInterface): ...@@ -189,16 +180,16 @@ class SFRMNonRootInterface(SFRMTreeInterface):
SFRMTreeInterface.set_cost(self, value) SFRMTreeInterface.set_cost(self, value)
else: else:
SFRMTreeInterface.set_cost(self, value) SFRMTreeInterface.set_cost(self, value)
"""
raise NotImplemented
def _set_prune_state(self, value): def _set_prune_state(self, value: SFMRPruneStateABC):
assert isinstance(value, SFMRPruneStateABC) with self.get_state_lock():
if value != self._prune_state: if value != self._prune_state:
self._prune_state = value self._prune_state = value
self.evaluate_ingroup()
#Convergence.mark_change()
self.change_tree() self.change_tree()
self.evaluate_ingroup()
if value == SFMRPruneState.NDI: if value == SFMRPruneState.NDI:
self._assert_state.is_now_pruned(self) self._assert_state.is_now_pruned(self)
......
...@@ -52,7 +52,7 @@ class SFMRDownstreamInterested(SFMRPruneStateABC): ...@@ -52,7 +52,7 @@ class SFMRDownstreamInterested(SFMRPruneStateABC):
@type interface: SFRMNonRootInterface @type interface: SFRMNonRootInterface
''' '''
if len(interface.get_interface().neighbors) == 1: if len(interface.get_interface().neighbors) <= 1:
print('recv_prune, DI -> NDI (only 1 nbr)') print('recv_prune, DI -> NDI (only 1 nbr)')
interface._set_prune_state(SFMRPruneState.NDI) interface._set_prune_state(SFMRPruneState.NDI)
...@@ -110,13 +110,13 @@ class SFMRDownstreamInterestedPending(SFMRPruneStateABC): ...@@ -110,13 +110,13 @@ class SFMRDownstreamInterestedPending(SFMRPruneStateABC):
@type interface: SFRMNonRootInterface @type interface: SFRMNonRootInterface
''' '''
# TODO foi alterado pelo Pedro... necessita de verificacao se esta OK... # TODO foi alterado pelo Pedro... necessita de verificacao se esta OK...
print('recv_prune, DIP -> DIP') #print('recv_prune, DIP -> DIP')
if len(interface.get_interface().neighbors) == 1: if len(interface.get_interface().neighbors) <= 1:
print('recv_prune, DIP -> DI (only 1 nbr)') print('recv_prune, DIP -> NDI (only 1 nbr)')
else:
print('recv_prune, DIP -> NDI')
interface._set_prune_state(SFMRPruneState.NDI) interface._set_prune_state(SFMRPruneState.NDI)
interface.clear_dipt_timer() interface.clear_dipt_timer()
else:
print('recv_prune, DIP -> DIP')
@staticmethod @staticmethod
def recv_join(interface): def recv_join(interface):
...@@ -145,7 +145,8 @@ class SFMRDownstreamInterestedPending(SFMRPruneStateABC): ...@@ -145,7 +145,8 @@ class SFMRDownstreamInterestedPending(SFMRPruneStateABC):
print('is_now_root, DIP -> DI') print('is_now_root, DIP -> DI')
interface._set_prune_state(SFMRPruneState.DI) interface._set_prune_state(SFMRPruneState.DI)
interface._get_dipt().stop() #interface._get_dipt().stop()
interface.clear_dipt_timer()
@staticmethod @staticmethod
def new_nbr(interface): def new_nbr(interface):
......
...@@ -21,9 +21,16 @@ class SFRMRootInterface(SFRMTreeInterface): ...@@ -21,9 +21,16 @@ class SFRMRootInterface(SFRMTreeInterface):
evaluate_ig_cb, evaluate_ig_cb,
is_originater: bool, ): is_originater: bool, ):
''' '''
SFRMTreeInterface.__init__(self, kernel_entry, interface_id, None) SFRMTreeInterface.__init__(self, kernel_entry, interface_id)
self._is_originater = is_originater self._is_originater = is_originater
def recv_data_msg(self, msg, sender):
#with self.CHANGE_STATE_LOCK:
# self._kernel_entry.evaluate_ingroup()
#if not self._kernel_entry.is_in_group():
# self.send_prune()
return
#Override #Override
#def recv_assert_msg(self, msg: SFMRAssertMsg, sender: Addr): #def recv_assert_msg(self, msg: SFMRAssertMsg, sender: Addr):
def recv_assert_msg(self, msg, sender): def recv_assert_msg(self, msg, sender):
...@@ -33,15 +40,13 @@ class SFRMRootInterface(SFRMTreeInterface): ...@@ -33,15 +40,13 @@ class SFRMRootInterface(SFRMTreeInterface):
def recv_prune_msg(self, msg, sender, in_group): def recv_prune_msg(self, msg, sender, in_group):
super().recv_prune_msg(msg, sender, in_group) super().recv_prune_msg(msg, sender, in_group)
if in_group: #if in_group:
with self._kernel_entry._lock_test2:
if self._kernel_entry._was_in_group:
print("I WILL SEND JOIN") print("I WILL SEND JOIN")
self.send_join() self.send_join()
print("I SENT JOIN") print("I SENT JOIN")
def forward_data_msg(self, msg):
pass
def send_join(self): def send_join(self):
# Originaters dont need to send prunes or joins # Originaters dont need to send prunes or joins
if self._is_originater: if self._is_originater:
......
...@@ -5,13 +5,16 @@ Created on Jul 16, 2015 ...@@ -5,13 +5,16 @@ Created on Jul 16, 2015
''' '''
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
import Main import Main
from threading import Lock, RLock
import traceback
#from convergence import Convergence #from convergence import Convergence
#from sfmr.messages.prune import SFMRPruneMsg #from sfmr.messages.prune import SFMRPruneMsg
#from .router_interface import SFMRInterface #from .router_interface import SFMRInterface
class SFRMTreeInterface(metaclass=ABCMeta): class SFRMTreeInterface(metaclass=ABCMeta):
def __init__(self, kernel_entry, interface_id, evaluate_ig_cb): def __init__(self, kernel_entry, interface_id):
''' '''
@type interface: SFMRInterface @type interface: SFMRInterface
@type node: Node @type node: Node
...@@ -24,10 +27,20 @@ class SFRMTreeInterface(metaclass=ABCMeta): ...@@ -24,10 +27,20 @@ class SFRMTreeInterface(metaclass=ABCMeta):
#self._node = node #self._node = node
#self._tree_id = tree_id #self._tree_id = tree_id
#self._cost = cost #self._cost = cost
self._evaluate_ig = evaluate_ig_cb #self._evaluate_ig = evaluate_ig_cb
try:
interface_name = Main.kernel.vif_index_to_name_dic[interface_id]
igmp_interface = Main.igmp_interfaces[interface_name] # type: InterfaceIGMP
group_state = igmp_interface.interface_state.get_group_state(kernel_entry.group_ip)
self._igmp_has_members = group_state.add_multicast_routing_entry(self)
except:
#traceback.print_exc()
self._igmp_has_members = False
self._igmp_lock = RLock()
#self.rprint('new ' + self.__class__.__name__) #self.rprint('new ' + self.__class__.__name__)
#Convergence.mark_change()
def recv_data_msg(self, msg, sender): def recv_data_msg(self, msg, sender):
pass pass
...@@ -40,19 +53,15 @@ class SFRMTreeInterface(metaclass=ABCMeta): ...@@ -40,19 +53,15 @@ class SFRMTreeInterface(metaclass=ABCMeta):
pass pass
def recv_prune_msg(self, msg, sender, in_group): def recv_prune_msg(self, msg, sender, in_group):
print("SUPER PRUNE")
pass pass
def recv_join_msg(self, msg, sender, in_group): def recv_join_msg(self, msg, sender, in_group):
print("SUPER JOIN")
pass
@abstractmethod
def forward_data_msg(self, msg):
pass pass
def forward_state_reset_msg(self, msg): def forward_state_reset_msg(self, msg):
self._interface.send_mcast(msg) #self._interface.send_mcast(msg)
# todo
raise NotImplemented
def send_prune(self): def send_prune(self):
try: try:
...@@ -91,8 +100,20 @@ class SFRMTreeInterface(metaclass=ABCMeta): ...@@ -91,8 +100,20 @@ class SFRMTreeInterface(metaclass=ABCMeta):
print('Tree Interface deleted') print('Tree Interface deleted')
def evaluate_ingroup(self): def evaluate_ingroup(self):
# todo help self._evaluate_ig() self._kernel_entry.evaluate_ingroup()
return
def notify_igmp(self, has_members: bool):
with self.get_state_lock():
with self._igmp_lock:
if has_members != self._igmp_has_members:
self._igmp_has_members = has_members
self.change_tree()
self.evaluate_ingroup()
def igmp_has_members(self):
with self._igmp_lock:
return self._igmp_has_members
''' '''
def rprint(self, msg, *entrys): def rprint(self, msg, *entrys):
...@@ -107,10 +128,10 @@ class SFRMTreeInterface(metaclass=ABCMeta): ...@@ -107,10 +128,10 @@ class SFRMTreeInterface(metaclass=ABCMeta):
return '{}<{}>'.format(self.__class__, self._interface.get_link()) return '{}<{}>'.format(self.__class__, self._interface.get_link())
def get_link(self): def get_link(self):
# todo
return self._interface.get_link() return self._interface.get_link()
def get_interface(self): def get_interface(self):
import Main
kernel = Main.kernel kernel = Main.kernel
interface_name = kernel.vif_index_to_name_dic[self._interface_id] interface_name = kernel.vif_index_to_name_dic[self._interface_id]
interface = Main.interfaces[interface_name] interface = Main.interfaces[interface_name]
...@@ -121,20 +142,20 @@ class SFRMTreeInterface(metaclass=ABCMeta): ...@@ -121,20 +142,20 @@ class SFRMTreeInterface(metaclass=ABCMeta):
return self.get_ip() return self.get_ip()
def get_ip(self): def get_ip(self):
import Main #kernel = Main.kernel
kernel = Main.kernel #interface_name = kernel.vif_index_to_name_dic[self._interface_id]
interface_name = kernel.vif_index_to_name_dic[self._interface_id] #import netifaces
import netifaces #netifaces.ifaddresses(interface_name)
netifaces.ifaddresses(interface_name) #ip = netifaces.ifaddresses(interface_name)[netifaces.AF_INET][0]['addr']
ip = netifaces.ifaddresses(interface_name)[netifaces.AF_INET][0]['addr'] ip = self.get_interface().get_ip()
return ip return ip
def get_tree_id(self): def get_tree_id(self):
#return self._tree_id
return (self._kernel_entry.source_ip, self._kernel_entry.group_ip) return (self._kernel_entry.source_ip, self._kernel_entry.group_ip)
def get_cost(self): def get_cost(self):
#return self._cost #return self._cost
# todo
return 10 return 10
def set_cost(self, value): def set_cost(self, value):
...@@ -142,3 +163,6 @@ class SFRMTreeInterface(metaclass=ABCMeta): ...@@ -142,3 +163,6 @@ class SFRMTreeInterface(metaclass=ABCMeta):
def change_tree(self): def change_tree(self):
self._kernel_entry.change() self._kernel_entry.change()
def get_state_lock(self):
return self._kernel_entry.CHANGE_STATE_LOCK
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment