1from itertools import product
2
3import numpy as np
4from numpy.testing import assert_array_equal, assert_equal
5import pytest
6
7from scipy.sparse import csr_matrix, coo_matrix, diags
8from scipy.sparse.csgraph import (
9    maximum_bipartite_matching, min_weight_full_bipartite_matching
10)
11
12
13def test_maximum_bipartite_matching_raises_on_dense_input():
14    with pytest.raises(TypeError):
15        graph = np.array([[0, 1], [0, 0]])
16        maximum_bipartite_matching(graph)
17
18
19def test_maximum_bipartite_matching_empty_graph():
20    graph = csr_matrix((0, 0))
21    x = maximum_bipartite_matching(graph, perm_type='row')
22    y = maximum_bipartite_matching(graph, perm_type='column')
23    expected_matching = np.array([])
24    assert_array_equal(expected_matching, x)
25    assert_array_equal(expected_matching, y)
26
27
28def test_maximum_bipartite_matching_empty_left_partition():
29    graph = csr_matrix((2, 0))
30    x = maximum_bipartite_matching(graph, perm_type='row')
31    y = maximum_bipartite_matching(graph, perm_type='column')
32    assert_array_equal(np.array([]), x)
33    assert_array_equal(np.array([-1, -1]), y)
34
35
36def test_maximum_bipartite_matching_empty_right_partition():
37    graph = csr_matrix((0, 3))
38    x = maximum_bipartite_matching(graph, perm_type='row')
39    y = maximum_bipartite_matching(graph, perm_type='column')
40    assert_array_equal(np.array([-1, -1, -1]), x)
41    assert_array_equal(np.array([]), y)
42
43
44def test_maximum_bipartite_matching_graph_with_no_edges():
45    graph = csr_matrix((2, 2))
46    x = maximum_bipartite_matching(graph, perm_type='row')
47    y = maximum_bipartite_matching(graph, perm_type='column')
48    assert_array_equal(np.array([-1, -1]), x)
49    assert_array_equal(np.array([-1, -1]), y)
50
51
52def test_maximum_bipartite_matching_graph_that_causes_augmentation():
53    # In this graph, column 1 is initially assigned to row 1, but it should be
54    # reassigned to make room for row 2.
55    graph = csr_matrix([[1, 1], [1, 0]])
56    x = maximum_bipartite_matching(graph, perm_type='column')
57    y = maximum_bipartite_matching(graph, perm_type='row')
58    expected_matching = np.array([1, 0])
59    assert_array_equal(expected_matching, x)
60    assert_array_equal(expected_matching, y)
61
62
63def test_maximum_bipartite_matching_graph_with_more_rows_than_columns():
64    graph = csr_matrix([[1, 1], [1, 0], [0, 1]])
65    x = maximum_bipartite_matching(graph, perm_type='column')
66    y = maximum_bipartite_matching(graph, perm_type='row')
67    assert_array_equal(np.array([0, -1, 1]), x)
68    assert_array_equal(np.array([0, 2]), y)
69
70
71def test_maximum_bipartite_matching_graph_with_more_columns_than_rows():
72    graph = csr_matrix([[1, 1, 0], [0, 0, 1]])
73    x = maximum_bipartite_matching(graph, perm_type='column')
74    y = maximum_bipartite_matching(graph, perm_type='row')
75    assert_array_equal(np.array([0, 2]), x)
76    assert_array_equal(np.array([0, -1, 1]), y)
77
78
79def test_maximum_bipartite_matching_explicit_zeros_count_as_edges():
80    data = [0, 0]
81    indices = [1, 0]
82    indptr = [0, 1, 2]
83    graph = csr_matrix((data, indices, indptr), shape=(2, 2))
84    x = maximum_bipartite_matching(graph, perm_type='row')
85    y = maximum_bipartite_matching(graph, perm_type='column')
86    expected_matching = np.array([1, 0])
87    assert_array_equal(expected_matching, x)
88    assert_array_equal(expected_matching, y)
89
90
91def test_maximum_bipartite_matching_feasibility_of_result():
92    # This is a regression test for GitHub issue #11458
93    data = np.ones(50, dtype=int)
94    indices = [11, 12, 19, 22, 23, 5, 22, 3, 8, 10, 5, 6, 11, 12, 13, 5, 13,
95               14, 20, 22, 3, 15, 3, 13, 14, 11, 12, 19, 22, 23, 5, 22, 3, 8,
96               10, 5, 6, 11, 12, 13, 5, 13, 14, 20, 22, 3, 15, 3, 13, 14]
97    indptr = [0, 5, 7, 10, 10, 15, 20, 22, 22, 23, 25, 30, 32, 35, 35, 40, 45,
98              47, 47, 48, 50]
99    graph = csr_matrix((data, indices, indptr), shape=(20, 25))
100    x = maximum_bipartite_matching(graph, perm_type='row')
101    y = maximum_bipartite_matching(graph, perm_type='column')
102    assert (x != -1).sum() == 13
103    assert (y != -1).sum() == 13
104    # Ensure that each element of the matching is in fact an edge in the graph.
105    for u, v in zip(range(graph.shape[0]), y):
106        if v != -1:
107            assert graph[u, v]
108    for u, v in zip(x, range(graph.shape[1])):
109        if u != -1:
110            assert graph[u, v]
111
112
113def test_matching_large_random_graph_with_one_edge_incident_to_each_vertex():
114    np.random.seed(42)
115    A = diags(np.ones(25), offsets=0, format='csr')
116    rand_perm = np.random.permutation(25)
117    rand_perm2 = np.random.permutation(25)
118
119    Rrow = np.arange(25)
120    Rcol = rand_perm
121    Rdata = np.ones(25, dtype=int)
122    Rmat = coo_matrix((Rdata, (Rrow, Rcol))).tocsr()
123
124    Crow = rand_perm2
125    Ccol = np.arange(25)
126    Cdata = np.ones(25, dtype=int)
127    Cmat = coo_matrix((Cdata, (Crow, Ccol))).tocsr()
128    # Randomly permute identity matrix
129    B = Rmat * A * Cmat
130
131    # Row permute
132    perm = maximum_bipartite_matching(B, perm_type='row')
133    Rrow = np.arange(25)
134    Rcol = perm
135    Rdata = np.ones(25, dtype=int)
136    Rmat = coo_matrix((Rdata, (Rrow, Rcol))).tocsr()
137    C1 = Rmat * B
138
139    # Column permute
140    perm2 = maximum_bipartite_matching(B, perm_type='column')
141    Crow = perm2
142    Ccol = np.arange(25)
143    Cdata = np.ones(25, dtype=int)
144    Cmat = coo_matrix((Cdata, (Crow, Ccol))).tocsr()
145    C2 = B * Cmat
146
147    # Should get identity matrix back
148    assert_equal(any(C1.diagonal() == 0), False)
149    assert_equal(any(C2.diagonal() == 0), False)
150
151
152@pytest.mark.parametrize('num_rows,num_cols', [(0, 0), (2, 0), (0, 3)])
153def test_min_weight_full_matching_trivial_graph(num_rows, num_cols):
154    biadjacency_matrix = csr_matrix((num_cols, num_rows))
155    row_ind, col_ind = min_weight_full_bipartite_matching(biadjacency_matrix)
156    assert len(row_ind) == 0
157    assert len(col_ind) == 0
158
159
160@pytest.mark.parametrize('biadjacency_matrix',
161                         [
162                            [[1, 1, 1], [1, 0, 0], [1, 0, 0]],
163                            [[1, 1, 1], [0, 0, 1], [0, 0, 1]],
164                            [[1, 0, 0], [2, 0, 0]],
165                            [[0, 1, 0], [0, 2, 0]],
166                            [[1, 0], [2, 0], [5, 0]]
167                         ])
168def test_min_weight_full_matching_infeasible_problems(biadjacency_matrix):
169    with pytest.raises(ValueError):
170        min_weight_full_bipartite_matching(csr_matrix(biadjacency_matrix))
171
172
173def test_explicit_zero_causes_warning():
174    with pytest.warns(UserWarning):
175        biadjacency_matrix = csr_matrix(((2, 0, 3), (0, 1, 1), (0, 2, 3)))
176        min_weight_full_bipartite_matching(biadjacency_matrix)
177
178
179# General test for linear sum assignment solvers to make it possible to rely
180# on the same tests for scipy.optimize.linear_sum_assignment.
181def linear_sum_assignment_assertions(
182    solver, array_type, sign, test_case
183):
184    cost_matrix, expected_cost = test_case
185    maximize = sign == -1
186    cost_matrix = sign * array_type(cost_matrix)
187    expected_cost = sign * np.array(expected_cost)
188
189    row_ind, col_ind = solver(cost_matrix, maximize=maximize)
190    assert_array_equal(row_ind, np.sort(row_ind))
191    assert_array_equal(expected_cost,
192                       np.array(cost_matrix[row_ind, col_ind]).flatten())
193
194    cost_matrix = cost_matrix.T
195    row_ind, col_ind = solver(cost_matrix, maximize=maximize)
196    assert_array_equal(row_ind, np.sort(row_ind))
197    assert_array_equal(np.sort(expected_cost),
198                       np.sort(np.array(
199                           cost_matrix[row_ind, col_ind])).flatten())
200
201
202linear_sum_assignment_test_cases = product(
203    [-1, 1],
204    [
205        # Square
206        ([[400, 150, 400],
207          [400, 450, 600],
208          [300, 225, 300]],
209         [150, 400, 300]),
210
211        # Rectangular variant
212        ([[400, 150, 400, 1],
213          [400, 450, 600, 2],
214          [300, 225, 300, 3]],
215         [150, 2, 300]),
216
217        ([[10, 10, 8],
218          [9, 8, 1],
219          [9, 7, 4]],
220         [10, 1, 7]),
221
222        # Square
223        ([[10, 10, 8, 11],
224          [9, 8, 1, 1],
225          [9, 7, 4, 10]],
226         [10, 1, 4]),
227
228        # Rectangular variant
229        ([[10, float("inf"), float("inf")],
230          [float("inf"), float("inf"), 1],
231          [float("inf"), 7, float("inf")]],
232         [10, 1, 7])
233    ])
234
235
236@pytest.mark.parametrize('sign,test_case', linear_sum_assignment_test_cases)
237def test_min_weight_full_matching_small_inputs(sign, test_case):
238    linear_sum_assignment_assertions(
239        min_weight_full_bipartite_matching, csr_matrix, sign, test_case)
240