!43525 [MSLITE] Support dynamic stride slice and slice fusion

Merge pull request !43525 from zhangyongxian/dev_zhangyongxian_slice
This commit is contained in:
i-robot 2022-10-10 12:11:28 +00:00 committed by Gitee
commit 84a688bd29
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
7 changed files with 495 additions and 5 deletions

View File

@ -266,7 +266,9 @@ int ElementWiseTensorRT::AddConstTensor(TensorRTContext *ctx) {
nvinfer1::ITensor *constant_input =
ConvertConstantTensorWithDims(ctx, in_tensors_[const_tensor_index], expect_shape, op_name_);
CHECK_NULL_RETURN(constant_input);
auto const_helper = ITensorHelper{constant_input, input(ctx, 1 - const_tensor_index).format_, true};
bool is_tensor = !in_tensors_[const_tensor_index].Shape().empty();
auto const_helper = ITensorHelper{constant_input, input(ctx, 1 - const_tensor_index).format_,
input(ctx, 1 - const_tensor_index).same_format_, is_tensor};
ctx->RegisterTensor(const_helper, in_tensors_[const_tensor_index].Name());
return RET_OK;
}

View File

@ -293,7 +293,5 @@ int SliceTensorRT::AddInnerOp(TensorRTContext *ctx) {
MS_LOG(DEBUG) << "slice output : " << GetTensorFormat(helper);
return RET_OK;
}
REGISTER_TENSORRT_CREATOR(schema::PrimitiveType_StridedSlice, SliceTensorRT)
REGISTER_TENSORRT_CREATOR(schema::PrimitiveType_SliceFusion, SliceTensorRT)
REGISTER_TENSORRT_CREATOR(schema::PrimitiveType_Crop, SliceTensorRT)
} // namespace mindspore::lite

View File

