forked from OSchip/llvm-project
[MLIR] Add conversion from AtomicRMWOp -> GenericAtomicRMWOp.
Adding this pattern reduces code duplication. There is no need to have a custom implementation for lowering to llvm.cmpxchg. Differential Revision: https://reviews.llvm.org/D78753
This commit is contained in:
parent
cd3a54c55a
commit
b79751e83d
|
@ -39,3 +39,7 @@ This document describes the available MLIR passes and their contracts.
|
|||
## `spv` Dialect Passes
|
||||
|
||||
[include "SPIRVPasses.md"]
|
||||
|
||||
## `standard` Dialect Passes
|
||||
|
||||
[include "StandardPasses.md"]
|
||||
|
|
|
@ -1 +1,2 @@
|
|||
add_subdirectory(IR)
|
||||
add_subdirectory(Transforms)
|
||||
|
|
|
@ -0,0 +1,5 @@
|
|||
set(LLVM_TARGET_DEFINITIONS Passes.td)
|
||||
mlir_tablegen(Passes.h.inc -gen-pass-decls)
|
||||
add_public_tablegen_target(MLIRStandardTransformsIncGen)
|
||||
|
||||
add_mlir_doc(Passes -gen-pass-doc StandardPasses ./)
|
|
@ -0,0 +1,29 @@
|
|||
|
||||
//===- Passes.h - Pass Entrypoints ------------------------------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This header file defines prototypes that expose pass constructors in the loop
|
||||
// transformation library.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_DIALECT_STANDARD_TRANSFORMS_PASSES_H_
|
||||
#define MLIR_DIALECT_STANDARD_TRANSFORMS_PASSES_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
namespace mlir {
|
||||
|
||||
class Pass;
|
||||
|
||||
/// Creates an instance of the ExpandAtomic pass.
|
||||
std::unique_ptr<Pass> createExpandAtomicPass();
|
||||
|
||||
} // end namespace mlir
|
||||
|
||||
#endif // MLIR_DIALECT_STANDARD_TRANSFORMS_PASSES_H_
|
|
@ -0,0 +1,19 @@
|
|||
//===-- Passes.td - StandardOps pass definition file -------*- tablegen -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_DIALECT_STANDARD_TRANSFORMS_PASSES
|
||||
#define MLIR_DIALECT_STANDARD_TRANSFORMS_PASSES
|
||||
|
||||
include "mlir/Pass/PassBase.td"
|
||||
|
||||
def ExpandAtomic : FunctionPass<"expand-atomic"> {
|
||||
let summary = "Expands AtomicRMWOp into GenericAtomicRMWOp.";
|
||||
let constructor = "mlir::createExpandAtomicPass()";
|
||||
}
|
||||
|
||||
#endif // MLIR_DIALECT_STANDARD_TRANSFORMS_PASSES
|
|
@ -34,6 +34,7 @@
|
|||
#include "mlir/Dialect/LoopOps/Passes.h"
|
||||
#include "mlir/Dialect/Quant/Passes.h"
|
||||
#include "mlir/Dialect/SPIRV/Passes.h"
|
||||
#include "mlir/Dialect/StandardOps/Transforms/Passes.h"
|
||||
#include "mlir/Transforms/LocationSnapshot.h"
|
||||
#include "mlir/Transforms/Passes.h"
|
||||
#include "mlir/Transforms/ViewOpGraph.h"
|
||||
|
@ -86,6 +87,10 @@ inline void registerAllPasses() {
|
|||
// SPIR-V
|
||||
#define GEN_PASS_REGISTRATION
|
||||
#include "mlir/Dialect/SPIRV/Passes.h.inc"
|
||||
|
||||
// Standard
|
||||
#define GEN_PASS_REGISTRATION
|
||||
#include "mlir/Dialect/StandardOps/Transforms/Passes.h.inc"
|
||||
}
|
||||
|
||||
} // namespace mlir
|
||||
|
|
|
@ -2743,113 +2743,6 @@ struct AtomicRMWOpLowering : public LoadStoreOpLowering<AtomicRMWOp> {
|
|||
}
|
||||
};
|
||||
|
||||
/// Wrap a llvm.cmpxchg operation in a while loop so that the operation can be
|
||||
/// retried until it succeeds in atomically storing a new value into memory.
|
||||
///
|
||||
/// +---------------------------------+
|
||||
/// | <code before the AtomicRMWOp> |
|
||||
/// | <compute initial %loaded> |
|
||||
/// | br loop(%loaded) |
|
||||
/// +---------------------------------+
|
||||
/// |
|
||||
/// -------| |
|
||||
/// | v v
|
||||
/// | +--------------------------------+
|
||||
/// | | loop(%loaded): |
|
||||
/// | | <body contents> |
|
||||
/// | | %pair = cmpxchg |
|
||||
/// | | %ok = %pair[0] |
|
||||
/// | | %new = %pair[1] |
|
||||
/// | | cond_br %ok, end, loop(%new) |
|
||||
/// | +--------------------------------+
|
||||
/// | | |
|
||||
/// |----------- |
|
||||
/// v
|
||||
/// +--------------------------------+
|
||||
/// | end: |
|
||||
/// | <code after the AtomicRMWOp> |
|
||||
/// +--------------------------------+
|
||||
///
|
||||
struct AtomicCmpXchgOpLowering : public LoadStoreOpLowering<AtomicRMWOp> {
|
||||
using Base::Base;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto atomicOp = cast<AtomicRMWOp>(op);
|
||||
auto maybeKind = matchSimpleAtomicOp(atomicOp);
|
||||
if (maybeKind)
|
||||
return failure();
|
||||
|
||||
LLVM::FCmpPredicate predicate;
|
||||
switch (atomicOp.kind()) {
|
||||
case AtomicRMWKind::maxf:
|
||||
predicate = LLVM::FCmpPredicate::ogt;
|
||||
break;
|
||||
case AtomicRMWKind::minf:
|
||||
predicate = LLVM::FCmpPredicate::olt;
|
||||
break;
|
||||
default:
|
||||
return failure();
|
||||
}
|
||||
|
||||
OperandAdaptor<AtomicRMWOp> adaptor(operands);
|
||||
auto loc = op->getLoc();
|
||||
auto valueType = adaptor.value().getType().cast<LLVM::LLVMType>();
|
||||
|
||||
// Split the block into initial, loop, and ending parts.
|
||||
auto *initBlock = rewriter.getInsertionBlock();
|
||||
auto initPosition = rewriter.getInsertionPoint();
|
||||
auto *loopBlock = rewriter.splitBlock(initBlock, initPosition);
|
||||
auto loopArgument = loopBlock->addArgument(valueType);
|
||||
auto loopPosition = rewriter.getInsertionPoint();
|
||||
auto *endBlock = rewriter.splitBlock(loopBlock, loopPosition);
|
||||
|
||||
// Compute the loaded value and branch to the loop block.
|
||||
rewriter.setInsertionPointToEnd(initBlock);
|
||||
auto memRefType = atomicOp.getMemRefType();
|
||||
auto dataPtr = getDataPtr(loc, memRefType, adaptor.memref(),
|
||||
adaptor.indices(), rewriter, getModule());
|
||||
Value init = rewriter.create<LLVM::LoadOp>(loc, dataPtr);
|
||||
rewriter.create<LLVM::BrOp>(loc, init, loopBlock);
|
||||
|
||||
// Prepare the body of the loop block.
|
||||
rewriter.setInsertionPointToStart(loopBlock);
|
||||
auto predicateI64 =
|
||||
rewriter.getI64IntegerAttr(static_cast<int64_t>(predicate));
|
||||
auto boolType = LLVM::LLVMType::getInt1Ty(&getDialect());
|
||||
auto lhs = loopArgument;
|
||||
auto rhs = adaptor.value();
|
||||
auto cmp =
|
||||
rewriter.create<LLVM::FCmpOp>(loc, boolType, predicateI64, lhs, rhs);
|
||||
auto select = rewriter.create<LLVM::SelectOp>(loc, cmp, lhs, rhs);
|
||||
|
||||
// Prepare the epilog of the loop block.
|
||||
rewriter.setInsertionPointToEnd(loopBlock);
|
||||
// Append the cmpxchg op to the end of the loop block.
|
||||
auto successOrdering = LLVM::AtomicOrdering::acq_rel;
|
||||
auto failureOrdering = LLVM::AtomicOrdering::monotonic;
|
||||
auto pairType = LLVM::LLVMType::getStructTy(valueType, boolType);
|
||||
auto cmpxchg = rewriter.create<LLVM::AtomicCmpXchgOp>(
|
||||
loc, pairType, dataPtr, loopArgument, select, successOrdering,
|
||||
failureOrdering);
|
||||
// Extract the %new_loaded and %ok values from the pair.
|
||||
Value newLoaded = rewriter.create<LLVM::ExtractValueOp>(
|
||||
loc, valueType, cmpxchg, rewriter.getI64ArrayAttr({0}));
|
||||
Value ok = rewriter.create<LLVM::ExtractValueOp>(
|
||||
loc, boolType, cmpxchg, rewriter.getI64ArrayAttr({1}));
|
||||
|
||||
// Conditionally branch to the end or back to the loop depending on %ok.
|
||||
rewriter.create<LLVM::CondBrOp>(loc, ok, endBlock, ArrayRef<Value>(),
|
||||
loopBlock, newLoaded);
|
||||
|
||||
// The 'result' of the atomic_rmw op is the newly loaded value.
|
||||
rewriter.replaceOp(op, {newLoaded});
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
/// Wrap a llvm.cmpxchg operation in a while loop so that the operation can be
|
||||
/// retried until it succeeds in atomically storing a new value into memory.
|
||||
///
|
||||
|
@ -2985,7 +2878,6 @@ void mlir::populateStdToLLVMNonMemoryConversionPatterns(
|
|||
AddIOpLowering,
|
||||
AllocaOpLowering,
|
||||
AndOpLowering,
|
||||
AtomicCmpXchgOpLowering,
|
||||
AtomicRMWOpLowering,
|
||||
BranchOpLowering,
|
||||
CallIndirectOpLowering,
|
||||
|
|
|
@ -17,3 +17,5 @@ add_mlir_dialect_library(MLIRStandardOps
|
|||
MLIRSideEffects
|
||||
MLIRViewLikeInterface
|
||||
)
|
||||
|
||||
add_subdirectory(Transforms)
|
||||
|
|
|
@ -0,0 +1,17 @@
|
|||
add_mlir_dialect_library(MLIRStandardOpsTransforms
|
||||
ExpandAtomic.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/StandardOps/Transforms
|
||||
|
||||
DEPENDS
|
||||
MLIRStandardTransformsIncGen
|
||||
)
|
||||
target_link_libraries(MLIRStandardOpsTransforms
|
||||
PUBLIC
|
||||
MLIRIR
|
||||
MLIRPass
|
||||
MLIRStandardOps
|
||||
MLIRSupport
|
||||
LLVMSupport
|
||||
)
|
|
@ -0,0 +1,93 @@
|
|||
//===- ExpandAtomic.cpp - Code to perform loop fusion ---------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file implements expansion of AtomicRMWOp into GenericAtomicRMWOp.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "PassDetail.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/Dialect/StandardOps/Transforms/Passes.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace {
|
||||
|
||||
/// Converts `atomic_rmw` that cannot be lowered to a simple atomic op with
|
||||
/// AtomicRMWOpLowering pattern, e.g. with "minf" or "maxf" attributes, to
|
||||
/// `generic_atomic_rmw` with the expanded code.
|
||||
///
|
||||
/// %x = atomic_rmw "maxf" %fval, %F[%i] : (f32, memref<10xf32>) -> f32
|
||||
///
|
||||
/// will be lowered to
|
||||
///
|
||||
/// %x = std.generic_atomic_rmw %F[%i] : memref<10xf32> {
|
||||
/// ^bb0(%current: f32):
|
||||
/// %cmp = cmpf "ogt", %current, %fval : f32
|
||||
/// %new_value = select %cmp, %current, %fval : f32
|
||||
/// atomic_yield %new_value : f32
|
||||
/// }
|
||||
struct AtomicRMWOpConverter : public OpRewritePattern<AtomicRMWOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(AtomicRMWOp op,
|
||||
PatternRewriter &rewriter) const final {
|
||||
CmpFPredicate predicate;
|
||||
switch (op.kind()) {
|
||||
case AtomicRMWKind::maxf:
|
||||
predicate = CmpFPredicate::OGT;
|
||||
break;
|
||||
case AtomicRMWKind::minf:
|
||||
predicate = CmpFPredicate::OLT;
|
||||
break;
|
||||
default:
|
||||
return failure();
|
||||
}
|
||||
|
||||
auto loc = op.getLoc();
|
||||
auto genericOp =
|
||||
rewriter.create<GenericAtomicRMWOp>(loc, op.memref(), op.indices());
|
||||
OpBuilder bodyBuilder = OpBuilder::atBlockEnd(genericOp.getBody());
|
||||
|
||||
Value lhs = genericOp.getCurrentValue();
|
||||
Value rhs = op.value();
|
||||
Value cmp = bodyBuilder.create<CmpFOp>(loc, predicate, lhs, rhs);
|
||||
Value select = bodyBuilder.create<SelectOp>(loc, cmp, lhs, rhs);
|
||||
bodyBuilder.create<AtomicYieldOp>(loc, select);
|
||||
|
||||
rewriter.replaceOp(op, genericOp.getResult());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct ExpandAtomic : public ExpandAtomicBase<ExpandAtomic> {
|
||||
void runOnFunction() override {
|
||||
OwningRewritePatternList patterns;
|
||||
patterns.insert<AtomicRMWOpConverter>(&getContext());
|
||||
|
||||
ConversionTarget target(getContext());
|
||||
target.addLegalOp<GenericAtomicRMWOp>();
|
||||
target.addDynamicallyLegalOp<AtomicRMWOp>([](AtomicRMWOp op) {
|
||||
return op.kind() != AtomicRMWKind::maxf &&
|
||||
op.kind() != AtomicRMWKind::minf;
|
||||
});
|
||||
if (failed(mlir::applyPartialConversion(getFunction(), target, patterns)))
|
||||
signalPassFailure();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<Pass> mlir::createExpandAtomicPass() {
|
||||
return std::make_unique<ExpandAtomic>();
|
||||
}
|
|
@ -0,0 +1,23 @@
|
|||
//===- PassDetail.h - GPU Pass class details --------------------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef DIALECT_STANDARD_TRANSFORMS_PASSDETAIL_H_
|
||||
#define DIALECT_STANDARD_TRANSFORMS_PASSDETAIL_H_
|
||||
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
class AtomicRMWOp;
|
||||
|
||||
#define GEN_PASS_CLASSES
|
||||
#include "mlir/Dialect/StandardOps/Transforms/Passes.h.inc"
|
||||
|
||||
} // end namespace mlir
|
||||
|
||||
#endif // DIALECT_STANDARD_TRANSFORMS_PASSDETAIL_H_
|
|
@ -1110,25 +1110,6 @@ func @atomic_rmw(%I : memref<10xi32>, %ival : i32, %F : memref<10xf32>, %fval :
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @cmpxchg
|
||||
func @cmpxchg(%F : memref<10xf32>, %fval : f32, %i : index) -> f32 {
|
||||
%x = atomic_rmw "maxf" %fval, %F[%i] : (f32, memref<10xf32>) -> f32
|
||||
// CHECK: %[[init:.*]] = llvm.load %{{.*}} : !llvm<"float*">
|
||||
// CHECK-NEXT: llvm.br ^bb1(%[[init]] : !llvm.float)
|
||||
// CHECK-NEXT: ^bb1(%[[loaded:.*]]: !llvm.float):
|
||||
// CHECK-NEXT: %[[cmp:.*]] = llvm.fcmp "ogt" %[[loaded]], %{{.*}} : !llvm.float
|
||||
// CHECK-NEXT: %[[max:.*]] = llvm.select %[[cmp]], %[[loaded]], %{{.*}} : !llvm.i1, !llvm.float
|
||||
// CHECK-NEXT: %[[pair:.*]] = llvm.cmpxchg %{{.*}}, %[[loaded]], %[[max]] acq_rel monotonic : !llvm.float
|
||||
// CHECK-NEXT: %[[new:.*]] = llvm.extractvalue %[[pair]][0] : !llvm<"{ float, i1 }">
|
||||
// CHECK-NEXT: %[[ok:.*]] = llvm.extractvalue %[[pair]][1] : !llvm<"{ float, i1 }">
|
||||
// CHECK-NEXT: llvm.cond_br %[[ok]], ^bb2, ^bb1(%[[new]] : !llvm.float)
|
||||
// CHECK-NEXT: ^bb2:
|
||||
return %x : f32
|
||||
// CHECK-NEXT: llvm.return %[[new]]
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @generic_atomic_rmw
|
||||
func @generic_atomic_rmw(%I : memref<10xf32>, %i : index) -> f32 {
|
||||
%x = generic_atomic_rmw %I[%i] : memref<10xf32> {
|
||||
|
|
|
@ -0,0 +1,24 @@
|
|||
// RUN: mlir-opt %s -expand-atomic -split-input-file | FileCheck %s --dump-input-on-failure
|
||||
|
||||
// CHECK-LABEL: func @atomic_rmw_to_generic
|
||||
// CHECK-SAME: ([[F:%.*]]: memref<10xf32>, [[f:%.*]]: f32, [[i:%.*]]: index)
|
||||
func @atomic_rmw_to_generic(%F: memref<10xf32>, %f: f32, %i: index) -> f32 {
|
||||
%x = atomic_rmw "maxf" %f, %F[%i] : (f32, memref<10xf32>) -> f32
|
||||
return %x : f32
|
||||
}
|
||||
// CHECK: %0 = std.generic_atomic_rmw %arg0[%arg2] : memref<10xf32> {
|
||||
// CHECK: ^bb0([[CUR_VAL:%.*]]: f32):
|
||||
// CHECK: [[CMP:%.*]] = cmpf "ogt", [[CUR_VAL]], [[f]] : f32
|
||||
// CHECK: [[SELECT:%.*]] = select [[CMP]], [[CUR_VAL]], [[f]] : f32
|
||||
// CHECK: atomic_yield [[SELECT]] : f32
|
||||
// CHECK: }
|
||||
// CHECK: return %0 : f32
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @atomic_rmw_no_conversion
|
||||
func @atomic_rmw_no_conversion(%F: memref<10xf32>, %f: f32, %i: index) -> f32 {
|
||||
%x = atomic_rmw "addf" %f, %F[%i] : (f32, memref<10xf32>) -> f32
|
||||
return %x : f32
|
||||
}
|
||||
// CHECK-NOT: generic_atomic_rmw
|
Loading…
Reference in New Issue