1-- Prosody IM
2-- Copyright (C) 2008-2010 Matthew Wild
3-- Copyright (C) 2008-2010 Waqas Hussain
4-- Copyright (C) 2018 Kim Alvefur
5--
6-- This project is MIT/X11 licensed. Please see the
7-- COPYING file in the source package for more information.
8--
9
10local getmetatable = getmetatable;
11local next, type = next, type;
12local s_format = string.format;
13local s_gsub = string.gsub;
14local s_rep = string.rep;
15local s_char = string.char;
16local s_match = string.match;
17local t_concat = table.concat;
18
19local pcall = pcall;
20local envload = require"util.envload".envload;
21
22local pos_inf, neg_inf = math.huge, -math.huge;
23-- luacheck: ignore 143/math
24local m_type = math.type or function (n)
25	return n % 1 == 0 and n <= 9007199254740992 and n >= -9007199254740992 and "integer" or "float";
26end;
27
28local char_to_hex = {};
29for i = 0,255 do
30	char_to_hex[s_char(i)] = s_format("%02x", i);
31end
32
33local function to_hex(s)
34	return (s_gsub(s, ".", char_to_hex));
35end
36
37local function fatal_error(obj, why)
38	error("Can't serialize "..type(obj) .. (why and ": ".. why or ""));
39end
40
41local function nonfatal_fallback(x, why)
42	return s_format("{__type=%q,__error=%q}", type(x), why or "fail");
43end
44
45local string_escapes = {
46	['\a'] = [[\a]]; ['\b'] = [[\b]];
47	['\f'] = [[\f]]; ['\n'] = [[\n]];
48	['\r'] = [[\r]]; ['\t'] = [[\t]];
49	['\v'] = [[\v]]; ['\\'] = [[\\]];
50	['\"'] = [[\"]]; ['\''] = [[\']];
51}
52
53for i = 0, 255 do
54	local c = s_char(i);
55	if not string_escapes[c] then
56		string_escapes[c] = s_format("\\%03d", i);
57	end
58end
59
60local default_keywords = {
61	["do"] = true; ["and"] = true; ["else"] = true; ["break"] = true;
62	["if"] = true; ["end"] = true; ["goto"] = true; ["false"] = true;
63	["in"] = true; ["for"] = true; ["then"] = true; ["local"] = true;
64	["or"] = true; ["nil"] = true; ["true"] = true; ["until"] = true;
65	["elseif"] = true; ["function"] = true; ["not"] = true;
66	["repeat"] = true; ["return"] = true; ["while"] = true;
67};
68
69local function new(opt)
70	if type(opt) ~= "table" then
71		opt = { preset = opt };
72	end
73
74	local types = {
75		table = true;
76		string = true;
77		number = true;
78		boolean = true;
79		["nil"] = true;
80	};
81
82	-- presets
83	if opt.preset == "debug" then
84		opt.preset = "oneline";
85		opt.freeze = true;
86		opt.fatal = false;
87		opt.fallback = nonfatal_fallback;
88		opt.unquoted = true;
89	end
90	if opt.preset == "oneline" then
91		opt.indentwith = opt.indentwith or "";
92		opt.itemstart = opt.itemstart or " ";
93		opt.itemlast = opt.itemlast or "";
94		opt.tend = opt.tend or " }";
95	elseif opt.preset == "compact" then
96		opt.indentwith = opt.indentwith or "";
97		opt.itemstart = opt.itemstart or "";
98		opt.itemlast = opt.itemlast or "";
99		opt.equals = opt.equals or "=";
100		opt.unquoted = true;
101	end
102
103	local fallback = opt.fallback or opt.fatal == false and nonfatal_fallback or fatal_error;
104
105	local function ser(v)
106		return (types[type(v)] or fallback)(v);
107	end
108
109	local keywords = opt.keywords or default_keywords;
110
111	-- indented
112	local indentwith = opt.indentwith or "\t";
113	local itemstart = opt.itemstart or "\n";
114	local itemsep = opt.itemsep or ";";
115	local itemlast = opt.itemlast or ";\n";
116	local tstart = opt.tstart or "{";
117	local tend = opt.tend or "}";
118	local kstart = opt.kstart or "[";
119	local kend = opt.kend or "]";
120	local equals = opt.equals or " = ";
121	local unquoted = opt.unquoted == true and "^[%a_][%w_]*$" or opt.unquoted;
122	local hex = opt.hex;
123	local freeze = opt.freeze;
124	local maxdepth = opt.maxdepth or 127;
125	local multirefs = opt.multiref;
126
127	-- serialize one table, recursively
128	-- t - table being serialized
129	-- o - array where tokens are added, concatenate to get final result
130	--   - also used to detect cycles
131	-- l - position in o of where to insert next token
132	-- d - depth, used for indentation
133	local function serialize_table(t, o, l, d)
134		if o[t] then
135			o[l], l = fallback(t, "table has multiple references"), l + 1;
136			return l;
137		elseif d > maxdepth then
138			o[l], l = fallback(t, "max table depth reached"), l + 1;
139			return l;
140		end
141
142		-- Keep track of table loops
143		local ot = t; -- reference pre-freeze
144		o[t] = true;
145		o[ot] = true;
146
147		if freeze == true then
148			-- opportunity to do pre-serialization
149			local mt = getmetatable(t);
150			if type(mt) == "table" then
151				local tag = mt.__name;
152				local fr = mt.__freeze;
153
154				if type(fr) == "function" then
155					t = fr(t);
156					if type(tag) == "string" then
157						o[l], l = tag, l + 1;
158					end
159				end
160			end
161		end
162
163		o[l], l = tstart, l + 1;
164		local indent = s_rep(indentwith, d);
165		local numkey = 1;
166		local ktyp, vtyp;
167		for k,v in next,t do
168			o[l], l = itemstart, l + 1;
169			o[l], l = indent, l + 1;
170			ktyp, vtyp = type(k), type(v);
171			if k == numkey then
172				-- next index in array part
173				-- assuming that these are found in order
174				numkey = numkey + 1;
175			elseif unquoted and ktyp == "string" and
176				not keywords[k] and s_match(k, unquoted) then
177				-- unquoted keys
178				o[l], l = k, l + 1;
179				o[l], l = equals, l + 1;
180			else
181				-- quoted keys
182				o[l], l = kstart, l + 1;
183				if ktyp == "table" then
184					l = serialize_table(k, o, l, d+1);
185				else
186					o[l], l = ser(k), l + 1;
187				end
188				-- =
189				o[l], o[l+1], l = kend, equals, l + 2;
190			end
191
192			-- the value
193			if vtyp == "table" then
194				l = serialize_table(v, o, l, d+1);
195			else
196				o[l], l = ser(v), l + 1;
197			end
198			-- last item?
199			if next(t, k) ~= nil then
200				o[l], l = itemsep, l + 1;
201			else
202				o[l], l = itemlast, l + 1;
203			end
204		end
205		if next(t) ~= nil then
206			o[l], l = s_rep(indentwith, d-1), l + 1;
207		end
208		o[l], l = tend, l +1;
209
210		if multirefs then
211			o[t] = nil;
212			o[ot] = nil;
213		end
214
215		return l;
216	end
217
218	function types.table(t)
219		local o = {};
220		serialize_table(t, o, 1, 1);
221		return t_concat(o);
222	end
223
224	local function serialize_string(s)
225		return '"' .. s_gsub(s, "[%z\1-\31\"\'\\\127-\255]", string_escapes) .. '"';
226	end
227
228	if type(hex) == "string" then
229		function types.string(s)
230			local esc = serialize_string(s);
231			if #esc > (#s*2+2+#hex) then
232				return hex .. '"' .. to_hex(s) .. '"';
233			end
234			return esc;
235		end
236	else
237		types.string = serialize_string;
238	end
239
240	function types.number(t)
241		if m_type(t) == "integer" then
242			return s_format("%d", t);
243		elseif t == pos_inf then
244			return "(1/0)";
245		elseif t == neg_inf then
246			return "(-1/0)";
247		elseif t ~= t then
248			return "(0/0)";
249		end
250		return s_format("%.18g", t);
251	end
252
253	-- Are these faster than tostring?
254	types["nil"] = function()
255		return "nil";
256	end
257
258	function types.boolean(t)
259		return t and "true" or "false";
260	end
261
262	return ser;
263end
264
265local function deserialize(str)
266	if type(str) ~= "string" then return nil; end
267	str = "return "..str;
268	local f, err = envload(str, "=serialized data", {});
269	if not f then return nil, err; end
270	local success, ret = pcall(f);
271	if not success then return nil, ret; end
272	return ret;
273end
274
275local default = new();
276return {
277	new = new;
278	serialize = function (x, opt)
279		if opt == nil then
280			return default(x);
281		else
282			return new(opt)(x);
283		end
284	end;
285	deserialize = deserialize;
286};
287