[MLIR] Add conversion from AtomicRMWOp -> GenericAtomicRMWOp.

Adding this pattern reduces code duplication. There is no need to have a
custom implementation for lowering to llvm.cmpxchg.

Differential Revision: https://reviews.llvm.org/D78753
This commit is contained in:
Alexander Belyaev 2020-05-05 08:30:30 +02:00
parent cd3a54c55a
commit b79751e83d
13 changed files with 222 additions and 127 deletions

View File

@ -39,3 +39,7 @@ This document describes the available MLIR passes and their contracts.
## `spv` Dialect Passes
[include "SPIRVPasses.md"]
## `standard` Dialect Passes
[include "StandardPasses.md"]

View File

@ -1 +1,2 @@
add_subdirectory(IR)
add_subdirectory(Transforms)

View File

@ -0,0 +1,5 @@
set(LLVM_TARGET_DEFINITIONS Passes.td)
mlir_tablegen(Passes.h.inc -gen-pass-decls)
add_public_tablegen_target(MLIRStandardTransformsIncGen)
add_mlir_doc(Passes -gen-pass-doc StandardPasses ./)

View File

@ -0,0 +1,29 @@
//===- Passes.h - Pass Entrypoints ------------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This header file defines prototypes that expose pass constructors in the loop
// transformation library.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_STANDARD_TRANSFORMS_PASSES_H_
#define MLIR_DIALECT_STANDARD_TRANSFORMS_PASSES_H_
#include <memory>
namespace mlir {
class Pass;
/// Creates an instance of the ExpandAtomic pass.
std::unique_ptr<Pass> createExpandAtomicPass();
} // end namespace mlir
#endif // MLIR_DIALECT_STANDARD_TRANSFORMS_PASSES_H_

View File

@ -0,0 +1,19 @@
//===-- Passes.td - StandardOps pass definition file -------*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_STANDARD_TRANSFORMS_PASSES
#define MLIR_DIALECT_STANDARD_TRANSFORMS_PASSES
include "mlir/Pass/PassBase.td"
def ExpandAtomic : FunctionPass<"expand-atomic"> {
let summary = "Expands AtomicRMWOp into GenericAtomicRMWOp.";
let constructor = "mlir::createExpandAtomicPass()";
}
#endif // MLIR_DIALECT_STANDARD_TRANSFORMS_PASSES

View File

@ -34,6 +34,7 @@
#include "mlir/Dialect/LoopOps/Passes.h"
#include "mlir/Dialect/Quant/Passes.h"
#include "mlir/Dialect/SPIRV/Passes.h"
#include "mlir/Dialect/StandardOps/Transforms/Passes.h"
#include "mlir/Transforms/LocationSnapshot.h"
#include "mlir/Transforms/Passes.h"
#include "mlir/Transforms/ViewOpGraph.h"
@ -86,6 +87,10 @@ inline void registerAllPasses() {
// SPIR-V
#define GEN_PASS_REGISTRATION
#include "mlir/Dialect/SPIRV/Passes.h.inc"
// Standard
#define GEN_PASS_REGISTRATION
#include "mlir/Dialect/StandardOps/Transforms/Passes.h.inc"
}
} // namespace mlir

View File

@ -2743,113 +2743,6 @@ struct AtomicRMWOpLowering : public LoadStoreOpLowering<AtomicRMWOp> {
}
};
/// Wrap a llvm.cmpxchg operation in a while loop so that the operation can be
/// retried until it succeeds in atomically storing a new value into memory.
///
/// +---------------------------------+
/// | <code before the AtomicRMWOp> |
/// | <compute initial %loaded> |
/// | br loop(%loaded) |
/// +---------------------------------+
/// |
/// -------| |
/// | v v
/// | +--------------------------------+
/// | | loop(%loaded): |
/// | | <body contents> |
/// | | %pair = cmpxchg |
/// | | %ok = %pair[0] |
/// | | %new = %pair[1] |
/// | | cond_br %ok, end, loop(%new) |
/// | +--------------------------------+
/// | | |
/// |----------- |
/// v
/// +--------------------------------+
/// | end: |
/// | <code after the AtomicRMWOp> |
/// +--------------------------------+
///
struct AtomicCmpXchgOpLowering : public LoadStoreOpLowering<AtomicRMWOp> {
using Base::Base;
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto atomicOp = cast<AtomicRMWOp>(op);
auto maybeKind = matchSimpleAtomicOp(atomicOp);
if (maybeKind)
return failure();
LLVM::FCmpPredicate predicate;
switch (atomicOp.kind()) {
case AtomicRMWKind::maxf:
predicate = LLVM::FCmpPredicate::ogt;
break;
case AtomicRMWKind::minf:
predicate = LLVM::FCmpPredicate::olt;
break;
default:
return failure();
}
OperandAdaptor<AtomicRMWOp> adaptor(operands);
auto loc = op->getLoc();
auto valueType = adaptor.value().getType().cast<LLVM::LLVMType>();
// Split the block into initial, loop, and ending parts.
auto *initBlock = rewriter.getInsertionBlock();
auto initPosition = rewriter.getInsertionPoint();
auto *loopBlock = rewriter.splitBlock(initBlock, initPosition);
auto loopArgument = loopBlock->addArgument(valueType);
auto loopPosition = rewriter.getInsertionPoint();
auto *endBlock = rewriter.splitBlock(loopBlock, loopPosition);
// Compute the loaded value and branch to the loop block.
rewriter.setInsertionPointToEnd(initBlock);
auto memRefType = atomicOp.getMemRefType();
auto dataPtr = getDataPtr(loc, memRefType, adaptor.memref(),
adaptor.indices(), rewriter, getModule());
Value init = rewriter.create<LLVM::LoadOp>(loc, dataPtr);
rewriter.create<LLVM::BrOp>(loc, init, loopBlock);
// Prepare the body of the loop block.
rewriter.setInsertionPointToStart(loopBlock);
auto predicateI64 =
rewriter.getI64IntegerAttr(static_cast<int64_t>(predicate));
auto boolType = LLVM::LLVMType::getInt1Ty(&getDialect());
auto lhs = loopArgument;
auto rhs = adaptor.value();
auto cmp =
rewriter.create<LLVM::FCmpOp>(loc, boolType, predicateI64, lhs, rhs);
auto select = rewriter.create<LLVM::SelectOp>(loc, cmp, lhs, rhs);
// Prepare the epilog of the loop block.
rewriter.setInsertionPointToEnd(loopBlock);
// Append the cmpxchg op to the end of the loop block.
auto successOrdering = LLVM::AtomicOrdering::acq_rel;
auto failureOrdering = LLVM::AtomicOrdering::monotonic;
auto pairType = LLVM::LLVMType::getStructTy(valueType, boolType);
auto cmpxchg = rewriter.create<LLVM::AtomicCmpXchgOp>(
loc, pairType, dataPtr, loopArgument, select, successOrdering,
failureOrdering);
// Extract the %new_loaded and %ok values from the pair.
Value newLoaded = rewriter.create<LLVM::ExtractValueOp>(
loc, valueType, cmpxchg, rewriter.getI64ArrayAttr({0}));
Value ok = rewriter.create<LLVM::ExtractValueOp>(
loc, boolType, cmpxchg, rewriter.getI64ArrayAttr({1}));
// Conditionally branch to the end or back to the loop depending on %ok.
rewriter.create<LLVM::CondBrOp>(loc, ok, endBlock, ArrayRef<Value>(),
loopBlock, newLoaded);
// The 'result' of the atomic_rmw op is the newly loaded value.
rewriter.replaceOp(op, {newLoaded});
return success();
}
};
/// Wrap a llvm.cmpxchg operation in a while loop so that the operation can be
/// retried until it succeeds in atomically storing a new value into memory.
///
@ -2985,7 +2878,6 @@ void mlir::populateStdToLLVMNonMemoryConversionPatterns(
AddIOpLowering,
AllocaOpLowering,
AndOpLowering,
AtomicCmpXchgOpLowering,
AtomicRMWOpLowering,
BranchOpLowering,
CallIndirectOpLowering,

View File

@ -17,3 +17,5 @@ add_mlir_dialect_library(MLIRStandardOps
MLIRSideEffects
MLIRViewLikeInterface
)
add_subdirectory(Transforms)

View File

@ -0,0 +1,17 @@
add_mlir_dialect_library(MLIRStandardOpsTransforms
ExpandAtomic.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/StandardOps/Transforms
DEPENDS
MLIRStandardTransformsIncGen
)
target_link_libraries(MLIRStandardOpsTransforms
PUBLIC
MLIRIR
MLIRPass
MLIRStandardOps
MLIRSupport
LLVMSupport
)

View File

@ -0,0 +1,93 @@
//===- ExpandAtomic.cpp - Code to perform loop fusion ---------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements expansion of AtomicRMWOp into GenericAtomicRMWOp.
//
//===----------------------------------------------------------------------===//
#include "PassDetail.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/StandardOps/Transforms/Passes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
using namespace mlir;
namespace {
/// Converts `atomic_rmw` that cannot be lowered to a simple atomic op with
/// AtomicRMWOpLowering pattern, e.g. with "minf" or "maxf" attributes, to
/// `generic_atomic_rmw` with the expanded code.
///
/// %x = atomic_rmw "maxf" %fval, %F[%i] : (f32, memref<10xf32>) -> f32
///
/// will be lowered to
///
/// %x = std.generic_atomic_rmw %F[%i] : memref<10xf32> {
/// ^bb0(%current: f32):
/// %cmp = cmpf "ogt", %current, %fval : f32
/// %new_value = select %cmp, %current, %fval : f32
/// atomic_yield %new_value : f32
/// }
struct AtomicRMWOpConverter : public OpRewritePattern<AtomicRMWOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtomicRMWOp op,
PatternRewriter &rewriter) const final {
CmpFPredicate predicate;
switch (op.kind()) {
case AtomicRMWKind::maxf:
predicate = CmpFPredicate::OGT;
break;
case AtomicRMWKind::minf:
predicate = CmpFPredicate::OLT;
break;
default:
return failure();
}
auto loc = op.getLoc();
auto genericOp =
rewriter.create<GenericAtomicRMWOp>(loc, op.memref(), op.indices());
OpBuilder bodyBuilder = OpBuilder::atBlockEnd(genericOp.getBody());
Value lhs = genericOp.getCurrentValue();
Value rhs = op.value();
Value cmp = bodyBuilder.create<CmpFOp>(loc, predicate, lhs, rhs);
Value select = bodyBuilder.create<SelectOp>(loc, cmp, lhs, rhs);
bodyBuilder.create<AtomicYieldOp>(loc, select);
rewriter.replaceOp(op, genericOp.getResult());
return success();
}
};
struct ExpandAtomic : public ExpandAtomicBase<ExpandAtomic> {
void runOnFunction() override {
OwningRewritePatternList patterns;
patterns.insert<AtomicRMWOpConverter>(&getContext());
ConversionTarget target(getContext());
target.addLegalOp<GenericAtomicRMWOp>();
target.addDynamicallyLegalOp<AtomicRMWOp>([](AtomicRMWOp op) {
return op.kind() != AtomicRMWKind::maxf &&
op.kind() != AtomicRMWKind::minf;
});
if (failed(mlir::applyPartialConversion(getFunction(), target, patterns)))
signalPassFailure();
}
};
} // namespace
std::unique_ptr<Pass> mlir::createExpandAtomicPass() {
return std::make_unique<ExpandAtomic>();
}

