1 use crate::app::VERSION;
2 use crate::config::Config;
3 use anyhow::bail;
4 use anyhow::Error;
5 use anyhow::Result;
6 use mlua::Lua;
7 use mlua::LuaSerdeExt;
8 use mlua::SerializeOptions;
9 use serde::Deserialize;
10 use serde::Serialize;
11 use std::fs;
12
13 const DEFAULT_LUA_SCRIPT: &str = include_str!("init.lua");
14 const UPGRADE_GUIDE_LINK: &str = "https://xplr.dev/en/upgrade-guide.html";
15
serialize<'lua, T: Serialize + Sized>( lua: &'lua mlua::Lua, value: &T, ) -> Result<mlua::Value<'lua>>16 pub fn serialize<'lua, T: Serialize + Sized>(
17 lua: &'lua mlua::Lua,
18 value: &T,
19 ) -> Result<mlua::Value<'lua>> {
20 lua.to_value_with(
21 value,
22 SerializeOptions::new().serialize_none_to_null(false),
23 )
24 .map_err(Error::from)
25 }
26
parse_version(version: &str) -> Result<(u16, u16, u16, Option<u16>)>27 fn parse_version(version: &str) -> Result<(u16, u16, u16, Option<u16>)> {
28 let mut configv = version.split('.');
29
30 let major = configv.next().unwrap_or_default().parse::<u16>()?;
31 let minor = configv.next().unwrap_or_default().parse::<u16>()?;
32 let bugfix = configv
33 .next()
34 .and_then(|s| s.split('-').next())
35 .unwrap_or_default()
36 .parse::<u16>()?;
37
38 let beta = configv.next().unwrap_or_default().parse::<u16>().ok();
39
40 Ok((major, minor, bugfix, beta))
41 }
42
43 /// Check the config version and notify users.
check_version(version: &str, path: &str) -> Result<()>44 pub fn check_version(version: &str, path: &str) -> Result<()> {
45 // Until we're v1, let's ignore major versions
46 let (rmajor, rminor, rbugfix, rbeta) = parse_version(VERSION)?;
47 let (smajor, sminor, sbugfix, sbeta) = parse_version(version)?;
48
49 if rmajor == smajor
50 && rminor == sminor
51 && rbugfix >= sbugfix
52 && rbeta == sbeta
53 {
54 Ok(())
55 } else {
56 bail!(
57 "incompatible script version in: {}. The script version is: {}, the required version is: {}. Visit {}",
58 path,
59 version,
60 VERSION.to_string(),
61 UPGRADE_GUIDE_LINK,
62 )
63 }
64 }
65
66 /// Used to initialize Lua globals
init(lua: &Lua) -> Result<Config>67 pub fn init(lua: &Lua) -> Result<Config> {
68 let config = Config::default();
69 let globals = lua.globals();
70
71 let lua_xplr = lua.create_table()?;
72 lua_xplr.set("config", serialize(lua, &config)?)?;
73
74 let lua_xplr_fn = lua.create_table()?;
75 let lua_xplr_fn_builtin = lua.create_table()?;
76 let lua_xplr_fn_custom = lua.create_table()?;
77
78 lua_xplr_fn.set("builtin", lua_xplr_fn_builtin)?;
79 lua_xplr_fn.set("custom", lua_xplr_fn_custom)?;
80 lua_xplr.set("fn", lua_xplr_fn)?;
81 globals.set("xplr", lua_xplr)?;
82
83 lua.load(DEFAULT_LUA_SCRIPT).set_name("init")?.exec()?;
84
85 let lua_xplr: mlua::Table = globals.get("xplr")?;
86 let config: Config = lua.from_value(lua_xplr.get("config")?)?;
87 Ok(config)
88 }
89
90 /// Used to extend Lua globals
extend(lua: &Lua, path: &str) -> Result<Config>91 pub fn extend(lua: &Lua, path: &str) -> Result<Config> {
92 let globals = lua.globals();
93
94 let script = fs::read_to_string(path)?;
95
96 lua.load(&script).set_name("init")?.exec()?;
97
98 let version: String =
99 match globals.get("version").and_then(|v| lua.from_value(v)) {
100 Ok(v) => v,
101 Err(_) => bail!("'version' must be defined globally in {}", path),
102 };
103
104 check_version(&version, path)?;
105
106 let lua_xplr: mlua::Table = globals.get("xplr")?;
107
108 let config: Config = lua.from_value(lua_xplr.get("config")?)?;
109 Ok(config)
110 }
111
resolve_fn_recursive<'lua, 'a>( table: &mlua::Table<'lua>, mut path: impl Iterator<Item = &'a str>, ) -> Result<mlua::Function<'lua>>112 fn resolve_fn_recursive<'lua, 'a>(
113 table: &mlua::Table<'lua>,
114 mut path: impl Iterator<Item = &'a str>,
115 ) -> Result<mlua::Function<'lua>> {
116 if let Some(nxt) = path.next() {
117 match table.get(nxt)? {
118 mlua::Value::Table(t) => resolve_fn_recursive(&t, path),
119 mlua::Value::Function(f) => Ok(f),
120 t => bail!("{:?} is not a function", t),
121 }
122 } else {
123 bail!("Invalid path")
124 }
125 }
126
127 /// This function resolves paths like `builtin.func_foo`, `custom.func_bar` into lua functions.
resolve_fn<'lua>( globals: &mlua::Table<'lua>, path: &str, ) -> Result<mlua::Function<'lua>>128 pub fn resolve_fn<'lua>(
129 globals: &mlua::Table<'lua>,
130 path: &str,
131 ) -> Result<mlua::Function<'lua>> {
132 resolve_fn_recursive(globals, path.split('.'))
133 }
134
call<'lua, R: Deserialize<'lua>>( lua: &'lua Lua, func: &str, arg: mlua::Value<'lua>, ) -> Result<R>135 pub fn call<'lua, R: Deserialize<'lua>>(
136 lua: &'lua Lua,
137 func: &str,
138 arg: mlua::Value<'lua>,
139 ) -> Result<R> {
140 let func = format!("xplr.fn.{}", func);
141 let func = resolve_fn(&lua.globals(), &func)?;
142 let res: mlua::Value = func.call(arg)?;
143 let res: R = lua.from_value(res)?;
144 Ok(res)
145 }
146
147 #[cfg(test)]
148 mod tests {
149
150 use super::*;
151
152 #[test]
test_compatibility()153 fn test_compatibility() {
154 assert!(check_version(VERSION, "foo path").is_ok());
155
156 // Current release if OK
157 assert!(check_version("0.17.0", "foo path").is_ok());
158
159 // Prev major release is ERR
160 // - Not yet
161
162 // Prev minor release is ERR (Change when we get to v1)
163 assert!(check_version("0.16.0", "foo path").is_err());
164
165 // Prev bugfix release is OK
166 // assert!(check_version("0.17.-1", "foo path").is_ok());
167
168 // Next major release is ERR
169 assert!(check_version("1.17.0", "foo path").is_err());
170
171 // Next minor release is ERR
172 assert!(check_version("0.18.0", "foo path").is_err());
173
174 // Next bugfix release is ERR (Change when we get to v1)
175 assert!(check_version("0.17.1", "foo path").is_err());
176 }
177 }
178