forked from mindspore-Ecosystem/mindspore
!46556 [MSLITE] Support TensorScatterUpdate, TensorScatterAdd ops and fix graph binding bug
Merge pull request !46556 from zhangyongxian/dev_zhangyongxian_wd
This commit is contained in:
commit
7aaafb123d
|
@ -26,10 +26,6 @@ int BatchNormTensorRT::IsSupport(const BaseOperatorPtr &base_operator, const std
|
||||||
MS_LOG(ERROR) << "Unsupported input tensor size, size is " << in_tensors.size();
|
MS_LOG(ERROR) << "Unsupported input tensor size, size is " << in_tensors.size();
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
}
|
}
|
||||||
if (out_tensors.size() != 1) {
|
|
||||||
MS_LOG(ERROR) << "Unsupported output tensor size, size is " << out_tensors.size();
|
|
||||||
return RET_ERROR;
|
|
||||||
}
|
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -18,15 +18,13 @@
|
||||||
#include <numeric>
|
#include <numeric>
|
||||||
#include "src/extendrt/delegate/tensorrt/tensorrt_utils.h"
|
#include "src/extendrt/delegate/tensorrt/tensorrt_utils.h"
|
||||||
#include "ops/scatter_nd_update.h"
|
#include "ops/scatter_nd_update.h"
|
||||||
|
#include "ops/tensor_scatter_update.h"
|
||||||
|
#include "ops/tensor_scatter_add.h"
|
||||||
|
|
||||||
namespace mindspore::lite {
|
namespace mindspore::lite {
|
||||||
int ScatterNdTensorRT::IsSupport(const BaseOperatorPtr &base_operator, const std::vector<TensorInfo> &in_tensors,
|
int ScatterNdTensorRT::IsSupport(const BaseOperatorPtr &base_operator, const std::vector<TensorInfo> &in_tensors,
|
||||||
const std::vector<TensorInfo> &out_tensors) {
|
const std::vector<TensorInfo> &out_tensors) {
|
||||||
#if TRT_VERSION_GE(8, 2)
|
#if TRT_VERSION_GE(8, 2)
|
||||||
if (!IsShapeKnown()) {
|
|
||||||
MS_LOG(ERROR) << "Unsupported input tensor unknown shape: " << op_name_;
|
|
||||||
return RET_ERROR;
|
|
||||||
}
|
|
||||||
if (in_tensors.size() != INPUT_SIZE3) {
|
if (in_tensors.size() != INPUT_SIZE3) {
|
||||||
MS_LOG(ERROR) << "Unsupported input tensor size, size is " << in_tensors.size() << " : " << op_name_;
|
MS_LOG(ERROR) << "Unsupported input tensor size, size is " << in_tensors.size() << " : " << op_name_;
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
|
@ -45,14 +43,31 @@ int ScatterNdTensorRT::IsSupport(const BaseOperatorPtr &base_operator, const std
|
||||||
|
|
||||||
int ScatterNdTensorRT::AddInnerOp(TensorRTContext *ctx) {
|
int ScatterNdTensorRT::AddInnerOp(TensorRTContext *ctx) {
|
||||||
#if TRT_VERSION_GE(8, 2)
|
#if TRT_VERSION_GE(8, 2)
|
||||||
ITensorHelper scatter_input;
|
ITensorHelper scatter_input = input(ctx, 0);
|
||||||
int ret = PreprocessInputs2SameDim(ctx, input(ctx, 0), &scatter_input);
|
if (in_tensors_[0].IsConst() && scatter_input.trt_tensor_ == nullptr) {
|
||||||
if (ret != RET_OK || scatter_input.trt_tensor_ == nullptr) {
|
scatter_input.trt_tensor_ = lite::ConvertConstantTensor(ctx, in_tensors_[0], op_name_);
|
||||||
MS_LOG(ERROR) << "PreprocessInputs2SameDim input tensor failed for " << op_name_;
|
scatter_input.format_ = Format::NCHW;
|
||||||
return ret;
|
ctx->RegisterTensor(scatter_input, in_tensors_[0].Name());
|
||||||
|
}
|
||||||
|
if (type_ == ops::kNameTensorScatterAdd) {
|
||||||
|
nvinfer1::ITensor *value_tensor = ctx->ConvertTo1DTensor(0.f);
|
||||||
|
if (in_tensors_[0].DataType() == DataType::kNumberTypeInt32) {
|
||||||
|
value_tensor = ctx->ConvertTo1DTensor(0);
|
||||||
|
}
|
||||||
|
auto unsqueeze_layer = ctx->network()->addShuffle(*value_tensor);
|
||||||
|
CHECK_NULL_RETURN(unsqueeze_layer);
|
||||||
|
auto shape = ctx->network()->addShape(*input(ctx, 0).trt_tensor_)->getOutput(0);
|
||||||
|
int rank = shape->getDimensions().d[0];
|
||||||
|
nvinfer1::Dims unsqueeze{rank};
|
||||||
|
std::fill(unsqueeze.d, unsqueeze.d + rank, 1);
|
||||||
|
unsqueeze_layer->setReshapeDimensions(unsqueeze);
|
||||||
|
unsqueeze_layer->setZeroIsPlaceholder(false);
|
||||||
|
value_tensor = unsqueeze_layer->getOutput(0);
|
||||||
|
CHECK_NULL_RETURN(value_tensor);
|
||||||
|
scatter_input.trt_tensor_ = Broadcast(ctx, value_tensor, shape);
|
||||||
}
|
}
|
||||||
ITensorHelper indices_helper;
|
ITensorHelper indices_helper;
|
||||||
ret = PreprocessInputs2SameDim(ctx, input(ctx, 1), &indices_helper);
|
int ret = PreprocessInputs2SameDim(ctx, input(ctx, 1), &indices_helper);
|
||||||
if (ret != RET_OK || indices_helper.trt_tensor_ == nullptr) {
|
if (ret != RET_OK || indices_helper.trt_tensor_ == nullptr) {
|
||||||
MS_LOG(ERROR) << "PreprocessInputs2SameDim indices tensor failed for " << op_name_;
|
MS_LOG(ERROR) << "PreprocessInputs2SameDim indices tensor failed for " << op_name_;
|
||||||
return ret;
|
return ret;
|
||||||
|
@ -72,6 +87,11 @@ int ScatterNdTensorRT::AddInnerOp(TensorRTContext *ctx) {
|
||||||
}
|
}
|
||||||
|
|
||||||
nvinfer1::ITensor *out_tensor = scatter_layer->getOutput(0);
|
nvinfer1::ITensor *out_tensor = scatter_layer->getOutput(0);
|
||||||
|
if (type_ == ops::kNameTensorScatterAdd) {
|
||||||
|
out_tensor = ctx->network()
|
||||||
|
->addElementWise(*out_tensor, *input(ctx, 0).trt_tensor_, nvinfer1::ElementWiseOperation::kSUM)
|
||||||
|
->getOutput(0);
|
||||||
|
}
|
||||||
ctx->RegisterTensor(ITensorHelper{out_tensor, scatter_input.format_, scatter_input.same_format_},
|
ctx->RegisterTensor(ITensorHelper{out_tensor, scatter_input.format_, scatter_input.same_format_},
|
||||||
out_tensors_[0].Name());
|
out_tensors_[0].Name());
|
||||||
this->layer_ = scatter_layer;
|
this->layer_ = scatter_layer;
|
||||||
|
@ -82,4 +102,6 @@ int ScatterNdTensorRT::AddInnerOp(TensorRTContext *ctx) {
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
REGISTER_TENSORRT_CREATOR(ops::kNameScatterNdUpdate, ScatterNdTensorRT)
|
REGISTER_TENSORRT_CREATOR(ops::kNameScatterNdUpdate, ScatterNdTensorRT)
|
||||||
|
REGISTER_TENSORRT_CREATOR(ops::kNameTensorScatterUpdate, ScatterNdTensorRT)
|
||||||
|
REGISTER_TENSORRT_CREATOR(ops::kNameTensorScatterAdd, ScatterNdTensorRT)
|
||||||
} // namespace mindspore::lite
|
} // namespace mindspore::lite
|
||||||
|
|
|
@ -431,6 +431,14 @@ int TensorRTSubGraph::BuildTensorRTGraph() {
|
||||||
int TensorRTSubGraph::MarkOutputs() {
|
int TensorRTSubGraph::MarkOutputs() {
|
||||||
// Mark NetWork Output Tensor.
|
// Mark NetWork Output Tensor.
|
||||||
for (const auto &out_tensor : outputs_) {
|
for (const auto &out_tensor : outputs_) {
|
||||||
|
std::string output_name = out_tensor.Name();
|
||||||
|
auto input_it = std::find_if(inputs_.begin(), inputs_.end(),
|
||||||
|
[=](const TensorInfo &input) { return input.Name() == output_name; });
|
||||||
|
if (input_it != inputs_.end()) {
|
||||||
|
nvinfer1::ITensor *trt_tensor = SetTensorRTNetworkInput(*input_it, GetInputIndexByName(input_it->Name()));
|
||||||
|
ctx_->network()->markOutput(*trt_tensor);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
if (out_tensor.IsConst()) {
|
if (out_tensor.IsConst()) {
|
||||||
MS_LOG(INFO) << "markOutput for: " << out_tensor.Name();
|
MS_LOG(INFO) << "markOutput for: " << out_tensor.Name();
|
||||||
auto output_helper = ctx_->MsName2Tensor(out_tensor.Name());
|
auto output_helper = ctx_->MsName2Tensor(out_tensor.Name());
|
||||||
|
|
Loading…
Reference in New Issue