1 /*
2  * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #ifndef NNPI_NNPITRACING_ML_WRAPPER_H
17 #define NNPI_NNPITRACING_ML_WRAPPER_H
18 
19 #include <map>
20 #include <nnpi_ice_caps.h>
21 #include <nnpi_inference.h>
22 #include <vector>
23 
24 struct NNPITraceEntry {
25   uint64_t engineTime{0};
26   uint64_t hostTime{0};
27   std::map<std::string, std::string> params;
28 };
29 
30 /// Device trace api wrapper.
31 class NNPITraceContext {
32 public:
33   NNPITraceContext(unsigned devID);
34   virtual ~NNPITraceContext();
35   /// Start capturing traces from the HW device.
36   bool startCapture(NNPIDeviceContext deviceContext, bool swTraces,
37                     bool hwTraces, uint32_t softwareBufferSizeMB,
38                     uint32_t hardwareBufferSizeMB);
39   /// Start capturing.
40   bool stopCapture(NNPIDeviceContext deviceContext) const;
41   /// Load traces (valid only after stopCapture()).
42   bool load();
43   /// Returns the number of traces captured and loaded (valid only after
44   /// load()).
getTraceCount()45   size_t getTraceCount() const { return entries_.size(); }
46   /// Read a loaded entry by index.
getEntry(int index)47   NNPITraceEntry &getEntry(int index) { return entries_[index]; }
48   /// Get the context device ID.
getDeviceID()49   uint32_t getDeviceID() const { return devID_; }
50   /// Returns true if device ID was set, false otherwise.
isDeviceIDSet()51   bool isDeviceIDSet() const { return devIDSet_; }
52   /// Get a vector of the loaded entries (valid only after load()).
getEntries()53   std::vector<NNPITraceEntry> getEntries() const { return entries_; }
54 
55 private:
56   bool destroyInternalContext();
57   bool createInternalContext(bool swTraces, bool hwTraces,
58                              uint32_t softwareBufferSizeMB,
59                              uint32_t hardwareBufferSizeMB);
60   bool readTraceOutput();
61 
62   IceCaps_t capsSession_{0};
63   uint64_t devMask_{0};
64   unsigned devID_{0};
65   bool devIDSet_{false};
66   std::vector<NNPITraceEntry> entries_;
67 };
68 
69 #endif // NNPI_NNPITRACING_ML_WRAPPER_H
70