Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions compyle/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,22 @@ def get_backend(backend=None):

minmax_operator_tpl = """

__device__ ${dtype}()
{
}

__device__ ${dtype}(${dtype} const volatile &src)
{
% for prop in prop_names:
% if not only_max:
this->cur_min_${prop} = src.cur_min_${prop};
% endif
% if not only_min:
this->cur_max_${prop} = src.cur_max_${prop};
% endif
% endfor
}
Comment on lines +100 to +114
Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@copilot apply changes based on this feedback


__device__ ${dtype} volatile &operator=(
${dtype} const &src) volatile
{
Expand Down
12 changes: 12 additions & 0 deletions compyle/ast_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,18 @@
basestring = str if PY_VER > 2 else basestring


def get_string_value(node):
"""Return a string literal's value or None if *node* is not a string."""
ast_constant = getattr(ast, 'Constant', None)
if ast_constant is not None and isinstance(node, ast_constant) and \
isinstance(node.value, str):
return node.value
ast_str = getattr(ast, 'Str', None)
if ast_str is not None and isinstance(node, ast_str):
return node.s
return None


class NameLister(ast.NodeVisitor):
"""Utility class to collect the Names in an AST.
"""
Expand Down
9 changes: 5 additions & 4 deletions compyle/cython_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from .types import KnownType, Undefined, get_declare_info
from .config import get_config
from .ast_utils import get_assigned, has_return
from .ast_utils import get_assigned, get_string_value, has_return
from .utils import getsourcelines

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -247,11 +247,12 @@ def parse_declare(code):
if call.func.id != 'declare':
raise CodeGenerationError('Unknown declare statement: %s' % code)
arg0 = call.args[0]
if not isinstance(arg0, ast.Str):
err = 'Type should be a string, given :%r' % arg0.s
type_str = get_string_value(arg0)
if type_str is None:
err = 'Type should be a string, given :%r' % getattr(arg0, 'value', arg0)
raise CodeGenerationError(err)

return get_declare_info(arg0.s)
return get_declare_info(type_str)


class CythonGenerator(object):
Expand Down
13 changes: 9 additions & 4 deletions compyle/jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import time
from pytools import memoize
from .config import get_config
from .ast_utils import get_string_value
from .cython_generator import CythonGenerator
from .transpiler import Transpiler, BUILTINS
from .types import (dtype_to_ctype, get_declare_info,
Expand Down Expand Up @@ -198,15 +199,16 @@ def warn(self, message, node):
warnings.warn(msg)

def visit_declare(self, node):
if not isinstance(node.args[0], ast.Str):
type_str = get_string_value(node.args[0])
if type_str is None:
self.error("Argument to declare should be a string.", node)
type_str = node.args[0].s
return self.get_declare_type(type_str)

def visit_cast(self, node):
if not isinstance(node.args[1], ast.Str):
type_str = get_string_value(node.args[1])
if type_str is None:
self.error("Cast type should be a string.", node)
return node.args[1].s
return type_str

def visit_address(self, node):
base_type = self.visit(node.args[0])
Expand Down Expand Up @@ -294,6 +296,9 @@ def visit_BinOp(self, node):
def visit_Num(self, node):
return get_ctype_from_arg(node.n)

def visit_Constant(self, node):
return get_ctype_from_arg(node.value)

def visit_UnaryOp(self, node):
return self.visit(node.operand)

Expand Down
20 changes: 16 additions & 4 deletions compyle/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from .cython_generator import (
CodeGenerationError, KnownType, Undefined, all_numeric
)
from .ast_utils import get_string_value
from .utils import getsource

PY_VER = sys.version_info.major
Expand Down Expand Up @@ -235,7 +236,7 @@ def _indent_block(self, code):

def _remove_docstring(self, body):
if body and isinstance(body[0], ast.Expr) and \
isinstance(body[0].value, ast.Str):
get_string_value(body[0].value) is not None:
return body[1:]
else:
return body
Expand Down Expand Up @@ -351,9 +352,9 @@ def visit_Assign(self, node):
left, right = node.targets[0], node.value
if isinstance(right, ast.Call) and \
isinstance(right.func, ast.Name) and right.func.id == 'declare':
if not isinstance(right.args[0], ast.Str):
type = get_string_value(right.args[0])
if type is None:
self.error("Argument to declare should be a string.", node)
type = right.args[0].s
if isinstance(left, ast.Name):
self._known.add(left.id)
return self._get_variable_declaration(type, [self.visit(left)])
Expand Down Expand Up @@ -395,7 +396,10 @@ def visit_Call(self, node):
elif 'atomic' in node.func.id:
return self.render_atomic(node.func.id, node.args[0])
elif node.func.id == 'cast':
return '(%s) (%s)' % (node.args[1].s, self.visit(node.args[0]))
type_str = get_string_value(node.args[1])
if type_str is None:
self.error("Cast type should be a string.", node)
return '(%s) (%s)' % (type_str, self.visit(node.args[0]))
else:
return '{func}({args})'.format(
func=node.func.id,
Expand Down Expand Up @@ -682,6 +686,14 @@ def visit_NameConstant(self, node):
else:
return value

def visit_Constant(self, node):
value = node.value
if value is True or value is False or value is None:
return self._replacements[value]
if isinstance(value, str):
return r'"%s"' % value
return literal_to_float(value, self._use_double)

def visit_Not(self, node):
return '!'

Expand Down