1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9 
10 #ifndef EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_BLOCKING_H
11 #define EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_BLOCKING_H
12 
13 
14 namespace Eigen {
15 namespace internal {
16 
17 enum {
18   ShardByRow = 0,
19   ShardByCol = 1
20 };
21 
22 
23 // Default Blocking Strategy
24 template <typename LhsMapper, typename RhsMapper, typename Index, int ShardingType=ShardByCol>
25 class TensorContractionBlocking {
26  public:
27 
28   typedef typename LhsMapper::Scalar LhsScalar;
29   typedef typename RhsMapper::Scalar RhsScalar;
30 
31   EIGEN_DEVICE_FUNC TensorContractionBlocking(Index k, Index m, Index n, Index num_threads = 1) :
kc_(k)32       kc_(k), mc_(m), nc_(n)
33   {
34     if (ShardingType == ShardByCol) {
35       computeProductBlockingSizes<LhsScalar, RhsScalar, 1>(kc_, mc_, nc_, num_threads);
36     }
37     else {
38       computeProductBlockingSizes<LhsScalar, RhsScalar, 1>(kc_, nc_, mc_, num_threads);
39     }
40   }
41 
kc()42   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Index kc() const { return kc_; }
mc()43   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Index mc() const { return mc_; }
nc()44   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Index nc() const { return nc_; }
45 
46  private:
47   Index kc_;
48   Index mc_;
49   Index nc_;
50 };
51 
52 
53 } // end namespace internal
54 } // end namespace Eigen
55 
56 #endif // EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_BLOCKING_H
57