forked from OSchip/llvm-project
[mlir] Add an AccessGroup attribute to load/store LLVM dialect ops and generate the access_group LLVM metadata.
This also includes LLVM dialect ops created from intrinsics. Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D97944
This commit is contained in:
parent
d0eb25a643
commit
4e393350c5
|
@ -35,6 +35,7 @@ def LLVM_Dialect : Dialect {
|
|||
static StringRef getLoopAttrName() { return "llvm.loop"; }
|
||||
static StringRef getParallelAccessAttrName() { return "parallel_access"; }
|
||||
static StringRef getLoopOptionsAttrName() { return "options"; }
|
||||
static StringRef getAccessGroupsAttrName() { return "access_groups"; }
|
||||
|
||||
/// Verifies if the given string is a well-formed data layout descriptor.
|
||||
/// Uses `reportError` to report errors.
|
||||
|
@ -247,7 +248,8 @@ def LLVM_IntrPatterns {
|
|||
// `llvm::Intrinsic` enum; one usually wants these to be related.
|
||||
class LLVM_IntrOpBase<Dialect dialect, string opName, string enumName,
|
||||
list<int> overloadedResults, list<int> overloadedOperands,
|
||||
list<OpTrait> traits, int numResults>
|
||||
list<OpTrait> traits, int numResults,
|
||||
bit requiresAccessGroup = 0>
|
||||
: LLVM_OpBase<dialect, opName, traits>,
|
||||
Results<!if(!gt(numResults, 0), (outs LLVM_Type:$res), (outs))> {
|
||||
string resultPattern = !if(!gt(numResults, 1),
|
||||
|
@ -264,19 +266,21 @@ class LLVM_IntrOpBase<Dialect dialect, string opName, string enumName,
|
|||
overloadedOperands>.lst), ", ") # [{
|
||||
});
|
||||
auto operands = moduleTranslation.lookupValues(opInst.getOperands());
|
||||
}] # !if(!gt(numResults, 0), "$res = ", "")
|
||||
# [{builder.CreateCall(fn, operands);
|
||||
}];
|
||||
}] # [{auto *inst = builder.CreateCall(fn, operands);
|
||||
}] # !if(!gt(requiresAccessGroup, 0),
|
||||
"moduleTranslation.setAccessGroupsMetadata(op, inst);",
|
||||
"(void) inst;")
|
||||
# !if(!gt(numResults, 0), "$res = inst;", "");
|
||||
}
|
||||
|
||||
// Base class for LLVM intrinsic operations, should not be used directly. Places
|
||||
// the intrinsic into the LLVM dialect and prefixes its name with "intr.".
|
||||
class LLVM_IntrOp<string mnem, list<int> overloadedResults,
|
||||
list<int> overloadedOperands, list<OpTrait> traits,
|
||||
int numResults>
|
||||
int numResults, bit requiresAccessGroup = 0>
|
||||
: LLVM_IntrOpBase<LLVM_Dialect, "intr." # mnem, !subst(".", "_", mnem),
|
||||
overloadedResults, overloadedOperands, traits,
|
||||
numResults>;
|
||||
numResults, requiresAccessGroup>;
|
||||
|
||||
// Base class for LLVM intrinsic operations returning no results. Places the
|
||||
// intrinsic into the LLVM dialect and prefixes its name with "intr.".
|
||||
|
|
|
@ -287,6 +287,10 @@ class MemoryOpWithAlignmentAndAttributes : MemoryOpWithAlignmentBase {
|
|||
inst->setMetadata(module->getMDKindID("nontemporal"), metadata);
|
||||
}
|
||||
}];
|
||||
|
||||
code setAccessGroupsMetadataCode = [{
|
||||
moduleTranslation.setAccessGroupsMetadata(op, inst);
|
||||
}];
|
||||
}
|
||||
|
||||
// Memory-related operations.
|
||||
|
@ -326,12 +330,13 @@ def LLVM_GEPOp : LLVM_Op<"getelementptr", [NoSideEffect]>,
|
|||
|
||||
def LLVM_LoadOp : LLVM_Op<"load">, MemoryOpWithAlignmentAndAttributes {
|
||||
let arguments = (ins LLVM_PointerTo<LLVM_LoadableType>:$addr,
|
||||
OptionalAttr<SymbolRefArrayAttr>:$access_groups,
|
||||
OptionalAttr<I64Attr>:$alignment, UnitAttr:$volatile_,
|
||||
UnitAttr:$nontemporal);
|
||||
let results = (outs LLVM_Type:$res);
|
||||
string llvmBuilder = [{
|
||||
auto *inst = builder.CreateLoad($addr, $volatile_);
|
||||
}] # setAlignmentCode # setNonTemporalMetadataCode # [{
|
||||
}] # setAlignmentCode # setNonTemporalMetadataCode # setAccessGroupsMetadataCode # [{
|
||||
$res = inst;
|
||||
}];
|
||||
let builders = [
|
||||
|
@ -346,16 +351,18 @@ def LLVM_LoadOp : LLVM_Op<"load">, MemoryOpWithAlignmentAndAttributes {
|
|||
CArg<"bool", "false">:$isNonTemporal)>];
|
||||
let parser = [{ return parseLoadOp(parser, result); }];
|
||||
let printer = [{ printLoadOp(p, *this); }];
|
||||
let verifier = [{ return ::verify(*this); }];
|
||||
}
|
||||
|
||||
def LLVM_StoreOp : LLVM_Op<"store">, MemoryOpWithAlignmentAndAttributes {
|
||||
let arguments = (ins LLVM_LoadableType:$value,
|
||||
LLVM_PointerTo<LLVM_LoadableType>:$addr,
|
||||
OptionalAttr<SymbolRefArrayAttr>:$access_groups,
|
||||
OptionalAttr<I64Attr>:$alignment, UnitAttr:$volatile_,
|
||||
UnitAttr:$nontemporal);
|
||||
string llvmBuilder = [{
|
||||
auto *inst = builder.CreateStore($value, $addr, $volatile_);
|
||||
}] # setAlignmentCode # setNonTemporalMetadataCode;
|
||||
}] # setAlignmentCode # setNonTemporalMetadataCode # setAccessGroupsMetadataCode;
|
||||
let builders = [
|
||||
OpBuilder<(ins "Value":$value, "Value":$addr,
|
||||
CArg<"unsigned", "0">:$alignment, CArg<"bool", "false">:$isVolatile,
|
||||
|
@ -363,6 +370,7 @@ def LLVM_StoreOp : LLVM_Op<"store">, MemoryOpWithAlignmentAndAttributes {
|
|||
];
|
||||
let parser = [{ return parseStoreOp(parser, result); }];
|
||||
let printer = [{ printStoreOp(p, *this); }];
|
||||
let verifier = [{ return ::verify(*this); }];
|
||||
}
|
||||
|
||||
// Casts.
|
||||
|
|
|
@ -128,6 +128,9 @@ public:
|
|||
"attempting to map loop options that was already mapped");
|
||||
}
|
||||
|
||||
// Sets LLVM metadata for memory operations that are in a parallel loop.
|
||||
void setAccessGroupsMetadata(Operation *op, llvm::Instruction *inst);
|
||||
|
||||
/// Converts the type from MLIR LLVM dialect to LLVM.
|
||||
llvm::Type *convertType(Type type);
|
||||
|
||||
|
|
|
@ -404,6 +404,34 @@ SwitchOp::getMutableSuccessorOperands(unsigned index) {
|
|||
// Builder, printer and parser for for LLVM::LoadOp.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static LogicalResult verifyAccessGroups(Operation *op) {
|
||||
if (Attribute attribute =
|
||||
op->getAttr(LLVMDialect::getAccessGroupsAttrName())) {
|
||||
// The attribute is already verified to be a symbol ref array attribute via
|
||||
// a constraint in the operation definition.
|
||||
for (SymbolRefAttr accessGroupRef :
|
||||
attribute.cast<ArrayAttr>().getAsRange<SymbolRefAttr>()) {
|
||||
StringRef metadataName = accessGroupRef.getRootReference();
|
||||
auto metadataOp = SymbolTable::lookupNearestSymbolFrom<LLVM::MetadataOp>(
|
||||
op->getParentOp(), metadataName);
|
||||
if (!metadataOp)
|
||||
return op->emitOpError() << "expected '" << accessGroupRef
|
||||
<< "' to reference a metadata op";
|
||||
StringRef accessGroupName = accessGroupRef.getLeafReference();
|
||||
Operation *accessGroupOp =
|
||||
SymbolTable::lookupNearestSymbolFrom(metadataOp, accessGroupName);
|
||||
if (!accessGroupOp)
|
||||
return op->emitOpError() << "expected '" << accessGroupRef
|
||||
<< "' to reference an access_group op";
|
||||
}
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
static LogicalResult verify(LoadOp op) {
|
||||
return verifyAccessGroups(op.getOperation());
|
||||
}
|
||||
|
||||
void LoadOp::build(OpBuilder &builder, OperationState &result, Type t,
|
||||
Value addr, unsigned alignment, bool isVolatile,
|
||||
bool isNonTemporal) {
|
||||
|
@ -462,6 +490,10 @@ static ParseResult parseLoadOp(OpAsmParser &parser, OperationState &result) {
|
|||
// Builder, printer and parser for LLVM::StoreOp.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static LogicalResult verify(StoreOp op) {
|
||||
return verifyAccessGroups(op.getOperation());
|
||||
}
|
||||
|
||||
void StoreOp::build(OpBuilder &builder, OperationState &result, Value value,
|
||||
Value addr, unsigned alignment, bool isVolatile,
|
||||
bool isNonTemporal) {
|
||||
|
|
|
@ -656,6 +656,27 @@ LogicalResult ModuleTranslation::createAccessGroupMetadata() {
|
|||
return success();
|
||||
}
|
||||
|
||||
void ModuleTranslation::setAccessGroupsMetadata(Operation *op,
|
||||
llvm::Instruction *inst) {
|
||||
auto accessGroups =
|
||||
op->getAttrOfType<ArrayAttr>(LLVMDialect::getAccessGroupsAttrName());
|
||||
if (accessGroups && !accessGroups.empty()) {
|
||||
llvm::Module *module = inst->getModule();
|
||||
SmallVector<llvm::Metadata *> metadatas;
|
||||
for (SymbolRefAttr accessGroupRef :
|
||||
accessGroups.getAsRange<SymbolRefAttr>())
|
||||
metadatas.push_back(getAccessGroup(*op, accessGroupRef));
|
||||
|
||||
llvm::MDNode *unionMD = nullptr;
|
||||
if (metadatas.size() == 1)
|
||||
unionMD = llvm::cast<llvm::MDNode>(metadatas.front());
|
||||
else if (metadatas.size() >= 2)
|
||||
unionMD = llvm::MDNode::get(module->getContext(), metadatas);
|
||||
|
||||
inst->setMetadata(module->getMDKindID("llvm.access.group"), unionMD);
|
||||
}
|
||||
}
|
||||
|
||||
llvm::Type *ModuleTranslation::convertType(Type type) {
|
||||
return typeTranslator.translateType(type);
|
||||
}
|
||||
|
|
|
@ -796,3 +796,39 @@ module {
|
|||
llvm.return
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
module {
|
||||
llvm.func @accessGroups(%arg0 : !llvm.ptr<i32>) {
|
||||
// expected-error@below {{attribute 'access_groups' failed to satisfy constraint: symbol ref array attribute}}
|
||||
%0 = llvm.load %arg0 { "access_groups" = "test" } : !llvm.ptr<i32>
|
||||
llvm.return
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
module {
|
||||
llvm.func @accessGroups(%arg0 : !llvm.ptr<i32>) {
|
||||
// expected-error@below {{expected '@func1' to reference a metadata op}}
|
||||
%0 = llvm.load %arg0 { "access_groups" = [@func1] } : !llvm.ptr<i32>
|
||||
llvm.return
|
||||
}
|
||||
llvm.func @func1() {
|
||||
llvm.return
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
module {
|
||||
llvm.func @accessGroups(%arg0 : !llvm.ptr<i32>) {
|
||||
// expected-error@below {{expected '@metadata' to reference an access_group op}}
|
||||
%0 = llvm.load %arg0 { "access_groups" = [@metadata] } : !llvm.ptr<i32>
|
||||
llvm.return
|
||||
}
|
||||
llvm.metadata @metadata {
|
||||
llvm.return
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1483,6 +1483,7 @@ module {
|
|||
llvm.cond_br %2, ^bb4, ^bb5 {llvm.loop = {parallel_access = [@metadata::@group1, @metadata::@group2], options = [#llvm.loopopt<disable_unroll = true>, #llvm.loopopt<disable_licm = true>, #llvm.loopopt<interleave_count = 1>]}}
|
||||
^bb4:
|
||||
%3 = llvm.add %1, %arg2 : i32
|
||||
// CHECK: = load i32, i32* %{{.*}} !llvm.access.group ![[ACCESS_GROUPS_NODE:[0-9]+]]
|
||||
%5 = llvm.load %4 { access_groups = [@metadata::@group1, @metadata::@group2] } : !llvm.ptr<i32>
|
||||
// CHECK: br label {{.*}} !llvm.loop ![[LOOP_NODE]]
|
||||
llvm.br ^bb3(%3 : i32) {llvm.loop = {parallel_access = [@metadata::@group1, @metadata::@group2], options = [#llvm.loopopt<disable_unroll = true>, #llvm.loopopt<disable_licm = true>, #llvm.loopopt<interleave_count = 1>]}}
|
||||
|
@ -1504,3 +1505,4 @@ module {
|
|||
// CHECK: ![[UNROLL_DISABLE_NODE]] = !{!"llvm.loop.unroll.disable", i1 true}
|
||||
// CHECK: ![[LICM_DISABLE_NODE]] = !{!"llvm.licm.disable", i1 true}
|
||||
// CHECK: ![[INTERLEAVE_NODE]] = !{!"llvm.loop.interleave.count", i32 1}
|
||||
// CHECK: ![[ACCESS_GROUPS_NODE]] = !{![[GROUP_NODE1]], ![[GROUP_NODE2]]}
|
||||
|
|
|
@ -23,11 +23,33 @@
|
|||
// It has no side effects.
|
||||
// CHECK: [NoSideEffect]
|
||||
// It has a result.
|
||||
// CHECK: 1>
|
||||
// CHECK: 1,
|
||||
// It does not require an access group.
|
||||
// CHECK: 0>
|
||||
// CHECK: Arguments<(ins LLVM_Type, LLVM_Type
|
||||
|
||||
//---------------------------------------------------------------------------//
|
||||
|
||||
// This checks that we can define an op that takes in an access group metadata.
|
||||
//
|
||||
// RUN: cat %S/../../../llvm/include/llvm/IR/Intrinsics.td \
|
||||
// RUN: | grep -v "llvm/IR/Intrinsics" \
|
||||
// RUN: | mlir-tblgen -gen-llvmir-intrinsics -I %S/../../../llvm/include/ --llvmir-intrinsics-filter=ptrmask --llvmir-intrinsics-access-group-regexp=ptrmask \
|
||||
// RUN: | FileCheck --check-prefix=GROUPS %s
|
||||
|
||||
// GROUPS-LABEL: def LLVM_ptrmask
|
||||
// GROUPS: LLVM_IntrOp<"ptrmask
|
||||
// It has no side effects.
|
||||
// GROUPS: [NoSideEffect]
|
||||
// It has a result.
|
||||
// GROUPS: 1,
|
||||
// It requires generation of an access group LLVM metadata.
|
||||
// GROUPS: 1>
|
||||
// It has an access group attribute.
|
||||
// GROUPS: OptionalAttr<SymbolRefArrayAttr>:$access_groups
|
||||
|
||||
//---------------------------------------------------------------------------//
|
||||
|
||||
// This checks that the ODS we produce can be consumed by MLIR tablegen. We only
|
||||
// make sure the entire process does not fail and produces some C++. The shape
|
||||
// of this C++ code is tested by ODS tests.
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
#include "llvm/Support/CommandLine.h"
|
||||
#include "llvm/Support/MachineValueType.h"
|
||||
#include "llvm/Support/PrettyStackTrace.h"
|
||||
#include "llvm/Support/Regex.h"
|
||||
#include "llvm/Support/Signals.h"
|
||||
#include "llvm/TableGen/Error.h"
|
||||
#include "llvm/TableGen/Main.h"
|
||||
|
@ -37,6 +38,12 @@ static llvm::cl::opt<std::string>
|
|||
"are planning to emit"),
|
||||
llvm::cl::init("LLVM_IntrOp"), llvm::cl::cat(IntrinsicGenCat));
|
||||
|
||||
static llvm::cl::opt<std::string> accessGroupRegexp(
|
||||
"llvmir-intrinsics-access-group-regexp",
|
||||
llvm::cl::desc("Mark intrinsics that match the specified "
|
||||
"regexp as taking an access group metadata"),
|
||||
llvm::cl::cat(IntrinsicGenCat));
|
||||
|
||||
// Used to represent the indices of overloadable operands/results.
|
||||
using IndicesTy = llvm::SmallBitVector;
|
||||
|
||||
|
@ -185,6 +192,10 @@ void printBracketedRange(const Range &range, llvm::raw_ostream &os) {
|
|||
static bool emitIntrinsic(const llvm::Record &record, llvm::raw_ostream &os) {
|
||||
LLVMIntrinsic intr(record);
|
||||
|
||||
llvm::Regex accessGroupMatcher(accessGroupRegexp);
|
||||
bool requiresAccessGroup =
|
||||
!accessGroupRegexp.empty() && accessGroupMatcher.match(record.getName());
|
||||
|
||||
// Prepare strings for traits, if any.
|
||||
llvm::SmallVector<llvm::StringRef, 2> traits;
|
||||
if (intr.isCommutative())
|
||||
|
@ -195,6 +206,8 @@ static bool emitIntrinsic(const llvm::Record &record, llvm::raw_ostream &os) {
|
|||
// Prepare strings for operands.
|
||||
llvm::SmallVector<llvm::StringRef, 8> operands(intr.getNumOperands(),
|
||||
"LLVM_Type");
|
||||
if (requiresAccessGroup)
|
||||
operands.push_back("OptionalAttr<SymbolRefArrayAttr>:$access_groups");
|
||||
|
||||
// Emit the definition.
|
||||
os << "def LLVM_" << intr.getProperRecordName() << " : " << opBaseClass
|
||||
|
@ -204,7 +217,8 @@ static bool emitIntrinsic(const llvm::Record &record, llvm::raw_ostream &os) {
|
|||
printBracketedRange(intr.getOverloadableOperandsIdxs().set_bits(), os);
|
||||
os << ", ";
|
||||
printBracketedRange(traits, os);
|
||||
os << ", " << intr.getNumResults() << ">, Arguments<(ins"
|
||||
os << ", " << intr.getNumResults() << ", "
|
||||
<< (requiresAccessGroup ? "1" : "0") << ">, Arguments<(ins"
|
||||
<< (operands.empty() ? "" : " ");
|
||||
llvm::interleaveComma(operands, os);
|
||||
os << ")>;\n\n";
|
||||
|
|
Loading…
Reference in New Issue