1"""Quantum mechanical angular momemtum."""
2
3from sympy import (Add, binomial, cos, exp, Expr, factorial, I, Integer, Mul,
4                   pi, Rational, S, sin, simplify, sqrt, Sum, symbols, sympify,
5                   Tuple, Dummy)
6from sympy.matrices import zeros
7from sympy.printing.pretty.stringpict import prettyForm, stringPict
8from sympy.printing.pretty.pretty_symbology import pretty_symbol
9
10from sympy.physics.quantum.qexpr import QExpr
11from sympy.physics.quantum.operator import (HermitianOperator, Operator,
12                                            UnitaryOperator)
13from sympy.physics.quantum.state import Bra, Ket, State
14from sympy.functions.special.tensor_functions import KroneckerDelta
15from sympy.physics.quantum.constants import hbar
16from sympy.physics.quantum.hilbert import ComplexSpace, DirectSumHilbertSpace
17from sympy.physics.quantum.tensorproduct import TensorProduct
18from sympy.physics.quantum.cg import CG
19from sympy.physics.quantum.qapply import qapply
20
21
22__all__ = [
23    'm_values',
24    'Jplus',
25    'Jminus',
26    'Jx',
27    'Jy',
28    'Jz',
29    'J2',
30    'Rotation',
31    'WignerD',
32    'JxKet',
33    'JxBra',
34    'JyKet',
35    'JyBra',
36    'JzKet',
37    'JzBra',
38    'JzOp',
39    'J2Op',
40    'JxKetCoupled',
41    'JxBraCoupled',
42    'JyKetCoupled',
43    'JyBraCoupled',
44    'JzKetCoupled',
45    'JzBraCoupled',
46    'couple',
47    'uncouple'
48]
49
50
51def m_values(j):
52    j = sympify(j)
53    size = 2*j + 1
54    if not size.is_Integer or not size > 0:
55        raise ValueError(
56            'Only integer or half-integer values allowed for j, got: : %r' % j
57        )
58    return size, [j - i for i in range(int(2*j + 1))]
59
60
61#-----------------------------------------------------------------------------
62# Spin Operators
63#-----------------------------------------------------------------------------
64
65
66class SpinOpBase:
67    """Base class for spin operators."""
68
69    @classmethod
70    def _eval_hilbert_space(cls, label):
71        # We consider all j values so our space is infinite.
72        return ComplexSpace(S.Infinity)
73
74    @property
75    def name(self):
76        return self.args[0]
77
78    def _print_contents(self, printer, *args):
79        return '%s%s' % (self.name, self._coord)
80
81    def _print_contents_pretty(self, printer, *args):
82        a = stringPict(str(self.name))
83        b = stringPict(self._coord)
84        return self._print_subscript_pretty(a, b)
85
86    def _print_contents_latex(self, printer, *args):
87        return r'%s_%s' % ((self.name, self._coord))
88
89    def _represent_base(self, basis, **options):
90        j = options.get('j', S.Half)
91        size, mvals = m_values(j)
92        result = zeros(size, size)
93        for p in range(size):
94            for q in range(size):
95                me = self.matrix_element(j, mvals[p], j, mvals[q])
96                result[p, q] = me
97        return result
98
99    def _apply_op(self, ket, orig_basis, **options):
100        state = ket.rewrite(self.basis)
101        # If the state has only one term
102        if isinstance(state, State):
103            ret = (hbar*state.m)*state
104        # state is a linear combination of states
105        elif isinstance(state, Sum):
106            ret = self._apply_operator_Sum(state, **options)
107        else:
108            ret = qapply(self*state)
109        if ret == self*state:
110            raise NotImplementedError
111        return ret.rewrite(orig_basis)
112
113    def _apply_operator_JxKet(self, ket, **options):
114        return self._apply_op(ket, 'Jx', **options)
115
116    def _apply_operator_JxKetCoupled(self, ket, **options):
117        return self._apply_op(ket, 'Jx', **options)
118
119    def _apply_operator_JyKet(self, ket, **options):
120        return self._apply_op(ket, 'Jy', **options)
121
122    def _apply_operator_JyKetCoupled(self, ket, **options):
123        return self._apply_op(ket, 'Jy', **options)
124
125    def _apply_operator_JzKet(self, ket, **options):
126        return self._apply_op(ket, 'Jz', **options)
127
128    def _apply_operator_JzKetCoupled(self, ket, **options):
129        return self._apply_op(ket, 'Jz', **options)
130
131    def _apply_operator_TensorProduct(self, tp, **options):
132        # Uncoupling operator is only easily found for coordinate basis spin operators
133        # TODO: add methods for uncoupling operators
134        if not (isinstance(self, JxOp) or isinstance(self, JyOp) or isinstance(self, JzOp)):
135            raise NotImplementedError
136        result = []
137        for n in range(len(tp.args)):
138            arg = []
139            arg.extend(tp.args[:n])
140            arg.append(self._apply_operator(tp.args[n]))
141            arg.extend(tp.args[n + 1:])
142            result.append(tp.__class__(*arg))
143        return Add(*result).expand()
144
145    # TODO: move this to qapply_Mul
146    def _apply_operator_Sum(self, s, **options):
147        new_func = qapply(self*s.function)
148        if new_func == self*s.function:
149            raise NotImplementedError
150        return Sum(new_func, *s.limits)
151
152    def _eval_trace(self, **options):
153        #TODO: use options to use different j values
154        #For now eval at default basis
155
156        # is it efficient to represent each time
157        # to do a trace?
158        return self._represent_default_basis().trace()
159
160
161class JplusOp(SpinOpBase, Operator):
162    """The J+ operator."""
163
164    _coord = '+'
165
166    basis = 'Jz'
167
168    def _eval_commutator_JminusOp(self, other):
169        return 2*hbar*JzOp(self.name)
170
171    def _apply_operator_JzKet(self, ket, **options):
172        j = ket.j
173        m = ket.m
174        if m.is_Number and j.is_Number:
175            if m >= j:
176                return S.Zero
177        return hbar*sqrt(j*(j + S.One) - m*(m + S.One))*JzKet(j, m + S.One)
178
179    def _apply_operator_JzKetCoupled(self, ket, **options):
180        j = ket.j
181        m = ket.m
182        jn = ket.jn
183        coupling = ket.coupling
184        if m.is_Number and j.is_Number:
185            if m >= j:
186                return S.Zero
187        return hbar*sqrt(j*(j + S.One) - m*(m + S.One))*JzKetCoupled(j, m + S.One, jn, coupling)
188
189    def matrix_element(self, j, m, jp, mp):
190        result = hbar*sqrt(j*(j + S.One) - mp*(mp + S.One))
191        result *= KroneckerDelta(m, mp + 1)
192        result *= KroneckerDelta(j, jp)
193        return result
194
195    def _represent_default_basis(self, **options):
196        return self._represent_JzOp(None, **options)
197
198    def _represent_JzOp(self, basis, **options):
199        return self._represent_base(basis, **options)
200
201    def _eval_rewrite_as_xyz(self, *args, **kwargs):
202        return JxOp(args[0]) + I*JyOp(args[0])
203
204
205class JminusOp(SpinOpBase, Operator):
206    """The J- operator."""
207
208    _coord = '-'
209
210    basis = 'Jz'
211
212    def _apply_operator_JzKet(self, ket, **options):
213        j = ket.j
214        m = ket.m
215        if m.is_Number and j.is_Number:
216            if m <= -j:
217                return S.Zero
218        return hbar*sqrt(j*(j + S.One) - m*(m - S.One))*JzKet(j, m - S.One)
219
220    def _apply_operator_JzKetCoupled(self, ket, **options):
221        j = ket.j
222        m = ket.m
223        jn = ket.jn
224        coupling = ket.coupling
225        if m.is_Number and j.is_Number:
226            if m <= -j:
227                return S.Zero
228        return hbar*sqrt(j*(j + S.One) - m*(m - S.One))*JzKetCoupled(j, m - S.One, jn, coupling)
229
230    def matrix_element(self, j, m, jp, mp):
231        result = hbar*sqrt(j*(j + S.One) - mp*(mp - S.One))
232        result *= KroneckerDelta(m, mp - 1)
233        result *= KroneckerDelta(j, jp)
234        return result
235
236    def _represent_default_basis(self, **options):
237        return self._represent_JzOp(None, **options)
238
239    def _represent_JzOp(self, basis, **options):
240        return self._represent_base(basis, **options)
241
242    def _eval_rewrite_as_xyz(self, *args, **kwargs):
243        return JxOp(args[0]) - I*JyOp(args[0])
244
245
246class JxOp(SpinOpBase, HermitianOperator):
247    """The Jx operator."""
248
249    _coord = 'x'
250
251    basis = 'Jx'
252
253    def _eval_commutator_JyOp(self, other):
254        return I*hbar*JzOp(self.name)
255
256    def _eval_commutator_JzOp(self, other):
257        return -I*hbar*JyOp(self.name)
258
259    def _apply_operator_JzKet(self, ket, **options):
260        jp = JplusOp(self.name)._apply_operator_JzKet(ket, **options)
261        jm = JminusOp(self.name)._apply_operator_JzKet(ket, **options)
262        return (jp + jm)/Integer(2)
263
264    def _apply_operator_JzKetCoupled(self, ket, **options):
265        jp = JplusOp(self.name)._apply_operator_JzKetCoupled(ket, **options)
266        jm = JminusOp(self.name)._apply_operator_JzKetCoupled(ket, **options)
267        return (jp + jm)/Integer(2)
268
269    def _represent_default_basis(self, **options):
270        return self._represent_JzOp(None, **options)
271
272    def _represent_JzOp(self, basis, **options):
273        jp = JplusOp(self.name)._represent_JzOp(basis, **options)
274        jm = JminusOp(self.name)._represent_JzOp(basis, **options)
275        return (jp + jm)/Integer(2)
276
277    def _eval_rewrite_as_plusminus(self, *args, **kwargs):
278        return (JplusOp(args[0]) + JminusOp(args[0]))/2
279
280
281class JyOp(SpinOpBase, HermitianOperator):
282    """The Jy operator."""
283
284    _coord = 'y'
285
286    basis = 'Jy'
287
288    def _eval_commutator_JzOp(self, other):
289        return I*hbar*JxOp(self.name)
290
291    def _eval_commutator_JxOp(self, other):
292        return -I*hbar*J2Op(self.name)
293
294    def _apply_operator_JzKet(self, ket, **options):
295        jp = JplusOp(self.name)._apply_operator_JzKet(ket, **options)
296        jm = JminusOp(self.name)._apply_operator_JzKet(ket, **options)
297        return (jp - jm)/(Integer(2)*I)
298
299    def _apply_operator_JzKetCoupled(self, ket, **options):
300        jp = JplusOp(self.name)._apply_operator_JzKetCoupled(ket, **options)
301        jm = JminusOp(self.name)._apply_operator_JzKetCoupled(ket, **options)
302        return (jp - jm)/(Integer(2)*I)
303
304    def _represent_default_basis(self, **options):
305        return self._represent_JzOp(None, **options)
306
307    def _represent_JzOp(self, basis, **options):
308        jp = JplusOp(self.name)._represent_JzOp(basis, **options)
309        jm = JminusOp(self.name)._represent_JzOp(basis, **options)
310        return (jp - jm)/(Integer(2)*I)
311
312    def _eval_rewrite_as_plusminus(self, *args, **kwargs):
313        return (JplusOp(args[0]) - JminusOp(args[0]))/(2*I)
314
315
316class JzOp(SpinOpBase, HermitianOperator):
317    """The Jz operator."""
318
319    _coord = 'z'
320
321    basis = 'Jz'
322
323    def _eval_commutator_JxOp(self, other):
324        return I*hbar*JyOp(self.name)
325
326    def _eval_commutator_JyOp(self, other):
327        return -I*hbar*JxOp(self.name)
328
329    def _eval_commutator_JplusOp(self, other):
330        return hbar*JplusOp(self.name)
331
332    def _eval_commutator_JminusOp(self, other):
333        return -hbar*JminusOp(self.name)
334
335    def matrix_element(self, j, m, jp, mp):
336        result = hbar*mp
337        result *= KroneckerDelta(m, mp)
338        result *= KroneckerDelta(j, jp)
339        return result
340
341    def _represent_default_basis(self, **options):
342        return self._represent_JzOp(None, **options)
343
344    def _represent_JzOp(self, basis, **options):
345        return self._represent_base(basis, **options)
346
347
348class J2Op(SpinOpBase, HermitianOperator):
349    """The J^2 operator."""
350
351    _coord = '2'
352
353    def _eval_commutator_JxOp(self, other):
354        return S.Zero
355
356    def _eval_commutator_JyOp(self, other):
357        return S.Zero
358
359    def _eval_commutator_JzOp(self, other):
360        return S.Zero
361
362    def _eval_commutator_JplusOp(self, other):
363        return S.Zero
364
365    def _eval_commutator_JminusOp(self, other):
366        return S.Zero
367
368    def _apply_operator_JxKet(self, ket, **options):
369        j = ket.j
370        return hbar**2*j*(j + 1)*ket
371
372    def _apply_operator_JxKetCoupled(self, ket, **options):
373        j = ket.j
374        return hbar**2*j*(j + 1)*ket
375
376    def _apply_operator_JyKet(self, ket, **options):
377        j = ket.j
378        return hbar**2*j*(j + 1)*ket
379
380    def _apply_operator_JyKetCoupled(self, ket, **options):
381        j = ket.j
382        return hbar**2*j*(j + 1)*ket
383
384    def _apply_operator_JzKet(self, ket, **options):
385        j = ket.j
386        return hbar**2*j*(j + 1)*ket
387
388    def _apply_operator_JzKetCoupled(self, ket, **options):
389        j = ket.j
390        return hbar**2*j*(j + 1)*ket
391
392    def matrix_element(self, j, m, jp, mp):
393        result = (hbar**2)*j*(j + 1)
394        result *= KroneckerDelta(m, mp)
395        result *= KroneckerDelta(j, jp)
396        return result
397
398    def _represent_default_basis(self, **options):
399        return self._represent_JzOp(None, **options)
400
401    def _represent_JzOp(self, basis, **options):
402        return self._represent_base(basis, **options)
403
404    def _print_contents_pretty(self, printer, *args):
405        a = prettyForm(str(self.name))
406        b = prettyForm('2')
407        return a**b
408
409    def _print_contents_latex(self, printer, *args):
410        return r'%s^2' % str(self.name)
411
412    def _eval_rewrite_as_xyz(self, *args, **kwargs):
413        return JxOp(args[0])**2 + JyOp(args[0])**2 + JzOp(args[0])**2
414
415    def _eval_rewrite_as_plusminus(self, *args, **kwargs):
416        a = args[0]
417        return JzOp(a)**2 + \
418            S.Half*(JplusOp(a)*JminusOp(a) + JminusOp(a)*JplusOp(a))
419
420
421class Rotation(UnitaryOperator):
422    """Wigner D operator in terms of Euler angles.
423
424    Defines the rotation operator in terms of the Euler angles defined by
425    the z-y-z convention for a passive transformation. That is the coordinate
426    axes are rotated first about the z-axis, giving the new x'-y'-z' axes. Then
427    this new coordinate system is rotated about the new y'-axis, giving new
428    x''-y''-z'' axes. Then this new coordinate system is rotated about the
429    z''-axis. Conventions follow those laid out in [1]_.
430
431    Parameters
432    ==========
433
434    alpha : Number, Symbol
435        First Euler Angle
436    beta : Number, Symbol
437        Second Euler angle
438    gamma : Number, Symbol
439        Third Euler angle
440
441    Examples
442    ========
443
444    A simple example rotation operator:
445
446        >>> from sympy import pi
447        >>> from sympy.physics.quantum.spin import Rotation
448        >>> Rotation(pi, 0, pi/2)
449        R(pi,0,pi/2)
450
451    With symbolic Euler angles and calculating the inverse rotation operator:
452
453        >>> from sympy import symbols
454        >>> a, b, c = symbols('a b c')
455        >>> Rotation(a, b, c)
456        R(a,b,c)
457        >>> Rotation(a, b, c).inverse()
458        R(-c,-b,-a)
459
460    See Also
461    ========
462
463    WignerD: Symbolic Wigner-D function
464    D: Wigner-D function
465    d: Wigner small-d function
466
467    References
468    ==========
469
470    .. [1] Varshalovich, D A, Quantum Theory of Angular Momentum. 1988.
471    """
472
473    @classmethod
474    def _eval_args(cls, args):
475        args = QExpr._eval_args(args)
476        if len(args) != 3:
477            raise ValueError('3 Euler angles required, got: %r' % args)
478        return args
479
480    @classmethod
481    def _eval_hilbert_space(cls, label):
482        # We consider all j values so our space is infinite.
483        return ComplexSpace(S.Infinity)
484
485    @property
486    def alpha(self):
487        return self.label[0]
488
489    @property
490    def beta(self):
491        return self.label[1]
492
493    @property
494    def gamma(self):
495        return self.label[2]
496
497    def _print_operator_name(self, printer, *args):
498        return 'R'
499
500    def _print_operator_name_pretty(self, printer, *args):
501        if printer._use_unicode:
502            return prettyForm('\N{SCRIPT CAPITAL R}' + ' ')
503        else:
504            return prettyForm("R ")
505
506    def _print_operator_name_latex(self, printer, *args):
507        return r'\mathcal{R}'
508
509    def _eval_inverse(self):
510        return Rotation(-self.gamma, -self.beta, -self.alpha)
511
512    @classmethod
513    def D(cls, j, m, mp, alpha, beta, gamma):
514        """Wigner D-function.
515
516        Returns an instance of the WignerD class corresponding to the Wigner-D
517        function specified by the parameters.
518
519        Parameters
520        ===========
521
522        j : Number
523            Total angular momentum
524        m : Number
525            Eigenvalue of angular momentum along axis after rotation
526        mp : Number
527            Eigenvalue of angular momentum along rotated axis
528        alpha : Number, Symbol
529            First Euler angle of rotation
530        beta : Number, Symbol
531            Second Euler angle of rotation
532        gamma : Number, Symbol
533            Third Euler angle of rotation
534
535        Examples
536        ========
537
538        Return the Wigner-D matrix element for a defined rotation, both
539        numerical and symbolic:
540
541            >>> from sympy.physics.quantum.spin import Rotation
542            >>> from sympy import pi, symbols
543            >>> alpha, beta, gamma = symbols('alpha beta gamma')
544            >>> Rotation.D(1, 1, 0,pi, pi/2,-pi)
545            WignerD(1, 1, 0, pi, pi/2, -pi)
546
547        See Also
548        ========
549
550        WignerD: Symbolic Wigner-D function
551
552        """
553        return WignerD(j, m, mp, alpha, beta, gamma)
554
555    @classmethod
556    def d(cls, j, m, mp, beta):
557        """Wigner small-d function.
558
559        Returns an instance of the WignerD class corresponding to the Wigner-D
560        function specified by the parameters with the alpha and gamma angles
561        given as 0.
562
563        Parameters
564        ===========
565
566        j : Number
567            Total angular momentum
568        m : Number
569            Eigenvalue of angular momentum along axis after rotation
570        mp : Number
571            Eigenvalue of angular momentum along rotated axis
572        beta : Number, Symbol
573            Second Euler angle of rotation
574
575        Examples
576        ========
577
578        Return the Wigner-D matrix element for a defined rotation, both
579        numerical and symbolic:
580
581            >>> from sympy.physics.quantum.spin import Rotation
582            >>> from sympy import pi, symbols
583            >>> beta = symbols('beta')
584            >>> Rotation.d(1, 1, 0, pi/2)
585            WignerD(1, 1, 0, 0, pi/2, 0)
586
587        See Also
588        ========
589
590        WignerD: Symbolic Wigner-D function
591
592        """
593        return WignerD(j, m, mp, 0, beta, 0)
594
595    def matrix_element(self, j, m, jp, mp):
596        result = self.__class__.D(
597            jp, m, mp, self.alpha, self.beta, self.gamma
598        )
599        result *= KroneckerDelta(j, jp)
600        return result
601
602    def _represent_base(self, basis, **options):
603        j = sympify(options.get('j', S.Half))
604        # TODO: move evaluation up to represent function/implement elsewhere
605        evaluate = sympify(options.get('doit'))
606        size, mvals = m_values(j)
607        result = zeros(size, size)
608        for p in range(size):
609            for q in range(size):
610                me = self.matrix_element(j, mvals[p], j, mvals[q])
611                if evaluate:
612                    result[p, q] = me.doit()
613                else:
614                    result[p, q] = me
615        return result
616
617    def _represent_default_basis(self, **options):
618        return self._represent_JzOp(None, **options)
619
620    def _represent_JzOp(self, basis, **options):
621        return self._represent_base(basis, **options)
622
623    def _apply_operator_uncoupled(self, state, ket, *, dummy=True, **options):
624        a = self.alpha
625        b = self.beta
626        g = self.gamma
627        j = ket.j
628        m = ket.m
629        if j.is_number:
630            s = []
631            size = m_values(j)
632            sz = size[1]
633            for mp in sz:
634                r = Rotation.D(j, m, mp, a, b, g)
635                z = r.doit()
636                s.append(z*state(j, mp))
637            return Add(*s)
638        else:
639            if dummy:
640                mp = Dummy('mp')
641            else:
642                mp = symbols('mp')
643            return Sum(Rotation.D(j, m, mp, a, b, g)*state(j, mp), (mp, -j, j))
644
645    def _apply_operator_JxKet(self, ket, **options):
646        return self._apply_operator_uncoupled(JxKet, ket, **options)
647
648    def _apply_operator_JyKet(self, ket, **options):
649        return self._apply_operator_uncoupled(JyKet, ket, **options)
650
651    def _apply_operator_JzKet(self, ket, **options):
652        return self._apply_operator_uncoupled(JzKet, ket, **options)
653
654    def _apply_operator_coupled(self, state, ket, *, dummy=True, **options):
655        a = self.alpha
656        b = self.beta
657        g = self.gamma
658        j = ket.j
659        m = ket.m
660        jn = ket.jn
661        coupling = ket.coupling
662        if j.is_number:
663            s = []
664            size = m_values(j)
665            sz = size[1]
666            for mp in sz:
667                r = Rotation.D(j, m, mp, a, b, g)
668                z = r.doit()
669                s.append(z*state(j, mp, jn, coupling))
670            return Add(*s)
671        else:
672            if dummy:
673                mp = Dummy('mp')
674            else:
675                mp = symbols('mp')
676            return Sum(Rotation.D(j, m, mp, a, b, g)*state(
677                j, mp, jn, coupling), (mp, -j, j))
678
679    def _apply_operator_JxKetCoupled(self, ket, **options):
680        return self._apply_operator_coupled(JxKetCoupled, ket, **options)
681
682    def _apply_operator_JyKetCoupled(self, ket, **options):
683        return self._apply_operator_coupled(JyKetCoupled, ket, **options)
684
685    def _apply_operator_JzKetCoupled(self, ket, **options):
686        return self._apply_operator_coupled(JzKetCoupled, ket, **options)
687
688class WignerD(Expr):
689    r"""Wigner-D function
690
691    The Wigner D-function gives the matrix elements of the rotation
692    operator in the jm-representation. For the Euler angles `\alpha`,
693    `\beta`, `\gamma`, the D-function is defined such that:
694
695    .. math ::
696        <j,m| \mathcal{R}(\alpha, \beta, \gamma ) |j',m'> = \delta_{jj'} D(j, m, m', \alpha, \beta, \gamma)
697
698    Where the rotation operator is as defined by the Rotation class [1]_.
699
700    The Wigner D-function defined in this way gives:
701
702    .. math ::
703        D(j, m, m', \alpha, \beta, \gamma) = e^{-i m \alpha} d(j, m, m', \beta) e^{-i m' \gamma}
704
705    Where d is the Wigner small-d function, which is given by Rotation.d.
706
707    The Wigner small-d function gives the component of the Wigner
708    D-function that is determined by the second Euler angle. That is the
709    Wigner D-function is:
710
711    .. math ::
712        D(j, m, m', \alpha, \beta, \gamma) = e^{-i m \alpha} d(j, m, m', \beta) e^{-i m' \gamma}
713
714    Where d is the small-d function. The Wigner D-function is given by
715    Rotation.D.
716
717    Note that to evaluate the D-function, the j, m and mp parameters must
718    be integer or half integer numbers.
719
720    Parameters
721    ==========
722
723    j : Number
724        Total angular momentum
725    m : Number
726        Eigenvalue of angular momentum along axis after rotation
727    mp : Number
728        Eigenvalue of angular momentum along rotated axis
729    alpha : Number, Symbol
730        First Euler angle of rotation
731    beta : Number, Symbol
732        Second Euler angle of rotation
733    gamma : Number, Symbol
734        Third Euler angle of rotation
735
736    Examples
737    ========
738
739    Evaluate the Wigner-D matrix elements of a simple rotation:
740
741        >>> from sympy.physics.quantum.spin import Rotation
742        >>> from sympy import pi
743        >>> rot = Rotation.D(1, 1, 0, pi, pi/2, 0)
744        >>> rot
745        WignerD(1, 1, 0, pi, pi/2, 0)
746        >>> rot.doit()
747        sqrt(2)/2
748
749    Evaluate the Wigner-d matrix elements of a simple rotation
750
751        >>> rot = Rotation.d(1, 1, 0, pi/2)
752        >>> rot
753        WignerD(1, 1, 0, 0, pi/2, 0)
754        >>> rot.doit()
755        -sqrt(2)/2
756
757    See Also
758    ========
759
760    Rotation: Rotation operator
761
762    References
763    ==========
764
765    .. [1] Varshalovich, D A, Quantum Theory of Angular Momentum. 1988.
766    """
767
768    is_commutative = True
769
770    def __new__(cls, *args, **hints):
771        if not len(args) == 6:
772            raise ValueError('6 parameters expected, got %s' % args)
773        args = sympify(args)
774        evaluate = hints.get('evaluate', False)
775        if evaluate:
776            return Expr.__new__(cls, *args)._eval_wignerd()
777        return Expr.__new__(cls, *args)
778
779    @property
780    def j(self):
781        return self.args[0]
782
783    @property
784    def m(self):
785        return self.args[1]
786
787    @property
788    def mp(self):
789        return self.args[2]
790
791    @property
792    def alpha(self):
793        return self.args[3]
794
795    @property
796    def beta(self):
797        return self.args[4]
798
799    @property
800    def gamma(self):
801        return self.args[5]
802
803    def _latex(self, printer, *args):
804        if self.alpha == 0 and self.gamma == 0:
805            return r'd^{%s}_{%s,%s}\left(%s\right)' % \
806                (
807                    printer._print(self.j), printer._print(
808                        self.m), printer._print(self.mp),
809                    printer._print(self.beta) )
810        return r'D^{%s}_{%s,%s}\left(%s,%s,%s\right)' % \
811            (
812                printer._print(
813                    self.j), printer._print(self.m), printer._print(self.mp),
814                printer._print(self.alpha), printer._print(self.beta), printer._print(self.gamma) )
815
816    def _pretty(self, printer, *args):
817        top = printer._print(self.j)
818
819        bot = printer._print(self.m)
820        bot = prettyForm(*bot.right(','))
821        bot = prettyForm(*bot.right(printer._print(self.mp)))
822
823        pad = max(top.width(), bot.width())
824        top = prettyForm(*top.left(' '))
825        bot = prettyForm(*bot.left(' '))
826        if pad > top.width():
827            top = prettyForm(*top.right(' '*(pad - top.width())))
828        if pad > bot.width():
829            bot = prettyForm(*bot.right(' '*(pad - bot.width())))
830        if self.alpha == 0 and self.gamma == 0:
831            args = printer._print(self.beta)
832            s = stringPict('d' + ' '*pad)
833        else:
834            args = printer._print(self.alpha)
835            args = prettyForm(*args.right(','))
836            args = prettyForm(*args.right(printer._print(self.beta)))
837            args = prettyForm(*args.right(','))
838            args = prettyForm(*args.right(printer._print(self.gamma)))
839
840            s = stringPict('D' + ' '*pad)
841
842        args = prettyForm(*args.parens())
843        s = prettyForm(*s.above(top))
844        s = prettyForm(*s.below(bot))
845        s = prettyForm(*s.right(args))
846        return s
847
848    def doit(self, **hints):
849        hints['evaluate'] = True
850        return WignerD(*self.args, **hints)
851
852    def _eval_wignerd(self):
853        j = sympify(self.j)
854        m = sympify(self.m)
855        mp = sympify(self.mp)
856        alpha = sympify(self.alpha)
857        beta = sympify(self.beta)
858        gamma = sympify(self.gamma)
859        if alpha == 0 and beta == 0 and gamma == 0:
860            return KroneckerDelta(m, mp)
861        if not j.is_number:
862            raise ValueError(
863                'j parameter must be numerical to evaluate, got %s' % j)
864        r = 0
865        if beta == pi/2:
866            # Varshalovich Equation (5), Section 4.16, page 113, setting
867            # alpha=gamma=0.
868            for k in range(2*j + 1):
869                if k > j + mp or k > j - m or k < mp - m:
870                    continue
871                r += (S.NegativeOne)**k*binomial(j + mp, k)*binomial(j - mp, k + m - mp)
872            r *= (S.NegativeOne)**(m - mp) / 2**j*sqrt(factorial(j + m) *
873                    factorial(j - m) / (factorial(j + mp)*factorial(j - mp)))
874        else:
875            # Varshalovich Equation(5), Section 4.7.2, page 87, where we set
876            # beta1=beta2=pi/2, and we get alpha=gamma=pi/2 and beta=phi+pi,
877            # then we use the Eq. (1), Section 4.4. page 79, to simplify:
878            # d(j, m, mp, beta+pi) = (-1)**(j-mp)*d(j, m, -mp, beta)
879            # This happens to be almost the same as in Eq.(10), Section 4.16,
880            # except that we need to substitute -mp for mp.
881            size, mvals = m_values(j)
882            for mpp in mvals:
883                r += Rotation.d(j, m, mpp, pi/2).doit()*(cos(-mpp*beta) + I*sin(-mpp*beta))*\
884                    Rotation.d(j, mpp, -mp, pi/2).doit()
885            # Empirical normalization factor so results match Varshalovich
886            # Tables 4.3-4.12
887            # Note that this exact normalization does not follow from the
888            # above equations
889            r = r*I**(2*j - m - mp)*(-1)**(2*m)
890            # Finally, simplify the whole expression
891            r = simplify(r)
892        r *= exp(-I*m*alpha)*exp(-I*mp*gamma)
893        return r
894
895
896Jx = JxOp('J')
897Jy = JyOp('J')
898Jz = JzOp('J')
899J2 = J2Op('J')
900Jplus = JplusOp('J')
901Jminus = JminusOp('J')
902
903
904#-----------------------------------------------------------------------------
905# Spin States
906#-----------------------------------------------------------------------------
907
908
909class SpinState(State):
910    """Base class for angular momentum states."""
911
912    _label_separator = ','
913
914    def __new__(cls, j, m):
915        j = sympify(j)
916        m = sympify(m)
917        if j.is_number:
918            if 2*j != int(2*j):
919                raise ValueError(
920                    'j must be integer or half-integer, got: %s' % j)
921            if j < 0:
922                raise ValueError('j must be >= 0, got: %s' % j)
923        if m.is_number:
924            if 2*m != int(2*m):
925                raise ValueError(
926                    'm must be integer or half-integer, got: %s' % m)
927        if j.is_number and m.is_number:
928            if abs(m) > j:
929                raise ValueError('Allowed values for m are -j <= m <= j, got j, m: %s, %s' % (j, m))
930            if int(j - m) != j - m:
931                raise ValueError('Both j and m must be integer or half-integer, got j, m: %s, %s' % (j, m))
932        return State.__new__(cls, j, m)
933
934    @property
935    def j(self):
936        return self.label[0]
937
938    @property
939    def m(self):
940        return self.label[1]
941
942    @classmethod
943    def _eval_hilbert_space(cls, label):
944        return ComplexSpace(2*label[0] + 1)
945
946    def _represent_base(self, **options):
947        j = self.j
948        m = self.m
949        alpha = sympify(options.get('alpha', 0))
950        beta = sympify(options.get('beta', 0))
951        gamma = sympify(options.get('gamma', 0))
952        size, mvals = m_values(j)
953        result = zeros(size, 1)
954        # breaks finding angles on L930
955        for p, mval in enumerate(mvals):
956            if m.is_number:
957                result[p, 0] = Rotation.D(
958                    self.j, mval, self.m, alpha, beta, gamma).doit()
959            else:
960                result[p, 0] = Rotation.D(self.j, mval,
961                                          self.m, alpha, beta, gamma)
962        return result
963
964    def _eval_rewrite_as_Jx(self, *args, **options):
965        if isinstance(self, Bra):
966            return self._rewrite_basis(Jx, JxBra, **options)
967        return self._rewrite_basis(Jx, JxKet, **options)
968
969    def _eval_rewrite_as_Jy(self, *args, **options):
970        if isinstance(self, Bra):
971            return self._rewrite_basis(Jy, JyBra, **options)
972        return self._rewrite_basis(Jy, JyKet, **options)
973
974    def _eval_rewrite_as_Jz(self, *args, **options):
975        if isinstance(self, Bra):
976            return self._rewrite_basis(Jz, JzBra, **options)
977        return self._rewrite_basis(Jz, JzKet, **options)
978
979    def _rewrite_basis(self, basis, evect, **options):
980        from sympy.physics.quantum.represent import represent
981        j = self.j
982        args = self.args[2:]
983        if j.is_number:
984            if isinstance(self, CoupledSpinState):
985                if j == int(j):
986                    start = j**2
987                else:
988                    start = (2*j - 1)*(2*j + 1)/4
989            else:
990                start = 0
991            vect = represent(self, basis=basis, **options)
992            result = Add(
993                *[vect[start + i]*evect(j, j - i, *args) for i in range(2*j + 1)])
994            if isinstance(self, CoupledSpinState) and options.get('coupled') is False:
995                return uncouple(result)
996            return result
997        else:
998            i = 0
999            mi = symbols('mi')
1000            # make sure not to introduce a symbol already in the state
1001            while self.subs(mi, 0) != self:
1002                i += 1
1003                mi = symbols('mi%d' % i)
1004                break
1005            # TODO: better way to get angles of rotation
1006            if isinstance(self, CoupledSpinState):
1007                test_args = (0, mi, (0, 0))
1008            else:
1009                test_args = (0, mi)
1010            if isinstance(self, Ket):
1011                angles = represent(
1012                    self.__class__(*test_args), basis=basis)[0].args[3:6]
1013            else:
1014                angles = represent(self.__class__(
1015                    *test_args), basis=basis)[0].args[0].args[3:6]
1016            if angles == (0, 0, 0):
1017                return self
1018            else:
1019                state = evect(j, mi, *args)
1020                lt = Rotation.D(j, mi, self.m, *angles)
1021                return Sum(lt*state, (mi, -j, j))
1022
1023    def _eval_innerproduct_JxBra(self, bra, **hints):
1024        result = KroneckerDelta(self.j, bra.j)
1025        if bra.dual_class() is not self.__class__:
1026            result *= self._represent_JxOp(None)[bra.j - bra.m]
1027        else:
1028            result *= KroneckerDelta(
1029                self.j, bra.j)*KroneckerDelta(self.m, bra.m)
1030        return result
1031
1032    def _eval_innerproduct_JyBra(self, bra, **hints):
1033        result = KroneckerDelta(self.j, bra.j)
1034        if bra.dual_class() is not self.__class__:
1035            result *= self._represent_JyOp(None)[bra.j - bra.m]
1036        else:
1037            result *= KroneckerDelta(
1038                self.j, bra.j)*KroneckerDelta(self.m, bra.m)
1039        return result
1040
1041    def _eval_innerproduct_JzBra(self, bra, **hints):
1042        result = KroneckerDelta(self.j, bra.j)
1043        if bra.dual_class() is not self.__class__:
1044            result *= self._represent_JzOp(None)[bra.j - bra.m]
1045        else:
1046            result *= KroneckerDelta(
1047                self.j, bra.j)*KroneckerDelta(self.m, bra.m)
1048        return result
1049
1050    def _eval_trace(self, bra, **hints):
1051
1052        # One way to implement this method is to assume the basis set k is
1053        # passed.
1054        # Then we can apply the discrete form of Trace formula here
1055        # Tr(|i><j| ) = \Sum_k <k|i><j|k>
1056        #then we do qapply() on each each inner product and sum over them.
1057
1058        # OR
1059
1060        # Inner product of |i><j| = Trace(Outer Product).
1061        # we could just use this unless there are cases when this is not true
1062
1063        return (bra*self).doit()
1064
1065
1066class JxKet(SpinState, Ket):
1067    """Eigenket of Jx.
1068
1069    See JzKet for the usage of spin eigenstates.
1070
1071    See Also
1072    ========
1073
1074    JzKet: Usage of spin states
1075
1076    """
1077
1078    @classmethod
1079    def dual_class(self):
1080        return JxBra
1081
1082    @classmethod
1083    def coupled_class(self):
1084        return JxKetCoupled
1085
1086    def _represent_default_basis(self, **options):
1087        return self._represent_JxOp(None, **options)
1088
1089    def _represent_JxOp(self, basis, **options):
1090        return self._represent_base(**options)
1091
1092    def _represent_JyOp(self, basis, **options):
1093        return self._represent_base(alpha=pi*Rational(3, 2), **options)
1094
1095    def _represent_JzOp(self, basis, **options):
1096        return self._represent_base(beta=pi/2, **options)
1097
1098
1099class JxBra(SpinState, Bra):
1100    """Eigenbra of Jx.
1101
1102    See JzKet for the usage of spin eigenstates.
1103
1104    See Also
1105    ========
1106
1107    JzKet: Usage of spin states
1108
1109    """
1110
1111    @classmethod
1112    def dual_class(self):
1113        return JxKet
1114
1115    @classmethod
1116    def coupled_class(self):
1117        return JxBraCoupled
1118
1119
1120class JyKet(SpinState, Ket):
1121    """Eigenket of Jy.
1122
1123    See JzKet for the usage of spin eigenstates.
1124
1125    See Also
1126    ========
1127
1128    JzKet: Usage of spin states
1129
1130    """
1131
1132    @classmethod
1133    def dual_class(self):
1134        return JyBra
1135
1136    @classmethod
1137    def coupled_class(self):
1138        return JyKetCoupled
1139
1140    def _represent_default_basis(self, **options):
1141        return self._represent_JyOp(None, **options)
1142
1143    def _represent_JxOp(self, basis, **options):
1144        return self._represent_base(gamma=pi/2, **options)
1145
1146    def _represent_JyOp(self, basis, **options):
1147        return self._represent_base(**options)
1148
1149    def _represent_JzOp(self, basis, **options):
1150        return self._represent_base(alpha=pi*Rational(3, 2), beta=-pi/2, gamma=pi/2, **options)
1151
1152
1153class JyBra(SpinState, Bra):
1154    """Eigenbra of Jy.
1155
1156    See JzKet for the usage of spin eigenstates.
1157
1158    See Also
1159    ========
1160
1161    JzKet: Usage of spin states
1162
1163    """
1164
1165    @classmethod
1166    def dual_class(self):
1167        return JyKet
1168
1169    @classmethod
1170    def coupled_class(self):
1171        return JyBraCoupled
1172
1173
1174class JzKet(SpinState, Ket):
1175    """Eigenket of Jz.
1176
1177    Spin state which is an eigenstate of the Jz operator. Uncoupled states,
1178    that is states representing the interaction of multiple separate spin
1179    states, are defined as a tensor product of states.
1180
1181    Parameters
1182    ==========
1183
1184    j : Number, Symbol
1185        Total spin angular momentum
1186    m : Number, Symbol
1187        Eigenvalue of the Jz spin operator
1188
1189    Examples
1190    ========
1191
1192    *Normal States:*
1193
1194    Defining simple spin states, both numerical and symbolic:
1195
1196        >>> from sympy.physics.quantum.spin import JzKet, JxKet
1197        >>> from sympy import symbols
1198        >>> JzKet(1, 0)
1199        |1,0>
1200        >>> j, m = symbols('j m')
1201        >>> JzKet(j, m)
1202        |j,m>
1203
1204    Rewriting the JzKet in terms of eigenkets of the Jx operator:
1205    Note: that the resulting eigenstates are JxKet's
1206
1207        >>> JzKet(1,1).rewrite("Jx")
1208        |1,-1>/2 - sqrt(2)*|1,0>/2 + |1,1>/2
1209
1210    Get the vector representation of a state in terms of the basis elements
1211    of the Jx operator:
1212
1213        >>> from sympy.physics.quantum.represent import represent
1214        >>> from sympy.physics.quantum.spin import Jx, Jz
1215        >>> represent(JzKet(1,-1), basis=Jx)
1216        Matrix([
1217        [      1/2],
1218        [sqrt(2)/2],
1219        [      1/2]])
1220
1221    Apply innerproducts between states:
1222
1223        >>> from sympy.physics.quantum.innerproduct import InnerProduct
1224        >>> from sympy.physics.quantum.spin import JxBra
1225        >>> i = InnerProduct(JxBra(1,1), JzKet(1,1))
1226        >>> i
1227        <1,1|1,1>
1228        >>> i.doit()
1229        1/2
1230
1231    *Uncoupled States:*
1232
1233    Define an uncoupled state as a TensorProduct between two Jz eigenkets:
1234
1235        >>> from sympy.physics.quantum.tensorproduct import TensorProduct
1236        >>> j1,m1,j2,m2 = symbols('j1 m1 j2 m2')
1237        >>> TensorProduct(JzKet(1,0), JzKet(1,1))
1238        |1,0>x|1,1>
1239        >>> TensorProduct(JzKet(j1,m1), JzKet(j2,m2))
1240        |j1,m1>x|j2,m2>
1241
1242    A TensorProduct can be rewritten, in which case the eigenstates that make
1243    up the tensor product is rewritten to the new basis:
1244
1245        >>> TensorProduct(JzKet(1,1),JxKet(1,1)).rewrite('Jz')
1246        |1,1>x|1,-1>/2 + sqrt(2)*|1,1>x|1,0>/2 + |1,1>x|1,1>/2
1247
1248    The represent method for TensorProduct's gives the vector representation of
1249    the state. Note that the state in the product basis is the equivalent of the
1250    tensor product of the vector representation of the component eigenstates:
1251
1252        >>> represent(TensorProduct(JzKet(1,0),JzKet(1,1)))
1253        Matrix([
1254        [0],
1255        [0],
1256        [0],
1257        [1],
1258        [0],
1259        [0],
1260        [0],
1261        [0],
1262        [0]])
1263        >>> represent(TensorProduct(JzKet(1,1),JxKet(1,1)), basis=Jz)
1264        Matrix([
1265        [      1/2],
1266        [sqrt(2)/2],
1267        [      1/2],
1268        [        0],
1269        [        0],
1270        [        0],
1271        [        0],
1272        [        0],
1273        [        0]])
1274
1275    See Also
1276    ========
1277
1278    JzKetCoupled: Coupled eigenstates
1279    sympy.physics.quantum.tensorproduct.TensorProduct: Used to specify uncoupled states
1280    uncouple: Uncouples states given coupling parameters
1281    couple: Couples uncoupled states
1282
1283    """
1284
1285    @classmethod
1286    def dual_class(self):
1287        return JzBra
1288
1289    @classmethod
1290    def coupled_class(self):
1291        return JzKetCoupled
1292
1293    def _represent_default_basis(self, **options):
1294        return self._represent_JzOp(None, **options)
1295
1296    def _represent_JxOp(self, basis, **options):
1297        return self._represent_base(beta=pi*Rational(3, 2), **options)
1298
1299    def _represent_JyOp(self, basis, **options):
1300        return self._represent_base(alpha=pi*Rational(3, 2), beta=pi/2, gamma=pi/2, **options)
1301
1302    def _represent_JzOp(self, basis, **options):
1303        return self._represent_base(**options)
1304
1305
1306class JzBra(SpinState, Bra):
1307    """Eigenbra of Jz.
1308
1309    See the JzKet for the usage of spin eigenstates.
1310
1311    See Also
1312    ========
1313
1314    JzKet: Usage of spin states
1315
1316    """
1317
1318    @classmethod
1319    def dual_class(self):
1320        return JzKet
1321
1322    @classmethod
1323    def coupled_class(self):
1324        return JzBraCoupled
1325
1326
1327# Method used primarily to create coupled_n and coupled_jn by __new__ in
1328# CoupledSpinState
1329# This same method is also used by the uncouple method, and is separated from
1330# the CoupledSpinState class to maintain consistency in defining coupling
1331def _build_coupled(jcoupling, length):
1332    n_list = [ [n + 1] for n in range(length) ]
1333    coupled_jn = []
1334    coupled_n = []
1335    for n1, n2, j_new in jcoupling:
1336        coupled_jn.append(j_new)
1337        coupled_n.append( (n_list[n1 - 1], n_list[n2 - 1]) )
1338        n_sort = sorted(n_list[n1 - 1] + n_list[n2 - 1])
1339        n_list[n_sort[0] - 1] = n_sort
1340    return coupled_n, coupled_jn
1341
1342
1343class CoupledSpinState(SpinState):
1344    """Base class for coupled angular momentum states."""
1345
1346    def __new__(cls, j, m, jn, *jcoupling):
1347        # Check j and m values using SpinState
1348        SpinState(j, m)
1349        # Build and check coupling scheme from arguments
1350        if len(jcoupling) == 0:
1351            # Use default coupling scheme
1352            jcoupling = []
1353            for n in range(2, len(jn)):
1354                jcoupling.append( (1, n, Add(*[jn[i] for i in range(n)])) )
1355            jcoupling.append( (1, len(jn), j) )
1356        elif len(jcoupling) == 1:
1357            # Use specified coupling scheme
1358            jcoupling = jcoupling[0]
1359        else:
1360            raise TypeError("CoupledSpinState only takes 3 or 4 arguments, got: %s" % (len(jcoupling) + 3) )
1361        # Check arguments have correct form
1362        if not (isinstance(jn, list) or isinstance(jn, tuple) or isinstance(jn, Tuple)):
1363            raise TypeError('jn must be Tuple, list or tuple, got %s' %
1364                            jn.__class__.__name__)
1365        if not (isinstance(jcoupling, list) or isinstance(jcoupling, tuple) or isinstance(jcoupling, Tuple)):
1366            raise TypeError('jcoupling must be Tuple, list or tuple, got %s' %
1367                            jcoupling.__class__.__name__)
1368        if not all(isinstance(term, list) or isinstance(term, tuple) or isinstance(term, Tuple) for term in jcoupling):
1369            raise TypeError(
1370                'All elements of jcoupling must be list, tuple or Tuple')
1371        if not len(jn) - 1 == len(jcoupling):
1372            raise ValueError('jcoupling must have length of %d, got %d' %
1373                             (len(jn) - 1, len(jcoupling)))
1374        if not all(len(x) == 3 for x in jcoupling):
1375            raise ValueError('All elements of jcoupling must have length 3')
1376        # Build sympified args
1377        j = sympify(j)
1378        m = sympify(m)
1379        jn = Tuple( *[sympify(ji) for ji in jn] )
1380        jcoupling = Tuple( *[Tuple(sympify(
1381            n1), sympify(n2), sympify(ji)) for (n1, n2, ji) in jcoupling] )
1382        # Check values in coupling scheme give physical state
1383        if any(2*ji != int(2*ji) for ji in jn if ji.is_number):
1384            raise ValueError('All elements of jn must be integer or half-integer, got: %s' % jn)
1385        if any(n1 != int(n1) or n2 != int(n2) for (n1, n2, _) in jcoupling):
1386            raise ValueError('Indices in jcoupling must be integers')
1387        if any(n1 < 1 or n2 < 1 or n1 > len(jn) or n2 > len(jn) for (n1, n2, _) in jcoupling):
1388            raise ValueError('Indices must be between 1 and the number of coupled spin spaces')
1389        if any(2*ji != int(2*ji) for (_, _, ji) in jcoupling if ji.is_number):
1390            raise ValueError('All coupled j values in coupling scheme must be integer or half-integer')
1391        coupled_n, coupled_jn = _build_coupled(jcoupling, len(jn))
1392        jvals = list(jn)
1393        for n, (n1, n2) in enumerate(coupled_n):
1394            j1 = jvals[min(n1) - 1]
1395            j2 = jvals[min(n2) - 1]
1396            j3 = coupled_jn[n]
1397            if sympify(j1).is_number and sympify(j2).is_number and sympify(j3).is_number:
1398                if j1 + j2 < j3:
1399                    raise ValueError('All couplings must have j1+j2 >= j3, '
1400                        'in coupling number %d got j1,j2,j3: %d,%d,%d' % (n + 1, j1, j2, j3))
1401                if abs(j1 - j2) > j3:
1402                    raise ValueError("All couplings must have |j1+j2| <= j3, "
1403                        "in coupling number %d got j1,j2,j3: %d,%d,%d" % (n + 1, j1, j2, j3))
1404                if int(j1 + j2) == j1 + j2:
1405                    pass
1406            jvals[min(n1 + n2) - 1] = j3
1407        if len(jcoupling) > 0 and jcoupling[-1][2] != j:
1408            raise ValueError('Last j value coupled together must be the final j of the state')
1409        # Return state
1410        return State.__new__(cls, j, m, jn, jcoupling)
1411
1412    def _print_label(self, printer, *args):
1413        label = [printer._print(self.j), printer._print(self.m)]
1414        for i, ji in enumerate(self.jn, start=1):
1415            label.append('j%d=%s' % (
1416                i, printer._print(ji)
1417            ))
1418        for jn, (n1, n2) in zip(self.coupled_jn[:-1], self.coupled_n[:-1]):
1419            label.append('j(%s)=%s' % (
1420                ','.join(str(i) for i in sorted(n1 + n2)), printer._print(jn)
1421            ))
1422        return ','.join(label)
1423
1424    def _print_label_pretty(self, printer, *args):
1425        label = [self.j, self.m]
1426        for i, ji in enumerate(self.jn, start=1):
1427            symb = 'j%d' % i
1428            symb = pretty_symbol(symb)
1429            symb = prettyForm(symb + '=')
1430            item = prettyForm(*symb.right(printer._print(ji)))
1431            label.append(item)
1432        for jn, (n1, n2) in zip(self.coupled_jn[:-1], self.coupled_n[:-1]):
1433            n = ','.join(pretty_symbol("j%d" % i)[-1] for i in sorted(n1 + n2))
1434            symb = prettyForm('j' + n + '=')
1435            item = prettyForm(*symb.right(printer._print(jn)))
1436            label.append(item)
1437        return self._print_sequence_pretty(
1438            label, self._label_separator, printer, *args
1439        )
1440
1441    def _print_label_latex(self, printer, *args):
1442        label = [
1443            printer._print(self.j, *args),
1444            printer._print(self.m, *args)
1445        ]
1446        for i, ji in enumerate(self.jn, start=1):
1447            label.append('j_{%d}=%s' % (i, printer._print(ji, *args)) )
1448        for jn, (n1, n2) in zip(self.coupled_jn[:-1], self.coupled_n[:-1]):
1449            n = ','.join(str(i) for i in sorted(n1 + n2))
1450            label.append('j_{%s}=%s' % (n, printer._print(jn, *args)) )
1451        return self._label_separator.join(label)
1452
1453    @property
1454    def jn(self):
1455        return self.label[2]
1456
1457    @property
1458    def coupling(self):
1459        return self.label[3]
1460
1461    @property
1462    def coupled_jn(self):
1463        return _build_coupled(self.label[3], len(self.label[2]))[1]
1464
1465    @property
1466    def coupled_n(self):
1467        return _build_coupled(self.label[3], len(self.label[2]))[0]
1468
1469    @classmethod
1470    def _eval_hilbert_space(cls, label):
1471        j = Add(*label[2])
1472        if j.is_number:
1473            return DirectSumHilbertSpace(*[ ComplexSpace(x) for x in range(int(2*j + 1), 0, -2) ])
1474        else:
1475            # TODO: Need hilbert space fix, see issue 5732
1476            # Desired behavior:
1477            #ji = symbols('ji')
1478            #ret = Sum(ComplexSpace(2*ji + 1), (ji, 0, j))
1479            # Temporary fix:
1480            return ComplexSpace(2*j + 1)
1481
1482    def _represent_coupled_base(self, **options):
1483        evect = self.uncoupled_class()
1484        if not self.j.is_number:
1485            raise ValueError(
1486                'State must not have symbolic j value to represent')
1487        if not self.hilbert_space.dimension.is_number:
1488            raise ValueError(
1489                'State must not have symbolic j values to represent')
1490        result = zeros(self.hilbert_space.dimension, 1)
1491        if self.j == int(self.j):
1492            start = self.j**2
1493        else:
1494            start = (2*self.j - 1)*(1 + 2*self.j)/4
1495        result[start:start + 2*self.j + 1, 0] = evect(
1496            self.j, self.m)._represent_base(**options)
1497        return result
1498
1499    def _eval_rewrite_as_Jx(self, *args, **options):
1500        if isinstance(self, Bra):
1501            return self._rewrite_basis(Jx, JxBraCoupled, **options)
1502        return self._rewrite_basis(Jx, JxKetCoupled, **options)
1503
1504    def _eval_rewrite_as_Jy(self, *args, **options):
1505        if isinstance(self, Bra):
1506            return self._rewrite_basis(Jy, JyBraCoupled, **options)
1507        return self._rewrite_basis(Jy, JyKetCoupled, **options)
1508
1509    def _eval_rewrite_as_Jz(self, *args, **options):
1510        if isinstance(self, Bra):
1511            return self._rewrite_basis(Jz, JzBraCoupled, **options)
1512        return self._rewrite_basis(Jz, JzKetCoupled, **options)
1513
1514
1515class JxKetCoupled(CoupledSpinState, Ket):
1516    """Coupled eigenket of Jx.
1517
1518    See JzKetCoupled for the usage of coupled spin eigenstates.
1519
1520    See Also
1521    ========
1522
1523    JzKetCoupled: Usage of coupled spin states
1524
1525    """
1526
1527    @classmethod
1528    def dual_class(self):
1529        return JxBraCoupled
1530
1531    @classmethod
1532    def uncoupled_class(self):
1533        return JxKet
1534
1535    def _represent_default_basis(self, **options):
1536        return self._represent_JzOp(None, **options)
1537
1538    def _represent_JxOp(self, basis, **options):
1539        return self._represent_coupled_base(**options)
1540
1541    def _represent_JyOp(self, basis, **options):
1542        return self._represent_coupled_base(alpha=pi*Rational(3, 2), **options)
1543
1544    def _represent_JzOp(self, basis, **options):
1545        return self._represent_coupled_base(beta=pi/2, **options)
1546
1547
1548class JxBraCoupled(CoupledSpinState, Bra):
1549    """Coupled eigenbra of Jx.
1550
1551    See JzKetCoupled for the usage of coupled spin eigenstates.
1552
1553    See Also
1554    ========
1555
1556    JzKetCoupled: Usage of coupled spin states
1557
1558    """
1559
1560    @classmethod
1561    def dual_class(self):
1562        return JxKetCoupled
1563
1564    @classmethod
1565    def uncoupled_class(self):
1566        return JxBra
1567
1568
1569class JyKetCoupled(CoupledSpinState, Ket):
1570    """Coupled eigenket of Jy.
1571
1572    See JzKetCoupled for the usage of coupled spin eigenstates.
1573
1574    See Also
1575    ========
1576
1577    JzKetCoupled: Usage of coupled spin states
1578
1579    """
1580
1581    @classmethod
1582    def dual_class(self):
1583        return JyBraCoupled
1584
1585    @classmethod
1586    def uncoupled_class(self):
1587        return JyKet
1588
1589    def _represent_default_basis(self, **options):
1590        return self._represent_JzOp(None, **options)
1591
1592    def _represent_JxOp(self, basis, **options):
1593        return self._represent_coupled_base(gamma=pi/2, **options)
1594
1595    def _represent_JyOp(self, basis, **options):
1596        return self._represent_coupled_base(**options)
1597
1598    def _represent_JzOp(self, basis, **options):
1599        return self._represent_coupled_base(alpha=pi*Rational(3, 2), beta=-pi/2, gamma=pi/2, **options)
1600
1601
1602class JyBraCoupled(CoupledSpinState, Bra):
1603    """Coupled eigenbra of Jy.
1604
1605    See JzKetCoupled for the usage of coupled spin eigenstates.
1606
1607    See Also
1608    ========
1609
1610    JzKetCoupled: Usage of coupled spin states
1611
1612    """
1613
1614    @classmethod
1615    def dual_class(self):
1616        return JyKetCoupled
1617
1618    @classmethod
1619    def uncoupled_class(self):
1620        return JyBra
1621
1622
1623class JzKetCoupled(CoupledSpinState, Ket):
1624    r"""Coupled eigenket of Jz
1625
1626    Spin state that is an eigenket of Jz which represents the coupling of
1627    separate spin spaces.
1628
1629    The arguments for creating instances of JzKetCoupled are ``j``, ``m``,
1630    ``jn`` and an optional ``jcoupling`` argument. The ``j`` and ``m`` options
1631    are the total angular momentum quantum numbers, as used for normal states
1632    (e.g. JzKet).
1633
1634    The other required parameter in ``jn``, which is a tuple defining the `j_n`
1635    angular momentum quantum numbers of the product spaces. So for example, if
1636    a state represented the coupling of the product basis state
1637    `\left|j_1,m_1\right\rangle\times\left|j_2,m_2\right\rangle`, the ``jn``
1638    for this state would be ``(j1,j2)``.
1639
1640    The final option is ``jcoupling``, which is used to define how the spaces
1641    specified by ``jn`` are coupled, which includes both the order these spaces
1642    are coupled together and the quantum numbers that arise from these
1643    couplings. The ``jcoupling`` parameter itself is a list of lists, such that
1644    each of the sublists defines a single coupling between the spin spaces. If
1645    there are N coupled angular momentum spaces, that is ``jn`` has N elements,
1646    then there must be N-1 sublists. Each of these sublists making up the
1647    ``jcoupling`` parameter have length 3. The first two elements are the
1648    indices of the product spaces that are considered to be coupled together.
1649    For example, if we want to couple `j_1` and `j_4`, the indices would be 1
1650    and 4. If a state has already been coupled, it is referenced by the
1651    smallest index that is coupled, so if `j_2` and `j_4` has already been
1652    coupled to some `j_{24}`, then this value can be coupled by referencing it
1653    with index 2. The final element of the sublist is the quantum number of the
1654    coupled state. So putting everything together, into a valid sublist for
1655    ``jcoupling``, if `j_1` and `j_2` are coupled to an angular momentum space
1656    with quantum number `j_{12}` with the value ``j12``, the sublist would be
1657    ``(1,2,j12)``, N-1 of these sublists are used in the list for
1658    ``jcoupling``.
1659
1660    Note the ``jcoupling`` parameter is optional, if it is not specified, the
1661    default coupling is taken. This default value is to coupled the spaces in
1662    order and take the quantum number of the coupling to be the maximum value.
1663    For example, if the spin spaces are `j_1`, `j_2`, `j_3`, `j_4`, then the
1664    default coupling couples `j_1` and `j_2` to `j_{12}=j_1+j_2`, then,
1665    `j_{12}` and `j_3` are coupled to `j_{123}=j_{12}+j_3`, and finally
1666    `j_{123}` and `j_4` to `j=j_{123}+j_4`. The jcoupling value that would
1667    correspond to this is:
1668
1669        ``((1,2,j1+j2),(1,3,j1+j2+j3))``
1670
1671    Parameters
1672    ==========
1673
1674    args : tuple
1675        The arguments that must be passed are ``j``, ``m``, ``jn``, and
1676        ``jcoupling``. The ``j`` value is the total angular momentum. The ``m``
1677        value is the eigenvalue of the Jz spin operator. The ``jn`` list are
1678        the j values of argular momentum spaces coupled together. The
1679        ``jcoupling`` parameter is an optional parameter defining how the spaces
1680        are coupled together. See the above description for how these coupling
1681        parameters are defined.
1682
1683    Examples
1684    ========
1685
1686    Defining simple spin states, both numerical and symbolic:
1687
1688        >>> from sympy.physics.quantum.spin import JzKetCoupled
1689        >>> from sympy import symbols
1690        >>> JzKetCoupled(1, 0, (1, 1))
1691        |1,0,j1=1,j2=1>
1692        >>> j, m, j1, j2 = symbols('j m j1 j2')
1693        >>> JzKetCoupled(j, m, (j1, j2))
1694        |j,m,j1=j1,j2=j2>
1695
1696    Defining coupled spin states for more than 2 coupled spaces with various
1697    coupling parameters:
1698
1699        >>> JzKetCoupled(2, 1, (1, 1, 1))
1700        |2,1,j1=1,j2=1,j3=1,j(1,2)=2>
1701        >>> JzKetCoupled(2, 1, (1, 1, 1), ((1,2,2),(1,3,2)) )
1702        |2,1,j1=1,j2=1,j3=1,j(1,2)=2>
1703        >>> JzKetCoupled(2, 1, (1, 1, 1), ((2,3,1),(1,2,2)) )
1704        |2,1,j1=1,j2=1,j3=1,j(2,3)=1>
1705
1706    Rewriting the JzKetCoupled in terms of eigenkets of the Jx operator:
1707    Note: that the resulting eigenstates are JxKetCoupled
1708
1709        >>> JzKetCoupled(1,1,(1,1)).rewrite("Jx")
1710        |1,-1,j1=1,j2=1>/2 - sqrt(2)*|1,0,j1=1,j2=1>/2 + |1,1,j1=1,j2=1>/2
1711
1712    The rewrite method can be used to convert a coupled state to an uncoupled
1713    state. This is done by passing coupled=False to the rewrite function:
1714
1715        >>> JzKetCoupled(1, 0, (1, 1)).rewrite('Jz', coupled=False)
1716        -sqrt(2)*|1,-1>x|1,1>/2 + sqrt(2)*|1,1>x|1,-1>/2
1717
1718    Get the vector representation of a state in terms of the basis elements
1719    of the Jx operator:
1720
1721        >>> from sympy.physics.quantum.represent import represent
1722        >>> from sympy.physics.quantum.spin import Jx
1723        >>> from sympy import S
1724        >>> represent(JzKetCoupled(1,-1,(S(1)/2,S(1)/2)), basis=Jx)
1725        Matrix([
1726        [        0],
1727        [      1/2],
1728        [sqrt(2)/2],
1729        [      1/2]])
1730
1731    See Also
1732    ========
1733
1734    JzKet: Normal spin eigenstates
1735    uncouple: Uncoupling of coupling spin states
1736    couple: Coupling of uncoupled spin states
1737
1738    """
1739
1740    @classmethod
1741    def dual_class(self):
1742        return JzBraCoupled
1743
1744    @classmethod
1745    def uncoupled_class(self):
1746        return JzKet
1747
1748    def _represent_default_basis(self, **options):
1749        return self._represent_JzOp(None, **options)
1750
1751    def _represent_JxOp(self, basis, **options):
1752        return self._represent_coupled_base(beta=pi*Rational(3, 2), **options)
1753
1754    def _represent_JyOp(self, basis, **options):
1755        return self._represent_coupled_base(alpha=pi*Rational(3, 2), beta=pi/2, gamma=pi/2, **options)
1756
1757    def _represent_JzOp(self, basis, **options):
1758        return self._represent_coupled_base(**options)
1759
1760
1761class JzBraCoupled(CoupledSpinState, Bra):
1762    """Coupled eigenbra of Jz.
1763
1764    See the JzKetCoupled for the usage of coupled spin eigenstates.
1765
1766    See Also
1767    ========
1768
1769    JzKetCoupled: Usage of coupled spin states
1770
1771    """
1772
1773    @classmethod
1774    def dual_class(self):
1775        return JzKetCoupled
1776
1777    @classmethod
1778    def uncoupled_class(self):
1779        return JzBra
1780
1781#-----------------------------------------------------------------------------
1782# Coupling/uncoupling
1783#-----------------------------------------------------------------------------
1784
1785
1786def couple(expr, jcoupling_list=None):
1787    """ Couple a tensor product of spin states
1788
1789    This function can be used to couple an uncoupled tensor product of spin
1790    states. All of the eigenstates to be coupled must be of the same class. It
1791    will return a linear combination of eigenstates that are subclasses of
1792    CoupledSpinState determined by Clebsch-Gordan angular momentum coupling
1793    coefficients.
1794
1795    Parameters
1796    ==========
1797
1798    expr : Expr
1799        An expression involving TensorProducts of spin states to be coupled.
1800        Each state must be a subclass of SpinState and they all must be the
1801        same class.
1802
1803    jcoupling_list : list or tuple
1804        Elements of this list are sub-lists of length 2 specifying the order of
1805        the coupling of the spin spaces. The length of this must be N-1, where N
1806        is the number of states in the tensor product to be coupled. The
1807        elements of this sublist are the same as the first two elements of each
1808        sublist in the ``jcoupling`` parameter defined for JzKetCoupled. If this
1809        parameter is not specified, the default value is taken, which couples
1810        the first and second product basis spaces, then couples this new coupled
1811        space to the third product space, etc
1812
1813    Examples
1814    ========
1815
1816    Couple a tensor product of numerical states for two spaces:
1817
1818        >>> from sympy.physics.quantum.spin import JzKet, couple
1819        >>> from sympy.physics.quantum.tensorproduct import TensorProduct
1820        >>> couple(TensorProduct(JzKet(1,0), JzKet(1,1)))
1821        -sqrt(2)*|1,1,j1=1,j2=1>/2 + sqrt(2)*|2,1,j1=1,j2=1>/2
1822
1823
1824    Numerical coupling of three spaces using the default coupling method, i.e.
1825    first and second spaces couple, then this couples to the third space:
1826
1827        >>> couple(TensorProduct(JzKet(1,1), JzKet(1,1), JzKet(1,0)))
1828        sqrt(6)*|2,2,j1=1,j2=1,j3=1,j(1,2)=2>/3 + sqrt(3)*|3,2,j1=1,j2=1,j3=1,j(1,2)=2>/3
1829
1830    Perform this same coupling, but we define the coupling to first couple
1831    the first and third spaces:
1832
1833        >>> couple(TensorProduct(JzKet(1,1), JzKet(1,1), JzKet(1,0)), ((1,3),(1,2)) )
1834        sqrt(2)*|2,2,j1=1,j2=1,j3=1,j(1,3)=1>/2 - sqrt(6)*|2,2,j1=1,j2=1,j3=1,j(1,3)=2>/6 + sqrt(3)*|3,2,j1=1,j2=1,j3=1,j(1,3)=2>/3
1835
1836    Couple a tensor product of symbolic states:
1837
1838        >>> from sympy import symbols
1839        >>> j1,m1,j2,m2 = symbols('j1 m1 j2 m2')
1840        >>> couple(TensorProduct(JzKet(j1,m1), JzKet(j2,m2)))
1841        Sum(CG(j1, m1, j2, m2, j, m1 + m2)*|j,m1 + m2,j1=j1,j2=j2>, (j, m1 + m2, j1 + j2))
1842
1843    """
1844    a = expr.atoms(TensorProduct)
1845    for tp in a:
1846        # Allow other tensor products to be in expression
1847        if not all([ isinstance(state, SpinState) for state in tp.args]):
1848            continue
1849        # If tensor product has all spin states, raise error for invalid tensor product state
1850        if not all([state.__class__ is tp.args[0].__class__ for state in tp.args]):
1851            raise TypeError('All states must be the same basis')
1852        expr = expr.subs(tp, _couple(tp, jcoupling_list))
1853    return expr
1854
1855
1856def _couple(tp, jcoupling_list):
1857    states = tp.args
1858    coupled_evect = states[0].coupled_class()
1859
1860    # Define default coupling if none is specified
1861    if jcoupling_list is None:
1862        jcoupling_list = []
1863        for n in range(1, len(states)):
1864            jcoupling_list.append( (1, n + 1) )
1865
1866    # Check jcoupling_list valid
1867    if not len(jcoupling_list) == len(states) - 1:
1868        raise TypeError('jcoupling_list must be length %d, got %d' %
1869                        (len(states) - 1, len(jcoupling_list)))
1870    if not all( len(coupling) == 2 for coupling in jcoupling_list):
1871        raise ValueError('Each coupling must define 2 spaces')
1872    if any([n1 == n2 for n1, n2 in jcoupling_list]):
1873        raise ValueError('Spin spaces cannot couple to themselves')
1874    if all([sympify(n1).is_number and sympify(n2).is_number for n1, n2 in jcoupling_list]):
1875        j_test = [0]*len(states)
1876        for n1, n2 in jcoupling_list:
1877            if j_test[n1 - 1] == -1 or j_test[n2 - 1] == -1:
1878                raise ValueError('Spaces coupling j_n\'s are referenced by smallest n value')
1879            j_test[max(n1, n2) - 1] = -1
1880
1881    # j values of states to be coupled together
1882    jn = [state.j for state in states]
1883    mn = [state.m for state in states]
1884
1885    # Create coupling_list, which defines all the couplings between all
1886    # the spaces from jcoupling_list
1887    coupling_list = []
1888    n_list = [ [i + 1] for i in range(len(states)) ]
1889    for j_coupling in jcoupling_list:
1890        # Least n for all j_n which is coupled as first and second spaces
1891        n1, n2 = j_coupling
1892        # List of all n's coupled in first and second spaces
1893        j1_n = list(n_list[n1 - 1])
1894        j2_n = list(n_list[n2 - 1])
1895        coupling_list.append( (j1_n, j2_n) )
1896        # Set new j_n to be coupling of all j_n in both first and second spaces
1897        n_list[ min(n1, n2) - 1 ] = sorted(j1_n + j2_n)
1898
1899    if all(state.j.is_number and state.m.is_number for state in states):
1900        # Numerical coupling
1901        # Iterate over difference between maximum possible j value of each coupling and the actual value
1902        diff_max = [ Add( *[ jn[n - 1] - mn[n - 1] for n in coupling[0] +
1903                         coupling[1] ] ) for coupling in coupling_list ]
1904        result = []
1905        for diff in range(diff_max[-1] + 1):
1906            # Determine available configurations
1907            n = len(coupling_list)
1908            tot = binomial(diff + n - 1, diff)
1909
1910            for config_num in range(tot):
1911                diff_list = _confignum_to_difflist(config_num, diff, n)
1912
1913                # Skip the configuration if non-physical
1914                # This is a lazy check for physical states given the loose restrictions of diff_max
1915                if any( [ d > m for d, m in zip(diff_list, diff_max) ] ):
1916                    continue
1917
1918                # Determine term
1919                cg_terms = []
1920                coupled_j = list(jn)
1921                jcoupling = []
1922                for (j1_n, j2_n), coupling_diff in zip(coupling_list, diff_list):
1923                    j1 = coupled_j[ min(j1_n) - 1 ]
1924                    j2 = coupled_j[ min(j2_n) - 1 ]
1925                    j3 = j1 + j2 - coupling_diff
1926                    coupled_j[ min(j1_n + j2_n) - 1 ] = j3
1927                    m1 = Add( *[ mn[x - 1] for x in j1_n] )
1928                    m2 = Add( *[ mn[x - 1] for x in j2_n] )
1929                    m3 = m1 + m2
1930                    cg_terms.append( (j1, m1, j2, m2, j3, m3) )
1931                    jcoupling.append( (min(j1_n), min(j2_n), j3) )
1932                # Better checks that state is physical
1933                if any([ abs(term[5]) > term[4] for term in cg_terms ]):
1934                    continue
1935                if any([ term[0] + term[2] < term[4] for term in cg_terms ]):
1936                    continue
1937                if any([ abs(term[0] - term[2]) > term[4] for term in cg_terms ]):
1938                    continue
1939                coeff = Mul( *[ CG(*term).doit() for term in cg_terms] )
1940                state = coupled_evect(j3, m3, jn, jcoupling)
1941                result.append(coeff*state)
1942        return Add(*result)
1943    else:
1944        # Symbolic coupling
1945        cg_terms = []
1946        jcoupling = []
1947        sum_terms = []
1948        coupled_j = list(jn)
1949        for j1_n, j2_n in coupling_list:
1950            j1 = coupled_j[ min(j1_n) - 1 ]
1951            j2 = coupled_j[ min(j2_n) - 1 ]
1952            if len(j1_n + j2_n) == len(states):
1953                j3 = symbols('j')
1954            else:
1955                j3_name = 'j' + ''.join(["%s" % n for n in j1_n + j2_n])
1956                j3 = symbols(j3_name)
1957            coupled_j[ min(j1_n + j2_n) - 1 ] = j3
1958            m1 = Add( *[ mn[x - 1] for x in j1_n] )
1959            m2 = Add( *[ mn[x - 1] for x in j2_n] )
1960            m3 = m1 + m2
1961            cg_terms.append( (j1, m1, j2, m2, j3, m3) )
1962            jcoupling.append( (min(j1_n), min(j2_n), j3) )
1963            sum_terms.append((j3, m3, j1 + j2))
1964        coeff = Mul( *[ CG(*term) for term in cg_terms] )
1965        state = coupled_evect(j3, m3, jn, jcoupling)
1966        return Sum(coeff*state, *sum_terms)
1967
1968
1969def uncouple(expr, jn=None, jcoupling_list=None):
1970    """ Uncouple a coupled spin state
1971
1972    Gives the uncoupled representation of a coupled spin state. Arguments must
1973    be either a spin state that is a subclass of CoupledSpinState or a spin
1974    state that is a subclass of SpinState and an array giving the j values
1975    of the spaces that are to be coupled
1976
1977    Parameters
1978    ==========
1979
1980    expr : Expr
1981        The expression containing states that are to be coupled. If the states
1982        are a subclass of SpinState, the ``jn`` and ``jcoupling`` parameters
1983        must be defined. If the states are a subclass of CoupledSpinState,
1984        ``jn`` and ``jcoupling`` will be taken from the state.
1985
1986    jn : list or tuple
1987        The list of the j-values that are coupled. If state is a
1988        CoupledSpinState, this parameter is ignored. This must be defined if
1989        state is not a subclass of CoupledSpinState. The syntax of this
1990        parameter is the same as the ``jn`` parameter of JzKetCoupled.
1991
1992    jcoupling_list : list or tuple
1993        The list defining how the j-values are coupled together. If state is a
1994        CoupledSpinState, this parameter is ignored. This must be defined if
1995        state is not a subclass of CoupledSpinState. The syntax of this
1996        parameter is the same as the ``jcoupling`` parameter of JzKetCoupled.
1997
1998    Examples
1999    ========
2000
2001    Uncouple a numerical state using a CoupledSpinState state:
2002
2003        >>> from sympy.physics.quantum.spin import JzKetCoupled, uncouple
2004        >>> from sympy import S
2005        >>> uncouple(JzKetCoupled(1, 0, (S(1)/2, S(1)/2)))
2006        sqrt(2)*|1/2,-1/2>x|1/2,1/2>/2 + sqrt(2)*|1/2,1/2>x|1/2,-1/2>/2
2007
2008    Perform the same calculation using a SpinState state:
2009
2010        >>> from sympy.physics.quantum.spin import JzKet
2011        >>> uncouple(JzKet(1, 0), (S(1)/2, S(1)/2))
2012        sqrt(2)*|1/2,-1/2>x|1/2,1/2>/2 + sqrt(2)*|1/2,1/2>x|1/2,-1/2>/2
2013
2014    Uncouple a numerical state of three coupled spaces using a CoupledSpinState state:
2015
2016        >>> uncouple(JzKetCoupled(1, 1, (1, 1, 1), ((1,3,1),(1,2,1)) ))
2017        |1,-1>x|1,1>x|1,1>/2 - |1,0>x|1,0>x|1,1>/2 + |1,1>x|1,0>x|1,0>/2 - |1,1>x|1,1>x|1,-1>/2
2018
2019    Perform the same calculation using a SpinState state:
2020
2021        >>> uncouple(JzKet(1, 1), (1, 1, 1), ((1,3,1),(1,2,1)) )
2022        |1,-1>x|1,1>x|1,1>/2 - |1,0>x|1,0>x|1,1>/2 + |1,1>x|1,0>x|1,0>/2 - |1,1>x|1,1>x|1,-1>/2
2023
2024    Uncouple a symbolic state using a CoupledSpinState state:
2025
2026        >>> from sympy import symbols
2027        >>> j,m,j1,j2 = symbols('j m j1 j2')
2028        >>> uncouple(JzKetCoupled(j, m, (j1, j2)))
2029        Sum(CG(j1, m1, j2, m2, j, m)*|j1,m1>x|j2,m2>, (m1, -j1, j1), (m2, -j2, j2))
2030
2031    Perform the same calculation using a SpinState state
2032
2033        >>> uncouple(JzKet(j, m), (j1, j2))
2034        Sum(CG(j1, m1, j2, m2, j, m)*|j1,m1>x|j2,m2>, (m1, -j1, j1), (m2, -j2, j2))
2035
2036    """
2037    a = expr.atoms(SpinState)
2038    for state in a:
2039        expr = expr.subs(state, _uncouple(state, jn, jcoupling_list))
2040    return expr
2041
2042
2043def _uncouple(state, jn, jcoupling_list):
2044    if isinstance(state, CoupledSpinState):
2045        jn = state.jn
2046        coupled_n = state.coupled_n
2047        coupled_jn = state.coupled_jn
2048        evect = state.uncoupled_class()
2049    elif isinstance(state, SpinState):
2050        if jn is None:
2051            raise ValueError("Must specify j-values for coupled state")
2052        if not (isinstance(jn, list) or isinstance(jn, tuple)):
2053            raise TypeError("jn must be list or tuple")
2054        if jcoupling_list is None:
2055            # Use default
2056            jcoupling_list = []
2057            for i in range(1, len(jn)):
2058                jcoupling_list.append(
2059                    (1, 1 + i, Add(*[jn[j] for j in range(i + 1)])) )
2060        if not (isinstance(jcoupling_list, list) or isinstance(jcoupling_list, tuple)):
2061            raise TypeError("jcoupling must be a list or tuple")
2062        if not len(jcoupling_list) == len(jn) - 1:
2063            raise ValueError("Must specify 2 fewer coupling terms than the number of j values")
2064        coupled_n, coupled_jn = _build_coupled(jcoupling_list, len(jn))
2065        evect = state.__class__
2066    else:
2067        raise TypeError("state must be a spin state")
2068    j = state.j
2069    m = state.m
2070    coupling_list = []
2071    j_list = list(jn)
2072
2073    # Create coupling, which defines all the couplings between all the spaces
2074    for j3, (n1, n2) in zip(coupled_jn, coupled_n):
2075        # j's which are coupled as first and second spaces
2076        j1 = j_list[n1[0] - 1]
2077        j2 = j_list[n2[0] - 1]
2078        # Build coupling list
2079        coupling_list.append( (n1, n2, j1, j2, j3) )
2080        # Set new value in j_list
2081        j_list[min(n1 + n2) - 1] = j3
2082
2083    if j.is_number and m.is_number:
2084        diff_max = [ 2*x for x in jn ]
2085        diff = Add(*jn) - m
2086
2087        n = len(jn)
2088        tot = binomial(diff + n - 1, diff)
2089
2090        result = []
2091        for config_num in range(tot):
2092            diff_list = _confignum_to_difflist(config_num, diff, n)
2093            if any( [ d > p for d, p in zip(diff_list, diff_max) ] ):
2094                continue
2095
2096            cg_terms = []
2097            for coupling in coupling_list:
2098                j1_n, j2_n, j1, j2, j3 = coupling
2099                m1 = Add( *[ jn[x - 1] - diff_list[x - 1] for x in j1_n ] )
2100                m2 = Add( *[ jn[x - 1] - diff_list[x - 1] for x in j2_n ] )
2101                m3 = m1 + m2
2102                cg_terms.append( (j1, m1, j2, m2, j3, m3) )
2103            coeff = Mul( *[ CG(*term).doit() for term in cg_terms ] )
2104            state = TensorProduct(
2105                *[ evect(j, j - d) for j, d in zip(jn, diff_list) ] )
2106            result.append(coeff*state)
2107        return Add(*result)
2108    else:
2109        # Symbolic coupling
2110        m_str = "m1:%d" % (len(jn) + 1)
2111        mvals = symbols(m_str)
2112        cg_terms = [(j1, Add(*[mvals[n - 1] for n in j1_n]),
2113                     j2, Add(*[mvals[n - 1] for n in j2_n]),
2114                     j3, Add(*[mvals[n - 1] for n in j1_n + j2_n])) for j1_n, j2_n, j1, j2, j3 in coupling_list[:-1] ]
2115        cg_terms.append(*[(j1, Add(*[mvals[n - 1] for n in j1_n]),
2116                           j2, Add(*[mvals[n - 1] for n in j2_n]),
2117                           j, m) for j1_n, j2_n, j1, j2, j3 in [coupling_list[-1]] ])
2118        cg_coeff = Mul(*[CG(*cg_term) for cg_term in cg_terms])
2119        sum_terms = [ (m, -j, j) for j, m in zip(jn, mvals) ]
2120        state = TensorProduct( *[ evect(j, m) for j, m in zip(jn, mvals) ] )
2121        return Sum(cg_coeff*state, *sum_terms)
2122
2123
2124def _confignum_to_difflist(config_num, diff, list_len):
2125    # Determines configuration of diffs into list_len number of slots
2126    diff_list = []
2127    for n in range(list_len):
2128        prev_diff = diff
2129        # Number of spots after current one
2130        rem_spots = list_len - n - 1
2131        # Number of configurations of distributing diff among the remaining spots
2132        rem_configs = binomial(diff + rem_spots - 1, diff)
2133        while config_num >= rem_configs:
2134            config_num -= rem_configs
2135            diff -= 1
2136            rem_configs = binomial(diff + rem_spots - 1, diff)
2137        diff_list.append(prev_diff - diff)
2138    return diff_list
2139