- PyTorch Lightning-like training API for JAX-based neural network training with comprehensive features:
- LightningModule: Base class for defining training models with
training_step(),validation_step(), andconfigure_optimizers()hooks - Trainer: Orchestration class for managing training loops, epochs, and device placement
- TrainOutput/EvalOutput: Structured output types for training and evaluation results
- LightningModule: Base class for defining training models with
- 10+ built-in callbacks for customizing training behavior:
ModelCheckpoint: Automatic model saving based on monitored metricsEarlyStopping: Stop training when metrics plateauLearningRateMonitor: Track and log learning rate changesGradientClipCallback: Gradient clipping for training stabilityTimer: Track training timeRichProgressBar/TQDMProgressBar: Visual progress indicatorsLambdaCallback/PrintCallback: Custom callback utilities
- 6 pluggable logging backends:
TensorBoardLogger: TensorBoard integrationWandBLogger: Weights & Biases integrationCSVLogger: Simple CSV file loggingNeptuneLogger: Neptune.ai integrationMLFlowLogger: MLFlow integrationCompositeLogger: Combine multiple loggers
- JAX-compatible data loading with distributed support:
DataLoader/DistributedDataLoader: Efficient batch loadingDataset,ArrayDataset,DictDataset,IterableDataset: Dataset abstractionsSampler,RandomSampler,SequentialSampler,BatchSampler,DistributedSampler: Sampling strategies
- Multi-device and multi-host training strategies:
SingleDeviceStrategy: Single device executionDataParallelStrategy: Data parallelism across devicesShardedDataParallelStrategy/FullyShardedDataParallelStrategy: Memory-efficient sharded trainingAutoStrategy: Automatic strategy selectionall_reduce,broadcast: Distributed communication primitives
- Comprehensive checkpoint management:
CheckpointManager: Manage multiple checkpoints with retention policiessave_checkpoint/load_checkpoint: Save and restore model statesfind_checkpoint/list_checkpoints: Checkpoint discovery utilities
- Multiple progress bar implementations:
SimpleProgressBar: Basic text-based progressTQDMProgressBarWrapper: TQDM-based progressRichProgressBarWrapper: Rich library-based progress
- Enhanced module documentation: All public modules now include comprehensive docstrings with examples, parameter descriptions, and usage guidelines directly in
__init__.pyfiles - Reorganized imports: Cleaner and more consistent import structure across all modules
- The entire
braintools.parammodule has been removed, including:- Data containers (
Data) - Parameter wrappers (
Param,Const) - State containers (
ArrayHidden,ArrayParam) - Regularization classes (
GaussianReg,L1Reg,L2Reg) - All transform classes (
SigmoidT,SoftplusT,AffineT, etc.) - Utility functions (
get_param(),get_size())
- Data containers (
- Users relying on these features should migrate to alternative implementations or pin to version 0.1.6
- Hierarchical data container: Added
Datafor composed state storage and cloning. - Parameter wrappers: Added
ParamandConstwith built-in transforms and optional regularization. - State containers: Added
ArrayHiddenandArrayParamwith transform-aware.dataaccess. - Regularization priors: Added
GaussianReg,L1Reg, andL2Regwith optional trainable hyperparameters. - Utilities: Added
get_param()andget_size()helpers for parameter/state handling.
- New
ReluTtransform for lower-bounded parameters. - Expanded transform suite now includes
PositiveT,NegativeT,ScaledSigmoidT,PowerT,OrderedT,SimplexT, andUnitVectorT.
- Transform naming cleanup: Standardized transform class names with the
*Tsuffix (e.g.,SigmoidT,SoftplusT,AffineT,ChainT,MaskedT,ClipT).
- Expanded param API docs: Added sections for data containers, state containers, regularization,
utilities, and updated transform listings in
docs/apis/param.rst. - API index update: Added
paramAPI page todocs/index.rst.
- New test coverage: Added tests for data containers, modules, regularization, state, transforms, and utilities across the param module.
- Transform API renames: Transform classes now use the
*Tsuffix (e.g.,Sigmoid->SigmoidT). - Custom transform removed: The
Customtransform is no longer part of the public API.
- Initializer RNG:
TruncatedNormalnow defaults tonumpy.randomwhen no RNG is provided.
- 7 new bijective transforms for constrained optimization and probabilistic modeling:
- Positive: Constrains parameters to (0, +∞) using exponential transformation
- Negative: Constrains parameters to (-∞, 0) using negative softplus
- ScaledSigmoid: Sigmoid with adjustable sharpness/temperature parameter (beta)
- Power: Box-Cox family power transformation for variance stabilization
- Ordered: Ensures monotonically increasing output vectors (useful for cutpoints in ordinal regression)
- Simplex: Stick-breaking transformation for probability vectors summing to 1
- UnitVector: Projects vectors onto the unit sphere (L2 norm = 1)
- Jacobian computation: Added
log_abs_det_jacobian()method to Transform base class and implementations for probabilistic modeling- Implemented for: Identity, Sigmoid, Softplus, Log, Exp, Affine, Chain, Positive
- Gradient computation of hyperparameters of surrogate gradient functions.
- Fix batching issue in surrogate gradient functions
__repr__methods: Added string representations to all Transform classes and Param class for better debugging- Enhanced documentation: Updated
docs/apis/param.rstwith comprehensive API reference- Organized sections: Base Classes, Parameter Wrapper, Bounded Transforms, Positive/Negative Transforms, Advanced Transforms, Composition Transforms
- Descriptive explanations for each transform's use case
- Comprehensive test coverage: Added 28 new tests for param module (45 total tests passing)
- Tests for all new transforms: roundtrip, constraints, repr methods
- Tests for
log_abs_det_jacobiancorrectness - Tests for edge cases and numerical stability
- New
apply()method: Addedapply()method to all LR schedulers for more flexible learning rate application- Allows applying learning rate transformations without stepping the scheduler
- Useful for custom training loops and learning rate inspection
- Comprehensive test coverage: Added 118+ comprehensive tests covering all 17 learning rate schedulers
- Tests for basic functionality, optimizer integration, JIT compilation, state persistence
- Full coverage of edge cases and special modes for each scheduler
- Validates correctness with
@brainstate.transform.jitcompilation
- Restructured tutorial organization: Renamed and reorganized documentation files for better clarity
- Moved module tutorials into subdirectories (
conn/,init/,input/,file/,surrogate/) - Updated table of contents structure across all modules
- Improved navigation with consolidated index files (
index.mdinstead oftoc_*.md)
- Moved module tutorials into subdirectories (
- Enhanced visual branding: Updated project logo from JPG to high-resolution PNG format
- Better quality and transparency support
- Consistent branding across documentation
- Test improvements: Refactored scheduler tests with better organization and coverage
- Each scheduler now has 5-10 dedicated tests
- Tests verify: basic functionality, optimizer integration, JIT compilation, multiple param groups, state dict save/load
- Discovered and documented key implementation behaviors (epoch counting, initialization patterns)
- Updated GitHub Actions: Bumped actions to latest versions for improved security and performance
actions/download-artifact: v5 → v6actions/upload-artifact: v4 → v5- Better artifact handling in CI pipeline
- Fixed edge cases in learning rate scheduler state management
- Corrected epoch counting behavior in milestone-based schedulers
- Improved JIT compilation compatibility for all schedulers
- All 17 learning rate schedulers now have comprehensive test coverage (100%)
- Enhanced reliability for training workflows with thorough validation
- Improved developer experience with better documentation structure
- New comprehensive surrogate gradient system for training spiking neural networks (SNNs)
- 18+ surrogate gradient functions with straight-through estimator support:
- Sigmoid-based:
Sigmoid,SoftSign,Arctan,ERF - Piecewise:
PiecewiseQuadratic,PiecewiseExp,PiecewiseLeakyRelu - ReLU-based:
ReluGrad,LeakyRelu,LogTailedRelu - Distribution-inspired:
GaussianGrad,MultiGaussianGrad,InvSquareGrad,SlayerGrad - Advanced:
S2NN,QPseudoSpike,SquarewaveFourierSeries,NonzeroSignLog
- Sigmoid-based:
- Customizable hyperparameters (alpha, sigma, width, etc.) for fine-tuning gradient behavior
- Comprehensive tutorials: 2 detailed notebooks covering basics and customization
- Enables gradient-based training of SNNs via backpropagation through time
- Over 2,600 lines of implementation with extensive test coverage
- ExponentialDecayLR scheduler: Fine-grained exponential decay with step-based control
- Support for transition steps, staircase mode, delayed start, and bounded decay
- Better control than epoch-based ExponentialLR for step-level scheduling
- Compatible with Optax's exponential_decay schedule
- Deprecation warnings added for future API changes:
- Deprecated
beta1andbeta2parameters in Adam optimizer (useb1andb2instead) - Deprecated
unitparameter in various initializers (useUNITLESSby default) - Deprecated
init_callfunction replaced withparamfor improved consistency
- Deprecated
- Enhanced state management: Refactored
UniqueStateManagerto utilize pytree methods - Comprehensive tests: Added extensive tests for
UniqueStateManagermethods and edge cases
- Updated API documentation for new surrogate gradient module
- Added learning rate scheduler documentation for
ExponentialDecayLR - Enhanced optimizer tutorials with updated examples
- Clarified docstrings for
FixedProbclass and variance scaling initializer
- Updated copyright information from BDP Ecosystem Limited to BrainX Ecosystem Limited
- Improved consistency across codebase with standardized function signatures
- Better default parameter handling (
UNITLESSfor unit parameters) - Enhanced test coverage for state management and optimizers
- Improved correlation and firing metrics implementation
- Enhanced LFP (Local Field Potential) analysis functions
- Better error handling and validation in metric computations
- Deprecation notices (not yet removed, but will be in future versions):
beta1/beta2parameters in Adam optimizer (useb1/b2)unitparameter in initializers (defaults toUNITLESS)init_callfunction (useparaminstead)
- This release focuses on enabling gradient-based training for spiking neural networks
- The surrogate gradient module is a major addition for neuromorphic computing and SNN research
- Enhanced learning rate scheduling provides more control for training workflows
- Momentum optimizers: Added
MomentumandMomentumNesterovoptimizers with gradient transformations - Improved state management: Refactored optimizer state handling with new
OptimStateclass for better encapsulation
- ZeroInit initializer: New zero initialization class for weights and parameters
- VarianceScaling export: Added
VarianceScalingto module exports for easier access
- Enhanced optimizer state management for better performance and maintainability
- Simplified initialization API with additional export options
- Updated documentation for new initialization methods
- Refactored test structure for initialization module
- Improved learning rate scheduler implementation
- Unified initialization API consolidating all weight and parameter initialization strategies
- Distance-based initialization: Support for distance-modulated weight patterns
- Variance scaling strategies: Xavier, He, LeCun initialization methods
- Orthogonal initialization for improved training stability
- Composite distributions for complex initialization patterns
- Simplified API with consistent parameter naming across all initializers
- Topological network patterns:
- Small-world and scale-free networks
- Hierarchical and core-periphery structures
- Modular and clustered random connectivity
- Enhanced biological connectivity:
- Excitatory-inhibitory balanced networks
- Distance-dependent connectivity with multiple profiles
- Compartment-specific connectivity (dendrite, soma, axon)
- Spatial connectivity improvements:
- 2D convolutional kernels for spatial networks
- Position-based connectivity with normalization
- Distance modulation using composable profiles
- Full Optax optimizer support: Adam, SGD, RMSProp, AdaGrad, AdaDelta, and more
- Advanced learning rate schedulers:
- Cosine annealing with warm restarts
- Polynomial decay with warmup
- Piecewise constant schedules
- Sequential and chained schedulers
- Improved optimizer state management with unique state handling
- Parameter groups with per-group learning rates
- Simplified
connmodule API with direct class access - Refactored initialization calls for consistency
- Improved type annotations throughout
- Better default parameter handling
- Updated tutorial structure for connectivity patterns
- New examples for topological networks
- Enhanced API documentation with detailed examples
- Improved code readability in tutorials
- Comprehensive test coverage for new features
- Better error handling and validation
- Consistent naming conventions
- Removed deprecated and redundant code
- Renamed
PointNeuronConnectivitytoPointConnectivity - Renamed
ConvKerneltoConv2dKernel - Unified initializer names (e.g.,
ConstantWeight→Constant) - Removed
PopulationRateConnectivityclass - Changed some parameter names for clarity (e.g., unified use of
rngparameter)
- New visualization modules for neural data analysis:
neural.py: Spike rasters, population activity, connectivity matrices, firing rate mapsthree_d.py: 3D visualizations for neural networks, brain surfaces, trajectories, electrode arraysstatistical.py: Statistical plotting tools (confusion matrices, ROC curves, correlation plots)interactive.py: Interactive visualizations with Plotly supportcolormaps.py: Neural-specific colormaps and publication-ready styling
- 15+ new tutorial notebooks covering all visualization techniques
- Brain-specific colormaps for membrane potential, spike activity, and connectivity
- New ODE integrators:
- Runge-Kutta methods: RK23, RK45, RKF45, DOP853, DOPRI5, SSPRK33
- Specialized methods: Midpoint, Heun, RK4(3/8), Ralston RK2/RK3, Bogacki-Shampine
- New SDE integrators: Heun, Tamed Euler, Implicit Euler, SRK2, SRK3, SRK4
- IMEX integrators for stiff equations: Euler, ARS(2,2,2), CNAB
- DDE integrators for delay differential equations
- Comprehensive test coverage and accuracy verification
- Spike encoders: Rate, Poisson, Population, Latency, and Temporal encoders
- Enhanced spike operations with bitwise functionality
- Spike metrics: Victor-Purpura distance, spike train synchrony, correlation indices
- Tutorial notebooks for spike encoding and analysis
- NevergradOptimizer: Integration with Nevergrad optimization library
- ScipyOptimizer: Enhanced scipy optimization with flexible bounds support
- Refactored optimizer architecture for better extensibility
- Support for dict and sequence parameter bounds
- Enhanced msgpack serialization with mismatch handling options
- Improved checkpoint loading with better error recovery
- Support for handling mismatched keys during state restoration
- LFP analysis functions: Power spectral density, coherence analysis, phase-amplitude coupling
- Functional connectivity: Dynamic connectivity computation
- Classification metrics: Binary, multiclass, focal loss, and smoothing techniques
- Regression losses: MSE, MAE, Huber, and quantile losses
- Added comprehensive API documentation for all new modules
- Created tutorials for:
- ODE/SDE integration methods
- Classification and regression losses
- Pairwise and embedding similarity
- Spiking metrics and LFP analysis
- Advanced neural visualization techniques
- Updated project description from "brain modeling" to "brain simulation"
- Changed references from BrainPy to BrainTools throughout
- Added extensive unit tests for all new modules
- Improved type hints and parameter documentation
- Better error handling and validation
- Consistent API design across modules
- Refactored optimizer module structure (moved from single
optimizer.pyto separate modules) - Removed unused key parameter from spike encoder methods
- Updated some function signatures for clarity
- Fixed Softplus unit scaling issues
- Corrected paths in publish workflow
- Fixed formatting in ODE integrator documentation
- Resolved msgpack checkpoint handling errors