From 5ac2677b642ddebdc99409c652c74381d1329610 Mon Sep 17 00:00:00 2001 From: Guillermo Etchebarne Date: Tue, 12 May 2026 11:12:28 -0300 Subject: [PATCH] distributed/rpc/rl: compute REINFORCE returns per observer --- distributed/rpc/rl/main.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/distributed/rpc/rl/main.py b/distributed/rpc/rl/main.py index 91451ecc84..c010e19e5b 100644 --- a/distributed/rpc/rl/main.py +++ b/distributed/rpc/rl/main.py @@ -177,11 +177,18 @@ def finish_episode(self): as the reward of the current episode. """ - # joins probs and rewards from different observers into lists - R, probs, rewards = 0, [], [] + # joins probs and per-observer discounted returns into flat lists; + # returns are computed per observer so trajectories from different + # observers don't bleed into each other through the reverse cumulative sum + probs, returns = [], [] for ob_id in self.rewards: + R = 0 + ob_returns = [] + for r in self.rewards[ob_id][::-1]: + R = r + args.gamma * R + ob_returns.insert(0, R) + returns.extend(ob_returns) probs.extend(self.saved_log_probs[ob_id]) - rewards.extend(self.rewards[ob_id]) # use the minimum observer reward to calculate the running reward min_reward = min([sum(self.rewards[ob_id]) for ob_id in self.rewards]) @@ -192,10 +199,7 @@ def finish_episode(self): self.rewards[ob_id] = [] self.saved_log_probs[ob_id] = [] - policy_loss, returns = [], [] - for r in rewards[::-1]: - R = r + args.gamma * R - returns.insert(0, R) + policy_loss = [] returns = torch.tensor(returns) returns = (returns - returns.mean()) / (returns.std() + self.eps) for log_prob, R in zip(probs, returns):