!46556 [MSLITE] Support TensorScatterUpdate, TensorScatterAdd ops and fix graph binding bug

Merge pull request !46556 from zhangyongxian/dev_zhangyongxian_wd
This commit is contained in:
i-robot 2022-12-08 02:22:26 +00:00 committed by Gitee
commit 7aaafb123d
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
3 changed files with 40 additions and 14 deletions

View File

@ -26,10 +26,6 @@ int BatchNormTensorRT::IsSupport(const BaseOperatorPtr &base_operator, const std
MS_LOG(ERROR) << "Unsupported input tensor size, size is " << in_tensors.size(); MS_LOG(ERROR) << "Unsupported input tensor size, size is " << in_tensors.size();
return RET_ERROR; return RET_ERROR;
} }
if (out_tensors.size() != 1) {
MS_LOG(ERROR) << "Unsupported output tensor size, size is " << out_tensors.size();
return RET_ERROR;
}
return RET_OK; return RET_OK;
} }

View File

@ -18,15 +18,13 @@
#include <numeric> #include <numeric>
#include "src/extendrt/delegate/tensorrt/tensorrt_utils.h" #include "src/extendrt/delegate/tensorrt/tensorrt_utils.h"
#include "ops/scatter_nd_update.h" #include "ops/scatter_nd_update.h"
#include "ops/tensor_scatter_update.h"
#include "ops/tensor_scatter_add.h"
namespace mindspore::lite { namespace mindspore::lite {
int ScatterNdTensorRT::IsSupport(const BaseOperatorPtr &base_operator, const std::vector<TensorInfo> &in_tensors, int ScatterNdTensorRT::IsSupport(const BaseOperatorPtr &base_operator, const std::vector<TensorInfo> &in_tensors,
const std::vector<TensorInfo> &out_tensors) { const std::vector<TensorInfo> &out_tensors) {
#if TRT_VERSION_GE(8, 2) #if TRT_VERSION_GE(8, 2)
if (!IsShapeKnown()) {
MS_LOG(ERROR) << "Unsupported input tensor unknown shape: " << op_name_;
return RET_ERROR;
}
if (in_tensors.size() != INPUT_SIZE3) { if (in_tensors.size() != INPUT_SIZE3) {
MS_LOG(ERROR) << "Unsupported input tensor size, size is " << in_tensors.size() << " : " << op_name_; MS_LOG(ERROR) << "Unsupported input tensor size, size is " << in_tensors.size() << " : " << op_name_;
return RET_ERROR; return RET_ERROR;
@ -45,14 +43,31 @@ int ScatterNdTensorRT::IsSupport(const BaseOperatorPtr &base_operator, const std
int ScatterNdTensorRT::AddInnerOp(TensorRTContext *ctx) { int ScatterNdTensorRT::AddInnerOp(TensorRTContext *ctx) {
#if TRT_VERSION_GE(8, 2) #if TRT_VERSION_GE(8, 2)
ITensorHelper scatter_input; ITensorHelper scatter_input = input(ctx, 0);
int ret = PreprocessInputs2SameDim(ctx, input(ctx, 0), &scatter_input); if (in_tensors_[0].IsConst() && scatter_input.trt_tensor_ == nullptr) {
if (ret != RET_OK || scatter_input.trt_tensor_ == nullptr) { scatter_input.trt_tensor_ = lite::ConvertConstantTensor(ctx, in_tensors_[0], op_name_);
MS_LOG(ERROR) << "PreprocessInputs2SameDim input tensor failed for " << op_name_; scatter_input.format_ = Format::NCHW;
return ret; ctx->RegisterTensor(scatter_input, in_tensors_[0].Name());
}
if (type_ == ops::kNameTensorScatterAdd) {
nvinfer1::ITensor *value_tensor = ctx->ConvertTo1DTensor(0.f);
if (in_tensors_[0].DataType() == DataType::kNumberTypeInt32) {
value_tensor = ctx->ConvertTo1DTensor(0);
}
auto unsqueeze_layer = ctx->network()->addShuffle(*value_tensor);
CHECK_NULL_RETURN(unsqueeze_layer);
auto shape = ctx->network()->addShape(*input(ctx, 0).trt_tensor_)->getOutput(0);
int rank = shape->getDimensions().d[0];
nvinfer1::Dims unsqueeze{rank};
std::fill(unsqueeze.d, unsqueeze.d + rank, 1);
unsqueeze_layer->setReshapeDimensions(unsqueeze);
unsqueeze_layer->setZeroIsPlaceholder(false);
value_tensor = unsqueeze_layer->getOutput(0);
CHECK_NULL_RETURN(value_tensor);
scatter_input.trt_tensor_ = Broadcast(ctx, value_tensor, shape);
} }
ITensorHelper indices_helper; ITensorHelper indices_helper;
ret = PreprocessInputs2SameDim(ctx, input(ctx, 1), &indices_helper); int ret = PreprocessInputs2SameDim(ctx, input(ctx, 1), &indices_helper);
if (ret != RET_OK || indices_helper.trt_tensor_ == nullptr) { if (ret != RET_OK || indices_helper.trt_tensor_ == nullptr) {
MS_LOG(ERROR) << "PreprocessInputs2SameDim indices tensor failed for " << op_name_; MS_LOG(ERROR) << "PreprocessInputs2SameDim indices tensor failed for " << op_name_;
return ret; return ret;
@ -72,6 +87,11 @@ int ScatterNdTensorRT::AddInnerOp(TensorRTContext *ctx) {
} }
nvinfer1::ITensor *out_tensor = scatter_layer->getOutput(0); nvinfer1::ITensor *out_tensor = scatter_layer->getOutput(0);
if (type_ == ops::kNameTensorScatterAdd) {
out_tensor = ctx->network()
->addElementWise(*out_tensor, *input(ctx, 0).trt_tensor_, nvinfer1::ElementWiseOperation::kSUM)
->getOutput(0);
}
ctx->RegisterTensor(ITensorHelper{out_tensor, scatter_input.format_, scatter_input.same_format_}, ctx->RegisterTensor(ITensorHelper{out_tensor, scatter_input.format_, scatter_input.same_format_},
out_tensors_[0].Name()); out_tensors_[0].Name());
this->layer_ = scatter_layer; this->layer_ = scatter_layer;
@ -82,4 +102,6 @@ int ScatterNdTensorRT::AddInnerOp(TensorRTContext *ctx) {
#endif #endif
} }
REGISTER_TENSORRT_CREATOR(ops::kNameScatterNdUpdate, ScatterNdTensorRT) REGISTER_TENSORRT_CREATOR(ops::kNameScatterNdUpdate, ScatterNdTensorRT)
REGISTER_TENSORRT_CREATOR(ops::kNameTensorScatterUpdate, ScatterNdTensorRT)
REGISTER_TENSORRT_CREATOR(ops::kNameTensorScatterAdd, ScatterNdTensorRT)
} // namespace mindspore::lite } // namespace mindspore::lite

View File

@ -431,6 +431,14 @@ int TensorRTSubGraph::BuildTensorRTGraph() {
int TensorRTSubGraph::MarkOutputs() { int TensorRTSubGraph::MarkOutputs() {
// Mark NetWork Output Tensor. // Mark NetWork Output Tensor.
for (const auto &out_tensor : outputs_) { for (const auto &out_tensor : outputs_) {
std::string output_name = out_tensor.Name();
auto input_it = std::find_if(inputs_.begin(), inputs_.end(),
[=](const TensorInfo &input) { return input.Name() == output_name; });
if (input_it != inputs_.end()) {
nvinfer1::ITensor *trt_tensor = SetTensorRTNetworkInput(*input_it, GetInputIndexByName(input_it->Name()));
ctx_->network()->markOutput(*trt_tensor);
continue;
}
if (out_tensor.IsConst()) { if (out_tensor.IsConst()) {
MS_LOG(INFO) << "markOutput for: " << out_tensor.Name(); MS_LOG(INFO) << "markOutput for: " << out_tensor.Name();
auto output_helper = ctx_->MsName2Tensor(out_tensor.Name()); auto output_helper = ctx_->MsName2Tensor(out_tensor.Name());