trt converter

This commit is contained in:
wilfChen 2021-04-02 17:31:07 +08:00
parent 870c799cb6
commit 943b992458
6 changed files with 500 additions and 81 deletions

View File

@ -68,7 +68,7 @@ bool TrtKernel::Init(const CNodePtr &kernel_node) {
return true;
}
TrtKernel::ReleaseResource() {
void TrtKernel::ReleaseResource() {
// Make sure destroy trt object before TrtLoader destruct.
context_.reset();
engine_.reset();

View File

@ -50,6 +50,33 @@ kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(CNodePtr node) {
builder.SetOutputsFormat(outputs_format);
return builder.Build();
}
AnfNodePtr RelpaceOutputEdge(const AnfNodePtr &node, CNodePtr adam, AnfNodePtr u_input) {
// Replace the parameters of the last UpdateState to maintain
// the execution order of FusedAdam and the following operators.
// n represents the operator assign_v in {prim::kPrimDepend, next_param, assign_v}
const auto &n = node->cast<CNodePtr>()->input(2);
MS_EXCEPTION_IF_NULL(n);
const auto &fg = n->func_graph();
MS_EXCEPTION_IF_NULL(fg);
auto mgr = fg->manager();
MS_EXCEPTION_IF_NULL(mgr);
auto &node_users = mgr->node_users();
auto iter = node_users.find(n);
if (iter == node_users.end()) {
MS_LOG(EXCEPTION) << "Can not find node : " << n->DebugString();
}
auto &users = iter->second;
for (auto &user : users) {
if (IsPrimitiveCNode(user.first, prim::kPrimUpdateState)) {
(user.first)->cast<CNodePtr>()->set_input(1, u_input);
(user.first)->cast<CNodePtr>()->set_input(2, adam);
break;
}
}
return adam;
}
} // namespace
const BaseRef AdamFusion::DefinePattern() const {
@ -118,51 +145,19 @@ const AnfNodePtr AdamFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr
// Fused into a FusedAdam operator.
auto prim = std::make_shared<Primitive>(kFusedAdamName);
MS_EXCEPTION_IF_NULL(prim);
std::vector<AnfNodePtr> inputs = {NewValueNode(prim),
beta1_input,
one_sub_beta1_input,
beta2_input,
one_sub_beta2_input,
eps_input,
lr_input,
param,
m_input,
v_input,
gradient_input};
auto prim_value = NewValueNode(prim);
std::vector<AnfNodePtr> inputs = {
prim_value, beta1_input, one_sub_beta1_input, beta2_input, one_sub_beta2_input, eps_input, lr_input, param,
m_input, v_input, gradient_input};
auto adam = graph->NewCNode(inputs);
MS_EXCEPTION_IF_NULL(adam);
auto types = {AnfAlgo::GetOutputInferDataType(node, 0)};
auto shapes = {AnfAlgo::GetOutputInferShape(node, 0)};
AnfAlgo::SetOutputInferTypeAndShape(types, shapes, adam.get());
adam->set_scope(node->scope());
auto build_info = GenerateKernelBuildInfo(adam);
AnfAlgo::SetSelectKernelBuildInfo(build_info, adam.get());
// Replace the parameters of the last UpdateState to maintain
// the execution order of FusedAdam and the following operators.
// n represents the operator assign_v in {prim::kPrimDepend, next_param, assign_v}
auto n = node->cast<CNodePtr>()->input(2);
auto fg = n->func_graph();
MS_EXCEPTION_IF_NULL(fg);
auto mgr = fg->manager();
MS_EXCEPTION_IF_NULL(mgr);
auto &node_users = mgr->node_users();
auto iter = node_users.find(n);
if (iter == node_users.end()) {
MS_LOG(EXCEPTION) << "Can not find node : " << n->DebugString();
}
auto &users = iter->second;
for (auto &user : users) {
if (IsPrimitiveCNode(user.first, prim::kPrimUpdateState)) {
(user.first)->cast<CNodePtr>()->set_input(1, u_input);
(user.first)->cast<CNodePtr>()->set_input(2, adam);
break;
}
}
return adam;
return RelpaceOutputEdge(node, adam, u_input);
}
} // namespace opt
} // namespace mindspore

