1"""
2Plotting (requires matplotlib)
3"""
4
5from colorsys import hsv_to_rgb, hls_to_rgb
6from .libmp import NoConvergence
7from .libmp.backend import xrange
8
9class VisualizationMethods(object):
10    plot_ignore = (ValueError, ArithmeticError, ZeroDivisionError, NoConvergence)
11
12def plot(ctx, f, xlim=[-5,5], ylim=None, points=200, file=None, dpi=None,
13    singularities=[], axes=None):
14    r"""
15    Shows a simple 2D plot of a function `f(x)` or list of functions
16    `[f_0(x), f_1(x), \ldots, f_n(x)]` over a given interval
17    specified by *xlim*. Some examples::
18
19        plot(lambda x: exp(x)*li(x), [1, 4])
20        plot([cos, sin], [-4, 4])
21        plot([fresnels, fresnelc], [-4, 4])
22        plot([sqrt, cbrt], [-4, 4])
23        plot(lambda t: zeta(0.5+t*j), [-20, 20])
24        plot([floor, ceil, abs, sign], [-5, 5])
25
26    Points where the function raises a numerical exception or
27    returns an infinite value are removed from the graph.
28    Singularities can also be excluded explicitly
29    as follows (useful for removing erroneous vertical lines)::
30
31        plot(cot, ylim=[-5, 5])   # bad
32        plot(cot, ylim=[-5, 5], singularities=[-pi, 0, pi])  # good
33
34    For parts where the function assumes complex values, the
35    real part is plotted with dashes and the imaginary part
36    is plotted with dots.
37
38    .. note :: This function requires matplotlib (pylab).
39    """
40    if file:
41        axes = None
42    fig = None
43    if not axes:
44        import pylab
45        fig = pylab.figure()
46        axes = fig.add_subplot(111)
47    if not isinstance(f, (tuple, list)):
48        f = [f]
49    a, b = xlim
50    colors = ['b', 'r', 'g', 'm', 'k']
51    for n, func in enumerate(f):
52        x = ctx.arange(a, b, (b-a)/float(points))
53        segments = []
54        segment = []
55        in_complex = False
56        for i in xrange(len(x)):
57            try:
58                if i != 0:
59                    for sing in singularities:
60                        if x[i-1] <= sing and x[i] >= sing:
61                            raise ValueError
62                v = func(x[i])
63                if ctx.isnan(v) or abs(v) > 1e300:
64                    raise ValueError
65                if hasattr(v, "imag") and v.imag:
66                    re = float(v.real)
67                    im = float(v.imag)
68                    if not in_complex:
69                        in_complex = True
70                        segments.append(segment)
71                        segment = []
72                    segment.append((float(x[i]), re, im))
73                else:
74                    if in_complex:
75                        in_complex = False
76                        segments.append(segment)
77                        segment = []
78                    if hasattr(v, "real"):
79                        v = v.real
80                    segment.append((float(x[i]), v))
81            except ctx.plot_ignore:
82                if segment:
83                    segments.append(segment)
84                segment = []
85        if segment:
86            segments.append(segment)
87        for segment in segments:
88            x = [s[0] for s in segment]
89            y = [s[1] for s in segment]
90            if not x:
91                continue
92            c = colors[n % len(colors)]
93            if len(segment[0]) == 3:
94                z = [s[2] for s in segment]
95                axes.plot(x, y, '--'+c, linewidth=3)
96                axes.plot(x, z, ':'+c, linewidth=3)
97            else:
98                axes.plot(x, y, c, linewidth=3)
99    axes.set_xlim([float(_) for _ in xlim])
100    if ylim:
101        axes.set_ylim([float(_) for _ in ylim])
102    axes.set_xlabel('x')
103    axes.set_ylabel('f(x)')
104    axes.grid(True)
105    if fig:
106        if file:
107            pylab.savefig(file, dpi=dpi)
108        else:
109            pylab.show()
110
111def default_color_function(ctx, z):
112    if ctx.isinf(z):
113        return (1.0, 1.0, 1.0)
114    if ctx.isnan(z):
115        return (0.5, 0.5, 0.5)
116    pi = 3.1415926535898
117    a = (float(ctx.arg(z)) + ctx.pi) / (2*ctx.pi)
118    a = (a + 0.5) % 1.0
119    b = 1.0 - float(1/(1.0+abs(z)**0.3))
120    return hls_to_rgb(a, b, 0.8)
121
122blue_orange_colors = [
123  (-1.0,  (0.0, 0.0, 0.0)),
124  (-0.95, (0.1, 0.2, 0.5)),   # dark blue
125  (-0.5,  (0.0, 0.5, 1.0)),   # blueish
126  (-0.05, (0.4, 0.8, 0.8)),   # cyanish
127  ( 0.0,  (1.0, 1.0, 1.0)),
128  ( 0.05, (1.0, 0.9, 0.3)),   # yellowish
129  ( 0.5,  (0.9, 0.5, 0.0)),   # orangeish
130  ( 0.95, (0.7, 0.1, 0.0)),   # redish
131  ( 1.0,  (0.0, 0.0, 0.0)),
132  ( 2.0,  (0.0, 0.0, 0.0)),
133]
134
135def phase_color_function(ctx, z):
136    if ctx.isinf(z):
137        return (1.0, 1.0, 1.0)
138    if ctx.isnan(z):
139        return (0.5, 0.5, 0.5)
140    pi = 3.1415926535898
141    w = float(ctx.arg(z)) / pi
142    w = max(min(w, 1.0), -1.0)
143    for i in range(1,len(blue_orange_colors)):
144        if blue_orange_colors[i][0] > w:
145            a, (ra, ga, ba) = blue_orange_colors[i-1]
146            b, (rb, gb, bb) = blue_orange_colors[i]
147            s = (w-a) / (b-a)
148            return ra+(rb-ra)*s, ga+(gb-ga)*s, ba+(bb-ba)*s
149
150def cplot(ctx, f, re=[-5,5], im=[-5,5], points=2000, color=None,
151    verbose=False, file=None, dpi=None, axes=None):
152    """
153    Plots the given complex-valued function *f* over a rectangular part
154    of the complex plane specified by the pairs of intervals *re* and *im*.
155    For example::
156
157        cplot(lambda z: z, [-2, 2], [-10, 10])
158        cplot(exp)
159        cplot(zeta, [0, 1], [0, 50])
160
161    By default, the complex argument (phase) is shown as color (hue) and
162    the magnitude is show as brightness. You can also supply a
163    custom color function (*color*). This function should take a
164    complex number as input and return an RGB 3-tuple containing
165    floats in the range 0.0-1.0.
166
167    Alternatively, you can select a builtin color function by passing
168    a string as *color*:
169
170      * "default" - default color scheme
171      * "phase" - a color scheme that only renders the phase of the function,
172         with white for positive reals, black for negative reals, gold in the
173         upper half plane, and blue in the lower half plane.
174
175    To obtain a sharp image, the number of points may need to be
176    increased to 100,000 or thereabout. Since evaluating the
177    function that many times is likely to be slow, the 'verbose'
178    option is useful to display progress.
179
180    .. note :: This function requires matplotlib (pylab).
181    """
182    if color is None or color == "default":
183        color = ctx.default_color_function
184    if color == "phase":
185        color = ctx.phase_color_function
186    import pylab
187    if file:
188        axes = None
189    fig = None
190    if not axes:
191        fig = pylab.figure()
192        axes = fig.add_subplot(111)
193    rea, reb = re
194    ima, imb = im
195    dre = reb - rea
196    dim = imb - ima
197    M = int(ctx.sqrt(points*dre/dim)+1)
198    N = int(ctx.sqrt(points*dim/dre)+1)
199    x = pylab.linspace(rea, reb, M)
200    y = pylab.linspace(ima, imb, N)
201    # Note: we have to be careful to get the right rotation.
202    # Test with these plots:
203    #   cplot(lambda z: z if z.real < 0 else 0)
204    #   cplot(lambda z: z if z.imag < 0 else 0)
205    w = pylab.zeros((N, M, 3))
206    for n in xrange(N):
207        for m in xrange(M):
208            z = ctx.mpc(x[m], y[n])
209            try:
210                v = color(f(z))
211            except ctx.plot_ignore:
212                v = (0.5, 0.5, 0.5)
213            w[n,m] = v
214        if verbose:
215            print(str(n) + ' of ' + str(N))
216    rea, reb, ima, imb = [float(_) for _ in [rea, reb, ima, imb]]
217    axes.imshow(w, extent=(rea, reb, ima, imb), origin='lower')
218    axes.set_xlabel('Re(z)')
219    axes.set_ylabel('Im(z)')
220    if fig:
221        if file:
222            pylab.savefig(file, dpi=dpi)
223        else:
224            pylab.show()
225
226def splot(ctx, f, u=[-5,5], v=[-5,5], points=100, keep_aspect=True, \
227          wireframe=False, file=None, dpi=None, axes=None):
228    """
229    Plots the surface defined by `f`.
230
231    If `f` returns a single component, then this plots the surface
232    defined by `z = f(x,y)` over the rectangular domain with
233    `x = u` and `y = v`.
234
235    If `f` returns three components, then this plots the parametric
236    surface `x, y, z = f(u,v)` over the pairs of intervals `u` and `v`.
237
238    For example, to plot a simple function::
239
240        >>> from mpmath import *
241        >>> f = lambda x, y: sin(x+y)*cos(y)
242        >>> splot(f, [-pi,pi], [-pi,pi])    # doctest: +SKIP
243
244    Plotting a donut::
245
246        >>> r, R = 1, 2.5
247        >>> f = lambda u, v: [r*cos(u), (R+r*sin(u))*cos(v), (R+r*sin(u))*sin(v)]
248        >>> splot(f, [0, 2*pi], [0, 2*pi])    # doctest: +SKIP
249
250    .. note :: This function requires matplotlib (pylab) 0.98.5.3 or higher.
251    """
252    import pylab
253    import mpl_toolkits.mplot3d as mplot3d
254    if file:
255        axes = None
256    fig = None
257    if not axes:
258        fig = pylab.figure()
259        axes = mplot3d.axes3d.Axes3D(fig)
260    ua, ub = u
261    va, vb = v
262    du = ub - ua
263    dv = vb - va
264    if not isinstance(points, (list, tuple)):
265        points = [points, points]
266    M, N = points
267    u = pylab.linspace(ua, ub, M)
268    v = pylab.linspace(va, vb, N)
269    x, y, z = [pylab.zeros((M, N)) for i in xrange(3)]
270    xab, yab, zab = [[0, 0] for i in xrange(3)]
271    for n in xrange(N):
272        for m in xrange(M):
273            fdata = f(ctx.convert(u[m]), ctx.convert(v[n]))
274            try:
275                x[m,n], y[m,n], z[m,n] = fdata
276            except TypeError:
277                x[m,n], y[m,n], z[m,n] = u[m], v[n], fdata
278            for c, cab in [(x[m,n], xab), (y[m,n], yab), (z[m,n], zab)]:
279                if c < cab[0]:
280                    cab[0] = c
281                if c > cab[1]:
282                    cab[1] = c
283    if wireframe:
284        axes.plot_wireframe(x, y, z, rstride=4, cstride=4)
285    else:
286        axes.plot_surface(x, y, z, rstride=4, cstride=4)
287    axes.set_xlabel('x')
288    axes.set_ylabel('y')
289    axes.set_zlabel('z')
290    if keep_aspect:
291        dx, dy, dz = [cab[1] - cab[0] for cab in [xab, yab, zab]]
292        maxd = max(dx, dy, dz)
293        if dx < maxd:
294            delta = maxd - dx
295            axes.set_xlim3d(xab[0] - delta / 2.0, xab[1] + delta / 2.0)
296        if dy < maxd:
297            delta = maxd - dy
298            axes.set_ylim3d(yab[0] - delta / 2.0, yab[1] + delta / 2.0)
299        if dz < maxd:
300            delta = maxd - dz
301            axes.set_zlim3d(zab[0] - delta / 2.0, zab[1] + delta / 2.0)
302    if fig:
303        if file:
304            pylab.savefig(file, dpi=dpi)
305        else:
306            pylab.show()
307
308
309VisualizationMethods.plot = plot
310VisualizationMethods.default_color_function = default_color_function
311VisualizationMethods.phase_color_function = phase_color_function
312VisualizationMethods.cplot = cplot
313VisualizationMethods.splot = splot
314