1# -*- coding: utf-8 -*-
2import numpy as np
3import matplotlib as mpl
4import matplotlib.pyplot as plt
5from matplotlib.colors import Normalize, LogNorm
6
7
8def generate_gridspec(**kwargs):
9    from matplotlib.gridspec import GridSpec
10    width = 0.84
11    bottom = 0.12
12    left = 0.12
13    return GridSpec(2, 2, width_ratios=[3, 1], height_ratios=[1, 3],
14                    bottom=bottom, top=bottom + width,
15                    left=left, right=left + width,
16                    **kwargs)
17
18
19def plot_DOS(ax, energy_e, dos_e, base_e, dos_min, dos_max,
20             flip=False, fill=None, line=None):
21    ax.xaxis.set_ticklabels([])
22    ax.yaxis.set_ticklabels([])
23    ax.spines['right'].set_visible(False)
24    ax.spines['top'].set_visible(False)
25    ax.yaxis.set_ticks_position('left')
26    ax.xaxis.set_ticks_position('bottom')
27    if flip:
28        set_label = ax.set_xlabel
29        fill_between = ax.fill_betweenx
30        set_energy_lim = ax.set_ylim
31        set_dos_lim = ax.set_xlim
32
33        def plot(x, y, *args, **kwargs):
34            return ax.plot(y, x, *args, **kwargs)
35    else:
36        set_label = ax.set_ylabel
37        fill_between = ax.fill_between
38        set_energy_lim = ax.set_xlim
39        set_dos_lim = ax.set_ylim
40
41        def plot(x, y, *args, **kwargs):
42            return ax.plot(x, y, *args, **kwargs)
43    if fill:
44        fill_between(energy_e, base_e, dos_e + base_e, **fill)
45    if line:
46        plot(energy_e, dos_e, **line)
47    set_label('DOS', labelpad=0)
48    set_energy_lim(np.take(energy_e, (0, -1)))
49    set_dos_lim(dos_min, dos_max)
50
51
52class TCM(object):
53
54    def __init__(self, energy_o, energy_u, fermilevel):
55        self.energy_o = energy_o
56        self.energy_u = energy_u
57        self.fermilevel = fermilevel
58
59        self.base_o = np.zeros_like(energy_o)
60        self.base_u = np.zeros_like(energy_u)
61
62    def __getattr__(self, attr):
63        # Generate axis only when needed
64        if attr in ['ax_occ_dos', 'ax_unocc_dos', 'ax_tcm']:
65            gs = generate_gridspec(hspace=0.05, wspace=0.05)
66            self.ax_occ_dos = plt.subplot(gs[0])
67            self.ax_unocc_dos = plt.subplot(gs[3])
68            self.ax_tcm = plt.subplot(gs[2])
69            return getattr(self, attr)
70        if attr in ['ax_spec']:
71            gs = generate_gridspec(hspace=0.8, wspace=0.8)
72            self.ax_spec = plt.subplot(gs[1])
73            return getattr(self, attr)
74        if attr in ['ax_cbar']:
75            self.ax_cbar = plt.axes((0.15, 0.6, 0.02, 0.1))
76            return getattr(self, attr)
77        raise AttributeError('%s object has no attribute %s' %
78                             (repr(self.__class__.__name__), repr(attr)))
79
80    def plot_TCM(self, tcm_ou, vmax='80%', vmin='symmetrize', cmap='seismic',
81                 log=False, colorbar=True, lw=None):
82        if lw is None:
83            lw = mpl.rcParams['lines.linewidth']
84        energy_o = self.energy_o
85        energy_u = self.energy_u
86        fermilevel = self.fermilevel
87
88        tcmmax = np.max(np.absolute(tcm_ou))
89        print('tcmmax', tcmmax)
90
91        # Plot TCM
92        ax = self.ax_tcm
93        plt.sca(ax)
94        plt.cla()
95        if isinstance(vmax, str):
96            assert vmax[-1] == '%'
97            tcmmax = np.max(np.absolute(tcm_ou))
98            vmax = tcmmax * float(vmax[:-1]) / 100.0
99        if vmin == 'symmetrize':
100            vmin = -vmax
101        if tcm_ou.dtype == complex:
102            linecolor = 'w'
103            from matplotlib.colors import hsv_to_rgb
104
105            def transform_to_hsv(z, rmin, rmax, hue_start=90):
106                amp = np.absolute(z)  # **2
107                amp = np.where(amp < rmin, rmin, amp)
108                amp = np.where(amp > rmax, rmax, amp)
109                ph = np.angle(z, deg=1) + hue_start
110                h = (ph % 360) / 360
111                s = 1.85 * np.ones_like(h)
112                v = (amp - rmin) / (rmax - rmin)
113                return hsv_to_rgb(np.dstack((h, s, v)))
114
115            img = transform_to_hsv(tcm_ou.T, 0, vmax)
116            plt.imshow(img, origin='lower',
117                       extent=[energy_o[0], energy_o[-1],
118                               energy_u[0], energy_u[-1]],
119                       interpolation='bilinear',
120                       )
121        else:
122            linecolor = 'k'
123            if cmap == 'magma':
124                linecolor = 'w'
125            if log:
126                norm = LogNorm(vmin=vmin, vmax=vmax)
127            else:
128                norm = Normalize(vmin=vmin, vmax=vmax)
129            plt.pcolormesh(energy_o, energy_u, tcm_ou.T,
130                           cmap=cmap, rasterized=True, norm=norm,
131                           )
132        if colorbar:
133            ax = self.ax_cbar
134            ax.clear()
135            cb = plt.colorbar(cax=ax)
136            cb.outline.set_edgecolor(linecolor)
137            ax.tick_params(axis='both', colors=linecolor)
138            # ax.yaxis.label.set_color(linecolor)
139            # ax.xaxis.label.set_color(linecolor)
140        ax = self.ax_tcm
141        plt.sca(ax)
142        plt.axhline(fermilevel, c=linecolor, lw=lw)
143        plt.axvline(fermilevel, c=linecolor, lw=lw)
144
145        ax.tick_params(axis='both', which='major', pad=2)
146        plt.xlabel(r'Occ. energy $\varepsilon_{o}$ (eV)', labelpad=0)
147        plt.ylabel(r'Unocc. energy $\varepsilon_{u}$ (eV)', labelpad=0)
148        plt.xlim(np.take(energy_o, (0, -1)))
149        plt.ylim(np.take(energy_u, (0, -1)))
150
151    def plot_DOS(self, dos_o, dos_u, stack=False,
152                 fill={'color': '0.8'}, line={'color': 'k'}):
153        # Plot DOSes
154        if stack:
155            base_o = self.base_o
156            base_u = self.base_u
157        else:
158            base_o = np.zeros_like(self.energy_o)
159            base_u = np.zeros_like(self.energy_u)
160        dos_min = 0.0
161        dos_max = 1.01 * max(np.max(dos_o), np.max(dos_u))
162        plot_DOS(self.ax_occ_dos, self.energy_o, dos_o, base_o,
163                 dos_min, dos_max, flip=False, fill=fill, line=line)
164        plot_DOS(self.ax_unocc_dos, self.energy_u, dos_u, base_u,
165                 dos_min, dos_max, flip=True, fill=fill, line=line)
166        if stack:
167            self.base_o += dos_o
168            self.base_u += dos_u
169
170    def plot_spectrum(self):
171        raise NotImplementedError()
172
173    def plot_TCM_diagonal(self, energy, **kwargs):
174        x_o = np.take(self.energy_o, (0, -1))
175        self.ax_tcm.plot(x_o, x_o + energy, **kwargs)
176
177    def set_title(self, *args, **kwargs):
178        self.ax_occ_dos.set_title(*args, **kwargs)
179
180
181class TCMPlotter(TCM):
182
183    def __init__(self, ksd, energy_o, energy_u, sigma,
184                 zero_fermilevel=True):
185        eig_n, fermilevel = ksd.get_eig_n(zero_fermilevel)
186        TCM.__init__(self, energy_o, energy_u, fermilevel)
187        self.ksd = ksd
188        self.sigma = sigma
189        self.eig_n = eig_n
190
191    def plot_TCM(self, weight_p, **kwargs):
192        # Calculate TCM
193        tcm_ou = self.ksd.get_TCM(weight_p, self.eig_n, self.energy_o,
194                                  self.energy_u, self.sigma)
195        TCM.plot_TCM(self, tcm_ou, **kwargs)
196
197    def plot_DOS(self, weight_n=1.0, **kwargs):
198        # Calculate DOS
199        dos_o, dos_u = self.ksd.get_weighted_DOS(weight_n, self.eig_n,
200                                                 self.energy_o,
201                                                 self.energy_u,
202                                                 self.sigma)
203        TCM.plot_DOS(self, dos_o, dos_u, **kwargs)
204