From aaf5d0e107d0db792e235bbd70b0dd1333049819 Mon Sep 17 00:00:00 2001 From: yuecideng Date: Mon, 9 Feb 2026 21:44:49 +0800 Subject: [PATCH 1/2] wip --- embodichain/lab/gym/utils/gym_utils.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/embodichain/lab/gym/utils/gym_utils.py b/embodichain/lab/gym/utils/gym_utils.py index 3555322c..227140f9 100644 --- a/embodichain/lab/gym/utils/gym_utils.py +++ b/embodichain/lab/gym/utils/gym_utils.py @@ -77,6 +77,27 @@ def convert_observation_to_space( array = np.array(observation) dtype = array.dtype space = spaces.Box(-np.inf, np.inf, shape=array.shape, dtype=dtype) + elif isinstance(observation, torch.Tensor): + if unbatched: + shape = observation.shape[1:] + else: + shape = observation.shape + dtype = observation.dtype + # Convert torch dtype to numpy dtype + if dtype in (torch.float32, torch.float64, torch.float16): + low, high = -np.inf, np.inf + np_dtype = np.float32 + elif dtype in (torch.int32, torch.int64, torch.int16, torch.int8): + info = np.iinfo(np.int32) + low, high = info.min, info.max + np_dtype = np.int32 + elif dtype == torch.bool: + low, high = 0, 1 + np_dtype = np.bool_ + else: + low, high = -np.inf, np.inf + np_dtype = np.float32 + space = spaces.Box(low, high, shape=shape, dtype=np_dtype) elif isinstance(observation, np.ndarray): if unbatched: shape = observation.shape[1:] From f3b3170c6798f2b8f1996a9938f0bb70bb9c5658 Mon Sep 17 00:00:00 2001 From: yuecideng Date: Mon, 9 Feb 2026 22:03:44 +0800 Subject: [PATCH 2/2] wip --- embodichain/lab/gym/utils/gym_utils.py | 27 +++++++++++++------------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/embodichain/lab/gym/utils/gym_utils.py b/embodichain/lab/gym/utils/gym_utils.py index 227140f9..5cee8b63 100644 --- a/embodichain/lab/gym/utils/gym_utils.py +++ b/embodichain/lab/gym/utils/gym_utils.py @@ -83,20 +83,21 @@ def convert_observation_to_space( else: shape = observation.shape dtype = observation.dtype - # Convert torch dtype to numpy dtype - if dtype in (torch.float32, torch.float64, torch.float16): + # Map torch dtype to numpy dtype and reuse get_dtype_bounds for consistency + torch_to_numpy_dtype = { + torch.float16: np.float16, + torch.float32: np.float32, + torch.float64: np.float64, + torch.int8: np.int8, + torch.int16: np.int16, + torch.int32: np.int32, + torch.int64: np.int64, + torch.bool: np.bool_, + } + np_dtype = torch_to_numpy_dtype.get(dtype, np.float32) + low, high = get_dtype_bounds(np_dtype) + if np.issubdtype(np_dtype, np.floating): low, high = -np.inf, np.inf - np_dtype = np.float32 - elif dtype in (torch.int32, torch.int64, torch.int16, torch.int8): - info = np.iinfo(np.int32) - low, high = info.min, info.max - np_dtype = np.int32 - elif dtype == torch.bool: - low, high = 0, 1 - np_dtype = np.bool_ - else: - low, high = -np.inf, np.inf - np_dtype = np.float32 space = spaces.Box(low, high, shape=shape, dtype=np_dtype) elif isinstance(observation, np.ndarray): if unbatched: