Commit c6d3cf5a authored by Pedro Oliveira's avatar Pedro Oliveira

Backup commit: has IGMPv2 implementation, Assert and Prune/Join state machine <- not all done yet

parent 6cbbc7ef
from Packet.ReceivedPacket import ReceivedPacket
from Packet.PacketPimAssert import PacketPimAssert
import Main
import traceback
class Assert:
TYPE = 5
def __init__(self):
Main.add_protocol(Assert.TYPE, self)
# receive handler
def receive_handle(self, packet: ReceivedPacket):
interface = packet.interface
interface_name = interface.interface_name
ip = packet.ip_header.ip_src
print("ip = ", ip)
pkt_assert = packet.payload.payload # type: PacketPimAssert
metric = pkt_assert.metric
metric_preference = pkt_assert.metric_preference
source = pkt_assert.source_address
group = pkt_assert.multicast_group_address
source_group = (source, group)
interface_name = packet.interface.interface_name
interface_index = Main.kernel.vif_name_to_index_dic[interface_name]
try:
#Main.kernel.routing[source_group].recv_assert_msg(interface_index, packet)
Main.kernel.get_routing_entry(source_group).recv_assert_msg(interface_index, packet)
except:
traceback.print_exc()
......@@ -124,7 +124,6 @@ class Daemon:
""" Check For the existence of a unix pid. """
try:
os.kill(pid, 0)
return True
except OSError:
return False
else:
return True
import random
from threading import Timer
from Packet.Packet import Packet
from Packet.ReceivedPacket import ReceivedPacket
from Packet.PacketPimHello import PacketPimHello
from Packet.PacketPimHeader import PacketPimHeader
from Interface import Interface
import Main
from utils import HELLO_HOLD_TIME_TIMEOUT
class Graft:
TYPE = 6
def __init__(self):
Main.add_protocol(Graft.TYPE, self)
# receive handler
def receive_handle(self, packet: ReceivedPacket):
ip = packet.ip_header.ip_src
print("ip = ", ip)
pkt_join_prune = packet.payload.payload
# TODO
raise Exception
\ No newline at end of file
import random
from threading import Timer
from Packet.Packet import Packet
from Packet.ReceivedPacket import ReceivedPacket
from Packet.PacketPimHello import PacketPimHello
from Packet.PacketPimHeader import PacketPimHeader
from Interface import Interface
import Main
from utils import HELLO_HOLD_TIME_TIMEOUT
class GraftAck:
TYPE = 7
def __init__(self):
Main.add_protocol(GraftAck.TYPE, self)
# receive handler
def receive_handle(self, packet: ReceivedPacket):
ip = packet.ip_header.ip_src
print("ip = ", ip)
pkt_join_prune = packet.payload.payload
# TODO
raise Exception
\ No newline at end of file
......@@ -29,10 +29,10 @@ class Hello:
def packet_send_handle(self, interface: Interface):
pim_payload = PacketPimHello()
pim_payload.add_option(1, Hello.TRIGGERED_HELLO_DELAY)
pim_payload.add_option(1, 3.5 * Hello.TRIGGERED_HELLO_DELAY)
pim_payload.add_option(20, interface.generation_id)
ph = PacketPimHeader(pim_payload)
packet = Packet(pim_header=ph)
packet = Packet(payload=ph)
interface.send(packet.bytes())
def force_send(self, interface: Interface):
......@@ -46,14 +46,14 @@ class Hello:
pim_payload.add_option(1, HELLO_HOLD_TIME_TIMEOUT)
pim_payload.add_option(20, interface.generation_id)
ph = PacketPimHeader(pim_payload)
packet = Packet(pim_header=ph)
packet = Packet(payload=ph)
interface.send(packet.bytes())
# receive handler
def receive_handle(self, packet: ReceivedPacket):
ip = packet.ip_header.ip_src
print("ip = ", ip)
options = packet.pim_header.payload.get_options()
options = packet.payload.payload.get_options()
if Main.get_neighbor(ip) is None:
# Unknown Neighbor
if (1 in options) and (20 in options):
......
from Packet.ReceivedPacket import ReceivedPacket
from utils import *
from ipaddress import IPv4Address
class IGMP:
# receive handler
@staticmethod
def receive_handle(packet: ReceivedPacket):
interface = packet.interface
ip_src = packet.ip_header.ip_src
ip_dst = packet.ip_header.ip_dst
print("ip = ", ip_src)
igmp_hdr = packet.payload
igmp_type = igmp_hdr.type
igmp_group = igmp_hdr.group_address
# source ip can't be 0.0.0.0 or multicast
if ip_src == "0.0.0.0" or IPv4Address(ip_src).is_multicast:
return
if igmp_type == Version_1_Membership_Report and ip_dst == igmp_group and IPv4Address(igmp_group).is_multicast:
interface.interface_state.receive_v1_membership_report(packet)
elif igmp_type == Version_2_Membership_Report and ip_dst == igmp_group and IPv4Address(igmp_group).is_multicast:
interface.interface_state.receive_v2_membership_report(packet)
elif igmp_type == Leave_Group and ip_dst == "224.0.0.2" and IPv4Address(igmp_group).is_multicast:
interface.interface_state.receive_leave_group(packet)
elif igmp_type == Membership_Query and (ip_dst == igmp_group or (ip_dst == "224.0.0.1" and igmp_group == "0.0.0.0")):
interface.interface_state.receive_query(packet)
else:
raise Exception("Exception igmp packet: type={}; ip_dst={}; packet_group_report={}".format(igmp_type, ip_dst, igmp_group))
......@@ -6,12 +6,14 @@ from Packet.ReceivedPacket import ReceivedPacket
import Main
import traceback
class Interface:
class Interface(object):
MCAST_GRP = '224.0.0.13'
# substituir ip por interface ou algo parecido
def __init__(self, interface_name: str):
self.interface_name = interface_name
ip_interface = netifaces.ifaddresses(interface_name)[netifaces.AF_INET][0]['addr']
self.ip_interface = ip_interface
s = socket.socket(socket.AF_INET, socket.SOCK_RAW, socket.IPPROTO_PIM)
......@@ -36,6 +38,9 @@ class Interface:
# generation id
self.generation_id = random.getrandbits(32)
# todo neighbors
self.neighbors = {}
# run receive method in background
receive_thread = threading.Thread(target=self.receive)
receive_thread.daemon = True
......@@ -47,10 +52,7 @@ class Interface:
(raw_packet, (ip, _)) = self.socket.recvfrom(256 * 1024)
if raw_packet:
packet = ReceivedPacket(raw_packet, self)
#print("packet received bytes: ", packet.bytes())
#print("pim type received = ", packet.pim_header.msg_type)
#print("options received = ", packet.pim_header.payload.options)
Main.protocols[packet.pim_header.get_pim_type()].receive_handle(packet) # TODO: perceber se existe melhor maneira de fazer isto
Main.protocols[packet.payload.get_pim_type()].receive_handle(packet) # TODO: perceber se existe melhor maneira de fazer isto
except Exception:
traceback.print_exc()
continue
......
import socket
import struct
import threading
import netifaces
from Packet.ReceivedPacket import ReceivedPacket
import Main
import traceback
if not hasattr(socket, 'SO_BINDTODEVICE'):
socket.SO_BINDTODEVICE = 25
class InterfaceIGMP(object):
ETH_P_IP = 0x0800 # Internet Protocol packet
PACKET_MR_ALLMULTI = 2
def __init__(self, interface_name: str):
# RECEIVE SOCKET
rcv_s = socket.socket(socket.PF_PACKET, socket.SOCK_RAW, socket.htons(InterfaceIGMP.ETH_P_IP))
# allow all multicast packets
rcv_s.setsockopt(socket.SOL_SOCKET, InterfaceIGMP.PACKET_MR_ALLMULTI, struct.pack("i HH BBBBBBBB", 0, InterfaceIGMP.PACKET_MR_ALLMULTI, 0, 0,0,0,0,0,0,0,0))
# bind to interface
rcv_s.bind((interface_name, 0))
self.recv_socket = rcv_s
# SEND SOCKET
snd_s = socket.socket(socket.AF_INET, socket.SOCK_RAW, socket.IPPROTO_IGMP)
# bind to interface
snd_s.setsockopt(socket.SOL_SOCKET, socket.SO_BINDTODEVICE, str(interface_name + "\0").encode('utf-8'))
self.send_socket = snd_s
self.interface_enabled = True
self.interface_name = interface_name
from igmp.RouterState import RouterState
self.interface_state = RouterState(self)
# run receive method in background
receive_thread = threading.Thread(target=self.receive)
receive_thread.daemon = True
receive_thread.start()
def get_ip(self):
return netifaces.ifaddresses(self.interface_name)[netifaces.AF_INET][0]['addr']
def send(self, data: bytes, address: str="224.0.0.1"):
if self.interface_enabled:
self.send_socket.sendto(data, (address, 0))
def receive(self):
while self.interface_enabled:
try:
(raw_packet, x) = self.recv_socket.recvfrom(256 * 1024)
if raw_packet:
raw_packet = raw_packet[14:]
from Packet.PacketIpHeader import PacketIpHeader
(verhlen, tos, iplen, ipid, frag, ttl, proto, cksum, src, dst) = \
struct.unpack(PacketIpHeader.IP_HDR, raw_packet[:PacketIpHeader.IP_HDR_LEN])
print(proto)
if proto != socket.IPPROTO_IGMP:
continue
print((raw_packet, x))
packet = ReceivedPacket(raw_packet, self)
Main.igmp.receive_handle(packet)
except Exception:
traceback.print_exc()
continue
def remove(self):
self.interface_enabled = False
self.recv_socket.close()
self.send_socket.close()
from Packet.Packet import Packet
from Packet.ReceivedPacket import ReceivedPacket
from Packet.PacketPimJoinPrune import PacketPimJoinPrune
from Packet.PacketPimJoinPruneMulticastGroup import PacketPimJoinPruneMulticastGroup
from Interface import Interface
import Main
import traceback
class JoinPrune:
TYPE = 3
def __init__(self):
Main.add_protocol(JoinPrune.TYPE, self)
# receive handler
def receive_handle(self, packet: ReceivedPacket):
interface = packet.interface
ip = packet.ip_header.ip_src
print("ip = ", ip)
pkt_join_prune = packet.payload.payload # type: PacketPimJoinPrune
# if im not upstream neighbor ignore message
if pkt_join_prune.upstream_neighbor_address != interface.ip_interface:
#return
pass
interface_name = interface.interface_name
interface_index = Main.kernel.vif_name_to_index_dic[interface_name]
# todo holdtime
holdtime = pkt_join_prune.hold_time
join_prune_groups = pkt_join_prune.groups
for group in join_prune_groups:
multicast_group = group.multicast_group
joined_src_addresses = group.joined_src_addresses
pruned_src_addresses = group.pruned_src_addresses
for source_address in joined_src_addresses:
source_group = (source_address, multicast_group)
try:
#Main.kernel.routing[source_group].recv_join_msg(interface_index, packet)
Main.kernel.get_routing_entry(source_group).recv_join_msg(interface_index, packet)
except:
# todo o que fazer quando n existe arvore para (s,g) ???
traceback.print_exc()
print("ATENCAO!!!!")
print(Main.kernel.routing)
continue
for source_address in pruned_src_addresses:
source_group = (source_address, multicast_group)
try:
#Main.kernel.routing[source_group].recv_prune_msg(interface_index, packet)
Main.kernel.get_routing_entry(source_group).recv_prune_msg(interface_index, packet)
except:
# todo o que fazer quando n existe arvore para (s,g) ???
traceback.print_exc()
print("ATENCAO!!!!")
print(Main.kernel.routing)
continue
This diff is collapsed.
......@@ -3,26 +3,33 @@ import time
from prettytable import PrettyTable
from Interface import Interface
from InterfaceIGMP import InterfaceIGMP
from Kernel import Kernel
from Neighbor import Neighbor
from threading import Lock
interfaces = {} # interfaces with multicast routing enabled
igmp_interfaces = {} # igmp interfaces
neighbors = {} # multicast router neighbors
neighbors_lock = Lock()
protocols = {}
kernel = None
igmp = None
def add_interface(interface_name):
def add_interface(interface_name, pim=False, igmp=False):
global interfaces
if interface_name not in interfaces:
if pim is True and interface_name not in interfaces:
interface = Interface(interface_name)
interfaces[interface_name] = interface
protocols[0].force_send(interface)
if igmp is True and interface_name not in igmp_interfaces:
interface = InterfaceIGMP(interface_name)
igmp_interfaces[interface_name] = interface
def remove_interface(interface_name):
def remove_interface(interface_name, pim=False, igmp=False):
global interfaces
global neighbors
if (interface_name in interfaces) or interface_name == "*":
if pim is True and ((interface_name in interfaces) or interface_name == "*"):
if interface_name == "*":
interface_name = list(interfaces.keys())
else:
......@@ -37,25 +44,43 @@ def remove_interface(interface_name):
if neighbor.contact_interface not in interfaces:
neighbor.remove()
if igmp is True and ((interface_name in igmp_interfaces) or interface_name == "*"):
if interface_name == "*":
interface_name = list(igmp_interfaces.keys())
else:
interface_name = [interface_name]
for if_name in interface_name:
igmp_interfaces[if_name].remove()
del igmp_interfaces[if_name]
print("removido interface")
def add_neighbor(contact_interface, ip, random_number, hello_hold_time):
global neighbors
if ip not in neighbors:
print("ADD NEIGHBOR")
neighbors[ip] = Neighbor(contact_interface, ip, random_number, hello_hold_time)
protocols[0].force_send(contact_interface)
with neighbors_lock:
if ip not in neighbors:
print("ADD NEIGHBOR")
n = Neighbor(contact_interface, ip, random_number, hello_hold_time)
neighbors[ip] = n
protocols[0].force_send(contact_interface)
# todo check neighbor in interface
contact_interface.neighbors[ip] = n
def get_neighbor(ip) -> Neighbor:
global neighbors
if ip not in neighbors:
return None
return neighbors[ip]
with neighbors_lock:
if ip not in neighbors:
return None
return neighbors[ip]
def remove_neighbor(ip):
global neighbors
if ip in neighbors:
del neighbors[ip]
print("removido neighbor")
with neighbors_lock:
if ip in neighbors:
del neighbors[ip]
print("removido neighbor")
def add_protocol(protocol_number, protocol_obj):
global protocols
......@@ -85,31 +110,75 @@ def list_enabled_interfaces():
ph = PacketPimJoinPrune("10.0.0.13", 210)
ph.add_multicast_group(PacketPimJoinPruneMulticastGroup("239.123.123.123", ["1.1.1.1", "10.1.1.1"], []))
ph.add_multicast_group(PacketPimJoinPruneMulticastGroup("239.123.123.124", ["1.1.1.2", "10.1.1.2"], []))
pckt = Packet(pim_header=PacketPimHeader(ph))
pckt = Packet(payload=PacketPimHeader(ph))
interfaces[interface].send(pckt.bytes())
ph = PacketPimJoinPrune("ff08::1", 210)
ph.add_multicast_group(PacketPimJoinPruneMulticastGroup("2001:1:a:b:c::1", ["1.1.1.1", "2001:1:a:b:c::2"], []))
ph.add_multicast_group(PacketPimJoinPruneMulticastGroup("239.123.123.123", ["1.1.1.1"], ["2001:1:a:b:c::3"]))
pckt = Packet(pim_header=PacketPimHeader(ph))
pckt = Packet(payload=PacketPimHeader(ph))
interfaces[interface].send(pckt.bytes())
from Packet.PacketPimAssert import PacketPimAssert
ph = PacketPimAssert("224.12.12.12", "10.0.0.2", 210, 2)
pckt = Packet(payload=PacketPimHeader(ph))
interfaces[interface].send(pckt.bytes())
t = PrettyTable(['Interface', 'IP', 'Enabled'])
from Packet.PacketPimGraft import PacketPimGraft
ph = PacketPimGraft("10.0.0.13", 210)
ph.add_multicast_group(PacketPimJoinPruneMulticastGroup("239.123.123.124", ["1.1.1.2", "10.1.1.2"], []))
pckt = Packet(payload=PacketPimHeader(ph))
interfaces[interface].send(pckt.bytes())
t = PrettyTable(['Interface', 'IP', 'PIM/IMGP Enabled', 'IGMP State'])
for interface in netifaces.interfaces():
# TODO: fix same interface with multiple ips
ip = netifaces.ifaddresses(interface)[netifaces.AF_INET][0]['addr']
status = interface in interfaces
t.add_row([interface, ip, status])
try:
# TODO: fix same interface with multiple ips
ip = netifaces.ifaddresses(interface)[netifaces.AF_INET][0]['addr']
pim_enabled = interface in interfaces
igmp_enabled = interface in igmp_interfaces
enabled = str(pim_enabled) + "/" + str(igmp_enabled)
if igmp_enabled:
state = igmp_interfaces[interface].interface_state.print_state()
else:
state = "-"
t.add_row([interface, ip, enabled, state])
except Exception:
continue
print(t)
return str(t)
def list_igmp_state():
t = PrettyTable(['Interface', 'RouterState', 'Group Adress', 'GroupState'])
for (interface_name, interface_obj) in list(igmp_interfaces.items()):
interface_state = interface_obj.interface_state
state_txt = interface_state.print_state()
print(interface_state.group_state.items())
for (group_addr, group_state) in list(interface_state.group_state.items()):
print(group_addr)
group_state_txt = group_state.print_state()
t.add_row([interface_name, state_txt, group_addr, group_state_txt])
return str(t)
def main(interfaces_to_add=[]):
from Hello import Hello
Hello()
from IGMP import IGMP
from Assert import Assert
from JoinPrune import JoinPrune
Hello()
Assert()
JoinPrune()
global kernel
kernel = Kernel()
global igmp
igmp = IGMP()
for interface in interfaces_to_add:
add_interface(interface)
......@@ -4,6 +4,7 @@ from utils import HELLO_HOLD_TIME_NO_TIMEOUT, HELLO_HOLD_TIME_TIMEOUT
from Interface import Interface
import Main
class Neighbor:
def __init__(self, contact_interface: Interface, ip, generation_id: int, hello_hold_time: int):
if hello_hold_time == HELLO_HOLD_TIME_TIMEOUT:
......@@ -24,7 +25,7 @@ class Neighbor:
if hello_hold_time == HELLO_HOLD_TIME_TIMEOUT:
self.remove()
elif hello_hold_time != HELLO_HOLD_TIME_NO_TIMEOUT:
self.neighbor_liveness_timer = Timer(4 * hello_hold_time, self.remove)
self.neighbor_liveness_timer = Timer(hello_hold_time, self.remove)
self.neighbor_liveness_timer.start()
else:
self.neighbor_liveness_timer = None
......@@ -35,7 +36,7 @@ class Neighbor:
print("HEARTBEAT")
if self.neighbor_liveness_timer is not None:
self.neighbor_liveness_timer.cancel()
self.neighbor_liveness_timer = Timer(4 * self.hello_hold_time, self.remove)
self.neighbor_liveness_timer = Timer(self.hello_hold_time, self.remove)
self.neighbor_liveness_timer.start()
self.time_of_last_update = time.time()
......
from Packet.PacketIpHeader import PacketIpHeader
from Packet.PacketPimHeader import PacketPimHeader
from .PacketIpHeader import PacketIpHeader
from .PacketPayload import PacketPayload
class Packet:
# ter ip header
# pim header
# pim options
def __init__(self, ip_header: PacketIpHeader = None, pim_header: PacketPimHeader = None):
class Packet(object):
def __init__(self, ip_header: PacketIpHeader = None, payload: PacketPayload = None):
self.ip_header = ip_header
self.pim_header = pim_header
self.payload = payload
# maybe remover
'''def add_option(self, option: PacketPimOption):
self.pim_header.add_option(option)
'''
def bytes(self):
return self.pim_header.bytes()
def bytes(self) -> bytes:
return self.payload.bytes()
import struct
from utils import checksum
import socket
from .PacketPayload import PacketPayload
'''
0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Type | Max Resp Time | Checksum |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Group Address |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Resv |S| QRV | QQIC | Number of Sources (N) |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Source Address [1] |
+- -+
| Source Address [2] |
+- . -+
. . .
. . .
+- -+
| Source Address [N] |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
'''
class PacketIGMPHeader(PacketPayload):
IGMP_TYPE = 2
IGMP_HDR = "! BB H 4s"
IGMP_HDR_LEN = struct.calcsize(IGMP_HDR)
IGMP3_SRC_ADDR_HDR = "! BB H "
IGMP3_SRC_ADDR_HDR_LEN = struct.calcsize(IGMP3_SRC_ADDR_HDR)
IPv4_HDR = "! 4s"
IPv4_HDR_LEN = struct.calcsize(IPv4_HDR)
Membership_Query = 0x11
Version_2_Membership_Report = 0x16
Leave_Group = 0x17
Version_1_Membership_Report = 0x12
def __init__(self, type: int, max_resp_time: int, group_address: str="0.0.0.0"):
# todo check type
self.type = type
self.max_resp_time = max_resp_time
self.group_address = group_address
def bytes(self) -> bytes:
# obter mensagem e criar checksum
msg_without_chcksum = struct.pack(PacketIGMPHeader.IGMP_HDR, self.type, self.max_resp_time, 0,
socket.inet_aton(self.group_address))
igmp_checksum = checksum(msg_without_chcksum)
msg = msg_without_chcksum[0:2] + struct.pack("! H", igmp_checksum) + msg_without_chcksum[4:]
return msg
def __len__(self):
return len(self.bytes())
@staticmethod
def parse_bytes(data: bytes):
print("parseIGMPHdr: ", data)
igmp_hdr = data[0:PacketIGMPHeader.IGMP_HDR_LEN]
(type, max_resp_time, rcv_checksum, group_address) = struct.unpack(PacketIGMPHeader.IGMP_HDR, igmp_hdr)
print(type, max_resp_time, rcv_checksum, group_address)
msg_to_checksum = data[0:2] + b'\x00\x00' + data[4:]
print("checksum calculated: " + str(checksum(msg_to_checksum)))
if checksum(msg_to_checksum) != rcv_checksum:
print("wrong checksum")
raise Exception("wrong checksum")
igmp_hdr = igmp_hdr[PacketIGMPHeader.IGMP_HDR_LEN:]
group_address = socket.inet_ntoa(group_address)
pkt = PacketIGMPHeader(type, max_resp_time, group_address)
return pkt
\ No newline at end of file
......@@ -21,7 +21,6 @@ import socket
'''
class PacketIpHeader:
IP_HDR = "! BBH HH BBH 4s 4s"
#IP_HDR2 = "! B"
IP_HDR_LEN = struct.calcsize(IP_HDR)
def __init__(self, ver, hdr_len, ttl, proto, ip_src, ip_dst):
......
import abc
class PacketPayload(object):
__metaclass__ = abc.ABCMeta
@abc.abstractmethod
def bytes(self) -> bytes:
"""Get packet payload in bytes format"""
@abc.abstractmethod
def __len__(self):
"""Get packet payload length"""
@staticmethod
@abc.abstractmethod
def parse_bytes(data: bytes):
"""From bytes create a object payload"""
import struct
import socket
from Packet.PacketPimEncodedGroupAddress import PacketPimEncodedGroupAddress
from Packet.PacketPimEncodedUnicastAddress import PacketPimEncodedUnicastAddress
'''
0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
|PIM Ver| Type | Reserved | Checksum |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Multicast Group Address (Encoded Group Format) |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Source Address (Encoded Unicast Format) |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
|R| Metric Preference |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Metric |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
'''
class PacketPimAssert:
PIM_TYPE = 5
PIM_HDR_ASSERT = "! %ss %ss LL"
PIM_HDR_ASSERT_WITHOUT_ADDRESS = "! LL"
PIM_HDR_ASSERT_v4 = PIM_HDR_ASSERT % (PacketPimEncodedGroupAddress.PIM_ENCODED_GROUP_ADDRESS_HDR_LEN, PacketPimEncodedUnicastAddress.PIM_ENCODED_UNICAST_ADDRESS_HDR_LEN)
PIM_HDR_ASSERT_v6 = PIM_HDR_ASSERT % (PacketPimEncodedGroupAddress.PIM_ENCODED_GROUP_ADDRESS_HDR_LEN_IPv6, PacketPimEncodedUnicastAddress.PIM_ENCODED_UNICAST_ADDRESS_HDR_LEN_IPV6)
PIM_HDR_ASSERT_WITHOUT_ADDRESS_LEN = struct.calcsize(PIM_HDR_ASSERT_WITHOUT_ADDRESS)
PIM_HDR_ASSERT_v4_LEN = struct.calcsize(PIM_HDR_ASSERT_v4)
PIM_HDR_ASSERT_v6_LEN = struct.calcsize(PIM_HDR_ASSERT_v6)
def __init__(self, multicast_group_address: str or bytes, source_address: str or bytes, metric_preference, metric):
if type(multicast_group_address) is bytes:
multicast_group_address = socket.inet_ntoa(multicast_group_address)
if type(source_address) is bytes:
source_address = socket.inet_ntoa(source_address)
self.multicast_group_address = multicast_group_address
self.source_address = source_address
self.metric_preference = metric_preference
self.metric = metric
def bytes(self) -> bytes:
multicast_group_address = PacketPimEncodedGroupAddress(self.multicast_group_address).bytes()
source_address = PacketPimEncodedUnicastAddress(self.source_address).bytes()
msg = multicast_group_address + source_address + struct.pack(PacketPimAssert.PIM_HDR_ASSERT_WITHOUT_ADDRESS,
0x7FFFFFFF & self.metric_preference,
self.metric)
return msg
def __len__(self):
return len(self.bytes())
@staticmethod
def parse_bytes(data: bytes):
multicast_group_addr_obj = PacketPimEncodedGroupAddress.parse_bytes(data)
multicast_group_addr_len = len(multicast_group_addr_obj)
data = data[multicast_group_addr_len:]
source_addr_obj = PacketPimEncodedUnicastAddress.parse_bytes(data)
source_addr_len = len(source_addr_obj)
data = data[source_addr_len:]
(metric_preference, metric) = struct.unpack(PacketPimAssert.PIM_HDR_ASSERT_WITHOUT_ADDRESS, data[:PacketPimAssert.PIM_HDR_ASSERT_WITHOUT_ADDRESS_LEN])
pim_payload = PacketPimAssert(multicast_group_addr_obj.group_address, source_addr_obj.unicast_address, metric_preference, metric)
return pim_payload
from Packet.PacketPimJoinPrune import PacketPimJoinPrune
'''
0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Upstream Neighbor Address (Encoded Unicast Format) |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Reserved | Num Groups | Hold Time |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Multicast Group Address 1 (Encoded Group Format) |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Number of Joined Sources | Number of Pruned Sources |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Joined Source Address 1 (Encoded Source Format) |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| . |
| . |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Joined Source Address n (Encoded Source Format) |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Pruned Source Address 1 (Encoded Source Format) |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| . |
| . |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Pruned Source Address n (Encoded Source Format) |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| . |
| . |
| . |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Multicast Group Address m (Encoded Group Format) |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Number of Joined Sources | Number of Pruned Sources |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Joined Source Address 1 (Encoded Source Format) |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| . |
| . |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
'''
class PacketPimGraft(PacketPimJoinPrune):
PIM_TYPE = 6
def __init__(self, upstream_neighbor_address, hold_time):
super().__init__(upstream_neighbor_address, hold_time)
from Packet.PacketPimJoinPrune import PacketPimJoinPrune
'''
0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Upstream Neighbor Address (Encoded Unicast Format) |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Reserved | Num Groups | Hold Time |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Multicast Group Address 1 (Encoded Group Format) |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Number of Joined Sources | Number of Pruned Sources |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Joined Source Address 1 (Encoded Source Format) |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| . |
| . |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Joined Source Address n (Encoded Source Format) |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Pruned Source Address 1 (Encoded Source Format) |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| . |
| . |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Pruned Source Address n (Encoded Source Format) |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| . |
| . |
| . |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Multicast Group Address m (Encoded Group Format) |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Number of Joined Sources | Number of Pruned Sources |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Joined Source Address 1 (Encoded Source Format) |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| . |
| . |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
'''
class PacketPimGraftAck(PacketPimJoinPrune):
PIM_TYPE = 7
def __init__(self, upstream_neighbor_address, hold_time):
super().__init__(upstream_neighbor_address, hold_time)
......@@ -2,8 +2,14 @@ import struct
from Packet.PacketPimHello import PacketPimHello
from Packet.PacketPimJoinPrune import PacketPimJoinPrune
from utils import checksum
from Packet.PacketPimAssert import PacketPimAssert
from Packet.PacketPimGraft import PacketPimGraft
from Packet.PacketPimGraftAck import PacketPimGraftAck
from Packet.PacketPimStateRefresh import PacketPimStateRefresh
from utils import checksum
from .PacketPayload import PacketPayload
'''
0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
......@@ -11,15 +17,22 @@ from utils import checksum
|PIM Ver| Type | Reserved | Checksum |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
'''
class PacketPimHeader:
class PacketPimHeader(PacketPayload):
PIM_VERSION = 2
PIM_HDR = "! BB H"
PIM_HDR_LEN = struct.calcsize(PIM_HDR)
PIM_MSG_TYPES = {0: PacketPimHello,
3: PacketPimJoinPrune,
5: PacketPimAssert,
6: PacketPimGraft,
7: PacketPimGraftAck,
9: PacketPimStateRefresh
}
def __init__(self, payload):
self.payload = payload
#self.msg_type = msg_type
def get_pim_type(self):
return self.payload.PIM_TYPE
......@@ -58,6 +71,8 @@ class PacketPimHeader:
raise Exception
pim_payload = data[PacketPimHeader.PIM_HDR_LEN:]
pim_payload = PacketPimHeader.PIM_MSG_TYPES[pim_type].parse_bytes(pim_payload)
'''
if pim_type == 0: # hello
pim_payload = PacketPimHello.parse_bytes(pim_payload)
elif pim_type == 3: # join/prune
......@@ -68,8 +83,9 @@ class PacketPimHeader:
print(i.multicast_group)
print(i.joined_src_addresses)
print(i.pruned_src_addresses)
elif pim_type == 5: # assert
pim_payload = PacketPimAssert.parse_bytes(pim_payload)
else:
raise Exception
'''
return PacketPimHeader(pim_payload)
......@@ -26,15 +26,18 @@ class PacketPimHello:
PIM_MSG_TYPES_LENGTH = {1: 2,
20: 4,
21: 4,
}
# todo: pensar melhor na implementacao state refresh capable option...
def __init__(self):
self.options = {}
def add_option(self, option_type: int, option_value: int):
if option_value is None:
del self.options[option_type]
return
def add_option(self, option_type: int, option_value: int or float):
option_value = int(option_value)
# if option_value requires more bits than the bits available for that field: option value will have all field bits = 1
if option_type in self.PIM_MSG_TYPES_LENGTH and self.PIM_MSG_TYPES_LENGTH[option_type] * 8 < option_value.bit_length():
option_value = (1 << (self.PIM_MSG_TYPES_LENGTH[option_type] * 8)) - 1
self.options[option_type] = option_value
def get_options(self):
......
......@@ -46,6 +46,9 @@ class PacketPimJoinPrune:
return msg
def __len__(self):
return len(self.bytes())
@staticmethod
def parse_bytes(data: bytes):
upstream_neighbor_addr_obj = PacketPimEncodedUnicastAddress.parse_bytes(data)
......
......@@ -43,11 +43,11 @@ class PacketPimJoinPruneMulticastGroup:
PIM_HDR_JOINED_PRUNED_SOURCE_v4_LEN = PacketPimEncodedSourceAddress.PIM_ENCODED_SOURCE_ADDRESS_HDR_LEN
PIM_HDR_JOINED_PRUNED_SOURCE_v6_LEN = PacketPimEncodedSourceAddress.PIM_ENCODED_SOURCE_ADDRESS_HDR_LEN_IPV6
def __init__(self, multicast_group, joined_src_addresses: list, pruned_src_addresses: list):
def __init__(self, multicast_group: str or bytes, joined_src_addresses: list=[], pruned_src_addresses: list=[]):
if type(multicast_group) not in (str, bytes):
raise Exception
elif type(multicast_group) is bytes:
self.multicast_group = socket.inet_ntoa(self.multicast_group)
multicast_group = socket.inet_ntoa(multicast_group)
if type(joined_src_addresses) is not list:
raise Exception
......
import struct
import socket
from Packet.PacketPimEncodedUnicastAddress import PacketPimEncodedUnicastAddress
from Packet.PacketPimEncodedGroupAddress import PacketPimEncodedGroupAddress
'''
0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
|PIM Ver| Type | Reserved | Checksum |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Multicast Group Address (Encoded Group Format) |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Source Address (Encoded Unicast Format) |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Originator Address (Encoded Unicast Format) |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
|R| Metric Preference |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Metric |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Masklen | TTL |P|N|O|Reserved | Interval |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
'''
class PacketPimStateRefresh:
PIM_TYPE = 9
PIM_HDR_STATE_REFRESH = "! %ss %ss %ss I I BBBB"
PIM_HDR_STATE_REFRESH_WITHOUT_ADDRESSES = "! I I BBBB"
PIM_HDR_STATE_REFRESH_v4 = PIM_HDR_STATE_REFRESH % (PacketPimEncodedGroupAddress.PIM_ENCODED_GROUP_ADDRESS_HDR_LEN, PacketPimEncodedUnicastAddress.PIM_ENCODED_UNICAST_ADDRESS_HDR_LEN, PacketPimEncodedUnicastAddress.PIM_ENCODED_UNICAST_ADDRESS_HDR_LEN)
PIM_HDR_STATE_REFRESH_v6 = PIM_HDR_STATE_REFRESH % (PacketPimEncodedGroupAddress.PIM_ENCODED_GROUP_ADDRESS_HDR_LEN_IPv6, PacketPimEncodedUnicastAddress.PIM_ENCODED_UNICAST_ADDRESS_HDR_LEN_IPV6, PacketPimEncodedUnicastAddress.PIM_ENCODED_UNICAST_ADDRESS_HDR_LEN_IPV6)
PIM_HDR_STATE_REFRESH_WITHOUT_ADDRESSES_LEN = struct.calcsize(PIM_HDR_STATE_REFRESH_WITHOUT_ADDRESSES)
PIM_HDR_STATE_REFRESH_v4_LEN = struct.calcsize(PIM_HDR_STATE_REFRESH_v4)
PIM_HDR_STATE_REFRESH_v6_LEN = struct.calcsize(PIM_HDR_STATE_REFRESH_v6)
def __init__(self, multicast_group_adress: str or bytes, source_address: str or bytes, originator_adress: str or bytes,
metric_preference: int, metric: int, mask_len: int, ttl: int, prune_indicator_flag: bool,
prune_now_flag: bool, assert_override_flag: bool, interval: int):
if type(multicast_group_adress) is bytes:
multicast_group_adress = socket.inet_ntoa(multicast_group_adress)
if type(source_address) is bytes:
source_address = socket.inet_ntoa(source_address)
if type(originator_adress) is bytes:
originator_adress = socket.inet_ntoa(originator_adress)
self.multicast_group_adress = multicast_group_adress
self.source_address = source_address
self.originator_adress = originator_adress
self.metric_preference = metric_preference
self.metric = metric
self.mask_len = mask_len
self.ttl = ttl
self.prune_indicator_flag = prune_indicator_flag
self.prune_now_flag = prune_now_flag
self.assert_override_flag = assert_override_flag
self.interval = interval
def bytes(self) -> bytes:
multicast_group_adress = PacketPimEncodedGroupAddress(self.multicast_group_adress).bytes()
source_address = PacketPimEncodedUnicastAddress(self.source_address).bytes()
originator_adress = PacketPimEncodedUnicastAddress(self.originator_adress).bytes()
prune_and_assert_flags = (self.prune_indicator_flag << 7) | (self.prune_now_flag << 6) | (self.assert_override_flag << 5)
msg = multicast_group_adress + source_address + originator_adress + \
struct.pack(self.PIM_HDR_STATE_REFRESH_WITHOUT_ADDRESSES, 0x7FFFFFFF & self.metric_preference,
self.metric, self.mask_len, self.ttl, prune_and_assert_flags, self. interval)
return msg
def __len__(self):
return len(self.bytes())
@staticmethod
def parse_bytes(data: bytes):
multicast_group_adress_obj = PacketPimEncodedGroupAddress.parse_bytes(data)
multicast_group_adress_len = len(multicast_group_adress_obj)
data = data[multicast_group_adress_len:]
source_address_obj = PacketPimEncodedUnicastAddress.parse_bytes(data)
source_address_len = len(source_address_obj)
data = data[source_address_len:]
originator_address_obj = PacketPimEncodedUnicastAddress.parse_bytes(data)
originator_address_len = len(originator_address_obj)
data = data[originator_address_len:]
(metric_preference, metric, mask_len, ttl, reserved_and_prune_and_assert_flags, interval) = struct.unpack(PacketPimStateRefresh.PIM_HDR_STATE_REFRESH_WITHOUT_ADDRESSES, data[:PacketPimStateRefresh.PIM_HDR_STATE_REFRESH_WITHOUT_ADDRESSES_LEN])
metric_preference = 0x7FFFFFFF & metric_preference
prune_indicator_flag = (0x80 & reserved_and_prune_and_assert_flags) >> 7
prune_now_flag = (0x40 & reserved_and_prune_and_assert_flags) >> 6
assert_override_flag = (0x20 & reserved_and_prune_and_assert_flags) >> 5
pim_payload = PacketPimStateRefresh(multicast_group_adress_obj.group_address, source_address_obj.unicast_address,
originator_address_obj.unicast_address, metric_preference, metric, mask_len,
ttl, prune_indicator_flag, prune_now_flag, assert_override_flag, interval)
return pim_payload
import struct
from Packet.Packet import Packet
from Packet.PacketIpHeader import PacketIpHeader
from Packet.PacketPimHeader import PacketPimHeader
from Packet.PacketPimHello import PacketPimHello
from Packet.PacketPimJoinPrune import PacketPimJoinPrune
from Packet.PacketPimJoinPruneMulticastGroup import PacketPimJoinPruneMulticastGroup
import socket
from utils import checksum
from Packet.PacketIGMPHeader import PacketIGMPHeader
from .PacketPimHeader import PacketPimHeader
from utils import TYPE_CHECKING
if TYPE_CHECKING:
from Interface import Interface
class ReceivedPacket(Packet):
def __init__(self, raw_packet, interface):
# choose payload protocol class based on ip protocol number
payload_protocol = {2: PacketIGMPHeader, 103: PacketPimHeader}
def __init__(self, raw_packet: bytes, interface: 'Interface'):
self.interface = interface
# Parse ao packet e preencher objeto Packet
packet_ip_hdr = raw_packet[:PacketIpHeader.IP_HDR_LEN]
ip_header = PacketIpHeader.parse_bytes(packet_ip_hdr)
protocol_number = ip_header.proto
packet_without_ip_hdr = raw_packet[ip_header.hdr_length:]
pim_header = PacketPimHeader.parse_bytes(packet_without_ip_hdr)
payload = ReceivedPacket.payload_protocol[protocol_number].parse_bytes(packet_without_ip_hdr)
super().__init__(ip_header=ip_header, pim_header=pim_header)
\ No newline at end of file
super().__init__(ip_header=ip_header, payload=payload)
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Read Write Lock
"""
import threading
import time
class RWLockRead(object):
"""
A Read/Write lock giving preference to Reader
"""
def __init__(self):
self.V_ReadCount = 0
self.A_Resource = threading.Lock()
self.A_LockReadCount = threading.Lock()
class _aReader(object):
def __init__(self, p_RWLock):
self.A_RWLock = p_RWLock
self.V_Locked = False
def acquire(self, blocking=1, timeout=-1):
p_TimeOut = None if (blocking and timeout < 0) else (timeout if blocking else 0)
c_DeadLine = None if p_TimeOut is None else (time.time() + p_TimeOut)
if not self.A_RWLock.A_LockReadCount.acquire(blocking=1, timeout=-1 if c_DeadLine is None else max(0, c_DeadLine - time.time())):
return False
self.A_RWLock.V_ReadCount += 1
if self.A_RWLock.V_ReadCount == 1:
if not self.A_RWLock.A_Resource.acquire(blocking=1, timeout=-1 if c_DeadLine is None else max(0, c_DeadLine - time.time())):
self.A_RWLock.V_ReadCount -= 1
self.A_RWLock.A_LockReadCount.release()
return False
self.A_RWLock.A_LockReadCount.release()
self.V_Locked = True
return True
def release(self):
if not self.V_Locked: raise RuntimeError("cannot release un-acquired lock")
self.V_Locked = False
self.A_RWLock.A_LockReadCount.acquire()
self.A_RWLock.V_ReadCount -= 1
if self.A_RWLock.V_ReadCount == 0:
self.A_RWLock.A_Resource.release()
self.A_RWLock.A_LockReadCount.release()
def locked(self):
return self.V_Locked
def __enter__(self):
self.acquire()
def __exit__(self, p_Type, p_Value, p_Traceback):
self.release()
class _aWriter(object):
def __init__(self, p_RWLock):
self.A_RWLock = p_RWLock
self.V_Locked = False
def acquire(self, blocking=1, timeout=-1):
self.V_Locked = self.A_RWLock.A_Resource.acquire(blocking, timeout)
return self.V_Locked
def release(self):
if not self.V_Locked: raise RuntimeError("cannot release un-acquired lock")
self.V_Locked = False
self.A_RWLock.A_Resource.release()
def locked(self):
return self.V_Locked
def __enter__(self):
self.acquire()
def __exit__(self, p_Type, p_Value, p_Traceback):
self.release()
def genRlock(self):
"""
Generate a reader lock
"""
return RWLockRead._aReader(self)
def genWlock(self):
"""
Generate a writer lock
"""
return RWLockRead._aWriter(self)
class RWLockWrite(object):
"""
A Read/Write lock giving preference to Writer
"""
def __init__(self):
self.V_ReadCount = 0
self.V_WriteCount = 0
self.A_LockReadCount = threading.Lock()
self.A_LockWriteCount = threading.Lock()
self.A_LockReadEntry = threading.Lock()
self.A_LockReadTry = threading.Lock()
self.A_Resource = threading.Lock()
class _aReader(object):
def __init__(self, p_RWLock):
self.A_RWLock = p_RWLock
self.V_Locked = False
def acquire(self, blocking=1, timeout=-1):
p_TimeOut = None if (blocking and timeout < 0) else (timeout if blocking else 0)
c_DeadLine = None if p_TimeOut is None else (time.time() + p_TimeOut)
if not self.A_RWLock.A_LockReadEntry.acquire(blocking=1, timeout=-1 if c_DeadLine is None else max(0, c_DeadLine - time.time())):
return False
if not self.A_RWLock.A_LockReadTry.acquire(blocking=1, timeout=-1 if c_DeadLine is None else max(0, c_DeadLine - time.time())):
self.A_RWLock.A_LockReadEntry.release()
return False
if not self.A_RWLock.A_LockReadCount.acquire(blocking=1, timeout=-1 if c_DeadLine is None else max(0, c_DeadLine - time.time())):
self.A_RWLock.A_LockReadTry.release()
self.A_RWLock.A_LockReadEntry.release()
return False
self.A_RWLock.V_ReadCount += 1
if (self.A_RWLock.V_ReadCount == 1):
if not self.A_RWLock.A_Resource.acquire(blocking=1, timeout=-1 if c_DeadLine is None else max(0, c_DeadLine - time.time())):
self.A_RWLock.A_LockReadTry.release()
self.A_RWLock.A_LockReadEntry.release()
self.A_RWLock.V_ReadCount -= 1
self.A_RWLock.A_LockReadCount.release()
return False
self.A_RWLock.A_LockReadCount.release()
self.A_RWLock.A_LockReadTry.release()
self.A_RWLock.A_LockReadEntry.release()
self.V_Locked = True
return True
def release(self):
if not self.V_Locked: raise RuntimeError("cannot release un-acquired lock")
self.V_Locked = False
self.A_RWLock.A_LockReadCount.acquire()
self.A_RWLock.V_ReadCount -= 1
if (self.A_RWLock.V_ReadCount == 0):
self.A_RWLock.A_Resource.release()
self.A_RWLock.A_LockReadCount.release()
def locked(self):
return self.V_Locked
def __enter__(self):
self.acquire()
def __exit__(self, p_Type, p_Value, p_Traceback):
self.release()
class _aWriter(object):
def __init__(self, p_RWLock):
self.A_RWLock = p_RWLock
self.V_Locked = False
def acquire(self, blocking=1, timeout=-1):
p_TimeOut = None if (blocking and timeout < 0) else (timeout if blocking else 0)
c_DeadLine = None if p_TimeOut is None else (time.time() + p_TimeOut)
if not self.A_RWLock.A_LockWriteCount.acquire(blocking=1, timeout=-1 if c_DeadLine is None else max(0, c_DeadLine - time.time())):
return False
self.A_RWLock.V_WriteCount += 1
if (self.A_RWLock.V_WriteCount == 1):
if not self.A_RWLock.A_LockReadTry.acquire(blocking=1, timeout=-1 if c_DeadLine is None else max(0, c_DeadLine - time.time())):
self.A_RWLock.V_WriteCount -= 1
self.A_RWLock.A_LockWriteCount.release()
return False
self.A_RWLock.A_LockWriteCount.release()
if not self.A_RWLock.A_Resource.acquire(blocking=1, timeout=-1 if c_DeadLine is None else max(0, c_DeadLine - time.time())):
self.A_RWLock.A_LockWriteCount.acquire()
self.A_RWLock.V_WriteCount -= 1
if self.A_RWLock.V_WriteCount == 0:
self.A_RWLock.A_LockReadTry.release()
self.A_RWLock.A_LockWriteCount.release()
return False
self.V_Locked = True
return True
def release(self):
if not self.V_Locked: raise RuntimeError("cannot release un-acquired lock")
self.V_Locked = False
self.A_RWLock.A_Resource.release()
self.A_RWLock.A_LockWriteCount.acquire()
self.A_RWLock.V_WriteCount -= 1
if (self.A_RWLock.V_WriteCount == 0):
self.A_RWLock.A_LockReadTry.release()
self.A_RWLock.A_LockWriteCount.release()
def locked(self):
return self.V_Locked
def __enter__(self):
self.acquire()
def __exit__(self, p_Type, p_Value, p_Traceback):
self.release()
def genRlock(self):
"""
Generate a reader lock
"""
return RWLockWrite._aReader(self)
def genWlock(self):
"""
Generate a writer lock
"""
return RWLockWrite._aWriter(self)
class RWLockFair(object):
"""
A Read/Write lock giving fairness to both Reader and Writer
"""
def __init__(self):
self.V_ReadCount = 0
self.A_LockReadCount = threading.Lock()
self.A_LockRead = threading.Lock()
self.A_LockWrite = threading.Lock()
class _aReader(object):
def __init__(self, p_RWLock):
self.A_RWLock = p_RWLock
self.V_Locked = False
def acquire(self, blocking=1, timeout=-1):
p_TimeOut = None if (blocking and timeout < 0) else (timeout if blocking else 0)
c_DeadLine = None if p_TimeOut is None else (time.time() + p_TimeOut)
if not self.A_RWLock.A_LockRead.acquire(blocking=1, timeout=-1 if c_DeadLine is None else max(0, c_DeadLine - time.time())):
return False
if not self.A_RWLock.A_LockReadCount.acquire(blocking=1, timeout=-1 if c_DeadLine is None else max(0, c_DeadLine - time.time())):
self.A_RWLock.A_LockRead.release()
return False
self.A_RWLock.V_ReadCount += 1
if self.A_RWLock.V_ReadCount == 1:
if not self.A_RWLock.A_LockWrite.acquire(blocking=1, timeout=-1 if c_DeadLine is None else max(0, c_DeadLine - time.time())):
self.A_RWLock.V_ReadCount -= 1
self.A_RWLock.A_LockReadCount.release()
self.A_RWLock.A_LockRead.release()
return False
self.A_RWLock.A_LockReadCount.release()
self.A_RWLock.A_LockRead.release()
self.V_Locked = True
return True
def release(self):
if not self.V_Locked: raise RuntimeError("cannot release un-acquired lock")
self.V_Locked = False
self.A_RWLock.A_LockReadCount.acquire()
self.A_RWLock.V_ReadCount -= 1
if self.A_RWLock.V_ReadCount == 0:
self.A_RWLock.A_LockWrite.release()
self.A_RWLock.A_LockReadCount.release()
def locked(self):
return self.V_Locked
def __enter__(self):
self.acquire()
def __exit__(self, p_Type, p_Value, p_Traceback):
self.release()
class _aWriter(object):
def __init__(self, p_RWLock):
self.A_RWLock = p_RWLock
self.V_Locked = False
def acquire(self, blocking=1, timeout=-1):
p_TimeOut = None if (blocking and timeout < 0) else (timeout if blocking else 0)
c_DeadLine = None if p_TimeOut is None else (time.time() + p_TimeOut)
if not self.A_RWLock.A_LockRead.acquire(blocking=1, timeout=-1 if c_DeadLine is None else max(0, c_DeadLine - time.time())):
return False
if not self.A_RWLock.A_LockWrite.acquire(blocking=1, timeout=-1 if c_DeadLine is None else max(0, c_DeadLine - time.time())):
self.A_RWLock.A_LockRead.release()
return False
self.V_Locked = True
return True
def release(self):
if not self.V_Locked: raise RuntimeError("cannot release un-acquired lock")
self.V_Locked = False
self.A_RWLock.A_LockWrite.release()
self.A_RWLock.A_LockRead.release()
def locked(self):
return self.V_Locked
def __enter__(self):
self.acquire()
def __exit__(self, p_Type, p_Value, p_Traceback):
self.release()
def genRlock(self):
"""
Generate a reader lock
"""
return RWLockFair._aReader(self)
def genWlock(self):
"""
Generate a writer lock
"""
return RWLockFair._aWriter(self)
......@@ -16,7 +16,7 @@ def client_socket(data_to_send):
# Connect the socket to the port where the server is listening
server_address = './uds_socket'
print('connecting to %s' % server_address)
#print('connecting to %s' % server_address)
try:
sock.connect(server_address)
sock.sendall(pickle.dumps(data_to_send))
......@@ -26,7 +26,7 @@ def client_socket(data_to_send):
except socket.error:
pass
finally:
print('closing socket')
#print('closing socket')
sock.close()
......@@ -61,14 +61,22 @@ class MyDaemon(Daemon):
connection.sendall(pickle.dumps(Main.list_enabled_interfaces()))
elif args.list_neighbors:
connection.sendall(pickle.dumps(Main.list_neighbors()))
elif args.list_state:
connection.sendall(pickle.dumps(Main.list_igmp_state()))
elif args.add_interface:
Main.add_interface(args.add_interface[0])
Main.add_interface(args.add_interface[0], pim=True)
connection.shutdown(socket.SHUT_RDWR)
elif args.add_interface_igmp:
Main.add_interface(args.add_interface_igmp[0], igmp=True)
connection.shutdown(socket.SHUT_RDWR)
elif args.remove_interface:
Main.remove_interface(args.remove_interface[0])
Main.remove_interface(args.remove_interface[0], pim=True)
connection.shutdown(socket.SHUT_RDWR)
elif args.remove_interface_igmp:
Main.remove_interface(args.remove_interface_igmp[0], igmp=True)
connection.shutdown(socket.SHUT_RDWR)
elif args.stop:
Main.remove_interface("*")
Main.remove_interface("*", pim=True, igmp=True)
connection.shutdown(socket.SHUT_RDWR)
except Exception:
connection.shutdown(socket.SHUT_RDWR)
......@@ -88,9 +96,12 @@ if __name__ == "__main__":
group.add_argument("-restart", "--restart", action="store_true", default=False, help="Restart PIM")
group.add_argument("-li", "--list_interfaces", action="store_true", default=False, help="List All PIM Interfaces")
group.add_argument("-ln", "--list_neighbors", action="store_true", default=False, help="List All PIM Neighbors")
group.add_argument("-ls", "--list_state", action="store_true", default=False, help="List state of IGMP")
group.add_argument("-mr", "--multicast_routes", action="store_true", default=False, help="List Multicast Routing table")
group.add_argument("-ai", "--add_interface", nargs=1, metavar='INTERFACE_NAME', help="Add PIM interface")
group.add_argument("-aiigmp", "--add_interface_igmp", nargs=1, metavar='INTERFACE_NAME', help="Add IGMP interface")
group.add_argument("-ri", "--remove_interface", nargs=1, metavar='INTERFACE_NAME', help="Remove PIM interface")
group.add_argument("-riigmp", "--remove_interface_igmp", nargs=1, metavar='INTERFACE_NAME', help="Remove IGMP interface")
group.add_argument("-v", "--verbose", action="store_true", default=False, help="Verbose (print all debug messages)")
args = parser.parse_args()
......
import random
from threading import Timer
from Packet.Packet import Packet
from Packet.ReceivedPacket import ReceivedPacket
from Packet.PacketPimHello import PacketPimHello
from Packet.PacketPimHeader import PacketPimHeader
from Interface import Interface
import Main
from utils import HELLO_HOLD_TIME_TIMEOUT
class StateRefresh:
TYPE = 9
def __init__(self):
Main.add_protocol(StateRefresh.TYPE, self)
# receive handler
def receive_handle(self, packet: ReceivedPacket):
ip = packet.ip_header.ip_src
print("ip = ", ip)
pkt_join_prune = packet.payload.payload
# TODO
raise Exception
\ No newline at end of file
from threading import Timer
from .wrapper import NoMembersPresent
from utils import GroupMembershipInterval, LastMemberQueryInterval, TYPE_CHECKING
from threading import Lock
if TYPE_CHECKING:
from .RouterState import RouterState
class GroupState(object):
def __init__(self, router_state: 'RouterState', group_ip: str):
self.router_state = router_state
self.group_ip = group_ip
self.state = NoMembersPresent
self.timer = None
self.v1_host_timer = None
self.retransmit_timer = None
# lock
self.lock = Lock()
def print_state(self):
return self.state.print_state()
###########################################
# Set timers
###########################################
def set_timer(self, alternative: bool=False, max_response_time: int=None):
self.clear_timer()
if not alternative:
time = GroupMembershipInterval
else:
time = self.router_state.interface_state.get_group_membership_time(max_response_time)
timer = Timer(time, self.group_membership_timeout)
timer.start()
self.timer = timer
def clear_timer(self):
if self.timer is not None:
self.timer.cancel()
def set_v1_host_timer(self):
self.clear_v1_host_timer()
v1_host_timer = Timer(GroupMembershipInterval, self.group_membership_v1_timeout)
v1_host_timer.start()
self.v1_host_timer = v1_host_timer
def clear_v1_host_timer(self):
if self.v1_host_timer is not None:
self.v1_host_timer.cancel()
def set_retransmit_timer(self):
self.clear_retransmit_timer()
retransmit_timer = Timer(LastMemberQueryInterval, self.retransmit_timeout)
retransmit_timer.start()
self.retransmit_timer = retransmit_timer
def clear_retransmit_timer(self):
if self.retransmit_timer is not None:
self.retransmit_timer.cancel()
###########################################
# Get group state from specific interface state
###########################################
def get_interface_group_state(self):
return self.state.get_state(self.router_state)
###########################################
# Timer timeout
###########################################
def group_membership_timeout(self):
with self.lock:
self.get_interface_group_state().group_membership_timeout(self)
def group_membership_v1_timeout(self):
with self.lock:
self.get_interface_group_state().group_membership_v1_timeout(self)
def retransmit_timeout(self):
with self.lock:
self.get_interface_group_state().retransmit_timeout(self)
###########################################
# Receive Packets
###########################################
def receive_v1_membership_report(self):
with self.lock:
self.get_interface_group_state().receive_v1_membership_report(self)
def receive_v2_membership_report(self):
with self.lock:
self.get_interface_group_state().receive_v2_membership_report(self)
def receive_leave_group(self):
with self.lock:
self.get_interface_group_state().receive_leave_group(self)
def receive_group_specific_query(self, max_response_time: int):
with self.lock:
self.get_interface_group_state().receive_group_specific_query(self, max_response_time)
from Packet.PacketIGMPHeader import PacketIGMPHeader
from Packet.ReceivedPacket import ReceivedPacket
from threading import Timer
from utils import Membership_Query, QueryResponseInterval, QueryInterval, OtherQuerierPresentInterval, TYPE_CHECKING
from .querier.Querier import Querier
from .nonquerier.NonQuerier import NonQuerier
from .GroupState import GroupState
if TYPE_CHECKING:
from InterfaceIGMP import InterfaceIGMP
class RouterState(object):
def __init__(self, interface: 'InterfaceIGMP'):
# interface of the router connected to the network
self.interface = interface
# state of the router (Querier/NonQuerier)
self.interface_state = Querier
# state of each group
# Key: GroupIPAddress, Value: GroupState object
self.group_state = {}
# send general query
packet = PacketIGMPHeader(type=Membership_Query, max_resp_time=QueryResponseInterval*10)
self.interface.send(packet.bytes())
# set initial general query timer
timer = Timer(QueryInterval, self.general_query_timeout)
timer.start()
self.general_query_timer = timer
# present timer
self.other_querier_present_timer = None
# Send packet via interface
def send(self, data: bytes, address: str):
self.interface.send(data, address)
############################################
# interface_state methods
############################################
def print_state(self):
return self.interface_state.state_name()
def set_general_query_timer(self):
self.clear_general_query_timer()
general_query_timer = Timer(QueryInterval, self.general_query_timeout)
general_query_timer.start()
self.general_query_timer = general_query_timer
def clear_general_query_timer(self):
if self.general_query_timer is not None:
self.general_query_timer.cancel()
def set_other_querier_present_timer(self):
self.clear_other_querier_present_timer()
other_querier_present_timer = Timer(OtherQuerierPresentInterval, self.other_querier_present_timeout)
other_querier_present_timer.start()
self.other_querier_present_timer = other_querier_present_timer
def clear_other_querier_present_timer(self):
if self.other_querier_present_timer is not None:
self.other_querier_present_timer.cancel()
def general_query_timeout(self):
self.interface_state.general_query_timeout(self)
def other_querier_present_timeout(self):
self.interface_state.other_querier_present_timeout(self)
def change_interface_state(self, querier: bool):
if querier:
self.interface_state = Querier
else:
self.interface_state = NonQuerier
############################################
# group state methods
############################################
def receive_v1_membership_report(self, packet: ReceivedPacket):
igmp_group = packet.payload.group_address
if igmp_group not in self.group_state:
self.group_state[igmp_group] = GroupState(self, igmp_group)
self.group_state[igmp_group].receive_v1_membership_report()
def receive_v2_membership_report(self, packet: ReceivedPacket):
igmp_group = packet.payload.group_address
if igmp_group not in self.group_state:
self.group_state[igmp_group] = GroupState(self, igmp_group)
self.group_state[igmp_group].receive_v2_membership_report()
def receive_leave_group(self, packet: ReceivedPacket):
igmp_group = packet.payload.group_address
if igmp_group in self.group_state:
self.group_state[igmp_group].receive_leave_group()
def receive_query(self, packet: ReceivedPacket):
self.interface_state.receive_query(self, packet)
igmp_group = packet.payload.group_address
# process group specific query
if igmp_group != "0.0.0.0" and igmp_group in self.group_state:
max_response_time = packet.payload.max_resp_time
self.group_state[igmp_group].receive_group_specific_query(max_response_time)
from ..wrapper import NoMembersPresent
from ..wrapper import MembersPresent
from utils import TYPE_CHECKING
if TYPE_CHECKING:
from ..GroupState import GroupState
def group_membership_timeout(group_state: 'GroupState'):
group_ip = group_state.group_ip
# TODO NOTIFY ROUTING - !!!!
group_state.state = NoMembersPresent
def group_membership_v1_timeout(group_state: 'GroupState'):
# do nothing
return
def retransmit_timeout(group_state: 'GroupState'):
# do nothing
return
def receive_v1_membership_report(group_state: 'GroupState'):
receive_v2_membership_report(group_state)
def receive_v2_membership_report(group_state: 'GroupState'):
group_state.set_timer()
group_state.state = MembersPresent
def receive_leave_group(group_state: 'GroupState'):
# do nothing
return
def receive_group_specific_query(group_state: 'GroupState', max_response_time: int):
# do nothing
return
from ..wrapper import NoMembersPresent
from ..wrapper import CheckingMembership
from utils import TYPE_CHECKING
if TYPE_CHECKING:
from ..GroupState import GroupState
def group_membership_timeout(group_state: 'GroupState'):
group_ip = group_state.group_ip
# TODO NOTIFY ROUTING - !!!!
group_state.state = NoMembersPresent
def group_membership_v1_timeout(group_state: 'GroupState'):
# do nothing
return
def retransmit_timeout(group_state: 'GroupState'):
# do nothing
return
def receive_v1_membership_report(group_state: 'GroupState'):
receive_v2_membership_report(group_state)
def receive_v2_membership_report(group_state: 'GroupState'):
group_state.set_timer()
def receive_leave_group(group_state: 'GroupState'):
# do nothing
return
def receive_group_specific_query(group_state: 'GroupState', max_response_time: int):
group_state.set_timer(alternative=True, max_response_time=max_response_time)
group_state.state = CheckingMembership
from ..wrapper import MembersPresent
from utils import TYPE_CHECKING
if TYPE_CHECKING:
from ..GroupState import GroupState
def group_membership_timeout(group_state: 'GroupState'):
# do nothing
return
def group_membership_v1_timeout(group_state: 'GroupState'):
# do nothing
return
def retransmit_timeout(group_state: 'GroupState'):
# do nothing
return
def receive_v1_membership_report(group_state: 'GroupState'):
receive_v2_membership_report(group_state)
def receive_v2_membership_report(group_state: 'GroupState'):
group_ip = group_state.group_ip
# TODO NOTIFY ROUTING + !!!!
group_state.set_timer()
group_state.state = MembersPresent
def receive_leave_group(group_state: 'GroupState'):
# do nothing
return
def receive_group_specific_query(group_state: 'GroupState', max_response_time: int):
# do nothing
return
from utils import Membership_Query, QueryResponseInterval, LastMemberQueryCount, TYPE_CHECKING
from Packet.PacketIGMPHeader import PacketIGMPHeader
from Packet.ReceivedPacket import ReceivedPacket
from . import NoMembersPresent, MembersPresent, CheckingMembership
from ipaddress import IPv4Address
if TYPE_CHECKING:
from ..RouterState import RouterState
class NonQuerier:
@staticmethod
def general_query_timeout(router_state: 'RouterState'):
# do nothing
return
@staticmethod
def other_querier_present_timeout(router_state: 'RouterState'):
#change state to Querier
router_state.change_interface_state(querier=True)
# send general query
packet = PacketIGMPHeader(type=Membership_Query, max_resp_time=QueryResponseInterval*10)
router_state.interface.send(packet.bytes())
# set general query timer
router_state.set_general_query_timer()
@staticmethod
def receive_query(router_state: 'RouterState', packet: ReceivedPacket):
source_ip = packet.ip_header.ip_src
# if source ip of membership query not lower than the ip of the received interface => ignore
if IPv4Address(source_ip) >= IPv4Address(router_state.interface.get_ip()):
return
# reset other present querier timer
router_state.set_other_querier_present_timer()
# TODO ver se existe uma melhor maneira de fazer isto
@staticmethod
def state_name():
return "Non Querier"
@staticmethod
def get_group_membership_time(max_response_time: int):
return (max_response_time/10.0) * LastMemberQueryCount
# State
@staticmethod
def get_checking_membership_state():
return CheckingMembership
@staticmethod
def get_members_present_state():
return MembersPresent
@staticmethod
def get_no_members_present_state():
return NoMembersPresent
@staticmethod
def get_version_1_members_present_state():
return NonQuerier.get_members_present_state()
from Packet.PacketIGMPHeader import PacketIGMPHeader
from ..wrapper import NoMembersPresent, MembersPresent, Version1MembersPresent
from utils import Membership_Query, LastMemberQueryInterval, TYPE_CHECKING
if TYPE_CHECKING:
from ..GroupState import GroupState
def group_membership_timeout(group_state: 'GroupState'):
group_ip = group_state.group_ip
# TODO NOTIFY ROUTING - !!!!
group_state.clear_retransmit_timer()
group_state.state = NoMembersPresent
def group_membership_v1_timeout(group_state: 'GroupState'):
# do nothing
return
def retransmit_timeout(group_state: 'GroupState'):
group_addr = group_state.group_ip
packet = PacketIGMPHeader(type=Membership_Query, max_resp_time=LastMemberQueryInterval*10, group_address=group_addr)
group_state.router_state.send(data=packet.bytes(), address=group_addr)
group_state.set_retransmit_timer()
def receive_v1_membership_report(group_state: 'GroupState'):
group_state.set_timer()
group_state.set_v1_host_timer()
group_state.state = Version1MembersPresent
def receive_v2_membership_report(group_state: 'GroupState'):
group_state.set_timer()
group_state.state = MembersPresent
def receive_leave_group(group_state: 'GroupState'):
# do nothing
return
def receive_group_specific_query(group_state: 'GroupState', max_response_time: int):
# do nothing
return
from Packet.PacketIGMPHeader import PacketIGMPHeader
from ..wrapper import Version1MembersPresent, CheckingMembership, NoMembersPresent
from utils import Membership_Query, LastMemberQueryInterval, TYPE_CHECKING
if TYPE_CHECKING:
from ..GroupState import GroupState
def group_membership_timeout(group_state: 'GroupState'):
# TODO NOTIFY ROUTING - !!!!
group_state.state = NoMembersPresent
def group_membership_v1_timeout(group_state: 'GroupState'):
# do nothing
return
def retransmit_timeout(group_state: 'GroupState'):
# do nothing
return
def receive_v1_membership_report(group_state: 'GroupState'):
group_state.set_timer()
group_state.set_v1_host_timer()
group_state.state = Version1MembersPresent
def receive_v2_membership_report(group_state: 'GroupState'):
group_state.set_timer()
def receive_leave_group(group_state: 'GroupState'):
group_ip = group_state.group_ip
group_state.set_timer(alternative=True)
group_state.set_retransmit_timer()
packet = PacketIGMPHeader(type=Membership_Query, max_resp_time=LastMemberQueryInterval*10, group_address=group_ip)
group_state.router_state.send(data=packet.bytes(), address=group_ip)
group_state.state = CheckingMembership
def receive_group_specific_query(group_state: 'GroupState', max_response_time):
# do nothing
return
from ..wrapper import MembersPresent
from ..wrapper import Version1MembersPresent
from utils import TYPE_CHECKING
if TYPE_CHECKING:
from ..GroupState import GroupState
def group_membership_timeout(group_state: 'GroupState'):
# do nothing
return
def group_membership_v1_timeout(group_state: 'GroupState'):
# do nothing
return
def retransmit_timeout(group_state: 'GroupState'):
# do nothing
return
def receive_v1_membership_report(group_state: 'GroupState'):
group_ip = group_state.group_ip
# TODO NOTIFY ROUTING + !!!!
group_state.set_timer()
group_state.set_v1_host_timer()
group_state.state = Version1MembersPresent
def receive_v2_membership_report(group_state: 'GroupState'):
group_ip = group_state.group_ip
# TODO NOTIFY ROUTING + !!!!
group_state.set_timer()
group_state.state = MembersPresent
def receive_leave_group(group_state: 'GroupState'):
# do nothing
return
def receive_group_specific_query(group_state: 'GroupState', max_response_time: int):
# do nothing
return
from Packet.PacketIGMPHeader import PacketIGMPHeader
from Packet.ReceivedPacket import ReceivedPacket
from utils import Membership_Query, QueryResponseInterval, LastMemberQueryCount, LastMemberQueryInterval
from . import CheckingMembership, MembersPresent, Version1MembersPresent, NoMembersPresent
from ipaddress import IPv4Address
from utils import TYPE_CHECKING
if TYPE_CHECKING:
from ..RouterState import RouterState
class Querier:
@staticmethod
def general_query_timeout(router_state: 'RouterState'):
# send general query
packet = PacketIGMPHeader(type=Membership_Query, max_resp_time=QueryResponseInterval*10)
router_state.interface.send(packet.bytes())
# set general query timer
router_state.set_general_query_timer()
@staticmethod
def other_querier_present_timeout(router_state: 'RouterState'):
# do nothing
return
@staticmethod
def receive_query(router_state: 'RouterState', packet: ReceivedPacket):
source_ip = packet.ip_header.ip_src
# if source ip of membership query not lower than the ip of the received interface => ignore
if IPv4Address(source_ip) >= IPv4Address(router_state.interface.get_ip()):
return
# if source ip of membership query lower than the ip of the received interface => change state
# change state of interface
# Querier -> Non Querier
router_state.change_interface_state(querier=False)
# set other present querier timer
router_state.clear_general_query_timer()
router_state.set_other_querier_present_timer()
# TODO ver se existe uma melhor maneira de fazer isto
@staticmethod
def state_name():
return "Querier"
@staticmethod
def get_group_membership_time(max_response_time: int):
return LastMemberQueryInterval * LastMemberQueryCount
# State
@staticmethod
def get_checking_membership_state():
return CheckingMembership
@staticmethod
def get_members_present_state():
return MembersPresent
@staticmethod
def get_no_members_present_state():
return NoMembersPresent
@staticmethod
def get_version_1_members_present_state():
return Version1MembersPresent
from ..wrapper import NoMembersPresent
from ..wrapper import MembersPresent
from utils import TYPE_CHECKING
if TYPE_CHECKING:
from ..GroupState import GroupState
def group_membership_timeout(group_state: 'GroupState'):
# TODO NOTIFY ROUTING - !!!!
group_state.state = NoMembersPresent
def group_membership_v1_timeout(group_state: 'GroupState'):
group_state.state = MembersPresent
def retransmit_timeout(group_state: 'GroupState'):
# do nothing
return
def receive_v1_membership_report(group_state: 'GroupState'):
group_state.set_timer()
group_state.set_v1_host_timer()
def receive_v2_membership_report(group_state: 'GroupState'):
group_state.set_timer()
def receive_leave_group(group_state: 'GroupState'):
# do nothing
return
def receive_group_specific_query(group_state: 'GroupState', max_response_time: int):
# do nothing
return
from utils import TYPE_CHECKING
if TYPE_CHECKING:
from ..RouterState import RouterState
def get_state(router_state: 'RouterState'):
return router_state.interface_state.get_checking_membership_state()
def print_state():
return "CheckingMembership"
'''
def group_membership_timeout(group_state):
get_state(group_state).group_membership_timeout(group_state)
def group_membership_v1_timeout(group_state):
get_state(group_state).group_membership_v1_timeout(group_state)
def retransmit_timeout(group_state):
get_state(group_state).retransmit_timeout(group_state)
def receive_v1_membership_report(group_state, packet: ReceivedPacket):
get_state(group_state).receive_v1_membership_report(group_state, packet)
def receive_v2_membership_report(group_state, packet: ReceivedPacket):
get_state(group_state).receive_v2_membership_report(group_state, packet)
def receive_leave_group(group_state, packet: ReceivedPacket):
get_state(group_state).receive_leave_group(group_state, packet)
def receive_group_specific_query(group_state, packet: ReceivedPacket):
get_state(group_state).receive_group_specific_query(group_state, packet)
'''
\ No newline at end of file
from utils import TYPE_CHECKING
if TYPE_CHECKING:
from ..RouterState import RouterState
def get_state(router_state: 'RouterState'):
return router_state.interface_state.get_members_present_state()
def print_state():
return "MembersPresent"
'''
def group_membership_timeout(group_state):
get_state(group_state).group_membership_timeout(group_state)
def group_membership_v1_timeout(group_state):
get_state(group_state).group_membership_v1_timeout(group_state)
def retransmit_timeout(group_state):
get_state(group_state).retransmit_timeout(group_state)
def receive_v1_membership_report(group_state, packet: ReceivedPacket):
get_state(group_state).receive_v1_membership_report(group_state, packet)
def receive_v2_membership_report(group_state, packet: ReceivedPacket):
get_state(group_state).receive_v2_membership_report(group_state, packet)
def receive_leave_group(group_state, packet: ReceivedPacket):
get_state(group_state).receive_leave_group(group_state, packet)
def receive_group_specific_query(group_state, packet: ReceivedPacket):
get_state(group_state).receive_group_specific_query(group_state, packet)
'''
\ No newline at end of file
from utils import TYPE_CHECKING
if TYPE_CHECKING:
from ..RouterState import RouterState
def get_state(router_state: 'RouterState'):
return router_state.interface_state.get_no_members_present_state()
def print_state():
return "NoMembersPresent"
'''
def group_membership_timeout(group_state):
get_state(group_state).group_membership_timeout(group_state)
def group_membership_v1_timeout(group_state):
get_state(group_state).group_membership_v1_timeout(group_state)
def retransmit_timeout(group_state):
get_state(group_state).retransmit_timeout(group_state)
def receive_v1_membership_report(group_state, packet: ReceivedPacket):
get_state(group_state).receive_v1_membership_report(group_state, packet)
def receive_v2_membership_report(group_state, packet: ReceivedPacket):
get_state(group_state).receive_v2_membership_report(group_state, packet)
def receive_leave_group(group_state, packet: ReceivedPacket):
get_state(group_state).receive_leave_group(group_state, packet)
def receive_group_specific_query(group_state, packet: ReceivedPacket):
get_state(group_state).receive_group_specific_query(group_state, packet)
'''
from utils import TYPE_CHECKING
if TYPE_CHECKING:
from ..RouterState import RouterState
def get_state(router_state: 'RouterState'):
return router_state.interface_state.get_version_1_members_present_state()
def print_state():
return "Version1MembersPresent"
'''
def group_membership_timeout(group_state):
get_state(group_state).group_membership_timeout(group_state)
def group_membership_v1_timeout(group_state):
get_state(group_state).group_membership_v1_timeout(group_state)
def retransmit_timeout(group_state):
get_state(group_state).retransmit_timeout(group_state)
def receive_v1_membership_report(group_state, packet: ReceivedPacket):
get_state(group_state).receive_v1_membership_report(group_state, packet)
def receive_v2_membership_report(group_state, packet: ReceivedPacket):
get_state(group_state).receive_v2_membership_report(group_state, packet)
def receive_leave_group(group_state, packet: ReceivedPacket):
get_state(group_state).receive_leave_group(group_state, packet)
def receive_group_specific_query(group_state, packet: ReceivedPacket):
get_state(group_state).receive_group_specific_query(group_state, packet)
'''
import Main
import socket
import struct
import netifaces
import threading
from tree.root_interface import SFRMRootInterface
from tree.non_root_interface import SFRMNonRootInterface
from threading import Timer
class KernelEntry:
TREE_TIMEOUT = 180
def __init__(self, source_ip: str, group_ip: str, inbound_interface_index: int):
self.source_ip = source_ip
self.group_ip = group_ip
# ip of neighbor of the rpf
self._rpf_node = None
self._has_members = True # todo check via igmp
self._was_in_group = True
self._rpf_is_origin = False
self._liveliness_timer = None
# decide inbound interface based on rpf check
self.inbound_interface_index = Main.kernel.vif_dic[self.check_rpf()]
#Main.kernel.flood(ip_src=source_ip, ip_dst=group_ip, iif=self.inbound_interface_index)
#import time
#time.sleep(5)
self.interface_state = {} # type: Dict[int, SFRMTreeInterface]
for i in range(Main.kernel.MAXVIFS):
if i == self.inbound_interface_index:
self.interface_state[i] = SFRMRootInterface(self, i, False)
else:
self.interface_state[i] = SFRMNonRootInterface(self, i)
print('Tree created')
self.evaluate_ingroup()
if self.is_originater():
self.set_liveliness_timer()
print('set SAT')
self._lock = threading.RLock()
def get_inbound_interface_index(self):
return self.inbound_interface_index
def get_outbound_interfaces_indexes(self):
outbound_indexes = [0]*Main.kernel.MAXVIFS
for (index, state) in self.interface_state.items():
outbound_indexes[index] = state.is_forwarding()
return outbound_indexes
def check_rpf(self):
from pyroute2 import IPRoute
# from utils import if_indextoname
ipr = IPRoute()
# obter index da interface
# rpf_interface_index = ipr.get_routes(family=socket.AF_INET, dst=ip)[0]['attrs'][2][1]
# interface_name = if_indextoname(rpf_interface_index)
# return interface_name
# obter ip da interface de saida
rpf_interface_source = ipr.get_routes(family=socket.AF_INET, dst=self.source_ip)[0]['attrs'][3][1]
return rpf_interface_source
def recv_data_msg(self, index):
if self.is_originater():
self.clear_liveliness_timer()
self.interface_state[index].recv_data_msg()
def recv_assert_msg(self, index, packet):
print("recv assert msg")
self.interface_state[index].recv_assert_msg(packet, None)
def recv_reset_msg(self, msg, sender):
# todo
return
def recv_prune_msg(self, index, packet):
print("recv prune msg")
self.interface_state[index].recv_prune_msg(packet, None, self.is_in_group())
def recv_join_msg(self, index, packet):
print("recv join msg")
self.interface_state[index].recv_join_msg(packet, None, self.is_in_group())
def recv_state_reset_msg(self, msg, sender):
# todo
return
def set_liveliness_timer(self):
self.clear_liveliness_timer()
timer = Timer(self.TREE_TIMEOUT, self.___liveliness_timer_expired)
timer.start()
self._liveliness_timer = timer
def clear_liveliness_timer(self):
if self._liveliness_timer is not None:
self._liveliness_timer.cancel()
def ___liveliness_timer_expired(self):
#todo
return
def network_update(self, change, args):
#todo
return
def update(self, caller, arg):
#todo
return
def nbr_event(self, link, node, event):
# todo
return
def is_in_group(self):
# todo
#if self.get_has_members():
if True:
return True
for interface in self.interface_state.values():
if interface.is_forwarding():
return True
return False
def evaluate_ingroup(self):
is_ig = self.is_in_group()
if self._was_in_group != is_ig:
if is_ig:
print('transitoned to IG')
#self._up_if.send_join()
self.interface_state[self.inbound_interface_index].send_join()
else:
print('transitoned to OG')
#self._up_if.send_prune()
self.interface_state[self.inbound_interface_index].send_prune()
self._was_in_group = is_ig
def is_originater(self):
# todo
#return self._rpf_node == self.get_source()
return False
def get_source(self):
return self.source_ip
def get_group(self):
return self.group_ip
def get_has_members(self):
#return self._has_members
return True
def set_has_members(self, value):
assert isinstance(value, bool)
self._has_members = value
self.evaluate_ingroup()
def change(self):
# todo: changes on unicast routing or multicast routing...
Main.kernel.set_multicast_route(self)
def delete(self):
for state in self.interface_state.values():
state.delete()
self.clear_liveliness_timer()
Main.kernel.remove_multicast_route(self)
from abc import ABCMeta, abstractstaticmethod
class SFMRAssertABC(metaclass=ABCMeta):
@abstractstaticmethod
def data_arrival(interface):
'''
@type interface: SFRMNonRootInterface
'''
raise NotImplemented()
@staticmethod
def recv_better_metric(interface, metric):
'''
@type interface: SFRMNonRootInterface
@type metric: SFMRAssertMetric
'''
raise NotImplemented()
@abstractstaticmethod
def recv_worse_metric(interface, metric):
'''
@type interface: SFRMNonRootInterface
@type metric: SFMRAssertMetric
'''
raise NotImplemented()
@abstractstaticmethod
def aw_failure(interface):
'''
@type interface: SFRMNonRootInterface
'''
raise NotImplemented()
@abstractstaticmethod
def al_rpc_better_than_aw(interface):
'''
@type interface: SFRMNonRootInterface
'''
raise NotImplemented()
@abstractstaticmethod
def aw_rpc_worsens(interface):
'''
@type interface: SFRMNonRootInterface
'''
raise NotImplemented()
@abstractstaticmethod
def is_now_root(interface):
'''
@type interface: SFRMNonRootInterface
'''
raise NotImplemented()
@abstractstaticmethod
def recv_reset(interface):
'''
@type interface: SFRMNonRootInterface
'''
raise NotImplemented()
@abstractstaticmethod
def is_now_pruned(interface):
'''
@type interface: SFRMNonRootInterface
'''
raise NotImplemented()
class SFMRAssertWinner(SFMRAssertABC):
@staticmethod
def data_arrival(interface):
'''
@type interface: SFRMNonRootInterface
'''
interface.rprint('data_arrival, W -> W')
interface.send_assert()
@staticmethod
def recv_better_metric(interface, metric):
'''
@type interface: SFRMNonRootInterface
@type metric: SFMRAssertMetric
'''
interface.rprint('recv_better_metric, W -> L')
interface._set_assert_state(AssertState.Looser)
interface._set_winner_metric(metric)
@staticmethod
def recv_worse_metric(interface, metric):
'''
@type interface: SFRMNonRootInterface
@type metric: SFMRAssertMetric
'''
interface.rprint('recv_worse_metric, W -> W')
interface.send_assert()
@staticmethod
def aw_failure(interface):
'''
@type interface: SFRMNonRootInterface
'''
assert False
@staticmethod
def al_rpc_better_than_aw(interface):
'''
@type interface: SFRMNonRootInterface
'''
assert False
@staticmethod
def aw_rpc_worsens(interface):
'''
@type interface: SFRMNonRootInterface
'''
interface.send_reset()
interface.rprint('aw_rpc_worsens, W -> W')
@staticmethod
def is_now_root(interface):
'''
@type interface: SFRMNonRootInterface
'''
interface.rprint('is_now_root, W -> W')
interface.send_reset()
@staticmethod
def recv_reset(interface):
'''
@type interface: SFRMNonRootInterface
'''
interface.rprint('recv_reset, W -> W')
@staticmethod
def is_now_pruned(interface):
'''
@type interface: SFRMNonRootInterface
'''
interface.rprint('is_now_pruned, W -> W')
class SFMRAssertLooser(SFMRAssertABC):
@staticmethod
def data_arrival(interface):
'''
@type interface: SFRMNonRootInterface
'''
interface.rprint('data_arrival, L -> L')
@staticmethod
def recv_better_metric(interface, metric):
'''
@type interface: SFRMNonRootInterface
@type metric: SFMRAssertMetric
'''
interface.rprint('recv_better_metric, L -> L')
interface._set_winner_metric(metric)
@staticmethod
def recv_worse_metric(interface, metric):
'''
@type interface: SFRMNonRootInterface
@type metric: SFMRAssertMetric
'''
interface.rprint('recv_worse_metric, L -> W')
interface.send_assert()
interface._set_assert_state(AssertState.Winner)
interface._set_winner_metric(None)
@staticmethod
def aw_failure(interface):
'''
@type interface: SFRMNonRootInterface
'''
interface.rprint('aw_failure, L -> W')
interface._set_assert_state(AssertState.Winner)
interface._set_winner_metric(None)
@staticmethod
def al_rpc_better_than_aw(interface):
'''
@type interface: SFRMNonRootInterface
'''
interface.rprint('al_rpc_improves, L -> W')
interface._set_assert_state(AssertState.Winner)
interface._set_winner_metric(None)
interface.send_assert()
@staticmethod
def aw_rpc_worsens(interface):
'''
@type interface: SFRMNonRootInterface
'''
assert False
@staticmethod
def is_now_root(interface):
'''
@type interface: SFRMNonRootInterface
'''
interface.rprint('is_now_root, L -> L')
@staticmethod
def recv_reset(interface):
'''
@type interface: SFRMNonRootInterface
'''
interface.rprint('recv_reset, L -> W')
interface._set_assert_state(AssertState.Winner)
interface._set_winner_metric(None)
@staticmethod
def is_now_pruned(interface):
'''
@type interface: SFRMNonRootInterface
'''
interface.rprint('is_now_pruned, L -> W')
interface._set_assert_state(AssertState.Winner)
interface._set_winner_metric(None)
class AssertState():
Winner = SFMRAssertWinner()
Looser = SFMRAssertLooser()
'''
Created on Jul 20, 2015
@author: alex
'''
import ipaddress
class SFMRAssertMetric(object):
'''
Note: we consider the node name the ip of the metric.
'''
def __init__(self, metric_preference: int or float, route_metric: int or float, ip_address: str):
if type(ip_address) is str:
ip_address = ipaddress.ip_address(ip_address)
self._metric_preference = metric_preference
self._route_metric = route_metric
self._ip_address = ip_address
def is_worse_than(self, other):
assert isinstance(other, SFMRAssertMetric)
if self.get_metric_preference() != other.get_metric_preference():
return self.get_metric_preference() > other.get_metric_preference()
elif self.get_route_metric() != other.get_route_metric():
return self.get_route_metric() > other.get_route_metric()
else:
return self.get_ip_address() <= other.get_ip_address()
@staticmethod
def infinite_assert_metric():
'''
@rtype SFMRAssertMetric
@type tree_if: SFRMTreeInterface
'''
#metric = SFMRAssertMetric()
#metric._metric = float("Inf")
#metric._node = ""
metric_preference = float("Inf")
route_metric = float("Inf")
ip = "0.0.0.0"
metric = SFMRAssertMetric(metric_preference=metric_preference, route_metric=route_metric, ip_address=ip)
return metric
@staticmethod
def spt_assert_metric(tree_if):
'''
@rtype SFMRAssertMetric
@type tree_if: SFRMTreeInterface
'''
#metric = SFMRAssertMetric()
#metric._metric = tree_if.get_cost()
#metric._node = tree_if.get_node()
metric_preference = 10 # todo check how to get metric preference
route_metric = tree_if.get_cost()
ip = tree_if.get_ip()
metric = SFMRAssertMetric(metric_preference=metric_preference, route_metric=route_metric, ip_address=ip)
return metric
# overrides
#def __str__(self):
# return "AssertMetric<%d:%d:%s>" % (self._metric_preference, self._node)
def get_metric_preference(self):
return self._metric_preference
def get_route_metric(self):
return self._route_metric
def get_ip_address(self):
return self._ip_address
def set_metric_preference(self, metric_preference):
self._metric_preference = metric_preference
def set_route_metric(self, route_metric):
self._route_metric = route_metric
def set_ip_address(self, ip):
self._ip_address = ip
'''
Created on Jul 16, 2015
@author: alex
'''
#from convergence import Convergence
#from des.event.timer import Timer
from threading import Timer
from .assert_ import AssertState, SFMRAssertABC
#from .messages.assert_msg import SFMRAssertMsg
#from .messages.reset import SFMResetMsg
from .metric import SFMRAssertMetric
from .prune import SFMRPruneState, SFMRPruneStateABC
from .tree_interface import SFRMTreeInterface
from Packet.ReceivedPacket import ReceivedPacket
from Packet.PacketPimAssert import PacketPimAssert
class SFRMNonRootInterface(SFRMTreeInterface):
DIPT_TIME = 3.0
def __init__(self, kernel_entry, interface_id):
SFRMTreeInterface.__init__(self, kernel_entry, interface_id, None)
self._assert_state = AssertState.Winner
self._assert_metric = None
self._prune_state = SFMRPruneState.DIP
#self._dipt = Timer(SFRMNonRootInterface.DIPT_TIME, self.__dipt_expires)
#self._dipt.start()
self._dipt = None
self.set_dipt_timer()
self.send_prune()
# Override
def recv_data_msg(self, msg=None, sender=None):
if self._prune_state != SFMRPruneState.NDI:
self._assert_state.data_arrival(self)
# Override
def recv_assert_msg(self, msg: ReceivedPacket, sender=None):
'''
@type msg: SFMRAssertMsg
@type sender: Addr
'''
if self._prune_state == SFMRPruneState.NDI:
return
if self._assert_state == AssertState.Looser:
winner_metric = self._get_winner_metric()
else:
winner_metric = self.get_metric()
ip_sender = msg.ip_header.ip_src
pkt_assert = msg.payload.payload # type: PacketPimAssert
msg_metric = SFMRAssertMetric(metric_preference=pkt_assert.metric_preference, route_metric=pkt_assert.metric, ip_address=ip_sender)
if winner_metric.is_worse_than(msg_metric):
self._assert_state.recv_better_metric(self, msg_metric)
else:
self._assert_state.recv_worse_metric(self, msg_metric)
# Override
def recv_reset_msg(self, msg, sender):
'''
@type msg: SFMResetMsg
@type sender: Addr
'''
if self._prune_state != SFMRPruneState.NDI:
self._assert_state.recv_reset(self)
# Override
def recv_prune_msg(self, msg, sender, in_group):
super().recv_prune_msg(msg, sender, in_group)
self._prune_state.recv_prune(self)
# Override
def recv_join_msg(self, msg, sender, in_group):
super().recv_join_msg(msg, sender, in_group)
self._prune_state.recv_join(self)
def forward_data_msg(self, msg):
pass
#def forward_data_msg(self, msg):
# if self.is_forwarding():
# self._interface.send_mcast(msg)
def send_assert(self):
(source, group) = self.get_tree_id()
from Packet.Packet import Packet
from Packet.PacketPimHeader import PacketPimHeader
from Packet.PacketPimAssert import PacketPimAssert
ph = PacketPimAssert(multicast_group_address=group, source_address=source, metric_preference=10, metric=2)
pckt = Packet(payload=PacketPimHeader(ph))
self.get_interface().send(pckt.bytes())
print('sent assert msg')
def send_reset(self):
# todo msg = SFMResetMsg(self.get_tree_id())
msg = None
self.get_interface().send_mcast(msg)
self.rprint('sent reset msg')
raise NotImplemented()
# Override
def send_prune(self):
SFRMTreeInterface.send_prune(self)
#if self._dipt.is_ticking():
if self._dipt.is_alive():
self._dipt.cancel()
# Override
def is_forwarding(self):
return self._assert_state == AssertState.Winner \
and self._prune_state != SFMRPruneState.NDI
# Override
def nbr_died(self, node):
if self._get_winner_metric() is not None \
and self._get_winner_metric().get_ip_address() == node\
and self._prune_state != SFMRPruneState.NDI:
self._assert_state.aw_failure(self)
self._prune_state.lost_nbr(self)
# Override
def nbr_connected(self):
self._prune_state.new_nbr(self)
# Override
def is_now_root(self):
self._assert_state.is_now_root(self)
self._prune_state.is_now_root(self)
# Override
def delete(self):
SFRMTreeInterface.delete(self)
#self._get_dipt().cancel()
self.clear_dipt_timer()
def __dipt_expires(self):
print('DIPT expired')
self._prune_state.dipt_expires(self)
def get_metric(self):
return SFMRAssertMetric.spt_assert_metric(self)
def _set_assert_state(self, value):
assert isinstance(value, SFMRAssertABC)
if value != self._assert_state:
self._assert_state = value
self.evaluate_ingroup()
#Convergence.mark_change()
self.change_tree()
def _get_winner_metric(self):
'''
@rtype: SFMRAssertMetric
'''
return self._assert_metric
def _set_winner_metric(self, value):
assert isinstance(value, SFMRAssertMetric) or value is None
self._assert_metric = value
# Override
def set_cost(self, value):
if value != self._cost and self._prune_state != SFMRPruneState.NDI:
if self.is_forwarding() and value > self._cost:
SFRMTreeInterface.set_cost(self, value)
self._assert_state.aw_rpc_worsens(self)
elif not self.is_forwarding(
) and value < self._get_winner_metric().get_metric():
SFRMTreeInterface.set_cost(self, value)
self._assert_state.al_rpc_better_than_aw(self)
else:
SFRMTreeInterface.set_cost(self, value)
else:
SFRMTreeInterface.set_cost(self, value)
def _set_prune_state(self, value):
assert isinstance(value, SFMRPruneStateABC)
if value != self._prune_state:
self._prune_state = value
self.evaluate_ingroup()
#Convergence.mark_change()
self.change_tree()
if value == SFMRPruneState.NDI:
self._assert_state.is_now_pruned(self)
def _get_dipt(self):
'''
@rtype: Timer
'''
return self._dipt
def set_dipt_timer(self):
self.clear_dipt_timer()
timer = Timer(self.DIPT_TIME, self.__dipt_expires)
timer.start()
self._dipt = timer
def clear_dipt_timer(self):
if self._dipt is not None:
self._dipt.cancel()
from abc import ABCMeta, abstractstaticmethod
class SFMRPruneStateABC(metaclass=ABCMeta):
@abstractstaticmethod
def recv_prune(interface):
'''
@type interface: SFRMNonRootInterface
'''
raise NotImplementedError()
@abstractstaticmethod
def recv_join(interface):
'''
@type interface: SFRMNonRootInterface
'''
raise NotImplementedError()
@abstractstaticmethod
def dipt_expires(interface):
'''
@type interface: SFRMNonRootInterface
'''
raise NotImplementedError()
@abstractstaticmethod
def is_now_root(interface):
'''
@type interface: SFRMNonRootInterface
'''
raise NotImplementedError()
@abstractstaticmethod
def new_nbr(interface):
'''
@type interface: SFRMNonRootInterface
'''
raise NotImplementedError()
@abstractstaticmethod
def lost_nbr(interface):
'''
@type interface: SFRMNonRootInterface
'''
raise NotImplementedError()
class SFMRDownstreamInterested(SFMRPruneStateABC):
@staticmethod
def recv_prune(interface):
'''
@type interface: SFRMNonRootInterface
'''
if len(interface.get_interface().neighbors) == 1:
print('recv_prune, DI -> NDI (only 1 nbr)')
interface._set_prune_state(SFMRPruneState.NDI)
else:
print('recv_prune, DI -> DIP')
interface._set_prune_state(SFMRPruneState.DIP)
#interface._get_dipt().start()
interface.set_dipt_timer()
@staticmethod
def recv_join(interface):
'''
@type interface: SFRMNonRootInterface
'''
print('recv_join, DI -> DI')
@staticmethod
def dipt_expires(interface):
'''
@type interface: SFRMNonRootInterface
'''
assert False
@staticmethod
def is_now_root(interface):
'''
@type interface: SFRMNonRootInterface
'''
print('is_now_root, DI -> DI')
@staticmethod
def new_nbr(interface):
'''
@type interface: SFRMNonRootInterface
'''
print('new_nbr, DI -> N')
@staticmethod
def lost_nbr(interface):
'''
@type interface: SFRMNonRootInterface
'''
print('lost_nbr, DI -> DIP')
interface.send_prune()
interface._set_prune_state(SFMRPruneState.DIP)
#interface._get_dipt().start()
interface.set_dipt_timer()
class SFMRDownstreamInterestedPending(SFMRPruneStateABC):
@staticmethod
def recv_prune(interface):
'''
@type interface: SFRMNonRootInterface
'''
# TODO foi alterado pelo Pedro... necessita de verificacao se esta OK...
print('recv_prune, DIP -> DIP')
if len(interface.get_interface().neighbors) == 1:
print('recv_prune, DIP -> DI (only 1 nbr)')
else:
print('recv_prune, DIP -> NDI')
interface._set_prune_state(SFMRPruneState.NDI)
interface.clear_dipt_timer()
@staticmethod
def recv_join(interface):
'''
@type interface: SFRMNonRootInterface
'''
print('recv_join, DIP -> DI')
interface._set_prune_state(SFMRPruneState.DI)
interface.clear_dipt_timer()
@staticmethod
def dipt_expires(interface):
'''
@type interface: SFRMNonRootInterface
'''
print('dipt_expires, DIP -> NDI')
interface._set_prune_state(SFMRPruneState.NDI)
@staticmethod
def is_now_root(interface):
'''
@type interface: SFRMNonRootInterface
'''
print('is_now_root, DIP -> DI')
interface._set_prune_state(SFMRPruneState.DI)
interface._get_dipt().stop()
@staticmethod
def new_nbr(interface):
'''
@type interface: SFRMNonRootInterface
'''
print('new_nbr, DIP -> DIP')
interface.send_prune()
#interface._get_dipt().reset()
interface.set_dipt_timer()
@staticmethod
def lost_nbr(interface):
'''
@type interface: SFRMNonRootInterface
'''
print('lost_nbr, DIP -> DIP')
#todo alterado pelo Pedro... necessita de verificar se esta OK...
interface.send_prune()
interface.set_dipt_timer()
class SFMRNoDownstreamInterested(SFMRPruneStateABC):
@staticmethod
def recv_prune(interface):
'''
@type interface: SFRMNonRootInterface
'''
print('recv_prune, NDI -> NDI')
@staticmethod
def recv_join(interface):
'''
@type interface: SFRMNonRootInterface
'''
print('recv_join, NDI -> DI')
interface._set_prune_state(SFMRPruneState.DI)
#interface._get_dipt().stop()
interface.clear_dipt_timer()
@staticmethod
def dipt_expires(interface):
'''
@type interface: SFRMNonRootInterface
'''
assert False
@staticmethod
def is_now_root(interface):
'''
@type interface: SFRMNonRootInterface
'''
print('is_now_root, NDI -> DI')
interface._set_prune_state(SFMRPruneState.DI)
@staticmethod
def new_nbr(interface):
'''
@type interface: SFRMNonRootInterface
'''
print('new_nbr, NDI -> NDI')
interface.send_prune()
@staticmethod
def lost_nbr(interface):
'''
@type interface: SFRMNonRootInterface
'''
print('lost_nbr, NDI -> NDI')
class SFMRPruneState():
DI = SFMRDownstreamInterested()
DIP = SFMRDownstreamInterestedPending()
NDI = SFMRNoDownstreamInterested()
'''
Created on Jul 16, 2015
@author: alex
'''
#from des.addr import Addr
#from .messages.assert_msg import SFMRAssertMsg
#from .messages.join import SFMRJoinMsg
from .tree_interface import SFRMTreeInterface
class SFRMRootInterface(SFRMTreeInterface):
def __init__(
self, kernel_entry, interface_id, is_originater: bool):
'''
interface,
node,
tree_id,
cost,
evaluate_ig_cb,
is_originater: bool, ):
'''
SFRMTreeInterface.__init__(self, kernel_entry, interface_id, None)
self._is_originater = is_originater
#Override
#def recv_assert_msg(self, msg: SFMRAssertMsg, sender: Addr):
def recv_assert_msg(self, msg, sender):
pass
#Override
def recv_prune_msg(self, msg, sender, in_group):
super().recv_prune_msg(msg, sender, in_group)
if in_group:
print("I WILL SEND JOIN")
self.send_join()
print("I SENT JOIN")
def forward_data_msg(self, msg):
pass
def send_join(self):
# Originaters dont need to send prunes or joins
if self._is_originater:
return
print("I WILL SEND JOIN")
#msg = SFMRJoinMsg(self.get_tree_id())
from Packet.Packet import Packet
from Packet.PacketPimHeader import PacketPimHeader
from Packet.PacketPimJoinPrune import PacketPimJoinPrune
from Packet.PacketPimJoinPruneMulticastGroup import PacketPimJoinPruneMulticastGroup
(source, group) = self.get_tree_id()
# todo help ip of upstream neighbor
ph = PacketPimJoinPrune("123.123.123.123", 210)
ph.add_multicast_group(PacketPimJoinPruneMulticastGroup(group, joined_src_addresses=[source]))
pckt = Packet(payload=PacketPimHeader(ph))
self.get_interface().send(pckt.bytes())
print('sent join msg')
#Override
def is_forwarding(self):
return False
#Override
def is_now_root(self):
assert False
#Override
def delete(self):
super().delete()
'''
Created on Jul 16, 2015
@author: alex
'''
from convergence import Convergence
from des.entities.node import NodeChanges
from des.event.timer import Timer
from sfmr.messages.state_reset import SFMRStateResetMsg
from sfmr.non_root_interface import SFRMNonRootInterface
from sfmr.root_interface import SFRMRootInterface
from sfmr.router_interface import SFMRInterface, NeighborEvent
class SFMRTree(object):
TREE_TIMEOUT = 180
def __init__(self, rprint, unicastd, tree_id, tree_liveliness_callback,
ifs, node, has_members):
'''
@type ifs: dict
@type node: Node
'''
self._rpf_node = None
self._rpf_link = None
self._rprint = rprint
self._tree_id = tree_id
self._unicastd = unicastd
self._node = node
self._has_members = has_members
self._was_in_group = True
self._rpf_is_origin = False
self._liveliness_timer = Timer(None, SFMRTree.TREE_TIMEOUT,
self.___liveliness_timer_expired)
self._died_cb = tree_liveliness_callback
self._interfaces = dict()
self._up_if = None
self.set_rpf()
self._create_root_if(self._rpf_link, ifs.pop(self._rpf_link))
for k, v in ifs.items():
self._create_non_root_if(k, v)
self.rprint('Tree created')
self.evaluate_ingroup()
if self.is_originater():
self._liveliness_timer.start()
self.rprint('set SAT')
def set_rpf(self):
"""
Updates the reverse path forward node and link from the unicast daemon
returning true if there is a change in the rpf_link
@type unid: Unicast
@rtype: (Bool, Bool)
@return: The first bool indicates if rpf_link has changed
The second indicates if rpf_node has changed
"""
next_hop_addr = self._unicastd.next_hop(self.get_source())
node_has_changed = next_hop_addr.get_node() != self._rpf_node
link_has_changed = next_hop_addr.get_link() != self._rpf_link
self._rpf_node = next_hop_addr.get_node()
self._rpf_link = next_hop_addr.get_link()
if link_has_changed:
self.rprint("Tree rpf link changed", 'to', self._rpf_link)
return link_has_changed, node_has_changed
def recv_data_msg(self, msg, sender):
'''
@type msg: DataMsg
@type sender: Addr
'''
if self.is_originater():
self._liveliness_timer.reset()
self._interfaces[sender.get_link()].recv_data_msg(msg, sender)
if sender.get_link() == self._rpf_link:
for interface in self._interfaces.values():
interface.forward_data_msg(msg)
def recv_assert_msg(self, msg, sender):
'''
@type msg: SFMRAssertMsg
@type sender: Addr
'''
self._interfaces[sender.get_link()].recv_assert_msg(msg, sender)
def recv_reset_msg(self, msg, sender):
'''
@type msg: SFMResetMsg
@type sender: Addr
'''
self._interfaces[sender.get_link()].recv_reset_msg(msg, sender)
def recv_prune_msg(self, msg, sender):
'''
@type msg: SFMResetMsg
@type sender: Addr
'''
self._interfaces[sender.get_link()].recv_prune_msg(
msg, sender, self.is_in_group())
def recv_join_msg(self, msg, sender):
'''
@type msg: SFMResetMsg
@type sender: Addr
'''
self._interfaces[sender.get_link()].recv_join_msg(
msg, sender, self.is_in_group())
def recv_state_reset_msg(self, msg, sender):
'''
@type msg: SFMResetMsg
@type sender: Addr
'''
self.flood_state_reset(msg)
self._died_cb(self.get_tree_id())
def ___liveliness_timer_expired(self):
self.rprint('Tree liveliness timer expired')
self.flood_state_reset(SFMRStateResetMsg(self.get_tree_id()))
self._died_cb(self.get_tree_id())
def flood_state_reset(self, msg):
for interface in self._interfaces.values():
interface.forward_state_reset_msg(msg)
def network_update(self, change, args):
assert isinstance(args, SFMRInterface)
link = args.get_link()
if NodeChanges.NewIf == change:
self._create_non_root_if(link, args)
elif NodeChanges.CrashIf == change:
self._interfaces.pop(link).delete()
if link == self._rpf_link:
self._up_if = None
if self._liveliness_timer.is_ticking():
self._liveliness_timer.stop()
self.rprint('stop SAT')
elif NodeChanges.IfCostChange == change:
pass
else:
assert False, "this should never be called (case switch)"
def update(self, caller, arg):
""" called when there is a change in the routing Daemons """
link_ch, node_ch = self.set_rpf()
if self._rpf_link is None:
self.rprint('Lost unicast connection to source')
self._died_cb(self.get_tree_id())
return
if link_ch:
if self._up_if is not None:
old_link = self._up_if.get_link()
old_router_if = self._up_if.get_interface()
self._interfaces.pop(old_link).delete()
self._create_non_root_if(old_link, old_router_if)
self._up_if = None
rpf_router_if = self._interfaces[self._rpf_link].get_interface()
old_if = self._interfaces.pop(self._rpf_link)
self._create_root_if(self._rpf_link, rpf_router_if)
old_if.is_now_root()
old_if.delete()
if self.is_in_group():
self._up_if.send_join()
if link_ch and node_ch:
pass
for interface in self._interfaces.values():
interface.set_cost(self._unicastd.cost_to(self.get_source()))
def nbr_event(self, link, node, event):
'''
@type link: Link
@type node: Node
@type event: NeighborEvent
'''
if NeighborEvent.timedOut == event:
self._interfaces[link].nbr_died(node)
elif NeighborEvent.genIdChanged == event:
self._interfaces[link].nbr_connected()
elif NeighborEvent.newNbr == event:
self._interfaces[link].nbr_connected()
else:
assert False
def is_in_group(self):
if self.get_has_members():
return True
for interface in self._interfaces.values():
if interface.is_forwarding():
return True
return False
def evaluate_ingroup(self):
is_ig = self.is_in_group()
if self._was_in_group != is_ig:
if is_ig:
self.rprint('transitoned to IG')
self._up_if.send_join()
else:
self.rprint('transitoned to OG')
self._up_if.send_prune()
self._was_in_group = is_ig
def is_originater(self):
return self._rpf_node == self.get_source()
def delete(self):
for interface in self._interfaces.values():
interface.delete()
self._liveliness_timer.stop()
self.rprint('Tree deleted')
Convergence.mark_change()
def rprint(self, msg, *entrys):
self._rprint(msg, '({}, {})'.format(self.get_tree_id()[0],
self.get_tree_id()[1]), *entrys)
def get_tree_id(self):
return self._tree_id
def get_source(self):
return self._tree_id[0]
def get_group(self):
return self._tree_id[1]
def get_node(self):
'''
@rtype: Node
'''
return self._node
def get_has_members(self):
return self._has_members
def set_has_members(self, value):
assert isinstance(value, bool)
self._has_members = value
self.evaluate_ingroup()
def _create_root_if(self, link, router_interface):
# assert self._up_if is None
assert link not in self._interfaces
self._interfaces[link] = SFRMRootInterface(
self.rprint, router_interface,
self.get_node(),
self.get_tree_id(),
self._unicastd.cost_to(self.get_source()), self.evaluate_ingroup,
self.is_originater())
self._up_if = self._interfaces[link]
def _create_non_root_if(self, link, router_interface):
assert link not in self._interfaces
nrif = SFRMNonRootInterface(self.rprint, router_interface,
self.get_node(),
self.get_tree_id(),
self._unicastd.cost_to(self.get_source()),
self.evaluate_ingroup)
self._interfaces[link] = nrif
nrif._dipt.start()
nrif.send_prune()
'''
Created on Jul 16, 2015
@author: alex
'''
from abc import ABCMeta, abstractmethod
import Main
#from convergence import Convergence
#from sfmr.messages.prune import SFMRPruneMsg
#from .router_interface import SFMRInterface
class SFRMTreeInterface(metaclass=ABCMeta):
def __init__(self, kernel_entry, interface_id, evaluate_ig_cb):
'''
@type interface: SFMRInterface
@type node: Node
'''
#assert isinstance(interface, SFMRInterface)
self._kernel_entry = kernel_entry
self._interface_id = interface_id
#self._interface = interface
#self._node = node
#self._tree_id = tree_id
#self._cost = cost
self._evaluate_ig = evaluate_ig_cb
#self.rprint('new ' + self.__class__.__name__)
#Convergence.mark_change()
def recv_data_msg(self, msg, sender):
pass
@abstractmethod
def recv_assert_msg(self, msg, sender):
pass
def recv_reset_msg(self, msg, sender):
pass
def recv_prune_msg(self, msg, sender, in_group):
print("SUPER PRUNE")
pass
def recv_join_msg(self, msg, sender, in_group):
print("SUPER JOIN")
pass
@abstractmethod
def forward_data_msg(self, msg):
pass
def forward_state_reset_msg(self, msg):
self._interface.send_mcast(msg)
def send_prune(self):
try:
from Packet.Packet import Packet
from Packet.PacketPimHeader import PacketPimHeader
from Packet.PacketPimJoinPrune import PacketPimJoinPrune
from Packet.PacketPimJoinPruneMulticastGroup import PacketPimJoinPruneMulticastGroup
(source, group) = self.get_tree_id()
# todo help ip of ph
ph = PacketPimJoinPrune("123.123.123.123", 210)
ph.add_multicast_group(PacketPimJoinPruneMulticastGroup(group, pruned_src_addresses=[source]))
pckt = Packet(payload=PacketPimHeader(ph))
self.get_interface().send(pckt.bytes())
print('sent prune msg')
except:
return
@abstractmethod
def is_forwarding(self):
pass
def nbr_died(self, node):
pass
def nbr_connected(self):
pass
@abstractmethod
def is_now_root(self):
pass
@abstractmethod
def delete(self):
print('Tree Interface deleted')
def evaluate_ingroup(self):
# todo help self._evaluate_ig()
return
'''
def rprint(self, msg, *entrys):
self._rprint(msg,
self._interface.get_link(),
*entrys)
'''
def rprint(self, msg, *entrys):
return
def __str__(self):
return '{}<{}>'.format(self.__class__, self._interface.get_link())
def get_link(self):
return self._interface.get_link()
def get_interface(self):
import Main
kernel = Main.kernel
interface_name = kernel.vif_index_to_name_dic[self._interface_id]
interface = Main.interfaces[interface_name]
return interface
def get_node(self):
# todo: para ser substituido por get_ip
return self.get_ip()
def get_ip(self):
import Main
kernel = Main.kernel
interface_name = kernel.vif_index_to_name_dic[self._interface_id]
import netifaces
netifaces.ifaddresses(interface_name)
ip = netifaces.ifaddresses(interface_name)[netifaces.AF_INET][0]['addr']
return ip
def get_tree_id(self):
#return self._tree_id
return (self._kernel_entry.source_ip, self._kernel_entry.group_ip)
def get_cost(self):
#return self._cost
return 10
def set_cost(self, value):
self._cost = value
def change_tree(self):
self._kernel_entry.change()
......@@ -35,3 +35,61 @@ def checksum(pkt: bytes) -> bytes:
s = ~s
return (((s >> 8) & 0xff) | s << 8) & 0xffff
import ctypes
import ctypes.util
libc = ctypes.CDLL(ctypes.util.find_library('c'))
def if_nametoindex(name):
if not isinstance(name, str):
raise TypeError('name must be a string.')
ret = libc.if_nametoindex(name)
if not ret:
raise RuntimeError("Invalid Name")
return ret
def if_indextoname(index):
if not isinstance(index, int):
raise TypeError('index must be an int.')
libc.if_indextoname.argtypes = [ctypes.c_uint32, ctypes.c_char_p]
libc.if_indextoname.restype = ctypes.c_char_p
ifname = ctypes.create_string_buffer(32)
ifname = libc.if_indextoname(index, ifname)
if not ifname:
raise RuntimeError ("Inavlid Index")
return ifname.decode("utf-8")
# obtain TYPE_CHECKING (for type hinting)
try:
from typing import TYPE_CHECKING
except ImportError:
TYPE_CHECKING = False
# IGMP timers (in seconds)
RobustnessVariable = 2
QueryInterval = 125
QueryResponseInterval = 10
MaxResponseTime_QueryResponseInterval = QueryResponseInterval*10
GroupMembershipInterval = RobustnessVariable * QueryInterval + QueryResponseInterval
OtherQuerierPresentInterval = RobustnessVariable * QueryInterval + QueryResponseInterval/2
StartupQueryInterval = QueryInterval / 4
StartupQueryCount = RobustnessVariable
LastMemberQueryInterval = 1
MaxResponseTime_LastMemberQueryInterval = LastMemberQueryInterval*10
LastMemberQueryCount = RobustnessVariable
UnsolicitedReportInterval = 10
Version1RouterPresentTimeout = 400
# IGMP msg type
Membership_Query = 0x11
Version_1_Membership_Report = 0x12
Version_2_Membership_Report = 0x16
Leave_Group = 0x17
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment