1#!/usr/bin/env python
2# Copyright 2014-2020 The PySCF Developers. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15#
16# Author: Qiming Sun <osirpt.sun@gmail.com>
17#
18
19'''
20Extension to numpy and scipy
21'''
22
23import string
24import ctypes
25import math
26import numpy
27from pyscf.lib import misc
28from numpy import asarray  # For backward compatibility
29
30EINSUM_MAX_SIZE = getattr(misc.__config__, 'lib_einsum_max_size', 2000)
31
32try:
33    # Import tblis before libnp_helper to avoid potential dl-loading conflicts
34    from pyscf import tblis_einsum
35    FOUND_TBLIS = True
36except (ImportError, OSError):
37    FOUND_TBLIS = False
38
39_np_helper = misc.load_library('libnp_helper')
40
41BLOCK_DIM = 192
42PLAIN = 0
43HERMITIAN = 1
44ANTIHERMI = 2
45SYMMETRIC = 3
46
47LeviCivita = numpy.zeros((3,3,3))
48LeviCivita[0,1,2] = LeviCivita[1,2,0] = LeviCivita[2,0,1] = 1
49LeviCivita[0,2,1] = LeviCivita[2,1,0] = LeviCivita[1,0,2] = -1
50
51PauliMatrices = numpy.array([[[0., 1.],
52                              [1., 0.]],  # x
53                             [[0.,-1j],
54                              [1j, 0.]],  # y
55                             [[1., 0.],
56                              [0.,-1.]]]) # z
57
58if hasattr(numpy, 'einsum_path'):
59    _einsum_path = numpy.einsum_path
60else:
61    def _einsum_path(subscripts, *operands, **kwargs):
62        #indices  = re.split(',|->', subscripts)
63        #indices_in = indices[:-1]
64        #idx_final = indices[-1]
65        if '->' in subscripts:
66            indices_in, idx_final = subscripts.split('->')
67            indices_in = indices_in.split(',')
68            # indices = indices_in + [idx_final]
69        else:
70            idx_final = ''
71            indices_in = subscripts.split('->')[0].split(',')
72
73        if len(indices_in) <= 2:
74            idx_removed = set(indices_in[0]).intersection(set(indices_in[1]))
75            einsum_str = indices_in[1] + ',' + indices_in[0] + '->' + idx_final
76            return operands, [((1,0), idx_removed, einsum_str, idx_final)]
77
78        input_sets = [set(x) for x in indices_in]
79        n_shared_max = 0
80        for i in range(len(indices_in)):
81            for j in range(i):
82                tmp = input_sets[i].intersection(input_sets[j])
83                n_shared_indices = len(tmp)
84                if n_shared_indices > n_shared_max:
85                    n_shared_max = n_shared_indices
86                    idx_removed = tmp
87                    a,b = i,j
88
89        idxA = indices_in.pop(a)
90        idxB = indices_in.pop(b)
91        rest_idx = ''.join(indices_in) + idx_final
92        idx_out = input_sets[a].union(input_sets[b])
93        idx_out = ''.join(idx_out.intersection(set(rest_idx)))
94
95        indices_in.append(idx_out)
96        einsum_str = idxA + ',' + idxB + '->' + idx_out
97        einsum_args = _einsum_path(','.join(indices_in)+'->'+idx_final)[1]
98        einsum_args.insert(0, ((a, b), idx_removed, einsum_str, indices_in))
99        return operands, einsum_args
100
101_numpy_einsum = numpy.einsum
102def _contract(subscripts, *tensors, **kwargs):
103    idx_str = subscripts.replace(' ','')
104    A, B = tensors
105    # Call numpy.asarray because A or B may be HDF5 Datasets
106    A = numpy.asarray(A)
107    B = numpy.asarray(B)
108
109    # small problem size
110    if A.size < EINSUM_MAX_SIZE or B.size < EINSUM_MAX_SIZE:
111        return _numpy_einsum(idx_str, A, B)
112
113    C_dtype = numpy.result_type(A, B)
114    if FOUND_TBLIS and C_dtype == numpy.double:
115        # tblis is slow for complex type
116        return tblis_einsum.contract(idx_str, A, B, **kwargs)
117
118    indices  = idx_str.replace(',', '').replace('->', '')
119    if '->' not in idx_str or any(indices.count(x) != 2 for x in set(indices)):
120        return _numpy_einsum(idx_str, A, B)
121
122    # Split the strings into a list of idx char's
123    idxA, idxBC = idx_str.split(',')
124    idxB, idxC = idxBC.split('->')
125    assert len(idxA) == A.ndim
126    assert len(idxB) == B.ndim
127
128    uniq_idxA = set(idxA)
129    uniq_idxB = set(idxB)
130    # Find the shared indices being summed over
131    shared_idxAB = uniq_idxA.intersection(uniq_idxB)
132
133    if ((not shared_idxAB) or  # Indices must overlap
134        # one operand is a subset of the other one (e.g. 'ijkl,jk->il')
135        uniq_idxA == shared_idxAB or uniq_idxB == shared_idxAB or
136        # repeated indices (e.g. 'iijk,kl->jl')
137        len(idxA) != len(uniq_idxA) or len(idxB) != len(uniq_idxB)):
138        return _numpy_einsum(idx_str, A, B)
139
140    DEBUG = kwargs.get('DEBUG', False)
141
142    if DEBUG:
143        print("*** Einsum for", idx_str)
144        print(" idxA =", idxA)
145        print(" idxB =", idxB)
146        print(" idxC =", idxC)
147
148    # Get the range for each index and put it in a dictionary
149    rangeA = dict(zip(idxA, A.shape))
150    rangeB = dict(zip(idxB, B.shape))
151    #rangeC = dict(zip(idxC, C.shape))
152    if DEBUG:
153        print("rangeA =", rangeA)
154        print("rangeB =", rangeB)
155
156    idxAt = list(idxA)
157    idxBt = list(idxB)
158    inner_shape = 1
159    insert_B_loc = 0
160    for n in shared_idxAB:
161        if rangeA[n] != rangeB[n]:
162            err = ('ERROR: In index string %s, the range of index %s is '
163                   'different in A (%d) and B (%d)' %
164                   (idx_str, n, rangeA[n], rangeB[n]))
165            raise ValueError(err)
166
167        # Bring idx all the way to the right for A
168        # and to the left (but preserve order) for B
169        idxA_n = idxAt.index(n)
170        idxAt.insert(len(idxAt)-1, idxAt.pop(idxA_n))
171
172        idxB_n = idxBt.index(n)
173        idxBt.insert(insert_B_loc, idxBt.pop(idxB_n))
174        insert_B_loc += 1
175
176        inner_shape *= rangeA[n]
177
178    if DEBUG:
179        print("shared_idxAB =", shared_idxAB)
180        print("inner_shape =", inner_shape)
181
182    # Transpose the tensors into the proper order and reshape into matrices
183    new_orderA = [idxA.index(idx) for idx in idxAt]
184    new_orderB = [idxB.index(idx) for idx in idxBt]
185
186    if DEBUG:
187        print("Transposing A as", new_orderA)
188        print("Transposing B as", new_orderB)
189        print("Reshaping A as (-1,", inner_shape, ")")
190        print("Reshaping B as (", inner_shape, ",-1)")
191
192    shapeCt = list()
193    idxCt = list()
194    for idx in idxAt:
195        if idx in shared_idxAB:
196            break
197        shapeCt.append(rangeA[idx])
198        idxCt.append(idx)
199    for idx in idxBt:
200        if idx in shared_idxAB:
201            continue
202        shapeCt.append(rangeB[idx])
203        idxCt.append(idx)
204    new_orderCt = [idxCt.index(idx) for idx in idxC]
205
206    if A.size == 0 or B.size == 0:
207        shapeCt = [shapeCt[i] for i in new_orderCt]
208        return numpy.zeros(shapeCt, dtype=C_dtype)
209
210    At = A.transpose(new_orderA)
211    Bt = B.transpose(new_orderB)
212
213    if At.flags.f_contiguous:
214        At = numpy.asarray(At.reshape(-1,inner_shape), order='F')
215    else:
216        At = numpy.asarray(At.reshape(-1,inner_shape), order='C')
217    if Bt.flags.f_contiguous:
218        Bt = numpy.asarray(Bt.reshape(inner_shape,-1), order='F')
219    else:
220        Bt = numpy.asarray(Bt.reshape(inner_shape,-1), order='C')
221
222    return dot(At,Bt).reshape(shapeCt, order='A').transpose(new_orderCt)
223
224def einsum(subscripts, *tensors, **kwargs):
225    '''Perform a more efficient einsum via reshaping to a matrix multiply.
226
227    Current differences compared to numpy.einsum:
228    This assumes that each repeated index is actually summed (i.e. no 'i,i->i')
229    and appears only twice (i.e. no 'ij,ik,il->jkl'). The output indices must
230    be explicitly specified (i.e. 'ij,j->i' and not 'ij,j').
231    '''
232    contract = kwargs.pop('_contract', _contract)
233
234    subscripts = subscripts.replace(' ','')
235    if len(tensors) <= 1 or '...' in subscripts:
236        out = _numpy_einsum(subscripts, *tensors, **kwargs)
237    elif len(tensors) <= 2:
238        out = _contract(subscripts, *tensors, **kwargs)
239    else:
240        if '->' in subscripts:
241            indices_in, idx_final = subscripts.split('->')
242            indices_in = indices_in.split(',')
243        else:
244            # idx_final = ''
245            indices_in = subscripts.split('->')[0].split(',')
246        tensors = list(tensors)
247        contraction_list = _einsum_path(subscripts, *tensors, optimize=True,
248                                        einsum_call=True)[1]
249        for contraction in contraction_list:
250            inds, idx_rm, einsum_str, remaining = contraction[:4]
251            tmp_operands = [tensors.pop(x) for x in inds]
252            if len(tmp_operands) > 2:
253                out = _numpy_einsum(einsum_str, *tmp_operands)
254            else:
255                out = contract(einsum_str, *tmp_operands)
256            tensors.append(out)
257    return out
258
259
260# 2d -> 1d or 3d -> 2d
261def pack_tril(mat, axis=-1, out=None):
262    '''flatten the lower triangular part of a matrix.
263    Given mat, it returns mat[...,numpy.tril_indices(mat.shape[0])]
264
265    Examples:
266
267    >>> pack_tril(numpy.arange(9).reshape(3,3))
268    [0 3 4 6 7 8]
269    '''
270    if mat.size == 0:
271        return numpy.zeros(mat.shape+(0,), dtype=mat.dtype)
272
273    if mat.ndim == 2:
274        count, nd = 1, mat.shape[0]
275        shape = nd*(nd+1)//2
276    else:
277        count, nd = mat.shape[:2]
278        shape = (count, nd*(nd+1)//2)
279
280    if mat.ndim == 2 or axis == -1:
281        mat = numpy.asarray(mat, order='C')
282        out = numpy.ndarray(shape, mat.dtype, buffer=out)
283        if mat.dtype == numpy.double:
284            fn = _np_helper.NPdpack_tril_2d
285        elif mat.dtype == numpy.complex128:
286            fn = _np_helper.NPzpack_tril_2d
287        else:
288            out[:] = mat[numpy.tril_indices(nd)]
289            return out
290
291        fn(ctypes.c_int(count), ctypes.c_int(nd),
292           out.ctypes.data_as(ctypes.c_void_p),
293           mat.ctypes.data_as(ctypes.c_void_p))
294        return out
295
296    else:  # pack the leading two dimension
297        assert(axis == 0)
298        out = mat[numpy.tril_indices(nd)]
299        return out
300
301# 1d -> 2d or 2d -> 3d, write hermitian lower triangle to upper triangle
302def unpack_tril(tril, filltriu=HERMITIAN, axis=-1, out=None):
303    '''Reversed operation of pack_tril.
304
305    Kwargs:
306        filltriu : int
307
308            | 0           Do not fill the upper triangular part, random number may appear
309                          in the upper triangular part
310            | 1 (default) Transpose the lower triangular part to fill the upper triangular part
311            | 2           Similar to filltriu=1, negative of the lower triangular part is assign
312                          to the upper triangular part to make the matrix anti-hermitian
313
314    Examples:
315
316    >>> unpack_tril(numpy.arange(6.))
317    [[ 0. 1. 3.]
318     [ 1. 2. 4.]
319     [ 3. 4. 5.]]
320    >>> unpack_tril(numpy.arange(6.), 0)
321    [[ 0. 0. 0.]
322     [ 1. 2. 0.]
323     [ 3. 4. 5.]]
324    >>> unpack_tril(numpy.arange(6.), 2)
325    [[ 0. -1. -3.]
326     [ 1.  2. -4.]
327     [ 3.  4.  5.]]
328    '''
329    tril = numpy.asarray(tril, order='C')
330    if tril.ndim == 1:
331        count, nd = 1, tril.size
332        nd = int(numpy.sqrt(nd*2))
333        shape = (nd,nd)
334    elif tril.ndim == 2:
335        if axis == 0:
336            nd, count = tril.shape
337        else:
338            count, nd = tril.shape
339        nd = int(numpy.sqrt(nd*2))
340        shape = (count,nd,nd)
341    else:
342        raise NotImplementedError('unpack_tril for high dimension arrays')
343
344    if (tril.dtype != numpy.double and tril.dtype != numpy.complex128):
345        out = numpy.ndarray(shape, tril.dtype, buffer=out)
346        idx, idy = numpy.tril_indices(nd)
347        if filltriu == ANTIHERMI:
348            out[...,idy,idx] = -tril
349        else:
350            out[...,idy,idx] = tril
351        out[...,idx,idy] = tril
352        return out
353
354    elif tril.ndim == 1 or axis == -1 or axis == tril.ndim-1:
355        out = numpy.ndarray(shape, tril.dtype, buffer=out)
356        if tril.dtype == numpy.double:
357            fn = _np_helper.NPdunpack_tril_2d
358        else:
359            fn = _np_helper.NPzunpack_tril_2d
360        fn(ctypes.c_int(count), ctypes.c_int(nd),
361           tril.ctypes.data_as(ctypes.c_void_p),
362           out.ctypes.data_as(ctypes.c_void_p), ctypes.c_int(filltriu))
363        return out
364
365    else:  # unpack the leading dimension
366        assert(axis == 0)
367        shape = (nd,nd) + tril.shape[1:]
368        out = numpy.ndarray(shape, tril.dtype, buffer=out)
369        idx = numpy.tril_indices(nd)
370        if filltriu == HERMITIAN:
371            for ij,(i,j) in enumerate(zip(*idx)):
372                out[i,j] = tril[ij]
373                out[j,i] = tril[ij].conj()
374        elif filltriu == ANTIHERMI:
375            for ij,(i,j) in enumerate(zip(*idx)):
376                out[i,j] = tril[ij]
377                out[j,i] =-tril[ij].conj()
378        elif filltriu == SYMMETRIC:
379            #:for ij,(i,j) in enumerate(zip(*idx)):
380            #:    out[i,j] = out[j,i] = tril[ij]
381            idxy = numpy.empty((nd,nd), dtype=numpy.int)
382            idxy[idx[0],idx[1]] = idxy[idx[1],idx[0]] = numpy.arange(nd*(nd+1)//2)
383            numpy.take(tril, idxy, axis=0, out=out)
384        else:
385            out[idx] = tril
386        return out
387
388# extract a row from a tril-packed matrix
389def unpack_row(tril, row_id):
390    '''Extract one row of the lower triangular part of a matrix.
391    It is equivalent to unpack_tril(a)[row_id]
392
393    Examples:
394
395    >>> unpack_row(numpy.arange(6.), 0)
396    [ 0. 1. 3.]
397    >>> unpack_tril(numpy.arange(6.))[0]
398    [ 0. 1. 3.]
399    '''
400    tril = numpy.ascontiguousarray(tril)
401    nd = int(numpy.sqrt(tril.size*2))
402    mat = numpy.empty(nd, tril.dtype)
403    if tril.dtype == numpy.double:
404        fn = _np_helper.NPdunpack_row
405    elif tril.dtype == numpy.complex128:
406        fn = _np_helper.NPzunpack_row
407    else:
408        p0 = row_id*(row_id+1)//2
409        p1 = row_id*(row_id+1)//2 + row_id
410        idx = numpy.arange(row_id, nd)
411        return numpy.append(tril[p0:p1], tril[idx*(idx+1)//2+row_id])
412
413    fn.restype = ctypes.c_void_p
414    fn(ctypes.c_int(nd), ctypes.c_int(row_id),
415       tril.ctypes.data_as(ctypes.c_void_p),
416       mat.ctypes.data_as(ctypes.c_void_p))
417    return mat
418
419# for i > j of 2d mat, mat[j,i] = mat[i,j]
420def hermi_triu(mat, hermi=HERMITIAN, inplace=True):
421    '''Use the elements of the lower triangular part to fill the upper triangular part.
422
423    Kwargs:
424        filltriu : int
425
426            | 1 (default) return a hermitian matrix
427            | 2           return an anti-hermitian matrix
428
429    Examples:
430
431    >>> unpack_row(numpy.arange(9.).reshape(3,3), 1)
432    [[ 0.  3.  6.]
433     [ 3.  4.  7.]
434     [ 6.  7.  8.]]
435    >>> unpack_row(numpy.arange(9.).reshape(3,3), 2)
436    [[ 0. -3. -6.]
437     [ 3.  4. -7.]
438     [ 6.  7.  8.]]
439    '''
440    assert(hermi == HERMITIAN or hermi == ANTIHERMI)
441    if not inplace:
442        mat = mat.copy('A')
443    if mat.flags.c_contiguous:
444        buf = mat
445    elif mat.flags.f_contiguous:
446        buf = mat.T
447    else:
448        raise NotImplementedError
449
450    nd = mat.shape[0]
451    assert(mat.size == nd**2)
452
453    if mat.dtype == numpy.double:
454        fn = _np_helper.NPdsymm_triu
455    elif mat.dtype == numpy.complex128:
456        fn = _np_helper.NPzhermi_triu
457    else:
458        raise NotImplementedError
459    fn.restype = ctypes.c_void_p
460    fn(ctypes.c_int(nd), buf.ctypes.data_as(ctypes.c_void_p),
461       ctypes.c_int(hermi))
462    return mat
463
464
465LINEAR_DEP_THRESHOLD = 1e-10
466def solve_lineq_by_SVD(a, b):
467    '''Solving a * x = b.  If a is a singular matrix, its small SVD values are
468    neglected.
469    '''
470    t, w, vH = numpy.linalg.svd(a)
471    idx = []
472    for i,wi in enumerate(w):
473        if wi > LINEAR_DEP_THRESHOLD:
474            idx.append(i)
475    if idx:
476        idx = numpy.array(idx)
477        tb = numpy.dot(numpy.array(t[:,idx]).T.conj(), numpy.array(b))
478        x = numpy.dot(numpy.array(vH[idx,:]).T.conj(), tb / w[idx])
479    else:
480        x = numpy.zeros_like(b)
481    return x
482
483def take_2d(a, idx, idy, out=None):
484    '''Equivalent to a[idx[:,None],idy] for a 2D array.
485
486    Examples:
487
488    >>> out = numpy.arange(9.).reshape(3,3)
489    >>> take_2d(a, [0,2], [0,2])
490    [[ 0.  2.]
491     [ 6.  8.]]
492    '''
493    a = numpy.asarray(a, order='C')
494    out = numpy.ndarray((len(idx),len(idy)), dtype=a.dtype, buffer=out)
495    idx = numpy.asarray(idx, dtype=numpy.int32)
496    idy = numpy.asarray(idy, dtype=numpy.int32)
497    if a.dtype == numpy.double:
498        fn = _np_helper.NPdtake_2d
499    elif a.dtype == numpy.complex128:
500        fn = _np_helper.NPztake_2d
501    else:
502        return a[idx[:,None],idy]
503    fn(out.ctypes.data_as(ctypes.c_void_p),
504       a.ctypes.data_as(ctypes.c_void_p),
505       idx.ctypes.data_as(ctypes.c_void_p),
506       idy.ctypes.data_as(ctypes.c_void_p),
507       ctypes.c_int(out.shape[1]), ctypes.c_int(a.shape[1]),
508       ctypes.c_int(idx.size), ctypes.c_int(idy.size))
509    return out
510
511def takebak_2d(out, a, idx, idy, thread_safe=True):
512    '''Reverse operation of take_2d.  Equivalent to out[idx[:,None],idy] += a
513    for a 2D array.
514
515    Examples:
516
517    >>> out = numpy.zeros((3,3))
518    >>> takebak_2d(out, numpy.ones((2,2)), [0,2], [0,2])
519    [[ 1.  0.  1.]
520     [ 0.  0.  0.]
521     [ 1.  0.  1.]]
522    '''
523    assert(out.flags.c_contiguous)
524    a = numpy.asarray(a, order='C')
525    if out.dtype != a.dtype:
526        a = a.astype(out.dtype)
527    if out.dtype == numpy.double:
528        fn = _np_helper.NPdtakebak_2d
529    elif out.dtype == numpy.complex128:
530        fn = _np_helper.NPztakebak_2d
531    else:
532        if thread_safe:
533            out[idx[:,None], idy] += a
534        else:
535            raise NotImplementedError
536        return out
537    idx = numpy.asarray(idx, dtype=numpy.int32)
538    idy = numpy.asarray(idy, dtype=numpy.int32)
539    fn(out.ctypes.data_as(ctypes.c_void_p),
540       a.ctypes.data_as(ctypes.c_void_p),
541       idx.ctypes.data_as(ctypes.c_void_p),
542       idy.ctypes.data_as(ctypes.c_void_p),
543       ctypes.c_int(out.shape[1]), ctypes.c_int(a.shape[1]),
544       ctypes.c_int(idx.size), ctypes.c_int(idy.size),
545       ctypes.c_int(thread_safe))
546    return out
547
548def transpose(a, axes=None, inplace=False, out=None):
549    '''Transposing an array with better memory efficiency
550
551    Examples:
552
553    >>> transpose(numpy.ones((3,2)))
554    [[ 1.  1.  1.]
555     [ 1.  1.  1.]]
556    '''
557    if inplace:
558        arow, acol = a.shape
559        assert(arow == acol)
560        tmp = numpy.empty((BLOCK_DIM,BLOCK_DIM))
561        for c0, c1 in misc.prange(0, acol, BLOCK_DIM):
562            for r0, r1 in misc.prange(0, c0, BLOCK_DIM):
563                tmp[:c1-c0,:r1-r0] = a[c0:c1,r0:r1]
564                a[c0:c1,r0:r1] = a[r0:r1,c0:c1].T
565                a[r0:r1,c0:c1] = tmp[:c1-c0,:r1-r0].T
566            # diagonal blocks
567            a[c0:c1,c0:c1] = a[c0:c1,c0:c1].T
568        return a
569
570    if (not a.flags.c_contiguous
571        or (a.dtype != numpy.double and a.dtype != numpy.complex128)):
572        if a.ndim == 2:
573            arow, acol = a.shape
574            out = numpy.empty((acol,arow), a.dtype)
575            r1 = c1 = 0
576            for c0 in range(0, acol-BLOCK_DIM, BLOCK_DIM):
577                c1 = c0 + BLOCK_DIM
578                for r0 in range(0, arow-BLOCK_DIM, BLOCK_DIM):
579                    r1 = r0 + BLOCK_DIM
580                    out[c0:c1,r0:r1] = a[r0:r1,c0:c1].T
581                out[c0:c1,r1:arow] = a[r1:arow,c0:c1].T
582            for r0 in range(0, arow-BLOCK_DIM, BLOCK_DIM):
583                r1 = r0 + BLOCK_DIM
584                out[c1:acol,r0:r1] = a[r0:r1,c1:acol].T
585            out[c1:acol,r1:arow] = a[r1:arow,c1:acol].T
586            return out
587        else:
588            return a.transpose(axes)
589
590    if a.ndim == 2:
591        arow, acol = a.shape
592        c_shape = (ctypes.c_int*3)(1, arow, acol)
593        out = numpy.ndarray((acol, arow), a.dtype, buffer=out)
594    elif a.ndim == 3 and axes == (0,2,1):
595        d0, arow, acol = a.shape
596        c_shape = (ctypes.c_int*3)(d0, arow, acol)
597        out = numpy.ndarray((d0, acol, arow), a.dtype, buffer=out)
598    else:
599        raise NotImplementedError
600
601    assert(a.flags.c_contiguous)
602    if a.dtype == numpy.double:
603        fn = _np_helper.NPdtranspose_021
604    else:
605        fn = _np_helper.NPztranspose_021
606    fn.restype = ctypes.c_void_p
607    fn(c_shape, a.ctypes.data_as(ctypes.c_void_p),
608       out.ctypes.data_as(ctypes.c_void_p))
609    return out
610
611def transpose_sum(a, inplace=False, out=None):
612    '''Computing a + a.T with better memory efficiency
613
614    Examples:
615
616    >>> transpose_sum(numpy.arange(4.).reshape(2,2))
617    [[ 0.  3.]
618     [ 3.  6.]]
619    '''
620    return hermi_sum(a, inplace=inplace, out=out)
621
622def hermi_sum(a, axes=None, hermi=HERMITIAN, inplace=False, out=None):
623    '''Computing a + a.T.conj() with better memory efficiency
624
625    Examples:
626
627    >>> transpose_sum(numpy.arange(4.).reshape(2,2))
628    [[ 0.  3.]
629     [ 3.  6.]]
630    '''
631    if inplace:
632        out = a
633    else:
634        out = numpy.ndarray(a.shape, a.dtype, buffer=out)
635
636    if (not a.flags.c_contiguous
637        or (a.dtype != numpy.double and a.dtype != numpy.complex128)):
638        if a.ndim == 2:
639            na = a.shape[0]
640            for c0, c1 in misc.prange(0, na, BLOCK_DIM):
641                for r0, r1 in misc.prange(0, c0, BLOCK_DIM):
642                    tmp = a[r0:r1,c0:c1] + a[c0:c1,r0:r1].conj().T
643                    out[c0:c1,r0:r1] = tmp.T.conj()
644                    out[r0:r1,c0:c1] = tmp
645                # diagonal blocks
646                tmp = a[c0:c1,c0:c1] + a[c0:c1,c0:c1].conj().T
647                out[c0:c1,c0:c1] = tmp
648            return out
649        else:
650            raise NotImplementedError('input array is not C-contiguous')
651
652    if a.ndim == 2:
653        assert(a.shape[0] == a.shape[1])
654        c_shape = (ctypes.c_int*3)(1, a.shape[0], a.shape[1])
655    elif a.ndim == 3 and axes == (0,2,1):
656        assert(a.shape[1] == a.shape[2])
657        c_shape = (ctypes.c_int*3)(*(a.shape))
658    else:
659        raise NotImplementedError
660
661    assert(a.flags.c_contiguous)
662    if a.dtype == numpy.double:
663        fn = _np_helper.NPdsymm_021_sum
664    else:
665        fn = _np_helper.NPzhermi_021_sum
666    fn(c_shape, a.ctypes.data_as(ctypes.c_void_p),
667       out.ctypes.data_as(ctypes.c_void_p), ctypes.c_int(hermi))
668    return out
669
670# NOTE: NOT assume array a, b to be C-contiguous, since a and b are two
671# pointers we want to pass in.
672# numpy.dot might not call optimized blas
673def ddot(a, b, alpha=1, c=None, beta=0):
674    '''Matrix-matrix multiplication for double precision arrays
675    '''
676    m = a.shape[0]
677    k = a.shape[1]
678    n = b.shape[1]
679    if a.flags.c_contiguous:
680        trans_a = 'N'
681    elif a.flags.f_contiguous:
682        trans_a = 'T'
683        a = a.T
684    else:
685        a = numpy.asarray(a, order='C')
686        trans_a = 'N'
687        #raise ValueError('a.flags: %s' % str(a.flags))
688
689    assert(k == b.shape[0])
690    if b.flags.c_contiguous:
691        trans_b = 'N'
692    elif b.flags.f_contiguous:
693        trans_b = 'T'
694        b = b.T
695    else:
696        b = numpy.asarray(b, order='C')
697        trans_b = 'N'
698        #raise ValueError('b.flags: %s' % str(b.flags))
699
700    if c is None:
701        c = numpy.empty((m,n))
702        beta = 0
703    else:
704        assert(c.shape == (m,n))
705
706    return _dgemm(trans_a, trans_b, m, n, k, a, b, c, alpha, beta)
707
708def zdot(a, b, alpha=1, c=None, beta=0):
709    '''Matrix-matrix multiplication for double complex arrays
710    '''
711    m = a.shape[0]
712    k = a.shape[1]
713    n = b.shape[1]
714    if a.flags.c_contiguous:
715        trans_a = 'N'
716    elif a.flags.f_contiguous:
717        trans_a = 'T'
718        a = a.T
719    else:
720        raise ValueError('a.flags: %s' % str(a.flags))
721
722    assert(k == b.shape[0])
723    if b.flags.c_contiguous:
724        trans_b = 'N'
725    elif b.flags.f_contiguous:
726        trans_b = 'T'
727        b = b.T
728    else:
729        raise ValueError('b.flags: %s' % str(b.flags))
730
731    if c is None:
732        beta = 0
733        c = numpy.empty((m,n), dtype=numpy.complex128)
734    else:
735        assert(c.shape == (m,n))
736
737    return _zgemm(trans_a, trans_b, m, n, k, a, b, c, alpha, beta)
738
739def dot(a, b, alpha=1, c=None, beta=0):
740    atype = a.dtype
741    btype = b.dtype
742
743    if atype == numpy.float64 and btype == numpy.float64:
744        if c is None or c.dtype == numpy.float64:
745            return ddot(a, b, alpha, c, beta)
746        else:
747            cr = numpy.asarray(c.real, order='C')
748            c.real = ddot(a, b, alpha, cr, beta)
749            return c
750
751    elif atype == numpy.complex128 and btype == numpy.complex128:
752        # Gauss's complex multiplication algorithm may affect numerical stability
753        #k1 = ddot(a.real+a.imag, b.real.copy(), alpha)
754        #k2 = ddot(a.real.copy(), b.imag-b.real, alpha)
755        #k3 = ddot(a.imag.copy(), b.real+b.imag, alpha)
756        #ab = k1-k3 + (k1+k2)*1j
757        return zdot(a, b, alpha, c, beta)
758
759    elif atype == numpy.float64 and btype == numpy.complex128:
760        if b.flags.f_contiguous:
761            order = 'F'
762        else:
763            order = 'C'
764        cr = ddot(a, numpy.asarray(b.real, order=order), alpha)
765        ci = ddot(a, numpy.asarray(b.imag, order=order), alpha)
766        ab = numpy.ndarray(cr.shape, dtype=numpy.complex128, buffer=c)
767        if c is None or beta == 0:
768            ab.real = cr
769            ab.imag = ci
770        else:
771            ab *= beta
772            ab.real += cr
773            ab.imag += ci
774        return ab
775
776    elif atype == numpy.complex128 and btype == numpy.float64:
777        if a.flags.f_contiguous:
778            order = 'F'
779        else:
780            order = 'C'
781        cr = ddot(numpy.asarray(a.real, order=order), b, alpha)
782        ci = ddot(numpy.asarray(a.imag, order=order), b, alpha)
783        ab = numpy.ndarray(cr.shape, dtype=numpy.complex128, buffer=c)
784        if c is None or beta == 0:
785            ab.real = cr
786            ab.imag = ci
787        else:
788            ab *= beta
789            ab.real += cr
790            ab.imag += ci
791        return ab
792
793    else:
794        if c is None:
795            c = numpy.dot(a, b) * alpha
796        elif beta == 0:
797            c[:] = numpy.dot(a, b) * alpha
798        else:
799            c *= beta
800            c += numpy.dot(a, b) * alpha
801        return c
802
803# a, b, c in C-order
804def _dgemm(trans_a, trans_b, m, n, k, a, b, c, alpha=1, beta=0,
805           offseta=0, offsetb=0, offsetc=0):
806    if a.size == 0 or b.size == 0:
807        if beta == 0:
808            c[:] = 0
809        else:
810            c[:] *= beta
811        return c
812
813    assert(a.flags.c_contiguous)
814    assert(b.flags.c_contiguous)
815    assert(c.flags.c_contiguous)
816
817    _np_helper.NPdgemm(ctypes.c_char(trans_b.encode('ascii')),
818                       ctypes.c_char(trans_a.encode('ascii')),
819                       ctypes.c_int(n), ctypes.c_int(m), ctypes.c_int(k),
820                       ctypes.c_int(b.shape[1]), ctypes.c_int(a.shape[1]),
821                       ctypes.c_int(c.shape[1]),
822                       ctypes.c_int(offsetb), ctypes.c_int(offseta),
823                       ctypes.c_int(offsetc),
824                       b.ctypes.data_as(ctypes.c_void_p),
825                       a.ctypes.data_as(ctypes.c_void_p),
826                       c.ctypes.data_as(ctypes.c_void_p),
827                       ctypes.c_double(alpha), ctypes.c_double(beta))
828    return c
829def _zgemm(trans_a, trans_b, m, n, k, a, b, c, alpha=1, beta=0,
830           offseta=0, offsetb=0, offsetc=0):
831    if a.size == 0 or b.size == 0:
832        if beta == 0:
833            c[:] = 0
834        else:
835            c[:] *= beta
836        return c
837
838    assert(a.flags.c_contiguous)
839    assert(b.flags.c_contiguous)
840    assert(c.flags.c_contiguous)
841    assert(a.dtype == numpy.complex128)
842    assert(b.dtype == numpy.complex128)
843    assert(c.dtype == numpy.complex128)
844
845    _np_helper.NPzgemm(ctypes.c_char(trans_b.encode('ascii')),
846                       ctypes.c_char(trans_a.encode('ascii')),
847                       ctypes.c_int(n), ctypes.c_int(m), ctypes.c_int(k),
848                       ctypes.c_int(b.shape[1]), ctypes.c_int(a.shape[1]),
849                       ctypes.c_int(c.shape[1]),
850                       ctypes.c_int(offsetb), ctypes.c_int(offseta),
851                       ctypes.c_int(offsetc),
852                       b.ctypes.data_as(ctypes.c_void_p),
853                       a.ctypes.data_as(ctypes.c_void_p),
854                       c.ctypes.data_as(ctypes.c_void_p),
855                       (ctypes.c_double*2)(alpha.real, alpha.imag),
856                       (ctypes.c_double*2)(beta.real, beta.imag))
857    return c
858
859def frompointer(pointer, count, dtype=float):
860    '''Interpret a buffer that the pointer refers to as a 1-dimensional array.
861
862    Args:
863        pointer : int or ctypes pointer
864            address of a buffer
865        count : int
866            Number of items to read.
867        dtype : data-type, optional
868            Data-type of the returned array; default: float.
869
870    Examples:
871
872    >>> s = numpy.ones(3, dtype=numpy.int32)
873    >>> ptr = s.ctypes.data
874    >>> frompointer(ptr, count=6, dtype=numpy.int16)
875    [1, 0, 1, 0, 1, 0]
876    '''
877    dtype = numpy.dtype(dtype)
878    count *= dtype.itemsize
879    buf = (ctypes.c_char * count).from_address(pointer)
880    a = numpy.ndarray(count, dtype=numpy.int8, buffer=buf)
881    return a.view(dtype)
882
883from distutils.version import LooseVersion
884if LooseVersion(numpy.__version__) <= LooseVersion('1.6.0'):
885    def norm(x, ord=None, axis=None):
886        '''numpy.linalg.norm for numpy 1.6.*
887        '''
888        if axis is None or ord is not None:
889            return numpy.linalg.norm(x, ord)
890        else:
891            x = numpy.asarray(x)
892            axes = string.ascii_lowercase[:x.ndim]
893            target = axes.replace(axes[axis], '')
894            descr = '%s,%s->%s' % (axes, axes, target)
895            xx = _numpy_einsum(descr, x.conj(), x)
896            return numpy.sqrt(xx.real)
897else:
898    norm = numpy.linalg.norm
899del(LooseVersion)
900
901def cond(x, p=None):
902    '''Compute the condition number'''
903    if isinstance(x, numpy.ndarray) and x.ndim == 2 or p is not None:
904        return numpy.linalg.cond(x, p)
905    else:
906        return numpy.asarray([numpy.linalg.cond(xi) for xi in x])
907
908def cartesian_prod(arrays, out=None):
909    '''
910    Generate a cartesian product of input arrays.
911    http://stackoverflow.com/questions/1208118/using-numpy-to-build-an-array-of-all-combinations-of-two-arrays
912
913    Args:
914        arrays : list of array-like
915            1-D arrays to form the cartesian product of.
916        out : ndarray
917            Array to place the cartesian product in.
918
919    Returns:
920        out : ndarray
921            2-D array of shape (M, len(arrays)) containing cartesian products
922            formed of input arrays.
923
924    Examples:
925
926    >>> cartesian_prod(([1, 2, 3], [4, 5], [6, 7]))
927    array([[1, 4, 6],
928           [1, 4, 7],
929           [1, 5, 6],
930           [1, 5, 7],
931           [2, 4, 6],
932           [2, 4, 7],
933           [2, 5, 6],
934           [2, 5, 7],
935           [3, 4, 6],
936           [3, 4, 7],
937           [3, 5, 6],
938           [3, 5, 7]])
939
940    '''
941    arrays = [numpy.asarray(x) for x in arrays]
942    dtype = numpy.result_type(*arrays)
943    nd = len(arrays)
944    dims = [nd] + [len(x) for x in arrays]
945    out = numpy.ndarray(dims, dtype, buffer=out)
946
947    shape = [-1] + [1] * nd
948    for i, arr in enumerate(arrays):
949        out[i] = arr.reshape(shape[:nd-i])
950
951    return out.reshape(nd,-1).T
952
953def direct_sum(subscripts, *operands):
954    '''Apply the summation over many operands with the einsum fashion.
955
956    Examples:
957
958    >>> a = numpy.random.random((6,5))
959    >>> b = numpy.random.random((4,3,2))
960    >>> direct_sum('ij,klm->ijklm', a, b).shape
961    (6, 5, 4, 3, 2)
962    >>> direct_sum('ij,klm', a, b).shape
963    (6, 5, 4, 3, 2)
964    >>> direct_sum('i,j,klm->mjlik', a[0], a[:,0], b).shape
965    (2, 6, 3, 5, 4)
966    >>> direct_sum('ij-klm->ijklm', a, b).shape
967    (6, 5, 4, 3, 2)
968    >>> direct_sum('ij+klm', a, b).shape
969    (6, 5, 4, 3, 2)
970    >>> direct_sum('-i-j+klm->mjlik', a[0], a[:,0], b).shape
971    (2, 6, 3, 5, 4)
972    >>> c = numpy.random((3,5))
973    >>> z = direct_sum('ik+jk->kij', a, c).shape  # This is slow
974    >>> abs(a.T.reshape(5,6,1) + c.reshape(5,1,3) - z).sum()
975    0.0
976    '''
977
978    def sign_and_symbs(subscript):
979        ''' sign list and notation list'''
980        subscript = subscript.replace(' ', '').replace(',', '+')
981
982        if subscript[0] not in '+-':
983            subscript = '+' + subscript
984        sign = [x for x in subscript if x in '+-']
985
986        symbs = subscript[1:].replace('-', '+').split('+')
987        #s = ''.join(symbs)
988        #assert(len(set(s)) == len(s))  # make sure no duplicated symbols
989        return sign, symbs
990
991    if '->' in subscripts:
992        src, dest = subscripts.split('->')
993        sign, src = sign_and_symbs(src)
994        dest = dest.replace(' ', '')
995    else:
996        sign, src = sign_and_symbs(subscripts)
997        dest = ''.join(src)
998    assert(len(src) == len(operands))
999
1000    for i, symb in enumerate(src):
1001        op = numpy.asarray(operands[i])
1002        assert(len(symb) == op.ndim)
1003        unisymb = set(symb)
1004        if len(unisymb) != len(symb):
1005            unisymb = ''.join(unisymb)
1006            op = _numpy_einsum('->'.join((symb, unisymb)), op)
1007            src[i] = unisymb
1008        if i == 0:
1009            if sign[i] == '+':
1010                out = op
1011            else:
1012                out = -op
1013        elif sign[i] == '+':
1014            out = out.reshape(out.shape+(1,)*op.ndim) + op
1015        else:
1016            out = out.reshape(out.shape+(1,)*op.ndim) - op
1017
1018    out = _numpy_einsum('->'.join((''.join(src), dest)), out)
1019    out.flags.writeable = True  # old numpy has this issue
1020    return out
1021
1022def condense(opname, a, loc_x, loc_y=None):
1023    '''
1024    .. code-block:: python
1025
1026        for i,i0 in enumerate(loc_x):
1027            i1 = loc_x[i+1]
1028            for j,j0 in enumerate(loc_y):
1029                j1 = loc_y[j+1]
1030                out[i,j] = op(a[i0:i1,j0:j1])
1031    '''
1032    assert(a.dtype == numpy.double)
1033    if not opname.startswith('NP_'):
1034        opname = 'NP_' + opname
1035    op = getattr(_np_helper, opname)
1036    if loc_y is None:
1037        loc_y = loc_x
1038    loc_x = numpy.asarray(loc_x, numpy.int32)
1039    loc_y = numpy.asarray(loc_y, numpy.int32)
1040    nloc_x = loc_x.size - 1
1041    nloc_y = loc_y.size - 1
1042    if a.flags.f_contiguous:
1043        out = numpy.zeros((nloc_x, nloc_y), order='F')
1044    else:
1045        a = numpy.asarray(a, order='C')
1046        out = numpy.zeros((nloc_x, nloc_y))
1047    _np_helper.NPcondense(op, out.ctypes.data_as(ctypes.c_void_p),
1048                          a.ctypes.data_as(ctypes.c_void_p),
1049                          loc_x.ctypes.data_as(ctypes.c_void_p),
1050                          loc_y.ctypes.data_as(ctypes.c_void_p),
1051                          ctypes.c_int(nloc_x), ctypes.c_int(nloc_y))
1052    return out
1053
1054def expm(a):
1055    '''Equivalent to scipy.linalg.expm'''
1056    bs = [a.copy()]
1057    n = 0
1058    for n in range(1, 14):
1059        bs.append(ddot(bs[-1], a))
1060        radius = (2**(n*(n+2))*math.factorial(n+2)*1e-16) **((n+1.)/(n+2))
1061        #print(n, radius, bs[-1].max(), -bs[-1].min())
1062        if bs[-1].max() < radius and -bs[-1].min() < radius:
1063            break
1064
1065    y = numpy.eye(a.shape[0])
1066    fac = 1
1067    for i, b in enumerate(bs):
1068        fac *= i + 1
1069        b *= (.5**(n*(i+1)) / fac)
1070        y += b
1071    buf, bs = bs[0], None
1072    for i in range(n):
1073        ddot(y, y, 1, buf, 0)
1074        y, buf = buf, y
1075    return y
1076
1077
1078class NPArrayWithTag(numpy.ndarray):
1079    # Initialize kwargs in function tag_array
1080    #def __new__(cls, a, **kwargs):
1081    #    obj = numpy.asarray(a).view(cls)
1082    #    obj.__dict__.update(kwargs)
1083    #    return obj
1084
1085    # Customize __reduce__ and __setstate__ to keep tags after serialization
1086    # pickle.loads(pickle.dumps(tagarray)).  This is needed by mpi communication
1087    def __reduce__(self):
1088        pickled = numpy.ndarray.__reduce__(self)
1089        state = pickled[2] + (self.__dict__,)
1090        return (pickled[0], pickled[1], state)
1091
1092    def __setstate__(self, state):
1093        numpy.ndarray.__setstate__(self, state[0:-1])
1094        self.__dict__.update(state[-1])
1095
1096    # Whenever the contents of the array was modified (through ufunc), the tag
1097    # should be expired. Overwrite the output of ufunc to restore ndarray type.
1098    def __array_wrap__(self, out, context=None):
1099        return numpy.ndarray.__array_wrap__(self, out, context).view(numpy.ndarray)
1100
1101
1102def tag_array(a, **kwargs):
1103    '''Attach attributes to numpy ndarray. The attribute name and value are
1104    obtained from the keyword arguments.
1105    '''
1106    # Make a shadow copy in any circumstance by converting it to an nparray.
1107    # If a is an object of NPArrayWithTag, all attributes will be lost in this
1108    # conversion. They need to be restored.
1109    t = numpy.asarray(a).view(NPArrayWithTag)
1110
1111    if isinstance(a, NPArrayWithTag):
1112        t.__dict__.update(a.__dict__)
1113    t.__dict__.update(kwargs)
1114    return t
1115
1116#TODO: merge with function pbc.cc.kccsd_rhf.vector_to_nested
1117def split_reshape(a, shapes):
1118    '''
1119    Split a vector into multiple tensors. shapes is a list of tuples.
1120    The entries of shapes indicate the shape of each tensor.
1121
1122    Returns:
1123        tensors : a list of tensors
1124
1125    Examples:
1126
1127    >>> a = numpy.arange(12)
1128    >>> split_reshape(a, ((2,3), (1,), ((2,2), (1,1))))
1129    [array([[0, 1, 2],
1130            [3, 4, 5]]),
1131     array([6]),
1132     [array([[ 7,  8],
1133             [ 9, 10]]),
1134      array([[11]])]]
1135    '''
1136    if isinstance(shapes[0], (int, numpy.integer)):
1137        return a.reshape(shapes)
1138
1139    def sub_split(a, shapes):
1140        tensors = []
1141        p1 = 0
1142        for shape in shapes:
1143            if isinstance(shape[0], (int, numpy.integer)):
1144                p0, p1 = p1, p1 + numpy.prod(shape)
1145                tensors.append(a[p0:p1].reshape(shape))
1146            else:
1147                subtensors, size = sub_split(a[p1:], shape)
1148                p1 += size
1149                tensors.append(subtensors)
1150        size = p1
1151        return tensors, size
1152    return sub_split(a, shapes)[0]
1153
1154if __name__ == '__main__':
1155    a = numpy.random.random((30,40,5,10))
1156    b = numpy.random.random((10,30,5,20))
1157    c = numpy.random.random((10,20,20))
1158    d = numpy.random.random((20,10))
1159    f = einsum('ijkl,xiky,ayp,px->ajl', a,b,c,d, optimize=True)
1160    ref = einsum('ijkl,xiky->jlxy', a, b)
1161    ref = einsum('jlxy,ayp->jlxap', ref, c)
1162    ref = einsum('jlxap,px->ajl', ref, d)
1163    print(abs(ref-f).max())
1164