TileSet MJX codegen fix#3299
Conversation
|
@tkelestemur I don't see the test added to this PR, you mentioned the test in the OP, can you add the test so I understand a bit better why you need this? |
| def _as_numpy_array(value): | ||
| if hasattr(value, 'numpy'): | ||
| value = value.numpy() | ||
| return np.asarray(value) |
There was a problem hiding this comment.
this method seems a bit odd, if a value has a numpy attr this returns np.asarray(value.numpy()) ?
There was a problem hiding this comment.
Good point, fixed. _as_numpy_array now returns value.numpy() directly when that path is available, and only uses np.asarray for the fallback path.
| if self.__class__ is not other.__class__: | ||
| return NotImplemented | ||
| return self.size == other.size and np.array_equal( | ||
| np.asarray(_as_numpy_array(self.adr)), |
There was a problem hiding this comment.
and this would be doing np.asarray(np.asarray(value.numpy())) potentially?
There was a problem hiding this comment.
Yep, that extra wrapping was unnecessary. I added a generic cleanup in the method adapter so generated method bodies collapse np.asarray(_as_numpy_array(...)) to _as_numpy_array(...).
b676c06 to
c80d8fb
Compare
|
@btaba Added the test now. It checks that generated This was initial point of the mujoco_warp PR. |
| """ | ||
|
|
||
|
|
||
| def _format_docstring(docstring: str) -> str: |
There was a problem hiding this comment.
why is all of this docstring formatting needed? it wasn't here in the original PR. does pyink not do the job?
There was a problem hiding this comment.
this is now removed.
| adr: np.ndarray | ||
| size: int | ||
|
|
||
| def __eq__(self, other) -> bool: |
There was a problem hiding this comment.
Given the amount of new code to do this codegen, I'd prefer these methods to be one-off hardcoded for TileSet, until it really seems necessary to automate these code-blocks. I'd prefer to keep the codegen more minimal as implemented in this PR. LGTM otherwise
There was a problem hiding this comment.
Agreed -- I removed the codegen changes and restored generate_warp_types.py to match main. I also verified that codegen does not break the hardcoded functions.
There was a problem hiding this comment.
Wait I'm confused now, if you run the codegen, this codeblock would be deleted? Did you forget to add the hardcoded part into generate_warp_types.py for this one-off eq and hash?
…-dataclass-methods # Conflicts: # mjx/mujoco/mjx/warp/types.py
| adr: np.ndarray | ||
| size: int | ||
|
|
||
| def __eq__(self, other) -> bool: |
There was a problem hiding this comment.
Wait I'm confused now, if you run the codegen, this codeblock would be deleted? Did you forget to add the hardcoded part into generate_warp_types.py for this one-off eq and hash?
| class_start = src.index(f'class {cls_name}:') | ||
| method_start = src.index( | ||
| ' def __eq__(self, other) -> bool:\n', class_start | ||
| ) |
There was a problem hiding this comment.
this won't generalize right? let's just not do the _write_manual_method_notes
| adr: np.ndarray | ||
| size: int | ||
|
|
||
| # Manually kept in this generated shim until TileSet method generation is |
There was a problem hiding this comment.
you can add this comment in the generate_warp_types file
This fixes MJX Warp type generation so methods on nested dataclasses from vendored mujoco_warp are preserved in the generated MJX shim types.
The immediate issue is
TileSet:mujoco_warpnow defines structural__eq__and__hash__(see google-deepmind/mujoco_warp#1276), but MJX was regeneratingTileSetonly from annotations, so those methods were dropped. I updated generate_warp_types.py to copy explicit methods from nested dataclasses in addition to generating their fields and pytree hooks.Since the MJX shim stores these fields as numpy-backed metadata, copied zero-argument
.numpy()calls are rewritten through a small_as_numpy_array(...)helper. This keeps the generated methods compatible with MJX np.ndarray fields while still allowing the method bodies to come from the vendored source.I also regenerated mjx/warp/types.py and added a test for TileSet structural equality and hashing.
Tested with: