From f933ac8a148bcb0d1b2ba4d3157b68beda33e204 Mon Sep 17 00:00:00 2001 From: GVourvachakis Date: Wed, 27 May 2026 23:07:30 +0300 Subject: [PATCH] refactor(deps): replace PyTorch data utilities with tf.data and pure Python PyTorch was used as a data-loading and legacy serialization utility, not for model definition, training, or optimization. This commit removes the runtime dependency by: - Replacing torch.utils.data Dataset/DataLoader usage in dr/, ns/, and swe/ eval scripts with tf.data pipelines already used by src/data_pipeline.py - Replacing the GridSampling Dataset subclass in adv/ with a plain Python generator; the index argument was never used and batches were JAX-generated - Replacing runtime torch.load('normstats.pt') in swe/swe_pipeline.py with NumPy .npz loading; a one-time optional migration script remains in that file - Removing dead imports such as Subset and TensorDataset that were never called - Removing BaseDataset's torch.Dataset inheritance because pure Python indexing is sufficient - Removing torch from requirements.txt and adding the missing tqdm and wandb runtime dependencies The full ML stack remains JAX + Flax + Optax. TensorFlow is retained solely for its tf.data pipeline API. No numerical behaviour is changed. Existing normstats.pt files must be converted once using the migration script in swe/swe_pipeline.py from an environment with PyTorch installed. --- adv/cvit_test.py | 6 +-- adv/deeponet_test.py | 6 +-- adv/fno_test.py | 6 +-- adv/loaders.py | 88 ++++++++++++++++++++--------------- adv/models.py | 34 ++++---------- adv/nomad_test.py | 6 +-- dr/dr_pipeline.py | 5 -- dr/eval.py | 43 ++++------------- dr/main.py | 1 - dr/train.py | 5 -- ns/configs/cvit_16x16.py | 2 - ns/configs/cvit_4x4.py | 2 - ns/configs/cvit_8x8.py | 2 - ns/configs/cvit_base_8x8.py | 2 - ns/configs/cvit_small_8x8.py | 2 - ns/eval.py | 35 ++++---------- ns/main.py | 1 - ns/ns_pipeline.py | 5 +- ns/train.py | 7 +-- requirements.txt | 3 +- src/data_pipeline.py | 17 ++----- src/model.py | 1 - src/utils.py | 5 +- swe/configs/cvit_16x16.py | 2 - swe/configs/cvit_32x32.py | 2 - swe/configs/cvit_4x4.py | 2 - swe/configs/cvit_8x8.py | 2 - swe/configs/cvit_base_8x8.py | 2 - swe/configs/cvit_small_8x8.py | 2 - swe/configs/vit.py | 2 - swe/eval.py | 37 ++++----------- swe/main.py | 1 - swe/swe_pipeline.py | 65 +++++++++++++++++++++++--- swe/train.py | 5 -- swe/train_vit.py | 13 ++---- 35 files changed, 165 insertions(+), 254 deletions(-) diff --git a/adv/cvit_test.py b/adv/cvit_test.py index fce6ca5..a352ce1 100644 --- a/adv/cvit_test.py +++ b/adv/cvit_test.py @@ -7,19 +7,17 @@ import h5py import jax.numpy as jnp from models import OperatorModel -from loaders import get_train_val_test_loaders +from loaders import get_train_val_test_data batch_size = 256 grid_size = 200 n_iterations = 100000 -_, _, test_loader = get_train_val_test_loaders(batch_size, grid_size) +_, _, (inputs_test, grid_test, outputs_test) = get_train_val_test_data() model = OperatorModel(jax.random.PRNGKey(0), "cvit") model.load_model() -inputs_test, grid_test, outputs_test = test_loader.u, test_loader.y, test_loader.s - s_pred = [] for i in range(10): u, y, s = ( diff --git a/adv/deeponet_test.py b/adv/deeponet_test.py index 5a075d5..8ea6c9e 100644 --- a/adv/deeponet_test.py +++ b/adv/deeponet_test.py @@ -7,19 +7,17 @@ import h5py import jax.numpy as jnp from models import OperatorModel -from loaders import get_train_val_test_loaders +from loaders import get_train_val_test_data batch_size = 256 grid_size = 200 n_iterations = 100000 -_, _, test_loader = get_train_val_test_loaders(batch_size, grid_size) +_, _, (inputs_test, grid_test, outputs_test) = get_train_val_test_data() model = OperatorModel(jax.random.PRNGKey(0), "deeponet") model.load_model() -inputs_test, grid_test, outputs_test = test_loader.u, test_loader.y, test_loader.s - s_pred = [] for i in range(10): u, y, s = ( diff --git a/adv/fno_test.py b/adv/fno_test.py index 4d9d808..7dcdfb6 100644 --- a/adv/fno_test.py +++ b/adv/fno_test.py @@ -7,19 +7,17 @@ import h5py import jax.numpy as jnp from models import OperatorModel -from loaders import get_train_val_test_loaders +from loaders import get_train_val_test_data batch_size = 256 grid_size = 200 n_iterations = 100000 -_, _, test_loader = get_train_val_test_loaders(batch_size, grid_size) +_, _, (inputs_test, grid_test, outputs_test) = get_train_val_test_data() model = OperatorModel(jax.random.PRNGKey(0), "fno1d") model.load_model() -inputs_test, grid_test, outputs_test = test_loader.u, test_loader.y, test_loader.s - s_pred = [] for i in range(10): u, y, s = ( diff --git a/adv/loaders.py b/adv/loaders.py index 9bc5be4..6a6b83b 100644 --- a/adv/loaders.py +++ b/adv/loaders.py @@ -2,41 +2,28 @@ import jax.numpy as jnp import numpy as np import einops -import torch.utils.data as data -from functools import partial - - -# defining the dataloader: -class GridSampling(data.Dataset): - def __init__(self, key, u, y, s, batch_size, grid_size): - self.key = key - self.u = u - self.y = y - self.s = s - self.batch_size = batch_size - self.grid_size = grid_size - - def __getitem__(self, index): - "Generate one batch of data" - self.key, subkey = jax.random.split(self.key) - batch = self.__data_generation(subkey) - return batch - - @partial(jax.jit, static_argnums=(0,)) - def __data_generation(self, key): - batch_idx = jax.random.randint(key, (self.batch_size,), 0, self.u.shape[0]) + + +def grid_sampling_iter(key, u, y, s, batch_size, grid_size): + @jax.jit + def _data_generation(sample_key): + batch_idx = jax.random.randint(sample_key, (batch_size,), 0, u.shape[0]) grid_idx = jnp.sort( - jax.random.randint(key, (self.grid_size,), 0, self.u.shape[1]) + jax.random.randint(sample_key, (grid_size,), 0, u.shape[1]) ) return ( - self.u[batch_idx, :], - self.y[batch_idx, :][:, grid_idx], - self.s[batch_idx, :][:, grid_idx], + u[batch_idx, :], + y[batch_idx, :][:, grid_idx], + s[batch_idx, :][:, grid_idx], ) + while True: + key, subkey = jax.random.split(key) + yield _data_generation(subkey) -def get_train_val_test_loaders(batch_size, grid_size): + +def get_train_val_test_data(): data_dir = "/scratch/PDEDatasets/advection_1d" inputs = np.load(f"{data_dir}/adv_a0.npy") outputs = np.load(f"{data_dir}/adv_aT.npy") @@ -68,29 +55,54 @@ def get_train_val_test_loaders(batch_size, grid_size): grid[idx[-n_test:]], ) - train_dataloader = GridSampling( - jax.random.PRNGKey(42), + train_data = ( jnp.array(inputs_train), jnp.array(grid_train), jnp.array(outputs_train), + ) + val_data = ( + jnp.array(inputs_val), + jnp.array(grid_val), + jnp.array(outputs_val), + ) + test_data = ( + jnp.array(inputs_test), + jnp.array(grid_test), + jnp.array(outputs_test), + ) + + return train_data, val_data, test_data + + +def get_train_val_test_loaders(batch_size, grid_size): + train_data, val_data, test_data = get_train_val_test_data() + inputs_train, grid_train, outputs_train = train_data + inputs_val, grid_val, outputs_val = val_data + inputs_test, grid_test, outputs_test = test_data + + train_dataloader = grid_sampling_iter( + jax.random.PRNGKey(42), + inputs_train, + grid_train, + outputs_train, batch_size, grid_size, ) - val_loader = GridSampling( + val_loader = grid_sampling_iter( jax.random.PRNGKey(42), - jnp.array(inputs_val), - jnp.array(grid_val), - jnp.array(outputs_val), + inputs_val, + grid_val, + outputs_val, batch_size, 200, ) - test_dataloader = GridSampling( + test_dataloader = grid_sampling_iter( jax.random.PRNGKey(42), - jnp.array(inputs_test), - jnp.array(grid_test), - jnp.array(outputs_test), + inputs_test, + grid_test, + outputs_test, batch_size, 200, ) diff --git a/adv/models.py b/adv/models.py index 26eca49..7324b75 100644 --- a/adv/models.py +++ b/adv/models.py @@ -1,16 +1,13 @@ -import jax -import optax import itertools -import einops -import jax.numpy as jnp -import numpy as np -import matplotlib.pyplot as plt -import flax.linen as nn -import torch.utils.data as data import pickle - -from typing import Any, Callable, Sequence, Optional, Union, Dict from functools import partial +from typing import Callable +import einops +import flax.linen as nn +import jax +import jax.numpy as jnp +import optax +from jax.nn.initializers import normal, xavier_uniform from tqdm.auto import trange @@ -204,16 +201,6 @@ def __call__(self, u, y): u = nn.Dense(self.output_dim)(u) return u - -from typing import Callable - -import einops -import flax.linen as nn -import jax.numpy as jnp -from einops import rearrange, repeat -from jax.nn.initializers import normal, xavier_uniform - - # Positional embedding from masked autoencoder https://arxiv.org/abs/2111.06377 def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): assert embed_dim % 2 == 0 @@ -516,10 +503,9 @@ def __init__(self, key, model_name): self.params = self.init(key, x, coords) # getting model size in MB: - self.model_size = ( - sum([p.size for p in jax.tree.leaves(self.params)]) * 4 / 1024 / 1024 - ) - self.model_count = sum([p.size for p in jax.tree_leaves(self.params)]) + param_leaves = jax.tree_util.tree_leaves(self.params) + self.model_size = sum(p.size for p in param_leaves) * 4 / 1024 / 1024 + self.model_count = sum(p.size for p in param_leaves) print(f"Model size: {self.model_size:.2f} MB") print(f"Model count: {self.model_count}") diff --git a/adv/nomad_test.py b/adv/nomad_test.py index ed453b4..59357a5 100644 --- a/adv/nomad_test.py +++ b/adv/nomad_test.py @@ -7,19 +7,17 @@ import h5py import jax.numpy as jnp from models import OperatorModel -from loaders import get_train_val_test_loaders +from loaders import get_train_val_test_data batch_size = 256 grid_size = 200 n_iterations = 200000 -_, _, test_loader = get_train_val_test_loaders(batch_size, grid_size) +_, _, (inputs_test, grid_test, outputs_test) = get_train_val_test_data() model = OperatorModel(jax.random.PRNGKey(0), "nomad") model.load_model() -inputs_test, grid_test, outputs_test = test_loader.u, test_loader.y, test_loader.s - s_pred = [] for i in range(10): u, y, s = ( diff --git a/dr/dr_pipeline.py b/dr/dr_pipeline.py index 1d821dd..56276bf 100644 --- a/dr/dr_pipeline.py +++ b/dr/dr_pipeline.py @@ -1,10 +1,6 @@ -import os - import h5py - import numpy as np from numpy.lib.stride_tricks import sliding_window_view - from einops import rearrange @@ -46,4 +42,3 @@ def create_dr_datasets(filename, prev_steps, pred_steps, train_samples, test_sam return train_dataset, test_dataset, mean, std - diff --git a/dr/eval.py b/dr/eval.py index 0b8eef9..403a643 100644 --- a/dr/eval.py +++ b/dr/eval.py @@ -1,22 +1,11 @@ import os - import einops - -import jax - import jax.numpy as jnp from jax.flatten_util import ravel_pytree - import orbax.checkpoint as ocp - - -from torch.utils.data import Dataset, DataLoader, Subset - +import tensorflow as tf from src.model import CVit from src.utils import create_optimizer, create_train_state, create_checkpoint_manager, rollout -from src.data_pipeline import BaseDataset - - from dr_pipeline import create_dr_datasets @@ -45,13 +34,8 @@ def evaluate(config): config.dataset.train_samples, config.dataset.test_samples) - test_dataset = BaseDataset(test_inputs, test_outputs) - - test_loader = DataLoader(test_dataset, - batch_size=32, - shuffle=False, - drop_last=True, - num_workers=8) + test_dataset = tf.data.Dataset.from_tensor_slices((test_inputs, test_outputs)) + test_loader = test_dataset.batch(32, drop_remainder=True).prefetch(tf.data.AUTOTUNE) # Create a grid for cvit _, t, h, w, c = test_inputs.shape @@ -61,9 +45,8 @@ def evaluate(config): coords = jnp.hstack([x_star.flatten()[:, None], y_star.flatten()[:, None]]) l2_error_list = [] - for batch in test_loader: - batch = jax.tree_map(lambda x: jnp.array(x), batch) - x, y = batch + for x, y in test_loader: + x, y = jnp.array(x), jnp.array(y) pred = model.apply(state.params, x, coords) pred = pred.reshape(-1, 1, h, w, c) @@ -86,18 +69,12 @@ def evaluate(config): config.dataset.train_samples, config.dataset.test_samples) - test_dataset = BaseDataset(test_inputs, test_outputs) - - test_loader = DataLoader(test_dataset, - batch_size=32, - shuffle=False, - drop_last=True, - num_workers=8) + test_dataset = tf.data.Dataset.from_tensor_slices((test_inputs, test_outputs)) + test_loader = test_dataset.batch(32, drop_remainder=True).prefetch(tf.data.AUTOTUNE) l2_error_list = [] - for batch in test_loader: - batch = jax.tree_map(lambda x: jnp.array(x), batch) - x, y = batch + for x, y in test_loader: + x, y = jnp.array(x), jnp.array(y) pred = rollout(state, x, coords, prev_steps=config.dataset.prev_steps, @@ -121,5 +98,3 @@ def evaluate(config): - - diff --git a/dr/main.py b/dr/main.py index fbacc65..4c8848f 100644 --- a/dr/main.py +++ b/dr/main.py @@ -1,7 +1,6 @@ from absl import app from absl import flags from ml_collections import config_flags - import train, train_vit import eval diff --git a/dr/train.py b/dr/train.py index 0d6508a..505fd60 100644 --- a/dr/train.py +++ b/dr/train.py @@ -1,18 +1,13 @@ import os - import time import ml_collections import wandb - from jax import random import jax.numpy as jnp import orbax.checkpoint as ocp - - from src.model import CVit from src.utils import create_optimizer, create_train_state, create_checkpoint_manager, create_train_step, create_eval_step from src.data_pipeline import create_dataloaders, batch_parser - from dr_pipeline import create_dr_datasets diff --git a/ns/configs/cvit_16x16.py b/ns/configs/cvit_16x16.py index 68b761c..29c3e3e 100644 --- a/ns/configs/cvit_16x16.py +++ b/ns/configs/cvit_16x16.py @@ -1,7 +1,5 @@ import ml_collections -import jax.numpy as jnp - def get_config(): """Get the default hyperparameter configuration.""" diff --git a/ns/configs/cvit_4x4.py b/ns/configs/cvit_4x4.py index 3db0db7..9a70e5c 100644 --- a/ns/configs/cvit_4x4.py +++ b/ns/configs/cvit_4x4.py @@ -1,7 +1,5 @@ import ml_collections -import jax.numpy as jnp - def get_config(): """Get the default hyperparameter configuration.""" diff --git a/ns/configs/cvit_8x8.py b/ns/configs/cvit_8x8.py index e5117a4..7f11625 100644 --- a/ns/configs/cvit_8x8.py +++ b/ns/configs/cvit_8x8.py @@ -1,7 +1,5 @@ import ml_collections -import jax.numpy as jnp - def get_config(): """Get the default hyperparameter configuration.""" diff --git a/ns/configs/cvit_base_8x8.py b/ns/configs/cvit_base_8x8.py index 36694ea..231199f 100644 --- a/ns/configs/cvit_base_8x8.py +++ b/ns/configs/cvit_base_8x8.py @@ -1,7 +1,5 @@ import ml_collections -import jax.numpy as jnp - def get_config(): """Get the default hyperparameter configuration.""" diff --git a/ns/configs/cvit_small_8x8.py b/ns/configs/cvit_small_8x8.py index abddb0b..ec6b4a4 100644 --- a/ns/configs/cvit_small_8x8.py +++ b/ns/configs/cvit_small_8x8.py @@ -1,7 +1,5 @@ import ml_collections -import jax.numpy as jnp - def get_config(): """Get the default hyperparameter configuration.""" diff --git a/ns/eval.py b/ns/eval.py index 1b4d8f9..370a901 100644 --- a/ns/eval.py +++ b/ns/eval.py @@ -1,16 +1,9 @@ import os - import einops - - -import jax import jax.numpy as jnp from jax.flatten_util import ravel_pytree - import orbax.checkpoint as ocp - -from torch.utils.data import TensorDataset, Dataset, DataLoader, Subset - +import tensorflow as tf from src.model import CVit from src.utils import ( create_optimizer, @@ -18,8 +11,6 @@ create_checkpoint_manager, rollout, ) -from src.data_pipeline import BaseDataset - from ns_pipeline import prepare_ns_dataset @@ -52,11 +43,8 @@ def evaluate(config): num_samples=1000, ) - test_dataset = BaseDataset(test_inputs, test_outputs) - - test_loader = DataLoader( - test_dataset, batch_size=32, shuffle=False, drop_last=True, num_workers=8 - ) + test_dataset = tf.data.Dataset.from_tensor_slices((test_inputs, test_outputs)) + test_loader = test_dataset.batch(32, drop_remainder=True).prefetch(tf.data.AUTOTUNE) # Create a grid for cvit _, t, h, w, c = test_inputs.shape @@ -66,9 +54,8 @@ def evaluate(config): coords = jnp.hstack([x_star.flatten()[:, None], y_star.flatten()[:, None]]) l2_error_list = [] - for batch in test_loader: - batch = jax.tree_map(lambda x: jnp.array(x), batch) - x, y = batch + for x, y in test_loader: + x, y = jnp.array(x), jnp.array(y) pred = model.apply(state.params, x, coords) pred = pred.reshape(-1, 1, h, w, c) @@ -94,16 +81,12 @@ def evaluate(config): num_samples=1000, ) - test_dataset = BaseDataset(test_inputs, test_outputs) - - test_loader = DataLoader( - test_dataset, batch_size=32, shuffle=False, drop_last=True, num_workers=8 - ) + test_dataset = tf.data.Dataset.from_tensor_slices((test_inputs, test_outputs)) + test_loader = test_dataset.batch(32, drop_remainder=True).prefetch(tf.data.AUTOTUNE) l2_error_list = [] - for batch in test_loader: - batch = jax.tree_map(lambda x: jnp.array(x), batch) - x, y = batch + for x, y in test_loader: + x, y = jnp.array(x), jnp.array(y) pred = rollout( state, diff --git a/ns/main.py b/ns/main.py index 87465d5..24fd3cc 100644 --- a/ns/main.py +++ b/ns/main.py @@ -1,7 +1,6 @@ from absl import app from absl import flags from ml_collections import config_flags - import train import eval diff --git a/ns/ns_pipeline.py b/ns/ns_pipeline.py index 9dbaf28..af13ed1 100644 --- a/ns/ns_pipeline.py +++ b/ns/ns_pipeline.py @@ -1,11 +1,8 @@ import os import re - import h5py - import numpy as np from numpy.lib.stride_tricks import sliding_window_view - from einops import rearrange @@ -89,7 +86,7 @@ def create_ns_datasets(config): # self.keys = keys # self.filename = directory + mode + ".zarr" # # normalization constants -# self.normstats = torch.load(directory + "normstats.pt") +# self.normstats = np.load(directory + "normstats.npz") # # self.prev_steps = prev_steps # self.pred_steps = pred_steps diff --git a/ns/train.py b/ns/train.py index fdc4242..be06fd9 100644 --- a/ns/train.py +++ b/ns/train.py @@ -1,14 +1,10 @@ import os - import time import ml_collections import wandb - -from jax import random, vmap, jit +from jax import random import jax.numpy as jnp import orbax.checkpoint as ocp - - from src.model import CVit from src.utils import ( create_optimizer, @@ -18,7 +14,6 @@ create_eval_step, ) from src.data_pipeline import create_dataloaders, batch_parser - from ns_pipeline import create_ns_datasets diff --git a/requirements.txt b/requirements.txt index eab51d8..0418bb3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,6 +8,7 @@ numpy optax orbax-checkpoint tensorflow_cpu -torch +tqdm +wandb xarray zarr diff --git a/src/data_pipeline.py b/src/data_pipeline.py index 25b0db6..41ea1ab 100644 --- a/src/data_pipeline.py +++ b/src/data_pipeline.py @@ -1,35 +1,25 @@ import numpy as np - from einops import rearrange - import jax from jax import random - import jax.numpy as jnp - - import flax - import tensorflow as tf -from torch.utils.data import Dataset, DataLoader, Subset +class BaseDataset: + """Minimal indexable dataset wrapper for equally sized arrays.""" -class BaseDataset(Dataset): def __init__(self, *datasets): - super().__init__() - # Ensure all datasets have the same length assert all( len(datasets[0]) == len(dataset) for dataset in datasets ), "All datasets must have the same length" self.datasets = datasets def __len__(self): - # Assuming all datasets have the same length, use the first one to determine the length return len(self.datasets[0]) def __getitem__(self, index): - # Retrieve the corresponding item from each dataset return tuple(dataset[index] for dataset in self.datasets) @@ -47,7 +37,8 @@ def prefetch(dataset, n_prefetch=None): """Prefetches data to device and converts to numpy array.""" ds_iter = iter(dataset) ds_iter = map( - lambda x: jax.tree_map(lambda t: np.asarray(memoryview(t)), x), ds_iter + lambda x: jax.tree_util.tree_map(lambda t: np.asarray(memoryview(t)), x), + ds_iter, ) if n_prefetch: ds_iter = map(prepare_tf_data, ds_iter) diff --git a/src/model.py b/src/model.py index c9394cf..d0f6769 100644 --- a/src/model.py +++ b/src/model.py @@ -1,5 +1,4 @@ from typing import Callable - import einops import flax.linen as nn import jax.numpy as jnp diff --git a/src/utils.py b/src/utils.py index 62e3ff0..ee1e19b 100644 --- a/src/utils.py +++ b/src/utils.py @@ -1,11 +1,8 @@ import os - import jax import jax.numpy as jnp -from jax import lax, jit, vmap, pmap, random, tree_map, jacfwd - +from jax import jit, random, vmap from flax.training import train_state - import optax import orbax.checkpoint as ocp diff --git a/swe/configs/cvit_16x16.py b/swe/configs/cvit_16x16.py index 90bb993..70e3580 100644 --- a/swe/configs/cvit_16x16.py +++ b/swe/configs/cvit_16x16.py @@ -1,7 +1,5 @@ import ml_collections -import jax.numpy as jnp - def get_config(): """Get the default hyperparameter configuration.""" diff --git a/swe/configs/cvit_32x32.py b/swe/configs/cvit_32x32.py index 5533604..06ef901 100644 --- a/swe/configs/cvit_32x32.py +++ b/swe/configs/cvit_32x32.py @@ -1,7 +1,5 @@ import ml_collections -import jax.numpy as jnp - def get_config(): """Get the default hyperparameter configuration.""" diff --git a/swe/configs/cvit_4x4.py b/swe/configs/cvit_4x4.py index e00b7f3..e175a94 100644 --- a/swe/configs/cvit_4x4.py +++ b/swe/configs/cvit_4x4.py @@ -1,7 +1,5 @@ import ml_collections -import jax.numpy as jnp - def get_config(): """Get the default hyperparameter configuration.""" diff --git a/swe/configs/cvit_8x8.py b/swe/configs/cvit_8x8.py index 1932577..9800a5e 100644 --- a/swe/configs/cvit_8x8.py +++ b/swe/configs/cvit_8x8.py @@ -1,7 +1,5 @@ import ml_collections -import jax.numpy as jnp - def get_config(): """Get the default hyperparameter configuration.""" diff --git a/swe/configs/cvit_base_8x8.py b/swe/configs/cvit_base_8x8.py index 3d5bbe4..eb463f7 100644 --- a/swe/configs/cvit_base_8x8.py +++ b/swe/configs/cvit_base_8x8.py @@ -1,7 +1,5 @@ import ml_collections -import jax.numpy as jnp - def get_config(): """Get the default hyperparameter configuration.""" diff --git a/swe/configs/cvit_small_8x8.py b/swe/configs/cvit_small_8x8.py index b8c6b72..e67406c 100644 --- a/swe/configs/cvit_small_8x8.py +++ b/swe/configs/cvit_small_8x8.py @@ -1,7 +1,5 @@ import ml_collections -import jax.numpy as jnp - def get_config(): """Get the default hyperparameter configuration.""" diff --git a/swe/configs/vit.py b/swe/configs/vit.py index 3f70ef5..440bc61 100644 --- a/swe/configs/vit.py +++ b/swe/configs/vit.py @@ -1,7 +1,5 @@ import ml_collections -import jax.numpy as jnp - def get_config(): """Get the default hyperparameter configuration.""" diff --git a/swe/eval.py b/swe/eval.py index b9da01c..9c81ca3 100644 --- a/swe/eval.py +++ b/swe/eval.py @@ -1,17 +1,9 @@ import os - import einops - -import jax - import jax.numpy as jnp from jax.flatten_util import ravel_pytree - import orbax.checkpoint as ocp - - -from torch.utils.data import Dataset, DataLoader, Subset - +import tensorflow as tf from src.model import CVit from src.utils import ( create_optimizer, @@ -19,9 +11,6 @@ create_checkpoint_manager, rollout, ) -from src.data_pipeline import BaseDataset - - from swe_pipeline import prepare_swe_dataset @@ -55,11 +44,8 @@ def evaluate(config): num_samples=1000, ) - test_dataset = BaseDataset(test_inputs, test_outputs) - - test_loader = DataLoader( - test_dataset, batch_size=32, shuffle=False, drop_last=True, num_workers=8 - ) + test_dataset = tf.data.Dataset.from_tensor_slices((test_inputs, test_outputs)) + test_loader = test_dataset.batch(32, drop_remainder=True).prefetch(tf.data.AUTOTUNE) # Create a grid for cvit _, t, h, w, c = test_inputs.shape @@ -69,9 +55,8 @@ def evaluate(config): coords = jnp.hstack([x_star.flatten()[:, None], y_star.flatten()[:, None]]) l2_error_list = [] - for batch in test_loader: - batch = jax.tree_map(lambda x: jnp.array(x), batch) - x, y = batch + for x, y in test_loader: + x, y = jnp.array(x), jnp.array(y) pred = model.apply(state.params, x, coords) pred = pred.reshape(-1, 1, h, w, c) @@ -98,16 +83,12 @@ def evaluate(config): num_samples=1000, ) - test_dataset = BaseDataset(test_inputs, test_outputs) - - test_loader = DataLoader( - test_dataset, batch_size=32, shuffle=False, drop_last=True, num_workers=8 - ) + test_dataset = tf.data.Dataset.from_tensor_slices((test_inputs, test_outputs)) + test_loader = test_dataset.batch(32, drop_remainder=True).prefetch(tf.data.AUTOTUNE) l2_error_list = [] - for batch in test_loader: - batch = jax.tree_map(lambda x: jnp.array(x), batch) - x, y = batch + for x, y in test_loader: + x, y = jnp.array(x), jnp.array(y) pred = rollout( state, diff --git a/swe/main.py b/swe/main.py index ed8de8d..b1fb673 100644 --- a/swe/main.py +++ b/swe/main.py @@ -1,7 +1,6 @@ from absl import app from absl import flags from ml_collections import config_flags - import train import eval diff --git a/swe/swe_pipeline.py b/swe/swe_pipeline.py index a20dae0..88ebc62 100644 --- a/swe/swe_pipeline.py +++ b/swe/swe_pipeline.py @@ -1,12 +1,22 @@ import xarray as xr - import numpy as np from numpy.lib.stride_tricks import sliding_window_view - from einops import rearrange -import jax -import torch + +def _load_norm_stats(filename): + with np.load(filename) as data: + keys = sorted( + { + name.rsplit("_", 1)[0] + for name in data.files + if name.endswith("_mean") or name.endswith("_std") + } + ) + return { + key: {"mean": data[f"{key}_mean"], "std": data[f"{key}_std"]} + for key in keys + } # Construct the full dataset @@ -18,8 +28,7 @@ def prepare_swe_dataset( ds = xr.open_zarr(filename) - norm_stats = torch.load(directory + "normstats.pt") - norm_stats = jax.tree_map(lambda x: np.array(x), norm_stats) + norm_stats = _load_norm_stats(directory + "normstats.npz") data_dict = {key: [] for key in keys} @@ -81,7 +90,7 @@ def create_swe_datasets(config): # self.keys = keys # self.filename = directory + mode + ".zarr" # # normalization constants -# self.normstats = torch.load(directory + "normstats.pt") +# self.normstats = _load_norm_stats(directory + "normstats.npz") # # self.prev_steps = prev_steps # self.pred_steps = pred_steps @@ -155,3 +164,45 @@ def create_swe_datasets(config): # batch = jax.tree_map(lambda x: jnp.array(x), batch) # # return batch + + +if __name__ == "__main__": + import argparse + import importlib + + parser = argparse.ArgumentParser( + description="Convert SWE normstats.pt files to normstats.npz." + ) + parser.add_argument( + "directory", + nargs="?", + default="", + help="Directory containing normstats.pt. Defaults to the current directory.", + ) + args = parser.parse_args() + + directory = args.directory + if directory and not directory.endswith("/"): + directory += "/" + + try: + torch = importlib.import_module("torch") + except ModuleNotFoundError as exc: + raise SystemExit( + "PyTorch is required only for converting normstats.pt to " + "normstats.npz. Install torch in this environment or run the " + "converter from an environment that already has PyTorch." + ) from exc + + norm_stats = torch.load(directory + "normstats.pt", map_location="cpu") + + np.savez( + directory + "normstats.npz", + **{ + f"{key}_{stat_name}": np.array(stat_value) + for key, stats in norm_stats.items() + for stat_name, stat_value in stats.items() + if stat_name in {"mean", "std"} + }, + ) + print(f"Saved {directory}normstats.npz from {directory}normstats.pt.") diff --git a/swe/train.py b/swe/train.py index 1814ddb..a4b16c6 100644 --- a/swe/train.py +++ b/swe/train.py @@ -1,14 +1,10 @@ import os - import time import ml_collections import wandb - from jax import random import jax.numpy as jnp import orbax.checkpoint as ocp - - from src.model import CVit from src.utils import ( create_optimizer, @@ -18,7 +14,6 @@ create_eval_step, ) from src.data_pipeline import create_dataloaders, batch_parser - from swe_pipeline import create_swe_datasets diff --git a/swe/train_vit.py b/swe/train_vit.py index d8458ba..1c0ac1a 100644 --- a/swe/train_vit.py +++ b/swe/train_vit.py @@ -1,23 +1,16 @@ import os - import time import ml_collections import wandb - from einops import rearrange - - import jax from jax import random, jit import jax.numpy as jnp import orbax.checkpoint as ocp - from flax.training import train_state - from model import Vit from utils import create_optimizer, create_checkpoint_manager -from data_pipeline import create_dataloaders, batch_parser - +from data_pipeline import create_dataloaders from swe_pipeline import create_swe_datasets @@ -144,7 +137,7 @@ def train_and_evaluate(config: ml_collections.ConfigDict): last_loss = 1.0 for step in range(config.training.num_steps): batch = next(train_iter) - batch = jax.tree_map(lambda x: jnp.squeeze(x), batch) + batch = jax.tree_util.tree_map(lambda x: jnp.squeeze(x), batch) state, loss = train_step_fn(state, batch) # Evaluate model @@ -153,7 +146,7 @@ def train_and_evaluate(config: ml_collections.ConfigDict): smse_list = [] for _ in range(config.logging.eval_steps): batch = next(test_iter) - batch = jax.tree_map(lambda x: jnp.squeeze(x), batch) + batch = jax.tree_util.tree_map(lambda x: jnp.squeeze(x), batch) l2_error, smse = eval_step_fn(state, batch) l2_error_list.append(l2_error)