diff --git a/examples/offline/config/d4rl/algo/categorical_bdpo.yaml b/examples/offline/config/d4rl/algo/categorical_bdpo.yaml new file mode 100644 index 0000000..f7e86c4 --- /dev/null +++ b/examples/offline/config/d4rl/algo/categorical_bdpo.yaml @@ -0,0 +1,61 @@ +# @package _global_ + +data: + norm_obs: false + norm_reward: iql_mujoco + +eval: + num_samples: 10 + interval: 50000 + +train_steps: 2_000_000 +pretrain_steps: 2_000_000 + +algo: + name: categorical_bdpo + warmup_steps: 50_000 + temperature: 0.0 + diffusion: + actor: + lr: 0.00001 + lr_decay_steps: ${train_steps} + clip_grad_norm: 1.0 + ema: 0.005 + ema_every: 1 + behavior: + lr: 0.0003 + lr_decay_steps: null + clip_grad_norm: 1.0 + ema: 0.005 + ema_every: 1 + noise_schedule: vp + resnet: false + dropout: 0.0 + layer_norm: false + time_dim: 64 + mlp_hidden_dims: [256, 256, 256, 256] + resnet_hidden_dims: [256, 256, 256] + solver: ddpm + steps: 5 + clip_sampler: true + x_min: -1.0 + x_max: 1.0 + critic: + discount: 0.99 + ensemble_size: 10 + eta: 1.0 + hidden_dims: [256, 256, 256] + output_nodes: 101 + v_min: -10.0 + v_max: 10.0 + symexp_before_mean: false + lr: 0.0003 + lr_decay_steps: ${train_steps} + clip_grad_norm: 1.0 + layer_norm: false + ema: 0.005 + ema_every: 1 + steps: 5 + num_samples: 10 + solver: ddpm + update_ratio: 5 diff --git a/examples/offline/main_d4rl.py b/examples/offline/main_d4rl.py index 90d6fc4..58bb580 100644 --- a/examples/offline/main_d4rl.py +++ b/examples/offline/main_d4rl.py @@ -22,6 +22,7 @@ SUPPORTED_AGENTS: Dict[str, Type[BaseAgent]] = { "iql": IQLAgent, "bdpo": BDPOAgent, + "categorical_bdpo": CategoricalBDPOAgent, "ivr": IVRAgent, "fql": FQLAgent, "dac": DACAgent, diff --git a/flowrl/agent/offline/__init__.py b/flowrl/agent/offline/__init__.py index 2493492..78168ba 100644 --- a/flowrl/agent/offline/__init__.py +++ b/flowrl/agent/offline/__init__.py @@ -1,5 +1,6 @@ from ..base import BaseAgent from .bdpo.bdpo import BDPOAgent +from .bdpo.categorical_bdpo import CategoricalBDPOAgent from .dac import DACAgent from .dql import DQLAgent from .dtql import DTQLAgent diff --git a/flowrl/agent/offline/bdpo/bdpo.py b/flowrl/agent/offline/bdpo/bdpo.py index 86a5653..447b98f 100644 --- a/flowrl/agent/offline/bdpo/bdpo.py +++ b/flowrl/agent/offline/bdpo/bdpo.py @@ -307,6 +307,75 @@ def actor_loss_fn(actor_params: Param, *args, **kwargs) -> Tuple[jnp.ndarray, Me return rng, new_actor, new_actor_target, metrics +@partial(jax.jit, static_argnames=("T", "num_samples")) +def jit_compute_statistics( + rng: PRNGKey, + vt_target: Model, + q0_target: Model, + actor: DDPM, + actor_target: DDPM, + behavior_target: DDPM, + batch: Batch, + T: int, + num_samples: int, +) -> Tuple[PRNGKey, Metric]: + """Compute per-timestep statistics for debugging and hyperparameter tuning.""" + alpha_hats = actor_target.alpha_hats + + B = batch.obs.shape[0] + t_proto = jnp.ones((*batch.obs.shape[:-1], num_samples, 1), dtype=jnp.int32) + + obs = batch.obs[..., jnp.newaxis, :].repeat(num_samples, axis=-2) # (B, N, obs_dim) + x0 = batch.action[..., jnp.newaxis, :].repeat(num_samples, axis=-2) # (B, N, act_dim) + + tot_metrics = {} + for i in range(0, T + 1): + rng, eps_key = jax.random.split(rng) + eps_sample = jax.random.normal(eps_key, x0.shape, dtype=jnp.float32) + t = t_proto * i + + xt = jnp.sqrt(alpha_hats[i]) * x0 + jnp.sqrt(1 - alpha_hats[i]) * eps_sample + + # vt critic values + vt_val = vt_target(obs, xt, t) # (E, B, N, 1) + + # noise predictions + eps_actor = actor(xt, t, condition=obs) + eps_actor_target = actor_target(xt, t, condition=obs) + eps_behavior_target = behavior_target(xt, t, condition=obs) + bc_loss = 0.5 * ((eps_actor - eps_behavior_target) ** 2).sum(axis=-1).mean() + bc_target_loss = 0.5 * ((eps_actor_target - eps_behavior_target) ** 2).sum(axis=-1).mean() + + # vt gradient w.r.t. action + def vt_fn_single(a, s, t_): + return vt_target(s, a, t_).mean(axis=0).squeeze() + vt_grad = jax.vmap(jax.grad(vt_fn_single))( + xt.reshape(B * num_samples, -1), + obs.reshape(B * num_samples, -1), + t.reshape(B * num_samples, -1), + ) + + # q0 gradient w.r.t. action + def q0_fn_single(a, s): + return q0_target(s, a).mean(axis=0).squeeze() + q0_grad = jax.vmap(jax.grad(q0_fn_single))( + xt.reshape(B * num_samples, -1), + obs.reshape(B * num_samples, -1), + ) + + tot_metrics.update({ + f"stats/vt_std_e_{i}": vt_val.std(axis=0).mean(), + f"stats/vt_std_a_{i}": vt_val.mean(axis=0).std(axis=-2).mean(), + f"stats/bc_loss_{i}": bc_loss, + f"stats/bc_target_loss_{i}": bc_target_loss, + f"stats/vt_grad_{i}": jnp.abs(vt_grad).mean(), + f"stats/q0_grad_{i}": jnp.abs(q0_grad).mean(), + f"stats/vt_mean_{i}": vt_val.mean(), + }) + + return rng, tot_metrics + + class BDPOAgent(BaseAgent): """ Behavior-Regularized Diffusion Policy Optimization (BDPO) @@ -518,6 +587,20 @@ def pretrain_step(self, batch: Batch, step: int) -> Metric: self._n_pretraining_steps += 1 return metrics + def compute_statistics(self, batch: Batch) -> Metric: + self.rng, stats = jit_compute_statistics( + self.rng, + self.vt_target, + self.q0_target, + self.actor, + self.actor_target, + self.behavior_target, + batch, + self.cfg.diffusion.steps, + self.cfg.critic.num_samples, + ) + return stats + def sample_actions( self, obs: jnp.ndarray, diff --git a/flowrl/agent/offline/bdpo/categorical_bdpo.py b/flowrl/agent/offline/bdpo/categorical_bdpo.py new file mode 100644 index 0000000..6cd4aa4 --- /dev/null +++ b/flowrl/agent/offline/bdpo/categorical_bdpo.py @@ -0,0 +1,795 @@ +from functools import partial +from typing import List + +import flax.linen as nn +import jax +import jax.numpy as jnp +import optax + +from flowrl.agent.base import BaseAgent +from flowrl.config.offline.algo.categorical_bdpo import ( + CategoricalBDPOConfig, + BDPODiffusionConfig, + BDPODiffusionTrainConfig, +) +from flowrl.flow.ddpm import DDPM, DDPMBackbone, jit_update_ddpm +from flowrl.functional.activation import mish +from flowrl.functional.ema import ema_update +from flowrl.functional.math import symlog, symexp +from flowrl.module.critic import CategoricalCritic, CategoricalCriticWithDiscreteTime, Ensemblize +from flowrl.module.mlp import MLP, ResidualMLP +from flowrl.module.model import Model +from flowrl.module.time_embedding import PositionalEmbedding +from flowrl.types import * + +EPS = 1e-6 + +def get_penalty( + actor_eps, + behavior_eps, + t: jnp.ndarray, + T: int, + alphas: jnp.ndarray, + alpha_hats: jnp.ndarray, + betas: jnp.ndarray +) -> jnp.ndarray: + return 0.5 * betas[t] * ((actor_eps - behavior_eps)**2) / (1 - betas[t]) / (1 - alpha_hats[t]) + +def get_atoms(v_min: float, v_max: float, num_atoms: int) -> jnp.ndarray: + """Get support atoms (pre-transformation) for categorical distribution.""" + return jnp.linspace(v_min, v_max, num_atoms) + +def logits_to_value( + logits: jnp.ndarray, + v_min: float, + v_max: float, + num_atoms: int, + symexp_before_mean: bool, +) -> jnp.ndarray: + """Expected value (post-transformation) from categorical logits. + + Args: + logits: raw logits, shape (E, B, N, num_atoms) + v_min: minimum support value + v_max: maximum support value + num_atoms: number of atoms in the categorical distribution + symexp_before_mean: if True, apply symexp per sample before averaging over N + (consistent with per-sample action selection; unbiased mean in original space); + if False, average over N in symlog space then apply symexp once + (biased by Jensen's inequality but potentially more stable). + Returns: + expected value, shape (B, 1) + """ + # The computation proceeds in the following order: + # 1. softmax over atoms + # 2. mean over ensemble members + # 3. compute expected value + # 4. mean over num_q_samples dimension + # 5. symexp + # Steps 1, 2, and 3 are in the correct relative order, consistent with BRC. + # Discussion or the order: https://github.com/typoverflow/flow-rl/pull/25#discussion_r2839940571 + E, B, N, _ = logits.shape # assertion for the shape of logits, will raise an error if the shape is not as expected + atoms = get_atoms(v_min, v_max, num_atoms) # (num_atoms,) + probs = jax.nn.softmax(logits, axis=-1).mean(axis=0) # (B, N, num_atoms) + value = (probs * atoms).sum(axis=-1, keepdims=True) # (B, N, 1), in symlog space + if symexp_before_mean: + value = symexp(value).mean(axis=1) # symexp per sample, then mean → (B, 1) + else: + value = symexp(value.mean(axis=1)) # mean in symlog space, then symexp → (B, 1) + return value + +def categorical_project( + target_values: jnp.ndarray, + v_min: float, + v_max: float, + num_atoms: int +) -> Tuple[jnp.ndarray, Metric]: + """C51-style projection: scalar targets (original-space) → categorical distribution. + + Args: + target_values: shape (*B, 1) + v_min: minimum support value + v_max: maximum support value + num_atoms: number of atoms in the categorical distribution + Returns: + target distribution, shape (*B, num_atoms) + metrics dict with pre-clip statistics (keys: pre_clip_max, pre_clip_min, + pre_clip_mean, clip_ratio; all computed in the transformed/symlog space) + """ + delta_z = (v_max - v_min) / (num_atoms - 1) + + symlog_values = symlog(target_values) + clip_ratio = ((symlog_values < v_min) | (symlog_values > v_max)).mean() + metrics = { + "pre_clip_max": symlog_values.max(), + "pre_clip_min": symlog_values.min(), + "pre_clip_mean": symlog_values.mean(), + "clip_ratio": clip_ratio, + } + + clipped = jnp.clip(symlog_values, v_min, v_max) + b = (clipped - v_min) / delta_z # (*B, 1) + l = jnp.floor(b).astype(jnp.int32) # (*B, 1) + u = jnp.ceil(b).astype(jnp.int32) # (*B, 1) + + d_u = (b - l.astype(jnp.float32)) # (*B, 1) + d_l = 1.0 - d_u # (*B, 1) + + probs = (d_l * jax.nn.one_hot(l.squeeze(-1), num_atoms) + + d_u * jax.nn.one_hot(u.squeeze(-1), num_atoms)) + return probs, metrics + +@partial(jax.jit, static_argnames=("training", "num_samples", "solver", "temperature", "v_min", "v_max", "num_atoms")) +def jit_sample_and_select( + rng: PRNGKey, + model: DDPM, + q0: Model, + obs: jnp.ndarray, + training: bool, + num_samples: int, + solver: str, + temperature: float, + v_min: float, + v_max: float, + num_atoms: int, +) -> Tuple[PRNGKey, jnp.ndarray]: + B = obs.shape[0] + rng, xT_rng = jax.random.split(rng) + + # sample + obs_repeat = obs[..., jnp.newaxis, :].repeat(num_samples, axis=-2) + xT = jax.random.normal(xT_rng, (*obs_repeat.shape[:-1], model.x_dim)) + rng, actions, _ = model.sample(rng, xT, obs_repeat, training, solver) + if temperature is None: + return rng, actions[:, 0] + else: + qs_logits = q0( + obs_repeat, + actions, + ) # (E, B, num_samples, num_atoms) + atoms = get_atoms(v_min, v_max, num_atoms) + probs = jax.nn.softmax(qs_logits, axis=-1).mean(axis=0) # (B, num_samples, num_atoms) + qs = symexp((probs * atoms).sum(axis=-1)) # (B, num_samples) + if temperature <= 0.0: + idx = qs.argmax(axis=-1) + else: + rng, select_rng = jax.random.split(rng) + idx = jax.random.categorical(select_rng, logits=qs/(1e-6+temperature), axis=-1) + actions = actions.reshape(B, num_samples, -1)[jnp.arange(B), idx] + return rng, actions + +@partial(jax.jit, static_argnames=("ema", "do_ema_update")) +def jit_update_behavior( + rng: PRNGKey, + behavior: DDPM, + behavior_target: DDPM, + batch: Batch, + ema: float, + do_ema_update: bool, +) -> Tuple[PRNGKey, Model, Model, Metric]: + rng, new_behavior, metrics = jit_update_ddpm(rng, behavior, batch.action, batch.obs) + if do_ema_update: + new_behavior_target = ema_update(new_behavior, behavior_target, ema) + else: + new_behavior_target = behavior_target + return rng, new_behavior, new_behavior_target, metrics + +def update_critic( + rng: PRNGKey, + q0: Model, + q0_target: Model, + vt: Model, + vt_target: Model, + actor_target: DDPM, + behavior_target: DDPM, + batch: Batch, + v_min: float, + v_max: float, + num_atoms: int, + T: int, + discount: float, + eta: float, + num_q_samples: int, + solver: str, + ema: float, + do_ema_update: bool, + symexp_before_mean: bool, +) -> Tuple[PRNGKey, Model, Model, Model, Model, Metric]: + B = batch.obs.shape[0] + A = batch.action.shape[-1] + alphas = actor_target.alphas + alpha_hats = actor_target.alpha_hats + betas = actor_target.betas + + rng, xT_rng = jax.random.split(rng) + + # q0 target + next_obs_repeat = batch.next_obs[..., jnp.newaxis, :].repeat(num_q_samples, axis=-2) + xT = jax.random.normal(xT_rng, (*next_obs_repeat.shape[:-1], A)) + rng, next_action, _ = actor_target.sample( + rng, + xT, + next_obs_repeat, + training=False, + solver=solver, + ) + q0_target_value = q0_target( + next_obs_repeat, + next_action + ) # (E, B, N, num_atoms) + q0_target_value = logits_to_value(q0_target_value, v_min, v_max, num_atoms, symexp_before_mean) # (B, 1) + + q0_target_value = batch.reward + discount * (1-batch.terminal) * q0_target_value # (B, 1) + q0_target_probs, q0_proj_metrics = categorical_project(q0_target_value, v_min, v_max, num_atoms) # (B, num_atoms) + + # vt target + obs_repeat = batch.obs[..., jnp.newaxis, :].repeat(num_q_samples, axis=-2) + rng, rep_xt_1, xt, rep_t_1, t, history = actor_target.onestep_sample( + rng, + batch.action, + batch.obs, + training=False, + num_samples=num_q_samples, + solver=solver, + sample_xt=True, + t=None, + ) + vt_target_value1 = vt_target( + obs_repeat, + rep_xt_1, + rep_t_1 + ) # (E, B, N, num_atoms) + vt_target_value1 = logits_to_value(vt_target_value1, v_min, v_max, num_atoms, symexp_before_mean) # (B, 1) + vt_target_value2 = q0_target( + obs_repeat, + rep_xt_1 + ) # (E, B, N, num_atoms) + vt_target_value2 = logits_to_value(vt_target_value2, v_min, v_max, num_atoms, symexp_before_mean) # (B, 1) + vt_target_value = (t != 1) * vt_target_value1 + (t == 1) * vt_target_value2 + + vt_actor_eps = history + vt_behavior_eps = behavior_target(xt, t, batch.obs) + vt_penalty = get_penalty(vt_actor_eps, vt_behavior_eps, t, T, alphas, alpha_hats, betas) + + vt_target_value = vt_target_value - eta * vt_penalty.sum(axis=-1, keepdims=True) # (B, 1) + vt_target_probs, vt_proj_metrics = categorical_project(vt_target_value, v_min, v_max, num_atoms) # (B, num_atoms) + + def q0_loss_fn(q0_params: Param, *args, **kwargs) -> Tuple[jnp.ndarray, Metric]: + pred = q0.apply( + {"params": q0_params}, + batch.obs, + batch.action + ) # (E, B, num_atoms) + log_probs = jax.nn.log_softmax(pred, axis=-1) + loss = -(q0_target_probs[jnp.newaxis, :] * log_probs).sum(axis=-1).mean() + E, B, _ = pred.shape + q0_value = logits_to_value(pred.reshape(E, B, 1, num_atoms), v_min, v_max, num_atoms, symexp_before_mean) # (B, 1) + return loss, { + "loss/q0_loss": loss, + "misc/q0_mean": q0_value.mean(), + "misc/reward": batch.reward.mean(), + "misc/q0_target_symlog_max": q0_proj_metrics["pre_clip_max"], + "misc/q0_target_symlog_min": q0_proj_metrics["pre_clip_min"], + "misc/q0_target_symlog_mean": q0_proj_metrics["pre_clip_mean"], + "misc/q0_target_clip_ratio": q0_proj_metrics["clip_ratio"], + } + def vt_loss_fn(vt_params: Param, *args, **kwargs) -> Tuple[jnp.ndarray, Metric]: + pred = vt.apply( + {"params": vt_params}, + batch.obs, + xt, + t + ) # (E, B, num_atoms) + # sum over atoms (cross-entropy loss) → mean over batch → mean over ensemble members → scalar + log_probs = jax.nn.log_softmax(pred, axis=-1) + loss = -(vt_target_probs[jnp.newaxis, :] * log_probs).sum(axis=-1).mean() + E, B, _ = pred.shape + vt_value = logits_to_value(pred.reshape(E, B, 1, num_atoms), v_min, v_max, num_atoms, symexp_before_mean) # (B, 1) + return loss, { + "loss/vt_loss": loss, + "misc/vt_mean": vt_value.mean(), + "misc/vt_logits_std": jax.nn.softmax(pred, axis=-1).std(axis=0).mean(), + "misc/vt_penalty": vt_penalty.mean(), + "misc/vt_target_symlog_max": vt_proj_metrics["pre_clip_max"], + "misc/vt_target_symlog_min": vt_proj_metrics["pre_clip_min"], + "misc/vt_target_symlog_mean": vt_proj_metrics["pre_clip_mean"], + "misc/vt_target_clip_ratio": vt_proj_metrics["clip_ratio"], + } + new_q0, q0_metrics = q0.apply_gradient(q0_loss_fn) + new_vt, vt_metrics = vt.apply_gradient(vt_loss_fn) + if do_ema_update: + new_q0_target = ema_update(new_q0, q0_target, ema) + new_vt_target = ema_update(new_vt, vt_target, ema) + else: + new_q0_target = q0_target + new_vt_target = vt_target + return rng, new_q0, new_q0_target, new_vt, new_vt_target, { + **q0_metrics, + **vt_metrics + } + +def update_actor( + rng: PRNGKey, + actor: DDPM, + actor_target: DDPM, + behavior_target: DDPM, + q0_target: Model, + vt_target: Model, + batch: Batch, + v_min: float, + v_max: float, + num_atoms: int, + T: int, + eta, + num_q_samples: int, + solver: str, + ema: float, + do_ema_update: bool, + symexp_before_mean: bool, +) -> Tuple[PRNGKey, Model, Model, Metric]: + alphas = actor.alphas + alpha_hats = actor.alpha_hats + betas = actor.betas + + rng, sample_rng = jax.random.split(rng) + + def actor_loss_fn(actor_params: Param, *args, **kwargs) -> Tuple[jnp.ndarray, Metric]: + sample_rng_, rep_xt_1, xt, rep_t_1, t, history = actor.onestep_sample( + sample_rng, + batch.action, + batch.obs, + training=True, + num_samples=num_q_samples, + solver=solver, + sample_xt=True, + t=None, + params=actor_params + ) + obs_repeat = batch.obs[..., jnp.newaxis, :].repeat(num_q_samples, axis=-2) + target_1 = vt_target( + obs_repeat, + rep_xt_1, + rep_t_1 + ) # (E, B, N, num_atoms) + target_1 = logits_to_value(target_1, v_min, v_max, num_atoms, symexp_before_mean) # (B, 1) + target_2 = q0_target( + obs_repeat, + rep_xt_1 + ) # (E, B, N, num_atoms) + target_2 = logits_to_value(target_2, v_min, v_max, num_atoms, symexp_before_mean) # (B, 1) + target = (t != 1) * target_1 + (t == 1) * target_2 + + actor_eps = history + behavior_eps = behavior_target(xt, t, batch.obs) + penalty = get_penalty(actor_eps, behavior_eps, t, T, alphas, alpha_hats, betas) + target = target - eta * penalty.sum(axis=-1, keepdims=True) + actor_loss = - target.mean() + return actor_loss, { + "loss/actor_loss": actor_loss + } + + new_actor, metrics = actor.apply_gradient(actor_loss_fn) + if do_ema_update: + new_actor_target = ema_update(new_actor, actor_target, ema) + else: + new_actor_target = actor_target + metrics["misc/eta"] = eta + return rng, new_actor, new_actor_target, metrics + +@partial(jax.jit, static_argnames=( + "do_actor_update", + "v_min", + "v_max", + "num_atoms", + "diffusion_steps", + "discount", + "eta", + "num_q_samples", + "solver", + "critic_ema", + "do_critic_ema_update", + "actor_ema", + "do_actor_ema_update", + "symexp_before_mean", +)) +def jit_train_step( + rng: PRNGKey, + # models + q0: Model, + q0_target: Model, + vt: Model, + vt_target: Model, + actor: DDPM, + actor_target: DDPM, + behavior_target: DDPM, + batch: Batch, + do_actor_update: bool, + # categorical parameters + v_min: float, + v_max: float, + num_atoms: int, + # other hyperparameters + diffusion_steps: int, + discount: float, + eta: float, + num_q_samples: int, + solver: str, + symexp_before_mean: bool, + # ema update + critic_ema: float, + do_critic_ema_update: bool, + actor_ema: float, + do_actor_ema_update: bool, +): + metrics = {} + rng, q0, q0_target, vt, vt_target, critic_metrics = update_critic( + rng=rng, + q0=q0, + q0_target=q0_target, + vt=vt, + vt_target=vt_target, + actor_target=actor_target, + behavior_target=behavior_target, + batch=batch, + v_min=v_min, + v_max=v_max, + num_atoms=num_atoms, + T=diffusion_steps, + discount=discount, + eta=eta, + num_q_samples=num_q_samples, + solver=solver, + ema=critic_ema, + do_ema_update=do_critic_ema_update, + symexp_before_mean=symexp_before_mean, + ) + metrics.update(critic_metrics) + if do_actor_update: + rng, actor, actor_target, actor_metrics = update_actor( + rng=rng, + actor=actor, + actor_target=actor_target, + behavior_target=behavior_target, + q0_target=q0_target, + vt_target=vt_target, + batch=batch, + v_min=v_min, + v_max=v_max, + num_atoms=num_atoms, + T=diffusion_steps, + eta=eta, + num_q_samples=num_q_samples, + solver=solver, + ema=actor_ema, + do_ema_update=do_actor_ema_update, + symexp_before_mean=symexp_before_mean, + ) + metrics.update(actor_metrics) + return rng, q0, q0_target, vt, vt_target, actor, actor_target, metrics + + +@partial(jax.jit, static_argnames=("T", "num_samples", "v_min", "v_max", "num_atoms", "symexp_before_mean")) +def jit_compute_statistics( + rng: PRNGKey, + vt_target: Model, + q0_target: Model, + actor: DDPM, + actor_target: DDPM, + behavior_target: DDPM, + batch: Batch, + T: int, + num_samples: int, + v_min: float, + v_max: float, + num_atoms: int, + symexp_before_mean: bool, +) -> Tuple[PRNGKey, Metric]: + """Compute per-timestep statistics for debugging and hyperparameter tuning.""" + alpha_hats = actor_target.alpha_hats + + B = batch.obs.shape[0] + t_proto = jnp.ones((*batch.obs.shape[:-1], num_samples, 1), dtype=jnp.int32) + + obs = batch.obs[..., jnp.newaxis, :].repeat(num_samples, axis=-2) # (B, N, obs_dim) + x0 = batch.action[..., jnp.newaxis, :].repeat(num_samples, axis=-2) # (B, N, act_dim) + + tot_metrics = {} + for i in range(0, T + 1): + rng, eps_key = jax.random.split(rng) + eps_sample = jax.random.normal(eps_key, x0.shape, dtype=jnp.float32) + t = t_proto * i + + xt = jnp.sqrt(alpha_hats[i]) * x0 + jnp.sqrt(1 - alpha_hats[i]) * eps_sample + + # vt critic values (categorical: logits -> value) + vt_logits = vt_target(obs, xt, t) # (E, B, N, num_atoms) + vt_val = logits_to_value(vt_logits, v_min, v_max, num_atoms, symexp_before_mean) # (B, 1) + + # per-ensemble-member values for std computation + atoms = get_atoms(v_min, v_max, num_atoms) + probs = jax.nn.softmax(vt_logits, axis=-1) # (E, B, N, num_atoms) + vt_per_member = symexp((probs * atoms).sum(axis=-1, keepdims=True)) # (E, B, N, 1) + + # noise predictions + eps_actor = actor(xt, t, condition=obs) + eps_actor_target = actor_target(xt, t, condition=obs) + eps_behavior_target = behavior_target(xt, t, condition=obs) + bc_loss = 0.5 * ((eps_actor - eps_behavior_target) ** 2).sum(axis=-1).mean() + bc_target_loss = 0.5 * ((eps_actor_target - eps_behavior_target) ** 2).sum(axis=-1).mean() + + # vt gradient w.r.t. action + def vt_fn_single(a, s, t_): + logits = vt_target(s, a, t_) # (E, num_atoms) + logits = logits[:, jnp.newaxis, jnp.newaxis, :] # (E, 1, 1, num_atoms) + return logits_to_value(logits, v_min, v_max, num_atoms, symexp_before_mean).squeeze() + vt_grad = jax.vmap(jax.grad(vt_fn_single))( + xt.reshape(B * num_samples, -1), + obs.reshape(B * num_samples, -1), + t.reshape(B * num_samples, -1), + ) + + # q0 gradient w.r.t. action + def q0_fn_single(a, s): + logits = q0_target(s, a) # (E, num_atoms) + logits = logits[:, jnp.newaxis, jnp.newaxis, :] # (E, 1, 1, num_atoms) + return logits_to_value(logits, v_min, v_max, num_atoms, symexp_before_mean).squeeze() + q0_grad = jax.vmap(jax.grad(q0_fn_single))( + xt.reshape(B * num_samples, -1), + obs.reshape(B * num_samples, -1), + ) + + tot_metrics.update({ + f"stats/vt_std_e_{i}": vt_per_member.std(axis=0).mean(), + f"stats/vt_std_a_{i}": vt_per_member.mean(axis=0).std(axis=-2).mean(), + f"stats/bc_loss_{i}": bc_loss, + f"stats/bc_target_loss_{i}": bc_target_loss, + f"stats/vt_grad_{i}": jnp.abs(vt_grad).mean(), + f"stats/q0_grad_{i}": jnp.abs(q0_grad).mean(), + f"stats/vt_mean_{i}": vt_val.mean(), + }) + + return rng, tot_metrics + + +class CategoricalBDPOAgent(BaseAgent): + """ + Categorical version of Behavior-Regularized Diffusion Policy Optimization (BDPO) + https://arxiv.org/abs/2502.04778 + """ + name = "CategoricalBDPOAgent" + model_names = ["behavior", "behavior_target", "actor", "actor_target", "q0", "q0_target", "vt", "vt_target"] + + def __init__(self, obs_dim: int, act_dim: int, cfg: CategoricalBDPOConfig, seed: int): + super().__init__(obs_dim, act_dim, cfg, seed) + self.cfg = cfg + self.rng, behavior_rng, actor_rng, q0_rng, vt_rng = jax.random.split(self.rng, 5) + + # define behavior and actor + time_embedding = PositionalEmbedding(output_dim=cfg.diffusion.time_dim) + if cfg.diffusion.resnet: + noise_predictor = ResidualMLP( + hidden_dims=cfg.diffusion.resnet_hidden_dims, + output_dim=act_dim, + activation=mish, + layer_norm=True, + dropout=cfg.diffusion.dropout, + multiplier=1, + ) + else: + noise_predictor = MLP( + hidden_dims=cfg.diffusion.mlp_hidden_dims, + output_dim=act_dim, + activation=mish, + layer_norm=cfg.diffusion.layer_norm, + dropout=cfg.diffusion.dropout + ) + backbone_def = DDPMBackbone( + noise_predictor=noise_predictor, + time_embedding=time_embedding + ) + + def create_ddpm(network_def: nn.Module, rng: PRNGKey, cfg: BDPODiffusionConfig, train_cfg: BDPODiffusionTrainConfig): + if train_cfg.lr_decay_steps is not None: + lr = optax.cosine_decay_schedule(train_cfg.lr, train_cfg.lr_decay_steps) + else: + lr = train_cfg.lr + ddpm = DDPM.create( + network=network_def, + rng=rng, + inputs=(jnp.ones((1, self.act_dim)), jnp.zeros((1, 1)), jnp.ones((1, self.obs_dim)), ), + x_dim=self.act_dim, + steps=cfg.steps, + noise_schedule=cfg.noise_schedule, + noise_schedule_params=None, + approx_postvar=True, + clip_sampler=cfg.clip_sampler, + x_min=cfg.x_min, + x_max=cfg.x_max, + optimizer=optax.adam(learning_rate=lr), + clip_grad_norm=train_cfg.clip_grad_norm, + ) + ddpm_target = DDPM.create( + network=network_def, + rng=rng, + x_dim=self.act_dim, + inputs=(jnp.ones((1, self.act_dim)), jnp.zeros((1, 1)), jnp.ones((1, self.obs_dim)), ), + steps=cfg.steps, + noise_schedule=cfg.noise_schedule, + noise_schedule_params=None, + approx_postvar=True, + clip_sampler=cfg.clip_sampler, + x_min=cfg.x_min, + x_max=cfg.x_max, + ) + return ddpm, ddpm_target + + self.behavior, self.behavior_target = create_ddpm(backbone_def, behavior_rng, cfg.diffusion, cfg.diffusion.behavior) + self.actor, self.actor_target = create_ddpm(backbone_def, actor_rng, cfg.diffusion, cfg.diffusion.actor) + + # define critic networks + if cfg.critic.lr_decay_steps is not None: + q0_lr = optax.cosine_decay_schedule(cfg.critic.lr, cfg.critic.lr_decay_steps) + vt_lr = optax.cosine_decay_schedule(cfg.critic.lr, cfg.critic.lr_decay_steps) + else: + q0_lr = cfg.critic.lr + vt_lr = cfg.critic.lr + q0_def = Ensemblize( + base=CategoricalCritic( + backbone=MLP( + hidden_dims=cfg.critic.hidden_dims, + activation=mish, + layer_norm=cfg.critic.layer_norm, + ), + output_nodes=cfg.critic.output_nodes, + ), + ensemble_size=cfg.critic.ensemble_size, + ) + vt_def = Ensemblize( + base=CategoricalCriticWithDiscreteTime( + backbone=MLP( + hidden_dims=cfg.critic.hidden_dims, + activation=mish, + layer_norm=True, + ), + time_embedding=time_embedding, + output_nodes=cfg.critic.output_nodes, + ), + ensemble_size=cfg.critic.ensemble_size, + ) + self.q0 = Model.create( + q0_def, + q0_rng, + inputs=(jnp.ones((1, self.obs_dim)), jnp.ones((1, self.act_dim))), + optimizer=optax.adam(learning_rate=q0_lr), + clip_grad_norm=cfg.critic.clip_grad_norm, + ) + self.q0_target = Model.create( + q0_def, + q0_rng, + inputs=(jnp.ones((1, self.obs_dim)), jnp.ones((1, self.act_dim))), + ) + self.vt = Model.create( + vt_def, + vt_rng, + inputs=(jnp.ones((1, self.obs_dim)), jnp.ones((1, self.act_dim)), jnp.zeros((1, 1))), + optimizer=optax.adam(learning_rate=vt_lr), + clip_grad_norm = cfg.critic.clip_grad_norm, + ) + self.vt_target = Model.create( + vt_def, + vt_rng, + inputs=(jnp.ones((1, self.obs_dim)), jnp.ones((1, self.act_dim)), jnp.zeros((1, 1))), + ) + + self.warmup_steps = cfg.warmup_steps + self._is_pretraining = True # will switch to False after prepared for training + self._n_pretraining_steps = 0 + self._n_training_steps = 0 + + @property + def saved_model_names(self) -> List[str]: + if self._is_pretraining: + return ["behavior_target"] + else: + return self.model_names + + def prepare_training(self): + self.actor = ema_update(self.behavior_target, self.actor, 1.0) + self.actor_target = ema_update(self.behavior_target, self.actor_target, 1.0) + self._is_pretraining = False + + def train_step(self, batch: Batch, step: int) -> Metric: + do_actor_update = (self._n_training_steps >= self.warmup_steps)\ + and (self._n_training_steps % self.cfg.critic.update_ratio == 0) + do_critic_ema_update = self._n_training_steps % self.cfg.critic.ema_every == 0 + if do_actor_update: + actor_step = self._n_training_steps // self.cfg.critic.update_ratio + do_actor_ema_update = actor_step % self.cfg.diffusion.actor.ema_every == 0 + else: + do_actor_ema_update = False + self.rng, self.q0, self.q0_target, self.vt, self.vt_target, self.actor, self.actor_target, metrics = jit_train_step( + rng=self.rng, + q0=self.q0, + q0_target=self.q0_target, + vt=self.vt, + vt_target=self.vt_target, + actor=self.actor, + actor_target=self.actor_target, + behavior_target=self.behavior_target, + batch=batch, + do_actor_update=do_actor_update, + v_min=self.cfg.critic.v_min, + v_max=self.cfg.critic.v_max, + num_atoms=self.cfg.critic.output_nodes, + diffusion_steps=self.cfg.diffusion.steps, + discount=self.cfg.critic.discount, + eta=self.cfg.critic.eta, + num_q_samples=self.cfg.critic.num_samples, + solver=self.cfg.critic.solver, + symexp_before_mean=self.cfg.critic.symexp_before_mean, + critic_ema=self.cfg.critic.ema, + do_critic_ema_update=do_critic_ema_update, + actor_ema=self.cfg.diffusion.actor.ema, + do_actor_ema_update=do_actor_ema_update, + ) + self._n_training_steps += 1 + return metrics + + def pretrain_step(self, batch: Batch, step: int) -> Metric: + self.rng, self.behavior, self.behavior_target, metrics = jit_update_behavior( + self.rng, + self.behavior, + self.behavior_target, + batch, + self.cfg.diffusion.behavior.ema, + self._n_pretraining_steps % self.cfg.diffusion.behavior.ema_every == 0 + ) + self._n_pretraining_steps += 1 + return metrics + + def compute_statistics(self, batch: Batch) -> Metric: + self.rng, stats = jit_compute_statistics( + self.rng, + self.vt_target, + self.q0_target, + self.actor, + self.actor_target, + self.behavior_target, + batch, + self.cfg.diffusion.steps, + self.cfg.critic.num_samples, + self.cfg.critic.v_min, + self.cfg.critic.v_max, + self.cfg.critic.output_nodes, + self.cfg.critic.symexp_before_mean, + ) + return stats + + def sample_actions( + self, + obs: jnp.ndarray, + deterministic: bool = True, + num_samples: int = 1, + ) -> Tuple[jnp.ndarray, Metric]: + if self._is_pretraining: + use_model = self.behavior_target + temperature = None + num_samples = 1 + else: + use_model = self.actor_target + temperature = self.cfg.temperature + num_samples = num_samples + self.rng, action = jit_sample_and_select( + self.rng, + use_model, + self.q0_target, + obs, + training=False, + num_samples=num_samples, + solver=self.cfg.diffusion.solver, + temperature=temperature, + v_min=self.cfg.critic.v_min, + v_max=self.cfg.critic.v_max, + num_atoms=self.cfg.critic.output_nodes, + ) + return action, {} diff --git a/flowrl/config/offline/__init__.py b/flowrl/config/offline/__init__.py index 9d6a01a..e1d495f 100644 --- a/flowrl/config/offline/__init__.py +++ b/flowrl/config/offline/__init__.py @@ -2,6 +2,7 @@ from .algo.base import BaseAlgoConfig from .algo.bdpo import BDPOConfig +from .algo.categorical_bdpo import CategoricalBDPOConfig from .algo.dac import DACConfig from .algo.dql import DQLConfig from .algo.dtql import DTQLConfig @@ -21,6 +22,7 @@ _CONFIGS = { "iql": IQLConfig, "bdpo": BDPOConfig, + "categorical_bdpo": CategoricalBDPOConfig, "ivr": IVRConfig, "dac": DACConfig, "dql": DQLConfig, diff --git a/flowrl/config/offline/algo/categorical_bdpo.py b/flowrl/config/offline/algo/categorical_bdpo.py new file mode 100644 index 0000000..235103c --- /dev/null +++ b/flowrl/config/offline/algo/categorical_bdpo.py @@ -0,0 +1,62 @@ +from dataclasses import dataclass, field +from typing import List, Optional + +from omegaconf import MISSING + +from .base import BaseAlgoConfig + + +@dataclass +class BDPODiffusionTrainConfig(): + lr: float + lr_decay_steps: int + clip_grad_norm: float + ema: float + ema_every: float + +@dataclass +class BDPODiffusionConfig(): + actor: BDPODiffusionTrainConfig + behavior: BDPODiffusionTrainConfig + noise_schedule: str + resnet: bool + dropout: Optional[float] + layer_norm: bool + time_dim: int + mlp_hidden_dims: List[int] + resnet_hidden_dims: List[int] + solver: str + steps: int + clip_sampler: bool + x_min: float + x_max: float + +@dataclass +class BDPOCriticConfig(): + discount: float + ensemble_size: int + eta: float + hidden_dims: List[int] + output_nodes: int + v_min: float + v_max: float + symexp_before_mean: bool # if True: symexp per sample then mean over N; if False: mean over N then symexp (Jensen bias but more stable) + + lr: float + lr_decay_steps: int + clip_grad_norm: float + layer_norm: bool + ema: float + ema_every: int + steps: int + num_samples: int + solver: str + update_ratio: int + +@dataclass +class CategoricalBDPOConfig(BaseAlgoConfig): + name: str + warmup_steps: int + temperature: Optional[float] # None is uniform, 0.0 is greedy + diffusion: BDPODiffusionConfig + critic: BDPOCriticConfig diff --git a/flowrl/module/critic.py b/flowrl/module/critic.py index 913ffe2..3b0f7a5 100644 --- a/flowrl/module/critic.py +++ b/flowrl/module/critic.py @@ -268,3 +268,35 @@ def __call__( self.output_nodes, kernel_init=self.kernel_init(), bias_init=self.bias_init() )(x) return x + +class CategoricalCriticWithDiscreteTime(nn.Module): + backbone: nn.Module + time_embedding: nn.Module + output_nodes: int + kernel_init: Initializer = init.default_kernel_init + bias_init: Initializer = init.default_bias_init + + @nn.compact + def __call__( + self, + obs: jnp.ndarray, + action: Optional[jnp.ndarray] = None, + t: jnp.ndarray = None, + training: bool = False, + ) -> jnp.ndarray: + t_ff = self.time_embedding(t) + t_ff = MLP( + hidden_dims=[t_ff.shape[-1]*2, t_ff.shape[-1]], + activation=mish, + kernel_init=self.kernel_init, + bias_init=self.bias_init, + )(t_ff) + x = jnp.concatenate( + [item for item in (obs, action, t_ff) if item is not None], + axis=-1, + ) + x = self.backbone(x, training=training) + x = nn.Dense( + self.output_nodes, kernel_init=self.kernel_init(), bias_init=self.bias_init() + )(x) + return x \ No newline at end of file diff --git a/scripts/d4rl/categorical_bdpo.sh b/scripts/d4rl/categorical_bdpo.sh new file mode 100644 index 0000000..5c4105c --- /dev/null +++ b/scripts/d4rl/categorical_bdpo.sh @@ -0,0 +1,95 @@ +export XLA_FLAGS='--xla_gpu_deterministic_ops=true --xla_gpu_autotune_level=0' +# Specify which GPUs to use +GPUS=(0 1 2 3 4 5 6 7) # Modify this array to specify which GPUs to use +SEEDS=(0 1 2 3 4) +NUM_EACH_GPU=2 + +PARALLEL=$((NUM_EACH_GPU * ${#GPUS[@]})) + +TASKS=( + # locomotion tasks + "hopper-medium-v2" + "hopper-medium-replay-v2" + "hopper-medium-expert-v2" + "walker2d-medium-v2" + "walker2d-medium-replay-v2" + "walker2d-medium-expert-v2" + "halfcheetah-medium-v2" + "halfcheetah-medium-replay-v2" + "halfcheetah-medium-expert-v2" + # antmaze tasks + "antmaze-umaze-v0" + "antmaze-umaze-diverse-v0" + "antmaze-medium-play-v0" + "antmaze-medium-diverse-v0" + "antmaze-large-play-v0" + "antmaze-large-diverse-v0" +) + + +SHARED_ARGS=( + "algo=categorical_bdpo" + "log.tag=default" + "log.project=flow-rl" + "log.entity=lamda-rl" +) + +ANTMAZE_ARGS=( + "algo.critic.layer_norm=true" + "algo.critic.discount=0.995" + "data.norm_reward=antmaze100" + "eval.num_episodes=100" + "algo.diffusion.mlp_hidden_dims=[512,512,512,512]" + "algo.diffusion.behavior.lr=1e-4" +) + +declare -A TASK_ARGS +TASK_ARGS=( + ["halfcheetah-medium-v2"]="algo.critic.eta=0.05" + ["halfcheetah-medium-replay-v2"]="algo.critic.eta=0.05" + ["halfcheetah-medium-expert-v2"]="algo.critic.eta=0.05" + ["hopper-medium-v2"]="algo.critic.eta=0.2 algo.critic.ensemble_size=20" + ["hopper-medium-replay-v2"]="algo.critic.eta=0.2" + ["hopper-medium-expert-v2"]="algo.critic.eta=0.2" + ["walker2d-medium-v2"]="algo.critic.eta=0.15" + ["walker2d-medium-replay-v2"]="algo.critic.eta=0.15" + ["walker2d-medium-expert-v2"]="algo.critic.eta=0.15" + # antmaze + ["antmaze-umaze-v0"]="algo.critic.eta=0.5 ${ANTMAZE_ARGS[@]}" + ["antmaze-umaze-diverse-v0"]="algo.temperature=null algo.critic.eta=0.5 ${ANTMAZE_ARGS[@]}" + ["antmaze-medium-play-v0"]="algo.critic.eta=0.2 ${ANTMAZE_ARGS[@]}" + ["antmaze-medium-diverse-v0"]="algo.critic.eta=0.2 ${ANTMAZE_ARGS[@]}" + ["antmaze-large-play-v0"]="algo.critic.eta=1.0 ${ANTMAZE_ARGS[@]}" + ["antmaze-large-diverse-v0"]="algo.critic.eta=1.0 ${ANTMAZE_ARGS[@]}" +) + + +# first argument is the name of the experiment +# any other arguments will be added to the command +# if no arguments are given, exit + + +run_task() { + task=$1 + seed=$2 + slot=$3 + # Calculate device index based on available GPUs + num_gpus=${#GPUS[@]} + device_idx=$((slot % num_gpus)) + device=${GPUS[$device_idx]} + echo "Running $task $seed on GPU $device" + command="python3 examples/offline/main_d4rl.py task=$task device=$device seed=$seed ${SHARED_ARGS[@]} ${TASK_ARGS[$task]}" + if [ -n "$DRY_RUN" ]; then + echo $command + else + echo $command + $command + fi +} + +. env_parallel.bash +if [ -n "$DRY_RUN" ]; then + env_parallel -P${PARALLEL} run_task {1} {2} {%} ::: ${TASKS[@]} ::: ${SEEDS[@]} +else + env_parallel --bar --results logs/parallel/$name -P${PARALLEL} run_task {1} {2} {%} ::: ${TASKS[@]} ::: ${SEEDS[@]} +fi