|
19 | 19 | OutputSpec, |
20 | 20 | TensorArgument, |
21 | 21 | ) |
| 22 | +from torch_tensorrt._features import ENABLED_FEATURES |
22 | 23 | from torch_tensorrt.dynamo.runtime._TorchTensorRTModule import ENGINE_IDX, NAME_IDX |
23 | 24 |
|
24 | 25 |
|
@@ -483,36 +484,18 @@ def inline_trt_modules( |
483 | 484 | f"trt_module_node: {trt_module_node.name} does not have the metadata which should be set during dynamo compile_module step." |
484 | 485 | ) |
485 | 486 | num_outputs = len(trt_module_node.meta["val"]) |
486 | | - # Insert a call_function node to perform inference on TRT engine |
487 | 487 | with gm.graph.inserting_before(trt_module_node): |
488 | | - if cross_compile_module: |
489 | | - engine_info = trt_module._pack_engine_info() |
490 | | - engine_bytes = engine_info[ENGINE_IDX] |
491 | | - engine_info[ENGINE_IDX] = base64.b64encode(engine_bytes).decode("utf-8") |
492 | | - # insert the no_placeholder node in the graph which should be replaced to the actual execute_engine node while load in the windows |
493 | | - trt_node = gm.graph.call_function( |
494 | | - torch.ops.tensorrt.no_op_placeholder_for_execute_engine.default, |
495 | | - (trt_module_node.args, *engine_info), |
496 | | - ) |
497 | | - else: |
498 | | - # for the normal workflow: use the execute_engine node |
499 | | - engine_name = f"{name}_engine" |
500 | | - # TODO: THROWS SOME WARNING ABOUT A LACK OF UNDERLYING REFERENCE TO THE OWNING GRAPH MODULE |
501 | | - # SAYS THERES 3 OPTIONS, SUBMODULE, PARAMETER, OR BUFFER, BUFFER SEEMS THE BEST BUT I THINK ITS KEYED TO TENSORS |
502 | | - setattr(gm, engine_name, trt_module.engine) |
503 | | - engine_node = gm.graph.get_attr(engine_name) |
504 | | - |
505 | | - trt_node = gm.graph.call_function( |
506 | | - torch.ops.tensorrt.execute_engine.default, |
507 | | - (trt_module_node.args, engine_node), |
508 | | - ) |
509 | | - # meta["val"] should be a lighter version of a tensor. For eg: it should be a FakeTensor (with output shape and dtype properties) |
510 | | - # Lighter version of a custom_obj is not defined clearly. meta["val"] does not have any type expectations but |
511 | | - # for custom object nodes, it should be CustomObjArgument |
512 | | - engine_node.meta["val"] = CustomObjArgument( |
513 | | - name=engine_node.name, class_fqn="" |
514 | | - ) |
515 | | - # set trt_node.meta with trt_module_node.meta |
| 488 | + # Always embed engine data as primitive string args via no_op_placeholder |
| 489 | + # so torch.export does not pickle torch.classes.tensorrt.Engine (which |
| 490 | + # requires the C++ TorchBind class at load time). |
| 491 | + # torch_tensorrt.load() lowers placeholders → execute_engine. |
| 492 | + engine_info = trt_module._pack_engine_info() |
| 493 | + engine_bytes = engine_info[ENGINE_IDX] |
| 494 | + engine_info[ENGINE_IDX] = base64.b64encode(engine_bytes).decode("utf-8") |
| 495 | + trt_node = gm.graph.call_function( |
| 496 | + torch.ops.tensorrt.no_op_placeholder_for_execute_engine.default, |
| 497 | + (trt_module_node.args, *engine_info), |
| 498 | + ) |
516 | 499 | assert num_outputs > 0 |
517 | 500 | trt_node.meta["val"] = trt_module_node.meta["val"] |
518 | 501 |
|
@@ -557,7 +540,12 @@ def replace_execute_engine_no_op_node( |
557 | 540 | packed_engine_info[ENGINE_IDX] = base64.b64decode( |
558 | 541 | engine_bytes.encode("utf-8") |
559 | 542 | ) |
560 | | - trt_engine = torch.classes.tensorrt.Engine(tuple(packed_engine_info)) |
| 543 | + if ENABLED_FEATURES.torch_tensorrt_runtime: |
| 544 | + trt_engine = torch.classes.tensorrt.Engine(tuple(packed_engine_info)) |
| 545 | + else: |
| 546 | + from torch_tensorrt.dynamo.runtime._PythonTRTEngine import TRTEngine |
| 547 | + |
| 548 | + trt_engine = TRTEngine(packed_engine_info) |
561 | 549 | setattr(gm, engine_name, trt_engine) |
562 | 550 | engine_node = gm.graph.get_attr(engine_name) |
563 | 551 |
|
|
0 commit comments