forked from OSchip/llvm-project
[Builder] Eliminate the StringRef/StringAttr forms of getSymbolRefAttr.
The StringAttr version doesn't need a context, so we can just use the existing `SymbolRefAttr::get` form. The StringRef version isn't preferred so we want to encourage people to use StringAttr. There is an additional form of getSymbolRefAttr that takes a (SymbolTrait implementing) operation. This should also be moved, but I'll do that as a separate patch. Differential Revision: https://reviews.llvm.org/D108922
This commit is contained in:
parent
7f2ce19d1c
commit
faf1c22408
|
@ -2545,7 +2545,7 @@ def fir_CallOp : fir_Op<"call", [CallOpInterface]> {
|
|||
[{
|
||||
$_state.addOperands(operands);
|
||||
$_state.addAttribute(calleeAttrName($_state.name),
|
||||
$_builder.getSymbolRefAttr(callee));
|
||||
SymbolRefAttr::get(callee));
|
||||
$_state.addTypes(callee.getType().getResults());
|
||||
}]>,
|
||||
OpBuilder<(ins "mlir::SymbolRefAttr":$callee,
|
||||
|
@ -2560,7 +2560,8 @@ def fir_CallOp : fir_Op<"call", [CallOpInterface]> {
|
|||
"llvm::ArrayRef<mlir::Type>":$results,
|
||||
CArg<"mlir::ValueRange", "{}">:$operands),
|
||||
[{
|
||||
build($_builder, $_state, $_builder.getSymbolRefAttr(callee), results,
|
||||
build($_builder, $_state,
|
||||
SymbolRefAttr::get($_builder.getContext(), callee), results,
|
||||
operands);
|
||||
}]>];
|
||||
|
||||
|
|
|
@ -919,7 +919,7 @@ mlir::SymbolRefAttr IntrinsicLibrary::getUnrestrictedIntrinsicSymbolRefAttr(
|
|||
funcOp = getWrapper(rtCallGenerator, name, signature, loadRefArguments);
|
||||
}
|
||||
|
||||
return builder.getSymbolRefAttr(funcOp.getName());
|
||||
return SymbolRefAttr::get(funcOp);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -398,8 +398,7 @@ mlir::ParseResult fir::parseCmpcOp(mlir::OpAsmParser &parser,
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void fir::ConvertOp::getCanonicalizationPatterns(
|
||||
OwningRewritePatternList &results, MLIRContext *context) {
|
||||
}
|
||||
OwningRewritePatternList &results, MLIRContext *context) {}
|
||||
|
||||
mlir::OpFoldResult fir::ConvertOp::fold(llvm::ArrayRef<mlir::Attribute> opnds) {
|
||||
if (value().getType() == getType())
|
||||
|
@ -629,7 +628,8 @@ void fir::GlobalOp::build(mlir::OpBuilder &builder, OperationState &result,
|
|||
result.addAttribute(typeAttrName(result.name), mlir::TypeAttr::get(type));
|
||||
result.addAttribute(mlir::SymbolTable::getSymbolAttrName(),
|
||||
builder.getStringAttr(name));
|
||||
result.addAttribute(symbolAttrName(), builder.getSymbolRefAttr(name));
|
||||
result.addAttribute(symbolAttrName(),
|
||||
SymbolRefAttr::get(builder.getContext(), name));
|
||||
if (isConstant)
|
||||
result.addAttribute(constantAttrName(result.name), builder.getUnitAttr());
|
||||
if (initialVal)
|
||||
|
@ -1330,7 +1330,7 @@ static constexpr llvm::StringRef getTargetOffsetAttr() {
|
|||
template <typename A, typename... AdditionalArgs>
|
||||
static A getSubOperands(unsigned pos, A allArgs,
|
||||
mlir::DenseIntElementsAttr ranges,
|
||||
AdditionalArgs &&... additionalArgs) {
|
||||
AdditionalArgs &&...additionalArgs) {
|
||||
unsigned start = 0;
|
||||
for (unsigned i = 0; i < pos; ++i)
|
||||
start += (*(ranges.begin() + i)).getZExtValue();
|
||||
|
|
|
@ -174,7 +174,8 @@ void GenericCallOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
|
|||
// Generic call always returns an unranked Tensor initially.
|
||||
state.addTypes(UnrankedTensorType::get(builder.getF64Type()));
|
||||
state.addOperands(arguments);
|
||||
state.addAttribute("callee", builder.getSymbolRefAttr(callee));
|
||||
state.addAttribute("callee",
|
||||
mlir::SymbolRefAttr::get(builder.getContext(), callee));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -174,7 +174,8 @@ void GenericCallOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
|
|||
// Generic call always returns an unranked Tensor initially.
|
||||
state.addTypes(UnrankedTensorType::get(builder.getF64Type()));
|
||||
state.addOperands(arguments);
|
||||
state.addAttribute("callee", builder.getSymbolRefAttr(callee));
|
||||
state.addAttribute("callee",
|
||||
mlir::SymbolRefAttr::get(builder.getContext(), callee));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -256,7 +256,8 @@ void GenericCallOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
|
|||
// Generic call always returns an unranked Tensor initially.
|
||||
state.addTypes(UnrankedTensorType::get(builder.getF64Type()));
|
||||
state.addOperands(arguments);
|
||||
state.addAttribute("callee", builder.getSymbolRefAttr(callee));
|
||||
state.addAttribute("callee",
|
||||
mlir::SymbolRefAttr::get(builder.getContext(), callee));
|
||||
}
|
||||
|
||||
/// Return the callee of the generic call operation, this is required by the
|
||||
|
|
|
@ -256,7 +256,8 @@ void GenericCallOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
|
|||
// Generic call always returns an unranked Tensor initially.
|
||||
state.addTypes(UnrankedTensorType::get(builder.getF64Type()));
|
||||
state.addOperands(arguments);
|
||||
state.addAttribute("callee", builder.getSymbolRefAttr(callee));
|
||||
state.addAttribute("callee",
|
||||
mlir::SymbolRefAttr::get(builder.getContext(), callee));
|
||||
}
|
||||
|
||||
/// Return the callee of the generic call operation, this is required by the
|
||||
|
|
|
@ -256,7 +256,8 @@ void GenericCallOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
|
|||
// Generic call always returns an unranked Tensor initially.
|
||||
state.addTypes(UnrankedTensorType::get(builder.getF64Type()));
|
||||
state.addOperands(arguments);
|
||||
state.addAttribute("callee", builder.getSymbolRefAttr(callee));
|
||||
state.addAttribute("callee",
|
||||
mlir::SymbolRefAttr::get(builder.getContext(), callee));
|
||||
}
|
||||
|
||||
/// Return the callee of the generic call operation, this is required by the
|
||||
|
|
|
@ -282,7 +282,8 @@ void GenericCallOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
|
|||
// Generic call always returns an unranked Tensor initially.
|
||||
state.addTypes(UnrankedTensorType::get(builder.getF64Type()));
|
||||
state.addOperands(arguments);
|
||||
state.addAttribute("callee", builder.getSymbolRefAttr(callee));
|
||||
state.addAttribute("callee",
|
||||
mlir::SymbolRefAttr::get(builder.getContext(), callee));
|
||||
}
|
||||
|
||||
/// Return the callee of the generic call operation, this is required by the
|
||||
|
|
|
@ -522,7 +522,7 @@ private:
|
|||
mlir::FuncOp calledFunc = calledFuncIt->second;
|
||||
return builder.create<GenericCallOp>(
|
||||
location, calledFunc.getType().getResult(0),
|
||||
builder.getSymbolRefAttr(callee), operands);
|
||||
mlir::SymbolRefAttr::get(builder.getContext(), callee), operands);
|
||||
}
|
||||
|
||||
/// Emit a print expression. It emits specific operations for two builtins:
|
||||
|
|
|
@ -515,14 +515,22 @@ def LLVM_CallOp : LLVM_Op<"call",
|
|||
let results = (outs Variadic<LLVM_Type>);
|
||||
let builders = [
|
||||
OpBuilder<(ins "LLVMFuncOp":$func, "ValueRange":$operands,
|
||||
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
|
||||
[{
|
||||
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes), [{
|
||||
Type resultType = func.getType().getReturnType();
|
||||
if (!resultType.isa<LLVM::LLVMVoidType>())
|
||||
$_state.addTypes(resultType);
|
||||
$_state.addAttribute("callee", $_builder.getSymbolRefAttr(func));
|
||||
$_state.addAttribute("callee", SymbolRefAttr::get(func));
|
||||
$_state.addAttributes(attributes);
|
||||
$_state.addOperands(operands);
|
||||
}]>,
|
||||
OpBuilder<(ins "TypeRange":$results, "StringAttr":$callee,
|
||||
CArg<"ValueRange", "{}">:$operands), [{
|
||||
build($_builder, $_state, results, SymbolRefAttr::get(callee), operands);
|
||||
}]>,
|
||||
OpBuilder<(ins "TypeRange":$results, "StringRef":$callee,
|
||||
CArg<"ValueRange", "{}">:$operands), [{
|
||||
build($_builder, $_state, results,
|
||||
StringAttr::get($_builder.getContext(), callee), operands);
|
||||
}]>];
|
||||
let verifier = [{ return ::verify(*this); }];
|
||||
let parser = [{ return parseCallOp(parser, result); }];
|
||||
|
|
|
@ -560,7 +560,7 @@ def CallOp : Std_Op<"call",
|
|||
let builders = [
|
||||
OpBuilder<(ins "FuncOp":$callee, CArg<"ValueRange", "{}">:$operands), [{
|
||||
$_state.addOperands(operands);
|
||||
$_state.addAttribute("callee",$_builder.getSymbolRefAttr(callee));
|
||||
$_state.addAttribute("callee", SymbolRefAttr::get(callee));
|
||||
$_state.addTypes(callee.getType().getResults());
|
||||
}]>,
|
||||
OpBuilder<(ins "SymbolRefAttr":$callee, "TypeRange":$results,
|
||||
|
@ -569,14 +569,19 @@ def CallOp : Std_Op<"call",
|
|||
$_state.addAttribute("callee", callee);
|
||||
$_state.addTypes(results);
|
||||
}]>,
|
||||
OpBuilder<(ins "StringAttr":$callee, "TypeRange":$results,
|
||||
CArg<"ValueRange", "{}">:$operands), [{
|
||||
build($_builder, $_state, SymbolRefAttr::get(callee), results, operands);
|
||||
}]>,
|
||||
OpBuilder<(ins "StringRef":$callee, "TypeRange":$results,
|
||||
CArg<"ValueRange", "{}">:$operands), [{
|
||||
build($_builder, $_state, $_builder.getSymbolRefAttr(callee), results,
|
||||
operands);
|
||||
build($_builder, $_state, StringAttr::get($_builder.getContext(), callee),
|
||||
results, operands);
|
||||
}]>];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
StringRef getCallee() { return callee(); }
|
||||
StringAttr getCalleeAttr() { return calleeAttr().getAttr(); }
|
||||
FunctionType getCalleeType();
|
||||
|
||||
/// Get the argument operands to the called function.
|
||||
|
|
|
@ -97,17 +97,6 @@ public:
|
|||
FloatAttr getFloatAttr(Type type, const APFloat &value);
|
||||
StringAttr getStringAttr(const Twine &bytes);
|
||||
ArrayAttr getArrayAttr(ArrayRef<Attribute> value);
|
||||
FlatSymbolRefAttr getSymbolRefAttr(Operation *value);
|
||||
FlatSymbolRefAttr getSymbolRefAttr(StringAttr value);
|
||||
SymbolRefAttr getSymbolRefAttr(StringAttr value,
|
||||
ArrayRef<FlatSymbolRefAttr> nestedReferences);
|
||||
SymbolRefAttr getSymbolRefAttr(StringRef value,
|
||||
ArrayRef<FlatSymbolRefAttr> nestedReferences) {
|
||||
return getSymbolRefAttr(getStringAttr(value), nestedReferences);
|
||||
}
|
||||
FlatSymbolRefAttr getSymbolRefAttr(StringRef value) {
|
||||
return getSymbolRefAttr(getStringAttr(value));
|
||||
}
|
||||
|
||||
// Returns a 0-valued attribute of the given `type`. This function only
|
||||
// supports boolean, integer, and 16-/32-/64-bit float types, and vector or
|
||||
|
|
|
@ -23,6 +23,7 @@ class FunctionType;
|
|||
class IntegerSet;
|
||||
class IntegerType;
|
||||
class Location;
|
||||
class Operation;
|
||||
class ShapedType;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -685,12 +686,17 @@ public:
|
|||
using ValueType = StringRef;
|
||||
|
||||
/// Construct a symbol reference for the given value name.
|
||||
static FlatSymbolRefAttr get(StringAttr value) {
|
||||
return SymbolRefAttr::get(value);
|
||||
}
|
||||
static FlatSymbolRefAttr get(MLIRContext *ctx, StringRef value) {
|
||||
return SymbolRefAttr::get(ctx, value);
|
||||
}
|
||||
|
||||
static FlatSymbolRefAttr get(StringAttr value) {
|
||||
return SymbolRefAttr::get(value);
|
||||
/// Convenience getter for building a SymbolRefAttr based on an operation
|
||||
/// that implements the SymbolTrait.
|
||||
static FlatSymbolRefAttr get(Operation *symbol) {
|
||||
return SymbolRefAttr::get(symbol);
|
||||
}
|
||||
|
||||
/// Returns the name of the held symbol reference as a StringAttr.
|
||||
|
|
|
@ -893,8 +893,16 @@ def Builtin_SymbolRefAttr : Builtin_Attr<"SymbolRef"> {
|
|||
}]>,
|
||||
];
|
||||
let extraClassDeclaration = [{
|
||||
static FlatSymbolRefAttr get(MLIRContext *ctx, StringRef value);
|
||||
static SymbolRefAttr get(MLIRContext *ctx, StringRef value,
|
||||
ArrayRef<FlatSymbolRefAttr> nestedRefs);
|
||||
/// Convenience getters for building a SymbolRefAttr with no path, which is
|
||||
/// known to produce a FlatSymbolRefAttr.
|
||||
static FlatSymbolRefAttr get(StringAttr value);
|
||||
static FlatSymbolRefAttr get(MLIRContext *ctx, StringRef value);
|
||||
|
||||
/// Convenience getter for buliding a SymbolRefAttr based on an operation
|
||||
/// that implements the SymbolTrait.
|
||||
static FlatSymbolRefAttr get(Operation *symbol);
|
||||
|
||||
/// Returns the name of the fully resolved symbol, i.e. the leaf of the
|
||||
/// reference path.
|
||||
|
|
|
@ -1582,15 +1582,16 @@ def SymbolRefAttr : Attr<CPred<"$_self.isa<::mlir::SymbolRefAttr>()">,
|
|||
let storageType = [{ ::mlir::SymbolRefAttr }];
|
||||
let returnType = [{ ::mlir::SymbolRefAttr }];
|
||||
let valueType = NoneType;
|
||||
let constBuilderCall = "$_builder.getSymbolRefAttr($0)";
|
||||
let constBuilderCall = "SymbolRefAttr::get($_builder.getContext(), $0)";
|
||||
let convertFromStorage = "$_self";
|
||||
}
|
||||
|
||||
def FlatSymbolRefAttr : Attr<CPred<"$_self.isa<::mlir::FlatSymbolRefAttr>()">,
|
||||
"flat symbol reference attribute"> {
|
||||
let storageType = [{ ::mlir::FlatSymbolRefAttr }];
|
||||
let returnType = [{ ::llvm::StringRef }];
|
||||
let valueType = NoneType;
|
||||
let constBuilderCall = "$_builder.getSymbolRefAttr($0)";
|
||||
let constBuilderCall = "SymbolRefAttr::get($_builder.getContext(), $0)";
|
||||
let convertFromStorage = "$_self.getValue()";
|
||||
}
|
||||
|
||||
|
|
|
@ -367,7 +367,7 @@ public:
|
|||
|
||||
// Allocate memory for the coroutine frame.
|
||||
auto coroAlloc = rewriter.create<LLVM::CallOp>(
|
||||
loc, i8Ptr, rewriter.getSymbolRefAttr(kMalloc),
|
||||
loc, i8Ptr, SymbolRefAttr::get(rewriter.getContext(), kMalloc),
|
||||
ValueRange(coroSize.getResult()));
|
||||
|
||||
// Begin a coroutine: @llvm.coro.begin.
|
||||
|
@ -399,8 +399,8 @@ public:
|
|||
auto coroMem = rewriter.create<LLVM::CoroFreeOp>(loc, i8Ptr, operands);
|
||||
|
||||
// Free the memory.
|
||||
rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, TypeRange(),
|
||||
rewriter.getSymbolRefAttr(kFree),
|
||||
rewriter.replaceOpWithNewOp<LLVM::CallOp>(
|
||||
op, TypeRange(), SymbolRefAttr::get(rewriter.getContext(), kFree),
|
||||
ValueRange(coroMem.getResult()));
|
||||
|
||||
return success();
|
||||
|
|
|
@ -62,8 +62,7 @@ public:
|
|||
|
||||
LLVMFuncOp funcOp = appendOrGetFuncOp(funcName, funcType, op);
|
||||
auto callOp = rewriter.create<LLVM::CallOp>(
|
||||
op->getLoc(), resultType, rewriter.getSymbolRefAttr(funcOp),
|
||||
castedOperands);
|
||||
op->getLoc(), resultType, SymbolRefAttr::get(funcOp), castedOperands);
|
||||
|
||||
if (resultType == operands.front().getType()) {
|
||||
rewriter.replaceOp(op, {callOp.getResult(0)});
|
||||
|
|
|
@ -171,14 +171,13 @@ void ConvertGpuLaunchFuncToVulkanLaunchFunc::convertGpuLaunchFunc(
|
|||
|
||||
// Create vulkan launch call op.
|
||||
auto vulkanLaunchCallOp = builder.create<CallOp>(
|
||||
loc, TypeRange{}, builder.getSymbolRefAttr(kVulkanLaunch),
|
||||
loc, TypeRange{}, SymbolRefAttr::get(builder.getContext(), kVulkanLaunch),
|
||||
vulkanLaunchOperands);
|
||||
|
||||
// Set SPIR-V binary shader data as an attribute.
|
||||
vulkanLaunchCallOp->setAttr(
|
||||
kSPIRVBlobAttrName,
|
||||
StringAttr::get(loc->getContext(),
|
||||
StringRef(binary.data(), binary.size())));
|
||||
builder.getStringAttr(StringRef(binary.data(), binary.size())));
|
||||
|
||||
// Set entry point name as an attribute.
|
||||
vulkanLaunchCallOp->setAttr(kSPIRVEntryPointAttrName,
|
||||
|
|
|
@ -248,9 +248,7 @@ void VulkanLaunchFuncToVulkanCallsPass::createBindMemRefCalls(
|
|||
}
|
||||
// Create call to `bindMemRef`.
|
||||
builder.create<LLVM::CallOp>(
|
||||
loc, TypeRange(),
|
||||
builder.getSymbolRefAttr(
|
||||
StringRef(symbolName.data(), symbolName.size())),
|
||||
loc, TypeRange(), StringRef(symbolName.data(), symbolName.size()),
|
||||
ValueRange{vulkanRuntime, descriptorSet, descriptorBinding,
|
||||
ptrToMemRefDescriptor});
|
||||
}
|
||||
|
@ -373,8 +371,7 @@ void VulkanLaunchFuncToVulkanCallsPass::translateVulkanLaunchCall(
|
|||
Location loc = cInterfaceVulkanLaunchCallOp.getLoc();
|
||||
// Create call to `initVulkan`.
|
||||
auto initVulkanCall = builder.create<LLVM::CallOp>(
|
||||
loc, TypeRange{getPointerType()}, builder.getSymbolRefAttr(kInitVulkan),
|
||||
ValueRange{});
|
||||
loc, TypeRange{getPointerType()}, kInitVulkan);
|
||||
// The result of `initVulkan` function is a pointer to Vulkan runtime, we
|
||||
// need to pass that pointer to each Vulkan runtime call.
|
||||
auto vulkanRuntime = initVulkanCall.getResult(0);
|
||||
|
@ -396,32 +393,29 @@ void VulkanLaunchFuncToVulkanCallsPass::translateVulkanLaunchCall(
|
|||
// Create call to `setBinaryShader` runtime function with the given pointer to
|
||||
// SPIR-V binary and binary size.
|
||||
builder.create<LLVM::CallOp>(
|
||||
loc, TypeRange(), builder.getSymbolRefAttr(kSetBinaryShader),
|
||||
loc, TypeRange(), kSetBinaryShader,
|
||||
ValueRange{vulkanRuntime, ptrToSPIRVBinary, binarySize});
|
||||
// Create LLVM global with entry point name.
|
||||
Value entryPointName = createEntryPointNameConstant(
|
||||
spirvAttributes.second.getValue(), loc, builder);
|
||||
// Create call to `setEntryPoint` runtime function with the given pointer to
|
||||
// entry point name.
|
||||
builder.create<LLVM::CallOp>(loc, TypeRange(),
|
||||
builder.getSymbolRefAttr(kSetEntryPoint),
|
||||
builder.create<LLVM::CallOp>(loc, TypeRange(), kSetEntryPoint,
|
||||
ValueRange{vulkanRuntime, entryPointName});
|
||||
|
||||
// Create number of local workgroup for each dimension.
|
||||
builder.create<LLVM::CallOp>(
|
||||
loc, TypeRange(), builder.getSymbolRefAttr(kSetNumWorkGroups),
|
||||
loc, TypeRange(), kSetNumWorkGroups,
|
||||
ValueRange{vulkanRuntime, cInterfaceVulkanLaunchCallOp.getOperand(0),
|
||||
cInterfaceVulkanLaunchCallOp.getOperand(1),
|
||||
cInterfaceVulkanLaunchCallOp.getOperand(2)});
|
||||
|
||||
// Create call to `runOnVulkan` runtime function.
|
||||
builder.create<LLVM::CallOp>(loc, TypeRange(),
|
||||
builder.getSymbolRefAttr(kRunOnVulkan),
|
||||
builder.create<LLVM::CallOp>(loc, TypeRange(), kRunOnVulkan,
|
||||
ValueRange{vulkanRuntime});
|
||||
|
||||
// Create call to 'deinitVulkan' runtime function.
|
||||
builder.create<LLVM::CallOp>(loc, TypeRange(),
|
||||
builder.getSymbolRefAttr(kDeinitVulkan),
|
||||
builder.create<LLVM::CallOp>(loc, TypeRange(), kDeinitVulkan,
|
||||
ValueRange{vulkanRuntime});
|
||||
|
||||
// Declare runtime functions.
|
||||
|
|
|
@ -50,7 +50,8 @@ static FlatSymbolRefAttr getLibraryCallSymbolRef(Operation *op,
|
|||
}
|
||||
|
||||
// fnName is a dynamic std::string, unique it via a SymbolRefAttr.
|
||||
FlatSymbolRefAttr fnNameAttr = rewriter.getSymbolRefAttr(fnName);
|
||||
FlatSymbolRefAttr fnNameAttr =
|
||||
SymbolRefAttr::get(rewriter.getContext(), fnName);
|
||||
auto module = op->getParentOfType<ModuleOp>();
|
||||
if (module.lookupSymbol(fnNameAttr.getAttr()))
|
||||
return fnNameAttr;
|
||||
|
|
|
@ -305,7 +305,7 @@ struct DeallocOpLowering : public ConvertOpToLLVMPattern<memref::DeallocOp> {
|
|||
op.getLoc(), getVoidPtrType(),
|
||||
memref.allocatedPtr(rewriter, op.getLoc()));
|
||||
rewriter.replaceOpWithNewOp<LLVM::CallOp>(
|
||||
op, TypeRange(), rewriter.getSymbolRefAttr(freeFunc), casted);
|
||||
op, TypeRange(), SymbolRefAttr::get(freeFunc), casted);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
|
|
@ -559,9 +559,10 @@ SymbolRefAttr PatternLowering::generateRewriter(
|
|||
/*results=*/llvm::None));
|
||||
|
||||
builder.create<pdl_interp::FinalizeOp>(rewriter.getLoc());
|
||||
return builder.getSymbolRefAttr(
|
||||
return SymbolRefAttr::get(
|
||||
builder.getContext(),
|
||||
pdl_interp::PDLInterpDialect::getRewriterModuleName(),
|
||||
builder.getSymbolRefAttr(rewriterFunc));
|
||||
SymbolRefAttr::get(rewriterFunc));
|
||||
}
|
||||
|
||||
void PatternLowering::generateRewriter(
|
||||
|
|
|
@ -1194,8 +1194,8 @@ private:
|
|||
// Helper to emit a call.
|
||||
static void emitCall(ConversionPatternRewriter &rewriter, Location loc,
|
||||
Operation *ref, ValueRange params = ValueRange()) {
|
||||
rewriter.create<LLVM::CallOp>(loc, TypeRange(),
|
||||
rewriter.getSymbolRefAttr(ref), params);
|
||||
rewriter.create<LLVM::CallOp>(loc, TypeRange(), SymbolRefAttr::get(ref),
|
||||
params);
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
@ -542,8 +542,9 @@ void LaunchFuncOp::build(OpBuilder &builder, OperationState &result,
|
|||
blockSize.y, blockSize.z});
|
||||
result.addOperands(kernelOperands);
|
||||
auto kernelModule = kernelFunc->getParentOfType<GPUModuleOp>();
|
||||
auto kernelSymbol = builder.getSymbolRefAttr(
|
||||
kernelModule.getName(), {builder.getSymbolRefAttr(kernelFunc.getName())});
|
||||
auto kernelSymbol =
|
||||
SymbolRefAttr::get(kernelModule.getNameAttr(),
|
||||
{SymbolRefAttr::get(kernelFunc.getNameAttr())});
|
||||
result.addAttribute(getKernelAttrName(), kernelSymbol);
|
||||
SmallVector<int32_t, 8> segmentSizes(8, 1);
|
||||
segmentSizes.front() = 0; // Initially no async dependencies.
|
||||
|
|
|
@ -129,7 +129,7 @@ Operation::result_range mlir::LLVM::createLLVMCall(OpBuilder &b, Location loc,
|
|||
ValueRange paramTypes,
|
||||
ArrayRef<Type> resultTypes) {
|
||||
return b
|
||||
.create<LLVM::CallOp>(loc, resultTypes, b.getSymbolRefAttr(fn),
|
||||
.create<LLVM::CallOp>(loc, resultTypes, SymbolRefAttr::get(fn),
|
||||
paramTypes)
|
||||
->getResults();
|
||||
}
|
||||
|
|
|
@ -1060,7 +1060,7 @@ static LogicalResult verify(spirv::AccessChainOp accessChainOp) {
|
|||
|
||||
void spirv::AddressOfOp::build(OpBuilder &builder, OperationState &state,
|
||||
spirv::GlobalVariableOp var) {
|
||||
build(builder, state, var.type(), builder.getSymbolRefAttr(var));
|
||||
build(builder, state, var.type(), SymbolRefAttr::get(var));
|
||||
}
|
||||
|
||||
static LogicalResult verify(spirv::AddressOfOp addressOfOp) {
|
||||
|
@ -1712,8 +1712,7 @@ void spirv::EntryPointOp::build(OpBuilder &builder, OperationState &state,
|
|||
ArrayRef<Attribute> interfaceVars) {
|
||||
build(builder, state,
|
||||
spirv::ExecutionModelAttr::get(builder.getContext(), executionModel),
|
||||
builder.getSymbolRefAttr(function),
|
||||
builder.getArrayAttr(interfaceVars));
|
||||
SymbolRefAttr::get(function), builder.getArrayAttr(interfaceVars));
|
||||
}
|
||||
|
||||
static ParseResult parseEntryPointOp(OpAsmParser &parser,
|
||||
|
@ -1772,7 +1771,7 @@ void spirv::ExecutionModeOp::build(OpBuilder &builder, OperationState &state,
|
|||
spirv::FuncOp function,
|
||||
spirv::ExecutionMode executionMode,
|
||||
ArrayRef<int32_t> params) {
|
||||
build(builder, state, builder.getSymbolRefAttr(function),
|
||||
build(builder, state, SymbolRefAttr::get(function),
|
||||
spirv::ExecutionModeAttr::get(builder.getContext(), executionMode),
|
||||
builder.getI32ArrayAttr(params));
|
||||
}
|
||||
|
|
|
@ -68,7 +68,7 @@ public:
|
|||
auto varOp = spirvModule.lookupSymbol<spirv::GlobalVariableOp>(varName);
|
||||
|
||||
rewriter.replaceOpWithNewOp<spirv::AddressOfOp>(
|
||||
op, varOp.type(), rewriter.getSymbolRefAttr(varName.getAttr()));
|
||||
op, varOp.type(), SymbolRefAttr::get(varName.getAttr()));
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
|
|
@ -156,7 +156,7 @@ struct DecomposeCallGraphTypesForCallOp
|
|||
resultMapping.push_back(i);
|
||||
}
|
||||
|
||||
CallOp newCallOp = rewriter.create<CallOp>(op.getLoc(), op.getCallee(),
|
||||
CallOp newCallOp = rewriter.create<CallOp>(op.getLoc(), op.getCalleeAttr(),
|
||||
newResultTypes, newOperands);
|
||||
|
||||
// Build a replacement value for each result to replace its uses. If a
|
||||
|
|
|
@ -210,23 +210,6 @@ ArrayAttr Builder::getArrayAttr(ArrayRef<Attribute> value) {
|
|||
return ArrayAttr::get(context, value);
|
||||
}
|
||||
|
||||
FlatSymbolRefAttr Builder::getSymbolRefAttr(Operation *value) {
|
||||
auto symName =
|
||||
value->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName());
|
||||
assert(symName && "value does not have a valid symbol name");
|
||||
return getSymbolRefAttr(symName.getValue());
|
||||
}
|
||||
|
||||
FlatSymbolRefAttr Builder::getSymbolRefAttr(StringAttr value) {
|
||||
return SymbolRefAttr::get(value);
|
||||
}
|
||||
|
||||
SymbolRefAttr
|
||||
Builder::getSymbolRefAttr(StringAttr value,
|
||||
ArrayRef<FlatSymbolRefAttr> nestedReferences) {
|
||||
return SymbolRefAttr::get(value, nestedReferences);
|
||||
}
|
||||
|
||||
ArrayAttr Builder::getBoolArrayAttr(ArrayRef<bool> values) {
|
||||
auto attrs = llvm::to_vector<8>(llvm::map_range(
|
||||
values, [this](bool v) -> Attribute { return getBoolAttr(v); }));
|
||||
|
|
|
@ -10,14 +10,14 @@
|
|||
#include "AttributeDetail.h"
|
||||
#include "mlir/IR/AffineMap.h"
|
||||
#include "mlir/IR/BuiltinDialect.h"
|
||||
#include "mlir/IR/Diagnostics.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/IR/IntegerSet.h"
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/IR/SymbolTable.h"
|
||||
#include "mlir/IR/Types.h"
|
||||
#include "mlir/Interfaces/DecodeAttributesInterfaces.h"
|
||||
#include "llvm/ADT/APSInt.h"
|
||||
#include "llvm/ADT/Sequence.h"
|
||||
#include "llvm/ADT/Twine.h"
|
||||
#include "llvm/Support/Endian.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
@ -272,14 +272,26 @@ LogicalResult FloatAttr::verify(function_ref<InFlightDiagnostic()> emitError,
|
|||
// SymbolRefAttr
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
SymbolRefAttr SymbolRefAttr::get(MLIRContext *ctx, StringRef value,
|
||||
ArrayRef<FlatSymbolRefAttr> nestedRefs) {
|
||||
return get(StringAttr::get(ctx, value), nestedRefs);
|
||||
}
|
||||
|
||||
FlatSymbolRefAttr SymbolRefAttr::get(MLIRContext *ctx, StringRef value) {
|
||||
return get(StringAttr::get(ctx, value));
|
||||
return get(ctx, value, {}).cast<FlatSymbolRefAttr>();
|
||||
}
|
||||
|
||||
FlatSymbolRefAttr SymbolRefAttr::get(StringAttr value) {
|
||||
return get(value, {}).cast<FlatSymbolRefAttr>();
|
||||
}
|
||||
|
||||
FlatSymbolRefAttr SymbolRefAttr::get(Operation *symbol) {
|
||||
auto symName =
|
||||
symbol->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName());
|
||||
assert(symName && "value does not have a valid symbol name");
|
||||
return SymbolRefAttr::get(symName);
|
||||
}
|
||||
|
||||
StringAttr SymbolRefAttr::getLeafReference() const {
|
||||
ArrayRef<FlatSymbolRefAttr> nestedRefs = getNestedReferences();
|
||||
return nestedRefs.empty() ? getRootReference() : nestedRefs.back().getAttr();
|
||||
|
|
|
@ -191,7 +191,8 @@ Attribute Parser::parseAttribute(Type type) {
|
|||
consumeToken(Token::at_identifier);
|
||||
nestedRefs.push_back(SymbolRefAttr::get(getContext(), nameStr));
|
||||
}
|
||||
SymbolRefAttr symbolRefAttr = builder.getSymbolRefAttr(nameStr, nestedRefs);
|
||||
SymbolRefAttr symbolRefAttr =
|
||||
SymbolRefAttr::get(getContext(), nameStr, nestedRefs);
|
||||
|
||||
// If we are populating the assembly state, record this symbol reference.
|
||||
if (state.asmState)
|
||||
|
|
|
@ -1406,8 +1406,7 @@ public:
|
|||
// If we are populating the assembly parser state, record this as a symbol
|
||||
// reference.
|
||||
if (parser.getState().asmState) {
|
||||
parser.getState().asmState->addUses(
|
||||
getBuilder().getSymbolRefAttr(result.getValue()),
|
||||
parser.getState().asmState->addUses(SymbolRefAttr::get(result),
|
||||
atToken.getLocRange());
|
||||
}
|
||||
return success();
|
||||
|
|
|
@ -245,7 +245,7 @@ Attribute Importer::getConstantAsAttr(llvm::Constant *value) {
|
|||
return b.getFloatAttr(FloatType::getF32(context), c->getValueAPF());
|
||||
}
|
||||
if (auto *f = dyn_cast<llvm::Function>(value))
|
||||
return b.getSymbolRefAttr(f->getName());
|
||||
return SymbolRefAttr::get(b.getContext(), f->getName());
|
||||
|
||||
// Convert constant data to a dense elements attribute.
|
||||
if (auto *cd = dyn_cast<llvm::ConstantDataSequential>(value)) {
|
||||
|
@ -668,8 +668,8 @@ LogicalResult Importer::processInstruction(llvm::Instruction *inst) {
|
|||
}
|
||||
Operation *op;
|
||||
if (llvm::Function *callee = ci->getCalledFunction()) {
|
||||
op = b.create<CallOp>(loc, tys, b.getSymbolRefAttr(callee->getName()),
|
||||
ops);
|
||||
op = b.create<CallOp>(
|
||||
loc, tys, SymbolRefAttr::get(b.getContext(), callee->getName()), ops);
|
||||
} else {
|
||||
Value calledValue = processValue(ci->getCalledOperand());
|
||||
if (!calledValue)
|
||||
|
@ -713,9 +713,10 @@ LogicalResult Importer::processInstruction(llvm::Instruction *inst) {
|
|||
|
||||
Operation *op;
|
||||
if (llvm::Function *callee = ii->getCalledFunction()) {
|
||||
op = b.create<InvokeOp>(loc, tys, b.getSymbolRefAttr(callee->getName()),
|
||||
ops, blocks[ii->getNormalDest()], normalArgs,
|
||||
blocks[ii->getUnwindDest()], unwindArgs);
|
||||
op = b.create<InvokeOp>(
|
||||
loc, tys, SymbolRefAttr::get(b.getContext(), callee->getName()), ops,
|
||||
blocks[ii->getNormalDest()], normalArgs, blocks[ii->getUnwindDest()],
|
||||
unwindArgs);
|
||||
} else {
|
||||
ops.insert(ops.begin(), processValue(ii->getCalledOperand()));
|
||||
op = b.create<InvokeOp>(loc, tys, ops, blocks[ii->getNormalDest()],
|
||||
|
@ -771,7 +772,7 @@ FlatSymbolRefAttr Importer::getPersonalityAsAttr(llvm::Function *f) {
|
|||
|
||||
// If it directly has a name, we can use it.
|
||||
if (pf->hasName())
|
||||
return b.getSymbolRefAttr(pf->getName());
|
||||
return SymbolRefAttr::get(b.getContext(), pf->getName());
|
||||
|
||||
// If it doesn't have a name, currently, only function pointers that are
|
||||
// bitcast to i8* are parsed.
|
||||
|
@ -779,7 +780,7 @@ FlatSymbolRefAttr Importer::getPersonalityAsAttr(llvm::Function *f) {
|
|||
if (ce->getOpcode() == llvm::Instruction::BitCast &&
|
||||
ce->getType() == llvm::Type::getInt8PtrTy(f->getContext())) {
|
||||
if (auto func = dyn_cast<llvm::Function>(ce->getOperand(0)))
|
||||
return b.getSymbolRefAttr(func->getName());
|
||||
return SymbolRefAttr::get(b.getContext(), func->getName());
|
||||
}
|
||||
}
|
||||
return FlatSymbolRefAttr();
|
||||
|
|
|
@ -44,20 +44,19 @@ Value spirv::Deserializer::getValue(uint32_t id) {
|
|||
}
|
||||
if (auto varOp = getGlobalVariable(id)) {
|
||||
auto addressOfOp = opBuilder.create<spirv::AddressOfOp>(
|
||||
unknownLoc, varOp.type(),
|
||||
opBuilder.getSymbolRefAttr(varOp.getOperation()));
|
||||
unknownLoc, varOp.type(), SymbolRefAttr::get(varOp.getOperation()));
|
||||
return addressOfOp.pointer();
|
||||
}
|
||||
if (auto constOp = getSpecConstant(id)) {
|
||||
auto referenceOfOp = opBuilder.create<spirv::ReferenceOfOp>(
|
||||
unknownLoc, constOp.default_value().getType(),
|
||||
opBuilder.getSymbolRefAttr(constOp.getOperation()));
|
||||
SymbolRefAttr::get(constOp.getOperation()));
|
||||
return referenceOfOp.reference();
|
||||
}
|
||||
if (auto constCompositeOp = getSpecConstantComposite(id)) {
|
||||
auto referenceOfOp = opBuilder.create<spirv::ReferenceOfOp>(
|
||||
unknownLoc, constCompositeOp.type(),
|
||||
opBuilder.getSymbolRefAttr(constCompositeOp.getOperation()));
|
||||
SymbolRefAttr::get(constCompositeOp.getOperation()));
|
||||
return referenceOfOp.reference();
|
||||
}
|
||||
if (auto specConstOperationInfo = getSpecConstantOperation(id)) {
|
||||
|
@ -357,11 +356,11 @@ Deserializer::processOp<spirv::EntryPointOp>(ArrayRef<uint32_t> words) {
|
|||
return emitError(unknownLoc, "undefined result <id> ")
|
||||
<< words[wordIndex] << " while decoding OpEntryPoint";
|
||||
}
|
||||
interface.push_back(opBuilder.getSymbolRefAttr(arg.getOperation()));
|
||||
interface.push_back(SymbolRefAttr::get(arg.getOperation()));
|
||||
wordIndex++;
|
||||
}
|
||||
opBuilder.create<spirv::EntryPointOp>(unknownLoc, execModel,
|
||||
opBuilder.getSymbolRefAttr(fnName),
|
||||
opBuilder.create<spirv::EntryPointOp>(
|
||||
unknownLoc, execModel, SymbolRefAttr::get(opBuilder.getContext(), fnName),
|
||||
opBuilder.getArrayAttr(interface));
|
||||
return success();
|
||||
}
|
||||
|
@ -394,7 +393,8 @@ Deserializer::processOp<spirv::ExecutionModeOp>(ArrayRef<uint32_t> words) {
|
|||
}
|
||||
auto values = opBuilder.getArrayAttr(attrListElems);
|
||||
opBuilder.create<spirv::ExecutionModeOp>(
|
||||
unknownLoc, opBuilder.getSymbolRefAttr(fn.getName()), execMode, values);
|
||||
unknownLoc, SymbolRefAttr::get(opBuilder.getContext(), fn.getName()),
|
||||
execMode, values);
|
||||
return success();
|
||||
}
|
||||
|
||||
|
@ -461,8 +461,8 @@ Deserializer::processOp<spirv::FunctionCallOp>(ArrayRef<uint32_t> operands) {
|
|||
}
|
||||
|
||||
auto opFunctionCall = opBuilder.create<spirv::FunctionCallOp>(
|
||||
unknownLoc, resultType, opBuilder.getSymbolRefAttr(functionName),
|
||||
arguments);
|
||||
unknownLoc, resultType,
|
||||
SymbolRefAttr::get(opBuilder.getContext(), functionName), arguments);
|
||||
|
||||
if (resultType)
|
||||
valueMap[resultID] = opFunctionCall.getResult(0);
|
||||
|
|
|
@ -575,7 +575,7 @@ spirv::Deserializer::processGlobalVariable(ArrayRef<uint32_t> operands) {
|
|||
<< operands[wordIndex] << "used as initializer";
|
||||
}
|
||||
wordIndex++;
|
||||
initializer = opBuilder.getSymbolRefAttr(initializerOp.getOperation());
|
||||
initializer = SymbolRefAttr::get(initializerOp.getOperation());
|
||||
}
|
||||
if (wordIndex != operands.size()) {
|
||||
return emitError(unknownLoc,
|
||||
|
@ -1279,7 +1279,7 @@ spirv::Deserializer::processSpecConstantComposite(ArrayRef<uint32_t> operands) {
|
|||
elements.reserve(operands.size() - 2);
|
||||
for (unsigned i = 2, e = operands.size(); i < e; ++i) {
|
||||
auto elementInfo = getSpecConstant(operands[i]);
|
||||
elements.push_back(opBuilder.getSymbolRefAttr(elementInfo));
|
||||
elements.push_back(SymbolRefAttr::get(elementInfo));
|
||||
}
|
||||
|
||||
auto op = opBuilder.create<spirv::SpecConstantCompositeOp>(
|
||||
|
|
|
@ -129,10 +129,10 @@ void NormalizeMemRefs::setCalleesAndCallersNonNormalizable(
|
|||
|
||||
// Functions called by this function.
|
||||
funcOp.walk([&](CallOp callOp) {
|
||||
StringRef callee = callOp.getCallee();
|
||||
StringAttr callee = callOp.getCalleeAttr();
|
||||
for (FuncOp &funcOp : normalizableFuncs) {
|
||||
// We compare FuncOp and callee's name.
|
||||
if (callee == funcOp.getName()) {
|
||||
if (callee == funcOp.getNameAttr()) {
|
||||
setCalleesAndCallersNonNormalizable(funcOp, moduleOp,
|
||||
normalizableFuncs);
|
||||
break;
|
||||
|
@ -255,10 +255,9 @@ void NormalizeMemRefs::updateFunctionSignature(FuncOp funcOp,
|
|||
auto callOp = dyn_cast<CallOp>(userOp);
|
||||
if (!callOp)
|
||||
continue;
|
||||
StringRef callee = callOp.getCallee();
|
||||
Operation *newCallOp = builder.create<CallOp>(
|
||||
userOp->getLoc(), resultTypes, builder.getSymbolRefAttr(callee),
|
||||
userOp->getOperands());
|
||||
Operation *newCallOp =
|
||||
builder.create<CallOp>(userOp->getLoc(), callOp.getCalleeAttr(),
|
||||
resultTypes, userOp->getOperands());
|
||||
bool replacingMemRefUsesFailed = false;
|
||||
bool returnTypeChanged = false;
|
||||
for (unsigned resIndex : llvm::seq<unsigned>(0, userOp->getNumResults())) {
|
||||
|
|
Loading…
Reference in New Issue