1## Copyright (C) 2020-2021 Stefano Guidoni <ilguido@users.sf.net>
2##
3## This program is free software; you can redistribute it and/or modify it under
4## the terms of the GNU General Public License as published by the Free Software
5## Foundation; either version 3 of the License, or (at your option) any later
6## version.
7##
8## This program is distributed in the hope that it will be useful, but WITHOUT
9## ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
10## FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more
11## details.
12##
13## You should have received a copy of the GNU General Public License along with
14## this program; if not, see <http://www.gnu.org/licenses/>.
15
16##
17## Author: Stefano Guidoni <ilguido@users.sf.net>
18
19classdef ConfusionMatrixChart < handle
20
21  ## -*- texinfo -*-
22  ## @deftypefn {} {@var{p} =} ConfusionMatrixChart ()
23  ## Create object @var{p}, a Confusion Matrix Chart object.
24  ##
25  ## @table @asis
26  ## @item @qcode{"DiagonalColor"}
27  ## The color of the patches on the diagonal, default is [0.0, 0.4471, 0.7412].
28  ##
29  ## @item @qcode{"OffDiagonalColor"}
30  ## The color of the patches off the diagonal, default is [0.851, 0.3255, 0.098].
31  ##
32  ## @item @qcode{"GridVisible"}
33  ## Available values: @qcode{on} (default), @qcode{off}.
34  ##
35  ## @item @qcode{"Normalization"}
36  ## Available values: @qcode{absolute} (default), @qcode{column-normalized},
37  ## @qcode{row-normalized}, @qcode{total-normalized}.
38  ##
39  ## @item @qcode{"ColumnSummary"}
40  ## Available values: @qcode{off} (default), @qcode{absolute},
41  ## @qcode{column-normalized},@qcode{total-normalized}.
42  ##
43  ## @item @qcode{"RowSummary"}
44  ## Available values: @qcode{off} (default), @qcode{absolute},
45  ## @qcode{row-normalized}, @qcode{total-normalized}.
46  ## @end table
47  ##
48  ## MATLAB compatibility -- the not implemented properties are: FontColor,
49  ## PositionConstraint, InnerPosition, Layout.
50  ##
51  ## @seealso{confusionchart}
52  ## @end deftypefn
53
54  properties (Access = public)
55    ## text properties
56    XLabel = "Predicted Class";
57    YLabel = "True Class";
58    Title  = "";
59
60    FontName  = "";
61    FontSize  = 0;
62
63    ## chart colours
64    DiagonalColor = [0 0.4471 0.7412];
65    OffDiagonalColor = [0.8510 0.3255 0.0980];
66
67    ## data visualization
68    Normalization = "absolute";
69    ColumnSummary = "off";
70    RowSummary = "off";
71
72    GridVisible = "on";
73
74    HandleVisibility = "";
75    OuterPosition = [];
76    Position = [];
77    Units = "";
78  endproperties
79
80  properties (GetAccess = public, SetAccess = private)
81    ClassLabels = {}; # a string cell array of classes
82    NormalizedValues = []; # the normalized confusion matrix
83    Parent = 0; # a handle to the parent object
84  endproperties
85
86  properties (Access = protected)
87    hax = 0.0; # a handle to the axes
88    ClassN = 0; # the number of classes
89    AbsoluteValues = []; # the original confusion matrix
90    ColumnSummaryAbsoluteValues = []; # default values of the column summary
91    RowSummaryAbsoluteValues = []; # default values of the row summary
92  endproperties
93
94  methods (Access = public)
95    ## class constructor
96    ## inputs: axis handle, a confusion matrix, a list of class labels,
97    ##         an array of optional property-value pairs.
98    function this = ConfusionMatrixChart (hax, cm, cl, args)
99      ## class initialization
100      this.hax = hax;
101      this.Parent = get (this.hax, "parent");
102      this.ClassLabels = cl;
103      this.NormalizedValues = cm;
104      this.AbsoluteValues = cm;
105      this.ClassN = rows (cm);
106      this.FontName = get (this.hax, "fontname");
107      this.FontSize = get (this.hax, "fontsize");
108
109      set (this.hax, "xlabel", this.XLabel);
110      set (this.hax, "ylabel", this.YLabel);
111
112      ## draw the chart
113      draw (this);
114
115      ## apply paired properties
116      if (! isempty (args))
117        pair_idx = 1;
118        while (pair_idx < length (args))
119          switch (args{pair_idx})
120            case "XLabel"
121              this.XLabel = args{pair_idx + 1};
122            case "YLabel"
123              this.YLabel = args{pair_idx + 1};
124            case "Title"
125              this.Title = args{pair_idx + 1};
126            case "FontName"
127              this.FontName = args{pair_idx + 1};
128            case "FontSize"
129              this.FontSize = args{pair_idx + 1};
130            case "DiagonalColor"
131              this.DiagonalColor = args{pair_idx + 1};
132            case "OffDiagonalColor"
133              this.OffDiagonalColor = args{pair_idx + 1};
134            case "Normalization"
135              this.Normalization = args{pair_idx + 1};
136            case "ColumnSummary"
137              this.ColumnSummary = args{pair_idx + 1};
138            case "RowSummary"
139              this.RowSummary = args{pair_idx + 1};
140            case "GridVisible"
141              this.GridVisible = args{pair_idx + 1};
142            case "HandleVisibility"
143              this.HandleVisibility = args{pair_idx + 1};
144            case "OuterPosition"
145              this.OuterPosition = args{pair_idx + 1};
146            case "Position"
147              this.Position = args{pair_idx + 1};
148            case "Units"
149              this.Units = args{pair_idx + 1};
150          otherwise
151              close (this.Parent);
152              error ("confusionchart: invalid property %s", args{pair_idx});
153          endswitch
154
155          pair_idx += 2;
156        endwhile
157      endif
158
159      ## init the color map
160      updateColorMap (this);
161    endfunction
162
163    ## set functions
164    function set.XLabel (this, string)
165      if (! ischar (string))
166        close (this.Parent);
167        error ("confusionchart: XLabel must be a string.");
168      endif
169
170      this.XLabel = updateAxesProperties (this, "xlabel", string);
171    endfunction
172
173    function set.YLabel (this, string)
174      if (! ischar (string))
175        close (this.Parent);
176        error ("confusionchart: YLabel must be a string.");
177      endif
178
179      this.YLabel = updateAxesProperties (this, "ylabel", string);
180    endfunction
181
182    function set.Title (this, string)
183      if (! ischar (string))
184        close (this.Parent);
185        error ("confusionchart: Title must be a string.");
186      endif
187
188      this.Title = updateAxesProperties (this, "title", string);
189    endfunction
190
191    function set.FontName (this, string)
192      if (! ischar (string))
193        close (this.Parent);
194        error ("confusionchart: FontName must be a string.");
195      endif
196
197      this.FontName = updateTextProperties (this, "fontname", string);
198    endfunction
199
200    function set.FontSize (this, value)
201      if (! isnumeric (value))
202        close (this.Parent);
203        error ("confusionchart: FontSize must be numeric.");
204      endif
205
206      this.FontSize = updateTextProperties (this, "fontsize", value);
207    endfunction
208
209    function set.DiagonalColor (this, color)
210      if (ischar (color))
211        color = this.convertNamedColor (color);
212      endif
213
214      if (! (isvector (color) && length (color) == 3 ))
215        close (this.Parent);
216        error ("confusionchart: DiagonalColor must be a color.");
217      endif
218
219      this.DiagonalColor = color;
220      updateColorMap (this);
221    endfunction
222
223    function set.OffDiagonalColor (this, color)
224      if (ischar (color))
225        color = this.convertNamedColor (color);
226      endif
227
228      if (! (isvector (color) && length (color) == 3))
229        close (this.Parent);
230        error ("confusionchart: OffDiagonalColor must be a color.");
231      endif
232
233      this.OffDiagonalColor = color;
234      updateColorMap (this);
235    endfunction
236
237    function set.Normalization (this, string)
238      if (! any (strcmp (string, {"absolute", "column-normalized",...
239        "row-normalized", "total-normalized"})))
240        close (this.Parent);
241        error ("confusionchart: invalid value for Normalization.");
242      endif
243
244      this.Normalization = string;
245      updateChart (this);
246    endfunction
247
248    function set.ColumnSummary (this, string)
249      if (! any (strcmp (string, {"off", "absolute", "column-normalized",...
250        "total-normalized"})))
251        close (this.Parent);
252        error ("confusionchart: invalid value for ColumnSummary.");
253      endif
254
255      this.ColumnSummary = string;
256      updateChart (this);
257    endfunction
258
259    function set.RowSummary (this, string)
260      if (! any (strcmp (string, {"off", "absolute", "row-normalized",...
261        "total-normalized"})))
262        close (this.Parent);
263        error ("confusionchart: invalid value for RowSummary.");
264      endif
265
266      this.RowSummary = string;
267      updateChart (this);
268    endfunction
269
270    function set.GridVisible (this, string)
271      if (! any (strcmp (string, {"off", "on"})))
272        close (this.Parent);
273        error ("confusionchart: invalid value for GridVisible.");
274      endif
275
276      this.GridVisible = string;
277      setGridVisibility (this);
278    endfunction
279
280    function set.HandleVisibility (this, string)
281      if (! any (strcmp (string, {"off", "on", "callback"})))
282        close (this.Parent);
283        error ("confusionchart: invalid value for HandleVisibility");
284      endif
285
286      set (this.hax, "handlevisibility", string);
287    endfunction
288
289    function set.OuterPosition (this, vector)
290      if (! isvector (vector) || ! isnumeric (vector) || length (vector) != 4)
291        close (this.Parent);
292        error ("confusionchart: invalid value for OuterPosition");
293      endif
294
295      set (this.hax, "outerposition", vector);
296    endfunction
297
298    function set.Position (this, vector)
299      if (! isvector (vector) || ! isnumeric (vector) || length (vector) != 4)
300        close (this.Parent);
301        error ("confusionchart: invalid value for Position");
302      endif
303
304      set (this.hax, "position", vector);
305    endfunction
306
307    function set.Units (this, string)
308      if (! any (strcmp (string, {"centimeters", "characters", "inches", ...
309                                  "normalized", "pixels", "points"})))
310        close (this.Parent);
311        error ("confusionchart: invalid value for Units");
312      endif
313
314      set (this.hax, "units", string);
315    endfunction
316
317    ## display method
318    ## MATLAB compatibility, this tries to mimic the MATLAB behaviour
319    function disp (this)
320      nv_sizes = size (this.NormalizedValues);
321      cl_sizes = size (this.ClassLabels);
322
323      printf ("%s with properties:\n\n", class (this));
324      printf ("\tNormalizedValues: [ %dx%d %s ]\n", nv_sizes(1), nv_sizes(2),...
325        class (this.NormalizedValues));
326      printf ("\tClassLabels: { %dx%d %s }\n\n", cl_sizes(1), cl_sizes(2),...
327        class (this.ClassLabels));
328    endfunction
329
330    ## sortClasses
331    ## reorder the chart
332    function sortClasses (this, order)
333      ## -*- texinfo -*-
334      ## @deftypefn  {} {} sortClasses (@var{cm},@var{order})
335      ## Sort the classes of the @code{ConfusionMatriChart} object @var{cm}
336      ## according to @var{order}.
337      ##
338      ## Valid values for @var{order} can be an array or cell array including
339      ## the same class labels as @var{cm}, or a value like @code{'auto'},
340      ## @code{'ascending-diagonal'}, @code{'descending-diagonal'} and
341      ## @code{'cluster'}.
342      ##
343      ## @end deftypefn
344      ##
345      ## @seealso{confusionchart, linkage, pdist}
346
347      ## check the input parameters
348      if (nargin != 2)
349        print_usage ();
350      endif
351
352      cl = this.ClassLabels;
353      cm_size = this.ClassN;
354      nv = this.NormalizedValues;
355      av = this.AbsoluteValues;
356      cv = this.ColumnSummaryAbsoluteValues;
357      rv = this.RowSummaryAbsoluteValues;
358
359      scl = {};
360      Idx = [];
361
362      if (strcmp (order, "auto"))
363        [scl, Idx] = sort (cl);
364      elseif (strcmp (order, "ascending-diagonal"))
365        [s, Idx] = sort (diag (nv));
366        scl = cl(Idx);
367      elseif (strcmp (order, "descending-diagonal"))
368        [s, Idx] = sort (diag (nv));
369        Idx = flip (Idx);
370        scl = cl(Idx);
371      elseif (strcmp (order, "cluster"))
372        ## the classes are all grouped together
373        ## this way one can visually evaluate which are the most similar classes
374        ## according to the learning algorithm
375        D = zeros (1, ((cm_size - 1) * cm_size / 2)); # a pdist like vector
376        maxD = 2 * max (max (av));
377        k = 1; # better than computing the index at every cycle
378        for i = 1 : (cm_size - 1)
379          for j = (i + 1) : cm_size
380            D(k++) = maxD - (av(i, j) + av(j, i)); # distance
381          endfor
382        endfor
383        tree = linkage (D, "average"); # clustering
384        ## we could have optimal leaf ordering with
385        Idx = optimalleaforder (tree, D); # optimal clustering
386        ## [sorted_v Idx] = sort (cluster (tree, ));
387        nodes_to_visit = 2 * cm_size - 1;
388        nodecount = 0;
389        while (! isempty (nodes_to_visit))
390          current_node = nodes_to_visit(1);
391          nodes_to_visit(1) = [];
392          if (current_node > cm_size)
393            node = current_node - cm_size;
394            nodes_to_visit = [tree(node,[2 1]) nodes_to_visit];
395          end
396
397          if (current_node <= cm_size)
398            nodecount++;
399            Idx(nodecount) = current_node;
400          end
401        end
402        ##
403        scl = cl(Idx);
404      else
405        ## must be an array or cell array of labels
406        if (! iscellstr (order))
407          if (! ischar (order))
408            if (isrow (order))
409              order = vec (order);
410            endif
411            order = num2str (order);
412          endif
413
414          scl = cellstr (order);
415        endif
416
417        if (length (scl) != length (cl))
418          error ("sortClasses: wrong size for order.")
419        endif
420
421        Idx = zeros (length (scl), 1);
422
423        for i = 1 : length (scl)
424          Idx(i) = find (strcmp (cl, scl{i}));
425        endfor
426      endif
427
428      ## rearrange the normalized values...
429      nv = nv(Idx, :);
430      nv = nv(:, Idx);
431      this.NormalizedValues = nv;
432
433      ## ...and the absolute values...
434      av = av(Idx, :);
435      av = av(:, Idx);
436      this.AbsoluteValues = av;
437
438      cv = cv([Idx ( Idx + cm_size )]);
439      this.ColumnSummaryAbsoluteValues = cv;
440
441      rv = rv([Idx ( Idx + cm_size )]);
442      this.RowSummaryAbsoluteValues = rv;
443
444      ## ...and the class labels
445      this.ClassLabels = scl;
446
447      ## update the axes
448      set (this.hax, "xtick", (0.5 : 1 : (cm_size - 0.5)), "xticklabel", scl,...
449          "ytick", (0.5 : 1 : (cm_size - 0.5)), "yticklabel", scl);
450
451      ## get text and patch handles
452      kids = get (this.hax, "children");
453      t_kids = kids(find (isprop (kids, "fontname"))); # hack to find texts
454      m_kid = kids(find (strcmp (get (kids, "userdata"), "MainChart")));
455      c_kid = kids(find (strcmp (get (kids, "userdata"), "ColumnSummary")));
456      r_kid = kids(find (strcmp (get (kids, "userdata"), "RowSummary")));
457
458      ## re-assign colors to the main chart
459      cdata_m = reshape (get (m_kid, "cdata"), cm_size, cm_size);
460      cdata_m = cdata_m(Idx, :);
461      cdata_m = cdata_m(:, Idx);
462
463      cdata_v = vec (cdata_m);
464
465      set (m_kid, "cdata", cdata_v);
466
467      ## re-assign colors to the column summary
468      cdata_m = reshape (transpose (get (c_kid, "cdata")), cm_size, 2);
469      cdata_m = cdata_m(Idx, :);
470
471      cdata_v = vec (cdata_m);
472
473      set (c_kid, "cdata", cdata_v);
474
475      ## re-assign colors to the row summary
476      cdata_m = reshape (get (r_kid, "cdata"), cm_size, 2);
477      cdata_m = cdata_m(Idx, :);
478
479      cdata_v = vec (cdata_m);
480
481      set (r_kid, "cdata", cdata_v);
482
483      ## move the text labels
484      for i = 1:length (t_kids)
485        t_pos = get (t_kids(i), "userdata");
486
487        if (t_pos(2) > cm_size)
488          ## row summary
489          t_pos(1) = find (Idx == (t_pos(1) + 1)) - 1;
490          set (t_kids(i), "userdata", t_pos);
491
492          t_pos = t_pos([2 1]) + 0.5;
493          set (t_kids(i), "position", t_pos);
494        elseif (t_pos(1) > cm_size)
495          ## column summary
496          t_pos(2) = find (Idx == (t_pos(2) + 1)) - 1;
497          set (t_kids(i), "userdata", t_pos);
498
499          t_pos = t_pos([2 1]) + 0.5;
500          set (t_kids(i), "position", t_pos);
501        else
502          ## main chart
503          t_pos(1) = find (Idx == (t_pos(1) + 1)) - 1;
504          t_pos(2) = find (Idx == (t_pos(2) + 1)) - 1;
505          set (t_kids(i), "userdata", t_pos);
506
507          t_pos = t_pos([2 1]) + 0.5;
508          set (t_kids(i), "position", t_pos);
509        endif
510      endfor
511
512      updateChart (this);
513    endfunction
514  endmethods
515
516  methods (Access = private)
517    ## convertNamedColor
518    ## convert a named colour to a colour triplet
519    function ret = convertNamedColor (this, color)
520      vColorNames = ["ymcrgbwk"]';
521      vColorTriplets = [1 1 0; 1 0 1; 0 1 1; 1 0 0; 0 1 0; 0 0 1; 1 1 1; 0 0 0];
522      if (strcmp (color, "black"))
523        color = 'k';
524      endif
525
526      index = find (vColorNames == color(1));
527      if (! isempty (index))
528        ret = vColorTriplets(index, :);
529      else
530        ret = []; # trigger an error message
531      endif
532    endfunction
533
534    ## updateAxesProperties
535    ## update the properties of the axes
536    function ret = updateAxesProperties (this, prop, value)
537      set (this.hax, prop, value);
538
539      ret = value;
540    endfunction
541
542    ## updateTextProperties
543    ## set the properties of the texts
544    function ret = updateTextProperties (this, prop, value)
545      hax_kids = get (this.hax, "children");
546      text_kids = hax_kids(isprop (hax_kids , "fontname")); # hack to find texts
547      text_kids(end + 1) = get (this.hax, "xlabel");
548      text_kids(end + 1) = get (this.hax, "ylabel");
549      text_kids(end + 1) = get (this.hax, "title");
550
551      updateAxesProperties (this, prop, value);
552      set (text_kids, prop, value);
553
554      ret = value;
555    endfunction
556
557    ## setGridVisibility
558    ## toggle the visibility of the grid
559    function setGridVisibility (this)
560      kids = get (this.hax, "children");
561      kids = kids(find (isprop (kids, "linestyle")));
562
563      if (strcmp (this.GridVisible, "on"))
564        set (kids, "linestyle", "-");
565      else
566        set (kids, "linestyle", "none");
567      endif
568    endfunction
569
570    ## updateColorMap
571    ## change the colormap and, accordingly, the text colors
572    function updateColorMap (this)
573      cm_size = this.ClassN;
574      d_color = this.DiagonalColor;
575      o_color = this.OffDiagonalColor;
576
577      ## quick hack
578      d_color(find (d_color == 1.0)) = 0.999;
579      o_color(find (o_color == 1.0)) = 0.999;
580
581      ## 64 shades for each color
582      cm_colormap(1:64,:) = [1.0 : (-(1.0 - o_color(1)) / 63) : o_color(1);...
583        1.0 : (-(1.0 - o_color(2)) / 63) : o_color(2);...
584        1.0 : (-(1.0 - o_color(3)) / 63) : o_color(3)]';
585      cm_colormap(65:128,:) = [1.0 : (-(1.0 - d_color(1)) / 63) : d_color(1);...
586        1.0 : (-(1.0 - d_color(2)) / 63) : d_color(2);...
587        1.0 : (-(1.0 - d_color(3)) / 63) : d_color(3)]';
588
589      colormap (this.hax, cm_colormap);
590
591      ## update text colors
592      kids = get (this.hax, "children");
593      t_kids = kids(find (isprop (kids, "fontname"))); # hack to find texts
594      m_patch = kids(find (strcmp (get (kids, "userdata"), "MainChart")));
595      c_patch = kids(find (strcmp (get (kids, "userdata"), "ColumnSummary")));
596      r_patch = kids(find (strcmp (get (kids, "userdata"), "RowSummary")));
597
598      m_colors = get (m_patch, "cdata");
599      c_colors = get (c_patch, "cdata");
600      r_colors = get (r_patch, "cdata");
601
602      ## when a patch is dark, let's use a pale color for the text
603      for i = 1 : length (t_kids)
604        t_pos = get (t_kids(i), "userdata");
605        color_idx = 1;
606
607        if (t_pos(2) > cm_size)
608          ## row summary
609          idx = (t_pos(2) - cm_size - 1) * cm_size + t_pos(1) + 1;
610          color_idx = r_colors(idx) + 1;
611        elseif (t_pos(1) > cm_size)
612          ## column summary
613          idx = (t_pos(1) - cm_size - 1) * cm_size + t_pos(2) + 1;
614          color_idx = c_colors(idx) + 1;
615        else
616          ## main chart
617          idx = t_pos(2) * cm_size + t_pos(1) + 1;
618          color_idx = m_colors(idx) + 1;
619        endif
620
621        if (sum (cm_colormap(color_idx, :)) < 1.8)
622          set (t_kids(i), "color", [.97 .97 1.0]);
623        else
624          set (t_kids(i), "color", [.15 .15 .15]);
625        endif
626      endfor
627    endfunction
628
629    ## updateChart
630    ## update the text labels and the NormalizedValues property
631    function updateChart (this)
632      cm_size = this.ClassN;
633      cm = this.AbsoluteValues;
634      l_cs = this.ColumnSummaryAbsoluteValues;
635      l_rs = this.RowSummaryAbsoluteValues;
636
637      kids = get (this.hax, "children");
638      t_kids = kids(find (isprop (kids, "fontname"))); # hack to find texts
639
640      normalization = this.Normalization;
641      column_summary = this.ColumnSummary;
642      row_summary = this.RowSummary;
643
644      ## normalization for labelling
645      row_totals = sum (cm, 2);
646      col_totals = sum (cm, 1);
647      mat_total = sum (col_totals);
648      cm_labels = cm;
649      add_percent = true;
650
651      if (strcmp (normalization, "column-normalized"))
652        for i = 1 : cm_size
653          cm_labels(:,i) = cm_labels(:,i) ./ col_totals(i);
654        endfor
655      elseif (strcmp (normalization, "row-normalized"))
656        for i = 1 : cm_size
657          cm_labels(i,:) = cm_labels(i,:) ./ row_totals(i);
658        endfor
659      elseif (strcmp (normalization, "total-normalized"))
660        cm_labels = cm_labels ./ mat_total;
661      else
662        add_percent = false;
663      endif
664
665      ## update NormalizedValues
666      this.NormalizedValues = cm_labels;
667
668      ## update axes
669      last_row = cm_size;
670      last_col = cm_size;
671      userdata = cell2mat (get (t_kids, "userdata"));
672
673      cs_kids = t_kids(find (userdata(:,1) > cm_size));
674      cs_kids(end + 1) = kids(find (strcmp (get (kids, "userdata"),...
675          "ColumnSummary")));
676
677      if (! strcmp ("off", column_summary))
678        set (cs_kids, "visible", "on");
679        last_row += 3;
680      else
681        set (cs_kids, "visible", "off");
682      endif
683
684      rs_kids = t_kids(find (userdata(:,2) > cm_size));
685      rs_kids(end + 1) = kids(find (strcmp (get (kids, "userdata"),...
686          "RowSummary")));
687
688      if (! strcmp ("off", row_summary))
689        set (rs_kids, "visible", "on");
690        last_col += 3;
691      else
692        set (rs_kids, "visible", "off");
693      endif
694
695      axis (this.hax, [0 last_col 0 last_row]);
696
697      ## update column summary data
698      cs_add_percent = true;
699      if (! strcmp (column_summary, "off"))
700        if (strcmp (column_summary, "column-normalized"))
701          for i = 1 : cm_size
702            if (col_totals(i) == 0)
703              ## avoid division by zero
704              l_cs([i (cm_size + i)]) = 0;
705            else
706              l_cs([i, cm_size + i]) = l_cs([i, cm_size + i]) ./ col_totals(i);
707            endif
708          endfor
709        elseif strcmp (column_summary, "total-normalized")
710          l_cs = l_cs ./ mat_total;
711        else
712          cs_add_percent = false;
713        endif
714      endif
715
716      ## update row summary data
717      rs_add_percent = true;
718      if (! strcmp (row_summary, "off"))
719        if (strcmp (row_summary, "row-normalized"))
720          for i = 1 : cm_size
721            if (row_totals(i) == 0)
722              ## avoid division by zero
723              l_rs([i (cm_size + i)]) = 0;
724            else
725              l_rs([i, cm_size + i]) = l_rs([i, cm_size + i]) ./ row_totals(i);
726            endif
727          endfor
728        elseif (strcmp (row_summary, "total-normalized"))
729          l_rs = l_rs ./ mat_total;
730        else
731          rs_add_percent = false;
732        endif
733      endif
734
735      ## update text
736      label_list = vec (cm_labels);
737
738      for i = 1 : length (t_kids)
739        t_pos = get (t_kids(i), "userdata");
740        new_string = "";
741
742        if (t_pos(2) > cm_size)
743          ## this is the row summary
744          idx = (t_pos(2) - cm_size - 1) * cm_size + t_pos(1) + 1;
745
746          if (rs_add_percent)
747            new_string = num2str (100.0 * l_rs(idx), "%3.1f");
748            new_string = [new_string "%"];
749          else
750            new_string = num2str (l_rs(idx));
751          endif
752        elseif (t_pos(1) > cm_size)
753          ## this is the column summary
754          idx = (t_pos(1) - cm_size - 1) * cm_size + t_pos(2) + 1;
755
756          if (cs_add_percent)
757            new_string = num2str (100.0 * l_cs(idx), "%3.1f");
758            new_string = [new_string "%"];
759          else
760            new_string = num2str (l_cs(idx));
761          endif
762        else
763          ## this is the main chart
764          idx = t_pos(2) * cm_size + t_pos(1) + 1;
765
766          if (add_percent)
767            new_string = num2str (100.0 * label_list(idx), "%3.1f");
768            new_string = [new_string "%"];
769          else
770            new_string = num2str (label_list(idx));
771          endif
772        endif
773
774        set (t_kids(i), "string", new_string);
775      endfor
776
777    endfunction
778
779    ## draw
780    ## draw the chart
781    function draw (this)
782      cm = this.AbsoluteValues;
783      cl = this.ClassLabels;
784      cm_size = this.ClassN;
785
786      ## set up the axes
787      set (this.hax, "xtick", (0.5 : 1 : (cm_size - 0.5)), "xticklabel",  cl,...
788          "ytick", (0.5 : 1 : (cm_size - 0.5)), "yticklabel",  cl );
789      axis ("ij");
790      axis (this.hax, [0 cm_size 0 cm_size]);
791
792      ## prepare the patches
793      indices_b = 0 : (cm_size -1);
794      indices_v = repmat (indices_b, cm_size, 1);
795      indices_vx = transpose (vec (indices_v));
796      indices_vy = vec (indices_v', 2);
797      indices_ex = vec ((cm_size + 1) * [1; 2] .* ones (2, cm_size), 2);
798
799      ## normalization for colorization
800      ## it is used a colormap of 128 shades of two colors, 64 shades for each
801      ## color
802      normal = max (max (cm));
803      cm_norm = round (63 * cm ./ normal);
804      cm_norm = cm_norm + 64 * eye (cm_size);
805
806      ## default normalization: absolute
807      cm_labels = vec (cm);
808
809      ## the patches of the main chart
810      x_patch = [indices_vx;
811                ( indices_vx + 1 );
812                ( indices_vx + 1 );
813                indices_vx];
814      y_patch = [indices_vy;
815                indices_vy;
816                ( indices_vy + 1 );
817                ( indices_vy + 1 )];
818      c_patch = vec (cm_norm(1 : cm_size, 1 : cm_size));
819
820      ## display the patches
821      ph = patch (this.hax, x_patch, y_patch, c_patch);
822
823      set (ph, "userdata", "MainChart");
824
825      ## display the labels
826      userdata = [indices_vy; indices_vx]';
827      nonzero_idx = find (cm_labels != 0);
828      th = text ((x_patch(1, nonzero_idx) + 0.5), (y_patch(1, nonzero_idx) +...
829          0.5), num2str (cm_labels(nonzero_idx)), "parent", this.hax );
830
831      set (th, "horizontalalignment", "center");
832      for i = 1 : length (nonzero_idx)
833        set (th(i), "userdata", userdata(nonzero_idx(i), :));
834      endfor
835
836      ## patches for the summaries
837      main_values = diag (cm);
838      ct_values = sum (cm)';
839      rt_values = sum (cm, 2);
840      cd_values = ct_values - main_values;
841      rd_values = rt_values - main_values;
842
843      ## column summary
844      x_cs = [[indices_b indices_b];
845              ( [indices_b indices_b] + 1 );
846              ( [indices_b indices_b] + 1 );
847              [indices_b indices_b]];
848      y_cs = [(repmat ([1 1 2 2]', 1, cm_size)) (repmat ([2 2 3 3]', 1, cm_size))] +...
849          cm_size;
850      c_cs = [(round (63 * (main_values ./ ct_values)) + 64);
851              (round (63 * (cd_values ./ ct_values)))];
852      c_cs(isnan (c_cs)) = 0;
853      l_cs = [main_values; cd_values];
854
855      ph = patch (this.hax, x_cs, y_cs, c_cs);
856
857      set (ph, "userdata", "ColumnSummary");
858      set (ph, "visible", "off" );
859
860      userdata = [y_cs(1,:); x_cs(1,:)]';
861      nonzero_idx = find (l_cs != 0);
862      th = text ((x_cs(1,nonzero_idx) + 0.5), (y_cs(1,nonzero_idx) + 0.5),...
863          num2str (l_cs(nonzero_idx)), "parent", this.hax);
864
865      set (th, "horizontalalignment", "center");
866      for i = 1 : length (nonzero_idx)
867        set (th(i), "userdata", userdata(nonzero_idx(i), :));
868      endfor
869      set (th, "visible", "off");
870
871      ## row summary
872      x_rs = y_cs;
873      y_rs = x_cs;
874      c_rs = [(round (63 * (main_values ./ rt_values)) + 64);
875              (round (63 * (rd_values ./ rt_values)))];
876      c_rs(isnan (c_rs)) = 0;
877      l_rs = [main_values; rd_values];
878
879      ph = patch (this.hax, x_rs, y_rs, c_rs);
880
881      set (ph, "userdata", "RowSummary");
882      set (ph, "visible", "off");
883
884      userdata = [y_rs(1,:); x_rs(1,:)]';
885      nonzero_idx = find (l_rs != 0);
886      th = text ((x_rs(1,nonzero_idx) + 0.5), (y_rs(1,nonzero_idx) + 0.5),...
887          num2str (l_rs(nonzero_idx)), "parent", this.hax);
888
889      set (th, "horizontalalignment", "center");
890      for i = 1 : length (nonzero_idx)
891        set (th(i), "userdata", userdata(nonzero_idx(i), :));
892      endfor
893      set (th, "visible", "off");
894
895      this.ColumnSummaryAbsoluteValues = l_cs;
896      this.RowSummaryAbsoluteValues = l_rs;
897    endfunction
898  endmethods
899
900endclassdef
901
902