Update syntax for amx.tile_muli to use two Unit attr to mark the zext case

This makes the annotation tied to the operand and the use of a keyword
more explicit/readable on what it means.

Differential Revision: https://reviews.llvm.org/D99001
This commit is contained in:
Mehdi Amini 2021-03-20 01:23:12 +00:00
parent ea48bf8649
commit cdb6eb7e83
8 changed files with 30 additions and 32 deletions

View File

@ -196,14 +196,14 @@ def TileMulIOp : AMX_Op<"tile_muli", [NoSideEffect, AllTypesMatch<["acc", "res"]
into a "m x n" destination tile. Supports all "si32 <- s/ui8 x s/ui8"
combinations (4 bytes packed into dwords in the columns of both the
source operand tiles; the zero or sign extension is specified with
the attributes). The operation is eventually lowered into one of
the "tdpbssd", "tdpbsud", "tdpbusd", or "tdpbuud" instructions with
the corresponding tile configuration.
the attributes and default to sign extended). The operation is eventually
lowered into one of the "tdpbssd", "tdpbsud", "tdpbusd", or "tdpbuud"
instructions with the corresponding tile configuration.
Example:
```mlir
%0 = amx.tile_muli %a, %b, %c [true, true]
%0 = amx.tile_muli %a zext, %b zext, %c
: vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
```
}];
@ -211,7 +211,9 @@ def TileMulIOp : AMX_Op<"tile_muli", [NoSideEffect, AllTypesMatch<["acc", "res"]
let arguments = (ins VectorOfRankAndType<[2], [I32, I8]>:$lhs,
VectorOfRankAndType<[2], [I32, I8]>:$rhs,
VectorOfRankAndType<[2], [I32, I8]>:$acc,
BoolArrayAttr:$zext);
UnitAttr:$isZextLhs,
UnitAttr:$isZextRhs
);
let results = (outs VectorOfRankAndType<[2], [I32, I8]>:$res);
let extraClassDeclaration = [{
VectorType getLhsVectorType() {
@ -224,7 +226,7 @@ def TileMulIOp : AMX_Op<"tile_muli", [NoSideEffect, AllTypesMatch<["acc", "res"]
return res().getType().cast<VectorType>();
}
}];
let assemblyFormat = "$lhs `,` $rhs `,` $acc $zext attr-dict `:` "
let assemblyFormat = "$lhs (`zext` $isZextLhs^)? `,` $rhs (`zext` $isZextRhs^)? `,` $acc attr-dict `:` "
"type($lhs) `,` type($rhs) `,` type($acc) ";
}

View File

