1from itertools import product 2 3import numpy as np 4from numpy.testing import assert_array_equal, assert_equal 5import pytest 6 7from scipy.sparse import csr_matrix, coo_matrix, diags 8from scipy.sparse.csgraph import ( 9 maximum_bipartite_matching, min_weight_full_bipartite_matching 10) 11 12 13def test_maximum_bipartite_matching_raises_on_dense_input(): 14 with pytest.raises(TypeError): 15 graph = np.array([[0, 1], [0, 0]]) 16 maximum_bipartite_matching(graph) 17 18 19def test_maximum_bipartite_matching_empty_graph(): 20 graph = csr_matrix((0, 0)) 21 x = maximum_bipartite_matching(graph, perm_type='row') 22 y = maximum_bipartite_matching(graph, perm_type='column') 23 expected_matching = np.array([]) 24 assert_array_equal(expected_matching, x) 25 assert_array_equal(expected_matching, y) 26 27 28def test_maximum_bipartite_matching_empty_left_partition(): 29 graph = csr_matrix((2, 0)) 30 x = maximum_bipartite_matching(graph, perm_type='row') 31 y = maximum_bipartite_matching(graph, perm_type='column') 32 assert_array_equal(np.array([]), x) 33 assert_array_equal(np.array([-1, -1]), y) 34 35 36def test_maximum_bipartite_matching_empty_right_partition(): 37 graph = csr_matrix((0, 3)) 38 x = maximum_bipartite_matching(graph, perm_type='row') 39 y = maximum_bipartite_matching(graph, perm_type='column') 40 assert_array_equal(np.array([-1, -1, -1]), x) 41 assert_array_equal(np.array([]), y) 42 43 44def test_maximum_bipartite_matching_graph_with_no_edges(): 45 graph = csr_matrix((2, 2)) 46 x = maximum_bipartite_matching(graph, perm_type='row') 47 y = maximum_bipartite_matching(graph, perm_type='column') 48 assert_array_equal(np.array([-1, -1]), x) 49 assert_array_equal(np.array([-1, -1]), y) 50 51 52def test_maximum_bipartite_matching_graph_that_causes_augmentation(): 53 # In this graph, column 1 is initially assigned to row 1, but it should be 54 # reassigned to make room for row 2. 55 graph = csr_matrix([[1, 1], [1, 0]]) 56 x = maximum_bipartite_matching(graph, perm_type='column') 57 y = maximum_bipartite_matching(graph, perm_type='row') 58 expected_matching = np.array([1, 0]) 59 assert_array_equal(expected_matching, x) 60 assert_array_equal(expected_matching, y) 61 62 63def test_maximum_bipartite_matching_graph_with_more_rows_than_columns(): 64 graph = csr_matrix([[1, 1], [1, 0], [0, 1]]) 65 x = maximum_bipartite_matching(graph, perm_type='column') 66 y = maximum_bipartite_matching(graph, perm_type='row') 67 assert_array_equal(np.array([0, -1, 1]), x) 68 assert_array_equal(np.array([0, 2]), y) 69 70 71def test_maximum_bipartite_matching_graph_with_more_columns_than_rows(): 72 graph = csr_matrix([[1, 1, 0], [0, 0, 1]]) 73 x = maximum_bipartite_matching(graph, perm_type='column') 74 y = maximum_bipartite_matching(graph, perm_type='row') 75 assert_array_equal(np.array([0, 2]), x) 76 assert_array_equal(np.array([0, -1, 1]), y) 77 78 79def test_maximum_bipartite_matching_explicit_zeros_count_as_edges(): 80 data = [0, 0] 81 indices = [1, 0] 82 indptr = [0, 1, 2] 83 graph = csr_matrix((data, indices, indptr), shape=(2, 2)) 84 x = maximum_bipartite_matching(graph, perm_type='row') 85 y = maximum_bipartite_matching(graph, perm_type='column') 86 expected_matching = np.array([1, 0]) 87 assert_array_equal(expected_matching, x) 88 assert_array_equal(expected_matching, y) 89 90 91def test_maximum_bipartite_matching_feasibility_of_result(): 92 # This is a regression test for GitHub issue #11458 93 data = np.ones(50, dtype=int) 94 indices = [11, 12, 19, 22, 23, 5, 22, 3, 8, 10, 5, 6, 11, 12, 13, 5, 13, 95 14, 20, 22, 3, 15, 3, 13, 14, 11, 12, 19, 22, 23, 5, 22, 3, 8, 96 10, 5, 6, 11, 12, 13, 5, 13, 14, 20, 22, 3, 15, 3, 13, 14] 97 indptr = [0, 5, 7, 10, 10, 15, 20, 22, 22, 23, 25, 30, 32, 35, 35, 40, 45, 98 47, 47, 48, 50] 99 graph = csr_matrix((data, indices, indptr), shape=(20, 25)) 100 x = maximum_bipartite_matching(graph, perm_type='row') 101 y = maximum_bipartite_matching(graph, perm_type='column') 102 assert (x != -1).sum() == 13 103 assert (y != -1).sum() == 13 104 # Ensure that each element of the matching is in fact an edge in the graph. 105 for u, v in zip(range(graph.shape[0]), y): 106 if v != -1: 107 assert graph[u, v] 108 for u, v in zip(x, range(graph.shape[1])): 109 if u != -1: 110 assert graph[u, v] 111 112 113def test_matching_large_random_graph_with_one_edge_incident_to_each_vertex(): 114 np.random.seed(42) 115 A = diags(np.ones(25), offsets=0, format='csr') 116 rand_perm = np.random.permutation(25) 117 rand_perm2 = np.random.permutation(25) 118 119 Rrow = np.arange(25) 120 Rcol = rand_perm 121 Rdata = np.ones(25, dtype=int) 122 Rmat = coo_matrix((Rdata, (Rrow, Rcol))).tocsr() 123 124 Crow = rand_perm2 125 Ccol = np.arange(25) 126 Cdata = np.ones(25, dtype=int) 127 Cmat = coo_matrix((Cdata, (Crow, Ccol))).tocsr() 128 # Randomly permute identity matrix 129 B = Rmat * A * Cmat 130 131 # Row permute 132 perm = maximum_bipartite_matching(B, perm_type='row') 133 Rrow = np.arange(25) 134 Rcol = perm 135 Rdata = np.ones(25, dtype=int) 136 Rmat = coo_matrix((Rdata, (Rrow, Rcol))).tocsr() 137 C1 = Rmat * B 138 139 # Column permute 140 perm2 = maximum_bipartite_matching(B, perm_type='column') 141 Crow = perm2 142 Ccol = np.arange(25) 143 Cdata = np.ones(25, dtype=int) 144 Cmat = coo_matrix((Cdata, (Crow, Ccol))).tocsr() 145 C2 = B * Cmat 146 147 # Should get identity matrix back 148 assert_equal(any(C1.diagonal() == 0), False) 149 assert_equal(any(C2.diagonal() == 0), False) 150 151 152@pytest.mark.parametrize('num_rows,num_cols', [(0, 0), (2, 0), (0, 3)]) 153def test_min_weight_full_matching_trivial_graph(num_rows, num_cols): 154 biadjacency_matrix = csr_matrix((num_cols, num_rows)) 155 row_ind, col_ind = min_weight_full_bipartite_matching(biadjacency_matrix) 156 assert len(row_ind) == 0 157 assert len(col_ind) == 0 158 159 160@pytest.mark.parametrize('biadjacency_matrix', 161 [ 162 [[1, 1, 1], [1, 0, 0], [1, 0, 0]], 163 [[1, 1, 1], [0, 0, 1], [0, 0, 1]], 164 [[1, 0, 0], [2, 0, 0]], 165 [[0, 1, 0], [0, 2, 0]], 166 [[1, 0], [2, 0], [5, 0]] 167 ]) 168def test_min_weight_full_matching_infeasible_problems(biadjacency_matrix): 169 with pytest.raises(ValueError): 170 min_weight_full_bipartite_matching(csr_matrix(biadjacency_matrix)) 171 172 173def test_explicit_zero_causes_warning(): 174 with pytest.warns(UserWarning): 175 biadjacency_matrix = csr_matrix(((2, 0, 3), (0, 1, 1), (0, 2, 3))) 176 min_weight_full_bipartite_matching(biadjacency_matrix) 177 178 179# General test for linear sum assignment solvers to make it possible to rely 180# on the same tests for scipy.optimize.linear_sum_assignment. 181def linear_sum_assignment_assertions( 182 solver, array_type, sign, test_case 183): 184 cost_matrix, expected_cost = test_case 185 maximize = sign == -1 186 cost_matrix = sign * array_type(cost_matrix) 187 expected_cost = sign * np.array(expected_cost) 188 189 row_ind, col_ind = solver(cost_matrix, maximize=maximize) 190 assert_array_equal(row_ind, np.sort(row_ind)) 191 assert_array_equal(expected_cost, 192 np.array(cost_matrix[row_ind, col_ind]).flatten()) 193 194 cost_matrix = cost_matrix.T 195 row_ind, col_ind = solver(cost_matrix, maximize=maximize) 196 assert_array_equal(row_ind, np.sort(row_ind)) 197 assert_array_equal(np.sort(expected_cost), 198 np.sort(np.array( 199 cost_matrix[row_ind, col_ind])).flatten()) 200 201 202linear_sum_assignment_test_cases = product( 203 [-1, 1], 204 [ 205 # Square 206 ([[400, 150, 400], 207 [400, 450, 600], 208 [300, 225, 300]], 209 [150, 400, 300]), 210 211 # Rectangular variant 212 ([[400, 150, 400, 1], 213 [400, 450, 600, 2], 214 [300, 225, 300, 3]], 215 [150, 2, 300]), 216 217 ([[10, 10, 8], 218 [9, 8, 1], 219 [9, 7, 4]], 220 [10, 1, 7]), 221 222 # Square 223 ([[10, 10, 8, 11], 224 [9, 8, 1, 1], 225 [9, 7, 4, 10]], 226 [10, 1, 4]), 227 228 # Rectangular variant 229 ([[10, float("inf"), float("inf")], 230 [float("inf"), float("inf"), 1], 231 [float("inf"), 7, float("inf")]], 232 [10, 1, 7]) 233 ]) 234 235 236@pytest.mark.parametrize('sign,test_case', linear_sum_assignment_test_cases) 237def test_min_weight_full_matching_small_inputs(sign, test_case): 238 linear_sum_assignment_assertions( 239 min_weight_full_bipartite_matching, csr_matrix, sign, test_case) 240