Hi, Albert! Thanks for your generous sharing!
I found that the Hydra parameters of the diffusion policy in the diffusion_policy.yaml
|
ema_factory: |
|
_target_: diffusers.training_utils.EMAModel |
|
_partial_: true |
|
decay: 0.9999 |
|
use_ema_warmup: false |
|
inv_gamma: 1.0 |
|
power: 0.75 |
do not match those of EMAModel
class EMAModel:
"""
Exponential Moving Average of models weights
"""
def __init__(
self,
model,
update_after_step=0,
inv_gamma=1.0,
power=2 / 3,
min_value=0.0,
max_value=0.9999,
device=None,
):
, which will trigger the initialization error.
Replacing
|
self.ema: EMAModel = ema_factory(parameters=self.networks.parameters()) |
with
self.ema: EMAModel = ema_factory(model=self.networks)
could resolve this error.
Besides, for the following snippet,
|
def train(self, mode=True): |
|
"""Override train method to manage EMA parameters.""" |
|
if mode and self._using_ema_params: |
|
# Switching to train mode, restore original parameters |
|
self.ema.restore(self.networks.parameters()) |
|
self._using_ema_params = False |
|
elif not mode and not self._using_ema_params: |
|
# Switching to eval mode, use EMA parameters |
|
self.ema.store(self.networks.parameters()) |
|
self.ema.copy_to(self.networks.parameters()) |
|
self._using_ema_params = True |
|
|
|
# Call parent train method |
|
return super().train(mode) |
EMAModel' object has no attribute 'store'. Thus,
self.ema.step(self.networks)
may be the right solution.
Thanks for your time.
Hi, Albert! Thanks for your generous sharing!
I found that the Hydra parameters of the diffusion policy in the
diffusion_policy.yamlAdapt3R/config/algo/diffusion_policy.yaml
Lines 20 to 26 in 9563f06
do not match those of
EMAModel, which will trigger the initialization error.
Replacing
Adapt3R/adapt3r/algos/diffusion_policy.py
Line 48 in 9563f06
with
could resolve this error.
Besides, for the following snippet,
Adapt3R/adapt3r/algos/diffusion_policy.py
Lines 55 to 68 in 9563f06
EMAModel' object has no attribute 'store'. Thus,
may be the right solution.
Thanks for your time.