1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 use std::{cmp, collections::HashMap, convert::TryFrom, iter::FromIterator, mem, str};
21 
22 use failure::Error;
23 use nom::{alpha1, digit1, le_i32, le_i64, le_u16, le_u32, le_u64, le_u8, types::CompleteStr};
24 use serde;
25 use serde_json;
26 use tvm_common::{
27     array::{DataType, TVMContext},
28     ffi::{DLDataTypeCode_kDLFloat, DLDataTypeCode_kDLInt, DLDataTypeCode_kDLUInt, DLTensor},
29     TVMArgValue,
30 };
31 
32 use crate::{errors::GraphFormatError, Module, Storage, Tensor};
33 
34 // @see `kTVMNDArrayMagic` in `ndarray.h`
35 const _NDARRAY_MAGIC: u64 = 0xDD5E40F096B4A13F;
36 // @see `kTVMNDArrayListMagic` in `graph_runtime.h`
37 const _NDARRAY_LIST_MAGIC: u64 = 0xF7E58D4F05049CB7;
38 
39 /// A TVM computation graph.
40 ///
41 /// # Examples
42 ///
43 /// ```
44 /// let graph_json = fs::read_to_string("graph.json").unwrap();
45 /// let graph = Graph::try_from(&graph_json).unwrap();
46 /// ```
47 #[derive(Serialize, Deserialize, Debug)]
48 pub struct Graph {
49     pub nodes: Vec<Node>,
50     pub arg_nodes: Vec<usize>,
51     pub heads: Vec<Entry>,
52     pub node_row_ptr: Option<Vec<usize>>,
53     pub attrs: Option<HashMap<String, serde_json::Value>>,
54 }
55 
56 #[derive(Serialize, Deserialize, Debug)]
57 pub struct Entry {
58     pub id: usize,
59     pub index: usize,
60     pub version: usize,
61 }
62 
63 impl Graph {
entry_index(&self, entry: &Entry) -> Result<usize, GraphFormatError>64     fn entry_index(&self, entry: &Entry) -> Result<usize, GraphFormatError> {
65         self.node_row_ptr
66             .as_ref()
67             .map(|nrp| nrp[entry.id] + entry.index)
68             .ok_or_else(|| GraphFormatError::MissingField("node_row_ptr"))
69     }
70 
71     /// Attempt to deserialize a JSON attribute to a type `T`.
get_attr<T: serde::de::DeserializeOwned>(&self, attr: &str) -> Result<T, GraphFormatError>72     fn get_attr<T: serde::de::DeserializeOwned>(&self, attr: &str) -> Result<T, GraphFormatError> {
73         Ok(serde_json::from_value::<T>(
74             self.attrs
75                 .as_ref()
76                 .ok_or(GraphFormatError::MissingField("attrs"))?
77                 .get(attr)
78                 .ok_or_else(|| {
79                     GraphFormatError::MissingAttr("graph".to_string(), attr.to_string())
80                 })?
81                 .to_owned(),
82         )
83         .map_err(|err| GraphFormatError::Parse(err.into()))?)
84     }
85 }
86 
87 #[derive(Serialize, Deserialize, Debug)]
88 pub struct Node {
89     pub op: String,
90     pub name: String,
91     pub inputs: Vec<Entry>,
92     pub attrs: Option<HashMap<String, String>>,
93     pub control_deps: Option<Vec<Entry>>,
94 }
95 
96 struct NodeAttrs {
97     func_name: String,
98     num_outputs: usize,
99     flatten_data: bool,
100 }
101 
102 macro_rules! get_node_attr {
103     ($node:expr, $attrs:ident, $attr:literal) => {
104         $attrs
105             .get($attr)
106             .ok_or_else(|| GraphFormatError::MissingAttr($node.to_owned(), $attr.to_owned()))
107     };
108 }
109 
110 impl Node {
parse_attrs(&self) -> Result<NodeAttrs, Error>111     fn parse_attrs(&self) -> Result<NodeAttrs, Error> {
112         let attrs = self
113             .attrs
114             .as_ref()
115             .ok_or_else(|| GraphFormatError::MissingAttr(self.name.clone(), "attrs".to_owned()))?;
116         Ok(NodeAttrs {
117             func_name: get_node_attr!(self.name, attrs, "func_name")?.to_owned(),
118             num_outputs: get_node_attr!(self.name, attrs, "num_outputs")?.parse::<usize>()?,
119             flatten_data: get_node_attr!(self.name, attrs, "flatten_data")?.parse::<u8>()? == 1,
120         })
121     }
122 }
123 
124 impl<'a> TryFrom<&'a String> for Graph {
125     type Error = Error;
try_from(graph_json: &String) -> Result<Self, self::Error>126     fn try_from(graph_json: &String) -> Result<Self, self::Error> {
127         let graph = serde_json::from_str(graph_json)?;
128         Ok(graph)
129     }
130 }
131 
132 impl<'a> TryFrom<&'a str> for Graph {
133     type Error = Error;
try_from(graph_json: &'a str) -> Result<Self, Self::Error>134     fn try_from(graph_json: &'a str) -> Result<Self, Self::Error> {
135         let graph = serde_json::from_str(graph_json)?;
136         Ok(graph)
137     }
138 }
139 
140 /// A executor for a TVM computation graph.
141 ///
142 /// # Examples
143 ///
144 /// ```
145 /// use ndarray::Array;
146 ///
147 /// let syslib = SystemLibModule::default(); // a provider of TVM functions
148 ///
149 /// let mut params_bytes = Vec::new();
150 /// fs::File::open("graph.params").unwrap().read_to_end(&mut params_bytes).unwrap();
151 /// let params = tvm::runtime::load_param_dict(&params_bytes).unwrap();
152 ///
153 /// let graph = Graph::try_from(&fs::read_to_string("graph.json").unwrap()).unwrap();
154 ///
155 /// let mut exec = GraphExecutor::new(graph, &syslib).unwrap();
156 /// exec.load_params(params);
157 ///
158 /// let x = Array::from_vec(vec![1f32, 2., 3., 4.]);
159 /// exec.set_input("data", x.into());
160 /// exec.run();
161 /// let output = exec.get_output(0).unwrap();
162 ///
163 /// println!("{:#?}", Array::try_from(output).unwrap());
164 /// ```
165 pub struct GraphExecutor<'m, 't> {
166     graph: Graph,
167     op_execs: Vec<Box<dyn Fn() + 'm>>,
168     tensors: Vec<Tensor<'t>>,
169 }
170 
171 unsafe impl<'m, 't> Send for GraphExecutor<'m, 't> {}
172 
173 impl<'m, 't> GraphExecutor<'m, 't> {
new<M: 'm + Module>(graph: Graph, lib: &'m M) -> Result<Self, Error>174     pub fn new<M: 'm + Module>(graph: Graph, lib: &'m M) -> Result<Self, Error> {
175         let tensors = Self::setup_storages(&graph)?;
176         Ok(GraphExecutor {
177             op_execs: Self::setup_op_execs(&graph, lib, &tensors)?,
178             tensors: tensors,
179             graph: graph,
180         })
181     }
182 
183     /// Runs the computation graph.
run(&self)184     pub fn run(&self) {
185         self.op_execs.iter().for_each(|op_exec| {
186             op_exec();
187         });
188     }
189 
190     /// Allocates `Storages` for each `storage_id` and returns `Tensor`s to hold each output.
setup_storages<'a>(graph: &'a Graph) -> Result<Vec<Tensor<'t>>, Error>191     fn setup_storages<'a>(graph: &'a Graph) -> Result<Vec<Tensor<'t>>, Error> {
192         let storage_ids = graph.get_attr::<(String, Vec<usize>)>("storage_id")?.1;
193         let shapes = graph.get_attr::<(String, Vec<Vec<i64>>)>("shape")?.1;
194         let dtypes = graph
195             .get_attr::<(String, Vec<String>)>("dltype")?
196             .1
197             .iter()
198             .map(|dltype| {
199                 if let Ok((_, dtype)) = tvm_str_to_type(CompleteStr(dltype)) {
200                     Ok(dtype)
201                 } else {
202                     Err(GraphFormatError::InvalidDLType(dltype.to_string()))
203                 }
204             })
205             .collect::<Result<Vec<DataType>, GraphFormatError>>()?;
206 
207         let align = dtypes.iter().map(|dtype| dtype.bits() as usize).max();
208         let mut storage_num_bytes = vec![0usize; *storage_ids.iter().max().unwrap_or(&1) + 1];
209         for (i, &storage_id) in storage_ids.iter().enumerate() {
210             let dtype_size = dtypes[i].bits() * dtypes[i].lanes() >> 3;
211             let nbytes = dtype_size * shapes[i].iter().product::<i64>() as usize;
212             storage_num_bytes[storage_id] = cmp::max(nbytes, storage_num_bytes[storage_id]);
213         }
214 
215         let mut storages: Vec<Storage> = storage_num_bytes
216             .into_iter()
217             .map(|nbytes| Storage::new(nbytes, align))
218             .collect::<Result<Vec<Storage>, Error>>()?;
219 
220         let tensors = izip!(storage_ids, shapes, dtypes)
221             .map(|(storage_id, shape, dtype)| {
222                 let storage = storages[storage_id].view();
223                 Tensor {
224                     data: mem::replace(&mut storages[storage_id], storage),
225                     ctx: TVMContext::default(),
226                     dtype: dtype,
227                     size: shape.iter().product::<i64>() as usize,
228                     shape: shape,
229                     strides: None,
230                     byte_offset: 0,
231                 }
232             })
233             .collect();
234 
235         Ok(tensors)
236     }
237 
238     /// Creates closures which represent the computation performed by this graph.
setup_op_execs<M: 'm + Module>( graph: &Graph, lib: &'m M, tensors: &Vec<Tensor<'t>>, ) -> Result<Vec<Box<dyn Fn() + 'm>>, Error>239     fn setup_op_execs<M: 'm + Module>(
240         graph: &Graph,
241         lib: &'m M,
242         tensors: &Vec<Tensor<'t>>,
243     ) -> Result<Vec<Box<dyn Fn() + 'm>>, Error> {
244         ensure!(graph.node_row_ptr.is_some(), "Missing node_row_ptr.");
245         let node_row_ptr = graph.node_row_ptr.as_ref().unwrap();
246 
247         let mut op_execs = Vec::new();
248         for (i, node) in graph.nodes.iter().enumerate() {
249             if node.op == "null" {
250                 continue;
251             }
252             ensure!(node.op == "tvm_op", "Only TVM ops are supported.");
253             ensure!(node.attrs.is_some(), "Missing node attrs.");
254 
255             let attrs = node.parse_attrs()?;
256 
257             if attrs.func_name == "__nop" {
258                 continue;
259             }
260 
261             let func = lib.get_function(&attrs.func_name).ok_or(format_err!(
262                 "Library is missing function {}",
263                 attrs.func_name
264             ))?;
265             let arg_indices = node
266                 .inputs
267                 .iter()
268                 .map(|entry| graph.entry_index(entry))
269                 .chain((0..attrs.num_outputs).map(|oi| Ok(node_row_ptr[i].clone() + oi)));
270 
271             let dl_tensors = arg_indices
272                 .map(|idx| {
273                     let tensor = &tensors[idx?];
274                     Ok(if attrs.flatten_data {
275                         Tensor::as_dltensor(tensor, true /* flatten */)
276                     } else {
277                         DLTensor::from(tensor)
278                     })
279                 })
280                 .collect::<Result<Vec<DLTensor>, Error>>()
281                 .unwrap();
282             let op: Box<dyn Fn()> = box move || {
283                 let args = dl_tensors
284                     .iter()
285                     .map(|t| t.into())
286                     .collect::<Vec<TVMArgValue>>();
287                 func(&args).unwrap();
288             };
289             op_execs.push(op);
290         }
291         Ok(op_execs)
292     }
293 
load_params(&mut self, params: HashMap<String, Tensor>)294     pub fn load_params(&mut self, params: HashMap<String, Tensor>) {
295         params.into_iter().for_each(|(name, param)| {
296             self.set_input(name, param);
297         })
298     }
299 
set_input<S: AsRef<str>>(&mut self, name: S, value: Tensor)300     pub fn set_input<S: AsRef<str>>(&mut self, name: S, value: Tensor) {
301         if let Some(idx) = self.get_input_index(name.as_ref()) {
302             // TODO: consider `new_with_params` to avoid ever allocating
303             let ptr = self.tensors[idx].data.as_ptr();
304             let mut to_replace = self.tensors.iter_mut().filter(|t| t.data.as_ptr() == ptr);
305             let owner = to_replace.nth(0).unwrap();
306             if value.data.is_owned() {
307                 // FIXME: for no-copy, need setup_op_execs to not capture tensor ptr
308                 // mem::replace(&mut (*owner), value);
309                 // to_replace.for_each(|t| {
310                 //   panic!("replacing");
311                 //   t.data = owner.data.view();
312                 // });
313                 owner.copy(&value);
314             } else {
315                 owner.copy(&value);
316             }
317         } else {
318             println!("Unexpected input `{}`", name.as_ref());
319         }
320     }
321 
322     /// Returns the graph input with name `name`, if it exists.
get_input<S: AsRef<str>>(&mut self, name: S) -> Option<&Tensor>323     pub fn get_input<S: AsRef<str>>(&mut self, name: S) -> Option<&Tensor> {
324         self.get_input_index(name.as_ref())
325             .and_then(move |idx| Some(&self.tensors[idx]))
326     }
327 
328     /// Returns the graph output with index `index`, if it exists.
get_output(&self, idx: usize) -> Option<&Tensor>329     pub fn get_output(&self, idx: usize) -> Option<&Tensor> {
330         let graph = &self.graph;
331         graph.heads.get(idx).and_then(|entry| {
332             graph
333                 .entry_index(entry)
334                 .map(|idx| self.tensors.get(idx))
335                 .unwrap_or(None)
336         })
337     }
338 
339     /// Returns the index for graph input with name `name`, if it exists.
get_input_index<S: AsRef<str>>(&self, name: S) -> Option<usize>340     pub fn get_input_index<S: AsRef<str>>(&self, name: S) -> Option<usize> {
341         let graph = &self.graph;
342         (0..graph.nodes.len())
343             .skip_while(|&i| graph.nodes[i].name != name.as_ref())
344             .nth(0)
345             .and_then(|i| {
346                 if graph.arg_nodes.iter().any(|&id| id == i) {
347                     graph.node_row_ptr.as_ref().map(|nrp| nrp[i])
348                 } else {
349                     None
350                 }
351             })
352     }
353 }
354 
355 // Converts a string to TVM DLDataTypeCode. @see `String2TVMType` in packed_func.h
356 named!(
357   tvm_str_to_type<CompleteStr, DataType>,
358   do_parse!(
359     type_name: alpha1 >>
360     bits: digit1 >>
361     lanes: opt!(tuple!(tag!("x"), digit1)) >>
362     (DataType {
363       code: match type_name {
364         CompleteStr("int") => DLDataTypeCode_kDLInt,
365         CompleteStr("uint") => DLDataTypeCode_kDLUInt,
366         CompleteStr("float") => DLDataTypeCode_kDLFloat,
367         _ => DLDataTypeCode_kDLFloat,
368       } as usize,
369       bits: bits.parse::<u8>().unwrap() as usize,
370       lanes: match lanes {
371         Some(lanes) => lanes.1.parse::<u16>().unwrap() as usize,
372         None => 1,
373       },
374     })
375   )
376 );
377 
378 // Converts a bytes to String.
379 named!(
380     name<String>,
381     map_res!(length_bytes!(le_u64), |b: &[u8]| String::from_utf8(
382         b.to_vec()
383     ))
384 );
385 
386 // Parses a TVMContext
387 named!(
388   tvm_ctx<&[u8], TVMContext>,
389   do_parse!(
390     device_type: le_u32 >>
391     device_id: le_i32 >>
392     (TVMContext { device_type: device_type as usize, device_id: device_id as usize })
393   )
394 );
395 
396 // Parses a DataType
397 named!(
398   data_type<&[u8], DataType>,
399   do_parse!(
400     code: le_u8 >>
401     bits: le_u8 >>
402     lanes: le_u16 >>
403     (DataType { code: code as usize, bits: bits as usize, lanes: lanes as usize })
404   )
405 );
406 
407 // Parses a Tensor from a TVM array file.
408 named!(
409     tensor<Tensor>,
410     do_parse!(
411         take!(8)
412             >> bits!(tag_bits!(u64, 64, 0))
413             >> ctx: tvm_ctx
414             >> ndim: le_u32
415             >> dtype: data_type
416             >> shape: count!(map!(le_i64, |sz| sz as i64), ndim as usize)
417             >> length: le_i64
418             >> data: take!(length)
419             >> (Tensor {
420                 data: Storage::from(data),
421                 ctx: ctx,
422                 dtype: dtype,
423                 size: shape.iter().product::<i64>() as usize,
424                 shape: shape,
425                 strides: None,
426                 byte_offset: 0,
427             })
428     )
429 );
430 
431 // Parses a graph params dict from a params binary file.
432 named!(
433     parse_param_dict<HashMap<String, Tensor>>,
434     do_parse!(
435         take!(8)
436             >> bits!(tag_bits!(u64, 64, 0))
437             >> names: length_count!(le_u64, name)
438             >> tensors: length_count!(le_u64, tensor)
439             >> (HashMap::from_iter(names.into_iter().zip(tensors.into_iter())))
440     )
441 );
442 
443 /// Loads a param dict saved using `nnvm.compiler.save_param_dict`.
load_param_dict(bytes: &[u8]) -> Result<HashMap<String, Tensor>, GraphFormatError>444 pub fn load_param_dict(bytes: &[u8]) -> Result<HashMap<String, Tensor>, GraphFormatError> {
445     if let Ok((remaining_bytes, param_dict)) = parse_param_dict(bytes) {
446         if remaining_bytes.len() == 0 {
447             Ok(param_dict)
448         } else {
449             Err(GraphFormatError::Params)
450         }
451     } else {
452         Err(GraphFormatError::Params)
453     }
454 }
455 
456 #[cfg(test)]
457 mod tests {
458     use super::*;
459 
460     #[test]
test_str_to_type()461     fn test_str_to_type() {
462         assert_eq!(
463             tvm_str_to_type(CompleteStr("float24")).unwrap().1,
464             DataType {
465                 code: DLDataTypeCode_kDLFloat as usize,
466                 bits: 24,
467                 lanes: 1
468             }
469         );
470         assert_eq!(
471             tvm_str_to_type(CompleteStr("uint111x44")).unwrap().1,
472             DataType {
473                 code: DLDataTypeCode_kDLUInt as usize,
474                 bits: 111,
475                 lanes: 44
476             }
477         );
478     }
479 }
480