Compare commits

..

6 Commits

Author SHA1 Message Date
Paul de Nonancourt
79ee4367ad Authenticate to AWS STS if web identity token file and role ARN are provided
This makes it possible to generate short-lived credentials when
snapshotting to S3 directly.
2025-12-16 15:35:14 +01:00
Paul de Nonancourt
865fda8503 Add experimental values to support EKS IRSA authentication
- Add MEILI_EXPERIMENTAL_S3_WEB_IDENTITY_TOKEN_FILE and MEILI_EXPERIMENTAL_S3_ROLE_ARN,
  conflicting with current MEILI_S3_ACCESS_KEY and MEILI_S3_SECRET_KEY.
- Remove default serde for S3SnapshotOpts:
  - Having all fields defaulted caused the TOML deserializer to consider that the Option<S3Opts> was always Some
  - This was causing an issue since new fields are optional so they would be None.

Co-authored-by: Louis Dureuil <louis@meilisearch.com>
2025-12-16 15:35:14 +01:00
Clément Renault
2b6b4284bb Merge pull request #6000 from meilisearch/change-network-topology-2
Allow changing network topology
2025-12-15 11:09:56 +00:00
Louis Dureuil
018cad1781 add batch reason 2025-12-15 11:06:25 +01:00
Clément Renault
26e368b116 Merge pull request #6041 from meilisearch/fix-workflow-injection
Remove risk of command injection
2025-12-09 17:04:58 +00:00
curquiza
ba95ac0915 Remove risk of command injection 2025-12-09 17:06:41 +01:00
6 changed files with 251 additions and 37 deletions

View File

@@ -25,14 +25,18 @@ jobs:
- uses: actions/checkout@v5
- name: Define the Docker image we need to use
id: define-image
env:
EVENT_NAME: ${{ github.event_name }}
DOCKER_IMAGE_INPUT: ${{ github.event.inputs.docker_image }}
run: |
event=${{ github.event_name }}
echo "docker-image=nightly" >> $GITHUB_OUTPUT
if [[ $event == 'workflow_dispatch' ]]; then
echo "docker-image=${{ github.event.inputs.docker_image }}" >> $GITHUB_OUTPUT
if [[ "$EVENT_NAME" == 'workflow_dispatch' ]]; then
echo "docker-image=$DOCKER_IMAGE_INPUT" >> $GITHUB_OUTPUT
fi
- name: Docker image is ${{ steps.define-image.outputs.docker-image }}
run: echo "Docker image is ${{ steps.define-image.outputs.docker-image }}"
env:
DOCKER_IMAGE: ${{ steps.define-image.outputs.docker-image }}
run: echo "Docker image is $DOCKER_IMAGE"
##########
## SDKs ##

View File

@@ -745,6 +745,7 @@ impl IndexScheduler {
mut current_batch: ProcessingBatch,
) -> Result<Option<(Batch, ProcessingBatch)>> {
current_batch.processing(Some(&mut task));
current_batch.reason(BatchStopReason::NetworkTask { id: task.uid });
let change_version =
task.network.as_ref().map(|network| network.network_version()).unwrap_or_default();
@@ -777,11 +778,16 @@ impl IndexScheduler {
task_version >= change_version
});
let (batch, current_batch) = res?;
let (batch, mut current_batch) = res?;
let batch = match batch {
Some(batch) => {
let inner_batch = Box::new(batch);
let inner_reason = current_batch.reason.to_string();
current_batch.reason(BatchStopReason::NetworkTaskOlderTasks {
id: task.uid,
inner_reason,
});
Batch::NetworkIndexBatch { network_task: task, inner_batch }
}
@@ -819,10 +825,15 @@ impl IndexScheduler {
task_version != change_version
});
let (batch, current_batch) = res?;
let (batch, mut current_batch) = res?;
let batch = batch.map(|batch| {
let inner_batch = Box::new(batch);
let inner_reason = current_batch.reason.to_string();
current_batch.reason(BatchStopReason::NetworkTaskImportTasks {
id: task.uid,
inner_reason,
});
(Batch::NetworkIndexBatch { network_task: task, inner_batch }, current_batch)
});

View File

@@ -1,5 +1,7 @@
use std::env::VarError;
use std::ffi::OsStr;
use std::fs;
use std::path::{Path, PathBuf};
use std::sync::atomic::Ordering;
use meilisearch_types::heed::CompactionOption;
@@ -14,6 +16,34 @@ use crate::{Error, IndexScheduler, Result};
const UPDATE_FILES_DIR_NAME: &str = "update_files";
#[derive(Debug, Clone, serde::Deserialize)]
struct StsCredentials {
#[serde(rename = "AccessKeyId")]
access_key_id: String,
#[serde(rename = "SecretAccessKey")]
secret_access_key: String,
#[serde(rename = "SessionToken")]
session_token: String,
}
#[derive(Debug, serde::Deserialize)]
struct AssumeRoleWithWebIdentityResult {
#[serde(rename = "Credentials")]
credentials: StsCredentials,
}
#[derive(Debug, serde::Deserialize)]
struct AssumeRoleWithWebIdentityResponse {
#[serde(rename = "AssumeRoleWithWebIdentityResult")]
result: AssumeRoleWithWebIdentityResult,
}
#[derive(Debug, serde::Deserialize)]
struct StsResponse {
#[serde(rename = "AssumeRoleWithWebIdentityResponse")]
response: AssumeRoleWithWebIdentityResponse,
}
/// # Safety
///
/// See [`EnvOpenOptions::open`].
@@ -231,6 +261,78 @@ impl IndexScheduler {
Ok(tasks)
}
#[cfg(unix)]
async fn assume_role_with_web_identity(
role_arn: &str,
web_identity_token_file: &Path,
) -> anyhow::Result<StsCredentials> {
let token = tokio::fs::read_to_string(web_identity_token_file)
.await
.map_err(|e| anyhow::anyhow!("Failed to read web identity token file: {e}"))?;
let duration: u32 =
match std::env::var("MEILI_EXPERIMENTAL_S3_WEB_IDENTITY_TOKEN_DURATION_SECONDS") {
Ok(s) => s.parse()?,
Err(VarError::NotPresent) => 3600,
Err(VarError::NotUnicode(e)) => {
anyhow::bail!("Invalid duration: {e:?}")
}
};
let form_data = [
("Action", "AssumeRoleWithWebIdentity"),
("Version", "2011-06-15"),
("RoleArn", role_arn),
("RoleSessionName", "meilisearch-snapshot-session"),
("WebIdentityToken", &token),
("DurationSeconds", &duration.to_string()),
];
let client = reqwest::Client::new();
let response = client
.post("https://sts.amazonaws.com/")
.header(reqwest::header::ACCEPT, "application/json")
.header(reqwest::header::CONTENT_TYPE, "application/x-www-form-urlencoded")
.form(&form_data)
.send()
.await
.map_err(|e| anyhow::anyhow!("Failed to send STS request: {e}"))?;
let status = response.status();
let body = response
.text()
.await
.map_err(|e| anyhow::anyhow!("Failed to read STS response body: {e}"))?;
if !status.is_success() {
return Err(anyhow::anyhow!("STS request failed with status {status}: {body}"));
}
let sts_response: StsResponse = serde_json::from_str(&body)
.map_err(|e| anyhow::anyhow!("Failed to deserialize STS response: {e}"))?;
Ok(sts_response.response.result.credentials)
}
async fn extract_credentials_from_options(
s3_access_key: Option<String>,
s3_secret_key: Option<String>,
s3_role_arn: Option<String>,
s3_web_identity_token_file: Option<PathBuf>,
) -> anyhow::Result<(String, String, Option<String>)> {
let static_credentials = s3_access_key.zip(s3_secret_key);
let web_identity = s3_role_arn.zip(s3_web_identity_token_file);
match (static_credentials, web_identity) {
(Some((access_key, secret_key)), None) => Ok((access_key, secret_key, None)),
(None, Some((role_arn, token_file))) => {
let StsCredentials { access_key_id, secret_access_key, session_token } =
Self::assume_role_with_web_identity(&role_arn, &token_file).await?;
Ok((access_key_id, secret_access_key, Some(session_token)))
}
(_, _) => anyhow::bail!("Clap must pass valid auth parameters"),
}
}
#[cfg(unix)]
pub(super) async fn process_snapshot_to_s3(
&self,
@@ -247,6 +349,8 @@ impl IndexScheduler {
s3_snapshot_prefix,
s3_access_key,
s3_secret_key,
s3_role_arn,
s3_web_identity_token_file,
s3_max_in_flight_parts,
s3_compression_level: level,
s3_signature_duration,
@@ -262,21 +366,33 @@ impl IndexScheduler {
};
let (reader, writer) = std::io::pipe()?;
let uploader_task = tokio::spawn(multipart_stream_to_s3(
s3_bucket_url,
s3_bucket_region,
s3_bucket_name,
s3_snapshot_prefix,
s3_access_key,
s3_secret_key,
s3_max_in_flight_parts,
s3_signature_duration,
s3_multipart_part_size,
must_stop_processing,
retry_backoff,
db_name,
reader,
));
let uploader_task = tokio::spawn(async move {
let (s3_access_key, s3_secret_key, s3_token) = Self::extract_credentials_from_options(
s3_access_key,
s3_secret_key,
s3_role_arn,
s3_web_identity_token_file,
)
.await?;
multipart_stream_to_s3(
s3_bucket_url,
s3_bucket_region,
s3_bucket_name,
s3_snapshot_prefix,
s3_access_key,
s3_secret_key,
s3_token,
s3_max_in_flight_parts,
s3_signature_duration,
s3_multipart_part_size,
must_stop_processing,
retry_backoff,
db_name,
reader,
)
.await
});
let index_scheduler = IndexScheduler::private_clone(self);
let builder_task = tokio::task::spawn_blocking(move || {
@@ -430,6 +546,7 @@ async fn multipart_stream_to_s3(
s3_snapshot_prefix: String,
s3_access_key: String,
s3_secret_key: String,
s3_token: Option<String>,
s3_max_in_flight_parts: std::num::NonZero<usize>,
s3_signature_duration: std::time::Duration,
s3_multipart_part_size: u64,
@@ -456,7 +573,10 @@ async fn multipart_stream_to_s3(
s3_bucket_url.parse().map_err(BucketError::ParseError).map_err(Error::S3BucketError)?;
let bucket = Bucket::new(url, UrlStyle::Path, s3_bucket_name, s3_bucket_region)
.map_err(Error::S3BucketError)?;
let credential = Credentials::new(s3_access_key, s3_secret_key);
let credential = match s3_token {
Some(token) => Credentials::new_with_token(s3_access_key, s3_secret_key, token),
None => Credentials::new(s3_access_key, s3_secret_key),
};
// Note for the future (rust 1.91+): use with_added_extension, it's prettier
let object_path = s3_snapshot_prefix.join(format!("{db_name}.snapshot"));

View File

@@ -899,6 +899,17 @@ pub enum BatchStopReason {
SettingsWithDocumentOperation {
id: TaskId,
},
NetworkTask {
id: TaskId,
},
NetworkTaskOlderTasks {
id: TaskId,
inner_reason: String,
},
NetworkTaskImportTasks {
id: TaskId,
inner_reason: String,
},
}
impl BatchStopReason {
@@ -987,6 +998,24 @@ impl Display for BatchStopReason {
"stopped before task with id {id} because it is a document operation which cannot be batched with settings changes"
)
}
BatchStopReason::NetworkTask { id } => {
write!(
f,
"stopped after task with id {id} because it is a network topology change task"
)
}
BatchStopReason::NetworkTaskOlderTasks { id, inner_reason } => {
write!(
f,
"stopped after batching network task with id {id} and a batch of older tasks: {inner_reason}"
)
}
BatchStopReason::NetworkTaskImportTasks { id, inner_reason } => {
write!(
f,
"stopped after batching network task with id {id} and a batch of import tasks: {inner_reason}"
)
}
}
}
}

View File

@@ -85,6 +85,9 @@ const MEILI_S3_BUCKET_NAME: &str = "MEILI_S3_BUCKET_NAME";
const MEILI_S3_SNAPSHOT_PREFIX: &str = "MEILI_S3_SNAPSHOT_PREFIX";
const MEILI_S3_ACCESS_KEY: &str = "MEILI_S3_ACCESS_KEY";
const MEILI_S3_SECRET_KEY: &str = "MEILI_S3_SECRET_KEY";
const MEILI_EXPERIMENTAL_S3_ROLE_ARN: &str = "MEILI_EXPERIMENTAL_S3_ROLE_ARN";
const MEILI_EXPERIMENTAL_S3_WEB_IDENTITY_TOKEN_FILE: &str =
"MEILI_EXPERIMENTAL_S3_WEB_IDENTITY_TOKEN_FILE";
const MEILI_EXPERIMENTAL_S3_MAX_IN_FLIGHT_PARTS: &str = "MEILI_EXPERIMENTAL_S3_MAX_IN_FLIGHT_PARTS";
const MEILI_EXPERIMENTAL_S3_COMPRESSION_LEVEL: &str = "MEILI_EXPERIMENTAL_S3_COMPRESSION_LEVEL";
const MEILI_EXPERIMENTAL_S3_SIGNATURE_DURATION_SECONDS: &str =
@@ -942,37 +945,65 @@ impl TryFrom<&IndexerOpts> for IndexerConfig {
// This group is a bit tricky but makes it possible to require all listed fields if one of them
// is specified. It lets us keep an Option for the S3SnapshotOpts configuration.
// <https://github.com/clap-rs/clap/issues/5092#issuecomment-2616986075>
#[group(requires_all = ["s3_bucket_url", "s3_bucket_region", "s3_bucket_name", "s3_snapshot_prefix", "s3_access_key", "s3_secret_key"])]
#[group(requires_all = ["s3_bucket_url", "s3_bucket_region", "s3_bucket_name", "s3_snapshot_prefix", "s3_auth"])]
pub struct S3SnapshotOpts {
/// The S3 bucket URL in the format https://s3.<region>.amazonaws.com.
#[clap(long, env = MEILI_S3_BUCKET_URL, required = false)]
#[serde(default)]
pub s3_bucket_url: String,
/// The region in the format us-east-1.
#[clap(long, env = MEILI_S3_BUCKET_REGION, required = false)]
#[serde(default)]
pub s3_bucket_region: String,
/// The bucket name.
#[clap(long, env = MEILI_S3_BUCKET_NAME, required = false)]
#[serde(default)]
pub s3_bucket_name: String,
/// The prefix path where to put the snapshot, uses normal slashes (/).
#[clap(long, env = MEILI_S3_SNAPSHOT_PREFIX, required = false)]
#[serde(default)]
pub s3_snapshot_prefix: String,
/// The S3 access key.
#[clap(long, env = MEILI_S3_ACCESS_KEY, required = false)]
/// The S3 access key. Conflicts with --experimental-s3-role-arn and --experimental-s3-web-identity-token-file.
#[clap(
long,
env = MEILI_S3_ACCESS_KEY,
required = false,
group = "s3_auth",
requires = "s3_secret_key"
)]
#[serde(default)]
pub s3_access_key: String,
pub s3_access_key: Option<String>,
/// The S3 secret key.
#[clap(long, env = MEILI_S3_SECRET_KEY, required = false)]
/// The S3 secret key. Conflicts with --experimental-s3-role-arn and --experimental-s3-web-identity-token-file.
#[clap(
long,
env = MEILI_S3_SECRET_KEY,
required = false,
conflicts_with_all = ["experimental_s3_role_arn", "experimental_s3_web_identity_token_file"]
)]
#[serde(default)]
pub s3_secret_key: String,
pub s3_secret_key: Option<String>,
/// The IAM role ARN for web identity federation. Conflicts with --s3-access-key and --s3-secret-key.
#[clap(
long,
env = MEILI_EXPERIMENTAL_S3_ROLE_ARN,
required = false,
group = "s3_auth",
requires = "experimental_s3_web_identity_token_file"
)]
#[serde(default)]
pub experimental_s3_role_arn: Option<String>,
/// The path to the web identity token file. Conflicts with --s3-access-key and --s3-secret-key.
#[clap(
long,
env = MEILI_EXPERIMENTAL_S3_WEB_IDENTITY_TOKEN_FILE,
required = false,
conflicts_with_all = ["s3_access_key", "s3_secret_key"]
)]
#[serde(default)]
pub experimental_s3_web_identity_token_file: Option<PathBuf>,
/// The maximum number of parts that can be uploaded in parallel.
///
@@ -1017,6 +1048,8 @@ impl S3SnapshotOpts {
s3_snapshot_prefix,
s3_access_key,
s3_secret_key,
experimental_s3_role_arn,
experimental_s3_web_identity_token_file,
experimental_s3_max_in_flight_parts,
experimental_s3_compression_level,
experimental_s3_signature_duration_seconds,
@@ -1027,8 +1060,18 @@ impl S3SnapshotOpts {
export_to_env_if_not_present(MEILI_S3_BUCKET_REGION, s3_bucket_region);
export_to_env_if_not_present(MEILI_S3_BUCKET_NAME, s3_bucket_name);
export_to_env_if_not_present(MEILI_S3_SNAPSHOT_PREFIX, s3_snapshot_prefix);
export_to_env_if_not_present(MEILI_S3_ACCESS_KEY, s3_access_key);
export_to_env_if_not_present(MEILI_S3_SECRET_KEY, s3_secret_key);
if let Some(key) = s3_access_key {
export_to_env_if_not_present(MEILI_S3_ACCESS_KEY, key);
}
if let Some(key) = s3_secret_key {
export_to_env_if_not_present(MEILI_S3_SECRET_KEY, key);
}
if let Some(arn) = experimental_s3_role_arn {
export_to_env_if_not_present(MEILI_EXPERIMENTAL_S3_ROLE_ARN, arn);
}
if let Some(path) = experimental_s3_web_identity_token_file {
export_to_env_if_not_present(MEILI_EXPERIMENTAL_S3_WEB_IDENTITY_TOKEN_FILE, path);
}
export_to_env_if_not_present(
MEILI_EXPERIMENTAL_S3_MAX_IN_FLIGHT_PARTS,
experimental_s3_max_in_flight_parts.to_string(),
@@ -1059,6 +1102,8 @@ impl TryFrom<S3SnapshotOpts> for S3SnapshotOptions {
s3_snapshot_prefix,
s3_access_key,
s3_secret_key,
experimental_s3_role_arn,
experimental_s3_web_identity_token_file,
experimental_s3_max_in_flight_parts,
experimental_s3_compression_level,
experimental_s3_signature_duration_seconds,
@@ -1072,6 +1117,8 @@ impl TryFrom<S3SnapshotOpts> for S3SnapshotOptions {
s3_snapshot_prefix,
s3_access_key,
s3_secret_key,
s3_role_arn: experimental_s3_role_arn,
s3_web_identity_token_file: experimental_s3_web_identity_token_file,
s3_max_in_flight_parts: experimental_s3_max_in_flight_parts,
s3_compression_level: experimental_s3_compression_level,
s3_signature_duration: Duration::from_secs(experimental_s3_signature_duration_seconds),

View File

@@ -1,4 +1,5 @@
use std::num::NonZeroUsize;
use std::path::PathBuf;
use std::time::Duration;
use grenad::CompressionType;
@@ -47,8 +48,10 @@ pub struct S3SnapshotOptions {
pub s3_bucket_region: String,
pub s3_bucket_name: String,
pub s3_snapshot_prefix: String,
pub s3_access_key: String,
pub s3_secret_key: String,
pub s3_access_key: Option<String>,
pub s3_secret_key: Option<String>,
pub s3_role_arn: Option<String>,
pub s3_web_identity_token_file: Option<PathBuf>,
pub s3_max_in_flight_parts: NonZeroUsize,
pub s3_compression_level: u32,
pub s3_signature_duration: Duration,