forked from OSchip/llvm-project
Remove useless nesting blok and dead return statement in TosaToLinalg.cpp (NFC)
Flagged by Coverity.
This commit is contained in:
parent
a1e62aa75b
commit
e4e463e747
|
@ -1431,248 +1431,240 @@ public:
|
|||
getNParallelLoopsAttrs(resultTy.getRank()));
|
||||
rewriter.replaceOp(op, genericOp.getResult(0));
|
||||
|
||||
{
|
||||
OpBuilder::InsertionGuard regionGuard(rewriter);
|
||||
rewriter.createBlock(&genericOp.region(), genericOp.region().end(),
|
||||
TypeRange({resultElementTy}));
|
||||
Value batch = rewriter.create<linalg::IndexOp>(loc, 0);
|
||||
Value y = rewriter.create<linalg::IndexOp>(loc, 1);
|
||||
Value x = rewriter.create<linalg::IndexOp>(loc, 2);
|
||||
Value channel = rewriter.create<linalg::IndexOp>(loc, 3);
|
||||
OpBuilder::InsertionGuard regionGuard(rewriter);
|
||||
rewriter.createBlock(&genericOp.region(), genericOp.region().end(),
|
||||
TypeRange({resultElementTy}));
|
||||
Value batch = rewriter.create<linalg::IndexOp>(loc, 0);
|
||||
Value y = rewriter.create<linalg::IndexOp>(loc, 1);
|
||||
Value x = rewriter.create<linalg::IndexOp>(loc, 2);
|
||||
Value channel = rewriter.create<linalg::IndexOp>(loc, 3);
|
||||
|
||||
auto hwMin = rewriter.create<arith::ConstantOp>(
|
||||
loc, rewriter.getI32IntegerAttr(0));
|
||||
auto hMax = rewriter.create<arith::ConstantOp>(
|
||||
loc, rewriter.getI32IntegerAttr(imageH - 1));
|
||||
auto wMax = rewriter.create<arith::ConstantOp>(
|
||||
loc, rewriter.getI32IntegerAttr(imageW - 1));
|
||||
auto hwMin =
|
||||
rewriter.create<arith::ConstantOp>(loc, rewriter.getI32IntegerAttr(0));
|
||||
auto hMax = rewriter.create<arith::ConstantOp>(
|
||||
loc, rewriter.getI32IntegerAttr(imageH - 1));
|
||||
auto wMax = rewriter.create<arith::ConstantOp>(
|
||||
loc, rewriter.getI32IntegerAttr(imageW - 1));
|
||||
|
||||
Value inY =
|
||||
rewriter.create<arith::IndexCastOp>(loc, rewriter.getI32Type(), y);
|
||||
Value inX =
|
||||
rewriter.create<arith::IndexCastOp>(loc, rewriter.getI32Type(), x);
|
||||
Value inY =
|
||||
rewriter.create<arith::IndexCastOp>(loc, rewriter.getI32Type(), y);
|
||||
Value inX =
|
||||
rewriter.create<arith::IndexCastOp>(loc, rewriter.getI32Type(), x);
|
||||
|
||||
int32_t shift = op.shift();
|
||||
bool floatingPointMode = shift == 0;
|
||||
int32_t shift = op.shift();
|
||||
bool floatingPointMode = shift == 0;
|
||||
|
||||
Value yStride, xStride, yOffset, xOffset;
|
||||
if (floatingPointMode) {
|
||||
yStride = rewriter.create<arith::ConstantOp>(loc, op.stride_fp()[0]);
|
||||
xStride = rewriter.create<arith::ConstantOp>(loc, op.stride_fp()[1]);
|
||||
yOffset = rewriter.create<arith::ConstantOp>(loc, op.offset_fp()[0]);
|
||||
xOffset = rewriter.create<arith::ConstantOp>(loc, op.offset_fp()[1]);
|
||||
} else {
|
||||
SmallVector<int32_t> stride, offset;
|
||||
getValuesFromIntArrayAttribute(op.stride(), stride);
|
||||
getValuesFromIntArrayAttribute(op.offset(), offset);
|
||||
Value yStride, xStride, yOffset, xOffset;
|
||||
if (floatingPointMode) {
|
||||
yStride = rewriter.create<arith::ConstantOp>(loc, op.stride_fp()[0]);
|
||||
xStride = rewriter.create<arith::ConstantOp>(loc, op.stride_fp()[1]);
|
||||
yOffset = rewriter.create<arith::ConstantOp>(loc, op.offset_fp()[0]);
|
||||
xOffset = rewriter.create<arith::ConstantOp>(loc, op.offset_fp()[1]);
|
||||
} else {
|
||||
SmallVector<int32_t> stride, offset;
|
||||
getValuesFromIntArrayAttribute(op.stride(), stride);
|
||||
getValuesFromIntArrayAttribute(op.offset(), offset);
|
||||
|
||||
yStride = rewriter.create<arith::ConstantOp>(
|
||||
loc, rewriter.getI32IntegerAttr(stride[0]));
|
||||
xStride = rewriter.create<arith::ConstantOp>(
|
||||
loc, rewriter.getI32IntegerAttr(stride[1]));
|
||||
yOffset = rewriter.create<arith::ConstantOp>(
|
||||
loc, rewriter.getI32IntegerAttr(offset[0]));
|
||||
xOffset = rewriter.create<arith::ConstantOp>(
|
||||
loc, rewriter.getI32IntegerAttr(offset[1]));
|
||||
}
|
||||
|
||||
// Compute the the integer index and partial offset.
|
||||
// x = x * stride + offset;
|
||||
// ix = floor(x)
|
||||
// dx = x - ix
|
||||
Value ix, iy, dx, dy;
|
||||
if (floatingPointMode) {
|
||||
Value y =
|
||||
rewriter.create<arith::UIToFPOp>(loc, rewriter.getF32Type(), inY);
|
||||
Value x =
|
||||
rewriter.create<arith::UIToFPOp>(loc, rewriter.getF32Type(), inX);
|
||||
|
||||
y = rewriter.create<arith::MulFOp>(loc, y, yStride);
|
||||
x = rewriter.create<arith::MulFOp>(loc, x, xStride);
|
||||
|
||||
y = rewriter.create<arith::AddFOp>(loc, y, yOffset);
|
||||
x = rewriter.create<arith::AddFOp>(loc, x, xOffset);
|
||||
|
||||
iy = rewriter.create<math::FloorOp>(loc, y);
|
||||
ix = rewriter.create<math::FloorOp>(loc, x);
|
||||
|
||||
dy = rewriter.create<arith::SubFOp>(loc, y, iy);
|
||||
dx = rewriter.create<arith::SubFOp>(loc, x, ix);
|
||||
|
||||
iy = rewriter.create<arith::FPToSIOp>(loc, rewriter.getI32Type(), iy);
|
||||
ix = rewriter.create<arith::FPToSIOp>(loc, rewriter.getI32Type(), ix);
|
||||
} else {
|
||||
Value shiftVal = rewriter.create<arith::ConstantOp>(
|
||||
loc, rewriter.getI32IntegerAttr(shift));
|
||||
|
||||
Value y = rewriter.create<arith::MulIOp>(loc, inY, yStride);
|
||||
Value x = rewriter.create<arith::MulIOp>(loc, inX, xStride);
|
||||
|
||||
y = rewriter.create<arith::AddIOp>(loc, y, yOffset);
|
||||
x = rewriter.create<arith::AddIOp>(loc, x, xOffset);
|
||||
|
||||
iy = rewriter.create<arith::ShRSIOp>(loc, y, shiftVal);
|
||||
ix = rewriter.create<arith::ShRSIOp>(loc, x, shiftVal);
|
||||
|
||||
Value yTrunc = rewriter.create<arith::ShLIOp>(loc, iy, shiftVal);
|
||||
Value xTrunc = rewriter.create<arith::ShLIOp>(loc, ix, shiftVal);
|
||||
|
||||
dy = rewriter.create<arith::SubIOp>(loc, y, yTrunc);
|
||||
dx = rewriter.create<arith::SubIOp>(loc, x, xTrunc);
|
||||
}
|
||||
|
||||
if (op.mode() == "NEAREST_NEIGHBOR") {
|
||||
Value yPred, xPred;
|
||||
// Round the index position towards the closest pixel location.
|
||||
if (floatingPointMode) {
|
||||
auto halfVal = rewriter.create<arith::ConstantOp>(
|
||||
loc, rewriter.getF32FloatAttr(0.5f));
|
||||
yPred = rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGE,
|
||||
dy, halfVal);
|
||||
xPred = rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGE,
|
||||
dx, halfVal);
|
||||
} else {
|
||||
auto halfVal = rewriter.create<arith::ConstantOp>(
|
||||
loc, rewriter.getI32IntegerAttr(1 << (shift - 1)));
|
||||
yPred = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sge,
|
||||
dy, halfVal);
|
||||
xPred = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sge,
|
||||
dx, halfVal);
|
||||
}
|
||||
|
||||
auto zeroVal = rewriter.create<arith::ConstantOp>(
|
||||
loc, rewriter.getI32IntegerAttr(0));
|
||||
auto oneVal = rewriter.create<arith::ConstantOp>(
|
||||
loc, rewriter.getI32IntegerAttr(1));
|
||||
|
||||
auto yOffset =
|
||||
rewriter.create<mlir::SelectOp>(loc, yPred, oneVal, zeroVal);
|
||||
auto xOffset =
|
||||
rewriter.create<mlir::SelectOp>(loc, xPred, oneVal, zeroVal);
|
||||
|
||||
iy = rewriter.create<arith::AddIOp>(loc, iy, yOffset);
|
||||
ix = rewriter.create<arith::AddIOp>(loc, ix, xOffset);
|
||||
|
||||
// Clamp the to be within the bounds of the input image.
|
||||
|
||||
iy = clampHelper<arith::CmpIOp>(loc, iy, hwMin, hMax,
|
||||
arith::CmpIPredicate::slt, rewriter);
|
||||
ix = clampHelper<arith::CmpIOp>(loc, ix, hwMin, wMax,
|
||||
arith::CmpIPredicate::slt, rewriter);
|
||||
|
||||
// Read the value from the input array.
|
||||
iy = rewriter.create<arith::IndexCastOp>(loc, rewriter.getIndexType(),
|
||||
iy);
|
||||
ix = rewriter.create<arith::IndexCastOp>(loc, rewriter.getIndexType(),
|
||||
ix);
|
||||
|
||||
Value result = rewriter.create<tensor::ExtractOp>(
|
||||
loc, input, ValueRange{batch, iy, ix, channel});
|
||||
|
||||
rewriter.create<linalg::YieldOp>(loc, result);
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
if (op.mode() == "BILINEAR") {
|
||||
Value y0 = iy;
|
||||
Value x0 = ix;
|
||||
|
||||
auto oneVal = rewriter.create<arith::ConstantOp>(
|
||||
loc, rewriter.getI32IntegerAttr(1));
|
||||
Value y1 = rewriter.create<arith::AddIOp>(loc, y0, oneVal);
|
||||
Value x1 = rewriter.create<arith::AddIOp>(loc, x0, oneVal);
|
||||
|
||||
y0 = clampHelper<arith::CmpIOp>(loc, y0, hwMin, hMax,
|
||||
arith::CmpIPredicate::slt, rewriter);
|
||||
y1 = clampHelper<arith::CmpIOp>(loc, y1, hwMin, hMax,
|
||||
arith::CmpIPredicate::slt, rewriter);
|
||||
|
||||
x0 = clampHelper<arith::CmpIOp>(loc, x0, hwMin, wMax,
|
||||
arith::CmpIPredicate::slt, rewriter);
|
||||
x1 = clampHelper<arith::CmpIOp>(loc, x1, hwMin, wMax,
|
||||
arith::CmpIPredicate::slt, rewriter);
|
||||
|
||||
y0 = rewriter.create<arith::IndexCastOp>(loc, rewriter.getIndexType(),
|
||||
y0);
|
||||
y1 = rewriter.create<arith::IndexCastOp>(loc, rewriter.getIndexType(),
|
||||
y1);
|
||||
x0 = rewriter.create<arith::IndexCastOp>(loc, rewriter.getIndexType(),
|
||||
x0);
|
||||
x1 = rewriter.create<arith::IndexCastOp>(loc, rewriter.getIndexType(),
|
||||
x1);
|
||||
|
||||
Value y0x0 = rewriter.create<tensor::ExtractOp>(
|
||||
loc, input, ValueRange{batch, y0, x0, channel});
|
||||
Value y0x1 = rewriter.create<tensor::ExtractOp>(
|
||||
loc, input, ValueRange{batch, y0, x1, channel});
|
||||
Value y1x0 = rewriter.create<tensor::ExtractOp>(
|
||||
loc, input, ValueRange{batch, y1, x0, channel});
|
||||
Value y1x1 = rewriter.create<tensor::ExtractOp>(
|
||||
loc, input, ValueRange{batch, y1, x1, channel});
|
||||
|
||||
if (floatingPointMode) {
|
||||
auto oneVal = rewriter.create<arith::ConstantOp>(
|
||||
loc, rewriter.getF32FloatAttr(1.f));
|
||||
Value rightPart = dx;
|
||||
Value leftPart = rewriter.create<arith::SubFOp>(loc, oneVal, dx);
|
||||
|
||||
y0x0 = rewriter.create<arith::MulFOp>(loc, y0x0, leftPart);
|
||||
y0x1 = rewriter.create<arith::MulFOp>(loc, y0x1, rightPart);
|
||||
Value topAcc = rewriter.create<arith::AddFOp>(loc, y0x0, y0x1);
|
||||
|
||||
y1x0 = rewriter.create<arith::MulFOp>(loc, y1x0, leftPart);
|
||||
y1x1 = rewriter.create<arith::MulFOp>(loc, y1x1, rightPart);
|
||||
Value bottomAcc = rewriter.create<arith::AddFOp>(loc, y1x0, y1x1);
|
||||
|
||||
Value bottomPart = dy;
|
||||
Value topPart = rewriter.create<arith::SubFOp>(loc, oneVal, dy);
|
||||
topAcc = rewriter.create<arith::MulFOp>(loc, topAcc, topPart);
|
||||
bottomAcc =
|
||||
rewriter.create<arith::MulFOp>(loc, bottomAcc, bottomPart);
|
||||
Value result = rewriter.create<arith::AddFOp>(loc, topAcc, bottomAcc);
|
||||
|
||||
rewriter.create<linalg::YieldOp>(loc, result);
|
||||
return success();
|
||||
}
|
||||
y0x0 = rewriter.create<arith::ExtSIOp>(loc, resultElementTy, y0x0);
|
||||
y0x1 = rewriter.create<arith::ExtSIOp>(loc, resultElementTy, y0x1);
|
||||
y1x0 = rewriter.create<arith::ExtSIOp>(loc, resultElementTy, y1x0);
|
||||
y1x1 = rewriter.create<arith::ExtSIOp>(loc, resultElementTy, y1x1);
|
||||
|
||||
if (resultElementTy.getIntOrFloatBitWidth() > 32) {
|
||||
dx = rewriter.create<arith::ExtSIOp>(loc, resultElementTy, dx);
|
||||
dy = rewriter.create<arith::ExtSIOp>(loc, resultElementTy, dy);
|
||||
}
|
||||
|
||||
auto unitVal = rewriter.create<arith::ConstantOp>(
|
||||
loc, rewriter.getIntegerAttr(resultElementTy, 1 << shift));
|
||||
Value rightPart = dx;
|
||||
Value leftPart = rewriter.create<arith::SubIOp>(loc, unitVal, dx);
|
||||
|
||||
y0x0 = rewriter.create<arith::MulIOp>(loc, y0x0, leftPart);
|
||||
y0x1 = rewriter.create<arith::MulIOp>(loc, y0x1, rightPart);
|
||||
Value topAcc = rewriter.create<arith::AddIOp>(loc, y0x0, y0x1);
|
||||
|
||||
y1x0 = rewriter.create<arith::MulIOp>(loc, y1x0, leftPart);
|
||||
y1x1 = rewriter.create<arith::MulIOp>(loc, y1x1, rightPart);
|
||||
Value bottomAcc = rewriter.create<arith::AddIOp>(loc, y1x0, y1x1);
|
||||
|
||||
Value bottomPart = dy;
|
||||
Value topPart = rewriter.create<arith::SubIOp>(loc, unitVal, dy);
|
||||
topAcc = rewriter.create<arith::MulIOp>(loc, topAcc, topPart);
|
||||
bottomAcc =
|
||||
rewriter.create<arith::MulIOp>(loc, bottomAcc, bottomPart);
|
||||
Value result = rewriter.create<arith::AddIOp>(loc, topAcc, bottomAcc);
|
||||
|
||||
rewriter.create<linalg::YieldOp>(loc, result);
|
||||
return success();
|
||||
}
|
||||
|
||||
return failure();
|
||||
yStride = rewriter.create<arith::ConstantOp>(
|
||||
loc, rewriter.getI32IntegerAttr(stride[0]));
|
||||
xStride = rewriter.create<arith::ConstantOp>(
|
||||
loc, rewriter.getI32IntegerAttr(stride[1]));
|
||||
yOffset = rewriter.create<arith::ConstantOp>(
|
||||
loc, rewriter.getI32IntegerAttr(offset[0]));
|
||||
xOffset = rewriter.create<arith::ConstantOp>(
|
||||
loc, rewriter.getI32IntegerAttr(offset[1]));
|
||||
}
|
||||
|
||||
return success();
|
||||
// Compute the the integer index and partial offset.
|
||||
// x = x * stride + offset;
|
||||
// ix = floor(x)
|
||||
// dx = x - ix
|
||||
Value ix, iy, dx, dy;
|
||||
if (floatingPointMode) {
|
||||
Value y =
|
||||
rewriter.create<arith::UIToFPOp>(loc, rewriter.getF32Type(), inY);
|
||||
Value x =
|
||||
rewriter.create<arith::UIToFPOp>(loc, rewriter.getF32Type(), inX);
|
||||
|
||||
y = rewriter.create<arith::MulFOp>(loc, y, yStride);
|
||||
x = rewriter.create<arith::MulFOp>(loc, x, xStride);
|
||||
|
||||
y = rewriter.create<arith::AddFOp>(loc, y, yOffset);
|
||||
x = rewriter.create<arith::AddFOp>(loc, x, xOffset);
|
||||
|
||||
iy = rewriter.create<math::FloorOp>(loc, y);
|
||||
ix = rewriter.create<math::FloorOp>(loc, x);
|
||||
|
||||
dy = rewriter.create<arith::SubFOp>(loc, y, iy);
|
||||
dx = rewriter.create<arith::SubFOp>(loc, x, ix);
|
||||
|
||||
iy = rewriter.create<arith::FPToSIOp>(loc, rewriter.getI32Type(), iy);
|
||||
ix = rewriter.create<arith::FPToSIOp>(loc, rewriter.getI32Type(), ix);
|
||||
} else {
|
||||
Value shiftVal = rewriter.create<arith::ConstantOp>(
|
||||
loc, rewriter.getI32IntegerAttr(shift));
|
||||
|
||||
Value y = rewriter.create<arith::MulIOp>(loc, inY, yStride);
|
||||
Value x = rewriter.create<arith::MulIOp>(loc, inX, xStride);
|
||||
|
||||
y = rewriter.create<arith::AddIOp>(loc, y, yOffset);
|
||||
x = rewriter.create<arith::AddIOp>(loc, x, xOffset);
|
||||
|
||||
iy = rewriter.create<arith::ShRSIOp>(loc, y, shiftVal);
|
||||
ix = rewriter.create<arith::ShRSIOp>(loc, x, shiftVal);
|
||||
|
||||
Value yTrunc = rewriter.create<arith::ShLIOp>(loc, iy, shiftVal);
|
||||
Value xTrunc = rewriter.create<arith::ShLIOp>(loc, ix, shiftVal);
|
||||
|
||||
dy = rewriter.create<arith::SubIOp>(loc, y, yTrunc);
|
||||
dx = rewriter.create<arith::SubIOp>(loc, x, xTrunc);
|
||||
}
|
||||
|
||||
if (op.mode() == "NEAREST_NEIGHBOR") {
|
||||
Value yPred, xPred;
|
||||
// Round the index position towards the closest pixel location.
|
||||
if (floatingPointMode) {
|
||||
auto halfVal = rewriter.create<arith::ConstantOp>(
|
||||
loc, rewriter.getF32FloatAttr(0.5f));
|
||||
yPred = rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGE,
|
||||
dy, halfVal);
|
||||
xPred = rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGE,
|
||||
dx, halfVal);
|
||||
} else {
|
||||
auto halfVal = rewriter.create<arith::ConstantOp>(
|
||||
loc, rewriter.getI32IntegerAttr(1 << (shift - 1)));
|
||||
yPred = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sge,
|
||||
dy, halfVal);
|
||||
xPred = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sge,
|
||||
dx, halfVal);
|
||||
}
|
||||
|
||||
auto zeroVal = rewriter.create<arith::ConstantOp>(
|
||||
loc, rewriter.getI32IntegerAttr(0));
|
||||
auto oneVal = rewriter.create<arith::ConstantOp>(
|
||||
loc, rewriter.getI32IntegerAttr(1));
|
||||
|
||||
auto yOffset =
|
||||
rewriter.create<mlir::SelectOp>(loc, yPred, oneVal, zeroVal);
|
||||
auto xOffset =
|
||||
rewriter.create<mlir::SelectOp>(loc, xPred, oneVal, zeroVal);
|
||||
|
||||
iy = rewriter.create<arith::AddIOp>(loc, iy, yOffset);
|
||||
ix = rewriter.create<arith::AddIOp>(loc, ix, xOffset);
|
||||
|
||||
// Clamp the to be within the bounds of the input image.
|
||||
|
||||
iy = clampHelper<arith::CmpIOp>(loc, iy, hwMin, hMax,
|
||||
arith::CmpIPredicate::slt, rewriter);
|
||||
ix = clampHelper<arith::CmpIOp>(loc, ix, hwMin, wMax,
|
||||
arith::CmpIPredicate::slt, rewriter);
|
||||
|
||||
// Read the value from the input array.
|
||||
iy =
|
||||
rewriter.create<arith::IndexCastOp>(loc, rewriter.getIndexType(), iy);
|
||||
ix =
|
||||
rewriter.create<arith::IndexCastOp>(loc, rewriter.getIndexType(), ix);
|
||||
|
||||
Value result = rewriter.create<tensor::ExtractOp>(
|
||||
loc, input, ValueRange{batch, iy, ix, channel});
|
||||
|
||||
rewriter.create<linalg::YieldOp>(loc, result);
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
if (op.mode() == "BILINEAR") {
|
||||
Value y0 = iy;
|
||||
Value x0 = ix;
|
||||
|
||||
auto oneVal = rewriter.create<arith::ConstantOp>(
|
||||
loc, rewriter.getI32IntegerAttr(1));
|
||||
Value y1 = rewriter.create<arith::AddIOp>(loc, y0, oneVal);
|
||||
Value x1 = rewriter.create<arith::AddIOp>(loc, x0, oneVal);
|
||||
|
||||
y0 = clampHelper<arith::CmpIOp>(loc, y0, hwMin, hMax,
|
||||
arith::CmpIPredicate::slt, rewriter);
|
||||
y1 = clampHelper<arith::CmpIOp>(loc, y1, hwMin, hMax,
|
||||
arith::CmpIPredicate::slt, rewriter);
|
||||
|
||||
x0 = clampHelper<arith::CmpIOp>(loc, x0, hwMin, wMax,
|
||||
arith::CmpIPredicate::slt, rewriter);
|
||||
x1 = clampHelper<arith::CmpIOp>(loc, x1, hwMin, wMax,
|
||||
arith::CmpIPredicate::slt, rewriter);
|
||||
|
||||
y0 =
|
||||
rewriter.create<arith::IndexCastOp>(loc, rewriter.getIndexType(), y0);
|
||||
y1 =
|
||||
rewriter.create<arith::IndexCastOp>(loc, rewriter.getIndexType(), y1);
|
||||
x0 =
|
||||
rewriter.create<arith::IndexCastOp>(loc, rewriter.getIndexType(), x0);
|
||||
x1 =
|
||||
rewriter.create<arith::IndexCastOp>(loc, rewriter.getIndexType(), x1);
|
||||
|
||||
Value y0x0 = rewriter.create<tensor::ExtractOp>(
|
||||
loc, input, ValueRange{batch, y0, x0, channel});
|
||||
Value y0x1 = rewriter.create<tensor::ExtractOp>(
|
||||
loc, input, ValueRange{batch, y0, x1, channel});
|
||||
Value y1x0 = rewriter.create<tensor::ExtractOp>(
|
||||
loc, input, ValueRange{batch, y1, x0, channel});
|
||||
Value y1x1 = rewriter.create<tensor::ExtractOp>(
|
||||
loc, input, ValueRange{batch, y1, x1, channel});
|
||||
|
||||
if (floatingPointMode) {
|
||||
auto oneVal = rewriter.create<arith::ConstantOp>(
|
||||
loc, rewriter.getF32FloatAttr(1.f));
|
||||
Value rightPart = dx;
|
||||
Value leftPart = rewriter.create<arith::SubFOp>(loc, oneVal, dx);
|
||||
|
||||
y0x0 = rewriter.create<arith::MulFOp>(loc, y0x0, leftPart);
|
||||
y0x1 = rewriter.create<arith::MulFOp>(loc, y0x1, rightPart);
|
||||
Value topAcc = rewriter.create<arith::AddFOp>(loc, y0x0, y0x1);
|
||||
|
||||
y1x0 = rewriter.create<arith::MulFOp>(loc, y1x0, leftPart);
|
||||
y1x1 = rewriter.create<arith::MulFOp>(loc, y1x1, rightPart);
|
||||
Value bottomAcc = rewriter.create<arith::AddFOp>(loc, y1x0, y1x1);
|
||||
|
||||
Value bottomPart = dy;
|
||||
Value topPart = rewriter.create<arith::SubFOp>(loc, oneVal, dy);
|
||||
topAcc = rewriter.create<arith::MulFOp>(loc, topAcc, topPart);
|
||||
bottomAcc = rewriter.create<arith::MulFOp>(loc, bottomAcc, bottomPart);
|
||||
Value result = rewriter.create<arith::AddFOp>(loc, topAcc, bottomAcc);
|
||||
|
||||
rewriter.create<linalg::YieldOp>(loc, result);
|
||||
return success();
|
||||
}
|
||||
y0x0 = rewriter.create<arith::ExtSIOp>(loc, resultElementTy, y0x0);
|
||||
y0x1 = rewriter.create<arith::ExtSIOp>(loc, resultElementTy, y0x1);
|
||||
y1x0 = rewriter.create<arith::ExtSIOp>(loc, resultElementTy, y1x0);
|
||||
y1x1 = rewriter.create<arith::ExtSIOp>(loc, resultElementTy, y1x1);
|
||||
|
||||
if (resultElementTy.getIntOrFloatBitWidth() > 32) {
|
||||
dx = rewriter.create<arith::ExtSIOp>(loc, resultElementTy, dx);
|
||||
dy = rewriter.create<arith::ExtSIOp>(loc, resultElementTy, dy);
|
||||
}
|
||||
|
||||
auto unitVal = rewriter.create<arith::ConstantOp>(
|
||||
loc, rewriter.getIntegerAttr(resultElementTy, 1 << shift));
|
||||
Value rightPart = dx;
|
||||
Value leftPart = rewriter.create<arith::SubIOp>(loc, unitVal, dx);
|
||||
|
||||
y0x0 = rewriter.create<arith::MulIOp>(loc, y0x0, leftPart);
|
||||
y0x1 = rewriter.create<arith::MulIOp>(loc, y0x1, rightPart);
|
||||
Value topAcc = rewriter.create<arith::AddIOp>(loc, y0x0, y0x1);
|
||||
|
||||
y1x0 = rewriter.create<arith::MulIOp>(loc, y1x0, leftPart);
|
||||
y1x1 = rewriter.create<arith::MulIOp>(loc, y1x1, rightPart);
|
||||
Value bottomAcc = rewriter.create<arith::AddIOp>(loc, y1x0, y1x1);
|
||||
|
||||
Value bottomPart = dy;
|
||||
Value topPart = rewriter.create<arith::SubIOp>(loc, unitVal, dy);
|
||||
topAcc = rewriter.create<arith::MulIOp>(loc, topAcc, topPart);
|
||||
bottomAcc = rewriter.create<arith::MulIOp>(loc, bottomAcc, bottomPart);
|
||||
Value result = rewriter.create<arith::AddIOp>(loc, topAcc, bottomAcc);
|
||||
|
||||
rewriter.create<linalg::YieldOp>(loc, result);
|
||||
return success();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
|
Loading…
Reference in New Issue