1 // Copyright 2016-2021 Doug Moen
2 // Licensed under the Apache License, version 2.0
3 // See accompanying file LICENSE or https://www.apache.org/licenses/LICENSE-2.0
4 
5 #ifndef LIBCURV_PRIM_H
6 #define LIBCURV_PRIM_H
7 
8 // Curv is an array language, following APL and its successors.
9 // This means that scalar operations (on numbers and booleans)
10 // are generalized to operate on arrays of scalars, in two ways:
11 // element-wise operation, and broadcasting.
12 //   2 + 2 == 4                  -- a scalar operation
13 //   [3,4] + [10,20] == [13,24]  -- element-wise addition
14 //   1 + [10,20] == [11,21]      -- broadcasting
15 
16 #include <libcurv/bool.h>
17 #include <libcurv/context.h>
18 #include <libcurv/exception.h>
19 #include <libcurv/list.h>
20 #include <libcurv/meaning.h>
21 #include <libcurv/num.h>
22 #include <libcurv/reactive.h>
23 #include <libcurv/sc_compiler.h>
24 #include <libcurv/sc_context.h>
25 #include <libcurv/vec.h>
26 
27 namespace curv {
28 
29 struct Binary_Op
30 {
sc_callBinary_Op31     static SC_Value sc_call(
32         SC_Frame& fm, const Operation& a, const Operation& b,
33         Shared<const Phrase> syntax)
34     {
35         throw Exception(At_SC_Phrase(syntax, fm),
36             "operation not supported");
37     }
38 };
39 
40 //----------------------------------------------------------------------//
41 // Templates for converting an Array_Op to a unary or binary Expression //
42 //----------------------------------------------------------------------//
43 
44 template <class Op>
45 struct Unary_Op_Expr : public Prefix_Expr_Base
46 {
47     using Prefix_Expr_Base::Prefix_Expr_Base;
evalUnary_Op_Expr48     virtual Value eval(Frame& fm) const override
49       { return Op::call(Fail::hard, At_Phrase(*syntax_, fm), arg_->eval(fm)); }
sc_evalUnary_Op_Expr50     virtual SC_Value sc_eval(SC_Frame& fm) const override
51       { return Op::sc_op(At_SC_Phrase(syntax_,fm), *arg_, fm); }
printUnary_Op_Expr52     virtual void print(std::ostream& out) const override
53       { out<<Op::Prim::name()<<"("<<*arg_<<")"; }
54 };
55 
56 template <class Op>
57 struct Binary_Op_Expr : public Infix_Expr_Base
58 {
59     using Infix_Expr_Base::Infix_Expr_Base;
evalBinary_Op_Expr60     virtual Value eval(Frame& fm) const override {
61         return Op::call(Fail::hard, At_Phrase(*syntax_, fm),
62             arg1_->eval(fm), arg2_->eval(fm));
63     }
sc_evalBinary_Op_Expr64     virtual SC_Value sc_eval(SC_Frame& fm) const override
65       { return Op::sc_call(fm, *arg1_, *arg2_, syntax_); }
idchrBinary_Op_Expr66     static bool idchr(char c)
67         { return (c>='a'&&c<='z')||(c>='A'&&c<='Z')||c=='_'; }
printBinary_Op_Expr68     virtual void print(std::ostream& out) const override {
69         if (idchr(Op::Prim::name()[0]))
70             out<<Op::Prim::name()<<"["<<*arg1_<<","<<*arg2_<<"]";
71         else
72             out<<"("<<*arg1_<<Op::Prim::name()<<*arg2_<<")";
73     }
74 };
75 
76 //---------------------------------------------------------------//
77 // Templates for converting a Prim to a unary or binary Array_Op //
78 //---------------------------------------------------------------//
79 
80 template <class PRIM>
81 struct Unary_Array_Op
82 {
83     // TODO: optimize: move semantics. unique object reuse.
84 
85     using Prim = PRIM;
86 
87     static Value
callUnary_Array_Op88     call(Fail fl, const At_Syntax& cx, Value x)
89     {
90         typename Prim::scalar_t sx;
91         if (Prim::unbox(x, sx, cx)) {
92             Value r = Prim::call(sx, cx);
93             if (!r.is_missing()) return r;
94         } else if (x.is_ref()) {
95             Ref_Value& rx(x.to_ref_unsafe());
96             switch (rx.type_) {
97             case Ref_Value::ty_abstract_list:
98                 if (rx.subtype_ == Ref_Value::sty_list)
99                     return element_wise_op(fl, cx, (List&)rx);
100                 else
101                     break; // TODO strings are lists?
102             case Ref_Value::ty_reactive:
103                 return reactive_op(fl, cx, (Reactive_Value&)rx);
104             }
105         }
106         FAIL(fl, missing, cx, domain_error(x));
107     }
108     static SC_Value
sc_opUnary_Array_Op109     sc_op(const At_Syntax& cx, Operation& argx, SC_Frame& fm)
110     {
111         // TODO: add array support
112         auto a = sc_eval_op(fm, argx);
113         Prim::sc_check_arg(a, cx);
114         return Prim::sc_call(fm, a);
115     }
116 
117     static Value
element_wise_opUnary_Array_Op118     element_wise_op(Fail fl, const At_Syntax& cx, List& xs)
119     {
120         Shared<List> result = List::make(xs.size());
121         for (unsigned i = 0; i < xs.size(); ++i) {
122             TRY_DEF(r, call(fl, cx, xs[i]));
123             (*result)[i] = r;
124         }
125         return {result};
126     }
127 
128     // Argument x is reactive. Construct a Reactive_Expression.
129     static Value
reactive_opUnary_Array_Op130     reactive_op(Fail fl, const At_Syntax& cx, Reactive_Value &rx)
131     {
132         SC_Type rtype = Prim::sc_result_type(rx.sctype_);
133         if (rtype) {
134             return {make<Reactive_Expression>(
135                 rtype,
136                 make<Unary_Op_Expr<Unary_Array_Op>>(
137                     share(cx.syntax()), rx.expr()),
138                 cx)};
139         } else {
140             FAIL(fl, missing, cx, domain_error({share(rx)}));
141         }
142     }
143 
domain_errorUnary_Array_Op144     static Shared<const String> domain_error(Value x)
145     {
146         return stringify(x, ": domain error");
147     }
148 };
149 
150 template <class PRIM>
151 struct Binary_Array_Op
152 {
153     // TODO: optimize: move semantics. unique object reuse.
154     // TODO: optimize: faster fast path in `op` for number case.
155 
156     using Prim = PRIM;
157 
domain_errorBinary_Array_Op158     static Exception domain_error(
159         const At_Syntax& cx, unsigned i, Value x, Value y)
160     {
161         if (dynamic_cast<const At_Arg*>(&cx)) {
162             return Exception(At_Index(i,cx),
163                 stringify(i==0?x:y,": domain error"));
164         }
165         if (auto ap = dynamic_cast<const At_Phrase*>(&cx)) {
166             if (auto bin = cast<const Binary_Phrase>(ap->phrase_))
167                 return Exception(cx,
168                     stringify(x," ",bin->opname()," ",y,": domain error"));
169         }
170         return Exception(cx, stringify(i==0?x:y,": domain error"));
171     }
172 
domain_errorBinary_Array_Op173     static Exception domain_error(
174         const At_Syntax& cx, Value x, Value y)
175     {
176         if (auto ap = dynamic_cast<const At_Phrase*>(&cx)) {
177             if (auto bin = cast<const Binary_Phrase>(ap->phrase_))
178                 return Exception(cx,
179                     stringify(x," ",bin->opname()," ",y,": domain error"));
180         }
181         return Exception(cx, stringify("[",x,",",y,"]: domain error"));
182     }
183 
184     static Value
reduceBinary_Array_Op185     reduce(Fail fl, const At_Syntax& cx, Value zero, Value arg)
186     {
187         auto list = arg.to<List>(fl, cx);
188         if (list == nullptr) return missing;
189         unsigned n = list->size();
190         if (n == 0)
191             return {zero};
192         Value result = list->front();
193         for (unsigned i = 1; i < n; ++i) {
194             TRY_DEF(r, call(fl, cx, result, list->at(i)));
195             result = r;
196         }
197         return result;
198     }
199     static SC_Value
sc_reduceBinary_Array_Op200     sc_reduce(const At_Syntax& cx, Value zero, Operation& argx, SC_Frame& fm)
201     {
202         if (auto list = cast_list_expr(argx)) {
203             if (list->empty())
204                 return sc_eval_const(fm, zero, *argx.syntax_);
205             auto first = sc_eval_op(fm, *list->at(0));
206             if (list->size() == 1) {
207                 Prim::sc_check_arg(first, cx);
208                 return first;
209             }
210             for (unsigned i = 1; i < list->size(); ++i) {
211                 auto second = sc_eval_op(fm, *list->at(i));
212                 Prim::sc_check_args(fm, first, second, cx);
213                 first = Prim::sc_call(fm, first, second);
214             }
215             return first;
216         }
217         else {
218             // Reduce an array value that exists at GPU run time.
219             // TODO: For a large 1D array, use a GPU loop and call a function.
220             // TODO: Binary_Array_Op::sc_reduce: reduce a matrix.
221             // 2D arrays (SC_Type rank 2) are not supported, because you can't
222             // generate a rank 1 array at GPU runtime, for now at least.
223             // For a single Vec, this inline expansion of the loop is good.
224             auto arg = sc_eval_op(fm, argx);
225             if (arg.type.is_vec()) {
226                 auto first = sc_vec_element(fm, arg, 0);
227                 for (unsigned i = 1; i < arg.type.count(); ++i) {
228                     auto second = sc_vec_element(fm, arg, i);
229                     Prim::sc_check_args(fm, first, second, cx);
230                     first = Prim::sc_call(fm, first, second);
231                 }
232                 return first;
233             }
234             else {
235                 throw Exception(cx, "argument is not a vector");
236             }
237         }
238     }
239 
240     static Value
callBinary_Array_Op241     call(Fail fl, const At_Syntax& cx, Value x, Value y)
242     {
243         // fast path: both x and y are scalars
244         // remaining cases:
245         // - x is a scalar, y is a list
246         // - x is a list, y is a scalar
247         // - x and y are lists
248         // - either x, or y, or both, is reactive
249         typename Prim::left_t sx;
250         typename Prim::right_t sy;
251         if (Prim::unbox_left(x, sx, cx)) {
252             if (Prim::unbox_right(y, sy, cx)) {
253                 Value r = Prim::call(sx, sy, cx);
254                 if (!r.is_missing()) return r;
255             }
256             else if (y.is_ref()) {
257                 Ref_Value& ry(y.to_ref_unsafe());
258                 switch (ry.type_) {
259                 case Ref_Value::ty_abstract_list:
260                     if (ry.subtype_ == Ref_Value::sty_list)
261                         return broadcast_right(fl, cx, x, (List&)ry);
262                     else
263                         break; // TODO: strings are lists
264                 case Ref_Value::ty_reactive:
265                     return reactive_op(cx, x, y);
266                 }
267             }
268             throw domain_error(cx,x,y);
269         } else if (x.is_ref()) {
270             Ref_Value& rx(x.to_ref_unsafe());
271             switch (rx.type_) {
272             case Ref_Value::ty_abstract_list:
273                 if (Prim::unbox_right(y, sy, cx))
274                     return broadcast_left(fl, cx, (List&)rx, y);
275                 else if (rx.subtype_ == Ref_Value::sty_list && y.is_ref()) {
276                     Ref_Value& ry(y.to_ref_unsafe());
277                     switch (ry.type_) {
278                     case Ref_Value::ty_abstract_list:
279                         if (ry.subtype_ == Ref_Value::sty_list)
280                             return element_wise_op(fl, cx, (List&)rx, (List&)ry);
281                         else
282                             break; // TODO: strings are lists
283                     case Ref_Value::ty_reactive:
284                         return reactive_op(cx, x, y);
285                     }
286                 }
287                 throw domain_error(cx,1,x,y);
288             case Ref_Value::ty_reactive:
289                 return reactive_op(cx, x, y);
290             }
291         }
292         throw domain_error(cx,0,x,y);
293     }
294     static SC_Value
sc_opBinary_Array_Op295     sc_op(const At_Syntax& cx, Operation& argx, SC_Frame& fm)
296     {
297         auto list = cast_list_expr(argx);
298         if (list && list->size() == 2) {
299             auto first = sc_eval_op(fm, *list->at(0));
300             auto second = sc_eval_op(fm, *list->at(1));
301             Prim::sc_check_args(fm, first, second, cx);
302             return Prim::sc_call(fm, first, second);
303         }
304         // TODO: Binary_Array_Op::sc_op: accept a 2-vector, 2-array or mat2.
305         throw Exception(cx, "expected a list of size 2");
306     }
307     static SC_Value
sc_callBinary_Array_Op308     sc_call(SC_Frame& fm, Operation& ax, Operation& ay, Shared<const Phrase> ph)
309     {
310         auto x = sc_eval_op(fm, ax);
311         auto y = sc_eval_op(fm, ay);
312         Prim::sc_check_args(fm, x, y, At_SC_Phrase(ph, fm));
313         return Prim::sc_call(fm, x, y);
314     }
315 
316     static Value
broadcast_leftBinary_Array_Op317     broadcast_left(Fail fl, const At_Syntax& cx, List& xlist, Value y)
318     {
319         Shared<List> result = List::make(xlist.size());
320         for (unsigned i = 0; i < xlist.size(); ++i) {
321             TRY_DEF(r, call(fl, cx, xlist[i], y));
322             (*result)[i] = r;
323         }
324         return {result};
325     }
326 
327     static Value
broadcast_rightBinary_Array_Op328     broadcast_right(Fail fl, const At_Syntax& cx, Value x, List& ylist)
329     {
330         Shared<List> result = List::make(ylist.size());
331         for (unsigned i = 0; i < ylist.size(); ++i) {
332             TRY_DEF(r, call(fl, cx, x, ylist[i]));
333             (*result)[i] = r;
334         }
335         return {result};
336     }
337 
338     static Value
element_wise_opBinary_Array_Op339     element_wise_op(Fail fl, const At_Syntax& cx, List& xs, List& ys)
340     {
341         if (xs.size() != ys.size()) {
342             FAIL(fl, missing, cx, stringify(
343                 "mismatched list sizes (",
344                 xs.size(),",",ys.size(),") in array operation"));
345         }
346         Shared<List> result = List::make(xs.size());
347         for (unsigned i = 0; i < xs.size(); ++i) {
348             TRY_DEF(r, call(fl, cx, xs[i], ys[i]));
349             (*result)[i] = r;
350         }
351         return {result};
352     }
353 
354     // At least one of x and y is reactive. Construct a Reactive_Expression.
355     static Value
reactive_opBinary_Array_Op356     reactive_op(const At_Syntax& cx, Value x, Value y)
357     {
358         Shared<Operation> x_expr;
359         SC_Type x_type;
360         if (auto xr = x.maybe<Reactive_Value>()) {
361             x_expr = xr->expr();
362             x_type = xr->sctype_;
363         } else {
364             x_expr = make<Constant>(share(cx.syntax()), x);
365             x_type = sc_type_of(x);
366         }
367 
368         Shared<Operation> y_expr;
369         SC_Type y_type;
370         if (auto yr = y.maybe<Reactive_Value>()) {
371             y_expr = yr->expr();
372             y_type = yr->sctype_;
373         } else {
374             y_expr = make<Constant>(share(cx.syntax()), y);
375             y_type = sc_type_of(y);
376         }
377 
378         SC_Type rtype = Prim::sc_result_type(x_type, y_type);
379         if (rtype) {
380             return {make<Reactive_Expression>(
381                 rtype,
382                 make<Binary_Op_Expr<Binary_Array_Op>>(
383                     share(cx.syntax()), x_expr, y_expr),
384                 cx)};
385         } else {
386             throw domain_error(cx, x, y);
387         }
388     }
389 };
390 
391 //----------------------------------------------------------------------------//
392 // Base types for Prim classes.                                               //
393 // Each Prim class defines the semantics of a primitive operator or function. //
394 // Each base class defines the argument and result types of a set of Prims.   //
395 //----------------------------------------------------------------------------//
396 
397 // A primitive mapping Bool -> Num.
398 // In SubCurv, argument types are Bool or Bvec but not Bool32.
399 struct Unary_Bool_To_Num_Prim
400 {
401     typedef bool scalar_t;
unboxUnary_Bool_To_Num_Prim402     static bool unbox(Value a, scalar_t& b, const Context& cx)
403     {
404         if (a.is_bool()) {
405             b = a.to_bool_unsafe();
406             return true;
407         } else
408             return false;
409     }
sc_check_argUnary_Bool_To_Num_Prim410     static void sc_check_arg(SC_Value a, const Context& cx)
411     {
412         if (a.type.is_bool_or_vec()) return;
413         throw Exception(cx, "argument must be Bool or BVec");
414     }
sc_result_typeUnary_Bool_To_Num_Prim415     static SC_Type sc_result_type(SC_Type a)
416     {
417         if (a.is_bool_or_vec())
418             return SC_Type::Num(a.count());
419         else
420             return {};
421     }
422 };
423 
424 // maps bool -> bool
425 struct Unary_Bool_Prim
426 {
427     typedef bool scalar_t;
unboxUnary_Bool_Prim428     static bool unbox(Value a, scalar_t& b, const Context& cx)
429     {
430         if (a.is_bool()) {
431             b = a.to_bool_unsafe();
432             return true;
433         } else
434             return false;
435     }
sc_check_argUnary_Bool_Prim436     static void sc_check_arg(SC_Value a, const Context& cx)
437     {
438         if (a.type.is_bool_plex()) return;
439         throw Exception(cx, stringify("expected Bool or Bool32, got ",a.type));
440     }
sc_result_typeUnary_Bool_Prim441     static SC_Type sc_result_type(SC_Type a)
442     {
443         if (a.is_bool_plex())
444             return a;
445         else
446             return {};
447     }
448 };
449 
450 // Maps [bool,bool] -> bool.
451 // In SubCurv, accepts Bool and Bool32 arguments, and vec of same.
452 struct Binary_Bool_Prim
453 {
454     typedef bool left_t, right_t;
unbox_leftBinary_Bool_Prim455     static bool unbox_left(Value a, left_t& b, const Context&)
456     {
457         if (a.is_bool()) {
458             b = a.to_bool_unsafe();
459             return true;
460         } else
461             return false;
462     }
unbox_rightBinary_Bool_Prim463     static bool unbox_right(Value a, right_t& b, const Context&)
464     {
465         if (a.is_bool()) {
466             b = a.to_bool_unsafe();
467             return true;
468         } else
469             return false;
470     }
sc_check_argBinary_Bool_Prim471     static void sc_check_arg(SC_Value a, const Context& cx)
472     {
473         if (a.type.is_bool_plex()) return;
474         throw Exception(cx, "argument must be Bool or Bool32");
475     }
sc_check_argsBinary_Bool_Prim476     static void sc_check_args(
477         SC_Frame& /*fm*/, SC_Value& a, SC_Value& b, const Context& cx)
478     {
479         if (a.type.is_bool_or_vec()) {
480             if (b.type.is_bool_or_vec()) {
481                 if (a.type.count() != b.type.count()
482                     && a.type.count() > 1 && b.type.count() > 1)
483                 {
484                     throw Exception(cx, stringify(
485                         "can't combine lists of different sizes (",
486                         a.type.count(), " and ", b.type.count(), ")"));
487                 }
488                 return;
489             }
490             else if (b.type.is_bool32_or_vec()) {
491                 // TODO: convert a to Bool32 via broadcasting
492             }
493         }
494         else if (a.type.is_bool32_or_vec()) {
495             if (b.type.is_bool32_or_vec()) {
496                 if (a.type.count() != b.type.count()
497                     && a.type.count() > 1 && b.type.count() > 1)
498                 {
499                     throw Exception(cx, stringify(
500                         "can't combine lists of different sizes (",
501                         a.type.count(), " and ", b.type.count(), ")"));
502                 }
503                 return;
504             }
505             if (b.type.is_bool()) {
506                 // TODO: convert b to Bool32 via broadcasting?
507             }
508         }
509         throw Exception(cx, stringify(
510             "arguments must be Bool or Bool32 (got ",
511             a.type, " and ", b.type, " instead)"));
512     }
sc_result_typeBinary_Bool_Prim513     static SC_Type sc_result_type(SC_Type a, SC_Type b)
514     {
515         if (a.is_bool_plex() && b.is_bool_plex())
516             return sc_unify_tensor_types(a,b);
517         else
518             return {};
519     }
520 };
521 
522 // A primitive mapping number -> number.
523 // The corresponding GLSL primitive accepts a number, vector or matrix.
524 struct Unary_Num_SCMat_Prim
525 {
526     typedef double scalar_t;
unboxUnary_Num_SCMat_Prim527     static bool unbox(Value a, scalar_t& b, const Context&)
528     {
529         if (a.is_num()) {
530             b = a.to_num_unsafe();
531             return true;
532         } else
533             return false;
534     }
sc_check_argUnary_Num_SCMat_Prim535     static void sc_check_arg(SC_Value a, const Context& cx)
536     {
537         if (!a.type.is_num_plex())
538             throw Exception(cx, "argument must be a Num, Vec or Mat");
539     }
sc_result_typeUnary_Num_SCMat_Prim540     static SC_Type sc_result_type(SC_Type a)
541     {
542         return a.is_num_plex() ? a : SC_Type{};
543     }
544 };
545 
546 // A primitive mapping number -> number.
547 // The corresponding GLSL primitive accepts a number or vector.
548 struct Unary_Num_SCVec_Prim : public Unary_Num_SCMat_Prim
549 {
sc_check_argUnary_Num_SCVec_Prim550     static void sc_check_arg(SC_Value a, const Context& cx)
551     {
552         if (!a.type.is_num_or_vec())
553             throw Exception(cx, "argument must be a Num or Vec");
554     }
sc_result_typeUnary_Num_SCVec_Prim555     static SC_Type sc_result_type(SC_Type a)
556     {
557         return a.is_num_or_vec() ? a : SC_Type{};
558     }
559 };
560 
561 // A primitive mapping number -> bool32.
562 // The corresponding GLSL primitive accepts a number or vector.
563 struct Unary_Num_To_Bool32_Prim : public Unary_Num_SCVec_Prim
564 {
sc_result_typeUnary_Num_To_Bool32_Prim565     static SC_Type sc_result_type(SC_Type a)
566     {
567         return a.is_num_or_vec() ? SC_Type::Bool32(a.count()) : SC_Type{};
568     }
569 };
570 
571 // Maps [num,num] -> num.
572 // The corresponding GLSL primitive accepts number, vector or matrix arguments.
573 struct Binary_Num_SCMat_Prim : public Unary_Num_SCMat_Prim
574 {
575     typedef double left_t;
576     typedef double right_t;
unbox_leftBinary_Num_SCMat_Prim577     static bool unbox_left(Value a, scalar_t& b, const Context& cx)
578     {
579         return unbox(a, b, cx);
580     }
unbox_rightBinary_Num_SCMat_Prim581     static bool unbox_right(Value a, scalar_t& b, const Context& cx)
582     {
583         return unbox(a, b, cx);
584     }
sc_check_argsBinary_Num_SCMat_Prim585     static void sc_check_args(
586         SC_Frame& fm, SC_Value& a, SC_Value& b, const Context& cx)
587     {
588         if (!a.type.is_num_plex()) {
589             throw Exception(At_Index(0, cx),
590                 stringify("argument expected to be Num, Vec or Mat; got ",
591                     a.type));
592         }
593         if (!b.type.is_num_plex()) {
594             throw Exception(At_Index(1, cx),
595                 stringify("argument expected to be Num, Vec or Mat; got ",
596                     a.type));
597         }
598         sc_plex_unify(fm, a, b, cx);
599     }
sc_result_typeBinary_Num_SCMat_Prim600     static SC_Type sc_result_type(SC_Type a, SC_Type b)
601     {
602         if (a.is_num_tensor() && b.is_num_tensor())
603             return sc_unify_tensor_types(a, b);
604         else
605             return {};
606     }
607 };
608 
609 // Maps [Num,Num] -> Num.
610 // The corresponding GLSL primitive accepts number or vector arguments.
611 struct Binary_Num_SCVec_Prim : public Unary_Num_SCVec_Prim
612 {
613     typedef double left_t;
614     typedef double right_t;
unbox_leftBinary_Num_SCVec_Prim615     static bool unbox_left(Value a, scalar_t& b, const Context& cx)
616     {
617         return unbox(a, b, cx);
618     }
unbox_rightBinary_Num_SCVec_Prim619     static bool unbox_right(Value a, scalar_t& b, const Context& cx)
620     {
621         return unbox(a, b, cx);
622     }
sc_check_argsBinary_Num_SCVec_Prim623     static void sc_check_args(
624         SC_Frame& fm, SC_Value& a, SC_Value& b, const Context& cx)
625     {
626         if (!a.type.is_num_or_vec()) {
627             throw Exception(At_Index(0, cx),
628                 stringify("argument expected to be Num or Vec; got ", a.type));
629         }
630         if (!b.type.is_num_or_vec()) {
631             throw Exception(At_Index(1, cx),
632                 stringify("argument expected to be Num or Vec; got ", a.type));
633         }
634         sc_plex_unify(fm, a, b, cx);
635     }
sc_result_typeBinary_Num_SCVec_Prim636     static SC_Type sc_result_type(SC_Type a, SC_Type b)
637     {
638         if (a.is_num_tensor() && b.is_num_tensor())
639             return sc_unify_tensor_types(a, b);
640         else
641             return {};
642     }
643 };
644 
645 // maps [num,num] -> bool
646 struct Binary_Num_To_Bool_Prim : public Binary_Num_SCVec_Prim
647 {
sc_result_typeBinary_Num_To_Bool_Prim648     static SC_Type sc_result_type(SC_Type a, SC_Type b)
649     {
650         if (a.is_num_or_vec() && b.is_num_or_vec()) {
651             SC_Type r = sc_unify_tensor_types(a,b);
652             if (r) return SC_Type::Bool(r.count());
653         }
654         return {};
655     }
656 };
657 
658 // The left operand is a non-empty list of booleans.
659 // The right operand is an integer >= 0 and < the size of the left operand.
660 // The result has the same type as the left operand.
661 // (These restrictions on the right operand conform to the definition
662 // of << and >> in the C/C++/GLSL languages.)
663 struct Shift_Prim
664 {
665     typedef Shared<const List> left_t;
666     typedef double right_t;
unbox_leftShift_Prim667     static bool unbox_left(Value a, left_t& b, const Context&)
668     {
669         b = a.maybe<const List>();
670         return b && !b->empty() && b->front().is_bool();
671     }
unbox_rightShift_Prim672     static bool unbox_right(Value a, right_t& b, const Context&)
673     {
674         b = a.to_num_or_nan();
675         return b == b;
676     }
sc_check_argsShift_Prim677     static void sc_check_args(
678         SC_Frame& /*fm*/, SC_Value& a, SC_Value& b, const Context& cx)
679     {
680         if (!a.type.is_bool32_or_vec()) {
681             throw Exception(At_Index(0, cx),
682                 stringify("expected argument of type Bool32, got ", a.type));
683         }
684         if (b.type != SC_Type::Num()) {
685             throw Exception(At_Index(1, cx),
686                 stringify("expected argument of type Num, got ", b.type));
687         }
688     }
sc_result_typeShift_Prim689     static SC_Type sc_result_type(SC_Type a, SC_Type b)
690     {
691         if (a.is_bool32_or_vec() && b.is_num())
692             return a;
693         else
694             return {};
695     }
696 };
697 
698 struct Unary_Vec2_To_Num_Prim
699 {
700     typedef Vec2 scalar_t;
unboxUnary_Vec2_To_Num_Prim701     static bool unbox(Value a, scalar_t& b, const Context&)
702         { return unbox_vec2(a, b); }
sc_check_argUnary_Vec2_To_Num_Prim703     static void sc_check_arg(SC_Value a, const Context& cx)
704     {
705         if (a.type != SC_Type::Num(2)) {
706             throw Exception(cx, stringify("expected a Vec2; got ", a.type));
707         }
708     }
sc_result_typeUnary_Vec2_To_Num_Prim709     static SC_Type sc_result_type(SC_Type a)
710     {
711         if (a == SC_Type::Num(2))
712             return SC_Type::Num();
713         else
714             return {};
715     }
716 };
717 
718 struct Bool32_Prim
719 {
unbox_bool32Bool32_Prim720     static bool unbox_bool32(Value in, unsigned& out, const Context& cx)
721     {
722         auto li = in.maybe<const List>();
723         if (!li || li->size() != 32 || !li->front().is_bool())
724             return false;
725         out = bool32_to_nat(li, cx);
726         return true;
727     }
728 };
729 struct Unary_Bool32_To_Num_Prim : public Bool32_Prim
730 {
731     typedef unsigned scalar_t;
unboxUnary_Bool32_To_Num_Prim732     static bool unbox(Value a, scalar_t& b, const Context& cx)
733     {
734         return unbox_bool32(a, b, cx);
735     }
sc_check_argUnary_Bool32_To_Num_Prim736     static void sc_check_arg(SC_Value a, const Context& cx)
737     {
738         if (!a.type.is_bool32_or_vec())
739             throw Exception(cx, "argument must be a Bool32 or list of Bool32");
740     }
sc_result_typeUnary_Bool32_To_Num_Prim741     static SC_Type sc_result_type(SC_Type a)
742     {
743         if (a.is_bool32_or_vec())
744             return SC_Type::Num(a.count());
745         else
746             return {};
747     }
748 };
749 // maps [bool32,bool32] -> bool32
750 struct Binary_Bool32_Prim : public Bool32_Prim
751 {
752     typedef unsigned left_t;
753     typedef unsigned right_t;
unbox_leftBinary_Bool32_Prim754     static bool unbox_left(Value a, left_t& b, const Context& cx)
755     {
756         return unbox_bool32(a, b, At_Index(0, cx));
757     }
unbox_rightBinary_Bool32_Prim758     static bool unbox_right(Value a, right_t& b, const Context& cx)
759     {
760         return unbox_bool32(a, b, At_Index(1, cx));
761     }
sc_check_argBinary_Bool32_Prim762     static void sc_check_arg(SC_Value a, const Context& cx)
763     {
764         if (!a.type.is_bool32_or_vec())
765             throw Exception(cx, "argument must be a Bool32 or list of Bool32");
766     }
sc_check_argsBinary_Bool32_Prim767     static void sc_check_args(
768         SC_Frame& /*fm*/, SC_Value& a, SC_Value& b, const Context& cx)
769     {
770         if (!a.type.is_bool32_or_vec()) {
771             throw Exception(At_Index(0, cx),
772                 stringify("expected argument of type Bool32, got ", a.type));
773         }
774         if (!b.type.is_bool32_or_vec()) {
775             throw Exception(At_Index(1, cx),
776                 stringify("expected argument of type Bool32, got ", b.type));
777         }
778         if (a.type != b.type) {
779             // Note, it's impossible to unify types of a and b
780             // with the current palette of plex types.
781             throw Exception(cx,
782                 stringify("mismatched argument types ",a.type," and ",b.type));
783         }
784     }
sc_result_typeBinary_Bool32_Prim785     static SC_Type sc_result_type(SC_Type a, SC_Type b)
786     {
787         if (a.is_bool32_or_vec() && b.is_bool32_or_vec() && a == b)
788             return a;
789         else
790             return {};
791     }
792 };
793 
794 // --------- //
795 // Utilities //
796 // --------- //
797 
798 // convert the result of Value::equal() from Ternary to Value
799 template <class Expr>
eqval(Ternary tr,Value a,Value b,const At_Syntax & cx)800 Value eqval(Ternary tr, Value a, Value b, const At_Syntax& cx)
801 {
802     if (tr != Ternary::Unknown)
803         return {tr.to_bool()};
804     else
805         return {make<Reactive_Expression>(
806             SC_Type::Bool(),
807             make<Expr>(
808                 share(cx.syntax()),
809                 to_expr(a, cx.syntax()),
810                 to_expr(b, cx.syntax())),
811             cx)};
812 }
813 
814 } // namespace
815 #endif // header guard
816