TreeFragment.py 7.63 KB
Newer Older
1 2 3 4 5 6 7 8
#
# TreeFragments - parsing of strings to trees
#

import re
from cStringIO import StringIO
from Scanning import PyrexScanner, StringSourceDescriptor
from Symtab import BuiltinScope, ModuleScope
9 10
import Symtab
import PyrexTypes
11
from Visitor import VisitorTransform
12
from Nodes import Node, StatListNode
13 14 15
from ExprNodes import NameNode
import Parsing
import Main
16
import UtilNodes
17 18 19 20 21 22 23

"""
Support for parsing strings into code trees.
"""

class StringParseContext(Main.Context):
    def __init__(self, include_directories, name):
24
        Main.Context.__init__(self, include_directories, {})
25 26 27 28 29 30 31
        self.module_name = name
        
    def find_module(self, module_name, relative_to = None, pos = None, need_pxd = 1):
        if module_name != self.module_name:
            raise AssertionError("Not yet supporting any cimports/includes from string code snippets")
        return ModuleScope(module_name, parent_module = None, context = self)
        
32
def parse_from_strings(name, code, pxds={}, level=None):
33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58
    """
    Utility method to parse a (unicode) string of code. This is mostly
    used for internal Cython compiler purposes (creating code snippets
    that transforms should emit, as well as unit testing).
    
    code - a unicode string containing Cython (module-level) code
    name - a descriptive name for the code source (to use in error messages etc.)
    """

    # Since source files carry an encoding, it makes sense in this context
    # to use a unicode string so that code fragments don't have to bother
    # with encoding. This means that test code passed in should not have an
    # encoding header.
    assert isinstance(code, unicode), "unicode code snippets only please"
    encoding = "UTF-8"

    module_name = name
    initial_pos = (name, 1, 0)
    code_source = StringSourceDescriptor(name, code)

    context = StringParseContext([], name)
    scope = context.find_module(module_name, pos = initial_pos, need_pxd = 0)

    buf = StringIO(code.encode(encoding))

    scanner = PyrexScanner(buf, code_source, source_encoding = encoding,
Dag Sverre Seljebotn's avatar
Merge  
Dag Sverre Seljebotn committed
59
                     scope = scope, context = context)
60 61 62 63
    if level is None:
        tree = Parsing.p_module(scanner, 0, module_name)
    else:
        tree = Parsing.p_code(scanner, level=level)
64 65
    return tree

66 67
class TreeCopier(VisitorTransform):
    def visit_Node(self, node):
68 69 70 71
        if node is None:
            return node
        else:
            c = node.clone_node()
72
            self.visitchildren(c)
73 74
            return c

75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100
class ApplyPositionAndCopy(TreeCopier):
    def __init__(self, pos):
        super(ApplyPositionAndCopy, self).__init__()
        self.pos = pos
        
    def visit_Node(self, node):
        copy = super(ApplyPositionAndCopy, self).visit_Node(node)
        copy.pos = self.pos
        return copy

class TemplateTransform(VisitorTransform):
    """
    Makes a copy of a template tree while doing substitutions.
    
    A dictionary "substitutions" should be passed in when calling
    the transform; mapping names to replacement nodes. Then replacement
    happens like this:
     - If an ExprStatNode contains a single NameNode, whose name is
       a key in the substitutions dictionary, the ExprStatNode is
       replaced with a copy of the tree given in the dictionary.
       It is the responsibility of the caller that the replacement
       node is a valid statement.
     - If a single NameNode is otherwise encountered, it is replaced
       if its name is listed in the substitutions dictionary in the
       same way. It is the responsibility of the caller to make sure
       that the replacement nodes is a valid expression.
101 102 103 104 105 106 107

    Also a list "temps" should be passed. Any names listed will
    be transformed into anonymous, temporary names.
   
    Currently supported for tempnames is:
    NameNode
    (various function and class definition nodes etc. should be added to this)
108 109 110 111
    
    Each replacement node gets the position of the substituted node
    recursively applied to every member node.
    """
