!46556 [MSLITE] Support TensorScatterUpdate, TensorScatterAdd ops and fix graph binding bug

Merge pull request !46556 from zhangyongxian/dev_zhangyongxian_wd
This commit is contained in:
i-robot 2022-12-08 02:22:26 +00:00 committed by Gitee
commit 7aaafb123d
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
3 changed files with 40 additions and 14 deletions

View File

@ -26,10 +26,6 @@ int BatchNormTensorRT::IsSupport(const BaseOperatorPtr &base_operator, const std
MS_LOG(ERROR) << "Unsupported input tensor size, size is " << in_tensors.size();
return RET_ERROR;
}
if (out_tensors.size() != 1) {
MS_LOG(ERROR) << "Unsupported output tensor size, size is " << out_tensors.size();
return RET_ERROR;
}
return RET_OK;
}

View File

@ -18,15 +18,13 @@
#include <numeric>
#include "src/extendrt/delegate/tensorrt/tensorrt_utils.h"
#include "ops/scatter_nd_update.h"
#include "ops/tensor_scatter_update.h"
#include "ops/tensor_scatter_add.h"
namespace mindspore::lite {
int ScatterNdTensorRT::IsSupport(const BaseOperatorPtr &base_operator, const std::vector<TensorInfo> &in_tensors,
const std::vector<TensorInfo> &out_tensors) {
#if TRT_VERSION_GE(8, 2)
if (!IsShapeKnown()) {
MS_LOG(ERROR) << "Unsupported input tensor unknown shape: " << op_name_;
return RET_ERROR;
}
if (in_tensors.size() != INPUT_SIZE3) {
MS_LOG(ERROR) << "Unsupported input tensor size, size is " << in_tensors.size() << " : " << op_name_;
return RET_ERROR;
@ -45,14 +43,31 @@ int ScatterNdTensorRT::IsSupport(const BaseOperatorPtr &base_operator, const std
int ScatterNdTensorRT::AddInnerOp(TensorRTContext *ctx) {
#if TRT_VERSION_GE(8, 2)
ITensorHelper scatter_input;
int ret = PreprocessInputs2SameDim(ctx, input(ctx, 0), &scatter_input);
if (ret != RET_OK || scatter_input.trt_tensor_ == nullptr) {
MS_LOG(ERROR) << "PreprocessInputs2SameDim input tensor failed for " << op_name_;
return ret;
ITensorHelper scatter_input = input(ctx, 0);
if (in_tensors_[0].IsConst() && scatter_input.trt_tensor_ == nullptr) {
scatter_input.trt_tensor_ = lite::ConvertConstantTensor(ctx, in_tensors_[0], op_name_);
scatter_input.format_ = Format::NCHW;
ctx->RegisterTensor(scatter_input, in_tensors_[0].Name());
}
if (type_ == ops::kNameTensorScatterAdd) {
nvinfer1::ITensor *value_tensor = ctx->ConvertTo1DTensor(0.f);
if (in_tensors_[0].DataType() == DataType::kNumberTypeInt32) {
value_tensor = ctx->ConvertTo1DTensor(0);
}
auto unsqueeze_layer = ctx->network()->addShuffle(*value_tensor);
CHECK_NULL_RETURN(unsqueeze_layer);
auto shape = ctx->network()->addShape(*input(ctx, 0).trt_tensor_)->getOutput(0);
int rank = shape->getDimensions().d[0];
nvinfer1::Dims unsqueeze{rank};
std::fill(unsqueeze.d, unsqueeze.d + rank, 1);
unsqueeze_layer->setReshapeDimensions(unsqueeze);
unsqueeze_layer->setZeroIsPlaceholder(false);
value_tensor = unsqueeze_layer->getOutput(0);
CHECK_NULL_RETURN(value_tensor);
scatter_input.trt_tensor_ = Broadcast(ctx, value_tensor, shape);
}
ITensorHelper indices_helper;
ret = PreprocessInputs2SameDim(ctx, input(ctx, 1), &indices_helper);
int ret = PreprocessInputs2SameDim(ctx, input(ctx, 1), &indices_helper);
if (ret != RET_OK || indices_helper.trt_tensor_ == nullptr) {
MS_LOG(ERROR) << "PreprocessInputs2SameDim indices tensor failed for " << op_name_;
return ret;
@ -72,6 +87,11 @@ int ScatterNdTensorRT::AddInnerOp(TensorRTContext *ctx) {
}
nvinfer1::ITensor *out_tensor = scatter_layer->getOutput(0);
if (type_ == ops::kNameTensorScatterAdd) {
out_tensor = ctx->network()
->addElementWise(*out_tensor, *input(ctx, 0).trt_tensor_, nvinfer1::ElementWiseOperation::kSUM)
->getOutput(0);
}
ctx->RegisterTensor(ITensorHelper{out_tensor, scatter_input.format_, scatter_input.same_format_},
out_tensors_[0].Name());
this->layer_ = scatter_layer;
@ -82,4 +102,6 @@ int ScatterNdTensorRT::AddInnerOp(TensorRTContext *ctx) {
#endif
}
REGISTER_TENSORRT_CREATOR(ops::kNameScatterNdUpdate, ScatterNdTensorRT)
REGISTER_TENSORRT_CREATOR(ops::kNameTensorScatterUpdate, ScatterNdTensorRT)
REGISTER_TENSORRT_CREATOR(ops::kNameTensorScatterAdd, ScatterNdTensorRT)
} // namespace mindspore::lite

View File

@ -431,6 +431,14 @@ int TensorRTSubGraph::BuildTensorRTGraph() {
int TensorRTSubGraph::MarkOutputs() {
// Mark NetWork Output Tensor.
for (const auto &out_tensor : outputs_) {
std::string output_name = out_tensor.Name();
auto input_it = std::find_if(inputs_.begin(), inputs_.end(),
[=](const TensorInfo &input) { return input.Name() == output_name; });
if (input_it != inputs_.end()) {
nvinfer1::ITensor *trt_tensor = SetTensorRTNetworkInput(*input_it, GetInputIndexByName(input_it->Name()));
ctx_->network()->markOutput(*trt_tensor);
continue;
}
if (out_tensor.IsConst()) {
MS_LOG(INFO) << "markOutput for: " << out_tensor.Name();
auto output_helper = ctx_->MsName2Tensor(out_tensor.Name());