forked from mindspore-Ecosystem/mindspore
Add oneslike gathernd and square tensorrt ops
This commit is contained in:
parent
7014e91c14
commit
4424f0d694
|
@ -0,0 +1,79 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/extendrt/delegate/tensorrt/op/gather_nd_tensorrt.h"
|
||||
#include "src/extendrt/delegate/tensorrt/tensorrt_utils.h"
|
||||
|
||||
namespace mindspore::lite {
|
||||
int GatherNDTensorRT::IsSupport(const schema::Primitive *primitive, const std::vector<mindspore::MSTensor> &in_tensors,
|
||||
const std::vector<mindspore::MSTensor> &out_tensors) {
|
||||
#if TRT_VERSION_GE(8, 2)
|
||||
if (in_tensors.size() != INPUT_SIZE2) {
|
||||
MS_LOG(ERROR) << "Unsupported input tensor size, size is " << in_tensors.size();
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (in_tensors[1].DataType() != DataType::kNumberTypeInt32) {
|
||||
MS_LOG(ERROR) << "Gather indices only support Int32";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (out_tensors.size() != 1) {
|
||||
MS_LOG(ERROR) << "Unsupported output tensor size, size is " << out_tensors.size();
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
return RET_OK;
|
||||
#else
|
||||
MS_LOG(WARNING) << "low TensorRT version don't support gathernd op, please upgrade TensorRT version to 8.2 or higher";
|
||||
return RET_ERROR;
|
||||
#endif
|
||||
}
|
||||
|
||||
int GatherNDTensorRT::AddInnerOp(TensorRTContext *ctx) {
|
||||
#if TRT_VERSION_GE(8, 2)
|
||||
if (ctx == nullptr || ctx->network() == nullptr) {
|
||||
MS_LOG(ERROR) << "context or network is invalid";
|
||||
return RET_ERROR;
|
||||
}
|
||||
ITensorHelper gather_nd_input = input(ctx, 0);
|
||||
ITensorHelper indices_tensor = input(ctx, 1);
|
||||
if (indices_tensor.trt_tensor_->getDimensions().nbDims < 1) {
|
||||
MS_LOG(ERROR) << "addGather failed for TensorRT.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto nbElementWiseDims = gather_nd_input.trt_tensor_->getDimensions().d[0];
|
||||
if (nbElementWiseDims == -1) {
|
||||
nbElementWiseDims = 0;
|
||||
}
|
||||
nvinfer1::IGatherLayer *gather_layer =
|
||||
ctx->network()->addGatherV2(*gather_nd_input.trt_tensor_, *indices_tensor.trt_tensor_, nvinfer1::GatherMode::kND);
|
||||
if (gather_layer == nullptr) {
|
||||
MS_LOG(ERROR) << "addGatherND failed for TensorRT.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
gather_layer->setNbElementWiseDims(nbElementWiseDims);
|
||||
this->layer_ = gather_layer;
|
||||
gather_layer->setName(op_name_.c_str());
|
||||
nvinfer1::ITensor *op_output = gather_layer->getOutput(0);
|
||||
ctx->RegisterTensor(ITensorHelper{op_output, gather_nd_input.format_, gather_nd_input.same_format_},
|
||||
out_tensors_[0].Name());
|
||||
return RET_OK;
|
||||
#else
|
||||
MS_LOG(WARNING) << "low TensorRT version don't support gathernd op, please upgrade TensorRT version to 8.2 or higher";
|
||||
return RET_ERROR;
|
||||
#endif
|
||||
}
|
||||
REGISTER_TENSORRT_CREATOR(schema::PrimitiveType_GatherNd, GatherNDTensorRT)
|
||||
} // namespace mindspore::lite
|
|
@ -0,0 +1,39 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_GATHER_ND_TENSORRT_H_
|
||||
#define MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_GATHER_ND_TENSORRT_H_
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "src/extendrt/delegate/tensorrt/op/tensorrt_op.h"
|
||||
|
||||
namespace mindspore::lite {
|
||||
class GatherNDTensorRT : public TensorRTOp {
|
||||
public:
|
||||
GatherNDTensorRT(const schema::Primitive *primitive, const std::vector<mindspore::MSTensor> &in_tensors,
|
||||
const std::vector<mindspore::MSTensor> &out_tensors, const std::string &name,
|
||||
const schema::QuantType &quant_type)
|
||||
: TensorRTOp(primitive, in_tensors, out_tensors, name, quant_type) {}
|
||||
|
||||
~GatherNDTensorRT() override = default;
|
||||
|
||||
int AddInnerOp(TensorRTContext *ctx) override;
|
||||
|
||||
int IsSupport(const schema::Primitive *primitive, const std::vector<mindspore::MSTensor> &in_tensors,
|
||||
const std::vector<mindspore::MSTensor> &out_tensors) override;
|
||||
};
|
||||
} // namespace mindspore::lite
|
||||
#endif // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_GATHER_ND_TENSORRT_H_
|
|
@ -0,0 +1,75 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <memory>
|
||||
#include "src/extendrt/delegate/tensorrt/op/oneslike_tensorrt.h"
|
||||
|
||||
namespace mindspore::lite {
|
||||
int OneslikeTensorRT::IsSupport(const schema::Primitive *primitive, const std::vector<mindspore::MSTensor> &in_tensors,
|
||||
const std::vector<mindspore::MSTensor> &out_tensors) {
|
||||
if (in_tensors.size() != 1) {
|
||||
MS_LOG(ERROR) << "Unsupported input tensor size, size is " << in_tensors.size();
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (out_tensors.size() != 1) {
|
||||
MS_LOG(ERROR) << "Unsupported output tensor size, size is " << in_tensors.size();
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int OneslikeTensorRT::AddInnerOp(TensorRTContext *ctx) {
|
||||
if (ctx == nullptr || ctx->network() == nullptr) {
|
||||
MS_LOG(ERROR) << "context or network is invalid";
|
||||
return RET_ERROR;
|
||||
}
|
||||
int input_nbdims = input(ctx, 0).trt_tensor_->getDimensions().nbDims;
|
||||
if (input_nbdims == -1) {
|
||||
MS_LOG(ERROR) << "oneslike op failed for " << op_name_;
|
||||
return RET_ERROR;
|
||||
}
|
||||
int ret = RunAsTrtOps(ctx);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "oneslike op failed for " << op_name_;
|
||||
return ret;
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
int OneslikeTensorRT::RunAsTrtOps(TensorRTContext *ctx) {
|
||||
if (ctx == nullptr || ctx->network() == nullptr) {
|
||||
MS_LOG(ERROR) << "context or network is invalid";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto const_zero = ctx->ConvertTo1DTensor(std::vector<float>(input(ctx, 0).trt_tensor_->getDimensions().nbDims, 0.f));
|
||||
CHECK_NULL_RETURN(const_zero);
|
||||
auto const_one = ctx->ConvertTo1DTensor(std::vector<float>(input(ctx, 0).trt_tensor_->getDimensions().nbDims, 1.f));
|
||||
CHECK_NULL_RETURN(const_one);
|
||||
auto prod_tensor = ctx->network()
|
||||
->addElementWise(*input(ctx, 0).trt_tensor_, *const_zero, nvinfer1::ElementWiseOperation::kPROD)
|
||||
->getOutput(0);
|
||||
CHECK_NULL_RETURN(prod_tensor);
|
||||
auto oneslike_layer = ctx->network()->addElementWise(*prod_tensor, *const_one, nvinfer1::ElementWiseOperation::kSUM);
|
||||
CHECK_NULL_RETURN(oneslike_layer);
|
||||
auto out_tensor = oneslike_layer->getOutput(0);
|
||||
CHECK_NULL_RETURN(out_tensor);
|
||||
this->layer_ = oneslike_layer;
|
||||
ctx->RegisterTensor(ITensorHelper{out_tensor, input(ctx, 0).format_, input(ctx, 0).same_format_},
|
||||
out_tensors_[0].Name());
|
||||
return RET_OK;
|
||||
}
|
||||
REGISTER_TENSORRT_CREATOR(schema::PrimitiveType_OnesLike, OneslikeTensorRT)
|
||||
} // namespace mindspore::lite
|
|
@ -0,0 +1,42 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_ONESLIKE_TENSORRT_H_
|
||||
#define MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_ONESLIKE_TENSORRT_H_
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "src/extendrt/delegate/tensorrt/op/tensorrt_op.h"
|
||||
|
||||
namespace mindspore::lite {
|
||||
class OneslikeTensorRT : public TensorRTOp {
|
||||
public:
|
||||
OneslikeTensorRT(const schema::Primitive *primitive, const std::vector<mindspore::MSTensor> &in_tensors,
|
||||
const std::vector<mindspore::MSTensor> &out_tensors, const std::string &name,
|
||||
const schema::QuantType &quant_type)
|
||||
: TensorRTOp(primitive, in_tensors, out_tensors, name, quant_type) {}
|
||||
|
||||
~OneslikeTensorRT() override = default;
|
||||
|
||||
int AddInnerOp(TensorRTContext *ctx) override;
|
||||
|
||||
int IsSupport(const schema::Primitive *primitive, const std::vector<mindspore::MSTensor> &in_tensors,
|
||||
const std::vector<mindspore::MSTensor> &out_tensors) override;
|
||||
|
||||
private:
|
||||
int RunAsTrtOps(TensorRTContext *ctx);
|
||||
};
|
||||
} // namespace mindspore::lite
|
||||
#endif // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_ONESLIKE_TENSORRT_H_
|
|
@ -0,0 +1,69 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/extendrt/delegate/tensorrt/op/square_tensorrt.h"
|
||||
|
||||
namespace mindspore::lite {
|
||||
int SquareTensorRT::IsSupport(const schema::Primitive *primitive, const std::vector<mindspore::MSTensor> &in_tensors,
|
||||
const std::vector<mindspore::MSTensor> &out_tensors) {
|
||||
if (in_tensors.size() != 1) {
|
||||
MS_LOG(ERROR) << "Unsupported input tensor size, size is " << in_tensors.size();
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (out_tensors.size() != 1) {
|
||||
MS_LOG(ERROR) << "Unsupported output tensor size, size is " << in_tensors.size();
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int SquareTensorRT::AddInnerOp(TensorRTContext *ctx) {
|
||||
if (ctx == nullptr || ctx->network() == nullptr) {
|
||||
MS_LOG(ERROR) << "context or network is invalid";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto square_op = op_primitive_->value_as_Square();
|
||||
CHECK_NULL_RETURN(square_op);
|
||||
int input_nbdims = input(ctx, 0).trt_tensor_->getDimensions().nbDims;
|
||||
if (input_nbdims == -1) {
|
||||
MS_LOG(ERROR) << "square failed for " << op_name_;
|
||||
return RET_ERROR;
|
||||
}
|
||||
int ret = RunAsTrtOps(ctx);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "square failed for " << op_name_;
|
||||
return ret;
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
int SquareTensorRT::RunAsTrtOps(TensorRTContext *ctx) {
|
||||
if (ctx == nullptr || ctx->network() == nullptr) {
|
||||
MS_LOG(ERROR) << "context or network is invalid";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto square_layer = ctx->network()->addElementWise(*input(ctx, 0).trt_tensor_, *input(ctx, 0).trt_tensor_,
|
||||
nvinfer1::ElementWiseOperation::kPROD);
|
||||
CHECK_NULL_RETURN(square_layer);
|
||||
auto out_tensor = square_layer->getOutput(0);
|
||||
CHECK_NULL_RETURN(out_tensor);
|
||||
this->layer_ = square_layer;
|
||||
ctx->RegisterTensor(ITensorHelper{out_tensor, input(ctx, 0).format_, input(ctx, 0).same_format_},
|
||||
out_tensors_[0].Name());
|
||||
return RET_OK;
|
||||
}
|
||||
REGISTER_TENSORRT_CREATOR(schema::PrimitiveType_Square, SquareTensorRT)
|
||||
} // namespace mindspore::lite
|
|
@ -0,0 +1,42 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_SQUARE_TENSORRT_H_
|
||||
#define MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_SQUARE_TENSORRT_H_
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "src/extendrt/delegate/tensorrt/op/tensorrt_op.h"
|
||||
|
||||
namespace mindspore::lite {
|
||||
class SquareTensorRT : public TensorRTOp {
|
||||
public:
|
||||
SquareTensorRT(const schema::Primitive *primitive, const std::vector<mindspore::MSTensor> &in_tensors,
|
||||
const std::vector<mindspore::MSTensor> &out_tensors, const std::string &name,
|
||||
const schema::QuantType &quant_type)
|
||||
: TensorRTOp(primitive, in_tensors, out_tensors, name, quant_type) {}
|
||||
|
||||
~SquareTensorRT() override = default;
|
||||
|
||||
int AddInnerOp(TensorRTContext *ctx) override;
|
||||
|
||||
int IsSupport(const schema::Primitive *primitive, const std::vector<mindspore::MSTensor> &in_tensors,
|
||||
const std::vector<mindspore::MSTensor> &out_tensors) override;
|
||||
|
||||
private:
|
||||
int RunAsTrtOps(TensorRTContext *ctx);
|
||||
};
|
||||
} // namespace mindspore::lite
|
||||
#endif // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_SQUARE_TENSORRT_H_
|
Loading…
Reference in New Issue