1# SPDX-License-Identifier: Apache-2.0
2
3from __future__ import absolute_import
4from __future__ import division
5from __future__ import print_function
6from __future__ import unicode_literals
7
8import numpy as np  # type: ignore
9
10import onnx
11from ..base import Base
12from . import expect
13
14
15class And(Base):
16
17    @staticmethod
18    def export():  # type: () -> None
19        node = onnx.helper.make_node(
20            'And',
21            inputs=['x', 'y'],
22            outputs=['and'],
23        )
24
25        # 2d
26        x = (np.random.randn(3, 4) > 0).astype(np.bool)
27        y = (np.random.randn(3, 4) > 0).astype(np.bool)
28        z = np.logical_and(x, y)
29        expect(node, inputs=[x, y], outputs=[z],
30               name='test_and2d')
31
32        # 3d
33        x = (np.random.randn(3, 4, 5) > 0).astype(np.bool)
34        y = (np.random.randn(3, 4, 5) > 0).astype(np.bool)
35        z = np.logical_and(x, y)
36        expect(node, inputs=[x, y], outputs=[z],
37               name='test_and3d')
38
39        # 4d
40        x = (np.random.randn(3, 4, 5, 6) > 0).astype(np.bool)
41        y = (np.random.randn(3, 4, 5, 6) > 0).astype(np.bool)
42        z = np.logical_and(x, y)
43        expect(node, inputs=[x, y], outputs=[z],
44               name='test_and4d')
45
46    @staticmethod
47    def export_and_broadcast():  # type: () -> None
48        node = onnx.helper.make_node(
49            'And',
50            inputs=['x', 'y'],
51            outputs=['and'],
52        )
53
54        # 3d vs 1d
55        x = (np.random.randn(3, 4, 5) > 0).astype(np.bool)
56        y = (np.random.randn(5) > 0).astype(np.bool)
57        z = np.logical_and(x, y)
58        expect(node, inputs=[x, y], outputs=[z],
59               name='test_and_bcast3v1d')
60
61        # 3d vs 2d
62        x = (np.random.randn(3, 4, 5) > 0).astype(np.bool)
63        y = (np.random.randn(4, 5) > 0).astype(np.bool)
64        z = np.logical_and(x, y)
65        expect(node, inputs=[x, y], outputs=[z],
66               name='test_and_bcast3v2d')
67
68        # 4d vs 2d
69        x = (np.random.randn(3, 4, 5, 6) > 0).astype(np.bool)
70        y = (np.random.randn(5, 6) > 0).astype(np.bool)
71        z = np.logical_and(x, y)
72        expect(node, inputs=[x, y], outputs=[z],
73               name='test_and_bcast4v2d')
74
75        # 4d vs 3d
76        x = (np.random.randn(3, 4, 5, 6) > 0).astype(np.bool)
77        y = (np.random.randn(4, 5, 6) > 0).astype(np.bool)
78        z = np.logical_and(x, y)
79        expect(node, inputs=[x, y], outputs=[z],
80               name='test_and_bcast4v3d')
81
82        # 4d vs 4d
83        x = (np.random.randn(1, 4, 1, 6) > 0).astype(np.bool)
84        y = (np.random.randn(3, 1, 5, 6) > 0).astype(np.bool)
85        z = np.logical_and(x, y)
86        expect(node, inputs=[x, y], outputs=[z],
87               name='test_and_bcast4v4d')
88