1"""Basic operations that are needed repeatedly throughout Skyfield."""
2
3from numpy import (
4    arcsin, arctan2, array, cos, einsum, finfo, float64,
5    full_like, load, rollaxis, sin, sqrt,
6)
7from pkgutil import get_data
8from skyfield.constants import tau
9
10_AVOID_DIVIDE_BY_ZERO = finfo(float64).tiny
11
12class A(object):
13    """Allow literal NumPy arrays to be spelled ``A[1, 2, 3]``."""
14    __getitem__ = array
15A = A()
16
17def dots(v, u):
18    """Given one or more vectors in `v` and `u`, return their dot products.
19
20    This works whether `v` and `u` each have the shape ``(3,)``, or
21    whether they are each whole arrays of corresponding x, y, and z
22    coordinates and have shape ``(3, N)``.
23
24    """
25    return (v * u).sum(axis=0)
26
27def T(M):
28    """Swap the first two dimensions of an array."""
29    return rollaxis(M, 1)
30
31def mxv(M, v):
32    """Matrix times vector: multiply an NxN matrix by a vector."""
33    return einsum('ij...,j...->i...', M, v)
34
35def mxm(M1, M2):
36    """Matrix times matrix: multiply two NxN matrices."""
37    return einsum('ij...,jk...->ik...', M1, M2)
38
39def mxmxm(M1, M2, M3):
40    """Matrix times matrix times matrix: multiply 3 NxN matrices together."""
41    return einsum('ij...,jk...,kl...->il...', M1, M2, M3)
42
43_T, _mxv, _mxm, _mxmxm = T, mxv, mxm, mxmxm  # In case anyone imported old name
44
45def length_of(xyz):
46    """Given a 3-element array |xyz|, return its length.
47
48    The three elements can be simple scalars, or the array can be two
49    dimensions and offer three whole series of x, y, and z coordinates.
50
51    """
52    return sqrt((xyz * xyz).sum(axis=0))
53
54def angle_between(u, v):
55    """Given two vectors `v` and `u`, return the radian angle separating them.
56
57    This works whether `v` and `u` each have the shape ``(3,)``, or
58    whether they are each whole arrays of corresponding x, y, and z
59    coordinates with shape ``(3, N)``. The returned angle will be
60    between 0 and tau/2.
61
62    This formula is from Section 12 of:
63    https://people.eecs.berkeley.edu/~wkahan/Mindless.pdf
64
65    """
66    a = u * length_of(v)
67    b = v * length_of(u)
68    return 2.0 * arctan2(length_of(a - b), length_of(a + b))
69
70def to_spherical(xyz):
71    """Convert |xyz| to spherical coordinates (r,theta,phi).
72
73    ``r`` - vector length
74    ``theta`` - angle above (+) or below (-) the xy-plane
75    ``phi`` - angle around the z-axis
76
77    Note that ``theta`` is an elevation angle measured up and down from
78    the xy-plane, not a polar angle measured from the z-axis, to match
79    the convention for both latitude and declination.
80
81    """
82    r = length_of(xyz)
83    x, y, z = xyz
84    theta = arcsin(z / (r + _AVOID_DIVIDE_BY_ZERO))
85    phi = arctan2(y, x) % tau
86    return r, theta, phi
87
88def _to_spherical_and_rates(r, v):
89    # Convert Cartesian rate and velocity vectors to angles and rates.
90    x, y, z = r
91    xdot, ydot, zdot = v
92
93    length = length_of(r)
94    lat = arcsin(z / (length + _AVOID_DIVIDE_BY_ZERO))
95    lon = arctan2(y, x) % tau
96    range_rate = dots(r, v) / length_of(r)
97
98    x2 = x * x
99    y2 = y * y
100    x2_plus_y2 = x2 + y2 + _AVOID_DIVIDE_BY_ZERO
101    lat_rate = (x2_plus_y2 * zdot - z * (x * xdot + y * ydot)) / (
102        (x2_plus_y2 + z*z) * sqrt(x2_plus_y2))
103    lon_rate = (x * ydot - xdot * y) / x2_plus_y2
104
105    return length, lat, lon, range_rate, lat_rate, lon_rate
106
107def from_spherical(r, theta, phi):
108    """Convert (r,theta,phi) to Cartesian coordinates |xyz|.
109
110    ``r`` - vector length
111    ``theta`` - angle in radians above (+) or below (-) the xy-plane
112    ``phi`` - angle in radians around the z-axis
113
114    Note that ``theta`` is an elevation angle measured up and down from
115    the xy-plane, not a polar angle measured from the z-axis, to match
116    the convention for both latitude and declination.
117
118    """
119    rxy = r * cos(theta)
120    return array((rxy * cos(phi), rxy * sin(phi), r * sin(theta)))
121
122# Support users who might have imported these under their old names.
123# I'm not sure why I called what are clearly spherical coordinates "polar".
124to_polar = to_spherical
125from_polar = from_spherical
126
127def rot_x(theta):
128    c = cos(theta)
129    s = sin(theta)
130    zero = theta * 0.0
131    one = zero + 1.0
132    return array(((one, zero, zero), (zero, c, -s), (zero, s, c)))
133
134def rot_y(theta):
135    c = cos(theta)
136    s = sin(theta)
137    zero = theta * 0.0
138    one = zero + 1.0
139    return array(((c, zero, s), (zero, one, zero), (-s, zero, c)))
140
141def rot_z(theta):
142    c = cos(theta)
143    s = sin(theta)
144    zero = theta * 0.0
145    one = zero + 1.0
146    return array(((c, -s, zero), (s, c, zero), (zero, zero, one)))
147
148def angular_velocity_matrix(angular_velocity_vector):
149    x, y, z = angular_velocity_vector
150    zero = x * 0.0
151    return array(((zero, -z, y), (z, zero, -x), (-y, x, zero)))
152
153def _to_array(value):
154    """Convert plain Python sequences into NumPy arrays.
155
156    This helps Skyfield endpoints convert caller-provided tuples and
157    lists into NumPy arrays.  If the ``value`` is not a sequence, then
158    it is coerced to a Numpy float object, but not an actual array.
159
160    """
161    if hasattr(value, 'shape'):
162        return value
163    elif hasattr(value, '__len__'):
164        return array(value)
165    else:
166        return float64(value)
167
168def _reconcile(a, b):
169    """Coerce two NumPy generics-or-arrays to the same number of dimensions."""
170    an = getattr(a, 'ndim', 0)
171    bn = getattr(b, 'ndim', 0)
172    difference = bn - an
173    if difference > 0:
174        if an:
175            a.shape += (1,) * difference
176        else:
177            a = full_like(b, a)
178    elif difference < 0:
179        if bn:
180            b.shape += (1,) * -difference
181        else:
182            b = full_like(a, b)
183    return a, b
184
185try:
186    from io import BytesIO
187except:
188    from StringIO import StringIO as BytesIO
189
190def load_bundled_npy(filename):
191    """Load a binary NumPy array file that is bundled with Skyfield."""
192    data = get_data('skyfield', 'data/{0}'.format(filename))
193    return load(BytesIO(data))
194