1 #include "chainerx/routines/type_util.h" 2 3 #include "chainerx/array.h" 4 #include "chainerx/dtype.h" 5 #include "chainerx/macro.h" 6 #include "chainerx/scalar.h" 7 8 namespace chainerx { 9 namespace type_util_detail { 10 Resolve() const11Dtype ResultTypeResolver::Resolve() const { 12 // If there were arrays, return the promoted array dtype. 13 // Otherwise, return the promoted scalar dtype. 14 if (array_max_dtype_.has_value()) { 15 Dtype array_max_dtype = *array_max_dtype_; 16 if (scalar_max_dtype_.has_value()) { 17 Dtype scalar_max_dtype = *scalar_max_dtype_; 18 if (GetDtypeCategory(scalar_max_dtype) > GetDtypeCategory(array_max_dtype)) { 19 return scalar_max_dtype; 20 } 21 } 22 return array_max_dtype; 23 } 24 CHAINERX_ASSERT(scalar_max_dtype_.has_value()); 25 return *scalar_max_dtype_; 26 } 27 AddArg(const Array & arg)28void ResultTypeResolver::AddArg(const Array& arg) { 29 // If there already were arrays, compare with the promoted array dtype. 30 // Othewise, keep the new dtype and forget scalars. 31 if (array_max_dtype_.has_value()) { 32 array_max_dtype_ = PromoteTypes(*array_max_dtype_, arg.dtype()); 33 } else { 34 array_max_dtype_ = arg.dtype(); 35 } 36 } 37 AddArg(Scalar arg)38void ResultTypeResolver::AddArg(Scalar arg) { 39 if (scalar_max_dtype_.has_value()) { 40 scalar_max_dtype_ = PromoteTypes(*scalar_max_dtype_, internal::GetDefaultDtype(arg.kind())); 41 } else { 42 scalar_max_dtype_ = internal::GetDefaultDtype(arg.kind()); 43 } 44 } 45 46 } // namespace type_util_detail 47 } // namespace chainerx 48