1# -*- coding: utf-8 -*-
2"""
3Tests the speed of image updates for an ImageItem and RawImageWidget.
4The speed will generally depend on the type of data being shown, whether
5it is being scaled and/or converted by lookup table, and whether OpenGL
6is used by the view widget
7"""
8
9## Add path to library (just for examples; you do not need this)
10import initExample
11
12import argparse
13import sys
14
15import numpy as np
16
17import pyqtgraph as pg
18from pyqtgraph.Qt import QtGui, QtCore, QT_LIB
19from time import perf_counter
20
21
22pg.setConfigOption('imageAxisOrder', 'row-major')
23
24import importlib
25ui_template = importlib.import_module(f'VideoTemplate_{QT_LIB.lower()}')
26
27try:
28    import cupy as cp
29    pg.setConfigOption("useCupy", True)
30    _has_cupy = True
31except ImportError:
32    cp = None
33    _has_cupy = False
34
35try:
36    import numba
37    _has_numba = True
38except ImportError:
39    numba = None
40    _has_numba = False
41
42try:
43    from pyqtgraph.widgets.RawImageWidget import RawImageGLWidget
44except ImportError:
45    RawImageGLWidget = None
46
47parser = argparse.ArgumentParser(description="Benchmark for testing video performance")
48parser.add_argument('--cuda', default=False, action='store_true', help="Use CUDA to process on the GPU", dest="cuda")
49parser.add_argument('--dtype', default='uint8', choices=['uint8', 'uint16', 'float'], help="Image dtype (uint8, uint16, or float)")
50parser.add_argument('--frames', default=3, type=int, help="Number of image frames to generate (default=3)")
51parser.add_argument('--image-mode', default='mono', choices=['mono', 'rgb'], help="Image data mode (mono or rgb)", dest='image_mode')
52parser.add_argument('--levels', default=None, type=lambda s: tuple([float(x) for x in s.split(',')]), help="min,max levels to scale monochromatic image dynamic range, or rmin,rmax,gmin,gmax,bmin,bmax to scale rgb")
53parser.add_argument('--lut', default=False, action='store_true', help="Use color lookup table")
54parser.add_argument('--lut-alpha', default=False, action='store_true', help="Use alpha color lookup table", dest='lut_alpha')
55parser.add_argument('--size', default='512x512', type=lambda s: tuple([int(x) for x in s.split('x')]), help="WxH image dimensions default='512x512'")
56args = parser.parse_args(sys.argv[1:])
57
58if RawImageGLWidget is not None:
59    # don't limit frame rate to vsync
60    sfmt = QtGui.QSurfaceFormat()
61    sfmt.setSwapInterval(0)
62    QtGui.QSurfaceFormat.setDefaultFormat(sfmt)
63
64app = pg.mkQApp("Video Speed Test Example")
65
66win = QtGui.QMainWindow()
67win.setWindowTitle('pyqtgraph example: VideoSpeedTest')
68ui = ui_template.Ui_MainWindow()
69ui.setupUi(win)
70win.show()
71
72if RawImageGLWidget is None:
73    ui.rawGLRadio.setEnabled(False)
74    ui.rawGLRadio.setText(ui.rawGLRadio.text() + " (OpenGL not available)")
75else:
76    ui.rawGLImg = RawImageGLWidget()
77    ui.stack.addWidget(ui.rawGLImg)
78
79# read in CLI args
80ui.cudaCheck.setChecked(args.cuda and _has_cupy)
81ui.cudaCheck.setEnabled(_has_cupy)
82ui.numbaCheck.setChecked(_has_numba and pg.getConfigOption("useNumba"))
83ui.numbaCheck.setEnabled(_has_numba)
84ui.framesSpin.setValue(args.frames)
85ui.widthSpin.setValue(args.size[0])
86ui.heightSpin.setValue(args.size[1])
87ui.dtypeCombo.setCurrentText(args.dtype)
88ui.rgbCheck.setChecked(args.image_mode=='rgb')
89ui.maxSpin1.setOpts(value=255, step=1)
90ui.minSpin1.setOpts(value=0, step=1)
91levelSpins = [ui.minSpin1, ui.maxSpin1, ui.minSpin2, ui.maxSpin2, ui.minSpin3, ui.maxSpin3]
92if args.cuda and _has_cupy:
93    xp = cp
94else:
95    xp = np
96if args.levels is None:
97    ui.scaleCheck.setChecked(False)
98    ui.rgbLevelsCheck.setChecked(False)
99else:
100    ui.scaleCheck.setChecked(True)
101    if len(args.levels) == 2:
102        ui.rgbLevelsCheck.setChecked(False)
103        ui.minSpin1.setValue(args.levels[0])
104        ui.maxSpin1.setValue(args.levels[1])
105    elif len(args.levels) == 6:
106        ui.rgbLevelsCheck.setChecked(True)
107        for spin,val in zip(levelSpins, args.levels):
108            spin.setValue(val)
109    else:
110        raise ValueError("levels argument must be 2 or 6 comma-separated values (got %r)" % (args.levels,))
111ui.lutCheck.setChecked(args.lut)
112ui.alphaCheck.setChecked(args.lut_alpha)
113
114
115#ui.graphicsView.useOpenGL()  ## buggy, but you can try it if you need extra speed.
116
117vb = pg.ViewBox()
118ui.graphicsView.setCentralItem(vb)
119vb.setAspectLocked()
120img = pg.ImageItem()
121vb.addItem(img)
122
123
124
125LUT = None
126def updateLUT():
127    global LUT, ui
128    dtype = ui.dtypeCombo.currentText()
129    if dtype == 'uint8':
130        n = 256
131    else:
132        n = 4096
133    LUT = ui.gradient.getLookupTable(n, alpha=ui.alphaCheck.isChecked())
134    if _has_cupy and xp == cp:
135        LUT = cp.asarray(LUT)
136ui.gradient.sigGradientChanged.connect(updateLUT)
137updateLUT()
138
139ui.alphaCheck.toggled.connect(updateLUT)
140
141def updateScale():
142    global ui, levelSpins
143    if ui.rgbLevelsCheck.isChecked():
144        for s in levelSpins[2:]:
145            s.setEnabled(True)
146    else:
147        for s in levelSpins[2:]:
148            s.setEnabled(False)
149
150updateScale()
151
152ui.rgbLevelsCheck.toggled.connect(updateScale)
153
154cache = {}
155def mkData():
156    with pg.BusyCursor():
157        global data, cache, ui, xp
158        frames = ui.framesSpin.value()
159        width = ui.widthSpin.value()
160        height = ui.heightSpin.value()
161        cacheKey = (ui.dtypeCombo.currentText(), ui.rgbCheck.isChecked(), frames, width, height)
162        if cacheKey not in cache:
163            if cacheKey[0] == 'uint8':
164                dt = xp.uint8
165                loc = 128
166                scale = 64
167                mx = 255
168            elif cacheKey[0] == 'uint16':
169                dt = xp.uint16
170                loc = 4096
171                scale = 1024
172                mx = 2**16 - 1
173            elif cacheKey[0] == 'float':
174                dt = xp.float32
175                loc = 1.0
176                scale = 0.1
177                mx = 1.0
178            else:
179                raise ValueError(f"unable to handle dtype: {cacheKey[0]}")
180
181            chan_shape = (height, width)
182            if ui.rgbCheck.isChecked():
183                frame_shape = chan_shape + (3,)
184            else:
185                frame_shape = chan_shape
186            data = xp.empty((frames,) + frame_shape, dtype=dt)
187            view = data.reshape((-1,) + chan_shape)
188            for idx in range(view.shape[0]):
189                subdata = xp.random.normal(loc=loc, scale=scale, size=chan_shape)
190                # note: gaussian filtering has been removed as it slows down array
191                #       creation greatly.
192                if cacheKey[0] != 'float':
193                    xp.clip(subdata, 0, mx, out=subdata)
194                view[idx] = subdata
195
196            data[:, 10:50, 10] = mx
197            data[:, 48, 9:12] = mx
198            data[:, 47, 8:13] = mx
199            cache = {cacheKey: data} # clear to save memory (but keep one to prevent unnecessary regeneration)
200
201        data = cache[cacheKey]
202        updateLUT()
203        updateSize()
204
205def updateSize():
206    global ui, vb
207    frames = ui.framesSpin.value()
208    width = ui.widthSpin.value()
209    height = ui.heightSpin.value()
210    dtype = xp.dtype(str(ui.dtypeCombo.currentText()))
211    rgb = 3 if ui.rgbCheck.isChecked() else 1
212    ui.sizeLabel.setText('%d MB' % (frames * width * height * rgb * dtype.itemsize / 1e6))
213    vb.setRange(QtCore.QRectF(0, 0, width, height))
214
215
216def noticeCudaCheck():
217    global xp, cache
218    cache = {}
219    if ui.cudaCheck.isChecked():
220        if _has_cupy:
221            xp = cp
222        else:
223            xp = np
224            ui.cudaCheck.setChecked(False)
225    else:
226        xp = np
227    mkData()
228
229
230def noticeNumbaCheck():
231    pg.setConfigOption('useNumba', _has_numba and ui.numbaCheck.isChecked())
232
233
234mkData()
235
236
237ui.dtypeCombo.currentIndexChanged.connect(mkData)
238ui.rgbCheck.toggled.connect(mkData)
239ui.widthSpin.editingFinished.connect(mkData)
240ui.heightSpin.editingFinished.connect(mkData)
241ui.framesSpin.editingFinished.connect(mkData)
242
243ui.widthSpin.valueChanged.connect(updateSize)
244ui.heightSpin.valueChanged.connect(updateSize)
245ui.framesSpin.valueChanged.connect(updateSize)
246ui.cudaCheck.toggled.connect(noticeCudaCheck)
247ui.numbaCheck.toggled.connect(noticeNumbaCheck)
248
249
250ptr = 0
251lastTime = perf_counter()
252fps = None
253def update():
254    global ui, ptr, lastTime, fps, LUT, img
255    if ui.lutCheck.isChecked():
256        useLut = LUT
257    else:
258        useLut = None
259
260    downsample = ui.downsampleCheck.isChecked()
261
262    if ui.scaleCheck.isChecked():
263        if ui.rgbLevelsCheck.isChecked():
264            useScale = [
265                [ui.minSpin1.value(), ui.maxSpin1.value()],
266                [ui.minSpin2.value(), ui.maxSpin2.value()],
267                [ui.minSpin3.value(), ui.maxSpin3.value()]]
268        else:
269            useScale = [ui.minSpin1.value(), ui.maxSpin1.value()]
270    else:
271        useScale = None
272
273    if ui.rawRadio.isChecked():
274        ui.rawImg.setImage(data[ptr%data.shape[0]], lut=useLut, levels=useScale)
275        ui.stack.setCurrentIndex(1)
276    elif ui.rawGLRadio.isChecked():
277        ui.rawGLImg.setImage(data[ptr%data.shape[0]], lut=useLut, levels=useScale)
278        ui.stack.setCurrentIndex(2)
279    else:
280        img.setImage(data[ptr%data.shape[0]], autoLevels=False, levels=useScale, lut=useLut, autoDownsample=downsample)
281        ui.stack.setCurrentIndex(0)
282        #img.setImage(data[ptr%data.shape[0]], autoRange=False)
283
284    ptr += 1
285    now = perf_counter()
286    dt = now - lastTime
287    lastTime = now
288    if fps is None:
289        fps = 1.0/dt
290    else:
291        s = np.clip(dt*3., 0, 1)
292        fps = fps * (1-s) + (1.0/dt) * s
293    ui.fpsLabel.setText('%0.2f fps' % fps)
294    app.processEvents()  ## force complete redraw for every plot
295timer = QtCore.QTimer()
296timer.timeout.connect(update)
297timer.start(0)
298
299if __name__ == '__main__':
300    pg.exec()
301