1""" 2Plotting (requires matplotlib) 3""" 4 5from colorsys import hsv_to_rgb, hls_to_rgb 6from .libmp import NoConvergence 7from .libmp.backend import xrange 8 9class VisualizationMethods(object): 10 plot_ignore = (ValueError, ArithmeticError, ZeroDivisionError, NoConvergence) 11 12def plot(ctx, f, xlim=[-5,5], ylim=None, points=200, file=None, dpi=None, 13 singularities=[], axes=None): 14 r""" 15 Shows a simple 2D plot of a function `f(x)` or list of functions 16 `[f_0(x), f_1(x), \ldots, f_n(x)]` over a given interval 17 specified by *xlim*. Some examples:: 18 19 plot(lambda x: exp(x)*li(x), [1, 4]) 20 plot([cos, sin], [-4, 4]) 21 plot([fresnels, fresnelc], [-4, 4]) 22 plot([sqrt, cbrt], [-4, 4]) 23 plot(lambda t: zeta(0.5+t*j), [-20, 20]) 24 plot([floor, ceil, abs, sign], [-5, 5]) 25 26 Points where the function raises a numerical exception or 27 returns an infinite value are removed from the graph. 28 Singularities can also be excluded explicitly 29 as follows (useful for removing erroneous vertical lines):: 30 31 plot(cot, ylim=[-5, 5]) # bad 32 plot(cot, ylim=[-5, 5], singularities=[-pi, 0, pi]) # good 33 34 For parts where the function assumes complex values, the 35 real part is plotted with dashes and the imaginary part 36 is plotted with dots. 37 38 .. note :: This function requires matplotlib (pylab). 39 """ 40 if file: 41 axes = None 42 fig = None 43 if not axes: 44 import pylab 45 fig = pylab.figure() 46 axes = fig.add_subplot(111) 47 if not isinstance(f, (tuple, list)): 48 f = [f] 49 a, b = xlim 50 colors = ['b', 'r', 'g', 'm', 'k'] 51 for n, func in enumerate(f): 52 x = ctx.arange(a, b, (b-a)/float(points)) 53 segments = [] 54 segment = [] 55 in_complex = False 56 for i in xrange(len(x)): 57 try: 58 if i != 0: 59 for sing in singularities: 60 if x[i-1] <= sing and x[i] >= sing: 61 raise ValueError 62 v = func(x[i]) 63 if ctx.isnan(v) or abs(v) > 1e300: 64 raise ValueError 65 if hasattr(v, "imag") and v.imag: 66 re = float(v.real) 67 im = float(v.imag) 68 if not in_complex: 69 in_complex = True 70 segments.append(segment) 71 segment = [] 72 segment.append((float(x[i]), re, im)) 73 else: 74 if in_complex: 75 in_complex = False 76 segments.append(segment) 77 segment = [] 78 if hasattr(v, "real"): 79 v = v.real 80 segment.append((float(x[i]), v)) 81 except ctx.plot_ignore: 82 if segment: 83 segments.append(segment) 84 segment = [] 85 if segment: 86 segments.append(segment) 87 for segment in segments: 88 x = [s[0] for s in segment] 89 y = [s[1] for s in segment] 90 if not x: 91 continue 92 c = colors[n % len(colors)] 93 if len(segment[0]) == 3: 94 z = [s[2] for s in segment] 95 axes.plot(x, y, '--'+c, linewidth=3) 96 axes.plot(x, z, ':'+c, linewidth=3) 97 else: 98 axes.plot(x, y, c, linewidth=3) 99 axes.set_xlim([float(_) for _ in xlim]) 100 if ylim: 101 axes.set_ylim([float(_) for _ in ylim]) 102 axes.set_xlabel('x') 103 axes.set_ylabel('f(x)') 104 axes.grid(True) 105 if fig: 106 if file: 107 pylab.savefig(file, dpi=dpi) 108 else: 109 pylab.show() 110 111def default_color_function(ctx, z): 112 if ctx.isinf(z): 113 return (1.0, 1.0, 1.0) 114 if ctx.isnan(z): 115 return (0.5, 0.5, 0.5) 116 pi = 3.1415926535898 117 a = (float(ctx.arg(z)) + ctx.pi) / (2*ctx.pi) 118 a = (a + 0.5) % 1.0 119 b = 1.0 - float(1/(1.0+abs(z)**0.3)) 120 return hls_to_rgb(a, b, 0.8) 121 122blue_orange_colors = [ 123 (-1.0, (0.0, 0.0, 0.0)), 124 (-0.95, (0.1, 0.2, 0.5)), # dark blue 125 (-0.5, (0.0, 0.5, 1.0)), # blueish 126 (-0.05, (0.4, 0.8, 0.8)), # cyanish 127 ( 0.0, (1.0, 1.0, 1.0)), 128 ( 0.05, (1.0, 0.9, 0.3)), # yellowish 129 ( 0.5, (0.9, 0.5, 0.0)), # orangeish 130 ( 0.95, (0.7, 0.1, 0.0)), # redish 131 ( 1.0, (0.0, 0.0, 0.0)), 132 ( 2.0, (0.0, 0.0, 0.0)), 133] 134 135def phase_color_function(ctx, z): 136 if ctx.isinf(z): 137 return (1.0, 1.0, 1.0) 138 if ctx.isnan(z): 139 return (0.5, 0.5, 0.5) 140 pi = 3.1415926535898 141 w = float(ctx.arg(z)) / pi 142 w = max(min(w, 1.0), -1.0) 143 for i in range(1,len(blue_orange_colors)): 144 if blue_orange_colors[i][0] > w: 145 a, (ra, ga, ba) = blue_orange_colors[i-1] 146 b, (rb, gb, bb) = blue_orange_colors[i] 147 s = (w-a) / (b-a) 148 return ra+(rb-ra)*s, ga+(gb-ga)*s, ba+(bb-ba)*s 149 150def cplot(ctx, f, re=[-5,5], im=[-5,5], points=2000, color=None, 151 verbose=False, file=None, dpi=None, axes=None): 152 """ 153 Plots the given complex-valued function *f* over a rectangular part 154 of the complex plane specified by the pairs of intervals *re* and *im*. 155 For example:: 156 157 cplot(lambda z: z, [-2, 2], [-10, 10]) 158 cplot(exp) 159 cplot(zeta, [0, 1], [0, 50]) 160 161 By default, the complex argument (phase) is shown as color (hue) and 162 the magnitude is show as brightness. You can also supply a 163 custom color function (*color*). This function should take a 164 complex number as input and return an RGB 3-tuple containing 165 floats in the range 0.0-1.0. 166 167 Alternatively, you can select a builtin color function by passing 168 a string as *color*: 169 170 * "default" - default color scheme 171 * "phase" - a color scheme that only renders the phase of the function, 172 with white for positive reals, black for negative reals, gold in the 173 upper half plane, and blue in the lower half plane. 174 175 To obtain a sharp image, the number of points may need to be 176 increased to 100,000 or thereabout. Since evaluating the 177 function that many times is likely to be slow, the 'verbose' 178 option is useful to display progress. 179 180 .. note :: This function requires matplotlib (pylab). 181 """ 182 if color is None or color == "default": 183 color = ctx.default_color_function 184 if color == "phase": 185 color = ctx.phase_color_function 186 import pylab 187 if file: 188 axes = None 189 fig = None 190 if not axes: 191 fig = pylab.figure() 192 axes = fig.add_subplot(111) 193 rea, reb = re 194 ima, imb = im 195 dre = reb - rea 196 dim = imb - ima 197 M = int(ctx.sqrt(points*dre/dim)+1) 198 N = int(ctx.sqrt(points*dim/dre)+1) 199 x = pylab.linspace(rea, reb, M) 200 y = pylab.linspace(ima, imb, N) 201 # Note: we have to be careful to get the right rotation. 202 # Test with these plots: 203 # cplot(lambda z: z if z.real < 0 else 0) 204 # cplot(lambda z: z if z.imag < 0 else 0) 205 w = pylab.zeros((N, M, 3)) 206 for n in xrange(N): 207 for m in xrange(M): 208 z = ctx.mpc(x[m], y[n]) 209 try: 210 v = color(f(z)) 211 except ctx.plot_ignore: 212 v = (0.5, 0.5, 0.5) 213 w[n,m] = v 214 if verbose: 215 print(str(n) + ' of ' + str(N)) 216 rea, reb, ima, imb = [float(_) for _ in [rea, reb, ima, imb]] 217 axes.imshow(w, extent=(rea, reb, ima, imb), origin='lower') 218 axes.set_xlabel('Re(z)') 219 axes.set_ylabel('Im(z)') 220 if fig: 221 if file: 222 pylab.savefig(file, dpi=dpi) 223 else: 224 pylab.show() 225 226def splot(ctx, f, u=[-5,5], v=[-5,5], points=100, keep_aspect=True, \ 227 wireframe=False, file=None, dpi=None, axes=None): 228 """ 229 Plots the surface defined by `f`. 230 231 If `f` returns a single component, then this plots the surface 232 defined by `z = f(x,y)` over the rectangular domain with 233 `x = u` and `y = v`. 234 235 If `f` returns three components, then this plots the parametric 236 surface `x, y, z = f(u,v)` over the pairs of intervals `u` and `v`. 237 238 For example, to plot a simple function:: 239 240 >>> from mpmath import * 241 >>> f = lambda x, y: sin(x+y)*cos(y) 242 >>> splot(f, [-pi,pi], [-pi,pi]) # doctest: +SKIP 243 244 Plotting a donut:: 245 246 >>> r, R = 1, 2.5 247 >>> f = lambda u, v: [r*cos(u), (R+r*sin(u))*cos(v), (R+r*sin(u))*sin(v)] 248 >>> splot(f, [0, 2*pi], [0, 2*pi]) # doctest: +SKIP 249 250 .. note :: This function requires matplotlib (pylab) 0.98.5.3 or higher. 251 """ 252 import pylab 253 import mpl_toolkits.mplot3d as mplot3d 254 if file: 255 axes = None 256 fig = None 257 if not axes: 258 fig = pylab.figure() 259 axes = mplot3d.axes3d.Axes3D(fig) 260 ua, ub = u 261 va, vb = v 262 du = ub - ua 263 dv = vb - va 264 if not isinstance(points, (list, tuple)): 265 points = [points, points] 266 M, N = points 267 u = pylab.linspace(ua, ub, M) 268 v = pylab.linspace(va, vb, N) 269 x, y, z = [pylab.zeros((M, N)) for i in xrange(3)] 270 xab, yab, zab = [[0, 0] for i in xrange(3)] 271 for n in xrange(N): 272 for m in xrange(M): 273 fdata = f(ctx.convert(u[m]), ctx.convert(v[n])) 274 try: 275 x[m,n], y[m,n], z[m,n] = fdata 276 except TypeError: 277 x[m,n], y[m,n], z[m,n] = u[m], v[n], fdata 278 for c, cab in [(x[m,n], xab), (y[m,n], yab), (z[m,n], zab)]: 279 if c < cab[0]: 280 cab[0] = c 281 if c > cab[1]: 282 cab[1] = c 283 if wireframe: 284 axes.plot_wireframe(x, y, z, rstride=4, cstride=4) 285 else: 286 axes.plot_surface(x, y, z, rstride=4, cstride=4) 287 axes.set_xlabel('x') 288 axes.set_ylabel('y') 289 axes.set_zlabel('z') 290 if keep_aspect: 291 dx, dy, dz = [cab[1] - cab[0] for cab in [xab, yab, zab]] 292 maxd = max(dx, dy, dz) 293 if dx < maxd: 294 delta = maxd - dx 295 axes.set_xlim3d(xab[0] - delta / 2.0, xab[1] + delta / 2.0) 296 if dy < maxd: 297 delta = maxd - dy 298 axes.set_ylim3d(yab[0] - delta / 2.0, yab[1] + delta / 2.0) 299 if dz < maxd: 300 delta = maxd - dz 301 axes.set_zlim3d(zab[0] - delta / 2.0, zab[1] + delta / 2.0) 302 if fig: 303 if file: 304 pylab.savefig(file, dpi=dpi) 305 else: 306 pylab.show() 307 308 309VisualizationMethods.plot = plot 310VisualizationMethods.default_color_function = default_color_function 311VisualizationMethods.phase_color_function = phase_color_function 312VisualizationMethods.cplot = cplot 313VisualizationMethods.splot = splot 314