[mlir][tosa] Migrate tosa to more efficient linalg.conv

Existing linalg.conv2d is not well optimized for performance. Changed to a
version that is more aligned for optimziation. Include the corresponding
transposes to use this optimized version.

This also splits the conv and depthwise conv into separate implementations
to avoid overly complex lowerings.

Reviewed By: antiagainst

Differential Revision: https://reviews.llvm.org/D107504
This commit is contained in:
Rob Suderman 2021-08-11 11:05:08 -07:00
parent c1a8f12873
commit 7de439b2be
5 changed files with 388 additions and 329 deletions

View File

@ -628,10 +628,10 @@ structured_op: !LinalgStructuredOpConfig
scalar_arg: B
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: conv_2d_input_nhwc_filter_ohwi_poly
cpp_class_name: Conv2DInputNhwcFilterOhwiPolyOp
name: conv_2d_nchw
cpp_class_name: Conv2DNchwOp
doc: |-
Performs a 2-D convolution.
Performs 2-D convolution.
Numeric casting is performed on the operands to the inner multiply, promoting
them to the same data type as the accumulator/output.
@ -648,13 +648,13 @@ structured_op: !LinalgStructuredOpConfig
usage: InputOperand
type_var: T2
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12]
-> (s4, s5, s6, s3)>
-> (s4, s1, s5, s6)>
- !LinalgOperandDefConfig
name: O
usage: OutputOperand
type_var: U
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12]
-> (s0, s7, s8, s4)>
-> (s0, s4, s7, s8, s1)>
- !LinalgOperandDefConfig
name: strides
usage: IndexAttribute
@ -670,19 +670,19 @@ structured_op: !LinalgStructuredOpConfig
indexing_maps: !LinalgIndexingMapsConfig
static_indexing_maps:
- affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8,
s9, s10, s11, s12] -> (d0, d1 * s9 + d3 * s11, d2 * s10 + d4 * s12, d6)>
s9, s10, s11, s12] -> (d0, d4, d2 * s9 + d5 * s11, d3 * s10 + d6 * s12)>
- affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8,
s9, s10, s11, s12] -> (d5, d3, d4, d6)>
s9, s10, s11, s12] -> (d1, d4, d5, d6)>
- affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8,
s9, s10, s11, s12] -> (d0, d1, d2, d5)>
s9, s10, s11, s12] -> (d0, d1, d2, d3)>
iterator_types:
- parallel
- parallel
- parallel
- reduction
- reduction
- parallel
- reduction
- reduction
- reduction
assignments:
- !ScalarAssign
arg: O
@ -710,14 +710,13 @@ structured_op: !LinalgStructuredOpConfig
scalar_arg: K
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: conv_2d_input_nhwc_filter_ohwi_poly_q
cpp_class_name: Conv2DInputNhwcFilterOhwiPolyQOp
name: conv_2d_nhwc_hwcf
cpp_class_name: Conv2DNhwcHwcfOp
doc: |-
Performs a 2-D quantized convolution.
Performs 2-D convolution.
Numeric casting is performed on the operands to the inner multiply, promoting
them to the same data type as the accumulator/output. Includes zero point
adjustment for quantization.
them to the same data type as the accumulator/output.
structured_op: !LinalgStructuredOpConfig
args:
- !LinalgOperandDefConfig
@ -731,21 +730,13 @@ structured_op: !LinalgStructuredOpConfig
usage: InputOperand
type_var: T2
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12]
-> (s4, s5, s6, s3)>
- !LinalgOperandDefConfig
name: IZp
usage: InputOperand
type_var: I32
- !LinalgOperandDefConfig
name: KZp
usage: InputOperand
type_var: I32
-> (s4, s5, s3, s6)>
- !LinalgOperandDefConfig
name: O
usage: OutputOperand
type_var: U
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12]
-> (s0, s7, s8, s4)>
-> (s0, s7, s8, s6)>
- !LinalgOperandDefConfig
name: strides
usage: IndexAttribute
@ -761,23 +752,19 @@ structured_op: !LinalgStructuredOpConfig
indexing_maps: !LinalgIndexingMapsConfig
static_indexing_maps:
- affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8,
s9, s10, s11, s12] -> (d0, d1 * s9 + d3 * s11, d2 * s10 + d4 * s12, d6)>
s9, s10, s11, s12] -> (d0, d1 * s9 + d4 * s11, d2 * s10 + d5 * s12, d6)>
- affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8,
s9, s10, s11, s12] -> (d5, d3, d4, d6)>
s9, s10, s11, s12] -> (d4, d5, d6, d3)>
- affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8,
s9, s10, s11, s12] -> ()>
- affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8,
s9, s10, s11, s12] -> ()>
- affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8,
s9, s10, s11, s12] -> (d0, d1, d2, d5)>
s9, s10, s11, s12] -> (d0, d1, d2, d3)>
iterator_types:
- parallel
- parallel
- parallel
- reduction
- reduction
- parallel
- reduction
- reduction
- reduction
assignments:
- !ScalarAssign
arg: O
@ -791,38 +778,18 @@ structured_op: !LinalgStructuredOpConfig
scalar_apply:
fn_name: mul
operands:
- !ScalarExpression
scalar_apply:
fn_name: sub
operands:
- !ScalarExpression
symbolic_cast:
type_var: U
operands:
- !ScalarExpression
scalar_arg: I
- !ScalarExpression
symbolic_cast:
type_var: U
operands:
- !ScalarExpression
scalar_arg: IZp
- !ScalarExpression
scalar_apply:
fn_name: sub
operands:
- !ScalarExpression
symbolic_cast:
type_var: U
operands:
- !ScalarExpression
scalar_arg: K
- !ScalarExpression
symbolic_cast:
type_var: U
operands:
- !ScalarExpression
scalar_arg: KZp
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: depthwise_conv_2d_input_nhwc_filter_hwc_poly
@ -906,7 +873,122 @@ structured_op: !LinalgStructuredOpConfig
scalar_arg: K
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: depthwise_conv_2D_nchw
name: conv_2d_nhwc_hwcf_q
cpp_class_name: Conv2DNhwcHwcfQOp
doc: |-
Performs 2-D convolution with zero point offsets.
Numeric casting is performed on the operands to the inner multiply, promoting
them to the same data type as the accumulator/output. This includes the zero
point offsets common to quantized operations.
structured_op: !LinalgStructuredOpConfig
args:
- !LinalgOperandDefConfig
name: I
usage: InputOperand
type_var: T1
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12]
-> (s0, s1, s2, s3)>
- !LinalgOperandDefConfig
name: K
usage: InputOperand
type_var: T2
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12]
-> (s4, s5, s3, s6)>
- !LinalgOperandDefConfig
name: IZp
usage: InputOperand
type_var: I32
- !LinalgOperandDefConfig
name: KZp
usage: InputOperand
type_var: I32
- !LinalgOperandDefConfig
name: O
usage: OutputOperand
type_var: U
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12]
-> (s0, s7, s8, s6)>
- !LinalgOperandDefConfig
name: strides
usage: IndexAttribute
type_var: I64
attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11,
s12] -> (s9, s10)>
- !LinalgOperandDefConfig
name: dilations
usage: IndexAttribute
type_var: I64
attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11,
s12] -> (s11, s12)>
indexing_maps: !LinalgIndexingMapsConfig
static_indexing_maps:
- affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8,
s9, s10, s11, s12] -> (d0, d1 * s9 + d4 * s11, d2 * s10 + d5 * s12, d6)>
- affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8,
s9, s10, s11, s12] -> (d4, d5, d6, d3)>
- affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8,
s9, s10, s11, s12] -> ()>
- affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8,
s9, s10, s11, s12] -> ()>
- affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8,
s9, s10, s11, s12] -> (d0, d1, d2, d3)>
iterator_types:
- parallel
- parallel
- parallel
- parallel
- reduction
- reduction
- reduction
assignments:
- !ScalarAssign
arg: O
value: !ScalarExpression
scalar_apply:
fn_name: add
operands:
- !ScalarExpression
scalar_arg: O
- !ScalarExpression
scalar_apply:
fn_name: mul
operands:
- !ScalarExpression
scalar_apply:
fn_name: sub
operands:
- !ScalarExpression
symbolic_cast:
type_var: U
operands:
- !ScalarExpression
scalar_arg: I
- !ScalarExpression
symbolic_cast:
type_var: U
operands:
- !ScalarExpression
scalar_arg: IZp
- !ScalarExpression
scalar_apply:
fn_name: sub
operands:
- !ScalarExpression
symbolic_cast:
type_var: U
operands:
- !ScalarExpression
scalar_arg: K
- !ScalarExpression
symbolic_cast:
type_var: U
operands:
- !ScalarExpression
scalar_arg: KZp
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: depthwise_conv2D_nchw
cpp_class_name: DepthwiseConv2DNchwOp
doc: |-
Performs depth-wise 2-D convolution.
@ -1101,88 +1183,6 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarExpression
scalar_arg: KZp
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: conv_2d_nchw
cpp_class_name: Conv2DNchwOp
doc: |-
Performs 2-D convolution.
Numeric casting is performed on the operands to the inner multiply, promoting
them to the same data type as the accumulator/output.
structured_op: !LinalgStructuredOpConfig
args:
- !LinalgOperandDefConfig
name: I
usage: InputOperand
type_var: T1
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12]
-> (s0, s1, s2, s3)>
- !LinalgOperandDefConfig
name: K
usage: InputOperand
type_var: T2
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12]
-> (s4, s1, s5, s6)>
- !LinalgOperandDefConfig
name: O
usage: OutputOperand
type_var: U
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12]
-> (s0, s4, s7, s8, s1)>
- !LinalgOperandDefConfig
name: strides
usage: IndexAttribute
type_var: I64
attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11,
s12] -> (s9, s10)>
- !LinalgOperandDefConfig
name: dilations
usage: IndexAttribute
type_var: I64
attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11,
s12] -> (s11, s12)>
indexing_maps: !LinalgIndexingMapsConfig
static_indexing_maps:
- affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8,
s9, s10, s11, s12] -> (d0, d4, d2 * s9 + d5 * s11, d3 * s10 + d6 * s12)>
- affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8,
s9, s10, s11, s12] -> (d1, d4, d5, d6)>
- affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8,
s9, s10, s11, s12] -> (d0, d1, d2, d3)>
iterator_types:
- parallel
- parallel
- parallel
- parallel
- reduction
- reduction
- reduction
assignments:
- !ScalarAssign
arg: O
value: !ScalarExpression
scalar_apply:
fn_name: add
operands:
- !ScalarExpression
scalar_arg: O
- !ScalarExpression
scalar_apply:
fn_name: mul
operands:
- !ScalarExpression
symbolic_cast:
type_var: U
operands:
- !ScalarExpression
scalar_arg: I
- !ScalarExpression
symbolic_cast:
type_var: U
operands:
- !ScalarExpression
scalar_arg: K
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: pooling_nhwc_sum
cpp_class_name: PoolingNhwcSumOp
@ -1896,3 +1896,4 @@ structured_op: !LinalgStructuredOpConfig
operands:
- !ScalarExpression
scalar_arg: I

