Register __ior__ as alias of bitwise_or (fixes #2584)#2702
Open
LeSingh1 wants to merge 1 commit into
Open
Conversation
`sanitize_op_kind` strips the dunder wrapper from torch op names, so `aten::__or__` becomes `or` and matches `bitwise_or`'s alias. But the in-place form `aten::__ior__` sanitizes to `ior` -- the leading `i` is preserved because `unify_inplace_and_functional` only strips a trailing `_`, and `sanitize_op_kind` only strips the dunder wrapper. The result is that `tensor |= other` fails with "PyTorch convert function for op '__ior__' not implemented", which is the gemma-3-1b-it failure reported in apple#2584. Add `ior` to the `bitwise_or` alias list (matching how `or`/`xor` are already registered for the post-sanitize form of `__or__`/`__xor__`) and a regression test in `TestBitwiseOr`. Fixes apple#2584
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
sanitize_op_kindstrips the dunder wrapper from torch op names, soaten::__or__becomesorand matchesbitwise_or's alias. The in-place formaten::__ior__sanitizes toior-- the leadingiis preserved becauseunify_inplace_and_functionalonly strips a trailing_, andsanitize_op_kindonly strips the dunder wrapper.The result:
tensor |= otherfails withThis is the gemma-3-1b-it failure reported in #2584.
Fix
Add
"ior"to thebitwise_oralias list, matching howor/xorare already registered to handle the post-sanitize form of__or__/__xor__.I deliberately scoped this to
__ior__only, since that's what the issue reports. If reviewers prefer, similar aliases foriand/ixorare a straightforward follow-up.Tests
Added
TestBitwiseOr::test_ior_operator, mirroring the existingtest_or_operatorbut usingx |= yso the registered alias path is exercised across all three frontends (TorchScript, TorchExport, ExecuTorch).I confirmed locally that:
x |= ytobitwise_orand runs through the full converter pipeline.__or__path.Native libs (
BlobWriter) are not available in my dev environment so I could not run the existingrun_compare_torchend-to-end suite locally -- CI will exercise it.Issue
Fixes #2584