1local H = wesnoth.require "helper"
2local AH = wesnoth.require "ai/lua/ai_helper.lua"
3local LS = wesnoth.require "location_set"
4local M = wesnoth.map
5
6-- This is a collection of Lua functions used for custom AI development.
7-- Note that this is still work in progress with significant changes occurring
8-- frequently. Backward compatibility cannot be guaranteed at this time in
9-- development releases, but it is of course easily possible to copy a function
10-- from a previous release directly into an add-on if it is needed there.
11
12local battle_calcs = {}
13
14function battle_calcs.unit_attack_info(unit, cache)
15    -- Return a table containing information about attack-related properties of @unit
16    -- The result can be cached if variable @cache is given
17    -- This is done in order to avoid duplication of slow processes, such as access to unit.__cfg
18
19    -- Return table has fields:
20    --  - attacks: the attack tables from unit.__cfg
21    --  - resist_mod: resistance modifiers (multiplicative factors) index by attack type
22    --  - alignment: just that
23
24    -- Set up a cache index. We use id+max_hitpoints+side, since the
25    -- unit can level up. Side is added to avoid the problem of MP leaders sometimes having
26    -- the same id when the game is started from the command-line
27    local cind = 'UI-' .. unit.id .. unit.max_hitpoints .. unit.side
28
29    -- If cache for this unit exists, return it
30    if cache and cache[cind] then
31        return cache[cind]
32    end
33
34    -- Otherwise collect the information
35    local unit_cfg = unit.__cfg
36    local unit_info = {
37        attacks = {},
38        resist_mod = {},
39        alignment = unit_cfg.alignment
40    }
41    for attack in wml.child_range(unit_cfg, 'attack') do
42        -- Extract information for specials; we do this first because some
43        -- custom special might have the same name as one of the default scalar fields
44        local a = {}
45        for special in wml.child_range(attack, 'specials') do
46            for _,sp in ipairs(special) do
47                if (sp[1] == 'damage') then  -- this is 'backstab'
48                    if (sp[2].id == 'backstab') then
49                        a.backstab = true
50                    else
51                        if (sp[2].id == 'charge') then a.charge = true end
52                    end
53                else
54                    -- magical, marksman
55                    if (sp[1] == 'chance_to_hit') then
56                        a[sp[2].id or 'no_id'] = true
57                    else
58                        a[sp[1]] = true
59                    end
60                end
61            end
62        end
63
64        -- Now extract the scalar (string and number) values from attack
65        for k,v in pairs(attack) do
66            if (type(v) == 'number') or (type(v) == 'string') then
67                a[k] = v
68            end
69        end
70
71        -- [attack]number= defaults to zero; must be defined for battle_calcs.best_weapons()
72        a.number = a.number or 0
73
74        table.insert(unit_info.attacks, a)
75    end
76
77    local attack_types = { "arcane", "blade", "cold", "fire", "impact", "pierce" }
78    for _,attack_type in ipairs(attack_types) do
79        unit_info.resist_mod[attack_type] = wesnoth.unit_resistance(unit, attack_type) / 100.
80    end
81
82    if cache then cache[cind] = unit_info end
83
84    return unit_info
85end
86
87function battle_calcs.strike_damage(attacker, defender, att_weapon, def_weapon, dst, cache)
88    -- Return the single strike damage of an attack by @attacker on @defender
89    -- Also returns the other information about the attack (since we're accessing the information already anyway)
90    -- Here, @att_weapon and @def_weapon are the weapon numbers in Lua counts, i.e., counts start at 1
91    -- If @def_weapon = 0, return 0 for defender damage
92    -- This can be used for defenders that do not have the right kind of weapon, or if
93    -- only the attacker damage is of interest
94    -- @dst: attack location, to take terrain time of day, illumination etc. into account
95    -- For the defender, the current location is assumed
96    --
97    -- 'cache' can be given to cache strike damage and to pass through to battle_calcs.unit_attack_info()
98
99    -- Set up a cache index. We use id+max_hitpoints+side for each unit, since the
100    -- unit can level up.
101    -- Also need to add the weapons and lawful_bonus values for each unit
102    local att_lawful_bonus = wesnoth.get_time_of_day({ dst[1], dst[2], true }).lawful_bonus
103    local def_lawful_bonus = wesnoth.get_time_of_day({ defender.x, defender.y, true }).lawful_bonus
104
105    local cind = 'SD-' .. attacker.id .. attacker.max_hitpoints .. attacker.side
106    cind = cind .. 'x' .. defender.id .. defender.max_hitpoints .. defender.side
107    cind = cind .. '-' .. att_weapon .. 'x' .. def_weapon
108    cind = cind .. '-' .. att_lawful_bonus .. 'x' .. def_lawful_bonus
109
110    -- If cache for this unit exists, return it
111    if cache and cache[cind] then
112        return cache[cind].att_damage, cache[cind].def_damage, cache[cind].att_attack, cache[cind].def_attack
113    end
114
115    local attacker_info = battle_calcs.unit_attack_info(attacker, cache)
116    local defender_info = battle_calcs.unit_attack_info(defender, cache)
117
118    -- Attacker base damage
119    local att_damage = attacker_info.attacks[att_weapon].damage
120
121    -- Opponent resistance modifier
122    local att_multiplier = defender_info.resist_mod[attacker_info.attacks[att_weapon].type] or 1
123
124    -- TOD modifier
125    att_multiplier = att_multiplier * AH.get_unit_time_of_day_bonus(attacker_info.alignment, att_lawful_bonus)
126
127    -- Now do all this for the defender, if def_weapon ~= 0
128    local def_damage, def_multiplier = 0, 1.
129    if (def_weapon ~= 0) then
130        -- Defender base damage
131        def_damage = defender_info.attacks[def_weapon].damage
132
133        -- Opponent resistance modifier
134        def_multiplier = attacker_info.resist_mod[defender_info.attacks[def_weapon].type] or 1
135
136        -- TOD modifier
137        def_multiplier = def_multiplier * AH.get_unit_time_of_day_bonus(defender_info.alignment, def_lawful_bonus)
138    end
139
140    -- Take 'charge' into account
141    if attacker_info.attacks[att_weapon].charge then
142        att_damage = att_damage * 2
143        def_damage = def_damage * 2
144    end
145
146    -- Rounding of .5 values is done differently depending on whether the
147    -- multiplier is greater or smaller than 1
148    if (att_multiplier > 1) then
149        att_damage = H.round(att_damage * att_multiplier - 0.001)
150    else
151        att_damage = H.round(att_damage * att_multiplier + 0.001)
152    end
153
154    if (def_weapon ~= 0) then
155        if (def_multiplier > 1) then
156            def_damage = H.round(def_damage * def_multiplier - 0.001)
157        else
158            def_damage = H.round(def_damage * def_multiplier + 0.001)
159        end
160    end
161
162    if cache then
163        cache[cind] = {
164            att_damage = att_damage,
165            def_damage = def_damage,
166            att_attack = attacker_info.attacks[att_weapon],
167            def_attack = defender_info.attacks[def_weapon]
168        }
169    end
170
171    return att_damage, def_damage, attacker_info.attacks[att_weapon], defender_info.attacks[def_weapon]
172end
173
174function battle_calcs.best_weapons(attacker, defender, dst, cache)
175    -- Return the number (index) of the best weapons for @attacker and @defender
176    -- @dst: attack location, to take terrain time of day, illumination etc. into account
177    -- For the defender, the current location is assumed
178    -- Ideally, we would do a full attack_rating here for all combinations,
179    -- but that would take too long. So we simply define the best weapons
180    -- as those that has the biggest difference between
181    -- damage done and damage received (the latter divided by 2)
182    -- Returns 0 if defender does not have a weapon for this range
183    --
184    -- 'cache' can be given to cache best weapons
185
186    -- Set up a cache index. We use id+max_hitpoints+side for each unit, since the
187    -- unit can level up.
188    -- Also need to add the weapons and lawful_bonus values for each unit
189    local att_lawful_bonus = wesnoth.get_time_of_day({ dst[1], dst[2], true }).lawful_bonus
190    local def_lawful_bonus = wesnoth.get_time_of_day({ defender.x, defender.y, true }).lawful_bonus
191
192    local cind = 'BW-' .. attacker.id .. attacker.max_hitpoints .. attacker.side
193    cind = cind .. 'x' .. defender.id .. defender.max_hitpoints .. defender.side
194    cind = cind .. '-' .. att_lawful_bonus .. 'x' .. def_lawful_bonus
195
196    -- If cache for this unit exists, return it
197    if cache and cache[cind] then
198        return cache[cind].best_att_weapon, cache[cind].best_def_weapon
199    end
200
201    local attacker_info = battle_calcs.unit_attack_info(attacker, cache)
202    local defender_info = battle_calcs.unit_attack_info(defender, cache)
203
204    -- Best attacker weapon
205    local max_rating, best_att_weapon, best_def_weapon = -9e99, 0, 0
206    for att_weapon_number,att_weapon in ipairs(attacker_info.attacks) do
207        local att_damage = battle_calcs.strike_damage(attacker, defender, att_weapon_number, 0, { dst[1], dst[2] }, cache)
208        local max_def_rating, tmp_best_def_weapon = -9e99, 0
209        for def_weapon_number,def_weapon in ipairs(defender_info.attacks) do
210            if (def_weapon.range == att_weapon.range) then
211                local def_damage = battle_calcs.strike_damage(defender, attacker, def_weapon_number, 0, { defender.x, defender.y }, cache)
212                local def_rating = def_damage * def_weapon.number
213                if (def_rating > max_def_rating) then
214                    max_def_rating, tmp_best_def_weapon = def_rating, def_weapon_number
215                end
216            end
217        end
218
219        local rating = att_damage * att_weapon.number
220        if (max_def_rating > -9e99) then rating = rating - max_def_rating / 2. end
221
222        if (rating > max_rating) then
223            max_rating, best_att_weapon, best_def_weapon = rating, att_weapon_number, tmp_best_def_weapon
224        end
225    end
226
227    if cache then
228        cache[cind] = { best_att_weapon = best_att_weapon, best_def_weapon = best_def_weapon }
229    end
230
231    return best_att_weapon, best_def_weapon
232end
233
234function battle_calcs.add_next_strike(cfg, arr, n_att, n_def, att_strike, hit_miss_counts, hit_miss_str)
235    -- Recursive function that sets up the sequences of strikes (misses and hits)
236    -- Each call corresponds to one strike of one of the combattants and can be
237    -- either miss (value 0) or hit (1)
238    --
239    -- Inputs:
240    -- - @cfg: config table with sub-tables att/def for the attacker/defender with the following fields:
241    --   - strikes: total number of strikes
242    --   - max_hits: maximum number of hits the unit can survive
243    --   - firststrike: set to true if attack has firststrike special
244    -- - @arr: an empty array that will hold the output table
245    -- - Other parameters of for recursion purposes only and are initialized below
246
247    -- On the first call of this function, initialize variables
248    -- Counts for hits/misses by both units:
249    --  - Indices 1 & 2: hit/miss for attacker
250    --  - Indices 3 & 4: hit/miss for defender
251    hit_miss_counts = hit_miss_counts or { 0, 0, 0, 0 }
252    hit_miss_str = hit_miss_str or ''  -- string with the hit/miss sequence; for visualization only
253
254    -- Strike counts
255    --  - n_att/n_def = number of strikes taken by attacker/defender
256    --  - att_strike: if true, it's the attacker's turn, otherwise it's the defender's turn
257    if (not n_att) then
258        if cfg.def.firststrike and (not cfg.att.firststrike) then
259            n_att = 0
260            n_def = 1
261            att_strike = false
262        else
263            n_att = 1
264            n_def = 0
265            att_strike = true
266        end
267    else
268        if att_strike then
269            if (n_def < cfg.def.strikes) then
270                n_def = n_def + 1
271                att_strike = false
272            else
273                n_att = n_att + 1
274            end
275        else
276            if (n_att < cfg.att.strikes) then
277                n_att = n_att + 1
278                att_strike = true
279            else
280                n_def = n_def + 1
281            end
282        end
283    end
284
285    -- Create both a hit and a miss
286    for i = 0,1 do  -- 0:miss, 1: hit
287        -- hit/miss counts and string for this call
288        local tmp_hmc = AH.table_copy(hit_miss_counts)
289        local tmp_hmstr = ''
290
291        -- Flag whether the opponent was killed by this strike
292        local killed_opp = false  -- Defaults to falso
293        if att_strike then
294            tmp_hmstr = hit_miss_str .. i  -- attacker hit/miss in string: 0 or 1
295            tmp_hmc[i+1] = tmp_hmc[i+1] + 1  -- Increment hit/miss counts
296            -- Set variable if opponent was killed:
297            if (tmp_hmc[2] > cfg.def.max_hits) then killed_opp = true end
298        -- Even values of n are strikes by the defender
299        else
300            tmp_hmstr = hit_miss_str .. (i+2)  -- defender hit/miss in string: 2 or 3
301            tmp_hmc[i+3] = tmp_hmc[i+3] + 1  -- Increment hit/miss counts
302            -- Set variable if opponent was killed:
303            if (tmp_hmc[4] > cfg.att.max_hits) then killed_opp = true end
304        end
305
306        -- If we've reached the total number of strikes, add this hit/miss combination to table,
307        -- but only if the opponent wasn't killed, as that would end the battle
308        if (n_att + n_def < cfg.att.strikes + cfg.def.strikes) and (not killed_opp) then
309            battle_calcs.add_next_strike(cfg, arr, n_att, n_def, att_strike, tmp_hmc, tmp_hmstr)
310        -- Otherwise, call the next recursion level
311        else
312            table.insert(arr, { hit_miss_str = tmp_hmstr, hit_miss_counts = tmp_hmc })
313        end
314    end
315end
316
317function battle_calcs.battle_outcome_coefficients(cfg, cache)
318    -- Determine the coefficients needed to calculate the hitpoint probability distribution
319    -- of a given battle
320    -- Inputs:
321    -- - @cfg: config table with sub-tables att/def for the attacker/defender with the following fields:
322    --   - strikes: total number of strikes
323    --   - max_hits: maximum number of hits the unit can survive
324    --   - firststrike: whether the unit has firststrike weapon special on this attack
325    -- The result can be cached if variable 'cache' is given
326    --
327    -- Output: table with the coefficients needed to calculate the distribution for both attacker and defender
328    -- First index: number of hits landed on the defender. Each of those contains an array of
329    -- coefficient tables, of format:
330    -- { num = value, am = value, ah = value, dm = value, dh = value }
331    -- This gives one term in a sum of form:
332    -- num * ahp^ah * (1-ahp)^am * dhp^dh * (1-dhp)^dm
333    -- where ahp is the probability that the attacker will land a hit
334    -- and dhp is the same for the defender
335    -- Terms that have exponents of 0 are omitted
336
337    -- Set up the cache id
338    local cind = 'coeff-' .. cfg.att.strikes .. '-' .. cfg.att.max_hits
339    if cfg.att.firststrike then cind = cind .. 'fs' end
340    cind = cind .. 'x' .. cfg.def.strikes .. '-' .. cfg.def.max_hits
341    if cfg.def.firststrike then cind = cind .. 'fs' end
342
343    -- If cache for this unit exists, return it
344    if cache and cache[cind] then
345        return cache[cind].coeffs_att, cache[cind].coeffs_def
346    end
347
348    -- Get the hit/miss counts for the battle
349    local hit_miss_counts = {}
350    battle_calcs.add_next_strike(cfg, hit_miss_counts)
351
352    -- We first calculate the coefficients for the defender HP distribution
353    -- so this is sorted by the number of hits the attacker lands
354
355    -- 'counts' is an array 4 layers deep, where the indices are the number of misses/hits
356    -- are the indices in order attacker miss, attacker hit, defender miss, defender hit
357    -- This is so that they can be grouped by number of attacker hits/misses, for
358    -- subsequent simplification
359    -- The element value is number of times we get the given combination of hits/misses
360    local counts = {}
361    for _,count in ipairs(hit_miss_counts) do
362        local i1 = count.hit_miss_counts[1]
363        local i2 = count.hit_miss_counts[2]
364        local i3 = count.hit_miss_counts[3]
365        local i4 = count.hit_miss_counts[4]
366        if not counts[i1] then counts[i1] = {} end
367        if not counts[i1][i2] then counts[i1][i2] = {} end
368        if not counts[i1][i2][i3] then counts[i1][i2][i3] = {} end
369        counts[i1][i2][i3][i4] = (counts[i1][i2][i3][i4] or 0) + 1
370    end
371
372    local coeffs_def = {}
373    for am,v1 in pairs(counts) do  -- attacker miss count
374        for ah,v2 in pairs(v1) do  -- attacker hit count
375            -- Set up the exponent coefficients for attacker hits/misses
376        local exp = {}  -- Array for an individual set of coefficients
377            -- Only populate those indices that have exponents > 0
378            if (am > 0) then exp.am = am end
379        if (ah > 0) then exp.ah = ah end
380
381            -- We combine results by testing whether they produce the same sum
382            -- with two very different hit probabilities, hp1 = 0.6, hp2 = 0.137
383            -- This will only happen is the coefficients add up to multiples of 1
384            local sum1, sum2 = 0,0
385            local hp1, hp2 = 0.6, 0.137
386            for dm,v3 in pairs(v2) do  -- defender miss count
387                for dh,num in pairs(v3) do  -- defender hit count
388                    sum1 = sum1 + num * hp1^dh * (1 - hp1)^dm
389                    sum2 = sum2 + num * hp2^dh * (1 - hp2)^dm
390                end
391            end
392
393            -- Now, coefficients are set up for each value of total hits by attacker
394            -- This holds all the coefficients that need to be added to get the propability
395            -- of the defender receiving this number of hits
396            if (not coeffs_def[ah]) then coeffs_def[ah] = {} end
397
398            -- If sum1 and sum2 are equal, that means all the defender probs added up to 1, or
399            -- multiple thereof, which means the can all be combine in the calculation
400            if (math.abs(sum1 - sum2) < 1e-9) then
401                exp.num = sum1
402            table.insert(coeffs_def[ah], exp)
403            -- If not, the defender probs don't add up to something nice and all
404            -- need to be calculated one by one
405            else
406                for dm,v3 in pairs(v2) do  -- defender miss count
407                    for dh,num in pairs(v3) do  -- defender hit count
408                        local tmp_exp = AH.table_copy(exp)
409                        tmp_exp.num = num
410                        if (dm > 0) then tmp_exp.dm = dm end
411                        if (dh > 0) then tmp_exp.dh = dh end
412                        table.insert(coeffs_def[ah], tmp_exp)
413                    end
414                end
415            end
416        end
417    end
418
419    -- Now we do the same for the HP distribution of the attacker,
420    -- which means everything needs to be sorted by defender hits
421    local counts = {}
422    for _,count in ipairs(hit_miss_counts) do
423    local i1 = count.hit_miss_counts[3] -- note that the order here is different from above
424        local i2 = count.hit_miss_counts[4]
425        local i3 = count.hit_miss_counts[1]
426    local i4 = count.hit_miss_counts[2]
427        if not counts[i1] then counts[i1] = {} end
428        if not counts[i1][i2] then counts[i1][i2] = {} end
429        if not counts[i1][i2][i3] then counts[i1][i2][i3] = {} end
430        counts[i1][i2][i3][i4] = (counts[i1][i2][i3][i4] or 0) + 1
431    end
432
433    local coeffs_att = {}
434    for dm,v1 in pairs(counts) do  -- defender miss count
435        for dh,v2 in pairs(v1) do  -- defender hit count
436            -- Set up the exponent coefficients for attacker hits/misses
437            local exp = {}  -- Array for an individual set of coefficients
438            -- Only populate those indices that have exponents > 0
439            if (dm > 0) then exp.dm = dm end
440            if (dh > 0) then exp.dh = dh end
441
442            -- We combine results by testing whether they produce the same sum
443            -- with two very different hit probabilities, hp1 = 0.6, hp2 = 0.137
444            -- This will only happen is the coefficients add up to multiples of 1
445            local sum1, sum2 = 0,0
446            local hp1, hp2 = 0.6, 0.137
447            for am,v3 in pairs(v2) do  -- attacker miss count
448                for ah,num in pairs(v3) do  -- attacker hit count
449                    sum1 = sum1 + num * hp1^ah * (1 - hp1)^am
450                    sum2 = sum2 + num * hp2^ah * (1 - hp2)^am
451                end
452            end
453
454            -- Now, coefficients are set up for each value of total hits by attacker
455            -- This holds all the coefficients that need to be added to get the propability
456            -- of the defender receiving this number of hits
457            if (not coeffs_att[dh]) then coeffs_att[dh] = {} end
458
459            -- If sum1 and sum2 are equal, that means all the defender probs added up to 1, or
460            -- multiple thereof, which means the can all be combine in the calculation
461            if (math.abs(sum1 - sum2) < 1e-9) then
462                exp.num = sum1
463                table.insert(coeffs_att[dh], exp)
464            -- If not, the defender probs don't add up to something nice and all
465            -- need to be calculated one by one
466            else
467                for am,v3 in pairs(v2) do  -- defender miss count
468                    for ah,num in pairs(v3) do  -- defender hit count
469                        local tmp_exp = AH.table_copy(exp)
470                    tmp_exp.num = num
471                        if (am > 0) then tmp_exp.am = am end
472                        if (ah > 0) then tmp_exp.ah = ah end
473                    table.insert(coeffs_att[dh], tmp_exp)
474                    end
475                end
476            end
477        end
478    end
479
480    -- The probability for the number of hits with the most terms can be skipped
481    -- and 1-sum(other_terms) can be used instead. Set a flag for which term to skip
482    local max_number, biggest_equation = 0, -1
483    for hits,v in pairs(coeffs_att) do
484        local number = 0
485        for _,c in pairs(v) do number = number + 1 end
486        if (number > max_number) then
487            max_number, biggest_equation = number, hits
488        end
489    end
490    coeffs_att[biggest_equation].skip = true
491
492    local max_number, biggest_equation = 0, -1
493    for hits,v in pairs(coeffs_def) do
494        local number = 0
495        for _,c in pairs(v) do number = number + 1 end
496        if (number > max_number) then
497            max_number, biggest_equation = number, hits
498        end
499    end
500    coeffs_def[biggest_equation].skip = true
501
502    if cache then cache[cind] = { coeffs_att = coeffs_att, coeffs_def = coeffs_def } end
503
504    return coeffs_att, coeffs_def
505end
506
507function battle_calcs.print_coefficients()
508    -- Print out the set of coefficients for a given number of attacker and defender strikes
509    -- Also print numerical values for a given hit probability
510    -- This function is for debugging purposes only
511
512    -- Configure these values at will
513    local attacker_strikes, defender_strikes = 3, 3  -- number of strikes
514    local att_hit_prob, def_hit_prob = 0.8, 0.4  -- probability of landing a hit attacker/defender
515    local attacker_coeffs = true -- attacker coefficients if set to true, defender coefficients otherwise
516    local defender_firststrike, attacker_firststrike = true, false
517
518    -- Go through all combinations of maximum hits either attacker or defender can survive
519    -- Note how this has to be crossed between ahits and defender_strikes and vice versa
520    for ahits = defender_strikes,0,-1 do
521        for dhits = attacker_strikes,0,-1 do
522            -- Get the coefficients for this case
523            local cfg = {
524                att = { strikes = attacker_strikes, max_hits = ahits, firststrike = attacker_firststrike },
525                def = { strikes = defender_strikes, max_hits = dhits, firststrike = defender_firststrike }
526            }
527
528            local coeffs, dummy = {}, {}
529            if attacker_coeffs then
530                coeffs = battle_calcs.battle_outcome_coefficients(cfg)
531            else
532                dummy, coeffs = battle_calcs.battle_outcome_coefficients(cfg)
533            end
534
535            std_print()
536            std_print('Attacker: ' .. cfg.att.strikes .. ' strikes, can survive ' .. cfg.att.max_hits .. ' hits')
537            std_print('Defender: ' .. cfg.def.strikes .. ' strikes, can survive ' .. cfg.def.max_hits .. ' hits')
538            std_print('Chance of hits on defender: ')
539
540            -- The first indices of coeffs are the possible number of hits the attacker can land on the defender
541            for hits = 0,#coeffs do
542                local hit_prob = 0.  -- probability for this number of hits
543                local str = ''  -- output string
544
545                local combs = coeffs[hits]  -- the combinations of coefficients to be evaluated
546                for i,exp in ipairs(combs) do  -- exp: exponents (and factor) for a set
547                    local prob = exp.num  -- probability for this set
548                    str = str .. exp.num
549                    if exp.am then
550                       prob = prob * (1 - att_hit_prob) ^ exp.am
551                       str = str .. ' pma^' .. exp.am
552                    end
553                    if exp.ah then
554                        prob = prob * att_hit_prob ^ exp.ah
555                        str = str .. ' pha^' .. exp.ah
556                    end
557                    if exp.dm then
558                        prob = prob * (1 - def_hit_prob) ^ exp.dm
559                        str = str .. ' pmd^' .. exp.dm
560                    end
561                    if exp.dh then
562                       prob = prob * def_hit_prob ^ exp.dh
563                        str = str .. ' phd^' .. exp.dh
564                    end
565
566                    hit_prob = hit_prob + prob  -- total probabilty for this number of hits landed
567                    if (i ~= #combs) then str = str .. '  +  ' end
568                end
569
570                local skip_str = ''
571                if combs.skip then skip_str = ' (skip)' end
572
573                std_print(hits .. skip_str .. ':  ' .. str)
574                std_print('      = ' .. hit_prob)
575            end
576        end
577    end
578end
579
580function battle_calcs.hp_distribution(coeffs, att_hit_prob, def_hit_prob, starting_hp, damage, opp_attack)
581    -- Multiply out the coefficients from battle_calcs.battle_outcome_coefficients()
582    -- For a given attacker and defender hit/miss probability
583    -- Also needed: the starting HP for the unit and the damage done by the opponent
584    -- and the opponent attack information @opp_attack
585
586    local stats  = { hp_chance = {}, average_hp = 0 }
587    local skip_hp, skip_prob = -1, 1
588    for hits = 0,#coeffs do
589        local hp = starting_hp - hits * damage
590        if (hp < 0) then hp = 0 end
591
592        -- Calculation of the outcome with the most terms can be skipped
593        if coeffs[hits].skip then
594            skip_hp = hp
595        else
596            local hp_prob = 0.  -- probability for this number of hits
597            for _,exp in ipairs(coeffs[hits]) do  -- exp: exponents (and factor) for a set
598                local prob = exp.num  -- probability for this set
599                if exp.am then prob = prob * (1 - att_hit_prob) ^ exp.am end
600                if exp.ah then prob = prob * att_hit_prob ^ exp.ah end
601                if exp.dm then prob = prob * (1 - def_hit_prob) ^ exp.dm end
602                if exp.dh then prob = prob * def_hit_prob ^ exp.dh end
603
604                hp_prob = hp_prob + prob  -- total probabilty for this number of hits landed
605            end
606
607            stats.hp_chance[hp] = hp_prob
608            stats.average_hp = stats.average_hp + hp * hp_prob
609
610            -- Also subtract this probability from the total prob. (=1), to get prob. of skipped outcome
611            skip_prob = skip_prob - hp_prob
612        end
613    end
614
615    -- Add in the outcome that was skipped
616    stats.hp_chance[skip_hp] = skip_prob
617    stats.average_hp = stats.average_hp + skip_hp * skip_prob
618
619    -- And always set hp_chance[0] since it is of such importance in the analysis
620    stats.hp_chance[0] = stats.hp_chance[0] or 0
621
622    -- Add poison probability
623    if opp_attack and opp_attack.poison then
624        stats.poisoned = 1. - stats.hp_chance[starting_hp]
625    else
626        stats.poisoned = 0
627    end
628
629    -- Add slow probability
630    if opp_attack and opp_attack.slow then
631        stats.slowed = 1. - stats.hp_chance[starting_hp]
632    else
633        stats.slowed = 0
634    end
635
636    return stats
637end
638
639function battle_calcs.battle_outcome(attacker, defender, cfg, cache)
640    -- Calculate the stats of a combat by @attacker vs. @defender
641    -- @cfg: optional input parameters
642    --  - att_weapon/def_weapon: attacker/defender weapon number
643    --      if not given, get "best" weapon (Note: both must be given, or they will both be determined)
644    --  - dst: { x, y }: the attack location; defaults to { attacker.x, attacker.y }
645    -- @cache: to be passed on to other functions. battle_outcome itself is not cached, too many factors enter
646
647    cfg = cfg or {}
648
649    local dst = cfg.dst or { attacker.x, attacker.y }
650
651    local att_weapon, def_weapon = 0, 0
652    if (not cfg.att_weapon) or (not cfg.def_weapon) then
653        att_weapon, def_weapon = battle_calcs.best_weapons(attacker, defender, dst, cache)
654    else
655        att_weapon, def_weapon = cfg.att_weapon, cfg.def_weapon
656    end
657
658    -- Collect all the information needed for the calculation
659    -- Strike damage and numbers
660    local att_damage, def_damage, att_attack, def_attack =
661        battle_calcs.strike_damage(attacker, defender, att_weapon, def_weapon, { dst[1], dst[2] }, cache)
662
663    -- Take swarm into account
664    local att_strikes, def_strikes = att_attack.number, 0
665    if (def_damage > 0) then
666        def_strikes = def_attack.number
667    end
668
669    if att_attack.swarm then
670        att_strikes = math.floor(att_strikes * attacker.hitpoints / attacker.max_hitpoints)
671    end
672    if def_attack and def_attack.swarm then
673        def_strikes = math.floor(def_strikes * defender.hitpoints / defender.max_hitpoints)
674    end
675
676    -- Maximum number of hits that either unit can survive
677    local att_max_hits = math.floor((attacker.hitpoints - 1) / def_damage)
678    if (att_max_hits > def_strikes) then att_max_hits = def_strikes end
679    local def_max_hits = math.floor((defender.hitpoints - 1) / att_damage)
680    if (def_max_hits > att_strikes) then def_max_hits = att_strikes end
681
682    -- Probability of landing a hit
683    local att_hit_prob = wesnoth.unit_defense(defender, wesnoth.get_terrain(defender.x, defender.y)) / 100.
684    local def_hit_prob = wesnoth.unit_defense(attacker, wesnoth.get_terrain(dst[1], dst[2])) / 100.
685
686    -- Magical: attack and defense, and under all circumstances
687    if att_attack.magical then att_hit_prob = 0.7 end
688    if def_attack and def_attack.magical then def_hit_prob = 0.7 end
689
690    -- Marksman: attack only, and only if terrain defense is less
691    if att_attack.marksman and (att_hit_prob < 0.6) then
692        att_hit_prob = 0.6
693    end
694
695    -- Get the coefficients for this kind of combat
696    local def_firstrike = false
697    if def_attack and def_attack.firststrike then def_firstrike = true end
698
699    local cfg = {
700        att = { strikes = att_strikes, max_hits = att_max_hits, firststrike = att_attack.firststrike },
701        def = { strikes = def_strikes, max_hits = def_max_hits, firststrike = def_firstrike }
702    }
703    local att_coeffs, def_coeffs = battle_calcs.battle_outcome_coefficients(cfg, cache)
704
705    -- And multiply out the factors
706    -- Note that att_hit_prob, def_hit_prob need to be in that order for both calls
707    local att_stats = battle_calcs.hp_distribution(att_coeffs, att_hit_prob, def_hit_prob, attacker.hitpoints, def_damage, def_attack)
708    local def_stats = battle_calcs.hp_distribution(def_coeffs, att_hit_prob, def_hit_prob, defender.hitpoints, att_damage, att_attack)
709
710    return att_stats, def_stats
711end
712
713function battle_calcs.simulate_combat_fake()
714    -- A function to return a fake simulate_combat result
715    -- Debug function to test how long simulate_combat takes
716    -- It doesn't need any arguments -> can be called with the arguments of other simulate_combat functions
717    local att_stats, def_stats = { hp_chance = {} }, { hp_chance = {} }
718
719    att_stats.hp_chance[0] = 0
720    att_stats.hp_chance[21], att_stats.hp_chance[23], att_stats.hp_chance[25], att_stats.hp_chance[27] = 0.125, 0.375, 0.375, 0.125
721    att_stats.poisoned, att_stats.slowed, att_stats.average_hp = 0.875, 0, 24
722
723    def_stats.hp_chance[0], def_stats.hp_chance[2], def_stats.hp_chance[10] = 0.09, 0.42, 0.49
724    def_stats.poisoned, def_stats.slowed, def_stats.average_hp = 0, 0, 1.74
725
726    return att_stats, def_stats
727end
728
729function battle_calcs.simulate_combat_loc(attacker, dst, defender, weapon)
730    -- Get simulate_combat results for unit @attacker attacking unit @defender
731    -- when on terrain of same type as that at @dst, which is of form { x, y }
732    -- If @weapon is set, use that weapon (Lua index starting at 1), otherwise use best weapon
733
734    local attacker_dst = wesnoth.copy_unit(attacker)
735    attacker_dst.x, attacker_dst.y = dst[1], dst[2]
736
737    if weapon then
738        return wesnoth.simulate_combat(attacker_dst, weapon, defender)
739    else
740        return wesnoth.simulate_combat(attacker_dst, defender)
741    end
742end
743
744function battle_calcs.attack_rating(attacker, defender, dst, cfg, cache)
745    -- Returns a common (but configurable) rating for attacks
746    -- Inputs:
747    -- @attacker: attacker unit
748    -- @defender: defender unit
749    -- @dst: the attack location in form { x, y }
750    -- @cfg: table of optional inputs and configurable rating parameters
751    --  Optional inputs:
752    --    - att_stats, def_stats: if given, use these stats, otherwise calculate them here
753    --        Note: these are calculated in combination, that is they either both need to be passed or both be omitted
754    --    - att_weapon/def_weapon: the attacker/defender weapon to be used if calculating battle stats here
755    --        This parameter is meaningless (unused) if att_stats/def_stats are passed
756    --        Defaults to weapon that does most damage to the opponent
757    --        Note: as with the stats, they either both need to be passed or both be omitted
758    -- @cache: cache table to be passed to battle_calcs.battle_outcome
759    --
760    -- Returns:
761    --   - Overall rating for the attack or attack combo
762    --   - Defender rating: not additive for attack combos; needs to be calculated for the
763    --     defender stats of the last attack in a combo (that works for everything except
764    --     the rating whether the defender is about to level in the attack combo)
765    --   - Attacker rating: this one is split up into two terms:
766    --     - a term that is additive for individual attacks in a combo
767    --     - a term that needs to be average for the individual attacks in a combo
768    --   - att_stats, def_stats: useful if they were calculated here, rather than passed down
769
770    cfg = cfg or {}
771
772    -- Set up the config parameters for the rating
773    local enemy_leader_weight = cfg.enemy_leader_weight or 5.
774    local defender_starting_damage_weight = cfg.defender_starting_damage_weight or 0.33
775    local xp_weight = cfg.xp_weight or 0.25
776    local level_weight = cfg.level_weight or 1.0
777    local defender_level_weight = cfg.defender_level_weight or 1.0
778    local distance_leader_weight = cfg.distance_leader_weight or 0.002
779    local defense_weight = cfg.defense_weight or 0.1
780    local occupied_hex_penalty = cfg.occupied_hex_penalty or -0.001
781    local own_value_weight = cfg.own_value_weight or 1.0
782
783    -- Get att_stats, def_stats
784    -- If they are passed in cfg, use those
785    local att_stats, def_stats = {}, {}
786    if (not cfg.att_stats) or (not cfg.def_stats) then
787        -- If cfg specifies the weapons use those, otherwise use "best" weapons
788        -- In the latter case, cfg.???_weapon will be nil, which will be passed on
789        local battle_cfg = { att_weapon = cfg.att_weapon, def_weapon = cfg.def_weapon, dst = dst }
790        att_stats,def_stats = battle_calcs.battle_outcome(attacker, defender, battle_cfg, cache)
791    else
792        att_stats, def_stats = cfg.att_stats, cfg.def_stats
793    end
794
795    -- We also need the leader (well, the location at least)
796    -- because if there's no other difference, prefer location _between_ the leader and the defender
797    local leader = wesnoth.get_units { side = attacker.side, canrecruit = 'yes' }[1]
798
799    ------ All the attacker contributions: ------
800    -- Add up rating for the attacking unit
801    -- We add this up in units of fraction of max_hitpoints
802    -- It is multiplied by unit cost later, to get a gold equivalent value
803
804    -- Average damage to unit is negative rating
805    local damage = attacker.hitpoints - att_stats.average_hp
806    -- Count poisoned as additional 8 HP damage times probability of being poisoned
807    if (att_stats.poisoned ~= 0) then
808        damage = damage + 8 * (att_stats.poisoned - att_stats.hp_chance[0])
809    end
810    -- Count slowed as additional 6 HP damage times probability of being slowed
811    if (att_stats.slowed ~= 0) then
812        damage = damage + 6 * (att_stats.slowed - att_stats.hp_chance[0])
813    end
814
815    -- If attack is from a village, we count that as a 10 HP bonus
816    local is_village = wesnoth.get_terrain_info(wesnoth.get_terrain(dst[1], dst[2])).village
817    if is_village then
818        damage = damage - 10.
819    end
820
821    -- If attack is adjacent to an unoccupied village, that's bad
822    for xa,ya in H.adjacent_tiles(dst[1], dst[2]) do
823        local is_adjacent_village = wesnoth.get_terrain_info(wesnoth.get_terrain(xa, ya)).village
824        if is_adjacent_village and (not wesnoth.get_unit(xa, ya)) then
825            damage = damage + 10
826        end
827    end
828
829    if (damage < 0) then damage = 0 end
830
831    -- Fraction damage (= fractional value of the unit)
832    local value_fraction = - damage / attacker.max_hitpoints
833
834    -- Additional, subtract the chance to die, in order to (de)emphasize units that might die
835    value_fraction = value_fraction - att_stats.hp_chance[0]
836
837    -- In addition, potentially leveling up in this attack is a huge bonus,
838    -- proportional to the chance of it happening and the chance of not dying itself
839    local level_bonus = 0.
840    local defender_level = wesnoth.unit_types[defender.type].level
841    if (attacker.max_experience - attacker.experience <= defender_level) then
842        level_bonus = 1. - att_stats.hp_chance[0]
843    else
844        if (attacker.max_experience - attacker.experience <= defender_level * 8) then
845            level_bonus = (1. - att_stats.hp_chance[0]) * def_stats.hp_chance[0]
846        end
847    end
848    value_fraction = value_fraction + level_bonus * level_weight
849
850
851    -- Now convert this into gold-equivalent value
852    local attacker_value = wesnoth.unit_types[attacker.type].cost
853
854    -- Being closer to leveling is good (this makes AI prefer units with lots of XP)
855    local xp_bonus = attacker.experience / attacker.max_experience
856    attacker_value = attacker_value * (1. + xp_bonus * xp_weight)
857
858    local attacker_rating = value_fraction * attacker_value
859
860    ------ Now (most of) the same for the defender ------
861    -- Average damage to defender is positive rating
862    local damage = defender.hitpoints - def_stats.average_hp
863    -- Count poisoned as additional 8 HP damage times probability of being poisoned
864    if (def_stats.poisoned ~= 0) then
865        damage = damage + 8 * (def_stats.poisoned - def_stats.hp_chance[0])
866    end
867    -- Count slowed as additional 6 HP damage times probability of being slowed
868    if (def_stats.slowed ~= 0) then
869        damage = damage + 6 * (def_stats.slowed - def_stats.hp_chance[0])
870    end
871
872    -- If defender is on a village, we count that as a 10 HP bonus
873    local is_village = wesnoth.get_terrain_info(wesnoth.get_terrain(defender.x, defender.y)).village
874    if is_village then
875        damage = damage - 10.
876    end
877
878    if (damage < 0) then damage = 0. end
879
880    -- Fraction damage (= fractional value of the unit)
881    local value_fraction = damage / defender.max_hitpoints
882
883    -- Additional, add the chance to kill, in order to emphasize enemies we might be able to kill
884    value_fraction = value_fraction + def_stats.hp_chance[0]
885
886    -- In addition, the defender potentially leveling up in this attack is a huge penalty,
887    -- proportional to the chance of it happening and the chance of not dying itself
888    local defender_level_penalty = 0.
889    local attacker_level = wesnoth.unit_types[attacker.type].level
890    if (defender.max_experience - defender.experience <= attacker_level) then
891        defender_level_penalty = 1. - def_stats.hp_chance[0]
892    else
893        if (defender.max_experience - defender.experience <= attacker_level * 8) then
894            defender_level_penalty = (1. - def_stats.hp_chance[0]) * att_stats.hp_chance[0]
895        end
896    end
897    value_fraction = value_fraction - defender_level_penalty * defender_level_weight
898
899    -- Now convert this into gold-equivalent value
900    local defender_value = wesnoth.unit_types[defender.type].cost
901
902    -- If this is the enemy leader, make damage to it much more important
903    if defender.canrecruit then
904        defender_value = defender_value * enemy_leader_weight
905    end
906
907    -- And prefer to attack already damaged enemies
908    local defender_starting_damage_fraction = (defender.max_hitpoints - defender.hitpoints) / defender.max_hitpoints
909    defender_value = defender_value * (1. + defender_starting_damage_fraction * defender_starting_damage_weight)
910
911    -- Being closer to leveling is good, we want to get rid of those enemies first
912    local xp_bonus = defender.experience / defender.max_experience
913    defender_value = defender_value * (1. + xp_bonus * xp_weight)
914
915    -- If defender is on a village, add a bonus rating (we want to get rid of those preferentially)
916    -- So yes, this is positive, even though it's a plus for the defender
917    -- Defenders on villages also got a negative damage rating above (these don't exactly cancel each other though)
918    local is_village = wesnoth.get_terrain_info(wesnoth.get_terrain(defender.x, defender.y)).village
919    if is_village then
920        defender_value = defender_value * (1. + 10. / attacker.max_hitpoints)
921    end
922
923    -- We also add a few contributions that are not directly attack/damage dependent
924    -- These are added to the defender rating for two reasons:
925    --   1. Defender rating is positive (and thus contributions can be made positive)
926    --   2. It is then independent of value of aggression (cfg.own_value_weight)
927    --
928    -- These are kept small though, so they mostly only serve as tie breakers
929    -- And yes, they might bring the overall rating from slightly negative to slightly positive
930    -- or vice versa, but as that is only approximate anyway, we keep it this way for simplicity
931
932    -- We don't need a bonus for good terrain for the attacker, as that is covered in the damage calculation
933    -- However, we add a small bonus for good terrain defense of the _defender_ on the _attack_ hex
934    -- This is in order to take good terrain away from defender on next move, all else being equal
935    local defender_defense = - wesnoth.unit_defense(defender, wesnoth.get_terrain(dst[1], dst[2])) / 100.
936    defender_value = defender_value + defender_defense * defense_weight
937
938    -- Get a very small bonus for hexes in between defender and AI leader
939    -- 'relative_distances' is larger for attack hexes closer to the side leader (possible values: -1 .. 1)
940    if leader then
941        local relative_distances =
942            M.distance_between(defender.x, defender.y, leader.x, leader.y)
943            - M.distance_between(dst[1], dst[2], leader.x, leader.y)
944        defender_value = defender_value + relative_distances * distance_leader_weight
945    end
946
947    -- Add a very small penalty for attack hexes occupied by other units
948    -- Note: it must be checked previously that the unit on the hex can move away
949    if (dst[1] ~= attacker.x) or (dst[2] ~= attacker.y) then
950        if wesnoth.get_unit(dst[1], dst[2]) then
951            defender_value = defender_value + occupied_hex_penalty
952        end
953    end
954
955    local defender_rating = value_fraction * defender_value
956
957    -- Finally apply factor of own unit weight to defender unit weight
958    attacker_rating = attacker_rating * own_value_weight
959
960    local rating = defender_rating + attacker_rating
961
962    return rating, defender_rating, attacker_rating, att_stats, def_stats
963end
964
965function battle_calcs.attack_combo_stats(tmp_attackers, tmp_dsts, defender, cache, cache_this_move)
966    -- Calculate attack combination outcomes using
967    -- @tmp_attackers: array of attacker units (this is done so that
968    --   the units need not be found here, as likely doing it in the
969    --   calling function is more efficient (because of repetition)
970    -- @tmp_dsts: array of the hexes (format { x, y }) from which the attackers attack
971    --   must be in same order as @attackers
972    -- @defender: the unit being attacked
973    -- @cache: the cache table to be passed through to other battle_calcs functions
974    --   attack_combo_stats itself is not cached, except for in cache_this_move below
975    -- @cache_this_move: an optional table of pre-calculated attack outcomes
976    --   - This is different from the other cache tables used in this file
977    --   - This table may only persist for this move (move, not turn !!!), as otherwise too many things change
978    --
979    -- Return values:
980    --   - The rating for this attack combination calculated from battle_calcs.attack_rating() results
981    --   - The sorted attackers and dsts arrays
982    --   - att_stats: an array of stats for each attacker, in the same order as 'attackers'
983    --   - defender combo stats: one set of stats containing the defender stats after the attack combination
984    --   - def_stats: an array of defender stats for each individual attack, in the same order as 'attackers'
985
986    cache_this_move = cache_this_move or {}
987
988    -- We first simulate and rate the individual attacks
989    local ratings, tmp_attacker_ratings = {}, {}
990    local tmp_att_stats, tmp_def_stats = {}, {}
991    local defender_ind = defender.x * 1000 + defender.y
992    for i,attacker in ipairs(tmp_attackers) do
993        -- Initialize or use the 'cache_this_move' table
994        local att_ind = attacker.x * 1000 + attacker.y
995        local dst_ind = tmp_dsts[i][1] * 1000 + tmp_dsts[i][2]
996        if (not cache_this_move[defender_ind]) then cache_this_move[defender_ind] = {} end
997        if (not cache_this_move[defender_ind][att_ind]) then cache_this_move[defender_ind][att_ind] = {} end
998
999        if (not cache_this_move[defender_ind][att_ind][dst_ind]) then
1000            -- Get the base rating
1001            local base_rating, def_rating, att_rating, att_stats, def_stats =
1002                battle_calcs.attack_rating(attacker, defender, tmp_dsts[i], {}, cache )
1003            tmp_attacker_ratings[i] = att_rating
1004            tmp_att_stats[i], tmp_def_stats[i] = att_stats, def_stats
1005
1006            -- But for combos, also want units with highest attack outcome uncertainties to go early
1007            -- So that we can change our mind in case of unfavorable outcome
1008            --local outcome_variance = 0
1009            --local av = tmp_def_stats[i].average_hp
1010            --local n_outcomes = 0
1011
1012            --for hp,p in pairs(tmp_def_stats[i].hp_chance) do
1013            --    if (p > 0) then
1014            --        local dhp_norm = (hp - av) / defender.max_hitpoints * wesnoth.unit_types[defender.type].cost
1015            --        local dvar = p * dhp_norm^2
1016            --        outcome_variance = outcome_variance + dvar
1017            --        n_outcomes = n_outcomes + 1
1018            --    end
1019            --end
1020            --outcome_variance = outcome_variance / n_outcomes
1021
1022            -- Note that this is a variance, not a standard deviations (as in, it's squared),
1023            -- so it does not matter much for low-variance attacks, but takes on large values for
1024            -- high variance attacks. I think that is what we want.
1025            local rating = base_rating --+ outcome_variance
1026
1027            -- If attacker has attack with 'slow' special, it should always go first
1028            -- Almost, bonus should not be quite as high as a really high CTK
1029            -- This isn't quite true in reality, but can be refined later
1030            if AH.has_weapon_special(attacker, "slow") then
1031                rating = rating + wesnoth.unit_types[defender.type].cost / 2.
1032            end
1033
1034            ratings[i] = { i, rating, base_rating, def_rating, att_rating }
1035
1036            -- Now add this attack to the cache_this_move table, so that next time around, we don't have to do this again
1037            cache_this_move[defender_ind][att_ind][dst_ind] = {
1038                rating = { -1, rating, base_rating, def_rating, att_rating },  -- Cannot use { i, rating, ... } here, as 'i' might be different next time
1039                attacker_ratings = tmp_attacker_ratings[i],
1040                att_stats = tmp_att_stats[i],
1041                def_stats = tmp_def_stats[i]
1042            }
1043        else
1044            local tmp_rating = cache_this_move[defender_ind][att_ind][dst_ind].rating
1045            tmp_rating[1] = i
1046            ratings[i] = tmp_rating
1047            tmp_attacker_ratings[i] = cache_this_move[defender_ind][att_ind][dst_ind].attacker_ratings
1048            tmp_att_stats[i] = cache_this_move[defender_ind][att_ind][dst_ind].att_stats
1049            tmp_def_stats[i] = cache_this_move[defender_ind][att_ind][dst_ind].def_stats
1050        end
1051    end
1052
1053    -- Now sort all the arrays based on this rating
1054    -- This will give the order in which the individual attacks are executed
1055    table.sort(ratings, function(a, b) return a[2] > b[2] end)
1056
1057    -- Reorder attackers, dsts in this order
1058    local attackers, dsts, att_stats, def_stats, attacker_ratings = {}, {}, {}, {}, {}
1059    for i,rating in ipairs(ratings) do
1060        attackers[i], dsts[i] = tmp_attackers[rating[1]], tmp_dsts[rating[1]]
1061    end
1062    -- Only keep the stats/ratings for the first attacker, the rest needs to be recalculated
1063    att_stats[1], def_stats[1] = tmp_att_stats[ratings[1][1]], tmp_def_stats[ratings[1][1]]
1064    attacker_ratings[1] = tmp_attacker_ratings[ratings[1][1]]
1065
1066    tmp_attackers, tmp_dsts, tmp_att_stats, tmp_def_stats, tmp_attacker_ratings = nil, nil, nil, nil, nil
1067
1068    -- Then we go through all the other attacks and calculate the outcomes
1069    -- based on all the possible outcomes of the previous attacks
1070    for i = 2,#attackers do
1071        att_stats[i] = { hp_chance = {} }
1072        def_stats[i] = { hp_chance = {} }
1073        local dst_ind = dsts[i][1] * 1000 + dsts[i][2]
1074
1075        for hp1,prob1 in pairs(def_stats[i-1].hp_chance) do -- Note: need pairs(), not ipairs() !!
1076            if (hp1 == 0) then
1077                att_stats[i].hp_chance[attackers[i].hitpoints] =
1078                    (att_stats[i].hp_chance[attackers[i].hitpoints] or 0) + prob1
1079                def_stats[i].hp_chance[0] = (def_stats[i].hp_chance[0] or 0) + prob1
1080            else
1081                local org_hp = defender.hitpoints
1082                defender.hitpoints = hp1
1083                local ast, dst
1084                local att_ind_i = attackers[i].x * 1000 + attackers[i].y
1085
1086                if (not cache_this_move[defender_ind][att_ind_i][dst_ind][hp1]) then
1087                    ast, dst = battle_calcs.battle_outcome(attackers[i], defender, { dst = dsts[i] } , cache)
1088                    cache_this_move[defender_ind][att_ind_i][dst_ind][hp1] = { ast = ast, dst = dst }
1089                else
1090                    ast = cache_this_move[defender_ind][att_ind_i][dst_ind][hp1].ast
1091                    dst = cache_this_move[defender_ind][att_ind_i][dst_ind][hp1].dst
1092                end
1093
1094                defender.hitpoints = org_hp
1095
1096                for hp2,prob2 in pairs(ast.hp_chance) do
1097                    att_stats[i].hp_chance[hp2] = (att_stats[i].hp_chance[hp2] or 0) + prob1 * prob2
1098                end
1099                for hp2,prob2 in pairs(dst.hp_chance) do
1100                    def_stats[i].hp_chance[hp2] = (def_stats[i].hp_chance[hp2] or 0) + prob1 * prob2
1101                end
1102
1103                -- Also do poisoned, slowed
1104                if (not att_stats[i].poisoned) then
1105                    att_stats[i].poisoned = ast.poisoned
1106                    att_stats[i].slowed = ast.slowed
1107                    def_stats[i].poisoned = 1. - (1. - dst.poisoned) * (1. - def_stats[i-1].poisoned)
1108                    def_stats[i].slowed = 1. - (1. - dst.slowed) * (1. - def_stats[i-1].slowed)
1109                end
1110            end
1111        end
1112
1113        -- Get the average HP
1114        local av_hp = 0
1115        for hp,prob in pairs(att_stats[i].hp_chance) do av_hp = av_hp + hp * prob end
1116        att_stats[i].average_hp = av_hp
1117        local av_hp = 0
1118        for hp,prob in pairs(def_stats[i].hp_chance) do av_hp = av_hp + hp * prob end
1119        def_stats[i].average_hp = av_hp
1120    end
1121
1122    -- Get the total rating for this attack combo:
1123    --   = sum of all the attacker ratings and the defender rating with the final def_stats
1124    -- Rating for first attack exists already
1125    local def_rating = ratings[1][4]
1126    local att_rating = ratings[1][5]
1127
1128    -- The others need to be calculated with the new stats
1129    for i = 2,#attackers do
1130        local cfg = { att_stats = att_stats[i], def_stats = def_stats[i] }
1131        local r, dr, ar = battle_calcs.attack_rating(attackers[i], defender, dsts[i], cfg, cache)
1132
1133        def_rating = dr
1134        att_rating = att_rating + ar
1135    end
1136
1137    local rating = def_rating + att_rating
1138
1139    return rating, attackers, dsts, att_stats, def_stats[#attackers], def_stats
1140end
1141
1142function battle_calcs.get_attack_map_unit(unit, cfg)
1143    -- Get all hexes that @unit can attack
1144    -- Return value is a location set, where the values are tables, containing
1145    --   - units: the number of units (always 1 for this function)
1146    --   - hitpoints: the combined hitpoints of the units
1147    --   - srcs: an array containing the positions of the units
1148    -- @cfg: table with config parameters
1149    --  max_moves: if set use max_moves for units (this setting is always used for units on other sides)
1150
1151    cfg = cfg or {}
1152
1153    -- 'moves' can be either "current" or "max"
1154    -- For unit on current side: use "current" by default, or override by cfg.moves
1155    local max_moves = cfg.max_moves
1156    -- For unit on any other side, only max_moves=true makes sense
1157    if (unit.side ~= wesnoth.current.side) then max_moves = true end
1158
1159    local old_moves = unit.moves
1160    if max_moves then unit.moves = unit.max_moves end
1161
1162    local reach = {}
1163    reach.units = LS.create()
1164    reach.hitpoints = LS.create()
1165
1166    -- Also for units on the other side, take all units on this side with
1167    -- MP left off the map (for enemy pathfinding)
1168    local units_MP = {}
1169    if (unit.side ~= wesnoth.current.side) then
1170        local all_units = wesnoth.get_units { side = wesnoth.current.side }
1171        for _,unit in ipairs(all_units) do
1172            if (unit.moves > 0) then
1173                table.insert(units_MP, unit)
1174                wesnoth.extract_unit(unit)
1175            end
1176        end
1177    end
1178
1179    -- Find hexes the unit can reach
1180    local initial_reach = wesnoth.find_reach(unit, cfg)
1181
1182    -- Put the units back out there
1183    if (unit.side ~= wesnoth.current.side) then
1184        for _,uMP in ipairs(units_MP) do wesnoth.put_unit(uMP) end
1185    end
1186
1187    for _,loc in ipairs(initial_reach) do
1188        reach.units:insert(loc[1], loc[2], 1)
1189        reach.hitpoints:insert(loc[1], loc[2], unit.hitpoints)
1190        for xa,ya in H.adjacent_tiles(loc[1], loc[2]) do
1191            reach.units:insert(xa, ya, 1)
1192            reach.hitpoints:insert(xa, ya, unit.hitpoints)
1193        end
1194    end
1195
1196    if max_moves then unit.moves = old_moves end
1197
1198    return reach
1199end
1200
1201function battle_calcs.get_attack_map(units, cfg)
1202    -- Get all hexes that @units can attack.  This is really just a wrapper
1203    -- function for battle_calcs.get_attack_map_unit()
1204    -- Return value is a location set, where the values are tables, containing
1205    --   - units: the number of units (always 1 for this function)
1206    --   - hitpoints: the combined hitpoints of the units
1207    --   - srcs: an array containing the positions of the units
1208    -- @cfg: table with config parameters
1209    --  max_moves: if set use max_moves for units (this setting is always used for units on other sides)
1210
1211    local attack_map1 = {}
1212    attack_map1.units = LS.create()
1213    attack_map1.hitpoints = LS.create()
1214
1215    for _,unit in ipairs(units) do
1216        local attack_map2 = battle_calcs.get_attack_map_unit(unit, cfg)
1217        attack_map1.units:union_merge(attack_map2.units, function(x, y, v1, v2)
1218            return (v1 or 0) + v2
1219        end)
1220        attack_map1.hitpoints:union_merge(attack_map2.hitpoints, function(x, y, v1, v2)
1221            return (v1 or 0) + v2
1222        end)
1223    end
1224
1225    return attack_map1
1226end
1227
1228function battle_calcs.relative_damage_map(units, enemies, cache)
1229    -- Returns a location set map containing the relative damage of
1230    -- @units vs. @enemies on the part of the map that the combined units
1231    -- can reach. The damage is calculated as the sum of defender_rating
1232    -- from attack_rating(), and thus (roughly) in gold units.
1233    -- Also returns the same maps for the own and enemy units only
1234    -- (with the enemy_damage_map having positive sign, while in the
1235    -- overall damage map it is subtracted)
1236
1237    -- Get the attack maps for each unit in 'units' and 'enemies'
1238    local my_attack_maps, enemy_attack_maps = {}, {}
1239    for i,unit in ipairs(units) do
1240        my_attack_maps[i] = battle_calcs.get_attack_map_unit(unit)
1241    end
1242    for i,e in ipairs(enemies) do
1243        enemy_attack_maps[i] = battle_calcs.get_attack_map_unit(e)
1244    end
1245
1246    -- Get the damage rating for each unit in 'units'. It is the maximum
1247    -- defender_rating (roughly the damage that it can do in units of gold)
1248    -- against any of the enemy units
1249    local unit_ratings = {}
1250    for i,unit in ipairs(units) do
1251        local max_rating, best_enemy = -9e99, {}
1252        for _,enemy in ipairs(enemies) do
1253            local rating, defender_rating, attacker_rating =
1254                battle_calcs.attack_rating(unit, enemy, { unit.x, unit.y }, { enemy_leader_weight = 1 }, cache)
1255
1256            local eff_rating = defender_rating
1257            if (eff_rating > max_rating) then
1258                max_rating = eff_rating
1259                best_enemy = enemy
1260            end
1261        end
1262        unit_ratings[i] = { rating = max_rating, unit_id = unit.id, enemy_id = best_enemy.id }
1263    end
1264
1265    -- Then we want the same thing for all of the enemy units (for the counter attack on enemy turn)
1266    local enemy_ratings = {}
1267    for i,enemy in ipairs(enemies) do
1268        local max_rating, best_unit = -9e99, {}
1269        for _,unit in ipairs(units) do
1270            local rating, defender_rating, attacker_rating =
1271                battle_calcs.attack_rating(enemy, unit, { enemy.x, enemy.y }, { enemy_leader_weight = 1 }, cache)
1272
1273            local eff_rating = defender_rating
1274            if (eff_rating > max_rating) then
1275                max_rating = eff_rating
1276                best_unit = unit
1277            end
1278        end
1279        enemy_ratings[i] = { rating = max_rating, unit_id = best_unit.id, enemy_id = enemy.id }
1280    end
1281
1282    -- The damage map is now the sum of these ratings for each unit that can attack a given hex,
1283    -- counting own-unit ratings as positive, enemy ratings as negative
1284    local damage_map, own_damage_map, enemy_damage_map = LS.create(), LS.create(), LS.create()
1285    for i,_ in ipairs(units) do
1286        my_attack_maps[i].units:iter(function(x, y, v)
1287            own_damage_map:insert(x, y, (own_damage_map:get(x, y) or 0) + unit_ratings[i].rating)
1288            damage_map:insert(x, y, (damage_map:get(x, y) or 0) + unit_ratings[i].rating)
1289        end)
1290    end
1291    for i,_ in ipairs(enemies) do
1292        enemy_attack_maps[i].units:iter(function(x, y, v)
1293            enemy_damage_map:insert(x, y, (enemy_damage_map:get(x, y) or 0) + enemy_ratings[i].rating)
1294            damage_map:insert(x, y, (damage_map:get(x, y) or 0) - enemy_ratings[i].rating)
1295        end)
1296    end
1297
1298    return damage_map, own_damage_map, enemy_damage_map
1299end
1300
1301function battle_calcs.best_defense_map(units, cfg)
1302    -- Get a defense rating map of all hexes all units in @units can reach
1303    -- For each hex, the value is the maximum of any of the units that can reach that hex
1304    -- @cfg: table with config parameters
1305    --  max_moves: if set use max_moves for units (this setting is always used for units on other sides)
1306    --  ignore_these_units: table of enemy units whose ZoC is to be ignored for route finding
1307
1308    cfg = cfg or {}
1309
1310    local defense_map = LS.create()
1311
1312    if cfg.ignore_these_units then
1313        for _,unit in ipairs(cfg.ignore_these_units) do wesnoth.extract_unit(unit) end
1314    end
1315
1316    for _,unit in ipairs(units) do
1317        -- Set max_moves according to the cfg value
1318        local max_moves = cfg.max_moves
1319        -- For unit on other than current side, only max_moves=true makes sense
1320        if (unit.side ~= wesnoth.current.side) then max_moves = true end
1321        local old_moves = unit.moves
1322        if max_moves then unit.moves = unit.max_moves end
1323        local reach = wesnoth.find_reach(unit, cfg)
1324        if max_moves then unit.moves = old_moves end
1325
1326        for _,loc in ipairs(reach) do
1327            local defense = 100 - wesnoth.unit_defense(unit, wesnoth.get_terrain(loc[1], loc[2]))
1328
1329            if (defense > (defense_map:get(loc[1], loc[2]) or -9e99)) then
1330                defense_map:insert(loc[1], loc[2], defense)
1331            end
1332        end
1333    end
1334
1335    if cfg.ignore_these_units then
1336        for _,unit in ipairs(cfg.ignore_these_units) do wesnoth.put_unit(unit) end
1337    end
1338
1339    return defense_map
1340end
1341
1342function battle_calcs.get_attack_combos_subset(units, enemy, cfg)
1343    -- Calculate combinations of attacks by @units on @enemy
1344    -- This method does *not* produce all possible attack combinations, but is
1345    -- meant to have a good chance to find either the best combination,
1346    -- or something close to it, by only considering a subset of all possibilities.
1347    -- It is also configurable to stop accumulating combinations when certain criteria are met.
1348    --
1349    -- The return value is an array of attack combinations, where each element is another
1350    -- array of tables containing 'dst' and 'src' fields of the attacking units. It can be
1351    -- specified whether the order of the attacks matters or not (see below).
1352    --
1353    -- Note: This function is optimized for speed, not elegance
1354    --
1355    -- Note 2: The structure of the returned table is different from the (current) return value
1356    -- of ai_helper.get_attack_combos(), since the order of attacks never matters for the latter.
1357    -- TODO: consider making the two consistent (not sure yet whether that is advantageous)
1358    --
1359    -- @cfg: Table of optional configuration parameters
1360    --   - order_matters: if set, keep attack combos that use the same units on the same
1361    --       hexes, but in different attack order (default: false)
1362    --   - max_combos: stop adding attack combos if this number of combos has been reached
1363    --       default: assemble all possible combinations
1364    --   - max_time: stop adding attack combos if this much time (in seconds) has passed
1365    --       default: assemble all possible combinations
1366    --       note: this counts the time from the first call to add_attack(), not to
1367    --         get_attack_combos_cfg(), so there's a bit of extra overhead in here.
1368    --         This is done to prevent the return of no combos at all
1369    --         Note 2: there is some overhead involved in reading the time from the system,
1370    --           so don't use this unless it's needed
1371    --   - skip_presort: by default, the units are presorted in order of the unit with
1372    --       the highest rating first. This has the advantage of likely finding the best
1373    --       (or at least close to the best) attack combination earlier, but it add overhead,
1374    --       so it's actually a disadvantage for small numbers of combinations. skip_presort
1375    --       specifies the number of units up to which the presorting is skipped. Default: 5
1376
1377    cfg = cfg or {}
1378    cfg.order_matters = cfg.order_matters or false
1379    cfg.max_combos = cfg.max_combos or 9e99
1380    cfg.max_time = cfg.max_time or false
1381    cfg.skip_presort = cfg.skip_presort or 5
1382
1383    ----- begin add_attack() -----
1384    -- Recursive local function adding another attack to the current combo
1385    -- and adding the current combo to the overall attack_combos array
1386    local function add_attack(attacks, reachable_hexes, n_reach, attack_combos, combos_str, current_combo, hexes_used, cfg)
1387
1388        local time_up = false
1389        if cfg.max_time and (wesnoth.get_time_stamp() / 1000. - cfg.start_time >= cfg.max_time) then
1390            time_up = true
1391        end
1392
1393        -- Go through all the units
1394        for ind_att,attack in ipairs(attacks) do  -- 'attack' is array of all attacks for the unit
1395
1396            -- Then go through the individual attacks of the unit ...
1397            for _,att in ipairs(attack) do
1398                -- But only if this hex is not used yet and
1399                -- the cutoff criteria are not met
1400                if (not hexes_used[att.dst]) and (not time_up) and (#attack_combos < cfg.max_combos) then
1401
1402                    -- Mark this hex as used by this unit
1403                    hexes_used[att.dst] = attack.src
1404
1405                    -- Set up a string uniquely identifying the unit/attack hex pairs
1406                    -- for current_combo. This is used to exclude pairs that already
1407                    -- exist in a different order (if 'cfg.order_matters' is not set)
1408                    -- For this, we also add the numerical value of the attack_hex to
1409                    -- the 'hexes_used' table (in addition to the line above)
1410                    local str = ''
1411                    if (not cfg.order_matters) then
1412                        hexes_used[reachable_hexes[att.dst]] = attack.src
1413                        for ind_hex = 1,n_reach do
1414                            if hexes_used[ind_hex] then
1415                                str = str .. hexes_used[ind_hex] .. '-'
1416                            else
1417                                str = str .. '0-'
1418                            end
1419                        end
1420                    end
1421
1422                    -- 'combos_str' contains all the strings of previous combos
1423                    -- (if 'cfg.order_matters' is not set)
1424                    -- Only add this combo if it does not yet exist
1425                    if (not combos_str[str]) then
1426
1427                        -- Add the string identifyer to the array
1428                        if (not cfg.order_matters) then
1429                            combos_str[str] = true
1430                        end
1431
1432                        -- Add the attack to 'current_combo'
1433                        table.insert(current_combo, { dst = att.dst, src = attack.src })
1434
1435                        -- And *copy* the content of 'current_combo' into 'attack_combos'
1436                        local n_combos = #attack_combos + 1
1437                        attack_combos[n_combos] = {}
1438                        for ind_combo,combo in pairs(current_combo) do attack_combos[n_combos][ind_combo] = combo end
1439
1440                        -- Finally, remove the current unit for 'attacks' for the call to the next recursion level
1441                        table.remove(attacks, ind_att)
1442
1443                        add_attack(attacks, reachable_hexes, n_reach, attack_combos, combos_str, current_combo, hexes_used, cfg)
1444
1445                        -- Reinsert the unit
1446                        table.insert(attacks, ind_att, attack)
1447
1448                        -- And remove the last element (current attack) from 'current_combo'
1449                        table.remove(current_combo)
1450                    end
1451
1452                    -- And mark the hex as usable again
1453                    if (not cfg.order_matters) then
1454                        hexes_used[reachable_hexes[att.dst]] = nil
1455                    end
1456                    hexes_used[att.dst] = nil
1457
1458                    -- *** Important ***: We *only* consider one attack hex per unit, the
1459                    -- first that is found in the array of attacks for the unit. As they
1460                    -- are sorted by terrain defense, we simply use the first in the table
1461                    -- the unit can reach that is not occupied
1462                    -- That's what the 'break' does here:
1463                    break
1464                end
1465            end
1466        end
1467    end
1468    ----- end add_attack() -----
1469
1470    -- For units on the current side, we need to make sure that
1471    -- there isn't a unit of the same side in the way that cannot move any more
1472    -- Set up an array of hexes blocked in such a way
1473    -- For units on other sides we always assume that they can move away
1474    local blocked_hexes = LS.create()
1475    if units[1] and (units[1].side == wesnoth.current.side) then
1476        for xa,ya in H.adjacent_tiles(enemy.x, enemy.y) do
1477            local unit_in_way = wesnoth.get_unit(xa, ya)
1478            if unit_in_way then
1479                -- Units on the same side are blockers if they cannot move away
1480                if (unit_in_way.side == wesnoth.current.side) then
1481                    local reach = wesnoth.find_reach(unit_in_way)
1482                    if (#reach <= 1) then
1483                        blocked_hexes:insert(unit_in_way.x, unit_in_way.y)
1484                    end
1485                else  -- Units on other sides are always blockers
1486                    blocked_hexes:insert(unit_in_way.x, unit_in_way.y)
1487                end
1488            end
1489        end
1490    end
1491
1492    -- For sides other than the current, we always use max_moves,
1493    -- for the current side we always use current moves
1494    local old_moves = {}
1495    for i,unit in ipairs(units) do
1496        if (unit.side ~= wesnoth.current.side) then
1497            old_moves[i] = unit.moves
1498            unit.moves = unit.max_moves
1499        end
1500    end
1501
1502    -- Now set up an array containing the attack locations for each unit
1503    local attacks = {}
1504    -- We also need a numbered array of the possible attack hex coordinates
1505    -- The order doesn't matter, as long as it is fixed
1506    local reachable_hexes = {}
1507    for i,unit in ipairs(units) do
1508
1509        local locs = {}  -- attack locations for this unit
1510
1511        for xa,ya in H.adjacent_tiles(enemy.x, enemy.y) do
1512
1513            local loc = {}  -- attack location information for this unit for this hex
1514
1515            -- Make sure the hex is not occupied by unit that cannot move out of the way
1516            if (not blocked_hexes:get(xa, ya) or ((xa == unit.x) and (ya == unit.y))) then
1517
1518                -- Check whether the unit can get to the hex
1519                -- wesnoth.map.distance_between() is much faster than wesnoth.find_path()
1520                --> pre-filter using the former
1521                local cost = M.distance_between(unit.x, unit.y, xa, ya)
1522
1523                -- If the distance is <= the unit's MP, then see if it can actually get there
1524                -- This also means that only short paths have to be evaluated (in most situations)
1525                if (cost <= unit.moves) then
1526                    local path  -- since cost is already defined outside this block
1527                    path, cost = AH.find_path_with_shroud(unit, xa, ya)
1528                end
1529
1530                -- If the unit can get to this hex
1531                if (cost <= unit.moves) then
1532                    -- Store information about it in 'loc' and add this to 'locs'
1533                    -- Want coordinates (dst) and terrain defense (for sorting)
1534                    loc.dst = xa * 1000 + ya
1535                    loc.hit_prob = wesnoth.unit_defense(unit, wesnoth.get_terrain(xa, ya))
1536                    table.insert(locs, loc)
1537
1538                    -- Also mark this hex as usable
1539                    reachable_hexes[loc.dst] = true
1540                end
1541            end
1542        end
1543
1544        -- Also add some top-level information for the unit
1545        if locs[1] then
1546            locs.src = unit.x * 1000 + unit.y  -- The current position of the unit
1547            locs.unit_i = i  -- The position of the unit in the 'units' array
1548
1549            -- Now sort the possible attack locations for this unit by terrain defense
1550            table.sort(locs, function(a, b) return a.hit_prob < b.hit_prob end)
1551
1552            -- Finally, add the attack locations for this unit to the 'attacks' array
1553            table.insert(attacks, locs)
1554        end
1555
1556    end
1557
1558    -- Reset moves for all units
1559    for i,unit in ipairs(units) do
1560        if (unit.side ~= wesnoth.current.side) then
1561            unit.moves = old_moves[i]
1562        end
1563    end
1564
1565    -- If the number of units that can attack is greater than cfg.skip_presort:
1566    -- We also sort the attackers by their attack rating on their favorite hex
1567    -- The motivation is that by starting with the strongest unit, we'll find the
1568    -- best attack combo earlier, and it is more likely to find the best (or at
1569    -- least a good combo) even when not all attack combinations are collected.
1570    if (#attacks > cfg.skip_presort) then
1571        for _,attack in ipairs(attacks) do
1572            local dst = attack[1].dst
1573            local x, y = math.floor(dst / 1000), dst % 1000
1574            attack.rating = battle_calcs.attack_rating(units[attack.unit_i], enemy, { x, y })
1575        end
1576        table.sort(attacks, function(a, b) return a.rating > b.rating end)
1577    end
1578
1579    -- To simplify and speed things up in the following, the field values
1580    -- 'reachable_hexes' table needs to be consecutive integers
1581    -- We also want a variable containing the number of elements in the array
1582    -- (#reachable_hexes doesn't work because they keys are location indices)
1583    local n_reach = 0
1584    for k,hex in pairs(reachable_hexes) do
1585        n_reach = n_reach + 1
1586        reachable_hexes[k] = n_reach
1587    end
1588
1589    -- If cfg.max_time is set, record the start time
1590    -- For convenience, we store this in cfg
1591    if cfg.max_time then
1592        cfg.start_time = wesnoth.get_time_stamp() / 1000.
1593    end
1594
1595
1596    -- All this was just setting up the required information, now we call the
1597    -- recursive function setting up the array of attackcombinations
1598    local attack_combos = {}  -- This will contain the return value
1599    -- Temporary arrays (but need to be persistent across the recursion levels)
1600    local combos_str, current_combo, hexes_used  = {}, {}, {}
1601
1602    add_attack(attacks, reachable_hexes, n_reach, attack_combos, combos_str, current_combo, hexes_used, cfg)
1603
1604    cfg.start_time = nil
1605
1606    return attack_combos
1607end
1608
1609return battle_calcs
1610