View File

@ -50,6 +50,34 @@ kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(CNodePtr node) {
builder.SetOutputsFormat(outputs_format);
return builder.Build();
}
AnfNodePtr ReplaceOutputEdge(const AnfNodePtr &node, CNodePtr adam_weight_decay, AnfNodePtr u_input) {
// Replace the parameters of the last UpdateState to maintain
// the execution order of FusedAdamWeightDecay and the following operators.
// n represents the operator assign_v in {prim::kPrimDepend, next_param, assign_v}
const auto &n = node->cast<CNodePtr>()->input(2);
MS_EXCEPTION_IF_NULL(n);
const auto &fg = n->func_graph();
MS_EXCEPTION_IF_NULL(fg);
auto mgr = fg->manager();
MS_EXCEPTION_IF_NULL(mgr);
auto &node_users = mgr->node_users();
auto iter = node_users.find(n);
if (iter == node_users.end()) {
MS_LOG(EXCEPTION) << "Can not find node : " << n->DebugString();
}
auto &users = iter->second;
for (auto &user : users) {
if (IsPrimitiveCNode(user.first, prim::kPrimUpdateState)) {
(user.first)->cast<CNodePtr>()->set_input(1, u_input);
(user.first)->cast<CNodePtr>()->set_input(2, adam_weight_decay);
break;
}
}
return adam_weight_decay;
}
} // namespace
const BaseRef AdamWeightDecayFusion::DefinePattern() const {
@ -122,18 +150,10 @@ const AnfNodePtr AdamWeightDecayFusion::Process(const FuncGraphPtr &graph, const
// Fused into a FusedAdamWeightDecay operator.
auto prim = std::make_shared<Primitive>(kFusedAdamWeightDecayName);
MS_EXCEPTION_IF_NULL(prim);
std::vector<AnfNodePtr> inputs = {NewValueNode(prim),
beta1_input,
one_sub_beta1_input,
beta2_input,
one_sub_beta2_input,
eps_input,
lr_input,
param,
m_input,
v_input,
gradient_input,
weight_decay_input};
auto prim_value = NewValueNode(prim);
std::vector<AnfNodePtr> inputs = {
prim_value, beta1_input, one_sub_beta1_input, beta2_input, one_sub_beta2_input, eps_input, lr_input, param,
m_input, v_input, gradient_input, weight_decay_input};
auto adam_weight_decay = graph->NewCNode(inputs);
MS_EXCEPTION_IF_NULL(adam_weight_decay);
auto types = {AnfAlgo::GetOutputInferDataType(node, 0)};
@ -143,31 +163,7 @@ const AnfNodePtr AdamWeightDecayFusion::Process(const FuncGraphPtr &graph, const
auto build_info = GenerateKernelBuildInfo(adam_weight_decay);
AnfAlgo::SetSelectKernelBuildInfo(build_info, adam_weight_decay.get());
// Replace the parameters of the last UpdateState to maintain
// the execution order of FusedAdamWeightDecay and the following operators.
// n represents the operator assign_v in {prim::kPrimDepend, next_param, assign_v}
auto n = node->cast<CNodePtr>()->input(2);
auto fg = n->func_graph();
MS_EXCEPTION_IF_NULL(fg);
auto mgr = fg->manager();
MS_EXCEPTION_IF_NULL(mgr);
auto &node_users = mgr->node_users();
auto iter = node_users.find(n);
if (iter == node_users.end()) {
MS_LOG(EXCEPTION) << "Can not find node : " << n->DebugString();
}
auto &users = iter->second;
for (auto &user : users) {
if (IsPrimitiveCNode(user.first, prim::kPrimUpdateState)) {
(user.first)->cast<CNodePtr>()->set_input(1, u_input);
(user.first)->cast<CNodePtr>()->set_input(2, adam_weight_decay);
break;
}
}
return adam_weight_decay;
return ReplaceOutputEdge(node, adam_weight_decay, u_input);
}
} // namespace opt
} // namespace mindspore

View File

