1"""
2Classes for easy interpolation of trajectories and Curves.
3Requires Scipy installed.
4"""
5
6import numpy as np
7
8
9class Interpolator:
10    """ Poorman's linear interpolator, doesn't require Scipy. """
11
12    def __init__(self, tt=None, ss=None, ttss = None, left=None, right=None):
13
14        if ttss is not None:
15            tt, ss = zip(*ttss)
16
17        self.tt = 1.0*np.array(tt)
18        self.ss = 1.0*np.array(ss)
19        self.left = left
20        self.right = right
21        self.tmin, self.tmax = min(tt), max(tt)
22
23    def __call__(self, t):
24        return np.interp(t, self.tt, self.ss, self.left, self.right)
25
26class Trajectory:
27
28    def __init__(self, tt, xx, yy):
29
30        self.tt = 1.0*np.array(tt)
31        self.xx = np.array(xx)
32        self.yy = np.array(yy)
33        self.update_interpolators()
34
35    def __call__(self, t):
36        return np.array([self.xi(t), self.yi(t)])
37
38    def addx(self, x):
39        return Trajectory(self.tt, self.xx+x, self.yy)
40
41    def addy(self, y):
42        return Trajectory(self.tt, self.xx, self.yy+y)
43
44    def update_interpolators(self):
45        self.xi =  Interpolator(self.tt, self.xx)
46        self.yi =  Interpolator(self.tt, self.yy)
47
48    def txy(self, tms=False):
49        return zip((1000 if tms else 1)*self.tt, self.xx, self.yy)
50
51    def to_file(self, filename):
52        np.savetxt(filename, np.array(self.txy(tms=True)),
53                   fmt="%d", delimiter='\t')
54
55    @staticmethod
56    def from_file(filename):
57        arr = np.loadtxt(filename, delimiter='\t')
58        tt, xx, yy = arr.T
59        return Trajectory(1.0*tt/1000, xx, yy)
60
61    @staticmethod
62    def save_list(trajs, filename):
63        N = len(trajs)
64        arr = np.hstack([np.array(list(t.txy(tms=True))) for t in trajs])
65        np.savetxt( filename, arr, fmt="%d", delimiter='\t',
66                    header = "\t".join(N*['t(ms)', 'x', 'y']))
67
68    @staticmethod
69    def load_list(filename):
70        arr = np.loadtxt(filename, delimiter='\t').T
71        Nlines = arr.shape[0]
72        return [Trajectory(tt=1.0*a[0]/1000, xx=a[1], yy=a[2])
73                for a in np.split(arr, Nlines/3)]
74