1 use crate::app::VERSION;
2 use crate::config::Config;
3 use anyhow::bail;
4 use anyhow::Error;
5 use anyhow::Result;
6 use mlua::Lua;
7 use mlua::LuaSerdeExt;
8 use mlua::SerializeOptions;
9 use serde::Deserialize;
10 use serde::Serialize;
11 use std::fs;
12 
13 const DEFAULT_LUA_SCRIPT: &str = include_str!("init.lua");
14 const UPGRADE_GUIDE_LINK: &str = "https://xplr.dev/en/upgrade-guide.html";
15 
serialize<'lua, T: Serialize + Sized>( lua: &'lua mlua::Lua, value: &T, ) -> Result<mlua::Value<'lua>>16 pub fn serialize<'lua, T: Serialize + Sized>(
17     lua: &'lua mlua::Lua,
18     value: &T,
19 ) -> Result<mlua::Value<'lua>> {
20     lua.to_value_with(
21         value,
22         SerializeOptions::new().serialize_none_to_null(false),
23     )
24     .map_err(Error::from)
25 }
26 
parse_version(version: &str) -> Result<(u16, u16, u16, Option<u16>)>27 fn parse_version(version: &str) -> Result<(u16, u16, u16, Option<u16>)> {
28     let mut configv = version.split('.');
29 
30     let major = configv.next().unwrap_or_default().parse::<u16>()?;
31     let minor = configv.next().unwrap_or_default().parse::<u16>()?;
32     let bugfix = configv
33         .next()
34         .and_then(|s| s.split('-').next())
35         .unwrap_or_default()
36         .parse::<u16>()?;
37 
38     let beta = configv.next().unwrap_or_default().parse::<u16>().ok();
39 
40     Ok((major, minor, bugfix, beta))
41 }
42 
43 /// Check the config version and notify users.
check_version(version: &str, path: &str) -> Result<()>44 pub fn check_version(version: &str, path: &str) -> Result<()> {
45     // Until we're v1, let's ignore major versions
46     let (rmajor, rminor, rbugfix, rbeta) = parse_version(VERSION)?;
47     let (smajor, sminor, sbugfix, sbeta) = parse_version(version)?;
48 
49     if rmajor == smajor
50         && rminor == sminor
51         && rbugfix >= sbugfix
52         && rbeta == sbeta
53     {
54         Ok(())
55     } else {
56         bail!(
57             "incompatible script version in: {}. The script version is: {}, the required version is: {}. Visit {}",
58             path,
59             version,
60             VERSION.to_string(),
61             UPGRADE_GUIDE_LINK,
62         )
63     }
64 }
65 
66 /// Used to initialize Lua globals
init(lua: &Lua) -> Result<Config>67 pub fn init(lua: &Lua) -> Result<Config> {
68     let config = Config::default();
69     let globals = lua.globals();
70 
71     let lua_xplr = lua.create_table()?;
72     lua_xplr.set("config", serialize(lua, &config)?)?;
73 
74     let lua_xplr_fn = lua.create_table()?;
75     let lua_xplr_fn_builtin = lua.create_table()?;
76     let lua_xplr_fn_custom = lua.create_table()?;
77 
78     lua_xplr_fn.set("builtin", lua_xplr_fn_builtin)?;
79     lua_xplr_fn.set("custom", lua_xplr_fn_custom)?;
80     lua_xplr.set("fn", lua_xplr_fn)?;
81     globals.set("xplr", lua_xplr)?;
82 
83     lua.load(DEFAULT_LUA_SCRIPT).set_name("init")?.exec()?;
84 
85     let lua_xplr: mlua::Table = globals.get("xplr")?;
86     let config: Config = lua.from_value(lua_xplr.get("config")?)?;
87     Ok(config)
88 }
89 
90 /// Used to extend Lua globals
extend(lua: &Lua, path: &str) -> Result<Config>91 pub fn extend(lua: &Lua, path: &str) -> Result<Config> {
92     let globals = lua.globals();
93 
94     let script = fs::read_to_string(path)?;
95 
96     lua.load(&script).set_name("init")?.exec()?;
97 
98     let version: String =
99         match globals.get("version").and_then(|v| lua.from_value(v)) {
100             Ok(v) => v,
101             Err(_) => bail!("'version' must be defined globally in {}", path),
102         };
103 
104     check_version(&version, path)?;
105 
106     let lua_xplr: mlua::Table = globals.get("xplr")?;
107 
108     let config: Config = lua.from_value(lua_xplr.get("config")?)?;
109     Ok(config)
110 }
111 
resolve_fn_recursive<'lua, 'a>( table: &mlua::Table<'lua>, mut path: impl Iterator<Item = &'a str>, ) -> Result<mlua::Function<'lua>>112 fn resolve_fn_recursive<'lua, 'a>(
113     table: &mlua::Table<'lua>,
114     mut path: impl Iterator<Item = &'a str>,
115 ) -> Result<mlua::Function<'lua>> {
116     if let Some(nxt) = path.next() {
117         match table.get(nxt)? {
118             mlua::Value::Table(t) => resolve_fn_recursive(&t, path),
119             mlua::Value::Function(f) => Ok(f),
120             t => bail!("{:?} is not a function", t),
121         }
122     } else {
123         bail!("Invalid path")
124     }
125 }
126 
127 /// This function resolves paths like `builtin.func_foo`, `custom.func_bar` into lua functions.
resolve_fn<'lua>( globals: &mlua::Table<'lua>, path: &str, ) -> Result<mlua::Function<'lua>>128 pub fn resolve_fn<'lua>(
129     globals: &mlua::Table<'lua>,
130     path: &str,
131 ) -> Result<mlua::Function<'lua>> {
132     resolve_fn_recursive(globals, path.split('.'))
133 }
134 
call<'lua, R: Deserialize<'lua>>( lua: &'lua Lua, func: &str, arg: mlua::Value<'lua>, ) -> Result<R>135 pub fn call<'lua, R: Deserialize<'lua>>(
136     lua: &'lua Lua,
137     func: &str,
138     arg: mlua::Value<'lua>,
139 ) -> Result<R> {
140     let func = format!("xplr.fn.{}", func);
141     let func = resolve_fn(&lua.globals(), &func)?;
142     let res: mlua::Value = func.call(arg)?;
143     let res: R = lua.from_value(res)?;
144     Ok(res)
145 }
146 
147 #[cfg(test)]
148 mod tests {
149 
150     use super::*;
151 
152     #[test]
test_compatibility()153     fn test_compatibility() {
154         assert!(check_version(VERSION, "foo path").is_ok());
155 
156         // Current release if OK
157         assert!(check_version("0.17.0", "foo path").is_ok());
158 
159         // Prev major release is ERR
160         // - Not yet
161 
162         // Prev minor release is ERR (Change when we get to v1)
163         assert!(check_version("0.16.0", "foo path").is_err());
164 
165         // Prev bugfix release is OK
166         // assert!(check_version("0.17.-1", "foo path").is_ok());
167 
168         // Next major release is ERR
169         assert!(check_version("1.17.0", "foo path").is_err());
170 
171         // Next minor release is ERR
172         assert!(check_version("0.18.0", "foo path").is_err());
173 
174         // Next bugfix release is ERR (Change when we get to v1)
175         assert!(check_version("0.17.1", "foo path").is_err());
176     }
177 }
178