diff --git a/mindspore/lite/CMakeLists.txt b/mindspore/lite/CMakeLists.txt index 1bd1ac0508a..17d0b3a7687 100644 --- a/mindspore/lite/CMakeLists.txt +++ b/mindspore/lite/CMakeLists.txt @@ -28,6 +28,7 @@ option(MSLITE_ENABLE_CONVERTER "enable converter, only x86_64 support" on) option(MSLITE_ENABLE_TOOLS "enable tools" on) option(MSLITE_ENABLE_TESTCASES "enable testcase" off) option(MSLITE_ENABLE_NNIE "enable NNIE" off) +option(MSLITE_ENABLE_TENSORRT "enable TensorRT" off) # Option that can be configured through manually option(ENABLE_VERBOSE "" off) @@ -69,6 +70,9 @@ endif() if(DEFINED ENV{MSLITE_ENABLE_NNIE}) set(MSLITE_ENABLE_NNIE $ENV{MSLITE_ENABLE_NNIE}) endif() +if(DEFINED ENV{MSLITE_ENABLE_TENSORRT}) + set(MSLITE_ENABLE_TENSORRT $ENV{MSLITE_ENABLE_TENSORRT}) +endif() if(PLATFORM_ARM64 OR PLATFORM_ARM32) set(PLATFORM_ARM "on") @@ -101,6 +105,20 @@ if(MSLITE_ENABLE_NPU) endif() endif() +if(MSLITE_ENABLE_TENSORRT) + set(SUPPORT_TENSORRT on) + if(DEFINED ENV{TENSORRT_HOME}) + message("TENSORRT_HOME = $ENV{TENSORRT_HOME}") + else() + message(FATAL_ERROR "please set TENSORRT_HOME, example: export TENSORRT_HOME=/root/usr/TensorRT-6.0.1.5/") + endif() + if(DEFINED ENV{CUDA_HOME}) + message("CUDA_HOME = $ENV{CUDA_HOME}") + else() + message(FATAL_ERROR "please set CUDA_HOME, example: export CUDA_HOME=/usr/local/cuda-10.1/") + endif() +endif() + message(STATUS "************MindSpore Lite Build Option:************") message(STATUS "\tMSLITE_GPU_BACKEND = \t${MSLITE_GPU_BACKEND}") message(STATUS "\tMSLITE_ENABLE_NPU = \t${MSLITE_ENABLE_NPU}") @@ -110,6 +128,7 @@ message(STATUS "\tMSLITE_ENABLE_AVX = \t${MSLITE_ENABLE_AVX}") message(STATUS "\tMSLITE_ENABLE_CONVERTER = \t${MSLITE_ENABLE_CONVERTER}") message(STATUS "\tMSLITE_ENABLE_TOOLS = \t${MSLITE_ENABLE_TOOLS}") message(STATUS "\tMSLITE_ENABLE_TESTCASES = \t${MSLITE_ENABLE_TESTCASES}") +message(STATUS "\tMSLITE_ENABLE_TENSORRT = \t${MSLITE_ENABLE_TENSORRT}") if(ENABLE_ASAN) add_definitions(-fsanitize=address -fno-omit-frame-pointer -mllvm -asan-use-private-alias=1) diff --git a/mindspore/lite/src/CMakeLists.txt b/mindspore/lite/src/CMakeLists.txt index 13c81dd0032..fed601c0426 100644 --- a/mindspore/lite/src/CMakeLists.txt +++ b/mindspore/lite/src/CMakeLists.txt @@ -60,7 +60,6 @@ set(LITE_SRC ${CMAKE_CURRENT_SOURCE_DIR}/common/tensor_util.cc ${CMAKE_CURRENT_SOURCE_DIR}/common/dynamic_library_loader.cc ${CMAKE_CURRENT_SOURCE_DIR}/delegate/delegate.cc - ${CMAKE_CURRENT_SOURCE_DIR}/common/quant_utils.cc ${CMAKE_CURRENT_SOURCE_DIR}/runtime/inner_allocator.cc ${CMAKE_CURRENT_SOURCE_DIR}/runtime/infer_manager.cc ${CMAKE_CURRENT_SOURCE_DIR}/tensor.cc @@ -118,6 +117,7 @@ if(MSLITE_GPU_BACKEND STREQUAL cuda) ) endif() set(TRAIN_SRC + ${CMAKE_CURRENT_SOURCE_DIR}/common/quant_utils.cc ${CMAKE_CURRENT_SOURCE_DIR}/train/train_populate_parameter.cc ${CMAKE_CURRENT_SOURCE_DIR}/train/train_session.cc ${CMAKE_CURRENT_SOURCE_DIR}/train/transfer_session.cc @@ -198,6 +198,19 @@ if(ENABLE_MINDRT) target_link_libraries(mindspore-lite_static mindrt_mid) endif() +if(SUPPORT_TENSORRT) + add_compile_definitions(GPU_TENSORRT) + set(TENSORRT_PATH $ENV{TENSORRT_HOME}) + set(CUDA_PATH $ENV{CUDA_HOME}) + set(TENSORRT_LIB_PATH ${TENSORRT_PATH}/lib) + set(CUDA_LIB_PATH ${CUDA_PATH}/lib64) + include_directories(${TENSORRT_PATH}/include) + include_directories(${CUDA_PATH}/include) + add_subdirectory(delegate/tensorrt) + target_link_libraries(mindspore-lite tensorrt_kernel_mid) + target_link_libraries(mindspore-lite_static tensorrt_kernel_mid) +endif() + if(MSLITE_GPU_BACKEND STREQUAL opencl) add_subdirectory(runtime/kernel/opencl) target_link_libraries(mindspore-lite cpu_kernel_mid opencl_kernel_mid nnacl_mid cpu_ops_mid) diff --git a/mindspore/lite/src/delegate/delegate_utils.cc b/mindspore/lite/src/delegate/delegate_utils.cc new file mode 100644 index 00000000000..4ab89d6377e --- /dev/null +++ b/mindspore/lite/src/delegate/delegate_utils.cc @@ -0,0 +1,25 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/delegate/delegate_utils.h" +namespace mindspore::lite { +bool IsSubGraphInputTensor(const std::vector &inputs, tensor::MSTensor *input) { + if (find(inputs.begin(), inputs.end(), input) != inputs.end()) { + return true; + } + return false; +} +} // namespace mindspore::lite diff --git a/mindspore/lite/src/delegate/delegate_utils.h b/mindspore/lite/src/delegate/delegate_utils.h new file mode 100644 index 00000000000..84114fc645b --- /dev/null +++ b/mindspore/lite/src/delegate/delegate_utils.h @@ -0,0 +1,211 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_DELEGATE_DELEGATE_UTILS +#define MINDSPORE_LITE_SRC_DELEGATE_DELEGATE_UTILS +#include +#include "include/ms_tensor.h" +#include "include/delegate.h" +#include "src/common/log_adapter.h" +#include "src/delegate/tensorrt/op/tensorrt_op.h" + +namespace mindspore::lite { +bool IsSubGraphInputTensor(const std::vector &inputs, tensor::MSTensor *input); + +template +std::vector GetGraphInTensors(std::vector ops) { + std::vector inputs; + auto is_op_output = [&](tensor::MSTensor *tensor) -> bool { + for (auto op : ops) { + auto out_tensors = op->outputs(); + if (find(out_tensors.begin(), out_tensors.end(), tensor) != out_tensors.end()) { + return true; + } + } + return false; + }; + + for (auto op : ops) { + for (auto in_tensor : op->inputs()) { + if (in_tensor->data() == nullptr && !is_op_output(in_tensor)) { + inputs.push_back(in_tensor); + } + } + } + return inputs; +} + +template +std::vector GetGraphOutTensors(const std::vector &ops) { + std::vector outputs; + auto is_op_input = [&](const tensor::MSTensor *tensor) -> bool { + for (auto op : ops) { + auto in_tensors = op->inputs(); + if (find(in_tensors.begin(), in_tensors.end(), tensor) != in_tensors.end()) { + return true; + } + } + return false; + }; + + for (auto op : ops) { + for (auto out_tensor : op->outputs()) { + if (!is_op_input(out_tensor)) { + outputs.push_back(out_tensor); + } + } + } + + for (auto op : ops) { + for (auto out_op : op->out_ops()) { + if (find(ops.begin(), ops.end(), out_op) == ops.end()) { + // visit the out op that is not in the subgraph + for (auto tensor : op->outputs()) { + if (find(out_op->inputs().begin(), out_op->inputs().end(), tensor) != out_op->inputs().end()) { + // find the connected tensor + outputs.push_back(tensor); + break; + } + } + } + } + } + return outputs; +} + +template +std::vector GraphInTensors(const std::vector &ops, DelegateModel *model, KernelIter from, + KernelIter end) { + auto in_tensors = GetGraphInTensors(ops); + std::vector all_in_tensors; + for (auto op : ops) { + for (auto in_tensor : op->inputs()) { + if (in_tensor->data() != nullptr && find(in_tensors.begin(), in_tensors.end(), in_tensor) == in_tensors.end()) { + all_in_tensors.push_back(in_tensor); + } + } + } + + for (auto iter = model->BeginKernelIterator(); iter != model->EndKernelIterator(); iter++) { + if (iter >= from && iter <= end) { + continue; + } + // The output of other kernels is the input of the current subgraph kernel. + for (auto out_tensor : (*iter)->outputs()) { + if (std::find(all_in_tensors.begin(), all_in_tensors.end(), out_tensor) != all_in_tensors.end()) { + in_tensors.push_back(out_tensor); + } + } + } + return in_tensors; +} + +template +std::vector GraphOutTensors(const std::vector &ops, DelegateModel *model, KernelIter from, + KernelIter end) { + auto out_tensors = GetGraphOutTensors(ops); + std::vector all_out_tensors; + for (auto op : ops) { + for (auto out_tensor : op->outputs()) { + if (find(out_tensors.begin(), out_tensors.end(), out_tensor) == out_tensors.end()) { + all_out_tensors.push_back(out_tensor); + } + } + } + + for (auto iter = model->BeginKernelIterator(); iter != model->EndKernelIterator(); iter++) { + if (iter >= from && iter <= end) { + continue; + } + // The input of other kernels is the output of the current subgraph kernel. + for (auto in_tensor : (*iter)->inputs()) { + if (find(all_out_tensors.begin(), all_out_tensors.end(), in_tensor) != all_out_tensors.end()) { + out_tensors.push_back(in_tensor); + } + } + } + return out_tensors; +} + +template +std::vector FindPreOps(T *cur_op, std::vector all_ops) { + std::vector in_ops; + for (auto in_tensor : cur_op->inputs()) { + for (auto op : all_ops) { + if (find(op->outputs().begin(), op->outputs().end(), in_tensor) != op->outputs().end()) { + in_ops.push_back(op); + } + } + } + return in_ops; +} + +template +std::vector FindNextOps(T *cur_op, std::vector all_ops) { + std::vector out_ops; + for (auto out_tensor : cur_op->outputs()) { + for (auto op : all_ops) { + if (find(op->inputs().begin(), op->inputs().end(), out_tensor) != op->inputs().end()) { + out_ops.push_back(op); + } + } + } + return out_ops; +} + +template +void FindPreNextOps(std::vector all_ops) { + for (auto op : all_ops) { + auto in_ops = FindPreOps(op, all_ops); + op->set_in_ops(in_ops); + auto out_ops = FindNextOps(op, all_ops); + op->set_out_ops(out_ops); + } +} + +template +int GetGraphInOutOps(const std::vector &inputs, + const std::vector &outputs, std::vector *in_ops, + std::vector *out_ops, const std::vector &all_ops) { + for (auto in_tensor : inputs) { + for (auto op : all_ops) { + if (find(op->inputs().begin(), op->inputs().end(), in_tensor) != op->inputs().end() && + find(in_ops->begin(), in_ops->end(), op) == in_ops->end()) { + in_ops->push_back(op); + } + } + } + if (in_ops->empty()) { + MS_LOG(ERROR) << "Can't find the input ops for npu sub graph."; + return RET_ERROR; + } + + for (auto out_tensor : outputs) { + for (auto op : all_ops) { + if (find(op->outputs().begin(), op->outputs().end(), out_tensor) != op->outputs().end() && + find(out_ops->begin(), out_ops->end(), op) == out_ops->end()) { + out_ops->push_back(op); + } + } + } + if (out_ops->empty()) { + MS_LOG(ERROR) << "Can't find the output ops for npu sub graph."; + return RET_ERROR; + } + return RET_OK; +} +} // namespace mindspore::lite + +#endif // MINDSPORE_LITE_SRC_DELEGATE_DELEGATE_UTILS diff --git a/mindspore/lite/src/delegate/tensorrt/CMakeLists.txt b/mindspore/lite/src/delegate/tensorrt/CMakeLists.txt new file mode 100644 index 00000000000..ccf4b2b9a3f --- /dev/null +++ b/mindspore/lite/src/delegate/tensorrt/CMakeLists.txt @@ -0,0 +1,22 @@ +include_directories(${TENSORRT_PATH}/include) +include_directories(${CUDA_PATH}/include) +file(GLOB_RECURSE TENSORRT_RUNTIME_SRC + ${CMAKE_CURRENT_SOURCE_DIR}/*.cc + ${CMAKE_CURRENT_SOURCE_DIR}/op/*.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../delegate_utils.cc + ) +add_library(libcudart SHARED IMPORTED) +set_target_properties(libcudart PROPERTIES IMPORTED_LOCATION + ${CUDA_LIB_PATH}/libcudart.so) + +add_library(libnvinfer SHARED IMPORTED) +set_target_properties(libnvinfer PROPERTIES IMPORTED_LOCATION + ${TENSORRT_LIB_PATH}/libnvinfer.so) + +add_library(tensorrt_kernel_mid OBJECT ${TENSORRT_RUNTIME_SRC}) +add_dependencies(tensorrt_kernel_mid fbs_src) +target_link_libraries( + tensorrt_kernel_mid + libcudart + libnvinfer +) diff --git a/mindspore/lite/src/delegate/tensorrt/op/activation_tensorrt.cc b/mindspore/lite/src/delegate/tensorrt/op/activation_tensorrt.cc new file mode 100644 index 00000000000..059b1bc5814 --- /dev/null +++ b/mindspore/lite/src/delegate/tensorrt/op/activation_tensorrt.cc @@ -0,0 +1,63 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/delegate/tensorrt/op/activation_tensorrt.h" +#include "src/delegate/tensorrt/tensorrt_utils.h" + +namespace mindspore::lite { +int ActivationTensorRT::IsSupport(const schema::Primitive *primitive, const std::vector &in_tensors, + const std::vector &out_tensors) { + if (in_tensors.size() != 1) { + MS_LOG(ERROR) << "Unsupported input tensor size, size is " << in_tensors.size(); + return RET_ERROR; + } + if (out_tensors.size() != 1) { + MS_LOG(ERROR) << "Unsupported output tensor size, size is " << out_tensors.size(); + return RET_ERROR; + } + if (type_ != schema::PrimitiveType_Activation) { + MS_LOG(ERROR) << "Unsupported schema type:" << schema::EnumNamePrimitiveType(type_); + return RET_ERROR; + } + return RET_OK; +} +int ActivationTensorRT::AddInnerOp(nvinfer1::INetworkDefinition *network) { + if (network == nullptr) { + MS_LOG(ERROR) << "network is invalid"; + return RET_ERROR; + } + auto activation_op = this->op_primitive_->value_as_Activation(); + if (activation_op == nullptr) { + MS_LOG(ERROR) << "op convert failed"; + return RET_ERROR; + } + nvinfer1::ActivationType action_code = ConvertActivationType(activation_op->activation_type()); + + nvinfer1::IActivationLayer *activation_layer = network->addActivation(*tensorrt_in_tensors_[0], action_code); + if (activation_layer == nullptr) { + MS_LOG(ERROR) << "add activation op failed for TensorRT."; + return RET_ERROR; + } + + if (activation_op->alpha() != activation_layer->getAlpha()) { + activation_layer->setAlpha(activation_op->alpha()); + } + activation_layer->setName(op_name_.c_str()); + this->AddInnerOutTensors(activation_layer->getOutput(0)); + + return RET_OK; +} +} // namespace mindspore::lite diff --git a/mindspore/lite/src/delegate/tensorrt/op/activation_tensorrt.h b/mindspore/lite/src/delegate/tensorrt/op/activation_tensorrt.h new file mode 100644 index 00000000000..475464798de --- /dev/null +++ b/mindspore/lite/src/delegate/tensorrt/op/activation_tensorrt.h @@ -0,0 +1,37 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_DELEGATE_TENSORRT_OP_ACTIVATION_TENSORRT_H_ +#define MINDSPORE_LITE_SRC_DELEGATE_TENSORRT_OP_ACTIVATION_TENSORRT_H_ +#include +#include +#include "src/delegate/tensorrt/op/tensorrt_op.h" + +namespace mindspore::lite { +class ActivationTensorRT : public TensorRTOp { + public: + ActivationTensorRT(const schema::Primitive *primitive, const std::vector &in_tensors, + const std::vector &out_tensors, const std::string &name) + : TensorRTOp(primitive, in_tensors, out_tensors, name) {} + + ~ActivationTensorRT() override = default; + + int AddInnerOp(nvinfer1::INetworkDefinition *network) override; + + int IsSupport(const schema::Primitive *primitive, const std::vector &in_tensors, + const std::vector &out_tensors) override; +}; +} // namespace mindspore::lite +#endif // MINDSPORE_LITE_SRC_DELEGATE_TENSORRT_OP_ACTIVATION_TENSORRT_H_ diff --git a/mindspore/lite/src/delegate/tensorrt/op/concate_tensorrt.cc b/mindspore/lite/src/delegate/tensorrt/op/concate_tensorrt.cc new file mode 100644 index 00000000000..2acd8e808c5 --- /dev/null +++ b/mindspore/lite/src/delegate/tensorrt/op/concate_tensorrt.cc @@ -0,0 +1,71 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/delegate/tensorrt/op/concate_tensorrt.h" +#include + +namespace mindspore::lite { +int ConcateTensorRT::IsSupport(const schema::Primitive *primitive, const std::vector &in_tensors, + const std::vector &out_tensors) { + if (in_tensors.size() < 1) { + MS_LOG(ERROR) << "Unsupported input tensor size, size is " << in_tensors.size(); + return RET_ERROR; + } + if (out_tensors.size() != 1) { + MS_LOG(ERROR) << "Unsupported output tensor size, size is " << out_tensors.size(); + return RET_ERROR; + } + return RET_OK; +} +int ConcateTensorRT::AddInnerOp(nvinfer1::INetworkDefinition *network) { + if (network == nullptr) { + MS_LOG(ERROR) << "network is invalid"; + return RET_ERROR; + } + // Concat + auto concate_op = this->op_primitive_->value_as_Concat(); + if (concate_op == nullptr) { + MS_LOG(ERROR) << "concate_op convert failed"; + return RET_ERROR; + } + MS_LOG(INFO) << "in tensort size of concate: " << tensorrt_in_tensors_.size(); + if (tensorrt_in_tensors_.size() != in_tensors_.size()) { + MS_LOG(ERROR) << "concate_op in tensor is invalid"; + return RET_ERROR; + } + + int axis = RET_INVALID_OP_ATTR; + axis = concate_op->axis(); + + nvinfer1::ITensor *trt_input_tensors[tensorrt_in_tensors_.size()]; + std::copy(tensorrt_in_tensors_.begin(), tensorrt_in_tensors_.end(), trt_input_tensors); + + nvinfer1::IConcatenationLayer *concate_layer = + network->addConcatenation(trt_input_tensors, static_cast(tensorrt_in_tensors_.size())); + if (concate_layer == nullptr) { + MS_LOG(ERROR) << "addConcatenation failed for TensorRT."; + return RET_ERROR; + } + + if (axis != RET_INVALID_OP_ATTR) { + concate_layer->setAxis(axis); + } + concate_layer->setName(op_name_.c_str()); + this->AddInnerOutTensors(concate_layer->getOutput(0)); + + return RET_OK; +} +} // namespace mindspore::lite diff --git a/mindspore/lite/src/delegate/tensorrt/op/concate_tensorrt.h b/mindspore/lite/src/delegate/tensorrt/op/concate_tensorrt.h new file mode 100644 index 00000000000..afd4ebfa91a --- /dev/null +++ b/mindspore/lite/src/delegate/tensorrt/op/concate_tensorrt.h @@ -0,0 +1,37 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_DELEGATE_TENSORRT_OP_CONCATE_TENSORRT_H_ +#define MINDSPORE_LITE_SRC_DELEGATE_TENSORRT_OP_CONCATE_TENSORRT_H_ +#include +#include +#include "src/delegate/tensorrt/op/tensorrt_op.h" + +namespace mindspore::lite { +class ConcateTensorRT : public TensorRTOp { + public: + ConcateTensorRT(const schema::Primitive *primitive, const std::vector &in_tensors, + const std::vector &out_tensors, const std::string &name) + : TensorRTOp(primitive, in_tensors, out_tensors, name) {} + + ~ConcateTensorRT() override = default; + + int AddInnerOp(nvinfer1::INetworkDefinition *network) override; + + int IsSupport(const schema::Primitive *primitive, const std::vector &in_tensors, + const std::vector &out_tensors) override; +}; +} // namespace mindspore::lite +#endif // MINDSPORE_LITE_SRC_DELEGATE_TENSORRT_OP_CONCATE_TENSORRT_H_ diff --git a/mindspore/lite/src/delegate/tensorrt/op/convolution_tensorrt.cc b/mindspore/lite/src/delegate/tensorrt/op/convolution_tensorrt.cc new file mode 100644 index 00000000000..8d6439593de --- /dev/null +++ b/mindspore/lite/src/delegate/tensorrt/op/convolution_tensorrt.cc @@ -0,0 +1,194 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/delegate/tensorrt/op/convolution_tensorrt.h" +#include "src/delegate/tensorrt/tensorrt_utils.h" +#include "nnacl/pack.h" + +namespace mindspore::lite { +int ConvolutionTensorRT::IsSupport(const schema::Primitive *primitive, + const std::vector &in_tensors, + const std::vector &out_tensors) { + if (in_tensors.size() != 2 && in_tensors.size() != 3) { + MS_LOG(ERROR) << "Unsupported input tensor size, size is " << in_tensors.size(); + return RET_ERROR; + } + if (out_tensors.size() != 1) { + MS_LOG(ERROR) << "Unsupported output tensor size, size is " << out_tensors.size(); + return RET_ERROR; + } + return RET_OK; +} + +int ConvolutionTensorRT::AddInnerOp(nvinfer1::INetworkDefinition *network) { + if (network == nullptr) { + MS_LOG(ERROR) << "network is invalid"; + return RET_ERROR; + } + const schema::Conv2DFusion *conv_op = this->op_primitive_->value_as_Conv2DFusion(); + if (conv_op == nullptr) { + MS_LOG(ERROR) << "op action convert failed"; + return RET_ERROR; + } + // transpose: NHWC->NCHW + nvinfer1::IShuffleLayer *transpose_layer_in = NHWC2NCHW(network, *tensorrt_in_tensors_[0]); + if (transpose_layer_in == nullptr) { + MS_LOG(ERROR) << "transpose: NHWC->NCHW failed"; + return RET_ERROR; + } + transpose_layer_in->setName((op_name_ + "_transpose2NCHW").c_str()); + + // conv + int nbOutputMaps = conv_op->out_channel(); + if (nbOutputMaps <= 0) { + MS_LOG(ERROR) << "out_channel is invalid"; + return RET_ERROR; + } + + nvinfer1::Dims kernelSize{}; + auto kernel_size = conv_op->kernel_size(); + if (kernel_size == nullptr) { + MS_LOG(ERROR) << "kernel_size is null"; + return RET_ERROR; + } + kernelSize.nbDims = static_cast(kernel_size->size()); + for (int i = 0; i < kernelSize.nbDims; i++) { + kernelSize.d[i] = kernel_size->Get(i); + } + + // transpose weight + tensor::MSTensor *weight_tensor = in_tensors_[1]; + nvinfer1::Weights kernelWeights{}; + kernelWeights.count = weight_tensor->ElementsNum(); + if (lite::ConvertDataType(weight_tensor->data_type()) != nvinfer1::DataType::kFLOAT) { + MS_LOG(WARNING) << "kernelWeights data type is not float"; + } + kernelWeights.type = nvinfer1::DataType::kFLOAT; + std::vector weight_shape = weight_tensor->shape(); + float *src_val = reinterpret_cast(weight_tensor->data()); + pack_weight_ = reinterpret_cast(malloc(weight_tensor->ElementsNum() * sizeof(float))); + if (pack_weight_ == nullptr) { + MS_LOG(ERROR) << "Malloc buffer failed."; + return RET_ERROR; + } + PackNHWCToNCHWFp32(src_val, pack_weight_, weight_shape[0], weight_shape[1] * weight_shape[2], weight_shape[3], 0, 0); + kernelWeights.values = pack_weight_; + + // bias + nvinfer1::Weights biasWeights{}; + if (in_tensors_.size() >= 3) { + tensor::MSTensor *bias_tensor = in_tensors_[2]; + biasWeights.type = ConvertDataType(bias_tensor->data_type()); + biasWeights.values = bias_tensor->data(); + biasWeights.count = bias_tensor->ElementsNum(); + } else { + biasWeights.type = nvinfer1::DataType::kFLOAT; + biasWeights.count = 0; + biasWeights.values = nullptr; + } + + nvinfer1::IConvolutionLayer *conv_layer = + network->addConvolutionNd(*transpose_layer_in->getOutput(0), nbOutputMaps, kernelSize, kernelWeights, biasWeights); + + if (conv_layer == nullptr) { + MS_LOG(ERROR) << "ConvolutionLayer failed"; + return RET_ERROR; + } + conv_layer->setName((op_name_ + "_conv").c_str()); + + // add params + SetAttributes(conv_op, conv_layer); + + // add activation + nvinfer1::ILayer *activation_layer = nullptr; + if (conv_op->activation_type() == schema::ActivationType::ActivationType_NO_ACTIVATION) { + activation_layer = conv_layer; + } else if (conv_op->activation_type() == schema::ActivationType::ActivationType_RELU) { + activation_layer = network->addActivation(*conv_layer->getOutput(0), nvinfer1::ActivationType::kRELU); + if (activation_layer == nullptr) { + MS_LOG(ERROR) << "addActivation for conv failed"; + return RET_ERROR; + } + activation_layer->setName((op_name_ + "_relu").c_str()); + } else if (conv_op->activation_type() == schema::ActivationType::ActivationType_RELU6) { + auto activation = network->addActivation(*conv_layer->getOutput(0), nvinfer1::ActivationType::kCLIP); + if (activation == nullptr) { + MS_LOG(ERROR) << "addActivation for conv failed"; + return RET_ERROR; + } + activation_layer->setName((op_name_ + "_relu6").c_str()); + activation->setAlpha(0); + activation->setBeta(6); + activation_layer = activation; + } else { + MS_LOG(DEBUG) << "Unsupported op action type for conv TensorRT: " << conv_op->activation_type(); + return RET_ERROR; + } + + // transpose: NCHW->NHWC + nvinfer1::IShuffleLayer *transpose_layer_out = NCHW2NHWC(network, *activation_layer->getOutput(0)); + if (transpose_layer_out == nullptr) { + MS_LOG(ERROR) << "op action convert failed"; + return RET_ERROR; + } + transpose_layer_out->setName((op_name_ + "_transpose2NHWC").c_str()); + + this->AddInnerOutTensors(transpose_layer_out->getOutput(0)); + return RET_OK; +} + +void ConvolutionTensorRT::SetAttributes(const schema::Conv2DFusion *conv_op, nvinfer1::IConvolutionLayer *conv_layer) { + auto stride = conv_op->stride(); + if (stride != nullptr) { + auto stride_val = std::vector(stride->begin(), stride->end()); + auto dims = ConvertCudaDims(stride_val); + conv_layer->setStrideNd(dims); + } + + auto dilation = conv_op->dilation(); + if (dilation != nullptr) { + auto dilation_val = std::vector(dilation->begin(), dilation->end()); + auto dims = ConvertCudaDims(dilation_val); + conv_layer->setDilationNd(dims); + } + int nbGroups = conv_op->group(); + if (nbGroups > 0) { + conv_layer->setNbGroups(nbGroups); + } + + schema::PadMode pad_mode = conv_op->pad_mode(); + if (pad_mode == schema::PadMode::PadMode_SAME) { + conv_layer->setPaddingMode(nvinfer1::PaddingMode::kSAME_UPPER); + } else { + auto padding = conv_op->pad_list(); + if (padding != nullptr) { + auto padding_val = std::vector(padding->begin(), padding->end()); + nvinfer1::Dims dims{}; + dims.nbDims = 2; + dims.d[0] = padding_val[0]; + dims.d[1] = padding_val[2]; + conv_layer->setPaddingNd(dims); + } + } +} + +ConvolutionTensorRT::~ConvolutionTensorRT() { + if (pack_weight_ != nullptr) { + free(pack_weight_); + pack_weight_ = nullptr; + } +} +} // namespace mindspore::lite diff --git a/mindspore/lite/src/delegate/tensorrt/op/convolution_tensorrt.h b/mindspore/lite/src/delegate/tensorrt/op/convolution_tensorrt.h new file mode 100644 index 00000000000..0ac71159180 --- /dev/null +++ b/mindspore/lite/src/delegate/tensorrt/op/convolution_tensorrt.h @@ -0,0 +1,42 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_TENSORRT_OP_CONVOLUTION_TENSORRT_H_ +#define MINDSPORE_LITE_SRC_TENSORRT_OP_CONVOLUTION_TENSORRT_H_ +#include +#include +#include "src/delegate/tensorrt/op/tensorrt_op.h" + +namespace mindspore::lite { +class ConvolutionTensorRT : public TensorRTOp { + public: + ConvolutionTensorRT(const schema::Primitive *primitive, const std::vector &in_tensors, + const std::vector &out_tensors, const std::string &name) + : TensorRTOp(primitive, in_tensors, out_tensors, name) {} + + ~ConvolutionTensorRT() override; + + int AddInnerOp(nvinfer1::INetworkDefinition *network) override; + + int IsSupport(const schema::Primitive *primitive, const std::vector &in_tensors, + const std::vector &out_tensors) override; + + private: + void SetAttributes(const schema::Conv2DFusion *ms_op, nvinfer1::IConvolutionLayer *current_layer_); + + void *pack_weight_{nullptr}; +}; +} // namespace mindspore::lite +#endif // MINDSPORE_LITE_SRC_TENSORRT_OP_CONVOLUTION_TENSORRT_H_ diff --git a/mindspore/lite/src/delegate/tensorrt/op/elementwise_tensorrt.cc b/mindspore/lite/src/delegate/tensorrt/op/elementwise_tensorrt.cc new file mode 100644 index 00000000000..cf7170f9671 --- /dev/null +++ b/mindspore/lite/src/delegate/tensorrt/op/elementwise_tensorrt.cc @@ -0,0 +1,143 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/delegate/tensorrt/op/elementwise_tensorrt.h" +#include "src/delegate/tensorrt/tensorrt_utils.h" + +namespace mindspore::lite { +int ElementWiseTensorRT::IsSupport(const schema::Primitive *primitive, + const std::vector &in_tensors, + const std::vector &out_tensors) { + std::map element_wise_ops = { + {schema::PrimitiveType_AddFusion, nvinfer1::ElementWiseOperation::kSUM}, + {schema::PrimitiveType_PowFusion, nvinfer1::ElementWiseOperation::kPOW}, + {schema::PrimitiveType_DivFusion, nvinfer1::ElementWiseOperation::kDIV}, + {schema::PrimitiveType_SubFusion, nvinfer1::ElementWiseOperation::kSUB}, + }; + auto iter = element_wise_ops.find(this->type_); + if (iter == element_wise_ops.end()) { + MS_LOG(ERROR) << "invalid PrimitiveType for ElementWiseTensorRT, PrimitiveType: " << this->type_; + return RET_ERROR; + } + element_wise_op_ = iter->second; + if (in_tensors.size() != 2) { + MS_LOG(ERROR) << "invalid input tensort size: " << in_tensors.size(); + return RET_ERROR; + } + if (out_tensors.size() != 1) { + MS_LOG(ERROR) << "invalid output tensort size: " << out_tensors.size(); + return RET_ERROR; + } + return RET_OK; +} +int ElementWiseTensorRT::AddInnerOp(nvinfer1::INetworkDefinition *network) { + if (network == nullptr) { + MS_LOG(ERROR) << "network or input tensor size is invalid"; + return RET_ERROR; + } + // create ITensor from MS scalar + if (this->in_tensors_[1]->shape().size() == 0) { + nvinfer1::ITensor *scalar_input = + lite::ConvertScalarToITensor(network, this->in_tensors_[0]->shape().size(), this->in_tensors_[1]->data()); + if (scalar_input == nullptr) { + MS_LOG(ERROR) << "create Itensor from scalar failed"; + return RET_ERROR; + } + this->AddInnerInTensors(scalar_input); + } + // add elementwise + if (this->tensorrt_in_tensors_.size() != 2) { + MS_LOG(ERROR) << "invalid inner in tensors cnt: " << this->tensorrt_in_tensors_.size(); + return RET_ERROR; + } + nvinfer1::IElementWiseLayer *cal_layer = + network->addElementWise(*tensorrt_in_tensors_[0], *tensorrt_in_tensors_[1], element_wise_op_); + + if (cal_layer == nullptr) { + MS_LOG(ERROR) << "addElementWise failed for TensorRT."; + return RET_ERROR; + } + cal_layer->setName(op_name_.c_str()); + + nvinfer1::ITensor *op_out_tensor = cal_layer->getOutput(0); + if (op_out_tensor == nullptr) { + MS_LOG(ERROR) << "addElementWise out tensor is nullptr."; + return RET_ERROR; + } + // add activation + nvinfer1::ITensor *activation_out_tensor = AddActivation(network, op_out_tensor); + op_out_tensor = (activation_out_tensor == nullptr) ? op_out_tensor : activation_out_tensor; + + // scale and shift + if (element_wise_op_ == nvinfer1::ElementWiseOperation::kPOW) { + auto pow_op = op_primitive_->value_as_PowFusion(); + if (pow_op == nullptr) { + MS_LOG(ERROR) << "PowFusion convert failed."; + return RET_ERROR; + } + float scale = pow_op->scale(); + float shift = pow_op->shift(); + if (abs(scale - 1) >= 1.0e-05 || abs(shift - 0) >= 1.0e-05) { + MS_LOG(WARNING) << "deal with scale and shift for pow op"; + } + } + + op_out_tensor->setName(out_tensors_[0]->tensor_name().c_str()); + this->AddInnerOutTensors(op_out_tensor); + return RET_OK; +} + +nvinfer1::ITensor *ElementWiseTensorRT::AddActivation(nvinfer1::INetworkDefinition *network, + nvinfer1::ITensor *in_tensor) { + schema::ActivationType activation = schema::ActivationType::ActivationType_NO_ACTIVATION; + switch (element_wise_op_) { + case nvinfer1::ElementWiseOperation::kSUM: { + auto sum_op = op_primitive_->value_as_AddFusion(); + if (sum_op == nullptr) { + MS_LOG(ERROR) << "AddFusion convert failed."; + return nullptr; + } + activation = sum_op->activation_type(); + break; + } + case nvinfer1::ElementWiseOperation::kDIV: { + auto div_op = op_primitive_->value_as_DivFusion(); + if (div_op == nullptr) { + MS_LOG(ERROR) << "DivFusion convert failed."; + return nullptr; + } + activation = div_op->activation_type(); + break; + } + case nvinfer1::ElementWiseOperation::kSUB: { + auto sub_op = op_primitive_->value_as_SubFusion(); + if (sub_op == nullptr) { + MS_LOG(ERROR) << "SubFusion convert failed."; + return nullptr; + } + activation = sub_op->activation_type(); + break; + } + default: + MS_LOG(INFO) << "no activation need for: " << op_name_; + } + nvinfer1::ITensor *activation_out_tensor = nullptr; + if (activation != schema::ActivationType::ActivationType_NO_ACTIVATION) { + MS_LOG(WARNING) << "op: " << op_name_ << " has activation"; + } + return activation_out_tensor; +} +} // namespace mindspore::lite diff --git a/mindspore/lite/src/delegate/tensorrt/op/elementwise_tensorrt.h b/mindspore/lite/src/delegate/tensorrt/op/elementwise_tensorrt.h new file mode 100644 index 00000000000..38fe8bfe4f7 --- /dev/null +++ b/mindspore/lite/src/delegate/tensorrt/op/elementwise_tensorrt.h @@ -0,0 +1,42 @@ +/** + * Copyright 2020-2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_DELEGATE_TENSORRT_OP_ELEMENTWISE_TENSORRT_H_ +#define MINDSPORE_LITE_SRC_DELEGATE_TENSORRT_OP_ELEMENTWISE_TENSORRT_H_ +#include +#include +#include +#include "src/delegate/tensorrt/op/tensorrt_op.h" + +namespace mindspore::lite { +class ElementWiseTensorRT : public TensorRTOp { + public: + ElementWiseTensorRT(const schema::Primitive *primitive, const std::vector &in_tensors, + const std::vector &out_tensors, const std::string &name) + : TensorRTOp(primitive, in_tensors, out_tensors, name) {} + + ~ElementWiseTensorRT() override = default; + + int AddInnerOp(nvinfer1::INetworkDefinition *network) override; + + int IsSupport(const schema::Primitive *primitive, const std::vector &in_tensors, + const std::vector &out_tensors) override; + + private: + nvinfer1::ElementWiseOperation element_wise_op_; + nvinfer1::ITensor *AddActivation(nvinfer1::INetworkDefinition *network, nvinfer1::ITensor *in_tensor); +}; +} // namespace mindspore::lite +#endif // MINDSPORE_LITE_SRC_DELEGATE_TENSORRT_OP_ELEMENTWISE_TENSORRT_H_ diff --git a/mindspore/lite/src/delegate/tensorrt/op/gather_tensorrt.cc b/mindspore/lite/src/delegate/tensorrt/op/gather_tensorrt.cc new file mode 100644 index 00000000000..d9240a798e1 --- /dev/null +++ b/mindspore/lite/src/delegate/tensorrt/op/gather_tensorrt.cc @@ -0,0 +1,65 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/delegate/tensorrt/op/gather_tensorrt.h" +#include "src/delegate/tensorrt/tensorrt_utils.h" + +namespace mindspore::lite { +int GatherTensorRT::IsSupport(const schema::Primitive *primitive, const std::vector &in_tensors, + const std::vector &out_tensors) { + if (in_tensors.size() != 3) { + MS_LOG(ERROR) << "invalid input tensor size: " << in_tensors.size(); + return RET_ERROR; + } + if (out_tensors.size() != 1) { + MS_LOG(ERROR) << "invalid output tensor size: " << out_tensors.size(); + return RET_ERROR; + } + if (in_tensors[1]->data_type() != kNumberTypeInt32) { + MS_LOG(ERROR) << "Gather indices only support Int32"; + return RET_ERROR; + } + if (in_tensors[2]->ElementsNum() == 1) { + axis_ = static_cast(in_tensors[2]->data())[0]; + } else { + MS_LOG(ERROR) << "TensorRT axis is attribute."; + return RET_ERROR; + } + return RET_OK; +} + +int GatherTensorRT::AddInnerOp(nvinfer1::INetworkDefinition *network) { + if (network == nullptr) { + MS_LOG(ERROR) << "network is invalid"; + return RET_ERROR; + } + // convert constant MSTensor to ITensor + nvinfer1::ITensor *add_tensor = lite::ConvertConstantTensor(network, this->in_tensors_[1]); + if (add_tensor == nullptr) { + MS_LOG(ERROR) << "add a new tensor failed for TensorRT GatherTensorRTOp."; + return RET_ERROR; + } + nvinfer1::IGatherLayer *gather_layer = + network->addGather(*tensorrt_in_tensors_[0], *add_tensor /*indices*/, axis_ /*axis*/); + if (gather_layer == nullptr) { + MS_LOG(ERROR) << "addGather failed for TensorRT."; + return RET_ERROR; + } + gather_layer->setName(op_name_.c_str()); + this->AddInnerOutTensors(gather_layer->getOutput(0)); + return RET_OK; +} +} // namespace mindspore::lite diff --git a/mindspore/lite/src/delegate/tensorrt/op/gather_tensorrt.h b/mindspore/lite/src/delegate/tensorrt/op/gather_tensorrt.h new file mode 100644 index 00000000000..f2c2daf00b6 --- /dev/null +++ b/mindspore/lite/src/delegate/tensorrt/op/gather_tensorrt.h @@ -0,0 +1,41 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_DELEGATE_TENSORRT_OP_GATHER_TENSORRT_H_ +#define MINDSPORE_LITE_SRC_DELEGATE_TENSORRT_OP_GATHER_TENSORRT_H_ +#include +#include +#include "src/delegate/tensorrt/op/tensorrt_op.h" + +namespace mindspore::lite { +class GatherTensorRT : public TensorRTOp { + public: + GatherTensorRT(const schema::Primitive *primitive, const std::vector &in_tensors, + const std::vector &out_tensors, const std::string &name) + : TensorRTOp(primitive, in_tensors, out_tensors, name) {} + + ~GatherTensorRT() override = default; + + int AddInnerOp(nvinfer1::INetworkDefinition *network) override; + + int IsSupport(const schema::Primitive *primitive, const std::vector &in_tensors, + const std::vector &out_tensors) override; + + private: + int axis_; + tensor::MSTensor *indices_; +}; +} // namespace mindspore::lite +#endif // MINDSPORE_LITE_SRC_DELEGATE_TENSORRT_OP_GATHER_TENSORRT_H_ diff --git a/mindspore/lite/src/delegate/tensorrt/op/matmul_tensorrt.cc b/mindspore/lite/src/delegate/tensorrt/op/matmul_tensorrt.cc new file mode 100644 index 00000000000..564e3e6f7ef --- /dev/null +++ b/mindspore/lite/src/delegate/tensorrt/op/matmul_tensorrt.cc @@ -0,0 +1,55 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/delegate/tensorrt/op/matmul_tensorrt.h" +#include "src/delegate/tensorrt/tensorrt_utils.h" + +namespace mindspore::lite { +int mindspore::lite::MatMulTensorRT::IsSupport(const mindspore::schema::Primitive *primitive, + const std::vector &in_tensors, + const std::vector &out_tensors) { + if (in_tensors.size() != 2 && in_tensors.size() != 3) { + MS_LOG(ERROR) << "Unsupported input tensor size, size is " << in_tensors.size(); + return RET_ERROR; + } + if (out_tensors.size() != 1) { + MS_LOG(ERROR) << "Unsupported output tensor size, size is " << out_tensors.size(); + return RET_ERROR; + } + return RET_OK; +} + +int mindspore::lite::MatMulTensorRT::AddInnerOp(nvinfer1::INetworkDefinition *network) { + auto primitive = this->GetPrimitive()->value_as_MatMul(); + transpose_a_ = primitive->transpose_a() ? nvinfer1::MatrixOperation::kTRANSPOSE : nvinfer1::MatrixOperation::kNONE; + transpose_b_ = primitive->transpose_b() ? nvinfer1::MatrixOperation::kTRANSPOSE : nvinfer1::MatrixOperation::kNONE; + auto weight = ConvertTensorWithExpandDims(network, in_tensors_[1], in_tensors_[0]->shape().size()); + + auto matmul_layer = network->addMatrixMultiply(*tensorrt_in_tensors_[0], transpose_a_, *weight, transpose_b_); + matmul_layer->setName(op_name_.c_str()); + + if (in_tensors_.size() == 3) { + auto bias = ConvertTensorWithExpandDims(network, in_tensors_[2], in_tensors_[0]->shape().size()); + auto bias_layer = network->addElementWise(*matmul_layer->getOutput(0), *bias, nvinfer1::ElementWiseOperation::kSUM); + auto bias_layer_name = op_name_ + "_bias"; + bias_layer->setName(bias_layer_name.c_str()); + this->AddInnerOutTensors(bias_layer->getOutput(0)); + } else { + this->AddInnerOutTensors(matmul_layer->getOutput(0)); + } + return RET_OK; +} +} // namespace mindspore::lite diff --git a/mindspore/lite/src/delegate/tensorrt/op/matmul_tensorrt.h b/mindspore/lite/src/delegate/tensorrt/op/matmul_tensorrt.h new file mode 100644 index 00000000000..5fd08670fd2 --- /dev/null +++ b/mindspore/lite/src/delegate/tensorrt/op/matmul_tensorrt.h @@ -0,0 +1,43 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_DELEGATE_TENSORRT_OP_MATMUL_TENSORRT_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_DELEGATE_TENSORRT_OP_MATMUL_TENSORRT_H_ +#include +#include +#include +#include "src/delegate/tensorrt/op/tensorrt_op.h" + +namespace mindspore::lite { +class MatMulTensorRT : public TensorRTOp { + public: + MatMulTensorRT(const schema::Primitive *primitive, const std::vector &in_tensors, + const std::vector &out_tensors, const std::string &name) + : TensorRTOp(primitive, in_tensors, out_tensors, name) {} + + ~MatMulTensorRT() override = default; + + int IsSupport(const schema::Primitive *primitive, const std::vector &in_tensors, + const std::vector &out_tensors) override; + + int AddInnerOp(nvinfer1::INetworkDefinition *network) override; + + private: + nvinfer1::MatrixOperation transpose_a_ = nvinfer1::MatrixOperation::kNONE; + nvinfer1::MatrixOperation transpose_b_ = nvinfer1::MatrixOperation::kNONE; +}; +} // namespace mindspore::lite +#endif // MINDSPORE_LITE_SRC_RUNTIME_DELEGATE_TENSORRT_OP_MATMUL_TENSORRT_H_ diff --git a/mindspore/lite/src/delegate/tensorrt/op/reduce_tensorrt.cc b/mindspore/lite/src/delegate/tensorrt/op/reduce_tensorrt.cc new file mode 100644 index 00000000000..dcf0d456490 --- /dev/null +++ b/mindspore/lite/src/delegate/tensorrt/op/reduce_tensorrt.cc @@ -0,0 +1,86 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/delegate/tensorrt/op/reduce_tensorrt.h" + +namespace mindspore::lite { +int ReduceTensorRT::IsSupport(const schema::Primitive *primitive, const std::vector &in_tensors, + const std::vector &out_tensors) { + auto reduce_op = primitive->value_as_ReduceFusion(); + if (reduce_op == nullptr) { + MS_LOG(ERROR) << "convert failed"; + return RET_ERROR; + } + if (in_tensors.size() != 2) { + MS_LOG(ERROR) << "Unsupported input tensor size, size is " << in_tensors.size(); + } + if (out_tensors.size() != 1) { + MS_LOG(ERROR) << "Unsupported output tensor size, size is " << out_tensors.size(); + } + auto it = reduce_ops_.find(reduce_op->mode()); + if (it != reduce_ops_.end()) { + reduce_op_ = it->second; + } else { + MS_LOG(ERROR) << "unsupported ReduceMode: " << reduce_op->mode(); + return RET_ERROR; + } + return RET_OK; +} + +int ReduceTensorRT::AddInnerOp(nvinfer1::INetworkDefinition *network) { + if (network == nullptr) { + MS_LOG(ERROR) << "network is invalid"; + return RET_ERROR; + } + auto reduce_op = op_primitive_->value_as_ReduceFusion(); + if (reduce_op == nullptr) { + MS_LOG(ERROR) << "convert failed"; + return RET_ERROR; + } + bool keep_dims = reduce_op->keep_dims(); + // axis + uint32_t reduceAxes = 0; + tensor::MSTensor *axis_tensor = this->in_tensors_[1]; + if (axis_tensor->data() == nullptr) { + MS_LOG(ERROR) << "invalid axis_tensor"; + return RET_ERROR; + } + if (axis_tensor->data_type() != TypeId::kNumberTypeInt32) { + MS_LOG(WARNING) << "not int data type"; + } + int *axis_data = reinterpret_cast(axis_tensor->data()); + for (int i = 0; i < axis_tensor->ElementsNum(); i++) { + reduceAxes |= (16 - (1u << *axis_data)); + axis_data++; + } + MS_LOG(INFO) << "reduceAxes: " << reduceAxes; + nvinfer1::IReduceLayer *layer = network->addReduce(*tensorrt_in_tensors_[0], reduce_op_, reduceAxes, keep_dims); + if (layer == nullptr) { + MS_LOG(ERROR) << "addReduce failed for TensorRT."; + return RET_ERROR; + } + layer->setName(op_name_.c_str()); + + nvinfer1::ITensor *out_tensor = layer->getOutput(0); + if (out_tensor == nullptr) { + MS_LOG(ERROR) << "addReduce output tensor create failed for TensorRT."; + return RET_ERROR; + } + out_tensor->setName(out_tensors_[0]->tensor_name().c_str()); + this->AddInnerOutTensors(out_tensor); + return RET_OK; +} +} // namespace mindspore::lite diff --git a/mindspore/lite/src/delegate/tensorrt/op/reduce_tensorrt.h b/mindspore/lite/src/delegate/tensorrt/op/reduce_tensorrt.h new file mode 100644 index 00000000000..82db48991bc --- /dev/null +++ b/mindspore/lite/src/delegate/tensorrt/op/reduce_tensorrt.h @@ -0,0 +1,49 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_DELEGATE_TENSORRT_OP_REDUCE_TENSORRT_H_ +#define MINDSPORE_LITE_SRC_DELEGATE_TENSORRT_OP_REDUCE_TENSORRT_H_ + +#include +#include +#include +#include "src/delegate/tensorrt/op/tensorrt_op.h" + +namespace mindspore::lite { +class ReduceTensorRT : public TensorRTOp { + public: + ReduceTensorRT(const schema::Primitive *primitive, const std::vector &in_tensors, + const std::vector &out_tensors, const std::string &name) + : TensorRTOp(primitive, in_tensors, out_tensors, name) {} + + ~ReduceTensorRT() override = default; + + int AddInnerOp(nvinfer1::INetworkDefinition *network) override; + + int IsSupport(const schema::Primitive *primitive, const std::vector &in_tensors, + const std::vector &out_tensors) override; + + private: + std::map reduce_ops_ = { + {schema::ReduceMode::ReduceMode_ReduceMean, nvinfer1::ReduceOperation::kAVG}, + {schema::ReduceMode::ReduceMode_ReduceMax, nvinfer1::ReduceOperation::kMAX}, + {schema::ReduceMode::ReduceMode_ReduceMin, nvinfer1::ReduceOperation::kMIN}, + {schema::ReduceMode::ReduceMode_ReduceProd, nvinfer1::ReduceOperation::kPROD}, + {schema::ReduceMode::ReduceMode_ReduceSum, nvinfer1::ReduceOperation::kSUM}, + }; + nvinfer1::ReduceOperation reduce_op_; +}; +} // namespace mindspore::lite +#endif // MINDSPORE_LITE_SRC_DELEGATE_TENSORRT_OP_REDUCE_TENSORRT_H_ diff --git a/mindspore/lite/src/delegate/tensorrt/op/scale_tensorrt.cc b/mindspore/lite/src/delegate/tensorrt/op/scale_tensorrt.cc new file mode 100644 index 00000000000..0d41750c5e0 --- /dev/null +++ b/mindspore/lite/src/delegate/tensorrt/op/scale_tensorrt.cc @@ -0,0 +1,132 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "src/delegate/tensorrt/op/scale_tensorrt.h" +#include "src/delegate/tensorrt/tensorrt_utils.h" + +namespace mindspore::lite { +int ScaleTensorRT::IsSupport(const schema::Primitive *primitive, const std::vector &in_tensors, + const std::vector &out_tensors) { + if (in_tensors.size() != 2 && in_tensors.size() != 3 && in_tensors.size() != 4) { + MS_LOG(ERROR) << "Unsupported input tensor size, size is: " << in_tensors.size(); + return RET_ERROR; + } + if (out_tensors.size() != 1) { + MS_LOG(ERROR) << "Unsupported output tensor size, size is: " << out_tensors.size(); + return RET_ERROR; + } + return RET_OK; +} + +int ScaleTensorRT::AddInnerOp(nvinfer1::INetworkDefinition *network) { + if (network == nullptr) { + MS_LOG(ERROR) << "network is invalid"; + return RET_ERROR; + } + auto scale_op = op_primitive_->value_as_ScaleFusion(); + if (scale_op == nullptr) { + MS_LOG(ERROR) << "convert failed"; + return RET_ERROR; + } + + schema::ActivationType activation_type = scale_op->activation_type(); + nvinfer1::ITensor *scale_in_tensor = tensorrt_in_tensors_[0]; + // unsqueeze input Itensor to 4 dims + if (in_tensors_[0]->shape().size() < 4) { + scale_in_tensor = AddUnsqueezeOp(network); + if (scale_in_tensor == nullptr) { + MS_LOG(ERROR) << "AddUnsqueezeOp failed"; + return RET_ERROR; + } + } + // mode of scale + size_t axis = scale_op->axis(); + nvinfer1::ScaleMode mode; + auto input_data_shape = in_tensors_[0]->shape(); + auto input_weight_shape = in_tensors_[1]->shape(); + int total = std::accumulate(input_data_shape.begin(), input_data_shape.end(), 1, std::multiplies()); + MS_LOG(INFO) << "input tensor element cnt: " << total; + if (input_weight_shape.size() == 0 || (input_weight_shape.size() == 1 && input_weight_shape[0] == 1)) { + mode = nvinfer1::ScaleMode::kUNIFORM; + } else if (axis < input_data_shape.size() && input_weight_shape.size() == 1 && + input_data_shape[axis] == input_weight_shape[0]) { + mode = nvinfer1::ScaleMode::kCHANNEL; + } else if (input_weight_shape.size() == 1 && input_weight_shape[0] == total) { + mode = nvinfer1::ScaleMode::kELEMENTWISE; + } else { + MS_LOG(ERROR) << "ScaleMode create failed"; + return RET_ERROR; + } + bool nd = false; + // (input * scale + shift) ^ power + nvinfer1::Weights power{nvinfer1::DataType::kFLOAT, nullptr, 0}; + nvinfer1::Weights shift{nvinfer1::DataType::kFLOAT, nullptr, 0}; + nvinfer1::Weights scale{nvinfer1::DataType::kFLOAT, nullptr, 0}; + if (in_tensors_.size() >= 2) { + scale.values = in_tensors_[1]->data(); + scale.count = in_tensors_[1]->ElementsNum(); + nd = input_weight_shape.size() == 1 ? false : true; + } + if (in_tensors_.size() >= 3) { + shift.values = in_tensors_[2]->data(); + shift.count = in_tensors_[2]->ElementsNum(); + } + if (in_tensors_.size() >= 4) { + power.values = in_tensors_[3]->data(); + power.count = in_tensors_[3]->ElementsNum(); + } + nvinfer1::IScaleLayer *cal_layer = nullptr; + if (nd) { + MS_LOG(WARNING) << "multi dims ScaleMode enter"; + cal_layer = network->addScaleNd(*scale_in_tensor, mode, shift, scale, power, axis); + } else { + cal_layer = network->addScale(*scale_in_tensor, mode, shift, scale, power); + } + + if (cal_layer == nullptr) { + MS_LOG(ERROR) << "addScaleNd failed for: " << op_name_; + return RET_ERROR; + } + cal_layer->setName(op_name_.c_str()); + nvinfer1::ITensor *op_out_tensor = cal_layer->getOutput(0); + + // add activation + if (activation_type != schema::ActivationType::ActivationType_NO_ACTIVATION) { + MS_LOG(WARNING) << "need activation for: " << op_name_; + } + op_out_tensor->setName(out_tensors_[0]->tensor_name().c_str()); + this->AddInnerOutTensors(op_out_tensor); + return RET_OK; +} + +nvinfer1::ITensor *ScaleTensorRT::AddUnsqueezeOp(nvinfer1::INetworkDefinition *network) { + nvinfer1::IShuffleLayer *unsqueeze_layer = network->addShuffle(*this->tensorrt_in_tensors_[0]); + if (unsqueeze_layer == nullptr) { + MS_LOG(ERROR) << "addShuffle failed for: " << op_name_; + return nullptr; + } + unsqueeze_layer->setName((op_name_ + "_unsqueeze").c_str()); + std::vector unsqueeze_shape = in_tensors_[0]->shape(); + for (size_t i = 0; i < 4 - unsqueeze_shape.size(); i++) { + unsqueeze_shape.push_back(1); + } + nvinfer1::Dims unsqueeze_dims = lite::ConvertCudaDims(unsqueeze_shape); + unsqueeze_layer->setReshapeDimensions(unsqueeze_dims); + return unsqueeze_layer->getOutput(0); +} +} // namespace mindspore::lite diff --git a/mindspore/lite/src/delegate/tensorrt/op/scale_tensorrt.h b/mindspore/lite/src/delegate/tensorrt/op/scale_tensorrt.h new file mode 100644 index 00000000000..1b596c72637 --- /dev/null +++ b/mindspore/lite/src/delegate/tensorrt/op/scale_tensorrt.h @@ -0,0 +1,42 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_DELEGATE_TENSORRT_OP_SCALE_TENSORRT_H_ +#define MINDSPORE_LITE_SRC_DELEGATE_TENSORRT_OP_SCALE_TENSORRT_H_ +#include +#include +#include "src/delegate/tensorrt/op/tensorrt_op.h" + +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +namespace mindspore::lite { +class ScaleTensorRT : public TensorRTOp { + public: + ScaleTensorRT(const schema::Primitive *primitive, const std::vector &in_tensors, + const std::vector &out_tensors, const std::string &name) + : TensorRTOp(primitive, in_tensors, out_tensors, name) {} + + ~ScaleTensorRT() override = default; + + int AddInnerOp(nvinfer1::INetworkDefinition *network) override; + + int IsSupport(const schema::Primitive *primitive, const std::vector &in_tensors, + const std::vector &out_tensors) override; + + private: + nvinfer1::ITensor *AddUnsqueezeOp(nvinfer1::INetworkDefinition *network); +}; +} // namespace mindspore::lite +#endif // MINDSPORE_LITE_SRC_DELEGATE_TENSORRT_OP_SCALE_TENSORRT_H_ diff --git a/mindspore/lite/src/delegate/tensorrt/op/shape_tensorrt.cc b/mindspore/lite/src/delegate/tensorrt/op/shape_tensorrt.cc new file mode 100644 index 00000000000..c9bd4add943 --- /dev/null +++ b/mindspore/lite/src/delegate/tensorrt/op/shape_tensorrt.cc @@ -0,0 +1,47 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/delegate/tensorrt/op/shape_tensorrt.h" + +namespace mindspore::lite { +int ShapeTensorRT::IsSupport(const schema::Primitive *primitive, const std::vector &in_tensors, + const std::vector &out_tensors) { + if (in_tensors.size() != 1) { + MS_LOG(ERROR) << "Unsupported input tensor size, size is " << in_tensors.size(); + return RET_ERROR; + } + if (out_tensors.size() != 1) { + MS_LOG(ERROR) << "Unsupported output tensor size, size is " << out_tensors.size(); + return RET_ERROR; + } + return RET_OK; +} +int ShapeTensorRT::AddInnerOp(nvinfer1::INetworkDefinition *network) { + if (network == nullptr) { + MS_LOG(ERROR) << "network is invalid"; + return RET_ERROR; + } + nvinfer1::IShapeLayer *shape_layer = network->addShape(*tensorrt_in_tensors_[0]); + + if (shape_layer == nullptr) { + MS_LOG(DEBUG) << "add shape op failed for TensorRT."; + return RET_ERROR; + } + shape_layer->setName(op_name_.c_str()); + this->AddInnerOutTensors(shape_layer->getOutput(0)); + return RET_OK; +} +} // namespace mindspore::lite diff --git a/mindspore/lite/src/delegate/tensorrt/op/shape_tensorrt.h b/mindspore/lite/src/delegate/tensorrt/op/shape_tensorrt.h new file mode 100644 index 00000000000..d7500cc7f63 --- /dev/null +++ b/mindspore/lite/src/delegate/tensorrt/op/shape_tensorrt.h @@ -0,0 +1,40 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_DELEGATE_TENSORRT_OP_SHAPE_TENSORRT_H_ +#define MINDSPORE_LITE_SRC_DELEGATE_TENSORRT_OP_SHAPE_TENSORRT_H_ +#include +#include +#include "src/delegate/tensorrt/op/tensorrt_op.h" + +namespace mindspore::lite { +class ShapeTensorRT : public TensorRTOp { + public: + ShapeTensorRT(const schema::Primitive *primitive, const std::vector &in_tensors, + const std::vector &out_tensors, const std::string &name) + : TensorRTOp(primitive, in_tensors, out_tensors, name) {} + + ~ShapeTensorRT() override = default; + + int AddInnerOp(nvinfer1::INetworkDefinition *network) override; + + int IsSupport(const schema::Primitive *primitive, const std::vector &in_tensors, + const std::vector &out_tensors) override; + + protected: + nvinfer1::ILayer *layer_ = nullptr; +}; +} // namespace mindspore::lite +#endif // MINDSPORE_LITE_SRC_DELEGATE_TENSORRT_OP_SHAPE_TENSORRT_H_ diff --git a/mindspore/lite/src/delegate/tensorrt/op/shuffle_tensorrt.cc b/mindspore/lite/src/delegate/tensorrt/op/shuffle_tensorrt.cc new file mode 100644 index 00000000000..339a67c5d8b --- /dev/null +++ b/mindspore/lite/src/delegate/tensorrt/op/shuffle_tensorrt.cc @@ -0,0 +1,231 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/delegate/tensorrt/op/shuffle_tensorrt.h" +#include + +namespace mindspore::lite { +int ShuffleTensorRT::IsSupport(const schema::Primitive *primitive, const std::vector &in_tensors, + const std::vector &out_tensors) { + if ((type_ == schema::PrimitiveType::PrimitiveType_Squeeze || + type_ == schema::PrimitiveType::PrimitiveType_Unsqueeze) && + in_tensors.size() != 1) { + MS_LOG(ERROR) << "invalid input tensort size: " << in_tensors.size(); + return RET_ERROR; + } + if ((type_ == schema::PrimitiveType::PrimitiveType_Transpose) && in_tensors.size() != 2) { + MS_LOG(ERROR) << "invalid input tensort size: " << in_tensors.size(); + return RET_ERROR; + } + if (out_tensors.size() != 1) { + MS_LOG(ERROR) << "invalid output tensort size: " << out_tensors.size(); + return RET_ERROR; + } + return RET_OK; +} + +int ShuffleTensorRT::AddInnerOp(nvinfer1::INetworkDefinition *network) { + if (network == nullptr) { + MS_LOG(ERROR) << "network is invalid"; + return RET_ERROR; + } + nvinfer1::IShuffleLayer *shuffle_layer = network->addShuffle(*tensorrt_in_tensors_[0]); + if (shuffle_layer == nullptr) { + MS_LOG(ERROR) << "add Shuffle op failed for TensorRT."; + return RET_ERROR; + } + shuffle_layer->setName(op_name_.c_str()); + + switch (this->type()) { + case schema::PrimitiveType_Unsqueeze: { + int ret = AddUnsqueezeOp(shuffle_layer); + if (ret != RET_OK) { + MS_LOG(ERROR) << "AddUnSqueezeOp failed."; + return ret; + } + break; + } + case schema::PrimitiveType_Squeeze: { + int ret = AddSqueezeOp(shuffle_layer); + if (ret != RET_OK) { + MS_LOG(ERROR) << "AddSqueezeOp failed."; + return ret; + } + break; + } + case schema::PrimitiveType_Transpose: { + int ret = AddTransposeOp(shuffle_layer); + if (ret != RET_OK) { + MS_LOG(ERROR) << "AddTransposeOpss failed."; + return ret; + } + break; + } + case schema::PrimitiveType_Reshape: { + int ret = AddReshapeOp(shuffle_layer); + if (ret != RET_OK) { + MS_LOG(ERROR) << "AddReshapeOp failed."; + return ret; + } + break; + } + default: + MS_LOG(ERROR) << "Unsupported op type."; + return RET_ERROR; + } + + nvinfer1::ITensor *out_tensor = shuffle_layer->getOutput(0); + if (out_tensor == nullptr) { + MS_LOG(ERROR) << "output tensor create failed"; + return RET_ERROR; + } + out_tensor->setName(out_tensors_[0]->tensor_name().c_str()); + this->AddInnerOutTensors(out_tensor); + return RET_OK; +} + +int ShuffleTensorRT::AddSqueezeOp(nvinfer1::IShuffleLayer *shuffle_layer) { + // squeeze + auto squeeze_op = this->op_primitive_->value_as_Squeeze(); + if (squeeze_op == nullptr) { + MS_LOG(ERROR) << "SqueezeOp convert failed"; + return RET_ERROR; + } + + // axis + std::vector squeeze_shape = in_tensors_[0]->shape(); + auto begin = std::begin(squeeze_shape); + auto axis = squeeze_op->axis(); + if (axis == nullptr) { + MS_LOG(ERROR) << "AddSqueezeOp has invalid axis"; + return RET_ERROR; + } + + for (size_t i = 0; i < axis->size(); i++) { + if (squeeze_shape[axis->Get(i)] != 1) { + MS_LOG(WARNING) << "squeeze_shape value is not 1, need check"; + } + squeeze_shape.erase(begin + axis->Get(i)); + } + + nvinfer1::Dims squeeze_dims = lite::ConvertCudaDims(squeeze_shape); + MS_LOG(INFO) << "AddSqueezeOp: " << op_name_ << " squeeze_dims.nbDims: " << squeeze_dims.nbDims; + + shuffle_layer->setReshapeDimensions(squeeze_dims); + return shuffle_layer->getOutput(0) == nullptr ? RET_ERROR : RET_OK; +} + +int ShuffleTensorRT::AddUnsqueezeOp(nvinfer1::IShuffleLayer *shuffle_layer) { + // Unsqueeze + auto unsqueeze_op = this->op_primitive_->value_as_Unsqueeze(); + if (unsqueeze_op == nullptr) { + MS_LOG(ERROR) << "AddUnsqueezeOp convert failed"; + return RET_ERROR; + } + if (in_tensors_.size() != 1) { + MS_LOG(WARNING) << "AddUnsqueezeOp size of in tensort needs check: " << in_tensors_.size(); + } + // axis + std::vector unsqueeze_shape = in_tensors_[0]->shape(); + auto begin = std::begin(unsqueeze_shape); + auto axis = unsqueeze_op->axis(); + + for (size_t i = 0; i < axis->size(); i++) { + unsqueeze_shape.insert(begin + axis->Get(i), 1); + } + + nvinfer1::Dims unsqueeze_dims = lite::ConvertCudaDims(unsqueeze_shape); + MS_LOG(INFO) << "AddUnsqueezeOp: " << op_name_ << " unsqueeze_dims.nbDims: " << unsqueeze_dims.nbDims; + + shuffle_layer->setReshapeDimensions(unsqueeze_dims); + return shuffle_layer->getOutput(0) == nullptr ? RET_ERROR : RET_OK; +} + +int ShuffleTensorRT::AddTransposeOp(nvinfer1::IShuffleLayer *shuffle_layer) { + auto transpose_op = this->op_primitive_->value_as_Transpose(); + if (transpose_op == nullptr) { + MS_LOG(ERROR) << "AddTransposeOp convert failed"; + return RET_ERROR; + } + if (in_tensors_.size() != 2) { + MS_LOG(ERROR) << "AddTransposeOp size of in tensort needs check: " << in_tensors_.size(); + return RET_ERROR; + } + // perm + tensor::MSTensor *perm_ternsor = in_tensors_[1]; + if (perm_ternsor->data() == nullptr || + perm_ternsor->ElementsNum() != tensorrt_in_tensors_[0]->getDimensions().nbDims) { + MS_LOG(ERROR) << "AddTransposeOp perm_ternsor data is invalid."; + return RET_ERROR; + } + int *perm_data = reinterpret_cast(perm_ternsor->data()); + + nvinfer1::Permutation perm{}; + for (int i = 0; i < perm_ternsor->ElementsNum(); i++) { + perm.order[i] = *perm_data; + perm_data++; + } + shuffle_layer->setFirstTranspose(perm); + return RET_OK; +} +int ShuffleTensorRT::AddReshapeOp(nvinfer1::IShuffleLayer *shuffle_layer) { + auto reshape_op = this->op_primitive_->value_as_Reshape(); + if (reshape_op == nullptr) { + MS_LOG(ERROR) << "AddReshapeOp convert failed"; + return RET_ERROR; + } + if (in_tensors_.size() != 2) { + MS_LOG(ERROR) << "AddReshapeOp size of in tensort needs check: " << in_tensors_.size(); + return RET_ERROR; + } + tensor::MSTensor *shape_tensor = in_tensors_[1]; + nvinfer1::Dims reshape_dims = ConvertCudaDims(shape_tensor->data(), shape_tensor->ElementsNum()); + int ret = InferReshapeDims(tensorrt_in_tensors_[0]->getDimensions(), &reshape_dims); + if (ret != RET_OK) { + MS_LOG(ERROR) << "invalid dims for reshape " << op_name_; + return ret; + } + shuffle_layer->setReshapeDimensions(reshape_dims); + return RET_OK; +} +int ShuffleTensorRT::InferReshapeDims(nvinfer1::Dims input_dims, nvinfer1::Dims *reshape_dims) { + int infer_index = -1; + int known_cnt = 1; + for (int i = 0; i < reshape_dims->nbDims; i++) { + if (reshape_dims->d[i] == 0) { + reshape_dims->d[i] = input_dims.d[i]; + known_cnt *= input_dims.d[i]; + } else if (reshape_dims->d[i] == -1) { + if (infer_index != -1) { + MS_LOG(ERROR) << "invalid dims (more than one infer dim) for reshape " << op_name_; + return RET_ERROR; + } + infer_index = i; + } else { + known_cnt *= input_dims.d[i]; + } + } + if (infer_index != -1) { + size_t tot_cnt = 1; + for (int i = 0; i < input_dims.nbDims; i++) { + tot_cnt *= input_dims.d[i]; + } + reshape_dims->d[infer_index] = tot_cnt / known_cnt; + MS_LOG(INFO) << "reshape infer_index: " << infer_index << ", reshape infer value: " << reshape_dims->d[infer_index]; + } + return RET_OK; +} +} // namespace mindspore::lite diff --git a/mindspore/lite/src/delegate/tensorrt/op/shuffle_tensorrt.h b/mindspore/lite/src/delegate/tensorrt/op/shuffle_tensorrt.h new file mode 100644 index 00000000000..09243a1ebb1 --- /dev/null +++ b/mindspore/lite/src/delegate/tensorrt/op/shuffle_tensorrt.h @@ -0,0 +1,45 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_DELEGATE_TENSORRT_OP_SHUFFLE_TENSORRT_H_ +#define MINDSPORE_LITE_SRC_DELEGATE_TENSORRT_OP_SHUFFLE_TENSORRT_H_ +#include +#include +#include "src/delegate/tensorrt/op/tensorrt_op.h" +#include "src/delegate/tensorrt/tensorrt_utils.h" + +namespace mindspore::lite { +class ShuffleTensorRT : public TensorRTOp { + public: + ShuffleTensorRT(const schema::Primitive *primitive, const std::vector &in_tensors, + const std::vector &out_tensors, const std::string &name) + : TensorRTOp(primitive, in_tensors, out_tensors, name) {} + + ~ShuffleTensorRT() override = default; + + int AddInnerOp(nvinfer1::INetworkDefinition *network) override; + + int IsSupport(const schema::Primitive *primitive, const std::vector &in_tensors, + const std::vector &out_tensors) override; + + private: + int AddSqueezeOp(nvinfer1::IShuffleLayer *shuffle_layer); + int AddUnsqueezeOp(nvinfer1::IShuffleLayer *shuffle_layer); + int AddTransposeOp(nvinfer1::IShuffleLayer *shuffle_layer); + int AddReshapeOp(nvinfer1::IShuffleLayer *shuffle_layer); + int InferReshapeDims(nvinfer1::Dims input_dims, nvinfer1::Dims *reshape_dims); +}; +} // namespace mindspore::lite +#endif // MINDSPORE_LITE_SRC_DELEGATE_TENSORRT_OP_SHUFFLE_TENSORRT_H_ diff --git a/mindspore/lite/src/delegate/tensorrt/op/softmax_tensorrt.cc b/mindspore/lite/src/delegate/tensorrt/op/softmax_tensorrt.cc new file mode 100644 index 00000000000..b4bf6ceca2a --- /dev/null +++ b/mindspore/lite/src/delegate/tensorrt/op/softmax_tensorrt.cc @@ -0,0 +1,119 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/delegate/tensorrt/op/softmax_tensorrt.h" + +namespace mindspore::lite { +int SoftMaxTensorRT::IsSupport(const schema::Primitive *primitive, const std::vector &in_tensors, + const std::vector &out_tensors) { + if (primitive->value_type() == schema::PrimitiveType::PrimitiveType_LogSoftmax) { + with_log_ = true; + auto softmax_op = primitive->value_as_LogSoftmax(); + if (softmax_op == nullptr) { + MS_LOG(ERROR) << "LogSoftmax convert failed"; + return RET_ERROR; + } + } else { + auto softmax_op = primitive->value_as_Softmax(); + if (softmax_op == nullptr) { + MS_LOG(ERROR) << "convert failed"; + return RET_ERROR; + } + } + + if (in_tensors.size() != 1) { + MS_LOG(ERROR) << "Unsupported input tensor size, size is " << in_tensors.size(); + return RET_ERROR; + } + if (out_tensors.size() != 1) { + MS_LOG(ERROR) << "Unsupported output tensor size, size is " << out_tensors.size(); + return RET_ERROR; + } + return RET_OK; +} + +int SoftMaxTensorRT::AddInnerOp(nvinfer1::INetworkDefinition *network) { + if (network == nullptr) { + MS_LOG(ERROR) << "network is invalid"; + return RET_ERROR; + } + nvinfer1::ISoftMaxLayer *softmax_layer_ = AddSoftMaxOp(network); + if (softmax_layer_ == nullptr) { + MS_LOG(ERROR) << "add softmax op failed for TensorRT."; + return RET_ERROR; + } + softmax_layer_->setName((op_name_ + "_softmax").c_str()); + + nvinfer1::ITensor *out_tensor = softmax_layer_->getOutput(0); + if (out_tensor == nullptr) { + MS_LOG(ERROR) << "softmax output tensor create failed for TensorRT."; + return RET_ERROR; + } + if (with_log_) { + nvinfer1::IUnaryLayer *log_layer = network->addUnary(*out_tensor, nvinfer1::UnaryOperation::kLOG); + if (log_layer == nullptr) { + MS_LOG(ERROR) << "add log op failed for TensorRT."; + return RET_ERROR; + } + log_layer->setName((op_name_ + "_log").c_str()); + out_tensor = log_layer->getOutput(0); + if (out_tensor == nullptr) { + MS_LOG(ERROR) << "softmax log output tensor create failed for TensorRT."; + return RET_ERROR; + } + } + out_tensor->setName(out_tensors_[0]->tensor_name().c_str()); + this->AddInnerOutTensors(out_tensor); + return RET_OK; +} + +nvinfer1::ISoftMaxLayer *SoftMaxTensorRT::AddSoftMaxOp(nvinfer1::INetworkDefinition *network) { + nvinfer1::ISoftMaxLayer *current_layer_ = network->addSoftMax(*this->GetInnerInTensors()[0]); + if (current_layer_ == nullptr) { + MS_LOG(ERROR) << "add softmax op failed for TensorRT."; + return nullptr; + } + std::vector axis_val; + if (with_log_) { + auto softmax_op = this->GetPrimitive()->value_as_LogSoftmax(); + if (softmax_op == nullptr) { + MS_LOG(ERROR) << "LogSoftmax convert failed"; + return nullptr; + } + int64_t axis = softmax_op->axis(); + axis_val.push_back(axis); + } else { + auto softmax_op = this->GetPrimitive()->value_as_Softmax(); + if (softmax_op == nullptr) { + MS_LOG(ERROR) << "Softmax convert failed"; + return nullptr; + } + auto axis = softmax_op->axis(); + axis_val = std::vector(axis->begin(), axis->end()); + } + + if (axis_val.size() != 1) { + MS_LOG(WARNING) << "axis needs check"; + } + + if (axis_val[0] >= this->tensorrt_in_tensors_[0]->getDimensions().nbDims) { + MS_LOG(ERROR) << "axis is larger than input tensor dims."; + return nullptr; + } + current_layer_->setAxes(axis_val[0]); + return current_layer_; +} +} // namespace mindspore::lite diff --git a/mindspore/lite/src/delegate/tensorrt/op/softmax_tensorrt.h b/mindspore/lite/src/delegate/tensorrt/op/softmax_tensorrt.h new file mode 100644 index 00000000000..26108ae9fd3 --- /dev/null +++ b/mindspore/lite/src/delegate/tensorrt/op/softmax_tensorrt.h @@ -0,0 +1,41 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_DELEGATE_TENSORRT_OP_SOFTMAX_TENSORRT_H_ +#define MINDSPORE_LITE_SRC_DELEGATE_TENSORRT_OP_SOFTMAX_TENSORRT_H_ +#include +#include +#include "src/delegate/tensorrt/op/tensorrt_op.h" + +namespace mindspore::lite { +class SoftMaxTensorRT : public TensorRTOp { + public: + SoftMaxTensorRT(const schema::Primitive *primitive, const std::vector &in_tensors, + const std::vector &out_tensors, const std::string &name) + : TensorRTOp(primitive, in_tensors, out_tensors, name) {} + + ~SoftMaxTensorRT() override = default; + + int AddInnerOp(nvinfer1::INetworkDefinition *network) override; + + int IsSupport(const schema::Primitive *primitive, const std::vector &in_tensors, + const std::vector &out_tensors) override; + + private: + bool with_log_ = false; + nvinfer1::ISoftMaxLayer *AddSoftMaxOp(nvinfer1::INetworkDefinition *network); +}; +} // namespace mindspore::lite +#endif // MINDSPORE_LITE_SRC_DELEGATE_TENSORRT_OP_SOFTMAX_TENSORRT_H_ diff --git a/mindspore/lite/src/delegate/tensorrt/op/tensorrt_op.cc b/mindspore/lite/src/delegate/tensorrt/op/tensorrt_op.cc new file mode 100644 index 00000000000..e39ad7274b7 --- /dev/null +++ b/mindspore/lite/src/delegate/tensorrt/op/tensorrt_op.cc @@ -0,0 +1,45 @@ +/** + * Copyright 2020-2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/delegate/tensorrt/op/tensorrt_op.h" + +namespace mindspore::lite { +const schema::Primitive *TensorRTOp::GetPrimitive() { return this->op_primitive_; } + +void TensorRTOp::AddInnerInTensors(nvinfer1::ITensor *tensor) { this->tensorrt_in_tensors_.push_back(tensor); } + +void TensorRTOp::AddInnerOutTensors(nvinfer1::ITensor *tensor) { this->tensorrt_out_tensors_.push_back(tensor); } + +std::vector &TensorRTOp::GetInnerOutTensor() { return this->tensorrt_out_tensors_; } + +std::vector &TensorRTOp::GetInnerInTensors() { return this->tensorrt_in_tensors_; } + +std::string TensorRTOp::GetOpName() { return this->op_name_; } + +std::vector &TensorRTOp::inputs() { return this->in_tensors_; } + +std::vector &TensorRTOp::outputs() { return this->out_tensors_; } + +schema::PrimitiveType TensorRTOp::type() const { return this->type_; } + +void TensorRTOp::set_in_ops(const std::vector &in_ops) { this->in_ops_ = in_ops; } + +void TensorRTOp::set_out_ops(const std::vector &out_ops) { this->out_ops_ = out_ops; } + +const std::vector &TensorRTOp::in_ops() const { return this->in_ops_; } + +const std::vector &TensorRTOp::out_ops() const { return this->out_ops_; } +} // namespace mindspore::lite diff --git a/mindspore/lite/src/delegate/tensorrt/op/tensorrt_op.h b/mindspore/lite/src/delegate/tensorrt/op/tensorrt_op.h new file mode 100644 index 00000000000..1a28d2406dd --- /dev/null +++ b/mindspore/lite/src/delegate/tensorrt/op/tensorrt_op.h @@ -0,0 +1,114 @@ +/** + * Copyright 2020-2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_DELEGATE_TENSORRT_OP_ +#define MINDSPORE_LITE_SRC_RUNTIME_DELEGATE_TENSORRT_OP_ + +#include +#include +#include +#include +#include "include/kernel.h" +#include "src/common/log_adapter.h" +#include "include/errorcode.h" + +namespace mindspore::lite { +class TensorRTOp { + public: + explicit TensorRTOp(const schema::Primitive *primitive, std::vector in_tensors, + std::vector out_tensors, std::string name) + : op_primitive_(primitive), + in_tensors_(std::move(in_tensors)), + out_tensors_(std::move(out_tensors)), + op_name_(std::move(name)) { + if (primitive != nullptr) { + this->type_ = primitive->value_type(); + } + } + + virtual ~TensorRTOp() = default; + + virtual int IsSupport(const schema::Primitive *primitive, const std::vector &in_tensors, + const std::vector &out_tensors) = 0; + + virtual int AddInnerOp(nvinfer1::INetworkDefinition *network) = 0; + + const schema::Primitive *GetPrimitive(); + + void AddInnerInTensors(nvinfer1::ITensor *tensor); + + void AddInnerOutTensors(nvinfer1::ITensor *tensor); + + std::vector &GetInnerOutTensor(); + + std::vector &GetInnerInTensors(); + + std::string GetOpName(); + + std::vector &inputs(); + + std::vector &outputs(); + + schema::PrimitiveType type() const; + + void set_in_ops(const std::vector &in_ops); + + void set_out_ops(const std::vector &out_ops); + + const std::vector &in_ops() const; + + const std::vector &out_ops() const; + + protected: + std::vector layers_; + + const schema::Primitive *op_primitive_; + + std::vector in_tensors_; + + std::vector out_tensors_; + + std::vector tensorrt_in_tensors_; + + std::vector tensorrt_out_tensors_; + + std::vector in_ops_; + + std::vector out_ops_; + + std::string op_name_; + + schema::PrimitiveType type_ = schema::PrimitiveType_NONE; +}; + +template +TensorRTOp *GetTensorRTOp(const schema::Primitive *primitive, const std::vector &in_tensors, + const std::vector &out_tensors, const std::string &name) { + auto *op = new (std::nothrow) T(primitive, in_tensors, out_tensors, name); + if (op == nullptr) { + MS_LOG(ERROR) << "TensorRT is nullptr."; + return nullptr; + } + + auto ret = op->IsSupport(primitive, in_tensors, out_tensors); + if (ret != RET_OK) { + MS_LOG(ERROR) << "TensorRT op is not supported."; + delete op; + return nullptr; + } + return op; +} +} // namespace mindspore::lite +#endif // MINDSPORE_LITE_SRC_RUNTIME_DELEGATE_TENSORRT_OP_ diff --git a/mindspore/lite/src/delegate/tensorrt/op/unary_tensorrt.cc b/mindspore/lite/src/delegate/tensorrt/op/unary_tensorrt.cc new file mode 100644 index 00000000000..4d34f3b090e --- /dev/null +++ b/mindspore/lite/src/delegate/tensorrt/op/unary_tensorrt.cc @@ -0,0 +1,55 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/delegate/tensorrt/op/unary_tensorrt.h" + +namespace mindspore::lite { +int UnaryTensorRT::IsSupport(const schema::Primitive *primitive, const std::vector &in_tensors, + const std::vector &out_tensors) { + if (in_tensors.size() != 1) { + MS_LOG(ERROR) << "Unsupported input tensor size, size is " << in_tensors.size(); + } + if (out_tensors.size() != 1) { + MS_LOG(ERROR) << "Unsupported output tensor size, size is " << out_tensors.size(); + } + auto it = unary_ops_.find(primitive->value_type()); + if (it != unary_ops_.end()) { + unary_op_ = it->second; + } else { + MS_LOG(ERROR) << "unsupported unary ops type: " << schema::EnumNamePrimitiveType(primitive->value_type()); + return RET_ERROR; + } + return RET_OK; +} + +int UnaryTensorRT::AddInnerOp(nvinfer1::INetworkDefinition *network) { + if (network == nullptr || this->tensorrt_in_tensors_.size() != 1) { + MS_LOG(ERROR) << "network or input tensor is invalid"; + return RET_ERROR; + } + nvinfer1::IUnaryLayer *cal_layer = network->addUnary(*tensorrt_in_tensors_[0], unary_op_); + if (cal_layer == nullptr) { + MS_LOG(ERROR) << "addUnary failed for: " << op_name_; + return RET_ERROR; + } + cal_layer->setName(op_name_.c_str()); + + nvinfer1::ITensor *op_out_tensor = cal_layer->getOutput(0); + op_out_tensor->setName(out_tensors_[0]->tensor_name().c_str()); + this->AddInnerOutTensors(op_out_tensor); + return RET_OK; +} +} // namespace mindspore::lite diff --git a/mindspore/lite/src/delegate/tensorrt/op/unary_tensorrt.h b/mindspore/lite/src/delegate/tensorrt/op/unary_tensorrt.h new file mode 100644 index 00000000000..2b430f39af3 --- /dev/null +++ b/mindspore/lite/src/delegate/tensorrt/op/unary_tensorrt.h @@ -0,0 +1,44 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_DELEGATE_TENSORRT_OP_UNARY_TENSORRT_H_ +#define MINDSPORE_LITE_SRC_DELEGATE_TENSORRT_OP_UNARY_TENSORRT_H_ +#include +#include +#include +#include "src/delegate/tensorrt/op/tensorrt_op.h" + +namespace mindspore::lite { +class UnaryTensorRT : public TensorRTOp { + public: + UnaryTensorRT(const schema::Primitive *primitive, const std::vector &in_tensors, + const std::vector &out_tensors, const std::string &name) + : TensorRTOp(primitive, in_tensors, out_tensors, name) {} + + ~UnaryTensorRT() override = default; + + int AddInnerOp(nvinfer1::INetworkDefinition *network) override; + + int IsSupport(const schema::Primitive *primitive, const std::vector &in_tensors, + const std::vector &out_tensors) override; + + private: + std::map unary_ops_ = { + {schema::PrimitiveType::PrimitiveType_Sqrt, nvinfer1::UnaryOperation::kSQRT}, + }; + nvinfer1::UnaryOperation unary_op_; +}; +} // namespace mindspore::lite +#endif // MINDSPORE_LITE_SRC_DELEGATE_TENSORRT_OP_UNARY_TENSORRT_H_ diff --git a/mindspore/lite/src/delegate/tensorrt/tensorrt_allocator.cc b/mindspore/lite/src/delegate/tensorrt/tensorrt_allocator.cc new file mode 100644 index 00000000000..4637eb6564a --- /dev/null +++ b/mindspore/lite/src/delegate/tensorrt/tensorrt_allocator.cc @@ -0,0 +1,86 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/delegate/tensorrt/tensorrt_allocator.h" +#include +#include +#include "src/common/log_adapter.h" +#include "src/delegate/tensorrt/tensorrt_utils.h" + +namespace mindspore::lite { +void *TensorRTAllocator::MallocDeviceMem(mindspore::tensor::MSTensor *host_tensor, size_t size) { + if (host_tensor == nullptr) { + return nullptr; + } + if (cuda_tensor_map_.find(host_tensor->tensor_name()) != cuda_tensor_map_.end()) { + return nullptr; + } + + auto cuda_type = ConvertDataType(host_tensor->data_type()); + if (static_cast(cuda_type) == -1) { + MS_LOG(ERROR) << "Unsupported Tensor Type:" << host_tensor->data_type(); + return nullptr; + } + void *device_ptr; + auto cuda_ret = cudaMalloc(&device_ptr, size); + if (cuda_ret != cudaSuccess) { + MS_LOG(ERROR) << "Cuda Malloc failed for size:" << size; + return nullptr; + } + cuda_tensor_map_[host_tensor->tensor_name()] = device_ptr; + return device_ptr; +} + +void *TensorRTAllocator::GetDevicePtr(const std::string &tensor_name) { + if (tensor_name.empty()) { + return nullptr; + } + if (cuda_tensor_map_.find(tensor_name) == cuda_tensor_map_.end()) { + return nullptr; + } + return this->cuda_tensor_map_.find(tensor_name)->second; +} + +int TensorRTAllocator::SyncMemInHostAndDevice(mindspore::tensor::MSTensor *host_tensor, + const std::string &device_tensor_name, bool is_host2device, bool sync) { + if (host_tensor == nullptr || host_tensor->data() == nullptr || + cuda_tensor_map_.find(device_tensor_name) == cuda_tensor_map_.end()) { + MS_LOG(ERROR) << " host or device ptr is null."; + return RET_ERROR; + } + auto device_ptr = cuda_tensor_map_.find(device_tensor_name)->second; + + void *src_ptr = is_host2device ? host_tensor->data() : device_ptr; + void *dst_ptr = is_host2device ? device_ptr : host_tensor->data(); + cudaMemcpyKind kind = is_host2device ? cudaMemcpyHostToDevice : cudaMemcpyDeviceToHost; + auto cuda_ret = cudaMemcpy(dst_ptr, src_ptr, host_tensor->Size(), kind); + if (cuda_ret != cudaSuccess) { + MS_LOG(ERROR) << "copy mem failed."; + return RET_ERROR; + } + return RET_OK; +} + +int TensorRTAllocator::ClearDeviceMem() { + for (const auto &iter : cuda_tensor_map_) { + auto cuda_ret = cudaFree(iter.second); + if (cuda_ret != cudaSuccess && cuda_ret != cudaErrorCudartUnloading) { + MS_LOG(WARNING) << "free cuda failed for " << cudaGetErrorName(cuda_ret); + } + } + return RET_OK; +} +} // namespace mindspore::lite diff --git a/mindspore/lite/src/delegate/tensorrt/tensorrt_allocator.h b/mindspore/lite/src/delegate/tensorrt/tensorrt_allocator.h new file mode 100644 index 00000000000..1c6d0ca2c76 --- /dev/null +++ b/mindspore/lite/src/delegate/tensorrt/tensorrt_allocator.h @@ -0,0 +1,38 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_DELEGATE_TENSORRT_TENSORRT_ALLOCATOR_H +#define MINDSPORE_LITE_SRC_RUNTIME_DELEGATE_TENSORRT_TENSORRT_ALLOCATOR_H +#include "src/delegate/tensorrt/tensorrt_allocator.h" +#include +#include +#include "include/ms_tensor.h" + +namespace mindspore::lite { +class TensorRTAllocator { + public: + TensorRTAllocator() = default; + void *MallocDeviceMem(mindspore::tensor::MSTensor *host_tensor, size_t size); + void *GetDevicePtr(const std::string &tensor_name); + int SyncMemInHostAndDevice(mindspore::tensor::MSTensor *host_tensor, const std::string &device_tensor_name, + bool is_host2device, bool sync = true); + int ClearDeviceMem(); + + private: + std::map cuda_tensor_map_; +}; +} // namespace mindspore::lite +#endif // MINDSPORE_LITE_SRC_RUNTIME_DELEGATE_TENSORRT_TENSORRT_ALLOCATOR_H diff --git a/mindspore/lite/src/delegate/tensorrt/tensorrt_delegate.cc b/mindspore/lite/src/delegate/tensorrt/tensorrt_delegate.cc new file mode 100644 index 00000000000..95987e954ce --- /dev/null +++ b/mindspore/lite/src/delegate/tensorrt/tensorrt_delegate.cc @@ -0,0 +1,140 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/delegate/tensorrt/tensorrt_delegate.h" +#include +#include "src/delegate/delegate_utils.h" +#include "src/delegate/tensorrt/op/activation_tensorrt.h" +#include "src/delegate/tensorrt/op/shape_tensorrt.h" +#include "src/delegate/tensorrt/op/gather_tensorrt.h" +#include "src/delegate/tensorrt/op/shuffle_tensorrt.h" +#include "src/delegate/tensorrt/op/concate_tensorrt.h" +#include "src/delegate/tensorrt/op/convolution_tensorrt.h" +#include "src/delegate/tensorrt/op/elementwise_tensorrt.h" +#include "src/delegate/tensorrt/op/reduce_tensorrt.h" +#include "src/delegate/tensorrt/op/softmax_tensorrt.h" +#include "src/delegate/tensorrt/op/unary_tensorrt.h" +#include "src/delegate/tensorrt/op/matmul_tensorrt.h" +#include "src/delegate/tensorrt/op/scale_tensorrt.h" + +namespace mindspore::lite { +int TensorRTDelegate::Init() { + op_func_lists_.clear(); + op_func_lists_ = { + {schema::PrimitiveType_Activation, GetTensorRTOp}, + {schema::PrimitiveType_Unsqueeze, GetTensorRTOp}, + {schema::PrimitiveType_Squeeze, GetTensorRTOp}, + {schema::PrimitiveType_Reshape, GetTensorRTOp}, + {schema::PrimitiveType_Concat, GetTensorRTOp}, + {schema::PrimitiveType_Conv2DFusion, GetTensorRTOp}, + {schema::PrimitiveType_SubFusion, GetTensorRTOp}, + {schema::PrimitiveType_DivFusion, GetTensorRTOp}, + {schema::PrimitiveType_PowFusion, GetTensorRTOp}, + {schema::PrimitiveType_AddFusion, GetTensorRTOp}, + {schema::PrimitiveType_Transpose, GetTensorRTOp}, + {schema::PrimitiveType_ReduceFusion, GetTensorRTOp}, + {schema::PrimitiveType_Sqrt, GetTensorRTOp}, + {schema::PrimitiveType_MatMul, GetTensorRTOp}, + {schema::PrimitiveType_ScaleFusion, GetTensorRTOp}, + }; + return RET_OK; +} + +int TensorRTDelegate::Build(DelegateModel *model) { + KernelIter from, end; + std::vector tensorrt_ops; + int graph_index = 0; + for (KernelIter iter = model->BeginKernelIterator(); iter != model->EndKernelIterator(); iter++) { + kernel::Kernel *kernel = *iter; + auto tensorrt_op = FindTensorRTOp(kernel, model->GetPrimitive(kernel)); + + if (tensorrt_op != nullptr) { + // If tensorrt_ops does not equal nullptr, this kernel can be supported by delegate + if (tensorrt_ops.size() == 0) { + from = iter; + } + tensorrt_ops.push_back(tensorrt_op); + end = iter; + } else { + if (tensorrt_ops.size() > 0) { + auto tensorrt_subgraph = CreateTensorRTGraph(tensorrt_ops, model, from, end); + if (tensorrt_subgraph == nullptr) { + MS_LOG(ERROR) << "Create TensorRT Graph failed."; + return RET_ERROR; + } + tensorrt_subgraph->set_name("TensorRtGraph" + std::to_string(graph_index++)); + iter = model->Replace(from, end + 1, tensorrt_subgraph); + tensorrt_ops.clear(); + } + } + } + if (tensorrt_ops.size() > 0) { + auto tensorrt_subgraph = CreateTensorRTGraph(tensorrt_ops, model, from, end); + if (tensorrt_subgraph == nullptr) { + MS_LOG(DEBUG) << "Create TensorRT Graph failed."; + return RET_ERROR; + } + tensorrt_subgraph->set_name("TensorRtGraph" + std::to_string(graph_index++)); + model->Replace(from, end + 1, tensorrt_subgraph); + tensorrt_ops.clear(); + } + return RET_OK; +} + +TensorRTOp *TensorRTDelegate::FindTensorRTOp(kernel::Kernel *kernel, const schema::Primitive *primitive) { + auto in_tensors = kernel->inputs(); + auto out_tensors = kernel->outputs(); + auto name = kernel->name(); + auto node_type = primitive->value_type(); + + if (op_func_lists_.find(node_type) != op_func_lists_.end()) { + return op_func_lists_[node_type](primitive, in_tensors, out_tensors, name); + } else { + MS_LOG(WARNING) << "Unsupported op type for TensorRT. kernel->name:" << kernel->name() + << " type:" << schema::EnumNamePrimitiveType(primitive->value_type()); + return nullptr; + } +} + +TensorRTSubGraph *TensorRTDelegate::CreateTensorRTGraph(const std::vector &ops, DelegateModel *model, + KernelIter from, KernelIter end) { + auto in_tensors = GraphInTensors(ops, model, from, end); + auto out_tensors = GraphOutTensors(ops, model, from, end); + auto *tensorrt_graph = new (std::nothrow) TensorRTSubGraph(ops, in_tensors, out_tensors); + if (tensorrt_graph == nullptr) { + MS_LOG(ERROR) << "new tensorrt_graph failed."; + return nullptr; + } + // 1. For every op, find pre and next ops + FindPreNextOps(ops); + + // 2. Init TensorRT SubGraph. + auto ret = tensorrt_graph->Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "TensorRTGraph init failed."; + return nullptr; + } + + // 3. Build TensorRT Model. + ret = tensorrt_graph->BuildTensorRTGraph(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "TensorRTGraph build failed."; + return nullptr; + } + + return tensorrt_graph; +} +} // namespace mindspore::lite diff --git a/mindspore/lite/src/delegate/tensorrt/tensorrt_delegate.h b/mindspore/lite/src/delegate/tensorrt/tensorrt_delegate.h new file mode 100644 index 00000000000..d2f47a30775 --- /dev/null +++ b/mindspore/lite/src/delegate/tensorrt/tensorrt_delegate.h @@ -0,0 +1,51 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_DELEGATE_TENSORRT_DELEGATE_ +#define MINDSPORE_LITE_SRC_RUNTIME_DELEGATE_TENSORRT_DELEGATE_ +#include +#include +#include +#include "include/delegate.h" +#include "src/delegate/tensorrt/tensorrt_subgraph.h" +#include "include/kernel.h" +#include "include/errorcode.h" +#include "src/common/log_adapter.h" + +namespace mindspore::lite { +typedef TensorRTOp *(*TensorRTGetOp)(const schema::Primitive *primitive, + const std::vector &in_tensors, + const std::vector &out_tensors, const std::string &name); + +class TensorRTDelegate : public Delegate { + public: + TensorRTDelegate() = default; + + ~TensorRTDelegate() override = default; + + int Init() override; + + int Build(DelegateModel *model) override; + + private: + TensorRTOp *FindTensorRTOp(kernel::Kernel *kernel, const schema::Primitive *primitive); + + TensorRTSubGraph *CreateTensorRTGraph(const std::vector &ops, DelegateModel *model, KernelIter from, + KernelIter end); + + std::map op_func_lists_; +}; +} // namespace mindspore::lite +#endif // MINDSPORE_LITE_SRC_RUNTIME_DELEGATE_TENSORRT_DELEGATE_ diff --git a/mindspore/lite/src/delegate/tensorrt/tensorrt_runtime.cc b/mindspore/lite/src/delegate/tensorrt/tensorrt_runtime.cc new file mode 100644 index 00000000000..87457329d80 --- /dev/null +++ b/mindspore/lite/src/delegate/tensorrt/tensorrt_runtime.cc @@ -0,0 +1,63 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/delegate/tensorrt/tensorrt_runtime.h" +#include +#include + +namespace mindspore::lite { +static std::mutex g_mtx; +TensorRTRuntime *TensorRTRuntime::cuda_runtime_instance_ = nullptr; +TensorRTRuntime *TensorRTRuntime::GetInstance() { + std::unique_lock lck(g_mtx); + static TensorRTRuntime cuda_runtime; + if (cuda_runtime_instance_ == nullptr) { + cuda_runtime_instance_ = &cuda_runtime; + } + return cuda_runtime_instance_; +} +int TensorRTRuntime::Init() { + if (is_init_) { + return RET_OK; + } + builder_ = nvinfer1::createInferBuilder(this->logger_); + if (builder_ == nullptr) { + MS_LOG(ERROR) << "create infer builder failed."; + return RET_ERROR; + } + builder_->setMaxBatchSize(MAX_BATCH_SIZE); + + allocator_ = new (std::nothrow) TensorRTAllocator(); + if (allocator_ == nullptr) { + MS_LOG(ERROR) << "Create allocator failed."; + return RET_ERROR; + } + is_init_ = true; + return RET_OK; +} + +TensorRTRuntime::~TensorRTRuntime() { + if (builder_ != nullptr) { + builder_->destroy(); + builder_ = nullptr; + } + if (allocator_ != nullptr) { + allocator_->ClearDeviceMem(); + delete allocator_; + allocator_ = nullptr; + } +} +} // namespace mindspore::lite diff --git a/mindspore/lite/src/delegate/tensorrt/tensorrt_runtime.h b/mindspore/lite/src/delegate/tensorrt/tensorrt_runtime.h new file mode 100644 index 00000000000..67839ed1d58 --- /dev/null +++ b/mindspore/lite/src/delegate/tensorrt/tensorrt_runtime.h @@ -0,0 +1,69 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_DELEGATE_TENSORRT_BUILDER_ +#define MINDSPORE_LITE_SRC_RUNTIME_DELEGATE_TENSORRT_BUILDER_ +#include +#include "include/errorcode.h" +#include "src/delegate/tensorrt/tensorrt_utils.h" +#include "src/delegate/tensorrt/tensorrt_allocator.h" +#define MAX_BATCH_SIZE 64 + +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; + +namespace mindspore::lite { +class TensorRTLogger : public nvinfer1::ILogger { + void log(Severity severity, const char *msg) override { + if (severity == Severity::kINTERNAL_ERROR || severity == Severity::kERROR) { + MS_LOG(ERROR) << msg; + } else if (severity == Severity::kWARNING) { + MS_LOG(WARNING) << msg; + } else if (severity == Severity::kINFO) { + MS_LOG(INFO) << msg; + } else { + MS_LOG(DEBUG) << msg; + } + } +}; + +class TensorRTRuntime { + public: + TensorRTRuntime() = default; + + ~TensorRTRuntime(); + + static TensorRTRuntime *GetInstance(); + + int Init(); + + nvinfer1::IBuilder *GetBuilder() { return this->builder_; } + + int GetBatchSize() { return batch_size_; } + + void SetBatchSize(int batch_size) { batch_size_ = batch_size; } + + TensorRTAllocator *GetAllocator() { return this->allocator_; } + + private: + static TensorRTRuntime *cuda_runtime_instance_; + bool is_init_ = false; + nvinfer1::IBuilder *builder_{nullptr}; + TensorRTLogger logger_; + TensorRTAllocator *allocator_{nullptr}; + int batch_size_ = 1; +}; +} // namespace mindspore::lite +#endif // MINDSPORE_LITE_SRC_RUNTIME_DELEGATE_TENSORRT_BUILDER_ diff --git a/mindspore/lite/src/delegate/tensorrt/tensorrt_subgraph.cc b/mindspore/lite/src/delegate/tensorrt/tensorrt_subgraph.cc new file mode 100644 index 00000000000..73f3306132b --- /dev/null +++ b/mindspore/lite/src/delegate/tensorrt/tensorrt_subgraph.cc @@ -0,0 +1,214 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/delegate/tensorrt/tensorrt_subgraph.h" +#include +#include +#include +#include "src/delegate/delegate_utils.h" + +namespace mindspore::lite { +TensorRTSubGraph::~TensorRTSubGraph() { + if (network_ != nullptr) { + network_->destroy(); + network_ = nullptr; + } + if (config_ != nullptr) { + config_->destroy(); + config_ = nullptr; + } + if (context_ != nullptr) { + context_->destroy(); + context_ = nullptr; + } + if (engine_ != nullptr) { + engine_->destroy(); + engine_ = nullptr; + } + if (tensor_bindings_ != nullptr) { + delete tensor_bindings_; + tensor_bindings_ = nullptr; + } + for (auto op : all_ops_) { + delete op; + } +} + +int TensorRTSubGraph::Init() { + auto ret = GetGraphInOutOps(inputs_, outputs_, &in_ops_, &out_ops_, all_ops_); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Get NPU subgraph input and output ops failed."; + return RET_ERROR; + } + runtime_ = TensorRTRuntime::GetInstance(); + ret = runtime_->Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "TensorRTRuntime init failed."; + return RET_ERROR; + } + this->network_ = runtime_->GetBuilder()->createNetworkV2( + 1U << static_cast(nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH)); + if (this->network_ == nullptr) { + MS_LOG(ERROR) << "New network failed."; + return RET_ERROR; + } + return RET_OK; +} + +int TensorRTSubGraph::BuildEngine() { + this->config_ = runtime_->GetBuilder()->createBuilderConfig(); + if (this->config_ == nullptr) { + MS_LOG(ERROR) << "create builder config failed."; + return RET_ERROR; + } + engine_ = runtime_->GetBuilder()->buildEngineWithConfig(*this->network_, *this->config_); + if (engine_ == nullptr) { + MS_LOG(ERROR) << "Create engine failed in TensorRT network"; + return RET_ERROR; + } + return RET_OK; +} + +int TensorRTSubGraph::BuildTensorRTGraph() { + MS_ASSERT(!all_ops_.empty()); + // Connect NetWork. + int ret; + for (auto cur_op : all_ops_) { + for (auto in_tensor : cur_op->inputs()) { + // Data From CPU + if (IsSubGraphInputTensor(this->inputs(), in_tensor)) { + auto cuda_dtype = ConvertDataType(in_tensor->data_type()); + if (static_cast(cuda_dtype) == -1) { + MS_LOG(ERROR) << "Unsupported input data type " << in_tensor->data_type(); + return RET_ERROR; + } + auto trt_tensor = + this->network_->addInput(in_tensor->tensor_name().c_str(), cuda_dtype, ConvertCudaDims(in_tensor->shape())); + cur_op->AddInnerInTensors(trt_tensor); + continue; + } + + auto trt_tensor = FindTensorRTInputs(cur_op, in_tensor); + // weight tensor + if (trt_tensor == nullptr) { + if (trt_specific_weight_nodes_.find(cur_op->type()) == trt_specific_weight_nodes_.end()) { + if (in_tensor == nullptr) { + MS_LOG(ERROR) << "Weight Tensor is nullptr."; + return RET_ERROR; + } + trt_tensor = lite::ConvertConstantTensor(this->network_, in_tensor); + cur_op->AddInnerInTensors(trt_tensor); + } + } else { + cur_op->AddInnerInTensors(trt_tensor); + } + } + + ret = cur_op->AddInnerOp(this->network_); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Add op failed in TensorRT network"; + return RET_ERROR; + } + } + + // Mark NetWork Output Tensor. + for (auto out_tensor : outputs_) { + for (auto out_op : this->out_ops_) { + for (size_t index = 0; index < out_op->outputs().size(); index++) { + if (out_op->outputs()[index] == out_tensor) { + out_op->GetInnerOutTensor()[index]->setName(out_tensor->tensor_name().c_str()); + this->network_->markOutput(*out_op->GetInnerOutTensor()[index]); + } + } + } + } + + ret = BuildEngine(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Create engine failed in TensorRT network"; + return ret; + } + return RET_OK; +} + +int TensorRTSubGraph::Prepare() { + if (runtime_->GetBatchSize() <= 0) { + MS_LOG(ERROR) << "TensorRTSubGraph has invalid batch size."; + return RET_ERROR; + } + if (this->engine_ == nullptr) { + MS_LOG(ERROR) << "engine_ is null in this builder_"; + return RET_ERROR; + } + this->context_ = this->engine_->createExecutionContext(); + if (this->context_ == nullptr) { + MS_LOG(ERROR) << "TensorRTSubGraph create context failed."; + return RET_ERROR; + } + int binding_num = this->engine_->getNbBindings(); + tensor_bindings_ = new (std::nothrow) void *[binding_num]; + if (tensor_bindings_ == nullptr) { + MS_LOG(ERROR) << "malloc tensor binding array failed."; + return RET_ERROR; + } + + for (auto tensor : inputs_) { + auto device_ptr = runtime_->GetAllocator()->MallocDeviceMem(tensor, tensor->Size()); + int index = this->engine_->getBindingIndex(tensor->tensor_name().c_str()); + tensor_bindings_[index] = device_ptr; + trt_in_tensor_name_.push_back(tensor->tensor_name()); + } + + for (auto tensor : outputs_) { + tensor->MutableData(); + auto device_ptr = runtime_->GetAllocator()->MallocDeviceMem(tensor, tensor->Size()); + int index = this->engine_->getBindingIndex(tensor->tensor_name().c_str()); + tensor_bindings_[index] = device_ptr; + trt_out_tensor_name_.push_back(tensor->tensor_name()); + } + return RET_OK; +} + +int TensorRTSubGraph::Execute() { + for (size_t i = 0; i < inputs_.size(); i++) { + runtime_->GetAllocator()->SyncMemInHostAndDevice(inputs_[i], trt_in_tensor_name_[i], true); + } + auto ret = this->context_->executeV2(tensor_bindings_); + if (!ret) { + MS_LOG(ERROR) << "TensorRT execute failed."; + return RET_ERROR; + } + for (size_t i = 0; i < outputs_.size(); i++) { + if (outputs_[i]->MutableData() == nullptr) { + MS_LOG(ERROR) << "Malloc output tensor data failed."; + } + runtime_->GetAllocator()->SyncMemInHostAndDevice(outputs_[i], trt_out_tensor_name_[i], false); + } + return RET_OK; +} + +nvinfer1::ITensor *TensorRTSubGraph::FindTensorRTInputs(TensorRTOp *cur_op, tensor::MSTensor *in_tensor) { + for (auto input_op : cur_op->in_ops()) { + for (size_t i = 0; i < input_op->outputs().size(); i++) { + auto out_tensor = input_op->outputs().at(i); + if (in_tensor == out_tensor) { + return input_op->GetInnerOutTensor().at(i); + } + } + } + return nullptr; +} +} // namespace mindspore::lite diff --git a/mindspore/lite/src/delegate/tensorrt/tensorrt_subgraph.h b/mindspore/lite/src/delegate/tensorrt/tensorrt_subgraph.h new file mode 100644 index 00000000000..447ca715963 --- /dev/null +++ b/mindspore/lite/src/delegate/tensorrt/tensorrt_subgraph.h @@ -0,0 +1,82 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_DELEGATE_TENSORRT_SUB_GTAPH_ +#define MINDSPORE_LITE_SRC_RUNTIME_DELEGATE_TENSORRT_SUB_GTAPH_ +#include +#include +#include +#include +#include "include/kernel.h" +#include "src/delegate/tensorrt/tensorrt_runtime.h" +#include "src/delegate/tensorrt/tensorrt_utils.h" + +namespace mindspore::lite { +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +class TensorRTSubGraph : public kernel::Kernel { + public: + TensorRTSubGraph(std::vector ops, const std::vector &inputs, + const std::vector &outputs) + : kernel::Kernel(inputs, outputs, nullptr, nullptr), all_ops_(std::move(ops)) { + trt_specific_weight_nodes_ = { + schema::PrimitiveType_Conv2DFusion, schema::PrimitiveType_ReduceFusion, schema::PrimitiveType_Transpose, + schema::PrimitiveType_Gather, schema::PrimitiveType_Reshape, schema::PrimitiveType_PowFusion, + schema::PrimitiveType_DivFusion, schema::PrimitiveType_MatMul, schema::PrimitiveType_ScaleFusion}; + } + + ~TensorRTSubGraph() override; + + int Prepare() override; + + int Execute() override; + + int ReSize() override { + MS_LOG(ERROR) << "TensorRT does not support the resize function temporarily."; + return lite::RET_ERROR; + } + + int BuildTensorRTGraph(); + + int Init(); + + private: + int BuildEngine(); + + static nvinfer1::ITensor *FindTensorRTInputs(TensorRTOp *cur_op, tensor::MSTensor *in_tensor); + + TensorRTRuntime *runtime_{nullptr}; + + std::vector all_ops_{}; + // subgraph input nodes. + std::vector in_ops_{}; + // subgraph output nodes. + std::vector out_ops_{}; + + void **tensor_bindings_{nullptr}; + + std::set trt_specific_weight_nodes_; + + // save in/out tensor name for subgraph isolate. + std::vector trt_in_tensor_name_; + std::vector trt_out_tensor_name_; + + nvinfer1::INetworkDefinition *network_{nullptr}; + nvinfer1::IBuilderConfig *config_{nullptr}; + nvinfer1::ICudaEngine *engine_{nullptr}; + nvinfer1::IExecutionContext *context_{nullptr}; +}; +} // namespace mindspore::lite +#endif // MINDSPORE_LITE_SRC_RUNTIME_DELEGATE_TENSORRT_SUB_GTAPH_ diff --git a/mindspore/lite/src/delegate/tensorrt/tensorrt_utils.cc b/mindspore/lite/src/delegate/tensorrt/tensorrt_utils.cc new file mode 100644 index 00000000000..0a8cf4e9635 --- /dev/null +++ b/mindspore/lite/src/delegate/tensorrt/tensorrt_utils.cc @@ -0,0 +1,169 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/delegate/tensorrt/tensorrt_utils.h" +#include + +namespace mindspore::lite { +nvinfer1::Dims ConvertCudaDims(const std::vector &shape) { + nvinfer1::Dims dims{}; + if (!shape.empty()) { + dims.nbDims = shape.size(); + for (int i = 0; i < dims.nbDims; i++) { + dims.d[i] = shape[i]; + } + } + return dims; +} +nvinfer1::Dims ConvertCudaDims(int data, size_t size) { + nvinfer1::Dims dims{}; + dims.nbDims = size; + for (size_t i = 0; i < size; i++) { + dims.d[i] = data; + } + return dims; +} + +nvinfer1::Dims ConvertCudaDims(void *data, size_t size) { + nvinfer1::Dims dims{}; + dims.nbDims = size; + int *dims_data = reinterpret_cast(data); + for (size_t i = 0; i < size; i++) { + dims.d[i] = *(dims_data + i); + } + return dims; +} + +nvinfer1::IShuffleLayer *SetTranspose(nvinfer1::INetworkDefinition *network, const nvinfer1::ITensor &input, + nvinfer1::Permutation permutation) { + nvinfer1::IShuffleLayer *layer = network->addShuffle(const_cast(input)); + if (layer == nullptr) { + MS_LOG(ERROR) << "failed to create ShuffleLayer when create transpose op."; + return nullptr; + } + layer->setFirstTranspose(permutation); + return layer; +} + +nvinfer1::DataType ConvertDataType(TypeId type_id) { + std::map data_type_map = {{TypeId::kNumberTypeInt8, nvinfer1::DataType::kINT8}, + {TypeId::kNumberTypeInt32, nvinfer1::DataType::kINT32}, + {TypeId::kNumberTypeFloat32, nvinfer1::DataType::kFLOAT}, + {TypeId::kNumberTypeFloat16, nvinfer1::DataType::kHALF}}; + auto iter = data_type_map.find(type_id); + nvinfer1::DataType data_type; + if (iter != data_type_map.end()) { + data_type = iter->second; + } else { + data_type = nvinfer1::DataType::kFLOAT; + MS_LOG(WARNING) << "invalid data_type for TensorRT, need check"; + } + return data_type; +} + +nvinfer1::IShuffleLayer *NHWC2NCHW(nvinfer1::INetworkDefinition *network, const nvinfer1::ITensor &input) { + // NHWC 0123 NCHW 0312 + nvinfer1::Permutation perm{{0, 3, 1, 2}}; + return SetTranspose(network, input, perm); +} + +nvinfer1::IShuffleLayer *NCHW2NHWC(nvinfer1::INetworkDefinition *network, const nvinfer1::ITensor &input) { + // NCHW 0123 NHWC 0231 + nvinfer1::Permutation perm{{0, 2, 3, 1}}; + return SetTranspose(network, input, perm); +} + +nvinfer1::ITensor *ConvertConstantTensor(nvinfer1::INetworkDefinition *network, tensor::MSTensor *ms_tensor) { + if (network == nullptr) { + MS_LOG(ERROR) << "network is null for ConvertConstantTensor"; + return nullptr; + } + nvinfer1::Dims dims = ConvertCudaDims(ms_tensor->shape()); + nvinfer1::DataType data_type = ConvertDataType(ms_tensor->data_type()); + + nvinfer1::Weights weights{data_type, ms_tensor->data(), ms_tensor->ElementsNum()}; + nvinfer1::IConstantLayer *constant_tensor = network->addConstant(dims, weights); + if (constant_tensor == nullptr) { + MS_LOG(ERROR) << "create constant_tensor failed."; + return nullptr; + } + auto name = ms_tensor->tensor_name() + "_constant_layer"; + constant_tensor->setName(name.c_str()); + return constant_tensor->getOutput(0); +} + +nvinfer1::ITensor *ConvertScalarToITensor(nvinfer1::INetworkDefinition *network, size_t shape_size, void *value) { + nvinfer1::Dims dims = ConvertCudaDims(1, shape_size); + nvinfer1::Weights weights{nvinfer1::DataType::kFLOAT, value, 1}; + nvinfer1::IConstantLayer *constant_tensor = network->addConstant(dims, weights); + if (constant_tensor == nullptr) { + MS_LOG(ERROR) << "create constant_tensor failed."; + return nullptr; + } + return constant_tensor->getOutput(0); +} + +nvinfer1::ActivationType ConvertActivationType(schema::ActivationType activation_type) { + std::map action_map = { + {schema::ActivationType_RELU, nvinfer1::ActivationType::kRELU}, + {schema::ActivationType_SIGMOID, nvinfer1::ActivationType::kSIGMOID}, + {schema::ActivationType_TANH, nvinfer1::ActivationType::kTANH}, + {schema::ActivationType_LEAKY_RELU, nvinfer1::ActivationType::kLEAKY_RELU}, + {schema::ActivationType_ELU, nvinfer1::ActivationType::kELU}, + {schema::ActivationType_SELU, nvinfer1::ActivationType::kSELU}, + {schema::ActivationType_SOFTSIGN, nvinfer1::ActivationType::kSOFTSIGN}, + {schema::ActivationType_SOFTPLUS, nvinfer1::ActivationType::kSOFTPLUS}, + {schema::ActivationType_THRESHOLDRELU, nvinfer1::ActivationType::kTHRESHOLDED_RELU}}; + auto iter = action_map.find(activation_type); + nvinfer1::ActivationType action_code = nvinfer1::ActivationType::kRELU; + if (iter != action_map.end()) { + action_code = iter->second; + } else { + MS_LOG(WARNING) << "Unsupported op action type for TensorRT: " << activation_type; + } + return action_code; +} + +nvinfer1::ITensor *ConvertTensorWithExpandDims(nvinfer1::INetworkDefinition *network, tensor::MSTensor *ms_tensor, + size_t expand_shape_size) { + if (network == nullptr) { + MS_LOG(ERROR) << "network is null for ConvertConstantTensor"; + return nullptr; + } + std::vector shape(expand_shape_size); + size_t shape_size = ms_tensor->shape().size(); + size_t expand_size = expand_shape_size - shape_size; + for (size_t i = 0; i < expand_shape_size; ++i) { + if (i < expand_size) { + shape[i] = 1; + } else { + shape[i] = ms_tensor->shape()[i - expand_size]; + } + } + nvinfer1::Dims dims = ConvertCudaDims(shape); + nvinfer1::DataType data_type = ConvertDataType(ms_tensor->data_type()); + + nvinfer1::Weights weights{data_type, ms_tensor->data(), ms_tensor->ElementsNum()}; + nvinfer1::IConstantLayer *constant_tensor = network->addConstant(dims, weights); + if (constant_tensor == nullptr) { + MS_LOG(ERROR) << "create constant_tensor failed."; + return nullptr; + } + auto name = ms_tensor->tensor_name() + "_constant_layer"; + constant_tensor->setName(name.c_str()); + return constant_tensor->getOutput(0); +} +} // namespace mindspore::lite diff --git a/mindspore/lite/src/delegate/tensorrt/tensorrt_utils.h b/mindspore/lite/src/delegate/tensorrt/tensorrt_utils.h new file mode 100644 index 00000000000..2f33765135d --- /dev/null +++ b/mindspore/lite/src/delegate/tensorrt/tensorrt_utils.h @@ -0,0 +1,48 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_DELEGATE_TENSORRT_UTILS_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_DELEGATE_TENSORRT_UTILS_H_ +#include +#include +#include "src/delegate/tensorrt/op/tensorrt_op.h" +#include "mindspore/core/ir/dtype/type_id.h" +#include "schema/ops_generated.h" + +namespace mindspore::lite { +// Convert shape to Cuda Dims. +nvinfer1::Dims ConvertCudaDims(const std::vector &shape); + +// Convert Tensor data to Cuda dims. +nvinfer1::Dims ConvertCudaDims(void *data, size_t size); + +nvinfer1::Dims ConvertCudaDims(int data, size_t size); + +nvinfer1::DataType ConvertDataType(TypeId type_id); + +nvinfer1::IShuffleLayer *NHWC2NCHW(nvinfer1::INetworkDefinition *network, const nvinfer1::ITensor &input); + +nvinfer1::IShuffleLayer *NCHW2NHWC(nvinfer1::INetworkDefinition *network, const nvinfer1::ITensor &input); + +nvinfer1::ActivationType ConvertActivationType(schema::ActivationType activation_type); + +nvinfer1::ITensor *ConvertConstantTensor(nvinfer1::INetworkDefinition *network, tensor::MSTensor *ms_tensor); + +nvinfer1::ITensor *ConvertTensorWithExpandDims(nvinfer1::INetworkDefinition *network, tensor::MSTensor *ms_tensor, + size_t expand_shape_size); + +nvinfer1::ITensor *ConvertScalarToITensor(nvinfer1::INetworkDefinition *network, size_t shape_size, void *value); +} // namespace mindspore::lite +#endif // MINDSPORE_LITE_SRC_RUNTIME_DELEGATE_TENSORRT_UTILS_H_ diff --git a/mindspore/lite/src/lite_kernel_util.cc b/mindspore/lite/src/lite_kernel_util.cc index 1c3477fca4b..564ced8f523 100644 --- a/mindspore/lite/src/lite_kernel_util.cc +++ b/mindspore/lite/src/lite_kernel_util.cc @@ -197,6 +197,9 @@ void LiteKernelUtil::InitTensorInitRefCount(const std::vector &inputs) { return -1; } bool LiteKernelUtil::IsSwitchCall(kernel::LiteKernel *kernel) { + if (kernel->desc().delegate != nullptr) { + return false; + } auto *subgraph_kernel = reinterpret_cast(kernel); if (subgraph_kernel == nullptr) { return false; diff --git a/mindspore/lite/src/lite_session.cc b/mindspore/lite/src/lite_session.cc index e3b79c5aef0..2ba987f019c 100644 --- a/mindspore/lite/src/lite_session.cc +++ b/mindspore/lite/src/lite_session.cc @@ -38,6 +38,9 @@ #if GPU_OPENCL #include "src/runtime/kernel/opencl/opencl_subgraph.h" #endif +#if GPU_TENSORRT +#include "src/delegate/tensorrt/tensorrt_delegate.h" +#endif namespace mindspore { namespace lite { @@ -527,6 +530,7 @@ int LiteSession::CompileGraph(Model *model) { #ifdef ENABLE_MINDRT } #endif + if (executor_ == nullptr) { MS_LOG(ERROR) << "New Executor failed"; is_running_.store(false); @@ -650,6 +654,9 @@ int LiteSession::Init(const Context *context) { return RET_ERROR; } } +#endif +#if GPU_TENSORRT + delegate_ = std::shared_ptr(new (std::nothrow) TensorRTDelegate()); #endif if (delegate_ != nullptr) { auto delegate_ret = delegate_->Init(); diff --git a/mindspore/lite/tools/cropper/build_cropper_config.sh b/mindspore/lite/tools/cropper/build_cropper_config.sh index 9d6f72c954a..8b5d0e186f0 100644 --- a/mindspore/lite/tools/cropper/build_cropper_config.sh +++ b/mindspore/lite/tools/cropper/build_cropper_config.sh @@ -159,8 +159,6 @@ getCommonFile() { while IFS='' read -r line; do common_files+=("$line"); done < <(ls ${MINDSPORE_HOME}/mindspore/lite/src/common/*.cc) runtime_files_cc=() while IFS='' read -r line; do runtime_files_cc+=("$line"); done < <(ls ${MINDSPORE_HOME}/mindspore/lite/src/runtime/*.cc) - runtime_files_c=() - while IFS='' read -r line; do runtime_files_c+=("$line"); done < <(ls ${MINDSPORE_HOME}/mindspore/lite/src/runtime/*.c) # sava all assembly files assembly_files=() while IFS='' read -r line; do assembly_files+=("$line"); done < <(ls ${MINDSPORE_HOME}/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/assembly/*/*.S) @@ -173,7 +171,7 @@ getCommonFile() { "${MINDSPORE_HOME}"/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/common_infer.c ) all_files=("${src_files[@]}" "${regist_files[@]}" "${common_files[@]}" "${runtime_files_cc[@]}" - "${runtime_files_c[@]}" "${others_files_c[@]}" "${assembly_files[@]}" "${mindrt_files[@]}" + "${others_files_c[@]}" "${assembly_files[@]}" "${mindrt_files[@]}" "${cxx_api_files[@]}" ) # shellcheck disable=SC2068