Skip to content

Commit 253c87d

Browse files
committed
Fix bugs
1 parent b79b37f commit 253c87d

11 files changed

Lines changed: 459 additions & 123 deletions

File tree

README.md

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
1-
# arch_eval library
1+
# arch_eval
22

33
[![Python 3.8+](https://img.shields.io/badge/python-3.8+-blue.svg)](https://www.python.org/downloads/)
44
[![PyTorch](https://img.shields.io/badge/PyTorch-1.9+-ee4c2c.svg)](https://pytorch.org/)
55
[![License](https://img.shields.io/badge/License-Apache%202.0-green.svg)](https://opensource.org/licenses/Apache-2.0)
66
[![GitHub Repo](https://img.shields.io/badge/GitHub-lof310%2Farch__eval-blue)](https://github.com/lof310/arch_eval)
7+
[![Stars](https://img.shields.io/github/stars/lof310/transformer)](#)
8+
[![Downloads](https://img.shields.io/github/downloads/lof310/transformer/total)](https://github.com/lof310/transformer/releases)
79

8-
**arch_eval** is a High-Level library for Efficient and Fast Architecture Evaluation and Comparison of Machine Learning models. It provides a unified interface for training, benchmarking, and hyperparameter optimization with features like distributed training, mixed precision, and real-time visualization.
10+
High-Level library for Efficient and Fast Architecture Evaluation and Comparison of Machine Learning models. It provides a unified interface for training, benchmarking, and hyperparameter optimization with features like distributed training, mixed precision, and real-time visualization.
911

1012
## Features
1113

@@ -15,11 +17,10 @@
1517
- **Advanced Mixed Precision**: AMP with float16, bfloat16, and experimental FP8 support.
1618
- **Gradient Checkpointing**: Reduce memory footprint for large models.
1719
- **Rich Visualization**: Real-time training windows, video recording of metrics, and publication‑ready plots.
18-
- **Logging**: DirectIntegration with Weights & Biases.
20+
- **Logging**: Integration with Weights & Biases.
1921
- **Hyperparameter Optimization**: Grid search and random search out of the box.
2022
- **Extensible Plugin System**: Custom hooks and callbacks for maximum flexibility.
21-
- **Robust Data Handling**: Supports PyTorch Datasets, synthetic data, torchvision datasets, Hugging Face datasets, and streaming.
22-
- **Production-Ready**: Configurable timeouts, retry logic and deterministic execution.
23+
- **Data Handling**: Supports PyTorch Datasets, synthetic data, torchvision datasets, Hugging Face datasets, and streaming.
2324

2425
## Installation
2526

@@ -37,11 +38,6 @@ pip install -e .
3738
pip install .
3839
```
3940

40-
Or Install directly with pip
41-
```bash
42-
pip install arch_eval
43-
```
44-
4541
## Quick Start
4642

4743
### 1. Train a Single Model
@@ -68,8 +64,7 @@ class MLP(nn.Module):
6864
self.net = nn.Sequential(
6965
nn.Linear(input_size, hidden),
7066
nn.GELU(),
71-
nn.Linear(hidden, num_classes),
72-
nn.Softmax(dim=-1)
67+
nn.Linear(hidden, num_classes)
7368
)
7469

7570
def forward(self, x):

arch_eval/__init__.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,6 @@
1111
from arch_eval.logging.logger_config import setup_logging
1212
from arch_eval.plugins.manager import PluginManager
1313

14-
# from arch_eval.interpret import permutation_importance, attention_weights
15-
1614
_plugin_manager = PluginManager()
1715
_plugin_manager.discover_plugins()
1816

@@ -27,6 +25,4 @@
2725
"init_distributed",
2826
"cleanup_distributed",
2927
"HyperparameterOptimizer",
30-
"permutation_importance",
31-
"attention_weights",
3228
]
1004 Bytes
Binary file not shown.
2.53 KB
Binary file not shown.

arch_eval/core/benchmark.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919

2020
logger = logging.getLogger(__name__)
2121

22-
2322
def _train_single_process(args):
2423
"""Helper for process-based parallelism with memory cleanup."""
2524
model_info, config = args
@@ -61,7 +60,6 @@ def _train_single_process(args):
6160
if torch.cuda.is_available():
6261
torch.cuda.empty_cache()
6362

64-
6563
class Benchmark:
6664
"""Benchmark multiple models for comparison."""
6765

arch_eval/core/config.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,33 @@
11
"""Configuration dataclasses for Trainer and Benchmark."""
22

33
import os
4+
import warnings
45
from dataclasses import dataclass, field
56
from enum import Enum
67
from typing import Any, Callable, Dict, List, Optional, Union
78

89
import torch
910

11+
1012
class TaskType(str, Enum):
1113
REGRESSION = "regression"
1214
CLASSIFICATION = "classification"
1315
NEXT_TOKEN_PREDICTION = "next-token-prediction"
1416

17+
1518
class DistributedBackend(str, Enum):
1619
NONE = "none"
1720
DATAPARALLEL = "dp"
1821
DISTRIBUTED = "ddp"
1922
FSDP = "fsdp"
2023

24+
2125
class MixedPrecisionDtype(str, Enum):
2226
FLOAT16 = "float16"
2327
BFLOAT16 = "bfloat16"
2428
FP8 = "fp8" # experimental
2529

30+
2631
def _serialize_callable(obj: Any) -> Any:
2732
"""Convert a callable to a serializable representation."""
2833
if obj is None:
@@ -34,6 +39,7 @@ def _serialize_callable(obj: Any) -> Any:
3439
warnings.warn(f"Callable {obj} may not be picklable.")
3540
return str(obj)
3641

42+
3743
def _deserialize_callable(rep: Any) -> Any:
3844
"""Restore a callable from its serialized representation."""
3945
if rep is None or not isinstance(rep, tuple):
@@ -47,6 +53,7 @@ def _deserialize_callable(rep: Any) -> Any:
4753
raise ValueError(f"Could not restore function {module_name}.{func_name}: {e}")
4854
return rep
4955

56+
5057
def _serialize_dtype(dtype: torch.dtype) -> str:
5158
"""Convert torch.dtype to string."""
5259
return str(dtype).split('.')[-1]
@@ -56,6 +63,7 @@ def _deserialize_dtype(dtype_str: str) -> torch.dtype:
5663
"""Convert string back to torch.dtype."""
5764
return getattr(torch, dtype_str)
5865

66+
5967
@dataclass
6068
class BaseConfig:
6169
"""Base configuration with common fields."""
@@ -112,7 +120,8 @@ def __post_init__(self):
112120
if self.seed is not None:
113121
torch.manual_seed(self.seed)
114122
torch.cuda.manual_seed_all(self.seed)
115-
import numpy as np, random
123+
import numpy as np
124+
import random
116125
np.random.seed(self.seed)
117126
random.seed(self.seed)
118127
if self.deterministic:
@@ -138,6 +147,7 @@ def __setstate__(self, state):
138147
state['dtype'] = _deserialize_dtype(state['dtype'])
139148
self.__dict__.update(state)
140149

150+
141151
@dataclass
142152
class TrainingConfig(BaseConfig):
143153
"""Configuration for Trainer."""
@@ -217,6 +227,13 @@ class TrainingConfig(BaseConfig):
217227
# Profiling
218228
profiler: Optional[Dict[str, Any]] = None # ej. {"enabled": True, "activities": ["cpu", "cuda"], "schedule": {...}}
219229

230+
# Memory
231+
gc_collect_interval: int = 50
232+
233+
# Confusion matrix
234+
log_confusion_matrix: bool = False
235+
confusion_matrix_labels: Optional[List[str]] = None
236+
220237
def __post_init__(self):
221238
super().__post_init__()
222239
if self.log_to_wandb and self.wandb_project is None:
@@ -250,7 +267,7 @@ def __post_init__(self):
250267
raise ConfigurationError("distributed_world_size must be >= 1")
251268
if self.mixed_precision_dtype == MixedPrecisionDtype.FP8:
252269
try:
253-
import transformer_engine.pytorch as te
270+
import transformer_engine.pytorch as te # noqa
254271
except ImportError:
255272
raise ConfigurationError("FP8 requires NVIDIA Transformer Engine installed.")
256273

@@ -286,6 +303,7 @@ def __setstate__(self, state):
286303
state['callbacks'] = restored_callbacks
287304
super().__setstate__(state)
288305

306+
289307
@dataclass
290308
class BenchmarkConfig(BaseConfig):
291309
"""Configuration for Benchmark."""

arch_eval/core/trainer.py

Lines changed: 38 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
import logging
44
import os
5-
import time
65
from collections import defaultdict
76
from contextlib import AbstractContextManager
87
from typing import Any, Dict, List, Optional, Tuple, Union
@@ -19,9 +18,9 @@
1918
from arch_eval.distributed import cleanup_distributed, get_wrapped_model, init_distributed
2019
from arch_eval.logging.logger_config import LoggerAdapter
2120
from arch_eval.metrics.calculator import MetricCalculator
22-
from arch_eval.plugins.manager import PluginManager, hook
21+
from arch_eval.plugins.manager import PluginManager
2322
from arch_eval.profiler import profiler_context
24-
from arch_eval.utils.device import memory_summary
23+
from arch_eval.utils.device import memory_summary, auto_device
2524
from arch_eval.viz.viz import PlotSaver, RealtimeWindow, VideoRecorder
2625

2726
logger = logging.getLogger(__name__)
@@ -61,13 +60,15 @@ def __init__(self, model: nn.Module, config: TrainingConfig):
6160
else:
6261
self.model = model
6362

64-
self._validate_model()
6563
self.device = torch.device(config.device)
6664
self.model = self.model.to(self.device).to(config.dtype)
6765

6866
self.dataset_handler = DatasetHandler(config)
6967
self.train_loader, self.val_loader, self.test_loader = self.dataset_handler.prepare_loaders()
7068

69+
# Validate model with a real batch (if train_loader exists)
70+
self._validate_model_with_data()
71+
7172
self.metric_calculator = MetricCalculator(
7273
config.task, config.device, output_transform=config.model_output_transform
7374
)
@@ -81,15 +82,15 @@ def __init__(self, model: nn.Module, config: TrainingConfig):
8182
self.amp_dtype = self._get_amp_dtype()
8283
self.scaler = torch.cuda.amp.GradScaler() if self.use_amp and config.grad_scaler else None
8384

84-
# Gradient checkpointing (experimental) TODO
85+
# Gradient checkpointing (experimental)
8586
if config.gradient_checkpointing:
8687
self._apply_gradient_checkpointing()
8788

8889
# Visualization
8990
self.window = None
9091
if config.realtime:
9192
try:
92-
self.window = RealtimeWindow(config)
93+
self.window = RealtimeWindow(config, metric_names=config.viz_metrics)
9394
if getattr(self.window, "disabled", False):
9495
self.window = None
9596
except Exception as e:
@@ -120,6 +121,9 @@ def __init__(self, model: nn.Module, config: TrainingConfig):
120121
self.accumulation_steps = config.gradient_accumulation_steps
121122
self.current_accum_step = 0
122123

124+
# Initialize checkpoint best metric
125+
self.checkpoint_best_metric = None
126+
123127
self.logger.info(f"Trainer initialized on {self.device}\n{memory_summary()}")
124128

125129
def _get_amp_dtype(self):
@@ -135,10 +139,7 @@ def _get_amp_dtype(self):
135139
return torch.float16
136140

137141
def _apply_gradient_checkpointing(self):
138-
"""Experimental: attempts to enable gradient checkpointing on specified modules.
139-
This is not a standard PyTorch feature; models must implement it internally.
140-
The current implementation sets a '_gradient_checkpointing' attribute on modules,
141-
which may be used by custom layers. For most models, this will have no effect."""
142+
"""Experimental: attempts to enable gradient checkpointing on specified modules."""
142143
if self.config.gradient_checkpointing_modules:
143144
for name in self.config.gradient_checkpointing_modules:
144145
module = dict(self.model.named_modules()).get(name)
@@ -150,14 +151,24 @@ def _apply_gradient_checkpointing(self):
150151
module._gradient_checkpointing = True
151152
self.logger.warning("Gradient checkpointing is experimental and may not work as expected.")
152153

153-
def _validate_model(self):
154-
shape = self.config.input_shape or (1, 10)
155-
dummy = torch.randn(1, *shape).to(torch.device(self.config.device))
154+
def _validate_model_with_data(self):
155+
"""Run a forward pass on a single batch to ensure the model accepts the data."""
156+
if self.train_loader is None:
157+
self.logger.warning("No training loader – skipping model validation.")
158+
return
156159
try:
160+
# Get one batch
161+
data, targets = next(iter(self.train_loader))
162+
data = data.to(self.device)
163+
targets = targets.to(self.device)
164+
self.model.eval()
157165
with torch.no_grad():
158-
self.model(dummy)
166+
_ = self.model(data)
167+
self.model.train()
168+
self.logger.info("Model validation passed.")
159169
except Exception as e:
160-
raise ModelError(f"Model validation failed: {e}")
170+
raise ModelError(f"Model validation failed on a real batch: {e}. "
171+
"Check that your model's input size matches the dataset features.")
161172

162173
def _setup_optimizers(self):
163174
self.optimizers = []
@@ -231,7 +242,7 @@ def _setup_loss_function(self):
231242

232243
def _compute_loss(self, output, targets):
233244
if isinstance(output, tuple) and len(output) == 2:
234-
return output[1] # Assume second element is loss
245+
return output[1] # Assume second element is loss
235246
else:
236247
return self.criterion(output, targets)
237248

@@ -298,8 +309,6 @@ def train(self) -> Dict[str, List[float]]:
298309
self.window.close()
299310
if self.config.log_to_wandb:
300311
wandb.finish()
301-
if hasattr(self, "tb_writer") and self.tb_writer:
302-
self.tb_writer.close()
303312
if self.config.distributed_backend != DistributedBackend.NONE:
304313
cleanup_distributed()
305314

@@ -334,7 +343,6 @@ def _train_epoch(self) -> Dict[str, float]:
334343
else:
335344
loss.backward()
336345

337-
# self.plugin_manager.execute_hook("on_backward", self, loss.item() * self.accumulation_steps) # FIXME: Possible Overhead
338346
self.current_accum_step += 1
339347

340348
if self.current_accum_step % self.accumulation_steps == 0:
@@ -397,7 +405,6 @@ def _evaluate(self, loader: DataLoader, split: str) -> Dict[str, float]:
397405
total_loss = 0.0
398406
metric_accum = defaultdict(float)
399407
count = 0
400-
autocast = torch.cuda.amp.autocast if self.use_amp else NullContext
401408

402409
# Reset confusion matrix accumulator if needed
403410
if self.config.log_confusion_matrix and split == "val":
@@ -437,8 +444,8 @@ def _evaluate(self, loader: DataLoader, split: str) -> Dict[str, float]:
437444
wandb.log({
438445
f"confusion_matrix/{split}": wandb.plot.confusion_matrix(
439446
probs=None,
440-
y_true=self.metric_calculator._all_targets,
441-
preds=self.metric_calculator._all_preds,
447+
y_true=np.array(self.metric_calculator._all_targets),
448+
preds=np.array(self.metric_calculator._all_preds),
442449
class_names=class_names
443450
)
444451
}, step=self.current_epoch)
@@ -475,9 +482,6 @@ def _log_metrics(self, metrics: Dict[str, float], step: int):
475482

476483
if self.config.log_to_wandb:
477484
wandb.log(metrics, step=step)
478-
if hasattr(self, "tb_writer") and self.tb_writer:
479-
for k, v in metrics.items():
480-
self.tb_writer.add_scalar(k, v, step)
481485

482486
self.plugin_manager.execute_hook("on_log", self, metrics, step)
483487

@@ -502,13 +506,18 @@ def _save_checkpoint(self, epoch: int, metrics: Dict[str, float]):
502506
current = metrics.get(self.config.checkpoint_metric)
503507
if current is not None:
504508
mode = "min" if "loss" in self.config.checkpoint_metric else "max"
505-
improved = (mode == "min" and current < self.checkpoint_best_metric) or (
506-
mode == "max" and current > self.checkpoint_best_metric
507-
)
508-
if improved:
509+
if self.checkpoint_best_metric is None:
509510
self.checkpoint_best_metric = current
510511
is_best = True
511512
save_this = True
513+
else:
514+
improved = (mode == "min" and current < self.checkpoint_best_metric) or (
515+
mode == "max" and current > self.checkpoint_best_metric
516+
)
517+
if improved:
518+
self.checkpoint_best_metric = current
519+
is_best = True
520+
save_this = True
512521
else:
513522
if epoch % self.config.save_frequency == 0:
514523
save_this = True

0 commit comments

Comments
 (0)