1# -*- coding: utf-8 -*-
2#
3# colormaps.py
4#
5# This file is part of NEST.
6#
7# Copyright (C) 2004 The NEST Initiative
8#
9# NEST is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 2 of the License, or
12# (at your option) any later version.
13#
14# NEST is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with NEST.  If not, see <http://www.gnu.org/licenses/>.
21
22# ConnPlotter --- A Tool to Generate Connectivity Pattern Matrices
23
24"""
25Colormaps for ConnPlotter.
26
27Provides the following functions and colormaps:
28
29 - make_colormap: based on color specification, create colormap
30                  running from from white to fully saturated color
31 - redblue: from fully saturated red to white to fully saturated blue
32 - bluered: from fully saturated blue to white to fully saturated red
33
34For all colormaps, "bad" values (NaN) are mapped to white.
35
36Provides also ZeroCenterNorm, mapping negative values to 0..0.5,
37positive to 0.5..1.
38"""
39
40# ----------------------------------------------------------------------------
41
42import matplotlib.colors as mc
43import matplotlib.cbook as cbook
44import numpy as np
45
46__all__ = ['ZeroCenterNorm', 'make_colormap', 'redblue', 'bluered',
47           'bad_color']
48
49# ----------------------------------------------------------------------------
50
51bad_color = (1.0, 1.0, 0.9)
52
53# ----------------------------------------------------------------------------
54
55
56class ZeroCenterNorm(mc.Normalize):
57    """
58    Normalize so that value 0 is always at 0.5.
59
60    Code from matplotlib.colors.Normalize.
61    Copyright (c) 2002-2009 John D. Hunter; All Rights Reserved
62    http://matplotlib.sourceforge.net/users/license.html
63    """
64
65    # ------------------------------------------------------------------------
66
67    def __call__(self, value, clip=None):
68        """
69        Normalize given values to [0,1].
70
71        Returns data in same form as passed in.
72        value can be scalar or array.
73        """
74        if clip is not None and clip is not False:
75            assert (False)  # clip not supported
76
77        if cbook.iterable(value):
78            vtype = 'array'
79            val = np.ma.asarray(value).astype(np.float)
80        else:
81            vtype = 'scalar'
82            val = np.ma.array([value]).astype(np.float)
83
84        self.autoscale_None(val)
85        self.vmin = min(0, self.vmin)
86        self.vmax = max(0, self.vmax)
87
88        # imshow expects masked arrays
89        # fill entire array with 0.5
90        result = np.ma.array(0.5 * np.ma.asarray(np.ones(np.shape(val))),
91                             dtype=np.float, mask=val.mask)
92
93        # change values != 0
94        result[val < 0] = 0.5 * (self.vmin - val[val < 0]) / self.vmin
95        result[val > 0] = 0.5 + 0.5 * val[val > 0] / self.vmax
96
97        if vtype == 'scalar':
98            result = result[0]
99
100        return result
101
102    # ------------------------------------------------------------------------
103
104    def inverse(self, value):
105        """
106        Invert color map. Required by colorbar().
107        """
108
109        if not self.scaled():
110            raise ValueError("Not invertible until scaled")
111        vmin, vmax = self.vmin, self.vmax
112
113        if cbook.iterable(value):
114            val = np.asarray(value)
115
116            res = np.zeros(np.shape(val))
117            res[val < 0.5] = vmin - 2 * vmin * val[val < 0.5]
118            res[val > 0.5] = 2 * (val[val > 0.5] - 0.5) * vmax
119            return res
120
121        else:
122            if value == 0.5:
123                return 0
124            elif value < 0.5:
125                return vmin - 2 * vmin * value  # vmin < 0
126            else:
127                return 2 * (value - 0.5) * vmax
128
129
130# ----------------------------------------------------------------------------
131
132def make_colormap(color):
133    """
134    Create LinearSegmentedColormap ranging from white to the given color.
135    Color can be given in any legal color format. Bad color is set to white.
136    """
137
138    try:
139        r, g, b = mc.colorConverter.to_rgb(color)
140    except Exception:
141        raise ValueError('Illegal color specification: %s' % color.__repr__)
142
143    cm = mc.LinearSegmentedColormap(color.__str__(),
144                                    {'red': [(0.0, 1.0, 1.0),
145                                             (1.0, r, r)],
146                                     'green': [(0.0, 1.0, 1.0),
147                                               (1.0, g, g)],
148                                     'blue': [(0.0, 1.0, 1.0),
149                                              (1.0, b, b)]})
150    cm.set_bad(color=bad_color)  # light yellow
151    return cm
152
153
154# ----------------------------------------------------------------------------
155
156redblue = mc.LinearSegmentedColormap('redblue',
157                                     {'red': [(0.0, 0.0, 1.0),
158                                              (0.5, 1.0, 1.0),
159                                              (1.0, 0.0, 0.0)],
160                                      'green': [(0.0, 0.0, 0.0),
161                                                (0.5, 1.0, 1.0),
162                                                (1.0, 0.0, 0.0)],
163                                      'blue': [(0.0, 0.0, 0.0),
164                                               (0.5, 1.0, 1.0),
165                                               (1.0, 1.0, 1.0)]})
166
167redblue.set_bad(color=bad_color)
168
169# ----------------------------------------------------------------------------
170
171bluered = mc.LinearSegmentedColormap('bluered',
172                                     {'red': [(0.0, 0.0, 0.0),
173                                              (0.5, 1.0, 1.0),
174                                              (1.0, 1.0, 1.0)],
175                                      'green': [(0.0, 0.0, 0.0),
176                                                (0.5, 1.0, 1.0),
177                                                (1.0, 0.0, 0.0)],
178                                      'blue': [(0.0, 1.0, 1.0),
179                                               (0.5, 1.0, 1.0),
180                                               (1.0, 0.0, 0.0)]})
181
182bluered.set_bad(color=bad_color)
183
184# ----------------------------------------------------------------------------
185
186if __name__ == '__main__':
187
188    # this should be proper unit tests
189    n1 = ZeroCenterNorm()
190    if (n1([-1, -0.5, 0.0, 0.5, 1.0]).data == np.array(
191            [0, 0.25, 0.5, 0.75, 1.0])).all():
192        print("n1 ok")
193    else:
194        print("n1 failed.")
195
196    n2 = ZeroCenterNorm(-1, 2)
197    if (n2([-1, -0.5, 0.0, 1.0, 2.0]).data == np.array(
198            [0, 0.25, 0.5, 0.75, 1.0])).all():
199        print("n2 ok")
200    else:
201        print("n2 failed.")
202