!21100 [MSLITE] add MulFusion, StridedSlice, AvgPoolFusion, PadFusion ops support for tensorrt

Merge pull request !21100 from Liu_Xuu/trt_0729_gan
This commit is contained in:
i-robot 2021-07-31 01:40:47 +00:00 committed by Gitee
commit 4e8b34749a
12 changed files with 432 additions and 7 deletions

View File

@ -26,6 +26,7 @@ int ElementWiseTensorRT::IsSupport(const schema::Primitive *primitive,
{schema::PrimitiveType_PowFusion, nvinfer1::ElementWiseOperation::kPOW},
{schema::PrimitiveType_DivFusion, nvinfer1::ElementWiseOperation::kDIV},
{schema::PrimitiveType_SubFusion, nvinfer1::ElementWiseOperation::kSUB},
{schema::PrimitiveType_MulFusion, nvinfer1::ElementWiseOperation::kPROD},
};
auto iter_op = element_wise_ops.find(this->type_);
if (iter_op != element_wise_ops.end()) {
@ -151,6 +152,15 @@ nvinfer1::ITensor *ElementWiseTensorRT::AddActivation(nvinfer1::INetworkDefiniti
activation = sub_op->activation_type();
break;
}
case schema::PrimitiveType_MulFusion: {
auto mul_op = op_primitive_->value_as_MulFusion();
if (mul_op == nullptr) {
MS_LOG(ERROR) << "MulFusion convert failed.";
return nullptr;
}
activation = mul_op->activation_type();
break;
}
default:
MS_LOG(DEBUG) << "no activation need for: " << op_name_;
}

View File

@ -0,0 +1,105 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <numeric>
#include <functional>
#include "src/delegate/tensorrt/op/pad_tensorrt.h"
#include "src/delegate/tensorrt/tensorrt_utils.h"
namespace mindspore::lite {
int PadTensorRT::IsSupport(const mindspore::schema::Primitive *primitive,
const std::vector<mindspore::MSTensor> &in_tensors,
const std::vector<mindspore::MSTensor> &out_tensors) {
if (in_tensors.size() != 2 && in_tensors.size() != 3) {
MS_LOG(ERROR) << "Unsupported input tensor size, size is " << in_tensors.size();
return RET_ERROR;
}
if (out_tensors.size() != 1) {
MS_LOG(ERROR) << "Unsupported output tensor size, size is " << out_tensors.size();
return RET_ERROR;
}
if (in_tensors_[1].Data() == nullptr) {
MS_LOG(ERROR) << "invalid pad tensor for: " << op_name_;
return RET_ERROR;
}
auto pad_primitive = this->GetPrimitive()->value_as_PadFusion();
if (pad_primitive == nullptr) {
MS_LOG(ERROR) << "convert PadFusion failed: " << op_name_;
return RET_ERROR;
}
schema::PaddingMode padding_mode = pad_primitive->padding_mode();
if (padding_mode != schema::PaddingMode::PaddingMode_CONSTANT) {
MS_LOG(ERROR) << "Unsupported padding mode: " << pad_primitive << ", for op: " << op_name_;
return RET_ERROR;
}
constant_value_ = pad_primitive->constant_value();
return RET_OK;
}
int PadTensorRT::AddInnerOp(nvinfer1::INetworkDefinition *network) {
mindspore::MSTensor &pad_tensor = in_tensors_[1];
int element_cnt = std::accumulate(pad_tensor.Shape().begin(), pad_tensor.Shape().end(), 1, std::multiplies<int>());
if (element_cnt != tensorrt_in_tensors_[0]->getDimensions().nbDims * 2) {
MS_LOG(ERROR) << "pad tensor cnt is invalid. cnt: " << element_cnt
<< ", input tensor dims cnt: " << tensorrt_in_tensors_[0]->getDimensions().nbDims;
return RET_ERROR;
}
// transpose: NHWC->NCHW
nvinfer1::IShuffleLayer *transpose_layer_in = NHWC2NCHW(network, *tensorrt_in_tensors_[0]);
if (transpose_layer_in == nullptr) {
MS_LOG(ERROR) << "transpose: NHWC->NCHW failed";
return RET_ERROR;
}
transpose_layer_in->setName((op_name_ + "_transpose2NCHW").c_str());
// trt 6 only support 2D padding
const int *padding_data = reinterpret_cast<const int *>(in_tensors_[1].Data().get());
nvinfer1::IPaddingLayer *padding_layer = nullptr;
if (element_cnt == index_NHWC_ * 2) {
// NHWC only support pad at HW index
// 0: N_pre, 1: N_post, 2: H_pre, 3: H_post, 4: W_pre, 5: W_post, 6: C_pre, 7: C_post
if (*padding_data != 0 || *(padding_data + 1) != 0 || *(padding_data + 6) != 0 || *(padding_data + 7) != 0) {
MS_LOG(WARNING) << "tensorrt padding only support pad at HW index, unsupported padding value of: " << op_name_;
}
nvinfer1::DimsHW prePadding{*(padding_data + 2), *(padding_data + 4)};
nvinfer1::DimsHW postPadding{*(padding_data + 3), *(padding_data + 5)};
MS_LOG(INFO) << "prePadding: " << *(padding_data + 2) << ", " << *(padding_data + 4);
MS_LOG(INFO) << "postPadding: " << *(padding_data + 3) << ", " << *(padding_data + 5);
padding_layer = network->addPadding(*transpose_layer_in->getOutput(0), prePadding, postPadding);
} else {
MS_LOG(ERROR) << "need check for pad_tensor dims: " << op_name_
<< ", pad_tensor ElementNum: " << pad_tensor.ElementNum();
return RET_ERROR;
}
if (padding_layer == nullptr) {
MS_LOG(ERROR) << "add padding layer failed for " << op_name_;
return RET_ERROR;
}
padding_layer->setName(op_name_.c_str());
// transpose: NCHW->NHWC
nvinfer1::IShuffleLayer *transpose_layer_out = NCHW2NHWC(network, *padding_layer->getOutput(0));
if (transpose_layer_out == nullptr) {
MS_LOG(ERROR) << "op action convert failed";
return RET_ERROR;
}
transpose_layer_out->setName((op_name_ + "_transpose2NHWC").c_str());
this->AddInnerOutTensors(transpose_layer_out->getOutput(0));
return RET_OK;
}
} // namespace mindspore::lite

