Skip to content

Commit fed7296

Browse files
committed
split get_ana_cov_blocks.py into recipe script and compute script
1 parent b8829b1 commit fed7296

2 files changed

Lines changed: 400 additions & 234 deletions

File tree

project/SO/pISO/python/get_ana_cov_blocks.py

Lines changed: 22 additions & 234 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77

88
import argparse
99
from os.path import join as opj
10-
from itertools import product
1110
import time
1211

1312
import numpy as npy
@@ -30,8 +29,6 @@
3029
d.read_from_file(args.paramfile)
3130
log = log.get_logger(**d)
3231

33-
coupling_cache_size = args.coupling_cache_size
34-
3532
surveys = d["surveys"]
3633
lmax = d["lmax"]
3734
cov_correlation_by_noise_model = d['cov_correlation_by_noise_model']
@@ -51,6 +48,8 @@
5148
mcm_dir = d['mcm_dir']
5249
cov_dir = d['cov_dir']
5350

51+
t0 = time.time()
52+
5453
canonized_sn_field_info2canonized_connected_combo_2pt = npy.load(opj(cov_dir, 'canonized_sn_field_info2canonized_connected_combo_2pt.npy'), allow_pickle=True).item()
5554
canonized_w2s = npy.load(opj(cov_dir, 'canonized_w2s.npy'), allow_pickle=True).item()
5655

@@ -61,6 +60,17 @@
6160
reference_sn_field_info2reference_canonized_disconnected_combo_4pt = npy.load(opj(cov_dir, 'reference_sn_field_info2reference_canonized_disconnected_combo_4pt.npy'), allow_pickle=True).item()
6261
canonized_wls = npy.load(opj(cov_dir, 'canonized_wls.npy'), allow_pickle=True).item()
6362

63+
cov_block_sets2can_discon_com_4pts_and_optypes = npy.load(opj(cov_dir, 'cov_block_sets2can_discon_com_4pts_and_optypes.npy'), allow_pickle=True).item()
64+
cov_block2TEB_block2can_sn_alm_info2nterms = npy.load(opj(cov_dir, 'cov_block2TEB_block2can_sn_alm_info2nterms.npy'), allow_pickle=True).item()
65+
66+
log.info(f'[Rank {so_mpi.rank}] Load metadata in {(time.time() - t0):.3f} seconds')
67+
68+
optype2str = {
69+
0: '00',
70+
1: '02',
71+
2: '++'
72+
}
73+
6474
def update_pseudospectra_dict(f1, f2, pseudospectra_dict=None):
6575
if pseudospectra_dict is None:
6676
pseudospectra_dict = {}
@@ -154,40 +164,6 @@ def pols_disconnected_combo_4pt2ducc_optype(pol1, pol2, pol3, pol4):
154164
# if 2, then the spintype is ++, which is ducc optype 2
155165
return spin2_1 + spin2_2
156166

