Skip to content

Commit 2215186

Browse files
Cristian GarciaFlax Authors
authored andcommitted
change defaults for Dropout and BatchNorm
Changes `Dropout.deterministic` and `BatchNorm.use_running_average` to be None by default, use now has to explicitely provide them by either: 1. Passing them to the constructor e.g: self.bn = nnx.BatchNorm(..., use_running_average=False) 2. Passing them to __call__: self.dropout(x, deterministic=False) 3. Using `nnx.view` to create a view of the model with specific values: train_model = nnx.view(model, detereministic=False, use_running_average=False) PiperOrigin-RevId: 877557940
1 parent ef6e4b1 commit 2215186

16 files changed

Lines changed: 62 additions & 44 deletions

docs_nnx/guides/demo.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ from flax import nnx
2525
class Block(nnx.Module):
2626
def __init__(self, din, dout, *, rngs):
2727
self.linear = nnx.Linear(din, dout, rngs=rngs)
28-
self.bn = nnx.BatchNorm(dout, rngs=rngs)
28+
self.bn = nnx.BatchNorm(dout, use_running_average=False, rngs=rngs)
2929
3030
def __call__(self, x):
3131
return nnx.relu(self.bn(self.linear(x)))

docs_nnx/guides/performance.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@ import optax
2626
class Model(nnx.Module):
2727
def __init__(self, din, dmid, dout, rngs: nnx.Rngs):
2828
self.linear = nnx.Linear(din, dmid, rngs=rngs)
29-
self.bn = nnx.BatchNorm(dmid, rngs=rngs)
30-
self.dropout = nnx.Dropout(0.2, rngs=rngs)
29+
self.bn = nnx.BatchNorm(dmid, use_running_average=False, rngs=rngs)
30+
self.dropout = nnx.Dropout(0.2, deterministic=False, rngs=rngs)
3131
self.linear_out = nnx.Linear(dmid, dout, rngs=rngs)
3232
3333
def __call__(self, x):

docs_nnx/guides/pytree.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -436,8 +436,8 @@ NNX Modules are `Pytree`s that have two additional methods for traking intermedi
436436
class Block(nnx.Module):
437437
def __init__(self, din: int, dout: int, rngs: nnx.Rngs):
438438
self.linear = nnx.Linear(din, dout, rngs=rngs)
439-
self.bn = nnx.BatchNorm(dout, rngs=rngs)
440-
self.dropout = nnx.Dropout(0.1, rngs=rngs)
439+
self.bn = nnx.BatchNorm(dout, use_running_average=False, rngs=rngs)
440+
self.dropout = nnx.Dropout(0.1, deterministic=False, rngs=rngs)
441441
442442
def __call__(self, x):
443443
y = nnx.relu(self.dropout(self.bn(self.linear(x))))

docs_nnx/guides/randomness.md

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ from flax import nnx
1818
class Model(nnx.Module):
1919
def __init__(self, *, rngs: nnx.Rngs):
2020
self.linear = nnx.Linear(20, 10, rngs=rngs)
21-
self.drop = nnx.Dropout(0.1)
21+
self.drop = nnx.Dropout(0.1, deterministic=False)
2222
2323
def __call__(self, x, *, rngs):
2424
return nnx.relu(self.drop(self.linear(x), rngs=rngs))
@@ -90,7 +90,7 @@ Specifically, this will use the RngSteam `rngs.params` for weight initialization
9090
The `nnx.Dropout` module also requires a random state, but it requires this state at *call* time rather than initialization. Once again, we can pass it random state using the `rngs` keyword argument.
9191

9292
```{code-cell} ipython3
93-
dropout = nnx.Dropout(0.5)
93+
dropout = nnx.Dropout(0.5, deterministic=False)
9494
```
9595

