Commit 0d19c086 authored by Bryton Lacquement's avatar Bryton Lacquement 🚪

tests: also test the recorded data

parent 3817aa59
from lib2to3.tests.test_fixers import FixerTestCase as lib2to3FixerTestCase from lib2to3.tests.test_fixers import FixerTestCase as lib2to3FixerTestCase
import sqlite3
from my2to3.trace import database, get_data, tracing_functions
class FixerTestCase(lib2to3FixerTestCase): class FixerTestCase(lib2to3FixerTestCase):
...@@ -6,22 +9,20 @@ class FixerTestCase(lib2to3FixerTestCase): ...@@ -6,22 +9,20 @@ class FixerTestCase(lib2to3FixerTestCase):
def setUp(self, fix_list=None, fixer_pkg="my2to3", options=None): def setUp(self, fix_list=None, fixer_pkg="my2to3", options=None):
super(FixerTestCase, self).setUp(fix_list, fixer_pkg, options) super(FixerTestCase, self).setUp(fix_list, fixer_pkg, options)
if self.fixer.endswith('_trace'): # Clear the database
fix_name = 'fix_' + self.fixer # XXX: a better way probably exists.
self.fixer_module = fixer_module = getattr(__import__('my2to3.fixes', fromlist=[fix_name]), fix_name) # TODO:
# - refactor
self.traces = [] # - optimize
# Wrap fixer_module.trace, to populate self.traces conn = sqlite3.connect(database)
self.old_insert_trace = fixer_module.insert_trace c = conn.cursor()
def decorate(func): for table in c.execute("SELECT name FROM sqlite_master WHERE type='table'").fetchall():
def call(*args): c.execute("DELETE FROM %s" % table)
self.traces.append(args) conn.commit()
return func(*args) conn.close()
return call
fixer_module.insert_trace = decorate(fixer_module.insert_trace)
def tearDown(self, *args, **kw): def assertDataEqual(self, table, data):
super(FixerTestCase, self).tearDown(*args, **kw) self.assertEqual(get_data(table), data)
if self.fixer.endswith('_trace'): def exec_code(self, string):
self.fixer_module.insert_trace = self.old_insert_trace exec(compile(string, "<string>", 'exec'), {f.__name__: f for f in tracing_functions})
...@@ -7,49 +7,90 @@ class testFixDivisionTrace(FixerTestCase): ...@@ -7,49 +7,90 @@ class testFixDivisionTrace(FixerTestCase):
fixer = "division_trace" fixer = "division_trace"
def test_simple_division(self): def test_simple_division(self):
b = """x / y""" b = """10 / 20"""
a = """division_traced("<string>",1,0,x , y)""" a = """division_traced("<string>",1,0,10 , 20)"""
self.check(b, a) self.check(b, a)
self.exec_code(a)
self.assertDataEqual("division_modified", [(u'<string>', 1, 0)])
self.assertDataEqual("division_trace", [(u'<string>', 1, 0, u'int', u'int')])
def test_nested_divisions(self):
b = """10 / 20 / 30"""
a = """division_traced("<string>",1,1,division_traced("<string>",1,0,10 , 20) , 30)"""
self.check(b, a)
self.exec_code(a)
self.assertDataEqual("division_modified", [(u'<string>', 1, 0), (u'<string>', 1, 1)])
self.assertDataEqual("division_trace", [(u'<string>', 1, 0, u'int', u'int'), (u'<string>', 1, 1, u'int', u'int')])
def test_nested_divisions_with_parentheses_1(self): def test_nested_divisions_with_parentheses_1(self):
b = """(x / y) / z""" b = """(10 / 20) / 30"""
a = """division_traced("<string>",1,1,(division_traced("<string>",1,0,x , y)) , z)""" a = """division_traced("<string>",1,1,(division_traced("<string>",1,0,10 , 20)) , 30)"""
self.check(b, a) self.check(b, a)
self.exec_code(a)
self.assertDataEqual("division_modified", [(u'<string>', 1, 0), (u'<string>', 1, 1)])
self.assertDataEqual("division_trace", [(u'<string>', 1, 0, u'int', u'int'), (u'<string>', 1, 1, u'int', u'int')])
def test_nested_divisions_with_parentheses_2(self): def test_nested_divisions_with_parentheses_2(self):
b = """x / (y / z)""" b = """30 / (20 / 10)"""
a = """division_traced("<string>",1,1,x , (division_traced("<string>",1,0,y , z)))""" a = """division_traced("<string>",1,1,30 , (division_traced("<string>",1,0,20 , 10)))"""
self.check(b, a) self.check(b, a)
self.exec_code(a)
self.assertDataEqual("division_modified", [(u'<string>', 1, 0), (u'<string>', 1, 1)])
self.assertDataEqual("division_trace", [(u'<string>', 1, 0, u'int', u'int'), (u'<string>', 1, 1, u'int', u'int')])
def test_inline_division_1(self): def test_inline_division_1(self):
b = """1 / 2 / 3""" b = """10 / 20 / 30"""
a = """division_traced("<string>",1,1,division_traced("<string>",1,0,1 , 2) , 3)""" a = """division_traced("<string>",1,1,division_traced("<string>",1,0,10 , 20) , 30)"""
self.check(b, a) self.check(b, a)
self.exec_code(a)
self.assertDataEqual("division_modified", [(u'<string>', 1, 0), (u'<string>', 1, 1)])
self.assertDataEqual("division_trace", [(u'<string>', 1, 0, u'int', u'int'), (u'<string>', 1, 1, u'int', u'int')])
def test_inline_division_2(self): def test_inline_division_2(self):
b = """1 / 2 * 3""" b = """10 / 20 * 30"""
a = """division_traced("<string>",1,0,1 , 2) * 3""" a = """division_traced("<string>",1,0,10 , 20) * 30"""
self.check(b, a) self.check(b, a)
self.exec_code(a)
self.assertDataEqual("division_modified", [(u'<string>', 1, 0)])
self.assertDataEqual("division_trace", [(u'<string>', 1, 0, u'int', u'int')])
def test_inline_division_3(self): def test_inline_division_3(self):
b = """1 * 2 / 3""" b = """10 * 20 / 30"""
a = """division_traced("<string>",1,0,1 * 2 , 3)""" a = """division_traced("<string>",1,0,10 * 20 , 30)"""
self.check(b, a) self.check(b, a)
self.exec_code(a)
self.assertDataEqual("division_modified", [(u'<string>', 1, 0)])
self.assertDataEqual("division_trace", [(u'<string>', 1, 0, u'int', u'int')])
def test_division_on_line_continuation(self): def test_division_on_line_continuation(self):
b = """x \ b = """10 \
/ y""" / 20"""
a = """ division_traced("<string>",1,0,x \ a = """division_traced("<string>",1,0,10 \
, y)""" , 20)"""
self.check(b, a) self.check(b, a)
self.exec_code(a)
self.assertDataEqual("division_modified", [(u'<string>', 1, 0)])
self.assertDataEqual("division_trace", [(u'<string>', 1, 0, u'int', u'int')])
def test_multiline_division(self): def test_multiline_division(self):
b = """foo = \ b = """foo = \
x / y""" 10 / 20"""
a = """foo = \ a = """foo = \
division_traced("<string>",1,0,x , y)""" division_traced("<string>",1,0,10 , 20)"""
self.check(b, a) self.check(b, a)
self.exec_code(a)
self.assertDataEqual("division_modified", [(u'<string>', 1, 0)])
self.assertDataEqual("division_trace", [(u'<string>', 1, 0, u'int', u'int')])
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -16,9 +16,8 @@ class testFixNestedExceptTrace(FixerTestCase): ...@@ -16,9 +16,8 @@ class testFixNestedExceptTrace(FixerTestCase):
except Exception as e: except Exception as e:
3 3
""" """
self.assertEqual(self.traces, [])
self.unchanged(a) self.unchanged(a)
self.assertEqual(self.traces, []) self.assertDataEqual("nested_except_trace", [])
def test_except(self): def test_except(self):
a = """ a = """
...@@ -30,9 +29,8 @@ class testFixNestedExceptTrace(FixerTestCase): ...@@ -30,9 +29,8 @@ class testFixNestedExceptTrace(FixerTestCase):
except Exception as e: except Exception as e:
3 3
""" """
self.assertEqual(self.traces, [])
self.unchanged(a) self.unchanged(a)
self.assertEqual(self.traces, [(u'<string>', 4, 7)]) self.assertDataEqual("nested_except_trace", [(u'<string>', 4, 7)])
def test_except_2(self): def test_except_2(self):
a = """ a = """
...@@ -44,9 +42,8 @@ class testFixNestedExceptTrace(FixerTestCase): ...@@ -44,9 +42,8 @@ class testFixNestedExceptTrace(FixerTestCase):
except Exception as f: except Exception as f:
3 3
""" """
self.assertEqual(self.traces, [])
self.unchanged(a) self.unchanged(a)
self.assertEqual(self.traces, []) self.assertDataEqual("nested_except_trace", [])
def test_multiple_except(self): def test_multiple_except(self):
a = """ a = """
...@@ -61,9 +58,8 @@ class testFixNestedExceptTrace(FixerTestCase): ...@@ -61,9 +58,8 @@ class testFixNestedExceptTrace(FixerTestCase):
except Exception as e: except Exception as e:
4 4
""" """
self.assertEqual(self.traces, [])
self.unchanged(a) self.unchanged(a)
self.assertEqual(self.traces, [(u'<string>', 7, 10), (u'<string>', 4, 7), (u'<string>', 4, 10)]) self.assertDataEqual("nested_except_trace", [(u'<string>', 7, 10), (u'<string>', 4, 7), (u'<string>', 4, 10)])
def test_else(self): def test_else(self):
a = """ a = """
...@@ -77,9 +73,8 @@ class testFixNestedExceptTrace(FixerTestCase): ...@@ -77,9 +73,8 @@ class testFixNestedExceptTrace(FixerTestCase):
except Exception as e: except Exception as e:
4 4
""" """
self.assertEqual(self.traces, [])
self.unchanged(a) self.unchanged(a)
self.assertEqual(self.traces, []) self.assertDataEqual("nested_except_trace", [])
def test_finally(self): def test_finally(self):
a = """ a = """
...@@ -93,9 +88,8 @@ class testFixNestedExceptTrace(FixerTestCase): ...@@ -93,9 +88,8 @@ class testFixNestedExceptTrace(FixerTestCase):
except Exception as e: except Exception as e:
4 4
""" """
self.assertEqual(self.traces, [])
self.unchanged(a) self.unchanged(a)
self.assertEqual(self.traces, []) self.assertDataEqual("nested_except_trace", [])
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -44,19 +44,20 @@ def get_fixers(): ...@@ -44,19 +44,20 @@ def get_fixers():
] ]
def get_data(table, columns_to_select, conditions): def get_data(table, columns_to_select='*', conditions={}):
# TODO: # TODO:
# - refactor # - refactor
# - optimize # - optimize
conn = sqlite3.connect(database) conn = sqlite3.connect(database)
c = conn.cursor() c = conn.cursor()
data = c.execute("SELECT %s FROM %s WHERE %s" % ( query = "SELECT %s FROM %s" % (
', '.join(columns_to_select), ', '.join(columns_to_select),
table, table,
' AND '.join(k + " = " + v for k, v in conditions.items()))
) )
if conditions:
query += "WHERE " + ' AND '.join(k + " = " + v for k, v in conditions.items())
try: try:
return data.fetchall() return c.execute(query).fetchall()
finally: finally:
conn.close() conn.close()
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment