1import os
2import argparse
3import AutoGemmParameters
4import Common
5
6
7################################################################################
8# Auto-Gemm
9################################################################################
10
11def writeOfflineCompilation(args):
12  print("AutoGemm.py: Generating list of kernels to pre-compile.")
13  if not os.path.exists( Common.getIncludePath() ):
14    os.makedirs( Common.getIncludePath() )
15
16  ocFileName = Common.getIncludePath() + "AutoGemmKernelsToPreCompile.h"
17  ocFile = open(ocFileName, "w")
18  ocFile.write( Common.getAutoGemmHeader() )
19
20  fileStr = "\n/*precision, order, transA, transB, beta, tileNumRows, tileNumCols, unroll*/\n"
21  fileStr += "\nunsigned int gemmPreCompile[][8] = {\n"
22
23  count = 0
24  for precision in args.precisions:
25    ocFile.write( fileStr )
26    fileStr = ""
27    validTiles = AutoGemmParameters.getTilesForPrecision(precision)
28    for order in args.orders:
29      for transpose in args.transposes:
30        transA = transpose[0]
31        transB = transpose[1]
32        if (transA=="C" or transB=="C") and (precision=="s" or precision=="d"):
33          # real precision doesn't have conjugate transpose
34          continue
35        for beta in args.betas:
36          for tile in validTiles:
37            # print combination
38            kernelStr = "  { %1u, %1u, %1u, %1u, %1u, %3u, %3u, %2u },\n" \
39                % (
40                Common.precisionInt[precision],
41                Common.orderInt[order],
42                Common.transposeInt[transA],
43                Common.transposeInt[transB],
44                beta,
45                tile.macroTileNumRows,
46                tile.macroTileNumCols,
47                tile.unroll
48                )
49            fileStr += kernelStr
50            #print kernelStr
51            count+=1
52  if count is 0:
53    fileStr += "  { %1u, %1u, %1u, %1u, %1u, %3u, %3u, %2u },\n" \
54        % ( 0, 0, 0, 0, 0, 0, 0, 0 )
55  fileStr += "};\n"
56  fileStr += "unsigned int gemmPreCompileNum = " + str(count) + ";\n"
57  ocFile.write( fileStr )
58  ocFile.close()
59  count *= 4
60  print("AutoGemm.py: %u kernels will be pre-compiled." % count)
61
62
63################################################################################
64# Main
65################################################################################
66if __name__ == "__main__":
67
68  # parse arguments
69  ap = argparse.ArgumentParser(description="Which gemm kernels to compile offline.")
70  ap.add_argument("--output-path", dest="output" )
71  ap.add_argument("--precisions", dest="precisions", action="store", nargs="+", choices=AutoGemmParameters.precisions )
72  ap.add_argument("--orders", dest="orders", action="store", nargs="+", choices=AutoGemmParameters.orders )
73  ap.add_argument("--transposes", dest="transposes", action="store", nargs="+", choices=AutoGemmParameters.getTransposeChoices() )
74  ap.add_argument("--betas", dest="betas", action="store", nargs="+", type=int, choices=AutoGemmParameters.betas )
75  args = ap.parse_args()
76  if args.output:
77    Common.setOutputPath(args.output)
78  else:
79    print("Warning: No output path specified; default is working directory.")
80
81  # write offline compilation header
82  if args.precisions is None:
83    args.precisions = []
84  if args.transposes is None:
85    args.transposes = []
86  if args.orders is None:
87    args.orders = []
88  if args.betas is None:
89    args.betas = []
90  writeOfflineCompilation(args)
91