View File

@ -0,0 +1,41 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_DELEGATE_TENSORRT_OP_PAD_TENSORRT_H_
#define MINDSPORE_LITE_SRC_DELEGATE_TENSORRT_OP_PAD_TENSORRT_H_
#include <string>
#include <vector>
#include "src/delegate/tensorrt/op/tensorrt_op.h"
namespace mindspore::lite {
class PadTensorRT : public TensorRTOp {
public:
PadTensorRT(const schema::Primitive *primitive, const std::vector<mindspore::MSTensor> &in_tensors,
const std::vector<mindspore::MSTensor> &out_tensors, const std::string &name)
: TensorRTOp(primitive, in_tensors, out_tensors, name) {}
~PadTensorRT() override = default;
int AddInnerOp(nvinfer1::INetworkDefinition *network) override;
int IsSupport(const schema::Primitive *primitive, const std::vector<mindspore::MSTensor> &in_tensors,
const std::vector<mindspore::MSTensor> &out_tensors) override;
private:
const int index_NHWC_ = 4;
float constant_value_ = 0.0f;
};
} // namespace mindspore::lite
#endif // MINDSPORE_LITE_SRC_DELEGATE_TENSORRT_OP_PAD_TENSORRT_H_

View File

@ -0,0 +1,116 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/delegate/tensorrt/op/pool_tensorrt.h"
#include "src/delegate/tensorrt/op/activation_tensorrt.h"
#include "src/delegate/tensorrt/tensorrt_utils.h"
namespace mindspore::lite {
int PoolTensorRT::IsSupport(const mindspore::schema::Primitive *primitive,
const std::vector<mindspore::MSTensor> &in_tensors,
const std::vector<mindspore::MSTensor> &out_tensors) {
if (in_tensors.size() != 1) {
MS_LOG(ERROR) << "Unsupported input tensor size, size is " << in_tensors.size();
return RET_ERROR;
}
if (out_tensors.size() != 1) {
MS_LOG(ERROR) << "Unsupported output tensor size, size is " << out_tensors.size();
return RET_ERROR;
}
return RET_OK;
}
int PoolTensorRT::AddInnerOp(nvinfer1::INetworkDefinition *network) {
const schema::AvgPoolFusion *pool_primitive = this->GetPrimitive()->value_as_AvgPoolFusion();
if (pool_primitive == nullptr) {
MS_LOG(ERROR) << "convert PoolFusion failed: " << op_name_;
return RET_ERROR;
}
if (tensorrt_in_tensors_.size() != 1) {
MS_LOG(ERROR) << "invalid input tensor size: " << tensorrt_in_tensors_.size();
return RET_ERROR;
}
// transpose: NHWC->NCHW
nvinfer1::IShuffleLayer *transpose_layer_in = NHWC2NCHW(network, *tensorrt_in_tensors_[0]);
if (transpose_layer_in == nullptr) {
MS_LOG(ERROR) << "transpose: NHWC->NCHW failed";
return RET_ERROR;
}
transpose_layer_in->setName((op_name_ + "_transpose2NCHW").c_str());
// pooling layer
nvinfer1::PoolingType pooling_type = nvinfer1::PoolingType::kAVERAGE;
auto kernel_size = pool_primitive->kernel_size();
if (kernel_size == nullptr) {
MS_LOG(ERROR) << "get kernel size failed: " << op_name_;
return RET_ERROR;
}
std::vector<int64_t> kernel_size_val = std::vector<int64_t>(kernel_size->begin(), kernel_size->end());
nvinfer1::Dims windowSize = lite::ConvertCudaDims(kernel_size_val);
nvinfer1::IPoolingLayer *pooling_layer =
network->addPoolingNd(*transpose_layer_in->getOutput(0), pooling_type, windowSize);
if (pooling_layer == nullptr) {
MS_LOG(ERROR) << "addPoolingNd failed for TensorRT.";
return RET_ERROR;
}
AddParams(pool_primitive, pooling_layer);
pooling_layer->setName(op_name_.c_str());
// add activation
nvinfer1::ILayer *activation_layer = nullptr;
if (pool_primitive->activation_type() == schema::ActivationType::ActivationType_NO_ACTIVATION) {
activation_layer = pooling_layer;
} else {
activation_layer =
ActivationTensorRT::AddActivation(network, pool_primitive->activation_type(), 0, pooling_layer->getOutput(0));
if (activation_layer == nullptr) {
MS_LOG(ERROR) << "addActivation for pool failed";
return RET_ERROR;
}
activation_layer->setName((op_name_ + "_activation").c_str());
}
// transpose: NCHW->NHWC
nvinfer1::IShuffleLayer *transpose_layer_out = NCHW2NHWC(network, *activation_layer->getOutput(0));
if (transpose_layer_out == nullptr) {
MS_LOG(ERROR) << "op action convert failed";
return RET_ERROR;
}
transpose_layer_out->setName((op_name_ + "_transpose2NHWC").c_str());
this->AddInnerOutTensors(transpose_layer_out->getOutput(0));
return RET_OK;
}
void PoolTensorRT::AddParams(const schema::AvgPoolFusion *primitive, nvinfer1::IPoolingLayer *pooling_layer) {
auto stride = primitive->strides();
std::vector<int64_t> stride_val = std::vector<int64_t>(stride->begin(), stride->end());
nvinfer1::Dims stride_dims = ConvertCudaDims(stride_val);
pooling_layer->setStrideNd(stride_dims);
schema::PadMode pad_mode = primitive->pad_mode();
if (pad_mode == schema::PadMode::PadMode_SAME) {
pooling_layer->setPaddingMode(nvinfer1::PaddingMode::kSAME_UPPER);
}
auto padding = primitive->pad();
if (padding != nullptr) {
auto padding_val = std::vector<int64_t>(padding->begin(), padding->end());
nvinfer1::Dims dims{};
dims.nbDims = 2;
dims.d[0] = padding_val[1];
dims.d[1] = padding_val[2];
pooling_layer->setPaddingNd(dims);
}
}
} // namespace mindspore::lite

