diff --git a/compyle/array.py b/compyle/array.py index 5b5a034..e737ac4 100644 --- a/compyle/array.py +++ b/compyle/array.py @@ -97,6 +97,22 @@ def get_backend(backend=None): minmax_operator_tpl = """ + __device__ ${dtype}() + { + } + + __device__ ${dtype}(${dtype} const volatile &src) + { + % for prop in prop_names: + % if not only_max: + this->cur_min_${prop} = src.cur_min_${prop}; + % endif + % if not only_min: + this->cur_max_${prop} = src.cur_max_${prop}; + % endif + % endfor + } + __device__ ${dtype} volatile &operator=( ${dtype} const &src) volatile { diff --git a/compyle/ast_utils.py b/compyle/ast_utils.py index e622142..eaa190d 100644 --- a/compyle/ast_utils.py +++ b/compyle/ast_utils.py @@ -9,6 +9,18 @@ basestring = str if PY_VER > 2 else basestring +def get_string_value(node): + """Return a string literal's value or None if *node* is not a string.""" + ast_constant = getattr(ast, 'Constant', None) + if ast_constant is not None and isinstance(node, ast_constant) and \ + isinstance(node.value, str): + return node.value + ast_str = getattr(ast, 'Str', None) + if ast_str is not None and isinstance(node, ast_str): + return node.s + return None + + class NameLister(ast.NodeVisitor): """Utility class to collect the Names in an AST. """ diff --git a/compyle/cython_generator.py b/compyle/cython_generator.py index a0e7e51..7ab9ba9 100644 --- a/compyle/cython_generator.py +++ b/compyle/cython_generator.py @@ -21,7 +21,7 @@ from .types import KnownType, Undefined, get_declare_info from .config import get_config -from .ast_utils import get_assigned, has_return +from .ast_utils import get_assigned, get_string_value, has_return from .utils import getsourcelines logger = logging.getLogger(__name__) @@ -247,11 +247,12 @@ def parse_declare(code): if call.func.id != 'declare': raise CodeGenerationError('Unknown declare statement: %s' % code) arg0 = call.args[0] - if not isinstance(arg0, ast.Str): - err = 'Type should be a string, given :%r' % arg0.s + type_str = get_string_value(arg0) + if type_str is None: + err = 'Type should be a string, given :%r' % getattr(arg0, 'value', arg0) raise CodeGenerationError(err) - return get_declare_info(arg0.s) + return get_declare_info(type_str) class CythonGenerator(object): diff --git a/compyle/jit.py b/compyle/jit.py index 080fd42..2b15bc3 100644 --- a/compyle/jit.py +++ b/compyle/jit.py @@ -8,6 +8,7 @@ import time from pytools import memoize from .config import get_config +from .ast_utils import get_string_value from .cython_generator import CythonGenerator from .transpiler import Transpiler, BUILTINS from .types import (dtype_to_ctype, get_declare_info, @@ -198,15 +199,16 @@ def warn(self, message, node): warnings.warn(msg) def visit_declare(self, node): - if not isinstance(node.args[0], ast.Str): + type_str = get_string_value(node.args[0]) + if type_str is None: self.error("Argument to declare should be a string.", node) - type_str = node.args[0].s return self.get_declare_type(type_str) def visit_cast(self, node): - if not isinstance(node.args[1], ast.Str): + type_str = get_string_value(node.args[1]) + if type_str is None: self.error("Cast type should be a string.", node) - return node.args[1].s + return type_str def visit_address(self, node): base_type = self.visit(node.args[0]) @@ -294,6 +296,9 @@ def visit_BinOp(self, node): def visit_Num(self, node): return get_ctype_from_arg(node.n) + def visit_Constant(self, node): + return get_ctype_from_arg(node.value) + def visit_UnaryOp(self, node): return self.visit(node.operand) diff --git a/compyle/translator.py b/compyle/translator.py index 7a10a92..c786222 100644 --- a/compyle/translator.py +++ b/compyle/translator.py @@ -25,6 +25,7 @@ from .cython_generator import ( CodeGenerationError, KnownType, Undefined, all_numeric ) +from .ast_utils import get_string_value from .utils import getsource PY_VER = sys.version_info.major @@ -235,7 +236,7 @@ def _indent_block(self, code): def _remove_docstring(self, body): if body and isinstance(body[0], ast.Expr) and \ - isinstance(body[0].value, ast.Str): + get_string_value(body[0].value) is not None: return body[1:] else: return body @@ -351,9 +352,9 @@ def visit_Assign(self, node): left, right = node.targets[0], node.value if isinstance(right, ast.Call) and \ isinstance(right.func, ast.Name) and right.func.id == 'declare': - if not isinstance(right.args[0], ast.Str): + type = get_string_value(right.args[0]) + if type is None: self.error("Argument to declare should be a string.", node) - type = right.args[0].s if isinstance(left, ast.Name): self._known.add(left.id) return self._get_variable_declaration(type, [self.visit(left)]) @@ -395,7 +396,10 @@ def visit_Call(self, node): elif 'atomic' in node.func.id: return self.render_atomic(node.func.id, node.args[0]) elif node.func.id == 'cast': - return '(%s) (%s)' % (node.args[1].s, self.visit(node.args[0])) + type_str = get_string_value(node.args[1]) + if type_str is None: + self.error("Cast type should be a string.", node) + return '(%s) (%s)' % (type_str, self.visit(node.args[0])) else: return '{func}({args})'.format( func=node.func.id, @@ -682,6 +686,14 @@ def visit_NameConstant(self, node): else: return value + def visit_Constant(self, node): + value = node.value + if value is True or value is False or value is None: + return self._replacements[value] + if isinstance(value, str): + return r'"%s"' % value + return literal_to_float(value, self._use_double) + def visit_Not(self, node): return '!'