diff --git a/mindspore/ccsrc/backend/common/expander/fallback/math_ops.cc b/mindspore/ccsrc/backend/common/expander/fallback/math_ops.cc index 8e3772ca040..3cf360582dd 100644 --- a/mindspore/ccsrc/backend/common/expander/fallback/math_ops.cc +++ b/mindspore/ccsrc/backend/common/expander/fallback/math_ops.cc @@ -79,138 +79,60 @@ REG_FALLBACK_BUILDER("BatchMatMulExt").SetBody(BODYFUNC(ib) { return {ib->BatchMatMul(x, y, false, false)}; }); -DEF_PURE_SHAPE_CALC(g_matmul_ext_fallback_shapecalc) - .SetCalc([](const ShapeArray &inputs) -> ShapeArray { - auto &input_shape = inputs.at(kIndex0); - auto &weight_shape = inputs.at(kIndex1); - - bool is_weight_scalar = weight_shape.size() == 1; - - ShapeVector multiplication_shape = ops::CheckMatMulShapes(input_shape, weight_shape); - ShapeVector broadcast_shape_input = ops::GetMatMulExtBroadcastShape(multiplication_shape, input_shape); - ShapeVector broadcast_shape_weight = ops::GetMatMulExtBroadcastShape(multiplication_shape, weight_shape); - ShapeVector output_shape = ops::InferShapeRem(multiplication_shape, input_shape, weight_shape, is_weight_scalar); - ShapeVector transpose_order; - size_t max_dim_count = multiplication_shape.size() + 2; - - for (size_t i = 0; i < max_dim_count; ++i) { - transpose_order.push_back(i); - } - - int64_t total_batch_size = 1; - for (auto dim_size : multiplication_shape) { - total_batch_size *= dim_size; - } - - ShapeVector final_input_shape = {total_batch_size, broadcast_shape_input[broadcast_shape_input.size() - 2], - broadcast_shape_input[broadcast_shape_input.size() - 1]}; - ShapeVector final_weight_shape = {total_batch_size, broadcast_shape_weight[broadcast_shape_weight.size() - 2], - broadcast_shape_weight[broadcast_shape_weight.size() - 1]}; - - if (is_weight_scalar) { - std::swap(transpose_order[max_dim_count - 1], transpose_order[max_dim_count - 2]); - std::swap(final_weight_shape[final_weight_shape.size() - 1], final_weight_shape[final_weight_shape.size() - 2]); - } - - return {broadcast_shape_input, broadcast_shape_weight, transpose_order, - final_input_shape, final_weight_shape, output_shape}; - }) - .SetInfer([](const ShapeArray &inputs, const HashSet &) -> std::vector { - int64_t broadcast_rank_input = -1LL; - int64_t broadcast_rank_weight = -1LL; - int64_t transpose_order_rank = -1LL; - int64_t final_input_shape_rank = -1LL; - int64_t final_weight_shape_rank = -1LL; - int64_t output_shape_rank = -1LL; - - if (!IsDynamicRank(inputs[0]) && !IsDynamicRank(inputs[1])) { - auto &input_shape = inputs.at(kIndex0); - auto &weight_shape = inputs.at(kIndex1); - - size_t max_dim_count = std::max(input_shape.size(), weight_shape.size()); - max_dim_count = std::max(max_dim_count, static_cast(2)); - - if (input_shape.size() == 1 && weight_shape.size() == 1) { - output_shape_rank = 0; - } else if (input_shape.size() == 1 || weight_shape.size() == 1) { - output_shape_rank = max_dim_count - 1; - } else { - output_shape_rank = max_dim_count; - } - - broadcast_rank_input = broadcast_rank_weight = transpose_order_rank = max_dim_count; - final_input_shape_rank = final_weight_shape_rank = 3; - } - return {broadcast_rank_input, broadcast_rank_weight, transpose_order_rank, - final_input_shape_rank, final_weight_shape_rank, output_shape_rank}; - }); - REG_FALLBACK_BUILDER("MatMulExt").SetBody(BODYFUNC(ib) { NodePtr input = ib->GetInput(kIndex0); NodePtr other = ib->GetInput(kIndex1); - if (IsDynamic(input->shape()) || IsDynamic(other->shape())) { - auto shapes = ib->ShapeCalc(g_matmul_ext_fallback_shapecalc, {input, other}); - input = ib->Emit("BroadcastTo", {input, shapes[0]}); - other = ib->Emit("BroadcastTo", {other, shapes[1]}); - other = ib->Transpose(other, shapes[2]); - input = ib->Reshape(input, shapes[3]); - other = ib->Reshape(other, shapes[4]); - auto ret = ib->BatchMatMul(input, other); - ret = ib->Reshape(ret, shapes[5]); - return {ret}; - } else { - auto input_rank = input->shape().size(); - auto other_rank = other->shape().size(); - if (input_rank == 2 && other_rank == 2) { - auto ret = ib->MatMul(input, other); - return {ret}; - } - const ShapeVector &shape1_orig = input->shape(); - const ShapeVector &shape2_orig = other->shape(); - bool is_empty_tensor = - std::any_of(shape1_orig.begin(), shape1_orig.end(), [](const auto &element) { return element == 0; }); - if (is_empty_tensor) { - return {ib->Tensor(0, input->dtype())}; - } - bool transpose_b = other_rank == 1; - ShapeVector shape_backbone = ops::CheckMatMulShapes(shape1_orig, shape2_orig); - ShapeVector shape_out = ops::InferShapeRem(shape_backbone, shape1_orig, shape2_orig, transpose_b); - input = Expand(ib, input, 2); - other = Expand(ib, other, 2); - NodePtr ret; - if (Rank(other) == 2) { - if (Rank(input) > 2) { - int64_t new_shape_dim0 = 1; - for (size_t i = 0; i < shape1_orig.size() - 1; ++i) { - new_shape_dim0 *= shape1_orig[i]; - } - std::vector new_shape_vector = {new_shape_dim0, shape1_orig.back()}; - input = ib->Reshape(input, ib->Value(new_shape_vector)); - } - ret = ib->MatMul(input, other, false, transpose_b); - } else { - size_t ndim_aligned = std::max(input_rank, other_rank); - input = Expand(ib, input, ndim_aligned); - other = Expand(ib, other, ndim_aligned); - ShapeVector shape1_aligned = input->shape(); - ShapeVector shape2_aligned = other->shape(); - ShapeVector shape_cur1(shape1_aligned.begin(), shape1_aligned.end() - 2); - ShapeVector shape_cur2(shape2_aligned.begin(), shape2_aligned.end() - 2); - const ShapeVector &broadcast_shape1 = ops::GetMatMulExtBroadcastShape(shape_backbone, shape1_orig); - const ShapeVector &broadcast_shape2 = ops::GetMatMulExtBroadcastShape(shape_backbone, shape2_orig); - if (input->shape() != broadcast_shape1) { - input = ib->Emit("BroadcastTo", {input, ib->Value(broadcast_shape1)}); - } - if (other->shape() != broadcast_shape2) { - other = ib->Emit("BroadcastTo", {other, ib->Value(broadcast_shape2)}); - } - input = ib->Reshape(input, To3D(input->shape())); - other = ib->Reshape(other, To3D(other->shape())); - ret = ib->BatchMatMul(input, other, false, transpose_b); - } - ret = ib->Reshape(ret, ib->Value(shape_out)); + auto input_rank = input->shape().size(); + auto other_rank = other->shape().size(); + if (input_rank == 2 && other_rank == 2) { + auto ret = ib->MatMul(input, other); return {ret}; } + const ShapeVector &shape1_orig = input->shape(); + const ShapeVector &shape2_orig = other->shape(); + bool is_empty_tensor = + std::any_of(shape1_orig.begin(), shape1_orig.end(), [](const auto &element) { return element == 0; }); + if (is_empty_tensor) { + return {ib->Tensor(0, input->dtype())}; + } + bool transpose_b = other_rank == 1; + ShapeVector shape_backbone = ops::CheckMatMulShapes(shape1_orig, shape2_orig); + ShapeVector shape_out = ops::InferShapeRem(shape_backbone, shape1_orig, shape2_orig, transpose_b); + input = Expand(ib, input, 2); + other = Expand(ib, other, 2); + NodePtr ret; + if (Rank(other) == 2) { + if (Rank(input) > 2) { + int64_t new_shape_dim0 = 1; + for (size_t i = 0; i < shape1_orig.size() - 1; ++i) { + new_shape_dim0 *= shape1_orig[i]; + } + std::vector new_shape_vector = {new_shape_dim0, shape1_orig.back()}; + input = ib->Reshape(input, ib->Value(new_shape_vector)); + } + ret = ib->MatMul(input, other, false, transpose_b); + } else { + size_t ndim_aligned = std::max(input_rank, other_rank); + input = Expand(ib, input, ndim_aligned); + other = Expand(ib, other, ndim_aligned); + ShapeVector shape1_aligned = input->shape(); + ShapeVector shape2_aligned = other->shape(); + ShapeVector shape_cur1(shape1_aligned.begin(), shape1_aligned.end() - 2); + ShapeVector shape_cur2(shape2_aligned.begin(), shape2_aligned.end() - 2); + const ShapeVector &broadcast_shape1 = ops::GetMatMulExtBroadcastShape(shape_backbone, shape1_orig); + const ShapeVector &broadcast_shape2 = ops::GetMatMulExtBroadcastShape(shape_backbone, shape2_orig); + if (input->shape() != broadcast_shape1) { + input = ib->Emit("BroadcastTo", {input, ib->Value(broadcast_shape1)}); + } + if (other->shape() != broadcast_shape2) { + other = ib->Emit("BroadcastTo", {other, ib->Value(broadcast_shape2)}); + } + input = ib->Reshape(input, To3D(input->shape())); + other = ib->Reshape(other, To3D(other->shape())); + ret = ib->BatchMatMul(input, other, false, transpose_b); + } + ret = ib->Reshape(ret, ib->Value(shape_out)); + return {ret}; }); REG_FALLBACK_BUILDER("MeanExt").SetBody(BODYFUNC(ib) { diff --git a/mindspore/ccsrc/plugin/device/ascend/kernel/opapi/aclnn/matmul_aclnn_kernel.cc b/mindspore/ccsrc/plugin/device/ascend/kernel/opapi/aclnn/matmul_aclnn_kernel.cc index c18c6e6045c..b7f8546ade6 100644 --- a/mindspore/ccsrc/plugin/device/ascend/kernel/opapi/aclnn/matmul_aclnn_kernel.cc +++ b/mindspore/ccsrc/plugin/device/ascend/kernel/opapi/aclnn/matmul_aclnn_kernel.cc @@ -46,7 +46,22 @@ bool MMAclnnKernelMod::Launch(const std::vector &inputs, const s RunOp(stream_ptr, workspace); return true; } + +void MMExtAclnnKernelMod::GetWorkSpaceInfo(const std::vector &inputs, + const std::vector &outputs) { + GetWorkspaceForResize(inputs[kIndex0], inputs[kIndex1], outputs[kIndex0], OpApiUtil::GetCubeMathType()); +} + +bool MMExtAclnnKernelMod::Launch(const std::vector &inputs, + const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) { + ParseGenExecutor(GEN_EXECUTOR_BOOST(op_type_, hash_id_, inputs[kIndex0], inputs[kIndex1], outputs[kIndex0], + OpApiUtil::GetCubeMathType())); + RunOp(stream_ptr, workspace); + return true; +} MS_ACLNN_KERNEL_FACTORY_REG(MatMul, MMAclnnKernelMod); MS_ACLNN_KERNEL_FACTORY_REG(MatMulV2, MMAclnnKernelMod); +MS_ACLNN_KERNEL_FACTORY_REG(MatMulExt, MMExtAclnnKernelMod); } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/ascend/kernel/opapi/aclnn/matmul_aclnn_kernel.h b/mindspore/ccsrc/plugin/device/ascend/kernel/opapi/aclnn/matmul_aclnn_kernel.h index 669332dc8f5..b6642054a01 100644 --- a/mindspore/ccsrc/plugin/device/ascend/kernel/opapi/aclnn/matmul_aclnn_kernel.h +++ b/mindspore/ccsrc/plugin/device/ascend/kernel/opapi/aclnn/matmul_aclnn_kernel.h @@ -40,6 +40,19 @@ class MMAclnnKernelMod : public AclnnKernelMod { std::pair input_a_; std::pair input_b_; }; + +class MMExtAclnnKernelMod : public AclnnKernelMod { + public: + MMExtAclnnKernelMod() : AclnnKernelMod("aclnnMatmul") {} + ~MMExtAclnnKernelMod() = default; + + void GetWorkSpaceInfo(const std::vector &inputs, const std::vector &outputs) override; + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override; + + private: + DEFINE_GET_WORKSPACE_FOR_RESIZE() +}; } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/transform/express_ir/onnx_exporter.cc b/mindspore/ccsrc/transform/express_ir/onnx_exporter.cc index 7ed4d5100e4..c2b307add78 100644 --- a/mindspore/ccsrc/transform/express_ir/onnx_exporter.cc +++ b/mindspore/ccsrc/transform/express_ir/onnx_exporter.cc @@ -1282,6 +1282,8 @@ class OnnxExporter { std::map *node_map_ptr, onnx::GraphProto *graph_proto); void ExportPrimNotEqual(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map *node_map_ptr, onnx::GraphProto *graph_proto); + void ExportPrimDense(const FuncGraphPtr &func_graph, const CNodePtr &node, + std::map *node_map_ptr, onnx::GraphProto *graph_proto); void ExportPrimSqueeze(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map *node_map_ptr, onnx::GraphProto *graph_proto); void ExportPrimDynamicRNN(const FuncGraphPtr &func_graph, const CNodePtr &node, @@ -3220,6 +3222,19 @@ void OnnxExporter::ExportPrimNotEqual(const FuncGraphPtr &, const CNodePtr &node AddOp("Not", {equal_name}, {node_name}, graph_proto); } +void OnnxExporter::ExportPrimDense(const FuncGraphPtr &func_graph, const CNodePtr &node, + std::map *node_map_ptr, + onnx::GraphProto *const graph_proto) { + auto matmul_node = dyn_cast(node->input(kOneNum)); + auto input_x = matmul_node->input(kOneNum); // matmul input x + auto input_y = matmul_node->input(kTwoNum); // matmul input y + auto input_b = node->input(kTwoNum); // matmul bias + + PrimitivePtr prim_matmul = dyn_cast((dyn_cast(matmul_node->input(kZeroNum)))->value()); + std::vector inputs{input_x, input_y, input_b}; + (*node_map_ptr)[node] = ExportPrimitive(func_graph, node_map_ptr, prim_matmul, inputs, graph_proto); +} + void OnnxExporter::ExportPrimSqueeze(const FuncGraphPtr &, const CNodePtr &node, std::map *node_map_ptr, onnx::GraphProto *const graph_proto) { @@ -3926,6 +3941,7 @@ void OnnxExporter::ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &n {prim::kPrimGreaterEqual, &OnnxExporter::ExportPrimGreaterEqual}, {prim::kPrimLessEqual, &OnnxExporter::ExportPrimLessEqual}, {prim::kPrimNotEqual, &OnnxExporter::ExportPrimNotEqual}, + {prim::kPrimDense, &OnnxExporter::ExportPrimDense}, {prim::kPrimSqueeze, &OnnxExporter::ExportPrimSqueeze}, {prim::kPrimExpandDims, &OnnxExporter::ExportPrimExpandDims}, {prim::kPrimGatherD, &OnnxExporter::ExportPrimGatherD}, diff --git a/mindspore/python/mindspore/ops/_vmap/vmap_math_ops.py b/mindspore/python/mindspore/ops/_vmap/vmap_math_ops.py index fe4dd20e202..26a3b6724d4 100644 --- a/mindspore/python/mindspore/ops/_vmap/vmap_math_ops.py +++ b/mindspore/python/mindspore/ops/_vmap/vmap_math_ops.py @@ -19,6 +19,7 @@ from __future__ import absolute_import import mindspore.numpy as mnp from mindspore.ops import operations as P from mindspore.ops import functional as F +from mindspore.ops.auto_generate import MatMulExt from mindspore.ops.primitive import _primexpr from mindspore.common import Tensor from mindspore.ops.operations import math_ops @@ -310,6 +311,27 @@ def get_matmul_vmap_rule(prim, axis_size): return vmap_rule +@vmap_rules_getters.register(MatMulExt) +def get_matmul_ext_vmap_rule(prim, axis_size): + """VmapRule for `*MatMulExt` operation.""" + if isinstance(prim, str): + prim = Primitive(prim) + + def vmap_rule(a_bdim, b_bdim): + is_all_none, result = vmap_general_preprocess(prim, a_bdim, b_bdim) + if is_all_none: + return result + + a, _ = a_bdim + b, _ = b_bdim + + matmul_ext = MatMulExt() + out = matmul_ext(a, b) + return out, 0 + + return vmap_rule + + @vmap_rules_getters.register(P.math_ops.MatrixSolve) def get_matrix_solve_vmap_rule(prim, axis_size): """VmapRule for `*MatMul` operation.""" diff --git a/tests/st/ops/test_f_uniform.py b/tests/st/ops/test_f_uniform.py index f4e0b75c2f1..2608973c3ba 100644 --- a/tests/st/ops/test_f_uniform.py +++ b/tests/st/ops/test_f_uniform.py @@ -33,7 +33,7 @@ class UniformExtCell(Cell): return self.uniform(x, from_, to, generator) -@pytest.mark.level1 +@pytest.mark.level0 @pytest.mark.env_onecard @pytest.mark.platform_arm_ascend_training @pytest.mark.parametrize("context_mode", [ @@ -59,11 +59,11 @@ def test_basic(context_mode): g1.manual_seed(41) g2 = Generator() - g2.manual_seed(43) + g2.manual_seed(41) output1 = uniform_cell(mindspore.tensor(x), from_, to, g1).numpy() - expect1 = uniform_cell(mindspore.tensor(x), from_, to, g1).numpy() - output2 = uniform_cell(mindspore.tensor(x), from_, to, g2).numpy() + output2 = uniform_cell(mindspore.tensor(x), from_, to, g1).numpy() + expect1 = uniform_cell(mindspore.tensor(x), from_, to, g2).numpy() expect2 = uniform_cell(mindspore.tensor(x), from_, to, g2).numpy() np.testing.assert_allclose(output1, expect1, rtol=rtol) diff --git a/tests/st/ops/test_ops_matmul_ext.py b/tests/st/ops/test_ops_matmul_ext.py index f9ed90f0019..62a45c3b72b 100644 --- a/tests/st/ops/test_ops_matmul_ext.py +++ b/tests/st/ops/test_ops_matmul_ext.py @@ -58,7 +58,6 @@ class MatMulGradCell(Cell): [[4, 2, 5, 6], [3, 4, 2, 6, 5], (3, 4, 2, 5, 5), (4, 2, 5, 6)], [[4, 1, 5, 6], [4, 2, 6, 5], (4, 2, 5, 5), (4, 1, 5, 6)], [[4, 2, 5, 6], [4, 2, 6, 5], (4, 2, 5, 5), (4, 2, 5, 6)], - [[2, 2, 2, 2, 2, 2, 2], [2, 2, 2, 2, 2, 2, 2], (2, 2, 2, 2, 2, 2, 2), (2, 2, 2, 2, 2, 2, 2)], ]) def test_ops(context_mode, shape1, shape2, output_shape1, output_shape2): """