View File

@ -0,0 +1,40 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_DELEGATE_TENSORRT_OP_POOL_TENSORRT_H_
#define MINDSPORE_LITE_SRC_DELEGATE_TENSORRT_OP_POOL_TENSORRT_H_
#include <string>
#include <vector>
#include "src/delegate/tensorrt/op/tensorrt_op.h"
namespace mindspore::lite {
class PoolTensorRT : public TensorRTOp {
public:
PoolTensorRT(const schema::Primitive *primitive, const std::vector<mindspore::MSTensor> &in_tensors,
const std::vector<mindspore::MSTensor> &out_tensors, const std::string &name)
: TensorRTOp(primitive, in_tensors, out_tensors, name) {}
~PoolTensorRT() override = default;
int AddInnerOp(nvinfer1::INetworkDefinition *network) override;
int IsSupport(const schema::Primitive *primitive, const std::vector<mindspore::MSTensor> &in_tensors,
const std::vector<mindspore::MSTensor> &out_tensors) override;
private:
void AddParams(const schema::AvgPoolFusion *primitive, nvinfer1::IPoolingLayer *pooling_layer);
};
} // namespace mindspore::lite
#endif // MINDSPORE_LITE_SRC_DELEGATE_TENSORRT_OP_POOL_TENSORRT_H_

