[mlir][SideEffects] Enable specifying side effects directly on the arguments/results of an operation.

Summary:
New classes are added to ODS to enable specifying additional information on the arguments and results of an operation. These classes, `Arg` and `Res` allow for adding a description and a set of 'decorators' along with the constraint. This enables specifying the side effects of an operation directly on the arguments and results themselves.

Example:
```
def LoadOp : Std_Op<"load"> {
  let arguments = (ins Arg<AnyMemRef, "the MemRef to load from",
                           [MemRead]>:$memref,
                       Variadic<Index>:$indices);
}
```

Differential Revision: https://reviews.llvm.org/D74440
This commit is contained in:
River Riddle 2020-03-06 13:55:36 -08:00
parent f8923584da
commit 20dca52288
10 changed files with 309 additions and 9 deletions

View File

@ -16,6 +16,7 @@
include "mlir/Analysis/CallInterfaces.td"
include "mlir/Analysis/ControlFlowInterfaces.td"
include "mlir/IR/OpAsmInterface.td"
include "mlir/IR/SideEffects.td"
def Std_Dialect : Dialect {
let name = "std";
@ -1052,7 +1053,9 @@ def LoadOp : Std_Op<"load",
%3 = load %0[%1, %1] : memref<4x4xi32>
}];
let arguments = (ins AnyMemRef:$memref, Variadic<Index>:$indices);
let arguments = (ins Arg<AnyMemRef, "the reference to load from",
[MemRead]>:$memref,
Variadic<Index>:$indices);
let results = (outs AnyType:$result);
let builders = [OpBuilder<
@ -1563,8 +1566,10 @@ def StoreOp : Std_Op<"store",
store %v, %A[%i, %j] : memref<4x128xf32, (d0, d1) -> (d0, d1), 0>
}];
let arguments = (ins AnyType:$value, AnyMemRef:$memref,
Variadic<Index>:$indices);
let arguments = (ins AnyType:$value,
Arg<AnyMemRef, "the reference to store to",
[MemWrite]>:$memref,
Variadic<Index>:$indices);
let builders = [OpBuilder<
"Builder *, OperationState &result, Value valueToStore, Value memref", [{
@ -1846,7 +1851,8 @@ def TensorLoadOp : Std_Op<"tensor_load",
%12 = tensor_load %10 : memref<4x?xf32, #layout, memspace0>
}];
let arguments = (ins AnyMemRef:$memref);
let arguments = (ins Arg<AnyMemRef, "the reference to load from",
[MemRead]>:$memref);
let results = (outs AnyTensor:$result);
// TensorLoadOp is fully verified by traits.
let verifier = ?;
@ -1890,7 +1896,9 @@ def TensorStoreOp : Std_Op<"tensor_store",
tensor_store %8, %10 : memref<4x?xf32, #layout, memspace0>
}];
let arguments = (ins AnyTensor:$tensor, AnyMemRef:$memref);
let arguments = (ins AnyTensor:$tensor,
Arg<AnyMemRef, "the reference to store to",
[MemWrite]>:$memref);
// TensorStoreOp is fully verified by traits.
let verifier = ?;

View File

