1 use std::fmt; 2 3 use serde::de::{Error, Unexpected, Visitor}; 4 use serde::{Deserialize, Deserializer, Serialize, Serializer}; 5 6 use ascii_str::AsciiStr; 7 use ascii_string::AsciiString; 8 9 impl Serialize for AsciiString { 10 #[inline] serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error>11 fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> { 12 serializer.serialize_str(self.as_str()) 13 } 14 } 15 16 struct AsciiStringVisitor; 17 18 impl<'de> Visitor<'de> for AsciiStringVisitor { 19 type Value = AsciiString; 20 expecting(&self, f: &mut fmt::Formatter) -> fmt::Result21 fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result { 22 f.write_str("an ascii string") 23 } 24 visit_str<E: Error>(self, v: &str) -> Result<Self::Value, E>25 fn visit_str<E: Error>(self, v: &str) -> Result<Self::Value, E> { 26 AsciiString::from_ascii(v).map_err(|_| Error::invalid_value(Unexpected::Str(v), &self)) 27 } 28 visit_string<E: Error>(self, v: String) -> Result<Self::Value, E>29 fn visit_string<E: Error>(self, v: String) -> Result<Self::Value, E> { 30 AsciiString::from_ascii(v.as_bytes()) 31 .map_err(|_| Error::invalid_value(Unexpected::Str(&v), &self)) 32 } 33 visit_bytes<E: Error>(self, v: &[u8]) -> Result<Self::Value, E>34 fn visit_bytes<E: Error>(self, v: &[u8]) -> Result<Self::Value, E> { 35 AsciiString::from_ascii(v).map_err(|_| Error::invalid_value(Unexpected::Bytes(&v), &self)) 36 } 37 visit_byte_buf<E: Error>(self, v: Vec<u8>) -> Result<Self::Value, E>38 fn visit_byte_buf<E: Error>(self, v: Vec<u8>) -> Result<Self::Value, E> { 39 AsciiString::from_ascii(v.as_slice()) 40 .map_err(|_| Error::invalid_value(Unexpected::Bytes(&v), &self)) 41 } 42 } 43 44 struct AsciiStringInPlaceVisitor<'a>(&'a mut AsciiString); 45 46 impl<'a, 'de> Visitor<'de> for AsciiStringInPlaceVisitor<'a> { 47 type Value = (); 48 expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result49 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { 50 formatter.write_str("an ascii string") 51 } 52 visit_str<E: Error>(self, v: &str) -> Result<Self::Value, E>53 fn visit_str<E: Error>(self, v: &str) -> Result<Self::Value, E> { 54 let ascii_str = match AsciiStr::from_ascii(v.as_bytes()) { 55 Ok(ascii_str) => ascii_str, 56 Err(_) => return Err(Error::invalid_value(Unexpected::Str(v), &self)), 57 }; 58 self.0.clear(); 59 self.0.push_str(ascii_str); 60 Ok(()) 61 } 62 visit_string<E: Error>(self, v: String) -> Result<Self::Value, E>63 fn visit_string<E: Error>(self, v: String) -> Result<Self::Value, E> { 64 let ascii_string = match AsciiString::from_ascii(v.as_bytes()) { 65 Ok(ascii_string) => ascii_string, 66 Err(_) => return Err(Error::invalid_value(Unexpected::Str(&v), &self)), 67 }; 68 *self.0 = ascii_string; 69 Ok(()) 70 } 71 visit_bytes<E: Error>(self, v: &[u8]) -> Result<Self::Value, E>72 fn visit_bytes<E: Error>(self, v: &[u8]) -> Result<Self::Value, E> { 73 let ascii_str = match AsciiStr::from_ascii(v) { 74 Ok(ascii_str) => ascii_str, 75 Err(_) => return Err(Error::invalid_value(Unexpected::Bytes(v), &self)), 76 }; 77 self.0.clear(); 78 self.0.push_str(ascii_str); 79 Ok(()) 80 } 81 visit_byte_buf<E: Error>(self, v: Vec<u8>) -> Result<Self::Value, E>82 fn visit_byte_buf<E: Error>(self, v: Vec<u8>) -> Result<Self::Value, E> { 83 let ascii_string = match AsciiString::from_ascii(v.as_slice()) { 84 Ok(ascii_string) => ascii_string, 85 Err(_) => return Err(Error::invalid_value(Unexpected::Bytes(&v), &self)), 86 }; 87 *self.0 = ascii_string; 88 Ok(()) 89 } 90 } 91 92 impl<'de> Deserialize<'de> for AsciiString { deserialize<D>(deserializer: D) -> Result<AsciiString, D::Error> where D: Deserializer<'de>,93 fn deserialize<D>(deserializer: D) -> Result<AsciiString, D::Error> 94 where 95 D: Deserializer<'de>, 96 { 97 deserializer.deserialize_string(AsciiStringVisitor) 98 } 99 deserialize_in_place<D>(deserializer: D, place: &mut Self) -> Result<(), D::Error> where D: Deserializer<'de>,100 fn deserialize_in_place<D>(deserializer: D, place: &mut Self) -> Result<(), D::Error> 101 where 102 D: Deserializer<'de>, 103 { 104 deserializer.deserialize_string(AsciiStringInPlaceVisitor(place)) 105 } 106 } 107 108 #[cfg(test)] 109 mod tests { 110 use super::*; 111 112 #[cfg(feature = "serde_test")] 113 const ASCII: &str = "Francais"; 114 #[cfg(feature = "serde_test")] 115 const UNICODE: &str = "Français"; 116 117 #[test] basic()118 fn basic() { 119 fn assert_serialize<T: Serialize>() {} 120 assert_serialize::<AsciiString>(); 121 fn assert_deserialize<'de, T: Deserialize<'de>>() {} 122 assert_deserialize::<AsciiString>(); 123 } 124 125 #[test] 126 #[cfg(feature = "serde_test")] serialize()127 fn serialize() { 128 use serde_test::{assert_tokens, Token}; 129 130 let ascii_string = AsciiString::from_ascii(ASCII).unwrap(); 131 assert_tokens(&ascii_string, &[Token::String(ASCII)]); 132 assert_tokens(&ascii_string, &[Token::Str(ASCII)]); 133 assert_tokens(&ascii_string, &[Token::BorrowedStr(ASCII)]); 134 } 135 136 #[test] 137 #[cfg(feature = "serde_test")] deserialize()138 fn deserialize() { 139 use serde_test::{assert_de_tokens, assert_de_tokens_error, Token}; 140 let ascii_string = AsciiString::from_ascii(ASCII).unwrap(); 141 assert_de_tokens(&ascii_string, &[Token::Bytes(ASCII.as_bytes())]); 142 assert_de_tokens(&ascii_string, &[Token::BorrowedBytes(ASCII.as_bytes())]); 143 assert_de_tokens(&ascii_string, &[Token::ByteBuf(ASCII.as_bytes())]); 144 assert_de_tokens_error::<AsciiString>( 145 &[Token::String(UNICODE)], 146 "invalid value: string \"Français\", expected an ascii string", 147 ); 148 } 149 } 150