forked from mindspore-Ecosystem/mindspore
!46556 [MSLITE] Support TensorScatterUpdate, TensorScatterAdd ops and fix graph binding bug
Merge pull request !46556 from zhangyongxian/dev_zhangyongxian_wd
This commit is contained in:
commit
7aaafb123d
|
@ -26,10 +26,6 @@ int BatchNormTensorRT::IsSupport(const BaseOperatorPtr &base_operator, const std
|
|||
MS_LOG(ERROR) << "Unsupported input tensor size, size is " << in_tensors.size();
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (out_tensors.size() != 1) {
|
||||
MS_LOG(ERROR) << "Unsupported output tensor size, size is " << out_tensors.size();
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
|
|
@ -18,15 +18,13 @@
|
|||
#include <numeric>
|
||||
#include "src/extendrt/delegate/tensorrt/tensorrt_utils.h"
|
||||
#include "ops/scatter_nd_update.h"
|
||||
#include "ops/tensor_scatter_update.h"
|
||||
#include "ops/tensor_scatter_add.h"
|
||||
|
||||
namespace mindspore::lite {
|
||||
int ScatterNdTensorRT::IsSupport(const BaseOperatorPtr &base_operator, const std::vector<TensorInfo> &in_tensors,
|
||||
const std::vector<TensorInfo> &out_tensors) {
|
||||
#if TRT_VERSION_GE(8, 2)
|
||||
if (!IsShapeKnown()) {
|
||||
MS_LOG(ERROR) << "Unsupported input tensor unknown shape: " << op_name_;
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (in_tensors.size() != INPUT_SIZE3) {
|
||||
MS_LOG(ERROR) << "Unsupported input tensor size, size is " << in_tensors.size() << " : " << op_name_;
|
||||
return RET_ERROR;
|
||||
|
@ -45,14 +43,31 @@ int ScatterNdTensorRT::IsSupport(const BaseOperatorPtr &base_operator, const std
|
|||
|
||||
int ScatterNdTensorRT::AddInnerOp(TensorRTContext *ctx) {
|
||||
#if TRT_VERSION_GE(8, 2)
|
||||
ITensorHelper scatter_input;
|
||||
int ret = PreprocessInputs2SameDim(ctx, input(ctx, 0), &scatter_input);
|
||||
if (ret != RET_OK || scatter_input.trt_tensor_ == nullptr) {
|
||||
MS_LOG(ERROR) << "PreprocessInputs2SameDim input tensor failed for " << op_name_;
|
||||
return ret;
|
||||
ITensorHelper scatter_input = input(ctx, 0);
|
||||
if (in_tensors_[0].IsConst() && scatter_input.trt_tensor_ == nullptr) {
|
||||
scatter_input.trt_tensor_ = lite::ConvertConstantTensor(ctx, in_tensors_[0], op_name_);
|
||||
scatter_input.format_ = Format::NCHW;
|
||||
ctx->RegisterTensor(scatter_input, in_tensors_[0].Name());
|
||||
}
|
||||
if (type_ == ops::kNameTensorScatterAdd) {
|
||||
nvinfer1::ITensor *value_tensor = ctx->ConvertTo1DTensor(0.f);
|
||||
if (in_tensors_[0].DataType() == DataType::kNumberTypeInt32) {
|
||||
value_tensor = ctx->ConvertTo1DTensor(0);
|
||||
}
|
||||
auto unsqueeze_layer = ctx->network()->addShuffle(*value_tensor);
|
||||
CHECK_NULL_RETURN(unsqueeze_layer);
|
||||
auto shape = ctx->network()->addShape(*input(ctx, 0).trt_tensor_)->getOutput(0);
|
||||
int rank = shape->getDimensions().d[0];
|
||||
nvinfer1::Dims unsqueeze{rank};
|
||||
std::fill(unsqueeze.d, unsqueeze.d + rank, 1);
|
||||
unsqueeze_layer->setReshapeDimensions(unsqueeze);
|
||||
unsqueeze_layer->setZeroIsPlaceholder(false);
|
||||
value_tensor = unsqueeze_layer->getOutput(0);
|
||||
CHECK_NULL_RETURN(value_tensor);
|
||||
scatter_input.trt_tensor_ = Broadcast(ctx, value_tensor, shape);
|
||||
}
|
||||
ITensorHelper indices_helper;
|
||||
ret = PreprocessInputs2SameDim(ctx, input(ctx, 1), &indices_helper);
|
||||
int ret = PreprocessInputs2SameDim(ctx, input(ctx, 1), &indices_helper);
|
||||
if (ret != RET_OK || indices_helper.trt_tensor_ == nullptr) {
|
||||
MS_LOG(ERROR) << "PreprocessInputs2SameDim indices tensor failed for " << op_name_;
|
||||
return ret;
|
||||
|
@ -72,6 +87,11 @@ int ScatterNdTensorRT::AddInnerOp(TensorRTContext *ctx) {
|
|||
}
|
||||
|
||||
nvinfer1::ITensor *out_tensor = scatter_layer->getOutput(0);
|
||||
if (type_ == ops::kNameTensorScatterAdd) {
|
||||
out_tensor = ctx->network()
|
||||
->addElementWise(*out_tensor, *input(ctx, 0).trt_tensor_, nvinfer1::ElementWiseOperation::kSUM)
|
||||
->getOutput(0);
|
||||
}
|
||||
ctx->RegisterTensor(ITensorHelper{out_tensor, scatter_input.format_, scatter_input.same_format_},
|
||||
out_tensors_[0].Name());
|
||||
this->layer_ = scatter_layer;
|
||||
|
@ -82,4 +102,6 @@ int ScatterNdTensorRT::AddInnerOp(TensorRTContext *ctx) {
|
|||
#endif
|
||||
}
|
||||
REGISTER_TENSORRT_CREATOR(ops::kNameScatterNdUpdate, ScatterNdTensorRT)
|
||||
REGISTER_TENSORRT_CREATOR(ops::kNameTensorScatterUpdate, ScatterNdTensorRT)
|
||||
REGISTER_TENSORRT_CREATOR(ops::kNameTensorScatterAdd, ScatterNdTensorRT)
|
||||
} // namespace mindspore::lite
|
||||
|
|
|
@ -431,6 +431,14 @@ int TensorRTSubGraph::BuildTensorRTGraph() {
|
|||
int TensorRTSubGraph::MarkOutputs() {
|
||||
// Mark NetWork Output Tensor.
|
||||
for (const auto &out_tensor : outputs_) {
|
||||
std::string output_name = out_tensor.Name();
|
||||
auto input_it = std::find_if(inputs_.begin(), inputs_.end(),
|
||||
[=](const TensorInfo &input) { return input.Name() == output_name; });
|
||||
if (input_it != inputs_.end()) {
|
||||
nvinfer1::ITensor *trt_tensor = SetTensorRTNetworkInput(*input_it, GetInputIndexByName(input_it->Name()));
|
||||
ctx_->network()->markOutput(*trt_tensor);
|
||||
continue;
|
||||
}
|
||||
if (out_tensor.IsConst()) {
|
||||
MS_LOG(INFO) << "markOutput for: " << out_tensor.Name();
|
||||
auto output_helper = ctx_->MsName2Tensor(out_tensor.Name());
|
||||
|
|
Loading…
Reference in New Issue