1 #include "doc_parser.h"
2 
3 #include "parse/parselo.h"
4 #include "libs/antlr/ErrorListener.h"
5 
6 #undef TRUE
7 #undef FALSE
8 
9 PUSH_SUPPRESS_WARNINGS
10 #include "arg_parser/generated/ArgumentListLexer.h"
11 #include "arg_parser/generated/ArgumentListParser.h"
12 #include "arg_parser/generated/ArgumentListVisitor.h"
13 POP_SUPPRESS_WARNINGS
14 
15 namespace scripting {
16 
17 namespace {
18 
merge_alternatives(const ade_type_info & left,const ade_type_info & right)19 ade_type_info merge_alternatives(const ade_type_info& left, const ade_type_info& right)
20 {
21 	SCP_vector<ade_type_info> elements;
22 	if (left.isSimple()) {
23 		elements.push_back(left);
24 	} else {
25 		elements.insert(elements.end(), left.elements().begin(), left.elements().end());
26 	}
27 	if (right.isSimple()) {
28 		elements.push_back(right);
29 	} else {
30 		elements.insert(elements.end(), right.elements().begin(), right.elements().end());
31 	}
32 
33 	return ade_type_info(ade_type_alternative(elements));
34 }
35 
36 class BaseVisitor : public ArgumentListVisitor {
failure()37 	static void failure()
38 	{
39 		throw std::runtime_error("Unhandled node encountered!");
40 	}
41 
42   public:
visitArg_list(ArgumentListParser::Arg_listContext *)43 	antlrcpp::Any visitArg_list(ArgumentListParser::Arg_listContext* /*context*/) override
44 	{
45 		failure();
46 		return antlrcpp::Any();
47 	}
visitStandalone_type(ArgumentListParser::Standalone_typeContext *)48 	antlrcpp::Any visitStandalone_type(ArgumentListParser::Standalone_typeContext* /*context*/) override
49 	{
50 		failure();
51 		return antlrcpp::Any();
52 	}
visitMap_type(ArgumentListParser::Map_typeContext *)53 	antlrcpp::Any visitMap_type(ArgumentListParser::Map_typeContext* /*context*/) override
54 	{
55 		failure();
56 		return antlrcpp::Any();
57 	}
visitIterator_type(ArgumentListParser::Iterator_typeContext *)58 	antlrcpp::Any visitIterator_type(ArgumentListParser::Iterator_typeContext* /*context*/) override
59 	{
60 		failure();
61 		return antlrcpp::Any();
62 	}
visitSimple_type(ArgumentListParser::Simple_typeContext *)63 	antlrcpp::Any visitSimple_type(ArgumentListParser::Simple_typeContext* /*context*/) override
64 	{
65 		failure();
66 		return antlrcpp::Any();
67 	}
visitType(ArgumentListParser::TypeContext *)68 	antlrcpp::Any visitType(ArgumentListParser::TypeContext* /*context*/) override
69 	{
70 		failure();
71 		return antlrcpp::Any();
72 	}
visitBoolean(ArgumentListParser::BooleanContext *)73 	antlrcpp::Any visitBoolean(ArgumentListParser::BooleanContext* /*context*/) override
74 	{
75 		failure();
76 		return antlrcpp::Any();
77 	}
visitValue(ArgumentListParser::ValueContext *)78 	antlrcpp::Any visitValue(ArgumentListParser::ValueContext* /*context*/) override
79 	{
80 		failure();
81 		return antlrcpp::Any();
82 	}
visitActual_argument(ArgumentListParser::Actual_argumentContext *)83 	antlrcpp::Any visitActual_argument(ArgumentListParser::Actual_argumentContext* /*context*/) override
84 	{
85 		failure();
86 		return antlrcpp::Any();
87 	}
visitOptional_argument(ArgumentListParser::Optional_argumentContext *)88 	antlrcpp::Any visitOptional_argument(ArgumentListParser::Optional_argumentContext* /*context*/) override
89 	{
90 		failure();
91 		return antlrcpp::Any();
92 	}
visitArgument(ArgumentListParser::ArgumentContext *)93 	antlrcpp::Any visitArgument(ArgumentListParser::ArgumentContext* /*context*/) override
94 	{
95 		failure();
96 		return antlrcpp::Any();
97 	}
visitFunc_arg(ArgumentListParser::Func_argContext *)98 	antlrcpp::Any visitFunc_arg(ArgumentListParser::Func_argContext* /*context*/) override
99 	{
100 		failure();
101 		return antlrcpp::Any();
102 	}
visitFunc_arglist(ArgumentListParser::Func_arglistContext *)103 	antlrcpp::Any visitFunc_arglist(ArgumentListParser::Func_arglistContext* /*context*/) override
104 	{
105 		failure();
106 		return antlrcpp::Any();
107 	}
visitFunction_type(ArgumentListParser::Function_typeContext *)108 	antlrcpp::Any visitFunction_type(ArgumentListParser::Function_typeContext* /*context*/) override
109 	{
110 		failure();
111 		return antlrcpp::Any();
112 	}
visitVarargs_or_simple_type(ArgumentListParser::Varargs_or_simple_typeContext *)113 	antlrcpp::Any visitVarargs_or_simple_type(ArgumentListParser::Varargs_or_simple_typeContext* /*context*/) override
114 	{
115 		failure();
116 		return antlrcpp::Any();
117 	}
118 };
119 
120 class ValueVisitor : public BaseVisitor {
121   public:
visitBoolean(ArgumentListParser::BooleanContext * context)122 	antlrcpp::Any visitBoolean(ArgumentListParser::BooleanContext* context) override { return visitChildren(context); }
visitValue(ArgumentListParser::ValueContext * context)123 	antlrcpp::Any visitValue(ArgumentListParser::ValueContext* context) override { return visitChildren(context); }
visitTerminal(antlr4::tree::TerminalNode * node)124 	antlrcpp::Any visitTerminal(antlr4::tree::TerminalNode* node) override { return node->getText(); }
125 };
126 
127 class ArglistVisitor : public BaseVisitor {
128   public:
129 	antlrcpp::Any visitFunc_arg(ArgumentListParser::Func_argContext* context) override;
visitFunc_arglist(ArgumentListParser::Func_arglistContext * context)130 	antlrcpp::Any visitFunc_arglist(ArgumentListParser::Func_arglistContext* context) override
131 	{
132 		SCP_vector<ade_type_info> argTypes;
133 
134 		size_t n = context->children.size();
135 		for (size_t i = 0; i < n; i++) {
136 			antlrcpp::Any childResult = context->children[i]->accept(this);
137 
138 			if (childResult.isNotNull()) {
139 				argTypes.push_back(childResult.as<ade_type_info>());
140 			}
141 		}
142 
143 		return argTypes;
144 	}
145 
146   protected:
aggregateResult(antlrcpp::Any any,const antlrcpp::Any & nextResult)147 	antlrcpp::Any aggregateResult(antlrcpp::Any any, const antlrcpp::Any& nextResult) override
148 	{
149 		if (any.isNull()) {
150 		}
151 		return AbstractParseTreeVisitor::aggregateResult(any, nextResult);
152 	}
153 };
154 
155 class TypeVisitor : public BaseVisitor {
156   public:
visitSimple_type(ArgumentListParser::Simple_typeContext * context)157 	antlrcpp::Any visitSimple_type(ArgumentListParser::Simple_typeContext* context) override
158 	{
159 		if (context->NIL() != nullptr) {
160 			return ade_type_info("nil");
161 		} else {
162 			return ade_type_info(context->ID()->getText());
163 		}
164 	}
visitVarargs_or_simple_type(ArgumentListParser::Varargs_or_simple_typeContext * context)165 	antlrcpp::Any visitVarargs_or_simple_type(ArgumentListParser::Varargs_or_simple_typeContext* context) override
166 	{
167 		auto retType = visit(context->simple_type()).as<ade_type_info>();
168 		if (context->VARARGS_SPECIFIER() != nullptr) {
169 			return ade_type_info(ade_type_varargs(std::move(retType)));
170 		}
171 		return retType;
172 	}
173 
visitType(ArgumentListParser::TypeContext * context)174 	antlrcpp::Any visitType(ArgumentListParser::TypeContext* context) override
175 	{
176 		return visitChildren(context);
177 	}
visitStandalone_type(ArgumentListParser::Standalone_typeContext * context)178 	antlrcpp::Any visitStandalone_type(ArgumentListParser::Standalone_typeContext* context) override
179 	{
180 		const auto typeValue = visitChildren(context).as<ade_type_info>();
181 
182 		// We can't properly distinguish between alternate types and tuple types in aggregateResult so instead we do
183 		// that here. For a single type, we just return that type
184 		if (!context->COMMA().empty()) {
185 			// We saw a comma in our rule so this is a tuple type
186 			return ade_type_info(ade_type_tuple(typeValue.elements()));
187 		}
188 
189 		// Otherwise just pass through the originally parsed type
190 		return typeValue;
191 	}
192 
visitFunction_type(ArgumentListParser::Function_typeContext * context)193 	antlrcpp::Any visitFunction_type(ArgumentListParser::Function_typeContext* context) override
194 	{
195 		auto retType = visit(context->type()).as<ade_type_info>();
196 
197 		ArglistVisitor arglistVisitor;
198 		auto argTypes = context->func_arglist()->accept(&arglistVisitor).as<SCP_vector<ade_type_info>>();
199 
200 		return ade_type_info(ade_type_function(std::move(retType), std::move(argTypes)));
201 	}
202 
visitMap_type(ArgumentListParser::Map_typeContext * context)203 	antlrcpp::Any visitMap_type(ArgumentListParser::Map_typeContext* context) override
204 	{
205 		const auto keyType = visit(context->type(0)).as<ade_type_info>();
206 		const auto valueType = visit(context->type(1)).as<ade_type_info>();
207 
208 		return ade_type_info(ade_type_map(keyType, valueType));
209 	}
210 
visitIterator_type(ArgumentListParser::Iterator_typeContext * context)211 	antlrcpp::Any visitIterator_type(ArgumentListParser::Iterator_typeContext* context) override
212 	{
213 		const auto valueType = visit(context->type()).as<ade_type_info>();
214 
215 		return ade_type_info(ade_type_iterator(valueType));
216 	}
217 
visitErrorNode(antlr4::tree::ErrorNode *)218 	antlrcpp::Any visitErrorNode(antlr4::tree::ErrorNode* /*node*/) override { return ade_type_info("<error type>"); }
219 
220   protected:
aggregateResult(antlrcpp::Any previous,const antlrcpp::Any & nextResult)221 	antlrcpp::Any aggregateResult(antlrcpp::Any previous, const antlrcpp::Any& nextResult) override
222 	{
223 		// This happens while visiting terminals, ignore those
224 		if (nextResult.isNull()) {
225 			return previous;
226 		}
227 
228 		if (previous.isNotNull()) {
229 			const auto& previousType = previous.as<ade_type_info>();
230 			const auto& nextType     = nextResult.as<ade_type_info>();
231 			return merge_alternatives(previousType, nextType);
232 		} else {
233 			return nextResult.as<ade_type_info>();
234 		}
235 	}
236 };
237 
visitFunc_arg(ArgumentListParser::Func_argContext * context)238 antlrcpp::Any scripting::ArglistVisitor::visitFunc_arg(ArgumentListParser::Func_argContext* context)
239 {
240 	TypeVisitor typeVisit;
241 	auto argType = context->type()->accept(&typeVisit).as<ade_type_info>();
242 	argType.setName(context->ID()->getText());
243 
244 	return argType;
245 }
246 
getCommentContent(const SCP_string & content)247 SCP_string getCommentContent(const SCP_string& content) {
248 	auto base = content.substr(2, content.length() - 4); // Strip leading and trailing delimiters
249 
250 	drop_white_space(base);
251 
252 	return base;
253 }
254 
255 class ArgumentCollectorVisitor : public BaseVisitor {
256 	bool saw_optional = false;
257 
258   public:
259 	SCP_vector<argument_def> args;
260 
visitArg_list(ArgumentListParser::Arg_listContext * context)261 	antlrcpp::Any visitArg_list(ArgumentListParser::Arg_listContext* context) override
262 	{
263 		visitChildren(context);
264 		return antlrcpp::Any();
265 	}
visitActual_argument(ArgumentListParser::Actual_argumentContext * context)266 	antlrcpp::Any visitActual_argument(ArgumentListParser::Actual_argumentContext* context) override
267 	{
268 		argument_def argdef;
269 		argdef.optional = saw_optional;
270 
271 		TypeVisitor typeVisit;
272 		const auto typeAny = context->type()->accept(&typeVisit);
273 		argdef.type        = typeAny.as<ade_type_info>();
274 
275 		if (context->ID() != nullptr) {
276 			argdef.name = context->ID()->getText();
277 		}
278 
279 		if (context->value() != nullptr) {
280 			ValueVisitor valueVisit;
281 			const auto valueAny = context->value()->accept(&valueVisit);
282 			argdef.def_val      = valueAny.as<SCP_string>();
283 			argdef.optional     = true;
284 		}
285 
286 		if (context->ARG_COMMENT() != nullptr) {
287 			argdef.comment = getCommentContent(context->ARG_COMMENT()->getText());
288 		}
289 
290 		args.push_back(std::move(argdef));
291 
292 		if (context->argument()) {
293 			context->argument()->accept(this);
294 		}
295 
296 		return antlrcpp::Any();
297 	}
visitOptional_argument(ArgumentListParser::Optional_argumentContext * context)298 	antlrcpp::Any visitOptional_argument(ArgumentListParser::Optional_argumentContext* context) override
299 	{
300 		saw_optional = true;
301 
302 		visitChildren(context);
303 		return antlrcpp::Any();
304 	}
visitArgument(ArgumentListParser::ArgumentContext * context)305 	antlrcpp::Any visitArgument(ArgumentListParser::ArgumentContext* context) override
306 	{
307 		visitChildren(context);
308 		return antlrcpp::Any();
309 	}
310 };
311 
312 class TypeCheckVisitor : public BaseVisitor {
313 	ArgumentListParser* _parser = nullptr;
314 	const SCP_unordered_set<SCP_string>& _validTypeNames;
315 
316   public:
TypeCheckVisitor(ArgumentListParser * parser,const SCP_unordered_set<SCP_string> & validTypeNames)317 	TypeCheckVisitor(ArgumentListParser* parser, const SCP_unordered_set<SCP_string>& validTypeNames)
318 		: _parser(parser), _validTypeNames(validTypeNames)
319 	{
320 	}
321 
visitSimple_type(ArgumentListParser::Simple_typeContext * context)322 	antlrcpp::Any visitSimple_type(ArgumentListParser::Simple_typeContext* context) override
323 	{
324 		// Nil is always valid
325 		if (context->NIL() != nullptr) {
326 			return antlrcpp::Any();
327 		}
328 
329 		if (_validTypeNames.find(context->ID()->getText()) == _validTypeNames.end()) {
330 			_parser->notifyErrorListeners(context->ID()->getSymbol(),
331 				"Invalid type name <" + context->ID()->getText() + ">",
332 				nullptr);
333 		}
334 
335 		return antlrcpp::Any();
336 	}
visitType(ArgumentListParser::TypeContext * context)337 	antlrcpp::Any visitType(ArgumentListParser::TypeContext* context) override
338 	{
339 		return visitChildren(context);
340 	}
visitVarargs_or_simple_type(ArgumentListParser::Varargs_or_simple_typeContext * context)341 	antlrcpp::Any visitVarargs_or_simple_type(ArgumentListParser::Varargs_or_simple_typeContext* context) override
342 	{
343 		return visitChildren(context);
344 	}
visitStandalone_type(ArgumentListParser::Standalone_typeContext * context)345 	antlrcpp::Any visitStandalone_type(ArgumentListParser::Standalone_typeContext* context) override
346 	{
347 		return visitChildren(context);
348 	}
visitMap_type(ArgumentListParser::Map_typeContext * context)349 	antlrcpp::Any visitMap_type(ArgumentListParser::Map_typeContext* context) override
350 	{
351 		return visitChildren(context);
352 	}
visitIterator_type(ArgumentListParser::Iterator_typeContext * context)353 	antlrcpp::Any visitIterator_type(ArgumentListParser::Iterator_typeContext* context) override
354 	{
355 		return visitChildren(context);
356 	}
357 
visitFunc_arg(ArgumentListParser::Func_argContext * context)358 	antlrcpp::Any visitFunc_arg(ArgumentListParser::Func_argContext* context) override
359 	{
360 		visit(context->type());
361 		return antlrcpp::Any();
362 	}
visitFunc_arglist(ArgumentListParser::Func_arglistContext * context)363 	antlrcpp::Any visitFunc_arglist(ArgumentListParser::Func_arglistContext* context) override
364 	{
365 		visitChildren(context);
366 		return antlrcpp::Any();
367 	}
368 
visitFunction_type(ArgumentListParser::Function_typeContext * context)369 	antlrcpp::Any visitFunction_type(ArgumentListParser::Function_typeContext* context) override
370 	{
371 		visitChildren(context);
372 		return antlrcpp::Any();
373 	}
374 
visitActual_argument(ArgumentListParser::Actual_argumentContext * context)375 	antlrcpp::Any visitActual_argument(ArgumentListParser::Actual_argumentContext* context) override
376 	{
377 		visit(context->type());
378 
379 		// Now also look at the remaining arguments
380 		if (context->argument() != nullptr)
381 		{
382 			visit(context->argument());
383 		}
384 		return antlrcpp::Any();
385 	}
386 
visitOptional_argument(ArgumentListParser::Optional_argumentContext * context)387 	antlrcpp::Any visitOptional_argument(ArgumentListParser::Optional_argumentContext* context) override
388 	{
389 		visitChildren(context);
390 		return antlrcpp::Any();
391 	}
visitArgument(ArgumentListParser::ArgumentContext * context)392 	antlrcpp::Any visitArgument(ArgumentListParser::ArgumentContext* context) override
393 	{
394 		visitChildren(context);
395 		return antlrcpp::Any();
396 	}
visitArg_list(ArgumentListParser::Arg_listContext * context)397 	antlrcpp::Any visitArg_list(ArgumentListParser::Arg_listContext* context) override
398 	{
399 		visitChildren(context);
400 		return antlrcpp::Any();
401 	}
402 };
403 
404 } // namespace
405 
argument_list_parser(const SCP_vector<SCP_string> & validTypeNames)406 argument_list_parser::argument_list_parser(const SCP_vector<SCP_string>& validTypeNames)
407 	: _validTypeNames(validTypeNames.begin(), validTypeNames.end())
408 {
409 }
410 
parse(const SCP_string & argumentList)411 bool argument_list_parser::parse(const SCP_string& argumentList)
412 {
413 	antlr4::ANTLRInputStream input(argumentList);
414 	ArgumentListLexer lexer(&input);
415 	antlr4::CommonTokenStream tokens(&lexer);
416 
417 	ArgumentListParser parser(&tokens);
418 	// By default we log to stderr which we do not want
419 	parser.removeErrorListeners();
420 	libs::antlr::ErrorListener errListener;
421 	parser.addErrorListener(&errListener);
422 
423 	antlr4::tree::ParseTree* tree = parser.arg_list();
424 
425 	TypeCheckVisitor typeChecker(&parser, _validTypeNames);
426 	tree->accept(&typeChecker);
427 
428 	// If there were errors, output them
429 	if (!errListener.diagnostics.empty()) {
430 		SCP_stringstream outStream;
431 		for (const auto& diag : errListener.diagnostics) {
432 			SCP_string tokenUnderline;
433 			if (diag.tokenLength > 1) {
434 				tokenUnderline = SCP_string(diag.tokenLength - 1, '~');
435 			}
436 			outStream << argumentList << "\n"
437 					  << SCP_string(diag.columnInLine, ' ') << "^" << tokenUnderline << "\n"
438 					  << diag.errorMessage << "\n";
439 		}
440 
441 		_errorMessage = outStream.str();
442 	} else {
443 		// Only look at the parameters if we know that we have valid input to avoid some of the error handling while
444 		// traversing the parse tree
445 		ArgumentCollectorVisitor argCollector;
446 		tree->accept(&argCollector);
447 
448 		_argList = argCollector.args;
449 	}
450 
451 	return errListener.diagnostics.empty();
452 }
getArgList() const453 const SCP_vector<scripting::argument_def>& argument_list_parser::getArgList() const
454 {
455 	return _argList;
456 }
getErrorMessage() const457 const SCP_string& argument_list_parser::getErrorMessage() const
458 {
459 	return _errorMessage;
460 }
461 
type_parser(const SCP_vector<SCP_string> & validTypeNames)462 type_parser::type_parser(const SCP_vector<SCP_string>& validTypeNames)
463 	: _validTypeNames(validTypeNames.begin(), validTypeNames.end())
464 {
465 }
parse(const SCP_string & type)466 bool type_parser::parse(const SCP_string& type)
467 {
468 	antlr4::ANTLRInputStream input(type);
469 	ArgumentListLexer lexer(&input);
470 	antlr4::CommonTokenStream tokens(&lexer);
471 
472 	ArgumentListParser parser(&tokens);
473 	// By default we log to stderr which we do not want
474 	parser.removeErrorListeners();
475 	libs::antlr::ErrorListener errListener;
476 	parser.addErrorListener(&errListener);
477 
478 	// Tuple type is the entry point for a standalone type
479 	antlr4::tree::ParseTree* tree = parser.standalone_type();
480 
481 	TypeCheckVisitor typeChecker(&parser, _validTypeNames);
482 	tree->accept(&typeChecker);
483 
484 	// If there were errors, output them
485 	if (!errListener.diagnostics.empty()) {
486 		SCP_stringstream outStream;
487 		for (const auto& diag : errListener.diagnostics) {
488 			SCP_string tokenUnderline;
489 			if (diag.tokenLength > 1) {
490 				tokenUnderline = SCP_string(diag.tokenLength - 1, '~');
491 			}
492 			outStream << type << "\n"
493 					  << SCP_string(diag.columnInLine, ' ') << "^" << tokenUnderline << "\n"
494 					  << diag.errorMessage << "\n";
495 		}
496 
497 		_errorMessage = outStream.str();
498 	} else {
499 		TypeVisitor typeVisitor;
500 		const auto typeAny = tree->accept(&typeVisitor);
501 
502 		_parsedType = typeAny.as<ade_type_info>();
503 	}
504 
505 	return errListener.diagnostics.empty();
506 }
getType() const507 const ade_type_info& type_parser::getType() const
508 {
509 	return _parsedType;
510 }
getErrorMessage() const511 const SCP_string& type_parser::getErrorMessage() const
512 {
513 	return _errorMessage;
514 }
515 
516 } // namespace scripting
517