@ -0,0 +1,156 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <algorithm>
#include <utility>
#include "src/litert/delegate/tensorrt/op/slicefusion_tensorrt.h"
#include "src/litert/delegate/tensorrt/tensorrt_utils.h"
namespace mindspore::lite {
nvinfer1::ITensor *SliceFusionTensorRT::GetDynamicSliceSize(TensorRTContext *ctx, nvinfer1::ITensor *input,
const nvinfer1::Dims &start_dims,
const nvinfer1::Dims &size_dims) {
auto in_tensor_shape = ctx->network()->addShape(*input)->getOutput(0);
if (in_tensor_shape == nullptr) {
MS_LOG(ERROR) << "add shape layer of input failed!";
return nullptr;
}
std::vector<nvinfer1::ITensor *> shape_tensors;
auto input_dims = input->getDimensions();
std::vector<int> input_shape_vec;
for (int i = 0; i != input_dims.nbDims; ++i) {
if (input_dims.d[i] == -1) {
if (!input_shape_vec.empty()) {
shape_tensors.push_back(ctx->ConvertTo1DTensor(input_shape_vec));
input_shape_vec.clear();
}
auto starts = nvinfer1::Dims{1, {i}};
auto size = nvinfer1::Dims{1, {1}};
auto strides = nvinfer1::Dims{1, {1}};
auto slice_layer = ctx->network()->addSlice(*in_tensor_shape, starts, size, strides);
if (slice_layer == nullptr) {
MS_LOG(ERROR) << "add slice layer failed";
return nullptr;
}
auto start_tensor = ctx->ConvertTo1DTensor(start_dims.d[i]);
shape_tensors.push_back(
ctx->network()
->addElementWise(*slice_layer->getOutput(0), *start_tensor, nvinfer1::ElementWiseOperation::kSUB)
->getOutput(0));
} else {
input_shape_vec.push_back(size_dims.d[i]);
}
}
if (!input_shape_vec.empty()) {
shape_tensors.push_back(ctx->ConvertTo1DTensor(input_shape_vec));
}
auto concat_layer = ctx->network()->addConcatenation(shape_tensors.data(), shape_tensors.size());
if (concat_layer == nullptr) {
MS_LOG(ERROR) << "add concat layer failed!";
return nullptr;
}
concat_layer->setAxis(0);
return concat_layer->getOutput(0);
}
int SliceFusionTensorRT::IsSupport(const mindspore::schema::Primitive *primitive,
const std::vector<mindspore::MSTensor> &in_tensors,
const std::vector<mindspore::MSTensor> &out_tensors) {
if (in_tensors.size() != SLICE_INPUT_SIZE) {
MS_LOG(ERROR) << "Unsupported input tensor size, size is " << in_tensors.size();
return RET_ERROR;
}
if (out_tensors.size() != 1) {
MS_LOG(ERROR) << "Unsupported output tensor size, size is " << out_tensors.size();
return RET_ERROR;
}
dynamic_shape_params_.support_hw_dynamic_ = false;
return RET_OK;
}
int SliceFusionTensorRT::AddInnerOp(TensorRTContext *ctx) {
ITensorHelper slice_input;
int ret = PreprocessInputs2SameDim(ctx, input(ctx, 0), &slice_input);
if (ret != RET_OK || slice_input.trt_tensor_ == nullptr) {
MS_LOG(ERROR) << "PreprocessInputs2SameDim input tensor failed for " << op_name_;
return RET_ERROR;
}
const auto &begin = in_tensors_.at(1);
const auto &size = in_tensors_.at(SIZE_INDEX);
auto start_dims = lite::ConvertCudaDims(begin.Data().get(), begin.ElementNum());
auto size_dims =
size.Data() == nullptr ? nvinfer1::Dims{0} : lite::ConvertCudaDims(size.Data().get(), size.ElementNum());
nvinfer1::ITensor *size_tensor = nullptr;
for (int i = 0; i != size_dims.nbDims; ++i) {
if (size_dims.d[i] == -1 && !IsDynamicInput(ctx, 0)) {
size_dims.d[i] = slice_input.trt_tensor_->getDimensions().d[i];
}
}
if (IsDynamicInput(ctx, 0)) {
size_tensor = GetDynamicSliceSize(ctx, slice_input.trt_tensor_, start_dims, size_dims);
size_dims = nvinfer1::Dims{-1};
}
if (size.Data() == nullptr) {
size_tensor = input(ctx, INPUT_SIZE2).trt_tensor_;
auto shape_vec_int64 = in_tensors_[0].Shape();
slice_input.trt_tensor_ = ConvertConstantTensor(ctx, in_tensors_[0], op_name_ + "_input");
CHECK_NULL_RETURN(slice_input.trt_tensor_);
std::vector<int> shape_vec_int32;
std::copy(shape_vec_int64.begin(), shape_vec_int64.end(), std::back_inserter(shape_vec_int32));
auto input_shape = ctx->ConvertTo1DTensor(shape_vec_int32);
CHECK_NULL_RETURN(input_shape);
auto minus_one = ctx->ConvertTo1DTensor(-1);
auto eq_minus_one =
ctx->network()->addElementWise(*size_tensor, *minus_one, nvinfer1::ElementWiseOperation::kEQUAL)->getOutput(0);
auto int_eq_minus_one =
TRTTensorCast(ctx, eq_minus_one, nvinfer1::DataType::kINT32, op_name_ + "_cast_int_mines_one");
auto gr_minus_one =
ctx->network()->addElementWise(*size_tensor, *minus_one, nvinfer1::ElementWiseOperation::kGREATER)->getOutput(0);
auto int_gr_minus_one =
TRTTensorCast(ctx, gr_minus_one, nvinfer1::DataType::kINT32, op_name_ + "_cast_int_ge_mines_one");
auto x = ctx->network()
->addElementWise(*int_gr_minus_one, *size_tensor, nvinfer1::ElementWiseOperation::kPROD)
->getOutput(0);
auto y = ctx->network()
->addElementWise(*int_eq_minus_one, *input_shape, nvinfer1::ElementWiseOperation::kPROD)
->getOutput(0);
size_tensor = ctx->network()->addElementWise(*x, *y, nvinfer1::ElementWiseOperation::kSUM)->getOutput(0);
size_dims = nvinfer1::Dims{-1};
}
auto stride_dims = lite::ConvertCudaDims(1, begin.ElementNum());
nvinfer1::ISliceLayer *slice_layer =
ctx->network()->addSlice(*slice_input.trt_tensor_, start_dims, size_dims, stride_dims);
if (slice_layer == nullptr) {
MS_LOG(ERROR) << "add Slice op failed for TensorRT: " << op_name_;
return RET_ERROR;
}
if (size_tensor != nullptr) {
slice_layer->setInput(INPUT_SIZE2, *size_tensor);
}
this->layer_ = slice_layer;
slice_layer->setName(op_name_.c_str());
nvinfer1::ITensor *out_tensor = slice_layer->getOutput(0);
auto helper = ITensorHelper{out_tensor, slice_input.format_, slice_input.same_format_};
ctx->RegisterTensor(helper, out_tensors_[0].Name());
MS_LOG(DEBUG) << "slice output : " << GetTensorFormat(helper);
return RET_OK;
}
REGISTER_TENSORRT_CREATOR(schema::PrimitiveType_SliceFusion, SliceFusionTensorRT)
} // namespace mindspore::lite

