Skip to content

refactor(deps): remove PyTorch dependency, use tf.data and pure Python loaders#11

Open
GVourvachakis wants to merge 1 commit into
PredictiveIntelligenceLab:mainfrom
GVourvachakis:remove-torch-dependency
Open

refactor(deps): remove PyTorch dependency, use tf.data and pure Python loaders#11
GVourvachakis wants to merge 1 commit into
PredictiveIntelligenceLab:mainfrom
GVourvachakis:remove-torch-dependency

Conversation

@GVourvachakis
Copy link
Copy Markdown

@GVourvachakis GVourvachakis commented May 27, 2026

Summary

This PR removes PyTorch as a runtime dependency from CViT.

PyTorch was previously used only for data-loading utilities and legacy
serialization, not for model definition, training, optimization, or numerical
kernels. The data-loading paths are now unified around tf.data, which was
already used in src/data_pipeline.py, and plain Python iteration where no
framework abstraction is needed.

The main ML stack remains:

  • JAX
  • Flax
  • Optax
  • TensorFlow only for tf.data

Changes

File Removed Replaced with
adv/loaders.py torch.utils.data.Dataset subclass Pure Python infinite generator using JAX sampling
adv/models.py unused torch.utils.data import, duplicate/stale imports Clean imports and current jax.tree_util API
adv/*_test.py reliance on loader attributes .u, .y, .s Explicit get_train_val_test_data() helper
dr/eval.py Dataset, DataLoader, Subset, BaseDataset tf.data.Dataset.from_tensor_slices(...).batch(...).prefetch(...)
ns/eval.py TensorDataset, Dataset, DataLoader, Subset, BaseDataset tf.data.Dataset.from_tensor_slices(...).batch(...).prefetch(...)
swe/eval.py Dataset, DataLoader, Subset, BaseDataset tf.data.Dataset.from_tensor_slices(...).batch(...).prefetch(...)
swe/swe_pipeline.py runtime torch.load("normstats.pt") runtime np.load("normstats.npz")
src/data_pipeline.py torch.utils.data.Dataset inheritance framework-free BaseDataset with __len__ / __getitem__
requirements.txt torch added missing runtime deps tqdm, wandb

SWE normstats.pt migration

SWE normalization stats were previously stored as a PyTorch .pt file. Runtime
loading now expects NumPy .npz stats:

normstats.npz

A one-time migration script is provided at the bottom of swe/swe_pipeline.py:

python swe/swe_pipeline.py /path/to/ShallowWater2D/

This converts:

normstats.pt -> normstats.npz

Because this is only a legacy conversion path, PyTorch is no longer listed in
requirements.txt. To run the migration script, use an environment that has
PyTorch installed, or temporarily install PyTorch for the conversion only.

After conversion, normal training/evaluation uses only normstats.npz.

What is not changed

  • Model architectures are not changed.
  • Optimizers and training objectives are not changed.
  • Sampling logic is preserved.
  • Batch sizes and drop_last/drop_remainder behavior are preserved.
  • Normalization logic is preserved after .pt -> .npz conversion.
  • Numerical behavior should be unchanged.

The only non-data-path compatibility updates are import cleanups and replacing
deprecated JAX helpers such as jax.tree_map / jax.tree_leaves with
jax.tree_util.

Verification

Performed repository-wide checks:

  • Confirmed no import torch or from torch statements remain in .py files.
  • Confirmed all files using tf.data import tensorflow as tf.
  • Confirmed normstats.pt appears only inside the guarded one-time migration block in swe/swe_pipeline.py.
  • Confirmed torch is removed from requirements.txt.
  • Ran syntax compilation for all Python files.
  • Ran import smoke checks for core modules.
  • Ran small smoke checks for:
    • ADV generator batching
    • tf.data eval loader pattern
    • BaseDataset
    • SWE .npz normstats loading
    • CViT forward initialization

Requested validation

Please verify with existing experiment runs that L2 errors on the
dr, ns, swe, and adv benchmarks remain unchanged after this refactor.

…Python

PyTorch was used as a data-loading and legacy serialization utility, not for model definition, training, or optimization. This commit removes the runtime dependency by:

- Replacing torch.utils.data Dataset/DataLoader usage in dr/, ns/, and swe/ eval scripts with tf.data pipelines already used by src/data_pipeline.py
- Replacing the GridSampling Dataset subclass in adv/ with a plain Python generator; the index argument was never used and batches were JAX-generated
- Replacing runtime torch.load('normstats.pt') in swe/swe_pipeline.py with NumPy .npz loading; a one-time optional migration script remains in that file
- Removing dead imports such as Subset and TensorDataset that were never called
- Removing BaseDataset's torch.Dataset inheritance because pure Python indexing is sufficient
- Removing torch from requirements.txt and adding the missing tqdm and wandb runtime dependencies

The full ML stack remains JAX + Flax + Optax. TensorFlow is retained solely for its tf.data pipeline API.

No numerical behaviour is changed. Existing normstats.pt files must be converted once using the migration script in swe/swe_pipeline.py from an environment with PyTorch installed.
@GVourvachakis GVourvachakis reopened this May 27, 2026
@GVourvachakis GVourvachakis changed the title refactor(deps): remove PyTorch dependency — replace with tf.data and pure Python refactor(deps): remove PyTorch dependency, use tf.data and pure Python loaders May 27, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant