keep shape in class LayerInput

This commit is contained in:
wilfChen 2021-05-26 11:22:32 +08:00
parent 1e8868cecc
commit c490abd459
5 changed files with 104 additions and 176 deletions

View File

@ -17,6 +17,7 @@
#ifndef MINDSPORE_CCSRC_BACKEND_OPTITIMIZER_TRT_PASS_LAYER_INPUT_H_
#define MINDSPORE_CCSRC_BACKEND_OPTITIMIZER_TRT_PASS_LAYER_INPUT_H_
#include <vector>
#include <NvInfer.h>
namespace mindspore::opt {
@ -26,8 +27,10 @@ namespace mindspore::opt {
class LayerInput {
public:
LayerInput() : type_(InputType::kUnknown), weight_(), tensor_(nullptr) {}
explicit LayerInput(nvinfer1::Weights &w) : type_(InputType::kWeight), weight_(w), tensor_(nullptr) {}
explicit LayerInput(nvinfer1::ITensor *t) : type_(InputType::kTensor), weight_(), tensor_(t) {}
explicit LayerInput(nvinfer1::Weights &w, const std::vector<int64_t> &s)
: type_(InputType::kWeight), weight_(w), tensor_(nullptr), shape_(s) {}
explicit LayerInput(nvinfer1::ITensor *t, const std::vector<int64_t> &s)
: type_(InputType::kTensor), weight_(), tensor_(t), shape_(s) {}
bool IsTensor() const { return type_ == InputType::kTensor; }
bool IsWeight() const { return type_ == InputType::kWeight; }
@ -48,6 +51,8 @@ class LayerInput {
return tensor_;
}
const std::vector<int64_t> &shape() const { return shape_; }
private:
enum class InputType : char { kUnknown = 0, kTensor, kWeight };
InputType type_;
@ -55,6 +60,8 @@ class LayerInput {
nvinfer1::Weights weight_;
// Keep the point as ITensor created/held by nvinfer1::INetworkDefinition.
nvinfer1::ITensor *tensor_;
// Keep the shape of tensor or weight.
std::vector<int64_t> shape_;
};
} // namespace mindspore::opt

View File

@ -24,7 +24,6 @@
#include <sstream>
#include <algorithm>
#include "runtime/device/gpu/trt_loader.h"
#include "runtime/device/gpu/cuda_driver.h"
#include "backend/optimizer/trt_pass/trt_op_factory.h"
#include "backend/kernel_compiler/gpu/trt/trt_utils.h"
#include "utils/convert_utils.h"
@ -33,105 +32,7 @@
#include "utils/ms_context.h"
namespace mindspore::opt {
namespace {
void GetRealOutputRecursively(const AnfNodePtr &node, size_t output_index,
std::vector<session::KernelWithIndex> *inputs) {
MS_EXCEPTION_IF_NULL(node);
if (node->isa<ValueNode>() || node->isa<Parameter>()) {
return inputs->push_back(std::make_pair(node, 0));
}
// Skip control node
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimDepend) || AnfAlgo::CheckPrimitiveType(node, prim::kPrimLoad) ||
AnfAlgo::CheckPrimitiveType(node, prim::kPrimUpdateState)) {
return GetRealOutputRecursively(node->cast<CNodePtr>()->input(kRealInputIndexInDepend), 0, inputs);
}
// Bypass TupleGetItem
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimTupleGetItem)) {
auto tuple_get_item = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(tuple_get_item);
auto input = AnfAlgo::GetTupleGetItemRealInput(tuple_get_item);
auto index = AnfAlgo::GetTupleGetItemOutIndex(tuple_get_item);
// Conceal MakeTuple + TupleGetItem pair.
if (AnfAlgo::CheckPrimitiveType(input, prim::kPrimMakeTuple)) {
auto make_tuple = input->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(make_tuple);
auto real_input = AnfAlgo::GetInputNode(make_tuple, index);
return GetRealOutputRecursively(real_input, 0, inputs);
}
// Skip TupleGetItem.
return GetRealOutputRecursively(input, index, inputs);
}
// Flatten MakeTuple inputs.
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimMakeTuple)) {
auto make_tuple = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(make_tuple);
size_t input_num = AnfAlgo::GetInputTensorNum(make_tuple);
for (size_t input_index = 0; input_index < input_num; ++input_index) {
auto input_node = AnfAlgo::GetInputNode(make_tuple, input_index);
GetRealOutputRecursively(input_node, 0, inputs);
}
return;
}
return inputs->push_back(std::make_pair(node, output_index));
}
/* Get node real inputs bypass control nodes.
* Examples:
* Case 1:
* c = Conv2D(a, b)
* d = ReLU(c)
* result: d--> (c)
*
* Case 2:
* c = Conv2D(a, b)
* d = Depend(c, v)
* e = ReLU(d)
* result: d -> (c)
*
* Case 3:
* (f, g, h, i, j) = BatchNorm(a, b, c, d, e)
* k = TupleGetItem((f, g, h, i, j), 0)
* l = ReLU(k)
* result: l -> (f)
*
* Case 4:
* c = Conv2D(a, b)
* e = MakeTuple(c, d)
* f = TupleGetItem(e, 0)
* g = ReLU(k)
* result: g -> (c)
*
* Case 5:
* b = MakeTuple(a1, a2, a3)
* c = MakeTuple(b, a4)
* d = return(c)
* result d -> (a1, a2, a3, a4)
*/
void GetRealInputs(const AnfNodePtr &node, std::vector<session::KernelWithIndex> *inputs) {
size_t input_num = AnfAlgo::GetInputTensorNum(node);
for (size_t input_index = 0; input_index < input_num; ++input_index) {
auto input_node = AnfAlgo::GetInputNode(node->cast<CNodePtr>(), input_index);
GetRealOutputRecursively(input_node, 0, inputs);
}
}
} // namespace
bool TrtConverterContext::Init() {
// Set device id before invoke trt api as cudaSetDevice is thread level config.
const auto &context = MsContext::GetInstance();
const auto &device_id = context->get_param<uint32_t>(MS_CTX_DEVICE_ID);
bool ret = device::gpu::CudaDriver::SetDevice(UintToInt(device_id));
if (!ret) {
MS_LOG(ERROR) << "Failed to set device id:" << device_id;
return false;
}
auto trt_loader = Singleton<device::gpu::TrtLoader>::Instance();
builder_ = trt_loader.CreateInferBuilder(&Singleton<TrtLogger>::Instance());
MS_EXCEPTION_IF_NULL(builder_);
@ -142,13 +43,13 @@ bool TrtConverterContext::Init() {
config_ = TrtPtr(builder_->createBuilderConfig());
MS_EXCEPTION_IF_NULL(config_);
InitInputTable();
InitValueNodeTable();
return true;
}
bool TrtConverterContext::Parser() {
InitInputTable();
InitValueNodeTable();
std::vector<AnfNodePtr> node_list = TopoSort(func_graph_->get_return());
const auto &converter_factory = TrtOpFactory::GetInstance();
for (auto node : node_list) {
@ -156,24 +57,10 @@ bool TrtConverterContext::Parser() {
continue;
}
// Mark graph outputs
std::string op_name = AnfAlgo::GetCNodePrimitive(node)->name();
if (op_name == kReturnOpName) {
std::vector<LayerInput> inputs;
(void)LoadLayerInput(node, &inputs);
for (size_t i = 0; i < inputs.size(); ++i) {
const auto &input = inputs[i].tensor();
std::string name = "return_output_" + std::to_string(i);
input->setName(name.c_str());
network_->markOutput(*input);
}
return true;
}
// Transform AnfNode To Trt layer.
// Bypass control node including Depend, Load, UpdateState, TupleGetItem, MakeTuple.
if (!AnfAlgo::IsRealKernel(node)) {
std::string op_name = AnfAlgo::GetCNodePrimitive(node)->name();
if (!AnfAlgo::IsRealKernel(node) && op_name != "Return") {
continue;
}
@ -190,8 +77,7 @@ bool TrtConverterContext::Parser() {
}
}
MS_LOG(ERROR) << "Graph ended without return node.";
return false;
return true;
}
bool TrtConverterContext::Serialize(std::string *model) {
@ -233,54 +119,58 @@ bool TrtConverterContext::InitInputTable() {
weight.values = tensor->data_c();
weight.type = TrtUtils::MsDtypeToTrtDtype(tensor->data_type());
weight.count = tensor->DataSize();
output_map_[input_node][0] = LayerInput(weight);
output_map_[input_node][0] = LayerInput(weight, tensor->shape());
} else {
nvinfer1::DataType trt_dtype = TrtUtils::MsDtypeToTrtDtype(AnfAlgo::GetOutputInferDataType(input_node, 0));
nvinfer1::Dims trt_dims = TrtUtils::MsDimsToTrtDims(AnfAlgo::GetOutputInferShape(input_node, 0), false);
nvinfer1::ITensor *tensor = network_->addInput(input->name().c_str(), trt_dtype, trt_dims);
output_map_[input_node][0] = LayerInput(tensor);
const std::vector<int64_t> &shape = TrtUtils::TrtDimsToMsDims(trt_dims);
output_map_[input_node][0] = LayerInput(tensor, shape);
}
}
return true;
}
bool TrtConverterContext::InitValueNodeTable() {
auto kernel_graph = std::dynamic_pointer_cast<session::KernelGraph>(func_graph_);
MS_EXCEPTION_IF_NULL(kernel_graph);
MS_EXCEPTION_IF_NULL(func_graph_);
const std::vector<AnfNodePtr> &node_list = TopoSort(func_graph_->get_return());
for (const auto &node : node_list) {
MS_EXCEPTION_IF_NULL(node);
if (node->isa<ValueNode>() && !IsValueNode<FuncGraph>(node)) {
auto value_node = node->cast<ValueNodePtr>();
auto &node_value = value_node->value();
MS_EXCEPTION_IF_NULL(node_value);
for (auto &value_node : kernel_graph->graph_value_nodes()) {
MS_EXCEPTION_IF_NULL(value_node);
auto &node_value = value_node->value();
MS_EXCEPTION_IF_NULL(node_value);
if (node_value->isa<tensor::Tensor>() || node_value->isa<ValueTuple>()) {
std::vector<tensor::TensorPtr> tensors;
TensorValueToTensor(node_value, &tensors);
for (size_t i = 0; i < tensors.size(); i++) {
const auto &tensor = tensors[i];
nvinfer1::Weights weight;
weight.values = tensor->data_c();
weight.type = TrtUtils::MsDtypeToTrtDtype(tensor->data_type());
weight.count = tensor->DataSize();
output_map_[value_node][i] = LayerInput(weight);
if (node_value->isa<tensor::Tensor>() || node_value->isa<ValueTuple>()) {
std::vector<tensor::TensorPtr> tensors;
TensorValueToTensor(node_value, &tensors);
for (size_t i = 0; i < tensors.size(); i++) {
const auto &tensor = tensors[i];
nvinfer1::Weights weight;
weight.values = tensor->data_c();
weight.type = TrtUtils::MsDtypeToTrtDtype(tensor->data_type());
weight.count = tensor->DataSize();
output_map_[value_node][i] = LayerInput(weight, tensor->shape());
}
}
}
}
return true;
}
bool TrtConverterContext::StoreLayerOutput(const AnfNodePtr &node, const std::vector<LayerInput> &nv_tensors) {
bool TrtConverterContext::StoreLayerOutput(const AnfNodePtr &node, const std::vector<nvinfer1::ITensor *> &nv_tensors) {
if (nv_tensors.size() != AnfAlgo::GetOutputTensorNum(node)) {
MS_LOG(INFO) << node->DebugString() << " output num not match. expect: " << AnfAlgo::GetOutputTensorNum(node)
<< ", while got: " << nv_tensors.size();
}
for (size_t tensor_index = 0; tensor_index < nv_tensors.size(); ++tensor_index) {
if (nv_tensors[tensor_index].tensor() != nullptr) {
output_map_[node][tensor_index] = nv_tensors[tensor_index];
if (nv_tensors[tensor_index] != nullptr) {
const nvinfer1::Dims &dim = nv_tensors[tensor_index]->getDimensions();
const std::vector<int64_t> &shape = TrtUtils::TrtDimsToMsDims(dim);
output_map_[node][tensor_index] = LayerInput(nv_tensors[tensor_index], shape);
std::ostringstream oss;
nvinfer1::Dims dim = nv_tensors[tensor_index].tensor()->getDimensions();
oss << node->fullname_with_scope() << ", output: " << tensor_index << ": [ ";
for (int32_t dim_index = 0; dim_index < dim.nbDims; dim_index++) {
oss << dim.d[dim_index] << " ";
@ -294,7 +184,7 @@ bool TrtConverterContext::StoreLayerOutput(const AnfNodePtr &node, const std::ve
bool TrtConverterContext::LoadLayerInput(const AnfNodePtr &node, std::vector<LayerInput> *inputs) {
std::vector<session::KernelWithIndex> real_inputs;
GetRealInputs(node, &real_inputs);
AnfAlgo::GetRealInputs(node, &real_inputs);
for (auto item : real_inputs) {
auto node_iter = output_map_.find(item.first);
if (node_iter == output_map_.end()) {
@ -313,7 +203,7 @@ bool TrtConverterContext::LoadLayerInput(const AnfNodePtr &node, std::vector<Lay
return true;
}
std::vector<AnfNodePtr> TrtConverterContext::GetGraphInputs() {
std::vector<AnfNodePtr> TrtConverterContext::GetGraphInputs() const {
// Get Anf-graph inputs without weights. All weights were binded to Trt-graph.
std::unordered_map<std::string, AnfNodePtr> graph_inputs;
for (const auto &input_node : func_graph_->parameters()) {
@ -342,9 +232,9 @@ std::vector<AnfNodePtr> TrtConverterContext::GetGraphInputs() {
return trt_inputs;
}
std::vector<session::KernelWithIndex> TrtConverterContext::GetGraphOutputs() {
std::vector<session::KernelWithIndex> TrtConverterContext::GetGraphOutputs() const {
std::vector<session::KernelWithIndex> graph_outputs;
GetRealInputs(func_graph_->get_return(), &graph_outputs);
AnfAlgo::GetRealInputs(func_graph_->get_return(), &graph_outputs);
return graph_outputs;
}

View File

@ -50,13 +50,13 @@ class TrtConverterContext : public std::enable_shared_from_this<TrtConverterCont
bool Serialize(std::string *model);
// Get trt graph inputs without weights. The inputs keep same order as binding name.
std::vector<AnfNodePtr> GetGraphInputs();
std::vector<AnfNodePtr> GetGraphInputs() const;
// Get trt graph outputs. All outputs are flatten to vector with concret shape.
std::vector<session::KernelWithIndex> GetGraphOutputs();
std::vector<session::KernelWithIndex> GetGraphOutputs() const;
// Store trt layer outputs to the cache.
bool StoreLayerOutput(const AnfNodePtr &node, const std::vector<LayerInput> &inputs);
bool StoreLayerOutput(const AnfNodePtr &node, const std::vector<nvinfer1::ITensor *> &inputs);
// Get trt layer inputs from the cache.
bool LoadLayerInput(const AnfNodePtr &node, std::vector<LayerInput> *inputs);

View File

@ -51,7 +51,7 @@ ConvertResult AddReshapeLayer(AnfNodePtr node, std::shared_ptr<TrtConverterConte
const nvinfer1::Dims &dims = TrtUtils::MsDimsToTrtDims(output_shape, false);
layer->setReshapeDimensions(dims);
return {true, {LayerInput(layer->getOutput(0))}};
return {true, {layer->getOutput(0)}};
}
ConvertResult AddElementLayer(AnfNodePtr node, std::shared_ptr<TrtConverterContext> context,
@ -97,7 +97,7 @@ ConvertResult AddElementLayer(AnfNodePtr node, std::shared_ptr<TrtConverterConte
auto *layer = context->network()->addElementWise(*x1, *x2, op_type);
MS_EXCEPTION_IF_NULL(layer);
return {true, {LayerInput(layer->getOutput(0))}};
return {true, {layer->getOutput(0)}};
}
ConvertResult AddPoolingLayer(AnfNodePtr node, std::shared_ptr<TrtConverterContext> context,
@ -129,7 +129,7 @@ ConvertResult AddPoolingLayer(AnfNodePtr node, std::shared_ptr<TrtConverterConte
layer->setPaddingMode(nvinfer1::PaddingMode::kSAME_UPPER);
}
return {true, {LayerInput(layer->getOutput(0))}};
return {true, {layer->getOutput(0)}};
}
ConvertResult AddActivationLayer(AnfNodePtr node, std::shared_ptr<TrtConverterContext> context,
@ -144,7 +144,7 @@ ConvertResult AddActivationLayer(AnfNodePtr node, std::shared_ptr<TrtConverterCo
auto *layer = context->network()->addActivation(*inputs[0].tensor(), act_type);
MS_EXCEPTION_IF_NULL(layer);
return {true, {LayerInput(layer->getOutput(0))}};
return {true, {layer->getOutput(0)}};
}
ConvertResult AddUnaryLayer(AnfNodePtr node, std::shared_ptr<TrtConverterContext> context,
@ -159,7 +159,7 @@ ConvertResult AddUnaryLayer(AnfNodePtr node, std::shared_ptr<TrtConverterContext
auto *layer = context->network()->addUnary(*inputs[0].tensor(), op_type);
MS_EXCEPTION_IF_NULL(layer);
return {true, {LayerInput(layer->getOutput(0))}};
return {true, {layer->getOutput(0)}};
}
ConvertResult addReduceLayer(AnfNodePtr node, std::shared_ptr<TrtConverterContext> context,
@ -191,7 +191,7 @@ ConvertResult addReduceLayer(AnfNodePtr node, std::shared_ptr<TrtConverterContex
// Skip reduce operator if reduce_axes == 0
if (reduce_axes == 0) {
MS_LOG(WARNING) << "No dimension be be reduced. " << node->DebugString();
return {true, {LayerInput(inputs[0].tensor())}};
return {true, {inputs[0].tensor()}};
}
bool keep_dims = AnfAlgo::GetNodeAttr<bool>(node, "keep_dims");
@ -215,10 +215,10 @@ ConvertResult addReduceLayer(AnfNodePtr node, std::shared_ptr<TrtConverterContex
dim.d[1] = 1;
reshape_layer->setReshapeDimensions(dim);
return {true, {LayerInput(reshape_layer->getOutput(0))}};
return {true, {reshape_layer->getOutput(0)}};
}
return {true, {LayerInput(layer->getOutput(0))}};
return {true, {layer->getOutput(0)}};
}
} // namespace
@ -265,7 +265,10 @@ MS_TRT_CONVERTER_FUNC_REG(Conv2D) {
layer->setPostPadding(nvinfer1::DimsHW{LongToInt(pad_list[1]), LongToInt(pad_list[3])});
}
return {true, {LayerInput(layer->getOutput(0))}};
const auto &group = AnfAlgo::GetNodeAttr<int64_t>(node, "group");
layer->setNbGroups(SizeToInt(group));
return {true, {layer->getOutput(0)}};
}
// Binary broadcast operators.
@ -367,7 +370,7 @@ MS_TRT_CONVERTER_FUNC_REG(GeLU) {
layer = context->network()->addElementWise(*c1, *layer->getOutput(0), nvinfer1::ElementWiseOperation::kPROD);
MS_EXCEPTION_IF_NULL(layer);
return {true, {LayerInput(layer->getOutput(0))}};
return {true, {layer->getOutput(0)}};
}
MS_TRT_CONVERTER_FUNC_REG(MatMul) {
@ -402,7 +405,7 @@ MS_TRT_CONVERTER_FUNC_REG(MatMul) {
auto *squeeze_y = context->network()->addShuffle(*layer->getOutput(0));
squeeze_y->setReshapeDimensions(y_dims);
return {true, {LayerInput(squeeze_y->getOutput(0))}};
return {true, {squeeze_y->getOutput(0)}};
} else {
// convert weight to tensor and appy addMatrixMultiply
MS_LOG(ERROR) << "Operator not implemented: " << node->DebugString();
@ -446,7 +449,7 @@ MS_TRT_CONVERTER_FUNC_REG(BatchMatMul) {
auto *layer = context->network()->addMatrixMultiply(*tensor1, trt_transpose1, *tensor2, trt_transpose2);
MS_EXCEPTION_IF_NULL(layer);
return {true, {LayerInput(layer->getOutput(0))}};
return {true, {layer->getOutput(0)}};
}
MS_TRT_CONVERTER_FUNC_REG(BiasAdd) {
@ -477,13 +480,14 @@ MS_TRT_CONVERTER_FUNC_REG(BiasAdd) {
auto *layer = context->network()->addElementWise(*inputs[0].tensor(), *bias, nvinfer1::ElementWiseOperation::kSUM);
MS_EXCEPTION_IF_NULL(layer);
return {true, {LayerInput(layer->getOutput(0))}};
return {true, {layer->getOutput(0)}};
}
// NoOp
MS_TRT_CONVERTER_FUNC_REG(Reshape) { return AddReshapeLayer(node, context); }
MS_TRT_CONVERTER_FUNC_REG(ExpandDims) { return AddReshapeLayer(node, context); }
MS_TRT_CONVERTER_FUNC_REG(Squeeze) { return AddReshapeLayer(node, context); }
MS_TRT_CONVERTER_FUNC_REG(Flatten) { return AddReshapeLayer(node, context); }
MS_TRT_CONVERTER_FUNC_REG(BatchNorm) {
std::vector<LayerInput> inputs;
@ -537,7 +541,7 @@ MS_TRT_CONVERTER_FUNC_REG(BatchNorm) {
auto *layer = context->network()->addScale(*inputs[0].tensor(), nvinfer1::ScaleMode::kCHANNEL, shift, scale, pow);
MS_EXCEPTION_IF_NULL(layer);
return {true, {LayerInput(layer->getOutput(0))}};
return {true, {layer->getOutput(0)}};
}
MS_TRT_CONVERTER_FUNC_REG(Concat) {
@ -567,7 +571,7 @@ MS_TRT_CONVERTER_FUNC_REG(Concat) {
}
layer->setAxis(axis);
return {true, {LayerInput(layer->getOutput(0))}};
return {true, {layer->getOutput(0)}};
}
MS_TRT_CONVERTER_FUNC_REG(Conv2DBackpropInput) {
@ -608,7 +612,7 @@ MS_TRT_CONVERTER_FUNC_REG(Conv2DBackpropInput) {
layer->setPostPadding(nvinfer1::DimsHW{LongToInt(pad_list[1]), LongToInt(pad_list[3])});
}
return {true, {LayerInput(layer->getOutput(0))}};
return {true, {layer->getOutput(0)}};
}
MS_TRT_CONVERTER_FUNC_REG(Slice) {
@ -632,7 +636,7 @@ MS_TRT_CONVERTER_FUNC_REG(Slice) {
auto *layer = context->network()->addSlice(*inputs[0].tensor(), trt_start, trt_size, trt_stride);
MS_EXCEPTION_IF_NULL(layer);
return {true, {LayerInput(layer->getOutput(0))}};
return {true, {layer->getOutput(0)}};
}
MS_TRT_CONVERTER_FUNC_REG(Transpose) {
@ -653,7 +657,7 @@ MS_TRT_CONVERTER_FUNC_REG(Transpose) {
MS_EXCEPTION_IF_NULL(layer);
layer->setFirstTranspose(trt_perm);
return {true, {LayerInput(layer->getOutput(0))}};
return {true, {layer->getOutput(0)}};
}
MS_TRT_CONVERTER_FUNC_REG(Softmax) {
@ -684,7 +688,7 @@ MS_TRT_CONVERTER_FUNC_REG(Softmax) {
auto *layer = context->network()->addSoftMax(*inputs[0].tensor());
MS_EXCEPTION_IF_NULL(layer);
layer->setAxes(reduce_axes);
return {true, {LayerInput(layer->getOutput(0))}};
return {true, {layer->getOutput(0)}};
}
MS_TRT_CONVERTER_FUNC_REG(LogSoftmax) {
@ -707,7 +711,7 @@ MS_TRT_CONVERTER_FUNC_REG(LogSoftmax) {
auto *log_layer = context->network()->addUnary(*softmax_layer->getOutput(0), nvinfer1::UnaryOperation::kLOG);
MS_EXCEPTION_IF_NULL(log_layer);
return {true, {LayerInput(log_layer->getOutput(0))}};
return {true, {log_layer->getOutput(0)}};
}
MS_TRT_CONVERTER_FUNC_REG(Gather) {
@ -729,7 +733,7 @@ MS_TRT_CONVERTER_FUNC_REG(Gather) {
auto *layer = context->network()->addGather(*input, *indices, LongToInt(axis));
MS_EXCEPTION_IF_NULL(layer);
return {true, {LayerInput(layer->getOutput(0))}};
return {true, {layer->getOutput(0)}};
}
MS_TRT_CONVERTER_FUNC_REG(Cast) {
@ -744,7 +748,7 @@ MS_TRT_CONVERTER_FUNC_REG(Cast) {
auto trt_type = TrtUtils::MsDtypeToTrtDtype(dst_type);
auto *layer = context->network()->addIdentity(*inputs[0].tensor());
layer->setOutputType(0, trt_type);
return {true, {LayerInput(layer->getOutput(0))}};
return {true, {layer->getOutput(0)}};
}
MS_TRT_CONVERTER_FUNC_REG(LayerNorm) {
@ -810,7 +814,33 @@ MS_TRT_CONVERTER_FUNC_REG(LayerNorm) {
nvinfer1::ElementWiseOperation::kSUM);
MS_EXCEPTION_IF_NULL(add);
return {true, {LayerInput(add->getOutput(0))}};
return {true, {add->getOutput(0)}};
}
MS_TRT_CONVERTER_FUNC_REG(Return) {
std::vector<LayerInput> inputs;
bool ret = context->LoadLayerInput(node, &inputs);
if (!ret) {
return {false, {}};
}
for (size_t i = 0; i < inputs.size(); ++i) {
nvinfer1::ITensor *input = nullptr;
if (inputs[i].IsTensor()) {
input = inputs[i].tensor();
} else {
std::vector<size_t> shape;
std::transform(inputs[i].shape().begin(), inputs[i].shape().end(), std::back_inserter(shape),
[](int64_t d) { return LongToSize(d); });
input = ToTensor(&inputs[i], shape, context);
}
const std::string &name = "return_output_" + std::to_string(i);
input->setName(name.c_str());
context->network()->markOutput(*input);
}
return {true, {}};
}
} // namespace opt
} // namespace mindspore

View File

@ -23,6 +23,7 @@
#include <utility>
#include <string>
#include <memory>
#include <NvInfer.h>
#include "base/base.h"
#include "ir/anf.h"
@ -30,7 +31,7 @@ namespace mindspore {
namespace opt {
class LayerInput;
class TrtConverterContext;
using ConvertResult = std::pair<bool, std::vector<LayerInput>>;
using ConvertResult = std::pair<bool, std::vector<nvinfer1::ITensor *>>;
using ConvertFunc = std::function<ConvertResult(AnfNodePtr, std::shared_ptr<TrtConverterContext>)>;
class TrtOpFactory {