Skip to content
This repository was archived by the owner on Jun 14, 2025. It is now read-only.

Commit 053a4cb

Browse files
authored
Fixing elementwise absolute value operation on UserTypes (#23)
* Steps to fix some particularly egregious elementwise abs performance on UserTypes * Slightly sneakier way of avoiding name mangling in new builtin * Fix elementwise abs test to avoid indexing into a UserType, other than in the outer Fortran test code
1 parent 128e255 commit 053a4cb

4 files changed

Lines changed: 49 additions & 21 deletions

File tree

dagrt/builtins_python.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,11 @@ def builtin_dot_product(a, b):
5959
return np.vdot(a, b)
6060

6161

62+
def builtin_elementwise_abs(x):
63+
import numpy as np
64+
return np.abs(x)
65+
66+
6267
def builtin_array(n):
6368
import numpy as np
6469
if n != np.floor(n):
@@ -141,6 +146,7 @@ def builtin_print(arg):
141146
"<builtin>norm_2": builtin_norm_2,
142147
"<builtin>norm_inf": builtin_norm_inf,
143148
"<builtin>dot_product": builtin_dot_product,
149+
"<builtin>elementwise_abs": builtin_elementwise_abs,
144150
"<builtin>array": builtin_array,
145151
"<builtin>matmul": builtin_matmul,
146152
"<builtin>transpose": builtin_transpose,

dagrt/codegen/fortran.py

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2407,12 +2407,15 @@ def codegen_builtin_len(results, function, args, arg_kinds,
24072407
code_generator.emit("")
24082408

24092409

2410-
class AbsComputer(TypeVisitorWithResult):
2411-
def visit_BuiltinType(self, fortran_type, fortran_expr_str, index_expr_map):
2410+
class AbsComputer(AssignmentEmitter):
2411+
def visit_BuiltinType(self, fortran_type, fortran_expr_str, index_expr_map,
2412+
rhs_expr, is_rhs_target):
2413+
expr = self.code_generator.expr(rhs_expr)
24122414
self.code_generator.emit(
2413-
"{result} = abs({result})"
2415+
"{result} = abs({expr})"
24142416
.format(
2415-
result=self.result_expr))
2417+
result=fortran_expr_str,
2418+
expr=expr))
24162419

24172420

24182421
def codegen_builtin_elementwise_abs(results, function, args, arg_kinds,
@@ -2421,25 +2424,23 @@ def codegen_builtin_elementwise_abs(results, function, args, arg_kinds,
24212424

24222425
from dagrt.data import Scalar, Array, UserType
24232426
x_kind = arg_kinds[0]
2424-
if isinstance(x_kind, Scalar):
2425-
if x_kind.is_real_valued:
2426-
ftype = BuiltinType("real*8")
2427-
else:
2428-
ftype = BuiltinType("complex*16")
2429-
elif isinstance(x_kind, UserType):
2430-
ftype = code_generator.user_type_map[x_kind.identifier]
2431-
elif isinstance(x_kind, Array):
2427+
if isinstance(x_kind, Scalar) or isinstance(x_kind, Array):
24322428
code_generator.emit("{result} = abs({arg})".format(
24332429
result=result,
24342430
arg=args[0]))
24352431
return
2432+
elif isinstance(x_kind, UserType):
2433+
ftype = code_generator.user_type_map[x_kind.identifier]
24362434
else:
24372435
raise TypeError("unsupported kind for elementwise_abs argument: %s" % x_kind)
24382436

2439-
code_generator.emit(f"{result} = 0")
2440-
code_generator.emit("")
2441-
2442-
AbsComputer(code_generator, result)(ftype, args[0], {})
2437+
# Need to pass argument to assignment emitter as a variable (for mappers)
2438+
# Call it a target to avoid name mangling
2439+
from pymbolic import var
2440+
argvar = var("<target>" + args[0])
2441+
code_generator.sym_kind_table.set(
2442+
None, "<target>" + args[0], UserType(x_kind.identifier))
2443+
AbsComputer(code_generator)(ftype, result, {}, argvar, is_rhs_target=False)
24432444
code_generator.emit("")
24442445

24452446

test/test_codegen_fortran.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -150,19 +150,32 @@ def test_self_dep_in_loop():
150150
fortran_libraries=["lapack", "blas"])
151151

152152

153+
class AbsFailure:
154+
pass
155+
156+
153157
def test_elementwise_abs():
154158
with CodeBuilder(name="primary") as cb:
155-
cb("y", "<func>f(0, <state>ytype)")
156-
cb("<state>ytype", "y")
157-
# Test new builtin on a usertype.
158-
cb("z", "<builtin>elementwise_abs(<state>ytype)")
159159
cb("i", "<builtin>array(20)")
160160
cb("i[j]", "-j",
161161
loops=(("j", 0, 20),))
162162
# Test new builtin on an array type.
163163
cb("k", "<builtin>elementwise_abs(i)")
164+
with cb.if_("k[20] > 19"):
165+
cb.raise_(AbsFailure)
166+
with cb.if_("k[20] < 19"):
167+
cb.raise_(AbsFailure)
164168
# Test new builtin on a scalar.
165169
cb("l", "<builtin>elementwise_abs(-20)")
170+
with cb.if_("l > 20"):
171+
cb.raise_(AbsFailure)
172+
with cb.if_("l < 20"):
173+
cb.raise_(AbsFailure)
174+
cb("y", "<func>f(0, <state>ytype)")
175+
cb("<state>ytype", "y")
176+
# Test new builtin on a usertype.
177+
cb("<state>ytype", "<builtin>elementwise_abs(<state>ytype)")
178+
# (We check this in the outer test code)
166179

167180
code = create_DAGCode_with_steady_phase(cb.statements)
168181

test/test_element_abs.f90

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,18 +13,26 @@ program test_element_abs
1313
real*8, dimension(100) :: y0
1414

1515
integer i
16+
integer stderr
17+
parameter(stderr=0)
1618

1719
! start code ----------------------------------------------------------------
1820

1921
dagrt_state_ptr => dagrt_state
2022

2123

2224
do i = 1, 100
23-
y0 = i
25+
y0(i) = i
2426
end do
2527

2628
call timestep_initialize(dagrt_state=dagrt_state_ptr, state_ytype=y0)
2729
call timestep_run(dagrt_state=dagrt_state_ptr)
30+
! For the UserType, check that the absolute value did its job.
31+
do i = 1, 100
32+
if (dagrt_state%state_ytype(i) /= 2*y0(i)) then
33+
write(stderr,*) "UserType elementwise abs failure"
34+
endif
35+
end do
2836
call timestep_shutdown(dagrt_state=dagrt_state_ptr)
2937

3038
end program

0 commit comments

Comments
 (0)