Skip to content

Commit e33ca90

Browse files
committed
review comments
1 parent 805b2b9 commit e33ca90

2 files changed

Lines changed: 32 additions & 10 deletions

File tree

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 32 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -261,10 +261,17 @@ def cross_compile_for_windows(
261261
if use_explicit_typing:
262262
if len(enabled_precisions) != 1 or not any(
263263
x in enabled_precisions
264-
for x in {torch.float32, dtype.f32, torch.float4_e2m1fn_x2, dtype.f4}
264+
for x in {
265+
torch.float32,
266+
dtype.f32,
267+
torch.float4_e2m1fn_x2,
268+
dtype.f4,
269+
torch.float8_e4m3fn,
270+
dtype.f8,
271+
}
265272
):
266273
raise AssertionError(
267-
f"use_explicit_typing was set to True, however found that enabled_precisions was also specified (saw: {enabled_precisions}, expected: dtype.f32, dtype.f4). enabled_precisions should not be used when use_explicit_typing=True"
274+
f"use_explicit_typing was set to True, however found that enabled_precisions was also specified (saw: {enabled_precisions}, expected: dtype.f32, dtype.f4, dtype.f8). enabled_precisions should not be used when use_explicit_typing=True"
268275
)
269276

270277
if use_fp32_acc:
@@ -641,10 +648,17 @@ def compile(
641648
if use_explicit_typing:
642649
if len(enabled_precisions) != 1 or not any(
643650
x in enabled_precisions
644-
for x in {torch.float32, dtype.f32, torch.float4_e2m1fn_x2, dtype.f4}
651+
for x in {
652+
torch.float32,
653+
dtype.f32,
654+
torch.float4_e2m1fn_x2,
655+
dtype.f4,
656+
torch.float8_e4m3fn,
657+
dtype.f8,
658+
}
645659
):
646660
raise AssertionError(
647-
f"use_explicit_typing was set to True, however found that enabled_precisions was also specified (saw: {enabled_precisions}, expected: dtype.f32, dtype.f4). enabled_precisions should not be used when use_explicit_typing=True"
661+
f"use_explicit_typing was set to True, however found that enabled_precisions was also specified (saw: {enabled_precisions}, expected: dtype.f32, dtype.f4, dtype.f8). enabled_precisions should not be used when use_explicit_typing=True"
648662
)
649663

650664
if autocast_low_precision_type is not None:
@@ -1310,10 +1324,17 @@ def convert_exported_program_to_serialized_trt_engine(
13101324
if use_explicit_typing:
13111325
if len(enabled_precisions) != 1 or not any(
13121326
x in enabled_precisions
1313-
for x in {torch.float32, dtype.f32, torch.float4_e2m1fn_x2, dtype.f4}
1327+
for x in {
1328+
torch.float32,
1329+
dtype.f32,
1330+
torch.float4_e2m1fn_x2,
1331+
dtype.f4,
1332+
torch.float8_e4m3fn,
1333+
dtype.f8,
1334+
}
13141335
):
13151336
raise AssertionError(
1316-
f"use_explicit_typing was set to True, however found that enabled_precisions was also specified (saw: {enabled_precisions}, expected: dtype.f32, dtype.f4). enabled_precisions should not be used when use_explicit_typing=True"
1337+
f"use_explicit_typing was set to True, however found that enabled_precisions was also specified (saw: {enabled_precisions}, expected: dtype.f32, dtype.f4, dtype.f8). enabled_precisions should not be used when use_explicit_typing=True"
13171338
)
13181339

13191340
if use_fp32_acc:
@@ -1447,18 +1468,20 @@ def convert_exported_program_to_serialized_trt_engine(
14471468
settings=settings,
14481469
engine_cache=engine_cache,
14491470
)
1450-
except UnsupportedOperatorException:
1471+
except UnsupportedOperatorException as e:
14511472
logger.error(
14521473
f"Conversion of module {gm} not currently fully supported or convertible!",
14531474
exc_info=True,
14541475
)
1455-
raise
1476+
raise UnsupportedOperatorException(
1477+
f"Conversion of module {gm} not currently fully supported or convertible!"
1478+
) from e
14561479
except Exception as e:
14571480
logger.error(
14581481
f"While interpreting the module got an error: {e}",
14591482
exc_info=True,
14601483
)
1461-
raise
1484+
raise RuntimeError(f"While interpreting the module got an error: {e}") from e
14621485

14631486
serialized_engine: bytes = interpreter_result.serialized_engine
14641487
return serialized_engine

tests/py/dynamo/models/test_models_export.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -386,7 +386,6 @@ def calibrate_loop(model):
386386
exp_program,
387387
inputs=[input_tensor],
388388
enabled_precisions={torch.float8_e4m3fn},
389-
use_explicit_typing=False,
390389
min_block_size=1,
391390
cache_built_engines=False,
392391
reuse_cached_engines=False,

0 commit comments

Comments
 (0)