Skip to content
75 changes: 69 additions & 6 deletions imsim/photon_pooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,19 +315,82 @@ def make_photon_batches(config, base, logger, phot_objects: list[ObjectInfo], fa
@staticmethod
def make_photon_subbatches(batch, nsubbatch):
"""
Divide a batch of objects into a list of smaller subbatches of approximately the same size.
Divide a batch of objects into a list of smaller subbatches of approximately the same
total number of photons.

Parameters:
batch: A list making a batch of objects to be divided into subbatches.
nsubbatch: The number of subbatches to create from the batch. Must be a positive integer.
Returns:
subbatches: A list of subbatches, each a valid batch in its own right, together containing all objects in the original batch.
"""
nobj = len(batch)
nobj_per_subbatch, nobj_extra = divmod(nobj, nsubbatch)
section_sizes = nobj_extra * [nobj_per_subbatch + 1] + (nsubbatch - nobj_extra) * [nobj_per_subbatch]
section_indices = [0] + list(itertools.accumulate(section_sizes))
subbatches = [batch[section_indices[i]:section_indices[i+1]] for i in range(nsubbatch)]
# This is going to be a bin-packing with fragmentation problem.
# Fragmentation is required because the range of fluxes in the objects
# is huge; it would be impossible to get roughly even workloads across
# the subbatches without splitting some objects.
# Our goal is also to reduce the number of object lookups; that means
# that we want to minimize the number of object splits.

nphotons_total = sum(obj.phot_flux for obj in batch)
photons_per_subbatch, extra_photons = divmod(nphotons_total, nsubbatch)

# We aren't asking for a perfectly equal distribution of photons; just
# something close. So we set a loose constraint.
loose_photons_per_subbatch = round(1.05 * photons_per_subbatch)

# We want the order of objects in the batch from brightest to faintest.
# This helps ensure that we minimize fragmentation by placing the
# largest objects (i.e. most difficult to fit) first, while the
# sub-batches are empty.
sorted_objects = sorted(batch, key=lambda obj: obj.phot_flux, reverse=True)

# Sub-batching loop.
subbatches = [[] for _ in range(nsubbatch)]
current_subbatch = 0
for obj in sorted_objects:
remaining_flux = obj.phot_flux
while remaining_flux > 0:
# Calculate current sub-batch flux.
# It would be nicer to store rather then recalculate.
subbatch = subbatches[current_subbatch]
subbatch_flux = sum(sb_obj.phot_flux for sb_obj in subbatch)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As you say, it would be nicer to keep this as a running total, rather than O(N^2) recalculating it each time. Is that hard for some reason? I'm thinking the easiest way would be to keep an array with the flux in each batch so far.

subbatch_fluxes = [0. for _ in range(nsubbatch)]

Then update it each time you add something to subbatches.

There are also patterns that only keep the running sum for the current batch, but the advantage of this is there is a nice debug statement you could put at the end with the list of realized fluxes per batch, which would be handy to be able to easily see when running.

available_flux = photons_per_subbatch - subbatch_flux
if remaining_flux <= available_flux:
# This sub-batch can contain the entirety of the remaining
# object flux. Place entirety it all in here, the continue
# without advancing in case more can go in.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I expected the first test to be if remaining_flux <= loose_photons_per_subbatch - subbatch_flux, in which case put it in entirely. And if remaining_flux >= photons_per_subbatch - subbatch_flux then bump to the next batch.

I think this cleanly gets what you now have both here and the start of the else branch.

subbatches[current_subbatch].append(
dataclasses.replace(obj, phot_flux=remaining_flux)
)
# If it was an exact fit, move to the next sub-batch.
if remaining_flux == available_flux:
current_subbatch = (1 + current_subbatch) % nsubbatch
remaining_flux = 0
else:
# The object is too bright to fit entirely in the current sub-batch.
# See if we can fit it in with the loose fitting criterion.
available_flux_loose = loose_photons_per_subbatch - subbatch_flux
if remaining_flux <= available_flux_loose:
# Place the entirety of the remaining object in this sub-batch.
subbatches[current_subbatch].append(
dataclasses.replace(obj, phot_flux=remaining_flux)
)
remaining_flux = 0
# Advance to the next sub-batch. Do this no matter what
# in an attempt to prevent too much bunching up in one
# sub-batch and unnecessary fragmentation of objects.
# (It's more likely that the next sub-batch will have
# space if we do this.)
current_subbatch = (1 + current_subbatch) % nsubbatch
else:
# If even the loose criterion isn't enough, then fill to
# the top and advance to next sub-batch.
subbatches[current_subbatch].append(
dataclasses.replace(obj, phot_flux=available_flux)
)
remaining_flux -= available_flux
current_subbatch = (1 + current_subbatch) % nsubbatch
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I understand this algorithm correctly, I don't think the % nsubbatch is ever necessary. Is this a relic from a previous implementation? Or am I missing something here?


return subbatches

@staticmethod
Expand Down
213 changes: 125 additions & 88 deletions tests/test_photon_pooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,15 @@ def create_phot_obj_list(num_objects, flux=1e5, start_num=0):
def create_faint_obj_list(num_objects, flux=100, start_num=0):
return [ObjectInfo(i+start_num, flux, ProcessingMode.FAINT) for i in range(num_objects)]

def shuffle_batch(batch):
shuffle(batch)
batch = [replace(object, index=i) for i, object in enumerate(batch)]
return batch

def create_mixed_obj_list():
# FFT, photon and faint photon objects.
base_list = create_fft_obj_list(10) + create_phot_obj_list(9) + create_faint_obj_list(1)
# Shuffle the objects then re-index them.
shuffle(base_list)
base_list = [replace(object, index=i) for i, object in enumerate(base_list)]
base_list = shuffle_batch(base_list)
return base_list

def run_partition_all_same_object_type(create_list_fn, desired_mode):
Expand Down Expand Up @@ -69,21 +72,34 @@ def test_partition_objects_mixed_all_types():
builder = valid_image_types["LSST_PhotonPoolingImage"]
nbatch = 10
orig_objects = create_mixed_obj_list()
print("Original objects:")
for obj in orig_objects:
print(f" Index: {obj.index}, flux: {obj.phot_flux}, mode: {obj.mode}")
fft_objects, phot_objects, faint_objects = builder.partition_objects(
orig_objects,
nbatch,
)
print("FFT objects:")
for obj in fft_objects:
print(f" Index: {obj.index}, flux: {obj.phot_flux}, mode: {obj.mode}")
print("photon objects:")
for obj in phot_objects:
print(f" Index: {obj.index}, flux: {obj.phot_flux}, mode: {obj.mode}")
print("faint objects:")
for obj in faint_objects:
print(f" Index: {obj.index}, flux: {obj.phot_flux}, mode: {obj.mode}")
# Assert correct number of objects in each list.
assert len(fft_objects) == 10
assert len(phot_objects) == 9
assert len(faint_objects) == 1
# Assert that all objects in the original list appear once in the new list.
all_objects = fft_objects + phot_objects + faint_objects
counts = Counter(object.index for object in all_objects)
print("Counts:", counts)
assert all([counts[obj.index] == 1 for obj in orig_objects])
expected_mode = [ProcessingMode.FFT] * 10 + [ProcessingMode.PHOT] * 9 + [ProcessingMode.FAINT]
assert all([object.mode == expected_mode[i] for i, object in enumerate(all_objects)])

return

def test_partition_objects_photon_and_faint():
Expand Down Expand Up @@ -176,95 +192,116 @@ def test_make_batches():

return

def test_make_photon_batches():
"""
Ensure that the photon batching method correctly handles PHOT and FAINT
object types and their fluxes are distributed correctly across the batches.
"""
builder = valid_image_types["LSST_PhotonPoolingImage"]
n_obj_phot = 15
n_obj_faint = 5
nobjects = n_obj_phot + n_obj_faint
phot_objects = create_phot_obj_list(n_obj_phot, start_num=0)
faint_objects = create_faint_obj_list(n_obj_faint, start_num=n_obj_phot)
objects = phot_objects + faint_objects
orig_flux = np.empty(nobjects)
for i, object in enumerate(objects):
orig_flux[i] = object.phot_flux

# Create 11 batches to ensure things don't divide nicely.
nbatch = 11
batches = builder.make_photon_batches({}, {}, None, phot_objects, faint_objects, nbatch)

# Count how many times the objects appear in the batches and sum their total
# flux across all batches.
count = Counter(object.index for batch in batches for object in batch)
total_flux = np.zeros(nobjects)
for batch in batches:
for object in batch:
total_flux[object.index] += object.phot_flux

# Assert that the PHOT objects appear in all batches (This may not be
# correct in the future if PHOT objects are spread across subsets of batches
# rather than all of them.)
# Also assert that FAINT objects appear once and only once.
for i, object in enumerate(objects):
if object.mode == ProcessingMode.PHOT:
assert count[i] == nbatch
elif object.mode == ProcessingMode.FAINT:
assert count[i] == 1

# Assert the summed flux across the objects in the batches is correct.
np.testing.assert_array_almost_equal(total_flux, orig_flux)

def assert_subbatches(batch, expected_subbatch_len, subbatches):
# Assert that the length of the full batch is equal to the sum of the sub-batches.
assert len(batch) == sum(len(subbatch) for subbatch in subbatches)
# Assert that the flattened list of sub-batches is equal to the original batch,
# including ordering of the objects.
assert batch == [object for subbatch in subbatches for object in subbatch]
# Assert that all objects in the original batch only once only in all the sub-batches.
counts = Counter(object.index for subbatch in subbatches for object in subbatch)
assert all([counts[obj.index] == 1 for obj in batch])
# Assert that the sub-batches are the expected lengths.
assert all([len(subbatch) == expected_subbatch_len[i] for i, subbatch in enumerate(subbatches)])

def run_subbatch_test(name, batch, nsubbatch):
total_original_flux = sum(object.phot_flux for object in batch)
subbatches = valid_image_types["LSST_PhotonPoolingImage"].make_photon_subbatches(batch, nsubbatch)
# In general there are multiple ways to split the batch. Assert that each
# object appears with its original flux across however many sub-batches it
# appears in, that the total flux across all sub-batches equals the total
# batch flux, and that the most full batch contains <= 1.1 * the flux in the
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1.05?

Oh, actually, with the rhs using the minimum, it won't be this. And in fact, I don't think the code makes any guarantee about that ratio. There can be arbitrarily many batches up to 5% bigger than nominal, so the minimum can be arbitrarily small (even 0 I believe). So your use of 1.1 here is I think just something that happens to work for this particular case.

If we want to avoid that "failure mode" (of having a very small batch), then I think we could lessen the likelihood of it by recalculating the nominal and loose levels for the rest of the batches any time one batch goes over the nominal level. Not sure how important that is, but it's something we could implement if we see that there is often a batch at the end that is very low flux.

# least full.
print("Subbatches in test:", name)
for i, subbatch in enumerate(subbatches):
print(f" Subbatch {i}: {[ (obj.index, obj.phot_flux) for obj in subbatch ]}")
assert len(subbatches) == nsubbatch
# Equivalent to commented out assert all, but much more readable!
for object in batch:
total_obj_flux = sum(sum(obj.phot_flux for obj in subbatch if obj.index == object.index) for subbatch in subbatches)
assert object.phot_flux == total_obj_flux
# assert all(sum(sb_obj.phot_flux for subbatch in subbatches for sb_obj in subbatch if sb_obj.index == obj.index) == obj.phot_flux for obj in batch)
assert sum([sum(object.phot_flux for object in subbatch) for subbatch in subbatches]) == total_original_flux
assert all(object.phot_flux > 0 for subbatch in subbatches for object in subbatch)
# assert all([sum(object.phot_flux for object in subbatch) == 2e4 for subbatch in subbatches])
total_subbatch_fluxes = [sum(object.phot_flux for object in subbatch) for subbatch in subbatches]
assert max(total_subbatch_fluxes) <= 1.1 * min(total_subbatch_fluxes)


def test_make_photon_subbatches():
"""
Test the sub-batching method for photon objects, which should evenly or
almost evenly distribute the objects in a batch across nsubbatch sub-batches.
Test the newer sub-batching method which attempts to spread the batch
flux equally across the sub-batches.
Some of these tests may be too restrictive, in particular those which specify
the exact sub-batch contents. It may be better to set those aside on only
check that the sub-batches balance flux and contain all the objects with the
correct total flux.
"""
# Create a batch containing 90 photon objects + 10 faint objects.
# The different object types should be treated equivalently by sub-batching.
n_obj_phot = 90
n_obj_faint = 10
phot_objects = create_phot_obj_list(n_obj_phot, start_num=0)
faint_objects = create_faint_obj_list(n_obj_faint, start_num=n_obj_phot)
batch = phot_objects + faint_objects

# Test a few different cases of sub-batching, easy and nasty. In particular,
# assert that the sub-batches are a split representation of the original
# batch, and also that we're splitting them up as close to evenly as
# possible. When it can't be exactly even, we want one extra 1 object in the
# first nobj%nsubbatch sub-batches.

# Split into 10 sub-batches.
nsubbatch = 10
subbatches = valid_image_types["LSST_PhotonPoolingImage"].make_photon_subbatches(batch, nsubbatch)
expected_subbatch_len = 10 * [10]
assert_subbatches(batch, expected_subbatch_len, subbatches)
# Create a batch with a 'large' number of objects with varying fluxes, but
# with none dominating the total. We want them to be spread across the
# sub-batches s.t. each has a flux of 2e4.
batch = [ObjectInfo(0, 1e4, ProcessingMode.PHOT),
ObjectInfo(1, 5e3, ProcessingMode.PHOT),
ObjectInfo(2, 2e3, ProcessingMode.PHOT),
ObjectInfo(3, 7e3, ProcessingMode.PHOT),
ObjectInfo(4, 3e3, ProcessingMode.PHOT),
ObjectInfo(5, 5e3, ProcessingMode.PHOT),
ObjectInfo(6, 1e4, ProcessingMode.PHOT),
ObjectInfo(7, 2e4, ProcessingMode.PHOT),
ObjectInfo(8, 1e4, ProcessingMode.PHOT),
ObjectInfo(9, 8e3, ProcessingMode.PHOT),
]
run_subbatch_test("equal distribution", batch, 4)

# Create a batch with a total flux of 1e6 photons. We should end up with 10
# sub-batches of 1e5 photons each. The majority of the flux is in a few very
# bright objects, so we want to see these correctly being split up across
# multiple sub-batches to make sure no one sub-batch requires a lot of
# time/memory. Yes, we're doing more work in the background for the extra
# objects, but this should be very little work as long as it's only for a
# very few bright objects.
batch = [ObjectInfo(0, 6e5, ProcessingMode.PHOT),
ObjectInfo(1, 2e5, ProcessingMode.PHOT),
ObjectInfo(2, 1e5, ProcessingMode.PHOT),
ObjectInfo(3, 4e4, ProcessingMode.PHOT),
ObjectInfo(4, 2e4, ProcessingMode.PHOT),
ObjectInfo(5, 1e4, ProcessingMode.PHOT),
ObjectInfo(6, 1e4, ProcessingMode.PHOT),
ObjectInfo(7, 1e4, ProcessingMode.PHOT),
ObjectInfo(8, 5e3, ProcessingMode.PHOT),
ObjectInfo(9, 5e3, ProcessingMode.PHOT),
]
run_subbatch_test("bright object fragmentation", batch, 10)

# Here there's still one very bright object, but it leaves a little bit of
# space in the first subbatch for something else to go in. The other faint
# objects pack into the second one.
batch = [ObjectInfo(0, 8e5, ProcessingMode.PHOT),
ObjectInfo(1, 3e5, ProcessingMode.PHOT),
ObjectInfo(2, 6e5, ProcessingMode.PHOT),
ObjectInfo(3, 3e5, ProcessingMode.PHOT),
]
run_subbatch_test("fragmentation in first sub-batch", batch, 2)

# Make sure the sub-batcher can go backwards (i.e. assign to subbatches
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't think the algorithm ever did go backwards? Again, am I missing something here?

# earlier then the one just filled). This would be important for best fit
# type implementations, but in other implementations like first fit might be
# equivalent to the previous test.
batch = [ObjectInfo(0, 8e4, ProcessingMode.PHOT),
ObjectInfo(1, 8e4, ProcessingMode.PHOT),
ObjectInfo(2, 4e4, ProcessingMode.PHOT),
]
run_subbatch_test("filling early sub-batches", batch, 2)


def test_make_photon_subbatches_non_simple():
# Need a test for which division of flux across sub-batches is not even,
# requiring non-trivial fragmentation of objects.

# The first test places a total of 1e6 photons across 7 sub-batches,
# i.e. 1e6 mod 7 = 142857 photons per sub-batch with remainder 1.
batch = [ObjectInfo(0, 5e5, ProcessingMode.PHOT),
ObjectInfo(1, 5e5, ProcessingMode.PHOT),
]
run_subbatch_test("small non-simple fragmentation", batch, 7)

# Then place 3 objects with total flux 1.1e6 in 31 sub-batches,
# i.e. 35483 photons per sub-batch with remainder 27.
batch = [ObjectInfo(0, 1e5, ProcessingMode.PHOT),
ObjectInfo(1, 5e5, ProcessingMode.PHOT),
ObjectInfo(2, 5e5, ProcessingMode.PHOT),
]
run_subbatch_test("large non-simple fragmentation", batch, 31)

# Split into 8 sub-batches.
nsubbatch = 8
subbatches = valid_image_types["LSST_PhotonPoolingImage"].make_photon_subbatches(batch, nsubbatch)
expected_subbatch_len = 4 * [13] + 4 * [12]
assert_subbatches(batch, expected_subbatch_len, subbatches)

# Split into 3 sub-batches.
nsubbatch = 3
subbatches = valid_image_types["LSST_PhotonPoolingImage"].make_photon_subbatches(batch, nsubbatch)
expected_subbatch_len = [34] + 2 * [33]
assert_subbatches(batch, expected_subbatch_len, subbatches)

if __name__ == "__main__":
testfns = [v for k, v in vars().items() if k[:5] == 'test_' and callable(v)]
Expand Down
Loading