Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 63 additions & 14 deletions source/isaaclab_contrib/isaaclab_contrib/rl/rlinf/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,17 +276,38 @@ def _register_gr00t_converters(cfg: dict) -> None:
Args:
cfg: The IsaacLab-specific configuration dictionary (``env.train.isaaclab``).
"""
from rlinf.models.embodiment.gr00t import simulation_io

obs_converter_type = cfg.get("obs_converter_type", "dex3")

if obs_converter_type not in simulation_io.OBS_CONVERSION:
simulation_io.OBS_CONVERSION[obs_converter_type] = _convert_isaaclab_obs_to_gr00t
logger.info(f"Registered obs converter: {obs_converter_type}")
simulation_modules = []
try:
from rlinf.models.embodiment.gr00t import simulation_io as gr00t_simulation_io

simulation_modules.append(("gr00t", gr00t_simulation_io))
except Exception as exc:
logger.debug(f"Could not import GR00T N1.5 simulation_io: {exc}")

try:
from rlinf.models.embodiment.gr00t_n1d6 import simulation_io as gr00t_n1d6_simulation_io

simulation_modules.append(("gr00t_n1d6", gr00t_n1d6_simulation_io))
except Exception as exc:
logger.debug(f"Could not import GR00T N1.6 simulation_io: {exc}")

try:
from rlinf.models.embodiment.gr00t_n1d7 import simulation_io as gr00t_n1d7_simulation_io

simulation_modules.append(("gr00t_n1d7", gr00t_n1d7_simulation_io))
except Exception as exc:
logger.debug(f"Could not import GR00T N1.7 simulation_io: {exc}")

if obs_converter_type not in simulation_io.ACTION_CONVERSION:
simulation_io.ACTION_CONVERSION[obs_converter_type] = _convert_gr00t_to_isaaclab_action
logger.info(f"Registered action converter: {obs_converter_type}")
for module_name, simulation_io in simulation_modules:
if obs_converter_type not in simulation_io.OBS_CONVERSION:
simulation_io.OBS_CONVERSION[obs_converter_type] = _convert_isaaclab_obs_to_gr00t
logger.info(f"Registered {module_name} obs converter: {obs_converter_type}")

if obs_converter_type not in simulation_io.ACTION_CONVERSION:
simulation_io.ACTION_CONVERSION[obs_converter_type] = _convert_gr00t_to_isaaclab_action
logger.info(f"Registered {module_name} action converter: {obs_converter_type}")


def _convert_isaaclab_obs_to_gr00t(env_obs: dict) -> dict:
Expand Down Expand Up @@ -338,10 +359,18 @@ def _convert_isaaclab_obs_to_gr00t(env_obs: dict) -> dict:
gr00t_key = spec.get("gr00t_key")
slice_range = spec.get("slice", [0, states_np.shape[-1]])
if gr00t_key:
groot_obs[gr00t_key] = states_np[:, :, slice_range[0] : slice_range[1]]

# Pass through task descriptions
groot_obs["annotation.human.action.task_description"] = env_obs.get("task_descriptions", [])
state_part = states_np[:, :, slice_range[0] : slice_range[1]]
if "scale" in spec:
state_part = state_part * np.asarray(spec["scale"], dtype=state_part.dtype)
if "offset" in spec:
state_part = state_part + np.asarray(spec["offset"], dtype=state_part.dtype)
groot_obs[gr00t_key] = state_part

# Pass through task descriptions. SO-101 N1.6 checkpoints use
# annotation.human.task_description, while older LIBERO-style configs use
# annotation.human.action.task_description.
language_key = gr00t_mapping.get("language_key", "annotation.human.action.task_description")
groot_obs[language_key] = env_obs.get("task_descriptions", [])

return groot_obs

Expand All @@ -367,8 +396,24 @@ def _convert_gr00t_to_isaaclab_action(action_chunk: dict, chunk_size: int = 1) -
prefix_pad = action_mapping.get("prefix_pad", 0)
suffix_pad = action_mapping.get("suffix_pad", 0)

# Concatenate all action parts
action_parts = [v[:, :chunk_size, :] for v in action_chunk.values()]
# Concatenate action parts in the configured order when provided.
action_keys = action_mapping.get("gr00t_action_keys") or list(action_chunk.keys())
action_parts = []
for key in action_keys:
if key in action_chunk:
action_parts.append(action_chunk[key][:, :chunk_size, :])
continue
short_key = key.split(".", 1)[1] if key.startswith("action.") else f"action.{key}"
if short_key in action_chunk:
action_parts.append(action_chunk[short_key][:, :chunk_size, :])
continue
logger.warning(
f"GR00T action key '{key}' (also tried '{short_key}') not found in action chunk "
f"(available: {list(action_chunk)}); this entry will be skipped and the action tensor "
"will be narrower than expected."
)
if not action_parts:
Comment thread
johnnynunez marked this conversation as resolved.
raise KeyError(f"No configured GR00T action keys found in action chunk: keys={list(action_chunk)}")
action_concat = np.concatenate(action_parts, axis=-1)

# Apply padding
Expand All @@ -379,6 +424,10 @@ def _convert_gr00t_to_isaaclab_action(action_chunk: dict, chunk_size: int = 1) -
mode="constant",
constant_values=0,
)
if "scale" in action_mapping:
action_concat = action_concat * np.asarray(action_mapping["scale"], dtype=action_concat.dtype)
if "offset" in action_mapping:
action_concat = action_concat + np.asarray(action_mapping["offset"], dtype=action_concat.dtype)
return action_concat


Expand Down