1 //---------------------------------------------------------------------------// 2 // Copyright (c) 2013 Kyle Lutz <kyle.r.lutz@gmail.com> 3 // 4 // Distributed under the Boost Software License, Version 1.0 5 // See accompanying file LICENSE_1_0.txt or copy at 6 // http://www.boost.org/LICENSE_1_0.txt 7 // 8 // See http://boostorg.github.com/compute for more information. 9 //---------------------------------------------------------------------------// 10 11 #ifndef BOOST_COMPUTE_DEVICE_PTR_HPP 12 #define BOOST_COMPUTE_DEVICE_PTR_HPP 13 14 #include <boost/type_traits.hpp> 15 #include <boost/static_assert.hpp> 16 17 #include <boost/compute/buffer.hpp> 18 #include <boost/compute/config.hpp> 19 #include <boost/compute/detail/is_buffer_iterator.hpp> 20 #include <boost/compute/detail/read_write_single_value.hpp> 21 #include <boost/compute/type_traits/is_device_iterator.hpp> 22 23 namespace boost { 24 namespace compute { 25 namespace detail { 26 27 template<class T, class IndexExpr> 28 struct device_ptr_index_expr 29 { 30 typedef T result_type; 31 device_ptr_index_exprboost::compute::detail::device_ptr_index_expr32 device_ptr_index_expr(const buffer &buffer, 33 uint_ index, 34 const IndexExpr &expr) 35 : m_buffer(buffer), 36 m_index(index), 37 m_expr(expr) 38 { 39 } 40 operator Tboost::compute::detail::device_ptr_index_expr41 operator T() const 42 { 43 BOOST_STATIC_ASSERT_MSG(boost::is_integral<IndexExpr>::value, 44 "Index expression must be integral"); 45 46 BOOST_ASSERT(m_buffer.get()); 47 48 const context &context = m_buffer.get_context(); 49 const device &device = context.get_device(); 50 command_queue queue(context, device); 51 52 return detail::read_single_value<T>(m_buffer, m_expr, queue); 53 } 54 55 const buffer &m_buffer; 56 uint_ m_index; 57 IndexExpr m_expr; 58 }; 59 60 template<class T> 61 class device_ptr 62 { 63 public: 64 typedef T value_type; 65 typedef std::size_t size_type; 66 typedef std::ptrdiff_t difference_type; 67 typedef std::random_access_iterator_tag iterator_category; 68 typedef T* pointer; 69 typedef T& reference; 70 device_ptr()71 device_ptr() 72 : m_index(0) 73 { 74 } 75 device_ptr(const buffer & buffer,size_t index=0)76 device_ptr(const buffer &buffer, size_t index = 0) 77 : m_buffer(buffer.get(), false), 78 m_index(index) 79 { 80 } 81 device_ptr(const device_ptr<T> & other)82 device_ptr(const device_ptr<T> &other) 83 : m_buffer(other.m_buffer.get(), false), 84 m_index(other.m_index) 85 { 86 } 87 operator =(const device_ptr<T> & other)88 device_ptr<T>& operator=(const device_ptr<T> &other) 89 { 90 if(this != &other){ 91 m_buffer.get() = other.m_buffer.get(); 92 m_index = other.m_index; 93 } 94 95 return *this; 96 } 97 98 #ifndef BOOST_COMPUTE_NO_RVALUE_REFERENCES device_ptr(device_ptr<T> && other)99 device_ptr(device_ptr<T>&& other) BOOST_NOEXCEPT 100 : m_buffer(other.m_buffer.get(), false), 101 m_index(other.m_index) 102 { 103 other.m_buffer.get() = 0; 104 } 105 operator =(device_ptr<T> && other)106 device_ptr<T>& operator=(device_ptr<T>&& other) BOOST_NOEXCEPT 107 { 108 m_buffer.get() = other.m_buffer.get(); 109 m_index = other.m_index; 110 111 other.m_buffer.get() = 0; 112 113 return *this; 114 } 115 #endif // BOOST_COMPUTE_NO_RVALUE_REFERENCES 116 ~device_ptr()117 ~device_ptr() 118 { 119 // set buffer to null so that its reference count will 120 // not be decremented when its destructor is called 121 m_buffer.get() = 0; 122 } 123 get_index() const124 size_type get_index() const 125 { 126 return m_index; 127 } 128 get_buffer() const129 const buffer& get_buffer() const 130 { 131 return m_buffer; 132 } 133 134 template<class OT> cast() const135 device_ptr<OT> cast() const 136 { 137 return device_ptr<OT>(m_buffer, m_index); 138 } 139 operator +(difference_type n) const140 device_ptr<T> operator+(difference_type n) const 141 { 142 return device_ptr<T>(m_buffer, m_index + n); 143 } 144 operator +(const device_ptr<T> & other) const145 device_ptr<T> operator+(const device_ptr<T> &other) const 146 { 147 return device_ptr<T>(m_buffer, m_index + other.m_index); 148 } 149 operator +=(difference_type n)150 device_ptr<T>& operator+=(difference_type n) 151 { 152 m_index += static_cast<size_t>(n); 153 return *this; 154 } 155 operator -(const device_ptr<T> & other) const156 difference_type operator-(const device_ptr<T> &other) const 157 { 158 return static_cast<difference_type>(m_index - other.m_index); 159 } 160 operator -=(difference_type n)161 device_ptr<T>& operator-=(difference_type n) 162 { 163 m_index -= n; 164 return *this; 165 } 166 operator ==(const device_ptr<T> & other) const167 bool operator==(const device_ptr<T> &other) const 168 { 169 return m_buffer.get() == other.m_buffer.get() && 170 m_index == other.m_index; 171 } 172 operator !=(const device_ptr<T> & other) const173 bool operator!=(const device_ptr<T> &other) const 174 { 175 return !(*this == other); 176 } 177 178 template<class Expr> 179 detail::device_ptr_index_expr<T, Expr> operator [](const Expr & expr) const180 operator[](const Expr &expr) const 181 { 182 BOOST_ASSERT(m_buffer.get()); 183 184 return detail::device_ptr_index_expr<T, Expr>(m_buffer, 185 uint_(m_index), 186 expr); 187 } 188 189 private: 190 const buffer m_buffer; 191 size_t m_index; 192 }; 193 194 // is_buffer_iterator specialization for device_ptr 195 template<class Iterator> 196 struct is_buffer_iterator< 197 Iterator, 198 typename boost::enable_if< 199 boost::is_same< 200 device_ptr<typename Iterator::value_type>, 201 typename boost::remove_const<Iterator>::type 202 > 203 >::type 204 > : public boost::true_type {}; 205 206 } // end detail namespace 207 208 // is_device_iterator specialization for device_ptr 209 template<class T> 210 struct is_device_iterator<detail::device_ptr<T> > : boost::true_type {}; 211 212 } // end compute namespace 213 } // end boost namespace 214 215 #endif // BOOST_COMPUTE_DEVICE_PTR_HPP 216