1%---------------------------------------------------------------------------%
2% vim: ft=mercury ts=4 sw=4 et
3%---------------------------------------------------------------------------%
4% Copyright (C) 1999-2000,2002-2007, 2009-2012 The University of Melbourne.
5% Copyright (C) 2015 The Mercury team.
6% This file may only be copied under the terms of the GNU General
7% Public License - see the file COPYING in the Mercury distribution.
8%---------------------------------------------------------------------------%
9%
10% Module: accumulator.m.
11% Main authors: petdr.
12%
13% Attempts to transform a single proc to a tail recursive form by
14% introducing accumulators. The algorithm can do this if the code after
15% the recursive call has either the order independent state update or
16% associative property.
17%
18% /* Order independent State update property */
19% :- promise all [A,B,S0,S]
20%   (
21%       (some[SA] (update(A, S0, SA), update(B, SA, S)))
22%   <=>
23%       (some[SB] (update(B, S0, SB), update(A, SB, S)))
24%   ).
25%
26% /* Associativity property */
27% :- promise all [A,B,C,ABC]
28%   (
29%       (some[AB] (assoc(A, B, AB), assoc(AB, C, ABC)))
30%   <=>
31%       (some[BC] (assoc(B, C, BC), assoc(A, BC, ABC)))
32%   ).
33%
34% XXX What about exceptions and non-termination?
35%
36% The promise declarations above only provide promises about the declarative
37% semantics, but in order to apply this optimization, we ought to check that
38% it will preserve the operational semantics (modulo whatever changes are
39% allowed by the language semantics options).
40%
41% Currently we check and respect the --fully-strict option, but not the
42% --no-reorder-conj option. XXX we should check --no-reorder-conj!
43% If --no-reorder-conj was set, it would still be OK to apply this
44% transformation, but ONLY in cases where the goals which get reordered
45% are guaranteed not to throw any exceptions.
46%
47% The algorithm implemented is a combination of the algorithms from
48% "Making Mercury Programs Tail Recursive" and
49% "State Update Transformation", which can be found at
50% <http://www.cs.mu.oz.au/research/mercury/information/papers.html>.
51%
52% Note that currently "State Update Transformation" paper only resides
53% in CVS papers archive in the directory update, but has been submitted
54% to PPDP '00.
55%
56% The transformation recognises predicates in the form
57%
58% p(In, OutUpdate, OutAssoc) :-
59%   minimal(In),
60%   initialize(OutUpdate),
61%   base(OutAssoc).
62% p(In, OutUpdate, OutAssoc) :-
63%   decompose(In, Current, Rest),
64%   p(Rest, OutUpdate0, OutAssoc0),
65%   update(Current, OutUpdate0, OutUpdate),
66%   assoc(Current, OutAssoc0, OutAssoc).
67%
68% which can be transformed by the algorithm in "State Update Transformation" to
69%
70% p(In, OutUpdate, OutAssoc) :-
71%   initialize(AccUpdate),
72%   p_acc(In, OutUpdate, OutAssoc, AccUpdate).
73%
74% p_acc(In, OutUpdate, OutAssoc, AccUpdate) :-
75%   minimal(In),
76%   base(OutAssoc),
77%   OutUpdate = AccUpdate.
78% p_acc(In, OutUpdate, OutAssoc, AccUpdate0) :-
79%   decompose(In, Current, Rest),
80%   update(Current, AccUpdate0, AccUpdate),
81%   p_acc(Rest, OutUpdate, OutAssoc0, AccUpdate),
82%   assoc(Current, OutAssoc0, OutAssoc).
83%
84% we then apply the algorithm from "Making Mercury Programs Tail Recursive"
85% to p_acc to obtain
86%
87% p_acc(In, OutUpdate, OutAssoc, AccUpdate) :-
88%   minimal(In),
89%   base(OutAssoc),
90%   OutUpdate = AccUpdate.
91% p_acc(In, OutUpdate, OutAssoc, AccUpdate0) :-
92%   decompose(In, Current, Rest),
93%   update(Current, AccUpdate0, AccUpdate),
94%   p_acc2(Rest, OutUpdate, OutAssoc, AccUpdate, Current).
95%
96% p_acc2(In, OutUpdate, OutAssoc, AccUpdate0, AccAssoc0) :-
97%   minimal(In),
98%   base(Base),
99%   assoc(AccAssoc0, Base, OutAssoc),
100%   OutUpdate = AccUpdate0.
101% p_acc2(In, OutUpdate, OutAssoc, AccUpdate0, AccAssoc0) :-
102%   decompose(In, Current, Rest),
103%   update(Current, AccUpdate0, AccUpdate),
104%   assoc(AccAssoc0, Current, AccAssoc),
105%   p_acc2(Rest, OutUpdate, OutAssoc, AccUpdate, AccAssoc).
106%
107% p_acc is no longer recursive and is only ever called from p, so we
108% inline p_acc into p to obtain the final schema.
109%
110% p(In, OutUpdate, OutAssoc) :-
111%   minimal(In),
112%   base(OutAssoc),
113%   initialize(AccUpdate),
114%   OutUpdate = AccUpdate.
115% p(In, OutUpdate, OutAssoc) :-
116%   decompose(In, Current, Rest),
117%   initialize(AccUpdate0),
118%   update(Current, AccUpdate0, AccUpdate),
119%   p_acc2(Rest, OutUpdate, OutAssoc, AccUpdate, Current).
120%
121% p_acc2(In, OutUpdate, OutAssoc, AccUpdate0, AccAssoc0) :-
122%   minimal(In),
123%   base(Base),
124%   assoc(AccAssoc0, Base, OutAssoc),
125%   OutUpdate = AccUpdate0.
126% p_acc2(In, OutUpdate, OutAssoc, AccUpdate0, AccAssoc0) :-
127%   decompose(In, Current, Rest),
128%   update(Current, AccUpdate0, AccUpdate),
129%   assoc(AccAssoc0, Current, AccAssoc),
130%   p_acc2(Rest, OutUpdate, OutAssoc, AccUpdate, AccAssoc).
131%
132% The only real difficulty in this new transformation is identifying the
133% initialize/1 and base/1 goals from the original base case.
134%
135% Note that if the recursive clause contains multiple calls to p, the
136% transformation attempts to move each recursive call to the end
137% until one succeeds. This makes the order of independent recursive
138% calls in the body irrelevant.
139%
140% XXX Replace calls to can_reorder_goals with calls to the version that
141% use the intermodule-analysis framework.
142%
143%---------------------------------------------------------------------------%
144
145:- module transform_hlds.accumulator.
146:- interface.
147
148:- import_module hlds.
149:- import_module hlds.hlds_module.
150:- import_module hlds.hlds_pred.
151
152:- import_module univ.
153
154    % Attempt to transform a procedure into accumulator recursive form.
155    % If we succeed, we will add the recursive version of the procedure
156    % to the module_info. However, we may also encounter errors, which
157    % we will add to the list of error_specs in the univ accumulator.
158    %
159:- pred accu_transform_proc(pred_proc_id::in, pred_info::in,
160    proc_info::in, proc_info::out, module_info::in, module_info::out,
161    univ::in, univ::out) is det.
162
163%---------------------------------------------------------------------------%
164%---------------------------------------------------------------------------%
165
166:- implementation.
167
168:- import_module hlds.assertion.
169:- import_module hlds.goal_util.
170:- import_module hlds.hlds_error_util.
171:- import_module hlds.hlds_goal.
172:- import_module hlds.hlds_out.
173:- import_module hlds.hlds_out.hlds_out_util.
174:- import_module hlds.hlds_promise.
175:- import_module hlds.instmap.
176:- import_module hlds.pred_table.
177:- import_module hlds.quantification.
178:- import_module hlds.status.
179:- import_module hlds.vartypes.
180:- import_module libs.
181:- import_module libs.globals.
182:- import_module libs.optimization_options.
183:- import_module libs.options.
184:- import_module mdbcomp.
185:- import_module mdbcomp.sym_name.
186:- import_module parse_tree.
187:- import_module parse_tree.error_util.
188:- import_module parse_tree.prog_data.
189:- import_module parse_tree.prog_mode.
190:- import_module parse_tree.prog_util.
191:- import_module parse_tree.set_of_var.
192:- import_module transform_hlds.goal_store.
193
194:- import_module assoc_list.
195:- import_module bool.
196:- import_module int.
197:- import_module io.
198:- import_module list.
199:- import_module map.
200:- import_module maybe.
201:- import_module pair.
202:- import_module require.
203:- import_module set.
204:- import_module solutions.
205:- import_module string.
206:- import_module term.
207:- import_module varset.
208
209%---------------------------------------------------------------------------%
210
211    % The form of the goal around the base and recursive cases.
212    %
213:- type top_level
214    --->    switch_base_rec
215    ;       switch_rec_base
216    ;       disj_base_rec
217    ;       disj_rec_base
218    ;       ite_base_rec
219    ;       ite_rec_base.
220
221    % An accu_goal_id represents a goal. The first field says which conjunction
222    % the goal came from (the base case or the recursive case), and the second
223    % gives the location of the goal in that conjunction.
224    %
225:- type accu_goal_id
226    --->    accu_goal_id(accu_case, int).
227
228:- type accu_case
229    --->    accu_base
230    ;       accu_rec.
231
232    % The goal_store associates a goal with each goal_id.
233    %
234:- type accu_goal_store == goal_store(accu_goal_id).
235
236    % A substitution from the first variable name to the second.
237    %
238:- type accu_subst == map(prog_var, prog_var).
239
240:- type accu_warning
241    --->    accu_warn(prog_context, pred_id, prog_var, prog_var).
242            % Warn that two prog_vars in a call to pred_id at the given context
243            % were swapped, which may cause an efficiency problem.
244
245%---------------------------------------------------------------------------%
246
247accu_transform_proc(proc(PredId, ProcId), PredInfo, !ProcInfo, !ModuleInfo,
248        !Cookie) :-
249    module_info_get_globals(!.ModuleInfo, Globals),
250    globals.get_opt_tuple(Globals, OptTuple),
251    DoLCMC = OptTuple ^ ot_opt_lcmc_accumulator,
252    globals.lookup_bool_option(Globals, fully_strict, FullyStrict),
253    ( if
254        should_attempt_accu_transform(!ModuleInfo, PredId, ProcId, PredInfo,
255            !ProcInfo, FullyStrict, DoLCMC, Warnings)
256    then
257        globals.lookup_bool_option(Globals, very_verbose, VeryVerbose),
258        (
259            VeryVerbose = yes,
260            trace [io(!IO)] (
261                module_info_get_name(!.ModuleInfo, ModuleName),
262                get_progress_output_stream(Globals, ModuleName,
263                    ProgressStream, !IO),
264                PredStr = pred_id_to_string(!.ModuleInfo, PredId),
265                io.format(ProgressStream,
266                    "%% Accumulators introduced into %s\n", [s(PredStr)], !IO)
267            )
268        ;
269            VeryVerbose = no
270        ),
271
272        (
273            Warnings = []
274        ;
275            Warnings = [_ | _],
276            pred_info_get_context(PredInfo, Context),
277            PredPieces = describe_one_pred_name(!.ModuleInfo,
278                should_module_qualify, PredId),
279            InPieces = [words("In") | PredPieces] ++ [suffix(":"), nl],
280            InMsg = simple_msg(Context,
281                [option_is_set(warn_accumulator_swaps, yes,
282                    [always(InPieces)])]),
283
284            proc_info_get_varset(!.ProcInfo, VarSet),
285            generate_warnings(!.ModuleInfo, VarSet, Warnings, WarnMsgs),
286            (
287                Warnings = [_],
288                EnsurePieces = [words("Please ensure that this"),
289                    words("argument rearrangement does not introduce"),
290                    words("performance problems.")]
291            ;
292                Warnings = [_, _ | _],
293                EnsurePieces = [words("Please ensure that these"),
294                    words("argument rearrangements do not introduce"),
295                    words("performance problems.")]
296            ),
297            SuppressPieces =
298                [words("These warnings can be suppressed by"),
299                quote("--no-warn-accumulator-swaps"), suffix(".")],
300            VerbosePieces = [words("If a predicate has been declared"),
301                words("associative"),
302                words("via a"), quote("promise"), words("declaration,"),
303                words("the compiler will rearrange the order of"),
304                words("the arguments in calls to that predicate,"),
305                words("if by so doing it makes the containing predicate"),
306                words("tail recursive. In such situations, the compiler"),
307                words("will issue this warning. If this reordering"),
308                words("changes the performance characteristics"),
309                words("of the call to the predicate, use"),
310                quote("--no-accumulator-introduction"),
311                words("to turn the optimization off, or "),
312                quote("--no-warn-accumulator-swaps"),
313                words("to turn off the warnings.")],
314            EnsureSuppressMsg = simple_msg(Context,
315                [option_is_set(warn_accumulator_swaps, yes,
316                    [always(EnsurePieces), always(SuppressPieces)]),
317                verbose_only(verbose_once, VerbosePieces)]),
318            Severity = severity_conditional(warn_accumulator_swaps, yes,
319                severity_warning, no),
320            Msgs = [InMsg | WarnMsgs] ++ [EnsureSuppressMsg],
321            Spec = error_spec($pred, Severity, phase_accumulator_intro, Msgs),
322
323            det_univ_to_type(!.Cookie, Specs0),
324            Specs = [Spec | Specs0],
325            type_to_univ(Specs, !:Cookie)
326        )
327    else
328        true
329    ).
330
331%---------------------------------------------------------------------------%
332%---------------------------------------------------------------------------%
333
334:- pred generate_warnings(module_info::in, prog_varset::in,
335    list(accu_warning)::in, list(error_msg)::out) is det.
336
337generate_warnings(_, _, [], []).
338generate_warnings(ModuleInfo, VarSet, [Warning | Warnings], [Msg | Msgs]) :-
339    generate_warning(ModuleInfo, VarSet, Warning, Msg),
340    generate_warnings(ModuleInfo, VarSet, Warnings, Msgs).
341
342:- pred generate_warning(module_info::in, prog_varset::in, accu_warning::in,
343    error_msg::out) is det.
344
345generate_warning(ModuleInfo, VarSet, Warning, Msg) :-
346    Warning = accu_warn(Context, PredId, VarA, VarB),
347    PredPieces = describe_one_pred_name(ModuleInfo, should_module_qualify,
348        PredId),
349
350    varset.lookup_name(VarSet, VarA, VarAName),
351    varset.lookup_name(VarSet, VarB, VarBName),
352
353    Pieces = [words("warning: the call to")] ++ PredPieces ++
354        [words("has had the location of the variables"),
355        quote(VarAName), words("and"), quote(VarBName),
356        words("swapped to allow accumulator introduction."), nl],
357    Msg = simplest_msg(Context, Pieces).
358
359%---------------------------------------------------------------------------%
360%---------------------------------------------------------------------------%
361
362    % should_attempt_accu_transform is only true iff the current proc
363    % has been transformed to call the newly created accumulator proc.
364    %
365:- pred should_attempt_accu_transform(module_info::in, module_info::out,
366    pred_id::in, proc_id::in, pred_info::in, proc_info::in, proc_info::out,
367    bool::in, maybe_opt_lcmc_accumulator::in,
368    list(accu_warning)::out) is semidet.
369
370should_attempt_accu_transform(!ModuleInfo, PredId, ProcId, PredInfo,
371        !ProcInfo, FullyStrict, DoLCMC, Warnings) :-
372    proc_info_get_goal(!.ProcInfo, Goal0),
373    proc_info_get_headvars(!.ProcInfo, HeadVars),
374    proc_info_get_initial_instmap(!.ModuleInfo, !.ProcInfo, InitialInstMap),
375    accu_standardize(Goal0, Goal),
376    identify_goal_type(PredId, ProcId, Goal, InitialInstMap,
377        TopLevel, Base, BaseInstMap, Rec, RecInstMap),
378
379    C = initialize_goal_store(Rec, RecInstMap, Base, BaseInstMap),
380    identify_recursive_calls(PredId, ProcId, C, RecCallIds),
381    list.length(Rec, M),
382
383    should_attempt_accu_transform_2(!ModuleInfo, PredId, PredInfo, !ProcInfo,
384        HeadVars, InitialInstMap, TopLevel, FullyStrict, DoLCMC,
385        RecCallIds, C, M, Rec, Warnings).
386
387    % should_attempt_accu_transform_2 takes a list of locations of the
388    % recursive calls, and attempts to introduce accumulator into each of the
389    % recursive calls, stopping at the first one that succeeds.
390    % This catches the following case, as selecting the first recursive call
391    % allows the second recursive call to be moved before it, and
392    % OutA is in the correct spot in list.append.
393    %
394    %   p(InA, OutA),
395    %   p(InB, OutB),
396    %   list.append(OutB, OutA, Out)
397    %
398:- pred should_attempt_accu_transform_2(module_info::in, module_info::out,
399    pred_id::in, pred_info::in, proc_info::in, proc_info::out,
400    list(prog_var)::in, instmap::in, top_level::in, bool::in,
401    maybe_opt_lcmc_accumulator::in,
402    list(accu_goal_id)::in, accu_goal_store::in, int::in, list(hlds_goal)::in,
403    list(accu_warning)::out) is semidet.
404
405should_attempt_accu_transform_2(!ModuleInfo, PredId, PredInfo, !ProcInfo,
406        HeadVars, InitialInstMap, TopLevel, FullyStrict, DoLCMC,
407        [Id | Ids], C, M, Rec, Warnings) :-
408    proc_info_get_vartypes(!.ProcInfo, VarTypes0),
409    identify_out_and_out_prime(!.ModuleInfo, VarTypes0, InitialInstMap,
410        Id, Rec, HeadVars, Out, OutPrime, HeadToCallSubst, CallToHeadSubst),
411    ( if
412        accu_stage1(!.ModuleInfo, VarTypes0, FullyStrict, DoLCMC, Id, M, C,
413            Sets),
414        accu_stage2(!.ModuleInfo, !.ProcInfo, Id, C, Sets, OutPrime, Out,
415            VarSet, VarTypes, Accs, BaseCase, BasePairs, Substs, CS,
416            WarningsPrime),
417        accu_stage3(Id, Accs, VarSet, VarTypes, C, CS, Substs,
418            HeadToCallSubst, CallToHeadSubst, BaseCase, BasePairs, Sets, Out,
419            TopLevel, PredId, PredInfo, !ProcInfo, !ModuleInfo)
420    then
421        Warnings = WarningsPrime
422    else
423        should_attempt_accu_transform_2(!ModuleInfo, PredId, PredInfo,
424            !ProcInfo, HeadVars, InitialInstMap, TopLevel, FullyStrict, DoLCMC,
425            Ids, C, M, Rec, Warnings)
426    ).
427
428%---------------------------------------------------------------------------%
429%---------------------------------------------------------------------------%
430
431    % Transform the goal into a standard form that is amenable to
432    % introducing accumulators.
433    %
434    % At the moment all this does is remove any extra disj/conj wrappers
435    % around the top level goal.
436    %
437    % Future work is for this code to rearrange code with multiple base
438    % and recursive cases into a single base and recursive case.
439    %
440:- pred accu_standardize(hlds_goal::in, hlds_goal::out) is det.
441
442accu_standardize(Goal0, Goal) :-
443    ( if
444        Goal0 = hlds_goal(GoalExpr0, _),
445        (
446            GoalExpr0 = conj(plain_conj, [Goal1])
447        ;
448            GoalExpr0 = disj([Goal1])
449        )
450    then
451        accu_standardize(Goal1, Goal)
452    else
453        Goal = Goal0
454    ).
455
456%---------------------------------------------------------------------------%
457%---------------------------------------------------------------------------%
458
459    % This predicate takes the original goal and identifies the `shape'
460    % of the goal around the recursive and base cases.
461    %
462    % Note that the base case can contain a recursive call, as the
463    % transformation doesn't depend on what is in the base case.
464    %
465:- pred identify_goal_type(pred_id::in, proc_id::in, hlds_goal::in,
466    instmap::in, top_level::out, list(hlds_goal)::out, instmap::out,
467    list(hlds_goal)::out, instmap::out) is semidet.
468
469identify_goal_type(PredId, ProcId, Goal, InitialInstMap, Type,
470        Base, BaseInstMap, Rec, RecInstMap) :-
471    Goal = hlds_goal(GoalExpr, _GoalInfo),
472    (
473        GoalExpr = switch(_Var, _CanFail, Cases),
474        ( if
475            Cases = [case(_IdA, [], GoalA), case(_IdB, [], GoalB)],
476            goal_to_conj_list(GoalA, GoalAList),
477            goal_to_conj_list(GoalB, GoalBList)
478        then
479            ( if is_recursive_case(GoalAList, proc(PredId, ProcId)) then
480                Type = switch_rec_base,
481                Base = GoalBList,
482                Rec = GoalAList
483            else if is_recursive_case(GoalBList, proc(PredId, ProcId)) then
484                Type = switch_base_rec,
485                Base = GoalAList,
486                Rec = GoalBList
487            else
488                fail
489            ),
490            BaseInstMap = InitialInstMap,
491            RecInstMap = InitialInstMap
492        else
493            fail
494        )
495    ;
496        GoalExpr = disj(Goals),
497        ( if
498            Goals = [GoalA, GoalB],
499            goal_to_conj_list(GoalA, GoalAList),
500            goal_to_conj_list(GoalB, GoalBList)
501        then
502            ( if is_recursive_case(GoalAList, proc(PredId, ProcId)) then
503                Type = disj_rec_base,
504                Base = GoalBList,
505                Rec = GoalAList
506            else if is_recursive_case(GoalBList, proc(PredId, ProcId)) then
507                Type = disj_base_rec,
508                Base = GoalAList,
509                Rec = GoalBList
510            else
511                fail
512            ),
513            BaseInstMap = InitialInstMap,
514            RecInstMap = InitialInstMap
515        else
516            fail
517        )
518    ;
519        GoalExpr = if_then_else(_Vars, Cond, Then, Else),
520        Cond = hlds_goal(_CondGoalExpr, CondGoalInfo),
521        CondInstMapDelta = goal_info_get_instmap_delta(CondGoalInfo),
522
523        goal_to_conj_list(Then, GoalAList),
524        goal_to_conj_list(Else, GoalBList),
525        ( if is_recursive_case(GoalAList, proc(PredId, ProcId)) then
526            Type = ite_rec_base,
527            Base = GoalBList,
528            Rec = GoalAList,
529
530            BaseInstMap = InitialInstMap,
531            apply_instmap_delta(CondInstMapDelta, InitialInstMap, RecInstMap)
532        else if is_recursive_case(GoalBList, proc(PredId, ProcId)) then
533            Type = ite_base_rec,
534            Base = GoalAList,
535            Rec = GoalBList,
536
537            RecInstMap = InitialInstMap,
538            apply_instmap_delta(CondInstMapDelta, InitialInstMap, BaseInstMap)
539        else
540            fail
541        )
542    ).
543
544    % is_recursive_case(Gs, Id) is true iff the list of goals, Gs,
545    % contains a call to the procedure specified by Id, where the call
546    % is located in a position that can be used by the transformation
547    % (i.e. not hidden in a compound goal).
548    %
549:- pred is_recursive_case(list(hlds_goal)::in, pred_proc_id::in) is semidet.
550
551is_recursive_case(Goals, proc(PredId, ProcId)) :-
552    list.append(_Initial, [RecursiveCall | _Final], Goals),
553    RecursiveCall = hlds_goal(plain_call(PredId, ProcId, _, _, _, _), _).
554
555%---------------------------------------------------------------------------%
556%---------------------------------------------------------------------------%
557
558    % The store info is folded over the list of goals which
559    % represent the base and recursive case conjunctions.
560:- type store_info
561    --->    store_info(
562                store_loc       :: int,
563                                % The location of the goal in the conjunction.
564                store_instmap   :: instmap,
565                store_goals     :: accu_goal_store
566            ).
567
568    % Initialise the goal_store, which will hold the C_{a,b} goals.
569    %
570:- func initialize_goal_store(list(hlds_goal), instmap,
571    list(hlds_goal), instmap) = accu_goal_store.
572
573initialize_goal_store(Rec, RecInstMap, Base, BaseInstMap) = C :-
574    goal_store_init(C0),
575    list.foldl3(accu_store(accu_rec), Rec,
576        1, _, RecInstMap, _, C0, C1),
577    list.foldl3(accu_store(accu_base), Base,
578        1, _, BaseInstMap, _, C1, C).
579
580:- pred accu_store(accu_case::in, hlds_goal::in,
581    int::in, int::out, instmap::in, instmap::out,
582    accu_goal_store::in, accu_goal_store::out) is det.
583
584accu_store(Case, Goal, !N, !InstMap, !GoalStore) :-
585    Id = accu_goal_id(Case, !.N),
586    goal_store_det_insert(Id, stored_goal(Goal, !.InstMap), !GoalStore),
587
588    !:N = !.N + 1,
589    Goal = hlds_goal(_, GoalInfo),
590    InstMapDelta = goal_info_get_instmap_delta(GoalInfo),
591    apply_instmap_delta(InstMapDelta, !InstMap).
592
593%---------------------------------------------------------------------------%
594%---------------------------------------------------------------------------%
595
596    % Determine the k's which are recursive calls.
597    % Note that this doesn't find recursive calls which are `hidden'
598    % in compound goals, this is not a problem as currently we can't use
599    % these to do transformation.
600    %
601:- pred identify_recursive_calls(pred_id::in, proc_id::in,
602    accu_goal_store::in, list(accu_goal_id)::out) is det.
603
604identify_recursive_calls(PredId, ProcId, GoalStore, Ids) :-
605    P =
606        ( pred(Key::out) is nondet :-
607            goal_store_member(GoalStore, Key, stored_goal(Goal, _InstMap)),
608            Key = accu_goal_id(accu_rec, _),
609            Goal = hlds_goal(plain_call(PredId, ProcId, _, _, _, _), _)
610        ),
611    solutions.solutions(P, Ids).
612
613%---------------------------------------------------------------------------%
614%---------------------------------------------------------------------------%
615
616    % Determine the variables which are members of the sets Out and Out',
617    % and initialize the substitutions between the two sets.
618    %
619    % This is done by identifing those variables whose instantiatedness change
620    % in the goals after the recursive call and are headvars.
621    %
622    % Note that we are only identifying the output variables which will need
623    % to be accumulated, as there may be other output variables which are
624    % produced prior to the recursive call.
625    %
626:- pred identify_out_and_out_prime(module_info::in, vartypes::in, instmap::in,
627    accu_goal_id::in, list(hlds_goal)::in,
628    list(prog_var)::in, list(prog_var)::out, list(prog_var)::out,
629    accu_subst::out, accu_subst::out) is det.
630
631identify_out_and_out_prime(ModuleInfo, VarTypes, InitialInstMap, GoalId,
632        Rec, HeadVars, Out, OutPrime, HeadToCallSubst, CallToHeadSubst) :-
633    GoalId = accu_goal_id(_Case, K),
634    ( if
635        list.take(K, Rec, InitialGoals),
636        list.drop(K-1, Rec, FinalGoals),
637        FinalGoals = [hlds_goal(plain_call(_, _, Args, _, _, _), _) | Rest]
638    then
639        goal_list_instmap_delta(InitialGoals, InitInstMapDelta),
640        apply_instmap_delta( InitInstMapDelta,
641            InitialInstMap, InstMapBeforeRest),
642
643        goal_list_instmap_delta(Rest, InstMapDelta),
644        apply_instmap_delta(InstMapDelta, InstMapBeforeRest, InstMapAfterRest),
645
646        instmap_changed_vars(ModuleInfo, VarTypes,
647            InstMapBeforeRest, InstMapAfterRest, ChangedVars),
648
649        assoc_list.from_corresponding_lists(HeadVars, Args, HeadArg0),
650
651        Member =
652            ( pred(M::in) is semidet :-
653                M = HeadVar - _,
654                set_of_var.member(ChangedVars, HeadVar)
655            ),
656        list.filter(Member, HeadArg0, HeadArg),
657        list.map(fst, HeadArg, Out),
658        list.map(snd, HeadArg, OutPrime),
659
660        map.from_assoc_list(HeadArg, HeadToCallSubst),
661
662        list.map((pred(X-Y::in, Y-X::out) is det), HeadArg, ArgHead),
663        map.from_assoc_list(ArgHead, CallToHeadSubst)
664    else
665        unexpected($pred, "test failed")
666    ).
667
668%---------------------------------------------------------------------------%
669%---------------------------------------------------------------------------%
670
671    % For each goal after the recursive call, we place that goal
672    % into a set according to what properties that goal has.
673    % For the definition of what goes into each set, inspect the documentation
674    % for the functions named before, assoc, and so on.
675    %
676:- type accu_sets
677    --->    accu_sets(
678                as_before           ::  set(accu_goal_id),
679                as_assoc            ::  set(accu_goal_id),
680                as_construct_assoc  ::  set(accu_goal_id),
681                as_construct        ::  set(accu_goal_id),
682                as_update           ::  set(accu_goal_id),
683                as_reject           ::  set(accu_goal_id)
684            ).
685
686    % Stage 1 is responsible for identifying which goals are associative,
687    % which can be moved before the recursive call and so on.
688    %
689:- pred accu_stage1(module_info::in, vartypes::in, bool::in,
690    maybe_opt_lcmc_accumulator::in, accu_goal_id::in, int::in,
691    accu_goal_store::in, accu_sets::out) is semidet.
692
693accu_stage1(ModuleInfo, VarTypes, FullyStrict, DoLCMC, GoalId, M, GoalStore,
694        Sets) :-
695    GoalId = accu_goal_id(Case, K),
696    NextGoalId = accu_goal_id(Case, K + 1),
697    accu_sets_init(Sets0),
698    accu_stage1_2(ModuleInfo, VarTypes, FullyStrict, NextGoalId, K, M,
699        GoalStore, Sets0, Sets1),
700    Sets1 = accu_sets(Before, Assoc,
701        ConstructAssoc, Construct, Update, Reject),
702    Sets = accu_sets(Before `set.union` set_upto(Case, K - 1), Assoc,
703        ConstructAssoc, Construct, Update, Reject),
704
705    % Continue the transformation only if the set reject is empty and
706    % the set assoc or update contains something that needs to be moved
707    % before the recursive call.
708    set.is_empty(Reject),
709    (
710        not set.is_empty(Assoc)
711    ;
712        not set.is_empty(Update)
713    ),
714    (
715        DoLCMC = do_not_opt_lcmc_accumulator,
716        % If LCMC is not turned on, then there must be no construction
717        % unifications after the recursive call.
718        set.is_empty(Construct),
719        set.is_empty(ConstructAssoc)
720    ;
721        DoLCMC = opt_lcmc_accumulator
722    ).
723
724    % For each goal after the recursive call decide which set
725    % the goal belongs to.
726    %
727:- pred accu_stage1_2(module_info::in, vartypes::in, bool::in,
728    accu_goal_id::in, int::in, int::in, accu_goal_store::in,
729    accu_sets::in, accu_sets::out) is det.
730
731accu_stage1_2(ModuleInfo, VarTypes, FullyStrict, GoalId, K, M, GoalStore,
732        !Sets) :-
733    GoalId = accu_goal_id(Case, I),
734    NextGoalId = accu_goal_id(Case, I + 1),
735    ( if I > M then
736        true
737    else
738        ( if
739            accu_before(ModuleInfo, VarTypes, FullyStrict, GoalId, K,
740                GoalStore, !.Sets)
741        then
742            !Sets ^ as_before := set.insert(!.Sets ^ as_before, GoalId),
743            accu_stage1_2(ModuleInfo, VarTypes, FullyStrict, NextGoalId, K, M,
744                GoalStore, !Sets)
745        else if
746            accu_assoc(ModuleInfo, VarTypes, FullyStrict, GoalId, K,
747                GoalStore, !.Sets)
748        then
749            !Sets ^ as_assoc := set.insert(!.Sets ^ as_assoc, GoalId),
750            accu_stage1_2(ModuleInfo, VarTypes, FullyStrict, NextGoalId, K, M,
751                GoalStore, !Sets)
752        else if
753            accu_construct(ModuleInfo, VarTypes, FullyStrict, GoalId, K,
754                GoalStore, !.Sets)
755        then
756            !Sets ^ as_construct := set.insert(!.Sets ^ as_construct, GoalId),
757            accu_stage1_2(ModuleInfo, VarTypes, FullyStrict, NextGoalId, K, M,
758                GoalStore, !Sets)
759        else if
760            accu_construct_assoc(ModuleInfo, VarTypes, FullyStrict, GoalId, K,
761                GoalStore, !.Sets)
762        then
763            !Sets ^ as_construct_assoc :=
764                set.insert(!.Sets ^ as_construct_assoc, GoalId),
765            accu_stage1_2(ModuleInfo, VarTypes, FullyStrict, NextGoalId, K, M,
766                GoalStore, !Sets)
767        else if
768            accu_update(ModuleInfo, VarTypes, FullyStrict, GoalId, K,
769                GoalStore, !.Sets)
770        then
771            !Sets ^ as_update := set.insert(!.Sets ^ as_update, GoalId),
772            accu_stage1_2(ModuleInfo, VarTypes, FullyStrict, NextGoalId, K, M,
773                GoalStore, !Sets)
774        else
775            !Sets ^ as_reject := set.insert(!.Sets ^ as_reject, GoalId)
776        )
777    ).
778
779%---------------------------------------------------------------------------%
780
781:- pred accu_sets_init(accu_sets::out) is det.
782
783accu_sets_init(Sets) :-
784    set.init(EmptySet),
785    Before = EmptySet,
786    Assoc = EmptySet,
787    ConstructAssoc = EmptySet,
788    Construct = EmptySet,
789    Update = EmptySet,
790    Reject = EmptySet,
791    Sets = accu_sets(Before, Assoc, ConstructAssoc, Construct, Update, Reject).
792
793    % set_upto(Case, K) returns the set
794    % {accu_goal_id(Case, 1) .. accu_goal_id(Case, K)}.
795    %
796:- func set_upto(accu_case, int) = set(accu_goal_id).
797
798set_upto(Case, K) = Set :-
799    ( if K =< 0 then
800        set.init(Set)
801    else
802        Set0 = set_upto(Case, K - 1),
803        set.insert(accu_goal_id(Case, K), Set0, Set)
804    ).
805
806%---------------------------------------------------------------------------%
807
808    % A goal is a member of the before set iff the goal only depends on goals
809    % which are before the recursive call or can be moved before the recursive
810    % call (member of the before set).
811    %
812:- pred accu_before(module_info::in, vartypes::in, bool::in,
813    accu_goal_id::in, int::in, accu_goal_store::in, accu_sets::in) is semidet.
814
815accu_before(ModuleInfo, VarTypes, FullyStrict, GoalId, K, GoalStore, Sets) :-
816    GoalId = accu_goal_id(Case, _I),
817    Before = Sets ^ as_before,
818    goal_store_lookup(GoalStore, GoalId, stored_goal(LaterGoal, LaterInstMap)),
819    (
820        member_lessthan_goalid(GoalStore, GoalId, LessThanGoalId,
821            stored_goal(EarlierGoal, EarlierInstMap)),
822        not can_reorder_goals_old(ModuleInfo, VarTypes, FullyStrict,
823            EarlierInstMap, EarlierGoal, LaterInstMap, LaterGoal)
824    )
825    =>
826    (
827        set.member(LessThanGoalId, set_upto(Case, K - 1) `union` Before)
828    ).
829
830    % A goal is a member of the assoc set iff the goal only depends on goals
831    % upto and including the recursive call and goals which can be moved
832    % before the recursive call (member of the before set) AND the goal
833    % is associative.
834    %
835:- pred accu_assoc(module_info::in, vartypes::in, bool::in,
836    accu_goal_id::in, int::in, accu_goal_store::in, accu_sets::in) is semidet.
837
838accu_assoc(ModuleInfo, VarTypes, FullyStrict, GoalId, K, GoalStore, Sets) :-
839    GoalId = accu_goal_id(Case, _I),
840    Before = Sets ^ as_before,
841    goal_store_lookup(GoalStore, GoalId, stored_goal(LaterGoal, LaterInstMap)),
842    LaterGoal = hlds_goal(plain_call(PredId, _, Args, _, _, _), _),
843    accu_is_associative(ModuleInfo, PredId, Args, _),
844    (
845        % XXX LessThanGoalId was _N - J, not N - J: it ignored the case.
846        % See the diff with the previous version.
847        member_lessthan_goalid(GoalStore, GoalId, LessThanGoalId,
848            stored_goal(EarlierGoal, EarlierInstMap)),
849        not can_reorder_goals_old(ModuleInfo, VarTypes, FullyStrict,
850            EarlierInstMap, EarlierGoal, LaterInstMap, LaterGoal)
851    )
852    =>
853    (
854        set.member(LessThanGoalId, set_upto(Case, K) `union` Before)
855    ).
856
857    % A goal is a member of the construct set iff the goal only depends
858    % on goals upto and including the recursive call and goals which
859    % can be moved before the recursive call (member of the before set)
860    % AND the goal is construction unification.
861    %
862:- pred accu_construct(module_info::in, vartypes::in, bool::in,
863    accu_goal_id::in, int::in, accu_goal_store::in, accu_sets::in) is semidet.
864
865accu_construct(ModuleInfo, VarTypes, FullyStrict, GoalId, K, GoalStore,
866        Sets) :-
867    GoalId = accu_goal_id(Case, _I),
868    Before = Sets ^ as_before,
869    Construct = Sets ^ as_construct,
870    goal_store_lookup(GoalStore, GoalId, stored_goal(LaterGoal, LaterInstMap)),
871    LaterGoal = hlds_goal(unify(_, _, _, Unify, _), _GoalInfo),
872    Unify = construct(_, _, _, _, _, _, _),
873    (
874        % XXX LessThanGoalId was _N - J, not N - J: it ignored the case.
875        % See the diff with the previous version.
876        member_lessthan_goalid(GoalStore, GoalId, LessThanGoalId,
877            stored_goal(EarlierGoal, EarlierInstMap)),
878        not can_reorder_goals_old(ModuleInfo, VarTypes, FullyStrict,
879            EarlierInstMap, EarlierGoal, LaterInstMap, LaterGoal)
880    )
881    =>
882    (
883        set.member(LessThanGoalId,
884            set_upto(Case, K) `union` Before `union` Construct)
885    ).
886
887    % A goal is a member of the construct_assoc set iff the goal depends only
888    % on goals upto and including the recursive call and goals which can be
889    % moved before the recursive call (member of the before set) and goals
890    % which are associative AND the goal is construction unification AND
891    % there is only one member of the assoc set which the construction
892    % unification depends on AND the construction unification can be expressed
893    % as a call to the member of the assoc set which the construction
894    % unification depends on.
895    %
896:- pred accu_construct_assoc(module_info::in, vartypes::in, bool::in,
897    accu_goal_id::in, int::in, accu_goal_store::in, accu_sets::in) is semidet.
898
899accu_construct_assoc(ModuleInfo, VarTypes, FullyStrict,
900        GoalId, K, GoalStore, Sets) :-
901    GoalId = accu_goal_id(Case, _I),
902    Before = Sets ^ as_before,
903    Assoc = Sets ^ as_assoc,
904    ConstructAssoc = Sets ^ as_construct_assoc,
905    goal_store_lookup(GoalStore, GoalId, stored_goal(LaterGoal, LaterInstMap)),
906    LaterGoal = hlds_goal(unify(_, _, _, Unify, _), _GoalInfo),
907    Unify = construct(_, ConsId, _, _, _, _, _),
908
909    goal_store_all_ancestors(GoalStore, GoalId, VarTypes, ModuleInfo,
910        FullyStrict, Ancestors),
911
912    set.is_singleton(Assoc `intersect` Ancestors, AssocId),
913    goal_store_lookup(GoalStore, AssocId,
914        stored_goal(AssocGoal, _AssocInstMap)),
915    AssocGoal = hlds_goal(plain_call(PredId, _, _, _, _, _), _),
916
917    is_associative_construction(ModuleInfo, PredId, ConsId),
918    (
919        % XXX LessThanGoalId was _N - J, not N - J: it ignored the case.
920        % See the diff with the previous version.
921        member_lessthan_goalid(GoalStore, GoalId, LessThanGoalId,
922            stored_goal(EarlierGoal, EarlierInstMap)),
923        not can_reorder_goals_old(ModuleInfo, VarTypes, FullyStrict,
924            EarlierInstMap, EarlierGoal, LaterInstMap, LaterGoal)
925    )
926    =>
927    (
928        set.member(LessThanGoalId,
929            set_upto(Case, K) `union` Before `union` Assoc
930            `union` ConstructAssoc)
931    ).
932
933    % A goal is a member of the update set iff the goal only depends
934    % on goals upto and including the recursive call and goals which
935    % can be moved before the recursive call (member of the before set)
936    % AND the goal updates some state.
937    %
938:- pred accu_update(module_info::in, vartypes::in, bool::in,
939    accu_goal_id::in, int::in, accu_goal_store::in, accu_sets::in) is semidet.
940
941accu_update(ModuleInfo, VarTypes, FullyStrict, GoalId, K, GoalStore, Sets) :-
942    GoalId = accu_goal_id(Case, _I),
943    Before = Sets ^ as_before,
944    goal_store_lookup(GoalStore, GoalId, stored_goal(LaterGoal, LaterInstMap)),
945    LaterGoal = hlds_goal(plain_call(PredId, _, Args, _, _, _), _),
946    accu_is_update(ModuleInfo, PredId, Args, _),
947    (
948        % XXX LessThanGoalId was _N - J, not N - J: it ignored the case.
949        % See the diff with the previous version.
950        member_lessthan_goalid(GoalStore, GoalId, LessThanGoalId,
951            stored_goal(EarlierGoal, EarlierInstMap)),
952        not can_reorder_goals_old(ModuleInfo, VarTypes, FullyStrict,
953            EarlierInstMap, EarlierGoal, LaterInstMap, LaterGoal)
954    )
955    =>
956    (
957        set.member(LessThanGoalId, set_upto(Case, K) `union` Before)
958    ).
959
960    % member_lessthan_goalid(GS, IdA, IdB, GB) is true iff the goal_id, IdB,
961    % and its associated goal, GB, is a member of the goal_store, GS,
962    % and IdB is less than IdA.
963    %
964:- pred member_lessthan_goalid(accu_goal_store::in,
965    accu_goal_id::in, accu_goal_id::out, stored_goal::out) is nondet.
966
967member_lessthan_goalid(GoalStore, GoalId, LessThanGoalId, LessThanGoal) :-
968    goal_store_member(GoalStore, LessThanGoalId, LessThanGoal),
969    GoalId = accu_goal_id(Case, I),
970    LessThanGoalId = accu_goal_id(Case, J),
971    J < I.
972
973%---------------------------------------------------------------------------%
974
975:- type accu_assoc
976    --->    accu_assoc(
977                set_of_progvar,     % the associative input args
978                prog_var,           % the corresponding output arg
979                bool                % is the predicate commutative?
980            ).
981
982    % If accu_is_associative is true, it returns the two arguments which are
983    % associative and the variable which depends on those two arguments,
984    % and an indicator of whether or not the predicate is commutative.
985    %
986:- pred accu_is_associative(module_info::in, pred_id::in, list(prog_var)::in,
987    accu_assoc::out) is semidet.
988
989accu_is_associative(ModuleInfo, PredId, Args, Result) :-
990    module_info_pred_info(ModuleInfo, PredId, PredInfo),
991    pred_info_get_assertions(PredInfo, Assertions),
992    AssertionsList = set.to_sorted_list(Assertions),
993    associativity_assertion(ModuleInfo, AssertionsList, Args,
994        AssociativeVarsOutputVar),
995    ( if
996        commutativity_assertion(ModuleInfo, AssertionsList, Args,
997            _CommutativeVars)
998    then
999        IsCommutative = yes
1000    else
1001        IsCommutative = no
1002    ),
1003    AssociativeVarsOutputVar =
1004        associative_vars_output_var(AssociativeVars, OutputVar),
1005    Result = accu_assoc(AssociativeVars, OutputVar, IsCommutative).
1006
1007    % Does there exist one (and only one) associativity assertion for the
1008    % current predicate?
1009    % The 'and only one condition' is required because we currently
1010    % do not handle the case of predicates which have individual parts
1011    % which are associative, because then we do not know which variable
1012    % is descended from which.
1013    %
1014:- pred associativity_assertion(module_info::in, list(assert_id)::in,
1015    list(prog_var)::in, associative_vars_output_var::out) is semidet.
1016
1017associativity_assertion(ModuleInfo, [AssertId | AssertIds], Args0,
1018        AssociativeVarsOutputVar) :-
1019    ( if
1020        assertion.is_associativity_assertion(ModuleInfo, AssertId,
1021            Args0, AssociativeVarsOutputVarPrime)
1022    then
1023        AssociativeVarsOutputVar = AssociativeVarsOutputVarPrime,
1024        not associativity_assertion(ModuleInfo, AssertIds, Args0, _)
1025    else
1026        associativity_assertion(ModuleInfo, AssertIds, Args0,
1027            AssociativeVarsOutputVar)
1028    ).
1029
1030    % Does there exist one (and only one) commutativity assertion for the
1031    % current predicate?
1032    % The 'and only one condition' is required because we currently
1033    % do not handle the case of predicates which have individual
1034    % parts which are commutative, because then we do not know which variable
1035    % is descended from which.
1036    %
1037:- pred commutativity_assertion(module_info::in,list(assert_id)::in,
1038    list(prog_var)::in, set_of_progvar::out) is semidet.
1039
1040commutativity_assertion(ModuleInfo, [AssertId | AssertIds], Args0,
1041        CommutativeVars) :-
1042    ( if
1043        assertion.is_commutativity_assertion(ModuleInfo, AssertId,
1044            Args0, CommutativeVarsPrime)
1045    then
1046        CommutativeVars = CommutativeVarsPrime,
1047        not commutativity_assertion(ModuleInfo, AssertIds, Args0, _)
1048    else
1049        commutativity_assertion(ModuleInfo, AssertIds, Args0,
1050            CommutativeVars)
1051    ).
1052
1053%---------------------------------------------------------------------------%
1054
1055    % Does the current predicate update some state?
1056    %
1057:- pred accu_is_update(module_info::in, pred_id::in, list(prog_var)::in,
1058    state_update_vars::out) is semidet.
1059
1060accu_is_update(ModuleInfo, PredId, Args, ResultStateVars) :-
1061    module_info_pred_info(ModuleInfo, PredId, PredInfo),
1062    pred_info_get_assertions(PredInfo, Assertions),
1063    list.filter_map(
1064        ( pred(AssertId::in, StateVars::out) is semidet :-
1065            assertion.is_update_assertion(ModuleInfo, AssertId,
1066                PredId, Args, StateVars)
1067        ),
1068        set.to_sorted_list(Assertions), Result),
1069    % XXX Maybe we should just match on the first result,
1070    % just in case there are duplicate promises.
1071    Result = [ResultStateVars].
1072
1073%---------------------------------------------------------------------------%
1074
1075    % Can the construction unification be expressed as a call to the
1076    % specified predicate.
1077    %
1078:- pred is_associative_construction(module_info::in, pred_id::in, cons_id::in)
1079    is semidet.
1080
1081is_associative_construction(ModuleInfo, PredId, ConsId) :-
1082    module_info_pred_info(ModuleInfo, PredId, PredInfo),
1083    pred_info_get_assertions(PredInfo, Assertions),
1084    list.filter(
1085        ( pred(AssertId::in) is semidet :-
1086            assertion.is_construction_equivalence_assertion(ModuleInfo,
1087                AssertId, ConsId, PredId)
1088        ),
1089        set.to_sorted_list(Assertions), Result),
1090    Result = [_ | _].
1091
1092%---------------------------------------------------------------------------%
1093%---------------------------------------------------------------------------%
1094
1095:- type accu_substs
1096    --->    accu_substs(
1097                acc_var_subst       :: accu_subst,
1098                rec_call_subst      :: accu_subst,
1099                assoc_call_subst    :: accu_subst,
1100                update_subst        :: accu_subst
1101            ).
1102
1103:- type accu_base
1104    --->    accu_base(
1105                % goals which initialize update
1106                init_update         :: set(accu_goal_id),
1107
1108                % goals which initialize assoc
1109                init_assoc          :: set(accu_goal_id),
1110
1111                % other goals
1112                other               :: set(accu_goal_id)
1113            ).
1114
1115    % Stage 2 is responsible for identifying the substitutions which
1116    % are needed to mimic the unfold/fold process that was used as
1117    % the justification of the algorithm in the paper.
1118    % It is also responsible for ensuring that the reordering of arguments
1119    % doesn't worsen the big-O complexity of the procedure.
1120    % It also divides the base case into goals that initialize the
1121    % variables used by the update goals, and those used by the assoc
1122    % goals and then all the rest.
1123    %
1124:- pred accu_stage2(module_info::in, proc_info::in,
1125    accu_goal_id::in, accu_goal_store::in, accu_sets::in,
1126    list(prog_var)::in, list(prog_var)::in, prog_varset::out, vartypes::out,
1127    list(prog_var)::out, accu_base::out, list(pair(prog_var))::out,
1128    accu_substs::out, accu_goal_store::out, list(accu_warning)::out)
1129    is semidet.
1130
1131accu_stage2(ModuleInfo, ProcInfo0, GoalId, GoalStore, Sets, OutPrime, Out,
1132        !:VarSet, !:VarTypes, Accs, BaseCase, BasePairs, !:Substs,
1133        CS, Warnings) :-
1134    Sets = accu_sets(Before0, Assoc, ConstructAssoc, Construct, Update, _),
1135    GoalId = accu_goal_id(Case, K),
1136    Before = Before0 `union` set_upto(Case, K-1),
1137
1138    % Note Update set is not placed in the after set, as the after set is used
1139    % to determine the variables that need to be accumulated for the
1140    % associative calls.
1141    After = Assoc `union` ConstructAssoc `union` Construct,
1142
1143    P =
1144        ( pred(Id::in, Set0::in, Set::out) is det :-
1145            goal_store_lookup(GoalStore, Id, stored_goal(Goal, _InstMap)),
1146            Goal = hlds_goal(_GoalExpr, GoalInfo),
1147            NonLocals = goal_info_get_nonlocals(GoalInfo),
1148            set_of_var.union(NonLocals, Set0, Set)
1149        ),
1150    list.foldl(P, set.to_sorted_list(Before),
1151        set_of_var.init, BeforeNonLocals),
1152    list.foldl(P, set.to_sorted_list(After),
1153        set_of_var.init, AfterNonLocals),
1154    InitAccs = set_of_var.intersect(BeforeNonLocals, AfterNonLocals),
1155
1156    proc_info_get_varset(ProcInfo0, !:VarSet),
1157    proc_info_get_vartypes(ProcInfo0, !:VarTypes),
1158
1159    accu_substs_init(set_of_var.to_sorted_list(InitAccs), !VarSet, !VarTypes,
1160        !:Substs),
1161
1162    set_of_var.list_to_set(OutPrime, OutPrimeSet),
1163    accu_process_assoc_set(ModuleInfo, GoalStore, set.to_sorted_list(Assoc),
1164        OutPrimeSet, !Substs, !VarSet, !VarTypes, CS, Warnings),
1165
1166    accu_process_update_set(ModuleInfo, GoalStore, set.to_sorted_list(Update),
1167        OutPrimeSet, !Substs, !VarSet, !VarTypes, UpdateOut, UpdateAccOut,
1168        BasePairs),
1169
1170    Accs = set_of_var.to_sorted_list(InitAccs) ++ UpdateAccOut,
1171
1172    accu_divide_base_case(ModuleInfo, !.VarTypes, GoalStore, UpdateOut, Out,
1173        UpdateBase, AssocBase, OtherBase),
1174
1175    BaseCase = accu_base(UpdateBase, AssocBase, OtherBase).
1176
1177%---------------------------------------------------------------------------%
1178
1179:- pred accu_substs_init(list(prog_var)::in, prog_varset::in, prog_varset::out,
1180    vartypes::in, vartypes::out, accu_substs::out) is det.
1181
1182accu_substs_init(InitAccs, !VarSet, !VarTypes, Substs) :-
1183    map.init(Subst),
1184    acc_var_subst_init(InitAccs, !VarSet, !VarTypes, AccVarSubst),
1185    RecCallSubst = Subst,
1186    AssocCallSubst = Subst,
1187    UpdateSubst = Subst,
1188    Substs = accu_substs(AccVarSubst, RecCallSubst, AssocCallSubst,
1189        UpdateSubst).
1190
1191    % Initialise the acc_var_subst to be from Var to A_Var where Var is a
1192    % member of InitAccs and A_Var is a fresh variable of the same type of Var.
1193    %
1194:- pred acc_var_subst_init(list(prog_var)::in,
1195    prog_varset::in, prog_varset::out, vartypes::in, vartypes::out,
1196    accu_subst::out) is det.
1197
1198acc_var_subst_init([], !VarSet, !VarTypes, map.init).
1199acc_var_subst_init([Var | Vars], !VarSet, !VarTypes, Subst) :-
1200    create_new_var(Var, "A_", AccVar, !VarSet, !VarTypes),
1201    acc_var_subst_init(Vars, !VarSet, !VarTypes, Subst0),
1202    map.det_insert(Var, AccVar, Subst0, Subst).
1203
1204    % Create a fresh variable which is the same type as the old variable
1205    % and has the same name except that it begins with the prefix.
1206    %
1207:- pred create_new_var(prog_var::in, string::in, prog_var::out,
1208    prog_varset::in, prog_varset::out, vartypes::in, vartypes::out) is det.
1209
1210create_new_var(OldVar, Prefix, NewVar, !VarSet, !VarTypes) :-
1211    varset.lookup_name(!.VarSet, OldVar, OldName),
1212    string.append(Prefix, OldName, NewName),
1213    varset.new_named_var(NewName, NewVar, !VarSet),
1214    lookup_var_type(!.VarTypes, OldVar, Type),
1215    add_var_type(NewVar, Type, !VarTypes).
1216
1217%---------------------------------------------------------------------------%
1218
1219    % For each member of the assoc set determine the substitutions needed,
1220    % and also check the efficiency of the procedure isn't worsened
1221    % by reordering the arguments to a call.
1222    %
1223:- pred accu_process_assoc_set(module_info::in, accu_goal_store::in,
1224    list(accu_goal_id)::in, set_of_progvar::in,
1225    accu_substs::in, accu_substs::out,
1226    prog_varset::in, prog_varset::out, vartypes::in, vartypes::out,
1227    accu_goal_store::out, list(accu_warning)::out) is semidet.
1228
1229accu_process_assoc_set(_ModuleInfo, _GS, [], _OutPrime, !Substs,
1230        !VarSet, !VarTypes, CS, []) :-
1231    goal_store_init(CS).
1232accu_process_assoc_set(ModuleInfo, GS, [Id | Ids], OutPrime, !Substs,
1233        !VarSet, !VarTypes, CS, Warnings) :-
1234    !.Substs = accu_substs(AccVarSubst, RecCallSubst0, AssocCallSubst0,
1235        UpdateSubst),
1236
1237    lookup_call(GS, Id, stored_goal(Goal, InstMap)),
1238
1239    Goal = hlds_goal(plain_call(PredId, _, Args, _, _, _), GoalInfo),
1240    accu_is_associative(ModuleInfo, PredId, Args, AssocInfo),
1241    AssocInfo = accu_assoc(Vars, AssocOutput, IsCommutative),
1242    OutPrimeVars = set_of_var.intersect(Vars, OutPrime),
1243    set_of_var.is_singleton(OutPrimeVars, DuringAssocVar),
1244    set_of_var.is_singleton(set_of_var.difference(Vars, OutPrimeVars),
1245        BeforeAssocVar),
1246
1247    map.lookup(AccVarSubst, BeforeAssocVar, AccVar),
1248    create_new_var(BeforeAssocVar, "NewAcc_", NewAcc, !VarSet, !VarTypes),
1249
1250    map.det_insert(DuringAssocVar, AccVar, AssocCallSubst0, AssocCallSubst1),
1251    map.det_insert(AssocOutput, NewAcc, AssocCallSubst1, AssocCallSubst),
1252    map.det_insert(DuringAssocVar, AssocOutput, RecCallSubst0, RecCallSubst1),
1253    map.det_insert(BeforeAssocVar, NewAcc, RecCallSubst1, RecCallSubst),
1254
1255    !:Substs = accu_substs(AccVarSubst, RecCallSubst, AssocCallSubst,
1256        UpdateSubst),
1257
1258    % ONLY swap the order of the variables if the goal is
1259    % associative and not commutative.
1260    (
1261        IsCommutative = yes,
1262        CSGoal = stored_goal(Goal, InstMap),
1263        CurWarnings = []
1264    ;
1265        IsCommutative = no,
1266
1267        % Ensure that the reordering doesn't cause a efficiency problem.
1268        module_info_pred_info(ModuleInfo, PredId, PredInfo),
1269        ModuleName = pred_info_module(PredInfo),
1270        PredName = pred_info_name(PredInfo),
1271        Arity = pred_info_orig_arity(PredInfo),
1272        ( if accu_has_heuristic(ModuleName, PredName, Arity) then
1273            % Only do the transformation if the accumulator variable is
1274            % *not* in a position where it will control the running time
1275            % of the predicate.
1276            accu_heuristic(ModuleName, PredName, Arity, Args,
1277                PossibleDuringAssocVars),
1278            set_of_var.member(PossibleDuringAssocVars, DuringAssocVar),
1279            CurWarnings = []
1280        else
1281            ProgContext = goal_info_get_context(GoalInfo),
1282            CurWarnings = [accu_warn(ProgContext, PredId, BeforeAssocVar,
1283                DuringAssocVar)]
1284        ),
1285        % Swap the arguments.
1286        [A, B] = set_of_var.to_sorted_list(Vars),
1287        map.from_assoc_list([A - B, B - A], Subst),
1288        rename_some_vars_in_goal(Subst, Goal, SwappedGoal),
1289        CSGoal = stored_goal(SwappedGoal, InstMap)
1290    ),
1291
1292    accu_process_assoc_set(ModuleInfo, GS, Ids, OutPrime, !Substs,
1293        !VarSet, !VarTypes, CS0, Warnings0),
1294    goal_store_det_insert(Id, CSGoal, CS0, CS),
1295    Warnings = Warnings0 ++ CurWarnings.
1296
1297:- pred accu_has_heuristic(module_name::in, string::in, arity::in) is semidet.
1298
1299accu_has_heuristic(unqualified("list"), "append", 3).
1300
1301    % heuristic returns the set of which head variables are important
1302    % in the running time of the predicate.
1303    %
1304:- pred accu_heuristic(module_name::in, string::in, arity::in,
1305    list(prog_var)::in, set_of_progvar::out) is semidet.
1306
1307accu_heuristic(unqualified("list"), "append", 3, [_Typeinfo, A, _B, _C],
1308        Set) :-
1309    set_of_var.make_singleton(A, Set).
1310
1311%---------------------------------------------------------------------------%
1312
1313    % For each member of the update set determine the substitutions needed
1314    % (creating the accumulator variables when needed).
1315    % Also associate with each Output variable which accumulator variable
1316    % to get the result from.
1317    %
1318:- pred accu_process_update_set(module_info::in, accu_goal_store::in,
1319    list(accu_goal_id)::in, set_of_progvar::in,
1320    accu_substs::in, accu_substs::out,
1321    prog_varset::in, prog_varset::out, vartypes::in, vartypes::out,
1322    list(prog_var)::out, list(prog_var)::out, list(pair(prog_var))::out)
1323    is semidet.
1324
1325accu_process_update_set(_ModuleInfo, _GS, [], _OutPrime, !Substs,
1326        !VarSet, !VarTypes, [], [], []).
1327accu_process_update_set(ModuleInfo, GS, [Id | Ids], OutPrime, !Substs,
1328        !VarSet, !VarTypes, StateOutputVars, Accs, BasePairs) :-
1329    !.Substs = accu_substs(AccVarSubst0, RecCallSubst0, AssocCallSubst,
1330        UpdateSubst0),
1331    lookup_call(GS, Id, stored_goal(Goal, _InstMap)),
1332
1333    Goal = hlds_goal(plain_call(PredId, _, Args, _, _, _), _GoalInfo),
1334    accu_is_update(ModuleInfo, PredId, Args, StateVars),
1335    StateVars = state_update_vars(StateVarA, StateVarB),
1336
1337    ( if set_of_var.member(OutPrime, StateVarA) then
1338        StateInputVar = StateVarA,
1339        StateOutputVar = StateVarB
1340    else
1341        StateInputVar = StateVarB,
1342        StateOutputVar = StateVarA
1343    ),
1344
1345    create_new_var(StateInputVar, "Acc_", Acc0, !VarSet, !VarTypes),
1346    create_new_var(StateOutputVar, "Acc_", Acc, !VarSet, !VarTypes),
1347
1348    map.det_insert(StateInputVar, Acc0, UpdateSubst0, UpdateSubst1),
1349    map.det_insert(StateOutputVar, Acc, UpdateSubst1, UpdateSubst),
1350    map.det_insert(StateInputVar, StateOutputVar, RecCallSubst0, RecCallSubst),
1351    map.det_insert(Acc, Acc0, AccVarSubst0, AccVarSubst),
1352    !:Substs = accu_substs(AccVarSubst, RecCallSubst, AssocCallSubst,
1353        UpdateSubst),
1354
1355    accu_process_update_set(ModuleInfo, GS, Ids, OutPrime, !Substs,
1356        !VarSet, !VarTypes, StateOutputVars0, Accs0, BasePairs0),
1357
1358    % Rather then concatenating to start of the list we concatenate to the end
1359    % of the list. This allows the accumulator introduction to be applied
1360    % as the heuristic will succeed (remember after transforming the two
1361    % input variables will have their order swapped, so they must be in the
1362    % inefficient order to start with)
1363
1364    StateOutputVars = StateOutputVars0 ++ [StateOutputVar],
1365    Accs = Accs0 ++ [Acc],
1366    BasePairs = BasePairs0 ++ [StateOutputVar - Acc0].
1367
1368%---------------------------------------------------------------------------%
1369
1370    % divide_base_case(UpdateOut, Out, U, A, O) is true iff given the output
1371    % variables which are instantiated by update goals, UpdateOut, and all
1372    % the variables that need to be accumulated, Out, divide the base case up
1373    % into three sets, those base case goals which initialize the variables
1374    % used by update calls, U, those which initialize variables used by
1375    % assoc calls, A, and the rest of the goals, O. Note that the sets
1376    % are not necessarily disjoint, as the result of a goal may be used
1377    % to initialize a variable in both U and A, so both U and A will contain
1378    % the same goal_id.
1379    %
1380:- pred accu_divide_base_case(module_info::in, vartypes::in,
1381    accu_goal_store::in, list(prog_var)::in, list(prog_var)::in,
1382    set(accu_goal_id)::out, set(accu_goal_id)::out, set(accu_goal_id)::out)
1383    is det.
1384
1385accu_divide_base_case(ModuleInfo, VarTypes, C, UpdateOut, Out,
1386        UpdateBase, AssocBase, OtherBase) :-
1387    list.delete_elems(Out, UpdateOut, AssocOut),
1388
1389    list.map(accu_related(ModuleInfo, VarTypes, C), UpdateOut, UpdateBaseList),
1390    list.map(accu_related(ModuleInfo, VarTypes, C), AssocOut, AssocBaseList),
1391    UpdateBase = set.power_union(set.list_to_set(UpdateBaseList)),
1392    AssocBase = set.power_union(set.list_to_set(AssocBaseList)),
1393
1394    Set = base_case_ids_set(C) `difference` (UpdateBase `union` AssocBase),
1395    set.to_sorted_list(Set, List),
1396
1397    list.map(
1398        ( pred(GoalId::in, Ancestors::out) is det :-
1399            goal_store_all_ancestors(C, GoalId, VarTypes,
1400                ModuleInfo, no, Ancestors)
1401        ), List, OtherBaseList),
1402
1403    OtherBase = set.list_to_set(List) `union`
1404        (base_case_ids_set(C) `intersect`
1405        set.power_union(set.list_to_set(OtherBaseList))).
1406
1407    % accu_related(ModuleInfo, VarTypes, GoalStore, Var, Related):
1408    %
1409    % From GoalStore, return all the goal_ids, Related, which are needed
1410    % to initialize Var.
1411    %
1412:- pred accu_related(module_info::in, vartypes::in, accu_goal_store::in,
1413    prog_var::in, set(accu_goal_id)::out) is det.
1414
1415accu_related(ModuleInfo, VarTypes, GoalStore, Var, Related) :-
1416    solutions.solutions(
1417        ( pred(Key::out) is nondet :-
1418            goal_store_member(GoalStore, Key, stored_goal(Goal, InstMap0)),
1419            Key = accu_goal_id(accu_base, _),
1420            Goal = hlds_goal(_GoalExpr, GoalInfo),
1421            InstMapDelta = goal_info_get_instmap_delta(GoalInfo),
1422            apply_instmap_delta(InstMapDelta, InstMap0, InstMap),
1423            instmap_changed_vars(ModuleInfo, VarTypes,
1424                InstMap0, InstMap, ChangedVars),
1425            set_of_var.is_singleton(ChangedVars, Var)
1426        ), Ids),
1427    (
1428        Ids = [],
1429        unexpected($pred, "no Id")
1430    ;
1431        Ids = [Id],
1432        goal_store_all_ancestors(GoalStore, Id, VarTypes, ModuleInfo, no,
1433            Ancestors),
1434        list.filter((pred(accu_goal_id(accu_base, _)::in) is semidet),
1435            set.to_sorted_list(set.insert(Ancestors, Id)), RelatedList),
1436        Related = set.list_to_set(RelatedList)
1437    ;
1438        Ids = [_, _ | _],
1439        unexpected($pred, "more than one Id")
1440    ).
1441
1442%---------------------------------------------------------------------------%
1443
1444:- inst stored_goal_plain_call for goal_store.stored_goal/0
1445    --->    stored_goal(goal_plain_call, ground).
1446
1447    % Do a goal_store_lookup where the result is known to be a call.
1448    %
1449:- pred lookup_call(accu_goal_store::in, accu_goal_id::in,
1450    stored_goal::out(stored_goal_plain_call)) is det.
1451
1452lookup_call(GoalStore, Id, stored_goal(Call, InstMap)) :-
1453    goal_store_lookup(GoalStore, Id, stored_goal(Goal, InstMap)),
1454    ( if
1455        Goal = hlds_goal(GoalExpr, GoalInfo),
1456        GoalExpr = plain_call(_, _, _, _, _, _)
1457    then
1458        Call = hlds_goal(GoalExpr, GoalInfo)
1459    else
1460        unexpected($pred, "not a call")
1461    ).
1462
1463%---------------------------------------------------------------------------%
1464%---------------------------------------------------------------------------%
1465
1466    % accu_stage3 creates the accumulator version of the predicate using
1467    % the substitutions determined in stage2. It also redefines the
1468    % original procedure to call the accumulator version of the procedure.
1469    %
1470:- pred accu_stage3(accu_goal_id::in, list(prog_var)::in, prog_varset::in,
1471    vartypes::in, accu_goal_store::in, accu_goal_store::in,
1472    accu_substs::in, accu_subst::in, accu_subst::in,
1473    accu_base::in, list(pair(prog_var))::in, accu_sets::in,
1474    list(prog_var)::in, top_level::in, pred_id::in, pred_info::in,
1475    proc_info::in, proc_info::out, module_info::in, module_info::out) is det.
1476
1477accu_stage3(RecCallId, Accs, VarSet, VarTypes, C, CS, Substs,
1478        HeadToCallSubst, CallToHeadSubst, BaseCase, BasePairs, Sets, Out,
1479        TopLevel, OrigPredId, OrigPredInfo, !OrigProcInfo, !ModuleInfo) :-
1480    acc_proc_info(Accs, VarSet, VarTypes, Substs, !.OrigProcInfo,
1481        AccTypes, AccProcInfo),
1482    acc_pred_info(AccTypes, Out, AccProcInfo, OrigPredId, OrigPredInfo,
1483        AccProcId, AccPredInfo),
1484    AccName = unqualified(pred_info_name(AccPredInfo)),
1485
1486    module_info_get_predicate_table(!.ModuleInfo, PredTable0),
1487    predicate_table_insert(AccPredInfo, AccPredId, PredTable0, PredTable),
1488    module_info_set_predicate_table(PredTable, !ModuleInfo),
1489    accu_create_goal(RecCallId, Accs, AccPredId, AccProcId, AccName, Substs,
1490        HeadToCallSubst, CallToHeadSubst, BaseCase, BasePairs, Sets, C, CS,
1491        OrigBaseGoal, OrigRecGoal, AccBaseGoal, AccRecGoal),
1492
1493    proc_info_get_goal(!.OrigProcInfo, OrigGoal0),
1494    accu_top_level(TopLevel, OrigGoal0, OrigBaseGoal, OrigRecGoal,
1495        AccBaseGoal, AccRecGoal, OrigGoal, AccGoal),
1496
1497    proc_info_set_goal(OrigGoal, !OrigProcInfo),
1498    proc_info_set_varset(VarSet, !OrigProcInfo),
1499    proc_info_set_vartypes(VarTypes, !OrigProcInfo),
1500
1501    requantify_proc_general(ordinary_nonlocals_no_lambda, !OrigProcInfo),
1502    update_accumulator_pred(AccPredId, AccProcId, AccGoal, !ModuleInfo).
1503
1504%---------------------------------------------------------------------------%
1505
1506    % Construct a proc_info for the introduced predicate.
1507    %
1508:- pred acc_proc_info(list(prog_var)::in, prog_varset::in, vartypes::in,
1509    accu_substs::in, proc_info::in, list(mer_type)::out, proc_info::out)
1510    is det.
1511
1512acc_proc_info(Accs0, VarSet, VarTypes, Substs, OrigProcInfo,
1513        AccTypes, AccProcInfo) :-
1514    % ProcInfo Stuff that must change.
1515    proc_info_get_headvars(OrigProcInfo, HeadVars0),
1516    proc_info_get_argmodes(OrigProcInfo, HeadModes0),
1517
1518    proc_info_get_inst_varset(OrigProcInfo, InstVarSet),
1519    proc_info_get_inferred_determinism(OrigProcInfo, Detism),
1520    proc_info_get_goal(OrigProcInfo, Goal),
1521    proc_info_get_context(OrigProcInfo, Context),
1522    proc_info_get_rtti_varmaps(OrigProcInfo, RttiVarMaps),
1523    proc_info_get_is_address_taken(OrigProcInfo, IsAddressTaken),
1524    proc_info_get_has_parallel_conj(OrigProcInfo, HasParallelConj),
1525    proc_info_get_var_name_remap(OrigProcInfo, VarNameRemap),
1526
1527    Substs = accu_substs(AccVarSubst, _RecCallSubst, _AssocCallSubst,
1528        _UpdateSubst),
1529    list.map(map.lookup(AccVarSubst), Accs0, Accs),
1530
1531    % We place the extra accumulator variables at the start, because placing
1532    % them at the end breaks the convention that the last variable of a
1533    % function is the output variable.
1534    HeadVars = Accs ++ HeadVars0,
1535
1536    % XXX we don't want to use the inst of the var as it can be more specific
1537    % than it should be. ie int_const(1) when it should be any integer.
1538    % However this will no longer handle partially instantiated data
1539    % structures.
1540    Inst = ground(shared, none_or_default_func),
1541    inst_lists_to_mode_list([Inst], [Inst], Mode),
1542    list.duplicate(list.length(Accs), list.det_head(Mode), AccModes),
1543    HeadModes = AccModes ++ HeadModes0,
1544
1545    lookup_var_types(VarTypes, Accs, AccTypes),
1546
1547    SeqNum = item_no_seq_num,
1548    proc_info_create(Context, SeqNum, VarSet, VarTypes, HeadVars,
1549        InstVarSet, HeadModes, detism_decl_none, Detism, Goal, RttiVarMaps,
1550        IsAddressTaken, HasParallelConj, VarNameRemap, AccProcInfo).
1551
1552%---------------------------------------------------------------------------%
1553
1554    % Construct the pred_info for the introduced predicate.
1555    %
1556:- pred acc_pred_info(list(mer_type)::in, list(prog_var)::in, proc_info::in,
1557    pred_id::in, pred_info::in, proc_id::out, pred_info::out) is det.
1558
1559acc_pred_info(NewTypes, OutVars, NewProcInfo, OrigPredId, OrigPredInfo,
1560        NewProcId, NewPredInfo) :-
1561    % PredInfo stuff that must change.
1562    pred_info_get_arg_types(OrigPredInfo, TypeVarSet, ExistQVars, Types0),
1563
1564    ModuleName = pred_info_module(OrigPredInfo),
1565    Name = pred_info_name(OrigPredInfo),
1566    PredOrFunc = pred_info_is_pred_or_func(OrigPredInfo),
1567    pred_info_get_context(OrigPredInfo, PredContext),
1568    pred_info_get_markers(OrigPredInfo, Markers),
1569    pred_info_get_class_context(OrigPredInfo, ClassContext),
1570    pred_info_get_origin(OrigPredInfo, OldOrigin),
1571    pred_info_get_var_name_remap(OrigPredInfo, VarNameRemap),
1572
1573    set.init(Assertions),
1574
1575    proc_info_get_context(NewProcInfo, Context),
1576    term.context_line(Context, Line),
1577    Counter = 0,
1578
1579    Types = NewTypes ++ Types0,
1580
1581    make_pred_name_with_context(ModuleName, "AccFrom", PredOrFunc, Name,
1582        Line, Counter, SymName),
1583
1584    OutVarNums = list.map(term.var_to_int, OutVars),
1585    Origin = origin_transformed(transform_accumulator(OutVarNums),
1586        OldOrigin, OrigPredId),
1587    GoalType = goal_not_for_promise(np_goal_type_none),
1588    pred_info_create(ModuleName, SymName, PredOrFunc, PredContext, Origin,
1589        pred_status(status_local), Markers, Types, TypeVarSet,
1590        ExistQVars, ClassContext, Assertions, VarNameRemap, GoalType,
1591        NewProcInfo, NewProcId, NewPredInfo).
1592
1593%---------------------------------------------------------------------------%
1594
1595    % create_goal creates the new base and recursive case of the
1596    % original procedure (OrigBaseGoal and OrigRecGoal) and the base
1597    % and recursive cases of accumulator version (AccBaseGoal and
1598    % AccRecGoal).
1599    %
1600:- pred accu_create_goal(accu_goal_id::in, list(prog_var)::in,
1601    pred_id::in, proc_id::in, sym_name::in, accu_substs::in,
1602    accu_subst::in, accu_subst::in, accu_base::in,
1603    list(pair(prog_var))::in, accu_sets::in,
1604    accu_goal_store::in, accu_goal_store::in,
1605    hlds_goal::out, hlds_goal::out, hlds_goal::out, hlds_goal::out) is det.
1606
1607accu_create_goal(RecCallId, Accs, AccPredId, AccProcId, AccName, Substs,
1608        HeadToCallSubst, CallToHeadSubst, BaseIds, BasePairs,
1609        Sets, C, CS, OrigBaseGoal, OrigRecGoal, AccBaseGoal, AccRecGoal) :-
1610    lookup_call(C, RecCallId, stored_goal(OrigCall, _InstMap)),
1611    Call = create_acc_call(OrigCall, Accs, AccPredId, AccProcId, AccName),
1612    create_orig_goal(Call, Substs, HeadToCallSubst, CallToHeadSubst,
1613        BaseIds, Sets, C, OrigBaseGoal, OrigRecGoal),
1614    create_acc_goal(Call, Substs, HeadToCallSubst, BaseIds, BasePairs,
1615        Sets, C, CS, AccBaseGoal, AccRecGoal).
1616
1617    % create_acc_call takes the original call and generates a call to the
1618    % accumulator version of the call, which can have the substitutions
1619    % applied to it easily.
1620    %
1621:- func create_acc_call(hlds_goal::in(goal_plain_call), list(prog_var)::in,
1622    pred_id::in, proc_id::in, sym_name::in) = (hlds_goal::out(goal_plain_call))
1623    is det.
1624
1625create_acc_call(OrigCall, Accs, AccPredId, AccProcId, AccName) = Call :-
1626    OrigCall = hlds_goal(OrigCallExpr, GoalInfo),
1627    OrigCallExpr = plain_call(_PredId, _ProcId, Args, Builtin, Context, _Name),
1628    CallExpr = plain_call(AccPredId, AccProcId, Accs ++ Args, Builtin,
1629        Context, AccName),
1630    Call = hlds_goal(CallExpr, GoalInfo).
1631
1632    % Create the goals which are to replace the original predicate.
1633    %
1634:- pred create_orig_goal(hlds_goal::in, accu_substs::in,
1635    accu_subst::in, accu_subst::in, accu_base::in, accu_sets::in,
1636    accu_goal_store::in, hlds_goal::out, hlds_goal::out) is det.
1637
1638create_orig_goal(Call, Substs, HeadToCallSubst, CallToHeadSubst,
1639        BaseIds, Sets, C, OrigBaseGoal, OrigRecGoal) :-
1640    Substs = accu_substs(_AccVarSubst, _RecCallSubst, _AssocCallSubst,
1641        UpdateSubst),
1642
1643    BaseIds = accu_base(UpdateBase, _AssocBase, _OtherBase),
1644    Before = Sets ^ as_before,
1645    Update = Sets ^ as_update,
1646
1647    U = create_new_orig_recursive_goals(UpdateBase, Update,
1648        HeadToCallSubst, UpdateSubst, C),
1649
1650    rename_some_vars_in_goal(CallToHeadSubst, Call, BaseCall),
1651    Cbefore = accu_goal_list(set.to_sorted_list(Before), C),
1652    Uupdate = accu_goal_list(set.to_sorted_list(UpdateBase) ++
1653        set.to_sorted_list(Update), U),
1654    Cbase = accu_goal_list(base_case_ids(C), C),
1655    calculate_goal_info(conj(plain_conj, Cbefore ++ Uupdate ++ [BaseCall]),
1656        OrigRecGoal),
1657    calculate_goal_info(conj(plain_conj, Cbase), OrigBaseGoal).
1658
1659    % Create the goals which are to go in the new accumulator version
1660    % of the predicate.
1661    %
1662:- pred create_acc_goal(hlds_goal::in, accu_substs::in, accu_subst::in,
1663    accu_base::in, list(pair(prog_var))::in, accu_sets::in,
1664    accu_goal_store::in, accu_goal_store::in,
1665    hlds_goal::out, hlds_goal::out) is det.
1666
1667create_acc_goal(Call, Substs, HeadToCallSubst, BaseIds, BasePairs, Sets,
1668        C, CS, AccBaseGoal, AccRecGoal) :-
1669    Substs = accu_substs(AccVarSubst, RecCallSubst, AssocCallSubst,
1670        UpdateSubst),
1671
1672    BaseIds = accu_base(_UpdateBase, AssocBase, OtherBase),
1673    Sets = accu_sets(Before, Assoc, ConstructAssoc, Construct, Update,
1674        _Reject),
1675
1676    rename_some_vars_in_goal(RecCallSubst, Call, RecCall),
1677
1678    Cbefore = accu_goal_list(set.to_sorted_list(Before), C),
1679
1680    % Create the goals which will be used in the new recursive case.
1681    R = create_new_recursive_goals(Assoc, Construct `union` ConstructAssoc,
1682        Update, AssocCallSubst, AccVarSubst, UpdateSubst, C, CS),
1683
1684    Rassoc = accu_goal_list(set.to_sorted_list(Assoc), R),
1685    Rupdate = accu_goal_list(set.to_sorted_list(Update), R),
1686    Rconstruct = accu_goal_list(set.to_sorted_list(Construct `union`
1687        ConstructAssoc), R),
1688
1689    % Create the goals which will be used in the new base case.
1690    B = create_new_base_goals(Assoc `union` Construct `union`
1691        ConstructAssoc, C, AccVarSubst, HeadToCallSubst),
1692    Bafter = set.to_sorted_list(Assoc `union`
1693        Construct `union` ConstructAssoc),
1694
1695    BaseCase = accu_goal_list(set.to_sorted_list(AssocBase `union` OtherBase)
1696        ++ Bafter, B),
1697
1698    list.map(acc_unification, BasePairs, UpdateBase),
1699
1700    calculate_goal_info(conj(plain_conj, Cbefore ++ Rassoc ++ Rupdate
1701        ++ [RecCall] ++ Rconstruct), AccRecGoal),
1702    calculate_goal_info(conj(plain_conj, UpdateBase ++ BaseCase), AccBaseGoal).
1703
1704    % Create the U set of goals (those that will be used in the original
1705    % recursive case) by renaming all the goals which are used to initialize
1706    % the update state variable using the head_to_call followed by the
1707    % update_subst, and rename all the update goals using the update_subst.
1708    %
1709:- func create_new_orig_recursive_goals(set(accu_goal_id), set(accu_goal_id),
1710    accu_subst, accu_subst, accu_goal_store) = accu_goal_store.
1711
1712create_new_orig_recursive_goals(UpdateBase, Update, HeadToCallSubst,
1713        UpdateSubst, C)
1714        = accu_rename(set.to_sorted_list(Update), UpdateSubst, C, Ubase) :-
1715    Ubase = accu_rename(set.to_sorted_list(UpdateBase),
1716        chain_subst(HeadToCallSubst, UpdateSubst), C, goal_store_init).
1717
1718    % Create the R set of goals (those that will be used in the new
1719    % recursive case) by renaming all the members of assoc in CS
1720    % using assoc_call_subst and all the members of (construct U
1721    % construct_assoc) in C with acc_var_subst.
1722    %
1723:- func create_new_recursive_goals(set(accu_goal_id), set(accu_goal_id),
1724    set(accu_goal_id), accu_subst, accu_subst, accu_subst,
1725    accu_goal_store, accu_goal_store) = accu_goal_store.
1726
1727create_new_recursive_goals(Assoc, Constructs, Update,
1728        AssocCallSubst, AccVarSubst, UpdateSubst, C, CS)
1729        = accu_rename(set.to_sorted_list(Constructs), AccVarSubst, C, RBase) :-
1730    RBase0 = accu_rename(set.to_sorted_list(Assoc), AssocCallSubst, CS,
1731        goal_store_init),
1732    RBase = accu_rename(set.to_sorted_list(Update), UpdateSubst, C, RBase0).
1733
1734    % Create the B set of goals (those that will be used in the new base case)
1735    % by renaming all the base case goals of C with head_to_call and all the
1736    % members of (assoc U construct U construct_assoc) of C with acc_var_subst.
1737    %
1738:- func create_new_base_goals(set(accu_goal_id), accu_goal_store,
1739    accu_subst, accu_subst) = accu_goal_store.
1740
1741create_new_base_goals(Ids, C, AccVarSubst, HeadToCallSubst)
1742        = accu_rename(set.to_sorted_list(Ids), AccVarSubst, C, Bbase) :-
1743    Bbase = accu_rename(base_case_ids(C), HeadToCallSubst, C, goal_store_init).
1744
1745    % acc_unification(O-A, G):
1746    %
1747    % is true if G represents the assignment unification Out = Acc.
1748    %
1749:- pred acc_unification(pair(prog_var)::in, hlds_goal::out) is det.
1750
1751acc_unification(Out - Acc, Goal) :-
1752    UnifyMode = unify_modes_li_lf_ri_rf(free, ground_inst,
1753        ground_inst, ground_inst),
1754    Context = unify_context(umc_explicit, []),
1755    Expr = unify(Out, rhs_var(Acc), UnifyMode, assign(Out,Acc), Context),
1756    set_of_var.list_to_set([Out, Acc], NonLocalVars),
1757    InstMapDelta = instmap_delta_bind_var(Out),
1758    goal_info_init(NonLocalVars, InstMapDelta, detism_det, purity_pure, Info),
1759    Goal = hlds_goal(Expr, Info).
1760
1761%---------------------------------------------------------------------------%
1762
1763    % Given the top level structure of the goal create new version
1764    % with new base and recursive cases plugged in.
1765    %
1766:- pred accu_top_level(top_level::in, hlds_goal::in,
1767    hlds_goal::in, hlds_goal::in, hlds_goal::in,
1768    hlds_goal::in, hlds_goal::out, hlds_goal::out) is det.
1769
1770accu_top_level(TopLevel, Goal, OrigBaseGoal, OrigRecGoal,
1771        NewBaseGoal, NewRecGoal, OrigGoal, NewGoal) :-
1772    (
1773        TopLevel = switch_base_rec,
1774        ( if
1775            Goal = hlds_goal(switch(Var, CanFail, Cases0), GoalInfo),
1776            Cases0 = [case(IdA, [], _), case(IdB, [], _)]
1777        then
1778            OrigCases = [case(IdA, [], OrigBaseGoal),
1779                case(IdB, [], OrigRecGoal)],
1780            OrigGoal = hlds_goal(switch(Var, CanFail, OrigCases), GoalInfo),
1781
1782            NewCases = [case(IdA, [], NewBaseGoal), case(IdB, [], NewRecGoal)],
1783            NewGoal = hlds_goal(switch(Var, CanFail, NewCases), GoalInfo)
1784        else
1785            unexpected($pred, "not the correct top level")
1786        )
1787    ;
1788        TopLevel = switch_rec_base,
1789        ( if
1790            Goal = hlds_goal(switch(Var, CanFail, Cases0), GoalInfo),
1791            Cases0 = [case(IdA, [], _), case(IdB, [], _)]
1792        then
1793            OrigCases = [case(IdA, [], OrigRecGoal),
1794                case(IdB, [], OrigBaseGoal)],
1795            OrigGoal = hlds_goal(switch(Var, CanFail, OrigCases), GoalInfo),
1796
1797            NewCases = [case(IdA, [], NewRecGoal), case(IdB, [], NewBaseGoal)],
1798            NewGoal = hlds_goal(switch(Var, CanFail, NewCases), GoalInfo)
1799        else
1800            unexpected($pred, "not the correct top level")
1801        )
1802    ;
1803        TopLevel = disj_base_rec,
1804        ( if
1805            Goal = hlds_goal(disj(Goals), GoalInfo),
1806            Goals = [_, _]
1807        then
1808            OrigGoals = [OrigBaseGoal, OrigRecGoal],
1809            OrigGoal = hlds_goal(disj(OrigGoals), GoalInfo),
1810
1811            NewGoals = [NewBaseGoal, NewRecGoal],
1812            NewGoal = hlds_goal(disj(NewGoals), GoalInfo)
1813        else
1814            unexpected($pred, "not the correct top level")
1815        )
1816    ;
1817        TopLevel = disj_rec_base,
1818        ( if
1819            Goal = hlds_goal(disj(Goals), GoalInfo),
1820            Goals = [_, _]
1821        then
1822            OrigGoals = [OrigRecGoal, OrigBaseGoal],
1823            OrigGoal = hlds_goal(disj(OrigGoals), GoalInfo),
1824
1825            NewGoals = [NewRecGoal, NewBaseGoal],
1826            NewGoal = hlds_goal(disj(NewGoals), GoalInfo)
1827        else
1828            unexpected($pred, "not the correct top level")
1829        )
1830    ;
1831        TopLevel = ite_base_rec,
1832        ( if Goal = hlds_goal(if_then_else(Vars, Cond, _, _), GoalInfo) then
1833            OrigGoal = hlds_goal(if_then_else(Vars, Cond,
1834                OrigBaseGoal, OrigRecGoal), GoalInfo),
1835            NewGoal = hlds_goal(if_then_else(Vars, Cond,
1836                NewBaseGoal, NewRecGoal), GoalInfo)
1837        else
1838            unexpected($pred, "not the correct top level")
1839        )
1840    ;
1841        TopLevel = ite_rec_base,
1842        ( if Goal = hlds_goal(if_then_else(Vars, Cond, _, _), GoalInfo) then
1843            OrigGoal = hlds_goal(if_then_else(Vars, Cond,
1844                OrigRecGoal, OrigBaseGoal), GoalInfo),
1845            NewGoal = hlds_goal(if_then_else(Vars, Cond,
1846                NewRecGoal, NewBaseGoal), GoalInfo)
1847        else
1848            unexpected($pred, "not the correct top level")
1849        )
1850    ).
1851
1852%---------------------------------------------------------------------------%
1853
1854    % Place the accumulator version of the predicate in the HLDS.
1855    %
1856:- pred update_accumulator_pred(pred_id::in, proc_id::in,
1857    hlds_goal::in, module_info::in, module_info::out) is det.
1858
1859update_accumulator_pred(NewPredId, NewProcId, AccGoal, !ModuleInfo) :-
1860    module_info_pred_proc_info(!.ModuleInfo, NewPredId, NewProcId,
1861        PredInfo, ProcInfo0),
1862    proc_info_set_goal(AccGoal, ProcInfo0, ProcInfo1),
1863    requantify_proc_general(ordinary_nonlocals_no_lambda, ProcInfo1, ProcInfo),
1864    module_info_set_pred_proc_info(NewPredId, NewProcId,
1865        PredInfo, ProcInfo, !ModuleInfo).
1866
1867%---------------------------------------------------------------------------%
1868%---------------------------------------------------------------------------%
1869
1870    % accu_rename(Ids, Subst, From, Initial):
1871    %
1872    % Return a goal_store, Final, which is the result of looking up each
1873    % member of set of goal_ids, Ids, in the goal_store, From, applying
1874    % the substitution and then storing the goal into the goal_store, Initial.
1875    %
1876:- func accu_rename(list(accu_goal_id), accu_subst,
1877    accu_goal_store, accu_goal_store) = accu_goal_store.
1878
1879accu_rename(Ids, Subst, From, Initial) = Final :-
1880    list.foldl(
1881        ( pred(Id::in, GS0::in, GS::out) is det :-
1882            goal_store_lookup(From, Id, stored_goal(Goal0, InstMap)),
1883            rename_some_vars_in_goal(Subst, Goal0, Goal),
1884            goal_store_det_insert(Id, stored_goal(Goal, InstMap), GS0, GS)
1885        ), Ids, Initial, Final).
1886
1887    % Return all the goal_ids which belong in the base case.
1888    %
1889:- func base_case_ids(accu_goal_store) = list(accu_goal_id).
1890
1891base_case_ids(GS) = Base :-
1892    solutions.solutions(
1893        ( pred(Key::out) is nondet :-
1894            goal_store_member(GS, Key, _Goal),
1895            Key = accu_goal_id(accu_base, _)
1896        ), Base).
1897
1898:- func base_case_ids_set(accu_goal_store) = set(accu_goal_id).
1899
1900base_case_ids_set(GS) = set.list_to_set(base_case_ids(GS)).
1901
1902    % Given a list of goal_ids, return the list of hlds_goals from
1903    % the goal_store.
1904    %
1905:- func accu_goal_list(list(accu_goal_id), accu_goal_store) = list(hlds_goal).
1906
1907accu_goal_list(Ids, GS) = Goals :-
1908    list.map(
1909        ( pred(Key::in, G::out) is det :-
1910            goal_store_lookup(GS, Key, stored_goal(G, _))
1911        ), Ids, Goals).
1912
1913%---------------------------------------------------------------------------%
1914%---------------------------------------------------------------------------%
1915
1916:- pred calculate_goal_info(hlds_goal_expr::in, hlds_goal::out) is det.
1917
1918calculate_goal_info(GoalExpr, hlds_goal(GoalExpr, GoalInfo)) :-
1919    ( if GoalExpr = conj(plain_conj, GoalList) then
1920        goal_list_nonlocals(GoalList, NonLocals),
1921        goal_list_instmap_delta(GoalList, InstMapDelta),
1922        goal_list_determinism(GoalList, Detism),
1923
1924        goal_info_init(NonLocals, InstMapDelta, Detism, purity_pure, GoalInfo)
1925    else
1926        unexpected($pred, "not a conj")
1927    ).
1928
1929%---------------------------------------------------------------------------%
1930%---------------------------------------------------------------------------%
1931
1932:- func chain_subst(accu_subst, accu_subst) = accu_subst.
1933
1934chain_subst(AtoB, BtoC) = AtoC :-
1935    map.keys(AtoB, Keys),
1936    chain_subst_2(Keys, AtoB, BtoC, AtoC).
1937
1938:- pred chain_subst_2(list(A)::in, map(A, B)::in, map(B, C)::in,
1939    map(A, C)::out) is det.
1940
1941chain_subst_2([], _, _, AtoC) :-
1942    map.init(AtoC).
1943chain_subst_2([A | As], AtoB, BtoC, AtoC) :-
1944    chain_subst_2(As, AtoB, BtoC, AtoC0),
1945    map.lookup(AtoB, A, B),
1946    ( if map.search(BtoC, B, C) then
1947        map.det_insert(A, C, AtoC0, AtoC)
1948    else
1949        AtoC = AtoC0
1950    ).
1951
1952%---------------------------------------------------------------------------%
1953:- end_module transform_hlds.accumulator.
1954%---------------------------------------------------------------------------%
1955
1956:- some [T] pred unravel_univ(univ::in, T::out) is det.
1957:- pragma foreign_export("C", unravel_univ(in, out), "ML_unravel_univ").
1958:- pragma foreign_export("C#", unravel_univ(in, out), "ML_unravel_univ").
1959:- pragma foreign_export("Java", unravel_univ(in, out), "ML_unravel_univ").
1960
1961unravel_univ(Univ, X) :-
1962    univ_value(Univ) = X.
1963