1 // This is an open source non-commercial project. Dear PVS-Studio, please check
2 // it. PVS-Studio Static Code Analyzer for C, C++ and C#: http://www.viva64.com
3
4 // lua bindings for tree-sitter.
5 // NB: this file mostly contains a generic lua interface for tree-sitter
6 // trees and nodes, and could be broken out as a reusable lua package
7
8 #include <assert.h>
9 #include <inttypes.h>
10 #include <lauxlib.h>
11 #include <lua.h>
12 #include <lualib.h>
13 #include <stdbool.h>
14 #include <stdlib.h>
15 #include <string.h>
16
17 #include "nvim/api/private/helpers.h"
18 #include "nvim/buffer.h"
19 #include "nvim/lua/treesitter.h"
20 #include "nvim/memline.h"
21 #include "tree_sitter/api.h"
22
23 #define TS_META_PARSER "treesitter_parser"
24 #define TS_META_TREE "treesitter_tree"
25 #define TS_META_NODE "treesitter_node"
26 #define TS_META_QUERY "treesitter_query"
27 #define TS_META_QUERYCURSOR "treesitter_querycursor"
28 #define TS_META_TREECURSOR "treesitter_treecursor"
29
30 typedef struct {
31 TSQueryCursor *cursor;
32 int predicated_match;
33 int max_match_id;
34 } TSLua_cursor;
35
36 #ifdef INCLUDE_GENERATED_DECLARATIONS
37 # include "lua/treesitter.c.generated.h"
38 #endif
39
40 static struct luaL_Reg parser_meta[] = {
41 { "__gc", parser_gc },
42 { "__tostring", parser_tostring },
43 { "parse", parser_parse },
44 { "set_included_ranges", parser_set_ranges },
45 { "included_ranges", parser_get_ranges },
46 { NULL, NULL }
47 };
48
49 static struct luaL_Reg tree_meta[] = {
50 { "__gc", tree_gc },
51 { "__tostring", tree_tostring },
52 { "root", tree_root },
53 { "edit", tree_edit },
54 { "copy", tree_copy },
55 { NULL, NULL }
56 };
57
58 static struct luaL_Reg node_meta[] = {
59 { "__tostring", node_tostring },
60 { "__eq", node_eq },
61 { "__len", node_child_count },
62 { "id", node_id },
63 { "range", node_range },
64 { "start", node_start },
65 { "end_", node_end },
66 { "type", node_type },
67 { "symbol", node_symbol },
68 { "field", node_field },
69 { "named", node_named },
70 { "missing", node_missing },
71 { "has_error", node_has_error },
72 { "sexpr", node_sexpr },
73 { "child_count", node_child_count },
74 { "named_child_count", node_named_child_count },
75 { "child", node_child },
76 { "named_child", node_named_child },
77 { "descendant_for_range", node_descendant_for_range },
78 { "named_descendant_for_range", node_named_descendant_for_range },
79 { "parent", node_parent },
80 { "iter_children", node_iter_children },
81 { "_rawquery", node_rawquery },
82 { "next_sibling", node_next_sibling },
83 { "prev_sibling", node_prev_sibling },
84 { "next_named_sibling", node_next_named_sibling },
85 { "prev_named_sibling", node_prev_named_sibling },
86 { NULL, NULL }
87 };
88
89 static struct luaL_Reg query_meta[] = {
90 { "__gc", query_gc },
91 { "__tostring", query_tostring },
92 { "inspect", query_inspect },
93 { NULL, NULL }
94 };
95
96 // cursors are not exposed, but still needs garbage collection
97 static struct luaL_Reg querycursor_meta[] = {
98 { "__gc", querycursor_gc },
99 { NULL, NULL }
100 };
101
102 static struct luaL_Reg treecursor_meta[] = {
103 { "__gc", treecursor_gc },
104 { NULL, NULL }
105 };
106
107 static PMap(cstr_t) langs = MAP_INIT;
108
build_meta(lua_State * L,const char * tname,const luaL_Reg * meta)109 static void build_meta(lua_State *L, const char *tname, const luaL_Reg *meta)
110 {
111 if (luaL_newmetatable(L, tname)) { // [meta]
112 luaL_register(L, NULL, meta);
113
114 lua_pushvalue(L, -1); // [meta, meta]
115 lua_setfield(L, -2, "__index"); // [meta]
116 }
117 lua_pop(L, 1); // [] (don't use it now)
118 }
119
120 /// init the tslua library
121 ///
122 /// all global state is stored in the regirstry of the lua_State
tslua_init(lua_State * L)123 void tslua_init(lua_State *L)
124 {
125 // type metatables
126 build_meta(L, TS_META_PARSER, parser_meta);
127 build_meta(L, TS_META_TREE, tree_meta);
128 build_meta(L, TS_META_NODE, node_meta);
129 build_meta(L, TS_META_QUERY, query_meta);
130 build_meta(L, TS_META_QUERYCURSOR, querycursor_meta);
131 build_meta(L, TS_META_TREECURSOR, treecursor_meta);
132 }
133
tslua_has_language(lua_State * L)134 int tslua_has_language(lua_State *L)
135 {
136 const char *lang_name = luaL_checkstring(L, 1);
137 lua_pushboolean(L, pmap_has(cstr_t)(&langs, lang_name));
138 return 1;
139 }
140
tslua_add_language(lua_State * L)141 int tslua_add_language(lua_State *L)
142 {
143 const char *path = luaL_checkstring(L, 1);
144 const char *lang_name = luaL_checkstring(L, 2);
145
146 if (pmap_has(cstr_t)(&langs, lang_name)) {
147 return 0;
148 }
149
150 #define BUFSIZE 128
151 char symbol_buf[BUFSIZE];
152 snprintf(symbol_buf, BUFSIZE, "tree_sitter_%s", lang_name);
153 #undef BUFSIZE
154
155 uv_lib_t lib;
156 if (uv_dlopen(path, &lib)) {
157 snprintf((char *)IObuff, IOSIZE, "Failed to load parser: uv_dlopen: %s",
158 uv_dlerror(&lib));
159 uv_dlclose(&lib);
160 lua_pushstring(L, (char *)IObuff);
161 return lua_error(L);
162 }
163
164 TSLanguage *(*lang_parser)(void);
165 if (uv_dlsym(&lib, symbol_buf, (void **)&lang_parser)) {
166 snprintf((char *)IObuff, IOSIZE, "Failed to load parser: uv_dlsym: %s",
167 uv_dlerror(&lib));
168 uv_dlclose(&lib);
169 lua_pushstring(L, (char *)IObuff);
170 return lua_error(L);
171 }
172
173 TSLanguage *lang = lang_parser();
174 if (lang == NULL) {
175 return luaL_error(L, "Failed to load parser %s: internal error", path);
176 }
177
178 uint32_t lang_version = ts_language_version(lang);
179 if (lang_version < TREE_SITTER_MIN_COMPATIBLE_LANGUAGE_VERSION
180 || lang_version > TREE_SITTER_LANGUAGE_VERSION) {
181 return luaL_error(L,
182 "ABI version mismatch for %s: supported between %d and %d, found %d",
183 path,
184 TREE_SITTER_MIN_COMPATIBLE_LANGUAGE_VERSION,
185 TREE_SITTER_LANGUAGE_VERSION, lang_version);
186 }
187
188 pmap_put(cstr_t)(&langs, xstrdup(lang_name), lang);
189
190 lua_pushboolean(L, true);
191 return 1;
192 }
193
tslua_inspect_lang(lua_State * L)194 int tslua_inspect_lang(lua_State *L)
195 {
196 const char *lang_name = luaL_checkstring(L, 1);
197
198 TSLanguage *lang = pmap_get(cstr_t)(&langs, lang_name);
199 if (!lang) {
200 return luaL_error(L, "no such language: %s", lang_name);
201 }
202
203 lua_createtable(L, 0, 2); // [retval]
204
205 size_t nsymbols = (size_t)ts_language_symbol_count(lang);
206
207 lua_createtable(L, nsymbols-1, 1); // [retval, symbols]
208 for (size_t i = 0; i < nsymbols; i++) {
209 TSSymbolType t = ts_language_symbol_type(lang, i);
210 if (t == TSSymbolTypeAuxiliary) {
211 // not used by the API
212 continue;
213 }
214 lua_createtable(L, 2, 0); // [retval, symbols, elem]
215 lua_pushstring(L, ts_language_symbol_name(lang, i));
216 lua_rawseti(L, -2, 1);
217 lua_pushboolean(L, t == TSSymbolTypeRegular);
218 lua_rawseti(L, -2, 2); // [retval, symbols, elem]
219 lua_rawseti(L, -2, i); // [retval, symbols]
220 }
221
222 lua_setfield(L, -2, "symbols"); // [retval]
223
224 size_t nfields = (size_t)ts_language_field_count(lang);
225 lua_createtable(L, nfields, 1); // [retval, fields]
226 // Field IDs go from 1 to nfields inclusive (extra index 0 maps to NULL)
227 for (size_t i = 1; i <= nfields; i++) {
228 lua_pushstring(L, ts_language_field_name_for_id(lang, i));
229 lua_rawseti(L, -2, i); // [retval, fields]
230 }
231
232 lua_setfield(L, -2, "fields"); // [retval]
233
234 uint32_t lang_version = ts_language_version(lang);
235 lua_pushinteger(L, lang_version); // [retval, version]
236 lua_setfield(L, -2, "_abi_version");
237
238 return 1;
239 }
240
tslua_push_parser(lua_State * L)241 int tslua_push_parser(lua_State *L)
242 {
243 // Gather language name
244 const char *lang_name = luaL_checkstring(L, 1);
245
246 TSLanguage *lang = pmap_get(cstr_t)(&langs, lang_name);
247 if (!lang) {
248 return luaL_error(L, "no such language: %s", lang_name);
249 }
250
251 TSParser **parser = lua_newuserdata(L, sizeof(TSParser *));
252 *parser = ts_parser_new();
253
254 if (!ts_parser_set_language(*parser, lang)) {
255 ts_parser_delete(*parser);
256 return luaL_error(L, "Failed to load language : %s", lang_name);
257 }
258
259 lua_getfield(L, LUA_REGISTRYINDEX, TS_META_PARSER); // [udata, meta]
260 lua_setmetatable(L, -2); // [udata]
261 return 1;
262 }
263
parser_check(lua_State * L,uint16_t index)264 static TSParser **parser_check(lua_State *L, uint16_t index)
265 {
266 return luaL_checkudata(L, index, TS_META_PARSER);
267 }
268
parser_gc(lua_State * L)269 static int parser_gc(lua_State *L)
270 {
271 TSParser **p = parser_check(L, 1);
272 if (!p) {
273 return 0;
274 }
275
276 ts_parser_delete(*p);
277 return 0;
278 }
279
parser_tostring(lua_State * L)280 static int parser_tostring(lua_State *L)
281 {
282 lua_pushstring(L, "<parser>");
283 return 1;
284 }
285
input_cb(void * payload,uint32_t byte_index,TSPoint position,uint32_t * bytes_read)286 static const char *input_cb(void *payload, uint32_t byte_index, TSPoint position,
287 uint32_t *bytes_read)
288 {
289 buf_T *bp = payload;
290 #define BUFSIZE 256
291 static char buf[BUFSIZE];
292
293 if ((linenr_T)position.row >= bp->b_ml.ml_line_count) {
294 *bytes_read = 0;
295 return "";
296 }
297 char_u *line = ml_get_buf(bp, position.row+1, false);
298 size_t len = STRLEN(line);
299 if (position.column > len) {
300 *bytes_read = 0;
301 return "";
302 }
303 size_t tocopy = MIN(len-position.column, BUFSIZE);
304
305 memcpy(buf, line+position.column, tocopy);
306 // Translate embedded \n to NUL
307 memchrsub(buf, '\n', '\0', tocopy);
308 *bytes_read = (uint32_t)tocopy;
309 if (tocopy < BUFSIZE) {
310 // now add the final \n. If it didn't fit, input_cb will be called again
311 // on the same line with advanced column.
312 buf[tocopy] = '\n';
313 (*bytes_read)++;
314 }
315 return buf;
316 #undef BUFSIZE
317 }
318
push_ranges(lua_State * L,const TSRange * ranges,const unsigned int length)319 static void push_ranges(lua_State *L, const TSRange *ranges, const unsigned int length)
320 {
321 lua_createtable(L, length, 0);
322 for (size_t i = 0; i < length; i++) {
323 lua_createtable(L, 4, 0);
324 lua_pushinteger(L, ranges[i].start_point.row);
325 lua_rawseti(L, -2, 1);
326 lua_pushinteger(L, ranges[i].start_point.column);
327 lua_rawseti(L, -2, 2);
328 lua_pushinteger(L, ranges[i].end_point.row);
329 lua_rawseti(L, -2, 3);
330 lua_pushinteger(L, ranges[i].end_point.column);
331 lua_rawseti(L, -2, 4);
332
333 lua_rawseti(L, -2, i+1);
334 }
335 }
336
parser_parse(lua_State * L)337 static int parser_parse(lua_State *L)
338 {
339 TSParser **p = parser_check(L, 1);
340 if (!p || !(*p)) {
341 return 0;
342 }
343
344 TSTree *old_tree = NULL;
345 if (!lua_isnil(L, 2)) {
346 TSTree **tmp = tree_check(L, 2);
347 old_tree = tmp ? *tmp : NULL;
348 }
349
350 TSTree *new_tree = NULL;
351 size_t len;
352 const char *str;
353 long bufnr;
354 buf_T *buf;
355 TSInput input;
356
357 // This switch is necessary because of the behavior of lua_isstring, that
358 // consider numbers as strings...
359 switch (lua_type(L, 3)) {
360 case LUA_TSTRING:
361 str = lua_tolstring(L, 3, &len);
362 new_tree = ts_parser_parse_string(*p, old_tree, str, len);
363 break;
364
365 case LUA_TNUMBER:
366 bufnr = lua_tointeger(L, 3);
367 buf = handle_get_buffer(bufnr);
368
369 if (!buf) {
370 return luaL_error(L, "invalid buffer handle: %d", bufnr);
371 }
372
373 input = (TSInput){ (void *)buf, input_cb, TSInputEncodingUTF8 };
374 new_tree = ts_parser_parse(*p, old_tree, input);
375
376 break;
377
378 default:
379 return luaL_error(L, "invalid argument to parser:parse()");
380 }
381
382 // Sometimes parsing fails (timeout, or wrong parser ABI)
383 // In those case, just return an error.
384 if (!new_tree) {
385 return luaL_error(L, "An error occurred when parsing.");
386 }
387
388 // The new tree will be pushed to the stack, without copy, owwership is now to
389 // the lua GC.
390 // Old tree is still owned by the lua GC.
391 uint32_t n_ranges = 0;
392 TSRange *changed = old_tree ? ts_tree_get_changed_ranges(old_tree, new_tree, &n_ranges) : NULL;
393
394 push_tree(L, new_tree, false); // [tree]
395
396 push_ranges(L, changed, n_ranges); // [tree, ranges]
397
398 xfree(changed);
399 return 2;
400 }
401
tree_copy(lua_State * L)402 static int tree_copy(lua_State *L)
403 {
404 TSTree **tree = tree_check(L, 1);
405 if (!(*tree)) {
406 return 0;
407 }
408
409 push_tree(L, *tree, true); // [tree]
410
411 return 1;
412 }
413
tree_edit(lua_State * L)414 static int tree_edit(lua_State *L)
415 {
416 if (lua_gettop(L) < 10) {
417 lua_pushstring(L, "not enough args to tree:edit()");
418 return lua_error(L);
419 }
420
421 TSTree **tree = tree_check(L, 1);
422 if (!(*tree)) {
423 return 0;
424 }
425
426 long start_byte = lua_tointeger(L, 2);
427 long old_end_byte = lua_tointeger(L, 3);
428 long new_end_byte = lua_tointeger(L, 4);
429 TSPoint start_point = { lua_tointeger(L, 5), lua_tointeger(L, 6) };
430 TSPoint old_end_point = { lua_tointeger(L, 7), lua_tointeger(L, 8) };
431 TSPoint new_end_point = { lua_tointeger(L, 9), lua_tointeger(L, 10) };
432
433 TSInputEdit edit = { start_byte, old_end_byte, new_end_byte,
434 start_point, old_end_point, new_end_point };
435
436 ts_tree_edit(*tree, &edit);
437
438 return 0;
439 }
440
441 // Use the top of the stack (without popping it) to create a TSRange, it can be
442 // either a lua table or a TSNode
range_from_lua(lua_State * L,TSRange * range)443 static void range_from_lua(lua_State *L, TSRange *range)
444 {
445 TSNode node;
446
447 if (lua_istable(L, -1)) {
448 // should be a table of 6 elements
449 if (lua_objlen(L, -1) != 6) {
450 goto error;
451 }
452
453 uint32_t start_row, start_col, start_byte, end_row, end_col, end_byte;
454 lua_rawgeti(L, -1, 1); // [ range, start_row]
455 start_row = luaL_checkinteger(L, -1);
456 lua_pop(L, 1);
457
458 lua_rawgeti(L, -1, 2); // [ range, start_col]
459 start_col = luaL_checkinteger(L, -1);
460 lua_pop(L, 1);
461
462 lua_rawgeti(L, -1, 3); // [ range, start_byte]
463 start_byte = luaL_checkinteger(L, -1);
464 lua_pop(L, 1);
465
466 lua_rawgeti(L, -1, 4); // [ range, end_row]
467 end_row = luaL_checkinteger(L, -1);
468 lua_pop(L, 1);
469
470 lua_rawgeti(L, -1, 5); // [ range, end_col]
471 end_col = luaL_checkinteger(L, -1);
472 lua_pop(L, 1);
473
474 lua_rawgeti(L, -1, 6); // [ range, end_byte]
475 end_byte = luaL_checkinteger(L, -1);
476 lua_pop(L, 1); // [ range ]
477
478 *range = (TSRange) {
479 .start_point = (TSPoint) {
480 .row = start_row,
481 .column = start_col
482 },
483 .end_point = (TSPoint) {
484 .row = end_row,
485 .column = end_col
486 },
487 .start_byte = start_byte,
488 .end_byte = end_byte,
489 };
490 } else if (node_check(L, -1, &node)) {
491 *range = (TSRange) {
492 .start_point = ts_node_start_point(node),
493 .end_point = ts_node_end_point(node),
494 .start_byte = ts_node_start_byte(node),
495 .end_byte = ts_node_end_byte(node)
496 };
497 } else {
498 goto error;
499 }
500 return;
501 error:
502 luaL_error(L,
503 "Ranges can only be made from 6 element long tables or nodes.");
504 }
505
parser_set_ranges(lua_State * L)506 static int parser_set_ranges(lua_State *L)
507 {
508 if (lua_gettop(L) < 2) {
509 return luaL_error(L,
510 "not enough args to parser:set_included_ranges()");
511 }
512
513 TSParser **p = parser_check(L, 1);
514 if (!p) {
515 return 0;
516 }
517
518 if (!lua_istable(L, 2)) {
519 return luaL_error(L,
520 "argument for parser:set_included_ranges() should be a table.");
521 }
522
523 size_t tbl_len = lua_objlen(L, 2);
524 TSRange *ranges = xmalloc(sizeof(TSRange) * tbl_len);
525
526
527 // [ parser, ranges ]
528 for (size_t index = 0; index < tbl_len; index++) {
529 lua_rawgeti(L, 2, index + 1); // [ parser, ranges, range ]
530 range_from_lua(L, ranges + index);
531 lua_pop(L, 1);
532 }
533
534 // This memcpies ranges, thus we can free it afterwards
535 ts_parser_set_included_ranges(*p, ranges, tbl_len);
536 xfree(ranges);
537
538 return 0;
539 }
540
parser_get_ranges(lua_State * L)541 static int parser_get_ranges(lua_State *L)
542 {
543 TSParser **p = parser_check(L, 1);
544 if (!p) {
545 return 0;
546 }
547
548 unsigned int len;
549 const TSRange *ranges = ts_parser_included_ranges(*p, &len);
550
551 push_ranges(L, ranges, len);
552 return 1;
553 }
554
555
556 // Tree methods
557
558 /// push tree interface on lua stack.
559 ///
560 /// This makes a copy of the tree, so ownership of the argument is unaffected.
push_tree(lua_State * L,TSTree * tree,bool do_copy)561 void push_tree(lua_State *L, TSTree *tree, bool do_copy)
562 {
563 if (tree == NULL) {
564 lua_pushnil(L);
565 return;
566 }
567 TSTree **ud = lua_newuserdata(L, sizeof(TSTree *)); // [udata]
568
569 if (do_copy) {
570 *ud = ts_tree_copy(tree);
571 } else {
572 *ud = tree;
573 }
574
575 lua_getfield(L, LUA_REGISTRYINDEX, TS_META_TREE); // [udata, meta]
576 lua_setmetatable(L, -2); // [udata]
577
578 // table used for node wrappers to keep a reference to tree wrapper
579 // NB: in lua 5.3 the uservalue for the node could just be the tree, but
580 // in lua 5.1 the uservalue (fenv) must be a table.
581 lua_createtable(L, 1, 0); // [udata, reftable]
582 lua_pushvalue(L, -2); // [udata, reftable, udata]
583 lua_rawseti(L, -2, 1); // [udata, reftable]
584 lua_setfenv(L, -2); // [udata]
585 }
586
tree_check(lua_State * L,uint16_t index)587 static TSTree **tree_check(lua_State *L, uint16_t index)
588 {
589 TSTree **ud = luaL_checkudata(L, index, TS_META_TREE);
590 return ud;
591 }
592
tree_gc(lua_State * L)593 static int tree_gc(lua_State *L)
594 {
595 TSTree **tree = tree_check(L, 1);
596 if (!tree) {
597 return 0;
598 }
599
600 ts_tree_delete(*tree);
601 return 0;
602 }
603
tree_tostring(lua_State * L)604 static int tree_tostring(lua_State *L)
605 {
606 lua_pushstring(L, "<tree>");
607 return 1;
608 }
609
tree_root(lua_State * L)610 static int tree_root(lua_State *L)
611 {
612 TSTree **tree = tree_check(L, 1);
613 if (!tree) {
614 return 0;
615 }
616 TSNode root = ts_tree_root_node(*tree);
617 push_node(L, root, 1);
618 return 1;
619 }
620
621 // Node methods
622
623 /// push node interface on lua stack
624 ///
625 /// top of stack must either be the tree this node belongs to or another node
626 /// of the same tree! This value is not popped. Can only be called inside a
627 /// cfunction with the tslua environment.
push_node(lua_State * L,TSNode node,int uindex)628 static void push_node(lua_State *L, TSNode node, int uindex)
629 {
630 assert(uindex > 0 || uindex < -LUA_MINSTACK);
631 if (ts_node_is_null(node)) {
632 lua_pushnil(L); // [nil]
633 return;
634 }
635 TSNode *ud = lua_newuserdata(L, sizeof(TSNode)); // [udata]
636 *ud = node;
637 lua_getfield(L, LUA_REGISTRYINDEX, TS_META_NODE); // [udata, meta]
638 lua_setmetatable(L, -2); // [udata]
639 lua_getfenv(L, uindex); // [udata, reftable]
640 lua_setfenv(L, -2); // [udata]
641 }
642
node_check(lua_State * L,int index,TSNode * res)643 static bool node_check(lua_State *L, int index, TSNode *res)
644 {
645 TSNode *ud = luaL_checkudata(L, index, TS_META_NODE);
646 if (ud) {
647 *res = *ud;
648 return true;
649 }
650 return false;
651 }
652
653
node_tostring(lua_State * L)654 static int node_tostring(lua_State *L)
655 {
656 TSNode node;
657 if (!node_check(L, 1, &node)) {
658 return 0;
659 }
660 lua_pushstring(L, "<node ");
661 lua_pushstring(L, ts_node_type(node));
662 lua_pushstring(L, ">");
663 lua_concat(L, 3);
664 return 1;
665 }
666
node_eq(lua_State * L)667 static int node_eq(lua_State *L)
668 {
669 TSNode node;
670 if (!node_check(L, 1, &node)) {
671 return 0;
672 }
673
674 TSNode node2;
675 if (!node_check(L, 2, &node2)) {
676 return 0;
677 }
678
679 lua_pushboolean(L, ts_node_eq(node, node2));
680 return 1;
681 }
682
node_id(lua_State * L)683 static int node_id(lua_State *L)
684 {
685 TSNode node;
686 if (!node_check(L, 1, &node)) {
687 return 0;
688 }
689
690 lua_pushlstring(L, (const char *)&node.id, sizeof node.id);
691 return 1;
692 }
693
node_range(lua_State * L)694 static int node_range(lua_State *L)
695 {
696 TSNode node;
697 if (!node_check(L, 1, &node)) {
698 return 0;
699 }
700 TSPoint start = ts_node_start_point(node);
701 TSPoint end = ts_node_end_point(node);
702 lua_pushnumber(L, start.row);
703 lua_pushnumber(L, start.column);
704 lua_pushnumber(L, end.row);
705 lua_pushnumber(L, end.column);
706 return 4;
707 }
708
node_start(lua_State * L)709 static int node_start(lua_State *L)
710 {
711 TSNode node;
712 if (!node_check(L, 1, &node)) {
713 return 0;
714 }
715 TSPoint start = ts_node_start_point(node);
716 uint32_t start_byte = ts_node_start_byte(node);
717 lua_pushnumber(L, start.row);
718 lua_pushnumber(L, start.column);
719 lua_pushnumber(L, start_byte);
720 return 3;
721 }
722
node_end(lua_State * L)723 static int node_end(lua_State *L)
724 {
725 TSNode node;
726 if (!node_check(L, 1, &node)) {
727 return 0;
728 }
729 TSPoint end = ts_node_end_point(node);
730 uint32_t end_byte = ts_node_end_byte(node);
731 lua_pushnumber(L, end.row);
732 lua_pushnumber(L, end.column);
733 lua_pushnumber(L, end_byte);
734 return 3;
735 }
736
node_child_count(lua_State * L)737 static int node_child_count(lua_State *L)
738 {
739 TSNode node;
740 if (!node_check(L, 1, &node)) {
741 return 0;
742 }
743 uint32_t count = ts_node_child_count(node);
744 lua_pushnumber(L, count);
745 return 1;
746 }
747
node_named_child_count(lua_State * L)748 static int node_named_child_count(lua_State *L)
749 {
750 TSNode node;
751 if (!node_check(L, 1, &node)) {
752 return 0;
753 }
754 uint32_t count = ts_node_named_child_count(node);
755 lua_pushnumber(L, count);
756 return 1;
757 }
758
node_type(lua_State * L)759 static int node_type(lua_State *L)
760 {
761 TSNode node;
762 if (!node_check(L, 1, &node)) {
763 return 0;
764 }
765 lua_pushstring(L, ts_node_type(node));
766 return 1;
767 }
768
node_symbol(lua_State * L)769 static int node_symbol(lua_State *L)
770 {
771 TSNode node;
772 if (!node_check(L, 1, &node)) {
773 return 0;
774 }
775 TSSymbol symbol = ts_node_symbol(node);
776 lua_pushnumber(L, symbol);
777 return 1;
778 }
779
node_field(lua_State * L)780 static int node_field(lua_State *L)
781 {
782 TSNode node;
783 if (!node_check(L, 1, &node)) {
784 return 0;
785 }
786
787 size_t name_len;
788 const char *field_name = luaL_checklstring(L, 2, &name_len);
789
790 TSTreeCursor cursor = ts_tree_cursor_new(node);
791
792 lua_newtable(L); // [table]
793 unsigned int curr_index = 0;
794
795 if (ts_tree_cursor_goto_first_child(&cursor)) {
796 do {
797 const char *current_field = ts_tree_cursor_current_field_name(&cursor);
798
799 if (current_field != NULL && !STRCMP(field_name, current_field)) {
800 push_node(L, ts_tree_cursor_current_node(&cursor), 1); // [table, node]
801 lua_rawseti(L, -2, ++curr_index);
802 }
803 } while (ts_tree_cursor_goto_next_sibling(&cursor));
804 }
805
806 ts_tree_cursor_delete(&cursor);
807 return 1;
808 }
809
node_named(lua_State * L)810 static int node_named(lua_State *L)
811 {
812 TSNode node;
813 if (!node_check(L, 1, &node)) {
814 return 0;
815 }
816 lua_pushboolean(L, ts_node_is_named(node));
817 return 1;
818 }
819
node_sexpr(lua_State * L)820 static int node_sexpr(lua_State *L)
821 {
822 TSNode node;
823 if (!node_check(L, 1, &node)) {
824 return 0;
825 }
826 char *allocated = ts_node_string(node);
827 lua_pushstring(L, allocated);
828 xfree(allocated);
829 return 1;
830 }
831
node_missing(lua_State * L)832 static int node_missing(lua_State *L)
833 {
834 TSNode node;
835 if (!node_check(L, 1, &node)) {
836 return 0;
837 }
838 lua_pushboolean(L, ts_node_is_missing(node));
839 return 1;
840 }
841
node_has_error(lua_State * L)842 static int node_has_error(lua_State *L)
843 {
844 TSNode node;
845 if (!node_check(L, 1, &node)) {
846 return 0;
847 }
848 lua_pushboolean(L, ts_node_has_error(node));
849 return 1;
850 }
851
node_child(lua_State * L)852 static int node_child(lua_State *L)
853 {
854 TSNode node;
855 if (!node_check(L, 1, &node)) {
856 return 0;
857 }
858 long num = lua_tointeger(L, 2);
859 TSNode child = ts_node_child(node, (uint32_t)num);
860
861 push_node(L, child, 1);
862 return 1;
863 }
864
node_named_child(lua_State * L)865 static int node_named_child(lua_State *L)
866 {
867 TSNode node;
868 if (!node_check(L, 1, &node)) {
869 return 0;
870 }
871 long num = lua_tointeger(L, 2);
872 TSNode child = ts_node_named_child(node, (uint32_t)num);
873
874 push_node(L, child, 1);
875 return 1;
876 }
877
node_descendant_for_range(lua_State * L)878 static int node_descendant_for_range(lua_State *L)
879 {
880 TSNode node;
881 if (!node_check(L, 1, &node)) {
882 return 0;
883 }
884 TSPoint start = { (uint32_t)lua_tointeger(L, 2),
885 (uint32_t)lua_tointeger(L, 3) };
886 TSPoint end = { (uint32_t)lua_tointeger(L, 4),
887 (uint32_t)lua_tointeger(L, 5) };
888 TSNode child = ts_node_descendant_for_point_range(node, start, end);
889
890 push_node(L, child, 1);
891 return 1;
892 }
893
node_named_descendant_for_range(lua_State * L)894 static int node_named_descendant_for_range(lua_State *L)
895 {
896 TSNode node;
897 if (!node_check(L, 1, &node)) {
898 return 0;
899 }
900 TSPoint start = { (uint32_t)lua_tointeger(L, 2),
901 (uint32_t)lua_tointeger(L, 3) };
902 TSPoint end = { (uint32_t)lua_tointeger(L, 4),
903 (uint32_t)lua_tointeger(L, 5) };
904 TSNode child = ts_node_named_descendant_for_point_range(node, start, end);
905
906 push_node(L, child, 1);
907 return 1;
908 }
909
node_next_child(lua_State * L)910 static int node_next_child(lua_State *L)
911 {
912 TSTreeCursor *ud = luaL_checkudata(L, lua_upvalueindex(1), TS_META_TREECURSOR);
913 if (!ud) {
914 return 0;
915 }
916
917 TSNode source;
918 if (!node_check(L, lua_upvalueindex(2), &source)) {
919 return 0;
920 }
921
922 // First call should return first child
923 if (ts_node_eq(source, ts_tree_cursor_current_node(ud))) {
924 if (ts_tree_cursor_goto_first_child(ud)) {
925 goto push;
926 } else {
927 goto end;
928 }
929 }
930
931 if (ts_tree_cursor_goto_next_sibling(ud)) {
932 push:
933 push_node(L,
934 ts_tree_cursor_current_node(ud),
935 lua_upvalueindex(2)); // [node]
936
937 const char *field = ts_tree_cursor_current_field_name(ud);
938
939 if (field != NULL) {
940 lua_pushstring(L, ts_tree_cursor_current_field_name(ud));
941 } else {
942 lua_pushnil(L);
943 } // [node, field_name_or_nil]
944 return 2;
945 }
946
947 end:
948 return 0;
949 }
950
node_iter_children(lua_State * L)951 static int node_iter_children(lua_State *L)
952 {
953 TSNode source;
954 if (!node_check(L, 1, &source)) {
955 return 0;
956 }
957
958 TSTreeCursor *ud = lua_newuserdata(L, sizeof(TSTreeCursor)); // [udata]
959 *ud = ts_tree_cursor_new(source);
960
961 lua_getfield(L, LUA_REGISTRYINDEX, TS_META_TREECURSOR); // [udata, mt]
962 lua_setmetatable(L, -2); // [udata]
963 lua_pushvalue(L, 1); // [udata, source_node]
964 lua_pushcclosure(L, node_next_child, 2);
965
966 return 1;
967 }
968
treecursor_gc(lua_State * L)969 static int treecursor_gc(lua_State *L)
970 {
971 TSTreeCursor *ud = luaL_checkudata(L, 1, TS_META_TREECURSOR);
972 ts_tree_cursor_delete(ud);
973 return 0;
974 }
975
node_parent(lua_State * L)976 static int node_parent(lua_State *L)
977 {
978 TSNode node;
979 if (!node_check(L, 1, &node)) {
980 return 0;
981 }
982 TSNode parent = ts_node_parent(node);
983 push_node(L, parent, 1);
984 return 1;
985 }
986
node_next_sibling(lua_State * L)987 static int node_next_sibling(lua_State *L)
988 {
989 TSNode node;
990 if (!node_check(L, 1, &node)) {
991 return 0;
992 }
993 TSNode sibling = ts_node_next_sibling(node);
994 push_node(L, sibling, 1);
995 return 1;
996 }
997
node_prev_sibling(lua_State * L)998 static int node_prev_sibling(lua_State *L)
999 {
1000 TSNode node;
1001 if (!node_check(L, 1, &node)) {
1002 return 0;
1003 }
1004 TSNode sibling = ts_node_prev_sibling(node);
1005 push_node(L, sibling, 1);
1006 return 1;
1007 }
1008
node_next_named_sibling(lua_State * L)1009 static int node_next_named_sibling(lua_State *L)
1010 {
1011 TSNode node;
1012 if (!node_check(L, 1, &node)) {
1013 return 0;
1014 }
1015 TSNode sibling = ts_node_next_named_sibling(node);
1016 push_node(L, sibling, 1);
1017 return 1;
1018 }
1019
node_prev_named_sibling(lua_State * L)1020 static int node_prev_named_sibling(lua_State *L)
1021 {
1022 TSNode node;
1023 if (!node_check(L, 1, &node)) {
1024 return 0;
1025 }
1026 TSNode sibling = ts_node_prev_named_sibling(node);
1027 push_node(L, sibling, 1);
1028 return 1;
1029 }
1030
1031 /// assumes the match table being on top of the stack
set_match(lua_State * L,TSQueryMatch * match,int nodeidx)1032 static void set_match(lua_State *L, TSQueryMatch *match, int nodeidx)
1033 {
1034 for (int i = 0; i < match->capture_count; i++) {
1035 push_node(L, match->captures[i].node, nodeidx);
1036 lua_rawseti(L, -2, match->captures[i].index+1);
1037 }
1038 }
1039
query_next_match(lua_State * L)1040 static int query_next_match(lua_State *L)
1041 {
1042 TSLua_cursor *ud = lua_touserdata(L, lua_upvalueindex(1));
1043 TSQueryCursor *cursor = ud->cursor;
1044
1045 TSQuery *query = query_check(L, lua_upvalueindex(3));
1046 TSQueryMatch match;
1047 if (ts_query_cursor_next_match(cursor, &match)) {
1048 lua_pushinteger(L, match.pattern_index+1); // [index]
1049 lua_createtable(L, ts_query_capture_count(query), 2); // [index, match]
1050 set_match(L, &match, lua_upvalueindex(2));
1051 return 2;
1052 }
1053 return 0;
1054 }
1055
1056
query_next_capture(lua_State * L)1057 static int query_next_capture(lua_State *L)
1058 {
1059 // Upvalues are:
1060 // [ cursor, node, query, current_match ]
1061 TSLua_cursor *ud = lua_touserdata(L, lua_upvalueindex(1));
1062 TSQueryCursor *cursor = ud->cursor;
1063
1064 TSQuery *query = query_check(L, lua_upvalueindex(3));
1065
1066 if (ud->predicated_match > -1) {
1067 lua_getfield(L, lua_upvalueindex(4), "active");
1068 bool active = lua_toboolean(L, -1);
1069 lua_pop(L, 1);
1070 if (!active) {
1071 ts_query_cursor_remove_match(cursor, ud->predicated_match);
1072 }
1073 ud->predicated_match = -1;
1074 }
1075
1076 TSQueryMatch match;
1077 uint32_t capture_index;
1078 if (ts_query_cursor_next_capture(cursor, &match, &capture_index)) {
1079 TSQueryCapture capture = match.captures[capture_index];
1080
1081 lua_pushinteger(L, capture.index+1); // [index]
1082 push_node(L, capture.node, lua_upvalueindex(2)); // [index, node]
1083
1084 // Now check if we need to run the predicates
1085 uint32_t n_pred;
1086 ts_query_predicates_for_pattern(query, match.pattern_index, &n_pred);
1087
1088 if (n_pred > 0 && (ud->max_match_id < (int)match.id)) {
1089 ud->max_match_id = match.id;
1090
1091 lua_pushvalue(L, lua_upvalueindex(4)); // [index, node, match]
1092 set_match(L, &match, lua_upvalueindex(2));
1093 lua_pushinteger(L, match.pattern_index+1);
1094 lua_setfield(L, -2, "pattern");
1095
1096 if (match.capture_count > 1) {
1097 ud->predicated_match = match.id;
1098 lua_pushboolean(L, false);
1099 lua_setfield(L, -2, "active");
1100 }
1101 return 3;
1102 }
1103 return 2;
1104 }
1105 return 0;
1106 }
1107
node_rawquery(lua_State * L)1108 static int node_rawquery(lua_State *L)
1109 {
1110 TSNode node;
1111 if (!node_check(L, 1, &node)) {
1112 return 0;
1113 }
1114 TSQuery *query = query_check(L, 2);
1115 // TODO(bfredl): these are expensive allegedly,
1116 // use a reuse list later on?
1117 TSQueryCursor *cursor = ts_query_cursor_new();
1118 // TODO(clason): API introduced after tree-sitter release 0.19.5
1119 // remove guard when minimum ts version is bumped to 0.19.6+
1120 #ifdef NVIM_TS_HAS_SET_MATCH_LIMIT
1121 ts_query_cursor_set_match_limit(cursor, 32);
1122 #endif
1123 ts_query_cursor_exec(cursor, query, node);
1124
1125 bool captures = lua_toboolean(L, 3);
1126
1127 if (lua_gettop(L) >= 4) {
1128 int start = luaL_checkinteger(L, 4);
1129 int end = lua_gettop(L) >= 5 ? luaL_checkinteger(L, 5) : MAXLNUM;
1130 ts_query_cursor_set_point_range(cursor,
1131 (TSPoint){ start, 0 }, (TSPoint){ end, 0 });
1132 }
1133
1134 TSLua_cursor *ud = lua_newuserdata(L, sizeof(*ud)); // [udata]
1135 ud->cursor = cursor;
1136 ud->predicated_match = -1;
1137 ud->max_match_id = -1;
1138
1139 lua_getfield(L, LUA_REGISTRYINDEX, TS_META_QUERYCURSOR);
1140 lua_setmetatable(L, -2); // [udata]
1141 lua_pushvalue(L, 1); // [udata, node]
1142
1143 // include query separately, as to keep a ref to it for gc
1144 lua_pushvalue(L, 2); // [udata, node, query]
1145
1146 if (captures) {
1147 // placeholder for match state
1148 lua_createtable(L, ts_query_capture_count(query), 2); // [u, n, q, match]
1149 lua_pushcclosure(L, query_next_capture, 4); // [closure]
1150 } else {
1151 lua_pushcclosure(L, query_next_match, 3); // [closure]
1152 }
1153
1154 return 1;
1155 }
1156
querycursor_gc(lua_State * L)1157 static int querycursor_gc(lua_State *L)
1158 {
1159 TSLua_cursor *ud = luaL_checkudata(L, 1, TS_META_QUERYCURSOR);
1160 ts_query_cursor_delete(ud->cursor);
1161 return 0;
1162 }
1163
1164 // Query methods
1165
tslua_parse_query(lua_State * L)1166 int tslua_parse_query(lua_State *L)
1167 {
1168 if (lua_gettop(L) < 2 || !lua_isstring(L, 1) || !lua_isstring(L, 2)) {
1169 return luaL_error(L, "string expected");
1170 }
1171
1172 const char *lang_name = lua_tostring(L, 1);
1173 TSLanguage *lang = pmap_get(cstr_t)(&langs, lang_name);
1174 if (!lang) {
1175 return luaL_error(L, "no such language: %s", lang_name);
1176 }
1177
1178 size_t len;
1179 const char *src = lua_tolstring(L, 2, &len);
1180
1181 uint32_t error_offset;
1182 TSQueryError error_type;
1183 TSQuery *query = ts_query_new(lang, src, len, &error_offset, &error_type);
1184
1185 if (!query) {
1186 return luaL_error(L, "query: %s at position %d",
1187 query_err_string(error_type), (int)error_offset);
1188 }
1189
1190 TSQuery **ud = lua_newuserdata(L, sizeof(TSQuery *)); // [udata]
1191 *ud = query;
1192 lua_getfield(L, LUA_REGISTRYINDEX, TS_META_QUERY); // [udata, meta]
1193 lua_setmetatable(L, -2); // [udata]
1194 return 1;
1195 }
1196
1197
query_err_string(TSQueryError err)1198 static const char *query_err_string(TSQueryError err)
1199 {
1200 switch (err) {
1201 case TSQueryErrorSyntax:
1202 return "invalid syntax";
1203 case TSQueryErrorNodeType:
1204 return "invalid node type";
1205 case TSQueryErrorField:
1206 return "invalid field";
1207 case TSQueryErrorCapture:
1208 return "invalid capture";
1209 default:
1210 return "error";
1211 }
1212 }
1213
query_check(lua_State * L,int index)1214 static TSQuery *query_check(lua_State *L, int index)
1215 {
1216 TSQuery **ud = luaL_checkudata(L, index, TS_META_QUERY);
1217 return *ud;
1218 }
1219
query_gc(lua_State * L)1220 static int query_gc(lua_State *L)
1221 {
1222 TSQuery *query = query_check(L, 1);
1223 if (!query) {
1224 return 0;
1225 }
1226
1227 ts_query_delete(query);
1228 return 0;
1229 }
1230
query_tostring(lua_State * L)1231 static int query_tostring(lua_State *L)
1232 {
1233 lua_pushstring(L, "<query>");
1234 return 1;
1235 }
1236
query_inspect(lua_State * L)1237 static int query_inspect(lua_State *L)
1238 {
1239 TSQuery *query = query_check(L, 1);
1240 if (!query) {
1241 return 0;
1242 }
1243
1244 uint32_t n_pat = ts_query_pattern_count(query);
1245 lua_createtable(L, 0, 2); // [retval]
1246 lua_createtable(L, n_pat, 1); // [retval, patterns]
1247 for (size_t i = 0; i < n_pat; i++) {
1248 uint32_t len;
1249 const TSQueryPredicateStep *step = ts_query_predicates_for_pattern(query,
1250 i, &len);
1251 if (len == 0) {
1252 continue;
1253 }
1254 lua_createtable(L, len/4, 1); // [retval, patterns, pat]
1255 lua_createtable(L, 3, 0); // [retval, patterns, pat, pred]
1256 int nextpred = 1;
1257 int nextitem = 1;
1258 for (size_t k = 0; k < len; k++) {
1259 if (step[k].type == TSQueryPredicateStepTypeDone) {
1260 lua_rawseti(L, -2, nextpred++); // [retval, patterns, pat]
1261 lua_createtable(L, 3, 0); // [retval, patterns, pat, pred]
1262 nextitem = 1;
1263 continue;
1264 }
1265
1266 if (step[k].type == TSQueryPredicateStepTypeString) {
1267 uint32_t strlen;
1268 const char *str = ts_query_string_value_for_id(query, step[k].value_id,
1269 &strlen);
1270 lua_pushlstring(L, str, strlen); // [retval, patterns, pat, pred, item]
1271 } else if (step[k].type == TSQueryPredicateStepTypeCapture) {
1272 lua_pushnumber(L, step[k].value_id+1); // [..., pat, pred, item]
1273 } else {
1274 abort();
1275 }
1276 lua_rawseti(L, -2, nextitem++); // [retval, patterns, pat, pred]
1277 }
1278 // last predicate should have ended with TypeDone
1279 lua_pop(L, 1); // [retval, patters, pat]
1280 lua_rawseti(L, -2, i+1); // [retval, patterns]
1281 }
1282 lua_setfield(L, -2, "patterns"); // [retval]
1283
1284 uint32_t n_captures = ts_query_capture_count(query);
1285 lua_createtable(L, n_captures, 0); // [retval, captures]
1286 for (size_t i = 0; i < n_captures; i++) {
1287 uint32_t strlen;
1288 const char *str = ts_query_capture_name_for_id(query, i, &strlen);
1289 lua_pushlstring(L, str, strlen); // [retval, captures, capture]
1290 lua_rawseti(L, -2, i+1);
1291 }
1292 lua_setfield(L, -2, "captures"); // [retval]
1293
1294 return 1;
1295 }
1296