forked from OSchip/llvm-project
[mlir][Vector] Add a folder for vector.broadcast
Fold the operation if the source is a scalar constant or splat constant. Update transform-patterns-matmul-to-vector.mlir because the broadcast ops are folded in the conversion. Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D87703
This commit is contained in:
parent
7f1f89ec8d
commit
f16abe5f84
|
@ -270,6 +270,7 @@ def Vector_BroadcastOp :
|
|||
}
|
||||
}];
|
||||
let assemblyFormat = "$source attr-dict `:` type($source) `to` type($vector)";
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def Vector_ShuffleOp :
|
||||
|
|
|
@ -929,6 +929,17 @@ static LogicalResult verify(BroadcastOp op) {
|
|||
return success();
|
||||
}
|
||||
|
||||
OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) {
|
||||
if (!operands[0])
|
||||
return {};
|
||||
auto vectorType = getVectorType();
|
||||
if (operands[0].getType().isIntOrIndexOrFloat())
|
||||
return DenseElementsAttr::get(vectorType, operands[0]);
|
||||
if (auto attr = operands[0].dyn_cast<SplatElementsAttr>())
|
||||
return DenseElementsAttr::get(vectorType, attr.getSplatValue());
|
||||
return {};
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ShuffleOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -13,13 +13,8 @@ func @matmul(%A: memref<1584x1584xf32, offset: 0, strides: [1584, 1]>,
|
|||
}
|
||||
|
||||
// CHECK-LABEL:func @matmul
|
||||
// CHECK: vector.broadcast {{.*}} : f32 to vector<8x16xf32>
|
||||
// CHECK: store {{.*}}[] : memref<vector<8x16xf32>>
|
||||
//
|
||||
// CHECK: vector.broadcast {{.*}} : f32 to vector<16x12xf32>
|
||||
// CHECK: store {{.*}}[] : memref<vector<16x12xf32>>
|
||||
//
|
||||
// CHECK: vector.broadcast {{.*}} : f32 to vector<8x12xf32>
|
||||
// CHECK: store {{.*}}[] : memref<vector<8x12xf32>>
|
||||
//
|
||||
// CHECK: linalg.copy
|
||||
|
|
|
@ -385,3 +385,28 @@ func @bitcast_folding(%I1: vector<4x8xf32>, %I2: vector<2xi32>) -> (vector<4x8xf
|
|||
%2 = vector.bitcast %1 : vector<4xi16> to vector<2xi32>
|
||||
return %0, %2 : vector<4x8xf32>, vector<2xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: broadcast_folding1
|
||||
// CHECK: %[[CST:.*]] = constant dense<42> : vector<4xi32>
|
||||
// CHECK-NOT: vector.broadcast
|
||||
// CHECK: return %[[CST]]
|
||||
func @broadcast_folding1() -> vector<4xi32> {
|
||||
%0 = constant 42 : i32
|
||||
%1 = vector.broadcast %0 : i32 to vector<4xi32>
|
||||
return %1 : vector<4xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @broadcast_folding2
|
||||
// CHECK: %[[CST:.*]] = constant dense<42> : vector<4x16xi32>
|
||||
// CHECK-NOT: vector.broadcast
|
||||
// CHECK: return %[[CST]]
|
||||
func @broadcast_folding2() -> vector<4x16xi32> {
|
||||
%0 = constant 42 : i32
|
||||
%1 = vector.broadcast %0 : i32 to vector<16xi32>
|
||||
%2 = vector.broadcast %1 : vector<16xi32> to vector<4x16xi32>
|
||||
return %2 : vector<4x16xi32>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue