1 /*
2 * R : A Computer Language for Statistical Data Analysis
3 * Copyright (C) 1998-2017 The R Core Team.
4 * Copyright (C) 2004-2017 The R Foundation
5 * Copyright (C) 1995, 1996 Robert Gentleman and Ross Ihaka
6 *
7 * This program is free software; you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation; either version 2 of the License, or
10 * (at your option) any later version.
11 *
12 * This program is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with this program; if not, a copy is available at
19 * https://www.R-project.org/Licenses/
20 *
21 *
22 * Symbolic Differentiation
23 */
24
25 #ifdef HAVE_CONFIG_H
26 #include <config.h>
27 #endif
28
29 #include "Defn.h"
30 #undef _
31 #ifdef ENABLE_NLS
32 #include <libintl.h>
33 #define _(String) dgettext ("stats", String)
34 #else
35 #define _(String) (String)
36 #endif
37
38 static SEXP ParenSymbol;
39 static SEXP PlusSymbol;
40 static SEXP MinusSymbol;
41 static SEXP TimesSymbol;
42 static SEXP DivideSymbol;
43 static SEXP PowerSymbol;
44 static SEXP ExpSymbol;
45 static SEXP LogSymbol;
46 static SEXP SinSymbol;
47 static SEXP CosSymbol;
48 static SEXP TanSymbol;
49 static SEXP SinhSymbol;
50 static SEXP CoshSymbol;
51 static SEXP TanhSymbol;
52 static SEXP SqrtSymbol;
53 static SEXP PnormSymbol;
54 static SEXP DnormSymbol;
55 static SEXP AsinSymbol;
56 static SEXP AcosSymbol;
57 static SEXP AtanSymbol;
58 static SEXP GammaSymbol;
59 static SEXP LGammaSymbol;
60 static SEXP DiGammaSymbol;
61 static SEXP TriGammaSymbol;
62 static SEXP PsiSymbol;
63 /* new symbols in R 3.4.0: */
64 static SEXP PiSymbol;
65 static SEXP ExpM1Symbol;
66 static SEXP Log1PSymbol;
67 static SEXP Log2Symbol;
68 static SEXP Log10Symbol;
69 static SEXP SinPiSymbol;
70 static SEXP CosPiSymbol;
71 static SEXP TanPiSymbol;
72 static SEXP FactorialSymbol;
73 static SEXP LFactorialSymbol;
74 /* possible future symbols
75 static SEXP Log1PExpSymbol;
76 static SEXP Log1MExpSymbol;
77 static SEXP Log1PMxSymbol;
78 */
79
80 static Rboolean Initialized = FALSE;
81
82
InitDerivSymbols(void)83 static void InitDerivSymbols(void)
84 {
85 /* Called from doD() and deriv() */
86 if(Initialized) return;
87 ParenSymbol = install("(");
88 PlusSymbol = install("+");
89 MinusSymbol = install("-");
90 TimesSymbol = install("*");
91 DivideSymbol = install("/");
92 PowerSymbol = install("^");
93 ExpSymbol = install("exp");
94 LogSymbol = install("log");
95 SinSymbol = install("sin");
96 CosSymbol = install("cos");
97 TanSymbol = install("tan");
98 SinhSymbol = install("sinh");
99 CoshSymbol = install("cosh");
100 TanhSymbol = install("tanh");
101 SqrtSymbol = install("sqrt");
102 PnormSymbol = install("pnorm");
103 DnormSymbol = install("dnorm");
104 AsinSymbol = install("asin");
105 AcosSymbol = install("acos");
106 AtanSymbol = install("atan");
107 GammaSymbol = install("gamma");
108 LGammaSymbol = install("lgamma");
109 DiGammaSymbol = install("digamma");
110 TriGammaSymbol = install("trigamma");
111 PsiSymbol = install("psigamma");
112 /* new symbols */
113 PiSymbol = install("pi");
114 ExpM1Symbol = install("expm1");
115 Log1PSymbol = install("log1p");
116 Log2Symbol = install("log2");
117 Log10Symbol = install("log10");
118 SinPiSymbol = install("sinpi");
119 CosPiSymbol = install("cospi");
120 TanPiSymbol = install("tanpi");
121 FactorialSymbol = install("factorial");
122 LFactorialSymbol = install("lfactorial");
123 /* possible future symbols
124 Log1PExpSymbol = install("log1pexp"); # log(1+exp(x))
125 Log1MExpSymbol = install("log1mexp"); # log(1-exp(-x)), for x > 0
126 Log1PMxSymbol = install("log1pmx"); # log1p(x)-x
127 */
128
129 Initialized = TRUE;
130 }
131
Constant(double x)132 static SEXP Constant(double x)
133 {
134 return ScalarReal(x);
135 }
136
isZero(SEXP s)137 static int isZero(SEXP s)
138 {
139 return asReal(s) == 0.0;
140 }
141
isOne(SEXP s)142 static int isOne(SEXP s)
143 {
144 return asReal(s) == 1.0;
145 }
146
isUminus(SEXP s)147 static int isUminus(SEXP s)
148 {
149 if (TYPEOF(s) == LANGSXP && CAR(s) == MinusSymbol) {
150 switch(length(s)) {
151 case 2:
152 return 1;
153 case 3:
154 if (CADDR(s) == R_MissingArg)
155 return 1;
156 else return 0;
157 default:
158 error(_("invalid form in unary minus check"));
159 return -1;/* for -Wall */
160 }
161 }
162 else return 0;
163 }
164
165 /* Pointer protect and return the argument */
166
PP(SEXP s)167 static SEXP PP(SEXP s)
168 {
169 PROTECT(s);
170 return s;
171 }
172
simplify(SEXP fun,SEXP arg1,SEXP arg2)173 static SEXP simplify(SEXP fun, SEXP arg1, SEXP arg2)
174 {
175 SEXP ans;
176 if (fun == PlusSymbol) {
177 if (isZero(arg1))
178 ans = arg2;
179 else if (isZero(arg2))
180 ans = arg1;
181 else if (isUminus(arg1))
182 ans = simplify(MinusSymbol, arg2, CADR(arg1));
183 else if (isUminus(arg2))
184 ans = simplify(MinusSymbol, arg1, CADR(arg2));
185 else
186 ans = lang3(PlusSymbol, arg1, arg2);
187 }
188 else if (fun == MinusSymbol) {
189 if (arg2 == R_MissingArg) {
190 if (isZero(arg1))
191 ans = Constant(0.);
192 else if (isUminus(arg1))
193 ans = CADR(arg1);
194 else
195 ans = lang2(MinusSymbol, arg1);
196 }
197 else {
198 if (isZero(arg2))
199 ans = arg1;
200 else if (isZero(arg1))
201 ans = simplify(MinusSymbol, arg2, R_MissingArg);
202 else if (isUminus(arg1)) {
203 ans = simplify(MinusSymbol,
204 PP(simplify(PlusSymbol, CADR(arg1), arg2)),
205 R_MissingArg);
206 UNPROTECT(1);
207 }
208 else if (isUminus(arg2))
209 ans = simplify(PlusSymbol, arg1, CADR(arg2));
210 else
211 ans = lang3(MinusSymbol, arg1, arg2);
212 }
213 }
214 else if (fun == TimesSymbol) {
215 if (isZero(arg1) || isZero(arg2))
216 ans = Constant(0.);
217 else if (isOne(arg1))
218 ans = arg2;
219 else if (isOne(arg2))
220 ans = arg1;
221 else if (isUminus(arg1)) {
222 ans = simplify(MinusSymbol,
223 PP(simplify(TimesSymbol, CADR(arg1), arg2)),
224 R_MissingArg);
225 UNPROTECT(1);
226 }
227 else if (isUminus(arg2)) {
228 ans = simplify(MinusSymbol,
229 PP(simplify(TimesSymbol, arg1, CADR(arg2))),
230 R_MissingArg);
231 UNPROTECT(1);
232 }
233 else
234 ans = lang3(TimesSymbol, arg1, arg2);
235 }
236 else if (fun == DivideSymbol) {
237 if (isZero(arg1))
238 ans = Constant(0.);
239 else if (isZero(arg2))
240 ans = Constant(NA_REAL);
241 else if (isOne(arg2))
242 ans = arg1;
243 else if (isUminus(arg1)) {
244 ans = simplify(MinusSymbol,
245 PP(simplify(DivideSymbol, CADR(arg1), arg2)),
246 R_MissingArg);
247 UNPROTECT(1);
248 }
249 else if (isUminus(arg2)) {
250 ans = simplify(MinusSymbol,
251 PP(simplify(DivideSymbol, arg1, CADR(arg2))),
252 R_MissingArg);
253 UNPROTECT(1);
254 }
255 else ans = lang3(DivideSymbol, arg1, arg2);
256 }
257 else if (fun == PowerSymbol) {
258 if (isZero(arg2))
259 ans = Constant(1.);
260 else if (isZero(arg1))
261 ans = Constant(0.);
262 else if (isOne(arg1))
263 ans = Constant(1.);
264 else if (isOne(arg2))
265 ans = arg1;
266 else
267 ans = lang3(PowerSymbol, arg1, arg2);
268 }
269 else if (fun == ExpSymbol) {
270 /* FIXME: simplify exp(lgamma( E )) = gamma( E ) */
271 /* FIXME: simplify exp(lfactorial( E )) = factorial( E ) */
272 ans = lang2(ExpSymbol, arg1);
273 }
274 else if (fun == LogSymbol) {
275 /* FIXME: simplify log(gamma( E )) = lgamma( E ) */
276 /* FIXME: simplify log(factorial( E )) = lfactorial( E ) */
277 ans = lang2(LogSymbol, arg1);
278 }
279 else if (fun == CosSymbol) ans = lang2(CosSymbol, arg1);
280 else if (fun == SinSymbol) ans = lang2(SinSymbol, arg1);
281 else if (fun == TanSymbol) ans = lang2(TanSymbol, arg1);
282 else if (fun == CoshSymbol) ans = lang2(CoshSymbol, arg1);
283 else if (fun == SinhSymbol) ans = lang2(SinhSymbol, arg1);
284 else if (fun == TanhSymbol) ans = lang2(TanhSymbol, arg1);
285 else if (fun == SqrtSymbol) ans = lang2(SqrtSymbol, arg1);
286 else if (fun == PnormSymbol)ans = lang2(PnormSymbol, arg1);
287 else if (fun == DnormSymbol)ans = lang2(DnormSymbol, arg1);
288 else if (fun == AsinSymbol) ans = lang2(AsinSymbol, arg1);
289 else if (fun == AcosSymbol) ans = lang2(AcosSymbol, arg1);
290 else if (fun == AtanSymbol) ans = lang2(AtanSymbol, arg1);
291 else if (fun == GammaSymbol)ans = lang2(GammaSymbol, arg1);
292 else if (fun == LGammaSymbol)ans = lang2(LGammaSymbol, arg1);
293 else if (fun == DiGammaSymbol) ans = lang2(DiGammaSymbol, arg1);
294 else if (fun == TriGammaSymbol) ans = lang2(TriGammaSymbol, arg1);
295 else if (fun == PsiSymbol){
296 if (arg2 == R_MissingArg) ans = lang2(PsiSymbol, arg1);
297 else ans = lang3(PsiSymbol, arg1, arg2);
298 }
299 /* new symbols */
300 else if (fun == ExpM1Symbol) {
301 /* FIXME: simplify expm1(log1p( E )) = E */
302 ans = lang2(ExpM1Symbol, arg1);
303 }
304 else if (fun == LogSymbol) {
305 /* FIXME: simplify log1p(expm1( E )) = E */
306 ans = lang2(Log1PSymbol, arg1);
307 }
308 else if (fun == Log2Symbol) ans = lang2(Log2Symbol, arg1);
309 else if (fun == Log10Symbol) ans = lang2(Log10Symbol, arg1);
310 else if (fun == CosPiSymbol) ans = lang2(CosPiSymbol, arg1);
311 else if (fun == SinPiSymbol) ans = lang2(SinPiSymbol, arg1);
312 else if (fun == TanPiSymbol) ans = lang2(TanPiSymbol, arg1);
313 else if (fun == FactorialSymbol)ans = lang2(FactorialSymbol, arg1);
314 else if (fun == LFactorialSymbol)ans = lang2(LFactorialSymbol, arg1);
315 /* possible future symbols
316 else if (fun == Log1PExpSymbol) ans = lang2(Log1PExpSymbol, arg1);
317 else if (fun == Log1MExpSymbol) ans = lang2(Log1MExpSymbol, arg1);
318 else if (fun == Log1PMxSymbol) ans = lang2(Log1PMxSymbol, arg1);
319 */
320
321 else ans = Constant(NA_REAL);
322 /* FIXME */
323 #ifdef NOTYET
324 if (length(ans) == 2 && isAtomic(CADR(ans)) && CAR(ans) != MinusSymbol)
325 c = eval(c, rho);
326 if (length(c) == 3 && isAtomic(CADR(ans)) && isAtomic(CADDR(ans)))
327 c = eval(c, rho);
328 #endif
329 return ans;
330 }/* simplify() */
331
332
333 /* D() implements the "derivative table" : */
D(SEXP expr,SEXP var)334 static SEXP D(SEXP expr, SEXP var)
335 {
336
337 #define PP_S(F,a1,a2) PP(simplify(F,a1,a2))
338 #define PP_S2(F,a1) PP(simplify(F,a1, R_MissingArg))
339
340 SEXP ans = R_NilValue, expr1, expr2;
341 switch(TYPEOF(expr)) {
342 case LGLSXP:
343 case INTSXP:
344 case REALSXP:
345 case CPLXSXP:
346 ans = Constant(0);
347 break;
348 case SYMSXP:
349 if (expr == var) ans = Constant(1.);
350 else ans = Constant(0.);
351 break;
352 case LISTSXP:
353 if (inherits(expr, "expression")) ans = D(CAR(expr), var);
354 else ans = Constant(NA_REAL);
355 break;
356 case LANGSXP:
357 if (CAR(expr) == ParenSymbol) {
358 ans = D(CADR(expr), var);
359 }
360 else if (CAR(expr) == PlusSymbol) {
361 if (length(expr) == 2)
362 ans = D(CADR(expr), var);
363 else {
364 ans = simplify(PlusSymbol,
365 PP(D(CADR(expr), var)),
366 PP(D(CADDR(expr), var)));
367 UNPROTECT(2);
368 }
369 }
370 else if (CAR(expr) == MinusSymbol) {
371 if (length(expr) == 2) {
372 ans = simplify(MinusSymbol,
373 PP(D(CADR(expr), var)),
374 R_MissingArg);
375 UNPROTECT(1);
376 }
377 else {
378 ans = simplify(MinusSymbol,
379 PP(D(CADR(expr), var)),
380 PP(D(CADDR(expr), var)));
381 UNPROTECT(2);
382 }
383 }
384 else if (CAR(expr) == TimesSymbol) {
385 ans = simplify(PlusSymbol,
386 PP_S(TimesSymbol,PP(D(CADR(expr),var)), CADDR(expr)),
387 PP_S(TimesSymbol,CADR(expr), PP(D(CADDR(expr),var))));
388 UNPROTECT(4);
389 }
390 else if (CAR(expr) == DivideSymbol) {
391 PROTECT(expr1 = D(CADR(expr), var));
392 PROTECT(expr2 = D(CADDR(expr), var));
393 ans = simplify(MinusSymbol,
394 PP_S(DivideSymbol, expr1, CADDR(expr)),
395 PP_S(DivideSymbol,
396 PP_S(TimesSymbol, CADR(expr), expr2),
397 PP_S(PowerSymbol,CADDR(expr),PP(Constant(2.)))));
398 UNPROTECT(7);
399 }
400 else if (CAR(expr) == PowerSymbol) {
401 if (isLogical(CADDR(expr)) || isNumeric(CADDR(expr))) {
402 ans = simplify(TimesSymbol,
403 CADDR(expr),
404 PP_S(TimesSymbol,
405 PP(D(CADR(expr), var)),
406 PP_S(PowerSymbol,
407 CADR(expr),
408 PP(Constant(asReal(CADDR(expr))-1.)))));
409 UNPROTECT(4);
410 }
411 else {
412 expr1 = simplify(TimesSymbol,
413 PP_S(PowerSymbol,
414 CADR(expr),
415 PP_S(MinusSymbol,
416 CADDR(expr),
417 PP(Constant(1.0)))),
418 PP_S(TimesSymbol,
419 CADDR(expr),
420 PP(D(CADR(expr), var))));
421 UNPROTECT(5);
422 PROTECT(expr1);
423 expr2 = simplify(TimesSymbol,
424 PP_S(PowerSymbol, CADR(expr), CADDR(expr)),
425 PP_S(TimesSymbol,
426 PP_S2(LogSymbol, CADR(expr)),
427 PP(D(CADDR(expr), var))));
428 UNPROTECT(4);
429 PROTECT(expr2);
430 ans = simplify(PlusSymbol, expr1, expr2);
431 UNPROTECT(2);
432 }
433 }
434 else if (CAR(expr) == ExpSymbol) {
435 ans = simplify(TimesSymbol,
436 expr,
437 PP(D(CADR(expr), var)));
438 UNPROTECT(1);
439 }
440 else if (CAR(expr) == LogSymbol) {
441 if (length(expr) != 2)
442 error("only single-argument calls to log() are supported;\n"
443 " maybe use log(x,a) = log(x)/log(a)");
444 ans = simplify(DivideSymbol,
445 PP(D(CADR(expr), var)),
446 CADR(expr));
447 UNPROTECT(1);
448 }
449 else if (CAR(expr) == CosSymbol) {
450 ans = simplify(TimesSymbol,
451 PP_S2(SinSymbol, CADR(expr)),
452 PP_S2(MinusSymbol, PP(D(CADR(expr), var))));
453 UNPROTECT(3);
454 }
455 else if (CAR(expr) == SinSymbol) {
456 ans = simplify(TimesSymbol,
457 PP_S2(CosSymbol, CADR(expr)),
458 PP(D(CADR(expr), var)));
459 UNPROTECT(2);
460 }
461 else if (CAR(expr) == TanSymbol) {
462 ans = simplify(DivideSymbol,
463 PP(D(CADR(expr), var)),
464 PP_S(PowerSymbol,
465 PP_S2(CosSymbol, CADR(expr)),
466 PP(Constant(2.0))));
467 UNPROTECT(4);
468 }
469 else if (CAR(expr) == CoshSymbol) {
470 ans = simplify(TimesSymbol,
471 PP_S2(SinhSymbol, CADR(expr)),
472 PP(D(CADR(expr), var)));
473 UNPROTECT(2);
474 }
475 else if (CAR(expr) == SinhSymbol) {
476 ans = simplify(TimesSymbol,
477 PP_S2(CoshSymbol, CADR(expr)),
478 PP(D(CADR(expr), var))),
479 UNPROTECT(2);
480 }
481 else if (CAR(expr) == TanhSymbol) {
482 ans = simplify(DivideSymbol,
483 PP(D(CADR(expr), var)),
484 PP_S(PowerSymbol,
485 PP_S2(CoshSymbol, CADR(expr)),
486 PP(Constant(2.0))));
487 UNPROTECT(4);
488 }
489 else if (CAR(expr) == SqrtSymbol) {
490 PROTECT(expr1 = allocList(3));
491 SET_TYPEOF(expr1, LANGSXP);
492 SETCAR(expr1, PowerSymbol);
493 SETCADR(expr1, CADR(expr));
494 SETCADDR(expr1, Constant(0.5));
495 ans = D(expr1, var);
496 UNPROTECT(1);
497 }
498 else if (CAR(expr) == PnormSymbol) {
499 ans = simplify(TimesSymbol,
500 PP_S2(DnormSymbol, CADR(expr)),
501 PP(D(CADR(expr), var)));
502 UNPROTECT(2);
503 }
504 else if (CAR(expr) == DnormSymbol) {
505 ans = simplify(TimesSymbol,
506 PP_S2(MinusSymbol, CADR(expr)),
507 PP_S(TimesSymbol,
508 PP_S2(DnormSymbol, CADR(expr)),
509 PP(D(CADR(expr), var))));
510 UNPROTECT(4);
511 }
512 else if (CAR(expr) == AsinSymbol) {
513 ans = simplify(DivideSymbol,
514 PP(D(CADR(expr), var)),
515 PP_S(SqrtSymbol,
516 PP_S(MinusSymbol, PP(Constant(1.)),
517 PP_S(PowerSymbol,CADR(expr),PP(Constant(2.)))),
518 R_MissingArg));
519 UNPROTECT(6);
520 }
521 else if (CAR(expr) == AcosSymbol) {
522 ans = simplify(MinusSymbol,
523 PP_S(DivideSymbol,
524 PP(D(CADR(expr), var)),
525 PP_S(SqrtSymbol,
526 PP_S(MinusSymbol, PP(Constant(1.)),
527 PP_S(PowerSymbol,
528 CADR(expr),PP(Constant(2.)))),
529 R_MissingArg)), R_MissingArg);
530 UNPROTECT(7);
531 }
532 else if (CAR(expr) == AtanSymbol) {
533 ans = simplify(DivideSymbol,
534 PP(D(CADR(expr), var)),
535 PP_S(PlusSymbol,PP(Constant(1.)),
536 PP_S(PowerSymbol, CADR(expr),PP(Constant(2.)))));
537 UNPROTECT(5);
538 }
539 else if (CAR(expr) == LGammaSymbol) {
540 ans = simplify(TimesSymbol,
541 PP(D(CADR(expr), var)),
542 PP_S2(DiGammaSymbol, CADR(expr)));
543 UNPROTECT(2);
544 }
545 else if (CAR(expr) == GammaSymbol) {
546 ans = simplify(TimesSymbol,
547 PP(D(CADR(expr), var)),
548 PP_S(TimesSymbol,
549 expr,
550 PP_S2(DiGammaSymbol, CADR(expr))));
551 UNPROTECT(3);
552 }
553 else if (CAR(expr) == DiGammaSymbol) {
554 ans = simplify(TimesSymbol,
555 PP(D(CADR(expr), var)),
556 PP_S2(TriGammaSymbol, CADR(expr)));
557 UNPROTECT(2);
558 }
559 else if (CAR(expr) == TriGammaSymbol) {
560 ans = simplify(TimesSymbol,
561 PP(D(CADR(expr), var)),
562 PP_S(PsiSymbol, CADR(expr), PP(ScalarInteger(2))));
563 UNPROTECT(3);
564 }
565 else if (CAR(expr) == PsiSymbol) {
566 if (length(expr) == 2){
567 ans = simplify(TimesSymbol,
568 PP(D(CADR(expr), var)),
569 PP_S(PsiSymbol, CADR(expr), PP(ScalarInteger(1))));
570 UNPROTECT(3);
571 } else if (TYPEOF(CADDR(expr)) == INTSXP ||
572 TYPEOF(CADDR(expr)) == REALSXP) {
573 ans = simplify(TimesSymbol,
574 PP(D(CADR(expr), var)),
575 PP_S(PsiSymbol,
576 CADR(expr),
577 PP(ScalarInteger(asInteger(CADDR(expr))+1))));
578 UNPROTECT(3);
579 } else {
580 ans = simplify(TimesSymbol,
581 PP(D(CADR(expr), var)),
582 PP_S(PsiSymbol,
583 CADR(expr),
584 simplify(PlusSymbol, CADDR(expr),
585 PP(ScalarInteger(1)))));
586 UNPROTECT(3);
587 }
588 }
589 /* new in R 3.4.0 */
590 else if (CAR(expr) == ExpM1Symbol) {
591 ans = simplify(TimesSymbol,
592 PP_S2(ExpSymbol, CADR(expr)),
593 PP(D(CADR(expr), var)));
594 UNPROTECT(2);
595 }
596 else if (CAR(expr) == Log1PSymbol) {
597 ans = simplify(DivideSymbol,
598 PP(D(CADR(expr), var)),
599 PP_S(PlusSymbol, PP(Constant(1.)), CADR(expr)));
600 UNPROTECT(3);
601 }
602 else if (CAR(expr) == Log2Symbol) {
603 ans = simplify(DivideSymbol,
604 PP(D(CADR(expr), var)),
605 PP_S(TimesSymbol, CADR(expr),
606 PP_S2(LogSymbol, PP(Constant(2.)))));
607 UNPROTECT(4);
608 }
609 else if (CAR(expr) == Log10Symbol) {
610 ans = simplify(DivideSymbol,
611 PP(D(CADR(expr), var)),
612 PP_S(TimesSymbol, CADR(expr),
613 PP_S2(LogSymbol, PP(Constant(10.)))));
614 UNPROTECT(4);
615 }
616 else if (CAR(expr) == CosPiSymbol) {
617 ans = simplify(TimesSymbol,
618 PP_S2(SinPiSymbol, CADR(expr)),
619 PP_S(TimesSymbol, PP_S2(MinusSymbol, PiSymbol),
620 PP(D(CADR(expr), var)) ));
621 UNPROTECT(4);
622 }
623 else if (CAR(expr) == SinPiSymbol) {
624 ans = simplify(TimesSymbol,
625 PP_S2(CosPiSymbol, CADR(expr)),
626 PP_S(TimesSymbol, PiSymbol,
627 PP(D(CADR(expr), var)) ));
628 UNPROTECT(3);
629 }
630 else if (CAR(expr) == TanPiSymbol) {
631 ans = simplify(DivideSymbol,
632 PP_S(TimesSymbol, PiSymbol, PP(D(CADR(expr), var))),
633 PP_S(PowerSymbol,
634 PP_S2(CosPiSymbol, CADR(expr)),
635 PP(Constant(2.0))));
636 UNPROTECT(5);
637 }
638 else if (CAR(expr) == LFactorialSymbol) {
639 ans = simplify(TimesSymbol,
640 PP(D(CADR(expr), var)),
641 PP_S2(DiGammaSymbol, PP_S(PlusSymbol,
642 CADR(expr),
643 PP(ScalarInteger(1)))));
644 UNPROTECT(4);
645 }
646 else if (CAR(expr) == FactorialSymbol) {
647 ans = simplify(TimesSymbol,
648 PP(D(CADR(expr), var)),
649 PP_S(TimesSymbol,
650 expr,
651 PP_S2(DiGammaSymbol, PP_S(PlusSymbol,
652 CADR(expr),
653 PP(ScalarInteger(1))))));
654 UNPROTECT(5);
655 }
656 /* possible future symbols
657 else if (CAR(expr) == Log1PExpSymbol) {
658 ans = simplify(DivideSymbol,
659 PP_S(TimesSymbol, PP(D(CADR(expr), var)),
660 PP_S2(ExpSymbol, CADR(expr))),
661 PP_S(PlusSymbol,PP(Constant(1.)),
662 PP_S2(ExpSymbol, CADR(expr)) ));
663 UNPROTECT(6);
664 }
665 else if (CAR(expr) == Log1MExpSymbol) {
666 ans = simplify(DivideSymbol,
667 PP_S(TimesSymbol, PP_S2(MinusSymbol, PP(D(CADR(expr), var))),
668 PP_S2(ExpSymbol, PP_S2(MinusSymbol, CADR(expr))) ),
669 PP_S2(ExpM1Symbol, PP_S2(MinusSymbol, CADR(expr))) );
670 UNPROTECT(7);
671 }
672 else if (CAR(expr) == Log1PMxSymbol) {
673 ans = simplify(DivideSymbol,
674 PP_S2(MinusSymbol, PP(D(CADR(expr), var))),
675 PP_S(PlusSymbol,PP(Constant(1.)), CADR(expr)) );
676 UNPROTECT(4);
677 }
678 */
679
680 else {
681 SEXP u = deparse1(CAR(expr), 0, SIMPLEDEPARSE);
682 error(_("Function '%s' is not in the derivatives table"),
683 translateChar(STRING_ELT(u, 0)));
684 }
685
686 break;
687 default:
688 ans = Constant(NA_REAL);
689 }
690 return ans;
691
692 #undef PP_S
693 #undef PP_S2
694
695 } /* D() */
696
isPlusForm(SEXP expr)697 static int isPlusForm(SEXP expr)
698 {
699 return TYPEOF(expr) == LANGSXP
700 && length(expr) == 3
701 && CAR(expr) == PlusSymbol;
702 }
703
isMinusForm(SEXP expr)704 static int isMinusForm(SEXP expr)
705 {
706 return TYPEOF(expr) == LANGSXP
707 && length(expr) == 3
708 && CAR(expr) == MinusSymbol;
709 }
710
isTimesForm(SEXP expr)711 static int isTimesForm(SEXP expr)
712 {
713 return TYPEOF(expr) == LANGSXP
714 && length(expr) == 3
715 && CAR(expr) == TimesSymbol;
716 }
717
isDivideForm(SEXP expr)718 static int isDivideForm(SEXP expr)
719 {
720 return TYPEOF(expr) == LANGSXP
721 && length(expr) == 3
722 && CAR(expr) == DivideSymbol;
723 }
724
isPowerForm(SEXP expr)725 static int isPowerForm(SEXP expr)
726 {
727 return (TYPEOF(expr) == LANGSXP
728 && length(expr) == 3
729 && CAR(expr) == PowerSymbol);
730 }
731
AddParens(SEXP expr)732 static SEXP AddParens(SEXP expr)
733 {
734 SEXP e;
735 if (TYPEOF(expr) == LANGSXP) {
736 e = CDR(expr);
737 while(e != R_NilValue) {
738 SETCAR(e, AddParens(CAR(e)));
739 e = CDR(e);
740 }
741 }
742 if (isPlusForm(expr)) {
743 if (isPlusForm(CADDR(expr))) {
744 SETCADDR(expr, lang2(ParenSymbol, CADDR(expr)));
745 }
746 }
747 else if (isMinusForm(expr)) {
748 if (isPlusForm(CADDR(expr)) || isMinusForm(CADDR(expr))) {
749 SETCADDR(expr, lang2(ParenSymbol, CADDR(expr)));
750 }
751 }
752 else if (isTimesForm(expr)) {
753 if (isPlusForm(CADDR(expr)) || isMinusForm(CADDR(expr))
754 || isTimesForm(CADDR(expr)) || isDivideForm(CADDR(expr))) {
755 SETCADDR(expr, lang2(ParenSymbol, CADDR(expr)));
756 }
757 if (isPlusForm(CADR(expr)) || isMinusForm(CADR(expr))) {
758 SETCADR(expr, lang2(ParenSymbol, CADR(expr)));
759 }
760 }
761 else if (isDivideForm(expr)) {
762 if (isPlusForm(CADDR(expr)) || isMinusForm(CADDR(expr))
763 || isTimesForm(CADDR(expr)) || isDivideForm(CADDR(expr))) {
764 SETCADDR(expr, lang2(ParenSymbol, CADDR(expr)));
765 }
766 if (isPlusForm(CADR(expr)) || isMinusForm(CADR(expr))) {
767 SETCADR(expr, lang2(ParenSymbol, CADR(expr)));
768 }
769 }
770 else if (isPowerForm(expr)) {
771 if (isPowerForm(CADR(expr))) {
772 SETCADR(expr, lang2(ParenSymbol, CADR(expr)));
773 }
774 if (isPlusForm(CADDR(expr)) || isMinusForm(CADDR(expr))
775 || isTimesForm(CADDR(expr)) || isDivideForm(CADDR(expr))) {
776 SETCADDR(expr, lang2(ParenSymbol, CADDR(expr)));
777 }
778 }
779 return expr;
780 }
781
doD(SEXP args)782 SEXP doD(SEXP args)
783 {
784 SEXP expr, var;
785 args = CDR(args);
786 if (isExpression(CAR(args))) expr = VECTOR_ELT(CAR(args), 0);
787 else expr = CAR(args);
788 if (!(isLanguage(expr) || isSymbol(expr) || isNumeric(expr) || isComplex(expr)))
789 error(_("expression must not be type '%s'"), type2char(TYPEOF(expr)));
790 var = CADR(args);
791 if (!isString(var) || length(var) < 1)
792 error(_("variable must be a character string"));
793 if (length(var) > 1)
794 warning(_("only the first element is used as variable name"));
795 var = installTrChar(STRING_ELT(var, 0));
796 InitDerivSymbols();
797 PROTECT(expr = D(expr, var));
798 expr = AddParens(expr);
799 UNPROTECT(1);
800 return expr;
801 }
802
803 /* ------ FindSubexprs ------ and ------ Accumulate ------ */
804
InvalidExpression(char * where)805 static void NORET InvalidExpression(char *where)
806 {
807 error(_("invalid expression in '%s'"), where);
808 }
809
equal(SEXP expr1,SEXP expr2)810 static int equal(SEXP expr1, SEXP expr2)
811 {
812 if (TYPEOF(expr1) == TYPEOF(expr2)) {
813 switch(TYPEOF(expr1)) {
814 case NILSXP:
815 return 1;
816 case SYMSXP:
817 return expr1 == expr2;
818 case LGLSXP:
819 case INTSXP:
820 return INTEGER(expr1)[0] == INTEGER(expr2)[0];
821 case REALSXP:
822 return REAL(expr1)[0] == REAL(expr2)[0];
823 case CPLXSXP:
824 return COMPLEX(expr1)[0].r == COMPLEX(expr2)[0].r
825 && COMPLEX(expr1)[0].i == COMPLEX(expr2)[0].i;
826 case LANGSXP:
827 case LISTSXP:
828 return equal(CAR(expr1), CAR(expr2))
829 && equal(CDR(expr1), CDR(expr2));
830 default:
831 InvalidExpression("equal");
832 }
833 }
834 return 0;
835 }
836
Accumulate(SEXP expr,SEXP exprlist)837 static int Accumulate(SEXP expr, SEXP exprlist)
838 {
839 SEXP e;
840 int k;
841 e = exprlist;
842 k = 0;
843 while(CDR(e) != R_NilValue) {
844 e = CDR(e);
845 k = k + 1;
846 if (equal(expr, CAR(e)))
847 return k;
848 }
849 SETCDR(e, CONS(expr, R_NilValue));
850 return k + 1;
851 }
852
Accumulate2(SEXP expr,SEXP exprlist)853 static int Accumulate2(SEXP expr, SEXP exprlist)
854 {
855 SEXP e;
856 int k;
857 e = exprlist;
858 k = 0;
859 while(CDR(e) != R_NilValue) {
860 e = CDR(e);
861 k = k + 1;
862 }
863 SETCDR(e, CONS(expr, R_NilValue));
864 return k + 1;
865 }
866
MakeVariable(int k,SEXP tag)867 static SEXP MakeVariable(int k, SEXP tag)
868 {
869 const void *vmax = vmaxget();
870 char buf[64];
871 snprintf(buf, 64, "%s%d", translateChar(STRING_ELT(tag, 0)), k);
872 vmaxset(vmax);
873 return install(buf);
874 }
875
FindSubexprs(SEXP expr,SEXP exprlist,SEXP tag)876 static int FindSubexprs(SEXP expr, SEXP exprlist, SEXP tag)
877 {
878 SEXP e;
879 int k;
880 switch(TYPEOF(expr)) {
881 case SYMSXP:
882 case LGLSXP:
883 case INTSXP:
884 case REALSXP:
885 case CPLXSXP:
886 return 0;
887 break;
888 case LISTSXP:
889 if (inherits(expr, "expression"))
890 return FindSubexprs(CAR(expr), exprlist, tag);
891 else { InvalidExpression("FindSubexprs"); return -1/*-Wall*/; }
892 break;
893 case LANGSXP:
894 if (CAR(expr) == install("(")) {
895 return FindSubexprs(CADR(expr), exprlist, tag);
896 }
897 else {
898 e = CDR(expr);
899 while(e != R_NilValue) {
900 if ((k = FindSubexprs(CAR(e), exprlist, tag)) != 0)
901 SETCAR(e, MakeVariable(k, tag));
902 e = CDR(e);
903 }
904 return Accumulate(expr, exprlist);
905 }
906 break;
907 default:
908 InvalidExpression("FindSubexprs");
909 return -1/*-Wall*/;
910 }
911 }
912
CountOccurrences(SEXP sym,SEXP lst)913 static int CountOccurrences(SEXP sym, SEXP lst)
914 {
915 switch(TYPEOF(lst)) {
916 case SYMSXP:
917 return lst == sym;
918 case LISTSXP:
919 case LANGSXP:
920 return CountOccurrences(sym, CAR(lst))
921 + CountOccurrences(sym, CDR(lst));
922 default:
923 return 0;
924 }
925 }
926
Replace(SEXP sym,SEXP expr,SEXP lst)927 static SEXP Replace(SEXP sym, SEXP expr, SEXP lst)
928 {
929 switch(TYPEOF(lst)) {
930 case SYMSXP:
931 if (lst == sym) return expr;
932 else return lst;
933 case LISTSXP:
934 case LANGSXP:
935 SETCAR(lst, Replace(sym, expr, CAR(lst)));
936 SETCDR(lst, Replace(sym, expr, CDR(lst)));
937 return lst;
938 default:
939 return lst;
940 }
941 }
942
CreateGrad(SEXP names)943 static SEXP CreateGrad(SEXP names)
944 {
945 SEXP p, q, data, dim, dimnames;
946 int i, n;
947 n = length(names);
948 PROTECT(dimnames = lang3(R_NilValue, R_NilValue, R_NilValue));
949 SETCAR(dimnames, install("list"));
950 p = install("c");
951 PROTECT(q = allocList(n));
952 SETCADDR(dimnames, LCONS(p, q));
953 UNPROTECT(1);
954 for(i = 0 ; i < n ; i++) {
955 SETCAR(q, ScalarString(STRING_ELT(names, i)));
956 q = CDR(q);
957 }
958 PROTECT(dim = lang3(R_NilValue, R_NilValue, R_NilValue));
959 SETCAR(dim, install("c"));
960 SETCADR(dim, lang2(install("length"), install(".value")));
961 SETCADDR(dim, ScalarInteger(length(names))); /* was real? */
962 PROTECT(data = ScalarReal(0.));
963 PROTECT(p = lang4(install("array"), data, dim, dimnames));
964 p = lang3(install("<-"), install(".grad"), p);
965 UNPROTECT(4);
966 return p;
967 }
968
CreateHess(SEXP names)969 static SEXP CreateHess(SEXP names)
970 {
971 SEXP p, q, data, dim, dimnames;
972 int i, n;
973 n = length(names);
974 PROTECT(dimnames = lang4(R_NilValue, R_NilValue, R_NilValue, R_NilValue));
975 SETCAR(dimnames, install("list"));
976 p = install("c");
977 PROTECT(q = allocList(n));
978 SETCADDR(dimnames, LCONS(p, q));
979 UNPROTECT(1);
980 for(i = 0 ; i < n ; i++) {
981 SETCAR(q, ScalarString(STRING_ELT(names, i)));
982 q = CDR(q);
983 }
984 SETCADDDR(dimnames, duplicate(CADDR(dimnames)));
985 PROTECT(dim = lang4(R_NilValue, R_NilValue, R_NilValue,R_NilValue));
986 SETCAR(dim, install("c"));
987 SETCADR(dim, lang2(install("length"), install(".value")));
988 SETCADDR(dim, ScalarInteger(length(names)));
989 SETCADDDR(dim, ScalarInteger(length(names)));
990 PROTECT(data = ScalarReal(0.));
991 PROTECT(p = lang4(install("array"), data, dim, dimnames));
992 p = lang3(install("<-"), install(".hessian"), p);
993 UNPROTECT(4);
994 return p;
995 }
996
DerivAssign(SEXP name,SEXP expr)997 static SEXP DerivAssign(SEXP name, SEXP expr)
998 {
999 SEXP ans, newname;
1000 PROTECT(ans = lang3(install("<-"), R_NilValue, expr));
1001 PROTECT(newname = ScalarString(name));
1002 SETCADR(ans, lang4(R_BracketSymbol, install(".grad"), R_MissingArg, newname));
1003 UNPROTECT(2);
1004 return ans;
1005 }
1006
HessAssign1(SEXP name,SEXP expr)1007 static SEXP HessAssign1(SEXP name, SEXP expr)
1008 {
1009 SEXP ans, newname;
1010 PROTECT(ans = lang3(install("<-"), R_NilValue, expr));
1011 PROTECT(newname = ScalarString(name));
1012 SETCADR(ans, lang5(R_BracketSymbol, install(".hessian"), R_MissingArg,
1013 newname, newname));
1014 UNPROTECT(2);
1015 return ans;
1016 }
1017
HessAssign2(SEXP name1,SEXP name2,SEXP expr)1018 static SEXP HessAssign2(SEXP name1, SEXP name2, SEXP expr)
1019 {
1020 SEXP ans, newname1, newname2, tmp1, tmp2, tmp3;
1021 PROTECT(newname1 = ScalarString(name1));
1022 PROTECT(newname2 = ScalarString(name2));
1023 /* this is overkill, but PR#14772 found an issue */
1024 PROTECT(tmp1 = lang5(R_BracketSymbol, install(".hessian"), R_MissingArg,
1025 newname1, newname2));
1026 PROTECT(tmp2 = lang5(R_BracketSymbol, install(".hessian"), R_MissingArg,
1027 newname2, newname1));
1028 PROTECT(tmp3 = lang3(install("<-"), tmp2, expr));
1029 ans = lang3(install("<-"), tmp1, tmp3);
1030 UNPROTECT(5);
1031 return ans;
1032 }
1033
1034 /* attr(.value, "gradient") <- .grad */
1035
AddGrad(void)1036 static SEXP AddGrad(void)
1037 {
1038 SEXP ans;
1039 PROTECT(ans = mkString("gradient"));
1040 PROTECT(ans = lang3(install("attr"), install(".value"), ans));
1041 ans = lang3(install("<-"), ans, install(".grad"));
1042 UNPROTECT(2);
1043 return ans;
1044 }
1045
AddHess(void)1046 static SEXP AddHess(void)
1047 {
1048 SEXP ans;
1049 PROTECT(ans = mkString("hessian"));
1050 PROTECT(ans = lang3(install("attr"), install(".value"), ans));
1051 ans = lang3(install("<-"), ans, install(".hessian"));
1052 UNPROTECT(2);
1053 return ans;
1054 }
1055
Prune(SEXP lst)1056 static SEXP Prune(SEXP lst)
1057 {
1058 if (lst == R_NilValue)
1059 return lst;
1060 SETCDR(lst, Prune(CDR(lst)));
1061 if (CAR(lst) == R_MissingArg)
1062 return CDR(lst);
1063 else return lst ;
1064 }
1065
deriv(SEXP args)1066 SEXP deriv(SEXP args)
1067 {
1068 /* deriv(expr, namevec, function.arg, tag, hessian) */
1069 SEXP ans, ans2, expr, funarg, names, s;
1070 int f_index, *d_index, *d2_index;
1071 int i, j, k, nexpr, nderiv=0, hessian;
1072 SEXP exprlist, tag;
1073
1074 args = CDR(args);
1075 InitDerivSymbols();
1076 PROTECT(exprlist = LCONS(R_BraceSymbol, R_NilValue));
1077 /* expr: */
1078 if (isExpression(CAR(args)))
1079 PROTECT(expr = VECTOR_ELT(CAR(args), 0));
1080 else PROTECT(expr = CAR(args));
1081 args = CDR(args);
1082 /* namevec: */
1083 names = CAR(args);
1084 if (!isString(names) || (nderiv = length(names)) < 1)
1085 error(_("invalid variable names"));
1086 args = CDR(args);
1087 /* function.arg: */
1088 funarg = CAR(args);
1089 args = CDR(args);
1090 /* tag: */
1091 tag = CAR(args);
1092 if (!isString(tag) || length(tag) < 1
1093 || length(STRING_ELT(tag, 0)) < 1 || length(STRING_ELT(tag, 0)) > 60)
1094 error(_("invalid tag"));
1095 args = CDR(args);
1096 /* hessian: */
1097 hessian = asLogical(CAR(args));
1098 /* NOTE: FindSubexprs is destructive, hence the duplication.
1099 It can allocate, so protect the duplicate.
1100 */
1101 PROTECT(ans = duplicate(expr));
1102 f_index = FindSubexprs(ans, exprlist, tag);
1103 d_index = (int*)R_alloc((size_t) nderiv, sizeof(int));
1104 if (hessian)
1105 d2_index = (int*)R_alloc((size_t) ((nderiv * (1 + nderiv))/2),
1106 sizeof(int));
1107 else d2_index = d_index;/*-Wall*/
1108 UNPROTECT(1);
1109 for(i=0, k=0; i<nderiv ; i++) {
1110 PROTECT(ans = duplicate(expr));
1111 PROTECT(ans = D(ans, installTrChar(STRING_ELT(names, i))));
1112 PROTECT(ans2 = duplicate(ans)); /* keep a temporary copy */
1113 d_index[i] = FindSubexprs(ans, exprlist, tag); /* examine the derivative first */
1114 PROTECT(ans = duplicate(ans2)); /* restore the copy */
1115 if (hessian) {
1116 for(j = i; j < nderiv; j++) {
1117 PROTECT(ans2 = duplicate(ans)); /* install could allocate */
1118 PROTECT(ans2 = D(ans2, installTrChar(STRING_ELT(names, j))));
1119 d2_index[k] = FindSubexprs(ans2, exprlist, tag);
1120 k++;
1121 UNPROTECT(2);
1122 }
1123 }
1124 UNPROTECT(4);
1125 }
1126 nexpr = length(exprlist) - 1;
1127 if (f_index) {
1128 Accumulate2(MakeVariable(f_index, tag), exprlist);
1129 }
1130 else {
1131 PROTECT(ans = duplicate(expr));
1132 Accumulate2(expr, exprlist);
1133 UNPROTECT(1);
1134 }
1135 Accumulate2(R_NilValue, exprlist);
1136 if (hessian) { Accumulate2(R_NilValue, exprlist); }
1137 for (i = 0, k = 0; i < nderiv ; i++) {
1138 if (d_index[i]) {
1139 Accumulate2(MakeVariable(d_index[i], tag), exprlist);
1140 if (hessian) {
1141 PROTECT(ans = duplicate(expr));
1142 PROTECT(ans = D(ans, installTrChar(STRING_ELT(names, i))));
1143 for (j = i; j < nderiv; j++) {
1144 if (d2_index[k]) {
1145 Accumulate2(MakeVariable(d2_index[k], tag), exprlist);
1146 } else {
1147 PROTECT(ans2 = duplicate(ans));
1148 PROTECT(ans2 = D(ans2, installTrChar(STRING_ELT(names, j))));
1149 Accumulate2(ans2, exprlist);
1150 UNPROTECT(2);
1151 }
1152 k++;
1153 }
1154 UNPROTECT(2);
1155 }
1156 } else { /* the first derivative is constant or simple variable */
1157 PROTECT(ans = duplicate(expr));
1158 PROTECT(ans = D(ans, installTrChar(STRING_ELT(names, i))));
1159 Accumulate2(ans, exprlist);
1160 UNPROTECT(2);
1161 if (hessian) {
1162 for (j = i; j < nderiv; j++) {
1163 if (d2_index[k]) {
1164 Accumulate2(MakeVariable(d2_index[k], tag), exprlist);
1165 } else {
1166 PROTECT(ans2 = duplicate(ans));
1167 PROTECT(ans2 = D(ans2, installTrChar(STRING_ELT(names, j))));
1168 if(isZero(ans2)) Accumulate2(R_MissingArg, exprlist);
1169 else Accumulate2(ans2, exprlist);
1170 UNPROTECT(2);
1171 }
1172 k++;
1173 }
1174 }
1175 }
1176 }
1177 Accumulate2(R_NilValue, exprlist);
1178 Accumulate2(R_NilValue, exprlist);
1179 if (hessian) { Accumulate2(R_NilValue, exprlist); }
1180
1181 i = 0;
1182 ans = CDR(exprlist);
1183 while (i < nexpr) {
1184 if (CountOccurrences(MakeVariable(i+1, tag), CDR(ans)) < 2) {
1185 SETCDR(ans, Replace(MakeVariable(i+1, tag), CAR(ans), CDR(ans)));
1186 SETCAR(ans, R_MissingArg);
1187 }
1188 else {
1189 SEXP var;
1190 PROTECT(var = MakeVariable(i+1, tag));
1191 SETCAR(ans, lang3(install("<-"), var, AddParens(CAR(ans))));
1192 UNPROTECT(1);
1193 }
1194 i = i + 1;
1195 ans = CDR(ans);
1196 }
1197 /* .value <- ... */
1198 SETCAR(ans, lang3(install("<-"), install(".value"), AddParens(CAR(ans))));
1199 ans = CDR(ans);
1200 /* .grad <- ... */
1201 SETCAR(ans, CreateGrad(names));
1202 ans = CDR(ans);
1203 /* .hessian <- ... */
1204 if (hessian) { SETCAR(ans, CreateHess(names)); ans = CDR(ans); }
1205 /* .grad[, "..."] <- ... */
1206 for (i = 0; i < nderiv ; i++) {
1207 SETCAR(ans, DerivAssign(STRING_ELT(names, i), AddParens(CAR(ans))));
1208 ans = CDR(ans);
1209 if (hessian) {
1210 for (j = i; j < nderiv; j++) {
1211 if (CAR(ans) != R_MissingArg) {
1212 if (i == j) {
1213 SETCAR(ans, HessAssign1(STRING_ELT(names, i),
1214 AddParens(CAR(ans))));
1215 } else {
1216 SETCAR(ans, HessAssign2(STRING_ELT(names, i),
1217 STRING_ELT(names, j),
1218 AddParens(CAR(ans))));
1219 }
1220 }
1221 ans = CDR(ans);
1222 }
1223 }
1224 }
1225 /* attr(.value, "gradient") <- .grad */
1226 SETCAR(ans, AddGrad());
1227 ans = CDR(ans);
1228 if (hessian) { SETCAR(ans, AddHess()); ans = CDR(ans); }
1229 /* .value */
1230 SETCAR(ans, install(".value"));
1231 /* Prune the expression list removing eliminated sub-expressions */
1232 SETCDR(exprlist, Prune(CDR(exprlist)));
1233
1234 if (TYPEOF(funarg) == LGLSXP && LOGICAL(funarg)[0]) { /* fun = TRUE */
1235 funarg = names;
1236 }
1237
1238 if (TYPEOF(funarg) == CLOSXP)
1239 {
1240 s = allocSExp(CLOSXP);
1241 SET_FORMALS(s, FORMALS(funarg));
1242 SET_CLOENV(s, CLOENV(funarg));
1243 funarg = s;
1244 SET_BODY(funarg, exprlist);
1245 }
1246 else if (isString(funarg)) {
1247 PROTECT(names = duplicate(funarg));
1248 PROTECT(funarg = allocSExp(CLOSXP));
1249 PROTECT(ans = allocList(length(names)));
1250 SET_FORMALS(funarg, ans);
1251 for(i = 0; i < length(names); i++) {
1252 SET_TAG(ans, installTrChar(STRING_ELT(names, i)));
1253 SETCAR(ans, R_MissingArg);
1254 ans = CDR(ans);
1255 }
1256 UNPROTECT(3);
1257 SET_BODY(funarg, exprlist);
1258 SET_CLOENV(funarg, R_GlobalEnv);
1259 }
1260 else {
1261 funarg = allocVector(EXPRSXP, 1);
1262 SET_VECTOR_ELT(funarg, 0, exprlist);
1263 /* funarg = lang2(install("expression"), exprlist); */
1264 }
1265 UNPROTECT(2);
1266 return funarg;
1267 }
1268