Skip to content

fix: correct Dense input dim in BroNet and respect multiplier in SimbaBlock#23

Merged
typoverflow merged 2 commits intodev_categoricalfrom
copilot/sub-pr-21-again
Feb 21, 2026
Merged

fix: correct Dense input dim in BroNet and respect multiplier in SimbaBlock#23
typoverflow merged 2 commits intodev_categoricalfrom
copilot/sub-pr-21-again

Conversation

Copy link
Copy Markdown
Contributor

Copilot AI commented Feb 21, 2026

Two bugs in the newly added BroNet and Simba modules that would cause runtime errors or silently ignore a user-configurable parameter.

Fixes

  • bronet.py: nn.Dense(self.hidden_dims, ...) passed a Sequence[int] where an int is required. Fixed to nn.Dense(self.hidden_dims[0], ...) — safe because setup() already asserts all dims are equal.

  • simba.py: SimbaBlock hardcoded the expansion factor as hidden_dim * 4, making the multiplier field on both SimbaBlock and Simba a no-op. Fixed to hidden_dim * self.multiplier.

# Before — multiplier param had no effect
x = nn.Dense(self.hidden_dim * 4, kernel_init=init.he_normal())(x)

# After — honours the configurable multiplier
x = nn.Dense(self.hidden_dim * self.multiplier, kernel_init=init.he_normal())(x)

✨ Let Copilot coding agent set things up for you — coding agent works faster and does higher quality work when set up for your repo.

Co-authored-by: typoverflow <41679605+typoverflow@users.noreply.github.com>
Copilot AI changed the title [WIP] Add BroNet, SimBa and categorical critic to networks fix: correct Dense input dim in BroNet and respect multiplier in SimbaBlock Feb 21, 2026
Copilot AI requested a review from typoverflow February 21, 2026 06:48
@typoverflow typoverflow marked this pull request as ready for review February 21, 2026 07:09
@typoverflow typoverflow merged commit 33e5391 into dev_categorical Feb 21, 2026
typoverflow added a commit that referenced this pull request Feb 21, 2026
* feat: add BroNet, simba and categorical critic

* fix: correct Dense input dim in BroNet and respect multiplier in SimbaBlock (#23)

* Initial plan

* fix: correct Dense input dim in BroNet and use multiplier in SimbaBlock

Co-authored-by: typoverflow <41679605+typoverflow@users.noreply.github.com>

---------

Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: typoverflow <41679605+typoverflow@users.noreply.github.com>

---------

Co-authored-by: Copilot <198982749+Copilot@users.noreply.github.com>
Co-authored-by: typoverflow <41679605+typoverflow@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants