Commit 3b9f95ef authored by Stefan Behnel's avatar Stefan Behnel

implement relative cimports and fix some general issues with relative imports

parent 19fb7521
...@@ -123,7 +123,7 @@ class Context(object): ...@@ -123,7 +123,7 @@ class Context(object):
pos = (module_name, 0, 0) pos = (module_name, 0, 0)
raise CompileError(pos, raise CompileError(pos,
"'%s' is not a valid module name" % module_name) "'%s' is not a valid module name" % module_name)
if "." not in module_name and relative_to: if relative_to:
if debug_find_module: if debug_find_module:
print("...trying relative import") print("...trying relative import")
scope = relative_to.lookup_submodule(module_name) scope = relative_to.lookup_submodule(module_name)
...@@ -139,7 +139,7 @@ class Context(object): ...@@ -139,7 +139,7 @@ class Context(object):
for name in module_name.split("."): for name in module_name.split("."):
scope = scope.find_submodule(name) scope = scope.find_submodule(name)
if debug_find_module: if debug_find_module:
print("...scope =", scope) print("...scope = %s" % scope)
if not scope.pxd_file_loaded: if not scope.pxd_file_loaded:
if debug_find_module: if debug_find_module:
print("...pxd not loaded") print("...pxd not loaded")
...@@ -149,7 +149,7 @@ class Context(object): ...@@ -149,7 +149,7 @@ class Context(object):
print("...looking for pxd file") print("...looking for pxd file")
pxd_pathname = self.find_pxd_file(module_name, pos) pxd_pathname = self.find_pxd_file(module_name, pos)
if debug_find_module: if debug_find_module:
print("......found ", pxd_pathname) print("......found %s" % pxd_pathname)
if not pxd_pathname and need_pxd: if not pxd_pathname and need_pxd:
package_pathname = self.search_include_directories(module_name, ".py", pos) package_pathname = self.search_include_directories(module_name, ".py", pos)
if package_pathname and package_pathname.endswith('__init__.py'): if package_pathname and package_pathname.endswith('__init__.py'):
...@@ -162,7 +162,7 @@ class Context(object): ...@@ -162,7 +162,7 @@ class Context(object):
print("Context.find_module: Parsing %s" % pxd_pathname) print("Context.find_module: Parsing %s" % pxd_pathname)
rel_path = module_name.replace('.', os.sep) + os.path.splitext(pxd_pathname)[1] rel_path = module_name.replace('.', os.sep) + os.path.splitext(pxd_pathname)[1]
if not pxd_pathname.endswith(rel_path): if not pxd_pathname.endswith(rel_path):
rel_path = pxd_pathname # safety measure to prevent printing incorrect paths rel_path = pxd_pathname # safety measure to prevent printing incorrect paths
source_desc = FileSourceDescriptor(pxd_pathname, rel_path) source_desc = FileSourceDescriptor(pxd_pathname, rel_path)
err, result = self.process_pxd(source_desc, scope, module_name) err, result = self.process_pxd(source_desc, scope, module_name)
if err: if err:
......
...@@ -6864,15 +6864,22 @@ class FromCImportStatNode(StatNode): ...@@ -6864,15 +6864,22 @@ class FromCImportStatNode(StatNode):
# from ... cimport statement # from ... cimport statement
# #
# module_name string Qualified name of module # module_name string Qualified name of module
# relative_level int or None Relative import: number of dots before module_name
# imported_names [(pos, name, as_name, kind)] Names to be imported # imported_names [(pos, name, as_name, kind)] Names to be imported
child_attrs = [] child_attrs = []
module_name = None
relative_level = None
imported_names = None
def analyse_declarations(self, env): def analyse_declarations(self, env):
if not env.is_module_scope: if not env.is_module_scope:
error(self.pos, "cimport only allowed at module level") error(self.pos, "cimport only allowed at module level")
return return
module_scope = env.find_module(self.module_name, self.pos) if self.relative_level and self.relative_level > env.qualified_name.count('.'):
error(self.pos, "relative cimport beyond main package is not allowed")
module_scope = env.find_module(self.module_name, self.pos, relative_level=self.relative_level)
module_name = module_scope.qualified_name
env.add_imported_module(module_scope) env.add_imported_module(module_scope)
for pos, name, as_name, kind in self.imported_names: for pos, name, as_name, kind in self.imported_names:
if name == "*": if name == "*":
...@@ -6886,29 +6893,27 @@ class FromCImportStatNode(StatNode): ...@@ -6886,29 +6893,27 @@ class FromCImportStatNode(StatNode):
entry.used = 1 entry.used = 1
else: else:
if kind == 'struct' or kind == 'union': if kind == 'struct' or kind == 'union':
entry = module_scope.declare_struct_or_union(name, entry = module_scope.declare_struct_or_union(
kind = kind, scope = None, typedef_flag = 0, pos = pos) name, kind=kind, scope=None, typedef_flag=0, pos=pos)
elif kind == 'class': elif kind == 'class':
entry = module_scope.declare_c_class(name, pos = pos, entry = module_scope.declare_c_class(name, pos=pos, module_name=module_name)
module_name = self.module_name)
else: else:
submodule_scope = env.context.find_module(name, relative_to = module_scope, pos = self.pos) submodule_scope = env.context.find_module(name, relative_to=module_scope, pos=self.pos)
if submodule_scope.parent_module is module_scope: if submodule_scope.parent_module is module_scope:
env.declare_module(as_name or name, submodule_scope, self.pos) env.declare_module(as_name or name, submodule_scope, self.pos)
else: else:
error(pos, "Name '%s' not declared in module '%s'" error(pos, "Name '%s' not declared in module '%s'" % (name, module_name))
% (name, self.module_name))
if entry: if entry:
local_name = as_name or name local_name = as_name or name
env.add_imported_entry(local_name, entry, pos) env.add_imported_entry(local_name, entry, pos)
if self.module_name.startswith('cpython'): # enough for now if module_name.startswith('cpython'): # enough for now
if self.module_name in utility_code_for_cimports: if module_name in utility_code_for_cimports:
env.use_utility_code(UtilityCode.load_cached( env.use_utility_code(UtilityCode.load_cached(
*utility_code_for_cimports[self.module_name])) *utility_code_for_cimports[module_name]))
for _, name, _, _ in self.imported_names: for _, name, _, _ in self.imported_names:
fqname = '%s.%s' % (self.module_name, name) fqname = '%s.%s' % (module_name, name)
if fqname in utility_code_for_cimports: if fqname in utility_code_for_cimports:
env.use_utility_code(UtilityCode.load_cached( env.use_utility_code(UtilityCode.load_cached(
*utility_code_for_cimports[fqname])) *utility_code_for_cimports[fqname]))
...@@ -6969,7 +6974,7 @@ class FromImportStatNode(StatNode): ...@@ -6969,7 +6974,7 @@ class FromImportStatNode(StatNode):
env.use_utility_code(UtilityCode.load_cached("ExtTypeTest", "ObjectHandling.c")) env.use_utility_code(UtilityCode.load_cached("ExtTypeTest", "ObjectHandling.c"))
break break
else: else:
entry = env.lookup(target.name) entry = env.lookup(target.name)
# check whether or not entry is already cimported # check whether or not entry is already cimported
if (entry.is_type and entry.type.name == name if (entry.is_type and entry.type.name == name
and hasattr(entry.type, 'module_name')): and hasattr(entry.type, 'module_name')):
...@@ -6978,8 +6983,8 @@ class FromImportStatNode(StatNode): ...@@ -6978,8 +6983,8 @@ class FromImportStatNode(StatNode):
continue continue
try: try:
# cimported with relative name # cimported with relative name
module = env.find_module(self.module.module_name.value, module = env.find_module(self.module.module_name.value, pos=self.pos,
pos=None) relative_level=self.module.level)
if entry.type.module_name == module.qualified_name: if entry.type.module_name == module.qualified_name:
continue continue
except AttributeError: except AttributeError:
......
...@@ -762,8 +762,8 @@ class InterpretCompilerDirectives(CythonTransform, SkipDeclarations): ...@@ -762,8 +762,8 @@ class InterpretCompilerDirectives(CythonTransform, SkipDeclarations):
return node return node
def visit_FromCImportStatNode(self, node): def visit_FromCImportStatNode(self, node):
if (node.module_name == u"cython") or \ if not node.relative_level and (
node.module_name.startswith(u"cython."): node.module_name == u"cython" or node.module_name.startswith(u"cython.")):
submodule = (node.module_name + u".")[7:] submodule = (node.module_name + u".")[7:]
newimp = [] newimp = []
......
...@@ -1319,24 +1319,19 @@ def p_from_import_statement(s, first_statement = 0): ...@@ -1319,24 +1319,19 @@ def p_from_import_statement(s, first_statement = 0):
while s.sy == '.': while s.sy == '.':
level += 1 level += 1
s.next() s.next()
if s.sy == 'cimport':
s.error("Relative cimport is not supported yet")
else: else:
level = None level = None
if level is not None and s.sy == 'import': if level is not None and s.sy in ('import', 'cimport'):
# we are dealing with "from .. import foo, bar" # we are dealing with "from .. import foo, bar"
dotted_name_pos, dotted_name = s.position(), '' dotted_name_pos, dotted_name = s.position(), ''
elif level is not None and s.sy == 'cimport':
# "from .. cimport"
s.error("Relative cimport is not supported yet")
else:
(dotted_name_pos, _, dotted_name, _) = \
p_dotted_name(s, as_allowed = 0)
if s.sy in ('import', 'cimport'):
kind = s.sy
s.next()
else: else:
if level is None and Future.absolute_import in s.context.future_directives:
level = 0
(dotted_name_pos, _, dotted_name, _) = p_dotted_name(s, as_allowed=False)
if s.sy not in ('import', 'cimport'):
s.error("Expected 'import' or 'cimport'") s.error("Expected 'import' or 'cimport'")
kind = s.sy
s.next()
is_cimport = kind == 'cimport' is_cimport = kind == 'cimport'
is_parenthesized = False is_parenthesized = False
...@@ -1359,7 +1354,7 @@ def p_from_import_statement(s, first_statement = 0): ...@@ -1359,7 +1354,7 @@ def p_from_import_statement(s, first_statement = 0):
if dotted_name == '__future__': if dotted_name == '__future__':
if not first_statement: if not first_statement:
s.error("from __future__ imports must occur at the beginning of the file") s.error("from __future__ imports must occur at the beginning of the file")
elif level is not None: elif level:
s.error("invalid syntax") s.error("invalid syntax")
else: else:
for (name_pos, name, as_name, kind) in imported_names: for (name_pos, name, as_name, kind) in imported_names:
...@@ -1374,9 +1369,10 @@ def p_from_import_statement(s, first_statement = 0): ...@@ -1374,9 +1369,10 @@ def p_from_import_statement(s, first_statement = 0):
s.context.future_directives.add(directive) s.context.future_directives.add(directive)
return Nodes.PassStatNode(pos) return Nodes.PassStatNode(pos)
elif kind == 'cimport': elif kind == 'cimport':
return Nodes.FromCImportStatNode(pos, return Nodes.FromCImportStatNode(
module_name = dotted_name, pos, module_name=dotted_name,
imported_names = imported_names) relative_level=level,
imported_names=imported_names)
else: else:
imported_name_strings = [] imported_name_strings = []
items = [] items = []
......
...@@ -1081,13 +1081,19 @@ class ModuleScope(Scope): ...@@ -1081,13 +1081,19 @@ class ModuleScope(Scope):
entry.name = name entry.name = name
return entry return entry
def find_module(self, module_name, pos): def find_module(self, module_name, pos, relative_level=-1):
# Find a module in the import namespace, interpreting # Find a module in the import namespace, interpreting
# relative imports relative to this module's parent. # relative imports relative to this module's parent.
# Finds and parses the module's .pxd file if the module # Finds and parses the module's .pxd file if the module
# has not been referenced before. # has not been referenced before.
return self.global_scope().context.find_module( module_scope = self.global_scope()
module_name, relative_to = self.parent_module, pos = pos) if relative_level is not None and relative_level > 0:
# merge current absolute module name and relative import name into qualified name
current_module = module_scope.qualified_name.split('.')
base_package = current_module[:-relative_level]
module_name = '.'.join(base_package + module_name.split('.'))
return module_scope.context.find_module(
module_name, relative_to=None if relative_level == 0 else self.parent_module, pos=pos)
def find_submodule(self, name): def find_submodule(self, name):
# Find and return scope for a submodule of this module, # Find and return scope for a submodule of this module,
......
# mode: error
# tag: cimport
from ..relative_cimport cimport some_name
from ..cython cimport declare
_ERRORS="""
4:0: 'relative_cimport.pxd' not found
4:0: 'some_name.pxd' not found
4:0: relative cimport beyond main package is not allowed
4:32: Name 'some_name' not declared in module 'relative_cimport'
6:0: 'declare.pxd' not found
6:0: relative cimport beyond main package is not allowed
6:22: Name 'declare' not declared in module 'cython'
"""
# mode: run
# tag: cimport
PYTHON setup.py build_ext --inplace
PYTHON -c "from pkg.b import test; assert test() == (1, 2)"
PYTHON -c "from pkg.sub.c import test; assert test() == (1, 2)"
######## setup.py ########
from distutils.core import setup
from Cython.Build import cythonize
from Cython.Distutils.extension import Extension
setup(
ext_modules=cythonize('**/*.pyx'),
)
######## pkg/__init__.py ########
######## pkg/sub/__init__.py ########
######## pkg/a.pyx ########
cdef class test_pxd:
pass
######## pkg/a.pxd ########
cdef class test_pxd:
cdef public int x
cdef public int y
######## pkg/b.pyx ########
from .a cimport test_pxd
def test():
cdef test_pxd obj = test_pxd()
obj.x = 1
obj.y = 2
return (obj.x, obj.y)
######## pkg/sub/c.pyx ########
from ..a cimport test_pxd
def test():
cdef test_pxd obj = test_pxd()
obj.x = 1
obj.y = 2
return (obj.x, obj.y)
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment