1import numpy as np
2import skimage.graph.mcp as mcp
3
4from skimage._shared.testing import assert_array_equal
5
6
7a = np.ones((8, 8), dtype=np.float32)
8a[1::2] *= 2.0
9
10
11class FlexibleMCP(mcp.MCP_Flexible):
12    """ Simple MCP subclass that allows the front to travel
13    a certain distance from the seed point, and uses a constant
14    cost factor that is independent of the cost array.
15    """
16
17    def _reset(self):
18        mcp.MCP_Flexible._reset(self)
19        self._distance = np.zeros((8, 8), dtype=np.float32).ravel()
20
21    def goal_reached(self, index, cumcost):
22        if self._distance[index] > 4:
23            return 2
24        else:
25            return 0
26
27    def travel_cost(self, index, new_index, offset_length):
28        return 1.0  # fixed cost
29
30    def examine_neighbor(self, index, new_index, offset_length):
31        pass  # We do not test this
32
33    def update_node(self, index, new_index, offset_length):
34        self._distance[new_index] = self._distance[index] + 1
35
36
37def test_flexible():
38    # Create MCP and do a traceback
39    mcp = FlexibleMCP(a)
40    costs, traceback = mcp.find_costs([(0, 0)])
41
42    # Check that inner part is correct. This basically
43    # tests whether travel_cost works.
44    assert_array_equal(costs[:4, :4], [[1, 2, 3, 4],
45                                       [2, 2, 3, 4],
46                                       [3, 3, 3, 4],
47                                       [4, 4, 4, 4]])
48
49    # Test that the algorithm stopped at the right distance.
50    # Note that some of the costs are filled in but not yet frozen,
51    # so we take a bit of margin
52    assert np.all(costs[-2:, :] == np.inf)
53    assert np.all(costs[:, -2:] == np.inf)
54