1""" Compressed Continuous Computation in python """
2
3from __future__ import print_function
4
5import sys
6import os
7try:
8    import c3
9except ImportError:
10    try:
11        C3HOME = os.environ['C3HOME']
12    except KeyError:
13        print("Must set environ variable C3HOME")
14    C3HOME = os.path.join(C3HOME, "build", "wrappers", "python")
15    if os.path.exists(C3HOME) is False:
16        raise ImportError("Must have ${C3HOME}/build/wrappers/python directory structure")
17    sys.path.insert(0, C3HOME)
18    import c3
19
20import numpy as np
21import copy
22import pycback as pcb
23import atexit
24# import contextlib
25
26def poly_randu(dim, ranks, maxorder):
27    """ Random function with specified ranks and legendre polynomials of maximum order"""
28    bds = c3.bounding_box_init_std(dim)
29    ft = c3.function_train_poly_randu(c3.LEGENDRE, bds, ranks, maxorder)
30    c3.bounding_box_free(bds)
31    return FunctionTrain(dim, ft=ft)
32
33class FunctionTrain(object):
34    """ Function-Train Decompositions """
35
36    ft = None
37    def __init__(self, din, ft=None):
38        """ Initialize a Function Train of dimension *din* """
39
40        self.dim = din
41        self.opts = []
42        self.ranks = [1]
43        for ii in range(din):
44            self.opts.append(None)
45            self.ranks.append(1)
46
47
48        self.ft = ft
49        # atexit.register(self.cleanup)
50
51    def copy(self):
52        """ Copy a function train """
53
54        ft_out = FunctionTrain(self.dim)
55        ft_out.ft = c3.function_train_copy(self.ft)
56        ft_out.opts = copy.deepcopy(self.opts)
57        return ft_out
58
59    def save(self, filename):
60        """ Save a function train to a file with name *filename* """
61        c3.function_train_save(self.ft, filename)
62
63    def load(self, filename):
64        """ Load a function-train from a file with name *filename* """
65
66        ft = c3.function_train_load(filename)
67        self.dim = c3.function_train_get_dim(ft)
68        self.ft = ft
69
70    def set_dim_opts(self, dim, ftype, lb=-1, ub=1, nparam=4,
71                     kernel_height_scale=1.0, kernel_width_scale=1.0,
72                     kernel_adapt_center=0,
73                     lin_elem_nodes=None, kernel_nodes=None,
74                     maxnum=np.inf, coeff_check=2, tol=1e-10,
75                     nregions=5):
76        """ Set approximation options per dimension """
77
78        if self.opts[dim] is not None:
79            raise AttributeError('cannot call set_dim_opts because was already called')
80
81        o = dict({})
82        o['type'] = ftype
83        o['lb'] = lb
84        o['ub'] = ub
85        o['nparam'] = nparam
86        o['maxnum'] = maxnum
87        o['coeff_check'] = coeff_check
88        o['kernel_height_scale'] = kernel_height_scale
89        o['kernel_width_scale'] = kernel_width_scale
90        o['kernel_adapt_center'] = kernel_adapt_center
91        o['tol'] = tol
92        o['nregions'] = nregions
93        if lin_elem_nodes is not None:
94            assert lin_elem_nodes.ndim == 1
95            lin_elem_nodes = np.unique(lin_elem_nodes.round(decimals=12))
96            o['nparam'] = len(lin_elem_nodes)
97            o['lb'] = lin_elem_nodes[0]
98            o['ub'] = lin_elem_nodes[-1]
99        else:
100            lin_elem_nodes = np.linspace(lb, ub, nparam)
101        o['lin_elem_nodes'] = lin_elem_nodes
102        if kernel_nodes is not None:
103            assert kernel_nodes.ndim == 1
104            kernel_nodes = np.unique(kernel_nodes.round(decimals=12))
105            o['nparam'] = len(kernel_nodes)
106            o['lb'] = kernel_nodes[0]
107            o['ub'] = kernel_nodes[-1]
108        o['kernel_nodes'] = kernel_nodes
109        self.opts.insert(dim, o)
110
111    def __convert_opts_to_c3_form__(self):
112
113        c3_ope_opts = []
114        for ii in range(self.dim):
115            ftype = self.opts[ii]['type']
116            lb = self.opts[ii]['lb']
117            ub = self.opts[ii]['ub']
118            nparam = self.opts[ii]['nparam']
119            maxnum = self.opts[ii]['maxnum']
120            coeff_check = self.opts[ii]['coeff_check']
121            tol = self.opts[ii]['tol']
122            kernel_height_scale = self.opts[ii]['kernel_height_scale']
123            kernel_width_scale = self.opts[ii]['kernel_width_scale']
124            kernel_adapt_center = self.opts[ii]['kernel_adapt_center']
125            lin_elem_nodes = self.opts[ii]['lin_elem_nodes']
126            kernel_nodes = self.opts[ii]['kernel_nodes']
127            nregions = self.opts[ii]['nregions']
128            if ftype == "legendre":
129                c3_ope_opts.append(c3.ope_opts_alloc(c3.LEGENDRE))
130                c3.ope_opts_set_lb(c3_ope_opts[ii], lb)
131                c3.ope_opts_set_ub(c3_ope_opts[ii], ub)
132                c3.ope_opts_set_nparams(c3_ope_opts[ii], nparam)
133                c3.ope_opts_set_coeffs_check(c3_ope_opts[ii], coeff_check)
134                c3.ope_opts_set_start(c3_ope_opts[ii], nparam)
135                if np.isinf(maxnum) is False:
136                    c3.ope_opts_set_maxnum(c3_ope_opts[ii], maxnum)
137                c3.ope_opts_set_tol(c3_ope_opts[ii], tol)
138            elif ftype == "hermite":
139                c3_ope_opts.append(c3.ope_opts_alloc(c3.HERMITE))
140                c3.ope_opts_set_nparams(c3_ope_opts[ii], nparam)
141                c3.ope_opts_set_tol(c3_ope_opts[ii], tol)
142            elif ftype == "fourier":
143                c3_ope_opts.append(c3.ope_opts_alloc(c3.FOURIER))
144                c3.ope_opts_set_lb(c3_ope_opts[ii], lb)
145                c3.ope_opts_set_ub(c3_ope_opts[ii], ub)
146                c3.ope_opts_set_nparams(c3_ope_opts[ii], nparam)
147                c3.ope_opts_set_coeffs_check(c3_ope_opts[ii], coeff_check)
148                c3.ope_opts_set_start(c3_ope_opts[ii], nparam)
149                if np.isinf(maxnum) is False:
150                    c3.ope_opts_set_maxnum(c3_ope_opts[ii], maxnum)
151                c3.ope_opts_set_tol(c3_ope_opts[ii], tol)
152            elif ftype == "linelm":
153                c3_ope_opts.append(c3.lin_elem_exp_aopts_alloc(lin_elem_nodes))
154            elif ftype == "kernel":
155                if kernel_adapt_center == 0:
156                    if kernel_nodes is not None:
157                        x = list(kernel_nodes)
158                    else:
159                        x = list(np.linspace(lb, ub, nparam))
160                    nparam = len(x)
161                    std = np.std(x)
162                    # print("standard deviation = ", std,  (ub-lb)/np.sqrt(12.0))
163                    width = nparam**(-0.2) / np.sqrt(12.0) * (ub-lb)  * kernel_width_scale
164                    width = nparam**(-0.2) * std * kernel_width_scale
165                    c3_ope_opts.append(c3.kernel_approx_opts_gauss(nparam, x,
166                                                                   kernel_height_scale,
167                                                                   kernel_width_scale))
168                else:
169                    if kernel_nodes is not None:
170                        x = list(kernel_nodes)
171                        n2 = len(x)
172                        nparam = 2*n2
173                    else:
174                        # print("here!!")
175                        assert nparam % 2 == 0, "number of parameters has to be even for adaptation"
176                        n2 = int(nparam/2)
177                        x = list(np.linspace(lb, ub, n2))
178                        # print("x = ", x)
179                    std = np.std(x)
180                    # width = (n2)**(-0.2) / np.sqrt(12.0) * (ub-lb)  * kernel_width_scale
181                    width = n2**(-0.2) * std * kernel_width_scale
182                    c3_ope_opts.append(c3.kernel_approx_opts_gauss(n2, x,
183                                                                   kernel_height_scale,
184                                                                   kernel_width_scale))
185                    c3.kernel_approx_opts_set_center_adapt(c3_ope_opts[-1], kernel_adapt_center)
186            elif ftype == "piecewise":
187                c3_ope_opts.append(c3.pw_poly_opts_alloc(c3.LEGENDRE, lb, ub))
188                c3.pw_poly_opts_set_maxorder(c3_ope_opts[-1], nparam)
189                c3.pw_poly_opts_set_coeffs_check(c3_ope_opts[-1], coeff_check)
190                c3.pw_poly_opts_set_tol(c3_ope_opts[-1], tol)
191                c3.pw_poly_opts_set_minsize(c3_ope_opts[-1], ((ub-lb)/nregions)**8)
192                c3.pw_poly_opts_set_nregions(c3_ope_opts[-1], nregions)
193            else:
194                raise AttributeError('No options can be specified for function type '
195                                     + ftype)
196        return c3_ope_opts
197
198    def _build_approx_params(self, method=c3.REGRESS):
199
200        c3_ope_opts = self.__convert_opts_to_c3_form__()
201
202        c3a = c3.c3approx_create(method, self.dim)
203        onedopts = []
204        optnodes = []
205        for ii in range(self.dim):
206            ope_opts = c3_ope_opts[ii]
207            if self.opts[ii]['type'] == "legendre":
208                onedopts.append(c3.one_approx_opts_alloc(c3.POLYNOMIAL, ope_opts))
209            elif self.opts[ii]['type'] == "hermite":
210                onedopts.append(c3.one_approx_opts_alloc(c3.POLYNOMIAL, ope_opts))
211            elif self.opts[ii]['type'] == "fourier":
212                onedopts.append(c3.one_approx_opts_alloc(c3.POLYNOMIAL, ope_opts))
213            elif self.opts[ii]['type'] == "linelm":
214                onedopts.append(c3.one_approx_opts_alloc(c3.LINELM, ope_opts))
215            elif self.opts[ii]['type'] == "kernel":
216                onedopts.append(c3.one_approx_opts_alloc(c3.KERNEL, ope_opts))
217            elif self.opts[ii]['type'] == "piecewise":
218                onedopts.append(c3.one_approx_opts_alloc(c3.PIECEWISE, ope_opts))
219            else:
220                raise AttributeError("Don't know what to do here")
221
222
223            lb = self.opts[ii]['lb']
224            ub = self.opts[ii]['ub']
225            if self.opts[ii]['type'] == "hermite" or self.opts[ii]['type'] == "fourier":
226                nn = 50
227                x = c3.linspace(lb, ub, nn)
228                cv = c3.c3vector_alloc(nn, x)
229                optnodes.append((x,cv))
230                c3.c3approx_set_opt_opts_dim(c3a, ii, optnodes[-1][1])
231
232            c3.c3approx_set_approx_opts_dim(c3a, ii, onedopts[ii])
233
234        return c3a, onedopts, c3_ope_opts, optnodes
235
236    def _free_approx_params(self, c3a, onedopts, low_opts, optnodes):
237        for ii in range(self.dim):
238            # print("ii: ", ii)
239            c3.one_approx_opts_free(onedopts[ii])
240            # print("woop: !")
241            if self.opts[ii]['type'] == "legendre":
242                c3.ope_opts_free(low_opts[ii])
243            elif self.opts[ii]['type'] == "hermite":
244                c3.ope_opts_free(low_opts[ii])
245            elif self.opts[ii]['type'] == "fourier":
246                c3.ope_opts_free(low_opts[ii])
247            elif self.opts[ii]['type'] == "linelm":
248                c3.lin_elem_exp_aopts_free(low_opts[ii])
249            elif self.opts[ii]['type'] == "kernel":
250                c3.kernel_approx_opts_free(low_opts[ii])
251            elif self.opts[ii]['type'] == "piecewise":
252                c3.pw_poly_opts_free(low_opts[ii])
253            else:
254                raise AttributeError("Don't know what to do here")
255
256        #NOT FREING optnodes[ii][0]
257        for ii in range(len(optnodes)):
258            # c3.free(optnodes[ii][0])
259            c3.c3vector_free(optnodes[ii][1])
260
261        # print("OK")
262        c3.c3approx_destroy(c3a)
263        # print("OKkkk")
264
265    def _assemble_cross_args(self, verbose, init_rank, maxrank=10, cross_tol=1e-8,
266                             round_tol=1e-8, kickrank=5, maxiter=5):
267
268
269        start_fibers = c3.malloc_dd(self.dim)
270
271        c3a, onedopts, low_opts, optnodes = self._build_approx_params(c3.CROSS)
272
273
274        for ii in range(self.dim):
275            c3.dd_row_linspace(start_fibers, ii, self.opts[ii]['lb'],
276                               self.opts[ii]['ub'], init_rank);
277
278
279            # c3.dd_row_linspace(start_fibers, ii, 0.8,
280            #                    1.2, init_rank)
281
282        # NEED TO FREE OPTNODES
283
284        c3.c3approx_init_cross(c3a, init_rank, verbose, start_fibers)
285        c3.c3approx_set_verbose(c3a, verbose)
286        c3.c3approx_set_cross_tol(c3a, cross_tol)
287        c3.c3approx_set_cross_maxiter(c3a, maxiter)
288        c3.c3approx_set_round_tol(c3a, round_tol)
289        c3.c3approx_set_adapt_maxrank_all(c3a, maxrank)
290        c3.c3approx_set_adapt_kickrank(c3a, kickrank)
291
292        c3.free_dd(self.dim, start_fibers)
293        return c3a, onedopts, low_opts, optnodes
294
295    def build_approximation(self, f, fargs, init_rank, verbose, adapt, maxrank=10, cross_tol=1e-8,
296                            round_tol=1e-8, kickrank=5, maxiter=5):
297        """ Build an adaptive approximation of *f* """
298
299        fobj = pcb.alloc_cobj()
300        pcb.assign(fobj, self.dim, f, fargs)
301        fw = c3.fwrap_create(self.dim, "python")
302        c3.fwrap_set_pyfunc(fw, fobj)
303        c3a, onedopts, low_opts, optnodes = self._assemble_cross_args(verbose, init_rank,
304                                                            maxrank=maxrank,
305                                                            cross_tol=cross_tol,
306                                                            round_tol=round_tol,
307                                                            kickrank=kickrank,
308                                                            maxiter=maxiter)
309        # print("do cross\n");
310        self.ft = c3.c3approx_do_cross(c3a, fw, adapt)
311
312        self._free_approx_params(c3a, onedopts, low_opts, optnodes)
313        c3.fwrap_destroy(fw)
314
315
316    def build_data_model(self, ndata, xdata, ydata, alg="AIO", obj="LS", verbose=0,
317                         opt_type="BFGS", opt_gtol=1e-10, opt_relftol=1e-10,
318                         opt_absxtol=1e-30, opt_maxiter=2000, opt_sgd_learn_rate=1e-3,
319                         adaptrank=0, roundtol=1e-5, maxrank=10, kickrank=2,
320                         kristoffel=False, regweight=1e-7, cvnparam=None,
321                         cvregweight=None, kfold=5, cvverbose=0, als_max_sweep=20,
322                         cvrank=None, norm_ydata=False, store_opt_info=False):
323        """
324        Note that this overwrites multiopts, and the final rank might not be the same
325        as self.rank
326
327        xdata should be ndata x dim
328        """
329
330        assert isinstance(xdata, np.ndarray)
331        assert ydata.ndim == 1
332
333        optimizer = None
334        if opt_type == "BFGS":
335            optimizer = c3.c3opt_create(c3.BFGS)
336            c3.c3opt_set_absxtol(optimizer, opt_absxtol)
337            c3.c3opt_ls_set_maxiter(optimizer, 300)
338            # c3.c3opt_ls_set_alpha(optimizer, 0.1)
339            # c3.c3opt_ls_set_beta(optimizer, 0.5)
340        elif opt_type == "SGD":
341            optimizer = c3.c3opt_create(c3.SGD)
342            c3.c3opt_set_sgd_nsamples(optimizer, xdata.shape[0])
343            c3.c3opt_set_sgd_learn_rate(optimizer, opt_sgd_learn_rate)
344            c3.c3opt_set_absxtol(optimizer, opt_absxtol)
345        else:
346            raise AttributeError('Optimizer:  ' + opt_type + " is unknown")
347
348        if store_opt_info is True:
349            c3.c3opt_set_storage_options(optimizer, 1, 0, 0)
350
351        if verbose > 1:
352            c3.c3opt_set_verbose(optimizer, 1)
353
354        # Set optimization options
355        c3.c3opt_set_gtol(optimizer, opt_gtol)
356        c3.c3opt_set_relftol(optimizer, opt_relftol)
357        c3.c3opt_set_maxiter(optimizer, opt_maxiter)
358
359        c3a, onedopts, low_opts, opt_opts = self._build_approx_params(c3.REGRESS)
360        multiopts = c3.c3approx_get_approx_args(c3a)
361
362        reg = c3.ft_regress_alloc(self.dim, multiopts, self.ranks)
363        if alg == "AIO" and obj == "LS":
364            c3.ft_regress_set_alg_and_obj(reg, c3.AIO, c3.FTLS)
365        elif alg == "AIO" and obj == "LS_SPARSECORE":
366            c3.ft_regress_set_alg_and_obj(reg, c3.AIO, c3.FTLS_SPARSEL2)
367            c3.ft_regress_set_regularization_weight(reg, regweight)
368        elif alg == "ALS" and obj == "LS":
369            c3.ft_regress_set_alg_and_obj(reg, c3.ALS, c3.FTLS)
370            c3.ft_regress_set_max_als_sweep(reg, als_max_sweep)
371        elif alg == "ALS" and obj == "LS_SPARSECORE":
372            c3.ft_regress_set_alg_and_obj(reg, c3.ALS, c3.FTLS_SPARSEL2)
373            c3.ft_regress_set_regularization_weight(reg, regweight)
374            c3.ft_regress_set_max_als_sweep(reg, als_max_sweep)
375        else:
376            raise AttributeError('Option combination of algorithm and objective not implemented '\
377                                 + alg + obj)
378        if alg == 'ALS':
379            c3.ft_regress_set_als_conv_tol(reg, opt_relftol)
380
381        if adaptrank != 0:
382            c3.ft_regress_set_adapt(reg, adaptrank)
383            c3.ft_regress_set_roundtol(reg, roundtol)
384            c3.ft_regress_set_maxrank(reg, maxrank)
385            c3.ft_regress_set_kickrank(reg, kickrank)
386            c3.ft_regress_set_kfold(reg, kfold)
387
388        c3.ft_regress_set_verbose(reg, verbose)
389
390        if kristoffel is True:
391            c3.ft_regress_set_kristoffel(reg, 1)
392
393        if self.ft is not None:
394            c3.function_train_free(self.ft)
395
396
397
398        cv = None
399        cvgrid = None
400        if (cvnparam is not None) and (cvregweight is None) and (cvrank is None):
401            cvgrid = c3.cv_opt_grid_init(1)
402            c3.cv_opt_grid_add_param(cvgrid, "num_param", len(cvnparam), list(cvnparam))
403        elif (cvnparam is None) and (cvregweight is not None) and (cvrank is None):
404            cvgrid = c3.cv_opt_grid_init(1)
405            c3.cv_opt_grid_add_param(cvgrid, "reg_weight", len(cvregweight), list(cvregweight))
406        elif (cvnparam is not None) and (cvregweight is not None) and (cvrank is None):
407            cvgrid = c3.cv_opt_grid_init(2)
408            c3.cv_opt_grid_add_param(cvgrid, "num_param", len(cvnparam), list(cvnparam))
409            c3.cv_opt_grid_add_param(cvgrid, "reg_weight", len(cvregeight), list(cvnparam))
410        elif (cvnparam is not None) and (cvrank is not None):
411            cvgrid = c3.cv_opt_grid_init(2)
412            c3.cv_opt_grid_add_param(cvgrid, "rank", len(cvrank), list(cvrank))
413            c3.cv_opt_grid_add_param(cvgrid, "num_param", len(cvnparam), list(cvnparam))
414
415
416
417
418        yuse = ydata
419        if norm_ydata is True:
420            vmin = np.min(ydata)
421            vmax = np.max(ydata)
422            vdiff = vmax - vmin
423            assert (vmax - vmin) > 1e-14
424            yuse = ydata / vdiff - vmin / vdiff
425
426        if cvgrid is not None:
427            #print("Cross validation is not working yet!\n")
428            c3.cv_opt_grid_set_verbose(cvgrid, cvverbose)
429
430            cv = c3.cross_validate_init(self.dim, xdata.flatten(order='C'), yuse, kfold, cvverbose)
431            c3.cross_validate_grid_opt(cv, cvgrid, reg, optimizer)
432            c3.cv_opt_grid_free(cvgrid)
433            c3.cross_validate_free(cv)
434
435
436        # print("Run regression")
437        self.ft = c3.ft_regress_run(reg, optimizer, xdata.flatten(order='C'), yuse)
438        # print("Done!")
439
440
441        if norm_ydata is True: # need to unnormalize
442            ft_use = self.scale_and_shift(vdiff, vmin, c3_pointer=True)
443            c3.function_train_free(self.ft)
444            self.ft = ft_use
445
446        if store_opt_info is True:
447            if alg == "ALS":
448                nepoch = c3.ft_regress_get_nepochs(reg)
449                results = np.zeros((nepoch))
450                for ii in range(nepoch):
451                    results[ii] = c3.ft_regress_get_stored_fvals(reg, ii)
452            else:
453                nepoch = c3.c3opt_get_niters(optimizer)
454                results = np.zeros((nepoch))
455                for ii in range(nepoch):
456                    results[ii] = c3.c3opt_get_stored_function(optimizer, ii)
457
458            # print("nepoch", nepoch)
459
460        c3.ft_regress_free(reg)
461        c3.c3opt_free(optimizer)
462
463        # Free built approximation options
464        # print("Free params")
465        self._free_approx_params(c3a, onedopts, low_opts, opt_opts)
466        # print("Done Free params")
467
468        if store_opt_info is True:
469            return results
470
471
472    def set_ranks(self, ranks):
473
474        if len(ranks) != self.dim+1:
475            raise AttributeError("Ranks must be a list of size dim+1, \
476            with the first and last elements = 1")
477
478        if isinstance(ranks, list):
479            self.ranks = copy.deepcopy(ranks)
480        else:
481            self.ranks = list(copy.deepcopy(ranks))
482
483        if ranks[0] != 1:
484            print("Warning: rank[0] is not specified to 1, overwriting ")
485            self.ranks[0] = 1
486
487        if ranks[self.dim] != 1:
488            print("Warning: rank[0] is not specified to 1, overwriting ")
489            self.ranks[self.dim] = 1
490
491    def eval(self, pt):
492        """ Evaluate a FunctionTrain """
493        assert isinstance(pt, np.ndarray)
494        if pt.ndim == 1:
495            return c3.function_train_eval(self.ft, pt)
496        else:
497            assert pt.shape[1] == self.dim
498            out = np.zeros((pt.shape[0]))
499            for ii, p in enumerate(pt):
500                out[ii] = c3.function_train_eval(self.ft, p)
501            return out
502
503    def grad_eval(self, pt):
504        """ Evaluate the gradient of a FunctionTrain """
505
506        grad_out = np.zeros((self.dim))
507        c3.function_train_gradient_eval(self.ft, pt, grad_out)
508        return grad_out
509
510    def round(self, eps=1e-14):
511        """ Round a FunctionTrain """
512
513        c3a, onedopts, low_opts, optopts = self._build_approx_params()
514        multiopts = c3.c3approx_get_approx_args(c3a)
515        ftc = c3.function_train_round(self.ft, eps, multiopts)
516        c3.function_train_free(self.ft)
517        # c3.function_train_free(self.ft)
518        self.ft = ftc
519        self._free_approx_params(c3a, onedopts, low_opts, optopts)
520
521    def __add__(self, other, eps=0):
522        """ Add two function trains """
523
524        out = FunctionTrain(self.dim)
525        out.ft = c3.function_train_sum(self.ft, other.ft)
526        out.opts = copy.deepcopy(self.opts)
527        out.round(eps)
528        return out
529
530    def __sub__(self, other, eps=0):
531        """ Subtract two function trains """
532
533        # print("subtracting!")
534        temp1 = c3.function_train_copy(other.ft)
535        c3.function_train_scale(temp1, -1.0)
536
537        out_ft = FunctionTrain(self.dim)
538        out_ft.opts = copy.deepcopy(self.opts)
539        out_ft.ft = c3.function_train_sum(self.ft, temp1)
540        out_ft.round(eps)
541
542        c3.function_train_free(temp1)
543        return out_ft
544
545    def __mul__(self, other, eps=0):
546        out = FunctionTrain(self.dim)
547        out.ft = c3.function_train_product(self.ft, other.ft)
548        out.opts = copy.deepcopy(self.opts)
549        out.round(eps)
550        return out
551
552    def integrate(self):
553        return c3.function_train_integrate(self.ft)
554
555    def inner(self, other):
556        return c3.function_train_inner(self.ft, other.ft)
557
558    def scale(self, a, eps=0):
559        """ f <- a*f"""
560        c3.function_train_scale(self.ft, a)
561        return self
562
563    def scale_and_shift(self, scale, shift, eps=0, c3_pointer=False):
564
565        c3a, onedopts, low_opts = self._build_approx_params()
566        multiopts = c3.c3approx_get_approx_args(c3a)
567
568        ft1 = c3.function_train_copy(self.ft)
569        c3.function_train_scale(ft1, scale)
570
571        ft2 = c3.function_train_constant(shift, multiopts)
572
573        ft_out = FunctionTrain(self.dim)
574        ft_out.opts = copy.deepcopy(self.opts)
575        ft_out.ft = c3.function_train_sum(ft1, ft2)
576
577        c3.function_train_free(ft1)
578        c3.function_train_free(ft2)
579
580        c3.function_train_round(ft_out.ft, eps, multiopts)
581        self._free_approx_params(c3a, onedopts, low_opts)
582
583        if c3_pointer is True:
584            ft_ret = c3.function_train_copy(ft_out.ft)
585            return ft_ret
586        else:
587            return ft_out
588
589    def get_ranks(self):
590        dim = c3.function_train_get_dim(self.ft)
591        ranks = [1]*(dim+1)
592        for ii in range(dim+1):
593            ranks[ii] = c3.function_train_get_rank(self.ft, ii)
594        return np.array(ranks)
595
596    def norm2(self):
597        return c3.function_train_norm2(self.ft)
598
599    def expectation(self):
600        return c3.function_train_integrate_weighted(self.ft)
601
602    def variance(self):
603        mean_val = self.expectation()
604        second_moment = c3.function_train_inner_weighted(self.ft, self.ft)
605        return second_moment - mean_val*mean_val
606
607    def get_uni_func(self, dim, row, col):
608        return GenericFunction(c3.function_train_get_gfuni(self.ft, dim, row, col))
609
610    def laplace(self, eps=0.0):
611        ft_out = FunctionTrain(self.dim)
612
613        c3a, onedopts, low_opts, opt_opts = self._build_approx_params(c3.REGRESS)
614        multiopts = c3.c3approx_get_approx_args(c3a)
615
616
617        ft_out.ft = c3.exact_laplace(self.ft, multiopts)
618        ft_out.opts = copy.deepcopy(self.opts)
619
620        self._free_approx_params(c3a, onedopts, low_opts, opt_opts)
621
622        if eps > 0.0:
623            ft_out.round(eps=eps)
624
625        return ft_out
626
627
628    def __del__(self):
629        self.close()
630
631    def close(self):
632        # print("Running Cleanup ")
633        if self.ft is not None:
634            c3.function_train_free(self.ft)
635            self.ft = None
636
637
638
639
640class GenericFunction(object):
641    """ Univariate Functions """
642
643    gf = None
644    def __init__(self, gf):
645        self.gf = c3.generic_function_copy(gf)
646
647    def eval(self, x):
648        if type(x) == list:
649            return [c3.generic_function_1d_eval(self.gf, xx) for xx in x]
650        elif type(x) == np.ndarray:
651            assert x.ndim == 1, "Only 1d arrays handled"
652            return np.array([c3.generic_function_1d_eval(self.gf, xx) for xx in x])
653        else:
654            return c3.generic_function_1d_eval(self.gf, x)
655
656    def __del__(self):
657        self.close()
658
659    def close(self):
660        if self.gf is not None:
661            c3.generic_function_free(self.gf)
662            self.gf = None
663
664class SobolIndices(object):
665    """ Sobol Sensitivity Indices """
666
667    si = None
668    def __init__(self, ft, order=None):
669        if order is None:
670            order = ft.dim
671        self.si = c3.c3_sobol_sensitivity_calculate(ft.ft, order)
672
673    def get_total_sensitivity(self, index):
674        return c3.c3_sobol_sensitivity_get_total(self.si, index)
675
676    def get_main_sensitivity(self, index):
677        return c3.c3_sobol_sensitivity_get_main(self.si, index)
678
679    def get_variance(self):
680        return c3.c3_sobol_sensitivity_get_variance(self.si)
681
682    def get_interaction(self, variables):
683        # assert len(variables) == 2, "Only effects between two variables is currently supported"
684        for ii in range(len(variables)-1):
685            assert variables[ii] < variables[ii+1], \
686                "Sobol index variables must be ordered with v[i] < V[j] for i < j"
687
688        ret = c3.c3_sobol_sensitivity_get_interaction(self.si, variables)
689        return ret
690
691    def __del__(self):
692        self.close()
693
694    def close(self):
695        if self.si is not None:
696            c3.c3_sobol_sensitivity_free(self.si)
697            self.si = None
698