Skip to content

Commit 6ede655

Browse files
authored
Merge pull request lcmd-epfl#116 from lcmd-epfl/fix-spahm-a
Fix and refactor SPAHM(a)
2 parents 20ad599 + 89c91f4 commit 6ede655

60 files changed

Lines changed: 636 additions & 490 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

qstack/basis_opt/opt.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@ def energy(x):
4343
E += qbbt.energy_mol(newbasis, m)
4444
return E
4545

46-
4746
def gradient(x):
4847
"""Compute total loss function (fitting error) and gradient for given exponents.
4948
@@ -76,7 +75,6 @@ def gradient(x):
7675
dE_dx = dE_da * exponents
7776
return E, dE_dx
7877

79-
8078
def gradient_only(x):
8179
"""Compute only the gradient of the loss function (wrapper for optimization algorithms).
8280
@@ -88,7 +86,6 @@ def gradient_only(x):
8886
"""
8987
return gradient(x)[1]
9088

91-
9289
def read_bases(basis_files):
9390
"""Read basis set definitions from files or dicts.
9491
@@ -117,7 +114,6 @@ def read_bases(basis_files):
117114
basis.update(i)
118115
return basis
119116

120-
121117
def make_bf_start():
122118
"""Create basis function index bounds for each element.
123119
@@ -131,7 +127,6 @@ def make_bf_start():
131127
bf_bounds[q] = [start, start+nbf[i]]
132128
return bf_bounds
133129

134-
135130
def make_moldata(fname):
136131
"""Create molecular data dictionary from file or dict.
137132

qstack/compound.py

Lines changed: 27 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from qstack.reorder import get_mrange
1010
from qstack.mathutils.array import stack_padding
1111
from qstack.mathutils.rotation_matrix import rotate_euler
12+
from qstack.tools import Cursor
1213

1314

1415
# detects a charge-spin line, containing only two ints (one positive or negative, the other positive and nonzero)
@@ -319,49 +320,49 @@ def singleatom_basis_enumerator(basis):
319320
ao_starts = []
320321
l_per_bas = []
321322
n_per_bas = []
322-
cursor = 0
323+
cursor = Cursor(action='ranger')
323324
cursor_per_l = []
324325
for bas in basis:
325326
# shape of `bas`, l, then another optional constant, then lists [exp, coeff, coeff, coeff]
326327
# that make a matrix between the number of functions (number of coeff per list)
327328
# and the number of primitive gaussians (one per list)
328329
l = bas[0]
329330
while len(cursor_per_l) <= l:
330-
cursor_per_l.append(0)
331-
331+
cursor_per_l.append(Cursor(action='ranger'))
332332
n_count = len(bas[-1])-1
333-
n_start = cursor_per_l[l]
334-
cursor_per_l[l] += n_count
335-
336333
l_per_bas += [l] * n_count
337-
n_per_bas.extend(range(n_start, n_start+n_count))
334+
n_per_bas.extend(cursor_per_l[l].add(n_count))
338335
msize = 2*l+1
339-
ao_starts.extend(range(cursor, cursor+msize*n_count, msize))
340-
cursor += msize*n_count
336+
ao_starts.extend(cursor.add(msize*n_count)[::msize])
341337
return l_per_bas, n_per_bas, ao_starts
342338

343339

344-
def basis_flatten(mol, return_both=True):
340+
def basis_flatten(mol, return_both=True, return_shells=False):
345341
"""Flatten a basis set definition for AOs.
346342
347343
Args:
348344
mol (pyscf.gto.Mole): pyscf Mole object.
349345
return_both (bool): Whether to return both AO info and primitive Gaussian info. Defaults to True.
346+
return_shells (bool): Whether to return angular momenta per shell. Defaults to False.
350347
351348
Returns:
352349
- numpy.ndarray: 3×mol.nao int array where each column corresponds to an AO and rows are:
353-
- 0: atom index
354-
- 1: angular momentum quantum number l
355-
- 2: magnetic quantum number m
350+
- 0: atom index
351+
- 1: angular momentum quantum number l
352+
- 2: magnetic quantum number m
356353
If return_both is True, also returns:
357354
- numpy.ndarray: 2×mol.nao×max_n float array where index (i,j,k) means:
358-
- i: 0 for exponent, 1 for contraction coefficient of a primitive Gaussian
359-
- j: AO index
360-
- k: radial function index (padded with zeros if necessary)
355+
- i: 0 for exponent, 1 for contraction coefficient of a primitive Gaussian
356+
- j: AO index
357+
- k: radial function index (padded with zeros if necessary)
358+
If return_shell is True, also returns:
359+
- numpy.ndarray: angular momentum quantum number for each shell
360+
361361
"""
362362
x = []
363+
L = []
363364
y = np.zeros((3, mol.nao), dtype=int)
364-
i = 0
365+
i = Cursor(action='slicer')
365366
a = mol.bas_exps()
366367
for iat in range(mol.natm):
367368
for bas_id in mol.atom_shell_ids(iat):
@@ -373,11 +374,13 @@ def basis_flatten(mol, return_both=True):
373374
for c in cs.T:
374375
ac = np.array([a[bas_id], c])
375376
x.extend([ac]*msize)
376-
y[:2,i:i+msize*n] = np.array([[iat, l]]*msize*n).T
377-
y[2,i:i+msize*n] = [*get_mrange(l)]*n
378-
i += msize*n
377+
y[:,i(msize*n)] = np.vstack((np.array([[iat, l]]*msize*n).T, [*get_mrange(l)]*n))
378+
if return_shells:
379+
L.extend([l]*n)
380+
381+
ret = [y]
379382
if return_both:
380-
x = stack_padding(x).transpose((1,0,2))
381-
return y, x
382-
else:
383-
return y
383+
ret.append(stack_padding(x).transpose((1,0,2)))
384+
if return_shells:
385+
ret.append(np.array(L))
386+
return ret[0] if len(ret)==1 else ret

