Skip to content

stat_tracking: fix AttributeError on second update() call caused by in-place np.stack overwrite#3

Open
Dev-X25874 wants to merge 1 commit into
radixark:mainfrom
Dev-X25874:fix/stat-tracker-extend-crash
Open

stat_tracking: fix AttributeError on second update() call caused by in-place np.stack overwrite#3
Dev-X25874 wants to merge 1 commit into
radixark:mainfrom
Dev-X25874:fix/stat-tracker-extend-crash

Conversation

@Dev-X25874
Copy link
Copy Markdown

What

PerPromptStatTracker.update() crashes with AttributeError: 'numpy.ndarray' object has no attribute 'extend' on any call after the first when the same
prompt key is seen again.

Root Cause

The second for loop in update() overwrites self.stats[prompt] with the
result of np.stack(...), converting the stored value from a list to a
numpy.ndarray. On the next call, the first for loop tries to call
.extend() on that ndarray, which does not exist on that type.

# before — mutates self.stats[prompt] to ndarray
self.stats[prompt] = np.stack(self.stats[prompt])

Fix

Compute the stacked array into a local variable prompt_history instead of
overwriting self.stats[prompt], so the list is preserved across calls.

# after — self.stats[prompt] stays a list
prompt_history = np.array(self.stats[prompt])

All downstream mean/std calculations use prompt_history; no values change.

Impact

  • Affects all training scripts that rely on cross-call reward history
    accumulation (train_sd3.py, train_flux.py, train_flux_fast.py, etc.)
  • The built-in main() smoke test only calls update() once before clear(),
    so the bug was not caught locally
  • One-line change, zero behaviour delta

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant