Skip to content

Commit 694755f

Browse files
committed
Enabled op agnostic serialization for both runtime
1 parent 5b1bde7 commit 694755f

4 files changed

Lines changed: 62 additions & 45 deletions

File tree

py/torch_tensorrt/_compile.py

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -545,7 +545,7 @@ def convert_method_to_trt_engine(
545545
module, torchtrt_arg_inputs, kwarg_inputs=torchtrt_kwarg_inputs, **kwargs
546546
)
547547

548-
return dynamo_convert_exported_program_to_serialized_trt_engine(
548+
return dynamo_convert_exported_program_to_serialized_trt_engine( # type: ignore[no-any-return]
549549
exp_program,
550550
arg_inputs=tuple(arg_inputs),
551551
kwarg_inputs=torchtrt_kwarg_inputs,
@@ -594,35 +594,40 @@ def load(
594594
Raises:
595595
ValueError: If there is no file or the file is not either a TorchScript file or ExportedProgram file
596596
"""
597+
from torch_tensorrt.dynamo._exporter import replace_execute_engine_no_op_node
597598

598599
try:
599-
logger.debug(f"Loading the provided file {file_path} using torch.jit.load()")
600-
ts_module = function_overload_with_kwargs(
600+
logger.debug(f"Loading the provided file {file_path} using torch.export.load()")
601+
exp_program = function_overload_with_kwargs(
601602
torch.export.load,
602603
file_path,
603604
extra_files=extra_files,
604605
**kwargs,
605606
)
606-
return ts_module
607-
except Exception:
607+
gm = exp_program.graph_module
608+
if any(
609+
"no_op_placeholder_for_execute_engine" in n.name for n in gm.graph.nodes
610+
):
611+
return replace_execute_engine_no_op_node(exp_program)
612+
return exp_program
613+
except Exception as e:
608614
logger.info(
609-
f"Loading the provided file {file_path} via torch.export.load() failed with the following error",
615+
f"Loading the provided file {file_path} via torch.export.load() failed with the following error: {e}",
610616
exc_info=True,
611617
)
612-
pass
613618

614619
try:
615-
logger.debug(f"Loading the provided file {file_path} using torch.export.load()")
616-
exp_program = function_overload_with_kwargs(
620+
logger.debug(f"Loading the provided file {file_path} using torch.jit.load()")
621+
ts_module = function_overload_with_kwargs(
617622
torch.jit.load,
618623
file_path,
619624
_extra_files=extra_files,
620625
**kwargs,
621626
)
622-
return exp_program
623-
except Exception:
627+
return ts_module
628+
except Exception as e:
624629
logger.info(
625-
f"Loading the provided file {file_path} via torch.jit.load() (after failing to load with torch.export.load()) failed with the following error",
630+
f"Loading the provided file {file_path} via torch.jit.load() (after failing to load with torch.export.load()) failed with the following error: {e}",
626631
exc_info=True,
627632
)
628633
raise ValueError(
@@ -805,8 +810,8 @@ def _all_are_input_objects(obj: Any) -> bool:
805810
f"Inferred dynamic_shapes from torch_tensorrt.Input objects with min/opt/max specifications: {dynamic_shapes}"
806811
)
807812

808-
arg_tensors = tuple(get_torch_inputs(arg_inputs, default_device())) # type: ignore
809-
kwarg_tensors = get_torch_inputs(kwarg_inputs, default_device()) # type: ignore
813+
arg_tensors = tuple(get_torch_inputs(arg_inputs, default_device()))
814+
kwarg_tensors = get_torch_inputs(kwarg_inputs, default_device())
810815

811816
else:
812817
# Mixed case: some inputs are Tensors, some are Input objects

py/torch_tensorrt/dynamo/_exporter.py

Lines changed: 18 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
OutputSpec,
2020
TensorArgument,
2121
)
22+
from torch_tensorrt._features import ENABLED_FEATURES
2223
from torch_tensorrt.dynamo.runtime._TorchTensorRTModule import ENGINE_IDX, NAME_IDX
2324

2425

@@ -483,36 +484,18 @@ def inline_trt_modules(
483484
f"trt_module_node: {trt_module_node.name} does not have the metadata which should be set during dynamo compile_module step."
484485
)
485486
num_outputs = len(trt_module_node.meta["val"])
486-
# Insert a call_function node to perform inference on TRT engine
487487
with gm.graph.inserting_before(trt_module_node):
488-
if cross_compile_module:
489-
engine_info = trt_module._pack_engine_info()
490-
engine_bytes = engine_info[ENGINE_IDX]
491-
engine_info[ENGINE_IDX] = base64.b64encode(engine_bytes).decode("utf-8")
492-
# insert the no_placeholder node in the graph which should be replaced to the actual execute_engine node while load in the windows
493-
trt_node = gm.graph.call_function(
494-
torch.ops.tensorrt.no_op_placeholder_for_execute_engine.default,
495-
(trt_module_node.args, *engine_info),
496-
)
497-
else:
498-
# for the normal workflow: use the execute_engine node
499-
engine_name = f"{name}_engine"
500-
# TODO: THROWS SOME WARNING ABOUT A LACK OF UNDERLYING REFERENCE TO THE OWNING GRAPH MODULE
501-
# SAYS THERES 3 OPTIONS, SUBMODULE, PARAMETER, OR BUFFER, BUFFER SEEMS THE BEST BUT I THINK ITS KEYED TO TENSORS
502-
setattr(gm, engine_name, trt_module.engine)
503-
engine_node = gm.graph.get_attr(engine_name)
504-
505-
trt_node = gm.graph.call_function(
506-
torch.ops.tensorrt.execute_engine.default,
507-
(trt_module_node.args, engine_node),
508-
)
509-
# meta["val"] should be a lighter version of a tensor. For eg: it should be a FakeTensor (with output shape and dtype properties)
510-
# Lighter version of a custom_obj is not defined clearly. meta["val"] does not have any type expectations but
511-
# for custom object nodes, it should be CustomObjArgument
512-
engine_node.meta["val"] = CustomObjArgument(
513-
name=engine_node.name, class_fqn=""
514-
)
515-
# set trt_node.meta with trt_module_node.meta
488+
# Always embed engine data as primitive string args via no_op_placeholder
489+
# so torch.export does not pickle torch.classes.tensorrt.Engine (which
490+
# requires the C++ TorchBind class at load time).
491+
# torch_tensorrt.load() lowers placeholders → execute_engine.
492+
engine_info = trt_module._pack_engine_info()
493+
engine_bytes = engine_info[ENGINE_IDX]
494+
engine_info[ENGINE_IDX] = base64.b64encode(engine_bytes).decode("utf-8")
495+
trt_node = gm.graph.call_function(
496+
torch.ops.tensorrt.no_op_placeholder_for_execute_engine.default,
497+
(trt_module_node.args, *engine_info),
498+
)
516499
assert num_outputs > 0
517500
trt_node.meta["val"] = trt_module_node.meta["val"]
518501

@@ -557,7 +540,12 @@ def replace_execute_engine_no_op_node(
557540
packed_engine_info[ENGINE_IDX] = base64.b64decode(
558541
engine_bytes.encode("utf-8")
559542
)
560-
trt_engine = torch.classes.tensorrt.Engine(tuple(packed_engine_info))
543+
if ENABLED_FEATURES.torch_tensorrt_runtime:
544+
trt_engine = torch.classes.tensorrt.Engine(tuple(packed_engine_info))
545+
else:
546+
from torch_tensorrt.dynamo.runtime._PythonTRTEngine import TRTEngine
547+
548+
trt_engine = TRTEngine(packed_engine_info)
561549
setattr(gm, engine_name, trt_engine)
562550
engine_node = gm.graph.get_attr(engine_name)
563551

py/torch_tensorrt/dynamo/runtime/_PythonTRTEngine.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -707,3 +707,25 @@ def execute_engine(
707707
) -> List[torch.Tensor]:
708708
outputs = engine.execute(input_tensors)
709709
return [outputs] if isinstance(outputs, torch.Tensor) else list(outputs)
710+
711+
@torch.library.custom_op( # type: ignore[misc]
712+
"tensorrt::no_op_placeholder_for_execute_engine", mutates_args=()
713+
)
714+
def no_op_placeholder_for_execute_engine(
715+
inputs: List[torch.Tensor],
716+
abi_version: str,
717+
name: str,
718+
serialized_device_info: str,
719+
serialized_engine: str,
720+
serialized_in_binding_names: str,
721+
serialized_out_binding_names: str,
722+
serialized_hardware_compatible: str,
723+
serialized_metadata: str,
724+
serialized_target_platform: str,
725+
serialized_require_output_allocator: str,
726+
serialized_resource_allocation_strategy: str,
727+
) -> List[torch.Tensor]:
728+
raise RuntimeError(
729+
"TensorRT engine placeholder reached eager execution; load this artifact with "
730+
"torch_tensorrt.load() so placeholders are lowered to execute_engine."
731+
)

py/torch_tensorrt/dynamo/runtime/meta_ops/register_meta_ops.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -354,7 +354,9 @@ def no_op_placeholder_for_execute_engine(
354354
serialized_metadata: str,
355355
serialized_target_platform: str,
356356
serialized_require_output_allocator: str,
357+
serialized_resource_allocation_strategy: str,
357358
) -> List[torch.Tensor]:
358359
raise RuntimeError(
359-
"The saved model is cross compiled for windows in Linux, should only be loadded in Windows via torch_tensorrt.load_cross_compiled_exported_program() api."
360+
"TensorRT engine placeholder reached eager execution; load this artifact with "
361+
"torch_tensorrt.load() so placeholders are lowered to execute_engine."
360362
)

0 commit comments

Comments
 (0)