qstack/constants.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,13 @@
33
https://physics.nist.gov/cuu/Constants/
44
https://physics.nist.gov/cuu/Constants/Table/allascii.txt
55
"""
6+
import math
7+
8+
69
# Constants
710
SPEED_LIGHT = 299792458.0
811
PLANCK = 6.62607004e-34
9-
HBAR = PLANCK/(2*3.141592653589793)
12+
HBAR = PLANCK/(2*math.pi)
1013
FUND_CHARGE = 1.6021766208e-19
1114
MOL_NA = 6.022140857e23
1215
MASS_E = 9.10938356e-31
@@ -20,4 +23,4 @@
2023
BOHR2ANGS = 0.52917721092 # Angstroms
2124
HARTREE2J = HBAR**2/(MASS_E*(BOHR2ANGS*1e-10)**2)
2225
HARTREE2EV = 27.21138602
23-
AU2DEBYE = FUND_CHARGE * BOHR2ANGS*1e-10 / DEBYE # 2.541746
26+
AU2DEBYE = FUND_CHARGE * BOHR2ANGS*1e-10 / DEBYE # 2.541746

qstack/equio.py

Lines changed: 24 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
import numpy as np
77
from pyscf import data
88
import metatensor
9-
from qstack.reorder import get_mrange
9+
from qstack.tools import Cursor
10+
from qstack.reorder import get_mrange, pyscf2gpr_l1_order
1011
from qstack.compound import singleatom_basis_enumerator
1112

1213

@@ -26,8 +27,6 @@
2627

2728
_molid_name = 'mol_id'
2829

29-
_pyscf2gpr_l1_order = [1,2,0]
30-
3130

3231
def _get_llist(mol):
3332
"""Get list of angular momentum quantum numbers for basis functions of each element of a molecule.
@@ -50,7 +49,7 @@ def _get_tsize(tensor):
5049
Returns:
5150
int: Total size of the tensor (total number of elements).
5251
"""
53-
return sum([np.prod(tensor.block(key).values.shape) for key in tensor.keys])
52+
return sum(np.prod(tensor.block(key).values.shape) for key in tensor.keys)
5453

5554

5655
def _labels_to_array(labels):
@@ -115,26 +114,24 @@ def vector_to_tensormap(mol, c):
115114
# Fill in the blocks
116115

117116
iq = dict.fromkeys(llists.keys(), 0)
118-
i = 0
117+
i = Cursor(action='slicer')
119118
for q in atom_charges:
120119
if llists[q]==sorted(llists[q]):
121120
for l in set(llists[q]):
122121
msize = 2*l+1
123-
nsize = blocks[(l,q)].shape[-1]
124-
cslice = c[i:i+nsize*msize].reshape(nsize,msize).T
122+
nsize = blocks[l,q].shape[-1]
123+
cslice = c[i(nsize*msize)].reshape(nsize,msize).T
125124
if l==1: # for l=1, the pyscf order is x,y,z (1,-1,0)
126-
cslice = cslice[_pyscf2gpr_l1_order]
127-
blocks[(l,q)][iq[q],:,:] = cslice
128-
i += msize*nsize
125+
cslice = cslice[pyscf2gpr_l1_order]
126+
blocks[l,q][iq[q],:,:] = cslice
129127
else:
130128
il = dict.fromkeys(range(max(llists[q]) + 1), 0)
131129
for l in llists[q]:
132130
msize = 2*l+1
133-
cslice = c[i:i+msize]
131+
cslice = c[i(msize)]
134132
if l==1: # for l=1, the pyscf order is x,y,z (1,-1,0)
135-
cslice = cslice[_pyscf2gpr_l1_order]
136-
blocks[(l,q)][iq[q],:,il[l]] = cslice
137-
i += msize
133+
cslice = cslice[pyscf2gpr_l1_order]
134+
blocks[l,q][iq[q],:,il[l]] = cslice
138135
il[l] += 1
139136
iq[q] += 1
140137

@@ -242,58 +239,54 @@ def matrix_to_tensormap(mol, dm):
242239

243240
if all(llists[q]==sorted(llists[q]) for q in llists):
244241
iq1 = dict.fromkeys(elements, 0)
245-
i1 = 0
242+
i1 = Cursor(action='slicer')
246243
for iat1, q1 in enumerate(atom_charges):
247244
for l1 in set(llists[q1]):
248245
msize1 = 2*l1+1
249246
nsize1 = llists[q1].count(l1)
250247
iq2 = dict.fromkeys(elements, 0)
251-
i2 = 0
248+
i1.add(nsize1*msize1)
249+
i2 = Cursor(action='slicer')
252250
for iat2, q2 in enumerate(atom_charges):
253251
for l2 in set(llists[q2]):
254252
msize2 = 2*l2+1
255253
nsize2 = llists[q2].count(l2)
256-
dmslice = dm[i1:i1+nsize1*msize1,i2:i2+nsize2*msize2].reshape(nsize1,msize1,nsize2,msize2)
254+
dmslice = dm[i1(),i2(nsize2*msize2)].reshape(nsize1,msize1,nsize2,msize2)
257255
dmslice = np.transpose(dmslice, axes=[1,3,0,2]).reshape(msize1,msize2,-1)
258256
block = tensor_blocks[tm_label_vals.index((l1,l2,q1,q2))]
259257
at_p = block.samples.position((iat1,iat2))
260-
blocks[(l1,l2,q1,q2)][at_p,:,:,:] = dmslice
261-
i2 += msize2*nsize2
258+
blocks[l1,l2,q1,q2][at_p,:,:,:] = dmslice
262259
iq2[q2] += 1
263-
i1 += msize1*nsize1
264260
iq1[q1] += 1
265261
else:
266262
iq1 = dict.fromkeys(elements, 0)
267-
i1 = 0
263+
i1 = Cursor(action='slicer')
268264
for iat1, q1 in enumerate(atom_charges):
269265
il1 = dict.fromkeys(range(max(llists[q1]) + 1), 0)
270266
for l1 in llists[q1]:
271-
msize1 = 2*l1+1
267+
i1.add(2*l1+1)
272268
iq2 = dict.fromkeys(elements, 0)
273-
i2 = 0
269+
i2 = Cursor(action='slicer')
274270
for iat2, q2 in enumerate(atom_charges):
275271
il2 = dict.fromkeys(range(max(llists[q2]) + 1), 0)
276272
for l2 in llists[q2]:
277-
msize2 = 2*l2+1
278-
dmslice = dm[i1:i1+msize1,i2:i2+msize2]
273+
dmslice = dm[i1(),i2(2*l2+1)]
279274
block = tensor_blocks[tm_label_vals.index((l1, l2, q1, q2))]
280275
at_p = block.samples.position((iat1, iat2))
281276
n_p = block.properties.position((il1[l1], il2[l2]))
282-
blocks[(l1,l2,q1,q2)][at_p,:,:,n_p] = dmslice
283-
i2 += msize2
277+
blocks[l1,l2,q1,q2][at_p,:,:,n_p] = dmslice
284278
il2[l2] += 1
285279
iq2[q2] += 1
286-
i1 += msize1
287280
il1[l1] += 1
288281
iq1[q1] += 1
289282

290283
# Fix the m order (for l=1, the pyscf order is x,y,z (1,-1,0))
291284
for key in blocks:
292285
l1,l2 = key[:2]
293286
if l1==1:
294-
blocks[key] = np.ascontiguousarray(blocks[key][:,_pyscf2gpr_l1_order,:,:])
287+
blocks[key] = np.ascontiguousarray(blocks[key][:,pyscf2gpr_l1_order,:,:])
295288
if l2==1:
296-
blocks[key] = np.ascontiguousarray(blocks[key][:,:,_pyscf2gpr_l1_order,:])
289+
blocks[key] = np.ascontiguousarray(blocks[key][:,:,pyscf2gpr_l1_order,:])
297290

298291
# Build tensor map
299292
tensor_blocks = [metatensor.TensorBlock(values=blocks[key], samples=block_samp_labels[key], components=block_comp_labels[key], properties=block_prop_labels[key]) for key in tm_label_vals]
@@ -492,7 +485,7 @@ def split(tensor):
492485
continue
493486
sampleidx = [t[0] for t in samples]
494487
samplelbl = [t[1] for t in samples]
495-
#sampleidx = [block.samples.position(lbl) for lbl in samplelbl]
488+
# sampleidx = [block.samples.position(lbl) for lbl in samplelbl]
496489

497490
blocks[key] = block.values[sampleidx]
498491
block_samp_labels[key] = metatensor.Labels(tensor.sample_names[1:], np.array(samplelbl)[:,1:])

qstack/fields/dm.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
"""Density matrix manipulation and analysis functions."""
22

3+
import numpy as np
34
from pyscf import dft
45
from qstack import constants
5-
import numpy as np
6+
from qstack.tools import Cursor
67

78

89
def get_converged_mf(mol, xc, dm0=None, verbose=False):
@@ -79,27 +80,26 @@ def sphericalize_density_matrix(mol, dm):
7980
A numpy ndarray with the sphericalized density matrix.
8081
"""
8182
idx_by_l = [[] for i in range(constants.MAX_L)]
82-
i0 = 0
83+
i0 = Cursor(action='ranger')
8384
for ib in range(mol.nbas):
8485
l = mol.bas_angular(ib)
86+
msize = 2*l+1
8587
nc = mol.bas_nctr(ib)
86-
i1 = i0 + nc * (l*2+1)
87-
idx_by_l[l].extend(range(i0, i1, l*2+1))
88-
i0 = i1
88+
idx_by_l[l].extend(i0(nc*msize)[::msize])
8989

