1#!/usr/bin/env python 2# 3# Copyright 2008,2010,2013 Free Software Foundation, Inc. 4# 5# This file is part of GNU Radio 6# 7# GNU Radio is free software; you can redistribute it and/or modify 8# it under the terms of the GNU General Public License as published by 9# the Free Software Foundation; either version 3, or (at your option) 10# any later version. 11# 12# GNU Radio is distributed in the hope that it will be useful, 13# but WITHOUT ANY WARRANTY; without even the implied warranty of 14# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 15# GNU General Public License for more details. 16# 17# You should have received a copy of the GNU General Public License 18# along with GNU Radio; see the file COPYING. If not, write to 19# the Free Software Foundation, Inc., 51 Franklin Street, 20# Boston, MA 02110-1301, USA. 21# 22 23from __future__ import division 24 25import numpy 26from gnuradio import gr, gr_unittest, wavelet, analog, blocks 27import copy 28#import pygsl.wavelet as wavelet # FIXME: pygsl not checked for in config 29import math 30 31def sqr(x): 32 return x*x 33 34def np2(k): 35 m = 0 36 n = k - 1 37 while n > 0: 38 m += 1 39 return m 40 41 42class test_classify(gr_unittest.TestCase): 43 44 def setUp(self): 45 self.tb = gr.top_block() 46 47 def tearDown(self): 48 self.tb = None 49 50# def test_000_(self): 51# src_data = numpy.zeros(10) 52# trg_data = numpy.zeros(10) 53# src = blocks.vector_source_f(src_data) 54# dst = blocks.vector_sink_f() 55# self.tb.connect(src, dst) 56# self.tb.run() 57# rsl_data = dst.data() 58# sum = 0 59# for (u,v) in zip(trg_data, rsl_data): 60# w = u - v 61# sum += w * w 62# sum /= float(len(trg_data)) 63# assert sum < 1e-6 64 65 def test_001_(self): 66 src_data = numpy.array([-1.0, 1.0, -1.0, 1.0]) 67 trg_data = src_data * 0.5 68 src = blocks.vector_source_f(src_data) 69 dst = blocks.vector_sink_f() 70 rail = analog.rail_ff(-0.5, 0.5) 71 self.tb.connect(src, rail) 72 self.tb.connect(rail, dst) 73 self.tb.run() 74 rsl_data = dst.data() 75 sum = 0 76 for (u, v) in zip(trg_data, rsl_data): 77 w = u - v 78 sum += w * w 79 sum /= float(len(trg_data)) 80 assert sum < 1e-6 81 82 def test_002_(self): 83 src_data = numpy.array([-1.0, 84 -1.0 / 2.0, 85 -1.0 / 3.0, 86 -1.0 / 4.0, 87 -1.0 / 5.0]) 88 trg_data = copy.deepcopy(src_data) 89 90 src = blocks.vector_source_f(src_data, False, len(src_data)) 91 st = blocks.stretch_ff(-1.0 / 5.0, len(src_data)) 92 dst = blocks.vector_sink_f(len(src_data)) 93 self.tb.connect(src, st) 94 self.tb.connect(st, dst) 95 self.tb.run() 96 rsl_data = dst.data() 97 sum = 0 98 for (u, v) in zip(trg_data, rsl_data): 99 w = u - v 100 sum += w * w 101 sum /= float(len(trg_data)) 102 assert sum < 1e-6 103 104 def test_003_(self): 105 src_grid = (0.0, 1.0, 2.0, 3.0, 4.0) 106 trg_grid = copy.deepcopy(src_grid) 107 src_data = (0.0, 1.0, 0.0, 1.0, 0.0) 108 109 src = blocks.vector_source_f(src_data, False, len(src_grid)) 110 sq = wavelet.squash_ff(src_grid, trg_grid) 111 dst = blocks.vector_sink_f(len(trg_grid)) 112 self.tb.connect(src, sq) 113 self.tb.connect(sq, dst) 114 self.tb.run() 115 rsl_data = dst.data() 116 sum = 0 117 for (u, v) in zip(src_data, rsl_data): 118 w = u - v 119 sum += w * w 120 sum /= float(len(src_data)) 121 assert sum < 1e-6 122 123# def test_004_(self): # FIXME: requires pygsl 124# 125# n = 256 126# o = 4 127# ws = wavelet.workspace(n) 128# w = wavelet.daubechies(o) 129# 130# a = numpy.arange(n) 131# b = numpy.sin(a*numpy.pi/16.0) 132# c = w.transform_forward(b, ws) 133# d = w.transform_inverse(c, ws) 134# 135# src = gr.vector_source_f(b, False, n) 136# wv = wavelet.wavelet_ff(n, o, True) 137# src = blocks.vector_source_f(b, False, n) 138# wv = wavelet.wavelet_ff(n, o, True) 139# 140# dst = blocks.vector_sink_f(n) 141# self.tb.connect(src, wv) 142# self.tb.connect(wv, dst) 143# self.tb.run() 144# e = dst.data() 145# 146# sum = 0 147# for (u, v) in zip(c, e): 148# w = u - v 149# sum += w * w 150# sum /= float(len(c)) 151# assert sum < 1e-6 152 153 def test_005_(self): 154 155 src_data = (1.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0) 156 157 dwav = numpy.array(src_data) 158 wvps = numpy.zeros(3) 159 # wavelet power spectrum 160 scl = 1.0 / sqr(dwav[0]) 161 k = 1 162 for e in range(len(wvps)): 163 wvps[e] = scl*sqr(dwav[k:k+(0o1<<e)]).sum() 164 k += 0o1<<e 165 166 src = blocks.vector_source_f(src_data, False, len(src_data)) 167 kon = wavelet.wvps_ff(len(src_data)) 168 dst = blocks.vector_sink_f(int(math.ceil(math.log(len(src_data), 2)))) 169 170 self.tb.connect(src, kon) 171 self.tb.connect(kon, dst) 172 173 self.tb.run() 174 snk_data = dst.data() 175 176 sum = 0 177 for (u,v) in zip(snk_data, wvps): 178 w = u - v 179 sum += w * w 180 sum /= float(len(snk_data)) 181 assert sum < 1e-6 182 183if __name__ == '__main__': 184 gr_unittest.run(test_classify, "test_classify.xml") 185