1 #include "chainerx/routines/type_util.h"
2 
3 #include "chainerx/array.h"
4 #include "chainerx/dtype.h"
5 #include "chainerx/macro.h"
6 #include "chainerx/scalar.h"
7 
8 namespace chainerx {
9 namespace type_util_detail {
10 
Resolve() const11 Dtype ResultTypeResolver::Resolve() const {
12     // If there were arrays, return the promoted array dtype.
13     // Otherwise, return the promoted scalar dtype.
14     if (array_max_dtype_.has_value()) {
15         Dtype array_max_dtype = *array_max_dtype_;
16         if (scalar_max_dtype_.has_value()) {
17             Dtype scalar_max_dtype = *scalar_max_dtype_;
18             if (GetDtypeCategory(scalar_max_dtype) > GetDtypeCategory(array_max_dtype)) {
19                 return scalar_max_dtype;
20             }
21         }
22         return array_max_dtype;
23     }
24     CHAINERX_ASSERT(scalar_max_dtype_.has_value());
25     return *scalar_max_dtype_;
26 }
27 
AddArg(const Array & arg)28 void ResultTypeResolver::AddArg(const Array& arg) {
29     // If there already were arrays, compare with the promoted array dtype.
30     // Othewise, keep the new dtype and forget scalars.
31     if (array_max_dtype_.has_value()) {
32         array_max_dtype_ = PromoteTypes(*array_max_dtype_, arg.dtype());
33     } else {
34         array_max_dtype_ = arg.dtype();
35     }
36 }
37 
AddArg(Scalar arg)38 void ResultTypeResolver::AddArg(Scalar arg) {
39     if (scalar_max_dtype_.has_value()) {
40         scalar_max_dtype_ = PromoteTypes(*scalar_max_dtype_, internal::GetDefaultDtype(arg.kind()));
41     } else {
42         scalar_max_dtype_ = internal::GetDefaultDtype(arg.kind());
43     }
44 }
45 
46 }  // namespace type_util_detail
47 }  // namespace chainerx
48