1# (C) Copyright 2005-2020 Enthought, Inc., Austin, TX
2# All rights reserved.
3#
4# This software is provided without warranty under the terms of the BSD
5# license included in LICENSE.txt and may be redistributed only under
6# the conditions described in the aforementioned license. The license
7# is also available online at http://www.enthought.com/licenses/BSD.txt
8#
9# Thanks for using Enthought open source!
10
11import unittest
12
13from traits.trait_list_object import TraitList
14from traits.testing.api import UnittestTools
15from traits.testing.optional_dependencies import numpy as np, requires_numpy
16
17from pyface.data_view.abstract_data_model import DataViewSetError
18from pyface.data_view.abstract_value_type import AbstractValueType
19from pyface.data_view.value_types.api import (
20    FloatValue, IntValue, TextValue, no_value
21)
22from pyface.data_view.data_models.data_accessors import (
23    AttributeDataAccessor, IndexDataAccessor, KeyDataAccessor
24)
25from pyface.data_view.data_models.row_table_data_model import RowTableDataModel
26
27
28class DataItem:
29
30    def __init__(self, a, b, c):
31        self.a = a
32        self.b = b
33        self.c = c
34
35
36class TestRowTableDataModel(UnittestTools, unittest.TestCase):
37
38    def setUp(self):
39        super().setUp()
40        self.data = [
41            DataItem(a=i, b=10*i, c=str(i)) for i in range(10)
42        ]
43        self.model = RowTableDataModel(
44            data=self.data,
45            row_header_data=AttributeDataAccessor(
46                attr='a',
47                value_type=IntValue(),
48            ),
49            column_data=[
50                AttributeDataAccessor(
51                    attr='b',
52                    value_type=IntValue(),
53                ),
54                AttributeDataAccessor(
55                    attr='c',
56                    value_type=TextValue(),
57                )
58            ]
59        )
60        self.values_changed_event = None
61        self.structure_changed_event = None
62        self.model.observe(self.model_values_changed, 'values_changed')
63        self.model.observe(self.model_structure_changed, 'structure_changed')
64
65    def tearDown(self):
66        self.model.observe(
67            self.model_values_changed, 'values_changed', remove=True)
68        self.model.observe(
69            self.model_structure_changed, 'structure_changed', remove=True)
70        self.values_changed_event = None
71        self.structure_changed_event = None
72        super().tearDown()
73
74    def model_values_changed(self, event):
75        self.values_changed_event = event
76
77    def model_structure_changed(self, event):
78        self.structure_changed_event = event
79
80    def test_no_data(self):
81        model = RowTableDataModel()
82        self.assertEqual(model.get_column_count(), 0)
83        self.assertTrue(model.can_have_children(()))
84        self.assertEqual(model.get_row_count(()), 0)
85
86    def test_get_column_count(self):
87        result = self.model.get_column_count()
88        self.assertEqual(result, 2)
89
90    def test_can_have_children(self):
91        for row in self.model.iter_rows():
92            with self.subTest(row=row):
93                result = self.model.can_have_children(row)
94                if len(row) == 0:
95                    self.assertEqual(result, True)
96                else:
97                    self.assertEqual(result, False)
98
99    def test_get_row_count(self):
100        for row in self.model.iter_rows():
101            with self.subTest(row=row):
102                result = self.model.get_row_count(row)
103                if len(row) == 0:
104                    self.assertEqual(result, 10)
105                else:
106                    self.assertEqual(result, 0)
107
108    def test_get_value(self):
109        for row, column in self.model.iter_items():
110            with self.subTest(row=row, column=column):
111                result = self.model.get_value(row, column)
112                if len(row) == 0 and len(column) == 0:
113                    self.assertEqual(result, 'A')
114                elif len(row) == 0:
115                    attr = self.model.column_data[column[0]].attr
116                    self.assertEqual(result, attr.title())
117                elif len(column) == 0:
118                    self.assertEqual(result, row[0])
119                else:
120                    attr = self.model.column_data[column[0]].attr
121                    self.assertEqual(
122                        result,
123                        getattr(self.data[row[0]], attr)
124                    )
125
126    def test_set_value(self):
127        for row, column in self.model.iter_items():
128            with self.subTest(row=row, column=column):
129                if len(row) == 0 and len(column) == 0:
130                    with self.assertRaises(DataViewSetError):
131                        self.model.set_value(row, column, 0)
132                elif len(row) == 0:
133                    with self.assertRaises(DataViewSetError):
134                        self.model.set_value(row, column, 0)
135                elif len(column) == 0:
136                    value = 6.0 * row[0]
137                    with self.assertTraitChanges(self.model, "values_changed"):
138                        self.model.set_value(row, column, value)
139                    self.assertEqual(self.data[row[0]].a, value)
140                    self.assertEqual(
141                        self.values_changed_event.new,
142                        (row, column, row, column)
143                    )
144                else:
145                    value = 6.0 * row[-1] + 2 * column[0]
146                    with self.assertTraitChanges(self.model, "values_changed"):
147                        self.model.set_value(row, column, value)
148                    attr = self.model.column_data[column[0]].attr
149                    self.assertEqual(
150                        getattr(self.data[row[0]], attr),
151                        value,
152                    )
153                    self.assertEqual(
154                        self.values_changed_event.new,
155                        (row, column, row, column)
156                    )
157
158    def test_get_value_type(self):
159        for row, column in self.model.iter_items():
160            with self.subTest(row=row, column=column):
161                result = self.model.get_value_type(row, column)
162                if len(row) == 0 and len(column) == 0:
163                    self.assertIsInstance(result, AbstractValueType)
164                    self.assertIs(
165                        result,
166                        self.model.row_header_data.title_type,
167                    )
168                elif len(row) == 0:
169                    self.assertIsInstance(result, AbstractValueType)
170                    self.assertIs(
171                        result,
172                        self.model.column_data[column[0]].title_type,
173                    )
174                elif len(column) == 0:
175                    self.assertIsInstance(result, AbstractValueType)
176                    self.assertIs(
177                        result,
178                        self.model.row_header_data.value_type,
179                    )
180                else:
181                    self.assertIsInstance(result, AbstractValueType)
182                    self.assertIs(
183                        result,
184                        self.model.column_data[column[0]].value_type,
185                    )
186
187    def test_data_updated(self):
188        with self.assertTraitChanges(self.model, "structure_changed"):
189            self.model.data = [
190                DataItem(a=i+1, b=20*(i+1), c=str(i)) for i in range(10)
191            ]
192        self.assertTrue(self.structure_changed_event.new)
193
194    def test_data_items_updated_item_added(self):
195        self.model.data = TraitList([
196            DataItem(a=i, b=10*i, c=str(i)) for i in range(10)
197        ])
198        with self.assertTraitChanges(self.model, "structure_changed"):
199            self.model.data += [DataItem(a=100, b=200, c="a string")]
200        self.assertTrue(self.structure_changed_event.new)
201
202    def test_data_items_updated_item_replaced(self):
203        self.model.data = TraitList([
204            DataItem(a=i, b=10*i, c=str(i)) for i in range(10)
205        ])
206        with self.assertTraitChanges(self.model, "values_changed"):
207            self.model.data[1] = DataItem(a=100, b=200, c="a string")
208        self.assertEqual(self.values_changed_event.new, ((1,), (), (1,), ()))
209
210    def test_data_items_updated_item_replaced_negative(self):
211        self.model.data = TraitList([
212            DataItem(a=i, b=10*i, c=str(i)) for i in range(10)
213        ])
214        with self.assertTraitChanges(self.model, "values_changed"):
215            self.model.data[-2] = DataItem(a=100, b=200, c="a string")
216        self.assertEqual(self.values_changed_event.new, ((8,), (), (8,), ()))
217
218    def test_data_items_updated_items_replaced(self):
219        self.model.data = TraitList([
220            DataItem(a=i, b=10*i, c=str(i)) for i in range(10)
221        ])
222        with self.assertTraitChanges(self.model, "values_changed"):
223            self.model.data[1:3] = [
224                DataItem(a=100, b=200, c="a string"),
225                DataItem(a=200, b=300, c="another string"),
226            ]
227        self.assertEqual(self.values_changed_event.new, ((1,), (), (2,), ()))
228
229    def test_data_items_updated_slice_replaced(self):
230        self.model.data = TraitList([
231            DataItem(a=i, b=10*i, c=str(i)) for i in range(10)
232        ])
233        with self.assertTraitChanges(self.model, "values_changed"):
234            self.model.data[1:4:2] = [
235                DataItem(a=100, b=200, c="a string"),
236                DataItem(a=200, b=300, c="another string"),
237            ]
238        self.assertEqual(self.values_changed_event.new, ((1,), (), (3,), ()))
239
240    def test_data_items_updated_reverse_slice_replaced(self):
241        self.model.data = TraitList([
242            DataItem(a=i, b=10*i, c=str(i)) for i in range(10)
243        ])
244        with self.assertTraitChanges(self.model, "values_changed"):
245            self.model.data[3:1:-1] = [
246                DataItem(a=100, b=200, c="a string"),
247                DataItem(a=200, b=300, c="another string"),
248            ]
249        self.assertEqual(self.values_changed_event.new, ((2,), (), (3,), ()))
250
251    def test_row_header_data_updated(self):
252        with self.assertTraitChanges(self.model, "values_changed"):
253            self.model.row_header_data = AttributeDataAccessor(attr='b')
254        self.assertEqual(
255            self.values_changed_event.new,
256            ((), (), (), ())
257        )
258
259    def test_row_header_data_values_updated(self):
260        with self.assertTraitChanges(self.model, "values_changed"):
261            self.model.row_header_data.updated = (self.model.row_header_data, 'value')
262        self.assertEqual(
263            self.values_changed_event.new,
264            ((0,), (), (9,), ())
265        )
266
267    def test_row_header_data_title_updated(self):
268        with self.assertTraitChanges(self.model, "values_changed"):
269            self.model.row_header_data.updated = (self.model.row_header_data, 'title')
270        self.assertEqual(
271            self.values_changed_event.new,
272            ((), (), (), ())
273        )
274
275    def test_no_data_row_header_data_update(self):
276        model = RowTableDataModel(
277            row_header_data=AttributeDataAccessor(
278                attr='a',
279                value_type=IntValue(),
280            ),
281            column_data=[
282                AttributeDataAccessor(
283                    attr='b',
284                    value_type=IntValue(),
285                ),
286                AttributeDataAccessor(
287                    attr='c',
288                    value_type=TextValue(),
289                )
290            ]
291        )
292
293        # check that updating accessors is safe with empty data
294        with self.assertTraitDoesNotChange(model, 'values_changed'):
295            model.row_header_data.attr = 'b'
296
297    def test_column_data_updated(self):
298        with self.assertTraitChanges(self.model, "structure_changed"):
299            self.model.column_data = [
300                AttributeDataAccessor(
301                    attr='c',
302                    value_type=TextValue(),
303                ),
304                AttributeDataAccessor(
305                    attr='b',
306                    value_type=IntValue(),
307                ),
308            ]
309        self.assertTrue(self.structure_changed_event.new)
310
311    def test_column_data_items_updated(self):
312        with self.assertTraitChanges(self.model, "structure_changed"):
313            self.model.column_data.pop()
314        self.assertTrue(self.structure_changed_event.new)
315
316    def test_column_data_value_updated(self):
317        with self.assertTraitChanges(self.model, "values_changed"):
318            self.model.column_data[0].updated = (self.model.column_data[0], 'value')
319        self.assertEqual(
320            self.values_changed_event.new,
321            ((0,), (0,), (9,), (0,))
322        )
323
324    def test_no_data_column_data_update(self):
325        model = RowTableDataModel(
326            row_header_data=AttributeDataAccessor(
327                attr='a',
328                value_type=IntValue(),
329            ),
330            column_data=[
331                AttributeDataAccessor(
332                    attr='b',
333                    value_type=IntValue(),
334                ),
335                AttributeDataAccessor(
336                    attr='c',
337                    value_type=TextValue(),
338                )
339            ]
340        )
341
342        with self.assertTraitDoesNotChange(model, 'values_changed'):
343            model.column_data[0].attr = 'a'
344
345    def test_column_data_title_updated(self):
346        with self.assertTraitChanges(self.model, "values_changed"):
347            self.model.column_data[0].updated = (self.model.column_data[0], 'title')
348        self.assertEqual(
349            self.values_changed_event.new,
350            ((), (0,), (), (0,))
351        )
352
353    def test_list_tuple_data(self):
354        data = [
355            (i, 10*i, str(i)) for i in range(10)
356        ]
357        model = RowTableDataModel(
358            data=data,
359            row_header_data=IndexDataAccessor(
360                index=0,
361                value_type=IntValue(),
362            ),
363            column_data=[
364                IndexDataAccessor(
365                    index=1,
366                    value_type=IntValue(),
367                ),
368                IndexDataAccessor(
369                    index=2,
370                    value_type=TextValue(),
371                )
372            ]
373        )
374
375        for row, column in model.iter_items():
376            with self.subTest(row=row, column=column):
377                result = model.get_value(row, column)
378                if len(row) == 0 and len(column) == 0:
379                    self.assertEqual(result, '0')
380                elif len(row) == 0:
381                    index = model.column_data[column[0]].index
382                    self.assertEqual(result, str(index))
383                elif len(column) == 0:
384                    self.assertEqual(result, row[0])
385                else:
386                    index = model.column_data[column[0]].index
387                    self.assertEqual(
388                        result,
389                        data[row[0]][index]
390                    )
391
392    def test_list_dict_data(self):
393        data = [
394            {'a': i, 'b': 10*i, 'c': str(i)} for i in range(10)
395        ]
396        model = RowTableDataModel(
397            data=data,
398            row_header_data=KeyDataAccessor(
399                key='a',
400                value_type=IntValue(),
401            ),
402            column_data=[
403                KeyDataAccessor(
404                    key='b',
405                    value_type=IntValue(),
406                ),
407                KeyDataAccessor(
408                    key='c',
409                    value_type=TextValue(),
410                )
411            ]
412        )
413
414        for row, column in model.iter_items():
415            with self.subTest(row=row, column=column):
416                result = model.get_value(row, column)
417                if len(row) == 0 and len(column) == 0:
418                    self.assertEqual(result, 'A')
419                elif len(row) == 0:
420                    key = model.column_data[column[0]].key
421                    self.assertEqual(result, str(key).title())
422                elif len(column) == 0:
423                    self.assertEqual(result, data[row[0]]['a'])
424                else:
425                    key = model.column_data[column[0]].key
426                    self.assertEqual(
427                        result,
428                        data[row[0]][key]
429                    )
430