From 2418cd92c0340723bbf03630a91520089ec9ca7e Mon Sep 17 00:00:00 2001 From: River Riddle Date: Mon, 7 Feb 2022 17:54:04 -0800 Subject: [PATCH] [mlir] Update uses of `parser`/`printer` ODS op field to `hasCustomAssemblyFormat` The parser/printer fields are deprecated and in the process of being removed. --- .../mlir/Dialect/Affine/IR/AffineOps.td | 22 +- .../Dialect/Arithmetic/IR/ArithmeticOps.td | 3 +- .../include/mlir/Dialect/Async/IR/AsyncOps.td | 4 +- mlir/include/mlir/Dialect/EmitC/IR/EmitC.td | 3 +- mlir/include/mlir/Dialect/GPU/GPUOps.td | 10 +- mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td | 57 +- mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td | 3 +- mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td | 14 +- .../mlir/Dialect/Linalg/IR/LinalgOps.td | 13 +- .../Dialect/Linalg/IR/LinalgStructuredOps.td | 4 +- .../mlir/Dialect/MemRef/IR/MemRefOps.td | 16 +- .../mlir/Dialect/OpenACC/OpenACCOps.td | 8 +- mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td | 24 +- mlir/include/mlir/Dialect/PDL/IR/PDLOps.td | 5 +- .../mlir/Dialect/PDLInterp/IR/PDLInterpOps.td | 3 +- mlir/include/mlir/Dialect/SCF/SCFOps.td | 20 +- .../mlir/Dialect/SPIRV/IR/SPIRVAtomicOps.td | 15 - .../mlir/Dialect/SPIRV/IR/SPIRVBase.td | 12 +- .../mlir/Dialect/SPIRV/IR/SPIRVBitOps.td | 8 +- .../mlir/Dialect/SPIRV/IR/SPIRVGLSLOps.td | 13 +- .../mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td | 26 +- .../Dialect/SPIRV/IR/SPIRVNonUniformOps.td | 3 - .../mlir/Dialect/SPIRV/IR/SPIRVOCLOps.td | 7 +- .../include/mlir/Dialect/Shape/IR/ShapeOps.td | 15 +- .../SparseTensor/IR/SparseTensorOps.td | 5 +- .../mlir/Dialect/StandardOps/IR/Ops.td | 10 +- .../mlir/Dialect/Tensor/IR/TensorOps.td | 8 +- .../mlir/Dialect/Vector/IR/VectorOps.td | 17 +- mlir/include/mlir/IR/BuiltinOps.td | 3 +- mlir/lib/Dialect/Affine/IR/AffineOps.cpp | 218 +++--- .../Dialect/Arithmetic/IR/ArithmeticOps.cpp | 20 +- mlir/lib/Dialect/Async/IR/Async.cpp | 20 +- mlir/lib/Dialect/EmitC/IR/EmitC.cpp | 8 +- mlir/lib/Dialect/GPU/IR/GPUDialect.cpp | 82 +-- mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp | 299 ++++---- mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp | 5 +- mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp | 14 +- mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 72 +- mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 109 +-- mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp | 91 ++- mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp | 190 +++-- mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp | 14 +- mlir/lib/Dialect/SCF/SCF.cpp | 94 +-- mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp | 683 +++++++++++------- mlir/lib/Dialect/Shape/IR/Shape.cpp | 55 +- mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 11 +- mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 105 ++- mlir/lib/IR/BuiltinDialect.cpp | 8 +- mlir/test/Dialect/SPIRV/IR/bit-ops.mlir | 2 +- mlir/test/Dialect/SPIRV/IR/logical-ops.mlir | 4 +- .../mlir-linalg-ods-yaml-gen.cpp | 20 +- 51 files changed, 1242 insertions(+), 1233 deletions(-) diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td index 31aebe40e43a..5e1b910c8b68 100644 --- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td +++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td @@ -28,15 +28,7 @@ def Affine_Dialect : Dialect { // Base class for Affine dialect ops. class Affine_Op traits = []> : - Op { - // For every affine op, there needs to be a: - // * void print(OpAsmPrinter &p, ${C++ class of Op} op) - // * ParseResult parse${C++ class of Op}(OpAsmParser &parser, - // OperationState &result) - // functions. - let printer = [{ return ::print(p, *this); }]; - let parser = [{ return ::parse$cppClass(parser, result); }]; -} + Op; // Require regions to have affine.yield. def ImplicitAffineTerminator @@ -109,6 +101,7 @@ def AffineApplyOp : Affine_Op<"apply", [NoSideEffect]> { }]; let hasCanonicalizer = 1; + let hasCustomAssemblyFormat = 1; let hasFolder = 1; let hasVerifier = 1; } @@ -348,6 +341,7 @@ def AffineForOp : Affine_Op<"for", }]; let hasCanonicalizer = 1; + let hasCustomAssemblyFormat = 1; let hasFolder = 1; let hasVerifier = 1; } @@ -472,6 +466,7 @@ def AffineIfOp : Affine_Op<"if", }]; let hasCanonicalizer = 1; + let hasCustomAssemblyFormat = 1; let hasFolder = 1; let hasVerifier = 1; } @@ -538,6 +533,7 @@ def AffineLoadOp : AffineLoadOpBase<"load"> { let extraClassDeclaration = extraClassDeclarationBase; let hasCanonicalizer = 1; + let hasCustomAssemblyFormat = 1; let hasFolder = 1; let hasVerifier = 1; } @@ -567,8 +563,7 @@ class AffineMinMaxOpBase traits = []> : operands().end()}; } }]; - let printer = [{ return ::printAffineMinMaxOp(p, *this); }]; - let parser = [{ return ::parseAffineMinMaxOp<$cppClass>(parser, result); }]; + let hasCustomAssemblyFormat = 1; let hasFolder = 1; let hasCanonicalizer = 1; let hasVerifier = 1; @@ -754,6 +749,7 @@ def AffineParallelOp : Affine_Op<"parallel", } }]; + let hasCustomAssemblyFormat = 1; let hasFolder = 1; let hasVerifier = 1; } @@ -834,6 +830,7 @@ def AffinePrefetchOp : Affine_Op<"prefetch", }]; let hasCanonicalizer = 1; + let hasCustomAssemblyFormat = 1; let hasFolder = 1; let hasVerifier = 1; } @@ -899,6 +896,7 @@ def AffineStoreOp : AffineStoreOpBase<"store"> { let extraClassDeclaration = extraClassDeclarationBase; let hasCanonicalizer = 1; + let hasCustomAssemblyFormat = 1; let hasFolder = 1; let hasVerifier = 1; } @@ -990,6 +988,7 @@ def AffineVectorLoadOp : AffineLoadOpBase<"vector_load"> { }]; let hasCanonicalizer = 1; + let hasCustomAssemblyFormat = 1; let hasVerifier = 1; } @@ -1055,6 +1054,7 @@ def AffineVectorStoreOp : AffineStoreOpBase<"vector_store"> { }]; let hasCanonicalizer = 1; + let hasCustomAssemblyFormat = 1; let hasVerifier = 1; } diff --git a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td index b278de529db2..f31fe2d9f044 100644 --- a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td +++ b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td @@ -1202,8 +1202,7 @@ def SelectOp : Arith_Op<"select", [ let hasVerifier = 1; // FIXME: Switch this to use the declarative assembly format. - let printer = [{ return ::print(p, *this); }]; - let parser = [{ return ::parse$cppClass(parser, result); }]; + let hasCustomAssemblyFormat = 1; } #endif // ARITHMETIC_OPS diff --git a/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td b/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td index d27a6234dd5c..d199bb2b42b0 100644 --- a/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td +++ b/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td @@ -81,9 +81,7 @@ def Async_ExecuteOp : Variadic:$results); let regions = (region SizedRegion<1>:$body); - let printer = [{ return ::print(p, *this); }]; - let parser = [{ return ::parse$cppClass(parser, result); }]; - + let hasCustomAssemblyFormat = 1; let skipDefaultBuilders = 1; let hasVerifier = 1; let builders = [ diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td index 4e587d411c5d..7135cd9f12c6 100644 --- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td +++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td @@ -143,8 +143,7 @@ def EmitC_IncludeOp Arg:$include, UnitAttr:$is_standard_include ); - let printer = [{ return ::print(p, *this); }]; - let parser = [{ return ::parse$cppClass(parser, result); }]; + let hasCustomAssemblyFormat = 1; } #endif // MLIR_DIALECT_EMITC_IR_EMITC diff --git a/mlir/include/mlir/Dialect/GPU/GPUOps.td b/mlir/include/mlir/Dialect/GPU/GPUOps.td index 41bb6d0c37cc..a07f73eaa432 100644 --- a/mlir/include/mlir/Dialect/GPU/GPUOps.td +++ b/mlir/include/mlir/Dialect/GPU/GPUOps.td @@ -294,9 +294,7 @@ def GPU_GPUFuncOp : GPU_Op<"func", [ /// Verifies the body of the function. LogicalResult verifyBody(); }]; - - let printer = [{ printGPUFuncOp(p, *this); }]; - let parser = [{ return parseGPUFuncOp(parser, result); }]; + let hasCustomAssemblyFormat = 1; } def GPU_LaunchFuncOp : GPU_Op<"launch_func", @@ -556,9 +554,8 @@ def GPU_LaunchOp : GPU_Op<"launch">, static constexpr unsigned kNumConfigRegionAttributes = 12; }]; - let parser = [{ return parseLaunchOp(parser, result); }]; - let printer = [{ printLaunchOp(p, *this); }]; let hasCanonicalizer = 1; + let hasCustomAssemblyFormat = 1; let hasVerifier = 1; } @@ -783,9 +780,8 @@ def GPU_GPUModuleOp : GPU_Op<"module", [ ``` }]; let builders = [OpBuilder<(ins "StringRef":$name)>]; - let parser = [{ return ::parseGPUModuleOp(parser, result); }]; - let printer = [{ return ::print(p, *this); }]; let regions = (region SizedRegion<1>:$body); + let hasCustomAssemblyFormat = 1; // We need to ensure the block inside the region is properly terminated; // the auto-generated builders do not guarantee that. diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td index 4a40132df963..7dae79047b1b 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -202,8 +202,7 @@ def LLVM_ICmpOp : LLVM_Op<"icmp", [NoSideEffect]> { build($_builder, $_state, IntegerType::get(lhs.getType().getContext(), 1), predicate, lhs, rhs); }]>]; - let parser = [{ return parseCmpOp(parser, result); }]; - let printer = [{ printICmpOp(p, *this); }]; + let hasCustomAssemblyFormat = 1; } // Predicate for float comparisons @@ -246,8 +245,7 @@ def LLVM_FCmpOp : LLVM_Op<"fcmp", [ let llvmBuilder = [{ $res = builder.CreateFCmp(getLLVMCmpPredicate($predicate), $lhs, $rhs); }]; - let parser = [{ return parseCmpOp(parser, result); }]; - let printer = [{ printFCmpOp(p, *this); }]; + let hasCustomAssemblyFormat = 1; } // Floating point binary operations. @@ -312,8 +310,7 @@ def LLVM_AllocaOp : LLVM_Op<"alloca">, MemoryOpWithAlignmentBase { build($_builder, $_state, resultType, arraySize, $_builder.getI64IntegerAttr(alignment)); }]>]; - let parser = [{ return parseAllocaOp(parser, result); }]; - let printer = [{ printAllocaOp(p, *this); }]; + let hasCustomAssemblyFormat = 1; } def LLVM_GEPOp : LLVM_Op<"getelementptr", [NoSideEffect]> { @@ -382,8 +379,7 @@ def LLVM_LoadOp : LLVM_Op<"load">, MemoryOpWithAlignmentAndAttributes { OpBuilder<(ins "Type":$t, "Value":$addr, CArg<"unsigned", "0">:$alignment, CArg<"bool", "false">:$isVolatile, CArg<"bool", "false">:$isNonTemporal)>]; - let parser = [{ return parseLoadOp(parser, result); }]; - let printer = [{ printLoadOp(p, *this); }]; + let hasCustomAssemblyFormat = 1; let hasVerifier = 1; } @@ -406,8 +402,7 @@ def LLVM_StoreOp : LLVM_Op<"store">, MemoryOpWithAlignmentAndAttributes { CArg<"unsigned", "0">:$alignment, CArg<"bool", "false">:$isVolatile, CArg<"bool", "false">:$isNonTemporal)> ]; - let parser = [{ return parseStoreOp(parser, result); }]; - let printer = [{ printStoreOp(p, *this); }]; + let hasCustomAssemblyFormat = 1; let hasVerifier = 1; } @@ -491,8 +486,7 @@ def LLVM_InvokeOp : LLVM_Op<"invoke", [ build($_builder, $_state, tys, /*callee=*/FlatSymbolRefAttr(), ops, normalOps, unwindOps, normal, unwind); }]>]; - let parser = [{ return parseInvokeOp(parser, result); }]; - let printer = [{ printInvokeOp(p, *this); }]; + let hasCustomAssemblyFormat = 1; let hasVerifier = 1; } @@ -500,8 +494,7 @@ def LLVM_LandingpadOp : LLVM_Op<"landingpad"> { let arguments = (ins UnitAttr:$cleanup, Variadic); let results = (outs LLVM_Type:$res); let builders = [LLVM_OneResultOpBuilder]; - let parser = [{ return parseLandingpadOp(parser, result); }]; - let printer = [{ printLandingpadOp(p, *this); }]; + let hasCustomAssemblyFormat = 1; let hasVerifier = 1; } @@ -559,8 +552,7 @@ def LLVM_CallOp : LLVM_Op<"call", build($_builder, $_state, results, StringAttr::get($_builder.getContext(), callee), operands); }]>]; - let parser = [{ return parseCallOp(parser, result); }]; - let printer = [{ printCallOp(p, *this); }]; + let hasCustomAssemblyFormat = 1; let hasVerifier = 1; } def LLVM_ExtractElementOp : LLVM_Op<"extractelement", [NoSideEffect]> { @@ -572,8 +564,7 @@ def LLVM_ExtractElementOp : LLVM_Op<"extractelement", [NoSideEffect]> { let builders = [ OpBuilder<(ins "Value":$vector, "Value":$position, CArg<"ArrayRef", "{}">:$attrs)>]; - let parser = [{ return parseExtractElementOp(parser, result); }]; - let printer = [{ printExtractElementOp(p, *this); }]; + let hasCustomAssemblyFormat = 1; let hasVerifier = 1; } def LLVM_ExtractValueOp : LLVM_Op<"extractvalue", [NoSideEffect]> { @@ -583,8 +574,7 @@ def LLVM_ExtractValueOp : LLVM_Op<"extractvalue", [NoSideEffect]> { $res = builder.CreateExtractValue($container, extractPosition($position)); }]; let builders = [LLVM_OneResultOpBuilder]; - let parser = [{ return parseExtractValueOp(parser, result); }]; - let printer = [{ printExtractValueOp(p, *this); }]; + let hasCustomAssemblyFormat = 1; let hasFolder = 1; let hasVerifier = 1; } @@ -596,8 +586,7 @@ def LLVM_InsertElementOp : LLVM_Op<"insertelement", [NoSideEffect]> { $res = builder.CreateInsertElement($vector, $value, $position); }]; let builders = [LLVM_OneResultOpBuilder]; - let parser = [{ return parseInsertElementOp(parser, result); }]; - let printer = [{ printInsertElementOp(p, *this); }]; + let hasCustomAssemblyFormat = 1; let hasVerifier = 1; } def LLVM_InsertValueOp : LLVM_Op<"insertvalue", [NoSideEffect]> { @@ -613,8 +602,7 @@ def LLVM_InsertValueOp : LLVM_Op<"insertvalue", [NoSideEffect]> { [{ build($_builder, $_state, container.getType(), container, value, position); }]>]; - let parser = [{ return parseInsertValueOp(parser, result); }]; - let printer = [{ printInsertValueOp(p, *this); }]; + let hasCustomAssemblyFormat = 1; let hasVerifier = 1; } def LLVM_ShuffleVectorOp : LLVM_Op<"shufflevector", [NoSideEffect]> { @@ -628,8 +616,7 @@ def LLVM_ShuffleVectorOp : LLVM_Op<"shufflevector", [NoSideEffect]> { let builders = [ OpBuilder<(ins "Value":$v1, "Value":$v2, "ArrayAttr":$mask, CArg<"ArrayRef", "{}">:$attrs)>]; - let parser = [{ return parseShuffleVectorOp(parser, result); }]; - let printer = [{ printShuffleVectorOp(p, *this); }]; + let hasCustomAssemblyFormat = 1; let hasVerifier = 1; } @@ -708,8 +695,7 @@ def LLVM_ReturnOp : LLVM_TerminatorOp<"return", [NoSideEffect]> { builder.CreateRetVoid(); }]; - let parser = [{ return parseReturnOp(parser, result); }]; - let printer = [{ printReturnOp(p, *this); }]; + let assemblyFormat = "attr-dict ($args^ `:` type($args))?"; let hasVerifier = 1; } def LLVM_ResumeOp : LLVM_TerminatorOp<"resume", []> { @@ -1151,8 +1137,7 @@ def LLVM_GlobalOp : LLVM_Op<"mlir.global", } }]; - let printer = "printGlobalOp(p, *this);"; - let parser = "return parseGlobalOp(parser, result);"; + let hasCustomAssemblyFormat = 1; let hasVerifier = 1; } @@ -1288,8 +1273,7 @@ def LLVM_LLVMFuncOp : LLVM_Op<"func", [ LogicalResult verifyType(); }]; - let printer = [{ printLLVMFuncOp(p, *this); }]; - let parser = [{ return parseLLVMFuncOp(parser, result); }]; + let hasCustomAssemblyFormat = 1; let hasVerifier = 1; } @@ -1824,8 +1808,7 @@ def LLVM_AtomicRMWOp : LLVM_Op<"atomicrmw"> { llvm::MaybeAlign(), getLLVMAtomicOrdering($ordering)); }]; - let parser = [{ return parseAtomicRMWOp(parser, result); }]; - let printer = [{ printAtomicRMWOp(p, *this); }]; + let hasCustomAssemblyFormat = 1; let hasVerifier = 1; } @@ -1854,8 +1837,7 @@ def LLVM_AtomicCmpXchgOp : LLVM_Op<"cmpxchg"> { getLLVMAtomicOrdering($success_ordering), getLLVMAtomicOrdering($failure_ordering)); }]; - let parser = [{ return parseAtomicCmpXchgOp(parser, result); }]; - let printer = [{ printAtomicCmpXchgOp(p, *this); }]; + let hasCustomAssemblyFormat = 1; let hasVerifier = 1; } @@ -1877,8 +1859,7 @@ def LLVM_FenceOp : LLVM_Op<"fence"> { builder.CreateFence(getLLVMAtomicOrdering($ordering), llvmContext.getOrInsertSyncScopeID($syncscope)); }]; - let parser = [{ return parseFenceOp(parser, result); }]; - let printer = [{ printFenceOp(p, *this); }]; + let hasCustomAssemblyFormat = 1; let hasVerifier = 1; } diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td index 4a55ddd96cb7..196fbdd95105 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td @@ -146,8 +146,7 @@ def NVVM_VoteBallotOp : $res = createIntrinsicCall(builder, llvm::Intrinsic::nvvm_vote_ballot_sync, {$mask, $pred}); }]; - let parser = [{ return parseNVVMVoteBallotOp(parser, result); }]; - let printer = [{ printNVVMIntrinsicOp(p, this->getOperation()); }]; + let hasCustomAssemblyFormat = 1; } diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td index 45642da7426a..bd80bde25960 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td @@ -159,12 +159,7 @@ def ROCDL_MubufLoadOp : llvm::Intrinsic::amdgcn_buffer_load, {$rsrc, $vindex, $offset, $glc, $slc}, {$_resultType}); }]; - let parser = [{ return parseROCDLMubufLoadOp(parser, result); }]; - let printer = [{ - Operation *op = this->getOperation(); - p << " " << op->getOperands() - << " : " << op->getResultTypes(); - }]; + let hasCustomAssemblyFormat = 1; } def ROCDL_MubufStoreOp : @@ -181,12 +176,7 @@ def ROCDL_MubufStoreOp : llvm::Intrinsic::amdgcn_buffer_store, {$vdata, $rsrc, $vindex, $offset, $glc, $slc}, {vdataType}); }]; - let parser = [{ return parseROCDLMubufStoreOp(parser, result); }]; - let printer = [{ - Operation *op = this->getOperation(); - p << " " << op->getOperands() - << " : " << vdata().getType(); - }]; + let hasCustomAssemblyFormat = 1; } #endif // ROCDLIR_OPS diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td index d8ae172e1749..518a2cfacf2d 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td @@ -22,15 +22,7 @@ include "mlir/Interfaces/ViewLikeInterface.td" // Base class for Linalg dialect ops that do not correspond to library calls. class Linalg_Op traits = []> : - Op { - // For every linalg op, there needs to be a: - // * void print(OpAsmPrinter &p, ${C++ class of Op} op) - // * ParseResult parse${C++ class of Op}(OpAsmParser &parser, - // OperationState &result) - // functions. - let printer = [{ return ::print(p, *this); }]; - let parser = [{ return ::parse$cppClass(parser, result); }]; -} + Op; def Linalg_InitTensorOp : Linalg_Op<"init_tensor", [NoSideEffect, @@ -123,6 +115,7 @@ def Linalg_InitTensorOp : Linalg_Op<"init_tensor", ]; let hasCanonicalizer = 1; + let hasCustomAssemblyFormat = 1; let hasVerifier = 1; } @@ -141,6 +134,7 @@ def Linalg_YieldOp : Linalg_Op<"yield", [NoSideEffect, ReturnLike, Terminator]>, ``` }]; let builders = [OpBuilder<(ins), [{ /* nothing to do */ }]>]; + let hasCustomAssemblyFormat = 1; let hasVerifier = 1; } @@ -423,6 +417,7 @@ def Linalg_TiledLoopOp : Linalg_Op<"tiled_loop", [ }]; let hasCanonicalizer = 1; + let hasCustomAssemblyFormat = 1; let hasFolder = 1; let hasVerifier = 1; } diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td index da01d10ebde0..8bdec0971ee1 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -260,10 +260,8 @@ def GenericOp : LinalgStructuredBase_Op<"generic", [AttrSizedOperandSegments]> { } }]; - let printer = [{ return ::print(p, *this); }]; - let parser = [{ return ::parseGenericOp(parser, result); }]; - let hasCanonicalizer = 1; + let hasCustomAssemblyFormat = 1; let hasFolder = 1; let hasVerifier = 1; } diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td index abb8231d7e5b..2445280b0157 100644 --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td @@ -26,10 +26,7 @@ def MemRefTypeAttr } class MemRef_Op traits = []> - : Op { - let printer = [{ return ::print(p, *this); }]; - let parser = [{ return ::parse$cppClass(parser, result); }]; -} + : Op; // Base class for ops with static/dynamic offset, sizes and strides // attributes/arguments. @@ -275,6 +272,7 @@ def MemRef_AllocaScopeOp : MemRef_Op<"alloca_scope", let results = (outs Variadic:$results); let regions = (region SizedRegion<1>:$bodyRegion); + let hasCustomAssemblyFormat = 1; let hasVerifier = 1; } @@ -657,6 +655,7 @@ def MemRef_DmaStartOp : MemRef_Op<"dma_start"> { return getOperand(getNumOperands() - 1); } }]; + let hasCustomAssemblyFormat = 1; let hasFolder = 1; let hasVerifier = 1; } @@ -771,6 +770,7 @@ def GenericAtomicRMWOp : MemRef_Op<"generic_atomic_rmw", [ return memref().getType().cast(); } }]; + let hasCustomAssemblyFormat = 1; let hasVerifier = 1; } @@ -996,6 +996,7 @@ def MemRef_PrefetchOp : MemRef_Op<"prefetch"> { static StringRef getIsDataCacheAttrName() { return "isDataCache"; } }]; + let hasCustomAssemblyFormat = 1; let hasFolder = 1; let hasVerifier = 1; } @@ -1048,8 +1049,6 @@ def MemRef_ReinterpretCastOp attr-dict `:` type($source) `to` type($result) }]; - let parser = ?; - let printer = ?; let hasVerifier = 1; let builders = [ @@ -1243,9 +1242,8 @@ class MemRef_ReassociativeReshapeOp traits = []> : let hasFolder = 1; let hasCanonicalizer = 1; + let hasCustomAssemblyFormat = 1; let hasVerifier = 1; - let printer = [{ return ::print(p, *this); }]; - let parser = [{ return ::parseReshapeLikeOp(parser, result); }]; } def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape"> { @@ -1700,6 +1698,7 @@ def MemRef_TransposeOp : MemRef_Op<"transpose", [NoSideEffect]>, ShapedType getShapedType() { return in().getType().cast(); } }]; + let hasCustomAssemblyFormat = 1; let hasFolder = 1; let hasVerifier = 1; } @@ -1770,6 +1769,7 @@ def MemRef_ViewOp : MemRef_Op<"view", [ }]; let hasCanonicalizer = 1; + let hasCustomAssemblyFormat = 1; let hasVerifier = 1; } diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td index 9fc293accb49..70cd3153d039 100644 --- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td +++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td @@ -34,11 +34,7 @@ include "mlir/Dialect/OpenACC/AccCommon.td" // Base class for OpenACC dialect ops. class OpenACC_Op traits = []> : - Op { - - let printer = [{ return ::print(p, *this); }]; - let parser = [{ return ::parse$cppClass(parser, result); }]; -} + Op; // Reduction operation enumeration. def OpenACC_ReductionOpAdd : I32EnumAttrCase<"redop_add", 0>; @@ -152,6 +148,7 @@ def OpenACC_ParallelOp : OpenACC_Op<"parallel", /// The i-th data operand passed. Value getDataOperand(unsigned i); }]; + let hasCustomAssemblyFormat = 1; } //===----------------------------------------------------------------------===// @@ -404,6 +401,7 @@ def OpenACC_LoopOp : OpenACC_Op<"loop", static StringRef getPrivateKeyword() { return "private"; } static StringRef getReductionKeyword() { return "reduction"; } }]; + let hasCustomAssemblyFormat = 1; let hasVerifier = 1; } diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td index b5532b1185e7..39a6f2f72878 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td @@ -126,8 +126,7 @@ def ParallelOp : OpenMP_Op<"parallel", [AttrSizedOperandSegments, let builders = [ OpBuilder<(ins CArg<"ArrayRef", "{}">:$attributes)> ]; - let parser = [{ return parseParallelOp(parser, result); }]; - let printer = [{ return printParallelOp(p, *this); }]; + let hasCustomAssemblyFormat = 1; let hasVerifier = 1; } @@ -215,8 +214,7 @@ def SectionsOp : OpenMP_Op<"sections", [AttrSizedOperandSegments]> { let regions = (region SizedRegion<1>:$region); - let parser = [{ return parseSectionsOp(parser, result); }]; - let printer = [{ return printSectionsOp(p, *this); }]; + let hasCustomAssemblyFormat = 1; let hasVerifier = 1; } @@ -334,8 +332,7 @@ def WsLoopOp : OpenMP_Op<"wsloop", [AttrSizedOperandSegments, /// Returns the number of reduction variables. unsigned getNumReductionVars() { return reduction_vars().size(); } }]; - let parser = [{ return parseWsLoopOp(parser, result); }]; - let printer = [{ return printWsLoopOp(p, *this); }]; + let hasCustomAssemblyFormat = 1; let hasVerifier = 1; } @@ -419,8 +416,7 @@ def TargetOp : OpenMP_Op<"target",[AttrSizedOperandSegments]> { let regions = (region AnyRegion:$region); - let parser = [{ return parseTargetOp(parser, result); }]; - let printer = [{ return printTargetOp(p, *this); }]; + let hasCustomAssemblyFormat = 1; } @@ -608,8 +604,7 @@ def AtomicReadOp : OpenMP_Op<"atomic.read"> { OpenMP_PointerLikeType:$v, DefaultValuedAttr:$hint, OptionalAttr:$memory_order); - let parser = [{ return parseAtomicReadOp(parser, result); }]; - let printer = [{ return printAtomicReadOp(p, *this); }]; + let hasCustomAssemblyFormat = 1; let hasVerifier = 1; } @@ -637,8 +632,7 @@ def AtomicWriteOp : OpenMP_Op<"atomic.write"> { AnyType:$value, DefaultValuedAttr:$hint, OptionalAttr:$memory_order); - let parser = [{ return parseAtomicWriteOp(parser, result); }]; - let printer = [{ return printAtomicWriteOp(p, *this); }]; + let hasCustomAssemblyFormat = 1; let hasVerifier = 1; } @@ -702,8 +696,7 @@ def AtomicUpdateOp : OpenMP_Op<"atomic.update"> { AtomicBinOpKindAttr:$binop, DefaultValuedAttr:$hint, OptionalAttr:$memory_order); - let parser = [{ return parseAtomicUpdateOp(parser, result); }]; - let printer = [{ return printAtomicUpdateOp(p, *this); }]; + let hasCustomAssemblyFormat = 1; let hasVerifier = 1; } @@ -746,8 +739,7 @@ def AtomicCaptureOp : OpenMP_Op<"atomic.capture", let arguments = (ins DefaultValuedAttr:$hint, OptionalAttr:$memory_order); let regions = (region SizedRegion<1>:$region); - let parser = [{ return parseAtomicCaptureOp(parser, result); }]; - let printer = [{ return printAtomicCaptureOp(p, *this); }]; + let hasCustomAssemblyFormat = 1; let hasVerifier = 1; } diff --git a/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td b/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td index 7f7b8a6a06d9..0c35d0457d84 100644 --- a/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td +++ b/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td @@ -23,10 +23,7 @@ include "mlir/IR/SymbolInterfaces.td" //===----------------------------------------------------------------------===// class PDL_Op traits = []> - : Op { - let printer = [{ ::print(p, *this); }]; - let parser = [{ return ::parse$cppClass(parser, result); }]; -} + : Op; //===----------------------------------------------------------------------===// // pdl::ApplyNativeConstraintOp diff --git a/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td b/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td index e3d83970f303..bfb572dab4e1 100644 --- a/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td +++ b/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td @@ -623,8 +623,7 @@ def PDLInterp_ForEachOp /// Returns the loop variable. BlockArgument getLoopVariable() { return region().getArgument(0); } }]; - let parser = [{ return ::parseForEachOp(parser, result); }]; - let printer = [{ return ::print(p, *this); }]; + let hasCustomAssemblyFormat = 1; let hasVerifier = 1; } diff --git a/mlir/include/mlir/Dialect/SCF/SCFOps.td b/mlir/include/mlir/Dialect/SCF/SCFOps.td index b06c90093ff6..a2218d13ab0f 100644 --- a/mlir/include/mlir/Dialect/SCF/SCFOps.td +++ b/mlir/include/mlir/Dialect/SCF/SCFOps.td @@ -26,15 +26,11 @@ def SCF_Dialect : Dialect { // Base class for SCF dialect ops. class SCF_Op traits = []> : - Op { - // For every standard op, there needs to be a: - // * void print(OpAsmPrinter &p, ${C++ class of Op} op) - // * ParseResult parse${C++ class of Op}(OpAsmParser &parser, - // OperationState &result) - // functions. - let printer = [{ return ::print(p, *this); }]; - let parser = [{ return ::parse$cppClass(parser, result); }]; -} + Op; + +//===----------------------------------------------------------------------===// +// ConditionOp +//===----------------------------------------------------------------------===// def ConditionOp : SCF_Op<"condition", [ HasParent<"WhileOp">, @@ -107,6 +103,7 @@ def ExecuteRegionOp : SCF_Op<"execute_region"> { let regions = (region AnyRegion:$region); let hasCanonicalizer = 1; + let hasCustomAssemblyFormat = 1; let hasFolder = 0; let hasVerifier = 1; @@ -308,6 +305,7 @@ def ForOp : SCF_Op<"for", }]; let hasCanonicalizer = 1; + let hasCustomAssemblyFormat = 1; let hasVerifier = 1; } @@ -401,6 +399,7 @@ def IfOp : SCF_Op<"if", }]; let hasFolder = 1; let hasCanonicalizer = 1; + let hasCustomAssemblyFormat = 1; let hasVerifier = 1; } @@ -483,6 +482,7 @@ def ParallelOp : SCF_Op<"parallel", }]; let hasCanonicalizer = 1; + let hasCustomAssemblyFormat = 1; let hasVerifier = 1; } @@ -531,6 +531,7 @@ def ReduceOp : SCF_Op<"reduce", [HasParent<"ParallelOp">]> { ]; let arguments = (ins AnyType:$operand); + let hasCustomAssemblyFormat = 1; let regions = (region SizedRegion<1>:$reductionOperator); let hasVerifier = 1; } @@ -684,6 +685,7 @@ def WhileOp : SCF_Op<"while", }]; let hasCanonicalizer = 1; + let hasCustomAssemblyFormat = 1; let hasVerifier = 1; } diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAtomicOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAtomicOps.td index d15de160ecbd..f8140fda5823 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAtomicOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAtomicOps.td @@ -16,9 +16,6 @@ class SPV_AtomicUpdateOp traits = []> : SPV_Op { - let parser = [{ return ::parseAtomicUpdateOp(parser, result, false); }]; - let printer = [{ return ::printAtomicUpdateOp(getOperation(), p); }]; - let arguments = (ins SPV_AnyPtr:$pointer, SPV_ScopeAttr:$memory_scope, @@ -32,9 +29,6 @@ class SPV_AtomicUpdateOp traits = []> : class SPV_AtomicUpdateWithValueOp traits = []> : SPV_Op { - let parser = [{ return ::parseAtomicUpdateOp(parser, result, true); }]; - let printer = [{ return ::printAtomicUpdateOp(getOperation(), p); }]; - let arguments = (ins SPV_AnyPtr:$pointer, SPV_ScopeAttr:$memory_scope, @@ -163,9 +157,6 @@ def SPV_AtomicCompareExchangeOp : SPV_Op<"AtomicCompareExchange", []> { let results = (outs SPV_Integer:$result ); - - let parser = [{ return ::parseAtomicCompareExchangeImpl(parser, result); }]; - let printer = [{ return ::printAtomicCompareExchangeImpl(*this, p); }]; } // ----- @@ -215,9 +206,6 @@ def SPV_AtomicCompareExchangeWeakOp : SPV_Op<"AtomicCompareExchangeWeak", []> { let results = (outs SPV_Integer:$result ); - - let parser = [{ return ::parseAtomicCompareExchangeImpl(parser, result); }]; - let printer = [{ return ::printAtomicCompareExchangeImpl(*this, p); }]; } // ----- @@ -331,9 +319,6 @@ def SPV_AtomicFAddEXTOp : SPV_Op<"AtomicFAddEXT", []> { let results = (outs SPV_Float:$result ); - - let parser = [{ return ::parseAtomicUpdateOp(parser, result, true); }]; - let printer = [{ return ::printAtomicUpdateOp(getOperation(), p); }]; } // ----- diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td index 055835ea4808..6d494859c03f 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td @@ -4269,12 +4269,11 @@ class SPV_Op traits = []> : // For each SPIR-V op, the following static functions need to be defined // in SPVOps.cpp: // - // * static ParseResult parse(OpAsmParser &parser, - // OperationState &result) - // * static void print(OpAsmPrinter &p, op) + // * ParseResult ::parse(OpAsmParser &parser, + // OperationState &result) + // * void ::print(OpAsmPrinter &p) // * LogicalResult ::verify() - let parser = [{ return ::parse$cppClass(parser, result); }]; - let printer = [{ return ::print(*this, p); }]; + let hasCustomAssemblyFormat = 1; let hasVerifier = 1; // Specifies whether this op has a direct corresponding SPIR-V binary @@ -4320,8 +4319,7 @@ class SPV_UnaryOp:$result ); - let parser = [{ return ::parseUnaryOp(parser, result); }]; - let printer = [{ return ::printUnaryOp(getOperation(), p); }]; + let assemblyFormat = "$operand `:` type($operand) attr-dict"; // No additional verification needed in addition to the ODS-generated ones. let hasVerifier = 0; } diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBitOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBitOps.td index b9b31f6dd5ab..71884165e9f6 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBitOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBitOps.td @@ -53,9 +53,11 @@ class SPV_BitUnaryOp traits = []> : class SPV_ShiftOp traits = []> : SPV_BinaryOp { - let parser = [{ return ::parseShiftOp(parser, result); }]; - let printer = [{ ::printShiftOp(this->getOperation(), p); }]; + [NoSideEffect, SameOperandsAndResultShape, + AllTypesMatch<["operand1", "result"]>])> { + let assemblyFormat = [{ + operands attr-dict `:` type($operand1) `,` type($operand2) + }]; let hasVerifier = 1; } diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGLSLOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGLSLOps.td index f0d5515d2bcf..dbfcc727b05a 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGLSLOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGLSLOps.td @@ -45,10 +45,6 @@ class SPV_GLSLUnaryOp:$result ); - let parser = [{ return parseUnaryOp(parser, result); }]; - - let printer = [{ return printUnaryOp(getOperation(), p); }]; - let hasVerifier = 0; } @@ -56,7 +52,10 @@ class SPV_GLSLUnaryOp traits = []> : - SPV_GLSLUnaryOp; + SPV_GLSLUnaryOp { + let assemblyFormat = "$operand `:` type($operand) attr-dict"; +} // Base class for GLSL binary ops. class SPV_GLSLBinaryOp:$result ); - let parser = [{ return parseOneResultSameOperandTypeOp(parser, result); }]; - let printer = [{ return printOneResultOp(getOperation(), p); }]; - + let hasCustomAssemblyFormat = 1; let hasVerifier = 0; } diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td index 9c490830547d..a6ec4969beea 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td @@ -21,11 +21,15 @@ class SPV_LogicalBinaryOp traits = []> : // Result type is SPV_Bool. SPV_BinaryOp { - let parser = [{ return ::parseLogicalBinaryOp(parser, result); }]; - let printer = [{ return ::printLogicalOp(getOperation(), p); }]; + !listconcat(traits, [ + NoSideEffect, SameTypeOperands, + SameOperandsAndResultShape, + TypesMatchWith<"type of result to correspond to the `i1` " + "equivalent of the operand", + "operand1", "result", + "getUnaryOpResultType($_self)" + >])> { + let assemblyFormat = "$operand1 `,` $operand2 `:` type($operand1) attr-dict"; let builders = [ OpBuilder<(ins "Value":$lhs, "Value":$rhs), @@ -37,10 +41,14 @@ class SPV_LogicalUnaryOp traits = []> : // Result type is SPV_Bool. SPV_UnaryOp { - let parser = [{ return ::parseLogicalUnaryOp(parser, result); }]; - let printer = [{ return ::printLogicalOp(getOperation(), p); }]; + !listconcat(traits, [ + NoSideEffect, SameTypeOperands, SameOperandsAndResultShape, + TypesMatchWith<"type of result to correspond to the `i1` " + "equivalent of the operand", + "operand", "result", + "getUnaryOpResultType($_self)" + >])> { + let assemblyFormat = "$operand `:` type($operand) attr-dict"; let builders = [ OpBuilder<(ins "Value":$value), diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td index 1077e8a802a8..534a263ff891 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td @@ -27,9 +27,6 @@ class SPV_GroupNonUniformArithmeticOp:$result ); - - let parser = [{ return parseGroupNonUniformArithmeticOp(parser, result); }]; - let printer = [{ printGroupNonUniformArithmeticOp(getOperation(), p); }]; } // ----- diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOCLOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOCLOps.td index 0b3c08a51011..fcc69a49d559 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOCLOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOCLOps.td @@ -44,9 +44,7 @@ class SPV_OCLUnaryOp:$result ); - let parser = [{ return parseUnaryOp(parser, result); }]; - - let printer = [{ return printUnaryOp(getOperation(), p); }]; + let assemblyFormat = "$operand `:` type($operand) attr-dict"; let hasVerifier = 0; } @@ -55,7 +53,8 @@ class SPV_OCLUnaryOp traits = []> : - SPV_OCLUnaryOp; + SPV_OCLUnaryOp; // Base class for OpenCL binary ops. class SPV_OCLBinaryOp]; + let hasCustomAssemblyFormat = 1; let hasVerifier = 1; - let printer = [{ return ::print(p, *this); }]; - let parser = [{ return ::parse$cppClass(parser, result); }]; } def Shape_ShapeOfOp : Shape_Op<"shape_of", @@ -883,9 +880,6 @@ def Shape_AssumingOp : Shape_Op<"assuming", [ let regions = (region SizedRegion<1>:$doRegion); let results = (outs Variadic:$results); - let printer = [{ return ::print(p, *this); }]; - let parser = [{ return ::parse$cppClass(parser, result); }]; - let extraClassDeclaration = [{ // Inline the region into the region containing the AssumingOp and delete // the AssumingOp. @@ -900,6 +894,7 @@ def Shape_AssumingOp : Shape_Op<"assuming", [ ]; let hasCanonicalizer = 1; + let hasCustomAssemblyFormat = 1; let hasVerifier = 1; } @@ -1086,9 +1081,7 @@ def Shape_FunctionLibraryOp : Shape_Op<"function_library", let builders = [OpBuilder<(ins "StringRef":$name)>]; let skipDefaultBuilders = 1; - - let printer = [{ ::print(p, *this); }]; - let parser = [{ return ::parse$cppClass(parser, result); }]; + let hasCustomAssemblyFormat = 1; } #endif // SHAPE_OPS diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td index 278fedbbd3cb..0000714fc50c 100644 --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td @@ -19,10 +19,7 @@ include "mlir/Interfaces/SideEffectInterfaces.td" //===----------------------------------------------------------------------===// class SparseTensor_Op traits = []> - : Op { - let printer = [{ return ::print(p, *this); }]; - let parser = [{ return ::parse$cppClass(parser, result); }]; -} + : Op; //===----------------------------------------------------------------------===// // Sparse Tensor Operations. diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td index ffd16c16c09d..4fa909c56b69 100644 --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -31,15 +31,7 @@ def StandardOps_Dialect : Dialect { // Base class for Standard dialect ops. class Std_Op traits = []> : - Op { - // For every standard op, there needs to be a: - // * void print(OpAsmPrinter &p, ${C++ class of Op} op) - // * ParseResult parse${C++ class of Op}(OpAsmParser &parser, - // OperationState &result) - // functions. - let printer = [{ return ::print(p, *this); }]; - let parser = [{ return ::parse$cppClass(parser, result); }]; -} + Op; //===----------------------------------------------------------------------===// // CallOp diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td index a4dc2953b201..3a2ec73791d3 100644 --- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td +++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td @@ -18,10 +18,7 @@ include "mlir/Interfaces/TilingInterface.td" include "mlir/Interfaces/ViewLikeInterface.td" class Tensor_Op traits = []> - : Op { - let printer = [{ return ::print(p, *this); }]; - let parser = [{ return ::parse$cppClass(parser, result); }]; -} + : Op; // Base class for ops with static/dynamic offset, sizes and strides // attributes/arguments. @@ -737,9 +734,8 @@ class Tensor_ReassociativeReshapeOp traits = []> : let hasFolder = 1; let hasCanonicalizer = 1; + let hasCustomAssemblyFormat = 1; let hasVerifier = 1; - let printer = [{ return ::print(p, *this); }]; - let parser = [{ return ::parseReshapeLikeOp(parser, result); }]; } def Tensor_ExpandShapeOp : Tensor_ReassociativeReshapeOp<"expand_shape"> { diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td index 6feb8e0fa4a8..6d11514f0b64 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -26,15 +26,7 @@ def Vector_Dialect : Dialect { // Base class for Vector dialect ops. class Vector_Op traits = []> : - Op { - // For every vector op, there needs to be a: - // * void print(OpAsmPrinter &p, ${C++ class of Op} op) - // * ParseResult parse${C++ class of Op}(OpAsmParser &parser, - // OperationState &result) - // functions. - let printer = [{ return ::print(p, *this); }]; - let parser = [{ return ::parse$cppClass(parser, result); }]; -} + Op; // The "kind" of combining function for contractions and reductions. def COMBINING_KIND_ADD : BitEnumAttrCaseBit<"ADD", 0, "add">; @@ -253,6 +245,7 @@ def Vector_ContractionOp : }]; let hasCanonicalizer = 1; + let hasCustomAssemblyFormat = 1; let hasVerifier = 1; } @@ -289,6 +282,7 @@ def Vector_ReductionOp : return vector().getType().cast(); } }]; + let hasCustomAssemblyFormat = 1; let hasVerifier = 1; } @@ -478,6 +472,7 @@ def Vector_ShuffleOp : return vector().getType().cast(); } }]; + let hasCustomAssemblyFormat = 1; let hasVerifier = 1; } @@ -559,6 +554,7 @@ def Vector_ExtractOp : } }]; let hasCanonicalizer = 1; + let hasCustomAssemblyFormat = 1; let hasFolder = 1; let hasVerifier = 1; } @@ -968,6 +964,7 @@ def Vector_OuterProductOp : return CombiningKind::ADD; } }]; + let hasCustomAssemblyFormat = 1; let hasVerifier = 1; } @@ -1350,6 +1347,7 @@ def Vector_TransferReadOp : CArg<"Optional>", "::llvm::None">:$inBounds)>, ]; let hasCanonicalizer = 1; + let hasCustomAssemblyFormat = 1; let hasFolder = 1; let hasVerifier = 1; } @@ -1489,6 +1487,7 @@ def Vector_TransferWriteOp : ]; let hasFolder = 1; let hasCanonicalizer = 1; + let hasCustomAssemblyFormat = 1; let hasVerifier = 1; } diff --git a/mlir/include/mlir/IR/BuiltinOps.td b/mlir/include/mlir/IR/BuiltinOps.td index 93ab8eea0f9e..c447c7924842 100644 --- a/mlir/include/mlir/IR/BuiltinOps.td +++ b/mlir/include/mlir/IR/BuiltinOps.td @@ -155,8 +155,7 @@ def FuncOp : Builtin_Op<"func", [ bool isDeclaration() { return isExternal(); } }]; - let parser = [{ return ::parseFuncOp(parser, result); }]; - let printer = [{ return ::print(*this, p); }]; + let hasCustomAssemblyFormat = 1; let hasVerifier = 1; } diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp index eb6a0bd730fc..d3b3c6e471b8 100644 --- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -493,8 +493,7 @@ AffineValueMap AffineApplyOp::getAffineValueMap() { return AffineValueMap(getAffineMap(), getOperands(), getResult()); } -static ParseResult parseAffineApplyOp(OpAsmParser &parser, - OperationState &result) { +ParseResult AffineApplyOp::parse(OpAsmParser &parser, OperationState &result) { auto &builder = parser.getBuilder(); auto indexTy = builder.getIndexType(); @@ -516,11 +515,11 @@ static ParseResult parseAffineApplyOp(OpAsmParser &parser, return success(); } -static void print(OpAsmPrinter &p, AffineApplyOp op) { - p << " " << op.mapAttr(); - printDimAndSymbolList(op.operand_begin(), op.operand_end(), - op.getAffineMap().getNumDims(), p); - p.printOptionalAttrDict(op->getAttrs(), /*elidedAttrs=*/{"map"}); +void AffineApplyOp::print(OpAsmPrinter &p) { + p << " " << mapAttr(); + printDimAndSymbolList(operand_begin(), operand_end(), + getAffineMap().getNumDims(), p); + p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"map"}); } LogicalResult AffineApplyOp::verify() { @@ -1434,8 +1433,7 @@ static ParseResult parseBound(bool isLower, OperationState &result, "expected valid affine map representation for loop bounds"); } -static ParseResult parseAffineForOp(OpAsmParser &parser, - OperationState &result) { +ParseResult AffineForOp::parse(OpAsmParser &parser, OperationState &result) { auto &builder = parser.getBuilder(); OpAsmParser::OperandType inductionVariable; // Parse the induction variable followed by '='. @@ -1551,37 +1549,36 @@ unsigned AffineForOp::getNumIterOperands() { return getNumOperands() - lbMap.getNumInputs() - ubMap.getNumInputs(); } -static void print(OpAsmPrinter &p, AffineForOp op) { +void AffineForOp::print(OpAsmPrinter &p) { p << ' '; - p.printOperand(op.getBody()->getArgument(0)); + p.printOperand(getBody()->getArgument(0)); p << " = "; - printBound(op.getLowerBoundMapAttr(), op.getLowerBoundOperands(), "max", p); + printBound(getLowerBoundMapAttr(), getLowerBoundOperands(), "max", p); p << " to "; - printBound(op.getUpperBoundMapAttr(), op.getUpperBoundOperands(), "min", p); + printBound(getUpperBoundMapAttr(), getUpperBoundOperands(), "min", p); - if (op.getStep() != 1) - p << " step " << op.getStep(); + if (getStep() != 1) + p << " step " << getStep(); bool printBlockTerminators = false; - if (op.getNumIterOperands() > 0) { + if (getNumIterOperands() > 0) { p << " iter_args("; - auto regionArgs = op.getRegionIterArgs(); - auto operands = op.getIterOperands(); + auto regionArgs = getRegionIterArgs(); + auto operands = getIterOperands(); llvm::interleaveComma(llvm::zip(regionArgs, operands), p, [&](auto it) { p << std::get<0>(it) << " = " << std::get<1>(it); }); - p << ") -> (" << op.getResultTypes() << ")"; + p << ") -> (" << getResultTypes() << ")"; printBlockTerminators = true; } p << ' '; - p.printRegion(op.region(), - /*printEntryBlockArgs=*/false, printBlockTerminators); - p.printOptionalAttrDict(op->getAttrs(), - /*elidedAttrs=*/{op.getLowerBoundAttrName(), - op.getUpperBoundAttrName(), - op.getStepAttrName()}); + p.printRegion(region(), /*printEntryBlockArgs=*/false, printBlockTerminators); + p.printOptionalAttrDict((*this)->getAttrs(), + /*elidedAttrs=*/{getLowerBoundAttrName(), + getUpperBoundAttrName(), + getStepAttrName()}); } /// Fold the constant bounds of a loop. @@ -2081,8 +2078,7 @@ LogicalResult AffineIfOp::verify() { return success(); } -static ParseResult parseAffineIfOp(OpAsmParser &parser, - OperationState &result) { +ParseResult AffineIfOp::parse(OpAsmParser &parser, OperationState &result) { // Parse the condition attribute set. IntegerSetAttr conditionAttr; unsigned numDims; @@ -2132,30 +2128,29 @@ static ParseResult parseAffineIfOp(OpAsmParser &parser, return success(); } -static void print(OpAsmPrinter &p, AffineIfOp op) { +void AffineIfOp::print(OpAsmPrinter &p) { auto conditionAttr = - op->getAttrOfType(op.getConditionAttrName()); + (*this)->getAttrOfType(getConditionAttrName()); p << " " << conditionAttr; - printDimAndSymbolList(op.operand_begin(), op.operand_end(), + printDimAndSymbolList(operand_begin(), operand_end(), conditionAttr.getValue().getNumDims(), p); - p.printOptionalArrowTypeList(op.getResultTypes()); + p.printOptionalArrowTypeList(getResultTypes()); p << ' '; - p.printRegion(op.thenRegion(), - /*printEntryBlockArgs=*/false, - /*printBlockTerminators=*/op.getNumResults()); + p.printRegion(thenRegion(), /*printEntryBlockArgs=*/false, + /*printBlockTerminators=*/getNumResults()); // Print the 'else' regions if it has any blocks. - auto &elseRegion = op.elseRegion(); + auto &elseRegion = this->elseRegion(); if (!elseRegion.empty()) { p << " else "; p.printRegion(elseRegion, /*printEntryBlockArgs=*/false, - /*printBlockTerminators=*/op.getNumResults()); + /*printBlockTerminators=*/getNumResults()); } // Print the attribute list. - p.printOptionalAttrDict(op->getAttrs(), - /*elidedAttrs=*/op.getConditionAttrName()); + p.printOptionalAttrDict((*this)->getAttrs(), + /*elidedAttrs=*/getConditionAttrName()); } IntegerSet AffineIfOp::getIntegerSet() { @@ -2259,8 +2254,7 @@ void AffineLoadOp::build(OpBuilder &builder, OperationState &result, build(builder, result, memref, map, indices); } -static ParseResult parseAffineLoadOp(OpAsmParser &parser, - OperationState &result) { +ParseResult AffineLoadOp::parse(OpAsmParser &parser, OperationState &result) { auto &builder = parser.getBuilder(); auto indexTy = builder.getIndexType(); @@ -2280,15 +2274,15 @@ static ParseResult parseAffineLoadOp(OpAsmParser &parser, parser.addTypeToList(type.getElementType(), result.types)); } -static void print(OpAsmPrinter &p, AffineLoadOp op) { - p << " " << op.getMemRef() << '['; +void AffineLoadOp::print(OpAsmPrinter &p) { + p << " " << getMemRef() << '['; if (AffineMapAttr mapAttr = - op->getAttrOfType(op.getMapAttrName())) - p.printAffineMapOfSSAIds(mapAttr, op.getMapOperands()); + (*this)->getAttrOfType(getMapAttrName())) + p.printAffineMapOfSSAIds(mapAttr, getMapOperands()); p << ']'; - p.printOptionalAttrDict(op->getAttrs(), - /*elidedAttrs=*/{op.getMapAttrName()}); - p << " : " << op.getMemRefType(); + p.printOptionalAttrDict((*this)->getAttrs(), + /*elidedAttrs=*/{getMapAttrName()}); + p << " : " << getMemRefType(); } /// Verify common indexing invariants of affine.load, affine.store, @@ -2374,8 +2368,7 @@ void AffineStoreOp::build(OpBuilder &builder, OperationState &result, build(builder, result, valueToStore, memref, map, indices); } -static ParseResult parseAffineStoreOp(OpAsmParser &parser, - OperationState &result) { +ParseResult AffineStoreOp::parse(OpAsmParser &parser, OperationState &result) { auto indexTy = parser.getBuilder().getIndexType(); MemRefType type; @@ -2396,16 +2389,16 @@ static ParseResult parseAffineStoreOp(OpAsmParser &parser, parser.resolveOperands(mapOperands, indexTy, result.operands)); } -static void print(OpAsmPrinter &p, AffineStoreOp op) { - p << " " << op.getValueToStore(); - p << ", " << op.getMemRef() << '['; +void AffineStoreOp::print(OpAsmPrinter &p) { + p << " " << getValueToStore(); + p << ", " << getMemRef() << '['; if (AffineMapAttr mapAttr = - op->getAttrOfType(op.getMapAttrName())) - p.printAffineMapOfSSAIds(mapAttr, op.getMapOperands()); + (*this)->getAttrOfType(getMapAttrName())) + p.printAffineMapOfSSAIds(mapAttr, getMapOperands()); p << ']'; - p.printOptionalAttrDict(op->getAttrs(), - /*elidedAttrs=*/{op.getMapAttrName()}); - p << " : " << op.getMemRefType(); + p.printOptionalAttrDict((*this)->getAttrs(), + /*elidedAttrs=*/{getMapAttrName()}); + p << " : " << getMemRefType(); } LogicalResult AffineStoreOp::verify() { @@ -2669,6 +2662,12 @@ void AffineMinOp::getCanonicalizationPatterns(RewritePatternSet &patterns, LogicalResult AffineMinOp::verify() { return verifyAffineMinMaxOp(*this); } +ParseResult AffineMinOp::parse(OpAsmParser &parser, OperationState &result) { + return parseAffineMinMaxOp(parser, result); +} + +void AffineMinOp::print(OpAsmPrinter &p) { printAffineMinMaxOp(p, *this); } + //===----------------------------------------------------------------------===// // AffineMaxOp //===----------------------------------------------------------------------===// @@ -2690,6 +2689,12 @@ void AffineMaxOp::getCanonicalizationPatterns(RewritePatternSet &patterns, LogicalResult AffineMaxOp::verify() { return verifyAffineMinMaxOp(*this); } +ParseResult AffineMaxOp::parse(OpAsmParser &parser, OperationState &result) { + return parseAffineMinMaxOp(parser, result); +} + +void AffineMaxOp::print(OpAsmPrinter &p) { printAffineMinMaxOp(p, *this); } + //===----------------------------------------------------------------------===// // AffinePrefetchOp //===----------------------------------------------------------------------===// @@ -2697,8 +2702,8 @@ LogicalResult AffineMaxOp::verify() { return verifyAffineMinMaxOp(*this); } // // affine.prefetch %0[%i, %j + 5], read, locality<3>, data : memref<400x400xi32> // -static ParseResult parseAffinePrefetchOp(OpAsmParser &parser, - OperationState &result) { +ParseResult AffinePrefetchOp::parse(OpAsmParser &parser, + OperationState &result) { auto &builder = parser.getBuilder(); auto indexTy = builder.getIndexType(); @@ -2746,21 +2751,20 @@ static ParseResult parseAffinePrefetchOp(OpAsmParser &parser, return success(); } -static void print(OpAsmPrinter &p, AffinePrefetchOp op) { - p << " " << op.memref() << '['; - AffineMapAttr mapAttr = op->getAttrOfType(op.getMapAttrName()); - if (mapAttr) { - SmallVector operands(op.getMapOperands()); - p.printAffineMapOfSSAIds(mapAttr, operands); - } - p << ']' << ", " << (op.isWrite() ? "write" : "read") << ", " - << "locality<" << op.localityHint() << ">, " - << (op.isDataCache() ? "data" : "instr"); +void AffinePrefetchOp::print(OpAsmPrinter &p) { + p << " " << memref() << '['; + AffineMapAttr mapAttr = + (*this)->getAttrOfType(getMapAttrName()); + if (mapAttr) + p.printAffineMapOfSSAIds(mapAttr, getMapOperands()); + p << ']' << ", " << (isWrite() ? "write" : "read") << ", " + << "locality<" << localityHint() << ">, " + << (isDataCache() ? "data" : "instr"); p.printOptionalAttrDict( - op->getAttrs(), - /*elidedAttrs=*/{op.getMapAttrName(), op.getLocalityHintAttrName(), - op.getIsDataCacheAttrName(), op.getIsWriteAttrName()}); - p << " : " << op.getMemRefType(); + (*this)->getAttrs(), + /*elidedAttrs=*/{getMapAttrName(), getLocalityHintAttrName(), + getIsDataCacheAttrName(), getIsWriteAttrName()}); + p << " : " << getMemRefType(); } LogicalResult AffinePrefetchOp::verify() { @@ -3133,36 +3137,36 @@ static void printMinMaxBound(OpAsmPrinter &p, AffineMapAttr mapAttr, } } -static void print(OpAsmPrinter &p, AffineParallelOp op) { - p << " (" << op.getBody()->getArguments() << ") = ("; - printMinMaxBound(p, op.lowerBoundsMapAttr(), op.lowerBoundsGroupsAttr(), - op.getLowerBoundsOperands(), "max"); +void AffineParallelOp::print(OpAsmPrinter &p) { + p << " (" << getBody()->getArguments() << ") = ("; + printMinMaxBound(p, lowerBoundsMapAttr(), lowerBoundsGroupsAttr(), + getLowerBoundsOperands(), "max"); p << ") to ("; - printMinMaxBound(p, op.upperBoundsMapAttr(), op.upperBoundsGroupsAttr(), - op.getUpperBoundsOperands(), "min"); + printMinMaxBound(p, upperBoundsMapAttr(), upperBoundsGroupsAttr(), + getUpperBoundsOperands(), "min"); p << ')'; - SmallVector steps = op.getSteps(); + SmallVector steps = getSteps(); bool elideSteps = llvm::all_of(steps, [](int64_t step) { return step == 1; }); if (!elideSteps) { p << " step ("; llvm::interleaveComma(steps, p); p << ')'; } - if (op.getNumResults()) { + if (getNumResults()) { p << " reduce ("; - llvm::interleaveComma(op.reductions(), p, [&](auto &attr) { + llvm::interleaveComma(reductions(), p, [&](auto &attr) { arith::AtomicRMWKind sym = *arith::symbolizeAtomicRMWKind( attr.template cast().getInt()); p << "\"" << arith::stringifyAtomicRMWKind(sym) << "\""; }); - p << ") -> (" << op.getResultTypes() << ")"; + p << ") -> (" << getResultTypes() << ")"; } p << ' '; - p.printRegion(op.region(), /*printEntryBlockArgs=*/false, - /*printBlockTerminators=*/op.getNumResults()); + p.printRegion(region(), /*printEntryBlockArgs=*/false, + /*printBlockTerminators=*/getNumResults()); p.printOptionalAttrDict( - op->getAttrs(), + (*this)->getAttrs(), /*elidedAttrs=*/{AffineParallelOp::getReductionsAttrName(), AffineParallelOp::getLowerBoundsMapAttrName(), AffineParallelOp::getLowerBoundsGroupsAttrName(), @@ -3319,8 +3323,8 @@ static ParseResult parseAffineMapWithMinMax(OpAsmParser &parser, // `to` parallel-bound steps? region attr-dict? // steps ::= `steps` `(` integer-literals `)` // -static ParseResult parseAffineParallelOp(OpAsmParser &parser, - OperationState &result) { +ParseResult AffineParallelOp::parse(OpAsmParser &parser, + OperationState &result) { auto &builder = parser.getBuilder(); auto indexType = builder.getIndexType(); SmallVector ivs; @@ -3469,8 +3473,8 @@ void AffineVectorLoadOp::getCanonicalizationPatterns(RewritePatternSet &results, results.add>(context); } -static ParseResult parseAffineVectorLoadOp(OpAsmParser &parser, - OperationState &result) { +ParseResult AffineVectorLoadOp::parse(OpAsmParser &parser, + OperationState &result) { auto &builder = parser.getBuilder(); auto indexTy = builder.getIndexType(); @@ -3492,15 +3496,15 @@ static ParseResult parseAffineVectorLoadOp(OpAsmParser &parser, parser.addTypeToList(resultType, result.types)); } -static void print(OpAsmPrinter &p, AffineVectorLoadOp op) { - p << " " << op.getMemRef() << '['; +void AffineVectorLoadOp::print(OpAsmPrinter &p) { + p << " " << getMemRef() << '['; if (AffineMapAttr mapAttr = - op->getAttrOfType(op.getMapAttrName())) - p.printAffineMapOfSSAIds(mapAttr, op.getMapOperands()); + (*this)->getAttrOfType(getMapAttrName())) + p.printAffineMapOfSSAIds(mapAttr, getMapOperands()); p << ']'; - p.printOptionalAttrDict(op->getAttrs(), - /*elidedAttrs=*/{op.getMapAttrName()}); - p << " : " << op.getMemRefType() << ", " << op.getType(); + p.printOptionalAttrDict((*this)->getAttrs(), + /*elidedAttrs=*/{getMapAttrName()}); + p << " : " << getMemRefType() << ", " << getType(); } /// Verify common invariants of affine.vector_load and affine.vector_store. @@ -3559,8 +3563,8 @@ void AffineVectorStoreOp::getCanonicalizationPatterns( results.add>(context); } -static ParseResult parseAffineVectorStoreOp(OpAsmParser &parser, - OperationState &result) { +ParseResult AffineVectorStoreOp::parse(OpAsmParser &parser, + OperationState &result) { auto indexTy = parser.getBuilder().getIndexType(); MemRefType memrefType; @@ -3583,16 +3587,16 @@ static ParseResult parseAffineVectorStoreOp(OpAsmParser &parser, parser.resolveOperands(mapOperands, indexTy, result.operands)); } -static void print(OpAsmPrinter &p, AffineVectorStoreOp op) { - p << " " << op.getValueToStore(); - p << ", " << op.getMemRef() << '['; +void AffineVectorStoreOp::print(OpAsmPrinter &p) { + p << " " << getValueToStore(); + p << ", " << getMemRef() << '['; if (AffineMapAttr mapAttr = - op->getAttrOfType(op.getMapAttrName())) - p.printAffineMapOfSSAIds(mapAttr, op.getMapOperands()); + (*this)->getAttrOfType(getMapAttrName())) + p.printAffineMapOfSSAIds(mapAttr, getMapOperands()); p << ']'; - p.printOptionalAttrDict(op->getAttrs(), - /*elidedAttrs=*/{op.getMapAttrName()}); - p << " : " << op.getMemRefType() << ", " << op.getValueToStore().getType(); + p.printOptionalAttrDict((*this)->getAttrs(), + /*elidedAttrs=*/{getMapAttrName()}); + p << " : " << getMemRefType() << ", " << getValueToStore().getType(); } LogicalResult AffineVectorStoreOp::verify() { diff --git a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp index c836444ffb6c..fffc1b98b008 100644 --- a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp +++ b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp @@ -1506,16 +1506,7 @@ OpFoldResult arith::SelectOp::fold(ArrayRef operands) { return nullptr; } -static void print(OpAsmPrinter &p, arith::SelectOp op) { - p << " " << op.getOperands(); - p.printOptionalAttrDict(op->getAttrs()); - p << " : "; - if (ShapedType condType = op.getCondition().getType().dyn_cast()) - p << condType << ", "; - p << op.getType(); -} - -static ParseResult parseSelectOp(OpAsmParser &parser, OperationState &result) { +ParseResult SelectOp::parse(OpAsmParser &parser, OperationState &result) { Type conditionType, resultType; SmallVector operands; if (parser.parseOperandList(operands, /*requiredOperandCount=*/3) || @@ -1538,6 +1529,15 @@ static ParseResult parseSelectOp(OpAsmParser &parser, OperationState &result) { parser.getNameLoc(), result.operands); } +void arith::SelectOp::print(OpAsmPrinter &p) { + p << " " << getOperands(); + p.printOptionalAttrDict((*this)->getAttrs()); + p << " : "; + if (ShapedType condType = getCondition().getType().dyn_cast()) + p << condType << ", "; + p << getType(); +} + LogicalResult arith::SelectOp::verify() { Type conditionType = getCondition().getType(); if (conditionType.isSignlessInteger(1)) diff --git a/mlir/lib/Dialect/Async/IR/Async.cpp b/mlir/lib/Dialect/Async/IR/Async.cpp index 7a8c070b4e3e..916e2d1af451 100644 --- a/mlir/lib/Dialect/Async/IR/Async.cpp +++ b/mlir/lib/Dialect/Async/IR/Async.cpp @@ -125,16 +125,16 @@ void ExecuteOp::build(OpBuilder &builder, OperationState &result, } } -static void print(OpAsmPrinter &p, ExecuteOp op) { +void ExecuteOp::print(OpAsmPrinter &p) { // [%tokens,...] - if (!op.dependencies().empty()) - p << " [" << op.dependencies() << "]"; + if (!dependencies().empty()) + p << " [" << dependencies() << "]"; // (%value as %unwrapped: !async.value, ...) - if (!op.operands().empty()) { + if (!operands().empty()) { p << " ("; - Block *entry = op.body().empty() ? nullptr : &op.body().front(); - llvm::interleaveComma(op.operands(), p, [&, n = 0](Value operand) mutable { + Block *entry = body().empty() ? nullptr : &body().front(); + llvm::interleaveComma(operands(), p, [&, n = 0](Value operand) mutable { Value argument = entry ? entry->getArgument(n++) : Value(); p << operand << " as " << argument << ": " << operand.getType(); }); @@ -142,14 +142,14 @@ static void print(OpAsmPrinter &p, ExecuteOp op) { } // -> (!async.value, ...) - p.printOptionalArrowTypeList(llvm::drop_begin(op.getResultTypes())); - p.printOptionalAttrDictWithKeyword(op->getAttrs(), + p.printOptionalArrowTypeList(llvm::drop_begin(getResultTypes())); + p.printOptionalAttrDictWithKeyword((*this)->getAttrs(), {kOperandSegmentSizesAttr}); p << ' '; - p.printRegion(op.body(), /*printEntryBlockArgs=*/false); + p.printRegion(body(), /*printEntryBlockArgs=*/false); } -static ParseResult parseExecuteOp(OpAsmParser &parser, OperationState &result) { +ParseResult ExecuteOp::parse(OpAsmParser &parser, OperationState &result) { MLIRContext *ctx = result.getContext(); // Sizes of parsed variadic operands, will be updated below after parsing. diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp index 383833f685b1..d654c05d290d 100644 --- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp +++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp @@ -121,18 +121,18 @@ OpFoldResult emitc::ConstantOp::fold(ArrayRef operands) { // IncludeOp //===----------------------------------------------------------------------===// -static void print(OpAsmPrinter &p, IncludeOp &op) { - bool standardInclude = op.is_standard_include(); +void IncludeOp::print(OpAsmPrinter &p) { + bool standardInclude = is_standard_include(); p << " "; if (standardInclude) p << "<"; - p << "\"" << op.include() << "\""; + p << "\"" << include() << "\""; if (standardInclude) p << ">"; } -static ParseResult parseIncludeOp(OpAsmParser &parser, OperationState &result) { +ParseResult IncludeOp::parse(OpAsmParser &parser, OperationState &result) { bool standardInclude = !parser.parseOptionalLess(); StringAttr include; diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp index 731c1e8d330a..013ab7d31309 100644 --- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp +++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp @@ -445,21 +445,21 @@ static void printSizeAssignment(OpAsmPrinter &p, KernelDim3 size, p << size.z << " = " << operands.z << ')'; } -static void printLaunchOp(OpAsmPrinter &p, LaunchOp op) { +void LaunchOp::print(OpAsmPrinter &p) { // Print the launch configuration. - p << ' ' << op.getBlocksKeyword(); - printSizeAssignment(p, op.getGridSize(), op.getGridSizeOperandValues(), - op.getBlockIds()); - p << ' ' << op.getThreadsKeyword(); - printSizeAssignment(p, op.getBlockSize(), op.getBlockSizeOperandValues(), - op.getThreadIds()); - if (op.dynamicSharedMemorySize()) - p << ' ' << op.getDynamicSharedMemorySizeKeyword() << ' ' - << op.dynamicSharedMemorySize(); + p << ' ' << getBlocksKeyword(); + printSizeAssignment(p, getGridSize(), getGridSizeOperandValues(), + getBlockIds()); + p << ' ' << getThreadsKeyword(); + printSizeAssignment(p, getBlockSize(), getBlockSizeOperandValues(), + getThreadIds()); + if (dynamicSharedMemorySize()) + p << ' ' << getDynamicSharedMemorySizeKeyword() << ' ' + << dynamicSharedMemorySize(); p << ' '; - p.printRegion(op.body(), /*printEntryBlockArgs=*/false); - p.printOptionalAttrDict(op->getAttrs()); + p.printRegion(body(), /*printEntryBlockArgs=*/false); + p.printOptionalAttrDict((*this)->getAttrs()); } // Parse the size assignment blocks for blocks and threads. These have the form @@ -492,12 +492,14 @@ parseSizeAssignment(OpAsmParser &parser, return parser.parseRParen(); } -// Parses a Launch operation. -// operation ::= `gpu.launch` `blocks` `(` ssa-id-list `)` `in` ssa-reassignment -// `threads` `(` ssa-id-list `)` `in` ssa-reassignment -// region attr-dict? -// ssa-reassignment ::= `(` ssa-id `=` ssa-use (`,` ssa-id `=` ssa-use)* `)` -static ParseResult parseLaunchOp(OpAsmParser &parser, OperationState &result) { +/// Parses a Launch operation. +/// operation ::= `gpu.launch` `blocks` `(` ssa-id-list `)` `in` +/// ssa-reassignment +/// `threads` `(` ssa-id-list `)` `in` +/// ssa-reassignment +/// region attr-dict? +/// ssa-reassignment ::= `(` ssa-id `=` ssa-use (`,` ssa-id `=` ssa-use)* `)` +ParseResult LaunchOp::parse(OpAsmParser &parser, OperationState &result) { // Sizes of the grid and block. SmallVector sizes( LaunchOp::kNumConfigOperands); @@ -778,7 +780,7 @@ parseAttributions(OpAsmParser &parser, StringRef keyword, /// ::= `gpu.func` symbol-ref-id `(` argument-list `)` /// (`->` function-result-list)? memory-attribution `kernel`? /// function-attributes? region -static ParseResult parseGPUFuncOp(OpAsmParser &parser, OperationState &result) { +ParseResult GPUFuncOp::parse(OpAsmParser &parser, OperationState &result) { SmallVector entryArgs; SmallVector argAttrs; SmallVector resultAttrs; @@ -853,27 +855,26 @@ static void printAttributions(OpAsmPrinter &p, StringRef keyword, p << ')'; } -/// Prints a GPU Func op. -static void printGPUFuncOp(OpAsmPrinter &p, GPUFuncOp op) { +void GPUFuncOp::print(OpAsmPrinter &p) { p << ' '; - p.printSymbolName(op.getName()); + p.printSymbolName(getName()); - FunctionType type = op.getType(); - function_interface_impl::printFunctionSignature( - p, op.getOperation(), type.getInputs(), - /*isVariadic=*/false, type.getResults()); + FunctionType type = getType(); + function_interface_impl::printFunctionSignature(p, *this, type.getInputs(), + /*isVariadic=*/false, + type.getResults()); - printAttributions(p, op.getWorkgroupKeyword(), op.getWorkgroupAttributions()); - printAttributions(p, op.getPrivateKeyword(), op.getPrivateAttributions()); - if (op.isKernel()) - p << ' ' << op.getKernelKeyword(); + printAttributions(p, getWorkgroupKeyword(), getWorkgroupAttributions()); + printAttributions(p, getPrivateKeyword(), getPrivateAttributions()); + if (isKernel()) + p << ' ' << getKernelKeyword(); function_interface_impl::printFunctionAttributes( - p, op.getOperation(), type.getNumInputs(), type.getNumResults(), - {op.getNumWorkgroupAttributionsAttrName(), + p, *this, type.getNumInputs(), type.getNumResults(), + {getNumWorkgroupAttributionsAttrName(), GPUDialect::getKernelFuncAttrName()}); p << ' '; - p.printRegion(op.getBody(), /*printEntryBlockArgs=*/false); + p.printRegion(getBody(), /*printEntryBlockArgs=*/false); } LogicalResult GPUFuncOp::verifyType() { @@ -970,10 +971,9 @@ void GPUModuleOp::build(OpBuilder &builder, OperationState &result, ::mlir::SymbolTable::getSymbolAttrName(), builder.getStringAttr(name))); } -static ParseResult parseGPUModuleOp(OpAsmParser &parser, - OperationState &result) { +ParseResult GPUModuleOp::parse(OpAsmParser &parser, OperationState &result) { StringAttr nameAttr; - if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(), + if (parser.parseSymbolName(nameAttr, mlir::SymbolTable::getSymbolAttrName(), result.attributes)) return failure(); @@ -991,13 +991,13 @@ static ParseResult parseGPUModuleOp(OpAsmParser &parser, return success(); } -static void print(OpAsmPrinter &p, GPUModuleOp op) { +void GPUModuleOp::print(OpAsmPrinter &p) { p << ' '; - p.printSymbolName(op.getName()); - p.printOptionalAttrDictWithKeyword(op->getAttrs(), - {SymbolTable::getSymbolAttrName()}); + p.printSymbolName(getName()); + p.printOptionalAttrDictWithKeyword((*this)->getAttrs(), + {mlir::SymbolTable::getSymbolAttrName()}); p << ' '; - p.printRegion(op->getRegion(0), /*printEntryBlockArgs=*/false, + p.printRegion(getRegion(), /*printEntryBlockArgs=*/false, /*printBlockTerminators=*/false); } diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp index fb2cd2546d5f..8714faf474e6 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -90,18 +90,19 @@ static LogicalResult verifySymbolAttrUse(FlatSymbolRefAttr symbol, //===----------------------------------------------------------------------===// // Printing/parsing for LLVM::CmpOp. //===----------------------------------------------------------------------===// -static void printICmpOp(OpAsmPrinter &p, ICmpOp &op) { - p << " \"" << stringifyICmpPredicate(op.getPredicate()) << "\" " - << op.getOperand(0) << ", " << op.getOperand(1); - p.printOptionalAttrDict(op->getAttrs(), {"predicate"}); - p << " : " << op.getLhs().getType(); + +void ICmpOp::print(OpAsmPrinter &p) { + p << " \"" << stringifyICmpPredicate(getPredicate()) << "\" " << getOperand(0) + << ", " << getOperand(1); + p.printOptionalAttrDict((*this)->getAttrs(), {"predicate"}); + p << " : " << getLhs().getType(); } -static void printFCmpOp(OpAsmPrinter &p, FCmpOp &op) { - p << " \"" << stringifyFCmpPredicate(op.getPredicate()) << "\" " - << op.getOperand(0) << ", " << op.getOperand(1); - p.printOptionalAttrDict(processFMFAttr(op->getAttrs()), {"predicate"}); - p << " : " << op.getLhs().getType(); +void FCmpOp::print(OpAsmPrinter &p) { + p << " \"" << stringifyFCmpPredicate(getPredicate()) << "\" " << getOperand(0) + << ", " << getOperand(1); + p.printOptionalAttrDict(processFMFAttr((*this)->getAttrs()), {"predicate"}); + p << " : " << getLhs().getType(); } // ::= `llvm.icmp` string-literal ssa-use `,` ssa-use @@ -171,27 +172,35 @@ static ParseResult parseCmpOp(OpAsmParser &parser, OperationState &result) { return success(); } +ParseResult ICmpOp::parse(OpAsmParser &parser, OperationState &result) { + return parseCmpOp(parser, result); +} + +ParseResult FCmpOp::parse(OpAsmParser &parser, OperationState &result) { + return parseCmpOp(parser, result); +} + //===----------------------------------------------------------------------===// // Printing/parsing for LLVM::AllocaOp. //===----------------------------------------------------------------------===// -static void printAllocaOp(OpAsmPrinter &p, AllocaOp &op) { - auto elemTy = op.getType().cast().getElementType(); +void AllocaOp::print(OpAsmPrinter &p) { + auto elemTy = getType().cast().getElementType(); - auto funcTy = FunctionType::get( - op.getContext(), {op.getArraySize().getType()}, {op.getType()}); + auto funcTy = + FunctionType::get(getContext(), {getArraySize().getType()}, {getType()}); - p << ' ' << op.getArraySize() << " x " << elemTy; - if (op.getAlignment().hasValue() && *op.getAlignment() != 0) - p.printOptionalAttrDict(op->getAttrs()); + p << ' ' << getArraySize() << " x " << elemTy; + if (getAlignment().hasValue() && *getAlignment() != 0) + p.printOptionalAttrDict((*this)->getAttrs()); else - p.printOptionalAttrDict(op->getAttrs(), {"alignment"}); + p.printOptionalAttrDict((*this)->getAttrs(), {"alignment"}); p << " : " << funcTy; } // ::= `llvm.alloca` ssa-use `x` type attribute-dict? // `:` type `,` type -static ParseResult parseAllocaOp(OpAsmParser &parser, OperationState &result) { +ParseResult AllocaOp::parse(OpAsmParser &parser, OperationState &result) { OpAsmParser::OperandType arraySize; Type type, elemType; SMLoc trailingTypeLoc; @@ -627,13 +636,13 @@ void LoadOp::build(OpBuilder &builder, OperationState &result, Type t, result.addAttribute("alignment", builder.getI64IntegerAttr(alignment)); } -static void printLoadOp(OpAsmPrinter &p, LoadOp &op) { +void LoadOp::print(OpAsmPrinter &p) { p << ' '; - if (op.getVolatile_()) + if (getVolatile_()) p << "volatile "; - p << op.getAddr(); - p.printOptionalAttrDict(op->getAttrs(), {kVolatileAttrName}); - p << " : " << op.getAddr().getType(); + p << getAddr(); + p.printOptionalAttrDict((*this)->getAttrs(), {kVolatileAttrName}); + p << " : " << getAddr().getType(); } // Extract the pointee type from the LLVM pointer type wrapped in MLIR. Return @@ -648,7 +657,7 @@ static Type getLoadStoreElementType(OpAsmParser &parser, Type type, } // ::= `llvm.load` `volatile` ssa-use attribute-dict? `:` type -static ParseResult parseLoadOp(OpAsmParser &parser, OperationState &result) { +ParseResult LoadOp::parse(OpAsmParser &parser, OperationState &result) { OpAsmParser::OperandType addr; Type type; SMLoc trailingTypeLoc; @@ -687,18 +696,18 @@ void StoreOp::build(OpBuilder &builder, OperationState &result, Value value, result.addAttribute("alignment", builder.getI64IntegerAttr(alignment)); } -static void printStoreOp(OpAsmPrinter &p, StoreOp &op) { +void StoreOp::print(OpAsmPrinter &p) { p << ' '; - if (op.getVolatile_()) + if (getVolatile_()) p << "volatile "; - p << op.getValue() << ", " << op.getAddr(); - p.printOptionalAttrDict(op->getAttrs(), {kVolatileAttrName}); - p << " : " << op.getAddr().getType(); + p << getValue() << ", " << getAddr(); + p.printOptionalAttrDict((*this)->getAttrs(), {kVolatileAttrName}); + p << " : " << getAddr().getType(); } // ::= `llvm.store` `volatile` ssa-use `,` ssa-use // attribute-dict? `:` type -static ParseResult parseStoreOp(OpAsmParser &parser, OperationState &result) { +ParseResult StoreOp::parse(OpAsmParser &parser, OperationState &result) { OpAsmParser::OperandType addr, value; Type type; SMLoc trailingTypeLoc; @@ -750,8 +759,8 @@ LogicalResult InvokeOp::verify() { return success(); } -static void printInvokeOp(OpAsmPrinter &p, InvokeOp op) { - auto callee = op.getCallee(); +void InvokeOp::print(OpAsmPrinter &p) { + auto callee = getCallee(); bool isDirect = callee.hasValue(); p << ' '; @@ -760,27 +769,26 @@ static void printInvokeOp(OpAsmPrinter &p, InvokeOp op) { if (isDirect) p.printSymbolName(callee.getValue()); else - p << op.getOperand(0); + p << getOperand(0); - p << '(' << op.getOperands().drop_front(isDirect ? 0 : 1) << ')'; + p << '(' << getOperands().drop_front(isDirect ? 0 : 1) << ')'; p << " to "; - p.printSuccessorAndUseList(op.getNormalDest(), op.getNormalDestOperands()); + p.printSuccessorAndUseList(getNormalDest(), getNormalDestOperands()); p << " unwind "; - p.printSuccessorAndUseList(op.getUnwindDest(), op.getUnwindDestOperands()); + p.printSuccessorAndUseList(getUnwindDest(), getUnwindDestOperands()); - p.printOptionalAttrDict(op->getAttrs(), + p.printOptionalAttrDict((*this)->getAttrs(), {InvokeOp::getOperandSegmentSizeAttr(), "callee"}); p << " : "; - p.printFunctionalType( - llvm::drop_begin(op.getOperandTypes(), isDirect ? 0 : 1), - op.getResultTypes()); + p.printFunctionalType(llvm::drop_begin(getOperandTypes(), isDirect ? 0 : 1), + getResultTypes()); } /// ::= `llvm.invoke` (function-id | ssa-use) `(` ssa-use-list `)` /// `to` bb-id (`[` ssa-use-and-type-list `]`)? /// `unwind` bb-id (`[` ssa-use-and-type-list `]`)? /// attribute-dict? `:` function-type -static ParseResult parseInvokeOp(OpAsmParser &parser, OperationState &result) { +ParseResult InvokeOp::parse(OpAsmParser &parser, OperationState &result) { SmallVector operands; FunctionType funcType; SymbolRefAttr funcAttr; @@ -913,11 +921,11 @@ LogicalResult LandingpadOp::verify() { return success(); } -static void printLandingpadOp(OpAsmPrinter &p, LandingpadOp &op) { - p << (op.getCleanup() ? " cleanup " : " "); +void LandingpadOp::print(OpAsmPrinter &p) { + p << (getCleanup() ? " cleanup " : " "); // Clauses - for (auto value : op.getOperands()) { + for (auto value : getOperands()) { // Similar to llvm - if clause is an array type then it is filter // clause else catch clause bool isArrayTy = value.getType().isa(); @@ -925,15 +933,14 @@ static void printLandingpadOp(OpAsmPrinter &p, LandingpadOp &op) { << value.getType() << ") "; } - p.printOptionalAttrDict(op->getAttrs(), {"cleanup"}); + p.printOptionalAttrDict((*this)->getAttrs(), {"cleanup"}); - p << ": " << op.getType(); + p << ": " << getType(); } /// ::= `llvm.landingpad` `cleanup`? /// ((`catch` | `filter`) operand-type ssa-use)* attribute-dict? -static ParseResult parseLandingpadOp(OpAsmParser &parser, - OperationState &result) { +ParseResult LandingpadOp::parse(OpAsmParser &parser, OperationState &result) { // Check for cleanup if (succeeded(parser.parseOptionalKeyword("cleanup"))) result.addAttribute("cleanup", parser.getBuilder().getUnitAttr()); @@ -1045,8 +1052,8 @@ LogicalResult CallOp::verify() { return success(); } -static void printCallOp(OpAsmPrinter &p, CallOp &op) { - auto callee = op.getCallee(); +void CallOp::print(OpAsmPrinter &p) { + auto callee = getCallee(); bool isDirect = callee.hasValue(); // Print the direct callee if present as a function attribute, or an indirect @@ -1055,20 +1062,20 @@ static void printCallOp(OpAsmPrinter &p, CallOp &op) { if (isDirect) p.printSymbolName(callee.getValue()); else - p << op.getOperand(0); + p << getOperand(0); - auto args = op.getOperands().drop_front(isDirect ? 0 : 1); + auto args = getOperands().drop_front(isDirect ? 0 : 1); p << '(' << args << ')'; - p.printOptionalAttrDict(processFMFAttr(op->getAttrs()), {"callee"}); + p.printOptionalAttrDict(processFMFAttr((*this)->getAttrs()), {"callee"}); // Reconstruct the function MLIR function type from operand and result types. - p << " : " - << FunctionType::get(op.getContext(), args.getTypes(), op.getResultTypes()); + p << " : "; + p.printFunctionalType(args.getTypes(), getResultTypes()); } // ::= `llvm.call` (function-id | ssa-use) `(` ssa-use-list `)` // attribute-dict? `:` function-type -static ParseResult parseCallOp(OpAsmParser &parser, OperationState &result) { +ParseResult CallOp::parse(OpAsmParser &parser, OperationState &result) { SmallVector operands; Type type; SymbolRefAttr funcAttr; @@ -1162,17 +1169,17 @@ void LLVM::ExtractElementOp::build(OpBuilder &b, OperationState &result, result.addAttributes(attrs); } -static void printExtractElementOp(OpAsmPrinter &p, ExtractElementOp &op) { - p << ' ' << op.getVector() << "[" << op.getPosition() << " : " - << op.getPosition().getType() << "]"; - p.printOptionalAttrDict(op->getAttrs()); - p << " : " << op.getVector().getType(); +void ExtractElementOp::print(OpAsmPrinter &p) { + p << ' ' << getVector() << "[" << getPosition() << " : " + << getPosition().getType() << "]"; + p.printOptionalAttrDict((*this)->getAttrs()); + p << " : " << getVector().getType(); } // ::= `llvm.extractelement` ssa-use `, ` ssa-use // attribute-dict? `:` type -static ParseResult parseExtractElementOp(OpAsmParser &parser, - OperationState &result) { +ParseResult ExtractElementOp::parse(OpAsmParser &parser, + OperationState &result) { SMLoc loc; OpAsmParser::OperandType vector, position; Type type, positionType; @@ -1209,10 +1216,10 @@ LogicalResult ExtractElementOp::verify() { // Printing/parsing for LLVM::ExtractValueOp. //===----------------------------------------------------------------------===// -static void printExtractValueOp(OpAsmPrinter &p, ExtractValueOp &op) { - p << ' ' << op.getContainer() << op.getPosition(); - p.printOptionalAttrDict(op->getAttrs(), {"position"}); - p << " : " << op.getContainer().getType(); +void ExtractValueOp::print(OpAsmPrinter &p) { + p << ' ' << getContainer() << getPosition(); + p.printOptionalAttrDict((*this)->getAttrs(), {"position"}); + p << " : " << getContainer().getType(); } // Extract the type at `position` in the wrapped LLVM IR aggregate type @@ -1308,8 +1315,7 @@ static Type getInsertExtractValueElementType(Type containerType, // ::= `llvm.extractvalue` ssa-use // `[` integer-literal (`,` integer-literal)* `]` // attribute-dict? `:` type -static ParseResult parseExtractValueOp(OpAsmParser &parser, - OperationState &result) { +ParseResult ExtractValueOp::parse(OpAsmParser &parser, OperationState &result) { OpAsmParser::OperandType container; Type containerType; ArrayAttr positionAttr; @@ -1376,17 +1382,17 @@ LogicalResult ExtractValueOp::verify() { // Printing/parsing for LLVM::InsertElementOp. //===----------------------------------------------------------------------===// -static void printInsertElementOp(OpAsmPrinter &p, InsertElementOp &op) { - p << ' ' << op.getValue() << ", " << op.getVector() << "[" << op.getPosition() - << " : " << op.getPosition().getType() << "]"; - p.printOptionalAttrDict(op->getAttrs()); - p << " : " << op.getVector().getType(); +void InsertElementOp::print(OpAsmPrinter &p) { + p << ' ' << getValue() << ", " << getVector() << "[" << getPosition() << " : " + << getPosition().getType() << "]"; + p.printOptionalAttrDict((*this)->getAttrs()); + p << " : " << getVector().getType(); } // ::= `llvm.insertelement` ssa-use `,` ssa-use `,` ssa-use // attribute-dict? `:` type -static ParseResult parseInsertElementOp(OpAsmParser &parser, - OperationState &result) { +ParseResult InsertElementOp::parse(OpAsmParser &parser, + OperationState &result) { SMLoc loc; OpAsmParser::OperandType vector, value, position; Type vectorType, positionType; @@ -1427,17 +1433,16 @@ LogicalResult InsertElementOp::verify() { // Printing/parsing for LLVM::InsertValueOp. //===----------------------------------------------------------------------===// -static void printInsertValueOp(OpAsmPrinter &p, InsertValueOp &op) { - p << ' ' << op.getValue() << ", " << op.getContainer() << op.getPosition(); - p.printOptionalAttrDict(op->getAttrs(), {"position"}); - p << " : " << op.getContainer().getType(); +void InsertValueOp::print(OpAsmPrinter &p) { + p << ' ' << getValue() << ", " << getContainer() << getPosition(); + p.printOptionalAttrDict((*this)->getAttrs(), {"position"}); + p << " : " << getContainer().getType(); } // ::= `llvm.insertvaluevalue` ssa-use `,` ssa-use // `[` integer-literal (`,` integer-literal)* `]` // attribute-dict? `:` type -static ParseResult parseInsertValueOp(OpAsmParser &parser, - OperationState &result) { +ParseResult InsertValueOp::parse(OpAsmParser &parser, OperationState &result) { OpAsmParser::OperandType container, value; Type containerType; ArrayAttr positionAttr; @@ -1483,34 +1488,6 @@ LogicalResult InsertValueOp::verify() { // Printing, parsing and verification for LLVM::ReturnOp. //===----------------------------------------------------------------------===// -static void printReturnOp(OpAsmPrinter &p, ReturnOp op) { - p.printOptionalAttrDict(op->getAttrs()); - assert(op.getNumOperands() <= 1); - - if (op.getNumOperands() == 0) - return; - - p << ' ' << op.getOperand(0) << " : " << op.getOperand(0).getType(); -} - -// ::= `llvm.return` ssa-use-list attribute-dict? `:` -// type-list-no-parens -static ParseResult parseReturnOp(OpAsmParser &parser, OperationState &result) { - SmallVector operands; - Type type; - - if (parser.parseOperandList(operands) || - parser.parseOptionalAttrDict(result.attributes)) - return failure(); - if (operands.empty()) - return success(); - - if (parser.parseColonType(type) || - parser.resolveOperand(operands[0], type, result.operands)) - return failure(); - return success(); -} - LogicalResult ReturnOp::verify() { if (getNumOperands() > 1) return emitOpError("expected at most 1 operand"); @@ -1636,34 +1613,34 @@ void GlobalOp::build(OpBuilder &builder, OperationState &result, Type type, result.addRegion(); } -static void printGlobalOp(OpAsmPrinter &p, GlobalOp op) { - p << ' ' << stringifyLinkage(op.getLinkage()) << ' '; - if (auto unnamedAddr = op.getUnnamedAddr()) { +void GlobalOp::print(OpAsmPrinter &p) { + p << ' ' << stringifyLinkage(getLinkage()) << ' '; + if (auto unnamedAddr = getUnnamedAddr()) { StringRef str = stringifyUnnamedAddr(*unnamedAddr); if (!str.empty()) p << str << ' '; } - if (op.getConstant()) + if (getConstant()) p << "constant "; - p.printSymbolName(op.getSymName()); + p.printSymbolName(getSymName()); p << '('; - if (auto value = op.getValueOrNull()) + if (auto value = getValueOrNull()) p.printAttribute(value); p << ')'; // Note that the alignment attribute is printed using the // default syntax here, even though it is an inherent attribute // (as defined in https://mlir.llvm.org/docs/LangRef/#attributes) - p.printOptionalAttrDict(op->getAttrs(), + p.printOptionalAttrDict((*this)->getAttrs(), {SymbolTable::getSymbolAttrName(), "global_type", "constant", "value", getLinkageAttrName(), getUnnamedAddrAttrName()}); // Print the trailing type unless it's a string global. - if (op.getValueOrNull().dyn_cast_or_null()) + if (getValueOrNull().dyn_cast_or_null()) return; - p << " : " << op.getType(); + p << " : " << getType(); - Region &initializer = op.getInitializerRegion(); + Region &initializer = getInitializerRegion(); if (!initializer.empty()) { p << ' '; p.printRegion(initializer, /*printEntryBlockArgs=*/false); @@ -1720,15 +1697,15 @@ static RetTy parseOptionalLLVMKeyword(OpAsmParser &parser, // // The type can be omitted for string attributes, in which case it will be // inferred from the value of the string as [strlen(value) x i8]. -static ParseResult parseGlobalOp(OpAsmParser &parser, OperationState &result) { +ParseResult GlobalOp::parse(OpAsmParser &parser, OperationState &result) { MLIRContext *ctx = parser.getContext(); // Parse optional linkage, default to External. - result.addAttribute(getLinkageAttrName(), + result.addAttribute(::getLinkageAttrName(), LLVM::LinkageAttr::get( ctx, parseOptionalLLVMKeyword( parser, result, LLVM::Linkage::External))); // Parse optional UnnamedAddr, default to None. - result.addAttribute(getUnnamedAddrAttrName(), + result.addAttribute(::getUnnamedAddrAttrName(), parser.getBuilder().getI64IntegerAttr( parseOptionalLLVMKeyword( parser, result, LLVM::UnnamedAddr::None))); @@ -1910,17 +1887,17 @@ void LLVM::ShuffleVectorOp::build(OpBuilder &b, OperationState &result, result.addAttributes(attrs); } -static void printShuffleVectorOp(OpAsmPrinter &p, ShuffleVectorOp &op) { - p << ' ' << op.getV1() << ", " << op.getV2() << " " << op.getMask(); - p.printOptionalAttrDict(op->getAttrs(), {"mask"}); - p << " : " << op.getV1().getType() << ", " << op.getV2().getType(); +void ShuffleVectorOp::print(OpAsmPrinter &p) { + p << ' ' << getV1() << ", " << getV2() << " " << getMask(); + p.printOptionalAttrDict((*this)->getAttrs(), {"mask"}); + p << " : " << getV1().getType() << ", " << getV2().getType(); } // ::= `llvm.shufflevector` ssa-use `, ` ssa-use // `[` integer-literal (`,` integer-literal)* `]` // attribute-dict? `:` type -static ParseResult parseShuffleVectorOp(OpAsmParser &parser, - OperationState &result) { +ParseResult ShuffleVectorOp::parse(OpAsmParser &parser, + OperationState &result) { SMLoc loc; OpAsmParser::OperandType v1, v2; ArrayAttr maskAttr; @@ -2035,11 +2012,10 @@ buildLLVMFunctionType(OpAsmParser &parser, SMLoc loc, // operation ::= `llvm.func` linkage? function-signature function-attributes? // function-body // -static ParseResult parseLLVMFuncOp(OpAsmParser &parser, - OperationState &result) { +ParseResult LLVMFuncOp::parse(OpAsmParser &parser, OperationState &result) { // Default to external linkage if no keyword is provided. result.addAttribute( - getLinkageAttrName(), + ::getLinkageAttrName(), LinkageAttr::get(parser.getContext(), parseOptionalLLVMKeyword( parser, result, LLVM::Linkage::External))); @@ -2083,13 +2059,13 @@ static ParseResult parseLLVMFuncOp(OpAsmParser &parser, // Print the LLVMFuncOp. Collects argument and result types and passes them to // helper functions. Drops "void" result since it cannot be parsed back. Skips // the external linkage since it is the default value. -static void printLLVMFuncOp(OpAsmPrinter &p, LLVMFuncOp op) { +void LLVMFuncOp::print(OpAsmPrinter &p) { p << ' '; - if (op.getLinkage() != LLVM::Linkage::External) - p << stringifyLinkage(op.getLinkage()) << ' '; - p.printSymbolName(op.getName()); + if (getLinkage() != LLVM::Linkage::External) + p << stringifyLinkage(getLinkage()) << ' '; + p.printSymbolName(getName()); - LLVMFunctionType fnType = op.getType(); + LLVMFunctionType fnType = getType(); SmallVector argTypes; SmallVector resTypes; argTypes.reserve(fnType.getNumParams()); @@ -2100,13 +2076,13 @@ static void printLLVMFuncOp(OpAsmPrinter &p, LLVMFuncOp op) { if (!returnType.isa()) resTypes.push_back(returnType); - function_interface_impl::printFunctionSignature(p, op, argTypes, - op.isVarArg(), resTypes); + function_interface_impl::printFunctionSignature(p, *this, argTypes, + isVarArg(), resTypes); function_interface_impl::printFunctionAttributes( - p, op, argTypes.size(), resTypes.size(), {getLinkageAttrName()}); + p, *this, argTypes.size(), resTypes.size(), {getLinkageAttrName()}); // Print the body if this is not an external function. - Region &body = op.getBody(); + Region &body = getBody(); if (!body.empty()) { p << ' '; p.printRegion(body, /*printEntryBlockArgs=*/false, @@ -2279,17 +2255,16 @@ static ParseResult parseAtomicOrdering(OpAsmParser &parser, // Printer, parser and verifier for LLVM::AtomicRMWOp. //===----------------------------------------------------------------------===// -static void printAtomicRMWOp(OpAsmPrinter &p, AtomicRMWOp &op) { - p << ' ' << stringifyAtomicBinOp(op.getBinOp()) << ' ' << op.getPtr() << ", " - << op.getVal() << ' ' << stringifyAtomicOrdering(op.getOrdering()) << ' '; - p.printOptionalAttrDict(op->getAttrs(), {"bin_op", "ordering"}); - p << " : " << op.getRes().getType(); +void AtomicRMWOp::print(OpAsmPrinter &p) { + p << ' ' << stringifyAtomicBinOp(getBinOp()) << ' ' << getPtr() << ", " + << getVal() << ' ' << stringifyAtomicOrdering(getOrdering()) << ' '; + p.printOptionalAttrDict((*this)->getAttrs(), {"bin_op", "ordering"}); + p << " : " << getRes().getType(); } // ::= `llvm.atomicrmw` keyword ssa-use `,` ssa-use keyword // attribute-dict? `:` type -static ParseResult parseAtomicRMWOp(OpAsmParser &parser, - OperationState &result) { +ParseResult AtomicRMWOp::parse(OpAsmParser &parser, OperationState &result) { Type type; OpAsmParser::OperandType ptr, val; if (parseAtomicBinOp(parser, result, "bin_op") || parser.parseOperand(ptr) || @@ -2348,19 +2323,19 @@ LogicalResult AtomicRMWOp::verify() { // Printer, parser and verifier for LLVM::AtomicCmpXchgOp. //===----------------------------------------------------------------------===// -static void printAtomicCmpXchgOp(OpAsmPrinter &p, AtomicCmpXchgOp &op) { - p << ' ' << op.getPtr() << ", " << op.getCmp() << ", " << op.getVal() << ' ' - << stringifyAtomicOrdering(op.getSuccessOrdering()) << ' ' - << stringifyAtomicOrdering(op.getFailureOrdering()); - p.printOptionalAttrDict(op->getAttrs(), +void AtomicCmpXchgOp::print(OpAsmPrinter &p) { + p << ' ' << getPtr() << ", " << getCmp() << ", " << getVal() << ' ' + << stringifyAtomicOrdering(getSuccessOrdering()) << ' ' + << stringifyAtomicOrdering(getFailureOrdering()); + p.printOptionalAttrDict((*this)->getAttrs(), {"success_ordering", "failure_ordering"}); - p << " : " << op.getVal().getType(); + p << " : " << getVal().getType(); } // ::= `llvm.cmpxchg` ssa-use `,` ssa-use `,` ssa-use // keyword keyword attribute-dict? `:` type -static ParseResult parseAtomicCmpXchgOp(OpAsmParser &parser, - OperationState &result) { +ParseResult AtomicCmpXchgOp::parse(OpAsmParser &parser, + OperationState &result) { auto &builder = parser.getBuilder(); Type type; OpAsmParser::OperandType ptr, cmp, val; @@ -2416,7 +2391,7 @@ LogicalResult AtomicCmpXchgOp::verify() { // ::= `llvm.fence` (`syncscope(`strAttr`)`)? keyword // attribute-dict? -static ParseResult parseFenceOp(OpAsmParser &parser, OperationState &result) { +ParseResult FenceOp::parse(OpAsmParser &parser, OperationState &result) { StringAttr sScope; StringRef syncscopeKeyword = "syncscope"; if (!failed(parser.parseOptionalKeyword(syncscopeKeyword))) { @@ -2434,12 +2409,12 @@ static ParseResult parseFenceOp(OpAsmParser &parser, OperationState &result) { return success(); } -static void printFenceOp(OpAsmPrinter &p, FenceOp &op) { +void FenceOp::print(OpAsmPrinter &p) { StringRef syncscopeKeyword = "syncscope"; p << ' '; - if (!op->getAttr(syncscopeKeyword).cast().getValue().empty()) - p << "syncscope(" << op->getAttr(syncscopeKeyword) << ") "; - p << stringifyAtomicOrdering(op.getOrdering()); + if (!(*this)->getAttr(syncscopeKeyword).cast().getValue().empty()) + p << "syncscope(" << (*this)->getAttr(syncscopeKeyword) << ") "; + p << stringifyAtomicOrdering(getOrdering()); } LogicalResult FenceOp::verify() { diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp index 5d5e8f401212..3557525eb0fe 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -46,8 +46,7 @@ static void printNVVMIntrinsicOp(OpAsmPrinter &p, Operation *op) { } // ::= `llvm.nvvm.vote.ballot.sync %mask, %pred` : result_type -static ParseResult parseNVVMVoteBallotOp(OpAsmParser &parser, - OperationState &result) { +ParseResult VoteBallotOp::parse(OpAsmParser &parser, OperationState &result) { MLIRContext *context = parser.getContext(); auto int32Ty = IntegerType::get(context, 32); auto int1Ty = IntegerType::get(context, 1); @@ -62,6 +61,8 @@ static ParseResult parseNVVMVoteBallotOp(OpAsmParser &parser, parser.getNameLoc(), result.operands)); } +void VoteBallotOp::print(OpAsmPrinter &p) { printNVVMIntrinsicOp(p, *this); } + LogicalResult CpAsyncOp::verify() { if (size() != 4 && size() != 8 && size() != 16) return emitError("expected byte size to be either 4, 8 or 16."); diff --git a/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp index 19aeeadab01e..278f236acba0 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp @@ -39,8 +39,7 @@ using namespace ROCDL; // ::= // `llvm.amdgcn.buffer.load.* %rsrc, %vindex, %offset, %glc, %slc : // result_type` -static ParseResult parseROCDLMubufLoadOp(OpAsmParser &parser, - OperationState &result) { +ParseResult MubufLoadOp::parse(OpAsmParser &parser, OperationState &result) { SmallVector ops; Type type; if (parser.parseOperandList(ops, 5) || parser.parseColonType(type) || @@ -56,11 +55,14 @@ static ParseResult parseROCDLMubufLoadOp(OpAsmParser &parser, parser.getNameLoc(), result.operands); } +void MubufLoadOp::print(OpAsmPrinter &p) { + p << " " << getOperands() << " : " << (*this)->getResultTypes(); +} + // ::= // `llvm.amdgcn.buffer.store.* %vdata, %rsrc, %vindex, %offset, %glc, %slc : // result_type` -static ParseResult parseROCDLMubufStoreOp(OpAsmParser &parser, - OperationState &result) { +ParseResult MubufStoreOp::parse(OpAsmParser &parser, OperationState &result) { SmallVector ops; Type type; if (parser.parseOperandList(ops, 6) || parser.parseColonType(type)) @@ -78,6 +80,10 @@ static ParseResult parseROCDLMubufStoreOp(OpAsmParser &parser, return success(); } +void MubufStoreOp::print(OpAsmPrinter &p) { + p << " " << getOperands() << " : " << vdata().getType(); +} + //===----------------------------------------------------------------------===// // ROCDLDialect initialization, type parsing, and registration. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index 6ccde902efd8..133ff048a44a 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -517,50 +517,51 @@ void GenericOp::build( /*libraryCall=*/"", bodyBuild, attributes); } -static void print(OpAsmPrinter &p, GenericOp op) { +void GenericOp::print(OpAsmPrinter &p) { p << " "; // Print extra attributes. - auto genericAttrNames = op.linalgTraitAttrNames(); + auto genericAttrNames = linalgTraitAttrNames(); llvm::StringSet<> genericAttrNamesSet; genericAttrNamesSet.insert(genericAttrNames.begin(), genericAttrNames.end()); SmallVector genericAttrs; - for (auto attr : op->getAttrs()) + for (auto attr : (*this)->getAttrs()) if (genericAttrNamesSet.count(attr.getName().strref()) > 0) genericAttrs.push_back(attr); if (!genericAttrs.empty()) { - auto genericDictAttr = DictionaryAttr::get(op.getContext(), genericAttrs); + auto genericDictAttr = DictionaryAttr::get(getContext(), genericAttrs); p << genericDictAttr; } // Printing is shared with named ops, except for the region and attributes - printCommonStructuredOpParts(p, op); + printCommonStructuredOpParts(p, *this); genericAttrNames.push_back("operand_segment_sizes"); genericAttrNamesSet.insert(genericAttrNames.back()); bool hasExtraAttrs = false; - for (NamedAttribute n : op->getAttrs()) { + for (NamedAttribute n : (*this)->getAttrs()) { if ((hasExtraAttrs = !genericAttrNamesSet.contains(n.getName().strref()))) break; } if (hasExtraAttrs) { p << " attrs = "; - p.printOptionalAttrDict(op->getAttrs(), /*elidedAttrs=*/genericAttrNames); + p.printOptionalAttrDict((*this)->getAttrs(), + /*elidedAttrs=*/genericAttrNames); } // Print region. - if (!op.region().empty()) { + if (!region().empty()) { p << ' '; - p.printRegion(op.region()); + p.printRegion(region()); } // Print results. - printNamedStructuredOpResults(p, op.result_tensors().getTypes()); + printNamedStructuredOpResults(p, result_tensors().getTypes()); } -static ParseResult parseGenericOp(OpAsmParser &parser, OperationState &result) { +ParseResult GenericOp::parse(OpAsmParser &parser, OperationState &result) { DictionaryAttr dictAttr; // Parse the core linalg traits that must check into a dictAttr. // The name is unimportant as we will overwrite result.attributes. @@ -988,15 +989,15 @@ LogicalResult InitTensorOp::reifyResultShapes( // YieldOp //===----------------------------------------------------------------------===// -static void print(OpAsmPrinter &p, linalg::YieldOp op) { - if (op.getNumOperands() > 0) - p << ' ' << op.getOperands(); - p.printOptionalAttrDict(op->getAttrs()); - if (op.getNumOperands() > 0) - p << " : " << op.getOperandTypes(); +void linalg::YieldOp::print(OpAsmPrinter &p) { + if (getNumOperands() > 0) + p << ' ' << getOperands(); + p.printOptionalAttrDict((*this)->getAttrs()); + if (getNumOperands() > 0) + p << " : " << getOperandTypes(); } -static ParseResult parseYieldOp(OpAsmParser &parser, OperationState &result) { +ParseResult YieldOp::parse(OpAsmParser &parser, OperationState &result) { SmallVector opInfo; SmallVector types; SMLoc loc = parser.getCurrentLocation(); @@ -1137,22 +1138,22 @@ void TiledLoopOp::build(OpBuilder &builder, OperationState &result, } } -static void print(OpAsmPrinter &p, TiledLoopOp op) { - p << " (" << op.getInductionVars() << ") = (" << op.lowerBound() << ") to (" - << op.upperBound() << ") step (" << op.step() << ")"; +void TiledLoopOp::print(OpAsmPrinter &p) { + p << " (" << getInductionVars() << ") = (" << lowerBound() << ") to (" + << upperBound() << ") step (" << step() << ")"; - if (!op.inputs().empty()) { + if (!inputs().empty()) { p << " ins ("; - llvm::interleaveComma(llvm::zip(op.getRegionInputArgs(), op.inputs()), p, + llvm::interleaveComma(llvm::zip(getRegionInputArgs(), inputs()), p, [&](auto it) { p << std::get<0>(it) << " = " << std::get<1>(it) << ": " << std::get<1>(it).getType(); }); p << ")"; } - if (!op.outputs().empty()) { + if (!outputs().empty()) { p << " outs ("; - llvm::interleaveComma(llvm::zip(op.getRegionOutputArgs(), op.outputs()), p, + llvm::interleaveComma(llvm::zip(getRegionOutputArgs(), outputs()), p, [&](auto it) { p << std::get<0>(it) << " = " << std::get<1>(it) << ": " << std::get<1>(it).getType(); @@ -1160,25 +1161,24 @@ static void print(OpAsmPrinter &p, TiledLoopOp op) { p << ")"; } - if (llvm::any_of(op.iterator_types(), [](Attribute attr) { + if (llvm::any_of(iterator_types(), [](Attribute attr) { return attr.cast().getValue() != getParallelIteratorTypeName(); })) - p << " iterators" << op.iterator_types(); + p << " iterators" << iterator_types(); - if (op.distribution_types().hasValue()) - p << " distribution" << op.distribution_types().getValue(); + if (distribution_types().hasValue()) + p << " distribution" << distribution_types().getValue(); p << ' '; - p.printRegion(op.region(), /*printEntryBlockArgs=*/false); - p.printOptionalAttrDict( - op->getAttrs(), /*elidedAttrs=*/{TiledLoopOp::getOperandSegmentSizeAttr(), - getIteratorTypesAttrName(), - getDistributionTypesAttrName()}); + p.printRegion(region(), /*printEntryBlockArgs=*/false); + p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{ + TiledLoopOp::getOperandSegmentSizeAttr(), + getIteratorTypesAttrName(), + getDistributionTypesAttrName()}); } -static ParseResult parseTiledLoopOp(OpAsmParser &parser, - OperationState &result) { +ParseResult TiledLoopOp::parse(OpAsmParser &parser, OperationState &result) { auto &builder = parser.getBuilder(); // Parse an opening `(` followed by induction variables followed by `)` SmallVector ivs; diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp index 33192e1bb704..b64fb00ce4cc 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -206,23 +206,22 @@ void AllocaOp::getCanonicalizationPatterns(RewritePatternSet &results, // AllocaScopeOp //===----------------------------------------------------------------------===// -static void print(OpAsmPrinter &p, AllocaScopeOp &op) { +void AllocaScopeOp::print(OpAsmPrinter &p) { bool printBlockTerminators = false; p << ' '; - if (!op.results().empty()) { - p << " -> (" << op.getResultTypes() << ")"; + if (!results().empty()) { + p << " -> (" << getResultTypes() << ")"; printBlockTerminators = true; } p << ' '; - p.printRegion(op.bodyRegion(), + p.printRegion(bodyRegion(), /*printEntryBlockArgs=*/false, /*printBlockTerminators=*/printBlockTerminators); - p.printOptionalAttrDict(op->getAttrs()); + p.printOptionalAttrDict((*this)->getAttrs()); } -static ParseResult parseAllocaScopeOp(OpAsmParser &parser, - OperationState &result) { +ParseResult AllocaScopeOp::parse(OpAsmParser &parser, OperationState &result) { // Create a region for the body. result.regions.reserve(1); Region *bodyRegion = result.addRegion(); @@ -778,17 +777,16 @@ void DmaStartOp::build(OpBuilder &builder, OperationState &result, result.addOperands({stride, elementsPerStride}); } -static void print(OpAsmPrinter &p, DmaStartOp op) { - p << " " << op.getSrcMemRef() << '[' << op.getSrcIndices() << "], " - << op.getDstMemRef() << '[' << op.getDstIndices() << "], " - << op.getNumElements() << ", " << op.getTagMemRef() << '[' - << op.getTagIndices() << ']'; - if (op.isStrided()) - p << ", " << op.getStride() << ", " << op.getNumElementsPerStride(); +void DmaStartOp::print(OpAsmPrinter &p) { + p << " " << getSrcMemRef() << '[' << getSrcIndices() << "], " + << getDstMemRef() << '[' << getDstIndices() << "], " << getNumElements() + << ", " << getTagMemRef() << '[' << getTagIndices() << ']'; + if (isStrided()) + p << ", " << getStride() << ", " << getNumElementsPerStride(); - p.printOptionalAttrDict(op->getAttrs()); - p << " : " << op.getSrcMemRef().getType() << ", " - << op.getDstMemRef().getType() << ", " << op.getTagMemRef().getType(); + p.printOptionalAttrDict((*this)->getAttrs()); + p << " : " << getSrcMemRef().getType() << ", " << getDstMemRef().getType() + << ", " << getTagMemRef().getType(); } // Parse DmaStartOp. @@ -799,8 +797,7 @@ static void print(OpAsmPrinter &p, DmaStartOp op) { // memref<1024 x f32, 2>, // memref<1 x i32> // -static ParseResult parseDmaStartOp(OpAsmParser &parser, - OperationState &result) { +ParseResult DmaStartOp::parse(OpAsmParser &parser, OperationState &result) { OpAsmParser::OperandType srcMemRefInfo; SmallVector srcIndexInfos; OpAsmParser::OperandType dstMemRefInfo; @@ -993,8 +990,8 @@ LogicalResult GenericAtomicRMWOp::verify() { return hasSideEffects ? failure() : success(); } -static ParseResult parseGenericAtomicRMWOp(OpAsmParser &parser, - OperationState &result) { +ParseResult GenericAtomicRMWOp::parse(OpAsmParser &parser, + OperationState &result) { OpAsmParser::OperandType memref; Type memrefType; SmallVector ivs; @@ -1015,11 +1012,11 @@ static ParseResult parseGenericAtomicRMWOp(OpAsmParser &parser, return success(); } -static void print(OpAsmPrinter &p, GenericAtomicRMWOp op) { - p << ' ' << op.memref() << "[" << op.indices() - << "] : " << op.memref().getType() << ' '; - p.printRegion(op.getRegion()); - p.printOptionalAttrDict(op->getAttrs()); +void GenericAtomicRMWOp::print(OpAsmPrinter &p) { + p << ' ' << memref() << "[" << indices() << "] : " << memref().getType() + << ' '; + p.printRegion(getRegion()); + p.printOptionalAttrDict((*this)->getAttrs()); } //===----------------------------------------------------------------------===// @@ -1163,20 +1160,19 @@ OpFoldResult LoadOp::fold(ArrayRef cstOperands) { // PrefetchOp //===----------------------------------------------------------------------===// -static void print(OpAsmPrinter &p, PrefetchOp op) { - p << " " << op.memref() << '['; - p.printOperands(op.indices()); - p << ']' << ", " << (op.isWrite() ? "write" : "read"); - p << ", locality<" << op.localityHint(); - p << ">, " << (op.isDataCache() ? "data" : "instr"); +void PrefetchOp::print(OpAsmPrinter &p) { + p << " " << memref() << '['; + p.printOperands(indices()); + p << ']' << ", " << (isWrite() ? "write" : "read"); + p << ", locality<" << localityHint(); + p << ">, " << (isDataCache() ? "data" : "instr"); p.printOptionalAttrDict( - op->getAttrs(), + (*this)->getAttrs(), /*elidedAttrs=*/{"localityHint", "isWrite", "isDataCache"}); - p << " : " << op.getMemRefType(); + p << " : " << getMemRefType(); } -static ParseResult parsePrefetchOp(OpAsmParser &parser, - OperationState &result) { +ParseResult PrefetchOp::parse(OpAsmParser &parser, OperationState &result) { OpAsmParser::OperandType memrefInfo; SmallVector indexInfo; IntegerAttr localityHint; @@ -1374,12 +1370,19 @@ SmallVector ExpandShapeOp::getReassociationExprs() { getReassociationIndices()); } -static void print(OpAsmPrinter &p, ExpandShapeOp op) { - ::mlir::printReshapeOp(p, op); +ParseResult ExpandShapeOp::parse(OpAsmParser &parser, OperationState &result) { + return parseReshapeLikeOp(parser, result); +} +void ExpandShapeOp::print(OpAsmPrinter &p) { + ::mlir::printReshapeOp(p, *this); } -static void print(OpAsmPrinter &p, CollapseShapeOp op) { - ::mlir::printReshapeOp(p, op); +ParseResult CollapseShapeOp::parse(OpAsmParser &parser, + OperationState &result) { + return parseReshapeLikeOp(parser, result); +} +void CollapseShapeOp::print(OpAsmPrinter &p) { + ::mlir::printReshapeOp(p, *this); } /// Detect whether memref dims [dim, dim + extent) can be reshaped without @@ -2241,15 +2244,13 @@ void TransposeOp::build(OpBuilder &b, OperationState &result, Value in, } // transpose $in $permutation attr-dict : type($in) `to` type(results) -static void print(OpAsmPrinter &p, TransposeOp op) { - p << " " << op.in() << " " << op.permutation(); - p.printOptionalAttrDict(op->getAttrs(), - {TransposeOp::getPermutationAttrName()}); - p << " : " << op.in().getType() << " to " << op.getType(); +void TransposeOp::print(OpAsmPrinter &p) { + p << " " << in() << " " << permutation(); + p.printOptionalAttrDict((*this)->getAttrs(), {getPermutationAttrName()}); + p << " : " << in().getType() << " to " << getType(); } -static ParseResult parseTransposeOp(OpAsmParser &parser, - OperationState &result) { +ParseResult TransposeOp::parse(OpAsmParser &parser, OperationState &result) { OpAsmParser::OperandType in; AffineMap permutation; MemRefType srcType, dstType; @@ -2292,7 +2293,7 @@ OpFoldResult TransposeOp::fold(ArrayRef) { // ViewOp //===----------------------------------------------------------------------===// -static ParseResult parseViewOp(OpAsmParser &parser, OperationState &result) { +ParseResult ViewOp::parse(OpAsmParser &parser, OperationState &result) { OpAsmParser::OperandType srcInfo; SmallVector offsetInfo; SmallVector sizesInfo; @@ -2317,12 +2318,12 @@ static ParseResult parseViewOp(OpAsmParser &parser, OperationState &result) { parser.addTypeToList(dstType, result.types)); } -static void print(OpAsmPrinter &p, ViewOp op) { - p << ' ' << op.getOperand(0) << '['; - p.printOperand(op.byte_shift()); - p << "][" << op.sizes() << ']'; - p.printOptionalAttrDict(op->getAttrs()); - p << " : " << op.getOperand(0).getType() << " to " << op.getType(); +void ViewOp::print(OpAsmPrinter &p) { + p << ' ' << getOperand(0) << '['; + p.printOperand(byte_shift()); + p << "][" << sizes() << ']'; + p.printOptionalAttrDict((*this)->getAttrs()); + p << " : " << getOperand(0).getType() << " to " << getType(); } LogicalResult ViewOp::verify() { diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp index 465f07666d38..fe5af72e98e1 100644 --- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp +++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp @@ -218,8 +218,7 @@ struct RemoveConstantIfCondition : public OpRewritePattern { /// `private` `(` value-list `)`? /// `firstprivate` `(` value-list `)`? /// region attr-dict? -static ParseResult parseParallelOp(OpAsmParser &parser, - OperationState &result) { +ParseResult ParallelOp::parse(OpAsmParser &parser, OperationState &result) { Builder &builder = parser.getBuilder(); SmallVector privateOperands, firstprivateOperands, copyOperands, copyinOperands, @@ -390,99 +389,94 @@ static ParseResult parseParallelOp(OpAsmParser &parser, return success(); } -static void print(OpAsmPrinter &printer, ParallelOp &op) { +void ParallelOp::print(OpAsmPrinter &printer) { // async()? - if (Value async = op.async()) + if (Value async = this->async()) printer << " " << ParallelOp::getAsyncKeyword() << "(" << async << ": " << async.getType() << ")"; // wait()? - printOperandList(op.waitOperands(), ParallelOp::getWaitKeyword(), printer); + printOperandList(waitOperands(), ParallelOp::getWaitKeyword(), printer); // num_gangs()? - if (Value numGangs = op.numGangs()) + if (Value numGangs = this->numGangs()) printer << " " << ParallelOp::getNumGangsKeyword() << "(" << numGangs << ": " << numGangs.getType() << ")"; // num_workers()? - if (Value numWorkers = op.numWorkers()) + if (Value numWorkers = this->numWorkers()) printer << " " << ParallelOp::getNumWorkersKeyword() << "(" << numWorkers << ": " << numWorkers.getType() << ")"; // vector_length()? - if (Value vectorLength = op.vectorLength()) + if (Value vectorLength = this->vectorLength()) printer << " " << ParallelOp::getVectorLengthKeyword() << "(" << vectorLength << ": " << vectorLength.getType() << ")"; // if()? - if (Value ifCond = op.ifCond()) + if (Value ifCond = this->ifCond()) printer << " " << ParallelOp::getIfKeyword() << "(" << ifCond << ")"; // self()? - if (Value selfCond = op.selfCond()) + if (Value selfCond = this->selfCond()) printer << " " << ParallelOp::getSelfKeyword() << "(" << selfCond << ")"; // reduction()? - printOperandList(op.reductionOperands(), ParallelOp::getReductionKeyword(), + printOperandList(reductionOperands(), ParallelOp::getReductionKeyword(), printer); // copy()? - printOperandList(op.copyOperands(), ParallelOp::getCopyKeyword(), printer); + printOperandList(copyOperands(), ParallelOp::getCopyKeyword(), printer); // copyin()? - printOperandList(op.copyinOperands(), ParallelOp::getCopyinKeyword(), - printer); + printOperandList(copyinOperands(), ParallelOp::getCopyinKeyword(), printer); // copyin_readonly()? - printOperandList(op.copyinReadonlyOperands(), + printOperandList(copyinReadonlyOperands(), ParallelOp::getCopyinReadonlyKeyword(), printer); // copyout()? - printOperandList(op.copyoutOperands(), ParallelOp::getCopyoutKeyword(), - printer); + printOperandList(copyoutOperands(), ParallelOp::getCopyoutKeyword(), printer); // copyout_zero()? - printOperandList(op.copyoutZeroOperands(), - ParallelOp::getCopyoutZeroKeyword(), printer); - - // create()? - printOperandList(op.createOperands(), ParallelOp::getCreateKeyword(), + printOperandList(copyoutZeroOperands(), ParallelOp::getCopyoutZeroKeyword(), printer); + // create()? + printOperandList(createOperands(), ParallelOp::getCreateKeyword(), printer); + // create_zero()? - printOperandList(op.createZeroOperands(), ParallelOp::getCreateZeroKeyword(), + printOperandList(createZeroOperands(), ParallelOp::getCreateZeroKeyword(), printer); // no_create()? - printOperandList(op.noCreateOperands(), ParallelOp::getNoCreateKeyword(), + printOperandList(noCreateOperands(), ParallelOp::getNoCreateKeyword(), printer); // present()? - printOperandList(op.presentOperands(), ParallelOp::getPresentKeyword(), - printer); + printOperandList(presentOperands(), ParallelOp::getPresentKeyword(), printer); // deviceptr()? - printOperandList(op.devicePtrOperands(), ParallelOp::getDevicePtrKeyword(), + printOperandList(devicePtrOperands(), ParallelOp::getDevicePtrKeyword(), printer); // attach()? - printOperandList(op.attachOperands(), ParallelOp::getAttachKeyword(), - printer); + printOperandList(attachOperands(), ParallelOp::getAttachKeyword(), printer); // private()? - printOperandList(op.gangPrivateOperands(), ParallelOp::getPrivateKeyword(), + printOperandList(gangPrivateOperands(), ParallelOp::getPrivateKeyword(), printer); // firstprivate()? - printOperandList(op.gangFirstPrivateOperands(), + printOperandList(gangFirstPrivateOperands(), ParallelOp::getFirstPrivateKeyword(), printer); printer << ' '; - printer.printRegion(op.region(), + printer.printRegion(region(), /*printEntryBlockArgs=*/false, /*printBlockTerminators=*/true); printer.printOptionalAttrDictWithKeyword( - op->getAttrs(), ParallelOp::getOperandSegmentSizeAttr()); + (*this)->getAttrs(), ParallelOp::getOperandSegmentSizeAttr()); } unsigned ParallelOp::getNumDataOperands() { @@ -518,7 +512,7 @@ Value ParallelOp::getDataOperand(unsigned i) { /// (`private` `(` value-list `)`)? /// (`reduction` `(` value-list `)`)? /// region attr-dict? -static ParseResult parseLoopOp(OpAsmParser &parser, OperationState &result) { +ParseResult LoopOp::parse(OpAsmParser &parser, OperationState &result) { Builder &builder = parser.getBuilder(); unsigned executionMapping = OpenACCExecMapping::NONE; SmallVector operandTypes; @@ -606,12 +600,12 @@ static ParseResult parseLoopOp(OpAsmParser &parser, OperationState &result) { return success(); } -static void print(OpAsmPrinter &printer, LoopOp &op) { - unsigned execMapping = op.exec_mapping(); +void LoopOp::print(OpAsmPrinter &printer) { + unsigned execMapping = exec_mapping(); if (execMapping & OpenACCExecMapping::GANG) { printer << " " << LoopOp::getGangKeyword(); - Value gangNum = op.gangNum(); - Value gangStatic = op.gangStatic(); + Value gangNum = this->gangNum(); + Value gangStatic = this->gangStatic(); // Print optional gang operands if (gangNum || gangStatic) { @@ -633,7 +627,7 @@ static void print(OpAsmPrinter &printer, LoopOp &op) { printer << " " << LoopOp::getWorkerKeyword(); // Print optional worker operand if present - if (Value workerNum = op.workerNum()) + if (Value workerNum = this->workerNum()) printer << "(" << workerNum << ": " << workerNum.getType() << ")"; } @@ -641,31 +635,30 @@ static void print(OpAsmPrinter &printer, LoopOp &op) { printer << " " << LoopOp::getVectorKeyword(); // Print optional vector operand if present - if (Value vectorLength = op.vectorLength()) + if (Value vectorLength = this->vectorLength()) printer << "(" << vectorLength << ": " << vectorLength.getType() << ")"; } // tile()? - printOperandList(op.tileOperands(), LoopOp::getTileKeyword(), printer); + printOperandList(tileOperands(), LoopOp::getTileKeyword(), printer); // private()? - printOperandList(op.privateOperands(), LoopOp::getPrivateKeyword(), printer); + printOperandList(privateOperands(), LoopOp::getPrivateKeyword(), printer); // reduction()? - printOperandList(op.reductionOperands(), LoopOp::getReductionKeyword(), - printer); + printOperandList(reductionOperands(), LoopOp::getReductionKeyword(), printer); - if (op.getNumResults() > 0) - printer << " -> (" << op.getResultTypes() << ")"; + if (getNumResults() > 0) + printer << " -> (" << getResultTypes() << ")"; printer << ' '; - printer.printRegion(op.region(), + printer.printRegion(region(), /*printEntryBlockArgs=*/false, /*printBlockTerminators=*/true); printer.printOptionalAttrDictWithKeyword( - op->getAttrs(), {LoopOp::getExecutionMappingAttrName(), - LoopOp::getOperandSegmentSizeAttr()}); + (*this)->getAttrs(), {LoopOp::getExecutionMappingAttrName(), + LoopOp::getOperandSegmentSizeAttr()}); } LogicalResult acc::LoopOp::verify() { diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp index d15b6ddea163..3fb60f3d95df 100644 --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -172,48 +172,47 @@ LogicalResult ParallelOp::verify() { return success(); } -static void printParallelOp(OpAsmPrinter &p, ParallelOp op) { +void ParallelOp::print(OpAsmPrinter &p) { p << " "; - if (auto ifCond = op.if_expr_var()) + if (auto ifCond = if_expr_var()) p << "if(" << ifCond << " : " << ifCond.getType() << ") "; - if (auto threads = op.num_threads_var()) + if (auto threads = num_threads_var()) p << "num_threads(" << threads << " : " << threads.getType() << ") "; - printDataVars(p, op.private_vars(), "private"); - printDataVars(p, op.firstprivate_vars(), "firstprivate"); - printDataVars(p, op.shared_vars(), "shared"); - printDataVars(p, op.copyin_vars(), "copyin"); + printDataVars(p, private_vars(), "private"); + printDataVars(p, firstprivate_vars(), "firstprivate"); + printDataVars(p, shared_vars(), "shared"); + printDataVars(p, copyin_vars(), "copyin"); - if (!op.allocate_vars().empty()) - printAllocateAndAllocator(p, op.allocate_vars(), op.allocators_vars()); + if (!allocate_vars().empty()) + printAllocateAndAllocator(p, allocate_vars(), allocators_vars()); - if (auto def = op.default_val()) + if (auto def = default_val()) p << "default(" << stringifyClauseDefault(*def).drop_front(3) << ") "; - if (auto bind = op.proc_bind_val()) + if (auto bind = proc_bind_val()) p << "proc_bind(" << stringifyClauseProcBindKind(*bind) << ") "; p << ' '; - p.printRegion(op.getRegion()); + p.printRegion(getRegion()); } -static void printTargetOp(OpAsmPrinter &p, TargetOp op) { +void TargetOp::print(OpAsmPrinter &p) { p << " "; - if (auto ifCond = op.if_expr()) + if (auto ifCond = if_expr()) p << "if(" << ifCond << " : " << ifCond.getType() << ") "; - if (auto device = op.device()) + if (auto device = this->device()) p << "device(" << device << " : " << device.getType() << ") "; - if (auto threads = op.thread_limit()) + if (auto threads = thread_limit()) p << "thread_limit(" << threads << " : " << threads.getType() << ") "; - if (op.nowait()) { + if (nowait()) p << "nowait "; - } - p.printRegion(op.getRegion()); + p.printRegion(getRegion()); } //===----------------------------------------------------------------------===// @@ -971,8 +970,7 @@ static ParseResult parseClauses(OpAsmParser &parser, OperationState &result, /// clause ::= if | num-threads | private | firstprivate | shared | copyin | /// allocate | default | proc-bind /// -static ParseResult parseParallelOp(OpAsmParser &parser, - OperationState &result) { +ParseResult ParallelOp::parse(OpAsmParser &parser, OperationState &result) { SmallVector clauses = { ifClause, numThreadsClause, privateClause, firstprivateClause, sharedClause, copyinClause, @@ -1000,7 +998,7 @@ static ParseResult parseParallelOp(OpAsmParser &parser, /// clause-list ::= clause | clause clause-list /// clause ::= if | device | thread_limit | nowait /// -static ParseResult parseTargetOp(OpAsmParser &parser, OperationState &result) { +ParseResult TargetOp::parse(OpAsmParser &parser, OperationState &result) { SmallVector clauses = {ifClause, deviceClause, threadLimitClause, nowaitClause}; @@ -1031,9 +1029,7 @@ static ParseResult parseTargetOp(OpAsmParser &parser, OperationState &result) { /// clause-list ::= clause clause-list | empty /// clause ::= private | firstprivate | lastprivate | reduction | allocate | /// nowait -static ParseResult parseSectionsOp(OpAsmParser &parser, - OperationState &result) { - +ParseResult SectionsOp::parse(OpAsmParser &parser, OperationState &result) { SmallVector clauses = {privateClause, firstprivateClause, lastprivateClause, reductionClause, allocateClause, nowaitClause}; @@ -1053,23 +1049,23 @@ static ParseResult parseSectionsOp(OpAsmParser &parser, return success(); } -static void printSectionsOp(OpAsmPrinter &p, SectionsOp op) { +void SectionsOp::print(OpAsmPrinter &p) { p << " "; - printDataVars(p, op.private_vars(), "private"); - printDataVars(p, op.firstprivate_vars(), "firstprivate"); - printDataVars(p, op.lastprivate_vars(), "lastprivate"); + printDataVars(p, private_vars(), "private"); + printDataVars(p, firstprivate_vars(), "firstprivate"); + printDataVars(p, lastprivate_vars(), "lastprivate"); - if (!op.reduction_vars().empty()) - printReductionVarList(p, op.reductions(), op.reduction_vars()); + if (!reduction_vars().empty()) + printReductionVarList(p, reductions(), reduction_vars()); - if (!op.allocate_vars().empty()) - printAllocateAndAllocator(p, op.allocate_vars(), op.allocators_vars()); + if (!allocate_vars().empty()) + printAllocateAndAllocator(p, allocate_vars(), allocators_vars()); - if (op.nowait()) + if (nowait()) p << "nowait"; p << ' '; - p.printRegion(op.region()); + p.printRegion(region()); } LogicalResult SectionsOp::verify() { @@ -1108,8 +1104,7 @@ LogicalResult SectionsOp::verify() { /// clause-list ::= clause clause-list | empty /// clause ::= private | firstprivate | lastprivate | linear | schedule | // collapse | nowait | ordered | order | reduction -static ParseResult parseWsLoopOp(OpAsmParser &parser, OperationState &result) { - +ParseResult WsLoopOp::parse(OpAsmParser &parser, OperationState &result) { // Parse an opening `(` followed by induction variables followed by `)` SmallVector ivs; if (parser.parseRegionArgumentList(ivs, /*requiredOperandCount=*/-1, @@ -1166,43 +1161,43 @@ static ParseResult parseWsLoopOp(OpAsmParser &parser, OperationState &result) { return success(); } -static void printWsLoopOp(OpAsmPrinter &p, WsLoopOp op) { - auto args = op.getRegion().front().getArguments(); - p << " (" << args << ") : " << args[0].getType() << " = (" << op.lowerBound() - << ") to (" << op.upperBound() << ") "; - if (op.inclusive()) { +void WsLoopOp::print(OpAsmPrinter &p) { + auto args = getRegion().front().getArguments(); + p << " (" << args << ") : " << args[0].getType() << " = (" << lowerBound() + << ") to (" << upperBound() << ") "; + if (inclusive()) { p << "inclusive "; } - p << "step (" << op.step() << ") "; + p << "step (" << step() << ") "; - printDataVars(p, op.private_vars(), "private"); - printDataVars(p, op.firstprivate_vars(), "firstprivate"); - printDataVars(p, op.lastprivate_vars(), "lastprivate"); + printDataVars(p, private_vars(), "private"); + printDataVars(p, firstprivate_vars(), "firstprivate"); + printDataVars(p, lastprivate_vars(), "lastprivate"); - if (!op.linear_vars().empty()) - printLinearClause(p, op.linear_vars(), op.linear_step_vars()); + if (!linear_vars().empty()) + printLinearClause(p, linear_vars(), linear_step_vars()); - if (auto sched = op.schedule_val()) - printScheduleClause(p, sched.getValue(), op.schedule_modifier(), - op.simd_modifier(), op.schedule_chunk_var()); + if (auto sched = schedule_val()) + printScheduleClause(p, sched.getValue(), schedule_modifier(), + simd_modifier(), schedule_chunk_var()); - if (auto collapse = op.collapse_val()) + if (auto collapse = collapse_val()) p << "collapse(" << collapse << ") "; - if (op.nowait()) + if (nowait()) p << "nowait "; - if (auto ordered = op.ordered_val()) + if (auto ordered = ordered_val()) p << "ordered(" << ordered << ") "; - if (auto order = op.order_val()) + if (auto order = order_val()) p << "order(" << stringifyClauseOrderKind(*order) << ") "; - if (!op.reduction_vars().empty()) - printReductionVarList(p, op.reductions(), op.reduction_vars()); + if (!reduction_vars().empty()) + printReductionVarList(p, reductions(), reduction_vars()); p << ' '; - p.printRegion(op.region(), /*printEntryBlockArgs=*/false); + p.printRegion(region(), /*printEntryBlockArgs=*/false); } //===----------------------------------------------------------------------===// @@ -1439,8 +1434,7 @@ LogicalResult OrderedRegionOp::verify() { /// /// operation ::= `omp.atomic.read` atomic-clause-list address `->` result-type /// address ::= operand `:` type -static ParseResult parseAtomicReadOp(OpAsmParser &parser, - OperationState &result) { +ParseResult AtomicReadOp::parse(OpAsmParser &parser, OperationState &result) { OpAsmParser::OperandType x, v; Type addressType; SmallVector clauses = {memoryOrderClause, hintClause}; @@ -1456,14 +1450,13 @@ static ParseResult parseAtomicReadOp(OpAsmParser &parser, return success(); } -/// Printer for AtomicReadOp -static void printAtomicReadOp(OpAsmPrinter &p, AtomicReadOp op) { - p << " " << op.v() << " = " << op.x() << " "; - if (auto mo = op.memory_order()) +void AtomicReadOp::print(OpAsmPrinter &p) { + p << " " << v() << " = " << x() << " "; + if (auto mo = memory_order()) p << "memory_order(" << stringifyClauseMemoryOrderKind(*mo) << ") "; - if (op.hintAttr()) - printSynchronizationHint(p << " ", op, op.hintAttr()); - p << ": " << op.x().getType(); + if (hintAttr()) + printSynchronizationHint(p << " ", *this, hintAttr()); + p << ": " << x().getType(); } /// Verifier for AtomicReadOp @@ -1491,8 +1484,7 @@ LogicalResult AtomicReadOp::verify() { /// operands ::= address `,` value /// address ::= operand `:` type /// value ::= operand `:` type -static ParseResult parseAtomicWriteOp(OpAsmParser &parser, - OperationState &result) { +ParseResult AtomicWriteOp::parse(OpAsmParser &parser, OperationState &result) { OpAsmParser::OperandType address, value; Type addrType, valueType; SmallVector clauses = {memoryOrderClause, hintClause}; @@ -1509,14 +1501,13 @@ static ParseResult parseAtomicWriteOp(OpAsmParser &parser, return success(); } -/// Printer for AtomicWriteOp -static void printAtomicWriteOp(OpAsmPrinter &p, AtomicWriteOp op) { - p << " " << op.address() << " = " << op.value() << " "; - if (auto mo = op.memory_order()) +void AtomicWriteOp::print(OpAsmPrinter &p) { + p << " " << address() << " = " << value() << " "; + if (auto mo = memory_order()) p << "memory_order(" << stringifyClauseMemoryOrderKind(*mo) << ") "; - if (op.hintAttr()) - printSynchronizationHint(p, op, op.hintAttr()); - p << ": " << op.address().getType() << ", " << op.value().getType(); + if (hintAttr()) + printSynchronizationHint(p, *this, hintAttr()); + p << ": " << address().getType() << ", " << value().getType(); } /// Verifier for AtomicWriteOp @@ -1538,8 +1529,7 @@ LogicalResult AtomicWriteOp::verify() { /// Parser for AtomicUpdateOp /// /// operation ::= `omp.atomic.update` atomic-clause-list region -static ParseResult parseAtomicUpdateOp(OpAsmParser &parser, - OperationState &result) { +ParseResult AtomicUpdateOp::parse(OpAsmParser &parser, OperationState &result) { SmallVector clauses = {memoryOrderClause, hintClause}; SmallVector segments; OpAsmParser::OperandType x, y, z; @@ -1579,24 +1569,22 @@ static ParseResult parseAtomicUpdateOp(OpAsmParser &parser, return parser.resolveOperand(expr, exprType, result.operands); } -/// Printer for AtomicUpdateOp -static void printAtomicUpdateOp(OpAsmPrinter &p, AtomicUpdateOp op) { - p << " " << op.x() << " = "; +void AtomicUpdateOp::print(OpAsmPrinter &p) { + p << " " << x() << " = "; Value y, z; - if (op.isXBinopExpr()) { - y = op.x(); - z = op.expr(); + if (isXBinopExpr()) { + y = x(); + z = expr(); } else { - y = op.expr(); - z = op.x(); + y = expr(); + z = x(); } - p << y << " " << AtomicBinOpKindToString(op.binop()).lower() << " " << z - << " "; - if (auto mo = op.memory_order()) + p << y << " " << AtomicBinOpKindToString(binop()).lower() << " " << z << " "; + if (auto mo = memory_order()) p << "memory_order(" << stringifyClauseMemoryOrderKind(*mo) << ") "; - if (op.hintAttr()) - printSynchronizationHint(p, op, op.hintAttr()); - p << ": " << op.x().getType() << ", " << op.expr().getType(); + if (hintAttr()) + printSynchronizationHint(p, *this, hintAttr()); + p << ": " << x().getType() << ", " << expr().getType(); } /// Verifier for AtomicUpdateOp @@ -1615,9 +1603,8 @@ LogicalResult AtomicUpdateOp::verify() { // AtomicCaptureOp //===----------------------------------------------------------------------===// -/// Parser for AtomicCaptureOp -static LogicalResult parseAtomicCaptureOp(OpAsmParser &parser, - OperationState &result) { +ParseResult AtomicCaptureOp::parse(OpAsmParser &parser, + OperationState &result) { SmallVector clauses = {memoryOrderClause, hintClause}; SmallVector segments; if (parseClauses(parser, result, clauses, segments) || @@ -1626,13 +1613,12 @@ static LogicalResult parseAtomicCaptureOp(OpAsmParser &parser, return success(); } -/// Printer for AtomicCaptureOp -static void printAtomicCaptureOp(OpAsmPrinter &p, AtomicCaptureOp op) { - if (op.memory_order()) - p << "memory_order(" << op.memory_order() << ") "; - if (op.hintAttr()) - printSynchronizationHint(p, op, op.hintAttr()); - p.printRegion(op.region()); +void AtomicCaptureOp::print(OpAsmPrinter &p) { + if (memory_order()) + p << "memory_order(" << memory_order() << ") "; + if (hintAttr()) + printSynchronizationHint(p, *this, hintAttr()); + p.printRegion(region()); } /// Verifier for AtomicCaptureOp diff --git a/mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp b/mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp index 039e38c855da..1a45f4a92df5 100644 --- a/mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp +++ b/mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp @@ -97,7 +97,7 @@ void ForEachOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state, } } -static ParseResult parseForEachOp(OpAsmParser &parser, OperationState &result) { +ParseResult ForEachOp::parse(OpAsmParser &parser, OperationState &result) { // Parse the loop variable followed by type. OpAsmParser::OperandType loopVariable; Type loopVariableType; @@ -137,13 +137,13 @@ static ParseResult parseForEachOp(OpAsmParser &parser, OperationState &result) { return success(); } -static void print(OpAsmPrinter &p, ForEachOp op) { - BlockArgument arg = op.getLoopVariable(); - p << ' ' << arg << " : " << arg.getType() << " in " << op.values() << ' '; - p.printRegion(op.region(), /*printEntryBlockArgs=*/false); - p.printOptionalAttrDict(op->getAttrs()); +void ForEachOp::print(OpAsmPrinter &p) { + BlockArgument arg = getLoopVariable(); + p << ' ' << arg << " : " << arg.getType() << " in " << values() << ' '; + p.printRegion(region(), /*printEntryBlockArgs=*/false); + p.printOptionalAttrDict((*this)->getAttrs()); p << " -> "; - p.printSuccessor(op.successor()); + p.printSuccessor(successor()); } LogicalResult ForEachOp::verify() { diff --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp index d8a76f835df9..258ead87848f 100644 --- a/mlir/lib/Dialect/SCF/SCF.cpp +++ b/mlir/lib/Dialect/SCF/SCF.cpp @@ -101,8 +101,8 @@ static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op, /// return %idx : i32 /// } /// -static ParseResult parseExecuteRegionOp(OpAsmParser &parser, - OperationState &result) { +ParseResult ExecuteRegionOp::parse(OpAsmParser &parser, + OperationState &result) { if (parser.parseOptionalArrowTypeList(result.types)) return failure(); @@ -115,15 +115,15 @@ static ParseResult parseExecuteRegionOp(OpAsmParser &parser, return success(); } -static void print(OpAsmPrinter &p, ExecuteRegionOp op) { - p.printOptionalArrowTypeList(op.getResultTypes()); +void ExecuteRegionOp::print(OpAsmPrinter &p) { + p.printOptionalArrowTypeList(getResultTypes()); p << ' '; - p.printRegion(op.getRegion(), + p.printRegion(getRegion(), /*printEntryBlockArgs=*/false, /*printBlockTerminators=*/true); - p.printOptionalAttrDict(op->getAttrs()); + p.printOptionalAttrDict((*this)->getAttrs()); } LogicalResult ExecuteRegionOp::verify() { @@ -340,22 +340,22 @@ static void printInitializationList(OpAsmPrinter &p, p << ")"; } -static void print(OpAsmPrinter &p, ForOp op) { - p << " " << op.getInductionVar() << " = " << op.getLowerBound() << " to " - << op.getUpperBound() << " step " << op.getStep(); +void ForOp::print(OpAsmPrinter &p) { + p << " " << getInductionVar() << " = " << getLowerBound() << " to " + << getUpperBound() << " step " << getStep(); - printInitializationList(p, op.getRegionIterArgs(), op.getIterOperands(), + printInitializationList(p, getRegionIterArgs(), getIterOperands(), " iter_args"); - if (!op.getIterOperands().empty()) - p << " -> (" << op.getIterOperands().getTypes() << ')'; + if (!getIterOperands().empty()) + p << " -> (" << getIterOperands().getTypes() << ')'; p << ' '; - p.printRegion(op.getRegion(), + p.printRegion(getRegion(), /*printEntryBlockArgs=*/false, - /*printBlockTerminators=*/op.hasIterOperands()); - p.printOptionalAttrDict(op->getAttrs()); + /*printBlockTerminators=*/hasIterOperands()); + p.printOptionalAttrDict((*this)->getAttrs()); } -static ParseResult parseForOp(OpAsmParser &parser, OperationState &result) { +ParseResult ForOp::parse(OpAsmParser &parser, OperationState &result) { auto &builder = parser.getBuilder(); OpAsmParser::OperandType inductionVariable, lb, ub, step; // Parse the induction variable followed by '='. @@ -1070,7 +1070,7 @@ LogicalResult IfOp::verify() { return RegionBranchOpInterface::verifyTypes(*this); } -static ParseResult parseIfOp(OpAsmParser &parser, OperationState &result) { +ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) { // Create the regions for 'then'. result.regions.reserve(2); Region *thenRegion = result.addRegion(); @@ -1103,22 +1103,22 @@ static ParseResult parseIfOp(OpAsmParser &parser, OperationState &result) { return success(); } -static void print(OpAsmPrinter &p, IfOp op) { +void IfOp::print(OpAsmPrinter &p) { bool printBlockTerminators = false; - p << " " << op.getCondition(); - if (!op.getResults().empty()) { - p << " -> (" << op.getResultTypes() << ")"; + p << " " << getCondition(); + if (!getResults().empty()) { + p << " -> (" << getResultTypes() << ")"; // Print yield explicitly if the op defines values. printBlockTerminators = true; } p << ' '; - p.printRegion(op.getThenRegion(), + p.printRegion(getThenRegion(), /*printEntryBlockArgs=*/false, /*printBlockTerminators=*/printBlockTerminators); // Print the 'else' regions if it exists and has a block. - auto &elseRegion = op.getElseRegion(); + auto &elseRegion = getElseRegion(); if (!elseRegion.empty()) { p << " else "; p.printRegion(elseRegion, @@ -1126,7 +1126,7 @@ static void print(OpAsmPrinter &p, IfOp op) { /*printBlockTerminators=*/printBlockTerminators); } - p.printOptionalAttrDict(op->getAttrs()); + p.printOptionalAttrDict((*this)->getAttrs()); } /// Given the region at `index`, or the parent operation if `index` is None, @@ -1784,8 +1784,7 @@ LogicalResult ParallelOp::verify() { return success(); } -static ParseResult parseParallelOp(OpAsmParser &parser, - OperationState &result) { +ParseResult ParallelOp::parse(OpAsmParser &parser, OperationState &result) { auto &builder = parser.getBuilder(); // Parse an opening `(` followed by induction variables followed by `)` SmallVector ivs; @@ -1855,16 +1854,17 @@ static ParseResult parseParallelOp(OpAsmParser &parser, return success(); } -static void print(OpAsmPrinter &p, ParallelOp op) { - p << " (" << op.getBody()->getArguments() << ") = (" << op.getLowerBound() - << ") to (" << op.getUpperBound() << ") step (" << op.getStep() << ")"; - if (!op.getInitVals().empty()) - p << " init (" << op.getInitVals() << ")"; - p.printOptionalArrowTypeList(op.getResultTypes()); +void ParallelOp::print(OpAsmPrinter &p) { + p << " (" << getBody()->getArguments() << ") = (" << getLowerBound() + << ") to (" << getUpperBound() << ") step (" << getStep() << ")"; + if (!getInitVals().empty()) + p << " init (" << getInitVals() << ")"; + p.printOptionalArrowTypeList(getResultTypes()); p << ' '; - p.printRegion(op.getRegion(), /*printEntryBlockArgs=*/false); + p.printRegion(getRegion(), /*printEntryBlockArgs=*/false); p.printOptionalAttrDict( - op->getAttrs(), /*elidedAttrs=*/ParallelOp::getOperandSegmentSizeAttr()); + (*this)->getAttrs(), + /*elidedAttrs=*/ParallelOp::getOperandSegmentSizeAttr()); } Region &ParallelOp::getLoopBody() { return getRegion(); } @@ -2096,7 +2096,7 @@ LogicalResult ReduceOp::verify() { return success(); } -static ParseResult parseReduceOp(OpAsmParser &parser, OperationState &result) { +ParseResult ReduceOp::parse(OpAsmParser &parser, OperationState &result) { // Parse an opening `(` followed by the reduced value followed by `)` OpAsmParser::OperandType operand; if (parser.parseLParen() || parser.parseOperand(operand) || @@ -2117,10 +2117,10 @@ static ParseResult parseReduceOp(OpAsmParser &parser, OperationState &result) { return success(); } -static void print(OpAsmPrinter &p, ReduceOp op) { - p << "(" << op.getOperand() << ") "; - p << " : " << op.getOperand().getType() << ' '; - p.printRegion(op.getReductionOperator()); +void ReduceOp::print(OpAsmPrinter &p) { + p << "(" << getOperand() << ") "; + p << " : " << getOperand().getType() << ' '; + p.printRegion(getReductionOperator()); } //===----------------------------------------------------------------------===// @@ -2192,7 +2192,7 @@ void WhileOp::getSuccessorRegions(Optional index, /// initializer ::= /* empty */ | `(` assignment-list `)` /// assignment-list ::= assignment | assignment `,` assignment-list /// assignment ::= ssa-value `=` ssa-value -static ParseResult parseWhileOp(OpAsmParser &parser, OperationState &result) { +ParseResult scf::WhileOp::parse(OpAsmParser &parser, OperationState &result) { SmallVector regionArgs, operands; Region *before = result.addRegion(); Region *after = result.addRegion(); @@ -2229,16 +2229,16 @@ static ParseResult parseWhileOp(OpAsmParser &parser, OperationState &result) { } /// Prints a `while` op. -static void print(OpAsmPrinter &p, scf::WhileOp op) { - printInitializationList(p, op.getBefore().front().getArguments(), - op.getInits(), " "); +void scf::WhileOp::print(OpAsmPrinter &p) { + printInitializationList(p, getBefore().front().getArguments(), getInits(), + " "); p << " : "; - p.printFunctionalType(op.getInits().getTypes(), op.getResults().getTypes()); + p.printFunctionalType(getInits().getTypes(), getResults().getTypes()); p << ' '; - p.printRegion(op.getBefore(), /*printEntryBlockArgs=*/false); + p.printRegion(getBefore(), /*printEntryBlockArgs=*/false); p << " do "; - p.printRegion(op.getAfter()); - p.printOptionalAttrDictWithKeyword(op->getAttrs()); + p.printRegion(getAfter()); + p.printOptionalAttrDictWithKeyword((*this)->getAttrs()); } /// Verifies that two ranges of types match, i.e. have the same number of diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp index 1d060473547d..1b4b4de3bb17 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp @@ -924,84 +924,15 @@ static LogicalResult verifyGroupNonUniformArithmeticOp(Operation *groupOp) { return success(); } -static ParseResult parseUnaryOp(OpAsmParser &parser, OperationState &state) { - OpAsmParser::OperandType operandInfo; - Type type; - if (parser.parseOperand(operandInfo) || parser.parseColonType(type) || - parser.resolveOperands(operandInfo, type, state.operands)) { - return failure(); - } - state.addTypes(type); - return success(); -} - -static void printUnaryOp(Operation *unaryOp, OpAsmPrinter &printer) { - printer << ' ' << unaryOp->getOperand(0) << " : " - << unaryOp->getOperand(0).getType(); -} - /// Result of a logical op must be a scalar or vector of boolean type. -static Type getUnaryOpResultType(Builder &builder, Type operandType) { +static Type getUnaryOpResultType(Type operandType) { + Builder builder(operandType.getContext()); Type resultType = builder.getIntegerType(1); - if (auto vecType = operandType.dyn_cast()) { + if (auto vecType = operandType.dyn_cast()) return VectorType::get(vecType.getNumElements(), resultType); - } return resultType; } -static ParseResult parseLogicalUnaryOp(OpAsmParser &parser, - OperationState &state) { - OpAsmParser::OperandType operandInfo; - Type type; - if (parser.parseOperand(operandInfo) || parser.parseColonType(type) || - parser.resolveOperand(operandInfo, type, state.operands)) { - return failure(); - } - state.addTypes(getUnaryOpResultType(parser.getBuilder(), type)); - return success(); -} - -static ParseResult parseLogicalBinaryOp(OpAsmParser &parser, - OperationState &result) { - SmallVector ops; - Type type; - if (parser.parseOperandList(ops, 2) || parser.parseColonType(type) || - parser.resolveOperands(ops, type, result.operands)) { - return failure(); - } - result.addTypes(getUnaryOpResultType(parser.getBuilder(), type)); - return success(); -} - -static void printLogicalOp(Operation *logicalOp, OpAsmPrinter &printer) { - printer << ' ' << logicalOp->getOperands() << " : " - << logicalOp->getOperand(0).getType(); -} - -static ParseResult parseShiftOp(OpAsmParser &parser, OperationState &state) { - SmallVector operandInfo; - Type baseType; - Type shiftType; - auto loc = parser.getCurrentLocation(); - - if (parser.parseOperandList(operandInfo, 2) || parser.parseColon() || - parser.parseType(baseType) || parser.parseComma() || - parser.parseType(shiftType) || - parser.resolveOperands(operandInfo, {baseType, shiftType}, loc, - state.operands)) { - return failure(); - } - state.addTypes(baseType); - return success(); -} - -static void printShiftOp(Operation *op, OpAsmPrinter &printer) { - Value base = op->getOperand(0); - Value shift = op->getOperand(1); - printer << ' ' << base << ", " << shift << " : " << base.getType() << ", " - << shift.getType(); -} - static LogicalResult verifyShiftOp(Operation *op) { if (op->getOperand(0).getType() != op->getResult(0).getType()) { return op->emitError("expected the same type for the first operand and " @@ -1096,8 +1027,8 @@ void spirv::AccessChainOp::build(OpBuilder &builder, OperationState &state, build(builder, state, type, basePtr, indices); } -static ParseResult parseAccessChainOp(OpAsmParser &parser, - OperationState &state) { +ParseResult spirv::AccessChainOp::parse(OpAsmParser &parser, + OperationState &state) { OpAsmParser::OperandType ptrInfo; SmallVector indicesInfo; Type type; @@ -1114,8 +1045,8 @@ static ParseResult parseAccessChainOp(OpAsmParser &parser, // Check that the provided indices list is not empty before parsing their // type list. if (indicesInfo.empty()) { - return emitError(state.location, "'spv.AccessChain' op expected at " - "least one index "); + return mlir::emitError(state.location, "'spv.AccessChain' op expected at " + "least one index "); } if (parser.parseComma() || parser.parseTypeList(indicesTypes)) @@ -1124,9 +1055,9 @@ static ParseResult parseAccessChainOp(OpAsmParser &parser, // Check that the indices types list is not empty and that it has a one-to-one // mapping to the provided indices. if (indicesTypes.size() != indicesInfo.size()) { - return emitError(state.location, "'spv.AccessChain' op indices " - "types' count must be equal to indices " - "info count"); + return mlir::emitError(state.location, + "'spv.AccessChain' op indices types' count must be " + "equal to indices info count"); } if (parser.resolveOperands(indicesInfo, indicesTypes, loc, state.operands)) @@ -1148,8 +1079,8 @@ static void printAccessChain(Op op, ValueRange indices, OpAsmPrinter &printer) { << "] : " << op.base_ptr().getType() << ", " << indices.getTypes(); } -static void print(spirv::AccessChainOp op, OpAsmPrinter &printer) { - printAccessChain(op, op.indices(), printer); +void spirv::AccessChainOp::print(OpAsmPrinter &printer) { + printAccessChain(*this, indices(), printer); } template @@ -1280,6 +1211,14 @@ LogicalResult spirv::AtomicAndOp::verify() { return ::verifyAtomicUpdateOp(getOperation()); } +ParseResult spirv::AtomicAndOp::parse(OpAsmParser &parser, + OperationState &result) { + return ::parseAtomicUpdateOp(parser, result, true); +} +void spirv::AtomicAndOp::print(OpAsmPrinter &p) { + ::printAtomicUpdateOp(*this, p); +} + //===----------------------------------------------------------------------===// // spv.AtomicCompareExchangeOp //===----------------------------------------------------------------------===// @@ -1288,6 +1227,14 @@ LogicalResult spirv::AtomicCompareExchangeOp::verify() { return ::verifyAtomicCompareExchangeImpl(*this); } +ParseResult spirv::AtomicCompareExchangeOp::parse(OpAsmParser &parser, + OperationState &result) { + return ::parseAtomicCompareExchangeImpl(parser, result); +} +void spirv::AtomicCompareExchangeOp::print(OpAsmPrinter &p) { + ::printAtomicCompareExchangeImpl(*this, p); +} + //===----------------------------------------------------------------------===// // spv.AtomicCompareExchangeWeakOp //===----------------------------------------------------------------------===// @@ -1296,18 +1243,26 @@ LogicalResult spirv::AtomicCompareExchangeWeakOp::verify() { return ::verifyAtomicCompareExchangeImpl(*this); } +ParseResult spirv::AtomicCompareExchangeWeakOp::parse(OpAsmParser &parser, + OperationState &result) { + return ::parseAtomicCompareExchangeImpl(parser, result); +} +void spirv::AtomicCompareExchangeWeakOp::print(OpAsmPrinter &p) { + ::printAtomicCompareExchangeImpl(*this, p); +} + //===----------------------------------------------------------------------===// // spv.AtomicExchange //===----------------------------------------------------------------------===// -static void print(spirv::AtomicExchangeOp atomOp, OpAsmPrinter &printer) { - printer << " \"" << stringifyScope(atomOp.memory_scope()) << "\" \"" - << stringifyMemorySemantics(atomOp.semantics()) << "\" " - << atomOp.getOperands() << " : " << atomOp.pointer().getType(); +void spirv::AtomicExchangeOp::print(OpAsmPrinter &printer) { + printer << " \"" << stringifyScope(memory_scope()) << "\" \"" + << stringifyMemorySemantics(semantics()) << "\" " << getOperands() + << " : " << pointer().getType(); } -static ParseResult parseAtomicExchangeOp(OpAsmParser &parser, - OperationState &state) { +ParseResult spirv::AtomicExchangeOp::parse(OpAsmParser &parser, + OperationState &state) { spirv::Scope memoryScope; spirv::MemorySemantics semantics; SmallVector operandInfo; @@ -1349,13 +1304,21 @@ LogicalResult spirv::AtomicExchangeOp::verify() { } //===----------------------------------------------------------------------===// -// spv.AtomicFAddEXTOp +// spv.AtomicIAddOp //===----------------------------------------------------------------------===// LogicalResult spirv::AtomicIAddOp::verify() { return ::verifyAtomicUpdateOp(getOperation()); } +ParseResult spirv::AtomicIAddOp::parse(OpAsmParser &parser, + OperationState &result) { + return ::parseAtomicUpdateOp(parser, result, true); +} +void spirv::AtomicIAddOp::print(OpAsmPrinter &p) { + ::printAtomicUpdateOp(*this, p); +} + //===----------------------------------------------------------------------===// // spv.AtomicFAddEXTOp //===----------------------------------------------------------------------===// @@ -1364,6 +1327,14 @@ LogicalResult spirv::AtomicFAddEXTOp::verify() { return ::verifyAtomicUpdateOp(getOperation()); } +ParseResult spirv::AtomicFAddEXTOp::parse(OpAsmParser &parser, + OperationState &result) { + return ::parseAtomicUpdateOp(parser, result, true); +} +void spirv::AtomicFAddEXTOp::print(OpAsmPrinter &p) { + ::printAtomicUpdateOp(*this, p); +} + //===----------------------------------------------------------------------===// // spv.AtomicIDecrementOp //===----------------------------------------------------------------------===// @@ -1372,6 +1343,14 @@ LogicalResult spirv::AtomicIDecrementOp::verify() { return ::verifyAtomicUpdateOp(getOperation()); } +ParseResult spirv::AtomicIDecrementOp::parse(OpAsmParser &parser, + OperationState &result) { + return ::parseAtomicUpdateOp(parser, result, false); +} +void spirv::AtomicIDecrementOp::print(OpAsmPrinter &p) { + ::printAtomicUpdateOp(*this, p); +} + //===----------------------------------------------------------------------===// // spv.AtomicIIncrementOp //===----------------------------------------------------------------------===// @@ -1380,6 +1359,14 @@ LogicalResult spirv::AtomicIIncrementOp::verify() { return ::verifyAtomicUpdateOp(getOperation()); } +ParseResult spirv::AtomicIIncrementOp::parse(OpAsmParser &parser, + OperationState &result) { + return ::parseAtomicUpdateOp(parser, result, false); +} +void spirv::AtomicIIncrementOp::print(OpAsmPrinter &p) { + ::printAtomicUpdateOp(*this, p); +} + //===----------------------------------------------------------------------===// // spv.AtomicISubOp //===----------------------------------------------------------------------===// @@ -1388,6 +1375,14 @@ LogicalResult spirv::AtomicISubOp::verify() { return ::verifyAtomicUpdateOp(getOperation()); } +ParseResult spirv::AtomicISubOp::parse(OpAsmParser &parser, + OperationState &result) { + return ::parseAtomicUpdateOp(parser, result, true); +} +void spirv::AtomicISubOp::print(OpAsmPrinter &p) { + ::printAtomicUpdateOp(*this, p); +} + //===----------------------------------------------------------------------===// // spv.AtomicOrOp //===----------------------------------------------------------------------===// @@ -1396,6 +1391,14 @@ LogicalResult spirv::AtomicOrOp::verify() { return ::verifyAtomicUpdateOp(getOperation()); } +ParseResult spirv::AtomicOrOp::parse(OpAsmParser &parser, + OperationState &result) { + return ::parseAtomicUpdateOp(parser, result, true); +} +void spirv::AtomicOrOp::print(OpAsmPrinter &p) { + ::printAtomicUpdateOp(*this, p); +} + //===----------------------------------------------------------------------===// // spv.AtomicSMaxOp //===----------------------------------------------------------------------===// @@ -1404,6 +1407,14 @@ LogicalResult spirv::AtomicSMaxOp::verify() { return ::verifyAtomicUpdateOp(getOperation()); } +ParseResult spirv::AtomicSMaxOp::parse(OpAsmParser &parser, + OperationState &result) { + return ::parseAtomicUpdateOp(parser, result, true); +} +void spirv::AtomicSMaxOp::print(OpAsmPrinter &p) { + ::printAtomicUpdateOp(*this, p); +} + //===----------------------------------------------------------------------===// // spv.AtomicSMinOp //===----------------------------------------------------------------------===// @@ -1412,6 +1423,14 @@ LogicalResult spirv::AtomicSMinOp::verify() { return ::verifyAtomicUpdateOp(getOperation()); } +ParseResult spirv::AtomicSMinOp::parse(OpAsmParser &parser, + OperationState &result) { + return ::parseAtomicUpdateOp(parser, result, true); +} +void spirv::AtomicSMinOp::print(OpAsmPrinter &p) { + ::printAtomicUpdateOp(*this, p); +} + //===----------------------------------------------------------------------===// // spv.AtomicUMaxOp //===----------------------------------------------------------------------===// @@ -1420,6 +1439,14 @@ LogicalResult spirv::AtomicUMaxOp::verify() { return ::verifyAtomicUpdateOp(getOperation()); } +ParseResult spirv::AtomicUMaxOp::parse(OpAsmParser &parser, + OperationState &result) { + return ::parseAtomicUpdateOp(parser, result, true); +} +void spirv::AtomicUMaxOp::print(OpAsmPrinter &p) { + ::printAtomicUpdateOp(*this, p); +} + //===----------------------------------------------------------------------===// // spv.AtomicUMinOp //===----------------------------------------------------------------------===// @@ -1428,6 +1455,14 @@ LogicalResult spirv::AtomicUMinOp::verify() { return ::verifyAtomicUpdateOp(getOperation()); } +ParseResult spirv::AtomicUMinOp::parse(OpAsmParser &parser, + OperationState &result) { + return ::parseAtomicUpdateOp(parser, result, true); +} +void spirv::AtomicUMinOp::print(OpAsmPrinter &p) { + ::printAtomicUpdateOp(*this, p); +} + //===----------------------------------------------------------------------===// // spv.AtomicXorOp //===----------------------------------------------------------------------===// @@ -1436,6 +1471,14 @@ LogicalResult spirv::AtomicXorOp::verify() { return ::verifyAtomicUpdateOp(getOperation()); } +ParseResult spirv::AtomicXorOp::parse(OpAsmParser &parser, + OperationState &result) { + return ::parseAtomicUpdateOp(parser, result, true); +} +void spirv::AtomicXorOp::print(OpAsmPrinter &p) { + ::printAtomicUpdateOp(*this, p); +} + //===----------------------------------------------------------------------===// // spv.BitcastOp //===----------------------------------------------------------------------===// @@ -1489,8 +1532,8 @@ spirv::BranchConditionalOp::getMutableSuccessorOperands(unsigned index) { : falseTargetOperandsMutable(); } -static ParseResult parseBranchConditionalOp(OpAsmParser &parser, - OperationState &state) { +ParseResult spirv::BranchConditionalOp::parse(OpAsmParser &parser, + OperationState &state) { auto &builder = parser.getBuilder(); OpAsmParser::OperandType condInfo; Block *dest; @@ -1540,10 +1583,10 @@ static ParseResult parseBranchConditionalOp(OpAsmParser &parser, return success(); } -static void print(spirv::BranchConditionalOp branchOp, OpAsmPrinter &printer) { - printer << ' ' << branchOp.condition(); +void spirv::BranchConditionalOp::print(OpAsmPrinter &printer) { + printer << ' ' << condition(); - if (auto weights = branchOp.branch_weights()) { + if (auto weights = branch_weights()) { printer << " ["; llvm::interleaveComma(weights->getValue(), printer, [&](Attribute a) { printer << a.cast().getInt(); @@ -1552,11 +1595,9 @@ static void print(spirv::BranchConditionalOp branchOp, OpAsmPrinter &printer) { } printer << ", "; - printer.printSuccessorAndUseList(branchOp.getTrueBlock(), - branchOp.getTrueBlockArguments()); + printer.printSuccessorAndUseList(getTrueBlock(), getTrueBlockArguments()); printer << ", "; - printer.printSuccessorAndUseList(branchOp.getFalseBlock(), - branchOp.getFalseBlockArguments()); + printer.printSuccessorAndUseList(getFalseBlock(), getFalseBlockArguments()); } LogicalResult spirv::BranchConditionalOp::verify() { @@ -1577,8 +1618,8 @@ LogicalResult spirv::BranchConditionalOp::verify() { // spv.CompositeConstruct //===----------------------------------------------------------------------===// -static ParseResult parseCompositeConstructOp(OpAsmParser &parser, - OperationState &state) { +ParseResult spirv::CompositeConstructOp::parse(OpAsmParser &parser, + OperationState &state) { SmallVector operands; Type type; auto loc = parser.getCurrentLocation(); @@ -1611,10 +1652,8 @@ static ParseResult parseCompositeConstructOp(OpAsmParser &parser, return parser.resolveOperands(operands, elementTypes, loc, state.operands); } -static void print(spirv::CompositeConstructOp compositeConstructOp, - OpAsmPrinter &printer) { - printer << " " << compositeConstructOp.constituents() << " : " - << compositeConstructOp.getResult().getType(); +void spirv::CompositeConstructOp::print(OpAsmPrinter &printer) { + printer << " " << constituents() << " : " << getResult().getType(); } LogicalResult spirv::CompositeConstructOp::verify() { @@ -1658,8 +1697,8 @@ void spirv::CompositeExtractOp::build(OpBuilder &builder, OperationState &state, build(builder, state, elementType, composite, indexAttr); } -static ParseResult parseCompositeExtractOp(OpAsmParser &parser, - OperationState &state) { +ParseResult spirv::CompositeExtractOp::parse(OpAsmParser &parser, + OperationState &state) { OpAsmParser::OperandType compositeInfo; Attribute indicesAttr; Type compositeType; @@ -1682,11 +1721,8 @@ static ParseResult parseCompositeExtractOp(OpAsmParser &parser, return success(); } -static void print(spirv::CompositeExtractOp compositeExtractOp, - OpAsmPrinter &printer) { - printer << ' ' << compositeExtractOp.composite() - << compositeExtractOp.indices() << " : " - << compositeExtractOp.composite().getType(); +void spirv::CompositeExtractOp::print(OpAsmPrinter &printer) { + printer << ' ' << composite() << indices() << " : " << composite().getType(); } LogicalResult spirv::CompositeExtractOp::verify() { @@ -1715,8 +1751,8 @@ void spirv::CompositeInsertOp::build(OpBuilder &builder, OperationState &state, build(builder, state, composite.getType(), object, composite, indexAttr); } -static ParseResult parseCompositeInsertOp(OpAsmParser &parser, - OperationState &state) { +ParseResult spirv::CompositeInsertOp::parse(OpAsmParser &parser, + OperationState &state) { SmallVector operands; Type objectType, compositeType; Attribute indicesAttr; @@ -1753,19 +1789,17 @@ LogicalResult spirv::CompositeInsertOp::verify() { return success(); } -static void print(spirv::CompositeInsertOp compositeInsertOp, - OpAsmPrinter &printer) { - printer << " " << compositeInsertOp.object() << ", " - << compositeInsertOp.composite() << compositeInsertOp.indices() - << " : " << compositeInsertOp.object().getType() << " into " - << compositeInsertOp.composite().getType(); +void spirv::CompositeInsertOp::print(OpAsmPrinter &printer) { + printer << " " << object() << ", " << composite() << indices() << " : " + << object().getType() << " into " << composite().getType(); } //===----------------------------------------------------------------------===// // spv.Constant //===----------------------------------------------------------------------===// -static ParseResult parseConstantOp(OpAsmParser &parser, OperationState &state) { +ParseResult spirv::ConstantOp::parse(OpAsmParser &parser, + OperationState &state) { Attribute value; if (parser.parseAttribute(value, kValueAttrName, state.attributes)) return failure(); @@ -1779,10 +1813,10 @@ static ParseResult parseConstantOp(OpAsmParser &parser, OperationState &state) { return parser.addTypeToList(type, state.types); } -static void print(spirv::ConstantOp constOp, OpAsmPrinter &printer) { - printer << ' ' << constOp.value(); - if (constOp.getType().isa()) - printer << " : " << constOp.getType(); +void spirv::ConstantOp::print(OpAsmPrinter &printer) { + printer << ' ' << value(); + if (getType().isa()) + printer << " : " << getType(); } static LogicalResult verifyConstantType(spirv::ConstantOp op, Attribute value, @@ -2034,8 +2068,8 @@ void spirv::EntryPointOp::build(OpBuilder &builder, OperationState &state, SymbolRefAttr::get(function), builder.getArrayAttr(interfaceVars)); } -static ParseResult parseEntryPointOp(OpAsmParser &parser, - OperationState &state) { +ParseResult spirv::EntryPointOp::parse(OpAsmParser &parser, + OperationState &state) { spirv::ExecutionModel execModel; SmallVector identifiers; SmallVector idTypes; @@ -2065,11 +2099,10 @@ static ParseResult parseEntryPointOp(OpAsmParser &parser, return success(); } -static void print(spirv::EntryPointOp entryPointOp, OpAsmPrinter &printer) { - printer << " \"" << stringifyExecutionModel(entryPointOp.execution_model()) - << "\" "; - printer.printSymbolName(entryPointOp.fn()); - auto interfaceVars = entryPointOp.interface().getValue(); +void spirv::EntryPointOp::print(OpAsmPrinter &printer) { + printer << " \"" << stringifyExecutionModel(execution_model()) << "\" "; + printer.printSymbolName(fn()); + auto interfaceVars = interface().getValue(); if (!interfaceVars.empty()) { printer << ", "; llvm::interleaveComma(interfaceVars, printer); @@ -2095,8 +2128,8 @@ void spirv::ExecutionModeOp::build(OpBuilder &builder, OperationState &state, builder.getI32ArrayAttr(params)); } -static ParseResult parseExecutionModeOp(OpAsmParser &parser, - OperationState &state) { +ParseResult spirv::ExecutionModeOp::parse(OpAsmParser &parser, + OperationState &state) { spirv::ExecutionMode execMode; Attribute fn; if (parser.parseAttribute(fn, kFnNameAttrName, state.attributes) || @@ -2119,12 +2152,11 @@ static ParseResult parseExecutionModeOp(OpAsmParser &parser, return success(); } -static void print(spirv::ExecutionModeOp execModeOp, OpAsmPrinter &printer) { +void spirv::ExecutionModeOp::print(OpAsmPrinter &printer) { printer << " "; - printer.printSymbolName(execModeOp.fn()); - printer << " \"" << stringifyExecutionMode(execModeOp.execution_mode()) - << "\""; - auto values = execModeOp.values(); + printer.printSymbolName(fn()); + printer << " \"" << stringifyExecutionMode(execution_mode()) << "\""; + auto values = this->values(); if (values.empty()) return; printer << ", "; @@ -2161,7 +2193,7 @@ LogicalResult spirv::UConvertOp::verify() { // spv.func //===----------------------------------------------------------------------===// -static ParseResult parseFuncOp(OpAsmParser &parser, OperationState &state) { +ParseResult spirv::FuncOp::parse(OpAsmParser &parser, OperationState &state) { SmallVector entryArgs; SmallVector argAttrs; SmallVector resultAttrs; @@ -2209,22 +2241,22 @@ static ParseResult parseFuncOp(OpAsmParser &parser, OperationState &state) { return failure(result.hasValue() && failed(*result)); } -static void print(spirv::FuncOp fnOp, OpAsmPrinter &printer) { +void spirv::FuncOp::print(OpAsmPrinter &printer) { // Print function name, signature, and control. printer << " "; - printer.printSymbolName(fnOp.sym_name()); - auto fnType = fnOp.getType(); + printer.printSymbolName(sym_name()); + auto fnType = getType(); function_interface_impl::printFunctionSignature( - printer, fnOp, fnType.getInputs(), + printer, *this, fnType.getInputs(), /*isVariadic=*/false, fnType.getResults()); - printer << " \"" << spirv::stringifyFunctionControl(fnOp.function_control()) + printer << " \"" << spirv::stringifyFunctionControl(function_control()) << "\""; function_interface_impl::printFunctionAttributes( - printer, fnOp, fnType.getNumInputs(), fnType.getNumResults(), + printer, *this, fnType.getNumInputs(), fnType.getNumResults(), {spirv::attributeName()}); // Print the body if this is not an external function. - Region &body = fnOp.body(); + Region &body = this->body(); if (!body.empty()) { printer << ' '; printer.printRegion(body, /*printEntryBlockArgs=*/false, @@ -2354,6 +2386,46 @@ Operation::operand_range spirv::FunctionCallOp::getArgOperands() { return arguments(); } +//===----------------------------------------------------------------------===// +// spv.GLSLFClampOp +//===----------------------------------------------------------------------===// + +ParseResult spirv::GLSLFClampOp::parse(OpAsmParser &parser, + OperationState &result) { + return parseOneResultSameOperandTypeOp(parser, result); +} +void spirv::GLSLFClampOp::print(OpAsmPrinter &p) { printOneResultOp(*this, p); } + +//===----------------------------------------------------------------------===// +// spv.GLSLUClampOp +//===----------------------------------------------------------------------===// + +ParseResult spirv::GLSLUClampOp::parse(OpAsmParser &parser, + OperationState &result) { + return parseOneResultSameOperandTypeOp(parser, result); +} +void spirv::GLSLUClampOp::print(OpAsmPrinter &p) { printOneResultOp(*this, p); } + +//===----------------------------------------------------------------------===// +// spv.GLSLSClampOp +//===----------------------------------------------------------------------===// + +ParseResult spirv::GLSLSClampOp::parse(OpAsmParser &parser, + OperationState &result) { + return parseOneResultSameOperandTypeOp(parser, result); +} +void spirv::GLSLSClampOp::print(OpAsmPrinter &p) { printOneResultOp(*this, p); } + +//===----------------------------------------------------------------------===// +// spv.GLSLFmaOp +//===----------------------------------------------------------------------===// + +ParseResult spirv::GLSLFmaOp::parse(OpAsmParser &parser, + OperationState &result) { + return parseOneResultSameOperandTypeOp(parser, result); +} +void spirv::GLSLFmaOp::print(OpAsmPrinter &p) { printOneResultOp(*this, p); } + //===----------------------------------------------------------------------===// // spv.GlobalVariable //===----------------------------------------------------------------------===// @@ -2379,8 +2451,8 @@ void spirv::GlobalVariableOp::build(OpBuilder &builder, OperationState &state, builder.getStringAttr(spirv::stringifyBuiltIn(builtin))); } -static ParseResult parseGlobalVariableOp(OpAsmParser &parser, - OperationState &state) { +ParseResult spirv::GlobalVariableOp::parse(OpAsmParser &parser, + OperationState &state) { // Parse variable name. StringAttr nameAttr; if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(), @@ -2415,18 +2487,17 @@ static ParseResult parseGlobalVariableOp(OpAsmParser &parser, return success(); } -static void print(spirv::GlobalVariableOp varOp, OpAsmPrinter &printer) { - auto *op = varOp.getOperation(); +void spirv::GlobalVariableOp::print(OpAsmPrinter &printer) { SmallVector elidedAttrs{ spirv::attributeName()}; // Print variable name. printer << ' '; - printer.printSymbolName(varOp.sym_name()); + printer.printSymbolName(sym_name()); elidedAttrs.push_back(SymbolTable::getSymbolAttrName()); // Print optional initializer - if (auto initializer = varOp.initializer()) { + if (auto initializer = this->initializer()) { printer << " " << kInitializerAttrName << '('; printer.printSymbolName(initializer.getValue()); printer << ')'; @@ -2434,8 +2505,8 @@ static void print(spirv::GlobalVariableOp varOp, OpAsmPrinter &printer) { } elidedAttrs.push_back(kTypeAttrName); - printVariableDecorations(op, printer, elidedAttrs); - printer << " : " << varOp.type(); + printVariableDecorations(*this, printer, elidedAttrs); + printer << " : " << type(); } LogicalResult spirv::GlobalVariableOp::verify() { @@ -2526,8 +2597,8 @@ LogicalResult spirv::GroupNonUniformBroadcastOp::verify() { // spv.SubgroupBlockReadINTEL //===----------------------------------------------------------------------===// -static ParseResult parseSubgroupBlockReadINTELOp(OpAsmParser &parser, - OperationState &state) { +ParseResult spirv::SubgroupBlockReadINTELOp::parse(OpAsmParser &parser, + OperationState &state) { // Parse the storage class specification spirv::StorageClass storageClass; OpAsmParser::OperandType ptrInfo; @@ -2549,11 +2620,8 @@ static ParseResult parseSubgroupBlockReadINTELOp(OpAsmParser &parser, return success(); } -static void print(spirv::SubgroupBlockReadINTELOp blockReadOp, - OpAsmPrinter &printer) { - SmallVector elidedAttrs; - printer << " " << blockReadOp.ptr(); - printer << " : " << blockReadOp.getType(); +void spirv::SubgroupBlockReadINTELOp::print(OpAsmPrinter &printer) { + printer << " " << ptr() << " : " << getType(); } LogicalResult spirv::SubgroupBlockReadINTELOp::verify() { @@ -2567,8 +2635,8 @@ LogicalResult spirv::SubgroupBlockReadINTELOp::verify() { // spv.SubgroupBlockWriteINTEL //===----------------------------------------------------------------------===// -static ParseResult parseSubgroupBlockWriteINTELOp(OpAsmParser &parser, - OperationState &state) { +ParseResult spirv::SubgroupBlockWriteINTELOp::parse(OpAsmParser &parser, + OperationState &state) { // Parse the storage class specification spirv::StorageClass storageClass; SmallVector operandInfo; @@ -2591,11 +2659,8 @@ static ParseResult parseSubgroupBlockWriteINTELOp(OpAsmParser &parser, return success(); } -static void print(spirv::SubgroupBlockWriteINTELOp blockWriteOp, - OpAsmPrinter &printer) { - SmallVector elidedAttrs; - printer << " " << blockWriteOp.ptr() << ", " << blockWriteOp.value(); - printer << " : " << blockWriteOp.value().getType(); +void spirv::SubgroupBlockWriteINTELOp::print(OpAsmPrinter &printer) { + printer << " " << ptr() << ", " << value() << " : " << value().getType(); } LogicalResult spirv::SubgroupBlockWriteINTELOp::verify() { @@ -2625,6 +2690,14 @@ LogicalResult spirv::GroupNonUniformFAddOp::verify() { return verifyGroupNonUniformArithmeticOp(*this); } +ParseResult spirv::GroupNonUniformFAddOp::parse(OpAsmParser &parser, + OperationState &result) { + return parseGroupNonUniformArithmeticOp(parser, result); +} +void spirv::GroupNonUniformFAddOp::print(OpAsmPrinter &p) { + printGroupNonUniformArithmeticOp(*this, p); +} + //===----------------------------------------------------------------------===// // spv.GroupNonUniformFMaxOp //===----------------------------------------------------------------------===// @@ -2633,6 +2706,14 @@ LogicalResult spirv::GroupNonUniformFMaxOp::verify() { return verifyGroupNonUniformArithmeticOp(*this); } +ParseResult spirv::GroupNonUniformFMaxOp::parse(OpAsmParser &parser, + OperationState &result) { + return parseGroupNonUniformArithmeticOp(parser, result); +} +void spirv::GroupNonUniformFMaxOp::print(OpAsmPrinter &p) { + printGroupNonUniformArithmeticOp(*this, p); +} + //===----------------------------------------------------------------------===// // spv.GroupNonUniformFMinOp //===----------------------------------------------------------------------===// @@ -2641,6 +2722,14 @@ LogicalResult spirv::GroupNonUniformFMinOp::verify() { return verifyGroupNonUniformArithmeticOp(*this); } +ParseResult spirv::GroupNonUniformFMinOp::parse(OpAsmParser &parser, + OperationState &result) { + return parseGroupNonUniformArithmeticOp(parser, result); +} +void spirv::GroupNonUniformFMinOp::print(OpAsmPrinter &p) { + printGroupNonUniformArithmeticOp(*this, p); +} + //===----------------------------------------------------------------------===// // spv.GroupNonUniformFMulOp //===----------------------------------------------------------------------===// @@ -2649,6 +2738,14 @@ LogicalResult spirv::GroupNonUniformFMulOp::verify() { return verifyGroupNonUniformArithmeticOp(*this); } +ParseResult spirv::GroupNonUniformFMulOp::parse(OpAsmParser &parser, + OperationState &result) { + return parseGroupNonUniformArithmeticOp(parser, result); +} +void spirv::GroupNonUniformFMulOp::print(OpAsmPrinter &p) { + printGroupNonUniformArithmeticOp(*this, p); +} + //===----------------------------------------------------------------------===// // spv.GroupNonUniformIAddOp //===----------------------------------------------------------------------===// @@ -2657,6 +2754,14 @@ LogicalResult spirv::GroupNonUniformIAddOp::verify() { return verifyGroupNonUniformArithmeticOp(*this); } +ParseResult spirv::GroupNonUniformIAddOp::parse(OpAsmParser &parser, + OperationState &result) { + return parseGroupNonUniformArithmeticOp(parser, result); +} +void spirv::GroupNonUniformIAddOp::print(OpAsmPrinter &p) { + printGroupNonUniformArithmeticOp(*this, p); +} + //===----------------------------------------------------------------------===// // spv.GroupNonUniformIMulOp //===----------------------------------------------------------------------===// @@ -2665,6 +2770,14 @@ LogicalResult spirv::GroupNonUniformIMulOp::verify() { return verifyGroupNonUniformArithmeticOp(*this); } +ParseResult spirv::GroupNonUniformIMulOp::parse(OpAsmParser &parser, + OperationState &result) { + return parseGroupNonUniformArithmeticOp(parser, result); +} +void spirv::GroupNonUniformIMulOp::print(OpAsmPrinter &p) { + printGroupNonUniformArithmeticOp(*this, p); +} + //===----------------------------------------------------------------------===// // spv.GroupNonUniformSMaxOp //===----------------------------------------------------------------------===// @@ -2673,6 +2786,14 @@ LogicalResult spirv::GroupNonUniformSMaxOp::verify() { return verifyGroupNonUniformArithmeticOp(*this); } +ParseResult spirv::GroupNonUniformSMaxOp::parse(OpAsmParser &parser, + OperationState &result) { + return parseGroupNonUniformArithmeticOp(parser, result); +} +void spirv::GroupNonUniformSMaxOp::print(OpAsmPrinter &p) { + printGroupNonUniformArithmeticOp(*this, p); +} + //===----------------------------------------------------------------------===// // spv.GroupNonUniformSMinOp //===----------------------------------------------------------------------===// @@ -2681,6 +2802,14 @@ LogicalResult spirv::GroupNonUniformSMinOp::verify() { return verifyGroupNonUniformArithmeticOp(*this); } +ParseResult spirv::GroupNonUniformSMinOp::parse(OpAsmParser &parser, + OperationState &result) { + return parseGroupNonUniformArithmeticOp(parser, result); +} +void spirv::GroupNonUniformSMinOp::print(OpAsmPrinter &p) { + printGroupNonUniformArithmeticOp(*this, p); +} + //===----------------------------------------------------------------------===// // spv.GroupNonUniformUMaxOp //===----------------------------------------------------------------------===// @@ -2689,6 +2818,14 @@ LogicalResult spirv::GroupNonUniformUMaxOp::verify() { return verifyGroupNonUniformArithmeticOp(*this); } +ParseResult spirv::GroupNonUniformUMaxOp::parse(OpAsmParser &parser, + OperationState &result) { + return parseGroupNonUniformArithmeticOp(parser, result); +} +void spirv::GroupNonUniformUMaxOp::print(OpAsmPrinter &p) { + printGroupNonUniformArithmeticOp(*this, p); +} + //===----------------------------------------------------------------------===// // spv.GroupNonUniformUMinOp //===----------------------------------------------------------------------===// @@ -2697,6 +2834,14 @@ LogicalResult spirv::GroupNonUniformUMinOp::verify() { return verifyGroupNonUniformArithmeticOp(*this); } +ParseResult spirv::GroupNonUniformUMinOp::parse(OpAsmParser &parser, + OperationState &result) { + return parseGroupNonUniformArithmeticOp(parser, result); +} +void spirv::GroupNonUniformUMinOp::print(OpAsmPrinter &p) { + printGroupNonUniformArithmeticOp(*this, p); +} + //===----------------------------------------------------------------------===// // spv.LoadOp //===----------------------------------------------------------------------===// @@ -2709,7 +2854,7 @@ void spirv::LoadOp::build(OpBuilder &builder, OperationState &state, alignment); } -static ParseResult parseLoadOp(OpAsmParser &parser, OperationState &state) { +ParseResult spirv::LoadOp::parse(OpAsmParser &parser, OperationState &state) { // Parse the storage class specification spirv::StorageClass storageClass; OpAsmParser::OperandType ptrInfo; @@ -2730,17 +2875,16 @@ static ParseResult parseLoadOp(OpAsmParser &parser, OperationState &state) { return success(); } -static void print(spirv::LoadOp loadOp, OpAsmPrinter &printer) { - auto *op = loadOp.getOperation(); +void spirv::LoadOp::print(OpAsmPrinter &printer) { SmallVector elidedAttrs; StringRef sc = stringifyStorageClass( - loadOp.ptr().getType().cast().getStorageClass()); - printer << " \"" << sc << "\" " << loadOp.ptr(); + ptr().getType().cast().getStorageClass()); + printer << " \"" << sc << "\" " << ptr(); - printMemoryAccessAttribute(loadOp, printer, elidedAttrs); + printMemoryAccessAttribute(*this, printer, elidedAttrs); - printer.printOptionalAttrDict(op->getAttrs(), elidedAttrs); - printer << " : " << loadOp.getType(); + printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs); + printer << " : " << getType(); } LogicalResult spirv::LoadOp::verify() { @@ -2764,21 +2908,19 @@ void spirv::LoopOp::build(OpBuilder &builder, OperationState &state) { state.addRegion(); } -static ParseResult parseLoopOp(OpAsmParser &parser, OperationState &state) { +ParseResult spirv::LoopOp::parse(OpAsmParser &parser, OperationState &state) { if (parseControlAttribute(parser, state)) return failure(); return parser.parseRegion(*state.addRegion(), /*arguments=*/{}, /*argTypes=*/{}); } -static void print(spirv::LoopOp loopOp, OpAsmPrinter &printer) { - auto *op = loopOp.getOperation(); - - auto control = loopOp.loop_control(); +void spirv::LoopOp::print(OpAsmPrinter &printer) { + auto control = loop_control(); if (control != spirv::LoopControl::None) printer << " control(" << spirv::stringifyLoopControl(control) << ")"; printer << ' '; - printer.printRegion(op->getRegion(0), /*printEntryBlockArgs=*/false, + printer.printRegion(getRegion(), /*printEntryBlockArgs=*/false, /*printBlockTerminators=*/true); } @@ -2970,19 +3112,19 @@ void spirv::ModuleOp::build(OpBuilder &builder, OperationState &state, builder.getStringAttr(*name)); } -static ParseResult parseModuleOp(OpAsmParser &parser, OperationState &state) { +ParseResult spirv::ModuleOp::parse(OpAsmParser &parser, OperationState &state) { Region *body = state.addRegion(); // If the name is present, parse it. StringAttr nameAttr; - parser.parseOptionalSymbolName(nameAttr, SymbolTable::getSymbolAttrName(), - state.attributes); + parser.parseOptionalSymbolName( + nameAttr, mlir::SymbolTable::getSymbolAttrName(), state.attributes); // Parse attributes spirv::AddressingModel addrModel; spirv::MemoryModel memoryModel; - if (parseEnumKeywordAttr(addrModel, parser, state) || - parseEnumKeywordAttr(memoryModel, parser, state)) + if (::parseEnumKeywordAttr(addrModel, parser, state) || + ::parseEnumKeywordAttr(memoryModel, parser, state)) return failure(); if (succeeded(parser.parseOptionalKeyword("requires"))) { @@ -3006,29 +3148,29 @@ static ParseResult parseModuleOp(OpAsmParser &parser, OperationState &state) { return success(); } -static void print(spirv::ModuleOp moduleOp, OpAsmPrinter &printer) { - if (Optional name = moduleOp.getName()) { +void spirv::ModuleOp::print(OpAsmPrinter &printer) { + if (Optional name = getName()) { printer << ' '; printer.printSymbolName(*name); } SmallVector elidedAttrs; - printer << " " << spirv::stringifyAddressingModel(moduleOp.addressing_model()) - << " " << spirv::stringifyMemoryModel(moduleOp.memory_model()); + printer << " " << spirv::stringifyAddressingModel(addressing_model()) << " " + << spirv::stringifyMemoryModel(memory_model()); auto addressingModelAttrName = spirv::attributeName(); auto memoryModelAttrName = spirv::attributeName(); elidedAttrs.assign({addressingModelAttrName, memoryModelAttrName, - SymbolTable::getSymbolAttrName()}); + mlir::SymbolTable::getSymbolAttrName()}); - if (Optional triple = moduleOp.vce_triple()) { + if (Optional triple = vce_triple()) { printer << " requires " << *triple; elidedAttrs.push_back(spirv::ModuleOp::getVCETripleAttrName()); } - printer.printOptionalAttrDictWithKeyword(moduleOp->getAttrs(), elidedAttrs); + printer.printOptionalAttrDictWithKeyword((*this)->getAttrs(), elidedAttrs); printer << ' '; - printer.printRegion(moduleOp.getRegion()); + printer.printRegion(getRegion()); } LogicalResult spirv::ModuleOp::verify() { @@ -3163,21 +3305,20 @@ LogicalResult spirv::SelectOp::verify() { // spv.mlir.selection //===----------------------------------------------------------------------===// -static ParseResult parseSelectionOp(OpAsmParser &parser, - OperationState &state) { +ParseResult spirv::SelectionOp::parse(OpAsmParser &parser, + OperationState &state) { if (parseControlAttribute(parser, state)) return failure(); return parser.parseRegion(*state.addRegion(), /*arguments=*/{}, /*argTypes=*/{}); } -static void print(spirv::SelectionOp selectionOp, OpAsmPrinter &printer) { - auto *op = selectionOp.getOperation(); - auto control = selectionOp.selection_control(); +void spirv::SelectionOp::print(OpAsmPrinter &printer) { + auto control = selection_control(); if (control != spirv::SelectionControl::None) printer << " control(" << spirv::stringifySelectionControl(control) << ")"; printer << ' '; - printer.printRegion(op->getRegion(0), /*printEntryBlockArgs=*/false, + printer.printRegion(getRegion(), /*printEntryBlockArgs=*/false, /*printBlockTerminators=*/true); } @@ -3279,8 +3420,8 @@ spirv::SelectionOp spirv::SelectionOp::createIfThen( // spv.SpecConstant //===----------------------------------------------------------------------===// -static ParseResult parseSpecConstantOp(OpAsmParser &parser, - OperationState &state) { +ParseResult spirv::SpecConstantOp::parse(OpAsmParser &parser, + OperationState &state) { StringAttr nameAttr; Attribute valueAttr; @@ -3304,12 +3445,12 @@ static ParseResult parseSpecConstantOp(OpAsmParser &parser, return success(); } -static void print(spirv::SpecConstantOp constOp, OpAsmPrinter &printer) { +void spirv::SpecConstantOp::print(OpAsmPrinter &printer) { printer << ' '; - printer.printSymbolName(constOp.sym_name()); - if (auto specID = constOp->getAttrOfType(kSpecIdAttrName)) + printer.printSymbolName(sym_name()); + if (auto specID = (*this)->getAttrOfType(kSpecIdAttrName)) printer << ' ' << kSpecIdAttrName << '(' << specID.getInt() << ')'; - printer << " = " << constOp.default_value(); + printer << " = " << default_value(); } LogicalResult spirv::SpecConstantOp::verify() { @@ -3332,7 +3473,7 @@ LogicalResult spirv::SpecConstantOp::verify() { // spv.StoreOp //===----------------------------------------------------------------------===// -static ParseResult parseStoreOp(OpAsmParser &parser, OperationState &state) { +ParseResult spirv::StoreOp::parse(OpAsmParser &parser, OperationState &state) { // Parse the storage class specification spirv::StorageClass storageClass; SmallVector operandInfo; @@ -3353,17 +3494,16 @@ static ParseResult parseStoreOp(OpAsmParser &parser, OperationState &state) { return success(); } -static void print(spirv::StoreOp storeOp, OpAsmPrinter &printer) { - auto *op = storeOp.getOperation(); +void spirv::StoreOp::print(OpAsmPrinter &printer) { SmallVector elidedAttrs; StringRef sc = stringifyStorageClass( - storeOp.ptr().getType().cast().getStorageClass()); - printer << " \"" << sc << "\" " << storeOp.ptr() << ", " << storeOp.value(); + ptr().getType().cast().getStorageClass()); + printer << " \"" << sc << "\" " << ptr() << ", " << value(); - printMemoryAccessAttribute(storeOp, printer, elidedAttrs); + printMemoryAccessAttribute(*this, printer, elidedAttrs); - printer << " : " << storeOp.value().getType(); - printer.printOptionalAttrDict(op->getAttrs(), elidedAttrs); + printer << " : " << value().getType(); + printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs); } LogicalResult spirv::StoreOp::verify() { @@ -3397,7 +3537,8 @@ LogicalResult spirv::UnreachableOp::verify() { // spv.Variable //===----------------------------------------------------------------------===// -static ParseResult parseVariableOp(OpAsmParser &parser, OperationState &state) { +ParseResult spirv::VariableOp::parse(OpAsmParser &parser, + OperationState &state) { // Parse optional initializer Optional initInfo; if (succeeded(parser.parseOptionalKeyword("init"))) { @@ -3438,15 +3579,15 @@ static ParseResult parseVariableOp(OpAsmParser &parser, OperationState &state) { return success(); } -static void print(spirv::VariableOp varOp, OpAsmPrinter &printer) { +void spirv::VariableOp::print(OpAsmPrinter &printer) { SmallVector elidedAttrs{ spirv::attributeName()}; // Print optional initializer - if (varOp.getNumOperands() != 0) - printer << " init(" << varOp.initializer() << ")"; + if (getNumOperands() != 0) + printer << " init(" << initializer() << ")"; - printVariableDecorations(varOp, printer, elidedAttrs); - printer << " : " << varOp.getType(); + printVariableDecorations(*this, printer, elidedAttrs); + printer << " : " << getType(); } LogicalResult spirv::VariableOp::verify() { @@ -3526,8 +3667,8 @@ LogicalResult spirv::VectorShuffleOp::verify() { // spv.CooperativeMatrixLoadNV //===----------------------------------------------------------------------===// -static ParseResult parseCooperativeMatrixLoadNVOp(OpAsmParser &parser, - OperationState &state) { +ParseResult spirv::CooperativeMatrixLoadNVOp::parse(OpAsmParser &parser, + OperationState &state) { SmallVector operandInfo; Type strideType = parser.getBuilder().getIntegerType(32); Type columnMajorType = parser.getBuilder().getIntegerType(1); @@ -3548,13 +3689,12 @@ static ParseResult parseCooperativeMatrixLoadNVOp(OpAsmParser &parser, return success(); } -static void print(spirv::CooperativeMatrixLoadNVOp m, OpAsmPrinter &printer) { - printer << " " << m.pointer() << ", " << m.stride() << ", " - << m.columnmajor(); +void spirv::CooperativeMatrixLoadNVOp::print(OpAsmPrinter &printer) { + printer << " " << pointer() << ", " << stride() << ", " << columnmajor(); // Print optional memory access attribute. - if (auto memAccess = m.memory_access()) + if (auto memAccess = memory_access()) printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"]"; - printer << " : " << m.pointer().getType() << " as " << m.getType(); + printer << " : " << pointer().getType() << " as " << getType(); } static LogicalResult verifyPointerAndCoopMatrixType(Operation *op, Type pointer, @@ -3585,8 +3725,8 @@ LogicalResult spirv::CooperativeMatrixLoadNVOp::verify() { // spv.CooperativeMatrixStoreNV //===----------------------------------------------------------------------===// -static ParseResult parseCooperativeMatrixStoreNVOp(OpAsmParser &parser, - OperationState &state) { +ParseResult spirv::CooperativeMatrixStoreNVOp::parse(OpAsmParser &parser, + OperationState &state) { SmallVector operandInfo; Type strideType = parser.getBuilder().getIntegerType(32); Type columnMajorType = parser.getBuilder().getIntegerType(1); @@ -3607,15 +3747,13 @@ static ParseResult parseCooperativeMatrixStoreNVOp(OpAsmParser &parser, return success(); } -static void print(spirv::CooperativeMatrixStoreNVOp coopMatrix, - OpAsmPrinter &printer) { - printer << " " << coopMatrix.pointer() << ", " << coopMatrix.object() << ", " - << coopMatrix.stride() << ", " << coopMatrix.columnmajor(); +void spirv::CooperativeMatrixStoreNVOp::print(OpAsmPrinter &printer) { + printer << " " << pointer() << ", " << object() << ", " << stride() << ", " + << columnmajor(); // Print optional memory access attribute. - if (auto memAccess = coopMatrix.memory_access()) + if (auto memAccess = memory_access()) printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"]"; - printer << " : " << coopMatrix.pointer().getType() << ", " - << coopMatrix.getOperand(1).getType(); + printer << " : " << pointer().getType() << ", " << getOperand(1).getType(); } LogicalResult spirv::CooperativeMatrixStoreNVOp::verify() { @@ -3694,40 +3832,31 @@ LogicalResult spirv::MatrixTimesScalarOp::verify() { // spv.CopyMemory //===----------------------------------------------------------------------===// -static void print(spirv::CopyMemoryOp copyMemory, OpAsmPrinter &printer) { - auto *op = copyMemory.getOperation(); +void spirv::CopyMemoryOp::print(OpAsmPrinter &printer) { printer << ' '; - StringRef targetStorageClass = - stringifyStorageClass(copyMemory.target() - .getType() - .cast() - .getStorageClass()); - printer << " \"" << targetStorageClass << "\" " << copyMemory.target() - << ", "; + StringRef targetStorageClass = stringifyStorageClass( + target().getType().cast().getStorageClass()); + printer << " \"" << targetStorageClass << "\" " << target() << ", "; - StringRef sourceStorageClass = - stringifyStorageClass(copyMemory.source() - .getType() - .cast() - .getStorageClass()); - printer << " \"" << sourceStorageClass << "\" " << copyMemory.source(); + StringRef sourceStorageClass = stringifyStorageClass( + source().getType().cast().getStorageClass()); + printer << " \"" << sourceStorageClass << "\" " << source(); SmallVector elidedAttrs; - printMemoryAccessAttribute(copyMemory, printer, elidedAttrs); - printSourceMemoryAccessAttribute(copyMemory, printer, elidedAttrs, - copyMemory.source_memory_access(), - copyMemory.source_alignment()); + printMemoryAccessAttribute(*this, printer, elidedAttrs); + printSourceMemoryAccessAttribute(*this, printer, elidedAttrs, + source_memory_access(), source_alignment()); - printer.printOptionalAttrDict(op->getAttrs(), elidedAttrs); + printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs); Type pointeeType = - copyMemory.target().getType().cast().getPointeeType(); + target().getType().cast().getPointeeType(); printer << " : " << pointeeType; } -static ParseResult parseCopyMemoryOp(OpAsmParser &parser, - OperationState &state) { +ParseResult spirv::CopyMemoryOp::parse(OpAsmParser &parser, + OperationState &state) { spirv::StorageClass targetStorageClass; OpAsmParser::OperandType targetPtrInfo; @@ -3857,8 +3986,8 @@ LogicalResult spirv::MatrixTimesMatrixOp::verify() { // spv.SpecConstantComposite //===----------------------------------------------------------------------===// -static ParseResult parseSpecConstantCompositeOp(OpAsmParser &parser, - OperationState &state) { +ParseResult spirv::SpecConstantCompositeOp::parse(OpAsmParser &parser, + OperationState &state) { StringAttr compositeName; if (parser.parseSymbolName(compositeName, SymbolTable::getSymbolAttrName(), @@ -3897,16 +4026,16 @@ static ParseResult parseSpecConstantCompositeOp(OpAsmParser &parser, return success(); } -static void print(spirv::SpecConstantCompositeOp op, OpAsmPrinter &printer) { +void spirv::SpecConstantCompositeOp::print(OpAsmPrinter &printer) { printer << " "; - printer.printSymbolName(op.sym_name()); + printer.printSymbolName(sym_name()); printer << " ("; - auto constituents = op.constituents().getValue(); + auto constituents = this->constituents().getValue(); if (!constituents.empty()) llvm::interleaveComma(constituents, printer); - printer << ") : " << op.type(); + printer << ") : " << type(); } LogicalResult spirv::SpecConstantCompositeOp::verify() { @@ -3945,8 +4074,8 @@ LogicalResult spirv::SpecConstantCompositeOp::verify() { // spv.SpecConstantOperation //===----------------------------------------------------------------------===// -static ParseResult parseSpecConstantOperationOp(OpAsmParser &parser, - OperationState &state) { +ParseResult spirv::SpecConstantOperationOp::parse(OpAsmParser &parser, + OperationState &state) { Region *body = state.addRegion(); if (parser.parseKeyword("wraps")) @@ -3972,9 +4101,9 @@ static ParseResult parseSpecConstantOperationOp(OpAsmParser &parser, return success(); } -static void print(spirv::SpecConstantOperationOp op, OpAsmPrinter &printer) { +void spirv::SpecConstantOperationOp::print(OpAsmPrinter &printer) { printer << " wraps "; - printer.printGenericOp(&op.body().front().front()); + printer.printGenericOp(&body().front().front()); } LogicalResult spirv::SpecConstantOperationOp::verify() { @@ -4255,14 +4384,14 @@ void spirv::InBoundsPtrAccessChainOp::build(OpBuilder &builder, build(builder, state, type, basePtr, element, indices); } -static ParseResult parseInBoundsPtrAccessChainOp(OpAsmParser &parser, - OperationState &state) { +ParseResult spirv::InBoundsPtrAccessChainOp::parse(OpAsmParser &parser, + OperationState &state) { return parsePtrAccessChainOpImpl( spirv::InBoundsPtrAccessChainOp::getOperationName(), parser, state); } -static void print(spirv::InBoundsPtrAccessChainOp op, OpAsmPrinter &printer) { - printAccessChain(op, concatElemAndIndices(op), printer); +void spirv::InBoundsPtrAccessChainOp::print(OpAsmPrinter &printer) { + printAccessChain(*this, concatElemAndIndices(*this), printer); } LogicalResult spirv::InBoundsPtrAccessChainOp::verify() { @@ -4281,14 +4410,14 @@ void spirv::PtrAccessChainOp::build(OpBuilder &builder, OperationState &state, build(builder, state, type, basePtr, element, indices); } -static ParseResult parsePtrAccessChainOp(OpAsmParser &parser, - OperationState &state) { +ParseResult spirv::PtrAccessChainOp::parse(OpAsmParser &parser, + OperationState &state) { return parsePtrAccessChainOpImpl(spirv::PtrAccessChainOp::getOperationName(), parser, state); } -static void print(spirv::PtrAccessChainOp op, OpAsmPrinter &printer) { - printAccessChain(op, concatElemAndIndices(op), printer); +void spirv::PtrAccessChainOp::print(OpAsmPrinter &printer) { + printAccessChain(*this, concatElemAndIndices(*this), printer); } LogicalResult spirv::PtrAccessChainOp::verify() { diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp index 9241e2d12ba2..1ae9e68d2024 100644 --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -256,8 +256,7 @@ OpFoldResult AnyOp::fold(ArrayRef operands) { // AssumingOp //===----------------------------------------------------------------------===// -static ParseResult parseAssumingOp(OpAsmParser &parser, - OperationState &result) { +ParseResult AssumingOp::parse(OpAsmParser &parser, OperationState &result) { result.regions.reserve(1); Region *doRegion = result.addRegion(); @@ -283,17 +282,17 @@ static ParseResult parseAssumingOp(OpAsmParser &parser, return success(); } -static void print(OpAsmPrinter &p, AssumingOp op) { - bool yieldsResults = !op.getResults().empty(); +void AssumingOp::print(OpAsmPrinter &p) { + bool yieldsResults = !getResults().empty(); - p << " " << op.getWitness(); + p << " " << getWitness(); if (yieldsResults) - p << " -> (" << op.getResultTypes() << ")"; + p << " -> (" << getResultTypes() << ")"; p << ' '; - p.printRegion(op.getDoRegion(), + p.printRegion(getDoRegion(), /*printEntryBlockArgs=*/false, /*printBlockTerminators=*/yieldsResults); - p.printOptionalAttrDict(op->getAttrs()); + p.printOptionalAttrDict((*this)->getAttrs()); } namespace { @@ -905,18 +904,16 @@ OpFoldResult ConcatOp::fold(ArrayRef operands) { // ConstShapeOp //===----------------------------------------------------------------------===// -static void print(OpAsmPrinter &p, ConstShapeOp &op) { +void ConstShapeOp::print(OpAsmPrinter &p) { p << " "; - p.printOptionalAttrDict(op->getAttrs(), /*elidedAttrs=*/{"shape"}); + p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"shape"}); p << "["; - interleaveComma(op.getShape().getValues(), p, - [&](int64_t i) { p << i; }); + interleaveComma(getShape().getValues(), p); p << "] : "; - p.printType(op.getType()); + p.printType(getType()); } -static ParseResult parseConstShapeOp(OpAsmParser &parser, - OperationState &result) { +ParseResult ConstShapeOp::parse(OpAsmParser &parser, OperationState &result) { if (parser.parseOptionalAttrDict(result.attributes)) return failure(); // We piggy-back on ArrayAttr parsing, though we don't internally store the @@ -1215,8 +1212,8 @@ FuncOp FunctionLibraryOp::getShapeFunction(Operation *op) { return lookupSymbol(attr); } -ParseResult parseFunctionLibraryOp(OpAsmParser &parser, - OperationState &result) { +ParseResult FunctionLibraryOp::parse(OpAsmParser &parser, + OperationState &result) { // Parse the op name. StringAttr nameAttr; if (parser.parseSymbolName(nameAttr, ::mlir::SymbolTable::getSymbolAttrName(), @@ -1241,16 +1238,16 @@ ParseResult parseFunctionLibraryOp(OpAsmParser &parser, return success(); } -void print(OpAsmPrinter &p, FunctionLibraryOp op) { +void FunctionLibraryOp::print(OpAsmPrinter &p) { p << ' '; - p.printSymbolName(op.getName()); + p.printSymbolName(getName()); p.printOptionalAttrDictWithKeyword( - op->getAttrs(), {SymbolTable::getSymbolAttrName(), "mapping"}); + (*this)->getAttrs(), {mlir::SymbolTable::getSymbolAttrName(), "mapping"}); p << ' '; - p.printRegion(op.getOperation()->getRegion(0), /*printEntryBlockArgs=*/false, + p.printRegion(getRegion(), /*printEntryBlockArgs=*/false, /*printBlockTerminators=*/false); p << " mapping "; - p.printAttributeWithoutType(op.getMappingAttr()); + p.printAttributeWithoutType(getMappingAttr()); } //===----------------------------------------------------------------------===// @@ -1846,7 +1843,7 @@ LogicalResult ReduceOp::verify() { return success(); } -static ParseResult parseReduceOp(OpAsmParser &parser, OperationState &result) { +ParseResult ReduceOp::parse(OpAsmParser &parser, OperationState &result) { // Parse operands. SmallVector operands; Type shapeOrExtentTensorType; @@ -1876,13 +1873,13 @@ static ParseResult parseReduceOp(OpAsmParser &parser, OperationState &result) { return success(); } -static void print(OpAsmPrinter &p, ReduceOp op) { - p << '(' << op.getShape() << ", " << op.getInitVals() - << ") : " << op.getShape().getType(); - p.printOptionalArrowTypeList(op.getResultTypes()); +void ReduceOp::print(OpAsmPrinter &p) { + p << '(' << getShape() << ", " << getInitVals() + << ") : " << getShape().getType(); + p.printOptionalArrowTypeList(getResultTypes()); p << ' '; - p.printRegion(op.getRegion()); - p.printOptionalAttrDict(op->getAttrs()); + p.printRegion(getRegion()); + p.printOptionalAttrDict((*this)->getAttrs()); } #define GET_OP_CLASSES diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp index a26ad993b495..71871ee50381 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -733,13 +733,16 @@ SmallVector ExpandShapeOp::getReassociationExprs() { getReassociationIndices()); } -static void print(OpAsmPrinter &p, ExpandShapeOp op) { - ::mlir::printReshapeOp(p, op); +ParseResult ExpandShapeOp::parse(OpAsmParser &parser, OperationState &result) { + return parseReshapeLikeOp(parser, result); } +void ExpandShapeOp::print(OpAsmPrinter &p) { printReshapeOp(p, *this); } -static void print(OpAsmPrinter &p, CollapseShapeOp op) { - ::mlir::printReshapeOp(p, op); +ParseResult CollapseShapeOp::parse(OpAsmParser &parser, + OperationState &result) { + return parseReshapeLikeOp(parser, result); } +void CollapseShapeOp::print(OpAsmPrinter &p) { printReshapeOp(p, *this); } /// Compute the RankedTensorType obtained by applying `reassociation` to `type`. static RankedTensorType diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index b576005c47e9..8213243f402f 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -404,8 +404,7 @@ LogicalResult ReductionOp::verify() { return success(); } -static ParseResult parseReductionOp(OpAsmParser &parser, - OperationState &result) { +ParseResult ReductionOp::parse(OpAsmParser &parser, OperationState &result) { SmallVector operandsInfo; Type redType; Type resType; @@ -426,11 +425,11 @@ static ParseResult parseReductionOp(OpAsmParser &parser, return success(); } -static void print(OpAsmPrinter &p, ReductionOp op) { - p << " \"" << op.kind() << "\", " << op.vector(); - if (!op.acc().empty()) - p << ", " << op.acc(); - p << " : " << op.vector().getType() << " into " << op.dest().getType(); +void ReductionOp::print(OpAsmPrinter &p) { + p << " \"" << kind() << "\", " << vector(); + if (!acc().empty()) + p << ", " << acc(); + p << " : " << vector().getType() << " into " << dest().getType(); } Value mlir::vector::getVectorReductionOp(arith::AtomicRMWKind op, @@ -510,8 +509,7 @@ void vector::ContractionOp::build(OpBuilder &builder, OperationState &result, builder.getContext())); } -static ParseResult parseContractionOp(OpAsmParser &parser, - OperationState &result) { +ParseResult ContractionOp::parse(OpAsmParser &parser, OperationState &result) { OpAsmParser::OperandType lhsInfo; OpAsmParser::OperandType rhsInfo; OpAsmParser::OperandType accInfo; @@ -557,25 +555,25 @@ static ParseResult parseContractionOp(OpAsmParser &parser, return success(); } -static void print(OpAsmPrinter &p, ContractionOp op) { +void ContractionOp::print(OpAsmPrinter &p) { // TODO: Unify printing code with linalg ops. - auto attrNames = op.getTraitAttrNames(); + auto attrNames = getTraitAttrNames(); llvm::StringSet<> traitAttrsSet; traitAttrsSet.insert(attrNames.begin(), attrNames.end()); SmallVector attrs; - for (auto attr : op->getAttrs()) + for (auto attr : (*this)->getAttrs()) if (traitAttrsSet.count(attr.getName().strref()) > 0) attrs.push_back(attr); - auto dictAttr = DictionaryAttr::get(op.getContext(), attrs); - p << " " << dictAttr << " " << op.lhs() << ", "; - p << op.rhs() << ", " << op.acc(); - if (op.masks().size() == 2) - p << ", " << op.masks(); + auto dictAttr = DictionaryAttr::get(getContext(), attrs); + p << " " << dictAttr << " " << lhs() << ", "; + p << rhs() << ", " << acc(); + if (masks().size() == 2) + p << ", " << masks(); - p.printOptionalAttrDict(op->getAttrs(), attrNames); - p << " : " << op.lhs().getType() << ", " << op.rhs().getType() << " into " - << op.getResultType(); + p.printOptionalAttrDict((*this)->getAttrs(), attrNames); + p << " : " << lhs().getType() << ", " << rhs().getType() << " into " + << getResultType(); } static bool verifyDimMap(VectorType lhsType, VectorType rhsType, @@ -967,13 +965,14 @@ void vector::ExtractOp::build(OpBuilder &builder, OperationState &result, build(builder, result, source, positionConstants); } -static void print(OpAsmPrinter &p, vector::ExtractOp op) { - p << " " << op.vector() << op.position(); - p.printOptionalAttrDict(op->getAttrs(), {"position"}); - p << " : " << op.vector().getType(); +void vector::ExtractOp::print(OpAsmPrinter &p) { + p << " " << vector() << position(); + p.printOptionalAttrDict((*this)->getAttrs(), {"position"}); + p << " : " << vector().getType(); } -static ParseResult parseExtractOp(OpAsmParser &parser, OperationState &result) { +ParseResult vector::ExtractOp::parse(OpAsmParser &parser, + OperationState &result) { SMLoc attributeLoc, typeLoc; NamedAttrList attrs; OpAsmParser::OperandType vector; @@ -1731,10 +1730,10 @@ void ShuffleOp::build(OpBuilder &builder, OperationState &result, Value v1, result.addAttribute(getMaskAttrName(), maskAttr); } -static void print(OpAsmPrinter &p, ShuffleOp op) { - p << " " << op.v1() << ", " << op.v2() << " " << op.mask(); - p.printOptionalAttrDict(op->getAttrs(), {ShuffleOp::getMaskAttrName()}); - p << " : " << op.v1().getType() << ", " << op.v2().getType(); +void ShuffleOp::print(OpAsmPrinter &p) { + p << " " << v1() << ", " << v2() << " " << mask(); + p.printOptionalAttrDict((*this)->getAttrs(), {ShuffleOp::getMaskAttrName()}); + p << " : " << v1().getType() << ", " << v2().getType(); } LogicalResult ShuffleOp::verify() { @@ -1770,7 +1769,7 @@ LogicalResult ShuffleOp::verify() { return success(); } -static ParseResult parseShuffleOp(OpAsmParser &parser, OperationState &result) { +ParseResult ShuffleOp::parse(OpAsmParser &parser, OperationState &result) { OpAsmParser::OperandType v1, v2; Attribute attr; VectorType v1Type, v2Type; @@ -2134,17 +2133,16 @@ void OuterProductOp::build(OpBuilder &builder, OperationState &result, result.addTypes(acc.getType()); } -static void print(OpAsmPrinter &p, OuterProductOp op) { - p << " " << op.lhs() << ", " << op.rhs(); - if (!op.acc().empty()) { - p << ", " << op.acc(); - p.printOptionalAttrDict(op->getAttrs()); +void OuterProductOp::print(OpAsmPrinter &p) { + p << " " << lhs() << ", " << rhs(); + if (!acc().empty()) { + p << ", " << acc(); + p.printOptionalAttrDict((*this)->getAttrs()); } - p << " : " << op.lhs().getType() << ", " << op.rhs().getType(); + p << " : " << lhs().getType() << ", " << rhs().getType(); } -static ParseResult parseOuterProductOp(OpAsmParser &parser, - OperationState &result) { +ParseResult OuterProductOp::parse(OpAsmParser &parser, OperationState &result) { SmallVector operandsInfo; Type tLHS, tRHS; if (parser.parseOperandList(operandsInfo) || @@ -2773,16 +2771,15 @@ static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op) { p.printOptionalAttrDict(op->getAttrs(), elidedAttrs); } -static void print(OpAsmPrinter &p, TransferReadOp op) { - p << " " << op.source() << "[" << op.indices() << "], " << op.padding(); - if (op.mask()) - p << ", " << op.mask(); - printTransferAttrs(p, cast(op.getOperation())); - p << " : " << op.getShapedType() << ", " << op.getVectorType(); +void TransferReadOp::print(OpAsmPrinter &p) { + p << " " << source() << "[" << indices() << "], " << padding(); + if (mask()) + p << ", " << mask(); + printTransferAttrs(p, *this); + p << " : " << getShapedType() << ", " << getVectorType(); } -static ParseResult parseTransferReadOp(OpAsmParser &parser, - OperationState &result) { +ParseResult TransferReadOp::parse(OpAsmParser &parser, OperationState &result) { auto &builder = parser.getBuilder(); SMLoc typesLoc; OpAsmParser::OperandType sourceInfo; @@ -3160,8 +3157,8 @@ void TransferWriteOp::build(OpBuilder &builder, OperationState &result, build(builder, result, vector, dest, indices, permutationMap, inBounds); } -static ParseResult parseTransferWriteOp(OpAsmParser &parser, - OperationState &result) { +ParseResult TransferWriteOp::parse(OpAsmParser &parser, + OperationState &result) { auto &builder = parser.getBuilder(); SMLoc typesLoc; OpAsmParser::OperandType vectorInfo, sourceInfo; @@ -3213,12 +3210,12 @@ static ParseResult parseTransferWriteOp(OpAsmParser &parser, parser.addTypeToList(shapedType, result.types)); } -static void print(OpAsmPrinter &p, TransferWriteOp op) { - p << " " << op.vector() << ", " << op.source() << "[" << op.indices() << "]"; - if (op.mask()) - p << ", " << op.mask(); - printTransferAttrs(p, cast(op.getOperation())); - p << " : " << op.getVectorType() << ", " << op.getShapedType(); +void TransferWriteOp::print(OpAsmPrinter &p) { + p << " " << vector() << ", " << source() << "[" << indices() << "]"; + if (mask()) + p << ", " << mask(); + printTransferAttrs(p, *this); + p << " : " << getVectorType() << ", " << getShapedType(); } LogicalResult TransferWriteOp::verify() { diff --git a/mlir/lib/IR/BuiltinDialect.cpp b/mlir/lib/IR/BuiltinDialect.cpp index 2d573eae8ebe..af94c5e8a985 100644 --- a/mlir/lib/IR/BuiltinDialect.cpp +++ b/mlir/lib/IR/BuiltinDialect.cpp @@ -112,7 +112,7 @@ void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name, /*resultAttrs=*/llvm::None); } -static ParseResult parseFuncOp(OpAsmParser &parser, OperationState &result) { +ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) { auto buildFuncType = [](Builder &builder, ArrayRef argTypes, ArrayRef results, function_interface_impl::VariadicFlag, @@ -122,10 +122,10 @@ static ParseResult parseFuncOp(OpAsmParser &parser, OperationState &result) { parser, result, /*allowVariadic=*/false, buildFuncType); } -static void print(FuncOp op, OpAsmPrinter &p) { - FunctionType fnType = op.getType(); +void FuncOp::print(OpAsmPrinter &p) { + FunctionType fnType = getType(); function_interface_impl::printFunctionOp( - p, op, fnType.getInputs(), /*isVariadic=*/false, fnType.getResults()); + p, *this, fnType.getInputs(), /*isVariadic=*/false, fnType.getResults()); } LogicalResult FuncOp::verify() { diff --git a/mlir/test/Dialect/SPIRV/IR/bit-ops.mlir b/mlir/test/Dialect/SPIRV/IR/bit-ops.mlir index 0ca176102e3e..cd775b4d945f 100644 --- a/mlir/test/Dialect/SPIRV/IR/bit-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/bit-ops.mlir @@ -181,7 +181,7 @@ func @shift_left_logical(%arg0: i32, %arg1 : i16) -> i32 { // ----- func @shift_left_logical_invalid_result_type(%arg0: i32, %arg1 : i16) -> i16 { - // expected-error @+1 {{expected the same type for the first operand and result, but provided 'i32' and 'i16'}} + // expected-error @+1 {{op inferred type(s) 'i32' are incompatible with return type(s) of operation 'i16'}} %0 = "spv.ShiftLeftLogical" (%arg0, %arg1) : (i32, i16) -> (i16) spv.ReturnValue %0 : i16 } diff --git a/mlir/test/Dialect/SPIRV/IR/logical-ops.mlir b/mlir/test/Dialect/SPIRV/IR/logical-ops.mlir index a8af6bf7f217..cd989394ca1b 100644 --- a/mlir/test/Dialect/SPIRV/IR/logical-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/logical-ops.mlir @@ -108,7 +108,7 @@ func @logicalBinary(%arg0 : i1, %arg1 : i1) func @logicalBinary(%arg0 : i1, %arg1 : i1) { - // expected-error @+1 {{custom op 'spv.LogicalAnd' expected 2 operands}} + // expected-error @+1 {{expected ','}} %0 = spv.LogicalAnd %arg0 : i1 return } @@ -166,7 +166,7 @@ func @logicalUnary(%arg0 : i1) func @logicalUnary(%arg0 : i32) { - // expected-error @+1 {{operand #0 must be bool or vector of bool values of length 2/3/4/8/16, but got 'i32'}} + // expected-error @+1 {{'operand' must be bool or vector of bool values of length 2/3/4/8/16, but got 'i32'}} %0 = spv.LogicalNot %arg0 : i32 return } diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp index aeb469b12e30..925b90848ac0 100644 --- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp +++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp @@ -509,10 +509,7 @@ def {0} : LinalgStructuredBase_Op<"{1}", !listconcat([AttrSizedOperandSegments], }]> {5} ]; - let printer = [{{ return ::printNamedStructuredOp(p, *this); }]; - let parser = [{{ - return ::parseNamedStructuredOp<{0}>(parser, result); - }]; + let hasCustomAssemblyFormat = 1; let hasFolder = 1; let extraClassDeclaration = structuredOpsBaseDecls # [{{ @@ -588,6 +585,18 @@ void {0}::getEffects(SmallVectorImpl< } )FMT"; +// Implementation of parse/print. +// Parameters: +// {0}: Class name +static const char structuredOpParserFormat[] = R"FMT( +ParseResult {0}::parse(OpAsmParser &parser, OperationState &result) {{ + return ::parseNamedStructuredOp<{0}>(parser, result); +} +void {0}::print(OpAsmPrinter &p) {{ + ::printNamedStructuredOp(p, *this); +} +)FMT"; + static LogicalResult generateNamedGenericOpOds(LinalgOpConfig &opConfig, GenerationContext &genContext) { if (!genContext.shouldGenerateOds()) @@ -1008,6 +1017,9 @@ void {0}::regionBuilder(ImplicitLocOpBuilder &b, Block &block) {{ interleaveToString(stmts, "\n ")); } + // Parser and printer. + os << llvm::formatv(structuredOpParserFormat, className); + // Canonicalizers and folders. os << llvm::formatv(structuredOpFoldersFormat, className);