112 113 114 115

    def __call__(self, node, substitutions, temps, pos):
        self.substitutions = substitutions
        self.pos = pos
116 117 118 119 120 121 122 123 124 125 126 127 128
        tempmap = {}
        temphandles = []
        for temp in temps:
            handle = UtilNodes.TempHandle(PyrexTypes.py_object_type)
            tempmap[temp] = handle
            temphandles.append(handle)
        self.tempmap = tempmap
        result = super(TemplateTransform, self).__call__(node)
        if temps:
            result = UtilNodes.TempsBlockNode(self.get_pos(node),
                                              temps=temphandles,
                                              body=result)
        return result
129

130 131 132 133 134
    def get_pos(self, node):
        if self.pos:
            return self.pos
        else:
            return node.pos
135

136
    def visit_Node(self, node):
137
        if node is None:
138
            return None
139 140
        else:
            c = node.clone_node()
141 142
            if self.pos is not None:
                c.pos = self.pos
143
            self.visitchildren(c)
144 145
            return c
    
146 147
    def try_substitution(self, node, key):
        sub = self.substitutions.get(key)
148 149 150 151
        if sub is not None:
            pos = self.pos
            if pos is None: pos = node.pos
            return ApplyPositionAndCopy(pos)(sub)
152
        else:
153 154
            return self.visit_Node(node) # make copy as usual
            
155
    def visit_NameNode(self, node):
156 157
        temphandle = self.tempmap.get(node.name)
        if temphandle:
158
            # Replace name with temporary
159 160 161
            return temphandle.ref(self.get_pos(node))
        else:
            return self.try_substitution(node, node.name)
162

163 164 165
    def visit_ExprStatNode(self, node):
        # If an expression-as-statement consists of only a replaceable
        # NameNode, we replace the entire statement, not only the NameNode
166 167
        if isinstance(node.expr, NameNode):
            return self.try_substitution(node, node.expr.name)
168 169 170
        else:
            return self.visit_Node(node)
    
171 172 173 174 175 176
def copy_code_tree(node):
    return TreeCopier()(node)

INDENT_RE = re.compile(ur"^ *")
def strip_common_indent(lines):
    "Strips empty lines and common indentation from the list of strings given in lines"
177
    # TODO: Facilitate textwrap.indent instead
178
    lines = [x for x in lines if x.strip() != u""]
179
    minindent = min([len(INDENT_RE.match(x).group(0)) for x in lines])
180 181 182 183
    lines = [x[minindent:] for x in lines]
    return lines
    
class TreeFragment(object):
184
    def __init__(self, code, name="(tree fragment)", pxds={}, temps=[], pipeline=[], level=None):
185 186 187 188 189 190 191 192
        if isinstance(code, unicode):
            def fmt(x): return u"\n".join(strip_common_indent(x.split(u"\n"))) 
            
            fmt_code = fmt(code)
            fmt_pxds = {}
            for key, value in pxds.iteritems():
                fmt_pxds[key] = fmt(value)
                
193 194 195
            mod = t = parse_from_strings(name, fmt_code, fmt_pxds, level=level)
            if level is None:
                t = t.body # Make sure a StatListNode is at the top
196 197 198 199 200
            if not isinstance(t, StatListNode):
                t = StatListNode(pos=mod.pos, stats=[t])
            for transform in pipeline:
                t = transform(t)
            self.root = t
201 202 203 204 205
        elif isinstance(code, Node):
            if pxds != {}: raise NotImplementedError()
            self.root = code
        else:
            raise ValueError("Unrecognized code format (accepts unicode and Node)")
206
        self.temps = temps
207 208 209 210

    def copy(self):
        return copy_code_tree(self.root)

211 212 213
    def substitute(self, nodes={}, temps=[], pos = None):
        return TemplateTransform()(self.root,
                                   substitutions = nodes,
214
                                   temps = self.temps + temps, pos = pos)
215 216 217 218