diff --git a/codeflash/verification/comparator.py b/codeflash/verification/comparator.py index 51f958f34..41ed1a413 100644 --- a/codeflash/verification/comparator.py +++ b/codeflash/verification/comparator.py @@ -3,9 +3,11 @@ import datetime import decimal import enum +import itertools import math import re import types +import warnings import weakref from collections import ChainMap, OrderedDict, deque from importlib.util import find_spec @@ -528,6 +530,55 @@ def comparator(orig: Any, new: Any, superset_obj: bool = False) -> bool: ) return comparator(orig_dict, new_dict, superset_obj) + # Handle itertools infinite iterators + if isinstance(orig, itertools.count): + # repr reliably reflects internal state, e.g. "count(5)" or "count(5, 2)" + return repr(orig) == repr(new) + + if isinstance(orig, itertools.repeat): + # repr reliably reflects internal state, e.g. "repeat(5)" or "repeat(5, 3)" + return repr(orig) == repr(new) + + if isinstance(orig, itertools.cycle): + # cycle has no useful repr and no public attributes; use __reduce__ to extract state. + # __reduce__ returns (cls, (remaining_iter,), (saved_items, first_pass_done)). + # NOTE: consuming the remaining_iter is destructive to the cycle object, but this is + # acceptable since the comparator is the final consumer of captured return values. + # NOTE: __reduce__ on itertools.cycle was removed in Python 3.14. + try: + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + orig_reduce = orig.__reduce__() + new_reduce = new.__reduce__() + orig_remaining = list(orig_reduce[1][0]) + new_remaining = list(new_reduce[1][0]) + orig_saved, orig_started = orig_reduce[2] + new_saved, new_started = new_reduce[2] + if orig_started != new_started: + return False + return comparator(orig_remaining, new_remaining, superset_obj) and comparator( + orig_saved, new_saved, superset_obj + ) + except TypeError: + # Python 3.14+: __reduce__ removed. Fall back to consuming elements from both + # cycles and comparing. Since the comparator is the final consumer, this is safe. + sample_size = 200 + orig_sample = [next(orig) for _ in range(sample_size)] + new_sample = [next(new) for _ in range(sample_size)] + return comparator(orig_sample, new_sample, superset_obj) + + # Handle remaining itertools types (chain, islice, starmap, product, permutations, etc.) + # by materializing into lists. count/repeat/cycle are already handled above. + # NOTE: materializing is destructive (consumes the iterator) and will hang on infinite input, + # but the three infinite itertools types are already handled above. + if type(orig).__module__ == "itertools": + if isinstance(orig, itertools.groupby): + # groupby yields (key, group_iterator) — materialize groups too + orig_groups = [(k, list(g)) for k, g in orig] + new_groups = [(k, list(g)) for k, g in new] + return comparator(orig_groups, new_groups, superset_obj) + return comparator(list(orig), list(new), superset_obj) + # re.Pattern can be made better by DFA Minimization and then comparing if isinstance( orig, (datetime.datetime, datetime.date, datetime.timedelta, datetime.time, datetime.timezone, re.Pattern) diff --git a/tests/test_comparator.py b/tests/test_comparator.py index db766cd66..28eeb8490 100644 --- a/tests/test_comparator.py +++ b/tests/test_comparator.py @@ -417,6 +417,312 @@ class Color4(IntFlag): assert not comparator(id1, id3) +def test_itertools_count() -> None: + import itertools + + # Equal: same start and step (default step=1) + assert comparator(itertools.count(0), itertools.count(0)) + assert comparator(itertools.count(5), itertools.count(5)) + assert comparator(itertools.count(0, 1), itertools.count(0, 1)) + assert comparator(itertools.count(10, 3), itertools.count(10, 3)) + + # Equal: negative start and step + assert comparator(itertools.count(-5, -2), itertools.count(-5, -2)) + + # Equal: float start and step + assert comparator(itertools.count(0.5, 0.1), itertools.count(0.5, 0.1)) + + # Not equal: different start + assert not comparator(itertools.count(0), itertools.count(1)) + assert not comparator(itertools.count(5), itertools.count(10)) + + # Not equal: different step + assert not comparator(itertools.count(0, 1), itertools.count(0, 2)) + assert not comparator(itertools.count(0, 1), itertools.count(0, -1)) + + # Not equal: different type + assert not comparator(itertools.count(0), 0) + assert not comparator(itertools.count(0), [0, 1, 2]) + + # Equal after partial consumption (both advanced to the same state) + a = itertools.count(0) + b = itertools.count(0) + next(a) + next(b) + assert comparator(a, b) + + # Not equal after different consumption + a = itertools.count(0) + b = itertools.count(0) + next(a) + assert not comparator(a, b) + + # Works inside containers + assert comparator([itertools.count(0)], [itertools.count(0)]) + assert comparator({"key": itertools.count(5, 2)}, {"key": itertools.count(5, 2)}) + assert not comparator([itertools.count(0)], [itertools.count(1)]) + + +def test_itertools_repeat() -> None: + import itertools + + # Equal: infinite repeat + assert comparator(itertools.repeat(5), itertools.repeat(5)) + assert comparator(itertools.repeat("hello"), itertools.repeat("hello")) + + # Equal: bounded repeat + assert comparator(itertools.repeat(5, 3), itertools.repeat(5, 3)) + assert comparator(itertools.repeat(None, 10), itertools.repeat(None, 10)) + + # Not equal: different value + assert not comparator(itertools.repeat(5), itertools.repeat(6)) + assert not comparator(itertools.repeat(5, 3), itertools.repeat(6, 3)) + + # Not equal: different count + assert not comparator(itertools.repeat(5, 3), itertools.repeat(5, 4)) + + # Not equal: bounded vs infinite + assert not comparator(itertools.repeat(5), itertools.repeat(5, 3)) + + # Not equal: different type + assert not comparator(itertools.repeat(5), 5) + assert not comparator(itertools.repeat(5), [5]) + + # Equal after partial consumption + a = itertools.repeat(5, 5) + b = itertools.repeat(5, 5) + next(a) + next(b) + assert comparator(a, b) + + # Not equal after different consumption + a = itertools.repeat(5, 5) + b = itertools.repeat(5, 5) + next(a) + assert not comparator(a, b) + + # Works inside containers + assert comparator([itertools.repeat(5, 3)], [itertools.repeat(5, 3)]) + assert not comparator([itertools.repeat(5, 3)], [itertools.repeat(5, 4)]) + + +def test_itertools_cycle() -> None: + import itertools + + # Equal: same sequence + assert comparator(itertools.cycle([1, 2, 3]), itertools.cycle([1, 2, 3])) + assert comparator(itertools.cycle("abc"), itertools.cycle("abc")) + + # Not equal: different sequence + assert not comparator(itertools.cycle([1, 2, 3]), itertools.cycle([1, 2, 4])) + assert not comparator(itertools.cycle([1, 2, 3]), itertools.cycle([1, 2])) + + # Not equal: different type + assert not comparator(itertools.cycle([1, 2, 3]), [1, 2, 3]) + + # Equal after same partial consumption + a = itertools.cycle([1, 2, 3]) + b = itertools.cycle([1, 2, 3]) + next(a) + next(b) + assert comparator(a, b) + + # Not equal after different consumption + a = itertools.cycle([1, 2, 3]) + b = itertools.cycle([1, 2, 3]) + next(a) + assert not comparator(a, b) + + # Equal after consuming a full cycle + a = itertools.cycle([1, 2, 3]) + b = itertools.cycle([1, 2, 3]) + for _ in range(3): + next(a) + next(b) + assert comparator(a, b) + + # Equal at same position across different full-cycle counts + a = itertools.cycle([1, 2, 3]) + b = itertools.cycle([1, 2, 3]) + for _ in range(4): + next(a) + for _ in range(7): + next(b) + # Both at position 1 within the cycle (4%3 == 7%3 == 1) + assert comparator(a, b) + + # Works inside containers + assert comparator([itertools.cycle([1, 2])], [itertools.cycle([1, 2])]) + assert not comparator([itertools.cycle([1, 2])], [itertools.cycle([1, 3])]) + + +def test_itertools_chain() -> None: + import itertools + + assert comparator(itertools.chain([1, 2], [3, 4]), itertools.chain([1, 2], [3, 4])) + assert not comparator(itertools.chain([1, 2], [3, 4]), itertools.chain([1, 2], [3, 5])) + assert comparator(itertools.chain.from_iterable([[1, 2], [3]]), itertools.chain.from_iterable([[1, 2], [3]])) + assert comparator(itertools.chain(), itertools.chain()) + assert not comparator(itertools.chain([1]), itertools.chain([1, 2])) + + +def test_itertools_islice() -> None: + import itertools + + assert comparator(itertools.islice(range(10), 5), itertools.islice(range(10), 5)) + assert not comparator(itertools.islice(range(10), 5), itertools.islice(range(10), 6)) + assert comparator(itertools.islice(range(10), 2, 5), itertools.islice(range(10), 2, 5)) + assert not comparator(itertools.islice(range(10), 2, 5), itertools.islice(range(10), 2, 6)) + + +def test_itertools_product() -> None: + import itertools + + assert comparator(itertools.product("AB", repeat=2), itertools.product("AB", repeat=2)) + assert not comparator(itertools.product("AB", repeat=2), itertools.product("AC", repeat=2)) + assert comparator(itertools.product([1, 2], [3, 4]), itertools.product([1, 2], [3, 4])) + assert not comparator(itertools.product([1, 2], [3, 4]), itertools.product([1, 2], [3, 5])) + + +def test_itertools_permutations_combinations() -> None: + import itertools + + assert comparator(itertools.permutations("ABC", 2), itertools.permutations("ABC", 2)) + assert not comparator(itertools.permutations("ABC", 2), itertools.permutations("ABD", 2)) + assert comparator(itertools.combinations("ABCD", 2), itertools.combinations("ABCD", 2)) + assert not comparator(itertools.combinations("ABCD", 2), itertools.combinations("ABCD", 3)) + assert comparator( + itertools.combinations_with_replacement("ABC", 2), + itertools.combinations_with_replacement("ABC", 2), + ) + assert not comparator( + itertools.combinations_with_replacement("ABC", 2), + itertools.combinations_with_replacement("ABD", 2), + ) + + +def test_itertools_accumulate() -> None: + import itertools + + assert comparator(itertools.accumulate([1, 2, 3, 4]), itertools.accumulate([1, 2, 3, 4])) + assert not comparator(itertools.accumulate([1, 2, 3, 4]), itertools.accumulate([1, 2, 3, 5])) + assert comparator(itertools.accumulate([1, 2, 3], initial=10), itertools.accumulate([1, 2, 3], initial=10)) + assert not comparator(itertools.accumulate([1, 2, 3], initial=10), itertools.accumulate([1, 2, 3], initial=0)) + + +def test_itertools_filtering() -> None: + import itertools + + # compress + assert comparator( + itertools.compress("ABCDEF", [1, 0, 1, 0, 1, 1]), + itertools.compress("ABCDEF", [1, 0, 1, 0, 1, 1]), + ) + assert not comparator( + itertools.compress("ABCDEF", [1, 0, 1, 0, 1, 1]), + itertools.compress("ABCDEF", [1, 1, 1, 0, 1, 1]), + ) + + # dropwhile + assert comparator( + itertools.dropwhile(lambda x: x < 5, [1, 4, 6, 4, 1]), + itertools.dropwhile(lambda x: x < 5, [1, 4, 6, 4, 1]), + ) + assert not comparator( + itertools.dropwhile(lambda x: x < 5, [1, 4, 6, 4, 1]), + itertools.dropwhile(lambda x: x < 5, [1, 4, 7, 4, 1]), + ) + + # takewhile + assert comparator( + itertools.takewhile(lambda x: x < 5, [1, 4, 6, 4, 1]), + itertools.takewhile(lambda x: x < 5, [1, 4, 6, 4, 1]), + ) + assert not comparator( + itertools.takewhile(lambda x: x < 5, [1, 4, 6, 4, 1]), + itertools.takewhile(lambda x: x < 5, [1, 3, 6, 4, 1]), + ) + + # filterfalse + assert comparator( + itertools.filterfalse(lambda x: x % 2, range(10)), + itertools.filterfalse(lambda x: x % 2, range(10)), + ) + + +def test_itertools_starmap() -> None: + import itertools + + assert comparator( + itertools.starmap(pow, [(2, 3), (3, 2), (10, 0)]), + itertools.starmap(pow, [(2, 3), (3, 2), (10, 0)]), + ) + assert not comparator( + itertools.starmap(pow, [(2, 3), (3, 2)]), + itertools.starmap(pow, [(2, 3), (3, 3)]), + ) + + +def test_itertools_zip_longest() -> None: + import itertools + + assert comparator( + itertools.zip_longest("AB", "xyz", fillvalue="-"), + itertools.zip_longest("AB", "xyz", fillvalue="-"), + ) + assert not comparator( + itertools.zip_longest("AB", "xyz", fillvalue="-"), + itertools.zip_longest("AB", "xyz", fillvalue="*"), + ) + + +def test_itertools_groupby() -> None: + import itertools + + assert comparator(itertools.groupby("AAABBBCC"), itertools.groupby("AAABBBCC")) + assert not comparator(itertools.groupby("AAABBBCC"), itertools.groupby("AAABBCC")) + assert comparator(itertools.groupby([]), itertools.groupby([])) + + # With key function + assert comparator( + itertools.groupby([1, 1, 2, 2, 3], key=lambda x: x), + itertools.groupby([1, 1, 2, 2, 3], key=lambda x: x), + ) + + +@pytest.mark.skipif(sys.version_info < (3, 10), reason="itertools.pairwise requires Python 3.10+") +def test_itertools_pairwise() -> None: + import itertools + + assert comparator(itertools.pairwise([1, 2, 3, 4]), itertools.pairwise([1, 2, 3, 4])) + assert not comparator(itertools.pairwise([1, 2, 3, 4]), itertools.pairwise([1, 2, 3, 5])) + + +@pytest.mark.skipif(sys.version_info < (3, 12), reason="itertools.batched requires Python 3.12+") +def test_itertools_batched() -> None: + import itertools + + assert comparator(itertools.batched("ABCDEFG", 3), itertools.batched("ABCDEFG", 3)) + assert not comparator(itertools.batched("ABCDEFG", 3), itertools.batched("ABCDEFG", 2)) + + +def test_itertools_in_containers() -> None: + import itertools + + # Itertools objects nested in dicts/lists + assert comparator( + {"a": itertools.chain([1], [2]), "b": itertools.islice(range(5), 3)}, + {"a": itertools.chain([1], [2]), "b": itertools.islice(range(5), 3)}, + ) + assert not comparator( + [itertools.product("AB", repeat=2)], + [itertools.product("AC", repeat=2)], + ) + + # Different itertools types should not match + assert not comparator(itertools.chain([1, 2]), itertools.islice([1, 2], 2)) + + def test_numpy(): try: import numpy as np