1local logger = require "rspamd_logger"
2local telescope = require "telescope"
3local util  = require 'lua_util'
4
5local function rspamd_assert_equals(tbl)
6  return tbl.expect == tbl.actual
7end
8
9local function rspamd_assert_equals_msg(_, tbl)
10  return logger.slog(
11    "Failed asserting that \n  (actual)   : %1 \n equals to\n  (expected) : %2",
12    tbl.actual, tbl.expect
13  )
14end
15
16local function rspamd_assert_table_equals(tbl)
17  return util.table_cmp(tbl.expect, tbl.actual)
18end
19
20local function rspamd_assert_table_equals_sorted(tbl)
21  local expect = util.deepcopy(tbl.expect)
22  local actual = util.deepcopy(tbl.actual)
23  util.deepsort(expect)
24  util.deepsort(actual)
25  return util.table_cmp(expect, actual)
26end
27
28local function table_keys_sorted(t)
29  local keys = {}
30
31  for k,_ in pairs(t) do
32    table.insert(keys, k)
33  end
34  table.sort(keys)
35  return keys;
36end
37
38local function format_line(level, key, v_expect, v_actual)
39  local prefix
40  if v_expect == v_actual then
41    prefix = string.rep(' ', level * 2 + 1)
42    return logger.slog("%s[%s] = %s", prefix, key, v_expect)
43  else
44    prefix = string.rep(' ', level * 2)
45    local ret = {}
46    if v_expect then
47      ret[#ret + 1] = logger.slog("-%s[%s] = %s: %s", prefix, key,
48          type(v_expect), v_expect)
49    end
50    if v_actual then
51      ret[#ret + 1] = logger.slog("+%s[%s] = %s: %s", prefix,
52          (key), type(v_actual), (v_actual))
53    end
54    return table.concat(ret, "\n")
55  end
56end
57
58local function format_table_begin(level, key)
59  local prefix = string.rep(' ', level * 2 + 1)
60  return string.format("%s[%s] = {", prefix, tostring(key))
61end
62
63local function format_table_end(level)
64  local prefix = string.rep(' ', level * 2 + 1)
65  return string.format("%s}", prefix)
66end
67
68local function rspamd_assert_table_diff_msg(_, tbl)
69  local avoid_loops = {}
70  local msg = rspamd_assert_equals_msg(_, tbl)
71
72  local diff = {}
73  local function recurse(expect, actual, level)
74    if avoid_loops[actual] then
75      return
76    end
77    avoid_loops[actual] = true
78
79    local keys_expect = table_keys_sorted(expect)
80    local keys_actual = table_keys_sorted(actual)
81
82    local i_k_expect, i_v_expect = next(keys_expect)
83    local i_k_actual, i_v_actual = next(keys_actual)
84
85    while i_k_expect and i_k_actual do
86      local v_expect = expect[i_v_expect]
87      local v_actual = actual[i_v_actual]
88
89      if i_v_expect == i_v_actual then
90        -- table keys are the same: compare values
91        if type(v_expect) == 'table' and type(v_actual) == 'table' then
92          if util.table_cmp(v_expect, v_actual) then
93            -- we use the same value for 'actual' and 'expect' as soon as they're equal and don't bother us
94            diff[#diff + 1] = format_line(level, i_v_expect, v_expect, v_expect)
95          else
96            diff[#diff + 1] = format_table_begin(level, i_v_expect)
97            recurse(v_expect, v_actual, level + 1)
98            diff[#diff + 1] = format_table_end(level)
99          end
100        else
101          diff[#diff + 1] = format_line(level, i_v_expect, v_expect, v_actual)
102        end
103
104        i_k_expect, i_v_expect = next(keys_expect, i_k_expect)
105        i_k_actual, i_v_actual = next(keys_actual, i_k_actual)
106      elseif tostring(v_actual) > tostring(v_expect) then
107        diff[#diff + 1] = format_line(level, i_v_expect, v_expect, nil)
108        i_k_expect, i_v_expect = next(keys_expect, i_k_expect)
109      else
110        diff[#diff + 1] = format_line(level, i_v_actual, nil, v_actual)
111        i_k_actual, i_v_actual = next(keys_actual, i_k_actual)
112      end
113
114    end
115
116    while i_k_expect do
117      local v_expect = expect[i_v_expect]
118      diff[#diff + 1] = format_line(level, i_v_expect, v_expect, nil)
119      i_k_expect, i_v_expect = next(keys_expect, i_k_expect)
120    end
121
122    while i_k_actual do
123      local v_actual = actual[i_v_actual]
124      diff[#diff + 1] = format_line(level, i_v_actual, nil, v_actual)
125      i_k_actual, i_v_actual = next(keys_actual, i_k_actual)
126    end
127  end
128  recurse(tbl.expect, tbl.actual, 0)
129
130  return string.format("%s\n===== diff (-expect, +actual) ======\n%s", msg, table.concat(diff, "\n"))
131end
132
133telescope.make_assertion("rspamd_eq",       rspamd_assert_equals_msg, rspamd_assert_equals)
134-- telescope.make_assertion("rspamd_table_eq", rspamd_assert_equals_msg, rspamd_assert_table_equals)
135telescope.make_assertion("rspamd_table_eq", rspamd_assert_table_diff_msg, rspamd_assert_table_equals)
136telescope.make_assertion("rspamd_table_eq_sorted", rspamd_assert_table_diff_msg,
137    rspamd_assert_table_equals_sorted)
138
139