1local a = vim.api
2local query = require'vim.treesitter.query'
3local language = require'vim.treesitter.language'
4
5local LanguageTree = {}
6LanguageTree.__index = LanguageTree
7
8--- Represents a single treesitter parser for a language.
9--- The language can contain child languages with in its range,
10--- hence the tree.
11---
12---@param source Can be a bufnr or a string of text to parse
13---@param lang The language this tree represents
14---@param opts Options table
15---@param opts.injections A table of language to injection query strings.
16---                      This is useful for overriding the built-in runtime file
17---                      searching for the injection language query per language.
18function LanguageTree.new(source, lang, opts)
19  language.require_language(lang)
20  opts = opts or {}
21
22  if opts.queries then
23    a.nvim_err_writeln("'queries' is no longer supported. Use 'injections' now")
24    opts.injections = opts.queries
25  end
26
27  local injections = opts.injections or {}
28  local self = setmetatable({
29    _source = source,
30    _lang = lang,
31    _children = {},
32    _regions = {},
33    _trees = {},
34    _opts = opts,
35    _injection_query = injections[lang]
36      and query.parse_query(lang, injections[lang])
37      or query.get_query(lang, "injections"),
38    _valid = false,
39    _parser = vim._create_ts_parser(lang),
40    _callbacks = {
41      changedtree = {},
42      bytes = {},
43      detach = {},
44      child_added = {},
45      child_removed = {}
46    },
47  }, LanguageTree)
48
49
50  return self
51end
52
53--- Invalidates this parser and all its children
54function LanguageTree:invalidate(reload)
55  self._valid = false
56
57  -- buffer was reloaded, reparse all trees
58  if reload then
59    self._trees = {}
60  end
61
62  for _, child in ipairs(self._children) do
63    child:invalidate(reload)
64  end
65end
66
67--- Returns all trees this language tree contains.
68--- Does not include child languages.
69function LanguageTree:trees()
70  return self._trees
71end
72
73--- Gets the language of this tree node.
74function LanguageTree:lang()
75  return self._lang
76end
77
78--- Determines whether this tree is valid.
79--- If the tree is invalid, `parse()` must be called
80--- to get the an updated tree.
81function LanguageTree:is_valid()
82  return self._valid
83end
84
85--- Returns a map of language to child tree.
86function LanguageTree:children()
87  return self._children
88end
89
90--- Returns the source content of the language tree (bufnr or string).
91function LanguageTree:source()
92  return self._source
93end
94
95--- Parses all defined regions using a treesitter parser
96--- for the language this tree represents.
97--- This will run the injection query for this language to
98--- determine if any child languages should be created.
99function LanguageTree:parse()
100  if self._valid then
101    return self._trees
102  end
103
104  local parser = self._parser
105  local changes = {}
106
107  local old_trees = self._trees
108  self._trees = {}
109
110  -- If there are no ranges, set to an empty list
111  -- so the included ranges in the parser are cleared.
112  if self._regions and #self._regions > 0 then
113    for i, ranges in ipairs(self._regions) do
114      local old_tree = old_trees[i]
115      parser:set_included_ranges(ranges)
116
117      local tree, tree_changes = parser:parse(old_tree, self._source)
118      self:_do_callback('changedtree', tree_changes, tree)
119
120      table.insert(self._trees, tree)
121      vim.list_extend(changes, tree_changes)
122    end
123  else
124    local tree, tree_changes = parser:parse(old_trees[1], self._source)
125    self:_do_callback('changedtree', tree_changes, tree)
126
127    table.insert(self._trees, tree)
128    vim.list_extend(changes, tree_changes)
129  end
130
131  local injections_by_lang = self:_get_injections()
132  local seen_langs = {}
133
134  for lang, injection_ranges in pairs(injections_by_lang) do
135    local has_lang = language.require_language(lang, nil, true)
136
137    -- Child language trees should just be ignored if not found, since
138    -- they can depend on the text of a node. Intermediate strings
139    -- would cause errors for unknown parsers.
140    if has_lang then
141      local child = self._children[lang]
142
143      if not child then
144        child = self:add_child(lang)
145      end
146
147      child:set_included_regions(injection_ranges)
148
149      local _, child_changes = child:parse()
150
151      -- Propagate any child changes so they are included in the
152      -- the change list for the callback.
153      if child_changes then
154        vim.list_extend(changes, child_changes)
155      end
156
157      seen_langs[lang] = true
158    end
159  end
160
161  for lang, _ in pairs(self._children) do
162    if not seen_langs[lang] then
163      self:remove_child(lang)
164    end
165  end
166
167  self._valid = true
168
169  return self._trees, changes
170end
171
172--- Invokes the callback for each LanguageTree and it's children recursively
173---
174---@param fn The function to invoke. This is invoked with arguments (tree: LanguageTree, lang: string)
175---@param include_self Whether to include the invoking tree in the results.
176function LanguageTree:for_each_child(fn, include_self)
177  if include_self then
178    fn(self, self._lang)
179  end
180
181  for _, child in pairs(self._children) do
182    child:for_each_child(fn, true)
183  end
184end
185
186--- Invokes the callback for each treesitter trees recursively.
187---
188--- Note, this includes the invoking language tree's trees as well.
189---
190---@param fn The callback to invoke. The callback is invoked with arguments
191---         (tree: TSTree, languageTree: LanguageTree)
192function LanguageTree:for_each_tree(fn)
193  for _, tree in ipairs(self._trees) do
194    fn(tree, self)
195  end
196
197  for _, child in pairs(self._children) do
198    child:for_each_tree(fn)
199  end
200end
201
202--- Adds a child language to this tree.
203---
204--- If the language already exists as a child, it will first be removed.
205---
206---@param lang The language to add.
207function LanguageTree:add_child(lang)
208  if self._children[lang] then
209    self:remove_child(lang)
210  end
211
212  self._children[lang] = LanguageTree.new(self._source, lang, self._opts)
213
214  self:invalidate()
215  self:_do_callback('child_added', self._children[lang])
216
217  return self._children[lang]
218end
219
220--- Removes a child language from this tree.
221---
222---@param lang The language to remove.
223function LanguageTree:remove_child(lang)
224  local child = self._children[lang]
225
226  if child then
227    self._children[lang] = nil
228    child:destroy()
229    self:invalidate()
230    self:_do_callback('child_removed', child)
231  end
232end
233
234--- Destroys this language tree and all its children.
235---
236--- Any cleanup logic should be performed here.
237--- Note, this DOES NOT remove this tree from a parent.
238--- `remove_child` must be called on the parent to remove it.
239function LanguageTree:destroy()
240  -- Cleanup here
241  for _, child in ipairs(self._children) do
242    child:destroy()
243  end
244end
245
246--- Sets the included regions that should be parsed by this parser.
247--- A region is a set of nodes and/or ranges that will be parsed in the same context.
248---
249--- For example, `{ { node1 }, { node2} }` is two separate regions.
250--- This will be parsed by the parser in two different contexts... thus resulting
251--- in two separate trees.
252---
253--- `{ { node1, node2 } }` is a single region consisting of two nodes.
254--- This will be parsed by the parser in a single context... thus resulting
255--- in a single tree.
256---
257--- This allows for embedded languages to be parsed together across different
258--- nodes, which is useful for templating languages like ERB and EJS.
259---
260--- Note, this call invalidates the tree and requires it to be parsed again.
261---
262---@param regions A list of regions this tree should manage and parse.
263function LanguageTree:set_included_regions(regions)
264  -- TODO(vigoux): I don't think string parsers are useful for now
265  if type(self._source) == "number" then
266    -- Transform the tables from 4 element long to 6 element long (with byte offset)
267    for _, region in ipairs(regions) do
268      for i, range in ipairs(region) do
269        if type(range) == "table" and #range == 4 then
270          local start_row, start_col, end_row, end_col = unpack(range)
271          -- Easy case, this is a buffer parser
272          -- TODO(vigoux): proper byte computation here, and account for EOL ?
273          local start_byte = a.nvim_buf_get_offset(self._source, start_row) + start_col
274          local end_byte = a.nvim_buf_get_offset(self._source, end_row) + end_col
275
276          region[i] = { start_row, start_col, start_byte, end_row, end_col, end_byte }
277        end
278      end
279    end
280  end
281
282  self._regions = regions
283  -- Trees are no longer valid now that we have changed regions.
284  -- TODO(vigoux,steelsojka): Look into doing this smarter so we can use some of the
285  --                          old trees for incremental parsing. Currently, this only
286  --                          effects injected languages.
287  self._trees = {}
288  self:invalidate()
289end
290
291--- Gets the set of included regions
292function LanguageTree:included_regions()
293  return self._regions
294end
295
296--- Gets language injection points by language.
297---
298--- This is where most of the injection processing occurs.
299---
300--- TODO: Allow for an offset predicate to tailor the injection range
301---       instead of using the entire nodes range.
302---@private
303function LanguageTree:_get_injections()
304  if not self._injection_query then return {} end
305
306  local injections = {}
307
308  for tree_index, tree in ipairs(self._trees) do
309    local root_node = tree:root()
310    local start_line, _, end_line, _ = root_node:range()
311
312    for pattern, match, metadata in self._injection_query:iter_matches(root_node, self._source, start_line, end_line+1) do
313      local lang = nil
314      local ranges = {}
315      local combined = metadata.combined
316
317      -- Directives can configure how injections are captured as well as actual node captures.
318      -- This allows more advanced processing for determining ranges and language resolution.
319      if metadata.content then
320        local content = metadata.content
321
322        -- Allow for captured nodes to be used
323        if type(content) == "number" then
324          content = {match[content]}
325        end
326
327        if content then
328          vim.list_extend(ranges, content)
329        end
330      end
331
332      if metadata.language then
333        lang = metadata.language
334      end
335
336      -- You can specify the content and language together
337      -- using a tag with the language, for example
338      -- @javascript
339      for id, node in pairs(match) do
340        local name = self._injection_query.captures[id]
341
342        -- Lang should override any other language tag
343        if name == "language" and not lang then
344          lang = query.get_node_text(node, self._source)
345        elseif name == "combined" then
346          combined = true
347        elseif name == "content" and #ranges == 0 then
348          table.insert(ranges, node)
349        -- Ignore any tags that start with "_"
350        -- Allows for other tags to be used in matches
351        elseif string.sub(name, 1, 1) ~= "_" then
352          if not lang then
353            lang = name
354          end
355
356          if #ranges == 0 then
357            table.insert(ranges, node)
358          end
359        end
360      end
361
362      -- Each tree index should be isolated from the other nodes.
363      if not injections[tree_index] then
364        injections[tree_index] = {}
365      end
366
367      if not injections[tree_index][lang] then
368        injections[tree_index][lang] = {}
369      end
370
371      -- Key this by pattern. If combined is set to true all captures of this pattern
372      -- will be parsed by treesitter as the same "source".
373      -- If combined is false, each "region" will be parsed as a single source.
374      if not injections[tree_index][lang][pattern] then
375        injections[tree_index][lang][pattern] = { combined = combined, regions = {} }
376      end
377
378      table.insert(injections[tree_index][lang][pattern].regions, ranges)
379    end
380  end
381
382  local result = {}
383
384  -- Generate a map by lang of node lists.
385  -- Each list is a set of ranges that should be parsed together.
386  for _, lang_map in ipairs(injections) do
387    for lang, patterns in pairs(lang_map) do
388      if not result[lang] then
389        result[lang] = {}
390      end
391
392      for _, entry in pairs(patterns) do
393        if entry.combined then
394          table.insert(result[lang], vim.tbl_flatten(entry.regions))
395        else
396          for _, ranges in ipairs(entry.regions) do
397            table.insert(result[lang], ranges)
398          end
399        end
400      end
401    end
402  end
403
404  return result
405end
406
407---@private
408function LanguageTree:_do_callback(cb_name, ...)
409  for _, cb in ipairs(self._callbacks[cb_name]) do
410    cb(...)
411  end
412end
413
414---@private
415function LanguageTree:_on_bytes(bufnr, changed_tick,
416                          start_row, start_col, start_byte,
417                          old_row, old_col, old_byte,
418                          new_row, new_col, new_byte)
419  self:invalidate()
420
421  local old_end_col = old_col + ((old_row == 0) and start_col or 0)
422  local new_end_col = new_col + ((new_row == 0) and start_col or 0)
423
424  -- Edit all trees recursively, together BEFORE emitting a bytes callback.
425  -- In most cases this callback should only be called from the root tree.
426  self:for_each_tree(function(tree)
427    tree:edit(start_byte,start_byte+old_byte,start_byte+new_byte,
428      start_row, start_col,
429      start_row+old_row, old_end_col,
430      start_row+new_row, new_end_col)
431  end)
432
433  self:_do_callback('bytes', bufnr, changed_tick,
434      start_row, start_col, start_byte,
435      old_row, old_col, old_byte,
436      new_row, new_col, new_byte)
437end
438
439---@private
440function LanguageTree:_on_reload()
441  self:invalidate(true)
442end
443
444
445---@private
446function LanguageTree:_on_detach(...)
447  self:invalidate(true)
448  self:_do_callback('detach', ...)
449end
450
451--- Registers callbacks for the parser
452---@param cbs An `nvim_buf_attach`-like table argument with the following keys :
453---  `on_bytes` : see `nvim_buf_attach`, but this will be called _after_ the parsers callback.
454---  `on_changedtree` : a callback that will be called every time the tree has syntactical changes.
455---      it will only be passed one argument, that is a table of the ranges (as node ranges) that
456---      changed.
457---  `on_child_added` : emitted when a child is added to the tree.
458---  `on_child_removed` : emitted when a child is removed from the tree.
459function LanguageTree:register_cbs(cbs)
460  if not cbs then return end
461
462  if cbs.on_changedtree then
463    table.insert(self._callbacks.changedtree, cbs.on_changedtree)
464  end
465
466  if cbs.on_bytes then
467    table.insert(self._callbacks.bytes, cbs.on_bytes)
468  end
469
470  if cbs.on_detach then
471    table.insert(self._callbacks.detach, cbs.on_detach)
472  end
473
474  if cbs.on_child_added then
475    table.insert(self._callbacks.child_added, cbs.on_child_added)
476  end
477
478  if cbs.on_child_removed then
479    table.insert(self._callbacks.child_removed, cbs.on_child_removed)
480  end
481end
482
483---@private
484local function tree_contains(tree, range)
485  local start_row, start_col, end_row, end_col = tree:root():range()
486  local start_fits = start_row < range[1] or (start_row == range[1] and start_col <= range[2])
487  local end_fits = end_row > range[3] or (end_row == range[3] and end_col >= range[4])
488
489  if start_fits and end_fits then
490    return true
491  end
492
493  return false
494end
495
496--- Determines whether @param range is contained in this language tree
497---
498--- This goes down the tree to recursively check children.
499---
500---@param range A range, that is a `{ start_line, start_col, end_line, end_col }` table.
501function LanguageTree:contains(range)
502  for _, tree in pairs(self._trees) do
503    if tree_contains(tree, range) then
504      return true
505    end
506  end
507
508  return false
509end
510
511--- Gets the appropriate language that contains @param range
512---
513---@param range A text range, see |LanguageTree:contains|
514function LanguageTree:language_for_range(range)
515  for _, child in pairs(self._children) do
516    if child:contains(range) then
517      return child:language_for_range(range)
518    end
519  end
520
521  return self
522end
523
524return LanguageTree
525