157-
def update_ducc_inputs_and_nterms(sna1, sna2, sna3, sna4,
158-
this_block_can_discon_com_4pts_and_optypes,
159-
can_sn_alm_info2nterms):
160-
# update ducc inputs with minimal unique couplings, and track their order
161-
sv1, m1, TEB1, split1 = sna1
162-
sv2, m2, TEB2, split2 = sna2
163-
sv3, m3, TEB3, split3 = sna3
164-
sv4, m4, TEB4, split4 = sna4
165-
166-
pol1 = TEB2pol(TEB1)
167-
pol2 = TEB2pol(TEB2)
168-
pol3 = TEB2pol(TEB3)
169-
pol4 = TEB2pol(TEB4)
170-
171-
snf1 = (sv1, m1, pol1, split1)
172-
snf2 = (sv2, m2, pol2, split2)
173-
snf3 = (sv3, m3, pol3, split3)
174-
snf4 = (sv4, m4, pol4, split4)
175-
can_discon_com_4pt = get_can_discon_com_4pt(snf1, snf2, snf3, snf4)
176-
177-
# NOTE: although using uncanonized pol1, pol2, pol3, pol4, the optype is
178-
# insensitive to disconnected 4pt canonization
179-
optype = pols_disconnected_combo_4pt2ducc_optype(pol1, pol2, pol3, pol4)
180-
181-
# adding to set does nothing if already in set
182-
this_block_can_discon_com_4pts_and_optypes.add((can_discon_com_4pt, optype))
183-
184-
# track the number of times this term has appeared in this cov TEB sub-block
185-
can_sn_alm_info = pspipe_list.canonize_disconnected_4pt(sna1, sna2, sna3, sna4)
186-
if can_sn_alm_info not in can_sn_alm_info2nterms:
187-
can_sn_alm_info2nterms[can_sn_alm_info] = 1
188-
else:
189-
can_sn_alm_info2nterms[can_sn_alm_info] += 1
190-
191167
@numba.njit(parallel=True)
192168
def add_term_to_pseudo_cov_block(pseudo_cov_block, num_terms, w4_1234, w4_coupling, w2_12, w2_34, C12, C34, coupling):
193169
# important to cast the scalar to the right type before multiplication,
@@ -217,199 +193,7 @@ def add_term_to_pseudo_cov_block(pseudo_cov_block, num_terms, w4_1234, w4_coupli
217193
for sv in surveys:
218194
nsplits[sv] = d[f'n_splits_{sv}']
219195

220-
# first, figure out all the "shared couplings" sets of cov blocks, such that the
221-
# total number of couplings in each block is <= the cache size. NOTE: a nominal
222-
# cov_block_set might get "chopped" by blindly cutting all the cov blocks into
223-
# equal-length subtasks
224-
so_mpi.init(True)
225-
226-
t0 = time.time()
227-
228-
subtasks = so_mpi.taskrange(imin=0, imax=n_covs - 1)
229-
230-
cov_block_sets2can_discon_com_4pts_and_optypes = {}
231-
cov_block2TEB_block2can_sn_alm_info2nterms = {}
232-
233-
# need to initialize objects before while loop, that otherwise are re-initialized
234-
# in the loop
235-
cov_block_set = []
236-
can_discon_com_4pts_and_optypes = set()
237-
i = 0
238-
while True:
239-
task = subtasks[i]
240-
svi, mi = ni_list[task].split('&')
241-
svj, mj = nj_list[task].split('&')
242-
svp, mp = np_list[task].split('&')
243-
svq, mq = nq_list[task].split('&')
244-
cov_block = ((svi, mi), (svj, mj),
245-
(svp, mp), (svq, mq))
246-
247-
# "n" holds the "noise correlation group" information: f1 and f2 have
248-
# correlated noise only if ni == nj
249-
if cov_correlation_by_noise_model:
250-
ni = (svi, mapnames2noise_model_tags[f'{svi}_{mi}'])
251-
nj = (svj, mapnames2noise_model_tags[f'{svj}_{mj}'])
252-
np = (svp, mapnames2noise_model_tags[f'{svp}_{mp}'])
253-
nq = (svq, mapnames2noise_model_tags[f'{svq}_{mq}'])
254-
else:
255-
ni = svi
256-
nj = svj
257-
np = svp
258-
nq = svq
259-
260-
# we need figure out which couplings we actually need first
261-
#
262-
# for each block, see which unique couplings are needed, and then try to add
263-
# to existing block set of unique couplings. if resulting merged set fits in
264-
# the cache, go to the next block, otherwise, end set and redo this block
265-
this_block_can_discon_com_4pts_and_optypes = set()
266-
267-
# for each cov TEB sub-block, tracks how many times a canonical 4pt combo
268-
# of (sv, m, TEB, split)s recurs, so it can be added once (times this count)
269-
# rather than each time
270-
TEB_block2can_sn_alm_info2nterms = {} # "alm_info" since keys are TEB instead of T and pol
271-
272-
splits_cross_iterator_ij = pspipe_list.get_splits_cross_iterator(svi, nsplits[svi], svj, nsplits[svj])
273-
splits_cross_iterator_pq = pspipe_list.get_splits_cross_iterator(svp, nsplits[svp], svq, nsplits[svq])
274-
for (TEBi, TEBj), (TEBp, TEBq) in product(spectra, repeat=2):
275-
if (TEBi, TEBj, TEBp, TEBq) not in TEB_block2can_sn_alm_info2nterms:
276-
TEB_block2can_sn_alm_info2nterms[TEBi, TEBj, TEBp, TEBq] = {}
277-
278-
can_sn_alm_info2nterms = TEB_block2can_sn_alm_info2nterms[TEBi, TEBj, TEBp, TEBq]
279-
280-
for (si, sj), (sp, sq) in product(splits_cross_iterator_ij, splits_cross_iterator_pq):
281-
282-
# ssss ipjq
283-
update_ducc_inputs_and_nterms((svi, mi, TEBi, 's'), (svp, mp, TEBp, 's'),
284-
(svj, mj, TEBj, 's'), (svq, mq, TEBq, 's'),
285-
this_block_can_discon_com_4pts_and_optypes,
286-
can_sn_alm_info2nterms)
287-
288-
# ssnn ipjq
289-
if nj == nq and sj == sq:
290-
update_ducc_inputs_and_nterms((svi, mi, TEBi, 's'), (svp, mp, TEBp, 's'),
291-
(svj, mj, TEBj, f'n{sj}'), (svq, mq, TEBq, f'n{sj}'),
292-
this_block_can_discon_com_4pts_and_optypes,
293-
can_sn_alm_info2nterms)
294-
295-
# nnss ipjq
296-
if ni == np and si == sp:
297-
update_ducc_inputs_and_nterms((svi, mi, TEBi, f'n{si}'), (svp, mp, TEBp, f'n{si}'),
298-
(svj, mj, TEBj, 's'), (svq, mq, TEBq, 's'),
299-
this_block_can_discon_com_4pts_and_optypes,
300-
can_sn_alm_info2nterms)
301-
302-
# nnnn ipjq
303-
if ni == np and si == sp and nj == nq and sj == sq:
304-
update_ducc_inputs_and_nterms((svi, mi, TEBi, f'n{si}'), (svp, mp, TEBp, f'n{si}'),
305-
(svj, mj, TEBj, f'n{sj}'), (svq, mq, TEBq, f'n{sj}'),
306-
this_block_can_discon_com_4pts_and_optypes,
307-
can_sn_alm_info2nterms)
308-
309-
# ssss iqjp
310-
update_ducc_inputs_and_nterms((svi, mi, TEBi, 's'), (svq, mq, TEBq, 's'),
311-
(svj, mj, TEBj, 's'), (svp, mp, TEBp, 's'),
312-
this_block_can_discon_com_4pts_and_optypes,
313-
can_sn_alm_info2nterms)
314-
315-
# ssnn iqjp
316-
if nj == np and sj == sp:
317-
update_ducc_inputs_and_nterms((svi, mi, TEBi, 's'), (svq, mq, TEBq, 's'),
318-
(svj, mj, TEBj, f'n{sj}'), (svp, mp, TEBp, f'n{sj}'),
319-
this_block_can_discon_com_4pts_and_optypes,
320-
can_sn_alm_info2nterms)
321-
322-
# nnss iqjp
323-
if ni == nq and si == sq:
324-
update_ducc_inputs_and_nterms((svi, mi, TEBi, f'n{si}'), (svq, mq, TEBq, f'n{si}'),
325-
(svj, mj, TEBj, 's'), (svp, mp, TEBp, 's'),
326-
this_block_can_discon_com_4pts_and_optypes,
327-
can_sn_alm_info2nterms)
328-
329-
# nnnn iqjp
330-
if ni == nq and si == sq and nj == np and sj == sp:
331-
update_ducc_inputs_and_nterms((svi, mi, TEBi, f'n{si}'), (svq, mq, TEBq, f'n{si}'),
332-
(svj, mj, TEBj, f'n{sj}'), (svp, mp, TEBp, f'n{sj}'),
333-
this_block_can_discon_com_4pts_and_optypes,
334-
can_sn_alm_info2nterms)
335-
336-
cov_block2TEB_block2can_sn_alm_info2nterms[cov_block] = TEB_block2can_sn_alm_info2nterms
337-
338-
# there are now four possibilities for what to do with this block:
339-
# (a) if the current block requires more couplings than the cache size
340-
# limit, and the current cache is empty, then we have no recourse:
341-
# ending the set, resetting the cache, and redoing the block will of course
342-
# never work. therefore, we first force the one block into the cache, and
343-
# then end the set and reset the cache. we then go on to the next block.
344-
# this *does* "violate" the cache limit, so we issue a warning
345-
# (b) like (a), if adding the current block's couplings to the cache would
346-
# result in a cache size more than the cache size limit, but unlike (a) if
347-
# the cache is not empty, we do have a recourse: end the set, reset the
348-
# cache, and then redo this block with an empty cache.
349-
# (c) if we are on the last block of all the subtasks, but we know we are
350-
# not going to redo this block with an empty cache (i.e., not (b)), then
351-
# we are also on the last task of the loop. like (a) we must force the block
352-
# into the cache and end the set. unlike (a), we break the loop instead of
353-
# going on to the next block. it's possible that (a) and (c) occur at the
354-
# same time, in which case (c) takes priority.
355-
# (d) otherwise proceed: add this block to the current cache and go on to
356-
# the next block. hopefully this happens most of the time
357-
358-
single_block_set = False
359-
end_set_and_redo_block = False
360-
if len(can_discon_com_4pts_and_optypes & this_block_can_discon_com_4pts_and_optypes) > coupling_cache_size:
361-
single_block_set = len(can_discon_com_4pts_and_optypes) == 0
362-
end_set_and_redo_block = len(can_discon_com_4pts_and_optypes) > 0
363-
364-
end_loop = (i+1 == len(subtasks)) and not end_set_and_redo_block
365-
366-
if single_block_set:
367-
log.warning(f"[Rank {so_mpi.rank}, Task {task}] Number of couplings for cov block {cov_block} is "
368-
f"{len(this_block_can_discon_com_4pts_and_optypes)}, which excees the coupling cache "
369-
f"size of {coupling_cache_size}. Adding to single-block-set, may result in OOM later.")
370-
371-
if single_block_set or end_loop:
372-
cov_block_set.append(cov_block)
373-
can_discon_com_4pts_and_optypes &= this_block_can_discon_com_4pts_and_optypes
374-
375-
if single_block_set or end_set_and_redo_block or end_loop:
376-
cov_block_sets2can_discon_com_4pts_and_optypes[tuple(cov_block_set)] = can_discon_com_4pts_and_optypes
377-
378-
cov_block_set = []
379-
can_discon_com_4pts_and_optypes = set()
380-
381-
if single_block_set and not end_loop:
382-
i += 1
383-
continue
384-
if end_set_and_redo_block:
385-
continue
386-
if end_loop:
387-
break
388-
else:
389-
cov_block_set.append(cov_block)
390-
can_discon_com_4pts_and_optypes &= this_block_can_discon_com_4pts_and_optypes
391-
i += 1
392-
393-
log.info(f'[Rank {so_mpi.rank}] Loop over cov block sets in {(time.time() - t0):.3f} seconds')
394-
395-
t0 = time.time()
396-
397-
# these may be useful to check later
398-
cov_block_sets2can_discon_com_4pts_and_optypes = so_mpi.gather_set_or_dict(cov_block_sets2can_discon_com_4pts_and_optypes,
399-
allgather=True,
400-
overlap_allowed=False)
401-
402-
cov_block2TEB_block2can_sn_alm_info2nterms = so_mpi.gather_set_or_dict(cov_block2TEB_block2can_sn_alm_info2nterms,
403-
allgather=True,
404-
overlap_allowed=False)
405-
406-
if so_mpi.rank == 0:
407-
npy.save(opj(cov_dir, 'cov_block_sets2can_discon_com_4pts_and_optypes.npy'), cov_block_sets2can_discon_com_4pts_and_optypes)
408-
npy.save(opj(cov_dir, 'cov_block2TEB_block2can_sn_alm_info2nterms.npy'), cov_block2TEB_block2can_sn_alm_info2nterms)
409-
410-
log.info(f'[Rank {so_mpi.rank}] Save cov block sets in {(time.time() - t0):.3f} seconds')
411-
412-
# now mpi over cov_block_sets
196+
# mpi over cov_block_sets
413197
cov_block_sets = list(cov_block_sets2can_discon_com_4pts_and_optypes.keys())
414198
n_cov_block_sets = len(cov_block_sets)
415199
subtasks = so_mpi.taskrange(imin=0, imax=n_cov_block_sets - 1)
@@ -441,15 +225,17 @@ def add_term_to_pseudo_cov_block(pseudo_cov_block, num_terms, w4_1234, w4_coupli
441225

442226
optype_counts = {}
443227
for optype in optypes_for_ducc:
444-
if optype in optype_counts:
445-
optype_counts[optype] += 1
228+
optypestr = optype2str[optype]
229+
if optypestr in optype_counts:
230+
optype_counts[optypestr] += 1
446231
else:
447-
optype_counts[optype] = 1
232+
optype_counts[optypestr] = 1
448233

449234
specs_for_ducc = None
450235
optypes_for_ducc = None
451236

452-
log.info(f'[Rank {so_mpi.rank}, Task {task}] Calculated {optype_counts} couplings in {(time.time() - t0):.3f} seconds')
237+
optypesstr = ', '.join([f'{ct} {s}-type couplings' for s, ct in optype_counts.items()])
238+
log.info(f'[Rank {so_mpi.rank}, Task {task}] Calculated {optypesstr} in {(time.time() - t0):.3f} seconds')
453239

454240
# now add all terms together for each cov block
455241
for i, cov_block in enumerate(cov_block_set):
@@ -568,6 +354,8 @@ def add_term_to_pseudo_cov_block(pseudo_cov_block, num_terms, w4_1234, w4_coupli
568354
dense=True, dtype=npy.float32)
569355

570356
# finalize: need to divide the split factor from each side and cast to double
357+
splits_cross_iterator_ij = pspipe_list.get_splits_cross_iterator(svi, nsplits[svi], svj, nsplits[svj])
358+
splits_cross_iterator_pq = pspipe_list.get_splits_cross_iterator(svp, nsplits[svp], svq, nsplits[svq])
571359
ana_cov /= (len(splits_cross_iterator_ij) * len(splits_cross_iterator_pq))
572360
ana_cov = ana_cov.astype(npy.float64, copy=False)
573361

0 commit comments

Comments
 (0)