diff --git a/embodichain/lab/gym/utils/gym_utils.py b/embodichain/lab/gym/utils/gym_utils.py index 3555322c..5cee8b63 100644 --- a/embodichain/lab/gym/utils/gym_utils.py +++ b/embodichain/lab/gym/utils/gym_utils.py @@ -77,6 +77,28 @@ def convert_observation_to_space( array = np.array(observation) dtype = array.dtype space = spaces.Box(-np.inf, np.inf, shape=array.shape, dtype=dtype) + elif isinstance(observation, torch.Tensor): + if unbatched: + shape = observation.shape[1:] + else: + shape = observation.shape + dtype = observation.dtype + # Map torch dtype to numpy dtype and reuse get_dtype_bounds for consistency + torch_to_numpy_dtype = { + torch.float16: np.float16, + torch.float32: np.float32, + torch.float64: np.float64, + torch.int8: np.int8, + torch.int16: np.int16, + torch.int32: np.int32, + torch.int64: np.int64, + torch.bool: np.bool_, + } + np_dtype = torch_to_numpy_dtype.get(dtype, np.float32) + low, high = get_dtype_bounds(np_dtype) + if np.issubdtype(np_dtype, np.floating): + low, high = -np.inf, np.inf + space = spaces.Box(low, high, shape=shape, dtype=np_dtype) elif isinstance(observation, np.ndarray): if unbatched: shape = observation.shape[1:]