From 0e23acd80f6f7fa87869539286f17deffce2cc18 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Thu, 27 Mar 2025 14:54:58 +0000 Subject: [PATCH 1/4] local seed --- vmas/simulator/environment/environment.py | 46 +++++++++++++++++++++-- 1 file changed, 43 insertions(+), 3 deletions(-) diff --git a/vmas/simulator/environment/environment.py b/vmas/simulator/environment/environment.py index 39d2d720..a85d47b8 100644 --- a/vmas/simulator/environment/environment.py +++ b/vmas/simulator/environment/environment.py @@ -1,6 +1,8 @@ -# Copyright (c) 2022-2024. +# Copyright (c) 2022-2025. # ProrokLab (https://www.proroklab.org/) # All rights reserved. +import contextlib +import functools import math import random from ctypes import byref @@ -26,13 +28,51 @@ ) -# environment for all agents in the multiagent world -# currently code assumes that no agents will be created/destroyed at runtime! +@contextlib.contextmanager +def local_seed(vmas_random_state): + torch_state = torch.random.get_rng_state() + np_state = np.random.get_state() + py_state = random.getstate() + + torch.random.set_rng_state(vmas_random_state[0]) + np.random.set_state(vmas_random_state[1]) + random.setstate(vmas_random_state[2]) + yield + vmas_random_state[0] = torch.random.get_rng_state() + vmas_random_state[1] = np.random.get_state() + vmas_random_state[2] = random.getstate() + + torch.random.set_rng_state(torch_state) + np.random.set_state(np_state) + random.setstate(py_state) + + +def apply_local_seed(cls): + """Applies the local seed to all the functions.""" + for attr_name, attr_value in cls.__dict__.items(): + if callable(attr_value): + wrapped = attr_value # Keep reference to original method + + @functools.wraps(wrapped) + def wrapper(self, *args, _wrapped=wrapped, **kwargs): + with local_seed(cls.vmas_random_state): + return _wrapped(self, *args, **kwargs) + + setattr(cls, attr_name, wrapper) + return cls + + +@apply_local_seed class Environment(TorchVectorizedObject): metadata = { "render.modes": ["human", "rgb_array"], "runtime.vectorized": True, } + vmas_random_state = [ + torch.random.get_rng_state(), + np.random.get_state(), + random.getstate(), + ] def __init__( self, From 0212591e7a1ba410ff448e34769fb462084d397b Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Fri, 28 Mar 2025 14:46:16 +0000 Subject: [PATCH 2/4] local seed --- tests/test_vmas.py | 20 ++- vmas/simulator/environment/environment.py | 148 ++++++++++++++++++---- 2 files changed, 139 insertions(+), 29 deletions(-) diff --git a/tests/test_vmas.py b/tests/test_vmas.py index 3a782b36..ab4cd4db 100644 --- a/tests/test_vmas.py +++ b/tests/test_vmas.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024. +# Copyright (c) 2022-2025. # ProrokLab (https://www.proroklab.org/) # All rights reserved. import math @@ -302,3 +302,21 @@ def test_vmas_differentiable(scenario, n_steps=10, n_envs=10): loss = obs[-1].mean() + rews[-1].mean() grad = torch.autograd.grad(loss, first_action) + + +def test_seeding(): + env = make_env(scenario="balance", num_envs=2, seed=0) + env.seed(0) + random_obs = env.reset()[0][0, 0] + env.seed(0) + assert random_obs == env.reset()[0][0, 0] + env.seed(0) + torch.manual_seed(1) + assert random_obs == env.reset()[0][0, 0] + + torch.manual_seed(0) + random_obs = torch.randn(1) + torch.manual_seed(0) + env.seed(1) + env.reset() + assert random_obs == torch.randn(1) diff --git a/vmas/simulator/environment/environment.py b/vmas/simulator/environment/environment.py index a85d47b8..15c755c5 100644 --- a/vmas/simulator/environment/environment.py +++ b/vmas/simulator/environment/environment.py @@ -2,7 +2,6 @@ # ProrokLab (https://www.proroklab.org/) # All rights reserved. import contextlib -import functools import math import random from ctypes import byref @@ -47,22 +46,6 @@ def local_seed(vmas_random_state): random.setstate(py_state) -def apply_local_seed(cls): - """Applies the local seed to all the functions.""" - for attr_name, attr_value in cls.__dict__.items(): - if callable(attr_value): - wrapped = attr_value # Keep reference to original method - - @functools.wraps(wrapped) - def wrapper(self, *args, _wrapped=wrapped, **kwargs): - with local_seed(cls.vmas_random_state): - return _wrapped(self, *args, **kwargs) - - setattr(cls, attr_name, wrapper) - return cls - - -@apply_local_seed class Environment(TorchVectorizedObject): metadata = { "render.modes": ["human", "rgb_array"], @@ -74,6 +57,7 @@ class Environment(TorchVectorizedObject): random.getstate(), ] + @local_seed(vmas_random_state) def __init__( self, scenario: BaseScenario, @@ -108,7 +92,7 @@ def __init__( self.grad_enabled = grad_enabled self.terminated_truncated = terminated_truncated - observations = self.reset(seed=seed) + observations = self._reset(seed=seed) # configure spaces self.multidiscrete_actions = multidiscrete_actions @@ -121,6 +105,7 @@ def __init__( self.visible_display = None self.text_lines = None + @local_seed(vmas_random_state) def reset( self, seed: Optional[int] = None, @@ -132,13 +117,104 @@ def reset( Resets the environment in a vectorized way Returns observations for all envs and agents """ + return self._reset( + seed=seed, + return_observations=return_observations, + return_info=return_info, + return_dones=return_dones, + ) + + @local_seed(vmas_random_state) + def reset_at( + self, + index: int, + return_observations: bool = True, + return_info: bool = False, + return_dones: bool = False, + ): + """ + Resets the environment at index + Returns observations for all agents in that environment + """ + return self._reset_at( + index=index, + return_observations=return_observations, + return_info=return_info, + return_dones=return_dones, + ) + + @local_seed(vmas_random_state) + def get_from_scenario( + self, + get_observations: bool, + get_rewards: bool, + get_infos: bool, + get_dones: bool, + dict_agent_names: Optional[bool] = None, + ): + """ + Get the environment data from the scenario + + Args: + get_observations (bool): whether to return the observations + get_rewards (bool): whether to return the rewards + get_infos (bool): whether to return the infos + get_dones (bool): whether to return the dones + dict_agent_names (bool, optional): whether to return the information in a dictionary with agent names as keys + or in a list + + Returns: + The agents' data + + """ + return self._get_from_scenario( + get_observations=get_observations, + get_rewards=get_rewards, + get_infos=get_infos, + get_dones=get_dones, + dict_agent_names=dict_agent_names, + ) + + @local_seed(vmas_random_state) + def seed(self, seed=None): + """ + Sets the seed for the environment + Args: + seed (int, optional): Seed for the environment. Defaults to None. + + """ + return self._seed(seed=seed) + + @local_seed(vmas_random_state) + def done(self): + """ + Get the done flags for the scenario. + + Returns: + Either terminated, truncated (if self.terminated_truncated==True) or terminated + truncated (if self.terminated_truncated==False) + + """ + return self._done() + + def _reset( + self, + seed: Optional[int] = None, + return_observations: bool = True, + return_info: bool = False, + return_dones: bool = False, + ): + """ + Resets the environment in a vectorized way + Returns observations for all envs and agents + """ + if seed is not None: - self.seed(seed) + self._seed(seed) # reset world self.scenario.env_reset_world_at(env_index=None) self.steps = torch.zeros(self.num_envs, device=self.device) - result = self.get_from_scenario( + result = self._get_from_scenario( get_observations=return_observations, get_infos=return_info, get_rewards=False, @@ -146,7 +222,7 @@ def reset( ) return result[0] if result and len(result) == 1 else result - def reset_at( + def _reset_at( self, index: int, return_observations: bool = True, @@ -161,7 +237,7 @@ def reset_at( self.scenario.env_reset_world_at(index) self.steps[index] = 0 - result = self.get_from_scenario( + result = self._get_from_scenario( get_observations=return_observations, get_infos=return_info, get_rewards=False, @@ -170,7 +246,7 @@ def reset_at( return result[0] if result and len(result) == 1 else result - def get_from_scenario( + def _get_from_scenario( self, get_observations: bool, get_rewards: bool, @@ -218,16 +294,22 @@ def get_from_scenario( if self.terminated_truncated: if get_dones: - terminated, truncated = self.done() + terminated, truncated = self._done() result = [obs, rewards, terminated, truncated, infos] else: if get_dones: - dones = self.done() + dones = self._done() result = [obs, rewards, dones, infos] return [data for data in result if data is not None] - def seed(self, seed=None): + def _seed(self, seed=None): + """ + Sets the seed for the environment + Args: + seed (int, optional): Seed for the environment. Defaults to None. + + """ if seed is None: seed = 0 torch.manual_seed(seed) @@ -235,6 +317,7 @@ def seed(self, seed=None): random.seed(seed) return [seed] + @local_seed(vmas_random_state) def step(self, actions: Union[List, Dict]): """Performs a vectorized step on all sub environments using `actions`. Args: @@ -309,14 +392,21 @@ def step(self, actions: Union[List, Dict]): self.steps += 1 - return self.get_from_scenario( + return self._get_from_scenario( get_observations=True, get_infos=True, get_rewards=True, get_dones=True, ) - def done(self): + def _done(self): + """ + Get the done flags for the scenario. + + Returns: + Either terminated, truncated (if self.terminated_truncated==True) or terminated + truncated (if self.terminated_truncated==False) + + """ terminated = self.scenario.done().clone() if self.max_steps is not None: @@ -427,6 +517,7 @@ def get_agent_observation_space(self, agent: Agent, obs: AGENT_OBS_TYPE): f"Invalid type of observation {obs} for agent {agent.name}" ) + @local_seed(vmas_random_state) def get_random_action(self, agent: Agent) -> torch.Tensor: """Returns a random action for the given agent. @@ -652,6 +743,7 @@ def _set_action(self, action, agent): ) agent.action.c += noise + @local_seed(vmas_random_state) def render( self, mode="human", From a8dd1260f3d38a2556c0a923882eda2ac21d202c Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Fri, 28 Mar 2025 14:46:24 +0000 Subject: [PATCH 3/4] local seed --- tests/test_vmas.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_vmas.py b/tests/test_vmas.py index ab4cd4db..ca033782 100644 --- a/tests/test_vmas.py +++ b/tests/test_vmas.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2025. +# Copyright (c) 2022-2024. # ProrokLab (https://www.proroklab.org/) # All rights reserved. import math From 6e76e55b6ba041cb547f6697338953b6a3ae5ea3 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Fri, 28 Mar 2025 16:24:05 +0000 Subject: [PATCH 4/4] local seed --- vmas/simulator/environment/environment.py | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/vmas/simulator/environment/environment.py b/vmas/simulator/environment/environment.py index 15c755c5..a061c914 100644 --- a/vmas/simulator/environment/environment.py +++ b/vmas/simulator/environment/environment.py @@ -47,6 +47,10 @@ def local_seed(vmas_random_state): class Environment(TorchVectorizedObject): + """ + The VMAS environment + """ + metadata = { "render.modes": ["human", "rgb_array"], "runtime.vectorized": True, @@ -320,16 +324,15 @@ def _seed(self, seed=None): @local_seed(vmas_random_state) def step(self, actions: Union[List, Dict]): """Performs a vectorized step on all sub environments using `actions`. + Args: - actions: Is a list on len 'self.n_agents' of which each element is a torch.Tensor of shape - '(self.num_envs, action_size_of_agent)'. + actions: Is a list on len 'self.n_agents' of which each element is a torch.Tensor of shape '(self.num_envs, action_size_of_agent)'. + Returns: - obs: List on len 'self.n_agents' of which each element is a torch.Tensor - of shape '(self.num_envs, obs_size_of_agent)' + obs: List on len 'self.n_agents' of which each element is a torch.Tensor of shape '(self.num_envs, obs_size_of_agent)' rewards: List on len 'self.n_agents' of which each element is a torch.Tensor of shape '(self.num_envs)' dones: Tensor of len 'self.num_envs' of which each element is a bool - infos : List on len 'self.n_agents' of which each element is a dictionary for which each key is a metric - and the value is a tensor of shape '(self.num_envs, metric_size_per_agent)' + infos: List on len 'self.n_agents' of which each element is a dictionary for which each key is a metric and the value is a tensor of shape '(self.num_envs, metric_size_per_agent)' Examples: >>> import vmas @@ -345,6 +348,7 @@ def step(self, actions: Union[List, Dict]): >>> obs = env.reset() >>> for _ in range(10): ... obs, rews, dones, info = env.step(env.get_random_actions()) + """ if isinstance(actions, Dict): actions_dict = actions @@ -578,7 +582,7 @@ def get_random_action(self, agent: Agent) -> torch.Tensor: return action def get_random_actions(self) -> Sequence[torch.Tensor]: - """Returns random actions for all agents that you can feed to :class:`step` + """Returns random actions for all agents that you can feed to :meth:`step` Returns: Sequence[torch.tensor]: the random actions for the agents @@ -767,6 +771,7 @@ def render( Render function for environment using pyglet On servers use mode="rgb_array" and set + ``` export DISPLAY=':99.0' Xvfb :99 -screen 0 1400x900x24 > /dev/null 2>&1 & @@ -774,8 +779,7 @@ def render( :param mode: One of human or rgb_array :param env_index: Index of the environment to render - :param agent_index_focus: If specified the camera will stay on the agent with this index. - If None, the camera will stay in the center and zoom out to contain all agents + :param agent_index_focus: If specified the camera will stay on the agent with this index. If None, the camera will stay in the center and zoom out to contain all agents :param visualize_when_rgb: Also run human visualization when mode=="rgb_array" :param plot_position_function: A function to plot under the rendering. The function takes a numpy array with shape (n_points, 2), which represents a set of x,y values to evaluate f over and plot it @@ -789,6 +793,7 @@ def render( :param plot_position_function_cmap_range: The range of the cmap in case plot_position_function outputs a single value :param plot_position_function_cmap_alpha: The alpha of the cmap in case plot_position_function outputs a single value :return: Rgb array or None, depending on the mode + """ self._check_batch_index(env_index) assert (