Add oneslike gathernd and square tensorrt ops

This commit is contained in:
wangwenzhe 2022-09-06 21:42:39 +08:00
parent 7014e91c14
commit 4424f0d694
6 changed files with 346 additions and 0 deletions

View File

@ -0,0 +1,79 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/extendrt/delegate/tensorrt/op/gather_nd_tensorrt.h"
#include "src/extendrt/delegate/tensorrt/tensorrt_utils.h"
namespace mindspore::lite {
int GatherNDTensorRT::IsSupport(const schema::Primitive *primitive, const std::vector<mindspore::MSTensor> &in_tensors,
const std::vector<mindspore::MSTensor> &out_tensors) {
#if TRT_VERSION_GE(8, 2)
if (in_tensors.size() != INPUT_SIZE2) {
MS_LOG(ERROR) << "Unsupported input tensor size, size is " << in_tensors.size();
return RET_ERROR;
}
if (in_tensors[1].DataType() != DataType::kNumberTypeInt32) {
MS_LOG(ERROR) << "Gather indices only support Int32";
return RET_ERROR;
}
if (out_tensors.size() != 1) {
MS_LOG(ERROR) << "Unsupported output tensor size, size is " << out_tensors.size();
return RET_ERROR;
}
return RET_OK;
#else
MS_LOG(WARNING) << "low TensorRT version don't support gathernd op, please upgrade TensorRT version to 8.2 or higher";
return RET_ERROR;
#endif
}
int GatherNDTensorRT::AddInnerOp(TensorRTContext *ctx) {
#if TRT_VERSION_GE(8, 2)
if (ctx == nullptr || ctx->network() == nullptr) {
MS_LOG(ERROR) << "context or network is invalid";
return RET_ERROR;
}
ITensorHelper gather_nd_input = input(ctx, 0);
ITensorHelper indices_tensor = input(ctx, 1);
if (indices_tensor.trt_tensor_->getDimensions().nbDims < 1) {
MS_LOG(ERROR) << "addGather failed for TensorRT.";
return RET_ERROR;
}
auto nbElementWiseDims = gather_nd_input.trt_tensor_->getDimensions().d[0];
if (nbElementWiseDims == -1) {
nbElementWiseDims = 0;
}
nvinfer1::IGatherLayer *gather_layer =
ctx->network()->addGatherV2(*gather_nd_input.trt_tensor_, *indices_tensor.trt_tensor_, nvinfer1::GatherMode::kND);
if (gather_layer == nullptr) {
MS_LOG(ERROR) << "addGatherND failed for TensorRT.";
return RET_ERROR;
}
gather_layer->setNbElementWiseDims(nbElementWiseDims);
this->layer_ = gather_layer;
gather_layer->setName(op_name_.c_str());
nvinfer1::ITensor *op_output = gather_layer->getOutput(0);
ctx->RegisterTensor(ITensorHelper{op_output, gather_nd_input.format_, gather_nd_input.same_format_},
out_tensors_[0].Name());
return RET_OK;
#else
MS_LOG(WARNING) << "low TensorRT version don't support gathernd op, please upgrade TensorRT version to 8.2 or higher";
return RET_ERROR;
#endif
}
REGISTER_TENSORRT_CREATOR(schema::PrimitiveType_GatherNd, GatherNDTensorRT)
} // namespace mindspore::lite

View File

@ -0,0 +1,39 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_GATHER_ND_TENSORRT_H_
#define MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_GATHER_ND_TENSORRT_H_
#include <string>
#include <vector>
#include "src/extendrt/delegate/tensorrt/op/tensorrt_op.h"
namespace mindspore::lite {
class GatherNDTensorRT : public TensorRTOp {
public:
GatherNDTensorRT(const schema::Primitive *primitive, const std::vector<mindspore::MSTensor> &in_tensors,
const std::vector<mindspore::MSTensor> &out_tensors, const std::string &name,
const schema::QuantType &quant_type)
: TensorRTOp(primitive, in_tensors, out_tensors, name, quant_type) {}
~GatherNDTensorRT() override = default;
int AddInnerOp(TensorRTContext *ctx) override;
int IsSupport(const schema::Primitive *primitive, const std::vector<mindspore::MSTensor> &in_tensors,
const std::vector<mindspore::MSTensor> &out_tensors) override;
};
} // namespace mindspore::lite
#endif // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_GATHER_ND_TENSORRT_H_

