forked from OSSInnovation/mindspore
trt converter
This commit is contained in:
parent
870c799cb6
commit
943b992458
|
@ -68,7 +68,7 @@ bool TrtKernel::Init(const CNodePtr &kernel_node) {
|
|||
return true;
|
||||
}
|
||||
|
||||
TrtKernel::ReleaseResource() {
|
||||
void TrtKernel::ReleaseResource() {
|
||||
// Make sure destroy trt object before TrtLoader destruct.
|
||||
context_.reset();
|
||||
engine_.reset();
|
||||
|
|
|
@ -50,6 +50,33 @@ kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(CNodePtr node) {
|
|||
builder.SetOutputsFormat(outputs_format);
|
||||
return builder.Build();
|
||||
}
|
||||
|
||||
AnfNodePtr RelpaceOutputEdge(const AnfNodePtr &node, CNodePtr adam, AnfNodePtr u_input) {
|
||||
// Replace the parameters of the last UpdateState to maintain
|
||||
// the execution order of FusedAdam and the following operators.
|
||||
// n represents the operator assign_v in {prim::kPrimDepend, next_param, assign_v}
|
||||
const auto &n = node->cast<CNodePtr>()->input(2);
|
||||
MS_EXCEPTION_IF_NULL(n);
|
||||
const auto &fg = n->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(fg);
|
||||
auto mgr = fg->manager();
|
||||
MS_EXCEPTION_IF_NULL(mgr);
|
||||
auto &node_users = mgr->node_users();
|
||||
auto iter = node_users.find(n);
|
||||
if (iter == node_users.end()) {
|
||||
MS_LOG(EXCEPTION) << "Can not find node : " << n->DebugString();
|
||||
}
|
||||
|
||||
auto &users = iter->second;
|
||||
for (auto &user : users) {
|
||||
if (IsPrimitiveCNode(user.first, prim::kPrimUpdateState)) {
|
||||
(user.first)->cast<CNodePtr>()->set_input(1, u_input);
|
||||
(user.first)->cast<CNodePtr>()->set_input(2, adam);
|
||||
break;
|
||||
}
|
||||
}
|
||||
return adam;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
const BaseRef AdamFusion::DefinePattern() const {
|
||||
|
@ -118,51 +145,19 @@ const AnfNodePtr AdamFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr
|
|||
// Fused into a FusedAdam operator.
|
||||
auto prim = std::make_shared<Primitive>(kFusedAdamName);
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
std::vector<AnfNodePtr> inputs = {NewValueNode(prim),
|
||||
beta1_input,
|
||||
one_sub_beta1_input,
|
||||
beta2_input,
|
||||
one_sub_beta2_input,
|
||||
eps_input,
|
||||
lr_input,
|
||||
param,
|
||||
m_input,
|
||||
v_input,
|
||||
gradient_input};
|
||||
auto prim_value = NewValueNode(prim);
|
||||
std::vector<AnfNodePtr> inputs = {
|
||||
prim_value, beta1_input, one_sub_beta1_input, beta2_input, one_sub_beta2_input, eps_input, lr_input, param,
|
||||
m_input, v_input, gradient_input};
|
||||
auto adam = graph->NewCNode(inputs);
|
||||
MS_EXCEPTION_IF_NULL(adam);
|
||||
auto types = {AnfAlgo::GetOutputInferDataType(node, 0)};
|
||||
auto shapes = {AnfAlgo::GetOutputInferShape(node, 0)};
|
||||
AnfAlgo::SetOutputInferTypeAndShape(types, shapes, adam.get());
|
||||
adam->set_scope(node->scope());
|
||||
|
||||
auto build_info = GenerateKernelBuildInfo(adam);
|
||||
AnfAlgo::SetSelectKernelBuildInfo(build_info, adam.get());
|
||||
|
||||
// Replace the parameters of the last UpdateState to maintain
|
||||
// the execution order of FusedAdam and the following operators.
|
||||
// n represents the operator assign_v in {prim::kPrimDepend, next_param, assign_v}
|
||||
auto n = node->cast<CNodePtr>()->input(2);
|
||||
auto fg = n->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(fg);
|
||||
auto mgr = fg->manager();
|
||||
MS_EXCEPTION_IF_NULL(mgr);
|
||||
auto &node_users = mgr->node_users();
|
||||
auto iter = node_users.find(n);
|
||||
if (iter == node_users.end()) {
|
||||
MS_LOG(EXCEPTION) << "Can not find node : " << n->DebugString();
|
||||
}
|
||||
|
||||
auto &users = iter->second;
|
||||
for (auto &user : users) {
|
||||
if (IsPrimitiveCNode(user.first, prim::kPrimUpdateState)) {
|
||||
(user.first)->cast<CNodePtr>()->set_input(1, u_input);
|
||||
(user.first)->cast<CNodePtr>()->set_input(2, adam);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
return adam;
|
||||
return RelpaceOutputEdge(node, adam, u_input);
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -50,6 +50,34 @@ kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(CNodePtr node) {
|
|||
builder.SetOutputsFormat(outputs_format);
|
||||
return builder.Build();
|
||||
}
|
||||
|
||||
AnfNodePtr ReplaceOutputEdge(const AnfNodePtr &node, CNodePtr adam_weight_decay, AnfNodePtr u_input) {
|
||||
// Replace the parameters of the last UpdateState to maintain
|
||||
// the execution order of FusedAdamWeightDecay and the following operators.
|
||||
// n represents the operator assign_v in {prim::kPrimDepend, next_param, assign_v}
|
||||
const auto &n = node->cast<CNodePtr>()->input(2);
|
||||
MS_EXCEPTION_IF_NULL(n);
|
||||
const auto &fg = n->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(fg);
|
||||
auto mgr = fg->manager();
|
||||
MS_EXCEPTION_IF_NULL(mgr);
|
||||
auto &node_users = mgr->node_users();
|
||||
auto iter = node_users.find(n);
|
||||
if (iter == node_users.end()) {
|
||||
MS_LOG(EXCEPTION) << "Can not find node : " << n->DebugString();
|
||||
}
|
||||
|
||||
auto &users = iter->second;
|
||||
for (auto &user : users) {
|
||||
if (IsPrimitiveCNode(user.first, prim::kPrimUpdateState)) {
|
||||
(user.first)->cast<CNodePtr>()->set_input(1, u_input);
|
||||
(user.first)->cast<CNodePtr>()->set_input(2, adam_weight_decay);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
return adam_weight_decay;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
const BaseRef AdamWeightDecayFusion::DefinePattern() const {
|
||||
|
@ -122,18 +150,10 @@ const AnfNodePtr AdamWeightDecayFusion::Process(const FuncGraphPtr &graph, const
|
|||
// Fused into a FusedAdamWeightDecay operator.
|
||||
auto prim = std::make_shared<Primitive>(kFusedAdamWeightDecayName);
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
std::vector<AnfNodePtr> inputs = {NewValueNode(prim),
|
||||
beta1_input,
|
||||
one_sub_beta1_input,
|
||||
beta2_input,
|
||||
one_sub_beta2_input,
|
||||
eps_input,
|
||||
lr_input,
|
||||
param,
|
||||
m_input,
|
||||
v_input,
|
||||
gradient_input,
|
||||
weight_decay_input};
|
||||
auto prim_value = NewValueNode(prim);
|
||||
std::vector<AnfNodePtr> inputs = {
|
||||
prim_value, beta1_input, one_sub_beta1_input, beta2_input, one_sub_beta2_input, eps_input, lr_input, param,
|
||||
m_input, v_input, gradient_input, weight_decay_input};
|
||||
auto adam_weight_decay = graph->NewCNode(inputs);
|
||||
MS_EXCEPTION_IF_NULL(adam_weight_decay);
|
||||
auto types = {AnfAlgo::GetOutputInferDataType(node, 0)};
|
||||
|
@ -143,31 +163,7 @@ const AnfNodePtr AdamWeightDecayFusion::Process(const FuncGraphPtr &graph, const
|
|||
|
||||
auto build_info = GenerateKernelBuildInfo(adam_weight_decay);
|
||||
AnfAlgo::SetSelectKernelBuildInfo(build_info, adam_weight_decay.get());
|
||||
|
||||
// Replace the parameters of the last UpdateState to maintain
|
||||
// the execution order of FusedAdamWeightDecay and the following operators.
|
||||
// n represents the operator assign_v in {prim::kPrimDepend, next_param, assign_v}
|
||||
auto n = node->cast<CNodePtr>()->input(2);
|
||||
auto fg = n->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(fg);
|
||||
auto mgr = fg->manager();
|
||||
MS_EXCEPTION_IF_NULL(mgr);
|
||||
auto &node_users = mgr->node_users();
|
||||
auto iter = node_users.find(n);
|
||||
if (iter == node_users.end()) {
|
||||
MS_LOG(EXCEPTION) << "Can not find node : " << n->DebugString();
|
||||
}
|
||||
|
||||
auto &users = iter->second;
|
||||
for (auto &user : users) {
|
||||
if (IsPrimitiveCNode(user.first, prim::kPrimUpdateState)) {
|
||||
(user.first)->cast<CNodePtr>()->set_input(1, u_input);
|
||||
(user.first)->cast<CNodePtr>()->set_input(2, adam_weight_decay);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
return adam_weight_decay;
|
||||
return ReplaceOutputEdge(node, adam_weight_decay, u_input);
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -0,0 +1,339 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/optimizer/trt_pass/trt_converter_context.h"
|
||||
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <sstream>
|
||||
#include <algorithm>
|
||||
#include "runtime/device/gpu/trt_loader.h"
|
||||
#include "backend/optimizer/trt_pass/trt_op_factory.h"
|
||||
#include "backend/kernel_compiler/gpu/trt/trt_utils.h"
|
||||
#include "utils/convert_utils.h"
|
||||
#include "utils/utils.h"
|
||||
#include "utils/singleton.h"
|
||||
|
||||
namespace mindspore::opt {
|
||||
namespace {
|
||||
void GetRealOutputRecursively(const AnfNodePtr &node, size_t output_index,
|
||||
std::vector<session::KernelWithIndex> *inputs) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (node->isa<ValueNode>() || node->isa<Parameter>()) {
|
||||
return inputs->push_back(std::make_pair(node, 0));
|
||||
}
|
||||
|
||||
// Skip control node
|
||||
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimDepend) || AnfAlgo::CheckPrimitiveType(node, prim::kPrimLoad) ||
|
||||
AnfAlgo::CheckPrimitiveType(node, prim::kPrimUpdateState)) {
|
||||
return GetRealOutputRecursive(node->cast<CNodePtr>()->input(kRealInputIndexInDepend), 0, inputs);
|
||||
}
|
||||
|
||||
// Bypass TupleGetItem
|
||||
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimTupleGetItem)) {
|
||||
auto tuple_get_item = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(tuple_get_item);
|
||||
auto input = AnfAlgo::GetTupleGetItemRealInput(tuple_get_item);
|
||||
auto index = AnfAlgo::GetTupleGetItemOutIndex(tuple_get_item);
|
||||
|
||||
// Conceal MakeTuple + TupleGetItem pair.
|
||||
if (AnfAlgo::CheckPrimitiveType(input, prim::kPrimMakeTuple)) {
|
||||
auto make_tuple = input->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(make_tuple);
|
||||
auto real_input = AnfAlgo::GetInputNode(make_tuple, index);
|
||||
return GetRealOutputRecursive(real_input, 0, inputs);
|
||||
}
|
||||
|
||||
// Skip TupleGetItem.
|
||||
return GetRealOutputRecursive(input, index, inputs);
|
||||
}
|
||||
|
||||
// Flatten MakeTuple inputs.
|
||||
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimMakeTuple)) {
|
||||
auto make_tuple = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(make_tuple);
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(make_tuple);
|
||||
for (size_t input_index = 0; input_index < input_num; ++input_index) {
|
||||
auto input_node = AnfAlgo::GetInputNode(make_tuple, input_index);
|
||||
GetRealOutputRecursive(input_node, 0, inputs);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
return inputs->push_back(std::make_pair(node, output_index));
|
||||
}
|
||||
|
||||
/* Get node real inputs bypass control nodes.
|
||||
* Examples:
|
||||
* Case 1:
|
||||
* c = Conv2D(a, b)
|
||||
* d = ReLU(c)
|
||||
* result: d--> (c)
|
||||
*
|
||||
* Case 2:
|
||||
* c = Conv2D(a, b)
|
||||
* d = Depend(c, v)
|
||||
* e = ReLU(d)
|
||||
* result: d -> (c)
|
||||
*
|
||||
* Case 3:
|
||||
* (f, g, h, i, j) = BatchNorm(a, b, c, d, e)
|
||||
* k = TupleGetItem((f, g, h, i, j), 0)
|
||||
* l = ReLU(k)
|
||||
* result: l -> (f)
|
||||
*
|
||||
* Case 4:
|
||||
* c = Conv2D(a, b)
|
||||
* e = MakeTuple(c, d)
|
||||
* f = TupleGetItem(e, 0)
|
||||
* g = ReLU(k)
|
||||
* result: g -> (c)
|
||||
*
|
||||
* Case 5:
|
||||
* b = MakeTuple(a1, a2, a3)
|
||||
* c = MakeTuple(b, a4)
|
||||
* d = return(c)
|
||||
* result d -> (a1, a2, a3, a4)
|
||||
*/
|
||||
void GetRealInputs(const AnfNodePtr &node, std::vector<session::KernelWithIndex> *inputs) {
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(node);
|
||||
for (size_t input_index = 0; input_index < input_num; ++input_index) {
|
||||
auto input_node = AnfAlgo::GetInputNode(node->cast<CNodePtr>(), input_index);
|
||||
GetRealOutputRecursively(input_node, 0, inputs);
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
bool TrtConverterContext::Init() {
|
||||
auto trt_loader = Singleton<device::gpu::TrtLoader>::Instance();
|
||||
builder_ = trt_loader.CreateInferBuilder(&Singleton<TrtLogger>::Instance());
|
||||
MS_EXCEPTION_IF_NULL(builder_);
|
||||
|
||||
auto batch_type = 1U << static_cast<uint32_t>(nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH);
|
||||
network_ = TrtPtr(builder_->createNetworkV2(batch_type));
|
||||
MS_EXCEPTION_IF_NULL(network_);
|
||||
|
||||
config_ = TrtPtr(builder_->createBuilderConfig());
|
||||
MS_EXCEPTION_IF_NULL(config_);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool TrtConverterContext::Parser() {
|
||||
InitInputTable();
|
||||
InitValueNodeTable();
|
||||
|
||||
std::vector<AnfNodePtr> node_list = TopoSort(func_graph_->get_return());
|
||||
const auto &converter_factory = TrtOpFactory::GetInstance();
|
||||
for (auto node : node_list) {
|
||||
if (!node->isa<CNode>()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Mark graph outputs
|
||||
std::string op_name = AnfAlgo::GetCNodePrimitive(node)->name();
|
||||
if (op_name == kReturnOpName) {
|
||||
std::vector<LayerInput> inputs;
|
||||
(void)LoadLayerInput(node, &inputs);
|
||||
|
||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||
const auto &input = inputs[i].tensor();
|
||||
std::string name = "return_output_" + std::to_string(i);
|
||||
input->setName(name.c_str());
|
||||
network_->markOutput(*input);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// Transform AnfNode To Trt layer.
|
||||
// Bypass control node including Depend, Load, UpdateState, TupleGetItem, MakeTuple.
|
||||
if (!AnfAlgo::IsRealKernel(node)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
ConvertFunc convert_func = converter_factory.GetConvertFunc(op_name);
|
||||
auto result = convert_func(node, this->shared_from_this());
|
||||
if (!result.first) {
|
||||
MS_LOG(ERROR) << op_name << " converter failed.";
|
||||
return false;
|
||||
}
|
||||
auto ret = StoreLayerOutput(node, result.second);
|
||||
if (!ret) {
|
||||
MS_LOG(ERROR) << op_name << " converter failed.";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
MS_LOG(ERROR) << "Graph ended without return node.";
|
||||
return false;
|
||||
}
|
||||
|
||||
bool TrtConverterContext::Serialize(std::string *model) {
|
||||
MS_EXCEPTION_IF_NULL(model);
|
||||
builder_->setMaxBatchSize(batch_size_);
|
||||
config_->setMaxWorkspaceSize(workspace_size_);
|
||||
engine_ = TrtPtr(builder_->buildEngineWithConfig(*network_, *config_));
|
||||
MS_EXCEPTION_IF_NULL(engine_);
|
||||
|
||||
std::shared_ptr<nvinfer1::IHostMemory> model_data = TrtPtr(engine_->serialize());
|
||||
*model = string(static_cast<const char *>(model_data->data()), model_data->size());
|
||||
return true;
|
||||
}
|
||||
|
||||
bool TrtConverterContext::InitInputTable() {
|
||||
const std::vector<AnfNodePtr> graph_inputs = func_graph_->parameters();
|
||||
for (auto input_node : graph_inputs) {
|
||||
if (!input_node->isa<Parameter>()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto input = input_node->cast<ParameterPtr>();
|
||||
if (AnfAlgo::IsParameterWeight(input)) {
|
||||
const auto ¶m_value = input->default_param();
|
||||
MS_EXCEPTION_IF_NULL(param_value);
|
||||
auto tensor = std::dynamic_pointer_cast<tensor::Tensor>(param_value);
|
||||
MS_EXCEPTION_IF_NULL(tensor);
|
||||
|
||||
nvinfer1::Weights weight;
|
||||
weight.values = tensor->data_c();
|
||||
weight.type = TrtUtils::MsDtypeToTrtDtype(tensor->data_type());
|
||||
weight.count = tensor->DataSize();
|
||||
output_map_[input_node][0] = LayerInput(weight);
|
||||
} else {
|
||||
nvinfer1::DataType trt_dtype = TrtUtils::MsDtypeToTrtDtype(AnfAlgo::GetOutputInferDataType(input_node, 0));
|
||||
nvinfer1::Dims trt_dims = TrtUtils::MsDimsToTrtDims(AnfAlgo::GetOutputInferShape(input_node, 0), false);
|
||||
nvinfer1::ITensor *tensor = network_->addInput(input->name().c_str(), trt_dtype, trt_dims);
|
||||
output_map_[input_node][0] = LayerInput(tensor);
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool TrtConverterContext::InitValueNodeTable() {
|
||||
auto kernel_graph = std::dynamic_pointer_cast<session::KernelGraph>(func_graph_);
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
|
||||
for (auto &value_node : kernel_graph->graph_value_nodes()) {
|
||||
MS_EXCEPTION_IF_NULL(value_node);
|
||||
auto &node_value = value_node->value();
|
||||
MS_EXCEPTION_IF_NULL(node_value);
|
||||
|
||||
if (node_value->isa<tensor::Tensor>() || node_value->isa<ValueTuple>()) {
|
||||
std::vector<tensor::TensorPtr> tensors;
|
||||
TensorValueToTensor(node_value, &tensors);
|
||||
for (size_t i = 0; i < tensors.size(); i++) {
|
||||
const auto &tensor = tensors[i];
|
||||
nvinfer1::Weights weight;
|
||||
weight.values = tensor->data_c();
|
||||
weight.type = TrtUtils::MsDtypeToTrtDtype(tensor->data_type());
|
||||
weight.count = tensor->DataSize();
|
||||
output_map_[value_node][i] = LayerInput(weight);
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool TrtConverterContext::StoreLayerOutput(const AnfNodePtr &node, const std::vector<LayerInput> &nv_tensors) {
|
||||
if (nv_tensors.size() != AnfAlgo::GetOutputTensorNum(node)) {
|
||||
MS_LOG(INFO) << node->DebugString() << " output num not match. expect: " << AnfAlgo::GetOutputTensorNum(node)
|
||||
<< ", while got: " << nv_tensors.size();
|
||||
}
|
||||
|
||||
for (size_t tensor_index = 0; tensor_index < nv_tensors.size(); ++tensor_index) {
|
||||
if (nv_tensors[tensor_index].tensor() != nullptr) {
|
||||
output_map_[node][tensor_index] = nv_tensors[tensor_index];
|
||||
|
||||
std::ostringstream oss;
|
||||
nvinfer1::Dims dim = nv_tensors[tensor_index].tensor()->getDimensions();
|
||||
oss << node->fullname_with_scope() << ", output: " << tensor_index << ": [ ";
|
||||
for (int32_t dim_index = 0; dim_index < dim.nbDims; dim_index++) {
|
||||
oss << dim.d[dim_index] << " ";
|
||||
}
|
||||
oss << "]";
|
||||
MS_LOG(INFO) << oss.str();
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool TrtConverterContext::LoadLayerInput(const AnfNodePtr &node, std::vector<LayerInput> *inputs) {
|
||||
std::vector<session::KernelWithIndex> real_inputs;
|
||||
GetRealInputs(node, &real_inputs);
|
||||
for (auto item : real_inputs) {
|
||||
auto node_iter = output_map_.find(item.first);
|
||||
if (node_iter == output_map_.end()) {
|
||||
MS_LOG(ERROR) << "node: " << node->DebugString() << " not found.";
|
||||
return false;
|
||||
}
|
||||
|
||||
auto out_iter = node_iter->second.find(item.second);
|
||||
if (out_iter == node_iter->second.end()) {
|
||||
MS_LOG(ERROR) << "node: " << node->DebugString() << "output index: " << item.second << " not found.";
|
||||
return false;
|
||||
}
|
||||
|
||||
inputs->push_back(out_iter->second);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<AnfNodePtr> TrtConverterContext::GetGraphInputs() {
|
||||
// Get Anf-graph inputs without weights. All weights were binded to Trt-graph.
|
||||
std::unordered_map<std::string, AnfNodePtr> graph_inputs;
|
||||
for (const auto &input_node : func_graph_->parameters()) {
|
||||
if (!input_node->isa<Parameter>()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto input = input_node->cast<ParameterPtr>();
|
||||
if (!AnfAlgo::IsParameterWeight(input)) {
|
||||
graph_inputs.insert(std::make_pair(input->name(), input_node));
|
||||
}
|
||||
}
|
||||
|
||||
// Keep the graph inputs in order of the binding name.
|
||||
std::vector<AnfNodePtr> trt_inputs;
|
||||
for (int32_t i = 0; i < engine_->getNbBindings(); ++i) {
|
||||
if (!engine_->bindingIsInput(i)) {
|
||||
continue;
|
||||
}
|
||||
auto iter = graph_inputs.find(engine_->getBindingName(i));
|
||||
if (iter == graph_inputs.end()) {
|
||||
MS_LOG(EXCEPTION) << "Get graph inputs failed. input name" << engine_->getBindingName(i);
|
||||
}
|
||||
trt_inputs.push_back(iter->second);
|
||||
}
|
||||
return trt_inputs;
|
||||
}
|
||||
|
||||
std::vector<session::KernelWithIndex> TrtConverterContext::GetGraphOutputs() {
|
||||
std::vector<session::KernelWithIndex> graph_outputs;
|
||||
GetRealInputs(func_graph_->get_return(), &graph_outputs);
|
||||
return graph_outputs;
|
||||
}
|
||||
|
||||
std::shared_ptr<tensor::Tensor> TrtConverterContext::CreateTempWeight(const TypeId &type,
|
||||
const std::vector<size_t> &shape) {
|
||||
ShapeVector shape_int;
|
||||
std::transform(shape.begin(), shape.end(), std::back_inserter(shape_int), SizeToLong);
|
||||
auto tensor = std::make_shared<tensor::Tensor>(type, shape_int);
|
||||
temp_weights_.push_back(tensor);
|
||||
return tensor;
|
||||
}
|
||||
} // namespace mindspore::opt
|
|
@ -0,0 +1,89 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTITIMIZER_TRT_CONVERTER_CONTEXT_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_OPTITIMIZER_TRT_CONVERTER_CONTEXT_H_
|
||||
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <NvInfer.h>
|
||||
#include "base/base.h"
|
||||
#include "ir/anf.h"
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
#include "backend/optimizer/trt_pass/layer_input.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class TrtConverterContext : public std::enable_shared_from_this<TrtConverterContext> {
|
||||
public:
|
||||
explicit TrtConverterContext(FuncGraphPtr fg)
|
||||
: func_graph_(fg),
|
||||
batch_size_(1),
|
||||
workspace_size_(4UL << 30),
|
||||
builder_(nullptr),
|
||||
network_(nullptr),
|
||||
config_(nullptr),
|
||||
engine_(nullptr) {}
|
||||
~TrtConverterContext() = default;
|
||||
|
||||
bool Init();
|
||||
|
||||
// Parser KernelGraph to trt graph
|
||||
bool Parser();
|
||||
|
||||
// Serialize trt models.
|
||||
bool Serialize(std::string *model);
|
||||
|
||||
// Get trt graph inputs without weights. The inputs keep same order as binding name.
|
||||
std::vector<AnfNodePtr> GetGraphInputs();
|
||||
|
||||
// Get trt graph outputs. All outputs are flatten to vector with concret shape.
|
||||
std::vector<session::KernelWithIndex> GetGraphOutputs();
|
||||
|
||||
// Store trt layer outputs to the cache.
|
||||
bool StoreLayerOutput(const AnfNodePtr &node, const std::vector<LayerInput> &inputs);
|
||||
|
||||
// Get trt layer inputs from the cache.
|
||||
bool LoadLayerInput(const AnfNodePtr &node, std::vector<LayerInput> *inputs);
|
||||
|
||||
// Create and keep temporary weight, as constant folding demanding new weight excluded in graph,
|
||||
// which should release until building finish.
|
||||
std::shared_ptr<tensor::Tensor> CreateTempWeight(const TypeId &type, const std::vector<size_t> &shape);
|
||||
|
||||
std::shared_ptr<nvinfer1::INetworkDefinition> network() const { return network_; }
|
||||
|
||||
private:
|
||||
bool InitInputTable();
|
||||
bool InitValueNodeTable();
|
||||
|
||||
FuncGraphPtr func_graph_;
|
||||
uint32_t batch_size_;
|
||||
size_t workspace_size_;
|
||||
std::shared_ptr<nvinfer1::IBuilder> builder_;
|
||||
std::shared_ptr<nvinfer1::INetworkDefinition> network_;
|
||||
std::shared_ptr<nvinfer1::IBuilderConfig> config_;
|
||||
std::shared_ptr<nvinfer1::ICudaEngine> engine_;
|
||||
|
||||
// Cache (AnfNode + output_index : ILayer output).
|
||||
std::unordered_map<AnfNodePtr, std::unordered_map<size_t, LayerInput>> output_map_;
|
||||
std::vector<std::shared_ptr<tensor::Tensor>> temp_weights_;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_OPTITIMIZER_TRT_CONVERTER_HELPER_H_
|
|
@ -29,9 +29,9 @@
|
|||
namespace mindspore {
|
||||
namespace opt {
|
||||
class LayerInput;
|
||||
class TrtConverterHelper;
|
||||
class TrtConverterContext;
|
||||
using ConvertResult = std::pair<bool, std::vector<LayerInput>>;
|
||||
using ConvertFunc = std::function<ConvertResult(AnfNodePtr, std::shared_ptr<TrtConverterHelper>)>;
|
||||
using ConvertFunc = std::function<ConvertResult(AnfNodePtr, std::shared_ptr<TrtConverterContext>)>;
|
||||
|
||||
class TrtOpFactory {
|
||||
public:
|
||||
|
@ -69,10 +69,10 @@ class TrtOpRegister {
|
|||
};
|
||||
|
||||
// Register operator converter from AnfNode to trt layer: `OPNAME` should keep the same as primitive definition.
|
||||
#define MS_TRT_CONVERTER_FUNC_REG(OPNAME) \
|
||||
ConvertResult Gpu##OPNAME##TrtConverter(AnfNodePtr node, std::shared_ptr<TrtConverterHelper> context); \
|
||||
static const TrtOpRegister(Gpu##OPNAME##ConverterRegister)(#OPNAME, Gpu##OPNAME##TrtConverter); \
|
||||
ConvertResult Gpu##OPNAME##TrtConverter(AnfNodePtr node, std::shared_ptr<TrtConverterHelper> context)
|
||||
#define MS_TRT_CONVERTER_FUNC_REG(OPNAME) \
|
||||
ConvertResult Gpu##OPNAME##TrtConverter(AnfNodePtr node, std::shared_ptr<TrtConverterContext> context); \
|
||||
static const TrtOpRegister(Gpu##OPNAME##ConverterRegister)(#OPNAME, Gpu##OPNAME##TrtConverter); \
|
||||
ConvertResult Gpu##OPNAME##TrtConverter(AnfNodePtr node, std::shared_ptr<TrtConverterContext> context)
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_OPTITIMIZER_TRT_PASS_OP_FACTORY_H_
|
||||
|
|
Loading…
Reference in New Issue