!41307 [MSLITE] Refactor Tensorrt broadcastto and reshape op

Merge pull request !41307 from zhangyongxian/dev_zhangyongxian_removemsshape
This commit is contained in:
i-robot 2022-09-02 08:56:17 +00:00 committed by Gitee
commit 44769acbe1
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
5 changed files with 133 additions and 111 deletions

View File

@ -18,6 +18,7 @@
#include <vector>
#include <numeric>
#include <functional>
#include "src/extendrt/delegate/tensorrt/tensorrt_utils.h"
namespace mindspore::lite {
int ShuffleTensorRT::IsSupport(const schema::Primitive *primitive, const std::vector<mindspore::MSTensor> &in_tensors,
@ -61,14 +62,10 @@ int ShuffleTensorRT::IsSupport(const schema::Primitive *primitive, const std::ve
return RET_ERROR;
}
dynamic_shape_params_.support_hw_dynamic_ = false;
if (in_tensors[0].Shape()[0] != out_tensors[0].Shape()[0]) {
dynamic_shape_params_.support_dynamic_ = false;
}
break;
}
case schema::PrimitiveType_Transpose:
case schema::PrimitiveType_ExpandDims:
case schema::PrimitiveType_BroadcastTo: {
case schema::PrimitiveType_ExpandDims: {
if (in_tensors.size() != INPUT_SIZE2) {
MS_LOG(ERROR) << "PrimitiveType_Transpose Unsupported in_tensors size: " << in_tensors.size();
return RET_ERROR;
@ -79,6 +76,13 @@ int ShuffleTensorRT::IsSupport(const schema::Primitive *primitive, const std::ve
}
break;
}
case schema::PrimitiveType_BroadcastTo: {
if (in_tensors.size() != INPUT_SIZE2) {
MS_LOG(ERROR) << "PrimitiveType_Transpose Unsupported in_tensors size: " << in_tensors.size();
return RET_ERROR;
}
break;
}
default: {
MS_LOG(ERROR) << "Unsupported op type:" << schema::EnumNamePrimitiveType(type_);
return RET_ERROR;
@ -243,7 +247,7 @@ int ShuffleTensorRT::AddUnsqueezeOp(nvinfer1::IShuffleLayer *shuffle_layer) {
nvinfer1::ITensor *expand_input = shuffler_input_;
if (input(ctx_, 0).is_tensor_ == true) {
for (size_t i = 0; i < param_axis_->size(); i++) {
expand_input = ExpandDim(shuffle_layer, expand_input, param_axis_->Get(i));
expand_input = ExpandDim(ctx_, expand_input, param_axis_->Get(i));
}
}
shuffler_output_ = expand_input;
@ -292,8 +296,12 @@ int ShuffleTensorRT::AddReshapeOp(nvinfer1::IShuffleLayer *shuffle_layer) {
mindspore::MSTensor &shape_tensor = in_tensors_[1];
if (shape_tensor.Data() != nullptr) {
// static shuffle layer
shuffle_layer->setReshapeDimensions(
InferReshapeDims(shuffler_input_->getDimensions(), in_tensors_[0].Shape(), out_tensors_[0].Shape()));
nvinfer1::Dims reshape_dims{shape_tensor.ElementNum()};
const int *shape_ptr = reinterpret_cast<const int *>(shape_tensor.Data().get());
for (int i = 0; i != shape_tensor.ElementNum(); ++i) {
reshape_dims.d[i] = *(shape_ptr + i);
}
shuffle_layer->setReshapeDimensions(reshape_dims);
} else {
if (in_tensors_.size() != INPUT_SIZE2) {
MS_LOG(ERROR) << "invalid shape tensor for reshape " << op_name_;
@ -325,114 +333,44 @@ int ShuffleTensorRT::AddExpandDimsOp(nvinfer1::IShuffleLayer *shuffle_layer) {
}
auto axis_data = static_cast<const int *>(in_tensors_[1].Data().get());
int axis = axis_data[0];
shuffler_output_ = ExpandDim(shuffle_layer, shuffler_input_, axis);
shuffler_output_ = ExpandDim(ctx_, shuffler_input_, axis);
return shuffler_output_ == nullptr ? RET_ERROR : RET_OK;
}
int ShuffleTensorRT::AddBroadcastToOp(nvinfer1::IShuffleLayer *shuffle_layer) {
if (out_tensors_[0].ElementNum() != in_tensors_[0].ElementNum() &&
out_tensors_[0].Shape().size() == in_tensors_[0].Shape().size()) {
MS_LOG(WARNING) << "broadcast element cnt changes, ignore broadcast for " << op_name_;
shuffle_layer->setReshapeDimensions(shuffler_input_->getDimensions());
MS_LOG(WARNING) << "here " << op_name_;
} else if (out_tensors_[0].ElementNum() == in_tensors_[0].ElementNum()) {
nvinfer1::Dims new_dims = ConvertCudaDims(out_tensors_[0].Shape());
if (new_dims.nbDims == -1) {
MS_LOG(ERROR) << "ConvertCudaDims failed for " << op_name_;
return RET_ERROR;
}
new_dims.d[0] = shuffler_input_->getDimensions().d[0];
shuffle_layer->setReshapeDimensions(new_dims);
MS_LOG(WARNING) << "here " << op_name_;
if (ReadyInputsNumber(ctx_) == INPUT_SIZE2) {
auto input_shape_tensor = input(ctx_, 1).trt_tensor_;
shuffler_output_ = Broadcast(ctx_, shuffler_input_, input_shape_tensor);
} else {
MS_LOG(ERROR) << "broadcast needs check for " << op_name_;
std::vector<int> input_shape;
const int *shape_ptr = reinterpret_cast<const int *>(in_tensors_[1].Data().get());
for (int i = 0; i != in_tensors_[1].ElementNum(); ++i) {
input_shape.push_back(*(shape_ptr + i));
}
nvinfer1::Dims in_tensor_dims = shuffler_input_->getDimensions();
auto input_shape_tensor = ctx_->ConvertTo1DTensor(input_shape);
while (in_tensor_dims.nbDims < input_shape.size()) {
shuffler_input_ = ExpandDim(ctx_, shuffler_input_, 0);
if (shuffler_input_->getDimensions().nbDims == -1) {
MS_LOG(ERROR) << "ConvertCudaDims failed for " << op_name_;
return RET_ERROR;
}
shuffle_layer->setReshapeDimensions(shuffler_input_->getDimensions());
shuffler_input_ = shuffle_layer->getOutput(0);
in_tensor_dims = shuffler_input_->getDimensions();
}
auto size_tensor = ctx_->network()->addShape(*shuffler_input_)->getOutput(0);
size_tensor = ctx_->network()
->addElementWise(*input_shape_tensor, *size_tensor, nvinfer1::ElementWiseOperation::kMAX)
->getOutput(0);
shuffler_output_ = Broadcast(ctx_, shuffler_input_, size_tensor);
}
shuffler_output_ = shuffle_layer->getOutput(0);
return shuffler_output_ == nullptr ? RET_ERROR : RET_OK;
}
nvinfer1::ITensor *ShuffleTensorRT::ExpandDim(nvinfer1::IShuffleLayer *shuffle_layer, nvinfer1::ITensor *input_tensor,
int axis) {
auto input_dims = input_tensor->getDimensions();
// if expand dim not at last dim and shape is dynamic, change to expanddim at last dim and transpose
bool special_expand = false;
for (int i = 0; i < input_dims.nbDims; i++) {
special_expand = special_expand || input_dims.d[i] == -1;
}
special_expand = special_expand && (axis != -1 && axis != input_dims.nbDims - 1);
if (special_expand) {
std::vector<int64_t> new_shape;
for (int i = 0; i < input_dims.nbDims; i++) {
new_shape.push_back(input_dims.d[i] == -1 ? 0 : input_dims.d[i]);
}
new_shape.push_back(1);
nvinfer1::Dims new_dims = ConvertCudaDims(new_shape);
if (new_dims.nbDims == -1) {
MS_LOG(ERROR) << "ConvertCudaDims failed for " << op_name_;
return nullptr;
}
shuffle_layer->setReshapeDimensions(new_dims);
// transpose
nvinfer1::Permutation perm{};
for (int i = 0; i < new_dims.nbDims; i++) {
if (i < axis) {
perm.order[i] = i;
} else if (i == axis) {
perm.order[i] = new_dims.nbDims - 1;
} else {
perm.order[i] = i - 1;
}
}
nvinfer1::IShuffleLayer *trans_layer = ctx_->network()->addShuffle(*shuffle_layer->getOutput(0));
if (trans_layer == nullptr) {
MS_LOG(ERROR) << "add transpose layer failed for special expand dims op " << op_name_;
return nullptr;
}
trans_layer->setFirstTranspose(perm);
return trans_layer->getOutput(0);
} else {
std::vector<int64_t> new_shape;
for (int i = 0; i < input_dims.nbDims; i++) {
if (axis == i) {
new_shape.push_back(1);
}
new_shape.push_back(input_dims.d[i] == -1 ? 0 : input_dims.d[i]);
}
if (axis == -1 || axis == input_dims.nbDims) {
new_shape.push_back(1);
}
nvinfer1::Dims new_dims = ConvertCudaDims(new_shape);
if (new_dims.nbDims == -1) {
MS_LOG(ERROR) << "ConvertCudaDims failed for " << op_name_;
return nullptr;
}
shuffle_layer->setReshapeDimensions(new_dims);
return shuffle_layer->getOutput(0);
}
}
nvinfer1::Dims ShuffleTensorRT::InferReshapeDims(const nvinfer1::Dims &input_dims,
const std::vector<int64_t> &ms_input_shape,
const std::vector<int64_t> &ms_output_shape) {
// tensorrt support infer shape of 0 and -1
nvinfer1::Dims reshape_dims = ConvertCudaDims(ms_output_shape);
if (reshape_dims.nbDims == -1) {
MS_LOG(ERROR) << "ConvertCudaDims failed for " << op_name_;
return reshape_dims;
}
for (int i = 0; i < reshape_dims.nbDims; i++) {
if (input_dims.d[i] == -1) {
if (ms_input_shape[i] == ms_output_shape[i]) {
reshape_dims.d[i] = 0;
} else {
reshape_dims.d[i] = -1;
}
}
MS_LOG(DEBUG) << "reshape infer_index " << i << " value: " << reshape_dims.d[i];
}
return reshape_dims;
}
REGISTER_TENSORRT_CREATOR(schema::PrimitiveType_Unsqueeze, ShuffleTensorRT)
REGISTER_TENSORRT_CREATOR(schema::PrimitiveType_Squeeze, ShuffleTensorRT)
REGISTER_TENSORRT_CREATOR(schema::PrimitiveType_Reshape, ShuffleTensorRT)

View File

@ -44,9 +44,6 @@ class ShuffleTensorRT : public TensorRTOp {
int AddFlattenOp(nvinfer1::IShuffleLayer *shuffle_layer);
int AddExpandDimsOp(nvinfer1::IShuffleLayer *shuffle_layer);
int AddBroadcastToOp(nvinfer1::IShuffleLayer *shuffle_layer);
nvinfer1::ITensor *ExpandDim(nvinfer1::IShuffleLayer *shuffle_layer, nvinfer1::ITensor *input_tensor, int axis);
nvinfer1::Dims InferReshapeDims(const nvinfer1::Dims &input_dims, const std::vector<int64_t> &ms_input_shape,
const std::vector<int64_t> &ms_output_shape);
Format out_format_ = Format::NHWC;
nvinfer1::ITensor *shuffler_input_{nullptr};

View File

@ -58,7 +58,7 @@ class TensorRTSubGraph : public kernel::Kernel {
schema::PrimitiveType_ScaleFusion, schema::PrimitiveType_StridedSlice, schema::PrimitiveType_PadFusion,
schema::PrimitiveType_FullConnection, schema::PrimitiveType_Cast, schema::PrimitiveType_ExpandDims,
schema::PrimitiveType_Resize, schema::PrimitiveType_LSTM, schema::PrimitiveType_LayerNormFusion,
schema::PrimitiveType_TopKFusion, schema::PrimitiveType_TileFusion,
schema::PrimitiveType_TopKFusion, schema::PrimitiveType_TileFusion, schema::PrimitiveType_BroadcastTo,
};
if (!support_resize) {
input_batchsize_index_ = -1;

View File

@ -24,6 +24,9 @@
#include "src/extendrt/delegate/tensorrt/distribution/distribution_collective.h"
namespace mindspore::lite {
namespace {
const int INPUT2 = 2;
}
nvinfer1::Dims ConvertCudaDims(int data, size_t size) {
nvinfer1::Dims dims{};
dims.nbDims = -1;
@ -720,6 +723,86 @@ int ParseData2Vector(const mindspore::MSTensor &ms_tensor, std::vector<float> *d
return RET_OK;
}
nvinfer1::ITensor *ExpandDim(TensorRTContext *ctx, nvinfer1::ITensor *input_tensor, int axis) {
// input has to prepocess to nchw
auto input_dims = input_tensor->getDimensions();
nvinfer1::IShuffleLayer *shuffle_layer = ctx->network()->addShuffle(*input_tensor);
// if expand dim not at last dim and shape is dynamic, change to expanddim at last dim and transpose
bool special_expand = false;
for (int i = 0; i < input_dims.nbDims; i++) {
special_expand = special_expand || input_dims.d[i] == -1;
}
special_expand = special_expand && (axis != -1 && axis != input_dims.nbDims);
if (special_expand) {
std::vector<int64_t> new_shape;
for (int i = 0; i < input_dims.nbDims; i++) {
new_shape.push_back(input_dims.d[i] == -1 ? 0 : input_dims.d[i]);
}
new_shape.push_back(1);
nvinfer1::Dims new_dims = ConvertCudaDims(new_shape);
if (new_dims.nbDims == -1) {
return nullptr;
}
shuffle_layer->setReshapeDimensions(new_dims);
// transpose
nvinfer1::Permutation perm{};
for (int i = 0; i < new_dims.nbDims; i++) {
if (i < axis) {
perm.order[i] = i;
} else if (i == axis) {
perm.order[i] = new_dims.nbDims - 1;
} else {
perm.order[i] = i - 1;
}
}
nvinfer1::IShuffleLayer *trans_layer = ctx->network()->addShuffle(*shuffle_layer->getOutput(0));
if (trans_layer == nullptr) {
MS_LOG(ERROR) << "add transpose layer failed for special expand dims op ";
return nullptr;
}
trans_layer->setFirstTranspose(perm);
return trans_layer->getOutput(0);
} else {
std::vector<int64_t> new_shape;
for (int i = 0; i < input_dims.nbDims; i++) {
if (axis == i) {
new_shape.push_back(1);
}
new_shape.push_back(input_dims.d[i] == -1 ? 0 : input_dims.d[i]);
}
if (axis == -1 || axis == input_dims.nbDims) {
new_shape.push_back(1);
}
nvinfer1::Dims new_dims = ConvertCudaDims(new_shape);
if (new_dims.nbDims == -1) {
return nullptr;
}
shuffle_layer->setReshapeDimensions(new_dims);
return shuffle_layer->getOutput(0);
}
}
nvinfer1::ITensor *Broadcast(TensorRTContext *ctx, nvinfer1::ITensor *input, nvinfer1::ITensor *shape) {
int rank = shape->getDimensions().d[0];
nvinfer1::Dims starts{rank};
std::fill(starts.d, starts.d + rank, 0);
nvinfer1::Dims strides{rank};
std::fill(strides.d, strides.d + rank, 1);
auto slice_layer = ctx->network()->addSlice(*input, starts, {}, strides);
slice_layer->setMode(nvinfer1::SliceMode::kWRAP);
slice_layer->setInput(INPUT2, *shape);
auto shuffler_output = slice_layer->getOutput(0);
if (shuffler_output == nullptr) {
MS_LOG(ERROR) << "add slice layer failed";
}
return shuffler_output;
}
nvinfer1::ITensor *Reshape(TensorRTContext *ctx, nvinfer1::ITensor *input, const std::vector<int64_t> &shape) {
return Reshape(ctx, input, ConvertCudaDims(shape));
}

View File

@ -138,6 +138,10 @@ int ParseData2Vector(const mindspore::MSTensor &ms_tensor, std::vector<float> *d
void DebugDims(const std::string &key, const nvinfer1::Dims &dims);
nvinfer1::ITensor *ExpandDim(TensorRTContext *ctx, nvinfer1::ITensor *input_tensor, int axis);
nvinfer1::ITensor *Broadcast(TensorRTContext *ctx, nvinfer1::ITensor *input, nvinfer1::ITensor *shape);
template <typename T>
nvinfer1::DataType GetNvinferDataType();