diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/trt/trt_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/trt/trt_kernel.cc index c416d1f318..6d24c02760 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/trt/trt_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/trt/trt_kernel.cc @@ -68,7 +68,7 @@ bool TrtKernel::Init(const CNodePtr &kernel_node) { return true; } -TrtKernel::ReleaseResource() { +void TrtKernel::ReleaseResource() { // Make sure destroy trt object before TrtLoader destruct. context_.reset(); engine_.reset(); diff --git a/mindspore/ccsrc/backend/optimizer/gpu/adam_fusion.cc b/mindspore/ccsrc/backend/optimizer/gpu/adam_fusion.cc index 9caedbe23c..b28e95b2d5 100644 --- a/mindspore/ccsrc/backend/optimizer/gpu/adam_fusion.cc +++ b/mindspore/ccsrc/backend/optimizer/gpu/adam_fusion.cc @@ -50,6 +50,33 @@ kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(CNodePtr node) { builder.SetOutputsFormat(outputs_format); return builder.Build(); } + +AnfNodePtr RelpaceOutputEdge(const AnfNodePtr &node, CNodePtr adam, AnfNodePtr u_input) { + // Replace the parameters of the last UpdateState to maintain + // the execution order of FusedAdam and the following operators. + // n represents the operator assign_v in {prim::kPrimDepend, next_param, assign_v} + const auto &n = node->cast()->input(2); + MS_EXCEPTION_IF_NULL(n); + const auto &fg = n->func_graph(); + MS_EXCEPTION_IF_NULL(fg); + auto mgr = fg->manager(); + MS_EXCEPTION_IF_NULL(mgr); + auto &node_users = mgr->node_users(); + auto iter = node_users.find(n); + if (iter == node_users.end()) { + MS_LOG(EXCEPTION) << "Can not find node : " << n->DebugString(); + } + + auto &users = iter->second; + for (auto &user : users) { + if (IsPrimitiveCNode(user.first, prim::kPrimUpdateState)) { + (user.first)->cast()->set_input(1, u_input); + (user.first)->cast()->set_input(2, adam); + break; + } + } + return adam; +} } // namespace const BaseRef AdamFusion::DefinePattern() const { @@ -118,51 +145,19 @@ const AnfNodePtr AdamFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr // Fused into a FusedAdam operator. auto prim = std::make_shared(kFusedAdamName); MS_EXCEPTION_IF_NULL(prim); - std::vector inputs = {NewValueNode(prim), - beta1_input, - one_sub_beta1_input, - beta2_input, - one_sub_beta2_input, - eps_input, - lr_input, - param, - m_input, - v_input, - gradient_input}; + auto prim_value = NewValueNode(prim); + std::vector inputs = { + prim_value, beta1_input, one_sub_beta1_input, beta2_input, one_sub_beta2_input, eps_input, lr_input, param, + m_input, v_input, gradient_input}; auto adam = graph->NewCNode(inputs); MS_EXCEPTION_IF_NULL(adam); auto types = {AnfAlgo::GetOutputInferDataType(node, 0)}; auto shapes = {AnfAlgo::GetOutputInferShape(node, 0)}; AnfAlgo::SetOutputInferTypeAndShape(types, shapes, adam.get()); adam->set_scope(node->scope()); - auto build_info = GenerateKernelBuildInfo(adam); AnfAlgo::SetSelectKernelBuildInfo(build_info, adam.get()); - - // Replace the parameters of the last UpdateState to maintain - // the execution order of FusedAdam and the following operators. - // n represents the operator assign_v in {prim::kPrimDepend, next_param, assign_v} - auto n = node->cast()->input(2); - auto fg = n->func_graph(); - MS_EXCEPTION_IF_NULL(fg); - auto mgr = fg->manager(); - MS_EXCEPTION_IF_NULL(mgr); - auto &node_users = mgr->node_users(); - auto iter = node_users.find(n); - if (iter == node_users.end()) { - MS_LOG(EXCEPTION) << "Can not find node : " << n->DebugString(); - } - - auto &users = iter->second; - for (auto &user : users) { - if (IsPrimitiveCNode(user.first, prim::kPrimUpdateState)) { - (user.first)->cast()->set_input(1, u_input); - (user.first)->cast()->set_input(2, adam); - break; - } - } - - return adam; + return RelpaceOutputEdge(node, adam, u_input); } } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/gpu/adam_weight_decay_fusion.cc b/mindspore/ccsrc/backend/optimizer/gpu/adam_weight_decay_fusion.cc index 0a2924790c..ee3a7dffc5 100644 --- a/mindspore/ccsrc/backend/optimizer/gpu/adam_weight_decay_fusion.cc +++ b/mindspore/ccsrc/backend/optimizer/gpu/adam_weight_decay_fusion.cc @@ -50,6 +50,34 @@ kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(CNodePtr node) { builder.SetOutputsFormat(outputs_format); return builder.Build(); } + +AnfNodePtr ReplaceOutputEdge(const AnfNodePtr &node, CNodePtr adam_weight_decay, AnfNodePtr u_input) { + // Replace the parameters of the last UpdateState to maintain + // the execution order of FusedAdamWeightDecay and the following operators. + // n represents the operator assign_v in {prim::kPrimDepend, next_param, assign_v} + const auto &n = node->cast()->input(2); + MS_EXCEPTION_IF_NULL(n); + const auto &fg = n->func_graph(); + MS_EXCEPTION_IF_NULL(fg); + auto mgr = fg->manager(); + MS_EXCEPTION_IF_NULL(mgr); + auto &node_users = mgr->node_users(); + auto iter = node_users.find(n); + if (iter == node_users.end()) { + MS_LOG(EXCEPTION) << "Can not find node : " << n->DebugString(); + } + + auto &users = iter->second; + for (auto &user : users) { + if (IsPrimitiveCNode(user.first, prim::kPrimUpdateState)) { + (user.first)->cast()->set_input(1, u_input); + (user.first)->cast()->set_input(2, adam_weight_decay); + break; + } + } + + return adam_weight_decay; +} } // namespace const BaseRef AdamWeightDecayFusion::DefinePattern() const { @@ -122,18 +150,10 @@ const AnfNodePtr AdamWeightDecayFusion::Process(const FuncGraphPtr &graph, const // Fused into a FusedAdamWeightDecay operator. auto prim = std::make_shared(kFusedAdamWeightDecayName); MS_EXCEPTION_IF_NULL(prim); - std::vector inputs = {NewValueNode(prim), - beta1_input, - one_sub_beta1_input, - beta2_input, - one_sub_beta2_input, - eps_input, - lr_input, - param, - m_input, - v_input, - gradient_input, - weight_decay_input}; + auto prim_value = NewValueNode(prim); + std::vector inputs = { + prim_value, beta1_input, one_sub_beta1_input, beta2_input, one_sub_beta2_input, eps_input, lr_input, param, + m_input, v_input, gradient_input, weight_decay_input}; auto adam_weight_decay = graph->NewCNode(inputs); MS_EXCEPTION_IF_NULL(adam_weight_decay); auto types = {AnfAlgo::GetOutputInferDataType(node, 0)}; @@ -143,31 +163,7 @@ const AnfNodePtr AdamWeightDecayFusion::Process(const FuncGraphPtr &graph, const auto build_info = GenerateKernelBuildInfo(adam_weight_decay); AnfAlgo::SetSelectKernelBuildInfo(build_info, adam_weight_decay.get()); - - // Replace the parameters of the last UpdateState to maintain - // the execution order of FusedAdamWeightDecay and the following operators. - // n represents the operator assign_v in {prim::kPrimDepend, next_param, assign_v} - auto n = node->cast()->input(2); - auto fg = n->func_graph(); - MS_EXCEPTION_IF_NULL(fg); - auto mgr = fg->manager(); - MS_EXCEPTION_IF_NULL(mgr); - auto &node_users = mgr->node_users(); - auto iter = node_users.find(n); - if (iter == node_users.end()) { - MS_LOG(EXCEPTION) << "Can not find node : " << n->DebugString(); - } - - auto &users = iter->second; - for (auto &user : users) { - if (IsPrimitiveCNode(user.first, prim::kPrimUpdateState)) { - (user.first)->cast()->set_input(1, u_input); - (user.first)->cast()->set_input(2, adam_weight_decay); - break; - } - } - - return adam_weight_decay; + return ReplaceOutputEdge(node, adam_weight_decay, u_input); } } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/trt_pass/trt_converter_context.cc b/mindspore/ccsrc/backend/optimizer/trt_pass/trt_converter_context.cc new file mode 100644 index 0000000000..730efb9cb6 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/trt_pass/trt_converter_context.cc @@ -0,0 +1,339 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/optimizer/trt_pass/trt_converter_context.h" + +#include +#include +#include +#include +#include +#include +#include +#include "runtime/device/gpu/trt_loader.h" +#include "backend/optimizer/trt_pass/trt_op_factory.h" +#include "backend/kernel_compiler/gpu/trt/trt_utils.h" +#include "utils/convert_utils.h" +#include "utils/utils.h" +#include "utils/singleton.h" + +namespace mindspore::opt { +namespace { +void GetRealOutputRecursively(const AnfNodePtr &node, size_t output_index, + std::vector *inputs) { + MS_EXCEPTION_IF_NULL(node); + if (node->isa() || node->isa()) { + return inputs->push_back(std::make_pair(node, 0)); + } + + // Skip control node + if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimDepend) || AnfAlgo::CheckPrimitiveType(node, prim::kPrimLoad) || + AnfAlgo::CheckPrimitiveType(node, prim::kPrimUpdateState)) { + return GetRealOutputRecursive(node->cast()->input(kRealInputIndexInDepend), 0, inputs); + } + + // Bypass TupleGetItem + if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimTupleGetItem)) { + auto tuple_get_item = node->cast(); + MS_EXCEPTION_IF_NULL(tuple_get_item); + auto input = AnfAlgo::GetTupleGetItemRealInput(tuple_get_item); + auto index = AnfAlgo::GetTupleGetItemOutIndex(tuple_get_item); + + // Conceal MakeTuple + TupleGetItem pair. + if (AnfAlgo::CheckPrimitiveType(input, prim::kPrimMakeTuple)) { + auto make_tuple = input->cast(); + MS_EXCEPTION_IF_NULL(make_tuple); + auto real_input = AnfAlgo::GetInputNode(make_tuple, index); + return GetRealOutputRecursive(real_input, 0, inputs); + } + + // Skip TupleGetItem. + return GetRealOutputRecursive(input, index, inputs); + } + + // Flatten MakeTuple inputs. + if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimMakeTuple)) { + auto make_tuple = node->cast(); + MS_EXCEPTION_IF_NULL(make_tuple); + size_t input_num = AnfAlgo::GetInputTensorNum(make_tuple); + for (size_t input_index = 0; input_index < input_num; ++input_index) { + auto input_node = AnfAlgo::GetInputNode(make_tuple, input_index); + GetRealOutputRecursive(input_node, 0, inputs); + } + return; + } + + return inputs->push_back(std::make_pair(node, output_index)); +} + +/* Get node real inputs bypass control nodes. + * Examples: + * Case 1: + * c = Conv2D(a, b) + * d = ReLU(c) + * result: d--> (c) + * + * Case 2: + * c = Conv2D(a, b) + * d = Depend(c, v) + * e = ReLU(d) + * result: d -> (c) + * + * Case 3: + * (f, g, h, i, j) = BatchNorm(a, b, c, d, e) + * k = TupleGetItem((f, g, h, i, j), 0) + * l = ReLU(k) + * result: l -> (f) + * + * Case 4: + * c = Conv2D(a, b) + * e = MakeTuple(c, d) + * f = TupleGetItem(e, 0) + * g = ReLU(k) + * result: g -> (c) + * + * Case 5: + * b = MakeTuple(a1, a2, a3) + * c = MakeTuple(b, a4) + * d = return(c) + * result d -> (a1, a2, a3, a4) + */ +void GetRealInputs(const AnfNodePtr &node, std::vector *inputs) { + size_t input_num = AnfAlgo::GetInputTensorNum(node); + for (size_t input_index = 0; input_index < input_num; ++input_index) { + auto input_node = AnfAlgo::GetInputNode(node->cast(), input_index); + GetRealOutputRecursively(input_node, 0, inputs); + } +} +} // namespace + +bool TrtConverterContext::Init() { + auto trt_loader = Singleton::Instance(); + builder_ = trt_loader.CreateInferBuilder(&Singleton::Instance()); + MS_EXCEPTION_IF_NULL(builder_); + + auto batch_type = 1U << static_cast(nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH); + network_ = TrtPtr(builder_->createNetworkV2(batch_type)); + MS_EXCEPTION_IF_NULL(network_); + + config_ = TrtPtr(builder_->createBuilderConfig()); + MS_EXCEPTION_IF_NULL(config_); + return true; +} + +bool TrtConverterContext::Parser() { + InitInputTable(); + InitValueNodeTable(); + + std::vector node_list = TopoSort(func_graph_->get_return()); + const auto &converter_factory = TrtOpFactory::GetInstance(); + for (auto node : node_list) { + if (!node->isa()) { + continue; + } + + // Mark graph outputs + std::string op_name = AnfAlgo::GetCNodePrimitive(node)->name(); + if (op_name == kReturnOpName) { + std::vector inputs; + (void)LoadLayerInput(node, &inputs); + + for (size_t i = 0; i < inputs.size(); ++i) { + const auto &input = inputs[i].tensor(); + std::string name = "return_output_" + std::to_string(i); + input->setName(name.c_str()); + network_->markOutput(*input); + } + return true; + } + + // Transform AnfNode To Trt layer. + // Bypass control node including Depend, Load, UpdateState, TupleGetItem, MakeTuple. + if (!AnfAlgo::IsRealKernel(node)) { + continue; + } + + ConvertFunc convert_func = converter_factory.GetConvertFunc(op_name); + auto result = convert_func(node, this->shared_from_this()); + if (!result.first) { + MS_LOG(ERROR) << op_name << " converter failed."; + return false; + } + auto ret = StoreLayerOutput(node, result.second); + if (!ret) { + MS_LOG(ERROR) << op_name << " converter failed."; + return false; + } + } + + MS_LOG(ERROR) << "Graph ended without return node."; + return false; +} + +bool TrtConverterContext::Serialize(std::string *model) { + MS_EXCEPTION_IF_NULL(model); + builder_->setMaxBatchSize(batch_size_); + config_->setMaxWorkspaceSize(workspace_size_); + engine_ = TrtPtr(builder_->buildEngineWithConfig(*network_, *config_)); + MS_EXCEPTION_IF_NULL(engine_); + + std::shared_ptr model_data = TrtPtr(engine_->serialize()); + *model = string(static_cast(model_data->data()), model_data->size()); + return true; +} + +bool TrtConverterContext::InitInputTable() { + const std::vector graph_inputs = func_graph_->parameters(); + for (auto input_node : graph_inputs) { + if (!input_node->isa()) { + continue; + } + + auto input = input_node->cast(); + if (AnfAlgo::IsParameterWeight(input)) { + const auto ¶m_value = input->default_param(); + MS_EXCEPTION_IF_NULL(param_value); + auto tensor = std::dynamic_pointer_cast(param_value); + MS_EXCEPTION_IF_NULL(tensor); + + nvinfer1::Weights weight; + weight.values = tensor->data_c(); + weight.type = TrtUtils::MsDtypeToTrtDtype(tensor->data_type()); + weight.count = tensor->DataSize(); + output_map_[input_node][0] = LayerInput(weight); + } else { + nvinfer1::DataType trt_dtype = TrtUtils::MsDtypeToTrtDtype(AnfAlgo::GetOutputInferDataType(input_node, 0)); + nvinfer1::Dims trt_dims = TrtUtils::MsDimsToTrtDims(AnfAlgo::GetOutputInferShape(input_node, 0), false); + nvinfer1::ITensor *tensor = network_->addInput(input->name().c_str(), trt_dtype, trt_dims); + output_map_[input_node][0] = LayerInput(tensor); + } + } + return true; +} + +bool TrtConverterContext::InitValueNodeTable() { + auto kernel_graph = std::dynamic_pointer_cast(func_graph_); + MS_EXCEPTION_IF_NULL(kernel_graph); + + for (auto &value_node : kernel_graph->graph_value_nodes()) { + MS_EXCEPTION_IF_NULL(value_node); + auto &node_value = value_node->value(); + MS_EXCEPTION_IF_NULL(node_value); + + if (node_value->isa() || node_value->isa()) { + std::vector tensors; + TensorValueToTensor(node_value, &tensors); + for (size_t i = 0; i < tensors.size(); i++) { + const auto &tensor = tensors[i]; + nvinfer1::Weights weight; + weight.values = tensor->data_c(); + weight.type = TrtUtils::MsDtypeToTrtDtype(tensor->data_type()); + weight.count = tensor->DataSize(); + output_map_[value_node][i] = LayerInput(weight); + } + } + } + return true; +} + +bool TrtConverterContext::StoreLayerOutput(const AnfNodePtr &node, const std::vector &nv_tensors) { + if (nv_tensors.size() != AnfAlgo::GetOutputTensorNum(node)) { + MS_LOG(INFO) << node->DebugString() << " output num not match. expect: " << AnfAlgo::GetOutputTensorNum(node) + << ", while got: " << nv_tensors.size(); + } + + for (size_t tensor_index = 0; tensor_index < nv_tensors.size(); ++tensor_index) { + if (nv_tensors[tensor_index].tensor() != nullptr) { + output_map_[node][tensor_index] = nv_tensors[tensor_index]; + + std::ostringstream oss; + nvinfer1::Dims dim = nv_tensors[tensor_index].tensor()->getDimensions(); + oss << node->fullname_with_scope() << ", output: " << tensor_index << ": [ "; + for (int32_t dim_index = 0; dim_index < dim.nbDims; dim_index++) { + oss << dim.d[dim_index] << " "; + } + oss << "]"; + MS_LOG(INFO) << oss.str(); + } + } + return true; +} + +bool TrtConverterContext::LoadLayerInput(const AnfNodePtr &node, std::vector *inputs) { + std::vector real_inputs; + GetRealInputs(node, &real_inputs); + for (auto item : real_inputs) { + auto node_iter = output_map_.find(item.first); + if (node_iter == output_map_.end()) { + MS_LOG(ERROR) << "node: " << node->DebugString() << " not found."; + return false; + } + + auto out_iter = node_iter->second.find(item.second); + if (out_iter == node_iter->second.end()) { + MS_LOG(ERROR) << "node: " << node->DebugString() << "output index: " << item.second << " not found."; + return false; + } + + inputs->push_back(out_iter->second); + } + return true; +} + +std::vector TrtConverterContext::GetGraphInputs() { + // Get Anf-graph inputs without weights. All weights were binded to Trt-graph. + std::unordered_map graph_inputs; + for (const auto &input_node : func_graph_->parameters()) { + if (!input_node->isa()) { + continue; + } + + auto input = input_node->cast(); + if (!AnfAlgo::IsParameterWeight(input)) { + graph_inputs.insert(std::make_pair(input->name(), input_node)); + } + } + + // Keep the graph inputs in order of the binding name. + std::vector trt_inputs; + for (int32_t i = 0; i < engine_->getNbBindings(); ++i) { + if (!engine_->bindingIsInput(i)) { + continue; + } + auto iter = graph_inputs.find(engine_->getBindingName(i)); + if (iter == graph_inputs.end()) { + MS_LOG(EXCEPTION) << "Get graph inputs failed. input name" << engine_->getBindingName(i); + } + trt_inputs.push_back(iter->second); + } + return trt_inputs; +} + +std::vector TrtConverterContext::GetGraphOutputs() { + std::vector graph_outputs; + GetRealInputs(func_graph_->get_return(), &graph_outputs); + return graph_outputs; +} + +std::shared_ptr TrtConverterContext::CreateTempWeight(const TypeId &type, + const std::vector &shape) { + ShapeVector shape_int; + std::transform(shape.begin(), shape.end(), std::back_inserter(shape_int), SizeToLong); + auto tensor = std::make_shared(type, shape_int); + temp_weights_.push_back(tensor); + return tensor; +} +} // namespace mindspore::opt diff --git a/mindspore/ccsrc/backend/optimizer/trt_pass/trt_converter_context.h b/mindspore/ccsrc/backend/optimizer/trt_pass/trt_converter_context.h new file mode 100644 index 0000000000..c9e5157114 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/trt_pass/trt_converter_context.h @@ -0,0 +1,89 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_OPTITIMIZER_TRT_CONVERTER_CONTEXT_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTITIMIZER_TRT_CONVERTER_CONTEXT_H_ + +#include +#include +#include +#include +#include +#include "base/base.h" +#include "ir/anf.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "backend/optimizer/trt_pass/layer_input.h" + +namespace mindspore { +namespace opt { +class TrtConverterContext : public std::enable_shared_from_this { + public: + explicit TrtConverterContext(FuncGraphPtr fg) + : func_graph_(fg), + batch_size_(1), + workspace_size_(4UL << 30), + builder_(nullptr), + network_(nullptr), + config_(nullptr), + engine_(nullptr) {} + ~TrtConverterContext() = default; + + bool Init(); + + // Parser KernelGraph to trt graph + bool Parser(); + + // Serialize trt models. + bool Serialize(std::string *model); + + // Get trt graph inputs without weights. The inputs keep same order as binding name. + std::vector GetGraphInputs(); + + // Get trt graph outputs. All outputs are flatten to vector with concret shape. + std::vector GetGraphOutputs(); + + // Store trt layer outputs to the cache. + bool StoreLayerOutput(const AnfNodePtr &node, const std::vector &inputs); + + // Get trt layer inputs from the cache. + bool LoadLayerInput(const AnfNodePtr &node, std::vector *inputs); + + // Create and keep temporary weight, as constant folding demanding new weight excluded in graph, + // which should release until building finish. + std::shared_ptr CreateTempWeight(const TypeId &type, const std::vector &shape); + + std::shared_ptr network() const { return network_; } + + private: + bool InitInputTable(); + bool InitValueNodeTable(); + + FuncGraphPtr func_graph_; + uint32_t batch_size_; + size_t workspace_size_; + std::shared_ptr builder_; + std::shared_ptr network_; + std::shared_ptr config_; + std::shared_ptr engine_; + + // Cache (AnfNode + output_index : ILayer output). + std::unordered_map> output_map_; + std::vector> temp_weights_; +}; +} // namespace opt +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_BACKEND_OPTITIMIZER_TRT_CONVERTER_HELPER_H_ diff --git a/mindspore/ccsrc/backend/optimizer/trt_pass/trt_op_factory.h b/mindspore/ccsrc/backend/optimizer/trt_pass/trt_op_factory.h index c42f11bf8e..87febe09c8 100644 --- a/mindspore/ccsrc/backend/optimizer/trt_pass/trt_op_factory.h +++ b/mindspore/ccsrc/backend/optimizer/trt_pass/trt_op_factory.h @@ -29,9 +29,9 @@ namespace mindspore { namespace opt { class LayerInput; -class TrtConverterHelper; +class TrtConverterContext; using ConvertResult = std::pair>; -using ConvertFunc = std::function)>; +using ConvertFunc = std::function)>; class TrtOpFactory { public: @@ -69,10 +69,10 @@ class TrtOpRegister { }; // Register operator converter from AnfNode to trt layer: `OPNAME` should keep the same as primitive definition. -#define MS_TRT_CONVERTER_FUNC_REG(OPNAME) \ - ConvertResult Gpu##OPNAME##TrtConverter(AnfNodePtr node, std::shared_ptr context); \ - static const TrtOpRegister(Gpu##OPNAME##ConverterRegister)(#OPNAME, Gpu##OPNAME##TrtConverter); \ - ConvertResult Gpu##OPNAME##TrtConverter(AnfNodePtr node, std::shared_ptr context) +#define MS_TRT_CONVERTER_FUNC_REG(OPNAME) \ + ConvertResult Gpu##OPNAME##TrtConverter(AnfNodePtr node, std::shared_ptr context); \ + static const TrtOpRegister(Gpu##OPNAME##ConverterRegister)(#OPNAME, Gpu##OPNAME##TrtConverter); \ + ConvertResult Gpu##OPNAME##TrtConverter(AnfNodePtr node, std::shared_ptr context) } // namespace opt } // namespace mindspore #endif // MINDSPORE_CCSRC_BACKEND_OPTITIMIZER_TRT_PASS_OP_FACTORY_H_