1 /*
2  *  R : A Computer Language for Statistical Data Analysis
3  *  Copyright (C) 1998-2017   The R Core Team.
4  *  Copyright (C) 2004-2017   The R Foundation
5  *  Copyright (C) 1995, 1996  Robert Gentleman and Ross Ihaka
6  *
7  *  This program is free software; you can redistribute it and/or modify
8  *  it under the terms of the GNU General Public License as published by
9  *  the Free Software Foundation; either version 2 of the License, or
10  *  (at your option) any later version.
11  *
12  *  This program is distributed in the hope that it will be useful,
13  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
14  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15  *  GNU General Public License for more details.
16  *
17  *  You should have received a copy of the GNU General Public License
18  *  along with this program; if not, a copy is available at
19  *  https://www.R-project.org/Licenses/
20  *
21  *
22  *  Symbolic Differentiation
23  */
24 
25 #ifdef HAVE_CONFIG_H
26 #include <config.h>
27 #endif
28 
29 #include "Defn.h"
30 #undef _
31 #ifdef ENABLE_NLS
32 #include <libintl.h>
33 #define _(String) dgettext ("stats", String)
34 #else
35 #define _(String) (String)
36 #endif
37 
38 static SEXP ParenSymbol;
39 static SEXP PlusSymbol;
40 static SEXP MinusSymbol;
41 static SEXP TimesSymbol;
42 static SEXP DivideSymbol;
43 static SEXP PowerSymbol;
44 static SEXP ExpSymbol;
45 static SEXP LogSymbol;
46 static SEXP SinSymbol;
47 static SEXP CosSymbol;
48 static SEXP TanSymbol;
49 static SEXP SinhSymbol;
50 static SEXP CoshSymbol;
51 static SEXP TanhSymbol;
52 static SEXP SqrtSymbol;
53 static SEXP PnormSymbol;
54 static SEXP DnormSymbol;
55 static SEXP AsinSymbol;
56 static SEXP AcosSymbol;
57 static SEXP AtanSymbol;
58 static SEXP GammaSymbol;
59 static SEXP LGammaSymbol;
60 static SEXP DiGammaSymbol;
61 static SEXP TriGammaSymbol;
62 static SEXP PsiSymbol;
63 /* new symbols in R 3.4.0: */
64 static SEXP PiSymbol;
65 static SEXP ExpM1Symbol;
66 static SEXP Log1PSymbol;
67 static SEXP Log2Symbol;
68 static SEXP Log10Symbol;
69 static SEXP SinPiSymbol;
70 static SEXP CosPiSymbol;
71 static SEXP TanPiSymbol;
72 static SEXP FactorialSymbol;
73 static SEXP LFactorialSymbol;
74 /* possible future symbols
75 static SEXP Log1PExpSymbol;
76 static SEXP Log1MExpSymbol;
77 static SEXP Log1PMxSymbol;
78 */
79 
80 static Rboolean Initialized = FALSE;
81 
82 
InitDerivSymbols(void)83 static void InitDerivSymbols(void)
84 {
85     /* Called from doD() and deriv() */
86     if(Initialized) return;
87     ParenSymbol = install("(");
88     PlusSymbol = install("+");
89     MinusSymbol = install("-");
90     TimesSymbol = install("*");
91     DivideSymbol = install("/");
92     PowerSymbol = install("^");
93     ExpSymbol = install("exp");
94     LogSymbol = install("log");
95     SinSymbol = install("sin");
96     CosSymbol = install("cos");
97     TanSymbol = install("tan");
98     SinhSymbol = install("sinh");
99     CoshSymbol = install("cosh");
100     TanhSymbol = install("tanh");
101     SqrtSymbol = install("sqrt");
102     PnormSymbol = install("pnorm");
103     DnormSymbol = install("dnorm");
104     AsinSymbol = install("asin");
105     AcosSymbol = install("acos");
106     AtanSymbol = install("atan");
107     GammaSymbol = install("gamma");
108     LGammaSymbol = install("lgamma");
109     DiGammaSymbol = install("digamma");
110     TriGammaSymbol = install("trigamma");
111     PsiSymbol = install("psigamma");
112 /* new symbols */
113     PiSymbol = install("pi");
114     ExpM1Symbol = install("expm1");
115     Log1PSymbol = install("log1p");
116     Log2Symbol = install("log2");
117     Log10Symbol = install("log10");
118     SinPiSymbol = install("sinpi");
119     CosPiSymbol = install("cospi");
120     TanPiSymbol = install("tanpi");
121     FactorialSymbol = install("factorial");
122     LFactorialSymbol = install("lfactorial");
123 /* possible future symbols
124     Log1PExpSymbol = install("log1pexp");    # log(1+exp(x))
125     Log1MExpSymbol = install("log1mexp");    # log(1-exp(-x)), for x > 0
126     Log1PMxSymbol = install("log1pmx");      # log1p(x)-x
127 */
128 
129     Initialized = TRUE;
130 }
131 
Constant(double x)132 static SEXP Constant(double x)
133 {
134     return ScalarReal(x);
135 }
136 
isZero(SEXP s)137 static int isZero(SEXP s)
138 {
139     return asReal(s) == 0.0;
140 }
141 
isOne(SEXP s)142 static int isOne(SEXP s)
143 {
144     return asReal(s) == 1.0;
145 }
146 
isUminus(SEXP s)147 static int isUminus(SEXP s)
148 {
149     if (TYPEOF(s) == LANGSXP && CAR(s) == MinusSymbol) {
150 	switch(length(s)) {
151 	case 2:
152 	    return 1;
153 	case 3:
154 	    if (CADDR(s) == R_MissingArg)
155 		return 1;
156 	    else return 0;
157 	default:
158 	    error(_("invalid form in unary minus check"));
159 	    return -1;/* for -Wall */
160 	}
161     }
162     else return 0;
163 }
164 
165 /* Pointer protect and return the argument */
166 
PP(SEXP s)167 static SEXP PP(SEXP s)
168 {
169     PROTECT(s);
170     return s;
171 }
172 
simplify(SEXP fun,SEXP arg1,SEXP arg2)173 static SEXP simplify(SEXP fun, SEXP arg1, SEXP arg2)
174 {
175     SEXP ans;
176     if (fun == PlusSymbol) {
177 	if (isZero(arg1))
178 	    ans = arg2;
179 	else if (isZero(arg2))
180 	    ans = arg1;
181 	else if (isUminus(arg1))
182 	    ans = simplify(MinusSymbol, arg2, CADR(arg1));
183 	else if (isUminus(arg2))
184 	    ans = simplify(MinusSymbol, arg1, CADR(arg2));
185 	else
186 	    ans = lang3(PlusSymbol, arg1, arg2);
187     }
188     else if (fun == MinusSymbol) {
189 	if (arg2 == R_MissingArg) {
190 	    if (isZero(arg1))
191 		ans = Constant(0.);
192 	    else if (isUminus(arg1))
193 		ans = CADR(arg1);
194 	    else
195 		ans = lang2(MinusSymbol, arg1);
196 	}
197 	else {
198 	    if (isZero(arg2))
199 		ans = arg1;
200 	    else if (isZero(arg1))
201 		ans = simplify(MinusSymbol, arg2, R_MissingArg);
202 	    else if (isUminus(arg1)) {
203 		ans = simplify(MinusSymbol,
204 			       PP(simplify(PlusSymbol, CADR(arg1), arg2)),
205 			       R_MissingArg);
206 		UNPROTECT(1);
207 	    }
208 	    else if (isUminus(arg2))
209 		ans = simplify(PlusSymbol, arg1, CADR(arg2));
210 	    else
211 		ans = lang3(MinusSymbol, arg1, arg2);
212 	}
213     }
214     else if (fun == TimesSymbol) {
215 	if (isZero(arg1) || isZero(arg2))
216 	    ans = Constant(0.);
217 	else if (isOne(arg1))
218 	    ans = arg2;
219 	else if (isOne(arg2))
220 	    ans = arg1;
221 	else if (isUminus(arg1)) {
222 	    ans = simplify(MinusSymbol,
223 			   PP(simplify(TimesSymbol, CADR(arg1), arg2)),
224 			   R_MissingArg);
225 	    UNPROTECT(1);
226 	}
227 	else if (isUminus(arg2)) {
228 	    ans = simplify(MinusSymbol,
229 			   PP(simplify(TimesSymbol, arg1, CADR(arg2))),
230 			   R_MissingArg);
231 	    UNPROTECT(1);
232 	}
233 	else
234 	    ans = lang3(TimesSymbol, arg1, arg2);
235     }
236     else if (fun == DivideSymbol) {
237 	if (isZero(arg1))
238 	    ans = Constant(0.);
239 	else if (isZero(arg2))
240 	    ans = Constant(NA_REAL);
241 	else if (isOne(arg2))
242 	    ans = arg1;
243 	else if (isUminus(arg1)) {
244 	    ans = simplify(MinusSymbol,
245 			   PP(simplify(DivideSymbol, CADR(arg1), arg2)),
246 			   R_MissingArg);
247 	    UNPROTECT(1);
248 	}
249 	else if (isUminus(arg2)) {
250 	    ans = simplify(MinusSymbol,
251 			   PP(simplify(DivideSymbol, arg1, CADR(arg2))),
252 			   R_MissingArg);
253 	    UNPROTECT(1);
254 	}
255 	else ans = lang3(DivideSymbol, arg1, arg2);
256     }
257     else if (fun == PowerSymbol) {
258 	if (isZero(arg2))
259 	    ans = Constant(1.);
260 	else if (isZero(arg1))
261 	    ans = Constant(0.);
262 	else if (isOne(arg1))
263 	    ans = Constant(1.);
264 	else if (isOne(arg2))
265 	    ans = arg1;
266 	else
267 	    ans = lang3(PowerSymbol, arg1, arg2);
268     }
269     else if (fun == ExpSymbol) {
270         /* FIXME: simplify exp(lgamma( E )) = gamma( E ) */
271         /* FIXME: simplify exp(lfactorial( E )) = factorial( E ) */
272         ans = lang2(ExpSymbol, arg1);
273     }
274     else if (fun == LogSymbol) {
275         /* FIXME: simplify log(gamma( E )) = lgamma( E ) */
276         /* FIXME: simplify log(factorial( E )) = lfactorial( E ) */
277         ans = lang2(LogSymbol, arg1);
278     }
279     else if (fun == CosSymbol)  ans = lang2(CosSymbol, arg1);
280     else if (fun == SinSymbol)  ans = lang2(SinSymbol, arg1);
281     else if (fun == TanSymbol)  ans = lang2(TanSymbol, arg1);
282     else if (fun == CoshSymbol) ans = lang2(CoshSymbol, arg1);
283     else if (fun == SinhSymbol) ans = lang2(SinhSymbol, arg1);
284     else if (fun == TanhSymbol) ans = lang2(TanhSymbol, arg1);
285     else if (fun == SqrtSymbol) ans = lang2(SqrtSymbol, arg1);
286     else if (fun == PnormSymbol)ans = lang2(PnormSymbol, arg1);
287     else if (fun == DnormSymbol)ans = lang2(DnormSymbol, arg1);
288     else if (fun == AsinSymbol) ans = lang2(AsinSymbol, arg1);
289     else if (fun == AcosSymbol) ans = lang2(AcosSymbol, arg1);
290     else if (fun == AtanSymbol) ans = lang2(AtanSymbol, arg1);
291     else if (fun == GammaSymbol)ans = lang2(GammaSymbol, arg1);
292     else if (fun == LGammaSymbol)ans = lang2(LGammaSymbol, arg1);
293     else if (fun == DiGammaSymbol) ans = lang2(DiGammaSymbol, arg1);
294     else if (fun == TriGammaSymbol) ans = lang2(TriGammaSymbol, arg1);
295     else if (fun == PsiSymbol){
296        if (arg2 == R_MissingArg) ans = lang2(PsiSymbol, arg1);
297        else ans = lang3(PsiSymbol, arg1, arg2);
298     }
299 /* new symbols */
300     else if (fun == ExpM1Symbol) {
301         /* FIXME: simplify expm1(log1p( E )) = E */
302         ans = lang2(ExpM1Symbol, arg1);
303     }
304     else if (fun == LogSymbol) {
305         /* FIXME: simplify log1p(expm1( E )) = E */
306         ans = lang2(Log1PSymbol, arg1);
307     }
308     else if (fun == Log2Symbol) ans = lang2(Log2Symbol, arg1);
309     else if (fun == Log10Symbol) ans = lang2(Log10Symbol, arg1);
310     else if (fun == CosPiSymbol) ans = lang2(CosPiSymbol, arg1);
311     else if (fun == SinPiSymbol) ans = lang2(SinPiSymbol, arg1);
312     else if (fun == TanPiSymbol) ans = lang2(TanPiSymbol, arg1);
313     else if (fun == FactorialSymbol)ans = lang2(FactorialSymbol, arg1);
314     else if (fun == LFactorialSymbol)ans = lang2(LFactorialSymbol, arg1);
315 /* possible future symbols
316     else if (fun == Log1PExpSymbol) ans = lang2(Log1PExpSymbol, arg1);
317     else if (fun == Log1MExpSymbol) ans = lang2(Log1MExpSymbol, arg1);
318     else if (fun == Log1PMxSymbol) ans = lang2(Log1PMxSymbol, arg1);
319 */
320 
321     else ans = Constant(NA_REAL);
322     /* FIXME */
323 #ifdef NOTYET
324     if (length(ans) == 2 && isAtomic(CADR(ans)) && CAR(ans) != MinusSymbol)
325 	c = eval(c, rho);
326     if (length(c) == 3 && isAtomic(CADR(ans)) && isAtomic(CADDR(ans)))
327 	c = eval(c, rho);
328 #endif
329     return ans;
330 }/* simplify() */
331 
332 
333 /* D() implements the "derivative table" : */
D(SEXP expr,SEXP var)334 static SEXP D(SEXP expr, SEXP var)
335 {
336 
337 #define PP_S(F,a1,a2) PP(simplify(F,a1,a2))
338 #define PP_S2(F,a1)   PP(simplify(F,a1, R_MissingArg))
339 
340     SEXP ans = R_NilValue, expr1, expr2;
341     switch(TYPEOF(expr)) {
342     case LGLSXP:
343     case INTSXP:
344     case REALSXP:
345     case CPLXSXP:
346 	ans = Constant(0);
347 	break;
348     case SYMSXP:
349 	if (expr == var) ans = Constant(1.);
350 	else ans = Constant(0.);
351 	break;
352     case LISTSXP:
353 	if (inherits(expr, "expression")) ans = D(CAR(expr), var);
354 	else ans = Constant(NA_REAL);
355 	break;
356     case LANGSXP:
357 	if (CAR(expr) == ParenSymbol) {
358 	    ans = D(CADR(expr), var);
359 	}
360 	else if (CAR(expr) == PlusSymbol) {
361 	    if (length(expr) == 2)
362 		ans = D(CADR(expr), var);
363 	    else {
364 		ans = simplify(PlusSymbol,
365 			       PP(D(CADR(expr), var)),
366 			       PP(D(CADDR(expr), var)));
367 		UNPROTECT(2);
368 	    }
369 	}
370 	else if (CAR(expr) == MinusSymbol) {
371 	    if (length(expr) == 2) {
372 		ans = simplify(MinusSymbol,
373 			       PP(D(CADR(expr), var)),
374 			       R_MissingArg);
375 		UNPROTECT(1);
376 	    }
377 	    else {
378 		ans = simplify(MinusSymbol,
379 			       PP(D(CADR(expr), var)),
380 			       PP(D(CADDR(expr), var)));
381 		UNPROTECT(2);
382 	    }
383 	}
384 	else if (CAR(expr) == TimesSymbol) {
385 	    ans = simplify(PlusSymbol,
386 			   PP_S(TimesSymbol,PP(D(CADR(expr),var)), CADDR(expr)),
387 			   PP_S(TimesSymbol,CADR(expr), PP(D(CADDR(expr),var))));
388 	    UNPROTECT(4);
389 	}
390 	else if (CAR(expr) == DivideSymbol) {
391 	    PROTECT(expr1 = D(CADR(expr), var));
392 	    PROTECT(expr2 = D(CADDR(expr), var));
393 	    ans = simplify(MinusSymbol,
394 			   PP_S(DivideSymbol, expr1, CADDR(expr)),
395 			   PP_S(DivideSymbol,
396 				PP_S(TimesSymbol, CADR(expr), expr2),
397 				PP_S(PowerSymbol,CADDR(expr),PP(Constant(2.)))));
398 	    UNPROTECT(7);
399 	}
400 	else if (CAR(expr) == PowerSymbol) {
401 	    if (isLogical(CADDR(expr)) || isNumeric(CADDR(expr))) {
402 		ans = simplify(TimesSymbol,
403 			       CADDR(expr),
404 			       PP_S(TimesSymbol,
405 				    PP(D(CADR(expr), var)),
406 				    PP_S(PowerSymbol,
407 					 CADR(expr),
408 					 PP(Constant(asReal(CADDR(expr))-1.)))));
409 		UNPROTECT(4);
410 	    }
411 	    else {
412 		expr1 = simplify(TimesSymbol,
413 				 PP_S(PowerSymbol,
414 				      CADR(expr),
415 				      PP_S(MinusSymbol,
416 					   CADDR(expr),
417 					   PP(Constant(1.0)))),
418 				 PP_S(TimesSymbol,
419 				      CADDR(expr),
420 				      PP(D(CADR(expr), var))));
421 		UNPROTECT(5);
422 		PROTECT(expr1);
423 		expr2 = simplify(TimesSymbol,
424 				 PP_S(PowerSymbol, CADR(expr), CADDR(expr)),
425 				 PP_S(TimesSymbol,
426 				      PP_S2(LogSymbol, CADR(expr)),
427 				      PP(D(CADDR(expr), var))));
428 		UNPROTECT(4);
429 		PROTECT(expr2);
430 		ans = simplify(PlusSymbol, expr1, expr2);
431 		UNPROTECT(2);
432 	    }
433 	}
434 	else if (CAR(expr) == ExpSymbol) {
435 	    ans = simplify(TimesSymbol,
436 			   expr,
437 			   PP(D(CADR(expr), var)));
438 	    UNPROTECT(1);
439 	}
440 	else if (CAR(expr) == LogSymbol) {
441 	    if (length(expr) != 2)
442 		error("only single-argument calls to log() are supported;\n"
443 		      "  maybe use log(x,a) = log(x)/log(a)");
444 	    ans = simplify(DivideSymbol,
445 			   PP(D(CADR(expr), var)),
446 			   CADR(expr));
447 	    UNPROTECT(1);
448 	}
449 	else if (CAR(expr) == CosSymbol) {
450 	    ans = simplify(TimesSymbol,
451 			   PP_S2(SinSymbol, CADR(expr)),
452 			   PP_S2(MinusSymbol, PP(D(CADR(expr), var))));
453 	    UNPROTECT(3);
454 	}
455 	else if (CAR(expr) == SinSymbol) {
456 	    ans = simplify(TimesSymbol,
457 			   PP_S2(CosSymbol, CADR(expr)),
458 			   PP(D(CADR(expr), var)));
459 	    UNPROTECT(2);
460 	}
461 	else if (CAR(expr) == TanSymbol) {
462 	    ans = simplify(DivideSymbol,
463 			   PP(D(CADR(expr), var)),
464 			   PP_S(PowerSymbol,
465 				PP_S2(CosSymbol, CADR(expr)),
466 				PP(Constant(2.0))));
467 	    UNPROTECT(4);
468 	}
469 	else if (CAR(expr) == CoshSymbol) {
470 	    ans = simplify(TimesSymbol,
471 			   PP_S2(SinhSymbol, CADR(expr)),
472 			   PP(D(CADR(expr), var)));
473 	    UNPROTECT(2);
474 	}
475 	else if (CAR(expr) == SinhSymbol) {
476 	    ans = simplify(TimesSymbol,
477 			   PP_S2(CoshSymbol, CADR(expr)),
478 			   PP(D(CADR(expr), var))),
479 		UNPROTECT(2);
480 	}
481 	else if (CAR(expr) == TanhSymbol) {
482 	    ans = simplify(DivideSymbol,
483 			   PP(D(CADR(expr), var)),
484 			   PP_S(PowerSymbol,
485 				PP_S2(CoshSymbol, CADR(expr)),
486 				PP(Constant(2.0))));
487 	    UNPROTECT(4);
488 	}
489 	else if (CAR(expr) == SqrtSymbol) {
490 	    PROTECT(expr1 = allocList(3));
491 	    SET_TYPEOF(expr1, LANGSXP);
492 	    SETCAR(expr1, PowerSymbol);
493 	    SETCADR(expr1, CADR(expr));
494 	    SETCADDR(expr1, Constant(0.5));
495 	    ans = D(expr1, var);
496 	    UNPROTECT(1);
497 	}
498 	else if (CAR(expr) == PnormSymbol) {
499 	    ans = simplify(TimesSymbol,
500 			   PP_S2(DnormSymbol, CADR(expr)),
501 			   PP(D(CADR(expr), var)));
502 	    UNPROTECT(2);
503 	}
504 	else if (CAR(expr) == DnormSymbol) {
505 	    ans = simplify(TimesSymbol,
506 			   PP_S2(MinusSymbol, CADR(expr)),
507 			   PP_S(TimesSymbol,
508 				PP_S2(DnormSymbol, CADR(expr)),
509 				PP(D(CADR(expr), var))));
510 	    UNPROTECT(4);
511 	}
512 	else if (CAR(expr) == AsinSymbol) {
513 	    ans = simplify(DivideSymbol,
514 			   PP(D(CADR(expr), var)),
515 			   PP_S(SqrtSymbol,
516 				PP_S(MinusSymbol, PP(Constant(1.)),
517 				     PP_S(PowerSymbol,CADR(expr),PP(Constant(2.)))),
518 				R_MissingArg));
519 	    UNPROTECT(6);
520 	}
521 	else if (CAR(expr) == AcosSymbol) {
522 	    ans = simplify(MinusSymbol,
523 			   PP_S(DivideSymbol,
524 				PP(D(CADR(expr), var)),
525 				PP_S(SqrtSymbol,
526 				     PP_S(MinusSymbol, PP(Constant(1.)),
527 					  PP_S(PowerSymbol,
528 					       CADR(expr),PP(Constant(2.)))),
529 				     R_MissingArg)), R_MissingArg);
530 	    UNPROTECT(7);
531 	}
532 	else if (CAR(expr) == AtanSymbol) {
533 	    ans = simplify(DivideSymbol,
534 			   PP(D(CADR(expr), var)),
535 			   PP_S(PlusSymbol,PP(Constant(1.)),
536 				PP_S(PowerSymbol, CADR(expr),PP(Constant(2.)))));
537 	    UNPROTECT(5);
538 	}
539 	else if (CAR(expr) == LGammaSymbol) {
540 	    ans = simplify(TimesSymbol,
541 			   PP(D(CADR(expr), var)),
542 			   PP_S2(DiGammaSymbol, CADR(expr)));
543 	    UNPROTECT(2);
544 	}
545 	else if (CAR(expr) == GammaSymbol) {
546 	    ans = simplify(TimesSymbol,
547 			   PP(D(CADR(expr), var)),
548 			   PP_S(TimesSymbol,
549 				expr,
550 				PP_S2(DiGammaSymbol, CADR(expr))));
551 	    UNPROTECT(3);
552 	}
553 	else if (CAR(expr) == DiGammaSymbol) {
554 	    ans = simplify(TimesSymbol,
555 			   PP(D(CADR(expr), var)),
556 			   PP_S2(TriGammaSymbol, CADR(expr)));
557 	    UNPROTECT(2);
558 	}
559 	else if (CAR(expr) == TriGammaSymbol) {
560 	    ans = simplify(TimesSymbol,
561 			   PP(D(CADR(expr), var)),
562 			   PP_S(PsiSymbol, CADR(expr), PP(ScalarInteger(2))));
563 	    UNPROTECT(3);
564 	}
565 	else if (CAR(expr) == PsiSymbol) {
566 	    if (length(expr) == 2){
567 		ans = simplify(TimesSymbol,
568 			       PP(D(CADR(expr), var)),
569 			       PP_S(PsiSymbol, CADR(expr), PP(ScalarInteger(1))));
570 		UNPROTECT(3);
571 	    } else if (TYPEOF(CADDR(expr)) == INTSXP ||
572 		       TYPEOF(CADDR(expr)) == REALSXP) {
573 		ans = simplify(TimesSymbol,
574 			       PP(D(CADR(expr), var)),
575 			       PP_S(PsiSymbol,
576 				    CADR(expr),
577 				    PP(ScalarInteger(asInteger(CADDR(expr))+1))));
578 		UNPROTECT(3);
579 	    } else {
580 		ans = simplify(TimesSymbol,
581 			       PP(D(CADR(expr), var)),
582 			       PP_S(PsiSymbol,
583 				    CADR(expr),
584 				    simplify(PlusSymbol, CADDR(expr),
585 					     PP(ScalarInteger(1)))));
586 		UNPROTECT(3);
587 	    }
588 	}
589 /* new in R 3.4.0 */
590         else if (CAR(expr) == ExpM1Symbol) {
591             ans = simplify(TimesSymbol,
592 			   PP_S2(ExpSymbol, CADR(expr)),
593                            PP(D(CADR(expr), var)));
594             UNPROTECT(2);
595         }
596         else if (CAR(expr) == Log1PSymbol) {
597             ans = simplify(DivideSymbol,
598                            PP(D(CADR(expr), var)),
599                            PP_S(PlusSymbol, PP(Constant(1.)), CADR(expr)));
600             UNPROTECT(3);
601         }
602         else if (CAR(expr) == Log2Symbol) {
603             ans = simplify(DivideSymbol,
604                            PP(D(CADR(expr), var)),
605                            PP_S(TimesSymbol, CADR(expr),
606 				             PP_S2(LogSymbol, PP(Constant(2.)))));
607             UNPROTECT(4);
608         }
609         else if (CAR(expr) == Log10Symbol) {
610             ans = simplify(DivideSymbol,
611                            PP(D(CADR(expr), var)),
612                            PP_S(TimesSymbol, CADR(expr),
613 				             PP_S2(LogSymbol, PP(Constant(10.)))));
614             UNPROTECT(4);
615         }
616         else if (CAR(expr) == CosPiSymbol) {
617             ans = simplify(TimesSymbol,
618                            PP_S2(SinPiSymbol, CADR(expr)),
619                            PP_S(TimesSymbol, PP_S2(MinusSymbol, PiSymbol),
620 				             PP(D(CADR(expr), var)) ));
621             UNPROTECT(4);
622         }
623         else if (CAR(expr) == SinPiSymbol) {
624             ans = simplify(TimesSymbol,
625                            PP_S2(CosPiSymbol, CADR(expr)),
626                            PP_S(TimesSymbol, PiSymbol,
627                                              PP(D(CADR(expr), var)) ));
628             UNPROTECT(3);
629         }
630         else if (CAR(expr) == TanPiSymbol) {
631             ans = simplify(DivideSymbol,
632                            PP_S(TimesSymbol, PiSymbol, PP(D(CADR(expr), var))),
633 			   PP_S(PowerSymbol,
634 				PP_S2(CosPiSymbol, CADR(expr)),
635 				PP(Constant(2.0))));
636             UNPROTECT(5);
637         }
638         else if (CAR(expr) == LFactorialSymbol) {
639             ans = simplify(TimesSymbol,
640                            PP(D(CADR(expr), var)),
641                            PP_S2(DiGammaSymbol, PP_S(PlusSymbol,
642 						     CADR(expr),
643 						     PP(ScalarInteger(1)))));
644             UNPROTECT(4);
645         }
646         else if (CAR(expr) == FactorialSymbol) {
647             ans = simplify(TimesSymbol,
648                            PP(D(CADR(expr), var)),
649                            PP_S(TimesSymbol,
650                                 expr,
651                                 PP_S2(DiGammaSymbol, PP_S(PlusSymbol,
652 							  CADR(expr),
653 							  PP(ScalarInteger(1))))));
654             UNPROTECT(5);
655         }
656 /* possible future symbols
657         else if (CAR(expr) == Log1PExpSymbol) {
658             ans = simplify(DivideSymbol,
659                            PP_S(TimesSymbol, PP(D(CADR(expr), var)),
660                                 PP_S2(ExpSymbol, CADR(expr))),
661                            PP_S(PlusSymbol,PP(Constant(1.)),
662                                 PP_S2(ExpSymbol, CADR(expr)) ));
663             UNPROTECT(6);
664         }
665         else if (CAR(expr) == Log1MExpSymbol) {
666             ans = simplify(DivideSymbol,
667                            PP_S(TimesSymbol, PP_S2(MinusSymbol, PP(D(CADR(expr), var))),
668                                 PP_S2(ExpSymbol, PP_S2(MinusSymbol, CADR(expr))) ),
669                            PP_S2(ExpM1Symbol, PP_S2(MinusSymbol, CADR(expr))) );
670             UNPROTECT(7);
671         }
672         else if (CAR(expr) == Log1PMxSymbol) {
673             ans = simplify(DivideSymbol,
674                            PP_S2(MinusSymbol, PP(D(CADR(expr), var))),
675                            PP_S(PlusSymbol,PP(Constant(1.)), CADR(expr)) );
676             UNPROTECT(4);
677         }
678 */
679 
680 	else {
681 	    SEXP u = deparse1(CAR(expr), 0, SIMPLEDEPARSE);
682 	    error(_("Function '%s' is not in the derivatives table"),
683 		  translateChar(STRING_ELT(u, 0)));
684 	}
685 
686 	break;
687     default:
688 	ans = Constant(NA_REAL);
689     }
690     return ans;
691 
692 #undef PP_S
693 #undef PP_S2
694 
695 } /* D() */
696 
isPlusForm(SEXP expr)697 static int isPlusForm(SEXP expr)
698 {
699     return TYPEOF(expr) == LANGSXP
700 	&& length(expr) == 3
701 	&& CAR(expr) == PlusSymbol;
702 }
703 
isMinusForm(SEXP expr)704 static int isMinusForm(SEXP expr)
705 {
706     return TYPEOF(expr) == LANGSXP
707 	&& length(expr) == 3
708 	&& CAR(expr) == MinusSymbol;
709 }
710 
isTimesForm(SEXP expr)711 static int isTimesForm(SEXP expr)
712 {
713     return TYPEOF(expr) == LANGSXP
714 	&& length(expr) == 3
715 	&& CAR(expr) == TimesSymbol;
716 }
717 
isDivideForm(SEXP expr)718 static int isDivideForm(SEXP expr)
719 {
720     return TYPEOF(expr) == LANGSXP
721 	&& length(expr) == 3
722 	&& CAR(expr) == DivideSymbol;
723 }
724 
isPowerForm(SEXP expr)725 static int isPowerForm(SEXP expr)
726 {
727     return (TYPEOF(expr) == LANGSXP
728 	    && length(expr) == 3
729 	    && CAR(expr) == PowerSymbol);
730 }
731 
AddParens(SEXP expr)732 static SEXP AddParens(SEXP expr)
733 {
734     SEXP e;
735     if (TYPEOF(expr) == LANGSXP) {
736 	e = CDR(expr);
737 	while(e != R_NilValue) {
738 	    SETCAR(e, AddParens(CAR(e)));
739 	    e = CDR(e);
740 	}
741     }
742     if (isPlusForm(expr)) {
743 	if (isPlusForm(CADDR(expr))) {
744 	    SETCADDR(expr, lang2(ParenSymbol, CADDR(expr)));
745 	}
746     }
747     else if (isMinusForm(expr)) {
748 	if (isPlusForm(CADDR(expr)) || isMinusForm(CADDR(expr))) {
749 	    SETCADDR(expr, lang2(ParenSymbol, CADDR(expr)));
750 	}
751     }
752     else if (isTimesForm(expr)) {
753 	if (isPlusForm(CADDR(expr)) || isMinusForm(CADDR(expr))
754 	    || isTimesForm(CADDR(expr)) || isDivideForm(CADDR(expr))) {
755 	    SETCADDR(expr, lang2(ParenSymbol, CADDR(expr)));
756 	}
757 	if (isPlusForm(CADR(expr)) || isMinusForm(CADR(expr))) {
758 	    SETCADR(expr, lang2(ParenSymbol, CADR(expr)));
759 	}
760     }
761     else if (isDivideForm(expr)) {
762 	if (isPlusForm(CADDR(expr)) || isMinusForm(CADDR(expr))
763 	    || isTimesForm(CADDR(expr)) || isDivideForm(CADDR(expr))) {
764 	    SETCADDR(expr, lang2(ParenSymbol, CADDR(expr)));
765 	}
766 	if (isPlusForm(CADR(expr)) || isMinusForm(CADR(expr))) {
767 	    SETCADR(expr, lang2(ParenSymbol, CADR(expr)));
768 	}
769     }
770     else if (isPowerForm(expr)) {
771 	if (isPowerForm(CADR(expr))) {
772 	    SETCADR(expr, lang2(ParenSymbol, CADR(expr)));
773 	}
774 	if (isPlusForm(CADDR(expr)) || isMinusForm(CADDR(expr))
775 	    || isTimesForm(CADDR(expr)) || isDivideForm(CADDR(expr))) {
776 	    SETCADDR(expr, lang2(ParenSymbol, CADDR(expr)));
777 	}
778     }
779     return expr;
780 }
781 
doD(SEXP args)782 SEXP doD(SEXP args)
783 {
784     SEXP expr, var;
785     args = CDR(args);
786     if (isExpression(CAR(args))) expr = VECTOR_ELT(CAR(args), 0);
787     else expr = CAR(args);
788     if (!(isLanguage(expr) || isSymbol(expr) || isNumeric(expr) || isComplex(expr)))
789         error(_("expression must not be type '%s'"), type2char(TYPEOF(expr)));
790     var = CADR(args);
791     if (!isString(var) || length(var) < 1)
792 	error(_("variable must be a character string"));
793     if (length(var) > 1)
794 	warning(_("only the first element is used as variable name"));
795     var = installTrChar(STRING_ELT(var, 0));
796     InitDerivSymbols();
797     PROTECT(expr = D(expr, var));
798     expr = AddParens(expr);
799     UNPROTECT(1);
800     return expr;
801 }
802 
803 /* ------ FindSubexprs ------ and ------ Accumulate ------ */
804 
InvalidExpression(char * where)805 static void NORET InvalidExpression(char *where)
806 {
807     error(_("invalid expression in '%s'"), where);
808 }
809 
equal(SEXP expr1,SEXP expr2)810 static int equal(SEXP expr1, SEXP expr2)
811 {
812     if (TYPEOF(expr1) == TYPEOF(expr2)) {
813 	switch(TYPEOF(expr1)) {
814 	case NILSXP:
815 	    return 1;
816 	case SYMSXP:
817 	    return expr1 == expr2;
818 	case LGLSXP:
819 	case INTSXP:
820 	    return INTEGER(expr1)[0] == INTEGER(expr2)[0];
821 	case REALSXP:
822 	    return REAL(expr1)[0] == REAL(expr2)[0];
823 	case CPLXSXP:
824 	    return COMPLEX(expr1)[0].r == COMPLEX(expr2)[0].r
825 		&& COMPLEX(expr1)[0].i == COMPLEX(expr2)[0].i;
826 	case LANGSXP:
827 	case LISTSXP:
828 	    return equal(CAR(expr1), CAR(expr2))
829 		&& equal(CDR(expr1), CDR(expr2));
830 	default:
831 	    InvalidExpression("equal");
832 	}
833     }
834     return 0;
835 }
836 
Accumulate(SEXP expr,SEXP exprlist)837 static int Accumulate(SEXP expr, SEXP exprlist)
838 {
839     SEXP e;
840     int k;
841     e = exprlist;
842     k = 0;
843     while(CDR(e) != R_NilValue) {
844 	e = CDR(e);
845 	k = k + 1;
846 	if (equal(expr, CAR(e)))
847 	    return k;
848     }
849     SETCDR(e, CONS(expr, R_NilValue));
850     return k + 1;
851 }
852 
Accumulate2(SEXP expr,SEXP exprlist)853 static int Accumulate2(SEXP expr, SEXP exprlist)
854 {
855     SEXP e;
856     int k;
857     e = exprlist;
858     k = 0;
859     while(CDR(e) != R_NilValue) {
860 	e = CDR(e);
861 	k = k + 1;
862     }
863     SETCDR(e, CONS(expr, R_NilValue));
864     return k + 1;
865 }
866 
MakeVariable(int k,SEXP tag)867 static SEXP MakeVariable(int k, SEXP tag)
868 {
869     const void *vmax = vmaxget();
870     char buf[64];
871     snprintf(buf, 64, "%s%d", translateChar(STRING_ELT(tag, 0)), k);
872     vmaxset(vmax);
873     return install(buf);
874 }
875 
FindSubexprs(SEXP expr,SEXP exprlist,SEXP tag)876 static int FindSubexprs(SEXP expr, SEXP exprlist, SEXP tag)
877 {
878     SEXP e;
879     int k;
880     switch(TYPEOF(expr)) {
881     case SYMSXP:
882     case LGLSXP:
883     case INTSXP:
884     case REALSXP:
885     case CPLXSXP:
886 	return 0;
887 	break;
888     case LISTSXP:
889 	if (inherits(expr, "expression"))
890 	    return FindSubexprs(CAR(expr), exprlist, tag);
891 	else { InvalidExpression("FindSubexprs"); return -1/*-Wall*/; }
892 	break;
893     case LANGSXP:
894 	if (CAR(expr) == install("(")) {
895 	    return FindSubexprs(CADR(expr), exprlist, tag);
896 	}
897 	else {
898 	    e = CDR(expr);
899 	    while(e != R_NilValue) {
900 		if ((k = FindSubexprs(CAR(e), exprlist, tag)) != 0)
901 		    SETCAR(e, MakeVariable(k, tag));
902 		e = CDR(e);
903 	    }
904 	    return Accumulate(expr, exprlist);
905 	}
906 	break;
907     default:
908 	InvalidExpression("FindSubexprs");
909 	return -1/*-Wall*/;
910     }
911 }
912 
CountOccurrences(SEXP sym,SEXP lst)913 static int CountOccurrences(SEXP sym, SEXP lst)
914 {
915     switch(TYPEOF(lst)) {
916     case SYMSXP:
917 	return lst == sym;
918     case LISTSXP:
919     case LANGSXP:
920 	return CountOccurrences(sym, CAR(lst))
921 	    + CountOccurrences(sym, CDR(lst));
922     default:
923 	return 0;
924     }
925 }
926 
Replace(SEXP sym,SEXP expr,SEXP lst)927 static SEXP Replace(SEXP sym, SEXP expr, SEXP lst)
928 {
929     switch(TYPEOF(lst)) {
930     case SYMSXP:
931 	if (lst == sym) return expr;
932 	else return lst;
933     case LISTSXP:
934     case LANGSXP:
935 	SETCAR(lst, Replace(sym, expr, CAR(lst)));
936 	SETCDR(lst, Replace(sym, expr, CDR(lst)));
937 	return lst;
938     default:
939 	return lst;
940     }
941 }
942 
CreateGrad(SEXP names)943 static SEXP CreateGrad(SEXP names)
944 {
945     SEXP p, q, data, dim, dimnames;
946     int i, n;
947     n = length(names);
948     PROTECT(dimnames = lang3(R_NilValue, R_NilValue, R_NilValue));
949     SETCAR(dimnames, install("list"));
950     p = install("c");
951     PROTECT(q = allocList(n));
952     SETCADDR(dimnames, LCONS(p, q));
953     UNPROTECT(1);
954     for(i = 0 ; i < n ; i++) {
955 	SETCAR(q, ScalarString(STRING_ELT(names, i)));
956 	q = CDR(q);
957     }
958     PROTECT(dim = lang3(R_NilValue, R_NilValue, R_NilValue));
959     SETCAR(dim, install("c"));
960     SETCADR(dim, lang2(install("length"), install(".value")));
961     SETCADDR(dim, ScalarInteger(length(names))); /* was real? */
962     PROTECT(data = ScalarReal(0.));
963     PROTECT(p = lang4(install("array"), data, dim, dimnames));
964     p = lang3(install("<-"), install(".grad"), p);
965     UNPROTECT(4);
966     return p;
967 }
968 
CreateHess(SEXP names)969 static SEXP CreateHess(SEXP names)
970 {
971     SEXP p, q, data, dim, dimnames;
972     int i, n;
973     n = length(names);
974     PROTECT(dimnames = lang4(R_NilValue, R_NilValue, R_NilValue, R_NilValue));
975     SETCAR(dimnames, install("list"));
976     p = install("c");
977     PROTECT(q = allocList(n));
978     SETCADDR(dimnames, LCONS(p, q));
979     UNPROTECT(1);
980     for(i = 0 ; i < n ; i++) {
981 	SETCAR(q, ScalarString(STRING_ELT(names, i)));
982 	q = CDR(q);
983     }
984     SETCADDDR(dimnames, duplicate(CADDR(dimnames)));
985     PROTECT(dim = lang4(R_NilValue, R_NilValue, R_NilValue,R_NilValue));
986     SETCAR(dim, install("c"));
987     SETCADR(dim, lang2(install("length"), install(".value")));
988     SETCADDR(dim, ScalarInteger(length(names)));
989     SETCADDDR(dim, ScalarInteger(length(names)));
990     PROTECT(data = ScalarReal(0.));
991     PROTECT(p = lang4(install("array"), data, dim, dimnames));
992     p = lang3(install("<-"), install(".hessian"), p);
993     UNPROTECT(4);
994     return p;
995 }
996 
DerivAssign(SEXP name,SEXP expr)997 static SEXP DerivAssign(SEXP name, SEXP expr)
998 {
999     SEXP ans, newname;
1000     PROTECT(ans = lang3(install("<-"), R_NilValue, expr));
1001     PROTECT(newname = ScalarString(name));
1002     SETCADR(ans, lang4(R_BracketSymbol, install(".grad"), R_MissingArg, newname));
1003     UNPROTECT(2);
1004     return ans;
1005 }
1006 
HessAssign1(SEXP name,SEXP expr)1007 static SEXP HessAssign1(SEXP name, SEXP expr)
1008 {
1009     SEXP ans, newname;
1010     PROTECT(ans = lang3(install("<-"), R_NilValue, expr));
1011     PROTECT(newname = ScalarString(name));
1012     SETCADR(ans, lang5(R_BracketSymbol, install(".hessian"), R_MissingArg,
1013 		       newname, newname));
1014     UNPROTECT(2);
1015     return ans;
1016 }
1017 
HessAssign2(SEXP name1,SEXP name2,SEXP expr)1018 static SEXP HessAssign2(SEXP name1, SEXP name2, SEXP expr)
1019 {
1020     SEXP ans, newname1, newname2, tmp1, tmp2, tmp3;
1021     PROTECT(newname1 = ScalarString(name1));
1022     PROTECT(newname2 = ScalarString(name2));
1023     /* this is overkill, but PR#14772 found an issue */
1024     PROTECT(tmp1 = lang5(R_BracketSymbol, install(".hessian"), R_MissingArg,
1025 			 newname1, newname2));
1026     PROTECT(tmp2 = lang5(R_BracketSymbol, install(".hessian"), R_MissingArg,
1027 			 newname2, newname1));
1028     PROTECT(tmp3 = lang3(install("<-"), tmp2, expr));
1029     ans = lang3(install("<-"), tmp1, tmp3);
1030     UNPROTECT(5);
1031     return ans;
1032 }
1033 
1034 /* attr(.value, "gradient") <- .grad */
1035 
AddGrad(void)1036 static SEXP AddGrad(void)
1037 {
1038     SEXP ans;
1039     PROTECT(ans = mkString("gradient"));
1040     PROTECT(ans = lang3(install("attr"), install(".value"), ans));
1041     ans = lang3(install("<-"), ans, install(".grad"));
1042     UNPROTECT(2);
1043     return ans;
1044 }
1045 
AddHess(void)1046 static SEXP AddHess(void)
1047 {
1048     SEXP ans;
1049     PROTECT(ans = mkString("hessian"));
1050     PROTECT(ans = lang3(install("attr"), install(".value"), ans));
1051     ans = lang3(install("<-"), ans, install(".hessian"));
1052     UNPROTECT(2);
1053     return ans;
1054 }
1055 
Prune(SEXP lst)1056 static SEXP Prune(SEXP lst)
1057 {
1058     if (lst == R_NilValue)
1059 	return lst;
1060     SETCDR(lst, Prune(CDR(lst)));
1061     if (CAR(lst) == R_MissingArg)
1062 	return CDR(lst);
1063     else return lst ;
1064 }
1065 
deriv(SEXP args)1066 SEXP deriv(SEXP args)
1067 {
1068 /* deriv(expr, namevec, function.arg, tag, hessian) */
1069     SEXP ans, ans2, expr, funarg, names, s;
1070     int f_index, *d_index, *d2_index;
1071     int i, j, k, nexpr, nderiv=0, hessian;
1072     SEXP exprlist, tag;
1073 
1074     args = CDR(args);
1075     InitDerivSymbols();
1076     PROTECT(exprlist = LCONS(R_BraceSymbol, R_NilValue));
1077     /* expr: */
1078     if (isExpression(CAR(args)))
1079 	PROTECT(expr = VECTOR_ELT(CAR(args), 0));
1080     else PROTECT(expr = CAR(args));
1081     args = CDR(args);
1082     /* namevec: */
1083     names = CAR(args);
1084     if (!isString(names) || (nderiv = length(names)) < 1)
1085 	error(_("invalid variable names"));
1086     args = CDR(args);
1087     /* function.arg: */
1088     funarg = CAR(args);
1089     args = CDR(args);
1090     /* tag: */
1091     tag = CAR(args);
1092     if (!isString(tag) || length(tag) < 1
1093 	|| length(STRING_ELT(tag, 0)) < 1 || length(STRING_ELT(tag, 0)) > 60)
1094 	error(_("invalid tag"));
1095     args = CDR(args);
1096     /* hessian: */
1097     hessian = asLogical(CAR(args));
1098     /* NOTE: FindSubexprs is destructive, hence the duplication.
1099        It can allocate, so protect the duplicate.
1100      */
1101     PROTECT(ans = duplicate(expr));
1102     f_index = FindSubexprs(ans, exprlist, tag);
1103     d_index = (int*)R_alloc((size_t) nderiv, sizeof(int));
1104     if (hessian)
1105 	d2_index = (int*)R_alloc((size_t) ((nderiv * (1 + nderiv))/2),
1106 				 sizeof(int));
1107     else d2_index = d_index;/*-Wall*/
1108     UNPROTECT(1);
1109     for(i=0, k=0; i<nderiv ; i++) {
1110 	PROTECT(ans = duplicate(expr));
1111 	PROTECT(ans = D(ans, installTrChar(STRING_ELT(names, i))));
1112 	PROTECT(ans2 = duplicate(ans));	/* keep a temporary copy */
1113 	d_index[i] = FindSubexprs(ans, exprlist, tag); /* examine the derivative first */
1114 	PROTECT(ans = duplicate(ans2));	/* restore the copy */
1115 	if (hessian) {
1116 	    for(j = i; j < nderiv; j++) {
1117 		PROTECT(ans2 = duplicate(ans)); /* install could allocate */
1118 		PROTECT(ans2 = D(ans2, installTrChar(STRING_ELT(names, j))));
1119 		d2_index[k] = FindSubexprs(ans2, exprlist, tag);
1120 		k++;
1121 		UNPROTECT(2);
1122 	    }
1123 	}
1124 	UNPROTECT(4);
1125     }
1126     nexpr = length(exprlist) - 1;
1127     if (f_index) {
1128 	Accumulate2(MakeVariable(f_index, tag), exprlist);
1129     }
1130     else {
1131 	PROTECT(ans = duplicate(expr));
1132 	Accumulate2(expr, exprlist);
1133 	UNPROTECT(1);
1134     }
1135     Accumulate2(R_NilValue, exprlist);
1136     if (hessian) { Accumulate2(R_NilValue, exprlist); }
1137     for (i = 0, k = 0; i < nderiv ; i++) {
1138 	if (d_index[i]) {
1139 	    Accumulate2(MakeVariable(d_index[i], tag), exprlist);
1140 	    if (hessian) {
1141 		PROTECT(ans = duplicate(expr));
1142 		PROTECT(ans = D(ans, installTrChar(STRING_ELT(names, i))));
1143 		for (j = i; j < nderiv; j++) {
1144 		    if (d2_index[k]) {
1145 			Accumulate2(MakeVariable(d2_index[k], tag), exprlist);
1146 		    } else {
1147 			PROTECT(ans2 = duplicate(ans));
1148 			PROTECT(ans2 = D(ans2, installTrChar(STRING_ELT(names, j))));
1149 			Accumulate2(ans2, exprlist);
1150 			UNPROTECT(2);
1151 		    }
1152 		    k++;
1153 		}
1154 		UNPROTECT(2);
1155 	    }
1156 	} else { /* the first derivative is constant or simple variable */
1157 	    PROTECT(ans = duplicate(expr));
1158 	    PROTECT(ans = D(ans, installTrChar(STRING_ELT(names, i))));
1159 	    Accumulate2(ans, exprlist);
1160 	    UNPROTECT(2);
1161 	    if (hessian) {
1162 		for (j = i; j < nderiv; j++) {
1163 		    if (d2_index[k]) {
1164 			Accumulate2(MakeVariable(d2_index[k], tag), exprlist);
1165 		    } else {
1166 			PROTECT(ans2 = duplicate(ans));
1167 			PROTECT(ans2 = D(ans2, installTrChar(STRING_ELT(names, j))));
1168 			if(isZero(ans2)) Accumulate2(R_MissingArg, exprlist);
1169 			else Accumulate2(ans2, exprlist);
1170 			UNPROTECT(2);
1171 		    }
1172 		    k++;
1173 		}
1174 	    }
1175 	}
1176     }
1177     Accumulate2(R_NilValue, exprlist);
1178     Accumulate2(R_NilValue, exprlist);
1179     if (hessian) { Accumulate2(R_NilValue, exprlist); }
1180 
1181     i = 0;
1182     ans = CDR(exprlist);
1183     while (i < nexpr) {
1184 	if (CountOccurrences(MakeVariable(i+1, tag), CDR(ans)) < 2) {
1185 	    SETCDR(ans, Replace(MakeVariable(i+1, tag), CAR(ans), CDR(ans)));
1186 	    SETCAR(ans, R_MissingArg);
1187 	}
1188 	else {
1189             SEXP var;
1190             PROTECT(var = MakeVariable(i+1, tag));
1191             SETCAR(ans, lang3(install("<-"), var, AddParens(CAR(ans))));
1192             UNPROTECT(1);
1193         }
1194 	i = i + 1;
1195 	ans = CDR(ans);
1196     }
1197     /* .value <- ... */
1198     SETCAR(ans, lang3(install("<-"), install(".value"), AddParens(CAR(ans))));
1199     ans = CDR(ans);
1200     /* .grad <- ... */
1201     SETCAR(ans, CreateGrad(names));
1202     ans = CDR(ans);
1203     /* .hessian <- ... */
1204     if (hessian) { SETCAR(ans, CreateHess(names)); ans = CDR(ans); }
1205     /* .grad[, "..."] <- ... */
1206     for (i = 0; i < nderiv ; i++) {
1207 	SETCAR(ans, DerivAssign(STRING_ELT(names, i), AddParens(CAR(ans))));
1208 	ans = CDR(ans);
1209 	if (hessian) {
1210 	    for (j = i; j < nderiv; j++) {
1211 		if (CAR(ans) != R_MissingArg) {
1212 		    if (i == j) {
1213 			SETCAR(ans, HessAssign1(STRING_ELT(names, i),
1214 						AddParens(CAR(ans))));
1215 		    } else {
1216 			SETCAR(ans, HessAssign2(STRING_ELT(names, i),
1217 						STRING_ELT(names, j),
1218 						AddParens(CAR(ans))));
1219 		    }
1220 		}
1221 		ans = CDR(ans);
1222 	    }
1223 	}
1224     }
1225     /* attr(.value, "gradient") <- .grad */
1226     SETCAR(ans, AddGrad());
1227     ans = CDR(ans);
1228     if (hessian) { SETCAR(ans, AddHess()); ans = CDR(ans); }
1229     /* .value */
1230     SETCAR(ans, install(".value"));
1231     /* Prune the expression list removing eliminated sub-expressions */
1232     SETCDR(exprlist, Prune(CDR(exprlist)));
1233 
1234     if (TYPEOF(funarg) == LGLSXP && LOGICAL(funarg)[0]) { /* fun = TRUE */
1235 	funarg = names;
1236     }
1237 
1238     if (TYPEOF(funarg) == CLOSXP)
1239     {
1240 	s = allocSExp(CLOSXP);
1241 	SET_FORMALS(s, FORMALS(funarg));
1242 	SET_CLOENV(s, CLOENV(funarg));
1243 	funarg = s;
1244 	SET_BODY(funarg, exprlist);
1245     }
1246     else if (isString(funarg)) {
1247 	PROTECT(names = duplicate(funarg));
1248 	PROTECT(funarg = allocSExp(CLOSXP));
1249 	PROTECT(ans = allocList(length(names)));
1250 	SET_FORMALS(funarg, ans);
1251 	for(i = 0; i < length(names); i++) {
1252 	    SET_TAG(ans, installTrChar(STRING_ELT(names, i)));
1253 	    SETCAR(ans, R_MissingArg);
1254 	    ans = CDR(ans);
1255 	}
1256 	UNPROTECT(3);
1257 	SET_BODY(funarg, exprlist);
1258 	SET_CLOENV(funarg, R_GlobalEnv);
1259     }
1260     else {
1261 	funarg = allocVector(EXPRSXP, 1);
1262 	SET_VECTOR_ELT(funarg, 0, exprlist);
1263 	/* funarg = lang2(install("expression"), exprlist); */
1264     }
1265     UNPROTECT(2);
1266     return funarg;
1267 }
1268