Commit b9c7ed1d authored by Bram Schoenmakers's avatar Bram Schoenmakers

Allow that an edge ID may occur multiple times in a graph.

For example:

Foo id:1
Bar p:1
Baz p:1

Would only store the ID for the last added edge (Foo,Baz) and forget
about the ID of 1 for (Foo,Bar).

Also make sure that add_edge does not add an edge twice, otherwise the
edge ID could be accidentally overwritten.
parent 51f3c03a
...@@ -25,15 +25,15 @@ class DirectedGraph(object): ...@@ -25,15 +25,15 @@ class DirectedGraph(object):
The p_id is the id of the edge, if the client wishes to maintain this. The p_id is the id of the edge, if the client wishes to maintain this.
""" """
if not self.has_node(p_from): if not self.has_edge(p_from, p_to):
self.add_node(p_from) if not self.has_node(p_from):
self.add_node(p_from)
if not self.has_node(p_to): if not self.has_node(p_to):
self.add_node(p_to) self.add_node(p_to)
self._edges[p_from].add(p_to) self._edges[p_from].add(p_to)
if p_id: self._edge_numbers[(p_from, p_to)] = p_id
self._edge_numbers[p_id] = (p_from, p_to)
def has_path(self, p_from, p_to): def has_path(self, p_from, p_to):
""" """
...@@ -125,7 +125,14 @@ class DirectedGraph(object): ...@@ -125,7 +125,14 @@ class DirectedGraph(object):
""" """
Returns True if the client registered an edge with the given id. Returns True if the client registered an edge with the given id.
""" """
return p_id in self._edge_numbers result = False
for edge_id in self._edge_numbers.itervalues():
if edge_id == p_id:
result = True
break
return result
def edge_id(self, p_from, p_to): def edge_id(self, p_from, p_to):
""" """
...@@ -133,14 +140,10 @@ class DirectedGraph(object): ...@@ -133,14 +140,10 @@ class DirectedGraph(object):
Returns None if the edge does not exist or has no value assigned. Returns None if the edge does not exist or has no value assigned.
""" """
result = None try:
return self._edge_numbers[(p_from, p_to)]
for key, value in self._edge_numbers.iteritems(): except KeyError:
if value == (p_from, p_to): return None
result = key
break
return result
def remove_edge(self, p_from, p_to, remove_unconnected_nodes=True): def remove_edge(self, p_from, p_to, remove_unconnected_nodes=True):
""" """
...@@ -152,9 +155,10 @@ class DirectedGraph(object): ...@@ -152,9 +155,10 @@ class DirectedGraph(object):
if self.has_edge(p_from, p_to): if self.has_edge(p_from, p_to):
self._edges[p_from].remove(p_to) self._edges[p_from].remove(p_to)
edge_id = self.edge_id(p_from, p_to) try:
if edge_id: del self._edge_numbers[(p_from, p_to)]
del self._edge_numbers[edge_id] except KeyError:
return None
if remove_unconnected_nodes: if remove_unconnected_nodes:
if self.is_isolated(p_from): if self.is_isolated(p_from):
......
...@@ -7,7 +7,7 @@ class GraphTest(unittest.TestCase): ...@@ -7,7 +7,7 @@ class GraphTest(unittest.TestCase):
self.graph = Graph.DirectedGraph() self.graph = Graph.DirectedGraph()
self.graph.add_edge(1, 2, 1) self.graph.add_edge(1, 2, 1)
self.graph.add_edge(2, 4) self.graph.add_edge(2, 4, "Test")
self.graph.add_edge(4, 3) self.graph.add_edge(4, 3)
self.graph.add_edge(4, 6) self.graph.add_edge(4, 6)
self.graph.add_edge(6, 2) self.graph.add_edge(6, 2)
...@@ -26,9 +26,17 @@ class GraphTest(unittest.TestCase): ...@@ -26,9 +26,17 @@ class GraphTest(unittest.TestCase):
for i in range(1, 7): for i in range(1, 7):
self.assertTrue(self.graph.has_node(i)) self.assertTrue(self.graph.has_node(i))
def test_has_edge_ids(self):
self.assertTrue(self.graph.has_edge_id(1))
self.assertTrue(self.graph.has_edge_id("Test"))
self.assertFalse(self.graph.has_edge_id("1"))
def test_incoming_neighbors1(self): def test_incoming_neighbors1(self):
self.assertEquals(self.graph.incoming_neighbors(1), set()) self.assertEquals(self.graph.incoming_neighbors(1), set())
def test_edge_id_of_nonexistent_edge(self):
self.assertFalse(self.graph.edge_id(1, 6))
def test_incoming_neighbors2(self): def test_incoming_neighbors2(self):
self.assertEquals(self.graph.incoming_neighbors(2), set([1, 6])) self.assertEquals(self.graph.incoming_neighbors(2), set([1, 6]))
...@@ -127,8 +135,17 @@ class GraphTest(unittest.TestCase): ...@@ -127,8 +135,17 @@ class GraphTest(unittest.TestCase):
# the one and only edge must be removed now # the one and only edge must be removed now
self.assertFalse(self.graph.has_edge(1, 3)) self.assertFalse(self.graph.has_edge(1, 3))
def test_add_double_edge_with_id(self):
self.graph.add_edge(1, 3, "Dummy")
self.assertFalse(self.graph.has_edge_id("Dummy"))
self.graph.remove_edge(1, 3)
# the one and only edge must be removed now
self.assertFalse(self.graph.has_edge(1, 3))
def test_str_output(self): def test_str_output(self):
out = 'digraph g {\n 1\n 1 -> 2 [label="1"]\n 1 -> 3\n 2\n 2 -> 4\n 3\n 3 -> 5\n 4\n 4 -> 3\n 4 -> 6\n 5\n 6\n 6 -> 2\n}\n' out = 'digraph g {\n 1\n 1 -> 2 [label="1"]\n 1 -> 3\n 2\n 2 -> 4 [label="Test"]\n 3\n 3 -> 5\n 4\n 4 -> 3\n 4 -> 6\n 5\n 6\n 6 -> 2\n}\n'
self.assertEquals(str(self.graph), out) self.assertEquals(str(self.graph), out)
def test_dot_output_without_labels(self): def test_dot_output_without_labels(self):
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment