Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions adv/cvit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,17 @@
import h5py
import jax.numpy as jnp
from models import OperatorModel
from loaders import get_train_val_test_loaders
from loaders import get_train_val_test_data


batch_size = 256
grid_size = 200
n_iterations = 100000

_, _, test_loader = get_train_val_test_loaders(batch_size, grid_size)
_, _, (inputs_test, grid_test, outputs_test) = get_train_val_test_data()
model = OperatorModel(jax.random.PRNGKey(0), "cvit")
model.load_model()

inputs_test, grid_test, outputs_test = test_loader.u, test_loader.y, test_loader.s

s_pred = []
for i in range(10):
u, y, s = (
Expand Down
6 changes: 2 additions & 4 deletions adv/deeponet_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,17 @@
import h5py
import jax.numpy as jnp
from models import OperatorModel
from loaders import get_train_val_test_loaders
from loaders import get_train_val_test_data


batch_size = 256
grid_size = 200
n_iterations = 100000

_, _, test_loader = get_train_val_test_loaders(batch_size, grid_size)
_, _, (inputs_test, grid_test, outputs_test) = get_train_val_test_data()
model = OperatorModel(jax.random.PRNGKey(0), "deeponet")
model.load_model()

inputs_test, grid_test, outputs_test = test_loader.u, test_loader.y, test_loader.s

s_pred = []
for i in range(10):
u, y, s = (
Expand Down
6 changes: 2 additions & 4 deletions adv/fno_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,17 @@
import h5py
import jax.numpy as jnp
from models import OperatorModel
from loaders import get_train_val_test_loaders
from loaders import get_train_val_test_data


batch_size = 256
grid_size = 200
n_iterations = 100000

_, _, test_loader = get_train_val_test_loaders(batch_size, grid_size)
_, _, (inputs_test, grid_test, outputs_test) = get_train_val_test_data()
model = OperatorModel(jax.random.PRNGKey(0), "fno1d")
model.load_model()

inputs_test, grid_test, outputs_test = test_loader.u, test_loader.y, test_loader.s

s_pred = []
for i in range(10):
u, y, s = (
Expand Down
88 changes: 50 additions & 38 deletions adv/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,41 +2,28 @@
import jax.numpy as jnp
import numpy as np
import einops
import torch.utils.data as data
from functools import partial


# defining the dataloader:
class GridSampling(data.Dataset):
def __init__(self, key, u, y, s, batch_size, grid_size):
self.key = key
self.u = u
self.y = y
self.s = s
self.batch_size = batch_size
self.grid_size = grid_size

def __getitem__(self, index):
"Generate one batch of data"
self.key, subkey = jax.random.split(self.key)
batch = self.__data_generation(subkey)
return batch

@partial(jax.jit, static_argnums=(0,))
def __data_generation(self, key):
batch_idx = jax.random.randint(key, (self.batch_size,), 0, self.u.shape[0])


def grid_sampling_iter(key, u, y, s, batch_size, grid_size):
@jax.jit
def _data_generation(sample_key):
batch_idx = jax.random.randint(sample_key, (batch_size,), 0, u.shape[0])
grid_idx = jnp.sort(
jax.random.randint(key, (self.grid_size,), 0, self.u.shape[1])
jax.random.randint(sample_key, (grid_size,), 0, u.shape[1])
)

return (
self.u[batch_idx, :],
self.y[batch_idx, :][:, grid_idx],
self.s[batch_idx, :][:, grid_idx],
u[batch_idx, :],
y[batch_idx, :][:, grid_idx],
s[batch_idx, :][:, grid_idx],
)

while True:
key, subkey = jax.random.split(key)
yield _data_generation(subkey)

def get_train_val_test_loaders(batch_size, grid_size):

def get_train_val_test_data():
data_dir = "/scratch/PDEDatasets/advection_1d"
inputs = np.load(f"{data_dir}/adv_a0.npy")
outputs = np.load(f"{data_dir}/adv_aT.npy")
Expand Down Expand Up @@ -68,29 +55,54 @@ def get_train_val_test_loaders(batch_size, grid_size):
grid[idx[-n_test:]],
)

train_dataloader = GridSampling(
jax.random.PRNGKey(42),
train_data = (
jnp.array(inputs_train),
jnp.array(grid_train),
jnp.array(outputs_train),
)
val_data = (
jnp.array(inputs_val),
jnp.array(grid_val),
jnp.array(outputs_val),
)
test_data = (
jnp.array(inputs_test),
jnp.array(grid_test),
jnp.array(outputs_test),
)

return train_data, val_data, test_data


def get_train_val_test_loaders(batch_size, grid_size):
train_data, val_data, test_data = get_train_val_test_data()
inputs_train, grid_train, outputs_train = train_data
inputs_val, grid_val, outputs_val = val_data
inputs_test, grid_test, outputs_test = test_data

train_dataloader = grid_sampling_iter(
jax.random.PRNGKey(42),
inputs_train,
grid_train,
outputs_train,
batch_size,
grid_size,
)

val_loader = GridSampling(
val_loader = grid_sampling_iter(
jax.random.PRNGKey(42),
jnp.array(inputs_val),
jnp.array(grid_val),
jnp.array(outputs_val),
inputs_val,
grid_val,
outputs_val,
batch_size,
200,
)

test_dataloader = GridSampling(
test_dataloader = grid_sampling_iter(
jax.random.PRNGKey(42),
jnp.array(inputs_test),
jnp.array(grid_test),
jnp.array(outputs_test),
inputs_test,
grid_test,
outputs_test,
batch_size,
200,
)
Expand Down
34 changes: 10 additions & 24 deletions adv/models.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,13 @@
import jax
import optax
import itertools
import einops
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
import flax.linen as nn
import torch.utils.data as data
import pickle

from typing import Any, Callable, Sequence, Optional, Union, Dict
from functools import partial
from typing import Callable
import einops
import flax.linen as nn
import jax
import jax.numpy as jnp
import optax
from jax.nn.initializers import normal, xavier_uniform
from tqdm.auto import trange


Expand Down Expand Up @@ -204,16 +201,6 @@ def __call__(self, u, y):
u = nn.Dense(self.output_dim)(u)
return u


from typing import Callable

import einops
import flax.linen as nn
import jax.numpy as jnp
from einops import rearrange, repeat
from jax.nn.initializers import normal, xavier_uniform


# Positional embedding from masked autoencoder https://arxiv.org/abs/2111.06377
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
assert embed_dim % 2 == 0
Expand Down Expand Up @@ -516,10 +503,9 @@ def __init__(self, key, model_name):
self.params = self.init(key, x, coords)

# getting model size in MB:
self.model_size = (
sum([p.size for p in jax.tree.leaves(self.params)]) * 4 / 1024 / 1024
)
self.model_count = sum([p.size for p in jax.tree_leaves(self.params)])
param_leaves = jax.tree_util.tree_leaves(self.params)
self.model_size = sum(p.size for p in param_leaves) * 4 / 1024 / 1024
self.model_count = sum(p.size for p in param_leaves)

print(f"Model size: {self.model_size:.2f} MB")
print(f"Model count: {self.model_count}")
Expand Down
6 changes: 2 additions & 4 deletions adv/nomad_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,17 @@
import h5py
import jax.numpy as jnp
from models import OperatorModel
from loaders import get_train_val_test_loaders
from loaders import get_train_val_test_data


batch_size = 256
grid_size = 200
n_iterations = 200000

_, _, test_loader = get_train_val_test_loaders(batch_size, grid_size)
_, _, (inputs_test, grid_test, outputs_test) = get_train_val_test_data()
model = OperatorModel(jax.random.PRNGKey(0), "nomad")
model.load_model()

inputs_test, grid_test, outputs_test = test_loader.u, test_loader.y, test_loader.s

s_pred = []
for i in range(10):
u, y, s = (
Expand Down
5 changes: 0 additions & 5 deletions dr/dr_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,6 @@
import os

import h5py

import numpy as np
from numpy.lib.stride_tricks import sliding_window_view

from einops import rearrange


Expand Down Expand Up @@ -46,4 +42,3 @@ def create_dr_datasets(filename, prev_steps, pred_steps, train_samples, test_sam

return train_dataset, test_dataset, mean, std


43 changes: 9 additions & 34 deletions dr/eval.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,11 @@
import os

import einops

import jax

import jax.numpy as jnp
from jax.flatten_util import ravel_pytree

import orbax.checkpoint as ocp


from torch.utils.data import Dataset, DataLoader, Subset

import tensorflow as tf
from src.model import CVit
from src.utils import create_optimizer, create_train_state, create_checkpoint_manager, rollout
from src.data_pipeline import BaseDataset


from dr_pipeline import create_dr_datasets


Expand Down Expand Up @@ -45,13 +34,8 @@ def evaluate(config):
config.dataset.train_samples,
config.dataset.test_samples)

test_dataset = BaseDataset(test_inputs, test_outputs)

test_loader = DataLoader(test_dataset,
batch_size=32,
shuffle=False,
drop_last=True,
num_workers=8)
test_dataset = tf.data.Dataset.from_tensor_slices((test_inputs, test_outputs))
test_loader = test_dataset.batch(32, drop_remainder=True).prefetch(tf.data.AUTOTUNE)

# Create a grid for cvit
_, t, h, w, c = test_inputs.shape
Expand All @@ -61,9 +45,8 @@ def evaluate(config):
coords = jnp.hstack([x_star.flatten()[:, None], y_star.flatten()[:, None]])

l2_error_list = []
for batch in test_loader:
batch = jax.tree_map(lambda x: jnp.array(x), batch)
x, y = batch
for x, y in test_loader:
x, y = jnp.array(x), jnp.array(y)
pred = model.apply(state.params, x, coords)
pred = pred.reshape(-1, 1, h, w, c)

Expand All @@ -86,18 +69,12 @@ def evaluate(config):
config.dataset.train_samples,
config.dataset.test_samples)

test_dataset = BaseDataset(test_inputs, test_outputs)

test_loader = DataLoader(test_dataset,
batch_size=32,
shuffle=False,
drop_last=True,
num_workers=8)
test_dataset = tf.data.Dataset.from_tensor_slices((test_inputs, test_outputs))
test_loader = test_dataset.batch(32, drop_remainder=True).prefetch(tf.data.AUTOTUNE)

l2_error_list = []
for batch in test_loader:
batch = jax.tree_map(lambda x: jnp.array(x), batch)
x, y = batch
for x, y in test_loader:
x, y = jnp.array(x), jnp.array(y)

pred = rollout(state, x, coords,
prev_steps=config.dataset.prev_steps,
Expand All @@ -121,5 +98,3 @@ def evaluate(config):





1 change: 0 additions & 1 deletion dr/main.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from absl import app
from absl import flags
from ml_collections import config_flags

import train, train_vit
import eval

Expand Down
5 changes: 0 additions & 5 deletions dr/train.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,13 @@
import os

import time
import ml_collections
import wandb

from jax import random
import jax.numpy as jnp
import orbax.checkpoint as ocp


from src.model import CVit
from src.utils import create_optimizer, create_train_state, create_checkpoint_manager, create_train_step, create_eval_step
from src.data_pipeline import create_dataloaders, batch_parser

from dr_pipeline import create_dr_datasets


Expand Down
2 changes: 0 additions & 2 deletions ns/configs/cvit_16x16.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import ml_collections

import jax.numpy as jnp


def get_config():
"""Get the default hyperparameter configuration."""
Expand Down
Loading