forked from mindspore-Ecosystem/mindspore
refctor arithmetic simplify
This commit is contained in:
parent
b86ce1e832
commit
e2cfc516eb
File diff suppressed because it is too large
Load Diff
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -17,16 +17,28 @@
|
|||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_ARITHMETIC_SIMPLIFY_H_
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <unordered_map>
|
||||
#include <string>
|
||||
|
||||
#include "backend/optimizer/common/optimizer.h"
|
||||
#include "ir/func_graph.h"
|
||||
#include "backend/optimizer/graph_kernel/model/lite_graph.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class PatternTree;
|
||||
using PatternTreePtr = std::shared_ptr<PatternTree>;
|
||||
class ArithmeticSimplify : public Pass {
|
||||
public:
|
||||
ArithmeticSimplify() : Pass("arithmetic_simplify") {}
|
||||
~ArithmeticSimplify() override = default;
|
||||
bool Run(const FuncGraphPtr &func_graph) override;
|
||||
|
||||
private:
|
||||
bool DoArithmeticTrans(const graphkernel::LiteGraphPtr &litegraph);
|
||||
bool DoConstantFold(const graphkernel::LiteGraphPtr &litegraph);
|
||||
std::unordered_map<std::string, std::vector<PatternTreePtr>> expressions_map_;
|
||||
};
|
||||
using ArithmeticSimplifyPtr = std::shared_ptr<ArithmeticSimplify>;
|
||||
} // namespace opt
|
||||
|
|
|
@ -115,12 +115,10 @@ NodePtr LiteGraph::GraphBuilder::Op(const std::string &op, const NodeBase &basei
|
|||
PrimOpPtr LiteGraph::GraphBuilder::CreateOp(const std::string &op, const std::string &node_name) {
|
||||
static std::map<std::string, std::function<PrimOpPtr(const std::string &, const std::string &)>> creators;
|
||||
if (creators.empty()) {
|
||||
creators = {
|
||||
{"Add", Elemwise},
|
||||
{"Sub", Elemwise},
|
||||
{"ReduceSum", Reduce},
|
||||
{"Conv2D", Conv2d},
|
||||
};
|
||||
creators = {{"Add", Elemwise}, {"Sub", Elemwise}, {"RealDiv", Elemwise}, {"Mul", Elemwise},
|
||||
{"Log", Elemwise}, {"Pow", Elemwise}, {"Sqrt", Elemwise}, {"Rsqrt", Elemwise},
|
||||
{"Rsqrt", Elemwise}, {"Neg", Elemwise}, {"Reciprocal", Elemwise}, {"Abs", Elemwise},
|
||||
{"ReduceSum", Reduce}, {"ReduceMax", Reduce}, {"ReduceMin", Reduce}, {"Conv2D", Conv2d}};
|
||||
}
|
||||
auto iter = creators.find(op);
|
||||
auto creator = (iter == creators.end() ? Opaque : iter->second);
|
||||
|
|
|
@ -15,11 +15,15 @@
|
|||
*/
|
||||
#include "backend/optimizer/graph_kernel/model/op_node.h"
|
||||
|
||||
#include <math.h>
|
||||
#include <sstream>
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <functional>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
|
||||
#include "backend/optimizer/graph_kernel/model/node.h"
|
||||
|
||||
|
@ -60,6 +64,87 @@ void PrimOp::Dump(std::ostringstream &os) const {
|
|||
}
|
||||
}
|
||||
|
||||
template <typename TM, typename TD>
|
||||
tensor::TensorPtr CalcByOperator(const NodePtrList &inputs, const std::string &op, TypeId tid) {
|
||||
std::vector<TM> inputs_tm;
|
||||
std::transform(inputs.begin(), inputs.end(), std::back_inserter(inputs_tm), [](const NodePtr &i) {
|
||||
return *static_cast<TM *>(std::static_pointer_cast<graphkernel::ConstTensorNode>(i)->data()->data_c());
|
||||
});
|
||||
|
||||
std::unordered_map<std::string, std::function<TM(const std::vector<TM> &)>> func_map;
|
||||
func_map["Add"] = [](const std::vector<TM> &n) { return n[0] + n[1]; };
|
||||
func_map["Sub"] = [](const std::vector<TM> &n) { return n[0] - n[1]; };
|
||||
func_map["Mul"] = [](const std::vector<TM> &n) { return n[0] * n[1]; };
|
||||
func_map["RealDiv"] = [](const std::vector<TM> &n) { return n[0] / n[1]; };
|
||||
func_map["Neg"] = [](const std::vector<TM> &n) { return -n[0]; };
|
||||
func_map["Reciprocal"] = [](const std::vector<TM> &n) { return TM(1) / n[0]; };
|
||||
func_map["Log"] = [](const std::vector<TM> &n) { return log(n[0]); };
|
||||
func_map["Exp"] = [](const std::vector<TM> &n) { return exp(n[0]); };
|
||||
func_map["Abs"] = [](const std::vector<TM> &n) { return n[0] < TM(0) ? (-n[0]) : n[0]; };
|
||||
func_map["Sqrt"] = [](const std::vector<TM> &n) { return sqrt(n[0]); };
|
||||
func_map["Rsqrt"] = [](const std::vector<TM> &n) { return TM(1) / sqrt(n[0]); };
|
||||
|
||||
if (func_map.find(op) == func_map.end()) return nullptr;
|
||||
return std::make_shared<tensor::Tensor>(static_cast<TD>(func_map[op](inputs_tm)), TypeIdToType(tid));
|
||||
}
|
||||
|
||||
NodePtr PrimOp::InferValue(const NodePtrList &inputs, const DAttrs &attrs, const std::string &op) {
|
||||
for (auto i : inputs) {
|
||||
if (i->NodeType() != NType::Value) return nullptr;
|
||||
}
|
||||
TypeId output_type = InferType(inputs, attrs);
|
||||
tensor::TensorPtr res = nullptr;
|
||||
switch (output_type) {
|
||||
case TypeId::kNumberTypeUInt8: {
|
||||
res = CalcByOperator<uint8_t, int64_t>(inputs, op, output_type);
|
||||
break;
|
||||
}
|
||||
case TypeId::kNumberTypeInt8: {
|
||||
res = CalcByOperator<int8_t, int64_t>(inputs, op, output_type);
|
||||
break;
|
||||
}
|
||||
case TypeId::kNumberTypeInt16: {
|
||||
res = CalcByOperator<int16_t, int64_t>(inputs, op, output_type);
|
||||
break;
|
||||
}
|
||||
case TypeId::kNumberTypeInt32: {
|
||||
res = CalcByOperator<int32_t, int64_t>(inputs, op, output_type);
|
||||
break;
|
||||
}
|
||||
case TypeId::kNumberTypeInt64: {
|
||||
res = CalcByOperator<int64_t, int64_t>(inputs, op, output_type);
|
||||
break;
|
||||
}
|
||||
case TypeId::kNumberTypeUInt16: {
|
||||
res = CalcByOperator<uint16_t, int64_t>(inputs, op, output_type);
|
||||
break;
|
||||
}
|
||||
case TypeId::kNumberTypeUInt32: {
|
||||
res = CalcByOperator<uint32_t, int64_t>(inputs, op, output_type);
|
||||
break;
|
||||
}
|
||||
case TypeId::kNumberTypeUInt64: {
|
||||
res = CalcByOperator<uint64_t, int64_t>(inputs, op, output_type);
|
||||
break;
|
||||
}
|
||||
case TypeId::kNumberTypeFloat16: {
|
||||
res = CalcByOperator<float16, double>(inputs, op, output_type);
|
||||
break;
|
||||
}
|
||||
case TypeId::kNumberTypeFloat32: {
|
||||
res = CalcByOperator<float, double>(inputs, op, output_type);
|
||||
break;
|
||||
}
|
||||
case TypeId::kNumberTypeFloat64: {
|
||||
res = CalcByOperator<double, double>(inputs, op, output_type);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
return nullptr;
|
||||
}
|
||||
return res == nullptr ? nullptr : std::make_shared<ConstTensorNode>(res);
|
||||
}
|
||||
|
||||
void ElemwiseOp::Infer(const NodePtrList &inputs, const DAttrs &attrs) {
|
||||
PrimOp::Infer(inputs, attrs);
|
||||
auto IsBroadcast = [this](const NodePtrList &inputs) -> bool {
|
||||
|
|
|
@ -45,6 +45,7 @@ class PrimOp : public Node {
|
|||
|
||||
const std::string &op() const { return op_; }
|
||||
ComputeType compute_type() const { return compute_type_; }
|
||||
virtual NodePtr InferValue(const NodePtrList &inputs, const DAttrs &attrs, const std::string &op);
|
||||
|
||||
protected:
|
||||
std::string op_;
|
||||
|
|
|
@ -198,6 +198,8 @@ void GraphKernelFlags::RegisterFlags(std::map<std::string, std::string> *flag_ma
|
|||
reg.AddFlag("enable_cluster_ops", &enable_cluster_ops);
|
||||
reg.AddFlag("enable_cluster_ops_only", &enable_cluster_ops_only);
|
||||
reg.AddFlag("disable_cluster_ops", &disable_cluster_ops);
|
||||
reg.AddFlag("enable_simplify_exprs_only", &enable_simplify_exprs_only);
|
||||
reg.AddFlag("disable_simplify_exprs", &disable_simplify_exprs);
|
||||
reg.AddFlag("enable_pass", &enable_pass);
|
||||
reg.AddFlag("disable_pass", &disable_pass);
|
||||
}
|
||||
|
@ -221,6 +223,8 @@ std::string GraphKernelFlags::DumpAllFlags() const {
|
|||
json["enable_cluster_ops"] = enable_cluster_ops;
|
||||
json["enable_cluster_ops_only"] = enable_cluster_ops_only;
|
||||
json["disable_cluster_ops"] = disable_cluster_ops;
|
||||
json["enable_simplify_exprs_only"] = enable_simplify_exprs_only;
|
||||
json["disable_simplify_exprs"] = disable_simplify_exprs;
|
||||
json["enable_pass"] = enable_pass;
|
||||
json["disable_pass"] = disable_pass;
|
||||
|
||||
|
|
|
@ -140,6 +140,18 @@ class GraphKernelFlags {
|
|||
*/
|
||||
std::vector<std::string> disable_cluster_ops;
|
||||
|
||||
/**
|
||||
* Arithmetic simplify expressions to be enabled (case sensitive).
|
||||
* The default list will be overwritten by this list.
|
||||
* Note that "disable_simplify_exprs" will be ignored if this flag is set.
|
||||
*/
|
||||
std::vector<std::string> enable_simplify_exprs_only;
|
||||
|
||||
/**
|
||||
* Arithmetic simplify expressions to be disabled (case sensitive).
|
||||
*/
|
||||
std::vector<std::string> disable_simplify_exprs;
|
||||
|
||||
/**
|
||||
* Passes to be enabled.
|
||||
* By default, the passes is controlled by "opt_level" and target device,
|
||||
|
|
|
@ -32,6 +32,7 @@ class Net(Cell):
|
|||
self.pow = P.Pow()
|
||||
self.neg = P.Neg()
|
||||
self.reducemin = P.ReduceMin()
|
||||
self.reducesum = P.ReduceSum(keep_dims=True)
|
||||
self.reshape = P.Reshape()
|
||||
|
||||
def construct(self, x, y):
|
||||
|
@ -44,7 +45,9 @@ class Net(Cell):
|
|||
neg_res = self.neg(self.neg(pow_res))
|
||||
add_res3 = self.add(neg_res, div_res)
|
||||
resh_res = self.reshape(add_res3, (2, 12, 3))
|
||||
return self.reducemin(resh_res, 1)
|
||||
neg_res = self.neg(resh_res)
|
||||
red_res = self.reducesum(neg_res, 0)
|
||||
return self.reducemin(self.reducemin(red_res, 1), 1)
|
||||
|
||||
|
||||
def test_basic():
|
||||
|
@ -58,7 +61,9 @@ def test_basic():
|
|||
pow_res = input_y * input_y
|
||||
neg_res = pow_res
|
||||
add_res3 = neg_res + div_res
|
||||
expect = np.min(add_res3, (1, 2))
|
||||
neg_res = np.negative(add_res3)
|
||||
red_res = np.sum(neg_res, axis=0, keepdims=True)
|
||||
expect = np.min(red_res, (1, 2, 3))
|
||||
|
||||
net = Net()
|
||||
result = net(Tensor(input_x), Tensor(input_y))
|
||||
|
|
Loading…
Reference in New Issue