1 // Copyright (c) ONNX Project Contributors.
2 // Licensed under the MIT license.
3
4 #include "onnx/defs/schema.h"
5 namespace ONNX_NAMESPACE {
6 using SupportType = OpSchema::SupportType;
7
handle_negative_axis_validate(const std::string & attrib,int axis,int rank)8 int handle_negative_axis_validate(
9 const std::string& attrib,
10 int axis,
11 int rank) {
12 if (!(-rank <= axis && axis < rank))
13 fail_shape_inference(
14 attrib,
15 " axis value ",
16 axis,
17 " is invalid for a tensor of rank ",
18 rank);
19 return (axis >= 0 ? axis : axis + rank);
20 }
21
ScanInferenceFunction(InferenceContext & ctx)22 void ScanInferenceFunction(InferenceContext& ctx) {
23 auto num_inputs = ctx.getNumInputs();
24 auto num_scan_inputs =
25 narrow_cast<size_t>(ctx.getAttribute("num_scan_inputs")->i());
26 auto num_loop_state_vars = num_inputs - num_scan_inputs;
27 auto num_outputs = ctx.getNumOutputs();
28 auto num_scan_outputs = num_outputs - num_loop_state_vars;
29
30 std::vector<int64_t> axes, output_axes;
31 if (getRepeatedAttribute(ctx, "scan_input_axes", axes)) {
32 if (axes.size() != num_scan_inputs)
33 fail_shape_inference(
34 "Number of scan input axes specified (",
35 axes.size(),
36 ") is not equal to number of scan inputs (",
37 num_scan_inputs,
38 ").");
39 } else {
40 axes.insert(axes.end(), num_scan_inputs, 0);
41 }
42
43 if (getRepeatedAttribute(ctx, "scan_output_axes", output_axes)) {
44 if (output_axes.size() != num_scan_outputs)
45 fail_shape_inference(
46 "Number of scan output axes specified (",
47 output_axes.size(),
48 ") is not equal to number of scan outputs (",
49 num_scan_outputs,
50 ").");
51 } else {
52 output_axes.insert(output_axes.end(), num_scan_outputs, 0);
53 }
54
55 std::vector<TypeProto> temporary_type_protos;
56 temporary_type_protos.reserve(num_inputs);
57
58 std::vector<const TypeProto*> subgraph_input_types;
59
60 TensorShapeProto_Dimension sequence_len_dim;
61
62 for (size_t i = 0; i < num_inputs; ++i) {
63 bool is_loop_state_var = i < num_loop_state_vars;
64 bool has_shape = hasInputShape(ctx, i);
65 const auto* input_type = ctx.getInputType(i);
66
67 // Enforce type constraint for inputs
68 if (!input_type || !input_type->has_tensor_type()) {
69 fail_type_inference("Scan input ", i, " was not a tensor.");
70 }
71
72 if (is_loop_state_var) {
73 // If it's a loop state variable we can propagate type and shape 1:1 to
74 // the matching Scan output.
75 // We can also pass through the type and shape to the subgraph but need to
76 // remove the batch size dimension from the shape.
77 propagateElemTypeFromInputToOutput(ctx, i, i);
78 if (has_shape)
79 propagateShapeFromInputToOutput(ctx, i, i);
80
81 subgraph_input_types.push_back(input_type);
82 } else {
83 // For other inputs there is no fixed relationships to the Scan outputs,
84 // so we don't propagate type/shape information.
85 // We can pass through the type and shape to the subgraph inputs but
86 // need to remove the sequence length dimensions from the shape.
87 if (has_shape) {
88 const auto& shape = input_type->tensor_type().shape();
89
90 // remove sequence length dimensions and add to subgraph_input_types
91 int axis = static_cast<int>(axes[i - num_loop_state_vars]);
92 axis = handle_negative_axis_validate(
93 "scan_input_axes", axis, shape.dim_size());
94
95 // update sequence_len if a value is available
96
97 const auto& dims = shape.dim();
98 mergeInDimensionInfo(dims.Get(axis), sequence_len_dim, 1);
99
100 temporary_type_protos.push_back(
101 RemoveIthDimensionFromShape(*input_type, axis));
102 subgraph_input_types.push_back(&temporary_type_protos.back());
103
104 } else {
105 subgraph_input_types.push_back(input_type);
106 }
107 }
108 }
109
110 // Run inferencing on the subgraph
111 std::vector<const TypeProto*> output_types;
112
113 GraphInferencer* graphInferencer = ctx.getGraphAttributeInferencer("body");
114 if (graphInferencer) {
115 std::vector<const TensorProto*> input_data;
116 for (size_t i = 0; i < num_inputs; ++i) {
117 // ctx.getInputData(i), the input to scan, does not represent the input to
118 // scan body. So, we pass in null, to represent an unknown value.
119 input_data.push_back(nullptr);
120 }
121
122 output_types =
123 graphInferencer->doInferencing(subgraph_input_types, input_data);
124 }
125
126 // if empty(), assume inferencing was skipped
127 if (!output_types.empty()) {
128 if (output_types.size() != num_outputs) {
129 fail_type_inference(
130 "Graph attribute inferencing returned type information for ",
131 output_types.size(),
132 " outputs. Expected ",
133 num_outputs);
134 }
135
136 // propagate type/shape information for loop state variables and outputs
137 for (size_t i = 0; i < num_outputs; ++i) {
138 const bool is_loop_state_var = i < num_loop_state_vars;
139 auto* subgraph_output_type = output_types[i];
140 auto* scan_output_type = ctx.getOutputType(i);
141 auto* mutable_scan_output_tensor_type =
142 scan_output_type->mutable_tensor_type();
143
144 if (!subgraph_output_type->has_tensor_type()) {
145 fail_type_inference(
146 "Scan 'body' subgraph outputs should all be tensors but output ",
147 i,
148 " was not");
149 }
150 auto& subgraph_output_tensor_type = subgraph_output_type->tensor_type();
151
152 if (is_loop_state_var) {
153 // merge shape; type already propagated
154 mergeInShapeInfo(
155 subgraph_output_tensor_type, *mutable_scan_output_tensor_type);
156 } else {
157 scan_output_type->mutable_tensor_type()->set_elem_type(
158 subgraph_output_tensor_type.elem_type());
159
160 // propagate shape
161 if (subgraph_output_tensor_type.has_shape()) {
162 // infer shape of scan-output from the shape of scan-output-element
163 // by adding sequence-length at the correct axis position
164 const TensorShapeProto& subgraph_output_shape =
165 subgraph_output_tensor_type.shape();
166 TensorShapeProto inferred_shape;
167
168 auto subgraph_output_rank = subgraph_output_shape.dim_size();
169 auto output_rank = subgraph_output_rank + 1;
170 int output_axis =
171 static_cast<int>(output_axes[i - num_loop_state_vars]);
172 output_axis = handle_negative_axis_validate(
173 "scan_output_axes", output_axis, output_rank);
174
175 for (int j = 0; j < output_axis; ++j)
176 *(inferred_shape.add_dim()) = subgraph_output_shape.dim(j);
177 *(inferred_shape.add_dim()) = sequence_len_dim;
178 for (int j = output_axis; j < subgraph_output_rank; ++j)
179 *(inferred_shape.add_dim()) = subgraph_output_shape.dim(j);
180
181 // Merge inferred shape with existing shape information
182 mergeInShapeInfo(inferred_shape, *mutable_scan_output_tensor_type);
183 }
184 }
185 }
186 }
187 }
188
IfInferenceFunction(InferenceContext & ctx)189 void IfInferenceFunction(InferenceContext& ctx) {
190 // there are no inputs so we just need to run the subgraph inferencing for
191 // then/else subgraphs and apply those to the outputs.
192 std::vector<const TypeProto*> subgraph_input_types; // none
193 std::vector<const TensorProto*> input_data; // none
194
195 std::vector<const TypeProto*> then_output_types;
196 std::vector<const TypeProto*> else_output_types;
197
198 // Run inferencing on the subgraph
199 GraphInferencer* graphInferencer =
200 ctx.getGraphAttributeInferencer("then_branch");
201 if (graphInferencer) {
202 then_output_types =
203 graphInferencer->doInferencing(subgraph_input_types, input_data);
204 }
205
206 graphInferencer = ctx.getGraphAttributeInferencer("else_branch");
207 if (graphInferencer) {
208 else_output_types =
209 graphInferencer->doInferencing(subgraph_input_types, input_data);
210 }
211
212 auto num_outputs = ctx.getNumOutputs();
213 auto num_then_outputs = then_output_types.size();
214 auto num_else_outputs = else_output_types.size();
215
216 // the output types for then and else should be the same
217 if (num_then_outputs != num_else_outputs) {
218 fail_type_inference(
219 "then_branch and else_branch produce different number of outputs. ",
220 num_then_outputs,
221 " != ",
222 num_else_outputs);
223 }
224
225 if (num_then_outputs != num_outputs) {
226 fail_type_inference(
227 "If node has ",
228 num_outputs,
229 " but subgraphs produce ",
230 num_then_outputs);
231 }
232
233 for (size_t i = 0, end = then_output_types.size(); i < end; ++i) {
234 auto then_output = then_output_types[i];
235 auto else_output = else_output_types[i];
236
237 if (then_output->value_case() != else_output->value_case()) {
238 fail_type_inference(
239 "Mismatched type for output ",
240 i,
241 " then=",
242 then_output->value_case(),
243 " else=",
244 else_output->value_case());
245 }
246
247 auto* if_output = ctx.getOutputType(i);
248 *if_output = *then_output;
249
250 if (then_output->has_tensor_type()) {
251 auto then_elem_type = then_output->tensor_type().elem_type();
252 auto else_elem_type = else_output->tensor_type().elem_type();
253
254 if (then_elem_type != else_elem_type) {
255 fail_type_inference(
256 "Mismatched tensor element type for output ",
257 i,
258 " then=",
259 then_elem_type,
260 " else=",
261 else_elem_type);
262 }
263
264 // merge the 'else' shape information to check it's consistent and
265 // augment the 'if' output if possible
266 mergeInShapeInfo(
267 else_output->tensor_type(), *if_output->mutable_tensor_type());
268 }
269 }
270 }
271
LoopInferenceFunction(InferenceContext & ctx)272 void LoopInferenceFunction(InferenceContext& ctx) {
273 auto num_inputs = ctx.getNumInputs();
274 auto num_loop_state_vars = num_inputs - 2; // skip 'M' and 'cond'
275
276 std::vector<const TypeProto*> subgraph_input_types;
277
278 std::vector<TypeProto> temporary_type_protos;
279 temporary_type_protos.reserve(num_inputs - 2);
280
281 // create TypeProto to validate iteration number type is the same as the
282 // optional 'M' input for max iterations.
283 TypeProto iter_num_type;
284 iter_num_type.mutable_tensor_type()->set_elem_type(
285 TensorProto_DataType_INT64);
286 subgraph_input_types.push_back(&iter_num_type);
287
288 // 'cond'
289 subgraph_input_types.push_back(ctx.getInputType(1));
290
291 // loop state value types get propagated to outputs, but shape may change
292 // across iterations so don't propagate it to the outputs and don't pass it
293 // into the subgraph inferencing
294 for (size_t i = 2; i < num_inputs; ++i) {
295 propagateElemTypeFromInputToOutput(ctx, i, i - 2);
296
297 // copy so we can remove the shape before passing to the subgraph
298 // inferencing
299 temporary_type_protos.push_back(*ctx.getInputType(i));
300 auto& input_type = temporary_type_protos.back();
301 input_type.mutable_tensor_type()->clear_shape();
302
303 subgraph_input_types.push_back(&input_type);
304 }
305
306 // Run inferencing on the subgraph
307 std::vector<const TypeProto*> subgraph_output_types;
308
309 GraphInferencer* graphInferencer = ctx.getGraphAttributeInferencer("body");
310 if (graphInferencer) {
311 std::vector<const TensorProto*> input_data;
312 input_data.push_back(nullptr); // iteration number
313 for (size_t i = 1; i < num_inputs; ++i) {
314 input_data.push_back(ctx.getInputData(i));
315 }
316
317 subgraph_output_types =
318 graphInferencer->doInferencing(subgraph_input_types, input_data);
319 }
320
321 // if empty(), assume inferencing was skipped
322 if (!subgraph_output_types.empty()) {
323 auto num_outputs = ctx.getNumOutputs();
324
325 // subgraph outputs the condition value first but that is only used
326 // internally and not returned by Loop.
327 if (subgraph_output_types.size() != num_outputs + 1) {
328 fail_type_inference(
329 "Graph attribute inferencing returned type information for ",
330 subgraph_output_types.size(),
331 " outputs. Expected ",
332 num_outputs + 1);
333 }
334
335 // check loop state values match. we should already have type/shape info
336 for (size_t i = 0; i < num_outputs; ++i) {
337 auto* subgraph_output_type = subgraph_output_types[i + 1]; // skip 'cond'
338 auto* loop_output_type = ctx.getOutputType(i);
339
340 const bool is_loop_state_var = i < num_loop_state_vars;
341
342 if (!subgraph_output_type->has_tensor_type()) {
343 fail_type_inference(
344 "Loop 'body' subgraph outputs should all be tensors but output ",
345 i,
346 " was ",
347 subgraph_output_type->value_case());
348 }
349
350 // if there's an existing type check it matches. otherwise propagate
351 propagateElemTypeWithValidation(subgraph_output_type, loop_output_type);
352
353 if (is_loop_state_var) {
354 // shape may change across iterations so ignore.
355 } else {
356 // per iteration output. first dimension will be number of iterations
357 // but we don't know that value yet
358 TypeProto inferred_type(*subgraph_output_type);
359 auto* mutable_inferred_tensor_type =
360 inferred_type.mutable_tensor_type();
361 auto* mutable_inferred_shape =
362 mutable_inferred_tensor_type->mutable_shape();
363
364 mutable_inferred_shape->clear_dim();
365
366 // add empty dimension for number of iterations
367 mutable_inferred_shape->add_dim();
368
369 // add dimensions from subgraph output shape
370 for (const auto& dim :
371 subgraph_output_type->tensor_type().shape().dim()) {
372 (*mutable_inferred_shape->add_dim()) = dim;
373 }
374
375 mergeInShapeInfo(
376 *mutable_inferred_tensor_type,
377 *loop_output_type->mutable_tensor_type());
378 }
379 }
380 }
381 }
382
383 ONNX_OPERATOR_SET_SCHEMA(
384 If,
385 1,
386 OpSchema()
387 .SetDoc("If conditional")
388 .Input(0, "cond", "Condition for the if", "B")
389 .Output(
390 0,
391 "outputs",
392 "Values that are live-out to the enclosing scope. The return values in "
393 "the `then_branch` and `else_branch` must be of the same shape and same "
394 "data type.",
395 "V",
396 OpSchema::Variadic,
397 false)
398 .Attr(
399 "then_branch",
400 "Graph to run if condition is true. Has N outputs: values you wish to "
401 "be live-out to the enclosing scope. The number of outputs must match"
402 " the number of outputs in the else_branch.",
403 AttributeProto::GRAPH)
404 .Attr(
405 "else_branch",
406 "Graph to run if condition is false. Has N outputs: values you wish to"
407 " be live-out to the enclosing scope. The number of outputs must match"
408 " the number of outputs in the then_branch.",
409 AttributeProto::GRAPH)
410 .TypeConstraint("V", OpSchema::all_tensor_types(), "All Tensor types")
411 .TypeConstraint("B", {"tensor(bool)"}, "Only bool")
412 .TypeAndShapeInferenceFunction(IfInferenceFunction));
413
414 static const char* Loop_ver1_doc = R"DOC(
415 Generic Looping construct. This loop has multiple termination conditions:
416
417 1) Trip count. Iteration count specified at runtime. Set by
418 specifying the input M. Optional. Set to empty string to omit.
419 Note that a static trip count (specified at graph construction time) can be
420 specified by passing in a constant node for input M.
421 2) Loop termination condition. This is an input to the op that determines
422 whether to run the first iteration and also a loop-carried dependency for
423 the body graph. The body graph must yield a value for the condition variable,
424 whether this input is provided or not.
425
426 This table summarizes the operating modes of this operator with equivalent
427 C-style code:
428
429 Operator inputs defined as (max_trip_count, condition_var).
430
431 input ("", ""):
432 for (int i=0; ; ++i) {
433 cond = ... // Note this value is ignored, but is required in the body
434 }
435
436 input ("", cond) // Note this is analogous to a while loop
437 bool cond = ...;
438 for (int i=0; cond; ++i) {
439 cond = ...;
440 }
441
442 input ("", 1) // Note this is analogous to a do-while loop
443 bool cond = true
444 for (int i=0; cond; ++i) {
445 cond = ...;
446 }
447
448 input (trip_count, "") // Note this is analogous to a for loop
449 int trip_count = ...
450 for (int i=0; i < trip_count; ++i) {
451 cond = ...; // ignored
452 }
453
454 input (trip_count, cond)
455 int trip_count = ...;
456 bool cond = ...;
457 for (int i=0; i < trip_count && cond; ++i) {
458 cond = ...;
459 }
460
461
462 *Sample usage - cond as well as trip count*
463
464 graph predict-net {
465 %a = Constant[value = <Scalar Tensor [3]>]()
466 %b = Constant[value = <Scalar Tensor [6]>]()
467 %keepgoing = Constant[value = <Scalar Tensor [1]>]()
468 %max_trip_count = Constant[value = <Scalar Tensor [10]>]()
469 %keepgoing_out, %b_out, %user_defined_vals = Loop[body = <graph body-net>](%max_trip_count, %keepgoing, %b)
470 return
471 }
472
473 graph body-net (
474 %i[INT32, scalar]
475 %keepgoing[BOOL, scalar]
476 %b[INT32, scalar]
477 ) {
478 %my_local = Add(%a, %b)
479 %b_out = Sub(%a, %b)
480 %keepgoing_out = Greater(%my_local, %b_out)
481 %user_defined_vals = Add(%b, %b)
482 return %keepgoing_out, %b_out, %user_defined_vals
483 }
484
485 *Sample equivalent C code*
486
487 {
488 /* User-defined code (enclosing scope) */
489 int a = 3, b = 6;
490 bool keepgoing = true; // Analogous to input cond
491 /* End user-defined code */
492
493 /* Implicitly-defined code */
494 const int max_trip_count = 10; // Analogous to input M
495 int user_defined_vals[]; // Imagine this is resizable
496 /* End implicitly-defined code */
497 for (int i=0; i < max_trip_count && keepgoing; ++i) {
498 /* User-defined code (loop body) */
499 int my_local = a + b; // Reading values in the enclosing scope is fine
500 b = a - b; // writes fine if we specify b as a loop-carried dependency
501 keepgoing = my_local > b; // keepgoing is a loop-carried dependency
502 user_defined_vals[i] = b + b;
503 /* End user-defined code */
504 }
505 // my_local = 123; // Can't do this. my_local was defined in the the body
506
507 // These below values are live-out from the loop and therefore accessible
508 b_out; user_defined_vals; keepgoing_out;
509 }
510
511 There are several things of note in this code snippet:
512
513 1) Values from the enclosing scope (i.e. variable a here) are in scope and can
514 be referenced in the inputs of the loop.
515 2) Any variables which you wish to make available in the enclosing scope (i.e.
516 the variables b and keepgoing) must be declared as either loop-carried
517 dependencies (both at the op inputs and output and at the body net input and
518 output) or scan_outputs.
519 3) Values created in the body cannot be accessed in the enclosing scope.
520
521 Note that the semantics of this op support "diagonal" or "wavefront" execution.
522 (See Step 3 here for an example:
523 https://devblogs.nvidia.com/optimizing-recurrent-neural-networks-cudnn-5/).
524 Frontends should emit multi-layer RNNs as a series of While operators (with
525 time being the inner looping dimension), with each successive layer consuming
526 the scan_outputs from the previous layer, possibly going through several
527 point-wise operators (e.g. dropout, residual connections, linear layer).
528 )DOC";
529
530 ONNX_OPERATOR_SET_SCHEMA(
531 Loop,
532 1,
533 OpSchema()
534 .SetDoc(Loop_ver1_doc)
535 .Input(
536 0,
537 "M",
538 "A maximum trip-count for the loop specified at runtime. Optional."
539 " Pass empty string to skip.",
540 "I",
541 OpSchema::Optional)
542 .Input(
543 1,
544 "cond",
545 "A boolean termination condition. Optional. Pass empty string to skip.",
546 "B",
547 OpSchema::Optional)
548 .Input(
549 2,
550 "v_initial",
551 "The initial values of any loop-carried dependencies (values that "
552 "change across loop iterations)",
553 "V",
554 OpSchema::Variadic,
555 false)
556 .Output(
557 0,
558 "v_final_and_scan_outputs",
559 "Final N loop carried dependency values then K scan_outputs",
560 "V",
561 OpSchema::Variadic,
562 false)
563 .Attr(
564 "body",
565 "The graph run each iteration. It has 2+N inputs: (iteration_num, "
566 "condition, loop carried dependencies...). It has 1+N+K outputs: "
567 "(condition, loop carried dependencies..., scan_outputs...). Each "
568 "scan_output is created by concatenating the value of the specified "
569 "output value at the end of each iteration of the loop. It is an error"
570 " if the dimensions or data type of these scan_outputs change across loop"
571 " iterations.",
572 AttributeProto::GRAPH)
573 .TypeConstraint("V", OpSchema::all_tensor_types(), "All Tensor types")
574 .TypeConstraint(
575 "I",
576 {"tensor(int64)"},
577 "tensor of int64, which should be a scalar.")
578 .TypeConstraint(
579 "B",
580 {"tensor(bool)"},
581 "tensor of bool, which should be a scalar.")
582 .TypeAndShapeInferenceFunction(LoopInferenceFunction));
583
584 static const char* scan_9_doc = R"DOC(
585 Scan can be used to iterate over one or more scan_input tensors,
586 constructing zero or more scan_output tensors. It combines ideas from general recurrences,
587 functional programming constructs such as scan, fold, map, and zip and is intended to enable
588 generalizations of RNN-like constructs for sequence-to-sequence processing.
589 Other tensors (referred to as state_variables here) can be used to carry a state
590 when iterating from one element to another (similar to hidden-state in RNNs, also referred
591 to as loop-carried dependences in the context of loops).
592 Many common usages involve a single scan_input tensor (where functionality
593 similar to scan, fold and map can be obtained). When more than one scan_input is used,
594 a behavior similar to zip is obtained.
595
596 The attribute body must be a graph, specifying the computation to be performed in
597 every iteration. It takes as input the current values of the state_variables and
598 the current iterated element of the scan_inputs. It must return the (updated) values
599 of the state_variables and zero or more scan_output_element tensors. The values of the
600 scan_output_element tensors are concatenated over all the iterations to produce the
601 scan_output values of the scan construct (similar to the concatenated intermediate
602 hidden-state values of RNN-like constructs). All the output tensors (state_variables as
603 well as scan_output_element tensors) are required to have the same shape in each iteration
604 of the loop (a restriction imposed to enable efficient memory allocation).
605
606 Note that the iterated element passed to the body subgraph does not have a sequence
607 axis. It will have a rank one less than the rank of the corresponding scan_input.
608
609 The scan operation returns the final values of the state_variables as well as the
610 scan_outputs.
611
612 The optional attribute scan_input_directions specifies the direction (forward or backward)
613 for each scan input. If this attribute is omitted, all sequences are scanned in the forward
614 direction. A bidirectional scan may be performed by specifying the same tensor input twice
615 in the scan_inputs, once with a forward direction, and once with a backward direction.
616
617 The scan_output of the operation is produced by concatenating the scan_output_element
618 values produced by the body in each iteration. The optional attribute scan_output_directions
619 specifies the direction in which scan_output is constructed (by appending or prepending the
620 scan_output_element to scan_output in each iteration) for each scan_output. If this attribute
621 is omitted, the scan_output_element is appended to the scan_output in each iteration.
622
623 The optional attribute scan_input_axes specifies the axis to be scanned for each scan_input.
624 If omitted, every scan_input will be scanned in axis 0. For example, if axis 0 is the
625 batch axis and axis 1 is the time axis (to be scanned), specify an axis value of 1.
626 Note that scanning a non-zero axis may be less efficient than scanning axis zero.
627
628 The optional attribute scan_output_axes specifies the axis along which the scan_outputs
629 are accumulated for each scan_output. For example, if axis 1 is the time axis (to be
630 scanned) for both inputs and outputs, specify a scan_input axis and scan_output axis
631 value of 1.
632
633 Note that because of the ONNX restriction that only the last parameter of an operator can
634 be variadic, the initial-states and scan-inputs are listed together as one input parameter.
635 Similarly, the final-states and scan-outputs are listed together as one output parameter.
636 The attribute num_scan_inputs indicates the number M of scan-inputs.
637
638 The behavior of
639
640 Scan <
641 num_scan_inputs = m,
642 body = loop-body,
643 scan_input_axes = [axis_1, ..., axis_m]
644 > (init_1, ..., init_n, scan_1, ..., scan_m)
645
646 is equivalent to the following pseudo-code:
647
648 // scan_i.shape[axis_i] denotes the (max) sequence-length of scan_i
649 // scan_i.shape[axis_i] is required to be equal to scan_j.shape[axis_j] for all i,j.
650 sequence_length = scan_1.shape[axis_1];
651
652 // initialize state-variables
653 st_1 = init_1; ... st_n = init_n;
654 // initialize scan-output variables: [] denotes an empty tensor
655 scan_out_1 = []; ...; scan_out_k = [];
656 // identify number of iterations:
657
658 // execute loop
659 for (int t = 0; t < sequence_length; ++t) {
660 // generate the scan-input elements: the notation T<axis=k>[t] indicates the sub-tensor
661 // of rank one less than T obtained by indexing T at position t along axis k.
662 si_1 = scan_1<axis=axis_1>[t];
663 ... ;
664 si_m = scan_m<axis=axis_m>[t];
665 // execute loop-body
666 st_1, ..., st_n, so_1, ..., so_k = loop-body(st_1, ..., st_n, si_1, ..., si_m)
667 // accumulate the scan-output elements
668 scan_out_1 = Concat<axis=0>(scan_out_1, so_1); ... ; scan_out_k = Concat<axis=0>(scan_out_k, so_k);
669 }
670
671 return st_1, ..., st_n, scan_out_1, ..., scan_out_k;
672
673 *Sample usage: Encoding RNN using a Scan*
674
675 The following example shows how a simple RNN over an input tensor %X, with weight tensor %Wi,
676 recurrence weight tensor %Ri, bias tensors %Wbi and %Rbi, and initial hidden-state %H_0 can
677 be encoded as a ScanLoop. Note that the loop-body is a nested graph, and it directly computes
678 %Wi, %Ri, %Wbi, and %Rbi (typically constants or initializers in the body graph). If these
679 values are computed in the outer graph, they need to be passed in as extra state_variables.
680
681 graph rnn-encoding {
682 %H_0 = ...
683 %X = ...
684 %Y_h, %Y = Scan[body = <graph rnn-cell-1>, num_scan_inputs=1](%H_0, %X)
685 return %Y, %Y_h
686 }
687
688 graph rnn-cell-1 (
689 %H_tminus1[FLOAT, tensor]
690 %X_t[FLOAT, tensor]
691 ) {
692 %Wi = ...
693 %Ri = ...
694 %Wbi = ...
695 %Rbi = ...
696 %t1 = X_t * (Wi^T)
697 %t2 = H_tminus1*(Ri^T)
698 %t3 = Add(%t1, %t2)
699 %t4 = Add(%t3, %Wbi)
700 %t5 = Add(%t4, %Rbi)
701 %Ht = Tanh(%t5)
702 %Accumulate = Identity(%Ht)
703 return %Ht, %Accumulate
704 }
705
706 )DOC";
707
708 ONNX_OPERATOR_SET_SCHEMA(
709 Scan,
710 9,
711 OpSchema()
712 .SetDoc(scan_9_doc)
713 .Input(
714 0,
715 "initial_state_and_scan_inputs",
716 "Initial values of the loop's N state variables followed by M scan_inputs",
717 "V",
718 OpSchema::Variadic,
719 false)
720 .Output(
721 0,
722 "final_state_and_scan_outputs",
723 "Final values of the loop's N state variables followed by K scan_outputs",
724 "V",
725 OpSchema::Variadic,
726 false)
727 .Attr(
728 "body",
729 "The graph run each iteration. It has N+M inputs: "
730 "(loop state variables..., scan_input_elts...). It has N+K outputs: "
731 "(loop state variables..., scan_output_elts...). Each "
732 "scan_output is created by concatenating the value of the specified "
733 "scan_output_elt value at the end of each iteration of the loop. It is an error"
734 " if the dimensions of these values change across loop iterations.",
735 AttributeProto::GRAPH,
736 true)
737 .Attr(
738 "num_scan_inputs",
739 "An attribute specifying the number of scan_inputs M. ",
740 AttributeProto::INT,
741 true)
742 .Attr(
743 "scan_input_directions",
744 "An optional list of M flags. The i-th element of the list specifies the direction "
745 "to be scanned for the i-th scan_input tensor: 0 indicates forward direction and 1 "
746 "indicates reverse direction. "
747 "If omitted, all scan_input tensors will be scanned in the forward direction.",
748 AttributeProto::INTS,
749 false)
750 .Attr(
751 "scan_output_directions",
752 "An optional list of K flags, one for each scan_output. The i-th element of the list "
753 "specifies whether the i-th scan_output should be constructed by appending or "
754 "prepending a new value in each iteration: 0 indicates appending and 1 "
755 "indicates prepending. "
756 "If omitted, all scan_output tensors will be produced by appending a value "
757 "in each iteration.",
758 AttributeProto::INTS,
759 false)
760 .Attr(
761 "scan_input_axes",
762 "An optional list of M flags. The i-th element of the list specifies the axis "
763 "to be scanned (the sequence axis) for the i-th scan_input. If omitted, 0 will "
764 "be used as the scan axis for every scan_input.",
765 AttributeProto::INTS,
766 false)
767 .Attr(
768 "scan_output_axes",
769 "An optional list of K flags. The i-th element of the list specifies the axis "
770 "for the i-th scan_output. The scan outputs are accumulated along the specified "
771 "axis. If omitted, 0 will be used as the scan axis for every scan_output.",
772 AttributeProto::INTS,
773 false)
774 .TypeConstraint("I", {"tensor(int64)"}, "Int64 tensor")
775 .TypeConstraint("V", OpSchema::all_tensor_types(), "All Tensor types")
776 .TypeAndShapeInferenceFunction(ScanInferenceFunction));
777 } // namespace ONNX_NAMESPACE
778