Skip to content

Commit 6e40c08

Browse files
committed
fix(optim): correct interface to _inner
1 parent 4d94281 commit 6e40c08

2 files changed

Lines changed: 4 additions & 4 deletions

File tree

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
name='truegrad',
1111
license='BSD',
1212
description='PyTorch interface for TrueGrad-AdamW',
13-
version='2.3.4',
13+
version='2.3.5',
1414
long_description=README,
1515
url='https://github.com/clashluke/truegrad',
1616
packages=setuptools.find_packages(),

truegrad/optim.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def step(self, closure=None):
9292

9393
step = step_t.item()
9494

95-
base_update, update, alpha = self._inner(step, p,
95+
base_update, update, alpha = self._inner(step, p, group,
9696
**{k: state.get(k) for k in self.shared_statistics},
9797
**{k: state.get(k) for k in self.base_statistics},
9898
**{k: state.get(k) for k in self.true_statistics})
@@ -132,7 +132,7 @@ def __init__(self, params, lr: float = 1e-3,
132132
decay_to_init=decay_to_init, default_to_baseline=default_to_baseline,
133133
enforce_baseline=enforce_baseline)
134134

135-
def _inner(self, step: int, p: Parameter, do_baseline: bool, group: Dict[str, Any], exp_avg: Tensor,
135+
def _inner(self, step: int, p: Parameter, group: Dict[str, Any], exp_avg: Tensor,
136136
exp_avg_sq: Optional[Tensor] = None, exp_avg_true_sq: Optional[Tensor] = None
137137
) -> Tuple[Optional[Tensor], Optional[Tensor], float]:
138138
if len(group["betas"]) == 2:
@@ -166,7 +166,7 @@ def __init__(self, params, lr: float = 1e-3,
166166
decay_to_init=decay_to_init, default_to_baseline=default_to_baseline,
167167
enforce_baseline=enforce_baseline)
168168

169-
def _inner(self, step: int, p: Parameter, do_baseline: bool, group: Dict[str, Any],
169+
def _inner(self, step: int, p: Parameter, group: Dict[str, Any],
170170
exp_avg: Optional[Tensor] = None, exp_avg_sq: Optional[Tensor] = None,
171171
exp_avg_true: Optional[Tensor] = None, exp_avg_true_sq: Optional[Tensor] = None
172172
) -> Tuple[Optional[Tensor], Optional[Tensor], float]:

0 commit comments

Comments
 (0)