@ -85,8 +85,6 @@ static LogicalResult verify(amx::TileMulFOp op) {
}
static LogicalResult verify(amx::TileMulIOp op) {
if (op.zext().size() != 2)
return op.emitOpError("unexpected zext length");
VectorType aType = op.getLhsVectorType();
VectorType bType = op.getRhsVectorType();
VectorType cType = op.getVectorType();

View File

@ -191,8 +191,8 @@ struct TileMulIConversion : public ConvertOpToLLVMPattern<TileMulIOp> {
getTileSizes(rewriter, *getTypeConverter(), bType, op.getLoc());
// Replace operation with intrinsic.
Type resType = typeConverter->convertType(cType);
bool zexta = op.zext()[0].cast<BoolAttr>().getValue();
bool zextb = op.zext()[1].cast<BoolAttr>().getValue();
bool zexta = op.isZextLhs();
bool zextb = op.isZextRhs();
if (zexta && zextb)
rewriter.replaceOpWithNewOp<amx::x86_amx_tdpbuud>(
op, resType, tsza.first, tszb.second, tsza.second, adaptor.acc(),

View File

@ -46,13 +46,3 @@ func @multsize() {
// expected-error@+1 {{'amx.tile_mulf' op bad mult shape: 4 x 4 x 4}}
%3 = amx.tile_mulf %0, %1, %2 : vector<8x8xbf16>, vector<8x8xbf16>, vector<4x4xf32>
}
// -----
func @zextsize() {
%0 = amx.tile_zero : vector<8x8xi8>
%1 = amx.tile_zero : vector<8x8xi8>
%2 = amx.tile_zero : vector<8x8xi32>
// expected-error@+1 {{'amx.tile_muli' op unexpected zext length}}
%3 = amx.tile_muli %0, %1, %2 [true] : vector<8x8xi8>, vector<8x8xi8>, vector<8x8xi32>
}

View File

@ -17,13 +17,13 @@ func @muli(%arg0: memref<?x?xi8>, %arg1: memref<?x?xi32>) {
%1 = amx.tile_zero : vector<16x64xi8>
%2 = amx.tile_load %arg0[%0, %0] : memref<?x?xi8> into vector<16x64xi8>
%3 = amx.tile_load %arg1[%0, %0] : memref<?x?xi32> into vector<16x16xi32>
%4 = amx.tile_muli %1, %2, %3 [true, true] : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
%4 = amx.tile_muli %1 zext, %2 zext, %3 : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
amx.tile_store %arg1[%0, %0], %4 : memref<?x?xi32>, vector<16x16xi32>
%5 = amx.tile_muli %1, %2, %3 [false, false] : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
%5 = amx.tile_muli %1, %2, %3 : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
amx.tile_store %arg1[%0, %0], %5 : memref<?x?xi32>, vector<16x16xi32>
%6 = amx.tile_muli %1, %2, %3 [true, false] : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
%6 = amx.tile_muli %1 zext, %2, %3 : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
amx.tile_store %arg1[%0, %0], %6 : memref<?x?xi32>, vector<16x16xi32>
%7 = amx.tile_muli %1, %2, %3 [false, true] : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
%7 = amx.tile_muli %1, %2 zext, %3 : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
amx.tile_store %arg1[%0, %0], %7 : memref<?x?xi32>, vector<16x16xi32>
return
}

View File

@ -28,14 +28,22 @@ func @tmulf(%arg0: memref<?x?xbf16>, %arg1: memref<?x?xf32>) {
// CHECK: %[[x:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xi8> into vector<16x64xi8>
// CHECK: %[[y:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xi8> into vector<16x64xi8>
// CHECK: %[[z:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xi32> into vector<16x16xi32>
// CHECK: %[[m:.*]] = amx.tile_muli %[[x]], %[[y]], %[[z]] [true, true] : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
// CHECK: %[[m:.*]] = amx.tile_muli %[[x]] zext, %[[y]] zext, %[[z]] : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
// CHECK: amx.tile_store %{{.*}}[%{{.*}}, %{{.*}}], %[[m]] : memref<?x?xi32>, vector<16x16xi32>
// Verify the parsing/printing of the sign-extension annotation.
// CHECK: amx.tile_muli %{{.*}}, %{{.*}} zext, %{{.*}}
// CHECK: amx.tile_muli %{{.*}} zext, %{{.*}}, %{{.*}}
// CHECK: amx.tile_muli %{{.*}}, %{{.*}}, %{{.*}}
func @tmuli(%arg0: memref<?x?xi8>, %arg1: memref<?x?xi8>, %arg2: memref<?x?xi32>) {
%0 = constant 0 : index
%1 = amx.tile_load %arg0[%0, %0] : memref<?x?xi8> into vector<16x64xi8>
%2 = amx.tile_load %arg1[%0, %0] : memref<?x?xi8> into vector<16x64xi8>
%3 = amx.tile_load %arg2[%0, %0] : memref<?x?xi32> into vector<16x16xi32>
%4 = amx.tile_muli %1, %2, %3 [true, true] : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
%4 = amx.tile_muli %1 zext, %2 zext, %3 : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
amx.tile_store %arg2[%0, %0], %4 : memref<?x?xi32>, vector<16x16xi32>
// Verify the various `zext` combinations.
%5 = amx.tile_muli %1, %2 zext, %3 : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
%6 = amx.tile_muli %1 zext, %2, %3 : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
%7 = amx.tile_muli %1, %2, %3 : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
return
}

View File

@ -24,7 +24,7 @@ func @kernel1(%arg0: memref<16x16xi8>,
%1 = amx.tile_load %arg0[%0, %0] : memref<16x16xi8> into vector<16x16xi8>
%2 = amx.tile_load %arg1[%0, %0] : memref<4x16xi8> into vector<4x16xi8>
%3 = amx.tile_zero : vector<16x4xi32>
%4 = amx.tile_muli %1, %2, %3 [false, false] : vector<16x16xi8>, vector<4x16xi8>, vector<16x4xi32>
%4 = amx.tile_muli %1, %2, %3 : vector<16x16xi8>, vector<4x16xi8>, vector<16x4xi32>
amx.tile_store %arg2[%0, %0], %4 : memref<16x4xi32>, vector<16x4xi32>
return
}
@ -36,7 +36,7 @@ func @kernel2(%arg0: memref<16x16xi8>,
%1 = amx.tile_load %arg0[%0, %0] : memref<16x16xi8> into vector<16x16xi8>
%2 = amx.tile_load %arg1[%0, %0] : memref<4x16xi8> into vector<4x16xi8>
%3 = amx.tile_zero : vector<16x4xi32>
%4 = amx.tile_muli %1, %2, %3 [false, true] : vector<16x16xi8>, vector<4x16xi8>, vector<16x4xi32>
%4 = amx.tile_muli %1, %2 zext, %3 : vector<16x16xi8>, vector<4x16xi8>, vector<16x4xi32>
amx.tile_store %arg2[%0, %0], %4 : memref<16x4xi32>, vector<16x4xi32>
return
}
@ -48,7 +48,7 @@ func @kernel3(%arg0: memref<16x16xi8>,
%1 = amx.tile_load %arg0[%0, %0] : memref<16x16xi8> into vector<16x16xi8>
%2 = amx.tile_load %arg1[%0, %0] : memref<4x16xi8> into vector<4x16xi8>
%3 = amx.tile_zero : vector<16x4xi32>
%4 = amx.tile_muli %1, %2, %3 [true, false] : vector<16x16xi8>, vector<4x16xi8>, vector<16x4xi32>
%4 = amx.tile_muli %1 zext, %2, %3 : vector<16x16xi8>, vector<4x16xi8>, vector<16x4xi32>
amx.tile_store %arg2[%0, %0], %4 : memref<16x4xi32>, vector<16x4xi32>
return
}
@ -60,7 +60,7 @@ func @kernel4(%arg0: memref<16x16xi8>,
%1 = amx.tile_load %arg0[%0, %0] : memref<16x16xi8> into vector<16x16xi8>
%2 = amx.tile_load %arg1[%0, %0] : memref<4x16xi8> into vector<4x16xi8>
%3 = amx.tile_zero : vector<16x4xi32>
%4 = amx.tile_muli %1, %2, %3 [true, true] : vector<16x16xi8>, vector<4x16xi8>, vector<16x4xi32>
%4 = amx.tile_muli %1 zext, %2 zext, %3 : vector<16x16xi8>, vector<4x16xi8>, vector<16x4xi32>
amx.tile_store %arg2[%0, %0], %4 : memref<16x4xi32>, vector<16x4xi32>
return
}

View File

@ -13,7 +13,7 @@ func @kernel1(%arg0: memref<2x8xi8>,
%1 = amx.tile_load %arg0[%0, %0] : memref<2x8xi8> into vector<2x8xi8>
%2 = amx.tile_load %arg1[%0, %0] : memref<2x8xi8> into vector<2x8xi8>
%3 = amx.tile_zero : vector<2x2xi32>
%4 = amx.tile_muli %1, %2, %3 [true, true] : vector<2x8xi8>, vector<2x8xi8>, vector<2x2xi32>
%4 = amx.tile_muli %1 zext, %2 zext, %3 : vector<2x8xi8>, vector<2x8xi8>, vector<2x2xi32>
amx.tile_store %arg2[%0, %0], %4 : memref<2x2xi32>, vector<2x2xi32>
return
}
@ -26,7 +26,7 @@ func @kernel2(%arg0: memref<2x8xi8>,
%1 = amx.tile_load %arg0[%0, %0] : memref<2x8xi8> into vector<2x8xi8>
%2 = amx.tile_load %arg1[%0, %0] : memref<2x8xi8> into vector<2x8xi8>
%3 = amx.tile_load %arg2[%0, %0] : memref<2x2xi32> into vector<2x2xi32>
%4 = amx.tile_muli %1, %2, %3 [true, true] : vector<2x8xi8>, vector<2x8xi8>, vector<2x2xi32>
%4 = amx.tile_muli %1 zext, %2 zext, %3 : vector<2x8xi8>, vector<2x8xi8>, vector<2x2xi32>
amx.tile_store %arg2[%0, %0], %4 : memref<2x2xi32>, vector<2x2xi32>
return
}