mirror of
				https://github.com/meilisearch/meilisearch.git
				synced 2025-10-25 21:16:28 +00:00 
			
		
		
		
	Better CSV support
This commit is contained in:
		| @@ -1,9 +1,8 @@ | ||||
| use std::fmt::{self, Debug, Display}; | ||||
| use std::fs::File; | ||||
| use std::io::{self, BufReader, BufWriter, Seek, Write}; | ||||
| use std::io::{self, BufWriter}; | ||||
| use std::marker::PhantomData; | ||||
|  | ||||
| use csv::StringRecord; | ||||
| use memmap2::Mmap; | ||||
| use milli::documents::Error; | ||||
| use milli::update::new::TopLevelMap; | ||||
| @@ -11,13 +10,13 @@ use milli::Object; | ||||
| use serde::de::{SeqAccess, Visitor}; | ||||
| use serde::{Deserialize, Deserializer}; | ||||
| use serde_json::error::Category; | ||||
| use serde_json::{Map, Value}; | ||||
|  | ||||
| use crate::error::deserr_codes::MalformedPayload; | ||||
| use crate::error::{Code, ErrorCode}; | ||||
|  | ||||
| type Result<T> = std::result::Result<T, DocumentFormatError>; | ||||
|  | ||||
| #[derive(Debug)] | ||||
| #[derive(Debug, Clone, Copy)] | ||||
| pub enum PayloadType { | ||||
|     Ndjson, | ||||
|     Json, | ||||
| @@ -101,6 +100,16 @@ impl From<(PayloadType, serde_json::Error)> for DocumentFormatError { | ||||
|     } | ||||
| } | ||||
|  | ||||
| impl From<(PayloadType, csv::Error)> for DocumentFormatError { | ||||
|     fn from((ty, error): (PayloadType, csv::Error)) -> Self { | ||||
|         if error.is_io_error() { | ||||
|             Self::Io(error.into()) | ||||
|         } else { | ||||
|             Self::MalformedPayload(Error::Csv(error), ty) | ||||
|         } | ||||
|     } | ||||
| } | ||||
|  | ||||
| impl From<io::Error> for DocumentFormatError { | ||||
|     fn from(error: io::Error) -> Self { | ||||
|         Self::Io(error) | ||||
| @@ -140,78 +149,63 @@ fn parse_csv_header(header: &str) -> (&str, AllowedType) { | ||||
|  | ||||
| /// Reads CSV from input and write an obkv batch to writer. | ||||
| pub fn read_csv(input: &File, output: impl io::Write, delimiter: u8) -> Result<u64> { | ||||
|     use serde_json::{Map, Value}; | ||||
|  | ||||
|     let ptype = PayloadType::Csv { delimiter }; | ||||
|     let mut output = BufWriter::new(output); | ||||
|     let mut reader = csv::ReaderBuilder::new().delimiter(delimiter).from_reader(input); | ||||
|  | ||||
|     // TODO manage error correctly | ||||
|     // Make sure that we insert the fields ids in order as the obkv writer has this requirement. | ||||
|     let mut typed_fields: Vec<_> = reader | ||||
|         .headers() | ||||
|         .unwrap() | ||||
|         .into_iter() | ||||
|         .map(parse_csv_header) | ||||
|         .map(|(f, t)| (f.to_string(), t)) | ||||
|         .collect(); | ||||
|     let headers = reader.headers().map_err(|e| DocumentFormatError::from((ptype, e)))?.clone(); | ||||
|     let typed_fields: Vec<_> = headers.iter().map(parse_csv_header).collect(); | ||||
|     let mut object: Map<_, _> = headers.iter().map(|k| (k.to_string(), Value::Null)).collect(); | ||||
|  | ||||
|     let mut object: Map<_, _> = | ||||
|         reader.headers().unwrap().iter().map(|k| (k.to_string(), Value::Null)).collect(); | ||||
|  | ||||
|     let mut line: usize = 0; | ||||
|     let mut line = 0; | ||||
|     let mut record = csv::StringRecord::new(); | ||||
|     while reader.read_record(&mut record).unwrap() { | ||||
|         // We increment here and not at the end of the while loop to take | ||||
|         // the header offset into account. | ||||
|     while reader.read_record(&mut record).map_err(|e| DocumentFormatError::from((ptype, e)))? { | ||||
|         // We increment here and not at the end of the loop | ||||
|         // to take the header offset into account. | ||||
|         line += 1; | ||||
|  | ||||
|         // Reset the document to write | ||||
|         // Reset the document values | ||||
|         object.iter_mut().for_each(|(_, v)| *v = Value::Null); | ||||
|  | ||||
|         for (i, (name, type_)) in typed_fields.iter().enumerate() { | ||||
|         for (i, (name, atype)) in typed_fields.iter().enumerate() { | ||||
|             let value = &record[i]; | ||||
|             let trimmed_value = value.trim(); | ||||
|             let value = match type_ { | ||||
|             let value = match atype { | ||||
|                 AllowedType::Number if trimmed_value.is_empty() => Value::Null, | ||||
|                 AllowedType::Number => match trimmed_value.parse::<i64>() { | ||||
|                     Ok(integer) => Value::from(integer), | ||||
|                     Err(_) => { | ||||
|                         match trimmed_value.parse::<f64>() { | ||||
|                     Err(_) => match trimmed_value.parse::<f64>() { | ||||
|                         Ok(float) => Value::from(float), | ||||
|                         Err(error) => { | ||||
|                                 panic!("bad float") | ||||
|                                 // return Err(Error::ParseFloat { | ||||
|                                 //     error, | ||||
|                                 //     line, | ||||
|                                 //     value: value.to_string(), | ||||
|                                 // }); | ||||
|                             } | ||||
|                         } | ||||
|                             return Err(DocumentFormatError::MalformedPayload( | ||||
|                                 Error::ParseFloat { error, line, value: value.to_string() }, | ||||
|                                 ptype, | ||||
|                             )) | ||||
|                         } | ||||
|                     }, | ||||
|                 }, | ||||
|                 AllowedType::Boolean if trimmed_value.is_empty() => Value::Null, | ||||
|                 AllowedType::Boolean => match trimmed_value.parse::<bool>() { | ||||
|                     Ok(bool) => Value::from(bool), | ||||
|                     Err(error) => { | ||||
|                         panic!("bad bool") | ||||
|                         // return Err(Error::ParseBool { | ||||
|                         //     error, | ||||
|                         //     line, | ||||
|                         //     value: value.to_string(), | ||||
|                         // }); | ||||
|                         return Err(DocumentFormatError::MalformedPayload( | ||||
|                             Error::ParseBool { error, line, value: value.to_string() }, | ||||
|                             ptype, | ||||
|                         )) | ||||
|                     } | ||||
|                 }, | ||||
|                 AllowedType::String if value.is_empty() => Value::Null, | ||||
|                 AllowedType::String => Value::from(value), | ||||
|             }; | ||||
|  | ||||
|             *object.get_mut(name).unwrap() = value; | ||||
|             *object.get_mut(*name).expect("encountered an unknown field") = value; | ||||
|         } | ||||
|  | ||||
|         serde_json::to_writer(&mut output, &object).unwrap(); | ||||
|         serde_json::to_writer(&mut output, &object) | ||||
|             .map_err(|e| DocumentFormatError::from((ptype, e)))?; | ||||
|     } | ||||
|  | ||||
|     Ok(line.saturating_sub(1) as u64) | ||||
|     Ok(line as u64) | ||||
| } | ||||
|  | ||||
| /// Reads JSON from temporary file and write an obkv batch to writer. | ||||
|   | ||||
		Reference in New Issue
	
	Block a user