From c739f140389a514425b56a1a1ac91909036f677c Mon Sep 17 00:00:00 2001 From: zhu_xiaochen <940619583@qq.com> Date: Tue, 20 Oct 2020 17:17:24 +0800 Subject: [PATCH] simplify transpose matmul reduce --- .../graph_kernel/arithmetic_simplify.cc | 234 +++++++++++++++++- mindspore/core/ir/pattern_matcher.h | 37 ++- tests/st/ops/graph_kernel/test_simplify.py | 15 +- 3 files changed, 273 insertions(+), 13 deletions(-) diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/arithmetic_simplify.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/arithmetic_simplify.cc index ef06307ce4a..511bdbccaf6 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/arithmetic_simplify.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/arithmetic_simplify.cc @@ -14,18 +14,18 @@ * limitations under the License. */ #include "backend/optimizer/graph_kernel/arithmetic_simplify.h" + #include #include "backend/optimizer/graph_kernel/graph_kernel_helper.h" #include "backend/kernel_compiler/common_utils.h" #include "backend/session/anf_runtime_algorithm.h" -#include "ir/pattern_matcher.h" #include "frontend/operator/ops.h" +#include "ir/pattern_matcher.h" #include "utils/convert_utils.h" #include "utils/utils.h" namespace mindspore { namespace opt { - AnfNodePtr NewCNodeWithInfo(const AnfNodePtrList &inputs, const AnfNodePtr &ori_node) { auto func_graph = ori_node->func_graph(); MS_EXCEPTION_IF_NULL(func_graph); @@ -401,10 +401,236 @@ AnfNodePtr SimplifyDiv(const AnfNodePtr &node) { (FLAG) = true; \ } +bool TryTransposeToReshape(const AnfNodePtr &node) { + auto perm = AnfAlgo::GetNodeAttr>(node, "perm"); + auto ori_shape = AnfAlgo::GetPrevNodeOutputInferShape(node, 0); + std::vector remove_one_perm; + for (auto idx : perm) { + if (idx < 0 || IntToSize(idx) >= ori_shape.size()) { + MS_EXCEPTION(ValueError); + return false; + } + if (ori_shape[idx] != 1) { + remove_one_perm.emplace_back(idx); + } + } + if (remove_one_perm.size() < 2) { + return true; + } + for (size_t idx = 1; idx < remove_one_perm.size(); idx++) { + if (remove_one_perm[idx] < remove_one_perm[idx - 1]) { + return false; + } + } + return true; +} + +AnfNodePtr SimplifyTranspose(const AnfNodePtr &node) { + if (!IsPrimitiveCNode(node, prim::kPrimTranspose)) { + return nullptr; + } + if (TryTransposeToReshape(node)) { + auto new_cnode = NewCNodeWithInfo({NewValueNode(prim::kPrimReshape), node->cast()->input(1)}, node); + return new_cnode; + } + return nullptr; +} + +AnfNodePtr SimplifyMatMul(const AnfNodePtr &node) { + if (!IsPrimitiveCNode(node, prim::kPrimMatMul)) { + return nullptr; + } + PatternNode x, y; + auto matmul_transpose_lambda = [&node, &x, &y]() -> AnfNodePtr { + auto new_matmul = NewCNodeWithInfo({NewValueNode(prim::kPrimMatMul), y.GetNode(node), x.GetNode(node)}, node); + auto new_abstract = node->abstract()->Clone(); + auto ori_shape = node->abstract()->GetShapeTrack()->cast(); + auto shape_value = ori_shape->shape(); + ShapeVector new_shape_value; + std::copy(shape_value.rbegin(), shape_value.rend(), std::back_inserter(new_shape_value)); + auto new_shape = std::make_shared(new_shape_value); + new_abstract->set_shape(new_shape); + new_matmul->set_abstract(new_abstract); + auto new_cnode = NewCNodeWithInfo({NewValueNode(prim::kPrimTranspose), new_matmul}, node); + auto transpose_a = AnfAlgo::GetNodeAttr(node, "transpose_a"); + auto transpose_b = AnfAlgo::GetNodeAttr(node, "transpose_b"); + auto transpose_x1 = AnfAlgo::GetNodeAttr(node, "transpose_x1"); + auto transpose_x2 = AnfAlgo::GetNodeAttr(node, "transpose_x2"); + auto perm = AnfAlgo::GetNodeAttr(node->cast()->input(1), "perm"); + AnfAlgo::SetNodeAttr("transpose_a", transpose_b, new_matmul); + AnfAlgo::SetNodeAttr("transpose_b", transpose_a, new_matmul); + AnfAlgo::SetNodeAttr("transpose_x1", transpose_x2, new_matmul); + AnfAlgo::SetNodeAttr("transpose_x2", transpose_x1, new_matmul); + AnfAlgo::SetNodeAttr("perm", perm, new_cnode); + return new_cnode; + }; + // MatMul(Transpose(x), Transpose(y)) ==> Transpose(MatMul(y, x)) + MATCH_REPLACE_LAMBDA(node, + PBinOperation(prim::kPrimMatMul, PUnaryOperation(prim::kPrimTranspose, x), + PUnaryOperation(prim::kPrimTranspose, y), false), + matmul_transpose_lambda); + return nullptr; +} + +ShapeVector TransAxisValueToVector(const ValuePtr &value) { + MS_EXCEPTION_IF_NULL(value); + ShapeVector axis_vector; + if (value->isa()) { + axis_vector.emplace_back(GetValue(value)); + } + if (value->isa() || value->isa()) { + axis_vector = GetValue>(value); + } + return axis_vector; +} + +ShapeVector GetNodeShape(const AnfNodePtr &node) { + auto base_shape = node->Shape()->cast(); + std::vector shape; + std::transform(base_shape->shape().begin(), base_shape->shape().end(), std::back_inserter(shape), IntToSize); + return shape; +} + +std::vector> GetUnmodifiedDim(const ShapeVector &a, const ShapeVector &b) { + std::vector> unmodified; + for (size_t i = 0, j = 0, patial_a = 1, patial_b = 1;;) { + if (i >= a.size() && j >= b.size()) { + break; + } + patial_a *= a[i]; + patial_b *= b[j]; + if (patial_a == patial_b && a[i] == b[j]) { + unmodified.emplace_back(std::make_pair(i, j)); + ++i; + ++j; + continue; + } + if (patial_a < patial_b && b[j] > a[i]) { + ++i; + patial_a *= a[i]; + if (patial_a == patial_b) { + ++i; + ++j; + } + continue; + } + if (patial_a > patial_b && b[j] < a[i]) { + ++j; + patial_b *= b[j]; + if (patial_a == patial_b) { + ++i; + ++j; + } + continue; + } + } + return unmodified; +} + +AnfNodePtr SimplifyReduce(const AnfNodePtr &node) { + if (!IsPrimitiveCNode(node, prim::kPrimReduceMax) && !IsPrimitiveCNode(node, prim::kPrimReduceMin) && + !IsPrimitiveCNode(node, prim::kPrimReduceSum)) { + return nullptr; + } + PatternNode x; + auto trans_reduce_lamda = [&node, &x](PrimitivePtr &operation) -> AnfNodePtr { + auto shape = GetNodeShape(node); + if (shape.size() != 0 && shape.size() != 1) { + return node; + } else { + auto tmp_node = node->cast(); + auto transpose_node = tmp_node->input(1); + auto transpose_dimensions = GetValue>(AnfAlgo::GetNodeAttr(transpose_node, "perm")); + ShapeVector new_dimensions; + auto reduce_dimensions = TransAxisValueToVector(AnfAlgo::GetNodeAttr(tmp_node, "axis")); + std::transform(reduce_dimensions.begin(), reduce_dimensions.end(), std::back_inserter(new_dimensions), + [&transpose_dimensions](const int &dim) { return transpose_dimensions[dim]; }); + std::sort(new_dimensions.begin(), new_dimensions.end()); + auto new_cnode = NewCNodeWithInfo({NewValueNode(operation), x.GetNode(node)}, node); + AnfAlgo::SetNodeAttr("axis", MakeValue(new_dimensions), new_cnode); + AnfAlgo::CopyNodeAttr("keep_dims", node, new_cnode); + return new_cnode; + } + }; + auto reduce_reduce_lamda = [&node, &x](PrimitivePtr &operation) -> AnfNodePtr { + auto tmp_node = node->cast(); + auto arg_node = tmp_node->input(1); + auto arg_dimensions = TransAxisValueToVector(AnfAlgo::GetNodeAttr(arg_node, "axis")); + auto reduce_dimensions = TransAxisValueToVector(AnfAlgo::GetNodeAttr(tmp_node, "axis")); + ShapeVector new_dimensions; + for (size_t i = 0; i < arg_dimensions.size(); ++i) { + for (size_t j = 0; j < reduce_dimensions.size(); ++j) { + if (reduce_dimensions[j] >= arg_dimensions[i]) { + ++reduce_dimensions[j]; + } + } + } + std::merge(arg_dimensions.begin(), arg_dimensions.end(), reduce_dimensions.begin(), reduce_dimensions.end(), + std::back_inserter(new_dimensions)); + auto new_cnode = NewCNodeWithInfo({NewValueNode(operation), x.GetNode(node)}, node); + AnfAlgo::SetNodeAttr("axis", MakeValue(new_dimensions), new_cnode); + AnfAlgo::CopyNodeAttr("keep_dims", node, new_cnode); + return new_cnode; + }; + auto reshape_reduce_lamda = [&node, &x](PrimitivePtr &operation) -> AnfNodePtr { + auto tmp_node = node->cast(); + auto arg_node = tmp_node->input(1); + auto input_shape = GetNodeShape(arg_node->cast()->input(1)); + auto re_shape = GetNodeShape(arg_node); + auto reduce_dimensions = TransAxisValueToVector(AnfAlgo::GetNodeAttr(tmp_node, "axis")); + auto unmodified_dim_pair = GetUnmodifiedDim(input_shape, re_shape); + std::vector dim_in_output(re_shape.size(), true); + std::vector dim_unmodified(re_shape.size(), false); + for (auto dim : reduce_dimensions) { + dim_in_output[dim] = false; + } + for (auto pair_dim : unmodified_dim_pair) { + dim_unmodified[pair_dim.second] = true; + } + bool replace = true; + for (size_t i = 0; i < dim_in_output.size(); ++i) { + if (dim_in_output[i] && !dim_unmodified[i]) { + replace = false; + } + } + if (replace) { + ShapeVector un_dimensions; + for (auto pair_dim : unmodified_dim_pair) { + if (dim_in_output[pair_dim.second]) { + un_dimensions.emplace_back(pair_dim.first); + } + } + ShapeVector new_dimensions; + for (size_t i = 0; i < input_shape.size(); ++i) { + if (std::find(un_dimensions.begin(), un_dimensions.end(), i) == un_dimensions.end()) { + new_dimensions.emplace_back(i); + } + } + auto new_cnode = NewCNodeWithInfo({NewValueNode(operation), x.GetNode(node)}, node); + AnfAlgo::SetNodeAttr("axis", MakeValue(new_dimensions), new_cnode); + AnfAlgo::CopyNodeAttr("keep_dims", node, new_cnode); + return new_cnode; + } + return node; + }; + std::list ReduceOperations = {prim::kPrimReduceSum, prim::kPrimReduceMax, prim::kPrimReduceMin}; + for (auto operation : ReduceOperations) { + // Reduce(Transpose(A)) = Reduce(A) if result is a scalar or vector + MATCH_REPLACE_LAMBDA_FLAG(node, PPrimitive(operation, PPrimitive(prim::kPrimTranspose, x)), trans_reduce_lamda, + operation); + // Reduce(Reduce(A)) = Reduce(A) + MATCH_REPLACE_LAMBDA_FLAG(node, PPrimitive(operation, PPrimitive(operation, x)), reduce_reduce_lamda, operation); + // Reduce(Reshape(A)) = Reduce(A) if reduce dimensions is not in reshape dimensions + MATCH_REPLACE_LAMBDA_FLAG(node, PPrimitive(operation, PPrimitive(prim::kPrimReshape, x)), reshape_reduce_lamda, + operation); + } + return nullptr; +} + AnfNodePtr TrySimplify(const AnfNodePtr &node) { std::list> SimplifyFuncList = { - SimplifyAdd, SimplifyDiv, SimplifyLog, SimplifyMul, SimplifyNeg, - SimplifyPow, SimplifyRsqrt, SimplifySelect, SimplifySqrt, SimplifySub}; + SimplifyAdd, SimplifyDiv, SimplifyLog, SimplifyMul, SimplifyNeg, SimplifyPow, SimplifyRsqrt, + SimplifySelect, SimplifySqrt, SimplifySub, SimplifyTranspose, SimplifyMatMul, SimplifyReduce}; for (auto f : SimplifyFuncList) { auto ret = f(node); if (ret != nullptr) { diff --git a/mindspore/core/ir/pattern_matcher.h b/mindspore/core/ir/pattern_matcher.h index eae586bc204..c5d3e2ee428 100644 --- a/mindspore/core/ir/pattern_matcher.h +++ b/mindspore/core/ir/pattern_matcher.h @@ -22,8 +22,8 @@ #include #include -#include "ir/visitor.h" #include "base/core_ops.h" +#include "ir/visitor.h" #include "utils/shape_utils.h" namespace mindspore { @@ -750,9 +750,18 @@ class PConstant : public PBase > { if (value->isa()) { tensor::TensorPtr tensor_ptr = dyn_cast(value); TypeId tensor_type = tensor_ptr->Dtype()->type_id(); + auto tensor_abstract = node->abstract()->cast(); + TypePtr tensor_type_ptr = tensor_abstract->element()->BuildType(); + ShapeVector tensor_shape = tensor_abstract->shape()->shape(); + auto new_tensor_ptr = std::make_shared(tensor_type_ptr->type_id(), tensor_shape); + size_t mem_size = GetTypeByte(tensor_type_ptr) * IntToSize(new_tensor_ptr->ElementsNum()); if ((tensor_type == TypeId::kNumberTypeFloat32) || (tensor_type == TypeId::kNumberTypeFloat) || (tensor_type == TypeId::kNumberTypeFloat64)) { - float *data2 = reinterpret_cast(tensor_ptr->data_c()); + float *data = reinterpret_cast(tensor_ptr->data_c()); + float *data2 = reinterpret_cast(new_tensor_ptr->data_c()); + if (memcpy_s(data2, mem_size, data, mem_size) != 0) { + return nullptr; + } for (int i = 0; i < tensor_ptr->DataSize(); i++) { if (data2[i] == 0 && calcu_type == prim::kPrimReciprocal) { return nullptr; @@ -761,7 +770,11 @@ class PConstant : public PBase > { } } if ((tensor_type == TypeId::kNumberTypeInt32) || (tensor_type == TypeId::kNumberTypeInt)) { - int *data2 = reinterpret_cast(tensor_ptr->data_c()); + int *data = reinterpret_cast(tensor_ptr->data_c()); + int *data2 = reinterpret_cast(new_tensor_ptr->data_c()); + if (memcpy_s(data2, mem_size, data, mem_size) != 0) { + return nullptr; + } for (int i = 0; i < tensor_ptr->DataSize(); i++) { if (data2[i] == 0 && calcu_type == prim::kPrimReciprocal) { return nullptr; @@ -770,7 +783,11 @@ class PConstant : public PBase > { } } if (tensor_type == TypeId::kNumberTypeFloat64) { - double *data2 = reinterpret_cast(tensor_ptr->data_c()); + double *data = reinterpret_cast(tensor_ptr->data_c()); + double *data2 = reinterpret_cast(new_tensor_ptr->data_c()); + if (memcpy_s(data2, mem_size, data, mem_size) != 0) { + return nullptr; + } for (int i = 0; i < tensor_ptr->DataSize(); i++) { if (data2[i] == 0 && calcu_type == prim::kPrimReciprocal) { return nullptr; @@ -778,7 +795,9 @@ class PConstant : public PBase > { data2[i] = CalcuConstant(data2[i], calcu_type); } } - return node; + auto new_vnode = NewValueNode(new_tensor_ptr); + new_vnode->set_abstract(tensor_ptr->ToAbstract()); + return new_vnode; } return nullptr; } @@ -1005,6 +1024,14 @@ BIN_OPERATION_PATTERN(operator-, prim::kPrimSub, false); return rep; \ } \ } + +#define MATCH_REPLACE_LAMBDA_FLAG(OrigNode, CaptureNode, Lambda, Flag) \ + if ((CaptureNode).TryCapture(OrigNode)) { \ + auto rep = (Lambda)(Flag); \ + if (rep != nullptr) { \ + return rep; \ + } \ + } } // namespace mindspore #endif // MINDSPORE_CORE_IR_PATTERN_MATCHER_H_ diff --git a/tests/st/ops/graph_kernel/test_simplify.py b/tests/st/ops/graph_kernel/test_simplify.py index 57aaa205e21..5a12613b488 100644 --- a/tests/st/ops/graph_kernel/test_simplify.py +++ b/tests/st/ops/graph_kernel/test_simplify.py @@ -20,7 +20,8 @@ from mindspore import Tensor from mindspore.nn import Cell import mindspore.ops.operations as P -context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target="GPU") +context.set_context(mode=context.GRAPH_MODE, + enable_graph_kernel=True, device_target="GPU") class Net(Cell): @@ -33,6 +34,8 @@ class Net(Cell): self.sqrt = P.Sqrt() self.pow = P.Pow() self.neg = P.Neg() + self.reducemin = P.ReduceMin() + self.reshape = P.Reshape() def construct(self, x, y): add_res1 = self.add(x, 4) @@ -42,7 +45,9 @@ class Net(Cell): div_res = self.div(mul_res, self.sqrt(mul_res)) pow_res = self.pow(y, 2) neg_res = self.neg(self.neg(pow_res)) - return self.add(div_res, neg_res) + add_res3 = self.add(neg_res, div_res) + resh_res = self.reshape(add_res3, (2, 12, 3)) + return self.reducemin(resh_res, 1) @pytest.mark.level0 @@ -58,10 +63,12 @@ def test_basic(): div_res = np.sqrt(mul_res) pow_res = input_y * input_y neg_res = pow_res - expect = div_res + neg_res + add_res3 = neg_res + div_res + expect = np.min(add_res3, (1, 2)) net = Net() result = net(Tensor(input_x), Tensor(input_y)) - res = np.allclose(expect, result.asnumpy(), rtol=1.e-4, atol=1.e-7, equal_nan=True) + res = np.allclose(expect, result.asnumpy(), rtol=1.e-4, + atol=1.e-7, equal_nan=True) assert res