[mlir][NVGPU] nvgpu.mmasync on F32 through TF32

Adds optional attribute to support tensor cores on F32 datatype by lowering to `mma.sync` with TF32 operands. Since, TF32 is not a native datatype in LLVM we are adding `tf32Enabled` as an attribute to allow the IR to be aware of `MmaSyncOp` datatype. Additionally, this patch adds placeholders for nvgpu-to-nvgpu transformation targeting higher precision tf32x3.

For mma.sync on f32 input using tensor cores there are two possibilites:
(a) tf32   (1 `mma.sync` per warp-level matrix-multiply-accumulate)
(b) tf32x3 (3 `mma.sync` per warp-level matrix-multiply-accumulate)

Typically, tf32 tensor core acceleration comes at a cost of accuracy from missing precision bits. While f32 has 23 precision bits, tf32 has only 10 precision bits. tf32x3 aims to recover the precision bits by splitting each operand into two tf32 values and issue three `mma.sync` tensor core operations.

Reviewed By: ThomasRaoux

Differential Revision: https://reviews.llvm.org/D130294
This commit is contained in:
Manish Gupta 2022-08-01 23:06:23 +00:00 committed by Thomas Raoux
parent bcef4d238d
commit 14d79afeae
16 changed files with 283 additions and 8 deletions

View File

@ -110,11 +110,22 @@ def NVGPU_MmaSyncOp : NVGPU_Op<"mma.sync", [
(vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
```
}];
let arguments = (ins AnyVector:$matrixA, AnyVector:$matrixB,
AnyVector:$matrixC, I64ArrayAttr:$mmaShape);
let arguments = (ins AnyVector:$matrixA,
AnyVector:$matrixB,
AnyVector:$matrixC,
I64ArrayAttr:$mmaShape,
OptionalAttr<UnitAttr>:$tf32Enabled
);
let results = (outs AnyVector:$res);
let builders = [
OpBuilder<(ins "Value":$matrixA,
"Value":$matrixB,
"Value":$matrixC,
"ArrayAttr":$mmaShape)>
];
let assemblyFormat = [{
`(` $matrixA`,` $matrixB`,` $matrixC `)` attr-dict
`:` `(` type($matrixA) `,` type($matrixB) `,` type($matrixC) `)` `->` type($res)

View File

@ -19,6 +19,10 @@
namespace mlir {
namespace nvgpu {
///
/// Passes
///
/// Optimizes vectorized accesses to a shared memory buffer specified by
/// memrefValue. This transformation assumes the following:
/// 1) All relevant accesses to `memrefValue` are contained with `parentOp`.
@ -41,6 +45,29 @@ namespace nvgpu {
mlir::LogicalResult optimizeSharedMemoryReadsAndWrites(Operation *parentOp,
Value memrefValue);
///
/// Rewrites patterns
///
//===----------------------------------------------------------------------===//
// NVGPU transformation options exposed as auxiliary structs.
//===----------------------------------------------------------------------===//
/// Enum to control the lowering of `nvgpu.mmasync`.
enum class MmaSyncF32Lowering { TF32 = 0, TF32x3 = 1, Unkown = 2 };
/// Collect patterns to convert mma.sync on f32 input and rewrite
/// to use tensor cores with user provided level of accuracy:
/// (a) tf32 (1 mma.sync per warp-level matrix-multiply-accumulate)
/// (b) tf32x3 (3 mma.sync per warp-level matrix-multiply-accumulate)
/// Typically, tf32 tensor core acceleration comes at a cost
/// of accuracy from missing precision bits. While f32 has 23 precision
/// bits, tf32 has only 10 precision bits. tf32x3 aims to recover the
/// precision bits by spliting each operand into two tf32 values
/// and issue three mma.sync tensor core operations.
void populateMmaSyncF32ToTF32Patterns(
RewritePatternSet &patterns,
nvgpu::MmaSyncF32Lowering precision = nvgpu::MmaSyncF32Lowering::TF32);
} // namespace nvgpu
} // namespace mlir

View File

@ -275,10 +275,14 @@ struct MmaSyncOptoNVVM : public ConvertOpToLLVMPattern<nvgpu::MmaSyncOp> {
NVVM::MMATypes ptxTypeB;
Optional<NVVM::MMATypes> ptxTypeC = NVVM::MmaOp::inferOperandMMAType(
cType.getElementType(), /*isAccumulator=*/true);
if (!ptxTypeC) {
if (!ptxTypeC)
return op->emitError(
"could not infer the PTX type for the accumulator/result");
}
// Tensor Cores (mma.sync) on F32 works only with TensorFloat32 (TF32).
bool tf32Enabled = op->hasAttr(op.getTf32EnabledAttrName());
if (aType.getElementType().isF32() && !tf32Enabled)
return failure();
Optional<NVVM::MMAIntOverflow> overflow(llvm::None);
if (aType.getElementType().isInteger(8)) {

View File

@ -687,8 +687,8 @@ convertContractOpToMmaSync(vector::ContractionOp op,
int64_t m = op.getLhs().getType().cast<VectorType>().getShape()[0];
int64_t n = op.getRhs().getType().cast<VectorType>().getShape()[0];
int64_t k = op.getLhs().getType().cast<VectorType>().getShape()[1];
Value matmul = b.create<nvgpu::MmaSyncOp>(
op.getLoc(), opC.getType(), opA, opB, opC, b.getI64ArrayAttr({m, n, k}));
Value matmul = b.create<nvgpu::MmaSyncOp>(op.getLoc(), opA, opB, opC,
b.getI64ArrayAttr({m, n, k}));
valueMapping[op.getResult()] = matmul;
return success();
}

View File

@ -91,6 +91,12 @@ LogicalResult DeviceAsyncCopyOp::verify() {
//===----------------------------------------------------------------------===//
// NVGPU_MmaSyncOp
//===----------------------------------------------------------------------===//
void MmaSyncOp::build(::mlir::OpBuilder &odsBuilder,
::mlir::OperationState &odsState, Value matrixA,
Value matrixB, Value matrixC, ArrayAttr mmaShape) {
build(odsBuilder, odsState, matrixC.getType(), matrixA, matrixB, matrixC,
mmaShape, UnitAttr());
}
LogicalResult MmaSyncOp::verify() {
@ -122,6 +128,9 @@ LogicalResult MmaSyncOp::verify() {
// vector element type
Type aType = aVector.getElementType();
// tensor float32 (TF32) enabled
bool tf32Enabled = getOperation()->hasAttr(getTf32EnabledAttrName());
// nvgpu.mma.sync shape (per 32 threads or per warp)
int64_t m = getMmaShape()[0].cast<IntegerAttr>().getInt();
int64_t n = getMmaShape()[1].cast<IntegerAttr>().getInt();
@ -163,6 +172,10 @@ LogicalResult MmaSyncOp::verify() {
return emitOpError() << "expected " << m * n
<< " warp-wide matrix C elements";
// verify tf32 tensor cores are enabled for only F32 datatype
if (tf32Enabled && !(aType.isF32()))
return emitOpError() << "expected tf32 tensor cores only for F32 operands";
//
// Extended verification
//

View File

@ -1,5 +1,6 @@
add_mlir_dialect_library(MLIRNVGPUTransforms
OptimizeSharedMemory.cpp
OptimizeSharedMemory.cpp
MmaSyncTF32Transform.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/NVGPU

View File

@ -0,0 +1,73 @@
//===- OptimizeSharedMemory.cpp - MLIR NVGPU pass implementation ----------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements transforms to enable 1xtf32 and 3xtf32 nvgpu.mma sync
// operations on f32 input datatype
//
//===----------------------------------------------------------------------===//
#include "PassDetail.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
#include "mlir/Dialect/NVGPU/Passes.h"
#include "mlir/Dialect/NVGPU/Transforms/Transforms.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/MathExtras.h"
using namespace mlir;
using namespace mlir::nvgpu;
namespace {
struct MmaSyncF32ToTF32Pattern : public OpRewritePattern<nvgpu::MmaSyncOp> {
using OpRewritePattern<nvgpu::MmaSyncOp>::OpRewritePattern;
MmaSyncF32ToTF32Pattern(MLIRContext *context,
nvgpu::MmaSyncF32Lowering precision)
: OpRewritePattern<nvgpu::MmaSyncOp>(context, /*benifit*/ 1),
precision(precision) {}
LogicalResult matchAndRewrite(nvgpu::MmaSyncOp op,
PatternRewriter &rewrite) const override {
Location location = op->getLoc();
if (op->hasAttr(op.getTf32EnabledAttrName()))
return failure();
if (precision == MmaSyncF32Lowering::Unkown)
return emitError(location, "MmaSync F32-to-TF32 cannot be lowered with "
"unknown precision level");
if (precision == MmaSyncF32Lowering::TF32x3)
return emitError(location, "TF32x3 is not supported at the moment "
"for nvgpu.mma.sync on f32 datatype");
if (precision == MmaSyncF32Lowering::TF32)
op.setTf32EnabledAttr(rewrite.getUnitAttr());
return success();
}
private:
/// Precision for F32 Tensor Cores (TF32 or TF32x3)
nvgpu::MmaSyncF32Lowering precision;
};
} // namespace
void mlir::nvgpu::populateMmaSyncF32ToTF32Patterns(
RewritePatternSet &patterns, nvgpu::MmaSyncF32Lowering precision) {
patterns.add<MmaSyncF32ToTF32Pattern>(patterns.getContext(), precision);
}

View File

@ -219,7 +219,7 @@ func.func @m16n8k4_tf32(%arg0: vector<2x1xf32>, %arg1: vector<1x1xf32>, %arg2: v
// CHECK-SAME: multiplicandBPtxType = #nvvm.mma_type<tf32>
// CHECK-SAME: shape = #nvvm.shape<m = 16, n = 8, k = 4>
// CHECK-SAME: -> !llvm.struct<(f32, f32, f32, f32)>
%d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 4]} : (vector<2x1xf32>, vector<1x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
%d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 4], tf32Enabled} : (vector<2x1xf32>, vector<1x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
// CHECK: [[undef:%.+]] = llvm.mlir.undef : vector<2xf32>
// CHECK-DAG: llvm.extractvalue [[d]][0] : !llvm.struct<(f32, f32, f32, f32)>
// CHECK-DAG: llvm.extractvalue [[d]][1] : !llvm.struct<(f32, f32, f32, f32)>

View File

@ -76,6 +76,13 @@ func.func @m16n8k16_fp16_vector_shape_a_extended(%arg0: vector<2x4xf16>, %arg1:
}
// -----
func.func @m16n8k16_fp16_tf32Enabled(%arg0: vector<4x2xf16>, %arg1: vector<2x2xf16>, %arg2: vector<2x2xf16>) -> vector<2x2xf16> {
// expected-error @+1 {{expected tf32 tensor cores only for F32 operands}}
%d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 16], tf32Enabled} : (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
return %d : vector<2x2xf16>
}
// -----
func.func @m16n8k8_fp32_vector_shape_a(%arg0: vector<4x2xf32>, %arg1: vector<2x1xf32>, %arg2: vector<2x2xf32>) -> vector<2x2xf32> {
// expected-error @+1 {{expected 128 warp-wide matrix A elements}}
%d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 8]} : (vector<4x2xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>

View File

@ -0,0 +1,20 @@
// RUN: mlir-opt %s -test-nvgpu-mmasync-f32-to-tf32-patterns="precision=tf32" -split-input-file | FileCheck %s
// CHECK-LABEL: m16n8k4_tf32
func.func @m16n8k4_tf32(%arg0: vector<2x1xf32>, %arg1: vector<1x1xf32>, %arg2: vector<2x2xf32>) -> vector<2x2xf32> {
// CHECK: nvgpu.mma.sync
// CHECK-SAME: tf32Enabled
%d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 4]} : (vector<2x1xf32>, vector<1x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
return %d : vector<2x2xf32>
}
// -----
// CHECK-LABEL: m16n8k8_tf32
func.func @m16n8k8_tf32(%arg0: vector<4x1xf32>, %arg1: vector<2x1xf32>, %arg2: vector<2x2xf32>) -> vector<2x2xf32> {
// CHECK: nvgpu.mma.sync
// CHECK-SAME: tf32Enabled
%d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 8]} : (vector<4x1xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
return %d : vector<2x2xf32>
}
// -----

View File

@ -0,0 +1,18 @@
// RUN: mlir-opt %s -test-nvgpu-mmasync-f32-to-tf32-patterns="precision=tf32x3" -split-input-file | FileCheck %s
// CHECK-LABEL: m16n8k4_tf32
func.func @m16n8k4_tf32(%arg0: vector<2x1xf32>, %arg1: vector<1x1xf32>, %arg2: vector<2x2xf32>) -> vector<2x2xf32> {
// expected-error @+1 {{TF32x3 is not supported at the moment for nvgpu.mma.sync on f32 datatype}}
%d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 4]} : (vector<2x1xf32>, vector<1x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
return %d : vector<2x2xf32>
}
// -----
// CHECK-LABEL: m16n8k8_tf32
func.func @m16n8k8_tf32(%arg0: vector<4x1xf32>, %arg1: vector<2x1xf32>, %arg2: vector<2x2xf32>) -> vector<2x2xf32> {
// expected-error @+1 {{TF32x3 is not supported at the moment for nvgpu.mma.sync on f32 datatype}}
%d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 8]} : (vector<4x1xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
return %d : vector<2x2xf32>
}
// -----

View File

@ -5,6 +5,7 @@ add_subdirectory(GPU)
add_subdirectory(Linalg)
add_subdirectory(Math)
add_subdirectory(MemRef)
add_subdirectory(NVGPU)
add_subdirectory(SCF)
add_subdirectory(Shape)
add_subdirectory(SPIRV)

View File

@ -0,0 +1,21 @@
# Exclude tests from libMLIR.so
add_mlir_library(MLIRNVGPUTestPasses
TestNVGPUTransforms.cpp
EXCLUDE_FROM_LIBMLIR
LINK_LIBS PUBLIC
MLIRIR
MLIRAffineDialect
MLIRAnalysis
MLIRFuncDialect
MLIRGPUOps
MLIRLLVMDialect
MLIRMemRefDialect
MLIRNVGPUDialect
MLIRNVGPUTransforms
MLIRPass
MLIRSCFDialect
MLIRTransformUtils
)

View File

@ -0,0 +1,76 @@
//===- TestNVGPUTransforms.cpp - Test NVGPU transforms and lowerings ----===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include <type_traits>
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/NVGPU/Transforms/Transforms.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
using namespace mlir;
using namespace mlir::nvgpu;
namespace {
struct TestMmaSyncF32ToTF32Patterns
: public PassWrapper<TestMmaSyncF32ToTF32Patterns,
OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMmaSyncF32ToTF32Patterns)
StringRef getArgument() const final {
return "test-nvgpu-mmasync-f32-to-tf32-patterns";
}
StringRef getDescription() const final {
return "Test patterns to convert mma.sync on f32 with tf32 precision";
}
TestMmaSyncF32ToTF32Patterns() = default;
TestMmaSyncF32ToTF32Patterns(const TestMmaSyncF32ToTF32Patterns &pass)
: PassWrapper(pass) {}
Option<std::string> precision{
*this, "precision",
llvm::cl::desc(
"Target nvgpu.mma.sync on f32 input with tf32 or tf32x3 precision"),
llvm::cl::init("tf32")};
MmaSyncF32Lowering tf32Precision =
llvm::StringSwitch<MmaSyncF32Lowering>(precision)
.Case("tf32", MmaSyncF32Lowering::TF32)
.Case("tf32x3", MmaSyncF32Lowering::TF32x3)
.Default(MmaSyncF32Lowering::Unkown);
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
populateMmaSyncF32ToTF32Patterns(patterns, tf32Precision);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
};
} // namespace
namespace mlir {
namespace test {
void registerTestNvgpuLowerings() {
PassRegistration<TestMmaSyncF32ToTF32Patterns>();
}
} // namespace test
} // namespace mlir

View File

@ -20,6 +20,7 @@ if(MLIR_INCLUDE_TESTS)
MLIRLinalgTestPasses
MLIRMathTestPasses
MLIRMemRefTestPasses
MLIRNVGPUTestPasses
MLIRSCFTestPasses
MLIRShapeTestPasses
MLIRSPIRVTestPasses

View File

@ -113,6 +113,7 @@ void registerTestTensorTransforms();
void registerTestTilingInterface();
void registerTestTransformDialectInterpreterPass();
void registerTestVectorLowerings();
void registerTestNvgpuLowerings();
} // namespace test
} // namespace mlir
@ -208,6 +209,7 @@ void registerTestPasses() {
mlir::test::registerTestTilingInterface();
mlir::test::registerTestTransformDialectInterpreterPass();
mlir::test::registerTestVectorLowerings();
mlir::test::registerTestNvgpuLowerings();
}
#endif