diff --git a/CHANGELOG.md b/CHANGELOG.md index 8c5da39d1f..a415947e70 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,6 +21,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added `code:get_object_code/1` - Added `erlang:display_string/1` and `erlang:display_string/2` - Added Thumb-2 support to armv6m JIT backend, optimizing code for ARMv7-M and later cores +- Added support for `binary:split/2,3` list patterns and `trim` / `trim_all` options ### Changed - ~10% binary size reduction by rewriting module loading logic diff --git a/libs/estdlib/src/binary.erl b/libs/estdlib/src/binary.erl index 3417e83b2f..55e76af863 100644 --- a/libs/estdlib/src/binary.erl +++ b/libs/estdlib/src/binary.erl @@ -172,13 +172,12 @@ part(_Binary, _Pos, _Len) -> %% @equiv split(Binary, Pattern, []) %% @param Binary binary to split %% @param Pattern pattern to perform the split -%% @return a list composed of one or two binaries +%% @return a list of binaries %% @doc Split a binary according to pattern. %% If pattern is not found, returns a singleton list with the passed binary. -%% Unlike Erlang/OTP, pattern must be a binary. %% @end %%----------------------------------------------------------------------------- --spec split(Binary :: binary(), Pattern :: binary()) -> [binary()]. +-spec split(Binary :: binary(), Pattern :: binary() | [binary(), ...]) -> [binary()]. split(_Binary, _Pattern) -> erlang:nif_error(undefined). @@ -186,14 +185,16 @@ split(_Binary, _Pattern) -> %% @param Binary binary to split %% @param Pattern pattern to perform the split %% @param Options options for the split -%% @return a list composed of one or two binaries +%% @return a list of binaries %% @doc Split a binary according to pattern. %% If pattern is not found, returns a singleton list with the passed binary. -%% Unlike Erlang/OTP, pattern must be a binary. -%% Only implemented option is `global' +%% Pattern can be a binary or a non-empty list of non-empty binaries. +%% Implemented options are `global', `trim', and `trim_all'. %% @end %%----------------------------------------------------------------------------- --spec split(Binary :: binary(), Pattern :: binary(), Option :: [global]) -> [binary()]. +-spec split( + Binary :: binary(), Pattern :: binary() | [binary(), ...], Option :: [global | trim | trim_all] +) -> [binary()]. split(_Binary, _Pattern, _Option) -> erlang:nif_error(undefined). diff --git a/src/libAtomVM/defaultatoms.def b/src/libAtomVM/defaultatoms.def index 62d7722f18..838c7470ed 100644 --- a/src/libAtomVM/defaultatoms.def +++ b/src/libAtomVM/defaultatoms.def @@ -166,6 +166,8 @@ X(CAST_ATOM, "\x5", "$cast") X(UNICODE_ATOM, "\x7", "unicode") X(GLOBAL_ATOM, "\x6", "global") +X(TRIM_ATOM, "\x4", "trim") +X(TRIM_ALL_ATOM, "\x8", "trim_all") X(TYPE_ATOM, "\x4", "type") X(NAME_ATOM, "\x4", "name") X(ARITY_ATOM, "\x5", "arity") diff --git a/src/libAtomVM/nifs.c b/src/libAtomVM/nifs.c index cae5d8d5a4..15c70299c3 100644 --- a/src/libAtomVM/nifs.c +++ b/src/libAtomVM/nifs.c @@ -152,6 +152,7 @@ static NativeHandlerResult process_console_mailbox(Context *ctx); static term make_list_from_utf8_buf(const uint8_t *buf, size_t buf_len, Context *ctx); static term make_list_from_ascii_buf(const uint8_t *buf, size_t len, Context *ctx); +static bool is_valid_pattern(term t); static term nif_binary_at_2(Context *ctx, int argc, term argv[]); static term nif_binary_copy(Context *ctx, int argc, term argv[]); @@ -3500,76 +3501,219 @@ static term nif_binary_part_3(Context *ctx, int argc, term argv[]) return term_maybe_create_sub_binary(pattern_term, slice.pos, slice.len, &ctx->heap, ctx->global); } +typedef struct +{ + const char *data; + size_t size; +} SplitPattern; + +static void init_split_patterns(term pattern_term, SplitPattern *patterns, size_t *shortest_pattern_length) +{ + size_t shortest = SIZE_MAX; + + if (term_is_binary(pattern_term)) { + patterns[0].data = term_binary_data(pattern_term); + patterns[0].size = term_binary_size(pattern_term); + shortest = patterns[0].size; + } else { + size_t index = 0; + term list = pattern_term; + while (term_is_nonempty_list(list)) { + term pattern = term_get_list_head(list); + patterns[index].data = term_binary_data(pattern); + patterns[index].size = term_binary_size(pattern); + if (patterns[index].size < shortest) { + shortest = patterns[index].size; + } + list = term_get_list_tail(list); + index++; + } + } + + *shortest_pattern_length = shortest; +} + +static const char *find_pattern(const char *binary, size_t binary_size, const SplitPattern *patterns, + size_t pattern_count, size_t *matched_pattern_index) +{ + const char *best_match = NULL; + size_t best_match_size = 0; + size_t best_match_index = 0; + + for (size_t i = 0; i < pattern_count; i++) { + size_t pattern_size = patterns[i].size; + if (binary_size < pattern_size) { + continue; + } + + const char *candidate = memmem(binary, binary_size, patterns[i].data, pattern_size); + if (candidate == NULL) { + continue; + } + + if (best_match == NULL || candidate < best_match || (candidate == best_match && pattern_size > best_match_size)) { + best_match = candidate; + best_match_size = pattern_size; + best_match_index = i; + } + } + + if (best_match != NULL) { + *matched_pattern_index = best_match_index; + } + + return best_match; +} + +static inline bool term_is_empty_binary(term t) +{ + return term_is_binary(t) && term_binary_size(t) == 0; +} + +static term trim_split_result(term list, bool trim, bool trim_all) +{ + if (trim_all) { + term trimmed_list = term_nil(); + term prev_kept = term_nil(); + bool has_prev_kept = false; + + term cursor = list; + while (term_is_nonempty_list(cursor)) { + term head = term_get_list_head(cursor); + term next = term_get_list_tail(cursor); + + if (!term_is_empty_binary(head)) { + if (!has_prev_kept) { + trimmed_list = cursor; + } else { + term_get_list_ptr(prev_kept)[LIST_TAIL_INDEX] = cursor; + } + prev_kept = cursor; + has_prev_kept = true; + } + + cursor = next; + } + + if (has_prev_kept) { + term_get_list_ptr(prev_kept)[LIST_TAIL_INDEX] = term_nil(); + } + + return trimmed_list; + } + + if (trim) { + term last_non_empty = term_nil(); + bool has_last_non_empty = false; + + term cursor = list; + while (term_is_nonempty_list(cursor)) { + if (!term_is_empty_binary(term_get_list_head(cursor))) { + last_non_empty = cursor; + has_last_non_empty = true; + } + cursor = term_get_list_tail(cursor); + } + + if (!has_last_non_empty) { + return term_nil(); + } + + term_get_list_ptr(last_non_empty)[LIST_TAIL_INDEX] = term_nil(); + } + + return list; +} + static term nif_binary_split(Context *ctx, int argc, term argv[]) { term bin_term = argv[0]; term pattern_term = argv[1]; VALIDATE_VALUE(bin_term, term_is_binary); - VALIDATE_VALUE(pattern_term, term_is_binary); + VALIDATE_VALUE(pattern_term, is_valid_pattern); bool global = false; + bool trim = false; + bool trim_all = false; if (argc == 3) { term options = argv[2]; if (UNLIKELY(!term_is_list(options))) { RAISE_ERROR(BADARG_ATOM); } - if (term_is_nonempty_list(options)) { + // Match BEAM semantics and ignore an improper tail after a valid + // option prefix, e.g. [global | foo]. + while (term_is_nonempty_list(options)) { term head = term_get_list_head(options); - term tail = term_get_list_tail(options); - if (UNLIKELY(head != GLOBAL_ATOM)) { - RAISE_ERROR(BADARG_ATOM); - } - if (UNLIKELY(!term_is_nil(tail))) { - RAISE_ERROR(BADARG_ATOM); + switch (head) { + case GLOBAL_ATOM: + global = true; + break; + case TRIM_ATOM: + trim = true; + break; + case TRIM_ALL_ATOM: + trim_all = true; + break; + default: + RAISE_ERROR(BADARG_ATOM); } - global = true; + options = term_get_list_tail(options); } } - int bin_size = term_binary_size(bin_term); - int pattern_size = term_binary_size(pattern_term); + size_t pattern_count = 1; + if (term_is_list(pattern_term)) { + int proper = 0; + pattern_count = term_list_length(pattern_term, &proper); + if (UNLIKELY(!proper)) { + RAISE_ERROR(BADARG_ATOM); + } + } - if (UNLIKELY(pattern_size == 0)) { - RAISE_ERROR(BADARG_ATOM); + SplitPattern *patterns = malloc(sizeof(SplitPattern) * pattern_count); + if (IS_NULL_PTR(patterns)) { + RAISE_ERROR(OUT_OF_MEMORY_ATOM); } + size_t shortest_pattern_length = 0; + init_split_patterns(pattern_term, patterns, &shortest_pattern_length); + + size_t bin_size = term_binary_size(bin_term); const char *bin_data = term_binary_data(bin_term); - const char *pattern_data = term_binary_data(pattern_term); // Count segments first to allocate memory once. size_t num_segments = 1; const char *temp_bin_data = bin_data; - int temp_bin_size = bin_size; + size_t temp_bin_size = bin_size; size_t heap_size = 0; do { - const char *found = (const char *) memmem(temp_bin_data, temp_bin_size, pattern_data, pattern_size); + size_t matched_pattern_index = 0; + const char *found = find_pattern(temp_bin_data, temp_bin_size, patterns, pattern_count, &matched_pattern_index); if (!found) { break; } num_segments++; heap_size += CONS_SIZE + term_sub_binary_heap_size(argv[0], found - temp_bin_data); - int next_search_offset = found - temp_bin_data + pattern_size; + size_t next_search_offset = (found - temp_bin_data) + patterns[matched_pattern_index].size; temp_bin_data += next_search_offset; temp_bin_size -= next_search_offset; - } while (global && temp_bin_size >= pattern_size); + } while (global && temp_bin_size >= shortest_pattern_length); heap_size += CONS_SIZE + term_sub_binary_heap_size(argv[0], temp_bin_size); term result_list = term_nil(); - if (num_segments == 1) { - // not found - if (UNLIKELY(memory_ensure_free_with_roots(ctx, LIST_SIZE(1, 0), 1, argv, MEMORY_CAN_SHRINK) != MEMORY_GC_OK)) { - RAISE_ERROR(OUT_OF_MEMORY_ATOM); - } - - return term_list_prepend(argv[0], result_list, &ctx->heap); + size_t needed_heap_size = num_segments == 1 ? LIST_SIZE(1, 0) : heap_size; + if (UNLIKELY(memory_ensure_free_with_roots(ctx, needed_heap_size, 2, argv, MEMORY_CAN_SHRINK) != MEMORY_GC_OK)) { + free(patterns); + RAISE_ERROR(OUT_OF_MEMORY_ATOM); } - // binary:split/2,3 always return sub binaries, except when copied binaries are as small as sub-binaries. - if (UNLIKELY(memory_ensure_free_with_roots(ctx, heap_size, 2, argv, MEMORY_CAN_SHRINK) != MEMORY_GC_OK)) { - RAISE_ERROR(OUT_OF_MEMORY_ATOM); + if (num_segments == 1) { + result_list = term_list_prepend(argv[0], result_list, &ctx->heap); + free(patterns); + return trim_split_result(result_list, trim, trim_all); } // Allocate list first @@ -3577,16 +3721,17 @@ static term nif_binary_split(Context *ctx, int argc, term argv[]) result_list = term_list_prepend(term_nil(), result_list, &ctx->heap); } - // Reset pointers after allocation + // Reset pointers after allocation / possible GC. bin_data = term_binary_data(argv[0]); - pattern_data = term_binary_data(argv[1]); + init_split_patterns(argv[1], patterns, &shortest_pattern_length); term list_cursor = result_list; temp_bin_data = bin_data; temp_bin_size = bin_size; term *list_ptr = term_get_list_ptr(list_cursor); do { - const char *found = (const char *) memmem(temp_bin_data, temp_bin_size, pattern_data, pattern_size); + size_t matched_pattern_index = 0; + const char *found = find_pattern(temp_bin_data, temp_bin_size, patterns, pattern_count, &matched_pattern_index); if (found) { term tok = term_maybe_create_sub_binary(argv[0], temp_bin_data - bin_data, found - temp_bin_data, &ctx->heap, ctx->global); @@ -3595,7 +3740,7 @@ static term nif_binary_split(Context *ctx, int argc, term argv[]) list_cursor = list_ptr[LIST_TAIL_INDEX]; list_ptr = term_get_list_ptr(list_cursor); - int next_search_offset = found - temp_bin_data + pattern_size; + size_t next_search_offset = (found - temp_bin_data) + patterns[matched_pattern_index].size; temp_bin_data += next_search_offset; temp_bin_size -= next_search_offset; } @@ -3607,7 +3752,9 @@ static term nif_binary_split(Context *ctx, int argc, term argv[]) } } while (!term_is_nil(list_cursor)); - return result_list; + free(patterns); + + return trim_split_result(result_list, trim, trim_all); } static term nif_binary_replace(Context *ctx, int argc, term argv[]) diff --git a/tests/erlang_tests/test_binary_split.erl b/tests/erlang_tests/test_binary_split.erl index 8b907a7e21..54d5684f99 100644 --- a/tests/erlang_tests/test_binary_split.erl +++ b/tests/erlang_tests/test_binary_split.erl @@ -31,7 +31,17 @@ start() -> ok = split_compare(<<"Test">>, <<>>), ok = split_compare2(<<"Test">>, <<>>), ok = split_compare2(<<"helloSEPARATORworld">>, <<"hello">>, <<"world">>), + ok = split_compare_expected([<<"f">>, <<"bar">>], <<"foobar">>, [<<"oo">>, <<"o">>]), + ok = split_compare_expected([<<>>, <<>>, <<>>], <<"aba">>, [<<"a">>, <<"ab">>], [global]), + ok = split_compare_expected([<<>>, <<"a">>], <<":a::">>, <<":">>, [global, trim]), + ok = split_compare_expected([<<"a">>], <<":a::">>, <<":">>, [global, trim_all]), + ok = split_compare_expected([], <<>>, <<":">>, [trim]), + ok = split_compare_expected([<<"abc">>], <<"abc">>, [<<"z">>, <<"y">>], [trim_all]), + ok = split_compare_expected([<<"a">>], <<"a">>, <<":">>, [global | foo]), ok = fail_split(<<>>), + ok = fail_split([]), + ok = fail_split([<<>>]), + ok = fail_split([foo]), ok = fail_split({1, 2}), ok = fail_split2({1, 2}), case erlang:system_info(machine) of @@ -58,6 +68,18 @@ split_compare2(Bin, Part1, Part2) -> ok = compare_bin(Part1, A), ok = compare_bin(B, Part2). +split_compare_expected(Expected, Bin, Pattern) -> + split_compare_expected(Expected, Bin, Pattern, []). + +split_compare_expected(Expected, Bin, Pattern, Options) -> + compare_bin_list(Expected, binary:split(Bin, Pattern, Options)). + +compare_bin_list([], []) -> + ok; +compare_bin_list([ExpectedHead | ExpectedTail], [ActualHead | ActualTail]) -> + ok = compare_bin(ExpectedHead, ActualHead), + compare_bin_list(ExpectedTail, ActualTail). + compare_bin(Bin1, Bin2) -> compare_bin(Bin1, Bin2, byte_size(Bin1) - 1). diff --git a/tests/libs/estdlib/test_binary.erl b/tests/libs/estdlib/test_binary.erl index 65bd0f0e6f..bc7039624e 100644 --- a/tests/libs/estdlib/test_binary.erl +++ b/tests/libs/estdlib/test_binary.erl @@ -35,6 +35,20 @@ test_split() -> ?ASSERT_MATCH(binary:split(<<"foobar">>, <<"ooz">>), [<<"foobar">>]), ?ASSERT_MATCH(binary:split(<<"foobar">>, <<"o">>), [<<"f">>, <<"obar">>]), ?ASSERT_MATCH(binary:split(<<"foobar">>, <<"o">>, [global]), [<<"f">>, <<>>, <<"bar">>]), + ?ASSERT_MATCH(binary:split(<<"foobar">>, [<<"oo">>, <<"o">>]), [<<"f">>, <<"bar">>]), + ?ASSERT_MATCH(binary:split(<<"aba">>, [<<"a">>, <<"ab">>], [global]), [<<>>, <<>>, <<>>]), + ?ASSERT_MATCH(binary:split(<<":a::">>, <<":">>, [global, trim]), [<<>>, <<"a">>]), + ?ASSERT_MATCH(binary:split(<<":a::">>, <<":">>, [global, trim_all]), [<<"a">>]), + ?ASSERT_MATCH(binary:split(<<>>, <<":">>, [trim]), []), + ?ASSERT_EXCEPTION(binary:split(<<"a">>, []), error, badarg), + ?ASSERT_EXCEPTION(binary:split(<<"a">>, [<<>>]), error, badarg), + ?ASSERT_EXCEPTION(binary:split(<<"a">>, [foo]), error, badarg), + ?ASSERT_EXCEPTION(binary:split(<<"a">>, <<":">>, [foo]), error, badarg), + ?ASSERT_MATCH(binary:split(<<"a">>, <<":">>, [global | foo]), [<<"a">>]), + ?ASSERT_MATCH(binary:split(<<"aba">>, [<<"a">>, <<"ab">>]), [<<>>, <<"a">>]), + ?ASSERT_MATCH(binary:split(<<"aba">>, [<<"ab">>, <<"a">>]), [<<>>, <<"a">>]), + ?ASSERT_MATCH(binary:split(<<":">>, <<":">>, [trim]), []), + ?ASSERT_MATCH(binary:split(<<":">>, <<":">>, [trim_all]), []), ok. test_list_to_bin() ->