1 use std::fmt;
2 
3 use serde::de::{Error, Unexpected, Visitor};
4 use serde::{Deserialize, Deserializer, Serialize, Serializer};
5 
6 use ascii_str::AsciiStr;
7 use ascii_string::AsciiString;
8 
9 impl Serialize for AsciiString {
10     #[inline]
serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error>11     fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
12         serializer.serialize_str(self.as_str())
13     }
14 }
15 
16 struct AsciiStringVisitor;
17 
18 impl<'de> Visitor<'de> for AsciiStringVisitor {
19     type Value = AsciiString;
20 
expecting(&self, f: &mut fmt::Formatter) -> fmt::Result21     fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
22         f.write_str("an ascii string")
23     }
24 
visit_str<E: Error>(self, v: &str) -> Result<Self::Value, E>25     fn visit_str<E: Error>(self, v: &str) -> Result<Self::Value, E> {
26         AsciiString::from_ascii(v).map_err(|_| Error::invalid_value(Unexpected::Str(v), &self))
27     }
28 
visit_string<E: Error>(self, v: String) -> Result<Self::Value, E>29     fn visit_string<E: Error>(self, v: String) -> Result<Self::Value, E> {
30         AsciiString::from_ascii(v.as_bytes())
31             .map_err(|_| Error::invalid_value(Unexpected::Str(&v), &self))
32     }
33 
visit_bytes<E: Error>(self, v: &[u8]) -> Result<Self::Value, E>34     fn visit_bytes<E: Error>(self, v: &[u8]) -> Result<Self::Value, E> {
35         AsciiString::from_ascii(v).map_err(|_| Error::invalid_value(Unexpected::Bytes(&v), &self))
36     }
37 
visit_byte_buf<E: Error>(self, v: Vec<u8>) -> Result<Self::Value, E>38     fn visit_byte_buf<E: Error>(self, v: Vec<u8>) -> Result<Self::Value, E> {
39         AsciiString::from_ascii(v.as_slice())
40             .map_err(|_| Error::invalid_value(Unexpected::Bytes(&v), &self))
41     }
42 }
43 
44 struct AsciiStringInPlaceVisitor<'a>(&'a mut AsciiString);
45 
46 impl<'a, 'de> Visitor<'de> for AsciiStringInPlaceVisitor<'a> {
47     type Value = ();
48 
expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result49     fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
50         formatter.write_str("an ascii string")
51     }
52 
visit_str<E: Error>(self, v: &str) -> Result<Self::Value, E>53     fn visit_str<E: Error>(self, v: &str) -> Result<Self::Value, E> {
54         let ascii_str = match AsciiStr::from_ascii(v.as_bytes()) {
55             Ok(ascii_str) => ascii_str,
56             Err(_) => return Err(Error::invalid_value(Unexpected::Str(v), &self)),
57         };
58         self.0.clear();
59         self.0.push_str(ascii_str);
60         Ok(())
61     }
62 
visit_string<E: Error>(self, v: String) -> Result<Self::Value, E>63     fn visit_string<E: Error>(self, v: String) -> Result<Self::Value, E> {
64         let ascii_string = match AsciiString::from_ascii(v.as_bytes()) {
65             Ok(ascii_string) => ascii_string,
66             Err(_) => return Err(Error::invalid_value(Unexpected::Str(&v), &self)),
67         };
68         *self.0 = ascii_string;
69         Ok(())
70     }
71 
visit_bytes<E: Error>(self, v: &[u8]) -> Result<Self::Value, E>72     fn visit_bytes<E: Error>(self, v: &[u8]) -> Result<Self::Value, E> {
73         let ascii_str = match AsciiStr::from_ascii(v) {
74             Ok(ascii_str) => ascii_str,
75             Err(_) => return Err(Error::invalid_value(Unexpected::Bytes(v), &self)),
76         };
77         self.0.clear();
78         self.0.push_str(ascii_str);
79         Ok(())
80     }
81 
visit_byte_buf<E: Error>(self, v: Vec<u8>) -> Result<Self::Value, E>82     fn visit_byte_buf<E: Error>(self, v: Vec<u8>) -> Result<Self::Value, E> {
83         let ascii_string = match AsciiString::from_ascii(v.as_slice()) {
84             Ok(ascii_string) => ascii_string,
85             Err(_) => return Err(Error::invalid_value(Unexpected::Bytes(&v), &self)),
86         };
87         *self.0 = ascii_string;
88         Ok(())
89     }
90 }
91 
92 impl<'de> Deserialize<'de> for AsciiString {
deserialize<D>(deserializer: D) -> Result<AsciiString, D::Error> where D: Deserializer<'de>,93     fn deserialize<D>(deserializer: D) -> Result<AsciiString, D::Error>
94     where
95         D: Deserializer<'de>,
96     {
97         deserializer.deserialize_string(AsciiStringVisitor)
98     }
99 
deserialize_in_place<D>(deserializer: D, place: &mut Self) -> Result<(), D::Error> where D: Deserializer<'de>,100     fn deserialize_in_place<D>(deserializer: D, place: &mut Self) -> Result<(), D::Error>
101     where
102         D: Deserializer<'de>,
103     {
104         deserializer.deserialize_string(AsciiStringInPlaceVisitor(place))
105     }
106 }
107 
108 #[cfg(test)]
109 mod tests {
110     use super::*;
111 
112     #[cfg(feature = "serde_test")]
113     const ASCII: &str = "Francais";
114     #[cfg(feature = "serde_test")]
115     const UNICODE: &str = "Français";
116 
117     #[test]
basic()118     fn basic() {
119         fn assert_serialize<T: Serialize>() {}
120         assert_serialize::<AsciiString>();
121         fn assert_deserialize<'de, T: Deserialize<'de>>() {}
122         assert_deserialize::<AsciiString>();
123     }
124 
125     #[test]
126     #[cfg(feature = "serde_test")]
serialize()127     fn serialize() {
128         use serde_test::{assert_tokens, Token};
129 
130         let ascii_string = AsciiString::from_ascii(ASCII).unwrap();
131         assert_tokens(&ascii_string, &[Token::String(ASCII)]);
132         assert_tokens(&ascii_string, &[Token::Str(ASCII)]);
133         assert_tokens(&ascii_string, &[Token::BorrowedStr(ASCII)]);
134     }
135 
136     #[test]
137     #[cfg(feature = "serde_test")]
deserialize()138     fn deserialize() {
139         use serde_test::{assert_de_tokens, assert_de_tokens_error, Token};
140         let ascii_string = AsciiString::from_ascii(ASCII).unwrap();
141         assert_de_tokens(&ascii_string, &[Token::Bytes(ASCII.as_bytes())]);
142         assert_de_tokens(&ascii_string, &[Token::BorrowedBytes(ASCII.as_bytes())]);
143         assert_de_tokens(&ascii_string, &[Token::ByteBuf(ASCII.as_bytes())]);
144         assert_de_tokens_error::<AsciiString>(
145             &[Token::String(UNICODE)],
146             "invalid value: string \"Français\", expected an ascii string",
147         );
148     }
149 }
150