diff --git a/runtime/core/exec_aten/util/tensor_shape_to_c_string.cpp b/runtime/core/exec_aten/util/tensor_shape_to_c_string.cpp index 02155a4d9b4..27423af5ee0 100644 --- a/runtime/core/exec_aten/util/tensor_shape_to_c_string.cpp +++ b/runtime/core/exec_aten/util/tensor_shape_to_c_string.cpp @@ -28,6 +28,10 @@ std::array tensor_shape_to_c_string_impl( std::strcpy(p, kLimitExceededError); return out; } + if (shape.empty()) { + std::strcpy(p, "()"); + return out; + } *p++ = '('; for (const auto elem : shape) { if (elem < 0 || diff --git a/runtime/core/exec_aten/util/test/tensor_shape_to_c_string_test.cpp b/runtime/core/exec_aten/util/test/tensor_shape_to_c_string_test.cpp index 8059360afd9..7d23e412061 100644 --- a/runtime/core/exec_aten/util/test/tensor_shape_to_c_string_test.cpp +++ b/runtime/core/exec_aten/util/test/tensor_shape_to_c_string_test.cpp @@ -17,6 +17,12 @@ using executorch::runtime::Span; using executorch::runtime::tensor_shape_to_c_string; using executorch::runtime::internal::kMaximumPrintableTensorShapeElement; +TEST(TensorShapeToCStringTest, ScalarShape) { + Span scalar_shape; + auto str = tensor_shape_to_c_string(scalar_shape); + EXPECT_STREQ(str.data(), "()"); +} + TEST(TensorShapeToCStringTest, Basic) { std::array sizes = {123, 456, 789}; auto str = tensor_shape_to_c_string( @@ -29,6 +35,24 @@ TEST(TensorShapeToCStringTest, Basic) { EXPECT_STREQ(str.data(), "(1234567890)"); } +TEST(TensorShapeToCStringTest, RankOneShape) { + std::array sizes = {3}; + auto str = tensor_shape_to_c_string( + Span(sizes.data(), sizes.size())); + EXPECT_STREQ(str.data(), "(3)"); +} + +TEST(TensorShapeToCStringTest, InvalidNegativeDimension) { + std::array sizes = {-3}; + auto str = tensor_shape_to_c_string( + Span(sizes.data(), sizes.size())); + if constexpr (std::numeric_limits::is_signed) { + EXPECT_STREQ(str.data(), "(ERR)"); + } else { + EXPECT_EQ(str.data(), "(" + std::to_string(sizes[0]) + ")"); + } +} + TEST(TensorShapeToCStringTest, NegativeItems) { std::array sizes = {-1, -3, -2, 4}; auto str = tensor_shape_to_c_string( @@ -78,6 +102,23 @@ TEST(TensorShapeToCStringTest, MaximumLength) { EXPECT_EQ(expected_str, str.data()); } +TEST(TensorShapeToCStringTest, NearDimensionLimit) { + std::array sizes; + std::fill(sizes.begin(), sizes.end(), 3); + + auto str = tensor_shape_to_c_string( + Span(sizes.data(), sizes.size())); + + std::ostringstream expected; + expected << '(' << sizes[0]; + for (size_t ii = 1; ii < sizes.size(); ++ii) { + expected << ", " << sizes[ii]; + } + expected << ')'; + + EXPECT_EQ(expected.str(), str.data()); +} + TEST(TensorShapeToCStringTest, ExceedsDimensionLimit) { std::array sizes; std::fill(sizes.begin(), sizes.end(), kMaximumPrintableTensorShapeElement);