1 // Copyright 2016-2021 Doug Moen
2 // Licensed under the Apache License, version 2.0
3 // See accompanying file LICENSE or https://www.apache.org/licenses/LICENSE-2.0
4
5 #ifndef LIBCURV_PRIM_H
6 #define LIBCURV_PRIM_H
7
8 // Curv is an array language, following APL and its successors.
9 // This means that scalar operations (on numbers and booleans)
10 // are generalized to operate on arrays of scalars, in two ways:
11 // element-wise operation, and broadcasting.
12 // 2 + 2 == 4 -- a scalar operation
13 // [3,4] + [10,20] == [13,24] -- element-wise addition
14 // 1 + [10,20] == [11,21] -- broadcasting
15
16 #include <libcurv/bool.h>
17 #include <libcurv/context.h>
18 #include <libcurv/exception.h>
19 #include <libcurv/list.h>
20 #include <libcurv/meaning.h>
21 #include <libcurv/num.h>
22 #include <libcurv/reactive.h>
23 #include <libcurv/sc_compiler.h>
24 #include <libcurv/sc_context.h>
25 #include <libcurv/vec.h>
26
27 namespace curv {
28
29 struct Binary_Op
30 {
sc_callBinary_Op31 static SC_Value sc_call(
32 SC_Frame& fm, const Operation& a, const Operation& b,
33 Shared<const Phrase> syntax)
34 {
35 throw Exception(At_SC_Phrase(syntax, fm),
36 "operation not supported");
37 }
38 };
39
40 //----------------------------------------------------------------------//
41 // Templates for converting an Array_Op to a unary or binary Expression //
42 //----------------------------------------------------------------------//
43
44 template <class Op>
45 struct Unary_Op_Expr : public Prefix_Expr_Base
46 {
47 using Prefix_Expr_Base::Prefix_Expr_Base;
evalUnary_Op_Expr48 virtual Value eval(Frame& fm) const override
49 { return Op::call(Fail::hard, At_Phrase(*syntax_, fm), arg_->eval(fm)); }
sc_evalUnary_Op_Expr50 virtual SC_Value sc_eval(SC_Frame& fm) const override
51 { return Op::sc_op(At_SC_Phrase(syntax_,fm), *arg_, fm); }
printUnary_Op_Expr52 virtual void print(std::ostream& out) const override
53 { out<<Op::Prim::name()<<"("<<*arg_<<")"; }
54 };
55
56 template <class Op>
57 struct Binary_Op_Expr : public Infix_Expr_Base
58 {
59 using Infix_Expr_Base::Infix_Expr_Base;
evalBinary_Op_Expr60 virtual Value eval(Frame& fm) const override {
61 return Op::call(Fail::hard, At_Phrase(*syntax_, fm),
62 arg1_->eval(fm), arg2_->eval(fm));
63 }
sc_evalBinary_Op_Expr64 virtual SC_Value sc_eval(SC_Frame& fm) const override
65 { return Op::sc_call(fm, *arg1_, *arg2_, syntax_); }
idchrBinary_Op_Expr66 static bool idchr(char c)
67 { return (c>='a'&&c<='z')||(c>='A'&&c<='Z')||c=='_'; }
printBinary_Op_Expr68 virtual void print(std::ostream& out) const override {
69 if (idchr(Op::Prim::name()[0]))
70 out<<Op::Prim::name()<<"["<<*arg1_<<","<<*arg2_<<"]";
71 else
72 out<<"("<<*arg1_<<Op::Prim::name()<<*arg2_<<")";
73 }
74 };
75
76 //---------------------------------------------------------------//
77 // Templates for converting a Prim to a unary or binary Array_Op //
78 //---------------------------------------------------------------//
79
80 template <class PRIM>
81 struct Unary_Array_Op
82 {
83 // TODO: optimize: move semantics. unique object reuse.
84
85 using Prim = PRIM;
86
87 static Value
callUnary_Array_Op88 call(Fail fl, const At_Syntax& cx, Value x)
89 {
90 typename Prim::scalar_t sx;
91 if (Prim::unbox(x, sx, cx)) {
92 Value r = Prim::call(sx, cx);
93 if (!r.is_missing()) return r;
94 } else if (x.is_ref()) {
95 Ref_Value& rx(x.to_ref_unsafe());
96 switch (rx.type_) {
97 case Ref_Value::ty_abstract_list:
98 if (rx.subtype_ == Ref_Value::sty_list)
99 return element_wise_op(fl, cx, (List&)rx);
100 else
101 break; // TODO strings are lists?
102 case Ref_Value::ty_reactive:
103 return reactive_op(fl, cx, (Reactive_Value&)rx);
104 }
105 }
106 FAIL(fl, missing, cx, domain_error(x));
107 }
108 static SC_Value
sc_opUnary_Array_Op109 sc_op(const At_Syntax& cx, Operation& argx, SC_Frame& fm)
110 {
111 // TODO: add array support
112 auto a = sc_eval_op(fm, argx);
113 Prim::sc_check_arg(a, cx);
114 return Prim::sc_call(fm, a);
115 }
116
117 static Value
element_wise_opUnary_Array_Op118 element_wise_op(Fail fl, const At_Syntax& cx, List& xs)
119 {
120 Shared<List> result = List::make(xs.size());
121 for (unsigned i = 0; i < xs.size(); ++i) {
122 TRY_DEF(r, call(fl, cx, xs[i]));
123 (*result)[i] = r;
124 }
125 return {result};
126 }
127
128 // Argument x is reactive. Construct a Reactive_Expression.
129 static Value
reactive_opUnary_Array_Op130 reactive_op(Fail fl, const At_Syntax& cx, Reactive_Value &rx)
131 {
132 SC_Type rtype = Prim::sc_result_type(rx.sctype_);
133 if (rtype) {
134 return {make<Reactive_Expression>(
135 rtype,
136 make<Unary_Op_Expr<Unary_Array_Op>>(
137 share(cx.syntax()), rx.expr()),
138 cx)};
139 } else {
140 FAIL(fl, missing, cx, domain_error({share(rx)}));
141 }
142 }
143
domain_errorUnary_Array_Op144 static Shared<const String> domain_error(Value x)
145 {
146 return stringify(x, ": domain error");
147 }
148 };
149
150 template <class PRIM>
151 struct Binary_Array_Op
152 {
153 // TODO: optimize: move semantics. unique object reuse.
154 // TODO: optimize: faster fast path in `op` for number case.
155
156 using Prim = PRIM;
157
domain_errorBinary_Array_Op158 static Exception domain_error(
159 const At_Syntax& cx, unsigned i, Value x, Value y)
160 {
161 if (dynamic_cast<const At_Arg*>(&cx)) {
162 return Exception(At_Index(i,cx),
163 stringify(i==0?x:y,": domain error"));
164 }
165 if (auto ap = dynamic_cast<const At_Phrase*>(&cx)) {
166 if (auto bin = cast<const Binary_Phrase>(ap->phrase_))
167 return Exception(cx,
168 stringify(x," ",bin->opname()," ",y,": domain error"));
169 }
170 return Exception(cx, stringify(i==0?x:y,": domain error"));
171 }
172
domain_errorBinary_Array_Op173 static Exception domain_error(
174 const At_Syntax& cx, Value x, Value y)
175 {
176 if (auto ap = dynamic_cast<const At_Phrase*>(&cx)) {
177 if (auto bin = cast<const Binary_Phrase>(ap->phrase_))
178 return Exception(cx,
179 stringify(x," ",bin->opname()," ",y,": domain error"));
180 }
181 return Exception(cx, stringify("[",x,",",y,"]: domain error"));
182 }
183
184 static Value
reduceBinary_Array_Op185 reduce(Fail fl, const At_Syntax& cx, Value zero, Value arg)
186 {
187 auto list = arg.to<List>(fl, cx);
188 if (list == nullptr) return missing;
189 unsigned n = list->size();
190 if (n == 0)
191 return {zero};
192 Value result = list->front();
193 for (unsigned i = 1; i < n; ++i) {
194 TRY_DEF(r, call(fl, cx, result, list->at(i)));
195 result = r;
196 }
197 return result;
198 }
199 static SC_Value
sc_reduceBinary_Array_Op200 sc_reduce(const At_Syntax& cx, Value zero, Operation& argx, SC_Frame& fm)
201 {
202 if (auto list = cast_list_expr(argx)) {
203 if (list->empty())
204 return sc_eval_const(fm, zero, *argx.syntax_);
205 auto first = sc_eval_op(fm, *list->at(0));
206 if (list->size() == 1) {
207 Prim::sc_check_arg(first, cx);
208 return first;
209 }
210 for (unsigned i = 1; i < list->size(); ++i) {
211 auto second = sc_eval_op(fm, *list->at(i));
212 Prim::sc_check_args(fm, first, second, cx);
213 first = Prim::sc_call(fm, first, second);
214 }
215 return first;
216 }
217 else {
218 // Reduce an array value that exists at GPU run time.
219 // TODO: For a large 1D array, use a GPU loop and call a function.
220 // TODO: Binary_Array_Op::sc_reduce: reduce a matrix.
221 // 2D arrays (SC_Type rank 2) are not supported, because you can't
222 // generate a rank 1 array at GPU runtime, for now at least.
223 // For a single Vec, this inline expansion of the loop is good.
224 auto arg = sc_eval_op(fm, argx);
225 if (arg.type.is_vec()) {
226 auto first = sc_vec_element(fm, arg, 0);
227 for (unsigned i = 1; i < arg.type.count(); ++i) {
228 auto second = sc_vec_element(fm, arg, i);
229 Prim::sc_check_args(fm, first, second, cx);
230 first = Prim::sc_call(fm, first, second);
231 }
232 return first;
233 }
234 else {
235 throw Exception(cx, "argument is not a vector");
236 }
237 }
238 }
239
240 static Value
callBinary_Array_Op241 call(Fail fl, const At_Syntax& cx, Value x, Value y)
242 {
243 // fast path: both x and y are scalars
244 // remaining cases:
245 // - x is a scalar, y is a list
246 // - x is a list, y is a scalar
247 // - x and y are lists
248 // - either x, or y, or both, is reactive
249 typename Prim::left_t sx;
250 typename Prim::right_t sy;
251 if (Prim::unbox_left(x, sx, cx)) {
252 if (Prim::unbox_right(y, sy, cx)) {
253 Value r = Prim::call(sx, sy, cx);
254 if (!r.is_missing()) return r;
255 }
256 else if (y.is_ref()) {
257 Ref_Value& ry(y.to_ref_unsafe());
258 switch (ry.type_) {
259 case Ref_Value::ty_abstract_list:
260 if (ry.subtype_ == Ref_Value::sty_list)
261 return broadcast_right(fl, cx, x, (List&)ry);
262 else
263 break; // TODO: strings are lists
264 case Ref_Value::ty_reactive:
265 return reactive_op(cx, x, y);
266 }
267 }
268 throw domain_error(cx,x,y);
269 } else if (x.is_ref()) {
270 Ref_Value& rx(x.to_ref_unsafe());
271 switch (rx.type_) {
272 case Ref_Value::ty_abstract_list:
273 if (Prim::unbox_right(y, sy, cx))
274 return broadcast_left(fl, cx, (List&)rx, y);
275 else if (rx.subtype_ == Ref_Value::sty_list && y.is_ref()) {
276 Ref_Value& ry(y.to_ref_unsafe());
277 switch (ry.type_) {
278 case Ref_Value::ty_abstract_list:
279 if (ry.subtype_ == Ref_Value::sty_list)
280 return element_wise_op(fl, cx, (List&)rx, (List&)ry);
281 else
282 break; // TODO: strings are lists
283 case Ref_Value::ty_reactive:
284 return reactive_op(cx, x, y);
285 }
286 }
287 throw domain_error(cx,1,x,y);
288 case Ref_Value::ty_reactive:
289 return reactive_op(cx, x, y);
290 }
291 }
292 throw domain_error(cx,0,x,y);
293 }
294 static SC_Value
sc_opBinary_Array_Op295 sc_op(const At_Syntax& cx, Operation& argx, SC_Frame& fm)
296 {
297 auto list = cast_list_expr(argx);
298 if (list && list->size() == 2) {
299 auto first = sc_eval_op(fm, *list->at(0));
300 auto second = sc_eval_op(fm, *list->at(1));
301 Prim::sc_check_args(fm, first, second, cx);
302 return Prim::sc_call(fm, first, second);
303 }
304 // TODO: Binary_Array_Op::sc_op: accept a 2-vector, 2-array or mat2.
305 throw Exception(cx, "expected a list of size 2");
306 }
307 static SC_Value
sc_callBinary_Array_Op308 sc_call(SC_Frame& fm, Operation& ax, Operation& ay, Shared<const Phrase> ph)
309 {
310 auto x = sc_eval_op(fm, ax);
311 auto y = sc_eval_op(fm, ay);
312 Prim::sc_check_args(fm, x, y, At_SC_Phrase(ph, fm));
313 return Prim::sc_call(fm, x, y);
314 }
315
316 static Value
broadcast_leftBinary_Array_Op317 broadcast_left(Fail fl, const At_Syntax& cx, List& xlist, Value y)
318 {
319 Shared<List> result = List::make(xlist.size());
320 for (unsigned i = 0; i < xlist.size(); ++i) {
321 TRY_DEF(r, call(fl, cx, xlist[i], y));
322 (*result)[i] = r;
323 }
324 return {result};
325 }
326
327 static Value
broadcast_rightBinary_Array_Op328 broadcast_right(Fail fl, const At_Syntax& cx, Value x, List& ylist)
329 {
330 Shared<List> result = List::make(ylist.size());
331 for (unsigned i = 0; i < ylist.size(); ++i) {
332 TRY_DEF(r, call(fl, cx, x, ylist[i]));
333 (*result)[i] = r;
334 }
335 return {result};
336 }
337
338 static Value
element_wise_opBinary_Array_Op339 element_wise_op(Fail fl, const At_Syntax& cx, List& xs, List& ys)
340 {
341 if (xs.size() != ys.size()) {
342 FAIL(fl, missing, cx, stringify(
343 "mismatched list sizes (",
344 xs.size(),",",ys.size(),") in array operation"));
345 }
346 Shared<List> result = List::make(xs.size());
347 for (unsigned i = 0; i < xs.size(); ++i) {
348 TRY_DEF(r, call(fl, cx, xs[i], ys[i]));
349 (*result)[i] = r;
350 }
351 return {result};
352 }
353
354 // At least one of x and y is reactive. Construct a Reactive_Expression.
355 static Value
reactive_opBinary_Array_Op356 reactive_op(const At_Syntax& cx, Value x, Value y)
357 {
358 Shared<Operation> x_expr;
359 SC_Type x_type;
360 if (auto xr = x.maybe<Reactive_Value>()) {
361 x_expr = xr->expr();
362 x_type = xr->sctype_;
363 } else {
364 x_expr = make<Constant>(share(cx.syntax()), x);
365 x_type = sc_type_of(x);
366 }
367
368 Shared<Operation> y_expr;
369 SC_Type y_type;
370 if (auto yr = y.maybe<Reactive_Value>()) {
371 y_expr = yr->expr();
372 y_type = yr->sctype_;
373 } else {
374 y_expr = make<Constant>(share(cx.syntax()), y);
375 y_type = sc_type_of(y);
376 }
377
378 SC_Type rtype = Prim::sc_result_type(x_type, y_type);
379 if (rtype) {
380 return {make<Reactive_Expression>(
381 rtype,
382 make<Binary_Op_Expr<Binary_Array_Op>>(
383 share(cx.syntax()), x_expr, y_expr),
384 cx)};
385 } else {
386 throw domain_error(cx, x, y);
387 }
388 }
389 };
390
391 //----------------------------------------------------------------------------//
392 // Base types for Prim classes. //
393 // Each Prim class defines the semantics of a primitive operator or function. //
394 // Each base class defines the argument and result types of a set of Prims. //
395 //----------------------------------------------------------------------------//
396
397 // A primitive mapping Bool -> Num.
398 // In SubCurv, argument types are Bool or Bvec but not Bool32.
399 struct Unary_Bool_To_Num_Prim
400 {
401 typedef bool scalar_t;
unboxUnary_Bool_To_Num_Prim402 static bool unbox(Value a, scalar_t& b, const Context& cx)
403 {
404 if (a.is_bool()) {
405 b = a.to_bool_unsafe();
406 return true;
407 } else
408 return false;
409 }
sc_check_argUnary_Bool_To_Num_Prim410 static void sc_check_arg(SC_Value a, const Context& cx)
411 {
412 if (a.type.is_bool_or_vec()) return;
413 throw Exception(cx, "argument must be Bool or BVec");
414 }
sc_result_typeUnary_Bool_To_Num_Prim415 static SC_Type sc_result_type(SC_Type a)
416 {
417 if (a.is_bool_or_vec())
418 return SC_Type::Num(a.count());
419 else
420 return {};
421 }
422 };
423
424 // maps bool -> bool
425 struct Unary_Bool_Prim
426 {
427 typedef bool scalar_t;
unboxUnary_Bool_Prim428 static bool unbox(Value a, scalar_t& b, const Context& cx)
429 {
430 if (a.is_bool()) {
431 b = a.to_bool_unsafe();
432 return true;
433 } else
434 return false;
435 }
sc_check_argUnary_Bool_Prim436 static void sc_check_arg(SC_Value a, const Context& cx)
437 {
438 if (a.type.is_bool_plex()) return;
439 throw Exception(cx, stringify("expected Bool or Bool32, got ",a.type));
440 }
sc_result_typeUnary_Bool_Prim441 static SC_Type sc_result_type(SC_Type a)
442 {
443 if (a.is_bool_plex())
444 return a;
445 else
446 return {};
447 }
448 };
449
450 // Maps [bool,bool] -> bool.
451 // In SubCurv, accepts Bool and Bool32 arguments, and vec of same.
452 struct Binary_Bool_Prim
453 {
454 typedef bool left_t, right_t;
unbox_leftBinary_Bool_Prim455 static bool unbox_left(Value a, left_t& b, const Context&)
456 {
457 if (a.is_bool()) {
458 b = a.to_bool_unsafe();
459 return true;
460 } else
461 return false;
462 }
unbox_rightBinary_Bool_Prim463 static bool unbox_right(Value a, right_t& b, const Context&)
464 {
465 if (a.is_bool()) {
466 b = a.to_bool_unsafe();
467 return true;
468 } else
469 return false;
470 }
sc_check_argBinary_Bool_Prim471 static void sc_check_arg(SC_Value a, const Context& cx)
472 {
473 if (a.type.is_bool_plex()) return;
474 throw Exception(cx, "argument must be Bool or Bool32");
475 }
sc_check_argsBinary_Bool_Prim476 static void sc_check_args(
477 SC_Frame& /*fm*/, SC_Value& a, SC_Value& b, const Context& cx)
478 {
479 if (a.type.is_bool_or_vec()) {
480 if (b.type.is_bool_or_vec()) {
481 if (a.type.count() != b.type.count()
482 && a.type.count() > 1 && b.type.count() > 1)
483 {
484 throw Exception(cx, stringify(
485 "can't combine lists of different sizes (",
486 a.type.count(), " and ", b.type.count(), ")"));
487 }
488 return;
489 }
490 else if (b.type.is_bool32_or_vec()) {
491 // TODO: convert a to Bool32 via broadcasting
492 }
493 }
494 else if (a.type.is_bool32_or_vec()) {
495 if (b.type.is_bool32_or_vec()) {
496 if (a.type.count() != b.type.count()
497 && a.type.count() > 1 && b.type.count() > 1)
498 {
499 throw Exception(cx, stringify(
500 "can't combine lists of different sizes (",
501 a.type.count(), " and ", b.type.count(), ")"));
502 }
503 return;
504 }
505 if (b.type.is_bool()) {
506 // TODO: convert b to Bool32 via broadcasting?
507 }
508 }
509 throw Exception(cx, stringify(
510 "arguments must be Bool or Bool32 (got ",
511 a.type, " and ", b.type, " instead)"));
512 }
sc_result_typeBinary_Bool_Prim513 static SC_Type sc_result_type(SC_Type a, SC_Type b)
514 {
515 if (a.is_bool_plex() && b.is_bool_plex())
516 return sc_unify_tensor_types(a,b);
517 else
518 return {};
519 }
520 };
521
522 // A primitive mapping number -> number.
523 // The corresponding GLSL primitive accepts a number, vector or matrix.
524 struct Unary_Num_SCMat_Prim
525 {
526 typedef double scalar_t;
unboxUnary_Num_SCMat_Prim527 static bool unbox(Value a, scalar_t& b, const Context&)
528 {
529 if (a.is_num()) {
530 b = a.to_num_unsafe();
531 return true;
532 } else
533 return false;
534 }
sc_check_argUnary_Num_SCMat_Prim535 static void sc_check_arg(SC_Value a, const Context& cx)
536 {
537 if (!a.type.is_num_plex())
538 throw Exception(cx, "argument must be a Num, Vec or Mat");
539 }
sc_result_typeUnary_Num_SCMat_Prim540 static SC_Type sc_result_type(SC_Type a)
541 {
542 return a.is_num_plex() ? a : SC_Type{};
543 }
544 };
545
546 // A primitive mapping number -> number.
547 // The corresponding GLSL primitive accepts a number or vector.
548 struct Unary_Num_SCVec_Prim : public Unary_Num_SCMat_Prim
549 {
sc_check_argUnary_Num_SCVec_Prim550 static void sc_check_arg(SC_Value a, const Context& cx)
551 {
552 if (!a.type.is_num_or_vec())
553 throw Exception(cx, "argument must be a Num or Vec");
554 }
sc_result_typeUnary_Num_SCVec_Prim555 static SC_Type sc_result_type(SC_Type a)
556 {
557 return a.is_num_or_vec() ? a : SC_Type{};
558 }
559 };
560
561 // A primitive mapping number -> bool32.
562 // The corresponding GLSL primitive accepts a number or vector.
563 struct Unary_Num_To_Bool32_Prim : public Unary_Num_SCVec_Prim
564 {
sc_result_typeUnary_Num_To_Bool32_Prim565 static SC_Type sc_result_type(SC_Type a)
566 {
567 return a.is_num_or_vec() ? SC_Type::Bool32(a.count()) : SC_Type{};
568 }
569 };
570
571 // Maps [num,num] -> num.
572 // The corresponding GLSL primitive accepts number, vector or matrix arguments.
573 struct Binary_Num_SCMat_Prim : public Unary_Num_SCMat_Prim
574 {
575 typedef double left_t;
576 typedef double right_t;
unbox_leftBinary_Num_SCMat_Prim577 static bool unbox_left(Value a, scalar_t& b, const Context& cx)
578 {
579 return unbox(a, b, cx);
580 }
unbox_rightBinary_Num_SCMat_Prim581 static bool unbox_right(Value a, scalar_t& b, const Context& cx)
582 {
583 return unbox(a, b, cx);
584 }
sc_check_argsBinary_Num_SCMat_Prim585 static void sc_check_args(
586 SC_Frame& fm, SC_Value& a, SC_Value& b, const Context& cx)
587 {
588 if (!a.type.is_num_plex()) {
589 throw Exception(At_Index(0, cx),
590 stringify("argument expected to be Num, Vec or Mat; got ",
591 a.type));
592 }
593 if (!b.type.is_num_plex()) {
594 throw Exception(At_Index(1, cx),
595 stringify("argument expected to be Num, Vec or Mat; got ",
596 a.type));
597 }
598 sc_plex_unify(fm, a, b, cx);
599 }
sc_result_typeBinary_Num_SCMat_Prim600 static SC_Type sc_result_type(SC_Type a, SC_Type b)
601 {
602 if (a.is_num_tensor() && b.is_num_tensor())
603 return sc_unify_tensor_types(a, b);
604 else
605 return {};
606 }
607 };
608
609 // Maps [Num,Num] -> Num.
610 // The corresponding GLSL primitive accepts number or vector arguments.
611 struct Binary_Num_SCVec_Prim : public Unary_Num_SCVec_Prim
612 {
613 typedef double left_t;
614 typedef double right_t;
unbox_leftBinary_Num_SCVec_Prim615 static bool unbox_left(Value a, scalar_t& b, const Context& cx)
616 {
617 return unbox(a, b, cx);
618 }
unbox_rightBinary_Num_SCVec_Prim619 static bool unbox_right(Value a, scalar_t& b, const Context& cx)
620 {
621 return unbox(a, b, cx);
622 }
sc_check_argsBinary_Num_SCVec_Prim623 static void sc_check_args(
624 SC_Frame& fm, SC_Value& a, SC_Value& b, const Context& cx)
625 {
626 if (!a.type.is_num_or_vec()) {
627 throw Exception(At_Index(0, cx),
628 stringify("argument expected to be Num or Vec; got ", a.type));
629 }
630 if (!b.type.is_num_or_vec()) {
631 throw Exception(At_Index(1, cx),
632 stringify("argument expected to be Num or Vec; got ", a.type));
633 }
634 sc_plex_unify(fm, a, b, cx);
635 }
sc_result_typeBinary_Num_SCVec_Prim636 static SC_Type sc_result_type(SC_Type a, SC_Type b)
637 {
638 if (a.is_num_tensor() && b.is_num_tensor())
639 return sc_unify_tensor_types(a, b);
640 else
641 return {};
642 }
643 };
644
645 // maps [num,num] -> bool
646 struct Binary_Num_To_Bool_Prim : public Binary_Num_SCVec_Prim
647 {
sc_result_typeBinary_Num_To_Bool_Prim648 static SC_Type sc_result_type(SC_Type a, SC_Type b)
649 {
650 if (a.is_num_or_vec() && b.is_num_or_vec()) {
651 SC_Type r = sc_unify_tensor_types(a,b);
652 if (r) return SC_Type::Bool(r.count());
653 }
654 return {};
655 }
656 };
657
658 // The left operand is a non-empty list of booleans.
659 // The right operand is an integer >= 0 and < the size of the left operand.
660 // The result has the same type as the left operand.
661 // (These restrictions on the right operand conform to the definition
662 // of << and >> in the C/C++/GLSL languages.)
663 struct Shift_Prim
664 {
665 typedef Shared<const List> left_t;
666 typedef double right_t;
unbox_leftShift_Prim667 static bool unbox_left(Value a, left_t& b, const Context&)
668 {
669 b = a.maybe<const List>();
670 return b && !b->empty() && b->front().is_bool();
671 }
unbox_rightShift_Prim672 static bool unbox_right(Value a, right_t& b, const Context&)
673 {
674 b = a.to_num_or_nan();
675 return b == b;
676 }
sc_check_argsShift_Prim677 static void sc_check_args(
678 SC_Frame& /*fm*/, SC_Value& a, SC_Value& b, const Context& cx)
679 {
680 if (!a.type.is_bool32_or_vec()) {
681 throw Exception(At_Index(0, cx),
682 stringify("expected argument of type Bool32, got ", a.type));
683 }
684 if (b.type != SC_Type::Num()) {
685 throw Exception(At_Index(1, cx),
686 stringify("expected argument of type Num, got ", b.type));
687 }
688 }
sc_result_typeShift_Prim689 static SC_Type sc_result_type(SC_Type a, SC_Type b)
690 {
691 if (a.is_bool32_or_vec() && b.is_num())
692 return a;
693 else
694 return {};
695 }
696 };
697
698 struct Unary_Vec2_To_Num_Prim
699 {
700 typedef Vec2 scalar_t;
unboxUnary_Vec2_To_Num_Prim701 static bool unbox(Value a, scalar_t& b, const Context&)
702 { return unbox_vec2(a, b); }
sc_check_argUnary_Vec2_To_Num_Prim703 static void sc_check_arg(SC_Value a, const Context& cx)
704 {
705 if (a.type != SC_Type::Num(2)) {
706 throw Exception(cx, stringify("expected a Vec2; got ", a.type));
707 }
708 }
sc_result_typeUnary_Vec2_To_Num_Prim709 static SC_Type sc_result_type(SC_Type a)
710 {
711 if (a == SC_Type::Num(2))
712 return SC_Type::Num();
713 else
714 return {};
715 }
716 };
717
718 struct Bool32_Prim
719 {
unbox_bool32Bool32_Prim720 static bool unbox_bool32(Value in, unsigned& out, const Context& cx)
721 {
722 auto li = in.maybe<const List>();
723 if (!li || li->size() != 32 || !li->front().is_bool())
724 return false;
725 out = bool32_to_nat(li, cx);
726 return true;
727 }
728 };
729 struct Unary_Bool32_To_Num_Prim : public Bool32_Prim
730 {
731 typedef unsigned scalar_t;
unboxUnary_Bool32_To_Num_Prim732 static bool unbox(Value a, scalar_t& b, const Context& cx)
733 {
734 return unbox_bool32(a, b, cx);
735 }
sc_check_argUnary_Bool32_To_Num_Prim736 static void sc_check_arg(SC_Value a, const Context& cx)
737 {
738 if (!a.type.is_bool32_or_vec())
739 throw Exception(cx, "argument must be a Bool32 or list of Bool32");
740 }
sc_result_typeUnary_Bool32_To_Num_Prim741 static SC_Type sc_result_type(SC_Type a)
742 {
743 if (a.is_bool32_or_vec())
744 return SC_Type::Num(a.count());
745 else
746 return {};
747 }
748 };
749 // maps [bool32,bool32] -> bool32
750 struct Binary_Bool32_Prim : public Bool32_Prim
751 {
752 typedef unsigned left_t;
753 typedef unsigned right_t;
unbox_leftBinary_Bool32_Prim754 static bool unbox_left(Value a, left_t& b, const Context& cx)
755 {
756 return unbox_bool32(a, b, At_Index(0, cx));
757 }
unbox_rightBinary_Bool32_Prim758 static bool unbox_right(Value a, right_t& b, const Context& cx)
759 {
760 return unbox_bool32(a, b, At_Index(1, cx));
761 }
sc_check_argBinary_Bool32_Prim762 static void sc_check_arg(SC_Value a, const Context& cx)
763 {
764 if (!a.type.is_bool32_or_vec())
765 throw Exception(cx, "argument must be a Bool32 or list of Bool32");
766 }
sc_check_argsBinary_Bool32_Prim767 static void sc_check_args(
768 SC_Frame& /*fm*/, SC_Value& a, SC_Value& b, const Context& cx)
769 {
770 if (!a.type.is_bool32_or_vec()) {
771 throw Exception(At_Index(0, cx),
772 stringify("expected argument of type Bool32, got ", a.type));
773 }
774 if (!b.type.is_bool32_or_vec()) {
775 throw Exception(At_Index(1, cx),
776 stringify("expected argument of type Bool32, got ", b.type));
777 }
778 if (a.type != b.type) {
779 // Note, it's impossible to unify types of a and b
780 // with the current palette of plex types.
781 throw Exception(cx,
782 stringify("mismatched argument types ",a.type," and ",b.type));
783 }
784 }
sc_result_typeBinary_Bool32_Prim785 static SC_Type sc_result_type(SC_Type a, SC_Type b)
786 {
787 if (a.is_bool32_or_vec() && b.is_bool32_or_vec() && a == b)
788 return a;
789 else
790 return {};
791 }
792 };
793
794 // --------- //
795 // Utilities //
796 // --------- //
797
798 // convert the result of Value::equal() from Ternary to Value
799 template <class Expr>
eqval(Ternary tr,Value a,Value b,const At_Syntax & cx)800 Value eqval(Ternary tr, Value a, Value b, const At_Syntax& cx)
801 {
802 if (tr != Ternary::Unknown)
803 return {tr.to_bool()};
804 else
805 return {make<Reactive_Expression>(
806 SC_Type::Bool(),
807 make<Expr>(
808 share(cx.syntax()),
809 to_expr(a, cx.syntax()),
810 to_expr(b, cx.syntax())),
811 cx)};
812 }
813
814 } // namespace
815 #endif // header guard
816