[MSLITE] add matmul, fully connected optimize for tensorr 0209_01

This commit is contained in:
Liu_Xuu 2022-01-24 16:53:46 +08:00
parent 9c8a4279c8
commit 8b47d8c515
14 changed files with 420 additions and 189 deletions

View File

@ -69,7 +69,7 @@ int ConvolutionTensorRT::AddInnerOp(nvinfer1::INetworkDefinition *network) {
// transpose weight
const mindspore::MSTensor &weight_tensor = in_tensors_[1];
nvinfer1::Weights kernelWeights = lite::TransposeWeight(weight_tensor, &pack_weight_);
nvinfer1::Weights kernelWeights = lite::TransposeWeight4D(weight_tensor, &pack_weight_);
// conv
int nbOutputMaps = weight_tensor.Shape()[0];

View File

@ -66,7 +66,7 @@ int DeconvolutionTensorRT::AddInnerOp(nvinfer1::INetworkDefinition *network) {
// transpose weight
const mindspore::MSTensor &weight_tensor = in_tensors_[1];
nvinfer1::Weights kernelWeights = lite::TransposeWeight(weight_tensor, &pack_weight_);
nvinfer1::Weights kernelWeights = lite::TransposeWeight4D(weight_tensor, &pack_weight_);
// deconv basic params
int nbOutputMaps = weight_tensor.Shape()[0];

View File

@ -0,0 +1,106 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/delegate/tensorrt/op/fullyconnected_tensorrt.h"
#include "src/delegate/tensorrt/tensorrt_utils.h"
#include "src/delegate/tensorrt/op/activation_tensorrt.h"
namespace mindspore::lite {
constexpr int BIAS_INDEX = 2;
int FullyConnectedTensorRT::IsSupport(const mindspore::schema::Primitive *primitive,
const std::vector<mindspore::MSTensor> &in_tensors,
const std::vector<mindspore::MSTensor> &out_tensors) {
if (!IsShapeKnown()) {
MS_LOG(ERROR) << "Unsupported input tensor unknown shape: " << op_name_;
return RET_ERROR;
}
if (in_tensors.size() != INPUT_SIZE3) {
MS_LOG(ERROR) << "Unsupported input tensor size, size is " << in_tensors.size();
return RET_ERROR;
}
return RET_OK;
}
int FullyConnectedTensorRT::AddInnerOp(nvinfer1::INetworkDefinition *network) {
int axis = -1;
if (type_ == schema::PrimitiveType_FullConnection) {
auto primitive = this->GetPrimitive()->value_as_FullConnection();
if (primitive == nullptr) {
MS_LOG(ERROR) << "convert to primitive FullConnection failed for " << op_name_;
return RET_ERROR;
}
activation_ = primitive->activation_type();
axis = primitive->axis();
}
if (axis < 0 || axis >= out_tensors_[0].Shape().size()) {
MS_LOG(ERROR) << "axis: " << axis << " is invalid for " << op_name_;
return RET_ERROR;
}
ITensorHelper fc_input;
auto ret = PreprocessInputs(network, &fc_input);
if (ret != RET_OK) {
MS_LOG(ERROR) << "PreprocessInputs failed for " << op_name_;
return ret;
}
auto kernel_weight = ConvertWeight(in_tensors_[1]);
auto bias_weight = ConvertWeight(in_tensors_[BIAS_INDEX]);
nvinfer1::IFullyConnectedLayer *fc_layer =
network->addFullyConnected(*(fc_input.trt_tensor_), out_tensors_[0].Shape()[axis], kernel_weight, bias_weight);
if (fc_layer == nullptr) {
MS_LOG(ERROR) << "addFullyConnected failed for " << op_name_;
return RET_ERROR;
}
fc_layer->setName(op_name_.c_str());
nvinfer1::ITensor *out_tensor = fc_layer->getOutput(0);
if (out_tensor->getDimensions().nbDims != out_tensors_[0].Shape().size()) {
std::vector<int64_t> squeeze_dim(out_tensors_[0].Shape());
squeeze_dim[0] = out_tensor->getDimensions().d[0] == -1 ? -1 : squeeze_dim[0];
out_tensor = Reshape(network, out_tensor, squeeze_dim);
}
// add activation
if (activation_ != schema::ActivationType::ActivationType_NO_ACTIVATION) {
nvinfer1::ILayer *activation_layer = ActivationTensorRT::AddActivation(network, activation_, 0, 0, 0, out_tensor);
if (activation_layer == nullptr) {
MS_LOG(ERROR) << "addActivation for matmul failed";
return RET_ERROR;
}
activation_layer->setName((op_name_ + "_activation").c_str());
out_tensor = activation_layer->getOutput(0);
}
out_tensor->setName((op_name_ + "_output").c_str());
MS_LOG(DEBUG) << "output " << GetTensorFormat(out_tensor);
this->AddInnerOutTensors(ITensorHelper{out_tensor, fc_input.format_});
return RET_OK;
}
int FullyConnectedTensorRT::PreprocessInputs(nvinfer1::INetworkDefinition *network, ITensorHelper *fc_input) {
auto ret = PreprocessInputs2SameDim(network, tensorrt_in_tensors_[0], fc_input);
if (ret != RET_OK) {
MS_LOG(ERROR) << "PreprocessInputs2SameDim failed for " << op_name_;
return ret;
}
auto origin_dims = fc_input->trt_tensor_->getDimensions();
if (origin_dims.nbDims != DIMENSION_4D) {
std::vector<int64_t> expand_dim(origin_dims.d, origin_dims.d + origin_dims.nbDims);
for (int i = 0; i < DIMENSION_4D - origin_dims.nbDims; i++) {
expand_dim.push_back(1);
}
fc_input->trt_tensor_ = Reshape(network, fc_input->trt_tensor_, expand_dim);
}
return RET_OK;
}
} // namespace mindspore::lite

View File

@ -0,0 +1,44 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_DELEGATE_TENSORRT_OP_FULLYCONNECTED_TENSORRT_H_
#define MINDSPORE_LITE_SRC_DELEGATE_TENSORRT_OP_FULLYCONNECTED_TENSORRT_H_
#include <string>
#include <vector>
#include <map>
#include "src/delegate/tensorrt/op/tensorrt_op.h"
namespace mindspore::lite {
class FullyConnectedTensorRT : public TensorRTOp {
public:
FullyConnectedTensorRT(const schema::Primitive *primitive, const std::vector<mindspore::MSTensor> &in_tensors,
const std::vector<mindspore::MSTensor> &out_tensors, const std::string &name)
: TensorRTOp(primitive, in_tensors, out_tensors, name) {}
~FullyConnectedTensorRT() override = default;
int AddInnerOp(nvinfer1::INetworkDefinition *network) override;
int IsSupport(const schema::Primitive *primitive, const std::vector<mindspore::MSTensor> &in_tensors,
const std::vector<mindspore::MSTensor> &out_tensors) override;
private:
int PreprocessInputs(nvinfer1::INetworkDefinition *network, ITensorHelper *fc_input);
schema::ActivationType activation_{schema::ActivationType::ActivationType_NO_ACTIVATION};
};
} // namespace mindspore::lite
#endif // MINDSPORE_LITE_SRC_DELEGATE_TENSORRT_OP_FULLYCONNECTED_TENSORRT_H_

View File

@ -162,7 +162,10 @@ int LSTMTensorRT::AddLSTMLayers() {
layer_weights[0].max_seq_size_ = max_sequence_size;
int ret = ParseLSTMCellInputs(i, hidden_init, cell_init, layer_input_states, &input_weight_offset,
&state_weight_offset, &bias_offset, layer_weights, next_state);
if (ret != RET_OK) {
MS_LOG(ERROR) << "ParseLSTMCellInputs failed for " << op_name_;
return RET_ERROR;
}
data_out = AddLSTMCell(layer_input_states, layer_weights, &next_state);
hidden_outputs.push_back(next_state.hidden_);
cell_outputs.push_back(next_state.cell_);

View File

@ -17,9 +17,15 @@
#include "src/delegate/tensorrt/op/matmul_tensorrt.h"
#include "src/delegate/tensorrt/tensorrt_utils.h"
#include "src/delegate/tensorrt/op/activation_tensorrt.h"
namespace mindspore::lite {
constexpr int BIAS_INDEX = 2;
MatMulTensorRT::~MatMulTensorRT() {
if (weight_ptr_ != nullptr) {
free(weight_ptr_);
weight_ptr_ = nullptr;
}
}
int MatMulTensorRT::IsSupport(const mindspore::schema::Primitive *primitive,
const std::vector<mindspore::MSTensor> &in_tensors,
const std::vector<mindspore::MSTensor> &out_tensors) {
@ -48,58 +54,22 @@ int MatMulTensorRT::AddInnerOp(nvinfer1::INetworkDefinition *network) {
transpose_a_ = primitive->transpose_a() ? nvinfer1::MatrixOperation::kTRANSPOSE : nvinfer1::MatrixOperation::kNONE;
transpose_b_ = primitive->transpose_b() ? nvinfer1::MatrixOperation::kTRANSPOSE : nvinfer1::MatrixOperation::kNONE;
activation_ = primitive->activation_type();
} else if (type_ == schema::PrimitiveType_FullConnection) {
transpose_a_ = nvinfer1::MatrixOperation::kNONE;
transpose_b_ = nvinfer1::MatrixOperation::kTRANSPOSE;
}
ITensorHelper matmul_a;
ITensorHelper matmul_b;
int ret = PreprocessInputs(network, &matmul_a, &matmul_b);
if (ret != RET_OK || matmul_a.trt_tensor_ == nullptr || matmul_b.trt_tensor_ == nullptr) {
MS_LOG(ERROR) << "PreprocessInputs matmul failed for " << op_name_;
nvinfer1::ITensor *out_tensor = nullptr;
if (in_tensors_.size() == INPUT_SIZE3 && in_tensors_[1].Data() != nullptr &&
in_tensors_[BIAS_INDEX].Data() != nullptr && transpose_a_ == nvinfer1::MatrixOperation::kNONE &&
in_tensors_[1].Shape().size() == DIMENSION_2D &&
(in_tensors_[0].Shape().size() == DIMENSION_2D || in_tensors_[0].Shape().size() == DIMENSION_4D)) {
MS_LOG(DEBUG) << "use fully connected instead of matmul for " << op_name_;
out_tensor = AddAsFullConnect(network);
} else {
out_tensor = AddAsMatmul(network);
}
if (out_tensor == nullptr) {
MS_LOG(ERROR) << "add matmul failed for " << op_name_;
return RET_ERROR;
}
MS_LOG(DEBUG) << "matmul input a " << GetTensorFormat(matmul_a);
MS_LOG(DEBUG) << "matmul input b " << GetTensorFormat(matmul_b);
auto matmul_layer =
network->addMatrixMultiply(*matmul_a.trt_tensor_, transpose_a_, *matmul_b.trt_tensor_, transpose_b_);
if (matmul_layer == nullptr) {
MS_LOG(ERROR) << "addMatrixMultiply failed for " << op_name_;
return RET_ERROR;
}
matmul_layer->setName(op_name_.c_str());
nvinfer1::ITensor *out_tensor = matmul_layer->getOutput(0);
tensor_name_map_[matmul_layer->getOutput(0)->getName()] = op_name_;
if (in_tensors_.size() == BIAS_INDEX + 1) {
nvinfer1::ITensor *bias = nullptr;
if (in_tensors_[BIAS_INDEX].Shape().size() < static_cast<size_t>(out_tensor->getDimensions().nbDims)) {
bias =
ConvertTensorWithExpandDims(network, in_tensors_[BIAS_INDEX], out_tensor->getDimensions().nbDims, op_name_);
} else if (in_tensors_[BIAS_INDEX].Shape().size() == static_cast<size_t>(out_tensor->getDimensions().nbDims)) {
bias = ConvertConstantTensor(network, in_tensors_[BIAS_INDEX], op_name_);
} else {
MS_LOG(ERROR) << "input tensor shape is invalid for " << op_name_;
return RET_ERROR;
}
if (bias == nullptr) {
MS_LOG(ERROR) << "create constant bias tensor failed for " << op_name_;
return RET_ERROR;
}
auto bias_layer = network->addElementWise(*matmul_layer->getOutput(0), *bias, nvinfer1::ElementWiseOperation::kSUM);
if (bias_layer == nullptr) {
MS_LOG(ERROR) << "add bias add layer failed for " << op_name_;
return RET_ERROR;
}
auto bias_layer_name = op_name_ + "_bias";
bias_layer->setName(bias_layer_name.c_str());
out_tensor = bias_layer->getOutput(0);
}
// add activation
if (activation_ != schema::ActivationType::ActivationType_NO_ACTIVATION) {
nvinfer1::ILayer *activation_layer = ActivationTensorRT::AddActivation(network, activation_, 0, 0, 0, out_tensor);
@ -117,8 +87,8 @@ int MatMulTensorRT::AddInnerOp(nvinfer1::INetworkDefinition *network) {
return RET_OK;
}
int MatMulTensorRT::PreprocessInputs(nvinfer1::INetworkDefinition *network, ITensorHelper *matmul_a,
ITensorHelper *matmul_b) {
int MatMulTensorRT::PreprocessMatMulInputs(nvinfer1::INetworkDefinition *network, ITensorHelper *matmul_a,
ITensorHelper *matmul_b) {
int ret;
if (tensorrt_in_tensors_.size() == INPUT_SIZE2) {
int a_index =
@ -179,4 +149,110 @@ int MatMulTensorRT::PreprocessInputs(nvinfer1::INetworkDefinition *network, ITen
}
return RET_OK;
}
nvinfer1::ITensor *MatMulTensorRT::AddAsMatmul(nvinfer1::INetworkDefinition *network) {
ITensorHelper matmul_a;
ITensorHelper matmul_b;
int ret = PreprocessMatMulInputs(network, &matmul_a, &matmul_b);
if (ret != RET_OK || matmul_a.trt_tensor_ == nullptr || matmul_b.trt_tensor_ == nullptr) {
MS_LOG(ERROR) << "PreprocessMatMulInputs matmul failed for " << op_name_;
return nullptr;
}
MS_LOG(DEBUG) << "matmul input a " << GetTensorFormat(matmul_a);
MS_LOG(DEBUG) << "matmul input b " << GetTensorFormat(matmul_b);
auto matmul_layer =
network->addMatrixMultiply(*matmul_a.trt_tensor_, transpose_a_, *matmul_b.trt_tensor_, transpose_b_);
if (matmul_layer == nullptr) {
MS_LOG(ERROR) << "addMatrixMultiply failed for " << op_name_;
return nullptr;
}
matmul_layer->setName(op_name_.c_str());
nvinfer1::ITensor *out_tensor = matmul_layer->getOutput(0);
tensor_name_map_[matmul_layer->getOutput(0)->getName()] = op_name_;
if (in_tensors_.size() == BIAS_INDEX + 1) {
nvinfer1::ITensor *bias = nullptr;
if (in_tensors_[BIAS_INDEX].Shape().size() < static_cast<size_t>(out_tensor->getDimensions().nbDims)) {
bias =
ConvertTensorWithExpandDims(network, in_tensors_[BIAS_INDEX], out_tensor->getDimensions().nbDims, op_name_);
} else if (in_tensors_[BIAS_INDEX].Shape().size() == static_cast<size_t>(out_tensor->getDimensions().nbDims)) {
bias = ConvertConstantTensor(network, in_tensors_[BIAS_INDEX], op_name_);
} else {
MS_LOG(ERROR) << "input tensor shape is invalid for " << op_name_;
return nullptr;
}
if (bias == nullptr) {
MS_LOG(ERROR) << "create constant bias tensor failed for " << op_name_;
return nullptr;
}
auto bias_layer = network->addElementWise(*matmul_layer->getOutput(0), *bias, nvinfer1::ElementWiseOperation::kSUM);
if (bias_layer == nullptr) {
MS_LOG(ERROR) << "add bias add layer failed for " << op_name_;
return nullptr;
}
auto bias_layer_name = op_name_ + "_bias";
bias_layer->setName(bias_layer_name.c_str());
out_tensor = bias_layer->getOutput(0);
}
return out_tensor;
}
nvinfer1::ITensor *MatMulTensorRT::AddAsFullConnect(nvinfer1::INetworkDefinition *network) {
nvinfer1::Weights weight;
nvinfer1::Weights bias = ConvertWeight(in_tensors_[BIAS_INDEX]);
nvinfer1::ITensor *input_a = tensorrt_in_tensors_[0].trt_tensor_;
out_format_ = tensorrt_in_tensors_[0].format_;
if (input_a->getDimensions().nbDims != DIMENSION_4D) {
nvinfer1::Dims in_dims(input_a->getDimensions());
in_dims.nbDims = DIMENSION_4D;
for (int i = input_a->getDimensions().nbDims; i < DIMENSION_4D; i++) {
in_dims.d[i] = 1;
}
input_a = Reshape(network, input_a, in_dims);
if (input_a == nullptr) {
MS_LOG(ERROR) << "reshape input failed for " << op_name_;
return nullptr;
}
MS_LOG(DEBUG) << "full connect expand input a to " << GetTensorFormat(input_a);
} else {
ITensorHelper tmp_input;
int ret = PreprocessInputs2SameDim(network, tensorrt_in_tensors_[0], &tmp_input);
if (ret != RET_OK || tmp_input.trt_tensor_ == nullptr) {
MS_LOG(ERROR) << "rPreprocessInputs2SameDim failed for " << op_name_;
return nullptr;
}
input_a = tmp_input.trt_tensor_;
out_format_ = tmp_input.format_;
MS_LOG(DEBUG) << "full connect preprocess input a to " << GetTensorFormat(tmp_input);
}
if (transpose_b_ == nvinfer1::MatrixOperation::kNONE) {
// transpose weight
weight = TransposeWeight2D(in_tensors_[1], &weight_ptr_);
if (weight.values == nullptr || weight_ptr_ == nullptr) {
MS_LOG(ERROR) << "TransposeWeight2D input weight failed for " << op_name_;
return nullptr;
}
} else {
weight = ConvertWeight(in_tensors_[1]);
}
int output_cnt = in_tensors_[BIAS_INDEX].Shape()[0];
auto fc_layer = network->addFullyConnected(*input_a, output_cnt, weight, bias);
if (fc_layer == nullptr) {
MS_LOG(ERROR) << "add fully connected layer failed for " << op_name_;
return nullptr;
}
fc_layer->setName((op_name_ + "_fullyconnected").c_str());
nvinfer1::ITensor *out_tensor = fc_layer->getOutput(0);
if (out_tensor->getDimensions().nbDims != out_tensors_[0].Shape().size()) {
std::vector<int64_t> out_dims(out_tensors_[0].Shape());
out_dims[0] = out_tensor->getDimensions().d[0];
out_tensor = Reshape(network, out_tensor, out_dims);
}
return out_tensor;
}
} // namespace mindspore::lite

View File

@ -28,7 +28,7 @@ class MatMulTensorRT : public TensorRTOp {
const std::vector<mindspore::MSTensor> &out_tensors, const std::string &name)
: TensorRTOp(primitive, in_tensors, out_tensors, name) {}
~MatMulTensorRT() override = default;
~MatMulTensorRT() override;
int IsSupport(const schema::Primitive *primitive, const std::vector<mindspore::MSTensor> &in_tensors,
const std::vector<mindspore::MSTensor> &out_tensors) override;
@ -36,11 +36,17 @@ class MatMulTensorRT : public TensorRTOp {
int AddInnerOp(nvinfer1::INetworkDefinition *network) override;
private:
int PreprocessInputs(nvinfer1::INetworkDefinition *network, ITensorHelper *matmul_a, ITensorHelper *matmul_b);
int PreprocessMatMulInputs(nvinfer1::INetworkDefinition *network, ITensorHelper *matmul_a, ITensorHelper *matmul_b);
nvinfer1::ITensor *AddAsMatmul(nvinfer1::INetworkDefinition *network);
nvinfer1::ITensor *AddAsFullConnect(nvinfer1::INetworkDefinition *network);
nvinfer1::MatrixOperation transpose_a_ = nvinfer1::MatrixOperation::kNONE;
nvinfer1::MatrixOperation transpose_b_ = nvinfer1::MatrixOperation::kNONE;
Format out_format_;
schema::ActivationType activation_{schema::ActivationType::ActivationType_NO_ACTIVATION};
void *weight_ptr_{nullptr};
};
} // namespace mindspore::lite
#endif // MINDSPORE_LITE_SRC_RUNTIME_DELEGATE_TENSORRT_OP_MATMUL_TENSORRT_H_

View File

@ -45,7 +45,7 @@ int ReduceTensorRT::AddInnerOp(nvinfer1::INetworkDefinition *network) {
}
bool keep_dims = reduce_op->keep_dims();
out_format_ = tensorrt_in_tensors_[0].format_;
nvinfer1::ITensor *shuffler_input = tensorrt_in_tensors_[0].trt_tensor_;
nvinfer1::ITensor *reduce_input = tensorrt_in_tensors_[0].trt_tensor_;
MS_LOG(DEBUG) << "origin input " << GetTensorFormat(tensorrt_in_tensors_[0]);
if (tensorrt_in_tensors_[0].trt_tensor_->getDimensions().nbDims == DIMENSION_4D &&
!SameDims(tensorrt_in_tensors_[0].trt_tensor_->getDimensions(), in_tensors_[0].Shape())) {
@ -57,31 +57,23 @@ int ReduceTensorRT::AddInnerOp(nvinfer1::INetworkDefinition *network) {
return RET_ERROR;
}
transpose_layer->setName((op_name_ + "_transpose_in").c_str());
shuffler_input = transpose_layer->getOutput(0);
reduce_input = transpose_layer->getOutput(0);
out_format_ = Format::NHWC;
} else if (tensorrt_in_tensors_[0].format_ == Format::NHWC) {
// NHWC->NCHW
nvinfer1::IShuffleLayer *transpose_layer = NHWC2NCHW(network, *tensorrt_in_tensors_[0].trt_tensor_);
if (transpose_layer == nullptr) {
MS_LOG(ERROR) << "create transpose layer failed for " << op_name_;
return RET_ERROR;
}
transpose_layer->setName((op_name_ + "_transpose_in").c_str());
reduce_input = transpose_layer->getOutput(0);
out_format_ = Format::NCHW;
} else {
MS_LOG(WARNING) << "input tensor format needs check: " << op_name_;
}
}
nvinfer1::ITensor *reduce_input = shuffler_input;
// 4 dims support reduce at each axis
if (reduce_input->getDimensions().nbDims < DIMENSION_4D) {
nvinfer1::IShuffleLayer *unsqueeze_layer = network->addShuffle(*reduce_input);
if (unsqueeze_layer == nullptr) {
MS_LOG(ERROR) << "add Shuffle op failed for TensorRT.";
return RET_ERROR;
}
unsqueeze_layer->setName((op_name_ + "_unsqueeze4dims").c_str());
nvinfer1::Dims unsqueeze_dims = reduce_input->getDimensions();
for (int i = unsqueeze_dims.nbDims; i < DIMENSION_4D; i++) {
unsqueeze_dims.d[i] = 1;
}
unsqueeze_dims.nbDims = DIMENSION_4D;
unsqueeze_layer->setReshapeDimensions(unsqueeze_dims);
reduce_input = unsqueeze_layer->getOutput(0);
}
MS_LOG(DEBUG) << "after transpose and expand dims " << GetTensorFormat(reduce_input, out_format_, true);
MS_LOG(DEBUG) << "after transpose input " << GetTensorFormat(reduce_input, out_format_, true);
uint32_t reduceAxis = GetAxis();
nvinfer1::IReduceLayer *layer =
@ -93,27 +85,6 @@ int ReduceTensorRT::AddInnerOp(nvinfer1::INetworkDefinition *network) {
layer->setName(op_name_.c_str());
nvinfer1::ITensor *out_tensor = layer->getOutput(0);
if (in_tensors_[0].Shape().size() != DIMENSION_4D) {
// queeze to origin dim
nvinfer1::IShuffleLayer *squeeze_layer = network->addShuffle(*layer->getOutput(0));
if (squeeze_layer == nullptr) {
MS_LOG(ERROR) << "add Shuffle op failed for TensorRT.";
return RET_ERROR;
}
squeeze_layer->setName((op_name_ + "_squeeze").c_str());
nvinfer1::Dims squeeze_dims = ConvertCudaDims(out_tensors_[0].Shape());
if (squeeze_dims.nbDims == -1) {
MS_LOG(ERROR) << "ConvertCudaDims failed for " << op_name_;
return RET_ERROR;
}
for (int i = 0; i < squeeze_dims.nbDims; i++) {
if (layer->getOutput(0)->getDimensions().d[i] == -1) {
squeeze_dims.d[i] = 0;
}
}
squeeze_layer->setReshapeDimensions(squeeze_dims);
out_tensor = squeeze_layer->getOutput(0);
}
if (out_tensor == nullptr) {
MS_LOG(ERROR) << "addReduce output tensor create failed for TensorRT.";
return RET_ERROR;
@ -123,6 +94,7 @@ int ReduceTensorRT::AddInnerOp(nvinfer1::INetworkDefinition *network) {
MS_LOG(DEBUG) << "output " << GetTensorFormat(tensorrt_out_tensors_[0]);
return RET_OK;
}
uint32_t ReduceTensorRT::GetAxis() {
// axis
uint32_t reduceAxis = 0;

View File

@ -138,7 +138,7 @@ int ScaleTensorRT::AddInnerOp(nvinfer1::INetworkDefinition *network) {
nvinfer1::ITensor *ScaleTensorRT::PreProcessInputTensor(nvinfer1::INetworkDefinition *network) {
nvinfer1::ITensor *scale_in_tensor = tensorrt_in_tensors_[0].trt_tensor_;
if (in_tensors_[0].Shape().size() < INPUT_SIZE4) {
if (in_tensors_[0].Shape().size() < DIMENSION_4D) {
// unsqueeze input Itensor to 4 dims
scale_in_tensor = AddUnsqueezeOp(network);
if (scale_in_tensor == nullptr) {
@ -202,12 +202,6 @@ nvinfer1::ScaleMode ScaleTensorRT::GetScaleMode(int64_t axis) {
}
nvinfer1::ITensor *ScaleTensorRT::AddUnsqueezeOp(nvinfer1::INetworkDefinition *network) {
nvinfer1::IShuffleLayer *unsqueeze_layer = network->addShuffle(*this->tensorrt_in_tensors_[0].trt_tensor_);
if (unsqueeze_layer == nullptr) {
MS_LOG(ERROR) << "addShuffle failed for: " << op_name_;
return nullptr;
}
unsqueeze_layer->setName((op_name_ + "_unsqueeze").c_str());
auto unsqueeze_shape = ConvertMSShape(tensorrt_in_tensors_[0].trt_tensor_->getDimensions());
size_t unsqueeze_size = DIMENSION_4D - unsqueeze_shape.size();
for (size_t i = 0; i < unsqueeze_size; i++) {
@ -223,24 +217,16 @@ nvinfer1::ITensor *ScaleTensorRT::AddUnsqueezeOp(nvinfer1::INetworkDefinition *n
MS_LOG(ERROR) << "ConvertCudaDims failed for " << op_name_;
return nullptr;
}
unsqueeze_layer->setReshapeDimensions(unsqueeze_dims);
return unsqueeze_layer->getOutput(0);
return Reshape(network, tensorrt_in_tensors_[0].trt_tensor_, unsqueeze_shape);
}
nvinfer1::ITensor *ScaleTensorRT::AddSqueezeOp(nvinfer1::ITensor *in_tensor, nvinfer1::INetworkDefinition *network) {
nvinfer1::IShuffleLayer *squeeze_layer = network->addShuffle(*in_tensor);
if (squeeze_layer == nullptr) {
MS_LOG(ERROR) << "addShuffle failed for: " << op_name_;
return nullptr;
}
squeeze_layer->setName((op_name_ + "_squeeze").c_str());
nvinfer1::Dims squeeze_dims;
squeeze_dims.nbDims = out_tensors_[0].Shape().size();
for (int i = 0; i < squeeze_dims.nbDims; i++) {
squeeze_dims.d[i] = in_tensor->getDimensions().d[i] == -1 ? 0 : in_tensor->getDimensions().d[i];
}
MS_LOG(DEBUG) << "squeeze_dims cnt for scale: " << squeeze_dims.nbDims;
squeeze_layer->setReshapeDimensions(squeeze_dims);
return squeeze_layer->getOutput(0);
return Reshape(network, in_tensor, squeeze_dims);
}
} // namespace mindspore::lite

View File

@ -43,6 +43,7 @@
#include "src/delegate/tensorrt/op/reducescatter_tensorrt.h"
#include "src/delegate/tensorrt/op/allgather_tensorrt.h"
#include "src/delegate/tensorrt/op/lstm_tensorrt.h"
#include "src/delegate/tensorrt/op/fullyconnected_tensorrt.h"
namespace mindspore::lite {
TensorRTDelegate::~TensorRTDelegate() {
@ -102,7 +103,7 @@ Status TensorRTDelegate::Init() {
{schema::PrimitiveType_Gather, GetTensorRTOp<GatherTensorRT>},
{schema::PrimitiveType_LSTM, GetTensorRTOp<LSTMTensorRT>},
{schema::PrimitiveType_MatMulFusion, GetTensorRTOp<MatMulTensorRT>},
{schema::PrimitiveType_FullConnection, GetTensorRTOp<MatMulTensorRT>},
{schema::PrimitiveType_FullConnection, GetTensorRTOp<FullyConnectedTensorRT>},
{schema::PrimitiveType_AvgPoolFusion, GetTensorRTOp<PoolTensorRT>},
{schema::PrimitiveType_MaxPoolFusion, GetTensorRTOp<PoolTensorRT>},
{schema::PrimitiveType_PadFusion, GetTensorRTOp<PadTensorRT>},

View File

@ -96,7 +96,7 @@ int TensorRTSubGraph::Init(cudaStream_t stream) {
int TensorRTSubGraph::GetInt8DynamicRange() {
if (!IsInt8Mode() || !runtime_->GetBuilder()->platformHasFastInt8()) {
MS_LOG(WARNING) << "no int8 mode, not need dynamic range.";
MS_LOG(INFO) << "no int8 mode, not need dynamic range.";
}
// input tensor
for (size_t i = 0; i < inputs_.size(); i++) {
@ -185,33 +185,6 @@ int TensorRTSubGraph::SetDeviceConfig(cudaStream_t stream) {
return RET_OK;
}
bool TensorRTSubGraph::SupportFP16() {
int deviceCnt = 0;
cudaError ret = cudaGetDeviceCount(&deviceCnt);
if (ret != cudaSuccess) {
MS_LOG(ERROR) << "cudaGetDeviceCount failed.";
return false;
}
std::vector<std::string> supportFP16_versions{"5.3", "6.0", "6.2", "7.0", "7.2", "7.5", "8.0", "8.6"};
cudaDeviceProp prop;
std::string version;
for (int dev = 0; dev < deviceCnt; dev++) {
ret = cudaGetDeviceProperties(&prop, dev);
if (ret != cudaSuccess) {
MS_LOG(ERROR) << "cuDeviceGetAttribute failed.";
return false;
}
version = std::to_string(prop.major) + "." + std::to_string(prop.minor);
if (std::find(supportFP16_versions.begin(), supportFP16_versions.end(), version) != supportFP16_versions.end()) {
MS_LOG(INFO) << "cuda device version is: " << version << ", support FP16, set enable FP16 tag successful";
return true;
}
}
MS_LOG(WARNING) << "cuda device version is: " << version << ", don't support FP16, set enable FP16 tag failed";
return false;
}
bool TensorRTSubGraph::IsInt8Mode() {
bool isInt8Mode = false;
for (auto cur_op : all_ops_) {
@ -242,7 +215,7 @@ bool TensorRTSubGraph::IsInt8Mode() {
void TensorRTSubGraph::SetInt8LayerPresion() {
if (!IsInt8Mode() || !runtime_->GetBuilder()->platformHasFastInt8()) {
MS_LOG(WARNING) << "no int8 mode, not need layer presion.";
MS_LOG(INFO) << "no int8 mode, not need layer presion.";
return;
}
@ -373,9 +346,6 @@ nvinfer1::Dims TensorRTSubGraph::ParseInputDimsProfile(const mindspore::MSTensor
return input_dims;
}
nvinfer1::Dims input_dims_opt = ConvertCudaDims(in_tensor.Shape());
if (input_batchsize_index_ != -1) {
input_dims_opt.d[input_batchsize_index_] = 1;
}
if (!profile_->setDimensions(in_tensor.Name().c_str(), nvinfer1::OptProfileSelector::kOPT, input_dims_opt)) {
MS_LOG(ERROR) << "setDimensions of kOPT failed for " << in_tensor.Name();
return input_dims;

View File

@ -96,8 +96,6 @@ class TensorRTSubGraph : public kernel::Kernel {
int GetTensorName(TensorRTOp *cur_op);
bool SupportFP16();
nvinfer1::ITensor *SetTensorRTNetworkInput(const mindspore::MSTensor &in_tensor);
ITensorHelper FindTensorRTInputs(TensorRTOp *cur_op, const mindspore::MSTensor &in_tensor);

View File

@ -19,6 +19,7 @@
#include <map>
#include <numeric>
#include <functional>
#include "src/delegate/tensorrt/distribution/distribution_collective.h"
namespace mindspore::lite {
nvinfer1::Dims ConvertCudaDims(int data, size_t size) {
@ -236,53 +237,100 @@ nvinfer1::ITensor *ConvertTensorWithExpandDims(nvinfer1::INetworkDefinition *net
return constant_tensor->getOutput(0);
}
nvinfer1::Weights TransposeWeight(const mindspore::MSTensor &ms_tensor, void **pack_weight) {
nvinfer1::Weights weights{};
MS_LOG(DEBUG) << "ms_tensor.DataType(): " << static_cast<int>(ms_tensor.DataType());
if (ms_tensor.DataType() == DataType::kNumberTypeFloat16) {
weights.type = nvinfer1::DataType::kHALF;
weights.count = ms_tensor.ElementNum();
void *pack_weight_tmp = malloc(ms_tensor.DataSize());
if (pack_weight_tmp == nullptr) {
MS_LOG(ERROR) << "Malloc buffer failed.";
return weights;
}
MS_ASSERT(ms_tensor.Data());
auto weight_shape = ms_tensor.Shape();
PackNHWCToNCHWFp16(ms_tensor.Data().get(), pack_weight_tmp, weight_shape[0], weight_shape[1] * weight_shape[2],
weight_shape[3], 0, 0);
*pack_weight = pack_weight_tmp;
weights.values = pack_weight_tmp;
return weights;
} else {
return TransposeWeightFP32(ms_tensor, pack_weight);
}
}
nvinfer1::Weights TransposeWeightFP32(const mindspore::MSTensor &ms_tensor, void **pack_weight) {
nvinfer1::Weights TransposeWeight4D(const mindspore::MSTensor &ms_tensor, void **pack_weight) {
// usage notice: malloc addr saved to pack_weight, save pack_weight ptr and free it when deconstruct
nvinfer1::Weights weights{};
weights.count = ms_tensor.ElementNum();
if (lite::ConvertDataType(ms_tensor.DataType()) != nvinfer1::DataType::kFLOAT) {
MS_LOG(WARNING) << "weights data type is not float32";
}
weights.type = nvinfer1::DataType::kFLOAT;
auto weight_shape = ms_tensor.Shape();
const void *src_ptr = ms_tensor.Data().get();
if (src_ptr == nullptr) {
MS_LOG(ERROR) << "TransposeWeight from a MSTensor with nullptr data";
if (weight_shape.size() != DIMENSION_4D) {
MS_LOG(ERROR) << ms_tensor.Name() << " dims is " << weight_shape.size();
return weights;
}
float *pack_weight_tmp = reinterpret_cast<float *>(malloc(ms_tensor.ElementNum() * sizeof(float)));
if (ms_tensor.Data() == nullptr) {
MS_LOG(ERROR) << ms_tensor.Name() << " has null data";
return weights;
}
void *pack_weight_tmp = malloc(ms_tensor.DataSize());
if (pack_weight_tmp == nullptr) {
MS_LOG(ERROR) << "Malloc buffer failed.";
return weights;
}
PackNHWCToNCHWFp32(src_ptr, pack_weight_tmp, weight_shape[0], weight_shape[1] * weight_shape[2], weight_shape[3], 0,
0);
weights.values = pack_weight_tmp;
*pack_weight = pack_weight_tmp;
weights.values = pack_weight_tmp;
switch (ms_tensor.DataType()) {
case DataType::kNumberTypeFloat16: {
weights.type = nvinfer1::DataType::kHALF;
PackNHWCToNCHWFp16(ms_tensor.Data().get(), pack_weight_tmp, weight_shape[0], weight_shape[1] * weight_shape[2],
weight_shape[3], 0, 0);
break;
}
case DataType::kNumberTypeFloat32: {
weights.type = nvinfer1::DataType::kFLOAT;
PackNHWCToNCHWFp32(ms_tensor.Data().get(), pack_weight_tmp, weight_shape[0], weight_shape[1] * weight_shape[2],
weight_shape[3], 0, 0);
break;
}
default: {
MS_LOG(ERROR) << ms_tensor.Name() << " has unsupported tensor datatype for transpose data : "
<< static_cast<int>(ms_tensor.DataType());
}
}
return weights;
}
nvinfer1::Weights TransposeWeight2D(const mindspore::MSTensor &ms_tensor, void **pack_weight) {
// usage notice: malloc addr saved to pack_weight, save pack_weight ptr and free it when deconstruct
nvinfer1::Weights weights{};
weights.count = ms_tensor.ElementNum();
auto weight_shape = ms_tensor.Shape();
if (weight_shape.size() != DIMENSION_2D) {
MS_LOG(ERROR) << ms_tensor.Name() << " dims is " << weight_shape.size();
return weights;
}
if (ms_tensor.Data() == nullptr) {
MS_LOG(ERROR) << ms_tensor.Name() << " has null data";
return weights;
}
void *pack_weight_tmp = malloc(ms_tensor.DataSize());
if (pack_weight_tmp == nullptr) {
MS_LOG(ERROR) << "Malloc buffer failed.";
return weights;
}
*pack_weight = pack_weight_tmp;
weights.values = pack_weight_tmp;
int row = weight_shape[0];
int col = weight_shape[1];
switch (ms_tensor.DataType()) {
case DataType::kNumberTypeFloat16: {
weights.type = nvinfer1::DataType::kHALF;
auto src = static_cast<const uint16_t *>(ms_tensor.Data().get());
auto dst = static_cast<uint16_t *>(pack_weight_tmp);
for (int r = 0; r < row; ++r) {
for (int c = 0; c < col; ++c) {
dst[c * row + r] = src[r * col + c];
}
}
break;
}
case DataType::kNumberTypeFloat32: {
weights.type = nvinfer1::DataType::kFLOAT;
auto dst = static_cast<float *>(pack_weight_tmp);
auto src = static_cast<const float *>(ms_tensor.Data().get());
for (int r = 0; r < row; ++r) {
for (int c = 0; c < col; ++c) {
dst[c * row + r] = src[r * col + c];
}
}
break;
}
default: {
MS_LOG(ERROR) << ms_tensor.Name() << " has unsupported tensor datatype for transpose data : "
<< static_cast<int>(ms_tensor.DataType());
}
}
return weights;
}
@ -522,4 +570,20 @@ void DeserializeValue(void const **buffer, size_t *buffer_size, void *value, siz
*buffer = static_cast<const char *>(*buffer) + cpy_size;
*buffer_size -= cpy_size;
}
nvinfer1::ITensor *Reshape(nvinfer1::INetworkDefinition *network, nvinfer1::ITensor *input,
const std::vector<int64_t> &shape) {
return Reshape(network, input, ConvertCudaDims(shape));
}
nvinfer1::ITensor *Reshape(nvinfer1::INetworkDefinition *network, nvinfer1::ITensor *input,
const nvinfer1::Dims &shape) {
auto reshape_layer = network->addShuffle(*input);
if (reshape_layer == nullptr) {
MS_LOG(ERROR) << "add reshape_layer failed";
return nullptr;
}
reshape_layer->setReshapeDimensions(shape);
return reshape_layer->getOutput(0);
}
} // namespace mindspore::lite

View File

@ -24,7 +24,6 @@
#include "mindspore/core/ir/dtype/type_id.h"
#include "schema/ops_generated.h"
#include "nnacl/pack.h"
#include "src/delegate/tensorrt/distribution/distribution_collective.h"
#define kNCHW_N 0
#define kNCHW_C 1
@ -81,9 +80,9 @@ nvinfer1::ITensor *ConvertTensorWithExpandDims(nvinfer1::INetworkDefinition *net
nvinfer1::ITensor *ConvertScalarToITensor(nvinfer1::INetworkDefinition *network, size_t shape_size, const void *value,
const DataType data_type, const std::string &op_name);
nvinfer1::Weights TransposeWeight(const mindspore::MSTensor &ms_tensor, void **pack_weight);
nvinfer1::Weights TransposeWeight4D(const mindspore::MSTensor &ms_tensor, void **pack_weight);
nvinfer1::Weights TransposeWeightFP32(const mindspore::MSTensor &ms_tensor, void **pack_weight);
nvinfer1::Weights TransposeWeight2D(const mindspore::MSTensor &ms_tensor, void **pack_weight);
nvinfer1::Weights ConvertWeight(const mindspore::MSTensor &ms_tensor);
@ -117,6 +116,12 @@ void SerializeValue(void **buffer, const void *value, size_t cpy_size);
void DeserializeValue(void const **buffer, size_t *buffer_size, void *value, size_t cpy_size);
nvinfer1::ITensor *Reshape(nvinfer1::INetworkDefinition *network, nvinfer1::ITensor *input,
const std::vector<int64_t> &shape);
nvinfer1::ITensor *Reshape(nvinfer1::INetworkDefinition *network, nvinfer1::ITensor *input,
const nvinfer1::Dims &shape);
template <typename T1, typename T2>
bool SameDims(const std::vector<T1> &shape1, const std::vector<T2> &shape2) {
if (shape1.size() != shape2.size()) {