1 // Copyright (c) ONNX Project Contributors.
2 // Licensed under the MIT license.
3 
4 #include "onnx/defs/schema.h"
5 namespace ONNX_NAMESPACE {
6 using SupportType = OpSchema::SupportType;
7 
handle_negative_axis_validate(const std::string & attrib,int axis,int rank)8 int handle_negative_axis_validate(
9     const std::string& attrib,
10     int axis,
11     int rank) {
12   if (!(-rank <= axis && axis < rank))
13     fail_shape_inference(
14         attrib,
15         " axis value ",
16         axis,
17         " is invalid for a tensor of rank ",
18         rank);
19   return (axis >= 0 ? axis : axis + rank);
20 }
21 
ScanInferenceFunction(InferenceContext & ctx)22 void ScanInferenceFunction(InferenceContext& ctx) {
23   auto num_inputs = ctx.getNumInputs();
24   auto num_scan_inputs =
25       narrow_cast<size_t>(ctx.getAttribute("num_scan_inputs")->i());
26   auto num_loop_state_vars = num_inputs - num_scan_inputs;
27   auto num_outputs = ctx.getNumOutputs();
28   auto num_scan_outputs = num_outputs - num_loop_state_vars;
29 
30   std::vector<int64_t> axes, output_axes;
31   if (getRepeatedAttribute(ctx, "scan_input_axes", axes)) {
32     if (axes.size() != num_scan_inputs)
33       fail_shape_inference(
34           "Number of scan input axes specified (",
35           axes.size(),
36           ") is not equal to number of scan inputs (",
37           num_scan_inputs,
38           ").");
39   } else {
40     axes.insert(axes.end(), num_scan_inputs, 0);
41   }
42 
43   if (getRepeatedAttribute(ctx, "scan_output_axes", output_axes)) {
44     if (output_axes.size() != num_scan_outputs)
45       fail_shape_inference(
46           "Number of scan output axes specified (",
47           output_axes.size(),
48           ") is not equal to number of scan outputs (",
49           num_scan_outputs,
50           ").");
51   } else {
52     output_axes.insert(output_axes.end(), num_scan_outputs, 0);
53   }
54 
55   std::vector<TypeProto> temporary_type_protos;
56   temporary_type_protos.reserve(num_inputs);
57 
58   std::vector<const TypeProto*> subgraph_input_types;
59 
60   TensorShapeProto_Dimension sequence_len_dim;
61 
62   for (size_t i = 0; i < num_inputs; ++i) {
63     bool is_loop_state_var = i < num_loop_state_vars;
64     bool has_shape = hasInputShape(ctx, i);
65     const auto* input_type = ctx.getInputType(i);
66 
67     // Enforce type constraint for inputs
68     if (!input_type || !input_type->has_tensor_type()) {
69       fail_type_inference("Scan input ", i, " was not a tensor.");
70     }
71 
72     if (is_loop_state_var) {
73       // If it's a loop state variable we can propagate type and shape 1:1 to
74       // the matching Scan output.
75       // We can also pass through the type and shape to the subgraph but need to
76       // remove the batch size dimension from the shape.
77       propagateElemTypeFromInputToOutput(ctx, i, i);
78       if (has_shape)
79         propagateShapeFromInputToOutput(ctx, i, i);
80 
81       subgraph_input_types.push_back(input_type);
82     } else {
83       // For other inputs there is no fixed relationships to the Scan outputs,
84       // so we don't propagate type/shape information.
85       // We can pass through the type and shape to the subgraph inputs but
86       // need to remove the sequence length dimensions from the shape.
87       if (has_shape) {
88         const auto& shape = input_type->tensor_type().shape();
89 
90         // remove sequence length dimensions and add to subgraph_input_types
91         int axis = static_cast<int>(axes[i - num_loop_state_vars]);
92         axis = handle_negative_axis_validate(
93             "scan_input_axes", axis, shape.dim_size());
94 
95         // update sequence_len if a value is available
96 
97         const auto& dims = shape.dim();
98         mergeInDimensionInfo(dims.Get(axis), sequence_len_dim, 1);
99 
100         temporary_type_protos.push_back(
101             RemoveIthDimensionFromShape(*input_type, axis));
102         subgraph_input_types.push_back(&temporary_type_protos.back());
103 
104       } else {
105         subgraph_input_types.push_back(input_type);
106       }
107     }
108   }
109 
110   // Run inferencing on the subgraph
111   std::vector<const TypeProto*> output_types;
112 
113   GraphInferencer* graphInferencer = ctx.getGraphAttributeInferencer("body");
114   if (graphInferencer) {
115     std::vector<const TensorProto*> input_data;
116     for (size_t i = 0; i < num_inputs; ++i) {
117       // ctx.getInputData(i), the input to scan, does not represent the input to
118       // scan body. So, we pass in null, to represent an unknown value.
119       input_data.push_back(nullptr);
120     }
121 
122     output_types =
123         graphInferencer->doInferencing(subgraph_input_types, input_data);
124   }
125 
126   // if empty(), assume inferencing was skipped
127   if (!output_types.empty()) {
128     if (output_types.size() != num_outputs) {
129       fail_type_inference(
130           "Graph attribute inferencing returned type information for ",
131           output_types.size(),
132           " outputs. Expected ",
133           num_outputs);
134     }
135 
136     // propagate type/shape information for loop state variables and outputs
137     for (size_t i = 0; i < num_outputs; ++i) {
138       const bool is_loop_state_var = i < num_loop_state_vars;
139       auto* subgraph_output_type = output_types[i];
140       auto* scan_output_type = ctx.getOutputType(i);
141       auto* mutable_scan_output_tensor_type =
142           scan_output_type->mutable_tensor_type();
143 
144       if (!subgraph_output_type->has_tensor_type()) {
145         fail_type_inference(
146             "Scan 'body' subgraph outputs should all be tensors but output ",
147             i,
148             " was not");
149       }
150       auto& subgraph_output_tensor_type = subgraph_output_type->tensor_type();
151 
152       if (is_loop_state_var) {
153         // merge shape; type already propagated
154         mergeInShapeInfo(
155             subgraph_output_tensor_type, *mutable_scan_output_tensor_type);
156       } else {
157         scan_output_type->mutable_tensor_type()->set_elem_type(
158             subgraph_output_tensor_type.elem_type());
159 
160         // propagate shape
161         if (subgraph_output_tensor_type.has_shape()) {
162           // infer shape of scan-output from the shape of scan-output-element
163           // by adding sequence-length at the correct axis position
164           const TensorShapeProto& subgraph_output_shape =
165               subgraph_output_tensor_type.shape();
166           TensorShapeProto inferred_shape;
167 
168           auto subgraph_output_rank = subgraph_output_shape.dim_size();
169           auto output_rank = subgraph_output_rank + 1;
170           int output_axis =
171               static_cast<int>(output_axes[i - num_loop_state_vars]);
172           output_axis = handle_negative_axis_validate(
173               "scan_output_axes", output_axis, output_rank);
174 
175           for (int j = 0; j < output_axis; ++j)
176             *(inferred_shape.add_dim()) = subgraph_output_shape.dim(j);
177           *(inferred_shape.add_dim()) = sequence_len_dim;
178           for (int j = output_axis; j < subgraph_output_rank; ++j)
179             *(inferred_shape.add_dim()) = subgraph_output_shape.dim(j);
180 
181           // Merge inferred shape with existing shape information
182           mergeInShapeInfo(inferred_shape, *mutable_scan_output_tensor_type);
183         }
184       }
185     }
186   }
187 }
188 
IfInferenceFunction(InferenceContext & ctx)189 void IfInferenceFunction(InferenceContext& ctx) {
190   // there are no inputs so we just need to run the subgraph inferencing for
191   // then/else subgraphs and apply those to the outputs.
192   std::vector<const TypeProto*> subgraph_input_types; // none
193   std::vector<const TensorProto*> input_data; // none
194 
195   std::vector<const TypeProto*> then_output_types;
196   std::vector<const TypeProto*> else_output_types;
197 
198   // Run inferencing on the subgraph
199   GraphInferencer* graphInferencer =
200       ctx.getGraphAttributeInferencer("then_branch");
201   if (graphInferencer) {
202     then_output_types =
203         graphInferencer->doInferencing(subgraph_input_types, input_data);
204   }
205 
206   graphInferencer = ctx.getGraphAttributeInferencer("else_branch");
207   if (graphInferencer) {
208     else_output_types =
209         graphInferencer->doInferencing(subgraph_input_types, input_data);
210   }
211 
212   auto num_outputs = ctx.getNumOutputs();
213   auto num_then_outputs = then_output_types.size();
214   auto num_else_outputs = else_output_types.size();
215 
216   // the output types for then and else should be the same
217   if (num_then_outputs != num_else_outputs) {
218     fail_type_inference(
219         "then_branch and else_branch produce different number of outputs. ",
220         num_then_outputs,
221         " != ",
222         num_else_outputs);
223   }
224 
225   if (num_then_outputs != num_outputs) {
226     fail_type_inference(
227         "If node has ",
228         num_outputs,
229         " but subgraphs produce ",
230         num_then_outputs);
231   }
232 
233   for (size_t i = 0, end = then_output_types.size(); i < end; ++i) {
234     auto then_output = then_output_types[i];
235     auto else_output = else_output_types[i];
236 
237     if (then_output->value_case() != else_output->value_case()) {
238       fail_type_inference(
239           "Mismatched type for output ",
240           i,
241           " then=",
242           then_output->value_case(),
243           " else=",
244           else_output->value_case());
245     }
246 
247     auto* if_output = ctx.getOutputType(i);
248     *if_output = *then_output;
249 
250     if (then_output->has_tensor_type()) {
251       auto then_elem_type = then_output->tensor_type().elem_type();
252       auto else_elem_type = else_output->tensor_type().elem_type();
253 
254       if (then_elem_type != else_elem_type) {
255         fail_type_inference(
256             "Mismatched tensor element type for output ",
257             i,
258             " then=",
259             then_elem_type,
260             " else=",
261             else_elem_type);
262       }
263 
264       // merge the 'else' shape information to check it's consistent and
265       // augment the 'if' output if possible
266       mergeInShapeInfo(
267           else_output->tensor_type(), *if_output->mutable_tensor_type());
268     }
269   }
270 }
271 
LoopInferenceFunction(InferenceContext & ctx)272 void LoopInferenceFunction(InferenceContext& ctx) {
273   auto num_inputs = ctx.getNumInputs();
274   auto num_loop_state_vars = num_inputs - 2; // skip 'M' and 'cond'
275 
276   std::vector<const TypeProto*> subgraph_input_types;
277 
278   std::vector<TypeProto> temporary_type_protos;
279   temporary_type_protos.reserve(num_inputs - 2);
280 
281   // create TypeProto to validate iteration number type is the same as the
282   // optional 'M' input for max iterations.
283   TypeProto iter_num_type;
284   iter_num_type.mutable_tensor_type()->set_elem_type(
285       TensorProto_DataType_INT64);
286   subgraph_input_types.push_back(&iter_num_type);
287 
288   // 'cond'
289   subgraph_input_types.push_back(ctx.getInputType(1));
290 
291   // loop state value types get propagated to outputs, but shape may change
292   // across iterations so don't propagate it to the outputs and don't pass it
293   // into the subgraph inferencing
294   for (size_t i = 2; i < num_inputs; ++i) {
295     propagateElemTypeFromInputToOutput(ctx, i, i - 2);
296 
297     // copy so we can remove the shape before passing to the subgraph
298     // inferencing
299     temporary_type_protos.push_back(*ctx.getInputType(i));
300     auto& input_type = temporary_type_protos.back();
301     input_type.mutable_tensor_type()->clear_shape();
302 
303     subgraph_input_types.push_back(&input_type);
304   }
305 
306   // Run inferencing on the subgraph
307   std::vector<const TypeProto*> subgraph_output_types;
308 
309   GraphInferencer* graphInferencer = ctx.getGraphAttributeInferencer("body");
310   if (graphInferencer) {
311     std::vector<const TensorProto*> input_data;
312     input_data.push_back(nullptr); // iteration number
313     for (size_t i = 1; i < num_inputs; ++i) {
314       input_data.push_back(ctx.getInputData(i));
315     }
316 
317     subgraph_output_types =
318         graphInferencer->doInferencing(subgraph_input_types, input_data);
319   }
320 
321   // if empty(), assume inferencing was skipped
322   if (!subgraph_output_types.empty()) {
323     auto num_outputs = ctx.getNumOutputs();
324 
325     // subgraph outputs the condition value first but that is only used
326     // internally and not returned by Loop.
327     if (subgraph_output_types.size() != num_outputs + 1) {
328       fail_type_inference(
329           "Graph attribute inferencing returned type information for ",
330           subgraph_output_types.size(),
331           " outputs. Expected ",
332           num_outputs + 1);
333     }
334 
335     // check loop state values match. we should already have type/shape info
336     for (size_t i = 0; i < num_outputs; ++i) {
337       auto* subgraph_output_type = subgraph_output_types[i + 1]; // skip 'cond'
338       auto* loop_output_type = ctx.getOutputType(i);
339 
340       const bool is_loop_state_var = i < num_loop_state_vars;
341 
342       if (!subgraph_output_type->has_tensor_type()) {
343         fail_type_inference(
344             "Loop 'body' subgraph outputs should all be tensors but output ",
345             i,
346             " was ",
347             subgraph_output_type->value_case());
348       }
349 
350       // if there's an existing type check it matches. otherwise propagate
351       propagateElemTypeWithValidation(subgraph_output_type, loop_output_type);
352 
353       if (is_loop_state_var) {
354         // shape may change across iterations so ignore.
355       } else {
356         // per iteration output. first dimension will be number of iterations
357         // but we don't know that value yet
358         TypeProto inferred_type(*subgraph_output_type);
359         auto* mutable_inferred_tensor_type =
360             inferred_type.mutable_tensor_type();
361         auto* mutable_inferred_shape =
362             mutable_inferred_tensor_type->mutable_shape();
363 
364         mutable_inferred_shape->clear_dim();
365 
366         // add empty dimension for number of iterations
367         mutable_inferred_shape->add_dim();
368 
369         // add dimensions from subgraph output shape
370         for (const auto& dim :
371              subgraph_output_type->tensor_type().shape().dim()) {
372           (*mutable_inferred_shape->add_dim()) = dim;
373         }
374 
375         mergeInShapeInfo(
376             *mutable_inferred_tensor_type,
377             *loop_output_type->mutable_tensor_type());
378       }
379     }
380   }
381 }
382 
383 ONNX_OPERATOR_SET_SCHEMA(
384     If,
385     1,
386     OpSchema()
387         .SetDoc("If conditional")
388         .Input(0, "cond", "Condition for the if", "B")
389         .Output(
390             0,
391             "outputs",
392             "Values that are live-out to the enclosing scope. The return values in "
393             "the `then_branch` and `else_branch` must be of the same shape and same "
394             "data type.",
395             "V",
396             OpSchema::Variadic,
397             false)
398         .Attr(
399             "then_branch",
400             "Graph to run if condition is true. Has N outputs: values you wish to "
401             "be live-out to the enclosing scope. The number of outputs must match"
402             " the number of outputs in the else_branch.",
403             AttributeProto::GRAPH)
404         .Attr(
405             "else_branch",
406             "Graph to run if condition is false. Has N outputs: values you wish to"
407             " be live-out to the enclosing scope. The number of outputs must match"
408             " the number of outputs in the then_branch.",
409             AttributeProto::GRAPH)
410         .TypeConstraint("V", OpSchema::all_tensor_types(), "All Tensor types")
411         .TypeConstraint("B", {"tensor(bool)"}, "Only bool")
412         .TypeAndShapeInferenceFunction(IfInferenceFunction));
413 
414 static const char* Loop_ver1_doc = R"DOC(
415 Generic Looping construct. This loop has multiple termination conditions:
416 
417 1) Trip count. Iteration count specified at runtime. Set by
418    specifying the input M. Optional. Set to empty string to omit.
419    Note that a static trip count (specified at graph construction time) can be
420    specified by passing in a constant node for input M.
421 2) Loop termination condition. This is an input to the op that determines
422    whether to run the first iteration and also a loop-carried dependency for
423    the body graph. The body graph must yield a value for the condition variable,
424    whether this input is provided or not.
425 
426 This table summarizes the operating modes of this operator with equivalent
427 C-style code:
428 
429     Operator inputs defined as (max_trip_count, condition_var).
430 
431     input ("", ""):
432         for (int i=0; ; ++i) {
433           cond = ... // Note this value is ignored, but is required in the body
434         }
435 
436     input ("", cond) // Note this is analogous to a while loop
437         bool cond = ...;
438         for (int i=0; cond; ++i) {
439           cond = ...;
440         }
441 
442     input ("", 1) // Note this is analogous to a do-while loop
443         bool cond = true
444         for (int i=0; cond; ++i) {
445           cond = ...;
446         }
447 
448     input (trip_count, "") // Note this is analogous to a for loop
449         int trip_count = ...
450         for (int i=0; i < trip_count; ++i) {
451           cond = ...; // ignored
452         }
453 
454     input (trip_count, cond)
455         int trip_count = ...;
456         bool cond = ...;
457         for (int i=0; i < trip_count && cond; ++i) {
458           cond = ...;
459         }
460 
461 
462 *Sample usage - cond as well as trip count*
463 
464     graph predict-net {
465       %a = Constant[value = <Scalar Tensor [3]>]()
466       %b = Constant[value = <Scalar Tensor [6]>]()
467       %keepgoing = Constant[value = <Scalar Tensor [1]>]()
468       %max_trip_count = Constant[value = <Scalar Tensor [10]>]()
469       %keepgoing_out, %b_out, %user_defined_vals = Loop[body = <graph body-net>](%max_trip_count, %keepgoing, %b)
470       return
471     }
472 
473     graph body-net (
474       %i[INT32, scalar]
475       %keepgoing[BOOL, scalar]
476       %b[INT32, scalar]
477     ) {
478       %my_local = Add(%a, %b)
479       %b_out = Sub(%a, %b)
480       %keepgoing_out = Greater(%my_local, %b_out)
481       %user_defined_vals = Add(%b, %b)
482       return %keepgoing_out, %b_out, %user_defined_vals
483     }
484 
485 *Sample equivalent C code*
486 
487     {
488       /* User-defined code (enclosing scope) */
489       int a = 3, b = 6;
490       bool keepgoing = true; // Analogous to input cond
491       /* End user-defined code */
492 
493       /* Implicitly-defined code */
494       const int max_trip_count = 10; // Analogous to input M
495       int user_defined_vals[]; // Imagine this is resizable
496       /* End implicitly-defined code */
497       for (int i=0; i < max_trip_count && keepgoing; ++i) {
498         /* User-defined code (loop body) */
499         int my_local = a + b; // Reading values in the enclosing scope is fine
500         b = a - b; // writes fine if we specify b as a loop-carried dependency
501         keepgoing = my_local > b; // keepgoing is a loop-carried dependency
502         user_defined_vals[i] = b + b;
503         /* End user-defined code */
504       }
505       // my_local = 123; // Can't do this. my_local was defined in the the body
506 
507       // These below values are live-out from the loop and therefore accessible
508       b_out; user_defined_vals; keepgoing_out;
509     }
510 
511 There are several things of note in this code snippet:
512 
513 1) Values from the enclosing scope (i.e. variable a here) are in scope and can
514    be referenced in the inputs of the loop.
515 2) Any variables which you wish to make available in the enclosing scope (i.e.
516    the variables b and keepgoing) must be declared as either loop-carried
517    dependencies (both at the op inputs and output and at the body net input and
518    output) or scan_outputs.
519 3) Values created in the body cannot be accessed in the enclosing scope.
520 
521 Note that the semantics of this op support "diagonal" or "wavefront" execution.
522 (See Step 3 here for an example:
523 https://devblogs.nvidia.com/optimizing-recurrent-neural-networks-cudnn-5/).
524 Frontends should emit multi-layer RNNs as a series of While operators (with
525 time being the inner looping dimension), with each successive layer consuming
526 the scan_outputs from the previous layer, possibly going through several
527 point-wise operators (e.g. dropout, residual connections, linear layer).
528 )DOC";
529 
530 ONNX_OPERATOR_SET_SCHEMA(
531     Loop,
532     1,
533     OpSchema()
534         .SetDoc(Loop_ver1_doc)
535         .Input(
536             0,
537             "M",
538             "A maximum trip-count for the loop specified at runtime. Optional."
539             " Pass empty string to skip.",
540             "I",
541             OpSchema::Optional)
542         .Input(
543             1,
544             "cond",
545             "A boolean termination condition. Optional. Pass empty string to skip.",
546             "B",
547             OpSchema::Optional)
548         .Input(
549             2,
550             "v_initial",
551             "The initial values of any loop-carried dependencies (values that "
552             "change across loop iterations)",
553             "V",
554             OpSchema::Variadic,
555             false)
556         .Output(
557             0,
558             "v_final_and_scan_outputs",
559             "Final N loop carried dependency values then K scan_outputs",
560             "V",
561             OpSchema::Variadic,
562             false)
563         .Attr(
564             "body",
565             "The graph run each iteration. It has 2+N inputs: (iteration_num, "
566             "condition, loop carried dependencies...). It has 1+N+K outputs: "
567             "(condition, loop carried dependencies..., scan_outputs...). Each "
568             "scan_output is created by concatenating the value of the specified "
569             "output value at the end of each iteration of the loop. It is an error"
570             " if the dimensions or data type of these scan_outputs change across loop"
571             " iterations.",
572             AttributeProto::GRAPH)
573         .TypeConstraint("V", OpSchema::all_tensor_types(), "All Tensor types")
574         .TypeConstraint(
575             "I",
576             {"tensor(int64)"},
577             "tensor of int64, which should be a scalar.")
578         .TypeConstraint(
579             "B",
580             {"tensor(bool)"},
581             "tensor of bool, which should be a scalar.")
582         .TypeAndShapeInferenceFunction(LoopInferenceFunction));
583 
584 static const char* scan_9_doc = R"DOC(
585 Scan can be used to iterate over one or more scan_input tensors,
586 constructing zero or more scan_output tensors. It combines ideas from general recurrences,
587 functional programming constructs such as scan, fold, map, and zip and is intended to enable
588 generalizations of RNN-like constructs for sequence-to-sequence processing.
589 Other tensors (referred to as state_variables here) can be used to carry a state
590 when iterating from one element to another (similar to hidden-state in RNNs, also referred
591 to as loop-carried dependences in the context of loops).
592 Many common usages involve a single scan_input tensor (where functionality
593 similar to scan, fold and map can be obtained). When more than one scan_input is used,
594 a behavior similar to zip is obtained.
595 
596 The attribute body must be a graph, specifying the computation to be performed in
597 every iteration. It takes as input the current values of the state_variables and
598 the current iterated element of the scan_inputs. It must return the (updated) values
599 of the state_variables and zero or more scan_output_element tensors. The values of the
600 scan_output_element tensors are concatenated over all the iterations to produce the
601 scan_output values of the scan construct (similar to the concatenated intermediate
602 hidden-state values of RNN-like constructs). All the output tensors (state_variables as
603 well as scan_output_element tensors) are required to have the same shape in each iteration
604 of the loop (a restriction imposed to enable efficient memory allocation).
605 
606 Note that the iterated element passed to the body subgraph does not have a sequence
607 axis. It will have a rank one less than the rank of the corresponding scan_input.
608 
609 The scan operation returns the final values of the state_variables as well as the
610 scan_outputs.
611 
612 The optional attribute scan_input_directions specifies the direction (forward or backward)
613 for each scan input. If this attribute is omitted, all sequences are scanned in the forward
614 direction. A bidirectional scan may be performed by specifying the same tensor input twice
615 in the scan_inputs, once with a forward direction, and once with a backward direction.
616 
617 The scan_output of the operation is produced by concatenating the scan_output_element
618 values produced by the body in each iteration.  The optional attribute scan_output_directions
619 specifies the direction in which scan_output is constructed (by appending or prepending the
620 scan_output_element to scan_output in each iteration) for each scan_output. If this attribute
621 is omitted, the scan_output_element is appended to the scan_output in each iteration.
622 
623 The optional attribute scan_input_axes specifies the axis to be scanned for each scan_input.
624 If omitted, every scan_input will be scanned in axis 0. For example, if axis 0 is the
625 batch axis and axis 1 is the time axis (to be scanned), specify an axis value of 1.
626 Note that scanning a non-zero axis may be less efficient than scanning axis zero.
627 
628 The optional attribute scan_output_axes specifies the axis along which the scan_outputs
629 are accumulated for each scan_output. For example, if axis 1 is the time axis (to be
630 scanned) for both inputs and outputs, specify a scan_input axis and scan_output axis
631 value of 1.
632 
633 Note that because of the ONNX restriction that only the last parameter of an operator can
634 be variadic, the initial-states and scan-inputs are listed together as one input parameter.
635 Similarly, the final-states and scan-outputs are listed together as one output parameter.
636 The attribute num_scan_inputs indicates the number M of scan-inputs.
637 
638 The behavior of
639 
640     Scan <
641         num_scan_inputs = m,
642         body = loop-body,
643         scan_input_axes = [axis_1, ..., axis_m]
644     > (init_1, ..., init_n, scan_1, ..., scan_m)
645 
646 is equivalent to the following pseudo-code:
647 
648     // scan_i.shape[axis_i] denotes the (max) sequence-length of scan_i
649     // scan_i.shape[axis_i] is required to be equal to scan_j.shape[axis_j] for all i,j.
650     sequence_length = scan_1.shape[axis_1];
651 
652     // initialize state-variables
653     st_1 = init_1; ... st_n = init_n;
654     // initialize scan-output variables: [] denotes an empty tensor
655     scan_out_1 = []; ...; scan_out_k = [];
656     // identify number of iterations:
657 
658     // execute loop
659     for (int t = 0; t < sequence_length; ++t) {
660         // generate the scan-input elements: the notation T<axis=k>[t] indicates the sub-tensor
661         // of rank one less than T obtained by indexing T at position t along axis k.
662         si_1 = scan_1<axis=axis_1>[t];
663         ... ;
664         si_m = scan_m<axis=axis_m>[t];
665         // execute loop-body
666         st_1, ..., st_n, so_1, ..., so_k = loop-body(st_1, ..., st_n, si_1, ..., si_m)
667         // accumulate the scan-output elements
668         scan_out_1 = Concat<axis=0>(scan_out_1, so_1); ... ; scan_out_k = Concat<axis=0>(scan_out_k, so_k);
669     }
670 
671     return st_1, ..., st_n, scan_out_1, ..., scan_out_k;
672 
673 *Sample usage: Encoding RNN using a Scan*
674 
675 The following example shows how a simple RNN over an input tensor %X, with weight tensor %Wi,
676 recurrence weight tensor %Ri, bias tensors %Wbi and %Rbi, and initial hidden-state %H_0 can
677 be encoded as a ScanLoop. Note that the loop-body is a nested graph, and it directly computes
678 %Wi, %Ri, %Wbi, and %Rbi (typically constants or initializers in the body graph). If these
679 values are computed in the outer graph, they need to be passed in as extra state_variables.
680 
681     graph rnn-encoding {
682       %H_0 = ...
683       %X = ...
684       %Y_h, %Y = Scan[body = <graph rnn-cell-1>, num_scan_inputs=1](%H_0, %X)
685       return %Y, %Y_h
686     }
687 
688     graph rnn-cell-1 (
689       %H_tminus1[FLOAT, tensor]
690       %X_t[FLOAT, tensor]
691     ) {
692       %Wi = ...
693       %Ri = ...
694       %Wbi = ...
695       %Rbi = ...
696       %t1 = X_t * (Wi^T)
697       %t2 = H_tminus1*(Ri^T)
698       %t3 = Add(%t1, %t2)
699       %t4 = Add(%t3, %Wbi)
700       %t5 = Add(%t4, %Rbi)
701       %Ht = Tanh(%t5)
702       %Accumulate = Identity(%Ht)
703       return %Ht, %Accumulate
704     }
705 
706 )DOC";
707 
708 ONNX_OPERATOR_SET_SCHEMA(
709     Scan,
710     9,
711     OpSchema()
712         .SetDoc(scan_9_doc)
713         .Input(
714             0,
715             "initial_state_and_scan_inputs",
716             "Initial values of the loop's N state variables followed by M scan_inputs",
717             "V",
718             OpSchema::Variadic,
719             false)
720         .Output(
721             0,
722             "final_state_and_scan_outputs",
723             "Final values of the loop's N state variables followed by K scan_outputs",
724             "V",
725             OpSchema::Variadic,
726             false)
727         .Attr(
728             "body",
729             "The graph run each iteration. It has N+M inputs: "
730             "(loop state variables..., scan_input_elts...). It has N+K outputs: "
731             "(loop state variables..., scan_output_elts...). Each "
732             "scan_output is created by concatenating the value of the specified "
733             "scan_output_elt value at the end of each iteration of the loop. It is an error"
734             " if the dimensions of these values change across loop iterations.",
735             AttributeProto::GRAPH,
736             true)
737         .Attr(
738             "num_scan_inputs",
739             "An attribute specifying the number of scan_inputs M. ",
740             AttributeProto::INT,
741             true)
742         .Attr(
743             "scan_input_directions",
744             "An optional list of M flags. The i-th element of the list specifies the direction "
745             "to be scanned for the i-th scan_input tensor: 0 indicates forward direction and 1 "
746             "indicates reverse direction. "
747             "If omitted, all scan_input tensors will be scanned in the forward direction.",
748             AttributeProto::INTS,
749             false)
750         .Attr(
751             "scan_output_directions",
752             "An optional list of K flags, one for each scan_output. The i-th element of the list "
753             "specifies whether the i-th scan_output should be constructed by appending or "
754             "prepending a new value in each iteration: 0 indicates appending and 1 "
755             "indicates prepending. "
756             "If omitted, all scan_output tensors will be produced by appending a value "
757             "in each iteration.",
758             AttributeProto::INTS,
759             false)
760         .Attr(
761             "scan_input_axes",
762             "An optional list of M flags. The i-th element of the list specifies the axis "
763             "to be scanned (the sequence axis) for the i-th scan_input. If omitted, 0 will "
764             "be used as the scan axis for every scan_input.",
765             AttributeProto::INTS,
766             false)
767         .Attr(
768             "scan_output_axes",
769             "An optional list of K flags. The i-th element of the list specifies the axis "
770             "for the i-th scan_output. The scan outputs are accumulated along the specified "
771             "axis. If omitted, 0 will be used as the scan axis for every scan_output.",
772             AttributeProto::INTS,
773             false)
774         .TypeConstraint("I", {"tensor(int64)"}, "Int64 tensor")
775         .TypeConstraint("V", OpSchema::all_tensor_types(), "All Tensor types")
776         .TypeAndShapeInferenceFunction(ScanInferenceFunction));
777 } // namespace ONNX_NAMESPACE
778