Skip to content

Commit bf6e60e

Browse files
jckingcopybara-github
authored andcommitted
Implement optimized string and bytes concatenation
PiperOrigin-RevId: 738849121
1 parent e249db7 commit bf6e60e

7 files changed

Lines changed: 178 additions & 22 deletions

File tree

common/internal/byte_string.cc

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,51 @@ T ConsumeAndDestroy(T& object) {
5656

5757
} // namespace
5858

59+
ByteString ByteString::Concat(const ByteString& lhs, const ByteString& rhs,
60+
absl::Nonnull<google::protobuf::Arena*> arena) {
61+
ABSL_DCHECK(arena != nullptr);
62+
63+
if (lhs.empty()) {
64+
return rhs;
65+
}
66+
if (rhs.empty()) {
67+
return lhs;
68+
}
69+
70+
if (lhs.GetKind() == ByteStringKind::kLarge ||
71+
rhs.GetKind() == ByteStringKind::kLarge) {
72+
// If either the left or right are absl::Cord, use absl::Cord.
73+
absl::Cord result;
74+
result.Append(lhs.ToCord());
75+
result.Append(rhs.ToCord());
76+
return ByteString(std::move(result));
77+
}
78+
79+
const size_t lhs_size = lhs.size();
80+
const size_t rhs_size = rhs.size();
81+
const size_t result_size = lhs_size + rhs_size;
82+
ByteString result;
83+
if (result_size <= kSmallByteStringCapacity) {
84+
// If the resulting string fits in inline storage, do it.
85+
result.rep_.small.size = result_size;
86+
result.rep_.small.arena = arena;
87+
lhs.CopyToArray(result.rep_.small.data);
88+
rhs.CopyToArray(result.rep_.small.data + lhs_size);
89+
} else {
90+
// Otherwise allocate on the arena.
91+
char* result_data =
92+
reinterpret_cast<char*>(arena->AllocateAligned(result_size));
93+
lhs.CopyToArray(result_data);
94+
rhs.CopyToArray(result_data + lhs_size);
95+
result.rep_.medium.data = result_data;
96+
result.rep_.medium.size = result_size;
97+
result.rep_.medium.owner =
98+
reinterpret_cast<uintptr_t>(arena) | kMetadataOwnerArenaBit;
99+
result.rep_.medium.kind = ByteStringKind::kMedium;
100+
}
101+
return result;
102+
}
103+
59104
ByteString::ByteString(Allocator<> allocator, absl::string_view string) {
60105
ABSL_DCHECK_LE(string.size(), max_size());
61106
auto* arena = allocator.arena();
@@ -249,6 +294,25 @@ void ByteString::RemoveSuffix(size_t n) {
249294
}
250295
}
251296

