forked from OSchip/llvm-project
[mlir][SideEffects] Enable specifying side effects directly on the arguments/results of an operation.
Summary: New classes are added to ODS to enable specifying additional information on the arguments and results of an operation. These classes, `Arg` and `Res` allow for adding a description and a set of 'decorators' along with the constraint. This enables specifying the side effects of an operation directly on the arguments and results themselves. Example: ``` def LoadOp : Std_Op<"load"> { let arguments = (ins Arg<AnyMemRef, "the MemRef to load from", [MemRead]>:$memref, Variadic<Index>:$indices); } ``` Differential Revision: https://reviews.llvm.org/D74440
This commit is contained in:
parent
f8923584da
commit
20dca52288
|
@ -16,6 +16,7 @@
|
|||
include "mlir/Analysis/CallInterfaces.td"
|
||||
include "mlir/Analysis/ControlFlowInterfaces.td"
|
||||
include "mlir/IR/OpAsmInterface.td"
|
||||
include "mlir/IR/SideEffects.td"
|
||||
|
||||
def Std_Dialect : Dialect {
|
||||
let name = "std";
|
||||
|
@ -1052,7 +1053,9 @@ def LoadOp : Std_Op<"load",
|
|||
%3 = load %0[%1, %1] : memref<4x4xi32>
|
||||
}];
|
||||
|
||||
let arguments = (ins AnyMemRef:$memref, Variadic<Index>:$indices);
|
||||
let arguments = (ins Arg<AnyMemRef, "the reference to load from",
|
||||
[MemRead]>:$memref,
|
||||
Variadic<Index>:$indices);
|
||||
let results = (outs AnyType:$result);
|
||||
|
||||
let builders = [OpBuilder<
|
||||
|
@ -1563,8 +1566,10 @@ def StoreOp : Std_Op<"store",
|
|||
store %v, %A[%i, %j] : memref<4x128xf32, (d0, d1) -> (d0, d1), 0>
|
||||
}];
|
||||
|
||||
let arguments = (ins AnyType:$value, AnyMemRef:$memref,
|
||||
Variadic<Index>:$indices);
|
||||
let arguments = (ins AnyType:$value,
|
||||
Arg<AnyMemRef, "the reference to store to",
|
||||
[MemWrite]>:$memref,
|
||||
Variadic<Index>:$indices);
|
||||
|
||||
let builders = [OpBuilder<
|
||||
"Builder *, OperationState &result, Value valueToStore, Value memref", [{
|
||||
|
@ -1846,7 +1851,8 @@ def TensorLoadOp : Std_Op<"tensor_load",
|
|||
%12 = tensor_load %10 : memref<4x?xf32, #layout, memspace0>
|
||||
}];
|
||||
|
||||
let arguments = (ins AnyMemRef:$memref);
|
||||
let arguments = (ins Arg<AnyMemRef, "the reference to load from",
|
||||
[MemRead]>:$memref);
|
||||
let results = (outs AnyTensor:$result);
|
||||
// TensorLoadOp is fully verified by traits.
|
||||
let verifier = ?;
|
||||
|
@ -1890,7 +1896,9 @@ def TensorStoreOp : Std_Op<"tensor_store",
|
|||
tensor_store %8, %10 : memref<4x?xf32, #layout, memspace0>
|
||||
}];
|
||||
|
||||
let arguments = (ins AnyTensor:$tensor, AnyMemRef:$memref);
|
||||
let arguments = (ins AnyTensor:$tensor,
|
||||
Arg<AnyMemRef, "the reference to store to",
|
||||
[MemWrite]>:$memref);
|
||||
// TensorStoreOp is fully verified by traits.
|
||||
let verifier = ?;
|
||||
|
||||
|
|
|
@ -1712,6 +1712,29 @@ class OpBuilder<string p, code b = ""> {
|
|||
code body = b;
|
||||
}
|
||||
|
||||
// A base decorator class that may optionally be added to OpVariables.
|
||||
class OpVariableDecorator;
|
||||
|
||||
// Class for providing additional information on the variables, i.e. arguments
|
||||
// and results, of an operation.
|
||||
class OpVariable<Constraint varConstraint, string desc = "",
|
||||
list<OpVariableDecorator> varDecorators = []> {
|
||||
// The constraint, either attribute or type, of the argument.
|
||||
Constraint constraint = varConstraint;
|
||||
|
||||
// A description for the argument.
|
||||
string description = desc;
|
||||
|
||||
// The list of decorators for this variable, e.g. side effects.
|
||||
list<OpVariableDecorator> decorators = varDecorators;
|
||||
}
|
||||
class Arg<Constraint constraint, string desc = "",
|
||||
list<OpVariableDecorator> decorators = []>
|
||||
: OpVariable<constraint, desc, decorators>;
|
||||
class Res<Constraint constraint, string desc = "",
|
||||
list<OpVariableDecorator> decorators = []>
|
||||
: OpVariable<constraint, desc, decorators>;
|
||||
|
||||
// Base class for all ops.
|
||||
class Op<Dialect dialect, string mnemonic, list<OpTrait> props = []> {
|
||||
// The dialect of the op.
|
||||
|
|
|
@ -107,7 +107,7 @@ class EffectOpInterfaceBase<string name, string baseEffect>
|
|||
// This class is the general base side effect class. This is used by derived
|
||||
// effect interfaces to define their effects.
|
||||
class SideEffect<EffectOpInterfaceBase interface, string effectName,
|
||||
string resourceName> {
|
||||
string resourceName> : OpVariableDecorator {
|
||||
/// The parent interface that the effect belongs to.
|
||||
string interfaceTrait = interface.trait;
|
||||
|
||||
|
|
|
@ -57,6 +57,34 @@ public:
|
|||
// Returns this op's C++ class name prefixed with namespaces.
|
||||
std::string getQualCppClassName() const;
|
||||
|
||||
/// A class used to represent the decorators of an operator variable, i.e.
|
||||
/// argument or result.
|
||||
struct VariableDecorator {
|
||||
public:
|
||||
explicit VariableDecorator(const llvm::Record *def) : def(def) {}
|
||||
const llvm::Record &getDef() const { return *def; }
|
||||
|
||||
protected:
|
||||
// The TableGen definition of this decorator.
|
||||
const llvm::Record *def;
|
||||
};
|
||||
|
||||
// A utility iterator over a list of variable decorators.
|
||||
struct VariableDecoratorIterator
|
||||
: public llvm::mapped_iterator<llvm::Init *const *,
|
||||
VariableDecorator (*)(llvm::Init *)> {
|
||||
using reference = VariableDecorator;
|
||||
|
||||
/// Initializes the iterator to the specified iterator.
|
||||
VariableDecoratorIterator(llvm::Init *const *it)
|
||||
: llvm::mapped_iterator<llvm::Init *const *,
|
||||
VariableDecorator (*)(llvm::Init *)>(it,
|
||||
&unwrap) {}
|
||||
static VariableDecorator unwrap(llvm::Init *init);
|
||||
};
|
||||
using var_decorator_iterator = VariableDecoratorIterator;
|
||||
using var_decorator_range = llvm::iterator_range<VariableDecoratorIterator>;
|
||||
|
||||
using value_iterator = NamedTypeConstraint *;
|
||||
using value_range = llvm::iterator_range<value_iterator>;
|
||||
|
||||
|
@ -84,6 +112,8 @@ public:
|
|||
TypeConstraint getResultTypeConstraint(int index) const;
|
||||
// Returns the `index`-th result's name.
|
||||
StringRef getResultName(int index) const;
|
||||
// Returns the `index`-th result's decorators.
|
||||
var_decorator_range getResultDecorators(int index) const;
|
||||
|
||||
// Returns the number of variadic results in this operation.
|
||||
unsigned getNumVariadicResults() const;
|
||||
|
@ -128,6 +158,7 @@ public:
|
|||
// Op argument (attribute or operand) accessors.
|
||||
Argument getArg(int index) const;
|
||||
StringRef getArgName(int index) const;
|
||||
var_decorator_range getArgDecorators(int index) const;
|
||||
|
||||
// Returns the trait wrapper for the given MLIR C++ `trait`.
|
||||
// TODO: We should add a C++ wrapper class for TableGen OpTrait instead of
|
||||
|
|
|
@ -0,0 +1,55 @@
|
|||
//===- SideEffects.h - Side Effects classes ---------------------*- C++ -*-===//
|
||||
//
|
||||
// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// Wrapper around side effect related classes defined in TableGen.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_TABLEGEN_SIDEEFFECTS_H_
|
||||
#define MLIR_TABLEGEN_SIDEEFFECTS_H_
|
||||
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "mlir/TableGen/Operator.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace tblgen {
|
||||
|
||||
// This class represents a specific instance of an effect that is being
|
||||
// exhibited.
|
||||
class SideEffect : public Operator::VariableDecorator {
|
||||
public:
|
||||
// Return the name of the C++ effect.
|
||||
StringRef getName() const;
|
||||
|
||||
// Return the name of the base C++ effect.
|
||||
StringRef getBaseName() const;
|
||||
|
||||
// Return the name of the parent interface trait.
|
||||
StringRef getInterfaceTrait() const;
|
||||
|
||||
// Return the name of the resource class.
|
||||
StringRef getResource() const;
|
||||
|
||||
static bool classof(const Operator::VariableDecorator *var);
|
||||
};
|
||||
|
||||
// This class represents an instance of a side effect interface applied to an
|
||||
// operation. This is a wrapper around an OpInterfaceTrait that also includes
|
||||
// the effects that are applied.
|
||||
class SideEffectTrait : public InterfaceOpTrait {
|
||||
public:
|
||||
// Return the effects that are attached to the side effect interface.
|
||||
Operator::var_decorator_range getEffects() const;
|
||||
|
||||
static bool classof(const OpTrait *t);
|
||||
};
|
||||
|
||||
} // end namespace tblgen
|
||||
} // end namespace mlir
|
||||
|
||||
#endif // MLIR_TABLEGEN_SIDEEFFECTS_H_
|
|
@ -10,6 +10,7 @@ add_llvm_library(LLVMMLIRTableGen
|
|||
OpTrait.cpp
|
||||
Pattern.cpp
|
||||
Predicate.cpp
|
||||
SideEffects.cpp
|
||||
Successor.cpp
|
||||
Type.cpp
|
||||
|
||||
|
|
|
@ -109,6 +109,15 @@ StringRef tblgen::Operator::getResultName(int index) const {
|
|||
return results->getArgNameStr(index);
|
||||
}
|
||||
|
||||
auto tblgen::Operator::getResultDecorators(int index) const
|
||||
-> var_decorator_range {
|
||||
Record *result =
|
||||
cast<DefInit>(def.getValueAsDag("results")->getArg(index))->getDef();
|
||||
if (!result->isSubClassOf("OpVariable"))
|
||||
return var_decorator_range(nullptr, nullptr);
|
||||
return *result->getValueAsListInit("decorators");
|
||||
}
|
||||
|
||||
unsigned tblgen::Operator::getNumVariadicResults() const {
|
||||
return std::count_if(
|
||||
results.begin(), results.end(),
|
||||
|
@ -138,6 +147,15 @@ StringRef tblgen::Operator::getArgName(int index) const {
|
|||
return argumentValues->getArgName(index)->getValue();
|
||||
}
|
||||
|
||||
auto tblgen::Operator::getArgDecorators(int index) const
|
||||
-> var_decorator_range {
|
||||
Record *arg =
|
||||
cast<DefInit>(def.getValueAsDag("arguments")->getArg(index))->getDef();
|
||||
if (!arg->isSubClassOf("OpVariable"))
|
||||
return var_decorator_range(nullptr, nullptr);
|
||||
return *arg->getValueAsListInit("decorators");
|
||||
}
|
||||
|
||||
const tblgen::OpTrait *tblgen::Operator::getTrait(StringRef trait) const {
|
||||
for (const auto &t : traits) {
|
||||
if (auto opTrait = dyn_cast<tblgen::NativeOpTrait>(&t)) {
|
||||
|
@ -226,6 +244,7 @@ void tblgen::Operator::populateOpStructure() {
|
|||
auto typeConstraintClass = recordKeeper.getClass("TypeConstraint");
|
||||
auto attrClass = recordKeeper.getClass("Attr");
|
||||
auto derivedAttrClass = recordKeeper.getClass("DerivedAttr");
|
||||
auto opVarClass = recordKeeper.getClass("OpVariable");
|
||||
numNativeAttributes = 0;
|
||||
|
||||
DagInit *argumentValues = def.getValueAsDag("arguments");
|
||||
|
@ -240,10 +259,12 @@ void tblgen::Operator::populateOpStructure() {
|
|||
PrintFatalError(def.getLoc(),
|
||||
Twine("undefined type for argument #") + Twine(i));
|
||||
Record *argDef = argDefInit->getDef();
|
||||
if (argDef->isSubClassOf(opVarClass))
|
||||
argDef = argDef->getValueAsDef("constraint");
|
||||
|
||||
if (argDef->isSubClassOf(typeConstraintClass)) {
|
||||
operands.push_back(
|
||||
NamedTypeConstraint{givenName, TypeConstraint(argDefInit)});
|
||||
NamedTypeConstraint{givenName, TypeConstraint(argDef)});
|
||||
} else if (argDef->isSubClassOf(attrClass)) {
|
||||
if (givenName.empty())
|
||||
PrintFatalError(argDef->getLoc(), "attributes must be named");
|
||||
|
@ -285,6 +306,8 @@ void tblgen::Operator::populateOpStructure() {
|
|||
int operandIndex = 0, attrIndex = 0;
|
||||
for (unsigned i = 0; i != numArgs; ++i) {
|
||||
Record *argDef = dyn_cast<DefInit>(argumentValues->getArg(i))->getDef();
|
||||
if (argDef->isSubClassOf(opVarClass))
|
||||
argDef = argDef->getValueAsDef("constraint");
|
||||
|
||||
if (argDef->isSubClassOf(typeConstraintClass)) {
|
||||
arguments.emplace_back(&operands[operandIndex++]);
|
||||
|
@ -303,11 +326,14 @@ void tblgen::Operator::populateOpStructure() {
|
|||
// Handle results.
|
||||
for (unsigned i = 0, e = resultsDag->getNumArgs(); i < e; ++i) {
|
||||
auto name = resultsDag->getArgNameStr(i);
|
||||
auto *resultDef = dyn_cast<DefInit>(resultsDag->getArg(i));
|
||||
if (!resultDef) {
|
||||
auto *resultInit = dyn_cast<DefInit>(resultsDag->getArg(i));
|
||||
if (!resultInit) {
|
||||
PrintFatalError(def.getLoc(),
|
||||
Twine("undefined type for result #") + Twine(i));
|
||||
}
|
||||
auto *resultDef = resultInit->getDef();
|
||||
if (resultDef->isSubClassOf(opVarClass))
|
||||
resultDef = resultDef->getValueAsDef("constraint");
|
||||
results.push_back({name, TypeConstraint(resultDef)});
|
||||
}
|
||||
|
||||
|
@ -394,3 +420,8 @@ void tblgen::Operator::print(llvm::raw_ostream &os) const {
|
|||
os << "[operand] " << arg.get<NamedTypeConstraint *>()->name << '\n';
|
||||
}
|
||||
}
|
||||
|
||||
auto tblgen::Operator::VariableDecoratorIterator::unwrap(llvm::Init *init)
|
||||
-> VariableDecorator {
|
||||
return VariableDecorator(cast<llvm::DefInit>(init)->getDef());
|
||||
}
|
||||
|
|
|
@ -0,0 +1,51 @@
|
|||
//===- SideEffects.cpp - SideEffect classes -------------------------------===//
|
||||
//
|
||||
// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/TableGen/SideEffects.h"
|
||||
#include "llvm/TableGen/Record.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::tblgen;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// SideEffect
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
StringRef SideEffect::getName() const {
|
||||
return def->getValueAsString("effect");
|
||||
}
|
||||
|
||||
StringRef SideEffect::getBaseName() const {
|
||||
return def->getValueAsString("baseEffect");
|
||||
}
|
||||
|
||||
StringRef SideEffect::getInterfaceTrait() const {
|
||||
return def->getValueAsString("interfaceTrait");
|
||||
}
|
||||
|
||||
StringRef SideEffect::getResource() const {
|
||||
auto value = def->getValueAsString("resource");
|
||||
return value.empty() ? "::mlir::SideEffects::DefaultResource" : value;
|
||||
}
|
||||
|
||||
bool SideEffect::classof(const Operator::VariableDecorator *var) {
|
||||
return var->getDef().isSubClassOf("SideEffect");
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// SideEffectsTrait
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
Operator::var_decorator_range SideEffectTrait::getEffects() const {
|
||||
auto *listInit = dyn_cast<llvm::ListInit>(def->getValueInit("effects"));
|
||||
return {listInit->begin(), listInit->end()};
|
||||
}
|
||||
|
||||
bool SideEffectTrait::classof(const OpTrait *t) {
|
||||
return t->getDef().isSubClassOf("SideEffectsTraitBase");
|
||||
}
|
|
@ -0,0 +1,26 @@
|
|||
// RUN: mlir-tblgen -gen-op-defs -I %S/../../include %s | FileCheck %s
|
||||
|
||||
include "mlir/IR/SideEffects.td"
|
||||
|
||||
def TEST_Dialect : Dialect {
|
||||
let name = "test";
|
||||
}
|
||||
class TEST_Op<string mnemonic, list<OpTrait> traits = []> :
|
||||
Op<TEST_Dialect, mnemonic, traits>;
|
||||
|
||||
def SideEffectOpA : TEST_Op<"side_effect_op_a"> {
|
||||
let arguments = (ins Arg<Variadic<AnyMemRef>, "", [MemRead]>);
|
||||
let results = (outs Res<AnyMemRef, "", [MemAlloc<"CustomResource">]>);
|
||||
}
|
||||
|
||||
def SideEffectOpB : TEST_Op<"side_effect_op_b",
|
||||
[MemoryEffects<[MemWrite<"CustomResource">]>]>;
|
||||
|
||||
// CHECK: void SideEffectOpA::getEffects
|
||||
// CHECK: for (Value value : getODSOperands(0))
|
||||
// CHECK: effects.emplace_back(MemoryEffects::Read::get(), value, ::mlir::SideEffects::DefaultResource::get());
|
||||
// CHECK: for (Value value : getODSResults(0))
|
||||
// CHECK: effects.emplace_back(MemoryEffects::Allocate::get(), value, CustomResource::get());
|
||||
|
||||
// CHECK: void SideEffectOpB::getEffects
|
||||
// CHECK: effects.emplace_back(MemoryEffects::Write::get(), CustomResource::get());
|
|
@ -20,6 +20,7 @@
|
|||
#include "mlir/TableGen/OpInterfaces.h"
|
||||
#include "mlir/TableGen/OpTrait.h"
|
||||
#include "mlir/TableGen/Operator.h"
|
||||
#include "mlir/TableGen/SideEffects.h"
|
||||
#include "llvm/ADT/Sequence.h"
|
||||
#include "llvm/ADT/StringExtras.h"
|
||||
#include "llvm/Support/Signals.h"
|
||||
|
@ -280,6 +281,9 @@ private:
|
|||
// Generate the OpInterface methods.
|
||||
void genOpInterfaceMethods();
|
||||
|
||||
// Generate the side effect interface methods.
|
||||
void genSideEffectInterfaceMethods();
|
||||
|
||||
private:
|
||||
// The TableGen record for this op.
|
||||
// TODO(antiagainst,zinenko): OpEmitter should not have a Record directly,
|
||||
|
@ -321,6 +325,7 @@ OpEmitter::OpEmitter(const Operator &op)
|
|||
genFolderDecls();
|
||||
genOpInterfaceMethods();
|
||||
generateOpFormat(op, opClass);
|
||||
genSideEffectInterfaceMethods();
|
||||
}
|
||||
|
||||
void OpEmitter::emitDecl(const Operator &op, raw_ostream &os) {
|
||||
|
@ -1161,6 +1166,75 @@ void OpEmitter::genOpInterfaceMethods() {
|
|||
}
|
||||
}
|
||||
|
||||
void OpEmitter::genSideEffectInterfaceMethods() {
|
||||
enum EffectKind { Operand, Result, Static };
|
||||
struct EffectLocation {
|
||||
/// The effect applied.
|
||||
SideEffect effect;
|
||||
|
||||
/// The index if the kind is either operand or result.
|
||||
unsigned index : 30;
|
||||
|
||||
/// The kind of the location.
|
||||
EffectKind kind : 2;
|
||||
};
|
||||
|
||||
StringMap<SmallVector<EffectLocation, 1>> interfaceEffects;
|
||||
auto resolveDecorators = [&](Operator::var_decorator_range decorators,
|
||||
unsigned index, EffectKind kind) {
|
||||
for (auto decorator : decorators)
|
||||
if (SideEffect *effect = dyn_cast<SideEffect>(&decorator))
|
||||
interfaceEffects[effect->getInterfaceTrait()].push_back(
|
||||
EffectLocation{*effect, index, kind});
|
||||
};
|
||||
|
||||
// Collect effects that were specified via:
|
||||
/// Traits.
|
||||
for (const auto &trait : op.getTraits())
|
||||
if (const auto *opTrait = dyn_cast<tblgen::SideEffectTrait>(&trait))
|
||||
resolveDecorators(opTrait->getEffects(), /*index=*/0, EffectKind::Static);
|
||||
/// Operands.
|
||||
for (unsigned i = 0, operandIt = 0, e = op.getNumArgs(); i != e; ++i) {
|
||||
if (op.getArg(i).is<NamedTypeConstraint *>()) {
|
||||
resolveDecorators(op.getArgDecorators(i), operandIt, EffectKind::Operand);
|
||||
++operandIt;
|
||||
}
|
||||
}
|
||||
/// Results.
|
||||
for (unsigned i = 0, e = op.getNumResults(); i != e; ++i)
|
||||
resolveDecorators(op.getResultDecorators(i), i, EffectKind::Result);
|
||||
|
||||
for (auto &it : interfaceEffects) {
|
||||
StringRef baseEffect = it.second.front().effect.getBaseName();
|
||||
auto effectsParam =
|
||||
llvm::formatv(
|
||||
"SmallVectorImpl<SideEffects::EffectInstance<{0}>> &effects",
|
||||
baseEffect)
|
||||
.str();
|
||||
|
||||
// Generate the 'getEffects' method.
|
||||
auto &getEffects = opClass.newMethod("void", "getEffects", effectsParam);
|
||||
auto &body = getEffects.body();
|
||||
|
||||
// Add effect instances for each of the locations marked on the operation.
|
||||
for (auto &location : it.second) {
|
||||
if (location.kind != EffectKind::Static) {
|
||||
body << " for (Value value : getODS"
|
||||
<< (location.kind == EffectKind::Operand ? "Operands" : "Results")
|
||||
<< "(" << location.index << "))\n ";
|
||||
}
|
||||
|
||||
body << " effects.emplace_back(" << location.effect.getName()
|
||||
<< "::get()";
|
||||
|
||||
// If the effect isn't static, it has a specific value attached to it.
|
||||
if (location.kind != EffectKind::Static)
|
||||
body << ", value";
|
||||
body << ", " << location.effect.getResource() << "::get());\n";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void OpEmitter::genParser() {
|
||||
if (!hasStringAttribute(def, "parser") ||
|
||||
hasStringAttribute(def, "assemblyFormat"))
|
||||
|
|
Loading…
Reference in New Issue