1from sympy.core import Basic, Expr 2from sympy.core.sympify import _sympify 3from sympy.matrices.expressions.transpose import transpose 4 5 6class DotProduct(Expr): 7 """ 8 Dot product of vector matrices 9 10 The input should be two 1 x n or n x 1 matrices. The output represents the 11 scalar dotproduct. 12 13 This is similar to using MatrixElement and MatMul, except DotProduct does 14 not require that one vector to be a row vector and the other vector to be 15 a column vector. 16 17 >>> from sympy import MatrixSymbol, DotProduct 18 >>> A = MatrixSymbol('A', 1, 3) 19 >>> B = MatrixSymbol('B', 1, 3) 20 >>> DotProduct(A, B) 21 DotProduct(A, B) 22 >>> DotProduct(A, B).doit() 23 A[0, 0]*B[0, 0] + A[0, 1]*B[0, 1] + A[0, 2]*B[0, 2] 24 """ 25 26 def __new__(cls, arg1, arg2): 27 arg1, arg2 = _sympify((arg1, arg2)) 28 29 if not arg1.is_Matrix: 30 raise TypeError("Argument 1 of DotProduct is not a matrix") 31 if not arg2.is_Matrix: 32 raise TypeError("Argument 2 of DotProduct is not a matrix") 33 if not (1 in arg1.shape): 34 raise TypeError("Argument 1 of DotProduct is not a vector") 35 if not (1 in arg2.shape): 36 raise TypeError("Argument 2 of DotProduct is not a vector") 37 38 if set(arg1.shape) != set(arg2.shape): 39 raise TypeError("DotProduct arguments are not the same length") 40 41 return Basic.__new__(cls, arg1, arg2) 42 43 def doit(self, expand=False): 44 if self.args[0].shape == self.args[1].shape: 45 if self.args[0].shape[0] == 1: 46 mul = self.args[0]*transpose(self.args[1]) 47 else: 48 mul = transpose(self.args[0])*self.args[1] 49 else: 50 if self.args[0].shape[0] == 1: 51 mul = self.args[0]*self.args[1] 52 else: 53 mul = transpose(self.args[0])*transpose(self.args[1]) 54 55 return mul[0] 56