View File

@ -0,0 +1,45 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_LITERT_DELEGATE_TENSORRT_OP_SLICE_FUSION_TENSORRT_H_
#define MINDSPORE_LITE_SRC_LITERT_DELEGATE_TENSORRT_OP_SLICE_FUSION_TENSORRT_H_
#include <string>
#include <vector>
#include <memory>
#include "src/litert/delegate/tensorrt/op/tensorrt_op.h"
namespace mindspore::lite {
constexpr int SIZE_INDEX = 2;
constexpr int SLICE_INPUT_SIZE = 3;
class SliceFusionTensorRT : public TensorRTOp {
public:
SliceFusionTensorRT(const schema::Primitive *primitive, const std::vector<mindspore::MSTensor> &in_tensors,
const std::vector<mindspore::MSTensor> &out_tensors, const std::string &name,
const schema::QuantType &quant_type)
: TensorRTOp(primitive, in_tensors, out_tensors, name, quant_type) {}
~SliceFusionTensorRT() override = default;
int AddInnerOp(TensorRTContext *ctx) override;
int IsSupport(const schema::Primitive *primitive, const std::vector<mindspore::MSTensor> &in_tensors,
const std::vector<mindspore::MSTensor> &out_tensors) override;
private:
nvinfer1::ITensor *GetDynamicSliceSize(TensorRTContext *ctx, nvinfer1::ITensor *input,
const nvinfer1::Dims &start_dims, const nvinfer1::Dims &size_dims);
};
} // namespace mindspore::lite
#endif // MINDSPORE_LITE_SRC_LITERT_DELEGATE_TENSORRT_OP_SLICE_FUSION_TENSORRT_H_

View File

