1%% -*- erlang-indent-level: 2 -*-
2%%
3%% Licensed under the Apache License, Version 2.0 (the "License");
4%% you may not use this file except in compliance with the License.
5%% You may obtain a copy of the License at
6%%
7%%     http://www.apache.org/licenses/LICENSE-2.0
8%%
9%% Unless required by applicable law or agreed to in writing, software
10%% distributed under the License is distributed on an "AS IS" BASIS,
11%% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12%% See the License for the specific language governing permissions and
13%% limitations under the License.
14%%
15%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
16%%@doc
17%%	                BASIC BLOCK WEIGHTING
18%%
19%% Computes basic block weights by using branch probabilities as weights in a
20%% linear equation system, that is then solved using Gauss-Jordan Elimination.
21%%
22%% The equation system representation is intentionally sparse, since most blocks
23%% have at most two successors.
24-module(hipe_bb_weights).
25-export([compute/3, compute_fast/3, weight/2, call_exn_pred/0]).
26-export_type([bb_weights/0]).
27
28-compile(inline).
29
30%%-define(DO_ASSERT,1).
31%%-define(DEBUG,1).
32-include("../main/hipe.hrl").
33
34%% If the equation system is large, it might take too long to solve it exactly.
35%% Thus, if there are more than ?HEUR_MAX_SOLVE labels, we use the iterative
36%% approximation.
37-define(HEUR_MAX_SOLVE, 10000).
38
39-opaque bb_weights() :: #{label() => float()}.
40
41-type cfg() :: any().
42-type target_module() :: module().
43-type target_context() :: any().
44-type target() :: {target_module(), target_context()}.
45
46-type label()            :: integer().
47-type var()              :: label().
48-type assignment()       :: {var(), float()}.
49-type eq_assoc()         :: [{var(), key()}].
50-type solution()         :: [assignment()].
51
52%% Constant. Predicted probability of a call resulting in an exception.
53-spec call_exn_pred() -> float().
54call_exn_pred() -> 0.01.
55
56-spec compute(cfg(), target_module(), target_context()) -> bb_weights().
57compute(CFG, TgtMod, TgtCtx) ->
58  Target = {TgtMod, TgtCtx},
59  Labels = labels(CFG, Target),
60  if length(Labels) > ?HEUR_MAX_SOLVE ->
61      ?debug_msg("~w: Too many labels (~w), approximating.~n",
62		 [?MODULE, length(Labels)]),
63      compute_fast(CFG, TgtMod, TgtCtx);
64     true ->
65      {EqSys, EqAssoc} = build_eq_system(CFG, Labels, Target),
66      case solve(EqSys, EqAssoc) of
67	{ok, Solution} ->
68	  maps:from_list(Solution)
69      end
70  end.
71
72-spec build_eq_system(cfg(), [label()], target()) -> {eq_system(), eq_assoc()}.
73build_eq_system(CFG, Labels, Target) ->
74  StartLb = hipe_gen_cfg:start_label(CFG),
75  EQS0 = eqs_new(),
76  {EQS1, Assoc} = build_eq_system(Labels, CFG, Target, [], EQS0),
77  {StartLb, StartKey} = lists:keyfind(StartLb, 1, Assoc),
78  StartRow0 = eqs_get(StartKey, EQS1),
79  StartRow = row_set_const(-1.0, StartRow0), % -1.0 since StartLb coef is -1.0
80  EQS = eqs_put(StartKey, StartRow, EQS1),
81  {EQS, Assoc}.
82
83build_eq_system([], _CFG, _Target, Map, EQS) -> {EQS, lists:reverse(Map)};
84build_eq_system([L|Ls], CFG, Target, Map, EQS0) ->
85  PredProb = pred_prob(L, CFG, Target),
86  {Key, EQS} = eqs_insert(row_new([{L, -1.0}|PredProb], 0.0), EQS0),
87  build_eq_system(Ls, CFG, Target, [{L, Key}|Map], EQS).
88
89pred_prob(L, CFG, Target) ->
90  [begin
91     BB = bb(CFG, Pred, Target),
92     Ps = branch_preds(hipe_bb:last(BB), Target),
93     ?ASSERT(length(lists:ukeysort(1, Ps))
94	     =:= length(hipe_gen_cfg:succ(CFG, Pred))),
95     case lists:keyfind(L, 1, Ps) of
96       {L, Prob} when is_float(Prob) -> {Pred, Prob}
97     end
98   end || Pred <- hipe_gen_cfg:pred(CFG, L)].
99
100%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
101-spec triangelise(eq_system(), eq_assoc()) -> {eq_system(), eq_assoc()}.
102triangelise(EQS, VKs) ->
103  triangelise_1(mk_triix(EQS, VKs), []).
104
105triangelise_1(TIX0, Acc) ->
106  case triix_is_empty(TIX0) of
107    true -> {triix_eqs(TIX0), lists:reverse(Acc)};
108    false ->
109      {V,Key,TIX1} = triix_pop_smallest(TIX0),
110      Row0 = triix_get(Key, TIX1),
111      case row_get(V, Row0) of
112	Coef when Coef > -0.0001, Coef < 0.0001 ->
113	  throw(error);
114	_ ->
115	  Row = row_normalise(V, Row0),
116	  TIX2 = triix_put(Key, Row, TIX1),
117	  TIX = eliminate_triix(V, Key, Row, TIX2),
118	  triangelise_1(TIX, [{V,Key}|Acc])
119      end
120  end.
121
122%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
123%% Triangelisation maintains its own index, outside of eqs. This index is
124%% essentially a BST (used as a heap) of all equations by size, with {Key,Var}
125%% as the values and only containing a subset of all the keys in the whole
126%% equation system. The key operation is triix_pop_smallest/1, which pops a
127%% {Key,Var} from the heap corresponding to one of the smallest equations. This
128%% is critical in order to prevent the equations from growing during
129%% triangelisation, which would make the algorithm O(n^2) in the common case.
130-type tri_eq_system() :: {eq_system(),
131			  gb_trees:tree(non_neg_integer(),
132					gb_trees:tree(key(), var()))}.
133
134triix_eqs({EQS, _}) -> EQS.
135triix_get(Key, {EQS, _}) -> eqs_get(Key, EQS).
136triix_is_empty({_, Tree}) -> gb_trees:is_empty(Tree).
137triix_lookup(V, {EQS, _}) -> eqs_lookup(V, EQS).
138
139mk_triix(EQS, VKs) ->
140  {EQS,
141   lists:foldl(fun({V,Key}, Tree) ->
142		   Size = row_size(eqs_get(Key, EQS)),
143		   sitree_insert(Size, Key, V, Tree)
144	       end, gb_trees:empty(), VKs)}.
145
146sitree_insert(Size, Key, V, SiTree) ->
147  SubTree1 =
148    case gb_trees:lookup(Size, SiTree) of
149      none -> gb_trees:empty();
150      {value, SubTree0} -> SubTree0
151    end,
152  SubTree = gb_trees:insert(Key, V, SubTree1),
153  gb_trees:enter(Size, SubTree, SiTree).
154
155sitree_update_subtree(Size, SubTree, SiTree) ->
156  case gb_trees:is_empty(SubTree) of
157    true -> gb_trees:delete(Size, SiTree);
158    false -> gb_trees:update(Size, SubTree, SiTree)
159  end.
160
161triix_put(Key, Row, {EQS, Tree0}) ->
162  OldSize = row_size(eqs_get(Key, EQS)),
163  case row_size(Row) of
164    OldSize -> {eqs_put(Key, Row, EQS), Tree0};
165    Size ->
166      Tree =
167	case gb_trees:lookup(OldSize, Tree0) of
168	  none -> Tree0;
169	  {value, SubTree0} ->
170	    case gb_trees:lookup(Key, SubTree0) of
171	      none -> Tree0;
172	      {value, V} ->
173		SubTree = gb_trees:delete(Key, SubTree0),
174		Tree1 = sitree_update_subtree(OldSize, SubTree, Tree0),
175		sitree_insert(Size, Key, V, Tree1)
176	    end
177	end,
178      {eqs_put(Key, Row, EQS), Tree}
179  end.
180
181triix_pop_smallest({EQS, Tree}) ->
182  {Size, SubTree0} = gb_trees:smallest(Tree),
183  {Key, V, SubTree} = gb_trees:take_smallest(SubTree0),
184  {V, Key, {EQS, sitree_update_subtree(Size, SubTree, Tree)}}.
185
186%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
187
188row_normalise(Var, Row) ->
189  %% Normalise v's coef to 1.0
190  %% row_set_coef ensures the coef is exactly 1.0 (no rounding errors)
191  row_set_coef(Var, 1.0, row_scale(Row, 1.0/row_get(Var, Row))).
192
193%% Precondition: Row must be normalised; i.e. Vars coef must be 1.0 (mod
194%% rounding errors)
195-spec eliminate(var(), key(), row(), eq_system()) -> eq_system().
196eliminate(Var, Key, Row, TIX0) ->
197  eliminate_abstr(Var, Key, Row, TIX0,
198		  fun eqs_get/2, fun eqs_lookup/2, fun eqs_put/3).
199
200-spec eliminate_triix(var(), key(), row(), tri_eq_system()) -> tri_eq_system().
201eliminate_triix(Var, Key, Row, TIX0) ->
202  eliminate_abstr(Var, Key, Row, TIX0,
203		  fun triix_get/2, fun triix_lookup/2, fun triix_put/3).
204
205%% The same function implemented for two data types, eqs and triix.
206-compile({inline, eliminate_abstr/7}).
207-spec eliminate_abstr(var(), key(), row(), ADT, fun((key(), ADT) -> row()),
208		      fun((var(), ADT) -> [key()]),
209		      fun((key(), row(), ADT) -> ADT)) -> ADT.
210eliminate_abstr(Var, Key, Row, ADT0, GetFun, LookupFun, PutFun) ->
211  ?ASSERT(1.0 =:= row_get(Var, Row)),
212  ADT =
213    lists:foldl(fun(RK, ADT1) when RK =:= Key -> ADT1;
214		   (RK, ADT1) ->
215		    R = GetFun(RK, ADT1),
216		    PutFun(RK, row_addmul(R, Row, -row_get(Var, R)), ADT1)
217		end, ADT0, LookupFun(Var, ADT0)),
218  [Key] = LookupFun(Var, ADT),
219  ADT.
220
221-spec solve(eq_system(), eq_assoc()) -> error | {ok, solution()}.
222solve(EQS0, EqAssoc0) ->
223  try triangelise(EQS0, EqAssoc0)
224  of {EQS1, EqAssoc} ->
225      {ok, solve_1(EqAssoc, maps:from_list(EqAssoc), EQS1, [])}
226  catch error -> error
227  end.
228
229solve_1([], _VarEqs, _EQS, Acc) -> Acc;
230solve_1([{V,K}|Ps], VarEqs, EQS0, Acc0) ->
231  Row0 = eqs_get(K, EQS0),
232  VarsToKill = [Var || {Var, _} <- row_coefs(Row0), Var =/= V],
233  Row1 = kill_vars(VarsToKill, VarEqs, EQS0, Row0),
234  [{V,_}] = row_coefs(Row1), % assertion
235  Row = row_normalise(V, Row1),
236  [{V,1.0}] = row_coefs(Row), % assertion
237  EQS = eliminate(V, K, Row, EQS0),
238  [K] = eqs_lookup(V, EQS),
239  solve_1(Ps, VarEqs, eqs_remove(K, EQS), [{V, row_const(Row)}|Acc0]).
240
241kill_vars([], _VarEqs, _EQS, Row) -> Row;
242kill_vars([V|Vs], VarEqs, EQS, Row0) ->
243  VRow0 = eqs_get(maps:get(V, VarEqs), EQS),
244  VRow = row_normalise(V, VRow0),
245  ?ASSERT(1.0 =:= row_get(V, VRow)),
246  Row = row_addmul(Row0, VRow, -row_get(V, Row0)),
247  ?ASSERT(0.0 =:= row_get(V, Row)), % V has been killed
248  kill_vars(Vs, VarEqs, EQS, Row).
249
250-spec weight(label(), bb_weights()) -> float().
251weight(Lbl, Weights) ->
252  maps:get(Lbl, Weights).
253
254%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
255%% Row datatype
256%% Invariant: No 0.0 coefficiets!
257-spec row_empty() -> row().
258row_empty() -> {orddict:new(), 0.0}.
259
260-spec row_new([{var(), float()}], float()) -> row().
261row_new(Coefs, Const) when is_float(Const) ->
262  row_ensure_invar({row_squash_multiples(lists:keysort(1, Coefs)), Const}).
263
264row_squash_multiples([{K, C1},{K, C2}|Ps]) ->
265  row_squash_multiples([{K,C1+C2}|Ps]);
266row_squash_multiples([P|Ps]) -> [P|row_squash_multiples(Ps)];
267row_squash_multiples([]) -> [].
268
269row_ensure_invar({Coef, Const}) ->
270  {orddict:filter(fun(_, 0.0) -> false; (_, F) when is_float(F) -> true end,
271		  Coef), Const}.
272
273row_const({_, Const}) -> Const.
274row_coefs({Coefs, _}) -> orddict:to_list(Coefs).
275row_size({Coefs, _}) -> orddict:size(Coefs).
276
277row_get(Var, {Coefs, _}) ->
278  case lists:keyfind(Var, 1, Coefs) of
279    false -> 0.0;
280    {_, Coef} -> Coef
281  end.
282
283row_set_coef(Var, 0.0, {Coefs, Const}) ->
284  {orddict:erase(Var, Coefs), Const};
285row_set_coef(Var, Coef, {Coefs, Const}) ->
286  {orddict:store(Var, Coef, Coefs), Const}.
287
288row_set_const(Const, {Coefs, _}) -> {Coefs, Const}.
289
290%% Lhs + Rhs*Factor
291-spec row_addmul(row(), row(), float()) -> row().
292row_addmul({LhsCoefs, LhsConst}, {RhsCoefs, RhsConst}, Factor)
293  when is_float(Factor) ->
294  Coefs = row_addmul_coefs(LhsCoefs, RhsCoefs, Factor),
295  Const = LhsConst + RhsConst * Factor,
296  {Coefs, Const}.
297
298row_addmul_coefs(Ls, [], Factor) when is_float(Factor) -> Ls;
299row_addmul_coefs([], Rs, Factor) when is_float(Factor) ->
300  row_scale_coefs(Rs, Factor);
301row_addmul_coefs([L={LV, _}|Ls], Rs=[{RV,_}|_], Factor)
302  when LV < RV, is_float(Factor) ->
303  [L|row_addmul_coefs(Ls, Rs, Factor)];
304row_addmul_coefs(Ls=[{LV, _}|_], [{RV, RC}|Rs], Factor)
305  when LV > RV, is_float(RC), is_float(Factor) ->
306  [{RV, RC*Factor}|row_addmul_coefs(Ls, Rs, Factor)];
307row_addmul_coefs([{V, LC}|Ls], [{V, RC}|Rs], Factor)
308  when is_float(LC), is_float(RC), is_float(Factor) ->
309  case LC + RC * Factor of
310    0.0 ->      row_addmul_coefs(Ls, Rs, Factor);
311    C -> [{V,C}|row_addmul_coefs(Ls, Rs, Factor)]
312  end.
313
314row_scale(_, 0.0) -> row_empty();
315row_scale({RowCoefs, RowConst}, Factor) when is_float(Factor) ->
316  {row_scale_coefs(RowCoefs, Factor), RowConst * Factor}.
317
318row_scale_coefs([{V,C}|Cs], Factor) when is_float(Factor), is_float(C) ->
319  [{V,C*Factor}|row_scale_coefs(Cs, Factor)];
320row_scale_coefs([], Factor) when is_float(Factor) ->
321  [].
322
323%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
324%% Equation system ADT
325%%
326%% Stores a linear equation system, allowing for efficient updates and efficient
327%% queries for all equations mentioning a variable.
328%%
329%% It is sort of like a "database" table of {Primary, Terms, Const} indexed both
330%% on Primary as well as the vars (map keys) in Terms.
331-type row()       :: {Terms :: orddict:orddict(var(), float()),
332		      Const :: float()}.
333-type key()       :: non_neg_integer().
334-type rev_index() :: #{var() => ordsets:ordset(key())}.
335-record(eq_system, {
336	  rows = #{}              :: #{key() => row()},
337	  revidx = revidx_empty() :: rev_index(),
338	  next_key = 0            :: key()
339	 }).
340-type eq_system() :: #eq_system{}.
341
342eqs_new() -> #eq_system{}.
343
344-spec eqs_insert(row(), eq_system()) -> {key(), eq_system()}.
345eqs_insert(Row, EQS=#eq_system{next_key=NextKey0}) ->
346  Key = NextKey0,
347  NextKey = NextKey0 + 1,
348  {Key, eqs_insert(Key, Row, EQS#eq_system{next_key=NextKey})}.
349
350eqs_insert(Key, Row, EQS=#eq_system{rows=Rows, revidx=RevIdx0}) ->
351  RevIdx = revidx_add(Key, Row, RevIdx0),
352  EQS#eq_system{rows=Rows#{Key => Row}, revidx=RevIdx}.
353
354eqs_put(Key, Row, EQS0) ->
355  eqs_insert(Key, Row, eqs_remove(Key, EQS0)).
356
357eqs_remove(Key, EQS=#eq_system{rows=Rows, revidx=RevIdx0}) ->
358  OldRow = maps:get(Key, Rows),
359  RevIdx = revidx_remove(Key, OldRow, RevIdx0),
360  EQS#eq_system{rows = maps:remove(Key, Rows), revidx=RevIdx}.
361
362-spec eqs_get(key(), eq_system()) -> row().
363eqs_get(Key, #eq_system{rows=Rows}) -> maps:get(Key, Rows).
364
365%% Keys of all equations containing a nonzero coefficient for Var
366-spec eqs_lookup(var(), eq_system()) -> ordsets:ordset(key()).
367eqs_lookup(Var, #eq_system{revidx=RevIdx}) -> maps:get(Var, RevIdx).
368
369%% eqs_rows(#eq_system{rows=Rows}) -> maps:to_list(Rows).
370
371%% eqs_print(EQS) ->
372%%   lists:foreach(fun({_, Row}) ->
373%% 		    row_print(Row)
374%% 		end, lists:sort(eqs_rows(EQS))).
375
376%% row_print(Row) ->
377%%   CoefStrs = [io_lib:format("~wl~w", [Coef, Var])
378%% 	      || {Var, Coef} <- row_coefs(Row)],
379%%   CoefStr = lists:join(" + ", CoefStrs),
380%%   io:format("~w = ~s~n", [row_const(Row), CoefStr]).
381
382revidx_empty() -> #{}.
383
384-spec revidx_add(key(), row(), rev_index()) -> rev_index().
385revidx_add(Key, Row, RevIdx0) ->
386  orddict:fold(fun(Var, _Coef, RevIdx1) ->
387		?ASSERT(_Coef /= 0.0),
388		RevIdx1#{Var => ordsets:add_element(
389				  Key, maps:get(Var, RevIdx1, ordsets:new()))}
390	    end, RevIdx0, row_coefs(Row)).
391
392-spec revidx_remove(key(), row(), rev_index()) -> rev_index().
393revidx_remove(Key, {Coefs, _}, RevIdx0) ->
394  orddict:fold(fun(Var, _Coef, RevIdx1) ->
395		case RevIdx1 of
396		  #{Var := Keys0} ->
397		    case ordsets:del_element(Key, Keys0) of
398		      [] -> maps:remove(Var, RevIdx1);
399		      Keys -> RevIdx1#{Var := Keys}
400		    end
401		end
402	    end, RevIdx0, Coefs).
403
404%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
405-define(FAST_ITERATIONS, 5).
406
407%% @doc Computes a rough approximation of BB weights. The approximation is
408%% particularly poor (converges slowly) for recursive functions and loops.
409-spec compute_fast(cfg(), target_module(), target_context()) -> bb_weights().
410compute_fast(CFG, TgtMod, TgtCtx) ->
411  Target = {TgtMod, TgtCtx},
412  StartLb = hipe_gen_cfg:start_label(CFG),
413  RPO = reverse_postorder(CFG, Target),
414  PredProbs = [{L, pred_prob(L, CFG, Target)} || L <- RPO, L =/= StartLb],
415  Probs0 = (maps:from_list([{L, 0.0} || L <- RPO]))#{StartLb := 1.0},
416  fast_iterate(?FAST_ITERATIONS, PredProbs, Probs0).
417
418fast_iterate(0, _Pred, Probs) -> Probs;
419fast_iterate(Iters, Pred, Probs0) ->
420  fast_iterate(Iters-1, Pred,
421	       fast_one(Pred, Probs0)).
422
423fast_one([{L, Pred}|Ls], Probs0) ->
424  Weight = fast_sum(Pred, Probs0, 0.0),
425  Probs = Probs0#{L => Weight},
426  fast_one(Ls, Probs);
427fast_one([], Probs) ->
428  Probs.
429
430fast_sum([{P,EWt}|Pred], Probs, Acc) when is_float(EWt), is_float(Acc) ->
431  case Probs of
432    #{P := PWt} when is_float(PWt) ->
433      fast_sum(Pred, Probs, Acc + PWt * EWt)
434  end;
435fast_sum([], _Probs, Acc) when is_float(Acc) ->
436  Acc.
437
438%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
439%% Target module interface functions
440%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
441-define(TGT_IFACE_0(N), N(         {M,C}) -> M:N(         C)).
442-define(TGT_IFACE_1(N), N(A1,      {M,C}) -> M:N(A1,      C)).
443-define(TGT_IFACE_2(N), N(A1,A2,   {M,C}) -> M:N(A1,A2,   C)).
444-define(TGT_IFACE_3(N), N(A1,A2,A3,{M,C}) -> M:N(A1,A2,A3,C)).
445
446?TGT_IFACE_2(bb).
447?TGT_IFACE_1(branch_preds).
448?TGT_IFACE_1(labels).
449?TGT_IFACE_1(reverse_postorder).
450