Skip to content

Commit 3a6cf78

Browse files
authored
Merge pull request #867
update rocRAND wrappers
2 parents 4d9bff9 + 2019c6a commit 3a6cf78

6 files changed

Lines changed: 289 additions & 69 deletions

File tree

gen/rocrand/generator.jl

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
using Clang.Generators
2+
using JuliaFormatter
3+
4+
include_dir = normpath(joinpath(ENV["ROCM_PATH"], "include"))
5+
rocrand_dir = joinpath(include_dir, "rocrand")
6+
options = load_options("rocrand/rocrand-generator.toml")
7+
8+
args = get_default_args()
9+
push!(args, "-I$include_dir")
10+
11+
rocrand_h = read(joinpath(rocrand_dir, "rocrand.h"), String)
12+
open("./rocrand.h", "w") do io
13+
println(io, """
14+
#include <stddef.h>
15+
16+
typedef void* hipStream_t;
17+
typedef struct { unsigned int x, y, z, w; } uint4;
18+
""")
19+
print(io, rocrand_h)
20+
end
21+
headers = [
22+
"./rocrand.h"
23+
]
24+
25+
ctx = create_context(headers, args, options)
26+
27+
# build without printing so we can do custom rewriting
28+
build!(ctx, BUILDSTAGE_NO_PRINTING)
29+
30+
# custom rewriter
31+
function rewrite!(e::Expr)
32+
if e.head === :const
33+
@assert Meta.isexpr(e.args[1], :(=))
34+
rhs = e.args[1].args[2]
35+
if Meta.isexpr(rhs, :call)
36+
if rhs.args[1] == :(*) && rhs.args[3] == :f
37+
e.args[1].args[2] = :(Float32($(rhs.args[2])))
38+
elseif rhs.args[1] == :(Cuint)
39+
e.args[1].args[2] = :($(rhs.args[2]) % Cuint)
40+
end
41+
end
42+
return e
43+
end
44+
(e.head === :function && Meta.isexpr(e.args[1], :call)) || return e
45+
f = e.args[1].args[1]
46+
if !(f isa Symbol)
47+
@assert f in (:(Base.getproperty), :(Base.setproperty!), :(Base.propertynames))
48+
return e
49+
end
50+
stmts = e.args[2].args
51+
map!(stmts, stmts) do ex
52+
Meta.isexpr(ex, :macrocall) || return ex
53+
ex.args[1] === Symbol("@ccall") || return ex
54+
# TODO: should this be `@gcsafe_ccall`?
55+
# ex.args[1] = Symbol("@gcsafe_ccall")
56+
Expr(:macrocall, Symbol("@check"), nothing, ex)
57+
end
58+
pushfirst!(stmts, :(AMDGPU.prepare_state()))
59+
return e
60+
end
61+
62+
function rewrite!(dag::ExprDAG)
63+
for node in get_nodes(dag)
64+
for expr in get_exprs(node)
65+
rewrite!(expr)
66+
end
67+
end
68+
end
69+
70+
rewrite!(ctx.dag)
71+
72+
# print
73+
build!(ctx, BUILDSTAGE_PRINTING_ONLY)
74+
75+
path = options["general"]["output_file_path"]
76+
format_file(path, YASStyle())

gen/rocrand/rocrand-generator.toml

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
[general]
2+
library_name = "librocrand"
3+
output_file_path = "../src/rand/librocrand.jl"
4+
export_symbol_prefixes = []
5+
print_using_CEnum = false
6+
output_ignorelist = [
7+
"(__)?hip.*",
8+
"(__)?HIP.*",
9+
"rocrand_status",
10+
"half",
11+
"SKEIN_KS_PARITY64",
12+
]
13+
14+
[codegen]
15+
use_ccall_macro = true

src/rand/error.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,20 @@
11
export ROCRANDError
22

33
import .AMDGPU: @check, check
4+
using CEnum: @cenum
5+
6+
@cenum rocrand_status::UInt32 begin
7+
ROCRAND_STATUS_SUCCESS = 0
8+
ROCRAND_STATUS_VERSION_MISMATCH = 100
9+
ROCRAND_STATUS_NOT_CREATED = 101
10+
ROCRAND_STATUS_ALLOCATION_FAILED = 102
11+
ROCRAND_STATUS_TYPE_ERROR = 103
12+
ROCRAND_STATUS_OUT_OF_RANGE = 104
13+
ROCRAND_STATUS_LENGTH_NOT_MULTIPLE = 105
14+
ROCRAND_STATUS_DOUBLE_PRECISION_REQUIRED = 106
15+
ROCRAND_STATUS_LAUNCH_FAILURE = 107
16+
ROCRAND_STATUS_INTERNAL_ERROR = 108
17+
end
418

519
struct ROCRANDError <: Exception
620
code::rocrand_status

0 commit comments

Comments
 (0)