Skip to content

Comments

[examples][XeGPU] Add XeGPU MLP example#56

Open
tkarna wants to merge 6 commits intollvm:mainfrom
tkarna:xegpu-mlp
Open

[examples][XeGPU] Add XeGPU MLP example#56
tkarna wants to merge 6 commits intollvm:mainfrom
tkarna:xegpu-mlp

Conversation

@tkarna
Copy link
Contributor

@tkarna tkarna commented Feb 23, 2026

  • Adds xegpu_mlp example. Supports arbitrary MLP models. Optional ReLU (on all layers). Bias not yet supported.
  • Matrix multiplication example has no-accumulate-c option to compute MLP-like C=A*B instead of C+=A*B.
  • Matmul and MLP examples share the same payload generator and lowering schedule:
    • Payload generator moved to lighthouse/ingress/gpu/matmul.py.
    • Schedule moved to lighthouse/schedule/xegpu/matmul_schedule.py.

Example: Run simplest KernelBench MLP:

python mlp.py -b 128 -i 16384 -o 8192 --hidden-sizes 16384 16384
MLP with 3 layers
  Layer 0: M=128, N=16384, K=16384
  Layer 1: M=128, N=16384, K=16384
  Layer 2: M=128, N=8192, K=16384
b=128 i=16384 o=8192 hs=16384,16384 dt=f16,f32 time(us): [...] GFLOPS: [...]

# cache allocated memrefs
self.gpu_memrefs = {}

def _allocate_array(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we outline the allocation into a helper module?

):
"""Transform schedule for matmul-like payload."""
try:
mod = bundle_xepu_mlp_schedule(

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Misprint?

Comment on lines -32 to -41
def emit_gpu_copy(suffix: str, element_type: ir.Type, rank: int = 2):
"""Emit GPU copy function."""
dyn = ir.ShapedType.get_dynamic_size()
memref_dyn_t = ir.MemRefType.get(rank * (dyn,), element_type)

@func.func(memref_dyn_t, memref_dyn_t, name="gpu_copy_" + suffix)
def copy_func(src, dst):
gpu.memcpy(None, [], dst, src)

copy_func.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This pattern (@func.func + c_interface) could be a decorator provided somewhere in lighthouse (or even better the @func.func decorator could accept this flag or generally attributes)

execution_engine: ExecutionEngine,
) -> ctypes.Structure:
key = (name, dtype_str)
if key in self.gpu_memrefs:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why allow the same with different types?

Comment on lines +95 to +100
alloc_func = execution_engine.lookup("gpu_alloc_" + dtype_str)
mref = make_nd_memref_descriptor(len(shape), as_ctype(dtype))()
ptr_mref = ctypes.pointer(ctypes.pointer(mref))
ptr_dims = [ctypes.pointer(ctypes.c_int32(d)) for d in shape]
alloc_func(get_packed_arg([ptr_mref] + ptr_dims))
self.gpu_memrefs[key] = mref
Copy link
Contributor

@fschlimb fschlimb Feb 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider using execution_engine.invoke.

# use integer values to avoid f16/f32 floating point discrepancies
def gen_random(shape, dtype):
# generate values in range [-3, 3]
a = np.round(6 * np.random.random_sample(shape)) - 3
Copy link
Contributor

@fschlimb fschlimb Feb 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

np.random.randint(-3, 4, shape)?

Comment on lines +449 to +453
if shape in self.param_db:
params = self.param_db[shape]
else:
raise ValueError(f"No parameters found for matmul shape {shape}")
parameters[f"layer_{i}"] = params
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if shape in self.param_db:
params = self.param_db[shape]
else:
raise ValueError(f"No parameters found for matmul shape {shape}")
parameters[f"layer_{i}"] = params
if shape not in self.param_db:
raise ValueError(f"No parameters found for matmul shape {shape}")
parameters[f"layer_{i}"] = self.param_db[shape]

}


class ParameterOracleMLP:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this a class and not just a function?

"xegpu-inst",
"final",
],
help="Dump kernel IR at different stages of lowering.",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might want to also mention that it will do nothing else than lowering/dumping.

