66import numpy as np
77from pyscf import data
88import metatensor
9- from qstack .reorder import get_mrange
9+ from qstack .tools import Cursor
10+ from qstack .reorder import get_mrange , pyscf2gpr_l1_order
1011from qstack .compound import singleatom_basis_enumerator
1112
1213
2627
2728_molid_name = 'mol_id'
2829
29- _pyscf2gpr_l1_order = [1 ,2 ,0 ]
30-
3130
3231def _get_llist (mol ):
3332 """Get list of angular momentum quantum numbers for basis functions of each element of a molecule.
@@ -50,7 +49,7 @@ def _get_tsize(tensor):
5049 Returns:
5150 int: Total size of the tensor (total number of elements).
5251 """
53- return sum ([ np .prod (tensor .block (key ).values .shape ) for key in tensor .keys ] )
52+ return sum (np .prod (tensor .block (key ).values .shape ) for key in tensor .keys )
5453
5554
5655def _labels_to_array (labels ):
@@ -115,26 +114,24 @@ def vector_to_tensormap(mol, c):
115114 # Fill in the blocks
116115
117116 iq = dict .fromkeys (llists .keys (), 0 )
118- i = 0
117+ i = Cursor ( action = 'slicer' )
119118 for q in atom_charges :
120119 if llists [q ]== sorted (llists [q ]):
121120 for l in set (llists [q ]):
122121 msize = 2 * l + 1
123- nsize = blocks [( l ,q ) ].shape [- 1 ]
124- cslice = c [i : i + nsize * msize ].reshape (nsize ,msize ).T
122+ nsize = blocks [l ,q ].shape [- 1 ]
123+ cslice = c [i ( nsize * msize ) ].reshape (nsize ,msize ).T
125124 if l == 1 : # for l=1, the pyscf order is x,y,z (1,-1,0)
126- cslice = cslice [_pyscf2gpr_l1_order ]
127- blocks [(l ,q )][iq [q ],:,:] = cslice
128- i += msize * nsize
125+ cslice = cslice [pyscf2gpr_l1_order ]
126+ blocks [l ,q ][iq [q ],:,:] = cslice
129127 else :
130128 il = dict .fromkeys (range (max (llists [q ]) + 1 ), 0 )
131129 for l in llists [q ]:
132130 msize = 2 * l + 1
133- cslice = c [i : i + msize ]
131+ cslice = c [i ( msize ) ]
134132 if l == 1 : # for l=1, the pyscf order is x,y,z (1,-1,0)
135- cslice = cslice [_pyscf2gpr_l1_order ]
136- blocks [(l ,q )][iq [q ],:,il [l ]] = cslice
137- i += msize
133+ cslice = cslice [pyscf2gpr_l1_order ]
134+ blocks [l ,q ][iq [q ],:,il [l ]] = cslice
138135 il [l ] += 1
139136 iq [q ] += 1
140137
@@ -242,58 +239,54 @@ def matrix_to_tensormap(mol, dm):
242239
243240 if all (llists [q ]== sorted (llists [q ]) for q in llists ):
244241 iq1 = dict .fromkeys (elements , 0 )
245- i1 = 0
242+ i1 = Cursor ( action = 'slicer' )
246243 for iat1 , q1 in enumerate (atom_charges ):
247244 for l1 in set (llists [q1 ]):
248245 msize1 = 2 * l1 + 1
249246 nsize1 = llists [q1 ].count (l1 )
250247 iq2 = dict .fromkeys (elements , 0 )
251- i2 = 0
248+ i1 .add (nsize1 * msize1 )
249+ i2 = Cursor (action = 'slicer' )
252250 for iat2 , q2 in enumerate (atom_charges ):
253251 for l2 in set (llists [q2 ]):
254252 msize2 = 2 * l2 + 1
255253 nsize2 = llists [q2 ].count (l2 )
256- dmslice = dm [i1 : i1 + nsize1 * msize1 ,i2 : i2 + nsize2 * msize2 ].reshape (nsize1 ,msize1 ,nsize2 ,msize2 )
254+ dmslice = dm [i1 () ,i2 ( nsize2 * msize2 ) ].reshape (nsize1 ,msize1 ,nsize2 ,msize2 )
257255 dmslice = np .transpose (dmslice , axes = [1 ,3 ,0 ,2 ]).reshape (msize1 ,msize2 ,- 1 )
258256 block = tensor_blocks [tm_label_vals .index ((l1 ,l2 ,q1 ,q2 ))]
259257 at_p = block .samples .position ((iat1 ,iat2 ))
260- blocks [(l1 ,l2 ,q1 ,q2 )][at_p ,:,:,:] = dmslice
261- i2 += msize2 * nsize2
258+ blocks [l1 ,l2 ,q1 ,q2 ][at_p ,:,:,:] = dmslice
262259 iq2 [q2 ] += 1
263- i1 += msize1 * nsize1
264260 iq1 [q1 ] += 1
265261 else :
266262 iq1 = dict .fromkeys (elements , 0 )
267- i1 = 0
263+ i1 = Cursor ( action = 'slicer' )
268264 for iat1 , q1 in enumerate (atom_charges ):
269265 il1 = dict .fromkeys (range (max (llists [q1 ]) + 1 ), 0 )
270266 for l1 in llists [q1 ]:
271- msize1 = 2 * l1 + 1
267+ i1 . add ( 2 * l1 + 1 )
272268 iq2 = dict .fromkeys (elements , 0 )
273- i2 = 0
269+ i2 = Cursor ( action = 'slicer' )
274270 for iat2 , q2 in enumerate (atom_charges ):
275271 il2 = dict .fromkeys (range (max (llists [q2 ]) + 1 ), 0 )
276272 for l2 in llists [q2 ]:
277- msize2 = 2 * l2 + 1
278- dmslice = dm [i1 :i1 + msize1 ,i2 :i2 + msize2 ]
273+ dmslice = dm [i1 (),i2 (2 * l2 + 1 )]
279274 block = tensor_blocks [tm_label_vals .index ((l1 , l2 , q1 , q2 ))]
280275 at_p = block .samples .position ((iat1 , iat2 ))
281276 n_p = block .properties .position ((il1 [l1 ], il2 [l2 ]))
282- blocks [(l1 ,l2 ,q1 ,q2 )][at_p ,:,:,n_p ] = dmslice
283- i2 += msize2
277+ blocks [l1 ,l2 ,q1 ,q2 ][at_p ,:,:,n_p ] = dmslice
284278 il2 [l2 ] += 1
285279 iq2 [q2 ] += 1
286- i1 += msize1
287280 il1 [l1 ] += 1
288281 iq1 [q1 ] += 1
289282
290283 # Fix the m order (for l=1, the pyscf order is x,y,z (1,-1,0))
291284 for key in blocks :
292285 l1 ,l2 = key [:2 ]
293286 if l1 == 1 :
294- blocks [key ] = np .ascontiguousarray (blocks [key ][:,_pyscf2gpr_l1_order ,:,:])
287+ blocks [key ] = np .ascontiguousarray (blocks [key ][:,pyscf2gpr_l1_order ,:,:])
295288 if l2 == 1 :
296- blocks [key ] = np .ascontiguousarray (blocks [key ][:,:,_pyscf2gpr_l1_order ,:])
289+ blocks [key ] = np .ascontiguousarray (blocks [key ][:,:,pyscf2gpr_l1_order ,:])
297290
298291 # Build tensor map
299292 tensor_blocks = [metatensor .TensorBlock (values = blocks [key ], samples = block_samp_labels [key ], components = block_comp_labels [key ], properties = block_prop_labels [key ]) for key in tm_label_vals ]
@@ -492,7 +485,7 @@ def split(tensor):
492485 continue
493486 sampleidx = [t [0 ] for t in samples ]
494487 samplelbl = [t [1 ] for t in samples ]
495- #sampleidx = [block.samples.position(lbl) for lbl in samplelbl]
488+ # sampleidx = [block.samples.position(lbl) for lbl in samplelbl]
496489
497490 blocks [key ] = block .values [sampleidx ]
498491 block_samp_labels [key ] = metatensor .Labels (tensor .sample_names [1 :], np .array (samplelbl )[:,1 :])
0 commit comments