1import numpy as np
2import pytest
3
4from keras_preprocessing.image import affine_transformations
5
6
7def test_random_transforms():
8    x = np.random.random((2, 28, 28))
9    assert affine_transformations.random_rotation(x, 45).shape == (2, 28, 28)
10    assert affine_transformations.random_shift(x, 1, 1).shape == (2, 28, 28)
11    assert affine_transformations.random_shear(x, 20).shape == (2, 28, 28)
12    assert affine_transformations.random_channel_shift(x, 20).shape == (2, 28, 28)
13
14
15def test_deterministic_transform():
16    x = np.ones((3, 3, 3))
17    x_rotated = np.array([[[0., 0., 0.],
18                           [0., 0., 0.],
19                           [1., 1., 1.]],
20                          [[0., 0., 0.],
21                           [1., 1., 1.],
22                           [1., 1., 1.]],
23                          [[0., 0., 0.],
24                           [0., 0., 0.],
25                           [1., 1., 1.]]])
26    assert np.allclose(affine_transformations.apply_affine_transform(
27        x, theta=45, channel_axis=2, fill_mode='constant'), x_rotated)
28
29
30def test_random_zoom():
31    x = np.random.random((2, 28, 28))
32    assert affine_transformations.random_zoom(x, (5, 5)).shape == (2, 28, 28)
33    assert np.allclose(x, affine_transformations.random_zoom(x, (1, 1)))
34
35
36def test_random_zoom_error():
37    with pytest.raises(ValueError):
38        affine_transformations.random_zoom(0, zoom_range=[0])
39
40
41def test_apply_brightness_shift_error(monkeypatch):
42    monkeypatch.setattr(affine_transformations, 'ImageEnhance', None)
43    with pytest.raises(ImportError):
44        affine_transformations.apply_brightness_shift(0, [0])
45
46
47def test_random_brightness(monkeypatch):
48    monkeypatch.setattr(affine_transformations,
49                        'apply_brightness_shift', lambda x, y: (x, y))
50    assert (0, 3.) == affine_transformations.random_brightness(0, (3, 3))
51
52
53def test_random_brightness_error():
54    with pytest.raises(ValueError):
55        affine_transformations.random_brightness(0, [0])
56
57
58def test_apply_affine_transform_error(monkeypatch):
59    monkeypatch.setattr(affine_transformations, 'scipy', None)
60    with pytest.raises(ImportError):
61        affine_transformations.apply_affine_transform(0)
62