9090
spherical_dm = np.zeros_like(dm)
9191

9292
for l in np.nonzero(idx_by_l)[0]:
93+
msize = 2*l+1
9394
for idx in idx_by_l[l]:
9495
for jdx in idx_by_l[l]:
9596
if l == 0:
9697
spherical_dm[idx,jdx] = dm[idx,jdx]
9798
else:
9899
trace = 0
99-
for m in range(2*l+1):
100+
for m in range(msize):
100101
trace += dm[idx+m,jdx+m]
101-
for m in range(2*l+1):
102-
spherical_dm[idx+m,jdx+m] = trace / (2*l+1)
102+
for m in range(msize):
103+
spherical_dm[idx+m,jdx+m] = trace / msize
103104

104105
return spherical_dm
105-

qstack/fields/dori.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def eval_rho_dm(mol, ao, dm, deriv=2):
4444
if deriv==2:
4545
DM_dAO_dr_i = 2 * _dot_ao_dm(mol, dAO_dr[i], dm, None, None, None)
4646
for j in range(i, 3):
47-
d2rho_dr2[i,j] = _contract_rho(dAO_dr[j], DM_dAO_dr_i) + 2.0*np.einsum('ip,ip->i', d2AO_dr2[triu_idx[(i,j)]], DM_AO)
47+
d2rho_dr2[i,j] = _contract_rho(dAO_dr[j], DM_dAO_dr_i) + 2.0*np.einsum('ip,ip->i', d2AO_dr2[triu_idx[i,j]], DM_AO)
4848
d2rho_dr2[j,i] = d2rho_dr2[i,j]
4949

