1
2import
3  intsets, ast, idents, algorithm, renderer, strutils,
4  msgs, modulegraphs, syntaxes, options, modulepaths,
5  lineinfos
6
7type
8  DepN = ref object
9    pnode: PNode
10    id, idx, lowLink: int
11    onStack: bool
12    kids: seq[DepN]
13    hAQ, hIS, hB, hCmd: int
14    when defined(debugReorder):
15      expls: seq[string]
16  DepG = seq[DepN]
17
18when defined(debugReorder):
19  var idNames = newTable[int, string]()
20
21proc newDepN(id: int, pnode: PNode): DepN =
22  new(result)
23  result.id = id
24  result.pnode = pnode
25  result.idx = -1
26  result.lowLink = -1
27  result.onStack = false
28  result.kids = @[]
29  result.hAQ = -1
30  result.hIS = -1
31  result.hB = -1
32  result.hCmd = -1
33  when defined(debugReorder):
34    result.expls = @[]
35
36proc accQuoted(cache: IdentCache; n: PNode): PIdent =
37  var id = ""
38  for i in 0..<n.len:
39    let ident = n[i].getPIdent
40    if ident != nil: id.add(ident.s)
41  result = getIdent(cache, id)
42
43proc addDecl(cache: IdentCache; n: PNode; declares: var IntSet) =
44  case n.kind
45  of nkPostfix: addDecl(cache, n[1], declares)
46  of nkPragmaExpr: addDecl(cache, n[0], declares)
47  of nkIdent:
48    declares.incl n.ident.id
49    when defined(debugReorder):
50      idNames[n.ident.id] = n.ident.s
51  of nkSym:
52    declares.incl n.sym.name.id
53    when defined(debugReorder):
54      idNames[n.sym.name.id] = n.sym.name.s
55  of nkAccQuoted:
56    let a = accQuoted(cache, n)
57    declares.incl a.id
58    when defined(debugReorder):
59      idNames[a.id] = a.s
60  of nkEnumFieldDef:
61    addDecl(cache, n[0], declares)
62  else: discard
63
64proc computeDeps(cache: IdentCache; n: PNode, declares, uses: var IntSet; topLevel: bool) =
65  template deps(n) = computeDeps(cache, n, declares, uses, false)
66  template decl(n) =
67    if topLevel: addDecl(cache, n, declares)
68  case n.kind
69  of procDefs, nkMacroDef, nkTemplateDef:
70    decl(n[0])
71    for i in 1..bodyPos: deps(n[i])
72  of nkLetSection, nkVarSection, nkUsingStmt:
73    for a in n:
74      if a.kind in {nkIdentDefs, nkVarTuple}:
75        for j in 0..<a.len-2: decl(a[j])
76        for j in a.len-2..<a.len: deps(a[j])
77  of nkConstSection, nkTypeSection:
78    for a in n:
79      if a.len >= 3:
80        decl(a[0])
81        for i in 1..<a.len:
82          if a[i].kind == nkEnumTy:
83            # declare enum members
84            for b in a[i]:
85              decl(b)
86          else:
87            deps(a[i])
88  of nkIdentDefs:
89    for i in 1..<n.len: # avoid members identifiers in object definition
90      deps(n[i])
91  of nkIdent: uses.incl n.ident.id
92  of nkSym: uses.incl n.sym.name.id
93  of nkAccQuoted: uses.incl accQuoted(cache, n).id
94  of nkOpenSymChoice, nkClosedSymChoice:
95    uses.incl n[0].sym.name.id
96  of nkStmtList, nkStmtListExpr, nkWhenStmt, nkElifBranch, nkElse, nkStaticStmt:
97    for i in 0..<n.len: computeDeps(cache, n[i], declares, uses, topLevel)
98  of nkPragma:
99    let a = n[0]
100    if a.kind == nkExprColonExpr and a[0].kind == nkIdent and a[0].ident.s == "pragma":
101      # user defined pragma
102      decl(a[1])
103    else:
104      for i in 0..<n.safeLen: deps(n[i])
105  of nkMixinStmt, nkBindStmt: discard
106  else:
107    # XXX: for callables, this technically adds the return type dep before args
108    for i in 0..<n.safeLen: deps(n[i])
109
110proc hasIncludes(n:PNode): bool =
111  for a in n:
112    if a.kind == nkIncludeStmt:
113      return true
114
115proc includeModule*(graph: ModuleGraph; s: PSym, fileIdx: FileIndex): PNode =
116  result = syntaxes.parseFile(fileIdx, graph.cache, graph.config)
117  graph.addDep(s, fileIdx)
118  graph.addIncludeDep(FileIndex s.position, fileIdx)
119
120proc expandIncludes(graph: ModuleGraph, module: PSym, n: PNode,
121                    modulePath: string, includedFiles: var IntSet): PNode =
122  # Parses includes and injects them in the current tree
123  if not n.hasIncludes:
124    return n
125  result = newNodeI(nkStmtList, n.info)
126  for a in n:
127    if a.kind == nkIncludeStmt:
128      for i in 0..<a.len:
129        var f = checkModuleName(graph.config, a[i])
130        if f != InvalidFileIdx:
131          if containsOrIncl(includedFiles, f.int):
132            localError(graph.config, a.info, "recursive dependency: '$1'" %
133              toMsgFilename(graph.config, f))
134          else:
135            let nn = includeModule(graph, module, f)
136            let nnn = expandIncludes(graph, module, nn, modulePath,
137                                      includedFiles)
138            excl(includedFiles, f.int)
139            for b in nnn:
140              result.add b
141    else:
142      result.add a
143
144proc splitSections(n: PNode): PNode =
145  # Split typeSections and ConstSections into
146  # sections that contain only one definition
147  assert n.kind == nkStmtList
148  result = newNodeI(nkStmtList, n.info)
149  for a in n:
150    if a.kind in {nkTypeSection, nkConstSection} and a.len > 1:
151      for b in a:
152        var s = newNode(a.kind)
153        s.info = b.info
154        s.add b
155        result.add s
156    else:
157      result.add a
158
159proc haveSameKind(dns: seq[DepN]): bool =
160  # Check if all the nodes in a strongly connected
161  # component have the same kind
162  result = true
163  let kind = dns[0].pnode.kind
164  for dn in dns:
165    if dn.pnode.kind != kind:
166      return false
167
168proc mergeSections(conf: ConfigRef; comps: seq[seq[DepN]], res: PNode) =
169  # Merges typeSections and ConstSections when they form
170  # a strong component (ex: circular type definition)
171  for c in comps:
172    assert c.len > 0
173    if c.len == 1:
174      res.add c[0].pnode
175    else:
176      let fstn = c[0].pnode
177      let kind = fstn.kind
178      # always return to the original order when we got circular dependencies
179      let cs = c.sortedByIt(it.id)
180      if kind in {nkTypeSection, nkConstSection} and haveSameKind(cs):
181        # Circular dependency between type or const sections, we just
182        # need to merge them
183        var sn = newNode(kind)
184        for dn in cs:
185          sn.add dn.pnode[0]
186        res.add sn
187      else:
188        # Problematic circular dependency, we arrange the nodes into
189        # their original relative order and make sure to re-merge
190        # consecutive type and const sections
191        var wmsg = "Circular dependency detected. `codeReordering` pragma may not be able to" &
192          " reorder some nodes properly"
193        when defined(debugReorder):
194          wmsg &= ":\n"
195          for i in 0..<cs.len-1:
196            for j in i..<cs.len:
197              for ci in 0..<cs[i].kids.len:
198                if cs[i].kids[ci].id == cs[j].id:
199                  wmsg &= "line " & $cs[i].pnode.info.line &
200                    " depends on line " & $cs[j].pnode.info.line &
201                    ": " & cs[i].expls[ci] & "\n"
202          for j in 0..<cs.len-1:
203            for ci in 0..<cs[^1].kids.len:
204              if cs[^1].kids[ci].id == cs[j].id:
205                wmsg &= "line " & $cs[^1].pnode.info.line &
206                  " depends on line " & $cs[j].pnode.info.line &
207                  ": " & cs[^1].expls[ci] & "\n"
208        message(conf, cs[0].pnode.info, warnUser, wmsg)
209
210        var i = 0
211        while i < cs.len:
212          if cs[i].pnode.kind in {nkTypeSection, nkConstSection}:
213            let ckind = cs[i].pnode.kind
214            var sn = newNode(ckind)
215            sn.add cs[i].pnode[0]
216            inc i
217            while i < cs.len and cs[i].pnode.kind == ckind:
218              sn.add cs[i].pnode[0]
219              inc i
220            res.add sn
221          else:
222            res.add cs[i].pnode
223            inc i
224
225proc hasImportStmt(n: PNode): bool =
226  # Checks if the node is an import statement or
227  # i it contains one
228  case n.kind
229  of nkImportStmt, nkFromStmt, nkImportExceptStmt:
230    return true
231  of nkStmtList, nkStmtListExpr, nkWhenStmt, nkElifBranch, nkElse, nkStaticStmt:
232    for a in n:
233      if a.hasImportStmt:
234        return true
235  else:
236    result = false
237
238proc hasImportStmt(n: DepN): bool =
239  if n.hIS < 0:
240    n.hIS = ord(n.pnode.hasImportStmt)
241  result = bool(n.hIS)
242
243proc hasCommand(n: PNode): bool =
244  # Checks if the node is a command or a call
245  # or if it contains one
246  case n.kind
247  of nkCommand, nkCall:
248    result = true
249  of nkStmtList, nkStmtListExpr, nkWhenStmt, nkElifBranch, nkElse,
250      nkStaticStmt, nkLetSection, nkConstSection, nkVarSection,
251      nkIdentDefs:
252    for a in n:
253      if a.hasCommand:
254        return true
255  else:
256    return false
257
258proc hasCommand(n: DepN): bool =
259  if n.hCmd < 0:
260    n.hCmd = ord(n.pnode.hasCommand)
261  result = bool(n.hCmd)
262
263proc hasAccQuoted(n: PNode): bool =
264  if n.kind == nkAccQuoted:
265    return true
266  for a in n:
267    if hasAccQuoted(a):
268      return true
269
270const extendedProcDefs = procDefs + {nkMacroDef, nkTemplateDef}
271
272proc hasAccQuotedDef(n: PNode): bool =
273  # Checks if the node is a function, macro, template ...
274  # with a quoted name or if it contains one
275  case n.kind
276  of extendedProcDefs:
277    result = n[0].hasAccQuoted
278  of nkStmtList, nkStmtListExpr, nkWhenStmt, nkElifBranch, nkElse, nkStaticStmt:
279    for a in n:
280      if hasAccQuotedDef(a):
281        return true
282  else:
283    result = false
284
285proc hasAccQuotedDef(n: DepN): bool =
286  if n.hAQ < 0:
287    n.hAQ = ord(n.pnode.hasAccQuotedDef)
288  result = bool(n.hAQ)
289
290proc hasBody(n: PNode): bool =
291  # Checks if the node is a function, macro, template ...
292  # with a body or if it contains one
293  case n.kind
294  of nkCommand, nkCall:
295    result = true
296  of extendedProcDefs:
297    result = n[^1].kind == nkStmtList
298  of nkStmtList, nkStmtListExpr, nkWhenStmt, nkElifBranch, nkElse, nkStaticStmt:
299    for a in n:
300      if a.hasBody:
301        return true
302  else:
303    result = false
304
305proc hasBody(n: DepN): bool =
306  if n.hB < 0:
307    n.hB = ord(n.pnode.hasBody)
308  result = bool(n.hB)
309
310proc intersects(s1, s2: IntSet): bool =
311  for a in s1:
312    if s2.contains(a):
313      return true
314
315proc buildGraph(n: PNode, deps: seq[(IntSet, IntSet)]): DepG =
316  # Build a dependency graph
317  result = newSeqOfCap[DepN](deps.len)
318  for i in 0..<deps.len:
319    result.add newDepN(i, n[i])
320  for i in 0..<deps.len:
321    var ni = result[i]
322    let uses = deps[i][1]
323    let niHasBody = ni.hasBody
324    let niHasCmd = ni.hasCommand
325    for j in 0..<deps.len:
326      if i == j: continue
327      var nj = result[j]
328      let declares = deps[j][0]
329      if j < i and nj.hasCommand and niHasCmd:
330        # Preserve order for commands and calls
331        ni.kids.add nj
332        when defined(debugReorder):
333          ni.expls.add "both have commands and one comes after the other"
334      elif j < i and nj.hasImportStmt:
335        # Every node that comes after an import statement must
336        # depend on that import
337        ni.kids.add nj
338        when defined(debugReorder):
339          ni.expls.add "parent is, or contains, an import statement and child comes after it"
340      elif j < i and niHasBody and nj.hasAccQuotedDef:
341        # Every function, macro, template... with a body depends
342        # on precedent function declarations that have quoted names.
343        # That's because it is hard to detect the use of functions
344        # like "[]=", "[]", "or" ... in their bodies.
345        ni.kids.add nj
346        when defined(debugReorder):
347          ni.expls.add "one declares a quoted identifier and the other has a body and comes after it"
348      elif j < i and niHasBody and not nj.hasBody and
349        intersects(deps[i][0], declares):
350          # Keep function declaration before function definition
351          ni.kids.add nj
352          when defined(debugReorder):
353            for dep in deps[i][0]:
354              if dep in declares:
355                ni.expls.add "one declares \"" & idNames[dep] & "\" and the other defines it"
356      else:
357        for d in declares:
358          if uses.contains(d):
359            ni.kids.add nj
360            when defined(debugReorder):
361              ni.expls.add "one declares \"" & idNames[d] & "\" and the other uses it"
362
363proc strongConnect(v: var DepN, idx: var int, s: var seq[DepN],
364                   res: var seq[seq[DepN]]) =
365  # Recursive part of trajan's algorithm
366  v.idx = idx
367  v.lowLink = idx
368  inc idx
369  s.add v
370  v.onStack = true
371  for w in v.kids.mitems:
372    if w.idx < 0:
373      strongConnect(w, idx, s, res)
374      v.lowLink = min(v.lowLink, w.lowLink)
375    elif w.onStack:
376      v.lowLink = min(v.lowLink, w.idx)
377  if v.lowLink == v.idx:
378    var comp = newSeq[DepN]()
379    while true:
380      var w = s.pop
381      w.onStack = false
382      comp.add w
383      if w.id == v.id: break
384    res.add comp
385
386proc getStrongComponents(g: var DepG): seq[seq[DepN]] =
387  ## Tarjan's algorithm. Performs a topological sort
388  ## and detects strongly connected components.
389  var s: seq[DepN]
390  var idx = 0
391  for v in g.mitems:
392    if v.idx < 0:
393      strongConnect(v, idx, s, result)
394
395proc hasForbiddenPragma(n: PNode): bool =
396  # Checks if the tree node has some pragmas that do not
397  # play well with reordering, like the push/pop pragma
398  for a in n:
399    if a.kind == nkPragma and a[0].kind == nkIdent and
400        a[0].ident.s == "push":
401      return true
402
403proc reorder*(graph: ModuleGraph, n: PNode, module: PSym): PNode =
404  if n.hasForbiddenPragma:
405    return n
406  var includedFiles = initIntSet()
407  let mpath = toFullPath(graph.config, module.fileIdx)
408  let n = expandIncludes(graph, module, n, mpath,
409                          includedFiles).splitSections
410  result = newNodeI(nkStmtList, n.info)
411  var deps = newSeq[(IntSet, IntSet)](n.len)
412  for i in 0..<n.len:
413    deps[i][0] = initIntSet()
414    deps[i][1] = initIntSet()
415    computeDeps(graph.cache, n[i], deps[i][0], deps[i][1], true)
416
417  var g = buildGraph(n, deps)
418  let comps = getStrongComponents(g)
419  mergeSections(graph.config, comps, result)
420