Commit 2fa4df3b authored by Robert Bradshaw's avatar Robert Bradshaw

Parse distutils directives.

parent a8404cbd
...@@ -19,6 +19,90 @@ def cached_method(f): ...@@ -19,6 +19,90 @@ def cached_method(f):
return res return res
return wrapper return wrapper
def parse_list(s):
if s[0] == '[' and s[-1] == ']':
s = s[1:-1]
delimiter = ','
else:
delimiter = ' '
s, literals = strip_string_literals(s)
def unquote(literal):
literal = literal.strip()
if literal[0] == "'":
return literals[literal[1:-1]]
else:
return literal
return [unquote(item) for item in s.split(delimiter)]
transitive_str = object()
transitive_list = object()
distutils_settings = {
'name': str,
'sources': list,
'define_macros': list,
'undef_macros': list,
'libraries': transitive_list,
'library_dirs': transitive_list,
'runtime_library_dirs': transitive_list,
'include_dirs': transitive_list,
'extra_objects': list,
'extra_compile_args': list,
'extra_link_args': list,
'export_symbols': list,
'depends': transitive_list,
'language': transitive_str,
}
def line_iter(source):
start = 0
while True:
end = source.find('\n', start)
if end == -1:
yield source[start:]
return
yield source[start:end]
start = end+1
class DistutilsInfo(object):
def __init__(self, source):
self.values = {}
for line in line_iter(source):
line = line.strip()
if line != '' and line[0] != '#':
break
line = line[1:].strip()
if line[:10] == 'distutils:':
line = line[10:]
ix = line.index('=')
key = str(line[:ix].strip())
value = line[ix+1:].strip()
type = distutils_settings[key]
if type in (list, transitive_list):
value = parse_list(value)
if key == 'define_macros':
value = [tuple(macro.split('=')) for macro in value]
self.values[key] = value
def merge(self, other):
for key, value in other.values.items():
type = distutils_settings[key]
if type is transitive_str and key not in self.values:
self.values[key] = value
elif type is transitive_list:
if key in self.values:
all = self.values[key]
for v in value:
if v not in all:
all.append(v)
else:
self.values[key] = value
return self
def strip_string_literals(code, prefix='__Pyx_L'): def strip_string_literals(code, prefix='__Pyx_L'):
""" """
Normalizes every string literal to be of the form '__Pyx_Lxxx', Normalizes every string literal to be of the form '__Pyx_Lxxx',
...@@ -85,6 +169,7 @@ def parse_dependencies(source_filename): ...@@ -85,6 +169,7 @@ def parse_dependencies(source_filename):
# The only catch is that we must strip comments and string # The only catch is that we must strip comments and string
# literals ahead of time. # literals ahead of time.
source = Utils.open_source_file(source_filename, "rU").read() source = Utils.open_source_file(source_filename, "rU").read()
distutils_info = DistutilsInfo(source)
source = re.sub('#.*', '', source) source = re.sub('#.*', '', source)
source, literals = strip_string_literals(source) source, literals = strip_string_literals(source)
source = source.replace('\\\n', ' ') source = source.replace('\\\n', ' ')
...@@ -105,7 +190,7 @@ def parse_dependencies(source_filename): ...@@ -105,7 +190,7 @@ def parse_dependencies(source_filename):
includes.append(literals[groups[5]]) includes.append(literals[groups[5]])
else: else:
externs.append(literals[groups[7]]) externs.append(literals[groups[7]])
return cimports, includes, externs return cimports, includes, externs, distutils_info
class DependencyTree(object): class DependencyTree(object):
...@@ -120,7 +205,7 @@ class DependencyTree(object): ...@@ -120,7 +205,7 @@ class DependencyTree(object):
@cached_method @cached_method
def cimports_and_externs(self, filename): def cimports_and_externs(self, filename):
cimports, includes, externs = self.parse_dependencies(filename) cimports, includes, externs = self.parse_dependencies(filename)[:3]
cimports = set(cimports) cimports = set(cimports)
externs = set(externs) externs = set(externs)
for include in includes: for include in includes:
...@@ -149,7 +234,7 @@ class DependencyTree(object): ...@@ -149,7 +234,7 @@ class DependencyTree(object):
if module[0] == '.': if module[0] == '.':
raise NotImplementedError, "New relative imports." raise NotImplementedError, "New relative imports."
if filename is not None: if filename is not None:
relative = '.'.join(self.package(filename) + module.split('.')) relative = '.'.join(self.package(filename) + tuple(module.split('.')))
pxd = self.context.find_pxd_file(relative, None) pxd = self.context.find_pxd_file(relative, None)
if pxd: if pxd:
return pxd return pxd
...@@ -158,9 +243,9 @@ class DependencyTree(object): ...@@ -158,9 +243,9 @@ class DependencyTree(object):
@cached_method @cached_method
def cimported_files(self, filename): def cimported_files(self, filename):
if filename[-4:] == '.pyx' and os.path.exists(filename[:-4] + '.pxd'): if filename[-4:] == '.pyx' and os.path.exists(filename[:-4] + '.pxd'):
self_pxd = (filename[:-4] + '.pxd',) self_pxd = [filename[:-4] + '.pxd']
else: else:
self_pxd = () self_pxd = []
a = self.cimports(filename) a = self.cimports(filename)
b = filter(None, [self.find_pxd(m, filename) for m in self.cimports(filename)]) b = filter(None, [self.find_pxd(m, filename) for m in self.cimports(filename)])
if len(a) != len(b): if len(a) != len(b):
...@@ -186,6 +271,12 @@ class DependencyTree(object): ...@@ -186,6 +271,12 @@ class DependencyTree(object):
def newest_dependency(self, filename): def newest_dependency(self, filename):
return self.transitive_merge(filename, self.extract_timestamp, max) return self.transitive_merge(filename, self.extract_timestamp, max)
def distutils_info0(self, filename):
return self.parse_dependencies(filename)[3]
def distutils_info(self, filename):
return self.transitive_merge(filename, self.distutils_info0, DistutilsInfo.merge)
def transitive_merge(self, node, extract, merge): def transitive_merge(self, node, extract, merge):
try: try:
seen = self._transitive_cache[extract, merge] seen = self._transitive_cache[extract, merge]
...@@ -229,6 +320,8 @@ def create_dependency_tree(ctx=None): ...@@ -229,6 +320,8 @@ def create_dependency_tree(ctx=None):
_dep_tree = DependencyTree(ctx) _dep_tree = DependencyTree(ctx)
return _dep_tree return _dep_tree
# TODO: Take common options.
# TODO: Symbolic names (e.g. for numpy.include_dirs()
def create_extension_list(filepatterns, ctx=None): def create_extension_list(filepatterns, ctx=None):
deps = create_dependency_tree(ctx) deps = create_dependency_tree(ctx)
if isinstance(filepatterns, str): if isinstance(filepatterns, str):
...@@ -238,7 +331,7 @@ def create_extension_list(filepatterns, ctx=None): ...@@ -238,7 +331,7 @@ def create_extension_list(filepatterns, ctx=None):
for file in glob(pattern): for file in glob(pattern):
pkg = deps.package(file) pkg = deps.package(file)
name = deps.fully_qualifeid_name(file) name = deps.fully_qualifeid_name(file)
module_list.append(Extension(name=name, sources=[file])) module_list.append(Extension(name=name, sources=[file], **deps.distutils_info(file).values))
return module_list return module_list
def cythonize(module_list, ctx=None): def cythonize(module_list, ctx=None):
......
PYTHON setup.py build_ext --inplace
PYTHON -c "import a"
######## setup.py ########
# TODO: Better interface...
from Cython.Compiler.Dependencies import create_extension_list, cythonize
from distutils.core import setup
setup(
ext_modules = cythonize(create_extension_list("*.pyx")),
)
######## my_lib.pxd ########
# distutils: language=c++
cdef extern from "my_lib_helper.cpp" namespace "A":
int x
######## my_lib_helper.cpp #######
namespace A {
int x = 100;
};
######## a.pyx ########
from my_lib cimport x
print x
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment