1 /* Copyright (c) 2007-2014 Massachusetts Institute of Technology
2  *
3  * Permission is hereby granted, free of charge, to any person obtaining
4  * a copy of this software and associated documentation files (the
5  * "Software"), to deal in the Software without restriction, including
6  * without limitation the rights to use, copy, modify, merge, publish,
7  * distribute, sublicense, and/or sell copies of the Software, and to
8  * permit persons to whom the Software is furnished to do so, subject to
9  * the following conditions:
10  *
11  * The above copyright notice and this permission notice shall be
12  * included in all copies or substantial portions of the Software.
13  *
14  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
15  * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
16  * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
17  * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
18  * LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
19  * OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
20  * WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
21  */
22 
23 /* Matlab MEX interface to NLopt, and in particular to nlopt_optimize */
24 
25 #include <stdio.h>
26 #include <stdlib.h>
27 #include <string.h>
28 #include <math.h>
29 #include <mex.h>
30 
31 #include "nlopt.h"
32 
33 #define CHECK0(cond, msg) if (!(cond)) mexErrMsgTxt(msg);
34 
struct_val_default(const mxArray * s,const char * name,double dflt)35 static double struct_val_default(const mxArray *s, const char *name, double dflt)
36 {
37      mxArray *val = mxGetField(s, 0, name);
38      if (val) {
39 	  CHECK0(mxIsNumeric(val) && !mxIsComplex(val)
40 		&& mxGetM(val) * mxGetN(val) == 1,
41 		"opt fields, other than xtol_abs, must be real scalars");
42 	  return mxGetScalar(val);
43      }
44      return dflt;
45 }
46 
struct_arrval(const mxArray * s,const char * name,unsigned n,double * dflt)47 static double *struct_arrval(const mxArray *s, const char *name, unsigned n,
48 			     double *dflt)
49 {
50      mxArray *val = mxGetField(s, 0, name);
51      if (val) {
52 	  CHECK0(mxIsNumeric(val) && !mxIsComplex(val)
53 		&& mxGetM(val) * mxGetN(val) == n,
54 		"opt vector field is not of length n");
55 	  return mxGetPr(val);
56      }
57      return dflt;
58 }
59 
struct_funcval(const mxArray * s,const char * name)60 static mxArray *struct_funcval(const mxArray *s, const char *name)
61 {
62      mxArray *val = mxGetField(s, 0, name);
63      if (val) {
64 	  CHECK0(mxIsChar(val) || mxIsFunctionHandle(val),
65 		 "opt function field is not a function handle/name");
66 	  return val;
67      }
68      return NULL;
69 }
70 
fill(double * arr,unsigned n,double val)71 static double *fill(double *arr, unsigned n, double val)
72 {
73      unsigned i;
74      for (i = 0; i < n; ++i) arr[i] = val;
75      return arr;
76 }
77 
78 #define FLEN 128 /* max length of user function name */
79 #define MAXRHS 3 /* max nrhs for user function */
80 typedef struct user_function_data_s {
81      char f[FLEN];
82      mxArray *plhs[2];
83      mxArray *prhs[MAXRHS];
84      int xrhs, nrhs;
85      int verbose, neval;
86      struct user_function_data_s *dpre;
87      nlopt_opt opt;
88 } user_function_data;
89 
user_function(unsigned n,const double * x,double * gradient,void * d_)90 static double user_function(unsigned n, const double *x,
91 			    double *gradient, /* NULL if not needed */
92 			    void *d_)
93 {
94   user_function_data *d = (user_function_data *) d_;
95   double f;
96 
97   d->plhs[0] = d->plhs[1] = NULL;
98   memcpy(mxGetPr(d->prhs[d->xrhs]), x, n * sizeof(double));
99 
100   CHECK0(0 == mexCallMATLAB(gradient ? 2 : 1, d->plhs,
101 			   d->nrhs, d->prhs, d->f),
102 	"error calling user function");
103 
104   CHECK0(mxIsNumeric(d->plhs[0]) && !mxIsComplex(d->plhs[0])
105 	&& mxGetM(d->plhs[0]) * mxGetN(d->plhs[0]) == 1,
106 	"user function must return real scalar");
107   f = mxGetScalar(d->plhs[0]);
108   mxDestroyArray(d->plhs[0]);
109   if (gradient) {
110      CHECK0(mxIsDouble(d->plhs[1]) && !mxIsComplex(d->plhs[1])
111 	   && (mxGetM(d->plhs[1]) == 1 || mxGetN(d->plhs[1]) == 1)
112 	   && mxGetM(d->plhs[1]) * mxGetN(d->plhs[1]) == n,
113 	   "gradient vector from user function is the wrong size");
114      memcpy(gradient, mxGetPr(d->plhs[1]), n * sizeof(double));
115      mxDestroyArray(d->plhs[1]);
116   }
117   d->neval++;
118   if (d->verbose) mexPrintf("nlopt_optimize eval #%d: %g\n", d->neval, f);
119   if (mxIsNaN(f)) nlopt_force_stop(d->opt);
120   return f;
121 }
122 
user_pre(unsigned n,const double * x,const double * v,double * vpre,void * d_)123 static void user_pre(unsigned n, const double *x, const double *v,
124 		       double *vpre, void *d_)
125 {
126   user_function_data *d = ((user_function_data *) d_)->dpre;
127   d->plhs[0] = d->plhs[1] = NULL;
128   memcpy(mxGetPr(d->prhs[d->xrhs]), x, n * sizeof(double));
129   memcpy(mxGetPr(d->prhs[d->xrhs + 1]), v, n * sizeof(double));
130 
131   CHECK0(0 == mexCallMATLAB(1, d->plhs,
132 			    d->nrhs, d->prhs, d->f),
133 	 "error calling user function");
134 
135   CHECK0(mxIsDouble(d->plhs[0]) && !mxIsComplex(d->plhs[0])
136 	 && (mxGetM(d->plhs[0]) == 1 || mxGetN(d->plhs[0]) == 1)
137 	 && mxGetM(d->plhs[0]) * mxGetN(d->plhs[0]) == n,
138 	 "vpre vector from user function is the wrong size");
139   memcpy(vpre, mxGetPr(d->plhs[0]), n * sizeof(double));
140   mxDestroyArray(d->plhs[0]);
141   d->neval++;
142   if (d->verbose) mexPrintf("nlopt_optimize precond eval #%d\n", d->neval);
143 }
144 
145 #define CHECK1(cond, msg) if (!(cond)) { mxFree(tmp); nlopt_destroy(opt); nlopt_destroy(local_opt); mexWarnMsgTxt(msg); return NULL; };
146 
make_opt(const mxArray * opts,unsigned n)147 nlopt_opt make_opt(const mxArray *opts, unsigned n)
148 {
149      nlopt_opt opt = NULL, local_opt = NULL;
150      nlopt_algorithm algorithm;
151      double *tmp = NULL;
152      unsigned i;
153 
154      algorithm = (nlopt_algorithm)
155 	  struct_val_default(opts, "algorithm", NLOPT_NUM_ALGORITHMS);
156      CHECK1(((int)algorithm) >= 0 && algorithm < NLOPT_NUM_ALGORITHMS,
157 	    "invalid opt.algorithm");
158 
159      tmp = (double *) mxCalloc(n, sizeof(double));
160      opt = nlopt_create(algorithm, n);
161      CHECK1(opt, "nlopt: out of memory");
162 
163      nlopt_set_lower_bounds(opt, struct_arrval(opts, "lower_bounds", n,
164 					       fill(tmp, n, -HUGE_VAL)));
165      nlopt_set_upper_bounds(opt, struct_arrval(opts, "upper_bounds", n,
166 					       fill(tmp, n, +HUGE_VAL)));
167 
168      nlopt_set_stopval(opt, struct_val_default(opts, "stopval", -HUGE_VAL));
169      nlopt_set_ftol_rel(opt, struct_val_default(opts, "ftol_rel", 0.0));
170      nlopt_set_ftol_abs(opt, struct_val_default(opts, "ftol_abs", 0.0));
171      nlopt_set_xtol_rel(opt, struct_val_default(opts, "xtol_rel", 0.0));
172      nlopt_set_xtol_abs(opt, struct_arrval(opts, "xtol_abs", n,
173 					   fill(tmp, n, 0.0)));
174      nlopt_set_x_weights(opt, struct_arrval(opts, "x_weights", n,
175 					   fill(tmp, n, 1.0)));
176      nlopt_set_maxeval(opt, struct_val_default(opts, "maxeval", 0.0) < 0 ?
177 		       0 : struct_val_default(opts, "maxeval", 0.0));
178      nlopt_set_maxtime(opt, struct_val_default(opts, "maxtime", 0.0));
179 
180      nlopt_set_population(opt, struct_val_default(opts, "population", 0));
181      nlopt_set_vector_storage(opt, struct_val_default(opts, "vector_storage", 0));
182 
183      if (struct_arrval(opts, "initial_step", n, NULL))
184 	  nlopt_set_initial_step(opt,
185 				 struct_arrval(opts, "initial_step", n, NULL));
186 
187      if (mxGetField(opts, 0, "local_optimizer")) {
188 	  const mxArray *local_opts = mxGetField(opts, 0, "local_optimizer");
189 	  CHECK1(mxIsStruct(local_opts),
190 		 "opt.local_optimizer must be a structure");
191 	  CHECK1(local_opt = make_opt(local_opts, n),
192 		 "error initializing local optimizer");
193 	  nlopt_set_local_optimizer(opt, local_opt);
194 	  nlopt_destroy(local_opt); local_opt = NULL;
195      }
196 
197      mxFree(tmp);
198      return opt;
199 }
200 
201 #define CHECK(cond, msg) if (!(cond)) { mxFree(dh); mxFree(dfc); nlopt_destroy(opt); mexErrMsgTxt(msg); }
202 
mexFunction(int nlhs,mxArray * plhs[],int nrhs,const mxArray * prhs[])203 void mexFunction(int nlhs, mxArray *plhs[],
204                  int nrhs, const mxArray *prhs[])
205 {
206      unsigned n;
207      double *x, *x0, opt_f;
208      nlopt_result ret;
209      mxArray *x_mx, *mx;
210      user_function_data d, dpre, *dfc = NULL, *dh = NULL;
211      nlopt_opt opt = NULL;
212 
213      CHECK(nrhs == 2 && nlhs <= 3, "wrong number of arguments");
214 
215      /* options = prhs[0] */
216      CHECK(mxIsStruct(prhs[0]), "opt must be a struct");
217 
218      /* x0 = prhs[1] */
219      CHECK(mxIsDouble(prhs[1]) && !mxIsComplex(prhs[1])
220 	   && (mxGetM(prhs[1]) == 1 || mxGetN(prhs[1]) == 1),
221 	   "x must be real row or column vector");
222      n = mxGetM(prhs[1]) * mxGetN(prhs[1]),
223      x0 = mxGetPr(prhs[1]);
224 
225      CHECK(opt = make_opt(prhs[0], n), "error initializing nlopt options");
226 
227      d.neval = 0;
228      d.verbose = (int) struct_val_default(prhs[0], "verbose", 0);
229      d.opt = opt;
230 
231      /* function f = prhs[1] */
232      mx = struct_funcval(prhs[0], "min_objective");
233      if (!mx) mx = struct_funcval(prhs[0], "max_objective");
234      CHECK(mx, "either opt.min_objective or opt.max_objective must exist");
235      if (mxIsChar(mx)) {
236 	  CHECK(mxGetString(mx, d.f, FLEN) == 0,
237 		"error reading function name string (too long?)");
238 	  d.nrhs = 1;
239 	  d.xrhs = 0;
240      }
241      else {
242 	  d.prhs[0] = mx;
243 	  strcpy(d.f, "feval");
244 	  d.nrhs = 2;
245 	  d.xrhs = 1;
246      }
247      d.prhs[d.xrhs] = mxCreateDoubleMatrix(1, n, mxREAL);
248 
249      if ((mx = struct_funcval(prhs[0], "pre"))) {
250 	  CHECK(mxIsChar(mx) || mxIsFunctionHandle(mx),
251 		"pre must contain function handles or function names");
252 	  if (mxIsChar(mx)) {
253 	       CHECK(mxGetString(mx, dpre.f, FLEN) == 0,
254                      "error reading function name string (too long?)");
255 	       dpre.nrhs = 2;
256 	       dpre.xrhs = 0;
257 	  }
258 	  else {
259 	       dpre.prhs[0] = mx;
260 	       strcpy(dpre.f, "feval");
261 	       dpre.nrhs = 3;
262 	       dpre.xrhs = 1;
263 	  }
264 	  dpre.verbose = d.verbose > 2;
265 	  dpre.opt = opt;
266 	  dpre.neval = 0;
267 	  dpre.prhs[dpre.xrhs] = d.prhs[d.xrhs];
268 	  dpre.prhs[d.xrhs+1] = mxCreateDoubleMatrix(1, n, mxREAL);
269 	  d.dpre = &dpre;
270 
271 	  if (struct_funcval(prhs[0], "min_objective"))
272 	       nlopt_set_precond_min_objective(opt, user_function,user_pre,&d);
273 	  else
274 	       nlopt_set_precond_max_objective(opt, user_function,user_pre,&d);
275      }
276      else {
277 	  dpre.nrhs = 0;
278 	  if (struct_funcval(prhs[0], "min_objective"))
279 	       nlopt_set_min_objective(opt, user_function, &d);
280 	  else
281 	       nlopt_set_max_objective(opt, user_function, &d);
282      }
283 
284      if ((mx = mxGetField(prhs[0], 0, "fc"))) {
285 	  int j, m;
286 	  double *fc_tol;
287 
288 	  CHECK(mxIsCell(mx), "fc must be a Cell array");
289 	  m = mxGetM(mx) * mxGetN(mx);;
290 	  dfc = (user_function_data *) mxCalloc(m, sizeof(user_function_data));
291 	  fc_tol = struct_arrval(prhs[0], "fc_tol", m, NULL);
292 
293 	  for (j = 0; j < m; ++j) {
294 	       mxArray *fc = mxGetCell(mx, j);
295 	       CHECK(mxIsChar(fc) || mxIsFunctionHandle(fc),
296 		     "fc must contain function handles or function names");
297 	       if (mxIsChar(fc)) {
298 		    CHECK(mxGetString(fc, dfc[j].f, FLEN) == 0,
299 		     "error reading function name string (too long?)");
300 		    dfc[j].nrhs = 1;
301 		    dfc[j].xrhs = 0;
302 	       }
303 	       else {
304 		    dfc[j].prhs[0] = fc;
305 		    strcpy(dfc[j].f, "feval");
306 		    dfc[j].nrhs = 2;
307 		    dfc[j].xrhs = 1;
308 	       }
309 	       dfc[j].verbose = d.verbose > 1;
310 	       dfc[j].opt = opt;
311 	       dfc[j].neval = 0;
312 	       dfc[j].prhs[dfc[j].xrhs] = d.prhs[d.xrhs];
313 	       CHECK(nlopt_add_inequality_constraint(opt, user_function,
314 						     dfc + j,
315 						     fc_tol ? fc_tol[j] : 0)
316 		     > 0, "nlopt error adding inequality constraint");
317 	  }
318      }
319 
320 
321      if ((mx = mxGetField(prhs[0], 0, "h"))) {
322 	  int j, m;
323 	  double *h_tol;
324 
325 	  CHECK(mxIsCell(mx), "h must be a Cell array");
326 	  m = mxGetM(mx) * mxGetN(mx);;
327 	  dh = (user_function_data *) mxCalloc(m, sizeof(user_function_data));
328 	  h_tol = struct_arrval(prhs[0], "h_tol", m, NULL);
329 
330 	  for (j = 0; j < m; ++j) {
331 	       mxArray *h = mxGetCell(mx, j);
332 	       CHECK(mxIsChar(h) || mxIsFunctionHandle(h),
333 		     "h must contain function handles or function names");
334 	       if (mxIsChar(h)) {
335 		    CHECK(mxGetString(h, dh[j].f, FLEN) == 0,
336 		     "error reading function name string (too long?)");
337 		    dh[j].nrhs = 1;
338 		    dh[j].xrhs = 0;
339 	       }
340 	       else {
341 		    dh[j].prhs[0] = h;
342 		    strcpy(dh[j].f, "feval");
343 		    dh[j].nrhs = 2;
344 		    dh[j].xrhs = 1;
345 	       }
346 	       dh[j].verbose = d.verbose > 1;
347 	       dh[j].opt = opt;
348 	       dh[j].neval = 0;
349 	       dh[j].prhs[dh[j].xrhs] = d.prhs[d.xrhs];
350 	       CHECK(nlopt_add_equality_constraint(opt, user_function,
351 						     dh + j,
352 						   h_tol ? h_tol[j] : 0)
353 		     > 0, "nlopt error adding equality constraint");
354 	  }
355      }
356 
357 
358      x_mx = mxCreateDoubleMatrix(mxGetM(prhs[1]), mxGetN(prhs[1]), mxREAL);
359      x = mxGetPr(x_mx);
360      memcpy(x, x0, sizeof(double) * n);
361 
362      ret = nlopt_optimize(opt, x, &opt_f);
363 
364      mxFree(dh);
365      mxFree(dfc);
366      mxDestroyArray(d.prhs[d.xrhs]);
367      if (dpre.nrhs > 0) mxDestroyArray(dpre.prhs[d.xrhs+1]);
368      nlopt_destroy(opt);
369 
370      plhs[0] = x_mx;
371      if (nlhs > 1) {
372 	  plhs[1] = mxCreateDoubleMatrix(1, 1, mxREAL);
373 	  *(mxGetPr(plhs[1])) = opt_f;
374      }
375      if (nlhs > 2) {
376 	  plhs[2] = mxCreateDoubleMatrix(1, 1, mxREAL);
377 	  *(mxGetPr(plhs[2])) = (int) ret;
378      }
379 }
380