Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 61 additions & 0 deletions examples/offline/config/d4rl/algo/categorical_bdpo.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# @package _global_

data:
norm_obs: false
norm_reward: iql_mujoco

eval:
num_samples: 10
interval: 50000

train_steps: 2_000_000
pretrain_steps: 2_000_000

algo:
name: categorical_bdpo
warmup_steps: 50_000
temperature: 0.0
diffusion:
actor:
lr: 0.00001
lr_decay_steps: ${train_steps}
clip_grad_norm: 1.0
ema: 0.005
ema_every: 1
behavior:
lr: 0.0003
lr_decay_steps: null
clip_grad_norm: 1.0
ema: 0.005
ema_every: 1
noise_schedule: vp
resnet: false
dropout: 0.0
layer_norm: false
time_dim: 64
mlp_hidden_dims: [256, 256, 256, 256]
resnet_hidden_dims: [256, 256, 256]
solver: ddpm
steps: 5
clip_sampler: true
x_min: -1.0
x_max: 1.0
critic:
discount: 0.99
ensemble_size: 10
eta: 1.0
hidden_dims: [256, 256, 256]
output_nodes: 101
v_min: -10.0
v_max: 10.0
symexp_before_mean: false
lr: 0.0003
lr_decay_steps: ${train_steps}
clip_grad_norm: 1.0
layer_norm: false
ema: 0.005
ema_every: 1
steps: 5
num_samples: 10
solver: ddpm
update_ratio: 5
1 change: 1 addition & 0 deletions examples/offline/main_d4rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
SUPPORTED_AGENTS: Dict[str, Type[BaseAgent]] = {
"iql": IQLAgent,
"bdpo": BDPOAgent,
"categorical_bdpo": CategoricalBDPOAgent,
"ivr": IVRAgent,
"fql": FQLAgent,
"dac": DACAgent,
Expand Down
1 change: 1 addition & 0 deletions flowrl/agent/offline/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from ..base import BaseAgent
from .bdpo.bdpo import BDPOAgent
from .bdpo.categorical_bdpo import CategoricalBDPOAgent
from .dac import DACAgent
from .dql import DQLAgent
from .dtql import DTQLAgent
Expand Down
83 changes: 83 additions & 0 deletions flowrl/agent/offline/bdpo/bdpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,75 @@ def actor_loss_fn(actor_params: Param, *args, **kwargs) -> Tuple[jnp.ndarray, Me
return rng, new_actor, new_actor_target, metrics


@partial(jax.jit, static_argnames=("T", "num_samples"))
def jit_compute_statistics(
rng: PRNGKey,
vt_target: Model,
q0_target: Model,
actor: DDPM,
actor_target: DDPM,
behavior_target: DDPM,
batch: Batch,
T: int,
num_samples: int,
) -> Tuple[PRNGKey, Metric]:
"""Compute per-timestep statistics for debugging and hyperparameter tuning."""
alpha_hats = actor_target.alpha_hats

B = batch.obs.shape[0]
t_proto = jnp.ones((*batch.obs.shape[:-1], num_samples, 1), dtype=jnp.int32)

obs = batch.obs[..., jnp.newaxis, :].repeat(num_samples, axis=-2) # (B, N, obs_dim)
x0 = batch.action[..., jnp.newaxis, :].repeat(num_samples, axis=-2) # (B, N, act_dim)

tot_metrics = {}
for i in range(0, T + 1):
rng, eps_key = jax.random.split(rng)
eps_sample = jax.random.normal(eps_key, x0.shape, dtype=jnp.float32)
t = t_proto * i

xt = jnp.sqrt(alpha_hats[i]) * x0 + jnp.sqrt(1 - alpha_hats[i]) * eps_sample

# vt critic values
vt_val = vt_target(obs, xt, t) # (E, B, N, 1)

# noise predictions
eps_actor = actor(xt, t, condition=obs)
eps_actor_target = actor_target(xt, t, condition=obs)
eps_behavior_target = behavior_target(xt, t, condition=obs)
bc_loss = 0.5 * ((eps_actor - eps_behavior_target) ** 2).sum(axis=-1).mean()
bc_target_loss = 0.5 * ((eps_actor_target - eps_behavior_target) ** 2).sum(axis=-1).mean()

# vt gradient w.r.t. action
def vt_fn_single(a, s, t_):
return vt_target(s, a, t_).mean(axis=0).squeeze()
vt_grad = jax.vmap(jax.grad(vt_fn_single))(
xt.reshape(B * num_samples, -1),
obs.reshape(B * num_samples, -1),
t.reshape(B * num_samples, -1),
)

# q0 gradient w.r.t. action
def q0_fn_single(a, s):
return q0_target(s, a).mean(axis=0).squeeze()
q0_grad = jax.vmap(jax.grad(q0_fn_single))(
xt.reshape(B * num_samples, -1),
obs.reshape(B * num_samples, -1),
)

tot_metrics.update({
f"stats/vt_std_e_{i}": vt_val.std(axis=0).mean(),
f"stats/vt_std_a_{i}": vt_val.mean(axis=0).std(axis=-2).mean(),
f"stats/bc_loss_{i}": bc_loss,
f"stats/bc_target_loss_{i}": bc_target_loss,
f"stats/vt_grad_{i}": jnp.abs(vt_grad).mean(),
f"stats/q0_grad_{i}": jnp.abs(q0_grad).mean(),
f"stats/vt_mean_{i}": vt_val.mean(),
})

return rng, tot_metrics


class BDPOAgent(BaseAgent):
"""
Behavior-Regularized Diffusion Policy Optimization (BDPO)
Expand Down Expand Up @@ -518,6 +587,20 @@ def pretrain_step(self, batch: Batch, step: int) -> Metric:
self._n_pretraining_steps += 1
return metrics

def compute_statistics(self, batch: Batch) -> Metric:
self.rng, stats = jit_compute_statistics(
self.rng,
self.vt_target,
self.q0_target,
self.actor,
self.actor_target,
self.behavior_target,
batch,
self.cfg.diffusion.steps,
self.cfg.critic.num_samples,
)
return stats

def sample_actions(
self,
obs: jnp.ndarray,
Expand Down
Loading