1 /*
2  * syntax.c:
3  *
4  * Copyright (C) 2007-2016 David Lutterkort
5  *
6  * This library is free software; you can redistribute it and/or
7  * modify it under the terms of the GNU Lesser General Public
8  * License as published by the Free Software Foundation; either
9  * version 2.1 of the License, or (at your option) any later version.
10  *
11  * This library is distributed in the hope that it will be useful,
12  * but WITHOUT ANY WARRANTY; without even the implied warranty of
13  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14  * Lesser General Public License for more details.
15  *
16  * You should have received a copy of the GNU Lesser General Public
17  * License along with this library; if not, write to the Free Software
18  * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307  USA
19  *
20  * Author: David Lutterkort <dlutter@redhat.com>
21  */
22 
23 #include <config.h>
24 
25 #include <assert.h>
26 #include <stdarg.h>
27 #include <limits.h>
28 #include <ctype.h>
29 #include <glob.h>
30 #include <argz.h>
31 #include <sys/types.h>
32 #include <sys/stat.h>
33 #include <unistd.h>
34 
35 #include "memory.h"
36 #include "syntax.h"
37 #include "augeas.h"
38 #include "transform.h"
39 #include "errcode.h"
40 
41 /* Extension of source files */
42 #define AUG_EXT ".aug"
43 
44 #define LNS_TYPE_CHECK(ctx) ((ctx)->aug->flags & AUG_TYPE_CHECK)
45 
46 static const char *const builtin_module = "Builtin";
47 
48 static const struct type string_type    = { .ref = UINT_MAX, .tag = T_STRING };
49 static const struct type regexp_type    = { .ref = UINT_MAX, .tag = T_REGEXP };
50 static const struct type lens_type      = { .ref = UINT_MAX, .tag = T_LENS };
51 static const struct type tree_type      = { .ref = UINT_MAX, .tag = T_TREE };
52 static const struct type filter_type    = { .ref = UINT_MAX, .tag = T_FILTER };
53 static const struct type transform_type =
54                                        { .ref = UINT_MAX, .tag = T_TRANSFORM };
55 static const struct type unit_type      = { .ref = UINT_MAX, .tag = T_UNIT };
56 
57 const struct type *const t_string    = &string_type;
58 const struct type *const t_regexp    = &regexp_type;
59 const struct type *const t_lens      = &lens_type;
60 const struct type *const t_tree      = &tree_type;
61 const struct type *const t_filter    = &filter_type;
62 const struct type *const t_transform = &transform_type;
63 const struct type *const t_unit      = &unit_type;
64 
65 static const char *const type_names[] = {
66     "string", "regexp", "lens", "tree", "filter",
67     "transform", "function", "unit", NULL
68 };
69 
70 /* The anonymous identifier which we will never bind */
71 static const char anon_ident[] = "_";
72 
73 static void print_value(FILE *out, struct value *v);
74 
75 /* The evaluation context with all loaded modules and the bindings for the
76  * module we are working on in LOCAL
77  */
78 struct ctx {
79     const char     *name;     /* The module we are working on */
80     struct augeas  *aug;
81     struct binding *local;
82 };
83 
init_fatal_exn(struct error * error)84 static int init_fatal_exn(struct error *error) {
85     if (error->exn != NULL)
86         return 0;
87     error->exn = make_exn_value(ref(error->info), "Error during evaluation");
88     if (error->exn == NULL)
89         return -1;
90     error->exn->exn->seen = 1;
91     error->exn->exn->error = 1;
92     error->exn->exn->lines = NULL;
93     error->exn->exn->nlines = 0;
94     error->exn->ref = REF_MAX;
95     return 0;
96 }
97 
format_error(struct info * info,aug_errcode_t code,const char * format,va_list ap)98 static void format_error(struct info *info, aug_errcode_t code,
99                          const char *format, va_list ap) {
100     struct error *error = info->error;
101     char *si = NULL, *sf = NULL, *sd = NULL;
102     int r;
103 
104     error->code = code;
105     /* Only syntax errors are cumulative */
106     if (code != AUG_ESYNTAX)
107         FREE(error->details);
108 
109     si = format_info(info);
110     r = vasprintf(&sf, format, ap);
111     if (r < 0)
112         sf = NULL;
113     if (error->details != NULL) {
114         r = xasprintf(&sd, "%s\n%s%s", error->details,
115                       (si == NULL) ? "(no location)" : si,
116                       (sf == NULL) ? "(no details)" : sf);
117     } else {
118         r = xasprintf(&sd, "%s%s",
119                       (si == NULL) ? "(no location)" : si,
120                       (sf == NULL) ? "(no details)" : sf);
121     }
122     if (r >= 0) {
123         free(error->details);
124         error->details = sd;
125     }
126     free(si);
127     free(sf);
128 }
129 
syntax_error(struct info * info,const char * format,...)130 void syntax_error(struct info *info, const char *format, ...) {
131     struct error *error = info->error;
132     va_list ap;
133 
134     if (error->code != AUG_NOERROR && error->code != AUG_ESYNTAX)
135         return;
136 
137 	va_start(ap, format);
138     format_error(info, AUG_ESYNTAX, format, ap);
139     va_end(ap);
140 }
141 
fatal_error(struct info * info,const char * format,...)142 void fatal_error(struct info *info, const char *format, ...) {
143     struct error *error = info->error;
144     va_list ap;
145 
146     if (error->code == AUG_EINTERNAL)
147         return;
148 
149 	va_start(ap, format);
150     format_error(info, AUG_EINTERNAL, format, ap);
151     va_end(ap);
152 }
153 
free_param(struct param * param)154 static void free_param(struct param *param) {
155     if (param == NULL)
156         return;
157     assert(param->ref == 0);
158     unref(param->info, info);
159     unref(param->name, string);
160     unref(param->type, type);
161     free(param);
162 }
163 
free_term(struct term * term)164 void free_term(struct term *term) {
165     if (term == NULL)
166         return;
167     assert(term->ref == 0);
168     switch(term->tag) {
169     case A_MODULE:
170         free(term->mname);
171         free(term->autoload);
172         unref(term->decls, term);
173         break;
174     case A_BIND:
175         free(term->bname);
176         unref(term->exp, term);
177         break;
178     case A_COMPOSE:
179     case A_UNION:
180     case A_MINUS:
181     case A_CONCAT:
182     case A_APP:
183     case A_LET:
184         unref(term->left, term);
185         unref(term->right, term);
186         break;
187     case A_VALUE:
188         unref(term->value, value);
189         break;
190     case A_IDENT:
191         unref(term->ident, string);
192         break;
193     case A_BRACKET:
194         unref(term->brexp, term);
195         break;
196     case A_FUNC:
197         unref(term->param, param);
198         unref(term->body, term);
199         break;
200     case A_REP:
201         unref(term->rexp, term);
202         break;
203     case A_TEST:
204         unref(term->test, term);
205         unref(term->result, term);
206         break;
207     default:
208         assert(0);
209         break;
210     }
211     unref(term->next, term);
212     unref(term->info, info);
213     unref(term->type, type);
214     free(term);
215 }
216 
free_binding(struct binding * binding)217 static void free_binding(struct binding *binding) {
218     if (binding == NULL)
219         return;
220     assert(binding->ref == 0);
221     unref(binding->next, binding);
222     unref(binding->ident, string);
223     unref(binding->type, type);
224     unref(binding->value, value);
225     free(binding);
226 }
227 
free_module(struct module * module)228 void free_module(struct module *module) {
229     if (module == NULL)
230         return;
231     assert(module->ref == 0);
232     free(module->name);
233     unref(module->next, module);
234     unref(module->bindings, binding);
235     unref(module->autoload, transform);
236     free(module);
237 }
238 
free_type(struct type * type)239 void free_type(struct type *type) {
240     if (type == NULL)
241         return;
242     assert(type->ref == 0);
243 
244     if (type->tag == T_ARROW) {
245         unref(type->dom, type);
246         unref(type->img, type);
247     }
248     free(type);
249 }
250 
free_exn(struct exn * exn)251 static void free_exn(struct exn *exn) {
252     if (exn == NULL)
253         return;
254 
255     unref(exn->info, info);
256     free(exn->message);
257     for (int i=0; i < exn->nlines; i++) {
258         free(exn->lines[i]);
259     }
260     free(exn->lines);
261     free(exn);
262 }
263 
free_value(struct value * v)264 void free_value(struct value *v) {
265     if (v == NULL)
266         return;
267     assert(v->ref == 0);
268 
269     switch(v->tag) {
270     case V_STRING:
271         unref(v->string, string);
272         break;
273     case V_REGEXP:
274         unref(v->regexp, regexp);
275         break;
276     case V_LENS:
277         unref(v->lens, lens);
278         break;
279     case V_TREE:
280         free_tree(v->origin);
281         break;
282     case V_FILTER:
283         unref(v->filter, filter);
284         break;
285     case V_TRANSFORM:
286         unref(v->transform, transform);
287         break;
288     case V_NATIVE:
289         if (v->native)
290             unref(v->native->type, type);
291         free(v->native);
292         break;
293     case V_CLOS:
294         unref(v->func, term);
295         unref(v->bindings, binding);
296         break;
297     case V_EXN:
298         free_exn(v->exn);
299         break;
300     case V_UNIT:
301         break;
302     default:
303         assert(0);
304     }
305     unref(v->info, info);
306     free(v);
307 }
308 
309 /*
310  * Creation of (some) terms. Others are in parser.y
311  * Reference counted arguments are now owned by the returned object, i.e.
312  * the make_* functions do not increment the count.
313  * Returned objects have a referece count of 1.
314  */
make_term(enum term_tag tag,struct info * info)315 struct term *make_term(enum term_tag tag, struct info *info) {
316   struct term *term;
317   if (make_ref(term) < 0) {
318       unref(info, info);
319   } else {
320       term->tag = tag;
321       term->info = info;
322   }
323   return term;
324 }
325 
make_param(char * name,struct type * type,struct info * info)326 struct term *make_param(char *name, struct type *type, struct info *info) {
327   struct term *term = make_term(A_FUNC, info);
328   if (term == NULL)
329       goto error;
330   make_ref_err(term->param);
331   term->param->info = ref(term->info);
332   make_ref_err(term->param->name);
333   term->param->name->str = name;
334   term->param->type = type;
335   return term;
336  error:
337   unref(term, term);
338   return NULL;
339 }
340 
make_value(enum value_tag tag,struct info * info)341 struct value *make_value(enum value_tag tag, struct info *info) {
342     struct value *value = NULL;
343     if (make_ref(value) < 0) {
344         unref(info, info);
345     } else {
346         value->tag = tag;
347         value->info = info;
348     }
349     return value;
350 }
351 
make_unit(struct info * info)352 struct value *make_unit(struct info *info) {
353     return make_value(V_UNIT, info);
354 }
355 
make_app_term(struct term * lambda,struct term * arg,struct info * info)356 struct term *make_app_term(struct term *lambda, struct term *arg,
357                            struct info *info) {
358   struct term *app = make_term(A_APP, info);
359   if (app == NULL) {
360       unref(lambda, term);
361       unref(arg, term);
362   } else {
363       app->left = lambda;
364       app->right = arg;
365   }
366   return app;
367 }
368 
make_app_ident(char * id,struct term * arg,struct info * info)369 struct term *make_app_ident(char *id, struct term *arg, struct info *info) {
370     struct term *ident = make_term(A_IDENT, ref(info));
371     ident->ident = make_string(id);
372     if (ident->ident == NULL) {
373         unref(arg, term);
374         unref(info, info);
375         unref(ident, term);
376         return NULL;
377     }
378     return make_app_term(ident, arg, info);
379 }
380 
build_func(struct term * params,struct term * exp)381 struct term *build_func(struct term *params, struct term *exp) {
382   assert(params->tag == A_FUNC);
383   if (params->next != NULL)
384     exp = build_func(params->next, exp);
385 
386   params->body = exp;
387   params->next = NULL;
388   return params;
389 }
390 
391 /* Ownership is taken as needed */
make_closure(struct term * func,struct binding * bnds)392 static struct value *make_closure(struct term *func, struct binding *bnds) {
393     struct value *v = NULL;
394     if (make_ref(v) == 0) {
395         v->tag  = V_CLOS;
396         v->info = ref(func->info);
397         v->func = ref(func);
398         v->bindings = ref(bnds);
399     }
400     return v;
401 }
402 
make_exn_value(struct info * info,const char * format,...)403 struct value *make_exn_value(struct info *info,
404                              const char *format, ...) {
405     va_list ap;
406     int r;
407     struct value *v;
408     char *message;
409 
410     va_start(ap, format);
411     r = vasprintf(&message, format, ap);
412     va_end(ap);
413     if (r == -1)
414         return NULL;
415 
416     v = make_value(V_EXN, ref(info));
417     if (ALLOC(v->exn) < 0)
418         return info->error->exn;
419     v->exn->info = info;
420     v->exn->message = message;
421 
422     return v;
423 }
424 
exn_add_lines(struct value * v,int nlines,...)425 void exn_add_lines(struct value *v, int nlines, ...) {
426     assert(v->tag == V_EXN);
427 
428     va_list ap;
429     if (REALLOC_N(v->exn->lines, v->exn->nlines + nlines) == -1)
430         return;
431     va_start(ap, nlines);
432     for (int i=0; i < nlines; i++) {
433         char *line = va_arg(ap, char *);
434         v->exn->lines[v->exn->nlines + i] = line;
435     }
436     va_end(ap);
437     v->exn->nlines += nlines;
438 }
439 
exn_printf_line(struct value * exn,const char * format,...)440 void exn_printf_line(struct value *exn, const char *format, ...) {
441     va_list ap;
442     int r;
443     char *line;
444 
445     va_start(ap, format);
446     r = vasprintf(&line, format, ap);
447     va_end(ap);
448     if (r >= 0)
449         exn_add_lines(exn, 1, line);
450 }
451 
452 /*
453  * Modules
454  */
455 static int load_module(struct augeas *aug, const char *name);
456 static char *module_basename(const char *modname);
457 
module_create(const char * name)458 struct module *module_create(const char *name) {
459     struct module *module;
460     make_ref(module);
461     module->name = strdup(name);
462     return module;
463 }
464 
module_find(struct module * module,const char * name)465 static struct module *module_find(struct module *module, const char *name) {
466     list_for_each(e, module) {
467         if (STRCASEEQ(e->name, name))
468             return e;
469     }
470     return NULL;
471 }
472 
bnd_lookup(struct binding * bindings,const char * name)473 static struct binding *bnd_lookup(struct binding *bindings, const char *name) {
474     list_for_each(b, bindings) {
475         if (STREQ(b->ident->str, name))
476             return b;
477     }
478     return NULL;
479 }
480 
modname_of_qname(const char * qname)481 static char *modname_of_qname(const char *qname) {
482     char *dot = strchr(qname, '.');
483     if (dot == NULL)
484         return NULL;
485 
486     return strndup(qname, dot - qname);
487 }
488 
lookup_internal(struct augeas * aug,const char * ctx_modname,const char * name,struct binding ** bnd)489 static int lookup_internal(struct augeas *aug, const char *ctx_modname,
490                            const char *name, struct binding **bnd) {
491     char *modname = modname_of_qname(name);
492 
493     *bnd = NULL;
494 
495     if (modname == NULL) {
496         struct module *builtin =
497             module_find(aug->modules, builtin_module);
498         assert(builtin != NULL);
499         *bnd = bnd_lookup(builtin->bindings, name);
500         return 0;
501     }
502 
503  qual_lookup:
504     list_for_each(module, aug->modules) {
505         if (STRCASEEQ(module->name, modname)) {
506             *bnd = bnd_lookup(module->bindings, name + strlen(modname) + 1);
507             free(modname);
508             return 0;
509         }
510     }
511     /* Try to load the module */
512     if (streqv(modname, ctx_modname)) {
513         free(modname);
514         return 0;
515     }
516     int loaded = load_module(aug, modname) == 0;
517     if (loaded)
518         goto qual_lookup;
519 
520     free(modname);
521     return -1;
522 }
523 
lens_lookup(struct augeas * aug,const char * qname)524 struct lens *lens_lookup(struct augeas *aug, const char *qname) {
525     struct binding *bnd = NULL;
526 
527     if (lookup_internal(aug, NULL, qname, &bnd) < 0)
528         return NULL;
529     if (bnd == NULL || bnd->value->tag != V_LENS)
530         return NULL;
531     return bnd->value->lens;
532 }
533 
ctx_lookup_bnd(struct info * info,struct ctx * ctx,const char * name)534 static struct binding *ctx_lookup_bnd(struct info *info,
535                                       struct ctx *ctx, const char *name) {
536     struct binding *b = NULL;
537     int nlen = strlen(ctx->name);
538 
539     if (STREQLEN(ctx->name, name, nlen) && name[nlen] == '.')
540         name += nlen + 1;
541 
542     b = bnd_lookup(ctx->local, name);
543     if (b != NULL)
544         return b;
545 
546     if (ctx->aug != NULL) {
547         int r;
548         r = lookup_internal(ctx->aug, ctx->name, name, &b);
549         if (r == 0)
550             return b;
551         char *modname = modname_of_qname(name);
552         syntax_error(info, "Could not load module %s for %s",
553                      modname, name);
554         free(modname);
555         return NULL;
556     }
557     return NULL;
558 }
559 
ctx_lookup(struct info * info,struct ctx * ctx,struct string * ident)560 static struct value *ctx_lookup(struct info *info,
561                                 struct ctx *ctx, struct string *ident) {
562     struct binding *b = ctx_lookup_bnd(info, ctx, ident->str);
563     return b == NULL ? NULL : b->value;
564 }
565 
ctx_lookup_type(struct info * info,struct ctx * ctx,struct string * ident)566 static struct type *ctx_lookup_type(struct info *info,
567                                     struct ctx *ctx, struct string *ident) {
568     struct binding *b = ctx_lookup_bnd(info, ctx, ident->str);
569     return b == NULL ? NULL : b->type;
570 }
571 
572 /* Takes ownership as needed */
bind_type(struct binding ** bnds,const char * name,struct type * type)573 static struct binding *bind_type(struct binding **bnds,
574                                  const char *name, struct type *type) {
575     struct binding *binding;
576 
577     if (STREQ(name, anon_ident))
578         return NULL;
579     make_ref(binding);
580     make_ref(binding->ident);
581     binding->ident->str = strdup(name);
582     binding->type = ref(type);
583     list_cons(*bnds, binding);
584 
585     return binding;
586 }
587 
588 /* Takes ownership as needed */
bind_param(struct binding ** bnds,struct param * param,struct value * v)589 static void bind_param(struct binding **bnds, struct param *param,
590                        struct value *v) {
591     struct binding *b;
592     make_ref(b);
593     b->ident = ref(param->name);
594     b->type  = ref(param->type);
595     b->value = ref(v);
596     ref(*bnds);
597     list_cons(*bnds, b);
598 }
599 
unbind_param(struct binding ** bnds,ATTRIBUTE_UNUSED struct param * param)600 static void unbind_param(struct binding **bnds, ATTRIBUTE_UNUSED struct param *param) {
601     struct binding *b = *bnds;
602     assert(b->ident == param->name);
603     assert(b->next != *bnds);
604     *bnds = b->next;
605     unref(b, binding);
606 }
607 
608 /* Takes ownership of VALUE */
bind(struct binding ** bnds,const char * name,struct type * type,struct value * value)609 static void bind(struct binding **bnds,
610                  const char *name, struct type *type, struct value *value) {
611     struct binding *b = NULL;
612 
613     if (STRNEQ(name, anon_ident)) {
614         b = bind_type(bnds, name, type);
615         b->value = ref(value);
616     }
617 }
618 
619 /*
620  * Some debug printing
621  */
622 
623 static char *type_string(struct type *t);
624 
dump_bindings(struct binding * bnds)625 static void dump_bindings(struct binding *bnds) {
626     list_for_each(b, bnds) {
627         char *st = type_string(b->type);
628         fprintf(stderr, "    %s: %s", b->ident->str, st);
629         fprintf(stderr, " = ");
630         print_value(stderr, b->value);
631         fputc('\n', stderr);
632         free(st);
633     }
634 }
635 
dump_module(struct module * module)636 static void dump_module(struct module *module) {
637     if (module == NULL)
638         return;
639     fprintf(stderr, "Module %s\n:", module->name);
640     dump_bindings(module->bindings);
641     dump_module(module->next);
642 }
643 
644 ATTRIBUTE_UNUSED
dump_ctx(struct ctx * ctx)645 static void dump_ctx(struct ctx *ctx) {
646     fprintf(stderr, "Context: %s\n", ctx->name);
647     dump_bindings(ctx->local);
648     if (ctx->aug != NULL) {
649         list_for_each(m, ctx->aug->modules)
650             dump_module(m);
651     }
652 }
653 
654 /*
655  * Values
656  */
print_tree_braces(FILE * out,int indent,struct tree * tree)657 void print_tree_braces(FILE *out, int indent, struct tree *tree) {
658     if (tree == NULL) {
659         fprintf(out, "(null tree)\n");
660         return;
661     }
662     list_for_each(t, tree) {
663         for (int i=0; i < indent; i++) fputc(' ', out);
664         fprintf(out, "{ ");
665         if (t->label != NULL)
666             fprintf(out, "\"%s\"", t->label);
667         if (t->value != NULL)
668             fprintf(out, " = \"%s\"", t->value);
669         if (t->children != NULL) {
670             fputc('\n', out);
671             print_tree_braces(out, indent + 2, t->children);
672             for (int i=0; i < indent; i++) fputc(' ', out);
673         } else {
674             fputc(' ', out);
675         }
676         fprintf(out, "}\n");
677     }
678 }
679 
print_value(FILE * out,struct value * v)680 static void print_value(FILE *out, struct value *v) {
681     if (v == NULL) {
682         fprintf(out, "<null>");
683         return;
684     }
685 
686     switch(v->tag) {
687     case V_STRING:
688         fprintf(out, "\"%s\"", v->string->str);
689         break;
690     case V_REGEXP:
691         fprintf(out, "/%s/", v->regexp->pattern->str);
692         break;
693     case V_LENS:
694         fprintf(out, "<lens:");
695         print_info(out, v->lens->info);
696         fprintf(out, ">");
697         break;
698     case V_TREE:
699         print_tree_braces(out, 0, v->origin);
700         break;
701     case V_FILTER:
702         fprintf(out, "<filter:");
703         list_for_each(f, v->filter) {
704             fprintf(out, "%c%s%c", f->include ? '+' : '-', f->glob->str,
705                    (f->next != NULL) ? ':' : '>');
706         }
707         break;
708     case V_TRANSFORM:
709         fprintf(out, "<transform:");
710         print_info(out, v->transform->lens->info);
711         fprintf(out, ">");
712         break;
713     case V_NATIVE:
714         fprintf(out, "<native:");
715         print_info(out, v->info);
716         fprintf(out, ">");
717         break;
718     case V_CLOS:
719         fprintf(out, "<closure:");
720         print_info(out, v->func->info);
721         fprintf(out, ">");
722         break;
723     case V_EXN:
724         if (! v->exn->seen) {
725             print_info(out, v->exn->info);
726             fprintf(out, "exception: %s\n", v->exn->message);
727             for (int i=0; i < v->exn->nlines; i++) {
728                 fprintf(out, "    %s\n", v->exn->lines[i]);
729             }
730             v->exn->seen = 1;
731         }
732         break;
733     case V_UNIT:
734         fprintf(out, "()");
735         break;
736     default:
737         assert(0);
738         break;
739     }
740 }
741 
value_equal(struct value * v1,struct value * v2)742 static int value_equal(struct value *v1, struct value *v2) {
743     if (v1 == NULL && v2 == NULL)
744         return 1;
745     if (v1 == NULL || v2 == NULL)
746         return 0;
747     if (v1->tag != v2->tag)
748         return 0;
749     switch (v1->tag) {
750     case V_STRING:
751         return STREQ(v1->string->str, v2->string->str);
752         break;
753     case V_REGEXP:
754         // FIXME: Should probably build FA's and compare them
755         return STREQ(v1->regexp->pattern->str, v2->regexp->pattern->str);
756         break;
757     case V_LENS:
758         return v1->lens == v2->lens;
759         break;
760     case V_TREE:
761         return tree_equal(v1->origin->children, v2->origin->children);
762         break;
763     case V_FILTER:
764         return v1->filter == v2->filter;
765         break;
766     case V_TRANSFORM:
767         return v1->transform == v2->transform;
768         break;
769     case V_NATIVE:
770         return v1->native == v2->native;
771         break;
772     case V_CLOS:
773         return v1->func == v2->func && v1->bindings == v2->bindings;
774         break;
775     default:
776         assert(0);
777         abort();
778         break;
779     }
780 }
781 
782 /*
783  * Types
784  */
make_arrow_type(struct type * dom,struct type * img)785 struct type *make_arrow_type(struct type *dom, struct type *img) {
786   struct type *type;
787   make_ref(type);
788   type->tag = T_ARROW;
789   type->dom = ref(dom);
790   type->img = ref(img);
791   return type;
792 }
793 
make_base_type(enum type_tag tag)794 struct type *make_base_type(enum type_tag tag) {
795     if (tag == T_STRING)
796         return (struct type *) t_string;
797     else if (tag == T_REGEXP)
798         return (struct type *) t_regexp;
799     else if (tag == T_LENS)
800         return (struct type *) t_lens;
801     else if (tag == T_TREE)
802         return (struct type *) t_tree;
803     else if (tag == T_FILTER)
804         return (struct type *) t_filter;
805     else if (tag == T_TRANSFORM)
806         return (struct type *) t_transform;
807     else if (tag == T_UNIT)
808         return (struct type *) t_unit;
809     assert(0);
810     abort();
811 }
812 
type_name(struct type * t)813 static const char *type_name(struct type *t) {
814     for (int i = 0; type_names[i] != NULL; i++)
815         if (i == t->tag)
816             return type_names[i];
817     assert(0);
818     abort();
819 }
820 
type_string(struct type * t)821 static char *type_string(struct type *t) {
822     if (t->tag == T_ARROW) {
823         char *s = NULL;
824         int r;
825         char *sd = type_string(t->dom);
826         char *si = type_string(t->img);
827         if (t->dom->tag == T_ARROW)
828             r = asprintf(&s, "(%s) -> %s", sd, si);
829         else
830             r = asprintf(&s, "%s -> %s", sd, si);
831         free(sd);
832         free(si);
833         return (r == -1) ? NULL : s;
834     } else {
835         return strdup(type_name(t));
836     }
837 }
838 
839 /* Decide whether T1 is a subtype of T2. The only subtype relations are
840  * T_STRING <: T_REGEXP and the usual subtyping of functions based on
841  * comparing domains/images
842  *
843  * Return 1 if T1 is a subtype of T2, 0 otherwise
844  */
subtype(struct type * t1,struct type * t2)845 static int subtype(struct type *t1, struct type *t2) {
846     if (t1 == t2)
847         return 1;
848     /* We only promote T_STRING => T_REGEXP, no automatic conversion
849        of strings/regexps to lenses (yet) */
850     if (t1->tag == T_STRING)
851         return (t2->tag == T_STRING || t2->tag == T_REGEXP);
852     if (t1->tag == T_ARROW && t2->tag == T_ARROW) {
853         return subtype(t2->dom, t1->dom)
854             && subtype(t1->img, t2->img);
855     }
856     return t1->tag == t2->tag;
857 }
858 
type_equal(struct type * t1,struct type * t2)859 static int type_equal(struct type *t1, struct type *t2) {
860     return (t1 == t2) || (subtype(t1, t2) && subtype(t2, t1));
861 }
862 
863 /* Return a type T with subtype(T, T1) && subtype(T, T2) */
864 static struct type *type_meet(struct type *t1, struct type *t2);
865 
866 /* Return a type T with subtype(T1, T) && subtype(T2, T) */
type_join(struct type * t1,struct type * t2)867 static struct type *type_join(struct type *t1, struct type *t2) {
868     if (t1->tag == T_STRING) {
869         if (t2->tag == T_STRING)
870             return ref(t1);
871         else if (t2->tag == T_REGEXP)
872             return ref(t2);
873     } else if (t1->tag == T_REGEXP) {
874         if (t2->tag == T_STRING || t2->tag == T_REGEXP)
875             return ref(t1);
876     } else if (t1->tag == T_ARROW) {
877         if (t2->tag != T_ARROW)
878             return NULL;
879         struct type *dom = type_meet(t1->dom, t2->dom);
880         struct type *img = type_join(t1->img, t2->img);
881         if (dom == NULL || img == NULL) {
882             unref(dom, type);
883             unref(img, type);
884             return NULL;
885         }
886         return make_arrow_type(dom, img);
887     } else if (type_equal(t1, t2)) {
888         return ref(t1);
889     }
890     return NULL;
891 }
892 
893 /* Return a type T with subtype(T, T1) && subtype(T, T2) */
type_meet(struct type * t1,struct type * t2)894 static struct type *type_meet(struct type *t1, struct type *t2) {
895     if (t1->tag == T_STRING) {
896         if (t2->tag == T_STRING || t2->tag == T_REGEXP)
897             return ref(t1);
898     } else if (t1->tag == T_REGEXP) {
899         if (t2->tag == T_STRING || t2->tag == T_REGEXP)
900             return ref(t2);
901     } else if (t1->tag == T_ARROW) {
902         if (t2->tag != T_ARROW)
903             return NULL;
904         struct type *dom = type_join(t1->dom, t2->dom);
905         struct type *img = type_meet(t1->img, t2->img);
906         if (dom == NULL || img == NULL) {
907             unref(dom, type);
908             unref(img, type);
909             return NULL;
910         }
911         return make_arrow_type(dom, img);
912     } else if (type_equal(t1, t2)) {
913         return ref(t1);
914     }
915     return NULL;
916 }
917 
value_type(struct value * v)918 static struct type *value_type(struct value *v) {
919     switch(v->tag) {
920     case V_STRING:
921         return make_base_type(T_STRING);
922     case V_REGEXP:
923         return make_base_type(T_REGEXP);
924     case V_LENS:
925         return make_base_type(T_LENS);
926     case V_TREE:
927         return make_base_type(T_TREE);
928     case V_FILTER:
929         return make_base_type(T_FILTER);
930     case V_TRANSFORM:
931         return make_base_type(T_TRANSFORM);
932     case V_UNIT:
933         return make_base_type(T_UNIT);
934     case V_NATIVE:
935         return ref(v->native->type);
936     case V_CLOS:
937         return ref(v->func->type);
938     case V_EXN:   /* Fail on exceptions */
939     default:
940         assert(0);
941         abort();
942     }
943 }
944 
945 /* Coerce V to the type T. Currently, only T_STRING can be coerced to
946  * T_REGEXP. Returns a value that is owned by the caller. Trying to perform
947  * an impossible coercion is a fatal error. Receives ownership of V.
948  */
coerce(struct value * v,struct type * t)949 static struct value *coerce(struct value *v, struct type *t) {
950     struct type *vt = value_type(v);
951     if (type_equal(vt, t)) {
952         unref(vt, type);
953         return v;
954     }
955     if (vt->tag == T_STRING && t->tag == T_REGEXP) {
956         struct value *rxp = make_value(V_REGEXP, ref(v->info));
957         rxp->regexp = make_regexp_literal(v->info, v->string->str);
958         if (rxp->regexp == NULL) {
959             report_error(v->info->error, AUG_ENOMEM, NULL);
960         };
961         unref(v, value);
962         unref(vt, type);
963         return rxp;
964     }
965     return make_exn_value(v->info, "Type %s can not be coerced to %s",
966                           type_name(vt), type_name(t));
967 }
968 
969 /* Return one of the expected types (passed as ...).
970    Does not give ownership of the returned type */
expect_types_arr(struct info * info,struct type * act,int ntypes,struct type * allowed[])971 static struct type *expect_types_arr(struct info *info,
972                                      struct type *act,
973                                      int ntypes, struct type *allowed[]) {
974     struct type *result = NULL;
975 
976     for (int i=0; i < ntypes; i++) {
977         if (subtype(act, allowed[i])) {
978             result = allowed[i];
979             break;
980         }
981     }
982     if (result == NULL) {
983         int len = 0;
984         for (int i=0; i < ntypes; i++) {
985             len += strlen(type_name(allowed[i]));
986         }
987         len += (ntypes - 1) * 4 + 1;
988         char *allowed_names;
989         if (ALLOC_N(allowed_names, len) < 0)
990             return NULL;
991         for (int i=0; i < ntypes; i++) {
992             if (i > 0)
993                 strcat(allowed_names, (i == ntypes - 1) ? ", or " : ", ");
994             strcat(allowed_names, type_name(allowed[i]));
995         }
996         char *act_str = type_string(act);
997         syntax_error(info, "type error: expected %s but found %s",
998                      allowed_names, act_str);
999         free(act_str);
1000         free(allowed_names);
1001     }
1002     return result;
1003 }
1004 
expect_types(struct info * info,struct type * act,int ntypes,...)1005 static struct type *expect_types(struct info *info,
1006                                  struct type *act, int ntypes, ...) {
1007     va_list ap;
1008     struct type *allowed[ntypes];
1009 
1010     va_start(ap, ntypes);
1011     for (int i=0; i < ntypes; i++)
1012         allowed[i] = va_arg(ap, struct type *);
1013     va_end(ap);
1014     return expect_types_arr(info, act, ntypes, allowed);
1015 }
1016 
1017 static struct value *apply(struct term *app, struct ctx *ctx);
1018 
1019 typedef struct value *(*impl0)(struct info *);
1020 typedef struct value *(*impl1)(struct info *, struct value *);
1021 typedef struct value *(*impl2)(struct info *, struct value *, struct value *);
1022 typedef struct value *(*impl3)(struct info *, struct value *, struct value *,
1023                                struct value *);
1024 typedef struct value *(*impl4)(struct info *, struct value *, struct value *,
1025                                struct value *, struct value *);
1026 typedef struct value *(*impl5)(struct info *, struct value *, struct value *,
1027                                struct value *, struct value *, struct value *);
1028 
1029 static struct value *native_call(struct info *info,
1030                                  struct native *func, struct ctx *ctx) {
1031     struct value *argv[func->argc + 1];
1032     struct binding *b = ctx->local;
1033 
1034     for (int i = func->argc - 1; i >= 0; i--) {
1035         argv[i] = b->value;
1036         b = b->next;
1037     }
1038     argv[func->argc] = NULL;
1039 
1040     return func->impl(info, argv);
1041 }
1042 
1043 static void type_error1(struct info *info, const char *msg, struct type *type) {
1044     char *s = type_string(type);
1045     syntax_error(info, "Type error: ");
1046     syntax_error(info, msg, s);
1047     free(s);
1048 }
1049 
type_error2(struct info * info,const char * msg,struct type * type1,struct type * type2)1050 static void type_error2(struct info *info, const char *msg,
1051                         struct type *type1, struct type *type2) {
1052     char *s1 = type_string(type1);
1053     char *s2 = type_string(type2);
1054     syntax_error(info, "Type error: ");
1055     syntax_error(info, msg, s1, s2);
1056     free(s1);
1057     free(s2);
1058 }
1059 
type_error_binop(struct info * info,const char * opname,struct type * type1,struct type * type2)1060 static void type_error_binop(struct info *info, const char *opname,
1061                              struct type *type1, struct type *type2) {
1062     char *s1 = type_string(type1);
1063     char *s2 = type_string(type2);
1064     syntax_error(info, "Type error: ");
1065     syntax_error(info, "%s of %s and %s is not possible", opname, s1, s2);
1066     free(s1);
1067     free(s2);
1068 }
1069 
1070 static int check_exp(struct term *term, struct ctx *ctx);
1071 
require_exp_type(struct term * term,struct ctx * ctx,int ntypes,struct type * allowed[])1072 static struct type *require_exp_type(struct term *term, struct ctx *ctx,
1073                                      int ntypes, struct type *allowed[]) {
1074     int r = 1;
1075 
1076     if (term->type == NULL) {
1077         r = check_exp(term, ctx);
1078         if (! r)
1079             return NULL;
1080     }
1081 
1082     return expect_types_arr(term->info, term->type, ntypes, allowed);
1083 }
1084 
check_compose(struct term * term,struct ctx * ctx)1085 static int check_compose(struct term *term, struct ctx *ctx) {
1086     struct type *tl = NULL, *tr = NULL;
1087 
1088     if (! check_exp(term->left, ctx))
1089         return 0;
1090     tl = term->left->type;
1091 
1092     if (tl->tag == T_ARROW) {
1093         /* Composition of functions f: a -> b and g: c -> d is defined as
1094            (f . g) x = g (f x) and is type correct if b <: c yielding a
1095            function with type a -> d */
1096         if (! check_exp(term->right, ctx))
1097             return 0;
1098         tr = term->right->type;
1099         if (tr->tag != T_ARROW)
1100             goto print_error;
1101         if (! subtype(tl->img, tr->dom))
1102             goto print_error;
1103         term->type = make_arrow_type(tl->dom, tr->img);
1104     } else if (tl->tag == T_UNIT) {
1105         if (! check_exp(term->right, ctx))
1106             return 0;
1107         term->type = ref(term->right->type);
1108     } else {
1109         goto print_error;
1110     }
1111     return 1;
1112  print_error:
1113     type_error_binop(term->info,
1114                      "composition", term->left->type, term->right->type);
1115     return 0;
1116 }
1117 
check_binop(const char * opname,struct term * term,struct ctx * ctx,int ntypes,...)1118 static int check_binop(const char *opname, struct term *term,
1119                        struct ctx *ctx, int ntypes, ...) {
1120     va_list ap;
1121     struct type *allowed[ntypes];
1122     struct type *tl = NULL, *tr = NULL;
1123 
1124     va_start(ap, ntypes);
1125     for (int i=0; i < ntypes; i++)
1126         allowed[i] = va_arg(ap, struct type *);
1127     va_end(ap);
1128 
1129     tl = require_exp_type(term->left, ctx, ntypes, allowed);
1130     if (tl == NULL)
1131         return 0;
1132 
1133     tr = require_exp_type(term->right, ctx, ntypes, allowed);
1134     if (tr == NULL)
1135         return 0;
1136 
1137     term->type = type_join(tl, tr);
1138     if (term->type == NULL)
1139         goto print_error;
1140     return 1;
1141  print_error:
1142     type_error_binop(term->info, opname, term->left->type, term->right->type);
1143     return 0;
1144 }
1145 
check_value(struct term * term)1146 static int check_value(struct term *term) {
1147     const char *msg;
1148     struct value *v = term->value;
1149 
1150     if (v->tag == V_REGEXP) {
1151         /* The only literal that needs checking are regular expressions,
1152            where we need to make sure the regexp is syntactically
1153            correct */
1154         if (regexp_check(v->regexp, &msg) == -1) {
1155             syntax_error(v->info, "Invalid regular expression: %s", msg);
1156             return 0;
1157         }
1158         term->type = make_base_type(T_REGEXP);
1159     } else if (v->tag == V_EXN) {
1160         /* Exceptions can't be typed */
1161         return 0;
1162     } else {
1163         /* There are cases where we generate values internally, and
1164            those have their type already set; we don't want to
1165            overwrite that */
1166         if (term->type == NULL) {
1167             term->type = value_type(v);
1168         }
1169     }
1170     return 1;
1171 }
1172 
1173 /* Return 1 if TERM passes, 0 otherwise */
check_exp(struct term * term,struct ctx * ctx)1174 static int check_exp(struct term *term, struct ctx *ctx) {
1175     int result = 1;
1176     assert(term->type == NULL || term->tag == A_VALUE || term->ref > 1);
1177     if (term->type != NULL && term->tag != A_VALUE)
1178         return 1;
1179 
1180     switch (term->tag) {
1181     case A_UNION:
1182         result = check_binop("union", term, ctx, 2, t_regexp, t_lens);
1183         break;
1184     case A_MINUS:
1185         result = check_binop("minus", term, ctx, 1, t_regexp);
1186         break;
1187     case A_COMPOSE:
1188         result = check_compose(term, ctx);
1189         break;
1190     case A_CONCAT:
1191         result = check_binop("concatenation", term, ctx,
1192                              4, t_string, t_regexp, t_lens, t_filter);
1193         break;
1194     case A_LET:
1195         {
1196             result = check_exp(term->right, ctx);
1197             if (result) {
1198                 struct term *func = term->left;
1199                 assert(func->tag == A_FUNC);
1200                 assert(func->param->type == NULL);
1201                 func->param->type = ref(term->right->type);
1202 
1203                 result = check_exp(func, ctx);
1204                 if (result) {
1205                     term->tag = A_APP;
1206                     term->type = ref(func->type->img);
1207                 }
1208             }
1209         }
1210         break;
1211     case A_APP:
1212         result = check_exp(term->left, ctx) & check_exp(term->right, ctx);
1213         if (result) {
1214             if (term->left->type->tag != T_ARROW) {
1215                 type_error1(term->info,
1216                             "expected function in application but found %s",
1217                             term->left->type);
1218                 result = 0;
1219             };
1220         }
1221         if (result) {
1222             result = expect_types(term->info,
1223                                   term->right->type,
1224                                   1, term->left->type->dom) != NULL;
1225             if (! result) {
1226                 type_error_binop(term->info, "application",
1227                                  term->left->type, term->right->type);
1228                 result = 0;
1229             }
1230         }
1231         if (result)
1232             term->type = ref(term->left->type->img);
1233         break;
1234     case A_VALUE:
1235         result = check_value(term);
1236         break;
1237     case A_IDENT:
1238         {
1239             struct type *t = ctx_lookup_type(term->info, ctx, term->ident);
1240             if (t == NULL) {
1241                 syntax_error(term->info, "Undefined variable %s",
1242                              term->ident->str);
1243                 result = 0;
1244             } else {
1245                 term->type = ref(t);
1246             }
1247         }
1248         break;
1249     case A_BRACKET:
1250         result = check_exp(term->brexp, ctx);
1251         if (result) {
1252             term->type = ref(expect_types(term->info, term->brexp->type,
1253                                           1, t_lens));
1254             if (term->type == NULL) {
1255                 type_error1(term->info,
1256                              "[..] is only defined for lenses, not for %s",
1257                             term->brexp->type);
1258                 result = 0;
1259             }
1260         }
1261         break;
1262     case A_FUNC:
1263         {
1264             bind_param(&ctx->local, term->param, NULL);
1265             result = check_exp(term->body, ctx);
1266             if (result) {
1267                 term->type =
1268                     make_arrow_type(term->param->type, term->body->type);
1269             }
1270             unbind_param(&ctx->local, term->param);
1271         }
1272         break;
1273     case A_REP:
1274         result = check_exp(term->exp, ctx);
1275         if (result) {
1276             term->type = ref(expect_types(term->info, term->exp->type, 2,
1277                                           t_regexp, t_lens));
1278             if (term->type == NULL) {
1279                 type_error1(term->info,
1280                             "Incompatible types: repetition is only defined"
1281                             " for regexp and lens, not for %s",
1282                             term->exp->type);
1283                 result = 0;
1284             }
1285         }
1286         break;
1287     default:
1288         assert(0);
1289         break;
1290     }
1291     assert(!result || term->type != NULL);
1292     return result;
1293 }
1294 
check_decl(struct term * term,struct ctx * ctx)1295 static int check_decl(struct term *term, struct ctx *ctx) {
1296     assert(term->tag == A_BIND || term->tag == A_TEST);
1297 
1298     if (term->tag == A_BIND) {
1299         if (!check_exp(term->exp, ctx))
1300             return 0;
1301         term->type = ref(term->exp->type);
1302 
1303         if (bnd_lookup(ctx->local, term->bname) != NULL) {
1304             syntax_error(term->info,
1305                          "the name %s is already defined", term->bname);
1306             return 0;
1307         }
1308         bind_type(&ctx->local, term->bname, term->type);
1309     } else if (term->tag == A_TEST) {
1310         if (!check_exp(term->test, ctx))
1311             return 0;
1312         if (term->result != NULL) {
1313             if (!check_exp(term->result, ctx))
1314                 return 0;
1315             if (! type_equal(term->test->type, term->result->type)) {
1316                 type_error2(term->info,
1317                             "expected test result of type %s but got %s",
1318                             term->result->type, term->test->type);
1319                 return 0;
1320             }
1321         } else {
1322             if (expect_types(term->info, term->test->type, 2,
1323                              t_string, t_tree) == NULL)
1324                 return 0;
1325         }
1326         term->type = ref(term->test->type);
1327     } else {
1328         assert(0);
1329     }
1330     return 1;
1331 }
1332 
typecheck(struct term * term,struct augeas * aug)1333 static int typecheck(struct term *term, struct augeas *aug) {
1334     int ok = 1;
1335     struct ctx ctx;
1336     char *fname;
1337     const char *basenam;
1338 
1339     assert(term->tag == A_MODULE);
1340 
1341     /* Check that the module name is consistent with the filename */
1342     fname = module_basename(term->mname);
1343 
1344     basenam = strrchr(term->info->filename->str, SEP);
1345     if (basenam == NULL)
1346         basenam = term->info->filename->str;
1347     else
1348         basenam += 1;
1349     if (STRNEQ(fname, basenam)) {
1350         syntax_error(term->info,
1351                      "The module %s must be in a file named %s",
1352                      term->mname, fname);
1353         free(fname);
1354         return 0;
1355     }
1356     free(fname);
1357 
1358     ctx.aug = aug;
1359     ctx.local = NULL;
1360     ctx.name = term->mname;
1361     list_for_each(dcl, term->decls) {
1362         ok &= check_decl(dcl, &ctx);
1363     }
1364     unref(ctx.local, binding);
1365     return ok;
1366 }
1367 
1368 static struct value *compile_exp(struct info *, struct term *, struct ctx *);
1369 
compile_union(struct term * exp,struct ctx * ctx)1370 static struct value *compile_union(struct term *exp, struct ctx *ctx) {
1371     struct value *v1 = compile_exp(exp->info, exp->left, ctx);
1372     if (EXN(v1))
1373         return v1;
1374     struct value *v2 = compile_exp(exp->info, exp->right, ctx);
1375     if (EXN(v2)) {
1376         unref(v1, value);
1377         return v2;
1378     }
1379 
1380     struct type *t = exp->type;
1381     struct info *info = exp->info;
1382     struct value *v = NULL;
1383 
1384     v1 = coerce(v1, t);
1385     if (EXN(v1))
1386         return v1;
1387     v2 = coerce(v2, t);
1388     if (EXN(v2)) {
1389         unref(v1, value);
1390         return v2;
1391     }
1392 
1393     if (t->tag == T_REGEXP) {
1394         v = make_value(V_REGEXP, ref(info));
1395         v->regexp = regexp_union(info, v1->regexp, v2->regexp);
1396     } else if (t->tag == T_LENS) {
1397         struct lens *l1 = v1->lens;
1398         struct lens *l2 = v2->lens;
1399         v = lns_make_union(ref(info), ref(l1), ref(l2), LNS_TYPE_CHECK(ctx));
1400     } else {
1401         fatal_error(info, "Tried to union a %s and a %s to yield a %s",
1402                     type_name(exp->left->type), type_name(exp->right->type),
1403                     type_name(t));
1404     }
1405     unref(v1, value);
1406     unref(v2, value);
1407     return v;
1408 }
1409 
compile_minus(struct term * exp,struct ctx * ctx)1410 static struct value *compile_minus(struct term *exp, struct ctx *ctx) {
1411     struct value *v1 = compile_exp(exp->info, exp->left, ctx);
1412     if (EXN(v1))
1413         return v1;
1414     struct value *v2 = compile_exp(exp->info, exp->right, ctx);
1415     if (EXN(v2)) {
1416         unref(v1, value);
1417         return v2;
1418     }
1419 
1420     struct type *t = exp->type;
1421     struct info *info = exp->info;
1422     struct value *v;
1423 
1424     v1 = coerce(v1, t);
1425     v2 = coerce(v2, t);
1426     if (t->tag == T_REGEXP) {
1427         struct regexp *re1 = v1->regexp;
1428         struct regexp *re2 = v2->regexp;
1429         struct regexp *re = regexp_minus(info, re1, re2);
1430         if (re == NULL) {
1431             v = make_exn_value(ref(info),
1432                    "Regular expression subtraction 'r1 - r2' failed");
1433             exn_printf_line(v, "r1: /%s/", re1->pattern->str);
1434             exn_printf_line(v, "r2: /%s/", re2->pattern->str);
1435         } else {
1436             v = make_value(V_REGEXP, ref(info));
1437             v->regexp = re;
1438         }
1439     } else {
1440         v = NULL;
1441         fatal_error(info, "Tried to subtract a %s and a %s to yield a %s",
1442                     type_name(exp->left->type), type_name(exp->right->type),
1443                     type_name(t));
1444     }
1445     unref(v1, value);
1446     unref(v2, value);
1447     return v;
1448 }
1449 
compile_compose(struct term * exp,struct ctx * ctx)1450 static struct value *compile_compose(struct term *exp, struct ctx *ctx) {
1451     struct info *info = exp->info;
1452     struct value *v;
1453 
1454     if (exp->left->type->tag == T_ARROW) {
1455         // FIXME: This is really crufty, and should be desugared in the
1456         // parser so that we don't have to do all this manual type
1457         // computation. Should we write function compostion as
1458         // concatenation instead of using a separate syntax ?
1459 
1460         /* Build lambda x: exp->right (exp->left x) as a closure */
1461         char *var = strdup("@0");
1462         struct term *func = make_param(var, ref(exp->left->type->dom),
1463                                        ref(info));
1464         func->type = make_arrow_type(exp->left->type->dom,
1465                                      exp->right->type->img);
1466         struct term *ident = make_term(A_IDENT, ref(info));
1467         ident->ident = ref(func->param->name);
1468         ident->type = ref(func->param->type);
1469         struct term *app = make_app_term(ref(exp->left), ident, ref(info));
1470         app->type = ref(app->left->type->img);
1471         app = make_app_term(ref(exp->right), app, ref(info));
1472         app->type = ref(app->right->type->img);
1473 
1474         build_func(func, app);
1475 
1476         if (!type_equal(func->type, exp->type)) {
1477             char *f = type_string(func->type);
1478             char *e = type_string(exp->type);
1479             fatal_error(info,
1480               "Composition has type %s but should have type %s", f, e);
1481             free(f);
1482             free(e);
1483             unref(func, term);
1484             return info->error->exn;
1485         }
1486         v = make_closure(func, ctx->local);
1487         unref(func, term);
1488     } else {
1489         v = compile_exp(exp->info, exp->left, ctx);
1490         unref(v, value);
1491         v = compile_exp(exp->info, exp->right, ctx);
1492     }
1493     return v;
1494 }
1495 
compile_concat(struct term * exp,struct ctx * ctx)1496 static struct value *compile_concat(struct term *exp, struct ctx *ctx) {
1497     struct value *v1 = compile_exp(exp->info, exp->left, ctx);
1498     if (EXN(v1))
1499         return v1;
1500     struct value *v2 = compile_exp(exp->info, exp->right, ctx);
1501     if (EXN(v2)) {
1502         unref(v1, value);
1503         return v2;
1504     }
1505 
1506     struct type *t = exp->type;
1507     struct info *info = exp->info;
1508     struct value *v;
1509 
1510     v1 = coerce(v1, t);
1511     v2 = coerce(v2, t);
1512     if (t->tag == T_STRING) {
1513         const char *s1 = v1->string->str;
1514         const char *s2 = v2->string->str;
1515         v = make_value(V_STRING, ref(info));
1516         make_ref(v->string);
1517         if (ALLOC_N(v->string->str, strlen(s1) + strlen(s2) + 1) < 0)
1518             goto error;
1519         char *s = v->string->str;
1520         strcpy(s, s1);
1521         strcat(s, s2);
1522     } else if (t->tag == T_REGEXP) {
1523         v = make_value(V_REGEXP, ref(info));
1524         v->regexp = regexp_concat(info, v1->regexp, v2->regexp);
1525     } else if (t->tag == T_FILTER) {
1526         struct filter *f1 = v1->filter;
1527         struct filter *f2 = v2->filter;
1528         v = make_value(V_FILTER, ref(info));
1529         if (v2->ref == 1 && f2->ref == 1) {
1530             list_append(f2, ref(f1));
1531             v->filter = ref(f2);
1532         } else if (v1->ref == 1 && f1->ref == 1) {
1533             list_append(f1, ref(f2));
1534             v->filter = ref(f1);
1535         } else {
1536             struct filter *cf1, *cf2;
1537             cf1 = make_filter(ref(f1->glob), f1->include);
1538             cf2 = make_filter(ref(f2->glob), f2->include);
1539             cf1->next = ref(f1->next);
1540             cf2->next = ref(f2->next);
1541             list_append(cf1, cf2);
1542             v->filter = cf1;
1543         }
1544     } else if (t->tag == T_LENS) {
1545         struct lens *l1 = v1->lens;
1546         struct lens *l2 = v2->lens;
1547         v = lns_make_concat(ref(info), ref(l1), ref(l2), LNS_TYPE_CHECK(ctx));
1548     } else {
1549         v = NULL;
1550         fatal_error(info, "Tried to concat a %s and a %s to yield a %s",
1551                     type_name(exp->left->type), type_name(exp->right->type),
1552                     type_name(t));
1553     }
1554     unref(v1, value);
1555     unref(v2, value);
1556     return v;
1557  error:
1558     return exp->info->error->exn;
1559 }
1560 
apply(struct term * app,struct ctx * ctx)1561 static struct value *apply(struct term *app, struct ctx *ctx) {
1562     struct value *f = compile_exp(app->info, app->left, ctx);
1563     struct value *result = NULL;
1564     struct ctx lctx;
1565 
1566     if (EXN(f))
1567         return f;
1568 
1569     struct value *arg = compile_exp(app->info, app->right, ctx);
1570     if (EXN(arg)) {
1571         unref(f, value);
1572         return arg;
1573     }
1574 
1575     assert(f->tag == V_CLOS);
1576 
1577     lctx.aug = ctx->aug;
1578     lctx.local = ref(f->bindings);
1579     lctx.name = ctx->name;
1580 
1581     arg = coerce(arg, f->func->param->type);
1582     if (arg == NULL)
1583         goto done;
1584 
1585     bind_param(&lctx.local, f->func->param, arg);
1586     result = compile_exp(app->info, f->func->body, &lctx);
1587     unref(result->info, info);
1588     result->info = ref(app->info);
1589     unbind_param(&lctx.local, f->func->param);
1590 
1591  done:
1592     unref(lctx.local, binding);
1593     unref(arg, value);
1594     unref(f, value);
1595     return result;
1596 }
1597 
compile_bracket(struct term * exp,struct ctx * ctx)1598 static struct value *compile_bracket(struct term *exp, struct ctx *ctx) {
1599     struct value *arg = compile_exp(exp->info, exp->brexp, ctx);
1600     if (EXN(arg))
1601         return arg;
1602     assert(arg->tag == V_LENS);
1603 
1604     struct value *v = lns_make_subtree(ref(exp->info), ref(arg->lens));
1605     unref(arg, value);
1606 
1607     return v;
1608 }
1609 
compile_rep(struct term * rep,struct ctx * ctx)1610 static struct value *compile_rep(struct term *rep, struct ctx *ctx) {
1611     struct value *arg = compile_exp(rep->info, rep->rexp, ctx);
1612     struct value *v = NULL;
1613 
1614     if (EXN(arg))
1615         return arg;
1616 
1617     arg = coerce(arg, rep->type);
1618     if (rep->type->tag == T_REGEXP) {
1619         int min, max;
1620         if (rep->quant == Q_STAR) {
1621             min = 0; max = -1;
1622         } else if (rep->quant == Q_PLUS) {
1623             min = 1; max = -1;
1624         } else if (rep->quant == Q_MAYBE) {
1625             min = 0; max = 1;
1626         } else {
1627             assert(0);
1628             abort();
1629         }
1630         v = make_value(V_REGEXP, ref(rep->info));
1631         v->regexp = regexp_iter(rep->info, arg->regexp, min, max);
1632     } else if (rep->type->tag == T_LENS) {
1633         int c = LNS_TYPE_CHECK(ctx);
1634         if (rep->quant == Q_STAR) {
1635             v = lns_make_star(ref(rep->info), ref(arg->lens), c);
1636         } else if (rep->quant == Q_PLUS) {
1637             v = lns_make_plus(ref(rep->info), ref(arg->lens), c);
1638         } else if (rep->quant == Q_MAYBE) {
1639             v = lns_make_maybe(ref(rep->info), ref(arg->lens), c);
1640         } else {
1641             assert(0);
1642         }
1643     } else {
1644         fatal_error(rep->info, "Tried to repeat a %s to yield a %s",
1645                     type_name(rep->rexp->type), type_name(rep->type));
1646     }
1647     unref(arg, value);
1648     return v;
1649 }
1650 
compile_exp(struct info * info,struct term * exp,struct ctx * ctx)1651 static struct value *compile_exp(struct info *info,
1652                                  struct term *exp, struct ctx *ctx) {
1653     struct value *v = NULL;
1654 
1655     switch (exp->tag) {
1656     case A_COMPOSE:
1657         v = compile_compose(exp, ctx);
1658         break;
1659     case A_UNION:
1660         v = compile_union(exp, ctx);
1661         break;
1662     case A_MINUS:
1663         v = compile_minus(exp, ctx);
1664         break;
1665     case A_CONCAT:
1666         v = compile_concat(exp, ctx);
1667         break;
1668     case A_APP:
1669         v = apply(exp, ctx);
1670         break;
1671     case A_VALUE:
1672         if (exp->value->tag == V_NATIVE) {
1673             v = native_call(info, exp->value->native, ctx);
1674         } else {
1675             v = ref(exp->value);
1676         }
1677         break;
1678     case A_IDENT:
1679         v = ref(ctx_lookup(exp->info, ctx, exp->ident));
1680         break;
1681     case A_BRACKET:
1682         v = compile_bracket(exp, ctx);
1683         break;
1684     case A_FUNC:
1685         v = make_closure(exp, ctx->local);
1686         break;
1687     case A_REP:
1688         v = compile_rep(exp, ctx);
1689         break;
1690     default:
1691         assert(0);
1692         break;
1693     }
1694 
1695     return v;
1696 }
1697 
compile_test(struct term * term,struct ctx * ctx)1698 static int compile_test(struct term *term, struct ctx *ctx) {
1699     struct value *actual = compile_exp(term->info, term->test, ctx);
1700     struct value *expect = NULL;
1701     int ret = 1;
1702 
1703     if (term->tr_tag == TR_EXN) {
1704         if (!EXN(actual)) {
1705             print_info(stdout, term->info);
1706             printf("Test run should have produced exception, but produced\n");
1707             print_value(stdout, actual);
1708             printf("\n");
1709             ret = 0;
1710         }
1711     } else {
1712         if (EXN(actual)) {
1713             print_info(stdout, term->info);
1714             printf("exception thrown in test\n");
1715             print_value(stdout, actual);
1716             printf("\n");
1717             ret = 0;
1718         } else if (term->tr_tag == TR_CHECK) {
1719             expect = compile_exp(term->info, term->result, ctx);
1720             if (EXN(expect))
1721                 goto done;
1722             if (! value_equal(actual, expect)) {
1723                 printf("Test failure:");
1724                 print_info(stdout, term->info);
1725                 printf("\n");
1726                 printf(" Expected:\n");
1727                 print_value(stdout, expect);
1728                 printf("\n");
1729                 printf(" Actual:\n");
1730                 print_value(stdout, actual);
1731                 printf("\n");
1732                 ret = 0;
1733             }
1734         } else {
1735             printf("Test result: ");
1736             print_info(stdout, term->info);
1737             printf("\n");
1738             if (actual->tag == V_TREE) {
1739                 print_tree_braces(stdout, 2, actual->origin->children);
1740             } else {
1741                 print_value(stdout, actual);
1742             }
1743             printf("\n");
1744         }
1745     }
1746  done:
1747     reset_error(term->info->error);
1748     unref(actual, value);
1749     unref(expect, value);
1750     return ret;
1751 }
1752 
compile_decl(struct term * term,struct ctx * ctx)1753 static int compile_decl(struct term *term, struct ctx *ctx) {
1754     if (term->tag == A_BIND) {
1755         int result;
1756 
1757         struct value *v = compile_exp(term->info, term->exp, ctx);
1758         bind(&ctx->local, term->bname, term->type, v);
1759 
1760         if (EXN(v) && !v->exn->seen) {
1761             struct error *error = term->info->error;
1762             struct memstream ms;
1763 
1764             init_memstream(&ms);
1765 
1766             syntax_error(term->info, "Failed to compile %s",
1767                          term->bname);
1768             fprintf(ms.stream, "%s\n", error->details);
1769             print_value(ms.stream, v);
1770             close_memstream(&ms);
1771 
1772             v->exn->seen = 1;
1773             free(error->details);
1774             error->details = ms.buf;
1775         }
1776         result = !(EXN(v) || HAS_ERR(ctx->aug));
1777         unref(v, value);
1778         return result;
1779     } else if (term->tag == A_TEST) {
1780         return compile_test(term, ctx);
1781     }
1782     assert(0);
1783     abort();
1784 }
1785 
compile(struct term * term,struct augeas * aug)1786 static struct module *compile(struct term *term, struct augeas *aug) {
1787     struct ctx ctx;
1788     struct transform *autoload = NULL;
1789     assert(term->tag == A_MODULE);
1790 
1791     ctx.aug = aug;
1792     ctx.local = NULL;
1793     ctx.name = term->mname;
1794     list_for_each(dcl, term->decls) {
1795         if (!compile_decl(dcl, &ctx))
1796             goto error;
1797     }
1798 
1799     if (term->autoload != NULL) {
1800         struct binding *bnd = bnd_lookup(ctx.local, term->autoload);
1801         if (bnd == NULL) {
1802             syntax_error(term->info, "Undefined transform in autoload %s",
1803                          term->autoload);
1804             goto error;
1805         }
1806         if (expect_types(term->info, bnd->type, 1, t_transform) == NULL)
1807             goto error;
1808         autoload = bnd->value->transform;
1809     }
1810     struct module *module = module_create(term->mname);
1811     module->bindings = ctx.local;
1812     module->autoload = ref(autoload);
1813     return module;
1814  error:
1815     unref(ctx.local, binding);
1816     return NULL;
1817 }
1818 
1819 /*
1820  * Defining native functions
1821  */
1822 static struct info *
make_native_info(struct error * error,const char * fname,int line)1823 make_native_info(struct error *error, const char *fname, int line) {
1824     struct info *info;
1825     if (make_ref(info) < 0)
1826         goto error;
1827     info->first_line = info->last_line = line;
1828     info->first_column = info->last_column = 0;
1829     info->error = error;
1830     if (make_ref(info->filename) < 0)
1831         goto error;
1832     info->filename->str = strdup(fname);
1833     return info;
1834  error:
1835     unref(info, info);
1836     return NULL;
1837 }
1838 
define_native_intl(const char * file,int line,struct error * error,struct module * module,const char * name,int argc,func_impl impl,...)1839 int define_native_intl(const char *file, int line,
1840                        struct error *error,
1841                        struct module *module, const char *name,
1842                        int argc, func_impl impl, ...) {
1843     assert(argc > 0);  /* We have no unit type */
1844     assert(argc <= 5);
1845     va_list ap;
1846     enum type_tag tag;
1847     struct term *params = NULL, *body = NULL, *func = NULL;
1848     struct type *type;
1849     struct value *v = NULL;
1850     struct info *info = NULL;
1851     struct ctx ctx;
1852 
1853     info = make_native_info(error, file, line);
1854     if (info == NULL)
1855         goto error;
1856 
1857     va_start(ap, impl);
1858     for (int i=0; i < argc; i++) {
1859         struct term *pterm;
1860         char ident[10];
1861         tag = va_arg(ap, enum type_tag);
1862         type = make_base_type(tag);
1863         snprintf(ident, 10, "@%d", i);
1864         pterm = make_param(strdup(ident), type, ref(info));
1865         list_append(params, pterm);
1866     }
1867     tag = va_arg(ap, enum type_tag);
1868     va_end(ap);
1869 
1870     type = make_base_type(tag);
1871 
1872     make_ref(v);
1873     if (v == NULL)
1874         goto error;
1875     v->tag = V_NATIVE;
1876     v->info = info;
1877     info = NULL;
1878 
1879     if (ALLOC(v->native) < 0)
1880         goto error;
1881     v->native->argc = argc;
1882     v->native->type = type;
1883     v->native->impl = impl;
1884 
1885     make_ref(body);
1886     if (body == NULL)
1887         goto error;
1888     body->info = ref(info);
1889     body->type = ref(type);
1890     body->tag = A_VALUE;
1891     body->value = v;
1892     v = NULL;
1893 
1894     func = build_func(params, body);
1895     params = NULL;
1896     body = NULL;
1897 
1898     ctx.aug = NULL;
1899     ctx.local = ref(module->bindings);
1900     ctx.name = module->name;
1901     if (! check_exp(func, &ctx)) {
1902         fatal_error(info, "Typechecking native %s failed",
1903                     name);
1904         abort();
1905     }
1906     v = make_closure(func, ctx.local);
1907     if (v == NULL) {
1908         unref(module->bindings, binding);
1909         goto error;
1910     }
1911     bind(&ctx.local, name, func->type, v);
1912     unref(v, value);
1913     unref(func, term);
1914     unref(module->bindings, binding);
1915 
1916     module->bindings = ctx.local;
1917     return 0;
1918  error:
1919     list_for_each(p, params) {
1920         unref(p, term);
1921     }
1922     unref(v, value);
1923     unref(body, term);
1924     unref(func, term);
1925     return -1;
1926 }
1927 
1928 
1929 /* Defined in parser.y */
1930 int augl_parse_file(struct augeas *aug, const char *name, struct term **term);
1931 
module_basename(const char * modname)1932 static char *module_basename(const char *modname) {
1933     char *fname;
1934 
1935     if (asprintf(&fname, "%s" AUG_EXT, modname) == -1)
1936         return NULL;
1937     for (int i=0; i < strlen(modname); i++)
1938         fname[i] = tolower(fname[i]);
1939     return fname;
1940 }
1941 
module_filename(struct augeas * aug,const char * modname)1942 static char *module_filename(struct augeas *aug, const char *modname) {
1943     char *dir = NULL;
1944     char *filename = NULL;
1945     char *name = module_basename(modname);
1946 
1947     /* Module names that contain slashes can fool us into finding and
1948      * loading a module in another directory, but once loaded we won't find
1949      * it under MODNAME so that we will later try and load it over and
1950      * over */
1951     if (index(modname, '/') != NULL)
1952         goto error;
1953 
1954     while ((dir = argz_next(aug->modpathz, aug->nmodpath, dir)) != NULL) {
1955         int len = strlen(name) + strlen(dir) + 2;
1956         struct stat st;
1957 
1958         if (REALLOC_N(filename, len) == -1)
1959             goto error;
1960         sprintf(filename, "%s/%s", dir, name);
1961         if (stat(filename, &st) == 0)
1962             goto done;
1963     }
1964  error:
1965     FREE(filename);
1966  done:
1967     free(name);
1968     return filename;
1969 }
1970 
load_module_file(struct augeas * aug,const char * filename,const char * name)1971 int load_module_file(struct augeas *aug, const char *filename,
1972                      const char *name) {
1973     struct term *term = NULL;
1974     int result = -1;
1975 
1976     if (aug->flags & AUG_TRACE_MODULE_LOADING)
1977         printf("Module %s", filename);
1978     augl_parse_file(aug, filename, &term);
1979     if (aug->flags & AUG_TRACE_MODULE_LOADING)
1980         printf(HAS_ERR(aug) ? " failed\n" : " loaded\n");
1981     ERR_BAIL(aug);
1982 
1983     if (! typecheck(term, aug))
1984         goto error;
1985 
1986     struct module *module = compile(term, aug);
1987     bool bad_module = (module == NULL);
1988     if (bad_module && name != NULL) {
1989         /* Put an empty placeholder on the module list so that
1990          * we don't retry loading this module everytime its mentioned
1991          */
1992         module = module_create(name);
1993     }
1994     if (module != NULL) {
1995         list_append(aug->modules, module);
1996         list_for_each(bnd, module->bindings) {
1997             if (bnd->value->tag == V_LENS) {
1998                 lens_release(bnd->value->lens);
1999             }
2000         }
2001     }
2002     ERR_THROW(bad_module, aug, AUG_ESYNTAX, "Failed to load %s", filename);
2003 
2004     result = 0;
2005  error:
2006     // FIXME: This leads to a bad free of a string used in a del lens
2007     // To reproduce run lenses/tests/test_yum.aug
2008     unref(term, term);
2009     return result;
2010 }
2011 
load_module(struct augeas * aug,const char * name)2012 static int load_module(struct augeas *aug, const char *name) {
2013     char *filename = NULL;
2014 
2015     if (module_find(aug->modules, name) != NULL)
2016         return 0;
2017 
2018     if ((filename = module_filename(aug, name)) == NULL)
2019         return -1;
2020 
2021     if (load_module_file(aug, filename, name) == -1)
2022         goto error;
2023 
2024     free(filename);
2025     return 0;
2026 
2027  error:
2028     free(filename);
2029     return -1;
2030 }
2031 
interpreter_init(struct augeas * aug)2032 int interpreter_init(struct augeas *aug) {
2033     int r;
2034 
2035     r = init_fatal_exn(aug->error);
2036     if (r < 0)
2037         return -1;
2038 
2039     aug->modules = builtin_init(aug->error);
2040     if (aug->flags & AUG_NO_MODL_AUTOLOAD)
2041         return 0;
2042 
2043     // For now, we just load every file on the search path
2044     const char *dir = NULL;
2045     glob_t globbuf;
2046     int gl_flags = GLOB_NOSORT;
2047 
2048     MEMZERO(&globbuf, 1);
2049 
2050     while ((dir = argz_next(aug->modpathz, aug->nmodpath, dir)) != NULL) {
2051         char *globpat;
2052         r = asprintf(&globpat, "%s/*.aug", dir);
2053         ERR_NOMEM(r < 0, aug);
2054 
2055         r = glob(globpat, gl_flags, NULL, &globbuf);
2056         if (r != 0 && r != GLOB_NOMATCH) {
2057             /* This really has to be an allocation failure; glob is not
2058              * supposed to return GLOB_ABORTED here */
2059             aug_errcode_t code =
2060                 r == GLOB_NOSPACE ? AUG_ENOMEM : AUG_EINTERNAL;
2061             ERR_REPORT(aug, code, "glob failure for %s", globpat);
2062             free(globpat);
2063             goto error;
2064         }
2065         gl_flags |= GLOB_APPEND;
2066         free(globpat);
2067     }
2068 
2069     for (int i=0; i < globbuf.gl_pathc; i++) {
2070         char *name, *p, *q;
2071         int res;
2072         p = strrchr(globbuf.gl_pathv[i], SEP);
2073         if (p == NULL)
2074             p = globbuf.gl_pathv[i];
2075         else
2076             p += 1;
2077         q = strchr(p, '.');
2078         name = strndup(p, q - p);
2079         name[0] = toupper(name[0]);
2080         res = load_module(aug, name);
2081         free(name);
2082         if (res == -1)
2083             goto error;
2084     }
2085     globfree(&globbuf);
2086     return 0;
2087  error:
2088     globfree(&globbuf);
2089     return -1;
2090 }
2091 
2092 /*
2093  * Local variables:
2094  *  indent-tabs-mode: nil
2095  *  c-indent-level: 4
2096  *  c-basic-offset: 4
2097  *  tab-width: 4
2098  * End:
2099  */
2100