//===- StandardTypes.cpp - MLIR Standard Type Classes ---------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/IR/StandardTypes.h" #include "TypeDetail.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/Diagnostics.h" #include "llvm/ADT/APFloat.h" #include "llvm/ADT/Twine.h" using namespace mlir; using namespace mlir::detail; //===----------------------------------------------------------------------===// // Type //===----------------------------------------------------------------------===// bool Type::isBF16() { return getKind() == StandardTypes::BF16; } bool Type::isF16() { return getKind() == StandardTypes::F16; } bool Type::isF32() { return getKind() == StandardTypes::F32; } bool Type::isF64() { return getKind() == StandardTypes::F64; } bool Type::isIndex() { return isa(); } /// Return true if this is an integer type with the specified width. bool Type::isInteger(unsigned width) { if (auto intTy = dyn_cast()) return intTy.getWidth() == width; return false; } bool Type::isSignlessInteger() { if (auto intTy = dyn_cast()) return intTy.isSignless(); return false; } bool Type::isSignlessInteger(unsigned width) { if (auto intTy = dyn_cast()) return intTy.isSignless() && intTy.getWidth() == width; return false; } bool Type::isSignedInteger() { if (auto intTy = dyn_cast()) return intTy.isSigned(); return false; } bool Type::isSignedInteger(unsigned width) { if (auto intTy = dyn_cast()) return intTy.isSigned() && intTy.getWidth() == width; return false; } bool Type::isUnsignedInteger() { if (auto intTy = dyn_cast()) return intTy.isUnsigned(); return false; } bool Type::isUnsignedInteger(unsigned width) { if (auto intTy = dyn_cast()) return intTy.isUnsigned() && intTy.getWidth() == width; return false; } bool Type::isSignlessIntOrIndex() { return isSignlessInteger() || isa(); } bool Type::isSignlessIntOrIndexOrFloat() { return isSignlessInteger() || isa(); } bool Type::isSignlessIntOrFloat() { return isSignlessInteger() || isa(); } bool Type::isIntOrIndex() { return isa() || isIndex(); } bool Type::isIntOrFloat() { return isa(); } bool Type::isIntOrIndexOrFloat() { return isIntOrFloat() || isIndex(); } //===----------------------------------------------------------------------===// /// ComplexType //===----------------------------------------------------------------------===// ComplexType ComplexType::get(Type elementType) { return Base::get(elementType.getContext(), StandardTypes::Complex, elementType); } ComplexType ComplexType::getChecked(Type elementType, Location location) { return Base::getChecked(location, StandardTypes::Complex, elementType); } /// Verify the construction of an integer type. LogicalResult ComplexType::verifyConstructionInvariants(Location loc, Type elementType) { if (!elementType.isIntOrFloat()) return emitError(loc, "invalid element type for complex"); return success(); } Type ComplexType::getElementType() { return getImpl()->elementType; } //===----------------------------------------------------------------------===// // Integer Type //===----------------------------------------------------------------------===// // static constexpr must have a definition (until in C++17 and inline variable). constexpr unsigned IntegerType::kMaxWidth; /// Verify the construction of an integer type. LogicalResult IntegerType::verifyConstructionInvariants(Location loc, unsigned width, SignednessSemantics signedness) { if (width > IntegerType::kMaxWidth) { return emitError(loc) << "integer bitwidth is limited to " << IntegerType::kMaxWidth << " bits"; } return success(); } unsigned IntegerType::getWidth() const { return getImpl()->getWidth(); } IntegerType::SignednessSemantics IntegerType::getSignedness() const { return getImpl()->getSignedness(); } //===----------------------------------------------------------------------===// // Float Type //===----------------------------------------------------------------------===// unsigned FloatType::getWidth() { switch (getKind()) { case StandardTypes::BF16: case StandardTypes::F16: return 16; case StandardTypes::F32: return 32; case StandardTypes::F64: return 64; default: llvm_unreachable("unexpected type"); } } /// Returns the floating semantics for the given type. const llvm::fltSemantics &FloatType::getFloatSemantics() { if (isBF16()) return APFloat::BFloat(); if (isF16()) return APFloat::IEEEhalf(); if (isF32()) return APFloat::IEEEsingle(); if (isF64()) return APFloat::IEEEdouble(); llvm_unreachable("non-floating point type used"); } unsigned Type::getIntOrFloatBitWidth() { assert(isIntOrFloat() && "only integers and floats have a bitwidth"); if (auto intType = dyn_cast()) return intType.getWidth(); return cast().getWidth(); } //===----------------------------------------------------------------------===// // ShapedType //===----------------------------------------------------------------------===// constexpr int64_t ShapedType::kDynamicSize; constexpr int64_t ShapedType::kDynamicStrideOrOffset; Type ShapedType::getElementType() const { return static_cast(impl)->elementType; } unsigned ShapedType::getElementTypeBitWidth() const { return getElementType().getIntOrFloatBitWidth(); } int64_t ShapedType::getNumElements() const { assert(hasStaticShape() && "cannot get element count of dynamic shaped type"); auto shape = getShape(); int64_t num = 1; for (auto dim : shape) num *= dim; return num; } int64_t ShapedType::getRank() const { return getShape().size(); } bool ShapedType::hasRank() const { return !isa(); } int64_t ShapedType::getDimSize(unsigned idx) const { assert(idx < getRank() && "invalid index for shaped type"); return getShape()[idx]; } bool ShapedType::isDynamicDim(unsigned idx) const { assert(idx < getRank() && "invalid index for shaped type"); return isDynamic(getShape()[idx]); } unsigned ShapedType::getDynamicDimIndex(unsigned index) const { assert(index < getRank() && "invalid index"); assert(ShapedType::isDynamic(getDimSize(index)) && "invalid index"); return llvm::count_if(getShape().take_front(index), ShapedType::isDynamic); } /// Get the number of bits require to store a value of the given shaped type. /// Compute the value recursively since tensors are allowed to have vectors as /// elements. int64_t ShapedType::getSizeInBits() const { assert(hasStaticShape() && "cannot get the bit size of an aggregate with a dynamic shape"); auto elementType = getElementType(); if (elementType.isIntOrFloat()) return elementType.getIntOrFloatBitWidth() * getNumElements(); // Tensors can have vectors and other tensors as elements, other shaped types // cannot. assert(isa() && "unsupported element type"); assert((elementType.isa()) && "unsupported tensor element type"); return getNumElements() * elementType.cast().getSizeInBits(); } ArrayRef ShapedType::getShape() const { switch (getKind()) { case StandardTypes::Vector: return cast().getShape(); case StandardTypes::RankedTensor: return cast().getShape(); case StandardTypes::MemRef: return cast().getShape(); default: llvm_unreachable("not a ShapedType or not ranked"); } } int64_t ShapedType::getNumDynamicDims() const { return llvm::count_if(getShape(), isDynamic); } bool ShapedType::hasStaticShape() const { return hasRank() && llvm::none_of(getShape(), isDynamic); } bool ShapedType::hasStaticShape(ArrayRef shape) const { return hasStaticShape() && getShape() == shape; } //===----------------------------------------------------------------------===// // VectorType //===----------------------------------------------------------------------===// VectorType VectorType::get(ArrayRef shape, Type elementType) { return Base::get(elementType.getContext(), StandardTypes::Vector, shape, elementType); } VectorType VectorType::getChecked(ArrayRef shape, Type elementType, Location location) { return Base::getChecked(location, StandardTypes::Vector, shape, elementType); } LogicalResult VectorType::verifyConstructionInvariants(Location loc, ArrayRef shape, Type elementType) { if (shape.empty()) return emitError(loc, "vector types must have at least one dimension"); if (!isValidElementType(elementType)) return emitError(loc, "vector elements must be int or float type"); if (any_of(shape, [](int64_t i) { return i <= 0; })) return emitError(loc, "vector types must have positive constant sizes"); return success(); } ArrayRef VectorType::getShape() const { return getImpl()->getShape(); } //===----------------------------------------------------------------------===// // TensorType //===----------------------------------------------------------------------===// // Check if "elementType" can be an element type of a tensor. Emit errors if // location is not nullptr. Returns failure if check failed. static inline LogicalResult checkTensorElementType(Location location, Type elementType) { if (!TensorType::isValidElementType(elementType)) return emitError(location, "invalid tensor element type"); return success(); } //===----------------------------------------------------------------------===// // RankedTensorType //===----------------------------------------------------------------------===// RankedTensorType RankedTensorType::get(ArrayRef shape, Type elementType) { return Base::get(elementType.getContext(), StandardTypes::RankedTensor, shape, elementType); } RankedTensorType RankedTensorType::getChecked(ArrayRef shape, Type elementType, Location location) { return Base::getChecked(location, StandardTypes::RankedTensor, shape, elementType); } LogicalResult RankedTensorType::verifyConstructionInvariants( Location loc, ArrayRef shape, Type elementType) { for (int64_t s : shape) { if (s < -1) return emitError(loc, "invalid tensor dimension size"); } return checkTensorElementType(loc, elementType); } ArrayRef RankedTensorType::getShape() const { return getImpl()->getShape(); } //===----------------------------------------------------------------------===// // UnrankedTensorType //===----------------------------------------------------------------------===// UnrankedTensorType UnrankedTensorType::get(Type elementType) { return Base::get(elementType.getContext(), StandardTypes::UnrankedTensor, elementType); } UnrankedTensorType UnrankedTensorType::getChecked(Type elementType, Location location) { return Base::getChecked(location, StandardTypes::UnrankedTensor, elementType); } LogicalResult UnrankedTensorType::verifyConstructionInvariants(Location loc, Type elementType) { return checkTensorElementType(loc, elementType); } //===----------------------------------------------------------------------===// // MemRefType //===----------------------------------------------------------------------===// /// Get or create a new MemRefType based on shape, element type, affine /// map composition, and memory space. Assumes the arguments define a /// well-formed MemRef type. Use getChecked to gracefully handle MemRefType /// construction failures. MemRefType MemRefType::get(ArrayRef shape, Type elementType, ArrayRef affineMapComposition, unsigned memorySpace) { auto result = getImpl(shape, elementType, affineMapComposition, memorySpace, /*location=*/llvm::None); assert(result && "Failed to construct instance of MemRefType."); return result; } /// Get or create a new MemRefType based on shape, element type, affine /// map composition, and memory space declared at the given location. /// If the location is unknown, the last argument should be an instance of /// UnknownLoc. If the MemRefType defined by the arguments would be /// ill-formed, emits errors (to the handler registered with the context or to /// the error stream) and returns nullptr. MemRefType MemRefType::getChecked(ArrayRef shape, Type elementType, ArrayRef affineMapComposition, unsigned memorySpace, Location location) { return getImpl(shape, elementType, affineMapComposition, memorySpace, location); } /// Get or create a new MemRefType defined by the arguments. If the resulting /// type would be ill-formed, return nullptr. If the location is provided, /// emit detailed error messages. To emit errors when the location is unknown, /// pass in an instance of UnknownLoc. MemRefType MemRefType::getImpl(ArrayRef shape, Type elementType, ArrayRef affineMapComposition, unsigned memorySpace, Optional location) { auto *context = elementType.getContext(); // Check that memref is formed from allowed types. if (!elementType.isIntOrFloat() && !elementType.isa()) return emitOptionalError(location, "invalid memref element type"), MemRefType(); for (int64_t s : shape) { // Negative sizes are not allowed except for `-1` that means dynamic size. if (s < -1) return emitOptionalError(location, "invalid memref size"), MemRefType(); } // Check that the structure of the composition is valid, i.e. that each // subsequent affine map has as many inputs as the previous map has results. // Take the dimensionality of the MemRef for the first map. auto dim = shape.size(); unsigned i = 0; for (const auto &affineMap : affineMapComposition) { if (affineMap.getNumDims() != dim) { if (location) emitError(*location) << "memref affine map dimension mismatch between " << (i == 0 ? Twine("memref rank") : "affine map " + Twine(i)) << " and affine map" << i + 1 << ": " << dim << " != " << affineMap.getNumDims(); return nullptr; } dim = affineMap.getNumResults(); ++i; } // Drop identity maps from the composition. // This may lead to the composition becoming empty, which is interpreted as an // implicit identity. SmallVector cleanedAffineMapComposition; for (const auto &map : affineMapComposition) { if (map.isIdentity()) continue; cleanedAffineMapComposition.push_back(map); } return Base::get(context, StandardTypes::MemRef, shape, elementType, cleanedAffineMapComposition, memorySpace); } ArrayRef MemRefType::getShape() const { return getImpl()->getShape(); } ArrayRef MemRefType::getAffineMaps() const { return getImpl()->getAffineMaps(); } unsigned MemRefType::getMemorySpace() const { return getImpl()->memorySpace; } //===----------------------------------------------------------------------===// // UnrankedMemRefType //===----------------------------------------------------------------------===// UnrankedMemRefType UnrankedMemRefType::get(Type elementType, unsigned memorySpace) { return Base::get(elementType.getContext(), StandardTypes::UnrankedMemRef, elementType, memorySpace); } UnrankedMemRefType UnrankedMemRefType::getChecked(Type elementType, unsigned memorySpace, Location location) { return Base::getChecked(location, StandardTypes::UnrankedMemRef, elementType, memorySpace); } unsigned UnrankedMemRefType::getMemorySpace() const { return getImpl()->memorySpace; } LogicalResult UnrankedMemRefType::verifyConstructionInvariants(Location loc, Type elementType, unsigned memorySpace) { // Check that memref is formed from allowed types. if (!elementType.isIntOrFloat() && !elementType.isa()) return emitError(loc, "invalid memref element type"); return success(); } // Fallback cases for terminal dim/sym/cst that are not part of a binary op ( // i.e. single term). Accumulate the AffineExpr into the existing one. static void extractStridesFromTerm(AffineExpr e, AffineExpr multiplicativeFactor, MutableArrayRef strides, AffineExpr &offset) { if (auto dim = e.dyn_cast()) strides[dim.getPosition()] = strides[dim.getPosition()] + multiplicativeFactor; else offset = offset + e * multiplicativeFactor; } /// Takes a single AffineExpr `e` and populates the `strides` array with the /// strides expressions for each dim position. /// The convention is that the strides for dimensions d0, .. dn appear in /// order to make indexing intuitive into the result. static LogicalResult extractStrides(AffineExpr e, AffineExpr multiplicativeFactor, MutableArrayRef strides, AffineExpr &offset) { auto bin = e.dyn_cast(); if (!bin) { extractStridesFromTerm(e, multiplicativeFactor, strides, offset); return success(); } if (bin.getKind() == AffineExprKind::CeilDiv || bin.getKind() == AffineExprKind::FloorDiv || bin.getKind() == AffineExprKind::Mod) return failure(); if (bin.getKind() == AffineExprKind::Mul) { auto dim = bin.getLHS().dyn_cast(); if (dim) { strides[dim.getPosition()] = strides[dim.getPosition()] + bin.getRHS() * multiplicativeFactor; return success(); } // LHS and RHS may both contain complex expressions of dims. Try one path // and if it fails try the other. This is guaranteed to succeed because // only one path may have a `dim`, otherwise this is not an AffineExpr in // the first place. if (bin.getLHS().isSymbolicOrConstant()) return extractStrides(bin.getRHS(), multiplicativeFactor * bin.getLHS(), strides, offset); return extractStrides(bin.getLHS(), multiplicativeFactor * bin.getRHS(), strides, offset); } if (bin.getKind() == AffineExprKind::Add) { auto res1 = extractStrides(bin.getLHS(), multiplicativeFactor, strides, offset); auto res2 = extractStrides(bin.getRHS(), multiplicativeFactor, strides, offset); return success(succeeded(res1) && succeeded(res2)); } llvm_unreachable("unexpected binary operation"); } LogicalResult mlir::getStridesAndOffset(MemRefType t, SmallVectorImpl &strides, AffineExpr &offset) { auto affineMaps = t.getAffineMaps(); // For now strides are only computed on a single affine map with a single // result (i.e. the closed subset of linearization maps that are compatible // with striding semantics). // TODO: support more forms on a per-need basis. if (affineMaps.size() > 1) return failure(); if (affineMaps.size() == 1 && affineMaps[0].getNumResults() != 1) return failure(); auto zero = getAffineConstantExpr(0, t.getContext()); auto one = getAffineConstantExpr(1, t.getContext()); offset = zero; strides.assign(t.getRank(), zero); AffineMap m; if (!affineMaps.empty()) { m = affineMaps.front(); assert(!m.isIdentity() && "unexpected identity map"); } // Canonical case for empty map. if (!m) { // 0-D corner case, offset is already 0. if (t.getRank() == 0) return success(); auto stridedExpr = makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext()); if (succeeded(extractStrides(stridedExpr, one, strides, offset))) return success(); assert(false && "unexpected failure: extract strides in canonical layout"); } // Non-canonical case requires more work. auto stridedExpr = simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols()); if (failed(extractStrides(stridedExpr, one, strides, offset))) { offset = AffineExpr(); strides.clear(); return failure(); } // Simplify results to allow folding to constants and simple checks. unsigned numDims = m.getNumDims(); unsigned numSymbols = m.getNumSymbols(); offset = simplifyAffineExpr(offset, numDims, numSymbols); for (auto &stride : strides) stride = simplifyAffineExpr(stride, numDims, numSymbols); /// In practice, a strided memref must be internally non-aliasing. Test /// against 0 as a proxy. /// TODO: static cases can have more advanced checks. /// TODO: dynamic cases would require a way to compare symbolic /// expressions and would probably need an affine set context propagated /// everywhere. if (llvm::any_of(strides, [](AffineExpr e) { return e == getAffineConstantExpr(0, e.getContext()); })) { offset = AffineExpr(); strides.clear(); return failure(); } return success(); } LogicalResult mlir::getStridesAndOffset(MemRefType t, SmallVectorImpl &strides, int64_t &offset) { AffineExpr offsetExpr; SmallVector strideExprs; if (failed(::getStridesAndOffset(t, strideExprs, offsetExpr))) return failure(); if (auto cst = offsetExpr.dyn_cast()) offset = cst.getValue(); else offset = ShapedType::kDynamicStrideOrOffset; for (auto e : strideExprs) { if (auto c = e.dyn_cast()) strides.push_back(c.getValue()); else strides.push_back(ShapedType::kDynamicStrideOrOffset); } return success(); } //===----------------------------------------------------------------------===// /// TupleType //===----------------------------------------------------------------------===// /// Get or create a new TupleType with the provided element types. Assumes the /// arguments define a well-formed type. TupleType TupleType::get(ArrayRef elementTypes, MLIRContext *context) { return Base::get(context, StandardTypes::Tuple, elementTypes); } /// Return the elements types for this tuple. ArrayRef TupleType::getTypes() const { return getImpl()->getTypes(); } /// Accumulate the types contained in this tuple and tuples nested within it. /// Note that this only flattens nested tuples, not any other container type, /// e.g. a tuple, tuple>> is flattened to /// (i32, tensor, f32, i64) void TupleType::getFlattenedTypes(SmallVectorImpl &types) { for (Type type : getTypes()) { if (auto nestedTuple = type.dyn_cast()) nestedTuple.getFlattenedTypes(types); else types.push_back(type); } } /// Return the number of element types. size_t TupleType::size() const { return getImpl()->size(); } AffineMap mlir::makeStridedLinearLayoutMap(ArrayRef strides, int64_t offset, MLIRContext *context) { AffineExpr expr; unsigned nSymbols = 0; // AffineExpr for offset. // Static case. if (offset != MemRefType::getDynamicStrideOrOffset()) { auto cst = getAffineConstantExpr(offset, context); expr = cst; } else { // Dynamic case, new symbol for the offset. auto sym = getAffineSymbolExpr(nSymbols++, context); expr = sym; } // AffineExpr for strides. for (auto en : llvm::enumerate(strides)) { auto dim = en.index(); auto stride = en.value(); assert(stride != 0 && "Invalid stride specification"); auto d = getAffineDimExpr(dim, context); AffineExpr mult; // Static case. if (stride != MemRefType::getDynamicStrideOrOffset()) mult = getAffineConstantExpr(stride, context); else // Dynamic case, new symbol for each new stride. mult = getAffineSymbolExpr(nSymbols++, context); expr = expr + d * mult; } return AffineMap::get(strides.size(), nSymbols, expr); } /// Return a version of `t` with identity layout if it can be determined /// statically that the layout is the canonical contiguous strided layout. /// Otherwise pass `t`'s layout into `simplifyAffineMap` and return a copy of /// `t` with simplified layout. /// If `t` has multiple layout maps or a multi-result layout, just return `t`. MemRefType mlir::canonicalizeStridedLayout(MemRefType t) { auto affineMaps = t.getAffineMaps(); // Already in canonical form. if (affineMaps.empty()) return t; // Can't reduce to canonical identity form, return in canonical form. if (affineMaps.size() > 1 || affineMaps[0].getNumResults() > 1) return t; // If the canonical strided layout for the sizes of `t` is equal to the // simplified layout of `t` we can just return an empty layout. Otherwise, // just simplify the existing layout. AffineExpr expr = makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext()); auto m = affineMaps[0]; auto simplifiedLayoutExpr = simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols()); if (expr != simplifiedLayoutExpr) return MemRefType::Builder(t).setAffineMaps({AffineMap::get( m.getNumDims(), m.getNumSymbols(), simplifiedLayoutExpr)}); return MemRefType::Builder(t).setAffineMaps({}); } AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef sizes, ArrayRef exprs, MLIRContext *context) { AffineExpr expr; bool dynamicPoisonBit = false; unsigned numDims = 0; unsigned nSymbols = 0; // Compute the number of symbols and dimensions of the passed exprs. for (AffineExpr expr : exprs) { expr.walk([&numDims, &nSymbols](AffineExpr d) { if (AffineDimExpr dim = d.dyn_cast()) numDims = std::max(numDims, dim.getPosition() + 1); else if (AffineSymbolExpr symbol = d.dyn_cast()) nSymbols = std::max(nSymbols, symbol.getPosition() + 1); }); } int64_t runningSize = 1; for (auto en : llvm::zip(llvm::reverse(exprs), llvm::reverse(sizes))) { int64_t size = std::get<1>(en); // Degenerate case, no size =-> no stride if (size == 0) continue; AffineExpr dimExpr = std::get<0>(en); AffineExpr stride = dynamicPoisonBit ? getAffineSymbolExpr(nSymbols++, context) : getAffineConstantExpr(runningSize, context); expr = expr ? expr + dimExpr * stride : dimExpr * stride; if (size > 0) runningSize *= size; else dynamicPoisonBit = true; } return simplifyAffineExpr(expr, numDims, nSymbols); } /// Return a version of `t` with a layout that has all dynamic offset and /// strides. This is used to erase the static layout. MemRefType mlir::eraseStridedLayout(MemRefType t) { auto val = ShapedType::kDynamicStrideOrOffset; return MemRefType::Builder(t).setAffineMaps(makeStridedLinearLayoutMap( SmallVector(t.getRank(), val), val, t.getContext())); } AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef sizes, MLIRContext *context) { SmallVector exprs; exprs.reserve(sizes.size()); for (auto dim : llvm::seq(0, sizes.size())) exprs.push_back(getAffineDimExpr(dim, context)); return makeCanonicalStridedLayoutExpr(sizes, exprs, context); } /// Return true if the layout for `t` is compatible with strided semantics. bool mlir::isStrided(MemRefType t) { int64_t offset; SmallVector stridesAndOffset; auto res = getStridesAndOffset(t, stridesAndOffset, offset); return succeeded(res); }