297+
void ByteString::CopyToArray(absl::Nonnull<char*> out) const {
298+
ABSL_DCHECK(out != nullptr);
299+
300+
switch (GetKind()) {
301+
case ByteStringKind::kSmall: {
302+
absl::string_view small = GetSmall();
303+
std::memcpy(out, small.data(), small.size());
304+
} break;
305+
case ByteStringKind::kMedium: {
306+
absl::string_view medium = GetMedium();
307+
std::memcpy(out, medium.data(), medium.size());
308+
} break;
309+
case ByteStringKind::kLarge: {
310+
const absl::Cord& large = GetLarge();
311+
(CopyCordToArray)(large, out);
312+
} break;
313+
}
314+
}
315+
252316
std::string ByteString::ToString() const {
253317
switch (GetKind()) {
254318
case ByteStringKind::kSmall:

common/internal/byte_string.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ namespace cel {
4343

4444
class BytesValueInputStream;
4545
class BytesValueOutputStream;
46+
class StringValue;
4647

4748
namespace common_internal {
4849

@@ -171,6 +172,9 @@ absl::string_view LegacyByteString(const ByteString& string, bool stable,
171172
class CEL_COMMON_INTERNAL_BYTE_STRING_TRIVIAL_ABI [[nodiscard]]
172173
ByteString final {
173174
public:
175+
static ByteString Concat(const ByteString& lhs, const ByteString& rhs,
176+
absl::Nonnull<google::protobuf::Arena*> arena);
177+
174178
ByteString() : ByteString(NewDeleteAllocator()) {}
175179

176180
explicit ByteString(absl::Nullable<const char*> string)
@@ -333,6 +337,7 @@ ByteString final {
333337
friend struct ByteStringTestFriend;
334338
friend class cel::BytesValueInputStream;
335339
friend class cel::BytesValueOutputStream;
340+
friend class cel::StringValue;
336341
friend absl::string_view LegacyByteString(
337342
const ByteString& string, bool stable,
338343
absl::Nonnull<google::protobuf::Arena*> arena);
@@ -475,6 +480,8 @@ ByteString final {
475480

476481
static void DestroyLarge(LargeByteStringRep& rep) { GetLarge(rep).~Cord(); }
477482

483+
void CopyToArray(absl::Nonnull<char*> out) const;
484+
478485
ByteStringRep rep_;
479486
};
480487

common/values/bytes_value.cc

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
#include "absl/strings/cord.h"
2424
#include "absl/strings/str_cat.h"
2525
#include "absl/strings/string_view.h"
26+
#include "common/internal/byte_string.h"
2627
#include "common/value.h"
2728
#include "internal/status_macros.h"
2829
#include "internal/strings.h"
@@ -53,6 +54,12 @@ std::string BytesDebugString(const Bytes& value) {
5354

5455
} // namespace
5556

57+
BytesValue BytesValue::Concat(const BytesValue& lhs, const BytesValue& rhs,
58+
absl::Nonnull<google::protobuf::Arena*> arena) {
59+
return BytesValue(
60+
common_internal::ByteString::Concat(lhs.value_, rhs.value_, arena));
61+
}
62+
5663
std::string BytesValue::DebugString() const { return BytesDebugString(*this); }
5764

5865
absl::Status BytesValue::SerializeTo(

common/values/bytes_value.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,10 @@ class BytesValue final : private common_internal::ValueMixin<BytesValue> {
8181
absl::Nullable<google::protobuf::Arena*> arena
8282
ABSL_ATTRIBUTE_LIFETIME_BOUND) = delete;
8383

84+
static BytesValue Concat(const BytesValue& lhs, const BytesValue& rhs,
85+
absl::Nonnull<google::protobuf::Arena*> arena
86+
ABSL_ATTRIBUTE_LIFETIME_BOUND);
87+
8488
ABSL_DEPRECATED("Use From")
8589
explicit BytesValue(absl::Nullable<const char*> value) : value_(value) {}
8690

common/values/string_value.cc

Lines changed: 80 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,17 +13,19 @@
1313
// limitations under the License.
1414

1515
#include <cstddef>
16+
#include <cstring>
1617
#include <string>
17-
#include <utility>
1818

1919
#include "google/protobuf/wrappers.pb.h"
2020
#include "absl/base/nullability.h"
2121
#include "absl/functional/overload.h"
2222
#include "absl/log/absl_check.h"
2323
#include "absl/status/status.h"
2424
#include "absl/strings/cord.h"
25+
#include "absl/strings/match.h"
2526
#include "absl/strings/str_cat.h"
2627
#include "absl/strings/string_view.h"
28+
#include "common/internal/byte_string.h"
2729
#include "common/value.h"
2830
#include "internal/status_macros.h"
2931
#include "internal/strings.h"
@@ -57,19 +59,8 @@ std::string StringDebugString(const Bytes& value) {
5759

5860
StringValue StringValue::Concat(const StringValue& lhs, const StringValue& rhs,
5961
absl::Nonnull<google::protobuf::Arena*> arena) {
60-
ABSL_DCHECK(arena != nullptr);
61-
62-
if (lhs.IsEmpty()) {
63-
return rhs;
64-
}
65-
if (rhs.IsEmpty()) {
66-
return lhs;
67-
}
68-
69-
absl::Cord result;
70-
result.Append(lhs.ToCord());
71-
result.Append(rhs.ToCord());
72-
return StringValue(std::move(result));
62+
return StringValue(
63+
common_internal::ByteString::Concat(lhs.value_, rhs.value_, arena));
7364
}
7465

7566
std::string StringValue::DebugString() const {
@@ -204,4 +195,79 @@ int StringValue::Compare(const StringValue& string) const {
204195
[this](const auto& alternative) -> int { return Compare(alternative); });
205196
}
206197

198+
bool StringValue::StartsWith(absl::string_view string) const {
199+
return value_.Visit(absl::Overload(
200+
[&](absl::string_view lhs) -> bool {
201+
return absl::StartsWith(lhs, string);
202+
},
203+
[&](const absl::Cord& lhs) -> bool { return lhs.StartsWith(string); }));
204+
}
205+
206+
bool StringValue::StartsWith(const absl::Cord& string) const {
207+
return value_.Visit(absl::Overload(
208+
[&](absl::string_view lhs) -> bool {
209+
return lhs.size() >= string.size() &&
210+
lhs == string.Subcord(0, lhs.size());
211+
},
212+
[&](const absl::Cord& lhs) -> bool { return lhs.StartsWith(string); }));
213+
}
214+
215+
bool StringValue::StartsWith(const StringValue& string) const {
216+
return string.value_.Visit(absl::Overload(
217+
[&](absl::string_view rhs) -> bool { return StartsWith(rhs); },
218+
[&](const absl::Cord& rhs) -> bool { return StartsWith(rhs); }));
219+
}
220+
221+
bool StringValue::EndsWith(absl::string_view string) const {
222+
return value_.Visit(absl::Overload(
223+
[&](absl::string_view lhs) -> bool {
224+
return absl::EndsWith(lhs, string);
225+
},
226+
[&](const absl::Cord& lhs) -> bool { return lhs.EndsWith(string); }));
227+
}
228+
229+
bool StringValue::EndsWith(const absl::Cord& string) const {
230+
return value_.Visit(absl::Overload(
231+
[&](absl::string_view lhs) -> bool {
232+
return lhs.size() >= string.size() &&
233+
lhs == string.Subcord(string.size() - lhs.size(), lhs.size());
234+
},
235+
[&](const absl::Cord& lhs) -> bool { return lhs.EndsWith(string); }));
236+
}
237+
238+
bool StringValue::EndsWith(const StringValue& string) const {
239+
return string.value_.Visit(absl::Overload(
240+
[&](absl::string_view rhs) -> bool { return EndsWith(rhs); },
241+
[&](const absl::Cord& rhs) -> bool { return EndsWith(rhs); }));
242+
}
243+
244+
bool StringValue::Contains(absl::string_view string) const {
245+
return value_.Visit(absl::Overload(
246+
[&](absl::string_view lhs) -> bool {
247+
return absl::StrContains(lhs, string);
248+
},
249+
[&](const absl::Cord& lhs) -> bool { return lhs.Contains(string); }));
250+
}
251+
252+
bool StringValue::Contains(const absl::Cord& string) const {
253+
return value_.Visit(absl::Overload(
254+
[&](absl::string_view lhs) -> bool {
255+
if (auto flat = string.TryFlat(); flat) {
256+
return absl::StrContains(lhs, *flat);
257+
}
258+
// There is no nice way to do this. We cannot use std::search due to
259+
// absl::Cord::CharIterator being an input iterator instead of a forward
260+
// iterator. So just make an external cord with a noop releaser. We know
261+
// the external cord will not outlive this function.
262+
return absl::MakeCordFromExternal(lhs, []() {}).Contains(string);
263+
},
264+
[&](const absl::Cord& lhs) -> bool { return lhs.Contains(string); }));
265+
}
266+
267+
bool StringValue::Contains(const StringValue& string) const {
268+
return string.value_.Visit(absl::Overload(
269+
[&](absl::string_view rhs) -> bool { return Contains(rhs); },
270+
[&](const absl::Cord& rhs) -> bool { return Contains(rhs); }));
271+
}
272+
207273
} // namespace cel

common/values/string_value.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,18 @@ class StringValue final : private common_internal::ValueMixin<StringValue> {
196196
int Compare(const absl::Cord& string) const;
197197
int Compare(const StringValue& string) const;
198198

199+
bool StartsWith(absl::string_view string) const;
200+
bool StartsWith(const absl::Cord& string) const;
201+
bool StartsWith(const StringValue& string) const;
202+
203+
bool EndsWith(absl::string_view string) const;
204+
bool EndsWith(const absl::Cord& string) const;
205+
bool EndsWith(const StringValue& string) const;
206+
207+
bool Contains(absl::string_view string) const;
208+
bool Contains(const absl::Cord& string) const;
209+
bool Contains(const StringValue& string) const;
210+
199211
absl::optional<absl::string_view> TryFlat() const
200212
ABSL_ATTRIBUTE_LIFETIME_BOUND {
201213
return value_.TryFlat();

runtime/standard/string_functions.cc

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
#include "absl/base/nullability.h"
2020
#include "absl/status/status.h"
2121
#include "absl/status/statusor.h"
22-
#include "absl/strings/match.h"
2322
#include "absl/strings/str_cat.h"
2423
#include "absl/strings/string_view.h"
2524
#include "base/builtins.h"
@@ -41,10 +40,7 @@ absl::StatusOr<StringValue> ConcatString(
4140
absl::Nonnull<const google::protobuf::DescriptorPool*>,
4241
absl::Nonnull<google::protobuf::MessageFactory*>,
4342
absl::Nonnull<google::protobuf::Arena*> arena) {
44-
// TODO: use StringValue::Concat when remaining interop usages
45-
// removed. Modern concat implementation forces additional copies when
46-
// converting to legacy string values.
47-
return StringValue(arena, absl::StrCat(value1.ToString(), value2.ToString()));
43+
return StringValue::Concat(value1, value2, arena);
4844
}
4945

5046
// Concatenation for bytes type.
@@ -60,15 +56,15 @@ absl::StatusOr<BytesValue> ConcatBytes(
6056
}
6157

6258
bool StringContains(const StringValue& value, const StringValue& substr) {
63-
return absl::StrContains(value.ToString(), substr.ToString());
59+
return value.Contains(substr);
6460
}
6561

6662
bool StringEndsWith(const StringValue& value, const StringValue& suffix) {
67-
return absl::EndsWith(value.ToString(), suffix.ToString());
63+
return value.EndsWith(suffix);
6864
}
6965

7066
bool StringStartsWith(const StringValue& value, const StringValue& prefix) {
71-
return absl::StartsWith(value.ToString(), prefix.ToString());
67+
return value.StartsWith(prefix);
7268
}
7369

7470
absl::Status RegisterSizeFunctions(FunctionRegistry& registry) {

0 commit comments

Comments
 (0)