@@ -92,7 +92,7 @@ def step(self, closure=None):
9292
9393 step = step_t .item ()
9494
95- base_update , update , alpha = self ._inner (step , p ,
95+ base_update , update , alpha = self ._inner (step , p , group ,
9696 ** {k : state .get (k ) for k in self .shared_statistics },
9797 ** {k : state .get (k ) for k in self .base_statistics },
9898 ** {k : state .get (k ) for k in self .true_statistics })
@@ -132,7 +132,7 @@ def __init__(self, params, lr: float = 1e-3,
132132 decay_to_init = decay_to_init , default_to_baseline = default_to_baseline ,
133133 enforce_baseline = enforce_baseline )
134134
135- def _inner (self , step : int , p : Parameter , do_baseline : bool , group : Dict [str , Any ], exp_avg : Tensor ,
135+ def _inner (self , step : int , p : Parameter , group : Dict [str , Any ], exp_avg : Tensor ,
136136 exp_avg_sq : Optional [Tensor ] = None , exp_avg_true_sq : Optional [Tensor ] = None
137137 ) -> Tuple [Optional [Tensor ], Optional [Tensor ], float ]:
138138 if len (group ["betas" ]) == 2 :
@@ -166,7 +166,7 @@ def __init__(self, params, lr: float = 1e-3,
166166 decay_to_init = decay_to_init , default_to_baseline = default_to_baseline ,
167167 enforce_baseline = enforce_baseline )
168168
169- def _inner (self , step : int , p : Parameter , do_baseline : bool , group : Dict [str , Any ],
169+ def _inner (self , step : int , p : Parameter , group : Dict [str , Any ],
170170 exp_avg : Optional [Tensor ] = None , exp_avg_sq : Optional [Tensor ] = None ,
171171 exp_avg_true : Optional [Tensor ] = None , exp_avg_true_sq : Optional [Tensor ] = None
172172 ) -> Tuple [Optional [Tensor ], Optional [Tensor ], float ]:
0 commit comments