1# SPDX-License-Identifier: Apache-2.0
2# Generate test instances from a large tensor product set of options
3
4Monoids = ["PLUS","MIN","MAX","TIMES","ANY"]
5Binops  = ["TIMES", "PLUS", "MIN", "MAX", "DIV","MINUS", "RDIV","RMINUS","FIRST","SECOND","PAIR"]
6Semirings = ["PLUS_TIMES", "MIN_PLUS", "MAX_PLUS"]
7#Semirings = ["PLUS_TIMES"]#,"MIN_PLUS"] #, "MAX_PLUS"]
8
9#DataTypes = ["bool","int8_t","int16_t", "int32_t", "int64_t",
10#                    "uint8_t", "uint16_t", "uint32_t", "uint64_t",
11#                    "float","double"]
12DataTypes = ["int32_t", "int64_t", "uint32_t","uint64_t","float","double"]
13#DataTypes = ["float","double"]
14DataTypes = ["int32_t","uint64_t"]
15
16DataShapes ={
17             "tinyxtiny": {'N':32, 'Anz':256, 'Bnz':128},
18             "smallxsmall": {'N':1024, 'Anz': 65_536, 'Bnz':65_536}
19            # "medxmed": {'N':4096, 'Anz': 2**20, 'Bnz':2**20}
20            # "largexlarge": {'N':2**16, 'Anz': 64*2**20, 'Bnz':64*2**20}
21             }
22
23Kernels= ["warp","mp", "vsvs","dndn", "spdn","vssp"]
24Kernels= ["warp"] #, "vsvs","dndn", "spdn","vssp"]
25
26
27
28def buildTest(ts="TestsuiteName",kern="vsvs", ds= "tiny-tiny", SR = "PLUS_TIMES",phase=3,
29              typeC="int",typeM="int",typeA="int",typeB="int",type_x="int",type_y="int",type_z="int"):
30
31    # build string interpolation from pieces
32    Test_name = f"{ds}{SR}C{typeC}M{typeM}A{typeA}B{typeB}X{type_x}Y{type_y}Z{type_z}"
33
34    Test_suite = ts
35    #print(Test_suite)
36    TEST_HEAD = f"""TEST( {Test_suite}, {Test_name})"""
37    #print(TEST_HEAD)
38    N = DataShapes[ds]['N']
39    Anz = DataShapes[ds]['Anz']
40    Bnz = DataShapes[ds]['Bnz']
41    phase1_body= f""" test_AxB_dot3_phase1_factory< {typeC}, {typeM}, {typeA}, {typeB}>( 5, {N}, {Anz},{Bnz});"""
42    phase2_body= f""" test_AxB_dot3_phase2_factory< {typeC} >( 5, {N}, {Anz},{Bnz});"""
43    phase3_body = f""" test_AxB_dot3_{kern}_factory< {typeC},{typeM},{typeA},{typeB},{type_x},{type_y},{type_z} > (5, {N}, {Anz}, {Bnz}, SR);"""
44    #print( TEST_BODY)
45    phasedict = { 1: phase1_body, 2: phase2_body, 3: phase3_body}
46    TEST_BODY= phasedict[phase]
47
48    return TEST_HEAD,TEST_BODY
49
50
51if __name__ == "__main__":
52
53
54   #print( buildTest()) #test if anything works
55
56
57   outfile = f"""AxB_dot3_test_instances.hpp"""
58   fp = open(outfile, 'w')
59
60
61   for k in Kernels:
62       Test_suite = f'AxB_dot3_tests_{k}'
63       for SR in Semirings:
64           for dtC in DataTypes:
65               dtX = dtC
66               dtY = dtC
67               dtZ = dtC
68               for dtM in ["bool", "int32_t"]:
69                   for dtA in DataTypes:
70                       for dtB in DataTypes:
71                           for ds in DataShapes:
72                               for phase in [3]:
73
74                                   TEST_HEAD, TEST_BODY = buildTest( Test_suite, k, ds, SR, phase,
75                                                                     dtC, dtM, dtA, dtB, dtX, dtY, dtZ)
76                                   fp.write( TEST_HEAD)
77                                   fp.write( """{ std::string SR = "%s"; """%SR)
78                                   fp.write( TEST_BODY)
79                                   fp.write( "}\n")
80
81
82   fp.close()
83
84