1import os 2import argparse 3import AutoGemmParameters 4import Common 5 6 7################################################################################ 8# Auto-Gemm 9################################################################################ 10 11def writeOfflineCompilation(args): 12 print("AutoGemm.py: Generating list of kernels to pre-compile.") 13 if not os.path.exists( Common.getIncludePath() ): 14 os.makedirs( Common.getIncludePath() ) 15 16 ocFileName = Common.getIncludePath() + "AutoGemmKernelsToPreCompile.h" 17 ocFile = open(ocFileName, "w") 18 ocFile.write( Common.getAutoGemmHeader() ) 19 20 fileStr = "\n/*precision, order, transA, transB, beta, tileNumRows, tileNumCols, unroll*/\n" 21 fileStr += "\nunsigned int gemmPreCompile[][8] = {\n" 22 23 count = 0 24 for precision in args.precisions: 25 ocFile.write( fileStr ) 26 fileStr = "" 27 validTiles = AutoGemmParameters.getTilesForPrecision(precision) 28 for order in args.orders: 29 for transpose in args.transposes: 30 transA = transpose[0] 31 transB = transpose[1] 32 if (transA=="C" or transB=="C") and (precision=="s" or precision=="d"): 33 # real precision doesn't have conjugate transpose 34 continue 35 for beta in args.betas: 36 for tile in validTiles: 37 # print combination 38 kernelStr = " { %1u, %1u, %1u, %1u, %1u, %3u, %3u, %2u },\n" \ 39 % ( 40 Common.precisionInt[precision], 41 Common.orderInt[order], 42 Common.transposeInt[transA], 43 Common.transposeInt[transB], 44 beta, 45 tile.macroTileNumRows, 46 tile.macroTileNumCols, 47 tile.unroll 48 ) 49 fileStr += kernelStr 50 #print kernelStr 51 count+=1 52 if count is 0: 53 fileStr += " { %1u, %1u, %1u, %1u, %1u, %3u, %3u, %2u },\n" \ 54 % ( 0, 0, 0, 0, 0, 0, 0, 0 ) 55 fileStr += "};\n" 56 fileStr += "unsigned int gemmPreCompileNum = " + str(count) + ";\n" 57 ocFile.write( fileStr ) 58 ocFile.close() 59 count *= 4 60 print("AutoGemm.py: %u kernels will be pre-compiled." % count) 61 62 63################################################################################ 64# Main 65################################################################################ 66if __name__ == "__main__": 67 68 # parse arguments 69 ap = argparse.ArgumentParser(description="Which gemm kernels to compile offline.") 70 ap.add_argument("--output-path", dest="output" ) 71 ap.add_argument("--precisions", dest="precisions", action="store", nargs="+", choices=AutoGemmParameters.precisions ) 72 ap.add_argument("--orders", dest="orders", action="store", nargs="+", choices=AutoGemmParameters.orders ) 73 ap.add_argument("--transposes", dest="transposes", action="store", nargs="+", choices=AutoGemmParameters.getTransposeChoices() ) 74 ap.add_argument("--betas", dest="betas", action="store", nargs="+", type=int, choices=AutoGemmParameters.betas ) 75 args = ap.parse_args() 76 if args.output: 77 Common.setOutputPath(args.output) 78 else: 79 print("Warning: No output path specified; default is working directory.") 80 81 # write offline compilation header 82 if args.precisions is None: 83 args.precisions = [] 84 if args.transposes is None: 85 args.transposes = [] 86 if args.orders is None: 87 args.orders = [] 88 if args.betas is None: 89 args.betas = [] 90 writeOfflineCompilation(args) 91