1% STK_PREDICT [overload STK function]
2
3% Copyright Notice
4%
5%    Copyright (C) 2015-2018 CentraleSupelec
6%    Copyright (C) 2011-2014 SUPELEC
7%
8%    Authors:  Julien Bect       <julien.bect@centralesupelec.fr>
9%              Emmanuel Vazquez  <emmanuel.vazquez@centralesupelec.fr>
10
11% Copying Permission Statement
12%
13%    This file is part of
14%
15%            STK: a Small (Matlab/Octave) Toolbox for Kriging
16%               (http://sourceforge.net/projects/kriging)
17%
18%    STK is free software: you can redistribute it and/or modify it under
19%    the terms of the GNU General Public License as published by the Free
20%    Software Foundation,  either version 3  of the License, or  (at your
21%    option) any later version.
22%
23%    STK is distributed  in the hope that it will  be useful, but WITHOUT
24%    ANY WARRANTY;  without even the implied  warranty of MERCHANTABILITY
25%    or FITNESS  FOR A  PARTICULAR PURPOSE.  See  the GNU  General Public
26%    License for more details.
27%
28%    You should  have received a copy  of the GNU  General Public License
29%    along with STK.  If not, see <http://www.gnu.org/licenses/>.
30
31function [zp, lambda, mu, K] = stk_predict (M_post, xt)
32
33% TODO: these should become options
34block_size = [];
35
36M_prior = M_post.prior_model;
37
38%--- Convert and check input arguments: xt -------------------------------------
39
40xt = double (xt);
41% FIXME: check variable names
42
43if (strcmp (M_prior.covariance_type, 'stk_discretecov')) && (isempty (xt))
44    % In this case, predict on all points of the underlying discrete space
45    nt = size (M_prior.param.K, 1);
46    xt = (1:nt)';
47else
48    nt = size (xt, 1);
49    if length (size (xt)) > 2
50        stk_error (['The input argument xt should not have more than two ' ...
51            'dimensions'], 'IncorrectSize');
52    elseif ~ isequal (size (xt), [nt M_prior.dim])
53        stk_error (sprintf (['The number of columns of xt (which is %d) ' ...
54            'does not agree with the dimension of the model (which is ' ...
55            '%d).'], size (xt, 2), M_prior.dim), 'IncorrectSize');
56    end
57end
58
59%--- Prepare the output arguments ----------------------------------------------
60
61zp_v = zeros (nt, 1);
62compute_prediction = ~ isempty (M_post.output_data);
63
64% compute the kriging prediction, or just the variances ?
65if compute_prediction
66    zp_a = zeros (nt, 1);
67else
68    zp_a = nan (nt, 1);
69end
70
71%--- Choose nb_blocks & block_size ---------------------------------------------
72
73n_obs = size (M_post.input_data, 1);
74
75if isempty (block_size)
76    MAX_RS_SIZE = 5e6;  SIZE_OF_DOUBLE = 8;  % in bytes
77    block_size = ceil (MAX_RS_SIZE / (n_obs * SIZE_OF_DOUBLE));
78end
79
80if nt == 0
81    % skip main loop
82    nb_blocks = 0;
83else
84    % blocks of size approx. block_size
85    nb_blocks = max (1, ceil (nt / block_size));
86    block_size = ceil (nt / nb_blocks);
87end
88
89% The full lambda_mu matrix is only needed when nargout > 1
90if nargout > 1
91    lambda_mu = zeros (n_obs + get (M_post.kreq, 'r'), nt);
92end
93
94% The full RS matrix is only needed when nargout > 3
95if nargout > 3
96    RS = zeros (size (lambda_mu));
97end
98
99%--- MAIN LOOP (over blocks) ---------------------------------------------------
100
101% TODO: this loop should be parallelized !!!
102
103for block_num = 1:nb_blocks
104
105    % compute the indices for the current block
106    idx_beg = 1 + block_size * (block_num - 1);
107    idx_end = min (nt, idx_beg + block_size - 1);
108    idx = idx_beg:idx_end;
109
110    % solve the kriging equation for the current block
111    xt_ = xt(idx, :);
112    kreq = stk_make_kreq (M_post, xt_);
113
114    % compute the kriging mean
115    if compute_prediction
116        zp_a(idx) = (get (kreq, 'lambda'))' * (double (M_post.output_data));
117    end
118
119    % The full lambda_mu matrix is only needed when nargout > 1
120    if nargout > 1
121        lambda_mu(:, idx) = get (kreq, 'lambda_mu');
122    end
123
124    % The full RS matrix is only needed when nargout > 3
125    if nargout > 3
126        RS(:, idx) = get (kreq, 'RS');
127    end
128
129    % compute kriging variances (this does NOT include the noise variance)
130    zp_v(idx) = stk_make_matcov (M_prior, xt_, xt_, true) ...
131        - get (kreq, 'delta_var');
132
133    % note: the following modification computes prediction variances for noisy
134    % variance, i.e., including the noise variance also
135    %    zp_v(idx) = stk_make_matcov (M_prior, xt_, [], true) ...
136    %                 - get (kreq, 'delta_var');
137
138    b = (zp_v < 0);
139    if any (b)
140        zp_v(b) = 0.0;
141        warning('STK:stk_predict:NegativeVariancesSetToZero', sprintf ( ...
142            ['Correcting numerical inaccuracies in kriging variance.\n' ...
143            '(%d negative variances have been set to zero)'], sum (b)));
144    end
145
146end
147
148%--- Ensure exact prediction at observation points for noiseless models --------
149
150if ~ stk_isnoisy (M_prior)
151
152    % FIXME: Fix the kreq object instead ?
153
154    xi = double (M_post.input_data);
155    zi = double (M_post.output_data);
156
157    [b, loc] = ismember (xt, xi, 'rows');
158    if sum (b) > 0
159
160        if compute_prediction
161            zp_a(b) = zi(loc(b));
162        end
163
164        zp_v(b) = 0.0;
165
166        if nargout > 1
167            lambda_mu(:, b) = 0.0;
168            lambda_mu(sub2ind (size (lambda_mu), loc(b), find (b))) = 1.0;
169        end
170    end
171end
172
173
174%--- Prepare outputs -----------------------------------------------------------
175
176zp = stk_dataframe ([zp_a zp_v], {'mean' 'var'});
177
178if nargout > 1 % lambda requested
179    lambda = lambda_mu(1:n_obs, :);
180end
181
182if nargout > 2 % mu requested
183    mu = lambda_mu((n_obs+1):end, :);
184end
185
186if nargout > 3
187    K0 = stk_make_matcov (M_prior, xt, xt);
188    deltaK = lambda_mu' * RS;
189    K = K0 - 0.5 * (deltaK + deltaK');
190end
191
192end % function
193
194%#ok<*SPWRN>
195
196
197%!shared n, m, M_post, M_prior, x0, x_obs, z_obs, x_prd, y_prd, idx_obs, idx_prd
198%!
199%! n = 10;     % number of observations
200%! m = n + 1;  % number of predictions
201%! d = 1;      % dimension of the input space
202%!
203%! x0 = (linspace (0, pi, n + m))';
204%!
205%! idx_obs = (2:2:(n+m-1))';
206%! idx_prd = (1:2:(n+m))';
207%!
208%! x_obs = x0(idx_obs);
209%! z_obs = sin (x_obs);
210%! x_prd = x0(idx_prd);
211%!
212%! M_prior = stk_model ('stk_materncov32_iso');
213%! M_prior.param = log ([1.0; 2.1]);
214%!
215%! M_post = stk_model_gpposterior (M_prior, x_obs, z_obs);
216
217%!error y_prd = stk_predict (M_post);
218%!test  y_prd = stk_predict (M_post, x_prd);
219%!error y_prd = stk_predict (M_post, [x_prd x_prd]);
220
221%!test  % nargout = 2
222%! [y_prd1, lambda] = stk_predict (M_post, x_prd);
223%! assert (stk_isequal_tolrel (y_prd, y_prd1));
224%! assert (isequal (size (lambda), [n m]));
225
226%!test  % nargout = 3
227%! [y_prd1, lambda, mu] = stk_predict (M_post, x_prd);
228%! assert (stk_isequal_tolrel (y_prd, y_prd1));
229%! assert (isequal (size (lambda), [n m]));
230%! assert (isequal (size (mu), [1 m]));  % ordinary kriging
231
232%!test  % nargout = 4
233%! [y_prd1, lambda, mu, K] = stk_predict (M_post, x_prd);
234%! assert (stk_isequal_tolrel (y_prd, y_prd1));
235%! assert (isequal (size (lambda), [n m]));
236%! assert (isequal (size (mu), [1 m]));  % ordinary kriging
237%! assert (isequal (size (K), [m m]));
238
239%!test  % nargout = 2, compute only variances
240%! M_post1 = stk_model_gpposterior (M_prior, x_obs, []);
241%! [y_prd_nan, lambda] = stk_predict (M_post1, x_prd);
242%! assert (isequal (size (lambda), [n m]));
243%! assert (all (isnan (y_prd_nan.mean)));
244
245%!test % discrete model (prediction indices provided)
246%! M_prior1 = stk_model ('stk_discretecov', M_prior, x0);
247%! M_post1 = stk_model_gpposterior (M_prior1, idx_obs, z_obs);
248%! y_prd1 = stk_predict (M_post1, idx_prd);
249%! assert (stk_isequal_tolrel (y_prd, y_prd1));
250
251%!test % discrete model (prediction indices *not* provided)
252%! M_prior1 = stk_model ('stk_discretecov', M_prior, x0);
253%! M_post1 = stk_model_gpposterior (M_prior1, idx_obs, z_obs);
254%! y_prd1 = stk_predict (M_post1, []);  % predict them all!
255%! assert (stk_isequal_tolrel (y_prd, y_prd1(idx_prd, :)));
256