forked from OSchip/llvm-project
Rename FunctionAttr to SymbolRefAttr.
This allows for the attribute to hold symbolic references to other operations than FuncOp. This also allows for removing the dependence on FuncOp from the base Builder. PiperOrigin-RevId: 257650017
This commit is contained in:
parent
4dfe6d457b
commit
9dbef0bf96
|
@ -655,7 +655,7 @@ PYBIND11_MODULE(pybind, m) {
|
|||
});
|
||||
m.def("constant_function", [](PythonFunction func) -> PythonValueHandle {
|
||||
auto function = FuncOp::getFromOpaquePointer(func.function);
|
||||
auto attr = FunctionAttr::get(function.getName(), function.getContext());
|
||||
auto attr = SymbolRefAttr::get(function.getName(), function.getContext());
|
||||
return ValueHandle::create<ConstantOp>(function.getType(), attr);
|
||||
});
|
||||
m.def("appendTo", [](const PythonBlockHandle &handle) {
|
||||
|
|
|
@ -26,6 +26,7 @@
|
|||
#include "mlir/Analysis/Verifier.h"
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/Function.h"
|
||||
#include "mlir/IR/Location.h"
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "mlir/IR/Module.h"
|
||||
|
|
|
@ -27,6 +27,7 @@
|
|||
#include "mlir/Analysis/Verifier.h"
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/Function.h"
|
||||
#include "mlir/IR/Location.h"
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "mlir/IR/Module.h"
|
||||
|
|
|
@ -27,6 +27,7 @@
|
|||
#include "mlir/Analysis/Verifier.h"
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/Function.h"
|
||||
#include "mlir/IR/Location.h"
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "mlir/IR/Module.h"
|
||||
|
|
|
@ -160,10 +160,10 @@ public:
|
|||
// clang-format off
|
||||
LoopBuilder(&i, zero, M, 1)([&]{
|
||||
llvmCall(retTy,
|
||||
rewriter.getFunctionAttr(printfFunc),
|
||||
rewriter.getSymbolRefAttr(printfFunc),
|
||||
{fmtCst, iOp(i)});
|
||||
});
|
||||
llvmCall(retTy, rewriter.getFunctionAttr(printfFunc), {fmtEol});
|
||||
llvmCall(retTy, rewriter.getSymbolRefAttr(printfFunc), {fmtEol});
|
||||
// clang-format on
|
||||
} else {
|
||||
IndexHandle N(vOp.ub(1));
|
||||
|
@ -171,10 +171,10 @@ public:
|
|||
LoopBuilder(&i, zero, M, 1)([&]{
|
||||
LoopBuilder(&j, zero, N, 1)([&]{
|
||||
llvmCall(retTy,
|
||||
rewriter.getFunctionAttr(printfFunc),
|
||||
rewriter.getSymbolRefAttr(printfFunc),
|
||||
{fmtCst, iOp(i, j)});
|
||||
});
|
||||
llvmCall(retTy, rewriter.getFunctionAttr(printfFunc), {fmtEol});
|
||||
llvmCall(retTy, rewriter.getSymbolRefAttr(printfFunc), {fmtEol});
|
||||
});
|
||||
// clang-format on
|
||||
}
|
||||
|
|
|
@ -27,6 +27,7 @@
|
|||
#include "mlir/Analysis/Verifier.h"
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/Function.h"
|
||||
#include "mlir/IR/Location.h"
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "mlir/IR/Module.h"
|
||||
|
|
|
@ -177,7 +177,7 @@ bare-id ::= (letter|[_]) (letter|digit|[_$.])*
|
|||
bare-id-list ::= bare-id (`,` bare-id)*
|
||||
suffix-id ::= digit+ | ((letter|id-punct) (letter|id-punct|digit)*)
|
||||
|
||||
function-id ::= `@` bare-id
|
||||
symbol-ref-id ::= `@` bare-id
|
||||
ssa-id ::= `%` suffix-id
|
||||
ssa-id-list ::= ssa-id (`,` ssa-id)*
|
||||
|
||||
|
@ -690,8 +690,8 @@ attribute-value ::= affine-map-attribute
|
|||
| integer-attribute
|
||||
| integer-set-attribute
|
||||
| float-attribute
|
||||
| function-attribute
|
||||
| string-attribute
|
||||
| symbol-ref-attribute
|
||||
| type-attribute
|
||||
| unit-attribute
|
||||
```
|
||||
|
@ -847,17 +847,6 @@ float-attribute ::= float-literal (`:` float-type)?
|
|||
A float attribute is a literal attribute that represents a floating point value
|
||||
of the specified [float type](#floating-point-types).
|
||||
|
||||
#### Function Attribute
|
||||
|
||||
Syntax:
|
||||
|
||||
``` {.ebnf}
|
||||
function-attribute ::= function-id
|
||||
```
|
||||
|
||||
A function attribute is a literal attribute that represents a named reference to
|
||||
the given function.
|
||||
|
||||
#### String Attribute
|
||||
|
||||
Syntax:
|
||||
|
@ -868,6 +857,17 @@ string-attribute ::= string-literal (`:` type)?
|
|||
|
||||
A string attribute is an attribute that represents a string literal value.
|
||||
|
||||
#### Symbol Reference Attribute
|
||||
|
||||
Syntax:
|
||||
|
||||
``` {.ebnf}
|
||||
symbol-ref-attribute ::= symbol-ref-id
|
||||
```
|
||||
|
||||
A symbol reference attribute is a literal attribute that represents a named
|
||||
reference to a given operation.
|
||||
|
||||
#### Type Attribute
|
||||
|
||||
Syntax:
|
||||
|
@ -924,7 +924,8 @@ referenced by name via a string attribute):
|
|||
``` {.ebnf}
|
||||
function ::= `func` function-signature function-attributes? function-body?
|
||||
|
||||
function-signature ::= function-id `(` argument-list `)` (`->` function-result-type)?
|
||||
function-signature ::= symbol-ref-id `(` argument-list `)`
|
||||
(`->` function-result-type)?
|
||||
argument-list ::= named-argument (`,` named-argument)* | /*empty*/
|
||||
argument-list ::= type attribute-dict? (`,` type attribute-dict?)* | /*empty*/
|
||||
named-argument ::= ssa-id `:` type attribute-dict?
|
||||
|
@ -1277,7 +1278,7 @@ single function to return.
|
|||
Syntax:
|
||||
|
||||
``` {.ebnf}
|
||||
operation ::= `call` function-id `(` ssa-use-list? `)` `:` function-type
|
||||
operation ::= `call` symbol-ref-id `(` ssa-use-list? `)` `:` function-type
|
||||
```
|
||||
|
||||
The `call` operation represents a direct call to a function. The operands and
|
||||
|
|
|
@ -167,10 +167,10 @@ in this chapter for lowering `toy.print`:
|
|||
LoopBuilder(&i, zero, M, 1)({
|
||||
LoopBuilder(&j, zero, N, 1)({
|
||||
llvmCall(retTy,
|
||||
rewriter.getFunctionAttr(printfFunc),
|
||||
rewriter.getSymbolRefAttr(printfFunc),
|
||||
{fmtCst, iOp(i, j)})
|
||||
}),
|
||||
llvmCall(retTy, rewriter.getFunctionAttr(printfFunc), {fmtEol})
|
||||
llvmCall(retTy, rewriter.getSymbolRefAttr(printfFunc), {fmtEol})
|
||||
});
|
||||
```
|
||||
|
||||
|
|
|
@ -141,11 +141,11 @@ enum Kind {
|
|||
Bool,
|
||||
Dictionary,
|
||||
Float,
|
||||
Function,
|
||||
Integer,
|
||||
IntegerSet,
|
||||
Opaque,
|
||||
String,
|
||||
SymbolRef,
|
||||
Type,
|
||||
Unit,
|
||||
|
||||
|
@ -306,25 +306,6 @@ public:
|
|||
Type type, const APFloat &value);
|
||||
};
|
||||
|
||||
/// A function attribute represents a reference to a function object.
|
||||
class FunctionAttr
|
||||
: public Attribute::AttrBase<FunctionAttr, Attribute,
|
||||
detail::StringAttributeStorage> {
|
||||
public:
|
||||
using Base::Base;
|
||||
using ValueType = StringRef;
|
||||
|
||||
static FunctionAttr get(StringRef value, MLIRContext *ctx);
|
||||
|
||||
/// Returns the name of the held function reference.
|
||||
StringRef getValue() const;
|
||||
|
||||
/// Methods for support type inquiry through isa, cast, and dyn_cast.
|
||||
static bool kindof(unsigned kind) {
|
||||
return kind == StandardAttributes::Function;
|
||||
}
|
||||
};
|
||||
|
||||
class IntegerAttr
|
||||
: public Attribute::AttrBase<IntegerAttr, Attribute,
|
||||
detail::IntegerAttributeStorage> {
|
||||
|
@ -417,6 +398,26 @@ public:
|
|||
}
|
||||
};
|
||||
|
||||
/// A symbol reference attribute represents a symbolic reference to another
|
||||
/// operation.
|
||||
class SymbolRefAttr
|
||||
: public Attribute::AttrBase<SymbolRefAttr, Attribute,
|
||||
detail::StringAttributeStorage> {
|
||||
public:
|
||||
using Base::Base;
|
||||
using ValueType = StringRef;
|
||||
|
||||
static SymbolRefAttr get(StringRef value, MLIRContext *ctx);
|
||||
|
||||
/// Returns the name of the held symbol reference.
|
||||
StringRef getValue() const;
|
||||
|
||||
/// Methods for support type inquiry through isa, cast, and dyn_cast.
|
||||
static bool kindof(unsigned kind) {
|
||||
return kind == StandardAttributes::SymbolRef;
|
||||
}
|
||||
};
|
||||
|
||||
class TypeAttr : public Attribute::AttrBase<TypeAttr, Attribute,
|
||||
detail::TypeAttributeStorage> {
|
||||
public:
|
||||
|
|
|
@ -18,7 +18,6 @@
|
|||
#ifndef MLIR_IR_BUILDERS_H
|
||||
#define MLIR_IR_BUILDERS_H
|
||||
|
||||
#include "mlir/IR/Function.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
|
||||
namespace mlir {
|
||||
|
@ -44,7 +43,7 @@ class FloatAttr;
|
|||
class StringAttr;
|
||||
class TypeAttr;
|
||||
class ArrayAttr;
|
||||
class FunctionAttr;
|
||||
class SymbolRefAttr;
|
||||
class ElementsAttr;
|
||||
class DenseElementsAttr;
|
||||
class DenseIntElementsAttr;
|
||||
|
@ -111,8 +110,8 @@ public:
|
|||
AffineMapAttr getAffineMapAttr(AffineMap map);
|
||||
IntegerSetAttr getIntegerSetAttr(IntegerSet set);
|
||||
TypeAttr getTypeAttr(Type type);
|
||||
FunctionAttr getFunctionAttr(FuncOp value);
|
||||
FunctionAttr getFunctionAttr(StringRef value);
|
||||
SymbolRefAttr getSymbolRefAttr(Operation *value);
|
||||
SymbolRefAttr getSymbolRefAttr(StringRef value);
|
||||
ElementsAttr getDenseElementsAttr(ShapedType type,
|
||||
ArrayRef<Attribute> values);
|
||||
ElementsAttr getDenseIntElementsAttr(ShapedType type,
|
||||
|
|
|
@ -831,12 +831,12 @@ def TypeArrayAttr : TypedArrayAttrBase<TypeAttr, "type array attribute"> {
|
|||
let constBuilderCall = ?;
|
||||
}
|
||||
|
||||
// Attributes containing functions.
|
||||
def FunctionAttr : Attr<CPred<"$_self.isa<FunctionAttr>()">,
|
||||
"function attribute"> {
|
||||
let storageType = [{ FunctionAttr }];
|
||||
// Attributes containing symbol references.
|
||||
def SymbolRefAttr : Attr<CPred<"$_self.isa<SymbolRefAttr>()">,
|
||||
"symbol reference attribute"> {
|
||||
let storageType = [{ SymbolRefAttr }];
|
||||
let returnType = [{ StringRef }];
|
||||
let constBuilderCall = "$_builder.getFunctionAttr($0)";
|
||||
let constBuilderCall = "$_builder.getSymbolRefAttr($0)";
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -22,6 +22,7 @@
|
|||
|
||||
namespace mlir {
|
||||
|
||||
class FuncOp;
|
||||
class PatternRewriter;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -199,7 +199,7 @@ def LLVM_TruncOp : LLVM_CastOp<"trunc", "CreateTrunc">;
|
|||
|
||||
// Call-related operations.
|
||||
def LLVM_CallOp : LLVM_Op<"call">,
|
||||
Arguments<(ins OptionalAttr<FunctionAttr>:$callee,
|
||||
Arguments<(ins OptionalAttr<SymbolRefAttr>:$callee,
|
||||
// TODO(b/133216756): fix test failure and
|
||||
// change to LLVM_Type
|
||||
Variadic<AnyType>)>,
|
||||
|
|
|
@ -83,7 +83,7 @@ def SPV_EntryPointOp : SPV_Op<"EntryPoint", [ModuleOnly]> {
|
|||
|
||||
let arguments = (ins
|
||||
SPV_ExecutionModelAttr:$execution_model,
|
||||
FunctionAttr:$fn,
|
||||
SymbolRefAttr:$fn,
|
||||
Variadic<SPV_AnyPtr>:$interface
|
||||
);
|
||||
|
||||
|
|
|
@ -210,20 +210,20 @@ def CallOp : Std_Op<"call"> {
|
|||
%2 = call @my_add(%0, %1) : (f32, f32) -> f32
|
||||
}];
|
||||
|
||||
let arguments = (ins FunctionAttr:$callee, Variadic<AnyType>:$operands);
|
||||
let arguments = (ins SymbolRefAttr:$callee, Variadic<AnyType>:$operands);
|
||||
let results = (outs Variadic<AnyType>);
|
||||
|
||||
let builders = [OpBuilder<
|
||||
"Builder *builder, OperationState *result, FuncOp callee,"
|
||||
"ArrayRef<Value *> operands = {}", [{
|
||||
result->addOperands(operands);
|
||||
result->addAttribute("callee", builder->getFunctionAttr(callee));
|
||||
result->addAttribute("callee", builder->getSymbolRefAttr(callee));
|
||||
result->addTypes(callee.getType().getResults());
|
||||
}]>, OpBuilder<
|
||||
"Builder *builder, OperationState *result, StringRef callee,"
|
||||
"ArrayRef<Type> results, ArrayRef<Value *> operands = {}", [{
|
||||
result->addOperands(operands);
|
||||
result->addAttribute("callee", builder->getFunctionAttr(callee));
|
||||
result->addAttribute("callee", builder->getSymbolRefAttr(callee));
|
||||
result->addTypes(results);
|
||||
}]>];
|
||||
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
#include "mlir/AffineOps/AffineOps.h"
|
||||
#include "mlir/IR/Block.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/Function.h"
|
||||
#include "mlir/IR/IntegerSet.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
|
|
|
@ -26,6 +26,7 @@
|
|||
#include "mlir/Analysis/Utils.h"
|
||||
#include "mlir/IR/AffineExprVisitor.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/Function.h"
|
||||
#include "mlir/IR/IntegerSet.h"
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/StandardOps/Ops.h"
|
||||
|
|
|
@ -324,7 +324,7 @@ void GpuLaunchFuncToCudaCallsPass::translateGpuLaunchCalls(
|
|||
// TODO(herhut): This should rather be a static global once supported.
|
||||
auto kernelFunction = getModule().lookupSymbol<FuncOp>(launchOp.kernel());
|
||||
auto cubinGetter =
|
||||
kernelFunction.getAttrOfType<FunctionAttr>(kCubinGetterAnnotation);
|
||||
kernelFunction.getAttrOfType<SymbolRefAttr>(kCubinGetterAnnotation);
|
||||
if (!cubinGetter) {
|
||||
kernelFunction.emitError("Missing ")
|
||||
<< kCubinGetterAnnotation << " attribute.";
|
||||
|
@ -337,7 +337,7 @@ void GpuLaunchFuncToCudaCallsPass::translateGpuLaunchCalls(
|
|||
auto cuModule = allocatePointer(builder, loc);
|
||||
FuncOp cuModuleLoad = getModule().lookupSymbol<FuncOp>(cuModuleLoadName);
|
||||
builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getCUResultType()},
|
||||
builder.getFunctionAttr(cuModuleLoad),
|
||||
builder.getSymbolRefAttr(cuModuleLoad),
|
||||
ArrayRef<Value *>{cuModule, data.getResult(0)});
|
||||
// Get the function from the module. The name corresponds to the name of
|
||||
// the kernel function.
|
||||
|
@ -349,14 +349,14 @@ void GpuLaunchFuncToCudaCallsPass::translateGpuLaunchCalls(
|
|||
getModule().lookupSymbol<FuncOp>(cuModuleGetFunctionName);
|
||||
builder.create<LLVM::CallOp>(
|
||||
loc, ArrayRef<Type>{getCUResultType()},
|
||||
builder.getFunctionAttr(cuModuleGetFunction),
|
||||
builder.getSymbolRefAttr(cuModuleGetFunction),
|
||||
ArrayRef<Value *>{cuFunction, cuOwningModuleRef, kernelName});
|
||||
// Grab the global stream needed for execution.
|
||||
FuncOp cuGetStreamHelper =
|
||||
getModule().lookupSymbol<FuncOp>(cuGetStreamHelperName);
|
||||
auto cuStream = builder.create<LLVM::CallOp>(
|
||||
loc, ArrayRef<Type>{getPointerType()},
|
||||
builder.getFunctionAttr(cuGetStreamHelper), ArrayRef<Value *>{});
|
||||
builder.getSymbolRefAttr(cuGetStreamHelper), ArrayRef<Value *>{});
|
||||
// Invoke the function with required arguments.
|
||||
auto cuLaunchKernel = getModule().lookupSymbol<FuncOp>(cuLaunchKernelName);
|
||||
auto cuFunctionRef =
|
||||
|
@ -366,7 +366,7 @@ void GpuLaunchFuncToCudaCallsPass::translateGpuLaunchCalls(
|
|||
builder.create<LLVM::IntToPtrOp>(loc, getPointerPointerType(), zero);
|
||||
builder.create<LLVM::CallOp>(
|
||||
loc, ArrayRef<Type>{getCUResultType()},
|
||||
builder.getFunctionAttr(cuLaunchKernel),
|
||||
builder.getSymbolRefAttr(cuLaunchKernel),
|
||||
ArrayRef<Value *>{cuFunctionRef, launchOp.getOperand(0),
|
||||
launchOp.getOperand(1), launchOp.getOperand(2),
|
||||
launchOp.getOperand(3), launchOp.getOperand(4),
|
||||
|
@ -377,7 +377,7 @@ void GpuLaunchFuncToCudaCallsPass::translateGpuLaunchCalls(
|
|||
// Sync on the stream to make it synchronous.
|
||||
auto cuStreamSync = getModule().lookupSymbol<FuncOp>(cuStreamSynchronizeName);
|
||||
builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getCUResultType()},
|
||||
builder.getFunctionAttr(cuStreamSync),
|
||||
builder.getSymbolRefAttr(cuStreamSync),
|
||||
ArrayRef<Value *>(cuStream.getResult(0)));
|
||||
launchOp.erase();
|
||||
}
|
||||
|
|
|
@ -94,7 +94,7 @@ private:
|
|||
auto memory =
|
||||
ob.create<LLVM::CallOp>(
|
||||
loc, ArrayRef<Type>{LLVM::LLVMType::getInt8PtrTy(llvmDialect)},
|
||||
builder.getFunctionAttr(getMallocHelper(loc, builder)),
|
||||
builder.getSymbolRefAttr(getMallocHelper(loc, builder)),
|
||||
ArrayRef<Value *>{sizeConstant})
|
||||
.getResult(0);
|
||||
for (auto byte : llvm::enumerate(blob.getValue().bytes())) {
|
||||
|
@ -111,7 +111,7 @@ private:
|
|||
}
|
||||
ob.create<LLVM::ReturnOp>(loc, ArrayRef<Value *>{memory});
|
||||
// Store the name of the getter on the function for easier lookup.
|
||||
orig.setAttr(kCubinGetterAnnotation, builder.getFunctionAttr(result));
|
||||
orig.setAttr(kCubinGetterAnnotation, builder.getSymbolRefAttr(result));
|
||||
return result;
|
||||
}
|
||||
|
||||
|
|
|
@ -456,7 +456,7 @@ struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> {
|
|||
Value *allocated =
|
||||
rewriter
|
||||
.create<LLVM::CallOp>(op->getLoc(), getVoidPtrType(),
|
||||
rewriter.getFunctionAttr(mallocFunc),
|
||||
rewriter.getSymbolRefAttr(mallocFunc),
|
||||
cumulativeSize)
|
||||
.getResult(0);
|
||||
auto structElementType = lowering.convertType(elementType);
|
||||
|
@ -520,7 +520,7 @@ struct DeallocOpLowering : public LLVMLegalizationPattern<DeallocOp> {
|
|||
Value *casted = rewriter.create<LLVM::BitcastOp>(
|
||||
op->getLoc(), getVoidPtrType(), bufferPtr);
|
||||
rewriter.replaceOpWithNewOp<LLVM::CallOp>(
|
||||
op, ArrayRef<Type>(), rewriter.getFunctionAttr(freeFunc), casted);
|
||||
op, ArrayRef<Type>(), rewriter.getSymbolRefAttr(freeFunc), casted);
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
|
|
|
@ -21,6 +21,7 @@
|
|||
|
||||
#include "mlir/GPU/GPUDialect.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/Function.h"
|
||||
#include "mlir/IR/Module.h"
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
|
@ -386,7 +387,7 @@ void LaunchFuncOp::build(Builder *builder, OperationState *result,
|
|||
{gridSizeX, gridSizeY, gridSizeZ, blockSizeX, blockSizeY, blockSizeZ});
|
||||
result->addOperands(kernelOperands);
|
||||
result->addAttribute(getKernelAttrName(),
|
||||
builder->getFunctionAttr(kernelFunc));
|
||||
builder->getSymbolRefAttr(kernelFunc));
|
||||
}
|
||||
|
||||
void LaunchFuncOp::build(Builder *builder, OperationState *result,
|
||||
|
@ -398,7 +399,7 @@ void LaunchFuncOp::build(Builder *builder, OperationState *result,
|
|||
}
|
||||
|
||||
StringRef LaunchFuncOp::kernel() {
|
||||
return getAttrOfType<FunctionAttr>(getKernelAttrName()).getValue();
|
||||
return getAttrOfType<SymbolRefAttr>(getKernelAttrName()).getValue();
|
||||
}
|
||||
|
||||
unsigned LaunchFuncOp::getNumKernelOperands() {
|
||||
|
@ -421,7 +422,7 @@ LogicalResult LaunchFuncOp::verify() {
|
|||
auto kernelAttr = this->getAttr(getKernelAttrName());
|
||||
if (!kernelAttr) {
|
||||
return emitOpError("attribute 'kernel' must be specified");
|
||||
} else if (!kernelAttr.isa<FunctionAttr>()) {
|
||||
} else if (!kernelAttr.isa<SymbolRefAttr>()) {
|
||||
return emitOpError("attribute 'kernel' must be a function");
|
||||
}
|
||||
|
||||
|
|
|
@ -670,10 +670,9 @@ void ModulePrinter::printAttribute(Attribute attr, bool mayElideType) {
|
|||
case StandardAttributes::Type:
|
||||
printType(attr.cast<TypeAttr>().getValue());
|
||||
break;
|
||||
case StandardAttributes::Function: {
|
||||
os << '@' << attr.cast<FunctionAttr>().getValue();
|
||||
case StandardAttributes::SymbolRef:
|
||||
os << '@' << attr.cast<SymbolRefAttr>().getValue();
|
||||
break;
|
||||
}
|
||||
case StandardAttributes::OpaqueElements: {
|
||||
auto eltsAttr = attr.cast<OpaqueElementsAttr>();
|
||||
os << "opaque<\"" << eltsAttr.getDialect()->getNamespace() << "\", ";
|
||||
|
|
|
@ -246,15 +246,15 @@ FloatAttr::verifyConstructionInvariants(llvm::Optional<Location> loc,
|
|||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// FunctionAttr
|
||||
// SymbolRefAttr
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
FunctionAttr FunctionAttr::get(StringRef value, MLIRContext *ctx) {
|
||||
return Base::get(ctx, StandardAttributes::Function, value,
|
||||
SymbolRefAttr SymbolRefAttr::get(StringRef value, MLIRContext *ctx) {
|
||||
return Base::get(ctx, StandardAttributes::SymbolRef, value,
|
||||
NoneType::get(ctx));
|
||||
}
|
||||
|
||||
StringRef FunctionAttr::getValue() const { return getImpl()->value; }
|
||||
StringRef SymbolRefAttr::getValue() const { return getImpl()->value; }
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// IntegerAttr
|
||||
|
|
|
@ -175,11 +175,14 @@ IntegerSetAttr Builder::getIntegerSetAttr(IntegerSet set) {
|
|||
|
||||
TypeAttr Builder::getTypeAttr(Type type) { return TypeAttr::get(type); }
|
||||
|
||||
FunctionAttr Builder::getFunctionAttr(FuncOp value) {
|
||||
return getFunctionAttr(value.getName());
|
||||
SymbolRefAttr Builder::getSymbolRefAttr(Operation *value) {
|
||||
auto symName =
|
||||
value->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName());
|
||||
assert(symName && "value does not have a valid symbol name");
|
||||
return getSymbolRefAttr(symName.getValue());
|
||||
}
|
||||
FunctionAttr Builder::getFunctionAttr(StringRef value) {
|
||||
return FunctionAttr::get(value, getContext());
|
||||
SymbolRefAttr Builder::getSymbolRefAttr(StringRef value) {
|
||||
return SymbolRefAttr::get(value, getContext());
|
||||
}
|
||||
|
||||
ElementsAttr Builder::getDenseElementsAttr(ShapedType type,
|
||||
|
|
|
@ -154,8 +154,8 @@ ParseResult FuncOp::parse(OpAsmParser *parser, OperationState *result) {
|
|||
SmallVector<SmallVector<NamedAttribute, 2>, 4> argAttrs;
|
||||
auto &builder = parser->getBuilder();
|
||||
|
||||
// Parse the name as a function attribute.
|
||||
FunctionAttr nameAttr;
|
||||
// Parse the name as a symbol reference attribute.
|
||||
SymbolRefAttr nameAttr;
|
||||
if (parser->parseAttribute(nameAttr, SymbolTable::getSymbolAttrName(),
|
||||
result->attributes))
|
||||
return failure();
|
||||
|
|
|
@ -82,7 +82,7 @@ namespace {
|
|||
struct BuiltinDialect : public Dialect {
|
||||
BuiltinDialect(MLIRContext *context) : Dialect(/*name=*/"", context) {
|
||||
addAttributes<AffineMapAttr, ArrayAttr, BoolAttr, DenseElementsAttr,
|
||||
DictionaryAttr, FloatAttr, FunctionAttr, IntegerAttr,
|
||||
DictionaryAttr, FloatAttr, SymbolRefAttr, IntegerAttr,
|
||||
IntegerSetAttr, OpaqueAttr, OpaqueElementsAttr,
|
||||
SparseElementsAttr, StringAttr, TypeAttr, UnitAttr>();
|
||||
addAttributes<CallSiteLoc, FileLineColLoc, FusedLoc, NameLoc, UnknownLoc>();
|
||||
|
|
|
@ -335,7 +335,7 @@ static ParseResult parseCallOp(OpAsmParser *parser, OperationState *result) {
|
|||
SmallVector<NamedAttribute, 4> attrs;
|
||||
SmallVector<OpAsmParser::OperandType, 8> operands;
|
||||
Type type;
|
||||
FunctionAttr funcAttr;
|
||||
SymbolRefAttr funcAttr;
|
||||
llvm::SMLoc trailingTypeLoc;
|
||||
|
||||
// Parse an operand list that will, in practice, contain 0 or 1 operand. In
|
||||
|
|
|
@ -204,7 +204,7 @@ public:
|
|||
Value *allocSize =
|
||||
mul(size, constant(int64Ty, IntegerAttr::get(indexType, elementSize)));
|
||||
Value *allocated =
|
||||
call(voidPtrTy, rewriter.getFunctionAttr(mallocFunc), allocSize)
|
||||
call(voidPtrTy, rewriter.getSymbolRefAttr(mallocFunc), allocSize)
|
||||
.getOperation()
|
||||
->getResult(0);
|
||||
allocated = bitcast(elementPtrType, allocated);
|
||||
|
@ -248,7 +248,7 @@ public:
|
|||
edsc::ScopedContext context(rewriter, op->getLoc());
|
||||
Value *casted = bitcast(voidPtrTy, extractvalue(elementPtrTy, operands[0],
|
||||
positionAttr(rewriter, 0)));
|
||||
call(ArrayRef<Type>(), rewriter.getFunctionAttr(freeFunc), casted);
|
||||
call(ArrayRef<Type>(), rewriter.getSymbolRefAttr(freeFunc), casted);
|
||||
rewriter.replaceOp(op, llvm::None);
|
||||
return matchSuccess();
|
||||
}
|
||||
|
@ -650,7 +650,7 @@ static void getLLVMLibraryCallDefinition(FuncOp fn,
|
|||
implFnArgs.push_back(alloca);
|
||||
llvm_store(arg, alloca);
|
||||
}
|
||||
call(ArrayRef<Type>(), builder.getFunctionAttr(implFn), implFnArgs);
|
||||
call(ArrayRef<Type>(), builder.getSymbolRefAttr(implFn), implFnArgs);
|
||||
llvm_return{ArrayRef<Value *>()};
|
||||
}
|
||||
|
||||
|
@ -694,7 +694,7 @@ public:
|
|||
auto f = getLLVMLibraryCallDeclaration<LinalgOp>(op, lowering, rewriter);
|
||||
static_cast<LinalgTypeConverter &>(lowering).addLibraryFnDeclaration(f);
|
||||
|
||||
auto fAttr = rewriter.getFunctionAttr(f);
|
||||
auto fAttr = rewriter.getSymbolRefAttr(f);
|
||||
auto named = rewriter.getNamedAttr("callee", fAttr);
|
||||
rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, operands,
|
||||
ArrayRef<NamedAttribute>{named});
|
||||
|
|
|
@ -215,7 +215,7 @@ Token Lexer::lexBareIdentifierOrKeyword(const char *tokStart) {
|
|||
|
||||
/// Lex an '@foo' identifier.
|
||||
///
|
||||
/// function-id ::= `@` bare-id
|
||||
/// symbol-ref-id ::= `@` bare-id
|
||||
///
|
||||
Token Lexer::lexAtIdentifier(const char *tokStart) {
|
||||
// These always start with a letter or underscore.
|
||||
|
|
|
@ -931,7 +931,7 @@ ParseResult Parser::parseXInDimensionList() {
|
|||
/// | type
|
||||
/// | `[` (attribute-value (`,` attribute-value)*)? `]`
|
||||
/// | `{` (attribute-entry (`,` attribute-entry)*)? `}`
|
||||
/// | function-id `:` function-type
|
||||
/// | symbol-ref-id
|
||||
/// | `dense` `<` attribute-value `>` `:`
|
||||
/// (tensor-type | vector-type)
|
||||
/// | `sparse` `<` attribute-value `,` attribute-value `>`
|
||||
|
@ -1010,13 +1010,6 @@ Attribute Parser::parseAttribute(Type type) {
|
|||
nullptr);
|
||||
}
|
||||
|
||||
// Parse a function attribute.
|
||||
case Token::at_identifier: {
|
||||
auto nameStr = getTokenSpelling();
|
||||
consumeToken(Token::at_identifier);
|
||||
return builder.getFunctionAttr(nameStr.drop_front());
|
||||
}
|
||||
|
||||
// Parse a location attribute.
|
||||
case Token::kw_loc: {
|
||||
LocationAttr attr;
|
||||
|
@ -1043,6 +1036,13 @@ Attribute Parser::parseAttribute(Type type) {
|
|||
: StringAttr::get(val, getContext());
|
||||
}
|
||||
|
||||
// Parse a symbol reference attribute.
|
||||
case Token::at_identifier: {
|
||||
auto nameStr = getTokenSpelling();
|
||||
consumeToken(Token::at_identifier);
|
||||
return builder.getSymbolRefAttr(nameStr.drop_front());
|
||||
}
|
||||
|
||||
// Parse a 'unit' attribute.
|
||||
case Token::kw_unit:
|
||||
consumeToken(Token::kw_unit);
|
||||
|
|
|
@ -22,6 +22,7 @@
|
|||
#include "mlir/SPIRV/SPIRVOps.h"
|
||||
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/Function.h"
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
#include "mlir/SPIRV/SPIRVTypes.h"
|
||||
|
@ -280,8 +281,8 @@ static ParseResult parseEntryPointOp(OpAsmParser *parser,
|
|||
parser->resolveOperands(identifiers, idTypes, loc, state->operands)) {
|
||||
return failure();
|
||||
}
|
||||
if (!fn.isa<FunctionAttr>()) {
|
||||
return parser->emitError(loc, "expected function attribute");
|
||||
if (!fn.isa<SymbolRefAttr>()) {
|
||||
return parser->emitError(loc, "expected symbol reference attribute");
|
||||
}
|
||||
state->addTypes(
|
||||
spirv::EntryPointType::get(parser->getBuilder().getContext()));
|
||||
|
|
|
@ -21,6 +21,7 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/Function.h"
|
||||
#include "mlir/IR/Module.h"
|
||||
#include "mlir/SPIRV/SPIRVOps.h"
|
||||
#include "mlir/SPIRV/Serialization.h"
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
#include "mlir/IR/AffineExpr.h"
|
||||
#include "mlir/IR/AffineMap.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/Function.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
#include "mlir/IR/Module.h"
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
|
@ -401,7 +402,7 @@ void BranchOp::eraseOperand(unsigned index) {
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static ParseResult parseCallOp(OpAsmParser *parser, OperationState *result) {
|
||||
FunctionAttr calleeAttr;
|
||||
SymbolRefAttr calleeAttr;
|
||||
FunctionType calleeType;
|
||||
SmallVector<OpAsmParser::OperandType, 4> operands;
|
||||
auto calleeLoc = parser->getNameLoc();
|
||||
|
@ -428,9 +429,9 @@ static void print(OpAsmPrinter *p, CallOp op) {
|
|||
|
||||
static LogicalResult verify(CallOp op) {
|
||||
// Check that the callee attribute was specified.
|
||||
auto fnAttr = op.getAttrOfType<FunctionAttr>("callee");
|
||||
auto fnAttr = op.getAttrOfType<SymbolRefAttr>("callee");
|
||||
if (!fnAttr)
|
||||
return op.emitOpError("requires a 'callee' function attribute");
|
||||
return op.emitOpError("requires a 'callee' symbol reference attribute");
|
||||
auto fn =
|
||||
op.getParentOfType<ModuleOp>().lookupSymbol<FuncOp>(fnAttr.getValue());
|
||||
if (!fn)
|
||||
|
@ -474,7 +475,7 @@ struct SimplifyIndirectCallWithKnownCallee
|
|||
PatternMatchResult matchAndRewrite(CallIndirectOp indirectCall,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// Check that the callee is a constant callee.
|
||||
FunctionAttr calledFn;
|
||||
SymbolRefAttr calledFn;
|
||||
if (!matchPattern(indirectCall.getCallee(), m_Constant(&calledFn)))
|
||||
return matchFailure();
|
||||
|
||||
|
@ -1029,8 +1030,8 @@ static void print(OpAsmPrinter *p, ConstantOp &op) {
|
|||
*p << ' ';
|
||||
p->printAttribute(op.getValue());
|
||||
|
||||
// If the value is a function, print a trailing type.
|
||||
if (op.getValue().isa<FunctionAttr>()) {
|
||||
// If the value is a symbol reference, print a trailing type.
|
||||
if (op.getValue().isa<SymbolRefAttr>()) {
|
||||
*p << " : ";
|
||||
p->printType(op.getType());
|
||||
}
|
||||
|
@ -1043,9 +1044,9 @@ static ParseResult parseConstantOp(OpAsmParser *parser,
|
|||
parser->parseAttribute(valueAttr, "value", result->attributes))
|
||||
return failure();
|
||||
|
||||
// If the attribute is a function, then we expect a trailing type.
|
||||
// If the attribute is a symbol reference, then we expect a trailing type.
|
||||
Type type;
|
||||
if (!valueAttr.isa<FunctionAttr>())
|
||||
if (!valueAttr.isa<SymbolRefAttr>())
|
||||
type = valueAttr.getType();
|
||||
else if (parser->parseColonType(type))
|
||||
return failure();
|
||||
|
@ -1093,7 +1094,7 @@ static LogicalResult verify(ConstantOp &op) {
|
|||
}
|
||||
|
||||
if (type.isa<FunctionType>()) {
|
||||
auto fnAttr = value.dyn_cast<FunctionAttr>();
|
||||
auto fnAttr = value.dyn_cast<SymbolRefAttr>();
|
||||
if (!fnAttr)
|
||||
return op.emitOpError("requires 'value' to be a function reference");
|
||||
|
||||
|
@ -1125,8 +1126,8 @@ OpFoldResult ConstantOp::fold(ArrayRef<Attribute> operands) {
|
|||
/// Returns true if a constant operation can be built with the given value and
|
||||
/// result type.
|
||||
bool ConstantOp::isBuildableWith(Attribute value, Type type) {
|
||||
// FunctionAttr can only be used with a function type.
|
||||
if (value.isa<FunctionAttr>())
|
||||
// SymbolRefAttr can only be used with a function type.
|
||||
if (value.isa<SymbolRefAttr>())
|
||||
return type.isa<FunctionType>();
|
||||
// Otherwise, the attribute must have the same type as 'type'.
|
||||
if (value.getType() != type)
|
||||
|
|
|
@ -83,7 +83,7 @@ llvm::Constant *ModuleTranslation::getLLVMConstant(llvm::Type *llvmType,
|
|||
return llvm::ConstantInt::get(llvmType, intAttr.getValue());
|
||||
if (auto floatAttr = attr.dyn_cast<FloatAttr>())
|
||||
return llvm::ConstantFP::get(llvmType, floatAttr.getValue());
|
||||
if (auto funcAttr = attr.dyn_cast<FunctionAttr>())
|
||||
if (auto funcAttr = attr.dyn_cast<SymbolRefAttr>())
|
||||
return functionMapping.lookup(funcAttr.getValue());
|
||||
if (auto splatAttr = attr.dyn_cast<SplatElementsAttr>()) {
|
||||
auto *vectorType = cast<llvm::VectorType>(llvmType);
|
||||
|
@ -176,7 +176,7 @@ bool ModuleTranslation::convertOperation(Operation &opInst,
|
|||
auto convertCall = [this, &builder](Operation &op) -> llvm::Value * {
|
||||
auto operands = lookupValues(op.getOperands());
|
||||
ArrayRef<llvm::Value *> operandsRef(operands);
|
||||
if (auto attr = op.getAttrOfType<FunctionAttr>("callee")) {
|
||||
if (auto attr = op.getAttrOfType<SymbolRefAttr>("callee")) {
|
||||
return builder.CreateCall(functionMapping.lookup(attr.getValue()),
|
||||
operandsRef);
|
||||
} else {
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/Function.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/StandardOps/Ops.h"
|
||||
#include "mlir/Transforms/FoldUtils.h"
|
||||
|
|
|
@ -30,6 +30,7 @@
|
|||
#include "mlir/IR/AffineMap.h"
|
||||
#include "mlir/IR/BlockAndValueMapping.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/Function.h"
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/StandardOps/Ops.h"
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
|
|
|
@ -29,6 +29,7 @@
|
|||
#include "mlir/IR/AffineMap.h"
|
||||
#include "mlir/IR/BlockAndValueMapping.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/Function.h"
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/StandardOps/Ops.h"
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
|
|
|
@ -28,6 +28,7 @@
|
|||
#include "mlir/Analysis/Dominance.h"
|
||||
#include "mlir/Analysis/Utils.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/Function.h"
|
||||
#include "mlir/IR/Module.h"
|
||||
#include "mlir/StandardOps/Ops.h"
|
||||
#include "mlir/Support/MathExtras.h"
|
||||
|
|
|
@ -30,7 +30,7 @@ spv.module "Logical" "VulkanKHR" {
|
|||
func @do_nothing() -> () {
|
||||
spv.Return
|
||||
}
|
||||
// expected-error @+1 {{custom op 'spv.EntryPoint' expected function attribute}}
|
||||
// expected-error @+1 {{custom op 'spv.EntryPoint' expected symbol reference attribute}}
|
||||
%4 = spv.EntryPoint "GLCompute" "do_nothing"
|
||||
}
|
||||
|
||||
|
|
|
@ -85,7 +85,7 @@ def BOp : NS_Op<"b_op", []> {
|
|||
F64Attr:$f64_attr,
|
||||
StrAttr:$str_attr,
|
||||
ElementsAttr:$elements_attr,
|
||||
FunctionAttr:$function_attr,
|
||||
SymbolRefAttr:$function_attr,
|
||||
SomeTypeAttr:$type_attr,
|
||||
ArrayAttr:$array_attr,
|
||||
TypedArrayAttrBase<SomeAttr, "SomeAttr array">:$some_attr_array,
|
||||
|
@ -122,7 +122,7 @@ def BOp : NS_Op<"b_op", []> {
|
|||
// CHECK: if (!(((tblgen_f64_attr.isa<FloatAttr>())) && ((tblgen_f64_attr.cast<FloatAttr>().getType().isF64()))))
|
||||
// CHECK: if (!((tblgen_str_attr.isa<StringAttr>())))
|
||||
// CHECK: if (!((tblgen_elements_attr.isa<ElementsAttr>())))
|
||||
// CHECK: if (!((tblgen_function_attr.isa<FunctionAttr>())))
|
||||
// CHECK: if (!((tblgen_function_attr.isa<SymbolRefAttr>())))
|
||||
// CHECK: if (!(((tblgen_type_attr.isa<TypeAttr>())) && ((tblgen_type_attr.cast<TypeAttr>().getValue().isa<SomeType>()))))
|
||||
// CHECK: if (!((tblgen_array_attr.isa<ArrayAttr>())))
|
||||
// CHECK: if (!(((tblgen_some_attr_array.isa<ArrayAttr>())) && (llvm::all_of(tblgen_some_attr_array.cast<ArrayAttr>(), [](Attribute attr) { return (some-condition); }))))
|
||||
|
|
Loading…
Reference in New Issue