forked from mindspore-Ecosystem/mindspore
!7529 complex arithmetic_simplify
Merge pull request !7529 from zhuxiaochen/1020_allsimplify_1.0
This commit is contained in:
commit
8d39a8a4b2
|
@ -14,18 +14,18 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
#include "backend/optimizer/graph_kernel/arithmetic_simplify.h"
|
#include "backend/optimizer/graph_kernel/arithmetic_simplify.h"
|
||||||
|
|
||||||
#include <list>
|
#include <list>
|
||||||
#include "backend/optimizer/graph_kernel/graph_kernel_helper.h"
|
#include "backend/optimizer/graph_kernel/graph_kernel_helper.h"
|
||||||
#include "backend/kernel_compiler/common_utils.h"
|
#include "backend/kernel_compiler/common_utils.h"
|
||||||
#include "backend/session/anf_runtime_algorithm.h"
|
#include "backend/session/anf_runtime_algorithm.h"
|
||||||
#include "ir/pattern_matcher.h"
|
|
||||||
#include "frontend/operator/ops.h"
|
#include "frontend/operator/ops.h"
|
||||||
|
#include "ir/pattern_matcher.h"
|
||||||
#include "utils/convert_utils.h"
|
#include "utils/convert_utils.h"
|
||||||
#include "utils/utils.h"
|
#include "utils/utils.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace opt {
|
namespace opt {
|
||||||
|
|
||||||
AnfNodePtr NewCNodeWithInfo(const AnfNodePtrList &inputs, const AnfNodePtr &ori_node) {
|
AnfNodePtr NewCNodeWithInfo(const AnfNodePtrList &inputs, const AnfNodePtr &ori_node) {
|
||||||
auto func_graph = ori_node->func_graph();
|
auto func_graph = ori_node->func_graph();
|
||||||
MS_EXCEPTION_IF_NULL(func_graph);
|
MS_EXCEPTION_IF_NULL(func_graph);
|
||||||
|
@ -401,10 +401,236 @@ AnfNodePtr SimplifyDiv(const AnfNodePtr &node) {
|
||||||
(FLAG) = true; \
|
(FLAG) = true; \
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool TryTransposeToReshape(const AnfNodePtr &node) {
|
||||||
|
auto perm = AnfAlgo::GetNodeAttr<std::vector<int>>(node, "perm");
|
||||||
|
auto ori_shape = AnfAlgo::GetPrevNodeOutputInferShape(node, 0);
|
||||||
|
std::vector<int> remove_one_perm;
|
||||||
|
for (auto idx : perm) {
|
||||||
|
if (idx < 0 || IntToSize(idx) >= ori_shape.size()) {
|
||||||
|
MS_EXCEPTION(ValueError);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (ori_shape[idx] != 1) {
|
||||||
|
remove_one_perm.emplace_back(idx);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (remove_one_perm.size() < 2) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
for (size_t idx = 1; idx < remove_one_perm.size(); idx++) {
|
||||||
|
if (remove_one_perm[idx] < remove_one_perm[idx - 1]) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
AnfNodePtr SimplifyTranspose(const AnfNodePtr &node) {
|
||||||
|
if (!IsPrimitiveCNode(node, prim::kPrimTranspose)) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
if (TryTransposeToReshape(node)) {
|
||||||
|
auto new_cnode = NewCNodeWithInfo({NewValueNode(prim::kPrimReshape), node->cast<CNodePtr>()->input(1)}, node);
|
||||||
|
return new_cnode;
|
||||||
|
}
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
AnfNodePtr SimplifyMatMul(const AnfNodePtr &node) {
|
||||||
|
if (!IsPrimitiveCNode(node, prim::kPrimMatMul)) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
PatternNode<AnfNodePtr> x, y;
|
||||||
|
auto matmul_transpose_lambda = [&node, &x, &y]() -> AnfNodePtr {
|
||||||
|
auto new_matmul = NewCNodeWithInfo({NewValueNode(prim::kPrimMatMul), y.GetNode(node), x.GetNode(node)}, node);
|
||||||
|
auto new_abstract = node->abstract()->Clone();
|
||||||
|
auto ori_shape = node->abstract()->GetShapeTrack()->cast<abstract::ShapePtr>();
|
||||||
|
auto shape_value = ori_shape->shape();
|
||||||
|
ShapeVector new_shape_value;
|
||||||
|
std::copy(shape_value.rbegin(), shape_value.rend(), std::back_inserter(new_shape_value));
|
||||||
|
auto new_shape = std::make_shared<abstract::Shape>(new_shape_value);
|
||||||
|
new_abstract->set_shape(new_shape);
|
||||||
|
new_matmul->set_abstract(new_abstract);
|
||||||
|
auto new_cnode = NewCNodeWithInfo({NewValueNode(prim::kPrimTranspose), new_matmul}, node);
|
||||||
|
auto transpose_a = AnfAlgo::GetNodeAttr<ValuePtr>(node, "transpose_a");
|
||||||
|
auto transpose_b = AnfAlgo::GetNodeAttr<ValuePtr>(node, "transpose_b");
|
||||||
|
auto transpose_x1 = AnfAlgo::GetNodeAttr<ValuePtr>(node, "transpose_x1");
|
||||||
|
auto transpose_x2 = AnfAlgo::GetNodeAttr<ValuePtr>(node, "transpose_x2");
|
||||||
|
auto perm = AnfAlgo::GetNodeAttr<ValuePtr>(node->cast<CNodePtr>()->input(1), "perm");
|
||||||
|
AnfAlgo::SetNodeAttr("transpose_a", transpose_b, new_matmul);
|
||||||
|
AnfAlgo::SetNodeAttr("transpose_b", transpose_a, new_matmul);
|
||||||
|
AnfAlgo::SetNodeAttr("transpose_x1", transpose_x2, new_matmul);
|
||||||
|
AnfAlgo::SetNodeAttr("transpose_x2", transpose_x1, new_matmul);
|
||||||
|
AnfAlgo::SetNodeAttr("perm", perm, new_cnode);
|
||||||
|
return new_cnode;
|
||||||
|
};
|
||||||
|
// MatMul(Transpose(x), Transpose(y)) ==> Transpose(MatMul(y, x))
|
||||||
|
MATCH_REPLACE_LAMBDA(node,
|
||||||
|
PBinOperation(prim::kPrimMatMul, PUnaryOperation(prim::kPrimTranspose, x),
|
||||||
|
PUnaryOperation(prim::kPrimTranspose, y), false),
|
||||||
|
matmul_transpose_lambda);
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
ShapeVector TransAxisValueToVector(const ValuePtr &value) {
|
||||||
|
MS_EXCEPTION_IF_NULL(value);
|
||||||
|
ShapeVector axis_vector;
|
||||||
|
if (value->isa<Int32Imm>()) {
|
||||||
|
axis_vector.emplace_back(GetValue<int>(value));
|
||||||
|
}
|
||||||
|
if (value->isa<ValueTuple>() || value->isa<ValueList>()) {
|
||||||
|
axis_vector = GetValue<std::vector<int>>(value);
|
||||||
|
}
|
||||||
|
return axis_vector;
|
||||||
|
}
|
||||||
|
|
||||||
|
ShapeVector GetNodeShape(const AnfNodePtr &node) {
|
||||||
|
auto base_shape = node->Shape()->cast<abstract::ShapePtr>();
|
||||||
|
std::vector<int> shape;
|
||||||
|
std::transform(base_shape->shape().begin(), base_shape->shape().end(), std::back_inserter(shape), IntToSize);
|
||||||
|
return shape;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<std::pair<int, int>> GetUnmodifiedDim(const ShapeVector &a, const ShapeVector &b) {
|
||||||
|
std::vector<std::pair<int, int>> unmodified;
|
||||||
|
for (size_t i = 0, j = 0, patial_a = 1, patial_b = 1;;) {
|
||||||
|
if (i >= a.size() && j >= b.size()) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
patial_a *= a[i];
|
||||||
|
patial_b *= b[j];
|
||||||
|
if (patial_a == patial_b && a[i] == b[j]) {
|
||||||
|
unmodified.emplace_back(std::make_pair(i, j));
|
||||||
|
++i;
|
||||||
|
++j;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (patial_a < patial_b && b[j] > a[i]) {
|
||||||
|
++i;
|
||||||
|
patial_a *= a[i];
|
||||||
|
if (patial_a == patial_b) {
|
||||||
|
++i;
|
||||||
|
++j;
|
||||||
|
}
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (patial_a > patial_b && b[j] < a[i]) {
|
||||||
|
++j;
|
||||||
|
patial_b *= b[j];
|
||||||
|
if (patial_a == patial_b) {
|
||||||
|
++i;
|
||||||
|
++j;
|
||||||
|
}
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return unmodified;
|
||||||
|
}
|
||||||
|
|
||||||
|
AnfNodePtr SimplifyReduce(const AnfNodePtr &node) {
|
||||||
|
if (!IsPrimitiveCNode(node, prim::kPrimReduceMax) && !IsPrimitiveCNode(node, prim::kPrimReduceMin) &&
|
||||||
|
!IsPrimitiveCNode(node, prim::kPrimReduceSum)) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
PatternNode<AnfNodePtr> x;
|
||||||
|
auto trans_reduce_lamda = [&node, &x](PrimitivePtr &operation) -> AnfNodePtr {
|
||||||
|
auto shape = GetNodeShape(node);
|
||||||
|
if (shape.size() != 0 && shape.size() != 1) {
|
||||||
|
return node;
|
||||||
|
} else {
|
||||||
|
auto tmp_node = node->cast<CNodePtr>();
|
||||||
|
auto transpose_node = tmp_node->input(1);
|
||||||
|
auto transpose_dimensions = GetValue<std::vector<int>>(AnfAlgo::GetNodeAttr<ValuePtr>(transpose_node, "perm"));
|
||||||
|
ShapeVector new_dimensions;
|
||||||
|
auto reduce_dimensions = TransAxisValueToVector(AnfAlgo::GetNodeAttr<ValuePtr>(tmp_node, "axis"));
|
||||||
|
std::transform(reduce_dimensions.begin(), reduce_dimensions.end(), std::back_inserter(new_dimensions),
|
||||||
|
[&transpose_dimensions](const int &dim) { return transpose_dimensions[dim]; });
|
||||||
|
std::sort(new_dimensions.begin(), new_dimensions.end());
|
||||||
|
auto new_cnode = NewCNodeWithInfo({NewValueNode(operation), x.GetNode(node)}, node);
|
||||||
|
AnfAlgo::SetNodeAttr("axis", MakeValue(new_dimensions), new_cnode);
|
||||||
|
AnfAlgo::CopyNodeAttr("keep_dims", node, new_cnode);
|
||||||
|
return new_cnode;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
auto reduce_reduce_lamda = [&node, &x](PrimitivePtr &operation) -> AnfNodePtr {
|
||||||
|
auto tmp_node = node->cast<CNodePtr>();
|
||||||
|
auto arg_node = tmp_node->input(1);
|
||||||
|
auto arg_dimensions = TransAxisValueToVector(AnfAlgo::GetNodeAttr<ValuePtr>(arg_node, "axis"));
|
||||||
|
auto reduce_dimensions = TransAxisValueToVector(AnfAlgo::GetNodeAttr<ValuePtr>(tmp_node, "axis"));
|
||||||
|
ShapeVector new_dimensions;
|
||||||
|
for (size_t i = 0; i < arg_dimensions.size(); ++i) {
|
||||||
|
for (size_t j = 0; j < reduce_dimensions.size(); ++j) {
|
||||||
|
if (reduce_dimensions[j] >= arg_dimensions[i]) {
|
||||||
|
++reduce_dimensions[j];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
std::merge(arg_dimensions.begin(), arg_dimensions.end(), reduce_dimensions.begin(), reduce_dimensions.end(),
|
||||||
|
std::back_inserter(new_dimensions));
|
||||||
|
auto new_cnode = NewCNodeWithInfo({NewValueNode(operation), x.GetNode(node)}, node);
|
||||||
|
AnfAlgo::SetNodeAttr("axis", MakeValue(new_dimensions), new_cnode);
|
||||||
|
AnfAlgo::CopyNodeAttr("keep_dims", node, new_cnode);
|
||||||
|
return new_cnode;
|
||||||
|
};
|
||||||
|
auto reshape_reduce_lamda = [&node, &x](PrimitivePtr &operation) -> AnfNodePtr {
|
||||||
|
auto tmp_node = node->cast<CNodePtr>();
|
||||||
|
auto arg_node = tmp_node->input(1);
|
||||||
|
auto input_shape = GetNodeShape(arg_node->cast<CNodePtr>()->input(1));
|
||||||
|
auto re_shape = GetNodeShape(arg_node);
|
||||||
|
auto reduce_dimensions = TransAxisValueToVector(AnfAlgo::GetNodeAttr<ValuePtr>(tmp_node, "axis"));
|
||||||
|
auto unmodified_dim_pair = GetUnmodifiedDim(input_shape, re_shape);
|
||||||
|
std::vector<bool> dim_in_output(re_shape.size(), true);
|
||||||
|
std::vector<bool> dim_unmodified(re_shape.size(), false);
|
||||||
|
for (auto dim : reduce_dimensions) {
|
||||||
|
dim_in_output[dim] = false;
|
||||||
|
}
|
||||||
|
for (auto pair_dim : unmodified_dim_pair) {
|
||||||
|
dim_unmodified[pair_dim.second] = true;
|
||||||
|
}
|
||||||
|
bool replace = true;
|
||||||
|
for (size_t i = 0; i < dim_in_output.size(); ++i) {
|
||||||
|
if (dim_in_output[i] && !dim_unmodified[i]) {
|
||||||
|
replace = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (replace) {
|
||||||
|
ShapeVector un_dimensions;
|
||||||
|
for (auto pair_dim : unmodified_dim_pair) {
|
||||||
|
if (dim_in_output[pair_dim.second]) {
|
||||||
|
un_dimensions.emplace_back(pair_dim.first);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ShapeVector new_dimensions;
|
||||||
|
for (size_t i = 0; i < input_shape.size(); ++i) {
|
||||||
|
if (std::find(un_dimensions.begin(), un_dimensions.end(), i) == un_dimensions.end()) {
|
||||||
|
new_dimensions.emplace_back(i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
auto new_cnode = NewCNodeWithInfo({NewValueNode(operation), x.GetNode(node)}, node);
|
||||||
|
AnfAlgo::SetNodeAttr("axis", MakeValue(new_dimensions), new_cnode);
|
||||||
|
AnfAlgo::CopyNodeAttr("keep_dims", node, new_cnode);
|
||||||
|
return new_cnode;
|
||||||
|
}
|
||||||
|
return node;
|
||||||
|
};
|
||||||
|
std::list<PrimitivePtr> ReduceOperations = {prim::kPrimReduceSum, prim::kPrimReduceMax, prim::kPrimReduceMin};
|
||||||
|
for (auto operation : ReduceOperations) {
|
||||||
|
// Reduce(Transpose(A)) = Reduce(A) if result is a scalar or vector
|
||||||
|
MATCH_REPLACE_LAMBDA_FLAG(node, PPrimitive(operation, PPrimitive(prim::kPrimTranspose, x)), trans_reduce_lamda,
|
||||||
|
operation);
|
||||||
|
// Reduce(Reduce(A)) = Reduce(A)
|
||||||
|
MATCH_REPLACE_LAMBDA_FLAG(node, PPrimitive(operation, PPrimitive(operation, x)), reduce_reduce_lamda, operation);
|
||||||
|
// Reduce(Reshape(A)) = Reduce(A) if reduce dimensions is not in reshape dimensions
|
||||||
|
MATCH_REPLACE_LAMBDA_FLAG(node, PPrimitive(operation, PPrimitive(prim::kPrimReshape, x)), reshape_reduce_lamda,
|
||||||
|
operation);
|
||||||
|
}
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
AnfNodePtr TrySimplify(const AnfNodePtr &node) {
|
AnfNodePtr TrySimplify(const AnfNodePtr &node) {
|
||||||
std::list<std::function<AnfNodePtr(AnfNodePtr)>> SimplifyFuncList = {
|
std::list<std::function<AnfNodePtr(AnfNodePtr)>> SimplifyFuncList = {
|
||||||
SimplifyAdd, SimplifyDiv, SimplifyLog, SimplifyMul, SimplifyNeg,
|
SimplifyAdd, SimplifyDiv, SimplifyLog, SimplifyMul, SimplifyNeg, SimplifyPow, SimplifyRsqrt,
|
||||||
SimplifyPow, SimplifyRsqrt, SimplifySelect, SimplifySqrt, SimplifySub};
|
SimplifySelect, SimplifySqrt, SimplifySub, SimplifyTranspose, SimplifyMatMul, SimplifyReduce};
|
||||||
for (auto f : SimplifyFuncList) {
|
for (auto f : SimplifyFuncList) {
|
||||||
auto ret = f(node);
|
auto ret = f(node);
|
||||||
if (ret != nullptr) {
|
if (ret != nullptr) {
|
||||||
|
|
|
@ -22,8 +22,8 @@
|
||||||
#include <tuple>
|
#include <tuple>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "ir/visitor.h"
|
|
||||||
#include "base/core_ops.h"
|
#include "base/core_ops.h"
|
||||||
|
#include "ir/visitor.h"
|
||||||
#include "utils/shape_utils.h"
|
#include "utils/shape_utils.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
|
@ -750,9 +750,18 @@ class PConstant : public PBase<PConstant<T> > {
|
||||||
if (value->isa<tensor::Tensor>()) {
|
if (value->isa<tensor::Tensor>()) {
|
||||||
tensor::TensorPtr tensor_ptr = dyn_cast<tensor::Tensor>(value);
|
tensor::TensorPtr tensor_ptr = dyn_cast<tensor::Tensor>(value);
|
||||||
TypeId tensor_type = tensor_ptr->Dtype()->type_id();
|
TypeId tensor_type = tensor_ptr->Dtype()->type_id();
|
||||||
|
auto tensor_abstract = node->abstract()->cast<abstract::AbstractTensorPtr>();
|
||||||
|
TypePtr tensor_type_ptr = tensor_abstract->element()->BuildType();
|
||||||
|
ShapeVector tensor_shape = tensor_abstract->shape()->shape();
|
||||||
|
auto new_tensor_ptr = std::make_shared<tensor::Tensor>(tensor_type_ptr->type_id(), tensor_shape);
|
||||||
|
size_t mem_size = GetTypeByte(tensor_type_ptr) * IntToSize(new_tensor_ptr->ElementsNum());
|
||||||
if ((tensor_type == TypeId::kNumberTypeFloat32) || (tensor_type == TypeId::kNumberTypeFloat) ||
|
if ((tensor_type == TypeId::kNumberTypeFloat32) || (tensor_type == TypeId::kNumberTypeFloat) ||
|
||||||
(tensor_type == TypeId::kNumberTypeFloat64)) {
|
(tensor_type == TypeId::kNumberTypeFloat64)) {
|
||||||
float *data2 = reinterpret_cast<float *>(tensor_ptr->data_c());
|
float *data = reinterpret_cast<float *>(tensor_ptr->data_c());
|
||||||
|
float *data2 = reinterpret_cast<float *>(new_tensor_ptr->data_c());
|
||||||
|
if (memcpy_s(data2, mem_size, data, mem_size) != 0) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
for (int i = 0; i < tensor_ptr->DataSize(); i++) {
|
for (int i = 0; i < tensor_ptr->DataSize(); i++) {
|
||||||
if (data2[i] == 0 && calcu_type == prim::kPrimReciprocal) {
|
if (data2[i] == 0 && calcu_type == prim::kPrimReciprocal) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
@ -761,7 +770,11 @@ class PConstant : public PBase<PConstant<T> > {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if ((tensor_type == TypeId::kNumberTypeInt32) || (tensor_type == TypeId::kNumberTypeInt)) {
|
if ((tensor_type == TypeId::kNumberTypeInt32) || (tensor_type == TypeId::kNumberTypeInt)) {
|
||||||
int *data2 = reinterpret_cast<int *>(tensor_ptr->data_c());
|
int *data = reinterpret_cast<int *>(tensor_ptr->data_c());
|
||||||
|
int *data2 = reinterpret_cast<int *>(new_tensor_ptr->data_c());
|
||||||
|
if (memcpy_s(data2, mem_size, data, mem_size) != 0) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
for (int i = 0; i < tensor_ptr->DataSize(); i++) {
|
for (int i = 0; i < tensor_ptr->DataSize(); i++) {
|
||||||
if (data2[i] == 0 && calcu_type == prim::kPrimReciprocal) {
|
if (data2[i] == 0 && calcu_type == prim::kPrimReciprocal) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
@ -770,7 +783,11 @@ class PConstant : public PBase<PConstant<T> > {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (tensor_type == TypeId::kNumberTypeFloat64) {
|
if (tensor_type == TypeId::kNumberTypeFloat64) {
|
||||||
double *data2 = reinterpret_cast<double *>(tensor_ptr->data_c());
|
double *data = reinterpret_cast<double *>(tensor_ptr->data_c());
|
||||||
|
double *data2 = reinterpret_cast<double *>(new_tensor_ptr->data_c());
|
||||||
|
if (memcpy_s(data2, mem_size, data, mem_size) != 0) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
for (int i = 0; i < tensor_ptr->DataSize(); i++) {
|
for (int i = 0; i < tensor_ptr->DataSize(); i++) {
|
||||||
if (data2[i] == 0 && calcu_type == prim::kPrimReciprocal) {
|
if (data2[i] == 0 && calcu_type == prim::kPrimReciprocal) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
@ -778,7 +795,9 @@ class PConstant : public PBase<PConstant<T> > {
|
||||||
data2[i] = CalcuConstant(data2[i], calcu_type);
|
data2[i] = CalcuConstant(data2[i], calcu_type);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return node;
|
auto new_vnode = NewValueNode(new_tensor_ptr);
|
||||||
|
new_vnode->set_abstract(tensor_ptr->ToAbstract());
|
||||||
|
return new_vnode;
|
||||||
}
|
}
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
@ -1005,6 +1024,14 @@ BIN_OPERATION_PATTERN(operator-, prim::kPrimSub, false);
|
||||||
return rep; \
|
return rep; \
|
||||||
} \
|
} \
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#define MATCH_REPLACE_LAMBDA_FLAG(OrigNode, CaptureNode, Lambda, Flag) \
|
||||||
|
if ((CaptureNode).TryCapture(OrigNode)) { \
|
||||||
|
auto rep = (Lambda)(Flag); \
|
||||||
|
if (rep != nullptr) { \
|
||||||
|
return rep; \
|
||||||
|
} \
|
||||||
|
}
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
||||||
#endif // MINDSPORE_CORE_IR_PATTERN_MATCHER_H_
|
#endif // MINDSPORE_CORE_IR_PATTERN_MATCHER_H_
|
||||||
|
|
|
@ -20,7 +20,8 @@ from mindspore import Tensor
|
||||||
from mindspore.nn import Cell
|
from mindspore.nn import Cell
|
||||||
import mindspore.ops.operations as P
|
import mindspore.ops.operations as P
|
||||||
|
|
||||||
context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target="GPU")
|
context.set_context(mode=context.GRAPH_MODE,
|
||||||
|
enable_graph_kernel=True, device_target="GPU")
|
||||||
|
|
||||||
|
|
||||||
class Net(Cell):
|
class Net(Cell):
|
||||||
|
@ -33,6 +34,8 @@ class Net(Cell):
|
||||||
self.sqrt = P.Sqrt()
|
self.sqrt = P.Sqrt()
|
||||||
self.pow = P.Pow()
|
self.pow = P.Pow()
|
||||||
self.neg = P.Neg()
|
self.neg = P.Neg()
|
||||||
|
self.reducemin = P.ReduceMin()
|
||||||
|
self.reshape = P.Reshape()
|
||||||
|
|
||||||
def construct(self, x, y):
|
def construct(self, x, y):
|
||||||
add_res1 = self.add(x, 4)
|
add_res1 = self.add(x, 4)
|
||||||
|
@ -42,7 +45,9 @@ class Net(Cell):
|
||||||
div_res = self.div(mul_res, self.sqrt(mul_res))
|
div_res = self.div(mul_res, self.sqrt(mul_res))
|
||||||
pow_res = self.pow(y, 2)
|
pow_res = self.pow(y, 2)
|
||||||
neg_res = self.neg(self.neg(pow_res))
|
neg_res = self.neg(self.neg(pow_res))
|
||||||
return self.add(div_res, neg_res)
|
add_res3 = self.add(neg_res, div_res)
|
||||||
|
resh_res = self.reshape(add_res3, (2, 12, 3))
|
||||||
|
return self.reducemin(resh_res, 1)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.level0
|
@pytest.mark.level0
|
||||||
|
@ -58,10 +63,12 @@ def test_basic():
|
||||||
div_res = np.sqrt(mul_res)
|
div_res = np.sqrt(mul_res)
|
||||||
pow_res = input_y * input_y
|
pow_res = input_y * input_y
|
||||||
neg_res = pow_res
|
neg_res = pow_res
|
||||||
expect = div_res + neg_res
|
add_res3 = neg_res + div_res
|
||||||
|
expect = np.min(add_res3, (1, 2))
|
||||||
|
|
||||||
net = Net()
|
net = Net()
|
||||||
result = net(Tensor(input_x), Tensor(input_y))
|
result = net(Tensor(input_x), Tensor(input_y))
|
||||||
|
|
||||||
res = np.allclose(expect, result.asnumpy(), rtol=1.e-4, atol=1.e-7, equal_nan=True)
|
res = np.allclose(expect, result.asnumpy(), rtol=1.e-4,
|
||||||
|
atol=1.e-7, equal_nan=True)
|
||||||
assert res
|
assert res
|
||||||
|
|
Loading…
Reference in New Issue