forked from mindspore-Ecosystem/mindspore
!23612 [MSLITE] fix bugs for diverse networks compatibility in tensorrt delegate
Merge pull request !23612 from Liu_Xuu/trt_0914_widedeep
This commit is contained in:
commit
9e61c44d07
|
@ -122,6 +122,10 @@ std::vector<mindspore::MSTensor> GraphOutTensors(const std::vector<T *> &ops, De
|
|||
if (find(out_tensors.begin(), out_tensors.end(), out_tensor) == out_tensors.end()) {
|
||||
all_out_tensors.push_back(out_tensor);
|
||||
}
|
||||
if (find(model->outputs().begin(), model->outputs().end(), out_tensor) != model->outputs().end() &&
|
||||
find(out_tensors.begin(), out_tensors.end(), out_tensor) == out_tensors.end()) {
|
||||
out_tensors.push_back(out_tensor);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -198,7 +198,8 @@ nvinfer1::ITensor *ElementWiseTensorRT::AddActivation(nvinfer1::INetworkDefiniti
|
|||
int ElementWiseTensorRT::AddConstTensor(nvinfer1::INetworkDefinition *network) {
|
||||
// create ITensor from MS constant tensor of index 1 - first_in_tensor_index_
|
||||
nvinfer1::ITensor *constant_input = nullptr;
|
||||
if (this->in_tensors_[1 - first_in_tensor_index_].Shape().size() == 0) {
|
||||
if (this->in_tensors_[1 - first_in_tensor_index_].Shape().size() == 0 ||
|
||||
this->in_tensors_[1 - first_in_tensor_index_].ElementNum() == 1) {
|
||||
constant_input = lite::ConvertScalarToITensor(network, this->in_tensors_[first_in_tensor_index_].Shape().size(),
|
||||
in_tensors_[1 - first_in_tensor_index_].Data().get(),
|
||||
in_tensors_[1 - first_in_tensor_index_].DataType());
|
||||
|
|
|
@ -53,14 +53,27 @@ int GatherTensorRT::AddInnerOp(nvinfer1::INetworkDefinition *network) {
|
|||
MS_LOG(ERROR) << "network is invalid";
|
||||
return RET_ERROR;
|
||||
}
|
||||
// convert constant MSTensor to ITensor
|
||||
nvinfer1::ITensor *add_tensor = lite::ConvertConstantTensor(network, this->in_tensors_[1]);
|
||||
if (add_tensor == nullptr) {
|
||||
MS_LOG(ERROR) << "add a new tensor failed for TensorRT GatherTensorRTOp.";
|
||||
|
||||
nvinfer1::ITensor *gather_input = this->tensorrt_in_tensors_[0].trt_tensor_;
|
||||
if (in_tensors_[0].IsConst()) {
|
||||
gather_input = lite::ConvertConstantTensor(network, this->in_tensors_[0]);
|
||||
MS_LOG(INFO) << "gather input is const tensor " << op_name_;
|
||||
}
|
||||
if (gather_input == nullptr) {
|
||||
MS_LOG(ERROR) << "get gather input failed for: " << op_name_;
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
nvinfer1::ITensor *indices_tensor = this->tensorrt_in_tensors_[tensorrt_in_tensors_.size() - 1].trt_tensor_;
|
||||
if (in_tensors_[1].IsConst()) {
|
||||
indices_tensor = lite::ConvertConstantTensor(network, this->in_tensors_[1]);
|
||||
MS_LOG(INFO) << "gather indices is const tensor " << op_name_;
|
||||
}
|
||||
if (indices_tensor == nullptr) {
|
||||
MS_LOG(ERROR) << "get gather indices failed for: " << op_name_;
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
nvinfer1::ITensor *gather_input = tensorrt_in_tensors_[0].trt_tensor_;
|
||||
Format out_format = tensorrt_in_tensors_[0].format_;
|
||||
if (tensorrt_in_tensors_[0].trt_tensor_->getDimensions().nbDims == DIMENSION_4D &&
|
||||
tensorrt_in_tensors_[0].format_ == Format::NCHW) {
|
||||
|
@ -75,7 +88,8 @@ int GatherTensorRT::AddInnerOp(nvinfer1::INetworkDefinition *network) {
|
|||
out_format = Format::NHWC;
|
||||
}
|
||||
|
||||
nvinfer1::IGatherLayer *gather_layer = network->addGather(*gather_input, *add_tensor /* indices */, axis_ /* axis */);
|
||||
nvinfer1::IGatherLayer *gather_layer =
|
||||
network->addGather(*gather_input, *indices_tensor /* indices */, axis_ /* axis */);
|
||||
if (gather_layer == nullptr) {
|
||||
MS_LOG(ERROR) << "addGather failed for TensorRT.";
|
||||
return RET_ERROR;
|
||||
|
|
|
@ -56,7 +56,7 @@ int ReduceTensorRT::AddInnerOp(nvinfer1::INetworkDefinition *network) {
|
|||
return RET_ERROR;
|
||||
}
|
||||
bool keep_dims = reduce_op->keep_dims();
|
||||
nvinfer1::ITensor *reduce_input = tensorrt_in_tensors_[0].trt_tensor_;
|
||||
|
||||
if (tensorrt_in_tensors_[0].trt_tensor_->getDimensions().nbDims == DIMENSION_4D &&
|
||||
tensorrt_in_tensors_[0].format_ == Format::NCHW) {
|
||||
out_format_ = Format::NHWC;
|
||||
|
@ -64,8 +64,26 @@ int ReduceTensorRT::AddInnerOp(nvinfer1::INetworkDefinition *network) {
|
|||
out_format_ = tensorrt_in_tensors_[0].format_;
|
||||
}
|
||||
|
||||
uint32_t reduceAxis = GetAxis();
|
||||
nvinfer1::ITensor *reduce_input = tensorrt_in_tensors_[0].trt_tensor_;
|
||||
// 4 dims support reduce at each axis
|
||||
if (tensorrt_in_tensors_[0].trt_tensor_->getDimensions().nbDims != DIMENSION_4D) {
|
||||
nvinfer1::IShuffleLayer *unsqueeze_layer = network->addShuffle(*tensorrt_in_tensors_[0].trt_tensor_);
|
||||
if (unsqueeze_layer == nullptr) {
|
||||
MS_LOG(ERROR) << "add Shuffle op failed for TensorRT.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
unsqueeze_layer->setName((op_name_ + "_unsqueeze4dims").c_str());
|
||||
nvinfer1::Dims unsqueeze_dims = tensorrt_in_tensors_[0].trt_tensor_->getDimensions();
|
||||
for (int i = unsqueeze_dims.nbDims; i < 4; i++) {
|
||||
unsqueeze_dims.d[i] = 1;
|
||||
}
|
||||
unsqueeze_dims.nbDims = 4;
|
||||
|
||||
unsqueeze_layer->setReshapeDimensions(unsqueeze_dims);
|
||||
reduce_input = unsqueeze_layer->getOutput(0);
|
||||
}
|
||||
|
||||
uint32_t reduceAxis = GetAxis();
|
||||
nvinfer1::IReduceLayer *layer = network->addReduce(*reduce_input, reduce_op_, reduceAxis, keep_dims);
|
||||
if (layer == nullptr) {
|
||||
MS_LOG(ERROR) << "addReduce failed for TensorRT.";
|
||||
|
@ -74,6 +92,18 @@ int ReduceTensorRT::AddInnerOp(nvinfer1::INetworkDefinition *network) {
|
|||
layer->setName(op_name_.c_str());
|
||||
|
||||
nvinfer1::ITensor *out_tensor = layer->getOutput(0);
|
||||
if (in_tensors_[0].Shape().size() != DIMENSION_4D) {
|
||||
// queeze to origin dim
|
||||
nvinfer1::IShuffleLayer *squeeze_layer = network->addShuffle(*layer->getOutput(0));
|
||||
if (squeeze_layer == nullptr) {
|
||||
MS_LOG(ERROR) << "add Shuffle op failed for TensorRT.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
squeeze_layer->setName((op_name_ + "_squeeze").c_str());
|
||||
nvinfer1::Dims squeeze_dims = ConvertCudaDims(out_tensors_[0].Shape());
|
||||
squeeze_layer->setReshapeDimensions(squeeze_dims);
|
||||
out_tensor = squeeze_layer->getOutput(0);
|
||||
}
|
||||
if (out_tensor == nullptr) {
|
||||
MS_LOG(ERROR) << "addReduce output tensor create failed for TensorRT.";
|
||||
return RET_ERROR;
|
||||
|
@ -96,10 +126,10 @@ uint32_t ReduceTensorRT::GetAxis() {
|
|||
int *axis_data = reinterpret_cast<int *>(axis_tensor.MutableData());
|
||||
bool need_transpose_axis =
|
||||
(out_format_ == Format::NCHW) && (tensorrt_in_tensors_[0].trt_tensor_->getDimensions().nbDims == DIMENSION_4D);
|
||||
uint32_t base = std::pow(2, in_tensors_[0].Shape().size());
|
||||
for (int i = 0; i < axis_tensor.ElementNum(); i++) {
|
||||
int format_axis_data = need_transpose_axis ? ConvertAxisFromNHWC2NCHW(*axis_data) : *axis_data;
|
||||
// 16 is 1111 as 4 dims
|
||||
reduceAxis |= (16 - (1u << format_axis_data));
|
||||
reduceAxis |= (base - (1u << format_axis_data));
|
||||
axis_data++;
|
||||
}
|
||||
MS_LOG(DEBUG) << "reduceAxis: " << reduceAxis;
|
||||
|
|
|
@ -63,6 +63,14 @@ void TensorRTAllocator::MarkMemValid(const std::string &name, bool isValid) {
|
|||
return;
|
||||
}
|
||||
|
||||
bool TensorRTAllocator::GetMemIsValid(const std::string &name) {
|
||||
if (cuda_tensor_map_.find(name) == cuda_tensor_map_.end()) {
|
||||
MS_LOG(INFO) << "tensor :" << name << " not in cuda Allocator pool.";
|
||||
return false;
|
||||
}
|
||||
return cuda_tensor_map_[name].is_valid_mem;
|
||||
}
|
||||
|
||||
void *TensorRTAllocator::GetDevicePtr(const std::string &tensor_name) {
|
||||
if (tensor_name.empty()) {
|
||||
return nullptr;
|
||||
|
|
|
@ -51,6 +51,8 @@ class TensorRTAllocator {
|
|||
|
||||
void MarkMemValid(const std::string &name, bool isValid);
|
||||
|
||||
bool GetMemIsValid(const std::string &name);
|
||||
|
||||
private:
|
||||
std::map<std::string, CudaTensorParam> cuda_tensor_map_;
|
||||
};
|
||||
|
|
|
@ -62,7 +62,6 @@ int TensorRTSubGraph::Init() {
|
|||
}
|
||||
for (size_t i = 0; i < inputs_.size(); i++) {
|
||||
if (inputs_[i].Shape().size() != DIMENSION_4D) {
|
||||
MS_LOG(WARNING) << "hw dims resize is unsupported.";
|
||||
input_hw_index_ = -1;
|
||||
}
|
||||
}
|
||||
|
@ -140,6 +139,13 @@ bool TensorRTSubGraph::SupportFP16() {
|
|||
}
|
||||
|
||||
nvinfer1::ITensor *TensorRTSubGraph::SetTensorRTNetworkInput(const mindspore::MSTensor &in_tensor) {
|
||||
for (int i = 0; i < this->network_->getNbInputs(); i++) {
|
||||
if (in_tensor.Name().compare(this->network_->getInput(i)->getName()) == 0) {
|
||||
MS_LOG(INFO) << "input tensor is already added in network: " << in_tensor.Name();
|
||||
return this->network_->getInput(i);
|
||||
}
|
||||
}
|
||||
|
||||
auto cuda_dtype = ConvertDataType(in_tensor.DataType());
|
||||
if (static_cast<int>(cuda_dtype) == -1) {
|
||||
MS_LOG(ERROR) << "Unsupported input data type " << static_cast<int>(in_tensor.DataType());
|
||||
|
@ -190,7 +196,7 @@ nvinfer1::ITensor *TensorRTSubGraph::SetTensorRTNetworkInput(const mindspore::MS
|
|||
MS_LOG(ERROR) << "setDimensions of kMAX failed.";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "add network input: " << in_tensor.Name();
|
||||
return this->network_->addInput(in_tensor.Name().c_str(), cuda_dtype, input_dims);
|
||||
}
|
||||
|
||||
|
@ -198,6 +204,7 @@ int TensorRTSubGraph::BuildTensorRTGraph() {
|
|||
MS_ASSERT(!all_ops_.empty());
|
||||
// Connect NetWork.
|
||||
int ret;
|
||||
|
||||
for (auto cur_op : all_ops_) {
|
||||
for (auto in_tensor : cur_op->inputs()) {
|
||||
// Data From CPU
|
||||
|
@ -215,13 +222,13 @@ int TensorRTSubGraph::BuildTensorRTGraph() {
|
|||
if (trt_tensor.trt_tensor_ == nullptr) {
|
||||
// weight tensor
|
||||
if (trt_specific_weight_nodes_.find(cur_op->type()) == trt_specific_weight_nodes_.end()) {
|
||||
if (in_tensor == nullptr) {
|
||||
MS_LOG(ERROR) << "Weight Tensor is nullptr.";
|
||||
if (in_tensor.Data() == nullptr) {
|
||||
MS_LOG(ERROR) << "Weight Tensor data is nullptr.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
trt_tensor.trt_tensor_ = lite::ConvertConstantTensor(this->network_, in_tensor);
|
||||
trt_tensor.format_ = Format::NHWC;
|
||||
MS_LOG(INFO) << "auto convert constant tensor for: " << cur_op->GetOpName();
|
||||
MS_LOG(INFO) << "auto convert constant tensor for: " << in_tensor.Name();
|
||||
cur_op->AddInnerInTensors(trt_tensor);
|
||||
}
|
||||
} else {
|
||||
|
@ -422,12 +429,16 @@ int TensorRTSubGraph::Execute() {
|
|||
return RET_ERROR;
|
||||
}
|
||||
for (size_t i = 0; i < inputs_.size(); i++) {
|
||||
runtime_->GetAllocator()->MarkMemValid(trt_in_tensor_name_[i], false);
|
||||
if (runtime_->GetAllocator()->GetMemIsValid(trt_in_tensor_name_[i])) {
|
||||
MS_LOG(INFO) << "no need memcpy to cuda for input tensor: " << trt_in_tensor_name_[i];
|
||||
continue;
|
||||
}
|
||||
int ret = runtime_->GetAllocator()->SyncMemInHostAndDevice(inputs_[i], trt_in_tensor_name_[i], true);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "sync mem from host to device failed for " << trt_in_tensor_name_[i];
|
||||
return ret;
|
||||
}
|
||||
runtime_->GetAllocator()->MarkMemValid(trt_in_tensor_name_[i], true);
|
||||
}
|
||||
|
||||
auto ret = this->trt_context_->executeV2(tensor_bindings_);
|
||||
|
@ -460,6 +471,11 @@ int TensorRTSubGraph::Execute() {
|
|||
MS_LOG(ERROR) << "sync mem from device to host failed for " << trt_out_tensor_name_[i];
|
||||
return sync_ret;
|
||||
}
|
||||
runtime_->GetAllocator()->MarkMemValid(trt_out_tensor_name_[i], false);
|
||||
}
|
||||
// make mem invalid, prepare for next execute
|
||||
for (size_t i = 0; i < inputs_.size(); i++) {
|
||||
runtime_->GetAllocator()->MarkMemValid(trt_in_tensor_name_[i], false);
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
@ -468,7 +484,7 @@ ITensorHelper TensorRTSubGraph::FindTensorRTInputs(TensorRTOp *cur_op, const min
|
|||
for (auto input_op : cur_op->in_ops()) {
|
||||
for (size_t i = 0; i < input_op->outputs().size(); i++) {
|
||||
auto out_tensor = input_op->outputs().at(i);
|
||||
if (in_tensor == out_tensor) {
|
||||
if (in_tensor.Name().compare(out_tensor.Name()) == 0) {
|
||||
return input_op->GetInnerOutTensor().at(i);
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue