open trt pass

This commit is contained in:
wilfChen 2021-06-03 19:23:22 +08:00
parent 44dd7d994b
commit 5373a2bb1b
4 changed files with 28 additions and 8 deletions

View File

@ -86,6 +86,20 @@ class TrtUtils {
[](const uint32_t &value) { return static_cast<int64_t>(value); });
return shape;
}
static bool IsSameShape(const nvinfer1::Dims &lhs, const nvinfer1::Dims &rhs) {
if (lhs.nbDims != rhs.nbDims) {
return false;
}
for (int32_t i = 0; i < lhs.nbDims; i++) {
if (lhs.d[i] != rhs.d[i]) {
return false;
}
}
return true;
}
};
class TrtLogger : public nvinfer1::ILogger {

View File

@ -80,7 +80,7 @@ std::unordered_map<AnfNodePtr, NodeInfo> CollectNodeInfo(const FuncGraphPtr &fun
const auto &converter_factory = TrtOpFactory::GetInstance();
ConvertFunc convert_func = converter_factory.GetConvertFunc(op_name);
if (!convert_func) {
res[node] = NodeInfo(NodeType::kUnSupport, i);
res[node] = NodeInfo(NodeType::kUnsupported, i);
continue;
}

View File

@ -236,7 +236,7 @@ ConvertResult AddReduceLayer(AnfNodePtr node, std::shared_ptr<TrtConverterContex
nvinfer1::Dims dim;
dim.nbDims = 1;
dim.d[1] = 1;
dim.d[0] = 1;
reshape_layer->setReshapeDimensions(dim);
return {true, {reshape_layer->getOutput(0)}};
@ -496,18 +496,18 @@ MS_TRT_CONVERTER_FUNC_REG(HSwish) {
return layer->getOutput(0);
};
// y = x * (Relu6(x) + 3.0) / 6.0
// relu6(x) = min(max(x, 0.0), 6.0)
// y = x * Relu6(x + 3.0) / 6.0
// Relu6(x) = min(max(x, 0.0), 6.0)
auto *c0 = AddConst(0.0f);
auto *c1 = AddConst(3.0f);
auto *c2 = AddConst(6.0f);
auto *x = inputs[0].tensor();
nvinfer1::ILayer *layer = context->network()->addElementWise(*x, *c0, nvinfer1::ElementWiseOperation::kMAX);
nvinfer1::ILayer *layer = context->network()->addElementWise(*x, *c1, nvinfer1::ElementWiseOperation::kSUM);
MS_EXCEPTION_IF_NULL(layer);
layer = context->network()->addElementWise(*layer->getOutput(0), *c0, nvinfer1::ElementWiseOperation::kMAX);
MS_EXCEPTION_IF_NULL(layer);
layer = context->network()->addElementWise(*layer->getOutput(0), *c2, nvinfer1::ElementWiseOperation::kMIN);
MS_EXCEPTION_IF_NULL(layer);
layer = context->network()->addElementWise(*layer->getOutput(0), *c1, nvinfer1::ElementWiseOperation::kSUM);
MS_EXCEPTION_IF_NULL(layer);
layer = context->network()->addElementWise(*layer->getOutput(0), *c2, nvinfer1::ElementWiseOperation::kDIV);
MS_EXCEPTION_IF_NULL(layer);
layer = context->network()->addElementWise(*x, *layer->getOutput(0), nvinfer1::ElementWiseOperation::kPROD);
@ -526,7 +526,7 @@ MS_TRT_CONVERTER_FUNC_REG(MatMul) {
const auto &transpose_a = AnfAlgo::GetNodeAttr<bool>(node, "transpose_a");
const auto &transpose_b = AnfAlgo::GetNodeAttr<bool>(node, "transpose_b");
if (inputs[0].IsTensor() && inputs[1].IsWeight()) {
if (inputs[0].IsTensor() && inputs[1].IsWeight() && transpose_a == false && transpose_b == true) {
// Reshape x from (M, K) to (M, K, 1, 1)
nvinfer1::Dims unsqueeze_dims = inputs[0].tensor()->getDimensions();
for (size_t i = 0; i < 2; i++) {

View File

@ -46,6 +46,9 @@
#include "backend/optimizer/gpu/add_relu_v2_fusion.h"
#include "backend/optimizer/gpu/add_relu_grad_v2_fusion.h"
#include "backend/optimizer/gpu/matmul_biasadd_fusion.h"
#if ENABLE_GPU_INFER
#include "backend/optimizer/trt_pass/graph_converter.h"
#endif
#include "backend/optimizer/graph_kernel/graph_kernel_optimization.h"
#include "backend/optimizer/pass/communication_op_fusion.h"
#include "backend/optimizer/pass/getitem_tuple.h"
@ -134,6 +137,9 @@ void GPUSession::Optimize(const std::shared_ptr<KernelGraph> &kernel_graph) {
MS_EXCEPTION_IF_NULL(kernel_graph);
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>();
#if ENABLE_GPU_INFER
pm->AddPass(std::make_shared<opt::GraphConverter>());
#endif
pm->AddPass(std::make_shared<opt::MatMulBiasAddFusion>());
pm->AddPass(std::make_shared<opt::AdamWeightDecayFusion>());
pm->AddPass(std::make_shared<opt::AdamFusion>());