diff --git a/crates/index-scheduler/src/scheduler/process_snapshot_creation.rs b/crates/index-scheduler/src/scheduler/process_snapshot_creation.rs index cb3e675bc..407954282 100644 --- a/crates/index-scheduler/src/scheduler/process_snapshot_creation.rs +++ b/crates/index-scheduler/src/scheduler/process_snapshot_creation.rs @@ -285,7 +285,19 @@ impl IndexScheduler { use async_compression::tokio::write::GzipEncoder; use async_compression::Level; + use bytes::{Bytes, BytesMut}; + use rusty_s3::actions::UploadPart; use tokio::fs::File; + use tokio::io::AsyncReadExt; + use tokio::task::JoinHandle; + + const ONE_HOUR: Duration = Duration::from_secs(3600); + // default part size is 250MiB + const MIN_PART_SIZE: usize = 250 * 1024 * 1024; + // 10MiB + const TEN_MIB: usize = 10 * 1024 * 1024; + // The maximum number of parts that can be uploaded to a single multipart upload. + const MAX_NUMBER_PARTS: usize = 10_000; // The maximum number of parts that can be uploaded in parallel. const S3_MAX_IN_FLIGHT_PARTS: &str = "MEILI_S3_MAX_IN_FLIGHT_PARTS"; @@ -299,69 +311,185 @@ impl IndexScheduler { let url = bucket_url.parse().unwrap(); let bucket = Bucket::new(url, UrlStyle::Path, bucket_name, bucket_region).unwrap(); let credential = Credentials::new(access_key, secret_key); + // TODO change this and use the database name like in the original version + let object = "data.ms.snapshot"; - let (writer, reader) = tokio::net::unix::pipe::pipe()?; - let compressed_writer = GzipEncoder::with_quality(writer, Level::Fastest); - let mut tarball = tokio_tar::Builder::new(compressed_writer); + // TODO implement exponential backoff on upload requests: https://docs.rs/backoff + // TODO return a result with actual errors + // TODO sign for longer than an hour? + // TODO Use a better thing than a String for the object path + let (writer, mut reader) = tokio::net::unix::pipe::pipe()?; + let uploader_task = tokio::spawn(async move { + let action = bucket.create_multipart_upload(Some(&credential), object); + // TODO Question: If it is only signed for an hour and a snapshot takes longer than an hour, what happens? + // If the part is deleted (like a TTL) we should sign it for at least 24 hours. + let url = action.sign(ONE_HOUR); + let resp = client.post(url).send().await.unwrap().error_for_status().unwrap(); + let body = resp.text().await.unwrap(); - // 1. Snapshot the version file - tarball.append_path_with_name(&self.scheduler.version_file_path, VERSION_FILE_NAME).await?; + let multipart = CreateMultipartUpload::parse_response(&body).unwrap(); + let mut etags = Vec::::new(); - // 2. Snapshot the index scheduler LMDB env - progress.update_progress(SnapshotCreationProgress::SnapshotTheIndexScheduler); - let mut tasks_env_file = self.env.try_clone_inner_file().map(File::from_std)?; - let path = Path::new("tasks").join("data.mdb"); - tarball.append_file(path, &mut tasks_env_file).await?; - drop(tasks_env_file); + let mut in_flight = VecDeque::<( + JoinHandle>, + Bytes, + )>::with_capacity(max_in_flight_parts); + for part_number in 1u16.. { + let part_upload = bucket.upload_part( + Some(&credential), + object, + part_number, + multipart.upload_id(), + ); + let url = part_upload.sign(ONE_HOUR); - // 2.3 Create a read transaction on the index-scheduler - let rtxn = self.env.read_txn()?; + // Wait for a buffer to be ready if there are in-flight parts that landed + let mut buffer = if in_flight.len() >= max_in_flight_parts { + let (request, buffer) = in_flight.pop_front().unwrap(); + let mut buffer = buffer.try_into_mut().expect("to convert into a mut buffer"); + let resp = request.await.unwrap().unwrap().error_for_status().unwrap(); + let etag = + resp.headers().get(ETAG).expect("every UploadPart request returns an Etag"); + // TODO use bumpalo to reduce the number of allocations + etags.push(etag.to_str().unwrap().to_owned()); + buffer.clear(); + buffer + } else { + // TODO Base this on the available memory + BytesMut::with_capacity(MIN_PART_SIZE) + }; - // 2.4 Create the update files directory - // And only copy the update files of the enqueued tasks - progress.update_progress(SnapshotCreationProgress::SnapshotTheUpdateFiles); - let enqueued = self.queue.tasks.get_status(&rtxn, Status::Enqueued)?; - let (atomic, update_file_progress) = AtomicUpdateFileStep::new(enqueued.len() as u32); - progress.update_progress(update_file_progress); - let update_files_dir = Path::new("update_files"); - for task_id in enqueued { - let task = - self.queue.tasks.get_task(&rtxn, task_id)?.ok_or(Error::CorruptedTaskQueue)?; - if let Some(content_uuid) = task.content_uuid() { - let src = self.queue.file_store.update_path(content_uuid); - let mut update_file = File::open(src).await?; - let path = update_files_dir.join(content_uuid.to_string()); - tarball.append_file(path, &mut update_file).await?; + while buffer.len() < (MIN_PART_SIZE / 2) { + if reader.read_buf(&mut buffer).await? == 0 { + break; + } + } + + if buffer.is_empty() { + break; + } + + let body = buffer.freeze(); + let task = tokio::spawn(client.put(url).body(body.clone()).send()); + in_flight.push_back((task, body)); } - atomic.fetch_add(1, Ordering::Relaxed); - } - // 3. Snapshot every indexes - progress.update_progress(SnapshotCreationProgress::SnapshotTheIndexes); - let index_mapping = self.index_mapper.index_mapping; - let nb_indexes = index_mapping.len(&rtxn)? as u32; - let indexes_dir = Path::new("indexes"); - for (i, result) in index_mapping.iter(&rtxn)?.enumerate() { - let (name, uuid) = result?; - progress.update_progress(VariableNameStep::::new( - name, i as u32, nb_indexes, - )); - let index = self.index_mapper.index(&rtxn, name)?; - let path = indexes_dir.join(uuid.to_string()); - let mut index_file = index.try_clone_inner_file().map(File::from_std).unwrap(); - tarball.append_file(path, &mut index_file).await?; - } + for (join_handle, _buffer) in in_flight { + let resp = join_handle.await.unwrap().unwrap().error_for_status().unwrap(); + let etag = + resp.headers().get(ETAG).expect("every UploadPart request returns an Etag"); + // TODO use bumpalo to reduce the number of allocations + etags.push(etag.to_str().unwrap().to_owned()); + } - drop(rtxn); + let action = bucket.complete_multipart_upload( + Some(&credential), + object, + multipart.upload_id(), + etags.iter().map(AsRef::as_ref), + ); + let url = action.sign(ONE_HOUR); + let resp = client + .post(url) + .body(action.body()) + .send() + .await + .unwrap() + .error_for_status() + .unwrap(); - // 4. Snapshot the auth LMDB env - progress.update_progress(SnapshotCreationProgress::SnapshotTheApiKeys); - let mut auth_env_file = - self.scheduler.auth_env.try_clone_inner_file().map(File::from_std).unwrap(); - let path = Path::new("auth").join("data.mdb"); - tarball.append_file(path, &mut auth_env_file).await?; + // TODO do a better check and do not assert + assert!(resp.status().is_success()); - tarball.finish().await?; + Result::<_, Error>::Ok(()) + }); + + // TODO not a big fan of this clone + // remove it and get all the necessary data from the scheduler + let index_scheduler = IndexScheduler::private_clone(self); + let builder_task = tokio::task::spawn_local(async move { + let compressed_writer = GzipEncoder::with_quality(writer, Level::Fastest); + let mut tarball = tokio_tar::Builder::new(compressed_writer); + + // 1. Snapshot the version file + tarball + .append_path_with_name( + &index_scheduler.scheduler.version_file_path, + VERSION_FILE_NAME, + ) + .await?; + + // 2. Snapshot the index scheduler LMDB env + progress.update_progress(SnapshotCreationProgress::SnapshotTheIndexScheduler); + let mut tasks_env_file = + index_scheduler.env.try_clone_inner_file().map(File::from_std)?; + let path = Path::new("tasks").join("data.mdb"); + tarball.append_file(path, &mut tasks_env_file).await?; + drop(tasks_env_file); + + // 2.3 Create a read transaction on the index-scheduler + let rtxn = index_scheduler.env.read_txn()?; + + // 2.4 Create the update files directory + // And only copy the update files of the enqueued tasks + progress.update_progress(SnapshotCreationProgress::SnapshotTheUpdateFiles); + let enqueued = index_scheduler.queue.tasks.get_status(&rtxn, Status::Enqueued)?; + let (atomic, update_file_progress) = AtomicUpdateFileStep::new(enqueued.len() as u32); + progress.update_progress(update_file_progress); + let update_files_dir = Path::new("update_files"); + for task_id in enqueued { + let task = index_scheduler + .queue + .tasks + .get_task(&rtxn, task_id)? + .ok_or(Error::CorruptedTaskQueue)?; + if let Some(content_uuid) = task.content_uuid() { + let src = index_scheduler.queue.file_store.update_path(content_uuid); + let mut update_file = File::open(src).await?; + let path = update_files_dir.join(content_uuid.to_string()); + tarball.append_file(path, &mut update_file).await?; + } + atomic.fetch_add(1, Ordering::Relaxed); + } + + // 3. Snapshot every indexes + progress.update_progress(SnapshotCreationProgress::SnapshotTheIndexes); + let index_mapping = index_scheduler.index_mapper.index_mapping; + let nb_indexes = index_mapping.len(&rtxn)? as u32; + let indexes_dir = Path::new("indexes"); + for (i, result) in index_mapping.iter(&rtxn)?.enumerate() { + let (name, uuid) = result?; + progress.update_progress(VariableNameStep::::new( + name, i as u32, nb_indexes, + )); + let index = index_scheduler.index_mapper.index(&rtxn, name)?; + let path = indexes_dir.join(uuid.to_string()); + let mut index_file = index.try_clone_inner_file().map(File::from_std).unwrap(); + tarball.append_file(path, &mut index_file).await?; + } + + drop(rtxn); + + // 4. Snapshot the auth LMDB env + progress.update_progress(SnapshotCreationProgress::SnapshotTheApiKeys); + let mut auth_env_file = index_scheduler + .scheduler + .auth_env + .try_clone_inner_file() + .map(File::from_std) + .unwrap(); + let path = Path::new("auth").join("data.mdb"); + tarball.append_file(path, &mut auth_env_file).await?; + + tarball.finish().await?; + + Result::<_, Error>::Ok(()) + }); + + let (uploader_result, builder_result) = tokio::join!(uploader_task, builder_task); + + uploader_result.unwrap()?; + builder_result.unwrap()?; for task in &mut tasks { task.status = Status::Succeeded; @@ -370,91 +498,3 @@ impl IndexScheduler { Ok(tasks) } } - -// TODO implement exponential backoff on upload requests: https://docs.rs/backoff -// TODO return a result with actual errors -// TODO sign for longer than an hour? -// TODO Use a better thing than a String for the object path -async fn multipart_upload( - bucket: &Bucket, - client: &Client, - credential: Option<&Credentials>, - max_in_flight_parts: usize, - bytes: bytes::Bytes, - object: &str, -) -> Result<()> { - const ONE_HOUR: Duration = Duration::from_secs(3600); - // default part size is 250MiB - const MIN_PART_SIZE: usize = 250 * 1024 * 1024; - // 10MiB - const TEN_MIB: usize = 10 * 1024 * 1024; - // The maximum number of parts that can be uploaded to a single multipart upload. - const MAX_NUMBER_PARTS: usize = 10_000; - - let action = bucket.create_multipart_upload(credential, object); - // TODO Question: If it is only signed for an hour and a snapshot takes longer than an hour, what happens? - // If the part is deleted (like a TTL) we should sign it for at least 24 hours. - let url = action.sign(ONE_HOUR); - let resp = client.post(url).send().await.unwrap().error_for_status().unwrap(); - let body = resp.text().await.unwrap(); - - let multipart = CreateMultipartUpload::parse_response(&body).unwrap(); - let mut etags = Vec::::new(); - - let part_size = bytes.len() / MAX_NUMBER_PARTS; - let part_size = if part_size < TEN_MIB { MIN_PART_SIZE } else { part_size }; - - let mut in_flight_parts = VecDeque::with_capacity(max_in_flight_parts); - let number_of_parts = bytes.len().div_ceil(part_size); - for i in 0..number_of_parts { - let part_number = u16::try_from(i).unwrap().checked_add(1).unwrap(); - let part_upload = - bucket.upload_part(credential, object, part_number, multipart.upload_id()); - let url = part_upload.sign(ONE_HOUR); - - // Make sure we do not read out of bound - let body = if bytes.len() < part_size * (i + 1) { - bytes.slice(part_size * i..) - } else { - bytes.slice(part_size * i..part_size * (i + 1)) - }; - - let task = tokio::spawn(client.put(url).body(body).send()); - in_flight_parts.push_back(task); - - if in_flight_parts.len() == max_in_flight_parts { - let resp = in_flight_parts - .pop_front() - .unwrap() - .await - .unwrap() - .unwrap() - .error_for_status() - .unwrap(); - let etag = resp.headers().get(ETAG).expect("every UploadPart request returns an Etag"); - // TODO use bumpalo to reduce the number of allocations - etags.push(etag.to_str().unwrap().to_owned()); - } - } - - for join_handle in in_flight_parts { - let resp = join_handle.await.unwrap().unwrap().error_for_status().unwrap(); - let etag = resp.headers().get(ETAG).expect("every UploadPart request returns an Etag"); - // TODO use bumpalo to reduce the number of allocations - etags.push(etag.to_str().unwrap().to_owned()); - } - - let action = bucket.complete_multipart_upload( - credential, - object, - multipart.upload_id(), - etags.iter().map(AsRef::as_ref), - ); - let url = action.sign(ONE_HOUR); - let resp = - client.post(url).body(action.body()).send().await.unwrap().error_for_status().unwrap(); - - assert!(resp.status().is_success()); - - Ok(()) -}