From 630716f92d31cc84c651c110c549b80f5385409d Mon Sep 17 00:00:00 2001 From: Cao Mingjun Date: Sun, 22 Feb 2026 22:23:13 +0800 Subject: [PATCH 01/11] feat: add Categorical BDPO --- .../config/d4rl/algo/categorical_bdpo.yaml | 63 ++ examples/offline/main_d4rl.py | 1 + flowrl/agent/offline/__init__.py | 1 + flowrl/agent/offline/bdpo/categorical_bdpo.py | 701 ++++++++++++++++++ flowrl/config/offline/__init__.py | 2 + .../config/offline/algo/categorical_bdpo.py | 64 ++ flowrl/module/critic.py | 31 + 7 files changed, 863 insertions(+) create mode 100644 examples/offline/config/d4rl/algo/categorical_bdpo.yaml create mode 100644 flowrl/agent/offline/bdpo/categorical_bdpo.py create mode 100644 flowrl/config/offline/algo/categorical_bdpo.py diff --git a/examples/offline/config/d4rl/algo/categorical_bdpo.yaml b/examples/offline/config/d4rl/algo/categorical_bdpo.yaml new file mode 100644 index 0000000..c70fe62 --- /dev/null +++ b/examples/offline/config/d4rl/algo/categorical_bdpo.yaml @@ -0,0 +1,63 @@ +# @package _global_ + +data: + norm_obs: false + norm_reward: iql_mujoco + +eval: + num_samples: 10 + interval: 50000 + +train_steps: 2_000_000 +pretrain_steps: 2_000_000 + +algo: + name: categorical_bdpo + warmup_steps: 50_000 + temperature: 0.0 + diffusion: + actor: + lr: 0.00001 + lr_decay_steps: ${train_steps} + clip_grad_norm: 1.0 + ema: 0.005 + ema_every: 1 + behavior: + lr: 0.0003 + lr_decay_steps: null + clip_grad_norm: 1.0 + ema: 0.005 + ema_every: 1 + noise_schedule: vp + resnet: false + dropout: 0.0 + layer_norm: false + time_dim: 64 + mlp_hidden_dims: [256, 256, 256, 256] + resnet_hidden_dims: [256, 256, 256] + solver: ddpm + steps: 5 + clip_sampler: true + x_min: -1.0 + x_max: 1.0 + critic: + discount: 0.99 + q_target: lcb + maxQ: false + ensemble_size: 10 + rho: 0.0 + eta: 1.0 + hidden_dims: [256, 256, 256] + output_nodes: 101 + v_min: -10.0 + v_max: 10.0 + lr: 0.0003 + lr_decay_steps: ${train_steps} + clip_grad_norm: 1.0 + layer_norm: false + ema: 0.005 + ema_every: 1 + steps: 5 + num_samples: 10 + solver: ddpm + update_ratio: 5 diff --git a/examples/offline/main_d4rl.py b/examples/offline/main_d4rl.py index 90d6fc4..58bb580 100644 --- a/examples/offline/main_d4rl.py +++ b/examples/offline/main_d4rl.py @@ -22,6 +22,7 @@ SUPPORTED_AGENTS: Dict[str, Type[BaseAgent]] = { "iql": IQLAgent, "bdpo": BDPOAgent, + "categorical_bdpo": CategoricalBDPOAgent, "ivr": IVRAgent, "fql": FQLAgent, "dac": DACAgent, diff --git a/flowrl/agent/offline/__init__.py b/flowrl/agent/offline/__init__.py index 2493492..78168ba 100644 --- a/flowrl/agent/offline/__init__.py +++ b/flowrl/agent/offline/__init__.py @@ -1,5 +1,6 @@ from ..base import BaseAgent from .bdpo.bdpo import BDPOAgent +from .bdpo.categorical_bdpo import CategoricalBDPOAgent from .dac import DACAgent from .dql import DQLAgent from .dtql import DTQLAgent diff --git a/flowrl/agent/offline/bdpo/categorical_bdpo.py b/flowrl/agent/offline/bdpo/categorical_bdpo.py new file mode 100644 index 0000000..afe6d24 --- /dev/null +++ b/flowrl/agent/offline/bdpo/categorical_bdpo.py @@ -0,0 +1,701 @@ +from functools import partial +from typing import List + +import flax.linen as nn +import jax +import jax.numpy as jnp +import optax + +from flowrl.agent.base import BaseAgent +from flowrl.config.offline.algo.categorical_bdpo import ( + CategoricalBDPOConfig, + BDPODiffusionConfig, + BDPODiffusionTrainConfig, +) +from flowrl.flow.ddpm import DDPM, DDPMBackbone, jit_update_ddpm +from flowrl.functional.activation import mish +from flowrl.functional.ema import ema_update +from flowrl.functional.math import symlog, symexp +from flowrl.module.critic import CategoricalCriticWithDiscreteTime, Ensemblize, ScalarCritic +from flowrl.module.mlp import MLP, ResidualMLP +from flowrl.module.model import Model +from flowrl.module.time_embedding import PositionalEmbedding +from flowrl.types import * + +EPS = 1e-6 + +def get_target(rng: PRNGKey, vs: jnp.ndarray, maxQ: bool, q_target: str, rho: float) -> Tuple[PRNGKey, jnp.ndarray]: + # vs is of shape (E, B, N, 1) + if maxQ: + vs = vs.max(axis=-2) + else: + vs = vs.mean(axis=-2) + if q_target == "min": + vs = vs.min(axis=0) + elif q_target == "convex": + vs = rho * vs.min(axis=0) + (1-rho) * vs.max(axis=0) + elif q_target == "lcb": + vs = vs.mean(axis=0) - rho * vs.std(axis=0) + elif q_target == "rand_convex": + rng, alpha_key = jax.random.split(rng) + alphas = jax.random.uniform(alpha_key, vs.shape) + alphas /= (alphas.sum(axis=0, keepdims=True) + EPS) + vs = (vs * alphas).sum(axis=0) + else: + raise NotImplementedError(f"Unrecognized Q-target type: {q_target}. ") + return rng, vs + +def get_penalty( + actor_eps, + behavior_eps, + t: jnp.ndarray, + T: int, + alphas: jnp.ndarray, + alpha_hats: jnp.ndarray, + betas: jnp.ndarray +) -> jnp.ndarray: + return 0.5 * betas[t] * ((actor_eps - behavior_eps)**2) / (1 - betas[t]) / (1 - alpha_hats[t]) + +def get_atoms(v_min: float, v_max: float, num_atoms: int) -> jnp.ndarray: + """Get support atoms (pre-transformation) for categorical distribution.""" + return jnp.linspace(v_min, v_max, num_atoms) + +def logits_to_value( + logits: jnp.ndarray, + v_min: float, + v_max: float, + num_atoms: int +) -> jnp.ndarray: + """Expected value (post-transformation) from categorical logits. + + Args: + logits: raw logits, shape (E, B, N, num_atoms) + v_min: minimum support value + v_max: maximum support value + num_atoms: number of atoms in the categorical distribution + Returns: + expected value, shape (B, 1) + """ + # TODO: check if current order is correct: + # 1. softmax over atoms + # 2. mean over ensemble members + # 3. compute expected value + # 4. symexp + # 5. mean over num_q_samples dimension + # Steps 1, 2, and 3 are in the correct relative order, consistent with BRC. + # However, the placement of steps 4 and 5 is open for discussion; + # we might consider swapping their order or even moving them before step 3. + # This requires further deliberation. + E, B, N, _ = logits.shape # assertion for the shape of logits, will raise an error if the shape is not as expected + atoms = get_atoms(v_min, v_max, num_atoms) # (num_atoms,) + probs = jax.nn.softmax(logits, axis=-1).mean(axis=0) # (B, N, num_atoms) + value = (probs * atoms).sum(axis=-1, keepdims=True) # (B, N, 1) + value = symexp(value) # (B, N, 1) + value = value.mean(axis=1) # (B, 1) + return value + +def categorical_project( + target_values: jnp.ndarray, + v_min: float, + v_max: float, + num_atoms: int +) -> Tuple[jnp.ndarray, Metric]: + """C51-style projection: scalar targets (original-space) → categorical distribution. + + Args: + target_values: shape (*B, 1) + v_min: minimum support value + v_max: maximum support value + num_atoms: number of atoms in the categorical distribution + Returns: + target distribution, shape (*B, num_atoms) + metrics dict with pre-clip statistics (keys: pre_clip_max, pre_clip_min, + pre_clip_mean, clip_ratio; all computed in the transformed/symlog space) + """ + delta_z = (v_max - v_min) / (num_atoms - 1) + + symlog_values = symlog(target_values) + clip_ratio = ((symlog_values < v_min) | (symlog_values > v_max)).mean() + metrics = { + "pre_clip_max": symlog_values.max(), + "pre_clip_min": symlog_values.min(), + "pre_clip_mean": symlog_values.mean(), + "clip_ratio": clip_ratio, + } + + clipped = jnp.clip(symlog_values, v_min, v_max) + b = (clipped - v_min) / delta_z # (*B, 1) + l = jnp.floor(b).astype(jnp.int32) # (*B, 1) + u = jnp.ceil(b).astype(jnp.int32) # (*B, 1) + + d_u = (b - l.astype(jnp.float32)) # (*B) + d_l = 1.0 - d_u # (*B) + + probs = (d_l * jax.nn.one_hot(l.squeeze(-1), num_atoms) + + d_u * jax.nn.one_hot(u.squeeze(-1), num_atoms)) + return probs, metrics + +@partial(jax.jit, static_argnames=("training", "num_samples", "solver", "temperature")) +def jit_sample_and_select( + rng: PRNGKey, + model: DDPM, + q0: Model, + obs: jnp.ndarray, + training: bool, + num_samples: int, + solver: str, + temperature: float, +) -> Tuple[PRNGKey, jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]]: + B = obs.shape[0] + rng, xT_rng = jax.random.split(rng) + + # sample + obs_repeat = obs[..., jnp.newaxis, :].repeat(num_samples, axis=-2) + xT = jax.random.normal(xT_rng, (*obs_repeat.shape[:-1], model.x_dim)) + rng, actions, _ = model.sample(rng, xT, obs_repeat, training, solver) + if temperature is None: + return rng, actions[:, 0] + else: + qs = q0( + obs_repeat, + actions, + ) + qs = qs.mean(axis=0).reshape(B, num_samples) + if temperature <= 0.0: + idx = qs.argmax(axis=-1) + else: + rng, select_rng = jax.random.split(rng) + idx = jax.random.categorical(select_rng, logits=qs/(1e-6+temperature), axis=-1) + actions = actions.reshape(B, num_samples, -1)[jnp.arange(B), idx] + return rng, actions + +@partial(jax.jit, static_argnames=("ema", "do_ema_update")) +def jit_update_behavior( + rng: PRNGKey, + behavior: DDPM, + behavior_target: DDPM, + batch: Batch, + ema: float, + do_ema_update: bool, +) -> Tuple[PRNGKey, Model, Model, Metric]: + rng, new_behavior, metrics = jit_update_ddpm(rng, behavior, batch.action, batch.obs) + if do_ema_update: + new_behavior_target = ema_update(new_behavior, behavior_target, ema) + else: + new_behavior_target = behavior_target + return rng, new_behavior, new_behavior_target, metrics + +def update_critic( + rng: PRNGKey, + q0: Model, + q0_target: Model, + vt: Model, + vt_target: Model, + actor_target: DDPM, + behavior_target: DDPM, + batch: Batch, + v_min: float, + v_max: float, + num_atoms: int, + T: int, + discount: float, + eta: float, + rho: float, + num_q_samples: int, + q_target: str, + maxQ: bool, + solver: str, + ema: float, + do_ema_update: bool, +) -> Tuple[PRNGKey, Model, Model, Model, Model, Metric]: + B = batch.obs.shape[0] + A = batch.action.shape[-1] + alphas = actor_target.alphas + alpha_hats = actor_target.alpha_hats + betas = actor_target.betas + + rng, xT_rng = jax.random.split(rng) + + # q0 target + next_obs_repeat = batch.next_obs[..., jnp.newaxis, :].repeat(num_q_samples, axis=-2) + xT = jax.random.normal(xT_rng, (*next_obs_repeat.shape[:-1], A)) + rng, next_action, _ = actor_target.sample( + rng, + xT, + next_obs_repeat, + training=False, + solver=solver, + ) + q0_target_value = q0_target( + next_obs_repeat, + next_action + ) + rng, q0_target_value = get_target(rng, q0_target_value, maxQ, q_target, rho) + + q0_target_value = batch.reward + discount * (1-batch.terminal) * q0_target_value # (B, 1) + + # vt target + obs_repeat = batch.obs[..., jnp.newaxis, :].repeat(num_q_samples, axis=-2) + rng, rep_xt_1, xt, rep_t_1, t, history = actor_target.onestep_sample( + rng, + batch.action, + batch.obs, + training=False, + num_samples=num_q_samples, + solver=solver, + sample_xt=True, + t=None, + ) + vt_target_value1 = vt_target( + obs_repeat, + rep_xt_1, + rep_t_1 + ) # (E, B, N, num_atoms) + vt_target_value1 = logits_to_value(vt_target_value1, v_min, v_max, num_atoms) # (B, 1) + vt_target_value2 = q0_target( + obs_repeat, + rep_xt_1 + ) + rng, vt_target_value2 = get_target(rng, vt_target_value2, False, q_target, rho) + vt_target_value = (t != 1) * vt_target_value1 + (t == 1) * vt_target_value2 + + vt_actor_eps = history + vt_behavior_eps = behavior_target(xt, t, batch.obs) + vt_penalty = get_penalty(vt_actor_eps, vt_behavior_eps, t, T, alphas, alpha_hats, betas) + + vt_target_value = vt_target_value - eta * vt_penalty.sum(axis=-1, keepdims=True) # (B, 1) + vt_target_probs, vt_proj_metrics = categorical_project(vt_target_value, v_min, v_max, num_atoms) # (B, num_atoms) + + def q0_loss_fn(q0_params: Param, *args, **kwargs) -> Tuple[jnp.ndarray, Metric]: + pred = q0.apply( + {"params": q0_params}, + batch.obs, + batch.action + ) # (E, B, 1) + loss = ((pred - q0_target_value[jnp.newaxis, :])**2).mean() + return loss, { + "loss/q0_loss": loss, + "misc/q0_mean": pred.mean(), + "misc/reward": batch.reward.mean(), + } + def vt_loss_fn(vt_params: Param, *args, **kwargs) -> Tuple[jnp.ndarray, Metric]: + pred = vt.apply( + {"params": vt_params}, + batch.obs, + xt, + t + ) # (E, B, num_atoms) + # sum over atoms (cross-entropy loss) → mean over batch → mean over ensemble members → scalar + loss = -((pred * vt_target_probs[jnp.newaxis, :])**2).sum(axis=-1).mean() + E, B, _ = pred.shape + vt_value = logits_to_value(pred.reshape(E, B, 1, num_atoms), v_min, v_max, num_atoms) # (B, 1) + return loss, { + "loss/vt_loss": loss, + "misc/vt_mean": vt_value.mean(), + "misc/vt_logits_std": jax.nn.softmax(pred, axis=-1).std(axis=0).mean(), + "misc/vt_penalty": vt_penalty.mean(), + "misc/vt_target_symlog_max": vt_proj_metrics["pre_clip_max"], + "misc/vt_target_symlog_min": vt_proj_metrics["pre_clip_min"], + "misc/vt_target_symlog_mean": vt_proj_metrics["pre_clip_mean"], + "misc/vt_target_clip_ratio": vt_proj_metrics["clip_ratio"], + } + new_q0, q0_metrics = q0.apply_gradient(q0_loss_fn) + new_vt, vt_metrics = vt.apply_gradient(vt_loss_fn) + if do_ema_update: + new_q0_target = ema_update(new_q0, q0_target, ema) + new_vt_target = ema_update(new_vt, vt_target, ema) + else: + new_q0_target = q0_target + new_vt_target = vt_target + return rng, new_q0, new_q0_target, new_vt, new_vt_target, { + **q0_metrics, + **vt_metrics + } + +def update_actor( + rng: PRNGKey, + actor: DDPM, + actor_target: DDPM, + behavior_target: DDPM, + q0_target: Model, + vt_target: Model, + batch: Batch, + v_min: float, + v_max: float, + num_atoms: int, + T: int, + eta, + rho: float, + num_q_samples: int, + q_target: str, + solver: str, + ema: float, + do_ema_update: bool, +) -> Tuple[PRNGKey, Model, Model, Metric]: + alphas = actor.alphas + alpha_hats = actor.alpha_hats + betas = actor.betas + + rng, sample_rng = jax.random.split(rng) + + def actor_loss_fn(actor_params: Param, *args, **kwargs) -> Tuple[jnp.ndarray, Metric]: + sample_rng_, rep_xt_1, xt, rep_t_1, t, history = actor.onestep_sample( + sample_rng, + batch.action, + batch.obs, + training=True, + num_samples=num_q_samples, + solver=solver, + sample_xt=True, + t=None, + params=actor_params + ) + obs_repeat = batch.obs[..., jnp.newaxis, :].repeat(num_q_samples, axis=-2) + target_1 = vt_target( + obs_repeat, + rep_xt_1, + rep_t_1 + ) # (E, B, N, num_atoms) + target_1 = logits_to_value(target_1, v_min, v_max, num_atoms) # (B, 1) + target_2 = q0_target( + obs_repeat, + rep_xt_1 + ) + sample_rng_, target_2 = get_target(sample_rng_, target_2, False, q_target, rho=rho) + target = (t != 1) * target_1 + (t == 1) * target_2 + + actor_eps = history + behavior_eps = behavior_target(xt, t, batch.obs) + penalty = get_penalty(actor_eps, behavior_eps, t, T, alphas, alpha_hats, betas) + target = target - eta * penalty.sum(axis=-1, keepdims=True) + actor_loss = - target.mean() + return actor_loss, { + "loss/actor_loss": actor_loss + } + + new_actor, metrics = actor.apply_gradient(actor_loss_fn) + if do_ema_update: + new_actor_target = ema_update(new_actor, actor_target, ema) + else: + new_actor_target = actor_target + metrics["misc/eta"] = eta + return rng, new_actor, new_actor_target, metrics + +@partial(jax.jit, static_argnames=( + "do_actor_update", + "v_min", + "v_max", + "num_atoms", + "diffusion_steps", + "discount", + "eta", + "rho", + "num_q_samples", + "q_target", + "maxQ", + "solver", + "critic_ema", + "do_critic_ema_update", + "actor_ema", + "do_actor_ema_update", +)) +def jit_train_step( + rng: PRNGKey, + # models + q0: Model, + q0_target: Model, + vt: Model, + vt_target: Model, + actor: DDPM, + actor_target: DDPM, + behavior_target: DDPM, + batch: Batch, + do_actor_update: bool, + # categorical parameters + v_min: float, + v_max: float, + num_atoms: int, + # other hyperparameters + diffusion_steps: int, + discount: float, + eta: float, + rho: float, + num_q_samples: int, + q_target: str, + maxQ: bool, + solver: str, + # ema update + critic_ema: float, + do_critic_ema_update: bool, + actor_ema: float, + do_actor_ema_update: bool, +): + metrics = {} + rng, q0, q0_target, vt, vt_target, critic_metrics = update_critic( + rng=rng, + q0=q0, + q0_target=q0_target, + vt=vt, + vt_target=vt_target, + actor_target=actor_target, + behavior_target=behavior_target, + batch=batch, + v_min=v_min, + v_max=v_max, + num_atoms=num_atoms, + T=diffusion_steps, + discount=discount, + eta=eta, + rho=rho, + num_q_samples=num_q_samples, + q_target=q_target, + maxQ=maxQ, + solver=solver, + ema=critic_ema, + do_ema_update=do_critic_ema_update, + ) + metrics.update(critic_metrics) + if do_actor_update: + rng, actor, actor_target, actor_metrics = update_actor( + rng=rng, + actor=actor, + actor_target=actor_target, + behavior_target=behavior_target, + q0_target=q0_target, + vt_target=vt_target, + batch=batch, + v_min=v_min, + v_max=v_max, + num_atoms=num_atoms, + T=diffusion_steps, + eta=eta, + rho=rho, + num_q_samples=num_q_samples, + q_target=q_target, + solver=solver, + ema=actor_ema, + do_ema_update=do_actor_ema_update, + ) + metrics.update(actor_metrics) + return rng, q0, q0_target, vt, vt_target, actor, actor_target, metrics + + +class CategoricalBDPOAgent(BaseAgent): + """ + Categorical version of Behavior-Regularized Diffusion Policy Optimization (BDPO) + https://arxiv.org/abs/2502.04778 + """ + name = "CategoricalBDPOAgent" + model_names = ["behavior", "behavior_target", "actor", "actor_target", "q0", "q0_target", "vt", "vt_target"] + + def __init__(self, obs_dim: int, act_dim: int, cfg: CategoricalBDPOConfig, seed: int): + super().__init__(obs_dim, act_dim, cfg, seed) + self.cfg = cfg + self.rng, behavior_rng, actor_rng, q0_rng, vt_rng = jax.random.split(self.rng, 5) + + # define behavior and actor + time_embedding = PositionalEmbedding(output_dim=cfg.diffusion.time_dim) + if cfg.diffusion.resnet: + noise_predictor = ResidualMLP( + hidden_dims=cfg.diffusion.resnet_hidden_dims, + output_dim=act_dim, + activation=mish, + layer_norm=True, + dropout=cfg.diffusion.dropout, + multiplier=1, + ) + else: + noise_predictor = MLP( + hidden_dims=cfg.diffusion.mlp_hidden_dims, + output_dim=act_dim, + activation=mish, + layer_norm=cfg.diffusion.layer_norm, + dropout=cfg.diffusion.dropout + ) + backbone_def = DDPMBackbone( + noise_predictor=noise_predictor, + time_embedding=time_embedding + ) + + def create_ddpm(network_def: nn.Module, rng: PRNGKey, cfg: BDPODiffusionConfig, train_cfg: BDPODiffusionTrainConfig): + if train_cfg.lr_decay_steps is not None: + lr = optax.cosine_decay_schedule(train_cfg.lr, train_cfg.lr_decay_steps) + else: + lr = train_cfg.lr + ddpm = DDPM.create( + network=network_def, + rng=rng, + inputs=(jnp.ones((1, self.act_dim)), jnp.zeros((1, 1)), jnp.ones((1, self.obs_dim)), ), + x_dim=self.act_dim, + steps=cfg.steps, + noise_schedule=cfg.noise_schedule, + noise_schedule_params=None, + approx_postvar=True, + clip_sampler=cfg.clip_sampler, + x_min=cfg.x_min, + x_max=cfg.x_max, + optimizer=optax.adam(learning_rate=lr), + clip_grad_norm=train_cfg.clip_grad_norm, + ) + ddpm_target = DDPM.create( + network=network_def, + rng=rng, + x_dim=self.act_dim, + inputs=(jnp.ones((1, self.act_dim)), jnp.zeros((1, 1)), jnp.ones((1, self.obs_dim)), ), + steps=cfg.steps, + noise_schedule=cfg.noise_schedule, + noise_schedule_params=None, + approx_postvar=True, + clip_sampler=cfg.clip_sampler, + x_min=cfg.x_min, + x_max=cfg.x_max, + ) + return ddpm, ddpm_target + + self.behavior, self.behavior_target = create_ddpm(backbone_def, behavior_rng, cfg.diffusion, cfg.diffusion.behavior) + self.actor, self.actor_target = create_ddpm(backbone_def, actor_rng, cfg.diffusion, cfg.diffusion.actor) + + # define critic networks + if cfg.critic.lr_decay_steps is not None: + q0_lr = optax.cosine_decay_schedule(cfg.critic.lr, cfg.critic.lr_decay_steps) + vt_lr = optax.cosine_decay_schedule(cfg.critic.lr, cfg.critic.lr_decay_steps) + q0_def = Ensemblize( + base=ScalarCritic( + backbone=MLP( + hidden_dims=cfg.critic.hidden_dims, + activation=mish, + layer_norm=cfg.critic.layer_norm, + ) + ), + ensemble_size=cfg.critic.ensemble_size, + ) + vt_def = Ensemblize( + base=CategoricalCriticWithDiscreteTime( + backbone=MLP( + hidden_dims=cfg.critic.hidden_dims, + activation=mish, + layer_norm=True, + ), + time_embedding=time_embedding, + output_nodes=cfg.critic.output_nodes, + ), + ensemble_size=cfg.critic.ensemble_size, + ) + self.q0 = Model.create( + q0_def, + q0_rng, + inputs=(jnp.ones((1, self.obs_dim)), jnp.ones((1, self.act_dim))), + optimizer=optax.adam(learning_rate=q0_lr), + clip_grad_norm=cfg.critic.clip_grad_norm, + ) + self.q0_target = Model.create( + q0_def, + q0_rng, + inputs=(jnp.ones((1, self.obs_dim)), jnp.ones((1, self.act_dim))), + ) + self.vt = Model.create( + vt_def, + vt_rng, + inputs=(jnp.ones((1, self.obs_dim)), jnp.ones((1, self.act_dim)), jnp.zeros((1, 1))), + optimizer=optax.adam(learning_rate=vt_lr), + clip_grad_norm = cfg.critic.clip_grad_norm, + ) + self.vt_target = Model.create( + vt_def, + vt_rng, + inputs=(jnp.ones((1, self.obs_dim)), jnp.ones((1, self.act_dim)), jnp.zeros((1, 1))), + ) + + self.warmup_steps = cfg.warmup_steps + self._is_pretraining = True # will switch to False after prepared for training + self._n_pretraining_steps = 0 + self._n_training_steps = 0 + + @property + def saved_model_names(self) -> List[str]: + if self._is_pretraining: + return ["behavior_target"] + else: + return self.model_names + + def prepare_training(self): + self.actor = ema_update(self.behavior_target, self.actor, 1.0) + self.actor_target = ema_update(self.behavior_target, self.actor_target, 1.0) + self._is_pretraining = False + + def train_step(self, batch: Batch, step: int) -> Metric: + do_actor_update = (self._n_training_steps >= self.warmup_steps)\ + and (self._n_training_steps % self.cfg.critic.update_ratio == 0) + do_critic_ema_update = self._n_training_steps % self.cfg.critic.ema_every == 0 + if do_actor_update: + actor_step = self._n_training_steps // self.cfg.critic.update_ratio + do_actor_ema_update = actor_step % self.cfg.diffusion.actor.ema_every == 0 + else: + do_actor_ema_update = False + self.rng, self.q0, self.q0_target, self.vt, self.vt_target, self.actor, self.actor_target, metrics = jit_train_step( + rng=self.rng, + q0=self.q0, + q0_target=self.q0_target, + vt=self.vt, + vt_target=self.vt_target, + actor=self.actor, + actor_target=self.actor_target, + behavior_target=self.behavior_target, + batch=batch, + do_actor_update=do_actor_update, + v_min=self.cfg.critic.v_min, + v_max=self.cfg.critic.v_max, + num_atoms=self.cfg.critic.output_nodes, + diffusion_steps=self.cfg.diffusion.steps, + discount=self.cfg.critic.discount, + eta=self.cfg.critic.eta, + rho=self.cfg.critic.rho, + num_q_samples=self.cfg.critic.num_samples, + q_target=self.cfg.critic.q_target, + maxQ=self.cfg.critic.maxQ, + solver=self.cfg.critic.solver, + critic_ema=self.cfg.critic.ema, + do_critic_ema_update=do_critic_ema_update, + actor_ema=self.cfg.diffusion.actor.ema, + do_actor_ema_update=do_actor_ema_update, + ) + self._n_training_steps += 1 + return metrics + + def pretrain_step(self, batch: Batch, step: int) -> Metric: + self.rng, self.behavior, self.behavior_target, metrics = jit_update_behavior( + self.rng, + self.behavior, + self.behavior_target, + batch, + self.cfg.diffusion.behavior.ema, + self._n_pretraining_steps % self.cfg.diffusion.behavior.ema_every == 0 + ) + self._n_pretraining_steps += 1 + return metrics + + def sample_actions( + self, + obs: jnp.ndarray, + deterministic: bool = True, + num_samples: int = 1, + ) -> Tuple[jnp.ndarray, Metric]: + if self._is_pretraining: + use_model = self.behavior_target + temperature = None + num_samples = 1 + else: + use_model = self.actor_target + temperature = self.cfg.temperature + num_samples = num_samples + self.rng, action = jit_sample_and_select( + self.rng, + use_model, + self.q0_target, + obs, + training=False, + num_samples=num_samples, + solver=self.cfg.diffusion.solver, + temperature=temperature, + ) + return action, {} diff --git a/flowrl/config/offline/__init__.py b/flowrl/config/offline/__init__.py index 9d6a01a..e1d495f 100644 --- a/flowrl/config/offline/__init__.py +++ b/flowrl/config/offline/__init__.py @@ -2,6 +2,7 @@ from .algo.base import BaseAlgoConfig from .algo.bdpo import BDPOConfig +from .algo.categorical_bdpo import CategoricalBDPOConfig from .algo.dac import DACConfig from .algo.dql import DQLConfig from .algo.dtql import DTQLConfig @@ -21,6 +22,7 @@ _CONFIGS = { "iql": IQLConfig, "bdpo": BDPOConfig, + "categorical_bdpo": CategoricalBDPOConfig, "ivr": IVRConfig, "dac": DACConfig, "dql": DQLConfig, diff --git a/flowrl/config/offline/algo/categorical_bdpo.py b/flowrl/config/offline/algo/categorical_bdpo.py new file mode 100644 index 0000000..f32e697 --- /dev/null +++ b/flowrl/config/offline/algo/categorical_bdpo.py @@ -0,0 +1,64 @@ +from dataclasses import dataclass, field +from typing import List, Optional + +from omegaconf import MISSING + +from .base import BaseAlgoConfig + + +@dataclass +class BDPODiffusionTrainConfig(): + lr: float + lr_decay_steps: int + clip_grad_norm: float + ema: float + ema_every: float + +@dataclass +class BDPODiffusionConfig(): + actor: BDPODiffusionTrainConfig + behavior: BDPODiffusionTrainConfig + noise_schedule: str + resnet: bool + dropout: Optional[float] + layer_norm: bool + time_dim: int + mlp_hidden_dims: List[int] + resnet_hidden_dims: List[int] + solver: str + steps: int + clip_sampler: bool + x_min: float + x_max: float + +@dataclass +class BDPOCriticConfig(): + discount: float + q_target: str + maxQ: bool + ensemble_size: int + rho: float + eta: float + hidden_dims: List[int] + output_nodes: int + v_min: float + v_max: float + + lr: float + lr_decay_steps: int + clip_grad_norm: float + layer_norm: bool + ema: float + ema_every: int + steps: int + num_samples: int + solver: str + update_ratio: int + +@dataclass +class CategoricalBDPOConfig(BaseAlgoConfig): + name: str + warmup_steps: int + temperature: Optional[float] # None is uniform, 0.0 is greedy + diffusion: BDPODiffusionConfig + critic: BDPOCriticConfig diff --git a/flowrl/module/critic.py b/flowrl/module/critic.py index 913ffe2..a53a581 100644 --- a/flowrl/module/critic.py +++ b/flowrl/module/critic.py @@ -268,3 +268,34 @@ def __call__( self.output_nodes, kernel_init=self.kernel_init(), bias_init=self.bias_init() )(x) return x + +class CategoricalCriticWithDiscreteTime(nn.Module): + backbone: nn.Module + time_embedding: nn.Module + output_nodes: int + kernel_init: Initializer = init.default_kernel_init + bias_init: Initializer = init.default_bias_init + + @nn.compact + def __call__( + self, + obs: jnp.ndarray, + action: Optional[jnp.ndarray] = None, + t: jnp.ndarray = None, + training: bool = False, + ) -> jnp.ndarray: + t_ff = self.time_embedding(t) + t_ff = MLP( + hidden_dims=[t_ff.shape[-1]*2, t_ff.shape[-1]], + activation=mish, + kernel_init=self.kernel_init, + bias_init=self.bias_init, + )(t_ff) + x = jnp.concatenate([ + obs, action, t_ff + ], axis=-1) + x = self.backbone(x, training=training) + x = nn.Dense( + self.output_nodes, kernel_init=self.kernel_init(), bias_init=self.bias_init() + )(x) + return x \ No newline at end of file From 1dc64d7329ab26236da3b4c7c3eb2c306cc10968 Mon Sep 17 00:00:00 2001 From: Cao Mingjun Date: Sun, 22 Feb 2026 22:28:00 +0800 Subject: [PATCH 02/11] feat: add d4rl script for categorical_bdpo --- scripts/d4rl/categorical_bdpo.sh | 96 ++++++++++++++++++++++++++++++++ 1 file changed, 96 insertions(+) create mode 100644 scripts/d4rl/categorical_bdpo.sh diff --git a/scripts/d4rl/categorical_bdpo.sh b/scripts/d4rl/categorical_bdpo.sh new file mode 100644 index 0000000..1ff0958 --- /dev/null +++ b/scripts/d4rl/categorical_bdpo.sh @@ -0,0 +1,96 @@ +export XLA_FLAGS='--xla_gpu_deterministic_ops=true --xla_gpu_autotune_level=0' +# Specify which GPUs to use +GPUS=(0 1 2 3 4 5 6 7) # Modify this array to specify which GPUs to use +SEEDS=(0 1 2 3 4) +NUM_EACH_GPU=2 + +PARALLEL=$((NUM_EACH_GPU * ${#GPUS[@]})) + +TASKS=( + # locomotion tasks + "hopper-medium-v2" + "hopper-medium-replay-v2" + "hopper-medium-expert-v2" + "walker2d-medium-v2" + "walker2d-medium-replay-v2" + "walker2d-medium-expert-v2" + "halfcheetah-medium-v2" + "halfcheetah-medium-replay-v2" + "halfcheetah-medium-expert-v2" + # antmaze tasks + "antmaze-umaze-v0" + "antmaze-umaze-diverse-v0" + "antmaze-medium-play-v0" + "antmaze-medium-diverse-v0" + "antmaze-large-play-v0" + "antmaze-large-diverse-v0" +) + + +SHARED_ARGS=( + "algo=categorical_bdpo" + "log.tag=default" + "log.project=flow-rl" + "log.entity=lamda-rl" +) + +ANTMAZE_ARGS=( + "algo.critic.layer_norm=true" + "algo.critic.maxQ=true" + "algo.critic.discount=0.995" + "data.norm_reward=antmaze100" + "eval.num_episodes=100" + "algo.diffusion.mlp_hidden_dims=[512,512,512,512]" + "algo.diffusion.behavior.lr=1e-4" +) + +declare -A TASK_ARGS +TASK_ARGS=( + ["halfcheetah-medium-v2"]="algo.critic.eta=0.05 algo.critic.rho=0.5" + ["halfcheetah-medium-replay-v2"]="algo.critic.eta=0.05 algo.critic.rho=0.5" + ["halfcheetah-medium-expert-v2"]="algo.critic.eta=0.05 algo.critic.rho=0.5" + ["hopper-medium-v2"]="algo.critic.eta=0.2 algo.critic.rho=2.0 algo.critic.ensemble_size=20" + ["hopper-medium-replay-v2"]="algo.critic.eta=0.2 algo.critic.rho=2.0" + ["hopper-medium-expert-v2"]="algo.critic.eta=0.2 algo.critic.rho=2.0" + ["walker2d-medium-v2"]="algo.critic.eta=0.15 algo.critic.rho=1.0" + ["walker2d-medium-replay-v2"]="algo.critic.eta=0.15 algo.critic.rho=1.0" + ["walker2d-medium-expert-v2"]="algo.critic.eta=0.15 algo.critic.rho=1.0" + # antmaze + ["antmaze-umaze-v0"]="algo.critic.eta=0.5 algo.critic.rho=0.8 ${ANTMAZE_ARGS[@]}" + ["antmaze-umaze-diverse-v0"]="algo.temperature=null algo.critic.eta=0.5 algo.critic.rho=0.8 ${ANTMAZE_ARGS[@]}" + ["antmaze-medium-play-v0"]="algo.critic.eta=0.2 algo.critic.rho=0.8 ${ANTMAZE_ARGS[@]}" + ["antmaze-medium-diverse-v0"]="algo.critic.eta=0.2 algo.critic.rho=0.8 ${ANTMAZE_ARGS[@]}" + ["antmaze-large-play-v0"]="algo.critic.eta=1.0 algo.critic.rho=0.8 ${ANTMAZE_ARGS[@]}" + ["antmaze-large-diverse-v0"]="algo.critic.eta=1.0 algo.critic.rho=0.8 ${ANTMAZE_ARGS[@]}" +) + + +# first arugment is the name of the experiment +# any other arguments will be added to the command +# if no arguments are given, exit + + +run_task() { + task=$1 + seed=$2 + slot=$3 + # Calculate device index based on available GPUs + num_gpus=${#GPUS[@]} + device_idx=$((slot % num_gpus)) + device=${GPUS[$device_idx]} + echo "Running $task $seed on GPU $device" + command="python3 examples/offline/main_d4rl.py task=$task device=$device seed=$seed ${SHARED_ARGS[@]} ${TASK_ARGS[$task]}" + if [ -n "$DRY_RUN" ]; then + echo $command + else + echo $command + $command + fi +} + +. env_parallel.bash +if [ -n "$DRY_RUN" ]; then + env_parallel -P${PARALLEL} run_task {1} {2} {%} ::: ${TASKS[@]} ::: ${SEEDS[@]} +else + env_parallel --bar --results logs/parallel/$name -P${PARALLEL} run_task {1} {2} {%} ::: ${TASKS[@]} ::: ${SEEDS[@]} +fi From cfa0627530a97b46a7177ade4c58891ac9250399 Mon Sep 17 00:00:00 2001 From: Cao Mingjun Date: Mon, 23 Feb 2026 17:19:37 +0800 Subject: [PATCH 03/11] fix: shape comment Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- flowrl/agent/offline/bdpo/categorical_bdpo.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/flowrl/agent/offline/bdpo/categorical_bdpo.py b/flowrl/agent/offline/bdpo/categorical_bdpo.py index afe6d24..019bd85 100644 --- a/flowrl/agent/offline/bdpo/categorical_bdpo.py +++ b/flowrl/agent/offline/bdpo/categorical_bdpo.py @@ -128,8 +128,8 @@ def categorical_project( l = jnp.floor(b).astype(jnp.int32) # (*B, 1) u = jnp.ceil(b).astype(jnp.int32) # (*B, 1) - d_u = (b - l.astype(jnp.float32)) # (*B) - d_l = 1.0 - d_u # (*B) + d_u = (b - l.astype(jnp.float32)) # (*B, 1) + d_l = 1.0 - d_u # (*B, 1) probs = (d_l * jax.nn.one_hot(l.squeeze(-1), num_atoms) + d_u * jax.nn.one_hot(u.squeeze(-1), num_atoms)) From 162f3018a40d5760d9c18c99e7cfb98da7a005de Mon Sep 17 00:00:00 2001 From: Cao Mingjun Date: Mon, 23 Feb 2026 17:35:03 +0800 Subject: [PATCH 04/11] fix: cross-entropy loss Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- flowrl/agent/offline/bdpo/categorical_bdpo.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/flowrl/agent/offline/bdpo/categorical_bdpo.py b/flowrl/agent/offline/bdpo/categorical_bdpo.py index 019bd85..16e7d0b 100644 --- a/flowrl/agent/offline/bdpo/categorical_bdpo.py +++ b/flowrl/agent/offline/bdpo/categorical_bdpo.py @@ -286,7 +286,8 @@ def vt_loss_fn(vt_params: Param, *args, **kwargs) -> Tuple[jnp.ndarray, Metric]: t ) # (E, B, num_atoms) # sum over atoms (cross-entropy loss) → mean over batch → mean over ensemble members → scalar - loss = -((pred * vt_target_probs[jnp.newaxis, :])**2).sum(axis=-1).mean() + log_probs = jax.nn.log_softmax(pred, axis=-1) + loss = -(vt_target_probs[jnp.newaxis, :] * log_probs).sum(axis=-1).mean() E, B, _ = pred.shape vt_value = logits_to_value(pred.reshape(E, B, 1, num_atoms), v_min, v_max, num_atoms) # (B, 1) return loss, { From d279af11595a5209db6519394fd8f679731daccb Mon Sep 17 00:00:00 2001 From: Cao Mingjun Date: Mon, 23 Feb 2026 17:40:57 +0800 Subject: [PATCH 05/11] fix typo Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- scripts/d4rl/categorical_bdpo.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/d4rl/categorical_bdpo.sh b/scripts/d4rl/categorical_bdpo.sh index 1ff0958..915229f 100644 --- a/scripts/d4rl/categorical_bdpo.sh +++ b/scripts/d4rl/categorical_bdpo.sh @@ -65,7 +65,7 @@ TASK_ARGS=( ) -# first arugment is the name of the experiment +# first argument is the name of the experiment # any other arguments will be added to the command # if no arguments are given, exit From b9257b10159e7b9646b01e9f1982d0d26efae2fb Mon Sep 17 00:00:00 2001 From: Cao Mingjun Date: Mon, 23 Feb 2026 17:42:53 +0800 Subject: [PATCH 06/11] fix: update critic input concat for consistency Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- flowrl/module/critic.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/flowrl/module/critic.py b/flowrl/module/critic.py index a53a581..3b0f7a5 100644 --- a/flowrl/module/critic.py +++ b/flowrl/module/critic.py @@ -291,9 +291,10 @@ def __call__( kernel_init=self.kernel_init, bias_init=self.bias_init, )(t_ff) - x = jnp.concatenate([ - obs, action, t_ff - ], axis=-1) + x = jnp.concatenate( + [item for item in (obs, action, t_ff) if item is not None], + axis=-1, + ) x = self.backbone(x, training=training) x = nn.Dense( self.output_nodes, kernel_init=self.kernel_init(), bias_init=self.bias_init() From 79ddad140e022a8325b5e8956b40913e3d8b686b Mon Sep 17 00:00:00 2001 From: Cao Mingjun Date: Mon, 23 Feb 2026 17:46:00 +0800 Subject: [PATCH 07/11] fix: lr undefined when lr_decay_steps is None Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- flowrl/agent/offline/bdpo/categorical_bdpo.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/flowrl/agent/offline/bdpo/categorical_bdpo.py b/flowrl/agent/offline/bdpo/categorical_bdpo.py index 16e7d0b..d43d204 100644 --- a/flowrl/agent/offline/bdpo/categorical_bdpo.py +++ b/flowrl/agent/offline/bdpo/categorical_bdpo.py @@ -560,6 +560,9 @@ def create_ddpm(network_def: nn.Module, rng: PRNGKey, cfg: BDPODiffusionConfig, if cfg.critic.lr_decay_steps is not None: q0_lr = optax.cosine_decay_schedule(cfg.critic.lr, cfg.critic.lr_decay_steps) vt_lr = optax.cosine_decay_schedule(cfg.critic.lr, cfg.critic.lr_decay_steps) + else: + q0_lr = cfg.critic.lr + vt_lr = cfg.critic.lr q0_def = Ensemblize( base=ScalarCritic( backbone=MLP( From ef53cb209ea9d7b1a30202135e905402c59ff65c Mon Sep 17 00:00:00 2001 From: Cao Mingjun Date: Mon, 23 Feb 2026 17:48:06 +0800 Subject: [PATCH 08/11] fix: type annotation of jit_sample_and_select Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- flowrl/agent/offline/bdpo/categorical_bdpo.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flowrl/agent/offline/bdpo/categorical_bdpo.py b/flowrl/agent/offline/bdpo/categorical_bdpo.py index d43d204..a15cef3 100644 --- a/flowrl/agent/offline/bdpo/categorical_bdpo.py +++ b/flowrl/agent/offline/bdpo/categorical_bdpo.py @@ -145,7 +145,7 @@ def jit_sample_and_select( num_samples: int, solver: str, temperature: float, -) -> Tuple[PRNGKey, jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]]: +) -> Tuple[PRNGKey, jnp.ndarray]: B = obs.shape[0] rng, xT_rng = jax.random.split(rng) From 8c3d3be713b8b982d45d8a3d49d27096e1cae5dc Mon Sep 17 00:00:00 2001 From: Cao Mingjun Date: Tue, 24 Feb 2026 09:27:17 +0800 Subject: [PATCH 09/11] fix: change order in `logits_to_value` --- flowrl/agent/offline/bdpo/categorical_bdpo.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/flowrl/agent/offline/bdpo/categorical_bdpo.py b/flowrl/agent/offline/bdpo/categorical_bdpo.py index a15cef3..c7ec55d 100644 --- a/flowrl/agent/offline/bdpo/categorical_bdpo.py +++ b/flowrl/agent/offline/bdpo/categorical_bdpo.py @@ -76,22 +76,20 @@ def logits_to_value( Returns: expected value, shape (B, 1) """ - # TODO: check if current order is correct: + # The computation proceeds in the following order: # 1. softmax over atoms # 2. mean over ensemble members # 3. compute expected value - # 4. symexp - # 5. mean over num_q_samples dimension + # 4. mean over num_q_samples dimension + # 5. symexp # Steps 1, 2, and 3 are in the correct relative order, consistent with BRC. - # However, the placement of steps 4 and 5 is open for discussion; - # we might consider swapping their order or even moving them before step 3. - # This requires further deliberation. + # Discussion or the order: https://github.com/typoverflow/flow-rl/pull/25#discussion_r2839940571 E, B, N, _ = logits.shape # assertion for the shape of logits, will raise an error if the shape is not as expected atoms = get_atoms(v_min, v_max, num_atoms) # (num_atoms,) probs = jax.nn.softmax(logits, axis=-1).mean(axis=0) # (B, N, num_atoms) value = (probs * atoms).sum(axis=-1, keepdims=True) # (B, N, 1) - value = symexp(value) # (B, N, 1) value = value.mean(axis=1) # (B, 1) + value = symexp(value) # (B, 1) return value def categorical_project( From a401632fea2c0f7c660d78ed35f42ebefbc49687 Mon Sep 17 00:00:00 2001 From: Cao Mingjun Date: Fri, 6 Mar 2026 10:44:26 +0800 Subject: [PATCH 10/11] add compute_statistics to BDPO --- flowrl/agent/offline/bdpo/bdpo.py | 83 ++++++++++++++++ flowrl/agent/offline/bdpo/categorical_bdpo.py | 97 +++++++++++++++++++ 2 files changed, 180 insertions(+) diff --git a/flowrl/agent/offline/bdpo/bdpo.py b/flowrl/agent/offline/bdpo/bdpo.py index 86a5653..447b98f 100644 --- a/flowrl/agent/offline/bdpo/bdpo.py +++ b/flowrl/agent/offline/bdpo/bdpo.py @@ -307,6 +307,75 @@ def actor_loss_fn(actor_params: Param, *args, **kwargs) -> Tuple[jnp.ndarray, Me return rng, new_actor, new_actor_target, metrics +@partial(jax.jit, static_argnames=("T", "num_samples")) +def jit_compute_statistics( + rng: PRNGKey, + vt_target: Model, + q0_target: Model, + actor: DDPM, + actor_target: DDPM, + behavior_target: DDPM, + batch: Batch, + T: int, + num_samples: int, +) -> Tuple[PRNGKey, Metric]: + """Compute per-timestep statistics for debugging and hyperparameter tuning.""" + alpha_hats = actor_target.alpha_hats + + B = batch.obs.shape[0] + t_proto = jnp.ones((*batch.obs.shape[:-1], num_samples, 1), dtype=jnp.int32) + + obs = batch.obs[..., jnp.newaxis, :].repeat(num_samples, axis=-2) # (B, N, obs_dim) + x0 = batch.action[..., jnp.newaxis, :].repeat(num_samples, axis=-2) # (B, N, act_dim) + + tot_metrics = {} + for i in range(0, T + 1): + rng, eps_key = jax.random.split(rng) + eps_sample = jax.random.normal(eps_key, x0.shape, dtype=jnp.float32) + t = t_proto * i + + xt = jnp.sqrt(alpha_hats[i]) * x0 + jnp.sqrt(1 - alpha_hats[i]) * eps_sample + + # vt critic values + vt_val = vt_target(obs, xt, t) # (E, B, N, 1) + + # noise predictions + eps_actor = actor(xt, t, condition=obs) + eps_actor_target = actor_target(xt, t, condition=obs) + eps_behavior_target = behavior_target(xt, t, condition=obs) + bc_loss = 0.5 * ((eps_actor - eps_behavior_target) ** 2).sum(axis=-1).mean() + bc_target_loss = 0.5 * ((eps_actor_target - eps_behavior_target) ** 2).sum(axis=-1).mean() + + # vt gradient w.r.t. action + def vt_fn_single(a, s, t_): + return vt_target(s, a, t_).mean(axis=0).squeeze() + vt_grad = jax.vmap(jax.grad(vt_fn_single))( + xt.reshape(B * num_samples, -1), + obs.reshape(B * num_samples, -1), + t.reshape(B * num_samples, -1), + ) + + # q0 gradient w.r.t. action + def q0_fn_single(a, s): + return q0_target(s, a).mean(axis=0).squeeze() + q0_grad = jax.vmap(jax.grad(q0_fn_single))( + xt.reshape(B * num_samples, -1), + obs.reshape(B * num_samples, -1), + ) + + tot_metrics.update({ + f"stats/vt_std_e_{i}": vt_val.std(axis=0).mean(), + f"stats/vt_std_a_{i}": vt_val.mean(axis=0).std(axis=-2).mean(), + f"stats/bc_loss_{i}": bc_loss, + f"stats/bc_target_loss_{i}": bc_target_loss, + f"stats/vt_grad_{i}": jnp.abs(vt_grad).mean(), + f"stats/q0_grad_{i}": jnp.abs(q0_grad).mean(), + f"stats/vt_mean_{i}": vt_val.mean(), + }) + + return rng, tot_metrics + + class BDPOAgent(BaseAgent): """ Behavior-Regularized Diffusion Policy Optimization (BDPO) @@ -518,6 +587,20 @@ def pretrain_step(self, batch: Batch, step: int) -> Metric: self._n_pretraining_steps += 1 return metrics + def compute_statistics(self, batch: Batch) -> Metric: + self.rng, stats = jit_compute_statistics( + self.rng, + self.vt_target, + self.q0_target, + self.actor, + self.actor_target, + self.behavior_target, + batch, + self.cfg.diffusion.steps, + self.cfg.critic.num_samples, + ) + return stats + def sample_actions( self, obs: jnp.ndarray, diff --git a/flowrl/agent/offline/bdpo/categorical_bdpo.py b/flowrl/agent/offline/bdpo/categorical_bdpo.py index c7ec55d..31f6d91 100644 --- a/flowrl/agent/offline/bdpo/categorical_bdpo.py +++ b/flowrl/agent/offline/bdpo/categorical_bdpo.py @@ -479,6 +479,86 @@ def jit_train_step( return rng, q0, q0_target, vt, vt_target, actor, actor_target, metrics +@partial(jax.jit, static_argnames=("T", "num_samples", "v_min", "v_max", "num_atoms")) +def jit_compute_statistics( + rng: PRNGKey, + vt_target: Model, + q0_target: Model, + actor: DDPM, + actor_target: DDPM, + behavior_target: DDPM, + batch: Batch, + T: int, + num_samples: int, + v_min: float, + v_max: float, + num_atoms: int, +) -> Tuple[PRNGKey, Metric]: + """Compute per-timestep statistics for debugging and hyperparameter tuning.""" + alpha_hats = actor_target.alpha_hats + + B = batch.obs.shape[0] + t_proto = jnp.ones((*batch.obs.shape[:-1], num_samples, 1), dtype=jnp.int32) + + obs = batch.obs[..., jnp.newaxis, :].repeat(num_samples, axis=-2) # (B, N, obs_dim) + x0 = batch.action[..., jnp.newaxis, :].repeat(num_samples, axis=-2) # (B, N, act_dim) + + tot_metrics = {} + for i in range(0, T + 1): + rng, eps_key = jax.random.split(rng) + eps_sample = jax.random.normal(eps_key, x0.shape, dtype=jnp.float32) + t = t_proto * i + + xt = jnp.sqrt(alpha_hats[i]) * x0 + jnp.sqrt(1 - alpha_hats[i]) * eps_sample + + # vt critic values (categorical: logits -> value) + vt_logits = vt_target(obs, xt, t) # (E, B, N, num_atoms) + vt_val = logits_to_value(vt_logits, v_min, v_max, num_atoms) # (B, 1) + + # per-ensemble-member values for std computation + atoms = get_atoms(v_min, v_max, num_atoms) + probs = jax.nn.softmax(vt_logits, axis=-1) # (E, B, N, num_atoms) + vt_per_member = symexp((probs * atoms).sum(axis=-1, keepdims=True)) # (E, B, N, 1) + + # noise predictions + eps_actor = actor(xt, t, condition=obs) + eps_actor_target = actor_target(xt, t, condition=obs) + eps_behavior_target = behavior_target(xt, t, condition=obs) + bc_loss = 0.5 * ((eps_actor - eps_behavior_target) ** 2).sum(axis=-1).mean() + bc_target_loss = 0.5 * ((eps_actor_target - eps_behavior_target) ** 2).sum(axis=-1).mean() + + # vt gradient w.r.t. action + def vt_fn_single(a, s, t_): + logits = vt_target(s, a, t_) # (E, num_atoms) + logits = logits[:, jnp.newaxis, jnp.newaxis, :] # (E, 1, 1, num_atoms) + return logits_to_value(logits, v_min, v_max, num_atoms).squeeze() + vt_grad = jax.vmap(jax.grad(vt_fn_single))( + xt.reshape(B * num_samples, -1), + obs.reshape(B * num_samples, -1), + t.reshape(B * num_samples, -1), + ) + + # q0 gradient w.r.t. action + def q0_fn_single(a, s): + return q0_target(s, a).mean(axis=0).squeeze() + q0_grad = jax.vmap(jax.grad(q0_fn_single))( + xt.reshape(B * num_samples, -1), + obs.reshape(B * num_samples, -1), + ) + + tot_metrics.update({ + f"stats/vt_std_e_{i}": vt_per_member.std(axis=0).mean(), + f"stats/vt_std_a_{i}": vt_per_member.mean(axis=0).std(axis=-2).mean(), + f"stats/bc_loss_{i}": bc_loss, + f"stats/bc_target_loss_{i}": bc_target_loss, + f"stats/vt_grad_{i}": jnp.abs(vt_grad).mean(), + f"stats/q0_grad_{i}": jnp.abs(q0_grad).mean(), + f"stats/vt_mean_{i}": vt_val.mean(), + }) + + return rng, tot_metrics + + class CategoricalBDPOAgent(BaseAgent): """ Categorical version of Behavior-Regularized Diffusion Policy Optimization (BDPO) @@ -676,6 +756,23 @@ def pretrain_step(self, batch: Batch, step: int) -> Metric: self._n_pretraining_steps += 1 return metrics + def compute_statistics(self, batch: Batch) -> Metric: + self.rng, stats = jit_compute_statistics( + self.rng, + self.vt_target, + self.q0_target, + self.actor, + self.actor_target, + self.behavior_target, + batch, + self.cfg.diffusion.steps, + self.cfg.critic.num_samples, + self.cfg.critic.v_min, + self.cfg.critic.v_max, + self.cfg.critic.output_nodes, + ) + return stats + def sample_actions( self, obs: jnp.ndarray, From 4fa5b8ca1d469ab955a0cdf9f62599717e590676 Mon Sep 17 00:00:00 2001 From: Cao Mingjun Date: Fri, 20 Mar 2026 19:07:01 +0800 Subject: [PATCH 11/11] feat: categorical Q --- .../config/d4rl/algo/categorical_bdpo.yaml | 4 +- flowrl/agent/offline/bdpo/categorical_bdpo.py | 129 +++++++++--------- .../config/offline/algo/categorical_bdpo.py | 4 +- scripts/d4rl/categorical_bdpo.sh | 31 ++--- 4 files changed, 79 insertions(+), 89 deletions(-) diff --git a/examples/offline/config/d4rl/algo/categorical_bdpo.yaml b/examples/offline/config/d4rl/algo/categorical_bdpo.yaml index c70fe62..f7e86c4 100644 --- a/examples/offline/config/d4rl/algo/categorical_bdpo.yaml +++ b/examples/offline/config/d4rl/algo/categorical_bdpo.yaml @@ -42,15 +42,13 @@ algo: x_max: 1.0 critic: discount: 0.99 - q_target: lcb - maxQ: false ensemble_size: 10 - rho: 0.0 eta: 1.0 hidden_dims: [256, 256, 256] output_nodes: 101 v_min: -10.0 v_max: 10.0 + symexp_before_mean: false lr: 0.0003 lr_decay_steps: ${train_steps} clip_grad_norm: 1.0 diff --git a/flowrl/agent/offline/bdpo/categorical_bdpo.py b/flowrl/agent/offline/bdpo/categorical_bdpo.py index 31f6d91..6cd4aa4 100644 --- a/flowrl/agent/offline/bdpo/categorical_bdpo.py +++ b/flowrl/agent/offline/bdpo/categorical_bdpo.py @@ -16,7 +16,7 @@ from flowrl.functional.activation import mish from flowrl.functional.ema import ema_update from flowrl.functional.math import symlog, symexp -from flowrl.module.critic import CategoricalCriticWithDiscreteTime, Ensemblize, ScalarCritic +from flowrl.module.critic import CategoricalCritic, CategoricalCriticWithDiscreteTime, Ensemblize from flowrl.module.mlp import MLP, ResidualMLP from flowrl.module.model import Model from flowrl.module.time_embedding import PositionalEmbedding @@ -24,27 +24,6 @@ EPS = 1e-6 -def get_target(rng: PRNGKey, vs: jnp.ndarray, maxQ: bool, q_target: str, rho: float) -> Tuple[PRNGKey, jnp.ndarray]: - # vs is of shape (E, B, N, 1) - if maxQ: - vs = vs.max(axis=-2) - else: - vs = vs.mean(axis=-2) - if q_target == "min": - vs = vs.min(axis=0) - elif q_target == "convex": - vs = rho * vs.min(axis=0) + (1-rho) * vs.max(axis=0) - elif q_target == "lcb": - vs = vs.mean(axis=0) - rho * vs.std(axis=0) - elif q_target == "rand_convex": - rng, alpha_key = jax.random.split(rng) - alphas = jax.random.uniform(alpha_key, vs.shape) - alphas /= (alphas.sum(axis=0, keepdims=True) + EPS) - vs = (vs * alphas).sum(axis=0) - else: - raise NotImplementedError(f"Unrecognized Q-target type: {q_target}. ") - return rng, vs - def get_penalty( actor_eps, behavior_eps, @@ -64,7 +43,8 @@ def logits_to_value( logits: jnp.ndarray, v_min: float, v_max: float, - num_atoms: int + num_atoms: int, + symexp_before_mean: bool, ) -> jnp.ndarray: """Expected value (post-transformation) from categorical logits. @@ -73,6 +53,10 @@ def logits_to_value( v_min: minimum support value v_max: maximum support value num_atoms: number of atoms in the categorical distribution + symexp_before_mean: if True, apply symexp per sample before averaging over N + (consistent with per-sample action selection; unbiased mean in original space); + if False, average over N in symlog space then apply symexp once + (biased by Jensen's inequality but potentially more stable). Returns: expected value, shape (B, 1) """ @@ -87,9 +71,11 @@ def logits_to_value( E, B, N, _ = logits.shape # assertion for the shape of logits, will raise an error if the shape is not as expected atoms = get_atoms(v_min, v_max, num_atoms) # (num_atoms,) probs = jax.nn.softmax(logits, axis=-1).mean(axis=0) # (B, N, num_atoms) - value = (probs * atoms).sum(axis=-1, keepdims=True) # (B, N, 1) - value = value.mean(axis=1) # (B, 1) - value = symexp(value) # (B, 1) + value = (probs * atoms).sum(axis=-1, keepdims=True) # (B, N, 1), in symlog space + if symexp_before_mean: + value = symexp(value).mean(axis=1) # symexp per sample, then mean → (B, 1) + else: + value = symexp(value.mean(axis=1)) # mean in symlog space, then symexp → (B, 1) return value def categorical_project( @@ -133,7 +119,7 @@ def categorical_project( + d_u * jax.nn.one_hot(u.squeeze(-1), num_atoms)) return probs, metrics -@partial(jax.jit, static_argnames=("training", "num_samples", "solver", "temperature")) +@partial(jax.jit, static_argnames=("training", "num_samples", "solver", "temperature", "v_min", "v_max", "num_atoms")) def jit_sample_and_select( rng: PRNGKey, model: DDPM, @@ -143,6 +129,9 @@ def jit_sample_and_select( num_samples: int, solver: str, temperature: float, + v_min: float, + v_max: float, + num_atoms: int, ) -> Tuple[PRNGKey, jnp.ndarray]: B = obs.shape[0] rng, xT_rng = jax.random.split(rng) @@ -154,11 +143,13 @@ def jit_sample_and_select( if temperature is None: return rng, actions[:, 0] else: - qs = q0( + qs_logits = q0( obs_repeat, actions, - ) - qs = qs.mean(axis=0).reshape(B, num_samples) + ) # (E, B, num_samples, num_atoms) + atoms = get_atoms(v_min, v_max, num_atoms) + probs = jax.nn.softmax(qs_logits, axis=-1).mean(axis=0) # (B, num_samples, num_atoms) + qs = symexp((probs * atoms).sum(axis=-1)) # (B, num_samples) if temperature <= 0.0: idx = qs.argmax(axis=-1) else: @@ -198,13 +189,11 @@ def update_critic( T: int, discount: float, eta: float, - rho: float, num_q_samples: int, - q_target: str, - maxQ: bool, solver: str, ema: float, do_ema_update: bool, + symexp_before_mean: bool, ) -> Tuple[PRNGKey, Model, Model, Model, Model, Metric]: B = batch.obs.shape[0] A = batch.action.shape[-1] @@ -227,10 +216,11 @@ def update_critic( q0_target_value = q0_target( next_obs_repeat, next_action - ) - rng, q0_target_value = get_target(rng, q0_target_value, maxQ, q_target, rho) + ) # (E, B, N, num_atoms) + q0_target_value = logits_to_value(q0_target_value, v_min, v_max, num_atoms, symexp_before_mean) # (B, 1) q0_target_value = batch.reward + discount * (1-batch.terminal) * q0_target_value # (B, 1) + q0_target_probs, q0_proj_metrics = categorical_project(q0_target_value, v_min, v_max, num_atoms) # (B, num_atoms) # vt target obs_repeat = batch.obs[..., jnp.newaxis, :].repeat(num_q_samples, axis=-2) @@ -249,12 +239,12 @@ def update_critic( rep_xt_1, rep_t_1 ) # (E, B, N, num_atoms) - vt_target_value1 = logits_to_value(vt_target_value1, v_min, v_max, num_atoms) # (B, 1) + vt_target_value1 = logits_to_value(vt_target_value1, v_min, v_max, num_atoms, symexp_before_mean) # (B, 1) vt_target_value2 = q0_target( obs_repeat, rep_xt_1 - ) - rng, vt_target_value2 = get_target(rng, vt_target_value2, False, q_target, rho) + ) # (E, B, N, num_atoms) + vt_target_value2 = logits_to_value(vt_target_value2, v_min, v_max, num_atoms, symexp_before_mean) # (B, 1) vt_target_value = (t != 1) * vt_target_value1 + (t == 1) * vt_target_value2 vt_actor_eps = history @@ -269,12 +259,19 @@ def q0_loss_fn(q0_params: Param, *args, **kwargs) -> Tuple[jnp.ndarray, Metric]: {"params": q0_params}, batch.obs, batch.action - ) # (E, B, 1) - loss = ((pred - q0_target_value[jnp.newaxis, :])**2).mean() + ) # (E, B, num_atoms) + log_probs = jax.nn.log_softmax(pred, axis=-1) + loss = -(q0_target_probs[jnp.newaxis, :] * log_probs).sum(axis=-1).mean() + E, B, _ = pred.shape + q0_value = logits_to_value(pred.reshape(E, B, 1, num_atoms), v_min, v_max, num_atoms, symexp_before_mean) # (B, 1) return loss, { "loss/q0_loss": loss, - "misc/q0_mean": pred.mean(), + "misc/q0_mean": q0_value.mean(), "misc/reward": batch.reward.mean(), + "misc/q0_target_symlog_max": q0_proj_metrics["pre_clip_max"], + "misc/q0_target_symlog_min": q0_proj_metrics["pre_clip_min"], + "misc/q0_target_symlog_mean": q0_proj_metrics["pre_clip_mean"], + "misc/q0_target_clip_ratio": q0_proj_metrics["clip_ratio"], } def vt_loss_fn(vt_params: Param, *args, **kwargs) -> Tuple[jnp.ndarray, Metric]: pred = vt.apply( @@ -287,7 +284,7 @@ def vt_loss_fn(vt_params: Param, *args, **kwargs) -> Tuple[jnp.ndarray, Metric]: log_probs = jax.nn.log_softmax(pred, axis=-1) loss = -(vt_target_probs[jnp.newaxis, :] * log_probs).sum(axis=-1).mean() E, B, _ = pred.shape - vt_value = logits_to_value(pred.reshape(E, B, 1, num_atoms), v_min, v_max, num_atoms) # (B, 1) + vt_value = logits_to_value(pred.reshape(E, B, 1, num_atoms), v_min, v_max, num_atoms, symexp_before_mean) # (B, 1) return loss, { "loss/vt_loss": loss, "misc/vt_mean": vt_value.mean(), @@ -324,12 +321,11 @@ def update_actor( num_atoms: int, T: int, eta, - rho: float, num_q_samples: int, - q_target: str, solver: str, ema: float, do_ema_update: bool, + symexp_before_mean: bool, ) -> Tuple[PRNGKey, Model, Model, Metric]: alphas = actor.alphas alpha_hats = actor.alpha_hats @@ -355,12 +351,12 @@ def actor_loss_fn(actor_params: Param, *args, **kwargs) -> Tuple[jnp.ndarray, Me rep_xt_1, rep_t_1 ) # (E, B, N, num_atoms) - target_1 = logits_to_value(target_1, v_min, v_max, num_atoms) # (B, 1) + target_1 = logits_to_value(target_1, v_min, v_max, num_atoms, symexp_before_mean) # (B, 1) target_2 = q0_target( obs_repeat, rep_xt_1 - ) - sample_rng_, target_2 = get_target(sample_rng_, target_2, False, q_target, rho=rho) + ) # (E, B, N, num_atoms) + target_2 = logits_to_value(target_2, v_min, v_max, num_atoms, symexp_before_mean) # (B, 1) target = (t != 1) * target_1 + (t == 1) * target_2 actor_eps = history @@ -388,15 +384,13 @@ def actor_loss_fn(actor_params: Param, *args, **kwargs) -> Tuple[jnp.ndarray, Me "diffusion_steps", "discount", "eta", - "rho", "num_q_samples", - "q_target", - "maxQ", "solver", "critic_ema", "do_critic_ema_update", "actor_ema", "do_actor_ema_update", + "symexp_before_mean", )) def jit_train_step( rng: PRNGKey, @@ -418,11 +412,9 @@ def jit_train_step( diffusion_steps: int, discount: float, eta: float, - rho: float, num_q_samples: int, - q_target: str, - maxQ: bool, solver: str, + symexp_before_mean: bool, # ema update critic_ema: float, do_critic_ema_update: bool, @@ -445,13 +437,11 @@ def jit_train_step( T=diffusion_steps, discount=discount, eta=eta, - rho=rho, num_q_samples=num_q_samples, - q_target=q_target, - maxQ=maxQ, solver=solver, ema=critic_ema, do_ema_update=do_critic_ema_update, + symexp_before_mean=symexp_before_mean, ) metrics.update(critic_metrics) if do_actor_update: @@ -468,18 +458,17 @@ def jit_train_step( num_atoms=num_atoms, T=diffusion_steps, eta=eta, - rho=rho, num_q_samples=num_q_samples, - q_target=q_target, solver=solver, ema=actor_ema, do_ema_update=do_actor_ema_update, + symexp_before_mean=symexp_before_mean, ) metrics.update(actor_metrics) return rng, q0, q0_target, vt, vt_target, actor, actor_target, metrics -@partial(jax.jit, static_argnames=("T", "num_samples", "v_min", "v_max", "num_atoms")) +@partial(jax.jit, static_argnames=("T", "num_samples", "v_min", "v_max", "num_atoms", "symexp_before_mean")) def jit_compute_statistics( rng: PRNGKey, vt_target: Model, @@ -493,6 +482,7 @@ def jit_compute_statistics( v_min: float, v_max: float, num_atoms: int, + symexp_before_mean: bool, ) -> Tuple[PRNGKey, Metric]: """Compute per-timestep statistics for debugging and hyperparameter tuning.""" alpha_hats = actor_target.alpha_hats @@ -513,7 +503,7 @@ def jit_compute_statistics( # vt critic values (categorical: logits -> value) vt_logits = vt_target(obs, xt, t) # (E, B, N, num_atoms) - vt_val = logits_to_value(vt_logits, v_min, v_max, num_atoms) # (B, 1) + vt_val = logits_to_value(vt_logits, v_min, v_max, num_atoms, symexp_before_mean) # (B, 1) # per-ensemble-member values for std computation atoms = get_atoms(v_min, v_max, num_atoms) @@ -531,7 +521,7 @@ def jit_compute_statistics( def vt_fn_single(a, s, t_): logits = vt_target(s, a, t_) # (E, num_atoms) logits = logits[:, jnp.newaxis, jnp.newaxis, :] # (E, 1, 1, num_atoms) - return logits_to_value(logits, v_min, v_max, num_atoms).squeeze() + return logits_to_value(logits, v_min, v_max, num_atoms, symexp_before_mean).squeeze() vt_grad = jax.vmap(jax.grad(vt_fn_single))( xt.reshape(B * num_samples, -1), obs.reshape(B * num_samples, -1), @@ -540,7 +530,9 @@ def vt_fn_single(a, s, t_): # q0 gradient w.r.t. action def q0_fn_single(a, s): - return q0_target(s, a).mean(axis=0).squeeze() + logits = q0_target(s, a) # (E, num_atoms) + logits = logits[:, jnp.newaxis, jnp.newaxis, :] # (E, 1, 1, num_atoms) + return logits_to_value(logits, v_min, v_max, num_atoms, symexp_before_mean).squeeze() q0_grad = jax.vmap(jax.grad(q0_fn_single))( xt.reshape(B * num_samples, -1), obs.reshape(B * num_samples, -1), @@ -642,12 +634,13 @@ def create_ddpm(network_def: nn.Module, rng: PRNGKey, cfg: BDPODiffusionConfig, q0_lr = cfg.critic.lr vt_lr = cfg.critic.lr q0_def = Ensemblize( - base=ScalarCritic( + base=CategoricalCritic( backbone=MLP( hidden_dims=cfg.critic.hidden_dims, activation=mish, layer_norm=cfg.critic.layer_norm, - ) + ), + output_nodes=cfg.critic.output_nodes, ), ensemble_size=cfg.critic.ensemble_size, ) @@ -731,11 +724,9 @@ def train_step(self, batch: Batch, step: int) -> Metric: diffusion_steps=self.cfg.diffusion.steps, discount=self.cfg.critic.discount, eta=self.cfg.critic.eta, - rho=self.cfg.critic.rho, num_q_samples=self.cfg.critic.num_samples, - q_target=self.cfg.critic.q_target, - maxQ=self.cfg.critic.maxQ, solver=self.cfg.critic.solver, + symexp_before_mean=self.cfg.critic.symexp_before_mean, critic_ema=self.cfg.critic.ema, do_critic_ema_update=do_critic_ema_update, actor_ema=self.cfg.diffusion.actor.ema, @@ -770,6 +761,7 @@ def compute_statistics(self, batch: Batch) -> Metric: self.cfg.critic.v_min, self.cfg.critic.v_max, self.cfg.critic.output_nodes, + self.cfg.critic.symexp_before_mean, ) return stats @@ -796,5 +788,8 @@ def sample_actions( num_samples=num_samples, solver=self.cfg.diffusion.solver, temperature=temperature, + v_min=self.cfg.critic.v_min, + v_max=self.cfg.critic.v_max, + num_atoms=self.cfg.critic.output_nodes, ) return action, {} diff --git a/flowrl/config/offline/algo/categorical_bdpo.py b/flowrl/config/offline/algo/categorical_bdpo.py index f32e697..235103c 100644 --- a/flowrl/config/offline/algo/categorical_bdpo.py +++ b/flowrl/config/offline/algo/categorical_bdpo.py @@ -34,15 +34,13 @@ class BDPODiffusionConfig(): @dataclass class BDPOCriticConfig(): discount: float - q_target: str - maxQ: bool ensemble_size: int - rho: float eta: float hidden_dims: List[int] output_nodes: int v_min: float v_max: float + symexp_before_mean: bool # if True: symexp per sample then mean over N; if False: mean over N then symexp (Jensen bias but more stable) lr: float lr_decay_steps: int diff --git a/scripts/d4rl/categorical_bdpo.sh b/scripts/d4rl/categorical_bdpo.sh index 915229f..5c4105c 100644 --- a/scripts/d4rl/categorical_bdpo.sh +++ b/scripts/d4rl/categorical_bdpo.sh @@ -36,7 +36,6 @@ SHARED_ARGS=( ANTMAZE_ARGS=( "algo.critic.layer_norm=true" - "algo.critic.maxQ=true" "algo.critic.discount=0.995" "data.norm_reward=antmaze100" "eval.num_episodes=100" @@ -46,22 +45,22 @@ ANTMAZE_ARGS=( declare -A TASK_ARGS TASK_ARGS=( - ["halfcheetah-medium-v2"]="algo.critic.eta=0.05 algo.critic.rho=0.5" - ["halfcheetah-medium-replay-v2"]="algo.critic.eta=0.05 algo.critic.rho=0.5" - ["halfcheetah-medium-expert-v2"]="algo.critic.eta=0.05 algo.critic.rho=0.5" - ["hopper-medium-v2"]="algo.critic.eta=0.2 algo.critic.rho=2.0 algo.critic.ensemble_size=20" - ["hopper-medium-replay-v2"]="algo.critic.eta=0.2 algo.critic.rho=2.0" - ["hopper-medium-expert-v2"]="algo.critic.eta=0.2 algo.critic.rho=2.0" - ["walker2d-medium-v2"]="algo.critic.eta=0.15 algo.critic.rho=1.0" - ["walker2d-medium-replay-v2"]="algo.critic.eta=0.15 algo.critic.rho=1.0" - ["walker2d-medium-expert-v2"]="algo.critic.eta=0.15 algo.critic.rho=1.0" + ["halfcheetah-medium-v2"]="algo.critic.eta=0.05" + ["halfcheetah-medium-replay-v2"]="algo.critic.eta=0.05" + ["halfcheetah-medium-expert-v2"]="algo.critic.eta=0.05" + ["hopper-medium-v2"]="algo.critic.eta=0.2 algo.critic.ensemble_size=20" + ["hopper-medium-replay-v2"]="algo.critic.eta=0.2" + ["hopper-medium-expert-v2"]="algo.critic.eta=0.2" + ["walker2d-medium-v2"]="algo.critic.eta=0.15" + ["walker2d-medium-replay-v2"]="algo.critic.eta=0.15" + ["walker2d-medium-expert-v2"]="algo.critic.eta=0.15" # antmaze - ["antmaze-umaze-v0"]="algo.critic.eta=0.5 algo.critic.rho=0.8 ${ANTMAZE_ARGS[@]}" - ["antmaze-umaze-diverse-v0"]="algo.temperature=null algo.critic.eta=0.5 algo.critic.rho=0.8 ${ANTMAZE_ARGS[@]}" - ["antmaze-medium-play-v0"]="algo.critic.eta=0.2 algo.critic.rho=0.8 ${ANTMAZE_ARGS[@]}" - ["antmaze-medium-diverse-v0"]="algo.critic.eta=0.2 algo.critic.rho=0.8 ${ANTMAZE_ARGS[@]}" - ["antmaze-large-play-v0"]="algo.critic.eta=1.0 algo.critic.rho=0.8 ${ANTMAZE_ARGS[@]}" - ["antmaze-large-diverse-v0"]="algo.critic.eta=1.0 algo.critic.rho=0.8 ${ANTMAZE_ARGS[@]}" + ["antmaze-umaze-v0"]="algo.critic.eta=0.5 ${ANTMAZE_ARGS[@]}" + ["antmaze-umaze-diverse-v0"]="algo.temperature=null algo.critic.eta=0.5 ${ANTMAZE_ARGS[@]}" + ["antmaze-medium-play-v0"]="algo.critic.eta=0.2 ${ANTMAZE_ARGS[@]}" + ["antmaze-medium-diverse-v0"]="algo.critic.eta=0.2 ${ANTMAZE_ARGS[@]}" + ["antmaze-large-play-v0"]="algo.critic.eta=1.0 ${ANTMAZE_ARGS[@]}" + ["antmaze-large-diverse-v0"]="algo.critic.eta=1.0 ${ANTMAZE_ARGS[@]}" )