Commit 6f0132db authored by Bryton Lacquement's avatar Bryton Lacquement 🚪

fixes: major refactoring to eliminate boilerplate code

parent 6c078a2e
...@@ -2,14 +2,28 @@ from collections import defaultdict ...@@ -2,14 +2,28 @@ from collections import defaultdict
from lib2to3.fixer_base import BaseFix as lib2to3_BaseFix from lib2to3.fixer_base import BaseFix as lib2to3_BaseFix
import lib2to3.fixer_util import lib2to3.fixer_util
from my2to3.trace import create_table, get_data, tracing_functions
class BaseFix(lib2to3_BaseFix): class BaseFix(lib2to3_BaseFix):
def start_tree(self, tree, filename): pass
super(BaseFix, self).start_tree(tree, filename)
class BaseStaticTraceFix(BaseFix):
def __init__(self, *args, **kwargs):
super(BaseStaticTraceFix, self).__init__(*args, **kwargs)
# Note: id is used to differentiate the divisions of the same line.
shared_columns = "filename", "lineno", "id"
self.insert_trace = create_table(self.basename + "_trace", *(shared_columns + self.traced_information))
self.insert_modified = create_table(self.basename + "_modified", *shared_columns)
def start_tree(self, *args, **kwargs):
super(BaseStaticTraceFix, self).start_tree(*args, **kwargs)
self.ids = defaultdict(int) self.ids = defaultdict(int)
def traced_call(self, name, insert_function, node, children): def traced_call(self, node, children):
# Important: every node inside "children" should be cloned before this # Important: every node inside "children" should be cloned before this
# function is called. # function is called.
...@@ -22,7 +36,7 @@ class BaseFix(lib2to3_BaseFix): ...@@ -22,7 +36,7 @@ class BaseFix(lib2to3_BaseFix):
id_ = self.ids[lineno] id_ = self.ids[lineno]
new_node = lib2to3.fixer_util.Call( new_node = lib2to3.fixer_util.Call(
lib2to3.fixer_util.Name(name), lib2to3.fixer_util.Name(self.basename + "_traced"),
args=[ args=[
lib2to3.fixer_util.String('"%s"' % filename), lib2to3.fixer_util.String('"%s"' % filename),
lib2to3.fixer_util.Comma(), lib2to3.fixer_util.Comma(),
...@@ -35,6 +49,50 @@ class BaseFix(lib2to3_BaseFix): ...@@ -35,6 +49,50 @@ class BaseFix(lib2to3_BaseFix):
self.ids[lineno] += 1 self.ids[lineno] += 1
insert_function(filename, lineno, id_) self.insert_modified(filename, lineno, id_)
return new_node return new_node
class BaseDynamicTraceFix(BaseStaticTraceFix):
def __init__(self, *args, **kwargs):
super(BaseDynamicTraceFix, self).__init__(*args, **kwargs)
def f(filename, lineno, id_, *args, **kwargs):
result, values = self._dynamic_trace(*args, **kwargs)
self.insert_trace(filename, lineno, id_, *values)
return result
f.__name__ = self.basename + "_traced"
tracing_functions.append(f)
class BaseSupportFix(BaseFix):
def __init__(self, *args, **kwargs):
super(BaseSupportFix, self).__init__(*args, **kwargs)
# Note: id is used to differentiate the divisions of the same line.
self.insert_support = create_table(self.basename + "_support", "filename", "lineno", "id", "status")
def analyze_data(self, filename, lineno, id_):
class_name = "Fix%sTrace" % self.basename.capitalize()
traced_information = getattr(
__import__("my2to3.fixes.fix_%s_trace" % self.basename, fromlist=[class_name]),
class_name
).traced_information
data = get_data(
self.basename + "_trace",
traced_information,
dict(filename=filename, lineno=lineno, id=id_)
)
if data:
try:
results, status = self._analyze_data()
except Exception:
results, status = False, "unknown"
else:
results, status = False, "no data"
self.insert_support(filename[1:-1], lineno, id_, status)
return results
...@@ -3,36 +3,26 @@ from lib2to3.pygram import python_symbols as syms ...@@ -3,36 +3,26 @@ from lib2to3.pygram import python_symbols as syms
from lib2to3.pytree import Leaf, Node from lib2to3.pytree import Leaf, Node
import os import os
from . import BaseFix from . import BaseSupportFix
from my2to3.trace import create_table, get_data from my2to3.trace import create_table
from my2to3.util import add_future, data2types from my2to3.util import add_future, data2types
# id is used to differentiate the divisions of the same line. class FixDivisionSupport(BaseSupportFix):
insert_support = create_table("division_support", "filename", "lineno", "id", "status") """Rewrites division_traced(n, a, b) into Py2/Py3-compatible division
"""
basename = "division"
def analyze_data(data):
if not data:
return False, "no data"
try: @staticmethod
def _analyze_data(data):
types = data2types(data) types = data2types(data)
except Exception as e:
# Probably, one type is not a builtins
return False, "unknown"
if len(types) == 1: if len(types) == 1:
dividend, divisor = types[0] dividend, divisor = types[0]
if dividend is divisor is int: if dividend is divisor is int:
return True, "automatic" return True, "automatic"
return False, "manual" return False, "manual"
class FixDivisionSupport(BaseFix):
"""Rewrites division_traced(n, a, b) into Py2/Py3-compatible division
"""
# Inspired by https://github.com/python/cpython/blob/e42b705188271da108de42b55d9344642170aa2b/Lib/lib2to3/fixes/fix_xrange.py#L14 # Inspired by https://github.com/python/cpython/blob/e42b705188271da108de42b55d9344642170aa2b/Lib/lib2to3/fixes/fix_xrange.py#L14
PATTERN = """ PATTERN = """
...@@ -42,15 +32,7 @@ class FixDivisionSupport(BaseFix): ...@@ -42,15 +32,7 @@ class FixDivisionSupport(BaseFix):
def transform(self, node, results): def transform(self, node, results):
args = results['args'] args = results['args']
filename, lineno, id_ = [l.value for l in args.children[:-4:2]] should_change = self.analyze_data(*[l.value for l in args.children[:-4:2]])
should_change, status = analyze_data(
get_data(
"division_trace",
["dividend_type", "divisor_type"],
dict(filename=filename, lineno=lineno, id=id_)
)
)
insert_support(filename[1:-1], lineno, id_, status)
if should_change: if should_change:
add_future(node, 'division') add_future(node, 'division')
operator = Leaf(lib2to3.pgen2.token.DOUBLESLASH, "//") operator = Leaf(lib2to3.pgen2.token.DOUBLESLASH, "//")
......
...@@ -4,30 +4,22 @@ import lib2to3.pgen2 ...@@ -4,30 +4,22 @@ import lib2to3.pgen2
from lib2to3.pygram import python_symbols as syms from lib2to3.pygram import python_symbols as syms
from lib2to3.pytree import Node from lib2to3.pytree import Node
from . import BaseFix from . import BaseDynamicTraceFix
from my2to3.trace import create_table, register_tracing_function
from my2to3.util import parse_type from my2to3.util import parse_type
insert_trace = create_table("division_trace", "filename", "lineno", "id", "dividend_type", "divisor_type") class FixDivisionTrace(BaseDynamicTraceFix):
insert_modified = create_table("division_modified", "filename", "lineno", "id")
@register_tracing_function
def division_traced(filename, lineno, id_, dividend, divisor):
insert_trace(
filename,
lineno,
id_,
parse_type(type(dividend)),
parse_type(type(divisor))
)
return dividend / divisor
class FixDivisionTrace(BaseFix):
"""Rewrites a / b into division_traced(id, a, b) """Rewrites a / b into division_traced(id, a, b)
""" """
basename = "division"
traced_information = "dividend_type", "divisor_type"
@staticmethod
def _dynamic_trace(dividend, divisor):
return (
dividend / divisor,
(parse_type(type(dividend)), parse_type(type(divisor)))
)
def match(self, node): def match(self, node):
if node.type == syms.term: if node.type == syms.term:
...@@ -64,7 +56,7 @@ class FixDivisionTrace(BaseFix): ...@@ -64,7 +56,7 @@ class FixDivisionTrace(BaseFix):
comma.prefix = children[1].prefix comma.prefix = children[1].prefix
children[1] = comma children[1] = comma
previous_node = self.traced_call("division_traced", insert_modified, node, children) previous_node = self.traced_call(node, children)
else: else:
# It's not a division operation # It's not a division operation
previous_node = Node(syms.term, children) previous_node = Node(syms.term, children)
......
from lib2to3.pygram import python_symbols as syms from lib2to3.pygram import python_symbols as syms
from . import BaseFix from . import BaseStaticTraceFix
from my2to3.trace import create_table
insert_trace = create_table("nested_except_trace", "filename", "lineno_parent", "lineno_child") class FixNestedExceptTrace(BaseStaticTraceFix):
class FixNestedExceptTrace(BaseFix):
"""This fixer detects scope bugs which can occur due to nested except clauses """This fixer detects scope bugs which can occur due to nested except clauses
which use the same variable name. which use the same variable name.
...@@ -33,6 +29,9 @@ class FixNestedExceptTrace(BaseFix): ...@@ -33,6 +29,9 @@ class FixNestedExceptTrace(BaseFix):
print(e) print(e)
raise raise
""" """
basename = "nested_except"
traced_information = "lineno_parent", "lineno_child"
# https://github.com/python/cpython/blob/3549ca313a6103a3adb281ef3a849298b7d7f72c/Lib/lib2to3/fixes/fix_except.py#L39-L45 # https://github.com/python/cpython/blob/3549ca313a6103a3adb281ef3a849298b7d7f72c/Lib/lib2to3/fixes/fix_except.py#L39-L45
PATTERN = """ PATTERN = """
try_stmt< 'try' ':' (simple_stmt | suite) try_stmt< 'try' ':' (simple_stmt | suite)
...@@ -59,6 +58,6 @@ class FixNestedExceptTrace(BaseFix): ...@@ -59,6 +58,6 @@ class FixNestedExceptTrace(BaseFix):
for child in self.list_except_clause_children(body): for child in self.list_except_clause_children(body):
try: try:
if except_clause.children[3] == child.children[3]: if except_clause.children[3] == child.children[3]:
insert_trace(self.filename, except_clause.get_lineno(), child.get_lineno()) self.insert_trace(self.filename, None, None, except_clause.get_lineno(), child.get_lineno())
except Exception: except Exception:
pass pass
...@@ -4,30 +4,23 @@ import lib2to3.pgen2 ...@@ -4,30 +4,23 @@ import lib2to3.pgen2
from lib2to3.pygram import python_symbols as syms from lib2to3.pygram import python_symbols as syms
from lib2to3.pytree import Node from lib2to3.pytree import Node
from . import BaseFix from . import BaseDynamicTraceFix
from my2to3.trace import create_table, register_tracing_function
from my2to3.util import parse_type from my2to3.util import parse_type
insert_trace = create_table("round_trace", "filename", "lineno", "id", "number", "ndigits") class FixRoundTrace(BaseDynamicTraceFix):
insert_modified = create_table("round_modified", "filename", "lineno", "id")
@register_tracing_function
def round_traced(filename, lineno, id_, number, ndigits=0):
insert_trace(
filename,
lineno,
id_,
number,
ndigits
)
return round(number, ndigits)
class FixRoundTrace(BaseFix):
"""Rewrites round(a[, b]) into round_traced(id, a[, b]) """Rewrites round(a[, b]) into round_traced(id, a[, b])
""" """
basename = "round"
traced_information = "number", "ndigits"
@staticmethod
def _dynamic_trace(number, ndigits=0):
return (
round(number, ndigits),
(parse_type(type(number)), parse_type(type(ndigits)))
)
# Inspired by https://github.com/python/cpython/blob/e42b705188271da108de42b55d9344642170aa2b/Lib/lib2to3/fixes/fix_xrange.py#L14 # Inspired by https://github.com/python/cpython/blob/e42b705188271da108de42b55d9344642170aa2b/Lib/lib2to3/fixes/fix_xrange.py#L14
PATTERN = """ PATTERN = """
power<'round' trailer< '(' args=any ')' > power<'round' trailer< '(' args=any ')' >
...@@ -45,4 +38,4 @@ class FixRoundTrace(BaseFix): ...@@ -45,4 +38,4 @@ class FixRoundTrace(BaseFix):
# e.g. round(1,23, 4) # e.g. round(1,23, 4)
children = [leaf.clone() for leaf in args.children] children = [leaf.clone() for leaf in args.children]
node.replace(self.traced_call("division_traced", insert_modified, node, children)) node.replace(self.traced_call(node, children))
...@@ -30,7 +30,7 @@ class testFixNestedExceptTrace(FixerTestCase): ...@@ -30,7 +30,7 @@ class testFixNestedExceptTrace(FixerTestCase):
3 3
""" """
self.unchanged(a) self.unchanged(a)
self.assertDataEqual("nested_except_trace", [(u'<string>', 4, 7)]) self.assertDataEqual("nested_except_trace", [(u'<string>', None, None, 4, 7)])
def test_except_2(self): def test_except_2(self):
a = """ a = """
...@@ -59,7 +59,7 @@ class testFixNestedExceptTrace(FixerTestCase): ...@@ -59,7 +59,7 @@ class testFixNestedExceptTrace(FixerTestCase):
4 4
""" """
self.unchanged(a) self.unchanged(a)
self.assertDataEqual("nested_except_trace", [(u'<string>', 7, 10), (u'<string>', 4, 7), (u'<string>', 4, 10)]) self.assertDataEqual("nested_except_trace", [(u'<string>', None, None, 7, 10), (u'<string>', None, None, 4, 7), (u'<string>', None, None, 4, 10)])
def test_else(self): def test_else(self):
a = """ a = """
......
...@@ -8,9 +8,6 @@ conn = sqlite3.connect(database) ...@@ -8,9 +8,6 @@ conn = sqlite3.connect(database)
tracing_functions = [] tracing_functions = []
def register_tracing_function(f):
tracing_functions.append(f)
return f
def create_table(table, *columns): def create_table(table, *columns):
...@@ -115,11 +112,9 @@ class ModuleImporter: ...@@ -115,11 +112,9 @@ class ModuleImporter:
def patch_imports(whitelist): def patch_imports(whitelist):
sys.meta_path.append(ModuleImporter(whitelist)) sys.meta_path.append(ModuleImporter(whitelist))
# XXX: This makes sure that "my2to3.trace.tracing_functions" is correctly
# populated. Let's find a better way.
apply_fixers("", "dummy")
for f in tracing_functions: for f in tracing_functions:
setattr(__builtin__, f.__name__, f) setattr(__builtin__, f.__name__, f)
for fixer in get_fixers():
# Import each fixer, to make sure that "my2to3.trace.tracing_functions" is
# correctly populated.
__import__(fixer)
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment