Add response updating logic

This commit is contained in:
Mubelotix
2025-08-26 10:59:12 +02:00
parent b2a72b0363
commit 4290901dea
10 changed files with 233 additions and 85 deletions

View File

@ -3,21 +3,22 @@ use std::io::{Read as _, Seek as _, Write as _};
use anyhow::{bail, Context};
use futures_util::TryStreamExt as _;
use serde::Deserialize;
use serde::{Deserialize, Serialize};
use sha2::Digest;
use super::client::Client;
#[derive(Deserialize, Clone)]
#[derive(Serialize, Deserialize, Clone)]
pub struct Asset {
pub local_location: Option<String>,
pub remote_location: Option<String>,
#[serde(default)]
#[serde(default, skip_serializing_if = "AssetFormat::is_default")]
pub format: AssetFormat,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub sha256: Option<String>,
}
#[derive(Deserialize, Default, Copy, Clone)]
#[derive(Serialize, Deserialize, Default, Copy, Clone)]
pub enum AssetFormat {
#[default]
Auto,
@ -27,6 +28,10 @@ pub enum AssetFormat {
}
impl AssetFormat {
fn is_default(&self) -> bool {
matches!(self, AssetFormat::Auto)
}
pub fn to_content_type(self, filename: &str) -> &'static str {
match self {
AssetFormat::Auto => Self::auto_detect(filename).to_content_type(filename),

View File

@ -1,5 +1,5 @@
use anyhow::Context;
use serde::Deserialize;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone)]
pub struct Client {
@ -61,7 +61,7 @@ impl Client {
}
}
#[derive(Debug, Clone, Copy, Deserialize)]
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
pub enum Method {
Get,

View File

@ -1,24 +1,32 @@
use std::collections::BTreeMap;
use std::fmt::Display;
use std::io::Read as _;
use std::sync::Arc;
use anyhow::{bail, Context as _};
use serde::Deserialize;
use reqwest::StatusCode;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use crate::common::assets::{fetch_asset, Asset};
use crate::common::client::{Client, Method};
#[derive(Clone, Deserialize)]
#[derive(Serialize, Deserialize, Clone)]
#[serde(rename_all = "camelCase")]
pub struct Command {
pub route: String,
pub method: Method,
#[serde(default)]
pub body: Body,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub expected_status: Option<u16>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub expected_response: Option<serde_json::Value>,
#[serde(default)]
synchronous: SyncMode,
}
#[derive(Default, Clone, Deserialize)]
#[derive(Default, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum Body {
Inline {
@ -63,7 +71,7 @@ impl Display for Command {
}
}
#[derive(Default, Debug, Clone, Copy, Deserialize)]
#[derive(Default, Debug, Clone, Copy, Serialize, Deserialize)]
enum SyncMode {
DontWait,
#[default]
@ -72,41 +80,42 @@ enum SyncMode {
}
async fn run_batch(
client: &Client,
batch: &[Command],
assets: &BTreeMap<String, Asset>,
asset_folder: &str,
) -> anyhow::Result<()> {
let [.., last] = batch else { return Ok(()) };
client: &Arc<Client>,
batch: Vec<Command>,
assets: &Arc<BTreeMap<String, Asset>>,
asset_folder: &'static str,
return_response: bool,
) -> anyhow::Result<Vec<(Value, StatusCode)>> {
let [.., last] = batch.as_slice() else { return Ok(Vec::new()) };
let sync = last.synchronous;
let batch_len = batch.len();
let mut tasks = tokio::task::JoinSet::new();
for command in batch {
// FIXME: you probably don't want to copy assets everytime here
tasks.spawn({
let client = client.clone();
let command = command.clone();
let assets = assets.clone();
let asset_folder = asset_folder.to_owned();
async move { run(client, command, &assets, &asset_folder).await }
});
let mut tasks = Vec::with_capacity(batch.len());
for batch in batch {
let client2 = Arc::clone(&client);
let assets2 = Arc::clone(&assets);
tasks.push(tokio::spawn(async move {
run(&client2, &batch, &assets2, asset_folder, return_response).await
}));
}
while let Some(result) = tasks.join_next().await {
result
.context("panicked while executing command")?
.context("error while executing command")?;
let mut outputs = Vec::with_capacity(if return_response { batch_len } else { 0 });
for task in tasks {
let output = task.await.context("task panicked")??;
if let Some(output) = output {
if return_response {
outputs.push(output);
}
}
}
match sync {
SyncMode::DontWait => {}
SyncMode::WaitForResponse => {}
SyncMode::WaitForTask => wait_for_tasks(client).await?,
SyncMode::WaitForTask => wait_for_tasks(&client).await?,
}
Ok(())
Ok(outputs)
}
async fn wait_for_tasks(client: &Client) -> anyhow::Result<()> {
@ -150,13 +159,16 @@ async fn wait_for_tasks(client: &Client) -> anyhow::Result<()> {
#[tracing::instrument(skip(client, command, assets, asset_folder), fields(command = %command))]
pub async fn run(
client: Client,
mut command: Command,
client: &Client,
command: &Command,
assets: &BTreeMap<String, Asset>,
asset_folder: &str,
) -> anyhow::Result<()> {
return_value: bool,
) -> anyhow::Result<Option<(Value, StatusCode)>> {
// memtake the body here to leave an empty body in its place, so that command is not partially moved-out
let body = std::mem::take(&mut command.body)
let body = command
.body
.clone()
.get(assets, asset_folder)
.with_context(|| format!("while getting body for command {command}"))?;
@ -172,7 +184,17 @@ pub async fn run(
request.send().await.with_context(|| format!("error sending command: {}", command))?;
let code = response.status();
if code.is_client_error() {
if let Some(expected_status) = command.expected_status {
if code.as_u16() != expected_status {
let response = response
.text()
.await
.context("could not read response body as text")
.context("reading response body when checking expected status")?;
bail!("unexpected status code: got {}, expected {expected_status}, response body: '{response}'", code.as_u16());
}
} else if code.is_client_error() {
tracing::error!(%command, %code, "error in workload file");
let response: serde_json::Value = response
.json()
@ -190,22 +212,44 @@ pub async fn run(
bail!("server error: server responded with error code {code} and '{response}'")
}
Ok(())
if return_value {
let response: serde_json::Value = response
.json()
.await
.context("could not deserialize response as JSON")
.context("parsing response when recording expected response")?;
return Ok(Some((response, code)));
} else if let Some(expected_response) = &command.expected_response {
let response: serde_json::Value = response
.json()
.await
.context("could not deserialize response as JSON")
.context("parsing response when checking expected response")?;
if &response != expected_response {
bail!("unexpected response: got '{response}', expected '{expected_response}'");
}
}
Ok(None)
}
pub async fn run_commands(
client: &Client,
client: &Arc<Client>,
commands: &[Command],
assets: &BTreeMap<String, Asset>,
asset_folder: &str,
) -> anyhow::Result<()> {
assets: &Arc<BTreeMap<String, Asset>>,
asset_folder: &'static str,
return_response: bool,
) -> anyhow::Result<Vec<(Value, StatusCode)>> {
let mut responses = Vec::new();
for batch in
commands.split_inclusive(|command| !matches!(command.synchronous, SyncMode::DontWait))
{
run_batch(client, batch, assets, asset_folder).await?;
let mut new_responses =
run_batch(client, batch.to_vec(), assets, asset_folder, return_response).await?;
responses.append(&mut new_responses);
}
Ok(())
Ok(responses)
}
pub fn health_command() -> Command {
@ -214,5 +258,7 @@ pub fn health_command() -> Command {
method: crate::common::client::Method::Get,
body: Default::default(),
synchronous: SyncMode::WaitForResponse,
expected_status: None,
expected_response: None,
}
}

View File

@ -1,8 +1,8 @@
use serde::Deserialize;
use serde::{Deserialize, Serialize};
use crate::{bench::BenchWorkload, test::TestWorkload};
#[derive(Deserialize)]
#[derive(Serialize, Deserialize)]
#[serde(tag = "type")]
#[serde(rename_all = "camelCase")]
pub enum Workload {