forked from OSchip/llvm-project
[mlir][NVGPU] Verifiers for nvgpu.mma.sync Op
- Adds verification for `nvgpu.mma.sync` op - Adds tests to `mlir/test/Dialect/NVGPU/invalid.mlir` - `nvgpu.mma.sync` verifier caught a bug and triggered a failure in m16n8k4_tf32_f32 variant in `mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir` - The output shape of vector holding thread-level accumulators was inconsistent and fixed in this change Reviewed By: ThomasRaoux Differential Revision: https://reviews.llvm.org/D129400
This commit is contained in:
parent
0aefc94651
commit
f7d42d5149
|
@ -81,7 +81,10 @@ def NVGPU_LdMatrixOp : NVGPU_Op<"ldmatrix",
|
|||
}];
|
||||
}
|
||||
|
||||
def NVGPU_MmaSyncOp : NVGPU_Op<"mma.sync", [NoSideEffect]> {
|
||||
def NVGPU_MmaSyncOp : NVGPU_Op<"mma.sync", [
|
||||
NoSideEffect,
|
||||
PredOpTrait<"matrixA and matrixB have same element type", TCopVTEtIsSameAs<0, 1>>,
|
||||
]> {
|
||||
let description = [{
|
||||
The `nvgpu.mma.sync` op represents the distributed form of a collective
|
||||
matrix-multiply-and-accumulate (mma) operation that is compatible with
|
||||
|
@ -112,6 +115,8 @@ def NVGPU_MmaSyncOp : NVGPU_Op<"mma.sync", [NoSideEffect]> {
|
|||
`(` $matrixA`,` $matrixB`,` $matrixC `)` attr-dict
|
||||
`:` `(` type($matrixA) `,` type($matrixB) `,` type($matrixC) `)` `->` type($res)
|
||||
}];
|
||||
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -88,5 +88,103 @@ LogicalResult DeviceAsyncCopyOp::verify() {
|
|||
return success();
|
||||
}
|
||||
|
||||
LogicalResult MmaSyncOp::verify() {
|
||||
|
||||
// Fundamental tensor core mma.sync op
|
||||
// For F32 (TF32), F16, S8, and S4 data types fundamental tensor core
|
||||
// operation is of shape: 8-by-8-by-128b. F64 is an exception. The
|
||||
// verification for mma.sync covering various shapes and data types is based
|
||||
// on the fundamental tensor core operionation.
|
||||
constexpr int kThreads = 32; // 32 threads per warp
|
||||
int64_t shapeM = 8;
|
||||
int64_t shapeN = 8;
|
||||
int64_t shapeK; // set based on data type (128b for all data types except F64)
|
||||
|
||||
// Number of elements A, B, and C per thread per fundamental tensor core tile
|
||||
int64_t numElementA; // set based on data type (32b except F64)
|
||||
int64_t numElementB; // set based on data type (32b except F64)
|
||||
int64_t numElementC{2}; // two accumulator elements per fundamental tile
|
||||
|
||||
// nvgpu.mma.sync vector operands (per thread)
|
||||
auto aVector = getMatrixA().getType().cast<VectorType>();
|
||||
auto bVector = getMatrixB().getType().cast<VectorType>();
|
||||
auto cVector = getMatrixC().getType().cast<VectorType>();
|
||||
|
||||
// vector shapes
|
||||
ArrayRef<int64_t> aShape = aVector.getShape();
|
||||
ArrayRef<int64_t> bShape = bVector.getShape();
|
||||
ArrayRef<int64_t> cShape = cVector.getShape();
|
||||
|
||||
// vector element type
|
||||
Type aType = aVector.getElementType();
|
||||
|
||||
// nvgpu.mma.sync shape (per 32 threads or per warp)
|
||||
int64_t m = getMmaShape()[0].cast<IntegerAttr>().getInt();
|
||||
int64_t n = getMmaShape()[1].cast<IntegerAttr>().getInt();
|
||||
int64_t k = getMmaShape()[2].cast<IntegerAttr>().getInt();
|
||||
|
||||
if (aType.isF64()) {
|
||||
// exception to 8-by-8-128b fundamental tensor core tile size
|
||||
shapeK = 4;
|
||||
numElementA = 1;
|
||||
numElementB = 1;
|
||||
} else if (aType.isF32() || aType.isBF16() || aType.isF16() ||
|
||||
aType.isInteger(8) || aType.isInteger(4)) {
|
||||
// 8-by-8-128b fundamental tensor core tile size
|
||||
int operandBitwidth = aType.getIntOrFloatBitWidth();
|
||||
shapeK = 128 / operandBitwidth; // 128b wide shapeK
|
||||
numElementA = 32 / operandBitwidth; // 32b wide operand A
|
||||
numElementB = 32 / operandBitwidth; // 32b wide operand B
|
||||
} else {
|
||||
return emitError() << "expected input data type (i4,i8,f16,bf16,tf32,f64) "
|
||||
"supported by nvgpu.mma.sync";
|
||||
}
|
||||
|
||||
//
|
||||
// Basic verification
|
||||
//
|
||||
|
||||
// verify warp-wide size for vector a
|
||||
if (aShape[0] * aShape[1] * kThreads != m * k)
|
||||
return emitOpError() << "expected " << m * k
|
||||
<< " warp-wide matrix A elements";
|
||||
|
||||
// verify warp-wide size for vector b
|
||||
if (bShape[0] * bShape[1] * kThreads != k * n)
|
||||
return emitOpError() << "expected " << k * n
|
||||
<< " warp-wide matrix B elements";
|
||||
|
||||
// verify warp-wide size for vector c
|
||||
if (cShape[0] * cShape[1] * kThreads != m * n)
|
||||
return emitOpError() << "expected " << m * n
|
||||
<< " warp-wide matrix C elements";
|
||||
|
||||
//
|
||||
// Extended verification
|
||||
//
|
||||
|
||||
// tiles of fundamental tensor core operations
|
||||
int64_t mTile = m / shapeM;
|
||||
int64_t nTile = n / shapeN;
|
||||
int64_t kTile = k / shapeK;
|
||||
|
||||
// verify shape of aVector
|
||||
if (!((aShape[0] == mTile * kTile) && (aShape[1] == numElementA)))
|
||||
return emitOpError() << "expected matrix A to be shaped (" << mTile * kTile
|
||||
<< " x " << numElementA << ")";
|
||||
|
||||
// verify shape of bVector
|
||||
if (!((bShape[0] == kTile * nTile) && (bShape[1] == numElementB)))
|
||||
return emitOpError() << "expected matrix B to be shaped (" << kTile * nTile
|
||||
<< " x " << numElementB << ")";
|
||||
|
||||
// verify shape of cVector
|
||||
if (!((cShape[0] == mTile * nTile) && (cShape[1] == numElementC)))
|
||||
return emitOpError() << "expected matrix C to be shaped (" << mTile * nTile
|
||||
<< " x " << numElementC << ")";
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "mlir/Dialect/NVGPU/IR/NVGPU.cpp.inc"
|
||||
|
|
|
@ -205,7 +205,7 @@ func.func @ldmatrix_x1(%arg0: memref<128x128xf16, 3>) -> vector<1x2xf16> {
|
|||
// -----
|
||||
|
||||
// CHECK-LABEL: @m16n8k4_tf32
|
||||
func.func @m16n8k4_tf32(%arg0: vector<2x1xf32>, %arg1: vector<1x1xf32>, %arg2: vector<4x1xf32>) -> vector<4x1xf32> {
|
||||
func.func @m16n8k4_tf32(%arg0: vector<2x1xf32>, %arg1: vector<1x1xf32>, %arg2: vector<2x2xf32>) -> vector<2x2xf32> {
|
||||
// The A, B operand should be bitcast to i32
|
||||
// CHECK: llvm.extractvalue
|
||||
// CHECK: llvm.bitcast {{.*}} : vector<1xf32> to i32
|
||||
|
@ -219,17 +219,22 @@ func.func @m16n8k4_tf32(%arg0: vector<2x1xf32>, %arg1: vector<1x1xf32>, %arg2: v
|
|||
// CHECK-SAME: multiplicandBPtxType = #nvvm.mma_type<tf32>
|
||||
// CHECK-SAME: shape = #nvvm.shape<m = 16, n = 8, k = 4>
|
||||
// CHECK-SAME: -> !llvm.struct<(f32, f32, f32, f32)>
|
||||
%d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 4]} : (vector<2x1xf32>, vector<1x1xf32>, vector<4x1xf32>) -> vector<4x1xf32>
|
||||
// CHECK: [[el:%.+]] = llvm.extractvalue [[d]][0]
|
||||
// CHECK: llvm.bitcast [[el]] : f32 to vector<1xf32>
|
||||
// CHECK: [[el:%.+]] = llvm.extractvalue [[d]][1]
|
||||
// CHECK: llvm.bitcast [[el]] : f32 to vector<1xf32>
|
||||
// CHECK: [[el:%.+]] = llvm.extractvalue [[d]][2]
|
||||
// CHECK: llvm.bitcast [[el]] : f32 to vector<1xf32>
|
||||
// CHECK: [[el:%.+]] = llvm.extractvalue [[d]][3]
|
||||
// CHECK: llvm.bitcast [[el]] : f32 to vector<1xf32>
|
||||
// CHECK-COUNT-4: llvm.insertvalue {{.*}} : !llvm.array<4 x vector<1xf32>>
|
||||
return %d : vector<4x1xf32>
|
||||
%d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 4]} : (vector<2x1xf32>, vector<1x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
|
||||
// CHECK: [[undef:%.+]] = llvm.mlir.undef : vector<2xf32>
|
||||
// CHECK-DAG: llvm.extractvalue [[d]][0] : !llvm.struct<(f32, f32, f32, f32)>
|
||||
// CHECK-DAG: llvm.extractvalue [[d]][1] : !llvm.struct<(f32, f32, f32, f32)>
|
||||
// CHECK: [[d00:%.+]] = llvm.insertelement {{%.+}}, [[undef]][{{.*}}] : vector<2xf32>
|
||||
// CHECK: [[d01:%.+]] = llvm.insertelement {{%.+}}, [[d00]][{{.*}}] : vector<2xf32>
|
||||
|
||||
// CHECK: [[undef:%.+]] = llvm.mlir.undef : vector<2xf32>
|
||||
// CHECK-DAG: llvm.extractvalue [[d]][2] : !llvm.struct<(f32, f32, f32, f32)>
|
||||
// CHECK-DAG: llvm.extractvalue [[d]][3] : !llvm.struct<(f32, f32, f32, f32)>
|
||||
// CHECK: [[d10:%.+]] = llvm.insertelement {{%.+}}, [[undef]][{{.*}}] : vector<2xf32>
|
||||
// CHECK: [[d11:%.+]] = llvm.insertelement {{%.+}}, [[d10]][{{.*}}] : vector<2xf32>
|
||||
|
||||
// CHECK-DAG: llvm.insertvalue [[d01]], {{%.+}}[0] : !llvm.array<2 x vector<2xf32>>
|
||||
// CHECK-DAG: llvm.insertvalue [[d11]], {{%.+}}[1] : !llvm.array<2 x vector<2xf32>>
|
||||
return %d : vector<2x2xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
|
|
@ -1,4 +1,73 @@
|
|||
// RUN: mlir-opt -split-input-file -verify-diagnostics %s
|
||||
func.func @m16n8k16_fp16_vector_shape_a(%arg0: vector<4x4xf16>, %arg1: vector<2x2xf16>, %arg2: vector<2x2xf16>) -> vector<2x2xf16> {
|
||||
// expected-error @+1 {{expected 256 warp-wide matrix A elements}}
|
||||
%d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 16]} : (vector<4x4xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
|
||||
return %d : vector<2x2xf16>
|
||||
}
|
||||
// -----
|
||||
|
||||
func.func @m16n8k16_fp16_vector_shape_b(%arg0: vector<4x2xf16>, %arg1: vector<2x4xf16>, %arg2: vector<2x2xf16>) -> vector<2x2xf16> {
|
||||
// expected-error @+1 {{expected 128 warp-wide matrix B elements}}
|
||||
%d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 16]} : (vector<4x2xf16>, vector<2x4xf16>, vector<2x2xf16>) -> vector<2x2xf16>
|
||||
return %d : vector<2x2xf16>
|
||||
}
|
||||
// -----
|
||||
|
||||
func.func @m16n8k16_fp16_vector_shape_c(%arg0: vector<4x2xf16>, %arg1: vector<2x2xf16>, %arg2: vector<2x4xf16>) -> vector<2x4xf16> {
|
||||
// expected-error @+1 {{expected 128 warp-wide matrix C elements}}
|
||||
%d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 16]} : (vector<4x2xf16>, vector<2x2xf16>, vector<2x4xf16>) -> vector<2x4xf16>
|
||||
return %d : vector<2x4xf16>
|
||||
}
|
||||
// -----
|
||||
|
||||
func.func @m16n8k16_fp16_vector_shape_a_extended(%arg0: vector<2x4xf16>, %arg1: vector<2x2xf16>, %arg2: vector<2x2xf16>) -> vector<2x2xf16> {
|
||||
// expected-error @+1 {{expected matrix A to be shaped (4 x 2)}}
|
||||
%d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 16]} : (vector<2x4xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
|
||||
return %d : vector<2x2xf16>
|
||||
}
|
||||
// -----
|
||||
|
||||
func.func @m16n8k8_fp32_vector_shape_a(%arg0: vector<4x2xf32>, %arg1: vector<2x1xf32>, %arg2: vector<2x2xf32>) -> vector<2x2xf32> {
|
||||
// expected-error @+1 {{expected 128 warp-wide matrix A elements}}
|
||||
%d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 8]} : (vector<4x2xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
|
||||
return %d : vector<2x2xf32>
|
||||
}
|
||||
// -----
|
||||
|
||||
func.func @m16n8k8_fp32_vector_shape_a_extended(%arg0: vector<1x4xf32>, %arg1: vector<2x1xf32>, %arg2: vector<2x2xf32>) -> vector<2x2xf32> {
|
||||
// expected-error @+1 {{expected matrix A to be shaped (4 x 1)}}
|
||||
%d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 8]} : (vector<1x4xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
|
||||
return %d : vector<2x2xf32>
|
||||
}
|
||||
// -----
|
||||
|
||||
func.func @m8n8k4_fp64_vector_shape_a(%arg0: vector<1x2xf64>, %arg1: vector<1x1xf64>, %arg2: vector<1x2xf64>) -> vector<1x2xf64> {
|
||||
// expected-error @+1 {{expected 32 warp-wide matrix A elements}}
|
||||
%d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [8, 8, 4]} : (vector<1x2xf64>, vector<1x1xf64>, vector<1x2xf64>) -> vector<1x2xf64>
|
||||
return %d : vector<1x2xf64>
|
||||
}
|
||||
// -----
|
||||
|
||||
func.func @m8n8k4_fp64_vector_shape_c_extended(%arg0: vector<1x1xf64>, %arg1: vector<1x1xf64>, %arg2: vector<2x1xf64>) -> vector<2x1xf64> {
|
||||
// expected-error @+1 {{expected matrix C to be shaped (1 x 2)}}
|
||||
%d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [8, 8, 4]} : (vector<1x1xf64>, vector<1x1xf64>, vector<2x1xf64>) -> vector<2x1xf64>
|
||||
return %d : vector<2x1xf64>
|
||||
}
|
||||
// -----
|
||||
|
||||
func.func @m16n8k32_int8_vector_shape_b(%arg0: vector<4x4xi8>, %arg1: vector<4x4xi8>, %arg2: vector<2x2xi32>) -> vector<2x2xi32> {
|
||||
// expected-error @+1 {{expected 256 warp-wide matrix B elements}}
|
||||
%d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 32]} : (vector<4x4xi8>, vector<4x4xi8>, vector<2x2xi32>) -> vector<2x2xi32>
|
||||
return %d : vector<2x2xi32>
|
||||
}
|
||||
// -----
|
||||
|
||||
func.func @m16n8k32_int32_datatype(%arg0: vector<4x4xi32>, %arg1: vector<2x4xi8>, %arg2: vector<2x2xi32>) -> vector<2x2xi32> {
|
||||
// expected-error @+1 {{op failed to verify that matrixA and matrixB have same element type}}
|
||||
%d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 32]} : (vector<4x4xi32>, vector<2x4xi8>, vector<2x2xi32>) -> vector<2x2xi32>
|
||||
return %d : vector<2x2xi32>
|
||||
}
|
||||
// -----
|
||||
|
||||
func.func @async_cp_memory_space(%dst : memref<16xf32>, %src : memref<16xf32>, %i : index) -> () {
|
||||
// expected-error @+1 {{destination memref must have memory space 3}}
|
||||
|
|
Loading…
Reference in New Issue