1 /* Copyright (c) 2007-2014 Massachusetts Institute of Technology
2 *
3 * Permission is hereby granted, free of charge, to any person obtaining
4 * a copy of this software and associated documentation files (the
5 * "Software"), to deal in the Software without restriction, including
6 * without limitation the rights to use, copy, modify, merge, publish,
7 * distribute, sublicense, and/or sell copies of the Software, and to
8 * permit persons to whom the Software is furnished to do so, subject to
9 * the following conditions:
10 *
11 * The above copyright notice and this permission notice shall be
12 * included in all copies or substantial portions of the Software.
13 *
14 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
15 * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
16 * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
17 * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
18 * LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
19 * OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
20 * WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
21 */
22
23 /* Matlab MEX interface to NLopt, and in particular to nlopt_optimize */
24
25 #include <stdio.h>
26 #include <stdlib.h>
27 #include <string.h>
28 #include <math.h>
29 #include <mex.h>
30
31 #include "nlopt.h"
32
33 #define CHECK0(cond, msg) if (!(cond)) mexErrMsgTxt(msg);
34
struct_val_default(const mxArray * s,const char * name,double dflt)35 static double struct_val_default(const mxArray *s, const char *name, double dflt)
36 {
37 mxArray *val = mxGetField(s, 0, name);
38 if (val) {
39 CHECK0(mxIsNumeric(val) && !mxIsComplex(val)
40 && mxGetM(val) * mxGetN(val) == 1,
41 "opt fields, other than xtol_abs, must be real scalars");
42 return mxGetScalar(val);
43 }
44 return dflt;
45 }
46
struct_arrval(const mxArray * s,const char * name,unsigned n,double * dflt)47 static double *struct_arrval(const mxArray *s, const char *name, unsigned n,
48 double *dflt)
49 {
50 mxArray *val = mxGetField(s, 0, name);
51 if (val) {
52 CHECK0(mxIsNumeric(val) && !mxIsComplex(val)
53 && mxGetM(val) * mxGetN(val) == n,
54 "opt vector field is not of length n");
55 return mxGetPr(val);
56 }
57 return dflt;
58 }
59
struct_funcval(const mxArray * s,const char * name)60 static mxArray *struct_funcval(const mxArray *s, const char *name)
61 {
62 mxArray *val = mxGetField(s, 0, name);
63 if (val) {
64 CHECK0(mxIsChar(val) || mxIsFunctionHandle(val),
65 "opt function field is not a function handle/name");
66 return val;
67 }
68 return NULL;
69 }
70
fill(double * arr,unsigned n,double val)71 static double *fill(double *arr, unsigned n, double val)
72 {
73 unsigned i;
74 for (i = 0; i < n; ++i) arr[i] = val;
75 return arr;
76 }
77
78 #define FLEN 128 /* max length of user function name */
79 #define MAXRHS 3 /* max nrhs for user function */
80 typedef struct user_function_data_s {
81 char f[FLEN];
82 mxArray *plhs[2];
83 mxArray *prhs[MAXRHS];
84 int xrhs, nrhs;
85 int verbose, neval;
86 struct user_function_data_s *dpre;
87 nlopt_opt opt;
88 } user_function_data;
89
user_function(unsigned n,const double * x,double * gradient,void * d_)90 static double user_function(unsigned n, const double *x,
91 double *gradient, /* NULL if not needed */
92 void *d_)
93 {
94 user_function_data *d = (user_function_data *) d_;
95 double f;
96
97 d->plhs[0] = d->plhs[1] = NULL;
98 memcpy(mxGetPr(d->prhs[d->xrhs]), x, n * sizeof(double));
99
100 CHECK0(0 == mexCallMATLAB(gradient ? 2 : 1, d->plhs,
101 d->nrhs, d->prhs, d->f),
102 "error calling user function");
103
104 CHECK0(mxIsNumeric(d->plhs[0]) && !mxIsComplex(d->plhs[0])
105 && mxGetM(d->plhs[0]) * mxGetN(d->plhs[0]) == 1,
106 "user function must return real scalar");
107 f = mxGetScalar(d->plhs[0]);
108 mxDestroyArray(d->plhs[0]);
109 if (gradient) {
110 CHECK0(mxIsDouble(d->plhs[1]) && !mxIsComplex(d->plhs[1])
111 && (mxGetM(d->plhs[1]) == 1 || mxGetN(d->plhs[1]) == 1)
112 && mxGetM(d->plhs[1]) * mxGetN(d->plhs[1]) == n,
113 "gradient vector from user function is the wrong size");
114 memcpy(gradient, mxGetPr(d->plhs[1]), n * sizeof(double));
115 mxDestroyArray(d->plhs[1]);
116 }
117 d->neval++;
118 if (d->verbose) mexPrintf("nlopt_optimize eval #%d: %g\n", d->neval, f);
119 if (mxIsNaN(f)) nlopt_force_stop(d->opt);
120 return f;
121 }
122
user_pre(unsigned n,const double * x,const double * v,double * vpre,void * d_)123 static void user_pre(unsigned n, const double *x, const double *v,
124 double *vpre, void *d_)
125 {
126 user_function_data *d = ((user_function_data *) d_)->dpre;
127 d->plhs[0] = d->plhs[1] = NULL;
128 memcpy(mxGetPr(d->prhs[d->xrhs]), x, n * sizeof(double));
129 memcpy(mxGetPr(d->prhs[d->xrhs + 1]), v, n * sizeof(double));
130
131 CHECK0(0 == mexCallMATLAB(1, d->plhs,
132 d->nrhs, d->prhs, d->f),
133 "error calling user function");
134
135 CHECK0(mxIsDouble(d->plhs[0]) && !mxIsComplex(d->plhs[0])
136 && (mxGetM(d->plhs[0]) == 1 || mxGetN(d->plhs[0]) == 1)
137 && mxGetM(d->plhs[0]) * mxGetN(d->plhs[0]) == n,
138 "vpre vector from user function is the wrong size");
139 memcpy(vpre, mxGetPr(d->plhs[0]), n * sizeof(double));
140 mxDestroyArray(d->plhs[0]);
141 d->neval++;
142 if (d->verbose) mexPrintf("nlopt_optimize precond eval #%d\n", d->neval);
143 }
144
145 #define CHECK1(cond, msg) if (!(cond)) { mxFree(tmp); nlopt_destroy(opt); nlopt_destroy(local_opt); mexWarnMsgTxt(msg); return NULL; };
146
make_opt(const mxArray * opts,unsigned n)147 nlopt_opt make_opt(const mxArray *opts, unsigned n)
148 {
149 nlopt_opt opt = NULL, local_opt = NULL;
150 nlopt_algorithm algorithm;
151 double *tmp = NULL;
152 unsigned i;
153
154 algorithm = (nlopt_algorithm)
155 struct_val_default(opts, "algorithm", NLOPT_NUM_ALGORITHMS);
156 CHECK1(((int)algorithm) >= 0 && algorithm < NLOPT_NUM_ALGORITHMS,
157 "invalid opt.algorithm");
158
159 tmp = (double *) mxCalloc(n, sizeof(double));
160 opt = nlopt_create(algorithm, n);
161 CHECK1(opt, "nlopt: out of memory");
162
163 nlopt_set_lower_bounds(opt, struct_arrval(opts, "lower_bounds", n,
164 fill(tmp, n, -HUGE_VAL)));
165 nlopt_set_upper_bounds(opt, struct_arrval(opts, "upper_bounds", n,
166 fill(tmp, n, +HUGE_VAL)));
167
168 nlopt_set_stopval(opt, struct_val_default(opts, "stopval", -HUGE_VAL));
169 nlopt_set_ftol_rel(opt, struct_val_default(opts, "ftol_rel", 0.0));
170 nlopt_set_ftol_abs(opt, struct_val_default(opts, "ftol_abs", 0.0));
171 nlopt_set_xtol_rel(opt, struct_val_default(opts, "xtol_rel", 0.0));
172 nlopt_set_xtol_abs(opt, struct_arrval(opts, "xtol_abs", n,
173 fill(tmp, n, 0.0)));
174 nlopt_set_x_weights(opt, struct_arrval(opts, "x_weights", n,
175 fill(tmp, n, 1.0)));
176 nlopt_set_maxeval(opt, struct_val_default(opts, "maxeval", 0.0) < 0 ?
177 0 : struct_val_default(opts, "maxeval", 0.0));
178 nlopt_set_maxtime(opt, struct_val_default(opts, "maxtime", 0.0));
179
180 nlopt_set_population(opt, struct_val_default(opts, "population", 0));
181 nlopt_set_vector_storage(opt, struct_val_default(opts, "vector_storage", 0));
182
183 if (struct_arrval(opts, "initial_step", n, NULL))
184 nlopt_set_initial_step(opt,
185 struct_arrval(opts, "initial_step", n, NULL));
186
187 if (mxGetField(opts, 0, "local_optimizer")) {
188 const mxArray *local_opts = mxGetField(opts, 0, "local_optimizer");
189 CHECK1(mxIsStruct(local_opts),
190 "opt.local_optimizer must be a structure");
191 CHECK1(local_opt = make_opt(local_opts, n),
192 "error initializing local optimizer");
193 nlopt_set_local_optimizer(opt, local_opt);
194 nlopt_destroy(local_opt); local_opt = NULL;
195 }
196
197 mxFree(tmp);
198 return opt;
199 }
200
201 #define CHECK(cond, msg) if (!(cond)) { mxFree(dh); mxFree(dfc); nlopt_destroy(opt); mexErrMsgTxt(msg); }
202
mexFunction(int nlhs,mxArray * plhs[],int nrhs,const mxArray * prhs[])203 void mexFunction(int nlhs, mxArray *plhs[],
204 int nrhs, const mxArray *prhs[])
205 {
206 unsigned n;
207 double *x, *x0, opt_f;
208 nlopt_result ret;
209 mxArray *x_mx, *mx;
210 user_function_data d, dpre, *dfc = NULL, *dh = NULL;
211 nlopt_opt opt = NULL;
212
213 CHECK(nrhs == 2 && nlhs <= 3, "wrong number of arguments");
214
215 /* options = prhs[0] */
216 CHECK(mxIsStruct(prhs[0]), "opt must be a struct");
217
218 /* x0 = prhs[1] */
219 CHECK(mxIsDouble(prhs[1]) && !mxIsComplex(prhs[1])
220 && (mxGetM(prhs[1]) == 1 || mxGetN(prhs[1]) == 1),
221 "x must be real row or column vector");
222 n = mxGetM(prhs[1]) * mxGetN(prhs[1]),
223 x0 = mxGetPr(prhs[1]);
224
225 CHECK(opt = make_opt(prhs[0], n), "error initializing nlopt options");
226
227 d.neval = 0;
228 d.verbose = (int) struct_val_default(prhs[0], "verbose", 0);
229 d.opt = opt;
230
231 /* function f = prhs[1] */
232 mx = struct_funcval(prhs[0], "min_objective");
233 if (!mx) mx = struct_funcval(prhs[0], "max_objective");
234 CHECK(mx, "either opt.min_objective or opt.max_objective must exist");
235 if (mxIsChar(mx)) {
236 CHECK(mxGetString(mx, d.f, FLEN) == 0,
237 "error reading function name string (too long?)");
238 d.nrhs = 1;
239 d.xrhs = 0;
240 }
241 else {
242 d.prhs[0] = mx;
243 strcpy(d.f, "feval");
244 d.nrhs = 2;
245 d.xrhs = 1;
246 }
247 d.prhs[d.xrhs] = mxCreateDoubleMatrix(1, n, mxREAL);
248
249 if ((mx = struct_funcval(prhs[0], "pre"))) {
250 CHECK(mxIsChar(mx) || mxIsFunctionHandle(mx),
251 "pre must contain function handles or function names");
252 if (mxIsChar(mx)) {
253 CHECK(mxGetString(mx, dpre.f, FLEN) == 0,
254 "error reading function name string (too long?)");
255 dpre.nrhs = 2;
256 dpre.xrhs = 0;
257 }
258 else {
259 dpre.prhs[0] = mx;
260 strcpy(dpre.f, "feval");
261 dpre.nrhs = 3;
262 dpre.xrhs = 1;
263 }
264 dpre.verbose = d.verbose > 2;
265 dpre.opt = opt;
266 dpre.neval = 0;
267 dpre.prhs[dpre.xrhs] = d.prhs[d.xrhs];
268 dpre.prhs[d.xrhs+1] = mxCreateDoubleMatrix(1, n, mxREAL);
269 d.dpre = &dpre;
270
271 if (struct_funcval(prhs[0], "min_objective"))
272 nlopt_set_precond_min_objective(opt, user_function,user_pre,&d);
273 else
274 nlopt_set_precond_max_objective(opt, user_function,user_pre,&d);
275 }
276 else {
277 dpre.nrhs = 0;
278 if (struct_funcval(prhs[0], "min_objective"))
279 nlopt_set_min_objective(opt, user_function, &d);
280 else
281 nlopt_set_max_objective(opt, user_function, &d);
282 }
283
284 if ((mx = mxGetField(prhs[0], 0, "fc"))) {
285 int j, m;
286 double *fc_tol;
287
288 CHECK(mxIsCell(mx), "fc must be a Cell array");
289 m = mxGetM(mx) * mxGetN(mx);;
290 dfc = (user_function_data *) mxCalloc(m, sizeof(user_function_data));
291 fc_tol = struct_arrval(prhs[0], "fc_tol", m, NULL);
292
293 for (j = 0; j < m; ++j) {
294 mxArray *fc = mxGetCell(mx, j);
295 CHECK(mxIsChar(fc) || mxIsFunctionHandle(fc),
296 "fc must contain function handles or function names");
297 if (mxIsChar(fc)) {
298 CHECK(mxGetString(fc, dfc[j].f, FLEN) == 0,
299 "error reading function name string (too long?)");
300 dfc[j].nrhs = 1;
301 dfc[j].xrhs = 0;
302 }
303 else {
304 dfc[j].prhs[0] = fc;
305 strcpy(dfc[j].f, "feval");
306 dfc[j].nrhs = 2;
307 dfc[j].xrhs = 1;
308 }
309 dfc[j].verbose = d.verbose > 1;
310 dfc[j].opt = opt;
311 dfc[j].neval = 0;
312 dfc[j].prhs[dfc[j].xrhs] = d.prhs[d.xrhs];
313 CHECK(nlopt_add_inequality_constraint(opt, user_function,
314 dfc + j,
315 fc_tol ? fc_tol[j] : 0)
316 > 0, "nlopt error adding inequality constraint");
317 }
318 }
319
320
321 if ((mx = mxGetField(prhs[0], 0, "h"))) {
322 int j, m;
323 double *h_tol;
324
325 CHECK(mxIsCell(mx), "h must be a Cell array");
326 m = mxGetM(mx) * mxGetN(mx);;
327 dh = (user_function_data *) mxCalloc(m, sizeof(user_function_data));
328 h_tol = struct_arrval(prhs[0], "h_tol", m, NULL);
329
330 for (j = 0; j < m; ++j) {
331 mxArray *h = mxGetCell(mx, j);
332 CHECK(mxIsChar(h) || mxIsFunctionHandle(h),
333 "h must contain function handles or function names");
334 if (mxIsChar(h)) {
335 CHECK(mxGetString(h, dh[j].f, FLEN) == 0,
336 "error reading function name string (too long?)");
337 dh[j].nrhs = 1;
338 dh[j].xrhs = 0;
339 }
340 else {
341 dh[j].prhs[0] = h;
342 strcpy(dh[j].f, "feval");
343 dh[j].nrhs = 2;
344 dh[j].xrhs = 1;
345 }
346 dh[j].verbose = d.verbose > 1;
347 dh[j].opt = opt;
348 dh[j].neval = 0;
349 dh[j].prhs[dh[j].xrhs] = d.prhs[d.xrhs];
350 CHECK(nlopt_add_equality_constraint(opt, user_function,
351 dh + j,
352 h_tol ? h_tol[j] : 0)
353 > 0, "nlopt error adding equality constraint");
354 }
355 }
356
357
358 x_mx = mxCreateDoubleMatrix(mxGetM(prhs[1]), mxGetN(prhs[1]), mxREAL);
359 x = mxGetPr(x_mx);
360 memcpy(x, x0, sizeof(double) * n);
361
362 ret = nlopt_optimize(opt, x, &opt_f);
363
364 mxFree(dh);
365 mxFree(dfc);
366 mxDestroyArray(d.prhs[d.xrhs]);
367 if (dpre.nrhs > 0) mxDestroyArray(dpre.prhs[d.xrhs+1]);
368 nlopt_destroy(opt);
369
370 plhs[0] = x_mx;
371 if (nlhs > 1) {
372 plhs[1] = mxCreateDoubleMatrix(1, 1, mxREAL);
373 *(mxGetPr(plhs[1])) = opt_f;
374 }
375 if (nlhs > 2) {
376 plhs[2] = mxCreateDoubleMatrix(1, 1, mxREAL);
377 *(mxGetPr(plhs[2])) = (int) ret;
378 }
379 }
380