1from sympy.core import Basic, Expr
2from sympy.core.sympify import _sympify
3from sympy.matrices.expressions.transpose import transpose
4
5
6class DotProduct(Expr):
7    """
8    Dot product of vector matrices
9
10    The input should be two 1 x n or n x 1 matrices. The output represents the
11    scalar dotproduct.
12
13    This is similar to using MatrixElement and MatMul, except DotProduct does
14    not require that one vector to be a row vector and the other vector to be
15    a column vector.
16
17    >>> from sympy import MatrixSymbol, DotProduct
18    >>> A = MatrixSymbol('A', 1, 3)
19    >>> B = MatrixSymbol('B', 1, 3)
20    >>> DotProduct(A, B)
21    DotProduct(A, B)
22    >>> DotProduct(A, B).doit()
23    A[0, 0]*B[0, 0] + A[0, 1]*B[0, 1] + A[0, 2]*B[0, 2]
24    """
25
26    def __new__(cls, arg1, arg2):
27        arg1, arg2 = _sympify((arg1, arg2))
28
29        if not arg1.is_Matrix:
30            raise TypeError("Argument 1 of DotProduct is not a matrix")
31        if not arg2.is_Matrix:
32            raise TypeError("Argument 2 of DotProduct is not a matrix")
33        if not (1 in arg1.shape):
34            raise TypeError("Argument 1 of DotProduct is not a vector")
35        if not (1 in arg2.shape):
36            raise TypeError("Argument 2 of DotProduct is not a vector")
37
38        if set(arg1.shape) != set(arg2.shape):
39            raise TypeError("DotProduct arguments are not the same length")
40
41        return Basic.__new__(cls, arg1, arg2)
42
43    def doit(self, expand=False):
44        if self.args[0].shape == self.args[1].shape:
45            if self.args[0].shape[0] == 1:
46                mul = self.args[0]*transpose(self.args[1])
47            else:
48                mul = transpose(self.args[0])*self.args[1]
49        else:
50            if self.args[0].shape[0] == 1:
51                mul = self.args[0]*self.args[1]
52            else:
53                mul = transpose(self.args[0])*transpose(self.args[1])
54
55        return mul[0]
56