It seems that the method is_valid_transition of OutOfGraphReplayBuffer is not checking if the stacked images are coming from another truncated trajectory, in which case the index is invalid.
It only checks if:
- the stacked images are coming from another terminated trajectory:
|
if self.get_terminal_stack(index)[:-1].any(): |
- the following observations are coming from the trajectory and are not truncated:
|
if i in self.episode_end_indices and not self._store['terminal'][i]: |
Here is a simple example of where it can be problematic:
import numpy as np
from dopamine.replay_memory.circular_replay_buffer import OutOfGraphReplayBuffer
replay_buffer = OutOfGraphReplayBuffer(observation_shape=(1,), stack_size=2, replay_capacity=10, batch_size=1)
replay_buffer.add(np.array([1]), 1, 1, False, episode_end=True)
replay_buffer.add(np.array([2]), 2, 2, False)
print(replay_buffer._store["observation"][:4])
print(replay_buffer.sample_transition_batch())
>>> [[0], [1], [0], [2]] # there is no valid index to sample.
>>> (array([[[1, 0]]], dtype=uint8), array([0], dtype=int32), array([0.], dtype=float32), array([[[0, 2]]], dtype=uint8), array([2], dtype=int32), array([2.], dtype=float32), array([0], dtype=uint8), array([2], dtype=int32))
Here, index 2 is considered to be valid while it is not the case since the state array([[[1, 0]]]) is composed of an observation from the previous trajectory: [1] and a sample from the new trajectory: [0].
To solve this bug,
|
for i in modulo_range(index, self._update_horizon, self._replay_capacity): |
could be changed in:
for i in modulo_range(index - self._stack_size + 1, self._update_horizon, self._replay_capacity):
It seems that the method is_valid_transition of OutOfGraphReplayBuffer is not checking if the stacked images are coming from another truncated trajectory, in which case the index is invalid.
It only checks if:
dopamine/dopamine/replay_memory/circular_replay_buffer.py
Line 462 in ce36aab
dopamine/dopamine/replay_memory/circular_replay_buffer.py
Line 468 in ce36aab
Here is a simple example of where it can be problematic:
>>> [[0], [1], [0], [2]] # there is no valid index to sample. >>> (array([[[1, 0]]], dtype=uint8), array([0], dtype=int32), array([0.], dtype=float32), array([[[0, 2]]], dtype=uint8), array([2], dtype=int32), array([2.], dtype=float32), array([0], dtype=uint8), array([2], dtype=int32))Here, index 2 is considered to be valid while it is not the case since the state array([[[1, 0]]]) is composed of an observation from the previous trajectory: [1] and a sample from the new trajectory: [0].
To solve this bug,
dopamine/dopamine/replay_memory/circular_replay_buffer.py
Line 467 in ce36aab