inference log level modify
This commit is contained in:
parent
1753f930d9
commit
41709acd31
|
@ -24,6 +24,7 @@
|
|||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <variant>
|
||||
#include <NvInfer.h>
|
||||
#include "utils/log_adapter.h"
|
||||
#include "utils/singleton.h"
|
||||
|
@ -47,7 +48,7 @@ class TrtUtils {
|
|||
return iter->second;
|
||||
}
|
||||
|
||||
static nvinfer1::DataType MsDtypeToTrtDtype(const TypeId &ms_dtype) {
|
||||
static std::variant<bool, nvinfer1::DataType> MsDtypeToTrtDtype(const TypeId &ms_dtype) {
|
||||
static std::map<TypeId, nvinfer1::DataType> type_list = {{TypeId::kNumberTypeFloat32, nvinfer1::DataType::kFLOAT},
|
||||
{TypeId::kNumberTypeFloat16, nvinfer1::DataType::kHALF},
|
||||
{TypeId::kNumberTypeInt8, nvinfer1::DataType::kINT8},
|
||||
|
@ -55,7 +56,8 @@ class TrtUtils {
|
|||
{TypeId::kNumberTypeInt32, nvinfer1::DataType::kINT32}};
|
||||
auto iter = type_list.find(ms_dtype);
|
||||
if (iter == type_list.end()) {
|
||||
MS_LOG(EXCEPTION) << "data type not support: " << ms_dtype;
|
||||
MS_LOG(WARNING) << "data type not support: " << ms_dtype;
|
||||
return false;
|
||||
}
|
||||
return iter->second;
|
||||
}
|
||||
|
@ -160,5 +162,12 @@ inline std::shared_ptr<T> TrtPtr(T *obj) {
|
|||
if (obj) obj->destroy();
|
||||
});
|
||||
}
|
||||
|
||||
#define TRT_VARIANT_CHECK(input, expect, ret) \
|
||||
do { \
|
||||
if ((input.index()) != (expect)) { \
|
||||
return ret; \
|
||||
} \
|
||||
} while (0)
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_TRT_UTILS_H_
|
||||
|
|
|
@ -23,6 +23,7 @@
|
|||
#include <memory>
|
||||
#include <sstream>
|
||||
#include <algorithm>
|
||||
#include <variant>
|
||||
#include "runtime/device/gpu/trt_loader.h"
|
||||
#include "backend/optimizer/trt_pass/trt_op_factory.h"
|
||||
#include "backend/kernel_compiler/gpu/trt/trt_utils.h"
|
||||
|
@ -89,10 +90,11 @@ bool TrtConverterContext::Serialize(std::string *model) {
|
|||
const auto &context = MsContext::GetInstance();
|
||||
const auto &precision_mode = context->get_param<std::string>(MS_CTX_INFER_PRECISION_MODE);
|
||||
if (precision_mode == "fp16") {
|
||||
MS_LOG(WARNING) << "Inference with mixed precision mode. It will take few minutes for operators selection.";
|
||||
MS_LOG(INFO) << "Inference with mixed precision mode";
|
||||
config_->setFlag(nvinfer1::BuilderFlag::kFP16);
|
||||
}
|
||||
|
||||
MS_LOG(WARNING) << "It will take few minutes for operators selection.";
|
||||
engine_ = TrtPtr(builder_->buildEngineWithConfig(*network_, *config_));
|
||||
MS_EXCEPTION_IF_NULL(engine_);
|
||||
|
||||
|
@ -117,7 +119,9 @@ bool TrtConverterContext::InitInputTable() {
|
|||
|
||||
nvinfer1::Weights weight;
|
||||
weight.values = tensor->data_c();
|
||||
weight.type = TrtUtils::MsDtypeToTrtDtype(tensor->data_type());
|
||||
std::variant<bool, nvinfer1::DataType> type = TrtUtils::MsDtypeToTrtDtype(tensor->data_type());
|
||||
TRT_VARIANT_CHECK(type, 1UL, false);
|
||||
weight.type = std::get<nvinfer1::DataType>(type);
|
||||
weight.count = tensor->DataSize();
|
||||
output_map_[input_node][0] = LayerInput(weight, tensor->shape());
|
||||
}
|
||||
|
@ -142,7 +146,9 @@ bool TrtConverterContext::InitValueNodeTable() {
|
|||
const auto &tensor = tensors[i];
|
||||
nvinfer1::Weights weight;
|
||||
weight.values = tensor->data_c();
|
||||
weight.type = TrtUtils::MsDtypeToTrtDtype(tensor->data_type());
|
||||
std::variant<bool, nvinfer1::DataType> type = TrtUtils::MsDtypeToTrtDtype(tensor->data_type());
|
||||
TRT_VARIANT_CHECK(type, 1UL, false);
|
||||
weight.type = std::get<nvinfer1::DataType>(type);
|
||||
weight.count = tensor->DataSize();
|
||||
output_map_[value_node][i] = LayerInput(weight, tensor->shape());
|
||||
}
|
||||
|
@ -179,7 +185,9 @@ bool TrtConverterContext::StoreLayerOutput(const AnfNodePtr &node, const std::ve
|
|||
LayerInput *TrtConverterContext::LoadInputOnDemand(const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto input = node->cast<ParameterPtr>();
|
||||
const nvinfer1::DataType &trt_dtype = TrtUtils::MsDtypeToTrtDtype(AnfAlgo::GetOutputInferDataType(node, 0));
|
||||
std::variant<bool, nvinfer1::DataType> type = TrtUtils::MsDtypeToTrtDtype(AnfAlgo::GetOutputInferDataType(node, 0));
|
||||
TRT_VARIANT_CHECK(type, 1UL, nullptr);
|
||||
const auto &trt_dtype = std::get<nvinfer1::DataType>(type);
|
||||
const nvinfer1::Dims &trt_dims = TrtUtils::MsDimsToTrtDims(AnfAlgo::GetOutputInferShape(node, 0), false);
|
||||
nvinfer1::ITensor *tensor = network_->addInput(input->name().c_str(), trt_dtype, trt_dims);
|
||||
const std::vector<int64_t> &shape = TrtUtils::TrtDimsToMsDims(trt_dims);
|
||||
|
@ -195,6 +203,10 @@ bool TrtConverterContext::LoadLayerInput(const AnfNodePtr &node, std::vector<Lay
|
|||
if (node_iter == output_map_.end()) {
|
||||
if (item.first->isa<Parameter>()) {
|
||||
LayerInput *input = LoadInputOnDemand(item.first);
|
||||
if (input == nullptr) {
|
||||
MS_LOG(WARNING) << "LoadLayerInput failed.";
|
||||
return false;
|
||||
}
|
||||
inputs->push_back(*input);
|
||||
continue;
|
||||
}
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <variant>
|
||||
#include <NvInfer.h>
|
||||
#include "backend/optimizer/trt_pass/trt_converter_context.h"
|
||||
#include "backend/optimizer/trt_pass/trt_op_factory.h"
|
||||
|
@ -879,7 +880,11 @@ MS_TRT_CONVERTER_FUNC_REG(Cast) {
|
|||
nvinfer1::ITensor *input = ToTensor(&inputs[0], input_shape, context);
|
||||
|
||||
const TypeId &dst_type = AnfAlgo::GetOutputInferDataType(node, 0);
|
||||
auto trt_type = TrtUtils::MsDtypeToTrtDtype(dst_type);
|
||||
std::variant<bool, nvinfer1::DataType> type = TrtUtils::MsDtypeToTrtDtype(dst_type);
|
||||
if (type.index() != 1) {
|
||||
return {false, {}};
|
||||
}
|
||||
auto trt_type = std::get<nvinfer1::DataType>(type);
|
||||
auto *layer = context->network()->addIdentity(*input);
|
||||
layer->setOutputType(0, trt_type);
|
||||
|
||||
|
|
Loading…
Reference in New Issue