1# (C) Copyright 2005-2020 Enthought, Inc., Austin, TX 2# All rights reserved. 3# 4# This software is provided without warranty under the terms of the BSD 5# license included in LICENSE.txt and may be redistributed only under 6# the conditions described in the aforementioned license. The license 7# is also available online at http://www.enthought.com/licenses/BSD.txt 8# 9# Thanks for using Enthought open source! 10 11import unittest 12 13from traits.trait_list_object import TraitList 14from traits.testing.api import UnittestTools 15from traits.testing.optional_dependencies import numpy as np, requires_numpy 16 17from pyface.data_view.abstract_data_model import DataViewSetError 18from pyface.data_view.abstract_value_type import AbstractValueType 19from pyface.data_view.value_types.api import ( 20 FloatValue, IntValue, TextValue, no_value 21) 22from pyface.data_view.data_models.data_accessors import ( 23 AttributeDataAccessor, IndexDataAccessor, KeyDataAccessor 24) 25from pyface.data_view.data_models.row_table_data_model import RowTableDataModel 26 27 28class DataItem: 29 30 def __init__(self, a, b, c): 31 self.a = a 32 self.b = b 33 self.c = c 34 35 36class TestRowTableDataModel(UnittestTools, unittest.TestCase): 37 38 def setUp(self): 39 super().setUp() 40 self.data = [ 41 DataItem(a=i, b=10*i, c=str(i)) for i in range(10) 42 ] 43 self.model = RowTableDataModel( 44 data=self.data, 45 row_header_data=AttributeDataAccessor( 46 attr='a', 47 value_type=IntValue(), 48 ), 49 column_data=[ 50 AttributeDataAccessor( 51 attr='b', 52 value_type=IntValue(), 53 ), 54 AttributeDataAccessor( 55 attr='c', 56 value_type=TextValue(), 57 ) 58 ] 59 ) 60 self.values_changed_event = None 61 self.structure_changed_event = None 62 self.model.observe(self.model_values_changed, 'values_changed') 63 self.model.observe(self.model_structure_changed, 'structure_changed') 64 65 def tearDown(self): 66 self.model.observe( 67 self.model_values_changed, 'values_changed', remove=True) 68 self.model.observe( 69 self.model_structure_changed, 'structure_changed', remove=True) 70 self.values_changed_event = None 71 self.structure_changed_event = None 72 super().tearDown() 73 74 def model_values_changed(self, event): 75 self.values_changed_event = event 76 77 def model_structure_changed(self, event): 78 self.structure_changed_event = event 79 80 def test_no_data(self): 81 model = RowTableDataModel() 82 self.assertEqual(model.get_column_count(), 0) 83 self.assertTrue(model.can_have_children(())) 84 self.assertEqual(model.get_row_count(()), 0) 85 86 def test_get_column_count(self): 87 result = self.model.get_column_count() 88 self.assertEqual(result, 2) 89 90 def test_can_have_children(self): 91 for row in self.model.iter_rows(): 92 with self.subTest(row=row): 93 result = self.model.can_have_children(row) 94 if len(row) == 0: 95 self.assertEqual(result, True) 96 else: 97 self.assertEqual(result, False) 98 99 def test_get_row_count(self): 100 for row in self.model.iter_rows(): 101 with self.subTest(row=row): 102 result = self.model.get_row_count(row) 103 if len(row) == 0: 104 self.assertEqual(result, 10) 105 else: 106 self.assertEqual(result, 0) 107 108 def test_get_value(self): 109 for row, column in self.model.iter_items(): 110 with self.subTest(row=row, column=column): 111 result = self.model.get_value(row, column) 112 if len(row) == 0 and len(column) == 0: 113 self.assertEqual(result, 'A') 114 elif len(row) == 0: 115 attr = self.model.column_data[column[0]].attr 116 self.assertEqual(result, attr.title()) 117 elif len(column) == 0: 118 self.assertEqual(result, row[0]) 119 else: 120 attr = self.model.column_data[column[0]].attr 121 self.assertEqual( 122 result, 123 getattr(self.data[row[0]], attr) 124 ) 125 126 def test_set_value(self): 127 for row, column in self.model.iter_items(): 128 with self.subTest(row=row, column=column): 129 if len(row) == 0 and len(column) == 0: 130 with self.assertRaises(DataViewSetError): 131 self.model.set_value(row, column, 0) 132 elif len(row) == 0: 133 with self.assertRaises(DataViewSetError): 134 self.model.set_value(row, column, 0) 135 elif len(column) == 0: 136 value = 6.0 * row[0] 137 with self.assertTraitChanges(self.model, "values_changed"): 138 self.model.set_value(row, column, value) 139 self.assertEqual(self.data[row[0]].a, value) 140 self.assertEqual( 141 self.values_changed_event.new, 142 (row, column, row, column) 143 ) 144 else: 145 value = 6.0 * row[-1] + 2 * column[0] 146 with self.assertTraitChanges(self.model, "values_changed"): 147 self.model.set_value(row, column, value) 148 attr = self.model.column_data[column[0]].attr 149 self.assertEqual( 150 getattr(self.data[row[0]], attr), 151 value, 152 ) 153 self.assertEqual( 154 self.values_changed_event.new, 155 (row, column, row, column) 156 ) 157 158 def test_get_value_type(self): 159 for row, column in self.model.iter_items(): 160 with self.subTest(row=row, column=column): 161 result = self.model.get_value_type(row, column) 162 if len(row) == 0 and len(column) == 0: 163 self.assertIsInstance(result, AbstractValueType) 164 self.assertIs( 165 result, 166 self.model.row_header_data.title_type, 167 ) 168 elif len(row) == 0: 169 self.assertIsInstance(result, AbstractValueType) 170 self.assertIs( 171 result, 172 self.model.column_data[column[0]].title_type, 173 ) 174 elif len(column) == 0: 175 self.assertIsInstance(result, AbstractValueType) 176 self.assertIs( 177 result, 178 self.model.row_header_data.value_type, 179 ) 180 else: 181 self.assertIsInstance(result, AbstractValueType) 182 self.assertIs( 183 result, 184 self.model.column_data[column[0]].value_type, 185 ) 186 187 def test_data_updated(self): 188 with self.assertTraitChanges(self.model, "structure_changed"): 189 self.model.data = [ 190 DataItem(a=i+1, b=20*(i+1), c=str(i)) for i in range(10) 191 ] 192 self.assertTrue(self.structure_changed_event.new) 193 194 def test_data_items_updated_item_added(self): 195 self.model.data = TraitList([ 196 DataItem(a=i, b=10*i, c=str(i)) for i in range(10) 197 ]) 198 with self.assertTraitChanges(self.model, "structure_changed"): 199 self.model.data += [DataItem(a=100, b=200, c="a string")] 200 self.assertTrue(self.structure_changed_event.new) 201 202 def test_data_items_updated_item_replaced(self): 203 self.model.data = TraitList([ 204 DataItem(a=i, b=10*i, c=str(i)) for i in range(10) 205 ]) 206 with self.assertTraitChanges(self.model, "values_changed"): 207 self.model.data[1] = DataItem(a=100, b=200, c="a string") 208 self.assertEqual(self.values_changed_event.new, ((1,), (), (1,), ())) 209 210 def test_data_items_updated_item_replaced_negative(self): 211 self.model.data = TraitList([ 212 DataItem(a=i, b=10*i, c=str(i)) for i in range(10) 213 ]) 214 with self.assertTraitChanges(self.model, "values_changed"): 215 self.model.data[-2] = DataItem(a=100, b=200, c="a string") 216 self.assertEqual(self.values_changed_event.new, ((8,), (), (8,), ())) 217 218 def test_data_items_updated_items_replaced(self): 219 self.model.data = TraitList([ 220 DataItem(a=i, b=10*i, c=str(i)) for i in range(10) 221 ]) 222 with self.assertTraitChanges(self.model, "values_changed"): 223 self.model.data[1:3] = [ 224 DataItem(a=100, b=200, c="a string"), 225 DataItem(a=200, b=300, c="another string"), 226 ] 227 self.assertEqual(self.values_changed_event.new, ((1,), (), (2,), ())) 228 229 def test_data_items_updated_slice_replaced(self): 230 self.model.data = TraitList([ 231 DataItem(a=i, b=10*i, c=str(i)) for i in range(10) 232 ]) 233 with self.assertTraitChanges(self.model, "values_changed"): 234 self.model.data[1:4:2] = [ 235 DataItem(a=100, b=200, c="a string"), 236 DataItem(a=200, b=300, c="another string"), 237 ] 238 self.assertEqual(self.values_changed_event.new, ((1,), (), (3,), ())) 239 240 def test_data_items_updated_reverse_slice_replaced(self): 241 self.model.data = TraitList([ 242 DataItem(a=i, b=10*i, c=str(i)) for i in range(10) 243 ]) 244 with self.assertTraitChanges(self.model, "values_changed"): 245 self.model.data[3:1:-1] = [ 246 DataItem(a=100, b=200, c="a string"), 247 DataItem(a=200, b=300, c="another string"), 248 ] 249 self.assertEqual(self.values_changed_event.new, ((2,), (), (3,), ())) 250 251 def test_row_header_data_updated(self): 252 with self.assertTraitChanges(self.model, "values_changed"): 253 self.model.row_header_data = AttributeDataAccessor(attr='b') 254 self.assertEqual( 255 self.values_changed_event.new, 256 ((), (), (), ()) 257 ) 258 259 def test_row_header_data_values_updated(self): 260 with self.assertTraitChanges(self.model, "values_changed"): 261 self.model.row_header_data.updated = (self.model.row_header_data, 'value') 262 self.assertEqual( 263 self.values_changed_event.new, 264 ((0,), (), (9,), ()) 265 ) 266 267 def test_row_header_data_title_updated(self): 268 with self.assertTraitChanges(self.model, "values_changed"): 269 self.model.row_header_data.updated = (self.model.row_header_data, 'title') 270 self.assertEqual( 271 self.values_changed_event.new, 272 ((), (), (), ()) 273 ) 274 275 def test_no_data_row_header_data_update(self): 276 model = RowTableDataModel( 277 row_header_data=AttributeDataAccessor( 278 attr='a', 279 value_type=IntValue(), 280 ), 281 column_data=[ 282 AttributeDataAccessor( 283 attr='b', 284 value_type=IntValue(), 285 ), 286 AttributeDataAccessor( 287 attr='c', 288 value_type=TextValue(), 289 ) 290 ] 291 ) 292 293 # check that updating accessors is safe with empty data 294 with self.assertTraitDoesNotChange(model, 'values_changed'): 295 model.row_header_data.attr = 'b' 296 297 def test_column_data_updated(self): 298 with self.assertTraitChanges(self.model, "structure_changed"): 299 self.model.column_data = [ 300 AttributeDataAccessor( 301 attr='c', 302 value_type=TextValue(), 303 ), 304 AttributeDataAccessor( 305 attr='b', 306 value_type=IntValue(), 307 ), 308 ] 309 self.assertTrue(self.structure_changed_event.new) 310 311 def test_column_data_items_updated(self): 312 with self.assertTraitChanges(self.model, "structure_changed"): 313 self.model.column_data.pop() 314 self.assertTrue(self.structure_changed_event.new) 315 316 def test_column_data_value_updated(self): 317 with self.assertTraitChanges(self.model, "values_changed"): 318 self.model.column_data[0].updated = (self.model.column_data[0], 'value') 319 self.assertEqual( 320 self.values_changed_event.new, 321 ((0,), (0,), (9,), (0,)) 322 ) 323 324 def test_no_data_column_data_update(self): 325 model = RowTableDataModel( 326 row_header_data=AttributeDataAccessor( 327 attr='a', 328 value_type=IntValue(), 329 ), 330 column_data=[ 331 AttributeDataAccessor( 332 attr='b', 333 value_type=IntValue(), 334 ), 335 AttributeDataAccessor( 336 attr='c', 337 value_type=TextValue(), 338 ) 339 ] 340 ) 341 342 with self.assertTraitDoesNotChange(model, 'values_changed'): 343 model.column_data[0].attr = 'a' 344 345 def test_column_data_title_updated(self): 346 with self.assertTraitChanges(self.model, "values_changed"): 347 self.model.column_data[0].updated = (self.model.column_data[0], 'title') 348 self.assertEqual( 349 self.values_changed_event.new, 350 ((), (0,), (), (0,)) 351 ) 352 353 def test_list_tuple_data(self): 354 data = [ 355 (i, 10*i, str(i)) for i in range(10) 356 ] 357 model = RowTableDataModel( 358 data=data, 359 row_header_data=IndexDataAccessor( 360 index=0, 361 value_type=IntValue(), 362 ), 363 column_data=[ 364 IndexDataAccessor( 365 index=1, 366 value_type=IntValue(), 367 ), 368 IndexDataAccessor( 369 index=2, 370 value_type=TextValue(), 371 ) 372 ] 373 ) 374 375 for row, column in model.iter_items(): 376 with self.subTest(row=row, column=column): 377 result = model.get_value(row, column) 378 if len(row) == 0 and len(column) == 0: 379 self.assertEqual(result, '0') 380 elif len(row) == 0: 381 index = model.column_data[column[0]].index 382 self.assertEqual(result, str(index)) 383 elif len(column) == 0: 384 self.assertEqual(result, row[0]) 385 else: 386 index = model.column_data[column[0]].index 387 self.assertEqual( 388 result, 389 data[row[0]][index] 390 ) 391 392 def test_list_dict_data(self): 393 data = [ 394 {'a': i, 'b': 10*i, 'c': str(i)} for i in range(10) 395 ] 396 model = RowTableDataModel( 397 data=data, 398 row_header_data=KeyDataAccessor( 399 key='a', 400 value_type=IntValue(), 401 ), 402 column_data=[ 403 KeyDataAccessor( 404 key='b', 405 value_type=IntValue(), 406 ), 407 KeyDataAccessor( 408 key='c', 409 value_type=TextValue(), 410 ) 411 ] 412 ) 413 414 for row, column in model.iter_items(): 415 with self.subTest(row=row, column=column): 416 result = model.get_value(row, column) 417 if len(row) == 0 and len(column) == 0: 418 self.assertEqual(result, 'A') 419 elif len(row) == 0: 420 key = model.column_data[column[0]].key 421 self.assertEqual(result, str(key).title()) 422 elif len(column) == 0: 423 self.assertEqual(result, data[row[0]]['a']) 424 else: 425 key = model.column_data[column[0]].key 426 self.assertEqual( 427 result, 428 data[row[0]][key] 429 ) 430