1#
2#     This file is part of CasADi.
3#
4#     CasADi -- A symbolic framework for dynamic optimization.
5#     Copyright (C) 2010-2014 Joel Andersson, Joris Gillis, Moritz Diehl,
6#                             K.U. Leuven. All rights reserved.
7#     Copyright (C) 2011-2014 Greg Horn
8#
9#     CasADi is free software; you can redistribute it and/or
10#     modify it under the terms of the GNU Lesser General Public
11#     License as published by the Free Software Foundation; either
12#     version 3 of the License, or (at your option) any later version.
13#
14#     CasADi is distributed in the hope that it will be useful,
15#     but WITHOUT ANY WARRANTY; without even the implied warranty of
16#     MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
17#     Lesser General Public License for more details.
18#
19#     You should have received a copy of the GNU Lesser General Public
20#     License along with CasADi; if not, write to the Free Software
21#     Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA
22#
23#
24# -*- coding: utf-8 -*-
25from casadi import *
26#
27# How to use Callback
28# Joel Andersson
29#
30
31class MyCallback(Callback):
32  def __init__(self, name, d, opts={}):
33    Callback.__init__(self)
34    self.d = d
35    self.construct(name, opts)
36
37  # Number of inputs and outputs
38  def get_n_in(self): return 1
39  def get_n_out(self): return 1
40
41  # Initialize the object
42  def init(self):
43     print('initializing object')
44
45  # Evaluate numerically
46  def eval(self, arg):
47    x = arg[0]
48    f = sin(self.d*x)
49    return [f]
50
51# Use the function
52f = MyCallback('f', 0.5)
53res = f(2)
54print(res)
55
56# You may call the Callback symbolically
57x = MX.sym("x")
58print(f(x))
59
60# Derivates OPTION 1: finite-differences
61eps = 1e-5
62print((f(2+eps)-f(2))/eps)
63
64f = MyCallback('f', 0.5, {"enable_fd":True})
65J = Function('J',[x],[jacobian(f(x),x)])
66print(J(2))
67
68# Derivates OPTION 2: Supply forward mode
69# Example from https://www.youtube.com/watch?v=mYOkLkS5yqc&t=4s
70
71class Example4To3(Callback):
72  def __init__(self, name, opts={}):
73    Callback.__init__(self)
74    self.construct(name, opts)
75
76  def get_n_in(self): return 1
77  def get_n_out(self): return 1
78
79  def get_sparsity_in(self,i):
80    return Sparsity.dense(4,1)
81
82  def get_sparsity_out(self,i):
83    return Sparsity.dense(3,1)
84
85  # Evaluate numerically
86  def eval(self, arg):
87    a,b,c,d = vertsplit(arg[0])
88    ret = vertcat(sin(c)*d+d**2,2*a+c,b**2+5*c)
89    return [ret]
90
91
92class Example4To3_Fwd(Example4To3):
93  def has_forward(self,nfwd):
94    # This example is written to work with a single forward seed vector
95    # For efficiency, you may allow more seeds at once
96    return nfwd==1
97  def get_forward(self,nfwd,name,inames,onames,opts):
98
99    class ForwardFun(Callback):
100      def __init__(self, opts={}):
101        Callback.__init__(self)
102        self.construct(name, opts)
103
104      def get_n_in(self): return 3
105      def get_n_out(self): return 1
106
107      def get_sparsity_in(self,i):
108        if i==0: # nominal input
109          return Sparsity.dense(4,1)
110        elif i==1: # nominal output
111          return Sparsity(3,1)
112        else: # Forward seed
113          return Sparsity.dense(4,1)
114
115      def get_sparsity_out(self,i):
116        # Forward sensitivity
117        return Sparsity.dense(3,1)
118
119      # Evaluate numerically
120      def eval(self, arg):
121        a,b,c,d = vertsplit(arg[0])
122        a_dot,b_dot,c_dot,d_dot = vertsplit(arg[2])
123        print("Forward sweep with", a_dot,b_dot,c_dot,d_dot)
124        w0 = sin(c)
125        w0_dot = cos(c)*c_dot
126        w1 = w0*d
127        w1_dot = w0_dot*d+w0*d_dot
128        w2 = d**2
129        w2_dot = 2*d_dot*d
130        r0 = w1+w2
131        r0_dot = w1_dot + w2_dot
132        w3 = 2*a
133        w3_dot = 2*a_dot
134        r1 = w3+c
135        r1_dot = w3_dot+c_dot
136        w4 = b**2
137        w4_dot = 2*b_dot*b
138        w5 = 5*w0
139        w5_dot = 5*w0_dot
140        r2 = w4+w5
141        r2_dot = w4_dot + w5_dot
142        ret = vertcat(r0_dot,r1_dot,r2_dot)
143        return [ret]
144    # You are required to keep a reference alive to the returned Callback object
145    self.fwd_callback = ForwardFun()
146    return self.fwd_callback
147
148
149f = Example4To3_Fwd('f')
150x = MX.sym("x",4)
151J = Function('J',[x],[jacobian(f(x),x)])
152print(J(vertcat(1,2,0,3)))
153
154# Derivates OPTION 3: Supply reverse mode
155
156class Example4To3_Rev(Example4To3):
157  def has_reverse(self,nadj):
158    # This example is written to work with a single forward seed vector
159    # For efficiency, you may allow more seeds at once
160    return nadj==1
161  def get_reverse(self,nfwd,name,inames,onames,opts):
162
163    class ReverseFun(Callback):
164      def __init__(self, opts={}):
165        Callback.__init__(self)
166        self.construct(name, opts)
167
168      def get_n_in(self): return 3
169      def get_n_out(self): return 1
170
171      def get_sparsity_in(self,i):
172        if i==0: # nominal input
173          return Sparsity.dense(4,1)
174        elif i==1: # nominal output
175          return Sparsity(3,1)
176        else: # Reverse seed
177          return Sparsity.dense(3,1)
178
179      def get_sparsity_out(self,i):
180        # Reverse sensitivity
181        return Sparsity.dense(4,1)
182
183      # Evaluate numerically
184      def eval(self, arg):
185        a,b,c,d = vertsplit(arg[0])
186        r0_bar,r1_bar,r2_bar = vertsplit(arg[2])
187        print("Reverse sweep with", r0_bar, r1_bar, r2_bar)
188        w0 = sin(c)
189        w1 = w0*d
190        w2 = d**2
191        r0 = w1+w2
192        w3 = 2*a
193        r1 = w3+c
194        w4 = b**2
195        w5 = 5*w0
196        r2 = w4+w5
197        w4_bar = r2_bar
198        w5_bar = r2_bar
199        w0_bar = 5*w5_bar
200        b_bar = 2*b*w4_bar
201        w3_bar = r1_bar
202        c_bar = r1_bar
203        a_bar = 2*w3_bar
204        w1_bar = r0_bar
205        w2_bar = r0_bar
206        d_bar = 2*d*w2_bar
207        w0_bar = w0_bar + w1_bar*d
208        d_bar = d_bar + w0*w1_bar
209        c_bar = c_bar + cos(c)*w0_bar
210        ret = vertcat(a_bar,b_bar,c_bar,d_bar)
211        return [ret]
212    # You are required to keep a reference alive to the returned Callback object
213    self.rev_callback = ReverseFun()
214    return self.rev_callback
215
216
217f = Example4To3_Rev('f')
218x = MX.sym("x",4)
219J = Function('J',[x],[jacobian(f(x),x)])
220print(J(vertcat(1,2,0,3)))
221
222# Derivates OPTION 4: Supply full Jacobian
223
224class Example4To3_Jac(Example4To3):
225  def has_jacobian(self): return True
226  def get_jacobian(self,name,inames,onames,opts):
227    class JacFun(Callback):
228      def __init__(self, opts={}):
229        Callback.__init__(self)
230        self.construct(name, opts)
231
232      def get_n_in(self): return 2
233      def get_n_out(self): return 1
234
235      def get_sparsity_in(self,i):
236        if i==0: # nominal input
237          return Sparsity.dense(4,1)
238        elif i==1: # nominal output
239          return Sparsity(3,1)
240
241      def get_sparsity_out(self,i):
242        return sparsify(DM([[0,0,1,1],[1,0,1,0],[0,1,1,0]])).sparsity()
243
244      # Evaluate numerically
245      def eval(self, arg):
246        a,b,c,d = vertsplit(arg[0])
247        ret = DM(3,4)
248        ret[0,2] = d*cos(c)
249        ret[0,3] = sin(c)+2*d
250        ret[1,0] = 2
251        ret[1,2] = 1
252        ret[2,1] = 2*b
253        ret[2,2] = 5
254        return [ret]
255
256    # You are required to keep a reference alive to the returned Callback object
257    self.jac_callback = JacFun()
258    return self.jac_callback
259
260f = Example4To3_Jac('f')
261x = MX.sym("x",4)
262J = Function('J',[x],[jacobian(f(x),x)])
263print(J(vertcat(1,2,0,3)))
264
265
266