@ -0,0 +1,231 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <algorithm>
#include <numeric>
#include "src/litert/delegate/tensorrt/op/strideslice_tensorrt.h"
#include "src/litert/delegate/tensorrt/tensorrt_utils.h"
namespace mindspore::lite {
nvinfer1::ITensor *StrideSliceTensorRT::GetDynamicAxisSliceSize(TensorRTContext *ctx, nvinfer1::ITensor *input,
int size_dim, int axis,
nvinfer1::ITensor *size_tensor) {
auto in_tensor_shape = ctx->network()->addShape(*input)->getOutput(0);
if (in_tensor_shape == nullptr) {
MS_LOG(ERROR) << "add shape layer of input failed!";
return nullptr;
}
auto len_tensor = (size_tensor == nullptr ? ctx->ConvertTo1DTensor(static_cast<int>(size_dim)) : size_tensor);
if (len_tensor == nullptr) {
MS_LOG(ERROR) << "convert 1d tensor failed!";
return nullptr;
}
nvinfer1::ITensor *concat_input_tensors[INPUT_SIZE2];
concat_input_tensors[0] = in_tensor_shape;
concat_input_tensors[1] = len_tensor;
auto concat_layer = ctx->network()->addConcatenation(concat_input_tensors, INPUT_SIZE2);
if (concat_layer == nullptr) {
MS_LOG(ERROR) << "add concat layer failed!";
return nullptr;
}
concat_layer->setAxis(0);
auto shape_and_len = concat_layer->getOutput(0);
if (shape_and_len == nullptr) {
MS_LOG(ERROR) << "get concat layer result failed!";
return nullptr;
}
std::vector<int> gather_slices(input->getDimensions().nbDims);
std::iota(gather_slices.begin(), gather_slices.end(), 0);
gather_slices[axis] = gather_slices.size();
auto gather_slices_tensor = ctx->ConvertTo1DTensor(gather_slices);
nvinfer1::IGatherLayer *gather_layer = ctx->network()->addGather(*shape_and_len, *gather_slices_tensor, 0);
if (gather_layer == nullptr) {
MS_LOG(ERROR) << "add gather layer failed!";
return nullptr;
}
return gather_layer->getOutput(0);
}
nvinfer1::ITensor *StrideSliceTensorRT::GetDynamicSliceSize(TensorRTContext *ctx, nvinfer1::ITensor *input,
const nvinfer1::Dims &size_dims) {
auto in_tensor_shape = ctx->network()->addShape(*input)->getOutput(0);
if (in_tensor_shape == nullptr) {
MS_LOG(ERROR) << "add shape layer of input failed!";
return nullptr;
}
std::vector<int> is_dynamic;
std::vector<int> is_fix;
std::vector<int> size_vec;
for (int i = 0; i != size_dims.nbDims; ++i) {
is_dynamic.push_back(size_dims.d[i] < 0);
is_fix.push_back(size_dims.d[i] >= 0);
size_vec.push_back(size_dims.d[i]);
}
auto is_dynamic_tensor = ctx->ConvertTo1DTensor(is_dynamic);
auto is_fix_tensor = ctx->ConvertTo1DTensor(is_fix);
auto size_tensor = ctx->ConvertTo1DTensor(size_vec);
auto fix_tensor =
ctx->network()->addElementWise(*is_fix_tensor, *size_tensor, nvinfer1::ElementWiseOperation::kPROD)->getOutput(0);
auto dynamic_tensor = ctx->network()
->addElementWise(*is_dynamic_tensor, *in_tensor_shape, nvinfer1::ElementWiseOperation::kPROD)
->getOutput(0);
size_tensor =
ctx->network()->addElementWise(*dynamic_tensor, *fix_tensor, nvinfer1::ElementWiseOperation::kSUM)->getOutput(0);
return size_tensor;
}
int StrideSliceTensorRT::IsSupport(const mindspore::schema::Primitive *primitive,
const std::vector<mindspore::MSTensor> &in_tensors,
const std::vector<mindspore::MSTensor> &out_tensors) {
if (in_tensors.size() < HAS_AXIS - 1) {
MS_LOG(ERROR) << "Unsupported input tensor size, size is " << in_tensors.size();
return RET_ERROR;
}
if (out_tensors.size() != 1) {
MS_LOG(ERROR) << "Unsupported output tensor size, size is " << out_tensors.size();
return RET_ERROR;
}
return RET_OK;
}
int StrideSliceTensorRT::ComputeSliceDims(TensorRTContext *ctx, ITensorHelper *slice_input) {
shrink_axis_ = op_primitive_->value_as_StridedSlice()->shrink_axis_mask();
size_t start_mask = op_primitive_->value_as_StridedSlice()->begin_mask();
size_t end_mask = op_primitive_->value_as_StridedSlice()->end_mask();
const mindspore::MSTensor &begin = in_tensors_.at(BEGINS_INDEX);
const mindspore::MSTensor &stride = in_tensors_.back();
const mindspore::MSTensor &end = in_tensors_.at(ENDS_INDEX);
auto input_dims = slice_input->trt_tensor_->getDimensions();
size_t axis_index = in_tensors_.size() == HAS_AXIS ? AXIS_INDEX : -1;
if (static_cast<size_t>(begin.ElementNum()) == slice_input->trt_tensor_->getDimensions().nbDims) {
start_dims_ = lite::ConvertCudaDims(begin.Data().get(), begin.ElementNum());
auto end_dims = lite::ConvertCudaDims(end.Data().get(), end.ElementNum());
size_dims_.nbDims = input_dims.nbDims;
for (int i = 0; i < size_dims_.nbDims; i++) {
size_t mask = 1 << i;
start_dims_.d[i] = ((start_mask & mask) == 0 ? start_dims_.d[i] : 0);
if (end.Data() == nullptr) {
continue;
}
end_dims.d[i] = ((end_mask & mask) == 0 ? end_dims.d[i] : slice_input->trt_tensor_->getDimensions().d[i]);
size_dims_.d[i] = end_dims.d[i] - start_dims_.d[i];
}
stride_dims_ = lite::ConvertCudaDims(stride.Data().get(), stride.ElementNum());
if (IsDynamicInput(ctx, 0)) {
size_tensor_ = GetDynamicSliceSize(ctx, slice_input->trt_tensor_, size_dims_);
size_dims_ = nvinfer1::Dims{-1};
}
} else {
if (axis_index == -1 || in_tensors_.at(axis_index).ElementNum() != 1) {
MS_LOG(ERROR) << "invalid input params for " << op_name_;
return RET_ERROR;
}
int axis_value = *(static_cast<const int *>(in_tensors_.at(axis_index).Data().get()));
int start_value = *(static_cast<const int *>(begin.Data().get()));
int stride_value = *(static_cast<const int *>(stride.Data().get()));
start_dims_.nbDims = input_dims.nbDims;
std::fill(start_dims_.d, start_dims_.d + start_dims_.nbDims, 0);
stride_dims_.nbDims = input_dims.nbDims;
std::fill(stride_dims_.d, stride_dims_.d + stride_dims_.nbDims, 1);
size_dims_ = slice_input->trt_tensor_->getDimensions();
if (start_value < 0) {
start_value = input_dims.d[axis_value] + start_value;
}
for (int i = 0; i < start_dims_.nbDims; i++) {
if (i == axis_value) {
start_dims_.d[i] = start_value;
stride_dims_.d[i] = stride_value;
if (end.Data() != nullptr) {
int end_value = *(static_cast<const int *>(end.Data().get()));
if (end_value >= 0) {
size_dims_.d[i] = std::min(end_value, input_dims.d[i]) - start_dims_.d[i];
} else if (end_value >= -input_dims.d[i]) {
size_dims_.d[i] = end_value + input_dims.d[i] - start_dims_.d[i];
} else {
size_dims_.d[i] = input_dims.d[i];
}
}
}
}
if (IsDynamicInput(ctx, 0)) {
size_tensor_ =
GetDynamicAxisSliceSize(ctx, slice_input->trt_tensor_, size_dims_.d[axis_value], axis_value, nullptr);
size_dims_ = nvinfer1::Dims{-1};
}
if (end.Data() == nullptr) {
auto start_tensor = ctx->ConvertTo1DTensor(start_value);
auto len_tensor =
ctx->network()
->addElementWise(*input(ctx, INPUT_SIZE2).trt_tensor_, *start_tensor, nvinfer1::ElementWiseOperation::kSUB)
->getOutput(0);
size_tensor_ = GetDynamicAxisSliceSize(ctx, slice_input->trt_tensor_, -1, axis_value, len_tensor);
size_dims_ = nvinfer1::Dims{-1};
}
}
return RET_OK;
}
int StrideSliceTensorRT::AddInnerOp(TensorRTContext *ctx) {
ITensorHelper slice_input;
int ret = PreprocessInputs2SameDim(ctx, input(ctx, 0), &slice_input);
if (ret != RET_OK || slice_input.trt_tensor_ == nullptr) {
MS_LOG(ERROR) << "PreprocessInputs2SameDim input tensor failed for " << op_name_;
return RET_ERROR;
}
if (ComputeSliceDims(ctx, &slice_input) != RET_OK) {
return RET_ERROR;
}
nvinfer1::ISliceLayer *slice_layer =
ctx->network()->addSlice(*slice_input.trt_tensor_, start_dims_, size_dims_, stride_dims_);
if (slice_layer == nullptr) {
MS_LOG(ERROR) << "add Slice op failed for TensorRT: " << op_name_;
return RET_ERROR;
}
if (size_tensor_ != nullptr) {
slice_layer->setInput(INPUT_SIZE2, *size_tensor_);
}
this->layer_ = slice_layer;
slice_layer->setName(op_name_.c_str());
nvinfer1::ITensor *out_tensor = slice_layer->getOutput(0);
auto shape = ConvertMSShape(out_tensor->getDimensions());
bool rank_0 = false;
if (shrink_axis_ != 0) {
for (int i = shape.size() - 1; i >= 0; --i) {
int mask = 1 << i;
if ((shrink_axis_ & mask) != 0) {
shape.erase(shape.begin() + i);
}
}
if (!shape.empty()) {
out_tensor = Reshape(ctx, out_tensor, shape);
} else {
rank_0 = true;
}
}
auto helper = ITensorHelper{out_tensor, slice_input.format_, slice_input.same_format_, !rank_0};
ctx->RegisterTensor(helper, out_tensors_[0].Name());
MS_LOG(DEBUG) << "slice output : " << GetTensorFormat(helper);
return RET_OK;
}
REGISTER_TENSORRT_CREATOR(schema::PrimitiveType_StridedSlice, StrideSliceTensorRT)
} // namespace mindspore::lite

