Create and use real error types in the codecs

This commit is contained in:
Clément Renault
2023-11-28 10:11:17 +01:00
parent d32eb11329
commit 548c8247c2
7 changed files with 43 additions and 38 deletions

View File

@@ -3,6 +3,7 @@ use std::marker::PhantomData;
use heed::{BoxedError, BytesDecode, BytesEncode}; use heed::{BoxedError, BytesDecode, BytesEncode};
use crate::heed_codec::SliceTooShortError;
use crate::{try_split_array_at, DocumentId, FieldId}; use crate::{try_split_array_at, DocumentId, FieldId};
pub struct FieldDocIdFacetCodec<C>(PhantomData<C>); pub struct FieldDocIdFacetCodec<C>(PhantomData<C>);
@@ -14,10 +15,10 @@ where
type DItem = (FieldId, DocumentId, C::DItem); type DItem = (FieldId, DocumentId, C::DItem);
fn bytes_decode(bytes: &'a [u8]) -> Result<Self::DItem, BoxedError> { fn bytes_decode(bytes: &'a [u8]) -> Result<Self::DItem, BoxedError> {
let (field_id_bytes, bytes) = try_split_array_at(bytes).unwrap(); let (field_id_bytes, bytes) = try_split_array_at(bytes).ok_or(SliceTooShortError)?;
let field_id = u16::from_be_bytes(field_id_bytes); let field_id = u16::from_be_bytes(field_id_bytes);
let (document_id_bytes, bytes) = try_split_array_at(bytes).unwrap(); let (document_id_bytes, bytes) = try_split_array_at(bytes).ok_or(SliceTooShortError)?;
let document_id = u32::from_be_bytes(document_id_bytes); let document_id = u32::from_be_bytes(document_id_bytes);
let value = C::bytes_decode(bytes)?; let value = C::bytes_decode(bytes)?;

View File

@@ -2,8 +2,10 @@ use std::borrow::Cow;
use std::convert::TryInto; use std::convert::TryInto;
use heed::{BoxedError, BytesDecode}; use heed::{BoxedError, BytesDecode};
use thiserror::Error;
use crate::facet::value_encoding::f64_into_bytes; use crate::facet::value_encoding::f64_into_bytes;
use crate::heed_codec::SliceTooShortError;
pub struct OrderedF64Codec; pub struct OrderedF64Codec;
@@ -12,7 +14,7 @@ impl<'a> BytesDecode<'a> for OrderedF64Codec {
fn bytes_decode(bytes: &'a [u8]) -> Result<Self::DItem, BoxedError> { fn bytes_decode(bytes: &'a [u8]) -> Result<Self::DItem, BoxedError> {
if bytes.len() < 16 { if bytes.len() < 16 {
Err(BoxedError::from("invalid slice length")) Err(SliceTooShortError.into())
} else { } else {
bytes[8..].try_into().map(f64::from_be_bytes).map_err(Into::into) bytes[8..].try_into().map(f64::from_be_bytes).map_err(Into::into)
} }
@@ -26,8 +28,7 @@ impl heed::BytesEncode<'_> for OrderedF64Codec {
let mut buffer = [0u8; 16]; let mut buffer = [0u8; 16];
// write the globally ordered float // write the globally ordered float
let bytes = f64_into_bytes(*f) let bytes = f64_into_bytes(*f).ok_or(InvalidGloballyOrderedFloatError { float: *f })?;
.ok_or_else(|| BoxedError::from("cannot generate a globally ordered float"))?;
buffer[..8].copy_from_slice(&bytes[..]); buffer[..8].copy_from_slice(&bytes[..]);
// Then the f64 value just to be able to read it back // Then the f64 value just to be able to read it back
let bytes = f.to_be_bytes(); let bytes = f.to_be_bytes();
@@ -36,3 +37,9 @@ impl heed::BytesEncode<'_> for OrderedF64Codec {
Ok(Cow::Owned(buffer.to_vec())) Ok(Cow::Owned(buffer.to_vec()))
} }
} }
#[derive(Error, Debug)]
#[error("the float {float} cannot be converted to a globally ordered representation")]
pub struct InvalidGloballyOrderedFloatError {
float: f64,
}