@ -0,0 +1,339 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/optimizer/trt_pass/trt_converter_context.h"
#include <unordered_map>
#include <utility>
#include <vector>
#include <string>
#include <memory>
#include <sstream>
#include <algorithm>
#include "runtime/device/gpu/trt_loader.h"
#include "backend/optimizer/trt_pass/trt_op_factory.h"
#include "backend/kernel_compiler/gpu/trt/trt_utils.h"
#include "utils/convert_utils.h"
#include "utils/utils.h"
#include "utils/singleton.h"
namespace mindspore::opt {
namespace {
void GetRealOutputRecursively(const AnfNodePtr &node, size_t output_index,
std::vector<session::KernelWithIndex> *inputs) {
MS_EXCEPTION_IF_NULL(node);
if (node->isa<ValueNode>() || node->isa<Parameter>()) {
return inputs->push_back(std::make_pair(node, 0));
}
// Skip control node
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimDepend) || AnfAlgo::CheckPrimitiveType(node, prim::kPrimLoad) ||
AnfAlgo::CheckPrimitiveType(node, prim::kPrimUpdateState)) {
return GetRealOutputRecursive(node->cast<CNodePtr>()->input(kRealInputIndexInDepend), 0, inputs);
}
// Bypass TupleGetItem
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimTupleGetItem)) {
auto tuple_get_item = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(tuple_get_item);
auto input = AnfAlgo::GetTupleGetItemRealInput(tuple_get_item);
auto index = AnfAlgo::GetTupleGetItemOutIndex(tuple_get_item);
// Conceal MakeTuple + TupleGetItem pair.
if (AnfAlgo::CheckPrimitiveType(input, prim::kPrimMakeTuple)) {
auto make_tuple = input->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(make_tuple);
auto real_input = AnfAlgo::GetInputNode(make_tuple, index);
return GetRealOutputRecursive(real_input, 0, inputs);
}
// Skip TupleGetItem.
return GetRealOutputRecursive(input, index, inputs);
}
// Flatten MakeTuple inputs.
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimMakeTuple)) {
auto make_tuple = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(make_tuple);
size_t input_num = AnfAlgo::GetInputTensorNum(make_tuple);
for (size_t input_index = 0; input_index < input_num; ++input_index) {
auto input_node = AnfAlgo::GetInputNode(make_tuple, input_index);
GetRealOutputRecursive(input_node, 0, inputs);
}
return;
}
return inputs->push_back(std::make_pair(node, output_index));
}
/* Get node real inputs bypass control nodes.
* Examples:
* Case 1:
* c = Conv2D(a, b)
* d = ReLU(c)
* result: d--> (c)
*
* Case 2:
* c = Conv2D(a, b)
* d = Depend(c, v)
* e = ReLU(d)
* result: d -> (c)
*
* Case 3:
* (f, g, h, i, j) = BatchNorm(a, b, c, d, e)
* k = TupleGetItem((f, g, h, i, j), 0)
* l = ReLU(k)
* result: l -> (f)
*
* Case 4:
* c = Conv2D(a, b)
* e = MakeTuple(c, d)
* f = TupleGetItem(e, 0)
* g = ReLU(k)
* result: g -> (c)
*
* Case 5:
* b = MakeTuple(a1, a2, a3)
* c = MakeTuple(b, a4)
* d = return(c)
* result d -> (a1, a2, a3, a4)
*/
void GetRealInputs(const AnfNodePtr &node, std::vector<session::KernelWithIndex> *inputs) {
size_t input_num = AnfAlgo::GetInputTensorNum(node);
for (size_t input_index = 0; input_index < input_num; ++input_index) {
auto input_node = AnfAlgo::GetInputNode(node->cast<CNodePtr>(), input_index);
GetRealOutputRecursively(input_node, 0, inputs);
}
}
} // namespace
bool TrtConverterContext::Init() {
auto trt_loader = Singleton<device::gpu::TrtLoader>::Instance();
builder_ = trt_loader.CreateInferBuilder(&Singleton<TrtLogger>::Instance());
MS_EXCEPTION_IF_NULL(builder_);
auto batch_type = 1U << static_cast<uint32_t>(nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH);
network_ = TrtPtr(builder_->createNetworkV2(batch_type));
MS_EXCEPTION_IF_NULL(network_);
config_ = TrtPtr(builder_->createBuilderConfig());
MS_EXCEPTION_IF_NULL(config_);
return true;
}
bool TrtConverterContext::Parser() {
InitInputTable();
InitValueNodeTable();
std::vector<AnfNodePtr> node_list = TopoSort(func_graph_->get_return());
const auto &converter_factory = TrtOpFactory::GetInstance();
for (auto node : node_list) {
if (!node->isa<CNode>()) {
continue;
}
// Mark graph outputs
std::string op_name = AnfAlgo::GetCNodePrimitive(node)->name();
if (op_name == kReturnOpName) {
std::vector<LayerInput> inputs;
(void)LoadLayerInput(node, &inputs);
for (size_t i = 0; i < inputs.size(); ++i) {
const auto &input = inputs[i].tensor();
std::string name = "return_output_" + std::to_string(i);
input->setName(name.c_str());
network_->markOutput(*input);
}
return true;
}
// Transform AnfNode To Trt layer.
// Bypass control node including Depend, Load, UpdateState, TupleGetItem, MakeTuple.
if (!AnfAlgo::IsRealKernel(node)) {
continue;
}
ConvertFunc convert_func = converter_factory.GetConvertFunc(op_name);
auto result = convert_func(node, this->shared_from_this());
if (!result.first) {
MS_LOG(ERROR) << op_name << " converter failed.";
return false;
}
auto ret = StoreLayerOutput(node, result.second);
if (!ret) {
MS_LOG(ERROR) << op_name << " converter failed.";
return false;
}
}
MS_LOG(ERROR) << "Graph ended without return node.";
return false;
}
bool TrtConverterContext::Serialize(std::string *model) {
MS_EXCEPTION_IF_NULL(model);
builder_->setMaxBatchSize(batch_size_);
config_->setMaxWorkspaceSize(workspace_size_);
engine_ = TrtPtr(builder_->buildEngineWithConfig(*network_, *config_));
MS_EXCEPTION_IF_NULL(engine_);
std::shared_ptr<nvinfer1::IHostMemory> model_data = TrtPtr(engine_->serialize());
*model = string(static_cast<const char *>(model_data->data()), model_data->size());
return true;
}
bool TrtConverterContext::InitInputTable() {
const std::vector<AnfNodePtr> graph_inputs = func_graph_->parameters();
for (auto input_node : graph_inputs) {
if (!input_node->isa<Parameter>()) {
continue;
}
auto input = input_node->cast<ParameterPtr>();
if (AnfAlgo::IsParameterWeight(input)) {
const auto &param_value = input->default_param();
MS_EXCEPTION_IF_NULL(param_value);
auto tensor = std::dynamic_pointer_cast<tensor::Tensor>(param_value);
MS_EXCEPTION_IF_NULL(tensor);
nvinfer1::Weights weight;
weight.values = tensor->data_c();
weight.type = TrtUtils::MsDtypeToTrtDtype(tensor->data_type());
weight.count = tensor->DataSize();
output_map_[input_node][0] = LayerInput(weight);
} else {
nvinfer1::DataType trt_dtype = TrtUtils::MsDtypeToTrtDtype(AnfAlgo::GetOutputInferDataType(input_node, 0));
nvinfer1::Dims trt_dims = TrtUtils::MsDimsToTrtDims(AnfAlgo::GetOutputInferShape(input_node, 0), false);
nvinfer1::ITensor *tensor = network_->addInput(input->name().c_str(), trt_dtype, trt_dims);
output_map_[input_node][0] = LayerInput(tensor);
}
}
return true;
}
bool TrtConverterContext::InitValueNodeTable() {
auto kernel_graph = std::dynamic_pointer_cast<session::KernelGraph>(func_graph_);
MS_EXCEPTION_IF_NULL(kernel_graph);
for (auto &value_node : kernel_graph->graph_value_nodes()) {
MS_EXCEPTION_IF_NULL(value_node);
auto &node_value = value_node->value();
MS_EXCEPTION_IF_NULL(node_value);
if (node_value->isa<tensor::Tensor>() || node_value->isa<ValueTuple>()) {
std::vector<tensor::TensorPtr> tensors;
TensorValueToTensor(node_value, &tensors);
for (size_t i = 0; i < tensors.size(); i++) {
const auto &tensor = tensors[i];
nvinfer1::Weights weight;
weight.values = tensor->data_c();
weight.type = TrtUtils::MsDtypeToTrtDtype(tensor->data_type());
weight.count = tensor->DataSize();
output_map_[value_node][i] = LayerInput(weight);
}
}
}
return true;
}
bool TrtConverterContext::StoreLayerOutput(const AnfNodePtr &node, const std::vector<LayerInput> &nv_tensors) {
if (nv_tensors.size() != AnfAlgo::GetOutputTensorNum(node)) {
MS_LOG(INFO) << node->DebugString() << " output num not match. expect: " << AnfAlgo::GetOutputTensorNum(node)
<< ", while got: " << nv_tensors.size();
}
for (size_t tensor_index = 0; tensor_index < nv_tensors.size(); ++tensor_index) {
if (nv_tensors[tensor_index].tensor() != nullptr) {
output_map_[node][tensor_index] = nv_tensors[tensor_index];
std::ostringstream oss;
nvinfer1::Dims dim = nv_tensors[tensor_index].tensor()->getDimensions();
oss << node->fullname_with_scope() << ", output: " << tensor_index << ": [ ";
for (int32_t dim_index = 0; dim_index < dim.nbDims; dim_index++) {
oss << dim.d[dim_index] << " ";
}
oss << "]";
MS_LOG(INFO) << oss.str();
}
}
return true;
}
bool TrtConverterContext::LoadLayerInput(const AnfNodePtr &node, std::vector<LayerInput> *inputs) {
std::vector<session::KernelWithIndex> real_inputs;
GetRealInputs(node, &real_inputs);
for (auto item : real_inputs) {
auto node_iter = output_map_.find(item.first);
if (node_iter == output_map_.end()) {
MS_LOG(ERROR) << "node: " << node->DebugString() << " not found.";
return false;
}
auto out_iter = node_iter->second.find(item.second);
if (out_iter == node_iter->second.end()) {
MS_LOG(ERROR) << "node: " << node->DebugString() << "output index: " << item.second << " not found.";
return false;
}
inputs->push_back(out_iter->second);
}
return true;
}
std::vector<AnfNodePtr> TrtConverterContext::GetGraphInputs() {
// Get Anf-graph inputs without weights. All weights were binded to Trt-graph.
std::unordered_map<std::string, AnfNodePtr> graph_inputs;
for (const auto &input_node : func_graph_->parameters()) {
if (!input_node->isa<Parameter>()) {
continue;
}
auto input = input_node->cast<ParameterPtr>();
if (!AnfAlgo::IsParameterWeight(input)) {
graph_inputs.insert(std::make_pair(input->name(), input_node));
}
}
// Keep the graph inputs in order of the binding name.
std::vector<AnfNodePtr> trt_inputs;
for (int32_t i = 0; i < engine_->getNbBindings(); ++i) {
if (!engine_->bindingIsInput(i)) {
continue;
}
auto iter = graph_inputs.find(engine_->getBindingName(i));
if (iter == graph_inputs.end()) {
MS_LOG(EXCEPTION) << "Get graph inputs failed. input name" << engine_->getBindingName(i);
}
trt_inputs.push_back(iter->second);
}
return trt_inputs;
}
std::vector<session::KernelWithIndex> TrtConverterContext::GetGraphOutputs() {
std::vector<session::KernelWithIndex> graph_outputs;
GetRealInputs(func_graph_->get_return(), &graph_outputs);
return graph_outputs;
}
std::shared_ptr<tensor::Tensor> TrtConverterContext::CreateTempWeight(const TypeId &type,
const std::vector<size_t> &shape) {
ShapeVector shape_int;
std::transform(shape.begin(), shape.end(), std::back_inserter(shape_int), SizeToLong);
auto tensor = std::make_shared<tensor::Tensor>(type, shape_int);
temp_weights_.push_back(tensor);
return tensor;
}
} // namespace mindspore::opt

View File

@ -0,0 +1,89 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_OPTITIMIZER_TRT_CONVERTER_CONTEXT_H_
#define MINDSPORE_CCSRC_BACKEND_OPTITIMIZER_TRT_CONVERTER_CONTEXT_H_
#include <unordered_map>
#include <vector>
#include <string>
#include <memory>
#include <NvInfer.h>
#include "base/base.h"
#include "ir/anf.h"
#include "backend/session/anf_runtime_algorithm.h"
#include "backend/optimizer/trt_pass/layer_input.h"
namespace mindspore {
namespace opt {
class TrtConverterContext : public std::enable_shared_from_this<TrtConverterContext> {
public:
explicit TrtConverterContext(FuncGraphPtr fg)
: func_graph_(fg),
batch_size_(1),
workspace_size_(4UL << 30),
builder_(nullptr),
network_(nullptr),
config_(nullptr),
engine_(nullptr) {}
~TrtConverterContext() = default;
bool Init();
// Parser KernelGraph to trt graph
bool Parser();
// Serialize trt models.
bool Serialize(std::string *model);
// Get trt graph inputs without weights. The inputs keep same order as binding name.
std::vector<AnfNodePtr> GetGraphInputs();
// Get trt graph outputs. All outputs are flatten to vector with concret shape.
std::vector<session::KernelWithIndex> GetGraphOutputs();
// Store trt layer outputs to the cache.
bool StoreLayerOutput(const AnfNodePtr &node, const std::vector<LayerInput> &inputs);
// Get trt layer inputs from the cache.
bool LoadLayerInput(const AnfNodePtr &node, std::vector<LayerInput> *inputs);
// Create and keep temporary weight, as constant folding demanding new weight excluded in graph,
// which should release until building finish.
std::shared_ptr<tensor::Tensor> CreateTempWeight(const TypeId &type, const std::vector<size_t> &shape);
std::shared_ptr<nvinfer1::INetworkDefinition> network() const { return network_; }
private:
bool InitInputTable();
bool InitValueNodeTable();
FuncGraphPtr func_graph_;
uint32_t batch_size_;
size_t workspace_size_;
std::shared_ptr<nvinfer1::IBuilder> builder_;
std::shared_ptr<nvinfer1::INetworkDefinition> network_;
std::shared_ptr<nvinfer1::IBuilderConfig> config_;
std::shared_ptr<nvinfer1::ICudaEngine> engine_;
// Cache (AnfNode + output_index : ILayer output).
std::unordered_map<AnfNodePtr, std::unordered_map<size_t, LayerInput>> output_map_;
std::vector<std::shared_ptr<tensor::Tensor>> temp_weights_;
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_OPTITIMIZER_TRT_CONVERTER_HELPER_H_

View File

@ -29,9 +29,9 @@
namespace mindspore {
namespace opt {
class LayerInput;
class TrtConverterHelper;
class TrtConverterContext;
using ConvertResult = std::pair<bool, std::vector<LayerInput>>;
using ConvertFunc = std::function<ConvertResult(AnfNodePtr, std::shared_ptr<TrtConverterHelper>)>;
using ConvertFunc = std::function<ConvertResult(AnfNodePtr, std::shared_ptr<TrtConverterContext>)>;
class TrtOpFactory {
public:
@ -69,10 +69,10 @@ class TrtOpRegister {
};
// Register operator converter from AnfNode to trt layer: `OPNAME` should keep the same as primitive definition.
#define MS_TRT_CONVERTER_FUNC_REG(OPNAME) \
ConvertResult Gpu##OPNAME##TrtConverter(AnfNodePtr node, std::shared_ptr<TrtConverterHelper> context); \
static const TrtOpRegister(Gpu##OPNAME##ConverterRegister)(#OPNAME, Gpu##OPNAME##TrtConverter); \
ConvertResult Gpu##OPNAME##TrtConverter(AnfNodePtr node, std::shared_ptr<TrtConverterHelper> context)
#define MS_TRT_CONVERTER_FUNC_REG(OPNAME) \
ConvertResult Gpu##OPNAME##TrtConverter(AnfNodePtr node, std::shared_ptr<TrtConverterContext> context); \
static const TrtOpRegister(Gpu##OPNAME##ConverterRegister)(#OPNAME, Gpu##OPNAME##TrtConverter); \
ConvertResult Gpu##OPNAME##TrtConverter(AnfNodePtr node, std::shared_ptr<TrtConverterContext> context)
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_OPTITIMIZER_TRT_PASS_OP_FACTORY_H_