!3043 Revert PR 2923

Merge pull request !3043 from BowenK/master
This commit is contained in:
mindspore-ci-bot 2020-07-14 11:46:45 +08:00 committed by Gitee
commit d95a54c321
5 changed files with 754 additions and 477 deletions

View File

@ -17,16 +17,14 @@
#ifndef MINDSPORE_CCSRC_IR_PATTERN_MATCHER_H_
#define MINDSPORE_CCSRC_IR_PATTERN_MATCHER_H_
#include <memory>
#include <functional>
#include <tuple>
#include <vector>
#include "ir/anf.h"
#include "operator/ops.h"
#include "optimizer/optimizer.h"
namespace mindspore {
///
/// Base class for all recognizable patterns.
/// We implement an Expression Template approach using static polymorphism based on
@ -62,7 +60,7 @@ class PIsEqual {
bool operator()(const T &lhs, const T &rhs) const { return lhs == rhs; }
};
template <typename T = AnfNodePtr>
template <typename T>
class PatternNode : public PBase<PatternNode<T> > {
public:
T GetNode(const AnfNodePtr &node) const {
@ -92,13 +90,12 @@ class PatternNode : public PBase<PatternNode<T> > {
template <typename T, typename T2>
class PBinOperation : public PBase<PBinOperation<T, T2> > {
public:
PBinOperation(const PrimitivePtr &prim, const T &x, const T2 &y, bool is_commutative = false)
: prim_(prim), x_(x), y_(y), is_commutative_(is_commutative) {}
PBinOperation(const PrimitivePtr &prim, const T &x, const T2 &y) : prim_(prim), x_(x), y_(y) {}
AnfNodePtr GetNode(const AnfNodePtr &node) const {
AnfNodePtr lhs = x_.GetNode(node->func_graph());
AnfNodePtr rhs = y_.GetNode(node->func_graph());
AnfNodePtrList list = {NewValueNode(prim_), lhs, rhs};
AnfNodePtrList list = {prim_->cast<AnfNodePtr>(), lhs, rhs};
return NewCNode(list, node->func_graph());
}
@ -109,14 +106,6 @@ class PBinOperation : public PBase<PBinOperation<T, T2> > {
if (inputs.size() == 3) {
// Binary Prim assumes only two inputs
if (!x_.TryCapture_(inputs[1]) || !y_.TryCapture_(inputs[2])) {
// If the operation is commutative, then check with inversed operands
if (is_commutative_) {
Reset();
if (!x_.TryCapture_(inputs[2]) || !y_.TryCapture_(inputs[1])) {
return false;
}
return true;
}
return false;
}
return true;
@ -124,6 +113,7 @@ class PBinOperation : public PBase<PBinOperation<T, T2> > {
}
return false;
}
void Reset() const {
x_.Reset();
y_.Reset();
@ -133,7 +123,6 @@ class PBinOperation : public PBase<PBinOperation<T, T2> > {
const PrimitivePtr prim_;
typename T::Internal x_;
typename T2::Internal y_;
bool is_commutative_{false};
};
///
@ -225,6 +214,7 @@ class PCNode : public PBase<PCNode<TArgs...> > {
return false;
}
void Reset() const {
tuple_utils::PTupleResetCapture reset;
tuple_utils::apply_func_tuple(&reset, args_);
@ -265,12 +255,6 @@ class PPrimitive : public PBase<PPrimitive<TArgs...> > {
return false;
}
// If set to true, TryCapture will try to capture the nodes in iversed nodes as well (only for two input case)
const PPrimitive<TArgs...> &Commutative(const bool &is_commutative = true) const {
is_commutative_ = is_commutative;
return *this;
}
void Reset() const {
tuple_utils::PTupleResetCapture reset;
tuple_utils::apply_func_tuple(&reset, args_);
@ -279,435 +263,46 @@ class PPrimitive : public PBase<PPrimitive<TArgs...> > {
private:
const PrimitivePtr prim_;
std::tuple<typename TArgs::Internal...> args_;
mutable bool is_commutative_{false};
};
///
/// PConstant class can capture a value node of a specified value (check_value_)
/// or a non-specified one (any_value = true).
/// It can be configured to capture a scalar constant as well (is_scalar_ = true)
///
template <typename T = AnfNodePtr>
class PConstant : public PBase<PConstant<T> > {
public:
explicit PConstant(const AnfNodePtr &as_node, const bool any_value = true, const int check_value = 0,
const bool is_scalar = false)
: as_node_(as_node),
captured_node_(as_node),
any_value_(any_value),
check_value_(check_value),
is_scalar_(is_scalar) {}
// Sets as_node_ as the node received as argument to produce a same-shape node with GetNode
const PConstant<T> &WithShapeAs(const AnfNodePtr &node) const {
as_node_ = node;
changed_shape_ = true;
return *this;
}
/// Sets captured_node_ as the node captured by the Pattern received as argument
/// to produce a new node with its contents when calling GetNode.
const PConstant<T> &WithValueOf(const PatternNode<T> &pnode) const {
if (!any_value_) {
MS_EXCEPTION(ValueError) << "Must use a PConstant with `any_value = true` to use the value of another node.";
}
captured_node_ = pnode.GetNode(captured_node_);
changed_shape_ = true;
return *this;
}
/// Create a new Value Node filled up with check_value.
/// This function must be used immediately before GetNode to avoid replacing the expected result.
const PConstant<T> &NewValue() const {
auto value_node_ = MakeValue(check_value_);
captured_node_ = NewValueNode(value_node_);
is_new_value_node_ = true;
return *this;
}
AnfNodePtr GetNode(const AnfNodePtr &node) const {
// If a NewValueNode was requested (using NewValue function) then return that created node.
if (is_new_value_node_) {
return captured_node_;
}
/// Return a NewTensorFilledWithData if the node was initialized to have a specific value
/// even if it wasn't captured. Usually for zero constants (x - x => zero).
/// If the shape was changed, use the new shape.
if (changed_shape_ || !captured_) {
if (!any_value_) {
return NewTensorFilledWithData(as_node_, check_value_);
}
return NewTensorFilledWithData(as_node_, captured_node_);
}
return captured_node_;
}
bool TryCapture_(const AnfNodePtr &node) const {
if (IsValueNode<Value>(node)) {
// If any_value_ is set don't check for the node's value. Just capture it.
if (any_value_) {
captured_node_ = node;
captured_ = true;
return true;
}
auto value = node->cast<ValueNodePtr>()->value();
if ((is_scalar_ && IsTensorScalarConstant(value)) || (!is_scalar_ && IsTensorConstant(value))) {
captured_node_ = node;
captured_ = true;
return true;
}
auto value_node_ = MakeValue(check_value_);
if (*GetValueNode(node) == *value_node_) {
captured_node_ = node;
captured_ = true;
return true;
}
}
return false;
}
void Reset() const {
captured_ = false;
changed_shape_ = false;
is_new_value_node_ = false;
}
// Support function used for checking if all values of a Tensor are equal to `check_value_`
// Supported data types: double, float/float32, int/int32
bool IsTensorConstant(const ValuePtr &value) const {
if (!value->isa<tensor::Tensor>()) {
return false;
}
auto tensor_ptr = dyn_cast<tensor::Tensor>(value);
TypeId tensor_type = tensor_ptr->Dtype()->type_id();
if ((tensor_type == TypeId::kNumberTypeFloat32) || (tensor_type == TypeId::kNumberTypeFloat)) {
float *data2 = reinterpret_cast<float *>(tensor_ptr->data_c());
for (int i = 0; i < tensor_ptr->DataSize(); i++) {
if (fabs(data2[i] - check_value_) > FLT_EPSILON) {
return false;
}
}
return true;
} else if (tensor_type == TypeId::kNumberTypeFloat64) {
double *data2 = reinterpret_cast<double *>(tensor_ptr->data_c());
for (int i = 0; i < tensor_ptr->DataSize(); i++) {
if (fabs(data2[i] - check_value_) > DBL_EPSILON) {
return false;
}
}
return true;
} else if ((tensor_type == TypeId::kNumberTypeInt32) || (tensor_type == TypeId::kNumberTypeInt)) {
int *data2 = reinterpret_cast<int *>(tensor_ptr->data_c());
for (int i = 0; i < tensor_ptr->DataSize(); i++) {
if (data2[i] != check_value_) {
return false;
}
}
return true;
}
// Input Data Type is not supported
return false;
}
bool IsTensorScalarConstant(const ValuePtr &value) const {
if (!value->isa<tensor::Tensor>()) {
return false;
}
auto tensor_ptr = dyn_cast<tensor::Tensor>(value);
if ((tensor_ptr->DataSize() > 1) || (tensor_ptr->DataDim() > 0)) {
return false;
}
return IsTensorConstant(value);
}
void *GetPointerToTensorData(const AnfNodePtr &node, bool writable = false) const {
if (!node->isa<ValueNode>()) {
return nullptr;
}
auto value = node->cast<ValueNodePtr>()->value();
if (!value->isa<tensor::Tensor>()) {
return nullptr;
}
tensor::TensorPtr tensor_ptr = dyn_cast<tensor::Tensor>(value);
return tensor_ptr->data_c();
}
// Make a new tensor (when possible) with the same shape as of `node`
// If x is nullptr then fill new tensor will "0"
// If x is a tensor with empty shape then fill new tensor with the single value of x
// If x is a tensor with same shape as `node` then return x as result
AnfNodePtr NewTensorFilledWithData(const AnfNodePtr &node, const AnfNodePtr &x = nullptr) const {
if ((node->abstract() == nullptr) || !node->abstract()->isa<abstract::AbstractTensor>()) {
return nullptr;
}
auto tensor_abstract = node->abstract()->cast<abstract::AbstractTensorPtr>();
TypePtr tensor_type_ptr = tensor_abstract->element()->BuildType();
std::vector<int> tensor_shape = tensor_abstract->shape()->shape();
auto new_tensor_ptr = std::make_shared<tensor::Tensor>(tensor_type_ptr->type_id(), tensor_shape);
size_t mem_size = GetTypeByte(tensor_type_ptr) * IntToSize(new_tensor_ptr->ElementsNum());
char *data = reinterpret_cast<char *>(new_tensor_ptr->data_c());
if (x == nullptr) {
std::memset(data, 0, mem_size);
auto new_vnode = NewValueNode(new_tensor_ptr);
new_vnode->set_abstract(new_tensor_ptr->ToAbstract());
return new_vnode;
}
// x is not nullptr
if (x->isa<CNode>()) {
if ((x->abstract() == nullptr) || !x->abstract()->isa<abstract::AbstractTensor>()) {
return nullptr;
}
auto x_abstract = x->abstract()->cast<abstract::AbstractTensorPtr>();
std::vector<int> x_shape = x_abstract->shape()->shape();
if (x_shape != tensor_shape) {
return nullptr;
}
return x;
}
if (!x->isa<ValueNode>()) {
return nullptr;
}
auto x_value = x->cast<ValueNodePtr>()->value();
if (!x_value->isa<tensor::Tensor>()) {
return nullptr;
}
auto x_tensor_ptr = dyn_cast<tensor::Tensor>(x_value);
if ((x_tensor_ptr->DataSize() > 1) && (x_tensor_ptr->DataSize() != new_tensor_ptr->DataSize())) {
return nullptr;
}
char *source_data = reinterpret_cast<char *>(GetPointerToTensorData(x));
if (x_tensor_ptr->DataSize() == 1) {
for (int i = 0; i < new_tensor_ptr->ElementsNum(); i++) {
memcpy(data + i * GetTypeByte(tensor_type_ptr), source_data, GetTypeByte(tensor_type_ptr));
}
} else {
memcpy(data, source_data, mem_size);
}
auto new_vnode = NewValueNode(new_tensor_ptr);
new_vnode->set_abstract(new_tensor_ptr->ToAbstract());
return new_vnode;
}
AnfNodePtr NewTensorFilledWithData(const AnfNodePtr &node, const int &value) const {
if ((node->abstract() == nullptr) || !node->abstract()->isa<abstract::AbstractTensor>()) {
return nullptr;
}
auto tensor_abstract = node->abstract()->cast<abstract::AbstractTensorPtr>();
TypePtr tensor_type_ptr = tensor_abstract->element()->BuildType();
std::vector<int> tensor_shape = tensor_abstract->shape()->shape();
auto new_tensor_ptr = std::make_shared<tensor::Tensor>(tensor_type_ptr->type_id(), tensor_shape);
size_t mem_size = GetTypeByte(tensor_type_ptr) * IntToSize(new_tensor_ptr->ElementsNum());
char *data = reinterpret_cast<char *>(new_tensor_ptr->data_c());
std::memset(data, value, mem_size);
auto new_vnode = NewValueNode(new_tensor_ptr);
new_vnode->set_abstract(new_tensor_ptr->ToAbstract());
return new_vnode;
}
// Support function to multiply two constant tensors: partially support broadcasting shapes
template <typename TM>
void Multiply(void *in_data_1, int in_data_1_size, void *in_data_2, int in_data_2_size, void **out_data,
int out_data_size) const {
TM *data_1 = reinterpret_cast<TM *>(in_data_1);
TM *data_2 = reinterpret_cast<TM *>(in_data_2);
TM *data_out = new TM[out_data_size];
if (in_data_1_size == 1) {
for (int i = 0; i < out_data_size; i++) {
data_out[i] = data_1[0];
}
} else {
for (int i = 0; i < out_data_size; i++) {
data_out[i] = data_1[i];
}
}
if (in_data_2_size == 1) {
for (int i = 0; i < out_data_size; i++) {
data_out[i] *= data_2[0];
}
} else {
if (in_data_2_size < out_data_size) {
MS_EXCEPTION(ValueError) << "in_data_2_size is smaller than out_data_size.";
}
for (int i = 0; i < out_data_size; i++) {
data_out[i] *= data_2[i];
}
}
*out_data = reinterpret_cast<void *>(data_out);
return;
}
AnfNodePtr MulByPatternConst(const PConstant<T> &vpnode_2, const AnfNodePtr &node_3) const {
AnfNodePtr vnode_1 = this->GetNode(captured_node_);
AnfNodePtr vnode_2 = vpnode_2.GetNode(captured_node_);
return MulConstantTensors(vnode_1, vnode_2, node_3);
}
AnfNodePtr MulConstantTensors(const AnfNodePtr &vnode_1, const AnfNodePtr &vnode_2, const AnfNodePtr &node_3) const {
if (!vnode_1->isa<ValueNode>() || !vnode_2->isa<ValueNode>() || (vnode_1->abstract() == nullptr) ||
(vnode_2->abstract() == nullptr) || (node_3->abstract() == nullptr)) {
return nullptr;
}
auto value_1 = GetValueNode(vnode_1);
auto value_2 = GetValueNode(vnode_2);
if (!value_1->isa<tensor::Tensor>() || !value_2->isa<tensor::Tensor>()) {
return nullptr;
}
auto tensor_ptr_1 = dyn_cast<tensor::Tensor>(value_1);
auto tensor_ptr_2 = dyn_cast<tensor::Tensor>(value_2);
auto tensor_1_abstract = vnode_1->abstract()->cast<abstract::AbstractTensorPtr>();
auto tensor_2_abstract = vnode_1->abstract()->cast<abstract::AbstractTensorPtr>();
auto tensor_3_abstract = node_3->abstract()->cast<abstract::AbstractTensorPtr>();
TypePtr tensor_1_type_ptr = tensor_1_abstract->element()->BuildType();
TypePtr tensor_2_type_ptr = tensor_2_abstract->element()->BuildType();
TypePtr tensor_3_type_ptr = tensor_3_abstract->element()->BuildType();
if ((tensor_1_type_ptr->type_id() != tensor_3_type_ptr->type_id()) ||
(tensor_2_type_ptr->type_id() != tensor_3_type_ptr->type_id())) {
return nullptr;
}
std::vector<int> tensor_out_shape = tensor_3_abstract->shape()->shape();
int data_out_size = std::accumulate(tensor_out_shape.begin(), tensor_out_shape.end(), 1, std::multiplies<int>());
if ((tensor_ptr_1->DataSize() > 1) && (tensor_ptr_1->DataSize() != data_out_size)) {
return nullptr;
}
if ((tensor_ptr_2->DataSize() > 1) && (tensor_ptr_2->DataSize() != data_out_size)) {
return nullptr;
}
auto new_tensor_ptr = std::make_shared<tensor::Tensor>(tensor_3_type_ptr->type_id(), tensor_out_shape);
size_t mem_size = GetTypeByte(tensor_3_type_ptr) * IntToSize(new_tensor_ptr->ElementsNum());
char *data = reinterpret_cast<char *>(new_tensor_ptr->data_c());
int ret = 0;
void *data_out = nullptr;
if ((tensor_3_type_ptr->type_id() == TypeId::kNumberTypeFloat32) ||
(tensor_3_type_ptr->type_id() == TypeId::kNumberTypeFloat)) {
Multiply<float>(tensor_ptr_1->data_c(), tensor_ptr_1->DataSize(), tensor_ptr_2->data_c(),
tensor_ptr_2->DataSize(), &data_out, data_out_size);
ret = memcpy_s(data, mem_size, data_out, mem_size);
delete[] reinterpret_cast<float *>(data_out);
} else {
if (tensor_3_type_ptr->type_id() == TypeId::kNumberTypeFloat64) {
Multiply<double>(tensor_ptr_1->data_c(), tensor_ptr_1->DataSize(), tensor_ptr_2->data_c(),
tensor_ptr_2->DataSize(), &data_out, data_out_size);
ret = memcpy_s(data, mem_size, data_out, mem_size);
delete[] reinterpret_cast<double *>(data_out);
} else {
if ((tensor_3_type_ptr->type_id() == TypeId::kNumberTypeInt32) ||
(tensor_3_type_ptr->type_id() == TypeId::kNumberTypeInt)) {
Multiply<int>(tensor_ptr_1->data_c(), tensor_ptr_1->DataSize(), tensor_ptr_2->data_c(),
tensor_ptr_2->DataSize(), &data_out, data_out_size);
ret = memcpy_s(data, mem_size, data_out, mem_size);
delete[] reinterpret_cast<int *>(data_out);
} else {
// Un-support data types
return nullptr;
}
}
}
if (ret != 0) {
MS_LOG(EXCEPTION) << "memcpy_s error, errorno " << ret << ", source size " << mem_size << "dest size"
<< new_tensor_ptr->DataSize();
}
auto new_vnode = NewValueNode(new_tensor_ptr);
new_vnode->set_abstract(new_tensor_ptr->ToAbstract());
return new_vnode;
}
using Internal = const PConstant<T> &;
protected:
mutable AnfNodePtr as_node_;
mutable AnfNodePtr captured_node_;
bool any_value_{true};
int check_value_{0};
bool is_scalar_{false};
mutable bool is_new_value_node_{false};
mutable bool captured_{false};
mutable bool changed_shape_{false};
};
// Macro for binary operation functions
#define BIN_OPERATION_PATTERN(Operator, MSPrimitive, Commutative) \
template <typename T, typename T2> \
inline PBinOperation<T, T2> Operator(const PBase<T> &x, const PBase<T2> &y) { \
return PBinOperation(MSPrimitive, x.get_object(), y.get_object(), Commutative); \
#define BIN_OPERATION_PATTERN(Operator, MSPrimitive) \
template <typename T, typename T2> \
inline PBinOperation<T, T2> Operator(const PBase<T> &x, const PBase<T2> &y) { \
return PBinOperation(MSPrimitive, x.get_object(), y.get_object()); \
}
// Arithmetic operations
BIN_OPERATION_PATTERN(operator+, prim::kPrimTensorAdd, true);
BIN_OPERATION_PATTERN(operator*, prim::kPrimMul, true);
BIN_OPERATION_PATTERN(operator+, prim::kPrimTensorAdd);
BIN_OPERATION_PATTERN(operator*, prim::kPrimMul);
// Macros for match and replace
#define MATCH_REPLACE(OrigNode, CaptureNode, ReplaceWith) \
if ((CaptureNode).TryCapture(OrigNode)) { \
auto rep = (ReplaceWith).GetNode(OrigNode); \
if (rep != nullptr) { \
return rep; \
} \
return (ReplaceWith).GetNode(OrigNode); \
}
#define MATCH_REPLACE_IF(OrigNode, CaptureNode, ReplaceWith, Condition) \
if ((CaptureNode).TryCapture(OrigNode) && (Condition)) { \
auto rep = (ReplaceWith).GetNode(OrigNode); \
if (rep != nullptr) { \
return rep; \
} \
return (ReplaceWith).GetNode(OrigNode); \
}
#define MATCH_REPLACE_IF_ELSE(OrigNode, CaptureNode, ReplaceWith, Condition, ElseNode) \
if ((CaptureNode).TryCapture(OrigNode)) { \
if ((Condition)) { \
auto rep = (ReplaceWith).GetNode(OrigNode); \
if (rep != nullptr) { \
return (ReplaceWith).GetNode(OrigNode); \
} \
} else { \
auto rep = (ElseNode).GetNode(OrigNode); \
if (rep != nullptr) { \
return (ElseNode).GetNode(OrigNode); \
} \
return (ReplaceWith).GetNode(OrigNode); \
} \
return (ElseNode).GetNode(OrigNode); \
}
#define MATCH_REPLACE_LAMBDA(OrigNode, CaptureNode, Lambda) \
if ((CaptureNode).TryCapture(OrigNode)) { \
auto rep = (Lambda)(); \
if (rep != nullptr) { \
return rep; \
} \
return (Lambda)(); \
}
#define MATCH_REPLACE_LAMBDA_IF(OrigNode, CaptureNode, Lambda, Condition) \
if ((CaptureNode).TryCapture(OrigNode) && (Condition)) { \
auto rep = (Lambda)(); \
if (rep != nullptr) { \
return rep; \
} \
return (Lambda)(); \
}
} // namespace mindspore

View File

@ -14,67 +14,542 @@
* limitations under the License.
*/
#include <algorithm>
#include <memory>
#include <vector>
#include <functional>
#include "optimizer/irpass/arithmetic_simplify.h"
#include "ir/optimizer_caller.h"
#include "ir/visitor.h"
#include "operator/ops.h"
#include "optimizer/irpass.h"
#include "optimizer/irpass/prim_eliminate.h"
#include "optimizer/optimizer.h"
namespace mindspore {
namespace opt {
namespace irpass {
AnfNodePtr ArithmeticSimplify::operator()(const OptimizerPtr &, const AnfNodePtr &node) {
PatternNode x, y, z, xs;
PConstant one_(node, false, 1);
PConstant one_scalar_(node, false, 1, true);
PConstant zero_(node, false, 0);
PConstant zero_scalar_(node, false, 0, true);
PConstant const_(node);
PConstant const_2(node);
PConstant any_const(node);
// {prim::kPrimScalarMul, 0, X}, {prim::kPrimScalarMul, X, 0}
// {prim::kPrimScalarMul, 1, X}, {prim::kPrimScalarMul, X, 1}
AnfNodePtr MultiplyByZeroOrOne::operator()(const OptimizerPtr &, const AnfNodePtr &node) {
Reset();
AnfVisitor::Match(prim::kPrimScalarMul)(node);
MATCH_REPLACE(node, x + zero_, x); // Add by zero
MATCH_REPLACE(node, x + zero_scalar_, x); // Add by zero
MATCH_REPLACE(node, PPrimitive(prim::kPrimScalarAdd, zero_scalar_, x), x); // Scalar Add by zero
MATCH_REPLACE(node, PPrimitive(prim::kPrimScalarAdd, x, zero_scalar_), x); // Scalar Add by zero
MATCH_REPLACE_IF(node, x * one_, any_const.WithValueOf(x), x.CheckFunc(IsVNode, node)); // Multiply by one
MATCH_REPLACE(node, PPrimitive(prim::kPrimScalarMul, one_scalar_, x), x); // Scalar Mul by one
MATCH_REPLACE(node, PPrimitive(prim::kPrimScalarMul, x, one_scalar_), x); // Scalar Mul by one
MATCH_REPLACE(node, PPrimitive(prim::kPrimScalarMul, zero_scalar_, x), zero_.NewValue()); // Scalar Mul by zero
MATCH_REPLACE(node, PPrimitive(prim::kPrimScalarMul, x, zero_scalar_), zero_.NewValue()); // Scalar Mul by zero
// Prim Eliminate (identity)
MATCH_REPLACE(node, PPrimitive(prim::kPrimIdentity, x), x);
// ConstantDuplicateMul
auto const_dup_lambda = [&node, &x, &const_, &const_2]() -> AnfNodePtr {
auto new_mul_tensor = const_.MulByPatternConst(const_2, x.GetNode(node));
auto mul_node = node->cast<CNodePtr>()->inputs()[0];
if (new_mul_tensor == nullptr) {
auto ttmul = NewCNode({mul_node, const_.GetNode(node), const_2.GetNode(node)}, node->func_graph());
return NewCNode({mul_node, x.GetNode(node), ttmul}, node->func_graph());
}
return NewCNode({mul_node, x.GetNode(node), new_mul_tensor}, node->func_graph());
};
MATCH_REPLACE_LAMBDA(node, const_ * (const_2 * x), const_dup_lambda);
if (node->func_graph() == nullptr) {
return nullptr;
if (is_zero_) {
return NewValueNode(zero_);
}
if (is_one_) {
return x_;
}
// OptUpdateZeroTensor
MATCH_REPLACE(node, PPrimitive(prim::kPrimMomentum, PPrimitive(prim::kPrimZerosLike, x), y, z, xs),
PPrimitive(prim::kPrimMakeTuple, z, y));
// PowerOneEliminate
MATCH_REPLACE(node, PPrimitive(prim::kPrimPow, x, one_scalar_), x);
return nullptr;
}
AnfNodePtr ArithmeticSimplify2::operator()(const OptimizerPtr &, const AnfNodePtr &node) {
PatternNode x, y;
PConstant zero_(node, false, 0);
void MultiplyByZeroOrOne::Visit(const AnfNodePtr &node) {
if (is_one_ || node->isa<CNode>()) {
x_ = node;
return;
}
MATCH_REPLACE(node, x * zero_, zero_); // Multiply by zero
MATCH_REPLACE(node, x * PPrimitive(prim::kPrimZerosLike, y), zero_); // Multiply by zero
AnfVisitor::Visit(node);
if (!is_one_) {
x_ = node;
}
}
void MultiplyByZeroOrOne::Visit(const ValueNodePtr &vnode) {
auto value = vnode->value();
if (*value == *zero_) {
is_zero_ = true;
} else if (*value == *one_) {
is_one_ = true;
}
}
void MultiplyByZeroOrOne::Reset() {
x_ = nullptr;
is_one_ = false;
is_zero_ = false;
}
// Support class used for checking if all values of a Tensor are equal `check_value_`
// Supported data types: double, float/float32, int/int32
bool CheckTensorConstant::IsTensorConstant(const ValuePtr &value) {
if (!value->isa<tensor::Tensor>()) {
return false;
}
auto tensor_ptr = dyn_cast<tensor::Tensor>(value);
TypeId tensor_type = tensor_ptr->Dtype()->type_id();
if ((tensor_type == TypeId::kNumberTypeFloat32) || (tensor_type == TypeId::kNumberTypeFloat)) {
float *data2 = reinterpret_cast<float *>(tensor_ptr->data_c());
for (int i = 0; i < tensor_ptr->DataSize(); i++) {
if (fabs(data2[i] - check_value_) > FLT_EPSILON) {
return false;
}
}
return true;
} else if (tensor_type == TypeId::kNumberTypeFloat64) {
double *data2 = reinterpret_cast<double *>(tensor_ptr->data_c());
for (int i = 0; i < tensor_ptr->DataSize(); i++) {
if (fabs(data2[i] - check_value_) > DBL_EPSILON) {
return false;
}
}
return true;
} else if ((tensor_type == TypeId::kNumberTypeInt32) || (tensor_type == TypeId::kNumberTypeInt)) {
int *data2 = reinterpret_cast<int *>(tensor_ptr->data_c());
for (int i = 0; i < tensor_ptr->DataSize(); i++) {
if (data2[i] != check_value_) {
return false;
}
}
return true;
}
// input Data Types is not supported
return false;
}
bool CheckTensorConstant::IsTensorScalarConstant(const ValuePtr &value) {
if (!value->isa<tensor::Tensor>()) {
return false;
}
auto tensor_ptr = dyn_cast<tensor::Tensor>(value);
if ((tensor_ptr->DataSize() > 1) || (tensor_ptr->DataDim() > 0)) {
return false;
}
return IsTensorConstant(value);
}
void *TensorMultiplyBase::GetPointerToTensorData(const AnfNodePtr &node, bool writable) {
if (!node->isa<ValueNode>()) {
return nullptr;
}
auto value = node->cast<ValueNodePtr>()->value();
if (!value->isa<tensor::Tensor>()) {
return nullptr;
}
tensor::TensorPtr tensor_ptr = dyn_cast<tensor::Tensor>(value);
return tensor_ptr->data_c();
}
// Make a new tensor (when possible) with the same shape as of `node`
// If x is nullptr then fill new tensor will "0"
// If x is a tensor with empty shape then fill new tensor with the single value of x
// If x is a tensor with same shape as `node` then return x as result
AnfNodePtr TensorMultiplyBase::NewTensorFilledWithData(const AnfNodePtr &node, const AnfNodePtr &x) {
if ((node->abstract() == nullptr) || !node->abstract()->isa<abstract::AbstractTensor>()) {
return nullptr;
}
auto tensor_abstract = node->abstract()->cast<abstract::AbstractTensorPtr>();
TypePtr tensor_type_ptr = tensor_abstract->element()->BuildType();
std::vector<int> tensor_shape = tensor_abstract->shape()->shape();
auto new_tensor_ptr = std::make_shared<tensor::Tensor>(tensor_type_ptr->type_id(), tensor_shape);
size_t mem_size = GetTypeByte(tensor_type_ptr) * IntToSize(new_tensor_ptr->ElementsNum());
char *data = reinterpret_cast<char *>(new_tensor_ptr->data_c());
if (x == nullptr) {
std::memset(data, 0, mem_size);
auto new_vnode = NewValueNode(new_tensor_ptr);
new_vnode->set_abstract(new_tensor_ptr->ToAbstract());
return new_vnode;
}
// x is not nullptr
if (x->isa<CNode>()) {
if ((x->abstract() == nullptr) || !x->abstract()->isa<abstract::AbstractTensor>()) {
return nullptr;
}
auto x_abstract = x->abstract()->cast<abstract::AbstractTensorPtr>();
std::vector<int> x_shape = x_abstract->shape()->shape();
if (x_shape != tensor_shape) {
return nullptr;
}
return x;
}
if (!x->isa<ValueNode>()) {
return nullptr;
}
auto x_value = x->cast<ValueNodePtr>()->value();
if (!x_value->isa<tensor::Tensor>()) {
return nullptr;
}
auto x_tensor_ptr = dyn_cast<tensor::Tensor>(x_value);
if ((x_tensor_ptr->DataSize() > 1) && (x_tensor_ptr->DataSize() != new_tensor_ptr->DataSize())) {
return nullptr;
}
char *source_data = reinterpret_cast<char *>(GetPointerToTensorData(x));
if (x_tensor_ptr->DataSize() == 1) {
for (int i = 0; i < new_tensor_ptr->ElementsNum(); i++) {
memcpy(data + i * GetTypeByte(tensor_type_ptr), source_data, GetTypeByte(tensor_type_ptr));
}
} else {
memcpy(data, source_data, mem_size);
}
auto new_vnode = NewValueNode(new_tensor_ptr);
new_vnode->set_abstract(new_tensor_ptr->ToAbstract());
return new_vnode;
}
// {prim::kPrimMul, 0, X}, {prim::kPrimMul, X, 0}
AnfNodePtr TensorMultiplyByZero::operator()(const OptimizerPtr &, const AnfNodePtr &node) {
Reset();
AnfVisitor::Match(prim::kPrimMul)(node);
if (is_zero_) {
if (x_->func_graph() != node->func_graph()) {
return nullptr;
}
return NewTensorFilledWithData(node);
}
return nullptr;
}
void TensorMultiplyByZero::Visit(const AnfNodePtr &node) {
if (is_zero_) {
x_ = node;
return;
}
if (IsParam(node)) {
x_ = node;
return;
}
if (IsCNode(node)) {
CNodePtr cnode = node->cast<CNodePtr>();
if (IsPrimitive(cnode->input(0), prim::kPrimZerosLike)) {
is_zero_ = true;
return;
}
x_ = node;
return;
}
auto value = node->cast<ValueNodePtr>()->value();
if (CheckTensorConstant(0).IsTensorConstant(value)) {
is_zero_ = true;
return;
}
x_ = node;
}
void TensorMultiplyByZero::Visit(const ValueNodePtr &vnode) {
auto value = vnode->value();
if (CheckTensorConstant(0).IsTensorConstant(value)) {
is_zero_ = true;
return;
}
x_ = vnode;
}
void TensorMultiplyByZero::Reset() {
x_ = nullptr;
is_zero_ = false;
}
// {prim::kPrimMul, 1, X}, {prim::kPrimMul, X, 1}
AnfNodePtr TensorMultiplyByOne::operator()(const OptimizerPtr &, const AnfNodePtr &node) {
Reset();
AnfVisitor::Match(prim::kPrimMul)(node);
if (is_one_) {
return NewTensorFilledWithData(node, x_);
}
return nullptr;
}
void TensorMultiplyByOne::Visit(const AnfNodePtr &node) {
if (is_one_) {
x_ = node;
return;
}
if (IsParam(node) || IsCNode(node)) {
x_ = node;
return;
}
auto value = node->cast<ValueNodePtr>()->value();
if (CheckTensorConstant(1).IsTensorConstant(value)) {
is_one_ = true;
return;
}
x_ = node;
}
void TensorMultiplyByOne::Visit(const ValueNodePtr &vnode) {
auto value = vnode->value();
if (CheckTensorConstant(1).IsTensorConstant(value)) {
is_one_ = true;
return;
}
x_ = vnode;
}
void TensorMultiplyByOne::Reset() {
x_ = nullptr;
is_one_ = false;
}
// {prim::kPrimScalarAdd, X, 0}
// {prim::kPrimScalarAdd, 0, X}
AnfNodePtr AddByZero::operator()(const OptimizerPtr &, const AnfNodePtr &node) {
Reset();
AnfVisitor::Match(prim::kPrimScalarAdd)(node);
if (is_zero_) {
return x_;
}
return nullptr;
}
void AddByZero::Visit(const AnfNodePtr &node) {
if (node->isa<ValueNode>() &&
((*GetValueNode(node) == *zero_) || CheckTensorConstant(0).IsTensorScalarConstant(GetValueNode(node)))) {
is_zero_ = true;
return;
}
x_ = node;
}
void AddByZero::Reset() {
x_ = nullptr;
is_zero_ = false;
}
// {prim::kPrimTensorAdd, {kPrimZerosLike, Y}, X},
// {prim::kPrimTensorAdd, X, {kPrimZerosLike, Y}}
AnfNodePtr TensorAddByZero::operator()(const OptimizerPtr &, const AnfNodePtr &node) {
Reset();
AnfVisitor::Match(prim::kPrimTensorAdd)(node);
if (is_zero_) {
return x_;
}
return nullptr;
}
void TensorAddByZero::Visit(const AnfNodePtr &node) {
if (node->isa<ValueNode>() && CheckTensorConstant(0).IsTensorScalarConstant(GetValueNode(node))) {
is_zero_ = true;
return;
}
x_ = node;
}
void TensorAddByZero::Visit(const ValueNodePtr &vnode) {
auto value = vnode->value();
if (CheckTensorConstant(0).IsTensorConstant(value)) {
is_zero_ = true;
return;
}
}
void TensorAddByZero::Reset() {
x_ = nullptr;
is_zero_ = false;
}
// {PrimMomentum, {kPrimZerosLike, X}, Y, Z, Xs} -> {prim::kPrimMakeTuple, Z, Y}
AnfNodePtr OptUpdateZeroTensor::operator()(const OptimizerPtr &, const AnfNodePtr &node) {
if (!IsPrimitiveCNode(node, prim::kPrimMomentum) || node->func_graph() == nullptr) {
return nullptr;
}
// {PrimMomentum, {...}, Y, Z, Xs}
auto &inputs = node->cast<CNodePtr>()->inputs();
if (inputs.size() < 4 || !IsPrimitiveCNode(inputs[1], prim::kPrimZerosLike)) {
return nullptr;
}
auto y = inputs[2];
auto z = inputs[3];
// {kPrimZerosLike, X}
if (inputs[1]->cast<CNodePtr>()->size() != 2) {
return nullptr;
}
// {prim::kPrimMakeTuple, Z, Y}
return node->func_graph()->NewCNode({NewValueNode(prim::kPrimMakeTuple), z, y});
}
// {prim::kPrimMul, Tensor1, {prim::kPrimMul, Tensor2, {...}}} ->
// {prim::kPrimMul, {...}, {prim::kPrimMul, Tensor1, Tensor2}}
// Support function to multiply two constant tensors: partially support broadcasting shapes
template <typename T>
void ConstantDuplicateMul::Multiply(void *in_data_1, int in_data_1_size, void *in_data_2, int in_data_2_size,
void **out_data, int out_data_size) {
T *data_1 = reinterpret_cast<T *>(in_data_1);
T *data_2 = reinterpret_cast<T *>(in_data_2);
T *data_out = new T[out_data_size];
if (in_data_1_size == 1) {
for (int i = 0; i < out_data_size; i++) {
data_out[i] = data_1[0];
}
} else {
for (int i = 0; i < out_data_size; i++) {
data_out[i] = data_1[i];
}
}
if (in_data_2_size == 1) {
for (int i = 0; i < out_data_size; i++) {
data_out[i] *= data_2[0];
}
} else {
for (int i = 0; i < out_data_size; i++) {
data_out[i] *= data_2[i];
}
}
*out_data = reinterpret_cast<void *>(data_out);
return;
}
AnfNodePtr ConstantDuplicateMul::MulConstantTensors(const AnfNodePtr &vnode_1, const AnfNodePtr &vnode_2,
const AnfNodePtr &node_3) {
if (!vnode_1->isa<ValueNode>() || !vnode_2->isa<ValueNode>() || (vnode_1->abstract() == nullptr) ||
(vnode_2->abstract() == nullptr) || (node_3->abstract() == nullptr)) {
return nullptr;
}
auto value_1 = GetValueNode(vnode_1);
auto value_2 = GetValueNode(vnode_2);
if (!value_1->isa<tensor::Tensor>() || !value_2->isa<tensor::Tensor>()) {
return nullptr;
}
auto tensor_ptr_1 = dyn_cast<tensor::Tensor>(value_1);
auto tensor_ptr_2 = dyn_cast<tensor::Tensor>(value_2);
auto tensor_1_abstract = vnode_1->abstract()->cast<abstract::AbstractTensorPtr>();
auto tensor_2_abstract = vnode_1->abstract()->cast<abstract::AbstractTensorPtr>();
auto tensor_3_abstract = node_3->abstract()->cast<abstract::AbstractTensorPtr>();
TypePtr tensor_1_type_ptr = tensor_1_abstract->element()->BuildType();
TypePtr tensor_2_type_ptr = tensor_2_abstract->element()->BuildType();
TypePtr tensor_3_type_ptr = tensor_3_abstract->element()->BuildType();
if ((tensor_1_type_ptr->type_id() != tensor_3_type_ptr->type_id()) ||
(tensor_2_type_ptr->type_id() != tensor_3_type_ptr->type_id())) {
return nullptr;
}
std::vector<int> tensor_out_shape = tensor_3_abstract->shape()->shape();
int data_out_size = std::accumulate(tensor_out_shape.begin(), tensor_out_shape.end(), 1, std::multiplies<int>());
if ((tensor_ptr_1->DataSize() > 1) && (tensor_ptr_1->DataSize() != data_out_size)) {
return nullptr;
}
if ((tensor_ptr_2->DataSize() > 1) && (tensor_ptr_2->DataSize() != data_out_size)) {
return nullptr;
}
void *data_out;
if ((tensor_3_type_ptr->type_id() == TypeId::kNumberTypeFloat32) ||
(tensor_3_type_ptr->type_id() == TypeId::kNumberTypeFloat)) {
Multiply<float>(tensor_ptr_1->data_c(), tensor_ptr_1->DataSize(), tensor_ptr_2->data_c(), tensor_ptr_2->DataSize(),
&data_out, data_out_size);
} else {
if (tensor_3_type_ptr->type_id() == TypeId::kNumberTypeFloat64) {
Multiply<double>(tensor_ptr_1->data_c(), tensor_ptr_1->DataSize(), tensor_ptr_2->data_c(),
tensor_ptr_2->DataSize(), &data_out, data_out_size);
} else {
if ((tensor_3_type_ptr->type_id() == TypeId::kNumberTypeInt32) ||
(tensor_3_type_ptr->type_id() == TypeId::kNumberTypeInt)) {
Multiply<int>(tensor_ptr_1->data_c(), tensor_ptr_1->DataSize(), tensor_ptr_2->data_c(),
tensor_ptr_2->DataSize(), &data_out, data_out_size);
} else {
// Un-support data types
return nullptr;
}
}
}
auto new_tensor_ptr = std::make_shared<tensor::Tensor>(tensor_3_type_ptr->type_id(), tensor_out_shape);
size_t mem_size = GetTypeByte(tensor_3_type_ptr) * IntToSize(new_tensor_ptr->ElementsNum());
char *data = reinterpret_cast<char *>(new_tensor_ptr->data_c());
memcpy(data, data_out, mem_size);
auto new_vnode = NewValueNode(new_tensor_ptr);
new_vnode->set_abstract(new_tensor_ptr->ToAbstract());
return new_vnode;
}
AnfNodePtr ConstantDuplicateMul::operator()(const OptimizerPtr &, const AnfNodePtr &node) {
Reset();
// {prim::kPrimMul, Tensor1, {...}}
AnfVisitor::Match(prim::kPrimMul, {IsNode, IsNode})(node);
if (vnode_ == nullptr || c_p_node_ == nullptr) {
return nullptr;
}
if (!IsCNode(c_p_node_)) {
return nullptr;
}
auto tensor1 = vnode_;
auto mul = c_p_node_->cast<CNodePtr>();
Reset();
// {prim::kPrimMul, Tensor2, {...}}
AnfVisitor::Match(prim::kPrimMul, {IsNode, IsNode})(mul);
if (vnode_ == nullptr || c_p_node_ == nullptr) {
return nullptr;
}
auto tensor2 = vnode_;
auto c_p_node = c_p_node_;
auto PrimMul = GetValueNode<PrimitivePtr>(mul->input(0));
auto fg = node->func_graph();
auto new_mul_tensor = MulConstantTensors(tensor1, tensor2, c_p_node);
if (new_mul_tensor == nullptr) {
auto ttmul = NewCNode({NewValueNode(PrimMul), tensor1, tensor2}, fg);
return NewCNode({NewValueNode(PrimMul), c_p_node, ttmul}, fg);
}
return NewCNode({NewValueNode(PrimMul), c_p_node, new_mul_tensor}, fg);
}
void ConstantDuplicateMul::Visit(const AnfNodePtr &node) {
if (IsValueNode<tensor::Tensor>(node)) {
vnode_ = node;
}
if (IsCNode(node) || IsParam(node)) {
c_p_node_ = node;
}
}
void ConstantDuplicateMul::Reset() {
vnode_ = nullptr;
c_p_node_ = nullptr;
}
AnfNodePtr PowerOneEliminate::operator()(const OptimizerPtr &, const AnfNodePtr &node) {
if (!IsPrimitiveCNode(node, prim::kPrimPow) || node->func_graph() == nullptr) {
return nullptr;
}
auto &inputs = node->cast<CNodePtr>()->inputs();
if (!IsValueNode<Scalar>(inputs[2])) {
return nullptr;
}
auto scalar = GetValueNode<ScalarPtr>(inputs[2]);
if (scalar->isa<FloatImm>() && GetValue<float>(scalar) == 1.0) {
return inputs[1];
} else if (scalar->isa<IntergerImm>() && GetValue<int>(scalar) == 1) {
return inputs[1];
}
return nullptr;
}
@ -179,6 +654,27 @@ void AdjustAllReduceMulAdd::Reset() {
all_reduce_fg_ = nullptr;
}
AnfNodePtr ArithmeticSimplify::operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) {
AnfNodePtr new_node;
for (auto &eliminater : eliminaters_) {
new_node = (*eliminater)(optimizer, node);
if (new_node != nullptr) {
return new_node;
}
}
return nullptr;
}
AnfNodePtr ArithmeticSimplify2::operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) {
AnfNodePtr new_node;
for (auto &eliminater : eliminaters_) {
new_node = (*eliminater)(optimizer, node);
if (new_node != nullptr) {
return new_node;
}
}
return nullptr;
}
} // namespace irpass
} // namespace opt
} // namespace mindspore

View File

@ -22,14 +22,158 @@
#include <vector>
#include "ir/optimizer_caller.h"
#include "ir/pattern_matcher.h"
#include "ir/visitor.h"
#include "operator/ops.h"
#include "optimizer/irpass.h"
#include "optimizer/irpass/prim_eliminate.h"
#include "optimizer/optimizer.h"
namespace mindspore {
namespace opt {
namespace irpass {
// {prim::kPrimScalarMul, 0, X}, {prim::kPrimScalarMul, X, 0}
// {prim::kPrimScalarMul, 1, X}, {prim::kPrimScalarMul, X, 1}
class MultiplyByZeroOrOne : public AnfVisitor {
public:
MultiplyByZeroOrOne() : zero_(MakeValue(0)), one_(MakeValue(1)) {}
~MultiplyByZeroOrOne() override = default;
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override;
void Visit(const AnfNodePtr &node) override;
void Visit(const ValueNodePtr &vnode) override;
void Reset();
private:
bool is_zero_{false}, is_one_{false};
ValuePtr zero_, one_;
AnfNodePtr x_{nullptr};
};
// Support class used for checking if all values of a Tensor are equal `check_value_`
// Supported data types: double, float/float32, int/int32
class CheckTensorConstant {
public:
explicit CheckTensorConstant(int _check_value = 0) : check_value_(_check_value) {}
~CheckTensorConstant() = default;
bool IsTensorConstant(const ValuePtr &value);
bool IsTensorScalarConstant(const ValuePtr &value);
private:
int check_value_;
};
class TensorMultiplyBase : public AnfVisitor {
protected:
void *GetPointerToTensorData(const AnfNodePtr &node, bool writable = false);
// Make a new tensor (when possible) with the same shape as of `node`
// If x is nullptr then fill new tensor will "0"
// If x is a tensor with empty shape then fill new tensor with the single value of x
// If x is a tensor with same shape as `node` then return x as result
AnfNodePtr NewTensorFilledWithData(const AnfNodePtr &node, const AnfNodePtr &x = nullptr);
AnfNodePtr x_{nullptr};
};
// {prim::kPrimMul, 0, X}, {prim::kPrimMul, X, 0}
class TensorMultiplyByZero : public TensorMultiplyBase {
public:
TensorMultiplyByZero() : zero_(MakeValue(0)) {}
~TensorMultiplyByZero() override = default;
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override;
void Visit(const AnfNodePtr &node) override;
void Visit(const ValueNodePtr &vnode) override;
void Reset();
private:
bool is_zero_{false};
ValuePtr zero_;
};
// {prim::kPrimMul, 1, X}, {prim::kPrimMul, X, 1}
class TensorMultiplyByOne : public TensorMultiplyBase {
public:
TensorMultiplyByOne() {}
~TensorMultiplyByOne() override = default;
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override;
void Visit(const AnfNodePtr &node) override;
void Visit(const ValueNodePtr &vnode) override;
void Reset();
private:
bool is_one_{false};
};
// {prim::kPrimScalarAdd, X, 0}
// {prim::kPrimScalarAdd, 0, X}
class AddByZero : public AnfVisitor {
public:
AddByZero() : zero_(MakeValue(0)) {}
~AddByZero() override = default;
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override;
void Visit(const AnfNodePtr &node) override;
void Reset();
private:
bool is_zero_{false};
ValuePtr zero_;
AnfNodePtr x_{nullptr};
};
// {prim::kPrimTensorAdd, {kPrimZerosLike, Y}, X},
// {prim::kPrimTensorAdd, X, {kPrimZerosLike, Y}}
class TensorAddByZero : public AnfVisitor {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override;
void Visit(const AnfNodePtr &node) override;
void Visit(const ValueNodePtr &vnode) override;
void Reset();
private:
bool is_zero_{false};
AnfNodePtr x_{nullptr};
};
// {PrimMomentum, {kPrimZerosLike, X}, Y, Z, Xs} -> {prim::kPrimMakeTuple, Z, Y}
class OptUpdateZeroTensor : public AnfVisitor {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override;
};
// {prim::kPrimMul, Tensor1, {orim::kPrimMul, Tensor2, {...}}} ->
// {prim::kPrimMul, {...}, {prim::kPrimMul, Tensor1, Tensor2}}
class ConstantDuplicateMul : public AnfVisitor {
public:
// Support function to multiply two constant tensors: partially support broadcasting shapes
template <typename T>
void Multiply(void *in_data_1, int in_data_1_size, void *in_data_2, int in_data_2_size, void **out_data,
int out_data_size);
AnfNodePtr MulConstantTensors(const AnfNodePtr &vnode_1, const AnfNodePtr &vnode_2, const AnfNodePtr &node_3);
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override;
void Visit(const AnfNodePtr &node) override;
void Reset();
private:
AnfNodePtr vnode_;
AnfNodePtr c_p_node_;
};
class PowerOneEliminate : public AnfVisitor {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override;
};
// grad = AllReduce(grad) / worker_number
// grad = grad + weight * decy
// ->
@ -56,7 +200,39 @@ class AdjustAllReduceMulAdd : public AnfVisitor {
class ArithmeticSimplify : public OptimizerCaller {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override;
ArithmeticSimplify()
: multiply_by_zero_or_one_(std::make_shared<MultiplyByZeroOrOne>()),
tensor_multiply_by_one_(std::make_shared<TensorMultiplyByOne>()),
add_by_zero_(std::make_shared<AddByZero>()),
tensor_add_by_zero_(std::make_shared<TensorAddByZero>()),
identity_(std::make_shared<PrimEliminater>(prim::kPrimIdentity)),
opt_update_zero_tensor_(std::make_shared<OptUpdateZeroTensor>()),
constant_duplicate_mul_(std::make_shared<ConstantDuplicateMul>()),
power_one_(std::make_shared<PowerOneEliminate>()) {
eliminaters_.emplace_back(multiply_by_zero_or_one_);
eliminaters_.emplace_back(tensor_multiply_by_one_);
eliminaters_.emplace_back(add_by_zero_);
eliminaters_.emplace_back(tensor_add_by_zero_);
eliminaters_.emplace_back(identity_);
eliminaters_.emplace_back(opt_update_zero_tensor_);
eliminaters_.emplace_back(constant_duplicate_mul_);
eliminaters_.emplace_back(power_one_);
}
~ArithmeticSimplify() = default;
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override;
private:
OptimizerCallerPtr multiply_by_zero_or_one_;
OptimizerCallerPtr tensor_multiply_by_one_;
OptimizerCallerPtr add_by_zero_;
OptimizerCallerPtr tensor_add_by_zero_;
OptimizerCallerPtr identity_;
OptimizerCallerPtr opt_update_zero_tensor_;
OptimizerCallerPtr constant_duplicate_mul_;
OptimizerCallerPtr power_one_;
std::vector<OptimizerCallerPtr> eliminaters_{};
};
// Arithmetic Simplifications should be done after step_parallel.
@ -66,9 +242,17 @@ class ArithmeticSimplify : public OptimizerCaller {
// ArithmeticSimplify and deferred until step_parallel.
class ArithmeticSimplify2 : public OptimizerCaller {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override;
};
ArithmeticSimplify2() : tensor_multiply_by_zero_(std::make_shared<TensorMultiplyByZero>()) {
eliminaters_.emplace_back(tensor_multiply_by_zero_);
}
~ArithmeticSimplify2() = default;
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override;
private:
OptimizerCallerPtr tensor_multiply_by_zero_;
std::vector<OptimizerCallerPtr> eliminaters_{};
};
} // namespace irpass
} // namespace opt
} // namespace mindspore

View File

@ -25,8 +25,10 @@
#include "ir/optimizer_caller.h"
#include "ir/pattern_matcher.h"
#include "ir/visitor.h"
#include "operator/ops.h"
#include "optimizer/irpass.h"
#include "optimizer/irpass/prim_eliminate.h"
#include "optimizer/optimizer.h"
namespace mindspore {
namespace opt {

View File

@ -77,7 +77,7 @@ class TestOptOpt : public UT::Common {
};
void SetUp() {
elim_Z = MakeSubstitution(std::make_shared<irpass::ArithmeticSimplify>(), "elim_Z", prim::kPrimScalarAdd);
elim_Z = MakeSubstitution(std::make_shared<irpass::AddByZero>(), "elim_Z", prim::kPrimScalarAdd);
elim_R = MakeSubstitution(std::make_shared<irpass::PrimEliminater>(R), "elim_R", R);
idempotent_P = MakeSubstitution(std::make_shared<IdempotentEliminater>(), "idempotent_P", P);
Qct_to_P = MakeSubstitution(std::make_shared<QctToP>(), "Qct_to_P", Q);