forked from OSchip/llvm-project
Make helper functions static or move them into anonymous namespaces. NFC.
This commit is contained in:
parent
bff33bd5c8
commit
df186507e1
|
@ -3081,7 +3081,7 @@ Error ASTNodeImporter::ImportFunctionDeclBody(FunctionDecl *FromFD,
|
||||||
|
|
||||||
// Returns true if the given D has a DeclContext up to the TranslationUnitDecl
|
// Returns true if the given D has a DeclContext up to the TranslationUnitDecl
|
||||||
// which is equal to the given DC.
|
// which is equal to the given DC.
|
||||||
bool isAncestorDeclContextOf(const DeclContext *DC, const Decl *D) {
|
static bool isAncestorDeclContextOf(const DeclContext *DC, const Decl *D) {
|
||||||
const DeclContext *DCi = D->getDeclContext();
|
const DeclContext *DCi = D->getDeclContext();
|
||||||
while (DCi != D->getTranslationUnitDecl()) {
|
while (DCi != D->getTranslationUnitDecl()) {
|
||||||
if (DCi == DC)
|
if (DCi == DC)
|
||||||
|
|
|
@ -14260,6 +14260,7 @@ CodeGenFunction::EmitNVPTXBuiltinExpr(unsigned BuiltinID, const CallExpr *E) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
struct BuiltinAlignArgs {
|
struct BuiltinAlignArgs {
|
||||||
llvm::Value *Src = nullptr;
|
llvm::Value *Src = nullptr;
|
||||||
llvm::Type *SrcType = nullptr;
|
llvm::Type *SrcType = nullptr;
|
||||||
|
@ -14288,6 +14289,7 @@ struct BuiltinAlignArgs {
|
||||||
Mask = CGF.Builder.CreateSub(Alignment, One, "mask");
|
Mask = CGF.Builder.CreateSub(Alignment, One, "mask");
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
/// Generate (x & (y-1)) == 0.
|
/// Generate (x & (y-1)) == 0.
|
||||||
RValue CodeGenFunction::EmitBuiltinIsAligned(const CallExpr *E) {
|
RValue CodeGenFunction::EmitBuiltinIsAligned(const CallExpr *E) {
|
||||||
|
|
|
@ -1438,6 +1438,7 @@ CGOpenMPRuntime::getUserDefinedReduction(const OMPDeclareReductionDecl *D) {
|
||||||
return UDRMap.lookup(D);
|
return UDRMap.lookup(D);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
// Temporary RAII solution to perform a push/pop stack event on the OpenMP IR
|
// Temporary RAII solution to perform a push/pop stack event on the OpenMP IR
|
||||||
// Builder if one is present.
|
// Builder if one is present.
|
||||||
struct PushAndPopStackRAII {
|
struct PushAndPopStackRAII {
|
||||||
|
@ -1481,6 +1482,7 @@ struct PushAndPopStackRAII {
|
||||||
}
|
}
|
||||||
llvm::OpenMPIRBuilder *OMPBuilder;
|
llvm::OpenMPIRBuilder *OMPBuilder;
|
||||||
};
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
static llvm::Function *emitParallelOrTeamsOutlinedFunction(
|
static llvm::Function *emitParallelOrTeamsOutlinedFunction(
|
||||||
CodeGenModule &CGM, const OMPExecutableDirective &D, const CapturedStmt *CS,
|
CodeGenModule &CGM, const OMPExecutableDirective &D, const CapturedStmt *CS,
|
||||||
|
@ -11122,8 +11124,8 @@ bool checkContext<OMP_CTX_SET_device, OMP_CTX_kind, CodeGenModule &>(
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool matchesContext(CodeGenModule &CGM,
|
static bool matchesContext(CodeGenModule &CGM,
|
||||||
const CompleteOMPContextSelectorData &ContextData) {
|
const CompleteOMPContextSelectorData &ContextData) {
|
||||||
for (const OMPContextSelectorData &Data : ContextData) {
|
for (const OMPContextSelectorData &Data : ContextData) {
|
||||||
switch (Data.Ctx) {
|
switch (Data.Ctx) {
|
||||||
case OMP_CTX_vendor:
|
case OMP_CTX_vendor:
|
||||||
|
|
|
@ -63,9 +63,9 @@ static void getARMHWDivFeatures(const Driver &D, const Arg *A,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle -mfpu=.
|
// Handle -mfpu=.
|
||||||
unsigned getARMFPUFeatures(const Driver &D, const Arg *A,
|
static unsigned getARMFPUFeatures(const Driver &D, const Arg *A,
|
||||||
const ArgList &Args, StringRef FPU,
|
const ArgList &Args, StringRef FPU,
|
||||||
std::vector<StringRef> &Features) {
|
std::vector<StringRef> &Features) {
|
||||||
unsigned FPUID = llvm::ARM::parseFPU(FPU);
|
unsigned FPUID = llvm::ARM::parseFPU(FPU);
|
||||||
if (!llvm::ARM::getFPUFeatures(FPUID, Features))
|
if (!llvm::ARM::getFPUFeatures(FPUID, Features))
|
||||||
D.Diag(clang::diag::err_drv_clang_unsupported) << A->getAsString(Args);
|
D.Diag(clang::diag::err_drv_clang_unsupported) << A->getAsString(Args);
|
||||||
|
|
|
@ -6,6 +6,7 @@
|
||||||
using namespace clang;
|
using namespace clang;
|
||||||
using namespace ento;
|
using namespace ento;
|
||||||
|
|
||||||
|
namespace {
|
||||||
class PlacementNewChecker : public Checker<check::PreStmt<CXXNewExpr>> {
|
class PlacementNewChecker : public Checker<check::PreStmt<CXXNewExpr>> {
|
||||||
public:
|
public:
|
||||||
void checkPreStmt(const CXXNewExpr *NE, CheckerContext &C) const;
|
void checkPreStmt(const CXXNewExpr *NE, CheckerContext &C) const;
|
||||||
|
@ -22,6 +23,7 @@ private:
|
||||||
BugType BT{this, "Insufficient storage for placement new",
|
BugType BT{this, "Insufficient storage for placement new",
|
||||||
categories::MemoryError};
|
categories::MemoryError};
|
||||||
};
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
SVal PlacementNewChecker::getExtentSizeOfPlace(const Expr *Place,
|
SVal PlacementNewChecker::getExtentSizeOfPlace(const Expr *Place,
|
||||||
ProgramStateRef State,
|
ProgramStateRef State,
|
||||||
|
|
|
@ -223,7 +223,8 @@ static const ExplodedNode *getAcquireSite(const ExplodedNode *N, SymbolRef Sym,
|
||||||
|
|
||||||
/// Returns the symbols extracted from the argument or null if it cannot be
|
/// Returns the symbols extracted from the argument or null if it cannot be
|
||||||
/// found.
|
/// found.
|
||||||
SymbolRef getFuchsiaHandleSymbol(QualType QT, SVal Arg, ProgramStateRef State) {
|
static SymbolRef getFuchsiaHandleSymbol(QualType QT, SVal Arg,
|
||||||
|
ProgramStateRef State) {
|
||||||
int PtrToHandleLevel = 0;
|
int PtrToHandleLevel = 0;
|
||||||
while (QT->isAnyPointerType() || QT->isReferenceType()) {
|
while (QT->isAnyPointerType() || QT->isReferenceType()) {
|
||||||
++PtrToHandleLevel;
|
++PtrToHandleLevel;
|
||||||
|
|
|
@ -1734,8 +1734,8 @@ static bool getHexUint(const MIToken &Token, APInt &Result) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool getUnsigned(const MIToken &Token, unsigned &Result,
|
static bool getUnsigned(const MIToken &Token, unsigned &Result,
|
||||||
ErrorCallbackType ErrCB) {
|
ErrorCallbackType ErrCB) {
|
||||||
if (Token.hasIntegerValue()) {
|
if (Token.hasIntegerValue()) {
|
||||||
const uint64_t Limit = uint64_t(std::numeric_limits<unsigned>::max()) + 1;
|
const uint64_t Limit = uint64_t(std::numeric_limits<unsigned>::max()) + 1;
|
||||||
uint64_t Val64 = Token.integerValue().getLimitedValue(Limit);
|
uint64_t Val64 = Token.integerValue().getLimitedValue(Limit);
|
||||||
|
|
|
@ -837,7 +837,7 @@ static bool isTargetDarwin(const MachineFunction &MF) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Convenience function to determine whether I is an SVE callee save.
|
// Convenience function to determine whether I is an SVE callee save.
|
||||||
bool IsSVECalleeSave(MachineBasicBlock::iterator I) {
|
static bool IsSVECalleeSave(MachineBasicBlock::iterator I) {
|
||||||
switch (I->getOpcode()) {
|
switch (I->getOpcode()) {
|
||||||
default:
|
default:
|
||||||
return false;
|
return false;
|
||||||
|
|
|
@ -341,9 +341,9 @@ static bool MSA3OpIntrinsicToGeneric(MachineInstr &MI, unsigned Opcode,
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool MSA2OpIntrinsicToGeneric(MachineInstr &MI, unsigned Opcode,
|
static bool MSA2OpIntrinsicToGeneric(MachineInstr &MI, unsigned Opcode,
|
||||||
MachineIRBuilder &MIRBuilder,
|
MachineIRBuilder &MIRBuilder,
|
||||||
const MipsSubtarget &ST) {
|
const MipsSubtarget &ST) {
|
||||||
assert(ST.hasMSA() && "MSA intrinsic not supported on target without MSA.");
|
assert(ST.hasMSA() && "MSA intrinsic not supported on target without MSA.");
|
||||||
MIRBuilder.buildInstr(Opcode)
|
MIRBuilder.buildInstr(Opcode)
|
||||||
.add(MI.getOperand(0))
|
.add(MI.getOperand(0))
|
||||||
|
|
|
@ -1649,7 +1649,7 @@ Instruction *InstCombiner::narrowMathIfNoOverflow(BinaryOperator &BO) {
|
||||||
return CastInst::Create(CastOpc, NarrowBO, BO.getType());
|
return CastInst::Create(CastOpc, NarrowBO, BO.getType());
|
||||||
}
|
}
|
||||||
|
|
||||||
bool isMergedGEPInBounds(GEPOperator &GEP1, GEPOperator &GEP2) {
|
static bool isMergedGEPInBounds(GEPOperator &GEP1, GEPOperator &GEP2) {
|
||||||
// At least one GEP must be inbounds.
|
// At least one GEP must be inbounds.
|
||||||
if (!GEP1.isInBounds() && !GEP2.isInBounds())
|
if (!GEP1.isInBounds() && !GEP2.isInBounds())
|
||||||
return false;
|
return false;
|
||||||
|
|
|
@ -117,6 +117,7 @@ static LogicalResult getInstIndexSet(Operation *op,
|
||||||
return getIndexSet(loops, indexSet);
|
return getIndexSet(loops, indexSet);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
// ValuePositionMap manages the mapping from Values which represent dimension
|
// ValuePositionMap manages the mapping from Values which represent dimension
|
||||||
// and symbol identifiers from 'src' and 'dst' access functions to positions
|
// and symbol identifiers from 'src' and 'dst' access functions to positions
|
||||||
// in new space where some Values are kept separate (using addSrc/DstValue)
|
// in new space where some Values are kept separate (using addSrc/DstValue)
|
||||||
|
@ -195,6 +196,7 @@ private:
|
||||||
DenseMap<Value, unsigned> dstDimPosMap;
|
DenseMap<Value, unsigned> dstDimPosMap;
|
||||||
DenseMap<Value, unsigned> symbolPosMap;
|
DenseMap<Value, unsigned> symbolPosMap;
|
||||||
};
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
// Builds a map from Value to identifier position in a new merged identifier
|
// Builds a map from Value to identifier position in a new merged identifier
|
||||||
// list, which is the result of merging dim/symbol lists from src/dst
|
// list, which is the result of merging dim/symbol lists from src/dst
|
||||||
|
@ -240,12 +242,11 @@ static void buildDimAndSymbolPositionMaps(
|
||||||
|
|
||||||
// Sets up dependence constraints columns appropriately, in the format:
|
// Sets up dependence constraints columns appropriately, in the format:
|
||||||
// [src-dim-ids, dst-dim-ids, symbol-ids, local-ids, const_term]
|
// [src-dim-ids, dst-dim-ids, symbol-ids, local-ids, const_term]
|
||||||
void initDependenceConstraints(const FlatAffineConstraints &srcDomain,
|
static void initDependenceConstraints(
|
||||||
const FlatAffineConstraints &dstDomain,
|
const FlatAffineConstraints &srcDomain,
|
||||||
const AffineValueMap &srcAccessMap,
|
const FlatAffineConstraints &dstDomain, const AffineValueMap &srcAccessMap,
|
||||||
const AffineValueMap &dstAccessMap,
|
const AffineValueMap &dstAccessMap, const ValuePositionMap &valuePosMap,
|
||||||
const ValuePositionMap &valuePosMap,
|
FlatAffineConstraints *dependenceConstraints) {
|
||||||
FlatAffineConstraints *dependenceConstraints) {
|
|
||||||
// Calculate number of equalities/inequalities and columns required to
|
// Calculate number of equalities/inequalities and columns required to
|
||||||
// initialize FlatAffineConstraints for 'dependenceDomain'.
|
// initialize FlatAffineConstraints for 'dependenceDomain'.
|
||||||
unsigned numIneq =
|
unsigned numIneq =
|
||||||
|
|
|
@ -1388,8 +1388,9 @@ static void getLowerAndUpperBoundIndices(const FlatAffineConstraints &cst,
|
||||||
// Check if the pos^th identifier can be expressed as a floordiv of an affine
|
// Check if the pos^th identifier can be expressed as a floordiv of an affine
|
||||||
// function of other identifiers (where the divisor is a positive constant).
|
// function of other identifiers (where the divisor is a positive constant).
|
||||||
// For eg: 4q <= i + j <= 4q + 3 <=> q = (i + j) floordiv 4.
|
// For eg: 4q <= i + j <= 4q + 3 <=> q = (i + j) floordiv 4.
|
||||||
bool detectAsFloorDiv(const FlatAffineConstraints &cst, unsigned pos,
|
static bool detectAsFloorDiv(const FlatAffineConstraints &cst, unsigned pos,
|
||||||
SmallVectorImpl<AffineExpr> *memo, MLIRContext *context) {
|
SmallVectorImpl<AffineExpr> *memo,
|
||||||
|
MLIRContext *context) {
|
||||||
assert(pos < cst.getNumIds() && "invalid position");
|
assert(pos < cst.getNumIds() && "invalid position");
|
||||||
|
|
||||||
SmallVector<unsigned, 4> lbIndices, ubIndices;
|
SmallVector<unsigned, 4> lbIndices, ubIndices;
|
||||||
|
|
|
@ -21,6 +21,7 @@
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
|
|
||||||
|
namespace {
|
||||||
/// Builds and holds block information during the construction phase.
|
/// Builds and holds block information during the construction phase.
|
||||||
struct BlockInfoBuilder {
|
struct BlockInfoBuilder {
|
||||||
using ValueSetT = Liveness::ValueSetT;
|
using ValueSetT = Liveness::ValueSetT;
|
||||||
|
@ -107,6 +108,7 @@ struct BlockInfoBuilder {
|
||||||
/// The set of all used values.
|
/// The set of all used values.
|
||||||
ValueSetT useValues;
|
ValueSetT useValues;
|
||||||
};
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
/// Builds the internal liveness block mapping.
|
/// Builds the internal liveness block mapping.
|
||||||
static void buildBlockMapping(MutableArrayRef<Region> regions,
|
static void buildBlockMapping(MutableArrayRef<Region> regions,
|
||||||
|
|
|
@ -471,8 +471,8 @@ static Operation *getInstAtPosition(ArrayRef<unsigned> positions,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Adds loop IV bounds to 'cst' for loop IVs not found in 'ivs'.
|
// Adds loop IV bounds to 'cst' for loop IVs not found in 'ivs'.
|
||||||
LogicalResult addMissingLoopIVBounds(SmallPtrSet<Value, 8> &ivs,
|
static LogicalResult addMissingLoopIVBounds(SmallPtrSet<Value, 8> &ivs,
|
||||||
FlatAffineConstraints *cst) {
|
FlatAffineConstraints *cst) {
|
||||||
for (unsigned i = 0, e = cst->getNumDimIds(); i < e; ++i) {
|
for (unsigned i = 0, e = cst->getNumDimIds(); i < e; ++i) {
|
||||||
auto value = cst->getIdValue(i);
|
auto value = cst->getIdValue(i);
|
||||||
if (ivs.count(value) == 0) {
|
if (ivs.count(value) == 0) {
|
||||||
|
|
|
@ -70,6 +70,8 @@ using urem = ValueBuilder<mlir::LLVM::URemOp>;
|
||||||
using llvm_alloca = ValueBuilder<LLVM::AllocaOp>;
|
using llvm_alloca = ValueBuilder<LLVM::AllocaOp>;
|
||||||
using llvm_return = OperationBuilder<LLVM::ReturnOp>;
|
using llvm_return = OperationBuilder<LLVM::ReturnOp>;
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
static LLVMType getPtrToElementType(T containerType,
|
static LLVMType getPtrToElementType(T containerType,
|
||||||
LLVMTypeConverter &lowering) {
|
LLVMTypeConverter &lowering) {
|
||||||
|
@ -104,7 +106,6 @@ static Type convertLinalgType(Type t, LLVMTypeConverter &lowering) {
|
||||||
return Type();
|
return Type();
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace {
|
|
||||||
/// EDSC-compatible wrapper for MemRefDescriptor.
|
/// EDSC-compatible wrapper for MemRefDescriptor.
|
||||||
class BaseViewConversionHelper {
|
class BaseViewConversionHelper {
|
||||||
public:
|
public:
|
||||||
|
@ -139,7 +140,6 @@ private:
|
||||||
|
|
||||||
MemRefDescriptor d;
|
MemRefDescriptor d;
|
||||||
};
|
};
|
||||||
} // namespace
|
|
||||||
|
|
||||||
// RangeOp creates a new range descriptor.
|
// RangeOp creates a new range descriptor.
|
||||||
class RangeOpConversion : public LLVMOpLowering {
|
class RangeOpConversion : public LLVMOpLowering {
|
||||||
|
@ -421,12 +421,16 @@ static FlatSymbolRefAttr getLibraryCallSymbolRef(Operation *op,
|
||||||
return fnNameAttr;
|
return fnNameAttr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
Type LinalgTypeConverter::convertType(Type t) {
|
Type LinalgTypeConverter::convertType(Type t) {
|
||||||
if (auto result = LLVMTypeConverter::convertType(t))
|
if (auto result = LLVMTypeConverter::convertType(t))
|
||||||
return result;
|
return result;
|
||||||
return convertLinalgType(t, *this);
|
return convertLinalgType(t, *this);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
// LinalgOpConversion<LinalgOp> creates a new call to the
|
// LinalgOpConversion<LinalgOp> creates a new call to the
|
||||||
// `LinalgOp::getLibraryCallName()` function.
|
// `LinalgOp::getLibraryCallName()` function.
|
||||||
// The implementation of the function can be either in the same module or in an
|
// The implementation of the function can be either in the same module or in an
|
||||||
|
@ -552,6 +556,8 @@ populateLinalgToStandardConversionPatterns(OwningRewritePatternList &patterns,
|
||||||
ctx);
|
ctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
/// Populate the given list with patterns that convert from Linalg to LLVM.
|
/// Populate the given list with patterns that convert from Linalg to LLVM.
|
||||||
void mlir::populateLinalgToLLVMConversionPatterns(
|
void mlir::populateLinalgToLLVMConversionPatterns(
|
||||||
LinalgTypeConverter &converter, OwningRewritePatternList &patterns,
|
LinalgTypeConverter &converter, OwningRewritePatternList &patterns,
|
||||||
|
|
|
@ -98,7 +98,7 @@ static Value getOrEmitUpperBound(ForOp forOp, OpBuilder &) {
|
||||||
// This roughly corresponds to the "matcher" part of the pattern-based
|
// This roughly corresponds to the "matcher" part of the pattern-based
|
||||||
// rewriting infrastructure.
|
// rewriting infrastructure.
|
||||||
template <typename OpTy>
|
template <typename OpTy>
|
||||||
LogicalResult checkLoopNestMappableImpl(OpTy forOp, unsigned numDims) {
|
static LogicalResult checkLoopNestMappableImpl(OpTy forOp, unsigned numDims) {
|
||||||
Region &limit = forOp.region();
|
Region &limit = forOp.region();
|
||||||
for (unsigned i = 0, e = numDims; i < e; ++i) {
|
for (unsigned i = 0, e = numDims; i < e; ++i) {
|
||||||
Operation *nested = &forOp.getBody()->front();
|
Operation *nested = &forOp.getBody()->front();
|
||||||
|
@ -124,8 +124,8 @@ LogicalResult checkLoopNestMappableImpl(OpTy forOp, unsigned numDims) {
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename OpTy>
|
template <typename OpTy>
|
||||||
LogicalResult checkLoopNestMappable(OpTy forOp, unsigned numBlockDims,
|
static LogicalResult checkLoopNestMappable(OpTy forOp, unsigned numBlockDims,
|
||||||
unsigned numThreadDims) {
|
unsigned numThreadDims) {
|
||||||
if (numBlockDims < 1 || numThreadDims < 1) {
|
if (numBlockDims < 1 || numThreadDims < 1) {
|
||||||
LLVM_DEBUG(llvm::dbgs() << "nothing to map");
|
LLVM_DEBUG(llvm::dbgs() << "nothing to map");
|
||||||
return success();
|
return success();
|
||||||
|
@ -142,8 +142,8 @@ LogicalResult checkLoopNestMappable(OpTy forOp, unsigned numBlockDims,
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename OpTy>
|
template <typename OpTy>
|
||||||
LogicalResult checkLoopOpMappable(OpTy forOp, unsigned numBlockDims,
|
static LogicalResult checkLoopOpMappable(OpTy forOp, unsigned numBlockDims,
|
||||||
unsigned numThreadDims) {
|
unsigned numThreadDims) {
|
||||||
if (numBlockDims < 1 || numThreadDims < 1) {
|
if (numBlockDims < 1 || numThreadDims < 1) {
|
||||||
LLVM_DEBUG(llvm::dbgs() << "nothing to map");
|
LLVM_DEBUG(llvm::dbgs() << "nothing to map");
|
||||||
return success();
|
return success();
|
||||||
|
@ -265,8 +265,8 @@ Optional<OpTy> LoopToGpuConverter::collectBounds(OpTy forOp,
|
||||||
/// `nids`. The innermost loop is mapped to the x-dimension, followed by the
|
/// `nids`. The innermost loop is mapped to the x-dimension, followed by the
|
||||||
/// next innermost loop to y-dimension, followed by z-dimension.
|
/// next innermost loop to y-dimension, followed by z-dimension.
|
||||||
template <typename OpTy>
|
template <typename OpTy>
|
||||||
OpTy createGPULaunchLoops(OpTy rootForOp, ArrayRef<Value> ids,
|
static OpTy createGPULaunchLoops(OpTy rootForOp, ArrayRef<Value> ids,
|
||||||
ArrayRef<Value> nids) {
|
ArrayRef<Value> nids) {
|
||||||
auto nDims = ids.size();
|
auto nDims = ids.size();
|
||||||
assert(nDims == nids.size());
|
assert(nDims == nids.size());
|
||||||
for (auto dim : llvm::seq<unsigned>(0, nDims)) {
|
for (auto dim : llvm::seq<unsigned>(0, nDims)) {
|
||||||
|
@ -285,9 +285,10 @@ OpTy createGPULaunchLoops(OpTy rootForOp, ArrayRef<Value> ids,
|
||||||
/// Utility method to convert the gpu::KernelDim3 object for representing id of
|
/// Utility method to convert the gpu::KernelDim3 object for representing id of
|
||||||
/// each workgroup/workitem and number of workgroup/workitems along a dimension
|
/// each workgroup/workitem and number of workgroup/workitems along a dimension
|
||||||
/// of the launch into a container.
|
/// of the launch into a container.
|
||||||
void packIdAndNumId(gpu::KernelDim3 kernelIds, gpu::KernelDim3 kernelNids,
|
static void packIdAndNumId(gpu::KernelDim3 kernelIds,
|
||||||
unsigned nDims, SmallVectorImpl<Value> &ids,
|
gpu::KernelDim3 kernelNids, unsigned nDims,
|
||||||
SmallVectorImpl<Value> &nids) {
|
SmallVectorImpl<Value> &ids,
|
||||||
|
SmallVectorImpl<Value> &nids) {
|
||||||
assert(nDims <= 3 && "invalid number of launch dimensions");
|
assert(nDims <= 3 && "invalid number of launch dimensions");
|
||||||
SmallVector<Value, 3> allIds = {kernelIds.z, kernelIds.y, kernelIds.x};
|
SmallVector<Value, 3> allIds = {kernelIds.z, kernelIds.y, kernelIds.x};
|
||||||
SmallVector<Value, 3> allNids = {kernelNids.z, kernelNids.y, kernelNids.x};
|
SmallVector<Value, 3> allNids = {kernelNids.z, kernelNids.y, kernelNids.x};
|
||||||
|
@ -300,9 +301,9 @@ void packIdAndNumId(gpu::KernelDim3 kernelIds, gpu::KernelDim3 kernelNids,
|
||||||
|
|
||||||
/// Generate the body of the launch operation.
|
/// Generate the body of the launch operation.
|
||||||
template <typename OpTy>
|
template <typename OpTy>
|
||||||
LogicalResult createLaunchBody(OpBuilder &builder, OpTy rootForOp,
|
static LogicalResult
|
||||||
gpu::LaunchOp launchOp, unsigned numBlockDims,
|
createLaunchBody(OpBuilder &builder, OpTy rootForOp, gpu::LaunchOp launchOp,
|
||||||
unsigned numThreadDims) {
|
unsigned numBlockDims, unsigned numThreadDims) {
|
||||||
OpBuilder::InsertionGuard bodyInsertionGuard(builder);
|
OpBuilder::InsertionGuard bodyInsertionGuard(builder);
|
||||||
builder.setInsertionPointToEnd(&launchOp.body().front());
|
builder.setInsertionPointToEnd(&launchOp.body().front());
|
||||||
auto returnOp = builder.create<gpu::ReturnOp>(launchOp.getLoc());
|
auto returnOp = builder.create<gpu::ReturnOp>(launchOp.getLoc());
|
||||||
|
@ -337,8 +338,9 @@ LogicalResult createLaunchBody(OpBuilder &builder, OpTy rootForOp,
|
||||||
// Convert the computation rooted at the `rootForOp`, into a GPU kernel with the
|
// Convert the computation rooted at the `rootForOp`, into a GPU kernel with the
|
||||||
// given workgroup size and number of workgroups.
|
// given workgroup size and number of workgroups.
|
||||||
template <typename OpTy>
|
template <typename OpTy>
|
||||||
LogicalResult createLaunchFromOp(OpTy rootForOp, ArrayRef<Value> numWorkGroups,
|
static LogicalResult createLaunchFromOp(OpTy rootForOp,
|
||||||
ArrayRef<Value> workGroupSizes) {
|
ArrayRef<Value> numWorkGroups,
|
||||||
|
ArrayRef<Value> workGroupSizes) {
|
||||||
OpBuilder builder(rootForOp.getOperation());
|
OpBuilder builder(rootForOp.getOperation());
|
||||||
if (numWorkGroups.size() > 3) {
|
if (numWorkGroups.size() > 3) {
|
||||||
return rootForOp.emitError("invalid ")
|
return rootForOp.emitError("invalid ")
|
||||||
|
|
|
@ -139,10 +139,11 @@ public:
|
||||||
|
|
||||||
// TODO(ravishankarm) : This method assumes that the `origBaseType` is a
|
// TODO(ravishankarm) : This method assumes that the `origBaseType` is a
|
||||||
// MemRefType with AffineMap that has static strides. Handle dynamic strides
|
// MemRefType with AffineMap that has static strides. Handle dynamic strides
|
||||||
spirv::AccessChainOp getElementPtr(OpBuilder &builder,
|
static spirv::AccessChainOp getElementPtr(OpBuilder &builder,
|
||||||
SPIRVTypeConverter &typeConverter,
|
SPIRVTypeConverter &typeConverter,
|
||||||
Location loc, MemRefType origBaseType,
|
Location loc, MemRefType origBaseType,
|
||||||
Value basePtr, ArrayRef<Value> indices) {
|
Value basePtr,
|
||||||
|
ArrayRef<Value> indices) {
|
||||||
// Get base and offset of the MemRefType and verify they are static.
|
// Get base and offset of the MemRefType and verify they are static.
|
||||||
int64_t offset;
|
int64_t offset;
|
||||||
SmallVector<int64_t, 4> strides;
|
SmallVector<int64_t, 4> strides;
|
||||||
|
|
|
@ -34,6 +34,8 @@
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
using namespace mlir::vector;
|
using namespace mlir::vector;
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
static LLVM::LLVMType getPtrToElementType(T containerType,
|
static LLVM::LLVMType getPtrToElementType(T containerType,
|
||||||
LLVMTypeConverter &lowering) {
|
LLVMTypeConverter &lowering) {
|
||||||
|
@ -948,6 +950,8 @@ public:
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
/// Populate the given list with patterns that convert from Vector to LLVM.
|
/// Populate the given list with patterns that convert from Vector to LLVM.
|
||||||
void mlir::populateVectorToLLVMConversionPatterns(
|
void mlir::populateVectorToLLVMConversionPatterns(
|
||||||
LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
|
LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
|
||||||
|
|
|
@ -141,7 +141,8 @@ bool mlir::isValidDim(Value value) {
|
||||||
/// Returns true if the 'index' dimension of the `memref` defined by
|
/// Returns true if the 'index' dimension of the `memref` defined by
|
||||||
/// `memrefDefOp` is a statically shaped one or defined using a valid symbol.
|
/// `memrefDefOp` is a statically shaped one or defined using a valid symbol.
|
||||||
template <typename AnyMemRefDefOp>
|
template <typename AnyMemRefDefOp>
|
||||||
bool isMemRefSizeValidSymbol(AnyMemRefDefOp memrefDefOp, unsigned index) {
|
static bool isMemRefSizeValidSymbol(AnyMemRefDefOp memrefDefOp,
|
||||||
|
unsigned index) {
|
||||||
auto memRefType = memrefDefOp.getType();
|
auto memRefType = memrefDefOp.getType();
|
||||||
// Statically shaped.
|
// Statically shaped.
|
||||||
if (!ShapedType::isDynamic(memRefType.getDimSize(index)))
|
if (!ShapedType::isDynamic(memRefType.getDimSize(index)))
|
||||||
|
@ -1620,7 +1621,8 @@ static LogicalResult verify(AffineIfOp op) {
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
ParseResult parseAffineIfOp(OpAsmParser &parser, OperationState &result) {
|
static ParseResult parseAffineIfOp(OpAsmParser &parser,
|
||||||
|
OperationState &result) {
|
||||||
// Parse the condition attribute set.
|
// Parse the condition attribute set.
|
||||||
IntegerSetAttr conditionAttr;
|
IntegerSetAttr conditionAttr;
|
||||||
unsigned numDims;
|
unsigned numDims;
|
||||||
|
@ -1667,7 +1669,7 @@ ParseResult parseAffineIfOp(OpAsmParser &parser, OperationState &result) {
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
void print(OpAsmPrinter &p, AffineIfOp op) {
|
static void print(OpAsmPrinter &p, AffineIfOp op) {
|
||||||
auto conditionAttr =
|
auto conditionAttr =
|
||||||
op.getAttrOfType<IntegerSetAttr>(op.getConditionAttrName());
|
op.getAttrOfType<IntegerSetAttr>(op.getConditionAttrName());
|
||||||
p << "affine.if " << conditionAttr;
|
p << "affine.if " << conditionAttr;
|
||||||
|
@ -2057,7 +2059,7 @@ static ParseResult parseAffinePrefetchOp(OpAsmParser &parser,
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
void print(OpAsmPrinter &p, AffinePrefetchOp op) {
|
static void print(OpAsmPrinter &p, AffinePrefetchOp op) {
|
||||||
p << AffinePrefetchOp::getOperationName() << " " << op.memref() << '[';
|
p << AffinePrefetchOp::getOperationName() << " " << op.memref() << '[';
|
||||||
AffineMapAttr mapAttr = op.getAttrOfType<AffineMapAttr>(op.getMapAttrName());
|
AffineMapAttr mapAttr = op.getAttrOfType<AffineMapAttr>(op.getMapAttrName());
|
||||||
if (mapAttr) {
|
if (mapAttr) {
|
||||||
|
@ -2074,7 +2076,7 @@ void print(OpAsmPrinter &p, AffinePrefetchOp op) {
|
||||||
p << " : " << op.getMemRefType();
|
p << " : " << op.getMemRefType();
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult verify(AffinePrefetchOp op) {
|
static LogicalResult verify(AffinePrefetchOp op) {
|
||||||
auto mapAttr = op.getAttrOfType<AffineMapAttr>(op.getMapAttrName());
|
auto mapAttr = op.getAttrOfType<AffineMapAttr>(op.getMapAttrName());
|
||||||
if (mapAttr) {
|
if (mapAttr) {
|
||||||
AffineMap map = mapAttr.getValue();
|
AffineMap map = mapAttr.getValue();
|
||||||
|
|
|
@ -258,7 +258,7 @@ iterator_range<Block::args_iterator> LaunchOp::getKernelArguments() {
|
||||||
return llvm::drop_begin(args, LaunchOp::kNumConfigRegionAttributes);
|
return llvm::drop_begin(args, LaunchOp::kNumConfigRegionAttributes);
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult verify(LaunchOp op) {
|
static LogicalResult verify(LaunchOp op) {
|
||||||
// Kernel launch takes kNumConfigOperands leading operands for grid/block
|
// Kernel launch takes kNumConfigOperands leading operands for grid/block
|
||||||
// sizes and transforms them into kNumConfigRegionAttributes region arguments
|
// sizes and transforms them into kNumConfigRegionAttributes region arguments
|
||||||
// for block/thread identifiers and grid/block sizes.
|
// for block/thread identifiers and grid/block sizes.
|
||||||
|
@ -300,7 +300,7 @@ static void printSizeAssignment(OpAsmPrinter &p, KernelDim3 size,
|
||||||
p << size.z << " = " << operands[2] << ')';
|
p << size.z << " = " << operands[2] << ')';
|
||||||
}
|
}
|
||||||
|
|
||||||
void printLaunchOp(OpAsmPrinter &p, LaunchOp op) {
|
static void printLaunchOp(OpAsmPrinter &p, LaunchOp op) {
|
||||||
ValueRange operands = op.getOperands();
|
ValueRange operands = op.getOperands();
|
||||||
|
|
||||||
// Print the launch configuration.
|
// Print the launch configuration.
|
||||||
|
@ -370,7 +370,7 @@ parseSizeAssignment(OpAsmParser &parser,
|
||||||
// (`args` ssa-reassignment `:` type-list)?
|
// (`args` ssa-reassignment `:` type-list)?
|
||||||
// region attr-dict?
|
// region attr-dict?
|
||||||
// ssa-reassignment ::= `(` ssa-id `=` ssa-use (`,` ssa-id `=` ssa-use)* `)`
|
// ssa-reassignment ::= `(` ssa-id `=` ssa-use (`,` ssa-id `=` ssa-use)* `)`
|
||||||
ParseResult parseLaunchOp(OpAsmParser &parser, OperationState &result) {
|
static ParseResult parseLaunchOp(OpAsmParser &parser, OperationState &result) {
|
||||||
// Sizes of the grid and block.
|
// Sizes of the grid and block.
|
||||||
SmallVector<OpAsmParser::OperandType, LaunchOp::kNumConfigOperands> sizes(
|
SmallVector<OpAsmParser::OperandType, LaunchOp::kNumConfigOperands> sizes(
|
||||||
LaunchOp::kNumConfigOperands);
|
LaunchOp::kNumConfigOperands);
|
||||||
|
@ -549,7 +549,7 @@ KernelDim3 LaunchFuncOp::getBlockSizeOperandValues() {
|
||||||
return KernelDim3{getOperand(3), getOperand(4), getOperand(5)};
|
return KernelDim3{getOperand(3), getOperand(4), getOperand(5)};
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult verify(LaunchFuncOp op) {
|
static LogicalResult verify(LaunchFuncOp op) {
|
||||||
auto module = op.getParentOfType<ModuleOp>();
|
auto module = op.getParentOfType<ModuleOp>();
|
||||||
if (!module)
|
if (!module)
|
||||||
return op.emitOpError("expected to belong to a module");
|
return op.emitOpError("expected to belong to a module");
|
||||||
|
@ -729,7 +729,7 @@ static void printAttributions(OpAsmPrinter &p, StringRef keyword,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Prints a GPU Func op.
|
/// Prints a GPU Func op.
|
||||||
void printGPUFuncOp(OpAsmPrinter &p, GPUFuncOp op) {
|
static void printGPUFuncOp(OpAsmPrinter &p, GPUFuncOp op) {
|
||||||
p << GPUFuncOp::getOperationName() << ' ';
|
p << GPUFuncOp::getOperationName() << ' ';
|
||||||
p.printSymbolName(op.getName());
|
p.printSymbolName(op.getName());
|
||||||
|
|
||||||
|
|
|
@ -109,7 +109,7 @@ static ParseResult parseGenericOp(OpAsmParser &parser, OperationState &result) {
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename GenericOpType>
|
template <typename GenericOpType>
|
||||||
LogicalResult verifyBlockArgs(GenericOpType op, Block &block);
|
static LogicalResult verifyBlockArgs(GenericOpType op, Block &block);
|
||||||
|
|
||||||
template <> LogicalResult verifyBlockArgs(GenericOp op, Block &block) {
|
template <> LogicalResult verifyBlockArgs(GenericOp op, Block &block) {
|
||||||
auto nViews = op.getNumInputsAndOutputs();
|
auto nViews = op.getNumInputsAndOutputs();
|
||||||
|
@ -158,7 +158,7 @@ template <> LogicalResult verifyBlockArgs(IndexedGenericOp op, Block &block) {
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename GenericOpType>
|
template <typename GenericOpType>
|
||||||
LogicalResult verifyFuncArgs(GenericOpType op, FunctionType funType);
|
static LogicalResult verifyFuncArgs(GenericOpType op, FunctionType funType);
|
||||||
|
|
||||||
template <> LogicalResult verifyFuncArgs(GenericOp op, FunctionType funType) {
|
template <> LogicalResult verifyFuncArgs(GenericOp op, FunctionType funType) {
|
||||||
auto nViews = op.getNumInputsAndOutputs();
|
auto nViews = op.getNumInputsAndOutputs();
|
||||||
|
@ -228,7 +228,7 @@ LogicalResult verifyFuncArgs(IndexedGenericOp op, FunctionType funType) {
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename GenericOpType>
|
template <typename GenericOpType>
|
||||||
LogicalResult verifyGenericOp(GenericOpType op) {
|
static LogicalResult verifyGenericOp(GenericOpType op) {
|
||||||
auto nInputViews = op.getNumInputs();
|
auto nInputViews = op.getNumInputs();
|
||||||
auto nLoops = op.getNumLoops();
|
auto nLoops = op.getNumLoops();
|
||||||
auto nViews = op.getNumInputsAndOutputs();
|
auto nViews = op.getNumInputsAndOutputs();
|
||||||
|
@ -729,7 +729,7 @@ static ParseResult parseYieldOp(OpAsmParser &parser, OperationState &result) {
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename GenericOpType>
|
template <typename GenericOpType>
|
||||||
LogicalResult verifyYield(YieldOp op, GenericOpType genericOp) {
|
static LogicalResult verifyYield(YieldOp op, GenericOpType genericOp) {
|
||||||
// The operand number and types must match the view element types.
|
// The operand number and types must match the view element types.
|
||||||
auto nOutputViews = genericOp.getNumOutputs();
|
auto nOutputViews = genericOp.getNumOutputs();
|
||||||
if (op.getNumOperands() != nOutputViews)
|
if (op.getNumOperands() != nOutputViews)
|
||||||
|
|
|
@ -37,6 +37,8 @@ using IndexedAffineValue = TemplatedIndexedValue<affine_load, affine_store>;
|
||||||
using edsc::op::operator+;
|
using edsc::op::operator+;
|
||||||
using edsc::op::operator==;
|
using edsc::op::operator==;
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
static SmallVector<ValueHandle, 8>
|
static SmallVector<ValueHandle, 8>
|
||||||
makeCanonicalAffineApplies(OpBuilder &b, Location loc, AffineMap map,
|
makeCanonicalAffineApplies(OpBuilder &b, Location loc, AffineMap map,
|
||||||
ArrayRef<Value> vals) {
|
ArrayRef<Value> vals) {
|
||||||
|
@ -379,7 +381,6 @@ public:
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
namespace {
|
|
||||||
// This struct is for factoring out the implementation and support template
|
// This struct is for factoring out the implementation and support template
|
||||||
// instantiations in the following 2 cases:
|
// instantiations in the following 2 cases:
|
||||||
// 1. Appending to a list of patterns via RewritePatternList.
|
// 1. Appending to a list of patterns via RewritePatternList.
|
||||||
|
@ -393,7 +394,6 @@ class LinalgOpToLoopsImpl {
|
||||||
public:
|
public:
|
||||||
static LogicalResult doit(Operation *op, PatternRewriter &rewriter);
|
static LogicalResult doit(Operation *op, PatternRewriter &rewriter);
|
||||||
};
|
};
|
||||||
} // namespace
|
|
||||||
|
|
||||||
template <typename LoopTy, typename IndexedValueTy, typename ConcreteOpTy>
|
template <typename LoopTy, typename IndexedValueTy, typename ConcreteOpTy>
|
||||||
LogicalResult LinalgOpToLoopsImpl<LoopTy, IndexedValueTy, ConcreteOpTy>::doit(
|
LogicalResult LinalgOpToLoopsImpl<LoopTy, IndexedValueTy, ConcreteOpTy>::doit(
|
||||||
|
@ -538,6 +538,8 @@ void LowerLinalgToLoopsPass<LoopType, IndexedValueType>::runOnFunction() {
|
||||||
applyPatternsGreedily(this->getFunction(), patterns);
|
applyPatternsGreedily(this->getFunction(), patterns);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
/// Create a pass to convert Linalg operations to loop.for loops and
|
/// Create a pass to convert Linalg operations to loop.for loops and
|
||||||
/// std.load/std.store accesses.
|
/// std.load/std.store accesses.
|
||||||
std::unique_ptr<OpPassBase<FuncOp>>
|
std::unique_ptr<OpPassBase<FuncOp>>
|
||||||
|
|
|
@ -170,7 +170,7 @@ struct TileCheck : public AffineExprVisitor<TileCheck> {
|
||||||
//
|
//
|
||||||
// TODO(pifon, ntv): Investigate whether mixing implicit and explicit indices
|
// TODO(pifon, ntv): Investigate whether mixing implicit and explicit indices
|
||||||
// does not lead to losing information.
|
// does not lead to losing information.
|
||||||
void transformIndexedGenericOpIndices(
|
static void transformIndexedGenericOpIndices(
|
||||||
OpBuilder &b, LinalgOp op, ArrayRef<ValueHandle *> pivs,
|
OpBuilder &b, LinalgOp op, ArrayRef<ValueHandle *> pivs,
|
||||||
const LoopIndexToRangeIndexMap &loopIndexToRangeIndex) {
|
const LoopIndexToRangeIndexMap &loopIndexToRangeIndex) {
|
||||||
auto indexedGenericOp = dyn_cast<IndexedGenericOp>(op.getOperation());
|
auto indexedGenericOp = dyn_cast<IndexedGenericOp>(op.getOperation());
|
||||||
|
|
|
@ -68,7 +68,7 @@ void ForOp::build(Builder *builder, OperationState &result, Value lb, Value ub,
|
||||||
bodyRegion->front().addArgument(builder->getIndexType());
|
bodyRegion->front().addArgument(builder->getIndexType());
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult verify(ForOp op) {
|
static LogicalResult verify(ForOp op) {
|
||||||
if (auto cst = dyn_cast_or_null<ConstantIndexOp>(op.step().getDefiningOp()))
|
if (auto cst = dyn_cast_or_null<ConstantIndexOp>(op.step().getDefiningOp()))
|
||||||
if (cst.getValue() <= 0)
|
if (cst.getValue() <= 0)
|
||||||
return op.emitOpError("constant step operand must be positive");
|
return op.emitOpError("constant step operand must be positive");
|
||||||
|
|
|
@ -26,8 +26,6 @@ public:
|
||||||
void runOnFunction() override;
|
void runOnFunction() override;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // end anonymous namespace
|
|
||||||
|
|
||||||
/// Base class rewrites ConstFakeQuant into a qbarrier/dbarrier pair.
|
/// Base class rewrites ConstFakeQuant into a qbarrier/dbarrier pair.
|
||||||
template <typename ConcreteRewriteClass, typename FakeQuantOp>
|
template <typename ConcreteRewriteClass, typename FakeQuantOp>
|
||||||
class FakeQuantRewrite : public OpRewritePattern<FakeQuantOp> {
|
class FakeQuantRewrite : public OpRewritePattern<FakeQuantOp> {
|
||||||
|
@ -126,6 +124,8 @@ public:
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
void ConvertSimulatedQuantPass::runOnFunction() {
|
void ConvertSimulatedQuantPass::runOnFunction() {
|
||||||
bool hadFailure = false;
|
bool hadFailure = false;
|
||||||
OwningRewritePatternList patterns;
|
OwningRewritePatternList patterns;
|
||||||
|
|
|
@ -301,8 +301,9 @@ AffineExpr SDBMExpr::getAsAffineExpr() const {
|
||||||
// expression is already a sum expression, update its constant and extract the
|
// expression is already a sum expression, update its constant and extract the
|
||||||
// LHS if the constant becomes zero. Otherwise, construct a sum expression.
|
// LHS if the constant becomes zero. Otherwise, construct a sum expression.
|
||||||
template <typename Result>
|
template <typename Result>
|
||||||
Result addConstantAndSink(SDBMDirectExpr expr, int64_t constant, bool negated,
|
static Result addConstantAndSink(SDBMDirectExpr expr, int64_t constant,
|
||||||
function_ref<Result(SDBMDirectExpr)> builder) {
|
bool negated,
|
||||||
|
function_ref<Result(SDBMDirectExpr)> builder) {
|
||||||
SDBMDialect *dialect = expr.getDialect();
|
SDBMDialect *dialect = expr.getDialect();
|
||||||
if (auto sumExpr = expr.dyn_cast<SDBMSumExpr>()) {
|
if (auto sumExpr = expr.dyn_cast<SDBMSumExpr>()) {
|
||||||
if (negated)
|
if (negated)
|
||||||
|
|
|
@ -375,6 +375,7 @@ Optional<uint64_t> parseAndVerify<uint64_t>(SPIRVDialect const &dialect,
|
||||||
return parseAndVerifyInteger<uint64_t>(dialect, parser);
|
return parseAndVerifyInteger<uint64_t>(dialect, parser);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
// Functor object to parse a comma separated list of specs. The function
|
// Functor object to parse a comma separated list of specs. The function
|
||||||
// parseAndVerify does the actual parsing and verification of individual
|
// parseAndVerify does the actual parsing and verification of individual
|
||||||
// elements. This is a functor since parsing the last element of the list
|
// elements. This is a functor since parsing the last element of the list
|
||||||
|
@ -407,6 +408,7 @@ template <typename ParseType> struct parseCommaSeparatedList<ParseType> {
|
||||||
return llvm::None;
|
return llvm::None;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
// dim ::= `1D` | `2D` | `3D` | `Cube` | <and other SPIR-V Dim specifiers...>
|
// dim ::= `1D` | `2D` | `3D` | `Cube` | <and other SPIR-V Dim specifiers...>
|
||||||
//
|
//
|
||||||
|
|
|
@ -33,8 +33,8 @@ using namespace mlir;
|
||||||
|
|
||||||
// Deserializes the SPIR-V binary module stored in the file named as
|
// Deserializes the SPIR-V binary module stored in the file named as
|
||||||
// `inputFilename` and returns a module containing the SPIR-V module.
|
// `inputFilename` and returns a module containing the SPIR-V module.
|
||||||
OwningModuleRef deserializeModule(const llvm::MemoryBuffer *input,
|
static OwningModuleRef deserializeModule(const llvm::MemoryBuffer *input,
|
||||||
MLIRContext *context) {
|
MLIRContext *context) {
|
||||||
Builder builder(context);
|
Builder builder(context);
|
||||||
|
|
||||||
// Make sure the input stream can be treated as a stream of SPIR-V words
|
// Make sure the input stream can be treated as a stream of SPIR-V words
|
||||||
|
@ -71,7 +71,7 @@ static TranslateToMLIRRegistration fromBinary(
|
||||||
// Serialization registration
|
// Serialization registration
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
LogicalResult serializeModule(ModuleOp module, raw_ostream &output) {
|
static LogicalResult serializeModule(ModuleOp module, raw_ostream &output) {
|
||||||
if (!module)
|
if (!module)
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
|
@ -104,8 +104,9 @@ static TranslateFromMLIRRegistration
|
||||||
// Round-trip registration
|
// Round-trip registration
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
LogicalResult roundTripModule(llvm::SourceMgr &sourceMgr, raw_ostream &output,
|
static LogicalResult roundTripModule(llvm::SourceMgr &sourceMgr,
|
||||||
MLIRContext *context) {
|
raw_ostream &output,
|
||||||
|
MLIRContext *context) {
|
||||||
// Parse an MLIR module from the source manager.
|
// Parse an MLIR module from the source manager.
|
||||||
auto srcModule = OwningModuleRef(parseSourceFile(sourceMgr, context));
|
auto srcModule = OwningModuleRef(parseSourceFile(sourceMgr, context));
|
||||||
if (!srcModule)
|
if (!srcModule)
|
||||||
|
|
|
@ -918,9 +918,10 @@ static ParseResult parseInsertStridedSliceOp(OpAsmParser &parser,
|
||||||
|
|
||||||
// TODO(ntv) Should be moved to Tablegen Confined attributes.
|
// TODO(ntv) Should be moved to Tablegen Confined attributes.
|
||||||
template <typename OpType>
|
template <typename OpType>
|
||||||
LogicalResult isIntegerArrayAttrSmallerThanShape(OpType op, ArrayAttr arrayAttr,
|
static LogicalResult isIntegerArrayAttrSmallerThanShape(OpType op,
|
||||||
ArrayRef<int64_t> shape,
|
ArrayAttr arrayAttr,
|
||||||
StringRef attrName) {
|
ArrayRef<int64_t> shape,
|
||||||
|
StringRef attrName) {
|
||||||
if (arrayAttr.size() > shape.size())
|
if (arrayAttr.size() > shape.size())
|
||||||
return op.emitOpError("expected ")
|
return op.emitOpError("expected ")
|
||||||
<< attrName << " attribute of rank smaller than vector rank";
|
<< attrName << " attribute of rank smaller than vector rank";
|
||||||
|
@ -931,10 +932,10 @@ LogicalResult isIntegerArrayAttrSmallerThanShape(OpType op, ArrayAttr arrayAttr,
|
||||||
// interval. If `halfOpen` is true then the admissible interval is [min, max).
|
// interval. If `halfOpen` is true then the admissible interval is [min, max).
|
||||||
// Otherwise, the admissible interval is [min, max].
|
// Otherwise, the admissible interval is [min, max].
|
||||||
template <typename OpType>
|
template <typename OpType>
|
||||||
LogicalResult isIntegerArrayAttrConfinedToRange(OpType op, ArrayAttr arrayAttr,
|
static LogicalResult
|
||||||
int64_t min, int64_t max,
|
isIntegerArrayAttrConfinedToRange(OpType op, ArrayAttr arrayAttr, int64_t min,
|
||||||
StringRef attrName,
|
int64_t max, StringRef attrName,
|
||||||
bool halfOpen = true) {
|
bool halfOpen = true) {
|
||||||
for (auto attr : arrayAttr) {
|
for (auto attr : arrayAttr) {
|
||||||
auto val = attr.cast<IntegerAttr>().getInt();
|
auto val = attr.cast<IntegerAttr>().getInt();
|
||||||
auto upper = max;
|
auto upper = max;
|
||||||
|
@ -951,7 +952,7 @@ LogicalResult isIntegerArrayAttrConfinedToRange(OpType op, ArrayAttr arrayAttr,
|
||||||
// interval. If `halfOpen` is true then the admissible interval is [min, max).
|
// interval. If `halfOpen` is true then the admissible interval is [min, max).
|
||||||
// Otherwise, the admissible interval is [min, max].
|
// Otherwise, the admissible interval is [min, max].
|
||||||
template <typename OpType>
|
template <typename OpType>
|
||||||
LogicalResult
|
static LogicalResult
|
||||||
isIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr,
|
isIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr,
|
||||||
ArrayRef<int64_t> shape, StringRef attrName,
|
ArrayRef<int64_t> shape, StringRef attrName,
|
||||||
bool halfOpen = true, int64_t min = 0) {
|
bool halfOpen = true, int64_t min = 0) {
|
||||||
|
@ -975,7 +976,7 @@ isIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr,
|
||||||
// interval. If `halfOpen` is true then the admissible interval is [min, max).
|
// interval. If `halfOpen` is true then the admissible interval is [min, max).
|
||||||
// Otherwise, the admissible interval is [min, max].
|
// Otherwise, the admissible interval is [min, max].
|
||||||
template <typename OpType>
|
template <typename OpType>
|
||||||
LogicalResult isSumOfIntegerArrayAttrConfinedToShape(
|
static LogicalResult isSumOfIntegerArrayAttrConfinedToShape(
|
||||||
OpType op, ArrayAttr arrayAttr1, ArrayAttr arrayAttr2,
|
OpType op, ArrayAttr arrayAttr1, ArrayAttr arrayAttr2,
|
||||||
ArrayRef<int64_t> shape, StringRef attrName1, StringRef attrName2,
|
ArrayRef<int64_t> shape, StringRef attrName1, StringRef attrName2,
|
||||||
bool halfOpen = true, int64_t min = 1) {
|
bool halfOpen = true, int64_t min = 1) {
|
||||||
|
@ -1470,7 +1471,8 @@ static void print(OpAsmPrinter &p, TransferReadOp op) {
|
||||||
p << " : " << op.getMemRefType() << ", " << op.getVectorType();
|
p << " : " << op.getMemRefType() << ", " << op.getVectorType();
|
||||||
}
|
}
|
||||||
|
|
||||||
ParseResult parseTransferReadOp(OpAsmParser &parser, OperationState &result) {
|
static ParseResult parseTransferReadOp(OpAsmParser &parser,
|
||||||
|
OperationState &result) {
|
||||||
llvm::SMLoc typesLoc;
|
llvm::SMLoc typesLoc;
|
||||||
OpAsmParser::OperandType memrefInfo;
|
OpAsmParser::OperandType memrefInfo;
|
||||||
SmallVector<OpAsmParser::OperandType, 8> indexInfo;
|
SmallVector<OpAsmParser::OperandType, 8> indexInfo;
|
||||||
|
@ -1545,7 +1547,8 @@ static void print(OpAsmPrinter &p, TransferWriteOp op) {
|
||||||
p << " : " << op.getVectorType() << ", " << op.getMemRefType();
|
p << " : " << op.getVectorType() << ", " << op.getMemRefType();
|
||||||
}
|
}
|
||||||
|
|
||||||
ParseResult parseTransferWriteOp(OpAsmParser &parser, OperationState &result) {
|
static ParseResult parseTransferWriteOp(OpAsmParser &parser,
|
||||||
|
OperationState &result) {
|
||||||
llvm::SMLoc typesLoc;
|
llvm::SMLoc typesLoc;
|
||||||
OpAsmParser::OperandType storeValueInfo;
|
OpAsmParser::OperandType storeValueInfo;
|
||||||
OpAsmParser::OperandType memRefInfo;
|
OpAsmParser::OperandType memRefInfo;
|
||||||
|
@ -1682,7 +1685,8 @@ static LogicalResult verify(TupleGetOp op) {
|
||||||
// ConstantMaskOp
|
// ConstantMaskOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
ParseResult parseConstantMaskOp(OpAsmParser &parser, OperationState &result) {
|
static ParseResult parseConstantMaskOp(OpAsmParser &parser,
|
||||||
|
OperationState &result) {
|
||||||
Type resultType;
|
Type resultType;
|
||||||
ArrayAttr maskDimSizesAttr;
|
ArrayAttr maskDimSizesAttr;
|
||||||
StringRef attrName = ConstantMaskOp::getMaskDimSizesAttrName();
|
StringRef attrName = ConstantMaskOp::getMaskDimSizesAttrName();
|
||||||
|
@ -1729,7 +1733,8 @@ static LogicalResult verify(ConstantMaskOp &op) {
|
||||||
// CreateMaskOp
|
// CreateMaskOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
ParseResult parseCreateMaskOp(OpAsmParser &parser, OperationState &result) {
|
static ParseResult parseCreateMaskOp(OpAsmParser &parser,
|
||||||
|
OperationState &result) {
|
||||||
auto indexType = parser.getBuilder().getIndexType();
|
auto indexType = parser.getBuilder().getIndexType();
|
||||||
Type resultType;
|
Type resultType;
|
||||||
SmallVector<OpAsmParser::OperandType, 4> operandInfo;
|
SmallVector<OpAsmParser::OperandType, 4> operandInfo;
|
||||||
|
@ -1758,7 +1763,7 @@ static LogicalResult verify(CreateMaskOp op) {
|
||||||
// PrintOp
|
// PrintOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
ParseResult parsePrintOp(OpAsmParser &parser, OperationState &result) {
|
static ParseResult parsePrintOp(OpAsmParser &parser, OperationState &result) {
|
||||||
OpAsmParser::OperandType source;
|
OpAsmParser::OperandType source;
|
||||||
Type sourceType;
|
Type sourceType;
|
||||||
return failure(parser.parseOperand(source) ||
|
return failure(parser.parseOperand(source) ||
|
||||||
|
|
|
@ -522,6 +522,7 @@ generateTransferOpSlices(VectorType vectorType, TupleType tupleType,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
// Splits vector TransferReadOp into smaller TransferReadOps based on slicing
|
// Splits vector TransferReadOp into smaller TransferReadOps based on slicing
|
||||||
// scheme of its unique ExtractSlicesOp user.
|
// scheme of its unique ExtractSlicesOp user.
|
||||||
struct SplitTransferReadOp : public OpRewritePattern<vector::TransferReadOp> {
|
struct SplitTransferReadOp : public OpRewritePattern<vector::TransferReadOp> {
|
||||||
|
@ -657,6 +658,8 @@ struct TupleGetFolderOp : public OpRewritePattern<vector::TupleGetOp> {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
// TODO(andydavis) Add pattern to rewrite ExtractSlices(ConstantMaskOp).
|
// TODO(andydavis) Add pattern to rewrite ExtractSlices(ConstantMaskOp).
|
||||||
// TODO(andydavis) Add this as DRR pattern.
|
// TODO(andydavis) Add this as DRR pattern.
|
||||||
void mlir::vector::populateVectorToVectorTransformationPatterns(
|
void mlir::vector::populateVectorToVectorTransformationPatterns(
|
||||||
|
|
|
@ -122,7 +122,7 @@ static std::string makePackedFunctionName(StringRef name) {
|
||||||
// For each function in the LLVM module, define an interface function that wraps
|
// For each function in the LLVM module, define an interface function that wraps
|
||||||
// all the arguments of the original function and all its results into an i8**
|
// all the arguments of the original function and all its results into an i8**
|
||||||
// pointer to provide a unified invocation interface.
|
// pointer to provide a unified invocation interface.
|
||||||
void packFunctionArguments(Module *module) {
|
static void packFunctionArguments(Module *module) {
|
||||||
auto &ctx = module->getContext();
|
auto &ctx = module->getContext();
|
||||||
llvm::IRBuilder<> builder(ctx);
|
llvm::IRBuilder<> builder(ctx);
|
||||||
DenseSet<llvm::Function *> interfaceFunctions;
|
DenseSet<llvm::Function *> interfaceFunctions;
|
||||||
|
|
|
@ -674,8 +674,8 @@ static Attribute rebuildAttrAfterRAUW(
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Generates a new symbol reference attribute with a new leaf reference.
|
/// Generates a new symbol reference attribute with a new leaf reference.
|
||||||
SymbolRefAttr generateNewRefAttr(SymbolRefAttr oldAttr,
|
static SymbolRefAttr generateNewRefAttr(SymbolRefAttr oldAttr,
|
||||||
FlatSymbolRefAttr newLeafAttr) {
|
FlatSymbolRefAttr newLeafAttr) {
|
||||||
if (oldAttr.isa<FlatSymbolRefAttr>())
|
if (oldAttr.isa<FlatSymbolRefAttr>())
|
||||||
return newLeafAttr;
|
return newLeafAttr;
|
||||||
auto nestedRefs = llvm::to_vector<2>(oldAttr.getNestedReferences());
|
auto nestedRefs = llvm::to_vector<2>(oldAttr.getNestedReferences());
|
||||||
|
|
|
@ -135,7 +135,7 @@ static void printResultsAsPipeline(raw_ostream &os, OpPassManager &pm) {
|
||||||
printPass(/*indent=*/0, &pass);
|
printPass(/*indent=*/0, &pass);
|
||||||
}
|
}
|
||||||
|
|
||||||
void printStatistics(OpPassManager &pm, PassDisplayMode displayMode) {
|
static void printStatistics(OpPassManager &pm, PassDisplayMode displayMode) {
|
||||||
auto os = llvm::CreateInfoOutputFile();
|
auto os = llvm::CreateInfoOutputFile();
|
||||||
|
|
||||||
// Print the stats header.
|
// Print the stats header.
|
||||||
|
|
|
@ -351,7 +351,7 @@ LogicalResult mlir::instBodySkew(AffineForOp forOp, ArrayRef<uint64_t> shifts,
|
||||||
// in the parent loop. Collect at most `maxLoops` loops and append them to
|
// in the parent loop. Collect at most `maxLoops` loops and append them to
|
||||||
// `forOps`.
|
// `forOps`.
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void getPerfectlyNestedLoopsImpl(
|
static void getPerfectlyNestedLoopsImpl(
|
||||||
SmallVectorImpl<T> &forOps, T rootForOp,
|
SmallVectorImpl<T> &forOps, T rootForOp,
|
||||||
unsigned maxLoops = std::numeric_limits<unsigned>::max()) {
|
unsigned maxLoops = std::numeric_limits<unsigned>::max()) {
|
||||||
for (unsigned i = 0; i < maxLoops; ++i) {
|
for (unsigned i = 0; i < maxLoops; ++i) {
|
||||||
|
|
|
@ -21,7 +21,8 @@ static void createOpI(PatternRewriter &rewriter, Value input) {
|
||||||
rewriter.create<OpI>(rewriter.getUnknownLoc(), input);
|
rewriter.create<OpI>(rewriter.getUnknownLoc(), input);
|
||||||
}
|
}
|
||||||
|
|
||||||
void handleNoResultOp(PatternRewriter &rewriter, OpSymbolBindingNoResult op) {
|
static void handleNoResultOp(PatternRewriter &rewriter,
|
||||||
|
OpSymbolBindingNoResult op) {
|
||||||
// Turn the no result op to a one-result op.
|
// Turn the no result op to a one-result op.
|
||||||
rewriter.create<OpSymbolBindingB>(op.getLoc(), op.operand().getType(),
|
rewriter.create<OpSymbolBindingB>(op.getLoc(), op.operand().getType(),
|
||||||
op.operand());
|
op.operand());
|
||||||
|
@ -56,6 +57,7 @@ static mlir::PassRegistration<TestPatternDriver>
|
||||||
// ReturnType Driver.
|
// ReturnType Driver.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
namespace {
|
||||||
struct ReturnTypeOpMatch : public RewritePattern {
|
struct ReturnTypeOpMatch : public RewritePattern {
|
||||||
ReturnTypeOpMatch(MLIRContext *ctx)
|
ReturnTypeOpMatch(MLIRContext *ctx)
|
||||||
: RewritePattern(OpWithInferTypeInterfaceOp::getOperationName(), 1, ctx) {
|
: RewritePattern(OpWithInferTypeInterfaceOp::getOperationName(), 1, ctx) {
|
||||||
|
@ -94,7 +96,6 @@ struct ReturnTypeOpMatch : public RewritePattern {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
namespace {
|
|
||||||
struct TestReturnTypeDriver : public FunctionPass<TestReturnTypeDriver> {
|
struct TestReturnTypeDriver : public FunctionPass<TestReturnTypeDriver> {
|
||||||
void runOnFunction() override {
|
void runOnFunction() override {
|
||||||
mlir::OwningRewritePatternList patterns;
|
mlir::OwningRewritePatternList patterns;
|
||||||
|
|
Loading…
Reference in New Issue