1%
2%     This file is part of CasADi.
3%
4%     CasADi -- A symbolic framework for dynamic optimization.
5%     Copyright (C) 2010-2014 Joel Andersson, Joris Gillis, Moritz Diehl,
6%                             K.U. Leuven. All rights reserved.
7%     Copyright (C) 2011-2014 Greg Horn
8%
9%     CasADi is free software; you can redistribute it and/or
10%     modify it under the terms of the GNU Lesser General Public
11%     License as published by the Free Software Foundation; either
12%     version 3 of the License, or (at your option) any later version.
13%
14%     CasADi is distributed in the hope that it will be useful,
15%     but WITHOUT ANY WARRANTY; without even the implied warranty of
16%     MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
17%     Lesser General Public License for more details.
18%
19%     You should have received a copy of the GNU Lesser General Public
20%     License along with CasADi; if not, write to the Free Software
21%     Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA
22%
23%
24
25import casadi.*
26disp 'Testing sensitivity analysis in CasADi'
27
28% All ODE and DAE integrators to be tested
29DAE_integrators = {'idas','collocation'};
30ODE_integrators = {'cvodes','rk', DAE_integrators{:}};
31
32for ode=0:1
33  if ode
34    disp '******'
35    disp 'Testing ODE example'
36    Integrators = ODE_integrators;
37
38    % Time
39    t = SX.sym('t');
40
41    % Parameter
42    u = SX.sym('u');
43
44    % Differential states
45    s = SX.sym('s'); v = SX.sym('v'); m = SX.sym('m');
46    x = [s;v;m];
47
48    % Constants
49    alpha = 0.05; % friction
50    beta = 0.1;   % fuel consumption rate
51
52    % Differential equation
53    ode = [v;
54           (u-alpha*v*v)/m;
55           -beta*u*u];
56
57    % Quadrature
58    quad = v^3 + ((3-sin(t)) - u)^2;
59
60    % DAE
61    dae = struct('t', t, 'x', x, 'p', u, 'ode', ode, 'quad', quad);
62
63    % Time length
64    tf = 0.5;
65
66    % Initial position
67    x0 = [0;0;1];
68
69    % Parameter
70    u0 = 0.4;
71
72  else
73    disp '******'
74    disp 'Testing DAE example'
75    Integrators = DAE_integrators;
76
77    % Differential state
78    x = SX.sym('x');
79
80    % Algebraic variable
81    z = SX.sym('z');
82
83    % Parameter
84    u = SX.sym('u');
85
86    % Differential equation
87    ode = -x + 0.5*x*x + u + 0.5*z;
88
89    % Algebraic constraint
90    alg = z + exp(z) - 1.0 + x;
91
92    % Quadrature
93    quad = x*x + 3.0*u*u;
94
95    % DAE
96    dae = struct('x', x, 'z', z, 'p', u, 'ode', ode, 'alg', alg, 'quad', quad);
97
98    % End time
99    tf = 5;
100
101    % Initial position
102    x0 = 1;
103
104    % Parameter
105    u0 = 0.4;
106  end
107
108  % Integrator
109  for integ=1:numel(Integrators)
110    MyIntegrator = Integrators{integ};
111
112    disp(sprintf('========'));
113    disp(sprintf('Integrator: %s', MyIntegrator));
114    disp(sprintf('========'));
115
116    % Integrator options
117    opts = struct('tf', tf);
118    if strcmp(MyIntegrator,'collocation')
119      opts.rootfinder = 'kinsol';
120    end
121
122    % Integrator
123    I = casadi.integrator('I', MyIntegrator, dae, opts);
124
125    % Integrate to get results
126    res = I('x0', x0, 'p', u0);
127    xf = full(res.xf);
128    qf = full(res.qf);
129    fprintf('%50s: xf=%s, qf=%s\n', 'Unperturbed solution', ...
130            sprintf('%d ', xf), sprintf('%d ', qf));
131
132    % Perturb solution to get a finite difference approximation
133    h = 0.001;
134    res = I('x0', x0, 'p', u0+h);
135    fd_xf = (full(res.xf)-xf)/h;
136    fd_qf = (full(res.qf)-qf)/h;
137    fprintf('%50s: d(xf)/d(p)=%s, d(qf)/d(p)=%s\n', 'Finite differences', ...
138            sprintf('%d ', fd_xf), sprintf('%d ', fd_qf));
139
140    % Calculate one directional derivative, forward mode
141    I_fwd = I.factory('I_fwd', {'x0', 'z0', 'p', 'fwd:p'}, {'fwd:xf', 'fwd:qf'});
142    res = I_fwd('x0', x0, 'p', u0, 'fwd_p', 1);
143    fwd_xf = full(res.fwd_xf);
144    fwd_qf = full(res.fwd_qf);
145    fprintf('%50s: d(xf)/d(p)=%s, d(qf)/d(p)=%s\n', 'Forward sensitivities', ...
146            sprintf('%d ', fwd_xf), sprintf('%d ', fwd_qf));
147
148    % Calculate one directional derivative, reverse mode
149    I_adj = I.factory('I_adj', {'x0', 'z0', 'p', 'adj:qf'}, {'adj:x0', 'adj:p'});
150    res = I_adj('x0', x0, 'p', u0, 'adj_qf', 1);
151    adj_x0 = full(res.adj_x0);
152    adj_p = full(res.adj_p);
153    fprintf('%50s: d(qf)/d(x0)=%s, d(qf)/d(p)=%s\n', 'Adjoint sensitivities', ...
154            sprintf('%d ', adj_x0), sprintf('%d ', adj_p));
155
156    % Perturb adjoint solution to get a finite difference approximation of
157    % the second order sensitivities
158    res = I_adj('x0', x0, 'p', u0+h, 'adj_qf', 1);
159    fd_adj_x0 = (full(res.adj_x0)-adj_x0)/h;
160    fd_adj_p = (full(res.adj_p)-adj_p)/h;
161    fprintf('%50s: d2(qf)/d(x0)d(p)=%s, d2(qf)/d(p)d(p)=%s\n', ...
162            'FD of adjoint sensitivities', ...
163            sprintf('%d ', fd_adj_x0), sprintf('%d ', fd_adj_p));
164
165    % Forward over adjoint to get the second order sensitivities
166    I_foa = I_adj.factory('I_foa', {'x0', 'z0', 'p', 'adj_qf', 'fwd:p'}, ...
167                          {'fwd:adj_x0', 'fwd:adj_p'});
168    res = I_foa('x0', x0, 'p', u0, 'adj_qf', 1, 'fwd_p', 1);
169    fwd_adj_x0 = full(res.fwd_adj_x0);
170    fwd_adj_p = full(res.fwd_adj_p);
171    fprintf('%50s: d2(qf)/d(x0)d(p)=%s, d2(qf)/d(p)d(p)=%s\n', ...
172            'Forward over adjoint sensitivities', ...
173            sprintf('%d ', fd_adj_x0), sprintf('%d ', fd_adj_p));
174
175    % Adjoint over adjoint to get the second order sensitivities
176    I_aoa = I_adj.factory('I_aoa', {'x0', 'z0', 'p', 'adj_qf', 'adj:adj_p'}, ...
177                         {'adj:x0', 'adj:p'});
178    res = I_aoa('x0', x0, 'p', u0, 'adj_qf', 1, 'adj_adj_p', 1);
179    adj_adj_x0 = full(res.adj_x0);
180    adj_adj_p = full(res.adj_p);
181    fprintf('%50s: d2(qf)/d(x0)d(p)=%s, d2(qf)/d(p)d(p)=%s\n', ...
182            'Adjoint over adjoint sensitivities', ...
183            sprintf('%d ', adj_adj_x0), sprintf('%d ', adj_adj_p));
184  end
185end
186