1-- Prosody IM
2-- Copyright (C) 2008-2010 Matthew Wild
3-- Copyright (C) 2008-2010 Waqas Hussain
4--
5-- This project is MIT/X11 licensed. Please see the
6-- COPYING file in the source package for more information.
7--
8
9local ipairs, pairs, getmetatable, setmetatable, next, tostring =
10      ipairs, pairs, getmetatable, setmetatable, next, tostring;
11local t_concat = table.concat;
12
13local _ENV = nil;
14-- luacheck: std none
15
16local set_mt = { __name = "set" };
17function set_mt.__call(set, _, k)
18	return next(set._items, k);
19end
20
21local items_mt = {};
22function items_mt.__call(items, _, k)
23	return next(items, k);
24end
25
26function set_mt:__freeze()
27	local a, i = {}, 1;
28	for item in self._items do
29		a[i], i = item, i+1;
30	end
31	return a;
32end
33
34local function is_set(o)
35	local mt = getmetatable(o);
36	return mt == set_mt;
37end
38
39local function new(list)
40	local items = setmetatable({}, items_mt);
41	local set = { _items = items };
42
43	-- We access the set through an upvalue in these methods, so ignore 'self' being unused
44	--luacheck: ignore 212/self
45
46	function set:add(item)
47		items[item] = true;
48	end
49
50	function set:contains(item)
51		return items[item];
52	end
53
54	function set:items()
55		return next, items;
56	end
57
58	function set:remove(item)
59		items[item] = nil;
60	end
61
62	function set:add_list(item_list)
63		if item_list then
64			for _, item in ipairs(item_list) do
65				items[item] = true;
66			end
67		end
68	end
69
70	function set:include(otherset)
71		for item in otherset do
72			items[item] = true;
73		end
74	end
75
76	function set:exclude(otherset)
77		for item in otherset do
78			items[item] = nil;
79		end
80	end
81
82	function set:empty()
83		return not next(items);
84	end
85
86	if list then
87		set:add_list(list);
88	end
89
90	return setmetatable(set, set_mt);
91end
92
93local function union(set1, set2)
94	local set = new();
95	local items = set._items;
96
97	for item in pairs(set1._items) do
98		items[item] = true;
99	end
100
101	for item in pairs(set2._items) do
102		items[item] = true;
103	end
104
105	return set;
106end
107
108local function difference(set1, set2)
109	local set = new();
110	local items = set._items;
111
112	for item in pairs(set1._items) do
113		items[item] = (not set2._items[item]) or nil;
114	end
115
116	return set;
117end
118
119local function intersection(set1, set2)
120	local set = new();
121	local items = set._items;
122
123	set1, set2 = set1._items, set2._items;
124
125	for item in pairs(set1) do
126		items[item] = (not not set2[item]) or nil;
127	end
128
129	return set;
130end
131
132local function xor(set1, set2)
133	return union(set1, set2) - intersection(set1, set2);
134end
135
136function set_mt.__add(set1, set2)
137	return union(set1, set2);
138end
139function set_mt.__sub(set1, set2)
140	return difference(set1, set2);
141end
142function set_mt.__div(set, func)
143	local new_set = new();
144	local items, new_items = set._items, new_set._items;
145	for item in pairs(items) do
146		local new_item = func(item);
147		if new_item ~= nil then
148			new_items[new_item] = true;
149		end
150	end
151	return new_set;
152end
153function set_mt.__eq(set1, set2)
154	set1, set2 = set1._items, set2._items;
155	for item in pairs(set1) do
156		if not set2[item] then
157			return false;
158		end
159	end
160
161	for item in pairs(set2) do
162		if not set1[item] then
163			return false;
164		end
165	end
166
167	return true;
168end
169function set_mt.__tostring(set)
170	local s, items = { }, set._items;
171	for item in pairs(items) do
172		s[#s+1] = tostring(item);
173	end
174	return t_concat(s, ", ");
175end
176
177return {
178	new = new;
179	is_set = is_set;
180	union = union;
181	difference = difference;
182	intersection = intersection;
183	xor = xor;
184};
185