Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions docs_nnx/guides/demo.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,22 +10,22 @@ jupytext:

# NNX Demo

```{code-cell} ipython3
```{code-cell}
import jax
from jax import numpy as jnp
from flax import nnx
```

### [1] NNX is Pythonic

```{code-cell} ipython3
```{code-cell}
:outputId: d8ef66d5-6866-4d5c-94c2-d22512bfe718


class Block(nnx.Module):
def __init__(self, din, dout, *, rngs):
self.linear = nnx.Linear(din, dout, rngs=rngs)
self.bn = nnx.BatchNorm(dout, rngs=rngs)
self.bn = nnx.BatchNorm(dout, use_running_average=False, rngs=rngs)

def __call__(self, x):
return nnx.relu(self.bn(self.linear(x)))
Expand Down Expand Up @@ -56,7 +56,7 @@ print(f'{model = }'[:500] + '\n...')

Because NNX Modules contain their own state, they are very easily to inspect:

```{code-cell} ipython3
```{code-cell}
:outputId: 10a46b0f-2993-4677-c26d-36a4ddf33449

print(f'{model.count = }')
Expand All @@ -66,7 +66,7 @@ print(f'{model.blocks[0].linear.kernel = }')

### [2] Model Surgery is Intuitive

```{code-cell} ipython3
```{code-cell}
:outputId: e6f86be8-3537-4c48-f471-316ee0fb6c45

# Module sharing
Expand All @@ -83,7 +83,7 @@ print(f'{y.shape = }')

### [3] Interacting with JAX is easy

```{code-cell} ipython3
```{code-cell}
:outputId: 9a3f378b-739e-4f45-9968-574651200ede

graphdef, state = model.split()
Expand All @@ -95,7 +95,7 @@ print(f'{state = }'[:500] + '\n...')
print(f'\n{graphdefefefefefef = }'[:300] + '\n...')
```

```{code-cell} ipython3
```{code-cell}
:outputId: 0007d357-152a-449e-bcb9-b1b5a91d2d8d

graphdef, state = model.split()
Expand All @@ -116,7 +116,7 @@ print(f'{y.shape = }')
print(f'{model.count.value = }')
```

```{code-cell} ipython3
```{code-cell}
params, batch_stats, counts, graphdef = model.split(nnx.Param, nnx.BatchStat, Count)

@jax.jit
Expand All @@ -135,7 +135,7 @@ print(f'{y.shape = }')
print(f'{model.count = }')
```

```{code-cell} ipython3
```{code-cell}
class Parent(nnx.Module):
def __init__(self, model: MLP):
self.model = model
Expand Down Expand Up @@ -163,6 +163,6 @@ print(f'{y.shape = }')
print(f'{parent.model.count.value = }')
```

```{code-cell} ipython3
```{code-cell}

```
4 changes: 2 additions & 2 deletions docs_nnx/guides/performance.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ import optax
class Model(nnx.Module):
def __init__(self, din, dmid, dout, rngs: nnx.Rngs):
self.linear = nnx.Linear(din, dmid, rngs=rngs)
self.bn = nnx.BatchNorm(dmid, rngs=rngs)
self.dropout = nnx.Dropout(0.2, rngs=rngs)
self.bn = nnx.BatchNorm(dmid, use_running_average=False, rngs=rngs)
self.dropout = nnx.Dropout(0.2, deterministic=False, rngs=rngs)
self.linear_out = nnx.Linear(dmid, dout, rngs=rngs)

def __call__(self, x):
Expand Down
4 changes: 2 additions & 2 deletions docs_nnx/guides/pytree.md
Original file line number Diff line number Diff line change
Expand Up @@ -436,8 +436,8 @@ NNX Modules are `Pytree`s that have two additional methods for traking intermedi
class Block(nnx.Module):
def __init__(self, din: int, dout: int, rngs: nnx.Rngs):
self.linear = nnx.Linear(din, dout, rngs=rngs)
self.bn = nnx.BatchNorm(dout, rngs=rngs)
self.dropout = nnx.Dropout(0.1, rngs=rngs)
self.bn = nnx.BatchNorm(dout, use_running_average=False, rngs=rngs)
self.dropout = nnx.Dropout(0.1, deterministic=False, rngs=rngs)

