forked from OSchip/llvm-project
[mlir] Add support for referencing a SymbolRefAttr in a SideEffectInstance
This allows for operations that exclusively affect symbol operations to better describe their side effects. Differential Revision: https://reviews.llvm.org/D91581
This commit is contained in:
parent
5f0ae23e71
commit
c0958b7b4c
|
@ -110,9 +110,20 @@ class EffectOpInterfaceBase<string name, string baseEffect>
|
|||
llvm::erase_if(effects, [&](auto &it) { return it.getValue() != value; });
|
||||
}
|
||||
|
||||
/// Collect all of the effect instances that operate on the provided symbol
|
||||
/// reference and place them in 'effects'.
|
||||
void getEffectsOnSymbol(::mlir::SymbolRefAttr value,
|
||||
llvm::SmallVectorImpl<::mlir::SideEffects::EffectInstance<
|
||||
}] # baseEffect # [{>> & effects) {
|
||||
getEffects(effects);
|
||||
llvm::erase_if(effects, [&](auto &it) {
|
||||
return it.getSymbolRef() != value;
|
||||
});
|
||||
}
|
||||
|
||||
/// Collect all of the effect instances that operate on the provided
|
||||
/// resource and place them in 'effects'.
|
||||
void getEffectsOnValue(::mlir::SideEffects::Resource *resource,
|
||||
void getEffectsOnResource(::mlir::SideEffects::Resource *resource,
|
||||
llvm::SmallVectorImpl<::mlir::SideEffects::EffectInstance<
|
||||
}] # baseEffect # [{>> & effects) {
|
||||
getEffects(effects);
|
||||
|
|
|
@ -131,9 +131,9 @@ struct AutomaticAllocationScopeResource
|
|||
|
||||
/// This class represents a specific instance of an effect. It contains the
|
||||
/// effect being applied, a resource that corresponds to where the effect is
|
||||
/// applied, an optional value (either operand, result, or region entry
|
||||
/// argument) that the effect is applied to, and an optional parameters
|
||||
/// attribute further specifying the details of the effect.
|
||||
/// applied, and an optional symbol reference or value(either operand, result,
|
||||
/// or region entry argument) that the effect is applied to, and an optional
|
||||
/// parameters attribute further specifying the details of the effect.
|
||||
template <typename EffectT> class EffectInstance {
|
||||
public:
|
||||
EffectInstance(EffectT *effect, Resource *resource = DefaultResource::get())
|
||||
|
@ -141,6 +141,9 @@ public:
|
|||
EffectInstance(EffectT *effect, Value value,
|
||||
Resource *resource = DefaultResource::get())
|
||||
: effect(effect), resource(resource), value(value) {}
|
||||
EffectInstance(EffectT *effect, SymbolRefAttr symbol,
|
||||
Resource *resource = DefaultResource::get())
|
||||
: effect(effect), resource(resource), value(symbol) {}
|
||||
EffectInstance(EffectT *effect, Attribute parameters,
|
||||
Resource *resource = DefaultResource::get())
|
||||
: effect(effect), resource(resource), parameters(parameters) {}
|
||||
|
@ -148,13 +151,23 @@ public:
|
|||
Resource *resource = DefaultResource::get())
|
||||
: effect(effect), resource(resource), value(value),
|
||||
parameters(parameters) {}
|
||||
EffectInstance(EffectT *effect, SymbolRefAttr symbol, Attribute parameters,
|
||||
Resource *resource = DefaultResource::get())
|
||||
: effect(effect), resource(resource), value(symbol),
|
||||
parameters(parameters) {}
|
||||
|
||||
/// Return the effect being applied.
|
||||
EffectT *getEffect() const { return effect; }
|
||||
|
||||
/// Return the value the effect is applied on, or nullptr if there isn't a
|
||||
/// known value being affected.
|
||||
Value getValue() const { return value; }
|
||||
Value getValue() const { return value ? value.dyn_cast<Value>() : Value(); }
|
||||
|
||||
/// Return the symbol reference the effect is applied on, or nullptr if there
|
||||
/// isn't a known smbol being affected.
|
||||
SymbolRefAttr getSymbolRef() const {
|
||||
return value ? value.dyn_cast<SymbolRefAttr>() : SymbolRefAttr();
|
||||
}
|
||||
|
||||
/// Return the resource that the effect applies to.
|
||||
Resource *getResource() const { return resource; }
|
||||
|
@ -169,8 +182,8 @@ private:
|
|||
/// The resource that the given value resides in.
|
||||
Resource *resource;
|
||||
|
||||
/// The value that the effect applies to. This is optionally null.
|
||||
Value value;
|
||||
/// The Symbol or Value that the effect applies to. This is optionally null.
|
||||
PointerUnion<SymbolRefAttr, Value> value;
|
||||
|
||||
/// Additional parameters of the effect instance. An attribute is used for
|
||||
/// type-safe structured storage and context-based uniquing. Concrete effects
|
||||
|
|
|
@ -94,6 +94,10 @@ public:
|
|||
// of `TypeAttrBase`).
|
||||
bool isTypeAttr() const;
|
||||
|
||||
// Returns true if this attribute is a symbol reference attribute (i.e., a
|
||||
// subclass of `SymbolRefAttr` or `FlatSymbolRefAttr`).
|
||||
bool isSymbolRefAttr() const;
|
||||
|
||||
// Returns true if this attribute is an enum attribute (i.e., a subclass of
|
||||
// `EnumAttrInfo`)
|
||||
bool isEnumAttr() const;
|
||||
|
|
|
@ -55,6 +55,13 @@ bool Attribute::isDerivedAttr() const { return isSubClassOf("DerivedAttr"); }
|
|||
|
||||
bool Attribute::isTypeAttr() const { return isSubClassOf("TypeAttrBase"); }
|
||||
|
||||
bool Attribute::isSymbolRefAttr() const {
|
||||
StringRef defName = def->getName();
|
||||
if (defName == "SymbolRefAttr" || defName == "FlatSymbolRefAttr")
|
||||
return true;
|
||||
return isSubClassOf("SymbolRefAttr") || isSubClassOf("FlatSymbolRefAttr");
|
||||
}
|
||||
|
||||
bool Attribute::isEnumAttr() const { return isSubClassOf("EnumAttrInfo"); }
|
||||
|
||||
StringRef Attribute::getStorageType() const {
|
||||
|
|
|
@ -19,6 +19,11 @@
|
|||
{effect="allocate", on_result, test_resource}
|
||||
]} : () -> i32
|
||||
|
||||
// expected-remark@+1 {{found an instance of 'read' on a symbol '@foo_ref', on resource '<Test>'}}
|
||||
"test.side_effect_op"() {effects = [
|
||||
{effect="read", on_reference = @foo_ref, test_resource}
|
||||
]} : () -> i32
|
||||
|
||||
// No _memory_ effects, but a parametric test effect.
|
||||
// expected-remark@+2 {{operation has no memory effects}}
|
||||
// expected-remark@+1 {{found a parametric effect with affine_map<(d0, d1) -> (d1, d0)>}}
|
||||
|
|
|
@ -744,17 +744,18 @@ void SideEffectOp::getEffects(
|
|||
.Case("read", MemoryEffects::Read::get())
|
||||
.Case("write", MemoryEffects::Write::get());
|
||||
|
||||
// Check for a result to affect.
|
||||
Value value;
|
||||
if (effectElement.get("on_result"))
|
||||
value = getResult();
|
||||
|
||||
// Check for a non-default resource to use.
|
||||
SideEffects::Resource *resource = SideEffects::DefaultResource::get();
|
||||
if (effectElement.get("test_resource"))
|
||||
resource = TestResource::get();
|
||||
|
||||
effects.emplace_back(effect, value, resource);
|
||||
// Check for a result to affect.
|
||||
if (effectElement.get("on_result"))
|
||||
effects.emplace_back(effect, getResult(), resource);
|
||||
else if (Attribute ref = effectElement.get("on_reference"))
|
||||
effects.emplace_back(effect, ref.cast<SymbolRefAttr>(), resource);
|
||||
else
|
||||
effects.emplace_back(effect, resource);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -43,6 +43,8 @@ struct SideEffectsPass
|
|||
|
||||
if (instance.getValue())
|
||||
diag << " on a value,";
|
||||
else if (SymbolRefAttr symbolRef = instance.getSymbolRef())
|
||||
diag << " on a symbol '" << symbolRef << "',";
|
||||
|
||||
diag << " on resource '" << instance.getResource()->getName() << "'";
|
||||
}
|
||||
|
|
|
@ -11,7 +11,12 @@ class TEST_Op<string mnemonic, list<OpTrait> traits = []> :
|
|||
def CustomResource : Resource<"CustomResource">;
|
||||
|
||||
def SideEffectOpA : TEST_Op<"side_effect_op_a"> {
|
||||
let arguments = (ins Arg<Variadic<AnyMemRef>, "", [MemRead]>);
|
||||
let arguments = (ins
|
||||
Arg<Variadic<AnyMemRef>, "", [MemRead]>,
|
||||
Arg<SymbolRefAttr, "", [MemRead]>:$symbol,
|
||||
Arg<FlatSymbolRefAttr, "", [MemWrite]>:$flat_symbol,
|
||||
Arg<OptionalAttr<SymbolRefAttr>, "", [MemRead]>:$optional_symbol
|
||||
);
|
||||
let results = (outs Res<AnyMemRef, "", [MemAlloc<CustomResource>]>);
|
||||
}
|
||||
|
||||
|
@ -21,6 +26,10 @@ def SideEffectOpB : TEST_Op<"side_effect_op_b",
|
|||
// CHECK: void SideEffectOpA::getEffects
|
||||
// CHECK: for (::mlir::Value value : getODSOperands(0))
|
||||
// CHECK: effects.emplace_back(MemoryEffects::Read::get(), value, ::mlir::SideEffects::DefaultResource::get());
|
||||
// CHECK: effects.emplace_back(MemoryEffects::Read::get(), symbol(), ::mlir::SideEffects::DefaultResource::get());
|
||||
// CHECK: effects.emplace_back(MemoryEffects::Write::get(), flat_symbol(), ::mlir::SideEffects::DefaultResource::get());
|
||||
// CHECK: if (auto symbolRef = optional_symbolAttr())
|
||||
// CHECK: effects.emplace_back(MemoryEffects::Read::get(), symbolRef, ::mlir::SideEffects::DefaultResource::get());
|
||||
// CHECK: for (::mlir::Value value : getODSResults(0))
|
||||
// CHECK: effects.emplace_back(MemoryEffects::Allocate::get(), value, CustomResource::get());
|
||||
|
||||
|
|
|
@ -1627,12 +1627,12 @@ void OpEmitter::genOpInterfaceMethods() {
|
|||
}
|
||||
|
||||
void OpEmitter::genSideEffectInterfaceMethods() {
|
||||
enum EffectKind { Operand, Result, Static };
|
||||
enum EffectKind { Operand, Result, Symbol, Static };
|
||||
struct EffectLocation {
|
||||
/// The effect applied.
|
||||
SideEffect effect;
|
||||
|
||||
/// The index if the kind is either operand or result.
|
||||
/// The index if the kind is not static.
|
||||
unsigned index : 30;
|
||||
|
||||
/// The kind of the location.
|
||||
|
@ -1661,17 +1661,29 @@ void OpEmitter::genSideEffectInterfaceMethods() {
|
|||
effects.push_back(EffectLocation{cast<SideEffect>(decorator),
|
||||
/*index=*/0, EffectKind::Static});
|
||||
}
|
||||
/// Operands.
|
||||
/// Attributes and Operands.
|
||||
for (unsigned i = 0, operandIt = 0, e = op.getNumArgs(); i != e; ++i) {
|
||||
if (op.getArg(i).is<NamedTypeConstraint *>()) {
|
||||
Argument arg = op.getArg(i);
|
||||
if (arg.is<NamedTypeConstraint *>()) {
|
||||
resolveDecorators(op.getArgDecorators(i), operandIt, EffectKind::Operand);
|
||||
++operandIt;
|
||||
continue;
|
||||
}
|
||||
const NamedAttribute *attr = arg.get<NamedAttribute *>();
|
||||
if (attr->attr.getBaseAttr().isSymbolRefAttr())
|
||||
resolveDecorators(op.getArgDecorators(i), i, EffectKind::Symbol);
|
||||
}
|
||||
/// Results.
|
||||
for (unsigned i = 0, e = op.getNumResults(); i != e; ++i)
|
||||
resolveDecorators(op.getResultDecorators(i), i, EffectKind::Result);
|
||||
|
||||
// The code used to add an effect instance.
|
||||
// {0}: The effect class.
|
||||
// {1}: Optional value or symbol reference.
|
||||
// {1}: The resource class.
|
||||
const char *addEffectCode =
|
||||
" effects.emplace_back({0}::get(), {1}{2}::get());\n";
|
||||
|
||||
for (auto &it : interfaceEffects) {
|
||||
// Generate the 'getEffects' method.
|
||||
std::string type = llvm::formatv("::mlir::SmallVectorImpl<::mlir::"
|
||||
|
@ -1684,19 +1696,30 @@ void OpEmitter::genSideEffectInterfaceMethods() {
|
|||
|
||||
// Add effect instances for each of the locations marked on the operation.
|
||||
for (auto &location : it.second) {
|
||||
if (location.kind != EffectKind::Static) {
|
||||
StringRef effect = location.effect.getName();
|
||||
StringRef resource = location.effect.getResource();
|
||||
if (location.kind == EffectKind::Static) {
|
||||
// A static instance has no attached value.
|
||||
body << llvm::formatv(addEffectCode, effect, "", resource).str();
|
||||
} else if (location.kind == EffectKind::Symbol) {
|
||||
// A symbol reference requires adding the proper attribute.
|
||||
const auto *attr = op.getArg(location.index).get<NamedAttribute *>();
|
||||
if (attr->attr.isOptional()) {
|
||||
body << " if (auto symbolRef = " << attr->name << "Attr())\n "
|
||||
<< llvm::formatv(addEffectCode, effect, "symbolRef, ", resource)
|
||||
.str();
|
||||
} else {
|
||||
body << llvm::formatv(addEffectCode, effect, attr->name + "(), ",
|
||||
resource)
|
||||
.str();
|
||||
}
|
||||
} else {
|
||||
// Otherwise this is an operand/result, so we need to attach the Value.
|
||||
body << " for (::mlir::Value value : getODS"
|
||||
<< (location.kind == EffectKind::Operand ? "Operands" : "Results")
|
||||
<< "(" << location.index << "))\n ";
|
||||
<< "(" << location.index << "))\n "
|
||||
<< llvm::formatv(addEffectCode, effect, "value, ", resource).str();
|
||||
}
|
||||
|
||||
body << " effects.emplace_back(" << location.effect.getName()
|
||||
<< "::get()";
|
||||
|
||||
// If the effect isn't static, it has a specific value attached to it.
|
||||
if (location.kind != EffectKind::Static)
|
||||
body << ", value";
|
||||
body << ", " << location.effect.getResource() << "::get());\n";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue