Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 84 additions & 14 deletions py/torch_tensorrt/_compile.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import collections.abc
import inspect
import logging
import platform
import warnings
Expand Down Expand Up @@ -572,23 +573,37 @@ def load_cross_compiled_exported_program(file_path: str = "") -> Any:
return dynamo_load_cross_compiled_exported_program(file_path)


def load(file_path: str = "") -> Any:
def load(
file_path: str = "", extra_files: Optional[dict[str, Any]] = None, **kwargs: Any
) -> Any:
"""
Load either a Torchscript model or ExportedProgram.

Loads a TorchScript or ExportedProgram file from disk. File type will be detect the type using try, except.

Arguments:
file_path (str): Path to file on the disk
extra_files (dict[str, Any]): Extra files to load with the model

Example:
# Load with extra files.
extra_files = {"foo.txt": ""} # values will be replaced with serialized data
ep = torch.export.load("exported_program.pt2", extra_files=extra_files)
print(extra_files["foo.txt"])

Raises:
ValueError: If there is no file or the file is not either a TorchScript file or ExportedProgram file
"""

try:
logger.debug(f"Loading the provided file {file_path} using torch.export.load()")
exp_program = torch.export.load(file_path)
return exp_program
logger.debug(f"Loading the provided file {file_path} using torch.jit.load()")
ts_module = function_overload_with_kwargs(
torch.export.load,
file_path,
extra_files=extra_files,
**kwargs,
)
return ts_module
except Exception:
logger.info(
f"Loading the provided file {file_path} via torch.export.load() failed with the following error",
Expand All @@ -597,9 +612,14 @@ def load(file_path: str = "") -> Any:
pass

try:
logger.debug(f"Loading the provided file {file_path} using torch.jit.load()")
ts_module = torch.jit.load(file_path)
return ts_module
logger.debug(f"Loading the provided file {file_path} using torch.export.load()")
exp_program = function_overload_with_kwargs(
torch.jit.load,
file_path,
_extra_files=extra_files,
**kwargs,
)
return exp_program
except Exception:
logger.info(
f"Loading the provided file {file_path} via torch.jit.load() (after failing to load with torch.export.load()) failed with the following error",
Expand All @@ -614,6 +634,7 @@ def save(
module: Any,
file_path: str = "",
*,
extra_files: Optional[dict[str, str]] = None,
output_format: str = "exported_program",
inputs: Optional[Sequence[torch.Tensor | Input]] = None,
arg_inputs: Optional[Sequence[torch.Tensor | Input]] = None,
Expand Down Expand Up @@ -845,7 +866,17 @@ def _extract_tensor(obj: Any) -> Any:
"Provided model is a torch.jit.ScriptModule but the output_format specified is not torchscript. Other output formats are not supported"
)
else:
torch.jit.save(module, file_path)
if arg_inputs is not None:
logger.warning(
"Provided model is a torch.jit.ScriptModule, inputs or arg_inputs is not necessary during save."
)
function_overload_with_kwargs(
torch.jit.save,
module,
file_path,
_extra_files=extra_files,
**kwargs,
)
elif module_type == _ModuleType.ep:
if output_format == "torchscript":
raise ValueError(
Expand All @@ -857,7 +888,14 @@ def _extract_tensor(obj: Any) -> Any:
"Provided model is a torch.export.ExportedProgram, inputs or arg_inputs is not necessary during save, it uses the inputs or arg_inputs provided during export and compile"
)
if output_format == "exported_program":
torch.export.save(module, file_path, pickle_protocol=pickle_protocol)
function_overload_with_kwargs(
torch.export.save,
module,
file_path,
pickle_protocol=pickle_protocol,
extra_files=extra_files,
**kwargs,
)
elif output_format == "aot_inductor":
inductor_configs = {}
if "inductor_configs" in kwargs:
Expand All @@ -878,7 +916,13 @@ def _extract_tensor(obj: Any) -> Any:
module_ts = torch.jit.trace(
module, arg_inputs, example_kwarg_inputs=kwarg_inputs
)
torch.jit.save(module_ts, file_path)
function_overload_with_kwargs(
torch.jit.save,
module_ts,
file_path,
_extra_files=extra_files,
**kwargs,
)
else:
if not retrace:
from torch_tensorrt.dynamo._exporter import export
Expand All @@ -901,8 +945,13 @@ def _extract_tensor(obj: Any) -> Any:
use_legacy_exporter=_use_legacy,
)
if output_format == "exported_program":
torch.export.save(
exp_program, file_path, pickle_protocol=pickle_protocol
function_overload_with_kwargs(
torch.export.save,
exp_program,
file_path,
pickle_protocol=pickle_protocol,
extra_files=extra_files,
**kwargs,
)
elif output_format == "aot_inductor":
inductor_configs = {}
Expand Down Expand Up @@ -975,8 +1024,13 @@ def _extract_tensor(obj: Any) -> Any:
)

if output_format == "exported_program":
torch.export.save(
exp_program, file_path, pickle_protocol=pickle_protocol
function_overload_with_kwargs(
torch.export.save,
exp_program,
file_path,
pickle_protocol=pickle_protocol,
extra_files=extra_files,
**kwargs,
)
elif output_format == "aot_inductor":
inductor_configs = {}
Expand All @@ -992,3 +1046,19 @@ def _extract_tensor(obj: Any) -> Any:
raise RuntimeError(
"Attempted to serialize an exported program with an unsupported format. Exported programs support exported_program and aot_inductor"
)


def function_overload_with_kwargs(
fn: Callable[..., Any], *args: Any, **kwargs: Any
) -> Any:
fn_signature = inspect.signature(fn).parameters
fn_kwargs = {}
for k, v in kwargs.items():
if k in fn_signature:
fn_kwargs[k] = v
else:
logger.warning(
f"Keyword argument {k} is not a valid argument for {fn.__name__}"
)

return fn(*args, **fn_kwargs)
59 changes: 59 additions & 0 deletions tests/py/dynamo/models/test_export_serde.py
Original file line number Diff line number Diff line change
Expand Up @@ -728,3 +728,62 @@ def forward(self, x):
cos_sim > COSINE_THRESHOLD,
msg=f"test_save_load_ts TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
)


@pytest.mark.unit
def test_save_load_extra_files(ir, tmpdir):
"""
This tests save/load API on Torchscript format (model still compiled using dynamo workflow)
"""

class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(3, 16, 3, stride=1, bias=True)
self.relu = torch.nn.ReLU()

def forward(self, x):
conv = self.conv(x)
relu = self.relu(conv)
mul = relu * 0.5
return mul

ep_path = os.path.join(tmpdir, "trt.er")
model = MyModule().eval().cuda()
input = torch.randn((1, 3, 224, 224)).to("cuda")

trt_gm = torchtrt.compile(
model,
ir=ir,
inputs=[input],
min_block_size=1,
cache_built_engines=False,
reuse_cached_engines=False,
)
assertions.assertTrue(
isinstance(trt_gm, torch.fx.GraphModule),
msg=f"test_save_load_ts output type does not match with torch.fx.GraphModule",
)
outputs_trt = trt_gm(input)
# Save it as torchscript representation
torchtrt.save(
trt_gm,
ep_path,
output_format="exported_program",
inputs=[input],
extra_files={"metadata": "Saving with extra files"},
)

loaded_extra_files = {"metadata": None}
trt_ep_module = torchtrt.load(ep_path, extra_files=loaded_extra_files)
outputs_trt_deser = trt_ep_module.module()(input)

cos_sim = cosine_similarity(outputs_trt, outputs_trt_deser)
assertions.assertTrue(
loaded_extra_files["metadata"] == "Saving with extra files",
msg="Extra files not saved and loaded correctly",
)
assertions.assertTrue(
cos_sim > COSINE_THRESHOLD,
msg=f"test_save_load_ts TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
)
Loading