Rename FunctionAttr to SymbolRefAttr.

This allows for the attribute to hold symbolic references to other operations than FuncOp. This also allows for removing the dependence on FuncOp from the base Builder.

PiperOrigin-RevId: 257650017
This commit is contained in:
River Riddle 2019-07-11 11:41:04 -07:00 committed by jpienaar
parent 4dfe6d457b
commit 9dbef0bf96
40 changed files with 133 additions and 115 deletions

View File

@ -655,7 +655,7 @@ PYBIND11_MODULE(pybind, m) {
});
m.def("constant_function", [](PythonFunction func) -> PythonValueHandle {
auto function = FuncOp::getFromOpaquePointer(func.function);
auto attr = FunctionAttr::get(function.getName(), function.getContext());
auto attr = SymbolRefAttr::get(function.getName(), function.getContext());
return ValueHandle::create<ConstantOp>(function.getType(), attr);
});
m.def("appendTo", [](const PythonBlockHandle &handle) {

View File

@ -26,6 +26,7 @@
#include "mlir/Analysis/Verifier.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Module.h"

View File

@ -27,6 +27,7 @@
#include "mlir/Analysis/Verifier.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Module.h"

View File

@ -27,6 +27,7 @@
#include "mlir/Analysis/Verifier.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Module.h"

View File

@ -160,10 +160,10 @@ public:
// clang-format off
LoopBuilder(&i, zero, M, 1)([&]{
llvmCall(retTy,
rewriter.getFunctionAttr(printfFunc),
rewriter.getSymbolRefAttr(printfFunc),
{fmtCst, iOp(i)});
});
llvmCall(retTy, rewriter.getFunctionAttr(printfFunc), {fmtEol});
llvmCall(retTy, rewriter.getSymbolRefAttr(printfFunc), {fmtEol});
// clang-format on
} else {
IndexHandle N(vOp.ub(1));
@ -171,10 +171,10 @@ public:
LoopBuilder(&i, zero, M, 1)([&]{
LoopBuilder(&j, zero, N, 1)([&]{
llvmCall(retTy,
rewriter.getFunctionAttr(printfFunc),
rewriter.getSymbolRefAttr(printfFunc),
{fmtCst, iOp(i, j)});
});
llvmCall(retTy, rewriter.getFunctionAttr(printfFunc), {fmtEol});
llvmCall(retTy, rewriter.getSymbolRefAttr(printfFunc), {fmtEol});
});
// clang-format on
}

View File

@ -27,6 +27,7 @@
#include "mlir/Analysis/Verifier.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Module.h"

View File

@ -177,7 +177,7 @@ bare-id ::= (letter|[_]) (letter|digit|[_$.])*
bare-id-list ::= bare-id (`,` bare-id)*
suffix-id ::= digit+ | ((letter|id-punct) (letter|id-punct|digit)*)
function-id ::= `@` bare-id
symbol-ref-id ::= `@` bare-id
ssa-id ::= `%` suffix-id
ssa-id-list ::= ssa-id (`,` ssa-id)*
@ -690,8 +690,8 @@ attribute-value ::= affine-map-attribute
| integer-attribute
| integer-set-attribute
| float-attribute
| function-attribute
| string-attribute
| symbol-ref-attribute
| type-attribute
| unit-attribute
```
@ -847,17 +847,6 @@ float-attribute ::= float-literal (`:` float-type)?
A float attribute is a literal attribute that represents a floating point value
of the specified [float type](#floating-point-types).
#### Function Attribute
Syntax:
``` {.ebnf}
function-attribute ::= function-id
```
A function attribute is a literal attribute that represents a named reference to
the given function.
#### String Attribute
Syntax:
@ -868,6 +857,17 @@ string-attribute ::= string-literal (`:` type)?
A string attribute is an attribute that represents a string literal value.
#### Symbol Reference Attribute
Syntax:
``` {.ebnf}
symbol-ref-attribute ::= symbol-ref-id
```
A symbol reference attribute is a literal attribute that represents a named
reference to a given operation.
#### Type Attribute
Syntax:
@ -924,7 +924,8 @@ referenced by name via a string attribute):
``` {.ebnf}
function ::= `func` function-signature function-attributes? function-body?
function-signature ::= function-id `(` argument-list `)` (`->` function-result-type)?
function-signature ::= symbol-ref-id `(` argument-list `)`
(`->` function-result-type)?
argument-list ::= named-argument (`,` named-argument)* | /*empty*/
argument-list ::= type attribute-dict? (`,` type attribute-dict?)* | /*empty*/
named-argument ::= ssa-id `:` type attribute-dict?
@ -1277,7 +1278,7 @@ single function to return.
Syntax:
``` {.ebnf}
operation ::= `call` function-id `(` ssa-use-list? `)` `:` function-type
operation ::= `call` symbol-ref-id `(` ssa-use-list? `)` `:` function-type
```
The `call` operation represents a direct call to a function. The operands and

View File

@ -167,10 +167,10 @@ in this chapter for lowering `toy.print`:
LoopBuilder(&i, zero, M, 1)({
LoopBuilder(&j, zero, N, 1)({
llvmCall(retTy,
rewriter.getFunctionAttr(printfFunc),
rewriter.getSymbolRefAttr(printfFunc),
{fmtCst, iOp(i, j)})
}),
llvmCall(retTy, rewriter.getFunctionAttr(printfFunc), {fmtEol})
llvmCall(retTy, rewriter.getSymbolRefAttr(printfFunc), {fmtEol})
});
```

View File

@ -141,11 +141,11 @@ enum Kind {
Bool,
Dictionary,
Float,
Function,
Integer,
IntegerSet,
Opaque,
String,
SymbolRef,
Type,
Unit,
@ -306,25 +306,6 @@ public:
Type type, const APFloat &value);
};
/// A function attribute represents a reference to a function object.
class FunctionAttr
: public Attribute::AttrBase<FunctionAttr, Attribute,
detail::StringAttributeStorage> {
public:
using Base::Base;
using ValueType = StringRef;
static FunctionAttr get(StringRef value, MLIRContext *ctx);
/// Returns the name of the held function reference.
StringRef getValue() const;
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool kindof(unsigned kind) {
return kind == StandardAttributes::Function;
}
};
class IntegerAttr
: public Attribute::AttrBase<IntegerAttr, Attribute,
detail::IntegerAttributeStorage> {
@ -417,6 +398,26 @@ public:
}
};
/// A symbol reference attribute represents a symbolic reference to another
/// operation.
class SymbolRefAttr
: public Attribute::AttrBase<SymbolRefAttr, Attribute,
detail::StringAttributeStorage> {
public:
using Base::Base;
using ValueType = StringRef;
static SymbolRefAttr get(StringRef value, MLIRContext *ctx);
/// Returns the name of the held symbol reference.
StringRef getValue() const;
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool kindof(unsigned kind) {
return kind == StandardAttributes::SymbolRef;
}
};
class TypeAttr : public Attribute::AttrBase<TypeAttr, Attribute,
detail::TypeAttributeStorage> {
public:

View File

@ -18,7 +18,6 @@
#ifndef MLIR_IR_BUILDERS_H
#define MLIR_IR_BUILDERS_H
#include "mlir/IR/Function.h"
#include "mlir/IR/OpDefinition.h"
namespace mlir {
@ -44,7 +43,7 @@ class FloatAttr;
class StringAttr;
class TypeAttr;
class ArrayAttr;
class FunctionAttr;
class SymbolRefAttr;
class ElementsAttr;
class DenseElementsAttr;
class DenseIntElementsAttr;
@ -111,8 +110,8 @@ public:
AffineMapAttr getAffineMapAttr(AffineMap map);
IntegerSetAttr getIntegerSetAttr(IntegerSet set);
TypeAttr getTypeAttr(Type type);
FunctionAttr getFunctionAttr(FuncOp value);
FunctionAttr getFunctionAttr(StringRef value);
SymbolRefAttr getSymbolRefAttr(Operation *value);
SymbolRefAttr getSymbolRefAttr(StringRef value);
ElementsAttr getDenseElementsAttr(ShapedType type,
ArrayRef<Attribute> values);
ElementsAttr getDenseIntElementsAttr(ShapedType type,

View File

@ -831,12 +831,12 @@ def TypeArrayAttr : TypedArrayAttrBase<TypeAttr, "type array attribute"> {
let constBuilderCall = ?;
}
// Attributes containing functions.
def FunctionAttr : Attr<CPred<"$_self.isa<FunctionAttr>()">,
"function attribute"> {
let storageType = [{ FunctionAttr }];
// Attributes containing symbol references.
def SymbolRefAttr : Attr<CPred<"$_self.isa<SymbolRefAttr>()">,
"symbol reference attribute"> {
let storageType = [{ SymbolRefAttr }];
let returnType = [{ StringRef }];
let constBuilderCall = "$_builder.getFunctionAttr($0)";
let constBuilderCall = "$_builder.getSymbolRefAttr($0)";
}
//===----------------------------------------------------------------------===//

View File

@ -22,6 +22,7 @@
namespace mlir {
class FuncOp;
class PatternRewriter;
//===----------------------------------------------------------------------===//

View File

@ -199,7 +199,7 @@ def LLVM_TruncOp : LLVM_CastOp<"trunc", "CreateTrunc">;
// Call-related operations.
def LLVM_CallOp : LLVM_Op<"call">,
Arguments<(ins OptionalAttr<FunctionAttr>:$callee,
Arguments<(ins OptionalAttr<SymbolRefAttr>:$callee,
// TODO(b/133216756): fix test failure and
// change to LLVM_Type
Variadic<AnyType>)>,

View File

@ -83,7 +83,7 @@ def SPV_EntryPointOp : SPV_Op<"EntryPoint", [ModuleOnly]> {
let arguments = (ins
SPV_ExecutionModelAttr:$execution_model,
FunctionAttr:$fn,
SymbolRefAttr:$fn,
Variadic<SPV_AnyPtr>:$interface
);

View File

@ -210,20 +210,20 @@ def CallOp : Std_Op<"call"> {
%2 = call @my_add(%0, %1) : (f32, f32) -> f32
}];
let arguments = (ins FunctionAttr:$callee, Variadic<AnyType>:$operands);
let arguments = (ins SymbolRefAttr:$callee, Variadic<AnyType>:$operands);
let results = (outs Variadic<AnyType>);
let builders = [OpBuilder<
"Builder *builder, OperationState *result, FuncOp callee,"
"ArrayRef<Value *> operands = {}", [{
result->addOperands(operands);
result->addAttribute("callee", builder->getFunctionAttr(callee));
result->addAttribute("callee", builder->getSymbolRefAttr(callee));
result->addTypes(callee.getType().getResults());
}]>, OpBuilder<
"Builder *builder, OperationState *result, StringRef callee,"
"ArrayRef<Type> results, ArrayRef<Value *> operands = {}", [{
result->addOperands(operands);
result->addAttribute("callee", builder->getFunctionAttr(callee));
result->addAttribute("callee", builder->getSymbolRefAttr(callee));
result->addTypes(results);
}]>];

View File

@ -18,6 +18,7 @@
#include "mlir/AffineOps/AffineOps.h"
#include "mlir/IR/Block.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/IntegerSet.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpImplementation.h"

View File

@ -26,6 +26,7 @@
#include "mlir/Analysis/Utils.h"
#include "mlir/IR/AffineExprVisitor.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/IntegerSet.h"
#include "mlir/IR/Operation.h"
#include "mlir/StandardOps/Ops.h"

View File

@ -324,7 +324,7 @@ void GpuLaunchFuncToCudaCallsPass::translateGpuLaunchCalls(
// TODO(herhut): This should rather be a static global once supported.
auto kernelFunction = getModule().lookupSymbol<FuncOp>(launchOp.kernel());
auto cubinGetter =
kernelFunction.getAttrOfType<FunctionAttr>(kCubinGetterAnnotation);
kernelFunction.getAttrOfType<SymbolRefAttr>(kCubinGetterAnnotation);
if (!cubinGetter) {
kernelFunction.emitError("Missing ")
<< kCubinGetterAnnotation << " attribute.";
@ -337,7 +337,7 @@ void GpuLaunchFuncToCudaCallsPass::translateGpuLaunchCalls(
auto cuModule = allocatePointer(builder, loc);
FuncOp cuModuleLoad = getModule().lookupSymbol<FuncOp>(cuModuleLoadName);
builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getCUResultType()},
builder.getFunctionAttr(cuModuleLoad),
builder.getSymbolRefAttr(cuModuleLoad),
ArrayRef<Value *>{cuModule, data.getResult(0)});
// Get the function from the module. The name corresponds to the name of
// the kernel function.
@ -349,14 +349,14 @@ void GpuLaunchFuncToCudaCallsPass::translateGpuLaunchCalls(
getModule().lookupSymbol<FuncOp>(cuModuleGetFunctionName);
builder.create<LLVM::CallOp>(
loc, ArrayRef<Type>{getCUResultType()},
builder.getFunctionAttr(cuModuleGetFunction),
builder.getSymbolRefAttr(cuModuleGetFunction),
ArrayRef<Value *>{cuFunction, cuOwningModuleRef, kernelName});
// Grab the global stream needed for execution.
FuncOp cuGetStreamHelper =
getModule().lookupSymbol<FuncOp>(cuGetStreamHelperName);
auto cuStream = builder.create<LLVM::CallOp>(
loc, ArrayRef<Type>{getPointerType()},
builder.getFunctionAttr(cuGetStreamHelper), ArrayRef<Value *>{});
builder.getSymbolRefAttr(cuGetStreamHelper), ArrayRef<Value *>{});
// Invoke the function with required arguments.
auto cuLaunchKernel = getModule().lookupSymbol<FuncOp>(cuLaunchKernelName);
auto cuFunctionRef =
@ -366,7 +366,7 @@ void GpuLaunchFuncToCudaCallsPass::translateGpuLaunchCalls(
builder.create<LLVM::IntToPtrOp>(loc, getPointerPointerType(), zero);
builder.create<LLVM::CallOp>(
loc, ArrayRef<Type>{getCUResultType()},
builder.getFunctionAttr(cuLaunchKernel),
builder.getSymbolRefAttr(cuLaunchKernel),
ArrayRef<Value *>{cuFunctionRef, launchOp.getOperand(0),
launchOp.getOperand(1), launchOp.getOperand(2),
launchOp.getOperand(3), launchOp.getOperand(4),
@ -377,7 +377,7 @@ void GpuLaunchFuncToCudaCallsPass::translateGpuLaunchCalls(
// Sync on the stream to make it synchronous.
auto cuStreamSync = getModule().lookupSymbol<FuncOp>(cuStreamSynchronizeName);
builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getCUResultType()},
builder.getFunctionAttr(cuStreamSync),
builder.getSymbolRefAttr(cuStreamSync),
ArrayRef<Value *>(cuStream.getResult(0)));
launchOp.erase();
}

View File

@ -94,7 +94,7 @@ private:
auto memory =
ob.create<LLVM::CallOp>(
loc, ArrayRef<Type>{LLVM::LLVMType::getInt8PtrTy(llvmDialect)},
builder.getFunctionAttr(getMallocHelper(loc, builder)),
builder.getSymbolRefAttr(getMallocHelper(loc, builder)),
ArrayRef<Value *>{sizeConstant})
.getResult(0);
for (auto byte : llvm::enumerate(blob.getValue().bytes())) {
@ -111,7 +111,7 @@ private:
}
ob.create<LLVM::ReturnOp>(loc, ArrayRef<Value *>{memory});
// Store the name of the getter on the function for easier lookup.
orig.setAttr(kCubinGetterAnnotation, builder.getFunctionAttr(result));
orig.setAttr(kCubinGetterAnnotation, builder.getSymbolRefAttr(result));
return result;
}

View File

@ -456,7 +456,7 @@ struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> {
Value *allocated =
rewriter
.create<LLVM::CallOp>(op->getLoc(), getVoidPtrType(),
rewriter.getFunctionAttr(mallocFunc),
rewriter.getSymbolRefAttr(mallocFunc),
cumulativeSize)
.getResult(0);
auto structElementType = lowering.convertType(elementType);
@ -520,7 +520,7 @@ struct DeallocOpLowering : public LLVMLegalizationPattern<DeallocOp> {
Value *casted = rewriter.create<LLVM::BitcastOp>(
op->getLoc(), getVoidPtrType(), bufferPtr);
rewriter.replaceOpWithNewOp<LLVM::CallOp>(
op, ArrayRef<Type>(), rewriter.getFunctionAttr(freeFunc), casted);
op, ArrayRef<Type>(), rewriter.getSymbolRefAttr(freeFunc), casted);
return matchSuccess();
}
};

View File

@ -21,6 +21,7 @@
#include "mlir/GPU/GPUDialect.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/PatternMatch.h"
@ -386,7 +387,7 @@ void LaunchFuncOp::build(Builder *builder, OperationState *result,
{gridSizeX, gridSizeY, gridSizeZ, blockSizeX, blockSizeY, blockSizeZ});
result->addOperands(kernelOperands);
result->addAttribute(getKernelAttrName(),
builder->getFunctionAttr(kernelFunc));
builder->getSymbolRefAttr(kernelFunc));
}
void LaunchFuncOp::build(Builder *builder, OperationState *result,
@ -398,7 +399,7 @@ void LaunchFuncOp::build(Builder *builder, OperationState *result,
}
StringRef LaunchFuncOp::kernel() {
return getAttrOfType<FunctionAttr>(getKernelAttrName()).getValue();
return getAttrOfType<SymbolRefAttr>(getKernelAttrName()).getValue();
}
unsigned LaunchFuncOp::getNumKernelOperands() {
@ -421,7 +422,7 @@ LogicalResult LaunchFuncOp::verify() {
auto kernelAttr = this->getAttr(getKernelAttrName());
if (!kernelAttr) {
return emitOpError("attribute 'kernel' must be specified");
} else if (!kernelAttr.isa<FunctionAttr>()) {
} else if (!kernelAttr.isa<SymbolRefAttr>()) {
return emitOpError("attribute 'kernel' must be a function");
}

View File

@ -670,10 +670,9 @@ void ModulePrinter::printAttribute(Attribute attr, bool mayElideType) {
case StandardAttributes::Type:
printType(attr.cast<TypeAttr>().getValue());
break;
case StandardAttributes::Function: {
os << '@' << attr.cast<FunctionAttr>().getValue();
case StandardAttributes::SymbolRef:
os << '@' << attr.cast<SymbolRefAttr>().getValue();
break;
}
case StandardAttributes::OpaqueElements: {
auto eltsAttr = attr.cast<OpaqueElementsAttr>();
os << "opaque<\"" << eltsAttr.getDialect()->getNamespace() << "\", ";

View File

@ -246,15 +246,15 @@ FloatAttr::verifyConstructionInvariants(llvm::Optional<Location> loc,
}
//===----------------------------------------------------------------------===//
// FunctionAttr
// SymbolRefAttr
//===----------------------------------------------------------------------===//
FunctionAttr FunctionAttr::get(StringRef value, MLIRContext *ctx) {
return Base::get(ctx, StandardAttributes::Function, value,
SymbolRefAttr SymbolRefAttr::get(StringRef value, MLIRContext *ctx) {
return Base::get(ctx, StandardAttributes::SymbolRef, value,
NoneType::get(ctx));
}
StringRef FunctionAttr::getValue() const { return getImpl()->value; }
StringRef SymbolRefAttr::getValue() const { return getImpl()->value; }
//===----------------------------------------------------------------------===//
// IntegerAttr

View File

@ -175,11 +175,14 @@ IntegerSetAttr Builder::getIntegerSetAttr(IntegerSet set) {
TypeAttr Builder::getTypeAttr(Type type) { return TypeAttr::get(type); }
FunctionAttr Builder::getFunctionAttr(FuncOp value) {
return getFunctionAttr(value.getName());
SymbolRefAttr Builder::getSymbolRefAttr(Operation *value) {
auto symName =
value->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName());
assert(symName && "value does not have a valid symbol name");
return getSymbolRefAttr(symName.getValue());
}
FunctionAttr Builder::getFunctionAttr(StringRef value) {
return FunctionAttr::get(value, getContext());
SymbolRefAttr Builder::getSymbolRefAttr(StringRef value) {
return SymbolRefAttr::get(value, getContext());
}
ElementsAttr Builder::getDenseElementsAttr(ShapedType type,

View File

@ -154,8 +154,8 @@ ParseResult FuncOp::parse(OpAsmParser *parser, OperationState *result) {
SmallVector<SmallVector<NamedAttribute, 2>, 4> argAttrs;
auto &builder = parser->getBuilder();
// Parse the name as a function attribute.
FunctionAttr nameAttr;
// Parse the name as a symbol reference attribute.
SymbolRefAttr nameAttr;
if (parser->parseAttribute(nameAttr, SymbolTable::getSymbolAttrName(),
result->attributes))
return failure();

View File

@ -82,7 +82,7 @@ namespace {
struct BuiltinDialect : public Dialect {
BuiltinDialect(MLIRContext *context) : Dialect(/*name=*/"", context) {
addAttributes<AffineMapAttr, ArrayAttr, BoolAttr, DenseElementsAttr,
DictionaryAttr, FloatAttr, FunctionAttr, IntegerAttr,
DictionaryAttr, FloatAttr, SymbolRefAttr, IntegerAttr,
IntegerSetAttr, OpaqueAttr, OpaqueElementsAttr,
SparseElementsAttr, StringAttr, TypeAttr, UnitAttr>();
addAttributes<CallSiteLoc, FileLineColLoc, FusedLoc, NameLoc, UnknownLoc>();

View File

@ -335,7 +335,7 @@ static ParseResult parseCallOp(OpAsmParser *parser, OperationState *result) {
SmallVector<NamedAttribute, 4> attrs;
SmallVector<OpAsmParser::OperandType, 8> operands;
Type type;
FunctionAttr funcAttr;
SymbolRefAttr funcAttr;
llvm::SMLoc trailingTypeLoc;
// Parse an operand list that will, in practice, contain 0 or 1 operand. In

View File

@ -204,7 +204,7 @@ public:
Value *allocSize =
mul(size, constant(int64Ty, IntegerAttr::get(indexType, elementSize)));
Value *allocated =
call(voidPtrTy, rewriter.getFunctionAttr(mallocFunc), allocSize)
call(voidPtrTy, rewriter.getSymbolRefAttr(mallocFunc), allocSize)
.getOperation()
->getResult(0);
allocated = bitcast(elementPtrType, allocated);
@ -248,7 +248,7 @@ public:
edsc::ScopedContext context(rewriter, op->getLoc());
Value *casted = bitcast(voidPtrTy, extractvalue(elementPtrTy, operands[0],
positionAttr(rewriter, 0)));
call(ArrayRef<Type>(), rewriter.getFunctionAttr(freeFunc), casted);
call(ArrayRef<Type>(), rewriter.getSymbolRefAttr(freeFunc), casted);
rewriter.replaceOp(op, llvm::None);
return matchSuccess();
}
@ -650,7 +650,7 @@ static void getLLVMLibraryCallDefinition(FuncOp fn,
implFnArgs.push_back(alloca);
llvm_store(arg, alloca);
}
call(ArrayRef<Type>(), builder.getFunctionAttr(implFn), implFnArgs);
call(ArrayRef<Type>(), builder.getSymbolRefAttr(implFn), implFnArgs);
llvm_return{ArrayRef<Value *>()};
}
@ -694,7 +694,7 @@ public:
auto f = getLLVMLibraryCallDeclaration<LinalgOp>(op, lowering, rewriter);
static_cast<LinalgTypeConverter &>(lowering).addLibraryFnDeclaration(f);
auto fAttr = rewriter.getFunctionAttr(f);
auto fAttr = rewriter.getSymbolRefAttr(f);
auto named = rewriter.getNamedAttr("callee", fAttr);
rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, operands,
ArrayRef<NamedAttribute>{named});

View File

@ -215,7 +215,7 @@ Token Lexer::lexBareIdentifierOrKeyword(const char *tokStart) {
/// Lex an '@foo' identifier.
///
/// function-id ::= `@` bare-id
/// symbol-ref-id ::= `@` bare-id
///
Token Lexer::lexAtIdentifier(const char *tokStart) {
// These always start with a letter or underscore.

View File

@ -931,7 +931,7 @@ ParseResult Parser::parseXInDimensionList() {
/// | type
/// | `[` (attribute-value (`,` attribute-value)*)? `]`
/// | `{` (attribute-entry (`,` attribute-entry)*)? `}`
/// | function-id `:` function-type
/// | symbol-ref-id
/// | `dense` `<` attribute-value `>` `:`
/// (tensor-type | vector-type)
/// | `sparse` `<` attribute-value `,` attribute-value `>`
@ -1010,13 +1010,6 @@ Attribute Parser::parseAttribute(Type type) {
nullptr);
}
// Parse a function attribute.
case Token::at_identifier: {
auto nameStr = getTokenSpelling();
consumeToken(Token::at_identifier);
return builder.getFunctionAttr(nameStr.drop_front());
}
// Parse a location attribute.
case Token::kw_loc: {
LocationAttr attr;
@ -1043,6 +1036,13 @@ Attribute Parser::parseAttribute(Type type) {
: StringAttr::get(val, getContext());
}
// Parse a symbol reference attribute.
case Token::at_identifier: {
auto nameStr = getTokenSpelling();
consumeToken(Token::at_identifier);
return builder.getSymbolRefAttr(nameStr.drop_front());
}
// Parse a 'unit' attribute.
case Token::kw_unit:
consumeToken(Token::kw_unit);

View File

@ -22,6 +22,7 @@
#include "mlir/SPIRV/SPIRVOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/SPIRV/SPIRVTypes.h"
@ -280,8 +281,8 @@ static ParseResult parseEntryPointOp(OpAsmParser *parser,
parser->resolveOperands(identifiers, idTypes, loc, state->operands)) {
return failure();
}
if (!fn.isa<FunctionAttr>()) {
return parser->emitError(loc, "expected function attribute");
if (!fn.isa<SymbolRefAttr>()) {
return parser->emitError(loc, "expected symbol reference attribute");
}
state->addTypes(
spirv::EntryPointType::get(parser->getBuilder().getContext()));

View File

@ -21,6 +21,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/IR/Builders.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/Module.h"
#include "mlir/SPIRV/SPIRVOps.h"
#include "mlir/SPIRV/Serialization.h"

View File

@ -20,6 +20,7 @@
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/OpImplementation.h"
@ -401,7 +402,7 @@ void BranchOp::eraseOperand(unsigned index) {
//===----------------------------------------------------------------------===//
static ParseResult parseCallOp(OpAsmParser *parser, OperationState *result) {
FunctionAttr calleeAttr;
SymbolRefAttr calleeAttr;
FunctionType calleeType;
SmallVector<OpAsmParser::OperandType, 4> operands;
auto calleeLoc = parser->getNameLoc();
@ -428,9 +429,9 @@ static void print(OpAsmPrinter *p, CallOp op) {
static LogicalResult verify(CallOp op) {
// Check that the callee attribute was specified.
auto fnAttr = op.getAttrOfType<FunctionAttr>("callee");
auto fnAttr = op.getAttrOfType<SymbolRefAttr>("callee");
if (!fnAttr)
return op.emitOpError("requires a 'callee' function attribute");
return op.emitOpError("requires a 'callee' symbol reference attribute");
auto fn =
op.getParentOfType<ModuleOp>().lookupSymbol<FuncOp>(fnAttr.getValue());
if (!fn)
@ -474,7 +475,7 @@ struct SimplifyIndirectCallWithKnownCallee
PatternMatchResult matchAndRewrite(CallIndirectOp indirectCall,
PatternRewriter &rewriter) const override {
// Check that the callee is a constant callee.
FunctionAttr calledFn;
SymbolRefAttr calledFn;
if (!matchPattern(indirectCall.getCallee(), m_Constant(&calledFn)))
return matchFailure();
@ -1029,8 +1030,8 @@ static void print(OpAsmPrinter *p, ConstantOp &op) {
*p << ' ';
p->printAttribute(op.getValue());
// If the value is a function, print a trailing type.
if (op.getValue().isa<FunctionAttr>()) {
// If the value is a symbol reference, print a trailing type.
if (op.getValue().isa<SymbolRefAttr>()) {
*p << " : ";
p->printType(op.getType());
}
@ -1043,9 +1044,9 @@ static ParseResult parseConstantOp(OpAsmParser *parser,
parser->parseAttribute(valueAttr, "value", result->attributes))
return failure();
// If the attribute is a function, then we expect a trailing type.
// If the attribute is a symbol reference, then we expect a trailing type.
Type type;
if (!valueAttr.isa<FunctionAttr>())
if (!valueAttr.isa<SymbolRefAttr>())
type = valueAttr.getType();
else if (parser->parseColonType(type))
return failure();
@ -1093,7 +1094,7 @@ static LogicalResult verify(ConstantOp &op) {
}
if (type.isa<FunctionType>()) {
auto fnAttr = value.dyn_cast<FunctionAttr>();
auto fnAttr = value.dyn_cast<SymbolRefAttr>();
if (!fnAttr)
return op.emitOpError("requires 'value' to be a function reference");
@ -1125,8 +1126,8 @@ OpFoldResult ConstantOp::fold(ArrayRef<Attribute> operands) {
/// Returns true if a constant operation can be built with the given value and
/// result type.
bool ConstantOp::isBuildableWith(Attribute value, Type type) {
// FunctionAttr can only be used with a function type.
if (value.isa<FunctionAttr>())
// SymbolRefAttr can only be used with a function type.
if (value.isa<SymbolRefAttr>())
return type.isa<FunctionType>();
// Otherwise, the attribute must have the same type as 'type'.
if (value.getType() != type)

View File

@ -83,7 +83,7 @@ llvm::Constant *ModuleTranslation::getLLVMConstant(llvm::Type *llvmType,
return llvm::ConstantInt::get(llvmType, intAttr.getValue());
if (auto floatAttr = attr.dyn_cast<FloatAttr>())
return llvm::ConstantFP::get(llvmType, floatAttr.getValue());
if (auto funcAttr = attr.dyn_cast<FunctionAttr>())
if (auto funcAttr = attr.dyn_cast<SymbolRefAttr>())
return functionMapping.lookup(funcAttr.getValue());
if (auto splatAttr = attr.dyn_cast<SplatElementsAttr>()) {
auto *vectorType = cast<llvm::VectorType>(llvmType);
@ -176,7 +176,7 @@ bool ModuleTranslation::convertOperation(Operation &opInst,
auto convertCall = [this, &builder](Operation &op) -> llvm::Value * {
auto operands = lookupValues(op.getOperands());
ArrayRef<llvm::Value *> operandsRef(operands);
if (auto attr = op.getAttrOfType<FunctionAttr>("callee")) {
if (auto attr = op.getAttrOfType<SymbolRefAttr>("callee")) {
return builder.CreateCall(functionMapping.lookup(attr.getValue()),
operandsRef);
} else {

View File

@ -20,6 +20,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/IR/Builders.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/StandardOps/Ops.h"
#include "mlir/Transforms/FoldUtils.h"

View File

@ -30,6 +30,7 @@
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/Operation.h"
#include "mlir/StandardOps/Ops.h"
#include "llvm/ADT/DenseMap.h"

View File

@ -29,6 +29,7 @@
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/Operation.h"
#include "mlir/StandardOps/Ops.h"
#include "llvm/ADT/DenseMap.h"

View File

@ -28,6 +28,7 @@
#include "mlir/Analysis/Dominance.h"
#include "mlir/Analysis/Utils.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/Module.h"
#include "mlir/StandardOps/Ops.h"
#include "mlir/Support/MathExtras.h"

View File

@ -30,7 +30,7 @@ spv.module "Logical" "VulkanKHR" {
func @do_nothing() -> () {
spv.Return
}
// expected-error @+1 {{custom op 'spv.EntryPoint' expected function attribute}}
// expected-error @+1 {{custom op 'spv.EntryPoint' expected symbol reference attribute}}
%4 = spv.EntryPoint "GLCompute" "do_nothing"
}

View File

@ -85,7 +85,7 @@ def BOp : NS_Op<"b_op", []> {
F64Attr:$f64_attr,
StrAttr:$str_attr,
ElementsAttr:$elements_attr,
FunctionAttr:$function_attr,
SymbolRefAttr:$function_attr,
SomeTypeAttr:$type_attr,
ArrayAttr:$array_attr,
TypedArrayAttrBase<SomeAttr, "SomeAttr array">:$some_attr_array,
@ -122,7 +122,7 @@ def BOp : NS_Op<"b_op", []> {
// CHECK: if (!(((tblgen_f64_attr.isa<FloatAttr>())) && ((tblgen_f64_attr.cast<FloatAttr>().getType().isF64()))))
// CHECK: if (!((tblgen_str_attr.isa<StringAttr>())))
// CHECK: if (!((tblgen_elements_attr.isa<ElementsAttr>())))
// CHECK: if (!((tblgen_function_attr.isa<FunctionAttr>())))
// CHECK: if (!((tblgen_function_attr.isa<SymbolRefAttr>())))
// CHECK: if (!(((tblgen_type_attr.isa<TypeAttr>())) && ((tblgen_type_attr.cast<TypeAttr>().getValue().isa<SomeType>()))))
// CHECK: if (!((tblgen_array_attr.isa<ArrayAttr>())))
// CHECK: if (!(((tblgen_some_attr_array.isa<ArrayAttr>())) && (llvm::all_of(tblgen_some_attr_array.cast<ArrayAttr>(), [](Attribute attr) { return (some-condition); }))))