View File

@ -190,8 +190,8 @@ int ShuffleTensorRT::AddReshapeOp(nvinfer1::IShuffleLayer *shuffle_layer) {
MS_LOG(ERROR) << "AddReshapeOp size of in tensort needs check: " << in_tensors_.size();
return RET_ERROR;
}
mindspore::MSTensor shape_tensor = in_tensors_[1];
nvinfer1::Dims reshape_dims = ConvertCudaDims(shape_tensor.MutableData(), shape_tensor.ElementNum());
mindspore::MSTensor &shape_tensor = in_tensors_[1];
nvinfer1::Dims reshape_dims = ConvertCudaDims(shape_tensor.Data().get(), shape_tensor.ElementNum());
int ret = InferReshapeDims(tensorrt_in_tensors_[0]->getDimensions(), &reshape_dims);
if (ret != RET_OK) {
MS_LOG(ERROR) << "invalid dims for reshape " << op_name_;

View File

@ -0,0 +1,68 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/delegate/tensorrt/op/slice_tensorrt.h"
#include "src/delegate/tensorrt/tensorrt_utils.h"
namespace mindspore::lite {
int SliceTensorRT::IsSupport(const mindspore::schema::Primitive *primitive,
const std::vector<mindspore::MSTensor> &in_tensors,
const std::vector<mindspore::MSTensor> &out_tensors) {
if (in_tensors.size() != 4 && in_tensors.size() != 5) {
MS_LOG(ERROR) << "Unsupported input tensor size, size is " << in_tensors.size();
return RET_ERROR;
}
if (out_tensors.size() != 1) {
MS_LOG(ERROR) << "Unsupported output tensor size, size is " << out_tensors.size();
return RET_ERROR;
}
if (in_tensors_[1].Data() == nullptr) {
MS_LOG(ERROR) << "invalid pad tensor for: " << op_name_;
return RET_ERROR;
}
return RET_OK;
}
int SliceTensorRT::AddInnerOp(nvinfer1::INetworkDefinition *network) {
auto slice_primitive = this->GetPrimitive()->value_as_StridedSlice();
if (slice_primitive == nullptr) {
MS_LOG(ERROR) << "convert StridedSlice failed: " << op_name_;
return RET_ERROR;
}
const mindspore::MSTensor &begin = in_tensors_[1];
// mindspore::MSTensor &end = in_tensors_[2];
const mindspore::MSTensor &stride = in_tensors_[3];
nvinfer1::Dims start_dims = lite::ConvertCudaDims(begin.Data().get(), begin.ElementNum());
nvinfer1::Dims size_dims = lite::ConvertCudaDims(out_tensors_[0].Shape());
nvinfer1::Dims stride_dims = lite::ConvertCudaDims(stride.Data().get(), stride.ElementNum());
nvinfer1::ISliceLayer *slice_layer = network->addSlice(*tensorrt_in_tensors_[0], start_dims, size_dims, stride_dims);
if (slice_layer == nullptr) {
MS_LOG(ERROR) << "add Slice op failed for TensorRT: " << op_name_;
return RET_ERROR;
}
slice_layer->setName(op_name_.c_str());
nvinfer1::ITensor *out_tensor = slice_layer->getOutput(0);
if (out_tensor == nullptr) {
MS_LOG(ERROR) << "output tensor create failed";
return RET_ERROR;
}
out_tensor->setName(out_tensors_[0].Name().c_str());
this->AddInnerOutTensors(out_tensor);
return RET_OK;
}
} // namespace mindspore::lite

View File

@ -0,0 +1,37 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_DELEGATE_TENSORRT_OP_SLICE_TENSORRT_H_
#define MINDSPORE_LITE_SRC_DELEGATE_TENSORRT_OP_SLICE_TENSORRT_H_
#include <string>
#include <vector>
#include "src/delegate/tensorrt/op/tensorrt_op.h"
namespace mindspore::lite {
class SliceTensorRT : public TensorRTOp {
public:
SliceTensorRT(const schema::Primitive *primitive, const std::vector<mindspore::MSTensor> &in_tensors,
const std::vector<mindspore::MSTensor> &out_tensors, const std::string &name)
: TensorRTOp(primitive, in_tensors, out_tensors, name) {}
~SliceTensorRT() override = default;
int AddInnerOp(nvinfer1::INetworkDefinition *network) override;
int IsSupport(const schema::Primitive *primitive, const std::vector<mindspore::MSTensor> &in_tensors,
const std::vector<mindspore::MSTensor> &out_tensors) override;
};
} // namespace mindspore::lite
#endif // MINDSPORE_LITE_SRC_DELEGATE_TENSORRT_OP_SLICE_TENSORRT_H_

View File

@ -33,6 +33,9 @@
#include "src/delegate/tensorrt/op/unary_tensorrt.h"
#include "src/delegate/tensorrt/op/matmul_tensorrt.h"
#include "src/delegate/tensorrt/op/scale_tensorrt.h"
#include "src/delegate/tensorrt/op/slice_tensorrt.h"
#include "src/delegate/tensorrt/op/pool_tensorrt.h"
#include "src/delegate/tensorrt/op/pad_tensorrt.h"
namespace mindspore::lite {
bool IsHardwareSupport() {
@ -76,12 +79,16 @@ int TensorRTDelegate::Init() {
{schema::PrimitiveType_DivFusion, GetTensorRTOp<ElementWiseTensorRT>},
{schema::PrimitiveType_PowFusion, GetTensorRTOp<ElementWiseTensorRT>},
{schema::PrimitiveType_AddFusion, GetTensorRTOp<ElementWiseTensorRT>},
{schema::PrimitiveType_MulFusion, GetTensorRTOp<ElementWiseTensorRT>},
{schema::PrimitiveType_Eltwise, GetTensorRTOp<ElementWiseTensorRT>},
{schema::PrimitiveType_Transpose, GetTensorRTOp<ShuffleTensorRT>},
{schema::PrimitiveType_ReduceFusion, GetTensorRTOp<ReduceTensorRT>},
{schema::PrimitiveType_Sqrt, GetTensorRTOp<UnaryTensorRT>},
{schema::PrimitiveType_MatMul, GetTensorRTOp<MatMulTensorRT>},
{schema::PrimitiveType_ScaleFusion, GetTensorRTOp<ScaleTensorRT>},
{schema::PrimitiveType_StridedSlice, GetTensorRTOp<SliceTensorRT>},
{schema::PrimitiveType_AvgPoolFusion, GetTensorRTOp<PoolTensorRT>},
{schema::PrimitiveType_PadFusion, GetTensorRTOp<PadTensorRT>},
};
return RET_OK;
}

View File

@ -37,7 +37,8 @@ class TensorRTSubGraph : public kernel::Kernel {
trt_specific_weight_nodes_ = {
schema::PrimitiveType_Conv2DFusion, schema::PrimitiveType_ReduceFusion, schema::PrimitiveType_Transpose,
schema::PrimitiveType_Gather, schema::PrimitiveType_Reshape, schema::PrimitiveType_PowFusion,
schema::PrimitiveType_DivFusion, schema::PrimitiveType_MatMul, schema::PrimitiveType_ScaleFusion};
schema::PrimitiveType_DivFusion, schema::PrimitiveType_MatMul, schema::PrimitiveType_ScaleFusion,
schema::PrimitiveType_MulFusion, schema::PrimitiveType_StridedSlice, schema::PrimitiveType_PadFusion};
}
~TensorRTSubGraph() override;

View File

@ -37,11 +37,11 @@ nvinfer1::Dims ConvertCudaDims(int data, size_t size) {
return dims;
}
nvinfer1::Dims ConvertCudaDims(void *data, size_t size) {
nvinfer1::Dims ConvertCudaDims(const void *data, int64_t size) {
nvinfer1::Dims dims{};
dims.nbDims = size;
int *dims_data = reinterpret_cast<int *>(data);
for (size_t i = 0; i < size; i++) {
const int *dims_data = reinterpret_cast<const int *>(data);
for (int i = 0; i < size; i++) {
dims.d[i] = *(dims_data + i);
}
return dims;

View File

@ -34,7 +34,7 @@ struct ActivationParams {
nvinfer1::Dims ConvertCudaDims(const std::vector<int64_t> &shape);
// Convert Tensor data to Cuda dims.
nvinfer1::Dims ConvertCudaDims(void *data, size_t size);
nvinfer1::Dims ConvertCudaDims(const void *data, int64_t size);
nvinfer1::Dims ConvertCudaDims(int data, size_t size);