1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17from __future__ import print_function
18import os
19import sys
20import mxnet as mx
21import mxnet.ndarray as nd
22import numpy as np
23from mxnet import gluon
24from mxnet.base import MXNetError
25from mxnet.gluon.data.vision import transforms
26from mxnet.test_utils import assert_almost_equal, set_default_context
27from mxnet.test_utils import almost_equal, same
28curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
29sys.path.insert(0, os.path.join(curr_path, '../unittest'))
30from common import assertRaises, setup_module, with_seed, teardown
31from test_gluon_data_vision import test_to_tensor, test_normalize, test_crop_resize
32
33set_default_context(mx.gpu(0))
34
35@with_seed()
36def test_normalize_gpu():
37    test_normalize()
38
39
40@with_seed()
41def test_to_tensor_gpu():
42    test_to_tensor()
43
44
45@with_seed()
46def test_resize_gpu():
47    # Test with normal case 3D input float type
48    data_in_3d = nd.random.uniform(0, 255, (300, 300, 3))
49    out_nd_3d = transforms.Resize((100, 100))(data_in_3d)
50    data_in_4d_nchw = nd.moveaxis(nd.expand_dims(data_in_3d, axis=0), 3, 1)
51    data_expected_3d = (nd.moveaxis(nd.contrib.BilinearResize2D(data_in_4d_nchw, height=100, width=100, align_corners=False), 1, 3))[0]
52    assert_almost_equal(out_nd_3d.asnumpy(), data_expected_3d.asnumpy())
53
54    # Test with normal case 4D input float type
55    data_in_4d = nd.random.uniform(0, 255, (2, 300, 300, 3))
56    out_nd_4d = transforms.Resize((100, 100))(data_in_4d)
57    data_in_4d_nchw = nd.moveaxis(data_in_4d, 3, 1)
58    data_expected_4d = nd.moveaxis(nd.contrib.BilinearResize2D(data_in_4d_nchw, height=100, width=100, align_corners=False), 1, 3)
59    assert_almost_equal(out_nd_4d.asnumpy(), data_expected_4d.asnumpy())
60
61    # Test invalid interp
62    data_in_3d = nd.random.uniform(0, 255, (300, 300, 3))
63    invalid_transform = transforms.Resize(-150, keep_ratio=False, interpolation=2)
64    assertRaises(MXNetError, invalid_transform, data_in_3d)
65
66    # Credited to Hang Zhang
67    def py_bilinear_resize_nhwc(x, outputHeight, outputWidth):
68        batch, inputHeight, inputWidth, channel = x.shape
69        if outputHeight == inputHeight and outputWidth == inputWidth:
70            return x
71        y = np.empty([batch, outputHeight, outputWidth, channel]).astype('uint8')
72        rheight = 1.0 * (inputHeight - 1) / (outputHeight - 1) if outputHeight > 1 else 0.0
73        rwidth = 1.0 * (inputWidth - 1) / (outputWidth - 1) if outputWidth > 1 else 0.0
74        for h2 in range(outputHeight):
75            h1r = 1.0 * h2 * rheight
76            h1 = int(np.floor(h1r))
77            h1lambda = h1r - h1
78            h1p = 1 if h1 < (inputHeight - 1) else 0
79            for w2 in range(outputWidth):
80                w1r = 1.0 * w2 * rwidth
81                w1 = int(np.floor(w1r))
82                w1lambda = w1r - w1
83                w1p = 1 if w1 < (inputHeight - 1) else 0
84                for b in range(batch):
85                    for c in range(channel):
86                        y[b][h2][w2][c] = (1-h1lambda)*((1-w1lambda)*x[b][h1][w1][c] + \
87                            w1lambda*x[b][h1][w1+w1p][c]) + \
88                            h1lambda*((1-w1lambda)*x[b][h1+h1p][w1][c] + \
89                            w1lambda*x[b][h1+h1p][w1+w1p][c])
90        return y
91
92@with_seed()
93def test_crop_resize_gpu():
94    test_crop_resize()
95