!41324 [MSLITE] Add TensorRT ConstantOfShape Op

Merge pull request !41324 from zhangyongxian/dev_zhangyongxian_constantofshape
This commit is contained in:
i-robot 2022-09-02 10:44:05 +00:00 committed by Gitee
commit b6fdaf90e9
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
2 changed files with 112 additions and 0 deletions

View File

@ -0,0 +1,73 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <numeric>
#include "src/extendrt/delegate/tensorrt/op/constantofshape_tensorrt.h"
#include "src/extendrt/delegate/tensorrt/tensorrt_utils.h"
namespace mindspore::lite {
int ConstantOfShapeTensorRT::IsSupport(const mindspore::schema::Primitive *primitive,
const std::vector<mindspore::MSTensor> &in_tensors,
const std::vector<mindspore::MSTensor> &out_tensors) {
if (in_tensors.size() != 1) {
MS_LOG(ERROR) << "Unsupported input tensor size, size is " << in_tensors.size() << " : " << op_name_;
return RET_ERROR;
}
if (out_tensors.size() != 1) {
MS_LOG(ERROR) << "Unsupported output tensor size, size is " << out_tensors.size() << " : " << op_name_;
return RET_ERROR;
}
auto constofshape_op = primitive->value_as_ConstantOfShape();
DataType data_type = static_cast<DataType>(constofshape_op->data_type());
if (data_type != DataType::kNumberTypeInt32 && data_type != DataType::kNumberTypeFloat32) {
MS_LOG(ERROR) << "Unsupported data type for " << op_name_;
return RET_ERROR;
}
return RET_OK;
}
int ConstantOfShapeTensorRT::AddInnerOp(TensorRTContext *ctx) {
auto constofshape_op = op_primitive_->value_as_ConstantOfShape();
auto &&value_vector = constofshape_op->value();
auto value_tensor = ctx->ConvertTo1DTensor(*value_vector->begin());
CHECK_NULL_RETURN(value_tensor);
auto unsqueeze_layer = ctx->network()->addShuffle(*value_tensor);
CHECK_NULL_RETURN(unsqueeze_layer);
auto shape = input(ctx, 0).trt_tensor_;
int rank = shape->getDimensions().d[0];
nvinfer1::Dims unsqueeze{rank};
std::fill(unsqueeze.d, unsqueeze.d + rank, 1);
unsqueeze_layer->setReshapeDimensions(unsqueeze);
unsqueeze_layer->setZeroIsPlaceholder(false);
value_tensor = unsqueeze_layer->getOutput(0);
CHECK_NULL_RETURN(value_tensor);
auto out_tensor = Broadcast(ctx, value_tensor, shape);
if (static_cast<DataType>(constofshape_op->data_type()) == DataType::kNumberTypeInt32) {
out_tensor = TRTTensorCast(ctx, out_tensor, nvinfer1::DataType::kINT32, op_name_ + "_cast_out");
}
auto output_helper = ITensorHelper{out_tensor, Format::NHWC, true};
ctx->RegisterTensor(output_helper, out_tensors_[0].Name());
MS_LOG(DEBUG) << "output " << GetTensorFormat(output_helper);
this->layer_ = unsqueeze_layer;
return RET_OK;
}
REGISTER_TENSORRT_CREATOR(schema::PrimitiveType_ConstantOfShape, ConstantOfShapeTensorRT)
} // namespace mindspore::lite

View File

@ -0,0 +1,39 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_DELEGATE_TENSORRT_OP_CONSTANT_OF_SHAPE_TENSORRT_H_
#define MINDSPORE_LITE_SRC_DELEGATE_TENSORRT_OP_CONSTANT_OF_SHAPE_TENSORRT_H_
#include <string>
#include <vector>
#include <algorithm>
#include "src/extendrt/delegate/tensorrt/op/tensorrt_op.h"
namespace mindspore::lite {
class ConstantOfShapeTensorRT : public TensorRTOp {
public:
ConstantOfShapeTensorRT(const schema::Primitive *primitive, const std::vector<mindspore::MSTensor> &in_tensors,
const std::vector<mindspore::MSTensor> &out_tensors, const std::string &name,
const schema::QuantType &quant_type)
: TensorRTOp(primitive, in_tensors, out_tensors, name, quant_type) {}
~ConstantOfShapeTensorRT() override = default;
int AddInnerOp(TensorRTContext *ctx) override;
int IsSupport(const schema::Primitive *primitive, const std::vector<mindspore::MSTensor> &in_tensors,
const std::vector<mindspore::MSTensor> &out_tensors) override;
};
} // namespace mindspore::lite
#endif // MINDSPORE_LITE_SRC_DELEGATE_TENSORRT_OP_CONSTANT_OF_SHAPE_TENSORRT_H_