From d4ea35d273ba66d0deb6ca860a167a9a7e5feda2 Mon Sep 17 00:00:00 2001 From: Lucas Schwiderski Date: Fri, 25 Nov 2022 16:08:58 +0100 Subject: [PATCH] feat: Implement deserialization --- CHANGELOG.adoc | 2 + Cargo.toml | 2 + src/de.rs | 677 ++++++++++++++++++++++++++++++++++++++++++------- src/error.rs | 73 ++++-- src/lib.rs | 5 +- src/parser.rs | 305 ++++++++++++++++++++++ src/ser.rs | 2 +- 7 files changed, 947 insertions(+), 119 deletions(-) create mode 100644 src/parser.rs diff --git a/CHANGELOG.adoc b/CHANGELOG.adoc index 9849837..71ad984 100644 --- a/CHANGELOG.adoc +++ b/CHANGELOG.adoc @@ -6,6 +6,8 @@ == [Unreleased] +* parsing & deserialization + == [v0.1.0] - 2022-11-18 === Added diff --git a/Cargo.toml b/Cargo.toml index 9cee7c1..38f79cc 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,6 +7,8 @@ description = "An SJSON serialization file format" categories = ["encoding", "parser-implementations"] [dependencies] +nom = "7.1.1" +nom_locate = "4.0.0" serde = { version = "1.0.147", default-features = false } [dev-dependencies] diff --git a/src/de.rs b/src/de.rs index 3d56983..e3b2286 100644 --- a/src/de.rs +++ b/src/de.rs @@ -1,233 +1,734 @@ -pub struct Deserializer {} +use nom::IResult; +use serde::de::{EnumAccess, IntoDeserializer, VariantAccess}; +use serde::Deserialize; -impl serde::Deserializer for Deserializer { - type Error; +use crate::error::{Error, ErrorCode, Result}; +use crate::parser::*; - fn deserialize_any(self, visitor: V) -> Result - where - V: serde::de::Visitor<'de>, - { - todo!() +pub struct Deserializer<'de> { + input: Span<'de>, + is_top_level: bool, +} + +impl<'de> Deserializer<'de> { + #![allow(clippy::should_implement_trait)] + pub fn from_str(input: &'de str) -> Self { + Self { + input: Span::from(input), + is_top_level: true, + } } - fn deserialize_bool(self, visitor: V) -> Result - where - V: serde::de::Visitor<'de>, - { - todo!() + fn parse(&mut self, f: &dyn Fn(Span) -> IResult) -> Result { + f(self.input) + .map(|(span, token)| { + self.input = span; + token + }) + .map_err(|err| self.error(ErrorCode::Message(err.to_string()))) } - fn deserialize_i8(self, visitor: V) -> Result - where - V: serde::de::Visitor<'de>, - { - todo!() + fn next_token(&mut self) -> Result { + match parse_next_token(self.input) { + Ok((span, token)) => { + self.input = span; + Ok(token) + } + Err(err) => Err(self.error(ErrorCode::Message(err.to_string()))), + } } - fn deserialize_i16(self, visitor: V) -> Result - where - V: serde::de::Visitor<'de>, - { - todo!() + fn peek_token(&mut self) -> Result { + match parse_next_token(self.input) { + Ok((_, token)) => Ok(token), + Err(err) => Err(self.error(ErrorCode::Message(err.to_string()))), + } } - fn deserialize_i32(self, visitor: V) -> Result + fn error(&self, code: ErrorCode) -> Error { + Error::new( + code, + self.input.location_line(), + self.input.get_utf8_column(), + Some(self.input.fragment().to_string()), + ) + } +} + +pub fn from_str<'a, T>(input: &'a str) -> Result +where + T: Deserialize<'a>, +{ + let mut de = Deserializer::from_str(input); + let t = T::deserialize(&mut de)?; + if de.input.is_empty() || parse_trailing_characters(de.input).is_ok() { + Ok(t) + } else { + Err(de.error(ErrorCode::TrailingCharacters)) + } +} + +impl<'de, 'a> serde::de::Deserializer<'de> for &'a mut Deserializer<'de> { + type Error = Error; + + fn deserialize_any(self, visitor: V) -> Result where V: serde::de::Visitor<'de>, { - todo!() + if self.is_top_level { + return Err(self.error(ErrorCode::ExpectedTopLevelObject)); + } + + match self.next_token()? { + Token::Boolean(val) => visitor.visit_bool::(val), + Token::Float(val) => visitor.visit_f64(val), + Token::Integer(val) => visitor.visit_i64(val), + Token::Null => visitor.visit_unit(), + Token::String(val) => visitor.visit_str(&val), + _ => Err(self.error(ErrorCode::ExpectedValue)), + } } - fn deserialize_i64(self, visitor: V) -> Result + fn deserialize_bool(self, visitor: V) -> Result where V: serde::de::Visitor<'de>, { - todo!() + if self.is_top_level { + return Err(self.error(ErrorCode::ExpectedTopLevelObject)); + } + + if let Ok(Token::Boolean(val)) = self.parse(&parse_bool) { + visitor.visit_bool(val) + } else { + Err(self.error(ErrorCode::ExpectedBoolean)) + } } - fn deserialize_u8(self, visitor: V) -> Result + fn deserialize_i8(self, visitor: V) -> Result where V: serde::de::Visitor<'de>, { - todo!() + self.deserialize_i64(visitor) } - fn deserialize_u16(self, visitor: V) -> Result + fn deserialize_i16(self, visitor: V) -> Result where V: serde::de::Visitor<'de>, { - todo!() + self.deserialize_i64(visitor) } - fn deserialize_u32(self, visitor: V) -> Result + fn deserialize_i32(self, visitor: V) -> Result where V: serde::de::Visitor<'de>, { - todo!() + self.deserialize_i64(visitor) } - fn deserialize_u64(self, visitor: V) -> Result + fn deserialize_i64(self, visitor: V) -> Result where V: serde::de::Visitor<'de>, { - todo!() + if self.is_top_level { + return Err(self.error(ErrorCode::ExpectedTopLevelObject)); + } + + if let Ok(Token::Integer(val)) = self.parse(&parse_integer) { + visitor.visit_i64(val) + } else { + Err(self.error(ErrorCode::ExpectedInteger)) + } } - fn deserialize_f32(self, visitor: V) -> Result + fn deserialize_u8(self, visitor: V) -> Result where V: serde::de::Visitor<'de>, { - todo!() + self.deserialize_i64(visitor) } - fn deserialize_f64(self, visitor: V) -> Result + fn deserialize_u16(self, visitor: V) -> Result where V: serde::de::Visitor<'de>, { - todo!() + self.deserialize_i64(visitor) } - fn deserialize_char(self, visitor: V) -> Result + fn deserialize_u32(self, visitor: V) -> Result where V: serde::de::Visitor<'de>, { - todo!() + self.deserialize_i64(visitor) } - fn deserialize_str(self, visitor: V) -> Result + fn deserialize_u64(self, visitor: V) -> Result where V: serde::de::Visitor<'de>, { - todo!() + self.deserialize_i64(visitor) } - fn deserialize_string(self, visitor: V) -> Result + fn deserialize_f32(self, visitor: V) -> Result where V: serde::de::Visitor<'de>, { - todo!() + self.deserialize_f64(visitor) } - fn deserialize_bytes(self, visitor: V) -> Result + fn deserialize_f64(self, visitor: V) -> Result where V: serde::de::Visitor<'de>, { - todo!() + if self.is_top_level { + return Err(self.error(ErrorCode::ExpectedTopLevelObject)); + } + + if let Ok(Token::Float(val)) = self.parse(&parse_float) { + visitor.visit_f64(val) + } else { + Err(self.error(ErrorCode::ExpectedFloat)) + } } - fn deserialize_byte_buf(self, visitor: V) -> Result + fn deserialize_char(self, visitor: V) -> Result where V: serde::de::Visitor<'de>, { - todo!() + self.deserialize_str(visitor) } - fn deserialize_option(self, visitor: V) -> Result + fn deserialize_str(self, visitor: V) -> Result where V: serde::de::Visitor<'de>, { - todo!() + if self.is_top_level { + return Err(self.error(ErrorCode::ExpectedTopLevelObject)); + } + + if let Ok(Token::String(val)) = self.parse(&parse_string) { + visitor.visit_str(&val) + } else { + Err(self.error(ErrorCode::ExpectedString)) + } } - fn deserialize_unit(self, visitor: V) -> Result + fn deserialize_string(self, visitor: V) -> Result where V: serde::de::Visitor<'de>, { - todo!() + self.deserialize_str(visitor) } - fn deserialize_unit_struct( - self, - name: &'static str, - visitor: V, - ) -> Result + fn deserialize_bytes(self, _visitor: V) -> Result where V: serde::de::Visitor<'de>, { - todo!() + unimplemented!() } - fn deserialize_newtype_struct( - self, - name: &'static str, - visitor: V, - ) -> Result + fn deserialize_byte_buf(self, _visitor: V) -> Result where V: serde::de::Visitor<'de>, { - todo!() + unimplemented!() } - fn deserialize_seq(self, visitor: V) -> Result + fn deserialize_option(self, visitor: V) -> Result where V: serde::de::Visitor<'de>, { - todo!() + if self.is_top_level { + return Err(self.error(ErrorCode::ExpectedTopLevelObject)); + } + + if self.peek_token()? == Token::Null { + let _ = self.next_token()?; + visitor.visit_none() + } else { + visitor.visit_some(self) + } } - fn deserialize_tuple(self, len: usize, visitor: V) -> Result + fn deserialize_unit(self, visitor: V) -> Result where V: serde::de::Visitor<'de>, { - todo!() + if self.is_top_level { + return Err(self.error(ErrorCode::ExpectedTopLevelObject)); + } + + if let Ok(Token::Null) = self.parse(&parse_null) { + visitor.visit_unit() + } else { + Err(self.error(ErrorCode::ExpectedNull)) + } + } + + fn deserialize_unit_struct(self, _name: &'static str, visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + self.deserialize_unit(visitor) + } + + fn deserialize_newtype_struct(self, _name: &'static str, visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + visitor.visit_newtype_struct(self) + } + + fn deserialize_seq(self, visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + if self.is_top_level { + return Err(self.error(ErrorCode::ExpectedTopLevelObject)); + } + + if self.next_token()? != Token::ArrayStart { + return Err(self.error(ErrorCode::ExpectedArray)); + } + + let value = visitor.visit_seq(Separated::new(self))?; + + if self.next_token()? == Token::ArrayEnd { + Ok(value) + } else { + Err(self.error(ErrorCode::ExpectedArrayEnd)) + } + } + + fn deserialize_tuple(self, _len: usize, visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + self.deserialize_seq(visitor) } fn deserialize_tuple_struct( self, - name: &'static str, - len: usize, + _name: &'static str, + _len: usize, visitor: V, - ) -> Result + ) -> Result where V: serde::de::Visitor<'de>, { - todo!() + self.deserialize_seq(visitor) } - fn deserialize_map(self, visitor: V) -> Result + fn deserialize_map(self, visitor: V) -> Result where V: serde::de::Visitor<'de>, { - todo!() + if self.is_top_level { + self.is_top_level = false; + + visitor.visit_map(Separated::new(self)) + } else { + if self.next_token()? != Token::ObjectStart { + return Err(self.error(ErrorCode::ExpectedMap)); + } + + let value = visitor.visit_map(Separated::new(self))?; + if self.next_token()? == Token::ObjectEnd { + Ok(value) + } else { + Err(self.error(ErrorCode::ExpectedMapEnd)) + } + } } fn deserialize_struct( self, - name: &'static str, - fields: &'static [&'static str], + _name: &'static str, + _fields: &'static [&'static str], visitor: V, - ) -> Result + ) -> Result where V: serde::de::Visitor<'de>, { - todo!() + self.deserialize_map(visitor) } fn deserialize_enum( self, - name: &'static str, - variants: &'static [&'static str], + _name: &'static str, + _variants: &'static [&'static str], visitor: V, - ) -> Result + ) -> Result where V: serde::de::Visitor<'de>, { - todo!() + match self.next_token()? { + Token::String(val) => visitor.visit_enum(val.into_deserializer()), + Token::ObjectStart => { + let value = visitor.visit_enum(Enum::new(self))?; + + if self.next_token()? == Token::ObjectEnd { + Ok(value) + } else { + Err(self.error(ErrorCode::ExpectedMapEnd)) + } + } + _ => Err(self.error(ErrorCode::ExpectedEnum)), + } } - fn deserialize_identifier(self, visitor: V) -> Result + fn deserialize_identifier(self, visitor: V) -> Result where V: serde::de::Visitor<'de>, { - todo!() + if let Ok(Token::String(val)) = self.parse(&parse_identifier) { + visitor.visit_str(&val) + } else { + Err(self.error(ErrorCode::ExpectedString)) + } } - fn deserialize_ignored_any(self, visitor: V) -> Result + fn deserialize_ignored_any(self, visitor: V) -> Result where V: serde::de::Visitor<'de>, { - todo!() + self.deserialize_any(visitor) } } -pub fn from_str() -> crate::Result {} +struct Separated<'a, 'de: 'a> { + de: &'a mut Deserializer<'de>, + first: bool, +} + +impl<'a, 'de: 'a> Separated<'a, 'de> { + fn new(de: &'a mut Deserializer<'de>) -> Self { + Self { de, first: true } + } +} + +impl<'de, 'a> serde::de::SeqAccess<'de> for Separated<'a, 'de> { + type Error = Error; + + fn next_element_seed(&mut self, seed: T) -> Result> + where + T: serde::de::DeserializeSeed<'de>, + { + if self.de.peek_token()? == Token::ArrayEnd { + return Ok(None); + } + + if !self.first && self.de.parse(&parse_separator)? != Token::Separator { + return Err(self.de.error(ErrorCode::ExpectedArraySeparator)); + } + + self.first = false; + + // TODO: Shouldn't I check that this is a valid value? + seed.deserialize(&mut *self.de).map(Some) + } +} + +impl<'de, 'a> serde::de::MapAccess<'de> for Separated<'a, 'de> { + type Error = Error; + + fn next_key_seed(&mut self, seed: K) -> Result> + where + K: serde::de::DeserializeSeed<'de>, + { + if matches!(self.de.peek_token()?, Token::ObjectEnd | Token::Eof) { + return Ok(None); + } + + if !self.first && self.de.parse(&parse_separator)? != Token::Separator { + return Err(self.de.error(ErrorCode::ExpectedMapSeparator)); + } + + self.first = false; + + // TODO: Shouldn't I check that this is a valid identifier? + seed.deserialize(&mut *self.de).map(Some) + } + + fn next_value_seed(&mut self, seed: V) -> Result + where + V: serde::de::DeserializeSeed<'de>, + { + if self.de.next_token()? != Token::Equals { + return Err(self.de.error(ErrorCode::ExpectedMapEquals)); + } + + // TODO: Shouldn't I check that this is a valid value? + seed.deserialize(&mut *self.de) + } +} + +struct Enum<'a, 'de: 'a> { + de: &'a mut Deserializer<'de>, +} + +impl<'a, 'de> Enum<'a, 'de> { + fn new(de: &'a mut Deserializer<'de>) -> Self { + Self { de } + } +} + +impl<'de, 'a> EnumAccess<'de> for Enum<'a, 'de> { + type Error = Error; + type Variant = Self; + + fn variant_seed(self, seed: V) -> Result<(V::Value, Self::Variant)> + where + V: serde::de::DeserializeSeed<'de>, + { + let val = seed.deserialize(&mut *self.de)?; + + if self.de.next_token()? == Token::Equals { + Ok((val, self)) + } else { + Err(self.de.error(ErrorCode::ExpectedMapEquals)) + } + } +} + +impl<'de, 'a> VariantAccess<'de> for Enum<'a, 'de> { + type Error = Error; + + fn unit_variant(self) -> Result<()> { + Err(self.de.error(ErrorCode::ExpectedString)) + } + + fn newtype_variant_seed(self, seed: T) -> Result + where + T: serde::de::DeserializeSeed<'de>, + { + seed.deserialize(self.de) + } + + fn tuple_variant(self, _len: usize, visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + serde::Deserializer::deserialize_seq(self.de, visitor) + } + + fn struct_variant(self, _fields: &'static [&'static str], visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + serde::Deserializer::deserialize_map(self.de, visitor) + } +} + +#[cfg(test)] +mod test { + use std::path::PathBuf; + + use crate::error::{Error, ErrorCode}; + use crate::from_str; + + macro_rules! assert_value_ok { + ($type:ty, $json:expr) => { + assert_value_ok!($type, Default::default(), $json) + }; + ($type:ty, $expected:expr, $json:expr) => {{ + #[derive(Debug, serde::Deserialize, PartialEq)] + struct Value { + value: $type, + } + + let expected = Value { value: $expected }; + + let json = format!("value = {}", $json); + let actual = from_str::(&json).unwrap(); + assert_eq!(actual, expected); + }}; + } + + macro_rules! assert_ok { + ($type:ty, $expected:expr, $json:expr) => {{ + let actual = from_str::<$type>($json).unwrap(); + assert_eq!(actual, $expected); + }}; + } + + macro_rules! assert_value_err { + ($type:ty, $expected:expr, $json:expr) => {{ + #[derive(Debug, serde::Deserialize, PartialEq)] + struct Value { + value: $type, + } + + let json = format!("value = {}", $json); + let actual = from_str::(&json); + assert_eq!(actual, Err($expected)); + }}; + } + + #[test] + fn deserialize_null() { + assert_value_ok!((), "null"); + + let err = Error::new(ErrorCode::ExpectedNull, 1, 8, Some(" foo".to_string())); + assert_value_err!((), err, "foo"); + } + + #[test] + fn deserialize_bool() { + assert_value_ok!(bool, true, "true"); + assert_value_ok!(bool, false, "false"); + + let err = Error::new(ErrorCode::ExpectedBoolean, 1, 8, Some(" foo".to_string())); + assert_value_err!(bool, err, "foo"); + } + + #[test] + fn deserialize_integer() { + assert_value_ok!(i64, 0, "0"); + assert_value_ok!(i64, -1, "-1"); + assert_value_ok!(i64, i64::MAX, i64::MAX.to_string()); + assert_value_ok!(i64, i64::MIN, i64::MIN.to_string()); + + assert_value_ok!(i8, 0, "0"); + assert_value_ok!(i8, 102, "102"); + assert_value_ok!(i8, -102, "-102"); + assert_value_ok!(u8, 102, "102"); + assert_value_ok!(i16, 256, "256"); + + let err = Error::new(ErrorCode::ExpectedInteger, 1, 8, Some(" foo".to_string())); + assert_value_err!(i64, err, "foo"); + } + + #[test] + fn deserialize_float() { + assert_value_ok!(f64, 0.0, "0"); + assert_value_ok!(f64, 0.0, "0.0"); + assert_value_ok!(f64, -1.0, "-1"); + assert_value_ok!(f64, -1.0, "-1.0"); + assert_value_ok!(f64, f64::MAX, f64::MAX.to_string()); + assert_value_ok!(f64, f64::MIN, f64::MIN.to_string()); + } + + #[test] + fn deserialize_vec() { + assert_value_ok!(Vec, vec![1, 2, 3], "[1, 2, 3]"); + assert_value_ok!( + Vec, + vec![1, 2, 3], + "\ +[ + 1 + 2 + 3 +]" + ); + } + + #[test] + fn deserialize_enum() { + #[derive(Debug, serde::Deserialize, PartialEq)] + enum Animal { + Mouse, + Dog { name: String }, + Cat(u64), + } + + assert_value_ok!(Animal, Animal::Mouse, "Mouse"); + assert_value_ok!(Animal, Animal::Cat(9), "{ Cat = 9 }"); + assert_value_ok!( + Animal, + Animal::Dog { + name: String::from("Buddy") + }, + "{ Dog = { name = Buddy }}" + ); + } + + // Checks the example from + // https://help.autodesk.com/view/Stingray/ENU/?guid=__stingray_help_managing_content_sjson_html + #[test] + fn deserialize_stingray_example() { + #[derive(Debug, serde::Deserialize, PartialEq)] + struct Win32Settings { + query_performance_counter_affinity_mask: u64, + } + + #[derive(Debug, serde::Deserialize, PartialEq)] + struct Settings { + boot_script: String, + console_port: u16, + win32: Win32Settings, + render_config: PathBuf, + } + + let expected = Settings { + boot_script: String::from("boot"), + console_port: 14030, + win32: Win32Settings { + query_performance_counter_affinity_mask: 0, + }, + render_config: PathBuf::from("core/rendering/renderer"), + }; + + let json = r#" +// The script that should be started when the application runs. +boot_script = "boot" + +// The port on which the console server runs. +console_port = 14030 + +// Settings for the win32 platform +win32 = { + /* Sets the affinity mask for + QueryPerformanceCounter() */ + query_performance_counter_affinity_mask = 0 +} + +render_config = "core/rendering/renderer" +"#; + + assert_ok!(Settings, expected, json); + } + + #[test] + fn deserialize_missing_top_level_struct() { + let json = "0"; + let err = Error::new( + ErrorCode::ExpectedTopLevelObject, + 1, + 1, + Some(json.to_string()), + ); + let actual = from_str::(json); + assert_eq!(actual, Err(err)); + + let json = "1.23"; + let err = Error::new( + ErrorCode::ExpectedTopLevelObject, + 1, + 1, + Some(json.to_string()), + ); + let actual = from_str::(json); + assert_eq!(actual, Err(err)); + + let json = "true"; + let err = Error::new( + ErrorCode::ExpectedTopLevelObject, + 1, + 1, + Some(json.to_string()), + ); + let actual = from_str::(json); + assert_eq!(actual, Err(err)); + + let json = "null"; + let err = Error::new( + ErrorCode::ExpectedTopLevelObject, + 1, + 1, + Some(json.to_string()), + ); + let actual = from_str::<()>(json); + assert_eq!(actual, Err(err)); + } +} diff --git a/src/error.rs b/src/error.rs index 4cddad0..6e604b4 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,36 +1,39 @@ -use std::{fmt, io}; +use std::fmt; pub type Result = std::result::Result; +#[derive(PartialEq)] pub struct Error { inner: Box, } +#[derive(PartialEq)] struct ErrorImpl { code: ErrorCode, - line: usize, + line: u32, column: usize, + fragment: Option, } -// TODO: Remove once they are constructed -#[allow(dead_code)] +#[derive(PartialEq)] pub(crate) enum ErrorCode { // Generic error built from a message or different error Message(String), - // Wrap inner I/O errors - Io(io::Error), - Eof, - Syntax, - ExpectedTopLevelObject, - ExpectedBoolean, - ExpectedInteger, - ExpectedString, - ExpectedNull, ExpectedArray, ExpectedArrayEnd, + ExpectedArraySeparator, + ExpectedBoolean, + ExpectedEnum, + ExpectedFloat, + ExpectedInteger, ExpectedMap, - ExpectedMapEquals, ExpectedMapEnd, + ExpectedMapEquals, + ExpectedMapSeparator, + ExpectedNull, + ExpectedString, + ExpectedTopLevelObject, + ExpectedValue, TrailingCharacters, } @@ -38,19 +41,25 @@ impl fmt::Display for ErrorCode { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { ErrorCode::Message(msg) => f.write_str(msg), - ErrorCode::Io(err) => fmt::Display::fmt(err, f), - ErrorCode::Eof => f.write_str("unexpected end of input"), - ErrorCode::Syntax => f.write_str("syntax error"), - ErrorCode::ExpectedTopLevelObject => f.write_str("expected object at the top level"), - ErrorCode::ExpectedBoolean => f.write_str("expected a boolean value"), - ErrorCode::ExpectedInteger => f.write_str("expected an integer value"), - ErrorCode::ExpectedString => f.write_str("expected a string value"), - ErrorCode::ExpectedNull => f.write_str("expected null"), ErrorCode::ExpectedArray => f.write_str("expected an array value"), ErrorCode::ExpectedArrayEnd => f.write_str("expected an array end delimiter"), - ErrorCode::ExpectedMap => f.write_str("expected an object value"), - ErrorCode::ExpectedMapEquals => f.write_str("expected a '=' between key and value"), + ErrorCode::ExpectedArraySeparator => { + f.write_str("expected comma or newline between array entries") + } + ErrorCode::ExpectedBoolean => f.write_str("expected a boolean value"), + ErrorCode::ExpectedEnum => f.write_str("expected string or object"), + ErrorCode::ExpectedFloat => f.write_str("expected floating point number"), + ErrorCode::ExpectedInteger => f.write_str("expected an integer value"), + ErrorCode::ExpectedMap => f.write_str("expected an object"), ErrorCode::ExpectedMapEnd => f.write_str("expected an object end delimiter"), + ErrorCode::ExpectedMapEquals => f.write_str("expected a '=' between key and value"), + ErrorCode::ExpectedMapSeparator => { + f.write_str("expected comma or newline between object entries") + } + ErrorCode::ExpectedNull => f.write_str("expected null"), + ErrorCode::ExpectedString => f.write_str("expected a string value"), + ErrorCode::ExpectedTopLevelObject => f.write_str("expected object at the top level"), + ErrorCode::ExpectedValue => f.write_str("expected a value"), ErrorCode::TrailingCharacters => f.write_str("unexpected trailing characters"), } } @@ -80,10 +89,11 @@ impl fmt::Debug for Error { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!( f, - "Error({:?}, line: {}, column: {})", + "Error({:?}, line: {}, column: {}, fragment: {:?})", self.inner.code.to_string(), self.inner.line, - self.inner.column + self.inner.column, + self.inner.fragment, ) } } @@ -97,6 +107,7 @@ impl serde::de::Error for Error { code: ErrorCode::Message(msg.to_string()), line: 0, column: 0, + fragment: None, }); Self { inner } } @@ -111,6 +122,7 @@ impl serde::ser::Error for Error { code: ErrorCode::Message(msg.to_string()), line: 0, column: 0, + fragment: None, }); Self { inner } } @@ -119,9 +131,14 @@ impl serde::ser::Error for Error { impl std::error::Error for Error {} impl Error { - pub(crate) fn new(code: ErrorCode, line: usize, column: usize) -> Self { + pub(crate) fn new(code: ErrorCode, line: u32, column: usize, fragment: Option) -> Self { Self { - inner: Box::new(ErrorImpl { code, line, column }), + inner: Box::new(ErrorImpl { + code, + line, + column, + fragment, + }), } } } diff --git a/src/lib.rs b/src/lib.rs index 80b500b..440118a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,7 +1,8 @@ -// mod de; +mod de; mod error; +mod parser; mod ser; -// pub use de::{from_str, Deserializer}; +pub use de::{from_str, Deserializer}; pub use error::{Error, Result}; pub use ser::{to_string, Serializer}; diff --git a/src/parser.rs b/src/parser.rs new file mode 100644 index 0000000..0371881 --- /dev/null +++ b/src/parser.rs @@ -0,0 +1,305 @@ +use nom::branch::alt; +use nom::bytes::complete::{escaped, tag, take_until}; +use nom::character::complete::{ + alpha1, alphanumeric1, char, digit1, none_of, not_line_ending, one_of, +}; +use nom::combinator::{cut, eof, map, map_res, opt, recognize, value}; +use nom::multi::{many0_count, many1_count}; +use nom::number::complete::double; +use nom::sequence::{delimited, pair, preceded, terminated, tuple}; +use nom::IResult; +use nom_locate::LocatedSpan; + +pub(crate) type Span<'a> = LocatedSpan<&'a str>; + +#[derive(Clone, Debug, PartialEq)] +pub(crate) enum Token { + ArrayEnd, + ArrayStart, + Boolean(bool), + Eof, + Equals, + Float(f64), + Integer(i64), + Null, + ObjectEnd, + ObjectStart, + Separator, + String(String), +} + +fn horizontal_whitespace(input: Span) -> IResult { + one_of(" \t")(input) +} + +fn whitespace(input: Span) -> IResult { + one_of(" \n\r\t")(input) +} + +fn null(input: Span) -> IResult { + value((), tag("null"))(input) +} + +fn separator(input: Span) -> IResult { + map(alt((tag(","), tag("\n"))), |val: Span| *val.fragment())(input) +} + +fn bool(input: Span) -> IResult { + alt((value(true, tag("true")), value(false, tag("false"))))(input) +} + +fn integer(input: Span) -> IResult { + map_res(recognize(tuple((opt(char('-')), digit1))), |val: Span| { + val.fragment().parse::() + })(input) +} + +fn float(input: Span) -> IResult { + double(input) +} + +fn identifier(input: Span) -> IResult { + let leading = alt((alpha1, tag("_"))); + let trailing = many0_count(alt((alphanumeric1, tag("_")))); + let ident = pair(leading, trailing); + + map(recognize(ident), |val: Span| *val.fragment())(input) +} + +fn string_content(input: Span) -> IResult { + // TODO: Handle Unicode escapes + map( + alt(( + escaped(none_of("\n\\\""), '\\', one_of(r#""rtn\"#)), + tag(""), + )), + |val: Span| *val.fragment(), + )(input) +} + +fn delimited_string(input: Span) -> IResult { + preceded(char('"'), cut(terminated(string_content, char('"'))))(input) +} + +fn string(input: Span) -> IResult { + alt((identifier, delimited_string))(input) +} + +fn line_comment(input: Span) -> IResult { + map( + preceded(tag("//"), alt((not_line_ending, eof))), + |val: Span| *val.fragment(), + )(input) +} + +fn block_comment(input: Span) -> IResult { + map( + delimited(tag("/*"), take_until("*/"), tag("*/")), + |val: Span| *val.fragment(), + )(input) +} + +fn comment(input: Span) -> IResult { + alt((line_comment, block_comment))(input) +} + +fn optional(input: Span) -> IResult { + let whitespace = value((), whitespace); + let comment = value((), comment); + let empty = value((), tag("")); + let content = value((), many1_count(alt((whitespace, comment)))); + + alt((content, empty))(input) +} + +pub(crate) fn parse_next_token(input: Span) -> IResult { + preceded( + opt(optional), + alt(( + // Order is important here. + // Certain valid strings like "null", "true" or "false" need to be + // matched to their special value. + // Integer-like numbers need to be matched to that, but are valid floats, too. + value(Token::Eof, eof), + value(Token::Separator, separator), + value(Token::ObjectStart, tag("{")), + value(Token::ObjectEnd, tag("}")), + value(Token::ArrayStart, tag("[")), + value(Token::ArrayEnd, tag("]")), + value(Token::Equals, tag("=")), + value(Token::Null, null), + map(bool, Token::Boolean), + map(integer, Token::Integer), + map(float, Token::Float), + map(string, |val| Token::String(val.to_string())), + )), + )(input) +} + +pub(crate) fn parse_trailing_characters(input: Span) -> IResult { + value((), optional)(input) +} + +pub(crate) fn parse_null(input: Span) -> IResult { + preceded(optional, value(Token::Null, null))(input) +} + +pub(crate) fn parse_separator(input: Span) -> IResult { + preceded( + opt(horizontal_whitespace), + value(Token::Separator, separator), + )(input) +} + +pub(crate) fn parse_bool(input: Span) -> IResult { + preceded(optional, map(bool, Token::Boolean))(input) +} + +pub(crate) fn parse_integer(input: Span) -> IResult { + preceded(optional, map(integer, Token::Integer))(input) +} + +pub(crate) fn parse_float(input: Span) -> IResult { + preceded(optional, map(float, Token::Float))(input) +} + +pub(crate) fn parse_identifier(input: Span) -> IResult { + preceded( + optional, + map(identifier, |val| Token::String(val.to_string())), + )(input) +} + +pub(crate) fn parse_string(input: Span) -> IResult { + preceded(optional, map(string, |val| Token::String(val.to_string())))(input) +} + +#[cfg(test)] +mod test { + use nom::error::{Error, ErrorKind}; + use nom::Err; + + use super::*; + + macro_rules! assert_ok { + ($input:expr, $parser:ident, $remain:expr, $output:expr) => {{ + let res = super::$parser(Span::from($input)); + assert_eq!( + res.map(|(span, res)| { (*span, res) }), + Ok(($remain, $output)) + ); + }}; + } + + macro_rules! assert_err { + ($input:expr, $parser:ident, $kind:expr) => {{ + { + let input = Span::from($input); + assert_eq!( + super::$parser(input), + Err(Err::Error(Error::new(input, $kind))) + ); + } + }}; + } + + #[test] + fn parse_optional() { + assert_ok!("\n", whitespace, "", '\n'); + assert_ok!("\t", whitespace, "", '\t'); + assert_ok!(" ", whitespace, " ", ' '); + assert_ok!("/* foo bar */", comment, "", " foo bar "); + assert_ok!("// foo", comment, "", " foo"); + assert_ok!("// foo\n", comment, "\n", " foo"); + + assert_ok!("", optional, "", ()); + assert_ok!("\t\n", optional, "", ()); + assert_ok!("\n\t", optional, "", ()); + assert_ok!("// foo", optional, "", ()); + assert_ok!("\n\t// foo\n\t/* foo\n\tbar */\n", optional, "", ()); + } + + #[test] + fn parse_integer() { + assert_ok!("3", integer, "", 3); + assert_ok!("12345", integer, "", 12345); + assert_ok!("-12345", integer, "", -12345); + assert_ok!("12345 ", integer, " ", 12345); + + assert_err!(" 12345", integer, ErrorKind::Digit); + + assert_ok!(" 12345", parse_integer, "", Token::Integer(12345)); + assert_ok!("\n12345", parse_integer, "", Token::Integer(12345)); + assert_ok!("\t12345", parse_integer, "", Token::Integer(12345)); + } + + #[test] + fn parse_float() { + assert_ok!("3", float, "", 3.0); + assert_ok!("3.0", float, "", 3.0); + assert_ok!("3.1415", float, "", 3.1415); + assert_ok!("-123.456789", float, "", -123.456789); + assert_err!(" 1.23", float, ErrorKind::Float); + assert_ok!("1.23 ", float, " ", 1.23); + } + + #[test] + fn parse_raw_string() { + assert_ok!("foo", identifier, "", "foo"); + assert_ok!("foo123", identifier, "", "foo123"); + assert_ok!("foo_bar", identifier, "", "foo_bar"); + assert_ok!("_foo", identifier, "", "_foo"); + assert_ok!("foo bar", identifier, " bar", "foo"); + + assert_err!("123", identifier, ErrorKind::Tag); + assert_err!("1foo", identifier, ErrorKind::Tag); + assert_err!("\"foo\"", identifier, ErrorKind::Tag); + } + + #[test] + fn parse_delimited_string() { + assert_ok!(r#""""#, delimited_string, "", ""); + assert_ok!(r#""foo""#, delimited_string, "", "foo"); + assert_ok!(r#""\"foo""#, delimited_string, "", r#"\"foo"#); + assert_ok!(r#""foo bar""#, delimited_string, "", "foo bar"); + assert_ok!(r#""foo123""#, delimited_string, "", "foo123"); + assert_ok!(r#""123foo""#, delimited_string, "", "123foo"); + assert_ok!(r#""foo\"bar""#, delimited_string, "", "foo\\\"bar"); + + assert_err!("foo\"", delimited_string, ErrorKind::Char); + + { + let input = Span::from("\"foo"); + assert_eq!( + delimited_string(input), + Err(Err::Failure(Error::new( + unsafe { Span::new_from_raw_offset(4, 1, "", ()) }, + ErrorKind::Char + ))) + ); + } + + { + let input = Span::from("\"foo\nbar\""); + assert_eq!( + delimited_string(input), + Err(Err::Failure(Error::new( + unsafe { Span::new_from_raw_offset(4, 1, "\nbar\"", ()) }, + ErrorKind::Char + ))) + ); + } + } + + #[test] + fn parse_line_comment() { + assert_ok!("// foo", line_comment, "", " foo"); + assert_ok!("// foo\n", line_comment, "\n", " foo"); + } + + #[test] + fn parse_block_comment() { + assert_ok!("/* foo */", block_comment, "", " foo "); + assert_ok!("/*\n\tfoo\nbar\n*/", block_comment, "", "\n\tfoo\nbar\n"); + } +} diff --git a/src/ser.rs b/src/ser.rs index 605a57b..51c9fd5 100644 --- a/src/ser.rs +++ b/src/ser.rs @@ -33,7 +33,7 @@ impl Serializer { fn ensure_top_level_struct(&self) -> Result<()> { if self.level == 0 { - return Err(Error::new(ErrorCode::ExpectedTopLevelObject, 0, 0)); + return Err(Error::new(ErrorCode::ExpectedTopLevelObject, 0, 0, None)); } Ok(())