Commit fd60f995 authored by Pedro Oliveira's avatar Pedro Oliveira

IGMP efficient packet reception (via BPF) & StateRefresh/Originator state...

IGMP efficient packet reception (via BPF) & StateRefresh/Originator state machine & socket to recv (S,G) data packets (also with BPF) & fix some state machine errors
parent 43fc51da
......@@ -18,8 +18,10 @@ class Hello:
options = packet.payload.payload.get_options()
if (1 in options) and (20 in options):
hello_hold_time = options[1]
generation_id = options[20]
#hello_hold_time = options[1]
hello_hold_time = options[1].holdtime
#generation_id = options[20]
generation_id = options[20].generation_id
else:
raise Exception
......
......@@ -34,6 +34,7 @@ class Interface(object):
# set socket TTL to 1
s.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_TTL, 1)
s.setsockopt(socket.IPPROTO_IP, socket.IP_TTL, 1)
# don't receive outgoing packets
s.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_LOOP, 0)
......@@ -62,6 +63,7 @@ class Interface(object):
packet = None
return packet
except Exception:
traceback.print_exc()
return None
"""
......
......@@ -5,6 +5,7 @@ import netifaces
from Packet.ReceivedPacket import ReceivedPacket
import Main
import traceback
from ctypes import create_string_buffer, addressof
if not hasattr(socket, 'SO_BINDTODEVICE'):
socket.SO_BINDTODEVICE = 25
......@@ -12,17 +13,32 @@ if not hasattr(socket, 'SO_BINDTODEVICE'):
class InterfaceIGMP(object):
ETH_P_IP = 0x0800 # Internet Protocol packet
FILTER_IGMP = [
struct.pack('HBBI', 0x28, 0, 0, 0x0000000c),
struct.pack('HBBI', 0x15, 0, 3, 0x00000800),
struct.pack('HBBI', 0x30, 0, 0, 0x00000017),
struct.pack('HBBI', 0x15, 0, 1, 0x00000002),
struct.pack('HBBI', 0x6, 0, 0, 0x00040000),
struct.pack('HBBI', 0x6, 0, 0, 0x00000000),
]
SO_ATTACH_FILTER = 26
PACKET_MR_ALLMULTI = 2
def __init__(self, interface_name: str, vif_index:int):
# RECEIVE SOCKET
rcv_s = socket.socket(socket.PF_PACKET, socket.SOCK_RAW, socket.htons(InterfaceIGMP.ETH_P_IP))
rcv_s = socket.socket(socket.AF_PACKET, socket.SOCK_RAW, socket.htons(InterfaceIGMP.ETH_P_IP))
# allow all multicast packets
rcv_s.setsockopt(socket.SOL_SOCKET, InterfaceIGMP.PACKET_MR_ALLMULTI, struct.pack("i HH BBBBBBBB", 0, InterfaceIGMP.PACKET_MR_ALLMULTI, 0, 0,0,0,0,0,0,0,0))
# receive only IGMP packets by setting a BPF filter
filters = b''.join(InterfaceIGMP.FILTER_IGMP)
b = create_string_buffer(filters)
mem_addr_of_filters = addressof(b)
fprog = struct.pack('HL', len(InterfaceIGMP.FILTER_IGMP), mem_addr_of_filters)
rcv_s.setsockopt(socket.SOL_SOCKET, InterfaceIGMP.SO_ATTACH_FILTER, fprog)
# bind to interface
rcv_s.bind((interface_name, 0))
rcv_s.bind((interface_name, 0x0800))
self.recv_socket = rcv_s
......@@ -62,17 +78,9 @@ class InterfaceIGMP(object):
def receive(self):
while self.interface_enabled:
try:
(raw_packet, x) = self.recv_socket.recvfrom(256 * 1024)
(raw_packet, _) = self.recv_socket.recvfrom(256 * 1024)
if raw_packet:
raw_packet = raw_packet[14:]
from Packet.PacketIpHeader import PacketIpHeader
(verhlen, tos, iplen, ipid, frag, ttl, proto, cksum, src, dst) = \
struct.unpack(PacketIpHeader.IP_HDR, raw_packet[:PacketIpHeader.IP_HDR_LEN])
#print(proto)
if proto != socket.IPPROTO_IGMP:
continue
#print((raw_packet, x))
packet = ReceivedPacket(raw_packet, self)
Main.igmp.receive_handle(packet)
except Exception:
......
......@@ -5,12 +5,14 @@ from Packet.ReceivedPacket import ReceivedPacket
import Main
import traceback
from RWLock.RWLock import RWLockWrite
from Packet.PacketPimHelloOptions import *
from Packet.PacketPimHello import PacketPimHello
from Packet.PacketPimHeader import PacketPimHeader
from Packet.Packet import Packet
from Hello import Hello
from utils import HELLO_HOLD_TIME_TIMEOUT
from threading import Timer
from tree.globals import REFRESH_INTERVAL
class InterfacePim(Interface):
MCAST_GRP = '224.0.0.13'
......@@ -20,7 +22,7 @@ class InterfacePim(Interface):
MAX_TRIGGERED_HELLO_PERIOD = 5
def __init__(self, interface_name: str, vif_index:int):
def __init__(self, interface_name: str, vif_index:int, state_refresh_capable:bool=False):
super().__init__(interface_name)
# generation id
......@@ -33,9 +35,8 @@ class InterfacePim(Interface):
self.hello_timer.start()
# todo: state refresh capable
self._state_refresh_capable = False
# state refresh capable
self._state_refresh_capable = state_refresh_capable
# todo: lan delay enabled
self._lan_delay_enabled = False
......@@ -58,10 +59,6 @@ class InterfacePim(Interface):
receive_thread.daemon = True
receive_thread.start()
def create_virtual_interface(self):
self.vif_index = Main.kernel.create_virtual_interface(ip_interface=self.ip_interface, interface_name=self.interface_name)
def receive(self):
while self.is_enabled():
try:
......@@ -80,8 +77,15 @@ class InterfacePim(Interface):
self.hello_timer.cancel()
pim_payload = PacketPimHello()
pim_payload.add_option(1, 3.5 * Hello.TRIGGERED_HELLO_DELAY)
pim_payload.add_option(20, self.generation_id)
pim_payload.add_option(PacketPimHelloHoldtime(holdtime=3.5 * Hello.TRIGGERED_HELLO_DELAY))
pim_payload.add_option(PacketPimHelloGenerationID(self.generation_id))
# TODO implementar LANPRUNEDELAY e OVERRIDE_INTERVAL por interface e nas maquinas de estados ler valor de interface e nao do globals.py
#pim_payload.add_option(PacketPimHelloLANPruneDelay(lan_prune_delay=self._propagation_delay, override_interval=self._override_interval))
if self._state_refresh_capable:
pim_payload.add_option(PacketPimHelloStateRefreshCapable(REFRESH_INTERVAL))
ph = PacketPimHeader(pim_payload)
packet = Packet(payload=ph)
self.send(packet.bytes())
......@@ -96,8 +100,8 @@ class InterfacePim(Interface):
# send pim_hello timeout message
pim_payload = PacketPimHello()
pim_payload.add_option(1, HELLO_HOLD_TIME_TIMEOUT)
pim_payload.add_option(20, self.generation_id)
pim_payload.add_option(PacketPimHelloHoldtime(holdtime=HELLO_HOLD_TIME_TIMEOUT))
pim_payload.add_option(PacketPimHelloGenerationID(self.generation_id))
ph = PacketPimHeader(pim_payload)
packet = Packet(payload=ph)
self.send(packet.bytes())
......@@ -130,3 +134,7 @@ class InterfacePim(Interface):
def remove_neighbor(self, ip):
with self.neighbors_lock.genWlock():
del self.neighbors[ip]
def is_state_refresh_enabled(self):
return self._state_refresh_capable
......@@ -41,6 +41,12 @@ class Kernel:
IGMPMSG_WHOLEPKT = 3 # NOT USED ON PIM-DM
# Interface flags
VIFF_TUNNEL = 0x1 # IPIP tunnel
VIFF_SRCRT = 0x2 # NI
VIFF_REGISTER = 0x4 # register vif
VIFF_USE_IFINDEX = 0x8 # use vifc_lcl_ifindex instead of vifc_lcl_addr to find an interface
def __init__(self):
# Kernel is running
self.running = True
......@@ -66,6 +72,9 @@ class Kernel:
self.rwlock = RWLockWrite()
self.interface_lock = Lock()
# Create register interface
# todo useless in PIM-DM... useful in PIM-SM
#self.create_virtual_interface("0.0.0.0", "pimreg", index=0, flags=Kernel.VIFF_REGISTER)
# Create virtual interfaces
'''
......@@ -149,8 +158,55 @@ class Kernel:
return index
def create_pim_interface(self, interface_name: str, state_refresh_capable:bool):
from InterfacePIM import InterfacePim
with self.interface_lock:
pim_interface = self.pim_interface.get(interface_name)
igmp_interface = self.igmp_interface.get(interface_name)
vif_already_exists = pim_interface or igmp_interface
if pim_interface:
# already exists
return
elif igmp_interface:
index = igmp_interface.vif_index
else:
index = list(range(0, self.MAXVIFS) - self.vif_index_to_name_dic.keys())[0]
ip_interface = None
if interface_name not in self.pim_interface:
pim_interface = InterfacePim(interface_name, index, state_refresh_capable)
self.pim_interface[interface_name] = pim_interface
ip_interface = pim_interface.ip_interface
if not vif_already_exists:
self.create_virtual_interface(ip_interface=ip_interface, interface_name=interface_name, index=index)
def create_igmp_interface(self, interface_name: str):
from InterfaceIGMP import InterfaceIGMP
with self.interface_lock:
pim_interface = self.pim_interface.get(interface_name)
igmp_interface = self.igmp_interface.get(interface_name)
vif_already_exists = pim_interface or igmp_interface
if igmp_interface:
# already exists
return
elif pim_interface:
index = pim_interface.vif_index
else:
index = list(range(0, self.MAXVIFS) - self.vif_index_to_name_dic.keys())[0]
ip_interface = None
if interface_name not in self.igmp_interface:
igmp_interface = InterfaceIGMP(interface_name, index)
self.igmp_interface[interface_name] = igmp_interface
ip_interface = igmp_interface.ip_interface
if not vif_already_exists:
self.create_virtual_interface(ip_interface=ip_interface, interface_name=interface_name, index=index)
'''
def create_interface(self, interface_name: str, igmp:bool = False, pim:bool = False):
from InterfaceIGMP import InterfaceIGMP
from InterfacePIM import InterfacePim
......@@ -180,7 +236,7 @@ class Kernel:
if not vif_already_exists:
self.create_virtual_interface(ip_interface=ip_interface, interface_name=interface_name, index=index)
'''
......@@ -263,20 +319,6 @@ class Kernel:
# TODO: ver melhor tabela routing
#self.routing[(socket.inet_ntoa(source_ip), socket.inet_ntoa(group_ip))] = {"inbound_interface_index": inbound_interface_index, "outbound_interfaces": outbound_interfaces}
'''
def flood(self, ip_src, ip_dst, iif):
source_ip = socket.inet_aton(ip_src)
group_ip = socket.inet_aton(ip_dst)
outbound_interfaces = [1]*Kernel.MAXVIFS
outbound_interfaces[iif] = 0
outbound_interfaces_and_other_parameters = outbound_interfaces + [0]*4
#outbound_interfaces, 0, 0, 0, 0 <- only works with python>=3.5
#struct_mfcctl = struct.pack("4s 4s H " + "B"*Kernel.MAXVIFS + " IIIi", source_ip, group_ip, inbound_interface_index, *outbound_interfaces, 0, 0, 0, 0)
struct_mfcctl = struct.pack("4s 4s H " + "B"*Kernel.MAXVIFS + " IIIi", source_ip, group_ip, iif, *outbound_interfaces_and_other_parameters)
self.socket.setsockopt(socket.IPPROTO_IP, Kernel.MRT_ADD_MFC, struct_mfcctl)
'''
def remove_multicast_route(self, kernel_entry: KernelEntry):
source_ip = socket.inet_aton(kernel_entry.source_ip)
......@@ -312,8 +354,7 @@ class Kernel:
def handler(self):
while self.running:
try:
msg = self.socket.recv(5000)
#print(len(msg))
msg = self.socket.recv(20)
(_, _, im_msgtype, im_mbz, im_vif, _, im_src, im_dst) = struct.unpack("II B B B B 4s 4s", msg[:20])
print((im_msgtype, im_mbz, socket.inet_ntoa(im_src), socket.inet_ntoa(im_dst)))
......@@ -336,6 +377,9 @@ class Kernel:
elif im_msgtype == Kernel.IGMPMSG_WRONGVIF:
print("WRONG VIF HANDLER")
self.igmpmsg_wrongvif_handler(ip_src, ip_dst, im_vif)
#elif im_msgtype == Kernel.IGMPMSG_WHOLEPKT:
# print("IGMP_WHOLEPKT")
# self.igmpmsg_wholepacket_handler(ip_src, ip_dst)
else:
raise Exception
except Exception:
......@@ -379,6 +423,17 @@ class Kernel:
self.get_routing_entry(source_group_pair, create_if_not_existent=True).recv_data_msg(iif)
#kernel_entry.recv_data_msg(iif)
''' useless in PIM-DM... useful in PIM-SM
def igmpmsg_wholepacket_handler(self, ip_src, ip_dst):
#kernel_entry = self.routing[(ip_src, ip_dst)]
source_group_pair = (ip_src, ip_dst)
self.get_routing_entry(source_group_pair, create_if_not_existent=True).recv_data_msg()
#kernel_entry.recv_data_msg(iif)
'''
"""
def get_routing_entry(self, source_group: tuple):
with self.rwlock.genRlock():
......
......@@ -15,6 +15,14 @@ kernel = None
igmp = None
def add_pim_interface(interface_name, state_refresh_capable:bool=False):
kernel.create_pim_interface(interface_name=interface_name, state_refresh_capable=state_refresh_capable)
def add_igmp_interface(interface_name):
kernel.create_igmp_interface(interface_name=interface_name)
'''
def add_interface(interface_name, pim=False, igmp=False):
#if pim is True and interface_name not in interfaces:
# interface = InterfacePim(interface_name)
......@@ -28,7 +36,7 @@ def add_interface(interface_name, pim=False, igmp=False):
# interfaces[interface_name] = kernel.pim_interface[interface_name]
#if igmp:
# igmp_interfaces[interface_name] = kernel.igmp_interface[interface_name]
'''
def remove_interface(interface_name, pim=False, igmp=False):
#if pim is True and ((interface_name in interfaces) or interface_name == "*"):
......@@ -76,38 +84,6 @@ def list_neighbors():
def list_enabled_interfaces():
global interfaces
# TESTE DE PIM JOIN/PRUNE
for interface in interfaces:
from Packet.Packet import Packet
from Packet.PacketPimHeader import PacketPimHeader
from Packet.PacketPimJoinPrune import PacketPimJoinPrune
from Packet.PacketPimJoinPruneMulticastGroup import PacketPimJoinPruneMulticastGroup
ph = PacketPimJoinPrune("10.0.0.13", 210)
ph.add_multicast_group(PacketPimJoinPruneMulticastGroup("239.123.123.123", ["1.1.1.1", "10.1.1.1"], []))
ph.add_multicast_group(PacketPimJoinPruneMulticastGroup("239.123.123.124", ["1.1.1.2", "10.1.1.2"], []))
pckt = Packet(payload=PacketPimHeader(ph))
interfaces[interface].send(pckt.bytes())
ph = PacketPimJoinPrune("ff08::1", 210)
ph.add_multicast_group(PacketPimJoinPruneMulticastGroup("2001:1:a:b:c::1", ["1.1.1.1", "2001:1:a:b:c::2"], []))
ph.add_multicast_group(PacketPimJoinPruneMulticastGroup("239.123.123.123", ["1.1.1.1"], ["2001:1:a:b:c::3"]))
pckt = Packet(payload=PacketPimHeader(ph))
interfaces[interface].send(pckt.bytes())
from Packet.PacketPimAssert import PacketPimAssert
ph = PacketPimAssert("224.12.12.12", "10.0.0.2", 210, 2)
pckt = Packet(payload=PacketPimHeader(ph))
interfaces[interface].send(pckt.bytes())
from Packet.PacketPimGraft import PacketPimGraft
ph = PacketPimGraft("10.0.0.13")
ph.add_multicast_group(PacketPimJoinPruneMulticastGroup("239.123.123.124", ["1.1.1.2", "10.1.1.2"], []))
pckt = Packet(payload=PacketPimHeader(ph))
interfaces[interface].send(pckt.bytes())
t = PrettyTable(['Interface', 'IP', 'PIM/IGMP Enabled', 'IGMP State'])
for interface in netifaces.interfaces():
......@@ -191,12 +167,14 @@ def main():
from JoinPrune import JoinPrune
from GraftAck import GraftAck
from Graft import Graft
from StateRefresh import StateRefresh
Hello()
Assert()
JoinPrune()
Graft()
GraftAck()
StateRefresh()
global kernel
kernel = Kernel()
......
import struct
from abc import ABCMeta, abstractstaticmethod
from .PacketPimHelloOptions import PacketPimHelloOptions, PacketPimHelloStateRefreshCapable, PacketPimHelloGenerationID, PacketPimHelloLANPruneDelay, PacketPimHelloHoldtime
'''
0 1 2 3
......@@ -25,6 +27,7 @@ class PacketPimHello:
PIM_HDR_OPTS_LEN = struct.calcsize(PIM_HDR_OPTS)
PIM_MSG_TYPES_LENGTH = {1: 2,
2: 4,
20: 4,
21: 4,
}
......@@ -33,16 +36,26 @@ class PacketPimHello:
def __init__(self):
self.options = {}
'''
def add_option(self, option_type: int, option_value: int or float):
option_value = int(option_value)
# if option_value requires more bits than the bits available for that field: option value will have all field bits = 1
if option_type in self.PIM_MSG_TYPES_LENGTH and self.PIM_MSG_TYPES_LENGTH[option_type] * 8 < option_value.bit_length():
option_value = (1 << (self.PIM_MSG_TYPES_LENGTH[option_type] * 8)) - 1
self.options[option_type] = option_value
'''
def add_option(self, option: 'PacketPimHelloOptions'):
#if option_type in self.PIM_MSG_TYPES_LENGTH and self.PIM_MSG_TYPES_LENGTH[option_type] * 8 < option_value.bit_length():
# option_value = (1 << (self.PIM_MSG_TYPES_LENGTH[option_type] * 8)) - 1
self.options[option.type] = option
def get_options(self):
return self.options
'''
def bytes(self) -> bytes:
res = b''
for (option_type, option_value) in self.options.items():
......@@ -50,10 +63,21 @@ class PacketPimHello:
type_length_hdr = struct.pack(PacketPimHello.PIM_HDR_OPTS, option_type, option_length)
res += type_length_hdr + struct.pack("! " + str(option_length) + "s", option_value.to_bytes(option_length, byteorder='big'))
return res
'''
def bytes(self) -> bytes:
res = b''
for option in self.options.values():
res += option.bytes()
return res
def __len__(self):
return len(self.bytes())
'''
@staticmethod
def parse_bytes(data: bytes):
pim_payload = PacketPimHello()
......@@ -66,13 +90,24 @@ class PacketPimHello:
(option_value,) = struct.unpack("! " + str(option_length) + "s", data[:option_length])
option_value_number = int.from_bytes(option_value, byteorder='big')
print("option value: ", option_value_number)
'''
options_list.append({"OPTION TYPE": option_type,
"OPTION LENGTH": option_length,
"OPTION VALUE": option_value_number
})
'''
#options_list.append({"OPTION TYPE": option_type,
# "OPTION LENGTH": option_length,
# "OPTION VALUE": option_value_number
# })
pim_payload.add_option(option_type, option_value_number)
data = data[option_length:]
return pim_payload
'''
@staticmethod
def parse_bytes(data: bytes):
pim_payload = PacketPimHello()
while data != b'':
option = PacketPimHelloOptions.parse_bytes(data)
option_length = len(option)
data = data[option_length:]
pim_payload.add_option(option)
return pim_payload
import struct
from abc import ABCMeta
import math
class PacketPimHelloOptions(metaclass=ABCMeta):
PIM_HDR_OPTS = "! HH"
PIM_HDR_OPTS_LEN = struct.calcsize(PIM_HDR_OPTS)
'''
0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Type | Length |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
'''
def __init__(self, type: int, length: int):
self.type = type
self.length = length
def bytes(self) -> bytes:
return struct.pack(PacketPimHelloOptions.PIM_HDR_OPTS, self.type, self.length)
def __len__(self):
return self.PIM_HDR_OPTS_LEN + self.length
@staticmethod
def parse_bytes(data: bytes, type:int = None, length:int = None):
(type, length) = struct.unpack(PacketPimHelloOptions.PIM_HDR_OPTS,
data[:PacketPimHelloOptions.PIM_HDR_OPTS_LEN])
print("TYPE:", type)
print("LENGTH:", length)
data = data[PacketPimHelloOptions.PIM_HDR_OPTS_LEN:]
#return PIM_MSG_TYPES[type](data)
return PIM_MSG_TYPES.get(type, PacketPimHelloUnknown).parse_bytes(data, type, length)
class PacketPimHelloStateRefreshCapable(PacketPimHelloOptions):
PIM_HDR_OPT = "! BBH"
PIM_HDR_OPT_LEN = struct.calcsize(PIM_HDR_OPT)
'''
0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Version = 1 | Interval | Reserved |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
'''
VERSION = 1
def __init__(self, interval: int):
super().__init__(type=21, length=4)
self.interval = interval
def bytes(self) -> bytes:
return super().bytes() + struct.pack(self.PIM_HDR_OPT, self.VERSION, self.interval, 0)
@staticmethod
def parse_bytes(data: bytes, type:int = None, length:int = None):
if type is None or length is None:
raise Exception
(version, interval, _) = struct.unpack(PacketPimHelloStateRefreshCapable.PIM_HDR_OPT,
data[:PacketPimHelloStateRefreshCapable.PIM_HDR_OPT_LEN])
return PacketPimHelloStateRefreshCapable(interval)
class PacketPimHelloLANPruneDelay(PacketPimHelloOptions):
PIM_HDR_OPT = "! HH"
PIM_HDR_OPT_LEN = struct.calcsize(PIM_HDR_OPT)
'''
0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
|T| LAN Prune Delay | Override Interval |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
'''
def __init__(self, lan_prune_delay: float, override_interval: float):
super().__init__(type=2, length=4)
self.lan_prune_delay = 0x7FFF & math.ceil(lan_prune_delay)
self.override_interval = math.ceil(override_interval)
def bytes(self) -> bytes:
return super().bytes() + struct.pack(self.PIM_HDR_OPT, self.lan_prune_delay, self.override_interval)
@staticmethod
def parse_bytes(data: bytes, type:int = None, length:int = None):
if type is None or length is None:
raise Exception
(lan_prune_delay, override_interval) = struct.unpack(PacketPimHelloLANPruneDelay.PIM_HDR_OPT,
data[:PacketPimHelloLANPruneDelay.PIM_HDR_OPT_LEN])
lan_prune_delay = lan_prune_delay & 0x7FFF
return PacketPimHelloLANPruneDelay(lan_prune_delay=lan_prune_delay, override_interval=override_interval)
class PacketPimHelloHoldtime(PacketPimHelloOptions):
PIM_HDR_OPT = "! H"
PIM_HDR_OPT_LEN = struct.calcsize(PIM_HDR_OPT)
'''
0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Hold Time |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
'''
def __init__(self, holdtime: int or float):
super().__init__(type=1, length=2)
self.holdtime = int(holdtime)
def bytes(self) -> bytes:
return super().bytes() + struct.pack(self.PIM_HDR_OPT, self.holdtime)
@staticmethod
def parse_bytes(data: bytes, type:int = None, length:int = None):
if type is None or length is None:
raise Exception
(holdtime, ) = struct.unpack(PacketPimHelloHoldtime.PIM_HDR_OPT,
data[:PacketPimHelloHoldtime.PIM_HDR_OPT_LEN])
print("HOLDTIME:", holdtime)
return PacketPimHelloHoldtime(holdtime=holdtime)
class PacketPimHelloGenerationID(PacketPimHelloOptions):
PIM_HDR_OPT = "! L"
PIM_HDR_OPT_LEN = struct.calcsize(PIM_HDR_OPT)
'''
0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Generation ID |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
'''
def __init__(self, generation_id: int):
super().__init__(type=20, length=4)
self.generation_id = generation_id
def bytes(self) -> bytes:
return super().bytes() + struct.pack(self.PIM_HDR_OPT, self.generation_id)
@staticmethod
def parse_bytes(data: bytes, type:int = None, length:int = None):
if type is None or length is None:
raise Exception
(generation_id, ) = struct.unpack(PacketPimHelloGenerationID.PIM_HDR_OPT,
data[:PacketPimHelloGenerationID.PIM_HDR_OPT_LEN])
print("GenerationID:", generation_id)
return PacketPimHelloGenerationID(generation_id=generation_id)
class PacketPimHelloUnknown(PacketPimHelloOptions):
PIM_HDR_OPT = "! L"
PIM_HDR_OPT_LEN = struct.calcsize(PIM_HDR_OPT)
'''
0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Unknown |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
'''
def __init__(self, type, length):
super().__init__(type=type, length=length)
print("PIM Hello Option Unknown... TYPE=", type, "LENGTH=", length)
def bytes(self) -> bytes:
raise Exception
@staticmethod
def parse_bytes(data: bytes, type:int = None, length:int = None):
if type is None or length is None:
raise Exception
return PacketPimHelloUnknown(type, length)
PIM_MSG_TYPES = {1: PacketPimHelloHoldtime,
2: PacketPimHelloLANPruneDelay,
20: PacketPimHelloGenerationID,
21: PacketPimHelloStateRefreshCapable,
}
......@@ -64,10 +64,13 @@ class MyDaemon(Daemon):
elif 'list_state' in args and args.list_state:
connection.sendall(pickle.dumps(Main.list_state()))
elif 'add_interface' in args and args.add_interface:
Main.add_interface(args.add_interface[0], pim=True)
Main.add_pim_interface(args.add_interface[0], False)
connection.shutdown(socket.SHUT_RDWR)
elif 'add_interface_sr' in args and args.add_interface_sr:
Main.add_pim_interface(args.add_interface_sr[0], True)
connection.shutdown(socket.SHUT_RDWR)
elif 'add_interface_igmp' in args and args.add_interface_igmp:
Main.add_interface(args.add_interface_igmp[0], igmp=True)
Main.add_igmp_interface(args.add_interface_igmp[0])
connection.shutdown(socket.SHUT_RDWR)
elif 'remove_interface' in args and args.remove_interface:
Main.remove_interface(args.remove_interface[0], pim=True)
......@@ -99,6 +102,7 @@ if __name__ == "__main__":
group.add_argument("-ls", "--list_state", action="store_true", default=False, help="List state of IGMP")
group.add_argument("-mr", "--multicast_routes", action="store_true", default=False, help="List Multicast Routing table")
group.add_argument("-ai", "--add_interface", nargs=1, metavar='INTERFACE_NAME', help="Add PIM interface")
group.add_argument("-aisr", "--add_interface_sr", nargs=1, metavar='INTERFACE_NAME', help="Add PIM interface with State Refresh enabled")
group.add_argument("-aiigmp", "--add_interface_igmp", nargs=1, metavar='INTERFACE_NAME', help="Add IGMP interface")
group.add_argument("-ri", "--remove_interface", nargs=1, metavar='INTERFACE_NAME', help="Remove PIM interface")
group.add_argument("-riigmp", "--remove_interface_igmp", nargs=1, metavar='INTERFACE_NAME', help="Remove IGMP interface")
......
......@@ -4,9 +4,9 @@ from Packet.Packet import Packet
from Packet.ReceivedPacket import ReceivedPacket
from Packet.PacketPimHello import PacketPimHello
from Packet.PacketPimHeader import PacketPimHeader
from Packet.PacketPimStateRefresh import PacketPimStateRefresh
from Interface import Interface
import Main
from utils import HELLO_HOLD_TIME_TIMEOUT
class StateRefresh:
......@@ -17,8 +17,27 @@ class StateRefresh:
# receive handler
def receive_handle(self, packet: ReceivedPacket):
#check if interface supports state refresh
if not packet.interface._state_refresh_capable:
return
ip = packet.ip_header.ip_src
print("ip = ", ip)
pkt_join_prune = packet.payload.payload
pkt_state_refresh = packet.payload.payload # type: PacketPimStateRefresh
# TODO
raise Exception
\ No newline at end of file
interface_index = packet.interface.vif_index
source = pkt_state_refresh.source_address
group = pkt_state_refresh.multicast_group_adress
source_group = (source, group)
try:
Main.kernel.get_routing_entry(source_group).recv_state_refresh_msg(interface_index, packet)
except:
try:
# import time
# time.sleep(2)
Main.kernel.get_routing_entry(source_group).recv_state_refresh_msg(interface_index, packet)
except:
pass
......@@ -65,7 +65,10 @@ class UnicastRouting(object):
unicast_routing_entry = UnicastRouting.get_route(ip_dst)
entry_protocol = unicast_routing_entry["proto"]
entry_cost = unicast_routing_entry["priority"]
return (entry_protocol, entry_cost)
mask = unicast_routing_entry["dst_len"]
if entry_cost is None:
entry_cost = 0
return (entry_protocol, entry_cost, mask)
"""
def get_rpf(ip_dst: str):
......
import subprocess
import struct
import socket
from ctypes import create_string_buffer, addressof
SO_ATTACH_FILTER = 26
ETH_P_IP = 0x0800 # Internet Protocol packet
def get_s_g_bpf_filter_code(source, group, interface_name):
cmd = "tcpdump -ddd \"(udp or icmp) and host %s and dst %s\"" % (source, group)
result = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE)
bpf_filter = b''
FILTER = []
tmp = result.stdout.read().splitlines()
num = int(tmp[0])
for line in tmp[1:]: #read and store result in log file
print(line)
#FILTER += (struct.pack("HBBI", *tuple(map(int, line.split(b' ')))), )
bpf_filter += struct.pack("HBBI", *tuple(map(int, line.split(b' '))))
print(num)
print(FILTER)
# defined in linux/filter.h.
b = create_string_buffer(bpf_filter)
mem_addr_of_filters = addressof(b)
fprog = struct.pack('HL', num, mem_addr_of_filters)
# Create listening socket with filters
s = socket.socket(socket.AF_PACKET, socket.SOCK_RAW, ETH_P_IP)
s.setsockopt(socket.SOL_SOCKET, SO_ATTACH_FILTER, fprog)
s.bind((interface_name, ETH_P_IP))
return s
......@@ -8,6 +8,7 @@ from .tree_interface import TreeInterface
from threading import Timer, Lock, RLock
from tree.metric import AssertMetric
import UnicastRouting
from Packet.PacketPimStateRefresh import PacketPimStateRefresh
class KernelEntry:
TREE_TIMEOUT = 180
......@@ -37,21 +38,12 @@ class KernelEntry:
print("RPF_NODE:", UnicastRouting.get_route(source_ip))
print(self.rpf_node == source_ip)
# (S,G) starts IG state
self._was_olist_null = False
# todo
#self._rpf_is_origin = False
self._originator_state = OriginatorState.NotOriginator
# decide inbound interface based on rpf check
self.inbound_interface_index = Main.kernel.vif_dic[self.check_rpf()]
#Main.kernel.flood(source_ip, group_ip, self.inbound_interface_index)
self.interface_state = {} # type: Dict[int, TreeInterface]
for i in Main.kernel.vif_index_to_name_dic.keys():
try:
......@@ -71,10 +63,6 @@ class KernelEntry:
self.change()
self.evaluate_olist_change()
print('Tree created')
#self._liveliness_timer = None
#if self.is_originater():
# self.set_liveliness_timer()
# print('set SAT')
#self._lock = threading.RLock()
......@@ -92,9 +80,9 @@ class KernelEntry:
return UnicastRouting.check_rpf(self.source_ip)
#################################
# Receive (S,G) packet
#################################
################################################
# Receive (S,G) data packets or control packets
################################################
def recv_data_msg(self, index):
print("recv data")
self.interface_state[index].recv_data_msg()
......@@ -132,8 +120,38 @@ class KernelEntry:
def recv_state_refresh_msg(self, index, packet):
print("recv state refresh msg")
prune_indicator = 1
self.interface_state[index].recv_state_refresh_msg(prune_indicator)
source_of_state_refresh = packet.ip_header.ip_src
metric_preference = packet.payload.payload.metric_preference
metric = packet.payload.payload.metric
mask_len = packet.payload.payload.mask_len
ttl = packet.payload.payload.ttl
prune_indicator_flag = packet.payload.payload.prune_indicator_flag #P
assert_override_flag = packet.payload.payload.assert_override_flag #O
interval = packet.payload.payload.interval
received_metric = AssertMetric(metric_preference=metric_preference, route_metric=metric, ip_address=source_of_state_refresh, state_refresh_interval=interval)
self.interface_state[index].recv_state_refresh_msg(received_metric, prune_indicator_flag)
iif = packet.interface.vif_index
if iif != self.inbound_interface_index:
return
if self.interface_state[iif].get_neighbor_RPF() != source_of_state_refresh:
return
# todo refresh limit
if ttl == 0:
return
self.forward_state_refresh_msg(packet.payload.payload)
################################################
# Send state refresh msg
################################################
def forward_state_refresh_msg(self, state_refresh_packet):
for interface in self.interface_state.values():
interface.send_state_refresh(state_refresh_packet)
###############################################################
......@@ -177,13 +195,9 @@ class KernelEntry:
self.rpf_node = rpf_node
self.interface_state[self.inbound_interface_index].change_rpf(self._was_olist_null)
def update(self, caller, arg):
#todo
return
def nbr_event(self, link, node, event):
# todo
# todo pode ser interessante verificar se a adicao/remocao de vizinhos se altera o olist
return
def is_olist_null(self):
......
......@@ -37,12 +37,11 @@ class AssertStateABC(metaclass=ABCMeta):
raise NotImplementedError()
@abstractstaticmethod
def receivedPreferedMetric(interface: "TreeInterfaceDownstream", assert_time, better_metric):
def receivedPreferedMetric(interface: "TreeInterfaceDownstream", better_metric):
"""
Receive Preferred Assert OR State Refresh
@type interface: TreeInterface
@type assert_time: int
@type better_metric: AssertMetric
"""
raise NotImplementedError()
......@@ -160,11 +159,11 @@ class NoInfoState(AssertStateABC):
'receivedInferiorMetricFromNonWinner_couldAssertIsTrue, NI -> W')
@staticmethod
def receivedPreferedMetric(interface: "TreeInterfaceDownstream", better_metric, state_refresh_interval = None):
def receivedPreferedMetric(interface: "TreeInterfaceDownstream", better_metric):
'''
@type interface: TreeInterface
'''
#interface.assert_timer.set_timer(assert_time)
state_refresh_interval = better_metric.state_refresh_interval
if state_refresh_interval is None:
# event caused by Assert Msg
assert_timer_value = pim_globals.ASSERT_TIME
......@@ -175,12 +174,8 @@ class NoInfoState(AssertStateABC):
interface.set_assert_timer(assert_timer_value)
interface.set_assert_winner_metric(better_metric)
interface.set_assert_state(AssertState.Loser)
#interface.assert_timer.reset()
#interface.assert_state = AssertState.Loser
#interface.assert_winner_metric = better_metric
# todo MUST also multicast a Prune(S,G) to the Assert winner <- TO THE colocar endereco do winner
# MUST also multicast a Prune(S,G) to the Assert winner
if interface.could_assert():
interface.send_prune(holdtime=assert_timer_value)
......@@ -240,14 +235,11 @@ class WinnerState(AssertStateABC):
'receivedInferiorMetricFromNonWinner_couldAssertIsTrue, W -> W')
@staticmethod
def receivedPreferedMetric(interface: "TreeInterfaceDownstream", better_metric, state_refresh_interval = None):
def receivedPreferedMetric(interface: "TreeInterfaceDownstream", better_metric):
'''
@type better_metric: AssertMetric
'''
#interface.assert_timer.set_timer(assert_time)
#interface.assert_timer.reset()
state_refresh_interval = better_metric.state_refresh_interval
if state_refresh_interval is None:
# event caused by AssertMsg
assert_timer_value = pim_globals.ASSERT_TIME
......@@ -256,22 +248,15 @@ class WinnerState(AssertStateABC):
assert_timer_value = state_refresh_interval*3
interface.set_assert_timer(assert_timer_value)
interface.set_assert_winner_metric(better_metric)
#interface.assert_state = AssertState.Loser
interface.set_assert_state(AssertState.Loser)
if interface.could_assert:
interface.send_prune(holdtime=assert_timer_value)
interface.send_prune(holdtime=assert_timer_value)
print('receivedPreferedMetric, W -> L')
@staticmethod
def sendStateRefresh(interface: "TreeInterfaceDownstream", state_refresh_interval):
#interface.assert_timer.set_timer(time)
interface.set_assert_timer(state_refresh_interval*3)
#interface.assert_timer.reset()
@staticmethod
def assertTimerExpires(interface: "TreeInterfaceDownstream"):
......@@ -334,12 +319,11 @@ class LoserState(AssertStateABC):
'receivedInferiorMetricFromNonWinner_couldAssertIsTrue, L -> L')
@staticmethod
def receivedPreferedMetric(interface: "TreeInterfaceDownstream", better_metric, state_refresh_interval = None):
def receivedPreferedMetric(interface: "TreeInterfaceDownstream", better_metric):
'''
@type better_metric: AssertMetric
'''
#interface.assert_timer.set_timer(assert_time)
#interface.assert_timer.reset()
state_refresh_interval = better_metric.state_refresh_interval
if state_refresh_interval is None:
assert_timer_value = pim_globals.ASSERT_TIME
else:
......@@ -353,7 +337,7 @@ class LoserState(AssertStateABC):
if interface.could_assert():
# todo enviar holdtime = assert_timer_value???!
interface.send_prune()
interface.send_prune(holdtime=assert_timer_value)
print('receivedPreferedMetric, L -> L')
......
......@@ -34,7 +34,7 @@ class DownstreamStateABS(metaclass=ABCMeta):
raise NotImplementedError()
@abstractstaticmethod
def PPTexpires(interface: "TreeInterfaceDownstream", prune_holdtime):
def PPTexpires(interface: "TreeInterfaceDownstream"):
"""
PPT(S,G) Expires
......@@ -127,7 +127,7 @@ class NoInfo(DownstreamStateABS):
print('receivedGraft, NI -> NI')
@staticmethod
def PPTexpires(interface: "TreeInterfaceDownstream", prune_holdtime):
def PPTexpires(interface: "TreeInterfaceDownstream"):
"""
PPT(S,G) Expires
......@@ -221,7 +221,7 @@ class PrunePending(DownstreamStateABS):
print('receivedGraft, PP -> NI')
@staticmethod
def PPTexpires(interface: "TreeInterfaceDownstream", prune_holdtime):
def PPTexpires(interface: "TreeInterfaceDownstream"):
"""
PPT(S,G) Expires
......@@ -335,7 +335,7 @@ class Pruned(DownstreamStateABS):
print('receivedGraft, P -> NI')
@staticmethod
def PPTexpires(interface: "TreeInterfaceDownstream", prune_holdtime):
def PPTexpires(interface: "TreeInterfaceDownstream"):
"""
PPT(S,G) Expires
......
......@@ -10,6 +10,7 @@ ASSERT_TIME = 180
GRAFT_RETRY_PERIOD = 3
JT_OVERRIDE_INTERVAL = 3.0
OVERRIDE_INTERVAL = 2.5
PROPAGATION_DELAY = 0.5
REFRESH_INTERVAL = 60 # State Refresh Interval
SOURCE_LIFETIME = 210
T_LIMIT = 210
......
......@@ -5,13 +5,14 @@ class AssertMetric(object):
Note: we consider the node name the ip of the metric.
'''
def __init__(self, metric_preference: int or float = float("Inf"), route_metric: int or float = float("Inf"), ip_address: str = "0.0.0.0"):
def __init__(self, metric_preference: int or float = float("Inf"), route_metric: int or float = float("Inf"), ip_address: str = "0.0.0.0", state_refresh_interval:int = None):
if type(ip_address) is str:
ip_address = ipaddress.ip_address(ip_address)
self._metric_preference = metric_preference
self._route_metric = route_metric
self._ip_address = ip_address
self._state_refresh_interval = state_refresh_interval
def is_better_than(self, other):
if self.metric_preference != other.metric_preference:
......@@ -39,7 +40,7 @@ class AssertMetric(object):
'''
(source_ip, _) = tree_if.get_tree_id()
import UnicastRouting
metric_preference, metric_cost = UnicastRouting.get_metric(source_ip)
(metric_preference, metric_cost, _) = UnicastRouting.get_metric(source_ip)
return AssertMetric(metric_preference, metric_cost, tree_if.get_ip())
......@@ -75,6 +76,14 @@ class AssertMetric(object):
self._ip_address = value
@property
def state_refresh_interval(self):
return self._state_refresh_interval
@state_refresh_interval.setter
def state_refresh_interval(self, value):
self._state_refresh_interval = value
def get_ip(self):
return str(self._ip_address)
......@@ -3,6 +3,7 @@ from abc import ABCMeta, abstractstaticmethod
from tree import globals as pim_globals
class OriginatorStateABC(metaclass=ABCMeta):
@abstractstaticmethod
def recvDataMsgFromSource(tree):
pass
......@@ -22,33 +23,32 @@ class OriginatorStateABC(metaclass=ABCMeta):
class Originator(OriginatorStateABC):
@staticmethod
def recvDataMsgFromSource(tree):
tree.source_active_timer.reset()
tree.set_source_active_timer()
@staticmethod
def SRTexpires(tree):
'''
@type tree: Tree
'''
tree.rprint('SRT expired, O to O')
print('SRT expired, O to O')
tree.state_refresh_timer.reset()
tree.send_state_refresh_msg()
tree.set_state_refresh_timer()
tree.create_state_refresh_msg()
@staticmethod
def SATexpires(tree):
tree.rprint('SAT expired, O to NO')
print('SAT expired, O to NO')
tree.source_active_timer.stop()
tree.state_refresh_timer.stop()
tree.originator_state = OriginatorState.NotOriginator
tree.clear_state_refresh_timer()
tree.set_originator_state(OriginatorState.NotOriginator)
@staticmethod
def SourceNotConnected(tree):
tree.rprint('Source no longer directly connected, O to NO')
print('Source no longer directly connected, O to NO')
tree.source_active_timer.stop()
tree.state_refresh_timer.stop()
tree.originator_state = OriginatorState.NotOriginator
tree.clear_state_refresh_timer()
tree.clear_source_active_timer()
tree.set_originator_state(OriginatorState.NotOriginator)
class NotOriginator(OriginatorStateABC):
......@@ -57,14 +57,12 @@ class NotOriginator(OriginatorStateABC):
'''
@type interface: Tree
'''
tree.originator_state = OriginatorState.Originator
tree.set_originator_state(OriginatorState.Originator)
tree.state_refresh_timer.start()
tree.source_active_timer.start()
tree.set_state_refresh_timer()
tree.set_source_active_timer()
tree.rprint('new DataMsg from Source, NO to O')
# Since the recording of the TTL is common to both states,its registering is made on the
# Tree.new_state_refresh_msg(...) method
print('new DataMsg from Source, NO to O')
@staticmethod
def SRTexpires(tree):
......@@ -76,7 +74,7 @@ class NotOriginator(OriginatorStateABC):
@staticmethod
def SourceNotConnected(tree):
pass
return
class OriginatorState():
......
......@@ -15,8 +15,13 @@ from .metric import AssertMetric
from .downstream_prune import DownstreamState, DownstreamStateABS
from .tree_interface import TreeInterface
from Packet.ReceivedPacket import ReceivedPacket
from Packet.PacketPimAssert import PacketPimAssert
from threading import Lock
from Packet.PacketPimStateRefresh import PacketPimStateRefresh
from Packet.Packet import Packet
from Packet.PacketPimHeader import PacketPimHeader
import traceback
class TreeInterfaceDownstream(TreeInterface):
def __init__(self, kernel_entry, interface_id):
......@@ -88,7 +93,7 @@ class TreeInterfaceDownstream(TreeInterface):
# Timer timeout
###########################################
def prune_pending_timeout(self):
self._prune_state.PPTexpires(self, 10)
self._prune_state.PPTexpires(self)
def prune_timeout(self):
self._prune_state.PTexpires(self)
......@@ -103,7 +108,6 @@ class TreeInterfaceDownstream(TreeInterface):
def recv_prune_msg(self, upstream_neighbor_address, holdtime):
super().recv_prune_msg(upstream_neighbor_address, holdtime)
#TODO if upstream_neighbor_address == self.get_ip():
if upstream_neighbor_address == self.get_ip():
self.set_receceived_prune_holdtime(holdtime)
self._prune_state.receivedPrune(self, holdtime)
......@@ -124,6 +128,55 @@ class TreeInterfaceDownstream(TreeInterface):
self._prune_state.receivedGraft(self, source_ip)
######################################
# Send messages
######################################
def send_state_refresh(self, state_refresh_msg_received):
if not self.get_interface()._state_refresh_capable:
return
interval = state_refresh_msg_received.interval
self._assert_state.sendStateRefresh(self, interval)
self._prune_state.send_state_refresh(self)
if self.lost_assert():
return
prune_indicator_bit = 0
if self.is_pruned():
prune_indicator_bit = 1
# TODO set timer
# todo maybe ja feito na maquina de estados Prune downstream
# if state_refresh_capable
# set PT....
import UnicastRouting
(metric_preference, metric, mask) = UnicastRouting.get_metric(state_refresh_msg_received.source_address)
assert_override_flag = 0
if self._assert_state == AssertState.NoInfo:
assert_override_flag = 1
try:
ph = PacketPimStateRefresh(multicast_group_adress=state_refresh_msg_received.multicast_group_adress,
source_address=state_refresh_msg_received.source_address,
originator_adress=state_refresh_msg_received.originator_adress,
metric_preference=metric_preference, metric=metric, mask_len=mask,
ttl=state_refresh_msg_received.ttl - 1,
prune_indicator_flag=prune_indicator_bit,
prune_now_flag=state_refresh_msg_received.prune_now_flag,
assert_override_flag=assert_override_flag,
interval=interval)
pckt = Packet(payload=PacketPimHeader(ph))
self.get_interface().send(pckt.bytes())
except:
traceback.print_exc()
return
##########################################################
# Override
def is_forwarding(self):
......
......@@ -12,6 +12,12 @@ from .upstream_prune import UpstreamState
from threading import Timer
from .globals import *
import random
from .metric import AssertMetric
from .originator import OriginatorState, OriginatorStateABC
from Packet.PacketPimStateRefresh import PacketPimStateRefresh
import traceback
from . import DataPacketsSocket
import threading
class TreeInterfaceUpstream(TreeInterface):
......@@ -22,10 +28,37 @@ class TreeInterfaceUpstream(TreeInterface):
self._override_timer = None
self._prune_limit_timer = None
self._originator_state = None
self._originator_state = OriginatorState.NotOriginator
self._state_refresh_timer = None
self._source_active_timer = None
self._prune_now_counter = 0
if self.is_S_directly_conn():
self._graft_prune_state.sourceIsNowDirectConnect(self)
self._originator_state.recvDataMsgFromSource(self)
# TODO TESTE SOCKET RECV DATA PCKTS
self.socket_is_enabled = True
(s,g) = self.get_tree_id()
interface_name = self.get_interface().interface_name
self.socket_pkt = DataPacketsSocket.get_s_g_bpf_filter_code(s, g, interface_name)
# run receive method in background
receive_thread = threading.Thread(target=self.socket_recv)
receive_thread.daemon = True
receive_thread.start()
def socket_recv(self):
while self.socket_is_enabled:
try:
self.socket_pkt.recvfrom(0)
print("PACOTE DADOS RECEBIDO")
self.recv_data_msg()
except:
traceback.print_exc()
continue
##########################################
# Set state
......@@ -38,6 +71,9 @@ class TreeInterfaceUpstream(TreeInterface):
self.change_tree()
self.evaluate_ingroup()
def set_originator_state(self, new_state: OriginatorStateABC):
if new_state != self._originator_state:
self._originator_state = new_state
##########################################
# Check timers
......@@ -81,6 +117,26 @@ class TreeInterfaceUpstream(TreeInterface):
if self._prune_limit_timer is not None:
self._prune_limit_timer.cancel()
# State Refresh timers
def set_state_refresh_timer(self):
self.clear_state_refresh_timer()
self._state_refresh_timer = Timer(REFRESH_INTERVAL, self.state_refresh_timeout)
self._state_refresh_timer.start()
def clear_state_refresh_timer(self):
if self._state_refresh_timer is not None:
self._state_refresh_timer.cancel()
def set_source_active_timer(self):
self.clear_source_active_timer()
self._source_active_timer = Timer(SOURCE_LIFETIME, self.source_active_timeout)
self._source_active_timer.start()
def clear_source_active_timer(self):
if self._source_active_timer is not None:
self._source_active_timer.cancel()
###########################################
# Timer timeout
###########################################
......@@ -93,6 +149,13 @@ class TreeInterfaceUpstream(TreeInterface):
def prune_limit_timeout(self):
return
# State Refresh timers
def state_refresh_timeout(self):
self._originator_state.SRTexpires(self)
def source_active_timeout(self):
self._originator_state.SATexpires(self)
###########################################
# Recv packets
###########################################
......@@ -101,12 +164,9 @@ class TreeInterfaceUpstream(TreeInterface):
if self.is_olist_null() and not self.is_prune_limit_timer_running() and not self.is_S_directly_conn():
self._graft_prune_state.dataArrivesRPFinterface_OListNull_PLTstoped(self)
def recv_state_refresh_msg(self, prune_indicator: int):
# todo check rpf nbr
if prune_indicator == 1:
self._graft_prune_state.stateRefreshArrivesRPFnbr_pruneIs1(self)
elif prune_indicator == 0 and not self.is_prune_limit_timer_running():
self._graft_prune_state.stateRefreshArrivesRPFnbr_pruneIs0_PLTstoped(self)
if self.is_S_directly_conn():
self._originator_state.recvDataMsgFromSource(self)
def recv_join_msg(self, upstream_neighbor_address):
super().recv_join_msg(upstream_neighbor_address)
......@@ -122,6 +182,33 @@ class TreeInterfaceUpstream(TreeInterface):
# todo check rpf nbr
self._graft_prune_state.recvGraftAckFromRPFnbr(self)
def recv_state_refresh_msg(self, received_metric: AssertMetric, prune_indicator: int):
super().recv_state_refresh_msg(received_metric, prune_indicator)
if self.get_neighbor_RPF() != received_metric.get_ip():
return
if prune_indicator == 1:
self._graft_prune_state.stateRefreshArrivesRPFnbr_pruneIs1(self)
elif prune_indicator == 0 and not self.is_prune_limit_timer_running():
self._graft_prune_state.stateRefreshArrivesRPFnbr_pruneIs0_PLTstoped(self)
####################################
def create_state_refresh_msg(self):
self._prune_now_counter+=1
self._prune_now_counter%=3
(source_ip, group_ip) = self.get_tree_id()
ph = PacketPimStateRefresh(multicast_group_adress=group_ip,
source_address=source_ip,
originator_adress=self.get_ip(),
metric_preference=0, metric=0, mask_len=0,
ttl=256,
prune_indicator_flag=0,
prune_now_flag=(self._prune_now_counter+1)//3,
assert_override_flag=0,
interval=60)
self._kernel_entry.forward_state_refresh_msg(ph)
###########################################
# Change olist
###########################################
......@@ -147,14 +234,20 @@ class TreeInterfaceUpstream(TreeInterface):
#Override
def delete(self):
super().delete()
self.socket_is_enabled = False
self.socket_pkt.close()
self.clear_graft_retry_timer()
self.clear_assert_timer()
self.clear_prune_limit_timer()
self.clear_override_timer()
self.clear_state_refresh_timer()
self.clear_source_active_timer()
def is_downstream(self):
return False
def is_originator(self):
return self._originator_state == OriginatorState.Originator
#-------------------------------------------------------------------------
# Properties
......
......@@ -44,6 +44,7 @@ class TreeInterface(metaclass=ABCMeta):
#self._cost = cost
#self._evaluate_ig = evaluate_ig_cb
# Local Membership State
try:
interface_name = Main.kernel.vif_index_to_name_dic[interface_id]
igmp_interface = Main.igmp_interfaces[interface_name] # type: InterfaceIGMP
......@@ -56,9 +57,6 @@ class TreeInterface(metaclass=ABCMeta):
self._local_membership_state = LocalMembership.NoInfo
# Local Membership State
#self._local_membership_state = None # todo NoInfo or Include
# Prune State
self._prune_state = DownstreamState.NoInfo
self._prune_pending_timer = None
......@@ -134,7 +132,7 @@ class TreeInterface(metaclass=ABCMeta):
pass
def recv_assert_msg(self, received_metric: AssertMetric):
if self._assert_winner_metric.is_better_than(received_metric):
if self.my_assert_metric().is_better_than(received_metric):
# received inferior assert
if self._assert_winner_metric.ip_address == received_metric.ip_address:
# received from assert winner
......@@ -142,16 +140,10 @@ class TreeInterface(metaclass=ABCMeta):
elif self.could_assert():
# received from non assert winner and could_assert
self._assert_state.receivedInferiorMetricFromNonWinner_couldAssertIsTrue(self)
else:
elif received_metric.is_better_than(self._assert_winner_metric):
#received preferred assert
self._assert_state.receivedPreferedMetric(self, received_metric)
def recv_reset_msg(self):
pass
def recv_prune_msg(self, upstream_neighbor_address, holdtime):
if upstream_neighbor_address == self.get_ip():
self._assert_state.receivedPruneOrJoinOrGraft(self)
......@@ -167,14 +159,8 @@ class TreeInterface(metaclass=ABCMeta):
def recv_graft_ack_msg(self):
pass
def recv_state_refresh_msg(self, prune_indicator):
pass
def forward_state_reset_msg(self):
raise NotImplemented
def recv_state_refresh_msg(self, received_metric: AssertMetric, prune_indicator):
self.recv_assert_msg(received_metric)
######################################
......@@ -185,48 +171,37 @@ class TreeInterface(metaclass=ABCMeta):
try:
(source, group) = self.get_tree_id()
# todo self.get_rpf_()
ip_dst = self.get_neighbor_RPF()
ph = PacketPimGraft(ip_dst)
ph.add_multicast_group(PacketPimJoinPruneMulticastGroup(group, joined_src_addresses=[source]))
pckt = Packet(payload=PacketPimHeader(ph))
self.get_interface().send(pckt.bytes(), ip_dst)
#msg = GraftMsg(self.get_tree().tree_id, self.get_rpf_())
#self.pim_if.send_mcast(msg)
except:
traceback.print_exc()
return
def send_graft_ack(self, ip_sender):
print("send graft ack")
try:
(source, group) = self.get_tree_id()
# todo endereco?!!
ph = PacketPimGraftAck(ip_sender)
ph.add_multicast_group(PacketPimJoinPruneMulticastGroup(group, joined_src_addresses=[source]))
pckt = Packet(payload=PacketPimHeader(ph))
self.get_interface().send(pckt.bytes(), ip_sender)
#msg = GraftAckMsg(self.get_tree().tree_id, self.get_node())
#self.pim_if.send_mcast(msg)
except:
traceback.print_exc()
return
def send_prune(self, holdtime=None):
if holdtime is None:
holdtime = T_LIMIT
print("send prune")
try:
(source, group) = self.get_tree_id()
# todo help ip of ph
#ph = PacketPimJoinPrune("123.123.123.123", 210)
ph = PacketPimJoinPrune(self.get_neighbor_RPF(), holdtime)
ph.add_multicast_group(PacketPimJoinPruneMulticastGroup(group, pruned_src_addresses=[source]))
pckt = Packet(payload=PacketPimHeader(ph))
......@@ -242,7 +217,6 @@ class TreeInterface(metaclass=ABCMeta):
holdtime = T_LIMIT
try:
(source, group) = self.get_tree_id()
# todo help ip of ph
ph = PacketPimJoinPrune(self.get_ip(), holdtime)
ph.add_multicast_group(PacketPimJoinPruneMulticastGroup(group, pruned_src_addresses=[source]))
pckt = Packet(payload=PacketPimHeader(ph))
......@@ -252,24 +226,18 @@ class TreeInterface(metaclass=ABCMeta):
except:
traceback.print_exc()
return
# todo
#msg = PruneMsg(self.get_tree().tree_id,
# self.get_node(), self._assert_timer.time_left())
#self.pim_if.send_mcast(msg)
def send_join(self):
print("send join")
try:
(source, group) = self.get_tree_id()
# todo help ip of ph
ph = PacketPimJoinPrune(self.get_neighbor_RPF(), 210)
ph.add_multicast_group(PacketPimJoinPruneMulticastGroup(group, joined_src_addresses=[source]))
pckt = Packet(payload=PacketPimHeader(ph))
self.get_interface().send(pckt.bytes())
#msg = JoinMsg(self.get_tree().tree_id, self.get_rpf_())
#self.pim_if.send_mcast(msg)
except:
traceback.print_exc()
return
......@@ -290,8 +258,6 @@ class TreeInterface(metaclass=ABCMeta):
return
def send_assert_cancel(self):
print("send assert cancel")
......@@ -304,12 +270,12 @@ class TreeInterface(metaclass=ABCMeta):
except:
traceback.print_exc()
return
#msg = AssertMsg.new_assert_cancel(self.tree_id)
#self.pim_if.send_mcast(msg)
def send_state_refresh(self):
# todo time
self._assert_state.sendStateRefresh(self)
def send_state_refresh(self, state_refresh_msg_received: PacketPimStateRefresh):
pass
#############################################################
@abstractmethod
def is_forwarding(self):
......@@ -364,10 +330,6 @@ class TreeInterface(metaclass=ABCMeta):
def __str__(self):
return '{}<{}>'.format(self.__class__, self._interface.get_link())
def get_link(self):
# todo
return self._interface.get_link()
def get_interface(self):
kernel = Main.kernel
interface_name = kernel.vif_index_to_name_dic[self._interface_id]
......@@ -396,8 +358,6 @@ class TreeInterface(metaclass=ABCMeta):
raise NotImplementedError()
#def get_rpf_(self):
# return self.get_neighbor_RPF()
# obtain ip of RPF'(S)
......
......@@ -169,8 +169,9 @@ class Forward(UpstreamStateABC):
@type interface: TreeInterfaceUpstream
"""
#interface.set_ot()
interface.set_override_timer()
# if OT is not running the router must set OT to t_override seconds
if not interface.is_override_timer_running():
interface.set_override_timer()
print('stateRefreshArrivesRPFnbr_pruneIs1, F -> F')
......@@ -332,14 +333,8 @@ class Pruned(UpstreamStateABC):
@type interface: TreeInterfaceUpstream
"""
if not interface.is_S_directly_conn():
#interface.set_state(UpstreamState.Pruned)
# todo send prune?!?!?!?!
#timer = interface._prune_limit_timer
#timer.set_timer(interface.t_override)
#timer.start()
interface.set_prune_limit_timer()
interface.send_prune()
print("dataArrivesRPFinterface_OListNull_PLTstoped, P -> P")
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment