This file contains guidelines for agentic coding agents working in this repository.
We use uv for all Python operations. Always use uv run prefix:
uv run python ... # Run Python code
uv run pytest ... # Run tests
uv run ty check ... # Type checking
uv run ruff check ... # Linting
uv run ruff format ... # FormattingRun all unit tests:
uv run pytest tests/unitRun a single test file:
uv run pytest tests/unit/aggregation/test_upgrad.py -W errorRun a single test function:
uv run pytest tests/unit/aggregation/test_upgrad.py::test_function_name -W errorRun tests with slow tests included:
uv run pytest tests/unit --runslowRun doc tests (examples in docstrings and .rst files):
uv run pytest tests/docRun tests on CUDA (requires GPU):
CUBLAS_WORKSPACE_CONFIG=:4096:8 PYTEST_TORCH_DEVICE=cuda:0 uv run pytest tests/unitRun tests with coverage:
uv run pytest tests/unit tests/doc --cov=srcRun type checking:
uv run ty checkRun linting:
uv run ruff checkRun formatting:
uv run ruff formatRun all checks together:
uv run ty check && uv run ruff check && uv run ruff formatFrom the docs folder:
uv run make htmlClean build:
uv run make clean- Only generate docstrings for public functions or functions with more than 4 lines of code
- Use Sphinx style (
:param my_param: Does something) - never use:returns:key - Include usage examples in docstrings for public classes
- Always update corresponding doc tests in
tests/doc/when modifying examples
- Always include type hints in function prototypes (including
-> None) - Do not include type hints when initializing local variables
- Use
tyfor type checking
- Use
combine-as-imports = true(configured in pyproject.toml) - Order imports: standard library, third-party, local
- Use absolute imports (e.g.,
from torchjd.aggregation import UPGrad)
- Line length: 100 characters (enforced by ruff)
- Quote style: double quotes
- Run
uv run ruff formatbefore committing
- Classes:
PascalCase(e.g.,class UPGrad) - Functions/methods:
snake_case(e.g.,def jac_to_grad) - Constants:
UPPER_SNAKE_CASE - Private functions: prefix with
_(e.g.,_internal_func) - Private modules: prefix with
_(e.g.,_utils)
- Use appropriate exception types from Python standard library or PyTorch
- Provide clear error messages with context
- Validate inputs at function boundaries
- Use assertions for internal invariants
- Use partial tensor functions from
tests/utils/tensors.py(e.g.,ones_(),randn_()) - This ensures tensors are on correct device and dtype automatically
- Manually move models/rng to DEVICE (defined in
settings.py) - Use
ModuleFactoryfor creating modules on correct device - Mark slow tests with
@pytest.mark.slow - Run affected tests after changes:
uv run pytest tests/unit/<module> -W error
src/torchjd/
├── __init__.py # Public API exports
├── autojac/ # Jacobian computation engine
│ ├── _backward.py
│ ├── _jac.py
│ ├── _jac_to_grad.py
│ ├── _mtl_backward.py
│ └── _transform/ # Internal transform implementations
├── autogram/ # Gramian-based engine
│ ├── _engine.py
│ └── ...
├── aggregation/ # Aggregators and weightings
│ ├── _upgrad.py
│ ├── _mean.py
│ └── _utils/ # Internal utilities
└── _linalg/ # Linear algebra utilities
tests/
├── unit/ # Unit tests (mirrors src structure)
├── doc/ # Docstring and rst example tests
└── utils/ # Test utilities (tensors.py, settings.py)
- Subclass
AggregatororWeightingbase classes - Aggregators must be immutable (no stateful changes)
- Implement the mathematical mapping as documented
- Add corresponding weighting if applicable
- Update
__init__.pyto export the new class
- Raise
DeprecationWarningfor deprecated public functionality - Add test in
tests/unit/test_deprecations.pyto verify warning
- Follow SOLID principles
- Keep code simple and readable
- Prefer explicit over implicit
- Add type hints to all public interfaces
- Document complex algorithms with comments
- Run all checks before submitting:
uv run ty check && uv run ruff check && uv run ruff format