1## Copyright (C) 2016 - Juan Pablo Carbajal
2##
3## This progrm is free software; you can redistribute it and/or modify
4## it under the terms of the GNU General Public License as published by
5## the Free Software Foundation; either version 3 of the License, or
6## (at your option) any later version.
7##
8## This program is distributed in the hope that it will be useful,
9## but WITHOUT ANY WARRANTY; without even the implied warranty of
10## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
11## GNU General Public License for more details.
12##
13## You should have received a copy of the GNU General Public License
14## along with this program. If not, see <http://www.gnu.org/licenses/>.
15
16## Author: Juan Pablo Carbajal <ajuanpi+dev@gmail.com>
17
18## -*- texinfo -*-
19## @defun {@var{h} =} violin (@var{x})
20## @defunx {@var{h} =} violin (@dots{}, @var{property}, @var{value}, @dots{})
21## @defunx {@var{h} =} violin (@var{hax}, @dots{})
22## @defunx {@var{h} =} violin (@dots{}, @asis{"horizontal"})
23## Produce a Violin plot of the data @var{x}.
24##
25## The input data @var{x} can be a N-by-m array containg N observations of m variables.
26## It can also be a cell with m elements, for the case in which the varibales
27## are not uniformly sampled.
28##
29## The following @var{property} can be set using @var{property}/@var{value} pairs
30## (default values in parenthesis).
31## The value of the property can be a scalar indicating that it applies
32## to all the variables in the data.
33## It can also be a cell/array, indicating the property for each variable.
34## In this case it should have m columns (as many as variables).
35##
36## @table @asis
37##
38## @item Color
39## (@asis{"y"}) Indicates the filling color of the violins.
40##
41## @item Nbins
42## (50) Internally, the function calls @command{hist} to compute the histogram of the data.
43## This property indicates how many bins to use. See @command{help hist}
44## for more details.
45##
46## @item SmoothFactor
47## (4) The fuction performs simple kernel density estimation and automatically
48## finds the bandwith of the kernel function that best approximates the histogram
49## using optimization (@command{sqp}).
50## The result is in general very noisy. To smooth the result the bandwidth is
51## multiplied by the value of this property. The higher the value the smoother
52## the violings, but values too high might remove features from the data distribution.
53##
54## @item Bandwidth
55## (NA) If this property is given a value other than NA, it sets the bandwith of the
56## kernel function. No optimization is peformed and the property @asis{SmoothFactor}
57## is ignored.
58##
59## @item Width
60## (0.5) Sets the maximum width of the violins. Violins are centered at integer axis
61## values. The distance between two violin middle axis is 1. Setting a value
62## higher thna 1 in this property will cause the violins to overlap.
63## @end table
64##
65## If the string @asis{"Horizontal"} is among the input arguments, the violin
66## plot is rendered along the x axis with the variables in the y axis.
67##
68## The returned structure @var{h} has handles to the plot elements, allowing
69## customization of the visualization using set/get functions.
70##
71## Example:
72##
73## @example
74## title ("Grade 3 heights");
75## axis ([0,3]);
76## set (gca, "xtick", 1:2, "xticklabel", @{"girls"; "boys"@});
77## h = violin (@{randn(100,1)*5+140, randn(130,1)*8+135@}, "Nbins", 10);
78## set (h.violin, "linewidth", 2)
79## @end example
80##
81## @seealso{boxplot, hist}
82## @end defun
83
84function h = violin (ax, varargin)
85
86  old_hold = ishold ();
87  # First argument is not an axis
88  if (~ishandle (ax) || ~isscalar (ax))
89    x  = ax;
90    ax = gca ();
91  else
92    x = varargin{1};
93    varargin(1) = [];
94  endif
95
96  ######################
97  ## Parse parameters ##
98  parser = inputParser ();
99  parser.CaseSensitive = false;
100  parser.FunctionName = 'violin';
101
102  parser.addParamValue ('Nbins', 50);
103  parser.addParamValue ('SmoothFactor', 4);
104  parser.addParamValue ('Bandwidth', NA);
105  parser.addParamValue ('Width', 0.5);
106  parser.addParamValue ('Color', "y");
107  parser.addSwitch ('Horizontal');
108
109  parser.parse (varargin{:});
110  res = parser.Results;
111
112  c        = res.Color;        # Color of violins
113  if (ischar (c)) c = c(:); endif
114  nb       = res.Nbins;        # Number of bins in histogram
115  sf       = res.SmoothFactor; # Smoothing factor for kernel estimation
116  r0       = res.Bandwidth;    # User value for KDE bandwth to prevent optimization
117  is_horiz = res.Horizontal;   # Whether the plot must be rotated
118  width    = res.Width;        # Width of the violins
119  clear parser res
120  ######################
121
122  ## Make everything a cell for code simplicity
123  if (~iscell (x))
124    [N Nc] = size (x);
125    x      = mat2cell (x, N, ones (1, Nc));
126  else
127    Nc = numel (x);
128  endif
129
130  try
131    [nb, c, sf, r0, width] = to_cell (nb, c, sf, r0, width, Nc);
132  catch err
133    if strcmp (err.identifier, "to_cell:element_idx")
134      n = str2num (err.message);
135      txt = {"Nbins", "Color", "SmoothFactor", "Bandwidth", "Width"};
136      error ("Octave:invaid-input-arg", ...
137             ["options should be scalars or call/array with as many values as" ...
138              " numbers of variables in the data (wrong size of %s)."], txt{n});
139    else
140      rethrow (lasterror())
141    endif
142  end
143
144  ## Build violins
145  [px py mx] = cellfun (@(y,n,s,r)build_polygon(y, n, s, r), ...
146                          x, nb, sf, r0, "unif", 0);
147
148  Nc    = 1:numel (px);
149  Ncc   = mat2cell (Nc, 1, ones (1, Nc(end)));
150
151  # get hold state
152  old_hold = ishold ();
153
154  # Draw plain violins
155  tmp      = cellfun (@(x,y,n,u, w)patch(ax, (w * x + n)(:), y(:) ,u), ...
156                        px, py, Ncc, c, width);
157  h.violin = tmp;
158
159  hold on
160  # Overlay mean value
161  tmp    = cellfun (@(z,y)plot(ax, z, y,'.k', "markersize", 6), Ncc, mx);
162  h.mean = tmp;
163
164  # Overlay median
165  Mx       = cellfun (@median, x, "unif", 0);
166  tmp      = cellfun (@(z,y)plot(ax, z, y, 'ok'), Ncc, Mx);
167  h.median = tmp;
168
169  # Overlay 1nd and 3th quartiles
170  LUBU = cellfun (@(x,y)abs(quantile(x,[0.25 0.75])-y), x, Mx, "unif", 0);
171  tmp  = cellfun (@(x,y,z)errorbar(ax, x, y, z(1),z(2)), Ncc, Mx, LUBU)(:);
172  # Flatten errorbar output handles
173  tmp2       = allchild (tmp);
174  if (~iscell (tmp2))
175    tmp2 = mat2cell (tmp2, ones(length (tmp2), 1), 1);
176  endif
177  tmp        = mat2cell (tmp, ones (length (tmp), 1), 1);
178  tmp        = cellfun (@vertcat, tmp, tmp2, "unif", 0);
179  h.quartile = cell2mat (tmp);
180
181  hold off
182
183  # Rotate the plot if it is horizontal
184  if (is_horiz)
185    structfun (@swap_axes, h);
186    set (ax, "ytick", Nc);
187  else
188    set (ax, "xtick", Nc);
189  endif
190
191  if (nargout < 1);
192    clear h;
193  endif
194
195  # restore hold state
196  if (old_hold)
197    hold on
198  endif
199endfunction
200
201function k = kde(x,r)
202  k  = mean (stdnormal_pdf (x / r)) / r;
203  k /= max (k);
204endfunction
205
206function [px py mx] = build_polygon (x, nb, sf, r)
207  N  = size (x, 1);
208  mx = mean (x);
209  sx = std (x);
210  X  = (x - mx ) / sx;
211
212  [count bin] = hist (X, nb);
213  count      /= max (count);
214
215  Y  = X - bin;
216  if isna (r)
217    r0 = 1.06 * N^(1/5);
218    r  = sqp (r0, @(r)sumsq (kde(Y,r) - count), [], [], 1e-3, 1e2);
219  else
220    sf = 1;
221  endif
222  sig = sf * r;
223
224  ## Create violin polygon
225  # smooth tails: extend to 1.83 sigmas, i.e. ~99% of data.
226  xx  = linspace (0, 1.83 * sig, 5);
227  bin = [bin(1)-fliplr(xx) bin bin(end)+xx];
228  py  = [bin; fliplr(bin)].'  * sx + mx;
229
230  v  = kde (X-bin, sig).';
231  px = [v -flipud(v)];
232
233endfunction
234
235function tf = swap_axes (h)
236    tmp  = mat2cell (h(:), ones (length (h),1), 1);
237%    tmp  = cellfun (@(x)[x; allchild(x)], tmp, "unif", 0);
238    tmpy = cellfun(@(x)get(x, "ydata"), tmp, "unif", 0);
239    tmpx = cellfun(@(x)get(x, "xdata"), tmp, "unif", 0);
240    cellfun (@(h,x,y)set (h, "xdata", y, "ydata", x), tmp, tmpx, tmpy);
241    tf = true;
242endfunction
243
244function varargout = to_cell (varargin)
245
246    m = varargin{end};
247    varargin(end) = [];
248
249    for i = 1:numel(varargin)
250      x  = varargin{i};
251      if (isscalar (x)) x = repmat (x, m, 1); endif
252
253      if (iscell (x))
254        if (numel(x) ~= m) # no dimension equals m
255          error ("to_cell:element_idx", "%d\n",i);
256        endif
257        varargout{i} = x;
258        continue
259      endif
260
261      sz = size (x);
262      d  = find (sz == m);
263      if (isempty (d)) # no dimension equals m
264        error ("to_cell:element_idx", "%d\n",i);
265      elseif (length (d) == 2)
266        #both dims are m, choose 1st
267      elseif (d == 1) # 2nd dimension is m --> transpose
268        x  = x.';
269        sz = fliplr (sz);
270      endif
271
272      varargout{i} = mat2cell (x, sz(1), ones (m,1));
273
274    endfor
275
276endfunction
277
278%!demo
279%! clf
280%! x = zeros (9e2, 10);
281%! for i=1:10
282%!   x(:,i) = (0.1 * randn (3e2, 3) * (randn (3,1) + 1) + ...
283%!          2 * randn (1,3))(:);
284%! endfor
285%! h = violin (x, "color", "c");
286%! axis tight
287%! set (h.violin, "linewidth", 2);
288%! set (gca, "xgrid", "on");
289%! xlabel ("Variables")
290%! ylabel ("Values")
291
292%!demo
293%! clf
294%! data = {randn(100,1)*5+140, randn(130,1)*8+135};
295%! subplot (1,2,1)
296%! title ("Grade 3 heights - vertical");
297%! set (gca, "xtick", 1:2, "xticklabel", {"girls"; "boys"});
298%! violin (data, "Nbins", 10);
299%! axis tight
300%!
301%! subplot(1,2,2)
302%! title ("Grade 3 heights - horizontal");
303%! set (gca, "ytick", 1:2, "yticklabel", {"girls"; "boys"});
304%! violin (data, "horizontal", "Nbins", 10);
305%! axis tight
306
307%!demo
308%! clf
309%! data = exprnd (0.1, 500,4);
310%! violin (data, "nbins", {5,10,50,100});
311%! axis ([0 5 0 max(data(:))])
312
313%!demo
314%! clf
315%! data = exprnd (0.1, 500,4);
316%! violin (data, "color", jet(4));
317%! axis ([0 5 0 max(data(:))])
318
319%!demo
320%! clf
321%! data = repmat(exprnd (0.1, 500,1), 1, 4);
322%! violin (data, "width", linspace (0.1,0.5,4));
323%! axis ([0 5 0 max(data(:))])
324
325%!demo
326%! clf
327%! data = repmat(exprnd (0.1, 500,1), 1, 4);
328%! violin (data, "nbins", [5,10,50,100], "smoothfactor", [4 4 8 10]);
329%! axis ([0 5 0 max(data(:))])
330