forked from OSchip/llvm-project
[mlir][Linalg] Move named op conversions out of canonicalizations.
These conversions are better suited to be applied at whole tensor level. Applying these as canonicalizations end up triggering such canonicalizations at all levels of the stack which might be undesirable. For example some of the resulting code patterns wont bufferize in-place and need additional stack buffers. Best is to be more deliberate in when these canonicalizations apply. Differential Revision: https://reviews.llvm.org/D115912
This commit is contained in:
parent
bee5bc9075
commit
4142932a83
|
@ -26,6 +26,8 @@ std::unique_ptr<Pass> createLinalgFoldUnitExtentDimsPass();
|
|||
std::unique_ptr<Pass> createLinalgElementwiseOpFusionPass();
|
||||
std::unique_ptr<Pass> createFoldReshapeOpsByLinearizationPass();
|
||||
|
||||
std::unique_ptr<Pass> createLinalgNamedOpConversionPass();
|
||||
|
||||
std::unique_ptr<OperationPass<FuncOp>> createLinalgTilingPass(
|
||||
ArrayRef<int64_t> tileSizes = {},
|
||||
linalg::LinalgTilingLoopType loopType = linalg::LinalgTilingLoopType::Loops,
|
||||
|
|
|
@ -100,6 +100,12 @@ def LinalgFoldReshapeOpsByLinearization :
|
|||
let dependentDialects = ["AffineDialect", "memref::MemRefDialect"];
|
||||
}
|
||||
|
||||
def LinalgNamedOpConversion: Pass<"linalg-named-op-conversion"> {
|
||||
let summary = "Convert from one named linalg op to another.";
|
||||
let constructor = "mlir::createLinalgNamedOpConversionPass()";
|
||||
let dependentDialects = ["linalg::LinalgDialect", "tensor::TensorDialect"];
|
||||
}
|
||||
|
||||
def LinalgLowerTiledLoopsToSCF
|
||||
: FunctionPass<"convert-linalg-tiled-loops-to-scf"> {
|
||||
let summary = "Lower linalg tiled loops to SCF loops and parallel loops";
|
||||
|
|
|
@ -86,6 +86,10 @@ void populateFoldReshapeOpsByLinearizationPatterns(RewritePatternSet &patterns);
|
|||
void populateFoldUnitDimsReshapeOpsByLinearizationPatterns(
|
||||
RewritePatternSet &patterns);
|
||||
|
||||
/// Patterns to convert from one named op to another. These can be seen as
|
||||
/// canonicalizations of named ops into another named op.
|
||||
void populateLinalgNamedOpConversionPatterns(RewritePatternSet &patterns);
|
||||
|
||||
/// Populates the given list with patterns to bufferize linalg ops.
|
||||
void populateLinalgBufferizePatterns(
|
||||
bufferization::BufferizeTypeConverter &converter,
|
||||
|
|
|
@ -2665,118 +2665,6 @@ struct FoldTensorCastOp : public OpInterfaceRewritePattern<LinalgOp> {
|
|||
}
|
||||
};
|
||||
|
||||
static llvm::SmallVector<int64_t> getIndicesVector(int start, int end) {
|
||||
return llvm::to_vector<2>(llvm::seq<int64_t>(start, end));
|
||||
}
|
||||
|
||||
LogicalResult matchAndReplaceDepthwiseConv(Operation *operation, Value input,
|
||||
Value kernel, Value iZp, Value kZp,
|
||||
Value init, Attribute stride,
|
||||
Attribute dilation,
|
||||
PatternRewriter &rewriter) {
|
||||
Location loc = operation->getLoc();
|
||||
auto linalgOp = dyn_cast<LinalgOp>(operation);
|
||||
// Exit out on the memref version of this operation.
|
||||
if (!linalgOp || !linalgOp.hasTensorSemantics())
|
||||
return failure();
|
||||
|
||||
auto result = operation->getResult(0);
|
||||
|
||||
auto kernelTy = kernel.getType().dyn_cast<RankedTensorType>();
|
||||
auto initTy = init.getType().dyn_cast<RankedTensorType>();
|
||||
auto resultTy = result.getType().template dyn_cast<RankedTensorType>();
|
||||
if (!kernelTy || !initTy || !resultTy)
|
||||
return failure();
|
||||
|
||||
if (kernelTy.getDimSize(3) != 1)
|
||||
return failure();
|
||||
|
||||
// Collapse kernel dims.
|
||||
SmallVector<ReassociationIndices, 4> collapsedKernelDims = {
|
||||
getIndicesVector(0, 1), getIndicesVector(1, 2), getIndicesVector(2, 4)};
|
||||
auto newKernelTy = RankedTensorType::get(
|
||||
{kernelTy.getDimSize(0), kernelTy.getDimSize(1), kernelTy.getDimSize(2)},
|
||||
kernelTy.getElementType());
|
||||
auto collapsedKernel = rewriter.create<tensor::CollapseShapeOp>(
|
||||
loc, newKernelTy, kernel, collapsedKernelDims);
|
||||
|
||||
// Collapse init dims.
|
||||
SmallVector<ReassociationIndices, 4> collapsedInitDims = {
|
||||
getIndicesVector(0, 1), getIndicesVector(1, 2), getIndicesVector(2, 3),
|
||||
getIndicesVector(3, 5)};
|
||||
auto newInitTy =
|
||||
RankedTensorType::get({initTy.getDimSize(0), initTy.getDimSize(1),
|
||||
initTy.getDimSize(2), initTy.getDimSize(3)},
|
||||
initTy.getElementType());
|
||||
auto collapsedInit = rewriter.create<tensor::CollapseShapeOp>(
|
||||
loc, newInitTy, init, collapsedInitDims);
|
||||
|
||||
Value newConv;
|
||||
if (isa<DepthwiseConv2DNhwcHwcmOp>(operation)) {
|
||||
newConv = rewriter
|
||||
.create<DepthwiseConv2DNhwcHwcOp>(
|
||||
loc, newInitTy, ValueRange{input, collapsedKernel},
|
||||
ValueRange{collapsedInit}, stride, dilation)
|
||||
.getResult(0);
|
||||
} else if (isa<DepthwiseConv2DNhwcHwcmQOp>(operation)) {
|
||||
newConv =
|
||||
rewriter
|
||||
.create<DepthwiseConv2DNhwcHwcQOp>(
|
||||
loc, newInitTy, ValueRange{input, collapsedKernel, iZp, kZp},
|
||||
ValueRange{collapsedInit}, stride, dilation)
|
||||
.getResult(0);
|
||||
}
|
||||
|
||||
if (!newConv)
|
||||
return failure();
|
||||
|
||||
// Expand dimensions back out to
|
||||
rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
|
||||
operation, resultTy, newConv, collapsedInitDims);
|
||||
return success();
|
||||
}
|
||||
|
||||
struct SimplifyDepthwiseConvOp
|
||||
: public OpRewritePattern<DepthwiseConv2DNhwcHwcmOp> {
|
||||
using OpRewritePattern<DepthwiseConv2DNhwcHwcmOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(DepthwiseConv2DNhwcHwcmOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Operation *operation = op.getOperation();
|
||||
Value input = op.getInputOperand(0)->get();
|
||||
Value kernel = op.getInputOperand(1)->get();
|
||||
Value init = op.getOutputOperand(0)->get();
|
||||
|
||||
auto stride = op.strides();
|
||||
auto dilation = op.dilations();
|
||||
|
||||
return matchAndReplaceDepthwiseConv(operation, input, kernel, nullptr,
|
||||
nullptr, init, stride, dilation,
|
||||
rewriter);
|
||||
}
|
||||
};
|
||||
|
||||
struct SimplifyDepthwiseConvQOp
|
||||
: public OpRewritePattern<DepthwiseConv2DNhwcHwcmQOp> {
|
||||
using OpRewritePattern<DepthwiseConv2DNhwcHwcmQOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(DepthwiseConv2DNhwcHwcmQOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Operation *operation = op.getOperation();
|
||||
Value input = op.getInputOperand(0)->get();
|
||||
Value kernel = op.getInputOperand(1)->get();
|
||||
Value iZp = op.getInputOperand(2)->get();
|
||||
Value kZp = op.getInputOperand(3)->get();
|
||||
Value init = op.getOutputOperand(0)->get();
|
||||
|
||||
auto stride = op.strides();
|
||||
auto dilation = op.dilations();
|
||||
|
||||
return matchAndReplaceDepthwiseConv(operation, input, kernel, iZp, kZp,
|
||||
init, stride, dilation, rewriter);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
#define LINALGOP_FOLDERS(XXX) \
|
||||
|
@ -2798,8 +2686,7 @@ LINALGOP_FOLDERS(GenericOp)
|
|||
|
||||
void LinalgDialect::getCanonicalizationPatterns(
|
||||
RewritePatternSet &results) const {
|
||||
results.add<EraseDeadLinalgOp, FoldTensorCastOp, SimplifyDepthwiseConvOp,
|
||||
SimplifyDepthwiseConvQOp>(getContext());
|
||||
results.add<EraseDeadLinalgOp, FoldTensorCastOp>(getContext());
|
||||
}
|
||||
|
||||
Operation *LinalgDialect::materializeConstant(OpBuilder &builder,
|
||||
|
|
|
@ -16,6 +16,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
|
|||
Interchange.cpp
|
||||
Loops.cpp
|
||||
LinalgStrategyPasses.cpp
|
||||
NamedOpConversions.cpp
|
||||
Promotion.cpp
|
||||
Tiling.cpp
|
||||
Transforms.cpp
|
||||
|
|
|
@ -0,0 +1,160 @@
|
|||
//===- NamedOpConversions.cpp - Implements conversions between named ops --===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file implements conversions between named ops that can be seens as
|
||||
// canonicalizations of named ops.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
#include "PassDetail.h"
|
||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||
#include "mlir/Dialect/Linalg/Passes.h"
|
||||
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::linalg;
|
||||
|
||||
static llvm::SmallVector<int64_t> getIndicesVector(int start, int end) {
|
||||
return llvm::to_vector<2>(llvm::seq<int64_t>(start, end));
|
||||
}
|
||||
|
||||
static LogicalResult
|
||||
matchAndReplaceDepthwiseConv(Operation *operation, Value input, Value kernel,
|
||||
Value iZp, Value kZp, Value init, Attribute stride,
|
||||
Attribute dilation, PatternRewriter &rewriter) {
|
||||
Location loc = operation->getLoc();
|
||||
auto linalgOp = dyn_cast<LinalgOp>(operation);
|
||||
// Exit out on the memref version of this operation.
|
||||
if (!linalgOp || !linalgOp.hasTensorSemantics())
|
||||
return failure();
|
||||
|
||||
auto result = operation->getResult(0);
|
||||
|
||||
auto kernelTy = kernel.getType().dyn_cast<RankedTensorType>();
|
||||
auto initTy = init.getType().dyn_cast<RankedTensorType>();
|
||||
auto resultTy = result.getType().template dyn_cast<RankedTensorType>();
|
||||
if (!kernelTy || !initTy || !resultTy)
|
||||
return failure();
|
||||
|
||||
if (kernelTy.getDimSize(3) != 1)
|
||||
return failure();
|
||||
|
||||
// Collapse kernel dims.
|
||||
SmallVector<ReassociationIndices, 4> collapsedKernelDims = {
|
||||
getIndicesVector(0, 1), getIndicesVector(1, 2), getIndicesVector(2, 4)};
|
||||
auto newKernelTy = RankedTensorType::get(
|
||||
{kernelTy.getDimSize(0), kernelTy.getDimSize(1), kernelTy.getDimSize(2)},
|
||||
kernelTy.getElementType());
|
||||
auto collapsedKernel = rewriter.create<tensor::CollapseShapeOp>(
|
||||
loc, newKernelTy, kernel, collapsedKernelDims);
|
||||
|
||||
// Collapse init dims.
|
||||
SmallVector<ReassociationIndices, 4> collapsedInitDims = {
|
||||
getIndicesVector(0, 1), getIndicesVector(1, 2), getIndicesVector(2, 3),
|
||||
getIndicesVector(3, 5)};
|
||||
auto newInitTy =
|
||||
RankedTensorType::get({initTy.getDimSize(0), initTy.getDimSize(1),
|
||||
initTy.getDimSize(2), initTy.getDimSize(3)},
|
||||
initTy.getElementType());
|
||||
auto collapsedInit = rewriter.create<tensor::CollapseShapeOp>(
|
||||
loc, newInitTy, init, collapsedInitDims);
|
||||
|
||||
Value newConv;
|
||||
if (isa<DepthwiseConv2DNhwcHwcmOp>(operation)) {
|
||||
newConv = rewriter
|
||||
.create<DepthwiseConv2DNhwcHwcOp>(
|
||||
loc, newInitTy, ValueRange{input, collapsedKernel},
|
||||
ValueRange{collapsedInit}, stride, dilation)
|
||||
.getResult(0);
|
||||
} else if (isa<DepthwiseConv2DNhwcHwcmQOp>(operation)) {
|
||||
newConv =
|
||||
rewriter
|
||||
.create<DepthwiseConv2DNhwcHwcQOp>(
|
||||
loc, newInitTy, ValueRange{input, collapsedKernel, iZp, kZp},
|
||||
ValueRange{collapsedInit}, stride, dilation)
|
||||
.getResult(0);
|
||||
}
|
||||
|
||||
if (!newConv)
|
||||
return failure();
|
||||
|
||||
// Expand dimensions back out to
|
||||
rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
|
||||
operation, resultTy, newConv, collapsedInitDims);
|
||||
return success();
|
||||
}
|
||||
|
||||
namespace {
|
||||
struct SimplifyDepthwiseConvOp
|
||||
: public OpRewritePattern<DepthwiseConv2DNhwcHwcmOp> {
|
||||
using OpRewritePattern<DepthwiseConv2DNhwcHwcmOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(DepthwiseConv2DNhwcHwcmOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Operation *operation = op.getOperation();
|
||||
Value input = op.getInputOperand(0)->get();
|
||||
Value kernel = op.getInputOperand(1)->get();
|
||||
Value init = op.getOutputOperand(0)->get();
|
||||
|
||||
auto stride = op.strides();
|
||||
auto dilation = op.dilations();
|
||||
|
||||
return matchAndReplaceDepthwiseConv(operation, input, kernel, nullptr,
|
||||
nullptr, init, stride, dilation,
|
||||
rewriter);
|
||||
}
|
||||
};
|
||||
|
||||
struct SimplifyDepthwiseConvQOp
|
||||
: public OpRewritePattern<DepthwiseConv2DNhwcHwcmQOp> {
|
||||
using OpRewritePattern<DepthwiseConv2DNhwcHwcmQOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(DepthwiseConv2DNhwcHwcmQOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Operation *operation = op.getOperation();
|
||||
Value input = op.getInputOperand(0)->get();
|
||||
Value kernel = op.getInputOperand(1)->get();
|
||||
Value iZp = op.getInputOperand(2)->get();
|
||||
Value kZp = op.getInputOperand(3)->get();
|
||||
Value init = op.getOutputOperand(0)->get();
|
||||
|
||||
auto stride = op.strides();
|
||||
auto dilation = op.dilations();
|
||||
|
||||
return matchAndReplaceDepthwiseConv(operation, input, kernel, iZp, kZp,
|
||||
init, stride, dilation, rewriter);
|
||||
}
|
||||
};
|
||||
|
||||
struct LinalgNamedOpConversionPass
|
||||
: public LinalgNamedOpConversionBase<LinalgNamedOpConversionPass> {
|
||||
LinalgNamedOpConversionPass() = default;
|
||||
LinalgNamedOpConversionPass(const LinalgNamedOpConversionPass &) {}
|
||||
|
||||
void runOnOperation() override {
|
||||
Operation *op = getOperation();
|
||||
RewritePatternSet patterns(op->getContext());
|
||||
populateLinalgNamedOpConversionPatterns(patterns);
|
||||
if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
|
||||
return signalPassFailure();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void mlir::linalg::populateLinalgNamedOpConversionPatterns(
|
||||
RewritePatternSet &patterns) {
|
||||
patterns.add<SimplifyDepthwiseConvOp, SimplifyDepthwiseConvQOp>(
|
||||
patterns.getContext());
|
||||
}
|
||||
|
||||
std::unique_ptr<Pass> mlir::createLinalgNamedOpConversionPass() {
|
||||
return std::make_unique<LinalgNamedOpConversionPass>();
|
||||
}
|
|
@ -38,6 +38,10 @@ namespace memref {
|
|||
class MemRefDialect;
|
||||
} // namespace memref
|
||||
|
||||
namespace tensor {
|
||||
class TensorDialect;
|
||||
} // namespace tensor
|
||||
|
||||
namespace vector {
|
||||
class VectorDialect;
|
||||
} // namespace vector
|
||||
|
|
|
@ -758,28 +758,3 @@ func @dim_of_tiled_loop_result_no_canonicalize(%arg0: tensor<?x?xf32>, %arg1: te
|
|||
%r2 = tensor.dim %r, %c0 : tensor<?x?xf32>
|
||||
return %r2 : index
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @depthwise_conv
|
||||
func @depthwise_conv(%arg0: tensor<?x?x?x?xf32>, %arg1: tensor<?x?x?x1xf32>, %arg2: tensor<?x?x?x?x1xf32>) -> tensor<?x?x?x?x1xf32> {
|
||||
// CHECK-DAG: %[[KERNEL:.+]] = tensor.collapse_shape %arg1 {{\[\[}}0], [1], [2, 3]]
|
||||
// CHECK-DAG: %[[INIT:.+]] = tensor.collapse_shape %arg2 {{\[\[}}0], [1], [2], [3, 4]]
|
||||
// CHECK-DAG: %[[CONV:.+]] = linalg.depthwise_conv_2d_nhwc_hwc {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%arg0, %[[KERNEL]] : tensor<?x?x?x?xf32>, tensor<?x?x?xf32>) outs(%[[INIT]] : tensor<?x?x?x?xf32>)
|
||||
// CHECK: %[[OUT:.+]] = tensor.expand_shape %[[CONV]] {{\[\[}}0], [1], [2], [3, 4]]
|
||||
%0 = linalg.depthwise_conv_2d_nhwc_hwcm {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<?x?x?x?xf32>, tensor<?x?x?x1xf32>) outs(%arg2 : tensor<?x?x?x?x1xf32>) -> tensor<?x?x?x?x1xf32>
|
||||
return %0 : tensor<?x?x?x?x1xf32>
|
||||
}
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @depthwise_conv_q
|
||||
func @depthwise_conv_q(%arg0: tensor<?x?x?x?xi8>, %arg1: tensor<?x?x?x1xi8>, %arg2: tensor<?x?x?x?x1xi32>, %arg3 : i32, %arg4 : i32) -> tensor<?x?x?x?x1xi32> {
|
||||
// CHECK-DAG: %[[KERNEL:.+]] = tensor.collapse_shape %arg1 {{\[\[}}0], [1], [2, 3]]
|
||||
// CHECK-DAG: %[[INIT:.+]] = tensor.collapse_shape %arg2 {{\[\[}}0], [1], [2], [3, 4]]
|
||||
// CHECK-DAG: %[[CONV:.+]] = linalg.depthwise_conv_2d_nhwc_hwc_q {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%arg0, %[[KERNEL]], %arg3, %arg4 : tensor<?x?x?x?xi8>, tensor<?x?x?xi8>, i32, i32) outs(%[[INIT]] : tensor<?x?x?x?xi32>)
|
||||
// CHECK: %[[OUT:.+]] = tensor.expand_shape %[[CONV]] {{\[\[}}0], [1], [2], [3, 4]]
|
||||
%0 = linalg.depthwise_conv_2d_nhwc_hwcm_q {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%arg0, %arg1, %arg3, %arg4 : tensor<?x?x?x?xi8>, tensor<?x?x?x1xi8>, i32, i32) outs(%arg2 : tensor<?x?x?x?x1xi32>) -> tensor<?x?x?x?x1xi32>
|
||||
return %0 : tensor<?x?x?x?x1xi32>
|
||||
}
|
||||
|
|
|
@ -0,0 +1,24 @@
|
|||
// RUN: mlir-opt %s -linalg-named-op-conversion -split-input-file | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: @depthwise_conv
|
||||
func @depthwise_conv(%arg0: tensor<?x?x?x?xf32>, %arg1: tensor<?x?x?x1xf32>, %arg2: tensor<?x?x?x?x1xf32>) -> tensor<?x?x?x?x1xf32> {
|
||||
// CHECK-DAG: %[[KERNEL:.+]] = tensor.collapse_shape %arg1 {{\[\[}}0], [1], [2, 3]]
|
||||
// CHECK-DAG: %[[INIT:.+]] = tensor.collapse_shape %arg2 {{\[\[}}0], [1], [2], [3, 4]]
|
||||
// CHECK-DAG: %[[CONV:.+]] = linalg.depthwise_conv_2d_nhwc_hwc {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%arg0, %[[KERNEL]] : tensor<?x?x?x?xf32>, tensor<?x?x?xf32>) outs(%[[INIT]] : tensor<?x?x?x?xf32>)
|
||||
// CHECK: %[[OUT:.+]] = tensor.expand_shape %[[CONV]] {{\[\[}}0], [1], [2], [3, 4]]
|
||||
%0 = linalg.depthwise_conv_2d_nhwc_hwcm {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<?x?x?x?xf32>, tensor<?x?x?x1xf32>) outs(%arg2 : tensor<?x?x?x?x1xf32>) -> tensor<?x?x?x?x1xf32>
|
||||
return %0 : tensor<?x?x?x?x1xf32>
|
||||
}
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @depthwise_conv_q
|
||||
func @depthwise_conv_q(%arg0: tensor<?x?x?x?xi8>, %arg1: tensor<?x?x?x1xi8>, %arg2: tensor<?x?x?x?x1xi32>, %arg3 : i32, %arg4 : i32) -> tensor<?x?x?x?x1xi32> {
|
||||
// CHECK-DAG: %[[KERNEL:.+]] = tensor.collapse_shape %arg1 {{\[\[}}0], [1], [2, 3]]
|
||||
// CHECK-DAG: %[[INIT:.+]] = tensor.collapse_shape %arg2 {{\[\[}}0], [1], [2], [3, 4]]
|
||||
// CHECK-DAG: %[[CONV:.+]] = linalg.depthwise_conv_2d_nhwc_hwc_q {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%arg0, %[[KERNEL]], %arg3, %arg4 : tensor<?x?x?x?xi8>, tensor<?x?x?xi8>, i32, i32) outs(%[[INIT]] : tensor<?x?x?x?xi32>)
|
||||
// CHECK: %[[OUT:.+]] = tensor.expand_shape %[[CONV]] {{\[\[}}0], [1], [2], [3, 4]]
|
||||
%0 = linalg.depthwise_conv_2d_nhwc_hwcm_q {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%arg0, %arg1, %arg3, %arg4 : tensor<?x?x?x?xi8>, tensor<?x?x?x1xi8>, i32, i32) outs(%arg2 : tensor<?x?x?x?x1xi32>) -> tensor<?x?x?x?x1xi32>
|
||||
return %0 : tensor<?x?x?x?x1xi32>
|
||||
}
|
Loading…
Reference in New Issue