View File

@ -0,0 +1,23 @@
//===- PassDetail.h - GPU Pass class details --------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef DIALECT_STANDARD_TRANSFORMS_PASSDETAIL_H_
#define DIALECT_STANDARD_TRANSFORMS_PASSDETAIL_H_
#include "mlir/Pass/Pass.h"
namespace mlir {
class AtomicRMWOp;
#define GEN_PASS_CLASSES
#include "mlir/Dialect/StandardOps/Transforms/Passes.h.inc"
} // end namespace mlir
#endif // DIALECT_STANDARD_TRANSFORMS_PASSDETAIL_H_

View File

@ -1110,25 +1110,6 @@ func @atomic_rmw(%I : memref<10xi32>, %ival : i32, %F : memref<10xf32>, %fval :
// -----
// CHECK-LABEL: func @cmpxchg
func @cmpxchg(%F : memref<10xf32>, %fval : f32, %i : index) -> f32 {
%x = atomic_rmw "maxf" %fval, %F[%i] : (f32, memref<10xf32>) -> f32
// CHECK: %[[init:.*]] = llvm.load %{{.*}} : !llvm<"float*">
// CHECK-NEXT: llvm.br ^bb1(%[[init]] : !llvm.float)
// CHECK-NEXT: ^bb1(%[[loaded:.*]]: !llvm.float):
// CHECK-NEXT: %[[cmp:.*]] = llvm.fcmp "ogt" %[[loaded]], %{{.*}} : !llvm.float
// CHECK-NEXT: %[[max:.*]] = llvm.select %[[cmp]], %[[loaded]], %{{.*}} : !llvm.i1, !llvm.float
// CHECK-NEXT: %[[pair:.*]] = llvm.cmpxchg %{{.*}}, %[[loaded]], %[[max]] acq_rel monotonic : !llvm.float
// CHECK-NEXT: %[[new:.*]] = llvm.extractvalue %[[pair]][0] : !llvm<"{ float, i1 }">
// CHECK-NEXT: %[[ok:.*]] = llvm.extractvalue %[[pair]][1] : !llvm<"{ float, i1 }">
// CHECK-NEXT: llvm.cond_br %[[ok]], ^bb2, ^bb1(%[[new]] : !llvm.float)
// CHECK-NEXT: ^bb2:
return %x : f32
// CHECK-NEXT: llvm.return %[[new]]
}
// -----
// CHECK-LABEL: func @generic_atomic_rmw
func @generic_atomic_rmw(%I : memref<10xf32>, %i : index) -> f32 {
%x = generic_atomic_rmw %I[%i] : memref<10xf32> {

View File

@ -0,0 +1,24 @@
// RUN: mlir-opt %s -expand-atomic -split-input-file | FileCheck %s --dump-input-on-failure
// CHECK-LABEL: func @atomic_rmw_to_generic
// CHECK-SAME: ([[F:%.*]]: memref<10xf32>, [[f:%.*]]: f32, [[i:%.*]]: index)
func @atomic_rmw_to_generic(%F: memref<10xf32>, %f: f32, %i: index) -> f32 {
%x = atomic_rmw "maxf" %f, %F[%i] : (f32, memref<10xf32>) -> f32
return %x : f32
}
// CHECK: %0 = std.generic_atomic_rmw %arg0[%arg2] : memref<10xf32> {
// CHECK: ^bb0([[CUR_VAL:%.*]]: f32):
// CHECK: [[CMP:%.*]] = cmpf "ogt", [[CUR_VAL]], [[f]] : f32
// CHECK: [[SELECT:%.*]] = select [[CMP]], [[CUR_VAL]], [[f]] : f32
// CHECK: atomic_yield [[SELECT]] : f32
// CHECK: }
// CHECK: return %0 : f32
// -----
// CHECK-LABEL: func @atomic_rmw_no_conversion
func @atomic_rmw_no_conversion(%F: memref<10xf32>, %f: f32, %i: index) -> f32 {
%x = atomic_rmw "addf" %f, %F[%i] : (f32, memref<10xf32>) -> f32
return %x : f32
}
// CHECK-NOT: generic_atomic_rmw