1 #include <map>
2 #include <sstream>
3 #include <vector>
4 
5 #include "DebugToFile.h"
6 #include "Function.h"
7 #include "IRMutator.h"
8 #include "IROperator.h"
9 
10 namespace Halide {
11 namespace Internal {
12 
13 using std::map;
14 using std::ostringstream;
15 using std::string;
16 using std::vector;
17 
18 class DebugToFile : public IRMutator {
19     const map<string, Function> &env;
20 
21     using IRMutator::visit;
22 
visit(const Realize * op)23     Stmt visit(const Realize *op) override {
24         map<string, Function>::const_iterator iter = env.find(op->name);
25         if (iter != env.end() && !iter->second.debug_file().empty()) {
26             Function f = iter->second;
27             vector<Expr> args;
28 
29             user_assert(op->types.size() == 1)
30                 << "debug_to_file doesn't handle functions with multiple values yet\n";
31 
32             // The name of the file
33             args.emplace_back(f.debug_file());
34 
35             // Inject loads to the corners of the function so that any
36             // passes doing further analysis of buffer use understand
37             // what we're doing (e.g. so we trigger a copy-back from a
38             // device pointer).
39             Expr num_elements = 1;
40             for (size_t i = 0; i < op->bounds.size(); i++) {
41                 num_elements *= op->bounds[i].extent;
42             }
43 
44             int type_code = 0;
45             Type t = op->types[0];
46             if (t == Float(32)) {
47                 type_code = 0;
48             } else if (t == Float(64)) {
49                 type_code = 1;
50             } else if (t == UInt(8) || t == UInt(1)) {
51                 type_code = 2;
52             } else if (t == Int(8)) {
53                 type_code = 3;
54             } else if (t == UInt(16)) {
55                 type_code = 4;
56             } else if (t == Int(16)) {
57                 type_code = 5;
58             } else if (t == UInt(32)) {
59                 type_code = 6;
60             } else if (t == Int(32)) {
61                 type_code = 7;
62             } else if (t == UInt(64)) {
63                 type_code = 8;
64             } else if (t == Int(64)) {
65                 type_code = 9;
66             } else {
67                 user_error << "Type " << t << " not supported for debug_to_file\n";
68             }
69             args.emplace_back(type_code);
70 
71             Expr buf = Variable::make(Handle(), f.name() + ".buffer");
72             args.push_back(buf);
73 
74             Expr call = Call::make(Int(32), Call::debug_to_file, args, Call::Intrinsic);
75             string call_result_name = unique_name("debug_to_file_result");
76             Expr call_result_var = Variable::make(Int(32), call_result_name);
77             Stmt body = AssertStmt::make(call_result_var == 0,
78                                          Call::make(Int(32), "halide_error_debug_to_file_failed",
79                                                     {f.name(), f.debug_file(), call_result_var},
80                                                     Call::Extern));
81             body = LetStmt::make(call_result_name, call, body);
82             body = Block::make(mutate(op->body), body);
83 
84             return Realize::make(op->name, op->types, op->memory_type, op->bounds, op->condition, body);
85         } else {
86             return IRMutator::visit(op);
87         }
88     }
89 
90 public:
DebugToFile(const map<string,Function> & e)91     DebugToFile(const map<string, Function> &e)
92         : env(e) {
93     }
94 };
95 
96 class RemoveDummyRealizations : public IRMutator {
97     const vector<Function> &outputs;
98 
99     using IRMutator::visit;
100 
visit(const Realize * op)101     Stmt visit(const Realize *op) override {
102         for (Function f : outputs) {
103             if (op->name == f.name()) {
104                 return mutate(op->body);
105             }
106         }
107         return IRMutator::visit(op);
108     }
109 
110 public:
RemoveDummyRealizations(const vector<Function> & o)111     RemoveDummyRealizations(const vector<Function> &o)
112         : outputs(o) {
113     }
114 };
115 
116 class AddDummyRealizations : public IRMutator {
117     const vector<Function> &outputs;
118 
119     using IRMutator::visit;
120 
visit(const ProducerConsumer * op)121     Stmt visit(const ProducerConsumer *op) override {
122         Stmt s = IRMutator::visit(op);
123         for (Function out : outputs) {
124             if (op->name == out.name()) {
125                 std::vector<Range> output_bounds;
126                 for (int i = 0; i < out.dimensions(); i++) {
127                     string dim = std::to_string(i);
128                     Expr min = Variable::make(Int(32), out.name() + ".min." + dim);
129                     Expr extent = Variable::make(Int(32), out.name() + ".extent." + dim);
130                     output_bounds.emplace_back(min, extent);
131                 }
132                 return Realize::make(out.name(),
133                                      out.output_types(),
134                                      MemoryType::Auto,
135                                      output_bounds,
136                                      const_true(),
137                                      s);
138             }
139         }
140         return s;
141     }
142 
143 public:
AddDummyRealizations(const vector<Function> & o)144     AddDummyRealizations(const vector<Function> &o)
145         : outputs(o) {
146     }
147 };
148 
debug_to_file(Stmt s,const vector<Function> & outputs,const map<string,Function> & env)149 Stmt debug_to_file(Stmt s, const vector<Function> &outputs, const map<string, Function> &env) {
150     // Temporarily wrap the produce nodes for the output functions in
151     // realize nodes so that we know when to write the debug outputs.
152     s = AddDummyRealizations(outputs).mutate(s);
153 
154     s = DebugToFile(env).mutate(s);
155 
156     // Remove the realize node we wrapped around the output
157     s = RemoveDummyRealizations(outputs).mutate(s);
158 
159     return s;
160 }
161 
162 }  // namespace Internal
163 }  // namespace Halide
164