9696
```{code-cell} ipython3
@@ -159,7 +159,7 @@ Say you want to train a model that uses dropout on a batch of data. You don't wa
159159
class Model(nnx.Module):
160160
def __init__(self, rngs: nnx.Rngs):
161161
self.linear = nnx.Linear(20, 10, rngs=rngs)
162-
self.drop = nnx.Dropout(0.1)
162+
self.drop = nnx.Dropout(0.1, deterministic=False)
163163
164164
def __call__(self, x, rngs):
165165
return nnx.relu(self.drop(self.linear(x), rngs=rngs))
@@ -199,7 +199,7 @@ So far, we have looked at passing random state directly to each Module when it g
199199
class Model(nnx.Module):
200200
def __init__(self, rngs: nnx.Rngs):
201201
self.linear = nnx.Linear(20, 10, rngs=rngs)
202-
self.drop = nnx.Dropout(0.1, rngs=rngs)
202+
self.drop = nnx.Dropout(0.1, deterministic=False, rngs=rngs)
203203
204204
def __call__(self, x):
205205
return nnx.relu(self.drop(self.linear(x)))
@@ -296,7 +296,7 @@ class Count(nnx.Variable): pass
296296
class RNNCell(nnx.Module):
297297
def __init__(self, din, dout, rngs):
298298
self.linear = nnx.Linear(dout + din, dout, rngs=rngs)
299-
self.drop = nnx.Dropout(0.1, rngs=rngs, rng_collection='recurrent_dropout')
299+
self.drop = nnx.Dropout(0.1, deterministic=False, rngs=rngs, rng_collection='recurrent_dropout')
300300
self.dout = dout
301301
self.count = Count(jnp.array(0, jnp.uint32))
302302

docs_nnx/hijax/hijax.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -153,8 +153,8 @@ print("mutable =", v_ref)
153153
class Block(nnx.Module):
154154
def __init__(self, din, dmid, dout, rngs: nnx.Rngs):
155155
self.linear = Linear(din, dmid, rngs=rngs)
156-
self.bn = nnx.BatchNorm(dmid, rngs=rngs)
157-
self.dropout = nnx.Dropout(0.1, rngs=rngs)
156+
self.bn = nnx.BatchNorm(dmid, use_running_average=False, rngs=rngs)
157+
self.dropout = nnx.Dropout(0.1, deterministic=False, rngs=rngs)
158158
self.linear_out = Linear(dmid, dout, rngs=rngs)
159159
160160
def __call__(self, x):

docs_nnx/nnx_basics.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,8 +96,8 @@ The example below shows how to define a simple `MLP` by subclassing `Module`. Th
9696
class MLP(nnx.Module):
9797
def __init__(self, din: int, dmid: int, dout: int, *, rngs: nnx.Rngs):
9898
self.linear1 = Linear(din, dmid, rngs=rngs)
99-
self.dropout = nnx.Dropout(rate=0.1)
100-
self.bn = nnx.BatchNorm(dmid, rngs=rngs)
99+
self.dropout = nnx.Dropout(rate=0.1, deterministic=False)
100+
self.bn = nnx.BatchNorm(dmid, use_running_average=False, rngs=rngs)
101101
self.linear2 = Linear(dmid, dout, rngs=rngs)
102102
103103
def __call__(self, x: jax.Array, rngs: nnx.Rngs):

flax/nnx/module.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -444,12 +444,12 @@ def view(node: A, /, *, only: filterlib.Filter = ..., raise_if_not_found: bool =
444444
>>> class Block(nnx.Module):
445445
... def __init__(self, din, dout, *, rngs: nnx.Rngs):
446446
... self.linear = nnx.Linear(din, dout, rngs=rngs)
447-
... self.dropout = nnx.Dropout(0.5, deterministic=False)
448-
... self.batch_norm = nnx.BatchNorm(10, use_running_average=False, rngs=rngs)
447+
... self.dropout = nnx.Dropout(0.5)
448+
... self.batch_norm = nnx.BatchNorm(10, rngs=rngs)
449449
...
450450
>>> block = Block(2, 5, rngs=nnx.Rngs(0))
451451
>>> block.dropout.deterministic, block.batch_norm.use_running_average
452-
(False, False)
452+
(None, None)
453453
>>> new_block = nnx.view(block, deterministic=True, use_running_average=True)
454454
>>> new_block.dropout.deterministic, new_block.batch_norm.use_running_average
455455
(True, True)
@@ -459,7 +459,7 @@ def view(node: A, /, *, only: filterlib.Filter = ..., raise_if_not_found: bool =
459459
>>> new_block = nnx.view(block, only=nnx.Dropout, deterministic=True)
460460
>>> # Only the dropout will be modified
461461
>>> new_block.dropout.deterministic, new_block.batch_norm.use_running_average
462-
(True, False)
462+
(True, None)
463463
464464
Args:
465465
node: the object to create a copy of.

flax/nnx/nn/normalization.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,7 @@ def __init__(
290290
self,
291291
num_features: int,
292292
*,
293-
use_running_average: bool = False,
293+
use_running_average: bool | None = None,
294294
axis: int = -1,
295295
momentum: float = 0.99,
296296
epsilon: float = 1e-5,
@@ -364,8 +364,17 @@ def __call__(
364364
use_running_average = first_from(
365365
use_running_average,
366366
self.use_running_average,
367-
error_msg="""No `use_running_average` argument was provided to BatchNorm
368-
as either a __call__ argument, class attribute, or nnx.flag.""",
367+
error_msg=(
368+
'No `use_running_average` argument was provided to BatchNorm.'
369+
' Consider one of the following options:\n\n'
370+
'1. Pass `use_running_average` to the BatchNorm constructor:\n\n'
371+
' self.bn = nnx.BatchNorm(..., use_running_average=True/False)\n\n'
372+
'2. Pass `use_running_average` to the BatchNorm __call__:\n\n'
373+
' self.bn(x, use_running_average=True/False)\n\n'
374+
'3. Use `nnx.view` to create a view of the model with a'
375+
' specific `use_running_average` value:\n\n'
376+
' model_view = nnx.view(model, use_running_average=True/False)\n'
377+
),
369378
)
370379
feature_axes = _canonicalize_axes(x.ndim, self.axis)
371380
reduction_axes = tuple(i for i in range(x.ndim) if i not in feature_axes)

flax/nnx/nn/stochastic.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def __init__(
7373
rate: float,
7474
*,
7575
broadcast_dims: Sequence[int] = (),
76-
deterministic: bool = False,
76+
deterministic: bool | None = None,
7777
rng_collection: str = 'dropout',
7878
rngs: rnglib.Rngs | rnglib.RngStream | None = None,
7979
):
@@ -117,8 +117,17 @@ def __call__(
117117
deterministic = first_from(
118118
deterministic,
119119
self.deterministic,
120-
error_msg="""No `deterministic` argument was provided to Dropout
121-
as either a __call__ argument or class attribute""",
120+
error_msg=(
121+
'No `deterministic` argument was provided to Dropout.'
122+
' Consider one of the following options:\n\n'
123+
'1. Pass `deterministic` to the Dropout constructor:\n\n'
124+
' self.dropout = nnx.Dropout(..., deterministic=True/False)\n\n'
125+
'2. Pass `deterministic` to the Dropout __call__:\n\n'
126+
' self.dropout(x, deterministic=True/False)\n\n'
127+
'3. Use `nnx.view` to create a view of the model with a'
128+
' specific `deterministic` value:\n\n'
129+
' model_view = nnx.view(model, deterministic=True/False)\n'
130+
),
122131
)
123132

124133
if (self.rate == 0.0) or deterministic:

tests/nnx/bridge/module_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,7 @@ def test_pure_nnx_submodule(self):
280280
class NNXLayer(nnx.Module):
281281
def __init__(self, dim, dropout, rngs):
282282
self.linear = nnx.Linear(dim, dim, use_bias=False, rngs=rngs)
283-
self.dropout = nnx.Dropout(dropout, rngs=rngs)
283+
self.dropout = nnx.Dropout(dropout, deterministic=False, rngs=rngs)
284284
self.count = nnx.Intermediate(jnp.array([0.]))
285285
def __call__(self, x):
286286
# Required check to avoid state update in `init()`. Can this be avoided?

0 commit comments

Comments
 (0)