refactor(deps): remove PyTorch dependency, use tf.data and pure Python loaders#11
Open
GVourvachakis wants to merge 1 commit into
Open
Conversation
…Python
PyTorch was used as a data-loading and legacy serialization utility, not for model definition, training, or optimization. This commit removes the runtime dependency by:
- Replacing torch.utils.data Dataset/DataLoader usage in dr/, ns/, and swe/ eval scripts with tf.data pipelines already used by src/data_pipeline.py
- Replacing the GridSampling Dataset subclass in adv/ with a plain Python generator; the index argument was never used and batches were JAX-generated
- Replacing runtime torch.load('normstats.pt') in swe/swe_pipeline.py with NumPy .npz loading; a one-time optional migration script remains in that file
- Removing dead imports such as Subset and TensorDataset that were never called
- Removing BaseDataset's torch.Dataset inheritance because pure Python indexing is sufficient
- Removing torch from requirements.txt and adding the missing tqdm and wandb runtime dependencies
The full ML stack remains JAX + Flax + Optax. TensorFlow is retained solely for its tf.data pipeline API.
No numerical behaviour is changed. Existing normstats.pt files must be converted once using the migration script in swe/swe_pipeline.py from an environment with PyTorch installed.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
This PR removes PyTorch as a runtime dependency from CViT.
PyTorch was previously used only for data-loading utilities and legacy
serialization, not for model definition, training, optimization, or numerical
kernels. The data-loading paths are now unified around
tf.data, which wasalready used in
src/data_pipeline.py, and plain Python iteration where noframework abstraction is needed.
The main ML stack remains:
tf.dataChanges
adv/loaders.pytorch.utils.data.Datasetsubclassadv/models.pytorch.utils.dataimport, duplicate/stale importsjax.tree_utilAPIadv/*_test.py.u,.y,.sget_train_val_test_data()helperdr/eval.pyDataset,DataLoader,Subset,BaseDatasettf.data.Dataset.from_tensor_slices(...).batch(...).prefetch(...)ns/eval.pyTensorDataset,Dataset,DataLoader,Subset,BaseDatasettf.data.Dataset.from_tensor_slices(...).batch(...).prefetch(...)swe/eval.pyDataset,DataLoader,Subset,BaseDatasettf.data.Dataset.from_tensor_slices(...).batch(...).prefetch(...)swe/swe_pipeline.pytorch.load("normstats.pt")np.load("normstats.npz")src/data_pipeline.pytorch.utils.data.DatasetinheritanceBaseDatasetwith__len__/__getitem__requirements.txttorchtqdm,wandbSWE
normstats.ptmigrationSWE normalization stats were previously stored as a PyTorch
.ptfile. Runtimeloading now expects NumPy
.npzstats:A one-time migration script is provided at the bottom of
swe/swe_pipeline.py:This converts:
Because this is only a legacy conversion path, PyTorch is no longer listed in
requirements.txt. To run the migration script, use an environment that hasPyTorch installed, or temporarily install PyTorch for the conversion only.
After conversion, normal training/evaluation uses only
normstats.npz.What is not changed
drop_last/drop_remainderbehavior are preserved..pt -> .npzconversion.The only non-data-path compatibility updates are import cleanups and replacing
deprecated JAX helpers such as
jax.tree_map/jax.tree_leaveswithjax.tree_util.Verification
Performed repository-wide checks:
import torchorfrom torchstatements remain in.pyfiles.tf.dataimporttensorflow as tf.normstats.ptappears only inside the guarded one-time migration block inswe/swe_pipeline.py.torchis removed fromrequirements.txt.tf.dataeval loader patternBaseDataset.npznormstats loadingRequested validation
Please verify with existing experiment runs that L2 errors on the
dr,ns,swe, andadvbenchmarks remain unchanged after this refactor.