Support custom accumulator provided as region to gpu.all_reduce.

In addition to specifying the type of accumulation through the 'op' attribute, the accumulation can now also be specified as arbitrary code region.

Adds a gpu.yield op to specify the result of the accumulation.

Also support more types (integers) and accumulations (mul).

PiperOrigin-RevId: 275065447
This commit is contained in:
Christian Sigg 2019-10-16 10:43:12 -07:00 committed by A. Unique TensorFlower
parent e7b49eef1d
commit d2f0f847af
8 changed files with 312 additions and 21 deletions

View File

@ -59,23 +59,48 @@ def gpu_Return : GPU_Op<"return", [Terminator]>, Arguments<(ins)>,
let printer = [{ p << getOperationName(); }];
}
def gpu_AllReduce : GPU_Op<"all_reduce", [SameOperandsAndResultType]>,
Arguments<(ins AnyType)>, Results<(outs AnyType)> {
def gpu_Yield : GPU_Op<"yield", [Terminator]>,
Arguments<(ins Variadic<AnyType>:$values)> {
let summary = "GPU yield operation";
let description = [{
"gpu.yield" is a special terminator operation for blocks inside regions
in gpu ops. It returns values to the immediately enclosing gpu op.
Example:
gpu.yield %f0, %f1 : f32, f32
}];
}
def gpu_AllReduce : GPU_Op<"all_reduce",
[SameOperandsAndResultType, IsolatedFromAbove]>,
Arguments<(ins AnyType:$value, OptionalAttr<StrAttr>:$op)>,
Results<(outs AnyType)> {
let summary = "Reduce values among workgroup.";
let description = [{
The "all_reduce" op reduces the value of every invocation across a local
workgroup.
workgroup. The result is equal for all invocations of a local workgroup.
For example,
For example, both
```
%1 = gpu.all_reduce %0 : f32
%1 = "gpu.all_reduce"(%0) ({}) { op = "add" } : (f32) -> (f32)
%2 = "gpu.all_reduce"(%0) ({
^bb(%lhs : f32, %rhs : f32):
%sum = addf %lhs, %rhs : f32
"gpu.yield"(%sum) : (f32) -> ()
}) : (f32) -> (f32)
```
computes the sum of each invocation's %0 value. The value of %1 is always
equal for all invocations of a local workgroup.
compute the sum of each invocation's %0 value. The first version specifies
the accumulation as operation, whereas the second version specifies the
accumulation as code region. The accumulation operation must either be
`add` or `mul`.
Either none or all invocations of a local workgroup need to execute this op
in convergence.
}];
let regions = (region AnyRegion:$body);
let verifier = [{ return ::verifyAllReduce(*this); }];
}
#endif // GPU_OPS

View File

@ -25,6 +25,7 @@
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
#include "mlir/Dialect/GPU/GPUDialect.h"
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
@ -36,6 +37,9 @@ namespace {
/// Converts all_reduce op to LLVM/NVVM ops.
struct GPUAllReduceOpLowering : public LLVMOpLowering {
using AccumulatorFactory = std::function<Value *(
Location, Value *, Value *, ConversionPatternRewriter &)>;
explicit GPUAllReduceOpLowering(LLVMTypeConverter &lowering_)
: LLVMOpLowering(gpu::AllReduce::getOperationName(),
lowering_.getDialect()->getContext(), lowering_),
@ -44,12 +48,102 @@ struct GPUAllReduceOpLowering : public LLVMOpLowering {
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const override {
Value *result = createBlockReduce(op->getLoc(), operands.front(), rewriter);
Location loc = op->getLoc();
Value *operand = operands.front();
// TODO(csigg): Generalize to other types of accumulation.
assert(op->getOperand(0)->getType().isIntOrFloat());
// Create the reduction using an accumulator factory.
AccumulatorFactory factory = getFactory(cast<gpu::AllReduce>(op), operand);
assert(factory && "failed to create accumulator factory");
Value *result = createBlockReduce(loc, operand, factory, rewriter);
rewriter.replaceOp(op, {result});
return matchSuccess();
}
private:
/// Returns an accumulator factory using either the op attribute or the body
/// region.
AccumulatorFactory getFactory(gpu::AllReduce allReduce,
Value *operand) const {
if (!allReduce.body().empty()) {
return getFactory(allReduce.body());
}
if (allReduce.op()) {
auto type = operand->getType().cast<LLVM::LLVMType>();
return getFactory(*allReduce.op(), type.getUnderlyingType());
}
return AccumulatorFactory();
}
/// Returns an accumulator factory that clones the body. The body's entry
/// block is expected to have 2 arguments. The gpu.yield return the
/// accumulated value of the same type.
AccumulatorFactory getFactory(Region &body) const {
return AccumulatorFactory([&](Location loc, Value *lhs, Value *rhs,
ConversionPatternRewriter &rewriter) {
Block *block = rewriter.getInsertionBlock();
Block *split = rewriter.splitBlock(block, rewriter.getInsertionPoint());
// Insert accumulator body between split block.
BlockAndValueMapping mapping;
mapping.map(body.front().getArgument(0), lhs);
mapping.map(body.front().getArgument(1), rhs);
rewriter.cloneRegionBefore(body, *split->getParent(),
split->getIterator(), mapping);
// Add branch before inserted body, into body.
block = block->getNextNode();
rewriter.create<LLVM::BrOp>(loc, ArrayRef<Value *>{},
llvm::makeArrayRef(block),
llvm::ArrayRef<Value *>());
// Replace all gpu.yield ops with branch out of body.
for (; block != split; block = block->getNextNode()) {
Operation *terminator = block->getTerminator();
if (!llvm::isa<gpu::Yield>(terminator))
continue;
rewriter.setInsertionPointToEnd(block);
rewriter.replaceOpWithNewOp<LLVM::BrOp>(
terminator, ArrayRef<Value *>{}, llvm::makeArrayRef(split),
llvm::makeArrayRef(terminator->getOperand(0)));
}
// Return accumulator result.
rewriter.setInsertionPointToStart(split);
return split->addArgument(lhs->getType());
});
}
/// Returns an accumulator factory that creates an op specified by opName.
AccumulatorFactory getFactory(StringRef opName, llvm::Type *type) const {
if (type->isVectorTy() || type->isArrayTy())
return getFactory(opName, type->getSequentialElementType());
bool isFloatingPoint = type->isFloatingPointTy();
if (opName == "add") {
return isFloatingPoint ? getFactory<LLVM::FAddOp>()
: getFactory<LLVM::AddOp>();
}
if (opName == "mul") {
return isFloatingPoint ? getFactory<LLVM::FMulOp>()
: getFactory<LLVM::MulOp>();
}
return AccumulatorFactory();
}
/// Returns an accumulator factory that creates an op of type T.
template <typename T> AccumulatorFactory getFactory() const {
return [](Location loc, Value *lhs, Value *rhs,
ConversionPatternRewriter &rewriter) {
return rewriter.create<T>(loc, lhs->getType(), lhs, rhs);
};
}
/// Creates an all_reduce across the block.
///
/// First reduce the elements within a warp. The first thread of each warp
@ -87,6 +181,7 @@ private:
/// return %result
///
Value *createBlockReduce(Location loc, Value *operand,
AccumulatorFactory &accumFactory,
ConversionPatternRewriter &rewriter) const {
auto type = operand->getType().cast<LLVM::LLVMType>();
@ -106,8 +201,8 @@ private:
Value *activeWidth = getActiveWidth(loc, threadIdx, blockSize, rewriter);
// Reduce elements within each warp to produce the intermediate results.
Value *warpReduce =
createWarpReduce(loc, activeWidth, laneId, operand, rewriter);
Value *warpReduce = createWarpReduce(loc, activeWidth, laneId, operand,
accumFactory, rewriter);
// Write the intermediate results to shared memory, using the first lane of
// each warp.
@ -131,7 +226,8 @@ private:
Value *loadSrc = rewriter.create<LLVM::GEPOp>(
loc, type, sharedMemPtr, ArrayRef<Value *>({zero, threadIdx}));
Value *value = rewriter.create<LLVM::LoadOp>(loc, type, loadSrc);
Value *result = createWarpReduce(loc, numWarps, laneId, value, rewriter);
Value *result = createWarpReduce(loc, numWarps, laneId, value,
accumFactory, rewriter);
rewriter.create<LLVM::StoreOp>(loc, result, resultPtr);
});
rewriter.create<NVVM::Barrier0Op>(loc);
@ -205,9 +301,8 @@ private:
/// Creates a reduction across the first activeWidth lanes of a warp.
/// The first lane returns the result, all others return values are undefined.
Value *createWarpReduce(Location loc, Value *activeWidth, Value *laneId,
Value *operand,
Value *operand, AccumulatorFactory accumFactory,
ConversionPatternRewriter &rewriter) const {
// TODO(csigg): Generalize to other types of accumulation.
Value *warpSize = rewriter.create<LLVM::ConstantOp>(
loc, int32Type, rewriter.getI32IntegerAttr(kWarpSize));
Value *maskAndClamp = rewriter.create<LLVM::ConstantOp>(
@ -251,7 +346,7 @@ private:
loc, rewriter, isActiveSrcLane,
[&] {
return llvm::SmallVector<Value *, 1>{
rewriter.create<LLVM::FAddOp>(loc, type, value, shfl)};
accumFactory(loc, value, shfl, rewriter)};
},
[&] { return llvm::makeArrayRef(value); });
value = rewriter.getInsertionBlock()->getArgument(0);
@ -269,7 +364,7 @@ private:
loc, int32Type, rewriter.getI32IntegerAttr(i));
Value *shfl = rewriter.create<NVVM::ShflBflyOp>(
loc, type, activeMask, value, offset, maskAndClamp);
value = rewriter.create<LLVM::FAddOp>(loc, type, value, shfl);
value = accumFactory(loc, value, shfl, rewriter);
}
return llvm::SmallVector<Value *, 1>{value};
});
@ -382,6 +477,8 @@ public:
target.addIllegalDialect<gpu::GPUDialect>();
target.addLegalDialect<LLVM::LLVMDialect>();
target.addLegalDialect<NVVM::NVVMDialect>();
// TODO(csigg): Remove once we support replacing non-root ops.
target.addLegalOp<gpu::Yield>();
if (failed(applyPartialConversion(m, target, patterns, &converter)))
signalPassFailure();
}