View File

@@ -2,6 +2,7 @@ use std::borrow::Cow;
use heed::BoxedError; use heed::BoxedError;
use super::SliceTooShortError;
use crate::{try_split_array_at, FieldId}; use crate::{try_split_array_at, FieldId};
pub struct FieldIdWordCountCodec; pub struct FieldIdWordCountCodec;
@@ -10,11 +11,9 @@ impl<'a> heed::BytesDecode<'a> for FieldIdWordCountCodec {
type DItem = (FieldId, u8); type DItem = (FieldId, u8);
fn bytes_decode(bytes: &'a [u8]) -> Result<Self::DItem, BoxedError> { fn bytes_decode(bytes: &'a [u8]) -> Result<Self::DItem, BoxedError> {
let (field_id_bytes, bytes) = let (field_id_bytes, bytes) = try_split_array_at(bytes).ok_or(SliceTooShortError)?;
try_split_array_at(bytes).ok_or("invalid slice length").map_err(BoxedError::from)?;
let field_id = u16::from_be_bytes(field_id_bytes); let field_id = u16::from_be_bytes(field_id_bytes);
let ([word_count], _nothing) = let ([word_count], _nothing) = try_split_array_at(bytes).ok_or(SliceTooShortError)?;
try_split_array_at(bytes).ok_or("invalid slice length").map_err(BoxedError::from)?;
Ok((field_id, word_count)) Ok((field_id, word_count))
} }
} }

View File

@@ -15,6 +15,7 @@ mod str_str_u8_codec;
pub use byte_slice_ref::BytesRefCodec; pub use byte_slice_ref::BytesRefCodec;
use heed::BoxedError; use heed::BoxedError;
pub use str_ref::StrRefCodec; pub use str_ref::StrRefCodec;
use thiserror::Error;
pub use self::beu16_str_codec::BEU16StrCodec; pub use self::beu16_str_codec::BEU16StrCodec;
pub use self::beu32_str_codec::BEU32StrCodec; pub use self::beu32_str_codec::BEU32StrCodec;
@@ -34,3 +35,7 @@ pub trait BytesDecodeOwned {
fn bytes_decode_owned(bytes: &[u8]) -> Result<Self::DItem, BoxedError>; fn bytes_decode_owned(bytes: &[u8]) -> Result<Self::DItem, BoxedError>;
} }
#[derive(Error, Debug)]
#[error("the slice is too short")]
pub struct SliceTooShortError;

View File

@@ -1,4 +1,5 @@
use std::borrow::Cow; use std::borrow::Cow;
use std::ffi::CStr;
use std::str; use std::str;
use charabia::{Language, Script}; use charabia::{Language, Script};
@@ -10,17 +11,12 @@ impl<'a> heed::BytesDecode<'a> for ScriptLanguageCodec {
type DItem = (Script, Language); type DItem = (Script, Language);
fn bytes_decode(bytes: &'a [u8]) -> Result<Self::DItem, BoxedError> { fn bytes_decode(bytes: &'a [u8]) -> Result<Self::DItem, BoxedError> {
let sep = bytes let cstr = CStr::from_bytes_until_nul(bytes)?;
.iter() let script = cstr.to_str()?;
.position(|b| *b == 0)
.ok_or("cannot find nul byte")
.map_err(BoxedError::from)?;
let (s_bytes, l_bytes) = bytes.split_at(sep);
let script = str::from_utf8(s_bytes)?;
let script_name = Script::from_name(script); let script_name = Script::from_name(script);
let lan = str::from_utf8(l_bytes)?;
// skip '\0' byte between the two strings. // skip '\0' byte between the two strings.
let lan_name = Language::from_name(&lan[1..]); let lan = str::from_utf8(&bytes[script.len() + 1..])?;
let lan_name = Language::from_name(lan);
Ok((script_name, lan_name)) Ok((script_name, lan_name))
} }

View File

