1#!/usr/bin/env python
2###############################################################################
3#                                                                             #
4# torusMesh.py                                                                #
5#                                                                             #
6# Implements structure and methods for working with a torus shaped mesh       #
7#                                                                             #
8# Copyright (C) Michael Imelfort                                              #
9#                                                                             #
10###############################################################################
11#                                                                             #
12#          .d8888b.                                    888b     d888          #
13#         d88P  Y88b                                   8888b   d8888          #
14#         888    888                                   88888b.d88888          #
15#         888        888d888 .d88b.   .d88b.  88888b.  888Y88888P888          #
16#         888  88888 888P"  d88""88b d88""88b 888 "88b 888 Y888P 888          #
17#         888    888 888    888  888 888  888 888  888 888  Y8P  888          #
18#         Y88b  d88P 888    Y88..88P Y88..88P 888 d88P 888   "   888          #
19#          "Y8888P88 888     "Y88P"   "Y88P"  88888P"  888       888          #
20#                                             888                             #
21#                                             888                             #
22#                                             888                             #
23#                                                                             #
24###############################################################################
25#                                                                             #
26# This program is free software: you can redistribute it and/or modify        #
27# it under the terms of the GNU General Public License as published by        #
28# the Free Software Foundation, either version 3 of the License, or           #
29# (at your option) any later version.                                         #
30#                                                                             #
31# This program is distributed in the hope that it will be useful,             #
32# but WITHOUT ANY WARRANTY; without even the implied warranty of              #
33# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the                #
34# GNU General Public License for more details.                                #
35#                                                                             #
36# You should have received a copy of the GNU General Public License           #
37# along with this program. If not, see <http://www.gnu.org/licenses/>.        #
38#                                                                             #
39###############################################################################
40
41__author__ = "Michael Imelfort"
42__copyright__ = "Copyright 2012/2013"
43__credits__ = ["Michael Imelfort"]
44__license__ = "GPL3"
45__version__ = "0.1.0"
46__maintainer__ = "Michael Imelfort"
47__email__ = "mike@mikeimelfort.com"
48__status__ = "Released"
49
50###############################################################################
51import sys
52import numpy as np
53from PIL import Image, ImageDraw
54from scipy.spatial.distance import cdist
55from colorsys import hsv_to_rgb as htr
56
57np.seterr(all='raise')
58
59###############################################################################
60###############################################################################
61###############################################################################
62###############################################################################
63
64class UnknownDistanceType(BaseException):
65    pass
66
67###############################################################################
68###############################################################################
69###############################################################################
70###############################################################################
71
72class TorusMesh:
73    """A closed mesh, in the shape of a torus"""
74
75    def __init__(self, rows, columns=0, dimension=1, randomize=False):
76        """ init
77        Set columns if you'd like something other than a square
78        By default the torus is a scalar field.
79        Increase dimension if you'd like a vector field
80        """
81        self.rows = rows
82        if(columns == 0): # make it square
83            self.columns = rows
84        else:
85            self.columns = columns
86        self.dimension = dimension
87        self.shape = (self.rows,self.columns,self.dimension)
88        self.flatShape = (self.rows*self.columns,self.dimension)
89        self.size = self.dimension*self.rows*self.columns
90
91        # we use these values in dist many many times
92        self.halfRow = float(self.rows)/2
93        self.halfColumn = float(self.columns)/2
94
95        # the first two dimensions repesent points on the surface
96        # the remainder represent values
97        if(randomize):
98            self.nodes = np.random.random(self.size).reshape(self.shape)
99        else:
100            self.nodes = np.zeros(self.size).reshape(self.shape)
101
102        # make an array of flattened nodes
103        self.flatNodes = self.nodes.reshape(self.flatShape)
104
105        # work out the central vector and corresponding max angle
106        c_vec = np.ones((self.dimension))
107        self.largestMag = np.linalg.norm(c_vec)
108        self.cVec = c_vec / self.largestMag
109        top_vec = np.zeros_like(self.cVec)
110        top_vec[0] = 1
111        self.maxAngle = self.getAngBetweenNormed(top_vec, self.cVec)
112
113    def fixFlatNodes(self, weights=None):
114        """Make sure everything is in sync"""
115        if weights is not None:
116            self.nodes = weights
117        self.flatNodes = self.nodes.reshape(self.flatShape)
118        return self.flatNodes
119
120#------------------------------------------------------------------------------
121# WORKING WITH THE DATA
122
123    def bestMatch(self, targetVector):
124        """Returns location of the best match to an existing vector
125        uses Euclidean distance
126        """
127        loc = np.argmin(cdist(self.flatNodes, [targetVector]))
128        row = int(loc/self.columns)
129        col = loc-(row*self.rows)
130        return [row, col]
131
132    def buildVarianceSurface(self):
133        """Work out the difference between each point and it's eight neighbours"""
134        diff_array = np.zeros(self.shape)
135        shift_array = np.zeros(self.shape)
136        shift_array2 = np.zeros(self.shape)
137        shift_diff = np.zeros(self.shape)
138
139        # shift horizontal
140        # ---
141        # BA-
142        # ---
143        shift_array[:,:-1,:] = self.nodes[:,1:self.columns,:]   # shift left
144        shift_array[:,-1,:] = self.nodes[:,0,:]                 # first node col to the end
145        tmp_diff = np.abs(shift_array-self.nodes)           # get the difference
146        diff_array += tmp_diff                              # add it on
147        shift_diff[:,0,:] = tmp_diff[:,-1,:]                    # shift diff right
148        shift_diff[:,1:self.columns,:] = tmp_diff[:,:-1,:]      # last to the first
149        diff_array += shift_diff                            # add it on
150
151        # shift horizontal vertical
152        # B--
153        # -A-
154        # ---
155        shift_array2[:-1,:,:] = shift_array[1:self.columns,:,:]
156        shift_array2[-1,:,:] = shift_array[0,:,:]
157        tmp_diff = np.abs(shift_array2-self.nodes)
158        diff_array += tmp_diff
159        shift_diff[0,:,:] = tmp_diff[-1,:,:]
160        shift_diff[1:self.columns,:,:] = tmp_diff[:-1,:,:]
161        tmp_diff[:,0,:] = shift_diff[:,-1,:]
162        tmp_diff[:,1:self.columns,:] = shift_diff[:,:-1,:]
163        diff_array += tmp_diff
164
165        # shift vertical
166        # -B-
167        # -A-
168        # ---
169        shift_array[:-1,:,:] = self.nodes[1:self.columns,:,:]
170        shift_array[-1,:,:] = self.nodes[0,:,:]
171        tmp_diff = np.abs(shift_array-self.nodes)
172        diff_array += tmp_diff
173        shift_diff[0,:,:] = tmp_diff[-1,:,:]
174        shift_diff[1:self.columns,:,:] = tmp_diff[:-1,:,:]
175        diff_array += shift_diff
176
177        # shift vertical horizontal
178        # --B
179        # -A-
180        # ---
181        shift_array2[:,0,:] = shift_array[:,-1,:]
182        shift_array2[:,1:self.columns,:] = shift_array[:,:-1,:]
183        tmp_diff = np.abs(shift_array2-self.nodes)
184        diff_array += tmp_diff
185        shift_diff[:,:-1,:] = tmp_diff[:,1:self.columns,:]
186        shift_diff[:,-1,:] = tmp_diff[:,0,:]
187        tmp_diff[0,:,:] = shift_diff[-1,:,:]
188        tmp_diff[1:self.columns,:,:] = shift_diff[:-1,:,:]
189        diff_array += tmp_diff
190
191        return diff_array
192
193#------------------------------------------------------------------------------
194# COLORING
195
196    def getColor(self, vector):
197        """return a colour for a given weight vector"""
198        sn = np.linalg.norm(vector)
199        if sn > 0:
200            vv = vector / sn
201            ang_perc = self.getAngBetweenNormed(vv, self.cVec)/self.maxAngle
202            mag_perc = sn / self.largestMag
203        else:
204            ang_perc = 0.0
205            mag_perc = 0.0
206        V = 1       # VAL remain fixed at 1. Reduce to make pastels if that's your preference...
207        col = [int(i*255) for i in htr(ang_perc, mag_perc, V)]
208        return col
209
210    def getAngBetweenNormed(self, P1, P2):
211        """Return the angle between two points (in radians)"""
212        # find the existing angle between them theta
213        c = np.dot(P1,P2)
214        # rounding errors hurt everyone...
215        if(c > 1.0):
216            c = 1.0
217        elif(c < -1.0):
218            c = -1.0
219        return np.arccos(c) # in radians
220
221#------------------------------------------------------------------------------
222# IO and IMAGE RENDERING
223
224    def __str__(self):
225        """string method"""
226        ret_array = []
227        for r in range(self.rows):
228            for c in range(self.columns):
229                ret_array.append("[ ")
230                for v in range(self.dimension):
231                    ret_array.append(str(self.nodes[r,c,v])+" ")
232                ret_array.append("],")
233            ret_array.append("\n")
234        return "".join(ret_array)
235
236    def renderSurface(self, fileName, nodes=None):
237        """make an image of the weights in the som"""
238        if nodes is None:
239            nodes = self.nodes
240        ns = np.shape(nodes)
241        rows = ns[0]
242        columns = ns[1]
243        try:
244            img = Image.new("RGB", (columns, rows))
245            for r in range(rows):
246                # build a color value for a vector value
247                for c in range(columns):
248                    col = self.getColor(nodes[r,c])
249                    img.putpixel((c,r), (col[0], col[1], col[2]))
250            img = img.resize((columns*10, rows*10),Image.NEAREST)
251            img.save(fileName)
252        except:
253            print sys.exc_info()[0]
254            raise
255
256###############################################################################
257###############################################################################
258###############################################################################
259###############################################################################
260