use nom::IResult; use serde::de::{EnumAccess, IntoDeserializer, VariantAccess}; use serde::Deserialize; use crate::error::{Error, ErrorCode, Result}; use crate::parser::*; pub struct Deserializer<'de> { input: Span<'de>, is_top_level: bool, } impl<'de> Deserializer<'de> { #![allow(clippy::should_implement_trait)] pub fn from_str(input: &'de str) -> Self { Self { input: Span::from(input), is_top_level: true, } } fn parse(&mut self, f: &dyn Fn(Span) -> IResult) -> Result { f(self.input) .map(|(span, token)| { self.input = span; token }) .map_err(|err| self.error(ErrorCode::Message(err.to_string()))) } fn next_token(&mut self) -> Result { match parse_next_token(self.input) { Ok((span, token)) => { self.input = span; Ok(token) } Err(err) => Err(self.error(ErrorCode::Message(err.to_string()))), } } fn peek_token(&mut self) -> Result { match parse_next_token(self.input) { Ok((_, token)) => Ok(token), Err(err) => Err(self.error(ErrorCode::Message(err.to_string()))), } } fn error(&self, code: ErrorCode) -> Error { Error::new( code, self.input.location_line(), self.input.get_utf8_column(), Some(self.input.fragment().to_string()), ) } fn error_with_token(&self, code: ErrorCode, token: Token) -> Error { Error::with_token( code, self.input.location_line(), self.input.get_utf8_column(), Some(self.input.fragment().to_string()), token, ) } } pub fn from_str<'a, T>(input: &'a str) -> Result where T: Deserialize<'a>, { let mut de = Deserializer::from_str(input); let t = T::deserialize(&mut de)?; if de.input.is_empty() || parse_trailing_characters(de.input).is_ok() { Ok(t) } else { Err(de.error(ErrorCode::TrailingCharacters)) } } impl<'de, 'a> serde::de::Deserializer<'de> for &'a mut Deserializer<'de> { type Error = Error; fn deserialize_any(self, visitor: V) -> Result where V: serde::de::Visitor<'de>, { if self.is_top_level { return Err(self.error(ErrorCode::ExpectedTopLevelObject)); } match self.peek_token()? { Token::Boolean(_) => self.deserialize_bool(visitor), Token::Float(_) => self.deserialize_f64(visitor), Token::Integer(_) => self.deserialize_i64(visitor), Token::Null => self.deserialize_unit(visitor), Token::String(_) => self.deserialize_str(visitor), Token::ArrayStart => self.deserialize_seq(visitor), Token::ObjectStart => self.deserialize_map(visitor), token => Err(self.error_with_token(ErrorCode::ExpectedValue, token)), } } fn deserialize_bool(self, visitor: V) -> Result where V: serde::de::Visitor<'de>, { if self.is_top_level { return Err(self.error(ErrorCode::ExpectedTopLevelObject)); } if let Ok(Token::Boolean(val)) = self.parse(&parse_bool) { visitor.visit_bool(val) } else { Err(self.error(ErrorCode::ExpectedBoolean)) } } fn deserialize_i8(self, visitor: V) -> Result where V: serde::de::Visitor<'de>, { self.deserialize_i64(visitor) } fn deserialize_i16(self, visitor: V) -> Result where V: serde::de::Visitor<'de>, { self.deserialize_i64(visitor) } fn deserialize_i32(self, visitor: V) -> Result where V: serde::de::Visitor<'de>, { self.deserialize_i64(visitor) } fn deserialize_i64(self, visitor: V) -> Result where V: serde::de::Visitor<'de>, { if self.is_top_level { return Err(self.error(ErrorCode::ExpectedTopLevelObject)); } if let Ok(Token::Integer(val)) = self.parse(&parse_integer) { visitor.visit_i64(val) } else { Err(self.error(ErrorCode::ExpectedInteger)) } } fn deserialize_u8(self, visitor: V) -> Result where V: serde::de::Visitor<'de>, { self.deserialize_i64(visitor) } fn deserialize_u16(self, visitor: V) -> Result where V: serde::de::Visitor<'de>, { self.deserialize_i64(visitor) } fn deserialize_u32(self, visitor: V) -> Result where V: serde::de::Visitor<'de>, { self.deserialize_i64(visitor) } fn deserialize_u64(self, visitor: V) -> Result where V: serde::de::Visitor<'de>, { self.deserialize_i64(visitor) } fn deserialize_f32(self, visitor: V) -> Result where V: serde::de::Visitor<'de>, { self.deserialize_f64(visitor) } fn deserialize_f64(self, visitor: V) -> Result where V: serde::de::Visitor<'de>, { if self.is_top_level { return Err(self.error(ErrorCode::ExpectedTopLevelObject)); } if let Ok(Token::Float(val)) = self.parse(&parse_float) { visitor.visit_f64(val) } else { Err(self.error(ErrorCode::ExpectedFloat)) } } fn deserialize_char(self, visitor: V) -> Result where V: serde::de::Visitor<'de>, { self.deserialize_str(visitor) } fn deserialize_str(self, visitor: V) -> Result where V: serde::de::Visitor<'de>, { if self.is_top_level { return Err(self.error(ErrorCode::ExpectedTopLevelObject)); } if let Ok(Token::String(val)) = self.parse(&parse_string) { visitor.visit_str(&val) } else { Err(self.error(ErrorCode::ExpectedString)) } } fn deserialize_string(self, visitor: V) -> Result where V: serde::de::Visitor<'de>, { self.deserialize_str(visitor) } fn deserialize_bytes(self, _visitor: V) -> Result where V: serde::de::Visitor<'de>, { unimplemented!() } fn deserialize_byte_buf(self, _visitor: V) -> Result where V: serde::de::Visitor<'de>, { unimplemented!() } fn deserialize_option(self, visitor: V) -> Result where V: serde::de::Visitor<'de>, { if self.is_top_level { return Err(self.error(ErrorCode::ExpectedTopLevelObject)); } if self.peek_token()? == Token::Null { let _ = self.next_token()?; visitor.visit_none() } else { visitor.visit_some(self) } } fn deserialize_unit(self, visitor: V) -> Result where V: serde::de::Visitor<'de>, { if self.is_top_level { return Err(self.error(ErrorCode::ExpectedTopLevelObject)); } if let Ok(Token::Null) = self.parse(&parse_null) { visitor.visit_unit() } else { Err(self.error(ErrorCode::ExpectedNull)) } } fn deserialize_unit_struct(self, _name: &'static str, visitor: V) -> Result where V: serde::de::Visitor<'de>, { self.deserialize_unit(visitor) } fn deserialize_newtype_struct(self, _name: &'static str, visitor: V) -> Result where V: serde::de::Visitor<'de>, { visitor.visit_newtype_struct(self) } fn deserialize_seq(self, visitor: V) -> Result where V: serde::de::Visitor<'de>, { if self.is_top_level { return Err(self.error(ErrorCode::ExpectedTopLevelObject)); } if self.next_token()? != Token::ArrayStart { return Err(self.error(ErrorCode::ExpectedArray)); } let value = visitor.visit_seq(Separated::new(self))?; if self.next_token()? == Token::ArrayEnd { Ok(value) } else { Err(self.error(ErrorCode::ExpectedArrayEnd)) } } fn deserialize_tuple(self, _len: usize, visitor: V) -> Result where V: serde::de::Visitor<'de>, { self.deserialize_seq(visitor) } fn deserialize_tuple_struct( self, _name: &'static str, _len: usize, visitor: V, ) -> Result where V: serde::de::Visitor<'de>, { self.deserialize_seq(visitor) } fn deserialize_map(self, visitor: V) -> Result where V: serde::de::Visitor<'de>, { if self.is_top_level { self.is_top_level = false; visitor.visit_map(Separated::new(self)) } else { if self.next_token()? != Token::ObjectStart { return Err(self.error(ErrorCode::ExpectedMap)); } let value = visitor.visit_map(Separated::new(self))?; if self.next_token()? == Token::ObjectEnd { Ok(value) } else { Err(self.error(ErrorCode::ExpectedMapEnd)) } } } fn deserialize_struct( self, _name: &'static str, _fields: &'static [&'static str], visitor: V, ) -> Result where V: serde::de::Visitor<'de>, { self.deserialize_map(visitor) } fn deserialize_enum( self, _name: &'static str, _variants: &'static [&'static str], visitor: V, ) -> Result where V: serde::de::Visitor<'de>, { match self.next_token()? { Token::String(val) => visitor.visit_enum(val.into_deserializer()), Token::ObjectStart => { let value = visitor.visit_enum(Enum::new(self))?; if self.next_token()? == Token::ObjectEnd { Ok(value) } else { Err(self.error(ErrorCode::ExpectedMapEnd)) } } _ => Err(self.error(ErrorCode::ExpectedEnum)), } } fn deserialize_identifier(self, visitor: V) -> Result where V: serde::de::Visitor<'de>, { if let Ok(Token::String(val)) = self.parse(&parse_identifier) { visitor.visit_str(&val) } else { Err(self.error(ErrorCode::ExpectedString)) } } fn deserialize_ignored_any(self, visitor: V) -> Result where V: serde::de::Visitor<'de>, { self.deserialize_any(visitor) } } struct Separated<'a, 'de: 'a> { de: &'a mut Deserializer<'de>, first: bool, } impl<'a, 'de: 'a> Separated<'a, 'de> { fn new(de: &'a mut Deserializer<'de>) -> Self { Self { de, first: true } } } impl<'de, 'a> serde::de::SeqAccess<'de> for Separated<'a, 'de> { type Error = Error; fn next_element_seed(&mut self, seed: T) -> Result> where T: serde::de::DeserializeSeed<'de>, { if self.de.peek_token()? == Token::ArrayEnd { return Ok(None); } if !self.first && self.de.parse(&parse_separator)? != Token::Separator { return Err(self.de.error(ErrorCode::ExpectedArraySeparator)); } self.first = false; // TODO: Shouldn't I check that this is a valid value? seed.deserialize(&mut *self.de).map(Some) } } impl<'de, 'a> serde::de::MapAccess<'de> for Separated<'a, 'de> { type Error = Error; fn next_key_seed(&mut self, seed: K) -> Result> where K: serde::de::DeserializeSeed<'de>, { if matches!(self.de.peek_token()?, Token::ObjectEnd | Token::Eof) { return Ok(None); } if !self.first && self.de.parse(&parse_separator)? != Token::Separator { return Err(self.de.error(ErrorCode::ExpectedMapSeparator)); } self.first = false; // TODO: Shouldn't I check that this is a valid identifier? seed.deserialize(&mut *self.de).map(Some) } fn next_value_seed(&mut self, seed: V) -> Result where V: serde::de::DeserializeSeed<'de>, { if self.de.next_token()? != Token::Equals { return Err(self.de.error(ErrorCode::ExpectedMapEquals)); } // TODO: Shouldn't I check that this is a valid value? seed.deserialize(&mut *self.de) } } struct Enum<'a, 'de: 'a> { de: &'a mut Deserializer<'de>, } impl<'a, 'de> Enum<'a, 'de> { fn new(de: &'a mut Deserializer<'de>) -> Self { Self { de } } } impl<'de, 'a> EnumAccess<'de> for Enum<'a, 'de> { type Error = Error; type Variant = Self; fn variant_seed(self, seed: V) -> Result<(V::Value, Self::Variant)> where V: serde::de::DeserializeSeed<'de>, { let val = seed.deserialize(&mut *self.de)?; if self.de.next_token()? == Token::Equals { Ok((val, self)) } else { Err(self.de.error(ErrorCode::ExpectedMapEquals)) } } } impl<'de, 'a> VariantAccess<'de> for Enum<'a, 'de> { type Error = Error; fn unit_variant(self) -> Result<()> { Err(self.de.error(ErrorCode::ExpectedString)) } fn newtype_variant_seed(self, seed: T) -> Result where T: serde::de::DeserializeSeed<'de>, { seed.deserialize(self.de) } fn tuple_variant(self, _len: usize, visitor: V) -> Result where V: serde::de::Visitor<'de>, { serde::Deserializer::deserialize_seq(self.de, visitor) } fn struct_variant(self, _fields: &'static [&'static str], visitor: V) -> Result where V: serde::de::Visitor<'de>, { serde::Deserializer::deserialize_map(self.de, visitor) } } #[cfg(test)] mod test { use std::path::PathBuf; use crate::error::{Error, ErrorCode}; use crate::from_str; macro_rules! assert_value_ok { ($type:ty, $json:expr) => { assert_value_ok!($type, Default::default(), $json) }; ($type:ty, $expected:expr, $json:expr) => {{ #[derive(Debug, serde::Deserialize, PartialEq)] struct Value { value: $type, } let expected = Value { value: $expected }; let json = format!("value = {}", $json); let actual = from_str::(&json).unwrap(); assert_eq!(actual, expected); }}; } macro_rules! assert_ok { ($type:ty, $expected:expr, $json:expr) => {{ let actual = from_str::<$type>($json).unwrap(); assert_eq!(actual, $expected); }}; } macro_rules! assert_value_err { ($type:ty, $expected:expr, $json:expr) => {{ #[derive(Debug, serde::Deserialize, PartialEq)] struct Value { value: $type, } let json = format!("value = {}", $json); let actual = from_str::(&json); assert_eq!(actual, Err($expected)); }}; } #[test] fn deserialize_null() { assert_value_ok!((), "null"); let err = Error::new(ErrorCode::ExpectedNull, 1, 8, Some(" foo".to_string())); assert_value_err!((), err, "foo"); } #[test] fn deserialize_bool() { assert_value_ok!(bool, true, "true"); assert_value_ok!(bool, false, "false"); let err = Error::new(ErrorCode::ExpectedBoolean, 1, 8, Some(" foo".to_string())); assert_value_err!(bool, err, "foo"); } #[test] fn deserialize_integer() { assert_value_ok!(i64, 0, "0"); assert_value_ok!(i64, -1, "-1"); assert_value_ok!(i64, i64::MAX, i64::MAX.to_string()); assert_value_ok!(i64, i64::MIN, i64::MIN.to_string()); assert_value_ok!(i8, 0, "0"); assert_value_ok!(i8, 102, "102"); assert_value_ok!(i8, -102, "-102"); assert_value_ok!(u8, 102, "102"); assert_value_ok!(i16, 256, "256"); let err = Error::new(ErrorCode::ExpectedInteger, 1, 8, Some(" foo".to_string())); assert_value_err!(i64, err, "foo"); } #[test] fn deserialize_float() { assert_value_ok!(f64, 0.0, "0"); assert_value_ok!(f64, 0.0, "0.0"); assert_value_ok!(f64, -1.0, "-1"); assert_value_ok!(f64, -1.0, "-1.0"); assert_value_ok!(f64, f64::MAX, f64::MAX.to_string()); assert_value_ok!(f64, f64::MIN, f64::MIN.to_string()); } #[test] fn deserialize_vec() { assert_value_ok!(Vec, vec![1, 2, 3], "[1, 2, 3]"); assert_value_ok!( Vec, vec![1, 2, 3], "\ [ 1 2 3 ]" ); } #[test] fn deserialize_enum() { #[derive(Debug, serde::Deserialize, PartialEq)] enum Animal { Mouse, Dog { name: String }, Cat(u64), } assert_value_ok!(Animal, Animal::Mouse, "Mouse"); assert_value_ok!(Animal, Animal::Cat(9), "{ Cat = 9 }"); assert_value_ok!( Animal, Animal::Dog { name: String::from("Buddy") }, "{ Dog = { name = Buddy }}" ); } // Checks the example from // https://help.autodesk.com/view/Stingray/ENU/?guid=__stingray_help_managing_content_sjson_html #[test] fn deserialize_stingray_example() { #[derive(Debug, serde::Deserialize, PartialEq)] struct Win32Settings { query_performance_counter_affinity_mask: u64, } #[derive(Debug, serde::Deserialize, PartialEq)] struct Settings { boot_script: String, console_port: u16, win32: Win32Settings, render_config: PathBuf, } let expected = Settings { boot_script: String::from("boot"), console_port: 14030, win32: Win32Settings { query_performance_counter_affinity_mask: 0, }, render_config: PathBuf::from("core/rendering/renderer"), }; let json = r#" // The script that should be started when the application runs. boot_script = "boot" // The port on which the console server runs. console_port = 14030 // Settings for the win32 platform win32 = { /* Sets the affinity mask for QueryPerformanceCounter() */ query_performance_counter_affinity_mask = 0 } render_config = "core/rendering/renderer" "#; assert_ok!(Settings, expected, json); } #[test] fn deserialize_missing_top_level_struct() { let json = "0"; let err = Error::new( ErrorCode::ExpectedTopLevelObject, 1, 1, Some(json.to_string()), ); let actual = from_str::(json); assert_eq!(actual, Err(err)); let json = "1.23"; let err = Error::new( ErrorCode::ExpectedTopLevelObject, 1, 1, Some(json.to_string()), ); let actual = from_str::(json); assert_eq!(actual, Err(err)); let json = "true"; let err = Error::new( ErrorCode::ExpectedTopLevelObject, 1, 1, Some(json.to_string()), ); let actual = from_str::(json); assert_eq!(actual, Err(err)); let json = "null"; let err = Error::new( ErrorCode::ExpectedTopLevelObject, 1, 1, Some(json.to_string()), ); let actual = from_str::<()>(json); assert_eq!(actual, Err(err)); } #[test] fn deserialize_array() { #[derive(Debug, Default, serde::Deserialize, PartialEq)] struct Data { array: Vec, } let expected = Data { array: vec![String::from("foo")], }; let sjson = r#" array = [ "foo" ] "#; assert_ok!(Data, expected, sjson); } // Regression test for #1 (https://git.sclu1034.dev/lucas/serde_sjson/issues/1) #[test] fn deserialize_dtmt_config() { #[derive(Debug, Default, serde::Deserialize, PartialEq)] struct DtmtConfig { name: String, #[serde(default)] description: String, version: Option, } let sjson = r#" name = "test-mod" description = "A dummy project to test things with" version = "0.1.0" packages = [ "packages/test-mod" ] "#; let expected = DtmtConfig { name: String::from("test-mod"), description: String::from("A dummy project to test things with"), version: Some(String::from("0.1.0")), }; assert_ok!(DtmtConfig, expected, sjson); } }