1 #pragma once
2
3 #include <utility>
4
5 #include <absl/types/optional.h>
6
7 #include "chainerx/array.h"
8 #include "chainerx/dtype.h"
9 #include "chainerx/macro.h"
10 #include "chainerx/scalar.h"
11
12 namespace chainerx {
13 namespace internal {
14
15 // Returns the default dtype.
GetDefaultDtype(DtypeKind kind)16 inline Dtype GetDefaultDtype(DtypeKind kind) {
17 switch (kind) {
18 case DtypeKind::kBool:
19 return Dtype::kBool;
20 case DtypeKind::kInt:
21 return Dtype::kInt32;
22 case DtypeKind::kFloat:
23 return Dtype::kFloat32;
24 default:
25 CHAINERX_NEVER_REACH();
26 }
27 }
28
GetMathResultDtype(Dtype dtype)29 inline Dtype GetMathResultDtype(Dtype dtype) {
30 if (GetKind(dtype) == DtypeKind::kFloat) {
31 return dtype;
32 }
33 return Dtype::kFloat32; // TODO(niboshi): Default dtype
34 }
35
36 } // namespace internal
37
38 namespace type_util_detail {
39
40 class ResultTypeResolver {
41 public:
42 template <typename Arg, typename... Args>
ResolveArgs(Arg arg,Args...args)43 Dtype ResolveArgs(Arg arg, Args... args) {
44 // At least single argument is required.
45 AddArgsImpl(std::forward<Arg>(arg));
46 AddArgsImpl(std::forward<Args>(args)...);
47 return Resolve();
48 }
49
50 Dtype Resolve() const;
51
52 void AddArg(const Array& arg);
53
54 void AddArg(Scalar arg);
55
56 private:
57 absl::optional<Dtype> array_max_dtype_;
58 absl::optional<Dtype> scalar_max_dtype_;
59
AddArgsImpl()60 void AddArgsImpl() {
61 // nop
62 }
63
64 template <typename Arg, typename... Args>
AddArgsImpl(Arg arg,Args...args)65 void AddArgsImpl(Arg arg, Args... args) {
66 AddArg(std::forward<Arg>(arg));
67 AddArgsImpl(std::forward<Args>(args)...);
68 }
69
GetDtypeCategory(Dtype dtype)70 static int GetDtypeCategory(Dtype dtype) {
71 switch (GetKind(dtype)) {
72 case DtypeKind::kFloat:
73 return 2;
74 default:
75 return 1;
76 }
77 }
78 };
79
80 } // namespace type_util_detail
81
ResultType(const Array & arg)82 inline Dtype ResultType(const Array& arg) { return arg.dtype(); }
83
ResultType(Scalar arg)84 inline Dtype ResultType(Scalar arg) { return internal::GetDefaultDtype(arg.kind()); }
85
86 template <typename Arg, typename... Args>
ResultType(Arg arg,Args...args)87 Dtype ResultType(Arg arg, Args... args) {
88 return type_util_detail::ResultTypeResolver{}.ResolveArgs(std::forward<Arg>(arg), std::forward<Args>(args)...);
89 }
90
91 template <typename Container>
ResultType(Container args)92 Dtype ResultType(Container args) {
93 type_util_detail::ResultTypeResolver resolver{};
94 if (args.empty()) {
95 throw ChainerxError{"At least one argument is required."};
96 }
97 for (const Array& arg : args) {
98 resolver.AddArg(arg);
99 }
100 return resolver.Resolve();
101 }
102
103 } // namespace chainerx
104