Replace string parsing with WaveTensorType bindings in water_emitter#839
Replace string parsing with WaveTensorType bindings in water_emitter#839
Conversation
|
|
||
| def _map_address_space(addr: str) -> str: | ||
| # TODO: Use WaveAddressSpaceAttr bindings when they exist | ||
| def _map_address_space(addr: str) -> WaveAddressSpaceAttr: |
There was a problem hiding this comment.
@copilot it looks like the type annotation in the original code was wrong and this function takes an int, not str. This int can take only some values defined somewhere in the code, could you use that instead of strings?
There was a problem hiding this comment.
Fixed in c0dbb92. Changed parameter type from str to int and now properly uses AddressSpace enum values (GLOBAL_MEMORY=2, SHARED_MEMORY=1, REGISTER=0) with direct equality checks.
| ) | ||
| # Convert symbolic shape to WaveSymbolAttr list | ||
| shape_attrs = [WaveSymbolAttr.get(str(s)) for s in type_.symbolic_shape] | ||
| element_type = _dtype_to_mlir(ctx, type_.dtype) |
| shape_attrs = [WaveSymbolAttr.get(str(s)) for s in type_.symbolic_shape] | ||
| element_type = _dtype_to_mlir(ctx, type_.dtype) | ||
| address_space_attr = _map_address_space(type_.address_space) | ||
| return WaveTensorType.get(shape_attrs, True, element_type, address_space_attr) |
There was a problem hiding this comment.
@copilot this code looks identical to the one above, can this be factored out into a helper function?
There was a problem hiding this comment.
Fixed in c0dbb92. Created _create_wave_tensor_type() helper function that both Register and Memory branches now use, eliminating the duplication.
| DATATYPE_MAP: dict[str, Callable[[], ir.Type]] = { | ||
| "f16": ir.F16Type.get, | ||
| "f32": ir.F32Type.get, | ||
| "f64": ir.F64Type.get, | ||
| "i1": lambda: ir.IntegerType.get_signless(1), | ||
| "i32": lambda: ir.IntegerType.get_signless(32), | ||
| "i64": lambda: ir.IntegerType.get_signless(64), | ||
| } |
There was a problem hiding this comment.
Added in fc3a221. Added "i8": lambda: ir.IntegerType.get_signless(8) to the DATATYPE_MAP.
There was a problem hiding this comment.
@copilot I'd also like to see i16 and bf16 added
There was a problem hiding this comment.
Added in 2c463da. Added mappings for:
"i16": lambda: ir.IntegerType.get_signless(16)"bf16": ir.BF16Type.get
|
@copilot there still seem to be test failures. Please run the tests and fix any broken test cases. |
Fixed in 77e140c. The issue was that |
|
@copilot please run |
Co-authored-by: ftynse <1512299+ftynse@users.noreply.github.com>
…ctor duplicated code Co-authored-by: ftynse <1512299+ftynse@users.noreply.github.com>
Co-authored-by: tgymnich <7985149+tgymnich@users.noreply.github.com>
Co-authored-by: tgymnich <7985149+tgymnich@users.noreply.github.com>
Co-authored-by: tgymnich <7985149+tgymnich@users.noreply.github.com>
Co-authored-by: tgymnich <7985149+tgymnich@users.noreply.github.com>
2c463da to
e11d39d
Compare
| # Handle int values | ||
| if addr == AddressSpace.GLOBAL_MEMORY.value: | ||
| return WaveAddressSpaceAttr.get(WaveAddressSpace.Global) | ||
| elif addr == AddressSpace.SHARED_MEMORY.value: | ||
| return WaveAddressSpaceAttr.get(WaveAddressSpace.Shared) | ||
| elif addr == AddressSpace.REGISTER.value: | ||
| return WaveAddressSpaceAttr.get(WaveAddressSpace.Register) | ||
| else: | ||
| # If we can't determine the address space, use unspecified for now | ||
| # TODO: Sometimes the address space is determined in the hyperparameters | ||
| # Check whether a mapping for `addr` is defined there. | ||
| return WaveAddressSpaceAttr.get(WaveAddressSpace.Unspecified) |
There was a problem hiding this comment.
@copilot I don't like integers being used for address spaces. Can you in a separate PR go over all occurrences of this in the code base and replace the int value with the proper enum value.
There was a problem hiding this comment.
Understood. I'll create a separate PR to refactor the codebase to use AddressSpace enum values instead of int values throughout. This will improve type safety and code clarity.
…ree-org#839) 1. **Added proper imports**: - WaveTensorType, WaveSymbolAttr, WaveAddressSpaceAttr, WaveAddressSpace from water_mlir - AddressSpace from wave_lang.kernel.lang.kernel_buffer 2. **Fixed `_map_address_space` function** (lines 186-216): - Now handles three input types: - `IndexSymbol` (sympy.Symbol like `GLOBAL_ADDRESS_SPACE`) - checks for `name` attribute and matches string - `AddressSpace` enum - converts to int value - `int` values - maps directly to MLIR enum - Properly maps all address space types to WaveAddressSpaceAttr - Formatted to match pre-commit standards (proper spacing, quote consistency, comment placement) 3. **Created `_create_wave_tensor_type` helper** (lines 219-228): - Eliminates code duplication between Register and Memory type handling - Converts symbolic shapes to WaveSymbolAttr list - Uses correct function name `_dtype_to_mlir_scalar_type` (not `_dtype_to_mlir`) - Constructs WaveTensorType with proper bindings 4. **Updated `_type_to_wave_mlir` function** (lines 231-257): - Simplified by using the new helper function - Both Register and Memory branches now call `_create_wave_tensor_type` - No more string parsing or concatenation 5. **Added data type mappings to DATATYPE_MAP** (lines 162-170): - Added `"i8": lambda: ir.IntegerType.get_signless(8)` - Added `"i16": lambda: ir.IntegerType.get_signless(16)` - Added `"bf16": ir.BF16Type.get` - Enables support for 8-bit integers, 16-bit integers, and bfloat16 types in Wave dialect Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: ftynse <1512299+ftynse@users.noreply.github.com> Co-authored-by: tgymnich <7985149+tgymnich@users.noreply.github.com>
Implementation Plan
_map_address_spacefunction to use WaveAddressSpaceAttr bindings instead of string mapping_type_to_wave_mlirfunction to use WaveTensorType.get() instead of ir.Type.parse()_map_address_spaceto acceptintinstead ofstrand useAddressSpaceenum values_dtype_to_mlirto_dtype_to_mlir_scalar_type_create_wave_tensor_typehelper functionSummary
Successfully implemented the WaveTensorType bindings in water_emitter.py:200 and addressed all review feedback, including fixing test failures and code formatting.
Changes Made:
Added proper imports:
Fixed
_map_address_spacefunction (lines 186-216):IndexSymbol(sympy.Symbol likeGLOBAL_ADDRESS_SPACE) - checks fornameattribute and matches stringAddressSpaceenum - converts to int valueintvalues - maps directly to MLIR enumCreated
_create_wave_tensor_typehelper (lines 219-228):_dtype_to_mlir_scalar_type(not_dtype_to_mlir)Updated
_type_to_wave_mlirfunction (lines 231-257):_create_wave_tensor_typeAdded data type mappings to DATATYPE_MAP (lines 162-170):
"i8": lambda: ir.IntegerType.get_signless(8)"i16": lambda: ir.IntegerType.get_signless(16)"bf16": ir.BF16Type.getSecurity Summary
No security vulnerabilities found.
Original prompt
💡 You can make Copilot smarter by setting up custom instructions, customizing its development environment and configuring Model Context Protocol (MCP) servers. Learn more Copilot coding agent tips in the docs.