View File

@ -137,6 +137,38 @@ template <typename T> static LogicalResult verifyIndexOp(T op) {
return success();
}
static LogicalResult verifyAllReduce(gpu::AllReduce allReduce) {
if (allReduce.body().empty() != allReduce.op().hasValue())
return allReduce.emitError(
"expected either an op attribute or a non-empty body");
if (allReduce.op()) {
ArrayRef<StringRef> supportedOps{"add", "mul"};
if (!llvm::is_contained(supportedOps, *allReduce.op()))
return allReduce.emitError("op \"") << *allReduce.op() << "\" is invalid";
}
if (!allReduce.body().empty()) {
if (allReduce.body().front().getNumArguments() != 2)
return allReduce.emitError("expected two region arguments");
for (auto *argument : allReduce.body().front().getArguments()) {
if (argument->getType() != allReduce.getType())
return allReduce.emitError("incorrect region argument type");
}
unsigned yieldCount = 0;
for (Block &block : allReduce.body()) {
if (auto yield = dyn_cast<gpu::Yield>(block.getTerminator())) {
if (yield.getNumOperands() != 1)
return allReduce.emitError("expected one gpu.yield operand");
if (yield.getOperand(0)->getType() != allReduce.getType())
return allReduce.emitError("incorrect gpu.yield type");
++yieldCount;
}
}
if (yieldCount == 0)
return allReduce.emitError("expected gpu.yield op in region");
}
return success();
}
#define GET_OP_CLASSES
#include "mlir/Dialect/GPU/GPUOps.cpp.inc"