View File

@ -0,0 +1,75 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <memory>
#include "src/extendrt/delegate/tensorrt/op/oneslike_tensorrt.h"
namespace mindspore::lite {
int OneslikeTensorRT::IsSupport(const schema::Primitive *primitive, const std::vector<mindspore::MSTensor> &in_tensors,
const std::vector<mindspore::MSTensor> &out_tensors) {
if (in_tensors.size() != 1) {
MS_LOG(ERROR) << "Unsupported input tensor size, size is " << in_tensors.size();
return RET_ERROR;
}
if (out_tensors.size() != 1) {
MS_LOG(ERROR) << "Unsupported output tensor size, size is " << in_tensors.size();
return RET_ERROR;
}
return RET_OK;
}
int OneslikeTensorRT::AddInnerOp(TensorRTContext *ctx) {
if (ctx == nullptr || ctx->network() == nullptr) {
MS_LOG(ERROR) << "context or network is invalid";
return RET_ERROR;
}
int input_nbdims = input(ctx, 0).trt_tensor_->getDimensions().nbDims;
if (input_nbdims == -1) {
MS_LOG(ERROR) << "oneslike op failed for " << op_name_;
return RET_ERROR;
}
int ret = RunAsTrtOps(ctx);
if (ret != RET_OK) {
MS_LOG(ERROR) << "oneslike op failed for " << op_name_;
return ret;
}
return ret;
}
int OneslikeTensorRT::RunAsTrtOps(TensorRTContext *ctx) {
if (ctx == nullptr || ctx->network() == nullptr) {
MS_LOG(ERROR) << "context or network is invalid";
return RET_ERROR;
}
auto const_zero = ctx->ConvertTo1DTensor(std::vector<float>(input(ctx, 0).trt_tensor_->getDimensions().nbDims, 0.f));
CHECK_NULL_RETURN(const_zero);
auto const_one = ctx->ConvertTo1DTensor(std::vector<float>(input(ctx, 0).trt_tensor_->getDimensions().nbDims, 1.f));
CHECK_NULL_RETURN(const_one);
auto prod_tensor = ctx->network()
->addElementWise(*input(ctx, 0).trt_tensor_, *const_zero, nvinfer1::ElementWiseOperation::kPROD)
->getOutput(0);
CHECK_NULL_RETURN(prod_tensor);
auto oneslike_layer = ctx->network()->addElementWise(*prod_tensor, *const_one, nvinfer1::ElementWiseOperation::kSUM);
CHECK_NULL_RETURN(oneslike_layer);
auto out_tensor = oneslike_layer->getOutput(0);
CHECK_NULL_RETURN(out_tensor);
this->layer_ = oneslike_layer;
ctx->RegisterTensor(ITensorHelper{out_tensor, input(ctx, 0).format_, input(ctx, 0).same_format_},
out_tensors_[0].Name());
return RET_OK;
}
REGISTER_TENSORRT_CREATOR(schema::PrimitiveType_OnesLike, OneslikeTensorRT)
} // namespace mindspore::lite

View File

@ -0,0 +1,42 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_ONESLIKE_TENSORRT_H_
#define MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_ONESLIKE_TENSORRT_H_
#include <string>
#include <vector>
#include "src/extendrt/delegate/tensorrt/op/tensorrt_op.h"
namespace mindspore::lite {
class OneslikeTensorRT : public TensorRTOp {
public:
OneslikeTensorRT(const schema::Primitive *primitive, const std::vector<mindspore::MSTensor> &in_tensors,
const std::vector<mindspore::MSTensor> &out_tensors, const std::string &name,
const schema::QuantType &quant_type)
: TensorRTOp(primitive, in_tensors, out_tensors, name, quant_type) {}
~OneslikeTensorRT() override = default;
int AddInnerOp(TensorRTContext *ctx) override;
int IsSupport(const schema::Primitive *primitive, const std::vector<mindspore::MSTensor> &in_tensors,
const std::vector<mindspore::MSTensor> &out_tensors) override;
private:
int RunAsTrtOps(TensorRTContext *ctx);
};
} // namespace mindspore::lite
#endif // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_ONESLIKE_TENSORRT_H_

