Skip to content

Commit b1653c0

Browse files
committed
[tensorrt] [byoc] [plugin] allows save external data
1 parent dcd46ca commit b1653c0

3 files changed

Lines changed: 15 additions & 5 deletions

File tree

python/tvm/tpat/cuda/kernel.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717

18+
import os
19+
1820
import tvm
1921
import tvm.contrib.graph_executor as runtime
2022
import tvm.relay as relay
@@ -86,7 +88,7 @@ def run(self):
8688
mod, params = relay.frontend.from_onnx(self._config.onnx_model)
8789

8890
# 2. Tune it
89-
if self._enable_tunning:
91+
if self._enable_tunning and not os.path.exists(self._config.work_dir):
9092
tunning_option = self._config._tune_option()
9193
ms.relay_integration.tune_relay(mod=mod, params=params, **tunning_option)
9294

python/tvm/tpat/cuda/onnx_util.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,9 +90,14 @@ def _handle_trt_not_support_type(
9090
_remove_unnecessary_cast_nodes(graph)
9191

9292
try:
93-
onnx.save(gs.export_onnx(graph), output_model_path)
93+
onnx.save(gs.export_onnx(graph), output_model_path["name"])
9494
except:
95-
onnx.save(gs.export_onnx(graph), output_model_path, save_as_external_data=True)
95+
onnx.save(
96+
gs.export_onnx(graph),
97+
output_model_path["name"],
98+
save_as_external_data=True,
99+
location=output_model_path["weights"],
100+
)
96101

97102

98103
def _remove_unnecessary_cast_nodes(graph):

python/tvm/tpat/cuda/pipeline.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ def pipeline(
119119
node_names: list[str],
120120
enable_tunning: bool,
121121
tunning_option: object,
122-
output_onnx: str,
122+
output_onnx: object,
123123
) -> Tuple[str, list[str]]:
124124
"""Generate plugins for specified nodes in an ONNX model.
125125
@@ -135,8 +135,11 @@ def pipeline(
135135
Flag indicating whether tunning is enabled.
136136
tunning_option : object
137137
Tunning option provided for ms.relay_integration.tune_relay, you don't need to specify mod, params and target.
138-
output_onnx : str
138+
output_onnx : object
139+
{ "name": xx, "weights": xx }
139140
Path to the output ONNX file where the modified model will be saved.
141+
It will firstly try to save without weights, if it fails, it will then
142+
save it with weights.
140143
141144
Returns
142145
-------

0 commit comments

Comments
 (0)