5050
if deriv==1:

qstack/fields/excited.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ def exciton_properties_dm(mol, hole, part):
120120
dist = np.linalg.norm(hole_r-part_r)
121121
hole_extent = np.sqrt(hole_r2-hole_r@hole_r)
122122
part_extent = np.sqrt(part_r2-part_r@part_r)
123-
return(dist, hole_extent, part_extent)
123+
return dist, hole_extent, part_extent
124124

125125

126126
def exciton_properties(mol, hole, part):

qstack/mathutils/array.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Array manipulation utility functions."""
22

33
import numpy as np
4+
from qstack.tools import slice_generator
45

56

67
def scatter(values, indices):
@@ -89,9 +90,7 @@ def vstack_padding(xs):
8990
if len(np.unique(shapes_other_axes, axis=0))==1:
9091
return np.vstack(xs)
9192
X = np.zeros((shapes_common_axis.sum(), *shapes_other_axes.max(axis=0)))
92-
idx = 0
93-
for x in xs:
94-
slices = (np.s_[idx:idx+x.shape[0]], *(np.s_[0:s] for s in x.shape[1:]))
93+
for x, s0 in slice_generator(xs, inc=lambda x: x.shape[0]):
94+
slices = (s0, *(np.s_[0:s] for s in x.shape[1:]))
9595
X[slices] = x
96-
idx += x.shape[0]
9796
return X

0 commit comments

Comments
 (0)