mirror of
				https://github.com/meilisearch/meilisearch.git
				synced 2025-10-25 13:06:27 +00:00 
			
		
		
		
	Small commit to add hybrid search and autoembedding
This commit is contained in:
		
							
								
								
									
										281
									
								
								Cargo.lock
									
									
									
										generated
									
									
									
								
							
							
						
						
									
										281
									
								
								Cargo.lock
									
									
									
										generated
									
									
									
								
							| @@ -46,7 +46,7 @@ dependencies = [ | ||||
|  "actix-tls", | ||||
|  "actix-utils", | ||||
|  "ahash 0.8.3", | ||||
|  "base64 0.21.2", | ||||
|  "base64 0.21.5", | ||||
|  "bitflags 1.3.2", | ||||
|  "brotli", | ||||
|  "bytes", | ||||
| @@ -120,7 +120,7 @@ dependencies = [ | ||||
|  "futures-util", | ||||
|  "mio", | ||||
|  "num_cpus", | ||||
|  "socket2", | ||||
|  "socket2 0.4.9", | ||||
|  "tokio", | ||||
|  "tracing", | ||||
| ] | ||||
| @@ -201,7 +201,7 @@ dependencies = [ | ||||
|  "serde_json", | ||||
|  "serde_urlencoded", | ||||
|  "smallvec", | ||||
|  "socket2", | ||||
|  "socket2 0.4.9", | ||||
|  "time", | ||||
|  "url", | ||||
| ] | ||||
| @@ -365,6 +365,12 @@ dependencies = [ | ||||
|  "backtrace", | ||||
| ] | ||||
|  | ||||
| [[package]] | ||||
| name = "anymap2" | ||||
| version = "0.13.0" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "d301b3b94cb4b2f23d7917810addbbaff90738e0ca2be692bd027e70d7e0330c" | ||||
|  | ||||
| [[package]] | ||||
| name = "arbitrary" | ||||
| version = "1.3.0" | ||||
| @@ -455,9 +461,9 @@ checksum = "9e1b586273c5702936fe7b7d6896644d8be71e6314cfe09d3167c95f712589e8" | ||||
|  | ||||
| [[package]] | ||||
| name = "base64" | ||||
| version = "0.21.2" | ||||
| version = "0.21.5" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "604178f6c5c21f02dc555784810edfb88d34ac2c73b2eae109655649ee73ce3d" | ||||
| checksum = "35636a1494ede3b646cc98f74f8e62c773a38a659ebc777a2cf26b9b74171df9" | ||||
|  | ||||
| [[package]] | ||||
| name = "base64ct" | ||||
| @@ -508,6 +514,21 @@ dependencies = [ | ||||
|  "serde", | ||||
| ] | ||||
|  | ||||
| [[package]] | ||||
| name = "bit-set" | ||||
| version = "0.5.3" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "0700ddab506f33b20a03b13996eccd309a48e5ff77d0d95926aa0210fb4e95f1" | ||||
| dependencies = [ | ||||
|  "bit-vec", | ||||
| ] | ||||
|  | ||||
| [[package]] | ||||
| name = "bit-vec" | ||||
| version = "0.6.3" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "349f9b6a179ed607305526ca489b34ad0a41aed5f7980fa90eb03160b69598fb" | ||||
|  | ||||
| [[package]] | ||||
| name = "bitflags" | ||||
| version = "1.3.2" | ||||
| @@ -555,12 +576,12 @@ dependencies = [ | ||||
|  | ||||
| [[package]] | ||||
| name = "bstr" | ||||
| version = "1.6.0" | ||||
| version = "1.8.0" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "6798148dccfbff0fae41c7574d2fa8f1ef3492fba0face179de5d8d447d67b05" | ||||
| checksum = "542f33a8835a0884b006a0c3df3dadd99c0c3f296ed26c2fdc8028e01ad6230c" | ||||
| dependencies = [ | ||||
|  "memchr", | ||||
|  "regex-automata 0.3.6", | ||||
|  "regex-automata 0.4.3", | ||||
|  "serde", | ||||
| ] | ||||
|  | ||||
| @@ -1346,6 +1367,12 @@ dependencies = [ | ||||
|  "syn 2.0.28", | ||||
| ] | ||||
|  | ||||
| [[package]] | ||||
| name = "doc-comment" | ||||
| version = "0.3.3" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "fea41bba32d969b513997752735605054bc0dfa92b4c56bf1189f2e174be7a10" | ||||
|  | ||||
| [[package]] | ||||
| name = "doxygen-rs" | ||||
| version = "0.2.2" | ||||
| @@ -1562,6 +1589,16 @@ dependencies = [ | ||||
|  "cc", | ||||
| ] | ||||
|  | ||||
| [[package]] | ||||
| name = "fancy-regex" | ||||
| version = "0.11.0" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "b95f7c0680e4142284cf8b22c14a476e87d61b004a3a0861872b32ef7ead40a2" | ||||
| dependencies = [ | ||||
|  "bit-set", | ||||
|  "regex", | ||||
| ] | ||||
|  | ||||
| [[package]] | ||||
| name = "fastrand" | ||||
| version = "2.0.0" | ||||
| @@ -1690,9 +1727,9 @@ checksum = "7ab85b9b05e3978cc9a9cf8fea7f01b494e1a09ed3037e16ba39edc7a29eb61a" | ||||
|  | ||||
| [[package]] | ||||
| name = "futures" | ||||
| version = "0.3.28" | ||||
| version = "0.3.29" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "23342abe12aba583913b2e62f22225ff9c950774065e4bfb61a19cd9770fec40" | ||||
| checksum = "da0290714b38af9b4a7b094b8a37086d1b4e61f2df9122c3cad2577669145335" | ||||
| dependencies = [ | ||||
|  "futures-channel", | ||||
|  "futures-core", | ||||
| @@ -1705,9 +1742,9 @@ dependencies = [ | ||||
|  | ||||
| [[package]] | ||||
| name = "futures-channel" | ||||
| version = "0.3.28" | ||||
| version = "0.3.29" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "955518d47e09b25bbebc7a18df10b81f0c766eaf4c4f1cccef2fca5f2a4fb5f2" | ||||
| checksum = "ff4dd66668b557604244583e3e1e1eada8c5c2e96a6d0d6653ede395b78bbacb" | ||||
| dependencies = [ | ||||
|  "futures-core", | ||||
|  "futures-sink", | ||||
| @@ -1715,15 +1752,15 @@ dependencies = [ | ||||
|  | ||||
| [[package]] | ||||
| name = "futures-core" | ||||
| version = "0.3.28" | ||||
| version = "0.3.29" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "4bca583b7e26f571124fe5b7561d49cb2868d79116cfa0eefce955557c6fee8c" | ||||
| checksum = "eb1d22c66e66d9d72e1758f0bd7d4fd0bee04cad842ee34587d68c07e45d088c" | ||||
|  | ||||
| [[package]] | ||||
| name = "futures-executor" | ||||
| version = "0.3.28" | ||||
| version = "0.3.29" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "ccecee823288125bd88b4d7f565c9e58e41858e47ab72e8ea2d64e93624386e0" | ||||
| checksum = "0f4fb8693db0cf099eadcca0efe2a5a22e4550f98ed16aba6c48700da29597bc" | ||||
| dependencies = [ | ||||
|  "futures-core", | ||||
|  "futures-task", | ||||
| @@ -1732,15 +1769,15 @@ dependencies = [ | ||||
|  | ||||
| [[package]] | ||||
| name = "futures-io" | ||||
| version = "0.3.28" | ||||
| version = "0.3.29" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "4fff74096e71ed47f8e023204cfd0aa1289cd54ae5430a9523be060cdb849964" | ||||
| checksum = "8bf34a163b5c4c52d0478a4d757da8fb65cabef42ba90515efee0f6f9fa45aaa" | ||||
|  | ||||
| [[package]] | ||||
| name = "futures-macro" | ||||
| version = "0.3.28" | ||||
| version = "0.3.29" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "89ca545a94061b6365f2c7355b4b32bd20df3ff95f02da9329b34ccc3bd6ee72" | ||||
| checksum = "53b153fd91e4b0147f4aced87be237c98248656bb01050b96bf3ee89220a8ddb" | ||||
| dependencies = [ | ||||
|  "proc-macro2", | ||||
|  "quote", | ||||
| @@ -1749,21 +1786,21 @@ dependencies = [ | ||||
|  | ||||
| [[package]] | ||||
| name = "futures-sink" | ||||
| version = "0.3.28" | ||||
| version = "0.3.29" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "f43be4fe21a13b9781a69afa4985b0f6ee0e1afab2c6f454a8cf30e2b2237b6e" | ||||
| checksum = "e36d3378ee38c2a36ad710c5d30c2911d752cb941c00c72dbabfb786a7970817" | ||||
|  | ||||
| [[package]] | ||||
| name = "futures-task" | ||||
| version = "0.3.28" | ||||
| version = "0.3.29" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "76d3d132be6c0e6aa1534069c705a74a5997a356c0dc2f86a47765e5617c5b65" | ||||
| checksum = "efd193069b0ddadc69c46389b740bbccdd97203899b48d09c5f7969591d6bae2" | ||||
|  | ||||
| [[package]] | ||||
| name = "futures-util" | ||||
| version = "0.3.28" | ||||
| version = "0.3.29" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "26b01e40b772d54cf6c6d721c1d1abd0647a0106a12ecaa1c186273392a69533" | ||||
| checksum = "a19526d624e703a3179b3d322efec918b6246ea0fa51d41124525f00f1cc8104" | ||||
| dependencies = [ | ||||
|  "futures-channel", | ||||
|  "futures-core", | ||||
| @@ -2207,7 +2244,7 @@ dependencies = [ | ||||
|  "httpdate", | ||||
|  "itoa", | ||||
|  "pin-project-lite", | ||||
|  "socket2", | ||||
|  "socket2 0.4.9", | ||||
|  "tokio", | ||||
|  "tower-service", | ||||
|  "tracing", | ||||
| @@ -2949,7 +2986,7 @@ version = "8.3.0" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "6971da4d9c3aa03c3d8f3ff0f4155b534aad021292003895a469716b2a230378" | ||||
| dependencies = [ | ||||
|  "base64 0.21.2", | ||||
|  "base64 0.21.5", | ||||
|  "pem", | ||||
|  "ring", | ||||
|  "serde", | ||||
| @@ -2957,6 +2994,16 @@ dependencies = [ | ||||
|  "simple_asn1", | ||||
| ] | ||||
|  | ||||
| [[package]] | ||||
| name = "kstring" | ||||
| version = "2.0.0" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "ec3066350882a1cd6d950d055997f379ac37fd39f81cd4d8ed186032eb3c5747" | ||||
| dependencies = [ | ||||
|  "serde", | ||||
|  "static_assertions", | ||||
| ] | ||||
|  | ||||
| [[package]] | ||||
| name = "language-tags" | ||||
| version = "0.3.2" | ||||
| @@ -2980,9 +3027,9 @@ dependencies = [ | ||||
|  | ||||
| [[package]] | ||||
| name = "libc" | ||||
| version = "0.2.147" | ||||
| version = "0.2.150" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "b4668fb0ea861c1df094127ac5f1da3409a82116a4ba74fca2e58ef927159bb3" | ||||
| checksum = "89d92a4743f9a61002fae18374ed11e7973f530cb3a3255fb354818118b2203c" | ||||
|  | ||||
| [[package]] | ||||
| name = "libgit2-sys" | ||||
| @@ -3251,6 +3298,63 @@ version = "0.4.5" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "57bcfdad1b858c2db7c38303a6d2ad4dfaf5eb53dfeb0910128b2c26d6158503" | ||||
|  | ||||
| [[package]] | ||||
| name = "liquid" | ||||
| version = "0.26.4" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "69f68ae1011499ae2ef879f631891f21c78e309755f4a5e483c4a8f12e10b609" | ||||
| dependencies = [ | ||||
|  "doc-comment", | ||||
|  "liquid-core", | ||||
|  "liquid-derive", | ||||
|  "liquid-lib", | ||||
|  "serde", | ||||
| ] | ||||
|  | ||||
| [[package]] | ||||
| name = "liquid-core" | ||||
| version = "0.26.4" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "79e0724dfcaad5cfb7965ea0f178ca0870b8d7315178f4a7179f5696f7f04d5f" | ||||
| dependencies = [ | ||||
|  "anymap2", | ||||
|  "itertools 0.10.5", | ||||
|  "kstring", | ||||
|  "liquid-derive", | ||||
|  "num-traits", | ||||
|  "pest", | ||||
|  "pest_derive", | ||||
|  "regex", | ||||
|  "serde", | ||||
|  "time", | ||||
| ] | ||||
|  | ||||
| [[package]] | ||||
| name = "liquid-derive" | ||||
| version = "0.26.4" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "fc2fb41a9bb4257a3803154bdf7e2df7d45197d1941c9b1a90ad815231630721" | ||||
| dependencies = [ | ||||
|  "proc-macro2", | ||||
|  "quote", | ||||
|  "syn 2.0.28", | ||||
| ] | ||||
|  | ||||
| [[package]] | ||||
| name = "liquid-lib" | ||||
| version = "0.26.4" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "e2a17e273a6fb1fb6268f7a5867ddfd0bd4683c7e19b51084f3d567fad4348c0" | ||||
| dependencies = [ | ||||
|  "itertools 0.10.5", | ||||
|  "liquid-core", | ||||
|  "once_cell", | ||||
|  "percent-encoding", | ||||
|  "regex", | ||||
|  "time", | ||||
|  "unicode-segmentation", | ||||
| ] | ||||
|  | ||||
| [[package]] | ||||
| name = "litemap" | ||||
| version = "0.6.1" | ||||
| @@ -3483,7 +3587,7 @@ dependencies = [ | ||||
| name = "meilisearch-auth" | ||||
| version = "1.5.1" | ||||
| dependencies = [ | ||||
|  "base64 0.21.2", | ||||
|  "base64 0.21.5", | ||||
|  "enum-iterator", | ||||
|  "hmac", | ||||
|  "maplit", | ||||
| @@ -3544,9 +3648,9 @@ dependencies = [ | ||||
|  | ||||
| [[package]] | ||||
| name = "memchr" | ||||
| version = "2.5.0" | ||||
| version = "2.6.4" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "2dffe52ecf27772e601905b7522cb4ef790d2cc203488bbd0e2fe85fcb74566d" | ||||
| checksum = "f665ee40bc4a3c5590afb1e9677db74a508659dfd71e126420da8274909a0167" | ||||
|  | ||||
| [[package]] | ||||
| name = "memmap2" | ||||
| @@ -3589,6 +3693,7 @@ dependencies = [ | ||||
|  "filter-parser", | ||||
|  "flatten-serde-json", | ||||
|  "fst", | ||||
|  "futures", | ||||
|  "fxhash", | ||||
|  "geoutils", | ||||
|  "grenad", | ||||
| @@ -3600,6 +3705,7 @@ dependencies = [ | ||||
|  "itertools 0.11.0", | ||||
|  "json-depth-checker", | ||||
|  "levenshtein_automata", | ||||
|  "liquid", | ||||
|  "log", | ||||
|  "logging_timer", | ||||
|  "maplit", | ||||
| @@ -3607,6 +3713,7 @@ dependencies = [ | ||||
|  "meili-snap", | ||||
|  "memmap2", | ||||
|  "mimalloc", | ||||
|  "nolife", | ||||
|  "obkv", | ||||
|  "once_cell", | ||||
|  "ordered-float", | ||||
| @@ -3614,6 +3721,7 @@ dependencies = [ | ||||
|  "rand", | ||||
|  "rand_pcg", | ||||
|  "rayon", | ||||
|  "reqwest", | ||||
|  "roaring", | ||||
|  "rstar", | ||||
|  "serde", | ||||
| @@ -3624,8 +3732,10 @@ dependencies = [ | ||||
|  "smartstring", | ||||
|  "tempfile", | ||||
|  "thiserror", | ||||
|  "tiktoken-rs", | ||||
|  "time", | ||||
|  "tokenizers", | ||||
|  "tokio", | ||||
|  "uuid 1.5.0", | ||||
| ] | ||||
|  | ||||
| @@ -3671,9 +3781,9 @@ dependencies = [ | ||||
|  | ||||
| [[package]] | ||||
| name = "mio" | ||||
| version = "0.8.8" | ||||
| version = "0.8.9" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "927a765cd3fc26206e66b296465fa9d3e5ab003e651c1b3c060e7956d96b19d2" | ||||
| checksum = "3dce281c5e46beae905d4de1870d8b1509a9142b62eedf18b443b011ca8343d0" | ||||
| dependencies = [ | ||||
|  "libc", | ||||
|  "log", | ||||
| @@ -3725,6 +3835,12 @@ name = "nelson" | ||||
| version = "0.1.0" | ||||
| source = "git+https://github.com/meilisearch/nelson.git?rev=675f13885548fb415ead8fbb447e9e6d9314000a#675f13885548fb415ead8fbb447e9e6d9314000a" | ||||
|  | ||||
| [[package]] | ||||
| name = "nolife" | ||||
| version = "0.3.1" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "bc52aaf087e8a52e7a2692f83f2dac6ac7ff9d0136bf9c6ac496635cfe3e50dc" | ||||
|  | ||||
| [[package]] | ||||
| name = "nom" | ||||
| version = "7.1.3" | ||||
| @@ -4480,6 +4596,12 @@ dependencies = [ | ||||
|  "regex-syntax", | ||||
| ] | ||||
|  | ||||
| [[package]] | ||||
| name = "regex-automata" | ||||
| version = "0.4.3" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "5f804c7828047e88b2d32e2d7fe5a105da8ee3264f01902f796c8e067dc2483f" | ||||
|  | ||||
| [[package]] | ||||
| name = "regex-syntax" | ||||
| version = "0.7.4" | ||||
| @@ -4488,11 +4610,11 @@ checksum = "e5ea92a5b6195c6ef2a0295ea818b312502c6fc94dde986c5553242e18fd4ce2" | ||||
|  | ||||
| [[package]] | ||||
| name = "reqwest" | ||||
| version = "0.11.18" | ||||
| version = "0.11.22" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "cde824a14b7c14f85caff81225f411faacc04a2013f41670f41443742b1c1c55" | ||||
| checksum = "046cd98826c46c2ac8ddecae268eb5c2e58628688a5fc7a2643704a73faba95b" | ||||
| dependencies = [ | ||||
|  "base64 0.21.2", | ||||
|  "base64 0.21.5", | ||||
|  "bytes", | ||||
|  "encoding_rs", | ||||
|  "futures-core", | ||||
| @@ -4514,6 +4636,7 @@ dependencies = [ | ||||
|  "serde", | ||||
|  "serde_json", | ||||
|  "serde_urlencoded", | ||||
|  "system-configuration", | ||||
|  "tokio", | ||||
|  "tokio-rustls 0.24.1", | ||||
|  "tower-service", | ||||
| @@ -4521,7 +4644,7 @@ dependencies = [ | ||||
|  "wasm-bindgen", | ||||
|  "wasm-bindgen-futures", | ||||
|  "web-sys", | ||||
|  "webpki-roots 0.22.6", | ||||
|  "webpki-roots 0.25.3", | ||||
|  "winreg", | ||||
| ] | ||||
|  | ||||
| @@ -4582,6 +4705,12 @@ version = "0.1.23" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "d626bb9dae77e28219937af045c257c28bfd3f69333c512553507f5f9798cb76" | ||||
|  | ||||
| [[package]] | ||||
| name = "rustc-hash" | ||||
| version = "1.1.0" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2" | ||||
|  | ||||
| [[package]] | ||||
| name = "rustc_version" | ||||
| version = "0.4.0" | ||||
| @@ -4648,7 +4777,7 @@ version = "1.0.3" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "2d3987094b1d07b653b7dfdc3f70ce9a1da9c51ac18c1b06b662e4f9a0e9f4b2" | ||||
| dependencies = [ | ||||
|  "base64 0.21.2", | ||||
|  "base64 0.21.5", | ||||
| ] | ||||
|  | ||||
| [[package]] | ||||
| @@ -4977,6 +5106,16 @@ dependencies = [ | ||||
|  "winapi", | ||||
| ] | ||||
|  | ||||
| [[package]] | ||||
| name = "socket2" | ||||
| version = "0.5.5" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "7b5fac59a5cb5dd637972e5fca70daf0523c9067fcdc4842f053dae04a18f8e9" | ||||
| dependencies = [ | ||||
|  "libc", | ||||
|  "windows-sys 0.48.0", | ||||
| ] | ||||
|  | ||||
| [[package]] | ||||
| name = "spin" | ||||
| version = "0.5.2" | ||||
| @@ -5097,6 +5236,27 @@ dependencies = [ | ||||
|  "winapi", | ||||
| ] | ||||
|  | ||||
| [[package]] | ||||
| name = "system-configuration" | ||||
| version = "0.5.1" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "ba3a3adc5c275d719af8cb4272ea1c4a6d668a777f37e115f6d11ddbc1c8e0e7" | ||||
| dependencies = [ | ||||
|  "bitflags 1.3.2", | ||||
|  "core-foundation", | ||||
|  "system-configuration-sys", | ||||
| ] | ||||
|  | ||||
| [[package]] | ||||
| name = "system-configuration-sys" | ||||
| version = "0.5.0" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "a75fb188eb626b924683e3b95e3a48e63551fcfb51949de2f06a9d91dbee93c9" | ||||
| dependencies = [ | ||||
|  "core-foundation-sys", | ||||
|  "libc", | ||||
| ] | ||||
|  | ||||
| [[package]] | ||||
| name = "tar" | ||||
| version = "0.4.40" | ||||
| @@ -5159,6 +5319,21 @@ dependencies = [ | ||||
|  "syn 2.0.28", | ||||
| ] | ||||
|  | ||||
| [[package]] | ||||
| name = "tiktoken-rs" | ||||
| version = "0.5.7" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "a4427b6b1c6b38215b92dd47a83a0ecc6735573d0a5a4c14acc0ac5b33b28adb" | ||||
| dependencies = [ | ||||
|  "anyhow", | ||||
|  "base64 0.21.5", | ||||
|  "bstr", | ||||
|  "fancy-regex", | ||||
|  "lazy_static", | ||||
|  "parking_lot", | ||||
|  "rustc-hash", | ||||
| ] | ||||
|  | ||||
| [[package]] | ||||
| name = "time" | ||||
| version = "0.3.30" | ||||
| @@ -5258,11 +5433,10 @@ dependencies = [ | ||||
|  | ||||
| [[package]] | ||||
| name = "tokio" | ||||
| version = "1.29.1" | ||||
| version = "1.34.0" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "532826ff75199d5833b9d2c5fe410f29235e25704ee5f0ef599fb51c21f4a4da" | ||||
| checksum = "d0c014766411e834f7af5b8f4cf46257aab4036ca95e9d2c144a10f59ad6f5b9" | ||||
| dependencies = [ | ||||
|  "autocfg", | ||||
|  "backtrace", | ||||
|  "bytes", | ||||
|  "libc", | ||||
| @@ -5271,16 +5445,16 @@ dependencies = [ | ||||
|  "parking_lot", | ||||
|  "pin-project-lite", | ||||
|  "signal-hook-registry", | ||||
|  "socket2", | ||||
|  "socket2 0.5.5", | ||||
|  "tokio-macros", | ||||
|  "windows-sys 0.48.0", | ||||
| ] | ||||
|  | ||||
| [[package]] | ||||
| name = "tokio-macros" | ||||
| version = "2.1.0" | ||||
| version = "2.2.0" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "630bdcf245f78637c13ec01ffae6187cca34625e8c63150d424b59e55af2675e" | ||||
| checksum = "5b8a1e28f2deaa14e508979454cb3a223b10b938b45af148bc0986de36f1923b" | ||||
| dependencies = [ | ||||
|  "proc-macro2", | ||||
|  "quote", | ||||
| @@ -5508,7 +5682,7 @@ version = "2.7.1" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "0b11c96ac7ee530603dcdf68ed1557050f374ce55a5a07193ebf8cbc9f8927e9" | ||||
| dependencies = [ | ||||
|  "base64 0.21.2", | ||||
|  "base64 0.21.5", | ||||
|  "flate2", | ||||
|  "log", | ||||
|  "native-tls", | ||||
| @@ -5758,6 +5932,12 @@ dependencies = [ | ||||
|  "rustls-webpki 0.100.2", | ||||
| ] | ||||
|  | ||||
| [[package]] | ||||
| name = "webpki-roots" | ||||
| version = "0.25.3" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "1778a42e8b3b90bff8d0f5032bf22250792889a5cdc752aa0020c84abe3aaf10" | ||||
|  | ||||
| [[package]] | ||||
| name = "whatlang" | ||||
| version = "0.16.2" | ||||
| @@ -5942,11 +6122,12 @@ dependencies = [ | ||||
|  | ||||
| [[package]] | ||||
| name = "winreg" | ||||
| version = "0.10.1" | ||||
| version = "0.50.0" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "80d0f4e272c85def139476380b12f9ac60926689dd2e01d4923222f40580869d" | ||||
| checksum = "524e57b2c537c0f9b1e69f1965311ec12182b4122e45035b1508cd24d2adadb1" | ||||
| dependencies = [ | ||||
|  "winapi", | ||||
|  "cfg-if", | ||||
|  "windows-sys 0.48.0", | ||||
| ] | ||||
|  | ||||
| [[package]] | ||||
|   | ||||
| @@ -276,6 +276,7 @@ pub(crate) mod test { | ||||
|                 ), | ||||
|             }), | ||||
|             pagination: Setting::NotSet, | ||||
|             embedders: Setting::NotSet, | ||||
|             _kind: std::marker::PhantomData, | ||||
|         }; | ||||
|         settings.check() | ||||
|   | ||||
| @@ -378,6 +378,7 @@ impl<T> From<v5::Settings<T>> for v6::Settings<v6::Unchecked> { | ||||
|                 v5::Setting::Reset => v6::Setting::Reset, | ||||
|                 v5::Setting::NotSet => v6::Setting::NotSet, | ||||
|             }, | ||||
|             embedders: v6::Setting::NotSet, | ||||
|             _kind: std::marker::PhantomData, | ||||
|         } | ||||
|     } | ||||
|   | ||||
| @@ -1202,6 +1202,10 @@ impl IndexScheduler { | ||||
|  | ||||
|                 let config = IndexDocumentsConfig { update_method: method, ..Default::default() }; | ||||
|  | ||||
|                 let embedder_configs = index.embedding_configs(index_wtxn)?; | ||||
|                 // TODO: consider Arc'ing the map too (we only need read access + we'll be cloning it multiple times, so really makes sense) | ||||
|                 let embedders = self.embedders(embedder_configs)?; | ||||
|  | ||||
|                 let mut builder = milli::update::IndexDocuments::new( | ||||
|                     index_wtxn, | ||||
|                     index, | ||||
| @@ -1220,6 +1224,8 @@ impl IndexScheduler { | ||||
|                             let (new_builder, user_result) = builder.add_documents(reader)?; | ||||
|                             builder = new_builder; | ||||
|  | ||||
|                             builder = builder.with_embedders(embedders.clone()); | ||||
|  | ||||
|                             let received_documents = | ||||
|                                 if let Some(Details::DocumentAdditionOrUpdate { | ||||
|                                     received_documents, | ||||
| @@ -1345,6 +1351,9 @@ impl IndexScheduler { | ||||
|  | ||||
|                 for (task, (_, settings)) in tasks.iter_mut().zip(settings) { | ||||
|                     let checked_settings = settings.clone().check(); | ||||
|                     if matches!(checked_settings.embedders, milli::update::Setting::Set(_)) { | ||||
|                         self.features().check_vector("Passing `embedders` in settings")? | ||||
|                     } | ||||
|                     if checked_settings.proximity_precision.set().is_some() { | ||||
|                         self.features.features().check_proximity_precision()?; | ||||
|                     } | ||||
|   | ||||
| @@ -56,12 +56,12 @@ impl RoFeatures { | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     pub fn check_vector(&self) -> Result<()> { | ||||
|     pub fn check_vector(&self, disabled_action: &'static str) -> Result<()> { | ||||
|         if self.runtime.vector_store { | ||||
|             Ok(()) | ||||
|         } else { | ||||
|             Err(FeatureNotEnabledError { | ||||
|                 disabled_action: "Passing `vector` as a query parameter", | ||||
|                 disabled_action, | ||||
|                 feature: "vector store", | ||||
|                 issue_link: "https://github.com/meilisearch/product/discussions/677", | ||||
|             } | ||||
|   | ||||
| @@ -41,6 +41,7 @@ pub fn snapshot_index_scheduler(scheduler: &IndexScheduler) -> String { | ||||
|         planned_failures: _, | ||||
|         run_loop_iteration: _, | ||||
|         currently_updating_index: _, | ||||
|         embedders: _, | ||||
|     } = scheduler; | ||||
|  | ||||
|     let rtxn = env.read_txn().unwrap(); | ||||
|   | ||||
| @@ -52,6 +52,7 @@ use meilisearch_types::heed::types::{SerdeBincode, SerdeJson, Str, I128}; | ||||
| use meilisearch_types::heed::{self, Database, Env, PutFlags, RoTxn, RwTxn}; | ||||
| use meilisearch_types::milli::documents::DocumentsBatchBuilder; | ||||
| use meilisearch_types::milli::update::IndexerConfig; | ||||
| use meilisearch_types::milli::vector::{Embedder, EmbedderOptions}; | ||||
| use meilisearch_types::milli::{self, CboRoaringBitmapCodec, Index, RoaringBitmapCodec, BEU32}; | ||||
| use meilisearch_types::tasks::{Kind, KindWithContent, Status, Task}; | ||||
| use puffin::FrameView; | ||||
| @@ -341,6 +342,8 @@ pub struct IndexScheduler { | ||||
|     /// so that a handle to the index is available from other threads (search) in an optimized manner. | ||||
|     currently_updating_index: Arc<RwLock<Option<(String, Index)>>>, | ||||
|  | ||||
|     embedders: Arc<RwLock<HashMap<EmbedderOptions, std::sync::Arc<Embedder>>>>, | ||||
|  | ||||
|     // ================= test | ||||
|     // The next entry is dedicated to the tests. | ||||
|     /// Provide a way to set a breakpoint in multiple part of the scheduler. | ||||
| @@ -386,6 +389,7 @@ impl IndexScheduler { | ||||
|             auth_path: self.auth_path.clone(), | ||||
|             version_file_path: self.version_file_path.clone(), | ||||
|             currently_updating_index: self.currently_updating_index.clone(), | ||||
|             embedders: self.embedders.clone(), | ||||
|             #[cfg(test)] | ||||
|             test_breakpoint_sdr: self.test_breakpoint_sdr.clone(), | ||||
|             #[cfg(test)] | ||||
| @@ -484,6 +488,7 @@ impl IndexScheduler { | ||||
|             auth_path: options.auth_path, | ||||
|             version_file_path: options.version_file_path, | ||||
|             currently_updating_index: Arc::new(RwLock::new(None)), | ||||
|             embedders: Default::default(), | ||||
|  | ||||
|             #[cfg(test)] | ||||
|             test_breakpoint_sdr, | ||||
| @@ -1333,6 +1338,42 @@ impl IndexScheduler { | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     // TODO: consider using a type alias or a struct embedder/template | ||||
|     #[allow(clippy::type_complexity)] | ||||
|     pub fn embedders( | ||||
|         &self, | ||||
|         embedding_configs: Vec<(String, milli::vector::EmbeddingConfig)>, | ||||
|     ) -> Result<HashMap<String, (Arc<milli::vector::Embedder>, Arc<milli::prompt::Prompt>)>> { | ||||
|         let res: Result<_> = embedding_configs | ||||
|             .into_iter() | ||||
|             .map(|(name, milli::vector::EmbeddingConfig { embedder_options, prompt })| { | ||||
|                 let prompt = | ||||
|                     Arc::new(prompt.try_into().map_err(meilisearch_types::milli::Error::from)?); | ||||
|                 // optimistically return existing embedder | ||||
|                 { | ||||
|                     let embedders = self.embedders.read().unwrap(); | ||||
|                     if let Some(embedder) = embedders.get(&embedder_options) { | ||||
|                         return Ok((name, (embedder.clone(), prompt))); | ||||
|                     } | ||||
|                 } | ||||
|  | ||||
|                 // add missing embedder | ||||
|                 let embedder = Arc::new( | ||||
|                     Embedder::new(embedder_options.clone()) | ||||
|                         .map_err(meilisearch_types::milli::vector::Error::from) | ||||
|                         .map_err(meilisearch_types::milli::UserError::from) | ||||
|                         .map_err(meilisearch_types::milli::Error::from)?, | ||||
|                 ); | ||||
|                 { | ||||
|                     let mut embedders = self.embedders.write().unwrap(); | ||||
|                     embedders.insert(embedder_options, embedder.clone()); | ||||
|                 } | ||||
|                 Ok((name, (embedder, prompt))) | ||||
|             }) | ||||
|             .collect(); | ||||
|         res | ||||
|     } | ||||
|  | ||||
|     /// Blocks the thread until the test handle asks to progress to/through this breakpoint. | ||||
|     /// | ||||
|     /// Two messages are sent through the channel for each breakpoint. | ||||
|   | ||||
| @@ -256,6 +256,7 @@ InvalidSettingsProximityPrecision     , InvalidRequest       , BAD_REQUEST ; | ||||
| InvalidSettingsFaceting               , InvalidRequest       , BAD_REQUEST ; | ||||
| InvalidSettingsFilterableAttributes   , InvalidRequest       , BAD_REQUEST ; | ||||
| InvalidSettingsPagination             , InvalidRequest       , BAD_REQUEST ; | ||||
| InvalidSettingsEmbedders              , InvalidRequest       , BAD_REQUEST ; | ||||
| InvalidSettingsRankingRules           , InvalidRequest       , BAD_REQUEST ; | ||||
| InvalidSettingsSearchableAttributes   , InvalidRequest       , BAD_REQUEST ; | ||||
| InvalidSettingsSortableAttributes     , InvalidRequest       , BAD_REQUEST ; | ||||
| @@ -303,7 +304,8 @@ TaskNotFound                          , InvalidRequest       , NOT_FOUND ; | ||||
| TooManyOpenFiles                      , System               , UNPROCESSABLE_ENTITY ; | ||||
| UnretrievableDocument                 , Internal             , BAD_REQUEST ; | ||||
| UnretrievableErrorCode                , InvalidRequest       , BAD_REQUEST ; | ||||
| UnsupportedMediaType                  , InvalidRequest       , UNSUPPORTED_MEDIA_TYPE | ||||
| UnsupportedMediaType                  , InvalidRequest       , UNSUPPORTED_MEDIA_TYPE ; | ||||
| VectorEmbeddingError                  , InvalidRequest       , BAD_REQUEST | ||||
| } | ||||
|  | ||||
| impl ErrorCode for JoinError { | ||||
| @@ -336,6 +338,9 @@ impl ErrorCode for milli::Error { | ||||
|                     UserError::InvalidDocumentId { .. } | UserError::TooManyDocumentIds { .. } => { | ||||
|                         Code::InvalidDocumentId | ||||
|                     } | ||||
|                     UserError::MissingDocumentField(_) => Code::InvalidDocumentFields, | ||||
|                     UserError::InvalidPrompt(_) => Code::InvalidSettingsEmbedders, | ||||
|                     UserError::InvalidPromptForEmbeddings(..) => Code::InvalidSettingsEmbedders, | ||||
|                     UserError::NoPrimaryKeyCandidateFound => Code::IndexPrimaryKeyNoCandidateFound, | ||||
|                     UserError::MultiplePrimaryKeyCandidatesFound { .. } => { | ||||
|                         Code::IndexPrimaryKeyMultipleCandidatesFound | ||||
| @@ -358,6 +363,7 @@ impl ErrorCode for milli::Error { | ||||
|                     UserError::InvalidMinTypoWordLenSetting(_, _) => { | ||||
|                         Code::InvalidSettingsTypoTolerance | ||||
|                     } | ||||
|                     UserError::VectorEmbeddingError(_) => Code::VectorEmbeddingError, | ||||
|                 } | ||||
|             } | ||||
|         } | ||||
|   | ||||
| @@ -199,6 +199,10 @@ pub struct Settings<T> { | ||||
|     #[deserr(default, error = DeserrJsonError<InvalidSettingsPagination>)] | ||||
|     pub pagination: Setting<PaginationSettings>, | ||||
|  | ||||
|     #[serde(default, skip_serializing_if = "Setting::is_not_set")] | ||||
|     #[deserr(default, error = DeserrJsonError<InvalidSettingsEmbedders>)] | ||||
|     pub embedders: Setting<BTreeMap<String, Setting<milli::vector::settings::EmbeddingSettings>>>, | ||||
|  | ||||
|     #[serde(skip)] | ||||
|     #[deserr(skip)] | ||||
|     pub _kind: PhantomData<T>, | ||||
| @@ -222,6 +226,7 @@ impl Settings<Checked> { | ||||
|             typo_tolerance: Setting::Reset, | ||||
|             faceting: Setting::Reset, | ||||
|             pagination: Setting::Reset, | ||||
|             embedders: Setting::Reset, | ||||
|             _kind: PhantomData, | ||||
|         } | ||||
|     } | ||||
| @@ -243,6 +248,7 @@ impl Settings<Checked> { | ||||
|             typo_tolerance, | ||||
|             faceting, | ||||
|             pagination, | ||||
|             embedders, | ||||
|             .. | ||||
|         } = self; | ||||
|  | ||||
| @@ -262,6 +268,7 @@ impl Settings<Checked> { | ||||
|             typo_tolerance, | ||||
|             faceting, | ||||
|             pagination, | ||||
|             embedders, | ||||
|             _kind: PhantomData, | ||||
|         } | ||||
|     } | ||||
| @@ -307,6 +314,7 @@ impl Settings<Unchecked> { | ||||
|             typo_tolerance: self.typo_tolerance, | ||||
|             faceting: self.faceting, | ||||
|             pagination: self.pagination, | ||||
|             embedders: self.embedders, | ||||
|             _kind: PhantomData, | ||||
|         } | ||||
|     } | ||||
| @@ -490,6 +498,12 @@ pub fn apply_settings_to_builder( | ||||
|         Setting::Reset => builder.reset_pagination_max_total_hits(), | ||||
|         Setting::NotSet => (), | ||||
|     } | ||||
|  | ||||
|     match settings.embedders.clone() { | ||||
|         Setting::Set(value) => builder.set_embedder_settings(value), | ||||
|         Setting::Reset => builder.reset_embedder_settings(), | ||||
|         Setting::NotSet => (), | ||||
|     } | ||||
| } | ||||
|  | ||||
| pub fn settings( | ||||
| @@ -571,6 +585,12 @@ pub fn settings( | ||||
|         ), | ||||
|     }; | ||||
|  | ||||
|     let embedders = index | ||||
|         .embedding_configs(rtxn)? | ||||
|         .into_iter() | ||||
|         .map(|(name, config)| (name, Setting::Set(config.into()))) | ||||
|         .collect(); | ||||
|  | ||||
|     Ok(Settings { | ||||
|         displayed_attributes: match displayed_attributes { | ||||
|             Some(attrs) => Setting::Set(attrs), | ||||
| @@ -599,6 +619,7 @@ pub fn settings( | ||||
|         typo_tolerance: Setting::Set(typo_tolerance), | ||||
|         faceting: Setting::Set(faceting), | ||||
|         pagination: Setting::Set(pagination), | ||||
|         embedders: Setting::Set(embedders), | ||||
|         _kind: PhantomData, | ||||
|     }) | ||||
| } | ||||
| @@ -747,6 +768,7 @@ pub(crate) mod test { | ||||
|             typo_tolerance: Setting::NotSet, | ||||
|             faceting: Setting::NotSet, | ||||
|             pagination: Setting::NotSet, | ||||
|             embedders: Setting::NotSet, | ||||
|             _kind: PhantomData::<Unchecked>, | ||||
|         }; | ||||
|  | ||||
| @@ -772,6 +794,7 @@ pub(crate) mod test { | ||||
|             typo_tolerance: Setting::NotSet, | ||||
|             faceting: Setting::NotSet, | ||||
|             pagination: Setting::NotSet, | ||||
|             embedders: Setting::NotSet, | ||||
|             _kind: PhantomData::<Unchecked>, | ||||
|         }; | ||||
|  | ||||
|   | ||||
| @@ -686,7 +686,7 @@ impl SearchAggregator { | ||||
|             ret.max_terms_number = q.split_whitespace().count(); | ||||
|         } | ||||
|  | ||||
|         if let Some(ref vector) = vector { | ||||
|         if let Some(meilisearch_types::milli::VectorQuery::Vector(ref vector)) = vector { | ||||
|             ret.max_vector_size = vector.len(); | ||||
|         } | ||||
|  | ||||
|   | ||||
| @@ -19,7 +19,11 @@ static ALLOC: mimalloc::MiMalloc = mimalloc::MiMalloc; | ||||
| /// does all the setup before meilisearch is launched | ||||
| fn setup(opt: &Opt) -> anyhow::Result<()> { | ||||
|     let mut log_builder = env_logger::Builder::new(); | ||||
|     log_builder.parse_filters(&opt.log_level.to_string()); | ||||
|     let log_filters = format!( | ||||
|         "{},h2=warn,hyper=warn,tokio_util=warn,tracing=warn,rustls=warn,mio=warn,reqwest=warn", | ||||
|         opt.log_level | ||||
|     ); | ||||
|     log_builder.parse_filters(&log_filters); | ||||
|  | ||||
|     log_builder.init(); | ||||
|  | ||||
|   | ||||
| @@ -7,6 +7,7 @@ use meilisearch_types::deserr::DeserrJsonError; | ||||
| use meilisearch_types::error::deserr_codes::*; | ||||
| use meilisearch_types::error::ResponseError; | ||||
| use meilisearch_types::index_uid::IndexUid; | ||||
| use meilisearch_types::milli::VectorQuery; | ||||
| use serde_json::Value; | ||||
|  | ||||
| use crate::analytics::{Analytics, FacetSearchAggregator}; | ||||
| @@ -117,7 +118,7 @@ impl From<FacetSearchQuery> for SearchQuery { | ||||
|             highlight_post_tag: DEFAULT_HIGHLIGHT_POST_TAG(), | ||||
|             crop_marker: DEFAULT_CROP_MARKER(), | ||||
|             matching_strategy, | ||||
|             vector, | ||||
|             vector: vector.map(VectorQuery::Vector), | ||||
|             attributes_to_search_on, | ||||
|         } | ||||
|     } | ||||
|   | ||||
| @@ -2,12 +2,13 @@ use actix_web::web::Data; | ||||
| use actix_web::{web, HttpRequest, HttpResponse}; | ||||
| use deserr::actix_web::{AwebJson, AwebQueryParameter}; | ||||
| use index_scheduler::IndexScheduler; | ||||
| use log::debug; | ||||
| use log::{debug, warn}; | ||||
| use meilisearch_types::deserr::query_params::Param; | ||||
| use meilisearch_types::deserr::{DeserrJsonError, DeserrQueryParamError}; | ||||
| use meilisearch_types::error::deserr_codes::*; | ||||
| use meilisearch_types::error::ResponseError; | ||||
| use meilisearch_types::index_uid::IndexUid; | ||||
| use meilisearch_types::milli::VectorQuery; | ||||
| use meilisearch_types::serde_cs::vec::CS; | ||||
| use serde_json::Value; | ||||
|  | ||||
| @@ -88,7 +89,7 @@ impl From<SearchQueryGet> for SearchQuery { | ||||
|  | ||||
|         Self { | ||||
|             q: other.q, | ||||
|             vector: other.vector.map(CS::into_inner), | ||||
|             vector: other.vector.map(CS::into_inner).map(VectorQuery::Vector), | ||||
|             offset: other.offset.0, | ||||
|             limit: other.limit.0, | ||||
|             page: other.page.as_deref().copied(), | ||||
| @@ -193,6 +194,9 @@ pub async fn search_with_post( | ||||
|     let index = index_scheduler.index(&index_uid)?; | ||||
|  | ||||
|     let features = index_scheduler.features(); | ||||
|  | ||||
|     embed(&mut query, index_scheduler.get_ref(), &index).await?; | ||||
|  | ||||
|     let search_result = | ||||
|         tokio::task::spawn_blocking(move || perform_search(&index, query, features)).await?; | ||||
|     if let Ok(ref search_result) = search_result { | ||||
| @@ -206,6 +210,38 @@ pub async fn search_with_post( | ||||
|     Ok(HttpResponse::Ok().json(search_result)) | ||||
| } | ||||
|  | ||||
| pub async fn embed( | ||||
|     query: &mut SearchQuery, | ||||
|     index_scheduler: &IndexScheduler, | ||||
|     index: &meilisearch_types::milli::Index, | ||||
| ) -> Result<(), ResponseError> { | ||||
|     if let Some(VectorQuery::String(prompt)) = query.vector.take() { | ||||
|         let embedder_configs = index.embedding_configs(&index.read_txn()?)?; | ||||
|         let embedder = index_scheduler.embedders(embedder_configs)?; | ||||
|  | ||||
|         /// FIXME: add error if no embedder, remove unwrap, support multiple embedders | ||||
|         let embeddings = embedder | ||||
|             .get("default") | ||||
|             .unwrap() | ||||
|             .0 | ||||
|             .embed(vec![prompt]) | ||||
|             .await | ||||
|             .map_err(meilisearch_types::milli::vector::Error::from) | ||||
|             .map_err(meilisearch_types::milli::UserError::from) | ||||
|             .map_err(meilisearch_types::milli::Error::from)? | ||||
|             .pop() | ||||
|             .expect("No vector returned from embedding"); | ||||
|  | ||||
|         if embeddings.iter().nth(1).is_some() { | ||||
|             warn!("Ignoring embeddings past the first one in long search query"); | ||||
|             query.vector = Some(VectorQuery::Vector(embeddings.iter().next().unwrap().to_vec())); | ||||
|         } else { | ||||
|             query.vector = Some(VectorQuery::Vector(embeddings.into_inner())); | ||||
|         } | ||||
|     }; | ||||
|     Ok(()) | ||||
| } | ||||
|  | ||||
| #[cfg(test)] | ||||
| mod test { | ||||
|     use super::*; | ||||
|   | ||||
| @@ -13,6 +13,7 @@ use crate::analytics::{Analytics, MultiSearchAggregator}; | ||||
| use crate::extractors::authentication::policies::ActionPolicy; | ||||
| use crate::extractors::authentication::{AuthenticationError, GuardedData}; | ||||
| use crate::extractors::sequential_extractor::SeqHandler; | ||||
| use crate::routes::indexes::search::embed; | ||||
| use crate::search::{ | ||||
|     add_search_rules, perform_search, SearchQueryWithIndex, SearchResultWithIndex, | ||||
| }; | ||||
| @@ -74,6 +75,8 @@ pub async fn multi_search_with_post( | ||||
|                 }) | ||||
|                 .with_index(query_index)?; | ||||
|  | ||||
|             embed(&mut query, index_scheduler.get_ref(), &index).await.with_index(query_index)?; | ||||
|  | ||||
|             let search_result = | ||||
|                 tokio::task::spawn_blocking(move || perform_search(&index, query, features)) | ||||
|                     .await | ||||
|   | ||||
| @@ -16,6 +16,7 @@ use meilisearch_types::index_uid::IndexUid; | ||||
| use meilisearch_types::milli::score_details::{ScoreDetails, ScoringStrategy}; | ||||
| use meilisearch_types::milli::{ | ||||
|     dot_product_similarity, FacetValueHit, InternalError, OrderBy, SearchForFacetValues, | ||||
|     VectorQuery, | ||||
| }; | ||||
| use meilisearch_types::settings::DEFAULT_PAGINATION_MAX_TOTAL_HITS; | ||||
| use meilisearch_types::{milli, Document}; | ||||
| @@ -46,7 +47,7 @@ pub struct SearchQuery { | ||||
|     #[deserr(default, error = DeserrJsonError<InvalidSearchQ>)] | ||||
|     pub q: Option<String>, | ||||
|     #[deserr(default, error = DeserrJsonError<InvalidSearchVector>)] | ||||
|     pub vector: Option<Vec<f32>>, | ||||
|     pub vector: Option<milli::VectorQuery>, | ||||
|     #[deserr(default = DEFAULT_SEARCH_OFFSET(), error = DeserrJsonError<InvalidSearchOffset>)] | ||||
|     pub offset: usize, | ||||
|     #[deserr(default = DEFAULT_SEARCH_LIMIT(), error = DeserrJsonError<InvalidSearchLimit>)] | ||||
| @@ -105,7 +106,7 @@ pub struct SearchQueryWithIndex { | ||||
|     #[deserr(default, error = DeserrJsonError<InvalidSearchQ>)] | ||||
|     pub q: Option<String>, | ||||
|     #[deserr(default, error = DeserrJsonError<InvalidSearchQ>)] | ||||
|     pub vector: Option<Vec<f32>>, | ||||
|     pub vector: Option<VectorQuery>, | ||||
|     #[deserr(default = DEFAULT_SEARCH_OFFSET(), error = DeserrJsonError<InvalidSearchOffset>)] | ||||
|     pub offset: usize, | ||||
|     #[deserr(default = DEFAULT_SEARCH_LIMIT(), error = DeserrJsonError<InvalidSearchLimit>)] | ||||
| @@ -339,11 +340,18 @@ fn prepare_search<'t>( | ||||
|     let mut search = index.search(rtxn); | ||||
|  | ||||
|     if query.vector.is_some() && query.q.is_some() { | ||||
|         warn!("Ignoring the query string `q` when used with the `vector` parameter."); | ||||
|         warn!("Attempting hybrid search"); | ||||
|     } | ||||
|  | ||||
|     if let Some(ref vector) = query.vector { | ||||
|         search.vector(vector.clone()); | ||||
|         match vector { | ||||
|             VectorQuery::Vector(vector) => { | ||||
|                 search.vector(vector.clone()); | ||||
|             } | ||||
|             VectorQuery::String(_) => { | ||||
|                 panic!("Failed while preparing search; caller did not generate embedding for query") | ||||
|             } | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     if let Some(ref query) = query.q { | ||||
| @@ -375,7 +383,7 @@ fn prepare_search<'t>( | ||||
|     } | ||||
|  | ||||
|     if query.vector.is_some() { | ||||
|         features.check_vector()?; | ||||
|         features.check_vector("Passing `vector` as a query parameter")?; | ||||
|     } | ||||
|  | ||||
|     // compute the offset on the limit depending on the pagination mode. | ||||
| @@ -429,7 +437,11 @@ pub fn perform_search( | ||||
|         prepare_search(index, &rtxn, &query, features)?; | ||||
|  | ||||
|     let milli::SearchResult { documents_ids, matching_words, candidates, document_scores, .. } = | ||||
|         search.execute()?; | ||||
|         if query.q.is_some() && query.vector.is_some() { | ||||
|             search.execute_hybrid()? | ||||
|         } else { | ||||
|             search.execute()? | ||||
|         }; | ||||
|  | ||||
|     let fields_ids_map = index.fields_ids_map(&rtxn).unwrap(); | ||||
|  | ||||
| @@ -538,13 +550,13 @@ pub fn perform_search( | ||||
|             insert_geo_distance(sort, &mut document); | ||||
|         } | ||||
|  | ||||
|         let semantic_score = match query.vector.as_ref() { | ||||
|         let semantic_score = /*match query.vector.as_ref() { | ||||
|             Some(vector) => match extract_field("_vectors", &fields_ids_map, obkv)? { | ||||
|                 Some(vectors) => compute_semantic_score(vector, vectors)?, | ||||
|                 None => None, | ||||
|             }, | ||||
|             None => None, | ||||
|         }; | ||||
|         };*/ None; | ||||
|  | ||||
|         let ranking_score = | ||||
|             query.show_ranking_score.then(|| ScoreDetails::global_score(score.iter())); | ||||
| @@ -629,7 +641,8 @@ pub fn perform_search( | ||||
|         hits: documents, | ||||
|         hits_info, | ||||
|         query: query.q.unwrap_or_default(), | ||||
|         vector: query.vector, | ||||
|         // FIXME: display input vector | ||||
|         vector: None, | ||||
|         processing_time_ms: before_search.elapsed().as_millis(), | ||||
|         facet_distribution, | ||||
|         facet_stats, | ||||
|   | ||||
| @@ -27,10 +27,13 @@ fst = "0.4.7" | ||||
| fxhash = "0.2.1" | ||||
| geoutils = "0.5.1" | ||||
| grenad = { version = "0.4.5", default-features = false, features = [ | ||||
|     "rayon", "tempfile" | ||||
|     "rayon", | ||||
|     "tempfile", | ||||
| ] } | ||||
| heed = { version = "0.20.0-alpha.9", default-features = false, features = [ | ||||
|     "serde-json", "serde-bincode", "read-txn-no-tls" | ||||
|     "serde-json", | ||||
|     "serde-bincode", | ||||
|     "read-txn-no-tls", | ||||
| ] } | ||||
| indexmap = { version = "2.0.0", features = ["serde"] } | ||||
| instant-distance = { version = "0.6.1", features = ["with-serde"] } | ||||
| @@ -77,6 +80,15 @@ candle-transformers = { git = "https://github.com/huggingface/candle.git", versi | ||||
| candle-nn = { git = "https://github.com/huggingface/candle.git", version = "0.3.1" } | ||||
| tokenizers = { git = "https://github.com/huggingface/tokenizers.git", tag = "v0.14.1", version = "0.14.1" } | ||||
| hf-hub = "0.3.2" | ||||
| tokio = { version = "1.34.0", features = ["rt"] } | ||||
| futures = "0.3.29" | ||||
| nolife = { version = "0.3.1" } | ||||
| reqwest = { version = "0.11.16", features = [ | ||||
|     "rustls-tls", | ||||
|     "json", | ||||
| ], default-features = false } | ||||
| tiktoken-rs = "0.5.7" | ||||
| liquid = "0.26.4" | ||||
|  | ||||
| [dev-dependencies] | ||||
| mimalloc = { version = "0.1.37", default-features = false } | ||||
| @@ -88,7 +100,15 @@ meili-snap = { path = "../meili-snap" } | ||||
| rand = { version = "0.8.5", features = ["small_rng"] } | ||||
|  | ||||
| [features] | ||||
| all-tokenizations = ["charabia/chinese", "charabia/hebrew", "charabia/japanese", "charabia/thai", "charabia/korean", "charabia/greek", "charabia/khmer"] | ||||
| all-tokenizations = [ | ||||
|     "charabia/chinese", | ||||
|     "charabia/hebrew", | ||||
|     "charabia/japanese", | ||||
|     "charabia/thai", | ||||
|     "charabia/korean", | ||||
|     "charabia/greek", | ||||
|     "charabia/khmer", | ||||
| ] | ||||
|  | ||||
| # Use POSIX semaphores instead of SysV semaphores in LMDB | ||||
| # For more information on this feature, see heed's Cargo.toml | ||||
|   | ||||
| @@ -5,8 +5,8 @@ use std::time::Instant; | ||||
|  | ||||
| use heed::EnvOpenOptions; | ||||
| use milli::{ | ||||
|     execute_search, DefaultSearchLogger, GeoSortStrategy, Index, SearchContext, SearchLogger, | ||||
|     TermsMatchingStrategy, | ||||
|     execute_search, filtered_universe, DefaultSearchLogger, GeoSortStrategy, Index, SearchContext, | ||||
|     SearchLogger, TermsMatchingStrategy, | ||||
| }; | ||||
|  | ||||
| #[global_allocator] | ||||
| @@ -49,14 +49,15 @@ fn main() -> Result<(), Box<dyn Error>> { | ||||
|             let start = Instant::now(); | ||||
|  | ||||
|             let mut ctx = SearchContext::new(&index, &txn); | ||||
|             let universe = filtered_universe(&ctx, &None)?; | ||||
|  | ||||
|             let docs = execute_search( | ||||
|                 &mut ctx, | ||||
|                 &(!query.trim().is_empty()).then(|| query.trim().to_owned()), | ||||
|                 &None, | ||||
|                 (!query.trim().is_empty()).then(|| query.trim()), | ||||
|                 TermsMatchingStrategy::Last, | ||||
|                 milli::score_details::ScoringStrategy::Skip, | ||||
|                 false, | ||||
|                 &None, | ||||
|                 universe, | ||||
|                 &None, | ||||
|                 GeoSortStrategy::default(), | ||||
|                 0, | ||||
|   | ||||
| @@ -180,6 +180,14 @@ only composed of alphanumeric characters (a-z A-Z 0-9), hyphens (-) and undersco | ||||
|     UnknownInternalDocumentId { document_id: DocumentId }, | ||||
|     #[error("`minWordSizeForTypos` setting is invalid. `oneTypo` and `twoTypos` fields should be between `0` and `255`, and `twoTypos` should be greater or equals to `oneTypo` but found `oneTypo: {0}` and twoTypos: {1}`.")] | ||||
|     InvalidMinTypoWordLenSetting(u8, u8), | ||||
|     #[error(transparent)] | ||||
|     VectorEmbeddingError(#[from] crate::vector::Error), | ||||
|     #[error(transparent)] | ||||
|     MissingDocumentField(#[from] crate::prompt::error::RenderPromptError), | ||||
|     #[error(transparent)] | ||||
|     InvalidPrompt(#[from] crate::prompt::error::NewPromptError), | ||||
|     #[error("Invalid prompt in for embeddings with name '{0}': {1}")] | ||||
|     InvalidPromptForEmbeddings(String, crate::prompt::error::NewPromptError), | ||||
| } | ||||
|  | ||||
| #[derive(Error, Debug)] | ||||
| @@ -336,6 +344,26 @@ impl From<HeedError> for Error { | ||||
|     } | ||||
| } | ||||
|  | ||||
| #[derive(Debug, Clone, Copy)] | ||||
| pub enum FaultSource { | ||||
|     User, | ||||
|     Runtime, | ||||
|     Bug, | ||||
|     Undecided, | ||||
| } | ||||
|  | ||||
| impl std::fmt::Display for FaultSource { | ||||
|     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { | ||||
|         let s = match self { | ||||
|             FaultSource::User => "user error", | ||||
|             FaultSource::Runtime => "runtime error", | ||||
|             FaultSource::Bug => "coding error", | ||||
|             FaultSource::Undecided => "error", | ||||
|         }; | ||||
|         f.write_str(s) | ||||
|     } | ||||
| } | ||||
|  | ||||
| #[test] | ||||
| fn conditionally_lookup_for_error_message() { | ||||
|     let prefix = "Attribute `name` is not sortable."; | ||||
|   | ||||
| @@ -23,6 +23,7 @@ use crate::heed_codec::{ | ||||
| }; | ||||
| use crate::proximity::ProximityPrecision; | ||||
| use crate::readable_slices::ReadableSlices; | ||||
| use crate::vector::EmbeddingConfig; | ||||
| use crate::{ | ||||
|     default_criteria, CboRoaringBitmapCodec, Criterion, DocumentId, ExternalDocumentsIds, | ||||
|     FacetDistribution, FieldDistribution, FieldId, FieldIdWordCountCodec, GeoPoint, ObkvCodec, | ||||
| @@ -74,6 +75,7 @@ pub mod main_key { | ||||
|     pub const SORT_FACET_VALUES_BY: &str = "sort-facet-values-by"; | ||||
|     pub const PAGINATION_MAX_TOTAL_HITS: &str = "pagination-max-total-hits"; | ||||
|     pub const PROXIMITY_PRECISION: &str = "proximity-precision"; | ||||
|     pub const EMBEDDING_CONFIGS: &str = "embedding_configs"; | ||||
| } | ||||
|  | ||||
| pub mod db_name { | ||||
| @@ -1528,6 +1530,33 @@ impl Index { | ||||
|  | ||||
|         Ok(script_language) | ||||
|     } | ||||
|  | ||||
|     pub(crate) fn put_embedding_configs( | ||||
|         &self, | ||||
|         wtxn: &mut RwTxn<'_>, | ||||
|         configs: Vec<(String, EmbeddingConfig)>, | ||||
|     ) -> heed::Result<()> { | ||||
|         self.main.remap_types::<Str, SerdeJson<Vec<(String, EmbeddingConfig)>>>().put( | ||||
|             wtxn, | ||||
|             main_key::EMBEDDING_CONFIGS, | ||||
|             &configs, | ||||
|         ) | ||||
|     } | ||||
|  | ||||
|     pub(crate) fn delete_embedding_configs(&self, wtxn: &mut RwTxn<'_>) -> heed::Result<bool> { | ||||
|         self.main.remap_key_type::<Str>().delete(wtxn, main_key::EMBEDDING_CONFIGS) | ||||
|     } | ||||
|  | ||||
|     pub fn embedding_configs( | ||||
|         &self, | ||||
|         rtxn: &RoTxn<'_>, | ||||
|     ) -> Result<Vec<(String, crate::vector::EmbeddingConfig)>> { | ||||
|         Ok(self | ||||
|             .main | ||||
|             .remap_types::<Str, SerdeJson<Vec<(String, EmbeddingConfig)>>>() | ||||
|             .get(rtxn, main_key::EMBEDDING_CONFIGS)? | ||||
|             .unwrap_or_default()) | ||||
|     } | ||||
| } | ||||
|  | ||||
| #[cfg(test)] | ||||
|   | ||||
| @@ -17,11 +17,13 @@ pub mod facet; | ||||
| mod fields_ids_map; | ||||
| pub mod heed_codec; | ||||
| pub mod index; | ||||
| pub mod prompt; | ||||
| pub mod proximity; | ||||
| mod readable_slices; | ||||
| pub mod score_details; | ||||
| mod search; | ||||
| pub mod update; | ||||
| pub mod vector; | ||||
|  | ||||
| #[cfg(test)] | ||||
| #[macro_use] | ||||
| @@ -37,8 +39,8 @@ pub use filter_parser::{Condition, FilterCondition, Span, Token}; | ||||
| use fxhash::{FxHasher32, FxHasher64}; | ||||
| pub use grenad::CompressionType; | ||||
| pub use search::new::{ | ||||
|     execute_search, DefaultSearchLogger, GeoSortStrategy, SearchContext, SearchLogger, | ||||
|     VisualSearchLogger, | ||||
|     execute_search, filtered_universe, DefaultSearchLogger, GeoSortStrategy, SearchContext, | ||||
|     SearchLogger, VisualSearchLogger, | ||||
| }; | ||||
| use serde_json::Value; | ||||
| pub use {charabia as tokenizer, heed}; | ||||
| @@ -60,7 +62,7 @@ pub use self::index::Index; | ||||
| pub use self::search::{ | ||||
|     FacetDistribution, FacetValueHit, Filter, FormatOptions, MatchBounds, MatcherBuilder, | ||||
|     MatchingWords, OrderBy, Search, SearchForFacetValues, SearchResult, TermsMatchingStrategy, | ||||
|     DEFAULT_VALUES_PER_FACET, | ||||
|     VectorQuery, DEFAULT_VALUES_PER_FACET, | ||||
| }; | ||||
|  | ||||
| pub type Result<T> = std::result::Result<T, error::Error>; | ||||
|   | ||||
							
								
								
									
										97
									
								
								milli/src/prompt/context.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										97
									
								
								milli/src/prompt/context.rs
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,97 @@ | ||||
| use liquid::model::{ | ||||
|     ArrayView, DisplayCow, KStringCow, ObjectRender, ObjectSource, State, Value as LiquidValue, | ||||
| }; | ||||
| use liquid::{ObjectView, ValueView}; | ||||
|  | ||||
| use super::document::Document; | ||||
| use super::fields::Fields; | ||||
| use crate::FieldsIdsMap; | ||||
|  | ||||
| #[derive(Debug, Clone)] | ||||
| pub struct Context<'a> { | ||||
|     document: &'a Document<'a>, | ||||
|     fields: Fields<'a>, | ||||
| } | ||||
|  | ||||
| impl<'a> Context<'a> { | ||||
|     pub fn new(document: &'a Document<'a>, field_id_map: &'a FieldsIdsMap) -> Self { | ||||
|         Self { document, fields: Fields::new(document, field_id_map) } | ||||
|     } | ||||
| } | ||||
|  | ||||
| impl<'a> ObjectView for Context<'a> { | ||||
|     fn as_value(&self) -> &dyn ValueView { | ||||
|         self | ||||
|     } | ||||
|  | ||||
|     fn size(&self) -> i64 { | ||||
|         2 | ||||
|     } | ||||
|  | ||||
|     fn keys<'k>(&'k self) -> Box<dyn Iterator<Item = KStringCow<'k>> + 'k> { | ||||
|         Box::new(["doc", "fields"].iter().map(|s| KStringCow::from_static(s))) | ||||
|     } | ||||
|  | ||||
|     fn values<'k>(&'k self) -> Box<dyn Iterator<Item = &'k dyn ValueView> + 'k> { | ||||
|         Box::new( | ||||
|             std::iter::once(self.document.as_value()) | ||||
|                 .chain(std::iter::once(self.fields.as_value())), | ||||
|         ) | ||||
|     } | ||||
|  | ||||
|     fn iter<'k>(&'k self) -> Box<dyn Iterator<Item = (KStringCow<'k>, &'k dyn ValueView)> + 'k> { | ||||
|         Box::new(self.keys().zip(self.values())) | ||||
|     } | ||||
|  | ||||
|     fn contains_key(&self, index: &str) -> bool { | ||||
|         index == "doc" || index == "fields" | ||||
|     } | ||||
|  | ||||
|     fn get<'s>(&'s self, index: &str) -> Option<&'s dyn ValueView> { | ||||
|         match index { | ||||
|             "doc" => Some(self.document.as_value()), | ||||
|             "fields" => Some(self.fields.as_value()), | ||||
|             _ => None, | ||||
|         } | ||||
|     } | ||||
| } | ||||
|  | ||||
| impl<'a> ValueView for Context<'a> { | ||||
|     fn as_debug(&self) -> &dyn std::fmt::Debug { | ||||
|         self | ||||
|     } | ||||
|  | ||||
|     fn render(&self) -> liquid::model::DisplayCow<'_> { | ||||
|         DisplayCow::Owned(Box::new(ObjectRender::new(self))) | ||||
|     } | ||||
|  | ||||
|     fn source(&self) -> liquid::model::DisplayCow<'_> { | ||||
|         DisplayCow::Owned(Box::new(ObjectSource::new(self))) | ||||
|     } | ||||
|  | ||||
|     fn type_name(&self) -> &'static str { | ||||
|         "object" | ||||
|     } | ||||
|  | ||||
|     fn query_state(&self, state: liquid::model::State) -> bool { | ||||
|         match state { | ||||
|             State::Truthy => true, | ||||
|             State::DefaultValue | State::Empty | State::Blank => false, | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     fn to_kstr(&self) -> liquid::model::KStringCow<'_> { | ||||
|         let s = ObjectRender::new(self).to_string(); | ||||
|         KStringCow::from_string(s) | ||||
|     } | ||||
|  | ||||
|     fn to_value(&self) -> LiquidValue { | ||||
|         LiquidValue::Object( | ||||
|             self.iter().map(|(k, x)| (k.to_string().into(), x.to_value())).collect(), | ||||
|         ) | ||||
|     } | ||||
|  | ||||
|     fn as_object(&self) -> Option<&dyn ObjectView> { | ||||
|         Some(self) | ||||
|     } | ||||
| } | ||||
							
								
								
									
										131
									
								
								milli/src/prompt/document.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										131
									
								
								milli/src/prompt/document.rs
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,131 @@ | ||||
| use std::cell::OnceCell; | ||||
| use std::collections::BTreeMap; | ||||
|  | ||||
| use liquid::model::{ | ||||
|     DisplayCow, KString, KStringCow, ObjectRender, ObjectSource, State, Value as LiquidValue, | ||||
| }; | ||||
| use liquid::{ObjectView, ValueView}; | ||||
|  | ||||
| use crate::update::del_add::{DelAdd, KvReaderDelAdd}; | ||||
| use crate::FieldsIdsMap; | ||||
|  | ||||
| #[derive(Debug, Clone)] | ||||
| pub struct Document<'a>(BTreeMap<&'a str, (&'a [u8], ParsedValue)>); | ||||
|  | ||||
| #[derive(Debug, Clone)] | ||||
| struct ParsedValue(std::cell::OnceCell<LiquidValue>); | ||||
|  | ||||
| impl ParsedValue { | ||||
|     fn empty() -> ParsedValue { | ||||
|         ParsedValue(OnceCell::new()) | ||||
|     } | ||||
|  | ||||
|     fn get(&self, raw: &[u8]) -> &LiquidValue { | ||||
|         self.0.get_or_init(|| { | ||||
|             let value: serde_json::Value = serde_json::from_slice(raw).unwrap(); | ||||
|             liquid::model::to_value(&value).unwrap() | ||||
|         }) | ||||
|     } | ||||
| } | ||||
|  | ||||
| impl<'a> Document<'a> { | ||||
|     pub fn new( | ||||
|         data: obkv::KvReaderU16<'a>, | ||||
|         side: DelAdd, | ||||
|         inverted_field_map: &'a FieldsIdsMap, | ||||
|     ) -> Self { | ||||
|         let mut out_data = BTreeMap::new(); | ||||
|         for (fid, raw) in data { | ||||
|             let obkv = KvReaderDelAdd::new(raw); | ||||
|             let Some(raw) = obkv.get(side) else { | ||||
|                 continue; | ||||
|             }; | ||||
|             let Some(name) = inverted_field_map.name(fid) else { | ||||
|                 continue; | ||||
|             }; | ||||
|             out_data.insert(name, (raw, ParsedValue::empty())); | ||||
|         } | ||||
|         Self(out_data) | ||||
|     } | ||||
|  | ||||
|     fn is_empty(&self) -> bool { | ||||
|         self.0.is_empty() | ||||
|     } | ||||
|  | ||||
|     fn len(&self) -> usize { | ||||
|         self.0.len() | ||||
|     } | ||||
|  | ||||
|     fn iter(&self) -> impl Iterator<Item = (KString, LiquidValue)> + '_ { | ||||
|         self.0.iter().map(|(&k, (raw, data))| (k.to_owned().into(), data.get(raw).to_owned())) | ||||
|     } | ||||
| } | ||||
|  | ||||
| impl<'a> ObjectView for Document<'a> { | ||||
|     fn as_value(&self) -> &dyn ValueView { | ||||
|         self | ||||
|     } | ||||
|  | ||||
|     fn size(&self) -> i64 { | ||||
|         self.len() as i64 | ||||
|     } | ||||
|  | ||||
|     fn keys<'k>(&'k self) -> Box<dyn Iterator<Item = KStringCow<'k>> + 'k> { | ||||
|         let keys = BTreeMap::keys(&self.0).map(|&s| s.into()); | ||||
|         Box::new(keys) | ||||
|     } | ||||
|  | ||||
|     fn values<'k>(&'k self) -> Box<dyn Iterator<Item = &'k dyn ValueView> + 'k> { | ||||
|         Box::new(self.0.values().map(|(raw, v)| v.get(raw) as &dyn ValueView)) | ||||
|     } | ||||
|  | ||||
|     fn iter<'k>(&'k self) -> Box<dyn Iterator<Item = (KStringCow<'k>, &'k dyn ValueView)> + 'k> { | ||||
|         Box::new(self.0.iter().map(|(&k, (raw, data))| (k.into(), data.get(raw) as &dyn ValueView))) | ||||
|     } | ||||
|  | ||||
|     fn contains_key(&self, index: &str) -> bool { | ||||
|         self.0.contains_key(index) | ||||
|     } | ||||
|  | ||||
|     fn get<'s>(&'s self, index: &str) -> Option<&'s dyn ValueView> { | ||||
|         self.0.get(index).map(|(raw, v)| v.get(raw) as &dyn ValueView) | ||||
|     } | ||||
| } | ||||
|  | ||||
| impl<'a> ValueView for Document<'a> { | ||||
|     fn as_debug(&self) -> &dyn std::fmt::Debug { | ||||
|         self | ||||
|     } | ||||
|  | ||||
|     fn render(&self) -> liquid::model::DisplayCow<'_> { | ||||
|         DisplayCow::Owned(Box::new(ObjectRender::new(self))) | ||||
|     } | ||||
|  | ||||
|     fn source(&self) -> liquid::model::DisplayCow<'_> { | ||||
|         DisplayCow::Owned(Box::new(ObjectSource::new(self))) | ||||
|     } | ||||
|  | ||||
|     fn type_name(&self) -> &'static str { | ||||
|         "object" | ||||
|     } | ||||
|  | ||||
|     fn query_state(&self, state: liquid::model::State) -> bool { | ||||
|         match state { | ||||
|             State::Truthy => true, | ||||
|             State::DefaultValue | State::Empty | State::Blank => self.is_empty(), | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     fn to_kstr(&self) -> liquid::model::KStringCow<'_> { | ||||
|         let s = ObjectRender::new(self).to_string(); | ||||
|         KStringCow::from_string(s) | ||||
|     } | ||||
|  | ||||
|     fn to_value(&self) -> LiquidValue { | ||||
|         LiquidValue::Object(self.iter().collect()) | ||||
|     } | ||||
|  | ||||
|     fn as_object(&self) -> Option<&dyn ObjectView> { | ||||
|         Some(self) | ||||
|     } | ||||
| } | ||||
							
								
								
									
										56
									
								
								milli/src/prompt/error.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										56
									
								
								milli/src/prompt/error.rs
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,56 @@ | ||||
| use crate::error::FaultSource; | ||||
|  | ||||
| #[derive(Debug, thiserror::Error)] | ||||
| #[error("{fault}: {kind}")] | ||||
| pub struct NewPromptError { | ||||
|     pub kind: NewPromptErrorKind, | ||||
|     pub fault: FaultSource, | ||||
| } | ||||
|  | ||||
| impl From<NewPromptError> for crate::Error { | ||||
|     fn from(value: NewPromptError) -> Self { | ||||
|         crate::Error::UserError(crate::UserError::InvalidPrompt(value)) | ||||
|     } | ||||
| } | ||||
|  | ||||
| impl NewPromptError { | ||||
|     pub(crate) fn cannot_parse_template(inner: liquid::Error) -> NewPromptError { | ||||
|         Self { kind: NewPromptErrorKind::CannotParseTemplate(inner), fault: FaultSource::User } | ||||
|     } | ||||
|  | ||||
|     pub(crate) fn invalid_fields_in_template(inner: liquid::Error) -> NewPromptError { | ||||
|         Self { kind: NewPromptErrorKind::InvalidFieldsInTemplate(inner), fault: FaultSource::User } | ||||
|     } | ||||
| } | ||||
|  | ||||
| #[derive(Debug, thiserror::Error)] | ||||
| pub enum NewPromptErrorKind { | ||||
|     #[error("cannot parse template: {0}")] | ||||
|     CannotParseTemplate(liquid::Error), | ||||
|     #[error("template contains invalid fields: {0}. Only `doc.*`, `fields[i].name`, `fields[i].value` are supported")] | ||||
|     InvalidFieldsInTemplate(liquid::Error), | ||||
| } | ||||
|  | ||||
| #[derive(Debug, thiserror::Error)] | ||||
| #[error("{fault}: {kind}")] | ||||
| pub struct RenderPromptError { | ||||
|     pub kind: RenderPromptErrorKind, | ||||
|     pub fault: FaultSource, | ||||
| } | ||||
| impl RenderPromptError { | ||||
|     pub(crate) fn missing_context(inner: liquid::Error) -> RenderPromptError { | ||||
|         Self { kind: RenderPromptErrorKind::MissingContext(inner), fault: FaultSource::User } | ||||
|     } | ||||
| } | ||||
|  | ||||
| #[derive(Debug, thiserror::Error)] | ||||
| pub enum RenderPromptErrorKind { | ||||
|     #[error("missing field in document: {0}")] | ||||
|     MissingContext(liquid::Error), | ||||
| } | ||||
|  | ||||
| impl From<RenderPromptError> for crate::Error { | ||||
|     fn from(value: RenderPromptError) -> Self { | ||||
|         crate::Error::UserError(crate::UserError::MissingDocumentField(value)) | ||||
|     } | ||||
| } | ||||
							
								
								
									
										172
									
								
								milli/src/prompt/fields.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										172
									
								
								milli/src/prompt/fields.rs
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,172 @@ | ||||
| use liquid::model::{ | ||||
|     ArrayView, DisplayCow, KStringCow, ObjectRender, ObjectSource, State, Value as LiquidValue, | ||||
| }; | ||||
| use liquid::{ObjectView, ValueView}; | ||||
|  | ||||
| use super::document::Document; | ||||
| use crate::FieldsIdsMap; | ||||
| #[derive(Debug, Clone)] | ||||
| pub struct Fields<'a>(Vec<FieldValue<'a>>); | ||||
|  | ||||
| impl<'a> Fields<'a> { | ||||
|     pub fn new(document: &'a Document<'a>, field_id_map: &'a FieldsIdsMap) -> Self { | ||||
|         Self( | ||||
|             std::iter::repeat(document) | ||||
|                 .zip(field_id_map.iter()) | ||||
|                 .map(|(document, (_fid, name))| FieldValue { document, name }) | ||||
|                 .collect(), | ||||
|         ) | ||||
|     } | ||||
| } | ||||
|  | ||||
| #[derive(Debug, Clone, Copy)] | ||||
| pub struct FieldValue<'a> { | ||||
|     name: &'a str, | ||||
|     document: &'a Document<'a>, | ||||
| } | ||||
|  | ||||
| impl<'a> ValueView for FieldValue<'a> { | ||||
|     fn as_debug(&self) -> &dyn std::fmt::Debug { | ||||
|         self | ||||
|     } | ||||
|  | ||||
|     fn render(&self) -> liquid::model::DisplayCow<'_> { | ||||
|         DisplayCow::Owned(Box::new(ObjectRender::new(self))) | ||||
|     } | ||||
|  | ||||
|     fn source(&self) -> liquid::model::DisplayCow<'_> { | ||||
|         DisplayCow::Owned(Box::new(ObjectSource::new(self))) | ||||
|     } | ||||
|  | ||||
|     fn type_name(&self) -> &'static str { | ||||
|         "object" | ||||
|     } | ||||
|  | ||||
|     fn query_state(&self, state: liquid::model::State) -> bool { | ||||
|         match state { | ||||
|             State::Truthy => true, | ||||
|             State::DefaultValue | State::Empty | State::Blank => self.is_empty(), | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     fn to_kstr(&self) -> liquid::model::KStringCow<'_> { | ||||
|         let s = ObjectRender::new(self).to_string(); | ||||
|         KStringCow::from_string(s) | ||||
|     } | ||||
|  | ||||
|     fn to_value(&self) -> LiquidValue { | ||||
|         LiquidValue::Object( | ||||
|             self.iter().map(|(k, v)| (k.to_string().into(), v.to_value())).collect(), | ||||
|         ) | ||||
|     } | ||||
|  | ||||
|     fn as_object(&self) -> Option<&dyn ObjectView> { | ||||
|         Some(self) | ||||
|     } | ||||
| } | ||||
|  | ||||
| impl<'a> FieldValue<'a> { | ||||
|     pub fn name(&self) -> &&'a str { | ||||
|         &self.name | ||||
|     } | ||||
|  | ||||
|     pub fn value(&self) -> &dyn ValueView { | ||||
|         self.document.get(self.name).unwrap_or(&LiquidValue::Nil) | ||||
|     } | ||||
|  | ||||
|     pub fn is_empty(&self) -> bool { | ||||
|         self.size() == 0 | ||||
|     } | ||||
| } | ||||
|  | ||||
| impl<'a> ObjectView for FieldValue<'a> { | ||||
|     fn as_value(&self) -> &dyn ValueView { | ||||
|         self | ||||
|     } | ||||
|  | ||||
|     fn size(&self) -> i64 { | ||||
|         2 | ||||
|     } | ||||
|  | ||||
|     fn keys<'k>(&'k self) -> Box<dyn Iterator<Item = KStringCow<'k>> + 'k> { | ||||
|         Box::new(["name", "value"].iter().map(|&x| KStringCow::from_static(x))) | ||||
|     } | ||||
|  | ||||
|     fn values<'k>(&'k self) -> Box<dyn Iterator<Item = &'k dyn ValueView> + 'k> { | ||||
|         Box::new( | ||||
|             std::iter::once(self.name() as &dyn ValueView).chain(std::iter::once(self.value())), | ||||
|         ) | ||||
|     } | ||||
|  | ||||
|     fn iter<'k>(&'k self) -> Box<dyn Iterator<Item = (KStringCow<'k>, &'k dyn ValueView)> + 'k> { | ||||
|         Box::new(self.keys().zip(self.values())) | ||||
|     } | ||||
|  | ||||
|     fn contains_key(&self, index: &str) -> bool { | ||||
|         index == "name" || index == "value" | ||||
|     } | ||||
|  | ||||
|     fn get<'s>(&'s self, index: &str) -> Option<&'s dyn ValueView> { | ||||
|         match index { | ||||
|             "name" => Some(self.name()), | ||||
|             "value" => Some(self.value()), | ||||
|             _ => None, | ||||
|         } | ||||
|     } | ||||
| } | ||||
|  | ||||
| impl<'a> ArrayView for Fields<'a> { | ||||
|     fn as_value(&self) -> &dyn ValueView { | ||||
|         self.0.as_value() | ||||
|     } | ||||
|  | ||||
|     fn size(&self) -> i64 { | ||||
|         self.0.len() as i64 | ||||
|     } | ||||
|  | ||||
|     fn values<'k>(&'k self) -> Box<dyn Iterator<Item = &'k dyn ValueView> + 'k> { | ||||
|         self.0.values() | ||||
|     } | ||||
|  | ||||
|     fn contains_key(&self, index: i64) -> bool { | ||||
|         self.0.contains_key(index) | ||||
|     } | ||||
|  | ||||
|     fn get(&self, index: i64) -> Option<&dyn ValueView> { | ||||
|         ArrayView::get(&self.0, index) | ||||
|     } | ||||
| } | ||||
|  | ||||
| impl<'a> ValueView for Fields<'a> { | ||||
|     fn as_debug(&self) -> &dyn std::fmt::Debug { | ||||
|         self | ||||
|     } | ||||
|  | ||||
|     fn render(&self) -> liquid::model::DisplayCow<'_> { | ||||
|         self.0.render() | ||||
|     } | ||||
|  | ||||
|     fn source(&self) -> liquid::model::DisplayCow<'_> { | ||||
|         self.0.source() | ||||
|     } | ||||
|  | ||||
|     fn type_name(&self) -> &'static str { | ||||
|         self.0.type_name() | ||||
|     } | ||||
|  | ||||
|     fn query_state(&self, state: liquid::model::State) -> bool { | ||||
|         self.0.query_state(state) | ||||
|     } | ||||
|  | ||||
|     fn to_kstr(&self) -> liquid::model::KStringCow<'_> { | ||||
|         self.0.to_kstr() | ||||
|     } | ||||
|  | ||||
|     fn to_value(&self) -> LiquidValue { | ||||
|         self.0.to_value() | ||||
|     } | ||||
|  | ||||
|     fn as_array(&self) -> Option<&dyn ArrayView> { | ||||
|         Some(self) | ||||
|     } | ||||
| } | ||||
							
								
								
									
										144
									
								
								milli/src/prompt/mod.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										144
									
								
								milli/src/prompt/mod.rs
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,144 @@ | ||||
| mod context; | ||||
| mod document; | ||||
| pub(crate) mod error; | ||||
| mod fields; | ||||
| mod template_checker; | ||||
|  | ||||
| use std::convert::TryFrom; | ||||
|  | ||||
| use error::{NewPromptError, RenderPromptError}; | ||||
|  | ||||
| use self::context::Context; | ||||
| use self::document::Document; | ||||
| use crate::update::del_add::DelAdd; | ||||
| use crate::FieldsIdsMap; | ||||
|  | ||||
| pub struct Prompt { | ||||
|     template: liquid::Template, | ||||
|     template_text: String, | ||||
|     strategy: PromptFallbackStrategy, | ||||
|     fallback: String, | ||||
| } | ||||
|  | ||||
| #[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] | ||||
| pub struct PromptData { | ||||
|     pub template: String, | ||||
|     pub strategy: PromptFallbackStrategy, | ||||
|     pub fallback: String, | ||||
| } | ||||
|  | ||||
| impl From<Prompt> for PromptData { | ||||
|     fn from(value: Prompt) -> Self { | ||||
|         Self { template: value.template_text, strategy: value.strategy, fallback: value.fallback } | ||||
|     } | ||||
| } | ||||
|  | ||||
| impl TryFrom<PromptData> for Prompt { | ||||
|     type Error = NewPromptError; | ||||
|  | ||||
|     fn try_from(value: PromptData) -> Result<Self, Self::Error> { | ||||
|         Prompt::new(value.template, Some(value.strategy), Some(value.fallback)) | ||||
|     } | ||||
| } | ||||
|  | ||||
| impl Clone for Prompt { | ||||
|     fn clone(&self) -> Self { | ||||
|         let template_text = self.template_text.clone(); | ||||
|         Self { | ||||
|             template: new_template(&template_text).unwrap(), | ||||
|             template_text, | ||||
|             strategy: self.strategy, | ||||
|             fallback: self.fallback.clone(), | ||||
|         } | ||||
|     } | ||||
| } | ||||
|  | ||||
| fn new_template(text: &str) -> Result<liquid::Template, liquid::Error> { | ||||
|     liquid::ParserBuilder::with_stdlib().build().unwrap().parse(text) | ||||
| } | ||||
|  | ||||
| fn default_template() -> liquid::Template { | ||||
|     new_template(default_template_text()).unwrap() | ||||
| } | ||||
|  | ||||
| fn default_template_text() -> &'static str { | ||||
|     "{% for field in fields %} \ | ||||
|     {{ field.name }}: {{ field.value }}\n\ | ||||
|     {% endfor %}" | ||||
| } | ||||
|  | ||||
| fn default_fallback() -> &'static str { | ||||
|     "<MISSING>" | ||||
| } | ||||
|  | ||||
| impl Default for Prompt { | ||||
|     fn default() -> Self { | ||||
|         Self { | ||||
|             template: default_template(), | ||||
|             template_text: default_template_text().into(), | ||||
|             strategy: Default::default(), | ||||
|             fallback: default_fallback().into(), | ||||
|         } | ||||
|     } | ||||
| } | ||||
|  | ||||
| impl Default for PromptData { | ||||
|     fn default() -> Self { | ||||
|         Self { | ||||
|             template: default_template_text().into(), | ||||
|             strategy: Default::default(), | ||||
|             fallback: default_fallback().into(), | ||||
|         } | ||||
|     } | ||||
| } | ||||
|  | ||||
| impl Prompt { | ||||
|     pub fn new( | ||||
|         template: String, | ||||
|         strategy: Option<PromptFallbackStrategy>, | ||||
|         fallback: Option<String>, | ||||
|     ) -> Result<Self, NewPromptError> { | ||||
|         let this = Self { | ||||
|             template: liquid::ParserBuilder::with_stdlib() | ||||
|                 .build() | ||||
|                 .unwrap() | ||||
|                 .parse(&template) | ||||
|                 .map_err(NewPromptError::cannot_parse_template)?, | ||||
|             template_text: template, | ||||
|             strategy: strategy.unwrap_or_default(), | ||||
|             fallback: fallback.unwrap_or_default(), | ||||
|         }; | ||||
|  | ||||
|         // render template with special object that's OK with `doc.*` and `fields.*` | ||||
|         /// FIXME: doesn't work for nested objects e.g. `doc.a.b` | ||||
|         this.template | ||||
|             .render(&template_checker::TemplateChecker) | ||||
|             .map_err(NewPromptError::invalid_fields_in_template)?; | ||||
|  | ||||
|         Ok(this) | ||||
|     } | ||||
|  | ||||
|     pub fn render( | ||||
|         &self, | ||||
|         document: obkv::KvReaderU16<'_>, | ||||
|         side: DelAdd, | ||||
|         field_id_map: &FieldsIdsMap, | ||||
|     ) -> Result<String, RenderPromptError> { | ||||
|         let document = Document::new(document, side, field_id_map); | ||||
|         let context = Context::new(&document, field_id_map); | ||||
|  | ||||
|         self.template.render(&context).map_err(RenderPromptError::missing_context) | ||||
|     } | ||||
| } | ||||
|  | ||||
| #[derive( | ||||
|     Debug, Default, Clone, PartialEq, Eq, Copy, serde::Serialize, serde::Deserialize, deserr::Deserr, | ||||
| )] | ||||
| #[serde(deny_unknown_fields, rename_all = "camelCase")] | ||||
| #[deserr(rename_all = camelCase, deny_unknown_fields)] | ||||
| pub enum PromptFallbackStrategy { | ||||
|     Fallback, | ||||
|     Skip, | ||||
|     #[default] | ||||
|     Error, | ||||
| } | ||||
							
								
								
									
										282
									
								
								milli/src/prompt/template_checker.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										282
									
								
								milli/src/prompt/template_checker.rs
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,282 @@ | ||||
| use liquid::model::{ | ||||
|     ArrayView, DisplayCow, KStringCow, ObjectRender, ObjectSource, State, Value as LiquidValue, | ||||
| }; | ||||
| use liquid::{ObjectView, ValueView}; | ||||
|  | ||||
| #[derive(Debug)] | ||||
| pub struct TemplateChecker; | ||||
|  | ||||
| #[derive(Debug)] | ||||
| pub struct DummyDoc; | ||||
|  | ||||
| #[derive(Debug)] | ||||
| pub struct DummyFields; | ||||
|  | ||||
| #[derive(Debug)] | ||||
| pub struct DummyField; | ||||
|  | ||||
| const DUMMY_VALUE: &LiquidValue = &LiquidValue::Nil; | ||||
|  | ||||
| impl ObjectView for DummyField { | ||||
|     fn as_value(&self) -> &dyn ValueView { | ||||
|         self | ||||
|     } | ||||
|  | ||||
|     fn size(&self) -> i64 { | ||||
|         2 | ||||
|     } | ||||
|  | ||||
|     fn keys<'k>(&'k self) -> Box<dyn Iterator<Item = KStringCow<'k>> + 'k> { | ||||
|         Box::new(["name", "value"].iter().map(|s| KStringCow::from_static(s))) | ||||
|     } | ||||
|  | ||||
|     fn values<'k>(&'k self) -> Box<dyn Iterator<Item = &'k dyn ValueView> + 'k> { | ||||
|         Box::new(std::iter::empty()) | ||||
|     } | ||||
|  | ||||
|     fn iter<'k>(&'k self) -> Box<dyn Iterator<Item = (KStringCow<'k>, &'k dyn ValueView)> + 'k> { | ||||
|         Box::new(std::iter::empty()) | ||||
|     } | ||||
|  | ||||
|     fn contains_key(&self, index: &str) -> bool { | ||||
|         index == "name" || index == "value" | ||||
|     } | ||||
|  | ||||
|     fn get<'s>(&'s self, index: &str) -> Option<&'s dyn ValueView> { | ||||
|         if self.contains_key(index) { | ||||
|             Some(DUMMY_VALUE.as_view()) | ||||
|         } else { | ||||
|             None | ||||
|         } | ||||
|     } | ||||
| } | ||||
|  | ||||
| impl ValueView for DummyField { | ||||
|     fn as_debug(&self) -> &dyn std::fmt::Debug { | ||||
|         self | ||||
|     } | ||||
|  | ||||
|     fn render(&self) -> DisplayCow<'_> { | ||||
|         DUMMY_VALUE.render() | ||||
|     } | ||||
|  | ||||
|     fn source(&self) -> DisplayCow<'_> { | ||||
|         DUMMY_VALUE.source() | ||||
|     } | ||||
|  | ||||
|     fn type_name(&self) -> &'static str { | ||||
|         "object" | ||||
|     } | ||||
|  | ||||
|     fn query_state(&self, state: State) -> bool { | ||||
|         DUMMY_VALUE.query_state(state) | ||||
|     } | ||||
|  | ||||
|     fn to_kstr(&self) -> KStringCow<'_> { | ||||
|         DUMMY_VALUE.to_kstr() | ||||
|     } | ||||
|  | ||||
|     fn to_value(&self) -> LiquidValue { | ||||
|         LiquidValue::Nil | ||||
|     } | ||||
|  | ||||
|     fn as_object(&self) -> Option<&dyn ObjectView> { | ||||
|         Some(self) | ||||
|     } | ||||
| } | ||||
|  | ||||
| impl ValueView for DummyFields { | ||||
|     fn as_debug(&self) -> &dyn std::fmt::Debug { | ||||
|         self | ||||
|     } | ||||
|  | ||||
|     fn render(&self) -> DisplayCow<'_> { | ||||
|         DUMMY_VALUE.render() | ||||
|     } | ||||
|  | ||||
|     fn source(&self) -> DisplayCow<'_> { | ||||
|         DUMMY_VALUE.source() | ||||
|     } | ||||
|  | ||||
|     fn type_name(&self) -> &'static str { | ||||
|         "array" | ||||
|     } | ||||
|  | ||||
|     fn query_state(&self, state: State) -> bool { | ||||
|         DUMMY_VALUE.query_state(state) | ||||
|     } | ||||
|  | ||||
|     fn to_kstr(&self) -> KStringCow<'_> { | ||||
|         DUMMY_VALUE.to_kstr() | ||||
|     } | ||||
|  | ||||
|     fn to_value(&self) -> LiquidValue { | ||||
|         LiquidValue::Nil | ||||
|     } | ||||
|  | ||||
|     fn as_array(&self) -> Option<&dyn ArrayView> { | ||||
|         Some(self) | ||||
|     } | ||||
| } | ||||
|  | ||||
| impl ArrayView for DummyFields { | ||||
|     fn as_value(&self) -> &dyn ValueView { | ||||
|         self | ||||
|     } | ||||
|  | ||||
|     fn size(&self) -> i64 { | ||||
|         i64::MAX | ||||
|     } | ||||
|  | ||||
|     fn values<'k>(&'k self) -> Box<dyn Iterator<Item = &'k dyn ValueView> + 'k> { | ||||
|         Box::new(std::iter::empty()) | ||||
|     } | ||||
|  | ||||
|     fn contains_key(&self, _index: i64) -> bool { | ||||
|         true | ||||
|     } | ||||
|  | ||||
|     fn get(&self, _index: i64) -> Option<&dyn ValueView> { | ||||
|         Some(DummyField.as_value()) | ||||
|     } | ||||
| } | ||||
|  | ||||
| impl ObjectView for DummyDoc { | ||||
|     fn as_value(&self) -> &dyn ValueView { | ||||
|         self | ||||
|     } | ||||
|  | ||||
|     fn size(&self) -> i64 { | ||||
|         1000 | ||||
|     } | ||||
|  | ||||
|     fn keys<'k>(&'k self) -> Box<dyn Iterator<Item = KStringCow<'k>> + 'k> { | ||||
|         Box::new(std::iter::empty()) | ||||
|     } | ||||
|  | ||||
|     fn values<'k>(&'k self) -> Box<dyn Iterator<Item = &'k dyn ValueView> + 'k> { | ||||
|         Box::new(std::iter::empty()) | ||||
|     } | ||||
|  | ||||
|     fn iter<'k>(&'k self) -> Box<dyn Iterator<Item = (KStringCow<'k>, &'k dyn ValueView)> + 'k> { | ||||
|         Box::new(std::iter::empty()) | ||||
|     } | ||||
|  | ||||
|     fn contains_key(&self, _index: &str) -> bool { | ||||
|         true | ||||
|     } | ||||
|  | ||||
|     fn get<'s>(&'s self, _index: &str) -> Option<&'s dyn ValueView> { | ||||
|         Some(DUMMY_VALUE.as_view()) | ||||
|     } | ||||
| } | ||||
|  | ||||
| impl ValueView for DummyDoc { | ||||
|     fn as_debug(&self) -> &dyn std::fmt::Debug { | ||||
|         self | ||||
|     } | ||||
|  | ||||
|     fn render(&self) -> DisplayCow<'_> { | ||||
|         DUMMY_VALUE.render() | ||||
|     } | ||||
|  | ||||
|     fn source(&self) -> DisplayCow<'_> { | ||||
|         DUMMY_VALUE.source() | ||||
|     } | ||||
|  | ||||
|     fn type_name(&self) -> &'static str { | ||||
|         "object" | ||||
|     } | ||||
|  | ||||
|     fn query_state(&self, state: State) -> bool { | ||||
|         DUMMY_VALUE.query_state(state) | ||||
|     } | ||||
|  | ||||
|     fn to_kstr(&self) -> KStringCow<'_> { | ||||
|         DUMMY_VALUE.to_kstr() | ||||
|     } | ||||
|  | ||||
|     fn to_value(&self) -> LiquidValue { | ||||
|         LiquidValue::Nil | ||||
|     } | ||||
|  | ||||
|     fn as_object(&self) -> Option<&dyn ObjectView> { | ||||
|         Some(self) | ||||
|     } | ||||
| } | ||||
|  | ||||
| impl ObjectView for TemplateChecker { | ||||
|     fn as_value(&self) -> &dyn ValueView { | ||||
|         self | ||||
|     } | ||||
|  | ||||
|     fn size(&self) -> i64 { | ||||
|         2 | ||||
|     } | ||||
|  | ||||
|     fn keys<'k>(&'k self) -> Box<dyn Iterator<Item = KStringCow<'k>> + 'k> { | ||||
|         Box::new(["doc", "fields"].iter().map(|s| KStringCow::from_static(s))) | ||||
|     } | ||||
|  | ||||
|     fn values<'k>(&'k self) -> Box<dyn Iterator<Item = &'k dyn ValueView> + 'k> { | ||||
|         Box::new( | ||||
|             std::iter::once(DummyDoc.as_value()).chain(std::iter::once(DummyFields.as_value())), | ||||
|         ) | ||||
|     } | ||||
|  | ||||
|     fn iter<'k>(&'k self) -> Box<dyn Iterator<Item = (KStringCow<'k>, &'k dyn ValueView)> + 'k> { | ||||
|         Box::new(self.keys().zip(self.values())) | ||||
|     } | ||||
|  | ||||
|     fn contains_key(&self, index: &str) -> bool { | ||||
|         index == "doc" || index == "fields" | ||||
|     } | ||||
|  | ||||
|     fn get<'s>(&'s self, index: &str) -> Option<&'s dyn ValueView> { | ||||
|         match index { | ||||
|             "doc" => Some(DummyDoc.as_value()), | ||||
|             "fields" => Some(DummyFields.as_value()), | ||||
|             _ => None, | ||||
|         } | ||||
|     } | ||||
| } | ||||
|  | ||||
| impl ValueView for TemplateChecker { | ||||
|     fn as_debug(&self) -> &dyn std::fmt::Debug { | ||||
|         self | ||||
|     } | ||||
|  | ||||
|     fn render(&self) -> liquid::model::DisplayCow<'_> { | ||||
|         DisplayCow::Owned(Box::new(ObjectRender::new(self))) | ||||
|     } | ||||
|  | ||||
|     fn source(&self) -> liquid::model::DisplayCow<'_> { | ||||
|         DisplayCow::Owned(Box::new(ObjectSource::new(self))) | ||||
|     } | ||||
|  | ||||
|     fn type_name(&self) -> &'static str { | ||||
|         "object" | ||||
|     } | ||||
|  | ||||
|     fn query_state(&self, state: liquid::model::State) -> bool { | ||||
|         match state { | ||||
|             State::Truthy => true, | ||||
|             State::DefaultValue | State::Empty | State::Blank => false, | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     fn to_kstr(&self) -> liquid::model::KStringCow<'_> { | ||||
|         let s = ObjectRender::new(self).to_string(); | ||||
|         KStringCow::from_string(s) | ||||
|     } | ||||
|  | ||||
|     fn to_value(&self) -> LiquidValue { | ||||
|         LiquidValue::Object( | ||||
|             self.iter().map(|(k, x)| (k.to_string().into(), x.to_value())).collect(), | ||||
|         ) | ||||
|     } | ||||
|  | ||||
|     fn as_object(&self) -> Option<&dyn ObjectView> { | ||||
|         Some(self) | ||||
|     } | ||||
| } | ||||
| @@ -1,3 +1,6 @@ | ||||
| use std::cmp::Ordering; | ||||
|  | ||||
| use itertools::Itertools; | ||||
| use serde::Serialize; | ||||
|  | ||||
| use crate::distance_between_two_points; | ||||
| @@ -12,9 +15,24 @@ pub enum ScoreDetails { | ||||
|     ExactAttribute(ExactAttribute), | ||||
|     ExactWords(ExactWords), | ||||
|     Sort(Sort), | ||||
|     Vector(Vector), | ||||
|     GeoSort(GeoSort), | ||||
| } | ||||
|  | ||||
| #[derive(Clone, Copy)] | ||||
| pub enum ScoreValue<'a> { | ||||
|     Score(f64), | ||||
|     Sort(&'a Sort), | ||||
|     GeoSort(&'a GeoSort), | ||||
| } | ||||
|  | ||||
| enum RankOrValue<'a> { | ||||
|     Rank(Rank), | ||||
|     Sort(&'a Sort), | ||||
|     GeoSort(&'a GeoSort), | ||||
|     Score(f64), | ||||
| } | ||||
|  | ||||
| impl ScoreDetails { | ||||
|     pub fn local_score(&self) -> Option<f64> { | ||||
|         self.rank().map(Rank::local_score) | ||||
| @@ -31,11 +49,55 @@ impl ScoreDetails { | ||||
|             ScoreDetails::ExactWords(details) => Some(details.rank()), | ||||
|             ScoreDetails::Sort(_) => None, | ||||
|             ScoreDetails::GeoSort(_) => None, | ||||
|             ScoreDetails::Vector(_) => None, | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     pub fn global_score<'a>(details: impl Iterator<Item = &'a Self>) -> f64 { | ||||
|         Rank::global_score(details.filter_map(Self::rank)) | ||||
|     pub fn global_score<'a>(details: impl Iterator<Item = &'a Self> + 'a) -> f64 { | ||||
|         Self::score_values(details) | ||||
|             .find_map(|x| { | ||||
|                 let ScoreValue::Score(score) = x else { | ||||
|                     return None; | ||||
|                 }; | ||||
|                 Some(score) | ||||
|             }) | ||||
|             .unwrap_or(1.0f64) | ||||
|     } | ||||
|  | ||||
|     pub fn score_values<'a>( | ||||
|         details: impl Iterator<Item = &'a Self> + 'a, | ||||
|     ) -> impl Iterator<Item = ScoreValue<'a>> + 'a { | ||||
|         details | ||||
|             .map(ScoreDetails::rank_or_value) | ||||
|             .coalesce(|left, right| match (left, right) { | ||||
|                 (RankOrValue::Rank(left), RankOrValue::Rank(right)) => { | ||||
|                     Ok(RankOrValue::Rank(Rank::merge(left, right))) | ||||
|                 } | ||||
|                 (left, right) => Err((left, right)), | ||||
|             }) | ||||
|             .map(|rank_or_value| match rank_or_value { | ||||
|                 RankOrValue::Rank(r) => ScoreValue::Score(r.local_score()), | ||||
|                 RankOrValue::Sort(s) => ScoreValue::Sort(s), | ||||
|                 RankOrValue::GeoSort(g) => ScoreValue::GeoSort(g), | ||||
|                 RankOrValue::Score(s) => ScoreValue::Score(s), | ||||
|             }) | ||||
|     } | ||||
|  | ||||
|     fn rank_or_value(&self) -> RankOrValue<'_> { | ||||
|         match self { | ||||
|             ScoreDetails::Words(w) => RankOrValue::Rank(w.rank()), | ||||
|             ScoreDetails::Typo(t) => RankOrValue::Rank(t.rank()), | ||||
|             ScoreDetails::Proximity(p) => RankOrValue::Rank(*p), | ||||
|             ScoreDetails::Fid(f) => RankOrValue::Rank(*f), | ||||
|             ScoreDetails::Position(p) => RankOrValue::Rank(*p), | ||||
|             ScoreDetails::ExactAttribute(e) => RankOrValue::Rank(e.rank()), | ||||
|             ScoreDetails::ExactWords(e) => RankOrValue::Rank(e.rank()), | ||||
|             ScoreDetails::Sort(sort) => RankOrValue::Sort(sort), | ||||
|             ScoreDetails::GeoSort(geosort) => RankOrValue::GeoSort(geosort), | ||||
|             ScoreDetails::Vector(vector) => RankOrValue::Score( | ||||
|                 vector.value_similarity.as_ref().map(|(_, s)| *s as f64).unwrap_or(0.0f64), | ||||
|             ), | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     /// Panics | ||||
| @@ -181,6 +243,19 @@ impl ScoreDetails { | ||||
|                     details_map.insert(sort, sort_details); | ||||
|                     order += 1; | ||||
|                 } | ||||
|                 ScoreDetails::Vector(s) => { | ||||
|                     let vector = format!("vectorSort({:?})", s.target_vector); | ||||
|                     let value = s.value_similarity.as_ref().map(|(v, _)| v); | ||||
|                     let similarity = s.value_similarity.as_ref().map(|(_, s)| s); | ||||
|  | ||||
|                     let details = serde_json::json!({ | ||||
|                         "order": order, | ||||
|                         "value": value, | ||||
|                         "similarity": similarity, | ||||
|                     }); | ||||
|                     details_map.insert(vector, details); | ||||
|                     order += 1; | ||||
|                 } | ||||
|             } | ||||
|         } | ||||
|         details_map | ||||
| @@ -297,15 +372,21 @@ impl Rank { | ||||
|     pub fn global_score(details: impl Iterator<Item = Self>) -> f64 { | ||||
|         let mut rank = Rank { rank: 1, max_rank: 1 }; | ||||
|         for inner_rank in details { | ||||
|             rank.rank -= 1; | ||||
|  | ||||
|             rank.rank *= inner_rank.max_rank; | ||||
|             rank.max_rank *= inner_rank.max_rank; | ||||
|  | ||||
|             rank.rank += inner_rank.rank; | ||||
|             rank = Rank::merge(rank, inner_rank); | ||||
|         } | ||||
|         rank.local_score() | ||||
|     } | ||||
|  | ||||
|     pub fn merge(mut outer: Rank, inner: Rank) -> Rank { | ||||
|         outer.rank = outer.rank.saturating_sub(1); | ||||
|  | ||||
|         outer.rank *= inner.max_rank; | ||||
|         outer.max_rank *= inner.max_rank; | ||||
|  | ||||
|         outer.rank += inner.rank; | ||||
|  | ||||
|         outer | ||||
|     } | ||||
| } | ||||
|  | ||||
| #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize)] | ||||
| @@ -335,13 +416,78 @@ pub struct Sort { | ||||
|     pub value: serde_json::Value, | ||||
| } | ||||
|  | ||||
| #[derive(Debug, Clone, Copy, PartialEq, PartialOrd)] | ||||
| impl PartialOrd for Sort { | ||||
|     fn partial_cmp(&self, other: &Self) -> Option<Ordering> { | ||||
|         if self.field_name != other.field_name { | ||||
|             return None; | ||||
|         } | ||||
|         if self.ascending != other.ascending { | ||||
|             return None; | ||||
|         } | ||||
|         match (&self.value, &other.value) { | ||||
|             (serde_json::Value::Null, serde_json::Value::Null) => Some(Ordering::Equal), | ||||
|             (serde_json::Value::Null, _) => Some(Ordering::Less), | ||||
|             (_, serde_json::Value::Null) => Some(Ordering::Greater), | ||||
|             // numbers are always before strings | ||||
|             (serde_json::Value::Number(_), serde_json::Value::String(_)) => Some(Ordering::Greater), | ||||
|             (serde_json::Value::String(_), serde_json::Value::Number(_)) => Some(Ordering::Less), | ||||
|             (serde_json::Value::Number(left), serde_json::Value::Number(right)) => { | ||||
|                 // FIXME: unwrap permitted here? | ||||
|                 let order = left.as_f64().unwrap().partial_cmp(&right.as_f64().unwrap())?; | ||||
|                 // 12 < 42, and when ascending, we want to see 12 first, so the smallest. | ||||
|                 // Hence, when ascending, smaller is better | ||||
|                 Some(if self.ascending { order.reverse() } else { order }) | ||||
|             } | ||||
|             (serde_json::Value::String(left), serde_json::Value::String(right)) => { | ||||
|                 let order = left.cmp(right); | ||||
|                 // Taking e.g. "a" and "z" | ||||
|                 // "a" < "z", and when ascending, we want to see "a" first, so the smallest. | ||||
|                 // Hence, when ascending, smaller is better | ||||
|                 Some(if self.ascending { order.reverse() } else { order }) | ||||
|             } | ||||
|             _ => None, | ||||
|         } | ||||
|     } | ||||
| } | ||||
|  | ||||
| #[derive(Debug, Clone, Copy, PartialEq)] | ||||
| pub struct GeoSort { | ||||
|     pub target_point: [f64; 2], | ||||
|     pub ascending: bool, | ||||
|     pub value: Option<[f64; 2]>, | ||||
| } | ||||
|  | ||||
| impl PartialOrd for GeoSort { | ||||
|     fn partial_cmp(&self, other: &Self) -> Option<Ordering> { | ||||
|         if self.target_point != other.target_point { | ||||
|             return None; | ||||
|         } | ||||
|         if self.ascending != other.ascending { | ||||
|             return None; | ||||
|         } | ||||
|         Some(match (self.distance(), other.distance()) { | ||||
|             (None, None) => Ordering::Equal, | ||||
|             (None, Some(_)) => Ordering::Less, | ||||
|             (Some(_), None) => Ordering::Greater, | ||||
|             (Some(left), Some(right)) => { | ||||
|                 let order = left.partial_cmp(&right)?; | ||||
|                 if self.ascending { | ||||
|                     // when ascending, the one with the smallest distance has the best score | ||||
|                     order.reverse() | ||||
|                 } else { | ||||
|                     order | ||||
|                 } | ||||
|             } | ||||
|         }) | ||||
|     } | ||||
| } | ||||
|  | ||||
| #[derive(Debug, Clone, PartialEq, PartialOrd)] | ||||
| pub struct Vector { | ||||
|     pub target_vector: Vec<f32>, | ||||
|     pub value_similarity: Option<(Vec<f32>, f32)>, | ||||
| } | ||||
|  | ||||
| impl GeoSort { | ||||
|     pub fn distance(&self) -> Option<f64> { | ||||
|         self.value.map(|value| distance_between_two_points(&self.target_point, &value)) | ||||
|   | ||||
							
								
								
									
										336
									
								
								milli/src/search/hybrid.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										336
									
								
								milli/src/search/hybrid.rs
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,336 @@ | ||||
| use std::cmp::Ordering; | ||||
| use std::collections::HashMap; | ||||
|  | ||||
| use itertools::Itertools; | ||||
| use roaring::RoaringBitmap; | ||||
|  | ||||
| use super::new::{execute_vector_search, PartialSearchResult}; | ||||
| use crate::score_details::{ScoreDetails, ScoreValue, ScoringStrategy}; | ||||
| use crate::{ | ||||
|     execute_search, DefaultSearchLogger, MatchingWords, Result, Search, SearchContext, SearchResult, | ||||
| }; | ||||
|  | ||||
| struct CombinedSearchResult { | ||||
|     matching_words: MatchingWords, | ||||
|     candidates: RoaringBitmap, | ||||
|     document_scores: Vec<(u32, CombinedScore)>, | ||||
| } | ||||
|  | ||||
| type CombinedScore = (Vec<ScoreDetails>, Option<Vec<ScoreDetails>>); | ||||
|  | ||||
| fn compare_scores(left: &CombinedScore, right: &CombinedScore) -> Ordering { | ||||
|     let mut left_main_it = ScoreDetails::score_values(left.0.iter()); | ||||
|     let mut left_sub_it = | ||||
|         ScoreDetails::score_values(left.1.as_ref().map(|x| x.iter()).into_iter().flatten()); | ||||
|  | ||||
|     let mut right_main_it = ScoreDetails::score_values(right.0.iter()); | ||||
|     let mut right_sub_it = | ||||
|         ScoreDetails::score_values(right.1.as_ref().map(|x| x.iter()).into_iter().flatten()); | ||||
|  | ||||
|     let mut left_main = left_main_it.next(); | ||||
|     let mut left_sub = left_sub_it.next(); | ||||
|     let mut right_main = right_main_it.next(); | ||||
|     let mut right_sub = right_sub_it.next(); | ||||
|  | ||||
|     loop { | ||||
|         let left = | ||||
|             take_best_score(&mut left_main, &mut left_sub, &mut left_main_it, &mut left_sub_it); | ||||
|  | ||||
|         let right = | ||||
|             take_best_score(&mut right_main, &mut right_sub, &mut right_main_it, &mut right_sub_it); | ||||
|  | ||||
|         match (left, right) { | ||||
|             (None, None) => return Ordering::Equal, | ||||
|             (None, Some(_)) => return Ordering::Less, | ||||
|             (Some(_), None) => return Ordering::Greater, | ||||
|             (Some(ScoreValue::Score(left)), Some(ScoreValue::Score(right))) => { | ||||
|                 if (left - right).abs() <= f64::EPSILON { | ||||
|                     continue; | ||||
|                 } | ||||
|                 return left.partial_cmp(&right).unwrap(); | ||||
|             } | ||||
|             (Some(ScoreValue::Sort(left)), Some(ScoreValue::Sort(right))) => { | ||||
|                 match left.partial_cmp(right).unwrap() { | ||||
|                     Ordering::Equal => continue, | ||||
|                     order => return order, | ||||
|                 } | ||||
|             } | ||||
|             (Some(ScoreValue::GeoSort(left)), Some(ScoreValue::GeoSort(right))) => { | ||||
|                 match left.partial_cmp(right).unwrap() { | ||||
|                     Ordering::Equal => continue, | ||||
|                     order => return order, | ||||
|                 } | ||||
|             } | ||||
|             (Some(ScoreValue::Score(_)), Some(_)) => return Ordering::Greater, | ||||
|             (Some(_), Some(ScoreValue::Score(_))) => return Ordering::Less, | ||||
|             // if we have this, we're bad | ||||
|             (Some(ScoreValue::GeoSort(_)), Some(ScoreValue::Sort(_))) | ||||
|             | (Some(ScoreValue::Sort(_)), Some(ScoreValue::GeoSort(_))) => { | ||||
|                 unreachable!("Unexpected geo and sort comparison") | ||||
|             } | ||||
|         } | ||||
|     } | ||||
| } | ||||
|  | ||||
| fn take_best_score<'a>( | ||||
|     main_score: &mut Option<ScoreValue<'a>>, | ||||
|     sub_score: &mut Option<ScoreValue<'a>>, | ||||
|     main_it: &mut impl Iterator<Item = ScoreValue<'a>>, | ||||
|     sub_it: &mut impl Iterator<Item = ScoreValue<'a>>, | ||||
| ) -> Option<ScoreValue<'a>> { | ||||
|     match (*main_score, *sub_score) { | ||||
|         (Some(main), None) => { | ||||
|             *main_score = main_it.next(); | ||||
|             Some(main) | ||||
|         } | ||||
|         (None, Some(sub)) => { | ||||
|             *sub_score = sub_it.next(); | ||||
|             Some(sub) | ||||
|         } | ||||
|         (main @ Some(ScoreValue::Score(main_f)), sub @ Some(ScoreValue::Score(sub_v))) => { | ||||
|             // take max, both advance | ||||
|             *main_score = main_it.next(); | ||||
|             *sub_score = sub_it.next(); | ||||
|             if main_f >= sub_v { | ||||
|                 main | ||||
|             } else { | ||||
|                 sub | ||||
|             } | ||||
|         } | ||||
|         (main @ Some(ScoreValue::Score(_)), _) => { | ||||
|             *main_score = main_it.next(); | ||||
|             main | ||||
|         } | ||||
|         (_, sub @ Some(ScoreValue::Score(_))) => { | ||||
|             *sub_score = sub_it.next(); | ||||
|             sub | ||||
|         } | ||||
|         (main @ Some(ScoreValue::GeoSort(main_geo)), sub @ Some(ScoreValue::GeoSort(sub_geo))) => { | ||||
|             // take best advance both | ||||
|             *main_score = main_it.next(); | ||||
|             *sub_score = sub_it.next(); | ||||
|             if main_geo >= sub_geo { | ||||
|                 main | ||||
|             } else { | ||||
|                 sub | ||||
|             } | ||||
|         } | ||||
|         (main @ Some(ScoreValue::Sort(main_sort)), sub @ Some(ScoreValue::Sort(sub_sort))) => { | ||||
|             // take best advance both | ||||
|             *main_score = main_it.next(); | ||||
|             *sub_score = sub_it.next(); | ||||
|             if main_sort >= sub_sort { | ||||
|                 main | ||||
|             } else { | ||||
|                 sub | ||||
|             } | ||||
|         } | ||||
|         ( | ||||
|             Some(ScoreValue::GeoSort(_) | ScoreValue::Sort(_)), | ||||
|             Some(ScoreValue::GeoSort(_) | ScoreValue::Sort(_)), | ||||
|         ) => None, | ||||
|  | ||||
|         (None, None) => None, | ||||
|     } | ||||
| } | ||||
|  | ||||
| impl CombinedSearchResult { | ||||
|     fn new(main_results: SearchResult, ancillary_results: PartialSearchResult) -> Self { | ||||
|         let mut docid_scores = HashMap::new(); | ||||
|         for (docid, score) in | ||||
|             main_results.documents_ids.iter().zip(main_results.document_scores.into_iter()) | ||||
|         { | ||||
|             docid_scores.insert(*docid, (score, None)); | ||||
|         } | ||||
|  | ||||
|         for (docid, score) in ancillary_results | ||||
|             .documents_ids | ||||
|             .iter() | ||||
|             .zip(ancillary_results.document_scores.into_iter()) | ||||
|         { | ||||
|             docid_scores | ||||
|                 .entry(*docid) | ||||
|                 .and_modify(|(_main_score, ancillary_score)| *ancillary_score = Some(score)); | ||||
|         } | ||||
|  | ||||
|         let mut document_scores: Vec<_> = docid_scores.into_iter().collect(); | ||||
|  | ||||
|         document_scores.sort_by(|(_, left), (_, right)| compare_scores(left, right).reverse()); | ||||
|  | ||||
|         Self { | ||||
|             matching_words: main_results.matching_words, | ||||
|             candidates: main_results.candidates, | ||||
|             document_scores, | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     fn merge(left: Self, right: Self, from: usize, length: usize) -> SearchResult { | ||||
|         let mut documents_ids = | ||||
|             Vec::with_capacity(left.document_scores.len() + right.document_scores.len()); | ||||
|         let mut document_scores = | ||||
|             Vec::with_capacity(left.document_scores.len() + right.document_scores.len()); | ||||
|  | ||||
|         let mut documents_seen = RoaringBitmap::new(); | ||||
|         for (docid, (main_score, _sub_score)) in left | ||||
|             .document_scores | ||||
|             .into_iter() | ||||
|             .merge_by(right.document_scores.into_iter(), |(_, left), (_, right)| { | ||||
|                 // the first value is the one with the greatest score | ||||
|                 compare_scores(left, right).is_ge() | ||||
|             }) | ||||
|             // remove documents we already saw | ||||
|             .filter(|(docid, _)| documents_seen.insert(*docid)) | ||||
|             // start skipping **after** the filter | ||||
|             .skip(from) | ||||
|             // take **after** skipping | ||||
|             .take(length) | ||||
|         { | ||||
|             documents_ids.push(docid); | ||||
|             // TODO: pass both scores to documents_score in some way? | ||||
|             document_scores.push(main_score); | ||||
|         } | ||||
|  | ||||
|         SearchResult { | ||||
|             matching_words: left.matching_words, | ||||
|             candidates: left.candidates | right.candidates, | ||||
|             documents_ids, | ||||
|             document_scores, | ||||
|         } | ||||
|     } | ||||
| } | ||||
|  | ||||
| impl<'a> Search<'a> { | ||||
|     pub fn execute_hybrid(&self) -> Result<SearchResult> { | ||||
|         // TODO: find classier way to achieve that than to reset vector and query params | ||||
|         // create separate keyword and semantic searches | ||||
|         let mut search = Search { | ||||
|             query: self.query.clone(), | ||||
|             vector: self.vector.clone(), | ||||
|             filter: self.filter.clone(), | ||||
|             offset: 0, | ||||
|             limit: self.limit + self.offset, | ||||
|             sort_criteria: self.sort_criteria.clone(), | ||||
|             searchable_attributes: self.searchable_attributes, | ||||
|             geo_strategy: self.geo_strategy, | ||||
|             terms_matching_strategy: self.terms_matching_strategy, | ||||
|             scoring_strategy: ScoringStrategy::Detailed, | ||||
|             words_limit: self.words_limit, | ||||
|             exhaustive_number_hits: self.exhaustive_number_hits, | ||||
|             rtxn: self.rtxn, | ||||
|             index: self.index, | ||||
|         }; | ||||
|  | ||||
|         let vector_query = search.vector.take(); | ||||
|         let keyword_query = self.query.as_deref(); | ||||
|  | ||||
|         let keyword_results = search.execute()?; | ||||
|  | ||||
|         // skip semantic search if we don't have a vector query (placeholder search) | ||||
|         let Some(vector_query) = vector_query else { | ||||
|             return Ok(keyword_results); | ||||
|         }; | ||||
|  | ||||
|         // completely skip semantic search if the results of the keyword search are good enough | ||||
|         if self.results_good_enough(&keyword_results) { | ||||
|             return Ok(keyword_results); | ||||
|         } | ||||
|  | ||||
|         search.vector = Some(vector_query); | ||||
|         search.query = None; | ||||
|  | ||||
|         // TODO: would be better to have two distinct functions at this point | ||||
|         let vector_results = search.execute()?; | ||||
|  | ||||
|         // Compute keyword scores for vector_results | ||||
|         let keyword_results_for_vector = | ||||
|             self.keyword_results_for_vector(keyword_query, &vector_results)?; | ||||
|  | ||||
|         // compute vector scores for keyword_results | ||||
|         let vector_results_for_keyword = | ||||
|             // can unwrap because we returned already if there was no vector query | ||||
|             self.vector_results_for_keyword(search.vector.as_ref().unwrap(), &keyword_results)?; | ||||
|  | ||||
|         let keyword_results = | ||||
|             CombinedSearchResult::new(keyword_results, vector_results_for_keyword); | ||||
|         let vector_results = CombinedSearchResult::new(vector_results, keyword_results_for_vector); | ||||
|  | ||||
|         let merge_results = | ||||
|             CombinedSearchResult::merge(vector_results, keyword_results, self.offset, self.limit); | ||||
|         assert!(merge_results.documents_ids.len() <= self.limit); | ||||
|         Ok(merge_results) | ||||
|     } | ||||
|  | ||||
|     fn vector_results_for_keyword( | ||||
|         &self, | ||||
|         vector: &[f32], | ||||
|         keyword_results: &SearchResult, | ||||
|     ) -> Result<PartialSearchResult> { | ||||
|         let mut ctx = SearchContext::new(self.index, self.rtxn); | ||||
|  | ||||
|         if let Some(searchable_attributes) = self.searchable_attributes { | ||||
|             ctx.searchable_attributes(searchable_attributes)?; | ||||
|         } | ||||
|  | ||||
|         let universe = keyword_results.documents_ids.iter().collect(); | ||||
|  | ||||
|         execute_vector_search( | ||||
|             &mut ctx, | ||||
|             vector, | ||||
|             ScoringStrategy::Detailed, | ||||
|             universe, | ||||
|             &self.sort_criteria, | ||||
|             self.geo_strategy, | ||||
|             0, | ||||
|             self.limit + self.offset, | ||||
|         ) | ||||
|     } | ||||
|  | ||||
|     fn keyword_results_for_vector( | ||||
|         &self, | ||||
|         query: Option<&str>, | ||||
|         vector_results: &SearchResult, | ||||
|     ) -> Result<PartialSearchResult> { | ||||
|         let mut ctx = SearchContext::new(self.index, self.rtxn); | ||||
|  | ||||
|         if let Some(searchable_attributes) = self.searchable_attributes { | ||||
|             ctx.searchable_attributes(searchable_attributes)?; | ||||
|         } | ||||
|  | ||||
|         let universe = vector_results.documents_ids.iter().collect(); | ||||
|  | ||||
|         execute_search( | ||||
|             &mut ctx, | ||||
|             query, | ||||
|             self.terms_matching_strategy, | ||||
|             ScoringStrategy::Detailed, | ||||
|             self.exhaustive_number_hits, | ||||
|             universe, | ||||
|             &self.sort_criteria, | ||||
|             self.geo_strategy, | ||||
|             0, | ||||
|             self.limit + self.offset, | ||||
|             Some(self.words_limit), | ||||
|             &mut DefaultSearchLogger, | ||||
|             &mut DefaultSearchLogger, | ||||
|         ) | ||||
|     } | ||||
|  | ||||
|     fn results_good_enough(&self, keyword_results: &SearchResult) -> bool { | ||||
|         const GOOD_ENOUGH_SCORE: f64 = 0.9; | ||||
|  | ||||
|         // 1. we check that we got a sufficient number of results | ||||
|         if keyword_results.document_scores.len() < self.limit + self.offset { | ||||
|             return false; | ||||
|         } | ||||
|  | ||||
|         // 2. and that all results have a good enough score. | ||||
|         // we need to check all results because due to sort like rules, they're not necessarily in relevancy order | ||||
|         for score in &keyword_results.document_scores { | ||||
|             let score = ScoreDetails::global_score(score.iter()); | ||||
|             if score < GOOD_ENOUGH_SCORE { | ||||
|                 return false; | ||||
|             } | ||||
|         } | ||||
|         true | ||||
|     } | ||||
| } | ||||
| @@ -3,6 +3,7 @@ use std::ops::ControlFlow; | ||||
|  | ||||
| use charabia::normalizer::NormalizerOption; | ||||
| use charabia::Normalize; | ||||
| use deserr::{DeserializeError, Deserr, Sequence}; | ||||
| use fst::automaton::{Automaton, Str}; | ||||
| use fst::{IntoStreamer, Streamer}; | ||||
| use levenshtein_automata::{LevenshteinAutomatonBuilder as LevBuilder, DFA}; | ||||
| @@ -12,12 +13,13 @@ use roaring::bitmap::RoaringBitmap; | ||||
|  | ||||
| pub use self::facet::{FacetDistribution, Filter, OrderBy, DEFAULT_VALUES_PER_FACET}; | ||||
| pub use self::new::matches::{FormatOptions, MatchBounds, MatcherBuilder, MatchingWords}; | ||||
| use self::new::PartialSearchResult; | ||||
| use self::new::{execute_vector_search, PartialSearchResult}; | ||||
| use crate::error::UserError; | ||||
| use crate::heed_codec::facet::{FacetGroupKey, FacetGroupValue}; | ||||
| use crate::score_details::{ScoreDetails, ScoringStrategy}; | ||||
| use crate::{ | ||||
|     execute_search, AscDesc, DefaultSearchLogger, DocumentId, FieldId, Index, Result, SearchContext, | ||||
|     execute_search, filtered_universe, AscDesc, DefaultSearchLogger, DocumentId, FieldId, Index, | ||||
|     Result, SearchContext, | ||||
| }; | ||||
|  | ||||
| // Building these factories is not free. | ||||
| @@ -30,6 +32,7 @@ const MAX_NUMBER_OF_FACETS: usize = 100; | ||||
|  | ||||
| pub mod facet; | ||||
| mod fst_utils; | ||||
| pub mod hybrid; | ||||
| pub mod new; | ||||
|  | ||||
| pub struct Search<'a> { | ||||
| @@ -50,6 +53,53 @@ pub struct Search<'a> { | ||||
|     index: &'a Index, | ||||
| } | ||||
|  | ||||
| #[derive(Debug, Clone, PartialEq)] | ||||
| pub enum VectorQuery { | ||||
|     Vector(Vec<f32>), | ||||
|     String(String), | ||||
| } | ||||
|  | ||||
| impl<E> Deserr<E> for VectorQuery | ||||
| where | ||||
|     E: DeserializeError, | ||||
| { | ||||
|     fn deserialize_from_value<V: deserr::IntoValue>( | ||||
|         value: deserr::Value<V>, | ||||
|         location: deserr::ValuePointerRef, | ||||
|     ) -> std::result::Result<Self, E> { | ||||
|         match value { | ||||
|             deserr::Value::String(s) => Ok(VectorQuery::String(s)), | ||||
|             deserr::Value::Sequence(seq) => { | ||||
|                 let v: std::result::Result<Vec<f32>, _> = seq | ||||
|                     .into_iter() | ||||
|                     .enumerate() | ||||
|                     .map(|(index, v)| match v.into_value() { | ||||
|                         deserr::Value::Float(f) => Ok(f as f32), | ||||
|                         deserr::Value::Integer(i) => Ok(i as f32), | ||||
|                         v => Err(deserr::take_cf_content(E::error::<V>( | ||||
|                             None, | ||||
|                             deserr::ErrorKind::IncorrectValueKind { | ||||
|                                 actual: v, | ||||
|                                 accepted: &[deserr::ValueKind::Float, deserr::ValueKind::Integer], | ||||
|                             }, | ||||
|                             location.push_index(index), | ||||
|                         ))), | ||||
|                     }) | ||||
|                     .collect(); | ||||
|                 Ok(VectorQuery::Vector(v?)) | ||||
|             } | ||||
|             _ => Err(deserr::take_cf_content(E::error::<V>( | ||||
|                 None, | ||||
|                 deserr::ErrorKind::IncorrectValueKind { | ||||
|                     actual: value, | ||||
|                     accepted: &[deserr::ValueKind::String, deserr::ValueKind::Sequence], | ||||
|                 }, | ||||
|                 location, | ||||
|             ))), | ||||
|         } | ||||
|     } | ||||
| } | ||||
|  | ||||
| impl<'a> Search<'a> { | ||||
|     pub fn new(rtxn: &'a heed::RoTxn, index: &'a Index) -> Search<'a> { | ||||
|         Search { | ||||
| @@ -75,8 +125,8 @@ impl<'a> Search<'a> { | ||||
|         self | ||||
|     } | ||||
|  | ||||
|     pub fn vector(&mut self, vector: impl Into<Vec<f32>>) -> &mut Search<'a> { | ||||
|         self.vector = Some(vector.into()); | ||||
|     pub fn vector(&mut self, vector: Vec<f32>) -> &mut Search<'a> { | ||||
|         self.vector = Some(vector); | ||||
|         self | ||||
|     } | ||||
|  | ||||
| @@ -140,23 +190,35 @@ impl<'a> Search<'a> { | ||||
|             ctx.searchable_attributes(searchable_attributes)?; | ||||
|         } | ||||
|  | ||||
|         let universe = filtered_universe(&ctx, &self.filter)?; | ||||
|         let PartialSearchResult { located_query_terms, candidates, documents_ids, document_scores } = | ||||
|             execute_search( | ||||
|                 &mut ctx, | ||||
|                 &self.query, | ||||
|                 &self.vector, | ||||
|                 self.terms_matching_strategy, | ||||
|                 self.scoring_strategy, | ||||
|                 self.exhaustive_number_hits, | ||||
|                 &self.filter, | ||||
|                 &self.sort_criteria, | ||||
|                 self.geo_strategy, | ||||
|                 self.offset, | ||||
|                 self.limit, | ||||
|                 Some(self.words_limit), | ||||
|                 &mut DefaultSearchLogger, | ||||
|                 &mut DefaultSearchLogger, | ||||
|             )?; | ||||
|             match self.vector.as_ref() { | ||||
|                 Some(vector) => execute_vector_search( | ||||
|                     &mut ctx, | ||||
|                     vector, | ||||
|                     self.scoring_strategy, | ||||
|                     universe, | ||||
|                     &self.sort_criteria, | ||||
|                     self.geo_strategy, | ||||
|                     self.offset, | ||||
|                     self.limit, | ||||
|                 )?, | ||||
|                 None => execute_search( | ||||
|                     &mut ctx, | ||||
|                     self.query.as_deref(), | ||||
|                     self.terms_matching_strategy, | ||||
|                     self.scoring_strategy, | ||||
|                     self.exhaustive_number_hits, | ||||
|                     universe, | ||||
|                     &self.sort_criteria, | ||||
|                     self.geo_strategy, | ||||
|                     self.offset, | ||||
|                     self.limit, | ||||
|                     Some(self.words_limit), | ||||
|                     &mut DefaultSearchLogger, | ||||
|                     &mut DefaultSearchLogger, | ||||
|                 )?, | ||||
|             }; | ||||
|  | ||||
|         // consume context and located_query_terms to build MatchingWords. | ||||
|         let matching_words = match located_query_terms { | ||||
|   | ||||
| @@ -498,19 +498,19 @@ mod tests { | ||||
|  | ||||
|     use super::*; | ||||
|     use crate::index::tests::TempIndex; | ||||
|     use crate::{execute_search, SearchContext}; | ||||
|     use crate::{execute_search, filtered_universe, SearchContext}; | ||||
|  | ||||
|     impl<'a> MatcherBuilder<'a> { | ||||
|         fn new_test(rtxn: &'a heed::RoTxn, index: &'a TempIndex, query: &str) -> Self { | ||||
|             let mut ctx = SearchContext::new(index, rtxn); | ||||
|             let universe = filtered_universe(&ctx, &None).unwrap(); | ||||
|             let crate::search::PartialSearchResult { located_query_terms, .. } = execute_search( | ||||
|                 &mut ctx, | ||||
|                 &Some(query.to_string()), | ||||
|                 &None, | ||||
|                 Some(query), | ||||
|                 crate::TermsMatchingStrategy::default(), | ||||
|                 crate::score_details::ScoringStrategy::Skip, | ||||
|                 false, | ||||
|                 &None, | ||||
|                 universe, | ||||
|                 &None, | ||||
|                 crate::search::new::GeoSortStrategy::default(), | ||||
|                 0, | ||||
|   | ||||
| @@ -16,6 +16,7 @@ mod small_bitmap; | ||||
|  | ||||
| mod exact_attribute; | ||||
| mod sort; | ||||
| mod vector_sort; | ||||
|  | ||||
| #[cfg(test)] | ||||
| mod tests; | ||||
| @@ -28,7 +29,6 @@ use db_cache::DatabaseCache; | ||||
| use exact_attribute::ExactAttribute; | ||||
| use graph_based_ranking_rule::{Exactness, Fid, Position, Proximity, Typo}; | ||||
| use heed::RoTxn; | ||||
| use instant_distance::Search; | ||||
| use interner::{DedupInterner, Interner}; | ||||
| pub use logger::visual::VisualSearchLogger; | ||||
| pub use logger::{DefaultSearchLogger, SearchLogger}; | ||||
| @@ -46,7 +46,7 @@ use self::geo_sort::GeoSort; | ||||
| pub use self::geo_sort::Strategy as GeoSortStrategy; | ||||
| use self::graph_based_ranking_rule::Words; | ||||
| use self::interner::Interned; | ||||
| use crate::distance::NDotProductPoint; | ||||
| use self::vector_sort::VectorSort; | ||||
| use crate::error::FieldIdMapMissingEntry; | ||||
| use crate::score_details::{ScoreDetails, ScoringStrategy}; | ||||
| use crate::search::new::distinct::apply_distinct_rule; | ||||
| @@ -258,6 +258,70 @@ fn get_ranking_rules_for_placeholder_search<'ctx>( | ||||
|     Ok(ranking_rules) | ||||
| } | ||||
|  | ||||
| fn get_ranking_rules_for_vector<'ctx>( | ||||
|     ctx: &SearchContext<'ctx>, | ||||
|     sort_criteria: &Option<Vec<AscDesc>>, | ||||
|     geo_strategy: geo_sort::Strategy, | ||||
|     target: &[f32], | ||||
| ) -> Result<Vec<BoxRankingRule<'ctx, PlaceholderQuery>>> { | ||||
|     // query graph search | ||||
|  | ||||
|     let mut sort = false; | ||||
|     let mut sorted_fields = HashSet::new(); | ||||
|     let mut geo_sorted = false; | ||||
|  | ||||
|     let mut vector = false; | ||||
|     let mut ranking_rules: Vec<BoxRankingRule<PlaceholderQuery>> = vec![]; | ||||
|  | ||||
|     let settings_ranking_rules = ctx.index.criteria(ctx.txn)?; | ||||
|     for rr in settings_ranking_rules { | ||||
|         match rr { | ||||
|             crate::Criterion::Words | ||||
|             | crate::Criterion::Typo | ||||
|             | crate::Criterion::Proximity | ||||
|             | crate::Criterion::Attribute | ||||
|             | crate::Criterion::Exactness => { | ||||
|                 if !vector { | ||||
|                     let vector_candidates = ctx.index.documents_ids(ctx.txn)?; | ||||
|                     let vector_sort = VectorSort::new(ctx, target.to_vec(), vector_candidates)?; | ||||
|                     ranking_rules.push(Box::new(vector_sort)); | ||||
|                     vector = true; | ||||
|                 } | ||||
|             } | ||||
|             crate::Criterion::Sort => { | ||||
|                 if sort { | ||||
|                     continue; | ||||
|                 } | ||||
|                 resolve_sort_criteria( | ||||
|                     sort_criteria, | ||||
|                     ctx, | ||||
|                     &mut ranking_rules, | ||||
|                     &mut sorted_fields, | ||||
|                     &mut geo_sorted, | ||||
|                     geo_strategy, | ||||
|                 )?; | ||||
|                 sort = true; | ||||
|             } | ||||
|             crate::Criterion::Asc(field_name) => { | ||||
|                 if sorted_fields.contains(&field_name) { | ||||
|                     continue; | ||||
|                 } | ||||
|                 sorted_fields.insert(field_name.clone()); | ||||
|                 ranking_rules.push(Box::new(Sort::new(ctx.index, ctx.txn, field_name, true)?)); | ||||
|             } | ||||
|             crate::Criterion::Desc(field_name) => { | ||||
|                 if sorted_fields.contains(&field_name) { | ||||
|                     continue; | ||||
|                 } | ||||
|                 sorted_fields.insert(field_name.clone()); | ||||
|                 ranking_rules.push(Box::new(Sort::new(ctx.index, ctx.txn, field_name, false)?)); | ||||
|             } | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     Ok(ranking_rules) | ||||
| } | ||||
|  | ||||
| /// Return the list of initialised ranking rules to be used for a query graph search. | ||||
| fn get_ranking_rules_for_query_graph_search<'ctx>( | ||||
|     ctx: &SearchContext<'ctx>, | ||||
| @@ -422,15 +486,62 @@ fn resolve_sort_criteria<'ctx, Query: RankingRuleQueryTrait>( | ||||
|     Ok(()) | ||||
| } | ||||
|  | ||||
| pub fn filtered_universe(ctx: &SearchContext, filters: &Option<Filter>) -> Result<RoaringBitmap> { | ||||
|     Ok(if let Some(filters) = filters { | ||||
|         filters.evaluate(ctx.txn, ctx.index)? | ||||
|     } else { | ||||
|         ctx.index.documents_ids(ctx.txn)? | ||||
|     }) | ||||
| } | ||||
|  | ||||
| #[allow(clippy::too_many_arguments)] | ||||
| pub fn execute_vector_search( | ||||
|     ctx: &mut SearchContext, | ||||
|     vector: &[f32], | ||||
|     scoring_strategy: ScoringStrategy, | ||||
|     universe: RoaringBitmap, | ||||
|     sort_criteria: &Option<Vec<AscDesc>>, | ||||
|     geo_strategy: geo_sort::Strategy, | ||||
|     from: usize, | ||||
|     length: usize, | ||||
| ) -> Result<PartialSearchResult> { | ||||
|     check_sort_criteria(ctx, sort_criteria.as_ref())?; | ||||
|  | ||||
|     /// FIXME: input universe = universe & documents_with_vectors | ||||
|     // for now if we're computing embeddings for ALL documents, we can assume that this is just universe | ||||
|     let ranking_rules = get_ranking_rules_for_vector(ctx, sort_criteria, geo_strategy, vector)?; | ||||
|  | ||||
|     let mut placeholder_search_logger = logger::DefaultSearchLogger; | ||||
|     let placeholder_search_logger: &mut dyn SearchLogger<PlaceholderQuery> = | ||||
|         &mut placeholder_search_logger; | ||||
|  | ||||
|     let BucketSortOutput { docids, scores, all_candidates } = bucket_sort( | ||||
|         ctx, | ||||
|         ranking_rules, | ||||
|         &PlaceholderQuery, | ||||
|         &universe, | ||||
|         from, | ||||
|         length, | ||||
|         scoring_strategy, | ||||
|         placeholder_search_logger, | ||||
|     )?; | ||||
|  | ||||
|     Ok(PartialSearchResult { | ||||
|         candidates: all_candidates, | ||||
|         document_scores: scores, | ||||
|         documents_ids: docids, | ||||
|         located_query_terms: None, | ||||
|     }) | ||||
| } | ||||
|  | ||||
| #[allow(clippy::too_many_arguments)] | ||||
| pub fn execute_search( | ||||
|     ctx: &mut SearchContext, | ||||
|     query: &Option<String>, | ||||
|     vector: &Option<Vec<f32>>, | ||||
|     query: Option<&str>, | ||||
|     terms_matching_strategy: TermsMatchingStrategy, | ||||
|     scoring_strategy: ScoringStrategy, | ||||
|     exhaustive_number_hits: bool, | ||||
|     filters: &Option<Filter>, | ||||
|     mut universe: RoaringBitmap, | ||||
|     sort_criteria: &Option<Vec<AscDesc>>, | ||||
|     geo_strategy: geo_sort::Strategy, | ||||
|     from: usize, | ||||
| @@ -439,60 +550,8 @@ pub fn execute_search( | ||||
|     placeholder_search_logger: &mut dyn SearchLogger<PlaceholderQuery>, | ||||
|     query_graph_logger: &mut dyn SearchLogger<QueryGraph>, | ||||
| ) -> Result<PartialSearchResult> { | ||||
|     let mut universe = if let Some(filters) = filters { | ||||
|         filters.evaluate(ctx.txn, ctx.index)? | ||||
|     } else { | ||||
|         ctx.index.documents_ids(ctx.txn)? | ||||
|     }; | ||||
|  | ||||
|     check_sort_criteria(ctx, sort_criteria.as_ref())?; | ||||
|  | ||||
|     if let Some(vector) = vector { | ||||
|         let mut search = Search::default(); | ||||
|         let docids = match ctx.index.vector_hnsw(ctx.txn)? { | ||||
|             Some(hnsw) => { | ||||
|                 if let Some(expected_size) = hnsw.iter().map(|(_, point)| point.len()).next() { | ||||
|                     if vector.len() != expected_size { | ||||
|                         return Err(UserError::InvalidVectorDimensions { | ||||
|                             expected: expected_size, | ||||
|                             found: vector.len(), | ||||
|                         } | ||||
|                         .into()); | ||||
|                     } | ||||
|                 } | ||||
|  | ||||
|                 let vector = NDotProductPoint::new(vector.clone()); | ||||
|  | ||||
|                 let neighbors = hnsw.search(&vector, &mut search); | ||||
|  | ||||
|                 let mut docids = Vec::new(); | ||||
|                 let mut uniq_docids = RoaringBitmap::new(); | ||||
|                 for instant_distance::Item { distance: _, pid, point: _ } in neighbors { | ||||
|                     let index = pid.into_inner(); | ||||
|                     let docid = ctx.index.vector_id_docid.get(ctx.txn, &index)?.unwrap(); | ||||
|                     if universe.contains(docid) && uniq_docids.insert(docid) { | ||||
|                         docids.push(docid); | ||||
|                         if docids.len() == (from + length) { | ||||
|                             break; | ||||
|                         } | ||||
|                     } | ||||
|                 } | ||||
|  | ||||
|                 // return the nearest documents that are also part of the candidates | ||||
|                 // along with a dummy list of scores that are useless in this context. | ||||
|                 docids.into_iter().skip(from).take(length).collect() | ||||
|             } | ||||
|             None => Vec::new(), | ||||
|         }; | ||||
|  | ||||
|         return Ok(PartialSearchResult { | ||||
|             candidates: universe, | ||||
|             document_scores: vec![Vec::new(); docids.len()], | ||||
|             documents_ids: docids, | ||||
|             located_query_terms: None, | ||||
|         }); | ||||
|     } | ||||
|  | ||||
|     let mut located_query_terms = None; | ||||
|     let query_terms = if let Some(query) = query { | ||||
|         // We make sure that the analyzer is aware of the stop words | ||||
| @@ -546,7 +605,7 @@ pub fn execute_search( | ||||
|             terms_matching_strategy, | ||||
|         )?; | ||||
|  | ||||
|         universe = | ||||
|         universe &= | ||||
|             resolve_universe(ctx, &universe, &graph, terms_matching_strategy, query_graph_logger)?; | ||||
|  | ||||
|         bucket_sort( | ||||
|   | ||||
							
								
								
									
										150
									
								
								milli/src/search/new/vector_sort.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										150
									
								
								milli/src/search/new/vector_sort.rs
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,150 @@ | ||||
| use std::future::Future; | ||||
| use std::iter::FromIterator; | ||||
| use std::pin::Pin; | ||||
|  | ||||
| use nolife::DynBoxScope; | ||||
| use roaring::RoaringBitmap; | ||||
|  | ||||
| use super::ranking_rules::{RankingRule, RankingRuleOutput, RankingRuleQueryTrait}; | ||||
| use crate::distance::NDotProductPoint; | ||||
| use crate::index::Hnsw; | ||||
| use crate::score_details::{self, ScoreDetails}; | ||||
| use crate::{Result, SearchContext, SearchLogger, UserError}; | ||||
|  | ||||
| pub struct VectorSort<Q: RankingRuleQueryTrait> { | ||||
|     query: Option<Q>, | ||||
|     target: Vec<f32>, | ||||
|     vector_candidates: RoaringBitmap, | ||||
|     scope: nolife::DynBoxScope<SearchFamily>, | ||||
| } | ||||
|  | ||||
| type Item<'a> = instant_distance::Item<'a, NDotProductPoint>; | ||||
| type SearchFut = Pin<Box<dyn Future<Output = nolife::Never>>>; | ||||
|  | ||||
| struct SearchFamily; | ||||
| impl<'a> nolife::Family<'a> for SearchFamily { | ||||
|     type Family = Box<dyn Iterator<Item = Item<'a>> + 'a>; | ||||
| } | ||||
|  | ||||
| async fn search_scope( | ||||
|     mut time_capsule: nolife::TimeCapsule<SearchFamily>, | ||||
|     hnsw: Hnsw, | ||||
|     target: Vec<f32>, | ||||
| ) -> nolife::Never { | ||||
|     let mut search = instant_distance::Search::default(); | ||||
|     let it = Box::new(hnsw.search(&NDotProductPoint::new(target), &mut search)); | ||||
|     let mut it: Box<dyn Iterator<Item = Item>> = it; | ||||
|     loop { | ||||
|         time_capsule.freeze(&mut it).await; | ||||
|     } | ||||
| } | ||||
|  | ||||
| impl<Q: RankingRuleQueryTrait> VectorSort<Q> { | ||||
|     pub fn new( | ||||
|         ctx: &SearchContext, | ||||
|         target: Vec<f32>, | ||||
|         vector_candidates: RoaringBitmap, | ||||
|     ) -> Result<Self> { | ||||
|         let hnsw = | ||||
|             ctx.index.vector_hnsw(ctx.txn)?.unwrap_or(Hnsw::builder().build_hnsw(Vec::default()).0); | ||||
|  | ||||
|         if let Some(expected_size) = hnsw.iter().map(|(_, point)| point.len()).next() { | ||||
|             if target.len() != expected_size { | ||||
|                 return Err(UserError::InvalidVectorDimensions { | ||||
|                     expected: expected_size, | ||||
|                     found: target.len(), | ||||
|                 } | ||||
|                 .into()); | ||||
|             } | ||||
|         } | ||||
|  | ||||
|         let target_clone = target.clone(); | ||||
|         let producer = move |time_capsule| -> SearchFut { | ||||
|             Box::pin(search_scope(time_capsule, hnsw, target_clone)) | ||||
|         }; | ||||
|         let scope = DynBoxScope::new(producer); | ||||
|  | ||||
|         Ok(Self { query: None, target, vector_candidates, scope }) | ||||
|     } | ||||
| } | ||||
|  | ||||
| impl<'ctx, Q: RankingRuleQueryTrait> RankingRule<'ctx, Q> for VectorSort<Q> { | ||||
|     fn id(&self) -> String { | ||||
|         "vector_sort".to_owned() | ||||
|     } | ||||
|  | ||||
|     fn start_iteration( | ||||
|         &mut self, | ||||
|         _ctx: &mut SearchContext<'ctx>, | ||||
|         _logger: &mut dyn SearchLogger<Q>, | ||||
|         universe: &RoaringBitmap, | ||||
|         query: &Q, | ||||
|     ) -> Result<()> { | ||||
|         assert!(self.query.is_none()); | ||||
|  | ||||
|         self.query = Some(query.clone()); | ||||
|         self.vector_candidates &= universe; | ||||
|  | ||||
|         Ok(()) | ||||
|     } | ||||
|  | ||||
|     #[allow(clippy::only_used_in_recursion)] | ||||
|     fn next_bucket( | ||||
|         &mut self, | ||||
|         ctx: &mut SearchContext<'ctx>, | ||||
|         _logger: &mut dyn SearchLogger<Q>, | ||||
|         universe: &RoaringBitmap, | ||||
|     ) -> Result<Option<RankingRuleOutput<Q>>> { | ||||
|         let query = self.query.as_ref().unwrap().clone(); | ||||
|         self.vector_candidates &= universe; | ||||
|  | ||||
|         if self.vector_candidates.is_empty() { | ||||
|             return Ok(Some(RankingRuleOutput { | ||||
|                 query, | ||||
|                 candidates: universe.clone(), | ||||
|                 score: ScoreDetails::Vector(score_details::Vector { | ||||
|                     target_vector: self.target.clone(), | ||||
|                     value_similarity: None, | ||||
|                 }), | ||||
|             })); | ||||
|         } | ||||
|  | ||||
|         let scope = &mut self.scope; | ||||
|         let target = &self.target; | ||||
|         let vector_candidates = &self.vector_candidates; | ||||
|  | ||||
|         scope.enter(|it| { | ||||
|             for item in it.by_ref() { | ||||
|                 let item: Item = item; | ||||
|                 let index = item.pid.into_inner(); | ||||
|                 let docid = ctx.index.vector_id_docid.get(ctx.txn, &index)?.unwrap(); | ||||
|  | ||||
|                 if vector_candidates.contains(docid) { | ||||
|                     return Ok(Some(RankingRuleOutput { | ||||
|                         query, | ||||
|                         candidates: RoaringBitmap::from_iter([docid]), | ||||
|                         score: ScoreDetails::Vector(score_details::Vector { | ||||
|                             target_vector: target.clone(), | ||||
|                             value_similarity: Some(( | ||||
|                                 item.point.clone().into_inner(), | ||||
|                                 1.0 - item.distance, | ||||
|                             )), | ||||
|                         }), | ||||
|                     })); | ||||
|                 } | ||||
|             } | ||||
|             Ok(Some(RankingRuleOutput { | ||||
|                 query, | ||||
|                 candidates: universe.clone(), | ||||
|                 score: ScoreDetails::Vector(score_details::Vector { | ||||
|                     target_vector: target.clone(), | ||||
|                     value_similarity: None, | ||||
|                 }), | ||||
|             })) | ||||
|         }) | ||||
|     } | ||||
|  | ||||
|     fn end_iteration(&mut self, _ctx: &mut SearchContext<'ctx>, _logger: &mut dyn SearchLogger<Q>) { | ||||
|         self.query = None; | ||||
|     } | ||||
| } | ||||
| @@ -1,9 +1,10 @@ | ||||
| use std::cmp::Ordering; | ||||
| use std::convert::TryFrom; | ||||
| use std::convert::{TryFrom, TryInto}; | ||||
| use std::fs::File; | ||||
| use std::io::{self, BufReader, BufWriter}; | ||||
| use std::mem::size_of; | ||||
| use std::str::from_utf8; | ||||
| use std::sync::Arc; | ||||
|  | ||||
| use bytemuck::cast_slice; | ||||
| use grenad::Writer; | ||||
| @@ -13,13 +14,56 @@ use serde_json::{from_slice, Value}; | ||||
|  | ||||
| use super::helpers::{create_writer, writer_into_reader, GrenadParameters}; | ||||
| use crate::error::UserError; | ||||
| use crate::prompt::Prompt; | ||||
| use crate::update::del_add::{DelAdd, KvReaderDelAdd, KvWriterDelAdd}; | ||||
| use crate::update::index_documents::helpers::try_split_at; | ||||
| use crate::{DocumentId, FieldId, InternalError, Result, VectorOrArrayOfVectors}; | ||||
| use crate::vector::Embedder; | ||||
| use crate::{DocumentId, FieldsIdsMap, InternalError, Result, VectorOrArrayOfVectors}; | ||||
|  | ||||
| /// The length of the elements that are always in the buffer when inserting new values. | ||||
| const TRUNCATE_SIZE: usize = size_of::<DocumentId>(); | ||||
|  | ||||
| pub struct ExtractedVectorPoints { | ||||
|     // docid, _index -> KvWriterDelAdd -> Vector | ||||
|     pub manual_vectors: grenad::Reader<BufReader<File>>, | ||||
|     // docid -> () | ||||
|     pub remove_vectors: grenad::Reader<BufReader<File>>, | ||||
|     // docid -> prompt | ||||
|     pub prompts: grenad::Reader<BufReader<File>>, | ||||
| } | ||||
|  | ||||
| enum VectorStateDelta { | ||||
|     NoChange, | ||||
|     // Remove all vectors, generated or manual, from this document | ||||
|     NowRemoved, | ||||
|  | ||||
|     // Add the manually specified vectors, passed in the other grenad | ||||
|     // Remove any previously generated vectors | ||||
|     // Note: changing the value of the manually specified vector **should not record** this delta | ||||
|     WasGeneratedNowManual(Vec<Vec<f32>>), | ||||
|  | ||||
|     ManualDelta(Vec<Vec<f32>>, Vec<Vec<f32>>), | ||||
|  | ||||
|     // Add the vector computed from the specified prompt | ||||
|     // Remove any previous vector | ||||
|     // Note: changing the value of the prompt **does require** recording this delta | ||||
|     NowGenerated(String), | ||||
| } | ||||
|  | ||||
| impl VectorStateDelta { | ||||
|     fn into_values(self) -> (bool, String, (Vec<Vec<f32>>, Vec<Vec<f32>>)) { | ||||
|         match self { | ||||
|             VectorStateDelta::NoChange => Default::default(), | ||||
|             VectorStateDelta::NowRemoved => (true, Default::default(), Default::default()), | ||||
|             VectorStateDelta::WasGeneratedNowManual(add) => { | ||||
|                 (true, Default::default(), (Default::default(), add)) | ||||
|             } | ||||
|             VectorStateDelta::ManualDelta(del, add) => (false, Default::default(), (del, add)), | ||||
|             VectorStateDelta::NowGenerated(prompt) => (true, prompt, Default::default()), | ||||
|         } | ||||
|     } | ||||
| } | ||||
|  | ||||
| /// Extracts the embedding vector contained in each document under the `_vectors` field. | ||||
| /// | ||||
| /// Returns the generated grenad reader containing the docid as key associated to the Vec<f32> | ||||
| @@ -27,16 +71,34 @@ const TRUNCATE_SIZE: usize = size_of::<DocumentId>(); | ||||
| pub fn extract_vector_points<R: io::Read + io::Seek>( | ||||
|     obkv_documents: grenad::Reader<R>, | ||||
|     indexer: GrenadParameters, | ||||
|     vectors_fid: FieldId, | ||||
| ) -> Result<grenad::Reader<BufReader<File>>> { | ||||
|     field_id_map: FieldsIdsMap, | ||||
|     prompt: Option<&Prompt>, | ||||
| ) -> Result<ExtractedVectorPoints> { | ||||
|     puffin::profile_function!(); | ||||
|  | ||||
|     let mut writer = create_writer( | ||||
|     // (docid, _index) -> KvWriterDelAdd -> Vector | ||||
|     let mut manual_vectors_writer = create_writer( | ||||
|         indexer.chunk_compression_type, | ||||
|         indexer.chunk_compression_level, | ||||
|         tempfile::tempfile()?, | ||||
|     ); | ||||
|  | ||||
|     // (docid) -> (prompt) | ||||
|     let mut prompts_writer = create_writer( | ||||
|         indexer.chunk_compression_type, | ||||
|         indexer.chunk_compression_level, | ||||
|         tempfile::tempfile()?, | ||||
|     ); | ||||
|  | ||||
|     // (docid) -> () | ||||
|     let mut remove_vectors_writer = create_writer( | ||||
|         indexer.chunk_compression_type, | ||||
|         indexer.chunk_compression_level, | ||||
|         tempfile::tempfile()?, | ||||
|     ); | ||||
|  | ||||
|     let vectors_fid = field_id_map.id("_vectors"); | ||||
|  | ||||
|     let mut key_buffer = Vec::new(); | ||||
|     let mut cursor = obkv_documents.into_cursor()?; | ||||
|     while let Some((key, value)) = cursor.move_on_next()? { | ||||
| @@ -53,43 +115,148 @@ pub fn extract_vector_points<R: io::Read + io::Seek>( | ||||
|         // lazily get it when needed | ||||
|         let document_id = || -> Value { from_utf8(external_id_bytes).unwrap().into() }; | ||||
|  | ||||
|         // first we retrieve the _vectors field | ||||
|         if let Some(value) = obkv.get(vectors_fid) { | ||||
|         let delta = if let Some(value) = vectors_fid.and_then(|vectors_fid| obkv.get(vectors_fid)) { | ||||
|             let vectors_obkv = KvReaderDelAdd::new(value); | ||||
|             match (vectors_obkv.get(DelAdd::Deletion), vectors_obkv.get(DelAdd::Addition)) { | ||||
|                 (Some(old), Some(new)) => { | ||||
|                     // no autogeneration | ||||
|                     let del_vectors = extract_vectors(old, document_id)?; | ||||
|                     let add_vectors = extract_vectors(new, document_id)?; | ||||
|  | ||||
|             // then we extract the values | ||||
|             let del_vectors = vectors_obkv | ||||
|                 .get(DelAdd::Deletion) | ||||
|                 .map(|vectors| extract_vectors(vectors, document_id)) | ||||
|                 .transpose()? | ||||
|                 .flatten(); | ||||
|             let add_vectors = vectors_obkv | ||||
|                 .get(DelAdd::Addition) | ||||
|                 .map(|vectors| extract_vectors(vectors, document_id)) | ||||
|                 .transpose()? | ||||
|                 .flatten(); | ||||
|                     VectorStateDelta::ManualDelta( | ||||
|                         del_vectors.unwrap_or_default(), | ||||
|                         add_vectors.unwrap_or_default(), | ||||
|                     ) | ||||
|                 } | ||||
|                 (None, Some(new)) => { | ||||
|                     // was possibly autogenerated, remove all vectors for that document | ||||
|                     let add_vectors = extract_vectors(new, document_id)?; | ||||
|  | ||||
|             // and we finally push the unique vectors into the writer | ||||
|             push_vectors_diff( | ||||
|                 &mut writer, | ||||
|                 &mut key_buffer, | ||||
|                 del_vectors.unwrap_or_default(), | ||||
|                 add_vectors.unwrap_or_default(), | ||||
|             )?; | ||||
|         } | ||||
|                     VectorStateDelta::WasGeneratedNowManual(add_vectors.unwrap_or_default()) | ||||
|                 } | ||||
|                 (Some(_old), None) => { | ||||
|                     // Do we keep this document? | ||||
|                     let document_is_kept = obkv | ||||
|                         .iter() | ||||
|                         .map(|(_, deladd)| KvReaderDelAdd::new(deladd)) | ||||
|                         .any(|deladd| deladd.get(DelAdd::Addition).is_some()); | ||||
|                     if document_is_kept { | ||||
|                         // becomes autogenerated | ||||
|                         match prompt { | ||||
|                             Some(prompt) => VectorStateDelta::NowGenerated(prompt.render( | ||||
|                                 obkv, | ||||
|                                 DelAdd::Addition, | ||||
|                                 &field_id_map, | ||||
|                             )?), | ||||
|                             None => VectorStateDelta::NowRemoved, | ||||
|                         } | ||||
|                     } else { | ||||
|                         VectorStateDelta::NowRemoved | ||||
|                     } | ||||
|                 } | ||||
|                 (None, None) => { | ||||
|                     // Do we keep this document? | ||||
|                     let document_is_kept = obkv | ||||
|                         .iter() | ||||
|                         .map(|(_, deladd)| KvReaderDelAdd::new(deladd)) | ||||
|                         .any(|deladd| deladd.get(DelAdd::Addition).is_some()); | ||||
|  | ||||
|                     if document_is_kept { | ||||
|                         match prompt { | ||||
|                             Some(prompt) => { | ||||
|                                 // Don't give up if the old prompt was failing | ||||
|                                 let old_prompt = prompt | ||||
|                                     .render(obkv, DelAdd::Deletion, &field_id_map) | ||||
|                                     .unwrap_or_default(); | ||||
|                                 let new_prompt = | ||||
|                                     prompt.render(obkv, DelAdd::Addition, &field_id_map)?; | ||||
|                                 if old_prompt != new_prompt { | ||||
|                                     log::trace!( | ||||
|                                         "Changing prompt from\n{old_prompt}\n===\nto\n{new_prompt}" | ||||
|                                     ); | ||||
|                                     VectorStateDelta::NowGenerated(new_prompt) | ||||
|                                 } else { | ||||
|                                     VectorStateDelta::NoChange | ||||
|                                 } | ||||
|                             } | ||||
|                             // We no longer have a prompt, so we need to remove any existing vector | ||||
|                             None => VectorStateDelta::NowRemoved, | ||||
|                         } | ||||
|                     } else { | ||||
|                         VectorStateDelta::NowRemoved | ||||
|                     } | ||||
|                 } | ||||
|             } | ||||
|         } else { | ||||
|             // Do we keep this document? | ||||
|             let document_is_kept = obkv | ||||
|                 .iter() | ||||
|                 .map(|(_, deladd)| KvReaderDelAdd::new(deladd)) | ||||
|                 .any(|deladd| deladd.get(DelAdd::Addition).is_some()); | ||||
|  | ||||
|             if document_is_kept { | ||||
|                 match prompt { | ||||
|                     Some(prompt) => { | ||||
|                         // Don't give up if the old prompt was failing | ||||
|                         let old_prompt = prompt | ||||
|                             .render(obkv, DelAdd::Deletion, &field_id_map) | ||||
|                             .unwrap_or_default(); | ||||
|                         let new_prompt = prompt.render(obkv, DelAdd::Addition, &field_id_map)?; | ||||
|                         if old_prompt != new_prompt { | ||||
|                             log::trace!( | ||||
|                                 "Changing prompt from\n{old_prompt}\n===\nto\n{new_prompt}" | ||||
|                             ); | ||||
|                             VectorStateDelta::NowGenerated(new_prompt) | ||||
|                         } else { | ||||
|                             VectorStateDelta::NoChange | ||||
|                         } | ||||
|                     } | ||||
|                     None => VectorStateDelta::NowRemoved, | ||||
|                 } | ||||
|             } else { | ||||
|                 VectorStateDelta::NowRemoved | ||||
|             } | ||||
|         }; | ||||
|  | ||||
|         // and we finally push the unique vectors into the writer | ||||
|         push_vectors_diff( | ||||
|             &mut remove_vectors_writer, | ||||
|             &mut prompts_writer, | ||||
|             &mut manual_vectors_writer, | ||||
|             &mut key_buffer, | ||||
|             delta, | ||||
|         )?; | ||||
|     } | ||||
|  | ||||
|     writer_into_reader(writer) | ||||
|     Ok(ExtractedVectorPoints { | ||||
|         // docid, _index -> KvWriterDelAdd -> Vector | ||||
|         manual_vectors: writer_into_reader(manual_vectors_writer)?, | ||||
|         // docid -> () | ||||
|         remove_vectors: writer_into_reader(remove_vectors_writer)?, | ||||
|         // docid -> prompt | ||||
|         prompts: writer_into_reader(prompts_writer)?, | ||||
|     }) | ||||
| } | ||||
|  | ||||
| /// Computes the diff between both Del and Add numbers and | ||||
| /// only inserts the parts that differ in the sorter. | ||||
| fn push_vectors_diff( | ||||
|     writer: &mut Writer<BufWriter<File>>, | ||||
|     remove_vectors_writer: &mut Writer<BufWriter<File>>, | ||||
|     prompts_writer: &mut Writer<BufWriter<File>>, | ||||
|     manual_vectors_writer: &mut Writer<BufWriter<File>>, | ||||
|     key_buffer: &mut Vec<u8>, | ||||
|     mut del_vectors: Vec<Vec<f32>>, | ||||
|     mut add_vectors: Vec<Vec<f32>>, | ||||
|     delta: VectorStateDelta, | ||||
| ) -> Result<()> { | ||||
|     let (must_remove, prompt, (mut del_vectors, mut add_vectors)) = delta.into_values(); | ||||
|     if must_remove { | ||||
|         key_buffer.truncate(TRUNCATE_SIZE); | ||||
|         remove_vectors_writer.insert(&key_buffer, [])?; | ||||
|     } | ||||
|     if !prompt.is_empty() { | ||||
|         key_buffer.truncate(TRUNCATE_SIZE); | ||||
|         prompts_writer.insert(&key_buffer, prompt.as_bytes())?; | ||||
|     } | ||||
|  | ||||
|     // We sort and dedup the vectors | ||||
|     del_vectors.sort_unstable_by(|a, b| compare_vectors(a, b)); | ||||
|     add_vectors.sort_unstable_by(|a, b| compare_vectors(a, b)); | ||||
| @@ -114,7 +281,7 @@ fn push_vectors_diff( | ||||
|                 let mut obkv = KvWriterDelAdd::memory(); | ||||
|                 obkv.insert(DelAdd::Deletion, cast_slice(&vector))?; | ||||
|                 let bytes = obkv.into_inner()?; | ||||
|                 writer.insert(&key_buffer, bytes)?; | ||||
|                 manual_vectors_writer.insert(&key_buffer, bytes)?; | ||||
|             } | ||||
|             EitherOrBoth::Right(vector) => { | ||||
|                 // We insert only the Add part of the Obkv to inform | ||||
| @@ -122,7 +289,7 @@ fn push_vectors_diff( | ||||
|                 let mut obkv = KvWriterDelAdd::memory(); | ||||
|                 obkv.insert(DelAdd::Addition, cast_slice(&vector))?; | ||||
|                 let bytes = obkv.into_inner()?; | ||||
|                 writer.insert(&key_buffer, bytes)?; | ||||
|                 manual_vectors_writer.insert(&key_buffer, bytes)?; | ||||
|             } | ||||
|         } | ||||
|     } | ||||
| @@ -146,3 +313,102 @@ fn extract_vectors(value: &[u8], document_id: impl Fn() -> Value) -> Result<Opti | ||||
|         .into()), | ||||
|     } | ||||
| } | ||||
|  | ||||
| #[logging_timer::time] | ||||
| pub fn extract_embeddings<R: io::Read + io::Seek>( | ||||
|     // docid, prompt | ||||
|     prompt_reader: grenad::Reader<R>, | ||||
|     indexer: GrenadParameters, | ||||
|     embedder: Arc<Embedder>, | ||||
| ) -> Result<(grenad::Reader<BufReader<File>>, Option<usize>)> { | ||||
|     let rt = tokio::runtime::Builder::new_current_thread().enable_io().enable_time().build()?; | ||||
|  | ||||
|     let n_chunks = embedder.chunk_count_hint(); // chunk level parellelism | ||||
|     let n_vectors_per_chunk = embedder.prompt_count_in_chunk_hint(); // number of vectors in a single chunk | ||||
|  | ||||
|     // docid, state with embedding | ||||
|     let mut state_writer = create_writer( | ||||
|         indexer.chunk_compression_type, | ||||
|         indexer.chunk_compression_level, | ||||
|         tempfile::tempfile()?, | ||||
|     ); | ||||
|  | ||||
|     let mut chunks = Vec::with_capacity(n_chunks); | ||||
|     let mut current_chunk = Vec::with_capacity(n_vectors_per_chunk); | ||||
|     let mut current_chunk_ids = Vec::with_capacity(n_vectors_per_chunk); | ||||
|     let mut chunks_ids = Vec::with_capacity(n_chunks); | ||||
|     let mut cursor = prompt_reader.into_cursor()?; | ||||
|  | ||||
|     let mut expected_dimension = None; | ||||
|  | ||||
|     while let Some((key, value)) = cursor.move_on_next()? { | ||||
|         let docid = key.try_into().map(DocumentId::from_be_bytes).unwrap(); | ||||
|         // SAFETY: precondition, the grenad value was saved from a string | ||||
|         let prompt = unsafe { std::str::from_utf8_unchecked(value) }; | ||||
|         if current_chunk.len() == current_chunk.capacity() { | ||||
|             chunks.push(std::mem::replace( | ||||
|                 &mut current_chunk, | ||||
|                 Vec::with_capacity(n_vectors_per_chunk), | ||||
|             )); | ||||
|             chunks_ids.push(std::mem::replace( | ||||
|                 &mut current_chunk_ids, | ||||
|                 Vec::with_capacity(n_vectors_per_chunk), | ||||
|             )); | ||||
|         }; | ||||
|         current_chunk.push(prompt.to_owned()); | ||||
|         current_chunk_ids.push(docid); | ||||
|  | ||||
|         if chunks.len() == chunks.capacity() { | ||||
|             let chunked_embeds = rt | ||||
|                 .block_on( | ||||
|                     embedder | ||||
|                         .embed_chunks(std::mem::replace(&mut chunks, Vec::with_capacity(n_chunks))), | ||||
|                 ) | ||||
|                 .map_err(crate::vector::Error::from) | ||||
|                 .map_err(crate::UserError::from) | ||||
|                 .map_err(crate::Error::from)?; | ||||
|  | ||||
|             for (docid, embeddings) in chunks_ids | ||||
|                 .iter() | ||||
|                 .flat_map(|docids| docids.iter()) | ||||
|                 .zip(chunked_embeds.iter().flat_map(|embeds| embeds.iter())) | ||||
|             { | ||||
|                 state_writer.insert(docid.to_be_bytes(), cast_slice(embeddings.as_inner()))?; | ||||
|                 expected_dimension = Some(embeddings.dimension()); | ||||
|             } | ||||
|             chunks_ids.clear(); | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     // send last chunk | ||||
|     if !chunks.is_empty() { | ||||
|         let chunked_embeds = rt | ||||
|             .block_on(embedder.embed_chunks(std::mem::take(&mut chunks))) | ||||
|             .map_err(crate::vector::Error::from) | ||||
|             .map_err(crate::UserError::from) | ||||
|             .map_err(crate::Error::from)?; | ||||
|         for (docid, embeddings) in chunks_ids | ||||
|             .iter() | ||||
|             .flat_map(|docids| docids.iter()) | ||||
|             .zip(chunked_embeds.iter().flat_map(|embeds| embeds.iter())) | ||||
|         { | ||||
|             state_writer.insert(docid.to_be_bytes(), cast_slice(embeddings.as_inner()))?; | ||||
|             expected_dimension = Some(embeddings.dimension()); | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     if !current_chunk.is_empty() { | ||||
|         let embeds = rt | ||||
|             .block_on(embedder.embed(std::mem::take(&mut current_chunk))) | ||||
|             .map_err(crate::vector::Error::from) | ||||
|             .map_err(crate::UserError::from) | ||||
|             .map_err(crate::Error::from)?; | ||||
|  | ||||
|         for (docid, embeddings) in current_chunk_ids.iter().zip(embeds.iter()) { | ||||
|             state_writer.insert(docid.to_be_bytes(), cast_slice(embeddings.as_inner()))?; | ||||
|             expected_dimension = Some(embeddings.dimension()); | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     Ok((writer_into_reader(state_writer)?, expected_dimension)) | ||||
| } | ||||
|   | ||||
| @@ -9,9 +9,10 @@ mod extract_word_docids; | ||||
| mod extract_word_pair_proximity_docids; | ||||
| mod extract_word_position_docids; | ||||
|  | ||||
| use std::collections::HashSet; | ||||
| use std::collections::{HashMap, HashSet}; | ||||
| use std::fs::File; | ||||
| use std::io::BufReader; | ||||
| use std::sync::Arc; | ||||
|  | ||||
| use crossbeam_channel::Sender; | ||||
| use log::debug; | ||||
| @@ -23,7 +24,9 @@ use self::extract_facet_string_docids::extract_facet_string_docids; | ||||
| use self::extract_fid_docid_facet_values::{extract_fid_docid_facet_values, ExtractedFacetValues}; | ||||
| use self::extract_fid_word_count_docids::extract_fid_word_count_docids; | ||||
| use self::extract_geo_points::extract_geo_points; | ||||
| use self::extract_vector_points::extract_vector_points; | ||||
| use self::extract_vector_points::{ | ||||
|     extract_embeddings, extract_vector_points, ExtractedVectorPoints, | ||||
| }; | ||||
| use self::extract_word_docids::extract_word_docids; | ||||
| use self::extract_word_pair_proximity_docids::extract_word_pair_proximity_docids; | ||||
| use self::extract_word_position_docids::extract_word_position_docids; | ||||
| @@ -32,8 +35,10 @@ use super::helpers::{ | ||||
|     MergeFn, MergeableReader, | ||||
| }; | ||||
| use super::{helpers, TypedChunk}; | ||||
| use crate::prompt::Prompt; | ||||
| use crate::proximity::ProximityPrecision; | ||||
| use crate::{FieldId, Result}; | ||||
| use crate::vector::Embedder; | ||||
| use crate::{FieldId, FieldsIdsMap, Result}; | ||||
|  | ||||
| /// Extract data for each databases from obkv documents in parallel. | ||||
| /// Send data in grenad file over provided Sender. | ||||
| @@ -47,13 +52,14 @@ pub(crate) fn data_from_obkv_documents( | ||||
|     faceted_fields: HashSet<FieldId>, | ||||
|     primary_key_id: FieldId, | ||||
|     geo_fields_ids: Option<(FieldId, FieldId)>, | ||||
|     vectors_field_id: Option<FieldId>, | ||||
|     field_id_map: FieldsIdsMap, | ||||
|     stop_words: Option<fst::Set<&[u8]>>, | ||||
|     allowed_separators: Option<&[&str]>, | ||||
|     dictionary: Option<&[&str]>, | ||||
|     max_positions_per_attributes: Option<u32>, | ||||
|     exact_attributes: HashSet<FieldId>, | ||||
|     proximity_precision: ProximityPrecision, | ||||
|     embedders: HashMap<String, (Arc<Embedder>, Arc<Prompt>)>, | ||||
| ) -> Result<()> { | ||||
|     puffin::profile_function!(); | ||||
|  | ||||
| @@ -64,7 +70,8 @@ pub(crate) fn data_from_obkv_documents( | ||||
|                 original_documents_chunk, | ||||
|                 indexer, | ||||
|                 lmdb_writer_sx.clone(), | ||||
|                 vectors_field_id, | ||||
|                 field_id_map.clone(), | ||||
|                 embedders.clone(), | ||||
|             ) | ||||
|         }) | ||||
|         .collect::<Result<()>>()?; | ||||
| @@ -276,24 +283,42 @@ fn send_original_documents_data( | ||||
|     original_documents_chunk: Result<grenad::Reader<BufReader<File>>>, | ||||
|     indexer: GrenadParameters, | ||||
|     lmdb_writer_sx: Sender<Result<TypedChunk>>, | ||||
|     vectors_field_id: Option<FieldId>, | ||||
|     field_id_map: FieldsIdsMap, | ||||
|     embedders: HashMap<String, (Arc<Embedder>, Arc<Prompt>)>, | ||||
| ) -> Result<()> { | ||||
|     let original_documents_chunk = | ||||
|         original_documents_chunk.and_then(|c| unsafe { as_cloneable_grenad(&c) })?; | ||||
|  | ||||
|     if let Some(vectors_field_id) = vectors_field_id { | ||||
|         let documents_chunk_cloned = original_documents_chunk.clone(); | ||||
|         let lmdb_writer_sx_cloned = lmdb_writer_sx.clone(); | ||||
|         rayon::spawn(move || { | ||||
|             let result = extract_vector_points(documents_chunk_cloned, indexer, vectors_field_id); | ||||
|             let _ = match result { | ||||
|                 Ok(vector_points) => { | ||||
|                     lmdb_writer_sx_cloned.send(Ok(TypedChunk::VectorPoints(vector_points))) | ||||
|                 } | ||||
|                 Err(error) => lmdb_writer_sx_cloned.send(Err(error)), | ||||
|             }; | ||||
|         }); | ||||
|     } | ||||
|     let documents_chunk_cloned = original_documents_chunk.clone(); | ||||
|     let lmdb_writer_sx_cloned = lmdb_writer_sx.clone(); | ||||
|     rayon::spawn(move || { | ||||
|         let (embedder, prompt) = embedders.get("default").cloned().unzip(); | ||||
|         let result = | ||||
|             extract_vector_points(documents_chunk_cloned, indexer, field_id_map, prompt.as_deref()); | ||||
|         let _ = match result { | ||||
|             Ok(ExtractedVectorPoints { manual_vectors, remove_vectors, prompts }) => { | ||||
|                 /// FIXME: support multiple embedders | ||||
|                 let results = embedder.and_then(|embedder| { | ||||
|                     match extract_embeddings(prompts, indexer, embedder.clone()) { | ||||
|                         Ok(results) => Some(results), | ||||
|                         Err(error) => { | ||||
|                             let _ = lmdb_writer_sx_cloned.send(Err(error)); | ||||
|                             None | ||||
|                         } | ||||
|                     } | ||||
|                 }); | ||||
|                 let (embeddings, expected_dimension) = results.unzip(); | ||||
|                 let expected_dimension = expected_dimension.flatten(); | ||||
|                 lmdb_writer_sx_cloned.send(Ok(TypedChunk::VectorPoints { | ||||
|                     remove_vectors, | ||||
|                     embeddings, | ||||
|                     expected_dimension, | ||||
|                     manual_vectors, | ||||
|                 })) | ||||
|             } | ||||
|             Err(error) => lmdb_writer_sx_cloned.send(Err(error)), | ||||
|         }; | ||||
|     }); | ||||
|  | ||||
|     // TODO: create a custom internal error | ||||
|     lmdb_writer_sx.send(Ok(TypedChunk::Documents(original_documents_chunk))).unwrap(); | ||||
|   | ||||
| @@ -4,11 +4,12 @@ mod helpers; | ||||
| mod transform; | ||||
| mod typed_chunk; | ||||
|  | ||||
| use std::collections::HashSet; | ||||
| use std::collections::{HashMap, HashSet}; | ||||
| use std::io::{Cursor, Read, Seek}; | ||||
| use std::iter::FromIterator; | ||||
| use std::num::NonZeroU32; | ||||
| use std::result::Result as StdResult; | ||||
| use std::sync::Arc; | ||||
|  | ||||
| use crossbeam_channel::{Receiver, Sender}; | ||||
| use heed::types::Str; | ||||
| @@ -32,10 +33,12 @@ use self::helpers::{grenad_obkv_into_chunks, GrenadParameters}; | ||||
| pub use self::transform::{Transform, TransformOutput}; | ||||
| use crate::documents::{obkv_to_object, DocumentsBatchReader}; | ||||
| use crate::error::{Error, InternalError, UserError}; | ||||
| use crate::prompt::Prompt; | ||||
| pub use crate::update::index_documents::helpers::CursorClonableMmap; | ||||
| use crate::update::{ | ||||
|     IndexerConfig, UpdateIndexingStep, WordPrefixDocids, WordPrefixIntegerDocids, WordsPrefixesFst, | ||||
| }; | ||||
| use crate::vector::Embedder; | ||||
| use crate::{CboRoaringBitmapCodec, Index, Result}; | ||||
|  | ||||
| static MERGED_DATABASE_COUNT: usize = 7; | ||||
| @@ -78,6 +81,7 @@ pub struct IndexDocuments<'t, 'i, 'a, FP, FA> { | ||||
|     should_abort: FA, | ||||
|     added_documents: u64, | ||||
|     deleted_documents: u64, | ||||
|     embedders: HashMap<String, (Arc<Embedder>, Arc<Prompt>)>, | ||||
| } | ||||
|  | ||||
| #[derive(Default, Debug, Clone)] | ||||
| @@ -121,6 +125,7 @@ where | ||||
|             index, | ||||
|             added_documents: 0, | ||||
|             deleted_documents: 0, | ||||
|             embedders: Default::default(), | ||||
|         }) | ||||
|     } | ||||
|  | ||||
| @@ -167,6 +172,14 @@ where | ||||
|         Ok((self, Ok(indexed_documents))) | ||||
|     } | ||||
|  | ||||
|     pub fn with_embedders( | ||||
|         mut self, | ||||
|         embedders: HashMap<String, (Arc<Embedder>, Arc<Prompt>)>, | ||||
|     ) -> Self { | ||||
|         self.embedders = embedders; | ||||
|         self | ||||
|     } | ||||
|  | ||||
|     /// Remove a batch of documents from the current builder. | ||||
|     /// | ||||
|     /// Returns the number of documents deleted from the builder. | ||||
| @@ -322,17 +335,18 @@ where | ||||
|         // get filterable fields for facet databases | ||||
|         let faceted_fields = self.index.faceted_fields_ids(self.wtxn)?; | ||||
|         // get the fid of the `_geo.lat` and `_geo.lng` fields. | ||||
|         let geo_fields_ids = match self.index.fields_ids_map(self.wtxn)?.id("_geo") { | ||||
|         let mut field_id_map = self.index.fields_ids_map(self.wtxn)?; | ||||
|  | ||||
|         // self.index.fields_ids_map($a)? ==>> field_id_map | ||||
|         let geo_fields_ids = match field_id_map.id("_geo") { | ||||
|             Some(gfid) => { | ||||
|                 let is_sortable = self.index.sortable_fields_ids(self.wtxn)?.contains(&gfid); | ||||
|                 let is_filterable = self.index.filterable_fields_ids(self.wtxn)?.contains(&gfid); | ||||
|                 // if `_geo` is faceted then we get the `lat` and `lng` | ||||
|                 if is_sortable || is_filterable { | ||||
|                     let field_ids = self | ||||
|                         .index | ||||
|                         .fields_ids_map(self.wtxn)? | ||||
|                     let field_ids = field_id_map | ||||
|                         .insert("_geo.lat") | ||||
|                         .zip(self.index.fields_ids_map(self.wtxn)?.insert("_geo.lng")) | ||||
|                         .zip(field_id_map.insert("_geo.lng")) | ||||
|                         .ok_or(UserError::AttributeLimitReached)?; | ||||
|                     Some(field_ids) | ||||
|                 } else { | ||||
| @@ -341,8 +355,6 @@ where | ||||
|             } | ||||
|             None => None, | ||||
|         }; | ||||
|         // get the fid of the `_vectors` field. | ||||
|         let vectors_field_id = self.index.fields_ids_map(self.wtxn)?.id("_vectors"); | ||||
|  | ||||
|         let stop_words = self.index.stop_words(self.wtxn)?; | ||||
|         let separators = self.index.allowed_separators(self.wtxn)?; | ||||
| @@ -364,6 +376,8 @@ where | ||||
|             self.indexer_config.documents_chunk_size.unwrap_or(1024 * 1024 * 4); // 4MiB | ||||
|         let max_positions_per_attributes = self.indexer_config.max_positions_per_attributes; | ||||
|  | ||||
|         let cloned_embedder = self.embedders.clone(); | ||||
|  | ||||
|         // Run extraction pipeline in parallel. | ||||
|         pool.install(|| { | ||||
|             puffin::profile_scope!("extract_and_send_grenad_chunks"); | ||||
| @@ -387,13 +401,14 @@ where | ||||
|                     faceted_fields, | ||||
|                     primary_key_id, | ||||
|                     geo_fields_ids, | ||||
|                     vectors_field_id, | ||||
|                     field_id_map, | ||||
|                     stop_words, | ||||
|                     separators.as_deref(), | ||||
|                     dictionary.as_deref(), | ||||
|                     max_positions_per_attributes, | ||||
|                     exact_attributes, | ||||
|                     proximity_precision, | ||||
|                     cloned_embedder, | ||||
|                 ) | ||||
|             }); | ||||
|  | ||||
| @@ -2505,7 +2520,7 @@ mod tests { | ||||
|             .unwrap(); | ||||
|  | ||||
|         let rtxn = index.read_txn().unwrap(); | ||||
|         let res = index.search(&rtxn).vector([0.0, 1.0, 2.0]).execute().unwrap(); | ||||
|         let res = index.search(&rtxn).vector([0.0, 1.0, 2.0].to_vec()).execute().unwrap(); | ||||
|         assert_eq!(res.documents_ids.len(), 3); | ||||
|     } | ||||
|  | ||||
|   | ||||
| @@ -47,7 +47,12 @@ pub(crate) enum TypedChunk { | ||||
|     FieldIdFacetIsNullDocids(grenad::Reader<BufReader<File>>), | ||||
|     FieldIdFacetIsEmptyDocids(grenad::Reader<BufReader<File>>), | ||||
|     GeoPoints(grenad::Reader<BufReader<File>>), | ||||
|     VectorPoints(grenad::Reader<BufReader<File>>), | ||||
|     VectorPoints { | ||||
|         remove_vectors: grenad::Reader<BufReader<File>>, | ||||
|         embeddings: Option<grenad::Reader<BufReader<File>>>, | ||||
|         expected_dimension: Option<usize>, | ||||
|         manual_vectors: grenad::Reader<BufReader<File>>, | ||||
|     }, | ||||
|     ScriptLanguageDocids(HashMap<(Script, Language), (RoaringBitmap, RoaringBitmap)>), | ||||
| } | ||||
|  | ||||
| @@ -100,8 +105,8 @@ impl TypedChunk { | ||||
|             TypedChunk::GeoPoints(grenad) => { | ||||
|                 format!("GeoPoints {{ number_of_entries: {} }}", grenad.len()) | ||||
|             } | ||||
|             TypedChunk::VectorPoints(grenad) => { | ||||
|                 format!("VectorPoints {{ number_of_entries: {} }}", grenad.len()) | ||||
|             TypedChunk::VectorPoints{ remove_vectors, manual_vectors, embeddings, expected_dimension } => { | ||||
|                 format!("VectorPoints {{ remove_vectors: {}, manual_vectors: {}, embeddings: {}, dimension: {} }}", remove_vectors.len(), manual_vectors.len(), embeddings.as_ref().map(|e| e.len()).unwrap_or_default(), expected_dimension.unwrap_or_default()) | ||||
|             } | ||||
|             TypedChunk::ScriptLanguageDocids(sl_map) => { | ||||
|                 format!("ScriptLanguageDocids {{ number_of_entries: {} }}", sl_map.len()) | ||||
| @@ -355,19 +360,64 @@ pub(crate) fn write_typed_chunk_into_index( | ||||
|             index.put_geo_rtree(wtxn, &rtree)?; | ||||
|             index.put_geo_faceted_documents_ids(wtxn, &geo_faceted_docids)?; | ||||
|         } | ||||
|         TypedChunk::VectorPoints(vector_points) => { | ||||
|             let mut vectors_set = HashSet::new(); | ||||
|         TypedChunk::VectorPoints { | ||||
|             remove_vectors, | ||||
|             manual_vectors, | ||||
|             embeddings, | ||||
|             expected_dimension, | ||||
|         } => { | ||||
|             if remove_vectors.is_empty() | ||||
|                 && manual_vectors.is_empty() | ||||
|                 && embeddings.as_ref().map_or(true, |e| e.is_empty()) | ||||
|             { | ||||
|                 return Ok((RoaringBitmap::new(), is_merged_database)); | ||||
|             } | ||||
|  | ||||
|             let mut docid_vectors_map: HashMap<DocumentId, HashSet<Vec<OrderedFloat<f32>>>> = | ||||
|                 HashMap::new(); | ||||
|  | ||||
|             // We extract and store the previous vectors | ||||
|             if let Some(hnsw) = index.vector_hnsw(wtxn)? { | ||||
|                 for (pid, point) in hnsw.iter() { | ||||
|                     let pid_key = pid.into_inner(); | ||||
|                     let docid = index.vector_id_docid.get(wtxn, &pid_key)?.unwrap(); | ||||
|                     let vector: Vec<_> = point.iter().copied().map(OrderedFloat).collect(); | ||||
|                     vectors_set.insert((docid, vector)); | ||||
|                     docid_vectors_map.entry(docid).or_default().insert(vector); | ||||
|                 } | ||||
|             } | ||||
|  | ||||
|             let mut cursor = vector_points.into_cursor()?; | ||||
|             // remove vectors for docids we want them removed | ||||
|             let mut cursor = remove_vectors.into_cursor()?; | ||||
|             while let Some((key, _)) = cursor.move_on_next()? { | ||||
|                 let docid = key.try_into().map(DocumentId::from_be_bytes).unwrap(); | ||||
|  | ||||
|                 docid_vectors_map.remove(&docid); | ||||
|             } | ||||
|  | ||||
|             // add generated embeddings | ||||
|             if let Some((embeddings, expected_dimension)) = embeddings.zip(expected_dimension) { | ||||
|                 let mut cursor = embeddings.into_cursor()?; | ||||
|                 while let Some((key, value)) = cursor.move_on_next()? { | ||||
|                     let docid = key.try_into().map(DocumentId::from_be_bytes).unwrap(); | ||||
|                     let data: Vec<OrderedFloat<_>> = | ||||
|                         pod_collect_to_vec(value).into_iter().map(OrderedFloat).collect(); | ||||
|                     // it is a code error to have embeddings and not expected_dimension | ||||
|                     let embeddings = | ||||
|                         crate::vector::Embeddings::from_inner(data, expected_dimension) | ||||
|                             // code error if we somehow got the wrong dimension | ||||
|                             .unwrap(); | ||||
|  | ||||
|                     let mut set = HashSet::new(); | ||||
|                     for embedding in embeddings.iter() { | ||||
|                         set.insert(embedding.to_vec()); | ||||
|                     } | ||||
|  | ||||
|                     docid_vectors_map.insert(docid, set); | ||||
|                 } | ||||
|             } | ||||
|  | ||||
|             // perform the manual diff | ||||
|             let mut cursor = manual_vectors.into_cursor()?; | ||||
|             while let Some((key, value)) = cursor.move_on_next()? { | ||||
|                 // convert the key back to a u32 (4 bytes) | ||||
|                 let (left, _index) = try_split_array_at(key).unwrap(); | ||||
| @@ -376,23 +426,30 @@ pub(crate) fn write_typed_chunk_into_index( | ||||
|                 let vector_deladd_obkv = KvReaderDelAdd::new(value); | ||||
|                 if let Some(value) = vector_deladd_obkv.get(DelAdd::Deletion) { | ||||
|                     // convert the vector back to a Vec<f32> | ||||
|                     let vector = pod_collect_to_vec(value).into_iter().map(OrderedFloat).collect(); | ||||
|                     let key = (docid, vector); | ||||
|                     if !vectors_set.remove(&key) { | ||||
|                         error!("Unable to delete the vector: {:?}", key.1); | ||||
|                     } | ||||
|                     let vector: Vec<OrderedFloat<_>> = | ||||
|                         pod_collect_to_vec(value).into_iter().map(OrderedFloat).collect(); | ||||
|                     docid_vectors_map.entry(docid).and_modify(|v| { | ||||
|                         if !v.remove(&vector) { | ||||
|                             error!("Unable to delete the vector: {:?}", vector); | ||||
|                         } | ||||
|                     }); | ||||
|                 } | ||||
|                 if let Some(value) = vector_deladd_obkv.get(DelAdd::Addition) { | ||||
|                     // convert the vector back to a Vec<f32> | ||||
|                     let vector = pod_collect_to_vec(value).into_iter().map(OrderedFloat).collect(); | ||||
|                     vectors_set.insert((docid, vector)); | ||||
|                     docid_vectors_map.entry(docid).and_modify(|v| { | ||||
|                         v.insert(vector); | ||||
|                     }); | ||||
|                 } | ||||
|             } | ||||
|  | ||||
|             // Extract the most common vector dimension | ||||
|             let expected_dimension_size = { | ||||
|                 let mut dims = HashMap::new(); | ||||
|                 vectors_set.iter().for_each(|(_, v)| *dims.entry(v.len()).or_insert(0) += 1); | ||||
|                 docid_vectors_map | ||||
|                     .values() | ||||
|                     .flat_map(|v| v.iter()) | ||||
|                     .for_each(|v| *dims.entry(v.len()).or_insert(0) += 1); | ||||
|                 dims.into_iter().max_by_key(|(_, count)| *count).map(|(len, _)| len) | ||||
|             }; | ||||
|  | ||||
| @@ -400,7 +457,10 @@ pub(crate) fn write_typed_chunk_into_index( | ||||
|             // prepare the vectors before inserting them in the HNSW. | ||||
|             let mut points = Vec::new(); | ||||
|             let mut docids = Vec::new(); | ||||
|             for (docid, vector) in vectors_set { | ||||
|             for (docid, vector) in docid_vectors_map | ||||
|                 .into_iter() | ||||
|                 .flat_map(|(docid, vectors)| std::iter::repeat(docid).zip(vectors)) | ||||
|             { | ||||
|                 if expected_dimension_size.map_or(false, |expected| expected != vector.len()) { | ||||
|                     return Err(UserError::InvalidVectorDimensions { | ||||
|                         expected: expected_dimension_size.unwrap_or(vector.len()), | ||||
|   | ||||
| @@ -3,7 +3,7 @@ use std::result::Result as StdResult; | ||||
|  | ||||
| use charabia::{Normalize, Tokenizer, TokenizerBuilder}; | ||||
| use deserr::{DeserializeError, Deserr}; | ||||
| use itertools::Itertools; | ||||
| use itertools::{EitherOrBoth, Itertools}; | ||||
| use serde::{Deserialize, Deserializer, Serialize, Serializer}; | ||||
| use time::OffsetDateTime; | ||||
|  | ||||
| @@ -15,6 +15,8 @@ use crate::index::{DEFAULT_MIN_WORD_LEN_ONE_TYPO, DEFAULT_MIN_WORD_LEN_TWO_TYPOS | ||||
| use crate::proximity::ProximityPrecision; | ||||
| use crate::update::index_documents::IndexDocumentsMethod; | ||||
| use crate::update::{IndexDocuments, UpdateIndexingStep}; | ||||
| use crate::vector::settings::{EmbeddingSettings, PromptSettings}; | ||||
| use crate::vector::EmbeddingConfig; | ||||
| use crate::{FieldsIdsMap, Index, OrderBy, Result}; | ||||
|  | ||||
| #[derive(Debug, Clone, PartialEq, Eq, Copy)] | ||||
| @@ -73,6 +75,13 @@ impl<T> Setting<T> { | ||||
|             otherwise => otherwise, | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     pub fn apply(&mut self, new: Self) { | ||||
|         if let Setting::NotSet = new { | ||||
|             return; | ||||
|         } | ||||
|         *self = new; | ||||
|     } | ||||
| } | ||||
|  | ||||
| impl<T: Serialize> Serialize for Setting<T> { | ||||
| @@ -129,6 +138,7 @@ pub struct Settings<'a, 't, 'i> { | ||||
|     sort_facet_values_by: Setting<HashMap<String, OrderBy>>, | ||||
|     pagination_max_total_hits: Setting<usize>, | ||||
|     proximity_precision: Setting<ProximityPrecision>, | ||||
|     embedder_settings: Setting<BTreeMap<String, Setting<EmbeddingSettings>>>, | ||||
| } | ||||
|  | ||||
| impl<'a, 't, 'i> Settings<'a, 't, 'i> { | ||||
| @@ -161,6 +171,7 @@ impl<'a, 't, 'i> Settings<'a, 't, 'i> { | ||||
|             sort_facet_values_by: Setting::NotSet, | ||||
|             pagination_max_total_hits: Setting::NotSet, | ||||
|             proximity_precision: Setting::NotSet, | ||||
|             embedder_settings: Setting::NotSet, | ||||
|             indexer_config, | ||||
|         } | ||||
|     } | ||||
| @@ -343,6 +354,14 @@ impl<'a, 't, 'i> Settings<'a, 't, 'i> { | ||||
|         self.proximity_precision = Setting::Reset; | ||||
|     } | ||||
|  | ||||
|     pub fn set_embedder_settings(&mut self, value: BTreeMap<String, Setting<EmbeddingSettings>>) { | ||||
|         self.embedder_settings = Setting::Set(value); | ||||
|     } | ||||
|  | ||||
|     pub fn reset_embedder_settings(&mut self) { | ||||
|         self.embedder_settings = Setting::Reset; | ||||
|     } | ||||
|  | ||||
|     fn reindex<FP, FA>( | ||||
|         &mut self, | ||||
|         progress_callback: &FP, | ||||
| @@ -890,6 +909,60 @@ impl<'a, 't, 'i> Settings<'a, 't, 'i> { | ||||
|         Ok(changed) | ||||
|     } | ||||
|  | ||||
|     fn update_embedding_configs(&mut self) -> Result<bool> { | ||||
|         let update = match std::mem::take(&mut self.embedder_settings) { | ||||
|             Setting::Set(configs) => { | ||||
|                 let mut changed = false; | ||||
|                 let old_configs = self.index.embedding_configs(self.wtxn)?; | ||||
|                 let old_configs: BTreeMap<String, Setting<EmbeddingSettings>> = | ||||
|                     old_configs.into_iter().map(|(k, v)| (k, Setting::Set(v.into()))).collect(); | ||||
|  | ||||
|                 let mut new_configs = BTreeMap::new(); | ||||
|                 for joined in old_configs | ||||
|                     .into_iter() | ||||
|                     .merge_join_by(configs.into_iter(), |(left, _), (right, _)| left.cmp(right)) | ||||
|                 { | ||||
|                     match joined { | ||||
|                         EitherOrBoth::Both((name, mut old), (_, new)) => { | ||||
|                             old.apply(new); | ||||
|                             let new = validate_prompt(&name, old)?; | ||||
|                             changed = true; | ||||
|                             new_configs.insert(name, new); | ||||
|                         } | ||||
|                         EitherOrBoth::Left((name, setting)) => { | ||||
|                             new_configs.insert(name, setting); | ||||
|                         } | ||||
|                         EitherOrBoth::Right((name, setting)) => { | ||||
|                             let setting = validate_prompt(&name, setting)?; | ||||
|                             changed = true; | ||||
|                             new_configs.insert(name, setting); | ||||
|                         } | ||||
|                     } | ||||
|                 } | ||||
|                 let new_configs: Vec<(String, EmbeddingConfig)> = new_configs | ||||
|                     .into_iter() | ||||
|                     .filter_map(|(name, setting)| match setting { | ||||
|                         Setting::Set(value) => Some((name, value.into())), | ||||
|                         Setting::Reset => None, | ||||
|                         Setting::NotSet => Some((name, EmbeddingSettings::default().into())), | ||||
|                     }) | ||||
|                     .collect(); | ||||
|                 if new_configs.is_empty() { | ||||
|                     self.index.delete_embedding_configs(self.wtxn)?; | ||||
|                 } else { | ||||
|                     self.index.put_embedding_configs(self.wtxn, new_configs)?; | ||||
|                 } | ||||
|                 changed | ||||
|             } | ||||
|             Setting::Reset => { | ||||
|                 self.index.delete_embedding_configs(self.wtxn)?; | ||||
|                 true | ||||
|             } | ||||
|             Setting::NotSet => false, | ||||
|         }; | ||||
|         Ok(update) | ||||
|     } | ||||
|  | ||||
|     pub fn execute<FP, FA>(mut self, progress_callback: FP, should_abort: FA) -> Result<()> | ||||
|     where | ||||
|         FP: Fn(UpdateIndexingStep) + Sync, | ||||
| @@ -927,6 +1000,13 @@ impl<'a, 't, 'i> Settings<'a, 't, 'i> { | ||||
|         let searchable_updated = self.update_searchable()?; | ||||
|         let exact_attributes_updated = self.update_exact_attributes()?; | ||||
|         let proximity_precision = self.update_proximity_precision()?; | ||||
|         // TODO: very rough approximation of the needs for reindexing where any change will result in | ||||
|         // a full reindexing. | ||||
|         // What can be done instead: | ||||
|         // 1. Only change the distance on a distance change | ||||
|         // 2. Only change the name -> embedder mapping on a name change | ||||
|         // 3. Keep the old vectors but reattempt indexing on a prompt change: only actually changed prompt will need embedding + storage | ||||
|         let embedding_configs_updated = self.update_embedding_configs()?; | ||||
|  | ||||
|         if stop_words_updated | ||||
|             || non_separator_tokens_updated | ||||
| @@ -937,6 +1017,7 @@ impl<'a, 't, 'i> Settings<'a, 't, 'i> { | ||||
|             || searchable_updated | ||||
|             || exact_attributes_updated | ||||
|             || proximity_precision | ||||
|             || embedding_configs_updated | ||||
|         { | ||||
|             self.reindex(&progress_callback, &should_abort, old_fields_ids_map)?; | ||||
|         } | ||||
| @@ -945,6 +1026,34 @@ impl<'a, 't, 'i> Settings<'a, 't, 'i> { | ||||
|     } | ||||
| } | ||||
|  | ||||
| fn validate_prompt( | ||||
|     name: &str, | ||||
|     new: Setting<EmbeddingSettings>, | ||||
| ) -> Result<Setting<EmbeddingSettings>> { | ||||
|     match new { | ||||
|         Setting::Set(EmbeddingSettings { | ||||
|             embedder_options, | ||||
|             prompt: | ||||
|                 Setting::Set(PromptSettings { template: Setting::Set(template), strategy, fallback }), | ||||
|         }) => { | ||||
|             // validate | ||||
|             let template = crate::prompt::Prompt::new(template, None, None) | ||||
|                 .map(|prompt| crate::prompt::PromptData::from(prompt).template) | ||||
|                 .map_err(|inner| UserError::InvalidPromptForEmbeddings(name.to_owned(), inner))?; | ||||
|  | ||||
|             Ok(Setting::Set(EmbeddingSettings { | ||||
|                 embedder_options, | ||||
|                 prompt: Setting::Set(PromptSettings { | ||||
|                     template: Setting::Set(template), | ||||
|                     strategy, | ||||
|                     fallback, | ||||
|                 }), | ||||
|             })) | ||||
|         } | ||||
|         new => Ok(new), | ||||
|     } | ||||
| } | ||||
|  | ||||
| #[cfg(test)] | ||||
| mod tests { | ||||
|     use big_s::S; | ||||
| @@ -1763,6 +1872,7 @@ mod tests { | ||||
|                     sort_facet_values_by, | ||||
|                     pagination_max_total_hits, | ||||
|                     proximity_precision, | ||||
|                     embedder_settings, | ||||
|                 } = settings; | ||||
|                 assert!(matches!(searchable_fields, Setting::NotSet)); | ||||
|                 assert!(matches!(displayed_fields, Setting::NotSet)); | ||||
| @@ -1785,6 +1895,7 @@ mod tests { | ||||
|                 assert!(matches!(sort_facet_values_by, Setting::NotSet)); | ||||
|                 assert!(matches!(pagination_max_total_hits, Setting::NotSet)); | ||||
|                 assert!(matches!(proximity_precision, Setting::NotSet)); | ||||
|                 assert!(matches!(embedder_settings, Setting::NotSet)); | ||||
|             }) | ||||
|             .unwrap(); | ||||
|     } | ||||
|   | ||||
							
								
								
									
										229
									
								
								milli/src/vector/error.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										229
									
								
								milli/src/vector/error.rs
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,229 @@ | ||||
| use std::path::PathBuf; | ||||
|  | ||||
| use hf_hub::api::sync::ApiError; | ||||
|  | ||||
| use crate::error::FaultSource; | ||||
| use crate::vector::openai::OpenAiError; | ||||
|  | ||||
| #[derive(Debug, thiserror::Error)] | ||||
| #[error("Error while generating embeddings: {inner}")] | ||||
| pub struct Error { | ||||
|     pub inner: Box<ErrorKind>, | ||||
| } | ||||
|  | ||||
| impl<I: Into<ErrorKind>> From<I> for Error { | ||||
|     fn from(value: I) -> Self { | ||||
|         Self { inner: Box::new(value.into()) } | ||||
|     } | ||||
| } | ||||
|  | ||||
| impl Error { | ||||
|     pub fn fault(&self) -> FaultSource { | ||||
|         match &*self.inner { | ||||
|             ErrorKind::NewEmbedderError(inner) => inner.fault, | ||||
|             ErrorKind::EmbedError(inner) => inner.fault, | ||||
|         } | ||||
|     } | ||||
| } | ||||
|  | ||||
| #[derive(Debug, thiserror::Error)] | ||||
| pub enum ErrorKind { | ||||
|     #[error(transparent)] | ||||
|     NewEmbedderError(#[from] NewEmbedderError), | ||||
|     #[error(transparent)] | ||||
|     EmbedError(#[from] EmbedError), | ||||
| } | ||||
|  | ||||
| #[derive(Debug, thiserror::Error)] | ||||
| #[error("{fault}: {kind}")] | ||||
| pub struct EmbedError { | ||||
|     pub kind: EmbedErrorKind, | ||||
|     pub fault: FaultSource, | ||||
| } | ||||
|  | ||||
| #[derive(Debug, thiserror::Error)] | ||||
| pub enum EmbedErrorKind { | ||||
|     #[error("could not tokenize: {0}")] | ||||
|     Tokenize(Box<dyn std::error::Error + Send + Sync>), | ||||
|     #[error("unexpected tensor shape: {0}")] | ||||
|     TensorShape(candle_core::Error), | ||||
|     #[error("unexpected tensor value: {0}")] | ||||
|     TensorValue(candle_core::Error), | ||||
|     #[error("could not run model: {0}")] | ||||
|     ModelForward(candle_core::Error), | ||||
|     #[error("could not reach OpenAI: {0}")] | ||||
|     OpenAiNetwork(reqwest::Error), | ||||
|     #[error("unexpected response from OpenAI: {0}")] | ||||
|     OpenAiUnexpected(reqwest::Error), | ||||
|     #[error("could not authenticate against OpenAI: {0}")] | ||||
|     OpenAiAuth(OpenAiError), | ||||
|     #[error("sent too many requests to OpenAI: {0}")] | ||||
|     OpenAiTooManyRequests(OpenAiError), | ||||
|     #[error("received internal error from OpenAI: {0}")] | ||||
|     OpenAiInternalServerError(OpenAiError), | ||||
|     #[error("sent too many tokens in a request to OpenAI: {0}")] | ||||
|     OpenAiTooManyTokens(OpenAiError), | ||||
|     #[error("received unhandled HTTP status code {0} from OpenAI")] | ||||
|     OpenAiUnhandledStatusCode(u16), | ||||
| } | ||||
|  | ||||
| impl EmbedError { | ||||
|     pub fn tokenize(inner: Box<dyn std::error::Error + Send + Sync>) -> Self { | ||||
|         Self { kind: EmbedErrorKind::Tokenize(inner), fault: FaultSource::Runtime } | ||||
|     } | ||||
|  | ||||
|     pub fn tensor_shape(inner: candle_core::Error) -> Self { | ||||
|         Self { kind: EmbedErrorKind::TensorShape(inner), fault: FaultSource::Bug } | ||||
|     } | ||||
|  | ||||
|     pub fn tensor_value(inner: candle_core::Error) -> Self { | ||||
|         Self { kind: EmbedErrorKind::TensorValue(inner), fault: FaultSource::Bug } | ||||
|     } | ||||
|  | ||||
|     pub fn model_forward(inner: candle_core::Error) -> Self { | ||||
|         Self { kind: EmbedErrorKind::ModelForward(inner), fault: FaultSource::Runtime } | ||||
|     } | ||||
|  | ||||
|     pub fn openai_network(inner: reqwest::Error) -> Self { | ||||
|         Self { kind: EmbedErrorKind::OpenAiNetwork(inner), fault: FaultSource::Runtime } | ||||
|     } | ||||
|  | ||||
|     pub fn openai_unexpected(inner: reqwest::Error) -> EmbedError { | ||||
|         Self { kind: EmbedErrorKind::OpenAiUnexpected(inner), fault: FaultSource::Bug } | ||||
|     } | ||||
|  | ||||
|     pub(crate) fn openai_auth_error(inner: OpenAiError) -> EmbedError { | ||||
|         Self { kind: EmbedErrorKind::OpenAiAuth(inner), fault: FaultSource::User } | ||||
|     } | ||||
|  | ||||
|     pub(crate) fn openai_too_many_requests(inner: OpenAiError) -> EmbedError { | ||||
|         Self { kind: EmbedErrorKind::OpenAiTooManyRequests(inner), fault: FaultSource::Runtime } | ||||
|     } | ||||
|  | ||||
|     pub(crate) fn openai_internal_server_error(inner: OpenAiError) -> EmbedError { | ||||
|         Self { kind: EmbedErrorKind::OpenAiInternalServerError(inner), fault: FaultSource::Runtime } | ||||
|     } | ||||
|  | ||||
|     pub(crate) fn openai_too_many_tokens(inner: OpenAiError) -> EmbedError { | ||||
|         Self { kind: EmbedErrorKind::OpenAiTooManyTokens(inner), fault: FaultSource::Bug } | ||||
|     } | ||||
|  | ||||
|     pub(crate) fn openai_unhandled_status_code(code: u16) -> EmbedError { | ||||
|         Self { kind: EmbedErrorKind::OpenAiUnhandledStatusCode(code), fault: FaultSource::Bug } | ||||
|     } | ||||
| } | ||||
|  | ||||
| #[derive(Debug, thiserror::Error)] | ||||
| #[error("{fault}: {kind}")] | ||||
| pub struct NewEmbedderError { | ||||
|     pub kind: NewEmbedderErrorKind, | ||||
|     pub fault: FaultSource, | ||||
| } | ||||
|  | ||||
| impl NewEmbedderError { | ||||
|     pub fn open_config(config_filename: PathBuf, inner: std::io::Error) -> NewEmbedderError { | ||||
|         let open_config = OpenConfig { filename: config_filename, inner }; | ||||
|  | ||||
|         Self { kind: NewEmbedderErrorKind::OpenConfig(open_config), fault: FaultSource::Runtime } | ||||
|     } | ||||
|  | ||||
|     pub fn deserialize_config( | ||||
|         config: String, | ||||
|         config_filename: PathBuf, | ||||
|         inner: serde_json::Error, | ||||
|     ) -> NewEmbedderError { | ||||
|         let deserialize_config = DeserializeConfig { config, filename: config_filename, inner }; | ||||
|         Self { | ||||
|             kind: NewEmbedderErrorKind::DeserializeConfig(deserialize_config), | ||||
|             fault: FaultSource::Runtime, | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     pub fn open_tokenizer( | ||||
|         tokenizer_filename: PathBuf, | ||||
|         inner: Box<dyn std::error::Error + Send + Sync>, | ||||
|     ) -> NewEmbedderError { | ||||
|         let open_tokenizer = OpenTokenizer { filename: tokenizer_filename, inner }; | ||||
|         Self { | ||||
|             kind: NewEmbedderErrorKind::OpenTokenizer(open_tokenizer), | ||||
|             fault: FaultSource::Runtime, | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     pub fn new_api_fail(inner: ApiError) -> Self { | ||||
|         Self { kind: NewEmbedderErrorKind::NewApiFail(inner), fault: FaultSource::Bug } | ||||
|     } | ||||
|  | ||||
|     pub fn api_get(inner: ApiError) -> Self { | ||||
|         Self { kind: NewEmbedderErrorKind::ApiGet(inner), fault: FaultSource::Undecided } | ||||
|     } | ||||
|  | ||||
|     pub fn pytorch_weight(inner: candle_core::Error) -> Self { | ||||
|         Self { kind: NewEmbedderErrorKind::PytorchWeight(inner), fault: FaultSource::Runtime } | ||||
|     } | ||||
|  | ||||
|     pub fn safetensor_weight(inner: candle_core::Error) -> Self { | ||||
|         Self { kind: NewEmbedderErrorKind::PytorchWeight(inner), fault: FaultSource::Runtime } | ||||
|     } | ||||
|  | ||||
|     pub fn load_model(inner: candle_core::Error) -> Self { | ||||
|         Self { kind: NewEmbedderErrorKind::LoadModel(inner), fault: FaultSource::Runtime } | ||||
|     } | ||||
|  | ||||
|     pub fn openai_initialize_web_client(inner: reqwest::Error) -> Self { | ||||
|         Self { kind: NewEmbedderErrorKind::InitWebClient(inner), fault: FaultSource::Runtime } | ||||
|     } | ||||
|  | ||||
|     pub fn openai_invalid_api_key_format(inner: reqwest::header::InvalidHeaderValue) -> Self { | ||||
|         Self { kind: NewEmbedderErrorKind::InvalidApiKeyFormat(inner), fault: FaultSource::User } | ||||
|     } | ||||
| } | ||||
|  | ||||
| #[derive(Debug, thiserror::Error)] | ||||
| #[error("could not open config at {filename:?}: {inner}")] | ||||
| pub struct OpenConfig { | ||||
|     pub filename: PathBuf, | ||||
|     pub inner: std::io::Error, | ||||
| } | ||||
|  | ||||
| #[derive(Debug, thiserror::Error)] | ||||
| #[error("could not deserialize config at {filename}: {inner}. Config follows:\n{config}")] | ||||
| pub struct DeserializeConfig { | ||||
|     pub config: String, | ||||
|     pub filename: PathBuf, | ||||
|     pub inner: serde_json::Error, | ||||
| } | ||||
|  | ||||
| #[derive(Debug, thiserror::Error)] | ||||
| #[error("could not open tokenizer at {filename}: {inner}")] | ||||
| pub struct OpenTokenizer { | ||||
|     pub filename: PathBuf, | ||||
|     #[source] | ||||
|     pub inner: Box<dyn std::error::Error + Send + Sync>, | ||||
| } | ||||
|  | ||||
| #[derive(Debug, thiserror::Error)] | ||||
| pub enum NewEmbedderErrorKind { | ||||
|     // hf | ||||
|     #[error(transparent)] | ||||
|     OpenConfig(OpenConfig), | ||||
|     #[error(transparent)] | ||||
|     DeserializeConfig(DeserializeConfig), | ||||
|     #[error(transparent)] | ||||
|     OpenTokenizer(OpenTokenizer), | ||||
|     #[error("could not build weights from Pytorch weights: {0}")] | ||||
|     PytorchWeight(candle_core::Error), | ||||
|     #[error("could not build weights from Safetensor weights: {0}")] | ||||
|     SafetensorWeight(candle_core::Error), | ||||
|     #[error("could not spawn HG_HUB API client: {0}")] | ||||
|     NewApiFail(ApiError), | ||||
|     #[error("fetching file from HG_HUB failed: {0}")] | ||||
|     ApiGet(ApiError), | ||||
|     #[error("loading model failed: {0}")] | ||||
|     LoadModel(candle_core::Error), | ||||
|     // openai | ||||
|     #[error("initializing web client for sending embedding requests failed: {0}")] | ||||
|     InitWebClient(reqwest::Error), | ||||
|     #[error("The API key passed to Authorization error was in an invalid format: {0}")] | ||||
|     InvalidApiKeyFormat(reqwest::header::InvalidHeaderValue), | ||||
| } | ||||
							
								
								
									
										192
									
								
								milli/src/vector/hf.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										192
									
								
								milli/src/vector/hf.rs
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,192 @@ | ||||
| use candle_core::Tensor; | ||||
| use candle_nn::VarBuilder; | ||||
| use candle_transformers::models::bert::{BertModel, Config, DTYPE}; | ||||
| // FIXME: currently we'll be using the hub to retrieve model, in the future we might want to embed it into Meilisearch itself | ||||
| use hf_hub::api::sync::Api; | ||||
| use hf_hub::{Repo, RepoType}; | ||||
| use tokenizers::{PaddingParams, Tokenizer}; | ||||
|  | ||||
| pub use super::error::{EmbedError, Error, NewEmbedderError}; | ||||
| use super::{Embedding, Embeddings}; | ||||
|  | ||||
| #[derive( | ||||
|     Debug, | ||||
|     Clone, | ||||
|     Copy, | ||||
|     Default, | ||||
|     Hash, | ||||
|     PartialEq, | ||||
|     Eq, | ||||
|     serde::Deserialize, | ||||
|     serde::Serialize, | ||||
|     deserr::Deserr, | ||||
| )] | ||||
| #[serde(deny_unknown_fields, rename_all = "camelCase")] | ||||
| #[deserr(rename_all = camelCase, deny_unknown_fields)] | ||||
| pub enum WeightSource { | ||||
|     #[default] | ||||
|     Safetensors, | ||||
|     Pytorch, | ||||
| } | ||||
|  | ||||
| #[derive(Debug, Clone, Hash, PartialEq, Eq, serde::Deserialize, serde::Serialize)] | ||||
| pub struct EmbedderOptions { | ||||
|     pub model: String, | ||||
|     pub revision: Option<String>, | ||||
|     pub weight_source: WeightSource, | ||||
|     pub normalize_embeddings: bool, | ||||
| } | ||||
|  | ||||
| impl EmbedderOptions { | ||||
|     pub fn new() -> Self { | ||||
|         Self { | ||||
|             //model: "sentence-transformers/all-MiniLM-L6-v2".to_string(), | ||||
|             model: "BAAI/bge-base-en-v1.5".to_string(), | ||||
|             //revision: Some("refs/pr/21".to_string()), | ||||
|             revision: None, | ||||
|             //weight_source: Default::default(), | ||||
|             weight_source: WeightSource::Pytorch, | ||||
|             normalize_embeddings: true, | ||||
|         } | ||||
|     } | ||||
| } | ||||
|  | ||||
| impl Default for EmbedderOptions { | ||||
|     fn default() -> Self { | ||||
|         Self::new() | ||||
|     } | ||||
| } | ||||
|  | ||||
| /// Perform embedding of documents and queries | ||||
| pub struct Embedder { | ||||
|     model: BertModel, | ||||
|     tokenizer: Tokenizer, | ||||
|     options: EmbedderOptions, | ||||
| } | ||||
|  | ||||
| impl std::fmt::Debug for Embedder { | ||||
|     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { | ||||
|         f.debug_struct("Embedder") | ||||
|             .field("model", &self.options.model) | ||||
|             .field("tokenizer", &self.tokenizer) | ||||
|             .field("options", &self.options) | ||||
|             .finish() | ||||
|     } | ||||
| } | ||||
|  | ||||
| impl Embedder { | ||||
|     pub fn new(options: EmbedderOptions) -> std::result::Result<Self, NewEmbedderError> { | ||||
|         let device = candle_core::Device::Cpu; | ||||
|         let repo = match options.revision.clone() { | ||||
|             Some(revision) => Repo::with_revision(options.model.clone(), RepoType::Model, revision), | ||||
|             None => Repo::model(options.model.clone()), | ||||
|         }; | ||||
|         let (config_filename, tokenizer_filename, weights_filename) = { | ||||
|             let api = Api::new().map_err(NewEmbedderError::new_api_fail)?; | ||||
|             let api = api.repo(repo); | ||||
|             let config = api.get("config.json").map_err(NewEmbedderError::api_get)?; | ||||
|             let tokenizer = api.get("tokenizer.json").map_err(NewEmbedderError::api_get)?; | ||||
|             let weights = match options.weight_source { | ||||
|                 WeightSource::Pytorch => { | ||||
|                     api.get("pytorch_model.bin").map_err(NewEmbedderError::api_get)? | ||||
|                 } | ||||
|                 WeightSource::Safetensors => { | ||||
|                     api.get("model.safetensors").map_err(NewEmbedderError::api_get)? | ||||
|                 } | ||||
|             }; | ||||
|             (config, tokenizer, weights) | ||||
|         }; | ||||
|  | ||||
|         let config = std::fs::read_to_string(&config_filename) | ||||
|             .map_err(|inner| NewEmbedderError::open_config(config_filename.clone(), inner))?; | ||||
|         let config: Config = serde_json::from_str(&config).map_err(|inner| { | ||||
|             NewEmbedderError::deserialize_config(config, config_filename, inner) | ||||
|         })?; | ||||
|         let mut tokenizer = Tokenizer::from_file(&tokenizer_filename) | ||||
|             .map_err(|inner| NewEmbedderError::open_tokenizer(tokenizer_filename, inner))?; | ||||
|  | ||||
|         let vb = match options.weight_source { | ||||
|             WeightSource::Pytorch => VarBuilder::from_pth(&weights_filename, DTYPE, &device) | ||||
|                 .map_err(NewEmbedderError::pytorch_weight)?, | ||||
|             WeightSource::Safetensors => unsafe { | ||||
|                 VarBuilder::from_mmaped_safetensors(&[weights_filename], DTYPE, &device) | ||||
|                     .map_err(NewEmbedderError::safetensor_weight)? | ||||
|             }, | ||||
|         }; | ||||
|  | ||||
|         let model = BertModel::load(vb, &config).map_err(NewEmbedderError::load_model)?; | ||||
|  | ||||
|         if let Some(pp) = tokenizer.get_padding_mut() { | ||||
|             pp.strategy = tokenizers::PaddingStrategy::BatchLongest | ||||
|         } else { | ||||
|             let pp = PaddingParams { | ||||
|                 strategy: tokenizers::PaddingStrategy::BatchLongest, | ||||
|                 ..Default::default() | ||||
|             }; | ||||
|             tokenizer.with_padding(Some(pp)); | ||||
|         } | ||||
|  | ||||
|         Ok(Self { model, tokenizer, options }) | ||||
|     } | ||||
|  | ||||
|     pub async fn embed( | ||||
|         &self, | ||||
|         mut texts: Vec<String>, | ||||
|     ) -> std::result::Result<Vec<Embeddings<f32>>, EmbedError> { | ||||
|         let tokens = match texts.len() { | ||||
|             1 => vec![self | ||||
|                 .tokenizer | ||||
|                 .encode(texts.pop().unwrap(), true) | ||||
|                 .map_err(EmbedError::tokenize)?], | ||||
|             _ => self.tokenizer.encode_batch(texts, true).map_err(EmbedError::tokenize)?, | ||||
|         }; | ||||
|         let token_ids = tokens | ||||
|             .iter() | ||||
|             .map(|tokens| { | ||||
|                 let tokens = tokens.get_ids().to_vec(); | ||||
|                 Tensor::new(tokens.as_slice(), &self.model.device).map_err(EmbedError::tensor_shape) | ||||
|             }) | ||||
|             .collect::<Result<Vec<_>, EmbedError>>()?; | ||||
|  | ||||
|         let token_ids = Tensor::stack(&token_ids, 0).map_err(EmbedError::tensor_shape)?; | ||||
|         let token_type_ids = token_ids.zeros_like().map_err(EmbedError::tensor_shape)?; | ||||
|         let embeddings = | ||||
|             self.model.forward(&token_ids, &token_type_ids).map_err(EmbedError::model_forward)?; | ||||
|  | ||||
|         // Apply some avg-pooling by taking the mean embedding value for all tokens (including padding) | ||||
|         let (_n_sentence, n_tokens, _hidden_size) = | ||||
|             embeddings.dims3().map_err(EmbedError::tensor_shape)?; | ||||
|  | ||||
|         let embeddings = (embeddings.sum(1).map_err(EmbedError::tensor_value)? / (n_tokens as f64)) | ||||
|             .map_err(EmbedError::tensor_shape)?; | ||||
|  | ||||
|         let embeddings: Tensor = if self.options.normalize_embeddings { | ||||
|             normalize_l2(&embeddings).map_err(EmbedError::tensor_value)? | ||||
|         } else { | ||||
|             embeddings | ||||
|         }; | ||||
|  | ||||
|         let embeddings: Vec<Embedding> = embeddings.to_vec2().map_err(EmbedError::tensor_shape)?; | ||||
|         Ok(embeddings.into_iter().map(Embeddings::from_single_embedding).collect()) | ||||
|     } | ||||
|  | ||||
|     pub async fn embed_chunks( | ||||
|         &self, | ||||
|         text_chunks: Vec<Vec<String>>, | ||||
|     ) -> std::result::Result<Vec<Vec<Embeddings<f32>>>, EmbedError> { | ||||
|         futures::future::try_join_all(text_chunks.into_iter().map(|prompts| self.embed(prompts))) | ||||
|             .await | ||||
|     } | ||||
|  | ||||
|     pub fn chunk_count_hint(&self) -> usize { | ||||
|         1 | ||||
|     } | ||||
|  | ||||
|     pub fn prompt_count_in_chunk_hint(&self) -> usize { | ||||
|         std::thread::available_parallelism().map(|x| x.get()).unwrap_or(8) | ||||
|     } | ||||
| } | ||||
|  | ||||
| fn normalize_l2(v: &Tensor) -> Result<Tensor, candle_core::Error> { | ||||
|     v.broadcast_div(&v.sqr()?.sum_keepdim(1)?.sqrt()?) | ||||
| } | ||||
							
								
								
									
										142
									
								
								milli/src/vector/mod.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										142
									
								
								milli/src/vector/mod.rs
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,142 @@ | ||||
| use self::error::{EmbedError, NewEmbedderError}; | ||||
| use crate::prompt::PromptData; | ||||
|  | ||||
| pub mod error; | ||||
| pub mod hf; | ||||
| pub mod openai; | ||||
| pub mod settings; | ||||
|  | ||||
| pub use self::error::Error; | ||||
|  | ||||
| pub type Embedding = Vec<f32>; | ||||
|  | ||||
| pub struct Embeddings<F> { | ||||
|     data: Vec<F>, | ||||
|     dimension: usize, | ||||
| } | ||||
|  | ||||
| impl<F> Embeddings<F> { | ||||
|     pub fn new(dimension: usize) -> Self { | ||||
|         Self { data: Default::default(), dimension } | ||||
|     } | ||||
|  | ||||
|     pub fn from_single_embedding(embedding: Vec<F>) -> Self { | ||||
|         Self { dimension: embedding.len(), data: embedding } | ||||
|     } | ||||
|  | ||||
|     pub fn from_inner(data: Vec<F>, dimension: usize) -> Result<Self, Vec<F>> { | ||||
|         let mut this = Self::new(dimension); | ||||
|         this.append(data)?; | ||||
|         Ok(this) | ||||
|     } | ||||
|  | ||||
|     pub fn dimension(&self) -> usize { | ||||
|         self.dimension | ||||
|     } | ||||
|  | ||||
|     pub fn into_inner(self) -> Vec<F> { | ||||
|         self.data | ||||
|     } | ||||
|  | ||||
|     pub fn as_inner(&self) -> &[F] { | ||||
|         &self.data | ||||
|     } | ||||
|  | ||||
|     pub fn iter(&self) -> impl Iterator<Item = &'_ [F]> + '_ { | ||||
|         self.data.as_slice().chunks_exact(self.dimension) | ||||
|     } | ||||
|  | ||||
|     pub fn push(&mut self, mut embedding: Vec<F>) -> Result<(), Vec<F>> { | ||||
|         if embedding.len() != self.dimension { | ||||
|             return Err(embedding); | ||||
|         } | ||||
|         self.data.append(&mut embedding); | ||||
|         Ok(()) | ||||
|     } | ||||
|  | ||||
|     pub fn append(&mut self, mut embeddings: Vec<F>) -> Result<(), Vec<F>> { | ||||
|         if embeddings.len() % self.dimension != 0 { | ||||
|             return Err(embeddings); | ||||
|         } | ||||
|         self.data.append(&mut embeddings); | ||||
|         Ok(()) | ||||
|     } | ||||
| } | ||||
|  | ||||
| #[derive(Debug)] | ||||
| pub enum Embedder { | ||||
|     HuggingFace(hf::Embedder), | ||||
|     OpenAi(openai::Embedder), | ||||
| } | ||||
|  | ||||
| #[derive(Debug, Clone, Default, serde::Deserialize, serde::Serialize)] | ||||
| pub struct EmbeddingConfig { | ||||
|     pub embedder_options: EmbedderOptions, | ||||
|     pub prompt: PromptData, | ||||
|     // TODO: add metrics and anything needed | ||||
| } | ||||
|  | ||||
| #[derive(Debug, Clone, Hash, PartialEq, Eq, serde::Deserialize, serde::Serialize)] | ||||
| pub enum EmbedderOptions { | ||||
|     HuggingFace(hf::EmbedderOptions), | ||||
|     OpenAi(openai::EmbedderOptions), | ||||
| } | ||||
|  | ||||
| impl Default for EmbedderOptions { | ||||
|     fn default() -> Self { | ||||
|         Self::HuggingFace(Default::default()) | ||||
|     } | ||||
| } | ||||
|  | ||||
| impl EmbedderOptions { | ||||
|     pub fn huggingface() -> Self { | ||||
|         Self::HuggingFace(hf::EmbedderOptions::new()) | ||||
|     } | ||||
|  | ||||
|     pub fn openai(api_key: String) -> Self { | ||||
|         Self::OpenAi(openai::EmbedderOptions::with_default_model(api_key)) | ||||
|     } | ||||
| } | ||||
|  | ||||
| impl Embedder { | ||||
|     pub fn new(options: EmbedderOptions) -> std::result::Result<Self, NewEmbedderError> { | ||||
|         Ok(match options { | ||||
|             EmbedderOptions::HuggingFace(options) => Self::HuggingFace(hf::Embedder::new(options)?), | ||||
|             EmbedderOptions::OpenAi(options) => Self::OpenAi(openai::Embedder::new(options)?), | ||||
|         }) | ||||
|     } | ||||
|  | ||||
|     pub async fn embed( | ||||
|         &self, | ||||
|         texts: Vec<String>, | ||||
|     ) -> std::result::Result<Vec<Embeddings<f32>>, EmbedError> { | ||||
|         match self { | ||||
|             Embedder::HuggingFace(embedder) => embedder.embed(texts).await, | ||||
|             Embedder::OpenAi(embedder) => embedder.embed(texts).await, | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     pub async fn embed_chunks( | ||||
|         &self, | ||||
|         text_chunks: Vec<Vec<String>>, | ||||
|     ) -> std::result::Result<Vec<Vec<Embeddings<f32>>>, EmbedError> { | ||||
|         match self { | ||||
|             Embedder::HuggingFace(embedder) => embedder.embed_chunks(text_chunks).await, | ||||
|             Embedder::OpenAi(embedder) => embedder.embed_chunks(text_chunks).await, | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     pub fn chunk_count_hint(&self) -> usize { | ||||
|         match self { | ||||
|             Embedder::HuggingFace(embedder) => embedder.chunk_count_hint(), | ||||
|             Embedder::OpenAi(embedder) => embedder.chunk_count_hint(), | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     pub fn prompt_count_in_chunk_hint(&self) -> usize { | ||||
|         match self { | ||||
|             Embedder::HuggingFace(embedder) => embedder.prompt_count_in_chunk_hint(), | ||||
|             Embedder::OpenAi(embedder) => embedder.prompt_count_in_chunk_hint(), | ||||
|         } | ||||
|     } | ||||
| } | ||||
							
								
								
									
										416
									
								
								milli/src/vector/openai.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										416
									
								
								milli/src/vector/openai.rs
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,416 @@ | ||||
| use std::fmt::Display; | ||||
|  | ||||
| use reqwest::StatusCode; | ||||
| use serde::{Deserialize, Serialize}; | ||||
|  | ||||
| use super::error::{EmbedError, NewEmbedderError}; | ||||
| use super::{Embedding, Embeddings}; | ||||
|  | ||||
| #[derive(Debug)] | ||||
| pub struct Embedder { | ||||
|     client: reqwest::Client, | ||||
|     tokenizer: tiktoken_rs::CoreBPE, | ||||
|     options: EmbedderOptions, | ||||
| } | ||||
|  | ||||
| #[derive(Debug, Clone, Hash, PartialEq, Eq, serde::Deserialize, serde::Serialize)] | ||||
| pub struct EmbedderOptions { | ||||
|     pub api_key: String, | ||||
|     pub embedding_model: EmbeddingModel, | ||||
| } | ||||
|  | ||||
| #[derive( | ||||
|     Debug, | ||||
|     Clone, | ||||
|     Copy, | ||||
|     Default, | ||||
|     Hash, | ||||
|     PartialEq, | ||||
|     Eq, | ||||
|     serde::Serialize, | ||||
|     serde::Deserialize, | ||||
|     deserr::Deserr, | ||||
| )] | ||||
| #[serde(deny_unknown_fields, rename_all = "camelCase")] | ||||
| #[deserr(rename_all = camelCase, deny_unknown_fields)] | ||||
| pub enum EmbeddingModel { | ||||
|     #[default] | ||||
|     TextEmbeddingAda002, | ||||
| } | ||||
|  | ||||
| impl EmbeddingModel { | ||||
|     pub fn max_token(&self) -> usize { | ||||
|         match self { | ||||
|             EmbeddingModel::TextEmbeddingAda002 => 8191, | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     pub fn dimensions(&self) -> usize { | ||||
|         match self { | ||||
|             EmbeddingModel::TextEmbeddingAda002 => 1536, | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     pub fn name(&self) -> &'static str { | ||||
|         match self { | ||||
|             EmbeddingModel::TextEmbeddingAda002 => "text-embedding-ada-002", | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     pub fn from_name(name: &'static str) -> Option<Self> { | ||||
|         match name { | ||||
|             "text-embedding-ada-002" => Some(EmbeddingModel::TextEmbeddingAda002), | ||||
|             _ => None, | ||||
|         } | ||||
|     } | ||||
| } | ||||
|  | ||||
| pub const OPENAI_EMBEDDINGS_URL: &str = "https://api.openai.com/v1/embeddings"; | ||||
|  | ||||
| impl EmbedderOptions { | ||||
|     pub fn with_default_model(api_key: String) -> Self { | ||||
|         Self { api_key, embedding_model: Default::default() } | ||||
|     } | ||||
|  | ||||
|     pub fn with_embedding_model(api_key: String, embedding_model: EmbeddingModel) -> Self { | ||||
|         Self { api_key, embedding_model } | ||||
|     } | ||||
| } | ||||
|  | ||||
| impl Embedder { | ||||
|     pub fn new(options: EmbedderOptions) -> Result<Self, NewEmbedderError> { | ||||
|         let mut headers = reqwest::header::HeaderMap::new(); | ||||
|         headers.insert( | ||||
|             reqwest::header::AUTHORIZATION, | ||||
|             reqwest::header::HeaderValue::from_str(&format!("Bearer {}", &options.api_key)) | ||||
|                 .map_err(NewEmbedderError::openai_invalid_api_key_format)?, | ||||
|         ); | ||||
|         headers.insert( | ||||
|             reqwest::header::CONTENT_TYPE, | ||||
|             reqwest::header::HeaderValue::from_static("application/json"), | ||||
|         ); | ||||
|         let client = reqwest::ClientBuilder::new() | ||||
|             .default_headers(headers) | ||||
|             .build() | ||||
|             .map_err(NewEmbedderError::openai_initialize_web_client)?; | ||||
|  | ||||
|         // looking at the code it is very unclear that this can actually fail. | ||||
|         let tokenizer = tiktoken_rs::cl100k_base().unwrap(); | ||||
|  | ||||
|         Ok(Self { options, client, tokenizer }) | ||||
|     } | ||||
|  | ||||
|     pub async fn embed(&self, texts: Vec<String>) -> Result<Vec<Embeddings<f32>>, EmbedError> { | ||||
|         let mut tokenized = false; | ||||
|  | ||||
|         for attempt in 0..7 { | ||||
|             let result = if tokenized { | ||||
|                 self.try_embed_tokenized(&texts).await | ||||
|             } else { | ||||
|                 self.try_embed(&texts).await | ||||
|             }; | ||||
|  | ||||
|             let retry_duration = match result { | ||||
|                 Ok(embeddings) => return Ok(embeddings), | ||||
|                 Err(retry) => { | ||||
|                     log::warn!("Failed: {}", retry.error); | ||||
|                     tokenized |= retry.must_tokenize(); | ||||
|                     retry.into_duration(attempt) | ||||
|                 } | ||||
|             }?; | ||||
|             log::warn!("Attempt #{}, retrying after {}ms.", attempt, retry_duration.as_millis()); | ||||
|             tokio::time::sleep(retry_duration).await; | ||||
|         } | ||||
|  | ||||
|         let result = if tokenized { | ||||
|             self.try_embed_tokenized(&texts).await | ||||
|         } else { | ||||
|             self.try_embed(&texts).await | ||||
|         }; | ||||
|  | ||||
|         result.map_err(Retry::into_error) | ||||
|     } | ||||
|  | ||||
|     async fn check_response(response: reqwest::Response) -> Result<reqwest::Response, Retry> { | ||||
|         if !response.status().is_success() { | ||||
|             match response.status() { | ||||
|                 StatusCode::UNAUTHORIZED => { | ||||
|                     let error_response: OpenAiErrorResponse = response | ||||
|                         .json() | ||||
|                         .await | ||||
|                         .map_err(EmbedError::openai_unexpected) | ||||
|                         .map_err(Retry::retry_later)?; | ||||
|  | ||||
|                     return Err(Retry::give_up(EmbedError::openai_auth_error( | ||||
|                         error_response.error, | ||||
|                     ))); | ||||
|                 } | ||||
|                 StatusCode::TOO_MANY_REQUESTS => { | ||||
|                     let error_response: OpenAiErrorResponse = response | ||||
|                         .json() | ||||
|                         .await | ||||
|                         .map_err(EmbedError::openai_unexpected) | ||||
|                         .map_err(Retry::retry_later)?; | ||||
|  | ||||
|                     return Err(Retry::rate_limited(EmbedError::openai_too_many_requests( | ||||
|                         error_response.error, | ||||
|                     ))); | ||||
|                 } | ||||
|                 StatusCode::INTERNAL_SERVER_ERROR => { | ||||
|                     let error_response: OpenAiErrorResponse = response | ||||
|                         .json() | ||||
|                         .await | ||||
|                         .map_err(EmbedError::openai_unexpected) | ||||
|                         .map_err(Retry::retry_later)?; | ||||
|                     return Err(Retry::retry_later(EmbedError::openai_internal_server_error( | ||||
|                         error_response.error, | ||||
|                     ))); | ||||
|                 } | ||||
|                 StatusCode::SERVICE_UNAVAILABLE => { | ||||
|                     let error_response: OpenAiErrorResponse = response | ||||
|                         .json() | ||||
|                         .await | ||||
|                         .map_err(EmbedError::openai_unexpected) | ||||
|                         .map_err(Retry::retry_later)?; | ||||
|                     return Err(Retry::retry_later(EmbedError::openai_internal_server_error( | ||||
|                         error_response.error, | ||||
|                     ))); | ||||
|                 } | ||||
|                 StatusCode::BAD_REQUEST => { | ||||
|                     // Most probably, one text contained too many tokens | ||||
|                     let error_response: OpenAiErrorResponse = response | ||||
|                         .json() | ||||
|                         .await | ||||
|                         .map_err(EmbedError::openai_unexpected) | ||||
|                         .map_err(Retry::retry_later)?; | ||||
|  | ||||
|                     log::warn!("OpenAI: input was too long, retrying on tokenized version. For best performance, limit the size of your prompt."); | ||||
|  | ||||
|                     return Err(Retry::retry_tokenized(EmbedError::openai_too_many_tokens( | ||||
|                         error_response.error, | ||||
|                     ))); | ||||
|                 } | ||||
|                 code => { | ||||
|                     return Err(Retry::give_up(EmbedError::openai_unhandled_status_code( | ||||
|                         code.as_u16(), | ||||
|                     ))); | ||||
|                 } | ||||
|             } | ||||
|         } | ||||
|         Ok(response) | ||||
|     } | ||||
|  | ||||
|     async fn try_embed<S: AsRef<str> + serde::Serialize>( | ||||
|         &self, | ||||
|         texts: &[S], | ||||
|     ) -> Result<Vec<Embeddings<f32>>, Retry> { | ||||
|         for text in texts { | ||||
|             log::trace!("Received prompt: {}", text.as_ref()) | ||||
|         } | ||||
|         let request = OpenAiRequest { model: self.options.embedding_model.name(), input: texts }; | ||||
|         let response = self | ||||
|             .client | ||||
|             .post(OPENAI_EMBEDDINGS_URL) | ||||
|             .json(&request) | ||||
|             .send() | ||||
|             .await | ||||
|             .map_err(EmbedError::openai_network) | ||||
|             .map_err(Retry::retry_later)?; | ||||
|  | ||||
|         let response = Self::check_response(response).await?; | ||||
|  | ||||
|         let response: OpenAiResponse = response | ||||
|             .json() | ||||
|             .await | ||||
|             .map_err(EmbedError::openai_unexpected) | ||||
|             .map_err(Retry::retry_later)?; | ||||
|  | ||||
|         log::trace!("response: {:?}", response.data); | ||||
|  | ||||
|         Ok(response | ||||
|             .data | ||||
|             .into_iter() | ||||
|             .map(|data| Embeddings::from_single_embedding(data.embedding)) | ||||
|             .collect()) | ||||
|     } | ||||
|  | ||||
|     async fn try_embed_tokenized(&self, text: &[String]) -> Result<Vec<Embeddings<f32>>, Retry> { | ||||
|         pub const OVERLAP_SIZE: usize = 200; | ||||
|         let mut all_embeddings = Vec::with_capacity(text.len()); | ||||
|         for text in text { | ||||
|             let max_token_count = self.options.embedding_model.max_token(); | ||||
|             let encoded = self.tokenizer.encode_ordinary(text.as_str()); | ||||
|             let len = encoded.len(); | ||||
|             if len < max_token_count { | ||||
|                 all_embeddings.append(&mut self.try_embed(&[text]).await?); | ||||
|                 continue; | ||||
|             } | ||||
|  | ||||
|             let mut tokens = encoded.as_slice(); | ||||
|             let mut embeddings_for_prompt = | ||||
|                 Embeddings::new(self.options.embedding_model.dimensions()); | ||||
|             while tokens.len() > max_token_count { | ||||
|                 let window = &tokens[..max_token_count]; | ||||
|                 embeddings_for_prompt.push(self.embed_tokens(window).await?).unwrap(); | ||||
|  | ||||
|                 tokens = &tokens[max_token_count - OVERLAP_SIZE..]; | ||||
|             } | ||||
|  | ||||
|             // end of text | ||||
|             embeddings_for_prompt.push(self.embed_tokens(tokens).await?).unwrap(); | ||||
|  | ||||
|             all_embeddings.push(embeddings_for_prompt); | ||||
|         } | ||||
|         Ok(all_embeddings) | ||||
|     } | ||||
|  | ||||
|     async fn embed_tokens(&self, tokens: &[usize]) -> Result<Embedding, Retry> { | ||||
|         for attempt in 0..9 { | ||||
|             let duration = match self.try_embed_tokens(tokens).await { | ||||
|                 Ok(embedding) => return Ok(embedding), | ||||
|                 Err(retry) => retry.into_duration(attempt), | ||||
|             } | ||||
|             .map_err(Retry::retry_later)?; | ||||
|  | ||||
|             tokio::time::sleep(duration).await; | ||||
|         } | ||||
|  | ||||
|         self.try_embed_tokens(tokens).await.map_err(|retry| Retry::give_up(retry.into_error())) | ||||
|     } | ||||
|  | ||||
|     async fn try_embed_tokens(&self, tokens: &[usize]) -> Result<Embedding, Retry> { | ||||
|         let request = | ||||
|             OpenAiTokensRequest { model: self.options.embedding_model.name(), input: tokens }; | ||||
|         let response = self | ||||
|             .client | ||||
|             .post(OPENAI_EMBEDDINGS_URL) | ||||
|             .json(&request) | ||||
|             .send() | ||||
|             .await | ||||
|             .map_err(EmbedError::openai_network) | ||||
|             .map_err(Retry::retry_later)?; | ||||
|  | ||||
|         let response = Self::check_response(response).await?; | ||||
|  | ||||
|         let mut response: OpenAiResponse = response | ||||
|             .json() | ||||
|             .await | ||||
|             .map_err(EmbedError::openai_unexpected) | ||||
|             .map_err(Retry::retry_later)?; | ||||
|         Ok(response.data.pop().map(|data| data.embedding).unwrap_or_default()) | ||||
|     } | ||||
|  | ||||
|     pub async fn embed_chunks( | ||||
|         &self, | ||||
|         text_chunks: Vec<Vec<String>>, | ||||
|     ) -> Result<Vec<Vec<Embeddings<f32>>>, EmbedError> { | ||||
|         futures::future::try_join_all(text_chunks.into_iter().map(|prompts| self.embed(prompts))) | ||||
|             .await | ||||
|     } | ||||
|  | ||||
|     pub fn chunk_count_hint(&self) -> usize { | ||||
|         10 | ||||
|     } | ||||
|  | ||||
|     pub fn prompt_count_in_chunk_hint(&self) -> usize { | ||||
|         10 | ||||
|     } | ||||
| } | ||||
|  | ||||
| // retrying in case of failure | ||||
|  | ||||
| struct Retry { | ||||
|     error: EmbedError, | ||||
|     strategy: RetryStrategy, | ||||
| } | ||||
|  | ||||
| enum RetryStrategy { | ||||
|     GiveUp, | ||||
|     Retry, | ||||
|     RetryTokenized, | ||||
|     RetryAfterRateLimit, | ||||
| } | ||||
|  | ||||
| impl Retry { | ||||
|     fn give_up(error: EmbedError) -> Self { | ||||
|         Self { error, strategy: RetryStrategy::GiveUp } | ||||
|     } | ||||
|  | ||||
|     fn retry_later(error: EmbedError) -> Self { | ||||
|         Self { error, strategy: RetryStrategy::Retry } | ||||
|     } | ||||
|  | ||||
|     fn retry_tokenized(error: EmbedError) -> Self { | ||||
|         Self { error, strategy: RetryStrategy::RetryTokenized } | ||||
|     } | ||||
|  | ||||
|     fn rate_limited(error: EmbedError) -> Self { | ||||
|         Self { error, strategy: RetryStrategy::RetryAfterRateLimit } | ||||
|     } | ||||
|  | ||||
|     fn into_duration(self, attempt: u32) -> Result<tokio::time::Duration, EmbedError> { | ||||
|         match self.strategy { | ||||
|             RetryStrategy::GiveUp => Err(self.error), | ||||
|             RetryStrategy::Retry => Ok(tokio::time::Duration::from_millis((10u64).pow(attempt))), | ||||
|             RetryStrategy::RetryTokenized => Ok(tokio::time::Duration::from_millis(1)), | ||||
|             RetryStrategy::RetryAfterRateLimit => { | ||||
|                 Ok(tokio::time::Duration::from_millis(100 + 10u64.pow(attempt))) | ||||
|             } | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     fn must_tokenize(&self) -> bool { | ||||
|         matches!(self.strategy, RetryStrategy::RetryTokenized) | ||||
|     } | ||||
|  | ||||
|     fn into_error(self) -> EmbedError { | ||||
|         self.error | ||||
|     } | ||||
| } | ||||
|  | ||||
| // openai api structs | ||||
|  | ||||
| #[derive(Debug, Serialize)] | ||||
| struct OpenAiRequest<'a, S: AsRef<str> + serde::Serialize> { | ||||
|     model: &'a str, | ||||
|     input: &'a [S], | ||||
| } | ||||
|  | ||||
| #[derive(Debug, Serialize)] | ||||
| struct OpenAiTokensRequest<'a> { | ||||
|     model: &'a str, | ||||
|     input: &'a [usize], | ||||
| } | ||||
|  | ||||
| #[derive(Debug, Deserialize)] | ||||
| struct OpenAiResponse { | ||||
|     data: Vec<OpenAiEmbedding>, | ||||
| } | ||||
|  | ||||
| #[derive(Debug, Deserialize)] | ||||
| struct OpenAiErrorResponse { | ||||
|     error: OpenAiError, | ||||
| } | ||||
|  | ||||
| #[derive(Debug, Deserialize)] | ||||
| pub struct OpenAiError { | ||||
|     message: String, | ||||
|     // type: String, | ||||
|     code: Option<String>, | ||||
| } | ||||
|  | ||||
| impl Display for OpenAiError { | ||||
|     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { | ||||
|         match &self.code { | ||||
|             Some(code) => write!(f, "{} ({})", self.message, code), | ||||
|             None => write!(f, "{}", self.message), | ||||
|         } | ||||
|     } | ||||
| } | ||||
|  | ||||
| #[derive(Debug, Deserialize)] | ||||
| struct OpenAiEmbedding { | ||||
|     embedding: Embedding, | ||||
|     // object: String, | ||||
|     // index: usize, | ||||
| } | ||||
							
								
								
									
										308
									
								
								milli/src/vector/settings.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										308
									
								
								milli/src/vector/settings.rs
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,308 @@ | ||||
| use deserr::Deserr; | ||||
| use serde::{Deserialize, Serialize}; | ||||
|  | ||||
| use crate::prompt::{PromptData, PromptFallbackStrategy}; | ||||
| use crate::update::Setting; | ||||
| use crate::vector::hf::WeightSource; | ||||
| use crate::vector::EmbeddingConfig; | ||||
|  | ||||
| #[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq, Deserr)] | ||||
| #[serde(deny_unknown_fields, rename_all = "camelCase")] | ||||
| #[deserr(rename_all = camelCase, deny_unknown_fields)] | ||||
| pub struct EmbeddingSettings { | ||||
|     #[serde(default, skip_serializing_if = "Setting::is_not_set", rename = "source")] | ||||
|     #[deserr(default, rename = "source")] | ||||
|     pub embedder_options: Setting<EmbedderSettings>, | ||||
|     #[serde(default, skip_serializing_if = "Setting::is_not_set")] | ||||
|     #[deserr(default)] | ||||
|     pub prompt: Setting<PromptSettings>, | ||||
| } | ||||
|  | ||||
| impl EmbeddingSettings { | ||||
|     pub fn apply(&mut self, new: Self) { | ||||
|         let EmbeddingSettings { embedder_options, prompt } = new; | ||||
|         self.embedder_options.apply(embedder_options); | ||||
|         self.prompt.apply(prompt); | ||||
|     } | ||||
| } | ||||
|  | ||||
| impl From<EmbeddingConfig> for EmbeddingSettings { | ||||
|     fn from(value: EmbeddingConfig) -> Self { | ||||
|         Self { | ||||
|             embedder_options: Setting::Set(value.embedder_options.into()), | ||||
|             prompt: Setting::Set(value.prompt.into()), | ||||
|         } | ||||
|     } | ||||
| } | ||||
|  | ||||
| impl From<EmbeddingSettings> for EmbeddingConfig { | ||||
|     fn from(value: EmbeddingSettings) -> Self { | ||||
|         let mut this = Self::default(); | ||||
|         let EmbeddingSettings { embedder_options, prompt } = value; | ||||
|         if let Some(embedder_options) = embedder_options.set() { | ||||
|             this.embedder_options = embedder_options.into(); | ||||
|         } | ||||
|         if let Some(prompt) = prompt.set() { | ||||
|             this.prompt = prompt.into(); | ||||
|         } | ||||
|         this | ||||
|     } | ||||
| } | ||||
|  | ||||
| #[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq, Deserr)] | ||||
| #[serde(deny_unknown_fields, rename_all = "camelCase")] | ||||
| #[deserr(rename_all = camelCase, deny_unknown_fields)] | ||||
| pub struct PromptSettings { | ||||
|     #[serde(default, skip_serializing_if = "Setting::is_not_set")] | ||||
|     #[deserr(default)] | ||||
|     pub template: Setting<String>, | ||||
|     #[serde(default, skip_serializing_if = "Setting::is_not_set")] | ||||
|     #[deserr(default)] | ||||
|     pub strategy: Setting<PromptFallbackStrategy>, | ||||
|     #[serde(default, skip_serializing_if = "Setting::is_not_set")] | ||||
|     #[deserr(default)] | ||||
|     pub fallback: Setting<String>, | ||||
| } | ||||
|  | ||||
| impl PromptSettings { | ||||
|     pub fn apply(&mut self, new: Self) { | ||||
|         let PromptSettings { template, strategy, fallback } = new; | ||||
|         self.template.apply(template); | ||||
|         self.strategy.apply(strategy); | ||||
|         self.fallback.apply(fallback); | ||||
|     } | ||||
| } | ||||
|  | ||||
| impl From<PromptData> for PromptSettings { | ||||
|     fn from(value: PromptData) -> Self { | ||||
|         Self { | ||||
|             template: Setting::Set(value.template), | ||||
|             strategy: Setting::Set(value.strategy), | ||||
|             fallback: Setting::Set(value.fallback), | ||||
|         } | ||||
|     } | ||||
| } | ||||
|  | ||||
| impl From<PromptSettings> for PromptData { | ||||
|     fn from(value: PromptSettings) -> Self { | ||||
|         let mut this = PromptData::default(); | ||||
|         let PromptSettings { template, strategy, fallback } = value; | ||||
|         if let Some(template) = template.set() { | ||||
|             this.template = template; | ||||
|         } | ||||
|         if let Some(strategy) = strategy.set() { | ||||
|             this.strategy = strategy; | ||||
|         } | ||||
|         if let Some(fallback) = fallback.set() { | ||||
|             this.fallback = fallback; | ||||
|         } | ||||
|         this | ||||
|     } | ||||
| } | ||||
|  | ||||
| #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] | ||||
| #[serde(deny_unknown_fields, rename_all = "camelCase")] | ||||
| pub enum EmbedderSettings { | ||||
|     HuggingFace(Setting<HfEmbedderSettings>), | ||||
|     OpenAi(Setting<OpenAiEmbedderSettings>), | ||||
| } | ||||
|  | ||||
| impl<E> Deserr<E> for EmbedderSettings | ||||
| where | ||||
|     E: deserr::DeserializeError, | ||||
| { | ||||
|     fn deserialize_from_value<V: deserr::IntoValue>( | ||||
|         value: deserr::Value<V>, | ||||
|         location: deserr::ValuePointerRef, | ||||
|     ) -> Result<Self, E> { | ||||
|         match value { | ||||
|             deserr::Value::Map(map) => { | ||||
|                 if deserr::Map::len(&map) != 1 { | ||||
|                     return Err(deserr::take_cf_content(E::error::<V>( | ||||
|                         None, | ||||
|                         deserr::ErrorKind::Unexpected { | ||||
|                             msg: format!( | ||||
|                                 "Expected a single field, got {} fields", | ||||
|                                 deserr::Map::len(&map) | ||||
|                             ), | ||||
|                         }, | ||||
|                         location, | ||||
|                     ))); | ||||
|                 } | ||||
|                 let mut it = deserr::Map::into_iter(map); | ||||
|                 let (k, v) = it.next().unwrap(); | ||||
|  | ||||
|                 match k.as_str() { | ||||
|                     "huggingFace" => Ok(EmbedderSettings::HuggingFace(Setting::Set( | ||||
|                         HfEmbedderSettings::deserialize_from_value( | ||||
|                             v.into_value(), | ||||
|                             location.push_key(&k), | ||||
|                         )?, | ||||
|                     ))), | ||||
|                     "openAi" => Ok(EmbedderSettings::OpenAi(Setting::Set( | ||||
|                         OpenAiEmbedderSettings::deserialize_from_value( | ||||
|                             v.into_value(), | ||||
|                             location.push_key(&k), | ||||
|                         )?, | ||||
|                     ))), | ||||
|                     other => Err(deserr::take_cf_content(E::error::<V>( | ||||
|                         None, | ||||
|                         deserr::ErrorKind::UnknownKey { | ||||
|                             key: other, | ||||
|                             accepted: &["huggingFace", "openAi"], | ||||
|                         }, | ||||
|                         location, | ||||
|                     ))), | ||||
|                 } | ||||
|             } | ||||
|             _ => Err(deserr::take_cf_content(E::error::<V>( | ||||
|                 None, | ||||
|                 deserr::ErrorKind::IncorrectValueKind { | ||||
|                     actual: value, | ||||
|                     accepted: &[deserr::ValueKind::Map], | ||||
|                 }, | ||||
|                 location, | ||||
|             ))), | ||||
|         } | ||||
|     } | ||||
| } | ||||
|  | ||||
| impl Default for EmbedderSettings { | ||||
|     fn default() -> Self { | ||||
|         Self::HuggingFace(Default::default()) | ||||
|     } | ||||
| } | ||||
|  | ||||
| impl From<crate::vector::EmbedderOptions> for EmbedderSettings { | ||||
|     fn from(value: crate::vector::EmbedderOptions) -> Self { | ||||
|         match value { | ||||
|             crate::vector::EmbedderOptions::HuggingFace(hf) => { | ||||
|                 Self::HuggingFace(Setting::Set(hf.into())) | ||||
|             } | ||||
|             crate::vector::EmbedderOptions::OpenAi(openai) => { | ||||
|                 Self::OpenAi(Setting::Set(openai.into())) | ||||
|             } | ||||
|         } | ||||
|     } | ||||
| } | ||||
|  | ||||
| impl From<EmbedderSettings> for crate::vector::EmbedderOptions { | ||||
|     fn from(value: EmbedderSettings) -> Self { | ||||
|         match value { | ||||
|             EmbedderSettings::HuggingFace(Setting::Set(hf)) => Self::HuggingFace(hf.into()), | ||||
|             EmbedderSettings::HuggingFace(_setting) => Self::HuggingFace(Default::default()), | ||||
|             EmbedderSettings::OpenAi(Setting::Set(ai)) => Self::OpenAi(ai.into()), | ||||
|             EmbedderSettings::OpenAi(_setting) => Self::OpenAi( | ||||
|                 crate::vector::openai::EmbedderOptions::with_default_model(infer_api_key()), | ||||
|             ), | ||||
|         } | ||||
|     } | ||||
| } | ||||
|  | ||||
| #[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq, Deserr)] | ||||
| #[serde(deny_unknown_fields, rename_all = "camelCase")] | ||||
| #[deserr(rename_all = camelCase, deny_unknown_fields)] | ||||
| pub struct HfEmbedderSettings { | ||||
|     #[serde(default, skip_serializing_if = "Setting::is_not_set")] | ||||
|     #[deserr(default)] | ||||
|     pub model: Setting<String>, | ||||
|     #[serde(default, skip_serializing_if = "Setting::is_not_set")] | ||||
|     #[deserr(default)] | ||||
|     pub revision: Setting<String>, | ||||
|     #[serde(default, skip_serializing_if = "Setting::is_not_set")] | ||||
|     #[deserr(default)] | ||||
|     pub weight_source: Setting<WeightSource>, | ||||
|     #[serde(default, skip_serializing_if = "Setting::is_not_set")] | ||||
|     #[deserr(default)] | ||||
|     pub normalize_embeddings: Setting<bool>, | ||||
| } | ||||
|  | ||||
| impl HfEmbedderSettings { | ||||
|     pub fn apply(&mut self, new: Self) { | ||||
|         let HfEmbedderSettings { | ||||
|             model, | ||||
|             revision, | ||||
|             weight_source, | ||||
|             normalize_embeddings: normalize_embedding, | ||||
|         } = new; | ||||
|         self.model.apply(model); | ||||
|         self.revision.apply(revision); | ||||
|         self.weight_source.apply(weight_source); | ||||
|         self.normalize_embeddings.apply(normalize_embedding); | ||||
|     } | ||||
| } | ||||
|  | ||||
| impl From<crate::vector::hf::EmbedderOptions> for HfEmbedderSettings { | ||||
|     fn from(value: crate::vector::hf::EmbedderOptions) -> Self { | ||||
|         Self { | ||||
|             model: Setting::Set(value.model), | ||||
|             revision: value.revision.map(Setting::Set).unwrap_or(Setting::NotSet), | ||||
|             weight_source: Setting::Set(value.weight_source), | ||||
|             normalize_embeddings: Setting::Set(value.normalize_embeddings), | ||||
|         } | ||||
|     } | ||||
| } | ||||
|  | ||||
| impl From<HfEmbedderSettings> for crate::vector::hf::EmbedderOptions { | ||||
|     fn from(value: HfEmbedderSettings) -> Self { | ||||
|         let HfEmbedderSettings { model, revision, weight_source, normalize_embeddings } = value; | ||||
|         let mut this = Self::default(); | ||||
|         if let Some(model) = model.set() { | ||||
|             this.model = model; | ||||
|         } | ||||
|         if let Some(revision) = revision.set() { | ||||
|             this.revision = Some(revision); | ||||
|         } | ||||
|         if let Some(weight_source) = weight_source.set() { | ||||
|             this.weight_source = weight_source; | ||||
|         } | ||||
|         if let Some(normalize_embeddings) = normalize_embeddings.set() { | ||||
|             this.normalize_embeddings = normalize_embeddings; | ||||
|         } | ||||
|         this | ||||
|     } | ||||
| } | ||||
|  | ||||
| #[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq, Deserr)] | ||||
| #[serde(deny_unknown_fields, rename_all = "camelCase")] | ||||
| #[deserr(rename_all = camelCase, deny_unknown_fields)] | ||||
| pub struct OpenAiEmbedderSettings { | ||||
|     #[serde(default, skip_serializing_if = "Setting::is_not_set")] | ||||
|     #[deserr(default)] | ||||
|     pub api_key: Setting<String>, | ||||
|     #[serde(default, skip_serializing_if = "Setting::is_not_set")] | ||||
|     #[deserr(default)] | ||||
|     pub embedding_model: Setting<crate::vector::openai::EmbeddingModel>, | ||||
| } | ||||
|  | ||||
| impl OpenAiEmbedderSettings { | ||||
|     pub fn apply(&mut self, new: Self) { | ||||
|         let Self { api_key, embedding_model: embedding_mode } = new; | ||||
|         self.api_key.apply(api_key); | ||||
|         self.embedding_model.apply(embedding_mode); | ||||
|     } | ||||
| } | ||||
|  | ||||
| impl From<crate::vector::openai::EmbedderOptions> for OpenAiEmbedderSettings { | ||||
|     fn from(value: crate::vector::openai::EmbedderOptions) -> Self { | ||||
|         Self { | ||||
|             api_key: Setting::Set(value.api_key), | ||||
|             embedding_model: Setting::Set(value.embedding_model), | ||||
|         } | ||||
|     } | ||||
| } | ||||
|  | ||||
| impl From<OpenAiEmbedderSettings> for crate::vector::openai::EmbedderOptions { | ||||
|     fn from(value: OpenAiEmbedderSettings) -> Self { | ||||
|         let OpenAiEmbedderSettings { api_key, embedding_model } = value; | ||||
|         Self { | ||||
|             api_key: api_key.set().unwrap_or_else(infer_api_key), | ||||
|             embedding_model: embedding_model.set().unwrap_or_default(), | ||||
|         } | ||||
|     } | ||||
| } | ||||
|  | ||||
| fn infer_api_key() -> String { | ||||
|     /// FIXME: get key from instance options? | ||||
|     std::env::var("MEILI_OPENAI_API_KEY").unwrap_or_default() | ||||
| } | ||||
		Reference in New Issue
	
	Block a user