1 //===- ComparisonCategories.cpp - Three Way Comparison Data -----*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 //  This file defines the Comparison Category enum and data types, which
10 //  store the types and expressions needed to support operator<=>
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "clang/AST/ComparisonCategories.h"
15 #include "clang/AST/ASTContext.h"
16 #include "clang/AST/Decl.h"
17 #include "clang/AST/DeclCXX.h"
18 #include "clang/AST/Type.h"
19 #include "llvm/ADT/SmallVector.h"
20 #include <optional>
21 
22 using namespace clang;
23 
24 std::optional<ComparisonCategoryType>
25 clang::getComparisonCategoryForBuiltinCmp(QualType T) {
26   using CCT = ComparisonCategoryType;
27 
28   if (T->isIntegralOrEnumerationType())
29     return CCT::StrongOrdering;
30 
31   if (T->isRealFloatingType())
32     return CCT::PartialOrdering;
33 
34   // C++2a [expr.spaceship]p8: If the composite pointer type is an object
35   // pointer type, p <=> q is of type std::strong_ordering.
36   // Note: this assumes neither operand is a null pointer constant.
37   if (T->isObjectPointerType())
38     return CCT::StrongOrdering;
39 
40   // TODO: Extend support for operator<=> to ObjC types.
41   return std::nullopt;
42 }
43 
44 bool ComparisonCategoryInfo::ValueInfo::hasValidIntValue() const {
45   assert(VD && "must have var decl");
46   if (!VD->isUsableInConstantExpressions(VD->getASTContext()))
47     return false;
48 
49   // Before we attempt to get the value of the first field, ensure that we
50   // actually have one (and only one) field.
51   auto *Record = VD->getType()->getAsCXXRecordDecl();
52   if (std::distance(Record->field_begin(), Record->field_end()) != 1 ||
53       !Record->field_begin()->getType()->isIntegralOrEnumerationType())
54     return false;
55 
56   return true;
57 }
58 
59 /// Attempt to determine the integer value used to represent the comparison
60 /// category result by evaluating the initializer for the specified VarDecl as
61 /// a constant expression and retrieving the value of the class's first
62 /// (and only) field.
63 ///
64 /// Note: The STL types are expected to have the form:
65 ///    struct X { T value; };
66 /// where T is an integral or enumeration type.
67 llvm::APSInt ComparisonCategoryInfo::ValueInfo::getIntValue() const {
68   assert(hasValidIntValue() && "must have a valid value");
69   return VD->evaluateValue()->getStructField(0).getInt();
70 }
71 
72 ComparisonCategoryInfo::ValueInfo *ComparisonCategoryInfo::lookupValueInfo(
73     ComparisonCategoryResult ValueKind) const {
74   // Check if we already have a cache entry for this value.
75   auto It = llvm::find_if(
76       Objects, [&](ValueInfo const &Info) { return Info.Kind == ValueKind; });
77   if (It != Objects.end())
78     return &(*It);
79 
80   // We don't have a cached result. Lookup the variable declaration and create
81   // a new entry representing it.
82   DeclContextLookupResult Lookup = Record->getCanonicalDecl()->lookup(
83       &Ctx.Idents.get(ComparisonCategories::getResultString(ValueKind)));
84   if (Lookup.empty() || !isa<VarDecl>(Lookup.front()))
85     return nullptr;
86   Objects.emplace_back(ValueKind, cast<VarDecl>(Lookup.front()));
87   return &Objects.back();
88 }
89 
90 static const NamespaceDecl *lookupStdNamespace(const ASTContext &Ctx,
91                                                NamespaceDecl *&StdNS) {
92   if (!StdNS) {
93     DeclContextLookupResult Lookup =
94         Ctx.getTranslationUnitDecl()->lookup(&Ctx.Idents.get("std"));
95     if (!Lookup.empty())
96       StdNS = dyn_cast<NamespaceDecl>(Lookup.front());
97   }
98   return StdNS;
99 }
100 
101 static CXXRecordDecl *lookupCXXRecordDecl(const ASTContext &Ctx,
102                                           const NamespaceDecl *StdNS,
103                                           ComparisonCategoryType Kind) {
104   StringRef Name = ComparisonCategories::getCategoryString(Kind);
105   DeclContextLookupResult Lookup = StdNS->lookup(&Ctx.Idents.get(Name));
106   if (!Lookup.empty())
107     if (CXXRecordDecl *RD = dyn_cast<CXXRecordDecl>(Lookup.front()))
108       return RD;
109   return nullptr;
110 }
111 
112 const ComparisonCategoryInfo *
113 ComparisonCategories::lookupInfo(ComparisonCategoryType Kind) const {
114   auto It = Data.find(static_cast<char>(Kind));
115   if (It != Data.end())
116     return &It->second;
117 
118   if (const NamespaceDecl *NS = lookupStdNamespace(Ctx, StdNS))
119     if (CXXRecordDecl *RD = lookupCXXRecordDecl(Ctx, NS, Kind))
120       return &Data.try_emplace((char)Kind, Ctx, RD, Kind).first->second;
121 
122   return nullptr;
123 }
124 
125 const ComparisonCategoryInfo *
126 ComparisonCategories::lookupInfoForType(QualType Ty) const {
127   assert(!Ty.isNull() && "type must be non-null");
128   using CCT = ComparisonCategoryType;
129   auto *RD = Ty->getAsCXXRecordDecl();
130   if (!RD)
131     return nullptr;
132 
133   // Check to see if we have information for the specified type cached.
134   const auto *CanonRD = RD->getCanonicalDecl();
135   for (auto &KV : Data) {
136     const ComparisonCategoryInfo &Info = KV.second;
137     if (CanonRD == Info.Record->getCanonicalDecl())
138       return &Info;
139   }
140 
141   if (!RD->getEnclosingNamespaceContext()->isStdNamespace())
142     return nullptr;
143 
144   // If not, check to see if the decl names a type in namespace std with a name
145   // matching one of the comparison category types.
146   for (unsigned I = static_cast<unsigned>(CCT::First),
147                 End = static_cast<unsigned>(CCT::Last);
148        I <= End; ++I) {
149     CCT Kind = static_cast<CCT>(I);
150 
151     // We've found the comparison category type. Build a new cache entry for
152     // it.
153     if (getCategoryString(Kind) == RD->getName())
154       return &Data.try_emplace((char)Kind, Ctx, RD, Kind).first->second;
155   }
156 
157   // We've found nothing. This isn't a comparison category type.
158   return nullptr;
159 }
160 
161 const ComparisonCategoryInfo &ComparisonCategories::getInfoForType(QualType Ty) const {
162   const ComparisonCategoryInfo *Info = lookupInfoForType(Ty);
163   assert(Info && "info for comparison category not found");
164   return *Info;
165 }
166 
167 QualType ComparisonCategoryInfo::getType() const {
168   assert(Record);
169   return QualType(Record->getTypeForDecl(), 0);
170 }
171 
172 StringRef ComparisonCategories::getCategoryString(ComparisonCategoryType Kind) {
173   using CCKT = ComparisonCategoryType;
174   switch (Kind) {
175   case CCKT::PartialOrdering:
176     return "partial_ordering";
177   case CCKT::WeakOrdering:
178     return "weak_ordering";
179   case CCKT::StrongOrdering:
180     return "strong_ordering";
181   }
182   llvm_unreachable("unhandled cases in switch");
183 }
184 
185 StringRef ComparisonCategories::getResultString(ComparisonCategoryResult Kind) {
186   using CCVT = ComparisonCategoryResult;
187   switch (Kind) {
188   case CCVT::Equal:
189     return "equal";
190   case CCVT::Equivalent:
191     return "equivalent";
192   case CCVT::Less:
193     return "less";
194   case CCVT::Greater:
195     return "greater";
196   case CCVT::Unordered:
197     return "unordered";
198   }
199   llvm_unreachable("unhandled case in switch");
200 }
201 
202 std::vector<ComparisonCategoryResult>
203 ComparisonCategories::getPossibleResultsForType(ComparisonCategoryType Type) {
204   using CCT = ComparisonCategoryType;
205   using CCR = ComparisonCategoryResult;
206   std::vector<CCR> Values;
207   Values.reserve(4);
208   bool IsStrong = Type == CCT::StrongOrdering;
209   Values.push_back(IsStrong ? CCR::Equal : CCR::Equivalent);
210   Values.push_back(CCR::Less);
211   Values.push_back(CCR::Greater);
212   if (Type == CCT::PartialOrdering)
213     Values.push_back(CCR::Unordered);
214   return Values;
215 }
216