View File

@ -0,0 +1,69 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/extendrt/delegate/tensorrt/op/square_tensorrt.h"
namespace mindspore::lite {
int SquareTensorRT::IsSupport(const schema::Primitive *primitive, const std::vector<mindspore::MSTensor> &in_tensors,
const std::vector<mindspore::MSTensor> &out_tensors) {
if (in_tensors.size() != 1) {
MS_LOG(ERROR) << "Unsupported input tensor size, size is " << in_tensors.size();
return RET_ERROR;
}
if (out_tensors.size() != 1) {
MS_LOG(ERROR) << "Unsupported output tensor size, size is " << in_tensors.size();
return RET_ERROR;
}
return RET_OK;
}
int SquareTensorRT::AddInnerOp(TensorRTContext *ctx) {
if (ctx == nullptr || ctx->network() == nullptr) {
MS_LOG(ERROR) << "context or network is invalid";
return RET_ERROR;
}
auto square_op = op_primitive_->value_as_Square();
CHECK_NULL_RETURN(square_op);
int input_nbdims = input(ctx, 0).trt_tensor_->getDimensions().nbDims;
if (input_nbdims == -1) {
MS_LOG(ERROR) << "square failed for " << op_name_;
return RET_ERROR;
}
int ret = RunAsTrtOps(ctx);
if (ret != RET_OK) {
MS_LOG(ERROR) << "square failed for " << op_name_;
return ret;
}
return ret;
}
int SquareTensorRT::RunAsTrtOps(TensorRTContext *ctx) {
if (ctx == nullptr || ctx->network() == nullptr) {
MS_LOG(ERROR) << "context or network is invalid";
return RET_ERROR;
}
auto square_layer = ctx->network()->addElementWise(*input(ctx, 0).trt_tensor_, *input(ctx, 0).trt_tensor_,
nvinfer1::ElementWiseOperation::kPROD);
CHECK_NULL_RETURN(square_layer);
auto out_tensor = square_layer->getOutput(0);
CHECK_NULL_RETURN(out_tensor);
this->layer_ = square_layer;
ctx->RegisterTensor(ITensorHelper{out_tensor, input(ctx, 0).format_, input(ctx, 0).same_format_},
out_tensors_[0].Name());
return RET_OK;
}
REGISTER_TENSORRT_CREATOR(schema::PrimitiveType_Square, SquareTensorRT)
} // namespace mindspore::lite

View File

@ -0,0 +1,42 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_SQUARE_TENSORRT_H_
#define MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_SQUARE_TENSORRT_H_
#include <string>
#include <vector>
#include "src/extendrt/delegate/tensorrt/op/tensorrt_op.h"
namespace mindspore::lite {
class SquareTensorRT : public TensorRTOp {
public:
SquareTensorRT(const schema::Primitive *primitive, const std::vector<mindspore::MSTensor> &in_tensors,
const std::vector<mindspore::MSTensor> &out_tensors, const std::string &name,
const schema::QuantType &quant_type)
: TensorRTOp(primitive, in_tensors, out_tensors, name, quant_type) {}
~SquareTensorRT() override = default;
int AddInnerOp(TensorRTContext *ctx) override;
int IsSupport(const schema::Primitive *primitive, const std::vector<mindspore::MSTensor> &in_tensors,
const std::vector<mindspore::MSTensor> &out_tensors) override;
private:
int RunAsTrtOps(TensorRTContext *ctx);
};
} // namespace mindspore::lite
#endif // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_SQUARE_TENSORRT_H_