Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions _unittests/ut_ci_models/test_ci_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

class TestCiExport(ExtTestCase):
@hide_stdout()
@requires_transformers("4.55")
def test_main_qwen25_tiny_llm(self):
main_qwen25(
model_id="arnir0/Tiny-LLM",
Expand Down
86 changes: 83 additions & 3 deletions _unittests/ut_investigate/test_input_observer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@
from onnx_diagnostic.ext_test_case import ExtTestCase
from onnx_diagnostic.investigate.input_observer import (
InputObserver,
infer_dynamic_dimensions,
_infer_dynamic_dimensions,
)


class TestInputObserver(ExtTestCase):
def test_infer_dynamic_dimensions(self):
self.assertEqual([2], infer_dynamic_dimensions([(1, 2, 3), (1, 2, 4)]))
self.assertEqual([0, 2], infer_dynamic_dimensions([(1, 2, 3), (2, 2, 4)]))
self.assertEqual([2], _infer_dynamic_dimensions([(1, 2, 3), (1, 2, 4)]))
self.assertEqual([0, 2], _infer_dynamic_dimensions([(1, 2, 3), (2, 2, 4)]))

def test_io_captured_args(self):
class Model(torch.nn.Module):
Expand Down Expand Up @@ -42,6 +42,36 @@ def forward(self, x, y):
self.assertIsInstance(args, tuple)
self.assertEqual(2, len(args))

def test_io_captured_not_forward(self):
class Model(torch.nn.Module):
def notforward(self, w):
return w.abs()

def forward(self, x, y):
return x + self.notforward(y)

inputs = [
(torch.randn((5, 6)), torch.randn((1, 6))),
(torch.randn((7, 7)), torch.randn((1, 7))),
(torch.randn((7, 8)), torch.randn((1, 8))),
(torch.randn((7, 9)), torch.randn((1, 9))),
]

model = Model()
observer = InputObserver()
with observer(model, method_name="notforward"):
for args in inputs:
model(*args)
self.assertEqual(len(observer.info), 3)
for i in range(3):
self.assertEqual(len(observer.info.flat_outputs[i]), 1)

cst = torch.export.Dim.DYNAMIC
self.assertEqual(({1: cst},), observer.infer_dynamic_shapes())
args = observer.infer_arguments()
self.assertIsInstance(args, tuple)
self.assertEqual(1, len(args))

def test_io_captured_kwargs(self):
class Model(torch.nn.Module):
def forward(self, x, y):
Expand Down Expand Up @@ -516,6 +546,56 @@ def forward(self, x, custom=None):
model(*args)
self.assertEqual(expected, observer.infer_dynamic_shapes())

def test_io_captured_args_kwargs_dynamic_batch(self):
class Model(torch.nn.Module):
def forward(self, x, y, z=None, w=None):
r = x + y
if z is not None:
r += z
if w is not None:
r += w
return r

inputs = [
(
(torch.randn((5, 6)), torch.randn((1, 6))),
dict(z=torch.randn((5, 6)), w=torch.randn((1, 6))),
),
(
(torch.randn((5, 7)), torch.randn((1, 7))),
dict(z=torch.randn((5, 7)), w=torch.randn((1, 7))),
),
(
(torch.randn((5, 8)), torch.randn((1, 8))),
dict(z=torch.randn((5, 8)), w=torch.randn((1, 8))),
),
(
(torch.randn((5, 9)), torch.randn((1, 9))),
dict(z=torch.randn((5, 9)), w=torch.randn((1, 9))),
),
]

model = Model()
expected = [model(*args, **kwargs) for args, kwargs in inputs]
observer = InputObserver()
with observer(model):
for args, kwargs in inputs:
model(*args, **kwargs)
self.assertEqual(len(observer.info), 3)
for i in range(3):
self.assertEqual(len(observer.info.flat_outputs[i]), 1)
torch.testing.assert_close(expected[i], observer.info.flat_outputs[i][0])

cst = torch.export.Dim.DYNAMIC
self.assertEqual(
dict(x={0: cst, 1: cst}, y={1: cst}, z={0: cst, 1: cst}, w={1: cst}),
observer.infer_dynamic_shapes(add_batch_dimension_for={0, "z"}),
)
self.assertEqual(
dict(x={0: cst, 1: cst}, y={1: cst}, z={0: cst, 1: cst}, w={1: cst}),
observer.infer_dynamic_shapes(add_batch_dimension_for={"x", "z"}),
)


if __name__ == "__main__":
unittest.main(verbosity=2)
6 changes: 3 additions & 3 deletions _unittests/ut_xrun_doc/test_documentation_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,10 @@ def add_test_methods(cls):

if (
not reason
and name in {"plot_export_tiny_llm.py"}
and not has_transformers("4.51")
and name in {"plot_export_tiny_llm.py", "plot_export_tiny_llm_patched.py"}
and not has_transformers("4.55")
):
reason = "transformers<4.51"
reason = "transformers<4.55"

if (
not reason
Expand Down
Loading
Loading