refctor arithmetic simplify

This commit is contained in:
Yang Jiao 2021-07-27 11:29:26 +08:00
parent b86ce1e832
commit e2cfc516eb
8 changed files with 711 additions and 842 deletions

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -17,16 +17,28 @@
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_ARITHMETIC_SIMPLIFY_H_
#include <memory>
#include <vector>
#include <unordered_map>
#include <string>
#include "backend/optimizer/common/optimizer.h"
#include "ir/func_graph.h"
#include "backend/optimizer/graph_kernel/model/lite_graph.h"
namespace mindspore {
namespace opt {
class PatternTree;
using PatternTreePtr = std::shared_ptr<PatternTree>;
class ArithmeticSimplify : public Pass {
public:
ArithmeticSimplify() : Pass("arithmetic_simplify") {}
~ArithmeticSimplify() override = default;
bool Run(const FuncGraphPtr &func_graph) override;
private:
bool DoArithmeticTrans(const graphkernel::LiteGraphPtr &litegraph);
bool DoConstantFold(const graphkernel::LiteGraphPtr &litegraph);
std::unordered_map<std::string, std::vector<PatternTreePtr>> expressions_map_;
};
using ArithmeticSimplifyPtr = std::shared_ptr<ArithmeticSimplify>;
} // namespace opt

View File

@ -115,12 +115,10 @@ NodePtr LiteGraph::GraphBuilder::Op(const std::string &op, const NodeBase &basei
PrimOpPtr LiteGraph::GraphBuilder::CreateOp(const std::string &op, const std::string &node_name) {
static std::map<std::string, std::function<PrimOpPtr(const std::string &, const std::string &)>> creators;
if (creators.empty()) {
creators = {
{"Add", Elemwise},
{"Sub", Elemwise},
{"ReduceSum", Reduce},
{"Conv2D", Conv2d},
};
creators = {{"Add", Elemwise}, {"Sub", Elemwise}, {"RealDiv", Elemwise}, {"Mul", Elemwise},
{"Log", Elemwise}, {"Pow", Elemwise}, {"Sqrt", Elemwise}, {"Rsqrt", Elemwise},
{"Rsqrt", Elemwise}, {"Neg", Elemwise}, {"Reciprocal", Elemwise}, {"Abs", Elemwise},
{"ReduceSum", Reduce}, {"ReduceMax", Reduce}, {"ReduceMin", Reduce}, {"Conv2D", Conv2d}};
}
auto iter = creators.find(op);
auto creator = (iter == creators.end() ? Opaque : iter->second);

View File

@ -15,11 +15,15 @@
*/
#include "backend/optimizer/graph_kernel/model/op_node.h"
#include <math.h>
#include <sstream>
#include <memory>
#include <set>
#include <string>
#include <vector>
#include <functional>
#include <unordered_map>
#include <unordered_set>
#include "backend/optimizer/graph_kernel/model/node.h"
@ -60,6 +64,87 @@ void PrimOp::Dump(std::ostringstream &os) const {
}
}
template <typename TM, typename TD>
tensor::TensorPtr CalcByOperator(const NodePtrList &inputs, const std::string &op, TypeId tid) {
std::vector<TM> inputs_tm;
std::transform(inputs.begin(), inputs.end(), std::back_inserter(inputs_tm), [](const NodePtr &i) {
return *static_cast<TM *>(std::static_pointer_cast<graphkernel::ConstTensorNode>(i)->data()->data_c());
});
std::unordered_map<std::string, std::function<TM(const std::vector<TM> &)>> func_map;
func_map["Add"] = [](const std::vector<TM> &n) { return n[0] + n[1]; };
func_map["Sub"] = [](const std::vector<TM> &n) { return n[0] - n[1]; };
func_map["Mul"] = [](const std::vector<TM> &n) { return n[0] * n[1]; };
func_map["RealDiv"] = [](const std::vector<TM> &n) { return n[0] / n[1]; };
func_map["Neg"] = [](const std::vector<TM> &n) { return -n[0]; };
func_map["Reciprocal"] = [](const std::vector<TM> &n) { return TM(1) / n[0]; };
func_map["Log"] = [](const std::vector<TM> &n) { return log(n[0]); };
func_map["Exp"] = [](const std::vector<TM> &n) { return exp(n[0]); };
func_map["Abs"] = [](const std::vector<TM> &n) { return n[0] < TM(0) ? (-n[0]) : n[0]; };
func_map["Sqrt"] = [](const std::vector<TM> &n) { return sqrt(n[0]); };
func_map["Rsqrt"] = [](const std::vector<TM> &n) { return TM(1) / sqrt(n[0]); };
if (func_map.find(op) == func_map.end()) return nullptr;
return std::make_shared<tensor::Tensor>(static_cast<TD>(func_map[op](inputs_tm)), TypeIdToType(tid));
}
NodePtr PrimOp::InferValue(const NodePtrList &inputs, const DAttrs &attrs, const std::string &op) {
for (auto i : inputs) {
if (i->NodeType() != NType::Value) return nullptr;
}
TypeId output_type = InferType(inputs, attrs);
tensor::TensorPtr res = nullptr;
switch (output_type) {
case TypeId::kNumberTypeUInt8: {
res = CalcByOperator<uint8_t, int64_t>(inputs, op, output_type);
break;
}
case TypeId::kNumberTypeInt8: {
res = CalcByOperator<int8_t, int64_t>(inputs, op, output_type);
break;
}
case TypeId::kNumberTypeInt16: {
res = CalcByOperator<int16_t, int64_t>(inputs, op, output_type);
break;
}
case TypeId::kNumberTypeInt32: {
res = CalcByOperator<int32_t, int64_t>(inputs, op, output_type);
break;
}
case TypeId::kNumberTypeInt64: {
res = CalcByOperator<int64_t, int64_t>(inputs, op, output_type);
break;
}
case TypeId::kNumberTypeUInt16: {
res = CalcByOperator<uint16_t, int64_t>(inputs, op, output_type);
break;
}
case TypeId::kNumberTypeUInt32: {
res = CalcByOperator<uint32_t, int64_t>(inputs, op, output_type);
break;
}
case TypeId::kNumberTypeUInt64: {
res = CalcByOperator<uint64_t, int64_t>(inputs, op, output_type);
break;
}
case TypeId::kNumberTypeFloat16: {
res = CalcByOperator<float16, double>(inputs, op, output_type);
break;
}
case TypeId::kNumberTypeFloat32: {
res = CalcByOperator<float, double>(inputs, op, output_type);
break;
}
case TypeId::kNumberTypeFloat64: {
res = CalcByOperator<double, double>(inputs, op, output_type);
break;
}
default:
return nullptr;
}
return res == nullptr ? nullptr : std::make_shared<ConstTensorNode>(res);
}
void ElemwiseOp::Infer(const NodePtrList &inputs, const DAttrs &attrs) {
PrimOp::Infer(inputs, attrs);
auto IsBroadcast = [this](const NodePtrList &inputs) -> bool {

View File

@ -45,6 +45,7 @@ class PrimOp : public Node {
const std::string &op() const { return op_; }
ComputeType compute_type() const { return compute_type_; }
virtual NodePtr InferValue(const NodePtrList &inputs, const DAttrs &attrs, const std::string &op);
protected:
std::string op_;

View File

@ -198,6 +198,8 @@ void GraphKernelFlags::RegisterFlags(std::map<std::string, std::string> *flag_ma
reg.AddFlag("enable_cluster_ops", &enable_cluster_ops);
reg.AddFlag("enable_cluster_ops_only", &enable_cluster_ops_only);
reg.AddFlag("disable_cluster_ops", &disable_cluster_ops);
reg.AddFlag("enable_simplify_exprs_only", &enable_simplify_exprs_only);
reg.AddFlag("disable_simplify_exprs", &disable_simplify_exprs);
reg.AddFlag("enable_pass", &enable_pass);
reg.AddFlag("disable_pass", &disable_pass);
}
@ -221,6 +223,8 @@ std::string GraphKernelFlags::DumpAllFlags() const {
json["enable_cluster_ops"] = enable_cluster_ops;
json["enable_cluster_ops_only"] = enable_cluster_ops_only;
json["disable_cluster_ops"] = disable_cluster_ops;
json["enable_simplify_exprs_only"] = enable_simplify_exprs_only;
json["disable_simplify_exprs"] = disable_simplify_exprs;
json["enable_pass"] = enable_pass;
json["disable_pass"] = disable_pass;

View File

@ -140,6 +140,18 @@ class GraphKernelFlags {
*/
std::vector<std::string> disable_cluster_ops;
/**
* Arithmetic simplify expressions to be enabled (case sensitive).
* The default list will be overwritten by this list.
* Note that "disable_simplify_exprs" will be ignored if this flag is set.
*/
std::vector<std::string> enable_simplify_exprs_only;
/**
* Arithmetic simplify expressions to be disabled (case sensitive).
*/
std::vector<std::string> disable_simplify_exprs;
/**
* Passes to be enabled.
* By default, the passes is controlled by "opt_level" and target device,

View File

@ -32,6 +32,7 @@ class Net(Cell):
self.pow = P.Pow()
self.neg = P.Neg()
self.reducemin = P.ReduceMin()
self.reducesum = P.ReduceSum(keep_dims=True)
self.reshape = P.Reshape()
def construct(self, x, y):
@ -44,7 +45,9 @@ class Net(Cell):
neg_res = self.neg(self.neg(pow_res))
add_res3 = self.add(neg_res, div_res)
resh_res = self.reshape(add_res3, (2, 12, 3))
return self.reducemin(resh_res, 1)
neg_res = self.neg(resh_res)
red_res = self.reducesum(neg_res, 0)
return self.reducemin(self.reducemin(red_res, 1), 1)
def test_basic():
@ -58,7 +61,9 @@ def test_basic():
pow_res = input_y * input_y
neg_res = pow_res
add_res3 = neg_res + div_res
expect = np.min(add_res3, (1, 2))
neg_res = np.negative(add_res3)
red_res = np.sum(neg_res, axis=0, keepdims=True)
expect = np.min(red_res, (1, 2, 3))
net = Net()
result = net(Tensor(input_x), Tensor(input_y))