Compare commits

..

108 Commits

Author SHA1 Message Date
bceaf4f981 Add a log on the time taken by the incremental facet updating 2024-01-25 17:48:31 +01:00
d29b301618 Disable the facet search 2024-01-25 17:47:33 +01:00
a6fa0b97ec Merge #4318
4318: Hide embedders r=ManyTheFish a=dureuill

Hides `embedders` when it is an empty dictionary.

Manual tests:

- getting settings with empty embedders: not displayed
- getting settings with non-empty embedders: displayed like before
- dump with empty embedders: can be imported
- dump with non-empty embedders: can be imported

Co-authored-by: Louis Dureuil <louis@meilisearch.com>
2024-01-15 09:37:31 +00:00
38abfec611 Fix tests 2024-01-11 21:35:30 +01:00
84a5c304fc Don't display the embedders setting when it is an empty dict 2024-01-11 21:35:06 +01:00
e93d36d5b9 Merge #4313
4313: Fix document formatting performances r=Kerollmops a=ManyTheFish

reduce the formatted option list to the attributes that should be formatted,
instead of all the attributes to display.
The time to compute the `format` list scales with the number of fields to format;
cumulated with `map_leaf_values` that iterates over all the nested fields, it gives a quadratic complexity:
`d*f` where `d` is the total number of fields to display and `f` is the total number of fields to format.

Co-authored-by: ManyTheFish <many@meilisearch.com>
2024-01-11 14:19:44 +00:00
95f8e21533 fix typos 2024-01-11 15:07:08 +01:00
68f197624e Merge #4314
4314: Fix proximity precision telemetry r=Kerollmops a=ManyTheFish

The proximity precision telemetry was partially missing in the global setting route.
This PR adds the missing field and return the default value when the value is not set.


Co-authored-by: ManyTheFish <many@meilisearch.com>
2024-01-11 13:50:03 +00:00
b79b03d4e2 Fix proximity precision telemetry 2024-01-11 13:24:26 +01:00
86270e6878 Transform fields contained into _format into strings 2024-01-11 12:44:56 +01:00
81b6128b29 Update tests 2024-01-11 12:28:32 +01:00
5f5a486895 Reduce formatting time 2024-01-11 11:36:41 +01:00
5f4fc6c955 Add timer logs 2024-01-11 09:44:16 +01:00
1f5e8fc072 Merge #4311
4311: Limit the number of values returned by the facet search r=dureuill a=Kerollmops

This PR fixes a bug where the number of values per facet returned by the `indexes/{index}/facet-search` route was not tacking the `faceting.maxValuePerFacet` setting. It also adds a test.

Co-authored-by: Clément Renault <clement@meilisearch.com>
2024-01-10 16:04:06 +00:00
3f3462ab62 Limit the number of values returned by the facet search 2024-01-10 16:54:08 +01:00
93363b0201 Merge #4308
4308: Fix hang on `/indexes` and `/stats` routes r=Kerollmops a=dureuill

# Pull Request

## Related issue
Fixes #4218 

## Context

- A previous fix added a field to the `IndexScheduler` to memorize the `currently_updating_index`, so that accessing it through the search would return the handle without trying to open it. This resolved a hang on the search, but #4218 reported further hangs on the `/indexes` and `/stats` routes
- These routes were shunting the `IndexScheduler` and using internal `IndexMapper` logic to access the indexes, again trying to reopen the updating index.

## What does this PR do?

- Moves the logic relative to the `currently_updating_index` from the `IndexScheduler` to the `IndexMapper`, so that any index request to the `IndexMapper` can benefit from it.

## Test

1. Follow reproducer from #4218 
2. Before this PR, notice a hang on `/stats` and `/indexes`, but not on `/indexes/<updating_index>/search`
3. After this PR, notice no hang on either of `/stats`, `/indexes` or `/indexes/<updating_index>/search`



Co-authored-by: Louis Dureuil <louis@meilisearch.com>
2024-01-10 10:46:20 +00:00
97bb1ff9e2 Move currently_updating_index to IndexMapper 2024-01-09 15:37:27 +01:00
5ee1378856 Merge #4303
4303: Display default value when proximityPrecision is not set r=dureuill a=ManyTheFish

# Pull Request

## Related
Issue: #4187
Spec change requests: https://github.com/meilisearch/specifications/pull/261#discussion_r1441725272

## What does this PR do?
- Display default value when proximityPrecision is not set instead of Null


Co-authored-by: ManyTheFish <many@meilisearch.com>
2024-01-08 14:29:57 +00:00
e27b850b09 move the default display strategy on setting getter function 2024-01-08 14:03:47 +01:00
f75f22e026 Display default value when proximityPrecision is not set 2024-01-08 11:09:37 +01:00
6203f4acef Merge #4296
4296: Fix single element search r=irevoire a=dureuill

# Pull Request

Before this PR, indexing a single vector in a single document would result in the vector not being found by the vector search.

This PR adds a test case for this condition, and resolves it by bumping arroy to a version containing the fix.

# Test case

Output of the test before and after this PR:

```diff
diff --git a/meilisearch/tests/search/hybrid.rs b/meilisearch/tests/search/hybrid.rs
index 2cd4b83e7..79819cab2 100644
--- a/meilisearch/tests/search/hybrid.rs on release-v1.6.0
+++ b/meilisearch/tests/search/hybrid.rs on fix-single-element-search
`@@` -171,5 +171,5 `@@` async fn single_document() {
     .await;

     snapshot!(code, `@"200` OK");
-    snapshot!(response["hits"][0], `@r###"{"title":"Shazam!","desc":"a` Captain Marvel ersatz","id":"1","_vectors":{"default":[1.0,3.0]},"_rankingScore":0.0}"###);
+    snapshot!(response["hits"][0], `@r###"{"title":"Shazam!","desc":"a` Captain Marvel ersatz","id":"1","_vectors":{"default":[1.0,3.0]},"_rankingScore":1.0,"_semanticScore":1.0}"###);
 }
```



Co-authored-by: Louis Dureuil <louis@meilisearch.com>
2024-01-03 15:01:43 +00:00
12edc2c20a Update arroy to a fixed version 2024-01-03 15:59:37 +01:00
94b9f3b310 Add test 2024-01-03 15:56:20 +01:00
da99a04eb3 Merge #4294
4294: fix compilation warnings for release v1.6 r=curquiza a=irevoire

# Pull Request

## Related issue
Fixes #4292

## What does this PR do?
- Removed unused imports

#4295 fixes the issue no main

Co-authored-by: Tamo <tamo@meilisearch.com>
2024-01-02 15:00:40 +00:00
54ae6951eb fix warning 2024-01-02 15:19:30 +01:00
658ec6e0a4 Merge #4279
4279: Check experimental feature on setting update query rather than in the task. r=ManyTheFish a=dureuill

Improve the UX by checking for the vector store feature and returning an error synchronously when sending a setting update, rather than in the indexing task.

Co-authored-by: Louis Dureuil <louis@meilisearch.com>
2023-12-22 11:36:12 +00:00
43e822e802 Merge #4238
4238: Task queue webhook r=dureuill a=irevoire

# Prototype `prototype-task-queue-webhook-1`

The prototype is available through Docker by using the following command:

```bash
docker run -p 7700:7700 -v $(pwd)/meili_data:/meili_data getmeili/meilisearch:prototype-task-queue-webhook-1
```

# Pull Request

Implements the task queue webhook.

## Related issue
Fixes https://github.com/meilisearch/meilisearch/issues/4236

## What does this PR do?
- Provide a new cli and env var for the webhook, respectively called `--task-webhook-url` and `MEILI_TASK_WEBHOOK_URL`
- Also supports sending the requests with a custom `Authorization` header by specifying the optional `--task-webhook-authorization-header` CLI parameter or `MEILI_TASK_WEBHOOK_AUTHORIZATION_HEADER` env variable.
- Throw an error if the specified URL is invalid
- Every time a batch is processed, send all the finished tasks into the webhook with our public `TaskView` type as a JSON Line GZIPed body.
- Add one test.

## PR checklist

### Before becoming ready to review
- [x] Add a test
- [x] Compress the data we send
- [x] Chunk and stream the data we send
- [x] Remove the unwrap in the index-scheduler when sending the data fails
- [x] The analytics are missing

### Before merging
- [x] Release a prototype



Co-authored-by: Tamo <tamo@meilisearch.com>
Co-authored-by: Clément Renault <clement@meilisearch.com>
2023-12-21 14:43:46 +00:00
ee54d3171e Check experimental feature at query time 2023-12-21 15:26:12 +01:00
a0e713c4e7 Merge #4277
4277: Update mini-dashboard to v0.2.12 r=curquiza a=mdubus

# Pull Request

## Related issue
Fixes #4276

## What does this PR do?
Upgrade mini-dashboard to version 0.2.12 ([see changes](https://github.com/meilisearch/mini-dashboard/releases/tag/v0.2.12))

## PR checklist
Please check if your PR fulfills the following requirements:
- [x] Does this PR fix an existing issue, or have you listed the changes applied in the PR description (and why they are needed)?
- [x] Have you read the contributing guidelines?
- [x] Have you made sure that the title is accurate and descriptive of the changes?

Thank you so much for contributing to Meilisearch!


Co-authored-by: Morgane Dubus <30866152+mdubus@users.noreply.github.com>
2023-12-21 11:03:46 +00:00
d4cb0a885b Merge #4275
4275: Flatten settings r=dureuill a=dureuill

# Pull Request

## Related issue
Initial internal feedback seems to indicate that the current shape of the `embedders` setting is undesirable: it has too much depth.

This PR changes this by flattening the structure of the embedders to the following:

```json5
// NEW structure
"embedders": {
  // still starts with the embedder name
  "default": {
    "source": "huggingFace", // now a string
    // properties of the source are all at the same level as the source
    "model": "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2",
    "revision": "a9c555277f9bcf24f28fa5e092e665fc6f7c49cd",
    "documentTemplate": "A product titled '{{doc.title}}'" // now a string
  }
}
```

By comparison, the old structure was:

```json5
// PREVIOUS version, no longer working with this PR
"embedders": {
  // still starts with the embedder name
  "default": {
    "source": {
      "huggingFace": {
        "model": "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2",
        "revision": "a9c555277f9bcf24f28fa5e092e665fc6f7c49cd"
      },
    "documentTemplate": { 
      "template": "A product titled '{{doc.title}}'" // now a string
    }
  }
}
```

The fields that are accepted in the new version of the `embedders` setting are depending on the value of the `source` field:

```json5
// huggingFace
"embedders": {
   "default": {
    "source": "huggingFace",
    "model": "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2",
    "revision": "a9c555277f9bcf24f28fa5e092e665fc6f7c49cd",
    "documentTemplate": "A product titled '{{doc.title}}'"
  }
}

// openAi
"embedders": {
   "default": {
    "source": "openAi",
    "model": "text-embedding-ada-002",
    "apiKey": "open_ai_api_key",
    "documentTemplate": "A product titled '{{doc.title}}'"
  }
}

// userProvided
"embedders": {
   "default": {
    "source": "userProvided",
    "dimensions": 42, // mandatory
  }
}
```

## What does this PR do?
- Flatten the settings structure
- Validate the prompt earlier to return a synchronous error on setting change rather than in the failing task
- Make it an error to pass a field for the wrong source (see above for allowed fields for each source)
- Not changed: It is still an error not to pass `dimensions` to the `userProvided` embedder
- If `source` was specified in the settings, validate the setting early to return a synchronous error in case of a missing mandatory field for the userProvided source (dimensions) or a forbidden field for the specified source.
- If `source` was not specified in the settings, still validate the setting, but only at indexing time, by using the source stored in the DB.
- Resets all values if the source changes, even if the user did not reset them explicitly.

## PR checklist
Please check if your PR fulfills the following requirements:
- [ ] Change the public facing guide for using the API
- [ ] Change examples of use in the changelog


Co-authored-by: Louis Dureuil <louis@meilisearch.com>
2023-12-21 09:58:01 +00:00
f52dee2b3b Update Cargo.toml
Update mini-dashboard with v0.2.12
2023-12-21 09:53:13 +01:00
0bf879fb88 Fix warning on rust stable 2023-12-20 17:48:09 +01:00
6ff81de401 Fix tests 2023-12-20 17:16:46 +01:00
2e4c9651df Validate settings in route 2023-12-20 17:16:46 +01:00
ec9649c922 Add function to validate settings in Meilisearch, to be used in the routes 2023-12-20 17:16:46 +01:00
9123370e90 Validate fused settings in settings task after fusing with existing setting 2023-12-20 17:16:46 +01:00
14b396d302 Add new errors 2023-12-20 17:16:45 +01:00
393216bf30 Flatten embedders settings 2023-12-20 17:16:43 +01:00
e249e4db7b Change Setting::apply function signature 2023-12-20 17:15:24 +01:00
de2ca7006e Merge #4272
4272: Don't pass default revision when the model is explicitly set in config r=Kerollmops a=dureuill

# Pull Request

## Related issue
Fixes #4271 

## What does this PR do?

- When the `model` is explicitly set in the `embedders` setting, we reset the `revision` to `None`, such that if the user doesn't specify a revision, the head of the model repository is chosen. 
- Not changed: If the user specifies a revision, it applies, like previously. 
- Not changed: If the user doesn't specify a model, the default model with the default revision applies, like previously.

## Manual testing on a fresh DB

1. Enable experimental feature:
```sh
curl \
  -X PATCH 'http://localhost:7700/experimental-features/' \
  -H 'Content-Type: application/json' -H 'Authorization: Bearer foo' \
--data-binary '{ "vectorStore": true
  }'
```
2. Send settings with a specified model but no specified revision:
```sh
curl \
-X PATCH 'http://localhost:7700/indexes/products/settings' \
-H 'Content-Type: application/json' --data-binary \
'{ "embedders": { "default": { "source": { "huggingFace": { "model": "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2" } }, "documentTemplate": { "template": "A product titled '{{doc.title}}'"} } } }'
```
3. Check that the task was successful:
```sh
curl 'http://localhost:7700/tasks/0'

{"uid":0,"indexUid":"products","status":"succeeded","type":"settingsUpdate","canceledBy":null,"details":{"embedders":{"default":{"source":{"huggingFace":{"model":"sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"}},"documentTemplate":{"template":"A product titled {{doc.title}}"}}}},"error":null,"duration":"PT0.001892S","enqueuedAt":"2023-12-20T09:17:01.73789Z","startedAt":"2023-12-20T09:17:01.73854Z","finishedAt":"2023-12-20T09:17:01.740432Z"}
```
4. Send documents to index:
```sh
curl 'https://localhost:7700/indexes/products/documents' -H 'Content-Type: application/json' --data-binary '{"id": 0, "title": "Best product"}'
```

Co-authored-by: Louis Dureuil <louis@meilisearch.com>
2023-12-20 14:27:51 +00:00
333ce12eb2 Fixed issue where the default revision is always the one we picked for the default model 2023-12-20 10:17:49 +01:00
fb9db1eba6 Merge #4269
4269: Remove dependency that requires libstdc++ r=dureuill a=dureuill

Removes the dependency that caused the additional runtime dependency on libstdc++ by disabling the default features of the hf tokenizer.

## Discussion

- This removes a feature that is using a C++ dependency and is supposed to accelerate the tokenizer. As the tokenizer is likely to be a significant bottleneck for embedding texts using a HF model, this is an issue.
- We should at least rerun the movies vector indexing and check that it still works correctly and that it has a runtime in the ballpark of what it used to be.

Co-authored-by: Louis Dureuil <louis.dureuil@xinra.net>
2023-12-19 12:26:48 +00:00
fa2b96b9a5 Add an Authorization Header along with the webhook calls 2023-12-19 12:18:45 +01:00
19736cefe8 add the analytics 2023-12-19 10:36:04 +01:00
4fb25b8782 fix clippy 2023-12-19 10:35:51 +01:00
c83a33017e stream and chunk the data 2023-12-19 10:35:51 +01:00
be72326c0a gzip the tasks 2023-12-19 10:35:51 +01:00
547379abb0 parse the url correctly 2023-12-19 10:35:51 +01:00
0b2fff27f2 update and fix the test 2023-12-19 10:35:51 +01:00
3adbc2b942 return a task view instead of a task 2023-12-19 10:35:51 +01:00
fbea721378 add a first working test with actixweb 2023-12-19 10:35:51 +01:00
391eb72137 start writing a test with actix but it doesn't works 2023-12-19 10:35:50 +01:00
d78ad51082 Implement the webhook 2023-12-19 10:35:50 +01:00
1956045a06 add the option 2023-12-19 10:23:56 +01:00
b2193e612f Revert "Add libstdc++ in Dockerfile" as it is no longer needed
This reverts commit 9df8cfc013.
2023-12-18 22:17:29 +01:00
942d49314c Remove dependency that requires libstdc++ 2023-12-18 22:17:18 +01:00
9a846e82bc Merge #4268
4268: Add libstdc++ in Dockerfile r=curquiza a=sanders41

# Pull Request

## Related issue
Fixes #4267

## What does this PR do?
- Add libstdc++ in the Dockerfile

## PR checklist
Please check if your PR fulfills the following requirements:
- [x] Does this PR fix an existing issue, or have you listed the changes applied in the PR description (and why they are needed)?
- [x] Have you read the contributing guidelines?
- [x] Have you made sure that the title is accurate and descriptive of the changes?

Thank you so much for contributing to Meilisearch!


Co-authored-by: Paul Sanders <psanders1@gmail.com>
2023-12-18 18:35:53 +00:00
9df8cfc013 Add libstdc++ in Dockerfile 2023-12-18 13:05:46 -05:00
248aaa6d45 Merge #4262
4262: Update version for the next release (v1.6.0) in Cargo.toml r=curquiza a=meili-bot

⚠️ This PR is automatically generated. Check the new version is the expected one and Cargo.lock has been updated before merging.

Co-authored-by: curquiza <curquiza@users.noreply.github.com>
2023-12-18 14:00:19 +00:00
50d6317ec0 Update version for the next release (v1.6.0) in Cargo.toml 2023-12-18 13:57:46 +00:00
b734bd9891 Merge #4261
4261: Set rust toolchain to 1.71.1 in dockerfile r=curquiza a=dureuill

Fixes docker [CI](https://github.com/meilisearch/meilisearch/actions/workflows/publish-docker-images.yml)

Co-authored-by: Louis Dureuil <louis@meilisearch.com>
2023-12-18 12:32:26 +00:00
9800d5a103 Set rust toolchain to 1.71.1 in dockerfile 2023-12-18 10:59:25 +01:00
7c4ed07617 Merge #4257
4257: Change proximity precision settings r=dureuill a=ManyTheFish

- [x] Add proximity_precision value into the analytics
- [x] Change the naming of `attributeScale` and `wordScale` into `byAttribute` and `byWord`
- [x] Remove proximityPrecision from the experimental feature

Co-authored-by: ManyTheFish <many@meilisearch.com>
Co-authored-by: Many the fish <many@meilisearch.com>
2023-12-18 09:07:28 +00:00
3a99a555a2 Fix experimental features snapshots in tests 2023-12-18 10:05:51 +01:00
9e1b458010 Merge branch 'main' into change-proximity-precision-settings 2023-12-18 09:08:47 +01:00
2aede03bc2 Merge #4226
4226: Hybrid search r=dureuill a=dureuill

Allows to perform hybrid search requests that combine the results of semantic and keyword search and automatically generate embeddings.

## How to use

See [feature description](https://meilisearch.notion.site/v1-6-Hybrid-Search-Embedders-ea42c82f90cc4bc0be1eeb917c1118c8)

## Changes

- work is based on #4213 
- milli::new search now takes an input universe directly, rather than computing it from a filter. This adds flexibility to require results on a subset of documents
- vector search is now a regular ranking rule (akin to sort and geosort) and reports its score as a ScoreDetail
- separate keyword search and vector search functions, vector search now respects (geo)sort ranking rules
- add automatic embedding
- add hybrid search

Co-authored-by: Louis Dureuil <louis@meilisearch.com>
Co-authored-by: ManyTheFish <many@meilisearch.com>
2023-12-14 16:24:56 +00:00
e741bc1c62 Add proximity_precision value into the analytics 2023-12-14 16:48:06 +01:00
6425996e36 Change the naming of attributeScale and wordScale into byAttribute and byWord 2023-12-14 16:31:00 +01:00
eb5cb91da2 Switch default from hf to openai 2023-12-14 16:19:46 +01:00
87bba98bd8 Various changes
- fixed seed for arroy
- check vector dimensions as soon as it is provided to search
- don't embed whitespace
2023-12-14 16:08:42 +01:00
217105b7da hybrid search uses semantic ratio, error handling 2023-12-14 16:08:42 +01:00
1b7c164a55 Pass the semantic ratio to milli 2023-12-14 16:08:42 +01:00
f3f3944469 Fix error checking 2023-12-14 16:08:42 +01:00
93dcbf598d Deserialize semantic ratio 2023-12-14 16:08:42 +01:00
ac68f33194 Add simple test 2023-12-14 16:08:42 +01:00
9991152bbe Add TODOs 2023-12-14 16:08:42 +01:00
a4536b1381 Small adjustments to respect the spec 2023-12-14 16:08:42 +01:00
5b51cb04af Remove some settings 2023-12-14 16:08:42 +01:00
3c1a14f1cd Add settings routes 2023-12-14 16:08:42 +01:00
b8e4709dfa Remove prompt strategy and fallback 2023-12-14 16:08:41 +01:00
806e5b6899 Tests pass 2023-12-14 16:08:41 +01:00
61bd2fb7a9 Update arroy 2023-12-14 16:08:41 +01:00
e0cc775dc4 Various changes
- DistributionShift in Search object (to be set from model in embed?)
- Fix issue where embedder index wasn't computed at search time
- Accept as default embedder either the "default" one, or the only embedder when there is only one
2023-12-14 16:08:41 +01:00
12940d79a9 WIP
- manual embedder
- multi embedders OK
- clippy + tests OK
2023-12-14 16:08:41 +01:00
922a640188 WIP multi embedders
fixed template bugs
2023-12-14 16:08:41 +01:00
abbe131084 Cosmetic change 2023-12-14 16:08:41 +01:00
d4715e0c4d Fix same vector sort bug 2023-12-14 16:08:41 +01:00
11e2a2c1aa Fix geosort bug 2023-12-14 16:08:41 +01:00
65e49b7092 Remove stuff, add distribution shift (WIP) 2023-12-14 16:08:38 +01:00
e56f160032 Actually pass embedders on reindex 2023-12-14 16:07:49 +01:00
687d92f217 prompt bifluor+ 2023-12-14 16:07:49 +01:00
fb539f61fe WIP 2023-12-14 16:07:49 +01:00
cb4ebe163e WIP 2023-12-14 16:07:49 +01:00
dde3a04679 WIP arroy integration 2023-12-14 16:07:49 +01:00
13c2c6c16b Small commit to add hybrid search and autoembedding 2023-12-14 16:07:48 +01:00
21bcf32109 Add candle and hg_hub, updating a lot of deps in the process 2023-12-14 16:07:48 +01:00
35e1981488 Remove proximityPrecision form the experimental feature 2023-12-14 15:52:42 +01:00
e0f712b9d3 Merge #4254
4254: Bring back v1.5.1 changes into main r=ManyTheFish a=Kerollmops

This pull request brings back changes from the _release-v1.5.1_ branch into _main_.

Co-authored-by: ManyTheFish <many@meilisearch.com>
Co-authored-by: meili-bors[bot] <89034592+meili-bors[bot]@users.noreply.github.com>
Co-authored-by: curquiza <curquiza@users.noreply.github.com>
Co-authored-by: Clément Renault <clement@meilisearch.com>
2023-12-14 09:41:57 +00:00
56571f762a Merge remote-tracking branch 'origin/main' into tmp-release-v1.5.1 2023-12-13 11:57:01 +01:00
005800634d Merge pull request #4249 from meilisearch/flag-limit-batch-size
Introduce parameters to limit the number of batched tasks
2023-12-13 10:32:14 +01:00
976af4fa8f Add the default commented experimental batched tasks limit parameter to the config file 2023-12-12 10:59:00 +01:00
99fec27788 Make the --max-number-of-batched-tasks argument experimental 2023-12-12 10:55:39 +01:00
afa8f273a8 Merge #4250
4250: Update version for the next release (v1.5.1) in Cargo.toml r=dureuill a=meili-bot

⚠️ This PR is automatically generated. Check the new version is the expected one and Cargo.lock has been updated before merging.

Co-authored-by: curquiza <curquiza@users.noreply.github.com>
2023-12-12 08:26:06 +00:00
4b644f6bc0 Update version for the next release (v1.5.1) in Cargo.toml 2023-12-11 17:15:11 +00:00
c95d68e244 Merge #4233
4233: Add test reproducing #4232 r=dureuill a=ManyTheFish

- add a test reproducing the bug
- fix the bug by creating 2 different restricting lists of attributes, one for the exact attributes, and the other for the tolerant attributes

## Related issue
Fixes #4232


Co-authored-by: ManyTheFish <many@meilisearch.com>
2023-11-29 08:47:17 +00:00
3b3fa38f27 Put the restrict list in a sub-struct 2023-11-28 18:37:57 +01:00
d6c2ee15a9 Filter on attributes before computing the docids when attribute restriction is on 2023-11-28 14:55:29 +01:00
dc07790133 Add test reproducing #4232 2023-11-27 11:39:11 +01:00
83 changed files with 6677 additions and 1160 deletions

1169
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -19,7 +19,7 @@ members = [
]
[workspace.package]
version = "1.5.0"
version = "1.6.0"
authors = ["Quentin de Quelen <quentin@dequelen.me>", "Clément Renault <clement@meilisearch.com>"]
description = "Meilisearch HTTP server"
homepage = "https://meilisearch.com"

View File

@ -1,5 +1,5 @@
# Compile
FROM rust:alpine3.16 AS compiler
FROM rust:1.71.1-alpine3.18 AS compiler
RUN apk add -q --update-cache --no-cache build-base openssl-dev

View File

@ -129,3 +129,6 @@ experimental_enable_metrics = false
# Experimental RAM reduction during indexing, do not use in production, see: <https://github.com/meilisearch/product/discussions/652>
experimental_reduce_indexing_memory_usage = false
# Experimentally reduces the maximum number of tasks that will be processed at once, see: <https://github.com/orgs/meilisearch/discussions/713>
# experimental_max_number_of_batched_tasks = 100

View File

@ -276,6 +276,7 @@ pub(crate) mod test {
),
}),
pagination: Setting::NotSet,
embedders: Setting::NotSet,
_kind: std::marker::PhantomData,
};
settings.check()

View File

@ -378,6 +378,7 @@ impl<T> From<v5::Settings<T>> for v6::Settings<v6::Unchecked> {
v5::Setting::Reset => v6::Setting::Reset,
v5::Setting::NotSet => v6::Setting::NotSet,
},
embedders: v6::Setting::NotSet,
_kind: std::marker::PhantomData,
}
}

View File

@ -18,6 +18,7 @@ derive_builder = "0.12.0"
dump = { path = "../dump" }
enum-iterator = "1.4.0"
file-store = { path = "../file-store" }
flate2 = "1.0.28"
log = "0.4.17"
meilisearch-auth = { path = "../meilisearch-auth" }
meilisearch-types = { path = "../meilisearch-types" }
@ -30,6 +31,7 @@ synchronoise = "1.0.1"
tempfile = "3.5.0"
thiserror = "1.0.40"
time = { version = "0.3.20", features = ["serde-well-known", "formatting", "parsing", "macros"] }
ureq = "2.9.1"
uuid = { version = "1.3.1", features = ["serde", "v4"] }
[dev-dependencies]

View File

@ -936,8 +936,8 @@ impl IndexScheduler {
};
// the index operation can take a long time, so save this handle to make it available to the search for the duration of the tick
*self.currently_updating_index.write().unwrap() =
Some((index_uid.clone(), index.clone()));
self.index_mapper
.set_currently_updating_index(Some((index_uid.clone(), index.clone())));
let mut index_wtxn = index.write_txn()?;
let tasks = self.apply_index_operation(&mut index_wtxn, &index, op)?;
@ -1202,6 +1202,10 @@ impl IndexScheduler {
let config = IndexDocumentsConfig { update_method: method, ..Default::default() };
let embedder_configs = index.embedding_configs(index_wtxn)?;
// TODO: consider Arc'ing the map too (we only need read access + we'll be cloning it multiple times, so really makes sense)
let embedders = self.embedders(embedder_configs)?;
let mut builder = milli::update::IndexDocuments::new(
index_wtxn,
index,
@ -1220,6 +1224,8 @@ impl IndexScheduler {
let (new_builder, user_result) = builder.add_documents(reader)?;
builder = new_builder;
builder = builder.with_embedders(embedders.clone());
let received_documents =
if let Some(Details::DocumentAdditionOrUpdate {
received_documents,
@ -1345,9 +1351,6 @@ impl IndexScheduler {
for (task, (_, settings)) in tasks.iter_mut().zip(settings) {
let checked_settings = settings.clone().check();
if checked_settings.proximity_precision.set().is_some() {
self.features.features().check_proximity_precision()?;
}
task.details = Some(Details::SettingsUpdate { settings: Box::new(settings) });
apply_settings_to_builder(&checked_settings, &mut builder);

View File

@ -56,12 +56,12 @@ impl RoFeatures {
}
}
pub fn check_vector(&self) -> Result<()> {
pub fn check_vector(&self, disabled_action: &'static str) -> Result<()> {
if self.runtime.vector_store {
Ok(())
} else {
Err(FeatureNotEnabledError {
disabled_action: "Passing `vector` as a query parameter",
disabled_action,
feature: "vector store",
issue_link: "https://github.com/meilisearch/product/discussions/677",
}
@ -81,19 +81,6 @@ impl RoFeatures {
.into())
}
}
pub fn check_proximity_precision(&self) -> Result<()> {
if self.runtime.proximity_precision {
Ok(())
} else {
Err(FeatureNotEnabledError {
disabled_action: "Using `proximityPrecision` index setting",
feature: "proximity precision",
issue_link: "https://github.com/orgs/meilisearch/discussions/710",
}
.into())
}
}
}
impl FeatureData {

View File

@ -69,6 +69,10 @@ pub struct IndexMapper {
/// Whether we open a meilisearch index with the MDB_WRITEMAP option or not.
enable_mdb_writemap: bool,
pub indexer_config: Arc<IndexerConfig>,
/// A few types of long running batches of tasks that act on a single index set this field
/// so that a handle to the index is available from other threads (search) in an optimized manner.
currently_updating_index: Arc<RwLock<Option<(String, Index)>>>,
}
/// Whether the index is available for use or is forbidden to be inserted back in the index map
@ -151,6 +155,7 @@ impl IndexMapper {
index_growth_amount,
enable_mdb_writemap,
indexer_config: Arc::new(indexer_config),
currently_updating_index: Default::default(),
})
}
@ -303,6 +308,14 @@ impl IndexMapper {
/// Return an index, may open it if it wasn't already opened.
pub fn index(&self, rtxn: &RoTxn, name: &str) -> Result<Index> {
if let Some((current_name, current_index)) =
self.currently_updating_index.read().unwrap().as_ref()
{
if current_name == name {
return Ok(current_index.clone());
}
}
let uuid = self
.index_mapping
.get(rtxn, name)?
@ -474,4 +487,8 @@ impl IndexMapper {
pub fn indexer_config(&self) -> &IndexerConfig {
&self.indexer_config
}
pub fn set_currently_updating_index(&self, index: Option<(String, Index)>) {
*self.currently_updating_index.write().unwrap() = index;
}
}

View File

@ -37,10 +37,12 @@ pub fn snapshot_index_scheduler(scheduler: &IndexScheduler) -> String {
snapshots_path: _,
auth_path: _,
version_file_path: _,
webhook_url: _,
webhook_authorization_header: _,
test_breakpoint_sdr: _,
planned_failures: _,
run_loop_iteration: _,
currently_updating_index: _,
embedders: _,
} = scheduler;
let rtxn = env.read_txn().unwrap();

View File

@ -34,6 +34,7 @@ pub type TaskId = u32;
use std::collections::{BTreeMap, HashMap};
use std::fs::File;
use std::io::{self, BufReader, Read};
use std::ops::{Bound, RangeBounds};
use std::path::{Path, PathBuf};
use std::sync::atomic::AtomicBool;
@ -45,6 +46,8 @@ use dump::{KindDump, TaskDump, UpdateFile};
pub use error::Error;
pub use features::RoFeatures;
use file_store::FileStore;
use flate2::bufread::GzEncoder;
use flate2::Compression;
use meilisearch_types::error::ResponseError;
use meilisearch_types::features::{InstanceTogglableFeatures, RuntimeTogglableFeatures};
use meilisearch_types::heed::byteorder::BE;
@ -52,7 +55,9 @@ use meilisearch_types::heed::types::{SerdeBincode, SerdeJson, Str, I128};
use meilisearch_types::heed::{self, Database, Env, PutFlags, RoTxn, RwTxn};
use meilisearch_types::milli::documents::DocumentsBatchBuilder;
use meilisearch_types::milli::update::IndexerConfig;
use meilisearch_types::milli::vector::{Embedder, EmbedderOptions, EmbeddingConfigs};
use meilisearch_types::milli::{self, CboRoaringBitmapCodec, Index, RoaringBitmapCodec, BEU32};
use meilisearch_types::task_view::TaskView;
use meilisearch_types::tasks::{Kind, KindWithContent, Status, Task};
use puffin::FrameView;
use roaring::RoaringBitmap;
@ -169,8 +174,8 @@ impl ProcessingTasks {
}
/// Set the processing tasks to an empty list
fn stop_processing(&mut self) {
self.processing = RoaringBitmap::new();
fn stop_processing(&mut self) -> RoaringBitmap {
std::mem::take(&mut self.processing)
}
/// Returns `true` if there, at least, is one task that is currently processing that we must stop.
@ -240,6 +245,10 @@ pub struct IndexSchedulerOptions {
pub snapshots_path: PathBuf,
/// The path to the folder containing the dumps.
pub dumps_path: PathBuf,
/// The URL on which we must send the tasks statuses
pub webhook_url: Option<String>,
/// The value we will send into the Authorization HTTP header on the webhook URL
pub webhook_authorization_header: Option<String>,
/// The maximum size, in bytes, of the task index.
pub task_db_size: usize,
/// The size, in bytes, with which a meilisearch index is opened the first time of each meilisearch index.
@ -322,6 +331,11 @@ pub struct IndexScheduler {
/// The maximum number of tasks that will be batched together.
pub(crate) max_number_of_batched_tasks: usize,
/// The webhook url we should send tasks to after processing every batches.
pub(crate) webhook_url: Option<String>,
/// The Authorization header to send to the webhook URL.
pub(crate) webhook_authorization_header: Option<String>,
/// A frame to output the indexation profiling files to disk.
pub(crate) puffin_frame: Arc<puffin::GlobalFrameView>,
@ -337,9 +351,7 @@ pub struct IndexScheduler {
/// The path to the version file of Meilisearch.
pub(crate) version_file_path: PathBuf,
/// A few types of long running batches of tasks that act on a single index set this field
/// so that a handle to the index is available from other threads (search) in an optimized manner.
currently_updating_index: Arc<RwLock<Option<(String, Index)>>>,
embedders: Arc<RwLock<HashMap<EmbedderOptions, Arc<Embedder>>>>,
// ================= test
// The next entry is dedicated to the tests.
@ -385,7 +397,9 @@ impl IndexScheduler {
dumps_path: self.dumps_path.clone(),
auth_path: self.auth_path.clone(),
version_file_path: self.version_file_path.clone(),
currently_updating_index: self.currently_updating_index.clone(),
webhook_url: self.webhook_url.clone(),
webhook_authorization_header: self.webhook_authorization_header.clone(),
embedders: self.embedders.clone(),
#[cfg(test)]
test_breakpoint_sdr: self.test_breakpoint_sdr.clone(),
#[cfg(test)]
@ -483,7 +497,9 @@ impl IndexScheduler {
snapshots_path: options.snapshots_path,
auth_path: options.auth_path,
version_file_path: options.version_file_path,
currently_updating_index: Arc::new(RwLock::new(None)),
webhook_url: options.webhook_url,
webhook_authorization_header: options.webhook_authorization_header,
embedders: Default::default(),
#[cfg(test)]
test_breakpoint_sdr,
@ -666,13 +682,6 @@ impl IndexScheduler {
/// If you need to fetch information from or perform an action on all indexes,
/// see the `try_for_each_index` function.
pub fn index(&self, name: &str) -> Result<Index> {
if let Some((current_name, current_index)) =
self.currently_updating_index.read().unwrap().as_ref()
{
if current_name == name {
return Ok(current_index.clone());
}
}
let rtxn = self.env.read_txn()?;
self.index_mapper.index(&rtxn, name)
}
@ -1153,7 +1162,7 @@ impl IndexScheduler {
};
// Reset the currently updating index to relinquish the index handle
*self.currently_updating_index.write().unwrap() = None;
self.index_mapper.set_currently_updating_index(None);
#[cfg(test)]
self.maybe_fail(tests::FailureLocation::AcquiringWtxn)?;
@ -1246,19 +1255,99 @@ impl IndexScheduler {
}
}
self.processing_tasks.write().unwrap().stop_processing();
let processed = self.processing_tasks.write().unwrap().stop_processing();
#[cfg(test)]
self.maybe_fail(tests::FailureLocation::CommittingWtxn)?;
wtxn.commit().map_err(Error::HeedTransaction)?;
// We shouldn't crash the tick function if we can't send data to the webhook.
let _ = self.notify_webhook(&processed);
#[cfg(test)]
self.breakpoint(Breakpoint::AfterProcessing);
Ok(TickOutcome::TickAgain(processed_tasks))
}
/// Once the tasks changes have been commited we must send all the tasks that were updated to our webhook if there is one.
fn notify_webhook(&self, updated: &RoaringBitmap) -> Result<()> {
if let Some(ref url) = self.webhook_url {
struct TaskReader<'a, 'b> {
rtxn: &'a RoTxn<'a>,
index_scheduler: &'a IndexScheduler,
tasks: &'b mut roaring::bitmap::Iter<'b>,
buffer: Vec<u8>,
written: usize,
}
impl<'a, 'b> Read for TaskReader<'a, 'b> {
fn read(&mut self, mut buf: &mut [u8]) -> std::io::Result<usize> {
if self.buffer.is_empty() {
match self.tasks.next() {
None => return Ok(0),
Some(task_id) => {
let task = self
.index_scheduler
.get_task(self.rtxn, task_id)
.map_err(|err| io::Error::new(io::ErrorKind::Other, err))?
.ok_or_else(|| {
io::Error::new(
io::ErrorKind::Other,
Error::CorruptedTaskQueue,
)
})?;
serde_json::to_writer(
&mut self.buffer,
&TaskView::from_task(&task),
)?;
self.buffer.push(b'\n');
}
}
}
let mut to_write = &self.buffer[self.written..];
let wrote = io::copy(&mut to_write, &mut buf)?;
self.written += wrote as usize;
// we wrote everything and must refresh our buffer on the next call
if self.written == self.buffer.len() {
self.written = 0;
self.buffer.clear();
}
Ok(wrote as usize)
}
}
let rtxn = self.env.read_txn()?;
let task_reader = TaskReader {
rtxn: &rtxn,
index_scheduler: self,
tasks: &mut updated.into_iter(),
buffer: Vec::with_capacity(50), // on average a task is around ~100 bytes
written: 0,
};
// let reader = GzEncoder::new(BufReader::new(task_reader), Compression::default());
let reader = GzEncoder::new(BufReader::new(task_reader), Compression::default());
let request = ureq::post(url).set("Content-Encoding", "gzip");
let request = match &self.webhook_authorization_header {
Some(header) => request.set("Authorization", header),
None => request,
};
if let Err(e) = request.send(reader) {
log::error!("While sending data to the webhook: {e}");
}
}
Ok(())
}
/// Register a task to cleanup the task queue if needed
fn cleanup_task_queue(&self) -> Result<()> {
let rtxn = self.env.read_txn().map_err(Error::HeedTransaction)?;
@ -1333,6 +1422,40 @@ impl IndexScheduler {
}
}
// TODO: consider using a type alias or a struct embedder/template
pub fn embedders(
&self,
embedding_configs: Vec<(String, milli::vector::EmbeddingConfig)>,
) -> Result<EmbeddingConfigs> {
let res: Result<_> = embedding_configs
.into_iter()
.map(|(name, milli::vector::EmbeddingConfig { embedder_options, prompt })| {
let prompt =
Arc::new(prompt.try_into().map_err(meilisearch_types::milli::Error::from)?);
// optimistically return existing embedder
{
let embedders = self.embedders.read().unwrap();
if let Some(embedder) = embedders.get(&embedder_options) {
return Ok((name, (embedder.clone(), prompt)));
}
}
// add missing embedder
let embedder = Arc::new(
Embedder::new(embedder_options.clone())
.map_err(meilisearch_types::milli::vector::Error::from)
.map_err(meilisearch_types::milli::Error::from)?,
);
{
let mut embedders = self.embedders.write().unwrap();
embedders.insert(embedder_options, embedder.clone());
}
Ok((name, (embedder, prompt)))
})
.collect();
res.map(EmbeddingConfigs::new)
}
/// Blocks the thread until the test handle asks to progress to/through this breakpoint.
///
/// Two messages are sent through the channel for each breakpoint.
@ -1638,6 +1761,8 @@ mod tests {
indexes_path: tempdir.path().join("indexes"),
snapshots_path: tempdir.path().join("snapshots"),
dumps_path: tempdir.path().join("dumps"),
webhook_url: None,
webhook_authorization_header: None,
task_db_size: 1000 * 1000, // 1 MB, we don't use MiB on purpose.
index_base_map_size: 1000 * 1000, // 1 MB, we don't use MiB on purpose.
enable_mdb_writemap: false,

View File

@ -188,3 +188,4 @@ merge_with_error_impl_take_error_message!(ParseOffsetDateTimeError);
merge_with_error_impl_take_error_message!(ParseTaskKindError);
merge_with_error_impl_take_error_message!(ParseTaskStatusError);
merge_with_error_impl_take_error_message!(IndexUidFormatError);
merge_with_error_impl_take_error_message!(InvalidSearchSemanticRatio);

View File

@ -222,6 +222,8 @@ InvalidVectorsType , InvalidRequest , BAD_REQUEST ;
InvalidDocumentId , InvalidRequest , BAD_REQUEST ;
InvalidDocumentLimit , InvalidRequest , BAD_REQUEST ;
InvalidDocumentOffset , InvalidRequest , BAD_REQUEST ;
InvalidEmbedder , InvalidRequest , BAD_REQUEST ;
InvalidHybridQuery , InvalidRequest , BAD_REQUEST ;
InvalidIndexLimit , InvalidRequest , BAD_REQUEST ;
InvalidIndexOffset , InvalidRequest , BAD_REQUEST ;
InvalidIndexPrimaryKey , InvalidRequest , BAD_REQUEST ;
@ -233,6 +235,7 @@ InvalidSearchAttributesToRetrieve , InvalidRequest , BAD_REQUEST ;
InvalidSearchCropLength , InvalidRequest , BAD_REQUEST ;
InvalidSearchCropMarker , InvalidRequest , BAD_REQUEST ;
InvalidSearchFacets , InvalidRequest , BAD_REQUEST ;
InvalidSearchSemanticRatio , InvalidRequest , BAD_REQUEST ;
InvalidFacetSearchFacetName , InvalidRequest , BAD_REQUEST ;
InvalidSearchFilter , InvalidRequest , BAD_REQUEST ;
InvalidSearchHighlightPostTag , InvalidRequest , BAD_REQUEST ;
@ -256,6 +259,7 @@ InvalidSettingsProximityPrecision , InvalidRequest , BAD_REQUEST ;
InvalidSettingsFaceting , InvalidRequest , BAD_REQUEST ;
InvalidSettingsFilterableAttributes , InvalidRequest , BAD_REQUEST ;
InvalidSettingsPagination , InvalidRequest , BAD_REQUEST ;
InvalidSettingsEmbedders , InvalidRequest , BAD_REQUEST ;
InvalidSettingsRankingRules , InvalidRequest , BAD_REQUEST ;
InvalidSettingsSearchableAttributes , InvalidRequest , BAD_REQUEST ;
InvalidSettingsSortableAttributes , InvalidRequest , BAD_REQUEST ;
@ -295,15 +299,18 @@ MissingFacetSearchFacetName , InvalidRequest , BAD_REQUEST ;
MissingIndexUid , InvalidRequest , BAD_REQUEST ;
MissingMasterKey , Auth , UNAUTHORIZED ;
MissingPayload , InvalidRequest , BAD_REQUEST ;
MissingSearchHybrid , InvalidRequest , BAD_REQUEST ;
MissingSwapIndexes , InvalidRequest , BAD_REQUEST ;
MissingTaskFilters , InvalidRequest , BAD_REQUEST ;
NoSpaceLeftOnDevice , System , UNPROCESSABLE_ENTITY;
PayloadTooLarge , InvalidRequest , PAYLOAD_TOO_LARGE ;
TaskNotFound , InvalidRequest , NOT_FOUND ;
TooManyOpenFiles , System , UNPROCESSABLE_ENTITY ;
TooManyVectors , InvalidRequest , BAD_REQUEST ;
UnretrievableDocument , Internal , BAD_REQUEST ;
UnretrievableErrorCode , InvalidRequest , BAD_REQUEST ;
UnsupportedMediaType , InvalidRequest , UNSUPPORTED_MEDIA_TYPE
UnsupportedMediaType , InvalidRequest , UNSUPPORTED_MEDIA_TYPE ;
VectorEmbeddingError , InvalidRequest , BAD_REQUEST
}
impl ErrorCode for JoinError {
@ -336,6 +343,13 @@ impl ErrorCode for milli::Error {
UserError::InvalidDocumentId { .. } | UserError::TooManyDocumentIds { .. } => {
Code::InvalidDocumentId
}
UserError::MissingDocumentField(_) => Code::InvalidDocumentFields,
UserError::InvalidFieldForSource { .. }
| UserError::MissingFieldForSource { .. }
| UserError::InvalidOpenAiModel { .. }
| UserError::InvalidPrompt(_) => Code::InvalidSettingsEmbedders,
UserError::TooManyEmbedders(_) => Code::InvalidSettingsEmbedders,
UserError::InvalidPromptForEmbeddings(..) => Code::InvalidSettingsEmbedders,
UserError::NoPrimaryKeyCandidateFound => Code::IndexPrimaryKeyNoCandidateFound,
UserError::MultiplePrimaryKeyCandidatesFound { .. } => {
Code::IndexPrimaryKeyMultipleCandidatesFound
@ -353,11 +367,15 @@ impl ErrorCode for milli::Error {
UserError::CriterionError(_) => Code::InvalidSettingsRankingRules,
UserError::InvalidGeoField { .. } => Code::InvalidDocumentGeoField,
UserError::InvalidVectorDimensions { .. } => Code::InvalidVectorDimensions,
UserError::InvalidVectorsMapType { .. } => Code::InvalidVectorsType,
UserError::InvalidVectorsType { .. } => Code::InvalidVectorsType,
UserError::TooManyVectors(_, _) => Code::TooManyVectors,
UserError::SortError(_) => Code::InvalidSearchSort,
UserError::InvalidMinTypoWordLenSetting(_, _) => {
Code::InvalidSettingsTypoTolerance
}
UserError::InvalidEmbedder(_) => Code::InvalidEmbedder,
UserError::VectorEmbeddingError(_) => Code::VectorEmbeddingError,
}
}
}
@ -445,6 +463,15 @@ impl fmt::Display for DeserrParseIntError {
}
}
impl fmt::Display for deserr_codes::InvalidSearchSemanticRatio {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"the value of `semanticRatio` is invalid, expected a float between `0.0` and `1.0`."
)
}
}
#[macro_export]
macro_rules! internal_error {
($target:ty : $($other:path), *) => {

View File

@ -7,7 +7,6 @@ pub struct RuntimeTogglableFeatures {
pub vector_store: bool,
pub metrics: bool,
pub export_puffin_reports: bool,
pub proximity_precision: bool,
}
#[derive(Default, Debug, Clone, Copy)]

View File

@ -9,6 +9,7 @@ pub mod index_uid_pattern;
pub mod keys;
pub mod settings;
pub mod star_or;
pub mod task_view;
pub mod tasks;
pub mod versioning;
pub use milli::{heed, Index};

View File

@ -199,6 +199,10 @@ pub struct Settings<T> {
#[deserr(default, error = DeserrJsonError<InvalidSettingsPagination>)]
pub pagination: Setting<PaginationSettings>,
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
#[deserr(default, error = DeserrJsonError<InvalidSettingsEmbedders>)]
pub embedders: Setting<BTreeMap<String, Setting<milli::vector::settings::EmbeddingSettings>>>,
#[serde(skip)]
#[deserr(skip)]
pub _kind: PhantomData<T>,
@ -222,6 +226,7 @@ impl Settings<Checked> {
typo_tolerance: Setting::Reset,
faceting: Setting::Reset,
pagination: Setting::Reset,
embedders: Setting::Reset,
_kind: PhantomData,
}
}
@ -243,6 +248,7 @@ impl Settings<Checked> {
typo_tolerance,
faceting,
pagination,
embedders,
..
} = self;
@ -262,6 +268,7 @@ impl Settings<Checked> {
typo_tolerance,
faceting,
pagination,
embedders,
_kind: PhantomData,
}
}
@ -307,9 +314,25 @@ impl Settings<Unchecked> {
typo_tolerance: self.typo_tolerance,
faceting: self.faceting,
pagination: self.pagination,
embedders: self.embedders,
_kind: PhantomData,
}
}
pub fn validate(self) -> Result<Self, milli::Error> {
self.validate_embedding_settings()
}
fn validate_embedding_settings(mut self) -> Result<Self, milli::Error> {
let Setting::Set(mut configs) = self.embedders else { return Ok(self) };
for (name, config) in configs.iter_mut() {
let config_to_check = std::mem::take(config);
let checked_config = milli::update::validate_embedding_settings(config_to_check, name)?;
*config = checked_config
}
self.embedders = Setting::Set(configs);
Ok(self)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
@ -490,6 +513,12 @@ pub fn apply_settings_to_builder(
Setting::Reset => builder.reset_pagination_max_total_hits(),
Setting::NotSet => (),
}
match settings.embedders.clone() {
Setting::Set(value) => builder.set_embedder_settings(value),
Setting::Reset => builder.reset_embedder_settings(),
Setting::NotSet => (),
}
}
pub fn settings(
@ -571,6 +600,13 @@ pub fn settings(
),
};
let embedders: BTreeMap<_, _> = index
.embedding_configs(rtxn)?
.into_iter()
.map(|(name, config)| (name, Setting::Set(config.into())))
.collect();
let embedders = if embedders.is_empty() { Setting::NotSet } else { Setting::Set(embedders) };
Ok(Settings {
displayed_attributes: match displayed_attributes {
Some(attrs) => Setting::Set(attrs),
@ -591,14 +627,12 @@ pub fn settings(
Some(field) => Setting::Set(field),
None => Setting::Reset,
},
proximity_precision: match proximity_precision {
Some(precision) => Setting::Set(precision),
None => Setting::Reset,
},
proximity_precision: Setting::Set(proximity_precision.unwrap_or_default()),
synonyms: Setting::Set(synonyms),
typo_tolerance: Setting::Set(typo_tolerance),
faceting: Setting::Set(faceting),
pagination: Setting::Set(pagination),
embedders,
_kind: PhantomData,
})
}
@ -699,27 +733,28 @@ impl From<RankingRuleView> for Criterion {
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserr, Serialize, Deserialize)]
#[derive(Default, Debug, Clone, Copy, PartialEq, Eq, Deserr, Serialize, Deserialize)]
#[serde(deny_unknown_fields, rename_all = "camelCase")]
#[deserr(error = DeserrJsonError<InvalidSettingsProximityPrecision>, rename_all = camelCase, deny_unknown_fields)]
pub enum ProximityPrecisionView {
WordScale,
AttributeScale,
#[default]
ByWord,
ByAttribute,
}
impl From<ProximityPrecision> for ProximityPrecisionView {
fn from(value: ProximityPrecision) -> Self {
match value {
ProximityPrecision::WordScale => ProximityPrecisionView::WordScale,
ProximityPrecision::AttributeScale => ProximityPrecisionView::AttributeScale,
ProximityPrecision::ByWord => ProximityPrecisionView::ByWord,
ProximityPrecision::ByAttribute => ProximityPrecisionView::ByAttribute,
}
}
}
impl From<ProximityPrecisionView> for ProximityPrecision {
fn from(value: ProximityPrecisionView) -> Self {
match value {
ProximityPrecisionView::WordScale => ProximityPrecision::WordScale,
ProximityPrecisionView::AttributeScale => ProximityPrecision::AttributeScale,
ProximityPrecisionView::ByWord => ProximityPrecision::ByWord,
ProximityPrecisionView::ByAttribute => ProximityPrecision::ByAttribute,
}
}
}
@ -747,6 +782,7 @@ pub(crate) mod test {
typo_tolerance: Setting::NotSet,
faceting: Setting::NotSet,
pagination: Setting::NotSet,
embedders: Setting::NotSet,
_kind: PhantomData::<Unchecked>,
};
@ -772,6 +808,7 @@ pub(crate) mod test {
typo_tolerance: Setting::NotSet,
faceting: Setting::NotSet,
pagination: Setting::NotSet,
embedders: Setting::NotSet,
_kind: PhantomData::<Unchecked>,
};

View File

@ -0,0 +1,139 @@
use serde::Serialize;
use time::{Duration, OffsetDateTime};
use crate::error::ResponseError;
use crate::settings::{Settings, Unchecked};
use crate::tasks::{serialize_duration, Details, IndexSwap, Kind, Status, Task, TaskId};
#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct TaskView {
pub uid: TaskId,
#[serde(default)]
pub index_uid: Option<String>,
pub status: Status,
#[serde(rename = "type")]
pub kind: Kind,
pub canceled_by: Option<TaskId>,
#[serde(skip_serializing_if = "Option::is_none")]
pub details: Option<DetailsView>,
pub error: Option<ResponseError>,
#[serde(serialize_with = "serialize_duration", default)]
pub duration: Option<Duration>,
#[serde(with = "time::serde::rfc3339")]
pub enqueued_at: OffsetDateTime,
#[serde(with = "time::serde::rfc3339::option", default)]
pub started_at: Option<OffsetDateTime>,
#[serde(with = "time::serde::rfc3339::option", default)]
pub finished_at: Option<OffsetDateTime>,
}
impl TaskView {
pub fn from_task(task: &Task) -> TaskView {
TaskView {
uid: task.uid,
index_uid: task.index_uid().map(ToOwned::to_owned),
status: task.status,
kind: task.kind.as_kind(),
canceled_by: task.canceled_by,
details: task.details.clone().map(DetailsView::from),
error: task.error.clone(),
duration: task.started_at.zip(task.finished_at).map(|(start, end)| end - start),
enqueued_at: task.enqueued_at,
started_at: task.started_at,
finished_at: task.finished_at,
}
}
}
#[derive(Default, Debug, PartialEq, Eq, Clone, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct DetailsView {
#[serde(skip_serializing_if = "Option::is_none")]
pub received_documents: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub indexed_documents: Option<Option<u64>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub primary_key: Option<Option<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub provided_ids: Option<usize>,
#[serde(skip_serializing_if = "Option::is_none")]
pub deleted_documents: Option<Option<u64>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub matched_tasks: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub canceled_tasks: Option<Option<u64>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub deleted_tasks: Option<Option<u64>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub original_filter: Option<Option<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub dump_uid: Option<Option<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
#[serde(flatten)]
pub settings: Option<Box<Settings<Unchecked>>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub swaps: Option<Vec<IndexSwap>>,
}
impl From<Details> for DetailsView {
fn from(details: Details) -> Self {
match details {
Details::DocumentAdditionOrUpdate { received_documents, indexed_documents } => {
DetailsView {
received_documents: Some(received_documents),
indexed_documents: Some(indexed_documents),
..DetailsView::default()
}
}
Details::SettingsUpdate { settings } => {
DetailsView { settings: Some(settings), ..DetailsView::default() }
}
Details::IndexInfo { primary_key } => {
DetailsView { primary_key: Some(primary_key), ..DetailsView::default() }
}
Details::DocumentDeletion {
provided_ids: received_document_ids,
deleted_documents,
} => DetailsView {
provided_ids: Some(received_document_ids),
deleted_documents: Some(deleted_documents),
original_filter: Some(None),
..DetailsView::default()
},
Details::DocumentDeletionByFilter { original_filter, deleted_documents } => {
DetailsView {
provided_ids: Some(0),
original_filter: Some(Some(original_filter)),
deleted_documents: Some(deleted_documents),
..DetailsView::default()
}
}
Details::ClearAll { deleted_documents } => {
DetailsView { deleted_documents: Some(deleted_documents), ..DetailsView::default() }
}
Details::TaskCancelation { matched_tasks, canceled_tasks, original_filter } => {
DetailsView {
matched_tasks: Some(matched_tasks),
canceled_tasks: Some(canceled_tasks),
original_filter: Some(Some(original_filter)),
..DetailsView::default()
}
}
Details::TaskDeletion { matched_tasks, deleted_tasks, original_filter } => {
DetailsView {
matched_tasks: Some(matched_tasks),
deleted_tasks: Some(deleted_tasks),
original_filter: Some(Some(original_filter)),
..DetailsView::default()
}
}
Details::Dump { dump_uid } => {
DetailsView { dump_uid: Some(dump_uid), ..DetailsView::default() }
}
Details::IndexSwap { swaps } => {
DetailsView { swaps: Some(swaps), ..Default::default() }
}
}
}
}

View File

@ -104,6 +104,7 @@ walkdir = "2.3.3"
yaup = "0.2.1"
serde_urlencoded = "0.7.1"
termcolor = "1.2.0"
url = { version = "2.5.0", features = ["serde"] }
[dev-dependencies]
actix-rt = "2.8.0"
@ -153,5 +154,5 @@ greek = ["meilisearch-types/greek"]
khmer = ["meilisearch-types/khmer"]
[package.metadata.mini-dashboard]
assets-url = "https://github.com/meilisearch/mini-dashboard/releases/download/v0.2.11/build.zip"
sha1 = "83cd44ed1e5f97ecb581dc9f958a63f4ccc982d9"
assets-url = "https://github.com/meilisearch/mini-dashboard/releases/download/v0.2.12/build.zip"
sha1 = "acfe9a018c93eb0604ea87ee87bff7df5474e18e"

View File

@ -36,7 +36,7 @@ use crate::routes::{create_all_stats, Stats};
use crate::search::{
FacetSearchResult, MatchingStrategy, SearchQuery, SearchQueryWithIndex, SearchResult,
DEFAULT_CROP_LENGTH, DEFAULT_CROP_MARKER, DEFAULT_HIGHLIGHT_POST_TAG,
DEFAULT_HIGHLIGHT_PRE_TAG, DEFAULT_SEARCH_LIMIT,
DEFAULT_HIGHLIGHT_PRE_TAG, DEFAULT_SEARCH_LIMIT, DEFAULT_SEMANTIC_RATIO,
};
use crate::Opt;
@ -251,6 +251,7 @@ struct Infos {
env: String,
experimental_enable_metrics: bool,
experimental_reduce_indexing_memory_usage: bool,
experimental_max_number_of_batched_tasks: usize,
db_path: bool,
import_dump: bool,
dump_dir: bool,
@ -263,7 +264,8 @@ struct Infos {
ignore_snapshot_if_db_exists: bool,
http_addr: bool,
http_payload_size_limit: Byte,
max_number_of_batched_tasks: usize,
task_queue_webhook: bool,
task_webhook_authorization_header: bool,
log_level: String,
max_indexing_memory: MaxMemory,
max_indexing_threads: MaxThreads,
@ -286,13 +288,15 @@ impl From<Opt> for Infos {
db_path,
experimental_enable_metrics,
experimental_reduce_indexing_memory_usage,
experimental_max_number_of_batched_tasks,
http_addr,
master_key: _,
env,
task_webhook_url,
task_webhook_authorization_header,
max_index_size: _,
max_task_db_size: _,
http_payload_size_limit,
max_number_of_batched_tasks,
ssl_cert_path,
ssl_key_path,
ssl_auth_path,
@ -342,7 +346,9 @@ impl From<Opt> for Infos {
ignore_snapshot_if_db_exists,
http_addr: http_addr != default_http_addr(),
http_payload_size_limit,
max_number_of_batched_tasks,
experimental_max_number_of_batched_tasks,
task_queue_webhook: task_webhook_url.is_some(),
task_webhook_authorization_header: task_webhook_authorization_header.is_some(),
log_level: log_level.to_string(),
max_indexing_memory,
max_indexing_threads,
@ -586,6 +592,11 @@ pub struct SearchAggregator {
// vector
// The maximum number of floats in a vector request
max_vector_size: usize,
// Whether the semantic ratio passed to a hybrid search equals the default ratio.
semantic_ratio: bool,
// Whether a non-default embedder was specified
embedder: bool,
hybrid: bool,
// every time a search is done, we increment the counter linked to the used settings
matching_strategy: HashMap<String, usize>,
@ -639,6 +650,7 @@ impl SearchAggregator {
crop_marker,
matching_strategy,
attributes_to_search_on,
hybrid,
} = query;
let mut ret = Self::default();
@ -712,6 +724,12 @@ impl SearchAggregator {
ret.show_ranking_score = *show_ranking_score;
ret.show_ranking_score_details = *show_ranking_score_details;
if let Some(hybrid) = hybrid {
ret.semantic_ratio = hybrid.semantic_ratio != DEFAULT_SEMANTIC_RATIO();
ret.embedder = hybrid.embedder.is_some();
ret.hybrid = true;
}
ret
}
@ -765,6 +783,9 @@ impl SearchAggregator {
facets_total_number_of_facets,
show_ranking_score,
show_ranking_score_details,
semantic_ratio,
embedder,
hybrid,
} = other;
if self.timestamp.is_none() {
@ -810,6 +831,9 @@ impl SearchAggregator {
// vector
self.max_vector_size = self.max_vector_size.max(max_vector_size);
self.semantic_ratio |= semantic_ratio;
self.hybrid |= hybrid;
self.embedder |= embedder;
// pagination
self.max_limit = self.max_limit.max(max_limit);
@ -878,6 +902,9 @@ impl SearchAggregator {
facets_total_number_of_facets,
show_ranking_score,
show_ranking_score_details,
semantic_ratio,
embedder,
hybrid,
} = self;
if total_received == 0 {
@ -917,6 +944,11 @@ impl SearchAggregator {
"vector": {
"max_vector_size": max_vector_size,
},
"hybrid": {
"enabled": hybrid,
"semantic_ratio": semantic_ratio,
"embedder": embedder,
},
"pagination": {
"max_limit": max_limit,
"max_offset": max_offset,
@ -1012,6 +1044,7 @@ impl MultiSearchAggregator {
crop_marker: _,
matching_strategy: _,
attributes_to_search_on: _,
hybrid: _,
} = query;
index_uid.as_str()
@ -1158,6 +1191,7 @@ impl FacetSearchAggregator {
filter,
matching_strategy,
attributes_to_search_on,
hybrid,
} = query;
let mut ret = Self::default();
@ -1171,7 +1205,8 @@ impl FacetSearchAggregator {
|| vector.is_some()
|| filter.is_some()
|| *matching_strategy != MatchingStrategy::default()
|| attributes_to_search_on.is_some();
|| attributes_to_search_on.is_some()
|| hybrid.is_some();
ret
}

View File

@ -51,6 +51,8 @@ pub enum MeilisearchHttpError {
DocumentFormat(#[from] DocumentFormatError),
#[error(transparent)]
Join(#[from] JoinError),
#[error("Invalid request: missing `hybrid` parameter when both `q` and `vector` are present.")]
MissingSearchHybrid,
}
impl ErrorCode for MeilisearchHttpError {
@ -74,6 +76,7 @@ impl ErrorCode for MeilisearchHttpError {
MeilisearchHttpError::FileStore(_) => Code::Internal,
MeilisearchHttpError::DocumentFormat(e) => e.error_code(),
MeilisearchHttpError::Join(_) => Code::Internal,
MeilisearchHttpError::MissingSearchHybrid => Code::MissingSearchHybrid,
}
}
}

View File

@ -228,13 +228,15 @@ fn open_or_create_database_unchecked(
indexes_path: opt.db_path.join("indexes"),
snapshots_path: opt.snapshot_dir.clone(),
dumps_path: opt.dump_dir.clone(),
webhook_url: opt.task_webhook_url.as_ref().map(|url| url.to_string()),
webhook_authorization_header: opt.task_webhook_authorization_header.clone(),
task_db_size: opt.max_task_db_size.get_bytes() as usize,
index_base_map_size: opt.max_index_size.get_bytes() as usize,
enable_mdb_writemap: opt.experimental_reduce_indexing_memory_usage,
indexer_config: (&opt.indexer_options).try_into()?,
autobatching_enabled: true,
max_number_of_tasks: 1_000_000,
max_number_of_batched_tasks: opt.max_number_of_batched_tasks,
max_number_of_batched_tasks: opt.experimental_max_number_of_batched_tasks,
index_growth_amount: byte_unit::Byte::from_str("10GiB").unwrap().get_bytes() as usize,
index_count: DEFAULT_INDEX_COUNT,
instance_features,

View File

@ -19,7 +19,11 @@ static ALLOC: mimalloc::MiMalloc = mimalloc::MiMalloc;
/// does all the setup before meilisearch is launched
fn setup(opt: &Opt) -> anyhow::Result<()> {
let mut log_builder = env_logger::Builder::new();
log_builder.parse_filters(&opt.log_level.to_string());
let log_filters = format!(
"{},h2=warn,hyper=warn,tokio_util=warn,tracing=warn,rustls=warn,mio=warn,reqwest=warn",
opt.log_level
);
log_builder.parse_filters(&log_filters);
log_builder.init();

View File

@ -21,6 +21,7 @@ use rustls::RootCertStore;
use rustls_pemfile::{certs, pkcs8_private_keys, rsa_private_keys};
use serde::{Deserialize, Serialize};
use sysinfo::{RefreshKind, System, SystemExt};
use url::Url;
const POSSIBLE_ENV: [&str; 2] = ["development", "production"];
@ -28,9 +29,10 @@ const MEILI_DB_PATH: &str = "MEILI_DB_PATH";
const MEILI_HTTP_ADDR: &str = "MEILI_HTTP_ADDR";
const MEILI_MASTER_KEY: &str = "MEILI_MASTER_KEY";
const MEILI_ENV: &str = "MEILI_ENV";
const MEILI_TASK_WEBHOOK_URL: &str = "MEILI_TASK_WEBHOOK_URL";
const MEILI_TASK_WEBHOOK_AUTHORIZATION_HEADER: &str = "MEILI_TASK_WEBHOOK_AUTHORIZATION_HEADER";
#[cfg(feature = "analytics")]
const MEILI_NO_ANALYTICS: &str = "MEILI_NO_ANALYTICS";
const MEILI_MAX_NUMBER_OF_BATCHED_TASKS: &str = "MEILI_MAX_NUMBER_OF_BATCHED_TASKS";
const MEILI_HTTP_PAYLOAD_SIZE_LIMIT: &str = "MEILI_HTTP_PAYLOAD_SIZE_LIMIT";
const MEILI_SSL_CERT_PATH: &str = "MEILI_SSL_CERT_PATH";
const MEILI_SSL_KEY_PATH: &str = "MEILI_SSL_KEY_PATH";
@ -52,6 +54,8 @@ const MEILI_LOG_LEVEL: &str = "MEILI_LOG_LEVEL";
const MEILI_EXPERIMENTAL_ENABLE_METRICS: &str = "MEILI_EXPERIMENTAL_ENABLE_METRICS";
const MEILI_EXPERIMENTAL_REDUCE_INDEXING_MEMORY_USAGE: &str =
"MEILI_EXPERIMENTAL_REDUCE_INDEXING_MEMORY_USAGE";
const MEILI_EXPERIMENTAL_MAX_NUMBER_OF_BATCHED_TASKS: &str =
"MEILI_EXPERIMENTAL_MAX_NUMBER_OF_BATCHED_TASKS";
const DEFAULT_CONFIG_FILE_PATH: &str = "./config.toml";
const DEFAULT_DB_PATH: &str = "./data.ms";
@ -155,6 +159,14 @@ pub struct Opt {
#[serde(default = "default_env")]
pub env: String,
/// Called whenever a task finishes so a third party can be notified.
#[clap(long, env = MEILI_TASK_WEBHOOK_URL)]
pub task_webhook_url: Option<Url>,
/// The Authorization header to send on the webhook URL whenever a task finishes so a third party can be notified.
#[clap(long, env = MEILI_TASK_WEBHOOK_AUTHORIZATION_HEADER)]
pub task_webhook_authorization_header: Option<String>,
/// Deactivates Meilisearch's built-in telemetry when provided.
///
/// Meilisearch automatically collects data from all instances that do not opt out using this flag.
@ -176,11 +188,6 @@ pub struct Opt {
#[serde(skip, default = "default_max_task_db_size")]
pub max_task_db_size: Byte,
/// Defines the maximum number of tasks that will be processed at once.
#[clap(long, env = MEILI_MAX_NUMBER_OF_BATCHED_TASKS, default_value_t = default_limit_batched_tasks())]
#[serde(default = "default_limit_batched_tasks")]
pub max_number_of_batched_tasks: usize,
/// Sets the maximum size of accepted payloads. Value must be given in bytes or explicitly stating a
/// base unit (for instance: 107374182400, '107.7Gb', or '107374 Mb').
#[clap(long, env = MEILI_HTTP_PAYLOAD_SIZE_LIMIT, default_value_t = default_http_payload_size_limit())]
@ -307,6 +314,11 @@ pub struct Opt {
#[serde(default)]
pub experimental_reduce_indexing_memory_usage: bool,
/// Experimentally reduces the maximum number of tasks that will be processed at once, see: <https://github.com/orgs/meilisearch/discussions/713>
#[clap(long, env = MEILI_EXPERIMENTAL_MAX_NUMBER_OF_BATCHED_TASKS, default_value_t = default_limit_batched_tasks())]
#[serde(default = "default_limit_batched_tasks")]
pub experimental_max_number_of_batched_tasks: usize,
#[serde(flatten)]
#[clap(flatten)]
pub indexer_options: IndexerOpts,
@ -374,10 +386,12 @@ impl Opt {
http_addr,
master_key,
env,
task_webhook_url,
task_webhook_authorization_header,
max_index_size: _,
max_task_db_size: _,
http_payload_size_limit,
max_number_of_batched_tasks,
experimental_max_number_of_batched_tasks,
ssl_cert_path,
ssl_key_path,
ssl_auth_path,
@ -408,6 +422,16 @@ impl Opt {
export_to_env_if_not_present(MEILI_MASTER_KEY, master_key);
}
export_to_env_if_not_present(MEILI_ENV, env);
if let Some(task_webhook_url) = task_webhook_url {
export_to_env_if_not_present(MEILI_TASK_WEBHOOK_URL, task_webhook_url.to_string());
}
if let Some(task_webhook_authorization_header) = task_webhook_authorization_header {
export_to_env_if_not_present(
MEILI_TASK_WEBHOOK_AUTHORIZATION_HEADER,
task_webhook_authorization_header,
);
}
#[cfg(feature = "analytics")]
{
export_to_env_if_not_present(MEILI_NO_ANALYTICS, no_analytics.to_string());
@ -417,8 +441,8 @@ impl Opt {
http_payload_size_limit.to_string(),
);
export_to_env_if_not_present(
MEILI_MAX_NUMBER_OF_BATCHED_TASKS,
max_number_of_batched_tasks.to_string(),
MEILI_EXPERIMENTAL_MAX_NUMBER_OF_BATCHED_TASKS,
experimental_max_number_of_batched_tasks.to_string(),
);
if let Some(ssl_cert_path) = ssl_cert_path {
export_to_env_if_not_present(MEILI_SSL_CERT_PATH, ssl_cert_path);

View File

@ -48,8 +48,6 @@ pub struct RuntimeTogglableFeatures {
pub metrics: Option<bool>,
#[deserr(default)]
pub export_puffin_reports: Option<bool>,
#[deserr(default)]
pub proximity_precision: Option<bool>,
}
async fn patch_features(
@ -72,10 +70,6 @@ async fn patch_features(
.0
.export_puffin_reports
.unwrap_or(old_features.export_puffin_reports),
proximity_precision: new_features
.0
.proximity_precision
.unwrap_or(old_features.proximity_precision),
};
// explicitly destructure for analytics rather than using the `Serialize` implementation, because
@ -86,7 +80,6 @@ async fn patch_features(
vector_store,
metrics,
export_puffin_reports,
proximity_precision,
} = new_features;
analytics.publish(
@ -96,7 +89,6 @@ async fn patch_features(
"vector_store": vector_store,
"metrics": metrics,
"export_puffin_reports": export_puffin_reports,
"proximity_precision": proximity_precision,
}),
Some(&req),
);

View File

@ -13,9 +13,9 @@ use crate::analytics::{Analytics, FacetSearchAggregator};
use crate::extractors::authentication::policies::*;
use crate::extractors::authentication::GuardedData;
use crate::search::{
add_search_rules, perform_facet_search, MatchingStrategy, SearchQuery, DEFAULT_CROP_LENGTH,
DEFAULT_CROP_MARKER, DEFAULT_HIGHLIGHT_POST_TAG, DEFAULT_HIGHLIGHT_PRE_TAG,
DEFAULT_SEARCH_LIMIT, DEFAULT_SEARCH_OFFSET,
add_search_rules, perform_facet_search, HybridQuery, MatchingStrategy, SearchQuery,
DEFAULT_CROP_LENGTH, DEFAULT_CROP_MARKER, DEFAULT_HIGHLIGHT_POST_TAG,
DEFAULT_HIGHLIGHT_PRE_TAG, DEFAULT_SEARCH_LIMIT, DEFAULT_SEARCH_OFFSET,
};
pub fn configure(cfg: &mut web::ServiceConfig) {
@ -36,6 +36,8 @@ pub struct FacetSearchQuery {
pub q: Option<String>,
#[deserr(default, error = DeserrJsonError<InvalidSearchVector>)]
pub vector: Option<Vec<f32>>,
#[deserr(default, error = DeserrJsonError<InvalidHybridQuery>)]
pub hybrid: Option<HybridQuery>,
#[deserr(default, error = DeserrJsonError<InvalidSearchFilter>)]
pub filter: Option<Value>,
#[deserr(default, error = DeserrJsonError<InvalidSearchMatchingStrategy>, default)]
@ -95,6 +97,7 @@ impl From<FacetSearchQuery> for SearchQuery {
filter,
matching_strategy,
attributes_to_search_on,
hybrid,
} = value;
SearchQuery {
@ -119,6 +122,7 @@ impl From<FacetSearchQuery> for SearchQuery {
matching_strategy,
vector,
attributes_to_search_on,
hybrid,
}
}
}

View File

@ -2,12 +2,14 @@ use actix_web::web::Data;
use actix_web::{web, HttpRequest, HttpResponse};
use deserr::actix_web::{AwebJson, AwebQueryParameter};
use index_scheduler::IndexScheduler;
use log::debug;
use log::{debug, warn};
use meilisearch_types::deserr::query_params::Param;
use meilisearch_types::deserr::{DeserrJsonError, DeserrQueryParamError};
use meilisearch_types::error::deserr_codes::*;
use meilisearch_types::error::ResponseError;
use meilisearch_types::index_uid::IndexUid;
use meilisearch_types::milli;
use meilisearch_types::milli::vector::DistributionShift;
use meilisearch_types::serde_cs::vec::CS;
use serde_json::Value;
@ -16,9 +18,9 @@ use crate::extractors::authentication::policies::*;
use crate::extractors::authentication::GuardedData;
use crate::extractors::sequential_extractor::SeqHandler;
use crate::search::{
add_search_rules, perform_search, MatchingStrategy, SearchQuery, DEFAULT_CROP_LENGTH,
DEFAULT_CROP_MARKER, DEFAULT_HIGHLIGHT_POST_TAG, DEFAULT_HIGHLIGHT_PRE_TAG,
DEFAULT_SEARCH_LIMIT, DEFAULT_SEARCH_OFFSET,
add_search_rules, perform_search, HybridQuery, MatchingStrategy, SearchQuery, SemanticRatio,
DEFAULT_CROP_LENGTH, DEFAULT_CROP_MARKER, DEFAULT_HIGHLIGHT_POST_TAG,
DEFAULT_HIGHLIGHT_PRE_TAG, DEFAULT_SEARCH_LIMIT, DEFAULT_SEARCH_OFFSET, DEFAULT_SEMANTIC_RATIO,
};
pub fn configure(cfg: &mut web::ServiceConfig) {
@ -74,6 +76,31 @@ pub struct SearchQueryGet {
matching_strategy: MatchingStrategy,
#[deserr(default, error = DeserrQueryParamError<InvalidSearchAttributesToSearchOn>)]
pub attributes_to_search_on: Option<CS<String>>,
#[deserr(default, error = DeserrQueryParamError<InvalidEmbedder>)]
pub hybrid_embedder: Option<String>,
#[deserr(default, error = DeserrQueryParamError<InvalidSearchSemanticRatio>)]
pub hybrid_semantic_ratio: Option<SemanticRatioGet>,
}
#[derive(Debug, Clone, Copy, Default, PartialEq, deserr::Deserr)]
#[deserr(try_from(String) = TryFrom::try_from -> InvalidSearchSemanticRatio)]
pub struct SemanticRatioGet(SemanticRatio);
impl std::convert::TryFrom<String> for SemanticRatioGet {
type Error = InvalidSearchSemanticRatio;
fn try_from(s: String) -> Result<Self, Self::Error> {
let f: f32 = s.parse().map_err(|_| InvalidSearchSemanticRatio)?;
Ok(SemanticRatioGet(SemanticRatio::try_from(f)?))
}
}
impl std::ops::Deref for SemanticRatioGet {
type Target = SemanticRatio;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl From<SearchQueryGet> for SearchQuery {
@ -86,6 +113,20 @@ impl From<SearchQueryGet> for SearchQuery {
None => None,
};
let hybrid = match (other.hybrid_embedder, other.hybrid_semantic_ratio) {
(None, None) => None,
(None, Some(semantic_ratio)) => {
Some(HybridQuery { semantic_ratio: *semantic_ratio, embedder: None })
}
(Some(embedder), None) => Some(HybridQuery {
semantic_ratio: DEFAULT_SEMANTIC_RATIO(),
embedder: Some(embedder),
}),
(Some(embedder), Some(semantic_ratio)) => {
Some(HybridQuery { semantic_ratio: *semantic_ratio, embedder: Some(embedder) })
}
};
Self {
q: other.q,
vector: other.vector.map(CS::into_inner),
@ -108,6 +149,7 @@ impl From<SearchQueryGet> for SearchQuery {
crop_marker: other.crop_marker,
matching_strategy: other.matching_strategy,
attributes_to_search_on: other.attributes_to_search_on.map(|o| o.into_iter().collect()),
hybrid,
}
}
}
@ -158,8 +200,12 @@ pub async fn search_with_url_query(
let index = index_scheduler.index(&index_uid)?;
let features = index_scheduler.features();
let distribution = embed(&mut query, index_scheduler.get_ref(), &index).await?;
let search_result =
tokio::task::spawn_blocking(move || perform_search(&index, query, features)).await?;
tokio::task::spawn_blocking(move || perform_search(&index, query, features, distribution))
.await?;
if let Ok(ref search_result) = search_result {
aggregate.succeed(search_result);
}
@ -193,8 +239,12 @@ pub async fn search_with_post(
let index = index_scheduler.index(&index_uid)?;
let features = index_scheduler.features();
let distribution = embed(&mut query, index_scheduler.get_ref(), &index).await?;
let search_result =
tokio::task::spawn_blocking(move || perform_search(&index, query, features)).await?;
tokio::task::spawn_blocking(move || perform_search(&index, query, features, distribution))
.await?;
if let Ok(ref search_result) = search_result {
aggregate.succeed(search_result);
}
@ -206,6 +256,80 @@ pub async fn search_with_post(
Ok(HttpResponse::Ok().json(search_result))
}
pub async fn embed(
query: &mut SearchQuery,
index_scheduler: &IndexScheduler,
index: &milli::Index,
) -> Result<Option<DistributionShift>, ResponseError> {
match (&query.hybrid, &query.vector, &query.q) {
(Some(HybridQuery { semantic_ratio: _, embedder }), None, Some(q))
if !q.trim().is_empty() =>
{
let embedder_configs = index.embedding_configs(&index.read_txn()?)?;
let embedders = index_scheduler.embedders(embedder_configs)?;
let embedder = if let Some(embedder_name) = embedder {
embedders.get(embedder_name)
} else {
embedders.get_default()
};
let embedder = embedder
.ok_or(milli::UserError::InvalidEmbedder("default".to_owned()))
.map_err(milli::Error::from)?
.0;
let distribution = embedder.distribution();
let embeddings = embedder
.embed(vec![q.to_owned()])
.await
.map_err(milli::vector::Error::from)
.map_err(milli::Error::from)?
.pop()
.expect("No vector returned from embedding");
if embeddings.iter().nth(1).is_some() {
warn!("Ignoring embeddings past the first one in long search query");
query.vector = Some(embeddings.iter().next().unwrap().to_vec());
} else {
query.vector = Some(embeddings.into_inner());
}
Ok(distribution)
}
(Some(hybrid), vector, _) => {
let embedder_configs = index.embedding_configs(&index.read_txn()?)?;
let embedders = index_scheduler.embedders(embedder_configs)?;
let embedder = if let Some(embedder_name) = &hybrid.embedder {
embedders.get(embedder_name)
} else {
embedders.get_default()
};
let embedder = embedder
.ok_or(milli::UserError::InvalidEmbedder("default".to_owned()))
.map_err(milli::Error::from)?
.0;
if let Some(vector) = vector {
if vector.len() != embedder.dimensions() {
return Err(meilisearch_types::milli::Error::UserError(
meilisearch_types::milli::UserError::InvalidVectorDimensions {
expected: embedder.dimensions(),
found: vector.len(),
},
)
.into());
}
}
Ok(embedder.distribution())
}
_ => Ok(None),
}
}
#[cfg(test)]
mod test {
use super::*;

View File

@ -7,6 +7,7 @@ use meilisearch_types::deserr::DeserrJsonError;
use meilisearch_types::error::ResponseError;
use meilisearch_types::facet_values_sort::FacetValuesSort;
use meilisearch_types::index_uid::IndexUid;
use meilisearch_types::milli::update::Setting;
use meilisearch_types::settings::{settings, RankingRuleView, Settings, Unchecked};
use meilisearch_types::tasks::KindWithContent;
use serde_json::json;
@ -89,6 +90,11 @@ macro_rules! make_setting_route {
..Default::default()
};
let new_settings = $crate::routes::indexes::settings::validate_settings(
new_settings,
&index_scheduler,
)?;
let allow_index_creation =
index_scheduler.filters().allow_index_creation(&index_uid);
@ -452,6 +458,7 @@ make_setting_route!(
json!({
"proximity_precision": {
"set": precision.is_some(),
"value": precision.unwrap_or_default(),
}
}),
Some(req),
@ -545,6 +552,67 @@ make_setting_route!(
}
);
make_setting_route!(
"/embedders",
patch,
std::collections::BTreeMap<String, Setting<meilisearch_types::milli::vector::settings::EmbeddingSettings>>,
meilisearch_types::deserr::DeserrJsonError<
meilisearch_types::error::deserr_codes::InvalidSettingsEmbedders,
>,
embedders,
"embedders",
analytics,
|setting: &Option<std::collections::BTreeMap<String, Setting<meilisearch_types::milli::vector::settings::EmbeddingSettings>>>, req: &HttpRequest| {
analytics.publish(
"Embedders Updated".to_string(),
serde_json::json!({"embedders": crate::routes::indexes::settings::embedder_analytics(setting.as_ref())}),
Some(req),
);
}
);
fn embedder_analytics(
setting: Option<
&std::collections::BTreeMap<
String,
Setting<meilisearch_types::milli::vector::settings::EmbeddingSettings>,
>,
>,
) -> serde_json::Value {
let mut sources = std::collections::HashSet::new();
if let Some(s) = &setting {
for source in s
.values()
.filter_map(|config| config.clone().set())
.filter_map(|config| config.source.set())
{
use meilisearch_types::milli::vector::settings::EmbedderSource;
match source {
EmbedderSource::OpenAi => sources.insert("openAi"),
EmbedderSource::HuggingFace => sources.insert("huggingFace"),
EmbedderSource::UserProvided => sources.insert("userProvided"),
};
}
};
let document_template_used = setting.as_ref().map(|map| {
map.values()
.filter_map(|config| config.clone().set())
.any(|config| config.document_template.set().is_some())
});
json!(
{
"total": setting.as_ref().map(|s| s.len()),
"sources": sources,
"document_template_used": document_template_used,
}
)
}
macro_rules! generate_configure {
($($mod:ident),*) => {
pub fn configure(cfg: &mut web::ServiceConfig) {
@ -574,7 +642,8 @@ generate_configure!(
ranking_rules,
typo_tolerance,
pagination,
faceting
faceting,
embedders
);
pub async fn update_all(
@ -587,6 +656,7 @@ pub async fn update_all(
let index_uid = IndexUid::try_from(index_uid.into_inner())?;
let new_settings = body.into_inner();
let new_settings = validate_settings(new_settings, &index_scheduler)?;
analytics.publish(
"Settings Updated".to_string(),
@ -620,7 +690,8 @@ pub async fn update_all(
"set": new_settings.distinct_attribute.as_ref().set().is_some()
},
"proximity_precision": {
"set": new_settings.proximity_precision.as_ref().set().is_some()
"set": new_settings.proximity_precision.as_ref().set().is_some(),
"value": new_settings.proximity_precision.as_ref().set().copied().unwrap_or_default()
},
"typo_tolerance": {
"enabled": new_settings.typo_tolerance
@ -681,6 +752,7 @@ pub async fn update_all(
"synonyms": {
"total": new_settings.synonyms.as_ref().set().map(|synonyms| synonyms.len()),
},
"embedders": crate::routes::indexes::settings::embedder_analytics(new_settings.embedders.as_ref().set())
}),
Some(&req),
);
@ -735,3 +807,13 @@ pub async fn delete_all(
debug!("returns: {:?}", task);
Ok(HttpResponse::Accepted().json(task))
}
fn validate_settings(
settings: Settings<Unchecked>,
index_scheduler: &IndexScheduler,
) -> Result<Settings<Unchecked>, ResponseError> {
if matches!(settings.embedders, Setting::Set(_)) {
index_scheduler.features().check_vector("Passing `embedders` in settings")?
}
Ok(settings.validate()?)
}

View File

@ -13,6 +13,7 @@ use crate::analytics::{Analytics, MultiSearchAggregator};
use crate::extractors::authentication::policies::ActionPolicy;
use crate::extractors::authentication::{AuthenticationError, GuardedData};
use crate::extractors::sequential_extractor::SeqHandler;
use crate::routes::indexes::search::embed;
use crate::search::{
add_search_rules, perform_search, SearchQueryWithIndex, SearchResultWithIndex,
};
@ -74,10 +75,15 @@ pub async fn multi_search_with_post(
})
.with_index(query_index)?;
let search_result =
tokio::task::spawn_blocking(move || perform_search(&index, query, features))
.await
.with_index(query_index)?;
let distribution = embed(&mut query, index_scheduler.get_ref(), &index)
.await
.with_index(query_index)?;
let search_result = tokio::task::spawn_blocking(move || {
perform_search(&index, query, features, distribution)
})
.await
.with_index(query_index)?;
search_results.push(SearchResultWithIndex {
index_uid: index_uid.into_inner(),

View File

@ -8,11 +8,9 @@ use meilisearch_types::deserr::DeserrQueryParamError;
use meilisearch_types::error::deserr_codes::*;
use meilisearch_types::error::{InvalidTaskDateError, ResponseError};
use meilisearch_types::index_uid::IndexUid;
use meilisearch_types::settings::{Settings, Unchecked};
use meilisearch_types::star_or::{OptionStarOr, OptionStarOrList};
use meilisearch_types::tasks::{
serialize_duration, Details, IndexSwap, Kind, KindWithContent, Status, Task,
};
use meilisearch_types::task_view::TaskView;
use meilisearch_types::tasks::{Kind, KindWithContent, Status};
use serde::Serialize;
use serde_json::json;
use time::format_description::well_known::Rfc3339;
@ -37,140 +35,6 @@ pub fn configure(cfg: &mut web::ServiceConfig) {
.service(web::resource("/cancel").route(web::post().to(SeqHandler(cancel_tasks))))
.service(web::resource("/{task_id}").route(web::get().to(SeqHandler(get_task))));
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct TaskView {
pub uid: TaskId,
#[serde(default)]
pub index_uid: Option<String>,
pub status: Status,
#[serde(rename = "type")]
pub kind: Kind,
pub canceled_by: Option<TaskId>,
#[serde(skip_serializing_if = "Option::is_none")]
pub details: Option<DetailsView>,
pub error: Option<ResponseError>,
#[serde(serialize_with = "serialize_duration", default)]
pub duration: Option<Duration>,
#[serde(with = "time::serde::rfc3339")]
pub enqueued_at: OffsetDateTime,
#[serde(with = "time::serde::rfc3339::option", default)]
pub started_at: Option<OffsetDateTime>,
#[serde(with = "time::serde::rfc3339::option", default)]
pub finished_at: Option<OffsetDateTime>,
}
impl TaskView {
pub fn from_task(task: &Task) -> TaskView {
TaskView {
uid: task.uid,
index_uid: task.index_uid().map(ToOwned::to_owned),
status: task.status,
kind: task.kind.as_kind(),
canceled_by: task.canceled_by,
details: task.details.clone().map(DetailsView::from),
error: task.error.clone(),
duration: task.started_at.zip(task.finished_at).map(|(start, end)| end - start),
enqueued_at: task.enqueued_at,
started_at: task.started_at,
finished_at: task.finished_at,
}
}
}
#[derive(Default, Debug, PartialEq, Eq, Clone, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct DetailsView {
#[serde(skip_serializing_if = "Option::is_none")]
pub received_documents: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub indexed_documents: Option<Option<u64>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub primary_key: Option<Option<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub provided_ids: Option<usize>,
#[serde(skip_serializing_if = "Option::is_none")]
pub deleted_documents: Option<Option<u64>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub matched_tasks: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub canceled_tasks: Option<Option<u64>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub deleted_tasks: Option<Option<u64>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub original_filter: Option<Option<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub dump_uid: Option<Option<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
#[serde(flatten)]
pub settings: Option<Box<Settings<Unchecked>>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub swaps: Option<Vec<IndexSwap>>,
}
impl From<Details> for DetailsView {
fn from(details: Details) -> Self {
match details {
Details::DocumentAdditionOrUpdate { received_documents, indexed_documents } => {
DetailsView {
received_documents: Some(received_documents),
indexed_documents: Some(indexed_documents),
..DetailsView::default()
}
}
Details::SettingsUpdate { settings } => {
DetailsView { settings: Some(settings), ..DetailsView::default() }
}
Details::IndexInfo { primary_key } => {
DetailsView { primary_key: Some(primary_key), ..DetailsView::default() }
}
Details::DocumentDeletion {
provided_ids: received_document_ids,
deleted_documents,
} => DetailsView {
provided_ids: Some(received_document_ids),
deleted_documents: Some(deleted_documents),
original_filter: Some(None),
..DetailsView::default()
},
Details::DocumentDeletionByFilter { original_filter, deleted_documents } => {
DetailsView {
provided_ids: Some(0),
original_filter: Some(Some(original_filter)),
deleted_documents: Some(deleted_documents),
..DetailsView::default()
}
}
Details::ClearAll { deleted_documents } => {
DetailsView { deleted_documents: Some(deleted_documents), ..DetailsView::default() }
}
Details::TaskCancelation { matched_tasks, canceled_tasks, original_filter } => {
DetailsView {
matched_tasks: Some(matched_tasks),
canceled_tasks: Some(canceled_tasks),
original_filter: Some(Some(original_filter)),
..DetailsView::default()
}
}
Details::TaskDeletion { matched_tasks, deleted_tasks, original_filter } => {
DetailsView {
matched_tasks: Some(matched_tasks),
deleted_tasks: Some(deleted_tasks),
original_filter: Some(Some(original_filter)),
..DetailsView::default()
}
}
Details::Dump { dump_uid } => {
DetailsView { dump_uid: Some(dump_uid), ..DetailsView::default() }
}
Details::IndexSwap { swaps } => {
DetailsView { swaps: Some(swaps), ..Default::default() }
}
}
}
}
#[derive(Debug, Deserr)]
#[deserr(error = DeserrQueryParamError, rename_all = camelCase, deny_unknown_fields)]
pub struct TasksFilterQuery {

View File

@ -7,24 +7,21 @@ use deserr::Deserr;
use either::Either;
use index_scheduler::RoFeatures;
use indexmap::IndexMap;
use log::warn;
use meilisearch_auth::IndexSearchRules;
use meilisearch_types::deserr::DeserrJsonError;
use meilisearch_types::error::deserr_codes::*;
use meilisearch_types::heed::RoTxn;
use meilisearch_types::index_uid::IndexUid;
use meilisearch_types::milli::score_details::{ScoreDetails, ScoringStrategy};
use meilisearch_types::milli::{
dot_product_similarity, FacetValueHit, InternalError, OrderBy, SearchForFacetValues,
};
use meilisearch_types::milli::score_details::{self, ScoreDetails, ScoringStrategy};
use meilisearch_types::milli::vector::DistributionShift;
use meilisearch_types::milli::{FacetValueHit, OrderBy, SearchForFacetValues};
use meilisearch_types::settings::DEFAULT_PAGINATION_MAX_TOTAL_HITS;
use meilisearch_types::{milli, Document};
use milli::tokenizer::TokenizerBuilder;
use milli::{
AscDesc, FieldId, FieldsIdsMap, Filter, FormatOptions, Index, MatchBounds, MatcherBuilder,
SortError, TermsMatchingStrategy, VectorOrArrayOfVectors, DEFAULT_VALUES_PER_FACET,
SortError, TermsMatchingStrategy, DEFAULT_VALUES_PER_FACET,
};
use ordered_float::OrderedFloat;
use regex::Regex;
use serde::Serialize;
use serde_json::{json, Value};
@ -39,6 +36,7 @@ pub const DEFAULT_CROP_LENGTH: fn() -> usize = || 10;
pub const DEFAULT_CROP_MARKER: fn() -> String = || "".to_string();
pub const DEFAULT_HIGHLIGHT_PRE_TAG: fn() -> String = || "<em>".to_string();
pub const DEFAULT_HIGHLIGHT_POST_TAG: fn() -> String = || "</em>".to_string();
pub const DEFAULT_SEMANTIC_RATIO: fn() -> SemanticRatio = || SemanticRatio(0.5);
#[derive(Debug, Clone, Default, PartialEq, Deserr)]
#[deserr(error = DeserrJsonError, rename_all = camelCase, deny_unknown_fields)]
@ -47,6 +45,8 @@ pub struct SearchQuery {
pub q: Option<String>,
#[deserr(default, error = DeserrJsonError<InvalidSearchVector>)]
pub vector: Option<Vec<f32>>,
#[deserr(default, error = DeserrJsonError<InvalidHybridQuery>)]
pub hybrid: Option<HybridQuery>,
#[deserr(default = DEFAULT_SEARCH_OFFSET(), error = DeserrJsonError<InvalidSearchOffset>)]
pub offset: usize,
#[deserr(default = DEFAULT_SEARCH_LIMIT(), error = DeserrJsonError<InvalidSearchLimit>)]
@ -87,6 +87,48 @@ pub struct SearchQuery {
pub attributes_to_search_on: Option<Vec<String>>,
}
#[derive(Debug, Clone, Default, PartialEq, Deserr)]
#[deserr(error = DeserrJsonError<InvalidHybridQuery>, rename_all = camelCase, deny_unknown_fields)]
pub struct HybridQuery {
/// TODO validate that sementic ratio is between 0.0 and 1,0
#[deserr(default, error = DeserrJsonError<InvalidSearchSemanticRatio>, default)]
pub semantic_ratio: SemanticRatio,
#[deserr(default, error = DeserrJsonError<InvalidEmbedder>, default)]
pub embedder: Option<String>,
}
#[derive(Debug, Clone, Copy, PartialEq, Deserr)]
#[deserr(try_from(f32) = TryFrom::try_from -> InvalidSearchSemanticRatio)]
pub struct SemanticRatio(f32);
impl Default for SemanticRatio {
fn default() -> Self {
DEFAULT_SEMANTIC_RATIO()
}
}
impl std::convert::TryFrom<f32> for SemanticRatio {
type Error = InvalidSearchSemanticRatio;
fn try_from(f: f32) -> Result<Self, Self::Error> {
// the suggested "fix" is: `!(0.0..=1.0).contains(&f)`` which is allegedly less readable
#[allow(clippy::manual_range_contains)]
if f > 1.0 || f < 0.0 {
Err(InvalidSearchSemanticRatio)
} else {
Ok(SemanticRatio(f))
}
}
}
impl std::ops::Deref for SemanticRatio {
type Target = f32;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl SearchQuery {
pub fn is_finite_pagination(&self) -> bool {
self.page.or(self.hits_per_page).is_some()
@ -106,6 +148,8 @@ pub struct SearchQueryWithIndex {
pub q: Option<String>,
#[deserr(default, error = DeserrJsonError<InvalidSearchQ>)]
pub vector: Option<Vec<f32>>,
#[deserr(default, error = DeserrJsonError<InvalidHybridQuery>)]
pub hybrid: Option<HybridQuery>,
#[deserr(default = DEFAULT_SEARCH_OFFSET(), error = DeserrJsonError<InvalidSearchOffset>)]
pub offset: usize,
#[deserr(default = DEFAULT_SEARCH_LIMIT(), error = DeserrJsonError<InvalidSearchLimit>)]
@ -171,6 +215,7 @@ impl SearchQueryWithIndex {
crop_marker,
matching_strategy,
attributes_to_search_on,
hybrid,
} = self;
(
index_uid,
@ -196,6 +241,7 @@ impl SearchQueryWithIndex {
crop_marker,
matching_strategy,
attributes_to_search_on,
hybrid,
// do not use ..Default::default() here,
// rather add any missing field from `SearchQuery` to `SearchQueryWithIndex`
},
@ -335,19 +381,44 @@ fn prepare_search<'t>(
rtxn: &'t RoTxn,
query: &'t SearchQuery,
features: RoFeatures,
distribution: Option<DistributionShift>,
) -> Result<(milli::Search<'t>, bool, usize, usize), MeilisearchHttpError> {
let mut search = index.search(rtxn);
if query.vector.is_some() && query.q.is_some() {
warn!("Ignoring the query string `q` when used with the `vector` parameter.");
if query.vector.is_some() {
features.check_vector("Passing `vector` as a query parameter")?;
}
if query.hybrid.is_some() {
features.check_vector("Passing `hybrid` as a query parameter")?;
}
if query.hybrid.is_none() && query.q.is_some() && query.vector.is_some() {
return Err(MeilisearchHttpError::MissingSearchHybrid);
}
search.distribution_shift(distribution);
if let Some(ref vector) = query.vector {
search.vector(vector.clone());
match &query.hybrid {
// If semantic ratio is 0.0, only the query search will impact the search results,
// skip the vector
Some(hybrid) if *hybrid.semantic_ratio == 0.0 => (),
_otherwise => {
search.vector(vector.clone());
}
}
}
if let Some(ref query) = query.q {
search.query(query);
if let Some(ref q) = query.q {
match &query.hybrid {
// If semantic ratio is 1.0, only the vector search will impact the search results,
// skip the query
Some(hybrid) if *hybrid.semantic_ratio == 1.0 => (),
_otherwise => {
search.query(q);
}
}
}
if let Some(ref searchable) = query.attributes_to_search_on {
@ -374,8 +445,8 @@ fn prepare_search<'t>(
features.check_score_details()?;
}
if query.vector.is_some() {
features.check_vector()?;
if let Some(HybridQuery { embedder: Some(embedder), .. }) = &query.hybrid {
search.embedder_name(embedder);
}
// compute the offset on the limit depending on the pagination mode.
@ -421,15 +492,22 @@ pub fn perform_search(
index: &Index,
query: SearchQuery,
features: RoFeatures,
distribution: Option<DistributionShift>,
) -> Result<SearchResult, MeilisearchHttpError> {
let before_search = Instant::now();
let rtxn = index.read_txn()?;
let (search, is_finite_pagination, max_total_hits, offset) =
prepare_search(index, &rtxn, &query, features)?;
prepare_search(index, &rtxn, &query, features, distribution)?;
let milli::SearchResult { documents_ids, matching_words, candidates, document_scores, .. } =
search.execute()?;
match &query.hybrid {
Some(hybrid) => match *hybrid.semantic_ratio {
ratio if ratio == 0.0 || ratio == 1.0 => search.execute()?,
ratio => search.execute_hybrid(ratio)?,
},
None => search.execute()?,
};
let fields_ids_map = index.fields_ids_map(&rtxn).unwrap();
@ -538,13 +616,17 @@ pub fn perform_search(
insert_geo_distance(sort, &mut document);
}
let semantic_score = match query.vector.as_ref() {
Some(vector) => match extract_field("_vectors", &fields_ids_map, obkv)? {
Some(vectors) => compute_semantic_score(vector, vectors)?,
None => None,
},
None => None,
};
let mut semantic_score = None;
for details in &score {
if let ScoreDetails::Vector(score_details::Vector {
target_vector: _,
value_similarity: Some((_matching_vector, similarity)),
}) = details
{
semantic_score = Some(*similarity);
break;
}
}
let ranking_score =
query.show_ranking_score.then(|| ScoreDetails::global_score(score.iter()));
@ -647,11 +729,15 @@ pub fn perform_facet_search(
let before_search = Instant::now();
let rtxn = index.read_txn()?;
let (search, _, _, _) = prepare_search(index, &rtxn, &search_query, features)?;
let mut facet_search = SearchForFacetValues::new(facet_name, search);
let (search, _, _, _) = prepare_search(index, &rtxn, &search_query, features, None)?;
let mut facet_search =
SearchForFacetValues::new(facet_name, search, search_query.hybrid.is_some());
if let Some(facet_query) = &facet_query {
facet_search.query(facet_query);
}
if let Some(max_facets) = index.max_values_per_facet(&rtxn)? {
facet_search.max_values(max_facets as usize);
}
Ok(FacetSearchResult {
facet_hits: facet_search.execute()?,
@ -676,18 +762,6 @@ fn insert_geo_distance(sorts: &[String], document: &mut Document) {
}
}
fn compute_semantic_score(query: &[f32], vectors: Value) -> milli::Result<Option<f32>> {
let vectors = serde_json::from_value(vectors)
.map(VectorOrArrayOfVectors::into_array_of_vectors)
.map_err(InternalError::SerdeJson)?;
Ok(vectors
.into_iter()
.flatten()
.map(|v| OrderedFloat(dot_product_similarity(query, &v)))
.max()
.map(OrderedFloat::into_inner))
}
fn compute_formatted_options(
attr_to_highlight: &HashSet<String>,
attr_to_crop: &[String],
@ -815,22 +889,6 @@ fn make_document(
Ok(document)
}
/// Extract the JSON value under the field name specified
/// but doesn't support nested objects.
fn extract_field(
field_name: &str,
field_ids_map: &FieldsIdsMap,
obkv: obkv::KvReaderU16,
) -> Result<Option<serde_json::Value>, MeilisearchHttpError> {
match field_ids_map.id(field_name) {
Some(fid) => match obkv.get(fid) {
Some(value) => Ok(serde_json::from_slice(value).map(Some)?),
None => Ok(None),
},
None => Ok(None),
}
}
fn format_fields<'a>(
document: &Document,
field_ids_map: &FieldsIdsMap,
@ -842,6 +900,14 @@ fn format_fields<'a>(
let mut matches_position = compute_matches.then(BTreeMap::new);
let mut document = document.clone();
// reduce the formatted option list to the attributes that should be formatted,
// instead of all the attributes to display.
let formatting_fields_options: Vec<_> = formatted_options
.iter()
.filter(|(_, option)| option.should_format())
.map(|(fid, option)| (field_ids_map.name(*fid).unwrap(), option))
.collect();
// select the attributes to retrieve
let displayable_names =
displayable_ids.iter().map(|&fid| field_ids_map.name(fid).expect("Missing field name"));
@ -850,13 +916,15 @@ fn format_fields<'a>(
// to the value and merge them together. eg. If a user said he wanted to highlight `doggo`
// and crop `doggo.name`. `doggo.name` needs to be highlighted + cropped while `doggo.age` is only
// highlighted.
let format = formatted_options
// Warn: The time to compute the format list scales with the number of fields to format;
// cumulated with map_leaf_values that iterates over all the nested fields, it gives a quadratic complexity:
// d*f where d is the total number of fields to display and f is the total number of fields to format.
let format = formatting_fields_options
.iter()
.filter(|(field, _option)| {
let name = field_ids_map.name(**field).unwrap();
.filter(|(name, _option)| {
milli::is_faceted_by(name, key) || milli::is_faceted_by(key, name)
})
.map(|(_, option)| *option)
.map(|(_, option)| **option)
.reduce(|acc, option| acc.merge(option));
let mut infos = Vec::new();
@ -953,7 +1021,7 @@ fn format_value<'a>(
let value = matcher.format(format_options);
Value::String(value.into_owned())
}
None => Value::Number(number),
None => Value::String(s),
}
}
value => value,

View File

@ -59,7 +59,7 @@ async fn import_dump_v1_movie_raw() {
"dictionary": [],
"synonyms": {},
"distinctAttribute": null,
"proximityPrecision": null,
"proximityPrecision": "byWord",
"typoTolerance": {
"enabled": true,
"minWordSizeForTypos": {
@ -220,7 +220,7 @@ async fn import_dump_v1_movie_with_settings() {
"dictionary": [],
"synonyms": {},
"distinctAttribute": null,
"proximityPrecision": null,
"proximityPrecision": "byWord",
"typoTolerance": {
"enabled": true,
"minWordSizeForTypos": {
@ -367,7 +367,7 @@ async fn import_dump_v1_rubygems_with_settings() {
"dictionary": [],
"synonyms": {},
"distinctAttribute": null,
"proximityPrecision": null,
"proximityPrecision": "byWord",
"typoTolerance": {
"enabled": true,
"minWordSizeForTypos": {
@ -500,7 +500,7 @@ async fn import_dump_v2_movie_raw() {
"dictionary": [],
"synonyms": {},
"distinctAttribute": null,
"proximityPrecision": null,
"proximityPrecision": "byWord",
"typoTolerance": {
"enabled": true,
"minWordSizeForTypos": {
@ -645,7 +645,7 @@ async fn import_dump_v2_movie_with_settings() {
"dictionary": [],
"synonyms": {},
"distinctAttribute": null,
"proximityPrecision": null,
"proximityPrecision": "byWord",
"typoTolerance": {
"enabled": true,
"minWordSizeForTypos": {
@ -789,7 +789,7 @@ async fn import_dump_v2_rubygems_with_settings() {
"dictionary": [],
"synonyms": {},
"distinctAttribute": null,
"proximityPrecision": null,
"proximityPrecision": "byWord",
"typoTolerance": {
"enabled": true,
"minWordSizeForTypos": {
@ -922,7 +922,7 @@ async fn import_dump_v3_movie_raw() {
"dictionary": [],
"synonyms": {},
"distinctAttribute": null,
"proximityPrecision": null,
"proximityPrecision": "byWord",
"typoTolerance": {
"enabled": true,
"minWordSizeForTypos": {
@ -1067,7 +1067,7 @@ async fn import_dump_v3_movie_with_settings() {
"dictionary": [],
"synonyms": {},
"distinctAttribute": null,
"proximityPrecision": null,
"proximityPrecision": "byWord",
"typoTolerance": {
"enabled": true,
"minWordSizeForTypos": {
@ -1211,7 +1211,7 @@ async fn import_dump_v3_rubygems_with_settings() {
"dictionary": [],
"synonyms": {},
"distinctAttribute": null,
"proximityPrecision": null,
"proximityPrecision": "byWord",
"typoTolerance": {
"enabled": true,
"minWordSizeForTypos": {
@ -1344,7 +1344,7 @@ async fn import_dump_v4_movie_raw() {
"dictionary": [],
"synonyms": {},
"distinctAttribute": null,
"proximityPrecision": null,
"proximityPrecision": "byWord",
"typoTolerance": {
"enabled": true,
"minWordSizeForTypos": {
@ -1489,7 +1489,7 @@ async fn import_dump_v4_movie_with_settings() {
"dictionary": [],
"synonyms": {},
"distinctAttribute": null,
"proximityPrecision": null,
"proximityPrecision": "byWord",
"typoTolerance": {
"enabled": true,
"minWordSizeForTypos": {
@ -1633,7 +1633,7 @@ async fn import_dump_v4_rubygems_with_settings() {
"dictionary": [],
"synonyms": {},
"distinctAttribute": null,
"proximityPrecision": null,
"proximityPrecision": "byWord",
"typoTolerance": {
"enabled": true,
"minWordSizeForTypos": {
@ -1848,8 +1848,7 @@ async fn import_dump_v6_containing_experimental_features() {
"scoreDetails": false,
"vectorStore": false,
"metrics": false,
"exportPuffinReports": false,
"proximityPrecision": false
"exportPuffinReports": false
}
"###);
@ -1878,7 +1877,7 @@ async fn import_dump_v6_containing_experimental_features() {
"dictionary": [],
"synonyms": {},
"distinctAttribute": null,
"proximityPrecision": "attributeScale",
"proximityPrecision": "byAttribute",
"typoTolerance": {
"enabled": true,
"minWordSizeForTypos": {

View File

@ -21,8 +21,7 @@ async fn experimental_features() {
"scoreDetails": false,
"vectorStore": false,
"metrics": false,
"exportPuffinReports": false,
"proximityPrecision": false
"exportPuffinReports": false
}
"###);
@ -34,8 +33,7 @@ async fn experimental_features() {
"scoreDetails": false,
"vectorStore": true,
"metrics": false,
"exportPuffinReports": false,
"proximityPrecision": false
"exportPuffinReports": false
}
"###);
@ -47,8 +45,7 @@ async fn experimental_features() {
"scoreDetails": false,
"vectorStore": true,
"metrics": false,
"exportPuffinReports": false,
"proximityPrecision": false
"exportPuffinReports": false
}
"###);
@ -61,8 +58,7 @@ async fn experimental_features() {
"scoreDetails": false,
"vectorStore": true,
"metrics": false,
"exportPuffinReports": false,
"proximityPrecision": false
"exportPuffinReports": false
}
"###);
@ -75,8 +71,7 @@ async fn experimental_features() {
"scoreDetails": false,
"vectorStore": true,
"metrics": false,
"exportPuffinReports": false,
"proximityPrecision": false
"exportPuffinReports": false
}
"###);
}
@ -96,8 +91,7 @@ async fn experimental_feature_metrics() {
"scoreDetails": false,
"vectorStore": false,
"metrics": true,
"exportPuffinReports": false,
"proximityPrecision": false
"exportPuffinReports": false
}
"###);
@ -152,7 +146,7 @@ async fn errors() {
meili_snap::snapshot!(code, @"400 Bad Request");
meili_snap::snapshot!(meili_snap::json_string!(response), @r###"
{
"message": "Unknown field `NotAFeature`: expected one of `scoreDetails`, `vectorStore`, `metrics`, `exportPuffinReports`, `proximityPrecision`",
"message": "Unknown field `NotAFeature`: expected one of `scoreDetails`, `vectorStore`, `metrics`, `exportPuffinReports`",
"code": "bad_request",
"type": "invalid_request",
"link": "https://docs.meilisearch.com/errors#bad_request"

View File

@ -105,6 +105,24 @@ async fn more_advanced_facet_search() {
snapshot!(response["facetHits"].as_array().unwrap().len(), @"1");
}
#[actix_rt::test]
async fn simple_facet_search_with_max_values() {
let server = Server::new().await;
let index = server.index("test");
let documents = DOCUMENTS.clone();
index.update_settings_faceting(json!({ "maxValuesPerFacet": 1 })).await;
index.update_settings_filterable_attributes(json!(["genres"])).await;
index.add_documents(documents, None).await;
index.wait_task(2).await;
let (response, code) =
index.facet_search(json!({"facetName": "genres", "facetQuery": "a"})).await;
assert_eq!(code, 200, "{}", response);
assert_eq!(dbg!(response)["facetHits"].as_array().unwrap().len(), 1);
}
#[actix_rt::test]
async fn non_filterable_facet_search_error() {
let server = Server::new().await;

View File

@ -0,0 +1,175 @@
use meili_snap::snapshot;
use once_cell::sync::Lazy;
use crate::common::index::Index;
use crate::common::{Server, Value};
use crate::json;
async fn index_with_documents<'a>(server: &'a Server, documents: &Value) -> Index<'a> {
let index = server.index("test");
let (response, code) = server.set_features(json!({"vectorStore": true})).await;
meili_snap::snapshot!(code, @"200 OK");
meili_snap::snapshot!(meili_snap::json_string!(response), @r###"
{
"scoreDetails": false,
"vectorStore": true,
"metrics": false,
"exportPuffinReports": false
}
"###);
let (response, code) = index
.update_settings(json!({ "embedders": {"default": {
"source": "userProvided",
"dimensions": 2}}} ))
.await;
assert_eq!(202, code, "{:?}", response);
index.wait_task(response.uid()).await;
let (response, code) = index.add_documents(documents.clone(), None).await;
assert_eq!(202, code, "{:?}", response);
index.wait_task(response.uid()).await;
index
}
static SIMPLE_SEARCH_DOCUMENTS: Lazy<Value> = Lazy::new(|| {
json!([
{
"title": "Shazam!",
"desc": "a Captain Marvel ersatz",
"id": "1",
"_vectors": {"default": [1.0, 3.0]},
},
{
"title": "Captain Planet",
"desc": "He's not part of the Marvel Cinematic Universe",
"id": "2",
"_vectors": {"default": [1.0, 2.0]},
},
{
"title": "Captain Marvel",
"desc": "a Shazam ersatz",
"id": "3",
"_vectors": {"default": [2.0, 3.0]},
}])
});
static SINGLE_DOCUMENT: Lazy<Value> = Lazy::new(|| {
json!([{
"title": "Shazam!",
"desc": "a Captain Marvel ersatz",
"id": "1",
"_vectors": {"default": [1.0, 3.0]},
}])
});
#[actix_rt::test]
async fn simple_search() {
let server = Server::new().await;
let index = index_with_documents(&server, &SIMPLE_SEARCH_DOCUMENTS).await;
let (response, code) = index
.search_post(
json!({"q": "Captain", "vector": [1.0, 1.0], "hybrid": {"semanticRatio": 0.2}}),
)
.await;
snapshot!(code, @"200 OK");
snapshot!(response["hits"], @r###"[{"title":"Captain Planet","desc":"He's not part of the Marvel Cinematic Universe","id":"2","_vectors":{"default":[1.0,2.0]}},{"title":"Captain Marvel","desc":"a Shazam ersatz","id":"3","_vectors":{"default":[2.0,3.0]}},{"title":"Shazam!","desc":"a Captain Marvel ersatz","id":"1","_vectors":{"default":[1.0,3.0]}}]"###);
let (response, code) = index
.search_post(
json!({"q": "Captain", "vector": [1.0, 1.0], "hybrid": {"semanticRatio": 0.8}}),
)
.await;
snapshot!(code, @"200 OK");
snapshot!(response["hits"], @r###"[{"title":"Captain Marvel","desc":"a Shazam ersatz","id":"3","_vectors":{"default":[2.0,3.0]},"_semanticScore":0.99029034},{"title":"Captain Planet","desc":"He's not part of the Marvel Cinematic Universe","id":"2","_vectors":{"default":[1.0,2.0]},"_semanticScore":0.97434163},{"title":"Shazam!","desc":"a Captain Marvel ersatz","id":"1","_vectors":{"default":[1.0,3.0]},"_semanticScore":0.9472136}]"###);
}
#[actix_rt::test]
async fn invalid_semantic_ratio() {
let server = Server::new().await;
let index = index_with_documents(&server, &SIMPLE_SEARCH_DOCUMENTS).await;
let (response, code) = index
.search_post(
json!({"q": "Captain", "vector": [1.0, 1.0], "hybrid": {"semanticRatio": 1.2}}),
)
.await;
snapshot!(code, @"400 Bad Request");
snapshot!(response, @r###"
{
"message": "Invalid value at `.hybrid.semanticRatio`: the value of `semanticRatio` is invalid, expected a float between `0.0` and `1.0`.",
"code": "invalid_search_semantic_ratio",
"type": "invalid_request",
"link": "https://docs.meilisearch.com/errors#invalid_search_semantic_ratio"
}
"###);
let (response, code) = index
.search_post(
json!({"q": "Captain", "vector": [1.0, 1.0], "hybrid": {"semanticRatio": -0.8}}),
)
.await;
snapshot!(code, @"400 Bad Request");
snapshot!(response, @r###"
{
"message": "Invalid value at `.hybrid.semanticRatio`: the value of `semanticRatio` is invalid, expected a float between `0.0` and `1.0`.",
"code": "invalid_search_semantic_ratio",
"type": "invalid_request",
"link": "https://docs.meilisearch.com/errors#invalid_search_semantic_ratio"
}
"###);
let (response, code) = index
.search_get(
&yaup::to_string(
&json!({"q": "Captain", "vector": [1.0, 1.0], "hybridSemanticRatio": 1.2}),
)
.unwrap(),
)
.await;
snapshot!(code, @"400 Bad Request");
snapshot!(response, @r###"
{
"message": "Invalid value in parameter `hybridSemanticRatio`: the value of `semanticRatio` is invalid, expected a float between `0.0` and `1.0`.",
"code": "invalid_search_semantic_ratio",
"type": "invalid_request",
"link": "https://docs.meilisearch.com/errors#invalid_search_semantic_ratio"
}
"###);
let (response, code) = index
.search_get(
&yaup::to_string(
&json!({"q": "Captain", "vector": [1.0, 1.0], "hybridSemanticRatio": -0.2}),
)
.unwrap(),
)
.await;
snapshot!(code, @"400 Bad Request");
snapshot!(response, @r###"
{
"message": "Invalid value in parameter `hybridSemanticRatio`: the value of `semanticRatio` is invalid, expected a float between `0.0` and `1.0`.",
"code": "invalid_search_semantic_ratio",
"type": "invalid_request",
"link": "https://docs.meilisearch.com/errors#invalid_search_semantic_ratio"
}
"###);
}
#[actix_rt::test]
async fn single_document() {
let server = Server::new().await;
let index = index_with_documents(&server, &SINGLE_DOCUMENT).await;
let (response, code) = index
.search_post(
json!({"vector": [1.0, 3.0], "hybrid": {"semanticRatio": 1.0}, "showRankingScore": true}),
)
.await;
snapshot!(code, @"200 OK");
snapshot!(response["hits"][0], @r###"{"title":"Shazam!","desc":"a Captain Marvel ersatz","id":"1","_vectors":{"default":[1.0,3.0]},"_rankingScore":1.0,"_semanticScore":1.0}"###);
}

View File

@ -6,6 +6,7 @@ mod errors;
mod facet_search;
mod formatted;
mod geo;
mod hybrid;
mod multi;
mod pagination;
mod restrict_searchable;
@ -20,22 +21,27 @@ static DOCUMENTS: Lazy<Value> = Lazy::new(|| {
{
"title": "Shazam!",
"id": "287947",
"_vectors": { "manual": [1, 2, 3]},
},
{
"title": "Captain Marvel",
"id": "299537",
"_vectors": { "manual": [1, 2, 54] },
},
{
"title": "Escape Room",
"id": "522681",
"_vectors": { "manual": [10, -23, 32] },
},
{
"title": "How to Train Your Dragon: The Hidden World",
"id": "166428",
"_vectors": { "manual": [-100, 231, 32] },
},
{
"title": "Gläss",
"id": "450465",
"_vectors": { "manual": [-100, 340, 90] },
}
])
});
@ -57,6 +63,7 @@ static NESTED_DOCUMENTS: Lazy<Value> = Lazy::new(|| {
},
],
"cattos": "pésti",
"_vectors": { "manual": [1, 2, 3]},
},
{
"id": 654,
@ -69,12 +76,14 @@ static NESTED_DOCUMENTS: Lazy<Value> = Lazy::new(|| {
},
],
"cattos": ["simba", "pestiféré"],
"_vectors": { "manual": [1, 2, 54] },
},
{
"id": 750,
"father": "romain",
"mother": "michelle",
"cattos": ["enigma"],
"_vectors": { "manual": [10, 23, 32] },
},
{
"id": 951,
@ -91,6 +100,7 @@ static NESTED_DOCUMENTS: Lazy<Value> = Lazy::new(|| {
},
],
"cattos": ["moumoute", "gomez"],
"_vectors": { "manual": [10, 23, 32] },
},
])
});
@ -802,6 +812,13 @@ async fn experimental_feature_score_details() {
{
"title": "How to Train Your Dragon: The Hidden World",
"id": "166428",
"_vectors": {
"manual": [
-100,
231,
32
]
},
"_rankingScoreDetails": {
"words": {
"order": 0,
@ -823,7 +840,7 @@ async fn experimental_feature_score_details() {
"order": 3,
"attributeRankingOrderScore": 1.0,
"queryWordDistanceScore": 0.8095238095238095,
"score": 0.9365079365079364
"score": 0.9727891156462584
},
"exactness": {
"order": 4,
@ -870,13 +887,100 @@ async fn experimental_feature_vector_store() {
meili_snap::snapshot!(code, @"200 OK");
meili_snap::snapshot!(response["vectorStore"], @"true");
let (response, code) = index
.update_settings(json!({"embedders": {
"manual": {
"source": "userProvided",
"dimensions": 3,
}
}}))
.await;
meili_snap::snapshot!(response, @r###"
{
"taskUid": 1,
"indexUid": "test",
"status": "enqueued",
"type": "settingsUpdate",
"enqueuedAt": "[date]"
}
"###);
meili_snap::snapshot!(code, @"202 Accepted");
let response = index.wait_task(response.uid()).await;
meili_snap::snapshot!(meili_snap::json_string!(response["status"]), @"\"succeeded\"");
let (response, code) = index
.search_post(json!({
"vector": [1.0, 2.0, 3.0],
}))
.await;
meili_snap::snapshot!(code, @"200 OK");
meili_snap::snapshot!(meili_snap::json_string!(response["hits"]), @"[]");
// vector search returns all documents that don't have vectors in the last bucket, like all sorts
meili_snap::snapshot!(meili_snap::json_string!(response["hits"]), @r###"
[
{
"title": "Shazam!",
"id": "287947",
"_vectors": {
"manual": [
1,
2,
3
]
},
"_semanticScore": 1.0
},
{
"title": "Captain Marvel",
"id": "299537",
"_vectors": {
"manual": [
1,
2,
54
]
},
"_semanticScore": 0.9129112
},
{
"title": "Gläss",
"id": "450465",
"_vectors": {
"manual": [
-100,
340,
90
]
},
"_semanticScore": 0.8106413
},
{
"title": "How to Train Your Dragon: The Hidden World",
"id": "166428",
"_vectors": {
"manual": [
-100,
231,
32
]
},
"_semanticScore": 0.74120104
},
{
"title": "Escape Room",
"id": "522681",
"_vectors": {
"manual": [
10,
-23,
32
]
}
}
]
"###);
}
#[cfg(feature = "default")]
@ -1126,7 +1230,14 @@ async fn simple_search_with_strange_synonyms() {
[
{
"title": "How to Train Your Dragon: The Hidden World",
"id": "166428"
"id": "166428",
"_vectors": {
"manual": [
-100,
231,
32
]
}
}
]
"###);
@ -1140,7 +1251,14 @@ async fn simple_search_with_strange_synonyms() {
[
{
"title": "How to Train Your Dragon: The Hidden World",
"id": "166428"
"id": "166428",
"_vectors": {
"manual": [
-100,
231,
32
]
}
}
]
"###);
@ -1154,7 +1272,14 @@ async fn simple_search_with_strange_synonyms() {
[
{
"title": "How to Train Your Dragon: The Hidden World",
"id": "166428"
"id": "166428",
"_vectors": {
"manual": [
-100,
231,
32
]
}
}
]
"###);

View File

@ -72,7 +72,14 @@ async fn simple_search_single_index() {
"hits": [
{
"title": "Gläss",
"id": "450465"
"id": "450465",
"_vectors": {
"manual": [
-100,
340,
90
]
}
}
],
"query": "glass",
@ -86,7 +93,14 @@ async fn simple_search_single_index() {
"hits": [
{
"title": "Captain Marvel",
"id": "299537"
"id": "299537",
"_vectors": {
"manual": [
1,
2,
54
]
}
}
],
"query": "captain",
@ -177,7 +191,14 @@ async fn simple_search_two_indexes() {
"hits": [
{
"title": "Gläss",
"id": "450465"
"id": "450465",
"_vectors": {
"manual": [
-100,
340,
90
]
}
}
],
"query": "glass",
@ -203,7 +224,14 @@ async fn simple_search_two_indexes() {
"age": 4
}
],
"cattos": "pésti"
"cattos": "pésti",
"_vectors": {
"manual": [
1,
2,
3
]
}
},
{
"id": 654,
@ -218,7 +246,14 @@ async fn simple_search_two_indexes() {
"cattos": [
"simba",
"pestiféré"
]
],
"_vectors": {
"manual": [
1,
2,
54
]
}
}
],
"query": "pésti",

View File

@ -335,3 +335,35 @@ async fn exactness_ranking_rule_order() {
})
.await;
}
#[actix_rt::test]
async fn search_on_exact_field() {
let server = Server::new().await;
let index = index_with_documents(
&server,
&json!([
{
"title": "Captain Marvel",
"exact": "Captain Marivel",
"id": "1",
},
{
"title": "Captain Marivel",
"exact": "Captain the Marvel",
"id": "2",
}]),
)
.await;
let (response, code) =
index.update_settings_typo_tolerance(json!({ "disableOnAttributes": ["exact"] })).await;
assert_eq!(202, code, "{:?}", response);
index.wait_task(1).await;
// Searching on an exact attribute should only return the document matching without typo.
index
.search(json!({"q": "Marvel", "attributesToSearchOn": ["exact"]}), |response, code| {
snapshot!(code, @"200 OK");
snapshot!(response["hits"].as_array().unwrap().len(), @"1");
})
.await;
}

View File

@ -83,6 +83,7 @@ async fn get_settings() {
"maxTotalHits": 1000,
})
);
assert_eq!(settings["proximityPrecision"], json!("byWord"));
}
#[actix_rt::test]

View File

@ -27,17 +27,6 @@ static DOCUMENTS: Lazy<crate::common::Value> = Lazy::new(|| {
#[actix_rt::test]
async fn attribute_scale_search() {
let server = Server::new().await;
let (response, code) = server.set_features(json!({"proximityPrecision": true})).await;
meili_snap::snapshot!(code, @"200 OK");
meili_snap::snapshot!(meili_snap::json_string!(response), @r###"
{
"scoreDetails": false,
"vectorStore": false,
"metrics": false,
"exportPuffinReports": false,
"proximityPrecision": true
}
"###);
let index = server.index("test");
index.add_documents(DOCUMENTS.clone(), None).await;
@ -45,7 +34,7 @@ async fn attribute_scale_search() {
let (response, code) = index
.update_settings(json!({
"proximityPrecision": "attributeScale",
"proximityPrecision": "byAttribute",
"rankingRules": ["words", "typo", "proximity"],
}))
.await;
@ -111,17 +100,6 @@ async fn attribute_scale_search() {
#[actix_rt::test]
async fn attribute_scale_phrase_search() {
let server = Server::new().await;
let (response, code) = server.set_features(json!({"proximityPrecision": true})).await;
meili_snap::snapshot!(code, @"200 OK");
meili_snap::snapshot!(meili_snap::json_string!(response), @r###"
{
"scoreDetails": false,
"vectorStore": false,
"metrics": false,
"exportPuffinReports": false,
"proximityPrecision": true
}
"###);
let index = server.index("test");
index.add_documents(DOCUMENTS.clone(), None).await;
@ -129,7 +107,7 @@ async fn attribute_scale_phrase_search() {
let (_response, _code) = index
.update_settings(json!({
"proximityPrecision": "attributeScale",
"proximityPrecision": "byAttribute",
"rankingRules": ["words", "typo", "proximity"],
}))
.await;
@ -190,17 +168,6 @@ async fn attribute_scale_phrase_search() {
#[actix_rt::test]
async fn word_scale_set_and_reset() {
let server = Server::new().await;
let (response, code) = server.set_features(json!({"proximityPrecision": true})).await;
meili_snap::snapshot!(code, @"200 OK");
meili_snap::snapshot!(meili_snap::json_string!(response), @r###"
{
"scoreDetails": false,
"vectorStore": false,
"metrics": false,
"exportPuffinReports": false,
"proximityPrecision": true
}
"###);
let index = server.index("test");
index.add_documents(DOCUMENTS.clone(), None).await;
@ -209,7 +176,7 @@ async fn word_scale_set_and_reset() {
// Set and reset the setting ensuring the swap between the 2 settings is applied.
let (_response, _code) = index
.update_settings(json!({
"proximityPrecision": "attributeScale",
"proximityPrecision": "byAttribute",
"rankingRules": ["words", "typo", "proximity"],
}))
.await;
@ -217,7 +184,7 @@ async fn word_scale_set_and_reset() {
let (_response, _code) = index
.update_settings(json!({
"proximityPrecision": "wordScale",
"proximityPrecision": "byWord",
"rankingRules": ["words", "typo", "proximity"],
}))
.await;
@ -316,17 +283,6 @@ async fn word_scale_set_and_reset() {
#[actix_rt::test]
async fn attribute_scale_default_ranking_rules() {
let server = Server::new().await;
let (response, code) = server.set_features(json!({"proximityPrecision": true})).await;
meili_snap::snapshot!(code, @"200 OK");
meili_snap::snapshot!(meili_snap::json_string!(response), @r###"
{
"scoreDetails": false,
"vectorStore": false,
"metrics": false,
"exportPuffinReports": false,
"proximityPrecision": true
}
"###);
let index = server.index("test");
index.add_documents(DOCUMENTS.clone(), None).await;
@ -334,7 +290,7 @@ async fn attribute_scale_default_ranking_rules() {
let (response, code) = index
.update_settings(json!({
"proximityPrecision": "attributeScale"
"proximityPrecision": "byAttribute"
}))
.await;
assert_eq!("202", code.as_str(), "{:?}", response);

View File

@ -1,4 +1,5 @@
mod errors;
mod webhook;
use meili_snap::insta::assert_json_snapshot;
use time::format_description::well_known::Rfc3339;

View File

@ -0,0 +1,123 @@
//! To test the webhook, we need to spawn a new server with a URL listening for
//! post requests. The webhook handle starts a server and forwards all the
//! received requests into a channel for you to handle.
use std::sync::Arc;
use actix_http::body::MessageBody;
use actix_web::dev::{ServiceFactory, ServiceResponse};
use actix_web::web::{Bytes, Data};
use actix_web::{post, App, HttpResponse, HttpServer};
use meili_snap::{json_string, snapshot};
use meilisearch::Opt;
use tokio::sync::mpsc;
use url::Url;
use crate::common::{default_settings, Server};
use crate::json;
#[post("/")]
async fn forward_body(sender: Data<mpsc::UnboundedSender<Vec<u8>>>, body: Bytes) -> HttpResponse {
let body = body.to_vec();
sender.send(body).unwrap();
HttpResponse::Ok().into()
}
fn create_app(
sender: Arc<mpsc::UnboundedSender<Vec<u8>>>,
) -> actix_web::App<
impl ServiceFactory<
actix_web::dev::ServiceRequest,
Config = (),
Response = ServiceResponse<impl MessageBody>,
Error = actix_web::Error,
InitError = (),
>,
> {
App::new().service(forward_body).app_data(Data::from(sender))
}
struct WebhookHandle {
pub server_handle: tokio::task::JoinHandle<Result<(), std::io::Error>>,
pub url: String,
pub receiver: mpsc::UnboundedReceiver<Vec<u8>>,
}
async fn create_webhook_server() -> WebhookHandle {
let mut log_builder = env_logger::Builder::new();
log_builder.parse_filters("info");
log_builder.init();
let (sender, receiver) = mpsc::unbounded_channel();
let sender = Arc::new(sender);
// By listening on the port 0, the system will give us any available port.
let server =
HttpServer::new(move || create_app(sender.clone())).bind(("127.0.0.1", 0)).unwrap();
let (ip, scheme) = server.addrs_with_scheme()[0];
let url = format!("{scheme}://{ip}/");
let server_handle = tokio::spawn(server.run());
WebhookHandle { server_handle, url, receiver }
}
#[actix_web::test]
async fn test_basic_webhook() {
let WebhookHandle { server_handle, url, mut receiver } = create_webhook_server().await;
let db_path = tempfile::tempdir().unwrap();
let server = Server::new_with_options(Opt {
task_webhook_url: Some(Url::parse(&url).unwrap()),
..default_settings(db_path.path())
})
.await
.unwrap();
let index = server.index("tamo");
// May be flaky: we're relying on the fact that while the first document addition is processed, the other
// operations will be received and will be batched together. If it doesn't happen it's not a problem
// the rest of the test won't assume anything about the number of tasks per batch.
for i in 0..5 {
let (_, _status) = index.add_documents(json!({ "id": i, "doggo": "bone" }), None).await;
}
let mut nb_tasks = 0;
while let Some(payload) = receiver.recv().await {
let payload = String::from_utf8(payload).unwrap();
let jsonl = payload.split('\n');
for json in jsonl {
if json.is_empty() {
break; // we reached EOF
}
nb_tasks += 1;
let json: serde_json::Value = serde_json::from_str(json).unwrap();
snapshot!(
json_string!(json, { ".uid" => "[uid]", ".duration" => "[duration]", ".enqueuedAt" => "[date]", ".startedAt" => "[date]", ".finishedAt" => "[date]" }),
@r###"
{
"uid": "[uid]",
"indexUid": "tamo",
"status": "succeeded",
"type": "documentAdditionOrUpdate",
"canceledBy": null,
"details": {
"receivedDocuments": 1,
"indexedDocuments": 1
},
"error": null,
"duration": "[duration]",
"enqueuedAt": "[date]",
"startedAt": "[date]",
"finishedAt": "[date]"
}
"###);
}
if nb_tasks == 5 {
break;
}
}
assert!(nb_tasks == 5, "We should have received the 5 tasks but only received {nb_tasks}");
server_handle.abort();
}

View File

@ -27,13 +27,15 @@ fst = "0.4.7"
fxhash = "0.2.1"
geoutils = "0.5.1"
grenad = { version = "0.4.5", default-features = false, features = [
"rayon", "tempfile"
"rayon",
"tempfile",
] }
heed = { version = "0.20.0-alpha.9", default-features = false, features = [
"serde-json", "serde-bincode", "read-txn-no-tls"
"serde-json",
"serde-bincode",
"read-txn-no-tls",
] }
indexmap = { version = "2.0.0", features = ["serde"] }
instant-distance = { version = "0.6.1", features = ["with-serde"] }
json-depth-checker = { path = "../json-depth-checker" }
levenshtein_automata = { version = "0.2.1", features = ["fst_automaton"] }
memmap2 = "0.7.1"
@ -72,6 +74,23 @@ puffin = "0.16.0"
log = "0.4.17"
logging_timer = "1.1.0"
csv = "1.2.1"
candle-core = { git = "https://github.com/huggingface/candle.git", version = "0.3.1" }
candle-transformers = { git = "https://github.com/huggingface/candle.git", version = "0.3.1" }
candle-nn = { git = "https://github.com/huggingface/candle.git", version = "0.3.1" }
tokenizers = { git = "https://github.com/huggingface/tokenizers.git", tag = "v0.14.1", version = "0.14.1", default_features = false, features = ["onig"] }
hf-hub = { git = "https://github.com/dureuill/hf-hub.git", branch = "rust_tls", default_features = false, features = [
"online",
] }
tokio = { version = "1.34.0", features = ["rt"] }
futures = "0.3.29"
reqwest = { version = "0.11.16", features = [
"rustls-tls",
"json",
], default-features = false }
tiktoken-rs = "0.5.7"
liquid = "0.26.4"
arroy = { git = "https://github.com/meilisearch/arroy.git", version = "0.1.0" }
rand = "0.8.5"
[dev-dependencies]
mimalloc = { version = "0.1.37", default-features = false }
@ -83,7 +102,15 @@ meili-snap = { path = "../meili-snap" }
rand = { version = "0.8.5", features = ["small_rng"] }
[features]
all-tokenizations = ["charabia/chinese", "charabia/hebrew", "charabia/japanese", "charabia/thai", "charabia/korean", "charabia/greek", "charabia/khmer"]
all-tokenizations = [
"charabia/chinese",
"charabia/hebrew",
"charabia/japanese",
"charabia/thai",
"charabia/korean",
"charabia/greek",
"charabia/khmer",
]
# Use POSIX semaphores instead of SysV semaphores in LMDB
# For more information on this feature, see heed's Cargo.toml

View File

@ -5,8 +5,8 @@ use std::time::Instant;
use heed::EnvOpenOptions;
use milli::{
execute_search, DefaultSearchLogger, GeoSortStrategy, Index, SearchContext, SearchLogger,
TermsMatchingStrategy,
execute_search, filtered_universe, DefaultSearchLogger, GeoSortStrategy, Index, SearchContext,
SearchLogger, TermsMatchingStrategy,
};
#[global_allocator]
@ -49,14 +49,15 @@ fn main() -> Result<(), Box<dyn Error>> {
let start = Instant::now();
let mut ctx = SearchContext::new(&index, &txn);
let universe = filtered_universe(&ctx, &None)?;
let docs = execute_search(
&mut ctx,
&(!query.trim().is_empty()).then(|| query.trim().to_owned()),
&None,
(!query.trim().is_empty()).then(|| query.trim()),
TermsMatchingStrategy::Last,
milli::score_details::ScoringStrategy::Skip,
false,
&None,
universe,
&None,
GeoSortStrategy::default(),
0,

View File

@ -1,41 +0,0 @@
use std::ops;
use instant_distance::Point;
use serde::{Deserialize, Serialize};
use crate::normalize_vector;
#[derive(Debug, Default, Clone, Serialize, Deserialize)]
pub struct NDotProductPoint(Vec<f32>);
impl NDotProductPoint {
pub fn new(point: Vec<f32>) -> Self {
NDotProductPoint(normalize_vector(point))
}
pub fn into_inner(self) -> Vec<f32> {
self.0
}
}
impl ops::Deref for NDotProductPoint {
type Target = [f32];
fn deref(&self) -> &Self::Target {
self.0.as_slice()
}
}
impl Point for NDotProductPoint {
fn distance(&self, other: &Self) -> f32 {
let dist = 1.0 - dot_product_similarity(&self.0, &other.0);
debug_assert!(!dist.is_nan());
dist
}
}
/// Returns the dot product similarity score that will between 0.0 and 1.0
/// if both vectors are normalized. The higher the more similar the vectors are.
pub fn dot_product_similarity(a: &[f32], b: &[f32]) -> f32 {
a.iter().zip(b).map(|(a, b)| a * b).sum()
}

View File

@ -61,6 +61,10 @@ pub enum InternalError {
AbortedIndexation,
#[error("The matching words list contains at least one invalid member.")]
InvalidMatchingWords,
#[error(transparent)]
ArroyError(#[from] arroy::Error),
#[error(transparent)]
VectorEmbeddingError(#[from] crate::vector::Error),
}
#[derive(Error, Debug)]
@ -110,8 +114,10 @@ only composed of alphanumeric characters (a-z A-Z 0-9), hyphens (-) and undersco
InvalidGeoField(#[from] GeoError),
#[error("Invalid vector dimensions: expected: `{}`, found: `{}`.", .expected, .found)]
InvalidVectorDimensions { expected: usize, found: usize },
#[error("The `_vectors` field in the document with the id: `{document_id}` is not an array. Was expecting an array of floats or an array of arrays of floats but instead got `{value}`.")]
InvalidVectorsType { document_id: Value, value: Value },
#[error("The `_vectors.{subfield}` field in the document with id: `{document_id}` is not an array. Was expecting an array of floats or an array of arrays of floats but instead got `{value}`.")]
InvalidVectorsType { document_id: Value, value: Value, subfield: String },
#[error("The `_vectors` field in the document with id: `{document_id}` is not an object. Was expecting an object with a key for each embedder with manually provided vectors, but instead got `{value}`")]
InvalidVectorsMapType { document_id: Value, value: Value },
#[error("{0}")]
InvalidFilter(String),
#[error("Invalid type for filter subexpression: expected: {}, found: {1}.", .0.join(", "))]
@ -180,6 +186,76 @@ only composed of alphanumeric characters (a-z A-Z 0-9), hyphens (-) and undersco
UnknownInternalDocumentId { document_id: DocumentId },
#[error("`minWordSizeForTypos` setting is invalid. `oneTypo` and `twoTypos` fields should be between `0` and `255`, and `twoTypos` should be greater or equals to `oneTypo` but found `oneTypo: {0}` and twoTypos: {1}`.")]
InvalidMinTypoWordLenSetting(u8, u8),
#[error(transparent)]
VectorEmbeddingError(#[from] crate::vector::Error),
#[error(transparent)]
MissingDocumentField(#[from] crate::prompt::error::RenderPromptError),
#[error(transparent)]
InvalidPrompt(#[from] crate::prompt::error::NewPromptError),
#[error("`.embedders.{0}.documentTemplate`: Invalid template: {1}.")]
InvalidPromptForEmbeddings(String, crate::prompt::error::NewPromptError),
#[error("Too many embedders in the configuration. Found {0}, but limited to 256.")]
TooManyEmbedders(usize),
#[error("Cannot find embedder with name {0}.")]
InvalidEmbedder(String),
#[error("Too many vectors for document with id {0}: found {1}, but limited to 256.")]
TooManyVectors(String, usize),
#[error("`.embedders.{embedder_name}`: Field `{field}` unavailable for source `{source_}` (only available for sources: {}). Available fields: {}",
allowed_sources_for_field
.iter()
.map(|accepted| format!("`{}`", accepted))
.collect::<Vec<String>>()
.join(", "),
allowed_fields_for_source
.iter()
.map(|accepted| format!("`{}`", accepted))
.collect::<Vec<String>>()
.join(", ")
)]
InvalidFieldForSource {
embedder_name: String,
source_: crate::vector::settings::EmbedderSource,
field: &'static str,
allowed_fields_for_source: &'static [&'static str],
allowed_sources_for_field: &'static [crate::vector::settings::EmbedderSource],
},
#[error("`.embedders.{embedder_name}.model`: Invalid model `{model}` for OpenAI. Supported models: {:?}", crate::vector::openai::EmbeddingModel::supported_models())]
InvalidOpenAiModel { embedder_name: String, model: String },
#[error("`.embedders.{embedder_name}`: Missing field `{field}` (note: this field is mandatory for source {source_})")]
MissingFieldForSource {
field: &'static str,
source_: crate::vector::settings::EmbedderSource,
embedder_name: String,
},
}
impl From<crate::vector::Error> for Error {
fn from(value: crate::vector::Error) -> Self {
match value.fault() {
FaultSource::User => Error::UserError(value.into()),
FaultSource::Runtime => Error::InternalError(value.into()),
FaultSource::Bug => Error::InternalError(value.into()),
FaultSource::Undecided => Error::InternalError(value.into()),
}
}
}
impl From<arroy::Error> for Error {
fn from(value: arroy::Error) -> Self {
match value {
arroy::Error::Heed(heed) => heed.into(),
arroy::Error::Io(io) => io.into(),
arroy::Error::InvalidVecDimension { expected, received } => {
Error::UserError(UserError::InvalidVectorDimensions { expected, found: received })
}
arroy::Error::DatabaseFull
| arroy::Error::InvalidItemAppend
| arroy::Error::UnmatchingDistance { .. }
| arroy::Error::MissingMetadata => {
Error::InternalError(InternalError::ArroyError(value))
}
}
}
}
#[derive(Error, Debug)]
@ -336,6 +412,26 @@ impl From<HeedError> for Error {
}
}
#[derive(Debug, Clone, Copy)]
pub enum FaultSource {
User,
Runtime,
Bug,
Undecided,
}
impl std::fmt::Display for FaultSource {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let s = match self {
FaultSource::User => "user error",
FaultSource::Runtime => "runtime error",
FaultSource::Bug => "coding error",
FaultSource::Undecided => "error",
};
f.write_str(s)
}
}
#[test]
fn conditionally_lookup_for_error_message() {
let prefix = "Attribute `name` is not sortable.";

View File

@ -10,7 +10,6 @@ use roaring::RoaringBitmap;
use rstar::RTree;
use time::OffsetDateTime;
use crate::distance::NDotProductPoint;
use crate::documents::PrimaryKey;
use crate::error::{InternalError, UserError};
use crate::fields_ids_map::FieldsIdsMap;
@ -22,7 +21,7 @@ use crate::heed_codec::{
BEU16StrCodec, FstSetCodec, ScriptLanguageCodec, StrBEU16Codec, StrRefCodec,
};
use crate::proximity::ProximityPrecision;
use crate::readable_slices::ReadableSlices;
use crate::vector::EmbeddingConfig;
use crate::{
default_criteria, CboRoaringBitmapCodec, Criterion, DocumentId, ExternalDocumentsIds,
FacetDistribution, FieldDistribution, FieldId, FieldIdWordCountCodec, GeoPoint, ObkvCodec,
@ -30,9 +29,6 @@ use crate::{
BEU32, BEU64,
};
/// The HNSW data-structure that we serialize, fill and search in.
pub type Hnsw = instant_distance::Hnsw<NDotProductPoint>;
pub const DEFAULT_MIN_WORD_LEN_ONE_TYPO: u8 = 5;
pub const DEFAULT_MIN_WORD_LEN_TWO_TYPOS: u8 = 9;
@ -48,10 +44,6 @@ pub mod main_key {
pub const FIELDS_IDS_MAP_KEY: &str = "fields-ids-map";
pub const GEO_FACETED_DOCUMENTS_IDS_KEY: &str = "geo-faceted-documents-ids";
pub const GEO_RTREE_KEY: &str = "geo-rtree";
/// The prefix of the key that is used to store the, potential big, HNSW structure.
/// It is concatenated with a big-endian encoded number (non-human readable).
/// e.g. vector-hnsw0x0032.
pub const VECTOR_HNSW_KEY_PREFIX: &str = "vector-hnsw";
pub const PRIMARY_KEY_KEY: &str = "primary-key";
pub const SEARCHABLE_FIELDS_KEY: &str = "searchable-fields";
pub const USER_DEFINED_SEARCHABLE_FIELDS_KEY: &str = "user-defined-searchable-fields";
@ -74,6 +66,7 @@ pub mod main_key {
pub const SORT_FACET_VALUES_BY: &str = "sort-facet-values-by";
pub const PAGINATION_MAX_TOTAL_HITS: &str = "pagination-max-total-hits";
pub const PROXIMITY_PRECISION: &str = "proximity-precision";
pub const EMBEDDING_CONFIGS: &str = "embedding_configs";
}
pub mod db_name {
@ -99,7 +92,8 @@ pub mod db_name {
pub const FACET_ID_STRING_FST: &str = "facet-id-string-fst";
pub const FIELD_ID_DOCID_FACET_F64S: &str = "field-id-docid-facet-f64s";
pub const FIELD_ID_DOCID_FACET_STRINGS: &str = "field-id-docid-facet-strings";
pub const VECTOR_ID_DOCID: &str = "vector-id-docids";
pub const VECTOR_EMBEDDER_CATEGORY_ID: &str = "vector-embedder-category-id";
pub const VECTOR_ARROY: &str = "vector-arroy";
pub const DOCUMENTS: &str = "documents";
pub const SCRIPT_LANGUAGE_DOCIDS: &str = "script_language_docids";
}
@ -166,8 +160,10 @@ pub struct Index {
/// Maps the document id, the facet field id and the strings.
pub field_id_docid_facet_strings: Database<FieldDocIdFacetStringCodec, Str>,
/// Maps a vector id to the document id that have it.
pub vector_id_docid: Database<BEU32, BEU32>,
/// Maps an embedder name to its id in the arroy store.
pub embedder_category_id: Database<Str, U8>,
/// Vector store based on arroy™.
pub vector_arroy: arroy::Database<arroy::distances::Angular>,
/// Maps the document id to the document as an obkv store.
pub(crate) documents: Database<BEU32, ObkvCodec>,
@ -182,7 +178,7 @@ impl Index {
) -> Result<Index> {
use db_name::*;
options.max_dbs(24);
options.max_dbs(25);
let env = options.open(path)?;
let mut wtxn = env.write_txn()?;
@ -222,7 +218,11 @@ impl Index {
env.create_database(&mut wtxn, Some(FIELD_ID_DOCID_FACET_F64S))?;
let field_id_docid_facet_strings =
env.create_database(&mut wtxn, Some(FIELD_ID_DOCID_FACET_STRINGS))?;
let vector_id_docid = env.create_database(&mut wtxn, Some(VECTOR_ID_DOCID))?;
// vector stuff
let embedder_category_id =
env.create_database(&mut wtxn, Some(VECTOR_EMBEDDER_CATEGORY_ID))?;
let vector_arroy = env.create_database(&mut wtxn, Some(VECTOR_ARROY))?;
let documents = env.create_database(&mut wtxn, Some(DOCUMENTS))?;
wtxn.commit()?;
@ -252,7 +252,8 @@ impl Index {
facet_id_is_empty_docids,
field_id_docid_facet_f64s,
field_id_docid_facet_strings,
vector_id_docid,
vector_arroy,
embedder_category_id,
documents,
})
}
@ -475,63 +476,6 @@ impl Index {
None => Ok(RoaringBitmap::new()),
}
}
/* vector HNSW */
/// Writes the provided `hnsw`.
pub(crate) fn put_vector_hnsw(&self, wtxn: &mut RwTxn, hnsw: &Hnsw) -> heed::Result<()> {
// We must delete all the chunks before we write the new HNSW chunks.
self.delete_vector_hnsw(wtxn)?;
let chunk_size = 1024 * 1024 * (1024 + 512); // 1.5 GiB
let bytes = bincode::serialize(hnsw).map_err(Into::into).map_err(heed::Error::Encoding)?;
for (i, chunk) in bytes.chunks(chunk_size).enumerate() {
let i = i as u32;
let mut key = main_key::VECTOR_HNSW_KEY_PREFIX.as_bytes().to_vec();
key.extend_from_slice(&i.to_be_bytes());
self.main.remap_types::<Bytes, Bytes>().put(wtxn, &key, chunk)?;
}
Ok(())
}
/// Delete the `hnsw`.
pub(crate) fn delete_vector_hnsw(&self, wtxn: &mut RwTxn) -> heed::Result<bool> {
let mut iter = self
.main
.remap_types::<Bytes, DecodeIgnore>()
.prefix_iter_mut(wtxn, main_key::VECTOR_HNSW_KEY_PREFIX.as_bytes())?;
let mut deleted = false;
while iter.next().transpose()?.is_some() {
// We do not keep a reference to the key or the value.
unsafe { deleted |= iter.del_current()? };
}
Ok(deleted)
}
/// Returns the `hnsw`.
pub fn vector_hnsw(&self, rtxn: &RoTxn) -> Result<Option<Hnsw>> {
let mut slices = Vec::new();
for result in self
.main
.remap_types::<Str, Bytes>()
.prefix_iter(rtxn, main_key::VECTOR_HNSW_KEY_PREFIX)?
{
let (_, slice) = result?;
slices.push(slice);
}
if slices.is_empty() {
Ok(None)
} else {
let readable_slices: ReadableSlices<_> = slices.into_iter().collect();
Ok(Some(
bincode::deserialize_from(readable_slices)
.map_err(Into::into)
.map_err(heed::Error::Decoding)?,
))
}
}
/* field distribution */
/// Writes the field distribution which associates every field name with
@ -1528,6 +1472,41 @@ impl Index {
Ok(script_language)
}
pub(crate) fn put_embedding_configs(
&self,
wtxn: &mut RwTxn<'_>,
configs: Vec<(String, EmbeddingConfig)>,
) -> heed::Result<()> {
self.main.remap_types::<Str, SerdeJson<Vec<(String, EmbeddingConfig)>>>().put(
wtxn,
main_key::EMBEDDING_CONFIGS,
&configs,
)
}
pub(crate) fn delete_embedding_configs(&self, wtxn: &mut RwTxn<'_>) -> heed::Result<bool> {
self.main.remap_key_type::<Str>().delete(wtxn, main_key::EMBEDDING_CONFIGS)
}
pub fn embedding_configs(
&self,
rtxn: &RoTxn<'_>,
) -> Result<Vec<(String, crate::vector::EmbeddingConfig)>> {
Ok(self
.main
.remap_types::<Str, SerdeJson<Vec<(String, EmbeddingConfig)>>>()
.get(rtxn, main_key::EMBEDDING_CONFIGS)?
.unwrap_or_default())
}
pub fn default_embedding_name(&self, rtxn: &RoTxn<'_>) -> Result<String> {
let configs = self.embedding_configs(rtxn)?;
Ok(match configs.as_slice() {
[(ref first_name, _)] => first_name.clone(),
_ => "default".to_owned(),
})
}
}
#[cfg(test)]

View File

@ -10,18 +10,18 @@ pub mod documents;
mod asc_desc;
mod criterion;
pub mod distance;
mod error;
mod external_documents_ids;
pub mod facet;
mod fields_ids_map;
pub mod heed_codec;
pub mod index;
pub mod prompt;
pub mod proximity;
mod readable_slices;
pub mod score_details;
mod search;
pub mod update;
pub mod vector;
#[cfg(test)]
#[macro_use]
@ -32,13 +32,12 @@ use std::convert::{TryFrom, TryInto};
use std::hash::BuildHasherDefault;
use charabia::normalizer::{CharNormalizer, CompatibilityDecompositionNormalizer};
pub use distance::dot_product_similarity;
pub use filter_parser::{Condition, FilterCondition, Span, Token};
use fxhash::{FxHasher32, FxHasher64};
pub use grenad::CompressionType;
pub use search::new::{
execute_search, DefaultSearchLogger, GeoSortStrategy, SearchContext, SearchLogger,
VisualSearchLogger,
execute_search, filtered_universe, DefaultSearchLogger, GeoSortStrategy, SearchContext,
SearchLogger, VisualSearchLogger,
};
use serde_json::Value;
pub use {charabia as tokenizer, heed};

View File

@ -0,0 +1,97 @@
use liquid::model::{
ArrayView, DisplayCow, KStringCow, ObjectRender, ObjectSource, State, Value as LiquidValue,
};
use liquid::{ObjectView, ValueView};
use super::document::Document;
use super::fields::Fields;
use crate::FieldsIdsMap;
#[derive(Debug, Clone)]
pub struct Context<'a> {
document: &'a Document<'a>,
fields: Fields<'a>,
}
impl<'a> Context<'a> {
pub fn new(document: &'a Document<'a>, field_id_map: &'a FieldsIdsMap) -> Self {
Self { document, fields: Fields::new(document, field_id_map) }
}
}
impl<'a> ObjectView for Context<'a> {
fn as_value(&self) -> &dyn ValueView {
self
}
fn size(&self) -> i64 {
2
}
fn keys<'k>(&'k self) -> Box<dyn Iterator<Item = KStringCow<'k>> + 'k> {
Box::new(["doc", "fields"].iter().map(|s| KStringCow::from_static(s)))
}
fn values<'k>(&'k self) -> Box<dyn Iterator<Item = &'k dyn ValueView> + 'k> {
Box::new(
std::iter::once(self.document.as_value())
.chain(std::iter::once(self.fields.as_value())),
)
}
fn iter<'k>(&'k self) -> Box<dyn Iterator<Item = (KStringCow<'k>, &'k dyn ValueView)> + 'k> {
Box::new(self.keys().zip(self.values()))
}
fn contains_key(&self, index: &str) -> bool {
index == "doc" || index == "fields"
}
fn get<'s>(&'s self, index: &str) -> Option<&'s dyn ValueView> {
match index {
"doc" => Some(self.document.as_value()),
"fields" => Some(self.fields.as_value()),
_ => None,
}
}
}
impl<'a> ValueView for Context<'a> {
fn as_debug(&self) -> &dyn std::fmt::Debug {
self
}
fn render(&self) -> liquid::model::DisplayCow<'_> {
DisplayCow::Owned(Box::new(ObjectRender::new(self)))
}
fn source(&self) -> liquid::model::DisplayCow<'_> {
DisplayCow::Owned(Box::new(ObjectSource::new(self)))
}
fn type_name(&self) -> &'static str {
"object"
}
fn query_state(&self, state: liquid::model::State) -> bool {
match state {
State::Truthy => true,
State::DefaultValue | State::Empty | State::Blank => false,
}
}
fn to_kstr(&self) -> liquid::model::KStringCow<'_> {
let s = ObjectRender::new(self).to_string();
KStringCow::from_string(s)
}
fn to_value(&self) -> LiquidValue {
LiquidValue::Object(
self.iter().map(|(k, x)| (k.to_string().into(), x.to_value())).collect(),
)
}
fn as_object(&self) -> Option<&dyn ObjectView> {
Some(self)
}
}

View File

@ -0,0 +1,131 @@
use std::cell::OnceCell;
use std::collections::BTreeMap;
use liquid::model::{
DisplayCow, KString, KStringCow, ObjectRender, ObjectSource, State, Value as LiquidValue,
};
use liquid::{ObjectView, ValueView};
use crate::update::del_add::{DelAdd, KvReaderDelAdd};
use crate::FieldsIdsMap;
#[derive(Debug, Clone)]
pub struct Document<'a>(BTreeMap<&'a str, (&'a [u8], ParsedValue)>);
#[derive(Debug, Clone)]
struct ParsedValue(std::cell::OnceCell<LiquidValue>);
impl ParsedValue {
fn empty() -> ParsedValue {
ParsedValue(OnceCell::new())
}
fn get(&self, raw: &[u8]) -> &LiquidValue {
self.0.get_or_init(|| {
let value: serde_json::Value = serde_json::from_slice(raw).unwrap();
liquid::model::to_value(&value).unwrap()
})
}
}
impl<'a> Document<'a> {
pub fn new(
data: obkv::KvReaderU16<'a>,
side: DelAdd,
inverted_field_map: &'a FieldsIdsMap,
) -> Self {
let mut out_data = BTreeMap::new();
for (fid, raw) in data {
let obkv = KvReaderDelAdd::new(raw);
let Some(raw) = obkv.get(side) else {
continue;
};
let Some(name) = inverted_field_map.name(fid) else {
continue;
};
out_data.insert(name, (raw, ParsedValue::empty()));
}
Self(out_data)
}
fn is_empty(&self) -> bool {
self.0.is_empty()
}
fn len(&self) -> usize {
self.0.len()
}
fn iter(&self) -> impl Iterator<Item = (KString, LiquidValue)> + '_ {
self.0.iter().map(|(&k, (raw, data))| (k.to_owned().into(), data.get(raw).to_owned()))
}
}
impl<'a> ObjectView for Document<'a> {
fn as_value(&self) -> &dyn ValueView {
self
}
fn size(&self) -> i64 {
self.len() as i64
}
fn keys<'k>(&'k self) -> Box<dyn Iterator<Item = KStringCow<'k>> + 'k> {
let keys = BTreeMap::keys(&self.0).map(|&s| s.into());
Box::new(keys)
}
fn values<'k>(&'k self) -> Box<dyn Iterator<Item = &'k dyn ValueView> + 'k> {
Box::new(self.0.values().map(|(raw, v)| v.get(raw) as &dyn ValueView))
}
fn iter<'k>(&'k self) -> Box<dyn Iterator<Item = (KStringCow<'k>, &'k dyn ValueView)> + 'k> {
Box::new(self.0.iter().map(|(&k, (raw, data))| (k.into(), data.get(raw) as &dyn ValueView)))
}
fn contains_key(&self, index: &str) -> bool {
self.0.contains_key(index)
}
fn get<'s>(&'s self, index: &str) -> Option<&'s dyn ValueView> {
self.0.get(index).map(|(raw, v)| v.get(raw) as &dyn ValueView)
}
}
impl<'a> ValueView for Document<'a> {
fn as_debug(&self) -> &dyn std::fmt::Debug {
self
}
fn render(&self) -> liquid::model::DisplayCow<'_> {
DisplayCow::Owned(Box::new(ObjectRender::new(self)))
}
fn source(&self) -> liquid::model::DisplayCow<'_> {
DisplayCow::Owned(Box::new(ObjectSource::new(self)))
}
fn type_name(&self) -> &'static str {
"object"
}
fn query_state(&self, state: liquid::model::State) -> bool {
match state {
State::Truthy => true,
State::DefaultValue | State::Empty | State::Blank => self.is_empty(),
}
}
fn to_kstr(&self) -> liquid::model::KStringCow<'_> {
let s = ObjectRender::new(self).to_string();
KStringCow::from_string(s)
}
fn to_value(&self) -> LiquidValue {
LiquidValue::Object(self.iter().collect())
}
fn as_object(&self) -> Option<&dyn ObjectView> {
Some(self)
}
}

56
milli/src/prompt/error.rs Normal file
View File

@ -0,0 +1,56 @@
use crate::error::FaultSource;
#[derive(Debug, thiserror::Error)]
#[error("{fault}: {kind}")]
pub struct NewPromptError {
pub kind: NewPromptErrorKind,
pub fault: FaultSource,
}
impl From<NewPromptError> for crate::Error {
fn from(value: NewPromptError) -> Self {
crate::Error::UserError(crate::UserError::InvalidPrompt(value))
}
}
impl NewPromptError {
pub(crate) fn cannot_parse_template(inner: liquid::Error) -> NewPromptError {
Self { kind: NewPromptErrorKind::CannotParseTemplate(inner), fault: FaultSource::User }
}
pub(crate) fn invalid_fields_in_template(inner: liquid::Error) -> NewPromptError {
Self { kind: NewPromptErrorKind::InvalidFieldsInTemplate(inner), fault: FaultSource::User }
}
}
#[derive(Debug, thiserror::Error)]
pub enum NewPromptErrorKind {
#[error("cannot parse template: {0}")]
CannotParseTemplate(liquid::Error),
#[error("template contains invalid fields: {0}. Only `doc.*`, `fields[i].name`, `fields[i].value` are supported")]
InvalidFieldsInTemplate(liquid::Error),
}
#[derive(Debug, thiserror::Error)]
#[error("{fault}: {kind}")]
pub struct RenderPromptError {
pub kind: RenderPromptErrorKind,
pub fault: FaultSource,
}
impl RenderPromptError {
pub(crate) fn missing_context(inner: liquid::Error) -> RenderPromptError {
Self { kind: RenderPromptErrorKind::MissingContext(inner), fault: FaultSource::User }
}
}
#[derive(Debug, thiserror::Error)]
pub enum RenderPromptErrorKind {
#[error("missing field in document: {0}")]
MissingContext(liquid::Error),
}
impl From<RenderPromptError> for crate::Error {
fn from(value: RenderPromptError) -> Self {
crate::Error::UserError(crate::UserError::MissingDocumentField(value))
}
}

172
milli/src/prompt/fields.rs Normal file
View File

@ -0,0 +1,172 @@
use liquid::model::{
ArrayView, DisplayCow, KStringCow, ObjectRender, ObjectSource, State, Value as LiquidValue,
};
use liquid::{ObjectView, ValueView};
use super::document::Document;
use crate::FieldsIdsMap;
#[derive(Debug, Clone)]
pub struct Fields<'a>(Vec<FieldValue<'a>>);
impl<'a> Fields<'a> {
pub fn new(document: &'a Document<'a>, field_id_map: &'a FieldsIdsMap) -> Self {
Self(
std::iter::repeat(document)
.zip(field_id_map.iter())
.map(|(document, (_fid, name))| FieldValue { document, name })
.collect(),
)
}
}
#[derive(Debug, Clone, Copy)]
pub struct FieldValue<'a> {
name: &'a str,
document: &'a Document<'a>,
}
impl<'a> ValueView for FieldValue<'a> {
fn as_debug(&self) -> &dyn std::fmt::Debug {
self
}
fn render(&self) -> liquid::model::DisplayCow<'_> {
DisplayCow::Owned(Box::new(ObjectRender::new(self)))
}
fn source(&self) -> liquid::model::DisplayCow<'_> {
DisplayCow::Owned(Box::new(ObjectSource::new(self)))
}
fn type_name(&self) -> &'static str {
"object"
}
fn query_state(&self, state: liquid::model::State) -> bool {
match state {
State::Truthy => true,
State::DefaultValue | State::Empty | State::Blank => self.is_empty(),
}
}
fn to_kstr(&self) -> liquid::model::KStringCow<'_> {
let s = ObjectRender::new(self).to_string();
KStringCow::from_string(s)
}
fn to_value(&self) -> LiquidValue {
LiquidValue::Object(
self.iter().map(|(k, v)| (k.to_string().into(), v.to_value())).collect(),
)
}
fn as_object(&self) -> Option<&dyn ObjectView> {
Some(self)
}
}
impl<'a> FieldValue<'a> {
pub fn name(&self) -> &&'a str {
&self.name
}
pub fn value(&self) -> &dyn ValueView {
self.document.get(self.name).unwrap_or(&LiquidValue::Nil)
}
pub fn is_empty(&self) -> bool {
self.size() == 0
}
}
impl<'a> ObjectView for FieldValue<'a> {
fn as_value(&self) -> &dyn ValueView {
self
}
fn size(&self) -> i64 {
2
}
fn keys<'k>(&'k self) -> Box<dyn Iterator<Item = KStringCow<'k>> + 'k> {
Box::new(["name", "value"].iter().map(|&x| KStringCow::from_static(x)))
}
fn values<'k>(&'k self) -> Box<dyn Iterator<Item = &'k dyn ValueView> + 'k> {
Box::new(
std::iter::once(self.name() as &dyn ValueView).chain(std::iter::once(self.value())),
)
}
fn iter<'k>(&'k self) -> Box<dyn Iterator<Item = (KStringCow<'k>, &'k dyn ValueView)> + 'k> {
Box::new(self.keys().zip(self.values()))
}
fn contains_key(&self, index: &str) -> bool {
index == "name" || index == "value"
}
fn get<'s>(&'s self, index: &str) -> Option<&'s dyn ValueView> {
match index {
"name" => Some(self.name()),
"value" => Some(self.value()),
_ => None,
}
}
}
impl<'a> ArrayView for Fields<'a> {
fn as_value(&self) -> &dyn ValueView {
self.0.as_value()
}
fn size(&self) -> i64 {
self.0.len() as i64
}
fn values<'k>(&'k self) -> Box<dyn Iterator<Item = &'k dyn ValueView> + 'k> {
self.0.values()
}
fn contains_key(&self, index: i64) -> bool {
self.0.contains_key(index)
}
fn get(&self, index: i64) -> Option<&dyn ValueView> {
ArrayView::get(&self.0, index)
}
}
impl<'a> ValueView for Fields<'a> {
fn as_debug(&self) -> &dyn std::fmt::Debug {
self
}
fn render(&self) -> liquid::model::DisplayCow<'_> {
self.0.render()
}
fn source(&self) -> liquid::model::DisplayCow<'_> {
self.0.source()
}
fn type_name(&self) -> &'static str {
self.0.type_name()
}
fn query_state(&self, state: liquid::model::State) -> bool {
self.0.query_state(state)
}
fn to_kstr(&self) -> liquid::model::KStringCow<'_> {
self.0.to_kstr()
}
fn to_value(&self) -> LiquidValue {
self.0.to_value()
}
fn as_array(&self) -> Option<&dyn ArrayView> {
Some(self)
}
}

176
milli/src/prompt/mod.rs Normal file
View File

@ -0,0 +1,176 @@
mod context;
mod document;
pub(crate) mod error;
mod fields;
mod template_checker;
use std::convert::TryFrom;
use error::{NewPromptError, RenderPromptError};
use self::context::Context;
use self::document::Document;
use crate::update::del_add::DelAdd;
use crate::FieldsIdsMap;
pub struct Prompt {
template: liquid::Template,
template_text: String,
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct PromptData {
pub template: String,
}
impl From<Prompt> for PromptData {
fn from(value: Prompt) -> Self {
Self { template: value.template_text }
}
}
impl TryFrom<PromptData> for Prompt {
type Error = NewPromptError;
fn try_from(value: PromptData) -> Result<Self, Self::Error> {
Prompt::new(value.template)
}
}
impl Clone for Prompt {
fn clone(&self) -> Self {
let template_text = self.template_text.clone();
Self { template: new_template(&template_text).unwrap(), template_text }
}
}
fn new_template(text: &str) -> Result<liquid::Template, liquid::Error> {
liquid::ParserBuilder::with_stdlib().build().unwrap().parse(text)
}
fn default_template() -> liquid::Template {
new_template(default_template_text()).unwrap()
}
fn default_template_text() -> &'static str {
"{% for field in fields %} \
{{ field.name }}: {{ field.value }}\n\
{% endfor %}"
}
impl Default for Prompt {
fn default() -> Self {
Self { template: default_template(), template_text: default_template_text().into() }
}
}
impl Default for PromptData {
fn default() -> Self {
Self { template: default_template_text().into() }
}
}
impl Prompt {
pub fn new(template: String) -> Result<Self, NewPromptError> {
let this = Self {
template: liquid::ParserBuilder::with_stdlib()
.build()
.unwrap()
.parse(&template)
.map_err(NewPromptError::cannot_parse_template)?,
template_text: template,
};
// render template with special object that's OK with `doc.*` and `fields.*`
this.template
.render(&template_checker::TemplateChecker)
.map_err(NewPromptError::invalid_fields_in_template)?;
Ok(this)
}
pub fn render(
&self,
document: obkv::KvReaderU16<'_>,
side: DelAdd,
field_id_map: &FieldsIdsMap,
) -> Result<String, RenderPromptError> {
let document = Document::new(document, side, field_id_map);
let context = Context::new(&document, field_id_map);
self.template.render(&context).map_err(RenderPromptError::missing_context)
}
}
#[cfg(test)]
mod test {
use super::Prompt;
use crate::error::FaultSource;
use crate::prompt::error::{NewPromptError, NewPromptErrorKind};
#[test]
fn default_template() {
// does not panic
Prompt::default();
}
#[test]
fn empty_template() {
Prompt::new("".into()).unwrap();
}
#[test]
fn template_ok() {
Prompt::new("{{doc.title}}: {{doc.overview}}".into()).unwrap();
}
#[test]
fn template_syntax() {
assert!(matches!(
Prompt::new("{{doc.title: {{doc.overview}}".into()),
Err(NewPromptError {
kind: NewPromptErrorKind::CannotParseTemplate(_),
fault: FaultSource::User
})
));
}
#[test]
fn template_missing_doc() {
assert!(matches!(
Prompt::new("{{title}}: {{overview}}".into()),
Err(NewPromptError {
kind: NewPromptErrorKind::InvalidFieldsInTemplate(_),
fault: FaultSource::User
})
));
}
#[test]
fn template_nested_doc() {
Prompt::new("{{doc.actor.firstName}}: {{doc.actor.lastName}}".into()).unwrap();
}
#[test]
fn template_fields() {
Prompt::new("{% for field in fields %}{{field}}{% endfor %}".into()).unwrap();
}
#[test]
fn template_fields_ok() {
Prompt::new("{% for field in fields %}{{field.name}}: {{field.value}}{% endfor %}".into())
.unwrap();
}
#[test]
fn template_fields_invalid() {
assert!(matches!(
// intentionally garbled field
Prompt::new("{% for field in fields %}{{field.vaelu}} {% endfor %}".into()),
Err(NewPromptError {
kind: NewPromptErrorKind::InvalidFieldsInTemplate(_),
fault: FaultSource::User
})
));
}
}

View File

@ -0,0 +1,301 @@
use liquid::model::{
ArrayView, DisplayCow, KStringCow, ObjectRender, ObjectSource, State, Value as LiquidValue,
};
use liquid::{Object, ObjectView, ValueView};
#[derive(Debug)]
pub struct TemplateChecker;
#[derive(Debug)]
pub struct DummyDoc;
#[derive(Debug)]
pub struct DummyFields;
#[derive(Debug)]
pub struct DummyField;
const DUMMY_VALUE: &LiquidValue = &LiquidValue::Nil;
impl ObjectView for DummyField {
fn as_value(&self) -> &dyn ValueView {
self
}
fn size(&self) -> i64 {
2
}
fn keys<'k>(&'k self) -> Box<dyn Iterator<Item = KStringCow<'k>> + 'k> {
Box::new(["name", "value"].iter().map(|s| KStringCow::from_static(s)))
}
fn values<'k>(&'k self) -> Box<dyn Iterator<Item = &'k dyn ValueView> + 'k> {
Box::new(vec![DUMMY_VALUE.as_view(), DUMMY_VALUE.as_view()].into_iter())
}
fn iter<'k>(&'k self) -> Box<dyn Iterator<Item = (KStringCow<'k>, &'k dyn ValueView)> + 'k> {
Box::new(self.keys().zip(self.values()))
}
fn contains_key(&self, index: &str) -> bool {
index == "name" || index == "value"
}
fn get<'s>(&'s self, index: &str) -> Option<&'s dyn ValueView> {
if self.contains_key(index) {
Some(DUMMY_VALUE.as_view())
} else {
None
}
}
}
impl ValueView for DummyField {
fn as_debug(&self) -> &dyn std::fmt::Debug {
self
}
fn render(&self) -> DisplayCow<'_> {
DUMMY_VALUE.render()
}
fn source(&self) -> DisplayCow<'_> {
DUMMY_VALUE.source()
}
fn type_name(&self) -> &'static str {
"object"
}
fn query_state(&self, state: State) -> bool {
match state {
State::Truthy => true,
State::DefaultValue => false,
State::Empty => false,
State::Blank => false,
}
}
fn to_kstr(&self) -> KStringCow<'_> {
DUMMY_VALUE.to_kstr()
}
fn to_value(&self) -> LiquidValue {
let mut this = Object::new();
this.insert("name".into(), LiquidValue::Nil);
this.insert("value".into(), LiquidValue::Nil);
LiquidValue::Object(this)
}
fn as_object(&self) -> Option<&dyn ObjectView> {
Some(self)
}
}
impl ValueView for DummyFields {
fn as_debug(&self) -> &dyn std::fmt::Debug {
self
}
fn render(&self) -> DisplayCow<'_> {
DUMMY_VALUE.render()
}
fn source(&self) -> DisplayCow<'_> {
DUMMY_VALUE.source()
}
fn type_name(&self) -> &'static str {
"array"
}
fn query_state(&self, state: State) -> bool {
match state {
State::Truthy => true,
State::DefaultValue => false,
State::Empty => false,
State::Blank => false,
}
}
fn to_kstr(&self) -> KStringCow<'_> {
DUMMY_VALUE.to_kstr()
}
fn to_value(&self) -> LiquidValue {
LiquidValue::Array(vec![DummyField.to_value()])
}
fn as_array(&self) -> Option<&dyn ArrayView> {
Some(self)
}
}
impl ArrayView for DummyFields {
fn as_value(&self) -> &dyn ValueView {
self
}
fn size(&self) -> i64 {
u16::MAX as i64
}
fn values<'k>(&'k self) -> Box<dyn Iterator<Item = &'k dyn ValueView> + 'k> {
Box::new(std::iter::once(DummyField.as_value()))
}
fn contains_key(&self, index: i64) -> bool {
index < self.size()
}
fn get(&self, _index: i64) -> Option<&dyn ValueView> {
Some(DummyField.as_value())
}
}
impl ObjectView for DummyDoc {
fn as_value(&self) -> &dyn ValueView {
self
}
fn size(&self) -> i64 {
1000
}
fn keys<'k>(&'k self) -> Box<dyn Iterator<Item = KStringCow<'k>> + 'k> {
Box::new(std::iter::empty())
}
fn values<'k>(&'k self) -> Box<dyn Iterator<Item = &'k dyn ValueView> + 'k> {
Box::new(std::iter::empty())
}
fn iter<'k>(&'k self) -> Box<dyn Iterator<Item = (KStringCow<'k>, &'k dyn ValueView)> + 'k> {
Box::new(std::iter::empty())
}
fn contains_key(&self, _index: &str) -> bool {
true
}
fn get<'s>(&'s self, _index: &str) -> Option<&'s dyn ValueView> {
// Recursively sends itself
Some(self)
}
}
impl ValueView for DummyDoc {
fn as_debug(&self) -> &dyn std::fmt::Debug {
self
}
fn render(&self) -> DisplayCow<'_> {
DUMMY_VALUE.render()
}
fn source(&self) -> DisplayCow<'_> {
DUMMY_VALUE.source()
}
fn type_name(&self) -> &'static str {
"object"
}
fn query_state(&self, state: State) -> bool {
match state {
State::Truthy => true,
State::DefaultValue => false,
State::Empty => false,
State::Blank => false,
}
}
fn to_kstr(&self) -> KStringCow<'_> {
DUMMY_VALUE.to_kstr()
}
fn to_value(&self) -> LiquidValue {
LiquidValue::Nil
}
fn as_object(&self) -> Option<&dyn ObjectView> {
Some(self)
}
}
impl ObjectView for TemplateChecker {
fn as_value(&self) -> &dyn ValueView {
self
}
fn size(&self) -> i64 {
2
}
fn keys<'k>(&'k self) -> Box<dyn Iterator<Item = KStringCow<'k>> + 'k> {
Box::new(["doc", "fields"].iter().map(|s| KStringCow::from_static(s)))
}
fn values<'k>(&'k self) -> Box<dyn Iterator<Item = &'k dyn ValueView> + 'k> {
Box::new(
std::iter::once(DummyDoc.as_value()).chain(std::iter::once(DummyFields.as_value())),
)
}
fn iter<'k>(&'k self) -> Box<dyn Iterator<Item = (KStringCow<'k>, &'k dyn ValueView)> + 'k> {
Box::new(self.keys().zip(self.values()))
}
fn contains_key(&self, index: &str) -> bool {
index == "doc" || index == "fields"
}
fn get<'s>(&'s self, index: &str) -> Option<&'s dyn ValueView> {
match index {
"doc" => Some(DummyDoc.as_value()),
"fields" => Some(DummyFields.as_value()),
_ => None,
}
}
}
impl ValueView for TemplateChecker {
fn as_debug(&self) -> &dyn std::fmt::Debug {
self
}
fn render(&self) -> liquid::model::DisplayCow<'_> {
DisplayCow::Owned(Box::new(ObjectRender::new(self)))
}
fn source(&self) -> liquid::model::DisplayCow<'_> {
DisplayCow::Owned(Box::new(ObjectSource::new(self)))
}
fn type_name(&self) -> &'static str {
"object"
}
fn query_state(&self, state: liquid::model::State) -> bool {
match state {
State::Truthy => true,
State::DefaultValue | State::Empty | State::Blank => false,
}
}
fn to_kstr(&self) -> liquid::model::KStringCow<'_> {
let s = ObjectRender::new(self).to_string();
KStringCow::from_string(s)
}
fn to_value(&self) -> LiquidValue {
LiquidValue::Object(
self.iter().map(|(k, x)| (k.to_string().into(), x.to_value())).collect(),
)
}
fn as_object(&self) -> Option<&dyn ObjectView> {
Some(self)
}
}

View File

@ -32,6 +32,6 @@ pub fn path_proximity(path: &[Position]) -> u32 {
#[serde(rename_all = "camelCase")]
pub enum ProximityPrecision {
#[default]
WordScale,
AttributeScale,
ByWord,
ByAttribute,
}

View File

@ -1,85 +0,0 @@
use std::io::{self, Read};
use std::iter::FromIterator;
pub struct ReadableSlices<A> {
inner: Vec<A>,
pos: u64,
}
impl<A> FromIterator<A> for ReadableSlices<A> {
fn from_iter<T: IntoIterator<Item = A>>(iter: T) -> Self {
ReadableSlices { inner: iter.into_iter().collect(), pos: 0 }
}
}
impl<A: AsRef<[u8]>> Read for ReadableSlices<A> {
fn read(&mut self, mut buf: &mut [u8]) -> io::Result<usize> {
let original_buf_len = buf.len();
// We explore the list of slices to find the one where we must start reading.
let mut pos = self.pos;
let index = match self
.inner
.iter()
.map(|s| s.as_ref().len() as u64)
.position(|size| pos.checked_sub(size).map(|p| pos = p).is_none())
{
Some(index) => index,
None => return Ok(0),
};
let mut inner_pos = pos as usize;
for slice in &self.inner[index..] {
let slice = &slice.as_ref()[inner_pos..];
if buf.len() > slice.len() {
// We must exhaust the current slice and go to the next one there is not enough here.
buf[..slice.len()].copy_from_slice(slice);
buf = &mut buf[slice.len()..];
inner_pos = 0;
} else {
// There is enough in this slice to fill the remaining bytes of the buffer.
// Let's break just after filling it.
buf.copy_from_slice(&slice[..buf.len()]);
buf = &mut [];
break;
}
}
let written = original_buf_len - buf.len();
self.pos += written as u64;
Ok(written)
}
}
#[cfg(test)]
mod test {
use std::io::Read;
use super::ReadableSlices;
#[test]
fn basic() {
let data: Vec<_> = (0..100).collect();
let splits: Vec<_> = data.chunks(3).collect();
let mut rdslices: ReadableSlices<_> = splits.into_iter().collect();
let mut output = Vec::new();
let length = rdslices.read_to_end(&mut output).unwrap();
assert_eq!(length, data.len());
assert_eq!(output, data);
}
#[test]
fn small_reads() {
let data: Vec<_> = (0..u8::MAX).collect();
let splits: Vec<_> = data.chunks(27).collect();
let mut rdslices: ReadableSlices<_> = splits.into_iter().collect();
let buffer = &mut [0; 45];
let length = rdslices.read(buffer).unwrap();
let expected: Vec<_> = (0..buffer.len() as u8).collect();
assert_eq!(length, buffer.len());
assert_eq!(buffer, &expected[..]);
}
}

View File

@ -1,3 +1,6 @@
use std::cmp::Ordering;
use itertools::Itertools;
use serde::Serialize;
use crate::distance_between_two_points;
@ -12,9 +15,24 @@ pub enum ScoreDetails {
ExactAttribute(ExactAttribute),
ExactWords(ExactWords),
Sort(Sort),
Vector(Vector),
GeoSort(GeoSort),
}
#[derive(Clone, Copy)]
pub enum ScoreValue<'a> {
Score(f64),
Sort(&'a Sort),
GeoSort(&'a GeoSort),
}
enum RankOrValue<'a> {
Rank(Rank),
Sort(&'a Sort),
GeoSort(&'a GeoSort),
Score(f64),
}
impl ScoreDetails {
pub fn local_score(&self) -> Option<f64> {
self.rank().map(Rank::local_score)
@ -31,11 +49,55 @@ impl ScoreDetails {
ScoreDetails::ExactWords(details) => Some(details.rank()),
ScoreDetails::Sort(_) => None,
ScoreDetails::GeoSort(_) => None,
ScoreDetails::Vector(_) => None,
}
}
pub fn global_score<'a>(details: impl Iterator<Item = &'a Self>) -> f64 {
Rank::global_score(details.filter_map(Self::rank))
pub fn global_score<'a>(details: impl Iterator<Item = &'a Self> + 'a) -> f64 {
Self::score_values(details)
.find_map(|x| {
let ScoreValue::Score(score) = x else {
return None;
};
Some(score)
})
.unwrap_or(1.0f64)
}
pub fn score_values<'a>(
details: impl Iterator<Item = &'a Self> + 'a,
) -> impl Iterator<Item = ScoreValue<'a>> + 'a {
details
.map(ScoreDetails::rank_or_value)
.coalesce(|left, right| match (left, right) {
(RankOrValue::Rank(left), RankOrValue::Rank(right)) => {
Ok(RankOrValue::Rank(Rank::merge(left, right)))
}
(left, right) => Err((left, right)),
})
.map(|rank_or_value| match rank_or_value {
RankOrValue::Rank(r) => ScoreValue::Score(r.local_score()),
RankOrValue::Sort(s) => ScoreValue::Sort(s),
RankOrValue::GeoSort(g) => ScoreValue::GeoSort(g),
RankOrValue::Score(s) => ScoreValue::Score(s),
})
}
fn rank_or_value(&self) -> RankOrValue<'_> {
match self {
ScoreDetails::Words(w) => RankOrValue::Rank(w.rank()),
ScoreDetails::Typo(t) => RankOrValue::Rank(t.rank()),
ScoreDetails::Proximity(p) => RankOrValue::Rank(*p),
ScoreDetails::Fid(f) => RankOrValue::Rank(*f),
ScoreDetails::Position(p) => RankOrValue::Rank(*p),
ScoreDetails::ExactAttribute(e) => RankOrValue::Rank(e.rank()),
ScoreDetails::ExactWords(e) => RankOrValue::Rank(e.rank()),
ScoreDetails::Sort(sort) => RankOrValue::Sort(sort),
ScoreDetails::GeoSort(geosort) => RankOrValue::GeoSort(geosort),
ScoreDetails::Vector(vector) => RankOrValue::Score(
vector.value_similarity.as_ref().map(|(_, s)| *s as f64).unwrap_or(0.0f64),
),
}
}
/// Panics
@ -181,6 +243,19 @@ impl ScoreDetails {
details_map.insert(sort, sort_details);
order += 1;
}
ScoreDetails::Vector(s) => {
let vector = format!("vectorSort({:?})", s.target_vector);
let value = s.value_similarity.as_ref().map(|(v, _)| v);
let similarity = s.value_similarity.as_ref().map(|(_, s)| s);
let details = serde_json::json!({
"order": order,
"value": value,
"similarity": similarity,
});
details_map.insert(vector, details);
order += 1;
}
}
}
details_map
@ -297,15 +372,21 @@ impl Rank {
pub fn global_score(details: impl Iterator<Item = Self>) -> f64 {
let mut rank = Rank { rank: 1, max_rank: 1 };
for inner_rank in details {
rank.rank -= 1;
rank.rank *= inner_rank.max_rank;
rank.max_rank *= inner_rank.max_rank;
rank.rank += inner_rank.rank;
rank = Rank::merge(rank, inner_rank);
}
rank.local_score()
}
pub fn merge(mut outer: Rank, inner: Rank) -> Rank {
outer.rank = outer.rank.saturating_sub(1);
outer.rank *= inner.max_rank;
outer.max_rank *= inner.max_rank;
outer.rank += inner.rank;
outer
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize)]
@ -335,13 +416,78 @@ pub struct Sort {
pub value: serde_json::Value,
}
#[derive(Debug, Clone, Copy, PartialEq, PartialOrd)]
impl PartialOrd for Sort {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
if self.field_name != other.field_name {
return None;
}
if self.ascending != other.ascending {
return None;
}
match (&self.value, &other.value) {
(serde_json::Value::Null, serde_json::Value::Null) => Some(Ordering::Equal),
(serde_json::Value::Null, _) => Some(Ordering::Less),
(_, serde_json::Value::Null) => Some(Ordering::Greater),
// numbers are always before strings
(serde_json::Value::Number(_), serde_json::Value::String(_)) => Some(Ordering::Greater),
(serde_json::Value::String(_), serde_json::Value::Number(_)) => Some(Ordering::Less),
(serde_json::Value::Number(left), serde_json::Value::Number(right)) => {
// FIXME: unwrap permitted here?
let order = left.as_f64().unwrap().partial_cmp(&right.as_f64().unwrap())?;
// 12 < 42, and when ascending, we want to see 12 first, so the smallest.
// Hence, when ascending, smaller is better
Some(if self.ascending { order.reverse() } else { order })
}
(serde_json::Value::String(left), serde_json::Value::String(right)) => {
let order = left.cmp(right);
// Taking e.g. "a" and "z"
// "a" < "z", and when ascending, we want to see "a" first, so the smallest.
// Hence, when ascending, smaller is better
Some(if self.ascending { order.reverse() } else { order })
}
_ => None,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct GeoSort {
pub target_point: [f64; 2],
pub ascending: bool,
pub value: Option<[f64; 2]>,
}
impl PartialOrd for GeoSort {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
if self.target_point != other.target_point {
return None;
}
if self.ascending != other.ascending {
return None;
}
Some(match (self.distance(), other.distance()) {
(None, None) => Ordering::Equal,
(None, Some(_)) => Ordering::Less,
(Some(_), None) => Ordering::Greater,
(Some(left), Some(right)) => {
let order = left.partial_cmp(&right)?;
if self.ascending {
// when ascending, the one with the smallest distance has the best score
order.reverse()
} else {
order
}
}
})
}
}
#[derive(Debug, Clone, PartialEq, PartialOrd)]
pub struct Vector {
pub target_vector: Vec<f32>,
pub value_similarity: Option<(Vec<f32>, f32)>,
}
impl GeoSort {
pub fn distance(&self) -> Option<f64> {
self.value.map(|value| distance_between_two_points(&self.target_point, &value))

183
milli/src/search/hybrid.rs Normal file
View File

@ -0,0 +1,183 @@
use std::cmp::Ordering;
use itertools::Itertools;
use roaring::RoaringBitmap;
use crate::score_details::{ScoreDetails, ScoreValue, ScoringStrategy};
use crate::{MatchingWords, Result, Search, SearchResult};
struct ScoreWithRatioResult {
matching_words: MatchingWords,
candidates: RoaringBitmap,
document_scores: Vec<(u32, ScoreWithRatio)>,
}
type ScoreWithRatio = (Vec<ScoreDetails>, f32);
fn compare_scores(
&(ref left_scores, left_ratio): &ScoreWithRatio,
&(ref right_scores, right_ratio): &ScoreWithRatio,
) -> Ordering {
let mut left_it = ScoreDetails::score_values(left_scores.iter());
let mut right_it = ScoreDetails::score_values(right_scores.iter());
loop {
let left = left_it.next();
let right = right_it.next();
match (left, right) {
(None, None) => return Ordering::Equal,
(None, Some(_)) => return Ordering::Less,
(Some(_), None) => return Ordering::Greater,
(Some(ScoreValue::Score(left)), Some(ScoreValue::Score(right))) => {
let left = left * left_ratio as f64;
let right = right * right_ratio as f64;
if (left - right).abs() <= f64::EPSILON {
continue;
}
return left.partial_cmp(&right).unwrap();
}
(Some(ScoreValue::Sort(left)), Some(ScoreValue::Sort(right))) => {
match left.partial_cmp(right).unwrap() {
Ordering::Equal => continue,
order => return order,
}
}
(Some(ScoreValue::GeoSort(left)), Some(ScoreValue::GeoSort(right))) => {
match left.partial_cmp(right).unwrap() {
Ordering::Equal => continue,
order => return order,
}
}
(Some(ScoreValue::Score(_)), Some(_)) => return Ordering::Greater,
(Some(_), Some(ScoreValue::Score(_))) => return Ordering::Less,
// if we have this, we're bad
(Some(ScoreValue::GeoSort(_)), Some(ScoreValue::Sort(_)))
| (Some(ScoreValue::Sort(_)), Some(ScoreValue::GeoSort(_))) => {
unreachable!("Unexpected geo and sort comparison")
}
}
}
}
impl ScoreWithRatioResult {
fn new(results: SearchResult, ratio: f32) -> Self {
let document_scores = results
.documents_ids
.into_iter()
.zip(results.document_scores.into_iter().map(|scores| (scores, ratio)))
.collect();
Self {
matching_words: results.matching_words,
candidates: results.candidates,
document_scores,
}
}
fn merge(left: Self, right: Self, from: usize, length: usize) -> SearchResult {
let mut documents_ids =
Vec::with_capacity(left.document_scores.len() + right.document_scores.len());
let mut document_scores =
Vec::with_capacity(left.document_scores.len() + right.document_scores.len());
let mut documents_seen = RoaringBitmap::new();
for (docid, (main_score, _sub_score)) in left
.document_scores
.into_iter()
.merge_by(right.document_scores.into_iter(), |(_, left), (_, right)| {
// the first value is the one with the greatest score
compare_scores(left, right).is_ge()
})
// remove documents we already saw
.filter(|(docid, _)| documents_seen.insert(*docid))
// start skipping **after** the filter
.skip(from)
// take **after** skipping
.take(length)
{
documents_ids.push(docid);
// TODO: pass both scores to documents_score in some way?
document_scores.push(main_score);
}
SearchResult {
matching_words: left.matching_words,
candidates: left.candidates | right.candidates,
documents_ids,
document_scores,
}
}
}
impl<'a> Search<'a> {
pub fn execute_hybrid(&self, semantic_ratio: f32) -> Result<SearchResult> {
// TODO: find classier way to achieve that than to reset vector and query params
// create separate keyword and semantic searches
let mut search = Search {
query: self.query.clone(),
vector: self.vector.clone(),
filter: self.filter.clone(),
offset: 0,
limit: self.limit + self.offset,
sort_criteria: self.sort_criteria.clone(),
searchable_attributes: self.searchable_attributes,
geo_strategy: self.geo_strategy,
terms_matching_strategy: self.terms_matching_strategy,
scoring_strategy: ScoringStrategy::Detailed,
words_limit: self.words_limit,
exhaustive_number_hits: self.exhaustive_number_hits,
rtxn: self.rtxn,
index: self.index,
distribution_shift: self.distribution_shift,
embedder_name: self.embedder_name.clone(),
};
let vector_query = search.vector.take();
let keyword_results = search.execute()?;
// skip semantic search if we don't have a vector query (placeholder search)
let Some(vector_query) = vector_query else {
return Ok(keyword_results);
};
// completely skip semantic search if the results of the keyword search are good enough
if self.results_good_enough(&keyword_results, semantic_ratio) {
return Ok(keyword_results);
}
search.vector = Some(vector_query);
search.query = None;
// TODO: would be better to have two distinct functions at this point
let vector_results = search.execute()?;
let keyword_results = ScoreWithRatioResult::new(keyword_results, 1.0 - semantic_ratio);
let vector_results = ScoreWithRatioResult::new(vector_results, semantic_ratio);
let merge_results =
ScoreWithRatioResult::merge(vector_results, keyword_results, self.offset, self.limit);
assert!(merge_results.documents_ids.len() <= self.limit);
Ok(merge_results)
}
fn results_good_enough(&self, keyword_results: &SearchResult, semantic_ratio: f32) -> bool {
// A result is good enough if its keyword score is > 0.9 with a semantic ratio of 0.5 => 0.9 * 0.5
const GOOD_ENOUGH_SCORE: f64 = 0.45;
// 1. we check that we got a sufficient number of results
if keyword_results.document_scores.len() < self.limit + self.offset {
return false;
}
// 2. and that all results have a good enough score.
// we need to check all results because due to sort like rules, they're not necessarily in relevancy order
for score in &keyword_results.document_scores {
let score = ScoreDetails::global_score(score.iter());
if score * ((1.0 - semantic_ratio) as f64) < GOOD_ENOUGH_SCORE {
return false;
}
}
true
}
}

View File

@ -12,12 +12,14 @@ use roaring::bitmap::RoaringBitmap;
pub use self::facet::{FacetDistribution, Filter, OrderBy, DEFAULT_VALUES_PER_FACET};
pub use self::new::matches::{FormatOptions, MatchBounds, MatcherBuilder, MatchingWords};
use self::new::PartialSearchResult;
use self::new::{execute_vector_search, PartialSearchResult};
use crate::error::UserError;
use crate::heed_codec::facet::{FacetGroupKey, FacetGroupValue};
use crate::score_details::{ScoreDetails, ScoringStrategy};
use crate::vector::DistributionShift;
use crate::{
execute_search, AscDesc, DefaultSearchLogger, DocumentId, FieldId, Index, Result, SearchContext,
execute_search, filtered_universe, AscDesc, DefaultSearchLogger, DocumentId, FieldId, Index,
Result, SearchContext,
};
// Building these factories is not free.
@ -25,11 +27,12 @@ static LEVDIST0: Lazy<LevBuilder> = Lazy::new(|| LevBuilder::new(0, true));
static LEVDIST1: Lazy<LevBuilder> = Lazy::new(|| LevBuilder::new(1, true));
static LEVDIST2: Lazy<LevBuilder> = Lazy::new(|| LevBuilder::new(2, true));
/// The maximum number of facets returned by the facet search route.
const MAX_NUMBER_OF_FACETS: usize = 100;
/// The maximum number of values per facet returned by the facet search route.
const DEFAULT_MAX_NUMBER_OF_VALUES_PER_FACET: usize = 100;
pub mod facet;
mod fst_utils;
pub mod hybrid;
pub mod new;
pub struct Search<'a> {
@ -46,8 +49,11 @@ pub struct Search<'a> {
scoring_strategy: ScoringStrategy,
words_limit: usize,
exhaustive_number_hits: bool,
/// TODO: Add semantic ratio or pass it directly to execute_hybrid()
rtxn: &'a heed::RoTxn<'a>,
index: &'a Index,
distribution_shift: Option<DistributionShift>,
embedder_name: Option<String>,
}
impl<'a> Search<'a> {
@ -67,6 +73,8 @@ impl<'a> Search<'a> {
words_limit: 10,
rtxn,
index,
distribution_shift: None,
embedder_name: None,
}
}
@ -75,8 +83,8 @@ impl<'a> Search<'a> {
self
}
pub fn vector(&mut self, vector: impl Into<Vec<f32>>) -> &mut Search<'a> {
self.vector = Some(vector.into());
pub fn vector(&mut self, vector: Vec<f32>) -> &mut Search<'a> {
self.vector = Some(vector);
self
}
@ -133,30 +141,75 @@ impl<'a> Search<'a> {
self
}
pub fn distribution_shift(
&mut self,
distribution_shift: Option<DistributionShift>,
) -> &mut Search<'a> {
self.distribution_shift = distribution_shift;
self
}
pub fn embedder_name(&mut self, embedder_name: impl Into<String>) -> &mut Search<'a> {
self.embedder_name = Some(embedder_name.into());
self
}
pub fn execute_for_candidates(&self, has_vector_search: bool) -> Result<RoaringBitmap> {
if has_vector_search {
let ctx = SearchContext::new(self.index, self.rtxn);
filtered_universe(&ctx, &self.filter)
} else {
Ok(self.execute()?.candidates)
}
}
pub fn execute(&self) -> Result<SearchResult> {
let embedder_name;
let embedder_name = match &self.embedder_name {
Some(embedder_name) => embedder_name,
None => {
embedder_name = self.index.default_embedding_name(self.rtxn)?;
&embedder_name
}
};
let mut ctx = SearchContext::new(self.index, self.rtxn);
if let Some(searchable_attributes) = self.searchable_attributes {
ctx.searchable_attributes(searchable_attributes)?;
}
let universe = filtered_universe(&ctx, &self.filter)?;
let PartialSearchResult { located_query_terms, candidates, documents_ids, document_scores } =
execute_search(
&mut ctx,
&self.query,
&self.vector,
self.terms_matching_strategy,
self.scoring_strategy,
self.exhaustive_number_hits,
&self.filter,
&self.sort_criteria,
self.geo_strategy,
self.offset,
self.limit,
Some(self.words_limit),
&mut DefaultSearchLogger,
&mut DefaultSearchLogger,
)?;
match self.vector.as_ref() {
Some(vector) => execute_vector_search(
&mut ctx,
vector,
self.scoring_strategy,
universe,
&self.sort_criteria,
self.geo_strategy,
self.offset,
self.limit,
self.distribution_shift,
embedder_name,
)?,
None => execute_search(
&mut ctx,
self.query.as_deref(),
self.terms_matching_strategy,
self.scoring_strategy,
self.exhaustive_number_hits,
universe,
&self.sort_criteria,
self.geo_strategy,
self.offset,
self.limit,
Some(self.words_limit),
&mut DefaultSearchLogger,
&mut DefaultSearchLogger,
)?,
};
// consume context and located_query_terms to build MatchingWords.
let matching_words = match located_query_terms {
@ -185,6 +238,8 @@ impl fmt::Debug for Search<'_> {
exhaustive_number_hits,
rtxn: _,
index: _,
distribution_shift,
embedder_name,
} = self;
f.debug_struct("Search")
.field("query", query)
@ -198,6 +253,8 @@ impl fmt::Debug for Search<'_> {
.field("scoring_strategy", scoring_strategy)
.field("exhaustive_number_hits", exhaustive_number_hits)
.field("words_limit", words_limit)
.field("distribution_shift", distribution_shift)
.field("embedder_name", embedder_name)
.finish()
}
}
@ -249,11 +306,23 @@ pub struct SearchForFacetValues<'a> {
query: Option<String>,
facet: String,
search_query: Search<'a>,
max_values: usize,
is_hybrid: bool,
}
impl<'a> SearchForFacetValues<'a> {
pub fn new(facet: String, search_query: Search<'a>) -> SearchForFacetValues<'a> {
SearchForFacetValues { query: None, facet, search_query }
pub fn new(
facet: String,
search_query: Search<'a>,
is_hybrid: bool,
) -> SearchForFacetValues<'a> {
SearchForFacetValues {
query: None,
facet,
search_query,
max_values: DEFAULT_MAX_NUMBER_OF_VALUES_PER_FACET,
is_hybrid,
}
}
pub fn query(&mut self, query: impl Into<String>) -> &mut Self {
@ -261,6 +330,11 @@ impl<'a> SearchForFacetValues<'a> {
self
}
pub fn max_values(&mut self, max: usize) -> &mut Self {
self.max_values = max;
self
}
fn one_original_value_of(
&self,
field_id: FieldId,
@ -303,7 +377,9 @@ impl<'a> SearchForFacetValues<'a> {
None => return Ok(vec![]),
};
let search_candidates = self.search_query.execute()?.candidates;
let search_candidates = self
.search_query
.execute_for_candidates(self.is_hybrid || self.search_query.vector.is_some())?;
match self.query.as_ref() {
Some(query) => {
@ -398,7 +474,7 @@ impl<'a> SearchForFacetValues<'a> {
.unwrap_or_else(|| left_bound.to_string());
results.push(FacetValueHit { value, count });
}
if results.len() >= MAX_NUMBER_OF_FACETS {
if results.len() >= self.max_values {
break;
}
}
@ -443,7 +519,7 @@ impl<'a> SearchForFacetValues<'a> {
.unwrap_or_else(|| query.to_string());
results.push(FacetValueHit { value, count });
}
if results.len() >= MAX_NUMBER_OF_FACETS {
if results.len() >= self.max_values {
return Ok(ControlFlow::Break(()));
}
}

View File

@ -15,6 +15,7 @@ pub struct BucketSortOutput {
// TODO: would probably be good to regroup some of these inside of a struct?
#[allow(clippy::too_many_arguments)]
#[logging_timer::time]
pub fn bucket_sort<'ctx, Q: RankingRuleQueryTrait>(
ctx: &mut SearchContext<'ctx>,
mut ranking_rules: Vec<BoxRankingRule<'ctx, Q>>,

View File

@ -162,7 +162,8 @@ impl<'ctx> SearchContext<'ctx> {
match &self.restricted_fids {
Some(restricted_fids) => {
let interned = self.word_interner.get(word).as_str();
let keys: Vec<_> = restricted_fids.iter().map(|fid| (interned, *fid)).collect();
let keys: Vec<_> =
restricted_fids.tolerant.iter().map(|fid| (interned, *fid)).collect();
DatabaseCache::get_value_from_keys::<_, _, CboRoaringBitmapCodec>(
self.txn,
@ -187,13 +188,29 @@ impl<'ctx> SearchContext<'ctx> {
&mut self,
word: Interned<String>,
) -> Result<Option<RoaringBitmap>> {
DatabaseCache::get_value::<_, _, CboRoaringBitmapCodec>(
self.txn,
word,
self.word_interner.get(word).as_str(),
&mut self.db_cache.exact_word_docids,
self.index.exact_word_docids.remap_data_type::<Bytes>(),
)
match &self.restricted_fids {
Some(restricted_fids) => {
let interned = self.word_interner.get(word).as_str();
let keys: Vec<_> =
restricted_fids.exact.iter().map(|fid| (interned, *fid)).collect();
DatabaseCache::get_value_from_keys::<_, _, CboRoaringBitmapCodec>(
self.txn,
word,
&keys[..],
&mut self.db_cache.exact_word_docids,
self.index.word_fid_docids.remap_data_type::<Bytes>(),
merge_cbo_roaring_bitmaps,
)
}
None => DatabaseCache::get_value::<_, _, CboRoaringBitmapCodec>(
self.txn,
word,
self.word_interner.get(word).as_str(),
&mut self.db_cache.exact_word_docids,
self.index.exact_word_docids.remap_data_type::<Bytes>(),
),
}
}
pub fn word_prefix_docids(&mut self, prefix: Word) -> Result<Option<RoaringBitmap>> {
@ -224,7 +241,8 @@ impl<'ctx> SearchContext<'ctx> {
match &self.restricted_fids {
Some(restricted_fids) => {
let interned = self.word_interner.get(prefix).as_str();
let keys: Vec<_> = restricted_fids.iter().map(|fid| (interned, *fid)).collect();
let keys: Vec<_> =
restricted_fids.tolerant.iter().map(|fid| (interned, *fid)).collect();
DatabaseCache::get_value_from_keys::<_, _, CboRoaringBitmapCodec>(
self.txn,
@ -249,13 +267,29 @@ impl<'ctx> SearchContext<'ctx> {
&mut self,
prefix: Interned<String>,
) -> Result<Option<RoaringBitmap>> {
DatabaseCache::get_value::<_, _, CboRoaringBitmapCodec>(
self.txn,
prefix,
self.word_interner.get(prefix).as_str(),
&mut self.db_cache.exact_word_prefix_docids,
self.index.exact_word_prefix_docids.remap_data_type::<Bytes>(),
)
match &self.restricted_fids {
Some(restricted_fids) => {
let interned = self.word_interner.get(prefix).as_str();
let keys: Vec<_> =
restricted_fids.exact.iter().map(|fid| (interned, *fid)).collect();
DatabaseCache::get_value_from_keys::<_, _, CboRoaringBitmapCodec>(
self.txn,
prefix,
&keys[..],
&mut self.db_cache.exact_word_prefix_docids,
self.index.word_prefix_fid_docids.remap_data_type::<Bytes>(),
merge_cbo_roaring_bitmaps,
)
}
None => DatabaseCache::get_value::<_, _, CboRoaringBitmapCodec>(
self.txn,
prefix,
self.word_interner.get(prefix).as_str(),
&mut self.db_cache.exact_word_prefix_docids,
self.index.exact_word_prefix_docids.remap_data_type::<Bytes>(),
),
}
}
pub fn get_db_word_pair_proximity_docids(
@ -265,9 +299,9 @@ impl<'ctx> SearchContext<'ctx> {
proximity: u8,
) -> Result<Option<RoaringBitmap>> {
match self.index.proximity_precision(self.txn)?.unwrap_or_default() {
ProximityPrecision::AttributeScale => {
ProximityPrecision::ByAttribute => {
// Force proximity to 0 because:
// in AttributeScale, there are only 2 possible distances:
// in ByAttribute, there are only 2 possible distances:
// 1. words in same attribute: in that the DB contains (0, word1, word2)
// 2. words in different attributes: no DB entry for these two words.
let proximity = 0;
@ -311,19 +345,17 @@ impl<'ctx> SearchContext<'ctx> {
Ok(docids)
}
ProximityPrecision::WordScale => {
DatabaseCache::get_value::<_, _, CboRoaringBitmapCodec>(
self.txn,
(proximity, word1, word2),
&(
proximity,
self.word_interner.get(word1).as_str(),
self.word_interner.get(word2).as_str(),
),
&mut self.db_cache.word_pair_proximity_docids,
self.index.word_pair_proximity_docids.remap_data_type::<Bytes>(),
)
}
ProximityPrecision::ByWord => DatabaseCache::get_value::<_, _, CboRoaringBitmapCodec>(
self.txn,
(proximity, word1, word2),
&(
proximity,
self.word_interner.get(word1).as_str(),
self.word_interner.get(word2).as_str(),
),
&mut self.db_cache.word_pair_proximity_docids,
self.index.word_pair_proximity_docids.remap_data_type::<Bytes>(),
),
}
}
@ -334,10 +366,10 @@ impl<'ctx> SearchContext<'ctx> {
proximity: u8,
) -> Result<Option<u64>> {
match self.index.proximity_precision(self.txn)?.unwrap_or_default() {
ProximityPrecision::AttributeScale => Ok(self
ProximityPrecision::ByAttribute => Ok(self
.get_db_word_pair_proximity_docids(word1, word2, proximity)?
.map(|d| d.len())),
ProximityPrecision::WordScale => {
ProximityPrecision::ByWord => {
DatabaseCache::get_value::<_, _, CboRoaringBitmapLenCodec>(
self.txn,
(proximity, word1, word2),
@ -360,9 +392,9 @@ impl<'ctx> SearchContext<'ctx> {
mut proximity: u8,
) -> Result<Option<RoaringBitmap>> {
let proximity_precision = self.index.proximity_precision(self.txn)?.unwrap_or_default();
if proximity_precision == ProximityPrecision::AttributeScale {
if proximity_precision == ProximityPrecision::ByAttribute {
// Force proximity to 0 because:
// in AttributeScale, there are only 2 possible distances:
// in ByAttribute, there are only 2 possible distances:
// 1. words in same attribute: in that the DB contains (0, word1, word2)
// 2. words in different attributes: no DB entry for these two words.
proximity = 0;
@ -374,7 +406,7 @@ impl<'ctx> SearchContext<'ctx> {
docids.clone()
} else {
let prefix_docids = match proximity_precision {
ProximityPrecision::AttributeScale => {
ProximityPrecision::ByAttribute => {
// Compute the distance at the attribute level and store it in the cache.
let fids = if let Some(fids) = self.index.searchable_fields_ids(self.txn)? {
fids
@ -395,7 +427,7 @@ impl<'ctx> SearchContext<'ctx> {
}
prefix_docids
}
ProximityPrecision::WordScale => {
ProximityPrecision::ByWord => {
// compute docids using prefix iter and store the result in the cache.
let key = U8StrStrCodec::bytes_encode(&(
proximity,

View File

@ -107,12 +107,16 @@ impl<Q: RankingRuleQueryTrait> GeoSort<Q> {
/// Refill the internal buffer of cached docids based on the strategy.
/// Drop the rtree if we don't need it anymore.
fn fill_buffer(&mut self, ctx: &mut SearchContext) -> Result<()> {
fn fill_buffer(
&mut self,
ctx: &mut SearchContext,
geo_candidates: &RoaringBitmap,
) -> Result<()> {
debug_assert!(self.field_ids.is_some(), "fill_buffer can't be called without the lat&lng");
debug_assert!(self.cached_sorted_docids.is_empty());
// lazily initialize the rtree if needed by the strategy, and cache it in `self.rtree`
let rtree = if self.strategy.use_rtree(self.geo_candidates.len() as usize) {
let rtree = if self.strategy.use_rtree(geo_candidates.len() as usize) {
if let Some(rtree) = self.rtree.as_ref() {
// get rtree from cache
Some(rtree)
@ -131,7 +135,7 @@ impl<Q: RankingRuleQueryTrait> GeoSort<Q> {
if self.ascending {
let point = lat_lng_to_xyz(&self.point);
for point in rtree.nearest_neighbor_iter(&point) {
if self.geo_candidates.contains(point.data.0) {
if geo_candidates.contains(point.data.0) {
self.cached_sorted_docids.push_back(point.data);
if self.cached_sorted_docids.len() >= cache_size {
break;
@ -143,7 +147,7 @@ impl<Q: RankingRuleQueryTrait> GeoSort<Q> {
// and we insert the points in reverse order they get reversed when emptying the cache later on
let point = lat_lng_to_xyz(&opposite_of(self.point));
for point in rtree.nearest_neighbor_iter(&point) {
if self.geo_candidates.contains(point.data.0) {
if geo_candidates.contains(point.data.0) {
self.cached_sorted_docids.push_front(point.data);
if self.cached_sorted_docids.len() >= cache_size {
break;
@ -155,8 +159,7 @@ impl<Q: RankingRuleQueryTrait> GeoSort<Q> {
// the iterative version
let [lat, lng] = self.field_ids.unwrap();
let mut documents = self
.geo_candidates
let mut documents = geo_candidates
.iter()
.map(|id| -> Result<_> { Ok((id, geo_value(id, lat, lng, ctx.index, ctx.txn)?)) })
.collect::<Result<Vec<(u32, [f64; 2])>>>()?;
@ -216,9 +219,10 @@ impl<'ctx, Q: RankingRuleQueryTrait> RankingRule<'ctx, Q> for GeoSort<Q> {
assert!(self.query.is_none());
self.query = Some(query.clone());
self.geo_candidates &= universe;
if self.geo_candidates.is_empty() {
let geo_candidates = &self.geo_candidates & universe;
if geo_candidates.is_empty() {
return Ok(());
}
@ -226,7 +230,7 @@ impl<'ctx, Q: RankingRuleQueryTrait> RankingRule<'ctx, Q> for GeoSort<Q> {
let lat = fid_map.id("_geo.lat").expect("geo candidates but no fid for lat");
let lng = fid_map.id("_geo.lng").expect("geo candidates but no fid for lng");
self.field_ids = Some([lat, lng]);
self.fill_buffer(ctx)?;
self.fill_buffer(ctx, &geo_candidates)?;
Ok(())
}
@ -238,9 +242,10 @@ impl<'ctx, Q: RankingRuleQueryTrait> RankingRule<'ctx, Q> for GeoSort<Q> {
universe: &RoaringBitmap,
) -> Result<Option<RankingRuleOutput<Q>>> {
let query = self.query.as_ref().unwrap().clone();
self.geo_candidates &= universe;
if self.geo_candidates.is_empty() {
let geo_candidates = &self.geo_candidates & universe;
if geo_candidates.is_empty() {
return Ok(Some(RankingRuleOutput {
query,
candidates: universe.clone(),
@ -261,7 +266,7 @@ impl<'ctx, Q: RankingRuleQueryTrait> RankingRule<'ctx, Q> for GeoSort<Q> {
}
};
while let Some((id, point)) = next(&mut self.cached_sorted_docids) {
if self.geo_candidates.contains(id) {
if geo_candidates.contains(id) {
return Ok(Some(RankingRuleOutput {
query,
candidates: RoaringBitmap::from_iter([id]),
@ -276,7 +281,7 @@ impl<'ctx, Q: RankingRuleQueryTrait> RankingRule<'ctx, Q> for GeoSort<Q> {
// if we got out of this loop it means we've exhausted our cache.
// we need to refill it and run the function again.
self.fill_buffer(ctx)?;
self.fill_buffer(ctx, &geo_candidates)?;
self.next_bucket(ctx, logger, universe)
}

View File

@ -72,7 +72,7 @@ impl<'m> MatcherBuilder<'m> {
}
}
#[derive(Copy, Clone, Default)]
#[derive(Copy, Clone, Default, Debug)]
pub struct FormatOptions {
pub highlight: bool,
pub crop: Option<usize>,
@ -82,6 +82,10 @@ impl FormatOptions {
pub fn merge(self, other: Self) -> Self {
Self { highlight: self.highlight || other.highlight, crop: self.crop.or(other.crop) }
}
pub fn should_format(&self) -> bool {
self.highlight || self.crop.is_some()
}
}
#[derive(Clone, Debug)]
@ -498,19 +502,19 @@ mod tests {
use super::*;
use crate::index::tests::TempIndex;
use crate::{execute_search, SearchContext};
use crate::{execute_search, filtered_universe, SearchContext};
impl<'a> MatcherBuilder<'a> {
fn new_test(rtxn: &'a heed::RoTxn, index: &'a TempIndex, query: &str) -> Self {
let mut ctx = SearchContext::new(index, rtxn);
let universe = filtered_universe(&ctx, &None).unwrap();
let crate::search::PartialSearchResult { located_query_terms, .. } = execute_search(
&mut ctx,
&Some(query.to_string()),
&None,
Some(query),
crate::TermsMatchingStrategy::default(),
crate::score_details::ScoringStrategy::Skip,
false,
&None,
universe,
&None,
crate::search::new::GeoSortStrategy::default(),
0,

View File

@ -16,6 +16,7 @@ mod small_bitmap;
mod exact_attribute;
mod sort;
mod vector_sort;
#[cfg(test)]
mod tests;
@ -28,7 +29,6 @@ use db_cache::DatabaseCache;
use exact_attribute::ExactAttribute;
use graph_based_ranking_rule::{Exactness, Fid, Position, Proximity, Typo};
use heed::RoTxn;
use instant_distance::Search;
use interner::{DedupInterner, Interner};
pub use logger::visual::VisualSearchLogger;
pub use logger::{DefaultSearchLogger, SearchLogger};
@ -46,11 +46,14 @@ use self::geo_sort::GeoSort;
pub use self::geo_sort::Strategy as GeoSortStrategy;
use self::graph_based_ranking_rule::Words;
use self::interner::Interned;
use crate::distance::NDotProductPoint;
use self::vector_sort::VectorSort;
use crate::error::FieldIdMapMissingEntry;
use crate::score_details::{ScoreDetails, ScoringStrategy};
use crate::search::new::distinct::apply_distinct_rule;
use crate::{AscDesc, DocumentId, Filter, Index, Member, Result, TermsMatchingStrategy, UserError};
use crate::vector::DistributionShift;
use crate::{
AscDesc, DocumentId, FieldId, Filter, Index, Member, Result, TermsMatchingStrategy, UserError,
};
/// A structure used throughout the execution of a search query.
pub struct SearchContext<'ctx> {
@ -61,7 +64,7 @@ pub struct SearchContext<'ctx> {
pub phrase_interner: DedupInterner<Phrase>,
pub term_interner: Interner<QueryTerm>,
pub phrase_docids: PhraseDocIdsCache,
pub restricted_fids: Option<Vec<u16>>,
pub restricted_fids: Option<RestrictedFids>,
}
impl<'ctx> SearchContext<'ctx> {
@ -81,8 +84,9 @@ impl<'ctx> SearchContext<'ctx> {
pub fn searchable_attributes(&mut self, searchable_attributes: &'ctx [String]) -> Result<()> {
let fids_map = self.index.fields_ids_map(self.txn)?;
let searchable_names = self.index.searchable_fields(self.txn)?;
let exact_attributes_ids = self.index.exact_attributes_ids(self.txn)?;
let mut restricted_fids = Vec::new();
let mut restricted_fids = RestrictedFids::default();
let mut contains_wildcard = false;
for field_name in searchable_attributes {
if field_name == "*" {
@ -121,7 +125,11 @@ impl<'ctx> SearchContext<'ctx> {
}
};
restricted_fids.push(fid);
if exact_attributes_ids.contains(&fid) {
restricted_fids.exact.push(fid);
} else {
restricted_fids.tolerant.push(fid);
};
}
self.restricted_fids = (!contains_wildcard).then_some(restricted_fids);
@ -145,6 +153,18 @@ impl Word {
}
}
#[derive(Debug, Clone, Default)]
pub struct RestrictedFids {
pub tolerant: Vec<FieldId>,
pub exact: Vec<FieldId>,
}
impl RestrictedFids {
pub fn contains(&self, fid: &FieldId) -> bool {
self.tolerant.contains(fid) || self.exact.contains(fid)
}
}
/// Apply the [`TermsMatchingStrategy`] to the query graph and resolve it.
fn resolve_maximally_reduced_query_graph(
ctx: &mut SearchContext,
@ -171,6 +191,7 @@ fn resolve_maximally_reduced_query_graph(
Ok(docids)
}
#[logging_timer::time]
fn resolve_universe(
ctx: &mut SearchContext,
initial_universe: &RoaringBitmap,
@ -239,6 +260,80 @@ fn get_ranking_rules_for_placeholder_search<'ctx>(
Ok(ranking_rules)
}
fn get_ranking_rules_for_vector<'ctx>(
ctx: &SearchContext<'ctx>,
sort_criteria: &Option<Vec<AscDesc>>,
geo_strategy: geo_sort::Strategy,
limit_plus_offset: usize,
target: &[f32],
distribution_shift: Option<DistributionShift>,
embedder_name: &str,
) -> Result<Vec<BoxRankingRule<'ctx, PlaceholderQuery>>> {
// query graph search
let mut sort = false;
let mut sorted_fields = HashSet::new();
let mut geo_sorted = false;
let mut vector = false;
let mut ranking_rules: Vec<BoxRankingRule<PlaceholderQuery>> = vec![];
let settings_ranking_rules = ctx.index.criteria(ctx.txn)?;
for rr in settings_ranking_rules {
match rr {
crate::Criterion::Words
| crate::Criterion::Typo
| crate::Criterion::Proximity
| crate::Criterion::Attribute
| crate::Criterion::Exactness => {
if !vector {
let vector_candidates = ctx.index.documents_ids(ctx.txn)?;
let vector_sort = VectorSort::new(
ctx,
target.to_vec(),
vector_candidates,
limit_plus_offset,
distribution_shift,
embedder_name,
)?;
ranking_rules.push(Box::new(vector_sort));
vector = true;
}
}
crate::Criterion::Sort => {
if sort {
continue;
}
resolve_sort_criteria(
sort_criteria,
ctx,
&mut ranking_rules,
&mut sorted_fields,
&mut geo_sorted,
geo_strategy,
)?;
sort = true;
}
crate::Criterion::Asc(field_name) => {
if sorted_fields.contains(&field_name) {
continue;
}
sorted_fields.insert(field_name.clone());
ranking_rules.push(Box::new(Sort::new(ctx.index, ctx.txn, field_name, true)?));
}
crate::Criterion::Desc(field_name) => {
if sorted_fields.contains(&field_name) {
continue;
}
sorted_fields.insert(field_name.clone());
ranking_rules.push(Box::new(Sort::new(ctx.index, ctx.txn, field_name, false)?));
}
}
}
Ok(ranking_rules)
}
/// Return the list of initialised ranking rules to be used for a query graph search.
fn get_ranking_rules_for_query_graph_search<'ctx>(
ctx: &SearchContext<'ctx>,
@ -403,15 +498,73 @@ fn resolve_sort_criteria<'ctx, Query: RankingRuleQueryTrait>(
Ok(())
}
pub fn filtered_universe(ctx: &SearchContext, filters: &Option<Filter>) -> Result<RoaringBitmap> {
Ok(if let Some(filters) = filters {
filters.evaluate(ctx.txn, ctx.index)?
} else {
ctx.index.documents_ids(ctx.txn)?
})
}
#[allow(clippy::too_many_arguments)]
pub fn execute_vector_search(
ctx: &mut SearchContext,
vector: &[f32],
scoring_strategy: ScoringStrategy,
universe: RoaringBitmap,
sort_criteria: &Option<Vec<AscDesc>>,
geo_strategy: geo_sort::Strategy,
from: usize,
length: usize,
distribution_shift: Option<DistributionShift>,
embedder_name: &str,
) -> Result<PartialSearchResult> {
check_sort_criteria(ctx, sort_criteria.as_ref())?;
// FIXME: input universe = universe & documents_with_vectors
// for now if we're computing embeddings for ALL documents, we can assume that this is just universe
let ranking_rules = get_ranking_rules_for_vector(
ctx,
sort_criteria,
geo_strategy,
from + length,
vector,
distribution_shift,
embedder_name,
)?;
let mut placeholder_search_logger = logger::DefaultSearchLogger;
let placeholder_search_logger: &mut dyn SearchLogger<PlaceholderQuery> =
&mut placeholder_search_logger;
let BucketSortOutput { docids, scores, all_candidates } = bucket_sort(
ctx,
ranking_rules,
&PlaceholderQuery,
&universe,
from,
length,
scoring_strategy,
placeholder_search_logger,
)?;
Ok(PartialSearchResult {
candidates: all_candidates,
document_scores: scores,
documents_ids: docids,
located_query_terms: None,
})
}
#[allow(clippy::too_many_arguments)]
#[logging_timer::time]
pub fn execute_search(
ctx: &mut SearchContext,
query: &Option<String>,
vector: &Option<Vec<f32>>,
query: Option<&str>,
terms_matching_strategy: TermsMatchingStrategy,
scoring_strategy: ScoringStrategy,
exhaustive_number_hits: bool,
filters: &Option<Filter>,
mut universe: RoaringBitmap,
sort_criteria: &Option<Vec<AscDesc>>,
geo_strategy: geo_sort::Strategy,
from: usize,
@ -420,60 +573,8 @@ pub fn execute_search(
placeholder_search_logger: &mut dyn SearchLogger<PlaceholderQuery>,
query_graph_logger: &mut dyn SearchLogger<QueryGraph>,
) -> Result<PartialSearchResult> {
let mut universe = if let Some(filters) = filters {
filters.evaluate(ctx.txn, ctx.index)?
} else {
ctx.index.documents_ids(ctx.txn)?
};
check_sort_criteria(ctx, sort_criteria.as_ref())?;
if let Some(vector) = vector {
let mut search = Search::default();
let docids = match ctx.index.vector_hnsw(ctx.txn)? {
Some(hnsw) => {
if let Some(expected_size) = hnsw.iter().map(|(_, point)| point.len()).next() {
if vector.len() != expected_size {
return Err(UserError::InvalidVectorDimensions {
expected: expected_size,
found: vector.len(),
}
.into());
}
}
let vector = NDotProductPoint::new(vector.clone());
let neighbors = hnsw.search(&vector, &mut search);
let mut docids = Vec::new();
let mut uniq_docids = RoaringBitmap::new();
for instant_distance::Item { distance: _, pid, point: _ } in neighbors {
let index = pid.into_inner();
let docid = ctx.index.vector_id_docid.get(ctx.txn, &index)?.unwrap();
if universe.contains(docid) && uniq_docids.insert(docid) {
docids.push(docid);
if docids.len() == (from + length) {
break;
}
}
}
// return the nearest documents that are also part of the candidates
// along with a dummy list of scores that are useless in this context.
docids.into_iter().skip(from).take(length).collect()
}
None => Vec::new(),
};
return Ok(PartialSearchResult {
candidates: universe,
document_scores: vec![Vec::new(); docids.len()],
documents_ids: docids,
located_query_terms: None,
});
}
let mut located_query_terms = None;
let query_terms = if let Some(query) = query {
// We make sure that the analyzer is aware of the stop words
@ -527,7 +628,7 @@ pub fn execute_search(
terms_matching_strategy,
)?;
universe =
universe &=
resolve_universe(ctx, &universe, &graph, terms_matching_strategy, query_graph_logger)?;
bucket_sort(

View File

@ -5,6 +5,7 @@ use super::*;
use crate::{Result, SearchContext, MAX_WORD_LENGTH};
/// Convert the tokenised search query into a list of located query terms.
#[logging_timer::time]
pub fn located_query_terms_from_tokens(
ctx: &mut SearchContext,
query: NormalizedTokenIter,

View File

@ -0,0 +1,170 @@
use std::iter::FromIterator;
use ordered_float::OrderedFloat;
use roaring::RoaringBitmap;
use super::ranking_rules::{RankingRule, RankingRuleOutput, RankingRuleQueryTrait};
use crate::score_details::{self, ScoreDetails};
use crate::vector::DistributionShift;
use crate::{DocumentId, Result, SearchContext, SearchLogger};
pub struct VectorSort<Q: RankingRuleQueryTrait> {
query: Option<Q>,
target: Vec<f32>,
vector_candidates: RoaringBitmap,
cached_sorted_docids: std::vec::IntoIter<(DocumentId, f32, Vec<f32>)>,
limit: usize,
distribution_shift: Option<DistributionShift>,
embedder_index: u8,
}
impl<Q: RankingRuleQueryTrait> VectorSort<Q> {
pub fn new(
ctx: &SearchContext,
target: Vec<f32>,
vector_candidates: RoaringBitmap,
limit: usize,
distribution_shift: Option<DistributionShift>,
embedder_name: &str,
) -> Result<Self> {
let embedder_index = ctx
.index
.embedder_category_id
.get(ctx.txn, embedder_name)?
.ok_or_else(|| crate::UserError::InvalidEmbedder(embedder_name.to_owned()))?;
Ok(Self {
query: None,
target,
vector_candidates,
cached_sorted_docids: Default::default(),
limit,
distribution_shift,
embedder_index,
})
}
fn fill_buffer(
&mut self,
ctx: &mut SearchContext<'_>,
vector_candidates: &RoaringBitmap,
) -> Result<()> {
let writer_index = (self.embedder_index as u16) << 8;
let readers: std::result::Result<Vec<_>, _> = (0..=u8::MAX)
.map_while(|k| {
arroy::Reader::open(ctx.txn, writer_index | (k as u16), ctx.index.vector_arroy)
.map(Some)
.or_else(|e| match e {
arroy::Error::MissingMetadata => Ok(None),
e => Err(e),
})
.transpose()
})
.collect();
let readers = readers?;
let target = &self.target;
let mut results = Vec::new();
for reader in readers.iter() {
let nns_by_vector =
reader.nns_by_vector(ctx.txn, target, self.limit, None, Some(vector_candidates))?;
let vectors: std::result::Result<Vec<_>, _> = nns_by_vector
.iter()
.map(|(docid, _)| reader.item_vector(ctx.txn, *docid).transpose().unwrap())
.collect();
let vectors = vectors?;
results.extend(nns_by_vector.into_iter().zip(vectors).map(|((x, y), z)| (x, y, z)));
}
results.sort_unstable_by_key(|(_, distance, _)| OrderedFloat(*distance));
self.cached_sorted_docids = results.into_iter();
Ok(())
}
}
impl<'ctx, Q: RankingRuleQueryTrait> RankingRule<'ctx, Q> for VectorSort<Q> {
fn id(&self) -> String {
"vector_sort".to_owned()
}
fn start_iteration(
&mut self,
ctx: &mut SearchContext<'ctx>,
_logger: &mut dyn SearchLogger<Q>,
universe: &RoaringBitmap,
query: &Q,
) -> Result<()> {
assert!(self.query.is_none());
self.query = Some(query.clone());
let vector_candidates = &self.vector_candidates & universe;
self.fill_buffer(ctx, &vector_candidates)?;
Ok(())
}
#[allow(clippy::only_used_in_recursion)]
fn next_bucket(
&mut self,
ctx: &mut SearchContext<'ctx>,
_logger: &mut dyn SearchLogger<Q>,
universe: &RoaringBitmap,
) -> Result<Option<RankingRuleOutput<Q>>> {
let query = self.query.as_ref().unwrap().clone();
let vector_candidates = &self.vector_candidates & universe;
if vector_candidates.is_empty() {
return Ok(Some(RankingRuleOutput {
query,
candidates: universe.clone(),
score: ScoreDetails::Vector(score_details::Vector {
target_vector: self.target.clone(),
value_similarity: None,
}),
}));
}
for (docid, distance, vector) in self.cached_sorted_docids.by_ref() {
if vector_candidates.contains(docid) {
let score = 1.0 - distance;
let score = self
.distribution_shift
.map(|distribution| distribution.shift(score))
.unwrap_or(score);
return Ok(Some(RankingRuleOutput {
query,
candidates: RoaringBitmap::from_iter([docid]),
score: ScoreDetails::Vector(score_details::Vector {
target_vector: self.target.clone(),
value_similarity: Some((vector, score)),
}),
}));
}
}
// if we got out of this loop it means we've exhausted our cache.
// we need to refill it and run the function again.
self.fill_buffer(ctx, &vector_candidates)?;
// we tried filling the buffer, but it remained empty 😢
// it means we don't actually have any document remaining in the universe with a vector.
// => exit
if self.cached_sorted_docids.len() == 0 {
return Ok(Some(RankingRuleOutput {
query,
candidates: universe.clone(),
score: ScoreDetails::Vector(score_details::Vector {
target_vector: self.target.clone(),
value_similarity: None,
}),
}));
}
self.next_bucket(ctx, _logger, universe)
}
fn end_iteration(&mut self, _ctx: &mut SearchContext<'ctx>, _logger: &mut dyn SearchLogger<Q>) {
self.query = None;
}
}

View File

@ -42,7 +42,8 @@ impl<'t, 'i> ClearDocuments<'t, 'i> {
facet_id_is_empty_docids,
field_id_docid_facet_f64s,
field_id_docid_facet_strings,
vector_id_docid,
vector_arroy,
embedder_category_id: _,
documents,
} = self.index;
@ -58,7 +59,6 @@ impl<'t, 'i> ClearDocuments<'t, 'i> {
self.index.put_field_distribution(self.wtxn, &FieldDistribution::default())?;
self.index.delete_geo_rtree(self.wtxn)?;
self.index.delete_geo_faceted_documents_ids(self.wtxn)?;
self.index.delete_vector_hnsw(self.wtxn)?;
// Clear the other databases.
external_documents_ids.clear(self.wtxn)?;
@ -82,7 +82,9 @@ impl<'t, 'i> ClearDocuments<'t, 'i> {
facet_id_string_docids.clear(self.wtxn)?;
field_id_docid_facet_f64s.clear(self.wtxn)?;
field_id_docid_facet_strings.clear(self.wtxn)?;
vector_id_docid.clear(self.wtxn)?;
// vector
vector_arroy.clear(self.wtxn)?;
documents.clear(self.wtxn)?;
Ok(number_of_documents)

View File

@ -61,6 +61,7 @@ impl FacetsUpdateIncremental {
}
}
#[logging_timer::time("FacetsUpdateIncremental::{}")]
pub fn execute(self, wtxn: &mut RwTxn) -> crate::Result<()> {
let mut cursor = self.delta_data.into_cursor()?;
while let Some((key, value)) = cursor.move_on_next()? {

View File

@ -170,91 +170,91 @@ impl<'i> FacetsUpdate<'i> {
incremental_update.execute(wtxn)?;
}
// We clear the list of normalized-for-search facets
// and the previous FSTs to compute everything from scratch
self.index.facet_id_normalized_string_strings.clear(wtxn)?;
self.index.facet_id_string_fst.clear(wtxn)?;
// // We clear the list of normalized-for-search facets
// // and the previous FSTs to compute everything from scratch
// self.index.facet_id_normalized_string_strings.clear(wtxn)?;
// self.index.facet_id_string_fst.clear(wtxn)?;
// As we can't use the same write transaction to read and write in two different databases
// we must create a temporary sorter that we will write into LMDB afterward.
// As multiple unnormalized facet values can become the same normalized facet value
// we must merge them together.
let mut sorter = create_sorter(
SortAlgorithm::Unstable,
merge_btreeset_string,
CompressionType::None,
None,
None,
None,
);
// // As we can't use the same write transaction to read and write in two different databases
// // we must create a temporary sorter that we will write into LMDB afterward.
// // As multiple unnormalized facet values can become the same normalized facet value
// // we must merge them together.
// let mut sorter = create_sorter(
// SortAlgorithm::Unstable,
// merge_btreeset_string,
// CompressionType::None,
// None,
// None,
// None,
// );
// We iterate on the list of original, semi-normalized, facet values
// and normalize them for search, inserting them in LMDB in any given order.
let options = NormalizerOption { lossy: true, ..Default::default() };
let database = self.index.facet_id_string_docids.remap_data_type::<DecodeIgnore>();
for result in database.iter(wtxn)? {
let (facet_group_key, ()) = result?;
if let FacetGroupKey { field_id, level: 0, left_bound } = facet_group_key {
let mut normalized_facet = left_bound.normalize(&options);
let normalized_truncated_facet: String;
if normalized_facet.len() > MAX_FACET_VALUE_LENGTH {
normalized_truncated_facet = normalized_facet
.char_indices()
.take_while(|(idx, _)| *idx < MAX_FACET_VALUE_LENGTH)
.map(|(_, c)| c)
.collect();
normalized_facet = normalized_truncated_facet.into();
}
let set = BTreeSet::from_iter(std::iter::once(left_bound));
let key = (field_id, normalized_facet.as_ref());
let key = BEU16StrCodec::bytes_encode(&key).map_err(heed::Error::Encoding)?;
let val = SerdeJson::bytes_encode(&set).map_err(heed::Error::Encoding)?;
sorter.insert(key, val)?;
}
}
// // We iterate on the list of original, semi-normalized, facet values
// // and normalize them for search, inserting them in LMDB in any given order.
// let options = NormalizerOption { lossy: true, ..Default::default() };
// let database = self.index.facet_id_string_docids.remap_data_type::<DecodeIgnore>();
// for result in database.iter(wtxn)? {
// let (facet_group_key, ()) = result?;
// if let FacetGroupKey { field_id, level: 0, left_bound } = facet_group_key {
// let mut normalized_facet = left_bound.normalize(&options);
// let normalized_truncated_facet: String;
// if normalized_facet.len() > MAX_FACET_VALUE_LENGTH {
// normalized_truncated_facet = normalized_facet
// .char_indices()
// .take_while(|(idx, _)| *idx < MAX_FACET_VALUE_LENGTH)
// .map(|(_, c)| c)
// .collect();
// normalized_facet = normalized_truncated_facet.into();
// }
// let set = BTreeSet::from_iter(std::iter::once(left_bound));
// let key = (field_id, normalized_facet.as_ref());
// let key = BEU16StrCodec::bytes_encode(&key).map_err(heed::Error::Encoding)?;
// let val = SerdeJson::bytes_encode(&set).map_err(heed::Error::Encoding)?;
// sorter.insert(key, val)?;
// }
// }
// In this loop we don't need to take care of merging bitmaps
// as the grenad sorter already merged them for us.
let mut merger_iter = sorter.into_stream_merger_iter()?;
while let Some((key_bytes, btreeset_bytes)) = merger_iter.next()? {
self.index.facet_id_normalized_string_strings.remap_types::<Bytes, Bytes>().put(
wtxn,
key_bytes,
btreeset_bytes,
)?;
}
// // In this loop we don't need to take care of merging bitmaps
// // as the grenad sorter already merged them for us.
// let mut merger_iter = sorter.into_stream_merger_iter()?;
// while let Some((key_bytes, btreeset_bytes)) = merger_iter.next()? {
// self.index.facet_id_normalized_string_strings.remap_types::<Bytes, Bytes>().put(
// wtxn,
// key_bytes,
// btreeset_bytes,
// )?;
// }
// We compute one FST by string facet
let mut text_fsts = vec![];
let mut current_fst: Option<(u16, fst::SetBuilder<Vec<u8>>)> = None;
let database =
self.index.facet_id_normalized_string_strings.remap_data_type::<DecodeIgnore>();
for result in database.iter(wtxn)? {
let ((field_id, normalized_facet), _) = result?;
current_fst = match current_fst.take() {
Some((fid, fst_builder)) if fid != field_id => {
let fst = fst_builder.into_set();
text_fsts.push((fid, fst));
Some((field_id, fst::SetBuilder::memory()))
}
Some((field_id, fst_builder)) => Some((field_id, fst_builder)),
None => Some((field_id, fst::SetBuilder::memory())),
};
// // We compute one FST by string facet
// let mut text_fsts = vec![];
// let mut current_fst: Option<(u16, fst::SetBuilder<Vec<u8>>)> = None;
// let database =
// self.index.facet_id_normalized_string_strings.remap_data_type::<DecodeIgnore>();
// for result in database.iter(wtxn)? {
// let ((field_id, normalized_facet), _) = result?;
// current_fst = match current_fst.take() {
// Some((fid, fst_builder)) if fid != field_id => {
// let fst = fst_builder.into_set();
// text_fsts.push((fid, fst));
// Some((field_id, fst::SetBuilder::memory()))
// }
// Some((field_id, fst_builder)) => Some((field_id, fst_builder)),
// None => Some((field_id, fst::SetBuilder::memory())),
// };
if let Some((_, fst_builder)) = current_fst.as_mut() {
fst_builder.insert(normalized_facet)?;
}
}
// if let Some((_, fst_builder)) = current_fst.as_mut() {
// fst_builder.insert(normalized_facet)?;
// }
// }
if let Some((field_id, fst_builder)) = current_fst {
let fst = fst_builder.into_set();
text_fsts.push((field_id, fst));
}
// if let Some((field_id, fst_builder)) = current_fst {
// let fst = fst_builder.into_set();
// text_fsts.push((field_id, fst));
// }
// We write those FSTs in LMDB now
for (field_id, fst) in text_fsts {
self.index.facet_id_string_fst.put(wtxn, &field_id, &fst)?;
}
// // We write those FSTs in LMDB now
// for (field_id, fst) in text_fsts {
// self.index.facet_id_string_fst.put(wtxn, &field_id, &fst)?;
// }
Ok(())
}

View File

@ -1,9 +1,10 @@
use std::cmp::Ordering;
use std::convert::TryFrom;
use std::convert::{TryFrom, TryInto};
use std::fs::File;
use std::io::{self, BufReader, BufWriter};
use std::mem::size_of;
use std::str::from_utf8;
use std::sync::Arc;
use bytemuck::cast_slice;
use grenad::Writer;
@ -13,13 +14,56 @@ use serde_json::{from_slice, Value};
use super::helpers::{create_writer, writer_into_reader, GrenadParameters};
use crate::error::UserError;
use crate::prompt::Prompt;
use crate::update::del_add::{DelAdd, KvReaderDelAdd, KvWriterDelAdd};
use crate::update::index_documents::helpers::try_split_at;
use crate::{DocumentId, FieldId, InternalError, Result, VectorOrArrayOfVectors};
use crate::vector::Embedder;
use crate::{DocumentId, FieldsIdsMap, InternalError, Result, VectorOrArrayOfVectors};
/// The length of the elements that are always in the buffer when inserting new values.
const TRUNCATE_SIZE: usize = size_of::<DocumentId>();
pub struct ExtractedVectorPoints {
// docid, _index -> KvWriterDelAdd -> Vector
pub manual_vectors: grenad::Reader<BufReader<File>>,
// docid -> ()
pub remove_vectors: grenad::Reader<BufReader<File>>,
// docid -> prompt
pub prompts: grenad::Reader<BufReader<File>>,
}
enum VectorStateDelta {
NoChange,
// Remove all vectors, generated or manual, from this document
NowRemoved,
// Add the manually specified vectors, passed in the other grenad
// Remove any previously generated vectors
// Note: changing the value of the manually specified vector **should not record** this delta
WasGeneratedNowManual(Vec<Vec<f32>>),
ManualDelta(Vec<Vec<f32>>, Vec<Vec<f32>>),
// Add the vector computed from the specified prompt
// Remove any previous vector
// Note: changing the value of the prompt **does require** recording this delta
NowGenerated(String),
}
impl VectorStateDelta {
fn into_values(self) -> (bool, String, (Vec<Vec<f32>>, Vec<Vec<f32>>)) {
match self {
VectorStateDelta::NoChange => Default::default(),
VectorStateDelta::NowRemoved => (true, Default::default(), Default::default()),
VectorStateDelta::WasGeneratedNowManual(add) => {
(true, Default::default(), (Default::default(), add))
}
VectorStateDelta::ManualDelta(del, add) => (false, Default::default(), (del, add)),
VectorStateDelta::NowGenerated(prompt) => (true, prompt, Default::default()),
}
}
}
/// Extracts the embedding vector contained in each document under the `_vectors` field.
///
/// Returns the generated grenad reader containing the docid as key associated to the Vec<f32>
@ -27,16 +71,35 @@ const TRUNCATE_SIZE: usize = size_of::<DocumentId>();
pub fn extract_vector_points<R: io::Read + io::Seek>(
obkv_documents: grenad::Reader<R>,
indexer: GrenadParameters,
vectors_fid: FieldId,
) -> Result<grenad::Reader<BufReader<File>>> {
field_id_map: &FieldsIdsMap,
prompt: &Prompt,
embedder_name: &str,
) -> Result<ExtractedVectorPoints> {
puffin::profile_function!();
let mut writer = create_writer(
// (docid, _index) -> KvWriterDelAdd -> Vector
let mut manual_vectors_writer = create_writer(
indexer.chunk_compression_type,
indexer.chunk_compression_level,
tempfile::tempfile()?,
);
// (docid) -> (prompt)
let mut prompts_writer = create_writer(
indexer.chunk_compression_type,
indexer.chunk_compression_level,
tempfile::tempfile()?,
);
// (docid) -> ()
let mut remove_vectors_writer = create_writer(
indexer.chunk_compression_type,
indexer.chunk_compression_level,
tempfile::tempfile()?,
);
let vectors_fid = field_id_map.id("_vectors");
let mut key_buffer = Vec::new();
let mut cursor = obkv_documents.into_cursor()?;
while let Some((key, value)) = cursor.move_on_next()? {
@ -53,43 +116,157 @@ pub fn extract_vector_points<R: io::Read + io::Seek>(
// lazily get it when needed
let document_id = || -> Value { from_utf8(external_id_bytes).unwrap().into() };
// first we retrieve the _vectors field
if let Some(value) = obkv.get(vectors_fid) {
let vectors_obkv = KvReaderDelAdd::new(value);
let vectors_field = vectors_fid
.and_then(|vectors_fid| obkv.get(vectors_fid))
.map(KvReaderDelAdd::new)
.map(|obkv| to_vector_maps(obkv, document_id))
.transpose()?;
// then we extract the values
let del_vectors = vectors_obkv
.get(DelAdd::Deletion)
.map(|vectors| extract_vectors(vectors, document_id))
.transpose()?
.flatten();
let add_vectors = vectors_obkv
.get(DelAdd::Addition)
.map(|vectors| extract_vectors(vectors, document_id))
.transpose()?
.flatten();
let (del_map, add_map) = vectors_field.unzip();
let del_map = del_map.flatten();
let add_map = add_map.flatten();
// and we finally push the unique vectors into the writer
push_vectors_diff(
&mut writer,
&mut key_buffer,
del_vectors.unwrap_or_default(),
add_vectors.unwrap_or_default(),
)?;
}
let del_value = del_map.and_then(|mut map| map.remove(embedder_name));
let add_value = add_map.and_then(|mut map| map.remove(embedder_name));
let delta = match (del_value, add_value) {
(Some(old), Some(new)) => {
// no autogeneration
let del_vectors = extract_vectors(old, document_id, embedder_name)?;
let add_vectors = extract_vectors(new, document_id, embedder_name)?;
if add_vectors.len() > u8::MAX.into() {
return Err(crate::Error::UserError(crate::UserError::TooManyVectors(
document_id().to_string(),
add_vectors.len(),
)));
}
VectorStateDelta::ManualDelta(del_vectors, add_vectors)
}
(Some(_old), None) => {
// Do we keep this document?
let document_is_kept = obkv
.iter()
.map(|(_, deladd)| KvReaderDelAdd::new(deladd))
.any(|deladd| deladd.get(DelAdd::Addition).is_some());
if document_is_kept {
// becomes autogenerated
VectorStateDelta::NowGenerated(prompt.render(
obkv,
DelAdd::Addition,
field_id_map,
)?)
} else {
VectorStateDelta::NowRemoved
}
}
(None, Some(new)) => {
// was possibly autogenerated, remove all vectors for that document
let add_vectors = extract_vectors(new, document_id, embedder_name)?;
if add_vectors.len() > u8::MAX.into() {
return Err(crate::Error::UserError(crate::UserError::TooManyVectors(
document_id().to_string(),
add_vectors.len(),
)));
}
VectorStateDelta::WasGeneratedNowManual(add_vectors)
}
(None, None) => {
// Do we keep this document?
let document_is_kept = obkv
.iter()
.map(|(_, deladd)| KvReaderDelAdd::new(deladd))
.any(|deladd| deladd.get(DelAdd::Addition).is_some());
if document_is_kept {
// Don't give up if the old prompt was failing
let old_prompt =
prompt.render(obkv, DelAdd::Deletion, field_id_map).unwrap_or_default();
let new_prompt = prompt.render(obkv, DelAdd::Addition, field_id_map)?;
if old_prompt != new_prompt {
log::trace!(
"🚀 Changing prompt from\n{old_prompt}\n===to===\n{new_prompt}"
);
VectorStateDelta::NowGenerated(new_prompt)
} else {
log::trace!("⏭️ Prompt unmodified, skipping");
VectorStateDelta::NoChange
}
} else {
VectorStateDelta::NowRemoved
}
}
};
// and we finally push the unique vectors into the writer
push_vectors_diff(
&mut remove_vectors_writer,
&mut prompts_writer,
&mut manual_vectors_writer,
&mut key_buffer,
delta,
)?;
}
writer_into_reader(writer)
Ok(ExtractedVectorPoints {
// docid, _index -> KvWriterDelAdd -> Vector
manual_vectors: writer_into_reader(manual_vectors_writer)?,
// docid -> ()
remove_vectors: writer_into_reader(remove_vectors_writer)?,
// docid -> prompt
prompts: writer_into_reader(prompts_writer)?,
})
}
fn to_vector_maps(
obkv: KvReaderDelAdd,
document_id: impl Fn() -> Value,
) -> Result<(Option<serde_json::Map<String, Value>>, Option<serde_json::Map<String, Value>>)> {
let del = to_vector_map(obkv, DelAdd::Deletion, &document_id)?;
let add = to_vector_map(obkv, DelAdd::Addition, &document_id)?;
Ok((del, add))
}
fn to_vector_map(
obkv: KvReaderDelAdd,
side: DelAdd,
document_id: &impl Fn() -> Value,
) -> Result<Option<serde_json::Map<String, Value>>> {
Ok(if let Some(value) = obkv.get(side) {
let Ok(value) = from_slice(value) else {
let value = from_slice(value).map_err(InternalError::SerdeJson)?;
return Err(crate::Error::UserError(UserError::InvalidVectorsMapType {
document_id: document_id(),
value,
}));
};
Some(value)
} else {
None
})
}
/// Computes the diff between both Del and Add numbers and
/// only inserts the parts that differ in the sorter.
fn push_vectors_diff(
writer: &mut Writer<BufWriter<File>>,
remove_vectors_writer: &mut Writer<BufWriter<File>>,
prompts_writer: &mut Writer<BufWriter<File>>,
manual_vectors_writer: &mut Writer<BufWriter<File>>,
key_buffer: &mut Vec<u8>,
mut del_vectors: Vec<Vec<f32>>,
mut add_vectors: Vec<Vec<f32>>,
delta: VectorStateDelta,
) -> Result<()> {
let (must_remove, prompt, (mut del_vectors, mut add_vectors)) = delta.into_values();
if must_remove {
key_buffer.truncate(TRUNCATE_SIZE);
remove_vectors_writer.insert(&key_buffer, [])?;
}
if !prompt.is_empty() {
key_buffer.truncate(TRUNCATE_SIZE);
prompts_writer.insert(&key_buffer, prompt.as_bytes())?;
}
// We sort and dedup the vectors
del_vectors.sort_unstable_by(|a, b| compare_vectors(a, b));
add_vectors.sort_unstable_by(|a, b| compare_vectors(a, b));
@ -114,7 +291,7 @@ fn push_vectors_diff(
let mut obkv = KvWriterDelAdd::memory();
obkv.insert(DelAdd::Deletion, cast_slice(&vector))?;
let bytes = obkv.into_inner()?;
writer.insert(&key_buffer, bytes)?;
manual_vectors_writer.insert(&key_buffer, bytes)?;
}
EitherOrBoth::Right(vector) => {
// We insert only the Add part of the Obkv to inform
@ -122,7 +299,7 @@ fn push_vectors_diff(
let mut obkv = KvWriterDelAdd::memory();
obkv.insert(DelAdd::Addition, cast_slice(&vector))?;
let bytes = obkv.into_inner()?;
writer.insert(&key_buffer, bytes)?;
manual_vectors_writer.insert(&key_buffer, bytes)?;
}
}
}
@ -136,13 +313,112 @@ fn compare_vectors(a: &[f32], b: &[f32]) -> Ordering {
}
/// Extracts the vectors from a JSON value.
fn extract_vectors(value: &[u8], document_id: impl Fn() -> Value) -> Result<Option<Vec<Vec<f32>>>> {
match from_slice(value) {
Ok(vectors) => Ok(VectorOrArrayOfVectors::into_array_of_vectors(vectors)),
fn extract_vectors(
value: Value,
document_id: impl Fn() -> Value,
name: &str,
) -> Result<Vec<Vec<f32>>> {
// FIXME: ugly clone of the vectors here
match serde_json::from_value(value.clone()) {
Ok(vectors) => {
Ok(VectorOrArrayOfVectors::into_array_of_vectors(vectors).unwrap_or_default())
}
Err(_) => Err(UserError::InvalidVectorsType {
document_id: document_id(),
value: from_slice(value).map_err(InternalError::SerdeJson)?,
value,
subfield: name.to_owned(),
}
.into()),
}
}
#[logging_timer::time]
pub fn extract_embeddings<R: io::Read + io::Seek>(
// docid, prompt
prompt_reader: grenad::Reader<R>,
indexer: GrenadParameters,
embedder: Arc<Embedder>,
) -> Result<grenad::Reader<BufReader<File>>> {
let rt = tokio::runtime::Builder::new_current_thread().enable_io().enable_time().build()?;
let n_chunks = embedder.chunk_count_hint(); // chunk level parellelism
let n_vectors_per_chunk = embedder.prompt_count_in_chunk_hint(); // number of vectors in a single chunk
// docid, state with embedding
let mut state_writer = create_writer(
indexer.chunk_compression_type,
indexer.chunk_compression_level,
tempfile::tempfile()?,
);
let mut chunks = Vec::with_capacity(n_chunks);
let mut current_chunk = Vec::with_capacity(n_vectors_per_chunk);
let mut current_chunk_ids = Vec::with_capacity(n_vectors_per_chunk);
let mut chunks_ids = Vec::with_capacity(n_chunks);
let mut cursor = prompt_reader.into_cursor()?;
while let Some((key, value)) = cursor.move_on_next()? {
let docid = key.try_into().map(DocumentId::from_be_bytes).unwrap();
// SAFETY: precondition, the grenad value was saved from a string
let prompt = unsafe { std::str::from_utf8_unchecked(value) };
if current_chunk.len() == current_chunk.capacity() {
chunks.push(std::mem::replace(
&mut current_chunk,
Vec::with_capacity(n_vectors_per_chunk),
));
chunks_ids.push(std::mem::replace(
&mut current_chunk_ids,
Vec::with_capacity(n_vectors_per_chunk),
));
};
current_chunk.push(prompt.to_owned());
current_chunk_ids.push(docid);
if chunks.len() == chunks.capacity() {
let chunked_embeds = rt
.block_on(
embedder
.embed_chunks(std::mem::replace(&mut chunks, Vec::with_capacity(n_chunks))),
)
.map_err(crate::vector::Error::from)
.map_err(crate::Error::from)?;
for (docid, embeddings) in chunks_ids
.iter()
.flat_map(|docids| docids.iter())
.zip(chunked_embeds.iter().flat_map(|embeds| embeds.iter()))
{
state_writer.insert(docid.to_be_bytes(), cast_slice(embeddings.as_inner()))?;
}
chunks_ids.clear();
}
}
// send last chunk
if !chunks.is_empty() {
let chunked_embeds = rt
.block_on(embedder.embed_chunks(std::mem::take(&mut chunks)))
.map_err(crate::vector::Error::from)
.map_err(crate::Error::from)?;
for (docid, embeddings) in chunks_ids
.iter()
.flat_map(|docids| docids.iter())
.zip(chunked_embeds.iter().flat_map(|embeds| embeds.iter()))
{
state_writer.insert(docid.to_be_bytes(), cast_slice(embeddings.as_inner()))?;
}
}
if !current_chunk.is_empty() {
let embeds = rt
.block_on(embedder.embed(std::mem::take(&mut current_chunk)))
.map_err(crate::vector::Error::from)
.map_err(crate::Error::from)?;
for (docid, embeddings) in current_chunk_ids.iter().zip(embeds.iter()) {
state_writer.insert(docid.to_be_bytes(), cast_slice(embeddings.as_inner()))?;
}
}
writer_into_reader(state_writer)
}

View File

@ -23,7 +23,9 @@ use self::extract_facet_string_docids::extract_facet_string_docids;
use self::extract_fid_docid_facet_values::{extract_fid_docid_facet_values, ExtractedFacetValues};
use self::extract_fid_word_count_docids::extract_fid_word_count_docids;
use self::extract_geo_points::extract_geo_points;
use self::extract_vector_points::extract_vector_points;
use self::extract_vector_points::{
extract_embeddings, extract_vector_points, ExtractedVectorPoints,
};
use self::extract_word_docids::extract_word_docids;
use self::extract_word_pair_proximity_docids::extract_word_pair_proximity_docids;
use self::extract_word_position_docids::extract_word_position_docids;
@ -33,7 +35,8 @@ use super::helpers::{
};
use super::{helpers, TypedChunk};
use crate::proximity::ProximityPrecision;
use crate::{FieldId, Result};
use crate::vector::EmbeddingConfigs;
use crate::{FieldId, FieldsIdsMap, Result};
/// Extract data for each databases from obkv documents in parallel.
/// Send data in grenad file over provided Sender.
@ -47,13 +50,14 @@ pub(crate) fn data_from_obkv_documents(
faceted_fields: HashSet<FieldId>,
primary_key_id: FieldId,
geo_fields_ids: Option<(FieldId, FieldId)>,
vectors_field_id: Option<FieldId>,
field_id_map: FieldsIdsMap,
stop_words: Option<fst::Set<&[u8]>>,
allowed_separators: Option<&[&str]>,
dictionary: Option<&[&str]>,
max_positions_per_attributes: Option<u32>,
exact_attributes: HashSet<FieldId>,
proximity_precision: ProximityPrecision,
embedders: EmbeddingConfigs,
) -> Result<()> {
puffin::profile_function!();
@ -64,7 +68,8 @@ pub(crate) fn data_from_obkv_documents(
original_documents_chunk,
indexer,
lmdb_writer_sx.clone(),
vectors_field_id,
field_id_map.clone(),
embedders.clone(),
)
})
.collect::<Result<()>>()?;
@ -152,7 +157,7 @@ pub(crate) fn data_from_obkv_documents(
});
}
if proximity_precision == ProximityPrecision::WordScale {
if proximity_precision == ProximityPrecision::ByWord {
spawn_extraction_task::<_, _, Vec<grenad::Reader<BufReader<File>>>>(
docid_word_positions_chunks.clone(),
indexer,
@ -276,24 +281,53 @@ fn send_original_documents_data(
original_documents_chunk: Result<grenad::Reader<BufReader<File>>>,
indexer: GrenadParameters,
lmdb_writer_sx: Sender<Result<TypedChunk>>,
vectors_field_id: Option<FieldId>,
field_id_map: FieldsIdsMap,
embedders: EmbeddingConfigs,
) -> Result<()> {
let original_documents_chunk =
original_documents_chunk.and_then(|c| unsafe { as_cloneable_grenad(&c) })?;
if let Some(vectors_field_id) = vectors_field_id {
let documents_chunk_cloned = original_documents_chunk.clone();
let lmdb_writer_sx_cloned = lmdb_writer_sx.clone();
rayon::spawn(move || {
let result = extract_vector_points(documents_chunk_cloned, indexer, vectors_field_id);
let _ = match result {
Ok(vector_points) => {
lmdb_writer_sx_cloned.send(Ok(TypedChunk::VectorPoints(vector_points)))
let documents_chunk_cloned = original_documents_chunk.clone();
let lmdb_writer_sx_cloned = lmdb_writer_sx.clone();
rayon::spawn(move || {
for (name, (embedder, prompt)) in embedders {
let result = extract_vector_points(
documents_chunk_cloned.clone(),
indexer,
&field_id_map,
&prompt,
&name,
);
match result {
Ok(ExtractedVectorPoints { manual_vectors, remove_vectors, prompts }) => {
let embeddings = match extract_embeddings(prompts, indexer, embedder.clone()) {
Ok(results) => Some(results),
Err(error) => {
let _ = lmdb_writer_sx_cloned.send(Err(error));
None
}
};
if !(remove_vectors.is_empty()
&& manual_vectors.is_empty()
&& embeddings.as_ref().map_or(true, |e| e.is_empty()))
{
let _ = lmdb_writer_sx_cloned.send(Ok(TypedChunk::VectorPoints {
remove_vectors,
embeddings,
expected_dimension: embedder.dimensions(),
manual_vectors,
embedder_name: name,
}));
}
}
Err(error) => lmdb_writer_sx_cloned.send(Err(error)),
};
});
}
Err(error) => {
let _ = lmdb_writer_sx_cloned.send(Err(error));
}
}
}
});
// TODO: create a custom internal error
lmdb_writer_sx.send(Ok(TypedChunk::Documents(original_documents_chunk))).unwrap();

View File

@ -16,7 +16,7 @@ pub use merge_functions::{
keep_first, keep_latest_obkv, merge_btreeset_string, merge_cbo_roaring_bitmaps,
merge_deladd_cbo_roaring_bitmaps, merge_deladd_cbo_roaring_bitmaps_into_cbo_roaring_bitmap,
merge_roaring_bitmaps, obkvs_keep_last_addition_merge_deletions,
obkvs_merge_additions_and_deletions, serialize_roaring_bitmap, MergeFn,
obkvs_merge_additions_and_deletions, MergeFn,
};
use crate::MAX_WORD_LENGTH;

View File

@ -4,7 +4,7 @@ mod helpers;
mod transform;
mod typed_chunk;
use std::collections::HashSet;
use std::collections::{HashMap, HashSet};
use std::io::{Cursor, Read, Seek};
use std::iter::FromIterator;
use std::num::NonZeroU32;
@ -14,13 +14,14 @@ use crossbeam_channel::{Receiver, Sender};
use heed::types::Str;
use heed::Database;
use log::debug;
use rand::SeedableRng;
use roaring::RoaringBitmap;
use serde::{Deserialize, Serialize};
use slice_group_by::GroupBy;
use typed_chunk::{write_typed_chunk_into_index, TypedChunk};
use self::enrich::enrich_documents_batch;
pub use self::enrich::{extract_finite_float_from_value, validate_geo_from_json, DocumentId};
pub use self::enrich::{extract_finite_float_from_value, DocumentId};
pub use self::helpers::{
as_cloneable_grenad, create_sorter, create_writer, fst_stream_into_hashset,
fst_stream_into_vec, merge_btreeset_string, merge_cbo_roaring_bitmaps,
@ -36,6 +37,7 @@ pub use crate::update::index_documents::helpers::CursorClonableMmap;
use crate::update::{
IndexerConfig, UpdateIndexingStep, WordPrefixDocids, WordPrefixIntegerDocids, WordsPrefixesFst,
};
use crate::vector::EmbeddingConfigs;
use crate::{CboRoaringBitmapCodec, Index, Result};
static MERGED_DATABASE_COUNT: usize = 7;
@ -78,6 +80,7 @@ pub struct IndexDocuments<'t, 'i, 'a, FP, FA> {
should_abort: FA,
added_documents: u64,
deleted_documents: u64,
embedders: EmbeddingConfigs,
}
#[derive(Default, Debug, Clone)]
@ -121,6 +124,7 @@ where
index,
added_documents: 0,
deleted_documents: 0,
embedders: Default::default(),
})
}
@ -167,6 +171,11 @@ where
Ok((self, Ok(indexed_documents)))
}
pub fn with_embedders(mut self, embedders: EmbeddingConfigs) -> Self {
self.embedders = embedders;
self
}
/// Remove a batch of documents from the current builder.
///
/// Returns the number of documents deleted from the builder.
@ -322,17 +331,18 @@ where
// get filterable fields for facet databases
let faceted_fields = self.index.faceted_fields_ids(self.wtxn)?;
// get the fid of the `_geo.lat` and `_geo.lng` fields.
let geo_fields_ids = match self.index.fields_ids_map(self.wtxn)?.id("_geo") {
let mut field_id_map = self.index.fields_ids_map(self.wtxn)?;
// self.index.fields_ids_map($a)? ==>> field_id_map
let geo_fields_ids = match field_id_map.id("_geo") {
Some(gfid) => {
let is_sortable = self.index.sortable_fields_ids(self.wtxn)?.contains(&gfid);
let is_filterable = self.index.filterable_fields_ids(self.wtxn)?.contains(&gfid);
// if `_geo` is faceted then we get the `lat` and `lng`
if is_sortable || is_filterable {
let field_ids = self
.index
.fields_ids_map(self.wtxn)?
let field_ids = field_id_map
.insert("_geo.lat")
.zip(self.index.fields_ids_map(self.wtxn)?.insert("_geo.lng"))
.zip(field_id_map.insert("_geo.lng"))
.ok_or(UserError::AttributeLimitReached)?;
Some(field_ids)
} else {
@ -341,8 +351,6 @@ where
}
None => None,
};
// get the fid of the `_vectors` field.
let vectors_field_id = self.index.fields_ids_map(self.wtxn)?.id("_vectors");
let stop_words = self.index.stop_words(self.wtxn)?;
let separators = self.index.allowed_separators(self.wtxn)?;
@ -364,6 +372,8 @@ where
self.indexer_config.documents_chunk_size.unwrap_or(1024 * 1024 * 4); // 4MiB
let max_positions_per_attributes = self.indexer_config.max_positions_per_attributes;
let cloned_embedder = self.embedders.clone();
// Run extraction pipeline in parallel.
pool.install(|| {
puffin::profile_scope!("extract_and_send_grenad_chunks");
@ -387,13 +397,14 @@ where
faceted_fields,
primary_key_id,
geo_fields_ids,
vectors_field_id,
field_id_map,
stop_words,
separators.as_deref(),
dictionary.as_deref(),
max_positions_per_attributes,
exact_attributes,
proximity_precision,
cloned_embedder,
)
});
@ -402,7 +413,7 @@ where
}
// needs to be dropped to avoid channel waiting lock.
drop(lmdb_writer_sx)
drop(lmdb_writer_sx);
});
let index_is_empty = self.index.number_of_documents(self.wtxn)? == 0;
@ -419,6 +430,8 @@ where
let mut word_docids = None;
let mut exact_word_docids = None;
let mut dimension = HashMap::new();
for result in lmdb_writer_rx {
if (self.should_abort)() {
return Err(Error::InternalError(InternalError::AbortedIndexation));
@ -448,6 +461,22 @@ where
word_position_docids = Some(cloneable_chunk);
TypedChunk::WordPositionDocids(chunk)
}
TypedChunk::VectorPoints {
expected_dimension,
remove_vectors,
embeddings,
manual_vectors,
embedder_name,
} => {
dimension.insert(embedder_name.clone(), expected_dimension);
TypedChunk::VectorPoints {
remove_vectors,
embeddings,
expected_dimension,
manual_vectors,
embedder_name,
}
}
otherwise => otherwise,
};
@ -480,6 +509,33 @@ where
// We write the primary key field id into the main database
self.index.put_primary_key(self.wtxn, &primary_key)?;
let number_of_documents = self.index.number_of_documents(self.wtxn)?;
let mut rng = rand::rngs::StdRng::seed_from_u64(42);
for (embedder_name, dimension) in dimension {
let wtxn = &mut *self.wtxn;
let vector_arroy = self.index.vector_arroy;
let embedder_index = self.index.embedder_category_id.get(wtxn, &embedder_name)?.ok_or(
InternalError::DatabaseMissingEntry { db_name: "embedder_category_id", key: None },
)?;
pool.install(|| {
let writer_index = (embedder_index as u16) << 8;
for k in 0..=u8::MAX {
let writer = arroy::Writer::prepare(
wtxn,
vector_arroy,
writer_index | (k as u16),
dimension,
)?;
if writer.is_empty(wtxn)? {
break;
}
writer.build(wtxn, &mut rng, None)?;
}
Result::Ok(())
})?;
}
self.execute_prefix_databases(
word_docids,
@ -694,6 +750,8 @@ fn execute_word_prefix_docids(
#[cfg(test)]
mod tests {
use std::collections::BTreeMap;
use big_s::S;
use fst::IntoStreamer;
use heed::RwTxn;
@ -703,6 +761,7 @@ mod tests {
use crate::documents::documents_batch_reader_from_objects;
use crate::index::tests::TempIndex;
use crate::search::TermsMatchingStrategy;
use crate::update::Setting;
use crate::{db_snap, Filter, Search};
#[test]
@ -2494,18 +2553,41 @@ mod tests {
/// Vectors must be of the same length.
#[test]
fn test_multiple_vectors() {
use crate::vector::settings::EmbeddingSettings;
let index = TempIndex::new();
index.add_documents(documents!([{"id": 0, "_vectors": [[0, 1, 2], [3, 4, 5]] }])).unwrap();
index.add_documents(documents!([{"id": 1, "_vectors": [6, 7, 8] }])).unwrap();
index
.add_documents(
documents!([{"id": 2, "_vectors": [[9, 10, 11], [12, 13, 14], [15, 16, 17]] }]),
)
.update_settings(|settings| {
let mut embedders = BTreeMap::default();
embedders.insert(
"manual".to_string(),
Setting::Set(EmbeddingSettings {
source: Setting::Set(crate::vector::settings::EmbedderSource::UserProvided),
model: Setting::NotSet,
revision: Setting::NotSet,
api_key: Setting::NotSet,
dimensions: Setting::Set(3),
document_template: Setting::NotSet,
}),
);
settings.set_embedder_settings(embedders);
})
.unwrap();
index
.add_documents(
documents!([{"id": 0, "_vectors": { "manual": [[0, 1, 2], [3, 4, 5]] } }]),
)
.unwrap();
index.add_documents(documents!([{"id": 1, "_vectors": { "manual": [6, 7, 8] }}])).unwrap();
index
.add_documents(
documents!([{"id": 2, "_vectors": { "manual": [[9, 10, 11], [12, 13, 14], [15, 16, 17]] }}]),
)
.unwrap();
let rtxn = index.read_txn().unwrap();
let res = index.search(&rtxn).vector([0.0, 1.0, 2.0]).execute().unwrap();
let res = index.search(&rtxn).vector([0.0, 1.0, 2.0].to_vec()).execute().unwrap();
assert_eq!(res.documents_ids.len(), 3);
}

View File

@ -1,4 +1,4 @@
use std::collections::{HashMap, HashSet};
use std::collections::HashMap;
use std::convert::TryInto;
use std::fs::File;
use std::io::{self, BufReader};
@ -8,9 +8,7 @@ use charabia::{Language, Script};
use grenad::MergerBuilder;
use heed::types::Bytes;
use heed::{PutFlags, RwTxn};
use log::error;
use obkv::{KvReader, KvWriter};
use ordered_float::OrderedFloat;
use roaring::RoaringBitmap;
use super::helpers::{
@ -18,16 +16,15 @@ use super::helpers::{
valid_lmdb_key, CursorClonableMmap,
};
use super::{ClonableMmap, MergeFn};
use crate::distance::NDotProductPoint;
use crate::error::UserError;
use crate::external_documents_ids::{DocumentOperation, DocumentOperationKind};
use crate::facet::FacetType;
use crate::index::db_name::DOCUMENTS;
use crate::index::Hnsw;
use crate::update::del_add::{deladd_serialize_add_side, DelAdd, KvReaderDelAdd};
use crate::update::facet::FacetsUpdate;
use crate::update::index_documents::helpers::{as_cloneable_grenad, try_split_array_at};
use crate::{lat_lng_to_xyz, DocumentId, FieldId, GeoPoint, Index, Result, SerializationError};
use crate::{
lat_lng_to_xyz, DocumentId, FieldId, GeoPoint, Index, InternalError, Result, SerializationError,
};
pub(crate) enum TypedChunk {
FieldIdDocidFacetStrings(grenad::Reader<CursorClonableMmap>),
@ -47,7 +44,13 @@ pub(crate) enum TypedChunk {
FieldIdFacetIsNullDocids(grenad::Reader<BufReader<File>>),
FieldIdFacetIsEmptyDocids(grenad::Reader<BufReader<File>>),
GeoPoints(grenad::Reader<BufReader<File>>),
VectorPoints(grenad::Reader<BufReader<File>>),
VectorPoints {
remove_vectors: grenad::Reader<BufReader<File>>,
embeddings: Option<grenad::Reader<BufReader<File>>>,
expected_dimension: usize,
manual_vectors: grenad::Reader<BufReader<File>>,
embedder_name: String,
},
ScriptLanguageDocids(HashMap<(Script, Language), (RoaringBitmap, RoaringBitmap)>),
}
@ -100,8 +103,8 @@ impl TypedChunk {
TypedChunk::GeoPoints(grenad) => {
format!("GeoPoints {{ number_of_entries: {} }}", grenad.len())
}
TypedChunk::VectorPoints(grenad) => {
format!("VectorPoints {{ number_of_entries: {} }}", grenad.len())
TypedChunk::VectorPoints{ remove_vectors, manual_vectors, embeddings, expected_dimension, embedder_name } => {
format!("VectorPoints {{ remove_vectors: {}, manual_vectors: {}, embeddings: {}, dimension: {}, embedder_name: {} }}", remove_vectors.len(), manual_vectors.len(), embeddings.as_ref().map(|e| e.len()).unwrap_or_default(), expected_dimension, embedder_name)
}
TypedChunk::ScriptLanguageDocids(sl_map) => {
format!("ScriptLanguageDocids {{ number_of_entries: {} }}", sl_map.len())
@ -355,19 +358,77 @@ pub(crate) fn write_typed_chunk_into_index(
index.put_geo_rtree(wtxn, &rtree)?;
index.put_geo_faceted_documents_ids(wtxn, &geo_faceted_docids)?;
}
TypedChunk::VectorPoints(vector_points) => {
let mut vectors_set = HashSet::new();
// We extract and store the previous vectors
if let Some(hnsw) = index.vector_hnsw(wtxn)? {
for (pid, point) in hnsw.iter() {
let pid_key = pid.into_inner();
let docid = index.vector_id_docid.get(wtxn, &pid_key)?.unwrap();
let vector: Vec<_> = point.iter().copied().map(OrderedFloat).collect();
vectors_set.insert((docid, vector));
TypedChunk::VectorPoints {
remove_vectors,
manual_vectors,
embeddings,
expected_dimension,
embedder_name,
} => {
let embedder_index = index.embedder_category_id.get(wtxn, &embedder_name)?.ok_or(
InternalError::DatabaseMissingEntry { db_name: "embedder_category_id", key: None },
)?;
let writer_index = (embedder_index as u16) << 8;
// FIXME: allow customizing distance
let writers: std::result::Result<Vec<_>, _> = (0..=u8::MAX)
.map(|k| {
arroy::Writer::prepare(
wtxn,
index.vector_arroy,
writer_index | (k as u16),
expected_dimension,
)
})
.collect();
let writers = writers?;
// remove vectors for docids we want them removed
let mut cursor = remove_vectors.into_cursor()?;
while let Some((key, _)) = cursor.move_on_next()? {
let docid = key.try_into().map(DocumentId::from_be_bytes).unwrap();
for writer in &writers {
// Uses invariant: vectors are packed in the first writers.
if !writer.del_item(wtxn, docid)? {
break;
}
}
}
let mut cursor = vector_points.into_cursor()?;
// add generated embeddings
if let Some(embeddings) = embeddings {
let mut cursor = embeddings.into_cursor()?;
while let Some((key, value)) = cursor.move_on_next()? {
let docid = key.try_into().map(DocumentId::from_be_bytes).unwrap();
let data = pod_collect_to_vec(value);
// it is a code error to have embeddings and not expected_dimension
let embeddings =
crate::vector::Embeddings::from_inner(data, expected_dimension)
// code error if we somehow got the wrong dimension
.unwrap();
if embeddings.embedding_count() > u8::MAX.into() {
let external_docid = if let Ok(Some(Ok(index))) = index
.external_id_of(wtxn, std::iter::once(docid))
.map(|it| it.into_iter().next())
{
index
} else {
format!("internal docid={docid}")
};
return Err(crate::Error::UserError(crate::UserError::TooManyVectors(
external_docid,
embeddings.embedding_count(),
)));
}
for (embedding, writer) in embeddings.iter().zip(&writers) {
writer.add_item(wtxn, docid, embedding)?;
}
}
}
// perform the manual diff
let mut cursor = manual_vectors.into_cursor()?;
while let Some((key, value)) = cursor.move_on_next()? {
// convert the key back to a u32 (4 bytes)
let (left, _index) = try_split_array_at(key).unwrap();
@ -375,58 +436,52 @@ pub(crate) fn write_typed_chunk_into_index(
let vector_deladd_obkv = KvReaderDelAdd::new(value);
if let Some(value) = vector_deladd_obkv.get(DelAdd::Deletion) {
// convert the vector back to a Vec<f32>
let vector = pod_collect_to_vec(value).into_iter().map(OrderedFloat).collect();
let key = (docid, vector);
if !vectors_set.remove(&key) {
error!("Unable to delete the vector: {:?}", key.1);
let vector: Vec<f32> = pod_collect_to_vec(value);
let mut deleted_index = None;
for (index, writer) in writers.iter().enumerate() {
let Some(candidate) = writer.item_vector(wtxn, docid)? else {
// uses invariant: vectors are packed in the first writers.
break;
};
if candidate == vector {
writer.del_item(wtxn, docid)?;
deleted_index = Some(index);
}
}
// 🥲 enforce invariant: vectors are packed in the first writers.
if let Some(deleted_index) = deleted_index {
let mut last_index_with_a_vector = None;
for (index, writer) in writers.iter().enumerate().skip(deleted_index) {
let Some(candidate) = writer.item_vector(wtxn, docid)? else {
break;
};
last_index_with_a_vector = Some((index, candidate));
}
if let Some((last_index, vector)) = last_index_with_a_vector {
// unwrap: computed the index from the list of writers
let writer = writers.get(last_index).unwrap();
writer.del_item(wtxn, docid)?;
writers.get(deleted_index).unwrap().add_item(wtxn, docid, &vector)?;
}
}
}
if let Some(value) = vector_deladd_obkv.get(DelAdd::Addition) {
// convert the vector back to a Vec<f32>
let vector = pod_collect_to_vec(value).into_iter().map(OrderedFloat).collect();
vectors_set.insert((docid, vector));
}
}
let vector = pod_collect_to_vec(value);
// Extract the most common vector dimension
let expected_dimension_size = {
let mut dims = HashMap::new();
vectors_set.iter().for_each(|(_, v)| *dims.entry(v.len()).or_insert(0) += 1);
dims.into_iter().max_by_key(|(_, count)| *count).map(|(len, _)| len)
};
// Ensure that the vector lengths are correct and
// prepare the vectors before inserting them in the HNSW.
let mut points = Vec::new();
let mut docids = Vec::new();
for (docid, vector) in vectors_set {
if expected_dimension_size.map_or(false, |expected| expected != vector.len()) {
return Err(UserError::InvalidVectorDimensions {
expected: expected_dimension_size.unwrap_or(vector.len()),
found: vector.len(),
// overflow was detected during vector extraction.
for writer in &writers {
if !writer.contains_item(wtxn, docid)? {
writer.add_item(wtxn, docid, &vector)?;
break;
}
}
.into());
} else {
let vector = vector.into_iter().map(OrderedFloat::into_inner).collect();
points.push(NDotProductPoint::new(vector));
docids.push(docid);
}
}
let hnsw_length = points.len();
let (new_hnsw, pids) = Hnsw::builder().build_hnsw(points);
assert_eq!(docids.len(), pids.len());
// Store the vectors in the point-docid relation database
index.vector_id_docid.clear(wtxn)?;
for (docid, pid) in docids.into_iter().zip(pids) {
index.vector_id_docid.put(wtxn, &pid.into_inner(), &docid)?;
}
log::debug!("There are {} entries in the HNSW so far", hnsw_length);
index.put_vector_hnsw(wtxn, &new_hnsw)?;
log::debug!("Finished vector chunk for {}", embedder_name);
}
TypedChunk::ScriptLanguageDocids(sl_map) => {
for (key, (deletion, addition)) in sl_map {

View File

@ -8,7 +8,7 @@ pub use self::index_documents::{
MergeFn,
};
pub use self::indexer_config::IndexerConfig;
pub use self::settings::{Setting, Settings};
pub use self::settings::{validate_embedding_settings, Setting, Settings};
pub use self::update_step::UpdateIndexingStep;
pub use self::word_prefix_docids::WordPrefixDocids;
pub use self::words_prefix_integer_docids::WordPrefixIntegerDocids;

View File

@ -1,9 +1,11 @@
use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet};
use std::convert::TryInto;
use std::result::Result as StdResult;
use std::sync::Arc;
use charabia::{Normalize, Tokenizer, TokenizerBuilder};
use deserr::{DeserializeError, Deserr};
use itertools::Itertools;
use itertools::{EitherOrBoth, Itertools};
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use time::OffsetDateTime;
@ -15,6 +17,8 @@ use crate::index::{DEFAULT_MIN_WORD_LEN_ONE_TYPO, DEFAULT_MIN_WORD_LEN_TWO_TYPOS
use crate::proximity::ProximityPrecision;
use crate::update::index_documents::IndexDocumentsMethod;
use crate::update::{IndexDocuments, UpdateIndexingStep};
use crate::vector::settings::{check_set, check_unset, EmbedderSource, EmbeddingSettings};
use crate::vector::{Embedder, EmbeddingConfig, EmbeddingConfigs};
use crate::{FieldsIdsMap, Index, OrderBy, Result};
#[derive(Debug, Clone, PartialEq, Eq, Copy)]
@ -73,6 +77,21 @@ impl<T> Setting<T> {
otherwise => otherwise,
}
}
/// Returns `true` if applying the new setting changed this setting
pub fn apply(&mut self, new: Self) -> bool
where
T: PartialEq + Eq,
{
if let Setting::NotSet = new {
return false;
}
if self == &new {
return false;
}
*self = new;
true
}
}
impl<T: Serialize> Serialize for Setting<T> {
@ -129,6 +148,7 @@ pub struct Settings<'a, 't, 'i> {
sort_facet_values_by: Setting<HashMap<String, OrderBy>>,
pagination_max_total_hits: Setting<usize>,
proximity_precision: Setting<ProximityPrecision>,
embedder_settings: Setting<BTreeMap<String, Setting<EmbeddingSettings>>>,
}
impl<'a, 't, 'i> Settings<'a, 't, 'i> {
@ -161,6 +181,7 @@ impl<'a, 't, 'i> Settings<'a, 't, 'i> {
sort_facet_values_by: Setting::NotSet,
pagination_max_total_hits: Setting::NotSet,
proximity_precision: Setting::NotSet,
embedder_settings: Setting::NotSet,
indexer_config,
}
}
@ -343,6 +364,14 @@ impl<'a, 't, 'i> Settings<'a, 't, 'i> {
self.proximity_precision = Setting::Reset;
}
pub fn set_embedder_settings(&mut self, value: BTreeMap<String, Setting<EmbeddingSettings>>) {
self.embedder_settings = Setting::Set(value);
}
pub fn reset_embedder_settings(&mut self) {
self.embedder_settings = Setting::Reset;
}
fn reindex<FP, FA>(
&mut self,
progress_callback: &FP,
@ -377,6 +406,9 @@ impl<'a, 't, 'i> Settings<'a, 't, 'i> {
fields_ids_map,
)?;
let embedder_configs = self.index.embedding_configs(self.wtxn)?;
let embedders = self.embedders(embedder_configs)?;
// We index the generated `TransformOutput` which must contain
// all the documents with fields in the newly defined searchable order.
let indexing_builder = IndexDocuments::new(
@ -387,11 +419,33 @@ impl<'a, 't, 'i> Settings<'a, 't, 'i> {
&progress_callback,
&should_abort,
)?;
let indexing_builder = indexing_builder.with_embedders(embedders);
indexing_builder.execute_raw(output)?;
Ok(())
}
fn embedders(
&self,
embedding_configs: Vec<(String, EmbeddingConfig)>,
) -> Result<EmbeddingConfigs> {
let res: Result<_> = embedding_configs
.into_iter()
.map(|(name, EmbeddingConfig { embedder_options, prompt })| {
let prompt = Arc::new(prompt.try_into().map_err(crate::Error::from)?);
let embedder = Arc::new(
Embedder::new(embedder_options.clone())
.map_err(crate::vector::Error::from)
.map_err(crate::Error::from)?,
);
Ok((name, (embedder, prompt)))
})
.collect();
res.map(EmbeddingConfigs::new)
}
fn update_displayed(&mut self) -> Result<bool> {
match self.displayed_fields {
Setting::Set(ref fields) => {
@ -890,6 +944,79 @@ impl<'a, 't, 'i> Settings<'a, 't, 'i> {
Ok(changed)
}
fn update_embedding_configs(&mut self) -> Result<bool> {
let update = match std::mem::take(&mut self.embedder_settings) {
Setting::Set(configs) => {
let mut changed = false;
let old_configs = self.index.embedding_configs(self.wtxn)?;
let old_configs: BTreeMap<String, Setting<EmbeddingSettings>> =
old_configs.into_iter().map(|(k, v)| (k, Setting::Set(v.into()))).collect();
let mut new_configs = BTreeMap::new();
for joined in old_configs
.into_iter()
.merge_join_by(configs.into_iter(), |(left, _), (right, _)| left.cmp(right))
{
match joined {
// updated config
EitherOrBoth::Both((name, mut old), (_, new)) => {
changed |= old.apply(new);
let new = validate_embedding_settings(old, &name)?;
new_configs.insert(name, new);
}
// unchanged config
EitherOrBoth::Left((name, setting)) => {
new_configs.insert(name, setting);
}
// new config
EitherOrBoth::Right((name, mut setting)) => {
// apply the default source in case the source was not set so that it gets validated
crate::vector::settings::EmbeddingSettings::apply_default_source(
&mut setting,
);
let setting = validate_embedding_settings(setting, &name)?;
changed = true;
new_configs.insert(name, setting);
}
}
}
let new_configs: Vec<(String, EmbeddingConfig)> = new_configs
.into_iter()
.filter_map(|(name, setting)| match setting {
Setting::Set(value) => Some((name, value.into())),
Setting::Reset => None,
Setting::NotSet => Some((name, EmbeddingSettings::default().into())),
})
.collect();
self.index.embedder_category_id.clear(self.wtxn)?;
for (index, (embedder_name, _)) in new_configs.iter().enumerate() {
self.index.embedder_category_id.put_with_flags(
self.wtxn,
heed::PutFlags::APPEND,
embedder_name,
&index
.try_into()
.map_err(|_| UserError::TooManyEmbedders(new_configs.len()))?,
)?;
}
if new_configs.is_empty() {
self.index.delete_embedding_configs(self.wtxn)?;
} else {
self.index.put_embedding_configs(self.wtxn, new_configs)?;
}
changed
}
Setting::Reset => {
self.index.delete_embedding_configs(self.wtxn)?;
true
}
Setting::NotSet => false,
};
Ok(update)
}
pub fn execute<FP, FA>(mut self, progress_callback: FP, should_abort: FA) -> Result<()>
where
FP: Fn(UpdateIndexingStep) + Sync,
@ -927,6 +1054,13 @@ impl<'a, 't, 'i> Settings<'a, 't, 'i> {
let searchable_updated = self.update_searchable()?;
let exact_attributes_updated = self.update_exact_attributes()?;
let proximity_precision = self.update_proximity_precision()?;
// TODO: very rough approximation of the needs for reindexing where any change will result in
// a full reindexing.
// What can be done instead:
// 1. Only change the distance on a distance change
// 2. Only change the name -> embedder mapping on a name change
// 3. Keep the old vectors but reattempt indexing on a prompt change: only actually changed prompt will need embedding + storage
let embedding_configs_updated = self.update_embedding_configs()?;
if stop_words_updated
|| non_separator_tokens_updated
@ -937,6 +1071,7 @@ impl<'a, 't, 'i> Settings<'a, 't, 'i> {
|| searchable_updated
|| exact_attributes_updated
|| proximity_precision
|| embedding_configs_updated
{
self.reindex(&progress_callback, &should_abort, old_fields_ids_map)?;
}
@ -945,6 +1080,90 @@ impl<'a, 't, 'i> Settings<'a, 't, 'i> {
}
}
fn validate_prompt(
name: &str,
new: Setting<EmbeddingSettings>,
) -> Result<Setting<EmbeddingSettings>> {
match new {
Setting::Set(EmbeddingSettings {
source,
model,
revision,
api_key,
dimensions,
document_template: Setting::Set(template),
}) => {
// validate
let template = crate::prompt::Prompt::new(template)
.map(|prompt| crate::prompt::PromptData::from(prompt).template)
.map_err(|inner| UserError::InvalidPromptForEmbeddings(name.to_owned(), inner))?;
Ok(Setting::Set(EmbeddingSettings {
source,
model,
revision,
api_key,
dimensions,
document_template: Setting::Set(template),
}))
}
new => Ok(new),
}
}
pub fn validate_embedding_settings(
settings: Setting<EmbeddingSettings>,
name: &str,
) -> Result<Setting<EmbeddingSettings>> {
let settings = validate_prompt(name, settings)?;
let Setting::Set(settings) = settings else { return Ok(settings) };
let EmbeddingSettings { source, model, revision, api_key, dimensions, document_template } =
settings;
let Some(inferred_source) = source.set() else {
return Ok(Setting::Set(EmbeddingSettings {
source,
model,
revision,
api_key,
dimensions,
document_template,
}));
};
match inferred_source {
EmbedderSource::OpenAi => {
check_unset(&revision, "revision", inferred_source, name)?;
check_unset(&dimensions, "dimensions", inferred_source, name)?;
if let Setting::Set(model) = &model {
crate::vector::openai::EmbeddingModel::from_name(model.as_str()).ok_or(
crate::error::UserError::InvalidOpenAiModel {
embedder_name: name.to_owned(),
model: model.clone(),
},
)?;
}
}
EmbedderSource::HuggingFace => {
check_unset(&api_key, "apiKey", inferred_source, name)?;
check_unset(&dimensions, "dimensions", inferred_source, name)?;
}
EmbedderSource::UserProvided => {
check_unset(&model, "model", inferred_source, name)?;
check_unset(&revision, "revision", inferred_source, name)?;
check_unset(&api_key, "apiKey", inferred_source, name)?;
check_unset(&document_template, "documentTemplate", inferred_source, name)?;
check_set(&dimensions, "dimensions", inferred_source, name)?;
}
}
Ok(Setting::Set(EmbeddingSettings {
source,
model,
revision,
api_key,
dimensions,
document_template,
}))
}
#[cfg(test)]
mod tests {
use big_s::S;
@ -1763,6 +1982,7 @@ mod tests {
sort_facet_values_by,
pagination_max_total_hits,
proximity_precision,
embedder_settings,
} = settings;
assert!(matches!(searchable_fields, Setting::NotSet));
assert!(matches!(displayed_fields, Setting::NotSet));
@ -1785,6 +2005,7 @@ mod tests {
assert!(matches!(sort_facet_values_by, Setting::NotSet));
assert!(matches!(pagination_max_total_hits, Setting::NotSet));
assert!(matches!(proximity_precision, Setting::NotSet));
assert!(matches!(embedder_settings, Setting::NotSet));
})
.unwrap();
}

244
milli/src/vector/error.rs Normal file
View File

@ -0,0 +1,244 @@
use std::path::PathBuf;
use hf_hub::api::sync::ApiError;
use crate::error::FaultSource;
use crate::vector::openai::OpenAiError;
#[derive(Debug, thiserror::Error)]
#[error("Error while generating embeddings: {inner}")]
pub struct Error {
pub inner: Box<ErrorKind>,
}
impl<I: Into<ErrorKind>> From<I> for Error {
fn from(value: I) -> Self {
Self { inner: Box::new(value.into()) }
}
}
impl Error {
pub fn fault(&self) -> FaultSource {
match &*self.inner {
ErrorKind::NewEmbedderError(inner) => inner.fault,
ErrorKind::EmbedError(inner) => inner.fault,
}
}
}
#[derive(Debug, thiserror::Error)]
pub enum ErrorKind {
#[error(transparent)]
NewEmbedderError(#[from] NewEmbedderError),
#[error(transparent)]
EmbedError(#[from] EmbedError),
}
#[derive(Debug, thiserror::Error)]
#[error("{fault}: {kind}")]
pub struct EmbedError {
pub kind: EmbedErrorKind,
pub fault: FaultSource,
}
#[derive(Debug, thiserror::Error)]
pub enum EmbedErrorKind {
#[error("could not tokenize: {0}")]
Tokenize(Box<dyn std::error::Error + Send + Sync>),
#[error("unexpected tensor shape: {0}")]
TensorShape(candle_core::Error),
#[error("unexpected tensor value: {0}")]
TensorValue(candle_core::Error),
#[error("could not run model: {0}")]
ModelForward(candle_core::Error),
#[error("could not reach OpenAI: {0}")]
OpenAiNetwork(reqwest::Error),
#[error("unexpected response from OpenAI: {0}")]
OpenAiUnexpected(reqwest::Error),
#[error("could not authenticate against OpenAI: {0}")]
OpenAiAuth(OpenAiError),
#[error("sent too many requests to OpenAI: {0}")]
OpenAiTooManyRequests(OpenAiError),
#[error("received internal error from OpenAI: {0}")]
OpenAiInternalServerError(OpenAiError),
#[error("sent too many tokens in a request to OpenAI: {0}")]
OpenAiTooManyTokens(OpenAiError),
#[error("received unhandled HTTP status code {0} from OpenAI")]
OpenAiUnhandledStatusCode(u16),
#[error("attempt to embed the following text in a configuration where embeddings must be user provided: {0:?}")]
ManualEmbed(String),
}
impl EmbedError {
pub fn tokenize(inner: Box<dyn std::error::Error + Send + Sync>) -> Self {
Self { kind: EmbedErrorKind::Tokenize(inner), fault: FaultSource::Runtime }
}
pub fn tensor_shape(inner: candle_core::Error) -> Self {
Self { kind: EmbedErrorKind::TensorShape(inner), fault: FaultSource::Bug }
}
pub fn tensor_value(inner: candle_core::Error) -> Self {
Self { kind: EmbedErrorKind::TensorValue(inner), fault: FaultSource::Bug }
}
pub fn model_forward(inner: candle_core::Error) -> Self {
Self { kind: EmbedErrorKind::ModelForward(inner), fault: FaultSource::Runtime }
}
pub fn openai_network(inner: reqwest::Error) -> Self {
Self { kind: EmbedErrorKind::OpenAiNetwork(inner), fault: FaultSource::Runtime }
}
pub fn openai_unexpected(inner: reqwest::Error) -> EmbedError {
Self { kind: EmbedErrorKind::OpenAiUnexpected(inner), fault: FaultSource::Bug }
}
pub(crate) fn openai_auth_error(inner: OpenAiError) -> EmbedError {
Self { kind: EmbedErrorKind::OpenAiAuth(inner), fault: FaultSource::User }
}
pub(crate) fn openai_too_many_requests(inner: OpenAiError) -> EmbedError {
Self { kind: EmbedErrorKind::OpenAiTooManyRequests(inner), fault: FaultSource::Runtime }
}
pub(crate) fn openai_internal_server_error(inner: OpenAiError) -> EmbedError {
Self { kind: EmbedErrorKind::OpenAiInternalServerError(inner), fault: FaultSource::Runtime }
}
pub(crate) fn openai_too_many_tokens(inner: OpenAiError) -> EmbedError {
Self { kind: EmbedErrorKind::OpenAiTooManyTokens(inner), fault: FaultSource::Bug }
}
pub(crate) fn openai_unhandled_status_code(code: u16) -> EmbedError {
Self { kind: EmbedErrorKind::OpenAiUnhandledStatusCode(code), fault: FaultSource::Bug }
}
pub(crate) fn embed_on_manual_embedder(texts: String) -> EmbedError {
Self { kind: EmbedErrorKind::ManualEmbed(texts), fault: FaultSource::User }
}
}
#[derive(Debug, thiserror::Error)]
#[error("{fault}: {kind}")]
pub struct NewEmbedderError {
pub kind: NewEmbedderErrorKind,
pub fault: FaultSource,
}
impl NewEmbedderError {
pub fn open_config(config_filename: PathBuf, inner: std::io::Error) -> NewEmbedderError {
let open_config = OpenConfig { filename: config_filename, inner };
Self { kind: NewEmbedderErrorKind::OpenConfig(open_config), fault: FaultSource::Runtime }
}
pub fn deserialize_config(
config: String,
config_filename: PathBuf,
inner: serde_json::Error,
) -> NewEmbedderError {
let deserialize_config = DeserializeConfig { config, filename: config_filename, inner };
Self {
kind: NewEmbedderErrorKind::DeserializeConfig(deserialize_config),
fault: FaultSource::Runtime,
}
}
pub fn open_tokenizer(
tokenizer_filename: PathBuf,
inner: Box<dyn std::error::Error + Send + Sync>,
) -> NewEmbedderError {
let open_tokenizer = OpenTokenizer { filename: tokenizer_filename, inner };
Self {
kind: NewEmbedderErrorKind::OpenTokenizer(open_tokenizer),
fault: FaultSource::Runtime,
}
}
pub fn new_api_fail(inner: ApiError) -> Self {
Self { kind: NewEmbedderErrorKind::NewApiFail(inner), fault: FaultSource::Bug }
}
pub fn api_get(inner: ApiError) -> Self {
Self { kind: NewEmbedderErrorKind::ApiGet(inner), fault: FaultSource::Undecided }
}
pub fn pytorch_weight(inner: candle_core::Error) -> Self {
Self { kind: NewEmbedderErrorKind::PytorchWeight(inner), fault: FaultSource::Runtime }
}
pub fn safetensor_weight(inner: candle_core::Error) -> Self {
Self { kind: NewEmbedderErrorKind::PytorchWeight(inner), fault: FaultSource::Runtime }
}
pub fn load_model(inner: candle_core::Error) -> Self {
Self { kind: NewEmbedderErrorKind::LoadModel(inner), fault: FaultSource::Runtime }
}
pub fn hf_could_not_determine_dimension(inner: EmbedError) -> NewEmbedderError {
Self {
kind: NewEmbedderErrorKind::CouldNotDetermineDimension(inner),
fault: FaultSource::Runtime,
}
}
pub fn openai_initialize_web_client(inner: reqwest::Error) -> Self {
Self { kind: NewEmbedderErrorKind::InitWebClient(inner), fault: FaultSource::Runtime }
}
pub fn openai_invalid_api_key_format(inner: reqwest::header::InvalidHeaderValue) -> Self {
Self { kind: NewEmbedderErrorKind::InvalidApiKeyFormat(inner), fault: FaultSource::User }
}
}
#[derive(Debug, thiserror::Error)]
#[error("could not open config at {filename:?}: {inner}")]
pub struct OpenConfig {
pub filename: PathBuf,
pub inner: std::io::Error,
}
#[derive(Debug, thiserror::Error)]
#[error("could not deserialize config at {filename}: {inner}. Config follows:\n{config}")]
pub struct DeserializeConfig {
pub config: String,
pub filename: PathBuf,
pub inner: serde_json::Error,
}
#[derive(Debug, thiserror::Error)]
#[error("could not open tokenizer at {filename}: {inner}")]
pub struct OpenTokenizer {
pub filename: PathBuf,
#[source]
pub inner: Box<dyn std::error::Error + Send + Sync>,
}
#[derive(Debug, thiserror::Error)]
pub enum NewEmbedderErrorKind {
// hf
#[error(transparent)]
OpenConfig(OpenConfig),
#[error(transparent)]
DeserializeConfig(DeserializeConfig),
#[error(transparent)]
OpenTokenizer(OpenTokenizer),
#[error("could not build weights from Pytorch weights: {0}")]
PytorchWeight(candle_core::Error),
#[error("could not build weights from Safetensor weights: {0}")]
SafetensorWeight(candle_core::Error),
#[error("could not spawn HG_HUB API client: {0}")]
NewApiFail(ApiError),
#[error("fetching file from HG_HUB failed: {0}")]
ApiGet(ApiError),
#[error("could not determine model dimensions: test embedding failed with {0}")]
CouldNotDetermineDimension(EmbedError),
#[error("loading model failed: {0}")]
LoadModel(candle_core::Error),
// openai
#[error("initializing web client for sending embedding requests failed: {0}")]
InitWebClient(reqwest::Error),
#[error("The API key passed to Authorization error was in an invalid format: {0}")]
InvalidApiKeyFormat(reqwest::header::InvalidHeaderValue),
}

195
milli/src/vector/hf.rs Normal file
View File

@ -0,0 +1,195 @@
use candle_core::Tensor;
use candle_nn::VarBuilder;
use candle_transformers::models::bert::{BertModel, Config, DTYPE};
// FIXME: currently we'll be using the hub to retrieve model, in the future we might want to embed it into Meilisearch itself
use hf_hub::api::sync::Api;
use hf_hub::{Repo, RepoType};
use tokenizers::{PaddingParams, Tokenizer};
pub use super::error::{EmbedError, Error, NewEmbedderError};
use super::{DistributionShift, Embedding, Embeddings};
#[derive(
Debug,
Clone,
Copy,
Default,
Hash,
PartialEq,
Eq,
serde::Deserialize,
serde::Serialize,
deserr::Deserr,
)]
#[serde(deny_unknown_fields, rename_all = "camelCase")]
#[deserr(rename_all = camelCase, deny_unknown_fields)]
enum WeightSource {
#[default]
Safetensors,
Pytorch,
}
#[derive(Debug, Clone, Hash, PartialEq, Eq, serde::Deserialize, serde::Serialize)]
pub struct EmbedderOptions {
pub model: String,
pub revision: Option<String>,
}
impl EmbedderOptions {
pub fn new() -> Self {
Self {
model: "BAAI/bge-base-en-v1.5".to_string(),
revision: Some("617ca489d9e86b49b8167676d8220688b99db36e".into()),
}
}
}
impl Default for EmbedderOptions {
fn default() -> Self {
Self::new()
}
}
/// Perform embedding of documents and queries
pub struct Embedder {
model: BertModel,
tokenizer: Tokenizer,
options: EmbedderOptions,
dimensions: usize,
}
impl std::fmt::Debug for Embedder {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Embedder")
.field("model", &self.options.model)
.field("tokenizer", &self.tokenizer)
.field("options", &self.options)
.finish()
}
}
impl Embedder {
pub fn new(options: EmbedderOptions) -> std::result::Result<Self, NewEmbedderError> {
let device = candle_core::Device::Cpu;
let repo = match options.revision.clone() {
Some(revision) => Repo::with_revision(options.model.clone(), RepoType::Model, revision),
None => Repo::model(options.model.clone()),
};
let (config_filename, tokenizer_filename, weights_filename, weight_source) = {
let api = Api::new().map_err(NewEmbedderError::new_api_fail)?;
let api = api.repo(repo);
let config = api.get("config.json").map_err(NewEmbedderError::api_get)?;
let tokenizer = api.get("tokenizer.json").map_err(NewEmbedderError::api_get)?;
let (weights, source) = {
api.get("pytorch_model.bin")
.map(|filename| (filename, WeightSource::Pytorch))
.or_else(|_| {
api.get("model.safetensors")
.map(|filename| (filename, WeightSource::Safetensors))
})
.map_err(NewEmbedderError::api_get)?
};
(config, tokenizer, weights, source)
};
let config = std::fs::read_to_string(&config_filename)
.map_err(|inner| NewEmbedderError::open_config(config_filename.clone(), inner))?;
let config: Config = serde_json::from_str(&config).map_err(|inner| {
NewEmbedderError::deserialize_config(config, config_filename, inner)
})?;
let mut tokenizer = Tokenizer::from_file(&tokenizer_filename)
.map_err(|inner| NewEmbedderError::open_tokenizer(tokenizer_filename, inner))?;
let vb = match weight_source {
WeightSource::Pytorch => VarBuilder::from_pth(&weights_filename, DTYPE, &device)
.map_err(NewEmbedderError::pytorch_weight)?,
WeightSource::Safetensors => unsafe {
VarBuilder::from_mmaped_safetensors(&[weights_filename], DTYPE, &device)
.map_err(NewEmbedderError::safetensor_weight)?
},
};
let model = BertModel::load(vb, &config).map_err(NewEmbedderError::load_model)?;
if let Some(pp) = tokenizer.get_padding_mut() {
pp.strategy = tokenizers::PaddingStrategy::BatchLongest
} else {
let pp = PaddingParams {
strategy: tokenizers::PaddingStrategy::BatchLongest,
..Default::default()
};
tokenizer.with_padding(Some(pp));
}
let mut this = Self { model, tokenizer, options, dimensions: 0 };
let embeddings = this
.embed(vec!["test".into()])
.map_err(NewEmbedderError::hf_could_not_determine_dimension)?;
this.dimensions = embeddings.first().unwrap().dimension();
Ok(this)
}
pub fn embed(
&self,
mut texts: Vec<String>,
) -> std::result::Result<Vec<Embeddings<f32>>, EmbedError> {
let tokens = match texts.len() {
1 => vec![self
.tokenizer
.encode(texts.pop().unwrap(), true)
.map_err(EmbedError::tokenize)?],
_ => self.tokenizer.encode_batch(texts, true).map_err(EmbedError::tokenize)?,
};
let token_ids = tokens
.iter()
.map(|tokens| {
let tokens = tokens.get_ids().to_vec();
Tensor::new(tokens.as_slice(), &self.model.device).map_err(EmbedError::tensor_shape)
})
.collect::<Result<Vec<_>, EmbedError>>()?;
let token_ids = Tensor::stack(&token_ids, 0).map_err(EmbedError::tensor_shape)?;
let token_type_ids = token_ids.zeros_like().map_err(EmbedError::tensor_shape)?;
let embeddings =
self.model.forward(&token_ids, &token_type_ids).map_err(EmbedError::model_forward)?;
// Apply some avg-pooling by taking the mean embedding value for all tokens (including padding)
let (_n_sentence, n_tokens, _hidden_size) =
embeddings.dims3().map_err(EmbedError::tensor_shape)?;
let embeddings = (embeddings.sum(1).map_err(EmbedError::tensor_value)? / (n_tokens as f64))
.map_err(EmbedError::tensor_shape)?;
let embeddings: Vec<Embedding> = embeddings.to_vec2().map_err(EmbedError::tensor_shape)?;
Ok(embeddings.into_iter().map(Embeddings::from_single_embedding).collect())
}
pub fn embed_chunks(
&self,
text_chunks: Vec<Vec<String>>,
) -> std::result::Result<Vec<Vec<Embeddings<f32>>>, EmbedError> {
text_chunks.into_iter().map(|prompts| self.embed(prompts)).collect()
}
pub fn chunk_count_hint(&self) -> usize {
1
}
pub fn prompt_count_in_chunk_hint(&self) -> usize {
std::thread::available_parallelism().map(|x| x.get()).unwrap_or(8)
}
pub fn dimensions(&self) -> usize {
self.dimensions
}
pub fn distribution(&self) -> Option<DistributionShift> {
if self.options.model == "BAAI/bge-base-en-v1.5" {
Some(DistributionShift { current_mean: 0.85, current_sigma: 0.1 })
} else {
None
}
}
}

View File

@ -0,0 +1,34 @@
use super::error::EmbedError;
use super::Embeddings;
#[derive(Debug, Clone, Copy)]
pub struct Embedder {
dimensions: usize,
}
#[derive(Debug, Clone, Hash, PartialEq, Eq, serde::Deserialize, serde::Serialize)]
pub struct EmbedderOptions {
pub dimensions: usize,
}
impl Embedder {
pub fn new(options: EmbedderOptions) -> Self {
Self { dimensions: options.dimensions }
}
pub fn embed(&self, mut texts: Vec<String>) -> Result<Vec<Embeddings<f32>>, EmbedError> {
let Some(text) = texts.pop() else { return Ok(Default::default()) };
Err(EmbedError::embed_on_manual_embedder(text))
}
pub fn dimensions(&self) -> usize {
self.dimensions
}
pub fn embed_chunks(
&self,
text_chunks: Vec<Vec<String>>,
) -> Result<Vec<Vec<Embeddings<f32>>>, EmbedError> {
text_chunks.into_iter().map(|prompts| self.embed(prompts)).collect()
}
}

257
milli/src/vector/mod.rs Normal file
View File

@ -0,0 +1,257 @@
use std::collections::HashMap;
use std::sync::Arc;
use self::error::{EmbedError, NewEmbedderError};
use crate::prompt::{Prompt, PromptData};
pub mod error;
pub mod hf;
pub mod manual;
pub mod openai;
pub mod settings;
pub use self::error::Error;
pub type Embedding = Vec<f32>;
pub struct Embeddings<F> {
data: Vec<F>,
dimension: usize,
}
impl<F> Embeddings<F> {
pub fn new(dimension: usize) -> Self {
Self { data: Default::default(), dimension }
}
pub fn from_single_embedding(embedding: Vec<F>) -> Self {
Self { dimension: embedding.len(), data: embedding }
}
pub fn from_inner(data: Vec<F>, dimension: usize) -> Result<Self, Vec<F>> {
let mut this = Self::new(dimension);
this.append(data)?;
Ok(this)
}
pub fn embedding_count(&self) -> usize {
self.data.len() / self.dimension
}
pub fn dimension(&self) -> usize {
self.dimension
}
pub fn into_inner(self) -> Vec<F> {
self.data
}
pub fn as_inner(&self) -> &[F] {
&self.data
}
pub fn iter(&self) -> impl Iterator<Item = &'_ [F]> + '_ {
self.data.as_slice().chunks_exact(self.dimension)
}
pub fn push(&mut self, mut embedding: Vec<F>) -> Result<(), Vec<F>> {
if embedding.len() != self.dimension {
return Err(embedding);
}
self.data.append(&mut embedding);
Ok(())
}
pub fn append(&mut self, mut embeddings: Vec<F>) -> Result<(), Vec<F>> {
if embeddings.len() % self.dimension != 0 {
return Err(embeddings);
}
self.data.append(&mut embeddings);
Ok(())
}
}
#[derive(Debug)]
pub enum Embedder {
HuggingFace(hf::Embedder),
OpenAi(openai::Embedder),
UserProvided(manual::Embedder),
}
#[derive(Debug, Clone, Default, serde::Deserialize, serde::Serialize)]
pub struct EmbeddingConfig {
pub embedder_options: EmbedderOptions,
pub prompt: PromptData,
// TODO: add metrics and anything needed
}
#[derive(Clone, Default)]
pub struct EmbeddingConfigs(HashMap<String, (Arc<Embedder>, Arc<Prompt>)>);
impl EmbeddingConfigs {
pub fn new(data: HashMap<String, (Arc<Embedder>, Arc<Prompt>)>) -> Self {
Self(data)
}
pub fn get(&self, name: &str) -> Option<(Arc<Embedder>, Arc<Prompt>)> {
self.0.get(name).cloned()
}
pub fn get_default(&self) -> Option<(Arc<Embedder>, Arc<Prompt>)> {
self.get_default_embedder_name().and_then(|default| self.get(&default))
}
pub fn get_default_embedder_name(&self) -> Option<String> {
let mut it = self.0.keys();
let first_name = it.next();
let second_name = it.next();
match (first_name, second_name) {
(None, _) => None,
(Some(first), None) => Some(first.to_owned()),
(Some(_), Some(_)) => Some("default".to_owned()),
}
}
}
impl IntoIterator for EmbeddingConfigs {
type Item = (String, (Arc<Embedder>, Arc<Prompt>));
type IntoIter = std::collections::hash_map::IntoIter<String, (Arc<Embedder>, Arc<Prompt>)>;
fn into_iter(self) -> Self::IntoIter {
self.0.into_iter()
}
}
#[derive(Debug, Clone, Hash, PartialEq, Eq, serde::Deserialize, serde::Serialize)]
pub enum EmbedderOptions {
HuggingFace(hf::EmbedderOptions),
OpenAi(openai::EmbedderOptions),
UserProvided(manual::EmbedderOptions),
}
impl Default for EmbedderOptions {
fn default() -> Self {
Self::HuggingFace(Default::default())
}
}
impl EmbedderOptions {
pub fn huggingface() -> Self {
Self::HuggingFace(hf::EmbedderOptions::new())
}
pub fn openai(api_key: Option<String>) -> Self {
Self::OpenAi(openai::EmbedderOptions::with_default_model(api_key))
}
}
impl Embedder {
pub fn new(options: EmbedderOptions) -> std::result::Result<Self, NewEmbedderError> {
Ok(match options {
EmbedderOptions::HuggingFace(options) => Self::HuggingFace(hf::Embedder::new(options)?),
EmbedderOptions::OpenAi(options) => Self::OpenAi(openai::Embedder::new(options)?),
EmbedderOptions::UserProvided(options) => {
Self::UserProvided(manual::Embedder::new(options))
}
})
}
pub async fn embed(
&self,
texts: Vec<String>,
) -> std::result::Result<Vec<Embeddings<f32>>, EmbedError> {
match self {
Embedder::HuggingFace(embedder) => embedder.embed(texts),
Embedder::OpenAi(embedder) => embedder.embed(texts).await,
Embedder::UserProvided(embedder) => embedder.embed(texts),
}
}
pub async fn embed_chunks(
&self,
text_chunks: Vec<Vec<String>>,
) -> std::result::Result<Vec<Vec<Embeddings<f32>>>, EmbedError> {
match self {
Embedder::HuggingFace(embedder) => embedder.embed_chunks(text_chunks),
Embedder::OpenAi(embedder) => embedder.embed_chunks(text_chunks).await,
Embedder::UserProvided(embedder) => embedder.embed_chunks(text_chunks),
}
}
pub fn chunk_count_hint(&self) -> usize {
match self {
Embedder::HuggingFace(embedder) => embedder.chunk_count_hint(),
Embedder::OpenAi(embedder) => embedder.chunk_count_hint(),
Embedder::UserProvided(_) => 1,
}
}
pub fn prompt_count_in_chunk_hint(&self) -> usize {
match self {
Embedder::HuggingFace(embedder) => embedder.prompt_count_in_chunk_hint(),
Embedder::OpenAi(embedder) => embedder.prompt_count_in_chunk_hint(),
Embedder::UserProvided(_) => 1,
}
}
pub fn dimensions(&self) -> usize {
match self {
Embedder::HuggingFace(embedder) => embedder.dimensions(),
Embedder::OpenAi(embedder) => embedder.dimensions(),
Embedder::UserProvided(embedder) => embedder.dimensions(),
}
}
pub fn distribution(&self) -> Option<DistributionShift> {
match self {
Embedder::HuggingFace(embedder) => embedder.distribution(),
Embedder::OpenAi(embedder) => embedder.distribution(),
Embedder::UserProvided(_embedder) => None,
}
}
}
#[derive(Debug, Clone, Copy)]
pub struct DistributionShift {
pub current_mean: f32,
pub current_sigma: f32,
}
impl DistributionShift {
/// `None` if sigma <= 0.
pub fn new(mean: f32, sigma: f32) -> Option<Self> {
if sigma <= 0.0 {
None
} else {
Some(Self { current_mean: mean, current_sigma: sigma })
}
}
pub fn shift(&self, score: f32) -> f32 {
// <https://math.stackexchange.com/a/2894689>
// We're somewhat abusively mapping the distribution of distances to a gaussian.
// The parameters we're given is the mean and sigma of the native result distribution.
// We're using them to retarget the distribution to a gaussian centered on 0.5 with a sigma of 0.4.
let target_mean = 0.5;
let target_sigma = 0.4;
// a^2 sig1^2 = sig2^2 => a^2 = sig2^2 / sig1^2 => a = sig2 / sig1, assuming a, sig1, and sig2 positive.
let factor = target_sigma / self.current_sigma;
// a*mu1 + b = mu2 => b = mu2 - a*mu1
let offset = target_mean - (factor * self.current_mean);
let mut score = factor * score + offset;
// clamp the final score in the ]0, 1] interval.
if score <= 0.0 {
score = f32::EPSILON;
}
if score > 1.0 {
score = 1.0;
}
score
}
}

452
milli/src/vector/openai.rs Normal file
View File

@ -0,0 +1,452 @@
use std::fmt::Display;
use reqwest::StatusCode;
use serde::{Deserialize, Serialize};
use super::error::{EmbedError, NewEmbedderError};
use super::{DistributionShift, Embedding, Embeddings};
#[derive(Debug)]
pub struct Embedder {
client: reqwest::Client,
tokenizer: tiktoken_rs::CoreBPE,
options: EmbedderOptions,
}
#[derive(Debug, Clone, Hash, PartialEq, Eq, serde::Deserialize, serde::Serialize)]
pub struct EmbedderOptions {
pub api_key: Option<String>,
pub embedding_model: EmbeddingModel,
}
#[derive(
Debug,
Clone,
Copy,
Default,
Hash,
PartialEq,
Eq,
serde::Serialize,
serde::Deserialize,
deserr::Deserr,
)]
#[serde(deny_unknown_fields, rename_all = "camelCase")]
#[deserr(rename_all = camelCase, deny_unknown_fields)]
pub enum EmbeddingModel {
// # WARNING
//
// If ever adding a model, make sure to add it to the list of supported models below.
#[default]
#[serde(rename = "text-embedding-ada-002")]
#[deserr(rename = "text-embedding-ada-002")]
TextEmbeddingAda002,
}
impl EmbeddingModel {
pub fn supported_models() -> &'static [&'static str] {
&["text-embedding-ada-002"]
}
pub fn max_token(&self) -> usize {
match self {
EmbeddingModel::TextEmbeddingAda002 => 8191,
}
}
pub fn dimensions(&self) -> usize {
match self {
EmbeddingModel::TextEmbeddingAda002 => 1536,
}
}
pub fn name(&self) -> &'static str {
match self {
EmbeddingModel::TextEmbeddingAda002 => "text-embedding-ada-002",
}
}
pub fn from_name(name: &str) -> Option<Self> {
match name {
"text-embedding-ada-002" => Some(EmbeddingModel::TextEmbeddingAda002),
_ => None,
}
}
fn distribution(&self) -> Option<DistributionShift> {
match self {
EmbeddingModel::TextEmbeddingAda002 => {
Some(DistributionShift { current_mean: 0.90, current_sigma: 0.08 })
}
}
}
}
pub const OPENAI_EMBEDDINGS_URL: &str = "https://api.openai.com/v1/embeddings";
impl EmbedderOptions {
pub fn with_default_model(api_key: Option<String>) -> Self {
Self { api_key, embedding_model: Default::default() }
}
pub fn with_embedding_model(api_key: Option<String>, embedding_model: EmbeddingModel) -> Self {
Self { api_key, embedding_model }
}
}
impl Embedder {
pub fn new(options: EmbedderOptions) -> Result<Self, NewEmbedderError> {
let mut headers = reqwest::header::HeaderMap::new();
let mut inferred_api_key = Default::default();
let api_key = options.api_key.as_ref().unwrap_or_else(|| {
inferred_api_key = infer_api_key();
&inferred_api_key
});
headers.insert(
reqwest::header::AUTHORIZATION,
reqwest::header::HeaderValue::from_str(&format!("Bearer {}", api_key))
.map_err(NewEmbedderError::openai_invalid_api_key_format)?,
);
headers.insert(
reqwest::header::CONTENT_TYPE,
reqwest::header::HeaderValue::from_static("application/json"),
);
let client = reqwest::ClientBuilder::new()
.default_headers(headers)
.build()
.map_err(NewEmbedderError::openai_initialize_web_client)?;
// looking at the code it is very unclear that this can actually fail.
let tokenizer = tiktoken_rs::cl100k_base().unwrap();
Ok(Self { options, client, tokenizer })
}
pub async fn embed(&self, texts: Vec<String>) -> Result<Vec<Embeddings<f32>>, EmbedError> {
let mut tokenized = false;
for attempt in 0..7 {
let result = if tokenized {
self.try_embed_tokenized(&texts).await
} else {
self.try_embed(&texts).await
};
let retry_duration = match result {
Ok(embeddings) => return Ok(embeddings),
Err(retry) => {
log::warn!("Failed: {}", retry.error);
tokenized |= retry.must_tokenize();
retry.into_duration(attempt)
}
}?;
log::warn!("Attempt #{}, retrying after {}ms.", attempt, retry_duration.as_millis());
tokio::time::sleep(retry_duration).await;
}
let result = if tokenized {
self.try_embed_tokenized(&texts).await
} else {
self.try_embed(&texts).await
};
result.map_err(Retry::into_error)
}
async fn check_response(response: reqwest::Response) -> Result<reqwest::Response, Retry> {
if !response.status().is_success() {
match response.status() {
StatusCode::UNAUTHORIZED => {
let error_response: OpenAiErrorResponse = response
.json()
.await
.map_err(EmbedError::openai_unexpected)
.map_err(Retry::retry_later)?;
return Err(Retry::give_up(EmbedError::openai_auth_error(
error_response.error,
)));
}
StatusCode::TOO_MANY_REQUESTS => {
let error_response: OpenAiErrorResponse = response
.json()
.await
.map_err(EmbedError::openai_unexpected)
.map_err(Retry::retry_later)?;
return Err(Retry::rate_limited(EmbedError::openai_too_many_requests(
error_response.error,
)));
}
StatusCode::INTERNAL_SERVER_ERROR => {
let error_response: OpenAiErrorResponse = response
.json()
.await
.map_err(EmbedError::openai_unexpected)
.map_err(Retry::retry_later)?;
return Err(Retry::retry_later(EmbedError::openai_internal_server_error(
error_response.error,
)));
}
StatusCode::SERVICE_UNAVAILABLE => {
let error_response: OpenAiErrorResponse = response
.json()
.await
.map_err(EmbedError::openai_unexpected)
.map_err(Retry::retry_later)?;
return Err(Retry::retry_later(EmbedError::openai_internal_server_error(
error_response.error,
)));
}
StatusCode::BAD_REQUEST => {
// Most probably, one text contained too many tokens
let error_response: OpenAiErrorResponse = response
.json()
.await
.map_err(EmbedError::openai_unexpected)
.map_err(Retry::retry_later)?;
log::warn!("OpenAI: input was too long, retrying on tokenized version. For best performance, limit the size of your prompt.");
return Err(Retry::retry_tokenized(EmbedError::openai_too_many_tokens(
error_response.error,
)));
}
code => {
return Err(Retry::give_up(EmbedError::openai_unhandled_status_code(
code.as_u16(),
)));
}
}
}
Ok(response)
}
async fn try_embed<S: AsRef<str> + serde::Serialize>(
&self,
texts: &[S],
) -> Result<Vec<Embeddings<f32>>, Retry> {
for text in texts {
log::trace!("Received prompt: {}", text.as_ref())
}
let request = OpenAiRequest { model: self.options.embedding_model.name(), input: texts };
let response = self
.client
.post(OPENAI_EMBEDDINGS_URL)
.json(&request)
.send()
.await
.map_err(EmbedError::openai_network)
.map_err(Retry::retry_later)?;
let response = Self::check_response(response).await?;
let response: OpenAiResponse = response
.json()
.await
.map_err(EmbedError::openai_unexpected)
.map_err(Retry::retry_later)?;
log::trace!("response: {:?}", response.data);
Ok(response
.data
.into_iter()
.map(|data| Embeddings::from_single_embedding(data.embedding))
.collect())
}
async fn try_embed_tokenized(&self, text: &[String]) -> Result<Vec<Embeddings<f32>>, Retry> {
pub const OVERLAP_SIZE: usize = 200;
let mut all_embeddings = Vec::with_capacity(text.len());
for text in text {
let max_token_count = self.options.embedding_model.max_token();
let encoded = self.tokenizer.encode_ordinary(text.as_str());
let len = encoded.len();
if len < max_token_count {
all_embeddings.append(&mut self.try_embed(&[text]).await?);
continue;
}
let mut tokens = encoded.as_slice();
let mut embeddings_for_prompt =
Embeddings::new(self.options.embedding_model.dimensions());
while tokens.len() > max_token_count {
let window = &tokens[..max_token_count];
embeddings_for_prompt.push(self.embed_tokens(window).await?).unwrap();
tokens = &tokens[max_token_count - OVERLAP_SIZE..];
}
// end of text
embeddings_for_prompt.push(self.embed_tokens(tokens).await?).unwrap();
all_embeddings.push(embeddings_for_prompt);
}
Ok(all_embeddings)
}
async fn embed_tokens(&self, tokens: &[usize]) -> Result<Embedding, Retry> {
for attempt in 0..9 {
let duration = match self.try_embed_tokens(tokens).await {
Ok(embedding) => return Ok(embedding),
Err(retry) => retry.into_duration(attempt),
}
.map_err(Retry::retry_later)?;
tokio::time::sleep(duration).await;
}
self.try_embed_tokens(tokens).await.map_err(|retry| Retry::give_up(retry.into_error()))
}
async fn try_embed_tokens(&self, tokens: &[usize]) -> Result<Embedding, Retry> {
let request =
OpenAiTokensRequest { model: self.options.embedding_model.name(), input: tokens };
let response = self
.client
.post(OPENAI_EMBEDDINGS_URL)
.json(&request)
.send()
.await
.map_err(EmbedError::openai_network)
.map_err(Retry::retry_later)?;
let response = Self::check_response(response).await?;
let mut response: OpenAiResponse = response
.json()
.await
.map_err(EmbedError::openai_unexpected)
.map_err(Retry::retry_later)?;
Ok(response.data.pop().map(|data| data.embedding).unwrap_or_default())
}
pub async fn embed_chunks(
&self,
text_chunks: Vec<Vec<String>>,
) -> Result<Vec<Vec<Embeddings<f32>>>, EmbedError> {
futures::future::try_join_all(text_chunks.into_iter().map(|prompts| self.embed(prompts)))
.await
}
pub fn chunk_count_hint(&self) -> usize {
10
}
pub fn prompt_count_in_chunk_hint(&self) -> usize {
10
}
pub fn dimensions(&self) -> usize {
self.options.embedding_model.dimensions()
}
pub fn distribution(&self) -> Option<DistributionShift> {
self.options.embedding_model.distribution()
}
}
// retrying in case of failure
struct Retry {
error: EmbedError,
strategy: RetryStrategy,
}
enum RetryStrategy {
GiveUp,
Retry,
RetryTokenized,
RetryAfterRateLimit,
}
impl Retry {
fn give_up(error: EmbedError) -> Self {
Self { error, strategy: RetryStrategy::GiveUp }
}
fn retry_later(error: EmbedError) -> Self {
Self { error, strategy: RetryStrategy::Retry }
}
fn retry_tokenized(error: EmbedError) -> Self {
Self { error, strategy: RetryStrategy::RetryTokenized }
}
fn rate_limited(error: EmbedError) -> Self {
Self { error, strategy: RetryStrategy::RetryAfterRateLimit }
}
fn into_duration(self, attempt: u32) -> Result<tokio::time::Duration, EmbedError> {
match self.strategy {
RetryStrategy::GiveUp => Err(self.error),
RetryStrategy::Retry => Ok(tokio::time::Duration::from_millis((10u64).pow(attempt))),
RetryStrategy::RetryTokenized => Ok(tokio::time::Duration::from_millis(1)),
RetryStrategy::RetryAfterRateLimit => {
Ok(tokio::time::Duration::from_millis(100 + 10u64.pow(attempt)))
}
}
}
fn must_tokenize(&self) -> bool {
matches!(self.strategy, RetryStrategy::RetryTokenized)
}
fn into_error(self) -> EmbedError {
self.error
}
}
// openai api structs
#[derive(Debug, Serialize)]
struct OpenAiRequest<'a, S: AsRef<str> + serde::Serialize> {
model: &'a str,
input: &'a [S],
}
#[derive(Debug, Serialize)]
struct OpenAiTokensRequest<'a> {
model: &'a str,
input: &'a [usize],
}
#[derive(Debug, Deserialize)]
struct OpenAiResponse {
data: Vec<OpenAiEmbedding>,
}
#[derive(Debug, Deserialize)]
struct OpenAiErrorResponse {
error: OpenAiError,
}
#[derive(Debug, Deserialize)]
pub struct OpenAiError {
message: String,
// type: String,
code: Option<String>,
}
impl Display for OpenAiError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match &self.code {
Some(code) => write!(f, "{} ({})", self.message, code),
None => write!(f, "{}", self.message),
}
}
}
#[derive(Debug, Deserialize)]
struct OpenAiEmbedding {
embedding: Embedding,
// object: String,
// index: usize,
}
fn infer_api_key() -> String {
std::env::var("MEILI_OPENAI_API_KEY")
.or_else(|_| std::env::var("OPENAI_API_KEY"))
.unwrap_or_default()
}

View File

@ -0,0 +1,244 @@
use deserr::Deserr;
use serde::{Deserialize, Serialize};
use crate::prompt::PromptData;
use crate::update::Setting;
use crate::vector::EmbeddingConfig;
use crate::UserError;
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq, Deserr)]
#[serde(deny_unknown_fields, rename_all = "camelCase")]
#[deserr(rename_all = camelCase, deny_unknown_fields)]
pub struct EmbeddingSettings {
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
#[deserr(default)]
pub source: Setting<EmbedderSource>,
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
#[deserr(default)]
pub model: Setting<String>,
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
#[deserr(default)]
pub revision: Setting<String>,
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
#[deserr(default)]
pub api_key: Setting<String>,
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
#[deserr(default)]
pub dimensions: Setting<usize>,
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
#[deserr(default)]
pub document_template: Setting<String>,
}
pub fn check_unset<T>(
key: &Setting<T>,
field: &'static str,
source: EmbedderSource,
embedder_name: &str,
) -> Result<(), UserError> {
if matches!(key, Setting::NotSet) {
Ok(())
} else {
Err(UserError::InvalidFieldForSource {
embedder_name: embedder_name.to_owned(),
source_: source,
field,
allowed_fields_for_source: EmbeddingSettings::allowed_fields_for_source(source),
allowed_sources_for_field: EmbeddingSettings::allowed_sources_for_field(field),
})
}
}
pub fn check_set<T>(
key: &Setting<T>,
field: &'static str,
source: EmbedderSource,
embedder_name: &str,
) -> Result<(), UserError> {
if matches!(key, Setting::Set(_)) {
Ok(())
} else {
Err(UserError::MissingFieldForSource {
field,
source_: source,
embedder_name: embedder_name.to_owned(),
})
}
}
impl EmbeddingSettings {
pub const SOURCE: &'static str = "source";
pub const MODEL: &'static str = "model";
pub const REVISION: &'static str = "revision";
pub const API_KEY: &'static str = "apiKey";
pub const DIMENSIONS: &'static str = "dimensions";
pub const DOCUMENT_TEMPLATE: &'static str = "documentTemplate";
pub fn allowed_sources_for_field(field: &'static str) -> &'static [EmbedderSource] {
match field {
Self::SOURCE => {
&[EmbedderSource::HuggingFace, EmbedderSource::OpenAi, EmbedderSource::UserProvided]
}
Self::MODEL => &[EmbedderSource::HuggingFace, EmbedderSource::OpenAi],
Self::REVISION => &[EmbedderSource::HuggingFace],
Self::API_KEY => &[EmbedderSource::OpenAi],
Self::DIMENSIONS => &[EmbedderSource::UserProvided],
Self::DOCUMENT_TEMPLATE => &[EmbedderSource::HuggingFace, EmbedderSource::OpenAi],
_other => unreachable!("unknown field"),
}
}
pub fn allowed_fields_for_source(source: EmbedderSource) -> &'static [&'static str] {
match source {
EmbedderSource::OpenAi => {
&[Self::SOURCE, Self::MODEL, Self::API_KEY, Self::DOCUMENT_TEMPLATE]
}
EmbedderSource::HuggingFace => {
&[Self::SOURCE, Self::MODEL, Self::REVISION, Self::DOCUMENT_TEMPLATE]
}
EmbedderSource::UserProvided => &[Self::SOURCE, Self::DIMENSIONS],
}
}
pub(crate) fn apply_default_source(setting: &mut Setting<EmbeddingSettings>) {
if let Setting::Set(EmbeddingSettings {
source: source @ (Setting::NotSet | Setting::Reset),
..
}) = setting
{
*source = Setting::Set(EmbedderSource::default())
}
}
}
#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize, PartialEq, Eq, Deserr)]
#[serde(deny_unknown_fields, rename_all = "camelCase")]
#[deserr(rename_all = camelCase, deny_unknown_fields)]
pub enum EmbedderSource {
#[default]
OpenAi,
HuggingFace,
UserProvided,
}
impl std::fmt::Display for EmbedderSource {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let s = match self {
EmbedderSource::OpenAi => "openAi",
EmbedderSource::HuggingFace => "huggingFace",
EmbedderSource::UserProvided => "userProvided",
};
f.write_str(s)
}
}
impl EmbeddingSettings {
pub fn apply(&mut self, new: Self) {
let EmbeddingSettings { source, model, revision, api_key, dimensions, document_template } =
new;
let old_source = self.source;
self.source.apply(source);
// Reinitialize the whole setting object on a source change
if old_source != self.source {
*self = EmbeddingSettings {
source,
model,
revision,
api_key,
dimensions,
document_template,
};
return;
}
self.model.apply(model);
self.revision.apply(revision);
self.api_key.apply(api_key);
self.dimensions.apply(dimensions);
self.document_template.apply(document_template);
}
}
impl From<EmbeddingConfig> for EmbeddingSettings {
fn from(value: EmbeddingConfig) -> Self {
let EmbeddingConfig { embedder_options, prompt } = value;
match embedder_options {
super::EmbedderOptions::HuggingFace(options) => Self {
source: Setting::Set(EmbedderSource::HuggingFace),
model: Setting::Set(options.model),
revision: options.revision.map(Setting::Set).unwrap_or_default(),
api_key: Setting::NotSet,
dimensions: Setting::NotSet,
document_template: Setting::Set(prompt.template),
},
super::EmbedderOptions::OpenAi(options) => Self {
source: Setting::Set(EmbedderSource::OpenAi),
model: Setting::Set(options.embedding_model.name().to_owned()),
revision: Setting::NotSet,
api_key: options.api_key.map(Setting::Set).unwrap_or_default(),
dimensions: Setting::NotSet,
document_template: Setting::Set(prompt.template),
},
super::EmbedderOptions::UserProvided(options) => Self {
source: Setting::Set(EmbedderSource::UserProvided),
model: Setting::NotSet,
revision: Setting::NotSet,
api_key: Setting::NotSet,
dimensions: Setting::Set(options.dimensions),
document_template: Setting::NotSet,
},
}
}
}
impl From<EmbeddingSettings> for EmbeddingConfig {
fn from(value: EmbeddingSettings) -> Self {
let mut this = Self::default();
let EmbeddingSettings { source, model, revision, api_key, dimensions, document_template } =
value;
if let Some(source) = source.set() {
match source {
EmbedderSource::OpenAi => {
let mut options = super::openai::EmbedderOptions::with_default_model(None);
if let Some(model) = model.set() {
if let Some(model) = super::openai::EmbeddingModel::from_name(&model) {
options.embedding_model = model;
}
}
if let Some(api_key) = api_key.set() {
options.api_key = Some(api_key);
}
this.embedder_options = super::EmbedderOptions::OpenAi(options);
}
EmbedderSource::HuggingFace => {
let mut options = super::hf::EmbedderOptions::default();
if let Some(model) = model.set() {
options.model = model;
// Reset the revision if we are setting the model.
// This allows the following:
// "huggingFace": {} -> default model with default revision
// "huggingFace": { "model": "name-of-the-default-model" } -> default model without a revision
// "huggingFace": { "model": "some-other-model" } -> most importantly, other model without a revision
options.revision = None;
}
if let Some(revision) = revision.set() {
options.revision = Some(revision);
}
this.embedder_options = super::EmbedderOptions::HuggingFace(options);
}
EmbedderSource::UserProvided => {
this.embedder_options =
super::EmbedderOptions::UserProvided(super::manual::EmbedderOptions {
dimensions: dimensions.set().unwrap(),
});
}
}
}
if let Setting::Set(template) = document_template {
this.prompt = PromptData { template }
}
this
}
}