def __call__(self, x):
y = nnx.relu(self.dropout(self.bn(self.linear(x))))
Expand Down
10 changes: 5 additions & 5 deletions docs_nnx/guides/randomness.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ from flax import nnx
class Model(nnx.Module):
def __init__(self, *, rngs: nnx.Rngs):
self.linear = nnx.Linear(20, 10, rngs=rngs)
self.drop = nnx.Dropout(0.1)
self.drop = nnx.Dropout(0.1, deterministic=False)

def __call__(self, x, *, rngs):
return nnx.relu(self.drop(self.linear(x), rngs=rngs))
Expand Down Expand Up @@ -90,7 +90,7 @@ Specifically, this will use the RngSteam `rngs.params` for weight initialization
The `nnx.Dropout` module also requires a random state, but it requires this state at *call* time rather than initialization. Once again, we can pass it random state using the `rngs` keyword argument.

```{code-cell} ipython3
dropout = nnx.Dropout(0.5)
dropout = nnx.Dropout(0.5, deterministic=False)
```

```{code-cell} ipython3
Expand Down Expand Up @@ -159,7 +159,7 @@ Say you want to train a model that uses dropout on a batch of data. You don't wa
class Model(nnx.Module):
def __init__(self, rngs: nnx.Rngs):
self.linear = nnx.Linear(20, 10, rngs=rngs)
self.drop = nnx.Dropout(0.1)
self.drop = nnx.Dropout(0.1, deterministic=False)

def __call__(self, x, rngs):
return nnx.relu(self.drop(self.linear(x), rngs=rngs))
Expand Down Expand Up @@ -199,7 +199,7 @@ So far, we have looked at passing random state directly to each Module when it g
class Model(nnx.Module):
def __init__(self, rngs: nnx.Rngs):
self.linear = nnx.Linear(20, 10, rngs=rngs)
self.drop = nnx.Dropout(0.1, rngs=rngs)
self.drop = nnx.Dropout(0.1, deterministic=False, rngs=rngs)

def __call__(self, x):
return nnx.relu(self.drop(self.linear(x)))
Expand Down Expand Up @@ -296,7 +296,7 @@ class Count(nnx.Variable): pass
class RNNCell(nnx.Module):
def __init__(self, din, dout, rngs):
self.linear = nnx.Linear(dout + din, dout, rngs=rngs)
self.drop = nnx.Dropout(0.1, rngs=rngs, rng_collection='recurrent_dropout')
self.drop = nnx.Dropout(0.1, deterministic=False, rngs=rngs, rng_collection='recurrent_dropout')
self.dout = dout
self.count = Count(jnp.array(0, jnp.uint32))

Expand Down
46 changes: 23 additions & 23 deletions docs_nnx/hijax/hijax.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ jupytext:

# Hijax

```{code-cell} ipython3
```{code-cell}
from flax import nnx
import jax
import jax.numpy as jnp
Expand All @@ -19,7 +19,7 @@ import optax
current_mode = nnx.var_defaults().hijax # ignore: only needed for testing
```

```{code-cell} ipython3
```{code-cell}
nnx.var_defaults(hijax=True)

rngs = nnx.Rngs(0)
Expand All @@ -44,7 +44,7 @@ for _ in range(3):

State propagation:

```{code-cell} ipython3
```{code-cell}
v = nnx.Variable(jnp.array(0), hijax=True)

@jax.jit
Expand All @@ -54,14 +54,14 @@ def inc(v):
print(v[...]); inc(v); print(v[...])
```

```{code-cell} ipython3
```{code-cell}
v = nnx.Variable(jnp.array(0), hijax=True)
print(jax.make_jaxpr(inc)(v))
```

Pytree values:

```{code-cell} ipython3
```{code-cell}
v = nnx.Variable({'a': jnp.array(0), 'b': jnp.array(2)}, hijax=True)

@jax.jit
Expand All @@ -74,7 +74,7 @@ print(v); inc_and_double(v); print(v)

Dynamic state structure:

