1## Copyright (C) 2014 Nir Krakauer
2##
3## This program is free software; you can redistribute it and/or modify
4## it under the terms of the GNU General Public License as published by
5## the Free Software Foundation; either version 3 of the License, or
6## (at your option) any later version.
7##
8## This program is distributed in the hope that it will be useful,
9## but WITHOUT ANY WARRANTY; without even the implied warranty of
10## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
11## GNU General Public License for more details.
12##
13## You should have received a copy of the GNU General Public License
14## along with this program; If not, see <http://www.gnu.org/licenses/>.
15
16## -*- texinfo -*-
17## @deftypefn{Function File}{@var{C} =} cvpartition (@var{X}, [@var{partition_type}, [@var{k}]])
18## Create a partition object for cross validation.
19##
20## @var{X} may be a positive integer, interpreted as the number of values @var{n} to partition, or a vector of length @var{n} containing class designations for the elements, in which case the partitioning types @var{KFold} and @var{HoldOut} attempt to ensure each partition represents the classes proportionately.
21##
22## @var{partition_type} must be one of the following:
23##
24## @table @asis
25## @item @samp{KFold}
26## Divide set into @var{k} equal-size subsets (this is the default, with @var{k}=10).
27## @item @samp{HoldOut}
28## Divide set into two subsets, "training" and "validation". If @var{k} is a fraction, that is the fraction of values put in the validation subset; if it is a positive integer, that is the number of values in the validation subset (by default @var{k}=0.1).
29## @item @samp{LeaveOut}
30## Leave-one-out partition (each element is placed in its own subset).
31## @item @samp{resubstitution}
32## Training and validation subsets that both contain all the original elements.
33## @item @samp{Given}
34## Subset indices are as given in @var{X}.
35## @end table
36##
37## The following fields are defined for the @samp{cvpartition} class:
38##
39## @table @asis
40## @item @samp{classes}
41## Class designations for the elements.
42## @item @samp{inds}
43## Subset indices for the elements.
44## @item @samp{n_classes}
45## Number of different classes.
46## @item @samp{NumObservations}
47## @var{n}, number of elements in data set.
48## @item @samp{NumTestSets}
49## Number of testing subsets.
50## @item @samp{TestSize}
51## Number of elements in (each) testing subset.
52## @item @samp{TrainSize}
53## Number of elements in (each) training subset.
54## @item @samp{Type}
55## Partition type.
56## @end table
57##
58## @seealso{crossval}
59## @end deftypefn
60
61## Author: Nir Krakauer
62
63function C = cvpartition (X, partition_type = 'KFold', k = [])
64
65  if (nargin < 1 || nargin > 3 || !isvector(X))
66    print_usage ();
67  endif
68
69  if isscalar (X)
70    n = X;
71    n_classes = 1;
72  else
73    n = numel (X);
74  endif
75
76  switch tolower(partition_type)
77    case {'kfold' 'holdout' 'leaveout' 'resubstitution' 'given'}
78    otherwise
79      warning ('unrecognized type, using KFold')
80      partition_type = 'KFold';
81  endswitch
82
83  switch tolower(partition_type)
84    case {'kfold' 'holdout' 'given'}
85      if !isscalar (X)
86        [y, ~, j] = unique (X(:));
87        n_per_class = accumarray (j, 1);
88        n_classes = numel (n_per_class);
89      endif
90  endswitch
91
92  C = struct ("classes", [], "inds", [], "n_classes", [], "NumObservations", [], "NumTestSets", [], "TestSize", [], "TrainSize", [], "Type", []);
93  #The non-Matlab fields classes, inds, n_classes are only useful for some methods
94
95  switch tolower(partition_type)
96    case 'kfold'
97      if isempty (k)
98        k = 10;
99      endif
100      if n_classes == 1
101        inds = floor((0:(n-1))' * (k / n)) + 1;
102      else
103        inds = nan(n, 1);
104        for i = 1:n_classes
105          if mod (i, 2) #alternate ordering over classes so that the subsets are more nearly the same size
106            inds(j == i) = floor((0:(n_per_class(i)-1))' * (k / n_per_class(i))) + 1;
107          else
108            inds(j == i) = floor(((n_per_class(i)-1):-1:0)' * (k / n_per_class(i))) + 1;
109          endif
110        endfor
111      endif
112      C.inds = inds;
113      C.NumTestSets = k;
114      [~, ~, jj] = unique (inds);
115      n_per_subset = accumarray (jj, 1);
116      C.TrainSize = n - n_per_subset;
117      C.TestSize = n_per_subset;
118    case 'given'
119      C.inds = j;
120      C.NumTestSets = n_classes;
121      C.TrainSize = n - n_per_class;
122      C.TestSize = n_per_class;
123    case 'holdout'
124      if isempty (k)
125        k = 0.1;
126      endif
127      if k < 1
128        f = k; #target fraction to sample
129        k = round (k * n); #number of samples
130      else
131        f = k / n;
132      endif
133      inds = zeros (n, 1, "logical");
134      if n_classes == 1
135        inds(randsample(n, k)) = true; #indices for test set
136      else #sample from each class
137        k_check = 0;
138        for i = 1:n_classes
139          ki = round(f*n_per_class(i));
140          inds(find(j == i)(randsample(n_per_class(i), ki))) = true;
141          k_check += ki;
142        endfor
143        if k_check < k #add random elements to test set to make it k
144          inds(find(!inds)(randsample(n - k_check, k - k_check))) = true;
145        elseif k_check > k #remove random elements from test set
146          inds(find(inds)(randsample(k_check, k_check - k))) = false;
147        endif
148        C.classes = j;
149      endif
150      C.n_classes = n_classes;
151      C.TrainSize = n - k;
152      C.TestSize = k;
153      C.NumTestSets = 1;
154      C.inds = inds;
155    case 'leaveout'
156      C.TrainSize = ones (n, 1);
157      C.TestSize = (n-1)  * ones (n, 1);
158      C.NumTestSets = n;
159    case 'resubstitution'
160      C.TrainSize = C.TestSize = n;
161      C.NumTestSets = 1;
162  endswitch
163
164  C.NumObservations = n;
165  C.Type = tolower (partition_type);
166
167  C = class (C, "cvpartition");
168
169endfunction
170
171
172%!demo
173%! # Partition with Fisher iris dataset (n = 150)
174%! # Stratified by species
175%! load fisheriris.txt
176%! y = fisheriris(:, 1);
177%! # 10-fold cross-validation partition
178%! c = cvpartition (y, 'KFold', 10)
179%! # leave-10-out partition
180%! c1 = cvpartition (y, 'HoldOut', 10)
181%! idx1 = test (c, 2);
182%! idx2 = training (c, 2);
183%! # another leave-10-out partition
184%! c2 = repartition (c1)
185#plot(struct(c).inds, '*')
186
187