forked from OSchip/llvm-project
Update syntax for amx.tile_muli to use two Unit attr to mark the zext case
This makes the annotation tied to the operand and the use of a keyword more explicit/readable on what it means. Differential Revision: https://reviews.llvm.org/D99001
This commit is contained in:
parent
ea48bf8649
commit
cdb6eb7e83
|
@ -196,14 +196,14 @@ def TileMulIOp : AMX_Op<"tile_muli", [NoSideEffect, AllTypesMatch<["acc", "res"]
|
|||
into a "m x n" destination tile. Supports all "si32 <- s/ui8 x s/ui8"
|
||||
combinations (4 bytes packed into dwords in the columns of both the
|
||||
source operand tiles; the zero or sign extension is specified with
|
||||
the attributes). The operation is eventually lowered into one of
|
||||
the "tdpbssd", "tdpbsud", "tdpbusd", or "tdpbuud" instructions with
|
||||
the corresponding tile configuration.
|
||||
the attributes and default to sign extended). The operation is eventually
|
||||
lowered into one of the "tdpbssd", "tdpbsud", "tdpbusd", or "tdpbuud"
|
||||
instructions with the corresponding tile configuration.
|
||||
|
||||
Example:
|
||||
|
||||
```mlir
|
||||
%0 = amx.tile_muli %a, %b, %c [true, true]
|
||||
%0 = amx.tile_muli %a zext, %b zext, %c
|
||||
: vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
|
||||
```
|
||||
}];
|
||||
|
@ -211,7 +211,9 @@ def TileMulIOp : AMX_Op<"tile_muli", [NoSideEffect, AllTypesMatch<["acc", "res"]
|
|||
let arguments = (ins VectorOfRankAndType<[2], [I32, I8]>:$lhs,
|
||||
VectorOfRankAndType<[2], [I32, I8]>:$rhs,
|
||||
VectorOfRankAndType<[2], [I32, I8]>:$acc,
|
||||
BoolArrayAttr:$zext);
|
||||
UnitAttr:$isZextLhs,
|
||||
UnitAttr:$isZextRhs
|
||||
);
|
||||
let results = (outs VectorOfRankAndType<[2], [I32, I8]>:$res);
|
||||
let extraClassDeclaration = [{
|
||||
VectorType getLhsVectorType() {
|
||||
|
@ -224,7 +226,7 @@ def TileMulIOp : AMX_Op<"tile_muli", [NoSideEffect, AllTypesMatch<["acc", "res"]
|
|||
return res().getType().cast<VectorType>();
|
||||
}
|
||||
}];
|
||||
let assemblyFormat = "$lhs `,` $rhs `,` $acc $zext attr-dict `:` "
|
||||
let assemblyFormat = "$lhs (`zext` $isZextLhs^)? `,` $rhs (`zext` $isZextRhs^)? `,` $acc attr-dict `:` "
|
||||
"type($lhs) `,` type($rhs) `,` type($acc) ";
|
||||
}
|
||||
|
||||
|
|
|
@ -85,8 +85,6 @@ static LogicalResult verify(amx::TileMulFOp op) {
|
|||
}
|
||||
|
||||
static LogicalResult verify(amx::TileMulIOp op) {
|
||||
if (op.zext().size() != 2)
|
||||
return op.emitOpError("unexpected zext length");
|
||||
VectorType aType = op.getLhsVectorType();
|
||||
VectorType bType = op.getRhsVectorType();
|
||||
VectorType cType = op.getVectorType();
|
||||
|
|
|
@ -191,8 +191,8 @@ struct TileMulIConversion : public ConvertOpToLLVMPattern<TileMulIOp> {
|
|||
getTileSizes(rewriter, *getTypeConverter(), bType, op.getLoc());
|
||||
// Replace operation with intrinsic.
|
||||
Type resType = typeConverter->convertType(cType);
|
||||
bool zexta = op.zext()[0].cast<BoolAttr>().getValue();
|
||||
bool zextb = op.zext()[1].cast<BoolAttr>().getValue();
|
||||
bool zexta = op.isZextLhs();
|
||||
bool zextb = op.isZextRhs();
|
||||
if (zexta && zextb)
|
||||
rewriter.replaceOpWithNewOp<amx::x86_amx_tdpbuud>(
|
||||
op, resType, tsza.first, tszb.second, tsza.second, adaptor.acc(),
|
||||
|
|
|
@ -46,13 +46,3 @@ func @multsize() {
|
|||
// expected-error@+1 {{'amx.tile_mulf' op bad mult shape: 4 x 4 x 4}}
|
||||
%3 = amx.tile_mulf %0, %1, %2 : vector<8x8xbf16>, vector<8x8xbf16>, vector<4x4xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @zextsize() {
|
||||
%0 = amx.tile_zero : vector<8x8xi8>
|
||||
%1 = amx.tile_zero : vector<8x8xi8>
|
||||
%2 = amx.tile_zero : vector<8x8xi32>
|
||||
// expected-error@+1 {{'amx.tile_muli' op unexpected zext length}}
|
||||
%3 = amx.tile_muli %0, %1, %2 [true] : vector<8x8xi8>, vector<8x8xi8>, vector<8x8xi32>
|
||||
}
|
||||
|
|
|
@ -17,13 +17,13 @@ func @muli(%arg0: memref<?x?xi8>, %arg1: memref<?x?xi32>) {
|
|||
%1 = amx.tile_zero : vector<16x64xi8>
|
||||
%2 = amx.tile_load %arg0[%0, %0] : memref<?x?xi8> into vector<16x64xi8>
|
||||
%3 = amx.tile_load %arg1[%0, %0] : memref<?x?xi32> into vector<16x16xi32>
|
||||
%4 = amx.tile_muli %1, %2, %3 [true, true] : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
|
||||
%4 = amx.tile_muli %1 zext, %2 zext, %3 : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
|
||||
amx.tile_store %arg1[%0, %0], %4 : memref<?x?xi32>, vector<16x16xi32>
|
||||
%5 = amx.tile_muli %1, %2, %3 [false, false] : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
|
||||
%5 = amx.tile_muli %1, %2, %3 : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
|
||||
amx.tile_store %arg1[%0, %0], %5 : memref<?x?xi32>, vector<16x16xi32>
|
||||
%6 = amx.tile_muli %1, %2, %3 [true, false] : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
|
||||
%6 = amx.tile_muli %1 zext, %2, %3 : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
|
||||
amx.tile_store %arg1[%0, %0], %6 : memref<?x?xi32>, vector<16x16xi32>
|
||||
%7 = amx.tile_muli %1, %2, %3 [false, true] : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
|
||||
%7 = amx.tile_muli %1, %2 zext, %3 : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
|
||||
amx.tile_store %arg1[%0, %0], %7 : memref<?x?xi32>, vector<16x16xi32>
|
||||
return
|
||||
}
|
||||
|
|
|
@ -28,14 +28,22 @@ func @tmulf(%arg0: memref<?x?xbf16>, %arg1: memref<?x?xf32>) {
|
|||
// CHECK: %[[x:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xi8> into vector<16x64xi8>
|
||||
// CHECK: %[[y:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xi8> into vector<16x64xi8>
|
||||
// CHECK: %[[z:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xi32> into vector<16x16xi32>
|
||||
// CHECK: %[[m:.*]] = amx.tile_muli %[[x]], %[[y]], %[[z]] [true, true] : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
|
||||
// CHECK: %[[m:.*]] = amx.tile_muli %[[x]] zext, %[[y]] zext, %[[z]] : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
|
||||
// CHECK: amx.tile_store %{{.*}}[%{{.*}}, %{{.*}}], %[[m]] : memref<?x?xi32>, vector<16x16xi32>
|
||||
// Verify the parsing/printing of the sign-extension annotation.
|
||||
// CHECK: amx.tile_muli %{{.*}}, %{{.*}} zext, %{{.*}}
|
||||
// CHECK: amx.tile_muli %{{.*}} zext, %{{.*}}, %{{.*}}
|
||||
// CHECK: amx.tile_muli %{{.*}}, %{{.*}}, %{{.*}}
|
||||
func @tmuli(%arg0: memref<?x?xi8>, %arg1: memref<?x?xi8>, %arg2: memref<?x?xi32>) {
|
||||
%0 = constant 0 : index
|
||||
%1 = amx.tile_load %arg0[%0, %0] : memref<?x?xi8> into vector<16x64xi8>
|
||||
%2 = amx.tile_load %arg1[%0, %0] : memref<?x?xi8> into vector<16x64xi8>
|
||||
%3 = amx.tile_load %arg2[%0, %0] : memref<?x?xi32> into vector<16x16xi32>
|
||||
%4 = amx.tile_muli %1, %2, %3 [true, true] : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
|
||||
%4 = amx.tile_muli %1 zext, %2 zext, %3 : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
|
||||
amx.tile_store %arg2[%0, %0], %4 : memref<?x?xi32>, vector<16x16xi32>
|
||||
// Verify the various `zext` combinations.
|
||||
%5 = amx.tile_muli %1, %2 zext, %3 : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
|
||||
%6 = amx.tile_muli %1 zext, %2, %3 : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
|
||||
%7 = amx.tile_muli %1, %2, %3 : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
|
||||
return
|
||||
}
|
||||
|
|
|
@ -24,7 +24,7 @@ func @kernel1(%arg0: memref<16x16xi8>,
|
|||
%1 = amx.tile_load %arg0[%0, %0] : memref<16x16xi8> into vector<16x16xi8>
|
||||
%2 = amx.tile_load %arg1[%0, %0] : memref<4x16xi8> into vector<4x16xi8>
|
||||
%3 = amx.tile_zero : vector<16x4xi32>
|
||||
%4 = amx.tile_muli %1, %2, %3 [false, false] : vector<16x16xi8>, vector<4x16xi8>, vector<16x4xi32>
|
||||
%4 = amx.tile_muli %1, %2, %3 : vector<16x16xi8>, vector<4x16xi8>, vector<16x4xi32>
|
||||
amx.tile_store %arg2[%0, %0], %4 : memref<16x4xi32>, vector<16x4xi32>
|
||||
return
|
||||
}
|
||||
|
@ -36,7 +36,7 @@ func @kernel2(%arg0: memref<16x16xi8>,
|
|||
%1 = amx.tile_load %arg0[%0, %0] : memref<16x16xi8> into vector<16x16xi8>
|
||||
%2 = amx.tile_load %arg1[%0, %0] : memref<4x16xi8> into vector<4x16xi8>
|
||||
%3 = amx.tile_zero : vector<16x4xi32>
|
||||
%4 = amx.tile_muli %1, %2, %3 [false, true] : vector<16x16xi8>, vector<4x16xi8>, vector<16x4xi32>
|
||||
%4 = amx.tile_muli %1, %2 zext, %3 : vector<16x16xi8>, vector<4x16xi8>, vector<16x4xi32>
|
||||
amx.tile_store %arg2[%0, %0], %4 : memref<16x4xi32>, vector<16x4xi32>
|
||||
return
|
||||
}
|
||||
|
@ -48,7 +48,7 @@ func @kernel3(%arg0: memref<16x16xi8>,
|
|||
%1 = amx.tile_load %arg0[%0, %0] : memref<16x16xi8> into vector<16x16xi8>
|
||||
%2 = amx.tile_load %arg1[%0, %0] : memref<4x16xi8> into vector<4x16xi8>
|
||||
%3 = amx.tile_zero : vector<16x4xi32>
|
||||
%4 = amx.tile_muli %1, %2, %3 [true, false] : vector<16x16xi8>, vector<4x16xi8>, vector<16x4xi32>
|
||||
%4 = amx.tile_muli %1 zext, %2, %3 : vector<16x16xi8>, vector<4x16xi8>, vector<16x4xi32>
|
||||
amx.tile_store %arg2[%0, %0], %4 : memref<16x4xi32>, vector<16x4xi32>
|
||||
return
|
||||
}
|
||||
|
@ -60,7 +60,7 @@ func @kernel4(%arg0: memref<16x16xi8>,
|
|||
%1 = amx.tile_load %arg0[%0, %0] : memref<16x16xi8> into vector<16x16xi8>
|
||||
%2 = amx.tile_load %arg1[%0, %0] : memref<4x16xi8> into vector<4x16xi8>
|
||||
%3 = amx.tile_zero : vector<16x4xi32>
|
||||
%4 = amx.tile_muli %1, %2, %3 [true, true] : vector<16x16xi8>, vector<4x16xi8>, vector<16x4xi32>
|
||||
%4 = amx.tile_muli %1 zext, %2 zext, %3 : vector<16x16xi8>, vector<4x16xi8>, vector<16x4xi32>
|
||||
amx.tile_store %arg2[%0, %0], %4 : memref<16x4xi32>, vector<16x4xi32>
|
||||
return
|
||||
}
|
||||
|
|
|
@ -13,7 +13,7 @@ func @kernel1(%arg0: memref<2x8xi8>,
|
|||
%1 = amx.tile_load %arg0[%0, %0] : memref<2x8xi8> into vector<2x8xi8>
|
||||
%2 = amx.tile_load %arg1[%0, %0] : memref<2x8xi8> into vector<2x8xi8>
|
||||
%3 = amx.tile_zero : vector<2x2xi32>
|
||||
%4 = amx.tile_muli %1, %2, %3 [true, true] : vector<2x8xi8>, vector<2x8xi8>, vector<2x2xi32>
|
||||
%4 = amx.tile_muli %1 zext, %2 zext, %3 : vector<2x8xi8>, vector<2x8xi8>, vector<2x2xi32>
|
||||
amx.tile_store %arg2[%0, %0], %4 : memref<2x2xi32>, vector<2x2xi32>
|
||||
return
|
||||
}
|
||||
|
@ -26,7 +26,7 @@ func @kernel2(%arg0: memref<2x8xi8>,
|
|||
%1 = amx.tile_load %arg0[%0, %0] : memref<2x8xi8> into vector<2x8xi8>
|
||||
%2 = amx.tile_load %arg1[%0, %0] : memref<2x8xi8> into vector<2x8xi8>
|
||||
%3 = amx.tile_load %arg2[%0, %0] : memref<2x2xi32> into vector<2x2xi32>
|
||||
%4 = amx.tile_muli %1, %2, %3 [true, true] : vector<2x8xi8>, vector<2x8xi8>, vector<2x2xi32>
|
||||
%4 = amx.tile_muli %1 zext, %2 zext, %3 : vector<2x8xi8>, vector<2x8xi8>, vector<2x2xi32>
|
||||
amx.tile_store %arg2[%0, %0], %4 : memref<2x2xi32>, vector<2x2xi32>
|
||||
return
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue