!68753 fix testcase

Merge pull request !68753 from 邹文祥/master-fix-uniform-testcase
This commit is contained in:
i-robot 2024-04-29 01:41:57 +00:00 committed by Gitee
commit 4ec81c944a
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
7 changed files with 119 additions and 132 deletions

View File

@ -79,138 +79,60 @@ REG_FALLBACK_BUILDER("BatchMatMulExt").SetBody(BODYFUNC(ib) {
return {ib->BatchMatMul(x, y, false, false)};
});
DEF_PURE_SHAPE_CALC(g_matmul_ext_fallback_shapecalc)
.SetCalc([](const ShapeArray &inputs) -> ShapeArray {
auto &input_shape = inputs.at(kIndex0);
auto &weight_shape = inputs.at(kIndex1);
bool is_weight_scalar = weight_shape.size() == 1;
ShapeVector multiplication_shape = ops::CheckMatMulShapes(input_shape, weight_shape);
ShapeVector broadcast_shape_input = ops::GetMatMulExtBroadcastShape(multiplication_shape, input_shape);
ShapeVector broadcast_shape_weight = ops::GetMatMulExtBroadcastShape(multiplication_shape, weight_shape);
ShapeVector output_shape = ops::InferShapeRem(multiplication_shape, input_shape, weight_shape, is_weight_scalar);
ShapeVector transpose_order;
size_t max_dim_count = multiplication_shape.size() + 2;
for (size_t i = 0; i < max_dim_count; ++i) {
transpose_order.push_back(i);
}
int64_t total_batch_size = 1;
for (auto dim_size : multiplication_shape) {
total_batch_size *= dim_size;
}
ShapeVector final_input_shape = {total_batch_size, broadcast_shape_input[broadcast_shape_input.size() - 2],
broadcast_shape_input[broadcast_shape_input.size() - 1]};
ShapeVector final_weight_shape = {total_batch_size, broadcast_shape_weight[broadcast_shape_weight.size() - 2],
broadcast_shape_weight[broadcast_shape_weight.size() - 1]};
if (is_weight_scalar) {
std::swap(transpose_order[max_dim_count - 1], transpose_order[max_dim_count - 2]);
std::swap(final_weight_shape[final_weight_shape.size() - 1], final_weight_shape[final_weight_shape.size() - 2]);
}
return {broadcast_shape_input, broadcast_shape_weight, transpose_order,
final_input_shape, final_weight_shape, output_shape};
})
.SetInfer([](const ShapeArray &inputs, const HashSet<size_t> &) -> std::vector<int64_t> {
int64_t broadcast_rank_input = -1LL;
int64_t broadcast_rank_weight = -1LL;
int64_t transpose_order_rank = -1LL;
int64_t final_input_shape_rank = -1LL;
int64_t final_weight_shape_rank = -1LL;
int64_t output_shape_rank = -1LL;
if (!IsDynamicRank(inputs[0]) && !IsDynamicRank(inputs[1])) {
auto &input_shape = inputs.at(kIndex0);
auto &weight_shape = inputs.at(kIndex1);
size_t max_dim_count = std::max(input_shape.size(), weight_shape.size());
max_dim_count = std::max(max_dim_count, static_cast<size_t>(2));
if (input_shape.size() == 1 && weight_shape.size() == 1) {
output_shape_rank = 0;
} else if (input_shape.size() == 1 || weight_shape.size() == 1) {
output_shape_rank = max_dim_count - 1;
} else {
output_shape_rank = max_dim_count;
}
broadcast_rank_input = broadcast_rank_weight = transpose_order_rank = max_dim_count;
final_input_shape_rank = final_weight_shape_rank = 3;
}
return {broadcast_rank_input, broadcast_rank_weight, transpose_order_rank,
final_input_shape_rank, final_weight_shape_rank, output_shape_rank};
});
REG_FALLBACK_BUILDER("MatMulExt").SetBody(BODYFUNC(ib) {
NodePtr input = ib->GetInput(kIndex0);
NodePtr other = ib->GetInput(kIndex1);
if (IsDynamic(input->shape()) || IsDynamic(other->shape())) {
auto shapes = ib->ShapeCalc(g_matmul_ext_fallback_shapecalc, {input, other});
input = ib->Emit("BroadcastTo", {input, shapes[0]});
other = ib->Emit("BroadcastTo", {other, shapes[1]});
other = ib->Transpose(other, shapes[2]);
input = ib->Reshape(input, shapes[3]);
other = ib->Reshape(other, shapes[4]);
auto ret = ib->BatchMatMul(input, other);
ret = ib->Reshape(ret, shapes[5]);
return {ret};
} else {
auto input_rank = input->shape().size();
auto other_rank = other->shape().size();
if (input_rank == 2 && other_rank == 2) {
auto ret = ib->MatMul(input, other);
return {ret};
}
const ShapeVector &shape1_orig = input->shape();
const ShapeVector &shape2_orig = other->shape();
bool is_empty_tensor =
std::any_of(shape1_orig.begin(), shape1_orig.end(), [](const auto &element) { return element == 0; });
if (is_empty_tensor) {
return {ib->Tensor(0, input->dtype())};
}
bool transpose_b = other_rank == 1;
ShapeVector shape_backbone = ops::CheckMatMulShapes(shape1_orig, shape2_orig);
ShapeVector shape_out = ops::InferShapeRem(shape_backbone, shape1_orig, shape2_orig, transpose_b);
input = Expand(ib, input, 2);
other = Expand(ib, other, 2);
NodePtr ret;
if (Rank(other) == 2) {
if (Rank(input) > 2) {
int64_t new_shape_dim0 = 1;
for (size_t i = 0; i < shape1_orig.size() - 1; ++i) {
new_shape_dim0 *= shape1_orig[i];
}
std::vector<int64_t> new_shape_vector = {new_shape_dim0, shape1_orig.back()};
input = ib->Reshape(input, ib->Value(new_shape_vector));
}
ret = ib->MatMul(input, other, false, transpose_b);
} else {
size_t ndim_aligned = std::max(input_rank, other_rank);
input = Expand(ib, input, ndim_aligned);
other = Expand(ib, other, ndim_aligned);
ShapeVector shape1_aligned = input->shape();
ShapeVector shape2_aligned = other->shape();
ShapeVector shape_cur1(shape1_aligned.begin(), shape1_aligned.end() - 2);
ShapeVector shape_cur2(shape2_aligned.begin(), shape2_aligned.end() - 2);
const ShapeVector &broadcast_shape1 = ops::GetMatMulExtBroadcastShape(shape_backbone, shape1_orig);
const ShapeVector &broadcast_shape2 = ops::GetMatMulExtBroadcastShape(shape_backbone, shape2_orig);
if (input->shape() != broadcast_shape1) {
input = ib->Emit("BroadcastTo", {input, ib->Value(broadcast_shape1)});
}
if (other->shape() != broadcast_shape2) {
other = ib->Emit("BroadcastTo", {other, ib->Value(broadcast_shape2)});
}
input = ib->Reshape(input, To3D(input->shape()));
other = ib->Reshape(other, To3D(other->shape()));
ret = ib->BatchMatMul(input, other, false, transpose_b);
}
ret = ib->Reshape(ret, ib->Value(shape_out));
auto input_rank = input->shape().size();
auto other_rank = other->shape().size();
if (input_rank == 2 && other_rank == 2) {
auto ret = ib->MatMul(input, other);
return {ret};
}
const ShapeVector &shape1_orig = input->shape();
const ShapeVector &shape2_orig = other->shape();
bool is_empty_tensor =
std::any_of(shape1_orig.begin(), shape1_orig.end(), [](const auto &element) { return element == 0; });
if (is_empty_tensor) {
return {ib->Tensor(0, input->dtype())};
}
bool transpose_b = other_rank == 1;
ShapeVector shape_backbone = ops::CheckMatMulShapes(shape1_orig, shape2_orig);
ShapeVector shape_out = ops::InferShapeRem(shape_backbone, shape1_orig, shape2_orig, transpose_b);
input = Expand(ib, input, 2);
other = Expand(ib, other, 2);
NodePtr ret;
if (Rank(other) == 2) {
if (Rank(input) > 2) {
int64_t new_shape_dim0 = 1;
for (size_t i = 0; i < shape1_orig.size() - 1; ++i) {
new_shape_dim0 *= shape1_orig[i];
}
std::vector<int64_t> new_shape_vector = {new_shape_dim0, shape1_orig.back()};
input = ib->Reshape(input, ib->Value(new_shape_vector));
}
ret = ib->MatMul(input, other, false, transpose_b);
} else {
size_t ndim_aligned = std::max(input_rank, other_rank);
input = Expand(ib, input, ndim_aligned);
other = Expand(ib, other, ndim_aligned);
ShapeVector shape1_aligned = input->shape();
ShapeVector shape2_aligned = other->shape();
ShapeVector shape_cur1(shape1_aligned.begin(), shape1_aligned.end() - 2);
ShapeVector shape_cur2(shape2_aligned.begin(), shape2_aligned.end() - 2);
const ShapeVector &broadcast_shape1 = ops::GetMatMulExtBroadcastShape(shape_backbone, shape1_orig);
const ShapeVector &broadcast_shape2 = ops::GetMatMulExtBroadcastShape(shape_backbone, shape2_orig);
if (input->shape() != broadcast_shape1) {
input = ib->Emit("BroadcastTo", {input, ib->Value(broadcast_shape1)});
}
if (other->shape() != broadcast_shape2) {
other = ib->Emit("BroadcastTo", {other, ib->Value(broadcast_shape2)});
}
input = ib->Reshape(input, To3D(input->shape()));
other = ib->Reshape(other, To3D(other->shape()));
ret = ib->BatchMatMul(input, other, false, transpose_b);
}
ret = ib->Reshape(ret, ib->Value(shape_out));
return {ret};
});
REG_FALLBACK_BUILDER("MeanExt").SetBody(BODYFUNC(ib) {

View File

@ -46,7 +46,22 @@ bool MMAclnnKernelMod::Launch(const std::vector<KernelTensor *> &inputs, const s
RunOp(stream_ptr, workspace);
return true;
}
void MMExtAclnnKernelMod::GetWorkSpaceInfo(const std::vector<KernelTensor *> &inputs,
const std::vector<KernelTensor *> &outputs) {
GetWorkspaceForResize(inputs[kIndex0], inputs[kIndex1], outputs[kIndex0], OpApiUtil::GetCubeMathType());
}
bool MMExtAclnnKernelMod::Launch(const std::vector<KernelTensor *> &inputs,
const std::vector<KernelTensor *> &workspace,
const std::vector<KernelTensor *> &outputs, void *stream_ptr) {
ParseGenExecutor(GEN_EXECUTOR_BOOST(op_type_, hash_id_, inputs[kIndex0], inputs[kIndex1], outputs[kIndex0],
OpApiUtil::GetCubeMathType()));
RunOp(stream_ptr, workspace);
return true;
}
MS_ACLNN_KERNEL_FACTORY_REG(MatMul, MMAclnnKernelMod);
MS_ACLNN_KERNEL_FACTORY_REG(MatMulV2, MMAclnnKernelMod);
MS_ACLNN_KERNEL_FACTORY_REG(MatMulExt, MMExtAclnnKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -40,6 +40,19 @@ class MMAclnnKernelMod : public AclnnKernelMod {
std::pair<KernelTensor *, bool> input_a_;
std::pair<KernelTensor *, bool> input_b_;
};
class MMExtAclnnKernelMod : public AclnnKernelMod {
public:
MMExtAclnnKernelMod() : AclnnKernelMod("aclnnMatmul") {}
~MMExtAclnnKernelMod() = default;
void GetWorkSpaceInfo(const std::vector<KernelTensor *> &inputs, const std::vector<KernelTensor *> &outputs) override;
bool Launch(const std::vector<KernelTensor *> &inputs, const std::vector<KernelTensor *> &workspace,
const std::vector<KernelTensor *> &outputs, void *stream_ptr) override;
private:
DEFINE_GET_WORKSPACE_FOR_RESIZE()
};
} // namespace kernel
} // namespace mindspore

View File

@ -1282,6 +1282,8 @@ class OnnxExporter {
std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto);
void ExportPrimNotEqual(const FuncGraphPtr &func_graph, const CNodePtr &node,
std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto);
void ExportPrimDense(const FuncGraphPtr &func_graph, const CNodePtr &node,
std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto);
void ExportPrimSqueeze(const FuncGraphPtr &func_graph, const CNodePtr &node,
std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto);
void ExportPrimDynamicRNN(const FuncGraphPtr &func_graph, const CNodePtr &node,
@ -3220,6 +3222,19 @@ void OnnxExporter::ExportPrimNotEqual(const FuncGraphPtr &, const CNodePtr &node
AddOp("Not", {equal_name}, {node_name}, graph_proto);
}
void OnnxExporter::ExportPrimDense(const FuncGraphPtr &func_graph, const CNodePtr &node,
std::map<AnfNodePtr, std::string> *node_map_ptr,
onnx::GraphProto *const graph_proto) {
auto matmul_node = dyn_cast<CNode>(node->input(kOneNum));
auto input_x = matmul_node->input(kOneNum); // matmul input x
auto input_y = matmul_node->input(kTwoNum); // matmul input y
auto input_b = node->input(kTwoNum); // matmul bias
PrimitivePtr prim_matmul = dyn_cast<Primitive>((dyn_cast<ValueNode>(matmul_node->input(kZeroNum)))->value());
std::vector<AnfNodePtr> inputs{input_x, input_y, input_b};
(*node_map_ptr)[node] = ExportPrimitive(func_graph, node_map_ptr, prim_matmul, inputs, graph_proto);
}
void OnnxExporter::ExportPrimSqueeze(const FuncGraphPtr &, const CNodePtr &node,
std::map<AnfNodePtr, std::string> *node_map_ptr,
onnx::GraphProto *const graph_proto) {
@ -3926,6 +3941,7 @@ void OnnxExporter::ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &n
{prim::kPrimGreaterEqual, &OnnxExporter::ExportPrimGreaterEqual},
{prim::kPrimLessEqual, &OnnxExporter::ExportPrimLessEqual},
{prim::kPrimNotEqual, &OnnxExporter::ExportPrimNotEqual},
{prim::kPrimDense, &OnnxExporter::ExportPrimDense},
{prim::kPrimSqueeze, &OnnxExporter::ExportPrimSqueeze},
{prim::kPrimExpandDims, &OnnxExporter::ExportPrimExpandDims},
{prim::kPrimGatherD, &OnnxExporter::ExportPrimGatherD},

View File

@ -19,6 +19,7 @@ from __future__ import absolute_import
import mindspore.numpy as mnp
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.ops.auto_generate import MatMulExt
from mindspore.ops.primitive import _primexpr
from mindspore.common import Tensor
from mindspore.ops.operations import math_ops
@ -310,6 +311,27 @@ def get_matmul_vmap_rule(prim, axis_size):
return vmap_rule
@vmap_rules_getters.register(MatMulExt)
def get_matmul_ext_vmap_rule(prim, axis_size):
"""VmapRule for `*MatMulExt` operation."""
if isinstance(prim, str):
prim = Primitive(prim)
def vmap_rule(a_bdim, b_bdim):
is_all_none, result = vmap_general_preprocess(prim, a_bdim, b_bdim)
if is_all_none:
return result
a, _ = a_bdim
b, _ = b_bdim
matmul_ext = MatMulExt()
out = matmul_ext(a, b)
return out, 0
return vmap_rule
@vmap_rules_getters.register(P.math_ops.MatrixSolve)
def get_matrix_solve_vmap_rule(prim, axis_size):
"""VmapRule for `*MatMul` operation."""

View File

@ -33,7 +33,7 @@ class UniformExtCell(Cell):
return self.uniform(x, from_, to, generator)
@pytest.mark.level1
@pytest.mark.level0
@pytest.mark.env_onecard
@pytest.mark.platform_arm_ascend_training
@pytest.mark.parametrize("context_mode", [
@ -59,11 +59,11 @@ def test_basic(context_mode):
g1.manual_seed(41)
g2 = Generator()
g2.manual_seed(43)
g2.manual_seed(41)
output1 = uniform_cell(mindspore.tensor(x), from_, to, g1).numpy()
expect1 = uniform_cell(mindspore.tensor(x), from_, to, g1).numpy()
output2 = uniform_cell(mindspore.tensor(x), from_, to, g2).numpy()
output2 = uniform_cell(mindspore.tensor(x), from_, to, g1).numpy()
expect1 = uniform_cell(mindspore.tensor(x), from_, to, g2).numpy()
expect2 = uniform_cell(mindspore.tensor(x), from_, to, g2).numpy()
np.testing.assert_allclose(output1, expect1, rtol=rtol)

View File

@ -58,7 +58,6 @@ class MatMulGradCell(Cell):
[[4, 2, 5, 6], [3, 4, 2, 6, 5], (3, 4, 2, 5, 5), (4, 2, 5, 6)],
[[4, 1, 5, 6], [4, 2, 6, 5], (4, 2, 5, 5), (4, 1, 5, 6)],
[[4, 2, 5, 6], [4, 2, 6, 5], (4, 2, 5, 5), (4, 2, 5, 6)],
[[2, 2, 2, 2, 2, 2, 2], [2, 2, 2, 2, 2, 2, 2], (2, 2, 2, 2, 2, 2, 2), (2, 2, 2, 2, 2, 2, 2)],
])
def test_ops(context_mode, shape1, shape2, output_shape1, output_shape2):
"""