forked from mindspore-Ecosystem/mindspore
keep shape in class LayerInput
This commit is contained in:
parent
1e8868cecc
commit
c490abd459
|
@ -17,6 +17,7 @@
|
|||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTITIMIZER_TRT_PASS_LAYER_INPUT_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_OPTITIMIZER_TRT_PASS_LAYER_INPUT_H_
|
||||
|
||||
#include <vector>
|
||||
#include <NvInfer.h>
|
||||
|
||||
namespace mindspore::opt {
|
||||
|
@ -26,8 +27,10 @@ namespace mindspore::opt {
|
|||
class LayerInput {
|
||||
public:
|
||||
LayerInput() : type_(InputType::kUnknown), weight_(), tensor_(nullptr) {}
|
||||
explicit LayerInput(nvinfer1::Weights &w) : type_(InputType::kWeight), weight_(w), tensor_(nullptr) {}
|
||||
explicit LayerInput(nvinfer1::ITensor *t) : type_(InputType::kTensor), weight_(), tensor_(t) {}
|
||||
explicit LayerInput(nvinfer1::Weights &w, const std::vector<int64_t> &s)
|
||||
: type_(InputType::kWeight), weight_(w), tensor_(nullptr), shape_(s) {}
|
||||
explicit LayerInput(nvinfer1::ITensor *t, const std::vector<int64_t> &s)
|
||||
: type_(InputType::kTensor), weight_(), tensor_(t), shape_(s) {}
|
||||
|
||||
bool IsTensor() const { return type_ == InputType::kTensor; }
|
||||
bool IsWeight() const { return type_ == InputType::kWeight; }
|
||||
|
@ -48,6 +51,8 @@ class LayerInput {
|
|||
return tensor_;
|
||||
}
|
||||
|
||||
const std::vector<int64_t> &shape() const { return shape_; }
|
||||
|
||||
private:
|
||||
enum class InputType : char { kUnknown = 0, kTensor, kWeight };
|
||||
InputType type_;
|
||||
|
@ -55,6 +60,8 @@ class LayerInput {
|
|||
nvinfer1::Weights weight_;
|
||||
// Keep the point as ITensor created/held by nvinfer1::INetworkDefinition.
|
||||
nvinfer1::ITensor *tensor_;
|
||||
// Keep the shape of tensor or weight.
|
||||
std::vector<int64_t> shape_;
|
||||
};
|
||||
} // namespace mindspore::opt
|
||||
|
||||
|
|
|
@ -24,7 +24,6 @@
|
|||
#include <sstream>
|
||||
#include <algorithm>
|
||||
#include "runtime/device/gpu/trt_loader.h"
|
||||
#include "runtime/device/gpu/cuda_driver.h"
|
||||
#include "backend/optimizer/trt_pass/trt_op_factory.h"
|
||||
#include "backend/kernel_compiler/gpu/trt/trt_utils.h"
|
||||
#include "utils/convert_utils.h"
|
||||
|
@ -33,105 +32,7 @@
|
|||
#include "utils/ms_context.h"
|
||||
|
||||
namespace mindspore::opt {
|
||||
namespace {
|
||||
void GetRealOutputRecursively(const AnfNodePtr &node, size_t output_index,
|
||||
std::vector<session::KernelWithIndex> *inputs) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (node->isa<ValueNode>() || node->isa<Parameter>()) {
|
||||
return inputs->push_back(std::make_pair(node, 0));
|
||||
}
|
||||
|
||||
// Skip control node
|
||||
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimDepend) || AnfAlgo::CheckPrimitiveType(node, prim::kPrimLoad) ||
|
||||
AnfAlgo::CheckPrimitiveType(node, prim::kPrimUpdateState)) {
|
||||
return GetRealOutputRecursively(node->cast<CNodePtr>()->input(kRealInputIndexInDepend), 0, inputs);
|
||||
}
|
||||
|
||||
// Bypass TupleGetItem
|
||||
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimTupleGetItem)) {
|
||||
auto tuple_get_item = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(tuple_get_item);
|
||||
auto input = AnfAlgo::GetTupleGetItemRealInput(tuple_get_item);
|
||||
auto index = AnfAlgo::GetTupleGetItemOutIndex(tuple_get_item);
|
||||
|
||||
// Conceal MakeTuple + TupleGetItem pair.
|
||||
if (AnfAlgo::CheckPrimitiveType(input, prim::kPrimMakeTuple)) {
|
||||
auto make_tuple = input->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(make_tuple);
|
||||
auto real_input = AnfAlgo::GetInputNode(make_tuple, index);
|
||||
return GetRealOutputRecursively(real_input, 0, inputs);
|
||||
}
|
||||
|
||||
// Skip TupleGetItem.
|
||||
return GetRealOutputRecursively(input, index, inputs);
|
||||
}
|
||||
|
||||
// Flatten MakeTuple inputs.
|
||||
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimMakeTuple)) {
|
||||
auto make_tuple = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(make_tuple);
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(make_tuple);
|
||||
for (size_t input_index = 0; input_index < input_num; ++input_index) {
|
||||
auto input_node = AnfAlgo::GetInputNode(make_tuple, input_index);
|
||||
GetRealOutputRecursively(input_node, 0, inputs);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
return inputs->push_back(std::make_pair(node, output_index));
|
||||
}
|
||||
|
||||
/* Get node real inputs bypass control nodes.
|
||||
* Examples:
|
||||
* Case 1:
|
||||
* c = Conv2D(a, b)
|
||||
* d = ReLU(c)
|
||||
* result: d--> (c)
|
||||
*
|
||||
* Case 2:
|
||||
* c = Conv2D(a, b)
|
||||
* d = Depend(c, v)
|
||||
* e = ReLU(d)
|
||||
* result: d -> (c)
|
||||
*
|
||||
* Case 3:
|
||||
* (f, g, h, i, j) = BatchNorm(a, b, c, d, e)
|
||||
* k = TupleGetItem((f, g, h, i, j), 0)
|
||||
* l = ReLU(k)
|
||||
* result: l -> (f)
|
||||
*
|
||||
* Case 4:
|
||||
* c = Conv2D(a, b)
|
||||
* e = MakeTuple(c, d)
|
||||
* f = TupleGetItem(e, 0)
|
||||
* g = ReLU(k)
|
||||
* result: g -> (c)
|
||||
*
|
||||
* Case 5:
|
||||
* b = MakeTuple(a1, a2, a3)
|
||||
* c = MakeTuple(b, a4)
|
||||
* d = return(c)
|
||||
* result d -> (a1, a2, a3, a4)
|
||||
*/
|
||||
void GetRealInputs(const AnfNodePtr &node, std::vector<session::KernelWithIndex> *inputs) {
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(node);
|
||||
for (size_t input_index = 0; input_index < input_num; ++input_index) {
|
||||
auto input_node = AnfAlgo::GetInputNode(node->cast<CNodePtr>(), input_index);
|
||||
GetRealOutputRecursively(input_node, 0, inputs);
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
bool TrtConverterContext::Init() {
|
||||
// Set device id before invoke trt api as cudaSetDevice is thread level config.
|
||||
const auto &context = MsContext::GetInstance();
|
||||
const auto &device_id = context->get_param<uint32_t>(MS_CTX_DEVICE_ID);
|
||||
bool ret = device::gpu::CudaDriver::SetDevice(UintToInt(device_id));
|
||||
if (!ret) {
|
||||
MS_LOG(ERROR) << "Failed to set device id:" << device_id;
|
||||
return false;
|
||||
}
|
||||
|
||||
auto trt_loader = Singleton<device::gpu::TrtLoader>::Instance();
|
||||
builder_ = trt_loader.CreateInferBuilder(&Singleton<TrtLogger>::Instance());
|
||||
MS_EXCEPTION_IF_NULL(builder_);
|
||||
|
@ -142,13 +43,13 @@ bool TrtConverterContext::Init() {
|
|||
|
||||
config_ = TrtPtr(builder_->createBuilderConfig());
|
||||
MS_EXCEPTION_IF_NULL(config_);
|
||||
|
||||
InitInputTable();
|
||||
InitValueNodeTable();
|
||||
return true;
|
||||
}
|
||||
|
||||
bool TrtConverterContext::Parser() {
|
||||
InitInputTable();
|
||||
InitValueNodeTable();
|
||||
|
||||
std::vector<AnfNodePtr> node_list = TopoSort(func_graph_->get_return());
|
||||
const auto &converter_factory = TrtOpFactory::GetInstance();
|
||||
for (auto node : node_list) {
|
||||
|
@ -156,24 +57,10 @@ bool TrtConverterContext::Parser() {
|
|||
continue;
|
||||
}
|
||||
|
||||
// Mark graph outputs
|
||||
std::string op_name = AnfAlgo::GetCNodePrimitive(node)->name();
|
||||
if (op_name == kReturnOpName) {
|
||||
std::vector<LayerInput> inputs;
|
||||
(void)LoadLayerInput(node, &inputs);
|
||||
|
||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||
const auto &input = inputs[i].tensor();
|
||||
std::string name = "return_output_" + std::to_string(i);
|
||||
input->setName(name.c_str());
|
||||
network_->markOutput(*input);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// Transform AnfNode To Trt layer.
|
||||
// Bypass control node including Depend, Load, UpdateState, TupleGetItem, MakeTuple.
|
||||
if (!AnfAlgo::IsRealKernel(node)) {
|
||||
std::string op_name = AnfAlgo::GetCNodePrimitive(node)->name();
|
||||
if (!AnfAlgo::IsRealKernel(node) && op_name != "Return") {
|
||||
continue;
|
||||
}
|
||||
|
||||
|
@ -190,8 +77,7 @@ bool TrtConverterContext::Parser() {
|
|||
}
|
||||
}
|
||||
|
||||
MS_LOG(ERROR) << "Graph ended without return node.";
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool TrtConverterContext::Serialize(std::string *model) {
|
||||
|
@ -233,54 +119,58 @@ bool TrtConverterContext::InitInputTable() {
|
|||
weight.values = tensor->data_c();
|
||||
weight.type = TrtUtils::MsDtypeToTrtDtype(tensor->data_type());
|
||||
weight.count = tensor->DataSize();
|
||||
output_map_[input_node][0] = LayerInput(weight);
|
||||
output_map_[input_node][0] = LayerInput(weight, tensor->shape());
|
||||
} else {
|
||||
nvinfer1::DataType trt_dtype = TrtUtils::MsDtypeToTrtDtype(AnfAlgo::GetOutputInferDataType(input_node, 0));
|
||||
nvinfer1::Dims trt_dims = TrtUtils::MsDimsToTrtDims(AnfAlgo::GetOutputInferShape(input_node, 0), false);
|
||||
nvinfer1::ITensor *tensor = network_->addInput(input->name().c_str(), trt_dtype, trt_dims);
|
||||
output_map_[input_node][0] = LayerInput(tensor);
|
||||
const std::vector<int64_t> &shape = TrtUtils::TrtDimsToMsDims(trt_dims);
|
||||
output_map_[input_node][0] = LayerInput(tensor, shape);
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool TrtConverterContext::InitValueNodeTable() {
|
||||
auto kernel_graph = std::dynamic_pointer_cast<session::KernelGraph>(func_graph_);
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
MS_EXCEPTION_IF_NULL(func_graph_);
|
||||
const std::vector<AnfNodePtr> &node_list = TopoSort(func_graph_->get_return());
|
||||
for (const auto &node : node_list) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (node->isa<ValueNode>() && !IsValueNode<FuncGraph>(node)) {
|
||||
auto value_node = node->cast<ValueNodePtr>();
|
||||
auto &node_value = value_node->value();
|
||||
MS_EXCEPTION_IF_NULL(node_value);
|
||||
|
||||
for (auto &value_node : kernel_graph->graph_value_nodes()) {
|
||||
MS_EXCEPTION_IF_NULL(value_node);
|
||||
auto &node_value = value_node->value();
|
||||
MS_EXCEPTION_IF_NULL(node_value);
|
||||
|
||||
if (node_value->isa<tensor::Tensor>() || node_value->isa<ValueTuple>()) {
|
||||
std::vector<tensor::TensorPtr> tensors;
|
||||
TensorValueToTensor(node_value, &tensors);
|
||||
for (size_t i = 0; i < tensors.size(); i++) {
|
||||
const auto &tensor = tensors[i];
|
||||
nvinfer1::Weights weight;
|
||||
weight.values = tensor->data_c();
|
||||
weight.type = TrtUtils::MsDtypeToTrtDtype(tensor->data_type());
|
||||
weight.count = tensor->DataSize();
|
||||
output_map_[value_node][i] = LayerInput(weight);
|
||||
if (node_value->isa<tensor::Tensor>() || node_value->isa<ValueTuple>()) {
|
||||
std::vector<tensor::TensorPtr> tensors;
|
||||
TensorValueToTensor(node_value, &tensors);
|
||||
for (size_t i = 0; i < tensors.size(); i++) {
|
||||
const auto &tensor = tensors[i];
|
||||
nvinfer1::Weights weight;
|
||||
weight.values = tensor->data_c();
|
||||
weight.type = TrtUtils::MsDtypeToTrtDtype(tensor->data_type());
|
||||
weight.count = tensor->DataSize();
|
||||
output_map_[value_node][i] = LayerInput(weight, tensor->shape());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool TrtConverterContext::StoreLayerOutput(const AnfNodePtr &node, const std::vector<LayerInput> &nv_tensors) {
|
||||
bool TrtConverterContext::StoreLayerOutput(const AnfNodePtr &node, const std::vector<nvinfer1::ITensor *> &nv_tensors) {
|
||||
if (nv_tensors.size() != AnfAlgo::GetOutputTensorNum(node)) {
|
||||
MS_LOG(INFO) << node->DebugString() << " output num not match. expect: " << AnfAlgo::GetOutputTensorNum(node)
|
||||
<< ", while got: " << nv_tensors.size();
|
||||
}
|
||||
|
||||
for (size_t tensor_index = 0; tensor_index < nv_tensors.size(); ++tensor_index) {
|
||||
if (nv_tensors[tensor_index].tensor() != nullptr) {
|
||||
output_map_[node][tensor_index] = nv_tensors[tensor_index];
|
||||
if (nv_tensors[tensor_index] != nullptr) {
|
||||
const nvinfer1::Dims &dim = nv_tensors[tensor_index]->getDimensions();
|
||||
const std::vector<int64_t> &shape = TrtUtils::TrtDimsToMsDims(dim);
|
||||
output_map_[node][tensor_index] = LayerInput(nv_tensors[tensor_index], shape);
|
||||
|
||||
std::ostringstream oss;
|
||||
nvinfer1::Dims dim = nv_tensors[tensor_index].tensor()->getDimensions();
|
||||
oss << node->fullname_with_scope() << ", output: " << tensor_index << ": [ ";
|
||||
for (int32_t dim_index = 0; dim_index < dim.nbDims; dim_index++) {
|
||||
oss << dim.d[dim_index] << " ";
|
||||
|
@ -294,7 +184,7 @@ bool TrtConverterContext::StoreLayerOutput(const AnfNodePtr &node, const std::ve
|
|||
|
||||
bool TrtConverterContext::LoadLayerInput(const AnfNodePtr &node, std::vector<LayerInput> *inputs) {
|
||||
std::vector<session::KernelWithIndex> real_inputs;
|
||||
GetRealInputs(node, &real_inputs);
|
||||
AnfAlgo::GetRealInputs(node, &real_inputs);
|
||||
for (auto item : real_inputs) {
|
||||
auto node_iter = output_map_.find(item.first);
|
||||
if (node_iter == output_map_.end()) {
|
||||
|
@ -313,7 +203,7 @@ bool TrtConverterContext::LoadLayerInput(const AnfNodePtr &node, std::vector<Lay
|
|||
return true;
|
||||
}
|
||||
|
||||
std::vector<AnfNodePtr> TrtConverterContext::GetGraphInputs() {
|
||||
std::vector<AnfNodePtr> TrtConverterContext::GetGraphInputs() const {
|
||||
// Get Anf-graph inputs without weights. All weights were binded to Trt-graph.
|
||||
std::unordered_map<std::string, AnfNodePtr> graph_inputs;
|
||||
for (const auto &input_node : func_graph_->parameters()) {
|
||||
|
@ -342,9 +232,9 @@ std::vector<AnfNodePtr> TrtConverterContext::GetGraphInputs() {
|
|||
return trt_inputs;
|
||||
}
|
||||
|
||||
std::vector<session::KernelWithIndex> TrtConverterContext::GetGraphOutputs() {
|
||||
std::vector<session::KernelWithIndex> TrtConverterContext::GetGraphOutputs() const {
|
||||
std::vector<session::KernelWithIndex> graph_outputs;
|
||||
GetRealInputs(func_graph_->get_return(), &graph_outputs);
|
||||
AnfAlgo::GetRealInputs(func_graph_->get_return(), &graph_outputs);
|
||||
return graph_outputs;
|
||||
}
|
||||
|
||||
|
|
|
@ -50,13 +50,13 @@ class TrtConverterContext : public std::enable_shared_from_this<TrtConverterCont
|
|||
bool Serialize(std::string *model);
|
||||
|
||||
// Get trt graph inputs without weights. The inputs keep same order as binding name.
|
||||
std::vector<AnfNodePtr> GetGraphInputs();
|
||||
std::vector<AnfNodePtr> GetGraphInputs() const;
|
||||
|
||||
// Get trt graph outputs. All outputs are flatten to vector with concret shape.
|
||||
std::vector<session::KernelWithIndex> GetGraphOutputs();
|
||||
std::vector<session::KernelWithIndex> GetGraphOutputs() const;
|
||||
|
||||
// Store trt layer outputs to the cache.
|
||||
bool StoreLayerOutput(const AnfNodePtr &node, const std::vector<LayerInput> &inputs);
|
||||
bool StoreLayerOutput(const AnfNodePtr &node, const std::vector<nvinfer1::ITensor *> &inputs);
|
||||
|
||||
// Get trt layer inputs from the cache.
|
||||
bool LoadLayerInput(const AnfNodePtr &node, std::vector<LayerInput> *inputs);
|
||||
|
|
|
@ -51,7 +51,7 @@ ConvertResult AddReshapeLayer(AnfNodePtr node, std::shared_ptr<TrtConverterConte
|
|||
const nvinfer1::Dims &dims = TrtUtils::MsDimsToTrtDims(output_shape, false);
|
||||
layer->setReshapeDimensions(dims);
|
||||
|
||||
return {true, {LayerInput(layer->getOutput(0))}};
|
||||
return {true, {layer->getOutput(0)}};
|
||||
}
|
||||
|
||||
ConvertResult AddElementLayer(AnfNodePtr node, std::shared_ptr<TrtConverterContext> context,
|
||||
|
@ -97,7 +97,7 @@ ConvertResult AddElementLayer(AnfNodePtr node, std::shared_ptr<TrtConverterConte
|
|||
auto *layer = context->network()->addElementWise(*x1, *x2, op_type);
|
||||
MS_EXCEPTION_IF_NULL(layer);
|
||||
|
||||
return {true, {LayerInput(layer->getOutput(0))}};
|
||||
return {true, {layer->getOutput(0)}};
|
||||
}
|
||||
|
||||
ConvertResult AddPoolingLayer(AnfNodePtr node, std::shared_ptr<TrtConverterContext> context,
|
||||
|
@ -129,7 +129,7 @@ ConvertResult AddPoolingLayer(AnfNodePtr node, std::shared_ptr<TrtConverterConte
|
|||
layer->setPaddingMode(nvinfer1::PaddingMode::kSAME_UPPER);
|
||||
}
|
||||
|
||||
return {true, {LayerInput(layer->getOutput(0))}};
|
||||
return {true, {layer->getOutput(0)}};
|
||||
}
|
||||
|
||||
ConvertResult AddActivationLayer(AnfNodePtr node, std::shared_ptr<TrtConverterContext> context,
|
||||
|
@ -144,7 +144,7 @@ ConvertResult AddActivationLayer(AnfNodePtr node, std::shared_ptr<TrtConverterCo
|
|||
auto *layer = context->network()->addActivation(*inputs[0].tensor(), act_type);
|
||||
MS_EXCEPTION_IF_NULL(layer);
|
||||
|
||||
return {true, {LayerInput(layer->getOutput(0))}};
|
||||
return {true, {layer->getOutput(0)}};
|
||||
}
|
||||
|
||||
ConvertResult AddUnaryLayer(AnfNodePtr node, std::shared_ptr<TrtConverterContext> context,
|
||||
|
@ -159,7 +159,7 @@ ConvertResult AddUnaryLayer(AnfNodePtr node, std::shared_ptr<TrtConverterContext
|
|||
auto *layer = context->network()->addUnary(*inputs[0].tensor(), op_type);
|
||||
MS_EXCEPTION_IF_NULL(layer);
|
||||
|
||||
return {true, {LayerInput(layer->getOutput(0))}};
|
||||
return {true, {layer->getOutput(0)}};
|
||||
}
|
||||
|
||||
ConvertResult addReduceLayer(AnfNodePtr node, std::shared_ptr<TrtConverterContext> context,
|
||||
|
@ -191,7 +191,7 @@ ConvertResult addReduceLayer(AnfNodePtr node, std::shared_ptr<TrtConverterContex
|
|||
// Skip reduce operator if reduce_axes == 0
|
||||
if (reduce_axes == 0) {
|
||||
MS_LOG(WARNING) << "No dimension be be reduced. " << node->DebugString();
|
||||
return {true, {LayerInput(inputs[0].tensor())}};
|
||||
return {true, {inputs[0].tensor()}};
|
||||
}
|
||||
|
||||
bool keep_dims = AnfAlgo::GetNodeAttr<bool>(node, "keep_dims");
|
||||
|
@ -215,10 +215,10 @@ ConvertResult addReduceLayer(AnfNodePtr node, std::shared_ptr<TrtConverterContex
|
|||
dim.d[1] = 1;
|
||||
reshape_layer->setReshapeDimensions(dim);
|
||||
|
||||
return {true, {LayerInput(reshape_layer->getOutput(0))}};
|
||||
return {true, {reshape_layer->getOutput(0)}};
|
||||
}
|
||||
|
||||
return {true, {LayerInput(layer->getOutput(0))}};
|
||||
return {true, {layer->getOutput(0)}};
|
||||
}
|
||||
} // namespace
|
||||
|
||||
|
@ -265,7 +265,10 @@ MS_TRT_CONVERTER_FUNC_REG(Conv2D) {
|
|||
layer->setPostPadding(nvinfer1::DimsHW{LongToInt(pad_list[1]), LongToInt(pad_list[3])});
|
||||
}
|
||||
|
||||
return {true, {LayerInput(layer->getOutput(0))}};
|
||||
const auto &group = AnfAlgo::GetNodeAttr<int64_t>(node, "group");
|
||||
layer->setNbGroups(SizeToInt(group));
|
||||
|
||||
return {true, {layer->getOutput(0)}};
|
||||
}
|
||||
|
||||
// Binary broadcast operators.
|
||||
|
@ -367,7 +370,7 @@ MS_TRT_CONVERTER_FUNC_REG(GeLU) {
|
|||
layer = context->network()->addElementWise(*c1, *layer->getOutput(0), nvinfer1::ElementWiseOperation::kPROD);
|
||||
MS_EXCEPTION_IF_NULL(layer);
|
||||
|
||||
return {true, {LayerInput(layer->getOutput(0))}};
|
||||
return {true, {layer->getOutput(0)}};
|
||||
}
|
||||
|
||||
MS_TRT_CONVERTER_FUNC_REG(MatMul) {
|
||||
|
@ -402,7 +405,7 @@ MS_TRT_CONVERTER_FUNC_REG(MatMul) {
|
|||
auto *squeeze_y = context->network()->addShuffle(*layer->getOutput(0));
|
||||
squeeze_y->setReshapeDimensions(y_dims);
|
||||
|
||||
return {true, {LayerInput(squeeze_y->getOutput(0))}};
|
||||
return {true, {squeeze_y->getOutput(0)}};
|
||||
} else {
|
||||
// convert weight to tensor and appy addMatrixMultiply
|
||||
MS_LOG(ERROR) << "Operator not implemented: " << node->DebugString();
|
||||
|
@ -446,7 +449,7 @@ MS_TRT_CONVERTER_FUNC_REG(BatchMatMul) {
|
|||
auto *layer = context->network()->addMatrixMultiply(*tensor1, trt_transpose1, *tensor2, trt_transpose2);
|
||||
MS_EXCEPTION_IF_NULL(layer);
|
||||
|
||||
return {true, {LayerInput(layer->getOutput(0))}};
|
||||
return {true, {layer->getOutput(0)}};
|
||||
}
|
||||
|
||||
MS_TRT_CONVERTER_FUNC_REG(BiasAdd) {
|
||||
|
@ -477,13 +480,14 @@ MS_TRT_CONVERTER_FUNC_REG(BiasAdd) {
|
|||
auto *layer = context->network()->addElementWise(*inputs[0].tensor(), *bias, nvinfer1::ElementWiseOperation::kSUM);
|
||||
MS_EXCEPTION_IF_NULL(layer);
|
||||
|
||||
return {true, {LayerInput(layer->getOutput(0))}};
|
||||
return {true, {layer->getOutput(0)}};
|
||||
}
|
||||
|
||||
// NoOp
|
||||
MS_TRT_CONVERTER_FUNC_REG(Reshape) { return AddReshapeLayer(node, context); }
|
||||
MS_TRT_CONVERTER_FUNC_REG(ExpandDims) { return AddReshapeLayer(node, context); }
|
||||
MS_TRT_CONVERTER_FUNC_REG(Squeeze) { return AddReshapeLayer(node, context); }
|
||||
MS_TRT_CONVERTER_FUNC_REG(Flatten) { return AddReshapeLayer(node, context); }
|
||||
|
||||
MS_TRT_CONVERTER_FUNC_REG(BatchNorm) {
|
||||
std::vector<LayerInput> inputs;
|
||||
|
@ -537,7 +541,7 @@ MS_TRT_CONVERTER_FUNC_REG(BatchNorm) {
|
|||
auto *layer = context->network()->addScale(*inputs[0].tensor(), nvinfer1::ScaleMode::kCHANNEL, shift, scale, pow);
|
||||
MS_EXCEPTION_IF_NULL(layer);
|
||||
|
||||
return {true, {LayerInput(layer->getOutput(0))}};
|
||||
return {true, {layer->getOutput(0)}};
|
||||
}
|
||||
|
||||
MS_TRT_CONVERTER_FUNC_REG(Concat) {
|
||||
|
@ -567,7 +571,7 @@ MS_TRT_CONVERTER_FUNC_REG(Concat) {
|
|||
}
|
||||
layer->setAxis(axis);
|
||||
|
||||
return {true, {LayerInput(layer->getOutput(0))}};
|
||||
return {true, {layer->getOutput(0)}};
|
||||
}
|
||||
|
||||
MS_TRT_CONVERTER_FUNC_REG(Conv2DBackpropInput) {
|
||||
|
@ -608,7 +612,7 @@ MS_TRT_CONVERTER_FUNC_REG(Conv2DBackpropInput) {
|
|||
layer->setPostPadding(nvinfer1::DimsHW{LongToInt(pad_list[1]), LongToInt(pad_list[3])});
|
||||
}
|
||||
|
||||
return {true, {LayerInput(layer->getOutput(0))}};
|
||||
return {true, {layer->getOutput(0)}};
|
||||
}
|
||||
|
||||
MS_TRT_CONVERTER_FUNC_REG(Slice) {
|
||||
|
@ -632,7 +636,7 @@ MS_TRT_CONVERTER_FUNC_REG(Slice) {
|
|||
auto *layer = context->network()->addSlice(*inputs[0].tensor(), trt_start, trt_size, trt_stride);
|
||||
MS_EXCEPTION_IF_NULL(layer);
|
||||
|
||||
return {true, {LayerInput(layer->getOutput(0))}};
|
||||
return {true, {layer->getOutput(0)}};
|
||||
}
|
||||
|
||||
MS_TRT_CONVERTER_FUNC_REG(Transpose) {
|
||||
|
@ -653,7 +657,7 @@ MS_TRT_CONVERTER_FUNC_REG(Transpose) {
|
|||
MS_EXCEPTION_IF_NULL(layer);
|
||||
layer->setFirstTranspose(trt_perm);
|
||||
|
||||
return {true, {LayerInput(layer->getOutput(0))}};
|
||||
return {true, {layer->getOutput(0)}};
|
||||
}
|
||||
|
||||
MS_TRT_CONVERTER_FUNC_REG(Softmax) {
|
||||
|
@ -684,7 +688,7 @@ MS_TRT_CONVERTER_FUNC_REG(Softmax) {
|
|||
auto *layer = context->network()->addSoftMax(*inputs[0].tensor());
|
||||
MS_EXCEPTION_IF_NULL(layer);
|
||||
layer->setAxes(reduce_axes);
|
||||
return {true, {LayerInput(layer->getOutput(0))}};
|
||||
return {true, {layer->getOutput(0)}};
|
||||
}
|
||||
|
||||
MS_TRT_CONVERTER_FUNC_REG(LogSoftmax) {
|
||||
|
@ -707,7 +711,7 @@ MS_TRT_CONVERTER_FUNC_REG(LogSoftmax) {
|
|||
auto *log_layer = context->network()->addUnary(*softmax_layer->getOutput(0), nvinfer1::UnaryOperation::kLOG);
|
||||
MS_EXCEPTION_IF_NULL(log_layer);
|
||||
|
||||
return {true, {LayerInput(log_layer->getOutput(0))}};
|
||||
return {true, {log_layer->getOutput(0)}};
|
||||
}
|
||||
|
||||
MS_TRT_CONVERTER_FUNC_REG(Gather) {
|
||||
|
@ -729,7 +733,7 @@ MS_TRT_CONVERTER_FUNC_REG(Gather) {
|
|||
auto *layer = context->network()->addGather(*input, *indices, LongToInt(axis));
|
||||
MS_EXCEPTION_IF_NULL(layer);
|
||||
|
||||
return {true, {LayerInput(layer->getOutput(0))}};
|
||||
return {true, {layer->getOutput(0)}};
|
||||
}
|
||||
|
||||
MS_TRT_CONVERTER_FUNC_REG(Cast) {
|
||||
|
@ -744,7 +748,7 @@ MS_TRT_CONVERTER_FUNC_REG(Cast) {
|
|||
auto trt_type = TrtUtils::MsDtypeToTrtDtype(dst_type);
|
||||
auto *layer = context->network()->addIdentity(*inputs[0].tensor());
|
||||
layer->setOutputType(0, trt_type);
|
||||
return {true, {LayerInput(layer->getOutput(0))}};
|
||||
return {true, {layer->getOutput(0)}};
|
||||
}
|
||||
|
||||
MS_TRT_CONVERTER_FUNC_REG(LayerNorm) {
|
||||
|
@ -810,7 +814,33 @@ MS_TRT_CONVERTER_FUNC_REG(LayerNorm) {
|
|||
nvinfer1::ElementWiseOperation::kSUM);
|
||||
MS_EXCEPTION_IF_NULL(add);
|
||||
|
||||
return {true, {LayerInput(add->getOutput(0))}};
|
||||
return {true, {add->getOutput(0)}};
|
||||
}
|
||||
|
||||
MS_TRT_CONVERTER_FUNC_REG(Return) {
|
||||
std::vector<LayerInput> inputs;
|
||||
bool ret = context->LoadLayerInput(node, &inputs);
|
||||
if (!ret) {
|
||||
return {false, {}};
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||
nvinfer1::ITensor *input = nullptr;
|
||||
if (inputs[i].IsTensor()) {
|
||||
input = inputs[i].tensor();
|
||||
} else {
|
||||
std::vector<size_t> shape;
|
||||
std::transform(inputs[i].shape().begin(), inputs[i].shape().end(), std::back_inserter(shape),
|
||||
[](int64_t d) { return LongToSize(d); });
|
||||
input = ToTensor(&inputs[i], shape, context);
|
||||
}
|
||||
|
||||
const std::string &name = "return_output_" + std::to_string(i);
|
||||
input->setName(name.c_str());
|
||||
context->network()->markOutput(*input);
|
||||
}
|
||||
|
||||
return {true, {}};
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -23,6 +23,7 @@
|
|||
#include <utility>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <NvInfer.h>
|
||||
#include "base/base.h"
|
||||
#include "ir/anf.h"
|
||||
|
||||
|
@ -30,7 +31,7 @@ namespace mindspore {
|
|||
namespace opt {
|
||||
class LayerInput;
|
||||
class TrtConverterContext;
|
||||
using ConvertResult = std::pair<bool, std::vector<LayerInput>>;
|
||||
using ConvertResult = std::pair<bool, std::vector<nvinfer1::ITensor *>>;
|
||||
using ConvertFunc = std::function<ConvertResult(AnfNodePtr, std::shared_ptr<TrtConverterContext>)>;
|
||||
|
||||
class TrtOpFactory {
|
||||
|
|
Loading…
Reference in New Issue