1 // This is an open source non-commercial project. Dear PVS-Studio, please check
2 // it. PVS-Studio Static Code Analyzer for C, C++ and C#: http://www.viva64.com
3 
4 // lua bindings for tree-sitter.
5 // NB: this file mostly contains a generic lua interface for tree-sitter
6 // trees and nodes, and could be broken out as a reusable lua package
7 
8 #include <assert.h>
9 #include <inttypes.h>
10 #include <lauxlib.h>
11 #include <lua.h>
12 #include <lualib.h>
13 #include <stdbool.h>
14 #include <stdlib.h>
15 #include <string.h>
16 
17 #include "nvim/api/private/helpers.h"
18 #include "nvim/buffer.h"
19 #include "nvim/lua/treesitter.h"
20 #include "nvim/memline.h"
21 #include "tree_sitter/api.h"
22 
23 #define TS_META_PARSER "treesitter_parser"
24 #define TS_META_TREE "treesitter_tree"
25 #define TS_META_NODE "treesitter_node"
26 #define TS_META_QUERY "treesitter_query"
27 #define TS_META_QUERYCURSOR "treesitter_querycursor"
28 #define TS_META_TREECURSOR "treesitter_treecursor"
29 
30 typedef struct {
31   TSQueryCursor *cursor;
32   int predicated_match;
33   int max_match_id;
34 } TSLua_cursor;
35 
36 #ifdef INCLUDE_GENERATED_DECLARATIONS
37 # include "lua/treesitter.c.generated.h"
38 #endif
39 
40 static struct luaL_Reg parser_meta[] = {
41   { "__gc", parser_gc },
42   { "__tostring", parser_tostring },
43   { "parse", parser_parse },
44   { "set_included_ranges", parser_set_ranges },
45   { "included_ranges", parser_get_ranges },
46   { NULL, NULL }
47 };
48 
49 static struct luaL_Reg tree_meta[] = {
50   { "__gc", tree_gc },
51   { "__tostring", tree_tostring },
52   { "root", tree_root },
53   { "edit", tree_edit },
54   { "copy", tree_copy },
55   { NULL, NULL }
56 };
57 
58 static struct luaL_Reg node_meta[] = {
59   { "__tostring", node_tostring },
60   { "__eq", node_eq },
61   { "__len", node_child_count },
62   { "id", node_id },
63   { "range", node_range },
64   { "start", node_start },
65   { "end_", node_end },
66   { "type", node_type },
67   { "symbol", node_symbol },
68   { "field", node_field },
69   { "named", node_named },
70   { "missing", node_missing },
71   { "has_error", node_has_error },
72   { "sexpr", node_sexpr },
73   { "child_count", node_child_count },
74   { "named_child_count", node_named_child_count },
75   { "child", node_child },
76   { "named_child", node_named_child },
77   { "descendant_for_range", node_descendant_for_range },
78   { "named_descendant_for_range", node_named_descendant_for_range },
79   { "parent", node_parent },
80   { "iter_children", node_iter_children },
81   { "_rawquery", node_rawquery },
82   { "next_sibling", node_next_sibling },
83   { "prev_sibling", node_prev_sibling },
84   { "next_named_sibling", node_next_named_sibling },
85   { "prev_named_sibling", node_prev_named_sibling },
86   { NULL, NULL }
87 };
88 
89 static struct luaL_Reg query_meta[] = {
90   { "__gc", query_gc },
91   { "__tostring", query_tostring },
92   { "inspect", query_inspect },
93   { NULL, NULL }
94 };
95 
96 // cursors are not exposed, but still needs garbage collection
97 static struct luaL_Reg querycursor_meta[] = {
98   { "__gc", querycursor_gc },
99   { NULL, NULL }
100 };
101 
102 static struct luaL_Reg treecursor_meta[] = {
103   { "__gc", treecursor_gc },
104   { NULL, NULL }
105 };
106 
107 static PMap(cstr_t) langs = MAP_INIT;
108 
build_meta(lua_State * L,const char * tname,const luaL_Reg * meta)109 static void build_meta(lua_State *L, const char *tname, const luaL_Reg *meta)
110 {
111   if (luaL_newmetatable(L, tname)) {  // [meta]
112     luaL_register(L, NULL, meta);
113 
114     lua_pushvalue(L, -1);  // [meta, meta]
115     lua_setfield(L, -2, "__index");  // [meta]
116   }
117   lua_pop(L, 1);  // [] (don't use it now)
118 }
119 
120 /// init the tslua library
121 ///
122 /// all global state is stored in the regirstry of the lua_State
tslua_init(lua_State * L)123 void tslua_init(lua_State *L)
124 {
125   // type metatables
126   build_meta(L, TS_META_PARSER, parser_meta);
127   build_meta(L, TS_META_TREE, tree_meta);
128   build_meta(L, TS_META_NODE, node_meta);
129   build_meta(L, TS_META_QUERY, query_meta);
130   build_meta(L, TS_META_QUERYCURSOR, querycursor_meta);
131   build_meta(L, TS_META_TREECURSOR, treecursor_meta);
132 }
133 
tslua_has_language(lua_State * L)134 int tslua_has_language(lua_State *L)
135 {
136   const char *lang_name = luaL_checkstring(L, 1);
137   lua_pushboolean(L, pmap_has(cstr_t)(&langs, lang_name));
138   return 1;
139 }
140 
tslua_add_language(lua_State * L)141 int tslua_add_language(lua_State *L)
142 {
143   const char *path = luaL_checkstring(L, 1);
144   const char *lang_name = luaL_checkstring(L, 2);
145 
146   if (pmap_has(cstr_t)(&langs, lang_name)) {
147     return 0;
148   }
149 
150 #define BUFSIZE 128
151   char symbol_buf[BUFSIZE];
152   snprintf(symbol_buf, BUFSIZE, "tree_sitter_%s", lang_name);
153 #undef BUFSIZE
154 
155   uv_lib_t lib;
156   if (uv_dlopen(path, &lib)) {
157     snprintf((char *)IObuff, IOSIZE, "Failed to load parser: uv_dlopen: %s",
158              uv_dlerror(&lib));
159     uv_dlclose(&lib);
160     lua_pushstring(L, (char *)IObuff);
161     return lua_error(L);
162   }
163 
164   TSLanguage *(*lang_parser)(void);
165   if (uv_dlsym(&lib, symbol_buf, (void **)&lang_parser)) {
166     snprintf((char *)IObuff, IOSIZE, "Failed to load parser: uv_dlsym: %s",
167              uv_dlerror(&lib));
168     uv_dlclose(&lib);
169     lua_pushstring(L, (char *)IObuff);
170     return lua_error(L);
171   }
172 
173   TSLanguage *lang = lang_parser();
174   if (lang == NULL) {
175     return luaL_error(L, "Failed to load parser %s: internal error", path);
176   }
177 
178   uint32_t lang_version = ts_language_version(lang);
179   if (lang_version < TREE_SITTER_MIN_COMPATIBLE_LANGUAGE_VERSION
180       || lang_version > TREE_SITTER_LANGUAGE_VERSION) {
181     return luaL_error(L,
182                       "ABI version mismatch for %s: supported between %d and %d, found %d",
183                       path,
184                       TREE_SITTER_MIN_COMPATIBLE_LANGUAGE_VERSION,
185                       TREE_SITTER_LANGUAGE_VERSION, lang_version);
186   }
187 
188   pmap_put(cstr_t)(&langs, xstrdup(lang_name), lang);
189 
190   lua_pushboolean(L, true);
191   return 1;
192 }
193 
tslua_inspect_lang(lua_State * L)194 int tslua_inspect_lang(lua_State *L)
195 {
196   const char *lang_name = luaL_checkstring(L, 1);
197 
198   TSLanguage *lang = pmap_get(cstr_t)(&langs, lang_name);
199   if (!lang) {
200     return luaL_error(L, "no such language: %s", lang_name);
201   }
202 
203   lua_createtable(L, 0, 2);  // [retval]
204 
205   size_t nsymbols = (size_t)ts_language_symbol_count(lang);
206 
207   lua_createtable(L, nsymbols-1, 1);  // [retval, symbols]
208   for (size_t i = 0; i < nsymbols; i++) {
209     TSSymbolType t = ts_language_symbol_type(lang, i);
210     if (t == TSSymbolTypeAuxiliary) {
211       // not used by the API
212       continue;
213     }
214     lua_createtable(L, 2, 0);  // [retval, symbols, elem]
215     lua_pushstring(L, ts_language_symbol_name(lang, i));
216     lua_rawseti(L, -2, 1);
217     lua_pushboolean(L, t == TSSymbolTypeRegular);
218     lua_rawseti(L, -2, 2);  // [retval, symbols, elem]
219     lua_rawseti(L, -2, i);  // [retval, symbols]
220   }
221 
222   lua_setfield(L, -2, "symbols");  // [retval]
223 
224   size_t nfields = (size_t)ts_language_field_count(lang);
225   lua_createtable(L, nfields, 1);  // [retval, fields]
226   // Field IDs go from 1 to nfields inclusive (extra index 0 maps to NULL)
227   for (size_t i = 1; i <= nfields; i++) {
228     lua_pushstring(L, ts_language_field_name_for_id(lang, i));
229     lua_rawseti(L, -2, i);  // [retval, fields]
230   }
231 
232   lua_setfield(L, -2, "fields");  // [retval]
233 
234   uint32_t lang_version = ts_language_version(lang);
235   lua_pushinteger(L, lang_version);  // [retval, version]
236   lua_setfield(L, -2, "_abi_version");
237 
238   return 1;
239 }
240 
tslua_push_parser(lua_State * L)241 int tslua_push_parser(lua_State *L)
242 {
243   // Gather language name
244   const char *lang_name = luaL_checkstring(L, 1);
245 
246   TSLanguage *lang = pmap_get(cstr_t)(&langs, lang_name);
247   if (!lang) {
248     return luaL_error(L, "no such language: %s", lang_name);
249   }
250 
251   TSParser **parser = lua_newuserdata(L, sizeof(TSParser *));
252   *parser = ts_parser_new();
253 
254   if (!ts_parser_set_language(*parser, lang)) {
255     ts_parser_delete(*parser);
256     return luaL_error(L, "Failed to load language : %s", lang_name);
257   }
258 
259   lua_getfield(L, LUA_REGISTRYINDEX, TS_META_PARSER);  // [udata, meta]
260   lua_setmetatable(L, -2);  // [udata]
261   return 1;
262 }
263 
parser_check(lua_State * L,uint16_t index)264 static TSParser **parser_check(lua_State *L, uint16_t index)
265 {
266   return luaL_checkudata(L, index, TS_META_PARSER);
267 }
268 
parser_gc(lua_State * L)269 static int parser_gc(lua_State *L)
270 {
271   TSParser **p = parser_check(L, 1);
272   if (!p) {
273     return 0;
274   }
275 
276   ts_parser_delete(*p);
277   return 0;
278 }
279 
parser_tostring(lua_State * L)280 static int parser_tostring(lua_State *L)
281 {
282   lua_pushstring(L, "<parser>");
283   return 1;
284 }
285 
input_cb(void * payload,uint32_t byte_index,TSPoint position,uint32_t * bytes_read)286 static const char *input_cb(void *payload, uint32_t byte_index, TSPoint position,
287                             uint32_t *bytes_read)
288 {
289   buf_T *bp  = payload;
290 #define BUFSIZE 256
291   static char buf[BUFSIZE];
292 
293   if ((linenr_T)position.row >= bp->b_ml.ml_line_count) {
294     *bytes_read = 0;
295     return "";
296   }
297   char_u *line = ml_get_buf(bp, position.row+1, false);
298   size_t len = STRLEN(line);
299   if (position.column > len) {
300     *bytes_read = 0;
301     return "";
302   }
303   size_t tocopy = MIN(len-position.column, BUFSIZE);
304 
305   memcpy(buf, line+position.column, tocopy);
306   // Translate embedded \n to NUL
307   memchrsub(buf, '\n', '\0', tocopy);
308   *bytes_read = (uint32_t)tocopy;
309   if (tocopy < BUFSIZE) {
310     // now add the final \n. If it didn't fit, input_cb will be called again
311     // on the same line with advanced column.
312     buf[tocopy] = '\n';
313     (*bytes_read)++;
314   }
315   return buf;
316 #undef BUFSIZE
317 }
318 
push_ranges(lua_State * L,const TSRange * ranges,const unsigned int length)319 static void push_ranges(lua_State *L, const TSRange *ranges, const unsigned int length)
320 {
321   lua_createtable(L, length, 0);
322   for (size_t i = 0; i < length; i++) {
323     lua_createtable(L, 4, 0);
324     lua_pushinteger(L, ranges[i].start_point.row);
325     lua_rawseti(L, -2, 1);
326     lua_pushinteger(L, ranges[i].start_point.column);
327     lua_rawseti(L, -2, 2);
328     lua_pushinteger(L, ranges[i].end_point.row);
329     lua_rawseti(L, -2, 3);
330     lua_pushinteger(L, ranges[i].end_point.column);
331     lua_rawseti(L, -2, 4);
332 
333     lua_rawseti(L, -2, i+1);
334   }
335 }
336 
parser_parse(lua_State * L)337 static int parser_parse(lua_State *L)
338 {
339   TSParser **p = parser_check(L, 1);
340   if (!p || !(*p)) {
341     return 0;
342   }
343 
344   TSTree *old_tree = NULL;
345   if (!lua_isnil(L, 2)) {
346     TSTree **tmp = tree_check(L, 2);
347     old_tree = tmp ? *tmp : NULL;
348   }
349 
350   TSTree *new_tree = NULL;
351   size_t len;
352   const char *str;
353   long bufnr;
354   buf_T *buf;
355   TSInput input;
356 
357   // This switch is necessary because of the behavior of lua_isstring, that
358   // consider numbers as strings...
359   switch (lua_type(L, 3)) {
360   case LUA_TSTRING:
361     str = lua_tolstring(L, 3, &len);
362     new_tree = ts_parser_parse_string(*p, old_tree, str, len);
363     break;
364 
365   case LUA_TNUMBER:
366     bufnr = lua_tointeger(L, 3);
367     buf = handle_get_buffer(bufnr);
368 
369     if (!buf) {
370       return luaL_error(L, "invalid buffer handle: %d", bufnr);
371     }
372 
373     input = (TSInput){ (void *)buf, input_cb, TSInputEncodingUTF8 };
374     new_tree = ts_parser_parse(*p, old_tree, input);
375 
376     break;
377 
378   default:
379     return luaL_error(L, "invalid argument to parser:parse()");
380   }
381 
382   // Sometimes parsing fails (timeout, or wrong parser ABI)
383   // In those case, just return an error.
384   if (!new_tree) {
385     return luaL_error(L, "An error occurred when parsing.");
386   }
387 
388   // The new tree will be pushed to the stack, without copy, owwership is now to
389   // the lua GC.
390   // Old tree is still owned by the lua GC.
391   uint32_t n_ranges = 0;
392   TSRange *changed = old_tree ?  ts_tree_get_changed_ranges(old_tree, new_tree, &n_ranges) : NULL;
393 
394   push_tree(L, new_tree, false);  // [tree]
395 
396   push_ranges(L, changed, n_ranges);  // [tree, ranges]
397 
398   xfree(changed);
399   return 2;
400 }
401 
tree_copy(lua_State * L)402 static int tree_copy(lua_State *L)
403 {
404   TSTree **tree = tree_check(L, 1);
405   if (!(*tree)) {
406     return 0;
407   }
408 
409   push_tree(L, *tree, true);  // [tree]
410 
411   return 1;
412 }
413 
tree_edit(lua_State * L)414 static int tree_edit(lua_State *L)
415 {
416   if (lua_gettop(L) < 10) {
417     lua_pushstring(L, "not enough args to tree:edit()");
418     return lua_error(L);
419   }
420 
421   TSTree **tree = tree_check(L, 1);
422   if (!(*tree)) {
423     return 0;
424   }
425 
426   long start_byte = lua_tointeger(L, 2);
427   long old_end_byte = lua_tointeger(L, 3);
428   long new_end_byte = lua_tointeger(L, 4);
429   TSPoint start_point = { lua_tointeger(L, 5), lua_tointeger(L, 6) };
430   TSPoint old_end_point = { lua_tointeger(L, 7), lua_tointeger(L, 8) };
431   TSPoint new_end_point = { lua_tointeger(L, 9), lua_tointeger(L, 10) };
432 
433   TSInputEdit edit = { start_byte, old_end_byte, new_end_byte,
434                        start_point, old_end_point, new_end_point };
435 
436   ts_tree_edit(*tree, &edit);
437 
438   return 0;
439 }
440 
441 // Use the top of the stack (without popping it) to create a TSRange, it can be
442 // either a lua table or a TSNode
range_from_lua(lua_State * L,TSRange * range)443 static void range_from_lua(lua_State *L, TSRange *range)
444 {
445   TSNode node;
446 
447   if (lua_istable(L, -1)) {
448     // should be a table of 6 elements
449     if (lua_objlen(L, -1) != 6) {
450       goto error;
451     }
452 
453     uint32_t start_row, start_col, start_byte, end_row, end_col, end_byte;
454     lua_rawgeti(L, -1, 1);  // [ range, start_row]
455     start_row = luaL_checkinteger(L, -1);
456     lua_pop(L, 1);
457 
458     lua_rawgeti(L, -1, 2);  // [ range, start_col]
459     start_col = luaL_checkinteger(L, -1);
460     lua_pop(L, 1);
461 
462     lua_rawgeti(L, -1, 3);  // [ range, start_byte]
463     start_byte = luaL_checkinteger(L, -1);
464     lua_pop(L, 1);
465 
466     lua_rawgeti(L, -1, 4);  // [ range, end_row]
467     end_row = luaL_checkinteger(L, -1);
468     lua_pop(L, 1);
469 
470     lua_rawgeti(L, -1, 5);  // [ range, end_col]
471     end_col = luaL_checkinteger(L, -1);
472     lua_pop(L, 1);
473 
474     lua_rawgeti(L, -1, 6);  // [ range, end_byte]
475     end_byte = luaL_checkinteger(L, -1);
476     lua_pop(L, 1);  // [ range ]
477 
478     *range = (TSRange) {
479       .start_point = (TSPoint) {
480         .row = start_row,
481         .column = start_col
482       },
483       .end_point = (TSPoint) {
484         .row = end_row,
485         .column = end_col
486       },
487       .start_byte = start_byte,
488       .end_byte = end_byte,
489     };
490   } else if (node_check(L, -1, &node)) {
491     *range = (TSRange) {
492       .start_point = ts_node_start_point(node),
493       .end_point = ts_node_end_point(node),
494       .start_byte = ts_node_start_byte(node),
495       .end_byte = ts_node_end_byte(node)
496     };
497   } else {
498     goto error;
499   }
500   return;
501 error:
502   luaL_error(L,
503              "Ranges can only be made from 6 element long tables or nodes.");
504 }
505 
parser_set_ranges(lua_State * L)506 static int parser_set_ranges(lua_State *L)
507 {
508   if (lua_gettop(L) < 2) {
509     return luaL_error(L,
510                       "not enough args to parser:set_included_ranges()");
511   }
512 
513   TSParser **p = parser_check(L, 1);
514   if (!p) {
515     return 0;
516   }
517 
518   if (!lua_istable(L, 2)) {
519     return luaL_error(L,
520                       "argument for parser:set_included_ranges() should be a table.");
521   }
522 
523   size_t tbl_len = lua_objlen(L, 2);
524   TSRange *ranges = xmalloc(sizeof(TSRange) * tbl_len);
525 
526 
527   // [ parser, ranges ]
528   for (size_t index = 0; index < tbl_len; index++) {
529     lua_rawgeti(L, 2, index + 1);  // [ parser, ranges, range ]
530     range_from_lua(L, ranges + index);
531     lua_pop(L, 1);
532   }
533 
534   // This memcpies ranges, thus we can free it afterwards
535   ts_parser_set_included_ranges(*p, ranges, tbl_len);
536   xfree(ranges);
537 
538   return 0;
539 }
540 
parser_get_ranges(lua_State * L)541 static int parser_get_ranges(lua_State *L)
542 {
543   TSParser **p = parser_check(L, 1);
544   if (!p) {
545     return 0;
546   }
547 
548   unsigned int len;
549   const TSRange *ranges = ts_parser_included_ranges(*p, &len);
550 
551   push_ranges(L, ranges, len);
552   return 1;
553 }
554 
555 
556 // Tree methods
557 
558 /// push tree interface on lua stack.
559 ///
560 /// This makes a copy of the tree, so ownership of the argument is unaffected.
push_tree(lua_State * L,TSTree * tree,bool do_copy)561 void push_tree(lua_State *L, TSTree *tree, bool do_copy)
562 {
563   if (tree == NULL) {
564     lua_pushnil(L);
565     return;
566   }
567   TSTree **ud = lua_newuserdata(L, sizeof(TSTree *));  // [udata]
568 
569   if (do_copy) {
570     *ud = ts_tree_copy(tree);
571   } else {
572     *ud = tree;
573   }
574 
575   lua_getfield(L, LUA_REGISTRYINDEX, TS_META_TREE);  // [udata, meta]
576   lua_setmetatable(L, -2);  // [udata]
577 
578   // table used for node wrappers to keep a reference to tree wrapper
579   // NB: in lua 5.3 the uservalue for the node could just be the tree, but
580   // in lua 5.1 the uservalue (fenv) must be a table.
581   lua_createtable(L, 1, 0);  // [udata, reftable]
582   lua_pushvalue(L, -2);  // [udata, reftable, udata]
583   lua_rawseti(L, -2, 1);  // [udata, reftable]
584   lua_setfenv(L, -2);  // [udata]
585 }
586 
tree_check(lua_State * L,uint16_t index)587 static TSTree **tree_check(lua_State *L, uint16_t index)
588 {
589   TSTree **ud = luaL_checkudata(L, index, TS_META_TREE);
590   return ud;
591 }
592 
tree_gc(lua_State * L)593 static int tree_gc(lua_State *L)
594 {
595   TSTree **tree = tree_check(L, 1);
596   if (!tree) {
597     return 0;
598   }
599 
600   ts_tree_delete(*tree);
601   return 0;
602 }
603 
tree_tostring(lua_State * L)604 static int tree_tostring(lua_State *L)
605 {
606   lua_pushstring(L, "<tree>");
607   return 1;
608 }
609 
tree_root(lua_State * L)610 static int tree_root(lua_State *L)
611 {
612   TSTree **tree = tree_check(L, 1);
613   if (!tree) {
614     return 0;
615   }
616   TSNode root = ts_tree_root_node(*tree);
617   push_node(L, root, 1);
618   return 1;
619 }
620 
621 // Node methods
622 
623 /// push node interface on lua stack
624 ///
625 /// top of stack must either be the tree this node belongs to or another node
626 /// of the same tree! This value is not popped. Can only be called inside a
627 /// cfunction with the tslua environment.
push_node(lua_State * L,TSNode node,int uindex)628 static void push_node(lua_State *L, TSNode node, int uindex)
629 {
630   assert(uindex > 0 || uindex < -LUA_MINSTACK);
631   if (ts_node_is_null(node)) {
632     lua_pushnil(L);  // [nil]
633     return;
634   }
635   TSNode *ud = lua_newuserdata(L, sizeof(TSNode));  // [udata]
636   *ud = node;
637   lua_getfield(L, LUA_REGISTRYINDEX, TS_META_NODE);  // [udata, meta]
638   lua_setmetatable(L, -2);  // [udata]
639   lua_getfenv(L, uindex);  // [udata, reftable]
640   lua_setfenv(L, -2);  // [udata]
641 }
642 
node_check(lua_State * L,int index,TSNode * res)643 static bool node_check(lua_State *L, int index, TSNode *res)
644 {
645   TSNode *ud = luaL_checkudata(L, index, TS_META_NODE);
646   if (ud) {
647     *res = *ud;
648     return true;
649   }
650   return false;
651 }
652 
653 
node_tostring(lua_State * L)654 static int node_tostring(lua_State *L)
655 {
656   TSNode node;
657   if (!node_check(L, 1, &node)) {
658     return 0;
659   }
660   lua_pushstring(L, "<node ");
661   lua_pushstring(L, ts_node_type(node));
662   lua_pushstring(L, ">");
663   lua_concat(L, 3);
664   return 1;
665 }
666 
node_eq(lua_State * L)667 static int node_eq(lua_State *L)
668 {
669   TSNode node;
670   if (!node_check(L, 1, &node)) {
671     return 0;
672   }
673 
674   TSNode node2;
675   if (!node_check(L, 2, &node2)) {
676     return 0;
677   }
678 
679   lua_pushboolean(L, ts_node_eq(node, node2));
680   return 1;
681 }
682 
node_id(lua_State * L)683 static int node_id(lua_State *L)
684 {
685   TSNode node;
686   if (!node_check(L, 1, &node)) {
687     return 0;
688   }
689 
690   lua_pushlstring(L, (const char *)&node.id, sizeof node.id);
691   return 1;
692 }
693 
node_range(lua_State * L)694 static int node_range(lua_State *L)
695 {
696   TSNode node;
697   if (!node_check(L, 1, &node)) {
698     return 0;
699   }
700   TSPoint start = ts_node_start_point(node);
701   TSPoint end = ts_node_end_point(node);
702   lua_pushnumber(L, start.row);
703   lua_pushnumber(L, start.column);
704   lua_pushnumber(L, end.row);
705   lua_pushnumber(L, end.column);
706   return 4;
707 }
708 
node_start(lua_State * L)709 static int node_start(lua_State *L)
710 {
711   TSNode node;
712   if (!node_check(L, 1, &node)) {
713     return 0;
714   }
715   TSPoint start = ts_node_start_point(node);
716   uint32_t start_byte = ts_node_start_byte(node);
717   lua_pushnumber(L, start.row);
718   lua_pushnumber(L, start.column);
719   lua_pushnumber(L, start_byte);
720   return 3;
721 }
722 
node_end(lua_State * L)723 static int node_end(lua_State *L)
724 {
725   TSNode node;
726   if (!node_check(L, 1, &node)) {
727     return 0;
728   }
729   TSPoint end = ts_node_end_point(node);
730   uint32_t end_byte = ts_node_end_byte(node);
731   lua_pushnumber(L, end.row);
732   lua_pushnumber(L, end.column);
733   lua_pushnumber(L, end_byte);
734   return 3;
735 }
736 
node_child_count(lua_State * L)737 static int node_child_count(lua_State *L)
738 {
739   TSNode node;
740   if (!node_check(L, 1, &node)) {
741     return 0;
742   }
743   uint32_t count = ts_node_child_count(node);
744   lua_pushnumber(L, count);
745   return 1;
746 }
747 
node_named_child_count(lua_State * L)748 static int node_named_child_count(lua_State *L)
749 {
750   TSNode node;
751   if (!node_check(L, 1, &node)) {
752     return 0;
753   }
754   uint32_t count = ts_node_named_child_count(node);
755   lua_pushnumber(L, count);
756   return 1;
757 }
758 
node_type(lua_State * L)759 static int node_type(lua_State *L)
760 {
761   TSNode node;
762   if (!node_check(L, 1, &node)) {
763     return 0;
764   }
765   lua_pushstring(L, ts_node_type(node));
766   return 1;
767 }
768 
node_symbol(lua_State * L)769 static int node_symbol(lua_State *L)
770 {
771   TSNode node;
772   if (!node_check(L, 1, &node)) {
773     return 0;
774   }
775   TSSymbol symbol = ts_node_symbol(node);
776   lua_pushnumber(L, symbol);
777   return 1;
778 }
779 
node_field(lua_State * L)780 static int node_field(lua_State *L)
781 {
782   TSNode node;
783   if (!node_check(L, 1, &node)) {
784     return 0;
785   }
786 
787   size_t name_len;
788   const char *field_name = luaL_checklstring(L, 2, &name_len);
789 
790   TSTreeCursor cursor = ts_tree_cursor_new(node);
791 
792   lua_newtable(L);  // [table]
793   unsigned int curr_index = 0;
794 
795   if (ts_tree_cursor_goto_first_child(&cursor)) {
796     do {
797       const char *current_field = ts_tree_cursor_current_field_name(&cursor);
798 
799       if (current_field != NULL && !STRCMP(field_name, current_field)) {
800         push_node(L, ts_tree_cursor_current_node(&cursor), 1);  // [table, node]
801         lua_rawseti(L, -2, ++curr_index);
802       }
803     } while (ts_tree_cursor_goto_next_sibling(&cursor));
804   }
805 
806   ts_tree_cursor_delete(&cursor);
807   return 1;
808 }
809 
node_named(lua_State * L)810 static int node_named(lua_State *L)
811 {
812   TSNode node;
813   if (!node_check(L, 1, &node)) {
814     return 0;
815   }
816   lua_pushboolean(L, ts_node_is_named(node));
817   return 1;
818 }
819 
node_sexpr(lua_State * L)820 static int node_sexpr(lua_State *L)
821 {
822   TSNode node;
823   if (!node_check(L, 1, &node)) {
824     return 0;
825   }
826   char *allocated = ts_node_string(node);
827   lua_pushstring(L, allocated);
828   xfree(allocated);
829   return 1;
830 }
831 
node_missing(lua_State * L)832 static int node_missing(lua_State *L)
833 {
834   TSNode node;
835   if (!node_check(L, 1, &node)) {
836     return 0;
837   }
838   lua_pushboolean(L, ts_node_is_missing(node));
839   return 1;
840 }
841 
node_has_error(lua_State * L)842 static int node_has_error(lua_State *L)
843 {
844   TSNode node;
845   if (!node_check(L, 1, &node)) {
846     return 0;
847   }
848   lua_pushboolean(L, ts_node_has_error(node));
849   return 1;
850 }
851 
node_child(lua_State * L)852 static int node_child(lua_State *L)
853 {
854   TSNode node;
855   if (!node_check(L, 1, &node)) {
856     return 0;
857   }
858   long num = lua_tointeger(L, 2);
859   TSNode child = ts_node_child(node, (uint32_t)num);
860 
861   push_node(L, child, 1);
862   return 1;
863 }
864 
node_named_child(lua_State * L)865 static int node_named_child(lua_State *L)
866 {
867   TSNode node;
868   if (!node_check(L, 1, &node)) {
869     return 0;
870   }
871   long num = lua_tointeger(L, 2);
872   TSNode child = ts_node_named_child(node, (uint32_t)num);
873 
874   push_node(L, child, 1);
875   return 1;
876 }
877 
node_descendant_for_range(lua_State * L)878 static int node_descendant_for_range(lua_State *L)
879 {
880   TSNode node;
881   if (!node_check(L, 1, &node)) {
882     return 0;
883   }
884   TSPoint start = { (uint32_t)lua_tointeger(L, 2),
885                     (uint32_t)lua_tointeger(L, 3) };
886   TSPoint end = { (uint32_t)lua_tointeger(L, 4),
887                   (uint32_t)lua_tointeger(L, 5) };
888   TSNode child = ts_node_descendant_for_point_range(node, start, end);
889 
890   push_node(L, child, 1);
891   return 1;
892 }
893 
node_named_descendant_for_range(lua_State * L)894 static int node_named_descendant_for_range(lua_State *L)
895 {
896   TSNode node;
897   if (!node_check(L, 1, &node)) {
898     return 0;
899   }
900   TSPoint start = { (uint32_t)lua_tointeger(L, 2),
901                     (uint32_t)lua_tointeger(L, 3) };
902   TSPoint end = { (uint32_t)lua_tointeger(L, 4),
903                   (uint32_t)lua_tointeger(L, 5) };
904   TSNode child = ts_node_named_descendant_for_point_range(node, start, end);
905 
906   push_node(L, child, 1);
907   return 1;
908 }
909 
node_next_child(lua_State * L)910 static int node_next_child(lua_State *L)
911 {
912   TSTreeCursor *ud = luaL_checkudata(L, lua_upvalueindex(1), TS_META_TREECURSOR);
913   if (!ud) {
914     return 0;
915   }
916 
917   TSNode source;
918   if (!node_check(L, lua_upvalueindex(2), &source)) {
919     return 0;
920   }
921 
922   // First call should return first child
923   if (ts_node_eq(source, ts_tree_cursor_current_node(ud))) {
924     if (ts_tree_cursor_goto_first_child(ud)) {
925       goto push;
926     } else {
927       goto end;
928     }
929   }
930 
931   if (ts_tree_cursor_goto_next_sibling(ud)) {
932 push:
933     push_node(L,
934               ts_tree_cursor_current_node(ud),
935               lua_upvalueindex(2));  // [node]
936 
937     const char *field = ts_tree_cursor_current_field_name(ud);
938 
939     if (field != NULL) {
940       lua_pushstring(L, ts_tree_cursor_current_field_name(ud));
941     } else {
942       lua_pushnil(L);
943     }  // [node, field_name_or_nil]
944     return 2;
945   }
946 
947 end:
948   return 0;
949 }
950 
node_iter_children(lua_State * L)951 static int node_iter_children(lua_State *L)
952 {
953   TSNode source;
954   if (!node_check(L, 1, &source)) {
955     return 0;
956   }
957 
958   TSTreeCursor *ud = lua_newuserdata(L, sizeof(TSTreeCursor));  // [udata]
959   *ud = ts_tree_cursor_new(source);
960 
961   lua_getfield(L, LUA_REGISTRYINDEX, TS_META_TREECURSOR);  // [udata, mt]
962   lua_setmetatable(L, -2);  // [udata]
963   lua_pushvalue(L, 1);  // [udata, source_node]
964   lua_pushcclosure(L, node_next_child, 2);
965 
966   return 1;
967 }
968 
treecursor_gc(lua_State * L)969 static int treecursor_gc(lua_State *L)
970 {
971   TSTreeCursor *ud = luaL_checkudata(L, 1, TS_META_TREECURSOR);
972   ts_tree_cursor_delete(ud);
973   return 0;
974 }
975 
node_parent(lua_State * L)976 static int node_parent(lua_State *L)
977 {
978   TSNode node;
979   if (!node_check(L, 1, &node)) {
980     return 0;
981   }
982   TSNode parent = ts_node_parent(node);
983   push_node(L, parent, 1);
984   return 1;
985 }
986 
node_next_sibling(lua_State * L)987 static int node_next_sibling(lua_State *L)
988 {
989   TSNode node;
990   if (!node_check(L, 1, &node)) {
991     return 0;
992   }
993   TSNode sibling = ts_node_next_sibling(node);
994   push_node(L, sibling, 1);
995   return 1;
996 }
997 
node_prev_sibling(lua_State * L)998 static int node_prev_sibling(lua_State *L)
999 {
1000   TSNode node;
1001   if (!node_check(L, 1, &node)) {
1002     return 0;
1003   }
1004   TSNode sibling = ts_node_prev_sibling(node);
1005   push_node(L, sibling, 1);
1006   return 1;
1007 }
1008 
node_next_named_sibling(lua_State * L)1009 static int node_next_named_sibling(lua_State *L)
1010 {
1011   TSNode node;
1012   if (!node_check(L, 1, &node)) {
1013     return 0;
1014   }
1015   TSNode sibling = ts_node_next_named_sibling(node);
1016   push_node(L, sibling, 1);
1017   return 1;
1018 }
1019 
node_prev_named_sibling(lua_State * L)1020 static int node_prev_named_sibling(lua_State *L)
1021 {
1022   TSNode node;
1023   if (!node_check(L, 1, &node)) {
1024     return 0;
1025   }
1026   TSNode sibling = ts_node_prev_named_sibling(node);
1027   push_node(L, sibling, 1);
1028   return 1;
1029 }
1030 
1031 /// assumes the match table being on top of the stack
set_match(lua_State * L,TSQueryMatch * match,int nodeidx)1032 static void set_match(lua_State *L, TSQueryMatch *match, int nodeidx)
1033 {
1034   for (int i = 0; i < match->capture_count; i++) {
1035     push_node(L, match->captures[i].node, nodeidx);
1036     lua_rawseti(L, -2, match->captures[i].index+1);
1037   }
1038 }
1039 
query_next_match(lua_State * L)1040 static int query_next_match(lua_State *L)
1041 {
1042   TSLua_cursor *ud = lua_touserdata(L, lua_upvalueindex(1));
1043   TSQueryCursor *cursor = ud->cursor;
1044 
1045   TSQuery *query = query_check(L, lua_upvalueindex(3));
1046   TSQueryMatch match;
1047   if (ts_query_cursor_next_match(cursor, &match)) {
1048     lua_pushinteger(L, match.pattern_index+1);  // [index]
1049     lua_createtable(L, ts_query_capture_count(query), 2);  // [index, match]
1050     set_match(L, &match, lua_upvalueindex(2));
1051     return 2;
1052   }
1053   return 0;
1054 }
1055 
1056 
query_next_capture(lua_State * L)1057 static int query_next_capture(lua_State *L)
1058 {
1059   // Upvalues are:
1060   // [ cursor, node, query, current_match ]
1061   TSLua_cursor *ud = lua_touserdata(L, lua_upvalueindex(1));
1062   TSQueryCursor *cursor = ud->cursor;
1063 
1064   TSQuery *query = query_check(L, lua_upvalueindex(3));
1065 
1066   if (ud->predicated_match > -1) {
1067     lua_getfield(L, lua_upvalueindex(4), "active");
1068     bool active = lua_toboolean(L, -1);
1069     lua_pop(L, 1);
1070     if (!active) {
1071       ts_query_cursor_remove_match(cursor, ud->predicated_match);
1072     }
1073     ud->predicated_match = -1;
1074   }
1075 
1076   TSQueryMatch match;
1077   uint32_t capture_index;
1078   if (ts_query_cursor_next_capture(cursor, &match, &capture_index)) {
1079     TSQueryCapture capture = match.captures[capture_index];
1080 
1081     lua_pushinteger(L, capture.index+1);  // [index]
1082     push_node(L, capture.node, lua_upvalueindex(2));  // [index, node]
1083 
1084     // Now check if we need to run the predicates
1085     uint32_t n_pred;
1086     ts_query_predicates_for_pattern(query, match.pattern_index, &n_pred);
1087 
1088     if (n_pred > 0 && (ud->max_match_id < (int)match.id)) {
1089       ud->max_match_id = match.id;
1090 
1091       lua_pushvalue(L, lua_upvalueindex(4));  // [index, node, match]
1092       set_match(L, &match, lua_upvalueindex(2));
1093       lua_pushinteger(L, match.pattern_index+1);
1094       lua_setfield(L, -2, "pattern");
1095 
1096       if (match.capture_count > 1) {
1097         ud->predicated_match = match.id;
1098         lua_pushboolean(L, false);
1099         lua_setfield(L, -2, "active");
1100       }
1101       return 3;
1102     }
1103     return 2;
1104   }
1105   return 0;
1106 }
1107 
node_rawquery(lua_State * L)1108 static int node_rawquery(lua_State *L)
1109 {
1110   TSNode node;
1111   if (!node_check(L, 1, &node)) {
1112     return 0;
1113   }
1114   TSQuery *query = query_check(L, 2);
1115   // TODO(bfredl): these are expensive allegedly,
1116   // use a reuse list later on?
1117   TSQueryCursor *cursor = ts_query_cursor_new();
1118   // TODO(clason): API introduced after tree-sitter release 0.19.5
1119   // remove guard when minimum ts version is bumped to 0.19.6+
1120 #ifdef NVIM_TS_HAS_SET_MATCH_LIMIT
1121   ts_query_cursor_set_match_limit(cursor, 32);
1122 #endif
1123   ts_query_cursor_exec(cursor, query, node);
1124 
1125   bool captures = lua_toboolean(L, 3);
1126 
1127   if (lua_gettop(L) >= 4) {
1128     int start = luaL_checkinteger(L, 4);
1129     int end = lua_gettop(L) >= 5 ? luaL_checkinteger(L, 5) : MAXLNUM;
1130     ts_query_cursor_set_point_range(cursor,
1131                                     (TSPoint){ start, 0 }, (TSPoint){ end, 0 });
1132   }
1133 
1134   TSLua_cursor *ud = lua_newuserdata(L, sizeof(*ud));  // [udata]
1135   ud->cursor = cursor;
1136   ud->predicated_match = -1;
1137   ud->max_match_id = -1;
1138 
1139   lua_getfield(L, LUA_REGISTRYINDEX, TS_META_QUERYCURSOR);
1140   lua_setmetatable(L, -2);  // [udata]
1141   lua_pushvalue(L, 1);  // [udata, node]
1142 
1143   // include query separately, as to keep a ref to it for gc
1144   lua_pushvalue(L, 2);  // [udata, node, query]
1145 
1146   if (captures) {
1147     // placeholder for match state
1148     lua_createtable(L, ts_query_capture_count(query), 2);  // [u, n, q, match]
1149     lua_pushcclosure(L, query_next_capture, 4);  // [closure]
1150   } else {
1151     lua_pushcclosure(L, query_next_match, 3);  // [closure]
1152   }
1153 
1154   return 1;
1155 }
1156 
querycursor_gc(lua_State * L)1157 static int querycursor_gc(lua_State *L)
1158 {
1159   TSLua_cursor *ud = luaL_checkudata(L, 1, TS_META_QUERYCURSOR);
1160   ts_query_cursor_delete(ud->cursor);
1161   return 0;
1162 }
1163 
1164 // Query methods
1165 
tslua_parse_query(lua_State * L)1166 int tslua_parse_query(lua_State *L)
1167 {
1168   if (lua_gettop(L) < 2 || !lua_isstring(L, 1) || !lua_isstring(L, 2)) {
1169     return luaL_error(L, "string expected");
1170   }
1171 
1172   const char *lang_name = lua_tostring(L, 1);
1173   TSLanguage *lang = pmap_get(cstr_t)(&langs, lang_name);
1174   if (!lang) {
1175     return luaL_error(L, "no such language: %s", lang_name);
1176   }
1177 
1178   size_t len;
1179   const char *src = lua_tolstring(L, 2, &len);
1180 
1181   uint32_t error_offset;
1182   TSQueryError error_type;
1183   TSQuery *query = ts_query_new(lang, src, len, &error_offset, &error_type);
1184 
1185   if (!query) {
1186     return luaL_error(L, "query: %s at position %d",
1187                       query_err_string(error_type), (int)error_offset);
1188   }
1189 
1190   TSQuery **ud = lua_newuserdata(L, sizeof(TSQuery *));  // [udata]
1191   *ud = query;
1192   lua_getfield(L, LUA_REGISTRYINDEX, TS_META_QUERY);  // [udata, meta]
1193   lua_setmetatable(L, -2);  // [udata]
1194   return 1;
1195 }
1196 
1197 
query_err_string(TSQueryError err)1198 static const char *query_err_string(TSQueryError err)
1199 {
1200   switch (err) {
1201   case TSQueryErrorSyntax:
1202     return "invalid syntax";
1203   case TSQueryErrorNodeType:
1204     return "invalid node type";
1205   case TSQueryErrorField:
1206     return "invalid field";
1207   case TSQueryErrorCapture:
1208     return "invalid capture";
1209   default:
1210     return "error";
1211   }
1212 }
1213 
query_check(lua_State * L,int index)1214 static TSQuery *query_check(lua_State *L, int index)
1215 {
1216   TSQuery **ud = luaL_checkudata(L, index, TS_META_QUERY);
1217   return *ud;
1218 }
1219 
query_gc(lua_State * L)1220 static int query_gc(lua_State *L)
1221 {
1222   TSQuery *query = query_check(L, 1);
1223   if (!query) {
1224     return 0;
1225   }
1226 
1227   ts_query_delete(query);
1228   return 0;
1229 }
1230 
query_tostring(lua_State * L)1231 static int query_tostring(lua_State *L)
1232 {
1233   lua_pushstring(L, "<query>");
1234   return 1;
1235 }
1236 
query_inspect(lua_State * L)1237 static int query_inspect(lua_State *L)
1238 {
1239   TSQuery *query = query_check(L, 1);
1240   if (!query) {
1241     return 0;
1242   }
1243 
1244   uint32_t n_pat = ts_query_pattern_count(query);
1245   lua_createtable(L, 0, 2);  // [retval]
1246   lua_createtable(L, n_pat, 1);  // [retval, patterns]
1247   for (size_t i = 0; i < n_pat; i++) {
1248     uint32_t len;
1249     const TSQueryPredicateStep *step = ts_query_predicates_for_pattern(query,
1250                                                                        i, &len);
1251     if (len == 0) {
1252       continue;
1253     }
1254     lua_createtable(L, len/4, 1);  // [retval, patterns, pat]
1255     lua_createtable(L, 3, 0);  // [retval, patterns, pat, pred]
1256     int nextpred = 1;
1257     int nextitem = 1;
1258     for (size_t k = 0; k < len; k++) {
1259       if (step[k].type == TSQueryPredicateStepTypeDone) {
1260         lua_rawseti(L, -2, nextpred++);  // [retval, patterns, pat]
1261         lua_createtable(L, 3, 0);  // [retval, patterns, pat, pred]
1262         nextitem = 1;
1263         continue;
1264       }
1265 
1266       if (step[k].type == TSQueryPredicateStepTypeString) {
1267         uint32_t strlen;
1268         const char *str = ts_query_string_value_for_id(query, step[k].value_id,
1269                                                        &strlen);
1270         lua_pushlstring(L, str, strlen);  // [retval, patterns, pat, pred, item]
1271       } else if (step[k].type == TSQueryPredicateStepTypeCapture) {
1272         lua_pushnumber(L, step[k].value_id+1);  // [..., pat, pred, item]
1273       } else {
1274         abort();
1275       }
1276       lua_rawseti(L, -2, nextitem++);  // [retval, patterns, pat, pred]
1277     }
1278     // last predicate should have ended with TypeDone
1279     lua_pop(L, 1);  // [retval, patters, pat]
1280     lua_rawseti(L, -2, i+1);  // [retval, patterns]
1281   }
1282   lua_setfield(L, -2, "patterns");  // [retval]
1283 
1284   uint32_t n_captures = ts_query_capture_count(query);
1285   lua_createtable(L, n_captures, 0);  // [retval, captures]
1286   for (size_t i = 0; i < n_captures; i++) {
1287     uint32_t strlen;
1288     const char *str = ts_query_capture_name_for_id(query, i, &strlen);
1289     lua_pushlstring(L, str, strlen);  // [retval, captures, capture]
1290     lua_rawseti(L, -2, i+1);
1291   }
1292   lua_setfield(L, -2, "captures");  // [retval]
1293 
1294   return 1;
1295 }
1296