jax_hf provides two JAX-jitted solvers for the Hartree–Fock free-energy minimisation problem on 2D k-meshes:
- Direct minimisation (primary): preconditioned Riemannian CG on Stiefel × capped simplex, eigen-free inner loop, Cayley retraction, one Fock build per iteration.
- Reference SCF (baseline / fallback): standard Roothaan iteration with linear mixing.
Exchange and Hartree can both be included, and the exchange kernel may
be layer-resolved. See examples/ for density-scan scripts on a
bilayer graphene model.
v2.0.0 note: This release is a clean-slate rewrite. The entire public API has changed relative to the deprecated v1.x line (which was already a skeleton in v1.1.0). See
MIGRATION.mdfor the migration guide.
pip install jax-hfimport jax.numpy as jnp
import jax_hf
# Build a HartreeFockKernel: precomputes the FFT of the interaction kernel,
# the Hartree matrix, etc., ready for JIT.
kernel = jax_hf.HartreeFockKernel(
weights=weights, # (nk1, nk2) k-point weights
hamiltonian=hamiltonian, # (nk1, nk2, nb, nb) single-particle Hamiltonian
coulomb_q=coulomb_q, # (nk1, nk2, 1, 1) scalar or (nk1, nk2, nb, nb) layer-resolved
T=0.1,
include_hartree=False, # set True for Hartree; also pass reference_density + hartree_matrix
include_exchange=True,
)
# Solve (direct minimisation, default)
result = jax_hf.solve(kernel, P0=jnp.zeros_like(hamiltonian), n_electrons=N)
print(result.energy, result.converged, result.n_iter)
# result.density, result.fock, result.Q, result.p, result.mu, result.history
# Or use SCF as a fallback baseline
result_scf = jax_hf.solve_scf(kernel, P0=jnp.zeros_like(hamiltonian), n_electrons=N)Both solvers take a Config dataclass with sensible defaults:
jax_hf.SolverConfig(max_iter=200, tol_E=1e-7, max_step=0.6, project_fn=None, ...)
jax_hf.SCFConfig(max_iter=200, mixing=0.3, density_tol=1e-7, comm_tol=1e-6, ...)project_fn lets you enforce symmetry constraints (spin, valley, time
reversal, spatial) on the density and Fock at every iteration. See
jax_hf.symmetry.make_project_fn.
| Name | Purpose |
|---|---|
HartreeFockKernel |
Problem + precomputed arrays |
solve (alias solve_direct_minimization), SolverConfig, SolveResult |
Primary solver |
solve_scf, SCFConfig, SCFResult |
Reference SCF solver |
build_fock, hf_energy, free_energy, occupation_entropy |
HF objective building blocks |
solve_continuation, ContinuationResult, resample_kgrid |
Coarse → fine multigrid driver + k-grid resampler |
Lower-level modules (jax_hf.utils, jax_hf.symmetry, jax_hf.linalg,
jax_hf.fock) expose the individual pieces for users who need them.
For large fine grids, solve_continuation runs a cheap coarse solve first
and uses its density to seed the fine solve. The two stages can mix and
match direct minimisation and SCF:
import jax_hf
from jax_hf import SCFConfig, SolverConfig
coarse = jax_hf.HartreeFockKernel(weights_c, h_c, Vq_c, T=0.1)
fine = jax_hf.HartreeFockKernel(weights_f, h_f, Vq_f, T=0.1)
result = jax_hf.solve_continuation(
coarse, fine, P0_coarse=jnp.zeros_like(h_c),
n_electrons_coarse=N, n_electrons_fine=N,
coarse_config=SCFConfig(max_iter=50, mixing=0.5), # robust coarse
fine_config=SolverConfig(max_iter=200, tol_E=1e-8), # fast fine
)
# result.coarse, result.fine (each a SolveResult or SCFResult)
# result.P0_fine (resampled coarse density used to seed the fine solve)The driver is intentionally algorithm-agnostic: it resamples the coarse
density onto the fine grid via resample_kgrid and hands off. Callers
that need physics-aware seeding (reference-density interpolation,
self-energy seeds, filling-consistent electron counts across grids) should
construct both kernels themselves.
examples/multilayer_graphene_density_scan.py— PM/SVP density scan for bilayer graphene, direct minimisation, Fock onlyexamples/multilayer_graphene_density_scan_extended.py— adds spin-polarised and "SVP flipped" branches (4 total)examples/multilayer_graphene_density_scan_hartree.py— same four branches with layer-resolved Coulomb and Hartree includedexamples/multilayer_graphene_reference_scf_scan.py— SCF baseline scan for side-by-side comparison
pytest tests/The bilayer regression tests (tests/test_bilayer_regression.py) require
contimod and contimod_graphene and will be skipped otherwise.