diff --git a/CMakeLists.txt b/CMakeLists.txt index d28c3bd289c..7cd4e6d7ce2 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -216,6 +216,9 @@ set(ENABLE_TPROXY anything else to enable." ) option(ENABLE_QUICHE "Use quiche (default OFF)") +option(ENABLE_HIGHWAY_DISPATCH + "Use Google Highway runtime dispatch for SIMD kernels (default OFF; without it the scalar paths are used)" +) option(ENABLE_EXAMPLE "Build example directory (default OFF)") set(TS_MAX_HOST_NAME_LEN @@ -447,9 +450,14 @@ if(EXTERNAL_LIBSWOC) find_package(libswoc REQUIRED) endif() -if(EXTERNAL_HWY) - message(STATUS "Looking for external highway") - find_package(HWY "1.4.0" CONFIG REQUIRED) +if(ENABLE_HIGHWAY_DISPATCH) + if(EXTERNAL_HWY) + message(STATUS "Looking for external highway") + find_package(HWY "1.4.0" CONFIG REQUIRED) + set(TS_HAS_HIGHWAY_DISPATCH ${HWY_FOUND}) + else() + set(TS_HAS_HIGHWAY_DISPATCH 1) + endif() endif() include(Check128BitCas) diff --git a/CMakePresets.json b/CMakePresets.json index cec863644f3..08e316b316b 100644 --- a/CMakePresets.json +++ b/CMakePresets.json @@ -280,6 +280,15 @@ "CMAKE_BUILD_TYPE": "Release" } }, + { + "name": "branch-highway", + "displayName": "CI branch with Highway SIMD dispatch", + "description": "Builds the Highway runtime-dispatched SIMD kernels (base64, ASCII tolower) so their unit tests run as real SIMD-vs-scalar parity checks instead of scalar-vs-scalar.", + "inherits": ["branch"], + "cacheVariables": { + "ENABLE_HIGHWAY_DISPATCH": "ON" + } + }, { "name": "branch-debug", "displayName": "CI branch debug", diff --git a/NOTICE b/NOTICE index e1537312911..7fb1e7a6358 100644 --- a/NOTICE +++ b/NOTICE @@ -133,3 +133,19 @@ Reference implementation: https://github.com/Thesys-lab/sosp23-s3fifo Highway is a C++ library that provides portable SIMD/vector intrinsics. https://github.com/google/highway + +~~ + +include/tscore/ink_ascii_tolower.h AVX-512BW kernel design (fused +mask_add and masked-tail load/store) is adapted from Tony Finch's +copytolower64.c (0BSD OR MIT-0). +https://dotat.at/cgi/git/vectolower.git/ + +~~ + +src/tscore/ink_base64_dispatch.cc contains SIMD base64 encode/decode kernels +derived from the simdutf library, re-expressed using Google Highway. The +algorithms (Wojciech Muła and Daniel Lemire's vectorized base64, and aqrit's +combined standard/URL-safe classifier) and lookup tables originate there. +simdutf: https://github.com/simdutf/simdutf (Apache-2.0 / MIT / BSL-1.0) +Copyright (c) 2021 The simdutf authors diff --git a/include/tscore/ink_ascii_tolower.h b/include/tscore/ink_ascii_tolower.h new file mode 100644 index 00000000000..8d69121796b --- /dev/null +++ b/include/tscore/ink_ascii_tolower.h @@ -0,0 +1,95 @@ +/** @file + + Bulk ASCII tolower copy. + + Used on the URL canonicalization fast path for cache-key digests + (src/proxy/hdrs/URL.cc::url_CryptoHash_get_fast) and any other place that + needs to fold ASCII to lowercase over a small-to-moderate buffer. + + Semantics match a byte-at-a-time loop using ParseRules::ink_tolower(): + + - Bytes in 'A'..'Z' (0x41..0x5A) are folded to 'a'..'z' (bit 5 set). All + other bytes (including 0x80..0xFF) pass through unchanged. There is no + UTF-8 case folding. + + - In-place use (dst == src) is supported. Partial overlap where dst != src + but the ranges intersect is not supported. + + Two implementations, gated by the ENABLE_HIGHWAY_DISPATCH CMake option + (TS_HAS_HIGHWAY_DISPATCH at compile time): + + - ON: the runtime-dispatched SIMD implementation in + src/tscore/ink_ascii_tolower_dispatch.cc, built against Google Highway. + The highest SIMD target supported by the live CPU is selected once and + cached, so a conservatively compiled binary still runs the widest body + its CPU supports. This is the canonical accelerated path. + + - OFF (default): a portable scalar loop, which the compiler auto-vectorizes + for the build's target. No hand-written intrinsics. + + @section license License + + Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ +#pragma once + +#include +#include + +#include "tscore/ink_config.h" + +namespace ts::ascii +{ + +#if TS_HAS_HIGHWAY_DISPATCH + +// Out-of-line, runtime-dispatched SIMD implementation. Defined in +// src/tscore/ink_ascii_tolower_dispatch.cc via Highway HWY_EXPORT. +void tolower_copy_dispatch(char *dst, const char *src, std::size_t n) noexcept; + +inline void +tolower_copy(char *dst, const char *src, std::size_t n) noexcept +{ + tolower_copy_dispatch(dst, src, n); +} + +#else // !TS_HAS_HIGHWAY_DISPATCH — portable scalar fallback + +inline void +tolower_copy(char *dst, const char *src, std::size_t n) noexcept +{ + // The unsigned (c - 'A') < 26 test is true only for 'A'..'Z'; every other + // byte (including 0x80..0xFF) wraps to >= 26 and passes through unchanged. + // The compiler auto-vectorizes this loop for the build's target. + for (std::size_t i = 0; i < n; ++i) { + auto c = static_cast(src[i]); + dst[i] = (static_cast(c - 'A') < 26) ? static_cast(c | 0x20) : static_cast(c); + } +} + +#endif // TS_HAS_HIGHWAY_DISPATCH + +// Thin sugar over tolower_copy for the in-place case. Makes call sites like +// ts::ascii::tolower_inplace(buf, n) read naturally instead of +// ts::ascii::tolower_copy(buf, buf, n). +inline void +tolower_inplace(char *buf, std::size_t n) noexcept +{ + tolower_copy(buf, buf, n); +} + +} // namespace ts::ascii diff --git a/include/tscore/ink_config.h.cmake.in b/include/tscore/ink_config.h.cmake.in index 73c8b860fb9..9867f06a56e 100644 --- a/include/tscore/ink_config.h.cmake.in +++ b/include/tscore/ink_config.h.cmake.in @@ -140,6 +140,7 @@ const int DEFAULT_STACKSIZE = @DEFAULT_STACK_SIZE@; /* Feature Flags */ #cmakedefine01 TS_HAS_128BIT_CAS #cmakedefine01 TS_HAS_BACKTRACE +#cmakedefine01 TS_HAS_HIGHWAY_DISPATCH #cmakedefine01 TS_HAS_IN6_IS_ADDR_UNSPECIFIED #cmakedefine01 TS_HAS_IP_TOS #cmakedefine01 TS_HAS_JEMALLOC diff --git a/lib/CMakeLists.txt b/lib/CMakeLists.txt index 17771e4ee5f..68e5b6eb98c 100644 --- a/lib/CMakeLists.txt +++ b/lib/CMakeLists.txt @@ -93,7 +93,7 @@ add_subdirectory(ls-hpack) # CMakeLists.txt slightly modified to fully disable # building tests if `HWY_ENABLE_TESTS=OFF`, we intend # to upstream this change -if(NOT EXTERNAL_HWY) +if(ENABLE_HIGHWAY_DISPATCH AND NOT EXTERNAL_HWY) message(STATUS "Using internal highway") set(HWY_FORCE_STATIC_LIBS ON diff --git a/src/proxy/hdrs/URL.cc b/src/proxy/hdrs/URL.cc index 68d84b5d481..1c9eb4170b2 100644 --- a/src/proxy/hdrs/URL.cc +++ b/src/proxy/hdrs/URL.cc @@ -25,6 +25,7 @@ #include #include "tscore/ink_platform.h" #include "tscore/ink_memory.h" +#include "tscore/ink_ascii_tolower.h" #include "proxy/hdrs/URL.h" #include "proxy/hdrs/MIME.h" #include "proxy/hdrs/HTTP.h" @@ -1684,16 +1685,6 @@ url_describe(HdrHeapObjImpl *raw, bool /* recurse ATS_UNUSED */) * * ***********************************************************************/ -static inline void -memcpy_tolower(char *d, const char *s, int n) -{ - while (n--) { - *d = ParseRules::ink_tolower(*s); - s++; - d++; - } -} - // fast path for CryptoHash, HTTP, no user/password/params/query, // no buffer overflow, no unescaping needed @@ -1704,7 +1695,7 @@ url_CryptoHash_get_fast(const URLImpl *url, CryptoContext &ctx, CryptoHash *hash char *p; p = buffer; - memcpy_tolower(p, url->m_ptr_scheme, url->m_len_scheme); + ts::ascii::tolower_copy(p, url->m_ptr_scheme, url->m_len_scheme); p += url->m_len_scheme; *p++ = ':'; *p++ = '/'; @@ -1713,7 +1704,7 @@ url_CryptoHash_get_fast(const URLImpl *url, CryptoContext &ctx, CryptoHash *hash *p++ = ':'; // no password *p++ = '@'; - memcpy_tolower(p, url->m_ptr_host, url->m_len_host); + ts::ascii::tolower_copy(p, url->m_ptr_host, url->m_len_host); p += url->m_len_host; *p++ = '/'; memcpy(p, url->m_ptr_path, url->m_len_path); diff --git a/src/proxy/hdrs/unit_tests/test_URL.cc b/src/proxy/hdrs/unit_tests/test_URL.cc index dc5ff4ade74..8dda4e9501a 100644 --- a/src/proxy/hdrs/unit_tests/test_URL.cc +++ b/src/proxy/hdrs/unit_tests/test_URL.cc @@ -659,6 +659,39 @@ std::vector get_hash_test_cases = { !IGNORE_QUERY, HAS_EQUAL_HASH, }, + { + // Verifies the scheme/host SIMD-tolower path in url_CryptoHash_get_fast: + // an uppercase host with a long enough prefix to hit the 16-byte SIMD + // body should hash identically to its lowercased form. + "Uppercase host: equal hashes", + "http://ONE.EXAMPLE.COM/a/path?name=value", + "http://one.example.com/a/path?name=value", + !IGNORE_QUERY, + HAS_EQUAL_HASH, + }, + { + "Mixed-case host: equal hashes", + "http://One.Example.Com/a/path?name=value", + "http://one.example.com/a/path?name=value", + !IGNORE_QUERY, + HAS_EQUAL_HASH, + }, + { + "Uppercase scheme: equal hashes", + "HTTP://one.example.com/a/path?name=value", + "http://one.example.com/a/path?name=value", + !IGNORE_QUERY, + HAS_EQUAL_HASH, + }, + { + // Long uppercase host crosses 16- and 32-byte SIMD body boundaries so + // the wider paths (when compiled in) are exercised by this fixture. + "Long uppercase host: equal hashes", + "http://A-VERY-LONG-HOST-NAME-FOR-SIMD.EXAMPLE.COM/a/path", + "http://a-very-long-host-name-for-simd.example.com/a/path", + !IGNORE_QUERY, + HAS_EQUAL_HASH, + }, }; /** Return the hash related to a URI. diff --git a/src/proxy/http/remap/UrlRewrite.cc b/src/proxy/http/remap/UrlRewrite.cc index fbda217462b..d2af3b343e9 100644 --- a/src/proxy/http/remap/UrlRewrite.cc +++ b/src/proxy/http/remap/UrlRewrite.cc @@ -22,6 +22,8 @@ */ +#include "tscore/ink_ascii_tolower.h" + #include "proxy/http/remap/UrlRewrite.h" #include "proxy/http/remap/RemapYamlConfig.h" #include "iocore/eventsystem/ConfigProcessor.h" @@ -931,10 +933,7 @@ UrlRewrite::_mappingLookup(MappingsStore &mappings, URL *request_url, int reques return false; } - // lowercase - for (int i = 0; i < request_host_len; ++i) { - request_host_lower[i] = tolower(request_host[i]); - } + ts::ascii::tolower_copy(request_host_lower, request_host, request_host_len); request_host_lower[request_host_len] = 0; bool retval = false; diff --git a/src/proxy/http/remap/unit-tests/test_RemapRules.cc b/src/proxy/http/remap/unit-tests/test_RemapRules.cc index f51e39d0dd6..35ae9c5e8c7 100644 --- a/src/proxy/http/remap/unit-tests/test_RemapRules.cc +++ b/src/proxy/http/remap/unit-tests/test_RemapRules.cc @@ -269,3 +269,58 @@ map_with_recv_port http://front.example.com \ } } } + +SCENARIO("UrlRewrite host lookup is case-insensitive", "[proxy][remap]") +{ + // _mappingLookup lower-cases the request host before consulting the hash + // table; these scenarios exercise that path with inputs that would not + // match in a strict byte-compare. Sized to cross the 16-byte SSE2 body + // for hosts that get a real SIMD pass. + GIVEN("A forward map with a lowercase source host") + { + auto urlrw = std::make_unique(); + std::string config = R"RMCFG( +map http://www.example.com http://origin.example.com + )RMCFG"; + + auto cpath = write_test_remap(config, "case_insensitive"); + int rc = urlrw->BuildTable(cpath.c_str()); + REQUIRE(rc == TS_SUCCESS); + REQUIRE(urlrw->rule_count() == 1); + + EasyURL url("http://www.example.com"); + UrlMappingContainer urlmap; + + THEN("uppercase request host matches the lowercase rule") + { + const char *host = "WWW.EXAMPLE.COM"; + REQUIRE(urlrw->forwardMappingLookup(&url.url, 80, host, strlen(host), urlmap)); + } + THEN("mixed-case request host matches the lowercase rule") + { + const char *host = "Www.Example.Com"; + REQUIRE(urlrw->forwardMappingLookup(&url.url, 80, host, strlen(host), urlmap)); + } + } + + GIVEN("A forward map with a long host that exercises the 16-byte SIMD body") + { + auto urlrw = std::make_unique(); + std::string config = R"RMCFG( +map http://a-very-long-host-name-for-simd.example.com http://origin.example.com + )RMCFG"; + + auto cpath = write_test_remap(config, "case_insensitive_long"); + int rc = urlrw->BuildTable(cpath.c_str()); + REQUIRE(rc == TS_SUCCESS); + + EasyURL url("http://a-very-long-host-name-for-simd.example.com"); + UrlMappingContainer urlmap; + + THEN("an all-uppercase 49-char host (covers >=32 SIMD bytes) matches") + { + const char *host = "A-VERY-LONG-HOST-NAME-FOR-SIMD.EXAMPLE.COM"; + REQUIRE(urlrw->forwardMappingLookup(&url.url, 80, host, strlen(host), urlmap)); + } + } +} diff --git a/src/proxy/http2/HPACK.cc b/src/proxy/http2/HPACK.cc index 7e4fd974f57..34ff607ced2 100644 --- a/src/proxy/http2/HPACK.cc +++ b/src/proxy/http2/HPACK.cc @@ -23,6 +23,7 @@ #include "proxy/http2/HPACK.h" +#include "tscore/ink_ascii_tolower.h" #include "tsutil/LocalBuffer.h" #include "swoc/TextView.h" @@ -789,9 +790,7 @@ hpack_encode_header_block(HpackIndexingTable &indexing_table, uint8_t *out_buf, int name_len = original_name.size(); ts::LocalBuffer local_buffer(name_len); char *lower_name = local_buffer.data(); - for (int i = 0; i < name_len; i++) { - lower_name[i] = ParseRules::ink_tolower(original_name[i]); - } + ts::ascii::tolower_copy(lower_name, original_name.data(), name_len); std::string_view name{lower_name, static_cast(name_len)}; std::string_view value = field.value_get(); diff --git a/src/proxy/http2/unit_tests/test_HpackIndexingTable.cc b/src/proxy/http2/unit_tests/test_HpackIndexingTable.cc index ad373211fb8..d682f29abe8 100644 --- a/src/proxy/http2/unit_tests/test_HpackIndexingTable.cc +++ b/src/proxy/http2/unit_tests/test_HpackIndexingTable.cc @@ -24,6 +24,7 @@ limitations under the License. */ +#include #include #include @@ -531,3 +532,34 @@ TEST_CASE("HPACK high level APIs", "[hpack]") } } } + +// Validates that hpack_encode_header_block() lower-cases mixed-case field +// names per RFC 7540 § 8.1.2 before emitting them. The lower-case step is the +// path that goes through ts::ascii::tolower_copy; if a regression broke the +// lowercasing, the byte-for-byte comparison below would fail. +TEST_CASE("HPACK encode lower-cases mixed-case field names", "[hpack]") +{ + uint8_t buf_mixed[BUFSIZE_FOR_REGRESSION_TEST]; + uint8_t buf_lower[BUFSIZE_FOR_REGRESSION_TEST]; + HpackIndexingTable table_mixed(MAX_TABLE_SIZE); + HpackIndexingTable table_lower(MAX_TABLE_SIZE); + + // Use a name long enough to exercise the 16-byte SSE2 body when present. + auto encode_one = [](HpackIndexingTable &table, uint8_t *buf, const char *name, const char *value) -> int64_t { + std::unique_ptr headers(new HTTPHdr, destroy_http_hdr); + headers->create(HTTP_TYPE_REQUEST); + MIMEField *field = mime_field_create(headers->m_heap, headers->m_http->m_fields_impl); + field->name_set(headers->m_heap, headers->m_http->m_fields_impl, name, strlen(name)); + field->value_set(headers->m_heap, headers->m_http->m_fields_impl, value, strlen(value)); + mime_hdr_field_attach(headers->m_http->m_fields_impl, field, 1, nullptr); + std::memset(buf, 0, BUFSIZE_FOR_REGRESSION_TEST); + return hpack_encode_header_block(table, buf, BUFSIZE_FOR_REGRESSION_TEST, headers.get()); + }; + + int64_t mixed_len = encode_one(table_mixed, buf_mixed, "Long-Custom-Header-Name", "abc"); + int64_t lower_len = encode_one(table_lower, buf_lower, "long-custom-header-name", "abc"); + + REQUIRE(mixed_len > 0); + REQUIRE(mixed_len == lower_len); + REQUIRE(std::memcmp(buf_mixed, buf_lower, lower_len) == 0); +} diff --git a/src/proxy/http3/QPACK.cc b/src/proxy/http3/QPACK.cc index dfdd2d278b3..92ac2b024f6 100644 --- a/src/proxy/http3/QPACK.cc +++ b/src/proxy/http3/QPACK.cc @@ -25,6 +25,7 @@ #include "proxy/hdrs/XPACK.h" #include "proxy/http3/QPACK.h" #include "tscore/ink_defs.h" +#include "tscore/ink_ascii_tolower.h" #include "tscore/ink_memory.h" #define QPACKDebug(fmt, ...) Dbg(dbg_ctl_qpack, "[%s] " fmt, this->_qc->cids().data(), ##__VA_ARGS__) @@ -369,9 +370,7 @@ QPACK::_encode_header(const MIMEField &field, uint16_t base_index, IOBufferBlock { auto name{field.name_get()}; char *lowered_name = this->_arena.str_store(name.data(), name.length()); - for (size_t i = 0; i < name.length(); i++) { - lowered_name[i] = ParseRules::ink_tolower(lowered_name[i]); - } + ts::ascii::tolower_inplace(lowered_name, name.length()); auto value{field.value_get()}; // TODO Set never_index flag on/off according to encoding headers diff --git a/src/tscore/CMakeLists.txt b/src/tscore/CMakeLists.txt index 3d70052b199..23141e8fc47 100644 --- a/src/tscore/CMakeLists.txt +++ b/src/tscore/CMakeLists.txt @@ -106,6 +106,17 @@ else() target_sources(tscore PRIVATE HKDF_openssl.cc) endif() +# Highway runtime-dispatched SIMD kernels (ASCII tolower, base64). Only +# compiled when ENABLE_HIGHWAY_DISPATCH=ON; otherwise ink_ascii_tolower.h and +# ink_base64.cc use their scalar paths. The source dir is added to the include +# path so Highway's foreach_target.h self-include of each *_dispatch.cc +# resolves. +if(TS_HAS_HIGHWAY_DISPATCH) + target_sources(tscore PRIVATE ink_ascii_tolower_dispatch.cc ink_base64_dispatch.cc) + target_link_libraries(tscore PRIVATE hwy::hwy) + target_include_directories(tscore PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}) +endif() + target_link_libraries( tscore PUBLIC OpenSSL::Crypto libswoc::libswoc yaml-cpp::yaml-cpp systemtap::systemtap resolv::resolv ts::tsutil ) @@ -159,6 +170,8 @@ if(BUILD_TESTING) unit_tests/test_Tokenizer.cc unit_tests/test_arena.cc unit_tests/test_ink_base64.cc + unit_tests/test_ink_ascii_tolower.cc + unit_tests/test_ink_base64.cc unit_tests/test_ink_inet.cc unit_tests/test_ink_memory.cc unit_tests/test_ink_string.cc diff --git a/src/tscore/ink_ascii_tolower_dispatch.cc b/src/tscore/ink_ascii_tolower_dispatch.cc new file mode 100644 index 00000000000..7cfc6f8eb83 --- /dev/null +++ b/src/tscore/ink_ascii_tolower_dispatch.cc @@ -0,0 +1,147 @@ +/** @file + + Runtime-dispatched implementation of ts::ascii::tolower_copy. + + Enabled via ENABLE_HIGHWAY_DISPATCH=ON at build time. When the option + is off, ink_ascii_tolower.h provides the original compile-time cascade + and this translation unit is not compiled. + + Why dispatch: when the build target is intentionally conservative + (e.g. -march=westmere for binary portability), the compile-time + cascade in ink_ascii_tolower.h is locked to SSE2 and cannot exploit + AVX2 or AVX-512 on hosts that support them. Highway's foreach_target + mechanism emits one body per enabled target inside this TU using + per-function GCC target attributes, regardless of the TU's own + -march. At runtime, the highest target supported by the live CPU is + selected and its function pointer is cached for direct reuse. Effect: + a Westmere-compiled binary still runs the AVX-512 body on Ice Lake. + + Per-call cost vs the fully-inlined cascade: ~1.3 ns flat (one direct + call to this function, then one indirect call through the cached + pointer). The trade for that overhead is unlocking the wider vector + path on bulk inputs, which can be ~2x faster on modern CPUs even from + a conservative build target. + + Correctness is bit-for-bit identical to the cascade (see the + test_ink_ascii_tolower parity suite in src/tscore/unit_tests). + + @section license License + + Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ + +#include "tscore/ink_ascii_tolower.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "ink_ascii_tolower_dispatch.cc" +#include +#include + +HWY_BEFORE_NAMESPACE(); +namespace ts::ascii +{ +namespace HWY_NAMESPACE +{ + namespace hn = hwy::HWY_NAMESPACE; + + // Templated chunk worker, forced inline so the small-N branch below + // stays a per-call-site dead-code-eliminable check rather than a real + // jump table. Body is the same as the static-dispatch reference: single + // compare via the unsigned(v - 'A') < 26 trick, fused masked add via + // MaskedAddOr (lowers to _mm512_mask_add_epi8 on AVX3; falls back to + // IfThenElse(Add) on targets without native masked-add), and a + // LoadN/StoreN masked tail. + template + HWY_ATTR HWY_INLINE void + tolower_chunk(D d, char *dst, const char *src, std::size_t n) + { + using V = hn::Vec; + const V A = hn::Set(d, static_cast>('A')); + const V R = hn::Set(d, static_cast>(26)); + const V B5 = hn::Set(d, static_cast>(0x20)); + const std::size_t N = hn::Lanes(d); + + const auto *in = reinterpret_cast(src); + auto *out = reinterpret_cast(dst); + + std::size_t i = 0; + if (n >= N) { + for (; i + N <= n; i += N) { + const V v = hn::LoadU(d, in + i); + const auto is_upper = hn::Lt(hn::Sub(v, A), R); + const V folded = hn::MaskedAddOr(v, is_upper, v, B5); + hn::StoreU(folded, d, out + i); + } + } + if (i < n) { + const std::size_t rem = n - i; + const V v = hn::LoadN(d, in + i, rem); + const auto is_upper = hn::Lt(hn::Sub(v, A), R); + const V folded = hn::MaskedAddOr(v, is_upper, v, B5); + hn::StoreN(folded, d, out + i, rem); + } + } + + // Small-N gate. On wide targets (AVX2 = 32 lanes, AVX3 = 64 lanes) the + // masked LoadN/StoreN on a sub-vector input pays ~10 ns of fixed setup + // before doing any work. Dropping to a 16-byte CappedTag for those + // inputs collapses the gate to SSSE3-class ops, which is what the + // hand-written cascade falls through to anyway. On targets where the + // ScalableTag is already 16 bytes (SSE2/NEON), both branches lower to + // the same body and the gate folds out at compile time. + HWY_ATTR void + tolower_copy_target(char *dst, const char *src, std::size_t n) + { + const hn::ScalableTag d_full; + if (n < hn::Lanes(d_full)) { + const hn::CappedTag d16; + tolower_chunk(d16, dst, src, n); + } else { + tolower_chunk(d_full, dst, src, n); + } + } + +} // namespace HWY_NAMESPACE +} // namespace ts::ascii +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace ts::ascii +{ + +HWY_EXPORT(tolower_copy_target); + +// Cached-pointer dispatch. On first call, the lazy-init lambda invokes +// the dispatch table (which patches the entry to point to the chosen +// target) and then captures the now-resolved pointer for direct reuse. +// Thread-safe via C++11 magic-static semantics; runs the resolver once. +// Per-call cost after init is one indirect call through the cached +// pointer. +void +tolower_copy_dispatch(char *dst, const char *src, std::size_t n) noexcept +{ + using tolower_fn_t = void (*)(char *, const char *, std::size_t); + static const tolower_fn_t fn = []() noexcept { + char dummy_dst = 0, dummy_src = 'A'; + HWY_DYNAMIC_DISPATCH(tolower_copy_target)(&dummy_dst, &dummy_src, 1); + return HWY_DYNAMIC_POINTER(tolower_copy_target); + }(); + fn(dst, src, n); +} + +} // namespace ts::ascii +#endif // HWY_ONCE diff --git a/src/tscore/ink_base64.cc b/src/tscore/ink_base64.cc index 208626f5262..635500cc0c4 100644 --- a/src/tscore/ink_base64.cc +++ b/src/tscore/ink_base64.cc @@ -1,6 +1,34 @@ /** @file - A brief file description + Base64 encoding and decoding as according to RFC1521. Similar to uudecode. + + The public entry points (ats_base64_encode / ats_base64_decode, also exposed + through TSBase64Encode / TSBase64Decode) dispatch between two implementations + that share the canonical scalar primitives in ink_base64_scalar.h: + + - A hand-rolled scalar path, always present, used directly and for inputs + below the SIMD crossover threshold. It avoids the SIMD path's runtime + dispatch overhead, which would otherwise dominate for tiny inputs (e.g. + the 8-byte SnowflakeID encode). + + - When ENABLE_HIGHWAY_DISPATCH is on (TS_HAS_HIGHWAY_DISPATCH), a SIMD path + in ink_base64_dispatch.cc built on Google Highway, used for larger + inputs. It produces output byte-for-byte identical to the scalar path. + + RFC 1521 requires inserting line breaks for long lines. The basic web + authentication scheme does not require them. This implementation is intended + for web-related use, and line breaks are not implemented. + + Contract preserved by both paths: + + - encode: standard RFC 1521 alphabet (`+`, `/`), `=` padding, no line + breaks, trailing NUL at outBuffer[length]. + + - decode: accepts standard (`+`, `/`) and URL-safe (`-`, `_`) alphabets in + the same input; tolerates missing padding; on any non-alphabet byte + (whitespace, `=`, or garbage) truncates and returns success with whatever + was decoded; trailing NUL at outBuffer[length]; supports in-place decode + (dst == src). @section license License @@ -20,73 +48,42 @@ See the License for the specific language governing permissions and limitations under the License. */ - -/* - * Base64 encoding and decoding as according to RFC1521. Similar to uudecode. - * - * RFC 1521 requires inserting line breaks for long lines. The basic web - * authentication scheme does not require them. This implementation is - * intended for web-related use, and line breaks are not implemented. - * - */ #include "tscore/ink_platform.h" +#include "tscore/ink_config.h" #include "tscore/ink_base64.h" #include "tscore/ink_assert.h" +#include "ink_base64_scalar.h" + +#if TS_HAS_HIGHWAY_DISPATCH +#include "ink_base64_dispatch.h" + +// Inputs at or below these byte counts stay on the scalar path, where they +// outrun the SIMD path's per-call dispatch overhead. The thresholds are +// conservative; both paths are correct at every size. inBufferSize for encode +// is the binary plaintext length; for decode it is the base64-encoded length. +namespace +{ +constexpr size_t BASE64_ENCODE_SIMD_THRESHOLD = 24; +constexpr size_t BASE64_DECODE_SIMD_THRESHOLD = 32; +} // namespace +#endif + bool ats_base64_encode(const unsigned char *inBuffer, size_t inBufferSize, char *outBuffer, size_t outBufSize, size_t *length) { - static const char _codes[66] = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; - char *obuf = outBuffer; - char in_tail[4]; - if (outBufSize < ats_base64_encode_dstlen(inBufferSize)) { return false; } - while (inBufferSize > 2) { - *obuf++ = _codes[(inBuffer[0] >> 2) & 077]; - *obuf++ = _codes[((inBuffer[0] & 03) << 4) | ((inBuffer[1] >> 4) & 017)]; - *obuf++ = _codes[((inBuffer[1] & 017) << 2) | ((inBuffer[2] >> 6) & 017)]; - *obuf++ = _codes[inBuffer[2] & 077]; - - inBufferSize -= 3; - inBuffer += 3; - } - - /* - * We've done all the input groups of three chars. We're left - * with 0, 1, or 2 input chars. We have to add zero-bits to the - * right if we don't have enough input chars. - * If 0 chars left, we're done. - * If 1 char left, form 2 output chars, and add 2 pad chars to output. - * If 2 chars left, form 3 output chars, add 1 pad char to output. - */ - if (inBufferSize == 0) { - *obuf = '\0'; - if (length) { - *length = (obuf - outBuffer); - } - } else { - memset(in_tail, 0, sizeof(in_tail)); - memcpy(in_tail, inBuffer, inBufferSize); - - *(obuf) = _codes[(in_tail[0] >> 2) & 077]; - *(obuf + 1) = _codes[((in_tail[0] & 03) << 4) | ((in_tail[1] >> 4) & 017)]; - *(obuf + 2) = _codes[((in_tail[1] & 017) << 2) | ((in_tail[2] >> 6) & 017)]; - *(obuf + 3) = _codes[in_tail[2] & 077]; - - if (inBufferSize == 1) { - *(obuf + 2) = '='; - } - *(obuf + 3) = '='; - *(obuf + 4) = '\0'; - - if (length) { - *length = (obuf + 4) - outBuffer; - } +#if TS_HAS_HIGHWAY_DISPATCH + if (inBufferSize > BASE64_ENCODE_SIMD_THRESHOLD) { + ts::base64::encode_dispatch(inBuffer, inBufferSize, outBuffer, length); + return true; } +#endif + ts::base64::encode_scalar(inBuffer, inBufferSize, outBuffer, length); return true; } @@ -96,76 +93,24 @@ ats_base64_encode(const char *inBuffer, size_t inBufferSize, char *outBuffer, si return ats_base64_encode(reinterpret_cast(inBuffer), inBufferSize, outBuffer, outBufSize, length); } -/*------------------------------------------------------------------------- - This is a reentrant, and malloc free implementation of ats_base64_decode. - -------------------------------------------------------------------------*/ -#ifdef DECODE -#undef DECODE -#endif - -#define DECODE(x) printableToSixBit[(unsigned char)x] -#define MAX_PRINT_VAL 63 - -/* Converts a printable character to it's six bit representation */ -const unsigned char printableToSixBit[256] = { - 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, - 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 62, 64, 62, 64, 63, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 64, 64, 64, 64, 64, 64, - 64, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 64, 64, 64, 64, 63, - 64, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 64, 64, 64, 64, 64, - 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, - 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, - 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, - 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64}; - bool ats_base64_decode(const char *inBuffer, size_t inBufferSize, unsigned char *outBuffer, size_t outBufSize, size_t *length) { - size_t decodedBytes = 0; - unsigned char *buf = outBuffer; - - // Make sure there is sufficient space in the output buffer if (outBufSize < ats_base64_decode_dstlen(inBufferSize)) { return false; } - // Ignore any trailing ='s or other undecodable characters: consume only the - // leading run of base64-alphabet bytes. - // TODO: Perhaps that ought to be an error instead? - size_t inBytes = 0; - while (inBytes < inBufferSize && DECODE(inBuffer[inBytes]) <= MAX_PRINT_VAL) { - ++inBytes; - } - - // Decode complete 4-character groups into 3 bytes each. Process only whole - // groups here so the loop never reads past the alphabet prefix; the previous - // code ran one extra iteration when inBytes was not a multiple of four (a - // read out of bounds of the input) and then read inBuffer[-2]. - while (inBytes >= 4) { - buf[0] = static_cast(DECODE(inBuffer[0]) << 2 | DECODE(inBuffer[1]) >> 4); - buf[1] = static_cast(DECODE(inBuffer[1]) << 4 | DECODE(inBuffer[2]) >> 2); - buf[2] = static_cast(DECODE(inBuffer[2]) << 6 | DECODE(inBuffer[3])); - buf += 3; - inBuffer += 4; - decodedBytes += 3; - inBytes -= 4; - } - - // Decode a trailing 2- or 3-character group; a lone trailing character does - // not encode a full byte and is dropped (as an RFC 4648 decoder requires). - if (inBytes >= 2) { - buf[0] = static_cast(DECODE(inBuffer[0]) << 2 | DECODE(inBuffer[1]) >> 4); - decodedBytes += 1; - if (inBytes >= 3) { - buf[1] = static_cast(DECODE(inBuffer[1]) << 4 | DECODE(inBuffer[2]) >> 2); - decodedBytes += 1; - } - } - - outBuffer[decodedBytes] = '\0'; - - if (length) { - *length = decodedBytes; +#if TS_HAS_HIGHWAY_DISPATCH + if (inBufferSize > BASE64_DECODE_SIMD_THRESHOLD) { + // The SIMD path validates inline and truncates at the first non-alphabet + // byte, so no separate alphabet pre-scan is needed here. + ts::base64::decode_dispatch(inBuffer, inBufferSize, outBuffer, length); + return true; } +#endif + // Ignore any trailing `=`s or other undecodable characters, then decode the + // valid alphabet prefix. + ts::base64::decode_scalar_prefix(inBuffer, inBufferSize, outBuffer, length); return true; } diff --git a/src/tscore/ink_base64_dispatch.cc b/src/tscore/ink_base64_dispatch.cc new file mode 100644 index 00000000000..21aebdaaddd --- /dev/null +++ b/src/tscore/ink_base64_dispatch.cc @@ -0,0 +1,258 @@ +/** @file + + Runtime-dispatched SIMD base64 encode/decode, built against Google Highway. + + Enabled by ENABLE_HIGHWAY_DISPATCH=ON. Highway's foreach_target mechanism + emits one body per enabled SIMD target from this single source; at runtime + the best target supported by the live CPU is selected once and cached, so a + conservatively compiled binary (e.g. -march=x86-64) still runs the AVX-512 + body on capable hardware. + + The SIMD math follows the well-known vectorized base64 algorithms popularized + by Wojciech Muła and Daniel Lemire and used by the simdutf library, expressed + here in Highway's portable ops (see NOTICE): + + - decode: aqrit's "default_or_url" classifier translates ASCII to 6-bit + values for the standard and URL-safe alphabets at once and flags any + non-alphabet byte; the 4x6-bit groups are packed to 3 bytes with two + pairwise multiply-adds and a shuffle. Validation is fused into the loop: + only fully-valid 16-byte blocks are consumed by SIMD, and the remainder + (including any non-alphabet truncation point) is finished by the scalar + decoder, so output matches the scalar path exactly. + + - encode: the Muła reshuffle splits each 3-byte group into four 6-bit + fields with one multiply-high and one multiply-low, then a pshufb-based + table maps 6-bit values to the standard alphabet. The 1-2 byte tail and + `=` padding are produced by the scalar encoder. + + In-place decode (out == in) is preserved: each block is fully loaded before + its bounds-safe StoreN, whose end never passes the next block's load. + + @section license License + + Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ + +#include "ink_base64_dispatch.h" +#include "ink_base64_scalar.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "ink_base64_dispatch.cc" +#include +#include + +HWY_BEFORE_NAMESPACE(); +namespace ts::base64 +{ +namespace HWY_NAMESPACE +{ + namespace hn = hwy::HWY_NAMESPACE; + + // per-128-bit-block constants, replicated to any width via LoadDup128 + alignas(16) static constexpr int8_t kMul1[16] = {0x40, 0x01, 0x40, 0x01, 0x40, 0x01, 0x40, 0x01, + 0x40, 0x01, 0x40, 0x01, 0x40, 0x01, 0x40, 0x01}; + alignas(16) static constexpr int16_t kMul2[8] = {0x1000, 0x0001, 0x1000, 0x0001, 0x1000, 0x0001, 0x1000, 0x0001}; + alignas(16) static constexpr uint8_t kPack[16] = {2, 1, 0, 6, 5, 4, 10, 9, 8, 14, 13, 12, 0x80, 0x80, 0x80, 0x80}; + + // ---- DECODE ---- + + // aqrit's "default_or_url" classifier (from simdutf's + // to_base64_mask): maps ASCII to 6-bit values for both + // standard (+ /) and URL-safe (- _) bytes in one pass, matching + // printableToSixBit exactly. The raw check mask flags whitespace as invalid + // (we omit simdutf's whitespace-XOR correction), giving ATS's truncate-on- + // non-alphabet semantics. Sets *ok false if any lane is non-alphabet. + alignas(16) static constexpr uint8_t kDeltaAsso[16] = {0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x11, 0x00, 0x16}; + alignas(16) static constexpr uint8_t kDeltaValues[16] = {0xBF, 0xE0, 0xB9, 0x13, 0x04, 0xBF, 0xBF, 0xB9, + 0xB9, 0x00, 0xFF, 0x11, 0xFF, 0xBF, 0x10, 0xB9}; + alignas(16) static constexpr uint8_t kCheckAsso[16] = {0x0D, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, + 0x01, 0x01, 0x03, 0x07, 0x0B, 0x0E, 0x0B, 0x06}; + alignas(16) static constexpr uint8_t kCheckValues[16] = {0x80, 0x80, 0x80, 0x80, 0xCF, 0xBF, 0xD5, 0xA6, + 0xB5, 0xA1, 0x00, 0x80, 0x00, 0x80, 0x00, 0x80}; + + template + HWY_ATTR HWY_INLINE hn::Vec + classify(D d, hn::Vec v, bool *ok) + { + const hn::RebindToSigned di; + const hn::Repartition d32; + + const auto shifted = hn::BitCast(d, hn::ShiftRight<3>(hn::BitCast(d32, v))); + auto delta_hash = hn::AverageRound(hn::TableLookupBytes(hn::LoadDup128(d, kDeltaAsso), v), shifted); + delta_hash = hn::And(delta_hash, hn::Set(d, 0x0F)); + const auto check_hash = hn::AverageRound(hn::TableLookupBytes(hn::LoadDup128(d, kCheckAsso), v), shifted); + + const auto out = + hn::SaturatedAdd(hn::BitCast(di, hn::TableLookupBytes(hn::LoadDup128(d, kDeltaValues), delta_hash)), hn::BitCast(di, v)); + const auto chk = + hn::SaturatedAdd(hn::BitCast(di, hn::TableLookupBytes(hn::LoadDup128(d, kCheckValues), check_hash)), hn::BitCast(di, v)); + + *ok = hn::AllFalse(di, hn::Lt(chk, hn::Zero(di))); // chk sign bit set => invalid + return hn::BitCast(d, out); + } + + // Decode one 16-char block -> 12 bytes. Returns false without storing if the + // block holds a non-alphabet byte. The 12 valid output bytes are contiguous + // at the front, so a bounds-safe StoreN(12) suffices (no overrun). + HWY_ATTR HWY_INLINE bool + decode_block16(const char *in, unsigned char *out) + { + const hn::Full128 d; + const hn::RebindToSigned d8i; + const hn::Repartition d16; + const hn::Repartition d32; + + const auto v = hn::LoadU(d, reinterpret_cast(in)); + bool ok; + const auto val = classify(d, v, &ok); + if (!ok) { + return false; + } + + const auto mul1 = hn::BitCast(d8i, hn::Load(d, reinterpret_cast(kMul1))); + const auto t0 = hn::SatWidenMulPairwiseAdd(d16, val, mul1); // maddubs + const auto t1 = hn::WidenMulPairwiseAdd(d32, t0, hn::Load(d16, kMul2)); // madd + const auto packed = hn::TableLookupBytes(hn::BitCast(d, t1), hn::Load(d, kPack)); + + hn::StoreN(packed, d, out, 12); + return true; + } + + HWY_ATTR void + DecodeImpl(const char *in, size_t in_len, unsigned char *out, size_t *out_len) + { + size_t i = 0, o = 0; + for (; i + 16 <= in_len; i += 16, o += 12) { + if (!decode_block16(in + i, out + o)) { + break; + } + } + // Scalar finishes the remainder: truncate at the first non-alphabet byte + // then decode 4-groups + a 2/3 char tail (and write the trailing NUL). + // Identical to running the scalar decoder over the whole alphabet prefix, + // because the SIMD loop consumed only fully-valid 4-group-aligned blocks. + size_t tail_len = 0; + decode_scalar_prefix(in + i, in_len - i, out + o, &tail_len); + o += tail_len; + if (out_len) { + *out_len = o; + } + } + + // ---- ENCODE ---- + + // 6-bit value (0..63) -> standard-alphabet ASCII, from simdutf's + // lookup_pshufb_improved (Muła): reduce to a 4-bit class, look up a per-class + // offset, add. Standard alphabet (+ /), matching encode_scalar. +#define U8(x) static_cast(x) + alignas(16) static constexpr uint8_t kShiftLUT[16] = {U8('a' - 26), + U8('0' - 52), + U8('0' - 52), + U8('0' - 52), + U8('0' - 52), + U8('0' - 52), + U8('0' - 52), + U8('0' - 52), + U8('0' - 52), + U8('0' - 52), + U8('0' - 52), + U8('+' - 62), + U8('/' - 63), + U8('A'), + 0, + 0}; +#undef U8 + + template + HWY_ATTR HWY_INLINE hn::Vec + to_ascii(D d, hn::Vec idx) + { + auto res = hn::SaturatedSub(idx, hn::Set(d, 51)); // 52..63 -> 1..12, else 0 + const auto less = hn::Lt(idx, hn::Set(d, 26)); // 0..25 (uppercase class) + res = hn::Or(res, hn::IfThenElseZero(less, hn::Set(d, 13))); + res = hn::TableLookupBytes(hn::LoadDup128(d, kShiftLUT), res); + return hn::Add(res, idx); + } + + // Encode 12 input bytes per 16-byte block -> 16 ASCII chars. `in` must have + // >= 16 readable bytes (over-reads bytes 12..15, only 0..11 used). Muła + // reshuffle: spread the 3 bytes of each group across a 32-bit lane, then split + // into four 6-bit fields with one mulhi + one mullo (each a pair of per-16-bit + // variable shifts). + template + HWY_ATTR HWY_INLINE void + encode_chunk(D d, const unsigned char *in, char *out) + { + const hn::Repartition d32; + const hn::Repartition d16; + + alignas(16) static constexpr uint8_t kSpread[16] = {1, 0, 2, 1, 4, 3, 5, 4, 7, 6, 8, 7, 10, 9, 11, 10}; + + const auto in32 = hn::BitCast(d32, hn::TableLookupBytes(hn::LoadU(d, in), hn::LoadDup128(d, kSpread))); + + const auto t0 = hn::And(in32, hn::Set(d32, 0x0fc0fc00u)); + const auto t1 = hn::MulHigh(hn::BitCast(d16, t0), hn::BitCast(d16, hn::Set(d32, 0x04000040u))); + const auto t2 = hn::And(in32, hn::Set(d32, 0x003f03f0u)); + const auto t3 = hn::Mul(hn::BitCast(d16, t2), hn::BitCast(d16, hn::Set(d32, 0x01000010u))); + const auto idx = hn::BitCast(d, hn::Or(t1, t3)); // four 6-bit values per 32-bit lane + + hn::StoreU(to_ascii(d, idx), d, reinterpret_cast(out)); + } + + HWY_ATTR void + EncodeImpl(const unsigned char *in, size_t in_len, char *out, size_t *out_len) + { + const hn::Full128 d; + size_t i = 0, o = 0; + while (in_len - i >= 16) { + encode_chunk(d, in + i, out + o); + i += 12; + o += 16; + } + size_t tail = 0; + encode_scalar(in + i, in_len - i, out + o, &tail); + o += tail; + if (out_len) { + *out_len = o; + } + } + +} // namespace HWY_NAMESPACE +} // namespace ts::base64 +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace ts::base64 +{ +HWY_EXPORT(DecodeImpl); +HWY_EXPORT(EncodeImpl); + +void +decode_dispatch(const char *in, size_t in_len, unsigned char *out, size_t *out_len) +{ + HWY_DYNAMIC_DISPATCH(DecodeImpl)(in, in_len, out, out_len); +} + +void +encode_dispatch(const unsigned char *in, size_t in_len, char *out, size_t *out_len) +{ + HWY_DYNAMIC_DISPATCH(EncodeImpl)(in, in_len, out, out_len); +} + +} // namespace ts::base64 +#endif diff --git a/src/tscore/ink_base64_dispatch.h b/src/tscore/ink_base64_dispatch.h new file mode 100644 index 00000000000..98f3eef335d --- /dev/null +++ b/src/tscore/ink_base64_dispatch.h @@ -0,0 +1,49 @@ +/** @file + + Runtime-dispatched SIMD base64 entry points, implemented in + ink_base64_dispatch.cc against Google Highway and selected at runtime via + HWY_DYNAMIC_DISPATCH. Only built and used when ENABLE_HIGHWAY_DISPATCH is on + (TS_HAS_HIGHWAY_DISPATCH at compile time); ink_base64.cc routes large inputs + here and falls back to the scalar path otherwise. + + Both functions produce output byte-for-byte identical to the scalar + encode_scalar / count_alphabet_prefix + decode_scalar in ink_base64_scalar.h. + decode consumes only fully-validated SIMD blocks and hands the remainder + (including any truncation at a non-alphabet byte) to the scalar tail. + + @section license License + + Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ + +#pragma once + +#include +#include + +namespace ts::base64 +{ +// Decode `in_len` bytes of base64 input. Output buffer capacity has already +// been validated by the caller (ats_base64_decode). Writes a trailing NUL at +// out[*out_len]. Supports in-place use (out == reinterpret_cast(in)). +void decode_dispatch(const char *in, size_t in_len, unsigned char *out, size_t *out_len); + +// Encode `in_len` binary bytes. Output buffer capacity already validated. +// Writes a trailing NUL at out[*out_len]. +void encode_dispatch(const unsigned char *in, size_t in_len, char *out, size_t *out_len); + +} // namespace ts::base64 diff --git a/src/tscore/ink_base64_scalar.h b/src/tscore/ink_base64_scalar.h new file mode 100644 index 00000000000..d3ff8d39c3e --- /dev/null +++ b/src/tscore/ink_base64_scalar.h @@ -0,0 +1,179 @@ +/** @file + + Scalar base64 encode/decode primitives, shared by the always-present scalar + path in ink_base64.cc and (when ENABLE_HIGHWAY_DISPATCH is on) the SIMD + kernel's scalar tail in ink_base64_dispatch.cc. Keeping a single definition + here guarantees the two paths cannot drift. + + These are the canonical reference semantics for ATS base64: + + - encode: standard RFC 1521 alphabet (`+`, `/`), `=` padding, no line + breaks, trailing NUL at outBuffer[length]. + + - decode: accepts standard (`+`, `/`) and URL-safe (`-`, `_`) alphabets + mixed in the same input; truncates at the first non-alphabet byte + (whitespace, `=`, or garbage); tolerates missing padding; trailing NUL + at outBuffer[length]; supports in-place decode (dst == src). + + decode_scalar restructures the historical tail handling: the previous code + ran one extra loop iteration past the alphabet prefix when the valid length + was not a multiple of four (reading bytes beyond the prefix, potentially out + of bounds) and then read inBuffer[-2]. This version processes only complete + 4-character groups and decodes a 2- or 3-character tail explicitly, dropping + a lone trailing character. The decoded length and bytes are identical to the + historical code for every well-defined input; only the out-of-bounds reads + are removed. + + @section license License + + Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ + +#pragma once + +#include +#include +#include + +namespace ts::base64 +{ +// Converts a printable character to its six-bit representation; 64 marks a +// non-alphabet byte. Both standard (`+`=62, `/`=63) and URL-safe (`-`=62, +// `_`=63) punctuation are accepted. +inline constexpr unsigned char printableToSixBit[256] = { + 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, + 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 62, 64, 62, 64, 63, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 64, 64, 64, 64, 64, 64, + 64, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 64, 64, 64, 64, 63, + 64, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 64, 64, 64, 64, 64, + 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, + 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, + 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, + 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64}; + +inline constexpr unsigned char MAX_PRINT_VAL = 63; + +inline unsigned char +decode_byte(char c) +{ + return printableToSixBit[static_cast(c)]; +} + +// Count the leading base64-alphabet bytes (standard or URL-safe). Any byte at +// or after this index is whitespace, `=`, or garbage and terminates the input. +inline size_t +count_alphabet_prefix(const char *inBuffer, size_t inBufferSize) +{ + size_t valid = 0; + + while (valid < inBufferSize && decode_byte(inBuffer[valid]) <= MAX_PRINT_VAL) { + ++valid; + } + return valid; +} + +// Hand-rolled scalar encode. Caller has already validated outBufSize. +inline void +encode_scalar(const unsigned char *inBuffer, size_t inBufferSize, char *outBuffer, size_t *length) +{ + static const char _codes[66] = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; + char *obuf = outBuffer; + char in_tail[4]; + + while (inBufferSize > 2) { + *obuf++ = _codes[(inBuffer[0] >> 2) & 077]; + *obuf++ = _codes[((inBuffer[0] & 03) << 4) | ((inBuffer[1] >> 4) & 017)]; + *obuf++ = _codes[((inBuffer[1] & 017) << 2) | ((inBuffer[2] >> 6) & 017)]; + *obuf++ = _codes[inBuffer[2] & 077]; + + inBufferSize -= 3; + inBuffer += 3; + } + + if (inBufferSize == 0) { + *obuf = '\0'; + if (length) { + *length = (obuf - outBuffer); + } + } else { + memset(in_tail, 0, sizeof(in_tail)); + memcpy(in_tail, inBuffer, inBufferSize); + + *(obuf) = _codes[(in_tail[0] >> 2) & 077]; + *(obuf + 1) = _codes[((in_tail[0] & 03) << 4) | ((in_tail[1] >> 4) & 017)]; + *(obuf + 2) = _codes[((in_tail[1] & 017) << 2) | ((in_tail[2] >> 6) & 017)]; + *(obuf + 3) = _codes[in_tail[2] & 077]; + + if (inBufferSize == 1) { + *(obuf + 2) = '='; + } + *(obuf + 3) = '='; + *(obuf + 4) = '\0'; + + if (length) { + *length = (obuf + 4) - outBuffer; + } + } +} + +// Hand-rolled scalar decode. Caller has pre-scanned with count_alphabet_prefix +// so every byte in inBuffer[0..inBufferSize) is in the base64 alphabet, and +// has validated outBufSize. +inline void +decode_scalar(const char *inBuffer, size_t inBufferSize, unsigned char *outBuffer, size_t *length) +{ + size_t decodedBytes = 0; + unsigned char *buf = outBuffer; + + while (inBufferSize >= 4) { + buf[0] = static_cast(decode_byte(inBuffer[0]) << 2 | decode_byte(inBuffer[1]) >> 4); + buf[1] = static_cast(decode_byte(inBuffer[1]) << 4 | decode_byte(inBuffer[2]) >> 2); + buf[2] = static_cast(decode_byte(inBuffer[2]) << 6 | decode_byte(inBuffer[3])); + buf += 3; + inBuffer += 4; + decodedBytes += 3; + inBufferSize -= 4; + } + + if (inBufferSize >= 2) { + buf[0] = static_cast(decode_byte(inBuffer[0]) << 2 | decode_byte(inBuffer[1]) >> 4); + decodedBytes += 1; + if (inBufferSize >= 3) { + buf[1] = static_cast(decode_byte(inBuffer[1]) << 4 | decode_byte(inBuffer[2]) >> 2); + decodedBytes += 1; + } + } + + outBuffer[decodedBytes] = '\0'; + if (length) { + *length = decodedBytes; + } +} + +// Decode the leading base64-alphabet run of inBuffer[0..inBufferSize): truncate +// at the first non-alphabet byte, write the decoded bytes plus trailing NUL to +// outBuffer, and set *length (if non-null) to the decoded byte count. This is +// the canonical scalar decode, used directly by ats_base64_decode and as the +// SIMD path's scalar tail in ink_base64_dispatch.cc. +inline void +decode_scalar_prefix(const char *inBuffer, size_t inBufferSize, unsigned char *outBuffer, size_t *length) +{ + const size_t valid = count_alphabet_prefix(inBuffer, inBufferSize); + + decode_scalar(inBuffer, valid, outBuffer, length); +} + +} // namespace ts::base64 diff --git a/src/tscore/unit_tests/test_ink_ascii_tolower.cc b/src/tscore/unit_tests/test_ink_ascii_tolower.cc new file mode 100644 index 00000000000..5f2cb8df3c9 --- /dev/null +++ b/src/tscore/unit_tests/test_ink_ascii_tolower.cc @@ -0,0 +1,132 @@ +/** @file + + Unit tests for ts::ascii::tolower_copy and ts::ascii::tolower_inplace. + + Runs as part of the standard test_tscore binary so the helper's SIMD + and scalar paths are exercised by ctest in every build, not just when + ENABLE_BENCHMARKS is set. + + @section license License + + Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ + +#include + +#include "tscore/ink_ascii_tolower.h" +#include "tscore/ParseRules.h" + +#include +#include +#include +#include + +namespace +{ + +// Same mixed-case ASCII distribution we use in the benchmark, so the unit +// tests exercise inputs that look like real URL/header bytes. +std::vector +make_mixed_case_ascii(std::size_t n, std::uint64_t seed) +{ + std::mt19937_64 rng(seed); + std::vector v(n); + for (std::size_t i = 0; i < n; ++i) { + auto r = static_cast(rng() & 0x3FU); + if (r < 26U) { + v[i] = static_cast('A' + r); + } else if (r < 52U) { + v[i] = static_cast('a' + (r - 26U)); + } else { + static constexpr char kNonAlpha[] = "0123456789-_./:"; + v[i] = kNonAlpha[r % (sizeof(kNonAlpha) - 1U)]; + } + } + return v; +} + +// Byte-at-a-time reference, equivalent to the prior static-inline +// memcpy_tolower in URL.cc. Anything ts::ascii::tolower_copy produces must +// match this for every input we test. +void +tolower_reference(char *d, const char *s, std::size_t n) noexcept +{ + while (n--) { + *d = ParseRules::ink_tolower(*s); + ++s; + ++d; + } +} + +} // namespace + +TEST_CASE("ts::ascii::tolower_copy matches scalar reference", "[ts_ascii_tolower]") +{ + // Bracket every SIMD body width (16/32/64) with both equal-to and + // offset-from-multiple lengths so the cascade transitions and the + // AVX-512BW masked tail are all exercised. + for (std::size_t sz : std::array{0, 1, 5, 15, 16, 17, 23, 31, 32, 33, 63, 64, 65, 257}) { + auto input = make_mixed_case_ascii(sz, 0xC0FFEE + sz); + std::vector expected(sz); + std::vector actual(sz); + + tolower_reference(expected.data(), input.data(), sz); + ts::ascii::tolower_copy(actual.data(), input.data(), sz); + + CAPTURE(sz); + REQUIRE(actual == expected); + } +} + +TEST_CASE("ts::ascii::tolower_copy preserves non-ASCII bytes", "[ts_ascii_tolower]") +{ + // Every byte value 0..255 should round-trip unchanged unless it is in + // 'A'..'Z', in which case it should map to 'a'..'z'. Guards against any + // future "speed-up" that widens the case-fold range past ASCII. + std::array input; + for (std::size_t i = 0; i < 256; ++i) { + input[i] = static_cast(i); + } + std::array output; + ts::ascii::tolower_copy(output.data(), reinterpret_cast(input.data()), input.size()); + + for (std::size_t i = 0; i < 256; ++i) { + auto in = static_cast(i); + auto out = static_cast(output[i]); + auto exp = (in >= 'A' && in <= 'Z') ? static_cast(in | 0x20) : in; + CAPTURE(i); + REQUIRE(out == exp); + } +} + +TEST_CASE("ts::ascii::tolower_inplace matches tolower_copy", "[ts_ascii_tolower]") +{ + // The inplace form must produce the same result as a non-overlapping copy. + // Exercise the same boundary sizes so the SIMD bodies and the AVX-512BW + // masked load/store are all exercised in-place. + for (std::size_t sz : std::array{0, 1, 5, 15, 16, 17, 23, 31, 32, 33, 63, 64, 65, 257}) { + auto input = make_mixed_case_ascii(sz, 0xBADF00D + sz); + std::vector expected(sz); + std::vector in_place(input); + + tolower_reference(expected.data(), input.data(), sz); + ts::ascii::tolower_inplace(in_place.data(), sz); + + CAPTURE(sz); + REQUIRE(in_place == expected); + } +} diff --git a/src/tscore/unit_tests/test_ink_base64.cc b/src/tscore/unit_tests/test_ink_base64.cc index bc7ac39816a..2e103ff04d4 100644 --- a/src/tscore/unit_tests/test_ink_base64.cc +++ b/src/tscore/unit_tests/test_ink_base64.cc @@ -2,10 +2,12 @@ Unit tests for ats_base64_encode / ats_base64_decode. - Includes a regression test for the out-of-bounds read that occurred when the - decodable prefix length was not a multiple of four. The decode helper places - the input in an exact-size heap buffer so that, under AddressSanitizer, any - read past the input (as the old decoder did) is caught. + These run in both build configurations. With ENABLE_HIGHWAY_DISPATCH off, + the public entry points are the scalar path and these checks pin its + behavior. With it on, large inputs take the Highway SIMD path and every + check below becomes a byte-for-byte parity test of SIMD vs. the scalar + primitives (the oracle), which is why the sizes deliberately straddle the + SIMD thresholds and run up to several KB. @section license License @@ -30,189 +32,158 @@ #include +#include "../ink_base64_scalar.h" // scalar oracle: encode_scalar / decode_scalar / count_alphabet_prefix + #include -#include #include -#include #include #include #include namespace { -// Decode `b64` with the input held in an EXACT-size heap buffer, so a read -// past the input trips AddressSanitizer's red zone (this is what the -// out-of-bounds bug did for prefixes whose length is not a multiple of four). +// Deterministic pseudo-random bytes (fixed seed -> reproducible). std::string -decode_tight(const std::string &b64) +prng_bytes(size_t n, uint32_t seed) { - std::vector in(b64.size() ? b64.size() : 1); - if (!b64.empty()) { - std::memcpy(in.data(), b64.data(), b64.size()); + std::mt19937 rng(seed); + std::uniform_int_distribution d(0, 255); + std::string s(n, '\0'); + for (auto &c : s) { + c = static_cast(d(rng)); } + return s; +} - // ats_base64_decode_dstlen() already includes the trailing NUL, so this is - // the exact documented minimum -- no slack, so any write at or past the - // bound trips ASan. - std::vector out(ats_base64_decode_dstlen(b64.size()), 0xCC); - size_t n = 0; - bool ok = ats_base64_decode(in.data(), b64.size(), out.data(), out.size(), &n); - REQUIRE(ok); - REQUIRE(out[n] == '\0'); // trailing NUL contract +// Encode via the scalar oracle. +std::string +oracle_encode(const std::string &bin) +{ + std::vector out(ats_base64_encode_dstlen(bin.size()) + 1, '\xCC'); + size_t n = 0; + ts::base64::encode_scalar(reinterpret_cast(bin.data()), bin.size(), out.data(), &n); + return std::string(out.data(), n); +} + +// Decode via the scalar oracle (count alphabet prefix, then decode it). +std::string +oracle_decode(const std::string &b64) +{ + const size_t valid = ts::base64::count_alphabet_prefix(b64.data(), b64.size()); + std::vector out(ats_base64_decode_dstlen(b64.size()) + 1, 0xCC); + size_t n = 0; + ts::base64::decode_scalar(b64.data(), valid, out.data(), &n); return std::string(reinterpret_cast(out.data()), n); } +// Encode via the public entry point (SIMD when enabled). std::string -encode(const std::string &bin) +public_encode(const std::string &bin) { - std::vector out(ats_base64_encode_dstlen(bin.size()), '\0'); + std::vector out(ats_base64_encode_dstlen(bin.size()) + 1, '\xDD'); size_t n = 0; bool ok = ats_base64_encode(bin.data(), bin.size(), out.data(), out.size(), &n); REQUIRE(ok); - REQUIRE(out[n] == '\0'); + REQUIRE(out[n] == '\0'); // trailing NUL contract return std::string(out.data(), n); } -const std::string kAlpha = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; +// Decode via the public entry point (SIMD when enabled). +std::string +public_decode(const std::string &b64) +{ + std::vector out(ats_base64_decode_dstlen(b64.size()) + 1, 0xDD); + size_t n = 0; + bool ok = ats_base64_decode(b64.data(), b64.size(), out.data(), out.size(), &n); + REQUIRE(ok); + REQUIRE(out[n] == '\0'); + return std::string(reinterpret_cast(out.data()), n); +} + +const std::string kStd = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; +const std::string kUrl = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_"; } // namespace TEST_CASE("ats_base64 known vectors", "[base64]") { - // RFC 4648 §10. - CHECK(encode("") == ""); - CHECK(encode("f") == "Zg=="); - CHECK(encode("fo") == "Zm8="); - CHECK(encode("foo") == "Zm9v"); - CHECK(encode("foob") == "Zm9vYg=="); - CHECK(encode("fooba") == "Zm9vYmE="); - CHECK(encode("foobar") == "Zm9vYmFy"); - - CHECK(decode_tight("Zm9vYmFy") == "foobar"); + // RFC 4648 §10 test vectors. + CHECK(public_encode("") == ""); + CHECK(public_encode("f") == "Zg=="); + CHECK(public_encode("fo") == "Zm8="); + CHECK(public_encode("foo") == "Zm9v"); + CHECK(public_encode("foob") == "Zm9vYg=="); + CHECK(public_encode("fooba") == "Zm9vYmE="); + CHECK(public_encode("foobar") == "Zm9vYmFy"); + + CHECK(public_decode("Zg==") == "f"); + CHECK(public_decode("Zm8=") == "fo"); + CHECK(public_decode("Zm9vYmFy") == "foobar"); } -TEST_CASE("ats_base64_decode does not read past a non-4-aligned prefix", "[base64]") +TEST_CASE("ats_base64 encode parity and round-trip across sizes", "[base64]") { - // Regression: a decodable prefix whose length is not a multiple of four made - // the old decoder run an extra loop iteration past the prefix (out of bounds - // of the input when the buffer ends there) and read inBuffer[-2]. These - // unpadded inputs exercise prefix lengths 2, 3, 6, 7 (i.e. 2 and 3 mod 4); - // decode_tight runs them in exact-size buffers so ASan would catch any - // over-read. - CHECK(decode_tight("") == ""); - CHECK(decode_tight("Zg") == "f"); // 2 chars -> 1 byte - CHECK(decode_tight("Zm8") == "fo"); // 3 chars -> 2 bytes - CHECK(decode_tight("Zm9v") == "foo"); // 4 chars -> 3 bytes - CHECK(decode_tight("Zm9vYg") == "foob"); // 6 chars -> 4 bytes - CHECK(decode_tight("Zm9vYmE") == "fooba"); // 7 chars -> 5 bytes - - // A lone trailing character does not encode a full byte and is dropped. - CHECK(decode_tight("Zm9vYg==").size() == 4); - CHECK(decode_tight("Q") == ""); - - // Sweep every prefix length (hence every length mod 4) in an exact-size - // buffer; assert the decoded length matches the RFC formula for that many - // alphabet characters and rely on ASan for over-read detection. - for (size_t len = 0; len <= 300; ++len) { - std::string s; - s.reserve(len); - for (size_t k = 0; k < len; ++k) { - s.push_back(kAlpha[(k * 7 + 3) % 64]); - } - const size_t rem = len % 4; - const size_t expected = (len / 4) * 3 + (rem ? rem - 1 : 0); - INFO("len=" << len); - CHECK(decode_tight(s).size() == expected); + // Straddle the SIMD thresholds (encode 24, decode 32) and go well past them. + std::vector sizes; + for (size_t n = 0; n <= 96; ++n) { + sizes.push_back(n); + } + for (size_t n : {100u, 127u, 128u, 129u, 200u, 255u, 256u, 511u, 512u, 1000u, 4096u, 4099u}) { + sizes.push_back(n); } -} -TEST_CASE("ats_base64 round-trips across sizes", "[base64]") -{ - std::mt19937 rng(20240601); - std::uniform_int_distribution byte(0, 255); + for (size_t n : sizes) { + const std::string bin = prng_bytes(n, static_cast(n * 2654435761u + 1)); - for (size_t n = 0; n <= 256; ++n) { - std::string bin(n, '\0'); - for (auto &c : bin) { - c = static_cast(byte(rng)); - } + const std::string enc_pub = public_encode(bin); + const std::string enc_ref = oracle_encode(bin); INFO("size=" << n); - CHECK(decode_tight(encode(bin)) == bin); + CHECK(enc_pub == enc_ref); // SIMD encode == scalar encode + CHECK(public_decode(enc_pub) == bin); // round-trips + CHECK(public_decode(enc_pub) == oracle_decode(enc_pub)); // SIMD decode == scalar decode } } -TEST_CASE("ats_base64_decode accepts the URL-safe alphabet", "[base64]") +TEST_CASE("ats_base64 decode parity for standard and URL-safe alphabets", "[base64]") { - // '-' and '_' map to the same six-bit values as '+' and '/', so a URL-safe - // string decodes to the same bytes as its standard-alphabet equivalent. - CHECK(decode_tight("____") == decode_tight("////")); - CHECK(decode_tight("----") == decode_tight("++++")); - - // Round-trip through the URL-safe alphabet: encode (standard), translate - // '+'/'/' to '-'/'_', decode, and expect the original bytes back. - std::mt19937 rng(99); - std::uniform_int_distribution byte(0, 255); - for (size_t n = 0; n <= 200; ++n) { - std::string bin(n, '\0'); - for (auto &c : bin) { - c = static_cast(byte(rng)); - } - std::string url = encode(bin); - for (auto &c : url) { - if (c == '+') { - c = '-'; - } else if (c == '/') { - c = '_'; + for (const std::string *alpha : {&kStd, &kUrl}) { + std::mt19937 rng(0xBEEF); + std::uniform_int_distribution pick(0, 63); + for (size_t n = 0; n <= 300; ++n) { + std::string s; + s.reserve(n); + for (size_t k = 0; k < n; ++k) { + s.push_back((*alpha)[pick(rng)]); } + INFO("alphabet=" << (alpha == &kUrl ? "url" : "std") << " len=" << n); + CHECK(public_decode(s) == oracle_decode(s)); } - INFO("size=" << n); - CHECK(decode_tight(url) == bin); } -} -TEST_CASE("ats_base64_decode accepts both alphabets mixed in one input", "[base64]") -{ - // The standard and URL-safe punctuation may appear in the same input. - std::mt19937 rng(1234); - std::uniform_int_distribution byte(0, 255); - std::uniform_int_distribution coin(0, 1); - for (size_t n = 0; n <= 200; ++n) { - std::string bin(n, '\0'); - for (auto &c : bin) { - c = static_cast(byte(rng)); - } - std::string enc = encode(bin); - for (auto &c : enc) { - if (c == '+' && coin(rng)) { - c = '-'; - } else if (c == '/' && coin(rng)) { - c = '_'; - } - } - INFO("size=" << n); - CHECK(decode_tight(enc) == bin); - } + // Both alphabets mixed within one input must decode identically to scalar. + const std::string mixed = "QWxhZGRpbjpvcGVuIHNlc2FtZQ" + "-_+/" + "QUJDREVGabcdef0123456789"; + CHECK(public_decode(mixed) == oracle_decode(mixed)); } -TEST_CASE("ats_base64_decode truncates at the first non-alphabet byte", "[base64]") +TEST_CASE("ats_base64 decode truncates at first non-alphabet byte", "[base64]") { - // A non-alphabet byte (whitespace, '=', or other garbage) ends the decodable - // input; decoding stops there and yields the decode of the prefix before it. - // Injecting it at every position also exercises the over-read fix under ASan. - const char *terminators = " \t\n\r=*@"; - std::mt19937 rng(55); - std::uniform_int_distribution pick(0, 63); - for (size_t len : {1u, 2u, 3u, 4u, 5u, 7u, 8u, 16u, 17u, 31u, 33u, 64u, 100u}) { + // A non-alphabet byte (whitespace, '=', or garbage) ends the input; the + // public path must match the scalar oracle exactly, including the SIMD path. + const char *terminators = " \t\n\r=*@"; + + for (size_t len : {1u, 4u, 7u, 16u, 17u, 31u, 32u, 33u, 48u, 63u, 64u, 65u, 96u, 128u, 200u}) { std::string base; for (size_t k = 0; k < len; ++k) { - base.push_back(kAlpha[pick(rng)]); + base.push_back(kStd[(k * 7 + 3) % 64]); } for (size_t pos = 0; pos < len; ++pos) { for (const char *t = terminators; *t; ++t) { std::string s = base; s[pos] = *t; - INFO("len=" << len << " pos=" << pos << " term=" << static_cast(*t)); - CHECK(decode_tight(s) == decode_tight(base.substr(0, pos))); + INFO("len=" << len << " pos=" << pos << " term=" << int(*t)); + CHECK(public_decode(s) == oracle_decode(s)); } } } @@ -220,19 +191,15 @@ TEST_CASE("ats_base64_decode truncates at the first non-alphabet byte", "[base64 TEST_CASE("ats_base64_decode supports in-place (dst == src)", "[base64]") { - std::mt19937 rng(7); - std::uniform_int_distribution byte(0, 255); + for (size_t n : {0u, 1u, 2u, 3u, 10u, 33u, 48u, 64u, 100u, 257u, 1000u}) { + const std::string bin = prng_bytes(n, static_cast(n + 7)); + const std::string enc = oracle_encode(bin); - for (size_t n : {0u, 1u, 2u, 3u, 10u, 33u, 48u, 64u, 100u, 257u}) { - std::string bin(n, '\0'); - for (auto &c : bin) { - c = static_cast(byte(rng)); - } - const std::string enc = encode(bin); - const std::string expect = decode_tight(enc); + // Decode into a separate buffer for the expected result. + const std::string expect = public_decode(enc); - // Large enough to hold both the input copied in and the decoded output. - std::vector buf(std::max(enc.size(), ats_base64_decode_dstlen(enc.size())), '\0'); + // Decode in place: output overwrites the input buffer. + std::vector buf(std::max(enc.size(), ats_base64_decode_dstlen(enc.size())) + 1, '\0'); std::copy(enc.begin(), enc.end(), buf.begin()); size_t n_out = 0; bool ok = ats_base64_decode(buf.data(), enc.size(), reinterpret_cast(buf.data()), buf.size(), &n_out); @@ -245,12 +212,16 @@ TEST_CASE("ats_base64_decode supports in-place (dst == src)", "[base64]") TEST_CASE("ats_base64 rejects undersized output buffers", "[base64]") { const std::string bin = "hello world, base64"; - const std::string enc = encode(bin); - size_t n = 0; + const std::string enc = oracle_encode(bin); + size_t n = 0; std::vector small_enc(ats_base64_encode_dstlen(bin.size()) - 1); CHECK_FALSE(ats_base64_encode(bin.data(), bin.size(), small_enc.data(), small_enc.size(), &n)); std::vector small_dec(ats_base64_decode_dstlen(enc.size()) - 1); CHECK_FALSE(ats_base64_decode(enc.data(), enc.size(), small_dec.data(), small_dec.size(), &n)); + + // Exactly the required size must succeed. + std::vector ok_enc(ats_base64_encode_dstlen(bin.size())); + CHECK(ats_base64_encode(bin.data(), bin.size(), ok_enc.data(), ok_enc.size(), &n)); } diff --git a/tests/fuzzing/CMakeLists.txt b/tests/fuzzing/CMakeLists.txt index 3639de02409..2d523d5efe3 100644 --- a/tests/fuzzing/CMakeLists.txt +++ b/tests/fuzzing/CMakeLists.txt @@ -29,6 +29,7 @@ add_executable(fuzz_proxy_protocol fuzz_proxy_protocol.cc) add_executable(fuzz_rec_http fuzz_rec_http.cc) add_executable(fuzz_yamlcpp fuzz_yamlcpp.cc) add_executable(fuzz_http3frame fuzz_http3frame.cc) +add_executable(fuzz_base64 fuzz_base64.cc) # Need to rewrite the ESI Parser to remove dependencies on TS API #target_link_libraries(fuzz_esi PRIVATE esi-common esicore tscore tsapi inkevent http) @@ -47,6 +48,10 @@ target_link_libraries(fuzz_yamlcpp PRIVATE yaml-cpp) target_link_options(fuzz_yamlcpp PRIVATE "-fuse-ld=lld") target_link_libraries(fuzz_http3frame PRIVATE ts::tscore ts::quic) target_link_options(fuzz_http3frame PRIVATE "-fuse-ld=lld") +target_link_libraries(fuzz_base64 PRIVATE ts::tscore) +target_link_options(fuzz_base64 PRIVATE "-fuse-ld=lld") +# fuzz_base64 cross-checks the public path against the scalar reference header. +target_include_directories(fuzz_base64 PRIVATE ${CMAKE_SOURCE_DIR}/src/tscore) target_sources( fuzz_hpack PRIVATE ${CMAKE_SOURCE_DIR}/src/proxy/http2/HTTP2.cc ${CMAKE_SOURCE_DIR}/src/proxy/http2/Http2Frame.cc diff --git a/tests/fuzzing/fuzz_base64.cc b/tests/fuzzing/fuzz_base64.cc new file mode 100644 index 00000000000..969fc1694c3 --- /dev/null +++ b/tests/fuzzing/fuzz_base64.cc @@ -0,0 +1,115 @@ +/** @file + + fuzzing tscore base64 (ats_base64_encode / ats_base64_decode) + + Treats the fuzz input as untrusted base64 and decodes it, then round-trips + arbitrary bytes through encode/decode. Every operation is cross-checked + against the scalar reference primitives, so any divergence of the public + (SIMD, when ENABLE_HIGHWAY_DISPATCH is on) path from the scalar path aborts. + Run under AddressSanitizer to catch any out-of-bounds access on the + untrusted-input decode path. + + @section license License + + Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ + +#include // size_t, used by ink_base64.h + +#include + +#include "ink_base64_scalar.h" // scalar reference (added to include path by CMake) + +#include +#include +#include +#include +#include + +#define kMaxInputLength 65536 + +namespace +{ +[[noreturn]] void +fail(const char *what) +{ + std::fprintf(stderr, "base64 fuzz mismatch: %s\n", what); + std::abort(); +} + +// Decode `in` two ways and require identical results (length, bytes, NUL). +void +check_decode(const char *in, size_t len) +{ + const size_t cap = ats_base64_decode_dstlen(len); + + std::vector out_pub(cap + 1, 0xAB); + size_t n_pub = 0; + if (!ats_base64_decode(in, len, out_pub.data(), out_pub.size(), &n_pub)) { + fail("decode returned false with sufficient buffer"); + } + + const size_t valid = ts::base64::count_alphabet_prefix(in, len); + std::vector out_ref(cap + 1, 0xCD); + size_t n_ref = 0; + ts::base64::decode_scalar(in, valid, out_ref.data(), &n_ref); + + if (n_pub != n_ref || std::memcmp(out_pub.data(), out_ref.data(), n_ref) != 0 || out_pub[n_ref] != '\0') { + fail("decode parity"); + } +} +} // namespace + +extern "C" int +LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) +{ + if (size > kMaxInputLength) { + return 0; + } + + // 1. Decode the untrusted input directly (any byte values). + check_decode(reinterpret_cast(data), size); + + // 2. Encode the input bytes; cross-check encode, then decode back and + // confirm the original bytes are recovered. + const size_t ecap = ats_base64_encode_dstlen(size); + std::vector enc_pub(ecap + 1, 1); + size_t ne_pub = 0; + if (!ats_base64_encode(reinterpret_cast(data), size, enc_pub.data(), enc_pub.size(), &ne_pub)) { + fail("encode returned false with sufficient buffer"); + } + + std::vector enc_ref(ecap + 1, 2); + size_t ne_ref = 0; + ts::base64::encode_scalar(data, size, enc_ref.data(), &ne_ref); + if (ne_pub != ne_ref || std::memcmp(enc_pub.data(), enc_ref.data(), ne_ref + 1) != 0) { + fail("encode parity"); + } + + // Encoded output is pure alphabet (+ padding); decoding it must recover the + // original bytes, and the two decode paths must agree. + check_decode(enc_pub.data(), ne_pub); + + std::vector back(ats_base64_decode_dstlen(ne_pub) + 1, 0); + size_t nb = 0; + ats_base64_decode(enc_pub.data(), ne_pub, back.data(), back.size(), &nb); + if (nb != size || std::memcmp(back.data(), data, size) != 0) { + fail("encode/decode round-trip"); + } + + return 0; +} diff --git a/tools/benchmark/CMakeLists.txt b/tools/benchmark/CMakeLists.txt index 49f25fad1c1..ef9d8ae5888 100644 --- a/tools/benchmark/CMakeLists.txt +++ b/tools/benchmark/CMakeLists.txt @@ -33,17 +33,5 @@ target_link_libraries(benchmark_ProxyAllocator PRIVATE Catch2::Catch2WithMain ts add_executable(benchmark_SharedMutex benchmark_SharedMutex.cc) target_link_libraries(benchmark_SharedMutex PRIVATE Catch2::Catch2 ts::tscore libswoc::libswoc) -add_executable(benchmark_Random benchmark_Random.cc) -target_link_libraries(benchmark_Random PRIVATE Catch2::Catch2WithMain ts::tscore) - -add_executable(benchmark_HostDB benchmark_HostDB.cc) -target_link_libraries( - benchmark_HostDB - PRIVATE ts::tscore - ts::tsutil - ts::inkevent - ts::http - ts::http_remap - ts::inkcache - ts::inkhostdb -) +add_executable(benchmark_ascii_tolower benchmark_ascii_tolower.cc) +target_link_libraries(benchmark_ascii_tolower PRIVATE Catch2::Catch2WithMain ts::tscore) diff --git a/tools/benchmark/benchmark_ascii_tolower.cc b/tools/benchmark/benchmark_ascii_tolower.cc new file mode 100644 index 00000000000..dcc155e4c8c --- /dev/null +++ b/tools/benchmark/benchmark_ascii_tolower.cc @@ -0,0 +1,138 @@ +/** @file + + Micro benchmark for ts::ascii::tolower_copy against a byte-at-a-time + scalar loop equivalent to the prior URL.cc::memcpy_tolower definition. + + @section license License + + Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ + +#define CATCH_CONFIG_ENABLE_BENCHMARKING + +#include +#include +#include + +#include "tscore/ink_ascii_tolower.h" +#include "tscore/ParseRules.h" + +#include +#include +#include +#include +#include +#include +#include + +namespace +{ + +// Sizes chosen to mirror the URL.cc hot path: +// 4-8B - common HTTP scheme strings ("http", "https") +// 16-32B - typical host names +// 64-256B - long host names / cache-key segments +// 1024B - stress the inner loop +constexpr std::array kSizes{4, 8, 16, 24, 32, 64, 256, 1024}; + +// Same character distribution we expect from URL/host input: ASCII letters +// (mixed case), digits, and the small set of non-alpha bytes that legitimately +// appear in URLs. +std::vector +make_mixed_case_ascii(std::size_t n, std::uint64_t seed = 0xABCDEFULL) +{ + std::mt19937_64 rng(seed); + std::vector v(n); + for (std::size_t i = 0; i < n; ++i) { + auto r = static_cast(rng() & 0x3FU); + if (r < 26U) { + v[i] = static_cast('A' + r); + } else if (r < 52U) { + v[i] = static_cast('a' + (r - 26U)); + } else { + static constexpr char kNonAlpha[] = "0123456789-_./:"; + v[i] = kNonAlpha[r % (sizeof(kNonAlpha) - 1U)]; + } + } + return v; +} + +// Mirror of the prior static inline memcpy_tolower() from URL.cc, kept here +// as the baseline the SIMD path is expected to beat. +inline void +tolower_scalar(char *d, const char *s, std::size_t n) noexcept +{ + while (n--) { + *d = ParseRules::ink_tolower(*s); + ++s; + ++d; + } +} + +} // namespace + +TEST_CASE("active SIMD configuration", "[tolower][config]") +{ + // Print the configuration so the benchmark output makes the selected + // implementation obvious. + std::cout << "ts::ascii::tolower_copy implementation: "; +#if TS_HAS_HIGHWAY_DISPATCH + std::cout << "Highway runtime dispatch (selects best available target at startup)"; +#elif defined(__AVX512BW__) + std::cout << "compile-time cascade — AVX-512BW (64B body + masked tail, gated at n>=64) + AVX2 + SSE2"; +#elif defined(__AVX2__) + std::cout << "compile-time cascade — AVX2 (32B body) + SSE2 (16B drain)"; +#elif defined(__SSE2__) + std::cout << "compile-time cascade — SSE2 (16B body)"; +#elif defined(__ARM_NEON) || defined(__aarch64__) + std::cout << "compile-time cascade — NEON (16B body)"; +#else + std::cout << "compile-time cascade — scalar only"; +#endif + std::cout << '\n'; + SUCCEED(); +} + +TEST_CASE("tolower throughput", "[bench][tolower]") +{ + for (std::size_t sz : kSizes) { + auto input = make_mixed_case_ascii(sz); + std::vector output_scalar(sz); + std::vector output_simd(sz); + + // Catch::Benchmark::keep_memory clobbers the buffer in the compiler's + // model, forcing it to materialize every byte we wrote. Without this an + // optimizing compiler can shrink or DCE the inline body's stores past + // the first element we observed. + + std::string scalar_name = "scalar " + std::to_string(sz) + "B"; + BENCHMARK(scalar_name.c_str()) + { + tolower_scalar(output_scalar.data(), input.data(), sz); + Catch::Benchmark::keep_memory(output_scalar.data()); + return output_scalar[0]; + }; + + std::string simd_name = "ts::atc " + std::to_string(sz) + "B"; + BENCHMARK(simd_name.c_str()) + { + ts::ascii::tolower_copy(output_simd.data(), input.data(), sz); + Catch::Benchmark::keep_memory(output_simd.data()); + return output_simd[0]; + }; + } +}