Skip to content

Register __ior__ as alias of bitwise_or (fixes #2584)#2702

Open
LeSingh1 wants to merge 1 commit into
apple:mainfrom
LeSingh1:fix/register-ior-op
Open

Register __ior__ as alias of bitwise_or (fixes #2584)#2702
LeSingh1 wants to merge 1 commit into
apple:mainfrom
LeSingh1:fix/register-ior-op

Conversation

@LeSingh1
Copy link
Copy Markdown

Summary

sanitize_op_kind strips the dunder wrapper from torch op names, so aten::__or__ becomes or and matches bitwise_or's alias. The in-place form aten::__ior__ sanitizes to ior -- the leading i is preserved because unify_inplace_and_functional only strips a trailing _, and sanitize_op_kind only strips the dunder wrapper.

The result: tensor |= other fails with

PyTorch convert function for op '__ior__' not implemented.

This is the gemma-3-1b-it failure reported in #2584.

Fix

Add "ior" to the bitwise_or alias list, matching how or / xor are already registered to handle the post-sanitize form of __or__ / __xor__.

@register_torch_op(torch_alias=["or", "ior"])
def bitwise_or(context, node):
    _bitwise_as_logical_if_boolean(context, node, "bitwise_or", logical_or)

I deliberately scoped this to __ior__ only, since that's what the issue reports. If reviewers prefer, similar aliases for iand / ixor are a straightforward follow-up.

Tests

Added TestBitwiseOr::test_ior_operator, mirroring the existing test_or_operator but using x |= y so the registered alias path is exercised across all three frontends (TorchScript, TorchExport, ExecuTorch).

I confirmed locally that:

  • The fixed code dispatches x |= y to bitwise_or and runs through the full converter pipeline.
  • A side-by-side manual run also confirms there is no regression in the existing __or__ path.

Native libs (BlobWriter) are not available in my dev environment so I could not run the existing run_compare_torch end-to-end suite locally -- CI will exercise it.

Issue

Fixes #2584

`sanitize_op_kind` strips the dunder wrapper from torch op names, so
`aten::__or__` becomes `or` and matches `bitwise_or`'s alias. But the
in-place form `aten::__ior__` sanitizes to `ior` -- the leading `i` is
preserved because `unify_inplace_and_functional` only strips a trailing
`_`, and `sanitize_op_kind` only strips the dunder wrapper.

The result is that `tensor |= other` fails with
"PyTorch convert function for op '__ior__' not implemented", which is
the gemma-3-1b-it failure reported in apple#2584.

Add `ior` to the `bitwise_or` alias list (matching how `or`/`xor` are
already registered for the post-sanitize form of `__or__`/`__xor__`)
and a regression test in `TestBitwiseOr`.

Fixes apple#2584
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

PyTorch __ior__ op is not implemented for conversion

1 participant