View File

@ -1,4 +1,4 @@
// RUN: mlir-opt %s -lower-gpu-ops-to-nvvm-ops | FileCheck %s
// RUN: mlir-opt %s -lower-gpu-ops-to-nvvm-ops -split-input-file | FileCheck %s
module attributes {gpu.kernel_module} {
// CHECK-LABEL: func @gpu_index_ops()
@ -32,12 +32,42 @@ module attributes {gpu.kernel_module} {
// CHECK: = nvvm.read.ptx.sreg.nctaid.z : !llvm.i32
%gDimZ = "gpu.grid_dim"() {dimension = "z"} : () -> (index)
%one = constant 1.0 : f32
std.return
}
}
// -----
module attributes {gpu.kernel_module} {
// CHECK-LABEL: func @gpu_all_reduce_op()
func @gpu_all_reduce_op()
attributes { gpu.kernel } {
%arg0 = constant 1.0 : f32
// TODO(csigg): Check full IR expansion once lowering has settled.
// CHECK: nvvm.shfl.sync.bfly
// CHECK: nvvm.barrier0
// CHECK: nvvm.shfl.sync.bfly
%result = "gpu.all_reduce"(%one) {scope = "workgroup", kernel = "add"} : (f32) -> (f32)
// CHECK: llvm.fadd
%result = "gpu.all_reduce"(%arg0) ({}) {op = "add"} : (f32) -> (f32)
std.return
}
}
// -----
module attributes {gpu.kernel_module} {
// CHECK-LABEL: func @gpu_all_reduce_region()
func @gpu_all_reduce_region()
attributes { gpu.kernel } {
%arg0 = constant 1 : i32
// TODO(csigg): Check full IR expansion once lowering has settled.
// CHECK: nvvm.shfl.sync.bfly
// CHECK: nvvm.barrier0
%result = "gpu.all_reduce"(%arg0) ({
^bb(%lhs : i32, %rhs : i32):
%xor = xor %lhs, %rhs : i32
"gpu.yield"(%xor) : (i32) -> ()
}) : (i32) -> (i32)
std.return
}

View File

@ -282,3 +282,81 @@ func @illegal_dimension() {
return
}
// -----
func @reduce_no_op_no_body(%arg0 : f32) {
// expected-error@+1 {{expected either an op attribute or a non-empty body}}
%res = "gpu.all_reduce"(%arg0) ({}) : (f32) -> (f32)
return
}
// -----
func @reduce_op_and_body(%arg0 : f32) {
// expected-error@+1 {{expected either an op attribute or a non-empty body}}
%res = "gpu.all_reduce"(%arg0) ({
^bb(%lhs : f32, %rhs : f32):
"gpu.yield"(%lhs) : (f32) -> ()
}) {op = "add"} : (f32) -> (f32)
}
// -----
func @reduce_invalid_op(%arg0 : f32) {
// expected-error@+1 {{op "foo" is invalid}}
%res = "gpu.all_reduce"(%arg0) ({}) {op = "foo"} : (f32) -> (f32)
return
}
// -----
func @reduce_incorrect_region_arguments(%arg0 : f32) {
// expected-error@+1 {{expected two region arguments}}
%res = "gpu.all_reduce"(%arg0) ({
^bb(%lhs : f32):
"gpu.yield"(%lhs) : (f32) -> ()
}) : (f32) -> (f32)
}
// -----
func @reduce_incorrect_region_arguments(%arg0 : f32) {
// expected-error@+1 {{incorrect region argument type}}
%res = "gpu.all_reduce"(%arg0) ({
^bb(%lhs : f32, %rhs : i32):
"gpu.yield"(%lhs) : (f32) -> ()
}) : (f32) -> (f32)
}
// -----
func @reduce_incorrect_yield(%arg0 : f32) {
// expected-error@+1 {{expected one gpu.yield operand}}
%res = "gpu.all_reduce"(%arg0) ({
^bb(%lhs : f32, %rhs : f32):
"gpu.yield"(%lhs, %rhs) : (f32, f32) -> ()
}) : (f32) -> (f32)
}
// -----
func @reduce_incorrect_yield(%arg0 : f32) {
// expected-error@+1 {{incorrect gpu.yield type}}
%res = "gpu.all_reduce"(%arg0) ({
^bb(%lhs : f32, %rhs : f32):
%one = constant 1 : i32
"gpu.yield"(%one) : (i32) -> ()
}) : (f32) -> (f32)
}
// -----
func @reduce_incorrect_yield(%arg0 : f32) {
// expected-error@+1 {{expected gpu.yield op in region}}
%res = "gpu.all_reduce"(%arg0) ({
^bb(%lhs : f32, %rhs : f32):
return
}) : (f32) -> (f32)
}

View File

@ -80,7 +80,7 @@ module attributes {gpu.container_module} {
%gDimZ = "gpu.grid_dim"() {dimension = "z"} : () -> (index)
%one = constant 1.0 : f32
%sum = "gpu.all_reduce"(%one) : (f32) -> (f32)
%sum = "gpu.all_reduce"(%one) ({}) {op = "add"} : (f32) -> (f32)
"some_op"(%bIdX, %tIdX) : (index, index) -> ()
%42 = load %arg1[%bIdX] : memref<?xf32, 1>

View File

@ -19,7 +19,7 @@ func @main() {
%idx = addi %tx, %t2 : index
%t3 = index_cast %idx : index to i32
%val = sitofp %t3 : i32 to f32
%sum = "gpu.all_reduce"(%val) { op = "add" } : (f32) -> (f32)
%sum = "gpu.all_reduce"(%val) ({}) { op = "add" } : (f32) -> (f32)
store %sum, %kernel_dst[%tx, %ty, %tz] : memref<?x?x?xf32>
gpu.return
}

View File

@ -0,0 +1,29 @@
// RUN: mlir-cuda-runner %s --shared-libs=%cuda_wrapper_library_dir/libcuda-runtime-wrappers%shlibext --entry-point-result=void | FileCheck %s
// CHECK: [3.500000e+01, 3.500000e+01, {{.*}}, 3.500000e+01, 3.500000e+01]
func @main() {
%arg = alloc() : memref<35xf32>
%dst = memref_cast %arg : memref<35xf32> to memref<?xf32>
%zero = constant 0 : i32
%one = constant 1 : index
%sx = dim %dst, 0 : memref<?xf32>
call @mcuMemHostRegister(%dst, %zero) : (memref<?xf32>, i32) -> ()
gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %one, %grid_y = %one, %grid_z = %one)
threads(%tx, %ty, %tz) in (%block_x = %sx, %block_y = %one, %block_z = %one)
args(%kernel_dst = %dst) : memref<?xf32> {
%val = index_cast %tx : index to i32
%xor = "gpu.all_reduce"(%val) ({
^bb(%lhs : i32, %rhs : i32):
%xor = xor %lhs, %rhs : i32
"gpu.yield"(%xor) : (i32) -> ()
}) : (i32) -> (i32)
%res = sitofp %xor : i32 to f32
store %res, %kernel_dst[%tx] : memref<?xf32>
gpu.return
}
call @mcuPrintFloat(%dst) : (memref<?xf32>) -> ()
return
}
func @mcuMemHostRegister(%ptr : memref<?xf32>, %flags : i32)
func @mcuPrintFloat(%ptr : memref<?xf32>)