Comment on lines +601 to +610
parts = [
f"b={args.batch_size}",
f"i={args.input_size}",
f"o={args.output_size}",
f"hs={list2str(hidden_sizes)}",
f"dt={ab_type},{c_type}",
f"time(us): {elapsed:.2f}",
f"GFLOPS: {gflops:.2f}",
]
print(" ".join(parts))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
parts = [
f"b={args.batch_size}",
f"i={args.input_size}",
f"o={args.output_size}",
f"hs={list2str(hidden_sizes)}",
f"dt={ab_type},{c_type}",
f"time(us): {elapsed:.2f}",
f"GFLOPS: {gflops:.2f}",
]
print(" ".join(parts))
print(
f"b={args.batch_size} "
f"i={args.input_size} "
f"o={args.output_size} "
f"hs={list2str(hidden_sizes)} "
f"dt={ab_type},{c_type} "
f"time(us): {elapsed:.2f} "
f"GFLOPS: {gflops:.2f}"
)

Comment on lines +24 to +34
python mlp.py -b 128 -i 16384 -o 8192 --hidden-sizes 16384 16384 ...
```

which corresponds to

```txt
MLP with 3 layers
Layer 0: M=128, N=16384, K=16384
Layer 1: M=128, N=16384, K=16384
Layer 2: M=128, N=8192, K=16384
```
Copy link
Contributor

@fschlimb fschlimb Feb 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For clarity, consider using different values for every parameter (instead of 16384 for three parameters). Same above.

Comment on lines +12 to +18
@func.func(*inputs, name="gpu_alloc_" + suffix)
def alloc_func(*shape):
dims = [arith.index_cast(index_t, a) for a in shape]
alloc = gpu.alloc(memref_dyn_t, None, [], dims, [])
return alloc

alloc_func.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See comment in deleted file above.

Comment on lines +72 to +79
@linalg.generic(
[c_tensor],
[empty],
[id_map, id_map],
[par_iter, par_iter],
)
def f(a, b):
return arith.extf(c_type, a)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't arith.extf/arith.truncf operate directly on tensors?

emit_gpu_copy(suffix, element_type)


def emit_mlp_layer(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe this could live in its own file, rather than in something that's called "matmul".

return terminal


def generate_matmul_payload(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider a better name. Also see above.

B = args[1]
C = args[-1]
bias = args[2] if has_bias else None
a_tensor = bufferization.to_tensor(tensor_a_t, A, restrict=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

emit_buf_to_tensor?

Comment on lines +295 to +300
if to_dealloc is not None:
gpu.dealloc(None, [], to_dealloc)
to_dealloc = None
if i != nlayers - 1:
# deallocate after next layer
to_dealloc = c_memref
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use deallocate_memrefs_on_exit?


class PipelineInterrupt(Exception):
"""Exception to signal early termination of the transform schedule."""

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is useful to have a dedicated exception to mark an intended pipeline interrupt. As such this can be caught and differentiated from other more severe exceptions.

Without such an interrupt exception, stopping the pipeline in the middle gets somewhat complicated, e.g., you'd return an extra boolean indicating "pipeline was interrupted".

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, yes, I meant remove the empty line :)

loop.HoistLoopInvariantSubsetsOp(k_loop)

transform.apply_cse(func)
canonicalize(func)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are all these canonicalizers needed?

Comment on lines +355 to +383
# A tile load layout
layout_load_a = {
"sg_layout": sg_layout,
"sg_data": sg_tile_a,
"inst_data": load_tile_a,
}
desc_op_a = xegpu.get_desc_op(tile_a)
# A tile load op anchor layout
load_op_a = transform.get_consumers_of_result(anytype, desc_op_a, 0)
xegpu.set_op_layout_attr(load_op_a, **layout_load_a)
# A tile dpas layout
layout_dpas_a = layout_load_a.copy()
layout_dpas_a["inst_data"] = dpas_shape_a
convert_layout(tile_a, layout_load_a, layout_dpas_a)

# B tile load layout
layout_load_b = {
"sg_layout": sg_layout,
"sg_data": sg_tile_b,
"inst_data": load_tile_b,
}
desc_op_b = xegpu.get_desc_op(tile_b)
# B tile load op anchor layout
load_op_b = transform.get_consumers_of_result(anytype, desc_op_b, 0)
xegpu.set_op_layout_attr(load_op_b, **layout_load_b)
# B tile dpas layout
layout_dpas_b = layout_load_b.copy()
layout_dpas_b["inst_data"] = dpas_shape_b
convert_layout(tile_b, layout_load_b, layout_dpas_b)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe deduplicate code?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants