!25944 [MSLITE] add equal, cast op and bug fix for ocr models in tensorrt delegate

Merge pull request !25944 from Liu_Xuu/trt_1104_ocr_merge
This commit is contained in:
i-robot 2021-11-09 07:42:09 +00:00 committed by Gitee
commit be13447b73
27 changed files with 913 additions and 43 deletions

View File

@ -112,6 +112,7 @@ if(BUILD_MINDDATA STREQUAL "full")
${TOP_DIR}/mindspore/lite/src/cxx_api/tensor_utils.cc
${TOP_DIR}/mindspore/lite/src/cxx_api/tensor/tensor_impl.cc
${TOP_DIR}/mindspore/lite/src/tensor.cc
${TOP_DIR}/mindspore/lite/src/common/utils.cc
${TOP_DIR}/mindspore/lite/src/ms_tensor.cc
${TOP_DIR}/mindspore/lite/src/common/string_util.cc
${TOP_DIR}/mindspore/lite/src/common/lite_utils.cc
@ -424,6 +425,7 @@ elseif(BUILD_MINDDATA STREQUAL "wrapper")
${CORE_DIR}/utils/status.cc
${CMAKE_CURRENT_SOURCE_DIR}/../src/cxx_api/types.cc
${CMAKE_CURRENT_SOURCE_DIR}/../src/tensor.cc
${CMAKE_CURRENT_SOURCE_DIR}/../src/common/utils.cc
)
add_library(minddata-lite SHARED

View File

@ -326,8 +326,8 @@ if(SUPPORT_TENSORRT)
include_directories(${TENSORRT_PATH}/include)
include_directories(${CUDA_PATH}/include)
add_subdirectory(delegate/tensorrt)
target_link_libraries(mindspore-lite tensorrt_kernel_mid)
target_link_libraries(mindspore-lite_static tensorrt_kernel_mid)
target_link_libraries(mindspore-lite tensorrt_kernel_mid cuda_kernel_mid)
target_link_libraries(mindspore-lite_static tensorrt_kernel_mid cuda_kernel_mid)
endif()
if(MSLITE_GPU_BACKEND STREQUAL opencl)

View File

@ -19,9 +19,13 @@
#include <asm/hwcap.h>
#endif
#include "src/common/utils.h"
#ifdef _MSC_VER
#if defined(_MSC_VER) || defined(_WIN32)
#include <windows.h>
#undef ERROR
#else
#include <unistd.h>
#include <sys/types.h>
#include <sys/param.h>
#endif
namespace mindspore {
@ -150,5 +154,17 @@ bool IsSupportSDot() {
#endif
return status;
}
size_t GetMaxMallocSize() {
size_t max_malloc_size = 0;
#if defined(_MSC_VER) || defined(_WIN32)
MEMORYSTATUSEX status;
status.dwLength = sizeof(status);
GlobalMemoryStatusEx(&status);
max_malloc_size = static_cast<size_t>(status.ullTotalPhys);
#else
max_malloc_size = static_cast<size_t>(sysconf(_SC_PHYS_PAGES)) * static_cast<size_t>(sysconf(_SC_PAGESIZE));
#endif
return max_malloc_size;
}
} // namespace lite
} // namespace mindspore

View File

