From 93217d2ef342521c1d128b09ece16c6a93259b84 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BC=A0=E5=8B=87=E8=B4=A4?= Date: Fri, 2 Sep 2022 15:15:10 +0800 Subject: [PATCH] [MSLITE] Support rank 0 for expanddim and fix bug for fill --- .../delegate/tensorrt/op/fill_tensorrt.cc | 19 +++++++++++++------ .../delegate/tensorrt/op/shuffle_tensorrt.cc | 4 ++++ .../delegate/tensorrt/tensorrt_subgraph.cc | 14 -------------- 3 files changed, 17 insertions(+), 20 deletions(-) diff --git a/mindspore/lite/src/extendrt/delegate/tensorrt/op/fill_tensorrt.cc b/mindspore/lite/src/extendrt/delegate/tensorrt/op/fill_tensorrt.cc index 21a7e9bd8fc..0003fba381a 100644 --- a/mindspore/lite/src/extendrt/delegate/tensorrt/op/fill_tensorrt.cc +++ b/mindspore/lite/src/extendrt/delegate/tensorrt/op/fill_tensorrt.cc @@ -49,14 +49,21 @@ int FillTensorRT::AddInnerOp(TensorRTContext *ctx) { return RET_ERROR; } fill_layer->setInput(0, *input(ctx, 1).trt_tensor_); - auto alpha_tensor = - ConvertScalarToITensor(ctx, 0, in_tensors_[0].Data().get(), in_tensors_[0].DataType(), op_name_ + "_alpha"); + nvinfer1::ITensor *alpha_tensor = nullptr; + if (in_tensors_[0].Data() == nullptr) { + alpha_tensor = input(ctx, 0).trt_tensor_; + } else { + alpha_tensor = + ConvertScalarToITensor(ctx, 0, in_tensors_[0].Data().get(), in_tensors_[0].DataType(), op_name_ + "_alpha"); + } fill_layer->setInput(1, *alpha_tensor); int nbdims = input(ctx, 1).trt_tensor_->getDimensions().d[0]; - zeros_ = std::vector(nbdims, 0.f); - nvinfer1::Dims beta_dims{1, {nbdims}}; - nvinfer1::Weights weights{ConvertDataType(DataType::kNumberTypeFloat32), &zeros_[0], nbdims}; - auto beta_tensor = ctx->network()->addConstant(beta_dims, weights)->getOutput(0); + nvinfer1::ITensor *beta_tensor = nullptr; + if (in_tensors_[0].DataType() == DataType::kNumberTypeInt32) { + beta_tensor = ctx->ConvertTo1DTensor(std::vector(nbdims, 0)); + } else { + beta_tensor = ctx->ConvertTo1DTensor(std::vector(nbdims, 0.f)); + } fill_layer->setInput(INPUT_SIZE2, *beta_tensor); nvinfer1::ITensor *out_tensor = fill_layer->getOutput(0); diff --git a/mindspore/lite/src/extendrt/delegate/tensorrt/op/shuffle_tensorrt.cc b/mindspore/lite/src/extendrt/delegate/tensorrt/op/shuffle_tensorrt.cc index 80c45fc7dd2..be60dd2d4ae 100644 --- a/mindspore/lite/src/extendrt/delegate/tensorrt/op/shuffle_tensorrt.cc +++ b/mindspore/lite/src/extendrt/delegate/tensorrt/op/shuffle_tensorrt.cc @@ -320,6 +320,10 @@ int ShuffleTensorRT::AddFlattenOp(nvinfer1::IShuffleLayer *shuffle_layer) { } int ShuffleTensorRT::AddExpandDimsOp(nvinfer1::IShuffleLayer *shuffle_layer) { + if (!input(ctx_, 0).is_tensor_) { + shuffler_output_ = shuffler_input_; + return RET_OK; + } if (in_tensors_[1].DataType() != DataType::kNumberTypeInt32) { MS_LOG(WARNING) << op_name_ << " axis tensor data type is " << static_cast(in_tensors_[1].DataType()); } diff --git a/mindspore/lite/src/extendrt/delegate/tensorrt/tensorrt_subgraph.cc b/mindspore/lite/src/extendrt/delegate/tensorrt/tensorrt_subgraph.cc index 44c25f2909d..7ea2911fe08 100644 --- a/mindspore/lite/src/extendrt/delegate/tensorrt/tensorrt_subgraph.cc +++ b/mindspore/lite/src/extendrt/delegate/tensorrt/tensorrt_subgraph.cc @@ -657,20 +657,6 @@ bool TensorRTSubGraph::ValidInputResizeDims(const nvinfer1::Dims &construct_dims MS_LOG(ERROR) << "invalid resize input."; return false; } - if (input_hw_index_ == -1) { - // only NHWC format support HW resize, otherwise only support batchsize resize - for (int d = 0; d < construct_dims.nbDims; d++) { - if (d != input_batchsize_index_ && construct_dims.d[d] != resize_input_shape[d]) { - MS_LOG(ERROR) << "only support dynamic batch size resize input."; - return false; - } - } - } else if ((input_hw_index_ == 1 && construct_dims.d[DIMENSION_3D] != resize_input_shape[DIMENSION_3D]) || - (input_hw_index_ == DIMENSION_2D && construct_dims.d[1] != resize_input_shape[1])) { - // input may be nhwc || nchw - MS_LOG(ERROR) << "don't support dynamic channel resize input."; - return false; - } return true; }