1local AH = wesnoth.require "ai/lua/ai_helper.lua"
2local BC = wesnoth.require "ai/lua/battle_calcs.lua"
3local M = wesnoth.map
4
5local function get_wolves(cfg)
6    local wolves = AH.get_units_with_moves {
7        side = wesnoth.current.side,
8        { "and", wml.get_child(cfg, "filter") }
9    }
10    return wolves
11end
12
13local function get_prey(cfg)
14    -- Note: we cannot pass wml.get_child() directly to AH.get_attackable_enemies()
15    -- as the former returns two values and the latter takes optional arguments
16    local filter_second = wml.get_child(cfg, "filter_second")
17    local prey = AH.get_attackable_enemies(filter_second)
18    return prey
19end
20
21local ca_wolves_move = {}
22
23function ca_wolves_move:evaluation(cfg)
24    if (not get_wolves(cfg)[1]) then return 0 end
25    if (not get_prey(cfg)[1]) then return 0 end
26    return cfg.ca_score
27end
28
29function ca_wolves_move:execution(cfg)
30    local wolves = get_wolves(cfg)
31    local prey = get_prey(cfg)
32
33    local avoid_units = AH.get_attackable_enemies({ type = cfg.avoid_type })
34    local avoid_map = BC.get_attack_map(avoid_units).units
35
36    -- Find prey that is closest to the wolves
37    local min_dist, target = 9e99
38    for _,prey_unit in ipairs(prey) do
39        local dist = 0
40        for _,wolf in ipairs(wolves) do
41            dist = dist + M.distance_between(wolf.x, wolf.y, prey_unit.x, prey_unit.y)
42        end
43        if (dist < min_dist) then
44            min_dist, target = dist, prey_unit
45        end
46    end
47
48    -- Now sort wolf from farthest to closest
49    table.sort(wolves, function(a, b)
50        return M.distance_between(a.x, a.y, target.x, target.y) > M.distance_between(b.x, b.y, target.x, target.y)
51    end)
52
53    -- First wolf moves toward target, but tries to stay away from map edges
54    local width, height = wesnoth.get_map_size()
55    local wolf1 = AH.find_best_move(wolves[1], function(x, y)
56        local dist_1t = M.distance_between(x, y, target.x, target.y)
57        local rating = - dist_1t
58        if (x <= 5) then rating = rating - (6 - x) / 1.4 end
59        if (y <= 5) then rating = rating - (6 - y) / 1.4 end
60        if (width - x <= 5) then rating = rating - (6 - (width - x)) / 1.4 end
61        if (height - y <= 5) then rating = rating - (6 - (height - y)) / 1.4 end
62
63       -- Hexes that avoid_type units can reach get a massive penalty
64       if avoid_map:get(x, y) then rating = rating - 1000 end
65
66       return rating
67    end)
68
69    local move_result = AH.movefull_stopunit(ai, wolves[1], wolf1)
70    -- If the wolf was ambushed, return and reconsider; also if an event removed a wolf
71    if (AH.is_incomplete_move(move_result)) then return end
72    for _,check_wolf in ipairs(wolves) do
73        if (not check_wolf) or (not check_wolf.valid) then return end
74    end
75
76    for i = 2,#wolves do
77        move = AH.find_best_move(wolves[i], function(x,y)
78            local rating = 0
79
80            -- We ideally want wolves to be 2-3 hexes from each other
81            -- but this requirement gets weaker and weaker with increasing wolf number
82            for j = 1,i-1 do
83                local dst = M.distance_between(x, y, wolves[j].x, wolves[j].y)
84                rating = rating - (dst - 2.7 * j)^2 / j
85            end
86
87            -- Same distance from Wolf 1 and target for all the wolves
88            local dist_t = M.distance_between(x, y, target.x, target.y)
89            local dist_1t = M.distance_between(wolf1[1], wolf1[2], target.x, target.y)
90            rating = rating - (dist_t - dist_1t)^2
91
92            -- Hexes that avoid_type units can reach get a massive penalty
93            if avoid_map:get(x, y) then rating = rating - 1000 end
94
95            return rating
96        end)
97
98        local move_result = AH.movefull_stopunit(ai, wolves[i], move)
99        -- If the wolf was ambushed, return and reconsider; also if an event removed a wolf
100        if (AH.is_incomplete_move(move_result)) then return end
101        for _,check_wolf in ipairs(wolves) do
102            if (not check_wolf) or (not check_wolf.valid) then return end
103        end
104    end
105end
106
107return ca_wolves_move
108