View File

@ -849,9 +849,136 @@ static LogicalResult reduceMatchAndRewriteHelper(Operation *op, uint64_t axis,
return success();
}
static LogicalResult
convolutionMatchAndRewriterHelper(Operation *op,
ConversionPatternRewriter &rewriter) {
namespace {
template <typename SrcOp>
class PointwiseConverter : public OpRewritePattern<SrcOp> {
public:
using OpRewritePattern<SrcOp>::OpRewritePattern;
LogicalResult matchAndRewrite(SrcOp op,
PatternRewriter &rewriter) const final {
return elementwiseMatchAndRewriteHelper(op, rewriter);
}
};
class ConvConverter : public OpConversionPattern<tosa::Conv2DOp> {
public:
using OpConversionPattern<tosa::Conv2DOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(tosa::Conv2DOp op, ArrayRef<Value> args,
ConversionPatternRewriter &rewriter) const final {
Location loc = op->getLoc();
Value input = op->getOperand(0);
Value weight = op->getOperand(1);
Value bias = op->getOperand(2);
ShapedType inputTy = input.getType().cast<ShapedType>();
ShapedType weightTy = weight.getType().cast<ShapedType>();
ShapedType biasTy = bias.getType().cast<ShapedType>();
ShapedType resultTy = op->getResult(0).getType().cast<ShapedType>();
Type inputETy = inputTy.getElementType();
Type resultETy = resultTy.getElementType();
auto padAttr = op->getAttr("pad").cast<ArrayAttr>();
auto strideTosaAttr = op->getAttr("stride").cast<ArrayAttr>();
auto dilationTosaAttr = op->getAttr("dilation").cast<ArrayAttr>();
bool isQuantized = op->hasAttr("quantization_info");
if (!inputTy.hasStaticShape() || !weightTy.hasStaticShape() ||
!biasTy.hasStaticShape() || !resultTy.hasStaticShape())
return rewriter.notifyMatchFailure(op,
"tosa.conv ops require static shapes");
auto weightShape = weightTy.getShape();
// Apply padding as necessary.
Attribute zeroAttr = rewriter.getZeroAttr(inputETy);
llvm::SmallVector<int64_t> pad;
pad.resize(2, 0);
getValuesFromIntArrayAttribute(padAttr, pad);
pad.resize(pad.size() + 2, 0);
input = applyPad(loc, input, pad, zeroAttr, rewriter);
// Transpose the kernel to match dimension ordering of the linalg
// convolution operation.
// TODO(suderman): See if this can be efficiently folded - check whether
// the input is used anywhere else, if not fold the constant.
SmallVector<int64_t> weightPerm{1, 2, 3, 0};
SmallVector<int64_t> newWeightShape{weightShape[1], weightShape[2],
weightShape[3], weightShape[0]};
auto weightPermAttr = DenseIntElementsAttr::get(
RankedTensorType::get({4}, rewriter.getI64Type()), weightPerm);
Value weightPermValue = rewriter.create<ConstantOp>(loc, weightPermAttr);
Type newWeightTy =
RankedTensorType::get(newWeightShape, weightTy.getElementType());
weight = rewriter.create<tosa::TransposeOp>(loc, newWeightTy, weight,
weightPermValue);
// Broadcast the initial value to the output tensor before convolving.
SmallVector<AffineMap, 4> indexingMaps;
indexingMaps.push_back(AffineMap::get(
/*dimCount=*/resultTy.getRank(), /*symbolCount=*/0,
{rewriter.getAffineDimExpr(3)}, rewriter.getContext()));
indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultTy.getRank()));
Value initTensor = rewriter.create<linalg::InitTensorOp>(
loc, resultTy.getShape(), resultETy);
Value biasBroadcast =
rewriter
.create<linalg::GenericOp>(
loc, resultTy, bias, initTensor, indexingMaps,
getNParallelLoopsAttrs(resultTy.getRank()),
[&](OpBuilder &nestedBuilder, Location nestedLoc,
ValueRange args) {
nestedBuilder.create<linalg::YieldOp>(nestedLoc, args[0]);
})
.getResult(0);
// Extract the attributes for convolution.
llvm::SmallVector<int64_t> stride, dilation;
getValuesFromIntArrayAttribute(strideTosaAttr, stride);
getValuesFromIntArrayAttribute(dilationTosaAttr, dilation);
// Create the convolution op.
auto strideAttr = DenseIntElementsAttr::get(
RankedTensorType::get({2}, rewriter.getI64Type()), stride);
auto dilationAttr = DenseIntElementsAttr::get(
RankedTensorType::get({2}, rewriter.getI64Type()), dilation);
Value conv;
if (isQuantized) {
auto quantizationInfo =
op->getAttr("quantization_info").cast<tosa::ConvOpQuantizationAttr>();
auto iZp = rewriter.getI32IntegerAttr(
quantizationInfo.input_zp().getValue().getSExtValue());
auto kZp = rewriter.getI32IntegerAttr(
quantizationInfo.weight_zp().getValue().getSExtValue());
auto iZpVal = rewriter.create<ConstantOp>(loc, iZp);
auto kZpVal = rewriter.create<ConstantOp>(loc, kZp);
rewriter.replaceOpWithNewOp<linalg::Conv2DNhwcHwcfQOp>(
op, resultTy, ValueRange{input, weight, iZpVal, kZpVal},
ValueRange{biasBroadcast}, strideAttr, dilationAttr);
return success();
}
rewriter.replaceOpWithNewOp<linalg::Conv2DNhwcHwcfOp>(
op, resultTy, ValueRange{input, weight}, ValueRange{biasBroadcast},
strideAttr, dilationAttr);
return success();
}
};
class DepthwiseConvConverter
: public OpConversionPattern<tosa::DepthwiseConv2DOp> {
public:
using OpConversionPattern<tosa::DepthwiseConv2DOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(tosa::DepthwiseConv2DOp op, ArrayRef<Value> args,
ConversionPatternRewriter &rewriter) const final {
Location loc = op->getLoc();
Value input = op->getOperand(0);
Value weight = op->getOperand(1);
@ -905,8 +1032,8 @@ convolutionMatchAndRewriterHelper(Operation *op,
{rewriter.getAffineDimExpr(3)}, rewriter.getContext()));
indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultTy.getRank()));
Value initTensor = rewriter.create<linalg::InitTensorOp>(
loc, resultTy.getShape(), resultTy.getElementType());
Value initTensor =
rewriter.create<linalg::InitTensorOp>(loc, resultShape, resultETy);
Value biasBroadcast =
rewriter
@ -929,24 +1056,6 @@ convolutionMatchAndRewriterHelper(Operation *op,
RankedTensorType::get({2}, rewriter.getI64Type()), stride);
auto dilationAttr = DenseIntElementsAttr::get(
RankedTensorType::get({2}, rewriter.getI64Type()), dilation);
if (isa<tosa::Conv2DOp>(op) && !isQuantized) {
rewriter.replaceOpWithNewOp<linalg::Conv2DInputNhwcFilterOhwiPolyOp>(
op, resultTy, ValueRange{input, weight}, ValueRange{biasBroadcast},
strideAttr, dilationAttr);
return success();
}
if (isa<tosa::Conv2DOp>(op) && isQuantized) {
auto iZpVal = rewriter.create<ConstantOp>(loc, iZp);
auto kZpVal = rewriter.create<ConstantOp>(loc, kZp);
rewriter.replaceOpWithNewOp<linalg::Conv2DInputNhwcFilterOhwiPolyQOp>(
op, resultTy, ValueRange{input, weight, iZpVal, kZpVal},
ValueRange{biasBroadcast}, strideAttr, dilationAttr);
return success();
}
if (isa<tosa::DepthwiseConv2DOp>(op)) {
ShapedType linalgConvTy =
RankedTensorType::get({resultShape[0], resultShape[1], resultShape[2],
weightShape[2], weightShape[3]},
@ -976,32 +1085,6 @@ convolutionMatchAndRewriterHelper(Operation *op,
rewriter.replaceOp(op, reshape);
return success();
}
return failure();
}
namespace {
template <typename SrcOp>
class PointwiseConverter : public OpRewritePattern<SrcOp> {
public:
using OpRewritePattern<SrcOp>::OpRewritePattern;
LogicalResult matchAndRewrite(SrcOp op,
PatternRewriter &rewriter) const final {
return elementwiseMatchAndRewriteHelper(op, rewriter);
}
};
template <typename T>
class ConvConverter : public OpConversionPattern<T> {
public:
using OpConversionPattern<T>::OpConversionPattern;
LogicalResult
matchAndRewrite(T op, ArrayRef<Value> args,
ConversionPatternRewriter &rewriter) const final {
return convolutionMatchAndRewriterHelper(op, rewriter);
}
};
class TransposeConvConverter
@ -2528,8 +2611,8 @@ void mlir::tosa::populateTosaToLinalgOnTensorsConversionPatterns(
ReduceConverter<tosa::ReduceProdOp>,
ArgMaxConverter,
ConcatConverter,
ConvConverter<tosa::Conv2DOp>,
ConvConverter<tosa::DepthwiseConv2DOp>,
ConvConverter,
DepthwiseConvConverter,
TransposeConvConverter,
GatherConverter,
PadConverter,

View File

@ -144,49 +144,39 @@ def dot(
implements(ContractionOpInterface)
C[None] += cast(U, A[D.m]) * cast(U, B[D.m])
@linalg_structured_op
def conv_2d_input_nhwc_filter_ohwi_poly(
I=TensorDef(T1, S.N, S.IH, S.IW, S.IC),
K=TensorDef(T2, S.OC, S.KH, S.KW, S.IC),
O=TensorDef(U, S.N, S.OH, S.OW, S.OC, output=True),
def conv_2d_nchw(
I=TensorDef(T1, S.N, S.C, S.IH, S.IW),
K=TensorDef(T2, S.F, S.C, S.KH, S.KW),
O=TensorDef(U, S.N, S.F, S.OH, S.OW, S.C, output=True),
strides=AttributeDef(S.SH, S.SW),
dilations=AttributeDef(S.DH, S.DW)):
"""Performs a 2-D convolution.
"""Performs 2-D convolution.
Numeric casting is performed on the operands to the inner multiply, promoting
them to the same data type as the accumulator/output.
"""
domain(D.n, D.oh, D.ow, D.kh, D.kw, D.oc, D.ic)
O[D.n, D.oh, D.ow, D.oc] += cast(
U, I[D.n,
D.oh * S.SH + D.kh * S.DH,
D.ow * S.SW + D.kw * S.DW,
D.ic]) * cast(U, K[D.oc, D.kh, D.kw, D.ic])
domain(D.n, D.f, D.oh, D.ow, D.c, D.kh, D.kw)
O[D.n, D.f, D.oh, D.ow] += cast(
U, I[D.n, D.c, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW,
]) * cast(U, K[D.f, D.c, D.kh, D.kw])
@linalg_structured_op
def conv_2d_input_nhwc_filter_ohwi_poly_q(
I=TensorDef(T1, S.N, S.IH, S.IW, S.IC),
K=TensorDef(T2, S.OC, S.KH, S.KW, S.IC),
IZp=ScalarDef(I32),
KZp=ScalarDef(I32),
O=TensorDef(U, S.N, S.OH, S.OW, S.OC, output=True),
def conv_2d_nhwc_hwcf(
I=TensorDef(T1, S.N, S.IH, S.IW, S.C),
K=TensorDef(T2, S.KH, S.KW, S.C, S.F),
O=TensorDef(U, S.N, S.OH, S.OW, S.F, output=True),
strides=AttributeDef(S.SH, S.SW),
dilations=AttributeDef(S.DH, S.DW)):
"""Performs a 2-D quantized convolution.
"""Performs 2-D convolution.
Numeric casting is performed on the operands to the inner multiply, promoting
them to the same data type as the accumulator/output. Includes zero point
adjustment for quantization.
them to the same data type as the accumulator/output.
"""
domain(D.n, D.oh, D.ow, D.kh, D.kw, D.oc, D.ic)
O[D.n, D.oh, D.ow, D.oc] += ((cast(
U, I[D.n,
D.oh * S.SH + D.kh * S.DH,
D.ow * S.SW + D.kw * S.DW,
D.ic]) - cast(U, IZp)) *
(cast(U, K[D.oc, D.kh, D.kw, D.ic]) - cast(U, KZp)))
domain(D.n, D.oh, D.ow, D.f, D.kh, D.kw, D.c)
O[D.n, D.oh, D.ow, D.f] += cast(
U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c
]) * cast(U, K[D.kh, D.kw, D.c, D.f])
@linalg_structured_op
def depthwise_conv_2d_input_nhwc_filter_hwc_poly(
@ -206,24 +196,27 @@ def depthwise_conv_2d_input_nhwc_filter_hwc_poly(
D.c]) * cast(U, K[D.kh, D.kw, D.c])
@linalg_structured_op
def conv_2d_nchw(
I=TensorDef(T1, S.N, S.C, S.IH, S.IW),
K=TensorDef(T2, S.F, S.C, S.KH, S.KW),
O=TensorDef(U, S.N, S.F, S.OH, S.OW, S.C, output=True),
def conv_2d_nhwc_hwcf_q(
I=TensorDef(T1, S.N, S.IH, S.IW, S.C),
K=TensorDef(T2, S.KH, S.KW, S.C, S.F),
IZp=ScalarDef(I32),
KZp=ScalarDef(I32),
O=TensorDef(U, S.N, S.OH, S.OW, S.F, output=True),
strides=AttributeDef(S.SH, S.SW),
dilations=AttributeDef(S.DH, S.DW)):
"""Performs 2-D convolution.
"""Performs 2-D convolution with zero point offsets.
Numeric casting is performed on the operands to the inner multiply, promoting
them to the same data type as the accumulator/output.
them to the same data type as the accumulator/output. This includes the zero
point offsets common to quantized operations.
"""
domain(D.n, D.f, D.oh, D.ow, D.c, D.kh, D.kw)
O[D.n, D.f, D.oh, D.ow] += cast(
U, I[D.n, D.c, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW,
]) * cast(U, K[D.f, D.c, D.kh, D.kw])
domain(D.n, D.oh, D.ow, D.f, D.kh, D.kw, D.c)
O[D.n, D.oh, D.ow, D.f] += (cast(
U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c
]) - cast(U, IZp)) * (cast(U, K[D.kh, D.kw, D.c, D.f]) - cast(U, KZp))
def depthwise_conv2D_nchw( #TODO: Fix name
@linalg_structured_op
def depthwise_conv2D_nchw(
I=TensorDef(T1, S.N, S.IH, S.IW, S.IC),
K=TensorDef(T2, S.KH, S.KW, S.IC, S.CM),
O=TensorDef(U, S.N, S.OH, S.OW, S.IC, S.CM, output=True),
@ -239,8 +232,8 @@ def depthwise_conv2D_nchw( #TODO: Fix name
U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW,
D.ic]) * cast(U, K[D.kh, D.kw, D.ic, D.cm])
def depthwise_conv2D_nchw_q( #TODO: Fix name
@linalg_structured_op
def depthwise_conv2D_nchw_q(
I=TensorDef(T1, S.N, S.IH, S.IW, S.IC),
K=TensorDef(T2, S.KH, S.KW, S.IC, S.CM),
IZp=ScalarDef(I32),

View File

@ -1176,14 +1176,19 @@ func @avg_pool(%arg0: tensor<1x6x34x62xf32>) -> (tensor<1x5x33x62xf32>) {
// -----
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d3)>
// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
// CHECK: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d0, d1, d2)>
// CHECK: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
// CHECK: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d3)>
// CHECK-LABEL: @conv2d_f32
// CHECK-LABEL @conv2d_f32
func @conv2d_f32(%input: tensor<1x49x42x27xf32>, %weights: tensor<28x3x3x27xf32>, %bias: tensor<28xf32>) -> () {
// CHECK: %[[INIT:.+]] = linalg.init_tensor [1, 45, 40, 28]
// CHECK: %[[BROADCAST:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<28xf32>) outs(%[[INIT]] : tensor<1x45x40x28xf32>)
// CHECK: linalg.conv_2d_input_nhwc_filter_ohwi_poly {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<1x49x42x27xf32>, tensor<28x3x3x27xf32>) outs(%[[BROADCAST]] : tensor<1x45x40x28xf32>)
// CHECK: %[[W_IN:.+]] = linalg.init_tensor [3, 3, 27, 28]
// CHECK: %[[W:.+]] = linalg.generic {indexing_maps = [#[[MAP0]], #[[MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg1 : tensor<28x3x3x27xf32>) outs(%[[W_IN]] : tensor<3x3x27x28xf32>)
// CHECK: linalg.yield %arg3 : f32
// CHECK: %[[B_IN:.+]] = linalg.init_tensor [1, 45, 40, 28]
// CHECK: %[[B:.+]] = linalg.generic {indexing_maps = [#[[MAP2]], #[[MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<28xf32>) outs(%[[B_IN]] : tensor<1x45x40x28xf32>)
// CHECK: linalg.yield %arg3 : f32
// CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %1 : tensor<1x49x42x27xf32>, tensor<3x3x27x28xf32>) outs(%[[B]] : tensor<1x45x40x28xf32>)
%0 = "tosa.conv2d"(%input, %weights, %bias) {pad = [0, 0, 0, 0], stride = [1, 1], dilation = [2, 1]} : (tensor<1x49x42x27xf32>, tensor<28x3x3x27xf32>, tensor<28xf32>) -> (tensor<1x45x40x28xf32>)
return
}
@ -1192,26 +1197,17 @@ func @conv2d_f32(%input: tensor<1x49x42x27xf32>, %weights: tensor<28x3x3x27xf32>
// CHECK-LABEL: @conv2d_padded_f32
func @conv2d_padded_f32(%input: tensor<1x47x40x28xf32>, %weights: tensor<28x3x3x28xf32>, %bias: tensor<28xf32>) -> () {
// CHECK: linalg.pad_tensor %arg0
// CHECK: linalg.conv_2d_input_nhwc_filter_ohwi_poly
// CHECK: linalg.pad_tensor %arg0 low[0, 1, 1, 0] high[0, 1, 1, 0]
// CHECK: linalg.conv_2d_nhwc_hwcf
%0 = "tosa.conv2d"(%input, %weights, %bias) {pad = [1, 1, 1, 1], stride = [1, 1], dilation = [2, 1]} : (tensor<1x47x40x28xf32>, tensor<28x3x3x28xf32>, tensor<28xf32>) -> (tensor<1x45x40x28xf32>)
return
}
// -----
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d3)>
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
// CHECK-LABEL: @conv2d_quant
func @conv2d_quant(%arg0 : tensor<1x12x12x1xi8>, %arg1 : tensor<1024x3x3x1xi8>, %arg2 : tensor<1024xi32>) -> () {
// CHECK: %[[INIT:.+]] = linalg.init_tensor [1, 10, 10, 1024]
// CHECK: %[[BROADCAST:.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1024xi32>) outs(%[[INIT]] : tensor<1x10x10x1024xi32>)
// CHECK: ^bb0(%arg3: i32, %arg4: i32):
// CHECK: linalg.yield %arg3 : i32
// CHECK: %[[C128:.+]] = constant -128
// CHECK: %[[C42:.+]] = constant 42
// CHECK: linalg.conv_2d_input_nhwc_filter_ohwi_poly_q {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1, %[[C128]], %[[C42]] : tensor<1x12x12x1xi8>, tensor<1024x3x3x1xi8>, i32, i32) outs(%1 : tensor<1x10x10x1024xi32>)
// CHECK: linalg.conv_2d_nhwc_hwcf_q
%0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {dilation = [1, 1], pad = [0, 0, 0, 0], quantization_info = {input_zp = -128 : i32, weight_zp = 42 : i32}, stride = [1, 1]} : (tensor<1x12x12x1xi8>, tensor<1024x3x3x1xi8>, tensor<1024xi32>) -> tensor<1x10x10x1024xi32>
return
}
@ -1229,7 +1225,7 @@ func @depthwise_conv(%arg0 : tensor<1x7x5x3xf32>, %arg1 : tensor<3x1x3x11xf32>,
// CHECK: linalg.yield %arg3 : f32
// CHECK: } -> tensor<1x5x5x33xf32>
// CHECK: [[DBIAS:%.+]] = linalg.tensor_expand_shape [[BIAS]] {{\[}}[0], [1], [2], [3, 4]]
// CHECK: [[DEPTH:%.+]] = linalg.depthwise_conv_2D_nchw {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<1x7x5x3xf32>, tensor<3x1x3x11xf32>) outs([[DBIAS]] : tensor<1x5x5x3x11xf32>)
// CHECK: [[DEPTH:%.+]] = linalg.depthwise_conv2D_nchw {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<1x7x5x3xf32>, tensor<3x1x3x11xf32>) outs([[DBIAS]] : tensor<1x5x5x3x11xf32>)
// CHECK: linalg.tensor_collapse_shape %3 {{\[}}[0], [1], [2], [3, 4]]
%2 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) { pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1] } : (tensor<1x7x5x3xf32>, tensor<3x1x3x11xf32>, tensor<33xf32>) -> (tensor<1x5x5x33xf32>)
return
@ -1260,8 +1256,8 @@ func @depthwise_conv_quant(%arg0 : tensor<1x12x12x4xi8>, %arg1 : tensor<3x3x4x12
// CHECK-LABEL: @transpose_conv
func @transpose_conv(%arg0 : tensor<1x12x12x2xf32>, %arg1 : tensor<4x3x3x2xf32>, %arg2 : tensor<4xf32>) -> () {
// CHECK: [[PAD:%.+]] = linalg.pad_tensor %arg0 low[0, 2, 2, 0] high[0, 2, 2, 0]
// CHECK: linalg.conv_2d_input_nhwc_filter_ohwi_poly {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins([[PAD]], {{%.+}} : tensor<1x16x16x2xf32>, tensor<4x3x3x2xf32>)
// CHECK: linalg.pad_tensor %arg0 low[0, 2, 2, 0] high[0, 2, 2, 0]
// CHECK: linalg.conv_2d_nhwc_hwcf
%0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {dilation = [1, 1], out_pad = [0, 0], out_shape = [1, 14, 14, 4], stride = [1, 1]} : (tensor<1x12x12x2xf32>, tensor<4x3x3x2xf32>, tensor<4xf32>) -> tensor<1x14x14x4xf32>
return
}
@ -1271,7 +1267,7 @@ func @transpose_conv(%arg0 : tensor<1x12x12x2xf32>, %arg1 : tensor<4x3x3x2xf32>,
// CHECK-LABEL: @transpose_conv_dilated
func @transpose_conv_dilated(%arg0 : tensor<1x12x12x2xf32>, %arg1 : tensor<4x3x3x2xf32>, %arg2 : tensor<4xf32>) -> () {
// CHECK: [[PAD:%.+]] = linalg.pad_tensor %arg0 low[0, 4, 4, 0] high[0, 4, 4, 0]
// CHECK: linalg.conv_2d_input_nhwc_filter_ohwi_poly {dilations = dense<2> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins([[PAD]], {{%.+}} : tensor<1x20x20x2xf32>, tensor<4x3x3x2xf32>)
// CHECK: linalg.conv_2d_nhwc_hwcf {dilations = dense<2> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins([[PAD]], {{%.+}} : tensor<1x20x20x2xf32>, tensor<3x3x2x4xf32>)
%0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {dilation = [2, 2], out_pad = [0, 0], out_shape = [1, 16, 16, 4], stride = [1, 1]} : (tensor<1x12x12x2xf32>, tensor<4x3x3x2xf32>, tensor<4xf32>) -> tensor<1x16x16x4xf32>
return
}

View File

@ -1,19 +1,5 @@
// RUN: mlir-opt -split-input-file -verify-diagnostics %s | FileCheck %s
// CHECK-LABEL: func @conv_2d_input_nhwc_filter_ohwi_poly_q_tensor
func @conv_2d_input_nhwc_filter_ohwi_poly_q_tensor(%input: tensor<2x4x5x3xi8>, %filter: tensor<2x2x2x3xi8>) -> tensor<2x3x4x2xi32> {
%zero = constant 0 : i32
%init = linalg.init_tensor [2, 3, 4, 2] : tensor<2x3x4x2xi32>
%fill = linalg.fill(%zero, %init) : i32, tensor<2x3x4x2xi32> -> tensor<2x3x4x2xi32>
%c128 = constant -128 : i32
%c42 = constant 42 : i32
%0 = linalg.conv_2d_input_nhwc_filter_ohwi_poly_q
{ dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }
ins(%input, %filter, %c128, %c42 : tensor<2x4x5x3xi8>, tensor<2x2x2x3xi8>, i32, i32)
outs(%fill : tensor<2x3x4x2xi32>) -> tensor<2x3x4x2xi32>
return %0 : tensor<2x3x4x2xi32>
}
// CHECK-LABEL: func @depthwise_conv_2d_input_nhwc_filter_hwcf_tensor
func @depthwise_conv_2d_input_nhwc_filter_hwcf_tensor(%input: tensor<2x4x5x2xf32>, %filter: tensor<2x2x2x3xf32>) -> tensor<2x3x4x2x3xf32> {
%zero = constant 0.000000e+00 : f32