1 /**************************************************************************************************
2 * *
3 * This file is part of HPIPM. *
4 * *
5 * HPIPM -- High-Performance Interior Point Method. *
6 * Copyright (C) 2019 by Gianluca Frison. *
7 * Developed at IMTEK (University of Freiburg) under the supervision of Moritz Diehl. *
8 * All rights reserved. *
9 * *
10 * The 2-Clause BSD License *
11 * *
12 * Redistribution and use in source and binary forms, with or without *
13 * modification, are permitted provided that the following conditions are met: *
14 * *
15 * 1. Redistributions of source code must retain the above copyright notice, this *
16 * list of conditions and the following disclaimer. *
17 * 2. Redistributions in binary form must reproduce the above copyright notice, *
18 * this list of conditions and the following disclaimer in the documentation *
19 * and/or other materials provided with the distribution. *
20 * *
21 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND *
22 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED *
23 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE *
24 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR *
25 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES *
26 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; *
27 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND *
28 * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT *
29 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS *
30 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. *
31 * *
32 * Author: Gianluca Frison, gianluca.frison (at) imtek.uni-freiburg.de *
33 * *
34 **************************************************************************************************/
35
INIT_VAR_TREE_OCP_QP(struct TREE_OCP_QP * qp,struct TREE_OCP_QP_SOL * qp_sol,struct TREE_OCP_QP_IPM_ARG * arg,struct TREE_OCP_QP_IPM_WORKSPACE * ws)36 void INIT_VAR_TREE_OCP_QP(struct TREE_OCP_QP *qp, struct TREE_OCP_QP_SOL *qp_sol, struct TREE_OCP_QP_IPM_ARG *arg, struct TREE_OCP_QP_IPM_WORKSPACE *ws)
37 {
38
39 // struct CORE_QP_IPM_WORKSPACE *cws = ws->core_workspace;
40
41 // loop index
42 int ii, jj;
43
44 //
45 int Nn = qp->dim->Nn;
46 int *nx = qp->dim->nx;
47 int *nu = qp->dim->nu;
48 int *nb = qp->dim->nb;
49 int *ng = qp->dim->ng;
50 int *ns = qp->dim->ns;
51
52 REAL mu0 = arg->mu0;
53
54 //
55 REAL *ux, *pi, *d_lb, *d_ub, *d_lg, *d_ug, *lam_lb, *lam_ub, *lam_lg, *lam_ug, *t_lb, *t_ub, *t_lg, *t_ug;
56 int *idxb;
57
58 REAL thr0 = 0.1;
59
60
61
62 // primal and dual variables
63 if(arg->warm_start==2)
64 {
65
66 thr0 = 1e-1;
67
68 for(ii=0; ii<Nn; ii++)
69 {
70 lam_lb = qp_sol->lam[ii].pa+0;
71 t_lb = qp_sol->t[ii].pa+0;
72
73 for(jj=0; jj<2*nb[ii]+2*ng[ii]+2*ns[ii]; jj++)
74 {
75 if(lam_lb[jj]<thr0)
76 lam_lb[jj] = thr0;
77 if(t_lb[jj]<thr0)
78 t_lb[jj] = thr0;
79 }
80 }
81
82 return;
83 }
84
85
86
87 // ux
88 if(arg->warm_start==0)
89 {
90 // cold start
91 for(ii=0; ii<Nn; ii++)
92 {
93 ux = qp_sol->ux[ii].pa;
94 for(jj=0; jj<nu[ii]+nx[ii]+2*ns[ii]; jj++)
95 {
96 ux[jj] = 0.0;
97 }
98 }
99 }
100 else
101 {
102 // warm start (keep u and x in solution)
103 for(ii=0; ii<Nn; ii++)
104 {
105 ux = qp_sol->ux[ii].pa;
106 for(jj=nu[ii]+nx[ii]; jj<nu[ii]+nx[ii]+2*ns[ii]; jj++)
107 {
108 ux[jj] = 0.0;
109 }
110 }
111 }
112
113 // pi
114 for(ii=0; ii<Nn-1; ii++)
115 {
116 pi = qp_sol->pi[ii].pa;
117 for(jj=0; jj<nx[ii+1]; jj++)
118 {
119 pi[jj] = 0.0;
120 }
121 }
122
123 // box constraints
124 for(ii=0; ii<Nn; ii++)
125 {
126 ux = qp_sol->ux[ii].pa;
127 d_lb = qp->d[ii].pa+0;
128 d_ub = qp->d[ii].pa+nb[ii]+ng[ii];
129 lam_lb = qp_sol->lam[ii].pa+0;
130 lam_ub = qp_sol->lam[ii].pa+nb[ii]+ng[ii];
131 t_lb = qp_sol->t[ii].pa+0;
132 t_ub = qp_sol->t[ii].pa+nb[ii]+ng[ii];
133 idxb = qp->idxb[ii];
134 for(jj=0; jj<nb[ii]; jj++)
135 {
136 #if 1
137 t_lb[jj] = - d_lb[jj] + ux[idxb[jj]];
138 t_ub[jj] = - d_ub[jj] - ux[idxb[jj]];
139 if(t_lb[jj]<thr0)
140 {
141 if(t_ub[jj]<thr0)
142 {
143 ux[idxb[jj]] = 0.5*(d_lb[jj] + d_ub[jj]);
144 t_lb[jj] = thr0;
145 t_ub[jj] = thr0;
146 }
147 else
148 {
149 t_lb[jj] = thr0;
150 ux[idxb[jj]] = d_lb[jj] + thr0;
151 }
152 }
153 else if(t_ub[jj]<thr0)
154 {
155 t_ub[jj] = thr0;
156 ux[idxb[jj]] = - d_ub[jj] - thr0;
157 }
158 #else
159 t_lb[jj] = 1.0;
160 t_ub[jj] = 1.0;
161 #endif
162 lam_lb[jj] = mu0/t_lb[jj];
163 lam_ub[jj] = mu0/t_ub[jj];
164 }
165 }
166
167 // general constraints
168 for(ii=0; ii<Nn; ii++)
169 {
170 t_lg = qp_sol->t[ii].pa+nb[ii];
171 t_ug = qp_sol->t[ii].pa+2*nb[ii]+ng[ii];
172 lam_lg = qp_sol->lam[ii].pa+nb[ii];
173 lam_ug = qp_sol->lam[ii].pa+2*nb[ii]+ng[ii];
174 d_lg = qp->d[ii].pa+nb[ii];
175 d_ug = qp->d[ii].pa+2*nb[ii]+ng[ii];
176 ux = qp_sol->ux[ii].pa;
177 GEMV_T(nu[ii]+nx[ii], ng[ii], 1.0, qp->DCt+ii, 0, 0, qp_sol->ux+ii, 0, 0.0, qp_sol->t+ii, nb[ii], qp_sol->t+ii, nb[ii]);
178 for(jj=0; jj<ng[ii]; jj++)
179 {
180 #if 1
181 t_ug[jj] = - t_lg[jj];
182 t_lg[jj] -= d_lg[jj];
183 t_ug[jj] -= d_ug[jj];
184 // t_lg[jj] = fmax(thr0, t_lg[jj]);
185 // t_ug[jj] = fmax(thr0, t_ug[jj]);
186 t_lg[jj] = thr0>t_lg[jj] ? thr0 : t_lg[jj];
187 t_ug[jj] = thr0>t_ug[jj] ? thr0 : t_ug[jj];
188 #else
189 t_lg[jj] = 1.0;
190 t_ug[jj] = 1.0;
191 #endif
192 lam_lg[jj] = mu0/t_lg[jj];
193 lam_ug[jj] = mu0/t_ug[jj];
194 }
195 }
196
197 // soft constraints
198 for(ii=0; ii<Nn; ii++)
199 {
200 lam_lb = qp_sol->lam[ii].pa+2*nb[ii]+2*ng[ii];
201 lam_ub = qp_sol->lam[ii].pa+2*nb[ii]+2*ng[ii]+ns[ii];
202 t_lb = qp_sol->t[ii].pa+2*nb[ii]+2*ng[ii];
203 t_ub = qp_sol->t[ii].pa+2*nb[ii]+2*ng[ii]+ns[ii];
204 for(jj=0; jj<ns[ii]; jj++)
205 {
206 t_lb[jj] = 1.0; // thr0;
207 t_ub[jj] = 1.0; // thr0;
208 lam_lb[jj] = mu0/t_lb[jj];
209 lam_ub[jj] = mu0/t_ub[jj];
210 }
211 }
212
213 return;
214
215 }
216
217
218
COMPUTE_RES_TREE_OCP_QP(struct TREE_OCP_QP * qp,struct TREE_OCP_QP_SOL * qp_sol,struct TREE_OCP_QP_RES * res,struct TREE_OCP_QP_RES_WORKSPACE * ws)219 void COMPUTE_RES_TREE_OCP_QP(struct TREE_OCP_QP *qp, struct TREE_OCP_QP_SOL *qp_sol, struct TREE_OCP_QP_RES *res, struct TREE_OCP_QP_RES_WORKSPACE *ws)
220 {
221
222 struct tree *ttree = qp->dim->ttree;
223
224 // loop index
225 int ii, jj;
226
227 int nkids, idxkid;
228
229 //
230 int Nn = qp->dim->Nn;
231 int *nx = qp->dim->nx;
232 int *nu = qp->dim->nu;
233 int *nb = qp->dim->nb;
234 int *ng = qp->dim->ng;
235 int *ns = qp->dim->ns;
236
237 int nct = 0;
238 for(ii=0; ii<Nn; ii++)
239 nct += 2*nb[ii]+2*ng[ii]+2*ns[ii];
240
241 REAL nct_inv = 1.0/nct;
242
243
244 struct STRMAT *BAbt = qp->BAbt;
245 struct STRMAT *RSQrq = qp->RSQrq;
246 struct STRMAT *DCt = qp->DCt;
247 struct STRVEC *b = qp->b;
248 struct STRVEC *rqz = qp->rqz;
249 struct STRVEC *d = qp->d;
250 struct STRVEC *m = qp->m;
251 int **idxb = qp->idxb;
252 struct STRVEC *Z = qp->Z;
253 int **idxs = qp->idxs;
254
255 struct STRVEC *ux = qp_sol->ux;
256 struct STRVEC *pi = qp_sol->pi;
257 struct STRVEC *lam = qp_sol->lam;
258 struct STRVEC *t = qp_sol->t;
259
260 struct STRVEC *res_g = res->res_g;
261 struct STRVEC *res_b = res->res_b;
262 struct STRVEC *res_d = res->res_d;
263 struct STRVEC *res_m = res->res_m;
264
265 struct STRVEC *tmp_nbgM = ws->tmp_nbgM;
266 struct STRVEC *tmp_nsM = ws->tmp_nsM;
267
268 int nx0, nx1, nu0, nu1, nb0, ng0, ns0;
269
270 //
271 REAL mu = 0.0;
272
273 // loop over nodes
274 for(ii=0; ii<Nn; ii++)
275 {
276
277 nx0 = nx[ii];
278 nu0 = nu[ii];
279 nb0 = nb[ii];
280 ng0 = ng[ii];
281 ns0 = ns[ii];
282
283 VECCP(nu0+nx0, rqz+ii, 0, res_g+ii, 0);
284
285 // if not root
286 if(ii>0)
287 AXPY(nx0, -1.0, pi+(ii-1), 0, res_g+ii, nu0, res_g+ii, nu0);
288
289 SYMV_L(nu0+nx0, nu0+nx0, 1.0, RSQrq+ii, 0, 0, ux+ii, 0, 1.0, res_g+ii, 0, res_g+ii, 0);
290
291 if(nb0+ng0>0)
292 {
293 AXPY(nb0+ng0, -1.0, lam+ii, 0, lam+ii, nb[ii]+ng[ii], tmp_nbgM+0, 0);
294 AXPY(nb0+ng0, 1.0, d+ii, 0, t+ii, 0, res_d+ii, 0);
295 AXPY(nb0+ng0, 1.0, d+ii, nb0+ng0, t+ii, nb0+ng0, res_d+ii, nb0+ng0);
296 // box
297 if(nb0>0)
298 {
299 VECAD_SP(nb0, 1.0, tmp_nbgM+0, 0, idxb[ii], res_g+ii, 0);
300 VECEX_SP(nb0, 1.0, idxb[ii], ux+ii, 0, tmp_nbgM+1, 0);
301 }
302 // general
303 if(ng0>0)
304 {
305 GEMV_NT(nu0+nx0, ng0, 1.0, 1.0, DCt+ii, 0, 0, tmp_nbgM+0, nb[ii], ux+ii, 0, 1.0, 0.0, res_g+ii, 0, tmp_nbgM+1, nb0, res_g+ii, 0, tmp_nbgM+1, nb0);
306 }
307
308 AXPY(nb0+ng0, -1.0, tmp_nbgM+1, 0, res_d+ii, 0, res_d+ii, 0);
309 AXPY(nb0+ng0, 1.0, tmp_nbgM+1, 0, res_d+ii, nb0+ng0, res_d+ii, nb0+ng0);
310 }
311 if(ns0>0)
312 {
313 // res_g
314 GEMV_DIAG(2*ns0, 1.0, Z+ii, 0, ux+ii, nu0+nx0, 1.0, rqz+ii, nu0+nx0, res_g+ii, nu0+nx0);
315 AXPY(2*ns0, -1.0, lam+ii, 2*nb0+2*ng0, res_g+ii, nu0+nx0, res_g+ii, nu0+nx0);
316 VECEX_SP(ns0, 1.0, idxs[ii], lam+ii, 0, tmp_nsM, 0);
317 AXPY(ns0, -1.0, tmp_nsM, 0, res_g+ii, nu0+nx0, res_g+ii, nu0+nx0);
318 VECEX_SP(ns0, 1.0, idxs[ii], lam+ii, nb0+ng0, tmp_nsM, 0);
319 AXPY(ns0, -1.0, tmp_nsM, 0, res_g+ii, nu0+nx0+ns0, res_g+ii, nu0+nx0+ns0);
320 // res_d
321 VECAD_SP(ns0, -1.0, ux+ii, nu0+nx0, idxs[ii], res_d+ii, 0);
322 VECAD_SP(ns0, -1.0, ux+ii, nu0+nx0+ns0, idxs[ii], res_d+ii, nb0+ng0);
323 AXPY(2*ns0, -1.0, ux+ii, nu0+nx0, t+ii, 2*nb0+2*ng0, res_d+ii, 2*nb0+2*ng0);
324 AXPY(2*ns0, 1.0, d+ii, 2*nb0+2*ng0, res_d+ii, 2*nb0+2*ng0, res_d+ii, 2*nb0+2*ng0);
325 }
326
327 // work on kids
328 nkids = (ttree->root+ii)->nkids;
329 for(jj=0; jj<nkids; jj++)
330 {
331
332 idxkid = (ttree->root+ii)->kids[jj];
333
334 nu1 = nu[idxkid];
335 nx1 = nx[idxkid];
336
337 AXPY(nx1, -1.0, ux+idxkid, nu1, b+idxkid-1, 0, res_b+idxkid-1, 0);
338
339 GEMV_NT(nu0+nx0, nx1, 1.0, 1.0, BAbt+idxkid-1, 0, 0, pi+idxkid-1, 0, ux+ii, 0, 1.0, 1.0, res_g+ii, 0, res_b+idxkid-1, 0, res_g+ii, 0, res_b+idxkid-1, 0);
340
341 }
342
343 mu += VECMULDOT(2*nb0+2*ng0+2*ns0, lam+ii, 0, t+ii, 0, res_m+ii, 0);
344 AXPY(2*nb0+2*ng0+2*ns0, -1.0, m+ii, 0, res_m+ii, 0, res_m+ii, 0);
345
346 }
347
348 res->res_mu = mu*nct_inv;
349
350 return;
351
352 }
353
354
355
COMPUTE_LIN_RES_TREE_OCP_QP(struct TREE_OCP_QP * qp,struct TREE_OCP_QP_SOL * qp_sol,struct TREE_OCP_QP_SOL * qp_step,struct TREE_OCP_QP_RES * res,struct TREE_OCP_QP_RES_WORKSPACE * ws)356 void COMPUTE_LIN_RES_TREE_OCP_QP(struct TREE_OCP_QP *qp, struct TREE_OCP_QP_SOL *qp_sol, struct TREE_OCP_QP_SOL *qp_step, struct TREE_OCP_QP_RES *res, struct TREE_OCP_QP_RES_WORKSPACE *ws)
357 {
358
359 struct tree *ttree = qp->dim->ttree;
360
361 // loop index
362 int ii, jj;
363
364 int nkids, idxkid;
365
366 //
367 int Nn = qp->dim->Nn;
368 int *nx = qp->dim->nx;
369 int *nu = qp->dim->nu;
370 int *nb = qp->dim->nb;
371 int *ng = qp->dim->ng;
372 int *ns = qp->dim->ns;
373
374 struct STRMAT *BAbt = qp->BAbt;
375 struct STRMAT *RSQrq = qp->RSQrq;
376 struct STRMAT *DCt = qp->DCt;
377 struct STRVEC *b = qp->b;
378 struct STRVEC *rqz = qp->rqz;
379 struct STRVEC *d = qp->d;
380 struct STRVEC *m = qp->m;
381 int **idxb = qp->idxb;
382 struct STRVEC *Z = qp->Z;
383 int **idxs = qp->idxs;
384
385 struct STRVEC *ux = qp_step->ux;
386 struct STRVEC *pi = qp_step->pi;
387 struct STRVEC *lam = qp_step->lam;
388 struct STRVEC *t = qp_step->t;
389
390 struct STRVEC *Lam = qp_sol->lam;
391 struct STRVEC *T = qp_sol->t;
392
393 struct STRVEC *res_g = res->res_g;
394 struct STRVEC *res_b = res->res_b;
395 struct STRVEC *res_d = res->res_d;
396 struct STRVEC *res_m = res->res_m;
397
398 struct STRVEC *tmp_nbgM = ws->tmp_nbgM;
399 struct STRVEC *tmp_nsM = ws->tmp_nsM;
400
401 int nx0, nx1, nu0, nu1, nb0, ng0, ns0;
402
403 //
404 REAL mu = 0.0;
405
406 // loop over nodes
407 for(ii=0; ii<Nn; ii++)
408 {
409
410 nx0 = nx[ii];
411 nu0 = nu[ii];
412 nb0 = nb[ii];
413 ng0 = ng[ii];
414 ns0 = ns[ii];
415
416 VECCP(nu0+nx0, rqz+ii, 0, res_g+ii, 0);
417
418 // if not root
419 if(ii>0)
420 AXPY(nx0, -1.0, pi+(ii-1), 0, res_g+ii, nu0, res_g+ii, nu0);
421
422 SYMV_L(nu0+nx0, nu0+nx0, 1.0, RSQrq+ii, 0, 0, ux+ii, 0, 1.0, res_g+ii, 0, res_g+ii, 0);
423
424 if(nb0+ng0>0)
425 {
426 AXPY(nb0+ng0, -1.0, lam+ii, 0, lam+ii, nb[ii]+ng[ii], tmp_nbgM+0, 0);
427 AXPY(nb0+ng0, 1.0, d+ii, 0, t+ii, 0, res_d+ii, 0);
428 AXPY(nb0+ng0, 1.0, d+ii, nb0+ng0, t+ii, nb0+ng0, res_d+ii, nb0+ng0);
429 // box
430 if(nb0>0)
431 {
432 VECAD_SP(nb0, 1.0, tmp_nbgM+0, 0, idxb[ii], res_g+ii, 0);
433 VECEX_SP(nb0, 1.0, idxb[ii], ux+ii, 0, tmp_nbgM+1, 0);
434 }
435 // general
436 if(ng0>0)
437 {
438 GEMV_NT(nu0+nx0, ng0, 1.0, 1.0, DCt+ii, 0, 0, tmp_nbgM+0, nb[ii], ux+ii, 0, 1.0, 0.0, res_g+ii, 0, tmp_nbgM+1, nb0, res_g+ii, 0, tmp_nbgM+1, nb0);
439 }
440
441 AXPY(nb0+ng0, -1.0, tmp_nbgM+1, 0, res_d+ii, 0, res_d+ii, 0);
442 AXPY(nb0+ng0, 1.0, tmp_nbgM+1, 0, res_d+ii, nb0+ng0, res_d+ii, nb0+ng0);
443 }
444 if(ns0>0)
445 {
446 // res_g
447 GEMV_DIAG(2*ns0, 1.0, Z+ii, 0, ux+ii, nu0+nx0, 1.0, rqz+ii, nu0+nx0, res_g+ii, nu0+nx0);
448 AXPY(2*ns0, -1.0, lam+ii, 2*nb0+2*ng0, res_g+ii, nu0+nx0, res_g+ii, nu0+nx0);
449 VECEX_SP(ns0, 1.0, idxs[ii], lam+ii, 0, tmp_nsM, 0);
450 AXPY(ns0, -1.0, tmp_nsM, 0, res_g+ii, nu0+nx0, res_g+ii, nu0+nx0);
451 VECEX_SP(ns0, 1.0, idxs[ii], lam+ii, nb0+ng0, tmp_nsM, 0);
452 AXPY(ns0, -1.0, tmp_nsM, 0, res_g+ii, nu0+nx0+ns0, res_g+ii, nu0+nx0+ns0);
453 // res_d
454 VECAD_SP(ns0, -1.0, ux+ii, nu0+nx0, idxs[ii], res_d+ii, 0);
455 VECAD_SP(ns0, -1.0, ux+ii, nu0+nx0+ns0, idxs[ii], res_d+ii, nb0+ng0);
456 AXPY(2*ns0, -1.0, ux+ii, nu0+nx0, t+ii, 2*nb0+2*ng0, res_d+ii, 2*nb0+2*ng0);
457 AXPY(2*ns0, 1.0, d+ii, 2*nb0+2*ng0, res_d+ii, 2*nb0+2*ng0, res_d+ii, 2*nb0+2*ng0);
458 }
459
460 // work on kids
461 nkids = (ttree->root+ii)->nkids;
462 for(jj=0; jj<nkids; jj++)
463 {
464
465 idxkid = (ttree->root+ii)->kids[jj];
466
467 nu1 = nu[idxkid];
468 nx1 = nx[idxkid];
469
470 AXPY(nx1, -1.0, ux+idxkid, nu1, b+idxkid-1, 0, res_b+idxkid-1, 0);
471
472 GEMV_NT(nu0+nx0, nx1, 1.0, 1.0, BAbt+idxkid-1, 0, 0, pi+idxkid-1, 0, ux+ii, 0, 1.0, 1.0, res_g+ii, 0, res_b+idxkid-1, 0, res_g+ii, 0, res_b+idxkid-1, 0);
473
474 }
475
476 VECCP(2*nb0+2*ng0+2*ns0, m+ii, 0, res_m+ii, 0);
477 VECMULACC(2*nb0+2*ng0+2*ns0, Lam+ii, 0, t+ii, 0, res_m+ii, 0);
478 VECMULACC(2*nb0+2*ng0+2*ns0, lam+ii, 0, T+ii, 0, res_m+ii, 0);
479
480 }
481
482 return;
483
484 }
485
486
487
488 // backward Riccati recursion
FACT_SOLVE_KKT_UNCONSTR_TREE_OCP_QP(struct TREE_OCP_QP * qp,struct TREE_OCP_QP_SOL * qp_sol,struct TREE_OCP_QP_IPM_ARG * arg,struct TREE_OCP_QP_IPM_WORKSPACE * ws)489 void FACT_SOLVE_KKT_UNCONSTR_TREE_OCP_QP(struct TREE_OCP_QP *qp, struct TREE_OCP_QP_SOL *qp_sol, struct TREE_OCP_QP_IPM_ARG *arg, struct TREE_OCP_QP_IPM_WORKSPACE *ws)
490 {
491
492 int Nn = qp->dim->Nn;
493 int *nx = qp->dim->nx;
494 int *nu = qp->dim->nu;
495 int *nb = qp->dim->nb;
496 int *ng = qp->dim->ng;
497
498 struct tree *ttree = qp->dim->ttree;
499
500 struct STRMAT *BAbt = qp->BAbt;
501 struct STRMAT *RSQrq = qp->RSQrq;
502 struct STRVEC *b = qp->b;
503 struct STRVEC *rqz = qp->rqz;
504
505 struct STRVEC *ux = qp_sol->ux;
506 struct STRVEC *pi = qp_sol->pi;
507
508 struct STRMAT *L = ws->L;
509 struct STRMAT *AL = ws->AL;
510 struct STRVEC *tmp_nxM = ws->tmp_nxM;
511
512 //
513 int ii, jj;
514
515 int idx, nkids, idxkid;
516
517 struct CORE_QP_IPM_WORKSPACE *cws = ws->core_workspace;
518
519 // backward factorization and substitution
520
521 // loop over nodes, starting from the end
522
523 for(ii=0; ii<Nn; ii++)
524 {
525
526 idx = Nn-ii-1;
527
528 nkids = (ttree->root+idx)->nkids;
529
530 #if defined(DOUBLE_PRECISION)
531 TRCP_L(nu[idx]+nx[idx], RSQrq+idx, 0, 0, L+idx, 0, 0); // TODO blasfeo_dtrcp_l with m and n, for m>=n
532 #else
533 GECP(nu[idx]+nx[idx], nu[idx]+nx[idx], RSQrq+idx, 0, 0, L+idx, 0, 0); // TODO blasfeo_dtrcp_l with m and n, for m>=n
534 #endif
535 ROWIN(nu[idx]+nx[idx], 1.0, rqz+idx, 0, L+idx, nu[idx]+nx[idx], 0);
536
537 for(jj=0; jj<nkids; jj++)
538 {
539
540 idxkid = (ttree->root+idx)->kids[jj];
541
542 ROWIN(nx[idxkid], 1.0, b+idxkid-1, 0, BAbt+idxkid-1, nu[idx]+nx[idx], 0);
543 TRMM_RLNN(nu[idx]+nx[idx]+1, nx[idxkid], 1.0, L+idxkid, nu[idxkid], nu[idxkid], BAbt+idxkid-1, 0, 0, AL, 0, 0);
544 GEAD(1, nx[idxkid], 1.0, L+idxkid, nu[idxkid]+nx[idxkid], nu[idxkid], AL, nu[idx]+nx[idx], 0);
545
546 SYRK_LN_MN(nu[idx]+nx[idx]+1, nu[idx]+nx[idx], nx[idxkid], 1.0, AL, 0, 0, AL, 0, 0, 1.0, L+idx, 0, 0, L+idx, 0, 0);
547
548 }
549
550 POTRF_L_MN(nu[idx]+nx[idx]+1, nu[idx]+nx[idx], L+idx, 0, 0, L+idx, 0, 0);
551
552 }
553
554
555 // forward substitution
556
557 // loop over nodes, starting from the root
558
559 // root
560 ii = 0;
561
562 idx = ii;
563 nkids = (ttree->root+idx)->nkids;
564
565 ROWEX(nu[idx]+nx[idx], -1.0, L+idx, nu[idx]+nx[idx], 0, ux+idx, 0);
566 TRSV_LTN(nu[idx]+nx[idx], L+idx, 0, 0, ux+idx, 0, ux+idx, 0);
567
568 for(jj=0; jj<nkids; jj++)
569 {
570
571 idxkid = (ttree->root+idx)->kids[jj];
572
573 GEMV_T(nu[idx]+nx[idx], nx[idxkid], 1.0, BAbt+idxkid-1, 0, 0, ux+idx, 0, 1.0, b+idxkid-1, 0, ux+idxkid, nu[idxkid]);
574 ROWEX(nx[idxkid], 1.0, L+idxkid, nu[idxkid]+nx[idxkid], nu[idxkid], tmp_nxM, 0);
575 TRMV_LTN(nx[idxkid], nx[idxkid], L+idxkid, nu[idxkid], nu[idxkid], ux+idxkid, nu[idxkid], pi+idxkid-1, 0);
576 AXPY(nx[idxkid], 1.0, tmp_nxM, 0, pi+idxkid-1, 0, pi+idxkid-1, 0);
577 TRMV_LNN(nx[idxkid], nx[idxkid], L+idxkid, nu[idxkid], nu[idxkid], pi+idxkid-1, 0, pi+idxkid-1, 0);
578
579 }
580
581 // other nodes
582 for(ii=1; ii<Nn; ii++)
583 {
584
585 idx = ii;
586 nkids = (ttree->root+idx)->nkids;
587
588 ROWEX(nu[idx], -1.0, L+idx, nu[idx]+nx[idx], 0, ux+idx, 0);
589 TRSV_LTN_MN(nu[idx]+nx[idx], nu[idx], L+idx, 0, 0, ux+idx, 0, ux+idx, 0);
590
591 for(jj=0; jj<nkids; jj++)
592 {
593
594 idxkid = (ttree->root+idx)->kids[jj];
595
596 GEMV_T(nu[idx]+nx[idx], nx[idxkid], 1.0, BAbt+idxkid-1, 0, 0, ux+idx, 0, 1.0, b+idxkid-1, 0, ux+idxkid, nu[idxkid]);
597 ROWEX(nx[idxkid], 1.0, L+idxkid, nu[idxkid]+nx[idxkid], nu[idxkid], tmp_nxM, 0);
598 TRMV_LTN(nx[idxkid], nx[idxkid], L+idxkid, nu[idxkid], nu[idxkid], ux+idxkid, nu[idxkid], pi+idxkid-1, 0);
599 AXPY(nx[idxkid], 1.0, tmp_nxM, 0, pi+idxkid-1, 0, pi+idxkid-1, 0);
600 TRMV_LNN(nx[idxkid], nx[idxkid], L+idxkid, nu[idxkid], nu[idxkid], pi+idxkid-1, 0, pi+idxkid-1, 0);
601
602 }
603
604 }
605
606 return;
607
608 }
609
610
611
COND_SLACKS_FACT_SOLVE(int ss,struct TREE_OCP_QP * qp,struct TREE_OCP_QP_SOL * qp_sol,struct TREE_OCP_QP_IPM_WORKSPACE * ws)612 static void COND_SLACKS_FACT_SOLVE(int ss, struct TREE_OCP_QP *qp, struct TREE_OCP_QP_SOL *qp_sol, struct TREE_OCP_QP_IPM_WORKSPACE *ws)
613 {
614
615 int ii, idx;
616
617 int nx0 = qp->dim->nx[ss];
618 int nu0 = qp->dim->nu[ss];
619 int nb0 = qp->dim->nb[ss];
620 int ng0 = qp->dim->ng[ss];
621 int ns0 = qp->dim->ns[ss];
622
623 struct STRVEC *Z = qp->Z+ss;
624 int *idxs0 = qp->idxs[ss];
625
626 // struct STRVEC *res_g = ws->res->res_g+ss; // TODO !!!
627 struct STRVEC *res_g = qp->rqz+ss;
628
629 // struct STRVEC *dux = ws->sol_step->ux+ss; // TODO !!!
630 struct STRVEC *dux = qp_sol->ux+ss;
631
632 struct STRVEC *Gamma = ws->Gamma+ss;
633 struct STRVEC *gamma = ws->gamma+ss;
634 struct STRVEC *Zs_inv = ws->Zs_inv+ss;
635 struct STRVEC *tmp_nbgM = ws->tmp_nbgM;
636
637 REAL *ptr_Gamma = Gamma->pa;
638 REAL *ptr_gamma = gamma->pa;
639 REAL *ptr_Z = Z->pa;
640 REAL *ptr_Zs_inv = Zs_inv->pa;
641 REAL *ptr_dux = dux->pa;
642 REAL *ptr_res_g = res_g->pa;
643 REAL *ptr_tmp0 = (tmp_nbgM+0)->pa;
644 REAL *ptr_tmp1 = (tmp_nbgM+1)->pa;
645 REAL *ptr_tmp2 = (tmp_nbgM+2)->pa;
646 REAL *ptr_tmp3 = (tmp_nbgM+3)->pa;
647
648 REAL tmp0, tmp1;
649
650 VECCP(nb0+ng0, Gamma, 0, tmp_nbgM+0, 0);
651 VECCP(nb0+ng0, Gamma, nb0+ng0, tmp_nbgM+1, 0);
652 VECCP(nb0+ng0, gamma, 0, tmp_nbgM+2, 0);
653 VECCP(nb0+ng0, gamma, nb0+ng0, tmp_nbgM+3, 0);
654
655 for(ii=0; ii<ns0; ii++)
656 {
657 idx = idxs0[ii];
658 ptr_Zs_inv[0+ii] = ptr_Z[0+ii] + ptr_Gamma[0+idx] + ptr_Gamma[2*nb0+2*ng0+ii];
659 ptr_Zs_inv[ns0+ii] = ptr_Z[ns0+ii] + ptr_Gamma[nb0+ng0+idx] + ptr_Gamma[2*nb0+2*ng0+ns0+ii];
660 ptr_dux[nu0+nx0+ii] = ptr_res_g[nu0+nx0+ii] + ptr_gamma[0+idx] + ptr_gamma[2*nb0+2*ng0+ii];
661 ptr_dux[nu0+nx0+ns0+ii] = ptr_res_g[nu0+nx0+ns0+ii] + ptr_gamma[nb0+ng0+idx] + ptr_gamma[2*nb0+2*ng0+ns0+ii];
662 ptr_Zs_inv[0+ii] = 1.0/ptr_Zs_inv[0+ii];
663 ptr_Zs_inv[ns0+ii] = 1.0/ptr_Zs_inv[ns0+ii];
664 tmp0 = ptr_dux[nu0+nx0+ii]*ptr_Zs_inv[0+ii];
665 tmp1 = ptr_dux[nu0+nx0+ns0+ii]*ptr_Zs_inv[ns0+ii];
666 ptr_tmp0[idx] = ptr_tmp0[idx] - ptr_tmp0[idx]*ptr_Zs_inv[0+ii]*ptr_tmp0[idx];
667 ptr_tmp1[idx] = ptr_tmp1[idx] - ptr_tmp1[idx]*ptr_Zs_inv[ns0+ii]*ptr_tmp1[idx];
668 ptr_tmp2[idx] = ptr_tmp2[idx] - ptr_Gamma[0+idx]*tmp0;
669 ptr_tmp3[idx] = ptr_tmp3[idx] - ptr_Gamma[nb0+ng0+idx]*tmp1;
670 }
671
672 AXPY(nb0+ng0, 1.0, tmp_nbgM+1, 0, tmp_nbgM+0, 0, tmp_nbgM+0, 0);
673 AXPY(nb0+ng0, -1.0, tmp_nbgM+3, 0, tmp_nbgM+2, 0, tmp_nbgM+1, 0);
674
675 return;
676
677 }
678
679
680
COND_SLACKS_SOLVE(int ss,struct TREE_OCP_QP * qp,struct TREE_OCP_QP_SOL * qp_sol,struct TREE_OCP_QP_IPM_WORKSPACE * ws)681 static void COND_SLACKS_SOLVE(int ss, struct TREE_OCP_QP *qp, struct TREE_OCP_QP_SOL *qp_sol, struct TREE_OCP_QP_IPM_WORKSPACE *ws)
682 {
683
684 int ii, idx;
685
686 int nx0 = qp->dim->nx[ss];
687 int nu0 = qp->dim->nu[ss];
688 int nb0 = qp->dim->nb[ss];
689 int ng0 = qp->dim->ng[ss];
690 int ns0 = qp->dim->ns[ss];
691
692 int *idxs0 = qp->idxs[ss];
693
694 // struct STRVEC *res_g = ws->res->res_g+ss; // TODO !!!
695 struct STRVEC *res_g = qp->rqz+ss;
696
697 // struct STRVEC *dux = ws->sol_step->ux+ss; // TODO !!!
698 struct STRVEC *dux = qp_sol->ux+ss;
699
700 struct STRVEC *Gamma = ws->Gamma+ss;
701 struct STRVEC *gamma = ws->gamma+ss;
702 struct STRVEC *Zs_inv = ws->Zs_inv+ss;
703 struct STRVEC *tmp_nbgM = ws->tmp_nbgM;
704
705 REAL *ptr_Gamma = Gamma->pa;
706 REAL *ptr_gamma = gamma->pa;
707 REAL *ptr_Zs_inv = Zs_inv->pa;
708 REAL *ptr_dux = dux->pa;
709 REAL *ptr_res_g = res_g->pa;
710 REAL *ptr_tmp2 = (tmp_nbgM+2)->pa;
711 REAL *ptr_tmp3 = (tmp_nbgM+3)->pa;
712
713 REAL tmp0, tmp1;
714
715 VECCP(nb0+ng0, gamma, 0, tmp_nbgM+2, 0);
716 VECCP(nb0+ng0, gamma, nb0+ng0, tmp_nbgM+3, 0);
717
718 for(ii=0; ii<ns0; ii++)
719 {
720 idx = idxs0[ii];
721 ptr_dux[nu0+nx0+ii] = ptr_res_g[nu0+nx0+ii] + ptr_gamma[0+idx] + ptr_gamma[2*nb0+2*ng0+ii];
722 ptr_dux[nu0+nx0+ns0+ii] = ptr_res_g[nu0+nx0+ns0+ii] + ptr_gamma[nb0+ng0+idx] + ptr_gamma[2*nb0+2*ng0+ns0+ii];
723 tmp0 = ptr_dux[nu0+nx0+ii]*ptr_Zs_inv[0+ii];
724 tmp1 = ptr_dux[nu0+nx0+ns0+ii]*ptr_Zs_inv[ns0+ii];
725 ptr_tmp2[idx] = ptr_tmp2[idx] - ptr_Gamma[0+idx]*tmp0;
726 ptr_tmp3[idx] = ptr_tmp3[idx] - ptr_Gamma[nb0+ng0+idx]*tmp1;
727 }
728
729 AXPY(nb0+ng0, -1.0, tmp_nbgM+3, 0, tmp_nbgM+2, 0, tmp_nbgM+1, 0);
730
731 return;
732
733 }
734
735
736
EXPAND_SLACKS(int ss,struct TREE_OCP_QP * qp,struct TREE_OCP_QP_SOL * qp_sol,struct TREE_OCP_QP_IPM_WORKSPACE * ws)737 static void EXPAND_SLACKS(int ss, struct TREE_OCP_QP *qp, struct TREE_OCP_QP_SOL *qp_sol, struct TREE_OCP_QP_IPM_WORKSPACE *ws)
738 {
739
740 int ii, idx;
741
742 int nx0 = qp->dim->nx[ss];
743 int nu0 = qp->dim->nu[ss];
744 int nb0 = qp->dim->nb[ss];
745 int ng0 = qp->dim->ng[ss];
746 int ns0 = qp->dim->ns[ss];
747
748 int *idxs0 = qp->idxs[ss];
749
750 struct STRVEC *dux = qp_sol->ux+ss;
751 struct STRVEC *dt = qp_sol->t+ss;
752
753 struct STRVEC *Gamma = ws->Gamma+ss;
754 struct STRVEC *Zs_inv = ws->Zs_inv+ss;
755
756 REAL *ptr_Gamma = Gamma->pa;
757 REAL *ptr_dux = dux->pa;
758 REAL *ptr_dt = dt->pa;
759 REAL *ptr_Zs_inv = Zs_inv->pa;
760
761 for(ii=0; ii<ns0; ii++)
762 {
763 idx = idxs0[ii];
764 ptr_dux[nu0+nx0+ii] = - ptr_Zs_inv[0+ii] * (ptr_dux[nu0+nx0+ii] + ptr_dt[idx]*ptr_Gamma[idx]);
765 ptr_dux[nu0+nx0+ns0+ii] = - ptr_Zs_inv[ns0+ii] * (ptr_dux[nu0+nx0+ns0+ii] + ptr_dt[nb0+ng0+idx]*ptr_Gamma[nb0+ng0+idx]);
766 ptr_dt[2*nb0+2*ng0+ii] = ptr_dux[nu0+nx0+ii];
767 ptr_dt[2*nb0+2*ng0+ns0+ii] = ptr_dux[nu0+nx0+ns0+ii];
768 ptr_dt[0+idx] = ptr_dt[0+idx] + ptr_dux[nu0+nx0+ii];
769 ptr_dt[nb0+ng0+idx] = ptr_dt[nb0+ng0+idx] + ptr_dux[nu0+nx0+ns0+ii];
770
771 }
772
773 return;
774
775 }
776
777
778
779 // backward Riccati recursion
FACT_SOLVE_KKT_STEP_TREE_OCP_QP(struct TREE_OCP_QP * qp,struct TREE_OCP_QP_SOL * qp_sol,struct TREE_OCP_QP_IPM_ARG * arg,struct TREE_OCP_QP_IPM_WORKSPACE * ws)780 void FACT_SOLVE_KKT_STEP_TREE_OCP_QP(struct TREE_OCP_QP *qp, struct TREE_OCP_QP_SOL *qp_sol, struct TREE_OCP_QP_IPM_ARG *arg, struct TREE_OCP_QP_IPM_WORKSPACE *ws)
781 {
782
783 int Nn = qp->dim->Nn;
784 int *nx = qp->dim->nx;
785 int *nu = qp->dim->nu;
786 int *nb = qp->dim->nb;
787 int *ng = qp->dim->ng;
788 int *ns = qp->dim->ns;
789
790 struct tree *ttree = qp->dim->ttree;
791
792 struct STRMAT *BAbt = qp->BAbt;
793 struct STRMAT *RSQrq = qp->RSQrq;
794 struct STRMAT *DCt = qp->DCt;
795 struct STRVEC *Z = qp->Z;
796 struct STRVEC *res_g = qp->rqz;
797 struct STRVEC *res_b = qp->b;
798 struct STRVEC *res_d = qp->d;
799 struct STRVEC *res_m = qp->m;
800 int **idxb = qp->idxb;
801 int **idxs = qp->idxs;
802
803 struct STRVEC *dux = qp_sol->ux;
804 struct STRVEC *dpi = qp_sol->pi;
805 struct STRVEC *dlam = qp_sol->lam;
806 struct STRVEC *dt = qp_sol->t;
807
808 struct STRMAT *L = ws->L;
809 struct STRMAT *AL = ws->AL;
810 struct STRVEC *Gamma = ws->Gamma;
811 struct STRVEC *gamma = ws->gamma;
812 struct STRVEC *Pb = ws->Pb;
813 struct STRVEC *Zs_inv = ws->Zs_inv;
814 struct STRVEC *tmp_nxM = ws->tmp_nxM;
815 struct STRVEC *tmp_nbgM = ws->tmp_nbgM;
816
817 REAL *ptr0, *ptr1, *ptr2, *ptr3;
818
819 //
820 int ss, jj;
821
822 int idx, nkids, idxkid;
823
824 struct CORE_QP_IPM_WORKSPACE *cws = ws->core_workspace;
825
826
827 COMPUTE_GAMMA_GAMMA_QP(res_d[0].pa, res_m[0].pa, cws);
828
829 // backward factorization and substitution
830
831 // loop over nodes, starting from the end
832 for(ss=0; ss<Nn; ss++)
833 {
834
835 idx = Nn-ss-1;
836
837 nkids = (ttree->root+idx)->nkids;
838
839 #if defined(DOUBLE_PRECISION)
840 TRCP_L(nu[idx]+nx[idx], RSQrq+idx, 0, 0, L+idx, 0, 0); // TODO blasfeo_dtrcp_l with m and n, for m>=n
841 #else
842 GECP(nu[idx]+nx[idx], nu[idx]+nx[idx], RSQrq+idx, 0, 0, L+idx, 0, 0); // TODO blasfeo_dtrcp_l with m and n, for m>=n
843 #endif
844 DIARE(nu[idx]+nx[idx], arg->reg_prim, L+idx, 0, 0);
845 ROWIN(nu[idx]+nx[idx], 1.0, res_g+idx, 0, L+idx, nu[idx]+nx[idx], 0);
846
847 for(jj=0; jj<nkids; jj++)
848 {
849
850 idxkid = (ttree->root+idx)->kids[jj];
851
852 GECP(nu[idx]+nx[idx], nx[idxkid], BAbt+idxkid-1, 0, 0, AL, 0, 0);
853 ROWIN(nx[idxkid], 1.0, res_b+idxkid-1, 0, AL, nu[idx]+nx[idx], 0);
854 TRMM_RLNN(nu[idx]+nx[idx]+1, nx[idxkid], 1.0, L+idxkid, nu[idxkid], nu[idxkid], AL, 0, 0, AL, 0, 0);
855 ROWEX(nx[idxkid], 1.0, AL, nu[idx]+nx[idx], 0, tmp_nxM, 0);
856 TRMV_LNN(nx[idxkid], nx[idxkid], L+idxkid, nu[idxkid], nu[idxkid], tmp_nxM, 0, Pb+idxkid-1, 0);
857 GEAD(1, nx[idxkid], 1.0, L+idxkid, nu[idxkid]+nx[idxkid], nu[idxkid], AL, nu[idx]+nx[idx], 0);
858
859 SYRK_LN_MN(nu[idx]+nx[idx]+1, nu[idx]+nx[idx], nx[idxkid], 1.0, AL, 0, 0, AL, 0, 0, 1.0, L+idx, 0, 0, L+idx, 0, 0);
860
861 }
862
863 if(ns[idx]>0)
864 {
865 COND_SLACKS_FACT_SOLVE(idx, qp, qp_sol, ws);
866 }
867 else if(nb[idx]+ng[idx]>0)
868 {
869 AXPY(nb[idx]+ng[idx], 1.0, Gamma+idx, nb[idx]+ng[idx], Gamma+idx, 0, tmp_nbgM+0, 0);
870 AXPY(nb[idx]+ng[idx], -1.0, gamma+idx, nb[idx]+ng[idx], gamma+idx, 0, tmp_nbgM+1, 0);
871 }
872 if(nb[idx]>0)
873 {
874 DIAAD_SP(nb[idx], 1.0, tmp_nbgM+0, 0, idxb[idx], L+idx, 0, 0);
875 ROWAD_SP(nb[idx], 1.0, tmp_nbgM+1, 0, idxb[idx], L+idx, nu[idx]+nx[idx], 0);
876 }
877 if(ng[idx]>0)
878 {
879 GEMM_R_DIAG(nu[idx]+nx[idx], ng[idx], 1.0, DCt+idx, 0, 0, tmp_nbgM+0, nb[idx], 0.0, AL+0, 0, 0, AL+0, 0, 0);
880 ROWIN(ng[idx], 1.0, tmp_nbgM+1, nb[idx], AL+0, nu[idx]+nx[idx], 0);
881 SYRK_POTRF_LN_MN(nu[idx]+nx[idx]+1, nu[idx]+nx[idx], ng[idx], AL+0, 0, 0, DCt+idx, 0, 0, L+idx, 0, 0, L+idx, 0, 0);
882 }
883 else
884 {
885 POTRF_L_MN(nu[idx]+nx[idx]+1, nu[idx]+nx[idx], L+idx, 0, 0, L+idx, 0, 0);
886 }
887
888 }
889
890 // forward substitution
891
892 // loop over nodes, starting from the root
893 for(ss=0; ss<Nn; ss++)
894 {
895
896 idx = ss;
897 nkids = (ttree->root+idx)->nkids;
898
899 if(idx>0)
900 {
901 ROWEX(nu[idx], -1.0, L+idx, nu[idx]+nx[idx], 0, dux+idx, 0);
902 TRSV_LTN_MN(nu[idx]+nx[idx], nu[idx], L+idx, 0, 0, dux+idx, 0, dux+idx, 0);
903 }
904 else
905 {
906 ROWEX(nu[idx]+nx[idx], -1.0, L+idx, nu[idx]+nx[idx], 0, dux+idx, 0);
907 TRSV_LTN(nu[idx]+nx[idx], L+idx, 0, 0, dux+idx, 0, dux+idx, 0);
908 }
909
910 for(jj=0; jj<nkids; jj++)
911 {
912
913 idxkid = (ttree->root+idx)->kids[jj];
914
915 GEMV_T(nu[idx]+nx[idx], nx[idxkid], 1.0, BAbt+idxkid-1, 0, 0, dux+idx, 0, 1.0, res_b+idxkid-1, 0, dux+idxkid, nu[idxkid]);
916 if(arg->comp_dual_sol)
917 {
918 ROWEX(nx[idxkid], 1.0, L+idxkid, nu[idxkid]+nx[idxkid], nu[idxkid], tmp_nxM, 0);
919 TRMV_LTN(nx[idxkid], nx[idxkid], L+idxkid, nu[idxkid], nu[idxkid], dux+idxkid, nu[idxkid], dpi+idxkid-1, 0);
920 AXPY(nx[idxkid], 1.0, tmp_nxM, 0, dpi+idxkid-1, 0, dpi+idxkid-1, 0);
921 TRMV_LNN(nx[idxkid], nx[idxkid], L+idxkid, nu[idxkid], nu[idxkid], dpi+idxkid-1, 0, dpi+idxkid-1, 0);
922 }
923
924 }
925
926 }
927
928
929
930 for(ss=0; ss<Nn; ss++)
931 VECEX_SP(nb[ss], 1.0, idxb[ss], dux+ss, 0, dt+ss, 0);
932 for(ss=0; ss<Nn; ss++)
933 GEMV_T(nu[ss]+nx[ss], ng[ss], 1.0, DCt+ss, 0, 0, dux+ss, 0, 0.0, dt+ss, nb[ss], dt+ss, nb[ss]);
934
935 for(ss=0; ss<Nn; ss++)
936 {
937 VECCP(nb[ss]+ng[ss], dt+ss, 0, dt+ss, nb[ss]+ng[ss]);
938 VECSC(nb[ss]+ng[ss], -1.0, dt+ss, nb[ss]+ng[ss]);
939 }
940
941 for(ss=0; ss<Nn; ss++)
942 {
943 if(ns[ss]>0)
944 EXPAND_SLACKS(ss, qp, qp_sol, ws);
945 }
946
947 COMPUTE_LAM_T_QP(res_d[0].pa, res_m[0].pa, dlam[0].pa, dt[0].pa, cws);
948
949 return;
950
951 }
952
953
954
955 // backward Riccati recursion
FACT_LQ_SOLVE_KKT_STEP_TREE_OCP_QP(struct TREE_OCP_QP * qp,struct TREE_OCP_QP_SOL * qp_sol,struct TREE_OCP_QP_IPM_ARG * arg,struct TREE_OCP_QP_IPM_WORKSPACE * ws)956 void FACT_LQ_SOLVE_KKT_STEP_TREE_OCP_QP(struct TREE_OCP_QP *qp, struct TREE_OCP_QP_SOL *qp_sol, struct TREE_OCP_QP_IPM_ARG *arg, struct TREE_OCP_QP_IPM_WORKSPACE *ws)
957 {
958
959 int Nn = qp->dim->Nn;
960 int *nx = qp->dim->nx;
961 int *nu = qp->dim->nu;
962 int *nb = qp->dim->nb;
963 int *ng = qp->dim->ng;
964 int *ns = qp->dim->ns;
965
966 struct tree *ttree = qp->dim->ttree;
967
968 struct STRMAT *BAbt = qp->BAbt;
969 struct STRMAT *RSQrq = qp->RSQrq;
970 struct STRMAT *DCt = qp->DCt;
971 struct STRVEC *Z = qp->Z;
972 struct STRVEC *res_g = qp->rqz;
973 struct STRVEC *res_b = qp->b;
974 struct STRVEC *res_d = qp->d;
975 struct STRVEC *res_m = qp->m;
976 int **idxb = qp->idxb;
977 int **idxs = qp->idxs;
978
979 struct STRVEC *dux = qp_sol->ux;
980 struct STRVEC *dpi = qp_sol->pi;
981 struct STRVEC *dlam = qp_sol->lam;
982 struct STRVEC *dt = qp_sol->t;
983
984 struct STRMAT *L = ws->L;
985 struct STRMAT *Lh = ws->Lh;
986 struct STRMAT *AL = ws->AL;
987 struct STRVEC *Gamma = ws->Gamma;
988 struct STRVEC *gamma = ws->gamma;
989 struct STRVEC *Pb = ws->Pb;
990 struct STRVEC *Zs_inv = ws->Zs_inv;
991 struct STRVEC *tmp_nxM = ws->tmp_nxM;
992 struct STRVEC *tmp_nbgM = ws->tmp_nbgM;
993 struct STRMAT *lq0 = ws->lq0;
994 void *lq_work0 = ws->lq_work0;
995
996 REAL *ptr0, *ptr1, *ptr2, *ptr3;
997
998 REAL tmp;
999
1000 //
1001 int ss, ii, jj;
1002
1003 int idx, nkids, idxkid;
1004
1005 struct CORE_QP_IPM_WORKSPACE *cws = ws->core_workspace;
1006
1007
1008 COMPUTE_GAMMA_GAMMA_QP(res_d[0].pa, res_m[0].pa, cws);
1009
1010 // backward factorization and substitution
1011
1012 // loop over nodes, starting from the end
1013 for(ss=0; ss<Nn; ss++)
1014 {
1015
1016 idx = Nn-ss-1;
1017
1018 nkids = (ttree->root+idx)->nkids;
1019
1020 // GESE(nu[idx]+nx[idx], 2*nu[idx]+2*nx[idx]+ng[idx], 0.0, lq0, 0, 0);
1021 GESE(nu[idx]+nx[idx], nu[idx]+nx[idx]+ng[idx], 0.0, lq0, 0, nu[idx]+nx[idx]);
1022 //
1023 if(ws->use_hess_fact[idx]==0)
1024 {
1025 POTRF_L(nu[idx]+nx[idx], RSQrq+idx, 0, 0, Lh+idx, 0, 0);
1026 ws->use_hess_fact[idx]=1;
1027 }
1028 #if defined(LA_HIGH_PERFORMANCE) | defined(LA_REFERENCE)
1029 TRCP_L(nu[idx]+nx[idx], Lh+idx, 0, 0, L+idx, 0, 0);
1030 #else
1031 GECP(nu[idx]+nx[idx], nu[idx]+nx[idx], Lh+idx, 0, 0, L+idx, 0, 0);
1032 #endif
1033
1034 VECCP(nu[idx]+nx[idx], res_g+idx, 0, dux+idx, 0);
1035
1036 if(ns[idx]>0)
1037 {
1038 COND_SLACKS_FACT_SOLVE(idx, qp, qp_sol, ws);
1039 }
1040 else if(nb[idx]+ng[idx]>0)
1041 {
1042 AXPY(nb[idx]+ng[idx], 1.0, Gamma+idx, nb[idx]+ng[idx], Gamma+idx, 0, tmp_nbgM+0, 0);
1043 AXPY(nb[idx]+ng[idx], -1.0, gamma+idx, nb[idx]+ng[idx], gamma+idx, 0, tmp_nbgM+1, 0);
1044 }
1045 if(nb[idx]>0)
1046 {
1047 for(ii=0; ii<nb[idx]; ii++)
1048 {
1049 tmp = BLASFEO_DVECEL(tmp_nbgM+0, ii);
1050 tmp = tmp>=0.0 ? tmp : 0.0;
1051 tmp = sqrt( tmp );
1052 BLASFEO_DMATEL(lq0, idxb[idx][ii], nu[idx]+nx[idx]+idxb[idx][ii]) = tmp>0.0 ? tmp : 0.0;
1053 }
1054 VECAD_SP(nb[idx], 1.0, tmp_nbgM+1, 0, idxb[idx], dux+idx, 0);
1055 }
1056 if(ng[idx]>0)
1057 {
1058 for(ii=0; ii<ng[idx]; ii++)
1059 {
1060 tmp = BLASFEO_DVECEL(tmp_nbgM+0, nb[idx]+ii);
1061 tmp = tmp>=0.0 ? tmp : 0.0;
1062 tmp = sqrt( tmp );
1063 BLASFEO_DVECEL(tmp_nbgM+0, nb[idx]+ii) = tmp;
1064 }
1065 GEMM_R_DIAG(nu[idx]+nx[idx], ng[idx], 1.0, DCt+idx, 0, 0, tmp_nbgM+0, nb[idx], 0.0, lq0, 0, 2*nu[idx]+2*nx[idx], lq0, 0, 2*nu[idx]+2*nx[idx]);
1066 GEMV_N(nu[idx]+nx[idx], ng[idx], 1.0, DCt+idx, 0, 0, tmp_nbgM+1, nb[idx], 1.0, dux+idx, 0, dux+idx, 0);
1067 }
1068
1069 DIARE(nu[idx]+nx[idx], arg->reg_prim, lq0, 0, nu[idx]+nx[idx]);
1070 #if defined(LA_HIGH_PERFORMANCE) | defined(LA_REFERENCE)
1071 GELQF_PD_LLA(nu[idx]+nx[idx], ng[idx], L+idx, 0, 0, lq0, 0, nu[idx]+nx[idx], lq0, 0, 2*nu[idx]+2*nx[idx], lq_work0); // TODO reduce lq1 size !!!
1072 #else
1073 TRCP_L(nu[idx]+nx[idx], L+idx, 0, 0, lq0, 0, 0);
1074 GELQF(nu[idx]+nx[idx], 2*nu[idx]+2*nx[idx]+ng[idx], lq0, 0, 0, lq0, 0, 0, lq_work0);
1075 TRCP_L(nu[idx]+nx[idx], lq0, 0, 0, L+idx, 0, 0);
1076 for(ii=0; ii<nu[idx]+nx[idx]; ii++)
1077 if(BLASFEO_DMATEL(L+idx, ii, ii) < 0)
1078 COLSC(nu[idx]+nx[idx]-ii, -1.0, L+idx, ii, ii);
1079 #endif
1080
1081 for(jj=0; jj<nkids; jj++)
1082 {
1083
1084 idxkid = (ttree->root+idx)->kids[jj];
1085
1086 TRMM_RLNN(nu[idx]+nx[idx], nx[idxkid], 1.0, L+idxkid, nu[idxkid], nu[idxkid], BAbt+idxkid-1, 0, 0, lq0, 0, nu[idx]+nx[idx]);
1087
1088 #if defined(LA_HIGH_PERFORMANCE) | defined(LA_REFERENCE)
1089 GELQF_PD_LA(nu[idx]+nx[idx], nx[idxkid], L+idx, 0, 0, lq0, 0, nu[idx]+nx[idx], lq_work0);
1090 #else
1091 TRCP_L(nu[idx]+nx[idx], L+idx, 0, 0, lq0, 0, 0);
1092 GELQF(nu[idx]+nx[idx], nu[idx]+nx[idx]+nx[idxkid], lq0, 0, 0, lq0, 0, 0, lq_work0);
1093 TRCP_L(nu[idx]+nx[idx], lq0, 0, 0, L+idx, 0, 0);
1094 for(ii=0; ii<nu[idx]+nx[idx]; ii++)
1095 if(BLASFEO_DMATEL(L+idx, ii, ii) < 0)
1096 COLSC(nu[idx]+nx[idx]-ii, -1.0, L+idx, ii, ii);
1097 #endif
1098
1099 TRMV_LTN(nx[idxkid], nx[idxkid], L+idxkid, nu[idxkid], nu[idxkid], res_b+idxkid-1, 0, Pb+idxkid-1, 0);
1100 TRMV_LNN(nx[idxkid], nx[idxkid], L+idxkid, nu[idxkid], nu[idxkid], Pb+idxkid-1, 0, Pb+idxkid-1, 0);
1101 AXPY(nx[idxkid], 1.0, dux+idxkid, nu[idxkid], Pb+idxkid-1, 0, tmp_nxM, 0);
1102 GEMV_N(nu[idx]+nx[idx], nx[idxkid], 1.0, BAbt+idxkid-1, 0, 0, tmp_nxM, 0, 1.0, dux+idx, 0, dux+idx, 0);
1103
1104 }
1105
1106 if(idx>0)
1107 {
1108 TRSV_LNN_MN(nu[idx]+nx[idx], nu[idx], L+idx, 0, 0, dux+idx, 0, dux+idx, 0);
1109 }
1110 else // root
1111 {
1112 TRSV_LNN(nu[idx]+nx[idx], L+idx, 0, 0, dux+idx, 0, dux+idx, 0);
1113 }
1114
1115 }
1116
1117
1118
1119 // forward substitution
1120
1121 // loop over nodes, starting from the root
1122 for(ss=0; ss<Nn; ss++)
1123 {
1124
1125 idx = ss;
1126 nkids = (ttree->root+idx)->nkids;
1127
1128 if(idx>0)
1129 {
1130 VECSC(nu[idx], -1.0, dux+idx, 0);
1131 TRSV_LTN_MN(nu[idx]+nx[idx], nu[idx], L+idx, 0, 0, dux+idx, 0, dux+idx, 0);
1132 }
1133 else // root
1134 {
1135 VECSC(nu[idx]+nx[idx], -1.0, dux+idx, 0);
1136 TRSV_LTN(nu[idx]+nx[idx], L+idx, 0, 0, dux+idx, 0, dux+idx, 0);
1137 }
1138
1139 for(jj=0; jj<nkids; jj++)
1140 {
1141
1142 idxkid = (ttree->root+idx)->kids[jj];
1143
1144 VECCP(nx[idxkid], dux+idxkid, nu[idxkid], dpi+idxkid-1, 0);
1145 GEMV_T(nu[idx]+nx[idx], nx[idxkid], 1.0, BAbt+idxkid-1, 0, 0, dux+idx, 0, 1.0, res_b+idxkid-1, 0, dux+idxkid, nu[idxkid]);
1146 VECCP(nx[idxkid], dux+idxkid, nu[idxkid], tmp_nxM, 0);
1147 TRMV_LTN(nx[idxkid], nx[idxkid], L+idxkid, nu[idxkid], nu[idxkid], tmp_nxM, 0, tmp_nxM, 0);
1148 TRMV_LNN(nx[idxkid], nx[idxkid], L+idxkid, nu[idxkid], nu[idxkid], tmp_nxM, 0, tmp_nxM, 0);
1149 AXPY(nx[idxkid], 1.0, tmp_nxM, 0, dpi+idxkid-1, 0, dpi+idxkid-1, 0);
1150
1151 }
1152
1153 }
1154
1155
1156
1157 for(ss=0; ss<Nn; ss++)
1158 VECEX_SP(nb[ss], 1.0, idxb[ss], dux+ss, 0, dt+ss, 0);
1159 for(ss=0; ss<Nn; ss++)
1160 GEMV_T(nu[ss]+nx[ss], ng[ss], 1.0, DCt+ss, 0, 0, dux+ss, 0, 0.0, dt+ss, nb[ss], dt+ss, nb[ss]);
1161
1162 for(ss=0; ss<Nn; ss++)
1163 {
1164 VECCP(nb[ss]+ng[ss], dt+ss, 0, dt+ss, nb[ss]+ng[ss]);
1165 VECSC(nb[ss]+ng[ss], -1.0, dt+ss, nb[ss]+ng[ss]);
1166 }
1167
1168 for(ss=0; ss<Nn; ss++)
1169 {
1170 if(ns[ss]>0)
1171 EXPAND_SLACKS(ss, qp, qp_sol, ws);
1172 }
1173
1174 COMPUTE_LAM_T_QP(res_d[0].pa, res_m[0].pa, dlam[0].pa, dt[0].pa, cws);
1175
1176 return;
1177
1178 }
1179
1180
1181 // backward Riccati recursion
SOLVE_KKT_STEP_TREE_OCP_QP(struct TREE_OCP_QP * qp,struct TREE_OCP_QP_SOL * qp_sol,struct TREE_OCP_QP_IPM_ARG * arg,struct TREE_OCP_QP_IPM_WORKSPACE * ws)1182 void SOLVE_KKT_STEP_TREE_OCP_QP(struct TREE_OCP_QP *qp, struct TREE_OCP_QP_SOL *qp_sol, struct TREE_OCP_QP_IPM_ARG *arg, struct TREE_OCP_QP_IPM_WORKSPACE *ws)
1183 {
1184
1185 int Nn = qp->dim->Nn;
1186 int *nx = qp->dim->nx;
1187 int *nu = qp->dim->nu;
1188 int *nb = qp->dim->nb;
1189 int *ng = qp->dim->ng;
1190 int *ns = qp->dim->ns;
1191
1192 struct tree *ttree = qp->dim->ttree;
1193
1194 struct STRMAT *BAbt = qp->BAbt;
1195 // struct STRMAT *RSQrq = qp->RSQrq;
1196 struct STRMAT *DCt = qp->DCt;
1197 struct STRVEC *res_g = qp->rqz;
1198 struct STRVEC *res_b = qp->b;
1199 struct STRVEC *res_d = qp->d;
1200 struct STRVEC *res_m = qp->m;
1201 int **idxb = qp->idxb;
1202 // int **idxs = qp->idxs;
1203
1204 struct STRVEC *dux = qp_sol->ux;
1205 struct STRVEC *dpi = qp_sol->pi;
1206 struct STRVEC *dlam = qp_sol->lam;
1207 struct STRVEC *dt = qp_sol->t;
1208
1209 struct STRMAT *L = ws->L;
1210 struct STRVEC *gamma = ws->gamma;
1211 struct STRVEC *Pb = ws->Pb;
1212 struct STRVEC *tmp_nxM = ws->tmp_nxM;
1213 struct STRVEC *tmp_nbgM = ws->tmp_nbgM;
1214
1215 //
1216 int ii, jj;
1217
1218 int idx, nkids, idxkid;
1219
1220 struct CORE_QP_IPM_WORKSPACE *cws = ws->core_workspace;
1221
1222 COMPUTE_GAMMA_QP(res_d[0].pa, res_m[0].pa, cws);
1223
1224
1225 // backward substitution
1226
1227 // loop over nodes, starting from the end
1228 for(ii=0; ii<Nn; ii++)
1229 {
1230
1231 idx = Nn-ii-1;
1232
1233 nkids = (ttree->root+idx)->nkids;
1234
1235 VECCP(nu[idx]+nx[idx], res_g+idx, 0, dux+idx, 0);
1236
1237 for(jj=0; jj<nkids; jj++)
1238 {
1239
1240 idxkid = (ttree->root+idx)->kids[jj];
1241
1242 if(ws->use_Pb)
1243 {
1244 AXPY(nx[idxkid], 1.0, dux+idxkid, nu[idxkid], Pb+idxkid-1, 0, tmp_nxM, 0);
1245 }
1246 else
1247 {
1248 TRMV_LTN(nx[idxkid], nx[idxkid], L+idxkid, nu[idxkid], nu[idxkid], res_b+idxkid-1, 0, tmp_nxM, 0);
1249 TRMV_LNN(nx[idxkid], nx[idxkid], L+idxkid, nu[idxkid], nu[idxkid], tmp_nxM, 0, tmp_nxM, 0);
1250 AXPY(nx[idxkid], 1.0, dux+idxkid, nu[idxkid], tmp_nxM, 0, tmp_nxM, 0);
1251 }
1252 GEMV_N(nu[idx]+nx[idx], nx[idxkid], 1.0, BAbt+idxkid-1, 0, 0, tmp_nxM, 0, 1.0, dux+idx, 0, dux+idx, 0);
1253
1254 }
1255
1256 if(ns[idx]>0)
1257 {
1258 COND_SLACKS_SOLVE(idx, qp, qp_sol, ws);
1259 }
1260 else if(nb[idx]+ng[idx]>0)
1261 {
1262 AXPY(nb[idx]+ng[idx], -1.0, gamma+idx, nb[idx]+ng[idx], gamma+idx, 0, tmp_nbgM+1, 0);
1263 }
1264 if(nb[idx]>0)
1265 {
1266 VECAD_SP(nb[idx], 1.0, tmp_nbgM+1, 0, idxb[idx], dux+idx, 0);
1267 }
1268 if(ng[idx]>0)
1269 {
1270 GEMV_N(nu[idx]+nx[idx], ng[idx], 1.0, DCt+idx, 0, 0, tmp_nbgM+1, nb[idx], 1.0, dux+idx, 0, dux+idx, 0);
1271 }
1272
1273 if(idx>0)
1274 {
1275 TRSV_LNN_MN(nu[idx]+nx[idx], nu[idx], L+idx, 0, 0, dux+idx, 0, dux+idx, 0);
1276 }
1277 else // root
1278 {
1279 TRSV_LNN(nu[idx]+nx[idx], L+idx, 0, 0, dux+idx, 0, dux+idx, 0);
1280 }
1281
1282 }
1283
1284
1285 // forward substitution
1286
1287 // loop over nodes, starting from the root
1288 for(ii=0; ii<Nn; ii++)
1289 {
1290
1291 idx = ii;
1292 nkids = (ttree->root+idx)->nkids;
1293
1294 if(idx>0)
1295 {
1296 VECSC(nu[idx], -1.0, dux+idx, 0);
1297 TRSV_LTN_MN(nu[idx]+nx[idx], nu[idx], L+idx, 0, 0, dux+idx, 0, dux+idx, 0);
1298 }
1299 else // root
1300 {
1301 VECSC(nu[idx]+nx[idx], -1.0, dux+idx, 0);
1302 TRSV_LTN(nu[idx]+nx[idx], L+idx, 0, 0, dux+idx, 0, dux+idx, 0);
1303 }
1304
1305 for(jj=0; jj<nkids; jj++)
1306 {
1307
1308 idxkid = (ttree->root+idx)->kids[jj];
1309
1310 if(arg->comp_dual_sol)
1311 {
1312 VECCP(nx[idxkid], dux+idxkid, nu[idxkid], dpi+idxkid-1, 0);
1313 }
1314 GEMV_T(nu[idx]+nx[idx], nx[idxkid], 1.0, BAbt+idxkid-1, 0, 0, dux+idx, 0, 1.0, res_b+idxkid-1, 0, dux+idxkid, nu[idxkid]);
1315 if(arg->comp_dual_sol)
1316 {
1317 VECCP(nx[idxkid], dux+idxkid, nu[idxkid], tmp_nxM, 0);
1318 TRMV_LTN(nx[idxkid], nx[idxkid], L+idxkid, nu[idxkid], nu[idxkid], tmp_nxM, 0, tmp_nxM, 0);
1319 TRMV_LNN(nx[idxkid], nx[idxkid], L+idxkid, nu[idxkid], nu[idxkid], tmp_nxM, 0, tmp_nxM, 0);
1320 AXPY(nx[idxkid], 1.0, tmp_nxM, 0, dpi+idxkid-1, 0, dpi+idxkid-1, 0);
1321 }
1322
1323 }
1324
1325 }
1326
1327
1328
1329 for(ii=0; ii<Nn; ii++)
1330 VECEX_SP(nb[ii], 1.0, idxb[ii], dux+ii, 0, dt+ii, 0);
1331 for(ii=0; ii<Nn; ii++)
1332 GEMV_T(nu[ii]+nx[ii], ng[ii], 1.0, DCt+ii, 0, 0, dux+ii, 0, 0.0, dt+ii, nb[ii], dt+ii, nb[ii]);
1333
1334 for(ii=0; ii<Nn; ii++)
1335 {
1336 VECCP(nb[ii]+ng[ii], dt+ii, 0, dt+ii, nb[ii]+ng[ii]);
1337 VECSC(nb[ii]+ng[ii], -1.0, dt+ii, nb[ii]+ng[ii]);
1338 }
1339
1340 for(ii=0; ii<Nn; ii++)
1341 {
1342 if(ns[ii]>0)
1343 EXPAND_SLACKS(ii, qp, qp_sol, ws);
1344 }
1345
1346 COMPUTE_LAM_T_QP(res_d[0].pa, res_m[0].pa, dlam[0].pa, dt[0].pa, cws);
1347
1348 return;
1349
1350 }
1351
1352
1353
1354