1#!/usr/bin/env python
2#
3# Copyright 2008,2010,2013 Free Software Foundation, Inc.
4#
5# This file is part of GNU Radio
6#
7# GNU Radio is free software; you can redistribute it and/or modify
8# it under the terms of the GNU General Public License as published by
9# the Free Software Foundation; either version 3, or (at your option)
10# any later version.
11#
12# GNU Radio is distributed in the hope that it will be useful,
13# but WITHOUT ANY WARRANTY; without even the implied warranty of
14# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15# GNU General Public License for more details.
16#
17# You should have received a copy of the GNU General Public License
18# along with GNU Radio; see the file COPYING.  If not, write to
19# the Free Software Foundation, Inc., 51 Franklin Street,
20# Boston, MA 02110-1301, USA.
21#
22
23from __future__ import division
24
25import numpy
26from gnuradio import gr, gr_unittest, wavelet, analog, blocks
27import copy
28#import pygsl.wavelet as wavelet # FIXME: pygsl not checked for in config
29import math
30
31def sqr(x):
32    return x*x
33
34def np2(k):
35    m = 0
36    n = k - 1
37    while n > 0:
38        m += 1
39    return m
40
41
42class test_classify(gr_unittest.TestCase):
43
44    def setUp(self):
45        self.tb = gr.top_block()
46
47    def tearDown(self):
48        self.tb = None
49
50#     def test_000_(self):
51#         src_data = numpy.zeros(10)
52#         trg_data = numpy.zeros(10)
53#         src = blocks.vector_source_f(src_data)
54#         dst = blocks.vector_sink_f()
55#         self.tb.connect(src, dst)
56#         self.tb.run()
57#         rsl_data = dst.data()
58#         sum = 0
59#         for (u,v) in zip(trg_data, rsl_data):
60#             w = u - v
61#             sum += w * w
62#         sum /= float(len(trg_data))
63#         assert sum < 1e-6
64
65    def test_001_(self):
66        src_data = numpy.array([-1.0, 1.0, -1.0, 1.0])
67        trg_data = src_data * 0.5
68        src = blocks.vector_source_f(src_data)
69        dst = blocks.vector_sink_f()
70        rail = analog.rail_ff(-0.5, 0.5)
71        self.tb.connect(src, rail)
72        self.tb.connect(rail, dst)
73        self.tb.run()
74        rsl_data = dst.data()
75        sum = 0
76        for (u, v) in zip(trg_data, rsl_data):
77            w = u - v
78            sum += w * w
79        sum /= float(len(trg_data))
80        assert sum < 1e-6
81
82    def test_002_(self):
83        src_data = numpy.array([-1.0,
84                                -1.0 / 2.0,
85                                -1.0 / 3.0,
86                                -1.0 / 4.0,
87                                -1.0 / 5.0])
88        trg_data = copy.deepcopy(src_data)
89
90        src = blocks.vector_source_f(src_data, False, len(src_data))
91        st = blocks.stretch_ff(-1.0 / 5.0, len(src_data))
92        dst = blocks.vector_sink_f(len(src_data))
93        self.tb.connect(src, st)
94        self.tb.connect(st, dst)
95        self.tb.run()
96        rsl_data = dst.data()
97        sum = 0
98        for (u, v) in zip(trg_data, rsl_data):
99            w = u - v
100            sum += w * w
101        sum /= float(len(trg_data))
102        assert sum < 1e-6
103
104    def test_003_(self):
105        src_grid = (0.0, 1.0, 2.0, 3.0, 4.0)
106        trg_grid = copy.deepcopy(src_grid)
107        src_data = (0.0, 1.0, 0.0, 1.0, 0.0)
108
109        src = blocks.vector_source_f(src_data, False, len(src_grid))
110        sq = wavelet.squash_ff(src_grid, trg_grid)
111        dst = blocks.vector_sink_f(len(trg_grid))
112        self.tb.connect(src, sq)
113        self.tb.connect(sq, dst)
114        self.tb.run()
115        rsl_data = dst.data()
116        sum = 0
117        for (u, v) in zip(src_data, rsl_data):
118            w = u - v
119            sum += w * w
120        sum /= float(len(src_data))
121        assert sum < 1e-6
122
123#    def test_004_(self): # FIXME: requires pygsl
124#
125#        n = 256
126#        o = 4
127#        ws = wavelet.workspace(n)
128#        w = wavelet.daubechies(o)
129#
130#        a = numpy.arange(n)
131#        b = numpy.sin(a*numpy.pi/16.0)
132#        c = w.transform_forward(b, ws)
133#        d = w.transform_inverse(c, ws)
134#
135#        src = gr.vector_source_f(b, False, n)
136#        wv = wavelet.wavelet_ff(n, o, True)
137#        src = blocks.vector_source_f(b, False, n)
138#        wv = wavelet.wavelet_ff(n, o, True)
139#
140#        dst = blocks.vector_sink_f(n)
141#        self.tb.connect(src, wv)
142#        self.tb.connect(wv, dst)
143#        self.tb.run()
144#        e = dst.data()
145#
146#        sum = 0
147#        for (u, v) in zip(c, e):
148#            w = u - v
149#            sum += w * w
150#        sum /= float(len(c))
151#        assert sum < 1e-6
152
153    def test_005_(self):
154
155        src_data = (1.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0)
156
157        dwav = numpy.array(src_data)
158        wvps = numpy.zeros(3)
159        # wavelet power spectrum
160        scl = 1.0 / sqr(dwav[0])
161        k = 1
162        for e in range(len(wvps)):
163            wvps[e] = scl*sqr(dwav[k:k+(0o1<<e)]).sum()
164            k += 0o1<<e
165
166        src = blocks.vector_source_f(src_data, False, len(src_data))
167        kon = wavelet.wvps_ff(len(src_data))
168        dst = blocks.vector_sink_f(int(math.ceil(math.log(len(src_data), 2))))
169
170        self.tb.connect(src, kon)
171        self.tb.connect(kon, dst)
172
173        self.tb.run()
174        snk_data = dst.data()
175
176        sum = 0
177        for (u,v) in zip(snk_data, wvps):
178            w = u - v
179            sum += w * w
180        sum /= float(len(snk_data))
181        assert sum < 1e-6
182
183if __name__ == '__main__':
184   gr_unittest.run(test_classify, "test_classify.xml")
185