Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 29 additions & 9 deletions tensorflow_serving/servables/tensorflow/tflite_interpreter_pool.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,18 +66,33 @@ absl::Status TfLiteInterpreterWrapper::SetStringData(
// [4] offset of each string (int32_t)
// [sizeof(int32_t) * (num_strings + 1)]] total size of strings
// [sizeof(int32_t) * (num_strings + 2)] batch.data()
int32_t num_strings = batch_size;
offset_.clear();
(void)batch_size;
std::vector<size_t> offsets;
size_t total_size = 0;
offset_.push_back(static_cast<int32_t>(total_size));
offsets.push_back(total_size);
for (const auto& tensor : tensors) {
const auto& flat = tensor->flat<tstring>();
for (int i = 0; i < flat.size(); ++i) {
if (flat(i).size() > std::numeric_limits<size_t>::max() - total_size) {
return absl::InternalError("String input is too large.");
}
total_size += flat(i).size();
offset_.push_back(static_cast<int32_t>(total_size));
offsets.push_back(total_size);
}
}
size_t required_bytes = total_size + sizeof(int32_t) * (num_strings + 2);
const size_t num_strings = offsets.size() - 1;
if (num_strings > std::numeric_limits<int32_t>::max()) {
return absl::InternalError("Too many string inputs.");
}
const size_t header_entries = num_strings + 2;
if (header_entries > std::numeric_limits<size_t>::max() / sizeof(int32_t)) {
return absl::InternalError("String input header is too large.");
}
const size_t header_bytes = sizeof(int32_t) * header_entries;
if (total_size > std::numeric_limits<size_t>::max() - header_bytes) {
return absl::InternalError("String input buffer is too large.");
}
size_t required_bytes = total_size + header_bytes;
if (tensor_buffer_.find(tensor_index) == tensor_buffer_.end()) {
return absl::InternalError(
absl::StrCat("Tensor input for index not found: ", tensor_index));
Expand All @@ -87,13 +102,18 @@ absl::Status TfLiteInterpreterWrapper::SetStringData(
free(tflite_tensor->data.raw);
}
tflite_tensor->data.raw = reinterpret_cast<char*>(malloc(required_bytes));
if (tflite_tensor->data.raw == nullptr) {
return absl::ResourceExhaustedError("Failed to allocate string input.");
}
tensor_buffer_max_bytes_[tensor_index] = required_bytes;
}
tensor_buffer_[tensor_index].reset(tflite_tensor->data.raw);
memcpy(tensor_buffer_[tensor_index].get(), &num_strings, sizeof(int32_t));
int32_t start = sizeof(int32_t) * (num_strings + 2);
for (size_t i = 0; i < offset_.size(); i++) {
size_t size_offset_i = start + offset_[i];
const int32_t num_strings_i32 = static_cast<int32_t>(num_strings);
memcpy(tensor_buffer_[tensor_index].get(), &num_strings_i32,
sizeof(int32_t));
size_t start = header_bytes;
for (size_t i = 0; i < offsets.size(); i++) {
size_t size_offset_i = start + offsets[i];
if (size_offset_i > std::numeric_limits<int32_t>::max()) {
return absl::InternalError(
absl::StrCat("Invalid size, string input too large:", size_offset_i));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,6 @@ class TfLiteInterpreterWrapper {
int batch_size_ = 1;
std::map<int, std::unique_ptr<char>> tensor_buffer_;
std::map<int, size_t> tensor_buffer_max_bytes_;
std::vector<int32_t> offset_;
#ifdef TFLITE_PROFILE
int max_num_entries_;
tflite::profiling::ProfileSummarizer run_summarizer_;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,43 @@ TEST(TfLiteInterpreterWrapper, TfLiteInterpreterWrapperTest) {
::testing::ElementsAreArray(expected_strs));
}

TEST(TfLiteInterpreterWrapper, SetStringDataUsesFlattenedStringCount) {
std::string model_bytes;
TF_ASSERT_OK(ReadFileToString(Env::Default(),
test_util::TestSrcDirPath(kParseExampleModel),
&model_bytes));
auto model = tflite::FlatBufferModel::BuildFromModel(
flatbuffers::GetRoot<tflite::Model>(model_bytes.data()));
tflite::ops::builtin::BuiltinOpResolver resolver;
tflite::ops::custom::AddParseExampleOp(&resolver);
std::unique_ptr<tflite::Interpreter> interpreter;
ASSERT_EQ(tflite::InterpreterBuilder(*model, resolver)(&interpreter,
/*num_threads=*/1),
kTfLiteOk);
ASSERT_EQ(interpreter->inputs().size(), 1);
const int idx = interpreter->inputs()[0];
auto* tensor = interpreter->tensor(idx);
ASSERT_EQ(tensor->type, kTfLiteString);
ASSERT_EQ(interpreter->ResizeInputTensor(idx, {1}), kTfLiteOk);
ASSERT_EQ(interpreter->AllocateTensors(), kTfLiteOk);

auto interpreter_wrapper =
std::make_unique<TfLiteInterpreterWrapper>(std::move(interpreter));

Tensor input(DT_STRING, TensorShape({1, 2}));
input.flat<tstring>()(0) = "first";
input.flat<tstring>()(1) = "second";
std::vector<const Tensor*> data = {&input};

auto* wrapped = interpreter_wrapper->Get();
tensor = wrapped->tensor(idx);
TF_ASSERT_OK(interpreter_wrapper->SetStringData(data, tensor, idx,
input.dim_size(0)));

const auto strings = ExtractVector<std::string>(wrapped->tensor(idx));
EXPECT_THAT(strings, ::testing::ElementsAre("first", "second"));
}

} // namespace internal
} // namespace serving
} // namespace tensorflow