diff --git a/.bumpversion.toml b/.bumpversion.toml index 0b6016416b1..56d6c8ef5b5 100644 --- a/.bumpversion.toml +++ b/.bumpversion.toml @@ -1,5 +1,5 @@ [tool.bumpversion] -current_version = "7.1.0-beta.1" +current_version = "7.1.0-beta.4" parse = "(?P\\d+)\\.(?P\\d+)\\.(?P\\d+)(-(?P(beta|rc))\\.(?P\\d+))?" serialize = [ "{major}.{minor}.{patch}-{prerelease}.{prerelease_num}", diff --git a/Cargo.lock b/Cargo.lock index 1b902bd3e4d..e20b54534a2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -506,9 +506,9 @@ checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0" [[package]] name = "autocfg" -version = "1.5.0" +version = "1.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8" +checksum = "f2032f911046de80f0a198e0901378627c33f59ea0ac00e363d481118bd70a53" [[package]] name = "aws-config" @@ -531,7 +531,7 @@ dependencies = [ "bytes", "fastrand", "hex", - "http 1.4.0", + "http 1.4.1", "ring", "time", "tokio", @@ -594,7 +594,7 @@ dependencies = [ "bytes-utils", "fastrand", "http 0.2.12", - "http 1.4.0", + "http 1.4.1", "http-body 0.4.6", "http-body 1.0.1", "percent-encoding", @@ -622,7 +622,7 @@ dependencies = [ "bytes", "fastrand", "http 0.2.12", - "http 1.4.0", + "http 1.4.1", "regex-lite", "tracing", ] @@ -652,7 +652,7 @@ dependencies = [ "hex", "hmac", "http 0.2.12", - "http 1.4.0", + "http 1.4.1", "http-body 1.0.1", "lru", "percent-encoding", @@ -681,7 +681,7 @@ dependencies = [ "bytes", "fastrand", "http 0.2.12", - "http 1.4.0", + "http 1.4.1", "regex-lite", "tracing", ] @@ -705,7 +705,7 @@ dependencies = [ "bytes", "fastrand", "http 0.2.12", - "http 1.4.0", + "http 1.4.1", "regex-lite", "tracing", ] @@ -730,7 +730,7 @@ dependencies = [ "aws-types", "fastrand", "http 0.2.12", - "http 1.4.0", + "http 1.4.1", "regex-lite", "tracing", ] @@ -751,7 +751,7 @@ dependencies = [ "hex", "hmac", "http 0.2.12", - "http 1.4.0", + "http 1.4.1", "percent-encoding", "sha2", "time", @@ -780,7 +780,7 @@ dependencies = [ "bytes", "crc-fast", "hex", - "http 1.4.0", + "http 1.4.1", "http-body 1.0.1", "http-body-util", "md-5", @@ -814,7 +814,7 @@ dependencies = [ "bytes-utils", "futures-core", "futures-util", - "http 1.4.0", + "http 1.4.1", "http-body 1.0.1", "http-body-util", "percent-encoding", @@ -833,7 +833,7 @@ dependencies = [ "aws-smithy-runtime-api", "aws-smithy-types", "h2", - "http 1.4.0", + "http 1.4.1", "hyper", "hyper-rustls", "hyper-util", @@ -890,7 +890,7 @@ dependencies = [ "bytes", "fastrand", "http 0.2.12", - "http 1.4.0", + "http 1.4.1", "http-body 0.4.6", "http-body 1.0.1", "http-body-util", @@ -910,7 +910,7 @@ dependencies = [ "aws-smithy-types", "bytes", "http 0.2.12", - "http 1.4.0", + "http 1.4.1", "pin-project-lite", "tokio", "tracing", @@ -928,7 +928,7 @@ dependencies = [ "bytes-utils", "futures-core", "http 0.2.12", - "http 1.4.0", + "http 1.4.1", "http-body 0.4.6", "http-body 1.0.1", "http-body-util", @@ -976,7 +976,7 @@ dependencies = [ "axum-core", "bytes", "futures-util", - "http 1.4.0", + "http 1.4.1", "http-body 1.0.1", "http-body-util", "hyper", @@ -1009,7 +1009,7 @@ dependencies = [ "async-trait", "bytes", "futures-util", - "http 1.4.0", + "http 1.4.1", "http-body 1.0.1", "http-body-util", "mime", @@ -1220,9 +1220,9 @@ dependencies = [ [[package]] name = "bumpalo" -version = "3.20.2" +version = "3.20.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5d20789868f4b01b2f2caec9f5c4e0213b41e3e5702a50157d699ae31ced2fcb" +checksum = "72f5acc6cb2ba439de613abc23857ec3d78374d8ed5ac84e9d11336e87da8649" [[package]] name = "bytecheck" @@ -3093,7 +3093,7 @@ checksum = "42703706b716c37f96a77aea830392ad231f44c9e9a67872fa5548707e11b11c" [[package]] name = "fsst" -version = "7.1.0-beta.1" +version = "7.1.0-beta.4" dependencies = [ "arrow-array", "rand 0.9.4", @@ -3358,9 +3358,9 @@ dependencies = [ [[package]] name = "geohash" -version = "0.13.1" +version = "0.13.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0fb94b1a65401d6cbf22958a9040aa364812c26674f841bee538b12c135db1e6" +checksum = "7f58890382f70caccc5fa388981f7ac80c913795042afce9f3e065695d8f7464" dependencies = [ "geo-types", "libm", @@ -3464,7 +3464,7 @@ dependencies = [ "fnv", "futures-core", "futures-sink", - "http 1.4.0", + "http 1.4.1", "indexmap 2.14.0", "slab", "tokio", @@ -3573,7 +3573,7 @@ checksum = "629d8f3bbeda9d148036d6b0de0a3ab947abd08ce90626327fc3547a49d59d97" dependencies = [ "dirs", "futures", - "http 1.4.0", + "http 1.4.1", "indicatif", "libc", "log", @@ -3597,7 +3597,7 @@ checksum = "430b33fa84f92796d4d263070b6c0d3ca219df7b9a0e1853ee431029b1612bcd" dependencies = [ "async-trait", "bytes", - "http 1.4.0", + "http 1.4.1", "more-asserts", "serde", "thiserror 2.0.18", @@ -3633,9 +3633,9 @@ dependencies = [ [[package]] name = "http" -version = "1.4.0" +version = "1.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e3ba2a386d7f85a81f119ad7498ebe444d2e22c2af0b86b069416ace48b3311a" +checksum = "8be7462df143984c4598a256ef469b251d7d7f9e271135073e78fc535414f3d0" dependencies = [ "bytes", "itoa", @@ -3659,7 +3659,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1efedce1fb8e6913f23e0c92de8e62cd5b772a67e7b3946df930a62566c93184" dependencies = [ "bytes", - "http 1.4.0", + "http 1.4.1", ] [[package]] @@ -3670,7 +3670,7 @@ checksum = "b021d93e26becf5dc7e1b75b1bed1fd93124b374ceb73f43d4d4eafec896a64a" dependencies = [ "bytes", "futures-core", - "http 1.4.0", + "http 1.4.1", "http-body 1.0.1", "pin-project-lite", ] @@ -3704,7 +3704,7 @@ dependencies = [ "futures-channel", "futures-core", "h2", - "http 1.4.0", + "http 1.4.1", "http-body 1.0.1", "httparse", "httpdate", @@ -3721,7 +3721,7 @@ version = "0.27.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "33ca68d021ef39cf6463ab54c1d0f5daf03377b70561305bb89a8f83aab66e0f" dependencies = [ - "http 1.4.0", + "http 1.4.1", "hyper", "hyper-util", "rustls", @@ -3758,7 +3758,7 @@ dependencies = [ "bytes", "futures-channel", "futures-util", - "http 1.4.0", + "http 1.4.1", "http-body 1.0.1", "hyper", "ipnet", @@ -3863,6 +3863,21 @@ dependencies = [ "zerovec", ] +[[package]] +name = "icu_locale" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d5a396343c7208121dc86e35623d3dfe19814a7613cfd14964994cdc9c9a2e26" +dependencies = [ + "icu_collections", + "icu_locale_core", + "icu_locale_data", + "icu_provider", + "potential_utf", + "tinystr", + "zerovec", +] + [[package]] name = "icu_locale_core" version = "2.2.0" @@ -3871,11 +3886,18 @@ checksum = "92219b62b3e2b4d88ac5119f8904c10f8f61bf7e95b640d25ba3075e6cac2c29" dependencies = [ "displaydoc", "litemap", + "serde", "tinystr", "writeable", "zerovec", ] +[[package]] +name = "icu_locale_data" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d5fdcc9ac77c6d74ff5cf6e65ef3181d6af32003b16fce3a77fb451d2f695993" + [[package]] name = "icu_normalizer" version = "2.2.0" @@ -3924,6 +3946,8 @@ checksum = "139c4cf31c8b5f33d7e199446eff9c1e02decfc2f0eec2c8d71f65befa45b421" dependencies = [ "displaydoc", "icu_locale_core", + "serde", + "stable_deref_trait", "writeable", "yoke", "zerofrom", @@ -3931,6 +3955,27 @@ dependencies = [ "zerovec", ] +[[package]] +name = "icu_segmenter" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c0794db0b1a86193ac9c48768d0e6c52c54448e0870ad87907d456ee0dac964" +dependencies = [ + "icu_collections", + "icu_locale", + "icu_provider", + "icu_segmenter_data", + "potential_utf", + "utf8_iter", + "zerovec", +] + +[[package]] +name = "icu_segmenter_data" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e4a2c462a4d927d512f5f882a033ddd62f33a05bb9f230d98f736ac3dc85938f" + [[package]] name = "id-arena" version = "2.3.0" @@ -4121,18 +4166,18 @@ checksum = "8f42a60cbdf9a97f5d2305f08a87dc4e09308d1276d28c869c684d7777685682" [[package]] name = "jieba-macros" -version = "0.9.0" +version = "0.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a29cfc5dcd898604c6f80363411fa6b6b08e27d1d253d6225b9cb6702ea02fc0" +checksum = "661344b2412fb00aee1841d2405c9a31f7c91cf6e578a8e953647c43dd1a8b0a" dependencies = [ "phf_codegen", ] [[package]] name = "jieba-rs" -version = "0.9.0" +version = "0.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3245d6e9d1d5facbd6a23848d6b67e3439738ccbb4fa5a3d65da315ba1a910a2" +checksum = "d7ef90d6209fcff084a01b488c4199d882e3764b15ff0e7a6b5d7efaa46e1e4f" dependencies = [ "cedarwood", "jieba-macros", @@ -4143,9 +4188,9 @@ dependencies = [ [[package]] name = "jiff" -version = "0.2.24" +version = "0.2.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f00b5dbd620d61dfdcb6007c9c1f6054ebd75319f163d886a9055cec1155073d" +checksum = "392c70591e8749fe235ddaf513e6f58b26bce3dcc16524cecc8936f75afa161e" dependencies = [ "jiff-static", "jiff-tzdb-platform", @@ -4155,14 +4200,14 @@ dependencies = [ "portable-atomic-util", "serde_core", "wasm-bindgen", - "windows-sys 0.61.2", + "windows-link", ] [[package]] name = "jiff-static" -version = "0.2.24" +version = "0.2.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e000de030ff8022ea1da3f466fbb0f3a809f5e51ed31f6dd931c35181ad8e6d7" +checksum = "47b605b0c050d845fc355bb11eb3f9a8deddc218ea60c76e61aa1f2adfb2c96a" dependencies = [ "proc-macro2", "quote", @@ -4245,9 +4290,9 @@ dependencies = [ [[package]] name = "js-sys" -version = "0.3.98" +version = "0.3.99" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "67df7112613f8bfd9150013a0314e196f4800d3201ae742489d999db2f979f08" +checksum = "142bc4740e452c1e57ade0cbc129f139c9093e354346f0872ef985f4f5cf5f11" dependencies = [ "cfg-if 1.0.4", "futures-util", @@ -4321,7 +4366,7 @@ checksum = "e037a2e1d8d5fdbd49b16a4ea09d5d6401c1f29eca5ff29d03d3824dba16256a" [[package]] name = "lance" -version = "7.1.0-beta.1" +version = "7.1.0-beta.4" dependencies = [ "all_asserts", "approx", @@ -4423,7 +4468,7 @@ dependencies = [ [[package]] name = "lance-arrow" -version = "7.1.0-beta.1" +version = "7.1.0-beta.4" dependencies = [ "arrow-array", "arrow-buffer", @@ -4471,7 +4516,7 @@ dependencies = [ [[package]] name = "lance-bitpacking" -version = "7.1.0-beta.1" +version = "7.1.0-beta.4" dependencies = [ "arrayref", "paste", @@ -4480,7 +4525,7 @@ dependencies = [ [[package]] name = "lance-core" -version = "7.1.0-beta.1" +version = "7.1.0-beta.4" dependencies = [ "arrow-array", "arrow-buffer", @@ -4517,7 +4562,7 @@ dependencies = [ [[package]] name = "lance-datafusion" -version = "7.1.0-beta.1" +version = "7.1.0-beta.4" dependencies = [ "arrow", "arrow-array", @@ -4550,7 +4595,7 @@ dependencies = [ [[package]] name = "lance-datagen" -version = "7.1.0-beta.1" +version = "7.1.0-beta.4" dependencies = [ "arrow", "arrow-array", @@ -4570,7 +4615,7 @@ dependencies = [ [[package]] name = "lance-encoding" -version = "7.1.0-beta.1" +version = "7.1.0-beta.4" dependencies = [ "arrow-arith", "arrow-array", @@ -4615,7 +4660,7 @@ dependencies = [ [[package]] name = "lance-examples" -version = "7.1.0-beta.1" +version = "7.1.0-beta.4" dependencies = [ "all_asserts", "arrow", @@ -4641,7 +4686,7 @@ dependencies = [ [[package]] name = "lance-file" -version = "7.1.0-beta.1" +version = "7.1.0-beta.4" dependencies = [ "arrow-arith", "arrow-array", @@ -4681,7 +4726,7 @@ dependencies = [ [[package]] name = "lance-geo" -version = "7.1.0-beta.1" +version = "7.1.0-beta.4" dependencies = [ "datafusion", "geo-traits", @@ -4695,7 +4740,7 @@ dependencies = [ [[package]] name = "lance-index" -version = "7.1.0-beta.1" +version = "7.1.0-beta.4" dependencies = [ "approx", "arc-swap", @@ -4732,6 +4777,7 @@ dependencies = [ "jieba-rs", "jsonb", "lance-arrow", + "lance-arrow-stats", "lance-core", "lance-datafusion", "lance-datagen", @@ -4773,7 +4819,7 @@ dependencies = [ [[package]] name = "lance-io" -version = "7.1.0-beta.1" +version = "7.1.0-beta.4" dependencies = [ "arrow", "arrow-arith", @@ -4793,7 +4839,7 @@ dependencies = [ "criterion", "deepsize", "futures", - "http 1.4.0", + "http 1.4.1", "io-uring", "lance-arrow", "lance-core", @@ -4822,7 +4868,7 @@ dependencies = [ [[package]] name = "lance-linalg" -version = "7.1.0-beta.1" +version = "7.1.0-beta.4" dependencies = [ "approx", "arrow-array", @@ -4843,7 +4889,7 @@ dependencies = [ [[package]] name = "lance-namespace" -version = "7.1.0-beta.1" +version = "7.1.0-beta.4" dependencies = [ "arrow", "async-trait", @@ -4855,7 +4901,7 @@ dependencies = [ [[package]] name = "lance-namespace-datafusion" -version = "7.1.0-beta.1" +version = "7.1.0-beta.4" dependencies = [ "arrow-array", "arrow-schema", @@ -4871,7 +4917,7 @@ dependencies = [ [[package]] name = "lance-namespace-impls" -version = "7.1.0-beta.1" +version = "7.1.0-beta.4" dependencies = [ "arrow", "arrow-ipc", @@ -4927,7 +4973,7 @@ dependencies = [ [[package]] name = "lance-select" -version = "7.1.0-beta.1" +version = "7.1.0-beta.4" dependencies = [ "arrow-array", "arrow-buffer", @@ -4944,7 +4990,7 @@ dependencies = [ [[package]] name = "lance-table" -version = "7.1.0-beta.1" +version = "7.1.0-beta.4" dependencies = [ "arrow", "arrow-array", @@ -4991,7 +5037,7 @@ dependencies = [ [[package]] name = "lance-test-macros" -version = "7.1.0-beta.1" +version = "7.1.0-beta.4" dependencies = [ "proc-macro2", "quote", @@ -5000,7 +5046,7 @@ dependencies = [ [[package]] name = "lance-testing" -version = "7.1.0-beta.1" +version = "7.1.0-beta.4" dependencies = [ "arrow-array", "arrow-schema", @@ -5011,8 +5057,9 @@ dependencies = [ [[package]] name = "lance-tokenizer" -version = "7.1.0-beta.1" +version = "7.1.0-beta.4" dependencies = [ + "icu_segmenter", "jieba-rs", "lindera", "rust-stemmers", @@ -5022,7 +5069,7 @@ dependencies = [ [[package]] name = "lance-tools" -version = "7.1.0-beta.1" +version = "7.1.0-beta.4" dependencies = [ "clap", "lance-core", @@ -5203,9 +5250,9 @@ dependencies = [ [[package]] name = "log" -version = "0.4.29" +version = "0.4.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5e5032e24019045c762d3c0f28f5b6b8bbf38563a65908389bf7978758920897" +checksum = "616ec5685824bcc94416c6d4a7a446eea774a31efd7062c8480ba6fd06d7a6e5" [[package]] name = "loom" @@ -5751,7 +5798,7 @@ dependencies = [ "futures-channel", "futures-core", "futures-util", - "http 1.4.0", + "http 1.4.1", "http-body-util", "httparse", "humantime", @@ -5871,7 +5918,7 @@ dependencies = [ "base64 0.22.1", "bytes", "futures", - "http 1.4.0", + "http 1.4.1", "http-body 1.0.1", "jiff", "log", @@ -5880,7 +5927,7 @@ dependencies = [ "percent-encoding", "quick-xml 0.38.4", "reqsign-core", - "reqwest 0.13.3", + "reqwest 0.13.4", "serde", "serde_json", "tokio", @@ -5896,7 +5943,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "048b1b29c503263bdd80a9afe46a68cd02ea9bd361185b1feab4b151078998e9" dependencies = [ "futures", - "http 1.4.0", + "http 1.4.1", "mea", "opendal-core", ] @@ -5940,7 +5987,7 @@ checksum = "7452bf3ec61cfd81ac9ad9ada17825931e9e371d44a045c6bfab9596c0a2ac3b" dependencies = [ "base64 0.22.1", "bytes", - "http 1.4.0", + "http 1.4.1", "log", "opendal-core", "opendal-service-azure-common", @@ -5960,7 +6007,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8f9884c2d8cf8ba2bb077d79c877dac5863ba3bab9e2c9c1e41a2e0491404772" dependencies = [ "bytes", - "http 1.4.0", + "http 1.4.1", "log", "opendal-core", "opendal-service-azure-common", @@ -5978,7 +6025,7 @@ version = "0.56.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ffb0e45d6c8dcf66ce2da20e241bcb80e6e540e109a4ff20f318f6c9b4c54e0c" dependencies = [ - "http 1.4.0", + "http 1.4.1", "opendal-core", ] @@ -5989,7 +6036,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "55a0765ba451b6effdbf514b7b50060530ff8a29e4231c4a3ab7792c016408e6" dependencies = [ "bytes", - "http 1.4.0", + "http 1.4.1", "log", "opendal-core", "quick-xml 0.38.4", @@ -6007,7 +6054,7 @@ checksum = "70a49477a10163431896d106136117f5670717f9c9e49cf6f710528800c6633a" dependencies = [ "async-trait", "bytes", - "http 1.4.0", + "http 1.4.1", "log", "opendal-core", "percent-encoding", @@ -6028,11 +6075,11 @@ checksum = "7b2ab7a2a8a11dfe257ef4db5c0de798acbcd0d6429c37382dad2154bc06a388" dependencies = [ "bytes", "hf-xet", - "http 1.4.0", + "http 1.4.1", "log", "opendal-core", "percent-encoding", - "reqwest 0.13.3", + "reqwest 0.13.4", "serde", "serde_json", ] @@ -6044,7 +6091,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "29c8a917829ad06d21b639558532cb0101fe49b040d946d673a73018683fac05" dependencies = [ "bytes", - "http 1.4.0", + "http 1.4.1", "log", "opendal-core", "quick-xml 0.38.4", @@ -6063,7 +6110,7 @@ dependencies = [ "base64 0.22.1", "bytes", "crc32c", - "http 1.4.0", + "http 1.4.1", "log", "md-5", "opendal-core", @@ -6522,6 +6569,8 @@ version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0103b1cef7ec0cf76490e969665504990193874ea05c85ff9bab8b911d0a0564" dependencies = [ + "serde_core", + "writeable", "zerovec", ] @@ -7130,7 +7179,7 @@ checksum = "57ac2757f3140aa2e213b554148ae0b52733e624fc6723f0cc6bb3d440176c95" dependencies = [ "anyhow", "form_urlencoded", - "http 1.4.0", + "http 1.4.1", "log", "percent-encoding", "reqsign-core", @@ -7148,7 +7197,7 @@ dependencies = [ "anyhow", "bytes", "form_urlencoded", - "http 1.4.0", + "http 1.4.1", "log", "percent-encoding", "quick-xml 0.39.4", @@ -7170,7 +7219,7 @@ dependencies = [ "base64 0.22.1", "bytes", "form_urlencoded", - "http 1.4.0", + "http 1.4.1", "jsonwebtoken", "log", "pem", @@ -7195,7 +7244,7 @@ dependencies = [ "futures", "hex", "hmac", - "http 1.4.0", + "http 1.4.1", "jiff", "log", "percent-encoding", @@ -7222,7 +7271,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "35cc609b49c69e76ecaceb775a03f792d1ed3e7755ab3548d4534fd801e3242e" dependencies = [ "form_urlencoded", - "http 1.4.0", + "http 1.4.1", "jsonwebtoken", "log", "percent-encoding", @@ -7242,7 +7291,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e128f19525861dbded59e1e7c17653a8ed63d573ca04aed708d552dbef5bb32a" dependencies = [ "anyhow", - "http 1.4.0", + "http 1.4.1", "log", "percent-encoding", "reqsign-core", @@ -7262,7 +7311,7 @@ dependencies = [ "futures-core", "futures-util", "h2", - "http 1.4.0", + "http 1.4.1", "http-body 1.0.1", "http-body-util", "hyper", @@ -7301,15 +7350,15 @@ dependencies = [ [[package]] name = "reqwest" -version = "0.13.3" +version = "0.13.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "62e0021ea2c22aed41653bc7e1419abb2c97e038ff2c33d0e1309e49a97deec0" +checksum = "219c5811de6525e5416c7d5d53bb656d3afdbc6c5af816e0802bcfa42dbdc1c3" dependencies = [ "base64 0.22.1", "bytes", "futures-core", "futures-util", - "http 1.4.0", + "http 1.4.1", "http-body 1.0.1", "http-body-util", "hyper", @@ -7347,8 +7396,8 @@ checksum = "07bc3f1384cffa4f274dad2d4ddd73aed32fed8f786d96c6be8aa4e5fd3c3b58" dependencies = [ "anyhow", "async-trait", - "http 1.4.0", - "reqwest 0.13.3", + "http 1.4.1", + "reqwest 0.13.4", "thiserror 2.0.18", "tower-service", ] @@ -7840,9 +7889,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.149" +version = "1.0.150" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "83fc039473c5595ace860d8c4fafa220ff474b3fc6bfdb4293327f1a37e94d86" +checksum = "e8014e44b4736ed0538adeecded0fce2a272f22dc9578a7eb6b2d9993c74cfb9" dependencies = [ "itoa", "memchr", @@ -8640,6 +8689,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c8323304221c2a851516f22236c5722a72eaa19749016521d6dff0824447d96d" dependencies = [ "displaydoc", + "serde_core", "zerovec", ] @@ -8839,7 +8889,7 @@ checksum = "1e9cd434a998747dd2c4276bc96ee2e0c7a2eadf3cae88e52be55a05fa9053f5" dependencies = [ "bitflags 2.11.1", "bytes", - "http 1.4.0", + "http 1.4.1", "http-body 1.0.1", "http-body-util", "pin-project-lite", @@ -8859,7 +8909,7 @@ dependencies = [ "bytes", "futures-core", "futures-util", - "http 1.4.0", + "http 1.4.1", "http-body 1.0.1", "http-body-util", "pin-project-lite", @@ -9324,9 +9374,9 @@ dependencies = [ [[package]] name = "wasm-bindgen" -version = "0.2.121" +version = "0.2.122" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "49ace1d07c165b0864824eee619580c4689389afa9dc9ed3a4c75040d82e6790" +checksum = "3ed04576f974d2b2fba0f38c51dbc5518011e38c36bf1143164be765528fd409" dependencies = [ "cfg-if 1.0.4", "once_cell", @@ -9337,9 +9387,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-futures" -version = "0.4.71" +version = "0.4.72" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "96492d0d3ffba25305a7dc88720d250b1401d7edca02cc3bcd50633b424673b8" +checksum = "9473dbd2991ae90b6291c3c32c30c6187ac49aa32f9905d1cce280ec1e110b0f" dependencies = [ "js-sys", "wasm-bindgen", @@ -9347,9 +9397,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro" -version = "0.2.121" +version = "0.2.122" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8e68e6f4afd367a562002c05637acb8578ff2dea1943df76afb9e83d177c8578" +checksum = "916151b09da36bd82f6615cbf3a419e2f0ba23a03c6160e8e92eb6bd4aa1dec6" dependencies = [ "quote", "wasm-bindgen-macro-support", @@ -9357,9 +9407,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro-support" -version = "0.2.121" +version = "0.2.122" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d95a9ec35c64b2a7cb35d3fead40c4238d0940c86d107136999567a4703259f2" +checksum = "299047362ccbfce148b67ab7e73349f77748e00c8296f9542adfad2ad82c5c5e" dependencies = [ "bumpalo", "proc-macro2", @@ -9370,9 +9420,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-shared" -version = "0.2.121" +version = "0.2.122" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c4e0100b01e9f0d03189a92b96772a1fb998639d981193d7dbab487302513441" +checksum = "9a929b2c61f11ba3e9bc35b50c1f25cb38e0e892c0c231ae2b8cf78d5dad4437" dependencies = [ "unicode-ident", ] @@ -9439,9 +9489,9 @@ dependencies = [ [[package]] name = "web-sys" -version = "0.3.98" +version = "0.3.99" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4b572dff8bcf38bad0fa19729c89bb5748b2b9b1d8be70cf90df697e3a8f32aa" +checksum = "6d621441cfc37b84979402712047321980c178f299193a3589d05b99e8763436" dependencies = [ "js-sys", "wasm-bindgen", @@ -9833,7 +9883,7 @@ dependencies = [ "base64 0.22.1", "deadpool", "futures", - "http 1.4.0", + "http 1.4.1", "http-body-util", "hyper", "hyper-util", @@ -9993,13 +10043,13 @@ dependencies = [ "clap", "crc32fast", "futures", - "http 1.4.0", + "http 1.4.1", "hyper", "lazy_static", "more-asserts", "rand 0.10.1", "redb", - "reqwest 0.13.3", + "reqwest 0.13.4", "reqwest-middleware", "serde", "serde_json", @@ -10067,7 +10117,7 @@ dependencies = [ "chrono", "clap", "gearhash", - "http 1.4.0", + "http 1.4.1", "itertools 0.14.0", "lazy_static", "more-asserts", @@ -10112,7 +10162,7 @@ dependencies = [ "oneshot", "pin-project", "rand 0.10.1", - "reqwest 0.13.3", + "reqwest 0.13.4", "serde", "serde_json", "shellexpand", @@ -10238,6 +10288,7 @@ dependencies = [ "displaydoc", "yoke", "zerofrom", + "zerovec", ] [[package]] @@ -10246,6 +10297,7 @@ version = "0.11.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "90f911cbc359ab6af17377d242225f4d75119aec87ea711a880987b18cd7b239" dependencies = [ + "serde", "yoke", "zerofrom", "zerovec-derive", diff --git a/Cargo.toml b/Cargo.toml index 7bd2192e770..871c0320bea 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -31,7 +31,7 @@ resolver = "3" [workspace.package] -version = "7.1.0-beta.1" +version = "7.1.0-beta.4" edition = "2024" authors = ["Lance Devs "] license = "Apache-2.0" @@ -56,30 +56,31 @@ rust-version = "1.91.0" [workspace.dependencies] arc-swap = "1.7" libc = "0.2.176" -lance = { version = "=7.1.0-beta.1", path = "./rust/lance", default-features = false } -lance-arrow = { version = "=7.1.0-beta.1", path = "./rust/lance-arrow" } -lance-core = { version = "=7.1.0-beta.1", path = "./rust/lance-core" } -lance-datafusion = { version = "=7.1.0-beta.1", path = "./rust/lance-datafusion" } -lance-datagen = { version = "=7.1.0-beta.1", path = "./rust/lance-datagen" } -lance-encoding = { version = "=7.1.0-beta.1", path = "./rust/lance-encoding" } -lance-file = { version = "=7.1.0-beta.1", path = "./rust/lance-file" } -lance-geo = { version = "=7.1.0-beta.1", path = "./rust/lance-geo" } -lance-index = { version = "=7.1.0-beta.1", path = "./rust/lance-index" } -lance-io = { version = "=7.1.0-beta.1", path = "./rust/lance-io", default-features = false } -lance-linalg = { version = "=7.1.0-beta.1", path = "./rust/lance-linalg" } -lance-namespace = { version = "=7.1.0-beta.1", path = "./rust/lance-namespace" } -lance-namespace-impls = { version = "=7.1.0-beta.1", path = "./rust/lance-namespace-impls" } +lance = { version = "=7.1.0-beta.4", path = "./rust/lance", default-features = false } +lance-arrow = { version = "=7.1.0-beta.4", path = "./rust/lance-arrow" } +lance-core = { version = "=7.1.0-beta.4", path = "./rust/lance-core" } +lance-datafusion = { version = "=7.1.0-beta.4", path = "./rust/lance-datafusion" } +lance-datagen = { version = "=7.1.0-beta.4", path = "./rust/lance-datagen" } +lance-encoding = { version = "=7.1.0-beta.4", path = "./rust/lance-encoding" } +lance-file = { version = "=7.1.0-beta.4", path = "./rust/lance-file" } +lance-geo = { version = "=7.1.0-beta.4", path = "./rust/lance-geo" } +lance-index = { version = "=7.1.0-beta.4", path = "./rust/lance-index" } +lance-io = { version = "=7.1.0-beta.4", path = "./rust/lance-io", default-features = false } +lance-linalg = { version = "=7.1.0-beta.4", path = "./rust/lance-linalg" } +lance-namespace = { version = "=7.1.0-beta.4", path = "./rust/lance-namespace" } +lance-namespace-impls = { version = "=7.1.0-beta.4", path = "./rust/lance-namespace-impls" } lance-namespace-datafusion = { version = "=7.0.0-beta.9", path = "./rust/lance-namespace-datafusion" } lance-namespace-reqwest-client = "0.7.7" -lance-select = { version = "=7.1.0-beta.1", path = "./rust/lance-select" } -lance-tokenizer = { version = "=7.1.0-beta.1", path = "./rust/lance-tokenizer" } -lance-table = { version = "=7.1.0-beta.1", path = "./rust/lance-table" } -lance-test-macros = { version = "=7.1.0-beta.1", path = "./rust/lance-test-macros" } -lance-testing = { version = "=7.1.0-beta.1", path = "./rust/lance-testing" } +lance-select = { version = "=7.1.0-beta.4", path = "./rust/lance-select" } +lance-tokenizer = { version = "=7.1.0-beta.4", path = "./rust/lance-tokenizer" } +lance-table = { version = "=7.1.0-beta.4", path = "./rust/lance-table" } +lance-test-macros = { version = "=7.1.0-beta.4", path = "./rust/lance-test-macros" } +lance-testing = { version = "=7.1.0-beta.4", path = "./rust/lance-testing" } approx = "0.5.1" # Note that this one does not include pyarrow arrow = { version = "58.0.0", optional = false, features = ["prettyprint"] } lance-arrow-scalar = { version = "=58.0.0", path = "./rust/arrow-scalar" } +lance-arrow-stats = { version = "=58.0.0", path = "./rust/arrow-stats" } arrow-arith = "58.0.0" arrow-array = "58.0.0" arrow-buffer = "58.0.0" @@ -101,7 +102,7 @@ half = { "version" = "2.1", default-features = false, features = [ "num-traits", "std", ] } -lance-bitpacking = { version = "=7.1.0-beta.1", path = "./rust/compression/bitpacking" } +lance-bitpacking = { version = "=7.1.0-beta.4", path = "./rust/compression/bitpacking" } bitpacking = "0.9" bitvec = "1" bytes = "1.11.1" @@ -141,7 +142,7 @@ deepsize = "0.2.0" dirs = "6.0.0" either = "1.0" fst = { version = "0.4.7", features = ["levenshtein"] } -fsst = { version = "=7.1.0-beta.1", path = "./rust/compression/fsst" } +fsst = { version = "=7.1.0-beta.4", path = "./rust/compression/fsst" } futures = "0.3" geoarrow-array = "0.8" geoarrow-schema = "0.8" @@ -151,9 +152,10 @@ geo-types = "0.7.16" http = "1.1.0" humantime = "2.2.0" hyperloglogplus = { version = "0.4.1", features = ["const-loop"] } +icu_segmenter = { version = "2.2", default-features = false, features = ["compiled_data"] } io-uring = "0.7" itertools = "0.13" -jieba-rs = { version = "0.9.0", default-features = false } +jieba-rs = { version = "0.10.0", default-features = false } jsonb = { version = "0.5.3", default-features = false, features = ["databend"] } libm = "0.2.15" log = "0.4" diff --git a/RUST_THIRD_PARTY_LICENSES.html b/RUST_THIRD_PARTY_LICENSES.html index 1ec36159b0f..e7db17f2115 100644 --- a/RUST_THIRD_PARTY_LICENSES.html +++ b/RUST_THIRD_PARTY_LICENSES.html @@ -18370,12 +18370,16 @@

Unicode License v3

Used by:

  • icu_collections 2.2.0
  • +
  • icu_locale 2.2.0
  • icu_locale_core 2.2.0
  • +
  • icu_locale_data 2.2.0
  • icu_normalizer 2.2.0
  • icu_normalizer_data 2.2.0
  • icu_properties 2.2.0
  • icu_properties_data 2.2.0
  • icu_provider 2.2.0
  • +
  • icu_segmenter 2.2.0
  • +
  • icu_segmenter_data 2.2.0
  • litemap 0.8.2
  • potential_utf 0.1.5
  • tinystr 0.8.3
  • diff --git a/docs/src/format/index/scalar/fts.md b/docs/src/format/index/scalar/fts.md index 33c4a5ed0da..792702aca08 100644 --- a/docs/src/format/index/scalar/fts.md +++ b/docs/src/format/index/scalar/fts.md @@ -80,9 +80,21 @@ The full text search index supports multiple tokenizer types for different text | **whitespace** | Splits only on whitespace characters | Preserve punctuation | | **raw** | No tokenization, treats entire text as single token | Exact matching | | **ngram** | Breaks text into overlapping character sequences | Substring/fuzzy search | +| **icu** | ICU dictionary-based Unicode word segmentation | Mixed-language text | | **jieba/*** | Chinese text tokenizer with word segmentation | Chinese text | | **lindera/*** | Japanese text tokenizer with morphological analysis | Japanese text | +#### ICU Tokenizer (Mixed-language text) + +The ICU tokenizer uses Unicode word boundary rules and dictionary-based segmentation for complex scripts. It is useful for mixed-language text where the default `simple` tokenizer would keep an unspaced CJK span as one large token. + +- **Models**: Uses compiled ICU4X segmenter data bundled with Lance +- **Usage**: Specify as `icu` +- **Features**: + - Unicode-aware word boundary detection + - Dictionary-based segmentation for Chinese, Japanese, Khmer, Lao, Myanmar, and Thai + - No external language model download required + #### Jieba Tokenizer (Chinese) Jieba is a popular Chinese text segmentation library that uses a dictionary-based approach with statistical methods for word segmentation. @@ -257,4 +269,4 @@ Here are the query types enabled by the FTS index: | **phrase** | Exact phrase matching with position information (requires `with_position: true`) | `{"phrase": {"query": "exact phrase"}}` | AtMost | | **boolean** | Complex boolean queries with must/should/must_not clauses for sophisticated search logic | `{"boolean": {"must": [...], "should": [...]}}` | AtMost | | **multi_match** | Search across multiple fields simultaneously with unified scoring | `{"multi_match": [{"field1": "query"}, ...]}` | AtMost | -| **boost** | Boost relevance scores for specific terms or queries by a configurable factor | `{"boost": {"query": {...}, "factor": 2.0}}` | AtMost | \ No newline at end of file +| **boost** | Boost relevance scores for specific terms or queries by a configurable factor | `{"boost": {"query": {...}, "factor": 2.0}}` | AtMost | diff --git a/docs/src/guide/tokenizer.md b/docs/src/guide/tokenizer.md index 096d972b7cf..192574656f5 100644 --- a/docs/src/guide/tokenizer.md +++ b/docs/src/guide/tokenizer.md @@ -1,6 +1,6 @@ # Tokenizers -Currently, Lance has built-in support for Jieba and Lindera. However, it doesn't come with its own language models. +Currently, Lance has built-in support for ICU, Jieba, and Lindera. ICU uses built-in segmenter data. Jieba and Lindera require external language models. If tokenization is needed, you can download language models by yourself. You can specify the location where the language models are stored by setting the environment variable LANCE_LANGUAGE_MODEL_HOME. If it's not set, the default value is @@ -12,6 +12,14 @@ ${system data directory}/lance/language_models It also supports configuring user dictionaries, which makes it convenient for users to expand their own dictionaries without retraining the language models. +## ICU Tokenizer + +ICU uses Unicode word boundary rules and bundled dictionary data for complex scripts. It is useful for mixed-language text and does not require downloading a language model. + +```python +ds.create_scalar_index("text", "INVERTED", base_tokenizer="icu") +``` + ## Language Models of Jieba ### Downloading the Model diff --git a/docs/src/quickstart/full-text-search.md b/docs/src/quickstart/full-text-search.md index 1f965d64a2c..17327e40bc5 100644 --- a/docs/src/quickstart/full-text-search.md +++ b/docs/src/quickstart/full-text-search.md @@ -90,7 +90,7 @@ ds.create_scalar_index( index_type="INVERTED", name="text_idx", # Optional index name (if omitted, default is "text_idx") with_position=False, # Set True to enable phrase queries (stores token positions) - base_tokenizer="simple", # Tokenizer: "simple" (whitespace+punct), "whitespace", or "raw" (no tokenization) + base_tokenizer="simple", # Tokenizer: "simple" (whitespace+punct), "icu", "whitespace", or "raw" (no tokenization) language="English", # Language used for stemming + stop words (only used if `stem` or `remove_stop_words` is True) max_token_length=40, # Drop tokens longer than this length lower_case=True, # Lowercase text before tokenization @@ -109,6 +109,7 @@ ds.create_scalar_index( Lance also supports multilingual tokenization: +- **icu**: Unicode word segmentation with built-in ICU dictionaries - **jieba/default**: Chinese text tokenization using Jieba - **lindera/ipadic**: Japanese text tokenization using Lindera with IPAdic dictionary - **lindera/ko-dic**: Korean text tokenization using Lindera with Ko-dic dictionary diff --git a/docs/src/quickstart/index.md b/docs/src/quickstart/index.md index 00daa6a4ee4..606948263c4 100644 --- a/docs/src/quickstart/index.md +++ b/docs/src/quickstart/index.md @@ -19,11 +19,24 @@ pip install pylance For the latest features and bug fixes, you can install the preview version: -```bash -pip install --pre --extra-index-url https://pypi.fury.io/lance-format/pylance -``` +=== "pip" + + ```bash + pip install --pre --extra-index-url https://pypi.fury.io/lance-format/ pylance + ``` + +=== "uv" + + ```bash + uv venv + uv pip install --prerelease allow --index https://pypi.fury.io/lance-format/ pylance + + # To add to pyproject.toml, just do: + uv add --prerelease allow --index https://pypi.fury.io/lance-format/ pylance + ``` -> Note: Preview releases receive the same level of testing as regular releases. +!!! note + Preview releases receive the same level of testing as regular releases. ## Set Up Your Environment diff --git a/java/lance-jni/Cargo.lock b/java/lance-jni/Cargo.lock index 264ea4294b7..ab1e8efc003 100644 --- a/java/lance-jni/Cargo.lock +++ b/java/lance-jni/Cargo.lock @@ -466,9 +466,9 @@ checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0" [[package]] name = "autocfg" -version = "1.5.0" +version = "1.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8" +checksum = "f2032f911046de80f0a198e0901378627c33f59ea0ac00e363d481118bd70a53" [[package]] name = "aws-config" @@ -491,7 +491,7 @@ dependencies = [ "bytes", "fastrand", "hex", - "http 1.4.0", + "http 1.4.1", "ring", "time", "tokio", @@ -552,7 +552,7 @@ dependencies = [ "bytes", "bytes-utils", "fastrand", - "http 1.4.0", + "http 1.4.1", "http-body 1.0.1", "percent-encoding", "pin-project-lite", @@ -579,7 +579,7 @@ dependencies = [ "bytes", "fastrand", "http 0.2.12", - "http 1.4.0", + "http 1.4.1", "regex-lite", "tracing", ] @@ -603,7 +603,7 @@ dependencies = [ "bytes", "fastrand", "http 0.2.12", - "http 1.4.0", + "http 1.4.1", "regex-lite", "tracing", ] @@ -628,7 +628,7 @@ dependencies = [ "aws-types", "fastrand", "http 0.2.12", - "http 1.4.0", + "http 1.4.1", "regex-lite", "tracing", ] @@ -648,7 +648,7 @@ dependencies = [ "hex", "hmac", "http 0.2.12", - "http 1.4.0", + "http 1.4.1", "percent-encoding", "sha2", "time", @@ -678,7 +678,7 @@ dependencies = [ "bytes-utils", "futures-core", "futures-util", - "http 1.4.0", + "http 1.4.1", "http-body 1.0.1", "http-body-util", "percent-encoding", @@ -697,7 +697,7 @@ dependencies = [ "aws-smithy-runtime-api", "aws-smithy-types", "h2", - "http 1.4.0", + "http 1.4.1", "hyper", "hyper-rustls", "hyper-util", @@ -754,7 +754,7 @@ dependencies = [ "bytes", "fastrand", "http 0.2.12", - "http 1.4.0", + "http 1.4.1", "http-body 0.4.6", "http-body 1.0.1", "http-body-util", @@ -774,7 +774,7 @@ dependencies = [ "aws-smithy-types", "bytes", "http 0.2.12", - "http 1.4.0", + "http 1.4.1", "pin-project-lite", "tokio", "tracing", @@ -791,7 +791,7 @@ dependencies = [ "bytes", "bytes-utils", "http 0.2.12", - "http 1.4.0", + "http 1.4.1", "http-body 0.4.6", "http-body 1.0.1", "http-body-util", @@ -837,7 +837,7 @@ dependencies = [ "axum-core", "bytes", "futures-util", - "http 1.4.0", + "http 1.4.1", "http-body 1.0.1", "http-body-util", "hyper", @@ -870,7 +870,7 @@ dependencies = [ "async-trait", "bytes", "futures-util", - "http 1.4.0", + "http 1.4.1", "http-body 1.0.1", "http-body-util", "mime", @@ -1039,9 +1039,9 @@ dependencies = [ [[package]] name = "bumpalo" -version = "3.20.2" +version = "3.20.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5d20789868f4b01b2f2caec9f5c4e0213b41e3e5702a50157d699ae31ced2fcb" +checksum = "72f5acc6cb2ba439de613abc23857ec3d78374d8ed5ac84e9d11336e87da8649" [[package]] name = "bytemuck" @@ -2509,7 +2509,7 @@ checksum = "42703706b716c37f96a77aea830392ad231f44c9e9a67872fa5548707e11b11c" [[package]] name = "fsst" -version = "7.1.0-beta.1" +version = "7.1.0-beta.4" dependencies = [ "arrow-array", "rand 0.9.4", @@ -2766,9 +2766,9 @@ dependencies = [ [[package]] name = "geohash" -version = "0.13.1" +version = "0.13.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0fb94b1a65401d6cbf22958a9040aa364812c26674f841bee538b12c135db1e6" +checksum = "7f58890382f70caccc5fa388981f7ac80c913795042afce9f3e065695d8f7464" dependencies = [ "geo-types", "libm", @@ -2866,7 +2866,7 @@ dependencies = [ "fnv", "futures-core", "futures-sink", - "http 1.4.0", + "http 1.4.1", "indexmap 2.14.0", "slab", "tokio", @@ -2975,7 +2975,7 @@ checksum = "430b33fa84f92796d4d263070b6c0d3ca219df7b9a0e1853ee431029b1612bcd" dependencies = [ "async-trait", "bytes", - "http 1.4.0", + "http 1.4.1", "more-asserts", "serde", "thiserror 2.0.18", @@ -3011,9 +3011,9 @@ dependencies = [ [[package]] name = "http" -version = "1.4.0" +version = "1.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e3ba2a386d7f85a81f119ad7498ebe444d2e22c2af0b86b069416ace48b3311a" +checksum = "8be7462df143984c4598a256ef469b251d7d7f9e271135073e78fc535414f3d0" dependencies = [ "bytes", "itoa", @@ -3037,7 +3037,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1efedce1fb8e6913f23e0c92de8e62cd5b772a67e7b3946df930a62566c93184" dependencies = [ "bytes", - "http 1.4.0", + "http 1.4.1", ] [[package]] @@ -3048,7 +3048,7 @@ checksum = "b021d93e26becf5dc7e1b75b1bed1fd93124b374ceb73f43d4d4eafec896a64a" dependencies = [ "bytes", "futures-core", - "http 1.4.0", + "http 1.4.1", "http-body 1.0.1", "pin-project-lite", ] @@ -3082,7 +3082,7 @@ dependencies = [ "futures-channel", "futures-core", "h2", - "http 1.4.0", + "http 1.4.1", "http-body 1.0.1", "httparse", "httpdate", @@ -3099,7 +3099,7 @@ version = "0.27.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "33ca68d021ef39cf6463ab54c1d0f5daf03377b70561305bb89a8f83aab66e0f" dependencies = [ - "http 1.4.0", + "http 1.4.1", "hyper", "hyper-util", "rustls", @@ -3119,7 +3119,7 @@ dependencies = [ "bytes", "futures-channel", "futures-util", - "http 1.4.0", + "http 1.4.1", "http-body 1.0.1", "hyper", "ipnet", @@ -3422,9 +3422,9 @@ checksum = "8f42a60cbdf9a97f5d2305f08a87dc4e09308d1276d28c869c684d7777685682" [[package]] name = "jiff" -version = "0.2.24" +version = "0.2.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f00b5dbd620d61dfdcb6007c9c1f6054ebd75319f163d886a9055cec1155073d" +checksum = "392c70591e8749fe235ddaf513e6f58b26bce3dcc16524cecc8936f75afa161e" dependencies = [ "jiff-static", "jiff-tzdb-platform", @@ -3434,14 +3434,14 @@ dependencies = [ "portable-atomic-util", "serde_core", "wasm-bindgen", - "windows-sys 0.61.2", + "windows-link", ] [[package]] name = "jiff-static" -version = "0.2.24" +version = "0.2.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e000de030ff8022ea1da3f466fbb0f3a809f5e51ed31f6dd931c35181ad8e6d7" +checksum = "47b605b0c050d845fc355bb11eb3f9a8deddc218ea60c76e61aa1f2adfb2c96a" dependencies = [ "proc-macro2", "quote", @@ -3549,9 +3549,9 @@ dependencies = [ [[package]] name = "js-sys" -version = "0.3.98" +version = "0.3.99" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "67df7112613f8bfd9150013a0314e196f4800d3201ae742489d999db2f979f08" +checksum = "142bc4740e452c1e57ade0cbc129f139c9093e354346f0872ef985f4f5cf5f11" dependencies = [ "cfg-if 1.0.4", "futures-util", @@ -3616,7 +3616,7 @@ checksum = "e037a2e1d8d5fdbd49b16a4ea09d5d6401c1f29eca5ff29d03d3824dba16256a" [[package]] name = "lance" -version = "7.1.0-beta.1" +version = "7.1.0-beta.4" dependencies = [ "arc-swap", "arrow", @@ -3688,7 +3688,7 @@ dependencies = [ [[package]] name = "lance-arrow" -version = "7.1.0-beta.1" +version = "7.1.0-beta.4" dependencies = [ "arrow-array", "arrow-buffer", @@ -3708,7 +3708,7 @@ dependencies = [ [[package]] name = "lance-bitpacking" -version = "7.1.0-beta.1" +version = "7.1.0-beta.4" dependencies = [ "arrayref", "paste", @@ -3717,7 +3717,7 @@ dependencies = [ [[package]] name = "lance-core" -version = "7.1.0-beta.1" +version = "7.1.0-beta.4" dependencies = [ "arrow-array", "arrow-buffer", @@ -3752,7 +3752,7 @@ dependencies = [ [[package]] name = "lance-datafusion" -version = "7.1.0-beta.1" +version = "7.1.0-beta.4" dependencies = [ "arrow", "arrow-array", @@ -3784,7 +3784,7 @@ dependencies = [ [[package]] name = "lance-datagen" -version = "7.1.0-beta.1" +version = "7.1.0-beta.4" dependencies = [ "arrow", "arrow-array", @@ -3802,7 +3802,7 @@ dependencies = [ [[package]] name = "lance-encoding" -version = "7.1.0-beta.1" +version = "7.1.0-beta.4" dependencies = [ "arrow-arith", "arrow-array", @@ -3837,7 +3837,7 @@ dependencies = [ [[package]] name = "lance-file" -version = "7.1.0-beta.1" +version = "7.1.0-beta.4" dependencies = [ "arrow-arith", "arrow-array", @@ -3868,7 +3868,7 @@ dependencies = [ [[package]] name = "lance-geo" -version = "7.1.0-beta.1" +version = "7.1.0-beta.4" dependencies = [ "datafusion", "geo-traits", @@ -3882,7 +3882,7 @@ dependencies = [ [[package]] name = "lance-index" -version = "7.1.0-beta.1" +version = "7.1.0-beta.4" dependencies = [ "arc-swap", "arrow", @@ -3950,7 +3950,7 @@ dependencies = [ [[package]] name = "lance-io" -version = "7.1.0-beta.1" +version = "7.1.0-beta.4" dependencies = [ "arrow", "arrow-arith", @@ -3969,7 +3969,7 @@ dependencies = [ "chrono", "deepsize", "futures", - "http 1.4.0", + "http 1.4.1", "io-uring", "lance-arrow", "lance-core", @@ -3992,7 +3992,7 @@ dependencies = [ [[package]] name = "lance-jni" -version = "7.1.0-beta.1" +version = "7.1.0-beta.4" dependencies = [ "arrow", "arrow-array", @@ -4028,7 +4028,7 @@ dependencies = [ [[package]] name = "lance-linalg" -version = "7.1.0-beta.1" +version = "7.1.0-beta.4" dependencies = [ "arrow-array", "arrow-buffer", @@ -4044,7 +4044,7 @@ dependencies = [ [[package]] name = "lance-namespace" -version = "7.1.0-beta.1" +version = "7.1.0-beta.4" dependencies = [ "arrow", "async-trait", @@ -4056,7 +4056,7 @@ dependencies = [ [[package]] name = "lance-namespace-impls" -version = "7.1.0-beta.1" +version = "7.1.0-beta.4" dependencies = [ "arrow", "arrow-ipc", @@ -4100,7 +4100,7 @@ dependencies = [ [[package]] name = "lance-select" -version = "7.1.0-beta.1" +version = "7.1.0-beta.4" dependencies = [ "arrow-array", "arrow-buffer", @@ -4114,7 +4114,7 @@ dependencies = [ [[package]] name = "lance-table" -version = "7.1.0-beta.1" +version = "7.1.0-beta.4" dependencies = [ "arrow", "arrow-array", @@ -4152,7 +4152,7 @@ dependencies = [ [[package]] name = "lance-tokenizer" -version = "7.1.0-beta.1" +version = "7.1.0-beta.4" dependencies = [ "rust-stemmers", "serde", @@ -4275,9 +4275,9 @@ dependencies = [ [[package]] name = "log" -version = "0.4.29" +version = "0.4.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5e5032e24019045c762d3c0f28f5b6b8bbf38563a65908389bf7978758920897" +checksum = "616ec5685824bcc94416c6d4a7a446eea774a31efd7062c8480ba6fd06d7a6e5" [[package]] name = "loom" @@ -4635,7 +4635,7 @@ dependencies = [ "futures-channel", "futures-core", "futures-util", - "http 1.4.0", + "http 1.4.1", "http-body-util", "httparse", "humantime", @@ -4727,7 +4727,7 @@ dependencies = [ "base64", "bytes", "futures", - "http 1.4.0", + "http 1.4.1", "http-body 1.0.1", "jiff", "log", @@ -4736,7 +4736,7 @@ dependencies = [ "percent-encoding", "quick-xml 0.38.4", "reqsign-core", - "reqwest 0.13.3", + "reqwest 0.13.4", "serde", "serde_json", "tokio", @@ -4752,7 +4752,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "048b1b29c503263bdd80a9afe46a68cd02ea9bd361185b1feab4b151078998e9" dependencies = [ "futures", - "http 1.4.0", + "http 1.4.1", "mea", "opendal-core", ] @@ -4796,7 +4796,7 @@ checksum = "7452bf3ec61cfd81ac9ad9ada17825931e9e371d44a045c6bfab9596c0a2ac3b" dependencies = [ "base64", "bytes", - "http 1.4.0", + "http 1.4.1", "log", "opendal-core", "opendal-service-azure-common", @@ -4816,7 +4816,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8f9884c2d8cf8ba2bb077d79c877dac5863ba3bab9e2c9c1e41a2e0491404772" dependencies = [ "bytes", - "http 1.4.0", + "http 1.4.1", "log", "opendal-core", "opendal-service-azure-common", @@ -4834,7 +4834,7 @@ version = "0.56.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ffb0e45d6c8dcf66ce2da20e241bcb80e6e540e109a4ff20f318f6c9b4c54e0c" dependencies = [ - "http 1.4.0", + "http 1.4.1", "opendal-core", ] @@ -4845,7 +4845,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "55a0765ba451b6effdbf514b7b50060530ff8a29e4231c4a3ab7792c016408e6" dependencies = [ "bytes", - "http 1.4.0", + "http 1.4.1", "log", "opendal-core", "quick-xml 0.38.4", @@ -4863,7 +4863,7 @@ checksum = "70a49477a10163431896d106136117f5670717f9c9e49cf6f710528800c6633a" dependencies = [ "async-trait", "bytes", - "http 1.4.0", + "http 1.4.1", "log", "opendal-core", "percent-encoding", @@ -4884,11 +4884,11 @@ checksum = "7b2ab7a2a8a11dfe257ef4db5c0de798acbcd0d6429c37382dad2154bc06a388" dependencies = [ "bytes", "hf-xet", - "http 1.4.0", + "http 1.4.1", "log", "opendal-core", "percent-encoding", - "reqwest 0.13.3", + "reqwest 0.13.4", "serde", "serde_json", ] @@ -4900,7 +4900,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "29c8a917829ad06d21b639558532cb0101fe49b040d946d673a73018683fac05" dependencies = [ "bytes", - "http 1.4.0", + "http 1.4.1", "log", "opendal-core", "quick-xml 0.38.4", @@ -4919,7 +4919,7 @@ dependencies = [ "base64", "bytes", "crc32c", - "http 1.4.0", + "http 1.4.1", "log", "md-5", "opendal-core", @@ -5715,7 +5715,7 @@ checksum = "57ac2757f3140aa2e213b554148ae0b52733e624fc6723f0cc6bb3d440176c95" dependencies = [ "anyhow", "form_urlencoded", - "http 1.4.0", + "http 1.4.1", "log", "percent-encoding", "reqsign-core", @@ -5733,7 +5733,7 @@ dependencies = [ "anyhow", "bytes", "form_urlencoded", - "http 1.4.0", + "http 1.4.1", "log", "percent-encoding", "quick-xml 0.39.4", @@ -5755,7 +5755,7 @@ dependencies = [ "base64", "bytes", "form_urlencoded", - "http 1.4.0", + "http 1.4.1", "jsonwebtoken", "log", "pem", @@ -5780,7 +5780,7 @@ dependencies = [ "futures", "hex", "hmac", - "http 1.4.0", + "http 1.4.1", "jiff", "log", "percent-encoding", @@ -5807,7 +5807,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "35cc609b49c69e76ecaceb775a03f792d1ed3e7755ab3548d4534fd801e3242e" dependencies = [ "form_urlencoded", - "http 1.4.0", + "http 1.4.1", "jsonwebtoken", "log", "percent-encoding", @@ -5827,7 +5827,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e128f19525861dbded59e1e7c17653a8ed63d573ca04aed708d552dbef5bb32a" dependencies = [ "anyhow", - "http 1.4.0", + "http 1.4.1", "log", "percent-encoding", "reqsign-core", @@ -5847,7 +5847,7 @@ dependencies = [ "futures-core", "futures-util", "h2", - "http 1.4.0", + "http 1.4.1", "http-body 1.0.1", "http-body-util", "hyper", @@ -5882,15 +5882,15 @@ dependencies = [ [[package]] name = "reqwest" -version = "0.13.3" +version = "0.13.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "62e0021ea2c22aed41653bc7e1419abb2c97e038ff2c33d0e1309e49a97deec0" +checksum = "219c5811de6525e5416c7d5d53bb656d3afdbc6c5af816e0802bcfa42dbdc1c3" dependencies = [ "base64", "bytes", "futures-core", "futures-util", - "http 1.4.0", + "http 1.4.1", "http-body 1.0.1", "http-body-util", "hyper", @@ -5928,8 +5928,8 @@ checksum = "07bc3f1384cffa4f274dad2d4ddd73aed32fed8f786d96c6be8aa4e5fd3c3b58" dependencies = [ "anyhow", "async-trait", - "http 1.4.0", - "reqwest 0.13.3", + "http 1.4.1", + "reqwest 0.13.4", "thiserror 2.0.18", "tower-service", ] @@ -6318,9 +6318,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.149" +version = "1.0.150" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "83fc039473c5595ace860d8c4fafa220ff474b3fc6bfdb4293327f1a37e94d86" +checksum = "e8014e44b4736ed0538adeecded0fce2a272f22dc9578a7eb6b2d9993c74cfb9" dependencies = [ "itoa", "memchr", @@ -7099,7 +7099,7 @@ checksum = "1e9cd434a998747dd2c4276bc96ee2e0c7a2eadf3cae88e52be55a05fa9053f5" dependencies = [ "bitflags", "bytes", - "http 1.4.0", + "http 1.4.1", "http-body 1.0.1", "http-body-util", "pin-project-lite", @@ -7119,7 +7119,7 @@ dependencies = [ "bytes", "futures-core", "futures-util", - "http 1.4.0", + "http 1.4.1", "http-body 1.0.1", "http-body-util", "pin-project-lite", @@ -7491,9 +7491,9 @@ dependencies = [ [[package]] name = "wasm-bindgen" -version = "0.2.121" +version = "0.2.122" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "49ace1d07c165b0864824eee619580c4689389afa9dc9ed3a4c75040d82e6790" +checksum = "3ed04576f974d2b2fba0f38c51dbc5518011e38c36bf1143164be765528fd409" dependencies = [ "cfg-if 1.0.4", "once_cell", @@ -7504,9 +7504,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-futures" -version = "0.4.71" +version = "0.4.72" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "96492d0d3ffba25305a7dc88720d250b1401d7edca02cc3bcd50633b424673b8" +checksum = "9473dbd2991ae90b6291c3c32c30c6187ac49aa32f9905d1cce280ec1e110b0f" dependencies = [ "js-sys", "wasm-bindgen", @@ -7514,9 +7514,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro" -version = "0.2.121" +version = "0.2.122" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8e68e6f4afd367a562002c05637acb8578ff2dea1943df76afb9e83d177c8578" +checksum = "916151b09da36bd82f6615cbf3a419e2f0ba23a03c6160e8e92eb6bd4aa1dec6" dependencies = [ "quote", "wasm-bindgen-macro-support", @@ -7524,9 +7524,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro-support" -version = "0.2.121" +version = "0.2.122" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d95a9ec35c64b2a7cb35d3fead40c4238d0940c86d107136999567a4703259f2" +checksum = "299047362ccbfce148b67ab7e73349f77748e00c8296f9542adfad2ad82c5c5e" dependencies = [ "bumpalo", "proc-macro2", @@ -7537,9 +7537,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-shared" -version = "0.2.121" +version = "0.2.122" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c4e0100b01e9f0d03189a92b96772a1fb998639d981193d7dbab487302513441" +checksum = "9a929b2c61f11ba3e9bc35b50c1f25cb38e0e892c0c231ae2b8cf78d5dad4437" dependencies = [ "unicode-ident", ] @@ -7606,9 +7606,9 @@ dependencies = [ [[package]] name = "web-sys" -version = "0.3.98" +version = "0.3.99" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4b572dff8bcf38bad0fa19729c89bb5748b2b9b1d8be70cf90df697e3a8f32aa" +checksum = "6d621441cfc37b84979402712047321980c178f299193a3589d05b99e8763436" dependencies = [ "js-sys", "wasm-bindgen", @@ -8176,13 +8176,13 @@ dependencies = [ "clap", "crc32fast", "futures", - "http 1.4.0", + "http 1.4.1", "hyper", "lazy_static", "more-asserts", "rand 0.10.1", "redb", - "reqwest 0.13.3", + "reqwest 0.13.4", "reqwest-middleware", "serde", "serde_json", @@ -8250,7 +8250,7 @@ dependencies = [ "chrono", "clap", "gearhash", - "http 1.4.0", + "http 1.4.1", "itertools 0.14.0", "lazy_static", "more-asserts", @@ -8295,7 +8295,7 @@ dependencies = [ "oneshot", "pin-project", "rand 0.10.1", - "reqwest 0.13.3", + "reqwest 0.13.4", "serde", "serde_json", "shellexpand", diff --git a/java/lance-jni/Cargo.toml b/java/lance-jni/Cargo.toml index 434159a57ef..cd300e249f3 100644 --- a/java/lance-jni/Cargo.toml +++ b/java/lance-jni/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "lance-jni" -version = "7.1.0-beta.1" +version = "7.1.0-beta.4" edition = "2024" authors = ["Lance Devs "] rust-version = "1.91" diff --git a/java/lance-jni/src/async_scanner.rs b/java/lance-jni/src/async_scanner.rs index eada9287c47..7cb71c37086 100644 --- a/java/lance-jni/src/async_scanner.rs +++ b/java/lance-jni/src/async_scanner.rs @@ -191,6 +191,7 @@ pub extern "system" fn Java_org_lance_ipc_AsyncScanner_createAsyncScanner<'local batch_readahead: jint, column_orderings: JObject<'local>, use_scalar_index: jboolean, + fast_search: jboolean, substrait_aggregate_obj: JObject<'local>, ) -> JObject<'local> { crate::ok_or_throw!( @@ -213,6 +214,7 @@ pub extern "system" fn Java_org_lance_ipc_AsyncScanner_createAsyncScanner<'local batch_readahead, column_orderings, use_scalar_index, + fast_search, substrait_aggregate_obj, ) ) @@ -237,6 +239,7 @@ fn inner_create_async_scanner<'local>( batch_readahead: jint, column_orderings: JObject<'local>, use_scalar_index: jboolean, + fast_search: jboolean, substrait_aggregate_obj: JObject<'local>, ) -> Result> { let dataset_guard = @@ -260,6 +263,7 @@ fn inner_create_async_scanner<'local>( batch_readahead, column_orderings, use_scalar_index, + fast_search, substrait_aggregate_obj, }; diff --git a/java/lance-jni/src/blocking_scanner.rs b/java/lance-jni/src/blocking_scanner.rs index b77c170d497..f18b0d92a27 100644 --- a/java/lance-jni/src/blocking_scanner.rs +++ b/java/lance-jni/src/blocking_scanner.rs @@ -245,6 +245,7 @@ pub(crate) struct ScannerOptions<'a> { pub batch_readahead: jint, pub column_orderings: JObject<'a>, pub use_scalar_index: jboolean, + pub fast_search: jboolean, pub substrait_aggregate_obj: JObject<'a>, } @@ -317,7 +318,9 @@ pub(crate) fn build_scanner_with_options<'a>( let key_array = env.get_vec_f32_from_method(&java_obj, "getKey")?; let key = Float32Array::from(key_array); let k = env.get_int_as_usize_from_method(&java_obj, "getK")?; - let _ = scanner.nearest(&column, &key, k); + scanner + .nearest(&column, &key, k) + .map_err(|err| Error::input_error(err.to_string()))?; let minimum_nprobes = env.get_int_as_usize_from_method(&java_obj, "getMinimumNprobes")?; scanner.minimum_nprobes(minimum_nprobes); @@ -361,6 +364,10 @@ pub(crate) fn build_scanner_with_options<'a>( Ok(()) })?; + if options.fast_search == JNI_TRUE { + scanner.fast_search(); + } + scanner.batch_readahead(options.batch_readahead as usize); env.get_optional(&options.column_orderings, |env, java_obj| { @@ -413,6 +420,7 @@ pub extern "system" fn Java_org_lance_ipc_LanceScanner_createScanner<'local>( batch_readahead: jint, // int column_orderings: JObject<'local>, // Optional> use_scalar_index: jboolean, // boolean + fast_search: jboolean, // boolean substrait_aggregate_obj: JObject<'local>, // Optional collect_stats: jboolean, // boolean ) -> JObject<'local> { @@ -436,6 +444,7 @@ pub extern "system" fn Java_org_lance_ipc_LanceScanner_createScanner<'local>( batch_readahead, column_orderings, use_scalar_index, + fast_search, substrait_aggregate_obj, collect_stats, ) @@ -461,6 +470,7 @@ fn inner_create_scanner<'local>( batch_readahead: jint, column_orderings: JObject<'local>, use_scalar_index: jboolean, + fast_search: jboolean, substrait_aggregate_obj: JObject<'local>, collect_stats: jboolean, ) -> Result> { @@ -485,6 +495,7 @@ fn inner_create_scanner<'local>( batch_readahead, column_orderings, use_scalar_index, + fast_search, substrait_aggregate_obj, }; diff --git a/java/lance-jni/src/delta.rs b/java/lance-jni/src/delta.rs index 21a4f726ed1..d5a6b0f3a27 100755 --- a/java/lance-jni/src/delta.rs +++ b/java/lance-jni/src/delta.rs @@ -109,7 +109,7 @@ fn inner_native_build<'local>( } #[unsafe(no_mangle)] -pub extern "system" fn Java_org_lance_delta_DatasetDelta_listTransactions<'local>( +pub extern "system" fn Java_org_lance_delta_DatasetDelta_nativeListTransactions<'local>( mut env: JNIEnv<'local>, j_delta: JObject<'local>, ) -> JObject<'local> { @@ -140,7 +140,7 @@ fn inner_list_transactions<'local>( } #[unsafe(no_mangle)] -pub extern "system" fn Java_org_lance_delta_DatasetDelta_getInsertedRows<'local>( +pub extern "system" fn Java_org_lance_delta_DatasetDelta_nativeGetInsertedRows<'local>( mut env: JNIEnv<'local>, j_delta: JObject<'local>, stream_addr: jlong, @@ -164,7 +164,7 @@ fn inner_get_inserted_rows<'local>( } #[unsafe(no_mangle)] -pub extern "system" fn Java_org_lance_delta_DatasetDelta_getUpdatedRows<'local>( +pub extern "system" fn Java_org_lance_delta_DatasetDelta_nativeGetUpdatedRows<'local>( mut env: JNIEnv<'local>, j_delta: JObject<'local>, stream_addr: jlong, diff --git a/java/pom.xml b/java/pom.xml index 4fcb5e2a9de..b4fb17de876 100644 --- a/java/pom.xml +++ b/java/pom.xml @@ -7,7 +7,7 @@ org.lance lance-core Lance Core - 7.1.0-beta.1 + 7.1.0-beta.4 jar Lance Format Java API diff --git a/java/src/main/java/org/lance/ipc/AsyncScanner.java b/java/src/main/java/org/lance/ipc/AsyncScanner.java index 59ecdebd750..2ec317cb245 100644 --- a/java/src/main/java/org/lance/ipc/AsyncScanner.java +++ b/java/src/main/java/org/lance/ipc/AsyncScanner.java @@ -79,6 +79,7 @@ public static AsyncScanner create( options.getBatchReadahead(), options.getColumnOrderings(), options.isUseScalarIndex(), + options.isFastSearch(), options.getSubstraitAggregate()); scanner.allocator = allocator; return scanner; @@ -101,6 +102,7 @@ static native AsyncScanner createAsyncScanner( int batchReadahead, Optional> columnOrderings, boolean useScalarIndex, + boolean fastSearch, Optional substraitAggregate); /** diff --git a/java/src/main/java/org/lance/ipc/LanceScanner.java b/java/src/main/java/org/lance/ipc/LanceScanner.java index c4c9c55fef0..edd3ebc22cc 100644 --- a/java/src/main/java/org/lance/ipc/LanceScanner.java +++ b/java/src/main/java/org/lance/ipc/LanceScanner.java @@ -75,6 +75,7 @@ public static LanceScanner create( options.getBatchReadahead(), options.getColumnOrderings(), options.isUseScalarIndex(), + options.isFastSearch(), options.getSubstraitAggregate(), options.isCollectStats()); scanner.allocator = allocator; @@ -100,6 +101,7 @@ static native LanceScanner createScanner( int batchReadahead, Optional> columnOrderings, boolean useScalarIndex, + boolean fastSearch, Optional substraitAggregate, boolean collectStats); diff --git a/java/src/main/java/org/lance/ipc/Query.java b/java/src/main/java/org/lance/ipc/Query.java index 3ad1301ee59..48013b375ee 100644 --- a/java/src/main/java/org/lance/ipc/Query.java +++ b/java/src/main/java/org/lance/ipc/Query.java @@ -140,6 +140,10 @@ public Builder setColumn(String column) { /** * Sets the vector to be searched. * + *

    This API accepts a single query vector. The array length must match the target vector + * column dimension. Batch nearest-neighbor search with multiple query vectors requires a + * list-shaped query input and is not available through this {@code float[]} entry point. + * * @param key The search vector. * @return The Builder instance for method chaining. */ diff --git a/java/src/main/java/org/lance/ipc/ScanOptions.java b/java/src/main/java/org/lance/ipc/ScanOptions.java index 293c2f3d3dd..68c485e39a3 100644 --- a/java/src/main/java/org/lance/ipc/ScanOptions.java +++ b/java/src/main/java/org/lance/ipc/ScanOptions.java @@ -39,6 +39,46 @@ public class ScanOptions { private final boolean useScalarIndex; private final Optional substraitAggregate; private final boolean collectStats; + private final boolean fastSearch; + + public ScanOptions( + Optional> fragmentIds, + Optional batchSize, + Optional> columns, + Optional filter, + Optional substraitFilter, + Optional limit, + Optional offset, + Optional nearest, + Optional fullTextQuery, + boolean prefilter, + boolean withRowId, + boolean withRowAddress, + int batchReadahead, + Optional> columnOrderings, + boolean useScalarIndex, + Optional substraitAggregate, + boolean collectStats) { + this( + fragmentIds, + batchSize, + columns, + filter, + substraitFilter, + limit, + offset, + nearest, + fullTextQuery, + prefilter, + withRowId, + withRowAddress, + batchReadahead, + columnOrderings, + useScalarIndex, + substraitAggregate, + collectStats, + false); + } /** * Constructor for LanceScanOptions. @@ -60,6 +100,8 @@ public class ScanOptions { * @param columnOrderings (Optional) Column orderings for result sorting. * @param useScalarIndex Whether to use scalar indices for the scan. Default is true. * @param substraitAggregate (Optional) Substrait aggregate expression for aggregate pushdown. + * @param collectStats Whether to collect scan execution statistics. Default is false. + * @param fastSearch Whether to only search indexed fragments. Default is false. */ public ScanOptions( Optional> fragmentIds, @@ -78,7 +120,8 @@ public ScanOptions( Optional> columnOrderings, boolean useScalarIndex, Optional substraitAggregate, - boolean collectStats) { + boolean collectStats, + boolean fastSearch) { Preconditions.checkArgument( !(filter.isPresent() && substraitFilter.isPresent()), "cannot set both substrait filter and string filter"); @@ -99,6 +142,7 @@ public ScanOptions( this.useScalarIndex = useScalarIndex; this.substraitAggregate = substraitAggregate; this.collectStats = collectStats; + this.fastSearch = fastSearch; } /** @@ -231,6 +275,15 @@ public boolean isUseScalarIndex() { return useScalarIndex; } + /** + * Get whether to only search indexed fragments. + * + * @return true if unindexed fragments should be skipped, false otherwise. + */ + public boolean isFastSearch() { + return fastSearch; + } + /** * Get the substrait aggregate expression. * @@ -264,6 +317,7 @@ public String toString() { .add("batchReadahead", batchReadahead) .add("columnOrdering", columnOrderings) .add("useScalarIndex", useScalarIndex) + .add("fastSearch", fastSearch) .add( "substraitAggregate", substraitAggregate.map(buf -> "ByteBuffer[" + buf.remaining() + " bytes]").orElse(null)) @@ -288,6 +342,7 @@ public static class Builder { private int batchReadahead = 16; private Optional> columnOrderings = Optional.empty(); private boolean useScalarIndex = true; + private boolean fastSearch = false; private Optional substraitAggregate = Optional.empty(); private boolean collectStats = false; @@ -314,6 +369,7 @@ public Builder(ScanOptions options) { this.batchReadahead = options.getBatchReadahead(); this.columnOrderings = options.getColumnOrderings(); this.useScalarIndex = options.isUseScalarIndex(); + this.fastSearch = options.isFastSearch(); this.substraitAggregate = options.getSubstraitAggregate(); this.collectStats = options.isCollectStats(); } @@ -481,6 +537,21 @@ public Builder useScalarIndex(boolean useScalarIndex) { return this; } + /** + * Set whether to only search indexed fragments. + * + *

    This is a weak-consistency mode for vector search, full text search, and scalar-indexed + * filters. It can reduce latency by skipping recently appended fragments that are not covered + * by the relevant index. + * + * @param fastSearch true to skip unindexed fragments, false otherwise. Default is false. + * @return Builder instance for method chaining. + */ + public Builder fastSearch(boolean fastSearch) { + this.fastSearch = fastSearch; + return this; + } + /** * Set the substrait aggregate expression. * @@ -529,7 +600,8 @@ public ScanOptions build() { columnOrderings, useScalarIndex, substraitAggregate, - collectStats); + collectStats, + fastSearch); } } } diff --git a/java/src/test/java/org/lance/AsyncScannerTest.java b/java/src/test/java/org/lance/AsyncScannerTest.java index 98f46887b64..43b15ecf100 100644 --- a/java/src/test/java/org/lance/AsyncScannerTest.java +++ b/java/src/test/java/org/lance/AsyncScannerTest.java @@ -13,6 +13,10 @@ */ package org.lance; +import org.lance.index.IndexOptions; +import org.lance.index.IndexParams; +import org.lance.index.IndexType; +import org.lance.index.scalar.ScalarIndexParams; import org.lance.ipc.AsyncScanner; import org.lance.ipc.ScanOptions; @@ -28,7 +32,9 @@ import java.nio.file.Path; import java.util.ArrayList; +import java.util.Collections; import java.util.List; +import java.util.Optional; import java.util.concurrent.CompletableFuture; import java.util.concurrent.TimeUnit; @@ -129,6 +135,56 @@ void testAsyncScanWithFilter(@TempDir Path tempDir) throws Exception { } } + @Test + void testFastSearchSkipsUnindexedFragments(@TempDir Path tempDir) throws Exception { + String datasetPath = tempDir.resolve("async_scanner_fast_search_scalar_index").toString(); + try (BufferAllocator allocator = new RootAllocator()) { + TestUtils.SimpleTestDataset testDataset = + new TestUtils.SimpleTestDataset(allocator, datasetPath); + testDataset.createEmptyDataset().close(); + + try (Dataset dataset = testDataset.write(1, 100)) { + ScalarIndexParams scalarParams = ScalarIndexParams.create("btree", "{}"); + IndexParams indexParams = IndexParams.builder().setScalarIndexParams(scalarParams).build(); + IndexOptions indexOptions = + IndexOptions.builder(Collections.singletonList("id"), IndexType.BTREE, indexParams) + .withIndexName("id_btree_index") + .replace(true) + .build(); + dataset.createIndex(indexOptions); + + FragmentMetadata metadata = testDataset.createNewFragment(10); + FragmentOperation.Append appendOp = + new FragmentOperation.Append(Collections.singletonList(metadata)); + try (Dataset appended = + Dataset.commit(allocator, datasetPath, appendOp, Optional.of(dataset.version()))) { + ScanOptions normalOptions = new ScanOptions.Builder().filter("id < 5").build(); + try (AsyncScanner scanner = AsyncScanner.create(appended, normalOptions, allocator)) { + ArrowReader reader = scanner.scanBatchesAsync().get(10, TimeUnit.SECONDS); + assertEquals(10, countRows(reader)); + reader.close(); + } + + ScanOptions fastOptions = + new ScanOptions.Builder().filter("id < 5").fastSearch(true).build(); + try (AsyncScanner scanner = AsyncScanner.create(appended, fastOptions, allocator)) { + ArrowReader reader = scanner.scanBatchesAsync().get(10, TimeUnit.SECONDS); + assertEquals(5, countRows(reader)); + reader.close(); + } + } + } + } + } + + private static int countRows(ArrowReader reader) throws Exception { + int rowCount = 0; + while (reader.loadNextBatch()) { + rowCount += reader.getVectorSchemaRoot().getRowCount(); + } + return rowCount; + } + /** * Example 3: Multiple concurrent async scans. * diff --git a/java/src/test/java/org/lance/DeltaTest.java b/java/src/test/java/org/lance/DeltaTest.java index 72537207524..ac7056840e4 100755 --- a/java/src/test/java/org/lance/DeltaTest.java +++ b/java/src/test/java/org/lance/DeltaTest.java @@ -28,11 +28,12 @@ import org.apache.arrow.vector.types.pojo.Schema; import org.apache.arrow.vector.util.ByteArrayReadableSeekableByteChannel; import org.junit.jupiter.api.Assertions; -import org.junit.jupiter.api.Assumptions; import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; import java.io.ByteArrayOutputStream; import java.io.IOException; +import java.nio.file.Path; import java.util.Arrays; import java.util.List; import java.util.stream.Collectors; @@ -41,9 +42,9 @@ public class DeltaTest { @Test - public void testInsertedRowsComparedAgainst() throws IOException { + public void testInsertedRowsComparedAgainst(@TempDir Path tempDir) throws IOException { try (RootAllocator allocator = new RootAllocator(Long.MAX_VALUE)) { - String uri = "memory://delta_demo"; + String uri = tempDir.resolve("delta_demo").toString(); // Build initial batch (2 rows) Schema schema = new Schema( @@ -79,7 +80,11 @@ public void testInsertedRowsComparedAgainst() throws IOException { org.apache.arrow.c.ArrowArrayStream.allocateNew(allocator)) { Data.exportArrayStream(allocator, reader1, stream1); Dataset ds = - Dataset.write().stream(stream1).uri(uri).mode(WriteParams.WriteMode.CREATE).execute(); + Dataset.write().stream(stream1) + .uri(uri) + .mode(WriteParams.WriteMode.CREATE) + .enableStableRowIds(true) + .execute(); // Append one row (v2) VectorSchemaRoot root2 = VectorSchemaRoot.create(schema, allocator); @@ -107,40 +112,35 @@ public void testInsertedRowsComparedAgainst() throws IOException { Dataset.write().stream(stream2).uri(uri).mode(WriteParams.WriteMode.APPEND).execute(); DatasetDelta delta = ds2.delta(1L); - try { - try (ArrowReader inserted = delta.getInsertedRows()) { - int total = 0; - boolean foundRow = false; - - while (inserted.loadNextBatch()) { - VectorSchemaRoot outRoot = inserted.getVectorSchemaRoot(); - Schema outSchema = outRoot.getSchema(); - List names = - outSchema.getFields().stream().map(Field::getName).collect(Collectors.toList()); - Assertions.assertTrue(names.contains("_row_created_at_version")); - Assertions.assertTrue(names.contains("_row_last_updated_at_version")); - - IntVector outId = (IntVector) outRoot.getVector("id"); - VarCharVector outVal = (VarCharVector) outRoot.getVector("val"); - - for (int i = 0; i < outRoot.getRowCount(); i++) { - int id = outId.get(i); - byte[] bytes = outVal.get(i); - String val = new String(bytes, java.nio.charset.StandardCharsets.UTF_8); - if (id == 3 && "c".equals(val)) { - foundRow = true; - } + try (ArrowReader inserted = delta.getInsertedRows()) { + int total = 0; + boolean foundRow = false; + + while (inserted.loadNextBatch()) { + VectorSchemaRoot outRoot = inserted.getVectorSchemaRoot(); + Schema outSchema = outRoot.getSchema(); + List names = + outSchema.getFields().stream().map(Field::getName).collect(Collectors.toList()); + Assertions.assertTrue(names.contains("_row_created_at_version")); + Assertions.assertTrue(names.contains("_row_last_updated_at_version")); + + IntVector outId = (IntVector) outRoot.getVector("id"); + VarCharVector outVal = (VarCharVector) outRoot.getVector("val"); + + for (int i = 0; i < outRoot.getRowCount(); i++) { + int id = outId.get(i); + byte[] bytes = outVal.get(i); + String val = new String(bytes, java.nio.charset.StandardCharsets.UTF_8); + if (id == 3 && "c".equals(val)) { + foundRow = true; } - - total += outRoot.getRowCount(); } - Assertions.assertEquals(1, total); - Assertions.assertTrue(foundRow, "Inserted row (id=3, val=c) not found in delta"); + total += outRoot.getRowCount(); } - } catch (UnsatisfiedLinkError e) { - Assumptions.assumeTrue( - false, "JNI for DatasetDelta.getInsertedRows not available: " + e.getMessage()); + + Assertions.assertEquals(1, total); + Assertions.assertTrue(foundRow, "Inserted row (id=3, val=c) not found in delta"); } } } @@ -148,10 +148,9 @@ public void testInsertedRowsComparedAgainst() throws IOException { } @Test - public void testListTransactionsExplicitRange() { + public void testListTransactionsExplicitRange(@TempDir Path tempDir) throws IOException { try (RootAllocator allocator = new RootAllocator(Long.MAX_VALUE)) { - String uri = "memory://delta_demo_tx"; - // v1 + String uri = tempDir.resolve("delta_demo_tx").toString(); Schema schema = new Schema( Arrays.asList( @@ -159,21 +158,59 @@ public void testListTransactionsExplicitRange() { "id", new org.apache.arrow.vector.types.pojo.ArrowType.Int(32, true)), Field.nullable( "val", org.apache.arrow.vector.types.pojo.ArrowType.Utf8.INSTANCE))); - try (Dataset ds = Dataset.create(allocator, uri, schema, new WriteParams.Builder().build())) { - // v2 - WriteParams params = - new WriteParams.Builder().withMode(WriteParams.WriteMode.APPEND).build(); - try (Dataset ds2 = Dataset.create(allocator, uri, schema, params); ) { + + // v1: create with two rows. + byte[] batch1 = writeBatch(allocator, schema, new int[] {1, 2}, new String[] {"a", "b"}); + try (ArrowStreamReader reader1 = + new ArrowStreamReader(new ByteArrayReadableSeekableByteChannel(batch1), allocator); + ArrowArrayStream stream1 = ArrowArrayStream.allocateNew(allocator)) { + Data.exportArrayStream(allocator, reader1, stream1); + Dataset.write().stream(stream1) + .uri(uri) + .mode(WriteParams.WriteMode.CREATE) + .execute() + .close(); + } + + // v2: append one row. + byte[] batch2 = writeBatch(allocator, schema, new int[] {3}, new String[] {"c"}); + try (ArrowStreamReader reader2 = + new ArrowStreamReader(new ByteArrayReadableSeekableByteChannel(batch2), allocator); + ArrowArrayStream stream2 = ArrowArrayStream.allocateNew(allocator)) { + Data.exportArrayStream(allocator, reader2, stream2); + try (Dataset ds2 = + Dataset.write().stream(stream2).uri(uri).mode(WriteParams.WriteMode.APPEND).execute()) { DatasetDelta delta = ds2.delta(1L, 2L); - try { - List txs = delta.listTransactions(); - Assertions.assertTrue(txs.size() == 1); - } catch (UnsatisfiedLinkError e) { - Assumptions.assumeTrue( - false, "JNI for DatasetDelta.listTransactions not available: " + e.getMessage()); - } + List txs = delta.listTransactions(); + Assertions.assertEquals(1, txs.size(), "delta v1..v2 should contain exactly one txn"); } } } } + + /** Helper: serialize a single Arrow batch with the given schema and (id, val) pairs. */ + private static byte[] writeBatch(RootAllocator allocator, Schema schema, int[] ids, String[] vals) + throws IOException { + Assertions.assertEquals(ids.length, vals.length, "ids and vals must align"); + VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator); + try { + root.allocateNew(); + IntVector idVec = (IntVector) root.getVector("id"); + VarCharVector valVec = (VarCharVector) root.getVector("val"); + for (int i = 0; i < ids.length; i++) { + idVec.setSafe(i, ids[i]); + valVec.setSafe(i, vals[i].getBytes(java.nio.charset.StandardCharsets.UTF_8)); + } + root.setRowCount(ids.length); + ByteArrayOutputStream out = new ByteArrayOutputStream(); + try (ArrowStreamWriter writer = new ArrowStreamWriter(root, null, out)) { + writer.start(); + writer.writeBatch(); + writer.end(); + } + return out.toByteArray(); + } finally { + root.close(); + } + } } diff --git a/java/src/test/java/org/lance/ScannerTest.java b/java/src/test/java/org/lance/ScannerTest.java index da80a74662c..894b208e8af 100644 --- a/java/src/test/java/org/lance/ScannerTest.java +++ b/java/src/test/java/org/lance/ScannerTest.java @@ -660,6 +660,43 @@ void testUseScalarIndex(@TempDir Path tempDir) throws Exception { } } + @Test + void testFastSearchSkipsUnindexedFragments(@TempDir Path tempDir) throws Exception { + String datasetPath = tempDir.resolve("dataset_scanner_fast_search_scalar_index").toString(); + try (BufferAllocator allocator = new RootAllocator()) { + TestUtils.SimpleTestDataset testDataset = + new TestUtils.SimpleTestDataset(allocator, datasetPath); + testDataset.createEmptyDataset().close(); + try (Dataset dataset = testDataset.write(1, 100)) { + ScalarIndexParams scalarParams = ScalarIndexParams.create("btree", "{}"); + IndexParams indexParams = IndexParams.builder().setScalarIndexParams(scalarParams).build(); + IndexOptions options = + IndexOptions.builder(Collections.singletonList("id"), IndexType.BTREE, indexParams) + .withIndexName("id_btree_index") + .replace(true) + .build(); + dataset.createIndex(options); + + FragmentMetadata metadata = testDataset.createNewFragment(10); + FragmentOperation.Append appendOp = + new FragmentOperation.Append(Collections.singletonList(metadata)); + try (Dataset appended = + Dataset.commit(allocator, datasetPath, appendOp, Optional.of(dataset.version()))) { + try (LanceScanner scanner = + appended.newScan(new ScanOptions.Builder().filter("id < 5").build())) { + assertEquals(10, scanner.countRows()); + } + + try (LanceScanner scanner = + appended.newScan( + new ScanOptions.Builder().filter("id < 5").fastSearch(true).build())) { + assertEquals(5, scanner.countRows()); + } + } + } + } + } + private void validScanResult(Dataset dataset, int fragmentId, int rowCount) throws Exception { try (Scanner scanner = dataset.newScan( diff --git a/java/src/test/java/org/lance/VectorSearchTest.java b/java/src/test/java/org/lance/VectorSearchTest.java index 0a34640da7e..8a82ecb2849 100644 --- a/java/src/test/java/org/lance/VectorSearchTest.java +++ b/java/src/test/java/org/lance/VectorSearchTest.java @@ -63,35 +63,37 @@ void test_create_index() throws Exception { } } - // rust/lance-linalg/src/distance/l2.rs:256:5: - // 5assertion `left == right` failed - // Directly panic instead of throwing an exception - // @Test - // void search_invalid_vector() throws Exception { - // try (TestVectorDataset testVectorDataset = new - // TestVectorDataset(tempDir.resolve("test_create_index"))) { - // try (Dataset dataset = testVectorDataset.create()) { - // float[] key = new float[30]; - // for (int i = 0; i < 30; i++) { - // key[i] = (float) (i + 30); - // } - // ScanOptions options = new ScanOptions.Builder() - // .nearest(new Query.Builder() - // .setColumn(TestVectorDataset.vectorColumnName) - // .setKey(key) - // .setK(5) - // .setUseIndex(false) - // .build()) - // .build(); - // assertThrows(IllegalArgumentException.class, () -> { - // try (Scanner scanner = dataset.newScan(options)) { - // try (ArrowReader reader = scanner.scanBatches()) { - // } - // } - // }); - // } - // } - // } + @Test + void search_invalid_vector() throws Exception { + try (TestVectorDataset testVectorDataset = + new TestVectorDataset(tempDir.resolve("search_invalid_vector"))) { + try (Dataset dataset = testVectorDataset.create()) { + float[] key = new float[30]; + for (int i = 0; i < 30; i++) { + key[i] = (float) (i + 30); + } + ScanOptions options = + new ScanOptions.Builder() + .nearest( + new Query.Builder() + .setColumn(TestVectorDataset.vectorColumnName) + .setKey(key) + .setK(5) + .setUseIndex(false) + .build()) + .build(); + assertThrows( + IllegalArgumentException.class, + () -> { + try (Scanner scanner = dataset.newScan(options)) { + try (ArrowReader reader = scanner.scanBatches()) { + reader.loadNextBatch(); + } + } + }); + } + } + } @ParameterizedTest @ValueSource(booleans = {false, true}) diff --git a/memtest/src/lib.rs b/memtest/src/lib.rs index f9caaa73e4d..e505f666e0d 100644 --- a/memtest/src/lib.rs +++ b/memtest/src/lib.rs @@ -23,6 +23,9 @@ pub unsafe extern "C" fn memtest_get_stats(stats: *mut MemtestStats) { if stats.is_null() { return; } + if (stats as usize).wrapping_rem(std::mem::align_of::()) != 0 { + return; + } (*stats).total_allocations = STATS .total_allocations @@ -47,3 +50,29 @@ pub unsafe extern "C" fn memtest_get_stats(stats: *mut MemtestStats) { pub extern "C" fn memtest_reset_stats() { STATS.reset(); } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_get_stats_null() { + unsafe { + memtest_get_stats(std::ptr::null_mut()); + } + } + + #[test] + fn test_get_stats_misaligned() { + unsafe { + let align = std::mem::align_of::(); + let size = std::mem::size_of::(); + let buf_size = size.saturating_add(align).saturating_add(align); + let mut buf = vec![0u8; buf_size]; + let base = buf.as_mut_ptr() as usize; + let offset = if base.wrapping_rem(align) == 0 { 1 } else { 0 }; + let misaligned = buf.as_mut_ptr().add(offset) as *mut MemtestStats; + memtest_get_stats(misaligned); + } + } +} diff --git a/python/Cargo.lock b/python/Cargo.lock index 399f18afee3..8a86ace4dc8 100644 --- a/python/Cargo.lock +++ b/python/Cargo.lock @@ -557,9 +557,9 @@ checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0" [[package]] name = "autocfg" -version = "1.5.0" +version = "1.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8" +checksum = "f2032f911046de80f0a198e0901378627c33f59ea0ac00e363d481118bd70a53" [[package]] name = "aws-config" @@ -582,7 +582,7 @@ dependencies = [ "bytes", "fastrand", "hex", - "http 1.4.0", + "http 1.4.1", "ring", "time", "tokio", @@ -643,7 +643,7 @@ dependencies = [ "bytes", "bytes-utils", "fastrand", - "http 1.4.0", + "http 1.4.1", "http-body 1.0.1", "percent-encoding", "pin-project-lite", @@ -670,7 +670,7 @@ dependencies = [ "bytes", "fastrand", "http 0.2.12", - "http 1.4.0", + "http 1.4.1", "regex-lite", "tracing", ] @@ -694,7 +694,7 @@ dependencies = [ "bytes", "fastrand", "http 0.2.12", - "http 1.4.0", + "http 1.4.1", "regex-lite", "tracing", ] @@ -718,7 +718,7 @@ dependencies = [ "bytes", "fastrand", "http 0.2.12", - "http 1.4.0", + "http 1.4.1", "regex-lite", "tracing", ] @@ -743,7 +743,7 @@ dependencies = [ "aws-types", "fastrand", "http 0.2.12", - "http 1.4.0", + "http 1.4.1", "regex-lite", "tracing", ] @@ -763,7 +763,7 @@ dependencies = [ "hex", "hmac", "http 0.2.12", - "http 1.4.0", + "http 1.4.1", "percent-encoding", "sha2", "time", @@ -793,7 +793,7 @@ dependencies = [ "bytes-utils", "futures-core", "futures-util", - "http 1.4.0", + "http 1.4.1", "http-body 1.0.1", "http-body-util", "percent-encoding", @@ -812,7 +812,7 @@ dependencies = [ "aws-smithy-runtime-api", "aws-smithy-types", "h2", - "http 1.4.0", + "http 1.4.1", "hyper", "hyper-rustls", "hyper-util", @@ -869,7 +869,7 @@ dependencies = [ "bytes", "fastrand", "http 0.2.12", - "http 1.4.0", + "http 1.4.1", "http-body 0.4.6", "http-body 1.0.1", "http-body-util", @@ -889,7 +889,7 @@ dependencies = [ "aws-smithy-types", "bytes", "http 0.2.12", - "http 1.4.0", + "http 1.4.1", "pin-project-lite", "tokio", "tracing", @@ -907,7 +907,7 @@ dependencies = [ "bytes-utils", "futures-core", "http 0.2.12", - "http 1.4.0", + "http 1.4.1", "http-body 0.4.6", "http-body 1.0.1", "http-body-util", @@ -955,7 +955,7 @@ dependencies = [ "axum-core", "bytes", "futures-util", - "http 1.4.0", + "http 1.4.1", "http-body 1.0.1", "http-body-util", "hyper", @@ -988,7 +988,7 @@ dependencies = [ "async-trait", "bytes", "futures-util", - "http 1.4.0", + "http 1.4.1", "http-body 1.0.1", "http-body-util", "mime", @@ -1163,9 +1163,9 @@ dependencies = [ [[package]] name = "bumpalo" -version = "3.20.2" +version = "3.20.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5d20789868f4b01b2f2caec9f5c4e0213b41e3e5702a50157d699ae31ced2fcb" +checksum = "72f5acc6cb2ba439de613abc23857ec3d78374d8ed5ac84e9d11336e87da8649" [[package]] name = "bytecheck" @@ -2853,7 +2853,7 @@ checksum = "42703706b716c37f96a77aea830392ad231f44c9e9a67872fa5548707e11b11c" [[package]] name = "fsst" -version = "7.1.0-beta.1" +version = "7.1.0-beta.4" dependencies = [ "arrow-array", "rand 0.9.4", @@ -3119,9 +3119,9 @@ dependencies = [ [[package]] name = "geohash" -version = "0.13.1" +version = "0.13.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0fb94b1a65401d6cbf22958a9040aa364812c26674f841bee538b12c135db1e6" +checksum = "7f58890382f70caccc5fa388981f7ac80c913795042afce9f3e065695d8f7464" dependencies = [ "geo-types", "libm", @@ -3219,7 +3219,7 @@ dependencies = [ "fnv", "futures-core", "futures-sink", - "http 1.4.0", + "http 1.4.1", "indexmap 2.14.0", "slab", "tokio", @@ -3328,7 +3328,7 @@ checksum = "430b33fa84f92796d4d263070b6c0d3ca219df7b9a0e1853ee431029b1612bcd" dependencies = [ "async-trait", "bytes", - "http 1.4.0", + "http 1.4.1", "more-asserts", "serde", "thiserror 2.0.18", @@ -3364,9 +3364,9 @@ dependencies = [ [[package]] name = "http" -version = "1.4.0" +version = "1.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e3ba2a386d7f85a81f119ad7498ebe444d2e22c2af0b86b069416ace48b3311a" +checksum = "8be7462df143984c4598a256ef469b251d7d7f9e271135073e78fc535414f3d0" dependencies = [ "bytes", "itoa", @@ -3390,7 +3390,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1efedce1fb8e6913f23e0c92de8e62cd5b772a67e7b3946df930a62566c93184" dependencies = [ "bytes", - "http 1.4.0", + "http 1.4.1", ] [[package]] @@ -3401,7 +3401,7 @@ checksum = "b021d93e26becf5dc7e1b75b1bed1fd93124b374ceb73f43d4d4eafec896a64a" dependencies = [ "bytes", "futures-core", - "http 1.4.0", + "http 1.4.1", "http-body 1.0.1", "pin-project-lite", ] @@ -3435,7 +3435,7 @@ dependencies = [ "futures-channel", "futures-core", "h2", - "http 1.4.0", + "http 1.4.1", "http-body 1.0.1", "httparse", "httpdate", @@ -3452,7 +3452,7 @@ version = "0.27.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "33ca68d021ef39cf6463ab54c1d0f5daf03377b70561305bb89a8f83aab66e0f" dependencies = [ - "http 1.4.0", + "http 1.4.1", "hyper", "hyper-util", "rustls", @@ -3472,7 +3472,7 @@ dependencies = [ "bytes", "futures-channel", "futures-util", - "http 1.4.0", + "http 1.4.1", "http-body 1.0.1", "hyper", "ipnet", @@ -3577,6 +3577,21 @@ dependencies = [ "zerovec", ] +[[package]] +name = "icu_locale" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d5a396343c7208121dc86e35623d3dfe19814a7613cfd14964994cdc9c9a2e26" +dependencies = [ + "icu_collections", + "icu_locale_core", + "icu_locale_data", + "icu_provider", + "potential_utf", + "tinystr", + "zerovec", +] + [[package]] name = "icu_locale_core" version = "2.2.0" @@ -3585,11 +3600,18 @@ checksum = "92219b62b3e2b4d88ac5119f8904c10f8f61bf7e95b640d25ba3075e6cac2c29" dependencies = [ "displaydoc", "litemap", + "serde", "tinystr", "writeable", "zerovec", ] +[[package]] +name = "icu_locale_data" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d5fdcc9ac77c6d74ff5cf6e65ef3181d6af32003b16fce3a77fb451d2f695993" + [[package]] name = "icu_normalizer" version = "2.2.0" @@ -3638,6 +3660,8 @@ checksum = "139c4cf31c8b5f33d7e199446eff9c1e02decfc2f0eec2c8d71f65befa45b421" dependencies = [ "displaydoc", "icu_locale_core", + "serde", + "stable_deref_trait", "writeable", "yoke", "zerofrom", @@ -3645,6 +3669,27 @@ dependencies = [ "zerovec", ] +[[package]] +name = "icu_segmenter" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c0794db0b1a86193ac9c48768d0e6c52c54448e0870ad87907d456ee0dac964" +dependencies = [ + "icu_collections", + "icu_locale", + "icu_provider", + "icu_segmenter_data", + "potential_utf", + "utf8_iter", + "zerovec", +] + +[[package]] +name = "icu_segmenter_data" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e4a2c462a4d927d512f5f882a033ddd62f33a05bb9f230d98f736ac3dc85938f" + [[package]] name = "id-arena" version = "2.3.0" @@ -3775,18 +3820,18 @@ checksum = "8f42a60cbdf9a97f5d2305f08a87dc4e09308d1276d28c869c684d7777685682" [[package]] name = "jieba-macros" -version = "0.9.0" +version = "0.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a29cfc5dcd898604c6f80363411fa6b6b08e27d1d253d6225b9cb6702ea02fc0" +checksum = "661344b2412fb00aee1841d2405c9a31f7c91cf6e578a8e953647c43dd1a8b0a" dependencies = [ "phf_codegen", ] [[package]] name = "jieba-rs" -version = "0.9.0" +version = "0.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3245d6e9d1d5facbd6a23848d6b67e3439738ccbb4fa5a3d65da315ba1a910a2" +checksum = "d7ef90d6209fcff084a01b488c4199d882e3764b15ff0e7a6b5d7efaa46e1e4f" dependencies = [ "cedarwood", "jieba-macros", @@ -3797,9 +3842,9 @@ dependencies = [ [[package]] name = "jiff" -version = "0.2.24" +version = "0.2.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f00b5dbd620d61dfdcb6007c9c1f6054ebd75319f163d886a9055cec1155073d" +checksum = "392c70591e8749fe235ddaf513e6f58b26bce3dcc16524cecc8936f75afa161e" dependencies = [ "jiff-static", "jiff-tzdb-platform", @@ -3809,14 +3854,14 @@ dependencies = [ "portable-atomic-util", "serde_core", "wasm-bindgen", - "windows-sys 0.61.2", + "windows-link", ] [[package]] name = "jiff-static" -version = "0.2.24" +version = "0.2.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e000de030ff8022ea1da3f466fbb0f3a809f5e51ed31f6dd931c35181ad8e6d7" +checksum = "47b605b0c050d845fc355bb11eb3f9a8deddc218ea60c76e61aa1f2adfb2c96a" dependencies = [ "proc-macro2", "quote", @@ -3899,9 +3944,9 @@ dependencies = [ [[package]] name = "js-sys" -version = "0.3.98" +version = "0.3.99" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "67df7112613f8bfd9150013a0314e196f4800d3201ae742489d999db2f979f08" +checksum = "142bc4740e452c1e57ade0cbc129f139c9093e354346f0872ef985f4f5cf5f11" dependencies = [ "cfg-if 1.0.4", "futures-util", @@ -3975,7 +4020,7 @@ checksum = "e037a2e1d8d5fdbd49b16a4ea09d5d6401c1f29eca5ff29d03d3824dba16256a" [[package]] name = "lance" -version = "7.1.0-beta.1" +version = "7.1.0-beta.4" dependencies = [ "arc-swap", "arrow", @@ -4048,7 +4093,7 @@ dependencies = [ [[package]] name = "lance-arrow" -version = "7.1.0-beta.1" +version = "7.1.0-beta.4" dependencies = [ "arrow-array", "arrow-buffer", @@ -4066,9 +4111,31 @@ dependencies = [ "rand 0.9.4", ] +[[package]] +name = "lance-arrow-scalar" +version = "58.0.0" +dependencies = [ + "arrow-array", + "arrow-buffer", + "arrow-cast", + "arrow-data", + "arrow-row", + "arrow-schema", + "half", +] + +[[package]] +name = "lance-arrow-stats" +version = "58.0.0" +dependencies = [ + "arrow-array", + "arrow-schema", + "lance-arrow-scalar", +] + [[package]] name = "lance-bitpacking" -version = "7.1.0-beta.1" +version = "7.1.0-beta.4" dependencies = [ "arrayref", "paste", @@ -4077,7 +4144,7 @@ dependencies = [ [[package]] name = "lance-core" -version = "7.1.0-beta.1" +version = "7.1.0-beta.4" dependencies = [ "arrow-array", "arrow-buffer", @@ -4112,7 +4179,7 @@ dependencies = [ [[package]] name = "lance-datafusion" -version = "7.1.0-beta.1" +version = "7.1.0-beta.4" dependencies = [ "arrow", "arrow-array", @@ -4144,7 +4211,7 @@ dependencies = [ [[package]] name = "lance-datagen" -version = "7.1.0-beta.1" +version = "7.1.0-beta.4" dependencies = [ "arrow", "arrow-array", @@ -4162,7 +4229,7 @@ dependencies = [ [[package]] name = "lance-encoding" -version = "7.1.0-beta.1" +version = "7.1.0-beta.4" dependencies = [ "arrow-arith", "arrow-array", @@ -4197,7 +4264,7 @@ dependencies = [ [[package]] name = "lance-file" -version = "7.1.0-beta.1" +version = "7.1.0-beta.4" dependencies = [ "arrow-arith", "arrow-array", @@ -4228,7 +4295,7 @@ dependencies = [ [[package]] name = "lance-geo" -version = "7.1.0-beta.1" +version = "7.1.0-beta.4" dependencies = [ "datafusion", "geo-traits", @@ -4242,7 +4309,7 @@ dependencies = [ [[package]] name = "lance-index" -version = "7.1.0-beta.1" +version = "7.1.0-beta.4" dependencies = [ "arc-swap", "arrow", @@ -4275,6 +4342,7 @@ dependencies = [ "jieba-rs", "jsonb", "lance-arrow", + "lance-arrow-stats", "lance-core", "lance-datafusion", "lance-datagen", @@ -4311,7 +4379,7 @@ dependencies = [ [[package]] name = "lance-io" -version = "7.1.0-beta.1" +version = "7.1.0-beta.4" dependencies = [ "arrow", "arrow-arith", @@ -4330,7 +4398,7 @@ dependencies = [ "chrono", "deepsize", "futures", - "http 1.4.0", + "http 1.4.1", "io-uring", "lance-arrow", "lance-core", @@ -4353,7 +4421,7 @@ dependencies = [ [[package]] name = "lance-linalg" -version = "7.1.0-beta.1" +version = "7.1.0-beta.4" dependencies = [ "arrow-array", "arrow-buffer", @@ -4369,7 +4437,7 @@ dependencies = [ [[package]] name = "lance-namespace" -version = "7.1.0-beta.1" +version = "7.1.0-beta.4" dependencies = [ "arrow", "async-trait", @@ -4381,7 +4449,7 @@ dependencies = [ [[package]] name = "lance-namespace-impls" -version = "7.1.0-beta.1" +version = "7.1.0-beta.4" dependencies = [ "arrow", "arrow-ipc", @@ -4425,7 +4493,7 @@ dependencies = [ [[package]] name = "lance-select" -version = "7.1.0-beta.1" +version = "7.1.0-beta.4" dependencies = [ "arrow-array", "arrow-buffer", @@ -4439,7 +4507,7 @@ dependencies = [ [[package]] name = "lance-table" -version = "7.1.0-beta.1" +version = "7.1.0-beta.4" dependencies = [ "arrow", "arrow-array", @@ -4479,8 +4547,9 @@ dependencies = [ [[package]] name = "lance-tokenizer" -version = "7.1.0-beta.1" +version = "7.1.0-beta.4" dependencies = [ + "icu_segmenter", "jieba-rs", "lindera", "rust-stemmers", @@ -4694,9 +4763,9 @@ dependencies = [ [[package]] name = "log" -version = "0.4.29" +version = "0.4.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5e5032e24019045c762d3c0f28f5b6b8bbf38563a65908389bf7978758920897" +checksum = "616ec5685824bcc94416c6d4a7a446eea774a31efd7062c8480ba6fd06d7a6e5" [[package]] name = "loom" @@ -5092,7 +5161,7 @@ dependencies = [ "futures-channel", "futures-core", "futures-util", - "http 1.4.0", + "http 1.4.1", "http-body-util", "httparse", "humantime", @@ -5184,7 +5253,7 @@ dependencies = [ "base64", "bytes", "futures", - "http 1.4.0", + "http 1.4.1", "http-body 1.0.1", "jiff", "log", @@ -5193,7 +5262,7 @@ dependencies = [ "percent-encoding", "quick-xml 0.38.4", "reqsign-core", - "reqwest 0.13.3", + "reqwest 0.13.4", "serde", "serde_json", "tokio", @@ -5209,7 +5278,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "048b1b29c503263bdd80a9afe46a68cd02ea9bd361185b1feab4b151078998e9" dependencies = [ "futures", - "http 1.4.0", + "http 1.4.1", "mea", "opendal-core", ] @@ -5253,7 +5322,7 @@ checksum = "7452bf3ec61cfd81ac9ad9ada17825931e9e371d44a045c6bfab9596c0a2ac3b" dependencies = [ "base64", "bytes", - "http 1.4.0", + "http 1.4.1", "log", "opendal-core", "opendal-service-azure-common", @@ -5273,7 +5342,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8f9884c2d8cf8ba2bb077d79c877dac5863ba3bab9e2c9c1e41a2e0491404772" dependencies = [ "bytes", - "http 1.4.0", + "http 1.4.1", "log", "opendal-core", "opendal-service-azure-common", @@ -5291,7 +5360,7 @@ version = "0.56.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ffb0e45d6c8dcf66ce2da20e241bcb80e6e540e109a4ff20f318f6c9b4c54e0c" dependencies = [ - "http 1.4.0", + "http 1.4.1", "opendal-core", ] @@ -5302,7 +5371,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "55a0765ba451b6effdbf514b7b50060530ff8a29e4231c4a3ab7792c016408e6" dependencies = [ "bytes", - "http 1.4.0", + "http 1.4.1", "log", "opendal-core", "quick-xml 0.38.4", @@ -5320,7 +5389,7 @@ checksum = "70a49477a10163431896d106136117f5670717f9c9e49cf6f710528800c6633a" dependencies = [ "async-trait", "bytes", - "http 1.4.0", + "http 1.4.1", "log", "opendal-core", "percent-encoding", @@ -5341,11 +5410,11 @@ checksum = "7b2ab7a2a8a11dfe257ef4db5c0de798acbcd0d6429c37382dad2154bc06a388" dependencies = [ "bytes", "hf-xet", - "http 1.4.0", + "http 1.4.1", "log", "opendal-core", "percent-encoding", - "reqwest 0.13.3", + "reqwest 0.13.4", "serde", "serde_json", ] @@ -5357,7 +5426,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "29c8a917829ad06d21b639558532cb0101fe49b040d946d673a73018683fac05" dependencies = [ "bytes", - "http 1.4.0", + "http 1.4.1", "log", "opendal-core", "quick-xml 0.38.4", @@ -5376,7 +5445,7 @@ dependencies = [ "base64", "bytes", "crc32c", - "http 1.4.0", + "http 1.4.1", "log", "md-5", "opendal-core", @@ -5770,6 +5839,8 @@ version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0103b1cef7ec0cf76490e969665504990193874ea05c85ff9bab8b911d0a0564" dependencies = [ + "serde_core", + "writeable", "zerovec", ] @@ -5899,7 +5970,7 @@ dependencies = [ [[package]] name = "pylance" -version = "7.1.0-beta.1" +version = "7.1.0-beta.4" dependencies = [ "arrow", "arrow-array", @@ -6406,7 +6477,7 @@ checksum = "57ac2757f3140aa2e213b554148ae0b52733e624fc6723f0cc6bb3d440176c95" dependencies = [ "anyhow", "form_urlencoded", - "http 1.4.0", + "http 1.4.1", "log", "percent-encoding", "reqsign-core", @@ -6424,7 +6495,7 @@ dependencies = [ "anyhow", "bytes", "form_urlencoded", - "http 1.4.0", + "http 1.4.1", "log", "percent-encoding", "quick-xml 0.39.4", @@ -6446,7 +6517,7 @@ dependencies = [ "base64", "bytes", "form_urlencoded", - "http 1.4.0", + "http 1.4.1", "jsonwebtoken", "log", "pem", @@ -6471,7 +6542,7 @@ dependencies = [ "futures", "hex", "hmac", - "http 1.4.0", + "http 1.4.1", "jiff", "log", "percent-encoding", @@ -6498,7 +6569,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "35cc609b49c69e76ecaceb775a03f792d1ed3e7755ab3548d4534fd801e3242e" dependencies = [ "form_urlencoded", - "http 1.4.0", + "http 1.4.1", "jsonwebtoken", "log", "percent-encoding", @@ -6518,7 +6589,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e128f19525861dbded59e1e7c17653a8ed63d573ca04aed708d552dbef5bb32a" dependencies = [ "anyhow", - "http 1.4.0", + "http 1.4.1", "log", "percent-encoding", "reqsign-core", @@ -6538,7 +6609,7 @@ dependencies = [ "futures-core", "futures-util", "h2", - "http 1.4.0", + "http 1.4.1", "http-body 1.0.1", "http-body-util", "hyper", @@ -6573,15 +6644,15 @@ dependencies = [ [[package]] name = "reqwest" -version = "0.13.3" +version = "0.13.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "62e0021ea2c22aed41653bc7e1419abb2c97e038ff2c33d0e1309e49a97deec0" +checksum = "219c5811de6525e5416c7d5d53bb656d3afdbc6c5af816e0802bcfa42dbdc1c3" dependencies = [ "base64", "bytes", "futures-core", "futures-util", - "http 1.4.0", + "http 1.4.1", "http-body 1.0.1", "http-body-util", "hyper", @@ -6619,8 +6690,8 @@ checksum = "07bc3f1384cffa4f274dad2d4ddd73aed32fed8f786d96c6be8aa4e5fd3c3b58" dependencies = [ "anyhow", "async-trait", - "http 1.4.0", - "reqwest 0.13.3", + "http 1.4.1", + "reqwest 0.13.4", "thiserror 2.0.18", "tower-service", ] @@ -7039,9 +7110,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.149" +version = "1.0.150" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "83fc039473c5595ace860d8c4fafa220ff474b3fc6bfdb4293327f1a37e94d86" +checksum = "e8014e44b4736ed0538adeecded0fce2a272f22dc9578a7eb6b2d9993c74cfb9" dependencies = [ "itoa", "memchr", @@ -7728,6 +7799,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c8323304221c2a851516f22236c5722a72eaa19749016521d6dff0824447d96d" dependencies = [ "displaydoc", + "serde_core", "zerovec", ] @@ -7874,7 +7946,7 @@ checksum = "1e9cd434a998747dd2c4276bc96ee2e0c7a2eadf3cae88e52be55a05fa9053f5" dependencies = [ "bitflags 2.11.1", "bytes", - "http 1.4.0", + "http 1.4.1", "http-body 1.0.1", "http-body-util", "pin-project-lite", @@ -7894,7 +7966,7 @@ dependencies = [ "bytes", "futures-core", "futures-util", - "http 1.4.0", + "http 1.4.1", "http-body 1.0.1", "http-body-util", "pin-project-lite", @@ -8304,9 +8376,9 @@ dependencies = [ [[package]] name = "wasm-bindgen" -version = "0.2.121" +version = "0.2.122" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "49ace1d07c165b0864824eee619580c4689389afa9dc9ed3a4c75040d82e6790" +checksum = "3ed04576f974d2b2fba0f38c51dbc5518011e38c36bf1143164be765528fd409" dependencies = [ "cfg-if 1.0.4", "once_cell", @@ -8317,9 +8389,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-futures" -version = "0.4.71" +version = "0.4.72" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "96492d0d3ffba25305a7dc88720d250b1401d7edca02cc3bcd50633b424673b8" +checksum = "9473dbd2991ae90b6291c3c32c30c6187ac49aa32f9905d1cce280ec1e110b0f" dependencies = [ "js-sys", "wasm-bindgen", @@ -8327,9 +8399,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro" -version = "0.2.121" +version = "0.2.122" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8e68e6f4afd367a562002c05637acb8578ff2dea1943df76afb9e83d177c8578" +checksum = "916151b09da36bd82f6615cbf3a419e2f0ba23a03c6160e8e92eb6bd4aa1dec6" dependencies = [ "quote", "wasm-bindgen-macro-support", @@ -8337,9 +8409,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro-support" -version = "0.2.121" +version = "0.2.122" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d95a9ec35c64b2a7cb35d3fead40c4238d0940c86d107136999567a4703259f2" +checksum = "299047362ccbfce148b67ab7e73349f77748e00c8296f9542adfad2ad82c5c5e" dependencies = [ "bumpalo", "proc-macro2", @@ -8350,9 +8422,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-shared" -version = "0.2.121" +version = "0.2.122" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c4e0100b01e9f0d03189a92b96772a1fb998639d981193d7dbab487302513441" +checksum = "9a929b2c61f11ba3e9bc35b50c1f25cb38e0e892c0c231ae2b8cf78d5dad4437" dependencies = [ "unicode-ident", ] @@ -8419,9 +8491,9 @@ dependencies = [ [[package]] name = "web-sys" -version = "0.3.98" +version = "0.3.99" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4b572dff8bcf38bad0fa19729c89bb5748b2b9b1d8be70cf90df697e3a8f32aa" +checksum = "6d621441cfc37b84979402712047321980c178f299193a3589d05b99e8763436" dependencies = [ "js-sys", "wasm-bindgen", @@ -8923,13 +8995,13 @@ dependencies = [ "clap", "crc32fast", "futures", - "http 1.4.0", + "http 1.4.1", "hyper", "lazy_static", "more-asserts", "rand 0.10.1", "redb", - "reqwest 0.13.3", + "reqwest 0.13.4", "reqwest-middleware", "serde", "serde_json", @@ -8997,7 +9069,7 @@ dependencies = [ "chrono", "clap", "gearhash", - "http 1.4.0", + "http 1.4.1", "itertools 0.14.0", "lazy_static", "more-asserts", @@ -9042,7 +9114,7 @@ dependencies = [ "oneshot", "pin-project", "rand 0.10.1", - "reqwest 0.13.3", + "reqwest 0.13.4", "serde", "serde_json", "shellexpand", @@ -9162,6 +9234,7 @@ dependencies = [ "displaydoc", "yoke", "zerofrom", + "zerovec", ] [[package]] @@ -9170,6 +9243,7 @@ version = "0.11.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "90f911cbc359ab6af17377d242225f4d75119aec87ea711a880987b18cd7b239" dependencies = [ + "serde", "yoke", "zerofrom", "zerovec-derive", diff --git a/python/Cargo.toml b/python/Cargo.toml index c1bc0630d47..81f5e36a000 100644 --- a/python/Cargo.toml +++ b/python/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "pylance" -version = "7.1.0-beta.1" +version = "7.1.0-beta.4" edition = "2024" authors = ["Lance Devs "] license = "Apache-2.0" @@ -40,6 +40,7 @@ lance-datagen = { path = "../rust/lance-datagen", optional = true } lance-encoding = { path = "../rust/lance-encoding" } lance-file = { path = "../rust/lance-file" } lance-index = { path = "../rust/lance-index", features = [ + "tokenizer-icu", "tokenizer-lindera", "tokenizer-jieba", ] } diff --git a/python/RUST_THIRD_PARTY_LICENSES.html b/python/RUST_THIRD_PARTY_LICENSES.html index e022348a414..8e550498f20 100644 --- a/python/RUST_THIRD_PARTY_LICENSES.html +++ b/python/RUST_THIRD_PARTY_LICENSES.html @@ -13820,8 +13820,8 @@

    Used by:

  • base64-simd 0.8.0
  • deepsize_derive 0.1.2
  • i_overlay 4.0.7
  • -
  • jieba-macros 0.9.0
  • -
  • jieba-rs 0.9.0
  • +
  • jieba-macros 0.10.0
  • +
  • jieba-rs 0.10.0
  • kanaria 0.2.0
  • libm 0.2.16
  • lindera-dictionary 3.0.7
  • @@ -15863,12 +15863,16 @@

    Unicode License v3

    Used by:

    • icu_collections 2.2.0
    • +
    • icu_locale 2.2.0
    • icu_locale_core 2.2.0
    • +
    • icu_locale_data 2.2.0
    • icu_normalizer 2.2.0
    • icu_normalizer_data 2.2.0
    • icu_properties 2.2.0
    • icu_properties_data 2.2.0
    • icu_provider 2.2.0
    • +
    • icu_segmenter 2.2.0
    • +
    • icu_segmenter_data 2.2.0
    • litemap 0.8.2
    • potential_utf 0.1.5
    • tinystr 0.8.3
    • diff --git a/python/python/lance/dataset.py b/python/python/lance/dataset.py index 5737ec013b5..70e6867d8dd 100644 --- a/python/python/lance/dataset.py +++ b/python/python/lance/dataset.py @@ -92,7 +92,7 @@ pa.Array, pa.Scalar, np.ndarray, - Iterable[float], + Iterable[Union[float, Iterable[float]]], ] LANCE_COMMIT_MESSAGE_KEY = "__lance_commit_message" _BLOB_PANDAS_MODE_LAZY = "lazy" @@ -1009,6 +1009,7 @@ def scanner( fragment_readahead: Optional[int] = None, scan_in_order: Optional[bool] = None, fragments: Optional[Iterable[LanceFragment]] = None, + index_segments: Optional[Iterable[Union[str, uuid.UUID]]] = None, full_text_query: Optional[Union[str, dict, FullTextQuery]] = None, *, prefilter: Optional[bool] = None, @@ -1099,6 +1100,17 @@ def scanner( "distance_range": (0.0, 1.0), } + ``q`` may also be a 2-D array-like value, or a list of vectors, for + fixed-size vector columns. In that case Lance runs a batch nearest-neighbor + query, returns up to ``k`` rows for each query vector, and adds + an Int32 non-null ``query_index`` as the first output column to identify + the source query for each result row. + Flattened 1-D arrays whose length is a multiple of the vector dimension are + rejected. Datasets that already contain a ``query_index`` column cannot be + used for batch nearest-neighbor search. When ``use_index`` is true and a + vector index is available, each query vector is searched through the index + path; otherwise the flat batch path is used. + batch_size: int, default None The maximum number of rows per batch. In some cases batches can be smaller than this size. Note: this can be overridden by @@ -1124,6 +1136,11 @@ def scanner( fragments: iterable of LanceFragment, default None If specified, only scan these fragments. If scan_in_order is True, then the fragments will be scanned in the order given. + index_segments: iterable of str or uuid.UUID, default None + If specified, restrict vector index search to these index segment UUIDs. + Only supported for vector search. If fragments is also specified, rows + from those fragments not covered by the selected index segments will be + searched with flat KNN. prefilter: bool, default False If True then the filter will be applied before the vector query is run. This will generate more correct results but it may be a more costly @@ -1172,8 +1189,9 @@ def scanner( - query: str The query string to search for. fast_search: bool, default False - If True, then the search will only be performed on the indexed data, which - yields faster search time. + If True, then vector search, full text search, and scalar-indexed + filters will only search indexed fragments, which yields faster + search time but may skip recently appended unindexed data. scan_stats_callback: Callable[[ScanStatistics], None], default None A callback function that will be called with the scan statistics after the scan is complete. Errors raised by the callback will be logged but not @@ -1251,6 +1269,7 @@ def setopt(opt, val): setopt(builder.fragment_readahead, fragment_readahead) setopt(builder.scan_in_order, scan_in_order) setopt(builder.with_fragments, fragments) + setopt(builder.with_index_segments, index_segments) setopt(builder.late_materialization, late_materialization) setopt(builder.blob_handling, blob_handling) setopt(builder.with_row_id, with_row_id) @@ -1417,6 +1436,8 @@ def to_table( use_stats: bool, optional, default True Use stats pushdown during filters. fast_search: bool, optional, default False + Only search indexed fragments for vector, full text, and scalar-indexed + filter queries. This may skip recently appended unindexed data. full_text_query: str or dict, optional query string to search for, the results will be ranked by BM25. e.g. "hello world", would match documents contains "hello" or "world". @@ -5686,6 +5707,7 @@ def __init__(self, ds: LanceDataset): self._fragment_readahead: Optional[int] = None self._scan_in_order = True self._fragments = None + self._index_segments = None self._with_row_id = False self._with_row_address = False self._use_stats = True @@ -5970,6 +5992,24 @@ def with_fragments( self._fragments = fragments return self + def with_index_segments( + self, index_segments: Optional[Iterable[Union[str, uuid.UUID]]] + ) -> ScannerBuilder: + if index_segments is not None: + segment_ids = [] + for segment_id in index_segments: + if isinstance(segment_id, (str, uuid.UUID)): + segment_ids.append(str(segment_id)) + else: + raise TypeError( + "index_segments must be an iterable of str or uuid.UUID. " + f"Got {type(segment_id)} instead." + ) + index_segments = segment_ids + + self._index_segments = index_segments + return self + def nearest( self, column: str, @@ -5989,6 +6029,16 @@ def nearest( Parameters ---------- + q: QueryVectorLike + A single query vector or, for fixed-size vector columns, a 2-D array-like + or list-shaped batch of query vectors. Batch queries return up to ``k`` rows + per query and include Int32 non-null ``query_index`` as the first output + column. Flattened 1-D inputs whose length is a multiple of the vector + dimension are rejected. Datasets with an existing ``query_index`` column + cannot be used for batch search. + When ``use_index`` is true and a vector index is available, each query + vector is searched through the index path; otherwise the flat batch path + is used. query_parallelism: int, optional Maximum partition-search concurrency for a single vector query. The default is 0. Value 0 uses the automatic policy, which @@ -6015,10 +6065,10 @@ def nearest( return self def fast_search(self, flag: bool) -> ScannerBuilder: - """Enable fast search, which only perform search on the indexed data. + """Enable fast search, which only performs search on indexed fragments. - Users can use `Table::optimize()` or `create_index()` to include the new data - into index, thus make new data searchable. + Users can use `Table::optimize()` or `create_index()` to include new data + in an index, thus making new data searchable. """ self._fast_search = flag return self @@ -6137,6 +6187,7 @@ def to_scanner(self) -> LanceScanner: self._fragment_readahead, self._scan_in_order, self._fragments, + self._index_segments, self._with_row_id, self._with_row_address, self._use_stats, @@ -7137,7 +7188,15 @@ def _build_vector_search_query( column: str The name of the vector column to search. q: QueryVectorLike - The query vector. + The query vector. For fixed-size vector columns, this may be a 2-D + array-like or list-shaped batch of query vectors. Batch queries return up to + ``k`` rows per query vector and include Int32 non-null ``query_index`` as + the first output column. + Flattened 1-D inputs whose length is a multiple of the vector dimension are + rejected. Datasets with an existing ``query_index`` column cannot be used for + batch search. When ``use_index`` is true and a vector index is available, + each query vector is searched through the index path; otherwise the flat batch + path is used. k: int, optional The number of nearest neighbors to return. metric: str, optional diff --git a/python/python/lance/lance/__init__.pyi b/python/python/lance/lance/__init__.pyi index dd91fb19614..2af79b4d072 100644 --- a/python/python/lance/lance/__init__.pyi +++ b/python/python/lance/lance/__init__.pyi @@ -253,6 +253,7 @@ class _Dataset: fragment_readahead: Optional[int] = None, scan_in_order: Optional[bool] = None, fragments: Optional[List[_Fragment]] = None, + index_segments: Optional[List[str]] = None, with_row_id: Optional[bool] = None, with_row_address: Optional[bool] = None, use_stats: Optional[bool] = None, diff --git a/python/python/tests/test_dataset.py b/python/python/tests/test_dataset.py index dc2030c4b2b..5379fa518a8 100644 --- a/python/python/tests/test_dataset.py +++ b/python/python/tests/test_dataset.py @@ -4426,6 +4426,22 @@ def test_use_scalar_index(tmp_path: Path): ).explain_plan(True) +def test_fast_search_scalar_index_skips_unindexed_fragments(tmp_path: Path): + table = pa.table({"filter": range(100)}) + dataset = lance.write_dataset(table, tmp_path, max_rows_per_file=100) + dataset.create_scalar_index("filter", "BTREE") + dataset = lance.write_dataset( + pa.table({"filter": range(100, 110)}), tmp_path, mode="append" + ) + + normal = dataset.to_table(filter="filter >= 95") + fast = dataset.to_table(filter="filter >= 95", fast_search=True) + + assert normal.num_rows == 15 + assert fast.num_rows == 5 + assert sorted(fast.column("filter").to_pylist()) == list(range(95, 100)) + + EXPECTED_DEFAULT_STORAGE_VERSION = stable_version() EXPECTED_MAJOR_VERSION = int(stable_version().split(".")[0]) EXPECTED_MINOR_VERSION = int(stable_version().split(".")[1]) diff --git a/python/python/tests/test_scalar_index.py b/python/python/tests/test_scalar_index.py index 5ddee78912e..7b4dede319b 100644 --- a/python/python/tests/test_scalar_index.py +++ b/python/python/tests/test_scalar_index.py @@ -651,9 +651,13 @@ def test_lance_mem_pool_env_var(tmp_path): @pytest.mark.parametrize("with_position", [True, False]) -def test_full_text_search(dataset, with_position): +@pytest.mark.parametrize("base_tokenizer", ["simple", "icu"]) +def test_full_text_search(dataset, with_position, base_tokenizer): dataset.create_scalar_index( - "doc", index_type="INVERTED", with_position=with_position + "doc", + index_type="INVERTED", + with_position=with_position, + base_tokenizer=base_tokenizer, ) row = dataset.take(indices=[0], columns=["doc"]) query = row.column(0)[0].as_py() @@ -840,7 +844,8 @@ def test_ngram_fts(tmp_path): ) -def test_fts_fts(tmp_path): +@pytest.mark.parametrize("base_tokenizer", ["simple", "icu"]) +def test_fts_fts(tmp_path, base_tokenizer): # Tests creating two FTS indices with the same name but different parameters dataset = lance.write_dataset( pa.table( @@ -855,7 +860,11 @@ def test_fts_fts(tmp_path): tmp_path, ) dataset.create_scalar_index( - "text", "INVERTED", with_position=True, remove_stop_words=False + "text", + "INVERTED", + with_position=True, + remove_stop_words=False, + base_tokenizer=base_tokenizer, ) results = dataset.to_table(full_text_query='"was a puppy"', prefilter=True) @@ -865,7 +874,11 @@ def test_fts_fts(tmp_path): assert results.num_rows == 3 dataset.create_scalar_index( - "text", "INVERTED", name="no_pos_idx", with_position=False + "text", + "INVERTED", + name="no_pos_idx", + with_position=False, + base_tokenizer=base_tokenizer, ) # There is no way to currently specify which index to use. Instead @@ -1024,7 +1037,8 @@ def num_indices(ds): assert ds.to_table(full_text_query="iota")["id"].to_pylist() == [5] -def test_fts_score(tmp_path): +@pytest.mark.parametrize("base_tokenizer", ["simple", "icu"]) +def test_fts_score(tmp_path, base_tokenizer): # the number of tokens matters for scoring, # make a table that all docs have the same number of tokens data = pa.table( @@ -1034,7 +1048,7 @@ def test_fts_score(tmp_path): } ) ds = lance.write_dataset(data, tmp_path) - ds.create_scalar_index("text", "INVERTED") + ds.create_scalar_index("text", "INVERTED", base_tokenizer=base_tokenizer) results = ds.to_table(full_text_query="lance search text") assert results.num_rows == 3 @@ -1046,7 +1060,7 @@ def test_fts_score(tmp_path): "text", "lance search text", tmp_path, - index_params={"with_position": False}, + index_params={"with_position": False, "base_tokenizer": base_tokenizer}, ) @@ -1665,6 +1679,22 @@ def test_jieba_tokenizer(tmp_path): assert results["_rowid"].to_pylist() == [0] +def test_icu_tokenizer(tmp_path): + data = pa.table( + { + "text": ["Hello, こんにちは世界!", "Hello, こんにちは!"], + } + ) + ds = lance.write_dataset(data, tmp_path, mode="overwrite") + ds.create_scalar_index("text", "INVERTED", base_tokenizer="icu") + results = ds.to_table( + full_text_query="世界", + prefilter=True, + with_row_id=True, + ) + assert results["_rowid"].to_pylist() == [0] + + def test_jieba_invalid_user_dict_tokenizer(tmp_path): set_language_model_path() data = pa.table( diff --git a/python/python/tests/test_vector_index.py b/python/python/tests/test_vector_index.py index 356f72a5e66..4cf1c4947e3 100644 --- a/python/python/tests/test_vector_index.py +++ b/python/python/tests/test_vector_index.py @@ -9,6 +9,7 @@ import string import tempfile import time +import uuid from pathlib import Path from typing import Optional @@ -180,6 +181,127 @@ def test_flat(dataset): run(dataset) +@pytest.mark.parametrize( + "queries", + [ + np.random.randn(2, 128).astype(np.float32), + np.random.randn(1, 128).astype(np.float32), + ], + ids=["two_queries", "single_query"], +) +def test_batch_flat_query_matches_repeated_single_queries(dataset, queries): + k = 5 + query_count = queries.shape[0] + + batch = dataset.to_table( + columns=["id"], + nearest={ + "column": "vector", + "q": queries, + "k": k, + "use_index": False, + }, + ) + + assert batch.num_rows == query_count * k + assert batch.column_names == ["query_index", "id", "_distance"] + query_index_field = batch.schema.field("query_index") + assert query_index_field.type == pa.int32() + assert not query_index_field.nullable + expected_query_index = sum([[i] * k for i in range(query_count)], []) + assert batch["query_index"].to_pylist() == expected_query_index + + _assert_batch_matches_single_queries( + dataset, + queries, + k=k, + nearest_kwargs={"use_index": False}, + ) + + +def _assert_batch_matches_single_queries(ds, queries, k, nearest_kwargs): + batch = ds.to_table( + columns=["id"], + nearest={ + "column": "vector", + "q": queries, + "k": k, + **nearest_kwargs, + }, + ) + if "distance_range" in nearest_kwargs: + lo, hi = nearest_kwargs["distance_range"] + assert all(lo <= d < hi for d in batch["_distance"].to_pylist()) + + for query_index, query in enumerate(queries): + single = ds.to_table( + columns=["id"], + nearest={ + "column": "vector", + "q": query, + "k": k, + **nearest_kwargs, + }, + ) + batch_slice = batch.filter(pc.field("query_index") == query_index) + assert batch_slice["id"].to_pylist() == single["id"].to_pylist() + np.testing.assert_allclose( + batch_slice["_distance"].to_numpy(), + single["_distance"].to_numpy(), + ) + + +def test_batch_vector_search_rejects_dataset_query_index_column(tmp_path): + dim = 128 + table = create_table(nvec=80, ndim=dim) + table = table.append_column( + "query_index", + pa.array(range(80), type=pa.uint32()), + ) + ds = lance.write_dataset(table, tmp_path / "with_query_index") + + queries = np.random.randn(2, dim).astype(np.float32) + with pytest.raises(Exception, match="query_index"): + ds.to_table( + columns=["id", "query_index"], + nearest={ + "column": "vector", + "q": queries, + "k": 5, + "use_index": False, + }, + ) + + +def test_flat_1d_query_length_multiple_of_dim_is_rejected(dataset): + q = np.random.randn(256).astype(np.float32) + with pytest.raises(ValueError, match=r"256.*128"): + dataset.to_table( + columns=["id"], + nearest={ + "column": "vector", + "q": q, + "k": 5, + "use_index": False, + }, + ) + + +def test_batch_fast_search_without_index_returns_empty_with_query_index(dataset): + queries = np.random.randn(2, 128).astype(np.float32) + batch = dataset.to_table( + columns=["id"], + nearest={ + "column": "vector", + "q": queries, + "k": 5, + }, + fast_search=True, + ) + assert batch.num_rows == 0 + assert "query_index" in batch.column_names + + def test_ann(indexed_dataset): run(indexed_dataset) @@ -2105,6 +2227,24 @@ def func(rs: pa.Table): run(dataset, q=np.array(q), assert_func=func) +def test_scanner_rejects_unknown_index_segments(tmp_path): + tbl = create_table() + dataset = lance.write_dataset(tbl, tmp_path) + dataset = dataset.create_index("vector", index_type="IVF_FLAT", num_partitions=4) + + with pytest.raises( + ValueError, match="with_index_segments referenced unknown index segments" + ): + dataset.scanner( + nearest={ + "column": "vector", + "q": np.random.randn(128).astype(np.float32), + "k": 10, + }, + index_segments=[uuid.uuid4()], + ).to_table() + + def test_vector_index_distance_range(tmp_path): """Ensure vector index honors distance_range.""" ndim = 128 diff --git a/python/src/dataset.rs b/python/src/dataset.rs index c868504e87c..d31cb870d0b 100644 --- a/python/src/dataset.rs +++ b/python/src/dataset.rs @@ -33,6 +33,7 @@ use pyo3::{ pyclass, types::{IntoPyDict, PyDict}, }; +use uuid::Uuid; use lance::dataset::AutoCleanupParams; use lance::dataset::cleanup::CleanupPolicyBuilder; @@ -1106,7 +1107,7 @@ impl Dataset { } #[allow(clippy::too_many_arguments)] - #[pyo3(signature=(columns=None, columns_with_transform=None, filter=None, search_filter=None, prefilter=None, limit=None, offset=None, nearest=None, batch_size=None, batch_size_bytes=None, io_buffer_size=None, batch_readahead=None, fragment_readahead=None, scan_in_order=None, fragments=None, with_row_id=None, with_row_address=None, use_stats=None, substrait_filter=None, fast_search=None, full_text_query=None, late_materialization=None, blob_handling=None, use_scalar_index=None, include_deleted_rows=None, scan_stats_callback=None, strict_batch_size=None, order_by=None, disable_scoring_autoprojection=None, substrait_aggregate=None))] + #[pyo3(signature=(columns=None, columns_with_transform=None, filter=None, search_filter=None, prefilter=None, limit=None, offset=None, nearest=None, batch_size=None, batch_size_bytes=None, io_buffer_size=None, batch_readahead=None, fragment_readahead=None, scan_in_order=None, fragments=None, index_segments=None, with_row_id=None, with_row_address=None, use_stats=None, substrait_filter=None, fast_search=None, full_text_query=None, late_materialization=None, blob_handling=None, use_scalar_index=None, include_deleted_rows=None, scan_stats_callback=None, strict_batch_size=None, order_by=None, disable_scoring_autoprojection=None, substrait_aggregate=None))] fn scanner( self_: PyRef<'_, Self>, columns: Option>, @@ -1124,6 +1125,7 @@ impl Dataset { fragment_readahead: Option, scan_in_order: Option, fragments: Option>, + index_segments: Option>, with_row_id: Option, with_row_address: Option, use_stats: Option, @@ -1302,6 +1304,22 @@ impl Dataset { scanner.with_fragments(fragments); } + if let Some(index_segments) = index_segments { + let index_segments = index_segments + .into_iter() + .map(|segment| { + Uuid::parse_str(&segment).map_err(|err| { + PyValueError::new_err(format!( + "invalid index segment uuid '{segment}': {err}" + )) + }) + }) + .collect::>>()?; + scanner + .with_index_segments(index_segments) + .map_err(|err| PyValueError::new_err(err.to_string()))?; + } + if let Some(scan_stats_callback) = scan_stats_callback { let callback = Self::make_scan_stats_callback(scan_stats_callback.clone())?; scanner.scan_stats_callback(callback); @@ -1371,7 +1389,12 @@ impl Dataset { let (_, element_type) = get_vector_type(self_.ds.schema(), &column) .map_err(|e| PyValueError::new_err(e.to_string()))?; let scanner = match element_type { - DataType::UInt8 => { + DataType::UInt8 + if !matches!( + q.data_type(), + DataType::List(_) | DataType::FixedSizeList(_, _) + ) => + { let q = arrow::compute::cast(&q, &DataType::UInt8).map_err(|e| { PyValueError::new_err(format!("Failed to cast q to binary vector: {}", e)) })?; diff --git a/rust/arrow-scalar/src/lib.rs b/rust/arrow-scalar/src/lib.rs index 04246589296..70c468ebb99 100644 --- a/rust/arrow-scalar/src/lib.rs +++ b/rust/arrow-scalar/src/lib.rs @@ -16,6 +16,8 @@ use std::fmt; use std::hash::{Hash, Hasher}; use std::sync::Arc; +use arrow_array::cast::AsArray; +use arrow_array::types::{Float16Type, Float32Type, Float64Type}; use arrow_array::{ArrayRef, make_array, new_null_array}; use arrow_cast::display::ArrayFormatter; use arrow_data::transform::MutableArrayData; @@ -109,6 +111,28 @@ impl ArrowScalar { pub fn is_null(&self) -> bool { self.array.null_count() == 1 } + + /// Returns `true` if this scalar is a non-null floating-point NaN. + /// + /// ``` + /// use lance_arrow_scalar::ArrowScalar; + /// + /// assert!(ArrowScalar::from(f32::NAN).is_nan()); + /// assert!(!ArrowScalar::from(1.0f32).is_nan()); + /// assert!(!ArrowScalar::from(1i32).is_nan()); + /// ``` + pub fn is_nan(&self) -> bool { + if self.is_null() { + return false; + } + + match self.data_type() { + DataType::Float16 => self.array.as_primitive::().value(0).is_nan(), + DataType::Float32 => self.array.as_primitive::().value(0).is_nan(), + DataType::Float64 => self.array.as_primitive::().value(0).is_nan(), + _ => false, + } + } } impl PartialEq for ArrowScalar { @@ -282,6 +306,23 @@ mod tests { assert_eq!(a.cmp(&b), expected); } + #[rstest] + #[case::float16_nan(ArrowScalar::from(half::f16::NAN), true)] + #[case::float32_nan(ArrowScalar::from(f32::NAN), true)] + #[case::float64_nan(ArrowScalar::from(f64::NAN), true)] + #[case::float64_finite(ArrowScalar::from(1.0f64), false)] + #[case::int32(ArrowScalar::from(1i32), false)] + fn test_is_nan(#[case] scalar: ArrowScalar, #[case] expected: bool) { + assert_eq!(scalar.is_nan(), expected); + } + + #[test] + fn test_null_is_not_nan() { + let array: ArrayRef = Arc::new(Float64Array::from(vec![None])); + let scalar = ArrowScalar::try_from_array(array).unwrap(); + assert!(!scalar.is_nan()); + } + #[test] fn test_display_string() { let s = ArrowScalar::from("hello world"); diff --git a/rust/lance-datafusion/src/planner.rs b/rust/lance-datafusion/src/planner.rs index ca64aacdd80..3fd1841ada7 100644 --- a/rust/lance-datafusion/src/planner.rs +++ b/rust/lance-datafusion/src/planner.rs @@ -917,7 +917,7 @@ impl Planner { pub fn optimize_expr(&self, expr: Expr) -> Result { let df_schema = Arc::new(DFSchema::try_from(self.schema.as_ref().clone())?); - // DataFusion needs the simplify and coerce passes to be applied before + // DataFusion needs the coerce and simplify passes to be applied before // expressions can be handled by the physical planner. let simplify_context = SimplifyContext::default() .with_schema(df_schema.clone()) @@ -925,8 +925,9 @@ impl Planner { let simplifier = datafusion::optimizer::simplify_expressions::ExprSimplifier::new(simplify_context); - let expr = simplifier.simplify(expr)?; + // Coerce before simplify to match DataFusion's analyzer-before-optimizer pipeline. let expr = simplifier.coerce(expr, &df_schema)?; + let expr = simplifier.simplify(expr)?; Ok(expr) } @@ -1011,6 +1012,7 @@ impl TreeNodeVisitor<'_> for ColumnCapturingVisitor { #[cfg(test)] mod tests { + use std::any::Any; use crate::logical_expr::ExprExt; @@ -1100,6 +1102,59 @@ mod tests { ); } + #[derive(Debug, Eq, PartialEq, Hash)] + struct StrictFloat64Udf { + signature: Signature, + } + + impl StrictFloat64Udf { + fn new() -> Self { + Self { + signature: Signature::exact(vec![DataType::Float64], Volatility::Immutable), + } + } + } + + impl ScalarUDFImpl for StrictFloat64Udf { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "strict_float64" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> DFResult { + Ok(DataType::Float64) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult { + let data_type = args.args[0].data_type(); + assert_eq!( + data_type, + DataType::Float64, + "strict_float64 expected Float64, got {data_type}" + ); + Ok(ColumnarValue::Scalar(ScalarValue::Float64(Some(0.0)))) + } + } + + #[test] + fn test_coerce_before_simplify() { + let planner = Planner::new(Arc::new(Schema::empty())); + let strict_float64 = Arc::new(ScalarUDF::new_from_impl(StrictFloat64Udf::new())); + let expr = Expr::ScalarFunction(ScalarFunction::new_udf(strict_float64, vec![lit(0_i64)])) + .eq(lit(0.0_f64)); + + let optimized = planner.optimize_expr(expr).unwrap(); + + planner.create_physical_expr(&optimized).unwrap(); + } + #[test] fn test_nested_col_refs() { let schema = Arc::new(Schema::new(vec![ diff --git a/rust/lance-index/Cargo.toml b/rust/lance-index/Cargo.toml index 8bde58fce7e..3dbe3aeeba9 100644 --- a/rust/lance-index/Cargo.toml +++ b/rust/lance-index/Cargo.toml @@ -38,6 +38,7 @@ itertools.workspace = true jieba-rs = { workspace = true, optional = true } jsonb.workspace = true lance-arrow.workspace = true +lance-arrow-stats.workspace = true lance-core.workspace = true lance-datafusion.workspace = true lance-encoding.workspace = true @@ -91,8 +92,10 @@ geo = ["dep:lance-geo", "lance-geo/geo", "dep:geoarrow-array", "dep:geoarrow-sch protoc = ["dep:protobuf-src"] jieba-rs = ["tokenizer-jieba"] lindera = ["tokenizer-lindera"] +icu = ["tokenizer-icu"] tokenizer-lindera = ["lance-tokenizer/tokenizer-lindera"] tokenizer-jieba = ["dep:jieba-rs", "lance-tokenizer/tokenizer-jieba"] +tokenizer-icu = ["lance-tokenizer/tokenizer-icu"] [build-dependencies] prost-build.workspace = true diff --git a/rust/lance-index/src/scalar.rs b/rust/lance-index/src/scalar.rs index bc0e39c206f..5ab138ff481 100644 --- a/rust/lance-index/src/scalar.rs +++ b/rust/lance-index/src/scalar.rs @@ -17,8 +17,7 @@ use std::fmt::Debug; use std::pin::Pin; use std::{any::Any, ops::Bound, sync::Arc}; -use datafusion_expr::Expr; -use datafusion_expr::expr::ScalarFunction; +use datafusion_expr::{Expr, expr::ScalarFunction}; use deepsize::DeepSizeOf; use inverted::query::{FtsQuery, FtsQueryNode, FtsSearchParams, MatchQuery, fill_fts_query_column}; use lance_core::{Error, Result}; diff --git a/rust/lance-index/src/scalar/bloomfilter.rs b/rust/lance-index/src/scalar/bloomfilter.rs index 95807c47d3f..af37f982d1c 100644 --- a/rust/lance-index/src/scalar/bloomfilter.rs +++ b/rust/lance-index/src/scalar/bloomfilter.rs @@ -20,6 +20,7 @@ use arrow_array::{Array, UInt64Array}; mod as_bytes; pub mod sbbf; use arrow_schema::{DataType, Field}; +use lance_arrow_stats::StatisticsAccumulator; use serde::{Deserialize, Serialize}; use std::sync::LazyLock; @@ -647,7 +648,7 @@ impl BloomFilterIndexBuilder { struct BloomFilterProcessor { params: BloomFilterIndexBuilderParams, sbbf: Option, - cur_zone_has_null: bool, + statistics: Option, } impl BloomFilterProcessor { @@ -655,7 +656,7 @@ impl BloomFilterProcessor { let mut processor = Self { params, sbbf: None, - cur_zone_has_null: false, + statistics: None, }; processor.reset()?; Ok(processor) @@ -744,6 +745,11 @@ impl ZoneProcessor for BloomFilterProcessor { Error::invalid_input("BloomFilterProcessor did not initialize bloom filter") })?; + let statistics = self + .statistics + .get_or_insert_with(|| StatisticsAccumulator::new(array.data_type())); + statistics.update(array)?; + let has_null = match array.data_type() { // Signed integers DataType::Int8 => { @@ -946,7 +952,7 @@ impl ZoneProcessor for BloomFilterProcessor { }; // Update the current zone's null tracking - self.cur_zone_has_null = self.cur_zone_has_null || has_null; + debug_assert_eq!(has_null, array.null_count() > 0); Ok(()) } @@ -954,16 +960,21 @@ impl ZoneProcessor for BloomFilterProcessor { let bloom_filter = self.sbbf.as_ref().ok_or_else(|| { Error::invalid_input("BloomFilterProcessor did not initialize bloom filter") })?; + let has_null = self + .statistics + .as_ref() + .map(|statistics| statistics.statistics().null_count > 0) + .unwrap_or(false); Ok(BloomFilterStatistics { bound, - has_null: self.cur_zone_has_null, + has_null, bloom_filter: bloom_filter.clone(), }) } fn reset(&mut self) -> Result<()> { self.sbbf = Some(Self::build_filter(&self.params)?); - self.cur_zone_has_null = false; + self.statistics = None; Ok(()) } } diff --git a/rust/lance-index/src/scalar/inverted.rs b/rust/lance-index/src/scalar/inverted.rs index fc716b0474a..6adf4457f05 100644 --- a/rust/lance-index/src/scalar/inverted.rs +++ b/rust/lance-index/src/scalar/inverted.rs @@ -78,7 +78,7 @@ fn scorer_terms( /// per-segment stat aggregation, term deduplication, and fuzzy-expansion /// union. Keeps a single source of truth for BM25 IDF arithmetic across /// segments. -pub fn build_global_bm25_scorer( +pub async fn build_global_bm25_scorer( indices: &[Arc], query_tokens: &Tokens, params: &FtsSearchParams, @@ -88,7 +88,7 @@ pub fn build_global_bm25_scorer( lance_core::Error::invalid_input("FTS index requires at least one segment") })?; let (mut total_tokens, mut num_docs, first_token_docs) = - first_index.bm25_stats_for_terms(&terms); + first_index.bm25_stats_for_terms(&terms).await?; let mut token_docs = HashMap::with_capacity(terms.len()); for (term, count) in terms.iter().cloned().zip(first_token_docs.into_iter()) { token_docs.insert(term, count); @@ -96,7 +96,7 @@ pub fn build_global_bm25_scorer( for index in indices.iter().skip(1) { let (segment_total_tokens, segment_num_docs, segment_token_docs) = - index.bm25_stats_for_terms(&terms); + index.bm25_stats_for_terms(&terms).await?; total_tokens += segment_total_tokens; num_docs += segment_num_docs; for (term, count) in terms.iter().zip(segment_token_docs.into_iter()) { diff --git a/rust/lance-index/src/scalar/inverted/index.rs b/rust/lance-index/src/scalar/inverted/index.rs index d08dacd26e7..152afc8bc0c 100644 --- a/rust/lance-index/src/scalar/inverted/index.rs +++ b/rust/lance-index/src/scalar/inverted/index.rs @@ -516,22 +516,77 @@ impl InvertedIndex { }) } - pub fn bm25_base_scorer(&self, query_tokens: &Tokens) -> MemBM25Scorer { + /// Build a single-segment [`MemBM25Scorer`] whose per-term IDF table + /// covers every token that the per-partition scoring loop will look + /// up. For fuzzy queries that means the union of Levenshtein + /// expansions, not just the raw query tokens — otherwise + /// `query_weight(expanded_token)` returns 0 and the BM25 contribution + /// of every expanded match is discarded. + pub async fn bm25_base_scorer( + &self, + query_tokens: &Tokens, + params: &FtsSearchParams, + ) -> Result { let scorer = IndexBM25Scorer::new(self.partitions.iter().map(|part| part.as_ref())); - let token_docs = query_tokens - .into_iter() - .map(|token| (token.to_string(), scorer.num_docs_containing_token(token))) - .collect::>(); - MemBM25Scorer::new(scorer.total_tokens(), scorer.num_docs(), token_docs) + let mut terms: Vec = Vec::new(); + let mut seen = HashSet::new(); + if matches!(params.fuzziness, Some(n) if n != 0) { + let expanded = self.expand_fuzzy_tokens(query_tokens, params)?; + for idx in 0..expanded.len() { + let token = expanded.get_token(idx); + if seen.insert(token.to_string()) { + terms.push(token.to_string()); + } + } + } else { + for token in query_tokens { + if seen.insert(token.to_string()) { + terms.push(token.to_string()); + } + } + } + let mut token_docs = HashMap::with_capacity(terms.len()); + for term in &terms { + let df = self.df_for_term(term).await?; + token_docs.insert(term.clone(), df); + } + Ok(MemBM25Scorer::new( + scorer.total_tokens(), + scorer.num_docs(), + token_docs, + )) } - pub fn bm25_stats_for_terms(&self, terms: &[String]) -> (u64, usize, Vec) { + pub async fn bm25_stats_for_terms(&self, terms: &[String]) -> Result<(u64, usize, Vec)> { let scorer = IndexBM25Scorer::new(self.partitions.iter().map(|part| part.as_ref())); - let token_docs = terms + let token_docs = + futures::future::try_join_all(terms.iter().map(|term| self.df_for_term(term))).await?; + Ok((scorer.total_tokens(), scorer.num_docs(), token_docs)) + } + + /// Sum the posting-list length for `term` across this index's partitions + /// via single-row reads, with partition lookups bounded by the store's + /// `io_parallelism()`. + async fn df_for_term(&self, term: &str) -> Result { + let io_parallelism = self.store.io_parallelism(); + let futures = self + .partitions .iter() - .map(|term| scorer.num_docs_containing_token(term)) - .collect(); - (scorer.total_tokens(), scorer.num_docs(), token_docs) + .map(|part| { + let part = part.clone(); + async move { + match part.tokens.get(term) { + Some(token_id) => part.inverted_list.posting_len_for_token(token_id).await, + None => Ok(0), + } + } + }) + .collect::>(); + let dfs: Vec = stream::iter(futures) + .buffer_unordered(io_parallelism) + .try_collect() + .await?; + Ok(dfs.into_iter().sum()) } /// Expand fuzzy query tokens against all partitions in this segment. @@ -570,11 +625,17 @@ impl InvertedIndex { metrics: Arc, base_scorer: Option<&MemBM25Scorer>, ) -> Result<(Vec, Vec)> { + // The wand only consults `scorer.doc_weight`, which is metadata-free. + // The outer aggregation below consults `scorer.query_weight`, which + // hits per-token `posting_len`; building a `MemBM25Scorer` with + // precomputed per-term IDFs avoids the v2 bulk metadata pull. let local_scorer; let scorer: &dyn Scorer = if let Some(base_scorer) = base_scorer { base_scorer } else { - local_scorer = IndexBM25Scorer::new(self.partitions.iter().map(|part| part.as_ref())); + local_scorer = self + .bm25_base_scorer(tokens.as_ref(), params.as_ref()) + .await?; &local_scorer }; @@ -1052,7 +1113,7 @@ impl InvertedPartition { } pub fn is_legacy(&self) -> bool { - self.inverted_list.lengths.is_none() + self.inverted_list.is_legacy_layout() } pub async fn load( @@ -1596,15 +1657,10 @@ impl TokenSet { pub struct PostingListReader { reader: Arc, - // legacy format only - offsets: Option>, - - // from metadata for legacy format - // from column for new format - max_scores: Option>, - - // new format only - lengths: Option>, + /// Layout-specific metadata. V2 keeps its per-token max-score and + /// length columns lazy so opening a partition doesn't drag O(num_tokens) + /// bytes off cold storage when the caller only needs `df` for a few terms. + metadata: PostingMetadata, has_position: bool, posting_tail_codec: PostingTailCodec, @@ -1613,6 +1669,33 @@ pub struct PostingListReader { index_cache: WeakLanceCache, } +/// Per-token metadata (max_score, length) needed by the BM25 query and stats +/// paths. The legacy and v2 formats store this metadata in different +/// places, with very different cost profiles for cold-load: the variants +/// surface that asymmetry so callers can choose a per-token or bulk access +/// pattern. +enum PostingMetadata { + /// Legacy v1: offsets and max_scores are encoded in the file's schema + /// metadata, so they are already in memory by the time `try_new` returns. + LegacyV1 { + offsets: Vec, + max_scores: Option>, + }, + /// V2: per-token `max_score` and `length` live as columns in the + /// posting file. The bulk vectors are filled lazily by + /// `ensure_metadata_loaded`, and the stats path can also fetch a single + /// token via `posting_len_for_token` without forcing the bulk load. + V2 { + metadata: tokio::sync::OnceCell, + }, +} + +#[derive(Debug, Clone)] +struct LoadedPostingMetadata { + max_scores: Vec, + lengths: Vec, +} + #[derive(Debug, Clone, Copy, PartialEq, Eq)] enum PositionsLayout { None, @@ -1622,18 +1705,40 @@ enum PositionsLayout { impl std::fmt::Debug for PostingListReader { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("InvertedListReader") - .field("offsets", &self.offsets) - .field("max_scores", &self.max_scores) - .finish() + let mut s = f.debug_struct("InvertedListReader"); + match &self.metadata { + PostingMetadata::LegacyV1 { + offsets, + max_scores, + } => { + s.field("layout", &"legacy_v1") + .field("offsets", offsets) + .field("max_scores", max_scores); + } + PostingMetadata::V2 { metadata } => { + s.field("layout", &"v2") + .field("metadata_loaded", &metadata.initialized()); + } + } + s.finish() } } impl DeepSizeOf for PostingListReader { fn deep_size_of_children(&self, context: &mut deepsize::Context) -> usize { - self.offsets.deep_size_of_children(context) - + self.max_scores.deep_size_of_children(context) - + self.lengths.deep_size_of_children(context) + match &self.metadata { + PostingMetadata::LegacyV1 { + offsets, + max_scores, + } => offsets.deep_size_of_children(context) + max_scores.deep_size_of_children(context), + PostingMetadata::V2 { metadata } => metadata + .get() + .map(|loaded| { + loaded.max_scores.deep_size_of_children(context) + + loaded.lengths.deep_size_of_children(context) + }) + .unwrap_or(0), + } } } @@ -1651,29 +1756,21 @@ impl PostingListReader { }; let posting_tail_codec = parse_posting_tail_codec(&reader.schema().metadata)?; let has_position = positions_layout != PositionsLayout::None; - let (offsets, max_scores, lengths) = if reader.schema().field(POSTING_COL).is_none() { + let metadata = if reader.schema().field(POSTING_COL).is_none() { let (offsets, max_scores) = Self::load_metadata(reader.schema())?; - (Some(offsets), max_scores, None) + PostingMetadata::LegacyV1 { + offsets, + max_scores, + } } else { - let metadata = reader - .read_range(0..reader.num_rows(), Some(&[MAX_SCORE_COL, LENGTH_COL])) - .await?; - let max_scores = metadata[MAX_SCORE_COL] - .as_primitive::() - .values() - .to_vec(); - let lengths = metadata[LENGTH_COL] - .as_primitive::() - .values() - .to_vec(); - (None, Some(max_scores), Some(lengths)) + PostingMetadata::V2 { + metadata: tokio::sync::OnceCell::new(), + } }; Ok(Self { reader, - offsets, - max_scores, - lengths, + metadata, has_position, posting_tail_codec, positions_layout, @@ -1702,9 +1799,9 @@ impl PostingListReader { // the number of posting lists pub fn len(&self) -> usize { - match self.offsets { - Some(ref offsets) => offsets.len(), - None => self.reader.num_rows(), + match &self.metadata { + PostingMetadata::LegacyV1 { offsets, .. } => offsets.len(), + PostingMetadata::V2 { .. } => self.reader.num_rows(), } } @@ -1720,33 +1817,131 @@ impl PostingListReader { self.posting_tail_codec } + fn is_legacy_layout(&self) -> bool { + matches!(self.metadata, PostingMetadata::LegacyV1 { .. }) + } + + /// Sync access to `posting_len`. Requires v2 metadata to already be + /// loaded via [`ensure_metadata_loaded`]; the bm25 scoring path enforces + /// that contract before kicking off wand. The stats path uses + /// [`Self::posting_len_for_token`] instead, which avoids the bulk load. pub(crate) fn posting_len(&self, token_id: u32) -> usize { let token_id = token_id as usize; - - match self.offsets { - Some(ref offsets) => { + match &self.metadata { + PostingMetadata::LegacyV1 { offsets, .. } => { let next_offset = offsets .get(token_id + 1) .copied() .unwrap_or(self.reader.num_rows()); next_offset - offsets[token_id] } - None => { - if let Some(lengths) = &self.lengths { - lengths[token_id] as usize - } else { - panic!("posting list reader is not initialized") + PostingMetadata::V2 { metadata } => { + let metadata = metadata + .get() + .expect("v2 posting metadata must be bulk-loaded before sync posting_len; call ensure_metadata_loaded first"); + metadata.lengths[token_id] as usize + } + } + } + + /// Async access to a single token's posting list length. For v2 + /// indexes this reads a single row from `LENGTH_COL` if the bulk metadata + /// has not been loaded yet, and never triggers the bulk load itself. The + /// stats path uses this so a single-term `df` lookup costs O(1) bytes + /// rather than O(num_unique_tokens). + pub(crate) async fn posting_len_for_token(&self, token_id: u32) -> Result { + match &self.metadata { + PostingMetadata::LegacyV1 { .. } => Ok(self.posting_len(token_id)), + PostingMetadata::V2 { metadata } => { + if let Some(metadata) = metadata.get() { + return Ok(metadata.lengths[token_id as usize] as usize); } + let token_id = token_id as usize; + let batch = self + .reader + .read_range(token_id..token_id + 1, Some(&[LENGTH_COL])) + .await?; + let len = batch[LENGTH_COL].as_primitive::().value(0); + Ok(len as usize) } } } + /// Async access to a single token's `(max_score, length)` pair. Mirrors + /// [`Self::posting_len_for_token`] but covers both columns the scoring + /// path needs, in one read. For v2 indexes that have not been + /// bulk-loaded this issues one `read_range(token..token+1, [MAX_SCORE, + /// LENGTH])`; for legacy v1 the values come from in-memory schema + /// metadata. + pub(crate) async fn posting_metadata_for_token( + &self, + token_id: u32, + ) -> Result<(Option, Option)> { + match &self.metadata { + PostingMetadata::LegacyV1 { max_scores, .. } => { + Ok((max_scores.as_ref().map(|m| m[token_id as usize]), None)) + } + PostingMetadata::V2 { metadata } => { + if let Some(loaded) = metadata.get() { + return Ok(( + Some(loaded.max_scores[token_id as usize]), + Some(loaded.lengths[token_id as usize]), + )); + } + let token_id_usize = token_id as usize; + let batch = self + .reader + .read_range( + token_id_usize..token_id_usize + 1, + Some(&[MAX_SCORE_COL, LENGTH_COL]), + ) + .await?; + let max_score = batch[MAX_SCORE_COL].as_primitive::().value(0); + let length = batch[LENGTH_COL].as_primitive::().value(0); + Ok((Some(max_score), Some(length))) + } + } + } + + /// Force the v2 bulk metadata (`max_scores`, `lengths`) into + /// memory. Cheap to call repeatedly; no-op for legacy v1 indexes whose + /// metadata is already populated from schema metadata at `try_new` time. + pub(crate) async fn ensure_metadata_loaded(&self) -> Result<()> { + let PostingMetadata::V2 { metadata } = &self.metadata else { + return Ok(()); + }; + metadata + .get_or_try_init(|| async { + let batch = self + .reader + .read_range( + 0..self.reader.num_rows(), + Some(&[MAX_SCORE_COL, LENGTH_COL]), + ) + .await?; + let max_scores = batch[MAX_SCORE_COL] + .as_primitive::() + .values() + .to_vec(); + let lengths = batch[LENGTH_COL] + .as_primitive::() + .values() + .to_vec(); + Ok::(LoadedPostingMetadata { + max_scores, + lengths, + }) + }) + .await?; + Ok(()) + } + pub(crate) async fn posting_batch( &self, token_id: u32, with_position: bool, ) -> Result { - if self.offsets.is_some() { + if self.is_legacy_layout() { self.posting_batch_legacy(token_id, with_position).await } else { let token_id = token_id as usize; @@ -1784,8 +1979,11 @@ impl PostingListReader { } let length = self.posting_len(token_id); + let PostingMetadata::LegacyV1 { offsets, .. } = &self.metadata else { + unreachable!("posting_batch_legacy is only reachable on legacy v1 layout"); + }; let token_id = token_id as usize; - let offset = self.offsets.as_ref().unwrap()[token_id]; + let offset = offsets[token_id]; let batch = self .reader .read_range(offset..offset + length, Some(&columns)) @@ -1806,8 +2004,15 @@ impl PostingListReader { .get_or_insert_with_key(cache_key, || async move { metrics.record_part_load(); info!(target: TRACE_IO_EVENTS, r#type=IO_TYPE_LOAD_SCALAR_PART, index_type="inverted", part_id=token_id); - let batch = self.posting_batch(token_id, false).await?; - self.posting_list_from_batch(&batch, token_id) + // Fetch the posting batch and this token's (max_score, + // length) in parallel; for cold v2 partitions this is one + // single-row metadata read plus one posting-row read, + // instead of pulling the full per-token metadata table. + let (batch, (max_score, length)) = futures::try_join!( + self.posting_batch(token_id, false), + self.posting_metadata_for_token(token_id), + )?; + self.posting_list_from_batch(&batch, max_score, length) }) .await? .as_ref() @@ -1842,16 +2047,13 @@ impl PostingListReader { pub(crate) fn posting_list_from_batch( &self, batch: &RecordBatch, - token_id: u32, + max_score: Option, + length: Option, ) -> Result { Self::posting_list_from_batch_parts( batch, - self.max_scores - .as_ref() - .map(|max_scores| max_scores[token_id as usize]), - self.lengths - .as_ref() - .map(|lengths| lengths[token_id as usize]), + max_score, + length, self.posting_tail_codec, self.positions_layout, ) @@ -1907,14 +2109,27 @@ impl PostingListReader { )); } + // Make sure max_scores/lengths are populated before we clone them into + // the blocking task; otherwise the v2 branch would unwrap empty + // OnceCells. + self.ensure_metadata_loaded().await?; + let read_batch_start = Instant::now(); let batch = self.read_batch(with_position).await?; let read_batch_elapsed = read_batch_start.elapsed(); - let legacy_layout = self.offsets.is_some(); - let offsets = self.offsets.clone(); - let max_scores = self.max_scores.clone(); - let lengths = self.lengths.clone(); + let (legacy_layout, offsets, max_scores, lengths) = match &self.metadata { + PostingMetadata::LegacyV1 { + offsets, + max_scores, + } => (true, Some(offsets.clone()), max_scores.clone(), None), + PostingMetadata::V2 { metadata } => ( + false, + None, + metadata.get().map(|loaded| loaded.max_scores.clone()), + metadata.get().map(|loaded| loaded.lengths.clone()), + ), + }; let posting_tail_codec = self.posting_tail_codec; let positions_layout = self.positions_layout; let populate_start = Instant::now(); @@ -1971,15 +2186,41 @@ impl PostingListReader { &self, with_position: bool, ) -> Result> + '_> { + // read_all walks every posting list; the bulk metadata is paid for + // unconditionally, so just load it once up front and index into it + // synchronously below. + self.ensure_metadata_loaded().await?; let batch = self.read_batch(with_position).await?; Ok((0..self.len()).map(move |i| { let token_id = i as u32; let range = self.posting_list_range(token_id); let batch = batch.slice(i, range.end - range.start); - self.posting_list_from_batch(&batch, token_id) + let (max_score, length) = self.bulk_metadata_for_token(token_id); + self.posting_list_from_batch(&batch, max_score, length) })) } + /// Sync lookup of `(max_score, length)` from the bulk-loaded metadata. + /// Only safe after [`Self::ensure_metadata_loaded`]; callers that hold + /// the OnceCell-loaded reference (e.g. read_all, prewarm) use this to + /// avoid the per-token IO path. + fn bulk_metadata_for_token(&self, token_id: u32) -> (Option, Option) { + match &self.metadata { + PostingMetadata::LegacyV1 { max_scores, .. } => { + (max_scores.as_ref().map(|m| m[token_id as usize]), None) + } + PostingMetadata::V2 { metadata } => { + let loaded = metadata.get().expect( + "v2 metadata must be bulk-loaded before bulk_metadata_for_token; call ensure_metadata_loaded first", + ); + ( + Some(loaded.max_scores[token_id as usize]), + Some(loaded.lengths[token_id as usize]), + ) + } + } + } + async fn read_positions(&self, token_id: u32) -> Result { let positions = self.index_cache.get_or_insert_with_key(PositionKey { token_id }, || async move { let positions = match self.positions_layout { @@ -2038,13 +2279,13 @@ impl PostingListReader { } fn posting_list_range(&self, token_id: u32) -> Range { - match self.offsets { - Some(ref offsets) => { + match &self.metadata { + PostingMetadata::LegacyV1 { offsets, .. } => { let offset = offsets[token_id as usize]; let posting_len = self.posting_len(token_id); offset..offset + posting_len } - None => { + PostingMetadata::V2 { .. } => { let token_id = token_id as usize; token_id..token_id + 1 } @@ -2052,9 +2293,10 @@ impl PostingListReader { } fn posting_columns(&self, with_position: bool) -> Vec<&'static str> { - let mut base_columns = match self.offsets { - Some(_) => vec![ROW_ID, FREQUENCY_COL], - None => vec![POSTING_COL], + let mut base_columns = if self.is_legacy_layout() { + vec![ROW_ID, FREQUENCY_COL] + } else { + vec![POSTING_COL] }; if with_position { match self.positions_layout { @@ -5045,18 +5287,27 @@ mod tests { // Verify the partitions were loaded correctly - // Verify posting list lengths (note: partition order may differ from creation order) - // Verify based on actual loading order + // Verify posting list lengths (note: partition order may differ from creation order). + // `posting_len_for_token` works for both legacy and v2 layouts without + // forcing the V2-only bulk metadata load. + let pl_0_0 = index.partitions[0] + .inverted_list + .posting_len_for_token(0) + .await + .unwrap(); + let pl_1_0 = index.partitions[1] + .inverted_list + .posting_len_for_token(0) + .await + .unwrap(); if index.partitions[0].id() == 0 { - // If partition[0] is ID=0, then it should have 1 document - assert_eq!(index.partitions[0].inverted_list.posting_len(0), 1); - assert_eq!(index.partitions[1].inverted_list.posting_len(0), 4); + assert_eq!(pl_0_0, 1); + assert_eq!(pl_1_0, 4); assert_eq!(index.partitions[0].docs.len(), 1); assert_eq!(index.partitions[1].docs.len(), 4); } else { - // If partition[0] is ID=1, then it should have 4 documents - assert_eq!(index.partitions[0].inverted_list.posting_len(0), 4); - assert_eq!(index.partitions[1].inverted_list.posting_len(0), 1); + assert_eq!(pl_0_0, 4); + assert_eq!(pl_1_0, 1); assert_eq!(index.partitions[0].docs.len(), 4); assert_eq!(index.partitions[1].docs.len(), 1); } @@ -5146,7 +5397,7 @@ mod tests { .unwrap(); let inverted_list = &index.partitions[0].inverted_list; assert!( - inverted_list.offsets.is_none(), + !inverted_list.is_legacy_layout(), "test should use modern posting layout" ); @@ -5177,6 +5428,294 @@ mod tests { ); } + /// IO accounting for the IO-counting stats test below: tracks bytes + /// pulled from the posting file so we can assert that the stats path is + /// O(1) in num_unique_tokens. + #[derive(Debug, Default)] + struct PostingMetadataCounter { + rows_read: std::sync::atomic::AtomicUsize, + metadata_rows_read: std::sync::atomic::AtomicUsize, + read_range_calls: std::sync::atomic::AtomicUsize, + } + + impl PostingMetadataCounter { + fn rows_read(&self) -> usize { + self.rows_read.load(std::sync::atomic::Ordering::Relaxed) + } + fn metadata_rows_read(&self) -> usize { + self.metadata_rows_read + .load(std::sync::atomic::Ordering::Relaxed) + } + fn read_range_calls(&self) -> usize { + self.read_range_calls + .load(std::sync::atomic::Ordering::Relaxed) + } + } + + struct CountingPostingReader { + inner: Arc, + counter: Arc, + } + + #[async_trait] + impl IndexReader for CountingPostingReader { + async fn read_record_batch(&self, n: u64, batch_size: u64) -> Result { + self.inner.read_record_batch(n, batch_size).await + } + async fn read_global_buffer(&self, index: u32) -> Result { + self.inner.read_global_buffer(index).await + } + async fn read_range( + &self, + range: std::ops::Range, + projection: Option<&[&str]>, + ) -> Result { + let n = range.end - range.start; + self.counter + .read_range_calls + .fetch_add(1, std::sync::atomic::Ordering::Relaxed); + self.counter + .rows_read + .fetch_add(n, std::sync::atomic::Ordering::Relaxed); + let touches_metadata = projection + .map(|cols| cols.contains(&MAX_SCORE_COL) || cols.contains(&LENGTH_COL)) + .unwrap_or(false); + if touches_metadata { + self.counter + .metadata_rows_read + .fetch_add(n, std::sync::atomic::Ordering::Relaxed); + } + self.inner.read_range(range, projection).await + } + async fn num_batches(&self, batch_size: u64) -> u32 { + self.inner.num_batches(batch_size).await + } + fn num_rows(&self) -> usize { + self.inner.num_rows() + } + fn schema(&self) -> &lance_core::datatypes::Schema { + self.inner.schema() + } + } + + #[derive(Debug)] + struct CountingStore { + inner: Arc, + posting_file: String, + counter: Arc, + } + + impl DeepSizeOf for CountingStore { + fn deep_size_of_children(&self, context: &mut deepsize::Context) -> usize { + self.inner.deep_size_of_children(context) + } + } + + #[async_trait] + impl IndexStore for CountingStore { + fn as_any(&self) -> &dyn std::any::Any { + self + } + fn clone_arc(&self) -> Arc { + Arc::new(Self { + inner: self.inner.clone(), + posting_file: self.posting_file.clone(), + counter: self.counter.clone(), + }) + } + fn io_parallelism(&self) -> usize { + self.inner.io_parallelism() + } + async fn new_index_file( + &self, + name: &str, + schema: Arc, + ) -> Result> { + self.inner.new_index_file(name, schema).await + } + async fn open_index_file(&self, name: &str) -> Result> { + let reader = self.inner.open_index_file(name).await?; + if name == self.posting_file { + Ok(Arc::new(CountingPostingReader { + inner: reader, + counter: self.counter.clone(), + })) + } else { + Ok(reader) + } + } + async fn copy_index_file(&self, name: &str, dest_store: &dyn IndexStore) -> Result<()> { + self.inner.copy_index_file(name, dest_store).await + } + async fn rename_index_file(&self, name: &str, new_name: &str) -> Result<()> { + self.inner.rename_index_file(name, new_name).await + } + async fn delete_index_file(&self, name: &str) -> Result<()> { + self.inner.delete_index_file(name).await + } + async fn list_files_with_sizes(&self) -> Result> { + self.inner.list_files_with_sizes().await + } + } + + async fn load_counted_v2_index( + num_tokens: usize, + ) -> (Arc, Arc) { + let tmpdir = TempObjDir::default(); + let inner_store = Arc::new(LanceIndexStore::new( + ObjectStore::local().into(), + tmpdir.clone(), + Arc::new(LanceCache::no_cache()), + )); + + let mut builder = InnerBuilder::new(0, false, TokenSetFormat::default()); + for i in 0..num_tokens { + builder.tokens.add(format!("t{}", i)); + let mut pl = PostingListBuilder::new(false); + pl.add(i as u32, PositionRecorder::Count(1)); + builder.posting_lists.push(pl); + builder.docs.append(i as u64, 1); + } + builder.write(inner_store.as_ref()).await.unwrap(); + + let metadata = HashMap::from([ + ( + "partitions".to_owned(), + serde_json::to_string(&vec![0u64]).unwrap(), + ), + ( + "params".to_owned(), + serde_json::to_string(&InvertedIndexParams::default()).unwrap(), + ), + ( + TOKEN_SET_FORMAT_KEY.to_owned(), + TokenSetFormat::default().to_string(), + ), + ]); + let mut writer = inner_store + .new_index_file(METADATA_FILE, Arc::new(arrow_schema::Schema::empty())) + .await + .unwrap(); + writer.finish_with_metadata(metadata).await.unwrap(); + + let counter = Arc::new(PostingMetadataCounter::default()); + let counting_store: Arc = Arc::new(CountingStore { + inner: inner_store, + posting_file: posting_file_path(0), + counter: counter.clone(), + }); + let index = InvertedIndex::load(counting_store, None, &LanceCache::no_cache()) + .await + .unwrap(); + (index, counter) + } + + /// IO regression test for the lazy posting-metadata refactor. Builds a + /// v2 InvertedIndex with `num_tokens` tokens in a single partition, + /// wraps the IndexStore so reads against the posting file are counted, + /// then asserts: + /// + /// * `InvertedIndex::load` does not touch the posting file at all + /// (`InvertedPartition::load` only needs the token file and docs file). + /// * `bm25_stats_for_terms(["t0"])` reads exactly one row from the + /// posting file (the single LENGTH_COL entry for token 0) regardless + /// of how many unique tokens the partition has. + /// + /// Before this refactor, `PostingListReader::try_new` did + /// `read_range(0..num_rows, [MAX_SCORE_COL, LENGTH_COL])`, so the + /// `metadata_rows_read` figure scaled linearly with `num_tokens` even + /// when nobody asked for those stats. The cases below exercise that + /// scaling explicitly. + #[rstest::rstest] + #[case::tokens_10(10)] + #[case::tokens_100(100)] + #[case::tokens_1000(1000)] + #[tokio::test] + async fn test_bm25_stats_for_terms_is_lazy(#[case] num_tokens: usize) { + let (index, counter) = load_counted_v2_index(num_tokens).await; + assert!( + !index.partitions[0].inverted_list.is_legacy_layout(), + "this test only proves the lazy path for v2 indexes", + ); + + // Opening the partition must not pull anything from the posting file. + // Pre-fix, `PostingListReader::try_new` issued one read_range here for + // [MAX_SCORE_COL, LENGTH_COL] covering every unique token. + assert_eq!( + counter.read_range_calls(), + 0, + "InvertedIndex::load must not read the posting file (was {} calls)", + counter.read_range_calls(), + ); + assert_eq!(counter.rows_read(), 0); + + let (total_tokens, num_docs, dfs) = index + .bm25_stats_for_terms(&["t0".to_string()]) + .await + .unwrap(); + assert_eq!(total_tokens, num_tokens as u64); + assert_eq!(num_docs, num_tokens); + assert_eq!(dfs, vec![1]); + + // Stats must pull a constant number of metadata rows from the posting + // file regardless of how many tokens the partition has. One term, one + // partition, one row. + assert_eq!( + counter.metadata_rows_read(), + 1, + "stats path should read exactly 1 metadata row per (term, partition); \ + got {} (read_range_calls={}, rows_read={}, num_tokens={})", + counter.metadata_rows_read(), + counter.read_range_calls(), + counter.rows_read(), + num_tokens, + ); + } + + #[tokio::test] + async fn test_posting_list_metadata_reads_scale_with_query_size() { + // Cold-start scoring used to bulk-read `0..num_tokens` of the + // [MAX_SCORE_COL, LENGTH_COL] columns the first time `bm25_search` + // ran against a partition. Now `posting_list` fetches each token's + // (max_score, length) as a single-row read alongside its posting + // batch, so K concurrent posting_list lookups should pull O(K) + // metadata rows, not O(num_tokens). + let num_tokens = 32; + let queried_tokens: [u32; 4] = [0, 1, 2, 3]; + let (index, counter) = load_counted_v2_index(num_tokens).await; + let inverted_list = index.partitions[0].inverted_list.clone(); + assert!( + !inverted_list.is_legacy_layout(), + "this test only proves the lazy path for v2 indexes", + ); + + let metrics = Arc::new(NoOpMetricsCollector); + stream::iter(queried_tokens) + .map(|token_id| { + let inverted_list = inverted_list.clone(); + let metrics = metrics.clone(); + async move { + inverted_list + .posting_list(token_id, false, metrics.as_ref()) + .await + .unwrap(); + } + }) + .buffer_unordered(queried_tokens.len()) + .collect::>() + .await; + + assert_eq!( + counter.metadata_rows_read(), + queried_tokens.len(), + "K posting_list calls should read K metadata rows (not the full \ + {}-row metadata table); got {} rows across {} read_range calls", + num_tokens, + counter.metadata_rows_read(), + counter.read_range_calls(), + ); + } + #[tokio::test] async fn test_prewarm_with_positions_populates_separate_position_cache() { let tmpdir = TempObjDir::default(); diff --git a/rust/lance-index/src/scalar/inverted/tokenizer.rs b/rust/lance-index/src/scalar/inverted/tokenizer.rs index fb1e8cc3639..ed0fd80638d 100644 --- a/rust/lance-index/src/scalar/inverted/tokenizer.rs +++ b/rust/lance-index/src/scalar/inverted/tokenizer.rs @@ -22,6 +22,8 @@ use crate::pbold; use crate::scalar::inverted::tokenizer::document_tokenizer::{ JsonTokenizer, LanceTokenizer, TextTokenizer, }; +#[cfg(feature = "tokenizer-icu")] +use lance_tokenizer::IcuTokenizer; pub use lance_tokenizer::Language; use lance_tokenizer::{ AsciiFoldingFilter, LowerCaser, NgramTokenizer, RawTokenizer, RemoveLongFilter, @@ -41,6 +43,7 @@ pub struct InvertedIndexParams { /// - `simple`: splits tokens on whitespace and punctuation /// - `whitespace`: splits tokens on whitespace /// - `raw`: no tokenization + /// - `icu`: ICU dictionary-based word segmentation /// - `lindera/*`: Lindera tokenizer /// - `jieba/*`: Jieba tokenizer /// @@ -195,6 +198,7 @@ impl InvertedIndexParams { /// - `whitespace`: splits tokens on whitespace /// - `raw`: no tokenization /// - `ngram`: N-Gram tokenizer + /// - `icu`: ICU dictionary-based word segmentation /// - `lindera/*`: Lindera tokenizer /// - `jieba/*`: Jieba tokenizer /// @@ -385,6 +389,8 @@ impl InvertedIndexParams { "simple" => Ok(TextAnalyzer::builder(SimpleTokenizer::default()).dynamic()), "whitespace" => Ok(TextAnalyzer::builder(WhitespaceTokenizer::default()).dynamic()), "raw" => Ok(TextAnalyzer::builder(RawTokenizer::default()).dynamic()), + #[cfg(feature = "tokenizer-icu")] + "icu" => Ok(TextAnalyzer::builder(IcuTokenizer::default()).dynamic()), "ngram" => { let tokenizer = NgramTokenizer::new( self.min_ngram_length as usize, @@ -437,6 +443,8 @@ pub fn language_model_home() -> Option { #[cfg(test)] mod tests { use super::InvertedIndexParams; + #[cfg(feature = "tokenizer-icu")] + use lance_tokenizer::TokenStream; #[test] fn test_build_only_fields_are_not_serialized() { @@ -485,4 +493,19 @@ mod tests { ); assert_eq!(json.get("num_workers"), Some(&serde_json::Value::from(3))); } + + #[cfg(feature = "tokenizer-icu")] + #[test] + fn test_build_icu_tokenizer() { + let mut tokenizer = InvertedIndexParams::default() + .base_tokenizer("icu".to_string()) + .stem(false) + .remove_stop_words(false) + .build() + .unwrap(); + let mut stream = tokenizer.token_stream_for_doc("Hello, こんにちは世界!"); + let mut tokens = Vec::new(); + stream.process(&mut |token| tokens.push(token.text.clone())); + assert_eq!(tokens, vec!["hello", "こんにちは", "世界"]); + } } diff --git a/rust/lance-index/src/scalar/zonemap.rs b/rust/lance-index/src/scalar/zonemap.rs index 8ba08d3a255..64601a64f96 100644 --- a/rust/lance-index/src/scalar/zonemap.rs +++ b/rust/lance-index/src/scalar/zonemap.rs @@ -22,13 +22,14 @@ use crate::scalar::{ BuiltinIndexType, CreatedIndex, SargableQuery, ScalarIndexParams, UpdateCriteria, compute_next_prefix, }; -use datafusion::functions_aggregate::min_max::{MaxAccumulator, MinAccumulator}; -use datafusion_expr::Accumulator; +use lance_arrow_stats::StatisticsAccumulator; use lance_core::cache::{LanceCache, WeakLanceCache}; use serde::{Deserialize, Serialize}; use std::sync::LazyLock; -use arrow_array::{ArrayRef, RecordBatch, UInt32Array, UInt64Array, new_empty_array}; +use arrow_array::{ + ArrayRef, RecordBatch, UInt32Array, UInt64Array, new_empty_array, new_null_array, +}; use arrow_schema::{DataType, Field}; use datafusion::execution::SendableRecordBatchStream; use datafusion_common::ScalarValue; @@ -131,9 +132,45 @@ impl DeepSizeOf for ZoneMapIndex { } impl ZoneMapIndex { - /// Evaluates whether a zone could potentially contain values matching the query - /// For NaN, total order is used here - /// reference: https://doc.rust-lang.org/std/primitive.f64.html#method.total_cmp + fn scalar_is_nan(value: &ScalarValue) -> bool { + match value { + ScalarValue::Float16(Some(value)) => value.is_nan(), + ScalarValue::Float32(Some(value)) => value.is_nan(), + ScalarValue::Float64(Some(value)) => value.is_nan(), + _ => false, + } + } + + /// Returns true if the zone has a non-null, non-NaN min value. + fn zone_has_finite_min(zone: &ZoneMapStatistics) -> bool { + !(zone.min.is_null() || Self::scalar_is_nan(&zone.min)) + } + + /// Returns true if both min and max are non-null / non-NaN. + fn zone_has_finite_extrema(zone: &ZoneMapStatistics) -> bool { + Self::zone_has_finite_min(zone) && !(zone.max.is_null() || Self::scalar_is_nan(&zone.max)) + } + + fn finite_value_may_be_in_zone(value: &ScalarValue, zone: &ZoneMapStatistics) -> bool { + if !Self::zone_has_finite_min(zone) || value < &zone.min { + return false; + } + + if Self::scalar_is_nan(&zone.max) { + // A NaN max means this zone had both NaNs and finite values. The + // finite max is not persisted, so keep the zone as a false positive + // instead of using total ordering to prune it. + return true; + } + + !zone.max.is_null() && value <= &zone.max + } + + /// Evaluates whether a zone could potentially contain values matching the query. + /// + /// NaN query values use the explicit `nan_count`. When the stored max is + /// NaN we do not treat it as a finite upper bound; that representation means + /// the zone had finite values plus NaNs, and the finite max was not persisted. fn evaluate_zone_against_query( &self, zone: &ZoneMapStatistics, @@ -165,20 +202,19 @@ impl ZoneMapIndex { return Ok(zone.nan_count > 0); } - // Check if target is within the zone's range - // Handle the case where zone.max is NaN (zone contains both finite values and NaN) - let min_check = target >= &zone.min; - let max_check = match &zone.max { - ScalarValue::Float16(Some(f)) if f.is_nan() => true, - ScalarValue::Float32(Some(f)) if f.is_nan() => true, - ScalarValue::Float64(Some(f)) if f.is_nan() => true, - _ => target <= &zone.max, - }; - Ok(min_check && max_check) + if !Self::zone_has_finite_min(zone) { + return Ok(false); + } + + Ok(Self::finite_value_may_be_in_zone(target, zone)) } SargableQuery::Range(start, end) => { // Zone overlaps with query range if there's any intersection between // the zone's [min, max] and the query's range + if !Self::zone_has_finite_min(zone) { + return Ok(false); + } + let zone_min = &zone.min; let zone_max = &zone.max; @@ -301,24 +337,28 @@ impl ZoneMapIndex { if f.is_nan() { zone.nan_count > 0 } else { - value >= &zone.min && value <= &zone.max + Self::finite_value_may_be_in_zone(value, zone) } } ScalarValue::Float32(Some(f)) => { if f.is_nan() { zone.nan_count > 0 } else { - value >= &zone.min && value <= &zone.max + Self::finite_value_may_be_in_zone(value, zone) } } ScalarValue::Float64(Some(f)) => { if f.is_nan() { zone.nan_count > 0 } else { - value >= &zone.min && value <= &zone.max + Self::finite_value_may_be_in_zone(value, zone) } } - _ => value >= &zone.min && value <= &zone.max, + _ => { + Self::zone_has_finite_extrema(zone) + && value >= &zone.min + && value <= &zone.max + } } } })) @@ -754,79 +794,91 @@ impl ZoneMapIndexBuilder { /// trainer takes care of chunking and fragment boundaries. struct ZoneMapProcessor { data_type: DataType, - min: MinAccumulator, - max: MaxAccumulator, - null_count: u32, - nan_count: u32, + statistics: StatisticsAccumulator, } impl ZoneMapProcessor { fn new(data_type: DataType) -> Result { - let min = MinAccumulator::try_new(&data_type)?; - let max = MaxAccumulator::try_new(&data_type)?; Ok(Self { + statistics: StatisticsAccumulator::new(&data_type), data_type, - min, - max, - null_count: 0, - nan_count: 0, }) } - fn count_nans(array: &ArrayRef) -> u32 { - match array.data_type() { - DataType::Float16 => { - let array = array - .as_any() - .downcast_ref::() - .unwrap(); - array.values().iter().filter(|&&x| x.is_nan()).count() as u32 - } - DataType::Float32 => { - let array = array - .as_any() - .downcast_ref::() - .unwrap(); - array.values().iter().filter(|&&x| x.is_nan()).count() as u32 - } - DataType::Float64 => { - let array = array - .as_any() - .downcast_ref::() - .unwrap(); - array.values().iter().filter(|&&x| x.is_nan()).count() as u32 - } - _ => 0, + fn scalar_value_from_stat( + value: Option<&ArrayRef>, + data_type: &DataType, + ) -> Result { + let array = value + .cloned() + .unwrap_or_else(|| new_null_array(data_type, 1)); + Ok(ScalarValue::try_from_array(&array, 0)?) + } + + fn stat_count_to_u32(name: &str, value: u64) -> Result { + u32::try_from(value).map_err(|_| { + Error::invalid_input(format!( + "{} value {} exceeds the supported UInt32 range", + name, value + )) + }) + } + + fn nan_scalar(data_type: &DataType) -> Option { + match data_type { + DataType::Float16 => Some(ScalarValue::Float16(Some(half::f16::NAN))), + DataType::Float32 => Some(ScalarValue::Float32(Some(f32::NAN))), + DataType::Float64 => Some(ScalarValue::Float64(Some(f64::NAN))), + _ => None, } } + + fn max_value_from_stats( + value: Option<&ArrayRef>, + data_type: &DataType, + nan_count: u32, + ) -> Result { + if nan_count > 0 + && let Some(nan) = Self::nan_scalar(data_type) + { + // DataFusion's max accumulator surfaced NaN as the zone max. Keep + // that stored zonemap shape while using arrow_stats so existing + // range/equality pruning remains conservative around NaN. + return Ok(nan); + } + Self::scalar_value_from_stat(value, data_type) + } } impl ZoneProcessor for ZoneMapProcessor { type ZoneStatistics = ZoneMapStatistics; fn process_chunk(&mut self, array: &ArrayRef) -> Result<()> { - self.null_count += array.null_count() as u32; - self.nan_count += Self::count_nans(array); - self.min.update_batch(std::slice::from_ref(array))?; - self.max.update_batch(std::slice::from_ref(array))?; + self.statistics.update(array)?; Ok(()) } fn finish_zone(&mut self, bound: ZoneBound) -> Result { + let statistics = self.statistics.statistics(); + let nan_count = Self::stat_count_to_u32("nan_count", statistics.nan_count.unwrap_or(0))?; Ok(ZoneMapStatistics { - min: self.min.evaluate()?, - max: self.max.evaluate()?, - null_count: self.null_count, - nan_count: self.nan_count, + min: Self::scalar_value_from_stat( + statistics.min.as_ref().map(|scalar| scalar.as_array()), + &self.data_type, + )?, + max: Self::max_value_from_stats( + statistics.max.as_ref().map(|scalar| scalar.as_array()), + &self.data_type, + nan_count, + )?, + null_count: Self::stat_count_to_u32("null_count", statistics.null_count)?, + nan_count, bound, }) } fn reset(&mut self) -> Result<()> { - self.min = MinAccumulator::try_new(&self.data_type)?; - self.max = MaxAccumulator::try_new(&self.data_type)?; - self.null_count = 0; - self.nan_count = 0; + self.statistics.reset(); Ok(()) } } diff --git a/rust/lance-index/src/vector/ivf/shuffler.rs b/rust/lance-index/src/vector/ivf/shuffler.rs index cbeb9556eab..175d5ecb9fb 100644 --- a/rust/lance-index/src/vector/ivf/shuffler.rs +++ b/rust/lance-index/src/vector/ivf/shuffler.rs @@ -28,6 +28,7 @@ use futures::stream::repeat_with; use futures::{FutureExt, Stream, StreamExt, TryStreamExt, stream}; use lance_arrow::RecordBatchExt; use lance_core::cache::LanceCache; +use lance_core::utils::futures::StreamOnDropExt; use lance_core::utils::tokio::get_num_compute_intensive_cpus; use lance_core::{Error, ROW_ID, Result, datatypes::Schema}; use lance_encoding::decoder::{DecoderPlugins, FilterExpression}; @@ -53,12 +54,12 @@ use crate::vector::transform::Transformer; const UNSORTED_BUFFER: &str = "unsorted.lance"; const SHUFFLE_BATCH_SIZE: usize = 1024; -fn get_temp_dir() -> Result { - // Note: using keep here means we will not delete this TempDir automatically - let dir = tempfile::TempDir::new()?.keep(); +/// Returns the temp dir path plus a guard whose `Drop` removes the directory. +fn get_temp_dir() -> Result<(Path, tempfile::TempDir)> { + let dir = tempfile::TempDir::new()?; let tmp_dir_path = - Path::from_filesystem_path(dir).map_err(|e| Error::io_source(Box::new(e)))?; - Ok(tmp_dir_path) + Path::from_filesystem_path(dir.path()).map_err(|e| Error::io_source(Box::new(e)))?; + Ok((tmp_dir_path, dir)) } /// A builder for a partition of data @@ -343,11 +344,22 @@ pub async fn shuffle_dataset( // step 3: load the sorted chunks, consumers are expect to be responsible for merging the streams let start = std::time::Instant::now(); - let stream = + let streams = IvfShuffler::load_partitioned_shuffles(&shuffler.output_dir, partition_files).await?; info!("merged partitioned shuffles in {:?}", start.elapsed()); - Ok(stream) + // Clone the temp-dir guard into each returned stream so the shuffle + // files are removed only after the consumer drops every stream. + let temp_dir_guard = shuffler.owned_temp_dir.clone(); + let guarded_streams = streams + .into_iter() + .map(|stream| { + let guard = temp_dir_guard.clone(); + stream.on_drop(move || drop(guard)) + }) + .collect::>(); + + Ok(guarded_streams) } pub async fn shuffle_vectors( @@ -385,6 +397,10 @@ pub struct IvfShuffler { output_dir: Path, + // `Some` for an auto-created `output_dir`; cleanup runs when the last + // clone of this `Arc` is dropped. `None` when the caller owns cleanup. + owned_temp_dir: Option>, + // whether the lance file is v1 (legacy) or v2 is_legacy: bool, @@ -410,9 +426,12 @@ impl IvfShuffler { is_legacy: bool, shuffle_output_root_filename: Option, ) -> Result { - let output_dir = match output_dir { - Some(output_dir) => output_dir, - None => get_temp_dir()?, + let (output_dir, owned_temp_dir) = match output_dir { + Some(output_dir) => (output_dir, None), + None => { + let (path, dir) = get_temp_dir()?; + (path, Some(Arc::new(dir))) + } }; let shuffle_output_root_filename = match shuffle_output_root_filename { @@ -423,6 +442,7 @@ impl IvfShuffler { Ok(Self { num_partitions, output_dir, + owned_temp_dir, unsorted_buffers: vec![], is_legacy, shuffle_output_root_filename, @@ -1198,4 +1218,63 @@ mod test { assert_eq!(num_batches, NUM_PARTITIONS * expected_num_part_files); } + + // Auto-created shuffler temp dir must be removed once the shuffler and + // its returned streams are dropped. + #[tokio::test] + async fn test_shuffler_cleans_up_auto_temp_dir() { + let (stream, mut shuffler) = make_stream_and_shuffler(false); + + // Snapshot the path without cloning the `Arc` — a clone here would + // block cleanup on drop. + let temp_dir_path = shuffler + .owned_temp_dir + .as_ref() + .expect("shuffler built with output_dir = None should own a TempDir guard") + .path() + .to_path_buf(); + + assert!( + temp_dir_path.is_dir(), + "auto-created shuffler temp dir should exist while shuffler is alive: {:?}", + temp_dir_path, + ); + + shuffler.write_unsorted_stream(stream).await.unwrap(); + let partition_files = shuffler.write_partitioned_shuffles(100, 1).await.unwrap(); + assert_eq!(partition_files.len(), 1); + + assert!( + temp_dir_path.join("unsorted.lance").is_file(), + "shuffler should have written unsorted.lance into its working dir: {:?}", + temp_dir_path, + ); + assert!( + temp_dir_path.join("sorted_0.lance").is_file(), + "shuffler should have written sorted_0.lance into its working dir: {:?}", + temp_dir_path, + ); + + let mut result_streams = + IvfShuffler::load_partitioned_shuffles(&shuffler.output_dir, partition_files) + .await + .unwrap(); + + while let Some(mut s) = result_streams.pop() { + while let Some(item) = s.next().await { + let _ = item.unwrap(); + } + } + drop(result_streams); + // Dropping the shuffler releases the last `Arc`, which + // removes the on-disk directory. + drop(shuffler); + + assert!( + !temp_dir_path.exists(), + "auto-created shuffler temp dir should be removed once the IvfShuffler and \ + its returned streams are dropped, but it still exists: {:?}", + temp_dir_path, + ); + } } diff --git a/rust/lance-io/src/object_reader.rs b/rust/lance-io/src/object_reader.rs index ff0128fd003..d6d5de98f0b 100644 --- a/rust/lance-io/src/object_reader.rs +++ b/rust/lance-io/src/object_reader.rs @@ -188,25 +188,29 @@ impl Reader for CloudObjectReader { #[instrument(level = "debug", skip(self))] fn get_range(&self, range: Range) -> BoxFuture<'static, OSResult> { - let get_request = Arc::new(GetRequest { - object_store: self.object_store.clone(), - path: self.path.clone(), - options: GetOptions { - range: Some( - Range { - start: range.start as u64, - end: range.end as u64, - } - .into(), - ), - ..Default::default() - }, - }); - Box::pin(do_get_with_outer_retry( - self.download_retry_count, - get_request, - move || format!("range {:?}", range), - )) + let object_store = self.object_store.clone(); + let path = self.path.clone(); + let get_range = Range { + start: range.start as u64, + end: range.end as u64, + }; + Box::pin(async move { + let bytes = do_with_retry(move || { + let object_store = object_store.clone(); + let path = path.clone(); + let get_range = get_range.clone(); + Box::pin(async move { object_store.get_ranges(&path, &[get_range]).await }) + }) + .await?; + + bytes + .into_iter() + .next() + .ok_or_else(|| object_store::Error::Generic { + store: "CloudObjectReader", + source: "get_ranges returned no bytes".into(), + }) + }) } #[instrument(level = "debug", skip_all)] diff --git a/rust/lance-io/src/object_store/dynamic_opendal.rs b/rust/lance-io/src/object_store/dynamic_opendal.rs index 081a9b1cc5e..3367eacd507 100644 --- a/rust/lance-io/src/object_store/dynamic_opendal.rs +++ b/rust/lance-io/src/object_store/dynamic_opendal.rs @@ -3,8 +3,10 @@ use std::collections::HashMap; use std::fmt; +use std::ops::Range; use std::sync::Arc; +use bytes::Bytes; use futures::{StreamExt, TryStreamExt, stream, stream::BoxStream}; use object_store::path::Path; use object_store::{ @@ -170,6 +172,18 @@ impl OSObjectStore for DynamicOpenDalStore { .await } + async fn get_ranges( + &self, + location: &Path, + ranges: &[Range], + ) -> object_store::Result> { + self.current_store() + .await + .map_err(|e| self.map_store_error(e))? + .get_ranges(location, ranges) + .await + } + fn delete_stream( &self, locations: BoxStream<'static, object_store::Result>, diff --git a/rust/lance-tokenizer/Cargo.toml b/rust/lance-tokenizer/Cargo.toml index 0f9909522ce..783d76e1d3c 100644 --- a/rust/lance-tokenizer/Cargo.toml +++ b/rust/lance-tokenizer/Cargo.toml @@ -12,6 +12,7 @@ categories.workspace = true rust-version.workspace = true [dependencies] +icu_segmenter = { workspace = true, optional = true } jieba-rs = { workspace = true, optional = true } lindera = { workspace = true, optional = true } rust-stemmers = "1.2.0" @@ -21,6 +22,7 @@ unicode-normalization = "0.1.25" [features] jieba-rs = ["dep:jieba-rs"] lindera = ["dep:lindera"] +tokenizer-icu = ["dep:icu_segmenter"] tokenizer-jieba = ["jieba-rs"] tokenizer-lindera = ["lindera"] diff --git a/rust/lance-tokenizer/src/icu.rs b/rust/lance-tokenizer/src/icu.rs new file mode 100644 index 00000000000..4e36f23115f --- /dev/null +++ b/rust/lance-tokenizer/src/icu.rs @@ -0,0 +1,127 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + +use icu_segmenter::{WordSegmenter, WordSegmenterBorrowed, options::WordBreakInvariantOptions}; + +use crate::{TextAnalyzer, TextAnalyzerBuilder, Token, TokenStream, Tokenizer}; + +#[derive(Clone)] +pub struct IcuTokenizer { + segmenter: WordSegmenterBorrowed<'static>, +} + +impl Default for IcuTokenizer { + fn default() -> Self { + Self { + segmenter: WordSegmenter::new_dictionary(WordBreakInvariantOptions::default()), + } + } +} + +impl IcuTokenizer { + pub fn analyzer(self) -> TextAnalyzer { + TextAnalyzer::builder(self).build() + } + + pub fn analyzer_builder(self) -> TextAnalyzerBuilder { + TextAnalyzer::builder(self).dynamic() + } +} + +pub struct IcuTokenStream { + tokens: Vec, + index: usize, +} + +impl TokenStream for IcuTokenStream { + fn advance(&mut self) -> bool { + if self.index < self.tokens.len() { + self.index += 1; + true + } else { + false + } + } + + fn token(&self) -> &Token { + &self.tokens[self.index - 1] + } + + fn token_mut(&mut self) -> &mut Token { + &mut self.tokens[self.index - 1] + } +} + +impl Tokenizer for IcuTokenizer { + type TokenStream<'a> = IcuTokenStream; + + fn token_stream<'a>(&'a mut self, text: &'a str) -> Self::TokenStream<'a> { + let mut boundaries = self.segmenter.segment_str(text); + let mut tokens = Vec::new(); + let Some(mut offset_from) = boundaries.next() else { + return IcuTokenStream { tokens, index: 0 }; + }; + + for offset_to in boundaries { + let token_text = &text[offset_from..offset_to]; + if token_text.chars().any(char::is_alphanumeric) { + tokens.push(Token { + offset_from, + offset_to, + position: tokens.len(), + text: token_text.to_owned(), + position_length: 1, + }); + } + offset_from = offset_to; + } + + IcuTokenStream { tokens, index: 0 } + } +} + +#[cfg(test)] +mod tests { + use crate::{IcuTokenizer, Token, TokenStream, Tokenizer}; + + fn collect_tokens(text: &str) -> Vec { + let mut tokenizer = IcuTokenizer::default(); + let mut stream = tokenizer.token_stream(text); + let mut tokens = Vec::new(); + stream.process(&mut |token| tokens.push(token.clone())); + tokens + } + + #[test] + fn test_icu_tokenizer_segments_mixed_text() { + let tokens = collect_tokens("Hello, こんにちは世界!"); + + assert_eq!( + tokens + .iter() + .map(|token| token.text.as_str()) + .collect::>(), + vec!["Hello", "こんにちは", "世界"] + ); + assert_eq!( + tokens + .iter() + .map(|token| (token.offset_from, token.offset_to, token.position)) + .collect::>(), + vec![(0, 5, 0), (7, 22, 1), (22, 28, 2)] + ); + } + + #[test] + fn test_icu_tokenizer_skips_non_word_segments() { + let tokens = collect_tokens("Mark'd ye his words?"); + + assert_eq!( + tokens + .iter() + .map(|token| token.text.as_str()) + .collect::>(), + vec!["Mark'd", "ye", "his", "words"] + ); + } +} diff --git a/rust/lance-tokenizer/src/lib.rs b/rust/lance-tokenizer/src/lib.rs index 54cc45dbbf4..2c441c58845 100644 --- a/rust/lance-tokenizer/src/lib.rs +++ b/rust/lance-tokenizer/src/lib.rs @@ -4,6 +4,8 @@ mod alphanum_only; mod analyzer; mod ascii_folding_filter; +#[cfg(feature = "tokenizer-icu")] +mod icu; #[cfg(feature = "tokenizer-jieba")] mod jieba; mod lower_caser; @@ -22,6 +24,8 @@ mod lindera; pub use alphanum_only::AlphaNumOnlyFilter; pub use analyzer::{TextAnalyzer, TextAnalyzerBuilder}; pub use ascii_folding_filter::AsciiFoldingFilter; +#[cfg(feature = "tokenizer-icu")] +pub use icu::IcuTokenizer; #[cfg(feature = "tokenizer-jieba")] pub use jieba::JiebaTokenizer; #[cfg(feature = "tokenizer-lindera")] diff --git a/rust/lance/src/dataset/mem_wal/api.rs b/rust/lance/src/dataset/mem_wal/api.rs index c8a3bc0441d..812b0066171 100644 --- a/rust/lance/src/dataset/mem_wal/api.rs +++ b/rust/lance/src/dataset/mem_wal/api.rs @@ -556,12 +556,21 @@ impl DatasetMemWalExt for Dataset { // Load index configs for each maintained index let mut index_configs = Vec::new(); for index_name in maintained_indexes { - let index_meta = self.load_index_by_name(index_name).await?.ok_or_else(|| { - Error::invalid_input(format!( - "Index '{}' from maintained_indexes not found on dataset", - index_name - )) - })?; + // A maintained index can split into multiple physical segments + // (e.g. `optimize_indices(append)` deltas), which the singular + // `load_index_by_name` rejects. Every segment carries the same + // type and params, so take the first match. + let index_meta = self + .load_indices_by_name(index_name) + .await? + .into_iter() + .next() + .ok_or_else(|| { + Error::invalid_input(format!( + "Index '{}' from maintained_indexes not found on dataset", + index_name + )) + })?; // Detect index type and create appropriate config let type_url = index_meta diff --git a/rust/lance/src/dataset/mem_wal/index/btree.rs b/rust/lance/src/dataset/mem_wal/index/btree.rs index a54112ec4e7..a65104d434d 100644 --- a/rust/lance/src/dataset/mem_wal/index/btree.rs +++ b/rust/lance/src/dataset/mem_wal/index/btree.rs @@ -289,70 +289,6 @@ impl BTreeMemIndex { Ok(batches) } - /// Export the index data as sorted RecordBatches with reversed row positions. - /// - /// This is used when flushing MemTable to disk with batches in reverse order. - /// Since the flushed data will have rows in reverse order, we need to map - /// the row positions accordingly: - /// `reversed_position = total_rows - original_position - 1` - /// - /// # Arguments - /// * `batch_size` - Maximum number of entries per batch - /// * `total_rows` - Total number of rows in the MemTable (needed for position reversal) - pub fn to_training_batches_reversed( - &self, - batch_size: usize, - total_rows: usize, - ) -> Result> { - use arrow_schema::{DataType, Field, Schema}; - use lance_core::ROW_ID; - use lance_index::scalar::registry::VALUE_COLUMN_NAME; - use std::sync::Arc; - - if self.lookup.is_empty() { - return Ok(vec![]); - } - - // Get the data type from the first key - let first_entry = self.lookup.front().unwrap(); - let data_type = first_entry.key().value.0.data_type(); - - // Create schema for training data - let schema = Arc::new(Schema::new(vec![ - Field::new(VALUE_COLUMN_NAME, data_type, true), - Field::new(ROW_ID, DataType::UInt64, false), - ])); - - let total_rows_u64 = total_rows as u64; - let mut batches = Vec::new(); - let mut values: Vec = Vec::with_capacity(batch_size); - let mut row_ids: Vec = Vec::with_capacity(batch_size); - - for entry in self.lookup.iter() { - let key = entry.key(); - values.push(key.value.0.clone()); - // Reverse the row position: new_pos = total_rows - old_pos - 1 - let reversed_position = total_rows_u64 - key.row_position - 1; - row_ids.push(reversed_position); - - if values.len() >= batch_size { - // Build and emit a batch - let batch = self.build_training_batch(&schema, &values, &row_ids)?; - batches.push(batch); - values.clear(); - row_ids.clear(); - } - } - - // Emit any remaining data - if !values.is_empty() { - let batch = self.build_training_batch(&schema, &values, &row_ids)?; - batches.push(batch); - } - - Ok(batches) - } - /// Build a single training batch from values and row IDs. fn build_training_batch( &self, @@ -506,63 +442,6 @@ mod tests { assert_eq!(row_ids.value(5), 5); // id=12 -> row 5 } - #[test] - fn test_btree_index_to_training_batches_reversed() { - use lance_core::ROW_ID; - use lance_index::scalar::registry::VALUE_COLUMN_NAME; - - let schema = create_test_schema(); - let index = BTreeMemIndex::new(0, "id".to_string()); - - let batch1 = create_test_batch(&schema, 0); // ids: 0, 1, 2 - let batch2 = create_test_batch(&schema, 10); // ids: 10, 11, 12 - - index.insert(&batch1, 0).unwrap(); // row positions 0, 1, 2 - index.insert(&batch2, 3).unwrap(); // row positions 3, 4, 5 - - // Export as training batches with reversed positions - // total_rows = 6, so reversed positions are: - // original 0 -> 6-0-1 = 5 - // original 1 -> 6-1-1 = 4 - // original 2 -> 6-2-1 = 3 - // original 3 -> 6-3-1 = 2 - // original 4 -> 6-4-1 = 1 - // original 5 -> 6-5-1 = 0 - let batches = index.to_training_batches_reversed(100, 6).unwrap(); - assert_eq!(batches.len(), 1); - - let batch = &batches[0]; - assert_eq!(batch.num_rows(), 6); - - // Check values are still in sorted order (0, 1, 2, 10, 11, 12) - let values = batch - .column_by_name(VALUE_COLUMN_NAME) - .unwrap() - .as_any() - .downcast_ref::() - .unwrap(); - assert_eq!(values.value(0), 0); - assert_eq!(values.value(1), 1); - assert_eq!(values.value(2), 2); - assert_eq!(values.value(3), 10); - assert_eq!(values.value(4), 11); - assert_eq!(values.value(5), 12); - - // Check row IDs are reversed - let row_ids = batch - .column_by_name(ROW_ID) - .unwrap() - .as_any() - .downcast_ref::() - .unwrap(); - assert_eq!(row_ids.value(0), 5); // id=0 was at row 0 -> reversed to 5 - assert_eq!(row_ids.value(1), 4); // id=1 was at row 1 -> reversed to 4 - assert_eq!(row_ids.value(2), 3); // id=2 was at row 2 -> reversed to 3 - assert_eq!(row_ids.value(3), 2); // id=10 was at row 3 -> reversed to 2 - assert_eq!(row_ids.value(4), 1); // id=11 was at row 4 -> reversed to 1 - assert_eq!(row_ids.value(5), 0); // id=12 was at row 5 -> reversed to 0 - } - #[test] fn test_btree_index_snapshot() { let schema = create_test_schema(); diff --git a/rust/lance/src/dataset/mem_wal/index/fts.rs b/rust/lance/src/dataset/mem_wal/index/fts.rs index 535b9c53fbd..46340a25b4c 100644 --- a/rust/lance/src/dataset/mem_wal/index/fts.rs +++ b/rust/lance/src/dataset/mem_wal/index/fts.rs @@ -9,7 +9,7 @@ //! - **One writer** (`insert` / `insert_with_batch_position`) at a time per //! index. Callers are responsible for that invariant; this is consistent //! with `IndexStore`'s usage from `ShardWriter`. -//! - **Many readers** (`search*`, `expand_fuzzy`, `to_index_builder_reversed`) +//! - **Many readers** (`search*`, `expand_fuzzy`, `to_index_builder`) //! in parallel with the writer. Reads are lock-free aside from a brief //! tokenizer-pool checkout. //! - **Per-batch monotonic visibility**: a reader either sees every row of @@ -41,7 +41,7 @@ //! # On-disk format //! //! At flush time we hand off to `lance_index::scalar::inverted::builder::InnerBuilder` -//! via `to_index_builder_reversed`, which merges every partition and the tail +//! via `to_index_builder`, which merges every partition and the tail //! into one builder. The on-disk format is unchanged from Lance's existing //! inverted index. @@ -1396,11 +1396,9 @@ impl FtsMemIndex { /// Export the in-memory FTS index to an `InnerBuilder` ready to be /// written to disk. /// - /// `total_rows` is the total number of rows in the MemTable being - /// flushed; row positions are reversed (`reversed = total_rows - pos - - /// 1`) to match the LSM-friendly newest-first flush order used by the - /// rest of MemWAL. - pub fn to_index_builder_reversed( + /// Doc row positions are kept in insert order to match the forward-written + /// flush data file 1:1. `total_rows` is used only to validate positions. + pub fn to_index_builder( &self, partition_id: u64, total_rows: usize, @@ -1434,9 +1432,9 @@ impl FtsMemIndex { )); } - // Step 2: assign doc_ids in ascending reversed-position order, so the - // on-disk layout matches MemWAL's newest-first flush convention. - let mut entries: Vec<(u64, u64, u32)> = Vec::with_capacity(all_docs.len()); + // Step 2: assign doc_ids in ascending insert-position order, so the + // stored row positions line up 1:1 with the forward-written data file. + let mut entries: Vec<(u64, u32)> = Vec::with_capacity(all_docs.len()); for (original, num_tokens) in &all_docs { if *original >= total_rows_u64 { return Err(Error::io(format!( @@ -1444,13 +1442,13 @@ impl FtsMemIndex { original, total_rows ))); } - entries.push((total_rows_u64 - original - 1, *original, *num_tokens)); + entries.push((*original, *num_tokens)); } - entries.sort_by_key(|(rev, _, _)| *rev); + entries.sort_by_key(|(original, _)| *original); let mut docs = DocSet::default(); let mut original_to_doc_id: HashMap = HashMap::with_capacity(entries.len()); - for (rev, original, num_tokens) in &entries { - let doc_id = docs.append(*rev, *num_tokens); + for (original, num_tokens) in &entries { + let doc_id = docs.append(*original, *num_tokens); original_to_doc_id.insert(*original, doc_id); } @@ -3712,14 +3710,14 @@ mod tests { } #[test] - fn test_to_index_builder_reversed_smoke() { + fn test_to_index_builder_smoke() { // Ensure flush works on a minimal input. let schema = create_test_schema(); let index = FtsMemIndex::new(1, "description".to_string()); let batch = create_test_batch(&schema); index.insert(&batch, 0).unwrap(); - let builder = index.to_index_builder_reversed(42, 3).unwrap(); + let builder = index.to_index_builder(42, 3).unwrap(); // The builder can be consumed by callers; we just check it built. assert!(builder.id() > 0 || builder.id() == 42); } @@ -3879,7 +3877,7 @@ mod tests { assert_eq!(st.partitions.len(), 1); assert_eq!(st.tail.visible_count(), 1); - let builder = index.to_index_builder_reversed(7, 300).unwrap(); + let builder = index.to_index_builder(7, 300).unwrap(); assert_eq!(builder.id(), 7); // A non-empty builder proves the partition and the tail both reached // the flush; end-to-end flush correctness is covered by the MemTable diff --git a/rust/lance/src/dataset/mem_wal/memtable/flush.rs b/rust/lance/src/dataset/mem_wal/memtable/flush.rs index fe9e643e719..f8e9e7db1d6 100644 --- a/rust/lance/src/dataset/mem_wal/memtable/flush.rs +++ b/rust/lance/src/dataset/mem_wal/memtable/flush.rs @@ -3,10 +3,13 @@ //! MemTable flush to persistent storage. +use std::collections::HashMap; use std::sync::Arc; +use arrow_array::RecordBatch; use bytes::Bytes; use lance_core::cache::LanceCache; +use lance_core::utils::deletion::DeletionVector; use lance_core::{Error, Result}; use lance_index::IndexType; use lance_index::mem_wal::{FlushedGeneration, ShardManifest}; @@ -14,9 +17,11 @@ use lance_index::scalar::{IndexStore, ScalarIndexParams}; use lance_io::object_store::ObjectStore; use lance_table::format::IndexMetadata; use lance_table::io::commit::write_manifest_file_to_path; +use lance_table::io::deletion::write_deletion_file; use log::info; use object_store::ObjectStoreExt; use object_store::path::Path; +use roaring::RoaringBitmap; use tracing::instrument; use uuid::Uuid; @@ -24,6 +29,7 @@ use super::super::index::MemIndexConfig; use super::super::memtable::MemTable; use crate::Dataset; use crate::dataset::mem_wal::manifest::ShardManifestStore; +use crate::dataset::mem_wal::scanner::exec::{compute_pk_hash, validate_pk_types}; use crate::dataset::mem_wal::util::{flushed_memtable_path, generate_random_hash}; #[derive(Debug, Clone)] @@ -33,6 +39,29 @@ pub struct FlushResult { pub covered_wal_entry_position: u64, } +/// Build the within-generation deletion vector for forward-written flush data. +/// +/// `batches` are in on-disk (insert) order, so the newest version of each +/// primary key is at the largest offset: the last occurrence of a PK hash is +/// kept and every earlier occurrence is marked deleted. Keys are hashed +/// (collisions accepted, consistent with the read path). +fn compute_dedup_deletions(batches: &[RecordBatch], pk_indices: &[usize]) -> RoaringBitmap { + let mut deleted = RoaringBitmap::new(); + let mut latest: HashMap = HashMap::new(); + let mut offset: u32 = 0; + for batch in batches { + for row in 0..batch.num_rows() { + let pk_hash = compute_pk_hash(batch, pk_indices, row); + if let Some(previous) = latest.insert(pk_hash, offset) { + // An earlier (older) occurrence of this PK is now superseded. + deleted.insert(previous); + } + offset += 1; + } + } + deleted +} + pub struct MemTableFlusher { object_store: Arc, base_path: Path, @@ -60,7 +89,6 @@ impl MemTableFlusher { /// Construct a full URI for a path within the base dataset. fn path_to_uri(&self, path: &Path) -> String { - // Remove base_path prefix from path to get relative path let path_str = path.as_ref(); let base_str = self.base_path.as_ref(); @@ -70,7 +98,6 @@ impl MemTableFlusher { path_str }; - // Combine base_uri with relative path let base = self.base_uri.trim_end_matches('/'); if relative.is_empty() { base.to_string() @@ -119,7 +146,15 @@ impl MemTableFlusher { memtable.batch_count() ); - let rows_flushed = self.write_data_file(&gen_path, memtable).await?; + let (rows_flushed, deleted) = self.write_data_file(&gen_path, memtable).await?; + + // Persist the within-generation deletion vector so the flushed + // generation exposes newest-per-PK on every read path. + if !deleted.is_empty() { + let uri = self.path_to_uri(&gen_path); + let dataset = Dataset::open(&uri).await?; + self.finalize_generation(&dataset, &deleted, None).await?; + } let bloom_path = gen_path.clone().join("bloom_filter.bin"); self.write_bloom_filter(&bloom_path, memtable.bloom_filter()) @@ -149,27 +184,61 @@ impl MemTableFlusher { }) } - /// Write data file with batches in reverse order (newest first). + /// Write the data file in insert (forward) order. /// - /// Returns the total number of rows written, which is needed for - /// reversing row positions in indexes. + /// Returns the total number of rows written and the within-generation + /// deletion vector marking every older duplicate of each primary key (see + /// [`compute_dedup_deletions`]). Forward order keeps the data file, the + /// incrementally-built indexes, and the deletion-vector offsets in one + /// position space (newest = largest offset) with no remap. #[instrument(name = "mt_write_data_file", level = "debug", skip_all, fields(path = %path))] - async fn write_data_file(&self, path: &Path, memtable: &MemTable) -> Result { + async fn write_data_file( + &self, + path: &Path, + memtable: &MemTable, + ) -> Result<(usize, RoaringBitmap)> { use arrow_array::RecordBatchIterator; use crate::dataset::WriteParams; if memtable.row_count() == 0 { - return Ok(0); + return Ok((0, RoaringBitmap::new())); } - // Scan batches in reverse order (newest first) so that the flushed - // data is ordered from newest to oldest. This enables more efficient - // K-way merge during LSM scan. - let (batches, total_rows) = memtable.scan_batches_reversed().await?; + let batches = memtable.scan_batches().await?; if batches.is_empty() { - return Ok(0); + return Ok((0, RoaringBitmap::new())); } + let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum(); + + // Build the deletion vector before `batches` is moved into the writer. + let pk_columns: Vec = memtable + .lance_schema() + .unenforced_primary_key() + .iter() + .map(|f| f.name.clone()) + .collect(); + let deleted = if pk_columns.is_empty() { + RoaringBitmap::new() + } else { + let schema = batches[0].schema(); + // Match the read-path contract (create_dedup_plan): unsupported PK + // types must error here rather than hit compute_pk_hash's + // debug-format fallback, which can collapse distinct keys. + validate_pk_types(schema.as_ref(), &pk_columns)?; + let pk_indices = pk_columns + .iter() + .map(|c| { + schema.index_of(c).map_err(|_| { + Error::invalid_input(format!( + "Primary key column '{}' not found in flush schema", + c + )) + }) + }) + .collect::>>()?; + compute_dedup_deletions(&batches, &pk_indices) + }; let uri = self.path_to_uri(path); let reader = @@ -182,7 +251,60 @@ impl MemTableFlusher { }; Dataset::write(reader, &uri, Some(write_params)).await?; - Ok(total_rows) + Ok((total_rows, deleted)) + } + + /// Persist the within-generation deletion vector (and any indexes) onto the + /// just-written generation by rewriting its manifest in place. + /// + /// The generation dataset is brand-new and not yet published in the shard + /// manifest, so overwriting its v1 manifest is safe. A no-op when there is + /// neither a deletion vector nor an index to record. + async fn finalize_generation( + &self, + dataset: &Dataset, + deleted: &RoaringBitmap, + indexes: Option>, + ) -> Result<()> { + let indexes = indexes.filter(|i| !i.is_empty()); + if deleted.is_empty() && indexes.is_none() { + return Ok(()); + } + + let mut manifest = dataset.manifest().clone(); + let manifest_path = dataset.manifest_location().path.clone(); + + if !deleted.is_empty() { + let dv = DeletionVector::from(deleted.clone()); + let deletion_file = write_deletion_file( + &dataset.base, + 0, // 1 fragment per flushed generation + dataset.version().version, + &dv, + dataset.object_store.as_ref(), + ) + .await?; + let fragments = Arc::make_mut(&mut manifest.fragments); + if let Some(fragment) = fragments.first_mut() { + fragment.deletion_file = deletion_file; + } + } + + // Clear stale section offsets from the v1 manifest since the rewritten + // file has a different layout (added index/deletion metadata). + manifest.index_section = None; + manifest.transaction_section = None; + manifest.transaction_file = None; + write_manifest_file_to_path( + &self.object_store, + &mut manifest, + indexes, + &manifest_path, + None, + ) + .await + .map_err(|e| Error::io(format!("Failed to write generation manifest: {}", e)))?; + Ok(()) } async fn write_bloom_filter( @@ -237,7 +359,7 @@ impl MemTableFlusher { memtable.batch_count() ); - let total_rows = self.write_data_file(&gen_path, memtable).await?; + let (total_rows, deleted) = self.write_data_file(&gen_path, memtable).await?; // Open the dataset once for all index building. Dataset::write already // created a v1 manifest with the fragment data. @@ -249,7 +371,7 @@ impl MemTableFlusher { let mut all_indexes: Vec = Vec::new(); let btree_indexes = self - .create_indexes(&mut dataset, index_configs, memtable.indexes(), total_rows) + .create_indexes(&mut dataset, index_configs, memtable.indexes()) .await?; if !btree_indexes.is_empty() { info!( @@ -266,7 +388,7 @@ impl MemTableFlusher { && let Some(mem_index) = registry.get_hnsw(&hnsw_config.name) { let mut index_meta = self - .create_hnsw_index(&gen_path, hnsw_config, mem_index, total_rows) + .create_hnsw_index(&gen_path, hnsw_config, mem_index) .await?; let schema = dataset.schema(); @@ -305,26 +427,11 @@ impl MemTableFlusher { all_indexes.extend(fts_indexes); } - // Write a single manifest that includes both fragments and all indexes, - // overwriting the data-only v1 manifest created by Dataset::write. - if !all_indexes.is_empty() { - let mut manifest = dataset.manifest().clone(); - let manifest_path = dataset.manifest_location().path.clone(); - // Clear stale section offsets from the original v1 manifest since - // the new file has a different layout with the added index section. - manifest.index_section = None; - manifest.transaction_section = None; - manifest.transaction_file = None; - write_manifest_file_to_path( - &self.object_store, - &mut manifest, - Some(all_indexes), - &manifest_path, - None, - ) - .await - .map_err(|e| Error::io(format!("Failed to write manifest with indexes: {}", e)))?; - } + // Write a single manifest that records the fragments, the + // within-generation deletion vector, and all indexes, overwriting the + // data-only v1 manifest created by Dataset::write. + self.finalize_generation(&dataset, &deleted, Some(all_indexes)) + .await?; let bloom_path = gen_path.clone().join("bloom_filter.bin"); self.write_bloom_filter(&bloom_path, memtable.bloom_filter()) @@ -363,7 +470,6 @@ impl MemTableFlusher { dataset: &mut Dataset, index_configs: &[MemIndexConfig], mem_indexes: Option<&super::super::index::IndexStore>, - total_rows: usize, ) -> Result> { use arrow_array::RecordBatchIterator; @@ -397,10 +503,9 @@ impl MemTableFlusher { if let Some(registry) = mem_indexes && let Some(btree_index) = registry.get_btree(&btree_cfg.name) { - // Use reversed training batches since the flushed data is in reverse order. - // Row positions need to be mapped: reversed_pos = total_rows - original_pos - 1 - let training_batches = - btree_index.to_training_batches_reversed(8192, total_rows)?; + // Forward-written data: index row positions line up 1:1 with + // the data file, no remap needed. + let training_batches = btree_index.to_training_batches(8192)?; if !training_batches.is_empty() { let schema = training_batches[0].schema(); let reader = @@ -461,8 +566,7 @@ impl MemTableFlusher { let partition_id = uuid::Uuid::new_v4().as_u64_pair().0; - let mut inner_builder = - fts_index.to_index_builder_reversed(partition_id, total_rows)?; + let mut inner_builder = fts_index.to_index_builder(partition_id, total_rows)?; let index_uuid = uuid::Uuid::new_v4(); let index_dir = gen_path @@ -571,13 +675,11 @@ impl MemTableFlusher { /// * `gen_path` - Path to the flushed generation folder /// * `config` - HNSW index configuration /// * `mem_index` - In-memory HNSW index (snapshotted, not consumed) - /// * `total_rows` - Total number of rows in the flushed data (for row position reversal) async fn create_hnsw_index( &self, gen_path: &Path, config: &super::super::index::HnswIndexConfig, mem_index: &super::super::index::HnswMemIndex, - total_rows: usize, ) -> Result { use arrow_array::cast::AsArray; use arrow_array::types::Float32Type; @@ -615,8 +717,9 @@ impl MemTableFlusher { "HnswMemIndex has no inserted vectors; nothing to flush", )); } - let Some((hnsw, flat_storage_batch)) = mem_index.to_lance_hnsw(Some(total_rows as u64))? - else { + // Forward-written data: HNSW row ids line up 1:1 with the data file, so + // no position reversal (pass `None`). + let Some((hnsw, flat_storage_batch)) = mem_index.to_lance_hnsw(None)? else { return Err(Error::invalid_input( "HnswMemIndex is empty; nothing to flush", )); @@ -845,6 +948,21 @@ mod tests { ])) } + /// Schema with `id` marked as the unenforced primary key, so the flush + /// computes a within-generation deletion vector. + fn create_pk_schema() -> Arc { + let mut id_metadata = std::collections::HashMap::new(); + id_metadata.insert( + "lance-schema:unenforced-primary-key".to_string(), + "true".to_string(), + ); + let id_field = Field::new("id", DataType::Int32, false).with_metadata(id_metadata); + Arc::new(ArrowSchema::new(vec![ + id_field, + Field::new("name", DataType::Utf8, true), + ])) + } + fn create_test_batch(schema: &ArrowSchema, num_rows: usize) -> RecordBatch { RecordBatch::try_new( Arc::new(schema.clone()), @@ -964,6 +1082,230 @@ mod tests { assert_eq!(updated_manifest.flushed_generations.len(), 1); } + /// Flushing a generation with within-generation duplicate PKs writes a + /// deletion vector so the flushed dataset exposes newest-per-PK on scan. + #[tokio::test] + async fn test_flush_writes_dedup_deletion_vector() { + use futures::TryStreamExt; + + let (store, base_path, base_uri, _temp_dir) = create_local_store().await; + let shard_id = Uuid::new_v4(); + let manifest_store = Arc::new(ShardManifestStore::new( + store.clone(), + &base_path, + shard_id, + 2, + )); + let (epoch, _manifest) = manifest_store.claim_epoch(0).await.unwrap(); + + let schema = create_pk_schema(); + let mut memtable = MemTable::new(schema.clone(), 1, vec![0]).unwrap(); + // Append order (newest last): id=1 a->a2, id=2 b, id=3 c->c2. + let batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Int32Array::from(vec![1, 2, 3, 1, 3])), + Arc::new(StringArray::from(vec!["a", "b", "c", "a2", "c2"])), + ], + ) + .unwrap(); + let frag_id = memtable.insert(batch).await.unwrap(); + memtable.mark_wal_flushed(&[frag_id], 1, &[0]); + + let flusher = MemTableFlusher::new( + store.clone(), + base_path, + base_uri.clone(), + shard_id, + manifest_store, + ); + let result = flusher.flush(&memtable, epoch, 1).await.unwrap(); + assert_eq!(result.rows_flushed, 5, "all physical rows are written"); + + // Scanning the flushed generation must honor the deletion vector and + // return only the newest version of each PK. + let gen_uri = format!( + "{}/_mem_wal/{}/{}", + base_uri.trim_end_matches('/'), + shard_id, + result.generation.path + ); + let dataset = Dataset::open(&gen_uri).await.unwrap(); + let batches: Vec = dataset + .scan() + .try_into_stream() + .await + .unwrap() + .try_collect() + .await + .unwrap(); + + let mut rows = std::collections::HashMap::new(); + for b in &batches { + let ids = b + .column_by_name("id") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + let names = b + .column_by_name("name") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + for i in 0..b.num_rows() { + rows.insert(ids.value(i), names.value(i).to_string()); + } + } + + assert_eq!( + rows.len(), + 3, + "deletion vector should leave newest-per-PK, got {:?}", + rows + ); + assert_eq!(rows.get(&1), Some(&"a2".to_string())); + assert_eq!(rows.get(&2), Some(&"b".to_string())); + assert_eq!(rows.get(&3), Some(&"c2".to_string())); + } + + /// Covers `finalize_generation` writing both a deletion vector *and* + /// indexes into the same manifest — the deletion-only and index-only + /// paths are exercised by sibling tests. + #[tokio::test] + async fn test_flush_with_indexes_and_dedup_deletion_vector() { + use super::super::super::index::{BTreeIndexConfig, IndexStore}; + use crate::index::DatasetIndexExt; + use futures::TryStreamExt; + + let (store, base_path, base_uri, _temp_dir) = create_local_store().await; + let shard_id = Uuid::new_v4(); + let manifest_store = Arc::new(ShardManifestStore::new( + store.clone(), + &base_path, + shard_id, + 2, + )); + let (epoch, _manifest) = manifest_store.claim_epoch(0).await.unwrap(); + + // BTree on the non-PK `name` column so the index sees the dedup set. + let index_configs = vec![MemIndexConfig::BTree(BTreeIndexConfig { + name: "name_btree".to_string(), + field_id: 1, + column: "name".to_string(), + })]; + + let schema = create_pk_schema(); + let mut memtable = MemTable::new(schema.clone(), 1, vec![0]).unwrap(); + let registry = IndexStore::from_configs(&index_configs, 100_000, 1_000).unwrap(); + memtable.set_indexes(registry); + + // Duplicate PKs in append order: id=1 a->a2, id=2 b, id=3 c->c2. + let batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Int32Array::from(vec![1, 2, 3, 1, 3])), + Arc::new(StringArray::from(vec!["a", "b", "c", "a2", "c2"])), + ], + ) + .unwrap(); + let frag_id = memtable.insert(batch).await.unwrap(); + memtable.mark_wal_flushed(&[frag_id], 1, &[0]); + + let flusher = MemTableFlusher::new( + store.clone(), + base_path.clone(), + base_uri.clone(), + shard_id, + manifest_store.clone(), + ); + let result = flusher + .flush_with_indexes(&memtable, epoch, &index_configs, 1) + .await + .unwrap(); + assert_eq!(result.rows_flushed, 5, "all physical rows are written"); + + let gen_uri = format!( + "{}/_mem_wal/{}/{}", + base_uri.trim_end_matches('/'), + shard_id, + result.generation.path + ); + let dataset = Dataset::open(&gen_uri).await.unwrap(); + assert_eq!( + dataset.version().version, + 1, + "flushed dataset must be a single-version dataset" + ); + + // Index half of the combined manifest. + let indices = dataset.load_indices().await.unwrap(); + assert_eq!(indices.len(), 1); + assert_eq!(indices[0].name, "name_btree"); + + // Deletion-vector half: scan returns newest-per-PK. + let batches: Vec = dataset + .scan() + .try_into_stream() + .await + .unwrap() + .try_collect() + .await + .unwrap(); + let mut rows = std::collections::HashMap::new(); + for b in &batches { + let ids = b + .column_by_name("id") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + let names = b + .column_by_name("name") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + for i in 0..b.num_rows() { + rows.insert(ids.value(i), names.value(i).to_string()); + } + } + assert_eq!( + rows.len(), + 3, + "deletion vector should leave newest-per-PK, got {:?}", + rows + ); + assert_eq!(rows.get(&1), Some(&"a2".to_string())); + assert_eq!(rows.get(&2), Some(&"b".to_string())); + assert_eq!(rows.get(&3), Some(&"c2".to_string())); + + // The BTree on `name` must not surface a stale value: a hit for the + // pre-update "a" would mean the indexed path ignored the deletion + // vector. + let stale_hits = dataset + .scan() + .filter("name = 'a'") + .unwrap() + .try_into_batch() + .await + .unwrap(); + assert_eq!( + stale_hits.num_rows(), + 0, + "older name 'a' for id=1 must be filtered out by the deletion vector" + ); + let fresh_hits = dataset + .scan() + .filter("name = 'a2'") + .unwrap() + .try_into_batch() + .await + .unwrap(); + assert_eq!(fresh_hits.num_rows(), 1); + } + #[tokio::test] async fn test_flusher_with_btree_index() { use super::super::super::index::{BTreeIndexConfig, IndexStore}; diff --git a/rust/lance/src/dataset/mem_wal/memtable/scanner/builder.rs b/rust/lance/src/dataset/mem_wal/memtable/scanner/builder.rs index f2a0b315536..1047c03a9c6 100644 --- a/rust/lance/src/dataset/mem_wal/memtable/scanner/builder.rs +++ b/rust/lance/src/dataset/mem_wal/memtable/scanner/builder.rs @@ -17,7 +17,11 @@ use lance_datafusion::expr::safe_coerce_scalar; use lance_datafusion::planner::Planner; use lance_linalg::distance::DistanceType; -use super::exec::{BTreeIndexExec, FtsIndexExec, MemTableScanExec, VectorIndexExec}; +use super::exec::{ + BTreeIndexExec, FtsIndexExec, MemTableBruteForceVectorExec, MemTableDedupScanExec, + MemTableScanExec, VectorIndexExec, +}; +use crate::dataset::mem_wal::scanner::exec::validate_pk_types; use crate::dataset::mem_wal::write::{BatchStore, IndexStore}; /// Vector search query parameters. @@ -782,6 +786,55 @@ impl MemTableScanner { Ok(plan) } + /// Plan a newest-per-PK active-arm scan via `MemTableDedupScanExec` — + /// dedup runs before the predicate so a PK whose newest version fails the + /// filter cannot leak an older version that passes. Unlike + /// `plan_full_scan`, this never takes the BTree skip (dedup needs + /// every version) and never pushes a limit (the LSM caps results above + /// the cross-source merge). + pub async fn create_dedup_plan(&self, pk_columns: &[String]) -> Result> { + validate_pk_types(&self.schema, pk_columns)?; + + let pk_indices = pk_columns + .iter() + .map(|name| { + self.schema + .column_with_name(name) + .map(|(idx, _)| idx) + .ok_or_else(|| { + Error::invalid_input(format!( + "Primary key column '{}' not found in schema", + name + )) + }) + }) + .collect::>>()?; + + let projection_indices = self.compute_projection_indices()?; + + // optimize_expr() must run before create_physical_expr() for type coercion. + let (filter_predicate, filter_expr) = if let Some(ref filter) = self.filter { + let planner = Planner::new(self.schema.clone()); + let optimized = planner.optimize_expr(filter.clone())?; + let predicate = planner.create_physical_expr(&optimized)?; + (Some(predicate), Some(optimized)) + } else { + (None, None) + }; + + Ok(Arc::new(MemTableDedupScanExec::new( + self.batch_store.clone(), + self.max_visible_batch_position, + projection_indices, + self.output_schema(), + pk_indices, + self.with_row_id, + self.with_row_address, + filter_predicate, + filter_expr, + ))) + } + /// Plan a BTree index query. /// /// Uses the effective visibility (min of max_visible and max_indexed) to ensure @@ -812,26 +865,39 @@ impl MemTableScanner { /// Plan a vector similarity search. /// - /// Uses the effective visibility (min of max_visible and max_indexed) to ensure - /// queries only see indexed data. Falls back to full scan if no index exists. + /// Always emits a plan whose output schema includes `_distance`: dispatches + /// to [`VectorIndexExec`] when an HNSW exists for the column, otherwise to + /// [`MemTableBruteForceVectorExec`]. The brute-force arm exists because the + /// active memtable is the LSM's unindexed-rows path — when the HNSW config + /// hasn't reached this writer yet (cold-start, or rows written between an + /// index commit and the next memtable rotation), KNN must still produce + /// correct, distance-bearing results so the LSM-level merge stays sound. async fn plan_vector_search(&self, query: &VectorQuery) -> Result> { - if !self.has_vector_index(&query.column) { - return self.plan_full_scan().await; - } - let max_visible = self.max_visible_batch_position; let projection_indices = self.compute_projection_indices()?; - - let index_exec = VectorIndexExec::new( - self.batch_store.clone(), - self.indexes.clone(), - query.clone(), - max_visible, - projection_indices, - self.base_output_schema(), - self.with_row_id, - )?; - self.apply_post_index_ops(Arc::new(index_exec)).await + let base_schema = self.base_output_schema(); + + let exec: Arc = if self.has_vector_index(&query.column) { + Arc::new(VectorIndexExec::new( + self.batch_store.clone(), + self.indexes.clone(), + query.clone(), + max_visible, + projection_indices, + base_schema, + self.with_row_id, + )?) + } else { + Arc::new(MemTableBruteForceVectorExec::new( + self.batch_store.clone(), + query.clone(), + max_visible, + projection_indices, + base_schema, + self.with_row_id, + )?) + }; + self.apply_post_index_ops(exec).await } /// Plan a full-text search. @@ -1452,4 +1518,45 @@ mod tests { assert_eq!(row_addrs.value(i), i as u64); } } + + /// Regression: vector search against a column with no HNSW must still + /// emit a plan whose output schema contains `_distance`. The earlier + /// behaviour fell back to `plan_full_scan` (no `_distance`), which broke + /// the LSM caller's `sort_by_distance` chain. Now the planner dispatches + /// to `MemTableBruteForceVectorExec` instead — see + /// [`super::super::exec::MemTableBruteForceVectorExec`]. + #[tokio::test] + async fn test_plan_vector_search_without_hnsw_produces_distance_schema() { + use std::sync::Arc; + + const DISTANCE_COLUMN: &str = "_distance"; + + let schema: SchemaRef = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new( + "vector", + DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float32, true)), 2), + true, + ), + ])); + + let batch_store = Arc::new(BatchStore::with_capacity(4)); + let indexes = Arc::new(IndexStore::new()); // intentionally no HNSW + + let mut scanner = MemTableScanner::new(batch_store, indexes, schema.clone()); + let query: Arc = + Arc::new(arrow_array::Float32Array::from(vec![0.0_f32, 0.0_f32])); + scanner.nearest("vector", query, 5); + + let plan = scanner + .create_plan() + .await + .expect("planner must produce a plan when no HNSW exists"); + let out_schema = plan.schema(); + assert!( + out_schema.field_with_name(DISTANCE_COLUMN).is_ok(), + "plan output schema missing `{DISTANCE_COLUMN}` — got {:?}", + out_schema + ); + } } diff --git a/rust/lance/src/dataset/mem_wal/memtable/scanner/exec.rs b/rust/lance/src/dataset/mem_wal/memtable/scanner/exec.rs index 53499f4bd2e..e6d79ba6b48 100644 --- a/rust/lance/src/dataset/mem_wal/memtable/scanner/exec.rs +++ b/rust/lance/src/dataset/mem_wal/memtable/scanner/exec.rs @@ -7,14 +7,19 @@ //! - `MemTableScanExec` - Full table scan with MVCC visibility //! - `BTreeIndexExec` - BTree index queries //! - `VectorIndexExec` - HNSW vector search +//! - `MemTableBruteForceVectorExec` - KNN over the active memtable without an HNSW //! - `FtsIndexExec` - Full-text search +mod brute_force_vector; mod btree; +mod dedup_scan; mod fts; mod scan; mod vector; +pub use brute_force_vector::MemTableBruteForceVectorExec; pub use btree::BTreeIndexExec; +pub use dedup_scan::MemTableDedupScanExec; pub use fts::FtsIndexExec; pub use scan::{MemTableScanExec, ROW_ADDRESS_COLUMN}; pub use vector::VectorIndexExec; diff --git a/rust/lance/src/dataset/mem_wal/memtable/scanner/exec/brute_force_vector.rs b/rust/lance/src/dataset/mem_wal/memtable/scanner/exec/brute_force_vector.rs new file mode 100644 index 00000000000..43169f92d2d --- /dev/null +++ b/rust/lance/src/dataset/mem_wal/memtable/scanner/exec/brute_force_vector.rs @@ -0,0 +1,640 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + +//! MemTableBruteForceVectorExec — KNN over the active memtable without an HNSW. +//! +//! Mirrors [`super::VectorIndexExec`]'s output contract (same schema, same row +//! shape, same `_distance` / `_rowid` semantics) so the LSM caller can swap one +//! for the other based on whether the memtable's `IndexStore` has an HNSW for +//! the queried column. The active memtable is the LSM's unindexed-rows path: +//! whenever the HNSW config is absent (cold-start before the Indexer commits, +//! or new rows in the window between commit and next memtable rotation), this +//! exec keeps KNN correct by computing exact distances row-by-row. + +use std::any::Any; +use std::fmt::{Debug, Formatter}; +use std::sync::Arc; + +use arrow_array::{Array, Float32Array, RecordBatch, UInt64Array, cast::AsArray}; +use arrow_schema::{DataType, Field, Schema, SchemaRef}; +use datafusion::common::stats::Precision; +use datafusion::error::Result as DataFusionResult; +use datafusion::execution::TaskContext; +use datafusion::physical_plan::execution_plan::{Boundedness, EmissionType}; +use datafusion::physical_plan::metrics::{ExecutionPlanMetricsSet, MetricsSet}; +use datafusion::physical_plan::stream::RecordBatchStreamAdapter; +use datafusion::physical_plan::{ + DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, PlanProperties, + SendableRecordBatchStream, Statistics, +}; +use datafusion_physical_expr::EquivalenceProperties; +use futures::stream::{self, StreamExt}; +use lance_core::{Error, Result}; +use lance_linalg::distance::DistanceType; + +use super::super::builder::VectorQuery; +use super::vector::DISTANCE_COLUMN; +use crate::dataset::mem_wal::write::BatchStore; + +/// Distance metric used when [`VectorQuery::distance_type`] is `None`. The +/// indexed path defers to the index's own metric, but with no index there is +/// no inherent default — L2 matches what most callers configure and what the +/// flushed/base arms use when re-ranking unindexed candidates. +const DEFAULT_DISTANCE_TYPE: DistanceType = DistanceType::L2; + +/// Brute-force KNN over an active memtable without an HNSW. Produces the same +/// output schema as [`super::VectorIndexExec`]. +pub struct MemTableBruteForceVectorExec { + batch_store: Arc, + query: VectorQuery, + max_visible_batch_position: usize, + projection: Option>, + output_schema: SchemaRef, + properties: Arc, + metrics: ExecutionPlanMetricsSet, + with_row_id: bool, +} + +impl Debug for MemTableBruteForceVectorExec { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.debug_struct("MemTableBruteForceVectorExec") + .field("column", &self.query.column) + .field("k", &self.query.k) + .field( + "max_visible_batch_position", + &self.max_visible_batch_position, + ) + .field("with_row_id", &self.with_row_id) + .finish() + } +} + +impl MemTableBruteForceVectorExec { + /// Build the exec. `base_schema` is the post-projection row schema (no + /// `_distance`, no `_rowid`); `_distance` is appended unconditionally and + /// `_rowid` only when `with_row_id` is set, matching [`VectorIndexExec`]. + pub fn new( + batch_store: Arc, + query: VectorQuery, + max_visible_batch_position: usize, + projection: Option>, + base_schema: SchemaRef, + with_row_id: bool, + ) -> Result { + let mut fields: Vec = base_schema + .fields() + .iter() + .map(|f| f.as_ref().clone()) + .collect(); + fields.push(Field::new(DISTANCE_COLUMN, DataType::Float32, true)); + if with_row_id { + fields.push(Field::new(lance_core::ROW_ID, DataType::UInt64, true)); + } + let output_schema = Arc::new(Schema::new(fields)); + + let properties = Arc::new(PlanProperties::new( + EquivalenceProperties::new(output_schema.clone()), + Partitioning::UnknownPartitioning(1), + EmissionType::Incremental, + Boundedness::Bounded, + )); + + Ok(Self { + batch_store, + query, + max_visible_batch_position, + projection, + output_schema, + properties, + metrics: ExecutionPlanMetricsSet::new(), + with_row_id, + }) + } + + /// Last row position visible under `max_visible_batch_position`, or `None` + /// if no batches are visible. Identical to `VectorIndexExec`'s helper so + /// both arms cut at the same MVCC boundary. + fn compute_max_visible_row(&self) -> Option { + let mut max_visible_row_exclusive: u64 = 0; + let mut current_row: u64 = 0; + + for (batch_position, stored_batch) in self.batch_store.iter().enumerate() { + let batch_end = current_row + stored_batch.num_rows as u64; + if batch_position <= self.max_visible_batch_position { + max_visible_row_exclusive = batch_end; + } + current_row = batch_end; + } + + if max_visible_row_exclusive > 0 { + Some(max_visible_row_exclusive - 1) + } else { + None + } + } + + /// Extract the flat per-element query vector. `arrow_batch_func` wants + /// `(from: &dyn Array, to: &FixedSizeListArray)` where `from` is the raw + /// primitive array of one vector (NOT an FSL), so unwrap an FSL with one + /// row if that's how the caller built it. + fn query_as_flat(&self) -> Result> { + let query_array = self.query.query_vector.as_ref(); + if let Some(fsl) = query_array.as_fixed_size_list_opt() { + if fsl.len() != 1 { + return Err(Error::invalid_input(format!( + "brute-force vector search expects a single query vector, got {}", + fsl.len() + ))); + } + return Ok(fsl.value(0)); + } + Ok(self.query.query_vector.clone()) + } + + /// Compute `(distance, row_position)` for every visible row, then top-k by + /// distance ascending. Rows where the vector column is null or where the + /// computed distance is non-finite are skipped — same convention as the + /// HNSW search (which filters on `result.distance.is_finite()`). + fn compute_topk(&self) -> Result> { + if self.query.k == 0 { + return Ok(Vec::new()); + } + let Some(max_visible_row) = self.compute_max_visible_row() else { + return Ok(Vec::new()); + }; + let query_flat = self.query_as_flat()?; + let column_name = self.query.column.as_str(); + let distance_type = self.query.distance_type.unwrap_or(DEFAULT_DISTANCE_TYPE); + let batch_func = distance_type.arrow_batch_func(); + + // Walk batches in append order. `current_row` is the global row offset + // of the *next* row about to be visited; rows past `max_visible_row` + // are dropped before they reach the heap. + let mut current_row: u64 = 0; + let mut candidates: Vec<(f32, u64)> = Vec::new(); + + for (batch_position, stored_batch) in self.batch_store.iter().enumerate() { + let n = stored_batch.num_rows; + if n == 0 { + continue; + } + if batch_position > self.max_visible_batch_position { + current_row += n as u64; + continue; + } + + let column = stored_batch + .data + .column_by_name(column_name) + .ok_or_else(|| { + Error::invalid_input(format!( + "Vector column '{}' not found in memtable schema", + column_name + )) + })?; + let column_fsl = column.as_fixed_size_list_opt().ok_or_else(|| { + Error::invalid_input(format!( + "Vector column '{}' must be FixedSizeList; got {:?}", + column_name, + column.data_type() + )) + })?; + + let distances = batch_func(query_flat.as_ref(), column_fsl).map_err(|e| { + Error::invalid_input(format!( + "brute-force distance computation failed for column '{}': {}", + column_name, e + )) + })?; + + for row in 0..n { + let pos = current_row + row as u64; + if pos > max_visible_row { + break; + } + if distances.is_null(row) { + continue; + } + let dist = distances.value(row); + if !dist.is_finite() { + continue; + } + candidates.push((dist, pos)); + } + + current_row += n as u64; + } + + // `partial_cmp` defaults Equal on NaN; we filtered non-finite above so + // every remaining value compares deterministically. + candidates.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal)); + + if self.query.distance_lower_bound.is_some() || self.query.distance_upper_bound.is_some() { + candidates.retain(|&(dist, _)| { + let above_lower = self.query.distance_lower_bound.is_none_or(|lb| dist >= lb); + let below_upper = self.query.distance_upper_bound.is_none_or(|ub| dist < ub); + above_lower && below_upper + }); + } + + candidates.truncate(self.query.k); + Ok(candidates) + } + + /// Materialize the top-k rows from the batch store, mirroring + /// `VectorIndexExec::materialize_rows`. Groups by batch so the per-batch + /// `take` is amortized; emits one output batch per source batch that + /// contributes. + fn materialize_rows(&self, results: &[(f32, u64)]) -> DataFusionResult> { + if results.is_empty() { + return Ok(vec![]); + } + + let mut batch_ranges = Vec::new(); + let mut current_row = 0usize; + for stored_batch in self.batch_store.iter() { + let start = current_row; + let end = current_row + stored_batch.num_rows; + batch_ranges.push((start, end)); + current_row = end; + } + + let mut batches_data: std::collections::HashMap> = + std::collections::HashMap::new(); + for &(distance, pos) in results { + let pos_usize = pos as usize; + for (batch_id, &(start, end)) in batch_ranges.iter().enumerate() { + if pos_usize >= start && pos_usize < end { + batches_data.entry(batch_id).or_default().push(( + pos_usize - start, + distance, + pos, + )); + break; + } + } + } + + let mut all_batches = Vec::new(); + for (batch_id, rows_with_dist) in batches_data { + if let Some(stored) = self.batch_store.get(batch_id) { + let rows: Vec = rows_with_dist.iter().map(|&(r, _, _)| r as u32).collect(); + let distances: Vec = rows_with_dist.iter().map(|&(_, d, _)| d).collect(); + let row_positions: Vec = + rows_with_dist.iter().map(|&(_, _, pos)| pos).collect(); + + let indices = arrow_array::UInt32Array::from(rows); + + let mut columns: Vec> = stored + .data + .columns() + .iter() + .map(|col| arrow_select::take::take(col.as_ref(), &indices, None).unwrap()) + .collect(); + + columns.push(Arc::new(Float32Array::from(distances))); + + let mut final_columns = if let Some(ref proj_indices) = self.projection { + let mut projected: Vec<_> = + proj_indices.iter().map(|&i| columns[i].clone()).collect(); + // Distance was just pushed onto `columns`; keep it last. + projected.push(columns.last().unwrap().clone()); + projected + } else { + columns + }; + + if self.with_row_id { + final_columns.push(Arc::new(UInt64Array::from(row_positions))); + } + + let batch = RecordBatch::try_new(self.output_schema.clone(), final_columns)?; + all_batches.push(batch); + } + } + + Ok(all_batches) + } +} + +impl DisplayAs for MemTableBruteForceVectorExec { + fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter<'_>) -> std::fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + write!( + f, + "MemTableBruteForceVectorExec: column={}, k={}, with_row_id={}", + self.query.column, self.query.k, self.with_row_id + ) + } + DisplayFormatType::TreeRender => { + write!( + f, + "MemTableBruteForceVectorExec\ncolumn={}\nk={}\nwith_row_id={}", + self.query.column, self.query.k, self.with_row_id + ) + } + } + } +} + +impl ExecutionPlan for MemTableBruteForceVectorExec { + fn name(&self) -> &str { + "MemTableBruteForceVectorExec" + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + self.output_schema.clone() + } + + fn children(&self) -> Vec<&Arc> { + vec![] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> DataFusionResult> { + if !children.is_empty() { + return Err(datafusion::error::DataFusionError::Internal( + "MemTableBruteForceVectorExec does not have children".to_string(), + )); + } + Ok(self) + } + + fn execute( + &self, + _partition: usize, + _context: Arc, + ) -> DataFusionResult { + let results = self + .compute_topk() + .map_err(|e| datafusion::error::DataFusionError::External(Box::new(e)))?; + let batches = self.materialize_rows(&results)?; + let stream = stream::iter(batches.into_iter().map(Ok)).boxed(); + Ok(Box::pin(RecordBatchStreamAdapter::new( + self.output_schema.clone(), + stream, + ))) + } + + fn partition_statistics(&self, _partition: Option) -> DataFusionResult { + Ok(Statistics { + num_rows: Precision::Exact(self.query.k), + total_byte_size: Precision::Absent, + column_statistics: vec![], + }) + } + + fn metrics(&self) -> Option { + Some(self.metrics.clone_inner()) + } + + fn properties(&self) -> &Arc { + &self.properties + } + + fn supports_limit_pushdown(&self) -> bool { + true + } +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow_array::{FixedSizeListArray, Float32Array, Int32Array}; + use arrow_schema::{DataType, Field, Schema}; + use datafusion::physical_plan::common::collect; + use datafusion::prelude::SessionContext; + + fn make_schema() -> SchemaRef { + Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new( + "vector", + DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float32, true)), 2), + true, + ), + ])) + } + + fn make_batch(schema: SchemaRef, ids: &[i32], vectors: &[[f32; 2]]) -> RecordBatch { + let id_array = Arc::new(Int32Array::from(ids.to_vec())) as Arc; + let values: Vec = vectors.iter().flat_map(|v| v.iter().copied()).collect(); + let inner = Arc::new(Float32Array::from(values)); + let field = Arc::new(Field::new("item", DataType::Float32, true)); + let vec_array = + Arc::new(FixedSizeListArray::try_new(field, 2, inner, None).expect("build fsl")) + as Arc; + RecordBatch::try_new(schema, vec![id_array, vec_array]).expect("build batch") + } + + fn store_with_batches(batches: Vec) -> Arc { + let store = Arc::new(BatchStore::with_capacity(batches.len().max(1))); + for batch in batches { + store.append(batch).expect("append batch"); + } + store + } + + fn query_for(vector: [f32; 2], k: usize) -> VectorQuery { + let values = Arc::new(Float32Array::from(vector.to_vec())) as Arc; + VectorQuery { + column: "vector".to_string(), + query_vector: values, + k, + nprobes: 1, + maximum_nprobes: None, + distance_type: Some(DistanceType::L2), + ef: None, + refine_factor: None, + distance_lower_bound: None, + distance_upper_bound: None, + } + } + + async fn execute_to_batches(exec: Arc) -> Vec { + let ctx = SessionContext::new(); + let stream = exec.execute(0, ctx.task_ctx()).expect("execute"); + collect(stream).await.expect("collect") + } + + #[tokio::test] + async fn top_k_by_distance() { + // Five rows; query at (0,0); L2 distances are id² each — expect ids in + // ascending order, capped at k=3. + let schema = make_schema(); + let batch = make_batch( + schema.clone(), + &[0, 1, 2, 3, 4], + &[[0.0, 0.0], [1.0, 0.0], [2.0, 0.0], [3.0, 0.0], [4.0, 0.0]], + ); + let store = store_with_batches(vec![batch]); + let query = query_for([0.0, 0.0], 3); + let exec = Arc::new( + MemTableBruteForceVectorExec::new( + store, + query, + /* max_visible_batch_position = */ usize::MAX, + None, + schema, + false, + ) + .expect("ctor"), + ); + let out = execute_to_batches(exec).await; + // Concat and check ids + distances in order. + let total: usize = out.iter().map(|b| b.num_rows()).sum(); + assert_eq!(total, 3, "k=3 cap not honored: got {total} rows"); + + let mut id_dist: Vec<(i32, f32)> = Vec::new(); + for batch in &out { + let ids = batch + .column_by_name("id") + .unwrap() + .as_primitive::(); + let dists = batch + .column_by_name(DISTANCE_COLUMN) + .unwrap() + .as_primitive::(); + for i in 0..batch.num_rows() { + id_dist.push((ids.value(i), dists.value(i))); + } + } + id_dist.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap()); + assert_eq!(id_dist[0].0, 0); + assert_eq!(id_dist[1].0, 1); + assert_eq!(id_dist[2].0, 2); + } + + #[tokio::test] + async fn empty_memtable_returns_empty_with_distance_schema() { + let schema = make_schema(); + let store = Arc::new(BatchStore::with_capacity(4)); + let query = query_for([0.5, 0.5], 10); + let exec = Arc::new( + MemTableBruteForceVectorExec::new(store, query, usize::MAX, None, schema, false) + .expect("ctor"), + ); + let out_schema = exec.schema(); + assert!( + out_schema.field_with_name(DISTANCE_COLUMN).is_ok(), + "output schema must contain `{DISTANCE_COLUMN}` even with empty memtable; got {:?}", + out_schema + ); + let out = execute_to_batches(exec).await; + let total: usize = out.iter().map(|b| b.num_rows()).sum(); + assert_eq!(total, 0); + } + + #[tokio::test] + async fn respects_max_visible_batch_position() { + // Two batches of two rows. Freeze at batch 0 — only ids 0,1 are + // visible candidates; the (closer) ids 2,3 in batch 1 are excluded. + let schema = make_schema(); + let b0 = make_batch(schema.clone(), &[0, 1], &[[5.0, 0.0], [6.0, 0.0]]); + let b1 = make_batch(schema.clone(), &[2, 3], &[[1.0, 0.0], [2.0, 0.0]]); + let store = store_with_batches(vec![b0, b1]); + let query = query_for([0.0, 0.0], 4); + let exec = Arc::new( + MemTableBruteForceVectorExec::new( + store, query, /* max_visible_batch_position = */ 0, None, schema, false, + ) + .expect("ctor"), + ); + let out = execute_to_batches(exec).await; + let mut returned_ids: Vec = Vec::new(); + for batch in &out { + let ids = batch + .column_by_name("id") + .unwrap() + .as_primitive::(); + for i in 0..batch.num_rows() { + returned_ids.push(ids.value(i)); + } + } + returned_ids.sort(); + assert_eq!(returned_ids, vec![0, 1]); + } + + #[tokio::test] + async fn applies_distance_bounds() { + let schema = make_schema(); + let batch = make_batch( + schema.clone(), + &[0, 1, 2, 3], + &[[0.0, 0.0], [1.0, 0.0], [2.0, 0.0], [3.0, 0.0]], + ); + let store = store_with_batches(vec![batch]); + let mut query = query_for([0.0, 0.0], 10); + // L2² distances: 0, 1, 4, 9. Keep only distances in [1, 5) — ids 1, 2. + query.distance_lower_bound = Some(1.0); + query.distance_upper_bound = Some(5.0); + let exec = Arc::new( + MemTableBruteForceVectorExec::new(store, query, usize::MAX, None, schema, false) + .expect("ctor"), + ); + let out = execute_to_batches(exec).await; + let mut ids: Vec = Vec::new(); + for batch in &out { + let id_arr = batch + .column_by_name("id") + .unwrap() + .as_primitive::(); + for i in 0..batch.num_rows() { + ids.push(id_arr.value(i)); + } + } + ids.sort(); + assert_eq!(ids, vec![1, 2]); + } + + #[tokio::test] + async fn populates_row_id_when_requested() { + let schema = make_schema(); + let batch = make_batch( + schema.clone(), + &[10, 11, 12], + &[[3.0, 0.0], [1.0, 0.0], [2.0, 0.0]], + ); + let store = store_with_batches(vec![batch]); + let query = query_for([0.0, 0.0], 3); + let exec = Arc::new( + MemTableBruteForceVectorExec::new( + store, + query, + usize::MAX, + None, + schema, + /* with_row_id = */ true, + ) + .expect("ctor"), + ); + let out_schema = exec.schema(); + assert!(out_schema.field_with_name(lance_core::ROW_ID).is_ok()); + + let out = execute_to_batches(exec).await; + let mut pairs: Vec<(i32, u64)> = Vec::new(); + for batch in &out { + let ids = batch + .column_by_name("id") + .unwrap() + .as_primitive::(); + let rowids = batch + .column_by_name(lance_core::ROW_ID) + .unwrap() + .as_primitive::(); + for i in 0..batch.num_rows() { + pairs.push((ids.value(i), rowids.value(i))); + } + } + // Row offsets are insert-order: id=10 → 0, id=11 → 1, id=12 → 2. + pairs.sort_by_key(|(id, _)| *id); + assert_eq!(pairs, vec![(10, 0), (11, 1), (12, 2)]); + } +} diff --git a/rust/lance/src/dataset/mem_wal/memtable/scanner/exec/dedup_scan.rs b/rust/lance/src/dataset/mem_wal/memtable/scanner/exec/dedup_scan.rs new file mode 100644 index 00000000000..ba5947e4b12 --- /dev/null +++ b/rust/lance/src/dataset/mem_wal/memtable/scanner/exec/dedup_scan.rs @@ -0,0 +1,443 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + +//! MemTableDedupScanExec — newest-per-PK scan over the active memtable that +//! fuses within-source dedup with the scalar predicate in a single pass. +//! +//! The active memtable is an append log, so a PK update is a later append +//! with the same key. [`super::MemTableScanExec`] pushes the filter into the +//! scan, which removes a PK's newest row *before* dedup runs — an older row +//! that still satisfies the predicate then leaks through (a "phantom"). +//! +//! This exec walks rows newest-first (batches reversed, rows iterated +//! back-to-front), seeds a seen-set from the *newest* occurrence of every PK +//! regardless of the predicate (so older versions stay suppressed even when +//! the newest row fails the filter), and records the keep/drop verdict into a +//! forward-aligned mask. A single `filter_record_batch` over the original +//! batch then emits the survivors with no per-column reverse copy. + +use std::any::Any; +use std::collections::HashSet; +use std::fmt::{Debug, Formatter}; +use std::sync::Arc; + +use arrow_array::{Array, BooleanArray, RecordBatch, UInt64Array}; +use arrow_schema::SchemaRef; +use datafusion::common::stats::Precision; +use datafusion::error::Result as DataFusionResult; +use datafusion::execution::TaskContext; +use datafusion::physical_plan::execution_plan::{Boundedness, EmissionType}; +use datafusion::physical_plan::metrics::{ExecutionPlanMetricsSet, MetricsSet}; +use datafusion::physical_plan::stream::RecordBatchStreamAdapter; +use datafusion::physical_plan::{ + DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, PlanProperties, + SendableRecordBatchStream, Statistics, +}; +use datafusion::prelude::Expr; +use datafusion_physical_expr::{EquivalenceProperties, PhysicalExprRef}; +use futures::stream::{self, StreamExt}; + +use crate::dataset::mem_wal::scanner::exec::compute_pk_hash; +use crate::dataset::mem_wal::write::BatchStore; + +/// Scans the active memtable newest-first and emits the newest-per-PK rows +/// that satisfy the (optional) predicate. See the module doc. +pub struct MemTableDedupScanExec { + batch_store: Arc, + max_visible_batch_position: usize, + /// Column indices to project (into the source schema). + projection: Option>, + output_schema: SchemaRef, + /// Column indices of the primary key in the source schema. + pk_indices: Vec, + properties: Arc, + metrics: ExecutionPlanMetricsSet, + with_row_id: bool, + with_row_address: bool, + filter_predicate: Option, + /// Original filter expression, for display only. + filter_expr: Option, +} + +impl Debug for MemTableDedupScanExec { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.debug_struct("MemTableDedupScanExec") + .field( + "max_visible_batch_position", + &self.max_visible_batch_position, + ) + .field("projection", &self.projection) + .field("pk_indices", &self.pk_indices) + .field("with_row_address", &self.with_row_address) + .field("has_filter", &self.filter_predicate.is_some()) + .finish() + } +} + +impl MemTableDedupScanExec { + /// Create a new fused dedup + predicate scan over the active memtable. + #[allow(clippy::too_many_arguments)] + pub fn new( + batch_store: Arc, + max_visible_batch_position: usize, + projection: Option>, + output_schema: SchemaRef, + pk_indices: Vec, + with_row_id: bool, + with_row_address: bool, + filter_predicate: Option, + filter_expr: Option, + ) -> Self { + let properties = Arc::new(PlanProperties::new( + EquivalenceProperties::new(output_schema.clone()), + Partitioning::UnknownPartitioning(1), + EmissionType::Incremental, + Boundedness::Bounded, + )); + + Self { + batch_store, + max_visible_batch_position, + projection, + output_schema, + pk_indices, + properties, + metrics: ExecutionPlanMetricsSet::new(), + with_row_id, + with_row_address, + filter_predicate, + filter_expr, + } + } +} + +impl DisplayAs for MemTableDedupScanExec { + fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter<'_>) -> std::fmt::Result { + let projection_names: Vec<&str> = self + .output_schema + .fields() + .iter() + .map(|field| field.name().as_str()) + .collect(); + let filter_str = self + .filter_expr + .as_ref() + .map(|e| format!(", filter={}", e)) + .unwrap_or_default(); + let row_addr_str = if self.with_row_address { + ", with_row_address=true" + } else { + "" + }; + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => write!( + f, + "MemTableDedupScanExec: projection=[{}], with_row_id={}{}{}", + projection_names.join(", "), + self.with_row_id, + row_addr_str, + filter_str + ), + DisplayFormatType::TreeRender => write!( + f, + "MemTableDedupScanExec\nprojection=[{}]\nwith_row_id={}{}{}", + projection_names.join(", "), + self.with_row_id, + row_addr_str, + filter_str + ), + } + } +} + +impl ExecutionPlan for MemTableDedupScanExec { + fn name(&self) -> &str { + "MemTableDedupScanExec" + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + self.output_schema.clone() + } + + fn children(&self) -> Vec<&Arc> { + vec![] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> DataFusionResult> { + if !children.is_empty() { + return Err(datafusion::error::DataFusionError::Internal( + "MemTableDedupScanExec does not have children".to_string(), + )); + } + Ok(self) + } + + fn execute( + &self, + _partition: usize, + _context: Arc, + ) -> DataFusionResult { + // Newest-first iteration: reverse batches here, rows are walked + // back-to-front below. + let mut batches = self + .batch_store + .visible_batches_with_offsets(self.max_visible_batch_position); + batches.reverse(); + + let projection = self.projection.clone(); + let schema = self.output_schema.clone(); + let with_row_id = self.with_row_id; + let with_row_address = self.with_row_address; + let filter_predicate = self.filter_predicate.clone(); + let pk_indices = self.pk_indices.clone(); + let need_row_offsets = with_row_id || with_row_address; + + // Cross-batch seen-set: first time a PK hash is seen (newest-first) wins. + let mut seen: HashSet = HashSet::new(); + let mut out: Vec> = Vec::with_capacity(batches.len()); + + for (batch, row_offset) in batches { + let n = batch.num_rows(); + if n == 0 { + continue; + } + + // Predicate mask over the original (forward) rows; null counts as + // no-match. + let filter_array = match &filter_predicate { + Some(predicate) => { + let value = predicate.evaluate(&batch)?; + let array = value.into_array(n)?; + let Some(boolean) = array.as_any().downcast_ref::() else { + return Err(datafusion::error::DataFusionError::Internal( + "Filter predicate did not evaluate to boolean".to_string(), + )); + }; + Some(boolean.clone()) + } + None => None, + }; + + // Walk newest-first; first insertion into `seen` is the newest + // occurrence (keep), later ones are older (drop). `seen` is + // updated even when the newest row fails the predicate so its + // older versions stay suppressed (no phantom). + let mut emit_forward = vec![false; n]; + for j in (0..n).rev() { + let pk_hash = compute_pk_hash(&batch, &pk_indices, j); + let is_newest = seen.insert(pk_hash); + let passes = match &filter_array { + Some(mask) => mask.is_valid(j) && mask.value(j), + None => true, + }; + emit_forward[j] = is_newest && passes; + } + let emit_mask = BooleanArray::from(emit_forward); + + let emitted = arrow_select::filter::filter_record_batch(&batch, &emit_mask)?; + if emitted.num_rows() == 0 { + continue; + } + + let filtered_offsets: Vec = if need_row_offsets { + (0..n) + .filter(|&j| emit_mask.value(j)) + .map(|j| row_offset + j as u64) + .collect() + } else { + vec![] + }; + + let mut columns: Vec> = if let Some(ref indices) = projection { + indices.iter().map(|&i| emitted.column(i).clone()).collect() + } else { + emitted.columns().to_vec() + }; + if with_row_id { + columns.push(Arc::new(UInt64Array::from(filtered_offsets.clone()))); + } + if with_row_address { + columns.push(Arc::new(UInt64Array::from(filtered_offsets))); + } + + out.push( + RecordBatch::try_new(schema.clone(), columns) + .map_err(datafusion::error::DataFusionError::from), + ); + } + + let stream = stream::iter(out).boxed(); + Ok(Box::pin(RecordBatchStreamAdapter::new( + self.output_schema.clone(), + stream, + ))) + } + + fn partition_statistics(&self, _partition: Option) -> DataFusionResult { + Ok(Statistics { + num_rows: Precision::Absent, + total_byte_size: Precision::Absent, + column_statistics: vec![], + }) + } + + fn metrics(&self) -> Option { + Some(self.metrics.clone_inner()) + } + + fn properties(&self) -> &Arc { + &self.properties + } + + fn supports_limit_pushdown(&self) -> bool { + false + } +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow_array::Int32Array; + use arrow_schema::{DataType, Field, Schema}; + use datafusion::prelude::col; + use futures::TryStreamExt; + use lance_datafusion::planner::Planner; + use std::collections::HashMap; + + fn source_schema() -> SchemaRef { + Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("value", DataType::Int32, true), + ])) + } + + fn output_schema() -> SchemaRef { + Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("value", DataType::Int32, true), + Field::new( + crate::dataset::mem_wal::memtable::scanner::exec::ROW_ADDRESS_COLUMN, + DataType::UInt64, + true, + ), + ])) + } + + /// id PK + nullable value. Each tuple is one appended row, in order. + fn batch(rows: &[(i32, Option)]) -> RecordBatch { + let ids: Vec = rows.iter().map(|(id, _)| *id).collect(); + let values: Vec> = rows.iter().map(|(_, v)| *v).collect(); + RecordBatch::try_new( + source_schema(), + vec![ + Arc::new(Int32Array::from(ids)), + Arc::new(Int32Array::from(values)), + ], + ) + .unwrap() + } + + /// Run the exec and collect (id -> (value, rowaddr)). + async fn run( + store: Arc, + max_visible: usize, + filter: Option, + ) -> HashMap, u64)> { + let filter_predicate = filter.map(|expr| { + let planner = Planner::new(source_schema()); + let optimized = planner.optimize_expr(expr).unwrap(); + planner.create_physical_expr(&optimized).unwrap() + }); + let filter_expr = None; + let exec = MemTableDedupScanExec::new( + store, + max_visible, + None, + output_schema(), + vec![0], + false, + true, + filter_predicate, + filter_expr, + ); + let ctx = Arc::new(TaskContext::default()); + let batches: Vec = exec.execute(0, ctx).unwrap().try_collect().await.unwrap(); + + let mut out = HashMap::new(); + for b in &batches { + let ids = b.column(0).as_any().downcast_ref::().unwrap(); + let values = b.column(1).as_any().downcast_ref::().unwrap(); + let addrs = b.column(2).as_any().downcast_ref::().unwrap(); + for i in 0..b.num_rows() { + let value = (!values.is_null(i)).then(|| values.value(i)); + let prev = out.insert(ids.value(i), (value, addrs.value(i))); + assert!(prev.is_none(), "duplicate PK {} in output", ids.value(i)); + } + } + out + } + + /// Within a single batch: insert + update of one PK collapses to newest, + /// and a predicate that the newest version fails must NOT resurrect the old. + #[tokio::test] + async fn within_batch_phantom_suppressed() { + let store = Arc::new(BatchStore::with_capacity(16)); + // id=10 inserted (100) then updated to NULL, all in one batch. + store.append(batch(&[(10, Some(100)), (10, None)])).unwrap(); + + let no_filter = run(store.clone(), 0, None).await; + assert_eq!(no_filter.len(), 1); + assert_eq!(no_filter[&10].0, None, "newest version of id=10 is NULL"); + + let not_null = run(store, 0, Some(col("value").is_not_null())).await; + assert!( + !not_null.contains_key(&10), + "id=10 newest is NULL; the stale value=100 must not leak under value IS NOT NULL" + ); + } + + /// Cross-batch dedup + predicate, mirroring the design doc worked example. + #[tokio::test] + async fn cross_batch_newest_per_pk_with_filter() { + let store = Arc::new(BatchStore::with_capacity(16)); + store + .append(batch(&[(10, Some(100)), (20, Some(200)), (10, None)])) + .unwrap(); + store.append(batch(&[(20, Some(999)), (30, None)])).unwrap(); + + // No filter: newest per PK = {10:NULL@2, 20:999@3, 30:NULL@4}. + let all = run(store.clone(), 1, None).await; + assert_eq!(all.len(), 3); + assert_eq!(all[&10], (None, 2)); + assert_eq!(all[&20], (Some(999), 3)); + assert_eq!(all[&30], (None, 4)); + + // value IS NOT NULL: only id=20 (newest 999) survives; 10 and 30 are + // newest-NULL so they must be absent (no stale leak). + let not_null = run(store, 1, Some(col("value").is_not_null())).await; + assert_eq!(not_null.len(), 1); + assert_eq!(not_null[&20], (Some(999), 3)); + } + + /// value IS NULL is the mirror case: a PK whose newest version is non-NULL + /// must not leak an older NULL version. + #[tokio::test] + async fn is_null_predicate_no_stale_leak() { + let store = Arc::new(BatchStore::with_capacity(16)); + // id=40 inserted NULL then updated to 400 (newest non-NULL). + store.append(batch(&[(40, None), (40, Some(400))])).unwrap(); + + let is_null = run(store, 0, Some(col("value").is_null())).await; + assert!( + !is_null.contains_key(&40), + "id=40 newest is 400; the stale NULL must not leak under value IS NULL" + ); + } +} diff --git a/rust/lance/src/dataset/mem_wal/scanner/exec.rs b/rust/lance/src/dataset/mem_wal/scanner/exec.rs index 79ae94f1c50..88fd617dc0a 100644 --- a/rust/lance/src/dataset/mem_wal/scanner/exec.rs +++ b/rust/lance/src/dataset/mem_wal/scanner/exec.rs @@ -7,30 +7,24 @@ //! for LSM tree query execution: //! //! - [`MemtableGenTagExec`]: Wraps a scan to add `_memtable_gen` column -//! - [`DeduplicateExec`]: Deduplicates by primary key, keeping newest version //! - [`BloomFilterGuardExec`]: Guards child execution with bloom filter check //! - [`CoalesceFirstExec`]: Returns first non-empty result with short-circuit -//! - [`LsmSourceTagExec`]: Tags rows with `_memtable_gen` + `_freshness` for the vector-search global dedup -//! - [`LsmGlobalPkDedupExec`]: Single-pass cross-source PK dedup over the merged vector-search stream -//! - [`WithinSourceDedupExec`]: Deduplicates rows with the same PK from a single source (used by point lookup) -//! - [`PkHashFilterExec`]: Drops rows whose PK hash was superseded by a newer generation (vector-search block-list) +//! - [`WithinSourceDedupExec`]: Deduplicates rows with the same PK from a single source +//! - [`PkHashFilterExec`]: Drops rows whose PK hash was superseded by a newer generation (the cross-generation block-list) mod bloom_guard; mod coalesce_first; -mod deduplicate; mod generation_tag; -mod global_pk_dedup; mod pk; mod pk_hash_filter; -mod source_tag; mod within_source_dedup; pub use bloom_guard::{BloomFilterGuardExec, compute_pk_hash_from_scalars}; pub use coalesce_first::CoalesceFirstExec; -pub use deduplicate::{DeduplicateExec, ROW_ADDRESS_COLUMN}; pub use generation_tag::{MEMTABLE_GEN_COLUMN, MemtableGenTagExec}; -pub use global_pk_dedup::LsmGlobalPkDedupExec; -pub use pk::{compute_pk_hash, resolve_pk_indices}; +pub use pk::{ + ROW_ADDRESS_COLUMN, compute_pk_hash, is_supported_pk_type, resolve_pk_indices, + validate_pk_types, +}; pub use pk_hash_filter::PkHashFilterExec; -pub use source_tag::{FRESHNESS_COLUMN, FreshnessPolarity, LsmSourceTagExec}; pub use within_source_dedup::{DedupDirection, WithinSourceDedupExec}; diff --git a/rust/lance/src/dataset/mem_wal/scanner/exec/bloom_guard.rs b/rust/lance/src/dataset/mem_wal/scanner/exec/bloom_guard.rs index ab8be9f8b75..6039eed1629 100644 --- a/rust/lance/src/dataset/mem_wal/scanner/exec/bloom_guard.rs +++ b/rust/lance/src/dataset/mem_wal/scanner/exec/bloom_guard.rs @@ -206,60 +206,44 @@ impl datafusion::physical_plan::RecordBatchStream for EmptyStream { /// Compute hash for a primary key value. /// -/// This function should be consistent with the hash function used when -/// inserting keys into the bloom filter. +/// Must stay byte-for-byte consistent with [`super::compute_pk_hash`]: for each +/// scalar, hash `is_null` first, then hash the inner value only when not-null. +/// This includes the typed Option(None) branches — they represent a NULL of a +/// known type and must hash the same as a row-side NULL of the same column. pub fn compute_pk_hash_from_scalars(values: &[datafusion::common::ScalarValue]) -> u64 { use std::collections::hash_map::DefaultHasher; use std::hash::{Hash, Hasher}; let mut hasher = DefaultHasher::new(); + fn hash_opt(hasher: &mut DefaultHasher, v: &Option) { + v.is_none().hash(hasher); + if let Some(val) = v { + val.hash(hasher); + } + } + for value in values { match value { datafusion::common::ScalarValue::Null => { true.hash(&mut hasher); // is_null = true } - datafusion::common::ScalarValue::Int32(v) => { - false.hash(&mut hasher); - if let Some(val) = v { - val.hash(&mut hasher); - } - } - datafusion::common::ScalarValue::Int64(v) => { - false.hash(&mut hasher); - if let Some(val) = v { - val.hash(&mut hasher); - } - } - datafusion::common::ScalarValue::UInt32(v) => { - false.hash(&mut hasher); - if let Some(val) = v { - val.hash(&mut hasher); - } - } - datafusion::common::ScalarValue::UInt64(v) => { - false.hash(&mut hasher); - if let Some(val) = v { - val.hash(&mut hasher); - } - } + datafusion::common::ScalarValue::Int8(v) => hash_opt(&mut hasher, v), + datafusion::common::ScalarValue::Int16(v) => hash_opt(&mut hasher, v), + datafusion::common::ScalarValue::Int32(v) => hash_opt(&mut hasher, v), + datafusion::common::ScalarValue::Int64(v) => hash_opt(&mut hasher, v), + datafusion::common::ScalarValue::UInt8(v) => hash_opt(&mut hasher, v), + datafusion::common::ScalarValue::UInt16(v) => hash_opt(&mut hasher, v), + datafusion::common::ScalarValue::UInt32(v) => hash_opt(&mut hasher, v), + datafusion::common::ScalarValue::UInt64(v) => hash_opt(&mut hasher, v), + datafusion::common::ScalarValue::Boolean(v) => hash_opt(&mut hasher, v), datafusion::common::ScalarValue::Utf8(v) - | datafusion::common::ScalarValue::LargeUtf8(v) => { - false.hash(&mut hasher); - if let Some(val) = v { - val.hash(&mut hasher); - } - } + | datafusion::common::ScalarValue::LargeUtf8(v) => hash_opt(&mut hasher, v), datafusion::common::ScalarValue::Binary(v) - | datafusion::common::ScalarValue::LargeBinary(v) => { - false.hash(&mut hasher); - if let Some(val) = v { - val.hash(&mut hasher); - } - } - // Add more types as needed + | datafusion::common::ScalarValue::LargeBinary(v) => hash_opt(&mut hasher, v), + // Unsupported types: validated out at the scanner boundary, but + // distinguish by value rather than collapse if reached. _ => { - // For unsupported types, just hash the debug representation false.hash(&mut hasher); format!("{:?}", value).hash(&mut hasher); } diff --git a/rust/lance/src/dataset/mem_wal/scanner/exec/deduplicate.rs b/rust/lance/src/dataset/mem_wal/scanner/exec/deduplicate.rs deleted file mode 100644 index 4a3492f2ad7..00000000000 --- a/rust/lance/src/dataset/mem_wal/scanner/exec/deduplicate.rs +++ /dev/null @@ -1,725 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright The Lance Authors - -//! Deduplication execution node for LSM merge reads. - -use std::any::Any; -use std::fmt; -use std::pin::Pin; -use std::sync::Arc; -use std::task::{Context, Poll}; - -use arrow_array::{Array, RecordBatch}; -use arrow_schema::{Field, Schema, SchemaRef, SortOptions}; -use datafusion::common::ScalarValue; -use datafusion::error::Result as DFResult; -use datafusion::execution::TaskContext; -use datafusion::physical_expr::expressions::Column; -use datafusion::physical_expr::{ - EquivalenceProperties, LexOrdering, Partitioning, PhysicalSortExpr, -}; -use datafusion::physical_plan::sorts::sort::SortExec; -use datafusion::physical_plan::{ - DisplayAs, DisplayFormatType, ExecutionPlan, ExecutionPlanProperties, PlanProperties, - SendableRecordBatchStream, -}; -use futures::{Stream, StreamExt}; -use lance_core::{Error, Result}; - -use super::generation_tag::MEMTABLE_GEN_COLUMN; - -/// Column name for row address (used for ordering within generation). -pub const ROW_ADDRESS_COLUMN: &str = "_rowaddr"; - -/// Deduplicates rows by primary key, keeping the row with highest (_memtable_gen, _rowaddr). -/// -/// # Algorithm -/// -/// 1. Sort input by (pk_columns, _memtable_gen DESC, _rowaddr DESC) - if not already sorted -/// 2. Stream through sorted data, emit only first row per PK -/// -/// After sorting, the first occurrence of each PK has the highest (_memtable_gen, _rowaddr), -/// so we can deduplicate in a single streaming pass. -/// -/// # Pre-sorted Input Optimization -/// -/// When `input_sorted` is true, the input is assumed to already be sorted by -/// (pk_columns ASC, _memtable_gen DESC, _rowaddr DESC). This allows skipping the internal -/// sort, which is useful when the input comes from SortPreservingMergeExec that -/// has already merged K pre-sorted streams. -/// -/// # Memory Efficiency -/// -/// Uses DataFusion's SortExec for external sort when data exceeds memory. -/// The streaming deduplication pass requires O(1) memory per partition. -#[derive(Debug)] -pub struct DeduplicateExec { - /// Child plan (UnionExec of tagged scans). - input: Arc, - /// Primary key column names. - pk_columns: Vec, - /// Output schema. - schema: SchemaRef, - /// Whether to keep _memtable_gen in output. - with_memtable_gen: bool, - /// Whether to keep _rowaddr in output. - keep_row_address: bool, - /// Whether the input is already sorted by (pk, _memtable_gen DESC, _rowaddr DESC). - input_sorted: bool, - /// Plan properties. - properties: Arc, -} - -impl DeduplicateExec { - /// Create a new deduplication executor. - /// - /// # Arguments - /// - /// * `input` - Child plan producing tagged rows - /// * `pk_columns` - Primary key column names for deduplication - /// * `with_memtable_gen` - Whether to include _memtable_gen in output - /// * `keep_row_address` - Whether to include _rowaddr in output - pub fn new( - input: Arc, - pk_columns: Vec, - with_memtable_gen: bool, - keep_row_address: bool, - ) -> Result { - Self::new_with_sorted( - input, - pk_columns, - with_memtable_gen, - keep_row_address, - false, - ) - } - - /// Create a new deduplication executor with pre-sorted input. - /// - /// # Arguments - /// - /// * `input` - Child plan producing tagged rows - /// * `pk_columns` - Primary key column names for deduplication - /// * `with_memtable_gen` - Whether to include _memtable_gen in output - /// * `keep_row_address` - Whether to include _rowaddr in output - /// * `input_sorted` - Whether the input is already sorted by (pk, _memtable_gen DESC, _rowaddr DESC) - pub fn new_with_sorted( - input: Arc, - pk_columns: Vec, - with_memtable_gen: bool, - keep_row_address: bool, - input_sorted: bool, - ) -> Result { - let input_schema = input.schema(); - - // Validate that required columns exist - for col in &pk_columns { - if input_schema.column_with_name(col).is_none() { - return Err(Error::invalid_input(format!( - "Primary key column '{}' not found in input schema", - col - ))); - } - } - - if input_schema.column_with_name(MEMTABLE_GEN_COLUMN).is_none() { - return Err(Error::invalid_input(format!( - "Generation column '{}' not found in input schema", - MEMTABLE_GEN_COLUMN - ))); - } - - if input_schema.column_with_name(ROW_ADDRESS_COLUMN).is_none() { - return Err(Error::invalid_input(format!( - "Row address column '{}' not found in input schema", - ROW_ADDRESS_COLUMN - ))); - } - - // Build output schema (may exclude internal columns) - let output_fields: Vec> = input_schema - .fields() - .iter() - .filter(|f| { - let name = f.name(); - if name == MEMTABLE_GEN_COLUMN && !with_memtable_gen { - return false; - } - if name == ROW_ADDRESS_COLUMN && !keep_row_address { - return false; - } - true - }) - .cloned() - .collect(); - let schema = Arc::new(Schema::new(output_fields)); - - // Output is single partition after sort + dedup - let properties = Arc::new(PlanProperties::new( - EquivalenceProperties::new(schema.clone()), - Partitioning::UnknownPartitioning(1), - input.pipeline_behavior(), - input.boundedness(), - )); - - Ok(Self { - input, - pk_columns, - schema, - with_memtable_gen, - keep_row_address, - input_sorted, - properties, - }) - } - - /// Create a deduplication executor for pre-sorted input without _memtable_gen column. - /// - /// This is used when the input is already sorted by (pk ASC, _rowaddr DESC) with - /// newer generations appearing first (via stream ordering). The _memtable_gen column is - /// not required in the input schema unless `with_memtable_gen=true`. - /// - /// # Arguments - /// - /// * `input` - Child plan producing rows sorted by (pk ASC, _rowaddr DESC) - /// * `pk_columns` - Primary key column names for deduplication - /// * `with_memtable_gen` - Whether to include _memtable_gen in output (requires _memtable_gen in input) - /// * `keep_row_address` - Whether to include _rowaddr in output - pub fn new_sorted( - input: Arc, - pk_columns: Vec, - with_memtable_gen: bool, - keep_row_address: bool, - ) -> Result { - let input_schema = input.schema(); - - // Validate that required columns exist - for col in &pk_columns { - if input_schema.column_with_name(col).is_none() { - return Err(Error::invalid_input(format!( - "Primary key column '{}' not found in input schema", - col - ))); - } - } - - // _memtable_gen column is only required if with_memtable_gen=true - if with_memtable_gen && input_schema.column_with_name(MEMTABLE_GEN_COLUMN).is_none() { - return Err(Error::invalid_input(format!( - "Generation column '{}' not found in input schema (required when with_memtable_gen=true)", - MEMTABLE_GEN_COLUMN - ))); - } - - if input_schema.column_with_name(ROW_ADDRESS_COLUMN).is_none() { - return Err(Error::invalid_input(format!( - "Row address column '{}' not found in input schema", - ROW_ADDRESS_COLUMN - ))); - } - - // Build output schema (may exclude internal columns) - let output_fields: Vec> = input_schema - .fields() - .iter() - .filter(|f| { - let name = f.name(); - if name == MEMTABLE_GEN_COLUMN && !with_memtable_gen { - return false; - } - if name == ROW_ADDRESS_COLUMN && !keep_row_address { - return false; - } - true - }) - .cloned() - .collect(); - let schema = Arc::new(Schema::new(output_fields)); - - // Output is single partition after dedup - let properties = Arc::new(PlanProperties::new( - EquivalenceProperties::new(schema.clone()), - Partitioning::UnknownPartitioning(1), - input.pipeline_behavior(), - input.boundedness(), - )); - - Ok(Self { - input, - pk_columns, - schema, - with_memtable_gen, - keep_row_address, - input_sorted: true, - properties, - }) - } - - /// Get the primary key columns. - pub fn pk_columns(&self) -> &[String] { - &self.pk_columns - } - - /// Build sort expressions for deduplication ordering. - fn build_sort_exprs(&self) -> DFResult> { - let input_schema = self.input.schema(); - let mut sort_exprs = Vec::new(); - - // Sort by PK columns (ASC) to group duplicates together - for col in &self.pk_columns { - let (idx, _) = input_schema.column_with_name(col).ok_or_else(|| { - datafusion::error::DataFusionError::Internal(format!("Column '{}' not found", col)) - })?; - sort_exprs.push(PhysicalSortExpr { - expr: Arc::new(Column::new(col, idx)), - options: SortOptions { - descending: false, - nulls_first: false, - }, - }); - } - - // Sort by _memtable_gen DESC (higher generation = newer) - let (gen_idx, _) = input_schema - .column_with_name(MEMTABLE_GEN_COLUMN) - .expect("_memtable_gen column validated in constructor"); - sort_exprs.push(PhysicalSortExpr { - expr: Arc::new(Column::new(MEMTABLE_GEN_COLUMN, gen_idx)), - options: SortOptions { - descending: true, - nulls_first: false, - }, - }); - - // Sort by _rowaddr DESC (higher address = newer within generation) - let (addr_idx, _) = input_schema - .column_with_name(ROW_ADDRESS_COLUMN) - .expect("_rowaddr column validated in constructor"); - sort_exprs.push(PhysicalSortExpr { - expr: Arc::new(Column::new(ROW_ADDRESS_COLUMN, addr_idx)), - options: SortOptions { - descending: true, - nulls_first: false, - }, - }); - - Ok(sort_exprs) - } - - /// Build the internal sorted execution plan. - fn build_sorted_plan(&self) -> DFResult> { - let sort_exprs = self.build_sort_exprs()?; - let lex_ordering = LexOrdering::new(sort_exprs).ok_or_else(|| { - datafusion::error::DataFusionError::Internal( - "Failed to create LexOrdering: empty sort expressions".to_string(), - ) - })?; - let sort_exec = SortExec::new(lex_ordering, self.input.clone()); - Ok(Arc::new(sort_exec)) - } - - /// Get column indices for PK comparison. - fn pk_indices(&self) -> Vec { - let schema = self.input.schema(); - self.pk_columns - .iter() - .map(|col| schema.column_with_name(col).unwrap().0) - .collect() - } - - /// Get column indices to keep in output. - fn output_indices(&self) -> Vec { - let input_schema = self.input.schema(); - input_schema - .fields() - .iter() - .enumerate() - .filter(|(_, f)| { - let name = f.name(); - if name == MEMTABLE_GEN_COLUMN && !self.with_memtable_gen { - return false; - } - if name == ROW_ADDRESS_COLUMN && !self.keep_row_address { - return false; - } - true - }) - .map(|(i, _)| i) - .collect() - } -} - -impl DisplayAs for DeduplicateExec { - fn fmt_as(&self, t: DisplayFormatType, f: &mut fmt::Formatter) -> fmt::Result { - match t { - DisplayFormatType::Default - | DisplayFormatType::Verbose - | DisplayFormatType::TreeRender => { - write!( - f, - "DeduplicateExec: pk=[{}], with_memtable_gen={}, keep_addr={}, input_sorted={}", - self.pk_columns.join(", "), - self.with_memtable_gen, - self.keep_row_address, - self.input_sorted - ) - } - } - } -} - -impl ExecutionPlan for DeduplicateExec { - fn name(&self) -> &str { - "DeduplicateExec" - } - - fn as_any(&self) -> &dyn Any { - self - } - - fn schema(&self) -> SchemaRef { - self.schema.clone() - } - - fn properties(&self) -> &Arc { - &self.properties - } - - fn children(&self) -> Vec<&Arc> { - vec![&self.input] - } - - fn with_new_children( - self: Arc, - children: Vec>, - ) -> DFResult> { - if children.len() != 1 { - return Err(datafusion::error::DataFusionError::Internal( - "DeduplicateExec requires exactly one child".to_string(), - )); - } - Ok(Arc::new( - Self::new_with_sorted( - children[0].clone(), - self.pk_columns.clone(), - self.with_memtable_gen, - self.keep_row_address, - self.input_sorted, - ) - .map_err(|e| datafusion::error::DataFusionError::External(Box::new(e)))?, - )) - } - - fn execute( - &self, - partition: usize, - context: Arc, - ) -> DFResult { - // Either use input directly (if pre-sorted) or wrap in sort - let sorted_stream = if self.input_sorted { - // Input is already sorted, use directly - self.input.execute(partition, context)? - } else { - // Build and execute the sorted plan - let sorted_plan = self.build_sorted_plan()?; - sorted_plan.execute(partition, context)? - }; - - Ok(Box::pin(DeduplicateStream::new( - sorted_stream, - self.pk_indices(), - self.output_indices(), - self.schema.clone(), - ))) - } -} - -/// Streaming deduplication on sorted input. -struct DeduplicateStream { - input: SendableRecordBatchStream, - pk_indices: Vec, - output_indices: Vec, - schema: SchemaRef, - /// Last PK values seen (for comparison). - last_pk: Option>>, -} - -impl DeduplicateStream { - fn new( - input: SendableRecordBatchStream, - pk_indices: Vec, - output_indices: Vec, - schema: SchemaRef, - ) -> Self { - Self { - input, - pk_indices, - output_indices, - schema, - last_pk: None, - } - } - - /// Process a batch and return deduplicated rows. - fn process_batch(&mut self, batch: RecordBatch) -> DFResult { - if batch.num_rows() == 0 { - return Ok(RecordBatch::new_empty(self.schema.clone())); - } - - let mut keep_indices = Vec::new(); - - for row_idx in 0..batch.num_rows() { - let current_pk: Vec> = self - .pk_indices - .iter() - .map(|&col_idx| batch.column(col_idx).slice(row_idx, 1)) - .collect(); - - let is_new_pk = match &self.last_pk { - None => true, - Some(last) => !pk_equals(¤t_pk, last), - }; - - if is_new_pk { - // This is the first (newest) row for this PK - keep_indices.push(row_idx); - self.last_pk = Some(current_pk); - } - // Else: duplicate PK with lower gen/rowaddr, skip it - } - - // Build output batch with only kept rows - self.filter_batch(&batch, &keep_indices) - } - - /// Filter batch to only include specified row indices. - fn filter_batch(&self, batch: &RecordBatch, indices: &[usize]) -> DFResult { - if indices.is_empty() { - return Ok(RecordBatch::new_empty(self.schema.clone())); - } - - let indices_array = - arrow_array::UInt32Array::from(indices.iter().map(|&i| i as u32).collect::>()); - - // Select only output columns - let columns: Vec> = self - .output_indices - .iter() - .map(|&col_idx| { - let col = batch.column(col_idx); - arrow_select::take::take(col.as_ref(), &indices_array, None) - .map_err(|e| datafusion::error::DataFusionError::ArrowError(Box::new(e), None)) - }) - .collect::>>()?; - - RecordBatch::try_new(self.schema.clone(), columns) - .map_err(|e| datafusion::error::DataFusionError::ArrowError(Box::new(e), None)) - } -} - -/// Compare two PK tuples for equality. -fn pk_equals(a: &[Arc], b: &[Arc]) -> bool { - if a.len() != b.len() { - return false; - } - - for (col_a, col_b) in a.iter().zip(b.iter()) { - // Each array has 1 element (single row) - convert to ScalarValue for comparison - let val_a = ScalarValue::try_from_array(col_a.as_ref(), 0); - let val_b = ScalarValue::try_from_array(col_b.as_ref(), 0); - - match (val_a, val_b) { - (Ok(a), Ok(b)) => { - if a != b { - return false; - } - } - _ => return false, - } - } - - true -} - -impl Stream for DeduplicateStream { - type Item = DFResult; - - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - match self.input.poll_next_unpin(cx) { - Poll::Ready(Some(Ok(batch))) => { - let result = self.process_batch(batch); - Poll::Ready(Some(result)) - } - Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))), - Poll::Ready(None) => Poll::Ready(None), - Poll::Pending => Poll::Pending, - } - } -} - -impl datafusion::physical_plan::RecordBatchStream for DeduplicateStream { - fn schema(&self) -> SchemaRef { - self.schema.clone() - } -} - -#[cfg(test)] -mod tests { - use super::*; - use arrow_array::{Int32Array, StringArray, UInt64Array}; - use datafusion::prelude::SessionContext; - use datafusion_physical_plan::test::TestMemoryExec; - - fn create_test_data() -> (SchemaRef, Vec) { - // Schema: id (PK), name, _memtable_gen, _rowaddr - let schema = Arc::new(Schema::new(vec![ - Field::new("id", arrow_schema::DataType::Int32, false), - Field::new("name", arrow_schema::DataType::Utf8, true), - Field::new(MEMTABLE_GEN_COLUMN, arrow_schema::DataType::UInt64, false), - Field::new(ROW_ADDRESS_COLUMN, arrow_schema::DataType::UInt64, false), - ])); - - // Data with duplicates: - // id=1: gen=0 (base), gen=2 (memtable) -> keep gen=2 - // id=2: gen=0 only -> keep gen=0 - // id=3: gen=1, gen=2 -> keep gen=2 - let batch = RecordBatch::try_new( - schema.clone(), - vec![ - Arc::new(Int32Array::from(vec![1, 2, 1, 3, 3])), - Arc::new(StringArray::from(vec![ - "old_1", "only_2", "new_1", "old_3", "new_3", - ])), - Arc::new(UInt64Array::from(vec![0, 0, 2, 1, 2])), - Arc::new(UInt64Array::from(vec![100, 200, 50, 10, 20])), - ], - ) - .unwrap(); - - (schema, vec![batch]) - } - - #[tokio::test] - async fn test_deduplicate_exec() { - let (schema, batches) = create_test_data(); - - let input = TestMemoryExec::try_new_exec(&[batches], schema, None).unwrap(); - - let dedup = DeduplicateExec::new( - input, - vec!["id".to_string()], - false, // don't keep _memtable_gen - false, // don't keep _rowaddr - ) - .unwrap(); - - // Output schema should only have id, name - assert_eq!(dedup.schema().fields().len(), 2); - assert_eq!(dedup.schema().field(0).name(), "id"); - assert_eq!(dedup.schema().field(1).name(), "name"); - - let ctx = SessionContext::new(); - let stream = dedup.execute(0, ctx.task_ctx()).unwrap(); - let result_batches: Vec<_> = stream.collect::>().await; - - // Concatenate results - let mut all_ids = Vec::new(); - let mut all_names = Vec::new(); - for batch_result in result_batches { - let batch = batch_result.unwrap(); - if batch.num_rows() > 0 { - let ids = batch - .column(0) - .as_any() - .downcast_ref::() - .unwrap(); - let names = batch - .column(1) - .as_any() - .downcast_ref::() - .unwrap(); - for i in 0..batch.num_rows() { - all_ids.push(ids.value(i)); - all_names.push(names.value(i).to_string()); - } - } - } - - // Should have 3 unique rows - assert_eq!(all_ids.len(), 3); - - // Find each id and verify the correct version was kept - for (id, name) in all_ids.iter().zip(all_names.iter()) { - match id { - 1 => assert_eq!(name, "new_1", "id=1 should keep gen=2 version"), - 2 => assert_eq!(name, "only_2", "id=2 has only one version"), - 3 => assert_eq!(name, "new_3", "id=3 should keep gen=2 version"), - _ => panic!("Unexpected id: {}", id), - } - } - } - - #[tokio::test] - async fn test_deduplicate_with_memtable_gen() { - let (schema, batches) = create_test_data(); - - let input = TestMemoryExec::try_new_exec(&[batches], schema, None).unwrap(); - - let dedup = DeduplicateExec::new( - input, - vec!["id".to_string()], - true, // keep _memtable_gen - false, // don't keep _rowaddr - ) - .unwrap(); - - // Output schema should have id, name, _memtable_gen - assert_eq!(dedup.schema().fields().len(), 3); - assert_eq!(dedup.schema().field(2).name(), MEMTABLE_GEN_COLUMN); - } - - #[test] - fn test_deduplicate_missing_pk_column() { - let schema = Arc::new(Schema::new(vec![ - Field::new("id", arrow_schema::DataType::Int32, false), - Field::new(MEMTABLE_GEN_COLUMN, arrow_schema::DataType::UInt64, false), - Field::new(ROW_ADDRESS_COLUMN, arrow_schema::DataType::UInt64, false), - ])); - - let batch = RecordBatch::try_new( - schema.clone(), - vec![ - Arc::new(Int32Array::from(vec![1])), - Arc::new(UInt64Array::from(vec![1])), - Arc::new(UInt64Array::from(vec![1])), - ], - ) - .unwrap(); - - let input = TestMemoryExec::try_new_exec(&[vec![batch]], schema, None).unwrap(); - - let result = DeduplicateExec::new(input, vec!["nonexistent".to_string()], false, false); - - assert!(result.is_err()); - } - - #[test] - fn test_display() { - let schema = Arc::new(Schema::new(vec![ - Field::new("id", arrow_schema::DataType::Int32, false), - Field::new("name", arrow_schema::DataType::Utf8, true), - Field::new(MEMTABLE_GEN_COLUMN, arrow_schema::DataType::UInt64, false), - Field::new(ROW_ADDRESS_COLUMN, arrow_schema::DataType::UInt64, false), - ])); - - let batch = RecordBatch::new_empty(schema.clone()); - let input = TestMemoryExec::try_new_exec(&[vec![batch]], schema, None).unwrap(); - - let dedup = DeduplicateExec::new(input, vec!["id".to_string()], true, false).unwrap(); - - // Test Debug format - let debug_str = format!("{:?}", dedup); - assert!(debug_str.contains("DeduplicateExec")); - assert!(debug_str.contains("pk_columns")); - } -} diff --git a/rust/lance/src/dataset/mem_wal/scanner/exec/global_pk_dedup.rs b/rust/lance/src/dataset/mem_wal/scanner/exec/global_pk_dedup.rs deleted file mode 100644 index fdf9372cc4e..00000000000 --- a/rust/lance/src/dataset/mem_wal/scanner/exec/global_pk_dedup.rs +++ /dev/null @@ -1,411 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright The Lance Authors - -//! Global, exact primary-key deduplication for the LSM vector-search -//! pipeline. -//! -//! Replaces the older two-step `WithinSourceDedupExec` + `FilterStaleExec` -//! design with a single streaming hash-by-PK pass over the merged stream. -//! For each PK the row with the largest `(generation, freshness)` tuple -//! wins — generation is the source identity (base = 0, memtable gens 1..N, -//! active = N+1) and freshness is the per-source row order normalized so -//! that "larger = newer" (see [`super::LsmSourceTagExec`]). -//! -//! Compared with the bloom-based staleness filter this is: -//! -//! - Exact (no false-positive recall loss, no top-k under-fill, no -//! missing-bloom footgun). -//! - One node instead of two (no separate per-source dedup wrap). -//! - O(unique PKs in the merged stream) state — typically far smaller -//! than the n_sources · k upper bound because most PKs collide across -//! sources for typical LSM update workloads. - -use std::any::Any; -use std::collections::HashMap; -use std::fmt; -use std::pin::Pin; -use std::sync::Arc; -use std::task::{Context, Poll}; - -use arrow_array::{Array, RecordBatch, UInt64Array}; -use arrow_schema::SchemaRef; -use datafusion::error::Result as DFResult; -use datafusion::execution::TaskContext; -use datafusion::physical_expr::{EquivalenceProperties, Partitioning}; -use datafusion::physical_plan::{ - DisplayAs, DisplayFormatType, ExecutionPlan, ExecutionPlanProperties, PlanProperties, - SendableRecordBatchStream, -}; -use futures::{Stream, StreamExt, ready}; - -use super::pk::{compute_pk_hash, resolve_pk_indices}; - -/// Cross-source PK dedup. Keeps one row per primary key — the one with -/// the largest `(generation, freshness)` tuple. -/// -/// # Required input columns -/// -/// - `pk_columns` — the primary key columns. -/// - `generation_column` (UInt64, NOT NULL) — typically -/// [`super::MEMTABLE_GEN_COLUMN`]. -/// - `freshness_column` (UInt64, nullable) — typically -/// [`super::FRESHNESS_COLUMN`]. NULL-freshness rows are skipped (they -/// can't be ordered against real values). -/// -/// The output schema is unchanged from the input. Callers that need to -/// drop the generation / freshness columns from the final output should -/// compose this node with a downstream `project_to_canonical`. -#[derive(Debug)] -pub struct LsmGlobalPkDedupExec { - input: Arc, - pk_columns: Vec, - generation_column: String, - freshness_column: String, - schema: SchemaRef, - properties: Arc, -} - -impl LsmGlobalPkDedupExec { - pub fn new( - input: Arc, - pk_columns: Vec, - generation_column: impl Into, - freshness_column: impl Into, - ) -> Self { - let schema = input.schema(); - let properties = Arc::new(PlanProperties::new( - EquivalenceProperties::new(schema.clone()), - Partitioning::UnknownPartitioning(1), - input.pipeline_behavior(), - input.boundedness(), - )); - Self { - input, - pk_columns, - generation_column: generation_column.into(), - freshness_column: freshness_column.into(), - schema, - properties, - } - } - - pub fn pk_columns(&self) -> &[String] { - &self.pk_columns - } - - pub fn generation_column(&self) -> &str { - &self.generation_column - } - - pub fn freshness_column(&self) -> &str { - &self.freshness_column - } -} - -impl DisplayAs for LsmGlobalPkDedupExec { - fn fmt_as(&self, t: DisplayFormatType, f: &mut fmt::Formatter) -> fmt::Result { - match t { - DisplayFormatType::Default - | DisplayFormatType::Verbose - | DisplayFormatType::TreeRender => { - write!( - f, - "LsmGlobalPkDedupExec: pk=[{}], gen={}, freshness={}", - self.pk_columns.join(", "), - self.generation_column, - self.freshness_column, - ) - } - } - } -} - -impl ExecutionPlan for LsmGlobalPkDedupExec { - fn name(&self) -> &str { - "LsmGlobalPkDedupExec" - } - - fn as_any(&self) -> &dyn Any { - self - } - - fn schema(&self) -> SchemaRef { - self.schema.clone() - } - - fn properties(&self) -> &Arc { - &self.properties - } - - fn children(&self) -> Vec<&Arc> { - vec![&self.input] - } - - fn with_new_children( - self: Arc, - children: Vec>, - ) -> DFResult> { - if children.len() != 1 { - return Err(datafusion::error::DataFusionError::Internal( - "LsmGlobalPkDedupExec requires exactly one child".to_string(), - )); - } - Ok(Arc::new(Self::new( - children[0].clone(), - self.pk_columns.clone(), - self.generation_column.clone(), - self.freshness_column.clone(), - ))) - } - - fn execute( - &self, - partition: usize, - context: Arc, - ) -> DFResult { - let input_stream = self.input.execute(partition, context)?; - Ok(Box::pin(GlobalPkDedupStream { - input: input_stream, - pk_columns: self.pk_columns.clone(), - generation_column: self.generation_column.clone(), - freshness_column: self.freshness_column.clone(), - schema: self.schema.clone(), - winners: HashMap::new(), - emitted: false, - })) - } -} - -struct Winner { - batch: RecordBatch, - generation: u64, - freshness: u64, -} - -struct GlobalPkDedupStream { - input: SendableRecordBatchStream, - pk_columns: Vec, - generation_column: String, - freshness_column: String, - schema: SchemaRef, - winners: HashMap, - emitted: bool, -} - -impl GlobalPkDedupStream { - fn consume_batch(&mut self, batch: RecordBatch) -> DFResult<()> { - if batch.num_rows() == 0 { - return Ok(()); - } - let pk_indices = resolve_pk_indices(&batch, &self.pk_columns)?; - let gen_arr = batch - .column_by_name(&self.generation_column) - .ok_or_else(|| { - datafusion::error::DataFusionError::Internal(format!( - "Generation column '{}' not found in batch", - self.generation_column - )) - })? - .as_any() - .downcast_ref::() - .ok_or_else(|| { - datafusion::error::DataFusionError::Internal(format!( - "Generation column '{}' is not UInt64", - self.generation_column - )) - })?; - let fresh_arr = batch - .column_by_name(&self.freshness_column) - .ok_or_else(|| { - datafusion::error::DataFusionError::Internal(format!( - "Freshness column '{}' not found in batch", - self.freshness_column - )) - })? - .as_any() - .downcast_ref::() - .ok_or_else(|| { - datafusion::error::DataFusionError::Internal(format!( - "Freshness column '{}' is not UInt64", - self.freshness_column - )) - })?; - - for row_idx in 0..batch.num_rows() { - if fresh_arr.is_null(row_idx) { - // A NULL freshness can't be ordered against a real value; - // skip rather than guess. Callers tag with a real value - // for every row eligible to win. - continue; - } - let generation = gen_arr.value(row_idx); - let fresh = fresh_arr.value(row_idx); - let pk_hash = compute_pk_hash(&batch, &pk_indices, row_idx); - - let take_row = match self.winners.get(&pk_hash) { - None => true, - Some(existing) => (generation, fresh) > (existing.generation, existing.freshness), - }; - - if take_row { - let single = batch.slice(row_idx, 1); - self.winners.insert( - pk_hash, - Winner { - batch: single, - generation, - freshness: fresh, - }, - ); - } - } - Ok(()) - } - - fn finalize(&mut self) -> DFResult { - if self.winners.is_empty() { - return Ok(RecordBatch::new_empty(self.schema.clone())); - } - let batches: Vec = self.winners.drain().map(|(_, w)| w.batch).collect(); - let batch_refs: Vec<&RecordBatch> = batches.iter().collect(); - arrow_select::concat::concat_batches(&self.schema, batch_refs) - .map_err(|e| datafusion::error::DataFusionError::ArrowError(Box::new(e), None)) - } -} - -impl Stream for GlobalPkDedupStream { - type Item = DFResult; - - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - loop { - if self.emitted { - return Poll::Ready(None); - } - match ready!(self.input.poll_next_unpin(cx)) { - Some(Ok(batch)) => { - if let Err(e) = self.consume_batch(batch) { - self.emitted = true; - return Poll::Ready(Some(Err(e))); - } - } - Some(Err(e)) => { - self.emitted = true; - return Poll::Ready(Some(Err(e))); - } - None => { - self.emitted = true; - return Poll::Ready(Some(self.finalize())); - } - } - } - } -} - -impl datafusion::physical_plan::RecordBatchStream for GlobalPkDedupStream { - fn schema(&self) -> SchemaRef { - self.schema.clone() - } -} - -#[cfg(test)] -mod tests { - use super::*; - use arrow_array::Int32Array; - use arrow_schema::{DataType, Field, Schema}; - use datafusion::prelude::SessionContext; - use datafusion_physical_plan::test::TestMemoryExec; - use futures::TryStreamExt; - - fn test_schema() -> SchemaRef { - Arc::new(Schema::new(vec![ - Field::new("id", DataType::Int32, false), - Field::new("_memtable_gen", DataType::UInt64, false), - Field::new("_freshness", DataType::UInt64, true), - ])) - } - - fn batch(ids: &[i32], gens: &[u64], fresh: &[Option]) -> RecordBatch { - RecordBatch::try_new( - test_schema(), - vec![ - Arc::new(Int32Array::from(ids.to_vec())), - Arc::new(UInt64Array::from(gens.to_vec())), - Arc::new(UInt64Array::from(fresh.to_vec())), - ], - ) - .unwrap() - } - - async fn run(batches: Vec) -> Vec { - let schema = test_schema(); - let input = TestMemoryExec::try_new_exec(&[batches], schema, None).unwrap(); - let exec = - LsmGlobalPkDedupExec::new(input, vec!["id".to_string()], "_memtable_gen", "_freshness"); - let ctx = SessionContext::new(); - let stream = exec.execute(0, ctx.task_ctx()).unwrap(); - stream.try_collect().await.unwrap() - } - - fn extract(batches: &[RecordBatch]) -> Vec<(i32, u64, Option)> { - let mut rows = Vec::new(); - for b in batches { - let ids = b.column(0).as_any().downcast_ref::().unwrap(); - let gens = b.column(1).as_any().downcast_ref::().unwrap(); - let fresh = b.column(2).as_any().downcast_ref::().unwrap(); - for i in 0..b.num_rows() { - rows.push(( - ids.value(i), - gens.value(i), - if fresh.is_null(i) { - None - } else { - Some(fresh.value(i)) - }, - )); - } - } - rows.sort_by_key(|r| r.0); - rows - } - - #[tokio::test] - async fn keeps_higher_freshness_within_single_generation() { - let b = batch(&[1, 1, 2], &[3, 3, 3], &[Some(10), Some(99), Some(5)]); - let rows = extract(&run(vec![b]).await); - assert_eq!(rows, vec![(1, 3, Some(99)), (2, 3, Some(5))]); - } - - #[tokio::test] - async fn higher_generation_beats_higher_freshness() { - let b = batch(&[1, 1, 2], &[1, 2, 2], &[Some(u64::MAX), Some(0), Some(5)]); - // id=1 in gen=2 with freshness 0 wins over gen=1 with freshness MAX. - let rows = extract(&run(vec![b]).await); - assert_eq!(rows, vec![(1, 2, Some(0)), (2, 2, Some(5))]); - } - - #[tokio::test] - async fn dedup_across_batches() { - let b1 = batch(&[1, 2], &[1, 2], &[Some(5), Some(5)]); - let b2 = batch(&[1, 3], &[3, 1], &[Some(0), Some(1)]); - // id=1: gen=3 wins. id=2: only gen=2 row. id=3: only gen=1 row. - let rows = extract(&run(vec![b1, b2]).await); - assert_eq!( - rows, - vec![(1, 3, Some(0)), (2, 2, Some(5)), (3, 1, Some(1))], - ); - } - - #[tokio::test] - async fn null_freshness_skipped() { - let b = batch(&[1, 1], &[5, 5], &[None, Some(0)]); - // The null-freshness row is dropped; the real one wins by default. - let rows = extract(&run(vec![b]).await); - assert_eq!(rows, vec![(1, 5, Some(0))]); - } - - #[tokio::test] - async fn empty_input() { - let total: usize = run(vec![]).await.iter().map(|b| b.num_rows()).sum(); - assert_eq!(total, 0); - } -} diff --git a/rust/lance/src/dataset/mem_wal/scanner/exec/pk.rs b/rust/lance/src/dataset/mem_wal/scanner/exec/pk.rs index abb2653fa50..523dd30bf82 100644 --- a/rust/lance/src/dataset/mem_wal/scanner/exec/pk.rs +++ b/rust/lance/src/dataset/mem_wal/scanner/exec/pk.rs @@ -3,14 +3,20 @@ //! Shared primary-key helpers for the LSM scanner execution nodes. //! -//! Centralizes PK column resolution and per-row hashing so that every dedup -//! node ([`super::WithinSourceDedupExec`] and [`super::LsmGlobalPkDedupExec`]) +//! Centralizes PK column resolution and per-row hashing so that every +//! consumer (e.g. [`super::WithinSourceDedupExec`], [`super::PkHashFilterExec`]) //! resolves and hashes a primary key the same way. The row hash is kept //! consistent with the variants supported by [`super::compute_pk_hash_from_scalars`] //! so a single PK produces the same hash regardless of which exec consumes it. use arrow_array::{Array, RecordBatch}; +use arrow_schema::{DataType, Schema}; +use datafusion::common::ScalarValue; use datafusion::error::{DataFusionError, Result as DFResult}; +use lance_core::{Error, Result}; + +/// Column name for a row address (the in-source row offset). +pub const ROW_ADDRESS_COLUMN: &str = "_rowaddr"; /// Resolve the column index of each primary-key column in `batch`. pub fn resolve_pk_indices(batch: &RecordBatch, pk_columns: &[String]) -> DFResult> { @@ -28,8 +34,60 @@ pub fn resolve_pk_indices(batch: &RecordBatch, pk_columns: &[String]) -> DFResul .collect() } +/// Primary-key column types we can hash exactly in the fast path. +/// +/// Anything else is rejected by [`validate_pk_types`] at the scanner boundary, +/// so the hot hash path never silently collapses distinct keys to one hash +/// (which would over-block in the block-list and drop live rows). +pub fn is_supported_pk_type(data_type: &DataType) -> bool { + matches!( + data_type, + DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 + | DataType::Boolean + | DataType::Utf8 + | DataType::LargeUtf8 + | DataType::Binary + | DataType::LargeBinary + ) +} + +/// Validate that every primary-key column has a type we can hash exactly. +/// +/// Rejects unsupported types with a descriptive error at the API boundary +/// rather than degrading to a constant hash. Call this where a scanner that +/// hashes primary keys is built. +pub fn validate_pk_types(schema: &Schema, pk_columns: &[String]) -> Result<()> { + for col in pk_columns { + let field = schema.field_with_name(col).map_err(|_| { + Error::invalid_input(format!("Primary key column '{}' not found in schema", col)) + })?; + if !is_supported_pk_type(field.data_type()) { + return Err(Error::invalid_input(format!( + "Primary key column '{}' has unsupported type {:?} for hashing; supported types: \ + Int8/16/32/64, UInt8/16/32/64, Boolean, Utf8/LargeUtf8, Binary/LargeBinary", + col, + field.data_type() + ))); + } + } + Ok(()) +} + /// Hash a single row's primary key, identified by the `pk_indices` column /// positions and `row_idx`. +/// +/// Must stay byte-for-byte consistent with +/// [`super::compute_pk_hash_from_scalars`] so a single PK hashes the same +/// regardless of which exec consumes it. Supported types are validated up +/// front by [`validate_pk_types`]; the trailing branch is a defensive, +/// value-distinguishing fallback that should be unreachable in validated plans. pub fn compute_pk_hash(batch: &RecordBatch, pk_indices: &[usize], row_idx: usize) -> u64 { use std::collections::hash_map::DefaultHasher; use std::hash::{Hash, Hasher}; @@ -41,18 +99,35 @@ pub fn compute_pk_hash(batch: &RecordBatch, pk_indices: &[usize], row_idx: usize is_null.hash(&mut hasher); if !is_null { - if let Some(arr) = col.as_any().downcast_ref::() { + if let Some(arr) = col.as_any().downcast_ref::() { + arr.value(row_idx).hash(&mut hasher); + } else if let Some(arr) = col.as_any().downcast_ref::() { + arr.value(row_idx).hash(&mut hasher); + } else if let Some(arr) = col.as_any().downcast_ref::() { arr.value(row_idx).hash(&mut hasher); } else if let Some(arr) = col.as_any().downcast_ref::() { arr.value(row_idx).hash(&mut hasher); - } else if let Some(arr) = col.as_any().downcast_ref::() { + } else if let Some(arr) = col.as_any().downcast_ref::() { arr.value(row_idx).hash(&mut hasher); - } else if let Some(arr) = col.as_any().downcast_ref::() { + } else if let Some(arr) = col.as_any().downcast_ref::() { arr.value(row_idx).hash(&mut hasher); } else if let Some(arr) = col.as_any().downcast_ref::() { arr.value(row_idx).hash(&mut hasher); } else if let Some(arr) = col.as_any().downcast_ref::() { arr.value(row_idx).hash(&mut hasher); + } else if let Some(arr) = col.as_any().downcast_ref::() { + arr.value(row_idx).hash(&mut hasher); + } else if let Some(arr) = col.as_any().downcast_ref::() { + arr.value(row_idx).hash(&mut hasher); + } else if let Some(arr) = col.as_any().downcast_ref::() { + arr.value(row_idx).hash(&mut hasher); + } else if let Some(arr) = col.as_any().downcast_ref::() { + arr.value(row_idx).hash(&mut hasher); + } else if let Some(arr) = col.as_any().downcast_ref::() { + arr.value(row_idx).hash(&mut hasher); + } else if let Ok(scalar) = ScalarValue::try_from_array(col.as_ref(), row_idx) { + // Defensive fallback: distinguish by value rather than collapse. + format!("{:?}", scalar).hash(&mut hasher); } } } diff --git a/rust/lance/src/dataset/mem_wal/scanner/exec/source_tag.rs b/rust/lance/src/dataset/mem_wal/scanner/exec/source_tag.rs deleted file mode 100644 index 29eac385381..00000000000 --- a/rust/lance/src/dataset/mem_wal/scanner/exec/source_tag.rs +++ /dev/null @@ -1,404 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright The Lance Authors - -//! Per-source tagging for the LSM vector-search dedup pipeline. -//! -//! `LsmSourceTagExec` appends two columns to each row of a per-source scan: -//! - `_memtable_gen` (UInt64): the source's generation number (base = 0, -//! flushed gens 1..N, active memtable = N+1). -//! - `_freshness` (UInt64): a within-source "newness" indicator normalized -//! so that *larger value = newer insert* regardless of which side -//! produced it. The active memtable stores rows in insert order -//! (`_freshness = _rowid`), while flushed memtables are reverse-written -//! (`_freshness = u64::MAX - _rowid`). -//! -//! Together, the two columns let [`super::LsmGlobalPkDedupExec`] decide a -//! winner per primary key via a single lexicographic `(gen, freshness)` -//! comparison across the merged stream — no separate within-source dedup -//! and no bloom-based staleness filtering needed. - -use std::any::Any; -use std::fmt; -use std::pin::Pin; -use std::sync::Arc; -use std::task::{Context, Poll}; - -use arrow_array::{Array, RecordBatch, UInt64Array}; -use arrow_schema::{DataType, Field, Schema, SchemaRef}; -use datafusion::error::Result as DFResult; -use datafusion::execution::TaskContext; -use datafusion::physical_expr::EquivalenceProperties; -use datafusion::physical_plan::{ - DisplayAs, DisplayFormatType, ExecutionPlan, ExecutionPlanProperties, PlanProperties, - SendableRecordBatchStream, -}; -use futures::{Stream, StreamExt}; - -use crate::dataset::mem_wal::scanner::data_source::LsmGeneration; - -use super::generation_tag::MEMTABLE_GEN_COLUMN; - -/// Column name for the normalized within-source freshness. Higher = newer. -pub const FRESHNESS_COLUMN: &str = "_freshness"; - -/// Polarity for translating a source's row-id column into `_freshness`. -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum FreshnessPolarity { - /// `_freshness = row_id`. Used by sources that store rows in insert - /// order (active memtable; also base table where duplicates aren't - /// expected but the polarity must still be consistent). - InsertOrder, - /// `_freshness = u64::MAX - row_id`. Used by flushed memtables, which - /// are reverse-written so a smaller `_rowid` is the newer insert. - ReverseWrite, -} - -/// Tag every row of a per-source scan with `_memtable_gen` + `_freshness`. -/// -/// # Required input columns -/// -/// - `row_id_column` (UInt64) — typically `_rowid`. Must be present on -/// every row; NULLs are propagated as NULL `_freshness` and will be -/// skipped by the downstream dedup. -/// -/// # Output schema -/// -/// Input schema + `_memtable_gen` (UInt64, NOT NULL) + `_freshness` -/// (UInt64, nullable to mirror the source's `_rowid` nullability). -#[derive(Debug)] -pub struct LsmSourceTagExec { - input: Arc, - generation: LsmGeneration, - polarity: FreshnessPolarity, - row_id_column: String, - schema: SchemaRef, - properties: Arc, -} - -impl LsmSourceTagExec { - pub fn new( - input: Arc, - generation: LsmGeneration, - polarity: FreshnessPolarity, - row_id_column: impl Into, - ) -> Self { - let input_schema = input.schema(); - let mut fields: Vec> = input_schema.fields().iter().cloned().collect(); - fields.push(Arc::new(Field::new( - MEMTABLE_GEN_COLUMN, - DataType::UInt64, - false, - ))); - fields.push(Arc::new(Field::new( - FRESHNESS_COLUMN, - DataType::UInt64, - true, - ))); - let schema = Arc::new(Schema::new(fields)); - - let properties = Arc::new(PlanProperties::new( - EquivalenceProperties::new(schema.clone()), - input.output_partitioning().clone(), - input.pipeline_behavior(), - input.boundedness(), - )); - - Self { - input, - generation, - polarity, - row_id_column: row_id_column.into(), - schema, - properties, - } - } - - pub fn generation(&self) -> LsmGeneration { - self.generation - } - - pub fn polarity(&self) -> FreshnessPolarity { - self.polarity - } - - pub fn row_id_column(&self) -> &str { - &self.row_id_column - } -} - -impl DisplayAs for LsmSourceTagExec { - fn fmt_as(&self, t: DisplayFormatType, f: &mut fmt::Formatter) -> fmt::Result { - match t { - DisplayFormatType::Default - | DisplayFormatType::Verbose - | DisplayFormatType::TreeRender => { - write!( - f, - "LsmSourceTagExec: gen={}, polarity={:?}, row_id_col={}", - self.generation, self.polarity, self.row_id_column, - ) - } - } - } -} - -impl ExecutionPlan for LsmSourceTagExec { - fn name(&self) -> &str { - "LsmSourceTagExec" - } - - fn as_any(&self) -> &dyn Any { - self - } - - fn schema(&self) -> SchemaRef { - self.schema.clone() - } - - fn properties(&self) -> &Arc { - &self.properties - } - - fn children(&self) -> Vec<&Arc> { - vec![&self.input] - } - - fn with_new_children( - self: Arc, - children: Vec>, - ) -> DFResult> { - if children.len() != 1 { - return Err(datafusion::error::DataFusionError::Internal( - "LsmSourceTagExec requires exactly one child".to_string(), - )); - } - Ok(Arc::new(Self::new( - children[0].clone(), - self.generation, - self.polarity, - self.row_id_column.clone(), - ))) - } - - fn execute( - &self, - partition: usize, - context: Arc, - ) -> DFResult { - let input_stream = self.input.execute(partition, context)?; - Ok(Box::pin(SourceTagStream { - input: input_stream, - generation: self.generation.as_u64(), - polarity: self.polarity, - row_id_column: self.row_id_column.clone(), - schema: self.schema.clone(), - })) - } -} - -struct SourceTagStream { - input: SendableRecordBatchStream, - generation: u64, - polarity: FreshnessPolarity, - row_id_column: String, - schema: SchemaRef, -} - -impl SourceTagStream { - fn tag_batch(&self, batch: RecordBatch) -> DFResult { - let num_rows = batch.num_rows(); - let gen_col: Arc = Arc::new(UInt64Array::from(vec![self.generation; num_rows])); - - let row_id_arr = batch - .column_by_name(&self.row_id_column) - .ok_or_else(|| { - datafusion::error::DataFusionError::Internal(format!( - "Row id column '{}' not found in batch — LsmSourceTagExec needs the per-source row id to derive _freshness", - self.row_id_column - )) - })? - .as_any() - .downcast_ref::() - .ok_or_else(|| { - datafusion::error::DataFusionError::Internal(format!( - "Row id column '{}' is not UInt64", - self.row_id_column - )) - })?; - - let freshness: Arc = match self.polarity { - FreshnessPolarity::InsertOrder => Arc::new(row_id_arr.clone()), - FreshnessPolarity::ReverseWrite => { - let mut builder = arrow_array::builder::UInt64Builder::with_capacity(num_rows); - for i in 0..num_rows { - if row_id_arr.is_null(i) { - builder.append_null(); - } else { - builder.append_value(u64::MAX - row_id_arr.value(i)); - } - } - Arc::new(builder.finish()) - } - }; - - let mut columns: Vec> = batch.columns().to_vec(); - columns.push(gen_col); - columns.push(freshness); - - RecordBatch::try_new(self.schema.clone(), columns) - .map_err(|e| datafusion::error::DataFusionError::ArrowError(Box::new(e), None)) - } -} - -impl Stream for SourceTagStream { - type Item = DFResult; - - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - match self.input.poll_next_unpin(cx) { - Poll::Ready(Some(Ok(batch))) => { - let tagged = self.tag_batch(batch); - Poll::Ready(Some(tagged)) - } - Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))), - Poll::Ready(None) => Poll::Ready(None), - Poll::Pending => Poll::Pending, - } - } -} - -impl datafusion::physical_plan::RecordBatchStream for SourceTagStream { - fn schema(&self) -> SchemaRef { - self.schema.clone() - } -} - -#[cfg(test)] -mod tests { - use super::*; - use arrow_array::Int32Array; - use datafusion::prelude::SessionContext; - use datafusion_physical_plan::test::TestMemoryExec; - use futures::TryStreamExt; - - fn input_schema() -> SchemaRef { - Arc::new(Schema::new(vec![ - Field::new("id", DataType::Int32, false), - Field::new("_rowid", DataType::UInt64, true), - ])) - } - - fn batch(ids: &[i32], row_ids: &[Option]) -> RecordBatch { - RecordBatch::try_new( - input_schema(), - vec![ - Arc::new(Int32Array::from(ids.to_vec())), - Arc::new(UInt64Array::from(row_ids.to_vec())), - ], - ) - .unwrap() - } - - async fn run( - b: RecordBatch, - generation: LsmGeneration, - polarity: FreshnessPolarity, - ) -> Vec { - let schema = b.schema(); - let input = TestMemoryExec::try_new_exec(&[vec![b]], schema, None).unwrap(); - let exec = LsmSourceTagExec::new(input, generation, polarity, "_rowid"); - let ctx = SessionContext::new(); - let stream = exec.execute(0, ctx.task_ctx()).unwrap(); - stream.try_collect().await.unwrap() - } - - fn columns(batches: &[RecordBatch]) -> (Vec, Vec>) { - let mut gens = Vec::new(); - let mut fresh = Vec::new(); - for b in batches { - let g = b - .column_by_name(MEMTABLE_GEN_COLUMN) - .unwrap() - .as_any() - .downcast_ref::() - .unwrap(); - let f = b - .column_by_name(FRESHNESS_COLUMN) - .unwrap() - .as_any() - .downcast_ref::() - .unwrap(); - for i in 0..b.num_rows() { - gens.push(g.value(i)); - fresh.push(if f.is_null(i) { None } else { Some(f.value(i)) }); - } - } - (gens, fresh) - } - - #[tokio::test] - async fn insert_order_passes_row_id_through() { - let b = batch(&[1, 2, 3], &[Some(0), Some(5), Some(99)]); - let out = run( - b, - LsmGeneration::memtable(7), - FreshnessPolarity::InsertOrder, - ) - .await; - let (gens, fresh) = columns(&out); - assert_eq!(gens, vec![7, 7, 7]); - assert_eq!(fresh, vec![Some(0), Some(5), Some(99)]); - } - - #[tokio::test] - async fn reverse_write_flips_row_id() { - let b = batch(&[1, 2, 3], &[Some(0), Some(5), Some(99)]); - let out = run( - b, - LsmGeneration::memtable(2), - FreshnessPolarity::ReverseWrite, - ) - .await; - let (gens, fresh) = columns(&out); - assert_eq!(gens, vec![2, 2, 2]); - // Under reverse-write, smaller row_id = newer ⇒ larger _freshness. - assert_eq!( - fresh, - vec![Some(u64::MAX), Some(u64::MAX - 5), Some(u64::MAX - 99)], - ); - } - - #[tokio::test] - async fn null_row_id_yields_null_freshness() { - let b = batch(&[1, 2], &[None, Some(3)]); - let out = run( - b, - LsmGeneration::memtable(1), - FreshnessPolarity::ReverseWrite, - ) - .await; - let (_, fresh) = columns(&out); - assert_eq!(fresh, vec![None, Some(u64::MAX - 3)]); - } - - #[tokio::test] - async fn base_table_generation_is_zero() { - let b = batch(&[1], &[Some(0)]); - let out = run(b, LsmGeneration::BASE_TABLE, FreshnessPolarity::InsertOrder).await; - let (gens, _) = columns(&out); - assert_eq!(gens, vec![0]); - } - - #[tokio::test] - async fn empty_batch_passthrough() { - let schema = input_schema(); - let empty = RecordBatch::new_empty(schema); - let out = run( - empty, - LsmGeneration::memtable(1), - FreshnessPolarity::InsertOrder, - ) - .await; - let total: usize = out.iter().map(|b| b.num_rows()).sum(); - assert_eq!(total, 0); - } -} diff --git a/rust/lance/src/dataset/mem_wal/scanner/planner.rs b/rust/lance/src/dataset/mem_wal/scanner/planner.rs index 98daf2bdda9..76013a99a5a 100644 --- a/rust/lance/src/dataset/mem_wal/scanner/planner.rs +++ b/rust/lance/src/dataset/mem_wal/scanner/planner.rs @@ -5,20 +5,17 @@ use std::sync::Arc; -use arrow_schema::{DataType, Field, Schema, SchemaRef, SortOptions}; -use datafusion::physical_expr::expressions::Column; -use datafusion::physical_expr::{LexOrdering, PhysicalSortExpr}; -use datafusion::physical_plan::sorts::sort::SortExec; -use datafusion::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; +use arrow_schema::{DataType, Field, Schema, SchemaRef}; +use datafusion::physical_plan::coalesce_partitions::CoalescePartitionsExec; use datafusion::physical_plan::union::UnionExec; use datafusion::physical_plan::{ExecutionPlan, limit::GlobalLimitExec}; use datafusion::prelude::Expr; -use lance_core::{Result, is_system_column}; +use lance_core::Result; use tracing::instrument; use super::collector::LsmDataSourceCollector; use super::data_source::LsmDataSource; -use super::exec::{DeduplicateExec, MEMTABLE_GEN_COLUMN, MemtableGenTagExec, ROW_ADDRESS_COLUMN}; +use super::exec::{MEMTABLE_GEN_COLUMN, MemtableGenTagExec, PkHashFilterExec, ROW_ADDRESS_COLUMN}; use super::flushed_cache::{FlushedMemTableCache, open_flushed_dataset}; use super::projection::{ build_scanner_projection, canonical_output_schema, null_columns, project_to_canonical, @@ -80,27 +77,16 @@ impl LsmScanPlanner { /// * `with_memtable_gen` - Whether to include _memtable_gen in output /// * `keep_row_address` - Whether to include _rowaddr in output /// - /// # Query Plan Optimization + /// # Query plan /// - /// The planner uses an optimized execution strategy: - /// 1. Each data source is scanned and locally sorted by (pk ASC, _rowaddr DESC) - /// 2. Sources are ordered by _memtable_gen DESC (newest first) in the UnionExec - /// 3. K pre-sorted streams are merged using SortPreservingMergeExec - /// 4. DeduplicateExec performs streaming deduplication on the merged output - /// - /// Key insight: DataFusion's SortPreservingMergeExec uses stream index as a - /// tiebreaker when sort keys are equal. By ordering inputs with highest _memtable_gen - /// first (lowest stream index), the merge naturally prefers newer rows. - /// - /// This avoids needing a `_memtable_gen` column entirely - generation ordering is implicit - /// in the stream ordering. The `_memtable_gen` column is only added (via MemtableGenTagExec) - /// when `with_memtable_gen=true`. - /// - /// This is more efficient than the naive approach of Union + global Sort because: - /// - Local sorts are smaller and can often fit in memory - /// - SortPreservingMergeExec is O(N log K) where K is the number of sources - /// - Memory usage is bounded by the sum of K sort buffers rather than all data - /// - No extra column for _memtable_gen in the common case + /// Each source is independently newest-per-PK (active via the fused + /// [`MemTableDedupScanExec`](super::super::memtable::scanner), flushed via + /// its within-generation deletion vector) and a cross-generation block-list + /// ([`PkHashFilterExec`]) drops any PK superseded by a newer generation. + /// Each PK therefore survives in exactly one source, so a plain + /// `UnionExec` carries at most one row per PK — no cross-source dedup, + /// sort, or merge needed. `_memtable_gen` / `_rowaddr` are output-only and + /// only produced when the caller opts in. #[instrument(name = "lsm_plan_scan", level = "debug", skip_all, fields(has_filter = filter.is_some(), limit, offset))] pub async fn plan_scan( &self, @@ -128,104 +114,74 @@ impl LsmScanPlanner { return self.empty_plan(projection, with_memtable_gen, keep_row_address); } - // 2. Build scan plan for each source with local sorting - // Order of operations: scan -> local sort -> (optional) tag with generation - // - // IMPORTANT: Sources are collected in generation order (base=0, then memtables 1,2,3...) - // We reverse this to get _memtable_gen DESC order for the merge tiebreaker. + // Cross-generation block-list keyed by source: a hit drops any row + // whose PK lives in a newer generation, applied before the union. + // `Box::pin` keeps the future off `clippy::large_futures`. + let block_lists = Box::pin(super::block_list::compute_source_block_lists( + &sources, + &self.pk_columns, + self.session.as_ref(), + self.flushed_cache.as_ref(), + )) + .await?; + + // Reverse so the union lists the newest generation first. This is + // cosmetic — correctness comes from the per-source dedup and the + // cross-gen block-list, not from output ordering. let sources: Vec<_> = sources.into_iter().rev().collect(); - let mut sorted_plans = Vec::new(); + let mut source_plans = Vec::new(); for source in sources { let is_base = matches!(source, LsmDataSource::BaseTable { .. }); let scan = self.build_source_scan(&source, projection, filter).await?; - // Sort locally by (pk ASC, _rowaddr DESC) - let local_sort_exprs = self.build_local_sort_exprs(&scan)?; - let lex_ordering = LexOrdering::new(local_sort_exprs).ok_or_else(|| { - lance_core::Error::internal( - "Failed to create LexOrdering from sort expressions".to_string(), - ) - })?; - let sorted: Arc = Arc::new(SortExec::new(lex_ordering, scan)); - - // When `_rowaddr` will be surfaced to the caller, NULL it for - // non-base arms post-sort: only base values are meaningful (e.g. - // for `take_rows`); other arms carry per-source addresses that - // collide with base IDs. The schema is preserved so union/dedup - // still match (dedup picks rows by upstream order, not value). - // Skipped when `_rowaddr` would be stripped by dedup anyway, to - // avoid adding a no-op projection to the plan. - let after_sort: Arc = if !is_base && keep_row_address { - null_columns(sorted, &[ROW_ADDRESS_COLUMN])? + // Drop cross-generation stale rows (PKs superseded by a newer gen). + // `k = 0`: there is no top-k, so the under-fetch warning never fires. + let scan = match block_lists.get(&(source.shard_id(), source.generation())) { + Some(set) => Arc::new(PkHashFilterExec::new( + scan, + self.pk_columns.clone(), + set.clone(), + 0, + )) as Arc, + None => scan, + }; + + // When `_rowaddr` is surfaced, NULL it for non-base arms: only base + // values are meaningful (e.g. for `take_rows`); per-source addresses + // collide with base IDs. + let scan: Arc = if !is_base && keep_row_address { + null_columns(scan, &[ROW_ADDRESS_COLUMN])? } else { - sorted + scan }; - // Only tag with generation if user wants _memtable_gen in output + // Tag with generation only if the caller wants `_memtable_gen`. let plan: Arc = if with_memtable_gen { - Arc::new(MemtableGenTagExec::new(after_sort, source.generation())) + Arc::new(MemtableGenTagExec::new(scan, source.generation())) } else { - after_sort + scan }; - sorted_plans.push(plan); + source_plans.push(plan); } - // 3. Merge pre-sorted streams - // Merge using (pk ASC) only - NOT _rowaddr, because _rowaddr is different across tables - // for the same pk, which would break the stream index tiebreaker. - // - // DataFusion's SortPreservingMergeExec uses stream index as a tiebreaker when - // sort keys are equal (see merge.rs line 349: `ac.cmp(bc).then_with(|| a.cmp(&b))`). - // By ordering inputs with highest _memtable_gen first (lowest stream index), the merge - // naturally prefers newer rows when PKs are equal. - // - // Local sort uses (pk ASC, _rowaddr DESC) to order within each source, but the merge - // only considers pk for comparison. This ensures: - // 1. For the same pk, newer generation (lower stream index) comes first - // 2. Within the same pk and generation, higher _rowaddr comes first - let merged: Arc = if sorted_plans.len() == 1 { - sorted_plans.remove(0) + // Union, then coalesce into a single partition (UnionExec emits one + // per arm; downstream consumers only read partition 0). + let mut plan: Arc = if source_plans.len() == 1 { + source_plans.remove(0) } else { - // Use SortPreservingMergeExec to merge K pre-sorted streams - // IMPORTANT: Only merge by pk columns, not _rowaddr! - let merge_sort_exprs = self.build_merge_sort_exprs(&sorted_plans[0])?; - let lex_ordering = LexOrdering::new(merge_sort_exprs).ok_or_else(|| { - lance_core::Error::internal( - "Failed to create LexOrdering from sort expressions".to_string(), - ) - })?; - - // UnionExec to combine all partitions (ordered by _memtable_gen DESC) #[allow(deprecated)] - let union = Arc::new(UnionExec::new(sorted_plans)); - - // SortPreservingMergeExec merges pre-sorted partitions - Arc::new(SortPreservingMergeExec::new(lex_ordering, union)) + let union = Arc::new(UnionExec::new(source_plans)); + Arc::new(CoalescePartitionsExec::new(union)) }; - // 4. Add deduplication (input is already sorted by pk, newer rows first) - let dedup = DeduplicateExec::new_sorted( - merged, - self.pk_columns.clone(), - with_memtable_gen, - keep_row_address, + // Project to the canonical output schema, dropping `_rowaddr` / + // `_memtable_gen` unless the caller opted in. + plan = project_to_canonical( + plan, + &self.canonical_scan_schema(projection, with_memtable_gen, keep_row_address), )?; - let mut plan: Arc = Arc::new(dedup); - - // 5. Surface user-requested system columns at the requested position. - // Skipped otherwise so the plan shape stays unchanged for callers - // that don't opt in. - let user_wants_system = projection - .map(|p| p.iter().any(|c| is_system_column(c))) - .unwrap_or(false); - if user_wants_system { - plan = project_to_canonical( - plan, - &self.canonical_scan_schema(projection, with_memtable_gen, keep_row_address), - )?; - } // 6. Add limit if specified if let Some(limit) = limit { @@ -267,84 +223,6 @@ impl LsmScanPlanner { Arc::new(Schema::new(fields)) } - /// Build sort expressions for local sorting within a single source. - /// - /// Sort order: (pk_columns ASC, _rowaddr DESC) - /// Note: _memtable_gen is not included because it's constant within each source. - fn build_local_sort_exprs( - &self, - plan: &Arc, - ) -> Result> { - let schema = plan.schema(); - let mut sort_exprs = Vec::new(); - - // Sort by PK columns (ASC) to group duplicates together - for col in &self.pk_columns { - let (idx, _) = schema.column_with_name(col).ok_or_else(|| { - lance_core::Error::invalid_input(format!("Column '{}' not found in schema", col)) - })?; - sort_exprs.push(PhysicalSortExpr { - expr: Arc::new(Column::new(col, idx)), - options: SortOptions { - descending: false, - nulls_first: false, - }, - }); - } - - // Sort by _rowaddr DESC (higher address = newer within generation) - let (addr_idx, _) = schema.column_with_name(ROW_ADDRESS_COLUMN).ok_or_else(|| { - lance_core::Error::invalid_input(format!( - "Column '{}' not found in schema", - ROW_ADDRESS_COLUMN - )) - })?; - sort_exprs.push(PhysicalSortExpr { - expr: Arc::new(Column::new(ROW_ADDRESS_COLUMN, addr_idx)), - options: SortOptions { - descending: true, - nulls_first: false, - }, - }); - - Ok(sort_exprs) - } - - /// Build sort expressions for merging streams. - /// - /// Sort order: (pk_columns ASC) only - /// - /// IMPORTANT: This does NOT include _rowaddr because _rowaddr values are different - /// across different tables for the same pk. Including _rowaddr would break the - /// stream index tiebreaker mechanism that ensures newer generations win. - /// - /// When pk is equal across streams, SortPreservingMergeExec uses stream index as - /// tiebreaker (lower index wins). Since streams are ordered by generation DESC - /// (newest first), this ensures newer rows come before older rows for the same pk. - fn build_merge_sort_exprs( - &self, - plan: &Arc, - ) -> Result> { - let schema = plan.schema(); - let mut sort_exprs = Vec::new(); - - // Sort by PK columns (ASC) only - NOT _rowaddr! - for col in &self.pk_columns { - let (idx, _) = schema.column_with_name(col).ok_or_else(|| { - lance_core::Error::invalid_input(format!("Column '{}' not found in schema", col)) - })?; - sort_exprs.push(PhysicalSortExpr { - expr: Arc::new(Column::new(col, idx)), - options: SortOptions { - descending: false, - nulls_first: false, - }, - }); - } - - Ok(sort_exprs) - } - /// Build scan plan for a single data source. async fn build_source_scan( &self, @@ -405,12 +283,14 @@ impl LsmScanPlanner { scanner.project(&cols.iter().map(|s| s.as_str()).collect::>()); scanner.with_row_address(); - // Apply filter - enables BTree index optimization for MemTable + // The dedup scan applies the filter post-dedup; pushing it + // into the raw scan would resurrect older versions of PKs + // whose newest version fails the predicate. if let Some(expr) = filter { scanner.filter_expr(expr.clone()); } - scanner.create_plan().await + scanner.create_dedup_plan(&self.pk_columns).await } } } @@ -490,7 +370,7 @@ mod integration_tests { use std::collections::HashMap; use std::sync::Arc; - use arrow_array::{Int32Array, RecordBatch, RecordBatchIterator, StringArray}; + use arrow_array::{Array, Int32Array, RecordBatch, RecordBatchIterator, StringArray}; use arrow_schema::{DataType, Field, Schema as ArrowSchema}; use futures::TryStreamExt; use uuid::Uuid; @@ -632,24 +512,22 @@ mod integration_tests { let plan = scanner.create_plan().await.unwrap(); - // Verify plan structure showing all levels (gen DESC order: active -> gen2 -> gen1 -> base): - // - DeduplicateExec at top (with_memtable_gen=false means no MemtableGenTagExec) - // - SortPreservingMergeExec merging by pk only (enables stream index tiebreaker) - // - UnionExec combining 4 sorted streams - // - Each stream: SortExec -> MemTableScanExec or LanceRead + // Verify the plan (gen DESC order: active -> gen2 -> gen1 -> base): + // - plain UnionExec at top + // - active arm: MemTableDedupScanExec (newest gen, not block-listed) + // - older arms: PkHashFilterExec (cross-gen block-list) -> LanceRead assert_plan_node_equals( plan, - "DeduplicateExec: pk=[id], with_memtable_gen=false, keep_addr=false, input_sorted=true - SortPreservingMergeExec: [id@0 ASC NULLS LAST] + "ProjectionExec:... + CoalescePartitionsExec UnionExec - SortExec: expr=[id@0 ASC NULLS LAST, _rowaddr@2 DESC NULLS LAST]... - MemTableScanExec: projection=[id, name, _rowaddr], with_row_id=false, with_row_address=true - SortExec: expr=[id@0 ASC NULLS LAST, _rowaddr@2 DESC NULLS LAST]... - LanceRead:...gen_2... - SortExec: expr=[id@0 ASC NULLS LAST, _rowaddr@2 DESC NULLS LAST]... - LanceRead:...gen_1... - SortExec: expr=[id@0 ASC NULLS LAST, _rowaddr@2 DESC NULLS LAST]... - LanceRead:...base/data...refine_filter=--", + MemTableDedupScanExec: projection=[id, name, _rowaddr], with_row_id=false, with_row_address=true + PkHashFilterExec: pk_cols=[id]... + LanceRead:...gen_2... + PkHashFilterExec: pk_cols=[id]... + LanceRead:...gen_1... + PkHashFilterExec: pk_cols=[id]... + LanceRead:...base/data...refine_filter=--", ) .await .unwrap(); @@ -669,32 +547,27 @@ mod integration_tests { let plan = scanner.create_plan().await.unwrap(); - // Verify plan structure with MemtableGenTagExec at each level (gen DESC order): - // - DeduplicateExec at top (with_memtable_gen=true) - // - SortPreservingMergeExec merging by pk only - // - UnionExec combining 4 streams - // - Each stream: MemtableGenTagExec -> SortExec -> data source - // - gen3 (active): MemtableGenTagExec: gen=gen3 -> MemTableScanExec - // - gen2 (flushed): MemtableGenTagExec: gen=gen2 -> LanceRead - // - gen1 (flushed): MemtableGenTagExec: gen=gen1 -> LanceRead - // - base: MemtableGenTagExec: gen=base -> LanceRead + // Verify the plan with `_memtable_gen` tags (gen DESC order): + // - plain UnionExec at top + // - each arm: MemtableGenTagExec -> (PkHashFilterExec ->) data source + // - gen3 (active): MemtableGenTagExec -> MemTableDedupScanExec + // - gen2/gen1/base: MemtableGenTagExec -> PkHashFilterExec -> LanceRead assert_plan_node_equals( plan, - "DeduplicateExec: pk=[id], with_memtable_gen=true, keep_addr=false, input_sorted=true - SortPreservingMergeExec: [id@0 ASC NULLS LAST] + "ProjectionExec:... + CoalescePartitionsExec UnionExec - MemtableGenTagExec: gen=gen3 - SortExec: expr=[id@0 ASC NULLS LAST, _rowaddr@2 DESC NULLS LAST]... - MemTableScanExec: projection=[id, name, _rowaddr], with_row_id=false, with_row_address=true - MemtableGenTagExec: gen=gen2 - SortExec: expr=[id@0 ASC NULLS LAST, _rowaddr@2 DESC NULLS LAST]... - LanceRead:...gen_2... - MemtableGenTagExec: gen=gen1 - SortExec: expr=[id@0 ASC NULLS LAST, _rowaddr@2 DESC NULLS LAST]... - LanceRead:...gen_1... - MemtableGenTagExec: gen=base - SortExec: expr=[id@0 ASC NULLS LAST, _rowaddr@2 DESC NULLS LAST]... - LanceRead:...base/data...refine_filter=--", + MemtableGenTagExec: gen=gen3 + MemTableDedupScanExec: projection=[id, name, _rowaddr], with_row_id=false, with_row_address=true + MemtableGenTagExec: gen=gen2 + PkHashFilterExec: pk_cols=[id]... + LanceRead:...gen_2... + MemtableGenTagExec: gen=gen1 + PkHashFilterExec: pk_cols=[id]... + LanceRead:...gen_1... + MemtableGenTagExec: gen=base + PkHashFilterExec: pk_cols=[id]... + LanceRead:...base/data...refine_filter=--", ) .await .unwrap(); @@ -760,6 +633,65 @@ mod integration_tests { assert_eq!(results.get(&7), Some(&"active_7".to_string())); } + /// The filtered-read plan applies the cross-generation block-list (older + /// generations whose PKs are superseded by a newer one are filtered), while + /// results stay newest-per-PK. + #[tokio::test] + async fn test_lsm_scan_filtered_read_applies_block_list() { + let (base_dataset, shard_snapshots, active_memtable, pk_columns, _temp_path) = + setup_multi_level_lsm().await; + + let mut scanner = LsmScanner::new(base_dataset, shard_snapshots, pk_columns); + if let Some((shard_id, memtable)) = active_memtable { + scanner = scanner.with_in_memory_memtables(shard_id, memtable); + } + + // base/gen1/gen2 all hold PKs superseded by a newer generation, so each + // is wrapped in a `PkHashFilterExec`; the newest (active) arm is not. + let plan = scanner.create_plan().await.unwrap(); + let plan_str = format!( + "{}", + datafusion::physical_plan::displayable(plan.as_ref()).indent(true) + ); + assert!( + plan_str.contains("PkHashFilterExec"), + "filtered-read plan must apply the cross-gen block-list, got:\n{}", + plan_str + ); + + // Results stay correct (newest-per-PK across generations). + let batches: Vec = scanner + .try_into_stream() + .await + .unwrap() + .try_collect() + .await + .unwrap(); + let mut results: HashMap = HashMap::new(); + for batch in batches { + let ids = batch + .column_by_name("id") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + let names = batch + .column_by_name("name") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + for i in 0..batch.num_rows() { + results.insert(ids.value(i), names.value(i).to_string()); + } + } + assert_eq!(results.len(), 7); + assert_eq!(results.get(&3), Some(&"gen1_3".to_string())); + assert_eq!(results.get(&4), Some(&"gen2_4".to_string())); + assert_eq!(results.get(&5), Some(&"active_5".to_string())); + assert_eq!(results.get(&6), Some(&"active_6".to_string())); + } + /// Regression for the concurrent-read-vs-flush hole: a sealed /// (frozen-awaiting-flush) memtable is not yet recorded as a flushed /// generation, but its rows must still be in the scan's read union and @@ -928,16 +860,12 @@ mod integration_tests { let plan = scanner.create_plan().await.unwrap(); - // With only one source, should skip UnionExec and SortPreservingMergeExec - // Plan structure: - // - DeduplicateExec at top - // - SortExec (no merge needed) - // - LanceRead for base table only + // A single source collapses to just its scan: no union, no block-list + // (nothing supersedes the base), no dedup. assert_plan_node_equals( plan, - "DeduplicateExec: pk=[id], with_memtable_gen=false, keep_addr=false, input_sorted=true - SortExec: expr=[id@0 ASC NULLS LAST, _rowaddr@2 DESC NULLS LAST]... - LanceRead:...base/data...refine_filter=--", + "ProjectionExec:... + LanceRead:...base/data...refine_filter=--", ) .await .unwrap(); @@ -1030,25 +958,24 @@ mod integration_tests { let plan = scanner.create_plan().await.unwrap(); // Verify plan with keep_addr=true (no _memtable_gen, so no MemtableGenTagExec). - // Non-base arms wrap their SortExec in a ProjectionExec that NULLs - // `_rowaddr` post-sort: per-source addresses are not meaningful to - // the caller. The base arm leaves `_rowaddr` real. + // Non-base arms wrap their scan in a ProjectionExec that NULLs `_rowaddr`: + // per-source addresses are not meaningful to the caller. The base arm + // leaves `_rowaddr` real. Older generations are block-list filtered. assert_plan_node_equals( plan, - "DeduplicateExec: pk=[id], with_memtable_gen=false, keep_addr=true, input_sorted=true - SortPreservingMergeExec: [id@0 ASC NULLS LAST] + "ProjectionExec:... + CoalescePartitionsExec UnionExec - ProjectionExec: expr=[id@0 as id, name@1 as name, NULL as _rowaddr] - SortExec: expr=[id@0 ASC NULLS LAST, _rowaddr@2 DESC NULLS LAST]... - MemTableScanExec: projection=[id, name, _rowaddr], with_row_id=false, with_row_address=true - ProjectionExec: expr=[id@0 as id, name@1 as name, NULL as _rowaddr] - SortExec: expr=[id@0 ASC NULLS LAST, _rowaddr@2 DESC NULLS LAST]... - LanceRead:...gen_2... - ProjectionExec: expr=[id@0 as id, name@1 as name, NULL as _rowaddr] - SortExec: expr=[id@0 ASC NULLS LAST, _rowaddr@2 DESC NULLS LAST]... - LanceRead:...gen_1... - SortExec: expr=[id@0 ASC NULLS LAST, _rowaddr@2 DESC NULLS LAST]... - LanceRead:...base/data...refine_filter=--", + ProjectionExec: expr=[id@0 as id, name@1 as name, NULL as _rowaddr] + MemTableDedupScanExec: projection=[id, name, _rowaddr], with_row_id=false, with_row_address=true + ProjectionExec: expr=[id@0 as id, name@1 as name, NULL as _rowaddr] + PkHashFilterExec: pk_cols=[id]... + LanceRead:...gen_2... + ProjectionExec: expr=[id@0 as id, name@1 as name, NULL as _rowaddr] + PkHashFilterExec: pk_cols=[id]... + LanceRead:...gen_1... + PkHashFilterExec: pk_cols=[id]... + LanceRead:...base/data...refine_filter=--", ) .await .unwrap(); @@ -1102,24 +1029,23 @@ mod integration_tests { // `take_rows`). MemtableGenTagExec sits above the NULL projection. assert_plan_node_equals( plan, - "DeduplicateExec: pk=[id], with_memtable_gen=true, keep_addr=true, input_sorted=true - SortPreservingMergeExec: [id@0 ASC NULLS LAST] + "ProjectionExec:... + CoalescePartitionsExec UnionExec - MemtableGenTagExec: gen=gen3 - ProjectionExec: expr=[id@0 as id, name@1 as name, NULL as _rowaddr] - SortExec: expr=[id@0 ASC NULLS LAST, _rowaddr@2 DESC NULLS LAST]... - MemTableScanExec: projection=[id, name, _rowaddr], with_row_id=false, with_row_address=true - MemtableGenTagExec: gen=gen2 - ProjectionExec: expr=[id@0 as id, name@1 as name, NULL as _rowaddr] - SortExec: expr=[id@0 ASC NULLS LAST, _rowaddr@2 DESC NULLS LAST]... - LanceRead:...gen_2... - MemtableGenTagExec: gen=gen1 - ProjectionExec: expr=[id@0 as id, name@1 as name, NULL as _rowaddr] - SortExec: expr=[id@0 ASC NULLS LAST, _rowaddr@2 DESC NULLS LAST]... - LanceRead:...gen_1... - MemtableGenTagExec: gen=base - SortExec: expr=[id@0 ASC NULLS LAST, _rowaddr@2 DESC NULLS LAST]... - LanceRead:...base/data...refine_filter=--", + MemtableGenTagExec: gen=gen3 + ProjectionExec: expr=[id@0 as id, name@1 as name, NULL as _rowaddr] + MemTableDedupScanExec: projection=[id, name, _rowaddr], with_row_id=false, with_row_address=true + MemtableGenTagExec: gen=gen2 + ProjectionExec: expr=[id@0 as id, name@1 as name, NULL as _rowaddr] + PkHashFilterExec: pk_cols=[id]... + LanceRead:...gen_2... + MemtableGenTagExec: gen=gen1 + ProjectionExec: expr=[id@0 as id, name@1 as name, NULL as _rowaddr] + PkHashFilterExec: pk_cols=[id]... + LanceRead:...gen_1... + MemtableGenTagExec: gen=base + PkHashFilterExec: pk_cols=[id]... + LanceRead:...base/data...refine_filter=--", ) .await .unwrap(); @@ -1241,20 +1167,27 @@ mod integration_tests { let plan_str = format!("{}", displayable(plan.as_ref()).indent(true)); // 1. Verify overall structure + assert!(plan_str.contains("UnionExec"), "Should have UnionExec"); assert!( - plan_str.contains("DeduplicateExec: pk=[id]"), - "Should have DeduplicateExec at top" + plan_str.contains("PkHashFilterExec"), + "older generations should be block-list filtered" ); assert!( - plan_str.contains("SortPreservingMergeExec"), - "Should use SortPreservingMergeExec for merging" + !plan_str.contains("DeduplicateExec"), + "filtered read must not use a cross-source DeduplicateExec" ); - assert!(plan_str.contains("UnionExec"), "Should have UnionExec"); - // 2. Verify BTree index optimization for active memtable + // 2. The active arm uses the fused dedup scan: it deduplicates to + // newest-per-PK *before* applying the predicate, so it deliberately + // forgoes the in-memory BTree skip (the dedup must see every + // version). See MemTableDedupScanExec. assert!( - plan_str.contains("BTreeIndexExec: predicate=Eq"), - "Active memtable should use BTreeIndexExec instead of MemTableScanExec" + plan_str.contains("MemTableDedupScanExec"), + "Active memtable should use the fused dedup scan" + ); + assert!( + !plan_str.contains("BTreeIndexExec"), + "Active filtered read no longer uses the BTree skip" ); // 3. Verify filter pushdown to flushed and base datasets @@ -1359,6 +1292,96 @@ mod integration_tests { assert_eq!(results.get(&3), Some(&"gen1_3".to_string())); } + /// End-to-end regression for the active within-generation phantom: a PK + /// inserted then updated in one memtable so its newest version fails the + /// predicate must NOT leak the older version that still passes. + #[tokio::test] + async fn test_lsm_scan_active_within_gen_phantom_suppressed() { + let schema = create_pk_schema(); + let temp_dir = tempfile::tempdir().unwrap(); + let base_path = temp_dir.path().to_str().unwrap(); + + // Base has an unrelated matching row, to prove real matches survive. + let base_uri = format!("{}/base", base_path); + let base_dataset = Arc::new( + create_dataset(&base_uri, vec![create_test_batch(&schema, &[1], "base")]).await, + ); + + let shard_id = Uuid::new_v4(); + let shard_snapshot = ShardSnapshot::new(shard_id).with_current_generation(1); + + // Active memtable: id=10 inserted ("keep") then updated to NULL within + // the same generation; id=20 ("active_20") is a control that matches. + let batch_store = Arc::new(BatchStore::with_capacity(16)); + let active_batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Int32Array::from(vec![10, 20, 10])), + Arc::new(StringArray::from(vec![ + Some("keep"), + Some("active_20"), + None, + ])), + ], + ) + .unwrap(); + batch_store.append(active_batch).unwrap(); + + let in_memory = InMemoryMemTables { + active: InMemoryMemTableRef { + batch_store, + index_store: Arc::new(IndexStore::new()), + schema: schema.clone(), + generation: 1, + }, + frozen: vec![], + }; + + let scanner = LsmScanner::new(base_dataset, vec![shard_snapshot], vec!["id".to_string()]) + .filter("name IS NOT NULL") + .unwrap() + .with_in_memory_memtables(shard_id, in_memory); + + let batches: Vec = scanner + .try_into_stream() + .await + .unwrap() + .try_collect() + .await + .unwrap(); + + let mut results: HashMap> = HashMap::new(); + for batch in batches { + let ids = batch + .column_by_name("id") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + let names = batch + .column_by_name("name") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + for i in 0..batch.num_rows() { + let name = (!names.is_null(i)).then(|| names.value(i).to_string()); + results.insert(ids.value(i), name); + } + } + + // id=10's newest version is NULL, so it must be absent. Pre-fix the + // predicate dropped the NULL before dedup and the stale "keep" leaked. + assert!( + !results.contains_key(&10), + "id=10 newest is NULL; stale 'keep' must not leak under name IS NOT NULL, got {:?}", + results + ); + assert_eq!(results.get(&1), Some(&Some("base_1".to_string()))); + assert_eq!(results.get(&20), Some(&Some("active_20".to_string()))); + assert_eq!(results.len(), 2); + } + #[tokio::test] async fn test_lsm_scan_without_base_table() { let (base_dataset, shard_snapshots, active_memtable, pk_columns, _temp_path) = @@ -1393,7 +1416,7 @@ mod integration_tests { plan_str ); assert!( - plan_str.contains("MemTableScanExec"), + plan_str.contains("MemTableDedupScanExec"), "Plan must scan the active memtable, got: {}", plan_str ); diff --git a/rust/lance/src/dataset/mem_wal/scanner/projection.rs b/rust/lance/src/dataset/mem_wal/scanner/projection.rs index 00c05056a18..20d1b1a403d 100644 --- a/rust/lance/src/dataset/mem_wal/scanner/projection.rs +++ b/rust/lance/src/dataset/mem_wal/scanner/projection.rs @@ -125,36 +125,6 @@ pub fn canonical_output_schema( Arc::new(Schema::new(fields)) } -/// Like [`canonical_output_schema`] but with the internal LSM bookkeeping -/// columns appended: `_memtable_gen` (UInt64, NOT NULL) and `_freshness` -/// (UInt64, nullable). Used by the vector-search pipeline to carry source -/// identity + per-source row order through the union and the global -/// dedup; both columns are dropped by a downstream `project_to_canonical` -/// before returning to the caller. -pub fn canonical_internal_schema( - user_projection: Option<&[String]>, - base_schema: &SchemaRef, - pk_columns: &[String], - include_distance: bool, -) -> SchemaRef { - use crate::dataset::mem_wal::scanner::exec::{FRESHNESS_COLUMN, MEMTABLE_GEN_COLUMN}; - - let canonical = - canonical_output_schema(user_projection, base_schema, pk_columns, include_distance); - let mut fields: Vec> = canonical.fields().iter().cloned().collect(); - fields.push(Arc::new(Field::new( - MEMTABLE_GEN_COLUMN, - DataType::UInt64, - false, - ))); - fields.push(Arc::new(Field::new( - FRESHNESS_COLUMN, - DataType::UInt64, - true, - ))); - Arc::new(Schema::new(fields)) -} - /// Wrap `plan` so the named columns become typed NULL literals; all /// other columns are forwarded unchanged. Schema is preserved (same /// fields, same dtypes). Useful for stripping the *value* of an diff --git a/rust/lance/src/dataset/mem_wal/scanner/vector_search.rs b/rust/lance/src/dataset/mem_wal/scanner/vector_search.rs index 820988536aa..b6b1f952b25 100644 --- a/rust/lance/src/dataset/mem_wal/scanner/vector_search.rs +++ b/rust/lance/src/dataset/mem_wal/scanner/vector_search.rs @@ -15,7 +15,6 @@ use datafusion::physical_expr::{LexOrdering, PhysicalSortExpr}; use datafusion::physical_plan::ExecutionPlan; #[allow(deprecated)] use datafusion::physical_plan::coalesce_batches::CoalesceBatchesExec; -use datafusion::physical_plan::coalesce_partitions::CoalescePartitionsExec; use datafusion::physical_plan::sorts::sort::SortExec; use datafusion::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; use datafusion::physical_plan::union::UnionExec; @@ -28,22 +27,23 @@ use crate::io::exec::TakeExec; use super::collector::LsmDataSourceCollector; use super::data_source::LsmDataSource; -use super::exec::{FreshnessPolarity, LsmGlobalPkDedupExec, LsmSourceTagExec}; +use super::exec::{DedupDirection, WithinSourceDedupExec}; use super::flushed_cache::{FlushedMemTableCache, open_flushed_dataset}; use super::projection::{ - DISTANCE_COLUMN, build_scanner_projection, canonical_internal_schema, canonical_output_schema, - null_columns, project_to_canonical, wants_row_id, + DISTANCE_COLUMN, build_scanner_projection, canonical_output_schema, null_columns, + project_to_canonical, wants_row_id, }; use crate::session::Session; /// Plans vector search queries over LSM data. /// -/// Each source independently runs KNN, then results are unioned and run -/// through a single global PK dedup that picks the row with the largest -/// `(generation, freshness)` tuple per primary key. Generation is the -/// source identity (base = 0, memtable gens 1..N, active = N+1) and -/// freshness is the per-source row order normalized so larger = newer -/// (see [`LsmSourceTagExec`]). +/// Each source is independently newest-per-PK before the union — the active +/// memtable via an over-fetched KNN + within-source dedup, flushed generations +/// via their within-generation deletion vector — and the cross-generation +/// block-list ([`super::exec::PkHashFilterExec`]) drops any PK superseded by a +/// newer generation. So each PK reaches the union from exactly one source and a +/// distance-ordered merge yields the global top-k; no cross-source dedup is +/// needed. /// /// # Query Plan Structure /// @@ -51,22 +51,19 @@ use crate::session::Session; /// TakeExec (optional: fetch user-projected cols from base dataset) /// SortPreservingMergeExec: order_by=[_distance ASC], fetch=k /// SortExec: order_by=[_distance ASC], fetch=k (per partition, parallel) -/// ProjectionExec (drops _memtable_gen, _freshness) -/// LsmGlobalPkDedupExec: pk=[…], gen=_memtable_gen, freshness=_freshness -/// CoalescePartitionsExec -/// UnionExec -/// ProjectionExec (canonical internal schema) -/// ProjectionExec (null_columns _rowid) (non-base only) -/// LsmSourceTagExec: gen=N+1, polarity=InsertOrder (active) -/// KNNExec: active memtable, k=k -/// ProjectionExec (canonical internal schema) -/// ProjectionExec (null_columns _rowid) -/// LsmSourceTagExec: gen=N, polarity=ReverseWrite (flushed) -/// KNNExec: flushed gen N, k=k (fast_search) -/// … one per flushed gen … -/// ProjectionExec (canonical internal schema) -/// LsmSourceTagExec: gen=0, polarity=InsertOrder (base) -/// KNNExec: base table, k=k (fast_search)[.refine()?] +/// UnionExec +/// ProjectionExec (canonical output schema) +/// SortExec(_distance, fetch=k) +/// WithinSourceDedupExec: KeepMaxRowAddr (active) +/// KNNExec: active memtable, fetch=ceil(k*overfetch) +/// ProjectionExec (canonical output schema) +/// ProjectionExec (null_columns _rowid) +/// PkHashFilterExec: block-list (flushed) +/// KNNExec: flushed gen N, fetch=ceil(k*overfetch) (fast_search) +/// … one per flushed gen … +/// ProjectionExec (canonical output schema) +/// PkHashFilterExec: block-list (base) +/// KNNExec: base table, k (fast_search)[.refine()?] /// ``` /// /// # Index-Only Search (fast_search) @@ -76,18 +73,10 @@ use crate::session::Session; /// - Each flushed memtable has its own vector index built during flush. /// - The active memtable covers any unindexed data. /// - Searching unindexed data in base/flushed would be redundant. -/// -/// # Dedup semantics -/// -/// `LsmGlobalPkDedupExec` keeps the row whose `(generation, freshness)` -/// tuple is largest, so newer generations always win and ties within a -/// generation fall to the source-local freshness (larger row offset for -/// active memtables; smaller `_rowid` for flushed memtables, flipped by -/// `LsmSourceTagExec` so the comparison stays uniform). pub struct LsmVectorSearchPlanner { /// Data source collector. collector: LsmDataSourceCollector, - /// Primary key column names (used by the global dedup). + /// Primary key column names (used by within-source dedup and block-list). pk_columns: Vec, /// Schema of the base table. base_schema: SchemaRef, @@ -95,13 +84,10 @@ pub struct LsmVectorSearchPlanner { vector_column: String, /// Distance metric type (L2, Cosine, Dot, etc.). distance_type: lance_linalg::distance::DistanceType, - /// Base dataset reference for post-rerank take. - /// - /// After the global PK dedup and sort, a `TakeExec` against this - /// dataset materializes any user-projected columns that were not - /// part of the per-source KNN output. Rows from memtables already - /// carry all columns; the take only fetches additional data for - /// base-table rows (which have a real `_rowid`). + /// Base dataset for the post-rerank take: after the cross-source distance + /// merge, `TakeExec` materializes user-projected columns that weren't in + /// the per-source KNN output. Memtable rows already carry all columns; + /// the take only fetches additional data for base rows (real `_rowid`). dataset: Option>, /// Session threaded into flushed-generation opens (shared caches). session: Option>, @@ -218,26 +204,21 @@ impl LsmVectorSearchPlanner { return self.empty_plan(projection); } - // `overfetch_factor` is the single stale-filtering knob: a factor `>= 1.0` - // turns the per-source block-list / `PkHashFilterExec` on (and sets the - // over-fetch multiple); `< 1.0` turns it off entirely. See the doc above. - let filter_stale = overfetch_factor >= 1.0; - - // Per-source blocked PK-hash sets (`NEWER(G)`; base = union of all gens). - // A present entry → that source over-fetches and drops blocked candidates - // before the union. `Box::pin` avoids `clippy::large_futures`. Skipped - // entirely when stale filtering is disabled (no block-list, no filter). - let block_lists = if filter_stale { - Box::pin(super::block_list::compute_source_block_lists( - &sources, - &self.pk_columns, - self.session.as_ref(), - self.flushed_cache.as_ref(), - )) - .await? - } else { - Default::default() - }; + // The block-list is the sole cross-generation dedup mechanism, so it + // runs unconditionally; `overfetch_factor` only tunes the over-fetch + // multiple and is clamped to >= 1.0 so blocked sources still yield k + // live candidates after the post-filter. + let overfetch_factor = overfetch_factor.max(1.0); + + // Per-source PK-hash block sets (`NEWER(G)`; base = union of all gens). + // `Box::pin` keeps the future off `clippy::large_futures`. + let block_lists = Box::pin(super::block_list::compute_source_block_lists( + &sources, + &self.pk_columns, + self.session.as_ref(), + self.flushed_cache.as_ref(), + )) + .await?; let canonical_schema = canonical_output_schema( projection, @@ -245,29 +226,24 @@ impl LsmVectorSearchPlanner { &self.pk_columns, true, // include _distance — KNN always produces it ); - // The internal schema carries `_memtable_gen` + `_freshness` - // through the union and the global dedup; both are dropped - // afterwards by a project back to the canonical output schema. - let internal_schema = - canonical_internal_schema(projection, &self.base_schema, &self.pk_columns, true); - // Refine the base table when explicitly requested, or whenever stale - // filtering runs (it over-fetches the base's approximate-index candidates, - // so distances must be re-ranked to exact before the cross-source merge). - let refine_base = refine_base_table || filter_stale; + // Refine the base table when explicitly requested, or whenever the base + // is blocked (it then over-fetches its approximate-index candidates, so + // distances must be re-ranked to exact before the cross-source merge). + // `block_lists` is non-empty exactly when a newer generation exists. + let refine_base = refine_base_table || !block_lists.is_empty(); let mut knn_plans = Vec::new(); for source in &sources { let generation = source.generation(); let is_base = matches!(source, LsmDataSource::BaseTable { .. }); - // A blocked source fetches `ceil(k * overfetch_factor)` candidates so - // the post-filter still leaves k live ones (factor >= 1.0 ⇒ >= k). - // `block_lists` is non-empty only when filtering is on, so a present - // entry already implies `overfetch_factor >= 1.0`. Keyed per shard — - // generations are per-shard, so a source is only blocked by its own - // shard's newer generations. + let is_active = matches!(source, LsmDataSource::ActiveMemTable { .. }); + // Over-fetch when the post-source filter can drop candidates: a + // blocked source loses superseded rows; the active source's + // within-source dedup collapses duplicate-PK HNSW nodes. Block + // lookup is per shard — generations are per-shard. let blocked = block_lists.get(&(source.shard_id(), generation)); - let fetch_k = if blocked.is_some() { + let fetch_k = if blocked.is_some() || is_active { ((k as f64) * overfetch_factor).ceil() as usize } else { k @@ -281,68 +257,51 @@ impl LsmVectorSearchPlanner { is_base && refine_base, )) .await?; - // Drop superseded rows before the union — closes the top-k stale-read - // gap the global dedup can't. Within-gen dups (not in the blocked set) - // are left to the dedup's freshness tiebreaker. - let knn = match blocked { - Some(set) => Arc::new(super::exec::PkHashFilterExec::new( + // Make each source independently newest-per-PK before the union: + // * active: the append-only HNSW returns one node per inserted + // version, so collapse duplicate PKs to the newest insert + // (KeepMaxRowAddr on `_rowid`) and re-sort by distance. This + // stays probabilistic — a fresh version evicted from the + // over-fetched top-k still leaks. + // * flushed/base: drop cross-gen superseded rows via the + // block-list (within-gen is handled by the flushed DV). + let knn = if is_active { + let deduped: Arc = Arc::new(WithinSourceDedupExec::new( knn, self.pk_columns.clone(), - set.clone(), - k, - )) as Arc, - None => knn, - }; - // Tag rows with `(_memtable_gen, _freshness)`. Polarity differs - // per source — see [`LsmSourceTagExec`] / [`FreshnessPolarity`]: - // * active memtable: insert order, larger `_rowid` = newer - // * flushed memtable: reverse-written, smaller `_rowid` = newer - // * base table: no duplicates expected; polarity moot - let polarity = match source { - LsmDataSource::FlushedMemTable { .. } => FreshnessPolarity::ReverseWrite, - LsmDataSource::ActiveMemTable { .. } | LsmDataSource::BaseTable { .. } => { - FreshnessPolarity::InsertOrder + lance_core::ROW_ID, + DedupDirection::KeepMaxRowAddr, + )); + sort_by_distance(deduped, k)? + } else { + match blocked { + Some(set) => Arc::new(super::exec::PkHashFilterExec::new( + knn, + self.pk_columns.clone(), + set.clone(), + k, + )) as Arc, + None => knn, } }; - let tagged: Arc = Arc::new(LsmSourceTagExec::new( - knn, - generation, - polarity, - lance_core::ROW_ID, - )); - // Lance's `fast_search()` always produces `_rowid` whether or - // not we asked for it; the active arm also produces `_rowid` - // when we ask for it (to drive freshness). For non-base arms - // the per-source value would collide with base row ids in the - // canonical output, so NULL it before stitching into the - // internal schema. The dedup has already consumed it via - // `_freshness`. + // Lance's `fast_search()` and the active scan both produce a + // per-source `_rowid` that would collide with base row ids in the + // canonical output, so NULL it on non-base arms. The base arm keeps + // its real `_rowid` to drive the post-rerank take. let after_null = if is_base { - tagged + knn } else { - null_columns(tagged, &[lance_core::ROW_ID])? + null_columns(knn, &[lance_core::ROW_ID])? }; - // Normalize each source to the internal canonical schema - // (canonical user cols + `_memtable_gen` + `_freshness`). - let normalized = project_to_canonical(after_null, &internal_schema)?; + // Normalize each source to the canonical output schema. + let normalized = project_to_canonical(after_null, &canonical_schema)?; knn_plans.push(normalized); } + // No cross-source dedup needed (see struct doc): SortExec(per partition) + // + SortPreservingMerge does the p-way distance-ordered top-k merge. #[allow(deprecated)] - let union: Arc = Arc::new(UnionExec::new(knn_plans)); - - // LsmGlobalPkDedupExec declares one output partition but only - // reads partition 0 of its input — coalesce first or partitions - // past the base table get silently dropped. - let coalesced: Arc = Arc::new(CoalescePartitionsExec::new(union)); - let deduped: Arc = Arc::new(LsmGlobalPkDedupExec::new( - coalesced, - self.pk_columns.clone(), - super::exec::MEMTABLE_GEN_COLUMN, - super::exec::FRESHNESS_COLUMN, - )); - // Drop `_memtable_gen` and `_freshness` — they're internal-only. - let merged: Arc = project_to_canonical(deduped, &canonical_schema)?; + let merged: Arc = Arc::new(UnionExec::new(knn_plans)); let distance_idx = merged.schema().index_of(DISTANCE_COLUMN).map_err(|_| { lance_core::Error::invalid_input(format!( @@ -502,6 +461,29 @@ impl LsmVectorSearchPlanner { } } +/// Sort a single-partition plan by `_distance` ascending and cap at `k`. +/// +/// Used to re-order the active arm after its within-source dedup (which emits +/// rows unordered) so the cross-source distance merge sees a sorted stream. +fn sort_by_distance(plan: Arc, k: usize) -> Result> { + let idx = plan.schema().index_of(DISTANCE_COLUMN).map_err(|_| { + lance_core::Error::invalid_input(format!( + "Column '{}' not found in schema", + DISTANCE_COLUMN + )) + })?; + let sort_expr = vec![PhysicalSortExpr { + expr: Arc::new(Column::new(DISTANCE_COLUMN, idx)), + options: SortOptions { + descending: false, + nulls_first: false, + }, + }]; + let ordering = LexOrdering::new(sort_expr) + .ok_or_else(|| lance_core::Error::internal("Failed to create LexOrdering".to_string()))?; + Ok(Arc::new(SortExec::new(ordering, plan).with_fetch(Some(k)))) +} + /// Convert a (typically single-row) FixedSizeList query into the array shape /// `Scanner::nearest` expects: /// @@ -951,15 +933,11 @@ mod tests { #[tokio::test] async fn test_vector_search_strips_internal_columns_and_preserves_active_rows() { // Two regressions in one test: - // (1) `LsmGlobalPkDedupExec` consumes `_memtable_gen` and `_freshness` - // but the user-visible output must NOT contain them — the - // post-dedup `project_to_canonical` is what strips them, so a - // refactor that drops that projection would leak these columns. - // (2) `LsmGlobalPkDedupExec` declares one output partition but only - // reads partition 0 of its input. Without a `CoalescePartitionsExec` - // ahead of it, every union partition past partition 0 is silently - // dropped — i.e. active-memtable rows disappear when the union - // puts them in a non-zero partition. + // (1) The plan must not leak internal columns (`_memtable_gen`, + // `_freshness`) into the user-visible output. + // (2) Active-memtable rows must reach the output — the UnionExec puts + // them in non-zero partitions, and any downstream node that only + // reads partition 0 would silently drop them. use crate::dataset::mem_wal::scanner::collector::{InMemoryMemTableRef, InMemoryMemTables}; use crate::dataset::mem_wal::write::{BatchStore, IndexStore}; use datafusion::prelude::SessionContext; @@ -1016,14 +994,22 @@ mod tests { .await .expect("planner should produce a plan"); - // Plan must include the new global dedup (proves the pipeline is wired). + // Each arm is independently newest-per-PK (active within-source dedup, + // flushed DV) and the block-list handles cross-gen, merged by a + // distance SPM. No global PK dedup or source tag node is involved. let plan_str = format!( "{}", datafusion::physical_plan::displayable(plan.as_ref()).indent(true) ); assert!( - plan_str.contains("LsmGlobalPkDedupExec"), - "expected new global-dedup pipeline, got:\n{}", + !plan_str.contains("LsmGlobalPkDedupExec") && !plan_str.contains("LsmSourceTagExec"), + "vector plan must not contain a global PK dedup or source tag node, got:\n{}", + plan_str + ); + assert!( + plan_str.contains("WithinSourceDedupExec") + && plan_str.contains("SortPreservingMergeExec"), + "expected per-arm dedup + distance merge, got:\n{}", plan_str ); @@ -1035,10 +1021,7 @@ mod tests { let out_schema = batches[0].schema(); assert!(out_schema.field_with_name(DISTANCE_COLUMN).is_ok()); - for internal in [ - super::super::exec::MEMTABLE_GEN_COLUMN, - super::super::exec::FRESHNESS_COLUMN, - ] { + for internal in [super::super::exec::MEMTABLE_GEN_COLUMN, "_freshness"] { assert!( out_schema.field_with_name(internal).is_err(), "`{}` leaked into output: {:?}", @@ -1051,11 +1034,11 @@ mod tests { ); } - // (2) Active-memtable rows must survive: collector emits base as - // partition 0 of the union and the active memtable as partition 1+. - // The active memtable holds ids 1..=4; the base holds id 10. At - // least one id in 1..=4 must appear in the output, otherwise the - // CoalescePartitionsExec was skipped and partitions 1+ were dropped. + // (2) Active-memtable rows must survive: the union emits base as + // partition 0 and the active memtable as partition 1+. The active + // memtable holds ids 1..=4; the base holds id 10. At least one id in + // 1..=4 must appear, otherwise the SortPreservingMerge dropped the + // non-zero partitions. let mut all_ids: Vec = Vec::new(); for batch in &batches { let id_col = batch @@ -1177,14 +1160,14 @@ mod tests { let pk1_count = ids.iter().filter(|i| **i == 1).count(); assert_eq!( pk1_count, 1, - "pk=1 must appear exactly once after cross-source dedup; got ids={:?}", + "pk=1 must appear exactly once in the merged top-k; got ids={:?}", ids, ); } #[tokio::test] async fn test_vector_search_system_columns_real_only_for_base() { - // Covers tests 1+2+3 from the PR review: + // Covers three properties of the per-source system columns: // 1. base-hit `_rowid`/`_rowaddr` carry real values // 2. flushed-memtable arm runs without erroring // 3. `_rowaddr` symmetry with `_rowid` (same code path, both are @@ -1530,6 +1513,18 @@ mod tests { .await .unwrap(); + // The active arm collapses duplicate-PK HNSW nodes itself via + // WithinSourceDedupExec — there is no cross-source dedup fallback. + let plan_str = format!( + "{}", + datafusion::physical_plan::displayable(plan.as_ref()).indent(true) + ); + assert!( + plan_str.contains("WithinSourceDedupExec"), + "active vector arm must self-dedup, got:\n{}", + plan_str + ); + let ctx = SessionContext::new(); let stream = plan.execute(0, ctx.task_ctx()).unwrap(); let batches: Vec = stream.try_collect().await.unwrap(); @@ -1556,33 +1551,15 @@ mod tests { #[tokio::test] async fn test_vector_search_stale_read_when_fresh_falls_out_of_top_k() { - // FAILING SPEC — exposes a stale-read gap in the per-source top-k → - // global-dedup pipeline (the design that replaced the bloom-based - // FilterStaleExec in #6881). - // - // `LsmGlobalPkDedupExec` keeps the row with the largest - // `(generation, freshness)` tuple PER PK, but it can only do so for - // PKs that actually appear in *some* source's top-k. If a PK's - // *fresh* version is pushed out of its own source's top-k by other - // (closer) rows, the dedup never sees it — so it cannot suppress the - // *stale* copy from an older source, which is then served. + // Regression for the cross-generation stale-read gap that the + // PkHashFilterExec block-list closes. // // Scenario: - // * Base table (gen 0): pk=1 with vector == query (distance ~0). - // This is the STALE copy. - // * Active memtable (gen 1): - // - pk=1 re-inserted with a FAR vector (the fresh value). - // - pk=2 with a vector closer to the query than fresh pk=1. - // - // With k=1 the active arm returns only pk=2 (closer than fresh - // pk=1), so fresh pk=1 never reaches the dedup. The base arm returns - // the stale pk=1 at distance ~0, which survives dedup unchallenged - // and wins top-1. - // - // Correct newest-wins behavior: pk=1's live vector is far, so the - // nearest live neighbor is pk=2. pk=1 must never be served at the - // stale ~0 distance of the superseded base-table copy. Today this - // FAILS — the stale pk=1 is returned. + // * Base (gen 0): stale pk=1 sitting on the query (distance ~0). + // * Active (gen 1): pk=1 updated to a far vector, plus pk=2 closer + // to the query than fresh pk=1. With k=1 the active arm surfaces + // pk=2 and drops fresh pk=1, so without the block-list the stale + // base copy of pk=1 wins top-1. use crate::dataset::mem_wal::scanner::collector::{InMemoryMemTableRef, InMemoryMemTables}; use crate::dataset::mem_wal::write::{BatchStore, IndexStore}; use crate::index::DatasetIndexExt; @@ -1705,27 +1682,27 @@ mod tests { rows ); - // Toggle off: with filter_stale=false the block-list is skipped, so the - // superseded base copy of pk=1 (distance ~0) is no longer suppressed. - // Fresh pk=1 never reaches the global dedup (evicted from the active - // top-k), so the stale copy resurfaces and wins top-1. - let unfiltered = planner + // The block-list is now unconditional: a sub-1.0 overfetch_factor is + // clamped to 1.0 and the stale base copy of pk=1 stays suppressed (the + // factor only tunes the over-fetch multiple, it cannot disable filtering). + let still_filtered = planner .plan_search(&query, 1, 1, None, false, 0.0) .await .unwrap(); - let unfiltered_rows = { - let stream = unfiltered + let still_filtered_rows = { + let stream = still_filtered .execute(0, SessionContext::new().task_ctx()) .unwrap(); let batches: Vec = stream.try_collect().await.unwrap(); collect_id_dist(&batches) }; assert!( - unfiltered_rows + still_filtered_rows .iter() - .any(|&(id, d)| id == 1 && d.abs() < 1e-3), - "filter_stale=false must surface the stale pk=1 (distance ~0); got {:?}", - unfiltered_rows + .all(|&(id, d)| !(id == 1 && d.abs() < 1e-3)), + "block-list is unconditional: stale pk=1 must stay suppressed even \ + with overfetch_factor < 1.0; got {:?}", + still_filtered_rows ); } diff --git a/rust/lance/src/dataset/mem_wal/write.rs b/rust/lance/src/dataset/mem_wal/write.rs index e1314d629e8..3d59735f7e8 100644 --- a/rust/lance/src/dataset/mem_wal/write.rs +++ b/rust/lance/src/dataset/mem_wal/write.rs @@ -4444,6 +4444,89 @@ mod shard_writer_tests { assert!(!defaults.contains_key("shard_spec_id")); } + /// A maintained index can be split across multiple physical segments once a + /// delta is appended over previously uncovered fragments (the distributed + /// indexer / `optimize_indices(append)` flow). `mem_wal_writer` must resolve + /// such an index by name without tripping the singular loader's "multiple + /// indices of the same name" error — it only reads the shared type/params, + /// which every segment carries identically. + #[tokio::test] + async fn test_mem_wal_writer_with_multi_segment_index() { + use lance_index::optimize::OptimizeOptions; + + let vector_dim = 32; + let schema = create_test_schema(vector_dim); + let uri = format!("memory://test_multi_segment_index_{}", Uuid::new_v4()); + + // Initial fragment + an IVF vector index covering it. + let initial = create_test_batch(&schema, 0, 256, vector_dim); + let batches = RecordBatchIterator::new([Ok(initial)], schema.clone()); + let mut dataset = Dataset::write(batches, &uri, Some(WriteParams::default())) + .await + .expect("Failed to create dataset"); + let vector_params = VectorIndexParams::ivf_flat(1, MetricType::L2); + dataset + .create_index( + &["vector"], + IndexType::Vector, + Some("vector_idx".to_string()), + &vector_params, + true, + ) + .await + .expect("Failed to create vector index"); + + // Append a second fragment and index it as a *delta* (no merge), so the + // index ends up with two physical segments sharing the name "vector_idx". + let appended = create_test_batch(&schema, 256, 256, vector_dim); + let append_batches = RecordBatchIterator::new([Ok(appended)], schema.clone()); + dataset + .append(append_batches, None) + .await + .expect("Failed to append fragment"); + dataset + .optimize_indices(&OptimizeOptions::append()) + .await + .expect("Failed to append index delta"); + + // Precondition: the index is genuinely multi-segment, so the singular + // `load_index_by_name` would error here. + assert_eq!( + dataset + .load_indices_by_name("vector_idx") + .await + .unwrap() + .len(), + 2, + "expected two physical segments for the maintained index" + ); + + dataset + .initialize_mem_wal() + .maintained_indexes(["vector_idx"]) + .execute() + .await + .expect("Failed to initialize MemWAL"); + + // The regression: loading the multi-segment maintained index must succeed. + let shard_id = Uuid::new_v4(); + let writer = dataset + .mem_wal_writer( + shard_id, + ShardWriterConfig::new(shard_id).with_durable_write(false), + ) + .await + .expect("mem_wal_writer must accept a multi-segment maintained index"); + + // And the resulting writer is functional. + writer + .put(vec![create_test_batch(&schema, 200, 10, vector_dim)]) + .await + .unwrap(); + assert_eq!(writer.memtable_stats().await.unwrap().row_count, 10); + writer.close().await.unwrap(); + } + #[tokio::test] async fn test_initialize_mem_wal_bucket_sharding() { let vector_dim = 128; diff --git a/rust/lance/src/dataset/scanner.rs b/rust/lance/src/dataset/scanner.rs index 547707affcb..d0f5983b3ce 100644 --- a/rust/lance/src/dataset/scanner.rs +++ b/rust/lance/src/dataset/scanner.rs @@ -22,6 +22,7 @@ use datafusion::logical_expr::{Expr, ScalarUDF, col, lit}; use datafusion::physical_expr::PhysicalSortExpr; #[allow(deprecated)] use datafusion::physical_plan::coalesce_batches::CoalesceBatchesExec; +use datafusion::physical_plan::coalesce_partitions::CoalescePartitionsExec; use datafusion::physical_plan::expressions; use datafusion::physical_plan::projection::ProjectionExec as DFProjectionExec; use datafusion::physical_plan::sorts::sort::SortExec; @@ -37,7 +38,7 @@ use datafusion::scalar::ScalarValue; use datafusion_expr::ExprSchemable; use datafusion_expr::execution_props::ExecutionProps; use datafusion_functions::core::getfield::GetFieldFunc; -use datafusion_physical_expr::expressions::Column; +use datafusion_physical_expr::expressions::{Column, Literal}; use datafusion_physical_expr::{LexOrdering, Partitioning, PhysicalExpr, create_physical_expr}; use datafusion_physical_plan::joins::PartitionMode; use datafusion_physical_plan::projection::ProjectionExec; @@ -100,7 +101,9 @@ use crate::io::exec::scalar_index::{MaterializeIndexExec, ScalarIndexExec}; use crate::io::exec::{ AddRowAddrExec, FilterPlan as ExprFilterPlan, KNNVectorDistanceExec, LancePushdownScanExec, LanceScanExec, Planner, PreFilterSource, ScanConfig, TakeExec, - knn::{KNN_INDEX_SCHEMA, new_knn_exec}, + knn::{ + KnnBatchParams, QUERY_INDEX_COL, knn_empty_result_schema, new_knn_exec, query_index_field, + }, project, }; use crate::io::exec::{AddRowOffsetExec, LanceFilterExec, LanceScanConfig, get_physical_optimizer}; @@ -768,6 +771,10 @@ pub struct Scanner { ordering: Option>, nearest: Option, + nearest_query_count: usize, + /// True when the query shape represents a batch of single-vector queries + /// (list-like query on a fixed-size vector column, or multiple concatenated vectors). + is_batch_nearest: bool, /// If false, do not use any scalar indices for the scan /// @@ -1023,6 +1030,8 @@ impl Scanner { offset: None, ordering: None, nearest: None, + nearest_query_count: 1, + is_batch_nearest: false, use_stats: true, ordered: true, fragments: None, @@ -1427,6 +1436,19 @@ impl Scanner { Ok(self) } + /// Returns true when `q` is a batch of single-vector queries. + /// + /// List-like queries against a [`DataType::List`] vector column are treated as one + /// multivector query. The same list-like query against a fixed-size vector column is + /// treated as a batch of single-vector queries. + fn is_batch_nearest_query(vector_type: &DataType, query_type: &DataType) -> bool { + matches!(vector_type, DataType::FixedSizeList(_, _)) + && matches!( + query_type, + DataType::List(_) | DataType::FixedSizeList(_, _) + ) + } + /// Find k-nearest neighbor within the vector column. /// the query can be a Float16Array, Float32Array, Float64Array, UInt8Array, /// or a ListArray/FixedSizeListArray of the above types. @@ -1448,16 +1470,10 @@ impl Scanner { // make sure the field exists let (vector_type, element_type) = get_vector_type(self.dataset.schema(), column)?; let dim = get_vector_dim(self.dataset.schema(), column)?; + let query_type = q.data_type().clone(); - let q = match q.data_type() { + let (q, query_count) = match &query_type { DataType::List(_) | DataType::FixedSizeList(_, _) => { - if !matches!(vector_type, DataType::List(_)) { - return Err(Error::invalid_input(format!( - "Query is multivector but column {}({})is not multivector", - column, vector_type, - ))); - } - if let Some(list_array) = q.as_list_opt::() { for i in 0..list_array.len() { let vec = list_array.value(i); @@ -1470,7 +1486,15 @@ impl Scanner { ))); } } - list_array.values().clone() + // A list-like query against a multivector column is one multivector query. + // The same list-like query against a fixed-size vector column is a batch + // of single-vector queries. + let query_count = if matches!(vector_type, DataType::List(_)) { + 1 + } else { + list_array.len() + }; + (list_array.values().clone(), query_count) } else { let fsl = q.as_fixed_size_list(); if fsl.value_length() as usize != dim { @@ -1481,7 +1505,15 @@ impl Scanner { dim, ))); } - fsl.values().clone() + // A list-like query against a multivector column is one multivector query. + // The same list-like query against a fixed-size vector column is a batch + // of single-vector queries. + let query_count = if matches!(vector_type, DataType::List(_)) { + 1 + } else { + fsl.len() + }; + (fsl.values().clone(), query_count) } } _ => { @@ -1493,10 +1525,17 @@ impl Scanner { dim, ))); } - q.slice(0, q.len()) + (q.slice(0, q.len()), 1) } }; + let is_batch_nearest = Self::is_batch_nearest_query(&vector_type, &query_type); + if is_batch_nearest && self.dataset.schema().field(QUERY_INDEX_COL).is_some() { + return Err(Error::invalid_input(format!( + "batch nearest neighbor search cannot be used on datasets with column '{QUERY_INDEX_COL}'" + ))); + } + let key = match &element_type { dt if dt == q.data_type() => q, dt if dt.is_floating() => coerce_float_vector( @@ -1528,6 +1567,8 @@ impl Scanner { query_parallelism: DEFAULT_QUERY_PARALLELISM, dist_q_c: 0.0, }); + self.nearest_query_count = query_count; + self.is_batch_nearest = is_batch_nearest; Ok(self) } @@ -1858,6 +1899,9 @@ impl Scanner { if self.nearest.as_ref().is_some() { extra_columns.push(ArrowField::new(DIST_COL, DataType::Float32, true)); + if self.is_batch_nearest { + extra_columns.push(query_index_field()); + } }; if self.full_text_query.is_some() { @@ -1919,6 +1963,23 @@ impl Scanner { } } + // Batch nearest queries expose the synthetic `query_index` discriminator as + // the first output column for compatibility with LanceDB batch vector search. + if self.is_batch_nearest { + let query_index_expr = if let Some(pos) = output_expr + .iter() + .position(|(_, name)| name == QUERY_INDEX_COL) + { + output_expr.remove(pos) + } else { + ( + expressions::col(QUERY_INDEX_COL, current_schema)?, + QUERY_INDEX_COL.to_string(), + ) + }; + output_expr.insert(0, query_index_expr); + } + if self.legacy_with_row_id { let row_id_pos = output_expr .iter() @@ -2777,6 +2838,10 @@ impl Scanner { read_options = read_options.with_io_buffer_size(io_buffer_size_bytes); } + if self.fast_search && filter_plan.has_index_query() { + read_options = read_options.with_only_indexed_fragments(); + } + let index_input = filter_plan.index_query.clone().map(|index_query| { Arc::new(ScalarIndexExec::new(self.dataset.clone(), index_query)) as Arc @@ -3500,6 +3565,7 @@ impl Scanner { } // ANN/KNN search execution node with optional prefilter + #[async_recursion] async fn vector_search( &self, filter_plan: &ExprFilterPlan, @@ -3651,6 +3717,10 @@ impl Scanner { }; if let Some((index_name, index_segments, index_metric)) = index_and_segments { + if self.is_batch_nearest { + return self.batch_indexed_vector_search(filter_plan, &q).await; + } + log::trace!("index found for vector search"); // Use the index's metric type q.metric_type = Some(index_metric); @@ -3688,7 +3758,9 @@ impl Scanner { Ok(knn_node) } else { if self.fast_search { - return Ok(Arc::new(EmptyExec::new(KNN_INDEX_SCHEMA.clone()))); + return Ok(Arc::new(EmptyExec::new(knn_empty_result_schema( + self.is_batch_nearest, + )))); } // Resolve metric type for flat search (use default if not specified) let metric = q @@ -3728,6 +3800,86 @@ impl Scanner { } } + async fn batch_indexed_vector_search( + &self, + filter_plan: &ExprFilterPlan, + q: &Query, + ) -> Result> { + let query_dim = q.key.len() / self.nearest_query_count; + let mut query_plans = Vec::with_capacity(self.nearest_query_count); + + for query_index in 0..self.nearest_query_count { + let mut single_query = q.clone(); + single_query.key = q.key.slice(query_index * query_dim, query_dim); + + let mut single_scanner = self.clone(); + single_scanner.nearest_query_count = 1; + single_scanner.is_batch_nearest = false; + single_scanner.nearest = Some(single_query.clone()); + + let single_plan = single_scanner + .vector_search(filter_plan, &single_query) + .await?; + query_plans.push(Self::add_query_index_column( + single_plan, + query_index as i32, + )?); + } + + let unioned = UnionExec::try_new(query_plans)?; + let unioned = Arc::new(RepartitionExec::try_new( + unioned, + Partitioning::RoundRobinBatch(1), + )?) as Arc; + + let query_index_sort = PhysicalSortExpr { + expr: expressions::col(QUERY_INDEX_COL, unioned.schema().as_ref())?, + options: SortOptions { + descending: false, + nulls_first: false, + }, + }; + let distance_sort = PhysicalSortExpr { + expr: expressions::col(DIST_COL, unioned.schema().as_ref())?, + options: SortOptions { + descending: false, + nulls_first: false, + }, + }; + let row_id_sort = PhysicalSortExpr { + expr: expressions::col(ROW_ID, unioned.schema().as_ref())?, + options: SortOptions { + descending: false, + nulls_first: false, + }, + }; + + Ok(Arc::new(SortExec::new( + [query_index_sort, distance_sort, row_id_sort].into(), + unioned, + ))) + } + + fn add_query_index_column( + plan: Arc, + query_index: i32, + ) -> Result> { + let schema = plan.schema(); + let mut projection_exprs = Vec::with_capacity(schema.fields().len() + 1); + projection_exprs.push(( + Arc::new(Literal::new(ScalarValue::Int32(Some(query_index)))) as Arc, + QUERY_INDEX_COL.to_string(), + )); + for field in schema.fields() { + projection_exprs.push(( + Arc::new(Column::new_with_schema(field.name(), schema.as_ref())?) + as Arc, + field.name().clone(), + )); + } + Ok(Arc::new(ProjectionExec::try_new(projection_exprs, plan)?)) + } + /// Combine ANN results with KNN results for data appended after index creation async fn knn_combined( &self, @@ -4267,13 +4419,29 @@ impl Scanner { default_distance_type_for(&element_type) } }; - let flat_dist = Arc::new(KNNVectorDistanceExec::try_new( + let input = if self.is_batch_nearest { + Arc::new(CoalescePartitionsExec::new(input)) as Arc + } else { + input + }; + let flat_dist = Arc::new(KNNVectorDistanceExec::try_new_batch( input, &q.column, q.key.clone(), - metric_type, + KnnBatchParams { + is_batch: self.is_batch_nearest, + query_count: self.nearest_query_count, + k: q.k, + lower_bound: q.lower_bound, + upper_bound: q.upper_bound, + distance_type: metric_type, + }, )?); + if self.is_batch_nearest { + return Ok(flat_dist); + } + let lower: Option<(Expr, Arc)> = q .lower_bound .map(|v| -> Result<(Expr, Arc)> { @@ -4338,10 +4506,17 @@ impl Scanner { ) .with_fetch(Some(q.k)); - let logical_not_null = col(DIST_COL).is_not_null(); - let not_nulls = Arc::new(LanceFilterExec::try_new(logical_not_null, Arc::new(sort))?); + Self::flat_knn_not_null_filter(Arc::new(sort)) + } - Ok(not_nulls) + fn flat_knn_not_null_filter( + knn_plan: Arc, + ) -> Result> { + let logical_not_null = col(DIST_COL).is_not_null(); + Ok(Arc::new(LanceFilterExec::try_new( + logical_not_null, + knn_plan, + )?)) } fn get_fragments_as_bitmap(&self) -> RoaringBitmap { @@ -4551,12 +4726,13 @@ impl Scanner { .partition_frags_by_coverage(index_query, fragments.clone()) .await?; - if missing_frags.is_empty() { + if missing_frags.is_empty() || self.fast_search { log::trace!("prefilter entirely satisfied by exact index search"); // We can only avoid materializing the index for a prefilter if: // 1. The search is indexed // 2. The index search is an exact search with no recheck or refine - // 3. The indices cover at least the same fragments as the vector index + // 3. The indices cover at least the same fragments as the vector index, + // unless fast_search allows skipping uncovered fragments. return Ok(PreFilterSource::ScalarIndexQuery(Arc::new( ScalarIndexExec::new(self.dataset.clone(), index_query.clone()), ))); @@ -4991,10 +5167,10 @@ mod test { use arrow::array::as_primitive_array; use arrow::datatypes::{Float64Type, Int32Type, Int64Type}; use arrow_array::cast::AsArray; - use arrow_array::types::{Float32Type, UInt64Type}; + use arrow_array::types::{Float32Type, UInt32Type, UInt64Type}; use arrow_array::{ - ArrayRef, FixedSizeListArray, Float16Array, Int32Array, LargeStringArray, PrimitiveArray, - RecordBatchIterator, StringArray, StructArray, UInt8Array, + ArrayRef, BooleanArray, FixedSizeListArray, Float16Array, Int32Array, LargeStringArray, + PrimitiveArray, RecordBatchIterator, StringArray, StructArray, UInt8Array, UInt32Array, }; use arrow_ord::sort::sort_to_indices; @@ -5530,6 +5706,56 @@ mod test { Ok(()) } + #[tokio::test] + async fn test_limit_with_scalar_index_and_refine_filter() { + let schema = Arc::new(ArrowSchema::new(vec![ + ArrowField::new("id", DataType::Int32, false), + ArrowField::new("topic", DataType::Int32, false), + ArrowField::new("is_night", DataType::Int32, false), + ])); + let batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Int32Array::from_iter_values(0..20)), + Arc::new(Int32Array::from_iter_values(std::iter::repeat_n(1, 20))), + Arc::new(Int32Array::from_iter_values( + (0..20).map(|i| if i < 10 { 0 } else { 1 }), + )), + ], + ) + .unwrap(); + let reader = RecordBatchIterator::new(vec![Ok(batch)], schema.clone()); + let mut dataset = Dataset::write(reader, "memory://", None).await.unwrap(); + dataset + .create_index( + &["topic"], + IndexType::BTree, + None, + &ScalarIndexParams::default(), + true, + ) + .await + .unwrap(); + + let actual = dataset + .scan() + .filter("topic = 1 AND is_night = 1") + .unwrap() + .limit(Some(10), None) + .unwrap() + .try_into_batch() + .await + .unwrap(); + + assert_eq!(actual.num_rows(), 10); + let ids = actual + .column_by_name("id") + .unwrap() + .as_primitive::() + .values(); + assert_eq!(ids, &(10..20).collect::>()); + } + #[test_log::test(tokio::test)] async fn test_limit_cancel() { // If there is a filter and a limit and we can't use the index to satisfy @@ -5637,7 +5863,333 @@ mod test { assert_eq!(expected_i, actual_i); } - #[rstest] + fn batch_knn_two_queries() -> (FixedSizeListArray, Vec) { + let query_values = (32..96).map(|v| v as f32).collect::>(); + let queries = + FixedSizeListArray::try_new_from_values(Float32Array::from(query_values.clone()), 32) + .unwrap(); + (queries, query_values) + } + + fn assert_query_index_field(batch: &RecordBatch) { + let schema = batch.schema(); + let field = schema.field(0); + assert_eq!(field.name(), QUERY_INDEX_COL); + assert_eq!(field.data_type(), &DataType::Int32); + assert!(!field.is_nullable()); + } + + async fn assert_batch_matches_single_queries( + dataset: &Dataset, + batch: &RecordBatch, + query_values: &[f32], + k: usize, + use_index: bool, + distance_range: Option<(Option, Option)>, + ) { + let query_count = query_values.len() / 32; + assert_eq!(batch.num_rows(), query_count * k); + + for query_index in 0..query_count { + let query = + Float32Array::from(query_values[query_index * 32..(query_index + 1) * 32].to_vec()); + let mut scan = dataset.scan(); + scan.nearest("vec", &query, k).unwrap(); + scan.use_index(use_index); + if let Some((lower, upper)) = distance_range { + scan.distance_range(lower, upper); + } + scan.project(&["i"]).unwrap(); + let single = scan.try_into_batch().await.unwrap(); + + let query_indices = batch[QUERY_INDEX_COL].as_primitive::(); + let mask = BooleanArray::from_iter( + query_indices + .iter() + .map(|value| value.map(|value| value == query_index as i32)), + ); + let batch_slice = arrow::compute::filter_record_batch(batch, &mask).unwrap(); + assert_eq!( + batch_slice["i"].as_primitive::().values(), + single["i"].as_primitive::().values() + ); + assert_eq!( + batch_slice[DIST_COL].as_primitive::().values(), + single[DIST_COL].as_primitive::().values() + ); + } + } + + #[tokio::test] + async fn test_batch_knn_flat() { + let test_ds = TestVectorDataset::new(LanceFileVersion::Stable, true) + .await + .unwrap(); + let dataset = &test_ds.dataset; + let k = 2; + + let (queries, query_values) = batch_knn_two_queries(); + let mut scan = dataset.scan(); + scan.nearest("vec", &queries, k).unwrap(); + scan.use_index(false); + scan.project(&["i"]).unwrap(); + + let plan = scan.explain_plan(false).await.unwrap(); + assert!( + plan.contains("KNNVectorDistance: queries=2"), + "expected flat batch KNN plan, got:\n{}", + plan + ); + assert!( + !plan.contains("ANNSubIndex"), + "flat batch KNN should not use ANN index, got:\n{}", + plan + ); + assert!( + !plan.contains("SortExec: TopK(fetch="), + "batch flat KNN must not truncate to k rows globally, got:\n{}", + plan + ); + + let batch = scan.try_into_batch().await.unwrap(); + assert_query_index_field(&batch); + assert_eq!( + batch.num_rows(), + 2 * k, + "batch flat KNN must return k rows per query vector" + ); + assert_eq!( + batch[QUERY_INDEX_COL].as_primitive::().values(), + &[0, 0, 1, 1] + ); + let query_indices = batch[QUERY_INDEX_COL].as_primitive::(); + for query_index in 0..2 { + let rows_for_query = query_indices + .iter() + .filter(|value| *value == Some(query_index)) + .count(); + assert_eq!( + rows_for_query, k, + "query_index {query_index} should have exactly {k} rows" + ); + } + assert_batch_matches_single_queries(dataset, &batch, &query_values, k, false, None).await; + + let query_values_one = (32..64).map(|v| v as f32).collect::>(); + let queries_one = FixedSizeListArray::try_new_from_values( + Float32Array::from(query_values_one.clone()), + 32, + ) + .unwrap(); + let mut scan = dataset.scan(); + scan.nearest("vec", &queries_one, k).unwrap(); + scan.use_index(false); + scan.project(&["i"]).unwrap(); + + let plan = scan.explain_plan(false).await.unwrap(); + assert!( + plan.contains("KNNVectorDistance: queries=1"), + "single-vector batch query should use batch KNN path, got:\n{}", + plan + ); + assert!( + !plan.contains("SortExec: TopK(fetch="), + "batch KNN must not apply per-query SortExec top-k, got:\n{}", + plan + ); + + let batch = scan.try_into_batch().await.unwrap(); + assert_query_index_field(&batch); + assert_eq!( + batch[QUERY_INDEX_COL].as_primitive::().values(), + &[0, 0] + ); + } + + #[tokio::test] + async fn test_primitive_query_length_multiple_of_dim_is_rejected() { + let test_ds = TestVectorDataset::new(LanceFileVersion::Stable, true) + .await + .unwrap(); + let dataset = &test_ds.dataset; + let q: Float32Array = (32..96).map(|v| v as f32).collect(); + + let err = match dataset.scan().nearest("vec", &q, 2) { + Err(err) => err.to_string(), + Ok(_) => panic!("expected primitive query length mismatch error"), + }; + assert!( + err.contains("query dim(64) doesn't match the column vec vector dim(32)"), + "unexpected error: {err}" + ); + } + + async fn dataset_with_query_index_column() -> (TempStrDir, Dataset) { + let path = TempStrDir::default(); + let schema = Arc::new(ArrowSchema::new(vec![ + ArrowField::new("i", DataType::Int32, true), + ArrowField::new( + "vec", + DataType::FixedSizeList( + Arc::new(ArrowField::new("item", DataType::Float32, true)), + 32, + ), + true, + ), + ArrowField::new(QUERY_INDEX_COL, DataType::UInt32, true), + ])); + let vector_values: Float32Array = (0..32 * 80).map(|v| v as f32).collect(); + let vectors = FixedSizeListArray::try_new_from_values(vector_values, 32).unwrap(); + let batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Int32Array::from_iter_values(0..80)), + Arc::new(vectors), + Arc::new(UInt32Array::from_iter((0..80).map(|v| v as u32))), + ], + ) + .unwrap(); + let dataset = Dataset::write( + RecordBatchIterator::new(std::iter::once(Ok(batch)), schema.clone()), + &path, + None, + ) + .await + .unwrap(); + (path, dataset) + } + + #[tokio::test] + async fn test_batch_knn_rejects_dataset_query_index_column() { + let (_tmp, dataset) = dataset_with_query_index_column().await; + let (queries, _) = batch_knn_two_queries(); + let err = match dataset.scan().nearest("vec", &queries, 2) { + Err(err) => err.to_string(), + Ok(_) => panic!("expected reserved query_index column error"), + }; + assert!(err.contains(QUERY_INDEX_COL), "unexpected error: {err}"); + } + + #[tokio::test] + async fn test_single_knn_projects_dataset_query_index_column() { + let (_tmp, dataset) = dataset_with_query_index_column().await; + let q: Float32Array = (32..64).map(|v| v as f32).collect(); + + let mut scan = dataset.scan(); + scan.nearest("vec", &q, 2).unwrap(); + scan.use_index(false); + scan.project(&["i"]).unwrap(); + let without_query_index = scan.try_into_batch().await.unwrap(); + + let mut scan = dataset.scan(); + scan.nearest("vec", &q, 2).unwrap(); + scan.use_index(false); + scan.project(&["i", QUERY_INDEX_COL]).unwrap(); + let with_query_index = scan.try_into_batch().await.unwrap(); + + assert_eq!(without_query_index.num_rows(), 2); + assert_eq!( + without_query_index["i"] + .as_primitive::() + .values(), + with_query_index["i"].as_primitive::().values() + ); + assert_eq!( + with_query_index[QUERY_INDEX_COL] + .as_primitive::() + .null_count(), + 0 + ); + } + + #[tokio::test] + async fn test_batch_knn_flat_respects_distance_range() { + let test_ds = TestVectorDataset::new(LanceFileVersion::Stable, true) + .await + .unwrap(); + let dataset = &test_ds.dataset; + let (queries, query_values) = batch_knn_two_queries(); + + let batch = dataset + .scan() + .nearest("vec", &queries, 2) + .unwrap() + .use_index(false) + .distance_range(Some(1.0), None) + .project(&["i"]) + .unwrap() + .try_into_batch() + .await + .unwrap(); + + assert_eq!( + batch[QUERY_INDEX_COL].as_primitive::().values(), + &[0, 0, 1, 1] + ); + assert_batch_matches_single_queries( + dataset, + &batch, + &query_values, + 2, + false, + Some((Some(1.0), None)), + ) + .await; + } + + #[tokio::test] + async fn test_batch_knn_indexed() { + let mut test_ds = TestVectorDataset::new(LanceFileVersion::Stable, true) + .await + .unwrap(); + test_ds.make_vector_index().await.unwrap(); + let dataset = &test_ds.dataset; + let (queries, query_values) = batch_knn_two_queries(); + + let mut scan = dataset.scan(); + scan.nearest("vec", &queries, 2).unwrap(); + scan.project(&["i"]).unwrap(); + + let plan = scan.explain_plan(false).await.unwrap(); + assert!( + plan.contains("ANNSubIndex"), + "batch KNN should use the vector index when available, got:\n{}", + plan + ); + assert!( + !plan.contains("KNNVectorDistance: queries=2"), + "indexed batch KNN should not force the flat batch path, got:\n{}", + plan + ); + + let batch = scan.try_into_batch().await.unwrap(); + assert_query_index_field(&batch); + assert_eq!( + batch[QUERY_INDEX_COL].as_primitive::().values(), + &[0, 0, 1, 1] + ); + + let batch = dataset + .scan() + .nearest("vec", &queries, 2) + .unwrap() + .distance_range(Some(1.0), None) + .project(&["i"]) + .unwrap() + .try_into_batch() + .await + .unwrap(); + assert_batch_matches_single_queries( + dataset, + &batch, + &query_values, + 2, + true, + Some((Some(1.0), None)), + ) + .await; + } + #[tokio::test] async fn test_can_project_distance() { let test_ds = TestVectorDataset::new(LanceFileVersion::Stable, true) @@ -9746,6 +10298,170 @@ full_filter=name LIKE Utf8(\"test%2\"), refine_filter=name LIKE Utf8(\"test%2\") assert_eq!(fast_rows, 0); } + #[tokio::test] + async fn test_batch_fast_search_without_index_returns_empty_with_query_index() { + let dataset = TestVectorDataset::new(LanceFileVersion::Stable, true) + .await + .unwrap(); + let query_values = (32..96).map(|v| v as f32).collect::>(); + let queries = + FixedSizeListArray::try_new_from_values(Float32Array::from(query_values), 32).unwrap(); + + let mut scanner = dataset.dataset.scan(); + scanner.nearest("vec", &queries, 2).unwrap().fast_search(); + let batch = scanner.try_into_batch().await.unwrap(); + + assert_eq!(batch.num_rows(), 0); + assert_query_index_field(&batch); + } + + #[rstest] + #[tokio::test] + async fn test_fast_search_scalar_index_skips_unindexed_fragments( + #[values(LanceFileVersion::Stable)] data_storage_version: LanceFileVersion, + ) { + let mut dataset = TestVectorDataset::new(data_storage_version, false) + .await + .unwrap(); + dataset.make_scalar_index().await.unwrap(); + dataset.append_new_data().await.unwrap(); + + let mut scanner = dataset.dataset.scan(); + scanner.filter("i >= 395").unwrap().project(&["i"]).unwrap(); + let normal_batch = scanner.try_into_batch().await.unwrap(); + + let mut scanner = dataset.dataset.scan(); + scanner + .filter("i >= 395") + .unwrap() + .fast_search() + .project(&["i"]) + .unwrap(); + let fast_batch = scanner.try_into_batch().await.unwrap(); + + assert_eq!(normal_batch.num_rows(), 15); + assert_eq!(fast_batch.num_rows(), 5); + } + + fn make_scalar_filter_test_batch(schema: SchemaRef, start: i32, end: i32) -> RecordBatch { + RecordBatch::try_new( + schema, + vec![ + Arc::new(Int32Array::from_iter_values(start..end)), + Arc::new(Int32Array::from_iter_values(start..end)), + ], + ) + .unwrap() + } + + async fn make_scalar_filter_test_dataset( + data_storage_version: LanceFileVersion, + ) -> (TempStrDir, SchemaRef, Dataset) { + let tmp_dir = TempStrDir::default(); + let schema = Arc::new(ArrowSchema::new(vec![ + ArrowField::new("a", DataType::Int32, false), + ArrowField::new("b", DataType::Int32, false), + ])); + let batch = make_scalar_filter_test_batch(schema.clone(), 0, 100); + let reader = RecordBatchIterator::new(vec![Ok(batch)], schema.clone()); + let dataset = Dataset::write( + reader, + &tmp_dir, + Some(WriteParams { + max_rows_per_file: 100, + data_storage_version: Some(data_storage_version), + ..Default::default() + }), + ) + .await + .unwrap(); + (tmp_dir, schema, dataset) + } + + async fn append_scalar_filter_test_data( + dataset: &mut Dataset, + schema: SchemaRef, + start: i32, + end: i32, + ) { + let batch = make_scalar_filter_test_batch(schema.clone(), start, end); + let reader = RecordBatchIterator::new(vec![Ok(batch)], schema); + dataset.append(reader, None).await.unwrap(); + } + + async fn create_scalar_index(dataset: &mut Dataset, column: &str) { + dataset + .create_index( + &[column], + IndexType::Scalar, + None, + &ScalarIndexParams::default(), + false, + ) + .await + .unwrap(); + } + + async fn scan_count(dataset: &Dataset, filter: &str, fast_search: bool) -> usize { + let mut scanner = dataset.scan(); + scanner + .filter(filter) + .unwrap() + .project(&["a", "b"]) + .unwrap(); + if fast_search { + scanner.fast_search(); + } + scanner.try_into_batch().await.unwrap().num_rows() + } + + #[rstest] + #[tokio::test] + async fn test_fast_search_scalar_index_filter_coverage_cases( + #[values(LanceFileVersion::Stable)] data_storage_version: LanceFileVersion, + ) { + let (_tmp_dir, schema, mut dataset) = + make_scalar_filter_test_dataset(data_storage_version).await; + create_scalar_index(&mut dataset, "a").await; + append_scalar_filter_test_data(&mut dataset, schema, 100, 110).await; + + // a is indexed and b is not. The indexed side finds candidates in covered + // fragments and b is applied as a refine filter. + assert_eq!(scan_count(&dataset, "a >= 95 AND b >= 95", false).await, 15); + assert_eq!(scan_count(&dataset, "a >= 95 AND b >= 95", true).await, 5); + + // OR cannot safely skip unindexed fragments: a row in an unindexed fragment + // may satisfy `b >= 105` even if `a` is not indexed there. Skipping it would + // silently drop valid results, so fast_search has no effect on OR queries where + // any branch lacks a scalar index. + assert_eq!(scan_count(&dataset, "a >= 105 OR b >= 105", false).await, 5); + assert_eq!(scan_count(&dataset, "a >= 105 OR b >= 105", true).await, 5); + + // A single-column indexed filter skips the appended fragment in fast mode. + assert_eq!(scan_count(&dataset, "a >= 95", false).await, 15); + assert_eq!(scan_count(&dataset, "a >= 95", true).await, 5); + + let (_tmp_dir, schema, mut dataset) = + make_scalar_filter_test_dataset(data_storage_version).await; + create_scalar_index(&mut dataset, "a").await; + append_scalar_filter_test_data(&mut dataset, schema, 100, 110).await; + create_scalar_index(&mut dataset, "b").await; + + // a and b are both indexed, but a only covers the original fragment while b + // covers both fragments. Fast search only reads the shared indexed coverage. + assert_eq!(scan_count(&dataset, "a >= 95 AND b >= 95", false).await, 15); + assert_eq!(scan_count(&dataset, "a >= 95 AND b >= 95", true).await, 5); + + let (_tmp_dir, schema, mut dataset) = + make_scalar_filter_test_dataset(data_storage_version).await; + append_scalar_filter_test_data(&mut dataset, schema, 100, 110).await; + + // With no scalar index query, fast_search must not enable indexed-fragment + // pruning. This guards against treating any ordinary filter as indexed. + assert_eq!(scan_count(&dataset, "a >= 95", false).await, 15); + assert_eq!(scan_count(&dataset, "a >= 95", true).await, 15); + } + #[rstest] #[tokio::test] pub async fn test_scan_planning_io( diff --git a/rust/lance/src/dataset/write/merge_insert.rs b/rust/lance/src/dataset/write/merge_insert.rs index e9432fbad27..54a71b685aa 100644 --- a/rust/lance/src/dataset/write/merge_insert.rs +++ b/rust/lance/src/dataset/write/merge_insert.rs @@ -270,6 +270,9 @@ pub enum WhenMatched { /// The row is updated (similar to UpdateAll) only for rows where the expression evaluates to /// true UpdateIf(String), + /// The row is updated (similar to UpdateAll) only for rows where the expression evaluates to + /// true + UpdateIfExpr(Expr), /// Fail the operation if a match is found /// /// This can be used to ensure that no existing rows are overwritten or modified after inserted. @@ -286,6 +289,10 @@ impl WhenMatched { // Store the expression string and defer parsing until we know which path to take Ok(Self::UpdateIf(expr.to_string())) } + + pub fn update_if_expr(expr: Expr) -> Self { + Self::UpdateIfExpr(expr) + } } /// Describes how rows should be handled when there is no matching row in the target table @@ -1635,6 +1642,7 @@ impl MergeInsertJob { self.params.when_matched, WhenMatched::UpdateAll | WhenMatched::UpdateIf(_) + | WhenMatched::UpdateIfExpr(_) | WhenMatched::Fail | WhenMatched::Delete | WhenMatched::DoNothing @@ -2082,22 +2090,27 @@ impl Merger { } else { None }; - let match_filter_expr = if let WhenMatched::UpdateIf(expr_str) = ¶ms.when_matched { - let combined_schema = Arc::new(combined_schema(&schema)); - let planner = Planner::new(combined_schema.clone()); - let expr = planner.parse_filter(expr_str)?; - let expr = planner.optimize_expr(expr)?; - let match_expr = planner.create_physical_expr(&expr)?; - let data_type = match_expr.data_type(combined_schema.as_ref())?; - if data_type != DataType::Boolean { - return Err(Error::invalid_input(format!( - "Merge insert conditions must be expressions that return a boolean value, received a 'when matched update if' expression ({}) which has data type {}", - expr, data_type - ))); + let match_filter_expr = match ¶ms.when_matched { + WhenMatched::UpdateIf(_) | WhenMatched::UpdateIfExpr(_) => { + let combined_schema = Arc::new(combined_schema(&schema)); + let planner = Planner::new(combined_schema.clone()); + let expr = match ¶ms.when_matched { + WhenMatched::UpdateIf(expr_str) => planner.parse_filter(expr_str)?, + WhenMatched::UpdateIfExpr(expr) => expr.clone(), + _ => unreachable!(), + }; + let expr = planner.optimize_expr(expr)?; + let match_expr = planner.create_physical_expr(&expr)?; + let data_type = match_expr.data_type(combined_schema.as_ref())?; + if data_type != DataType::Boolean { + return Err(Error::invalid_input(format!( + "Merge insert conditions must be expressions that return a boolean value, received a 'when matched update if' expression ({}) which has data type {}", + expr, data_type + ))); + } + Some(match_expr) } - Some(match_expr) - } else { - None + _ => None, }; let output_schema = if with_row_addr { Arc::new(schema.try_with_column(ROW_ADDR_FIELD.clone())?) @@ -8589,6 +8602,131 @@ MergeInsert: on=[id], when_matched=DoNothing, when_not_matched=InsertAll, when_n } } + // Regression test for GitHub issue #6877. + // + // Two sequential full-schema merge_insert UpdateAll calls against the same + // target row, on a dataset with stable_row_ids enabled and a BTREE scalar + // index on the join column, used to fail on the second call with + // "Ambiguous merge inserts are prohibited" — even though each call's + // source had exactly one row per key. + // + // Mechanism: with stable row ids the BTREE stores stable_row_ids (not + // physical addresses). After the first merge_insert, A's stable_row_id is + // preserved but its physical home moves to an unindexed fragment. The + // BTREE-side TakeExec resolves the stable_row_id to A's new location and + // emits a row; the unindexed-fragments scan also covers the new fragment + // and emits the same logical row. Both surface the same `_rowid`, so the + // merge_insert source-dedup HashSet sees a duplicate and aborts. + // + // Fix: thread `restrict_to_fragments` into `do_create_deletion_mask_row_id` + // so the allow-list only contains stable_row_ids whose current physical + // home is inside the index's fragment_bitmap. + #[tokio::test] + async fn test_issue_6877_repeated_merge_insert_stable_row_ids() { + use arrow_array::Int32Array; + + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Utf8, false), + Field::new("value", DataType::Int32, false), + ])); + + let initial = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(StringArray::from(vec!["A", "B", "C"])), + Arc::new(Int32Array::from(vec![1, 2, 3])), + ], + ) + .unwrap(); + + let mut ds = Dataset::write( + Box::new(RecordBatchIterator::new([Ok(initial)], schema.clone())), + "memory://test_6877", + Some(WriteParams { + mode: WriteMode::Overwrite, + enable_stable_row_ids: true, + ..Default::default() + }), + ) + .await + .unwrap(); + + ds.create_index( + &["id"], + IndexType::Scalar, + None, + &ScalarIndexParams::default(), + false, + ) + .await + .unwrap(); + + // First merge_insert: A 1 -> 11. + let update_a = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(StringArray::from(vec!["A"])), + Arc::new(Int32Array::from(vec![11])), + ], + ) + .unwrap(); + let (ds, _) = MergeInsertBuilder::try_new(Arc::new(ds), vec!["id".to_string()]) + .unwrap() + .when_matched(WhenMatched::UpdateAll) + .when_not_matched(WhenNotMatched::DoNothing) + .try_build() + .unwrap() + .execute_reader(Box::new(RecordBatchIterator::new( + [Ok(update_a)], + schema.clone(), + ))) + .await + .unwrap(); + + // Second merge_insert: A 11 -> 22. Used to fail before the fix. + let update_a_again = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(StringArray::from(vec!["A"])), + Arc::new(Int32Array::from(vec![22])), + ], + ) + .unwrap(); + let (ds, _) = MergeInsertBuilder::try_new(ds, vec!["id".to_string()]) + .unwrap() + .when_matched(WhenMatched::UpdateAll) + .when_not_matched(WhenNotMatched::DoNothing) + .try_build() + .unwrap() + .execute_reader(Box::new(RecordBatchIterator::new( + [Ok(update_a_again)], + schema.clone(), + ))) + .await + .unwrap(); + + // Sanity check: A's value is now 22. + let batches = ds + .scan() + .filter("id = 'A'") + .unwrap() + .try_into_stream() + .await + .unwrap() + .try_collect::>() + .await + .unwrap(); + let combined = concat_batches(&schema, &batches).unwrap(); + assert_eq!(combined.num_rows(), 1); + let values = combined + .column_by_name("value") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(values.value(0), 22); + } + // Regression test: partial-schema merge_insert followed by update (deleting all rows // in a fragment) followed by partial merge_insert should not produce // "fragment id N does not exist" errors. @@ -8945,6 +9083,128 @@ MergeInsert: on=[id], when_matched=DoNothing, when_not_matched=InsertAll, when_n ); } + // Companion regression test for issue #6877 on the FTS path. + // + // The FTS prefilter shares `do_create_deletion_mask_row_id` with the + // scalar-index path, so the same stable-row-id bypass that produced + // duplicate rows in merge_insert can produce duplicate hits in FTS search + // after a merge_insert moves rows to unindexed fragments. This test pins + // the contract for the FTS consumer. + #[tokio::test] + async fn test_issue_6877_fts_no_duplicates_stable_row_ids() { + let rows_per_frag = 10usize; + let num_frags = 3usize; + + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Utf8, false), + Field::new("text", DataType::Utf8, false), + ])); + + let make_batch = |frag_idx: usize| { + let start = frag_idx * rows_per_frag; + let ids: Vec = (start..start + rows_per_frag) + .map(|j| format!("id-{j:04}")) + .collect(); + let texts: Vec = (start..start + rows_per_frag) + .map(|j| format!("common unique{j:04}")) + .collect(); + RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(StringArray::from(ids)), + Arc::new(StringArray::from(texts)), + ], + ) + .unwrap() + }; + + let batch0 = make_batch(0); + let reader = Box::new(RecordBatchIterator::new([Ok(batch0)], schema.clone())); + let mut ds = Dataset::write( + reader, + "memory://fts_stable_row_id_test", + Some(WriteParams { + mode: WriteMode::Overwrite, + enable_stable_row_ids: true, + ..Default::default() + }), + ) + .await + .unwrap(); + for frag_idx in 1..num_frags { + let batch = make_batch(frag_idx); + let reader = Box::new(RecordBatchIterator::new([Ok(batch)], schema.clone())); + ds.append(reader, None).await.unwrap(); + } + + let params = InvertedIndexParams::default(); + ds.create_index(&["text"], IndexType::Inverted, None, ¶ms, true) + .await + .unwrap(); + + // Full-schema merge_insert rewriting fragment 1's rows. After this, + // the original locations are tombstoned and the new locations live in + // a new (unindexed) fragment; the stable_row_ids are preserved. + let frag1_start = rows_per_frag; + let ids: Vec = (frag1_start..frag1_start + rows_per_frag) + .map(|j| format!("id-{j:04}")) + .collect(); + let texts: Vec = (frag1_start..frag1_start + rows_per_frag) + .map(|j| format!("common updated{j:04}")) + .collect(); + let update_batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(StringArray::from(ids)), + Arc::new(StringArray::from(texts)), + ], + ) + .unwrap(); + let reader = Box::new(RecordBatchIterator::new([Ok(update_batch)], schema.clone())); + let (ds, _) = MergeInsertBuilder::try_new(Arc::new(ds), vec!["id".to_string()]) + .unwrap() + .when_matched(WhenMatched::UpdateAll) + .when_not_matched(WhenNotMatched::DoNothing) + .try_build() + .unwrap() + .execute_reader(reader) + .await + .unwrap(); + + // FTS search for "common" — every row should match exactly once. + let query = FullTextSearchQuery::new("common".to_string()); + let results = ds + .scan() + .full_text_search(query) + .unwrap() + .try_into_batch() + .await + .unwrap(); + + let ids = results + .column_by_name("id") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + let unique_ids: std::collections::HashSet<&str> = + (0..ids.len()).map(|i| ids.value(i)).collect(); + assert_eq!( + unique_ids.len(), + ids.len(), + "Found duplicate ids in FTS results: {} unique out of {} total", + unique_ids.len(), + ids.len() + ); + assert_eq!( + unique_ids.len(), + rows_per_frag * num_frags, + "Expected {} rows but got {}", + rows_per_frag * num_frags, + unique_ids.len() + ); + } + // Regression test: after a partial-schema merge_insert invalidates a fragment, // compaction should succeed and subsequent searches should return correct results. // diff --git a/rust/lance/src/dataset/write/merge_insert/assign_action.rs b/rust/lance/src/dataset/write/merge_insert/assign_action.rs index 46bda38873f..5641d51203d 100644 --- a/rust/lance/src/dataset/write/merge_insert/assign_action.rs +++ b/rust/lance/src/dataset/write/merge_insert/assign_action.rs @@ -125,6 +125,12 @@ pub fn merge_insert_action( )); } } + WhenMatched::UpdateIfExpr(condition) => { + cases.push(( + matched.and(condition.clone()), + Action::UpdateAll.as_literal_expr(), + )); + } WhenMatched::DoNothing => {} WhenMatched::Fail => { cases.push((matched, Action::Fail.as_literal_expr())); diff --git a/rust/lance/src/dataset/write/merge_insert/exec/write.rs b/rust/lance/src/dataset/write/merge_insert/exec/write.rs index 47dd6f28bd8..8c57ebc9df3 100644 --- a/rust/lance/src/dataset/write/merge_insert/exec/write.rs +++ b/rust/lance/src/dataset/write/merge_insert/exec/write.rs @@ -739,6 +739,9 @@ impl DisplayAs for FullSchemaMergeInsertExec { crate::dataset::WhenMatched::UpdateIf(condition) => { format!("UpdateIf({})", condition) } + crate::dataset::WhenMatched::UpdateIfExpr(expr) => { + format!("UpdateIf({})", expr.human_display()) + } crate::dataset::WhenMatched::Fail => "Fail".to_string(), crate::dataset::WhenMatched::Delete => "Delete".to_string(), }; diff --git a/rust/lance/src/dataset/write/merge_insert/logical_plan.rs b/rust/lance/src/dataset/write/merge_insert/logical_plan.rs index f5c1770238c..7f67972ff7e 100644 --- a/rust/lance/src/dataset/write/merge_insert/logical_plan.rs +++ b/rust/lance/src/dataset/write/merge_insert/logical_plan.rs @@ -102,6 +102,7 @@ impl UserDefinedLogicalNodeCore for MergeInsertWriteNode { crate::dataset::WhenMatched::DoNothing => "DoNothing", crate::dataset::WhenMatched::UpdateAll => "UpdateAll", crate::dataset::WhenMatched::UpdateIf(_) => "UpdateIf", + crate::dataset::WhenMatched::UpdateIfExpr(_) => "UpdateIfExpr", crate::dataset::WhenMatched::Fail => "Fail", crate::dataset::WhenMatched::Delete => "Delete", }; diff --git a/rust/lance/src/index.rs b/rust/lance/src/index.rs index c65fe13a23e..01a7943f4f7 100644 --- a/rust/lance/src/index.rs +++ b/rust/lance/src/index.rs @@ -1049,12 +1049,12 @@ impl DatasetIndexExt for Dataset { Ok(merged_segment) } - async fn commit_existing_index_segments( - &mut self, + async fn build_existing_index_segments_transaction( + &self, index_name: &str, column: &str, segments: Vec, - ) -> Result<()> { + ) -> Result { let Some(field) = self.schema().field(column) else { return Err(Error::index(format!( "CreateIndex: column '{column}' does not exist" @@ -1094,8 +1094,7 @@ impl DatasetIndexExt for Dataset { .any(|idx| idx.fields != [field.id]) { return Err(Error::index(format!( - "Index name '{index_name}' already exists with different fields, \ - please specify a different name" + "Index name '{index_name}' already exists with different fields, please specify a different name" ))); } let removed_indices = existing_named_indices @@ -1138,14 +1137,25 @@ impl DatasetIndexExt for Dataset { .flatten() .collect::>(); - let transaction = Transaction::new( + Ok(Transaction::new( self.manifest.version, Operation::CreateIndex { new_indices, removed_indices, }, None, - ); + )) + } + + async fn commit_existing_index_segments( + &mut self, + index_name: &str, + column: &str, + segments: Vec, + ) -> Result<()> { + let transaction = self + .build_existing_index_segments_transaction(index_name, column, segments) + .await?; self.apply_commit(transaction, &Default::default(), &Default::default()) .await?; @@ -2417,7 +2427,7 @@ mod tests { use super::*; use crate::dataset::builder::DatasetBuilder; use crate::dataset::optimize::{CompactionOptions, compact_files}; - use crate::dataset::{WriteMode, WriteParams}; + use crate::dataset::{CommitBuilder, WriteMode, WriteParams}; use crate::index::vector::VectorIndexParams; use crate::session::Session; use crate::utils::test::{DatagenExt, FragmentCount, FragmentRowCount, copy_test_data_to_tmp}; @@ -6369,6 +6379,89 @@ mod tests { ); } + #[tokio::test] + async fn test_build_existing_index_segments_transaction_does_not_commit() { + use lance_datagen::{BatchCount, RowCount, array}; + + let test_dir = tempfile::tempdir().unwrap(); + let test_uri = test_dir.path().to_str().unwrap(); + + let reader = lance_datagen::gen_batch() + .col("id", array::step::()) + .col( + "vector", + array::rand_vec::(8.into()), + ) + .into_reader_rows(RowCount::from(20), BatchCount::from(2)); + + let dataset = Dataset::write( + reader, + test_uri, + Some(WriteParams { + max_rows_per_file: 10, + max_rows_per_group: 10, + ..Default::default() + }), + ) + .await + .unwrap(); + + let field_id = dataset.schema().field("vector").unwrap().id; + let seg0 = write_vector_segment_metadata( + &dataset, + "vector_idx", + field_id, + Uuid::new_v4(), + [0_u32], + b"seg0", + ) + .await; + let seg1 = write_vector_segment_metadata( + &dataset, + "vector_idx", + field_id, + Uuid::new_v4(), + [1_u32], + b"seg1", + ) + .await; + + let transaction = dataset + .build_existing_index_segments_transaction( + "vector_idx", + "vector", + vec![segment_from_metadata(&seg0), segment_from_metadata(&seg1)], + ) + .await + .unwrap(); + + assert!( + dataset + .load_indices_by_name("vector_idx") + .await + .unwrap() + .is_empty(), + "building a transaction must not publish the index" + ); + assert_eq!(transaction.read_version, dataset.manifest.version); + let Operation::CreateIndex { + new_indices, + removed_indices, + } = &transaction.operation + else { + panic!("expected index creation transaction"); + }; + assert_eq!(new_indices.len(), 2); + assert!(removed_indices.is_empty()); + + let committed = CommitBuilder::new(Arc::new(dataset)) + .execute(transaction) + .await + .unwrap(); + let indices = committed.load_indices_by_name("vector_idx").await.unwrap(); + assert_eq!(indices.len(), 2); + } + #[tokio::test] async fn test_commit_existing_index_segments_rejects_duplicate_segment_ids() { use lance_datagen::{BatchCount, RowCount, array}; diff --git a/rust/lance/src/index/api.rs b/rust/lance/src/index/api.rs index 4ae652624c1..61cb24bfa8a 100644 --- a/rust/lance/src/index/api.rs +++ b/rust/lance/src/index/api.rs @@ -10,6 +10,7 @@ use lance_table::format::IndexMetadata; use roaring::RoaringBitmap; use uuid::Uuid; +use crate::dataset::transaction::Transaction; use crate::{Error, Result}; /// A single physical segment of a logical index. @@ -305,6 +306,19 @@ pub trait DatasetIndexExt { source_segments: Vec, ) -> Result; + /// Build a transaction that publishes existing physical index segments. + /// + /// This stages the same manifest update as [`Self::commit_existing_index_segments`] + /// without advancing the dataset version. Callers that need a strict + /// stage-then-commit workflow can pass the returned transaction to + /// [`crate::dataset::CommitBuilder`]. + async fn build_existing_index_segments_transaction( + &self, + index_name: &str, + column: &str, + segments: Vec, + ) -> Result; + /// Commit one or more existing physical index segments as a logical index. async fn commit_existing_index_segments( &mut self, diff --git a/rust/lance/src/index/append.rs b/rust/lance/src/index/append.rs index 16b17752ef4..4398928d3e2 100644 --- a/rust/lance/src/index/append.rs +++ b/rust/lance/src/index/append.rs @@ -515,6 +515,14 @@ pub async fn merge_indices_with_unindexed_frags<'a>( )) } it if it.is_scalar() => { + let num_to_merge = options + .num_indices_to_merge + .unwrap_or(1) + .min(old_indices.len()); + if unindexed.is_empty() && num_to_merge <= 1 { + return Ok(None); + } + // Use effective bitmap (intersected with existing dataset fragments) // to avoid carrying stale data from pruned indices. let effective_old_frags: RoaringBitmap = old_indices @@ -979,6 +987,7 @@ mod tests { .unwrap() .nearest("vector", array.value(0).as_primitive::(), 2) .unwrap() + .nprobes(2) .refine(1) .try_into_batch() .await @@ -1287,4 +1296,61 @@ mod tests { let dataset = DatasetBuilder::from_uri(test_uri).load().await.unwrap(); assert_eq!(query_id_count(&dataset, "song-42").await, 1); } + + #[tokio::test] + async fn test_optimize_scalar_no_unindexed_fragments() { + let test_dir = TempStrDir::default(); + let test_uri = test_dir.as_str(); + + let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Utf8, false)])); + let ids = StringArray::from_iter_values((0..32).map(|i| format!("song-{i}"))); + let batch = RecordBatch::try_new(schema.clone(), vec![Arc::new(ids)]).unwrap(); + let reader = RecordBatchIterator::new(vec![Ok(batch)], schema.clone()); + let mut dataset = Dataset::write(reader, test_uri, None).await.unwrap(); + + dataset + .create_index( + &["id"], + IndexType::BTree, + Some("id_idx".into()), + &ScalarIndexParams::default(), + true, + ) + .await + .unwrap(); + + let before = dataset.load_indices_by_name("id_idx").await.unwrap(); + assert_eq!(before.len(), 1); + let original_uuid = before[0].uuid; + let original_version = dataset.manifest.version; + + // `merge(1)` would historically rebuild the single existing segment + // (steady state, nothing unindexed) and replace its UUID; with the + // short-circuit it must skip work entirely. + dataset + .optimize_indices(&OptimizeOptions::merge(1)) + .await + .unwrap(); + + let after = dataset.load_indices_by_name("id_idx").await.unwrap(); + assert_eq!(after.len(), 1, "no new segment should be produced"); + assert_eq!( + after[0].uuid, original_uuid, + "no-op optimize must not churn the index UUID" + ); + assert_eq!( + dataset.manifest.version, original_version, + "no-op optimize must not advance the dataset version" + ); + + // The default options also short-circuit (num_to_merge defaults to 1 + // when there is a single old segment). + dataset + .optimize_indices(&OptimizeOptions::default()) + .await + .unwrap(); + let after_default = dataset.load_indices_by_name("id_idx").await.unwrap(); + assert_eq!(after_default[0].uuid, original_uuid); + assert_eq!(dataset.manifest.version, original_version); + } } diff --git a/rust/lance/src/index/prefilter.rs b/rust/lance/src/index/prefilter.rs index ebb4c600e65..071f1b8893d 100644 --- a/rust/lance/src/index/prefilter.rs +++ b/rust/lance/src/index/prefilter.rs @@ -127,13 +127,33 @@ impl DatasetPreFilter { } #[instrument(level = "debug", skip_all)] - async fn do_create_deletion_mask_row_id(dataset: Arc) -> Result> { - // This can only be computed as an allow list, since we have no idea - // what the row ids were in the missing fragments. + async fn do_create_deletion_mask_row_id( + dataset: Arc, + restrict_to: Option, + ) -> Result> { + // The mask is an allow-list of stable row ids. When `restrict_to` is + // set the iteration is limited to the listed fragments, so the + // resulting list excludes stable row ids whose *current* physical home + // is outside the restriction. This is the missing piece for the + // stable-row-id branch of #6563: without it, the merge_insert UNION + // (indexed scan ∪ unindexed-fragments scan) sees the same logical row + // twice — once via the BTREE (which holds the row's stable_row_id) and + // once via the unindexed scan (which holds the fragment the row now + // lives in). See issue #6877. async fn load_row_ids_and_deletions( dataset: &Dataset, + restrict_to: Option<&RoaringBitmap>, ) -> Result, Option>)>> { - stream::iter(dataset.get_fragments()) + let frags: Vec<_> = dataset + .get_fragments() + .into_iter() + .filter(|f| { + restrict_to + .map(|allow| allow.contains(f.id() as u32)) + .unwrap_or(true) + }) + .collect(); + stream::iter(frags) .map(|frag| async move { let row_ids = load_row_id_sequence(dataset, frag.metadata()); let deletion_vector = frag.get_deletion_vector(); @@ -145,16 +165,31 @@ impl DatasetPreFilter { .await } + let restrict_hash = restrict_to.as_ref().map(|b| { + use std::collections::hash_map::DefaultHasher; + use std::hash::{Hash, Hasher}; + let mut h = DefaultHasher::new(); + // RoaringBitmap doesn't implement Hash; serialize the sorted u32s. + for v in b.iter() { + v.hash(&mut h); + } + h.finish() + }); + let dataset_clone = dataset.clone(); + let restrict_for_load = restrict_to.clone(); let key = crate::session::caches::RowAddrMaskKey { version: dataset.manifest().version, + restrict_hash, }; dataset .metadata_cache .as_ref() .get_or_insert_with_key(key, move || { async move { - let row_ids_and_deletions = load_row_ids_and_deletions(&dataset_clone).await?; + let row_ids_and_deletions = + load_row_ids_and_deletions(&dataset_clone, restrict_for_load.as_ref()) + .await?; // The process of computing the final mask is CPU-bound, so we spawn it // on a blocking thread. @@ -268,7 +303,12 @@ impl DatasetPreFilter { if missing_frags.is_empty() && frags_with_deletion_files.is_empty() && !needs_allow_list { None } else if dataset.manifest.uses_stable_row_ids() { - Some(Self::do_create_deletion_mask_row_id(dataset.clone()).boxed()) + let restrict_to = if restrict_to_fragments { + Some(fragments) + } else { + None + }; + Some(Self::do_create_deletion_mask_row_id(dataset.clone(), restrict_to).boxed()) } else if missing_frags.is_empty() && frags_with_deletion_files.is_empty() { // No deletions to load, but the dataset has fragments outside the // index bitmap. Return a synchronous allow-list mask. @@ -529,4 +569,53 @@ mod test { let mask = mask.unwrap().await.unwrap(); assert_eq!(mask.allow_list().and_then(|x| x.len()), Some(3)); // There were three rows left over; } + + // Regression test for issue #6877. + // + // `create_restricted_deletion_mask` on a stable-row-id dataset must honor + // the bitmap restriction by excluding stable row ids whose *current* + // physical home is outside the bitmap. Without this, the merge_insert + // UNION (indexed-scan ∪ unindexed-fragments scan) emits the same logical + // row twice — once via the BTREE (which holds the row's stable_row_id) + // and once via the unindexed scan. + #[tokio::test] + async fn test_restricted_deletion_mask_stable_row_id_honors_bitmap() { + // Dataset with three fragments, 3 rows each, stable_row_ids = {0..9}. + // Row x=8 is deleted, so the live stable_row_ids are {0..7}. + let datasets = test_datasets(true).await; + let ds = datasets.deletions_no_missing_frags.clone(); + + // Full bitmap: allow-list covers all currently-live stable row ids. + let mask = DatasetPreFilter::create_restricted_deletion_mask( + ds.clone(), + RoaringBitmap::from_iter(0..3), + ) + .expect("full-bitmap mask present on stable-row-id dataset with deletions") + .await + .unwrap(); + let expected_all = RowAddrTreeMap::from_iter(0..8); + assert_eq!(mask.allow_list(), Some(&expected_all)); + + // Restricted to fragments {0, 1}: allow-list must exclude stable row + // ids whose current home is in fragment 2 (rows 6, 7 — 8 was deleted). + let mask = DatasetPreFilter::create_restricted_deletion_mask( + ds.clone(), + RoaringBitmap::from_iter(0..2), + ) + .expect("restricted mask present") + .await + .unwrap(); + let expected_restricted = RowAddrTreeMap::from_iter(0..6); + assert_eq!(mask.allow_list(), Some(&expected_restricted)); + + // Restricted to empty bitmap: every BTREE-returned address is filtered + // out. Empty allow-list is the correct semantic ("no row's current home + // is in the restriction"). + let mask = + DatasetPreFilter::create_restricted_deletion_mask(ds.clone(), RoaringBitmap::new()) + .expect("empty-restriction mask present") + .await + .unwrap(); + assert_eq!(mask.allow_list().and_then(|x| x.len()), Some(0)); + } } diff --git a/rust/lance/src/index/vector/utils.rs b/rust/lance/src/index/vector/utils.rs index 668daf8ff58..3046d0f3a83 100644 --- a/rust/lance/src/index/vector/utils.rs +++ b/rust/lance/src/index/vector/utils.rs @@ -10,7 +10,7 @@ use arrow::datatypes::DataType; use arrow_array::new_empty_array; use arrow_array::{Array, ArrayRef, FixedSizeListArray, RecordBatch, cast::AsArray}; use arrow_buffer::{Buffer, MutableBuffer}; -use futures::{Stream, StreamExt, stream}; +use futures::{Stream, StreamExt, TryStreamExt, stream}; use lance_arrow::DataTypeExt; use lance_core::datatypes::Schema; use lance_linalg::distance::DistanceType; @@ -261,6 +261,44 @@ fn infer_vector_element_type_impl( } } +async fn count_rows(dataset: &Dataset, fragment_ids: Option<&[u32]>) -> Result { + match fragment_ids { + None => dataset.count_rows(None).await, + Some(fragment_ids) => { + let sorted_ids: Vec; + let sorted_fragment_ids = if fragment_ids.windows(2).all(|w| w[0] <= w[1]) { + fragment_ids + } else { + sorted_ids = { + let mut v = fragment_ids.to_vec(); + v.sort_unstable(); + v + }; + &sorted_ids + }; + let fragments = dataset.get_frags_from_ordered_ids(sorted_fragment_ids); + let valid_fragments = fragments + .into_iter() + .enumerate() + .map(|(i, frag)| { + frag.ok_or_else(|| { + Error::index(format!( + "Unexpectedly missing fragment {}", + sorted_fragment_ids[i] + )) + }) + }) + .collect::>>()?; + let cnts = stream::iter(valid_fragments) + .map(|f| async move { f.count_rows(None).await }) + .buffer_unordered(16) + .try_collect::>() + .await?; + Ok(cnts.iter().sum::()) + } + } +} + /// Maybe sample training data from dataset, specified by column name. /// /// Returns a [FixedSizeListArray], containing the training dataset. @@ -271,13 +309,7 @@ pub async fn maybe_sample_training_data( sample_size_hint: usize, fragment_ids: Option<&[u32]>, ) -> Result { - let num_rows = if let Some(fragment_ids) = fragment_ids { - let mut scanner = dataset.scan(); - scanner.with_fragments(resolve_scan_fragments(dataset, fragment_ids)?); - scanner.count_rows().await? as usize - } else { - dataset.count_rows(None).await? - }; + let num_rows = count_rows(dataset, fragment_ids).await?; let vector_field = dataset.schema().field(column).ok_or(Error::index(format!( "Sample training data: column {} does not exist in schema", @@ -1224,4 +1256,75 @@ mod tests { .unwrap(); assert_eq!(n, 1030); } + + // Creates a dataset with three fragments holding 100, 200, and 150 rows. + async fn make_three_fragment_dataset() -> Dataset { + use arrow_array::{RecordBatch, RecordBatchIterator}; + use arrow_schema::Schema as ArrowSchema; + + let schema = Arc::new(ArrowSchema::new(vec![arrow_schema::Field::new( + "x", + arrow_schema::DataType::Float32, + false, + )])); + + let make_batch = |n: usize| -> RecordBatch { + let arr: ArrayRef = Arc::new(Float32Array::from_iter_values((0..n).map(|i| i as f32))); + RecordBatch::try_new(schema.clone(), vec![arr]).unwrap() + }; + + let mut dataset = InsertBuilder::new("memory://test_count_rows_util") + .execute(vec![make_batch(100)]) + .await + .unwrap(); + dataset + .append( + RecordBatchIterator::new(vec![Ok(make_batch(200))], schema.clone()), + None, + ) + .await + .unwrap(); + dataset + .append( + RecordBatchIterator::new(vec![Ok(make_batch(150))], schema.clone()), + None, + ) + .await + .unwrap(); + + dataset + } + + #[tokio::test] + async fn test_count_rows_none() { + let dataset = make_three_fragment_dataset().await; + assert_eq!(dataset.get_fragments().len(), 3); + assert_eq!(count_rows(&dataset, None).await.unwrap(), 450); + } + + #[tokio::test] + async fn test_count_rows_sorted_fragment_ids() { + let dataset = make_three_fragment_dataset().await; + let ids: Vec = dataset + .get_fragments() + .iter() + .map(|f| f.id() as u32) + .collect(); + // Skip the middle fragment (200 rows); expect 100 + 150 = 250. + let result = count_rows(&dataset, Some(&[ids[0], ids[2]])).await.unwrap(); + assert_eq!(result, 250); + } + + #[tokio::test] + async fn test_count_rows_unsorted_fragment_ids() { + let dataset = make_three_fragment_dataset().await; + let ids: Vec = dataset + .get_fragments() + .iter() + .map(|f| f.id() as u32) + .collect(); + // Pass the same two fragments in reverse (unsorted) order; result must match. + let result = count_rows(&dataset, Some(&[ids[2], ids[0]])).await.unwrap(); + assert_eq!(result, 250); + } } diff --git a/rust/lance/src/io/exec/filtered_read.rs b/rust/lance/src/io/exec/filtered_read.rs index 5a06408090b..0b83c000ef0 100644 --- a/rust/lance/src/io/exec/filtered_read.rs +++ b/rust/lance/src/io/exec/filtered_read.rs @@ -515,18 +515,31 @@ impl FilteredReadStream { evaluated_index: &Option>, options: &FilteredReadOptions, ) -> FilteredReadInternalPlan { - // For pushing down scan_range_after_filter + // For pushing down scan_range_after_filter. + // + // This is only valid when there is no refine filter left to evaluate. An exact scalar + // index result is exact for the indexed predicate, but not for the full predicate if a + // refine filter can still reject rows. + let can_push_down_scan_range_after_filter = options.refine_filter.is_none(); let mut scan_planned_with_limit_pushed_down = false; - let mut to_skip = options - .scan_range_after_filter - .as_ref() - .map(|r| r.start) - .unwrap_or(0); - let mut to_take = options - .scan_range_after_filter - .as_ref() - .map(|r| r.end - r.start) - .unwrap_or(u64::MAX); + let mut to_skip = if can_push_down_scan_range_after_filter { + options + .scan_range_after_filter + .as_ref() + .map(|r| r.start) + .unwrap_or(0) + } else { + 0 + }; + let mut to_take = if can_push_down_scan_range_after_filter { + options + .scan_range_after_filter + .as_ref() + .map(|r| r.end - r.start) + .unwrap_or(u64::MAX) + } else { + u64::MAX + }; // Full fragment ranges to read before applying scan_range_after_filter let mut fragments_to_read: BTreeMap>> = BTreeMap::new(); @@ -580,9 +593,10 @@ impl FilteredReadStream { &mut to_take, &mut fragments_to_read, &mut scan_push_down_fragments_to_read, + options.only_indexed_fragments, ); - if to_take == 0 { + if can_push_down_scan_range_after_filter && to_take == 0 { scan_planned_with_limit_pushed_down = true; fragments_to_read = scan_push_down_fragments_to_read; break; @@ -706,6 +720,7 @@ impl FilteredReadStream { to_take: &mut u64, fragments_to_read: &mut BTreeMap>>, scan_push_down_fragments_to_read: &mut BTreeMap>>, + only_indexed_fragments: bool, ) { let fragment_id = fragment.id() as u32; @@ -738,10 +753,13 @@ impl FilteredReadStream { } } } else { - // Fragment not indexed - add full fragment to unindexed_ranges - fragments_to_read.insert(fragment_id, to_read); + // Fragment not indexed. Normally we add the full fragment to keep + // results complete. Fast search intentionally accepts staleness. + if !only_indexed_fragments { + fragments_to_read.insert(fragment_id, to_read); + } } - } else { + } else if !only_indexed_fragments { // No index at all - add full fragment to unindexed_ranges fragments_to_read.insert(fragment_id, to_read); } @@ -1279,6 +1297,8 @@ pub struct FilteredReadOptions { pub threading_mode: FilteredReadThreadingMode, /// The size of the I/O buffer to use for the scan pub io_buffer_size_bytes: Option, + /// If true, skip fragments that are not covered by the scalar index result. + pub only_indexed_fragments: bool, } impl FilteredReadOptions { @@ -1307,6 +1327,7 @@ impl FilteredReadOptions { refine_filter: None, full_filter: None, io_buffer_size_bytes: None, + only_indexed_fragments: false, threading_mode: FilteredReadThreadingMode::OnePartitionMultipleThreads( get_num_compute_intensive_cpus(), ), @@ -1459,6 +1480,12 @@ impl FilteredReadOptions { self.io_buffer_size_bytes = Some(io_buffer_size); self } + + /// Only read fragments covered by a scalar index result. + pub fn with_only_indexed_fragments(mut self) -> Self { + self.only_indexed_fragments = true; + self + } } /// A plan node that reads a dataset, applying an optional filter and projection. @@ -2963,6 +2990,34 @@ mod tests { } } + #[tokio::test] + async fn test_with_fetch_limit_after_scalar_index_refine_filter() { + let fixture = Arc::new(TestFixture::new().await); + let base_options = FilteredReadOptions::basic_full_read(&fixture.dataset); + let filter_plan = fixture + .filter_plan("fully_indexed < 50 AND not_indexed >= 10", true) + .await; + let options = base_options.with_filter_plan(filter_plan); + let plan = fixture.make_plan(options).await; + + assert!(plan.index_input.is_some()); + assert!(plan.options().refine_filter.is_some()); + + let limited_plan = plan.with_fetch(Some(10)).unwrap(); + let limited_plan = limited_plan + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(limited_plan.options().scan_range_after_filter, Some(0..10)); + + let stream = limited_plan + .execute(0, Arc::new(TaskContext::default())) + .unwrap(); + let batches = stream.try_collect::>().await.unwrap(); + let actual_values = get_fully_indexed_values(batches).await; + assert_eq!(actual_values, (10..20).collect::>()); + } + #[tokio::test] async fn test_limit_pushdown_comprehensive() { let fixture = Arc::new(TestFixture::new().await); diff --git a/rust/lance/src/io/exec/fts.rs b/rust/lance/src/io/exec/fts.rs index 093ca6cb585..c55a357fa0f 100644 --- a/rust/lance/src/io/exec/fts.rs +++ b/rust/lance/src/io/exec/fts.rs @@ -24,7 +24,7 @@ use datafusion_physical_plan::joins::{HashJoinExec, PartitionMode}; use datafusion_physical_plan::metrics::{BaselineMetrics, Count}; use futures::future::try_join_all; use futures::stream::{self}; -use futures::{StreamExt, TryStreamExt}; +use futures::{FutureExt, StreamExt, TryStreamExt}; use itertools::Itertools; use lance_core::{ Error, ROW_ID, Result, @@ -496,7 +496,11 @@ impl ExecutionPlan for MatchQueryExec { let tokens = collect_query_tokens(&query.terms, &mut tokenizer); let base_scorer = match preset_base_scorer { Some(scorer) => scorer, - None => Arc::new(build_global_bm25_scorer(&indices, &tokens, ¶ms)?), + None => Arc::new( + build_global_bm25_scorer(&indices, &tokens, ¶ms) + .boxed() + .await?, + ), }; pre_filter.wait_for_ready().await?; @@ -1051,7 +1055,9 @@ impl ExecutionPlan for FlatMatchQueryExec { &indices, &query_tokens, &FtsSearchParams::new(), - )? + ) + .boxed() + .await? } }; (tokenizer, Some(base_scorer)) @@ -1363,7 +1369,11 @@ impl ExecutionPlan for PhraseQueryExec { let tokens = collect_query_tokens(&query.terms, &mut tokenizer); let base_scorer = match preset_base_scorer { Some(scorer) => scorer, - None => Arc::new(build_global_bm25_scorer(&indices, &tokens, ¶ms)?), + None => Arc::new( + build_global_bm25_scorer(&indices, &tokens, ¶ms) + .boxed() + .await?, + ), }; pre_filter.wait_for_ready().await?; @@ -2422,8 +2432,11 @@ mod tests { ); let mut tokenizer = indices[0].tokenizer(); let tokens = collect_query_tokens(&query.terms, &mut tokenizer); - let global_scorer = - Arc::new(build_global_bm25_scorer(&indices, &tokens, &search_params).unwrap()); + let global_scorer = Arc::new( + build_global_bm25_scorer(&indices, &tokens, &search_params) + .await + .unwrap(), + ); let override_exec = MatchQueryExec::new_with_segments( dataset.clone(), diff --git a/rust/lance/src/io/exec/knn.rs b/rust/lance/src/io/exec/knn.rs index 49ee7be86bc..71239b4e34b 100644 --- a/rust/lance/src/io/exec/knn.rs +++ b/rust/lance/src/io/exec/knn.rs @@ -2,12 +2,13 @@ // SPDX-FileCopyrightText: Copyright The Lance Authors use std::any::Any; -use std::collections::{HashMap, HashSet}; +use std::cmp::Ordering as CmpOrdering; +use std::collections::{BinaryHeap, HashMap, HashSet}; use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; use std::sync::{Arc, LazyLock, Mutex}; use std::time::Instant; -use arrow::array::Float32Builder; +use arrow::array::{Float32Builder, Int32Builder}; use arrow::datatypes::{Float32Type, UInt32Type, UInt64Type}; use arrow_array::{Array, Float32Array, UInt32Array, UInt64Array}; use arrow_array::{ @@ -16,6 +17,7 @@ use arrow_array::{ cast::AsArray, }; use arrow_schema::{DataType, Field, Schema, SchemaRef}; +use arrow_select::concat::concat_batches; use datafusion::physical_plan::PlanProperties; use datafusion::physical_plan::stream::RecordBatchStreamAdapter; use datafusion::physical_plan::{ @@ -65,10 +67,16 @@ use crate::{Error, Result}; use lance_arrow::*; use super::utils::{ - FilteredRowIdsToPrefilter, IndexMetrics, InstrumentedChildInputStream, PreFilterSource, - SelectionVectorToPrefilter, + FilteredRowIdsToPrefilter, IndexMetrics, InstrumentedChildInputStream, + InstrumentedRecordBatchStreamAdapter, PreFilterSource, SelectionVectorToPrefilter, }; +pub const QUERY_INDEX_COL: &str = "query_index"; + +pub fn query_index_field() -> Field { + Field::new(QUERY_INDEX_COL, DataType::Int32, false) +} + pub struct AnnPartitionMetrics { index_metrics: IndexMetrics, partitions_ranked: Count, @@ -141,23 +149,66 @@ pub struct KNNVectorDistanceExec { /// The vector query to execute. pub query: ArrayRef, + pub is_batch: bool, + pub query_count: usize, + pub k: usize, + pub lower_bound: Option, + pub upper_bound: Option, pub column: String, pub distance_type: DistanceType, + input_schema: SchemaRef, output_schema: SchemaRef, properties: Arc, metrics: ExecutionPlanMetricsSet, } +pub struct KnnBatchParams { + pub is_batch: bool, + pub query_count: usize, + pub k: usize, + pub lower_bound: Option, + pub upper_bound: Option, + pub distance_type: DistanceType, +} + +struct BatchKnnConfig { + input_schema: SchemaRef, + output_schema: SchemaRef, + column: String, + query: ArrayRef, + query_count: usize, + k: usize, + lower_bound: Option, + upper_bound: Option, + distance_type: DistanceType, +} + impl DisplayAs for KNNVectorDistanceExec { fn fmt_as(&self, t: DisplayFormatType, f: &mut std::fmt::Formatter) -> std::fmt::Result { match t { DisplayFormatType::Default | DisplayFormatType::Verbose => { - write!(f, "KNNVectorDistance: metric={}", self.distance_type,) + if self.is_batch { + write!( + f, + "KNNVectorDistance: queries={}, k={}, metric={}", + self.query_count, self.k, self.distance_type, + ) + } else { + write!(f, "KNNVectorDistance: metric={}", self.distance_type,) + } } DisplayFormatType::TreeRender => { - write!(f, "KNNVectorDistance\nmetric={}", self.distance_type,) + if self.is_batch { + write!( + f, + "KNNVectorDistance\nqueries={}\nk={}\nmetric={}", + self.query_count, self.k, self.distance_type, + ) + } else { + write!(f, "KNNVectorDistance\nmetric={}", self.distance_type,) + } } } } @@ -173,16 +224,76 @@ impl KNNVectorDistanceExec { query: ArrayRef, distance_type: DistanceType, ) -> Result { - let mut output_schema = input.schema().as_ref().clone(); - let (_, element_type) = get_vector_type(&(&output_schema).try_into()?, column)?; + Self::try_new_batch( + input, + column, + query, + KnnBatchParams { + is_batch: false, + query_count: 1, + k: 0, + lower_bound: None, + upper_bound: None, + distance_type, + }, + ) + } + + pub(crate) fn try_new_batch( + input: Arc, + column: &str, + query: ArrayRef, + params: KnnBatchParams, + ) -> Result { + let KnnBatchParams { + is_batch, + query_count, + k, + lower_bound, + upper_bound, + distance_type, + } = params; + if query_count == 0 { + return Err(Error::invalid_input( + "query_count must be positive for KNN".to_string(), + )); + } + if !query.len().is_multiple_of(query_count) { + return Err(Error::invalid_input(format!( + "query length ({}) must be divisible by query_count ({})", + query.len(), + query_count + ))); + } + if is_batch && k == 0 { + return Err(Error::invalid_input( + "k must be positive for batch KNN".to_string(), + )); + } + + let mut input_schema = input.schema().as_ref().clone(); + let (_, element_type) = get_vector_type(&(&input_schema).try_into()?, column)?; validate_distance_type_for(distance_type, &element_type)?; // FlatExec appends a distance column to the input schema. The input // may already have a distance column (possibly in the wrong position), so // we need to remove it before adding a new one. - if output_schema.column_with_name(DIST_COL).is_some() { - output_schema = output_schema.without_column(DIST_COL); + if input_schema.column_with_name(DIST_COL).is_some() { + input_schema = input_schema.without_column(DIST_COL); + } + if is_batch && input_schema.column_with_name(QUERY_INDEX_COL).is_some() { + return Err(Error::invalid_input(format!( + "batch KNN cannot run when the input already contains reserved column '{QUERY_INDEX_COL}'" + ))); } + let input_schema = Arc::new(input_schema); + let output_schema = if is_batch { + input_schema + .as_ref() + .try_with_column_at(0, query_index_field())? + } else { + input_schema.as_ref().clone() + }; let output_schema = Arc::new(output_schema.try_with_column(Field::new( DIST_COL, DataType::Float32, @@ -191,24 +302,163 @@ impl KNNVectorDistanceExec { // This node has the same partitioning & boundedness as the input node // but it destroys any ordering. - let properties = Arc::new( - input - .properties() - .as_ref() - .clone() - .with_eq_properties(EquivalenceProperties::new(output_schema.clone())), - ); + let properties = if is_batch { + Arc::new(PlanProperties::new( + EquivalenceProperties::new(output_schema.clone()), + Partitioning::UnknownPartitioning(1), + EmissionType::Final, + Boundedness::Bounded, + )) + } else { + Arc::new( + input + .properties() + .as_ref() + .clone() + .with_eq_properties(EquivalenceProperties::new(output_schema.clone())), + ) + }; Ok(Self { input, query, + is_batch, + query_count, + k, + lower_bound, + upper_bound, column: column.to_string(), distance_type, + input_schema, output_schema, properties, metrics: ExecutionPlanMetricsSet::new(), }) } + + async fn execute_batch( + input: SendableRecordBatchStream, + config: BatchKnnConfig, + ) -> DataFusionResult { + let BatchKnnConfig { + input_schema, + output_schema, + column, + query, + query_count, + k, + lower_bound, + upper_bound, + distance_type, + } = config; + let query_dim = query.len() / query_count; + let mut heaps = (0..query_count) + .map(|_| BinaryHeap::::with_capacity(k)) + .collect::>(); + let mut input = input; + + while let Some(batch) = input.next().await { + let batch = batch?; + if batch.num_rows() == 0 { + continue; + } + + let row_ids = batch + .column_by_name(ROW_ID) + .ok_or_else(|| { + DataFusionError::Internal( + "KNNVectorDistanceExec batch mode requires _rowid in input".to_string(), + ) + })? + .as_primitive::() + .clone(); + + for (query_index, heap) in heaps.iter_mut().enumerate().take(query_count) { + let key = query.slice(query_index * query_dim, query_dim); + let with_distances = compute_distance(key, distance_type, &column, batch.clone()) + .await + .map_err(|e| DataFusionError::External(Box::new(e)))?; + let distances = with_distances[DIST_COL].as_primitive::(); + let distance_values = distances.values(); + for row_index in 0..distances.len() { + if !distances.is_valid(row_index) { + continue; + } + let distance = distance_values[row_index]; + if distance.is_nan() { + continue; + } + // Single-query flat KNN applies distance_range as a plan filter. + // Batch mode filters before insertion so top-k stays per query. + if lower_bound.is_some_and(|lower_bound| distance < lower_bound) + || upper_bound.is_some_and(|upper_bound| distance >= upper_bound) + { + continue; + } + let query_index = query_index as i32; + let row_id = row_ids.value(row_index); + let row_index = row_index as u32; + let candidate = BatchKnnCandidate { + query_index, + distance, + row_id, + batch: batch.clone(), + row_index, + }; + if heap.len() < k { + heap.push(candidate); + } else if heap + .peek() + .is_some_and(|worst| candidate.cmp(worst).is_lt()) + { + heap.pop(); + heap.push(candidate); + } + } + } + } + + let mut results = heaps + .into_iter() + .flat_map(BinaryHeap::into_vec) + .collect::>(); + results.sort_by(|left, right| { + left.query_index + .cmp(&right.query_index) + .then_with(|| left.distance.total_cmp(&right.distance)) + .then_with(|| left.row_id.cmp(&right.row_id)) + }); + + if results.is_empty() { + return Ok(RecordBatch::new_empty(output_schema)); + } + + let mut query_indices = Int32Builder::with_capacity(results.len()); + let mut distances = Float32Builder::with_capacity(results.len()); + let mut row_batches = Vec::with_capacity(results.len()); + for result in results { + query_indices.append_value(result.query_index); + distances.append_value(result.distance); + let indices = UInt32Array::from(vec![result.row_index]); + row_batches.push( + arrow_select::take::take_record_batch(&result.batch, &indices).map_err(|e| { + DataFusionError::ArrowError(Box::new(e), Some("take top-k row".to_string())) + })?, + ); + } + + let output = concat_batches(&input_schema, &row_batches) + .map_err(|e| DataFusionError::ArrowError(Box::new(e), None))?; + output + .try_with_column_at(0, query_index_field(), Arc::new(query_indices.finish())) + .and_then(|batch| { + batch.try_with_column( + Field::new(DIST_COL, DataType::Float32, true), + Arc::new(distances.finish()), + ) + }) + .map_err(|e| DataFusionError::ArrowError(Box::new(e), None)) + } } impl ExecutionPlan for KNNVectorDistanceExec { @@ -239,11 +489,18 @@ impl ExecutionPlan for KNNVectorDistanceExec { )); } - Ok(Arc::new(Self::try_new( + Ok(Arc::new(Self::try_new_batch( children.pop().expect("length checked"), &self.column, self.query.clone(), - self.distance_type, + KnnBatchParams { + is_batch: self.is_batch, + query_count: self.query_count, + k: self.k, + lower_bound: self.lower_bound, + upper_bound: self.upper_bound, + distance_type: self.distance_type, + }, )?)) } @@ -253,7 +510,30 @@ impl ExecutionPlan for KNNVectorDistanceExec { context: Arc, ) -> DataFusionResult { let input_stream = self.input.execute(partition, context)?; - let input_schema = input_stream.schema(); + if self.is_batch { + let stream = stream::once(Self::execute_batch( + input_stream, + BatchKnnConfig { + input_schema: self.input_schema.clone(), + output_schema: self.output_schema.clone(), + column: self.column.clone(), + query: self.query.clone(), + query_count: self.query_count, + k: self.k, + lower_bound: self.lower_bound, + upper_bound: self.upper_bound, + distance_type: self.distance_type, + }, + )); + let schema = self.schema(); + return Ok(Box::pin(InstrumentedRecordBatchStreamAdapter::new( + schema, + stream.boxed(), + partition, + &self.metrics, + )) as SendableRecordBatchStream); + } + let input_schema = self.input.schema(); let key = self.query.clone(); let column = self.column.clone(); let dt = self.distance_type; @@ -284,18 +564,17 @@ impl ExecutionPlan for KNNVectorDistanceExec { // Time around the .await to capture the spawn_blocking // distance work, which otherwise runs while this future is // Pending and is missed by the helper's own poll timer. - let start = std::time::Instant::now(); + let start = Instant::now(); let batch = compute_distance(key, dt, &column, batch) .await .map_err(|e| DataFusionError::External(Box::new(e)))?; elapsed_compute.add_duration(start.elapsed()); let distances = batch[DIST_COL].as_primitive::(); - let mask = BooleanArray::from_iter( - distances - .iter() - .map(|v| Some(v.map(|v| !v.is_nan()).unwrap_or(false))), - ); + let distance_values = distances.values(); + let mask = BooleanArray::from_iter((0..distances.len()).map(|row_index| { + Some(distances.is_valid(row_index) && !distance_values[row_index].is_nan()) + })); arrow::compute::filter_record_batch(&batch, &mask) .map_err(|e| DataFusionError::ArrowError(Box::new(e), None)) } @@ -326,8 +605,18 @@ impl ExecutionPlan for KNNVectorDistanceExec { .zip(schema.fields()) .filter(|(_, field)| field.name() != DIST_COL) .map(|(stats, _)| stats) - .chain(std::iter::once(dist_stats)) .collect::>(); + let column_statistics = if self.is_batch { + std::iter::once(ColumnStatistics::default()) + .chain(column_statistics) + .chain(std::iter::once(dist_stats)) + .collect::>() + } else { + column_statistics + .into_iter() + .chain(std::iter::once(dist_stats)) + .collect::>() + }; Ok(Statistics { num_rows: inner_stats.num_rows, column_statistics, @@ -346,14 +635,65 @@ impl ExecutionPlan for KNNVectorDistanceExec { fn supports_limit_pushdown(&self) -> bool { false } + + fn required_input_distribution(&self) -> Vec { + if self.is_batch { + vec![Distribution::SinglePartition] + } else { + vec![Distribution::UnspecifiedDistribution] + } + } } -pub static KNN_INDEX_SCHEMA: LazyLock = LazyLock::new(|| { - Arc::new(Schema::new(vec![ +#[derive(Clone)] +struct BatchKnnCandidate { + query_index: i32, + distance: f32, + row_id: u64, + batch: RecordBatch, + row_index: u32, +} + +impl PartialEq for BatchKnnCandidate { + fn eq(&self, other: &Self) -> bool { + self.query_index == other.query_index + && self.distance == other.distance + && self.row_id == other.row_id + && self.row_index == other.row_index + } +} + +impl Eq for BatchKnnCandidate {} + +impl PartialOrd for BatchKnnCandidate { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for BatchKnnCandidate { + fn cmp(&self, other: &Self) -> CmpOrdering { + self.distance + .total_cmp(&other.distance) + .then_with(|| self.row_id.cmp(&other.row_id)) + .then_with(|| self.query_index.cmp(&other.query_index)) + .then_with(|| self.row_index.cmp(&other.row_index)) + } +} + +pub static KNN_INDEX_SCHEMA: LazyLock = LazyLock::new(|| knn_empty_result_schema(false)); + +/// Schema for empty vector-search results (e.g. `fast_search` with no index). +pub fn knn_empty_result_schema(include_query_index: bool) -> SchemaRef { + let mut fields = vec![ Field::new(DIST_COL, DataType::Float32, true), ROW_ID_FIELD.clone(), - ])) -}); + ]; + if include_query_index { + fields.insert(0, query_index_field()); + } + Arc::new(Schema::new(fields)) +} pub static KNN_PARTITION_SCHEMA: LazyLock = LazyLock::new(|| { Arc::new(Schema::new(vec![ diff --git a/rust/lance/src/session/caches.rs b/rust/lance/src/session/caches.rs index eab758418f7..82dc755f6c0 100644 --- a/rust/lance/src/session/caches.rs +++ b/rust/lance/src/session/caches.rs @@ -122,12 +122,19 @@ impl CacheKey for DeletionFileKey<'_> { #[derive(Debug)] pub struct RowAddrMaskKey { pub version: u64, + /// `Some(hash)` when the mask is restricted to a fragment subset; `None` + /// when it covers all fragments in the dataset. Two consumers that ask + /// for different subsets must not poison each other's cache entry. + pub restrict_hash: Option, } impl CacheKey for RowAddrMaskKey { type ValueType = RowAddrMask; fn key(&self) -> Cow<'_, str> { - Cow::Owned(format!("row_addr_mask/{}", self.version)) + match self.restrict_hash { + None => Cow::Owned(format!("row_addr_mask/{}", self.version)), + Some(h) => Cow::Owned(format!("row_addr_mask/{}/{:x}", self.version, h)), + } } fn type_name() -> &'static str { "RowAddrMask"