Skip to content

FhG-IISB/jNO

Repository files navigation

jNO logo

Dev Docs Dev Tutorials Tests License Citation Docker image available

Warning: This is a research-level repository. It may contain bugs and is subject to continuous change without notice.

Install

Quick install from PyPI:

pip install jax-neural-operators

If a Nvidia GPU is available install

pip instal jax[cuda]

For local development (recommended on Linux aarch64 when gmsh wheels are unavailable on PyPI), use micromamba:

micromamba create -n jno python=3.12 pip -y
micromamba activate jno
micromamba install -n jno -c conda-forge gmsh python-gmsh -y
pip install -e .

Minimal DeepONet Example

Create the following file

import jno
import jax
import optax
import foundax

dir = jno.setup("./runs/test")

# Domain
dom = 500 * jno.domain.rect(mesh_size=0.05, x_range=(0, 2), y_range=(0, 1))
x, y, _ = dom.variable("interior")
xb, yb, _ = dom.variable("boundary")

random_k = jax.random.uniform(jax.random.PRNGKey(0), shape=(500, 1, 1), minval=0.5, maxval=1.5)
k = dom.variable("k", random_k)

# Neural Network
fx = foundax.deeponet(n_sensors=1, coord_dim=2, basis_functions=32, hidden_dim=128, activation=jax.numpy.tanh)
net = jno.nn.wrap(fx)
net.optimizer(optax.adam(learning_rate=optax.schedules.cosine_decay_schedule(init_value=1e-3, decay_steps=20_000, alpha=1e-5)))

# Forward pass and hard enforcement of BCs via output transformation
u = net(k, jno.np.concat([x, y], axis=-1)) * x * (2 - x) * y * (1 - y)
pde = k * (u.dd(x) + u.dd(y)) + 1.0  # PDE Loss

# Checkpointing (saves every 5000 epochs, keeps best 3)
cb = jno.callback.checkpoint(save_interval_epochs=5000, best_fn=lambda m: m["total_loss"])

# Create -> Train -> Save
crux = jno.core(constraints=[pde.mse], domain=dom).print_shapes()
crux.solve(epochs=20_000, batchsize=32, callbacks=[cb]).plot(f"{dir}/training.png")
jno.save(crux, f"{dir}/model.pkl")

# Inference via test domain on a finer mesh
tst_dom = 16 * jno.domain.rect(mesh_size=0.01, x_range=(0, 2), y_range=(0, 1))
tst_dom.variable("k", jax.random.uniform(jax.random.PRNGKey(0), shape=(16, 1, 1), minval=0.1, maxval=1.9))

pred, x, y, k = crux.eval([u, x, y, k], domain=tst_dom)
print(pred.shape, x.shape, y.shape, k.shape)

and then run with

CUDA_VISIBLE_DEVICES=<gpu_id> JNO_SEED=<seed> python <filename>.py

Foundation Models and other neural networks

These models are maintained in a seperate repository (foundax) so they can also be used independently.

pip install foundax

Citation

If jNO is used we would appreciate to cite the following paper:

@article{armbruster2026jNO,
  author  = {Armbruster, Leon, ....},
  title   = {{jNO}: A JAX Library for Neural Operator and PDE Foundation Model Training},
  journal = {},
  year    = {},
}

Packages

 
 
 

Contributors