diff --git a/tensorflow_text/core/kernels/BUILD b/tensorflow_text/core/kernels/BUILD index 8aba03626..743530abd 100644 --- a/tensorflow_text/core/kernels/BUILD +++ b/tensorflow_text/core/kernels/BUILD @@ -11,7 +11,10 @@ load("//tensorflow_text:tftext.bzl", "tf_cc_library", "tflite_cc_library") licenses(["notice"]) # Visibility rules -package(default_visibility = ["//visibility:public"]) +package( + default_applicable_licenses = ["//tensorflow_text:license"], + default_visibility = ["//visibility:public"], +) exports_files(["LICENSE"]) @@ -347,7 +350,10 @@ cc_library( "darts_clone_trie_wrapper.h", ], deps = [ + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/types:span", ], ) diff --git a/tensorflow_text/core/kernels/darts_clone_trie_test.cc b/tensorflow_text/core/kernels/darts_clone_trie_test.cc index a80c28353..b800b0e5e 100644 --- a/tensorflow_text/core/kernels/darts_clone_trie_test.cc +++ b/tensorflow_text/core/kernels/darts_clone_trie_test.cc @@ -31,7 +31,7 @@ TEST(DartsCloneTrieTest, CreateCursorPointToRootAndTryTraverseOneStep) { ASSERT_OK_AND_ASSIGN(std::vector trie_array, BuildDartsCloneTrie(vocab_tokens)); ASSERT_OK_AND_ASSIGN(DartsCloneTrieWrapper trie, - DartsCloneTrieWrapper::Create(trie_array.data())); + DartsCloneTrieWrapper::Create(trie_array)); DartsCloneTrieWrapper::TraversalCursor cursor; int data; @@ -56,7 +56,7 @@ TEST(DartsCloneTrieTest, CreateCursorAndTryTraverseSeveralSteps) { ASSERT_OK_AND_ASSIGN(std::vector trie_array, BuildDartsCloneTrie(vocab_tokens)); ASSERT_OK_AND_ASSIGN(DartsCloneTrieWrapper trie, - DartsCloneTrieWrapper::Create(trie_array.data())); + DartsCloneTrieWrapper::Create(trie_array)); DartsCloneTrieWrapper::TraversalCursor cursor; int data; @@ -76,7 +76,7 @@ TEST(DartsCloneTrieTest, TraversePathNotExisted) { ASSERT_OK_AND_ASSIGN(std::vector trie_array, BuildDartsCloneTrie(vocab_tokens)); ASSERT_OK_AND_ASSIGN(DartsCloneTrieWrapper trie, - DartsCloneTrieWrapper::Create(trie_array.data())); + DartsCloneTrieWrapper::Create(trie_array)); DartsCloneTrieWrapper::TraversalCursor cursor; @@ -94,7 +94,7 @@ TEST(DartsCloneTrieTest, TraverseOnUtf8Path) { ASSERT_OK_AND_ASSIGN(std::vector trie_array, BuildDartsCloneTrie(vocab_tokens)); ASSERT_OK_AND_ASSIGN(DartsCloneTrieWrapper trie, - DartsCloneTrieWrapper::Create(trie_array.data())); + DartsCloneTrieWrapper::Create(trie_array)); DartsCloneTrieWrapper::TraversalCursor cursor; int data; @@ -115,7 +115,7 @@ TEST(DartsCloneTrieTest, TraverseOnPartialUtf8Path) { ASSERT_OK_AND_ASSIGN(std::vector trie_array, BuildDartsCloneTrie(vocab_tokens)); ASSERT_OK_AND_ASSIGN(DartsCloneTrieWrapper trie, - DartsCloneTrieWrapper::Create(trie_array.data())); + DartsCloneTrieWrapper::Create(trie_array)); DartsCloneTrieWrapper::TraversalCursor cursor; int data; @@ -135,7 +135,7 @@ TEST(DartsCloneTrieTest, TraverseOnUtf8PathNotExisted) { ASSERT_OK_AND_ASSIGN(std::vector trie_array, BuildDartsCloneTrie(vocab_tokens)); ASSERT_OK_AND_ASSIGN(DartsCloneTrieWrapper trie, - DartsCloneTrieWrapper::Create(trie_array.data())); + DartsCloneTrieWrapper::Create(trie_array)); DartsCloneTrieWrapper::TraversalCursor cursor; @@ -183,6 +183,20 @@ TEST(DartsCloneTrieBuildError, NegativeValues) { StatusIs(util::error::INVALID_ARGUMENT)); } +TEST(DartsCloneTrieTest, OutOfBoundsAccessIsRejected) { + std::vector vocab_tokens{"def", "\xe1\xb8\x8aZZ", "Abc"}; + ASSERT_OK_AND_ASSIGN(std::vector trie_array, + BuildDartsCloneTrie(vocab_tokens)); + // Wrap using a constrained span to emulate an out-of-bounds access attempts. + auto span = absl::MakeSpan(trie_array.data(), 1); + ASSERT_OK_AND_ASSIGN(DartsCloneTrieWrapper trie, + DartsCloneTrieWrapper::Create(span)); + + DartsCloneTrieWrapper::TraversalCursor cursor = + trie.CreateTraversalCursorPointToRoot(); + EXPECT_FALSE(trie.TryTraverseOneStep(cursor, 'd')); +} + } // namespace trie_utils } // namespace text } // namespace tensorflow diff --git a/tensorflow_text/core/kernels/darts_clone_trie_wrapper.h b/tensorflow_text/core/kernels/darts_clone_trie_wrapper.h index 43067ec1b..aa80d66bd 100644 --- a/tensorflow_text/core/kernels/darts_clone_trie_wrapper.h +++ b/tensorflow_text/core/kernels/darts_clone_trie_wrapper.h @@ -30,7 +30,10 @@ #include #include +#include "absl/base/attributes.h" +#include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/types/span.h" namespace tensorflow { namespace text { @@ -51,14 +54,14 @@ class DartsCloneTrieWrapper { uint32_t unit = 0; }; - // Constructs an instance by passing in the pointer to the trie array data. + // Constructs an instance by passing in the span of the trie array data. // The caller needs to make sure that 'trie_array' points to a valid structure // returned by darts_clone trie builder. The caller also needs to maintain the // availability of 'trie_array' throughout the lifetime of this instance. static absl::StatusOr Create( - const uint32_t* trie_array) { - if (trie_array == nullptr) { - return absl::InvalidArgumentError("trie_array is nullptr."); + absl::Span trie_array) { + if (trie_array.empty() || trie_array.data() == nullptr) { + return absl::InvalidArgumentError("trie_array is empty or nullptr."); } return DartsCloneTrieWrapper(trie_array); } @@ -70,13 +73,18 @@ class DartsCloneTrieWrapper { // Creates a cursor pointing to the 'node_id'. TraversalCursor CreateTraversalCursor(uint32_t node_id) { + if (node_id >= trie_array_.size()) { + return {0, 0}; + } return {node_id, trie_array_[node_id]}; } // Sets the cursor to point to 'node_id'. void SetTraversalCursor(TraversalCursor& cursor, uint32_t node_id) { - cursor.node_id = node_id; - cursor.unit = trie_array_[node_id]; + if (node_id < trie_array_.size()) { + cursor.node_id = node_id; + cursor.unit = trie_array_[node_id]; + } } // Traverses one step from 'cursor' following 'ch'. If successful (i.e., there @@ -84,6 +92,9 @@ class DartsCloneTrieWrapper { // Otherwise, does nothing (i.e., 'cursor' is not changed) and returns false. bool TryTraverseOneStep(TraversalCursor& cursor, unsigned char ch) const { const uint32_t next_node_id = cursor.node_id ^ offset(cursor.unit) ^ ch; + if (next_node_id >= trie_array_.size()) { + return false; + } const uint32_t next_node_unit = trie_array_[next_node_id]; if (label(next_node_unit) != ch) { return false; @@ -108,15 +119,18 @@ class DartsCloneTrieWrapper { if (!has_leaf(cursor.unit)) { return false; } - const uint32_t value_unit = - trie_array_[cursor.node_id ^ offset(cursor.unit)]; + const uint32_t value_node_id = cursor.node_id ^ offset(cursor.unit); + if (value_node_id >= trie_array_.size()) { + return false; + } + const uint32_t value_unit = trie_array_[value_node_id]; out_data = value(value_unit); return true; } private: // Use Create() instead of the constructor. - explicit DartsCloneTrieWrapper(const uint32_t* trie_array) + explicit DartsCloneTrieWrapper(absl::Span trie_array) : trie_array_(trie_array) {} // The actual implementation of TryTraverseSeveralSteps. @@ -127,6 +141,9 @@ class DartsCloneTrieWrapper { for (; size > 0; --size, ++ptr) { const unsigned char ch = static_cast(*ptr); cur_id ^= offset(cur_unit) ^ ch; + if (cur_id >= trie_array_.size()) { + return false; + } cur_unit = trie_array_[cur_id]; if (label(cur_unit) != ch) { return false; @@ -157,8 +174,8 @@ class DartsCloneTrieWrapper { return static_cast(unit & 0x7fffffff); } - // The pointer to the darts trie array. - const uint32_t* trie_array_; + // The dart trie array represented as a span for bounds awareness. + absl::Span trie_array_; }; } // namespace trie_utils diff --git a/tensorflow_text/core/kernels/fast_bert_normalizer.h b/tensorflow_text/core/kernels/fast_bert_normalizer.h index efd5102bf..f3a8618e6 100644 --- a/tensorflow_text/core/kernels/fast_bert_normalizer.h +++ b/tensorflow_text/core/kernels/fast_bert_normalizer.h @@ -82,7 +82,7 @@ class FastBertNormalizer { // which is not owned by this instance and should be kept alive through the // lifetime of the instance. static absl::StatusOr Create( - const uint32_t* trie_data, int data_for_codepoint_zero, + absl::Span trie_data, int data_for_codepoint_zero, const char* normalized_string_pool) { FastBertNormalizer result; SH_ASSIGN_OR_RETURN(auto trie, @@ -106,7 +106,9 @@ class FastBertNormalizer { // `GetFastBertNormalizerModel()` is autogenerated by flatbuffer. auto model = GetFastBertNormalizerModel(model_flatbuffer); return Create( - model->trie_array()->data(), model->data_for_codepoint_zero(), + absl::MakeSpan(model->trie_array()->data(), + model->trie_array()->size()), + model->data_for_codepoint_zero(), reinterpret_cast(model->normalized_string_pool()->data())); } diff --git a/tensorflow_text/core/kernels/fast_bert_normalizer_model_builder.cc b/tensorflow_text/core/kernels/fast_bert_normalizer_model_builder.cc index 808d2a09c..cdcdb3470 100644 --- a/tensorflow_text/core/kernels/fast_bert_normalizer_model_builder.cc +++ b/tensorflow_text/core/kernels/fast_bert_normalizer_model_builder.cc @@ -229,7 +229,7 @@ FastBertNormalizerFactory::FastBertNormalizerFactory( return; } auto char_set_recognizer_mapper = FastBertNormalizer::Create( - trie_data_.data(), data_for_codepoint_zero_, mapped_value_pool_.data()); + trie_data_, data_for_codepoint_zero_, mapped_value_pool_.data()); if (!char_set_recognizer_mapper.ok()) { // Should never happen since the same code must have passed the unit tests. LOG(ERROR) << "Unexpected error: Failed to initialize " diff --git a/tensorflow_text/core/kernels/fast_wordpiece_tokenizer.cc b/tensorflow_text/core/kernels/fast_wordpiece_tokenizer.cc index 88feee121..36b4585a8 100644 --- a/tensorflow_text/core/kernels/fast_wordpiece_tokenizer.cc +++ b/tensorflow_text/core/kernels/fast_wordpiece_tokenizer.cc @@ -49,7 +49,8 @@ FastWordpieceTokenizer::Create(const void* config_flatbuffer) { // `GetFastWordpieceTokenizerConfig()` is autogenerated by flatbuffer. tokenizer.config_ = GetFastWordpieceTokenizerConfig(config_flatbuffer); auto trie_or = trie_utils::DartsCloneTrieWrapper::Create( - tokenizer.config_->trie_array()->data()); + absl::MakeSpan(tokenizer.config_->trie_array()->data(), + tokenizer.config_->trie_array()->size())); if (!trie_or.ok()) { return absl::InvalidArgumentError( "Failed to create DartsCloneTrieWrapper from " diff --git a/tensorflow_text/core/kernels/fast_wordpiece_tokenizer_model_builder.cc b/tensorflow_text/core/kernels/fast_wordpiece_tokenizer_model_builder.cc index 9467c4d6e..3003b232e 100644 --- a/tensorflow_text/core/kernels/fast_wordpiece_tokenizer_model_builder.cc +++ b/tensorflow_text/core/kernels/fast_wordpiece_tokenizer_model_builder.cc @@ -434,7 +434,7 @@ absl::Status FastWordpieceBuilder::ConstructTrie( trie_utils::BuildDartsCloneTrie(keys, values)); SH_ASSIGN_OR_RETURN( trie_utils::DartsCloneTrieWrapper trie, - trie_utils::DartsCloneTrieWrapper::Create(trie_array_.data())); + trie_utils::DartsCloneTrieWrapper::Create(trie_array_)); trie_.emplace(std::move(trie)); if (trie_array_.size() > diff --git a/tensorflow_text/python/ops/test_data/fast_bert_normalizer_model_lower_case_nfd_strip_accents.fb b/tensorflow_text/python/ops/test_data/fast_bert_normalizer_model_lower_case_nfd_strip_accents.fb index 04186b144..38e7a183f 100644 Binary files a/tensorflow_text/python/ops/test_data/fast_bert_normalizer_model_lower_case_nfd_strip_accents.fb and b/tensorflow_text/python/ops/test_data/fast_bert_normalizer_model_lower_case_nfd_strip_accents.fb differ