forked from mindspore-Ecosystem/mindspore
!17562 graph partition
From: @wilfchen Reviewed-by: @cristoval,@limingqi107 Signed-off-by: @limingqi107
This commit is contained in:
commit
e5144573ad
|
@ -0,0 +1,275 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/optimizer/trt_pass/graph_converter.h"
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <set>
|
||||
#include <map>
|
||||
#include <unordered_map>
|
||||
#include <tuple>
|
||||
#include <algorithm>
|
||||
#include <utility>
|
||||
#include <string>
|
||||
#include "utils/ms_context.h"
|
||||
#include "backend/optimizer/trt_pass/trt_converter_context.h"
|
||||
#include "utils/singleton.h"
|
||||
#include "runtime/device/gpu/cuda_driver.h"
|
||||
#include "runtime/device/gpu/trt_loader.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace {
|
||||
void CopyGraphOutputTypeAndShape(const std::vector<session::KernelWithIndex> &graph_outputs, CNodePtr trt_node) {
|
||||
std::vector<TypeId> types;
|
||||
std::vector<std::vector<size_t>> shapes;
|
||||
for (const auto &item : graph_outputs) {
|
||||
types.push_back(AnfAlgo::GetOutputInferDataType(item.first, item.second));
|
||||
shapes.push_back(AnfAlgo::GetOutputInferShape(item.first, item.second));
|
||||
}
|
||||
|
||||
AnfAlgo::SetOutputInferTypeAndShape(types, shapes, trt_node.get());
|
||||
return;
|
||||
}
|
||||
|
||||
CNodePtr NewTrtNode(const FuncGraphPtr &graph, const std::string &model_data, const AnfNodePtrList &graph_inputs,
|
||||
const std::vector<session::KernelWithIndex> &graph_outputs) {
|
||||
// Create TrtNode which hold serialzed data.
|
||||
auto prim = std::make_shared<Primitive>("TrtNode");
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
std::vector<AnfNodePtr> inputs = {NewValueNode(prim)};
|
||||
std::copy(graph_inputs.begin(), graph_inputs.end(), std::back_inserter(inputs));
|
||||
prim->AddAttr("serialize_model", MakeValue(model_data));
|
||||
auto trt_node = graph->NewCNode(inputs);
|
||||
MS_EXCEPTION_IF_NULL(trt_node);
|
||||
|
||||
// Update output shape and type
|
||||
CopyGraphOutputTypeAndShape(graph_outputs, trt_node);
|
||||
return trt_node;
|
||||
}
|
||||
|
||||
CNodePtr BuildMakeTupleNode(const FuncGraphPtr root, const std::map<size_t, size_t> &anf_trt_index_map,
|
||||
CNodePtr trt_node) {
|
||||
std::vector<AnfNodePtr> make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple)};
|
||||
std::vector<TypeId> make_tuple_types;
|
||||
std::vector<std::vector<size_t>> make_tuple_shapes;
|
||||
|
||||
for (size_t out_idx = 0; out_idx < anf_trt_index_map.size(); out_idx++) {
|
||||
// Get TrtNode output index
|
||||
auto iter = anf_trt_index_map.find(out_idx);
|
||||
if (iter == anf_trt_index_map.end()) {
|
||||
MS_LOG(EXCEPTION) << "Output node found: " << out_idx;
|
||||
}
|
||||
size_t trt_index = iter->second;
|
||||
|
||||
// create tuple_getitem_cnode
|
||||
std::vector<AnfNodePtr> tuple_getitem_inputs = {NewValueNode(prim::kPrimTupleGetItem), trt_node,
|
||||
NewValueNode(MakeValue(SizeToLong(trt_index)))};
|
||||
const CNodePtr &tuple_getitem_cnode = root->NewCNode(tuple_getitem_inputs);
|
||||
MS_EXCEPTION_IF_NULL(tuple_getitem_cnode);
|
||||
|
||||
// Set tuple_getitem_cnode abstract.
|
||||
std::vector<TypeId> types = {AnfAlgo::GetOutputInferDataType(trt_node, trt_index)};
|
||||
std::vector<std::vector<size_t>> shapes = {AnfAlgo::GetOutputInferShape(trt_node, trt_index)};
|
||||
AnfAlgo::SetOutputInferTypeAndShape(types, shapes, tuple_getitem_cnode.get());
|
||||
|
||||
// Build make tuple inputs.
|
||||
make_tuple_inputs.push_back(tuple_getitem_cnode);
|
||||
make_tuple_types.push_back(types[0]);
|
||||
make_tuple_shapes.push_back(shapes[0]);
|
||||
}
|
||||
|
||||
const CNodePtr &make_tuple_cnode = root->NewCNode(make_tuple_inputs);
|
||||
MS_EXCEPTION_IF_NULL(make_tuple_cnode);
|
||||
AnfAlgo::SetOutputInferTypeAndShape(make_tuple_types, make_tuple_shapes, make_tuple_cnode.get());
|
||||
|
||||
return make_tuple_cnode;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
AnfNodePtrList GraphConverter::GetUsefulArguments(const AnfNodePtrList &arguments, const AnfNodePtrList ¶meters,
|
||||
const AnfNodePtrList &useful_parameters) {
|
||||
// Present map between formal parameter and actual argument.
|
||||
std::unordered_map<AnfNodePtr, AnfNodePtr> args_map;
|
||||
for (size_t i = 0; i < parameters.size(); i++) {
|
||||
args_map.insert(std::make_pair(parameters[i], arguments[i]));
|
||||
}
|
||||
|
||||
AnfNodePtrList useful_arguments;
|
||||
for (size_t j = 0; j < useful_parameters.size(); j++) {
|
||||
auto iter = args_map.find(useful_parameters[j]);
|
||||
if (iter == args_map.end() || iter->second == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Argument not found. Arg: " << useful_parameters[j]->DebugString();
|
||||
}
|
||||
useful_arguments.push_back(iter->second);
|
||||
}
|
||||
|
||||
return useful_arguments;
|
||||
}
|
||||
|
||||
std::tuple<std::map<size_t, size_t>, CNodePtr> GraphConverter::BuildTrtNode(const FuncGraphPtr &root_graph,
|
||||
const FuncGraphPtr &sub_graph,
|
||||
const AnfNodePtrList &arguments) {
|
||||
MS_EXCEPTION_IF_NULL(root_graph);
|
||||
MS_EXCEPTION_IF_NULL(sub_graph);
|
||||
auto converter = std::make_shared<TrtConverterContext>(sub_graph);
|
||||
bool ret = converter->Init();
|
||||
if (!ret) {
|
||||
MS_LOG(EXCEPTION) << "Graph convert init failed.";
|
||||
}
|
||||
|
||||
ret = converter->Parser();
|
||||
if (!ret) {
|
||||
MS_LOG(EXCEPTION) << "Graph converter parse failed.";
|
||||
}
|
||||
|
||||
std::string model_data;
|
||||
ret = converter->Serialize(&model_data);
|
||||
if (!ret) {
|
||||
MS_LOG(EXCEPTION) << "Graph converte serialize failed.";
|
||||
}
|
||||
|
||||
// Get actual arguments by useful formal parameters
|
||||
const AnfNodePtrList ¶meters = sub_graph->parameters();
|
||||
const AnfNodePtrList &useful_parameters = converter->GetGraphInputs();
|
||||
const AnfNodePtrList &useful_arguments = GetUsefulArguments(arguments, parameters, useful_parameters);
|
||||
|
||||
// Get outputs by the TensorRT binding order.
|
||||
std::map<size_t, size_t> anf_trt_index_map;
|
||||
std::vector<session::KernelWithIndex> trt_output_list;
|
||||
std::tie(anf_trt_index_map, trt_output_list) = converter->GetGraphOutputs();
|
||||
CNodePtr trt_node = NewTrtNode(root_graph, model_data, useful_arguments, trt_output_list);
|
||||
|
||||
return std::make_tuple(anf_trt_index_map, trt_node);
|
||||
}
|
||||
|
||||
void GraphConverter::RemoveParameterWithoutUser(const FuncGraphPtr &graph) {
|
||||
std::vector<AnfNodePtr> graph_inputs;
|
||||
|
||||
auto kernel_graph = std::dynamic_pointer_cast<session::KernelGraph>(graph);
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
const AnfNodePtrList &inputs = kernel_graph->inputs();
|
||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||
const auto &input = inputs[i];
|
||||
|
||||
// Keep inputs of graph.
|
||||
if (!input->isa<Parameter>() || !AnfAlgo::IsParameterWeight(input->cast<ParameterPtr>())) {
|
||||
graph_inputs.push_back(input);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Remove useless parameters of graph.
|
||||
FuncGraphManagerPtr manager = graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
const NodeUsersMap &users = manager->node_users();
|
||||
const auto &iter = users.find(input);
|
||||
if (iter != users.end() && !iter->second.empty()) {
|
||||
graph_inputs.push_back(input);
|
||||
}
|
||||
MS_LOG(INFO) << "Useless input: " << input->DebugString();
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "Graph total inputs num: " << graph_inputs.size();
|
||||
kernel_graph->SetGraphInputs(graph_inputs);
|
||||
kernel_graph->set_parameters(graph_inputs);
|
||||
}
|
||||
|
||||
bool GraphConverter::ReplaceSubgraphWithTrtNode(const FuncGraphPtr &root, const Subgraph &sub_graph_info) {
|
||||
FuncGraphPtr sub_graph;
|
||||
AnfNodePtrList args;
|
||||
AnfNodePtrList outputs;
|
||||
std::tie(sub_graph, args, outputs) = sub_graph_info;
|
||||
|
||||
std::map<size_t, size_t> anf_trt_index_map;
|
||||
CNodePtr trt_node;
|
||||
std::tie(anf_trt_index_map, trt_node) = BuildTrtNode(root, sub_graph, args);
|
||||
if (trt_node == nullptr) {
|
||||
MS_LOG(WARNING) << "Convert to Tensor-RT network failed.";
|
||||
return false;
|
||||
}
|
||||
|
||||
auto manager = root->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
if (outputs.size() == 1) {
|
||||
if (AnfAlgo::CheckPrimitiveType(outputs[0], prim::kPrimMakeTuple)) {
|
||||
const CNodePtr &make_tuple_cnode = BuildMakeTupleNode(root, anf_trt_index_map, trt_node);
|
||||
manager->Replace(outputs[0], make_tuple_cnode);
|
||||
} else {
|
||||
manager->Replace(outputs[0], trt_node);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
for (size_t out_idx = 0; out_idx < outputs.size(); out_idx++) {
|
||||
size_t trt_index = anf_trt_index_map[out_idx];
|
||||
std::vector<AnfNodePtr> fn_inputs = {NewValueNode(prim::kPrimTupleGetItem), trt_node,
|
||||
NewValueNode(MakeValue(SizeToLong(trt_index)))};
|
||||
const CNodePtr &new_out = root->NewCNode(fn_inputs);
|
||||
new_out->set_abstract(outputs[out_idx]->abstract());
|
||||
manager->Replace(outputs[out_idx], new_out);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool GraphConverter::Run(const FuncGraphPtr &fg) {
|
||||
MS_EXCEPTION_IF_NULL(fg);
|
||||
|
||||
const auto &context = MsContext::GetInstance();
|
||||
if (!context->get_param<bool>(MS_CTX_ENABLE_INFER_OPT)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Set device id before invoke trt api as cudaSetDevice is thread level config.
|
||||
const auto &device_id = context->get_param<uint32_t>(MS_CTX_DEVICE_ID);
|
||||
bool ret = device::gpu::CudaDriver::SetDevice(UintToInt(device_id));
|
||||
if (!ret) {
|
||||
MS_LOG(ERROR) << "Failed to set device id:" << device_id;
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto &trt_loader = Singleton<device::gpu::TrtLoader>::Instance();
|
||||
if (!trt_loader.nvinfer_loaded()) {
|
||||
MS_LOG(WARNING) << "Load Tensor-RT so failed. Inference with native backend.";
|
||||
return false;
|
||||
}
|
||||
|
||||
try {
|
||||
auto graph_partition = std::make_shared<GraphPartitioner>();
|
||||
const std::map<std::string, AnfNodePtrList> &segments = graph_partition->Partition(fg);
|
||||
for (const auto &segment : segments) {
|
||||
// Do not fusion when segment only contain 1 node.
|
||||
if (segment.second.size() == 1) {
|
||||
continue;
|
||||
}
|
||||
const Subgraph &sub_graph = graph_partition->CreateNewGraph(segment.second);
|
||||
ret = ReplaceSubgraphWithTrtNode(fg, sub_graph);
|
||||
if (!ret) {
|
||||
MS_LOG(WARNING) << "Failed replace sub graph with TrtNode.";
|
||||
continue;
|
||||
}
|
||||
// Remove useless parameters folded in TensorRT network.
|
||||
RemoveParameterWithoutUser(fg);
|
||||
}
|
||||
} catch (const std::exception &e) {
|
||||
MS_LOG(WARNING) << "Convert to Tensor-RT network failed. " << e.what();
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,58 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_TRT_PASS_GRAPH_CONVERTER_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_TRT_PASS_GRAPH_CONVERTER_H_
|
||||
|
||||
#include <map>
|
||||
#include <tuple>
|
||||
#include <memory>
|
||||
#include "backend/optimizer/common/optimizer.h"
|
||||
#include "backend/optimizer/trt_pass/graph_partitioner.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
// Pass replace MindIR operators to TrtNode contains serialized data generated by TensorRT.
|
||||
// It mainly includes three steps:
|
||||
// 1. Segment the network with `GraphPartition`.
|
||||
// 2. Build the TRT network with segmentation and takes its serialized data as an attribute of the `TrtNode`.
|
||||
// 3. Replace the segmentation with `TrtNode`.
|
||||
class GraphConverter : public Pass {
|
||||
public:
|
||||
GraphConverter() : Pass("mindir_to_trt_pass") {}
|
||||
~GraphConverter() override = default;
|
||||
|
||||
// Run the pass replace subgraph to TrtNode.
|
||||
bool Run(const FuncGraphPtr &fg) override;
|
||||
|
||||
private:
|
||||
// Replace subgraph with TrtNode which keep model data serialized by TensorRT network.
|
||||
bool ReplaceSubgraphWithTrtNode(const FuncGraphPtr &root_graph, const Subgraph &sub_graph);
|
||||
|
||||
// Build the TrtNode from subgraph including serialized model data and input shapes and dtypes.
|
||||
std::tuple<std::map<size_t, size_t>, CNodePtr> BuildTrtNode(const FuncGraphPtr &root_graph,
|
||||
const FuncGraphPtr &sub_graph,
|
||||
const AnfNodePtrList &arguments);
|
||||
|
||||
// Remove useless parameters which had been folded in Tensor-RT network.
|
||||
void RemoveParameterWithoutUser(const FuncGraphPtr &graph);
|
||||
|
||||
// Get useful arguments in the root graph after conversion.
|
||||
AnfNodePtrList GetUsefulArguments(const AnfNodePtrList &arguments, const AnfNodePtrList ¶meters,
|
||||
const AnfNodePtrList &used_parameter);
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_TRT_PASS_GRAPH_CONVERTER_H_
|
|
@ -120,12 +120,6 @@ bool TrtConverterContext::InitInputTable() {
|
|||
weight.type = TrtUtils::MsDtypeToTrtDtype(tensor->data_type());
|
||||
weight.count = tensor->DataSize();
|
||||
output_map_[input_node][0] = LayerInput(weight, tensor->shape());
|
||||
} else {
|
||||
nvinfer1::DataType trt_dtype = TrtUtils::MsDtypeToTrtDtype(AnfAlgo::GetOutputInferDataType(input_node, 0));
|
||||
nvinfer1::Dims trt_dims = TrtUtils::MsDimsToTrtDims(AnfAlgo::GetOutputInferShape(input_node, 0), false);
|
||||
nvinfer1::ITensor *tensor = network_->addInput(input->name().c_str(), trt_dtype, trt_dims);
|
||||
const std::vector<int64_t> &shape = TrtUtils::TrtDimsToMsDims(trt_dims);
|
||||
output_map_[input_node][0] = LayerInput(tensor, shape);
|
||||
}
|
||||
}
|
||||
return true;
|
||||
|
@ -182,12 +176,28 @@ bool TrtConverterContext::StoreLayerOutput(const AnfNodePtr &node, const std::ve
|
|||
return true;
|
||||
}
|
||||
|
||||
LayerInput *TrtConverterContext::LoadInputOnDemand(const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto input = node->cast<ParameterPtr>();
|
||||
const nvinfer1::DataType &trt_dtype = TrtUtils::MsDtypeToTrtDtype(AnfAlgo::GetOutputInferDataType(node, 0));
|
||||
const nvinfer1::Dims &trt_dims = TrtUtils::MsDimsToTrtDims(AnfAlgo::GetOutputInferShape(node, 0), false);
|
||||
nvinfer1::ITensor *tensor = network_->addInput(input->name().c_str(), trt_dtype, trt_dims);
|
||||
const std::vector<int64_t> &shape = TrtUtils::TrtDimsToMsDims(trt_dims);
|
||||
output_map_[node][0] = LayerInput(tensor, shape);
|
||||
return &output_map_[node][0];
|
||||
}
|
||||
|
||||
bool TrtConverterContext::LoadLayerInput(const AnfNodePtr &node, std::vector<LayerInput> *inputs) {
|
||||
std::vector<session::KernelWithIndex> real_inputs;
|
||||
AnfAlgo::GetRealInputs(node, &real_inputs);
|
||||
for (auto item : real_inputs) {
|
||||
auto node_iter = output_map_.find(item.first);
|
||||
if (node_iter == output_map_.end()) {
|
||||
if (item.first->isa<Parameter>()) {
|
||||
LayerInput *input = LoadInputOnDemand(item.first);
|
||||
inputs->push_back(*input);
|
||||
continue;
|
||||
}
|
||||
MS_LOG(ERROR) << "node: " << node->DebugString() << " not found.";
|
||||
return false;
|
||||
}
|
||||
|
@ -243,7 +253,7 @@ std::tuple<std::map<size_t, size_t>, std::vector<session::KernelWithIndex>> TrtC
|
|||
for (int32_t i = 0; i < engine_->getNbBindings(); ++i) {
|
||||
if (!engine_->bindingIsInput(i)) {
|
||||
const std::string &name = engine_->getBindingName(i);
|
||||
size_t pos = name.find_last_not_of("return_output_");
|
||||
size_t pos = name.find_first_not_of("return_output_");
|
||||
size_t anf_index = atoi(name.substr(pos).c_str());
|
||||
|
||||
anf_trt_index_map.insert(std::make_pair(anf_index, trt_index));
|
||||
|
|
|
@ -79,6 +79,7 @@ class TrtConverterContext : public std::enable_shared_from_this<TrtConverterCont
|
|||
private:
|
||||
bool InitInputTable();
|
||||
bool InitValueNodeTable();
|
||||
LayerInput *LoadInputOnDemand(const AnfNodePtr &node);
|
||||
|
||||
FuncGraphPtr func_graph_;
|
||||
uint32_t batch_size_;
|
||||
|
|
|
@ -24,11 +24,35 @@
|
|||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace {
|
||||
nvinfer1::ITensor *ToShape(LayerInput *input, const std::vector<size_t> &shape,
|
||||
std::shared_ptr<TrtConverterContext> context) {
|
||||
MS_EXCEPTION_IF_NULL(input);
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
|
||||
if (!input->IsTensor()) {
|
||||
MS_LOG(ERROR) << "Expect Tensor but got weight";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
const nvinfer1::Dims &src_dim = input->tensor()->getDimensions();
|
||||
const nvinfer1::Dims &dst_dim = TrtUtils::MsDimsToTrtDims(shape, false);
|
||||
if (TrtUtils::IsSameShape(src_dim, dst_dim)) {
|
||||
return input->tensor();
|
||||
}
|
||||
|
||||
auto *layer = context->network()->addShuffle(*input->tensor());
|
||||
MS_EXCEPTION_IF_NULL(layer);
|
||||
layer->setReshapeDimensions(dst_dim);
|
||||
|
||||
return layer->getOutput(0);
|
||||
}
|
||||
|
||||
nvinfer1::ITensor *ToTensor(LayerInput *input, const std::vector<size_t> &shape,
|
||||
std::shared_ptr<TrtConverterContext> context) {
|
||||
MS_EXCEPTION_IF_NULL(input);
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
if (input->IsTensor()) {
|
||||
return input->tensor();
|
||||
return ToShape(input, shape, context);
|
||||
}
|
||||
|
||||
const nvinfer1::Dims &dim = TrtUtils::MsDimsToTrtDims(shape, false);
|
||||
|
@ -553,31 +577,42 @@ MS_TRT_CONVERTER_FUNC_REG(BatchMatMul) {
|
|||
|
||||
std::vector<size_t> shape1 = AnfAlgo::GetPrevNodeOutputInferShape(node, 0);
|
||||
std::vector<size_t> shape2 = AnfAlgo::GetPrevNodeOutputInferShape(node, 1);
|
||||
|
||||
auto SwapLastDims = [](std::vector<size_t> shape, const bool &transpose) {
|
||||
if (shape.size() < 2) {
|
||||
MS_LOG(EXCEPTION) << "Operation not support: input rank should >= 2";
|
||||
}
|
||||
|
||||
if (!transpose) {
|
||||
return shape;
|
||||
}
|
||||
|
||||
size_t tmp = shape[shape.size() - 2];
|
||||
shape[shape.size() - 2] = shape[shape.size() - 1];
|
||||
shape[shape.size() - 1] = tmp;
|
||||
return shape;
|
||||
};
|
||||
|
||||
nvinfer1::ITensor *tensor1 = ToTensor(&inputs[0], SwapLastDims(shape1, transpose_a), context);
|
||||
nvinfer1::ITensor *tensor2 = ToTensor(&inputs[1], SwapLastDims(shape2, transpose_b), context);
|
||||
nvinfer1::ITensor *tensor1 = ToTensor(&inputs[0], shape1, context);
|
||||
nvinfer1::ITensor *tensor2 = ToTensor(&inputs[1], shape2, context);
|
||||
auto *layer = context->network()->addMatrixMultiply(*tensor1, trt_transpose1, *tensor2, trt_transpose2);
|
||||
MS_EXCEPTION_IF_NULL(layer);
|
||||
|
||||
return {true, {layer->getOutput(0)}};
|
||||
}
|
||||
|
||||
MS_TRT_CONVERTER_FUNC_REG(BiasAdd) { return AddElementLayer(node, context, nvinfer1::ElementWiseOperation::kSUM); }
|
||||
MS_TRT_CONVERTER_FUNC_REG(BiasAdd) {
|
||||
std::vector<LayerInput> inputs;
|
||||
bool ret = context->LoadLayerInput(node, &inputs);
|
||||
if (!ret || inputs.size() != 2) {
|
||||
MS_LOG(ERROR) << "Input num not match: " << inputs.size() << ", with 1 expected.";
|
||||
return {false, {}};
|
||||
}
|
||||
|
||||
const auto &x_shape = AnfAlgo::GetPrevNodeOutputInferShape(node, 0);
|
||||
const auto &bias_shape = AnfAlgo::GetPrevNodeOutputInferShape(node, 1);
|
||||
const auto &format = AnfAlgo::GetNodeAttr<std::string>(node, "format");
|
||||
const string::size_type &pos = format.find("C");
|
||||
if (pos == std::string::npos || pos >= x_shape.size()) {
|
||||
MS_LOG(ERROR) << "The format " << format << "' invalid";
|
||||
return {false, {}};
|
||||
}
|
||||
|
||||
// Convert bias to ITensor same dims as x.
|
||||
std::vector<size_t> unsqueeze_bias_dims(x_shape.size(), 1);
|
||||
unsqueeze_bias_dims[pos] = SizeToInt(bias_shape[0]);
|
||||
nvinfer1::ITensor *bias = ToTensor(&inputs[1], unsqueeze_bias_dims, context);
|
||||
|
||||
// Create Broadcast Add layer.
|
||||
auto *layer = context->network()->addElementWise(*inputs[0].tensor(), *bias, nvinfer1::ElementWiseOperation::kSUM);
|
||||
MS_EXCEPTION_IF_NULL(layer);
|
||||
|
||||
return {true, {layer->getOutput(0)}};
|
||||
}
|
||||
|
||||
// NoOp
|
||||
MS_TRT_CONVERTER_FUNC_REG(Reshape) { return AddReshapeLayer(node, context); }
|
||||
|
@ -847,13 +882,20 @@ MS_TRT_CONVERTER_FUNC_REG(Cast) {
|
|||
auto trt_type = TrtUtils::MsDtypeToTrtDtype(dst_type);
|
||||
auto *layer = context->network()->addIdentity(*input);
|
||||
layer->setOutputType(0, trt_type);
|
||||
|
||||
if (trt_type == nvinfer1::DataType::kHALF) {
|
||||
MS_LOG(WARNING) << "The model is exported with auto-mixed-precsion or manual precision mode. "
|
||||
<< "Retreat inference with native backend. It is recommended that export FP32 model "
|
||||
<< "and then inference with FP16 precision mode configuration.";
|
||||
return {false, {}};
|
||||
}
|
||||
return {true, {layer->getOutput(0)}};
|
||||
}
|
||||
|
||||
MS_TRT_CONVERTER_FUNC_REG(LayerNorm) {
|
||||
std::vector<LayerInput> inputs;
|
||||
bool ret = context->LoadLayerInput(node, &inputs);
|
||||
if (!ret || inputs.size() != 3 || !inputs[0].IsTensor() || !inputs[1].IsWeight() || !inputs[2].IsWeight()) {
|
||||
if (!ret || inputs.size() != 3 || !inputs[0].IsTensor()) {
|
||||
MS_LOG(ERROR) << "Get inputs failed. Input num: " << inputs.size();
|
||||
return {false, {}};
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue