1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 use std::{
21 cell::RefCell,
22 error::Error,
23 os::raw::{c_int, c_void},
24 ptr,
25 };
26
27 use crate::allocator::Allocation;
28 use crate::errors::InvalidPointer;
29 use std::alloc::LayoutErr;
30
31 const WS_ALIGN: usize = 64; // taken from `kTempAllocaAlignment` in `device_api.h`
32
remove_item<T: PartialEq>(vec: &mut Vec<T>, item: &T) -> Option<T>33 pub fn remove_item<T: PartialEq>(vec: &mut Vec<T>, item: &T) -> Option<T> {
34 let pos = vec.iter().position(|x| *x == *item)?;
35 Some(vec.remove(pos))
36 }
37
38 struct WorkspacePool {
39 workspaces: Vec<Allocation>,
40 free: Vec<usize>,
41 in_use: Vec<usize>,
42 }
43
44 impl WorkspacePool {
new() -> Self45 fn new() -> Self {
46 WorkspacePool {
47 workspaces: Vec::new(),
48 free: Vec::new(),
49 in_use: Vec::new(),
50 }
51 }
52
alloc_new(&mut self, size: usize) -> Result<*mut u8, LayoutErr>53 fn alloc_new(&mut self, size: usize) -> Result<*mut u8, LayoutErr> {
54 self.workspaces.push(Allocation::new(size, Some(WS_ALIGN))?);
55 self.in_use.push(self.workspaces.len() - 1);
56 Ok(self.workspaces[self.workspaces.len() - 1].as_mut_ptr())
57 }
58
alloc(&mut self, size: usize) -> Result<*mut u8, LayoutErr>59 fn alloc(&mut self, size: usize) -> Result<*mut u8, LayoutErr> {
60 if self.free.is_empty() {
61 return self.alloc_new(size);
62 }
63 let idx = self
64 .free
65 .iter()
66 .fold(None, |cur_ws_idx: Option<usize>, &idx| {
67 let ws_size = self.workspaces[idx].size();
68 if ws_size < size {
69 return cur_ws_idx;
70 }
71 cur_ws_idx.or(Some(idx)).and_then(|cur_idx| {
72 let cur_size = self.workspaces[cur_idx].size();
73 Some(if ws_size <= cur_size { idx } else { cur_idx })
74 })
75 });
76 match idx {
77 Some(idx) => {
78 remove_item(&mut self.free, &idx).unwrap();
79 self.in_use.push(idx);
80 Ok(self.workspaces[idx].as_mut_ptr())
81 }
82 None => self.alloc_new(size),
83 }
84 }
85
free(&mut self, ptr: *mut u8) -> Result<(), Box<dyn Error>>86 fn free(&mut self, ptr: *mut u8) -> Result<(), Box<dyn Error>> {
87 let mut ws_idx = None;
88 for i in 0..self.in_use.len() {
89 let idx = self.in_use[i];
90 if self.workspaces[idx].as_mut_ptr() == ptr {
91 self.in_use.remove(i);
92 ws_idx = Some(idx);
93 break;
94 }
95 }
96 let ws_idx = ws_idx.ok_or_else(|| InvalidPointer(ptr))?;
97 self.free.push(ws_idx);
98 Ok(())
99 }
100 }
101
102 thread_local!(static WORKSPACE_POOL: RefCell<WorkspacePool> = RefCell::new(WorkspacePool::new()));
103
104 const WORKSPACE_PAGE_SIZE: usize = 4 << 10;
105
106 #[no_mangle]
TVMBackendAllocWorkspace( _device_type: c_int, _device_id: c_int, size: u64, _dtype_code_hint: c_int, _dtype_bits_hint: c_int, ) -> *mut c_void107 pub extern "C" fn TVMBackendAllocWorkspace(
108 _device_type: c_int,
109 _device_id: c_int,
110 size: u64,
111 _dtype_code_hint: c_int,
112 _dtype_bits_hint: c_int,
113 ) -> *mut c_void {
114 let nbytes = if size == 0 {
115 WORKSPACE_PAGE_SIZE
116 } else {
117 size as usize
118 };
119 WORKSPACE_POOL.with(|pool_cell| {
120 pool_cell
121 .borrow_mut()
122 .alloc(nbytes as usize)
123 .unwrap_or(ptr::null_mut()) as *mut c_void
124 })
125 }
126
127 #[no_mangle]
TVMBackendFreeWorkspace( _device_type: c_int, _device_id: c_int, ptr: *mut c_void, ) -> c_int128 pub extern "C" fn TVMBackendFreeWorkspace(
129 _device_type: c_int,
130 _device_id: c_int,
131 ptr: *mut c_void,
132 ) -> c_int {
133 WORKSPACE_POOL.with(|pool_cell| {
134 (match pool_cell.borrow_mut().free(ptr as *mut u8) {
135 Ok(()) => 0,
136 Err(_) => -1,
137 }) as c_int
138 })
139 }
140