forked from OSchip/llvm-project
[mlir][Linalg] Teach constant -> generic op fusion to handle scalar constants.
The current folder of constant -> generic op only handles splat constants. The same logic holds for scalar constants. Teach the pattern to handle such cases. Differential Revision: https://reviews.llvm.org/D109982
This commit is contained in:
parent
474816384f
commit
a40a08ed98
|
@ -1162,11 +1162,12 @@ private:
|
||||||
ControlElementwiseOpsFusionFn controlFoldingReshapes;
|
ControlElementwiseOpsFusionFn controlFoldingReshapes;
|
||||||
};
|
};
|
||||||
|
|
||||||
/// Pattern to fold a generic op with a splat constant.
|
/// Pattern to fold a generic op with a splat constant/scalar constant. Does not
|
||||||
class FoldSplatConstants : public OpRewritePattern<GenericOp> {
|
/// handle cases where the constant is not single-valued.
|
||||||
|
class FoldConstants : public OpRewritePattern<GenericOp> {
|
||||||
public:
|
public:
|
||||||
FoldSplatConstants(MLIRContext *context, ControlElementwiseOpsFusionFn &fun,
|
FoldConstants(MLIRContext *context, ControlElementwiseOpsFusionFn &fun,
|
||||||
PatternBenefit benefit = 1)
|
PatternBenefit benefit = 1)
|
||||||
: OpRewritePattern<GenericOp>(context, benefit), controlFn(fun) {}
|
: OpRewritePattern<GenericOp>(context, benefit), controlFn(fun) {}
|
||||||
|
|
||||||
LogicalResult matchAndRewrite(GenericOp genericOp,
|
LogicalResult matchAndRewrite(GenericOp genericOp,
|
||||||
|
@ -1175,10 +1176,37 @@ public:
|
||||||
return failure();
|
return failure();
|
||||||
for (OpOperand *opOperand : genericOp.getInputOperands()) {
|
for (OpOperand *opOperand : genericOp.getInputOperands()) {
|
||||||
Operation *def = opOperand->get().getDefiningOp();
|
Operation *def = opOperand->get().getDefiningOp();
|
||||||
DenseElementsAttr constantAttr;
|
Attribute constantAttr;
|
||||||
if (!def ||
|
auto isScalarOrSplatConstantOp = [&constantAttr](Operation *def) -> bool {
|
||||||
!matchPattern(def, m_Constant<DenseElementsAttr>(&constantAttr)) ||
|
{
|
||||||
!constantAttr.isSplat() || !controlFn(def->getResult(0), *opOperand))
|
DenseElementsAttr splatAttr;
|
||||||
|
if (matchPattern(def, m_Constant<DenseElementsAttr>(&splatAttr)) &&
|
||||||
|
splatAttr.isSplat() &&
|
||||||
|
splatAttr.getType().getElementType().isIntOrFloat()) {
|
||||||
|
constantAttr = splatAttr.getSplatValue();
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
{
|
||||||
|
IntegerAttr intAttr;
|
||||||
|
if (matchPattern(def, m_Constant<IntegerAttr>(&intAttr))) {
|
||||||
|
constantAttr = intAttr;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
{
|
||||||
|
FloatAttr floatAttr;
|
||||||
|
if (matchPattern(def, m_Constant<FloatAttr>(&floatAttr))) {
|
||||||
|
constantAttr = floatAttr;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
};
|
||||||
|
|
||||||
|
auto resultValue = opOperand->get().dyn_cast<OpResult>();
|
||||||
|
if (!def || !resultValue || !isScalarOrSplatConstantOp(def) ||
|
||||||
|
!controlFn(resultValue, *opOperand))
|
||||||
continue;
|
continue;
|
||||||
|
|
||||||
// The operands and the indexing_maps of the fused operation the same as
|
// The operands and the indexing_maps of the fused operation the same as
|
||||||
|
@ -1205,8 +1233,7 @@ public:
|
||||||
|
|
||||||
// Create a constant scalar value from the splat constant.
|
// Create a constant scalar value from the splat constant.
|
||||||
Value scalarConstant = rewriter.create<ConstantOp>(
|
Value scalarConstant = rewriter.create<ConstantOp>(
|
||||||
def->getLoc(), constantAttr.getSplatValue(),
|
def->getLoc(), constantAttr, constantAttr.getType());
|
||||||
constantAttr.getType().getElementType());
|
|
||||||
|
|
||||||
SmallVector<Value> outputOperands = genericOp.getOutputOperands();
|
SmallVector<Value> outputOperands = genericOp.getOutputOperands();
|
||||||
auto fusedOp = rewriter.create<GenericOp>(
|
auto fusedOp = rewriter.create<GenericOp>(
|
||||||
|
@ -1411,7 +1438,7 @@ void mlir::linalg::populateFoldReshapeOpsByExpansionPatterns(
|
||||||
void mlir::linalg::populateElementwiseOpsFusionPatterns(
|
void mlir::linalg::populateElementwiseOpsFusionPatterns(
|
||||||
RewritePatternSet &patterns, LinalgElementwiseFusionOptions options) {
|
RewritePatternSet &patterns, LinalgElementwiseFusionOptions options) {
|
||||||
auto *context = patterns.getContext();
|
auto *context = patterns.getContext();
|
||||||
patterns.add<FuseElementwiseOps, FoldSplatConstants>(
|
patterns.add<FuseElementwiseOps, FoldConstants>(
|
||||||
context, options.controlElementwiseOpsFusionFn);
|
context, options.controlElementwiseOpsFusionFn);
|
||||||
patterns.add<RemoveOutsDependency>(context);
|
patterns.add<RemoveOutsDependency>(context);
|
||||||
populateFoldReshapeOpsByExpansionPatterns(patterns,
|
populateFoldReshapeOpsByExpansionPatterns(patterns,
|
||||||
|
|
|
@ -740,3 +740,37 @@ func @break_outs_dependency(%arg0 : tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||||
// CHECK-DAG: %[[INIT:.+]] = linalg.init_tensor [%[[D0]], %[[D1]]]
|
// CHECK-DAG: %[[INIT:.+]] = linalg.init_tensor [%[[D0]], %[[D1]]]
|
||||||
// CHECK: %[[RESULT:.+]] = linalg.generic
|
// CHECK: %[[RESULT:.+]] = linalg.generic
|
||||||
// CHECK-SAME: outs(%[[INIT]] : tensor<?x?xf32>)
|
// CHECK-SAME: outs(%[[INIT]] : tensor<?x?xf32>)
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @fuse_scalar_constant(%arg0 : tensor<?x?xf32>) -> (tensor<?x?xf32>, tensor<?x?xi32>) {
|
||||||
|
%cst = constant 4.0 : f32
|
||||||
|
%c42 = constant 42 : i32
|
||||||
|
%c0 = constant 0 : index
|
||||||
|
%c1 = constant 1 : index
|
||||||
|
%d0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
|
||||||
|
%d1 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
|
||||||
|
%0 = linalg.init_tensor[%d0, %d1] : tensor<?x?xf32>
|
||||||
|
%1 = linalg.init_tensor[%d0, %d1] : tensor<?x?xi32>
|
||||||
|
%2:2 = linalg.generic {
|
||||||
|
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
|
||||||
|
affine_map<(d0, d1) -> ()>,
|
||||||
|
affine_map<(d0, d1) -> ()>,
|
||||||
|
affine_map<(d0, d1) -> (d0, d1)>,
|
||||||
|
affine_map<(d0, d1) -> (d0, d1)>],
|
||||||
|
iterator_types = ["parallel", "parallel"]}
|
||||||
|
ins(%arg0, %cst, %c42 : tensor<?x?xf32>, f32, i32)
|
||||||
|
outs(%0, %1 : tensor<?x?xf32>, tensor<?x?xi32>) {
|
||||||
|
^bb0(%arg1 : f32, %arg2 : f32, %arg3 : i32, %arg4 : f32, %arg5 : i32) :
|
||||||
|
%3 = addf %arg1, %arg2 : f32
|
||||||
|
linalg.yield %3, %arg3 : f32, i32
|
||||||
|
} -> (tensor<?x?xf32>, tensor<?x?xi32>)
|
||||||
|
return %2#0, %2#1 : tensor<?x?xf32>, tensor<?x?xi32>
|
||||||
|
}
|
||||||
|
// CHECK-LABEL: func @fuse_scalar_constant
|
||||||
|
// CHECK-DAG: %[[CST:.+]] = constant 4.000000e+00 : f32
|
||||||
|
// CHECK-DAG: %[[C42:.+]] = constant 42 : i32
|
||||||
|
// CHECK: linalg.generic
|
||||||
|
// CHECK-SAME: ins(%{{.+}} : tensor<?x?xf32>)
|
||||||
|
// CHECK: %[[YIELD:.+]] = addf %{{.+}}, %[[CST]] : f32
|
||||||
|
// CHECK: linalg.yield %[[YIELD]], %[[C42]] : f32, i32
|
||||||
|
|
Loading…
Reference in New Issue