diff --git a/mindspore/ccsrc/backend/optimizer/trt_pass/layer_input.h b/mindspore/ccsrc/backend/optimizer/trt_pass/layer_input.h index a5becb8d853..5e3c815b5b5 100644 --- a/mindspore/ccsrc/backend/optimizer/trt_pass/layer_input.h +++ b/mindspore/ccsrc/backend/optimizer/trt_pass/layer_input.h @@ -17,6 +17,7 @@ #ifndef MINDSPORE_CCSRC_BACKEND_OPTITIMIZER_TRT_PASS_LAYER_INPUT_H_ #define MINDSPORE_CCSRC_BACKEND_OPTITIMIZER_TRT_PASS_LAYER_INPUT_H_ +#include #include namespace mindspore::opt { @@ -26,8 +27,10 @@ namespace mindspore::opt { class LayerInput { public: LayerInput() : type_(InputType::kUnknown), weight_(), tensor_(nullptr) {} - explicit LayerInput(nvinfer1::Weights &w) : type_(InputType::kWeight), weight_(w), tensor_(nullptr) {} - explicit LayerInput(nvinfer1::ITensor *t) : type_(InputType::kTensor), weight_(), tensor_(t) {} + explicit LayerInput(nvinfer1::Weights &w, const std::vector &s) + : type_(InputType::kWeight), weight_(w), tensor_(nullptr), shape_(s) {} + explicit LayerInput(nvinfer1::ITensor *t, const std::vector &s) + : type_(InputType::kTensor), weight_(), tensor_(t), shape_(s) {} bool IsTensor() const { return type_ == InputType::kTensor; } bool IsWeight() const { return type_ == InputType::kWeight; } @@ -48,6 +51,8 @@ class LayerInput { return tensor_; } + const std::vector &shape() const { return shape_; } + private: enum class InputType : char { kUnknown = 0, kTensor, kWeight }; InputType type_; @@ -55,6 +60,8 @@ class LayerInput { nvinfer1::Weights weight_; // Keep the point as ITensor created/held by nvinfer1::INetworkDefinition. nvinfer1::ITensor *tensor_; + // Keep the shape of tensor or weight. + std::vector shape_; }; } // namespace mindspore::opt diff --git a/mindspore/ccsrc/backend/optimizer/trt_pass/trt_converter_context.cc b/mindspore/ccsrc/backend/optimizer/trt_pass/trt_converter_context.cc index bff99837e68..edd9b66f0f4 100644 --- a/mindspore/ccsrc/backend/optimizer/trt_pass/trt_converter_context.cc +++ b/mindspore/ccsrc/backend/optimizer/trt_pass/trt_converter_context.cc @@ -24,7 +24,6 @@ #include #include #include "runtime/device/gpu/trt_loader.h" -#include "runtime/device/gpu/cuda_driver.h" #include "backend/optimizer/trt_pass/trt_op_factory.h" #include "backend/kernel_compiler/gpu/trt/trt_utils.h" #include "utils/convert_utils.h" @@ -33,105 +32,7 @@ #include "utils/ms_context.h" namespace mindspore::opt { -namespace { -void GetRealOutputRecursively(const AnfNodePtr &node, size_t output_index, - std::vector *inputs) { - MS_EXCEPTION_IF_NULL(node); - if (node->isa() || node->isa()) { - return inputs->push_back(std::make_pair(node, 0)); - } - - // Skip control node - if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimDepend) || AnfAlgo::CheckPrimitiveType(node, prim::kPrimLoad) || - AnfAlgo::CheckPrimitiveType(node, prim::kPrimUpdateState)) { - return GetRealOutputRecursively(node->cast()->input(kRealInputIndexInDepend), 0, inputs); - } - - // Bypass TupleGetItem - if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimTupleGetItem)) { - auto tuple_get_item = node->cast(); - MS_EXCEPTION_IF_NULL(tuple_get_item); - auto input = AnfAlgo::GetTupleGetItemRealInput(tuple_get_item); - auto index = AnfAlgo::GetTupleGetItemOutIndex(tuple_get_item); - - // Conceal MakeTuple + TupleGetItem pair. - if (AnfAlgo::CheckPrimitiveType(input, prim::kPrimMakeTuple)) { - auto make_tuple = input->cast(); - MS_EXCEPTION_IF_NULL(make_tuple); - auto real_input = AnfAlgo::GetInputNode(make_tuple, index); - return GetRealOutputRecursively(real_input, 0, inputs); - } - - // Skip TupleGetItem. - return GetRealOutputRecursively(input, index, inputs); - } - - // Flatten MakeTuple inputs. - if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimMakeTuple)) { - auto make_tuple = node->cast(); - MS_EXCEPTION_IF_NULL(make_tuple); - size_t input_num = AnfAlgo::GetInputTensorNum(make_tuple); - for (size_t input_index = 0; input_index < input_num; ++input_index) { - auto input_node = AnfAlgo::GetInputNode(make_tuple, input_index); - GetRealOutputRecursively(input_node, 0, inputs); - } - return; - } - - return inputs->push_back(std::make_pair(node, output_index)); -} - -/* Get node real inputs bypass control nodes. - * Examples: - * Case 1: - * c = Conv2D(a, b) - * d = ReLU(c) - * result: d--> (c) - * - * Case 2: - * c = Conv2D(a, b) - * d = Depend(c, v) - * e = ReLU(d) - * result: d -> (c) - * - * Case 3: - * (f, g, h, i, j) = BatchNorm(a, b, c, d, e) - * k = TupleGetItem((f, g, h, i, j), 0) - * l = ReLU(k) - * result: l -> (f) - * - * Case 4: - * c = Conv2D(a, b) - * e = MakeTuple(c, d) - * f = TupleGetItem(e, 0) - * g = ReLU(k) - * result: g -> (c) - * - * Case 5: - * b = MakeTuple(a1, a2, a3) - * c = MakeTuple(b, a4) - * d = return(c) - * result d -> (a1, a2, a3, a4) - */ -void GetRealInputs(const AnfNodePtr &node, std::vector *inputs) { - size_t input_num = AnfAlgo::GetInputTensorNum(node); - for (size_t input_index = 0; input_index < input_num; ++input_index) { - auto input_node = AnfAlgo::GetInputNode(node->cast(), input_index); - GetRealOutputRecursively(input_node, 0, inputs); - } -} -} // namespace - bool TrtConverterContext::Init() { - // Set device id before invoke trt api as cudaSetDevice is thread level config. - const auto &context = MsContext::GetInstance(); - const auto &device_id = context->get_param(MS_CTX_DEVICE_ID); - bool ret = device::gpu::CudaDriver::SetDevice(UintToInt(device_id)); - if (!ret) { - MS_LOG(ERROR) << "Failed to set device id:" << device_id; - return false; - } - auto trt_loader = Singleton::Instance(); builder_ = trt_loader.CreateInferBuilder(&Singleton::Instance()); MS_EXCEPTION_IF_NULL(builder_); @@ -142,13 +43,13 @@ bool TrtConverterContext::Init() { config_ = TrtPtr(builder_->createBuilderConfig()); MS_EXCEPTION_IF_NULL(config_); + + InitInputTable(); + InitValueNodeTable(); return true; } bool TrtConverterContext::Parser() { - InitInputTable(); - InitValueNodeTable(); - std::vector node_list = TopoSort(func_graph_->get_return()); const auto &converter_factory = TrtOpFactory::GetInstance(); for (auto node : node_list) { @@ -156,24 +57,10 @@ bool TrtConverterContext::Parser() { continue; } - // Mark graph outputs - std::string op_name = AnfAlgo::GetCNodePrimitive(node)->name(); - if (op_name == kReturnOpName) { - std::vector inputs; - (void)LoadLayerInput(node, &inputs); - - for (size_t i = 0; i < inputs.size(); ++i) { - const auto &input = inputs[i].tensor(); - std::string name = "return_output_" + std::to_string(i); - input->setName(name.c_str()); - network_->markOutput(*input); - } - return true; - } - // Transform AnfNode To Trt layer. // Bypass control node including Depend, Load, UpdateState, TupleGetItem, MakeTuple. - if (!AnfAlgo::IsRealKernel(node)) { + std::string op_name = AnfAlgo::GetCNodePrimitive(node)->name(); + if (!AnfAlgo::IsRealKernel(node) && op_name != "Return") { continue; } @@ -190,8 +77,7 @@ bool TrtConverterContext::Parser() { } } - MS_LOG(ERROR) << "Graph ended without return node."; - return false; + return true; } bool TrtConverterContext::Serialize(std::string *model) { @@ -233,54 +119,58 @@ bool TrtConverterContext::InitInputTable() { weight.values = tensor->data_c(); weight.type = TrtUtils::MsDtypeToTrtDtype(tensor->data_type()); weight.count = tensor->DataSize(); - output_map_[input_node][0] = LayerInput(weight); + output_map_[input_node][0] = LayerInput(weight, tensor->shape()); } else { nvinfer1::DataType trt_dtype = TrtUtils::MsDtypeToTrtDtype(AnfAlgo::GetOutputInferDataType(input_node, 0)); nvinfer1::Dims trt_dims = TrtUtils::MsDimsToTrtDims(AnfAlgo::GetOutputInferShape(input_node, 0), false); nvinfer1::ITensor *tensor = network_->addInput(input->name().c_str(), trt_dtype, trt_dims); - output_map_[input_node][0] = LayerInput(tensor); + const std::vector &shape = TrtUtils::TrtDimsToMsDims(trt_dims); + output_map_[input_node][0] = LayerInput(tensor, shape); } } return true; } bool TrtConverterContext::InitValueNodeTable() { - auto kernel_graph = std::dynamic_pointer_cast(func_graph_); - MS_EXCEPTION_IF_NULL(kernel_graph); + MS_EXCEPTION_IF_NULL(func_graph_); + const std::vector &node_list = TopoSort(func_graph_->get_return()); + for (const auto &node : node_list) { + MS_EXCEPTION_IF_NULL(node); + if (node->isa() && !IsValueNode(node)) { + auto value_node = node->cast(); + auto &node_value = value_node->value(); + MS_EXCEPTION_IF_NULL(node_value); - for (auto &value_node : kernel_graph->graph_value_nodes()) { - MS_EXCEPTION_IF_NULL(value_node); - auto &node_value = value_node->value(); - MS_EXCEPTION_IF_NULL(node_value); - - if (node_value->isa() || node_value->isa()) { - std::vector tensors; - TensorValueToTensor(node_value, &tensors); - for (size_t i = 0; i < tensors.size(); i++) { - const auto &tensor = tensors[i]; - nvinfer1::Weights weight; - weight.values = tensor->data_c(); - weight.type = TrtUtils::MsDtypeToTrtDtype(tensor->data_type()); - weight.count = tensor->DataSize(); - output_map_[value_node][i] = LayerInput(weight); + if (node_value->isa() || node_value->isa()) { + std::vector tensors; + TensorValueToTensor(node_value, &tensors); + for (size_t i = 0; i < tensors.size(); i++) { + const auto &tensor = tensors[i]; + nvinfer1::Weights weight; + weight.values = tensor->data_c(); + weight.type = TrtUtils::MsDtypeToTrtDtype(tensor->data_type()); + weight.count = tensor->DataSize(); + output_map_[value_node][i] = LayerInput(weight, tensor->shape()); + } } } } return true; } -bool TrtConverterContext::StoreLayerOutput(const AnfNodePtr &node, const std::vector &nv_tensors) { +bool TrtConverterContext::StoreLayerOutput(const AnfNodePtr &node, const std::vector &nv_tensors) { if (nv_tensors.size() != AnfAlgo::GetOutputTensorNum(node)) { MS_LOG(INFO) << node->DebugString() << " output num not match. expect: " << AnfAlgo::GetOutputTensorNum(node) << ", while got: " << nv_tensors.size(); } for (size_t tensor_index = 0; tensor_index < nv_tensors.size(); ++tensor_index) { - if (nv_tensors[tensor_index].tensor() != nullptr) { - output_map_[node][tensor_index] = nv_tensors[tensor_index]; + if (nv_tensors[tensor_index] != nullptr) { + const nvinfer1::Dims &dim = nv_tensors[tensor_index]->getDimensions(); + const std::vector &shape = TrtUtils::TrtDimsToMsDims(dim); + output_map_[node][tensor_index] = LayerInput(nv_tensors[tensor_index], shape); std::ostringstream oss; - nvinfer1::Dims dim = nv_tensors[tensor_index].tensor()->getDimensions(); oss << node->fullname_with_scope() << ", output: " << tensor_index << ": [ "; for (int32_t dim_index = 0; dim_index < dim.nbDims; dim_index++) { oss << dim.d[dim_index] << " "; @@ -294,7 +184,7 @@ bool TrtConverterContext::StoreLayerOutput(const AnfNodePtr &node, const std::ve bool TrtConverterContext::LoadLayerInput(const AnfNodePtr &node, std::vector *inputs) { std::vector real_inputs; - GetRealInputs(node, &real_inputs); + AnfAlgo::GetRealInputs(node, &real_inputs); for (auto item : real_inputs) { auto node_iter = output_map_.find(item.first); if (node_iter == output_map_.end()) { @@ -313,7 +203,7 @@ bool TrtConverterContext::LoadLayerInput(const AnfNodePtr &node, std::vector TrtConverterContext::GetGraphInputs() { +std::vector TrtConverterContext::GetGraphInputs() const { // Get Anf-graph inputs without weights. All weights were binded to Trt-graph. std::unordered_map graph_inputs; for (const auto &input_node : func_graph_->parameters()) { @@ -342,9 +232,9 @@ std::vector TrtConverterContext::GetGraphInputs() { return trt_inputs; } -std::vector TrtConverterContext::GetGraphOutputs() { +std::vector TrtConverterContext::GetGraphOutputs() const { std::vector graph_outputs; - GetRealInputs(func_graph_->get_return(), &graph_outputs); + AnfAlgo::GetRealInputs(func_graph_->get_return(), &graph_outputs); return graph_outputs; } diff --git a/mindspore/ccsrc/backend/optimizer/trt_pass/trt_converter_context.h b/mindspore/ccsrc/backend/optimizer/trt_pass/trt_converter_context.h index c9e51571143..a73de0d324c 100644 --- a/mindspore/ccsrc/backend/optimizer/trt_pass/trt_converter_context.h +++ b/mindspore/ccsrc/backend/optimizer/trt_pass/trt_converter_context.h @@ -50,13 +50,13 @@ class TrtConverterContext : public std::enable_shared_from_this GetGraphInputs(); + std::vector GetGraphInputs() const; // Get trt graph outputs. All outputs are flatten to vector with concret shape. - std::vector GetGraphOutputs(); + std::vector GetGraphOutputs() const; // Store trt layer outputs to the cache. - bool StoreLayerOutput(const AnfNodePtr &node, const std::vector &inputs); + bool StoreLayerOutput(const AnfNodePtr &node, const std::vector &inputs); // Get trt layer inputs from the cache. bool LoadLayerInput(const AnfNodePtr &node, std::vector *inputs); diff --git a/mindspore/ccsrc/backend/optimizer/trt_pass/trt_op_converter.cc b/mindspore/ccsrc/backend/optimizer/trt_pass/trt_op_converter.cc index ca416467b17..a6bf21887f1 100644 --- a/mindspore/ccsrc/backend/optimizer/trt_pass/trt_op_converter.cc +++ b/mindspore/ccsrc/backend/optimizer/trt_pass/trt_op_converter.cc @@ -51,7 +51,7 @@ ConvertResult AddReshapeLayer(AnfNodePtr node, std::shared_ptrsetReshapeDimensions(dims); - return {true, {LayerInput(layer->getOutput(0))}}; + return {true, {layer->getOutput(0)}}; } ConvertResult AddElementLayer(AnfNodePtr node, std::shared_ptr context, @@ -97,7 +97,7 @@ ConvertResult AddElementLayer(AnfNodePtr node, std::shared_ptrnetwork()->addElementWise(*x1, *x2, op_type); MS_EXCEPTION_IF_NULL(layer); - return {true, {LayerInput(layer->getOutput(0))}}; + return {true, {layer->getOutput(0)}}; } ConvertResult AddPoolingLayer(AnfNodePtr node, std::shared_ptr context, @@ -129,7 +129,7 @@ ConvertResult AddPoolingLayer(AnfNodePtr node, std::shared_ptrsetPaddingMode(nvinfer1::PaddingMode::kSAME_UPPER); } - return {true, {LayerInput(layer->getOutput(0))}}; + return {true, {layer->getOutput(0)}}; } ConvertResult AddActivationLayer(AnfNodePtr node, std::shared_ptr context, @@ -144,7 +144,7 @@ ConvertResult AddActivationLayer(AnfNodePtr node, std::shared_ptrnetwork()->addActivation(*inputs[0].tensor(), act_type); MS_EXCEPTION_IF_NULL(layer); - return {true, {LayerInput(layer->getOutput(0))}}; + return {true, {layer->getOutput(0)}}; } ConvertResult AddUnaryLayer(AnfNodePtr node, std::shared_ptr context, @@ -159,7 +159,7 @@ ConvertResult AddUnaryLayer(AnfNodePtr node, std::shared_ptrnetwork()->addUnary(*inputs[0].tensor(), op_type); MS_EXCEPTION_IF_NULL(layer); - return {true, {LayerInput(layer->getOutput(0))}}; + return {true, {layer->getOutput(0)}}; } ConvertResult addReduceLayer(AnfNodePtr node, std::shared_ptr context, @@ -191,7 +191,7 @@ ConvertResult addReduceLayer(AnfNodePtr node, std::shared_ptrDebugString(); - return {true, {LayerInput(inputs[0].tensor())}}; + return {true, {inputs[0].tensor()}}; } bool keep_dims = AnfAlgo::GetNodeAttr(node, "keep_dims"); @@ -215,10 +215,10 @@ ConvertResult addReduceLayer(AnfNodePtr node, std::shared_ptrsetReshapeDimensions(dim); - return {true, {LayerInput(reshape_layer->getOutput(0))}}; + return {true, {reshape_layer->getOutput(0)}}; } - return {true, {LayerInput(layer->getOutput(0))}}; + return {true, {layer->getOutput(0)}}; } } // namespace @@ -265,7 +265,10 @@ MS_TRT_CONVERTER_FUNC_REG(Conv2D) { layer->setPostPadding(nvinfer1::DimsHW{LongToInt(pad_list[1]), LongToInt(pad_list[3])}); } - return {true, {LayerInput(layer->getOutput(0))}}; + const auto &group = AnfAlgo::GetNodeAttr(node, "group"); + layer->setNbGroups(SizeToInt(group)); + + return {true, {layer->getOutput(0)}}; } // Binary broadcast operators. @@ -367,7 +370,7 @@ MS_TRT_CONVERTER_FUNC_REG(GeLU) { layer = context->network()->addElementWise(*c1, *layer->getOutput(0), nvinfer1::ElementWiseOperation::kPROD); MS_EXCEPTION_IF_NULL(layer); - return {true, {LayerInput(layer->getOutput(0))}}; + return {true, {layer->getOutput(0)}}; } MS_TRT_CONVERTER_FUNC_REG(MatMul) { @@ -402,7 +405,7 @@ MS_TRT_CONVERTER_FUNC_REG(MatMul) { auto *squeeze_y = context->network()->addShuffle(*layer->getOutput(0)); squeeze_y->setReshapeDimensions(y_dims); - return {true, {LayerInput(squeeze_y->getOutput(0))}}; + return {true, {squeeze_y->getOutput(0)}}; } else { // convert weight to tensor and appy addMatrixMultiply MS_LOG(ERROR) << "Operator not implemented: " << node->DebugString(); @@ -446,7 +449,7 @@ MS_TRT_CONVERTER_FUNC_REG(BatchMatMul) { auto *layer = context->network()->addMatrixMultiply(*tensor1, trt_transpose1, *tensor2, trt_transpose2); MS_EXCEPTION_IF_NULL(layer); - return {true, {LayerInput(layer->getOutput(0))}}; + return {true, {layer->getOutput(0)}}; } MS_TRT_CONVERTER_FUNC_REG(BiasAdd) { @@ -477,13 +480,14 @@ MS_TRT_CONVERTER_FUNC_REG(BiasAdd) { auto *layer = context->network()->addElementWise(*inputs[0].tensor(), *bias, nvinfer1::ElementWiseOperation::kSUM); MS_EXCEPTION_IF_NULL(layer); - return {true, {LayerInput(layer->getOutput(0))}}; + return {true, {layer->getOutput(0)}}; } // NoOp MS_TRT_CONVERTER_FUNC_REG(Reshape) { return AddReshapeLayer(node, context); } MS_TRT_CONVERTER_FUNC_REG(ExpandDims) { return AddReshapeLayer(node, context); } MS_TRT_CONVERTER_FUNC_REG(Squeeze) { return AddReshapeLayer(node, context); } +MS_TRT_CONVERTER_FUNC_REG(Flatten) { return AddReshapeLayer(node, context); } MS_TRT_CONVERTER_FUNC_REG(BatchNorm) { std::vector inputs; @@ -537,7 +541,7 @@ MS_TRT_CONVERTER_FUNC_REG(BatchNorm) { auto *layer = context->network()->addScale(*inputs[0].tensor(), nvinfer1::ScaleMode::kCHANNEL, shift, scale, pow); MS_EXCEPTION_IF_NULL(layer); - return {true, {LayerInput(layer->getOutput(0))}}; + return {true, {layer->getOutput(0)}}; } MS_TRT_CONVERTER_FUNC_REG(Concat) { @@ -567,7 +571,7 @@ MS_TRT_CONVERTER_FUNC_REG(Concat) { } layer->setAxis(axis); - return {true, {LayerInput(layer->getOutput(0))}}; + return {true, {layer->getOutput(0)}}; } MS_TRT_CONVERTER_FUNC_REG(Conv2DBackpropInput) { @@ -608,7 +612,7 @@ MS_TRT_CONVERTER_FUNC_REG(Conv2DBackpropInput) { layer->setPostPadding(nvinfer1::DimsHW{LongToInt(pad_list[1]), LongToInt(pad_list[3])}); } - return {true, {LayerInput(layer->getOutput(0))}}; + return {true, {layer->getOutput(0)}}; } MS_TRT_CONVERTER_FUNC_REG(Slice) { @@ -632,7 +636,7 @@ MS_TRT_CONVERTER_FUNC_REG(Slice) { auto *layer = context->network()->addSlice(*inputs[0].tensor(), trt_start, trt_size, trt_stride); MS_EXCEPTION_IF_NULL(layer); - return {true, {LayerInput(layer->getOutput(0))}}; + return {true, {layer->getOutput(0)}}; } MS_TRT_CONVERTER_FUNC_REG(Transpose) { @@ -653,7 +657,7 @@ MS_TRT_CONVERTER_FUNC_REG(Transpose) { MS_EXCEPTION_IF_NULL(layer); layer->setFirstTranspose(trt_perm); - return {true, {LayerInput(layer->getOutput(0))}}; + return {true, {layer->getOutput(0)}}; } MS_TRT_CONVERTER_FUNC_REG(Softmax) { @@ -684,7 +688,7 @@ MS_TRT_CONVERTER_FUNC_REG(Softmax) { auto *layer = context->network()->addSoftMax(*inputs[0].tensor()); MS_EXCEPTION_IF_NULL(layer); layer->setAxes(reduce_axes); - return {true, {LayerInput(layer->getOutput(0))}}; + return {true, {layer->getOutput(0)}}; } MS_TRT_CONVERTER_FUNC_REG(LogSoftmax) { @@ -707,7 +711,7 @@ MS_TRT_CONVERTER_FUNC_REG(LogSoftmax) { auto *log_layer = context->network()->addUnary(*softmax_layer->getOutput(0), nvinfer1::UnaryOperation::kLOG); MS_EXCEPTION_IF_NULL(log_layer); - return {true, {LayerInput(log_layer->getOutput(0))}}; + return {true, {log_layer->getOutput(0)}}; } MS_TRT_CONVERTER_FUNC_REG(Gather) { @@ -729,7 +733,7 @@ MS_TRT_CONVERTER_FUNC_REG(Gather) { auto *layer = context->network()->addGather(*input, *indices, LongToInt(axis)); MS_EXCEPTION_IF_NULL(layer); - return {true, {LayerInput(layer->getOutput(0))}}; + return {true, {layer->getOutput(0)}}; } MS_TRT_CONVERTER_FUNC_REG(Cast) { @@ -744,7 +748,7 @@ MS_TRT_CONVERTER_FUNC_REG(Cast) { auto trt_type = TrtUtils::MsDtypeToTrtDtype(dst_type); auto *layer = context->network()->addIdentity(*inputs[0].tensor()); layer->setOutputType(0, trt_type); - return {true, {LayerInput(layer->getOutput(0))}}; + return {true, {layer->getOutput(0)}}; } MS_TRT_CONVERTER_FUNC_REG(LayerNorm) { @@ -810,7 +814,33 @@ MS_TRT_CONVERTER_FUNC_REG(LayerNorm) { nvinfer1::ElementWiseOperation::kSUM); MS_EXCEPTION_IF_NULL(add); - return {true, {LayerInput(add->getOutput(0))}}; + return {true, {add->getOutput(0)}}; +} + +MS_TRT_CONVERTER_FUNC_REG(Return) { + std::vector inputs; + bool ret = context->LoadLayerInput(node, &inputs); + if (!ret) { + return {false, {}}; + } + + for (size_t i = 0; i < inputs.size(); ++i) { + nvinfer1::ITensor *input = nullptr; + if (inputs[i].IsTensor()) { + input = inputs[i].tensor(); + } else { + std::vector shape; + std::transform(inputs[i].shape().begin(), inputs[i].shape().end(), std::back_inserter(shape), + [](int64_t d) { return LongToSize(d); }); + input = ToTensor(&inputs[i], shape, context); + } + + const std::string &name = "return_output_" + std::to_string(i); + input->setName(name.c_str()); + context->network()->markOutput(*input); + } + + return {true, {}}; } } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/trt_pass/trt_op_factory.h b/mindspore/ccsrc/backend/optimizer/trt_pass/trt_op_factory.h index 498c6479097..5cc8d8da88a 100644 --- a/mindspore/ccsrc/backend/optimizer/trt_pass/trt_op_factory.h +++ b/mindspore/ccsrc/backend/optimizer/trt_pass/trt_op_factory.h @@ -23,6 +23,7 @@ #include #include #include +#include #include "base/base.h" #include "ir/anf.h" @@ -30,7 +31,7 @@ namespace mindspore { namespace opt { class LayerInput; class TrtConverterContext; -using ConvertResult = std::pair>; +using ConvertResult = std::pair>; using ConvertFunc = std::function)>; class TrtOpFactory {