[mlir][NVGPU] Verifiers for nvgpu.mma.sync Op

- Adds verification for `nvgpu.mma.sync` op
- Adds tests to `mlir/test/Dialect/NVGPU/invalid.mlir`
- `nvgpu.mma.sync` verifier caught a bug and triggered a failure in m16n8k4_tf32_f32 variant in `mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir`
     - The output shape of vector holding thread-level accumulators was inconsistent  and fixed in this change

Reviewed By: ThomasRaoux

Differential Revision: https://reviews.llvm.org/D129400
This commit is contained in:
Manish Gupta 2022-07-13 17:53:52 +00:00 committed by Thomas Raoux
parent 0aefc94651
commit f7d42d5149
4 changed files with 190 additions and 13 deletions

View File

@ -81,7 +81,10 @@ def NVGPU_LdMatrixOp : NVGPU_Op<"ldmatrix",
}];
}
def NVGPU_MmaSyncOp : NVGPU_Op<"mma.sync", [NoSideEffect]> {
def NVGPU_MmaSyncOp : NVGPU_Op<"mma.sync", [
NoSideEffect,
PredOpTrait<"matrixA and matrixB have same element type", TCopVTEtIsSameAs<0, 1>>,
]> {
let description = [{
The `nvgpu.mma.sync` op represents the distributed form of a collective
matrix-multiply-and-accumulate (mma) operation that is compatible with
@ -112,6 +115,8 @@ def NVGPU_MmaSyncOp : NVGPU_Op<"mma.sync", [NoSideEffect]> {
`(` $matrixA`,` $matrixB`,` $matrixC `)` attr-dict
`:` `(` type($matrixA) `,` type($matrixB) `,` type($matrixC) `)` `->` type($res)
}];
let hasVerifier = 1;
}

View File

@ -88,5 +88,103 @@ LogicalResult DeviceAsyncCopyOp::verify() {
return success();
}
LogicalResult MmaSyncOp::verify() {
// Fundamental tensor core mma.sync op
// For F32 (TF32), F16, S8, and S4 data types fundamental tensor core
// operation is of shape: 8-by-8-by-128b. F64 is an exception. The
// verification for mma.sync covering various shapes and data types is based
// on the fundamental tensor core operionation.
constexpr int kThreads = 32; // 32 threads per warp
int64_t shapeM = 8;
int64_t shapeN = 8;
int64_t shapeK; // set based on data type (128b for all data types except F64)
// Number of elements A, B, and C per thread per fundamental tensor core tile
int64_t numElementA; // set based on data type (32b except F64)
int64_t numElementB; // set based on data type (32b except F64)
int64_t numElementC{2}; // two accumulator elements per fundamental tile
// nvgpu.mma.sync vector operands (per thread)
auto aVector = getMatrixA().getType().cast<VectorType>();
auto bVector = getMatrixB().getType().cast<VectorType>();
auto cVector = getMatrixC().getType().cast<VectorType>();
// vector shapes
ArrayRef<int64_t> aShape = aVector.getShape();
ArrayRef<int64_t> bShape = bVector.getShape();
ArrayRef<int64_t> cShape = cVector.getShape();
// vector element type
Type aType = aVector.getElementType();
// nvgpu.mma.sync shape (per 32 threads or per warp)
int64_t m = getMmaShape()[0].cast<IntegerAttr>().getInt();
int64_t n = getMmaShape()[1].cast<IntegerAttr>().getInt();
int64_t k = getMmaShape()[2].cast<IntegerAttr>().getInt();
if (aType.isF64()) {
// exception to 8-by-8-128b fundamental tensor core tile size
shapeK = 4;
numElementA = 1;
numElementB = 1;
} else if (aType.isF32() || aType.isBF16() || aType.isF16() ||
aType.isInteger(8) || aType.isInteger(4)) {
// 8-by-8-128b fundamental tensor core tile size
int operandBitwidth = aType.getIntOrFloatBitWidth();
shapeK = 128 / operandBitwidth; // 128b wide shapeK
numElementA = 32 / operandBitwidth; // 32b wide operand A
numElementB = 32 / operandBitwidth; // 32b wide operand B
} else {
return emitError() << "expected input data type (i4,i8,f16,bf16,tf32,f64) "
"supported by nvgpu.mma.sync";
}
//
// Basic verification
//
// verify warp-wide size for vector a
if (aShape[0] * aShape[1] * kThreads != m * k)
return emitOpError() << "expected " << m * k
<< " warp-wide matrix A elements";
// verify warp-wide size for vector b
if (bShape[0] * bShape[1] * kThreads != k * n)
return emitOpError() << "expected " << k * n
<< " warp-wide matrix B elements";
// verify warp-wide size for vector c
if (cShape[0] * cShape[1] * kThreads != m * n)
return emitOpError() << "expected " << m * n
<< " warp-wide matrix C elements";
//
// Extended verification
//
// tiles of fundamental tensor core operations
int64_t mTile = m / shapeM;
int64_t nTile = n / shapeN;
int64_t kTile = k / shapeK;
// verify shape of aVector
if (!((aShape[0] == mTile * kTile) && (aShape[1] == numElementA)))
return emitOpError() << "expected matrix A to be shaped (" << mTile * kTile
<< " x " << numElementA << ")";
// verify shape of bVector
if (!((bShape[0] == kTile * nTile) && (bShape[1] == numElementB)))
return emitOpError() << "expected matrix B to be shaped (" << kTile * nTile
<< " x " << numElementB << ")";
// verify shape of cVector
if (!((cShape[0] == mTile * nTile) && (cShape[1] == numElementC)))
return emitOpError() << "expected matrix C to be shaped (" << mTile * nTile
<< " x " << numElementC << ")";
return success();
}
#define GET_OP_CLASSES
#include "mlir/Dialect/NVGPU/IR/NVGPU.cpp.inc"

View File

@ -205,7 +205,7 @@ func.func @ldmatrix_x1(%arg0: memref<128x128xf16, 3>) -> vector<1x2xf16> {
// -----
// CHECK-LABEL: @m16n8k4_tf32
func.func @m16n8k4_tf32(%arg0: vector<2x1xf32>, %arg1: vector<1x1xf32>, %arg2: vector<4x1xf32>) -> vector<4x1xf32> {
func.func @m16n8k4_tf32(%arg0: vector<2x1xf32>, %arg1: vector<1x1xf32>, %arg2: vector<2x2xf32>) -> vector<2x2xf32> {
// The A, B operand should be bitcast to i32
// CHECK: llvm.extractvalue
// CHECK: llvm.bitcast {{.*}} : vector<1xf32> to i32
@ -219,17 +219,22 @@ func.func @m16n8k4_tf32(%arg0: vector<2x1xf32>, %arg1: vector<1x1xf32>, %arg2: v
// CHECK-SAME: multiplicandBPtxType = #nvvm.mma_type<tf32>
// CHECK-SAME: shape = #nvvm.shape<m = 16, n = 8, k = 4>
// CHECK-SAME: -> !llvm.struct<(f32, f32, f32, f32)>
%d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 4]} : (vector<2x1xf32>, vector<1x1xf32>, vector<4x1xf32>) -> vector<4x1xf32>
// CHECK: [[el:%.+]] = llvm.extractvalue [[d]][0]
// CHECK: llvm.bitcast [[el]] : f32 to vector<1xf32>
// CHECK: [[el:%.+]] = llvm.extractvalue [[d]][1]
// CHECK: llvm.bitcast [[el]] : f32 to vector<1xf32>
// CHECK: [[el:%.+]] = llvm.extractvalue [[d]][2]
// CHECK: llvm.bitcast [[el]] : f32 to vector<1xf32>
// CHECK: [[el:%.+]] = llvm.extractvalue [[d]][3]
// CHECK: llvm.bitcast [[el]] : f32 to vector<1xf32>
// CHECK-COUNT-4: llvm.insertvalue {{.*}} : !llvm.array<4 x vector<1xf32>>
return %d : vector<4x1xf32>
%d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 4]} : (vector<2x1xf32>, vector<1x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
// CHECK: [[undef:%.+]] = llvm.mlir.undef : vector<2xf32>
// CHECK-DAG: llvm.extractvalue [[d]][0] : !llvm.struct<(f32, f32, f32, f32)>
// CHECK-DAG: llvm.extractvalue [[d]][1] : !llvm.struct<(f32, f32, f32, f32)>
// CHECK: [[d00:%.+]] = llvm.insertelement {{%.+}}, [[undef]][{{.*}}] : vector<2xf32>
// CHECK: [[d01:%.+]] = llvm.insertelement {{%.+}}, [[d00]][{{.*}}] : vector<2xf32>
// CHECK: [[undef:%.+]] = llvm.mlir.undef : vector<2xf32>
// CHECK-DAG: llvm.extractvalue [[d]][2] : !llvm.struct<(f32, f32, f32, f32)>
// CHECK-DAG: llvm.extractvalue [[d]][3] : !llvm.struct<(f32, f32, f32, f32)>
// CHECK: [[d10:%.+]] = llvm.insertelement {{%.+}}, [[undef]][{{.*}}] : vector<2xf32>
// CHECK: [[d11:%.+]] = llvm.insertelement {{%.+}}, [[d10]][{{.*}}] : vector<2xf32>
// CHECK-DAG: llvm.insertvalue [[d01]], {{%.+}}[0] : !llvm.array<2 x vector<2xf32>>
// CHECK-DAG: llvm.insertvalue [[d11]], {{%.+}}[1] : !llvm.array<2 x vector<2xf32>>
return %d : vector<2x2xf32>
}
// -----

View File

@ -1,4 +1,73 @@
// RUN: mlir-opt -split-input-file -verify-diagnostics %s
func.func @m16n8k16_fp16_vector_shape_a(%arg0: vector<4x4xf16>, %arg1: vector<2x2xf16>, %arg2: vector<2x2xf16>) -> vector<2x2xf16> {
// expected-error @+1 {{expected 256 warp-wide matrix A elements}}
%d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 16]} : (vector<4x4xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
return %d : vector<2x2xf16>
}
// -----
func.func @m16n8k16_fp16_vector_shape_b(%arg0: vector<4x2xf16>, %arg1: vector<2x4xf16>, %arg2: vector<2x2xf16>) -> vector<2x2xf16> {
// expected-error @+1 {{expected 128 warp-wide matrix B elements}}
%d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 16]} : (vector<4x2xf16>, vector<2x4xf16>, vector<2x2xf16>) -> vector<2x2xf16>
return %d : vector<2x2xf16>
}
// -----
func.func @m16n8k16_fp16_vector_shape_c(%arg0: vector<4x2xf16>, %arg1: vector<2x2xf16>, %arg2: vector<2x4xf16>) -> vector<2x4xf16> {
// expected-error @+1 {{expected 128 warp-wide matrix C elements}}
%d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 16]} : (vector<4x2xf16>, vector<2x2xf16>, vector<2x4xf16>) -> vector<2x4xf16>
return %d : vector<2x4xf16>
}
// -----
func.func @m16n8k16_fp16_vector_shape_a_extended(%arg0: vector<2x4xf16>, %arg1: vector<2x2xf16>, %arg2: vector<2x2xf16>) -> vector<2x2xf16> {
// expected-error @+1 {{expected matrix A to be shaped (4 x 2)}}
%d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 16]} : (vector<2x4xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
return %d : vector<2x2xf16>
}
// -----
func.func @m16n8k8_fp32_vector_shape_a(%arg0: vector<4x2xf32>, %arg1: vector<2x1xf32>, %arg2: vector<2x2xf32>) -> vector<2x2xf32> {
// expected-error @+1 {{expected 128 warp-wide matrix A elements}}
%d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 8]} : (vector<4x2xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
return %d : vector<2x2xf32>
}
// -----
func.func @m16n8k8_fp32_vector_shape_a_extended(%arg0: vector<1x4xf32>, %arg1: vector<2x1xf32>, %arg2: vector<2x2xf32>) -> vector<2x2xf32> {
// expected-error @+1 {{expected matrix A to be shaped (4 x 1)}}
%d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 8]} : (vector<1x4xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
return %d : vector<2x2xf32>
}
// -----
func.func @m8n8k4_fp64_vector_shape_a(%arg0: vector<1x2xf64>, %arg1: vector<1x1xf64>, %arg2: vector<1x2xf64>) -> vector<1x2xf64> {
// expected-error @+1 {{expected 32 warp-wide matrix A elements}}
%d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [8, 8, 4]} : (vector<1x2xf64>, vector<1x1xf64>, vector<1x2xf64>) -> vector<1x2xf64>
return %d : vector<1x2xf64>
}
// -----
func.func @m8n8k4_fp64_vector_shape_c_extended(%arg0: vector<1x1xf64>, %arg1: vector<1x1xf64>, %arg2: vector<2x1xf64>) -> vector<2x1xf64> {
// expected-error @+1 {{expected matrix C to be shaped (1 x 2)}}
%d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [8, 8, 4]} : (vector<1x1xf64>, vector<1x1xf64>, vector<2x1xf64>) -> vector<2x1xf64>
return %d : vector<2x1xf64>
}
// -----
func.func @m16n8k32_int8_vector_shape_b(%arg0: vector<4x4xi8>, %arg1: vector<4x4xi8>, %arg2: vector<2x2xi32>) -> vector<2x2xi32> {
// expected-error @+1 {{expected 256 warp-wide matrix B elements}}
%d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 32]} : (vector<4x4xi8>, vector<4x4xi8>, vector<2x2xi32>) -> vector<2x2xi32>
return %d : vector<2x2xi32>
}
// -----
func.func @m16n8k32_int32_datatype(%arg0: vector<4x4xi32>, %arg1: vector<2x4xi8>, %arg2: vector<2x2xi32>) -> vector<2x2xi32> {
// expected-error @+1 {{op failed to verify that matrixA and matrixB have same element type}}
%d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 32]} : (vector<4x4xi32>, vector<2x4xi8>, vector<2x2xi32>) -> vector<2x2xi32>
return %d : vector<2x2xi32>
}
// -----
func.func @async_cp_memory_space(%dst : memref<16xf32>, %src : memref<16xf32>, %i : index) -> () {
// expected-error @+1 {{destination memref must have memory space 3}}