@@ -5,6 +5,8 @@ use std::str;
use heed::BoxedError; use heed::BoxedError;
use super::SliceTooShortError;
pub struct StrBEU32Codec; pub struct StrBEU32Codec;
impl<'a> heed::BytesDecode<'a> for StrBEU32Codec { impl<'a> heed::BytesDecode<'a> for StrBEU32Codec {
@@ -14,7 +16,7 @@ impl<'a> heed::BytesDecode<'a> for StrBEU32Codec {
let footer_len = size_of::<u32>(); let footer_len = size_of::<u32>();
if bytes.len() < footer_len { if bytes.len() < footer_len {
return Err(BoxedError::from("cannot extract footer from bytes")); return Err(SliceTooShortError.into());
} }
let (word, bytes) = bytes.split_at(bytes.len() - footer_len); let (word, bytes) = bytes.split_at(bytes.len() - footer_len);
@@ -48,7 +50,7 @@ impl<'a> heed::BytesDecode<'a> for StrBEU16Codec {
let footer_len = size_of::<u16>(); let footer_len = size_of::<u16>();
if bytes.len() < footer_len + 1 { if bytes.len() < footer_len + 1 {
return Err(BoxedError::from("cannot extract footer from bytes")); return Err(SliceTooShortError.into());
} }
let (word_plus_nul_byte, bytes) = bytes.split_at(bytes.len() - footer_len); let (word_plus_nul_byte, bytes) = bytes.split_at(bytes.len() - footer_len);

View File

@@ -1,24 +1,22 @@
use std::borrow::Cow; use std::borrow::Cow;
use std::ffi::CStr;
use std::str; use std::str;
use heed::BoxedError; use heed::BoxedError;
use super::SliceTooShortError;
pub struct U8StrStrCodec; pub struct U8StrStrCodec;
impl<'a> heed::BytesDecode<'a> for U8StrStrCodec { impl<'a> heed::BytesDecode<'a> for U8StrStrCodec {
type DItem = (u8, &'a str, &'a str); type DItem = (u8, &'a str, &'a str);
fn bytes_decode(bytes: &'a [u8]) -> Result<Self::DItem, BoxedError> { fn bytes_decode(bytes: &'a [u8]) -> Result<Self::DItem, BoxedError> {
let (n, bytes) = bytes.split_first().ok_or("not enough bytes").map_err(BoxedError::from)?; let (n, bytes) = bytes.split_first().ok_or(SliceTooShortError)?;
let s1_end = bytes let cstr = CStr::from_bytes_until_nul(bytes)?;
.iter() let s1 = cstr.to_str()?;
.position(|b| *b == 0) // skip '\0' byte between the two strings.
.ok_or("cannot find nul byte") let s2 = str::from_utf8(&bytes[s1.len() + 1..])?;
.map_err(BoxedError::from)?;
let (s1_bytes, rest) = bytes.split_at(s1_end);
let s2_bytes = &rest[1..];
let s1 = str::from_utf8(s1_bytes)?;
let s2 = str::from_utf8(s2_bytes)?;
Ok((*n, s1, s2)) Ok((*n, s1, s2))
} }
} }
@@ -41,14 +39,11 @@ impl<'a> heed::BytesDecode<'a> for UncheckedU8StrStrCodec {
type DItem = (u8, &'a [u8], &'a [u8]); type DItem = (u8, &'a [u8], &'a [u8]);
fn bytes_decode(bytes: &'a [u8]) -> Result<Self::DItem, BoxedError> { fn bytes_decode(bytes: &'a [u8]) -> Result<Self::DItem, BoxedError> {
let (n, bytes) = bytes.split_first().ok_or("not enough bytes").map_err(BoxedError::from)?; let (n, bytes) = bytes.split_first().ok_or(SliceTooShortError)?;
let s1_end = bytes let cstr = CStr::from_bytes_until_nul(bytes)?;
.iter() let s1_bytes = cstr.to_bytes();
.position(|b| *b == 0) // skip '\0' byte between the two strings.
.ok_or("cannot find nul byte") let s2_bytes = &bytes[s1_bytes.len() + 1..];
.map_err(BoxedError::from)?;
let (s1_bytes, rest) = bytes.split_at(s1_end);
let s2_bytes = &rest[1..];
Ok((*n, s1_bytes, s2_bytes)) Ok((*n, s1_bytes, s2_bytes))
} }
} }