```{code-cell} ipython3
```{code-cell}
rngs = nnx.Rngs(0)
x = rngs.uniform((4, 5))
w = rngs.normal((5, 3))
Expand All @@ -91,7 +91,7 @@ y = linear(x, w, metrics)
print("After:", metrics)
```

```{code-cell} ipython3
```{code-cell}
# set default Variable mode for the rest of the guide
nnx.var_defaults(hijax=True)

Expand All @@ -102,7 +102,7 @@ print(variable)

### Mutability

```{code-cell} ipython3
```{code-cell}
class Linear(nnx.Module):
def __init__(self, in_features, out_features, rngs: nnx.Rngs):
self.kernel = nnx.Param(rngs.normal((in_features, out_features)))
Expand All @@ -116,7 +116,7 @@ print(f"{nnx.vars_as(model, mutable=False) = !s}")
print(f"{nnx.vars_as(model, mutable=True) = !s}")
```

```{code-cell} ipython3
```{code-cell}
v = nnx.Variable(jnp.array(0))
v_immut = nnx.vars_as(v, mutable=False)
assert not v_immut.mutable
Expand All @@ -129,15 +129,15 @@ except Exception as e:

### Ref support

```{code-cell} ipython3
```{code-cell}
v = nnx.Variable(jnp.array(0))
v_ref = nnx.vars_as(v, ref=True)
assert v_ref.ref
print(v_ref)
print(v_ref.get_raw_value())
```

```{code-cell} ipython3
```{code-cell}
v_immut = nnx.vars_as(v_ref, mutable=False)
assert not v_immut.ref
print("immutable =", v_immut)
Expand All @@ -149,12 +149,12 @@ print("mutable =", v_ref)

### Examples

```{code-cell} ipython3
```{code-cell}
class Block(nnx.Module):
def __init__(self, din, dmid, dout, rngs: nnx.Rngs):
self.linear = Linear(din, dmid, rngs=rngs)
self.bn = nnx.BatchNorm(dmid, rngs=rngs)
self.dropout = nnx.Dropout(0.1, rngs=rngs)
self.bn = nnx.BatchNorm(dmid, use_running_average=False, rngs=rngs)
self.dropout = nnx.Dropout(0.1, deterministic=False, rngs=rngs)
self.linear_out = Linear(dmid, dout, rngs=rngs)

def __call__(self, x):
Expand All @@ -164,7 +164,7 @@ class Block(nnx.Module):

#### Training Loop

```{code-cell} ipython3
```{code-cell}
# hijax Variables by default
model = Block(2, 64, 3, rngs=nnx.Rngs(0))
optimizer = nnx.Optimizer(model, optax.adam(1e-3), wrt=nnx.Param)
Expand All @@ -188,7 +188,7 @@ for _ in range(3):

#### Scan Over Layers

```{code-cell} ipython3
```{code-cell}
# TODO: does not work with hijax yet
# @jax.vmap
# def create_stack(rngs):
Expand All @@ -212,7 +212,7 @@ for _ in range(3):

#### Mutable Outputs

```{code-cell} ipython3
```{code-cell}
@jax.jit
def create_model(rngs):
return Block(2, 64, 3, rngs=rngs)
Expand All @@ -223,7 +223,7 @@ except Exception as e:
print(f"Error:", e)
```

```{code-cell} ipython3
```{code-cell}
@jax.jit
def create_model(rngs):
return nnx.vars_as((Block(2, 64, 3, rngs=rngs)), hijax=False)
Expand All @@ -235,7 +235,7 @@ print("model.linear =", model.linear)

#### Reference Sharing (aliasing)

```{code-cell} ipython3
```{code-cell}
# NOTE: doesn't currently fail on the jax side
def get_error(f, *args):
try:
Expand All @@ -252,7 +252,7 @@ def f(a, b):
print(get_error(f, x, x))
```

```{code-cell} ipython3
```{code-cell}
# NOTE: doesn't currently fail on the jax side
class HasShared(nnx.Pytree):
def __init__(self):
Expand All @@ -269,14 +269,14 @@ print(get_error(g, has_shared))
print(has_shared) # updates don't propagate
```

```{code-cell} ipython3
```{code-cell}
print("Duplicates found:")
if (all_duplicates := nnx.find_duplicates(has_shared)):
for duplicates in all_duplicates:
print("-", duplicates)
```

```{code-cell} ipython3
```{code-cell}
@jax.jit
def h(graphdef, state):
has_shared = nnx.merge(graphdef, state)
Expand All @@ -287,7 +287,7 @@ h(graphdef, state)
print(has_shared)
```

```{code-cell} ipython3
```{code-cell}
# clean up for CI tests
_ = nnx.var_defaults(hijax=current_mode)
```
Loading
Loading