@ -28,6 +28,10 @@
#include "include/errorcode.h"
#include "ir/dtype/type_id.h"
#ifndef EXPORT_WRAPPER
#define EXPORT_WRAPPER __attribute__((visibility("default")))
#endif
namespace mindspore {
namespace lite {
enum NodeType {
@ -42,6 +46,8 @@ uint64_t GetTimeUs();
bool IsSupportSDot();
size_t EXPORT_WRAPPER GetMaxMallocSize();
#ifdef __ANDROID__
uint32_t getHwCap(int hwcap_type);
#endif

View File

@ -25,3 +25,13 @@ target_link_libraries(
libcudart
libnvinfer
)
# cuda
find_package(CUDA)
file(GLOB_RECURSE CUDA_KERNEL_SRC
${CMAKE_CURRENT_SOURCE_DIR}/cuda_impl/*.cu
)
set_source_files_properties(${CUDA_KERNEL_SRC} PROPERTIES CUDA_SOURCE_PROPERTY_FORMAT OBJ)
SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGES} -std=c++14 -fPIC")
SET(CUDA_NVCC_FLAGS ${CUDA_NVCC_FLAGS};-std=c++14;)
cuda_add_library(cuda_kernel_mid STATIC ${CUDA_KERNEL_SRC})

View File

@ -0,0 +1,48 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/delegate/tensorrt/cuda_impl/cast.cuh"
#include "src/delegate/tensorrt/cuda_impl/cuda_helper.h"
// Generic cast
template <typename S, typename T>
__device__ __forceinline__ void CastBase(const S *input_addr, T *output_addr) {
*output_addr = static_cast<T>((*input_addr));
}
template <typename S, typename T>
__global__ void CastKernel(const int input_size, const S *input_addr, T *output_addr) {
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < input_size; pos += blockDim.x * gridDim.x) {
CastBase(input_addr + pos, output_addr + pos);
}
}
template <typename S, typename T>
void Cast(const int input_size, const S *input_addr, T *output_addr, cudaStream_t stream) {
CastKernel<<<GET_BLOCKS(input_size), GET_THREADS, 0, stream>>>(input_size, input_addr, output_addr);
}
template void Cast(const int input_size, const int8_t *input_addr, int8_t *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const int8_t *input_addr, int32_t *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const int8_t *input_addr, float *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const int32_t *input_addr, int8_t *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const int32_t *input_addr, int32_t *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const int32_t *input_addr, float *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const float *input_addr, int8_t *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const float *input_addr, int32_t *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const float *input_addr, float *output_addr, cudaStream_t stream);

View File

@ -0,0 +1,23 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_DELEGATE_TENSORRT_CDUA_IMPL_CAST_H_
#define MINDSPORE_LITE_SRC_DELEGATE_TENSORRT_CDUA_IMPL_CAST_H_
template <typename S, typename T>
void Cast(const int input_size, const S *input_addr, T *output_addr, cudaStream_t stream);
#endif // MINDSPORE_LITE_SRC_DELEGATE_TENSORRT_CDUA_IMPL_CAST_H_

View File

@ -0,0 +1,35 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/delegate/tensorrt/cuda_impl/cuda_helper.h"
CudaHelper &CudaHelper::GetInstance() {
static CudaHelper instance;
return instance;
}
int CudaHelper::GetThreadNum() const { return threads_per_block_; }
int CudaHelper::GetBlocksNum(const int total_threads) const {
return std::min(((total_threads - 1) / threads_per_block_) + 1, max_blocks_);
}
CudaHelper::CudaHelper() {
int device_id = 0;
(void)cudaGetDevice(&device_id);
cudaDeviceProp prop;
(void)cudaGetDeviceProperties(&prop, device_id);
threads_per_block_ = prop.maxThreadsPerBlock;
max_blocks_ = prop.multiProcessorCount;
}

View File

@ -0,0 +1,41 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_DELEGATE_TENSORRT_CDUA_IMPL_CUDA_HELPER_H_
#define MINDSPORE_LITE_SRC_DELEGATE_TENSORRT_CDUA_IMPL_CUDA_HELPER_H_
#include <cuda_runtime.h>
#include <algorithm>
class CudaHelper {
public:
int GetThreadNum() const;
int GetBlocksNum(const int total_threads) const;
static CudaHelper &GetInstance();
private:
CudaHelper();
~CudaHelper() = default;
CudaHelper(const CudaHelper &) = delete;
CudaHelper &operator=(const CudaHelper &) = delete;
int max_blocks_;
int threads_per_block_;
};
#define GET_BLOCKS(total_threads) CudaHelper::GetInstance().GetBlocksNum(total_threads)
#define GET_THREADS CudaHelper::GetInstance().GetThreadNum()
#endif // MINDSPORE_LITE_SRC_DELEGATE_TENSORRT_CDUA_IMPL_CUDA_HELPER_H_

View File

@ -0,0 +1,35 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/delegate/tensorrt/cuda_impl/equal.cuh"
#include <stdio.h>
#include "src/delegate/tensorrt/cuda_impl/cuda_helper.h"
template <typename T>
__global__ void EqualKernel(const T *input1, const T *input2, T *output, int element_cnt) {
for (int pos = blockIdx.x * blockDim.x + threadIdx.x; pos < element_cnt; pos += blockDim.x * gridDim.x) {
output[pos] = (input1[pos] - input2[pos] < 1e-6 && input1[pos] - input2[pos] > -1e-6);
}
}
template <typename T>
void Equal(const T *input1, const T *input2, T *output, int element_cnt, cudaStream_t stream) {
EqualKernel<<<GET_BLOCKS(element_cnt), GET_THREADS, 0, stream>>>(input1, input2, output, element_cnt);
return;
}
template void Equal(const float *input1, const float *input2, float *output, int element_cnt, cudaStream_t stream);
template void Equal(const int *input1, const int *input2, int *output, int element_cnt, cudaStream_t stream);

View File

@ -0,0 +1,23 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_DELEGATE_TENSORRT_CDUA_IMPL_EQUAL_H_
#define MINDSPORE_LITE_SRC_DELEGATE_TENSORRT_CDUA_IMPL_EQUAL_H_
template <typename T>
void Equal(const T *input1, const T *input2, T *output, int element_cnt, cudaStream_t stream);
#endif // MINDSPORE_LITE_SRC_DELEGATE_TENSORRT_CDUA_IMPL_EQUAL_H_

View File

@ -0,0 +1,190 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/delegate/tensorrt/op/cast_tensorrt.h"
#include <numeric>
#include <memory>
#include <functional>
#include "src/delegate/tensorrt/tensorrt_utils.h"
#include "NvInferRuntimeCommon.h"
namespace mindspore::lite {
const char *CAST_PLUGIN_VERSION{"1"};
const char *CAST_PLUGIN_NAME{"CastPluginCreater"};
nvinfer1::PluginFieldCollection CastPluginCreater::field_collection_{};
std::vector<nvinfer1::PluginField> CastPluginCreater::fields_;
REGISTER_TENSORRT_PLUGIN(CastPluginCreater);
int CastTensorRT::IsSupport(const schema::Primitive *primitive, const std::vector<mindspore::MSTensor> &in_tensors,
const std::vector<mindspore::MSTensor> &out_tensors) {
if (!IsShapeKnown()) {
MS_LOG(ERROR) << "Unsupported input tensor unknown shape: " << op_name_;
return RET_ERROR;
}
if (in_tensors.size() != INPUT_SIZE2) {
MS_LOG(ERROR) << "invalid input tensor size: " << in_tensors.size();
return RET_ERROR;
}
if (out_tensors.size() != 1) {
MS_LOG(ERROR) << "invalid output tensor size: " << out_tensors.size();
return RET_ERROR;
}
return RET_OK;
}
int CastTensorRT::AddInnerOp(nvinfer1::INetworkDefinition *network) {
nvinfer1::ITensor *inputTensors[] = {tensorrt_in_tensors_[0].trt_tensor_};
// cast to type tensor
auto type_tensor = in_tensors_[1];
if (type_tensor.Data() == nullptr) {
MS_LOG(ERROR) << "unknown cast type of " << op_name_;
return RET_ERROR;
}
if (type_tensor.DataType() != DataType::kNumberTypeInt32) {
MS_LOG(WARNING) << "unknown type_tensor data type of " << op_name_;
}
auto type_data = static_cast<const int *>(type_tensor.Data().get());
DataType data_type = static_cast<DataType>(type_data[0]);
MS_LOG(INFO) << op_name_ << " cast to data type(43 float): " << type_data[0];
nvinfer1::DataType dest_datatype = ConvertDataType(data_type);
auto plugin = std::make_shared<CastPlugin>(op_name_, tensorrt_in_tensors_[0].trt_tensor_->getType(), dest_datatype);
nvinfer1::IPluginV2Layer *cast_layer = network->addPluginV2(inputTensors, 1, *plugin);
if (cast_layer == nullptr) {
MS_LOG(ERROR) << "create cast layer failed for: " << op_name_;
return RET_ERROR;
}
nvinfer1::ITensor *cast_out = cast_layer->getOutput(0);
cast_layer->setName(op_name_.c_str());
cast_out->setName((op_name_ + "_output").c_str());
this->AddInnerOutTensors(ITensorHelper{cast_out, tensorrt_in_tensors_[0].format_});
return RET_OK;
}
// CastPluginCreater
CastPluginCreater::CastPluginCreater() {
// Fill PluginFieldCollection with PluginField arguments metadata
field_collection_.nbFields = fields_.size();
field_collection_.fields = fields_.data();
}
const char *CastPluginCreater::getPluginName() const noexcept { return CAST_PLUGIN_NAME; }
const char *CastPluginCreater::getPluginVersion() const noexcept { return CAST_PLUGIN_VERSION; }
const nvinfer1::PluginFieldCollection *CastPluginCreater::getFieldNames() noexcept { return &field_collection_; }
nvinfer1::IPluginV2 *CastPluginCreater::createPlugin(const char *name,
const nvinfer1::PluginFieldCollection *fc) noexcept {
const nvinfer1::PluginField *fields = fc->fields;
nvinfer1::DataType origin_datatype = static_cast<const nvinfer1::DataType *>(fields[0].data)[0];
nvinfer1::DataType dest_datatype = static_cast<const nvinfer1::DataType *>(fields[1].data)[0];
return new CastPlugin(name, origin_datatype, dest_datatype);
}
nvinfer1::IPluginV2 *CastPluginCreater::deserializePlugin(const char *name, const void *serialData,
size_t serialLength) noexcept {
MS_LOG(ERROR) << name << " don't support deserialize";
return nullptr;
}
void CastPluginCreater::setPluginNamespace(const char *libNamespace) noexcept { name_space_ = libNamespace; }
const char *CastPluginCreater::getPluginNamespace() const noexcept { return name_space_.c_str(); }
// CastPlugin
int CastPlugin::enqueue(const nvinfer1::PluginTensorDesc *inputDesc, const nvinfer1::PluginTensorDesc *outputDesc,
const void *const *inputs, void *const *outputs, void *workspace,
cudaStream_t stream) noexcept {
nvinfer1::Dims input_dims = inputDesc[0].dims;
int element_cnt = std::accumulate(input_dims.d, input_dims.d + input_dims.nbDims, 1, std::multiplies<int64_t>());
if (inputDesc->type == outputDesc->type) {
int element_size = (outputDesc->type == nvinfer1::DataType::kFLOAT)
? sizeof(float)
: ((outputDesc->type == nvinfer1::DataType::kINT32) ? sizeof(int) : 0);
auto cuda_ret = cudaMemcpy(outputs[0], inputs[0], element_cnt * element_size, cudaMemcpyDeviceToDevice);
if (cuda_ret != cudaSuccess) {
MS_LOG(ERROR) << "copy mem failed for " << layer_name_;
return RET_ERROR;
}
return RET_OK;
}
if (inputDesc->type == nvinfer1::DataType::kINT32 && dest_datatype_ == nvinfer1::DataType::kFLOAT) {
auto input = static_cast<const int *>(inputs[0]);
auto output = static_cast<float *>(outputs[0]);
Cast(element_cnt, input, output, stream);
} else if (inputDesc->type == nvinfer1::DataType::kFLOAT && dest_datatype_ == nvinfer1::DataType::kINT32) {
auto input = static_cast<const float *>(inputs[0]);
auto output = static_cast<int *>(outputs[0]);
Cast(element_cnt, input, output, stream);
} else {
MS_LOG(ERROR) << "unsupported data type cast " << layer_name_;
}
return RET_OK;
}
nvinfer1::IPluginV2DynamicExt *CastPlugin::clone() const noexcept {
auto *plugin = new CastPlugin(*this);
plugin->setPluginNamespace(name_space_.c_str());
return plugin;
}
nvinfer1::DimsExprs CastPlugin::getOutputDimensions(int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs,
nvinfer1::IExprBuilder &exprBuilder) noexcept {
return *inputs;
}
bool CastPlugin::supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc *tensorsDesc, int nbInputs,
int nbOutputs) noexcept {
return true;
}
void CastPlugin::configurePlugin(const nvinfer1::DynamicPluginTensorDesc *in, int nbInputs,
const nvinfer1::DynamicPluginTensorDesc *out, int nbOutputs) noexcept {}
size_t CastPlugin::getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs, int nbInputs,
const nvinfer1::PluginTensorDesc *outputs, int nbOutputs) const noexcept {
return 0;
}
nvinfer1::DataType CastPlugin::getOutputDataType(int index, const nvinfer1::DataType *inputTypes, int nbInputs) const
noexcept {
return dest_datatype_;
}
const char *CastPlugin::getPluginType() const noexcept { return CAST_PLUGIN_NAME; }
const char *CastPlugin::getPluginVersion() const noexcept { return CAST_PLUGIN_VERSION; }
int CastPlugin::getNbOutputs() const noexcept { return 1; }
int CastPlugin::initialize() noexcept { return 0; }
void CastPlugin::terminate() noexcept {}
size_t CastPlugin::getSerializationSize() const noexcept { return 0; }
void CastPlugin::serialize(void *buffer) const noexcept {}
void CastPlugin::destroy() noexcept {
// This gets called when the network containing plugin is destroyed
delete this;
}
void CastPlugin::setPluginNamespace(const char *libNamespace) noexcept { name_space_ = libNamespace; }
const char *CastPlugin::getPluginNamespace() const noexcept { return name_space_.c_str(); }
} // namespace mindspore::lite

View File

@ -0,0 +1,110 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_DELEGATE_TENSORRT_OP_CAST_TENSORRT_H_
#define MINDSPORE_LITE_SRC_DELEGATE_TENSORRT_OP_CAST_TENSORRT_H_
#include <string>
#include <vector>
#include "src/delegate/tensorrt/op/tensorrt_op.h"
#include "backend/kernel_compiler/gpu/cuda_impl/cast_impl.cuh"
namespace mindspore::lite {
class CastTensorRT : public TensorRTOp {
public:
CastTensorRT(const schema::Primitive *primitive, const std::vector<mindspore::MSTensor> &in_tensors,
const std::vector<mindspore::MSTensor> &out_tensors, const std::string &name)
: TensorRTOp(primitive, in_tensors, out_tensors, name) {}
~CastTensorRT() override = default;
int AddInnerOp(nvinfer1::INetworkDefinition *network) override;
int IsSupport(const schema::Primitive *primitive, const std::vector<mindspore::MSTensor> &in_tensors,
const std::vector<mindspore::MSTensor> &out_tensors) override;
private:
// CastTensorRT
};
class CastPluginCreater : public nvinfer1::IPluginCreator {
public:
CastPluginCreater();
const char *getPluginName() const noexcept override;
const char *getPluginVersion() const noexcept override;
const nvinfer1::PluginFieldCollection *getFieldNames() noexcept override;
nvinfer1::IPluginV2 *createPlugin(const char *name, const nvinfer1::PluginFieldCollection *fc) noexcept override;
nvinfer1::IPluginV2 *deserializePlugin(const char *name, const void *serialData,
size_t serialLength) noexcept override;
void setPluginNamespace(const char *pluginNamespace) noexcept override;
const char *getPluginNamespace() const noexcept override;
private:
static nvinfer1::PluginFieldCollection field_collection_;
static std::vector<nvinfer1::PluginField> fields_;
std::string name_space_;
};
class CastPlugin : public nvinfer1::IPluginV2DynamicExt {
public:
CastPlugin(const std::string name, nvinfer1::DataType origin_datatype, nvinfer1::DataType dest_datatype)
: layer_name_(name), origin_datatype_(origin_datatype), dest_datatype_(dest_datatype) {}
// It doesn't make sense to make GeluPluginDynamic without arguments, so we delete
// default constructor.
CastPlugin() = delete;
// IPluginV2DynamicExt Methods
nvinfer1::IPluginV2DynamicExt *clone() const noexcept override;
nvinfer1::DimsExprs getOutputDimensions(int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs,
nvinfer1::IExprBuilder &exprBuilder) noexcept override;
bool supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc *tensorsDesc, int nbInputs,
int nbOutputs) noexcept override;
void configurePlugin(const nvinfer1::DynamicPluginTensorDesc *in, int nbInputs,
const nvinfer1::DynamicPluginTensorDesc *out, int nbOutputs) noexcept override;
size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs, int nbInputs,
const nvinfer1::PluginTensorDesc *outputs, int nbOutputs) const noexcept override;
int enqueue(const nvinfer1::PluginTensorDesc *inputDesc, const nvinfer1::PluginTensorDesc *outputDesc,
const void *const *inputs, void *const *outputs, void *workspace, cudaStream_t stream) noexcept override;
// IPluginV2Ext Methods
nvinfer1::DataType getOutputDataType(int index, const nvinfer1::DataType *inputTypes, int nbInputs) const
noexcept override;
// IPluginV2 Methods
const char *getPluginType() const noexcept override;
const char *getPluginVersion() const noexcept override;
int getNbOutputs() const noexcept override;
int initialize() noexcept override;
void terminate() noexcept override;
size_t getSerializationSize() const noexcept override;
void serialize(void *buffer) const noexcept override;
void destroy() noexcept override;
void setPluginNamespace(const char *pluginNamespace) noexcept override;
const char *getPluginNamespace() const noexcept override;
private:
const std::string layer_name_;
std::string name_space_;
nvinfer1::DataType origin_datatype_;
nvinfer1::DataType dest_datatype_;
};
} // namespace mindspore::lite
#endif // MINDSPORE_LITE_SRC_DELEGATE_TENSORRT_OP_CAST_TENSORRT_H_

View File

@ -0,0 +1,166 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/delegate/tensorrt/op/equal_tensorrt.h"
#include <numeric>
#include <memory>
#include <functional>
#include "src/delegate/tensorrt/tensorrt_utils.h"
#include "NvInferRuntimeCommon.h"
namespace mindspore::lite {
const char *EQUAL_PLUGIN_VERSION{"1"};
const char *EQUAL_PLUGIN_NAME{"EqualPluginCreater"};
nvinfer1::PluginFieldCollection EqualPluginCreater::field_collection_{};
std::vector<nvinfer1::PluginField> EqualPluginCreater::fields_;
REGISTER_TENSORRT_PLUGIN(EqualPluginCreater);
int EqualTensorRT::IsSupport(const schema::Primitive *primitive, const std::vector<mindspore::MSTensor> &in_tensors,
const std::vector<mindspore::MSTensor> &out_tensors) {
if (!IsShapeKnown()) {
MS_LOG(ERROR) << "Unsupported input tensor unknown shape: " << op_name_;
return RET_ERROR;
}
if (in_tensors.size() != INPUT_SIZE2) {
MS_LOG(ERROR) << "invalid input tensor size: " << in_tensors.size();
return RET_ERROR;
}
if (out_tensors.size() != 1) {
MS_LOG(ERROR) << "invalid output tensor size: " << out_tensors.size();
return RET_ERROR;
}
return RET_OK;
}
int EqualTensorRT::AddInnerOp(nvinfer1::INetworkDefinition *network) {
nvinfer1::ITensor *inputTensors[] = {tensorrt_in_tensors_[0].trt_tensor_, tensorrt_in_tensors_[1].trt_tensor_};
auto plugin = std::make_shared<EqualPlugin>(op_name_);
nvinfer1::IPluginV2Layer *equal_layer = network->addPluginV2(inputTensors, 2, *plugin);
if (equal_layer == nullptr) {
MS_LOG(ERROR) << "create equal layer failed for: " << op_name_;
return RET_ERROR;
}
nvinfer1::ITensor *equal_out = equal_layer->getOutput(0);
equal_layer->setName(op_name_.c_str());
equal_out->setName((op_name_ + "_output").c_str());
this->AddInnerOutTensors(ITensorHelper{equal_out, tensorrt_in_tensors_[0].format_});
return RET_OK;
}
// EqualPluginCreater
EqualPluginCreater::EqualPluginCreater() {
// Fill PluginFieldCollection with PluginField arguments metadata
field_collection_.nbFields = fields_.size();
field_collection_.fields = fields_.data();
}
const char *EqualPluginCreater::getPluginName() const noexcept { return EQUAL_PLUGIN_NAME; }
const char *EqualPluginCreater::getPluginVersion() const noexcept { return EQUAL_PLUGIN_VERSION; }
const nvinfer1::PluginFieldCollection *EqualPluginCreater::getFieldNames() noexcept { return &field_collection_; }
nvinfer1::IPluginV2 *EqualPluginCreater::createPlugin(const char *name,
const nvinfer1::PluginFieldCollection *fc) noexcept {
return new EqualPlugin(name);
}
nvinfer1::IPluginV2 *EqualPluginCreater::deserializePlugin(const char *name, const void *serialData,
size_t serialLength) noexcept {
MS_LOG(ERROR) << name << " don't support deserialize";
return nullptr;
}
void EqualPluginCreater::setPluginNamespace(const char *libNamespace) noexcept { name_space_ = libNamespace; }
const char *EqualPluginCreater::getPluginNamespace() const noexcept { return name_space_.c_str(); }
// EqualPlugin
int EqualPlugin::enqueue(const nvinfer1::PluginTensorDesc *inputDesc, const nvinfer1::PluginTensorDesc *outputDesc,
const void *const *inputs, void *const *outputs, void *workspace,
cudaStream_t stream) noexcept {
nvinfer1::Dims input_dims = inputDesc[0].dims;
int element_cnt = std::accumulate(input_dims.d, input_dims.d + input_dims.nbDims, 1, std::multiplies<int64_t>());
MS_LOG(INFO) << layer_name_ << " element_cnt: " << element_cnt;
if (inputDesc->type == nvinfer1::DataType::kINT32) {
const int *input1 = static_cast<const int *>(inputs[0]);
const int *input2 = static_cast<const int *>(inputs[1]);
int *output = static_cast<int *>(outputs[0]);
Equal(input1, input2, output, element_cnt, stream);
} else if (inputDesc->type == nvinfer1::DataType::kFLOAT) {
const float *input1 = static_cast<const float *>(inputs[0]);
const float *input2 = static_cast<const float *>(inputs[1]);
float *output = static_cast<float *>(outputs[0]);
Equal(input1, input2, output, element_cnt, stream);
} else {
MS_LOG(ERROR) << "unsupported equal data type";
}
return RET_OK;
}
nvinfer1::IPluginV2DynamicExt *EqualPlugin::clone() const noexcept {
auto *plugin = new EqualPlugin(*this);
plugin->setPluginNamespace(name_space_.c_str());
return plugin;
}
nvinfer1::DimsExprs EqualPlugin::getOutputDimensions(int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs,
nvinfer1::IExprBuilder &exprBuilder) noexcept {
return *inputs;
}
bool EqualPlugin::supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc *tensorsDesc, int nbInputs,
int nbOutputs) noexcept {
return true;
}
void EqualPlugin::configurePlugin(const nvinfer1::DynamicPluginTensorDesc *in, int nbInputs,
const nvinfer1::DynamicPluginTensorDesc *out, int nbOutputs) noexcept {}
size_t EqualPlugin::getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs, int nbInputs,
const nvinfer1::PluginTensorDesc *outputs, int nbOutputs) const noexcept {
return 0;
}
nvinfer1::DataType EqualPlugin::getOutputDataType(int index, const nvinfer1::DataType *inputTypes, int nbInputs) const
noexcept {
return inputTypes[0];
}
const char *EqualPlugin::getPluginType() const noexcept { return EQUAL_PLUGIN_NAME; }
const char *EqualPlugin::getPluginVersion() const noexcept { return EQUAL_PLUGIN_VERSION; }
int EqualPlugin::getNbOutputs() const noexcept { return 1; }
int EqualPlugin::initialize() noexcept { return 0; }
void EqualPlugin::terminate() noexcept {}
size_t EqualPlugin::getSerializationSize() const noexcept { return 0; }
void EqualPlugin::serialize(void *buffer) const noexcept {}
void EqualPlugin::destroy() noexcept {
// This gets called when the network containing plugin is destroyed
delete this;
}
void EqualPlugin::setPluginNamespace(const char *libNamespace) noexcept { name_space_ = libNamespace; }
const char *EqualPlugin::getPluginNamespace() const noexcept { return name_space_.c_str(); }
} // namespace mindspore::lite

View File

@ -0,0 +1,107 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_DELEGATE_TENSORRT_OP_EQUAL_TENSORRT_H_
#define MINDSPORE_LITE_SRC_DELEGATE_TENSORRT_OP_EQUAL_TENSORRT_H_
#include <string>
#include <vector>
#include "src/delegate/tensorrt/op/tensorrt_op.h"
#include "src/delegate/tensorrt/cuda_impl/equal.cuh"
namespace mindspore::lite {
class EqualTensorRT : public TensorRTOp {
public:
EqualTensorRT(const schema::Primitive *primitive, const std::vector<mindspore::MSTensor> &in_tensors,
const std::vector<mindspore::MSTensor> &out_tensors, const std::string &name)
: TensorRTOp(primitive, in_tensors, out_tensors, name) {}
~EqualTensorRT() override = default;
int AddInnerOp(nvinfer1::INetworkDefinition *network) override;
int IsSupport(const schema::Primitive *primitive, const std::vector<mindspore::MSTensor> &in_tensors,
const std::vector<mindspore::MSTensor> &out_tensors) override;
private:
// EqualTensorRT
};
class EqualPluginCreater : public nvinfer1::IPluginCreator {
public:
EqualPluginCreater();
const char *getPluginName() const noexcept override;
const char *getPluginVersion() const noexcept override;
const nvinfer1::PluginFieldCollection *getFieldNames() noexcept override;
nvinfer1::IPluginV2 *createPlugin(const char *name, const nvinfer1::PluginFieldCollection *fc) noexcept override;
nvinfer1::IPluginV2 *deserializePlugin(const char *name, const void *serialData,
size_t serialLength) noexcept override;
void setPluginNamespace(const char *pluginNamespace) noexcept override;
const char *getPluginNamespace() const noexcept override;
private:
static nvinfer1::PluginFieldCollection field_collection_;
static std::vector<nvinfer1::PluginField> fields_;
std::string name_space_;
};
class EqualPlugin : public nvinfer1::IPluginV2DynamicExt {
public:
explicit EqualPlugin(const std::string name) : layer_name_(name) {}
// It doesn't make sense to make GeluPluginDynamic without arguments, so we delete
// default constructor.
EqualPlugin() = delete;
// IPluginV2DynamicExt Methods
nvinfer1::IPluginV2DynamicExt *clone() const noexcept override;
nvinfer1::DimsExprs getOutputDimensions(int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs,
nvinfer1::IExprBuilder &exprBuilder) noexcept override;
bool supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc *tensorsDesc, int nbInputs,
int nbOutputs) noexcept override;
void configurePlugin(const nvinfer1::DynamicPluginTensorDesc *in, int nbInputs,
const nvinfer1::DynamicPluginTensorDesc *out, int nbOutputs) noexcept override;
size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs, int nbInputs,
const nvinfer1::PluginTensorDesc *outputs, int nbOutputs) const noexcept override;
int enqueue(const nvinfer1::PluginTensorDesc *inputDesc, const nvinfer1::PluginTensorDesc *outputDesc,
const void *const *inputs, void *const *outputs, void *workspace, cudaStream_t stream) noexcept override;
// IPluginV2Ext Methods
nvinfer1::DataType getOutputDataType(int index, const nvinfer1::DataType *inputTypes, int nbInputs) const
noexcept override;
// IPluginV2 Methods
const char *getPluginType() const noexcept override;
const char *getPluginVersion() const noexcept override;
int getNbOutputs() const noexcept override;
int initialize() noexcept override;
void terminate() noexcept override;
size_t getSerializationSize() const noexcept override;
void serialize(void *buffer) const noexcept override;
void destroy() noexcept override;
void setPluginNamespace(const char *pluginNamespace) noexcept override;
const char *getPluginNamespace() const noexcept override;
private:
const std::string layer_name_;
std::string name_space_;
};
} // namespace mindspore::lite
#endif // MINDSPORE_LITE_SRC_DELEGATE_TENSORRT_OP_EQUAL_TENSORRT_H_

View File

@ -89,9 +89,20 @@ int PoolTensorRT::AddInnerOp(nvinfer1::INetworkDefinition *network) {
}
activation_layer->setName((op_name_ + "_activation").c_str());
}
activation_layer->getOutput(0)->setName((op_name_ + "_output").c_str());
this->AddInnerOutTensors(ITensorHelper{activation_layer->getOutput(0), Format::NCHW});
MS_LOG(DEBUG) << "output " << GetTensorFormat(activation_layer->getOutput(0), Format::NCHW);
nvinfer1::ITensor *out_trt_tensor = activation_layer->getOutput(0);
if (out_trt_tensor->getDimensions().nbDims == DIMENSION_4D) {
// transpose output from nchw to nhwc
nvinfer1::IShuffleLayer *transpose_layer_out = NCHW2NHWC(network, *out_trt_tensor);
if (transpose_layer_out == nullptr) {
MS_LOG(ERROR) << "op action convert failed";
return RET_ERROR;
}
transpose_layer_out->setName((op_name_ + "_transpose2NHWC").c_str());
out_trt_tensor = transpose_layer_out->getOutput(0);
}
out_trt_tensor->setName((op_name_ + "_output").c_str());
this->AddInnerOutTensors(ITensorHelper{out_trt_tensor, Format::NHWC});
MS_LOG(DEBUG) << "output " << GetTensorFormat(out_trt_tensor, Format::NHWC);
return RET_OK;
}

View File

@ -151,6 +151,7 @@ int ResizeTensorRT::SetOutputDims(nvinfer1::ITensor *resize_in_tensor, nvinfer1:
float scales[out_tensors_[0].Shape().size()];
for (size_t i = 0; i < out_tensors_[0].Shape().size(); i++) {
scales[i] = static_cast<float>(out_tensors_[0].Shape()[i]) / static_cast<float>(in_tensors_[0].Shape()[i]);
MS_LOG(DEBUG) << op_name_ << "scale at " << i << ": " << scales[i];
}
resize_layer->setScales(scales, out_tensors_[0].Shape().size());
}

View File

@ -73,26 +73,28 @@ int ShuffleTensorRT::AddInnerOp(nvinfer1::INetworkDefinition *network) {
MS_LOG(ERROR) << "network is invalid";
return RET_ERROR;
}
nvinfer1::ITensor *shuffler_input = tensorrt_in_tensors_[0].trt_tensor_;
MS_LOG(DEBUG) << "before transpose " << GetTensorFormat(shuffler_input, tensorrt_in_tensors_[0].format_);
if (tensorrt_in_tensors_[0].trt_tensor_->getDimensions().nbDims == DIMENSION_4D &&
!SameDims(tensorrt_in_tensors_[0].trt_tensor_->getDimensions(), in_tensors_[0].Shape())) {
shuffler_input_ = tensorrt_in_tensors_[0].trt_tensor_;
MS_LOG(DEBUG) << "before transpose " << GetTensorFormat(shuffler_input_, tensorrt_in_tensors_[0].format_);
if (shuffler_input_->getDimensions().nbDims == DIMENSION_4D &&
!SameDims(shuffler_input_->getDimensions(), in_tensors_[0].Shape())) {
// only valid for nchw or nhwc
if (tensorrt_in_tensors_[0].format_ == Format::NCHW) {
nvinfer1::IShuffleLayer *transpose_layer = NCHW2NHWC(network, *tensorrt_in_tensors_[0].trt_tensor_);
nvinfer1::IShuffleLayer *transpose_layer = NCHW2NHWC(network, *shuffler_input_);
if (transpose_layer == nullptr) {
MS_LOG(ERROR) << "create transpose layer failed for " << op_name_;
return RET_ERROR;
}
transpose_layer->setName((op_name_ + "_transpose_in").c_str());
shuffler_input = transpose_layer->getOutput(0);
shuffler_input_ = transpose_layer->getOutput(0);
out_format_ = Format::NHWC;
} else if (tensorrt_in_tensors_[0].format_ == Format::NHWC) {
nvinfer1::IShuffleLayer *transpose_layer = NHWC2NCHW(network, *tensorrt_in_tensors_[0].trt_tensor_);
nvinfer1::IShuffleLayer *transpose_layer = NHWC2NCHW(network, *shuffler_input_);
if (transpose_layer == nullptr) {
MS_LOG(ERROR) << "create transpose layer failed for " << op_name_;
return RET_ERROR;
}
transpose_layer->setName((op_name_ + "_transpose_in").c_str());
shuffler_input = transpose_layer->getOutput(0);
shuffler_input_ = transpose_layer->getOutput(0);
out_format_ = Format::NCHW;
} else {
MS_LOG(ERROR) << "invalid input format for " << op_name_;
@ -101,9 +103,9 @@ int ShuffleTensorRT::AddInnerOp(nvinfer1::INetworkDefinition *network) {
} else {
out_format_ = tensorrt_in_tensors_[0].format_;
}
MS_LOG(DEBUG) << "after transpose " << GetTensorFormat(shuffler_input, out_format_);
MS_LOG(DEBUG) << "after transpose " << GetTensorFormat(shuffler_input_, out_format_);
nvinfer1::IShuffleLayer *shuffle_layer = network->addShuffle(*shuffler_input);
nvinfer1::IShuffleLayer *shuffle_layer = network->addShuffle(*shuffler_input_);
if (shuffle_layer == nullptr) {
MS_LOG(ERROR) << "add Shuffle op failed for TensorRT.";
return RET_ERROR;
@ -165,7 +167,8 @@ int ShuffleTensorRT::AddSqueezeOp(nvinfer1::IShuffleLayer *shuffle_layer) {
}
// axis
auto squeeze_shape = std::vector<int64_t>(in_tensors_[0].Shape().begin(), in_tensors_[0].Shape().end());
auto squeeze_shape = shuffler_input_->getDimensions();
std::vector<int64_t> new_shape(squeeze_shape.d, squeeze_shape.d + squeeze_shape.nbDims);
auto axis = squeeze_op->axis();
if (axis == nullptr) {
MS_LOG(ERROR) << "AddSqueezeOp has invalid axis";
@ -173,13 +176,13 @@ int ShuffleTensorRT::AddSqueezeOp(nvinfer1::IShuffleLayer *shuffle_layer) {
}
for (int i = axis->size() - 1; i >= 0; i--) {
if (squeeze_shape[axis->Get(i)] != 1) {
MS_LOG(WARNING) << "squeeze_shape value is not 1, need check";
if (new_shape[axis->Get(i)] != 1) {
MS_LOG(WARNING) << "squeeze_shape value at " << i << " is " << axis->Get(i) << ", need check " << op_name_;
}
squeeze_shape.erase(squeeze_shape.begin() + axis->Get(i));
new_shape.erase(new_shape.begin() + axis->Get(i));
}
nvinfer1::Dims squeeze_dims = lite::ConvertCudaDims(squeeze_shape);
nvinfer1::Dims squeeze_dims = lite::ConvertCudaDims(new_shape);
shuffle_layer->setReshapeDimensions(squeeze_dims);
return shuffle_layer->getOutput(0) == nullptr ? RET_ERROR : RET_OK;
@ -196,7 +199,7 @@ int ShuffleTensorRT::AddUnsqueezeOp(nvinfer1::IShuffleLayer *shuffle_layer) {
MS_LOG(WARNING) << "AddUnsqueezeOp size of in tensort needs check: " << in_tensors_.size();
}
// axis
auto unsqueeze_shape = tensorrt_in_tensors_[0].trt_tensor_->getDimensions();
auto unsqueeze_shape = shuffler_input_->getDimensions();
std::vector<int64_t> new_shape(unsqueeze_shape.d, unsqueeze_shape.d + unsqueeze_shape.nbDims);
auto axis = unsqueeze_op->axis();
@ -278,7 +281,7 @@ int ShuffleTensorRT::AddExpandDimsOp(nvinfer1::IShuffleLayer *shuffle_layer) {
}
auto axis_data = static_cast<const int *>(in_tensors_[1].Data().get());
int axis = axis_data[0];
auto input_dims = tensorrt_in_tensors_[0].trt_tensor_->getDimensions();
auto input_dims = shuffler_input_->getDimensions();
std::vector<int64_t> new_shape;
for (int i = 0; i < input_dims.nbDims; i++) {
if (axis == i) {

View File

@ -44,6 +44,7 @@ class ShuffleTensorRT : public TensorRTOp {
int InferReshapeDims(nvinfer1::Dims input_dims, nvinfer1::Dims *reshape_dims);
Format out_format_ = Format::NHWC;
nvinfer1::ITensor *shuffler_input_;
};
} // namespace mindspore::lite
#endif // MINDSPORE_LITE_SRC_DELEGATE_TENSORRT_OP_SHUFFLE_TENSORRT_H_

View File

@ -25,9 +25,11 @@ int TopKTensorRT::IsSupport(const schema::Primitive *primitive, const std::vecto
}
if (in_tensors.size() != 1) {
MS_LOG(ERROR) << "Unsupported input tensor size, size is " << in_tensors.size();
return RET_ERROR;
}
if (out_tensors.size() != 1) {
MS_LOG(ERROR) << "Unsupported output tensor size, size is " << out_tensors.size();
return RET_ERROR;
}
return RET_OK;
}
@ -63,26 +65,62 @@ int TopKTensorRT::AddInnerOp(nvinfer1::INetworkDefinition *network) {
keep_dims = mim_prim->keep_dims();
} else {
MS_LOG(ERROR) << "invalid op primitive for " << op_name_;
}
if (keep_dims) {
MS_LOG(WARNING) << "keep dims is unsupported for " << op_name_;
return RET_ERROR;
}
if (tensorrt_in_tensors_[0].format_ == Format::NCHW) {
axis_value = ConvertAxisFromNHWC2NCHW(axis_value);
nvinfer1::ITensor *topk_input = tensorrt_in_tensors_[0].trt_tensor_;
Format output_format = tensorrt_in_tensors_[0].format_;
if (tensorrt_in_tensors_[0].trt_tensor_->getDimensions().nbDims == DIMENSION_4D &&
tensorrt_in_tensors_[0].format_ == Format::NCHW) {
nvinfer1::IShuffleLayer *transpose_layer = NCHW2NHWC(network, *topk_input);
if (transpose_layer == nullptr) {
MS_LOG(ERROR) << "create transpose layer failed for " << op_name_;
return RET_ERROR;
}
transpose_layer->setName((op_name_ + "_transpose_in").c_str());
topk_input = transpose_layer->getOutput(0);
output_format = Format::NHWC;
}
uint32_t reduce_axes = 1 << axis_value;
nvinfer1::ITopKLayer *topk_layer = network->addTopK(*tensorrt_in_tensors_[0].trt_tensor_, red_op, topk, reduce_axes);
nvinfer1::ITopKLayer *topk_layer = network->addTopK(*topk_input, red_op, topk, reduce_axes);
if (topk_layer == nullptr) {
MS_LOG(ERROR) << "addTopK failed for: " << op_name_;
return RET_ERROR;
}
topk_layer->setName(op_name_.c_str());
nvinfer1::ITensor *op_out_tensor = topk_layer->getOutput(1);
// output 0 is data value, output 1 is index
if (!keep_dims) {
MS_LOG(DEBUG) << op_name_ << "add squeeze for not keep dims at index " << axis_value;
if (op_out_tensor->getDimensions().d[axis_value] != 1) {
MS_LOG(ERROR) << "output dims is invalid for squeeze: " << op_name_;
return RET_ERROR;
}
nvinfer1::IShuffleLayer *squeeze_layer = network->addShuffle(*op_out_tensor);
if (squeeze_layer == nullptr) {
MS_LOG(ERROR) << "add squeeze layer failed for: " << op_name_;
return RET_ERROR;
}
nvinfer1::Dims squeeze_dims{};
squeeze_dims.nbDims = op_out_tensor->getDimensions().nbDims - 1;
if (axis_value != squeeze_dims.nbDims) {
MS_LOG(ERROR) << op_name_ << " reduce squeeze dims need check for axis: " << axis_value;
return RET_ERROR;
}
for (int i = 0; i < squeeze_dims.nbDims; i++) {
squeeze_dims.d[i] = 0;
// same with input
}
squeeze_layer->setReshapeDimensions(squeeze_dims);
squeeze_layer->setName((op_name_ + "_squeeze").c_str());
op_out_tensor = squeeze_layer->getOutput(0);
}
nvinfer1::ITensor *op_out_tensor = topk_layer->getOutput(0);
op_out_tensor->setName((op_name_ + "_output").c_str());
this->AddInnerOutTensors(ITensorHelper{op_out_tensor, tensorrt_in_tensors_[0].format_});
this->AddInnerOutTensors(ITensorHelper{op_out_tensor, output_format});
return RET_OK;
}
} // namespace mindspore::lite

View File

@ -37,6 +37,8 @@
#include "src/delegate/tensorrt/op/pool_tensorrt.h"
#include "src/delegate/tensorrt/op/pad_tensorrt.h"
#include "src/delegate/tensorrt/op/resize_tensorrt.h"
#include "src/delegate/tensorrt/op/equal_tensorrt.h"
#include "src/delegate/tensorrt/op/cast_tensorrt.h"
#include "src/delegate/tensorrt/op/topk_tensorrt.h"
namespace mindspore::lite {
@ -78,6 +80,7 @@ Status TensorRTDelegate::Init() {
{schema::PrimitiveType_Activation, GetTensorRTOp<ActivationTensorRT>},
{schema::PrimitiveType_Concat, GetTensorRTOp<ConcateTensorRT>},
{schema::PrimitiveType_Conv2DFusion, GetTensorRTOp<ConvolutionTensorRT>},
{schema::PrimitiveType_Cast, GetTensorRTOp<CastTensorRT>},
{schema::PrimitiveType_Conv2dTransposeFusion, GetTensorRTOp<DeconvolutionTensorRT>},
{schema::PrimitiveType_SubFusion, GetTensorRTOp<ElementWiseTensorRT>},
{schema::PrimitiveType_DivFusion, GetTensorRTOp<ElementWiseTensorRT>},
@ -88,6 +91,7 @@ Status TensorRTDelegate::Init() {
{schema::PrimitiveType_Minimum, GetTensorRTOp<ElementWiseTensorRT>},
{schema::PrimitiveType_Maximum, GetTensorRTOp<ElementWiseTensorRT>},
{schema::PrimitiveType_BiasAdd, GetTensorRTOp<ElementWiseTensorRT>},
{schema::PrimitiveType_Equal, GetTensorRTOp<EqualTensorRT>},
{schema::PrimitiveType_Gather, GetTensorRTOp<GatherTensorRT>},
{schema::PrimitiveType_MatMul, GetTensorRTOp<MatMulTensorRT>},
{schema::PrimitiveType_FullConnection, GetTensorRTOp<MatMulTensorRT>},

View File

@ -17,7 +17,8 @@
#define MINDSPORE_LITE_SRC_DELEGATE_TENSORRT_TENSORRT_DELEGATE_H_
#include <string>
#include <vector>
#include <map>
#include <unordered_map>
#include <unordered_set>
#include <memory>
#include "include/api/delegate.h"
#include "src/delegate/tensorrt/tensorrt_subgraph.h"
@ -47,11 +48,11 @@ class TensorRTDelegate : public Delegate {
TensorRTSubGraph *CreateTensorRTGraph(const std::vector<TensorRTOp *> &ops, DelegateModel *model, KernelIter from,
KernelIter end);
std::map<schema::PrimitiveType, TensorRTGetOp> op_func_lists_;
std::unordered_map<schema::PrimitiveType, TensorRTGetOp> op_func_lists_;
std::vector<schema::PrimitiveType> unsupport_hw_op_lists_;
std::unordered_set<schema::PrimitiveType> unsupport_hw_op_lists_;
std::vector<schema::PrimitiveType> unsupport_resize_op_list_;
std::unordered_set<schema::PrimitiveType> unsupport_resize_op_list_;
mindspore::Context *context_;

View File

@ -94,7 +94,7 @@ int CheckInfershapeResult(int result, const std::vector<lite::Tensor *> &inputs,
}
for (auto output : outputs) {
if (output->ElementsNum() >= MAX_MALLOC_SIZE / static_cast<int>(sizeof(int64_t))) {
if (static_cast<size_t>(output->ElementsNum()) >= GetMaxMallocSize() / sizeof(int64_t)) {
MS_LOG(ERROR) << "The size of output tensor is too big, output size: " << output->ElementsNum();
return RET_INFER_ERR;
}

View File

@ -17,6 +17,7 @@
#include "src/runtime/inner_allocator.h"
#include <utility>
#include "src/common/log_adapter.h"
#include "src/common/utils.h"
namespace mindspore {
std::shared_ptr<Allocator> Allocator::Create() { return std::make_shared<DefaultAllocator>(); }
@ -48,11 +49,11 @@ bool DefaultAllocator::ReuseMemory(size_t free_size, size_t size) {
}
void *DefaultAllocator::Malloc(size_t size) {
if (size > MAX_MALLOC_SIZE) {
if (size > lite::GetMaxMallocSize()) {
MS_LOG(ERROR) << "MallocData out of max_size, size: " << size;
return nullptr;
}
if (this->total_size_ >= MAX_THREAD_POOL_SIZE) {
if (this->total_size_ >= lite::GetMaxMallocSize()) {
MS_LOG(ERROR) << "Memory pool is exhausted";
return nullptr;
}

View File

@ -598,7 +598,7 @@ int Scheduler::InferNodeShape(const lite::Model::Node *node) {
if (ret == RET_OK) {
for (auto &output : outputs) {
if (output->ElementsNum() >= MAX_MALLOC_SIZE / static_cast<int>(sizeof(int64_t))) {
if (static_cast<size_t>(output->ElementsNum()) >= GetMaxMallocSize() / sizeof(int64_t)) {
MS_LOG(ERROR) << "The size of output tensor is too big";
FreeOpParameters();
return RET_ERROR;

View File

@ -41,9 +41,6 @@ namespace lite {
: (((y) >= 0) ? (INT64_MAX / (x)) > (-1 * (y)) : (INT64_MAX / (x)) > (y))))
#endif
namespace {
constexpr int kMaxMallocSize = 1024 * 1024 * 300;
} // namespace
Tensor::Tensor(const TypeId data_type, std::vector<int> shape, const mindspore::Format &format, Category category)
: data_type_(data_type), shape_(std::move(shape)), format_(format), category_(category) {}
@ -308,7 +305,7 @@ int Tensor::MallocData(const AllocatorPtr allocator) {
}
auto data_size = this->Size();
if (data_size > kMaxMallocSize) {
if (data_size > GetMaxMallocSize()) {
MS_LOG(ERROR) << "Malloc size is too big while coping data, " << data_size << " bytes";
return RET_ERROR;
}

View File

@ -12,6 +12,7 @@ set(REG_SRC ${CONVERT_REG_SRC}
${KERNEL_REG_DIR}/../runtime/inner_allocator.cc
${KERNEL_REG_DIR}/../common/string_util.cc
${KERNEL_REG_DIR}/../common/lite_utils.cc
${KERNEL_REG_DIR}/../common/utils.cc
${CORE_DIR}/utils/log_adapter.cc
${CORE_DIR}/utils/status.cc
${CORE_DIR}/gvar/log_adapter_common.cc