Skip to content

TileSet MJX codegen fix#3299

Open
tkelestemur wants to merge 4 commits into
google-deepmind:mainfrom
tkelestemur:tarik/mjx-warp-nested-dataclass-methods
Open

TileSet MJX codegen fix#3299
tkelestemur wants to merge 4 commits into
google-deepmind:mainfrom
tkelestemur:tarik/mjx-warp-nested-dataclass-methods

Conversation

@tkelestemur
Copy link
Copy Markdown
Contributor

This fixes MJX Warp type generation so methods on nested dataclasses from vendored mujoco_warp are preserved in the generated MJX shim types.

The immediate issue is TileSet: mujoco_warp now defines structural __eq__ and __hash__ (see google-deepmind/mujoco_warp#1276), but MJX was regenerating TileSet only from annotations, so those methods were dropped. I updated generate_warp_types.py to copy explicit methods from nested dataclasses in addition to generating their fields and pytree hooks.

Since the MJX shim stores these fields as numpy-backed metadata, copied zero-argument .numpy() calls are rewritten through a small _as_numpy_array(...) helper. This keeps the generated methods compatible with MJX np.ndarray fields while still allowing the method bodies to come from the vendored source.

I also regenerated mjx/warp/types.py and added a test for TileSet structural equality and hashing.

Tested with:

bash mujoco/mjx/codegen/update_for_mujoco_warp.sh
python mujoco/mjx/codegen/generate_warp_types.py --mjx_warp_types_out_path=mujoco/mjx/warp/types.py --mjx_types_path=mujoco/mjx/_src/types.py --logtostderr
python -m mujoco.mjx.warp.types_test
git diff --check
python -m py_compile mujoco/mjx/codegen/generate_warp_types.py mujoco/mjx/warp/types.py mujoco/mjx/warp/types_test.py

@btaba
Copy link
Copy Markdown
Collaborator

btaba commented May 30, 2026

@tkelestemur I don't see the test added to this PR, you mentioned the test in the OP, can you add the test so I understand a bit better why you need this?

Comment thread mjx/mujoco/mjx/warp/types.py Outdated
def _as_numpy_array(value):
if hasattr(value, 'numpy'):
value = value.numpy()
return np.asarray(value)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this method seems a bit odd, if a value has a numpy attr this returns np.asarray(value.numpy()) ?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, fixed. _as_numpy_array now returns value.numpy() directly when that path is available, and only uses np.asarray for the fallback path.

Comment thread mjx/mujoco/mjx/warp/types.py Outdated
if self.__class__ is not other.__class__:
return NotImplemented
return self.size == other.size and np.array_equal(
np.asarray(_as_numpy_array(self.adr)),
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and this would be doing np.asarray(np.asarray(value.numpy())) potentially?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, that extra wrapping was unnecessary. I added a generic cleanup in the method adapter so generated method bodies collapse np.asarray(_as_numpy_array(...)) to _as_numpy_array(...).

@tkelestemur tkelestemur force-pushed the tarik/mjx-warp-nested-dataclass-methods branch from b676c06 to c80d8fb Compare June 1, 2026 14:46
@tkelestemur
Copy link
Copy Markdown
Contributor Author

@btaba Added the test now. It checks that generated TileSet instances compare structurally and hash the same when their numpy-backed adr values are equivalent, and that a different adr does not compare equal.

This was initial point of the mujoco_warp PR.

"""


def _format_docstring(docstring: str) -> str:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is all of this docstring formatting needed? it wasn't here in the original PR. does pyink not do the job?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is now removed.

adr: np.ndarray
size: int

def __eq__(self, other) -> bool:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given the amount of new code to do this codegen, I'd prefer these methods to be one-off hardcoded for TileSet, until it really seems necessary to automate these code-blocks. I'd prefer to keep the codegen more minimal as implemented in this PR. LGTM otherwise

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed -- I removed the codegen changes and restored generate_warp_types.py to match main. I also verified that codegen does not break the hardcoded functions.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wait I'm confused now, if you run the codegen, this codeblock would be deleted? Did you forget to add the hardcoded part into generate_warp_types.py for this one-off eq and hash?

…-dataclass-methods

# Conflicts:
#	mjx/mujoco/mjx/warp/types.py
adr: np.ndarray
size: int

def __eq__(self, other) -> bool:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wait I'm confused now, if you run the codegen, this codeblock would be deleted? Did you forget to add the hardcoded part into generate_warp_types.py for this one-off eq and hash?

class_start = src.index(f'class {cls_name}:')
method_start = src.index(
' def __eq__(self, other) -> bool:\n', class_start
)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this won't generalize right? let's just not do the _write_manual_method_notes

adr: np.ndarray
size: int

# Manually kept in this generated shim until TileSet method generation is
Copy link
Copy Markdown
Collaborator

@btaba btaba Jun 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can add this comment in the generate_warp_types file

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants