Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions codeflash/verification/comparator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
import datetime
import decimal
import enum
import itertools
import math
import re
import types
import warnings
import weakref
from collections import ChainMap, OrderedDict, deque
from importlib.util import find_spec
Expand Down Expand Up @@ -528,6 +530,55 @@ def comparator(orig: Any, new: Any, superset_obj: bool = False) -> bool:
)
return comparator(orig_dict, new_dict, superset_obj)

# Handle itertools infinite iterators
if isinstance(orig, itertools.count):
# repr reliably reflects internal state, e.g. "count(5)" or "count(5, 2)"
return repr(orig) == repr(new)

if isinstance(orig, itertools.repeat):
# repr reliably reflects internal state, e.g. "repeat(5)" or "repeat(5, 3)"
return repr(orig) == repr(new)

if isinstance(orig, itertools.cycle):
# cycle has no useful repr and no public attributes; use __reduce__ to extract state.
# __reduce__ returns (cls, (remaining_iter,), (saved_items, first_pass_done)).
# NOTE: consuming the remaining_iter is destructive to the cycle object, but this is
# acceptable since the comparator is the final consumer of captured return values.
# NOTE: __reduce__ on itertools.cycle was removed in Python 3.14.
try:
with warnings.catch_warnings():
warnings.simplefilter("ignore", DeprecationWarning)
orig_reduce = orig.__reduce__()
new_reduce = new.__reduce__()
orig_remaining = list(orig_reduce[1][0])
new_remaining = list(new_reduce[1][0])
orig_saved, orig_started = orig_reduce[2]
new_saved, new_started = new_reduce[2]
if orig_started != new_started:
return False
return comparator(orig_remaining, new_remaining, superset_obj) and comparator(
orig_saved, new_saved, superset_obj
)
except TypeError:
# Python 3.14+: __reduce__ removed. Fall back to consuming elements from both
# cycles and comparing. Since the comparator is the final consumer, this is safe.
sample_size = 200
orig_sample = [next(orig) for _ in range(sample_size)]
new_sample = [next(new) for _ in range(sample_size)]
return comparator(orig_sample, new_sample, superset_obj)

# Handle remaining itertools types (chain, islice, starmap, product, permutations, etc.)
# by materializing into lists. count/repeat/cycle are already handled above.
# NOTE: materializing is destructive (consumes the iterator) and will hang on infinite input,
# but the three infinite itertools types are already handled above.
if type(orig).__module__ == "itertools":
if isinstance(orig, itertools.groupby):
# groupby yields (key, group_iterator) — materialize groups too
orig_groups = [(k, list(g)) for k, g in orig]
new_groups = [(k, list(g)) for k, g in new]
return comparator(orig_groups, new_groups, superset_obj)
return comparator(list(orig), list(new), superset_obj)

# re.Pattern can be made better by DFA Minimization and then comparing
if isinstance(
orig, (datetime.datetime, datetime.date, datetime.timedelta, datetime.time, datetime.timezone, re.Pattern)
Expand Down
306 changes: 306 additions & 0 deletions tests/test_comparator.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,312 @@ class Color4(IntFlag):
assert not comparator(id1, id3)


def test_itertools_count() -> None:
import itertools

# Equal: same start and step (default step=1)
assert comparator(itertools.count(0), itertools.count(0))
assert comparator(itertools.count(5), itertools.count(5))
assert comparator(itertools.count(0, 1), itertools.count(0, 1))
assert comparator(itertools.count(10, 3), itertools.count(10, 3))

# Equal: negative start and step
assert comparator(itertools.count(-5, -2), itertools.count(-5, -2))

# Equal: float start and step
assert comparator(itertools.count(0.5, 0.1), itertools.count(0.5, 0.1))

# Not equal: different start
assert not comparator(itertools.count(0), itertools.count(1))
assert not comparator(itertools.count(5), itertools.count(10))

# Not equal: different step
assert not comparator(itertools.count(0, 1), itertools.count(0, 2))
assert not comparator(itertools.count(0, 1), itertools.count(0, -1))

# Not equal: different type
assert not comparator(itertools.count(0), 0)
assert not comparator(itertools.count(0), [0, 1, 2])

# Equal after partial consumption (both advanced to the same state)
a = itertools.count(0)
b = itertools.count(0)
next(a)
next(b)
assert comparator(a, b)

# Not equal after different consumption
a = itertools.count(0)
b = itertools.count(0)
next(a)
assert not comparator(a, b)

# Works inside containers
assert comparator([itertools.count(0)], [itertools.count(0)])
assert comparator({"key": itertools.count(5, 2)}, {"key": itertools.count(5, 2)})
assert not comparator([itertools.count(0)], [itertools.count(1)])


def test_itertools_repeat() -> None:
import itertools

# Equal: infinite repeat
assert comparator(itertools.repeat(5), itertools.repeat(5))
assert comparator(itertools.repeat("hello"), itertools.repeat("hello"))

# Equal: bounded repeat
assert comparator(itertools.repeat(5, 3), itertools.repeat(5, 3))
assert comparator(itertools.repeat(None, 10), itertools.repeat(None, 10))

# Not equal: different value
assert not comparator(itertools.repeat(5), itertools.repeat(6))
assert not comparator(itertools.repeat(5, 3), itertools.repeat(6, 3))

# Not equal: different count
assert not comparator(itertools.repeat(5, 3), itertools.repeat(5, 4))

# Not equal: bounded vs infinite
assert not comparator(itertools.repeat(5), itertools.repeat(5, 3))

# Not equal: different type
assert not comparator(itertools.repeat(5), 5)
assert not comparator(itertools.repeat(5), [5])

# Equal after partial consumption
a = itertools.repeat(5, 5)
b = itertools.repeat(5, 5)
next(a)
next(b)
assert comparator(a, b)

# Not equal after different consumption
a = itertools.repeat(5, 5)
b = itertools.repeat(5, 5)
next(a)
assert not comparator(a, b)

# Works inside containers
assert comparator([itertools.repeat(5, 3)], [itertools.repeat(5, 3)])
assert not comparator([itertools.repeat(5, 3)], [itertools.repeat(5, 4)])


def test_itertools_cycle() -> None:
import itertools

# Equal: same sequence
assert comparator(itertools.cycle([1, 2, 3]), itertools.cycle([1, 2, 3]))
assert comparator(itertools.cycle("abc"), itertools.cycle("abc"))

# Not equal: different sequence
assert not comparator(itertools.cycle([1, 2, 3]), itertools.cycle([1, 2, 4]))
assert not comparator(itertools.cycle([1, 2, 3]), itertools.cycle([1, 2]))

# Not equal: different type
assert not comparator(itertools.cycle([1, 2, 3]), [1, 2, 3])

# Equal after same partial consumption
a = itertools.cycle([1, 2, 3])
b = itertools.cycle([1, 2, 3])
next(a)
next(b)
assert comparator(a, b)

# Not equal after different consumption
a = itertools.cycle([1, 2, 3])
b = itertools.cycle([1, 2, 3])
next(a)
assert not comparator(a, b)

# Equal after consuming a full cycle
a = itertools.cycle([1, 2, 3])
b = itertools.cycle([1, 2, 3])
for _ in range(3):
next(a)
next(b)
assert comparator(a, b)

# Equal at same position across different full-cycle counts
a = itertools.cycle([1, 2, 3])
b = itertools.cycle([1, 2, 3])
for _ in range(4):
next(a)
for _ in range(7):
next(b)
# Both at position 1 within the cycle (4%3 == 7%3 == 1)
assert comparator(a, b)

# Works inside containers
assert comparator([itertools.cycle([1, 2])], [itertools.cycle([1, 2])])
assert not comparator([itertools.cycle([1, 2])], [itertools.cycle([1, 3])])


def test_itertools_chain() -> None:
import itertools

assert comparator(itertools.chain([1, 2], [3, 4]), itertools.chain([1, 2], [3, 4]))
assert not comparator(itertools.chain([1, 2], [3, 4]), itertools.chain([1, 2], [3, 5]))
assert comparator(itertools.chain.from_iterable([[1, 2], [3]]), itertools.chain.from_iterable([[1, 2], [3]]))
assert comparator(itertools.chain(), itertools.chain())
assert not comparator(itertools.chain([1]), itertools.chain([1, 2]))


def test_itertools_islice() -> None:
import itertools

assert comparator(itertools.islice(range(10), 5), itertools.islice(range(10), 5))
assert not comparator(itertools.islice(range(10), 5), itertools.islice(range(10), 6))
assert comparator(itertools.islice(range(10), 2, 5), itertools.islice(range(10), 2, 5))
assert not comparator(itertools.islice(range(10), 2, 5), itertools.islice(range(10), 2, 6))


def test_itertools_product() -> None:
import itertools

assert comparator(itertools.product("AB", repeat=2), itertools.product("AB", repeat=2))
assert not comparator(itertools.product("AB", repeat=2), itertools.product("AC", repeat=2))
assert comparator(itertools.product([1, 2], [3, 4]), itertools.product([1, 2], [3, 4]))
assert not comparator(itertools.product([1, 2], [3, 4]), itertools.product([1, 2], [3, 5]))


def test_itertools_permutations_combinations() -> None:
import itertools

assert comparator(itertools.permutations("ABC", 2), itertools.permutations("ABC", 2))
assert not comparator(itertools.permutations("ABC", 2), itertools.permutations("ABD", 2))
assert comparator(itertools.combinations("ABCD", 2), itertools.combinations("ABCD", 2))
assert not comparator(itertools.combinations("ABCD", 2), itertools.combinations("ABCD", 3))
assert comparator(
itertools.combinations_with_replacement("ABC", 2),
itertools.combinations_with_replacement("ABC", 2),
)
assert not comparator(
itertools.combinations_with_replacement("ABC", 2),
itertools.combinations_with_replacement("ABD", 2),
)


def test_itertools_accumulate() -> None:
import itertools

assert comparator(itertools.accumulate([1, 2, 3, 4]), itertools.accumulate([1, 2, 3, 4]))
assert not comparator(itertools.accumulate([1, 2, 3, 4]), itertools.accumulate([1, 2, 3, 5]))
assert comparator(itertools.accumulate([1, 2, 3], initial=10), itertools.accumulate([1, 2, 3], initial=10))
assert not comparator(itertools.accumulate([1, 2, 3], initial=10), itertools.accumulate([1, 2, 3], initial=0))


def test_itertools_filtering() -> None:
import itertools

# compress
assert comparator(
itertools.compress("ABCDEF", [1, 0, 1, 0, 1, 1]),
itertools.compress("ABCDEF", [1, 0, 1, 0, 1, 1]),
)
assert not comparator(
itertools.compress("ABCDEF", [1, 0, 1, 0, 1, 1]),
itertools.compress("ABCDEF", [1, 1, 1, 0, 1, 1]),
)

# dropwhile
assert comparator(
itertools.dropwhile(lambda x: x < 5, [1, 4, 6, 4, 1]),
itertools.dropwhile(lambda x: x < 5, [1, 4, 6, 4, 1]),
)
assert not comparator(
itertools.dropwhile(lambda x: x < 5, [1, 4, 6, 4, 1]),
itertools.dropwhile(lambda x: x < 5, [1, 4, 7, 4, 1]),
)

# takewhile
assert comparator(
itertools.takewhile(lambda x: x < 5, [1, 4, 6, 4, 1]),
itertools.takewhile(lambda x: x < 5, [1, 4, 6, 4, 1]),
)
assert not comparator(
itertools.takewhile(lambda x: x < 5, [1, 4, 6, 4, 1]),
itertools.takewhile(lambda x: x < 5, [1, 3, 6, 4, 1]),
)

# filterfalse
assert comparator(
itertools.filterfalse(lambda x: x % 2, range(10)),
itertools.filterfalse(lambda x: x % 2, range(10)),
)


def test_itertools_starmap() -> None:
import itertools

assert comparator(
itertools.starmap(pow, [(2, 3), (3, 2), (10, 0)]),
itertools.starmap(pow, [(2, 3), (3, 2), (10, 0)]),
)
assert not comparator(
itertools.starmap(pow, [(2, 3), (3, 2)]),
itertools.starmap(pow, [(2, 3), (3, 3)]),
)


def test_itertools_zip_longest() -> None:
import itertools

assert comparator(
itertools.zip_longest("AB", "xyz", fillvalue="-"),
itertools.zip_longest("AB", "xyz", fillvalue="-"),
)
assert not comparator(
itertools.zip_longest("AB", "xyz", fillvalue="-"),
itertools.zip_longest("AB", "xyz", fillvalue="*"),
)


def test_itertools_groupby() -> None:
import itertools

assert comparator(itertools.groupby("AAABBBCC"), itertools.groupby("AAABBBCC"))
assert not comparator(itertools.groupby("AAABBBCC"), itertools.groupby("AAABBCC"))
assert comparator(itertools.groupby([]), itertools.groupby([]))

# With key function
assert comparator(
itertools.groupby([1, 1, 2, 2, 3], key=lambda x: x),
itertools.groupby([1, 1, 2, 2, 3], key=lambda x: x),
)


@pytest.mark.skipif(sys.version_info < (3, 10), reason="itertools.pairwise requires Python 3.10+")
def test_itertools_pairwise() -> None:
import itertools

assert comparator(itertools.pairwise([1, 2, 3, 4]), itertools.pairwise([1, 2, 3, 4]))
assert not comparator(itertools.pairwise([1, 2, 3, 4]), itertools.pairwise([1, 2, 3, 5]))


@pytest.mark.skipif(sys.version_info < (3, 12), reason="itertools.batched requires Python 3.12+")
def test_itertools_batched() -> None:
import itertools

assert comparator(itertools.batched("ABCDEFG", 3), itertools.batched("ABCDEFG", 3))
assert not comparator(itertools.batched("ABCDEFG", 3), itertools.batched("ABCDEFG", 2))


def test_itertools_in_containers() -> None:
import itertools

# Itertools objects nested in dicts/lists
assert comparator(
{"a": itertools.chain([1], [2]), "b": itertools.islice(range(5), 3)},
{"a": itertools.chain([1], [2]), "b": itertools.islice(range(5), 3)},
)
assert not comparator(
[itertools.product("AB", repeat=2)],
[itertools.product("AC", repeat=2)],
)

# Different itertools types should not match
assert not comparator(itertools.chain([1, 2]), itertools.islice([1, 2], 2))


def test_numpy():
try:
import numpy as np
Expand Down
Loading