1# SPDX-License-Identifier: Apache-2.0 2# Generate test instances from a large tensor product set of options 3 4Monoids = ["PLUS","MIN","MAX","TIMES","ANY"] 5Binops = ["TIMES", "PLUS", "MIN", "MAX", "DIV","MINUS", "RDIV","RMINUS","FIRST","SECOND","PAIR"] 6Semirings = ["PLUS_TIMES", "MIN_PLUS", "MAX_PLUS"] 7#Semirings = ["PLUS_TIMES"]#,"MIN_PLUS"] #, "MAX_PLUS"] 8 9#DataTypes = ["bool","int8_t","int16_t", "int32_t", "int64_t", 10# "uint8_t", "uint16_t", "uint32_t", "uint64_t", 11# "float","double"] 12DataTypes = ["int32_t", "int64_t", "uint32_t","uint64_t","float","double"] 13#DataTypes = ["float","double"] 14DataTypes = ["int32_t","uint64_t"] 15 16DataShapes ={ 17 "tinyxtiny": {'N':32, 'Anz':256, 'Bnz':128}, 18 "smallxsmall": {'N':1024, 'Anz': 65_536, 'Bnz':65_536} 19 # "medxmed": {'N':4096, 'Anz': 2**20, 'Bnz':2**20} 20 # "largexlarge": {'N':2**16, 'Anz': 64*2**20, 'Bnz':64*2**20} 21 } 22 23Kernels= ["warp","mp", "vsvs","dndn", "spdn","vssp"] 24Kernels= ["warp"] #, "vsvs","dndn", "spdn","vssp"] 25 26 27 28def buildTest(ts="TestsuiteName",kern="vsvs", ds= "tiny-tiny", SR = "PLUS_TIMES",phase=3, 29 typeC="int",typeM="int",typeA="int",typeB="int",type_x="int",type_y="int",type_z="int"): 30 31 # build string interpolation from pieces 32 Test_name = f"{ds}{SR}C{typeC}M{typeM}A{typeA}B{typeB}X{type_x}Y{type_y}Z{type_z}" 33 34 Test_suite = ts 35 #print(Test_suite) 36 TEST_HEAD = f"""TEST( {Test_suite}, {Test_name})""" 37 #print(TEST_HEAD) 38 N = DataShapes[ds]['N'] 39 Anz = DataShapes[ds]['Anz'] 40 Bnz = DataShapes[ds]['Bnz'] 41 phase1_body= f""" test_AxB_dot3_phase1_factory< {typeC}, {typeM}, {typeA}, {typeB}>( 5, {N}, {Anz},{Bnz});""" 42 phase2_body= f""" test_AxB_dot3_phase2_factory< {typeC} >( 5, {N}, {Anz},{Bnz});""" 43 phase3_body = f""" test_AxB_dot3_{kern}_factory< {typeC},{typeM},{typeA},{typeB},{type_x},{type_y},{type_z} > (5, {N}, {Anz}, {Bnz}, SR);""" 44 #print( TEST_BODY) 45 phasedict = { 1: phase1_body, 2: phase2_body, 3: phase3_body} 46 TEST_BODY= phasedict[phase] 47 48 return TEST_HEAD,TEST_BODY 49 50 51if __name__ == "__main__": 52 53 54 #print( buildTest()) #test if anything works 55 56 57 outfile = f"""AxB_dot3_test_instances.hpp""" 58 fp = open(outfile, 'w') 59 60 61 for k in Kernels: 62 Test_suite = f'AxB_dot3_tests_{k}' 63 for SR in Semirings: 64 for dtC in DataTypes: 65 dtX = dtC 66 dtY = dtC 67 dtZ = dtC 68 for dtM in ["bool", "int32_t"]: 69 for dtA in DataTypes: 70 for dtB in DataTypes: 71 for ds in DataShapes: 72 for phase in [3]: 73 74 TEST_HEAD, TEST_BODY = buildTest( Test_suite, k, ds, SR, phase, 75 dtC, dtM, dtA, dtB, dtX, dtY, dtZ) 76 fp.write( TEST_HEAD) 77 fp.write( """{ std::string SR = "%s"; """%SR) 78 fp.write( TEST_BODY) 79 fp.write( "}\n") 80 81 82 fp.close() 83 84