1# This file is part of h5py, a Python interface to the HDF5 library.
2#
3# http://www.h5py.org
4#
5# Copyright 2008-2013 Andrew Collette and contributors
6#
7# License:  Standard 3-clause BSD; see "license.txt" for full license terms
8#           and contributor agreement.
9
10"""
11    Demonstrates use of h5py in a multi-threaded GUI program.
12
13    In a perfect world, multi-threaded programs would practice strict
14    separation of tasks, with separate threads for HDF5, user interface,
15    processing, etc, communicating via queues.  In the real world, shared
16    state is frequently encountered, especially in the world of GUIs.  It's
17    quite common to initialize a shared resource (in this case an HDF5 file),
18    and pass it around between threads.  One must then be careful to regulate
19    access using locks, to ensure that each thread sees the file in a
20    consistent fashion.
21
22    This program demonstrates how to use h5py in a medium-sized
23    "shared-state" threading application.  Two threads exist: a GUI thread
24    (Tkinter) which takes user input and displays results, and a calculation
25    thread which is used to perform computation in the background, leaving
26    the GUI responsive to user input.
27
28    The computation thread calculates portions of the Mandelbrot set and
29    stores them in an HDF5 file.  The visualization/control thread reads
30    datasets from the same file and displays them using matplotlib.
31"""
32
33import tkinter as tk
34import threading
35
36import numpy as np
37from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg
38from matplotlib.figure import Figure
39
40import h5py
41
42
43file_lock = threading.RLock()  # Protects the file from concurrent access
44
45t = None  # We'll use this to store the active computation thread
46
47class ComputeThread(threading.Thread):
48
49    """
50        Computes a slice of the Mandelbrot set, and saves it to the HDF5 file.
51    """
52
53    def __init__(self, f, shape, escape, startcoords, extent, eventcall):
54        """ Set up a computation thread.
55
56        f: HDF5 File object
57        shape: 2-tuple (NX, NY)
58        escape: Integer giving max iterations to escape
59        start: Complex number giving initial location on the plane
60        extent: Complex number giving calculation extent on the plane
61        """
62        self.f = f
63        self.shape = shape
64        self.escape = escape
65        self.startcoords = startcoords
66        self.extent = extent
67        self.eventcall = eventcall
68
69        threading.Thread.__init__(self)
70
71    def run(self):
72        """ Perform computations and record the result to file """
73
74        nx, ny = self.shape
75
76        arr = np.ndarray((nx,ny), dtype='i')
77
78        xincr = self.extent.real/nx
79        yincr = self.extent.imag/ny
80
81        def compute_escape(pos, escape):
82            """ Compute the number of steps required to escape """
83            z = 0+0j;
84            for i in range(escape):
85                z = z**2 + pos
86                if abs(z) > 2:
87                    break
88            return i
89
90        for x in range(nx):
91            if x%25 == 0: print("Computing row %d" % x)
92            for y in range(ny):
93                pos = self.startcoords + complex(x*xincr, y*yincr)
94                arr[x,y] = compute_escape(pos, self.escape)
95
96        with file_lock:
97            dsname = "slice%03d" % len(self.f)
98            dset = self.f.create_dataset(dsname, (nx, ny), 'i')
99            dset.attrs['shape'] = self.shape
100            dset.attrs['start'] = self.startcoords
101            dset.attrs['extent'] = self.extent
102            dset.attrs['escape'] = self.escape
103            dset[...] = arr
104
105        print("Calculation for %s done" % dsname)
106
107        self.eventcall()
108
109class ComputeWidget(object):
110
111    """
112        Responsible for input widgets, and starting new computation threads.
113    """
114
115    def __init__(self, f, master, eventcall):
116
117        self.f = f
118
119        self.eventcall = eventcall
120
121        self.mainframe = tk.Frame(master=master)
122
123        entryframe = tk.Frame(master=self.mainframe)
124
125        nxlabel = tk.Label(entryframe, text="NX")
126        nylabel = tk.Label(entryframe, text="NY")
127        escapelabel = tk.Label(entryframe, text="Escape")
128        startxlabel = tk.Label(entryframe, text="Start X")
129        startylabel = tk.Label(entryframe, text="Start Y")
130        extentxlabel = tk.Label(entryframe, text="Extent X")
131        extentylabel = tk.Label(entryframe, text="Extent Y")
132
133        self.nxfield = tk.Entry(entryframe)
134        self.nyfield = tk.Entry(entryframe)
135        self.escapefield = tk.Entry(entryframe)
136        self.startxfield = tk.Entry(entryframe)
137        self.startyfield = tk.Entry(entryframe)
138        self.extentxfield = tk.Entry(entryframe)
139        self.extentyfield = tk.Entry(entryframe)
140
141        nxlabel.grid(row=0, column=0, sticky=tk.E)
142        nylabel.grid(row=1, column=0, sticky=tk.E)
143        escapelabel.grid(row=2, column=0, sticky=tk.E)
144        startxlabel.grid(row=3, column=0, sticky=tk.E)
145        startylabel.grid(row=4, column=0, sticky=tk.E)
146        extentxlabel.grid(row=5, column=0, sticky=tk.E)
147        extentylabel.grid(row=6, column=0, sticky=tk.E)
148
149        self.nxfield.grid(row=0, column=1)
150        self.nyfield.grid(row=1, column=1)
151        self.escapefield.grid(row=2, column=1)
152        self.startxfield.grid(row=3, column=1)
153        self.startyfield.grid(row=4, column=1)
154        self.extentxfield.grid(row=5, column=1)
155        self.extentyfield.grid(row=6, column=1)
156
157        entryframe.grid(row=0, rowspan=2, column=0)
158
159        self.suggestbutton = tk.Button(master=self.mainframe, text="Suggest", command=self.suggest)
160        self.computebutton = tk.Button(master=self.mainframe, text="Compute", command=self.compute)
161
162        self.suggestbutton.grid(row=0, column=1)
163        self.computebutton.grid(row=1, column=1)
164
165        self.suggest = 0
166
167    def compute(self, *args):
168        """ Validate input and start calculation thread.
169
170        We use a global variable "t" to store the current thread, to make
171        sure old threads are properly joined before they are discarded.
172        """
173        global t
174
175        try:
176            nx = int(self.nxfield.get())
177            ny = int(self.nyfield.get())
178            escape = int(self.escapefield.get())
179            start = complex(float(self.startxfield.get()), float(self.startyfield.get()))
180            extent = complex(float(self.extentxfield.get()), float(self.extentyfield.get()))
181            if (nx<=0) or (ny<=0) or (escape<=0):
182                raise ValueError("NX, NY and ESCAPE must be positive")
183            if abs(extent)==0:
184                raise ValueError("Extent must be finite")
185        except (ValueError, TypeError) as e:
186            print(e)
187            return
188
189        if t is not None:
190            t.join()
191
192        t = ComputeThread(self.f, (nx,ny), escape, start, extent, self.eventcall)
193        t.start()
194
195    def suggest(self, *args):
196        """ Populate the input fields with interesting locations """
197
198        suggestions = [(200,200,50, -2, -1, 3, 2),
199                       (500, 500, 200, 0.110, -0.680, 0.05, 0.05),
200                       (200, 200, 1000, -0.16070135-5e-8, 1.0375665-5e-8, 1e-7, 1e-7),
201                       (500, 500, 100, -1, 0, 0.5, 0.5)]
202
203        for entry, val in zip((self.nxfield, self.nyfield, self.escapefield,
204                self.startxfield, self.startyfield, self.extentxfield,
205                self.extentyfield), suggestions[self.suggest]):
206            entry.delete(0, 999)
207            entry.insert(0, repr(val))
208
209        self.suggest = (self.suggest+1)%len(suggestions)
210
211
212class ViewWidget(object):
213
214    """
215        Draws images using the datasets recorded in the HDF5 file.  Also
216        provides widgets to pick which dataset is displayed.
217    """
218
219    def __init__(self, f, master):
220
221        self.f = f
222
223        self.mainframe = tk.Frame(master=master)
224        self.lbutton = tk.Button(self.mainframe, text="<= Back", command=self.back)
225        self.rbutton = tk.Button(self.mainframe, text="Next =>", command=self.forward)
226        self.loclabel = tk.Label(self.mainframe, text='To start, enter values and click "compute"')
227        self.infolabel = tk.Label(self.mainframe, text='Or, click the "suggest" button for interesting locations')
228
229        self.fig = Figure(figsize=(5, 5), dpi=100)
230        self.plot = self.fig.add_subplot(111)
231        self.canvas = FigureCanvasTkAgg(self.fig, master=self.mainframe)
232        self.canvas.draw_idle()
233
234        self.loclabel.grid(row=0, column=1)
235        self.infolabel.grid(row=1, column=1)
236        self.lbutton.grid(row=2, column=0)
237        self.canvas.get_tk_widget().grid(row=2, column=1)
238        self.rbutton.grid(row=2, column=2)
239
240        self.index = 0
241
242        self.jumptolast()
243
244    def draw_fractal(self):
245        """ Read a dataset from the HDF5 file and display it """
246
247        with file_lock:
248            name = list(self.f.keys())[self.index]
249            dset = self.f[name]
250            arr = dset[...]
251            start = dset.attrs['start']
252            extent = dset.attrs['extent']
253            self.loclabel["text"] = 'Displaying dataset "%s" (%d of %d)' % (dset.name, self.index+1, len(self.f))
254            self.infolabel["text"] = "%(shape)s pixels, starts at %(start)s, extent %(extent)s" % dset.attrs
255
256        self.plot.clear()
257        self.plot.imshow(arr.transpose(), cmap='jet', aspect='auto', origin='lower',
258                         extent=(start.real, (start.real+extent.real),
259                                 start.imag, (start.imag+extent.imag)))
260        self.canvas.draw_idle()
261
262    def back(self):
263        """ Go to the previous dataset (in ASCII order) """
264        if self.index == 0:
265            print("Can't go back")
266            return
267        self.index -= 1
268        self.draw_fractal()
269
270    def forward(self):
271        """ Go to the next dataset (in ASCII order) """
272        if self.index == (len(self.f)-1):
273            print("Can't go forward")
274            return
275        self.index += 1
276        self.draw_fractal()
277
278    def jumptolast(self,*args):
279        """ Jump to the last (ASCII order) dataset and display it """
280        with file_lock:
281            if len(self.f) == 0:
282                print("can't jump to last (no datasets)")
283                return
284            index = len(self.f)-1
285        self.index = index
286        self.draw_fractal()
287
288
289if __name__ == '__main__':
290
291    f = h5py.File('mandelbrot_gui.hdf5', 'a')
292
293    root = tk.Tk()
294
295    display = ViewWidget(f, root)
296
297    root.bind("<<FractalEvent>>", display.jumptolast)
298    def callback():
299        root.event_generate("<<FractalEvent>>")
300    compute = ComputeWidget(f, root, callback)
301
302    display.mainframe.grid(row=0, column=0)
303    compute.mainframe.grid(row=1, column=0)
304
305    try:
306        root.mainloop()
307    finally:
308        if t is not None:
309            t.join()
310        f.close()
311