1 //! Support for dynamic class objects in Rust
2 
3 use std::any::TypeId;
4 use std::collections::HashMap;
5 use std::fmt;
6 use std::sync::Arc;
7 
8 use crate::errors::{InvalidCallError, OsoError};
9 
10 use super::class_method::{
11     AttributeGetter, ClassMethod, Constructor, InstanceMethod, RegisterHook,
12 };
13 use super::from_polar::FromPolarList;
14 use super::method::{Function, Method};
15 use super::to_polar::ToPolarResult;
16 use super::Host;
17 use super::PolarValue;
18 
19 type Attributes = HashMap<&'static str, AttributeGetter>;
20 type RegisterHooks = Vec<RegisterHook>;
21 type ClassMethods = HashMap<&'static str, ClassMethod>;
22 type InstanceMethods = HashMap<&'static str, InstanceMethod>;
23 
equality_not_supported( ) -> Box<dyn Fn(&Host, &Instance, &Instance) -> crate::Result<bool> + Send + Sync>24 fn equality_not_supported(
25 ) -> Box<dyn Fn(&Host, &Instance, &Instance) -> crate::Result<bool> + Send + Sync> {
26     let eq = move |host: &Host, lhs: &Instance, _: &Instance| -> crate::Result<bool> {
27         Err(OsoError::UnsupportedOperation {
28             operation: String::from("equals"),
29             type_name: lhs.name(host).to_owned(),
30         })
31     };
32 
33     Box::new(eq)
34 }
35 
iterator_not_supported( ) -> Box<dyn Fn(&Host, &Instance) -> crate::Result<crate::host::PolarIterator> + Send + Sync>36 fn iterator_not_supported(
37 ) -> Box<dyn Fn(&Host, &Instance) -> crate::Result<crate::host::PolarIterator> + Send + Sync> {
38     let into_iter = move |host: &Host, instance: &Instance| {
39         Err(OsoError::UnsupportedOperation {
40             operation: String::from("in"),
41             type_name: instance.name(host).to_owned(),
42         })
43     };
44 
45     Box::new(into_iter)
46 }
47 
48 #[derive(Clone)]
49 pub struct Class {
50     /// The class name. Defaults to the `std::any::type_name`
51     pub name: String,
52     pub type_id: TypeId,
53     /// A wrapped method that constructs an instance of `T` from `PolarValue`s
54     constructor: Option<Constructor>,
55     /// Methods that return simple attribute lookups on an instance of `T`
56     attributes: Attributes,
57     /// Instance methods on `T` that expect a list of `PolarValue`s, and an instance of `&T`
58     instance_methods: InstanceMethods,
59     /// Class methods on `T`
60     class_methods: ClassMethods,
61 
62     /// A method to check whether the supplied `TypeId` matches this class
63     /// (This isn't using `type_id` because we might want to register other types here
64     /// in order to check inheritance)
65     class_check: Arc<dyn Fn(TypeId) -> bool + Send + Sync>,
66 
67     /// A function that accepts arguments of this class and compares them for equality.
68     /// Limitation: Only works on comparisons of the same type.
69     equality_check: Arc<dyn Fn(&Host, &Instance, &Instance) -> crate::Result<bool> + Send + Sync>,
70 
71     into_iter:
72         Arc<dyn Fn(&Host, &Instance) -> crate::Result<crate::host::PolarIterator> + Send + Sync>,
73 
74     // Hooks to be called on the class once it's been registered with host.
75     pub register_hooks: RegisterHooks,
76 }
77 
78 impl Class {
builder<T: 'static>() -> ClassBuilder<T>79     pub fn builder<T: 'static>() -> ClassBuilder<T> {
80         ClassBuilder::new()
81     }
82 
init(&self, fields: Vec<PolarValue>) -> crate::Result<Instance>83     pub fn init(&self, fields: Vec<PolarValue>) -> crate::Result<Instance> {
84         if let Some(constructor) = &self.constructor {
85             constructor.invoke(fields)
86         } else {
87             Err(crate::OsoError::Custom {
88                 message: format!("MissingConstructorError: {} has no constructor", self.name),
89             })
90         }
91     }
92 
93     /// Call class method `attr` on `self` with arguments from `args`.
94     ///
95     /// Returns: The result as a `PolarValue`
call(&self, attr: &str, args: Vec<PolarValue>) -> crate::Result<PolarValue>96     pub fn call(&self, attr: &str, args: Vec<PolarValue>) -> crate::Result<PolarValue> {
97         let attr =
98             self.class_methods
99                 .get(attr)
100                 .ok_or_else(|| InvalidCallError::ClassMethodNotFound {
101                     method_name: attr.to_owned(),
102                     type_name: self.name.clone(),
103                 })?;
104 
105         attr.clone().invoke(args)
106     }
107 
get_method(&self, name: &str) -> Option<InstanceMethod>108     fn get_method(&self, name: &str) -> Option<InstanceMethod> {
109         tracing::trace!({class=%self.name, name}, "get_method");
110         if self.type_id == TypeId::of::<Class>() {
111             // all methods on `Class` redirect by looking up the class method
112             Some(InstanceMethod::from_class_method(name.to_string()))
113         } else {
114             self.instance_methods.get(name).cloned()
115         }
116     }
117 
equals(&self, host: &Host, lhs: &Instance, rhs: &Instance) -> crate::Result<bool>118     fn equals(&self, host: &Host, lhs: &Instance, rhs: &Instance) -> crate::Result<bool> {
119         // equality checking is currently only supported for exactly matching types
120         // TODO: support multiple dispatch for equality
121         if lhs.type_id() != rhs.type_id() {
122             Ok(false)
123         } else {
124             (self.equality_check)(host, lhs, rhs)
125         }
126     }
127 }
128 
129 #[derive(Clone)]
130 pub struct ClassBuilder<T> {
131     class: Class,
132     /// A type marker. Used to ensure methods have the correct type.
133     ty: std::marker::PhantomData<T>,
134 }
135 
136 impl<T> ClassBuilder<T>
137 where
138     T: 'static,
139 {
140     /// Create a new class builder.
new() -> Self141     fn new() -> Self {
142         let fq_name = std::any::type_name::<T>().to_string();
143         let short_name = fq_name.split("::").last().expect("type has invalid name");
144         Self {
145             class: Class {
146                 name: short_name.to_string(),
147                 constructor: None,
148                 attributes: HashMap::new(),
149                 instance_methods: InstanceMethods::new(),
150                 class_methods: ClassMethods::new(),
151                 class_check: Arc::new(|type_id| TypeId::of::<T>() == type_id),
152                 equality_check: Arc::from(equality_not_supported()),
153                 into_iter: Arc::from(iterator_not_supported()),
154                 type_id: TypeId::of::<T>(),
155                 register_hooks: RegisterHooks::new(),
156             },
157             ty: std::marker::PhantomData,
158         }
159     }
160 
161     /// Create a new class builder for a type that implements Default and use that as the constructor.
with_default() -> Self where T: std::default::Default, T: Send + Sync,162     pub fn with_default() -> Self
163     where
164         T: std::default::Default,
165         T: Send + Sync,
166     {
167         Self::with_constructor::<_, _>(T::default)
168     }
169 
170     /// Create a new class builder with a given constructor.
with_constructor<F, Args>(f: F) -> Self where F: Function<Args, Result = T>, T: Send + Sync, Args: FromPolarList,171     pub fn with_constructor<F, Args>(f: F) -> Self
172     where
173         F: Function<Args, Result = T>,
174         T: Send + Sync,
175         Args: FromPolarList,
176     {
177         let mut class: ClassBuilder<T> = ClassBuilder::new();
178         class = class.set_constructor(f);
179         class
180     }
181 
182     /// Set the constructor function to use for polar `new` statements.
set_constructor<F, Args>(mut self, f: F) -> Self where F: Function<Args, Result = T>, T: Send + Sync, Args: FromPolarList,183     pub fn set_constructor<F, Args>(mut self, f: F) -> Self
184     where
185         F: Function<Args, Result = T>,
186         T: Send + Sync,
187         Args: FromPolarList,
188     {
189         self.class.constructor = Some(Constructor::new(f));
190         self
191     }
192 
193     /// Set an equality function to be used for polar `==` statements.
set_equality_check<F>(mut self, f: F) -> Self where F: Fn(&T, &T) -> bool + Send + Sync + 'static,194     pub fn set_equality_check<F>(mut self, f: F) -> Self
195     where
196         F: Fn(&T, &T) -> bool + Send + Sync + 'static,
197     {
198         self.class.equality_check = Arc::new(move |host, a, b| {
199             tracing::trace!("equality check");
200 
201             let a = a.downcast(Some(host)).map_err(|e| e.user())?;
202             let b = b.downcast(Some(host)).map_err(|e| e.user())?;
203 
204             Ok((f)(a, b))
205         });
206 
207         self
208     }
209 
210     /// Set a method to convert instances into iterators
set_into_iter<F, I, V>(mut self, f: F) -> Self where F: Fn(&T) -> I + Send + Sync + 'static, I: Iterator<Item = V> + Clone + Send + Sync + 'static, V: ToPolarResult,211     pub fn set_into_iter<F, I, V>(mut self, f: F) -> Self
212     where
213         F: Fn(&T) -> I + Send + Sync + 'static,
214         I: Iterator<Item = V> + Clone + Send + Sync + 'static,
215         V: ToPolarResult,
216     {
217         self.class.into_iter = Arc::new(move |host, instance| {
218             tracing::trace!("iter check");
219 
220             let instance = instance.downcast(Some(host)).map_err(|e| e.user())?;
221 
222             Ok(crate::host::PolarIterator::new(f(instance)))
223         });
224 
225         self
226     }
227 
228     /// Use the existing `IntoIterator` implementation to convert instances into iterators
with_iter<V>(self) -> Self where T: IntoIterator<Item = V> + Clone, <T as IntoIterator>::IntoIter: Clone + Send + Sync + 'static, V: ToPolarResult,229     pub fn with_iter<V>(self) -> Self
230     where
231         T: IntoIterator<Item = V> + Clone,
232         <T as IntoIterator>::IntoIter: Clone + Send + Sync + 'static,
233         V: ToPolarResult,
234     {
235         self.set_into_iter(|t| t.clone().into_iter())
236     }
237 
238     /// Use PartialEq::eq as the equality check for polar `==` statements.
with_equality_check(self) -> Self where T: PartialEq<T>,239     pub fn with_equality_check(self) -> Self
240     where
241         T: PartialEq<T>,
242     {
243         self.set_equality_check(|a, b| PartialEq::eq(a, b))
244     }
245 
246     /// Add an attribute getter for statments like `foo.bar`
247     /// `class.add_attribute_getter("bar", |instance| instance.bar)
add_attribute_getter<F, R>(mut self, name: &'static str, f: F) -> Self where F: Fn(&T) -> R + Send + Sync + 'static, R: crate::ToPolar, T: 'static,248     pub fn add_attribute_getter<F, R>(mut self, name: &'static str, f: F) -> Self
249     where
250         F: Fn(&T) -> R + Send + Sync + 'static,
251         R: crate::ToPolar,
252         T: 'static,
253     {
254         self.class.attributes.insert(name, AttributeGetter::new(f));
255         self
256     }
257 
258     /// Set the name of the polar class.
name(mut self, name: &str) -> Self259     pub fn name(mut self, name: &str) -> Self {
260         self.class.name = name.to_string();
261         self
262     }
263 
264     /// Add a RegisterHook on the class that will register the given constant once the class is registered.
add_constant<V: crate::ToPolar + Clone + Send + Sync + 'static>( mut self, value: V, name: &'static str, ) -> Self265     pub fn add_constant<V: crate::ToPolar + Clone + Send + Sync + 'static>(
266         mut self,
267         value: V,
268         name: &'static str,
269     ) -> Self {
270         let register_hook = move |oso: &mut crate::Oso| oso.register_constant(value.clone(), name);
271         self.class
272             .register_hooks
273             .push(RegisterHook::new(register_hook));
274         self
275     }
276 
277     /// Add a method for polar method calls like `foo.plus(1)
278     /// `class.add_attribute_getter("bar", |instance, n| instance.foo + n)
add_method<F, Args, R>(mut self, name: &'static str, f: F) -> Self where Args: FromPolarList, F: Method<T, Args, Result = R>, R: ToPolarResult + 'static,279     pub fn add_method<F, Args, R>(mut self, name: &'static str, f: F) -> Self
280     where
281         Args: FromPolarList,
282         F: Method<T, Args, Result = R>,
283         R: ToPolarResult + 'static,
284     {
285         self.class
286             .instance_methods
287             .insert(name, InstanceMethod::new(f));
288         self
289     }
290 
291     /// A method that returns multiple values. Every element in the iterator returned by the method will
292     /// be a separate polar return value.
add_iterator_method<F, Args, I>(mut self, name: &'static str, f: F) -> Self where Args: FromPolarList, F: Method<T, Args>, F::Result: IntoIterator<Item = I>, I: ToPolarResult + 'static, <<F as Method<T, Args>>::Result as IntoIterator>::IntoIter: Iterator<Item = I> + Clone + Send + Sync + 'static, T: 'static,293     pub fn add_iterator_method<F, Args, I>(mut self, name: &'static str, f: F) -> Self
294     where
295         Args: FromPolarList,
296         F: Method<T, Args>,
297         F::Result: IntoIterator<Item = I>,
298         I: ToPolarResult + 'static,
299         <<F as Method<T, Args>>::Result as IntoIterator>::IntoIter:
300             Iterator<Item = I> + Clone + Send + Sync + 'static,
301         T: 'static,
302     {
303         self.class
304             .instance_methods
305             .insert(name, InstanceMethod::new_iterator(f));
306         self
307     }
308 
309     /// A method that's called on the type instead of an instance.
310     /// eg `Foo.pi`
add_class_method<F, Args, R>(mut self, name: &'static str, f: F) -> Self where F: Function<Args, Result = R>, Args: FromPolarList, R: ToPolarResult + 'static,311     pub fn add_class_method<F, Args, R>(mut self, name: &'static str, f: F) -> Self
312     where
313         F: Function<Args, Result = R>,
314         Args: FromPolarList,
315         R: ToPolarResult + 'static,
316     {
317         self.class.class_methods.insert(name, ClassMethod::new(f));
318         self
319     }
320 
321     /// Finish building a build the class
build(self) -> Class322     pub fn build(self) -> Class {
323         self.class
324     }
325 }
326 
327 /// Container for an instance of a `Class`
328 ///
329 /// Not guaranteed to be an instance of a registered class,
330 /// this is looked up through the `Instance::class` method,
331 /// and the `ToPolar` implementation for `PolarClass` will
332 /// register the class if not seen before.
333 ///
334 /// A reference to the underlying type of the Instance can be
335 /// retrived using `Instance::downcast`.
336 #[derive(Clone)]
337 pub struct Instance {
338     inner: Arc<dyn std::any::Any + Send + Sync>,
339 
340     /// The type name of the Instance, to be used for debugging purposes only.
341     /// To get the registered name, use `Instance::name`.
342     debug_type_name: &'static str,
343 }
344 
345 impl fmt::Debug for Instance {
fmt(&self, f: &mut fmt::Formatter) -> fmt::Result346     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
347         write!(f, "Instance<{}>", self.debug_type_name)
348     }
349 }
350 
351 impl Instance {
352     /// Create a new instance
new<T: Send + Sync + 'static>(instance: T) -> Self353     pub fn new<T: Send + Sync + 'static>(instance: T) -> Self {
354         Self {
355             inner: Arc::new(instance),
356             debug_type_name: std::any::type_name::<T>(),
357         }
358     }
359 
360     /// Check whether this is an instance of `class`
instance_of(&self, class: &Class) -> bool361     pub fn instance_of(&self, class: &Class) -> bool {
362         self.type_id() == class.type_id
363     }
364 
type_id(&self) -> std::any::TypeId365     pub fn type_id(&self) -> std::any::TypeId {
366         self.inner.as_ref().type_id()
367     }
368 
369     /// Looks up the `Class` for this instance on the provided `host`
class<'a>(&self, host: &'a Host) -> crate::Result<&'a Class>370     pub fn class<'a>(&self, host: &'a Host) -> crate::Result<&'a Class> {
371         host.get_class_by_type_id(self.inner.as_ref().type_id())
372             .map_err(|_| OsoError::MissingClassError {
373                 name: self.debug_type_name.to_string(),
374             })
375     }
376 
377     /// Get the canonical name of this instance.
378     ///
379     /// The canonical name is the registered name on host *if* if it registered.
380     /// Otherwise, the debug name is returned.
name<'a>(&self, host: &'a Host) -> &'a str381     pub fn name<'a>(&self, host: &'a Host) -> &'a str {
382         self.class(host)
383             .map(|class| class.name.as_ref())
384             .unwrap_or_else(|_| self.debug_type_name)
385     }
386 
387     /// Lookup an attribute on the instance via the registered `Class`
get_attr(&self, name: &str, host: &mut Host) -> crate::Result<PolarValue>388     pub fn get_attr(&self, name: &str, host: &mut Host) -> crate::Result<PolarValue> {
389         tracing::trace!({ method = %name }, "get_attr");
390         let attr = self
391             .class(host)
392             .and_then(|c| {
393                 c.attributes.get(name).ok_or_else(|| {
394                     InvalidCallError::AttributeNotFound {
395                         attribute_name: name.to_owned(),
396                         type_name: self.name(&host).to_owned(),
397                     }
398                     .into()
399                 })
400             })?
401             .clone();
402         attr.invoke(self, host)
403     }
404 
405     /// Call the named method on the instance via the registered `Class`
406     ///
407     /// Returns: A PolarValue, or an Error if the method cannot be called.
call( &self, name: &str, args: Vec<PolarValue>, host: &mut Host, ) -> crate::Result<PolarValue>408     pub fn call(
409         &self,
410         name: &str,
411         args: Vec<PolarValue>,
412         host: &mut Host,
413     ) -> crate::Result<PolarValue> {
414         tracing::trace!({method = %name, ?args}, "call");
415         let method = self.class(host).and_then(|c| {
416             c.get_method(name).ok_or_else(|| {
417                 InvalidCallError::MethodNotFound {
418                     method_name: name.to_owned(),
419                     type_name: self.name(&host).to_owned(),
420                 }
421                 .into()
422             })
423         })?;
424         method.invoke(self, args, host)
425     }
426 
as_iter(&self, host: &Host) -> crate::Result<crate::host::PolarIterator>427     pub fn as_iter(&self, host: &Host) -> crate::Result<crate::host::PolarIterator> {
428         self.class(host).and_then(|c| (c.into_iter)(host, self))
429     }
430 
431     /// Return `true` if the `instance` of self equals the instance of `other`.
equals(&self, other: &Self, host: &Host) -> crate::Result<bool>432     pub fn equals(&self, other: &Self, host: &Host) -> crate::Result<bool> {
433         tracing::trace!("equals");
434         self.class(host)
435             .and_then(|class| class.equals(host, &self, other))
436     }
437 
438     /// Attempt to downcast the inner type of the instance to a reference to the type `T`
439     /// This should be the _only_ place using downcast to avoid mistakes.
440     ///
441     /// # Arguments
442     ///
443     /// * `host`: Pass host if possible to improve error handling.
downcast<T: 'static>( &self, host: Option<&Host>, ) -> Result<&T, crate::errors::TypeError>444     pub fn downcast<T: 'static>(
445         &self,
446         host: Option<&Host>,
447     ) -> Result<&T, crate::errors::TypeError> {
448         let name = host
449             .map(|h| self.name(h).to_owned())
450             .unwrap_or_else(|| self.debug_type_name.to_owned());
451 
452         let expected_name = host
453             .and_then(|h| {
454                 h.get_class_by_type_id(std::any::TypeId::of::<T>())
455                     .map(|class| class.name.clone())
456                     .ok()
457             })
458             .unwrap_or_else(|| std::any::type_name::<T>().to_owned());
459 
460         self.inner
461             .as_ref()
462             .downcast_ref()
463             .ok_or_else(|| crate::errors::TypeError::expected(expected_name).got(name))
464     }
465 }
466 
467 #[cfg(test)]
468 mod test {
469     use super::*;
470 
471     #[test]
test_instance_of()472     fn test_instance_of() {
473         struct Foo {}
474         struct Bar {}
475 
476         let foo_class = Class::builder::<Foo>().build();
477         let bar_class = Class::builder::<Bar>().build();
478         let foo_instance = Instance::new(Foo {});
479 
480         assert!(foo_instance.instance_of(&foo_class));
481         assert!(!foo_instance.instance_of(&bar_class));
482     }
483 }
484