@ -1712,6 +1712,29 @@ class OpBuilder<string p, code b = ""> {
code body = b;
}
// A base decorator class that may optionally be added to OpVariables.
class OpVariableDecorator;
// Class for providing additional information on the variables, i.e. arguments
// and results, of an operation.
class OpVariable<Constraint varConstraint, string desc = "",
list<OpVariableDecorator> varDecorators = []> {
// The constraint, either attribute or type, of the argument.
Constraint constraint = varConstraint;
// A description for the argument.
string description = desc;
// The list of decorators for this variable, e.g. side effects.
list<OpVariableDecorator> decorators = varDecorators;
}
class Arg<Constraint constraint, string desc = "",
list<OpVariableDecorator> decorators = []>
: OpVariable<constraint, desc, decorators>;
class Res<Constraint constraint, string desc = "",
list<OpVariableDecorator> decorators = []>
: OpVariable<constraint, desc, decorators>;
// Base class for all ops.
class Op<Dialect dialect, string mnemonic, list<OpTrait> props = []> {
// The dialect of the op.

View File

@ -107,7 +107,7 @@ class EffectOpInterfaceBase<string name, string baseEffect>
// This class is the general base side effect class. This is used by derived
// effect interfaces to define their effects.
class SideEffect<EffectOpInterfaceBase interface, string effectName,
string resourceName> {
string resourceName> : OpVariableDecorator {
/// The parent interface that the effect belongs to.
string interfaceTrait = interface.trait;

View File

@ -57,6 +57,34 @@ public:
// Returns this op's C++ class name prefixed with namespaces.
std::string getQualCppClassName() const;
/// A class used to represent the decorators of an operator variable, i.e.
/// argument or result.
struct VariableDecorator {
public:
explicit VariableDecorator(const llvm::Record *def) : def(def) {}
const llvm::Record &getDef() const { return *def; }
protected:
// The TableGen definition of this decorator.
const llvm::Record *def;
};
// A utility iterator over a list of variable decorators.
struct VariableDecoratorIterator
: public llvm::mapped_iterator<llvm::Init *const *,
VariableDecorator (*)(llvm::Init *)> {
using reference = VariableDecorator;
/// Initializes the iterator to the specified iterator.
VariableDecoratorIterator(llvm::Init *const *it)
: llvm::mapped_iterator<llvm::Init *const *,
VariableDecorator (*)(llvm::Init *)>(it,
&unwrap) {}
static VariableDecorator unwrap(llvm::Init *init);
};
using var_decorator_iterator = VariableDecoratorIterator;
using var_decorator_range = llvm::iterator_range<VariableDecoratorIterator>;
using value_iterator = NamedTypeConstraint *;
using value_range = llvm::iterator_range<value_iterator>;
@ -84,6 +112,8 @@ public:
TypeConstraint getResultTypeConstraint(int index) const;
// Returns the `index`-th result's name.
StringRef getResultName(int index) const;
// Returns the `index`-th result's decorators.
var_decorator_range getResultDecorators(int index) const;
// Returns the number of variadic results in this operation.
unsigned getNumVariadicResults() const;
@ -128,6 +158,7 @@ public:
// Op argument (attribute or operand) accessors.
Argument getArg(int index) const;
StringRef getArgName(int index) const;
var_decorator_range getArgDecorators(int index) const;
// Returns the trait wrapper for the given MLIR C++ `trait`.
// TODO: We should add a C++ wrapper class for TableGen OpTrait instead of

View File

@ -0,0 +1,55 @@
//===- SideEffects.h - Side Effects classes ---------------------*- C++ -*-===//
//
// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Wrapper around side effect related classes defined in TableGen.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_TABLEGEN_SIDEEFFECTS_H_
#define MLIR_TABLEGEN_SIDEEFFECTS_H_
#include "mlir/Support/LLVM.h"
#include "mlir/TableGen/Operator.h"
namespace mlir {
namespace tblgen {
// This class represents a specific instance of an effect that is being
// exhibited.
class SideEffect : public Operator::VariableDecorator {
public:
// Return the name of the C++ effect.
StringRef getName() const;
// Return the name of the base C++ effect.
StringRef getBaseName() const;
// Return the name of the parent interface trait.
StringRef getInterfaceTrait() const;
// Return the name of the resource class.
StringRef getResource() const;
static bool classof(const Operator::VariableDecorator *var);
};
// This class represents an instance of a side effect interface applied to an
// operation. This is a wrapper around an OpInterfaceTrait that also includes
// the effects that are applied.
class SideEffectTrait : public InterfaceOpTrait {
public:
// Return the effects that are attached to the side effect interface.
Operator::var_decorator_range getEffects() const;
static bool classof(const OpTrait *t);
};
} // end namespace tblgen
} // end namespace mlir
#endif // MLIR_TABLEGEN_SIDEEFFECTS_H_

View File

@ -10,6 +10,7 @@ add_llvm_library(LLVMMLIRTableGen
OpTrait.cpp
Pattern.cpp
Predicate.cpp
SideEffects.cpp
Successor.cpp
Type.cpp

View File

@ -109,6 +109,15 @@ StringRef tblgen::Operator::getResultName(int index) const {
return results->getArgNameStr(index);
}
auto tblgen::Operator::getResultDecorators(int index) const
-> var_decorator_range {
Record *result =
cast<DefInit>(def.getValueAsDag("results")->getArg(index))->getDef();
if (!result->isSubClassOf("OpVariable"))
return var_decorator_range(nullptr, nullptr);
return *result->getValueAsListInit("decorators");
}
unsigned tblgen::Operator::getNumVariadicResults() const {
return std::count_if(
results.begin(), results.end(),
@ -138,6 +147,15 @@ StringRef tblgen::Operator::getArgName(int index) const {
return argumentValues->getArgName(index)->getValue();
}
auto tblgen::Operator::getArgDecorators(int index) const
-> var_decorator_range {
Record *arg =
cast<DefInit>(def.getValueAsDag("arguments")->getArg(index))->getDef();
if (!arg->isSubClassOf("OpVariable"))
return var_decorator_range(nullptr, nullptr);
return *arg->getValueAsListInit("decorators");
}
const tblgen::OpTrait *tblgen::Operator::getTrait(StringRef trait) const {
for (const auto &t : traits) {
if (auto opTrait = dyn_cast<tblgen::NativeOpTrait>(&t)) {
@ -226,6 +244,7 @@ void tblgen::Operator::populateOpStructure() {
auto typeConstraintClass = recordKeeper.getClass("TypeConstraint");
auto attrClass = recordKeeper.getClass("Attr");
auto derivedAttrClass = recordKeeper.getClass("DerivedAttr");
auto opVarClass = recordKeeper.getClass("OpVariable");
numNativeAttributes = 0;
DagInit *argumentValues = def.getValueAsDag("arguments");
@ -240,10 +259,12 @@ void tblgen::Operator::populateOpStructure() {
PrintFatalError(def.getLoc(),
Twine("undefined type for argument #") + Twine(i));
Record *argDef = argDefInit->getDef();
if (argDef->isSubClassOf(opVarClass))
argDef = argDef->getValueAsDef("constraint");
if (argDef->isSubClassOf(typeConstraintClass)) {
operands.push_back(
NamedTypeConstraint{givenName, TypeConstraint(argDefInit)});
NamedTypeConstraint{givenName, TypeConstraint(argDef)});
} else if (argDef->isSubClassOf(attrClass)) {
if (givenName.empty())
PrintFatalError(argDef->getLoc(), "attributes must be named");
@ -285,6 +306,8 @@ void tblgen::Operator::populateOpStructure() {
int operandIndex = 0, attrIndex = 0;
for (unsigned i = 0; i != numArgs; ++i) {
Record *argDef = dyn_cast<DefInit>(argumentValues->getArg(i))->getDef();
if (argDef->isSubClassOf(opVarClass))
argDef = argDef->getValueAsDef("constraint");
if (argDef->isSubClassOf(typeConstraintClass)) {
arguments.emplace_back(&operands[operandIndex++]);
@ -303,11 +326,14 @@ void tblgen::Operator::populateOpStructure() {
// Handle results.
for (unsigned i = 0, e = resultsDag->getNumArgs(); i < e; ++i) {
auto name = resultsDag->getArgNameStr(i);
auto *resultDef = dyn_cast<DefInit>(resultsDag->getArg(i));
if (!resultDef) {
auto *resultInit = dyn_cast<DefInit>(resultsDag->getArg(i));
if (!resultInit) {
PrintFatalError(def.getLoc(),
Twine("undefined type for result #") + Twine(i));
}
auto *resultDef = resultInit->getDef();
if (resultDef->isSubClassOf(opVarClass))
resultDef = resultDef->getValueAsDef("constraint");
results.push_back({name, TypeConstraint(resultDef)});
}
@ -394,3 +420,8 @@ void tblgen::Operator::print(llvm::raw_ostream &os) const {
os << "[operand] " << arg.get<NamedTypeConstraint *>()->name << '\n';
}
}
auto tblgen::Operator::VariableDecoratorIterator::unwrap(llvm::Init *init)
-> VariableDecorator {
return VariableDecorator(cast<llvm::DefInit>(init)->getDef());
}

View File

@ -0,0 +1,51 @@
//===- SideEffects.cpp - SideEffect classes -------------------------------===//
//
// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/TableGen/SideEffects.h"
#include "llvm/TableGen/Record.h"
using namespace mlir;
using namespace mlir::tblgen;
//===----------------------------------------------------------------------===//
// SideEffect
//===----------------------------------------------------------------------===//
StringRef SideEffect::getName() const {
return def->getValueAsString("effect");
}
StringRef SideEffect::getBaseName() const {
return def->getValueAsString("baseEffect");
}
StringRef SideEffect::getInterfaceTrait() const {
return def->getValueAsString("interfaceTrait");
}
StringRef SideEffect::getResource() const {
auto value = def->getValueAsString("resource");
return value.empty() ? "::mlir::SideEffects::DefaultResource" : value;
}
bool SideEffect::classof(const Operator::VariableDecorator *var) {
return var->getDef().isSubClassOf("SideEffect");
}
//===----------------------------------------------------------------------===//
// SideEffectsTrait
//===----------------------------------------------------------------------===//
Operator::var_decorator_range SideEffectTrait::getEffects() const {
auto *listInit = dyn_cast<llvm::ListInit>(def->getValueInit("effects"));
return {listInit->begin(), listInit->end()};
}
bool SideEffectTrait::classof(const OpTrait *t) {
return t->getDef().isSubClassOf("SideEffectsTraitBase");
}

View File

@ -0,0 +1,26 @@
// RUN: mlir-tblgen -gen-op-defs -I %S/../../include %s | FileCheck %s
include "mlir/IR/SideEffects.td"
def TEST_Dialect : Dialect {
let name = "test";
}
class TEST_Op<string mnemonic, list<OpTrait> traits = []> :
Op<TEST_Dialect, mnemonic, traits>;
def SideEffectOpA : TEST_Op<"side_effect_op_a"> {
let arguments = (ins Arg<Variadic<AnyMemRef>, "", [MemRead]>);
let results = (outs Res<AnyMemRef, "", [MemAlloc<"CustomResource">]>);
}
def SideEffectOpB : TEST_Op<"side_effect_op_b",
[MemoryEffects<[MemWrite<"CustomResource">]>]>;
// CHECK: void SideEffectOpA::getEffects
// CHECK: for (Value value : getODSOperands(0))
// CHECK: effects.emplace_back(MemoryEffects::Read::get(), value, ::mlir::SideEffects::DefaultResource::get());
// CHECK: for (Value value : getODSResults(0))
// CHECK: effects.emplace_back(MemoryEffects::Allocate::get(), value, CustomResource::get());
// CHECK: void SideEffectOpB::getEffects
// CHECK: effects.emplace_back(MemoryEffects::Write::get(), CustomResource::get());

View File

@ -20,6 +20,7 @@
#include "mlir/TableGen/OpInterfaces.h"
#include "mlir/TableGen/OpTrait.h"
#include "mlir/TableGen/Operator.h"
#include "mlir/TableGen/SideEffects.h"
#include "llvm/ADT/Sequence.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/Signals.h"
@ -280,6 +281,9 @@ private:
// Generate the OpInterface methods.
void genOpInterfaceMethods();
// Generate the side effect interface methods.
void genSideEffectInterfaceMethods();
private:
// The TableGen record for this op.
// TODO(antiagainst,zinenko): OpEmitter should not have a Record directly,
@ -321,6 +325,7 @@ OpEmitter::OpEmitter(const Operator &op)
genFolderDecls();
genOpInterfaceMethods();
generateOpFormat(op, opClass);
genSideEffectInterfaceMethods();
}
void OpEmitter::emitDecl(const Operator &op, raw_ostream &os) {
@ -1161,6 +1166,75 @@ void OpEmitter::genOpInterfaceMethods() {
}
}
void OpEmitter::genSideEffectInterfaceMethods() {
enum EffectKind { Operand, Result, Static };
struct EffectLocation {
/// The effect applied.
SideEffect effect;
/// The index if the kind is either operand or result.
unsigned index : 30;
/// The kind of the location.
EffectKind kind : 2;
};
StringMap<SmallVector<EffectLocation, 1>> interfaceEffects;
auto resolveDecorators = [&](Operator::var_decorator_range decorators,
unsigned index, EffectKind kind) {
for (auto decorator : decorators)
if (SideEffect *effect = dyn_cast<SideEffect>(&decorator))
interfaceEffects[effect->getInterfaceTrait()].push_back(
EffectLocation{*effect, index, kind});
};
// Collect effects that were specified via:
/// Traits.
for (const auto &trait : op.getTraits())
if (const auto *opTrait = dyn_cast<tblgen::SideEffectTrait>(&trait))
resolveDecorators(opTrait->getEffects(), /*index=*/0, EffectKind::Static);
/// Operands.
for (unsigned i = 0, operandIt = 0, e = op.getNumArgs(); i != e; ++i) {
if (op.getArg(i).is<NamedTypeConstraint *>()) {
resolveDecorators(op.getArgDecorators(i), operandIt, EffectKind::Operand);
++operandIt;
}
}
/// Results.
for (unsigned i = 0, e = op.getNumResults(); i != e; ++i)
resolveDecorators(op.getResultDecorators(i), i, EffectKind::Result);
for (auto &it : interfaceEffects) {
StringRef baseEffect = it.second.front().effect.getBaseName();
auto effectsParam =
llvm::formatv(
"SmallVectorImpl<SideEffects::EffectInstance<{0}>> &effects",
baseEffect)
.str();
// Generate the 'getEffects' method.
auto &getEffects = opClass.newMethod("void", "getEffects", effectsParam);
auto &body = getEffects.body();
// Add effect instances for each of the locations marked on the operation.
for (auto &location : it.second) {
if (location.kind != EffectKind::Static) {
body << " for (Value value : getODS"
<< (location.kind == EffectKind::Operand ? "Operands" : "Results")
<< "(" << location.index << "))\n ";
}
body << " effects.emplace_back(" << location.effect.getName()
<< "::get()";
// If the effect isn't static, it has a specific value attached to it.
if (location.kind != EffectKind::Static)
body << ", value";
body << ", " << location.effect.getResource() << "::get());\n";
}
}
}
void OpEmitter::genParser() {
if (!hasStringAttribute(def, "parser") ||
hasStringAttribute(def, "assemblyFormat"))