View File

@ -0,0 +1,58 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_LITERT_DELEGATE_TENSORRT_OP_STRIDE_SLICE_TENSORRT_H_
#define MINDSPORE_LITE_SRC_LITERT_DELEGATE_TENSORRT_OP_STRIDE_SLICE_TENSORRT_H_
#include <string>
#include <vector>
#include <tuple>
#include <memory>
#include "src/litert/delegate/tensorrt/op/tensorrt_op.h"
namespace mindspore::lite {
constexpr int BEGINS_INDEX = 1;
constexpr int ENDS_INDEX = 2;
constexpr int HAS_AXIS = 5;
constexpr int AXIS_INDEX = 3;
class StrideSliceTensorRT : public TensorRTOp {
public:
StrideSliceTensorRT(const schema::Primitive *primitive, const std::vector<mindspore::MSTensor> &in_tensors,
const std::vector<mindspore::MSTensor> &out_tensors, const std::string &name,
const schema::QuantType &quant_type)
: TensorRTOp(primitive, in_tensors, out_tensors, name, quant_type) {}
~StrideSliceTensorRT() override = default;
int AddInnerOp(TensorRTContext *ctx) override;
int IsSupport(const schema::Primitive *primitive, const std::vector<mindspore::MSTensor> &in_tensors,
const std::vector<mindspore::MSTensor> &out_tensors) override;
private:
nvinfer1::ITensor *GetDynamicSliceSize(TensorRTContext *ctx, nvinfer1::ITensor *input,
const nvinfer1::Dims &size_dims);
nvinfer1::ITensor *GetDynamicAxisSliceSize(TensorRTContext *ctx, nvinfer1::ITensor *input, int size_dim, int axis,
nvinfer1::ITensor *size_tensor);
int ComputeSliceDims(TensorRTContext *ctx, ITensorHelper *slice_input);
size_t shrink_axis_;
size_t start_axis_;
size_t end_axis_;
nvinfer1::Dims start_dims_;
nvinfer1::Dims size_dims_;
nvinfer1::Dims stride_dims_;
nvinfer1::ITensor *size_tensor_{nullptr};
};
} // namespace mindspore::lite
#endif // MINDSPORE_LITE_SRC_LITERT_DELEGATE_TENSORRT_OP_STRIDE_SLICE_TENSORRT_H_

View File

@ -57,10 +57,10 @@ class TensorRTSubGraph : public kernel::Kernel {
trt_specific_weight_nodes_ = {
schema::PrimitiveType_Conv2DFusion, schema::PrimitiveType_ReduceFusion, schema::PrimitiveType_Transpose,
schema::PrimitiveType_Gather, schema::PrimitiveType_Reshape, schema::PrimitiveType_MatMulFusion,
schema::PrimitiveType_ScaleFusion, schema::PrimitiveType_StridedSlice, schema::PrimitiveType_PadFusion,
schema::PrimitiveType_ScaleFusion, schema::PrimitiveType_PadFusion, schema::PrimitiveType_BroadcastTo,
schema::PrimitiveType_FullConnection, schema::PrimitiveType_Cast, schema::PrimitiveType_ExpandDims,
schema::PrimitiveType_Resize, schema::PrimitiveType_LSTM, schema::PrimitiveType_LayerNormFusion,
schema::PrimitiveType_TopKFusion, schema::PrimitiveType_TileFusion, schema::PrimitiveType_BroadcastTo,
schema::PrimitiveType_TopKFusion, schema::PrimitiveType_TileFusion,
};
if (!support_resize) {
input_batchsize_index_ = -1;