1 #pragma once
2 
3 #include <utility>
4 
5 #include <absl/types/optional.h>
6 
7 #include "chainerx/array.h"
8 #include "chainerx/dtype.h"
9 #include "chainerx/macro.h"
10 #include "chainerx/scalar.h"
11 
12 namespace chainerx {
13 namespace internal {
14 
15 // Returns the default dtype.
GetDefaultDtype(DtypeKind kind)16 inline Dtype GetDefaultDtype(DtypeKind kind) {
17     switch (kind) {
18         case DtypeKind::kBool:
19             return Dtype::kBool;
20         case DtypeKind::kInt:
21             return Dtype::kInt32;
22         case DtypeKind::kFloat:
23             return Dtype::kFloat32;
24         default:
25             CHAINERX_NEVER_REACH();
26     }
27 }
28 
GetMathResultDtype(Dtype dtype)29 inline Dtype GetMathResultDtype(Dtype dtype) {
30     if (GetKind(dtype) == DtypeKind::kFloat) {
31         return dtype;
32     }
33     return Dtype::kFloat32;  // TODO(niboshi): Default dtype
34 }
35 
36 }  // namespace internal
37 
38 namespace type_util_detail {
39 
40 class ResultTypeResolver {
41 public:
42     template <typename Arg, typename... Args>
ResolveArgs(Arg arg,Args...args)43     Dtype ResolveArgs(Arg arg, Args... args) {
44         // At least single argument is required.
45         AddArgsImpl(std::forward<Arg>(arg));
46         AddArgsImpl(std::forward<Args>(args)...);
47         return Resolve();
48     }
49 
50     Dtype Resolve() const;
51 
52     void AddArg(const Array& arg);
53 
54     void AddArg(Scalar arg);
55 
56 private:
57     absl::optional<Dtype> array_max_dtype_;
58     absl::optional<Dtype> scalar_max_dtype_;
59 
AddArgsImpl()60     void AddArgsImpl() {
61         // nop
62     }
63 
64     template <typename Arg, typename... Args>
AddArgsImpl(Arg arg,Args...args)65     void AddArgsImpl(Arg arg, Args... args) {
66         AddArg(std::forward<Arg>(arg));
67         AddArgsImpl(std::forward<Args>(args)...);
68     }
69 
GetDtypeCategory(Dtype dtype)70     static int GetDtypeCategory(Dtype dtype) {
71         switch (GetKind(dtype)) {
72             case DtypeKind::kFloat:
73                 return 2;
74             default:
75                 return 1;
76         }
77     }
78 };
79 
80 }  // namespace type_util_detail
81 
ResultType(const Array & arg)82 inline Dtype ResultType(const Array& arg) { return arg.dtype(); }
83 
ResultType(Scalar arg)84 inline Dtype ResultType(Scalar arg) { return internal::GetDefaultDtype(arg.kind()); }
85 
86 template <typename Arg, typename... Args>
ResultType(Arg arg,Args...args)87 Dtype ResultType(Arg arg, Args... args) {
88     return type_util_detail::ResultTypeResolver{}.ResolveArgs(std::forward<Arg>(arg), std::forward<Args>(args)...);
89 }
90 
91 template <typename Container>
ResultType(Container args)92 Dtype ResultType(Container args) {
93     type_util_detail::ResultTypeResolver resolver{};
94     if (args.empty()) {
95         throw ChainerxError{"At least one argument is required."};
96     }
97     for (const Array& arg : args) {
98         resolver.AddArg(arg);
99     }
100     return resolver.Resolve();
101 }
102 
103 }  // namespace chainerx
104