!68753 fix testcase
Merge pull request !68753 from 邹文祥/master-fix-uniform-testcase
This commit is contained in:
commit
4ec81c944a
|
@ -79,138 +79,60 @@ REG_FALLBACK_BUILDER("BatchMatMulExt").SetBody(BODYFUNC(ib) {
|
|||
return {ib->BatchMatMul(x, y, false, false)};
|
||||
});
|
||||
|
||||
DEF_PURE_SHAPE_CALC(g_matmul_ext_fallback_shapecalc)
|
||||
.SetCalc([](const ShapeArray &inputs) -> ShapeArray {
|
||||
auto &input_shape = inputs.at(kIndex0);
|
||||
auto &weight_shape = inputs.at(kIndex1);
|
||||
|
||||
bool is_weight_scalar = weight_shape.size() == 1;
|
||||
|
||||
ShapeVector multiplication_shape = ops::CheckMatMulShapes(input_shape, weight_shape);
|
||||
ShapeVector broadcast_shape_input = ops::GetMatMulExtBroadcastShape(multiplication_shape, input_shape);
|
||||
ShapeVector broadcast_shape_weight = ops::GetMatMulExtBroadcastShape(multiplication_shape, weight_shape);
|
||||
ShapeVector output_shape = ops::InferShapeRem(multiplication_shape, input_shape, weight_shape, is_weight_scalar);
|
||||
ShapeVector transpose_order;
|
||||
size_t max_dim_count = multiplication_shape.size() + 2;
|
||||
|
||||
for (size_t i = 0; i < max_dim_count; ++i) {
|
||||
transpose_order.push_back(i);
|
||||
}
|
||||
|
||||
int64_t total_batch_size = 1;
|
||||
for (auto dim_size : multiplication_shape) {
|
||||
total_batch_size *= dim_size;
|
||||
}
|
||||
|
||||
ShapeVector final_input_shape = {total_batch_size, broadcast_shape_input[broadcast_shape_input.size() - 2],
|
||||
broadcast_shape_input[broadcast_shape_input.size() - 1]};
|
||||
ShapeVector final_weight_shape = {total_batch_size, broadcast_shape_weight[broadcast_shape_weight.size() - 2],
|
||||
broadcast_shape_weight[broadcast_shape_weight.size() - 1]};
|
||||
|
||||
if (is_weight_scalar) {
|
||||
std::swap(transpose_order[max_dim_count - 1], transpose_order[max_dim_count - 2]);
|
||||
std::swap(final_weight_shape[final_weight_shape.size() - 1], final_weight_shape[final_weight_shape.size() - 2]);
|
||||
}
|
||||
|
||||
return {broadcast_shape_input, broadcast_shape_weight, transpose_order,
|
||||
final_input_shape, final_weight_shape, output_shape};
|
||||
})
|
||||
.SetInfer([](const ShapeArray &inputs, const HashSet<size_t> &) -> std::vector<int64_t> {
|
||||
int64_t broadcast_rank_input = -1LL;
|
||||
int64_t broadcast_rank_weight = -1LL;
|
||||
int64_t transpose_order_rank = -1LL;
|
||||
int64_t final_input_shape_rank = -1LL;
|
||||
int64_t final_weight_shape_rank = -1LL;
|
||||
int64_t output_shape_rank = -1LL;
|
||||
|
||||
if (!IsDynamicRank(inputs[0]) && !IsDynamicRank(inputs[1])) {
|
||||
auto &input_shape = inputs.at(kIndex0);
|
||||
auto &weight_shape = inputs.at(kIndex1);
|
||||
|
||||
size_t max_dim_count = std::max(input_shape.size(), weight_shape.size());
|
||||
max_dim_count = std::max(max_dim_count, static_cast<size_t>(2));
|
||||
|
||||
if (input_shape.size() == 1 && weight_shape.size() == 1) {
|
||||
output_shape_rank = 0;
|
||||
} else if (input_shape.size() == 1 || weight_shape.size() == 1) {
|
||||
output_shape_rank = max_dim_count - 1;
|
||||
} else {
|
||||
output_shape_rank = max_dim_count;
|
||||
}
|
||||
|
||||
broadcast_rank_input = broadcast_rank_weight = transpose_order_rank = max_dim_count;
|
||||
final_input_shape_rank = final_weight_shape_rank = 3;
|
||||
}
|
||||
return {broadcast_rank_input, broadcast_rank_weight, transpose_order_rank,
|
||||
final_input_shape_rank, final_weight_shape_rank, output_shape_rank};
|
||||
});
|
||||
|
||||
REG_FALLBACK_BUILDER("MatMulExt").SetBody(BODYFUNC(ib) {
|
||||
NodePtr input = ib->GetInput(kIndex0);
|
||||
NodePtr other = ib->GetInput(kIndex1);
|
||||
if (IsDynamic(input->shape()) || IsDynamic(other->shape())) {
|
||||
auto shapes = ib->ShapeCalc(g_matmul_ext_fallback_shapecalc, {input, other});
|
||||
input = ib->Emit("BroadcastTo", {input, shapes[0]});
|
||||
other = ib->Emit("BroadcastTo", {other, shapes[1]});
|
||||
other = ib->Transpose(other, shapes[2]);
|
||||
input = ib->Reshape(input, shapes[3]);
|
||||
other = ib->Reshape(other, shapes[4]);
|
||||
auto ret = ib->BatchMatMul(input, other);
|
||||
ret = ib->Reshape(ret, shapes[5]);
|
||||
return {ret};
|
||||
} else {
|
||||
auto input_rank = input->shape().size();
|
||||
auto other_rank = other->shape().size();
|
||||
if (input_rank == 2 && other_rank == 2) {
|
||||
auto ret = ib->MatMul(input, other);
|
||||
return {ret};
|
||||
}
|
||||
const ShapeVector &shape1_orig = input->shape();
|
||||
const ShapeVector &shape2_orig = other->shape();
|
||||
bool is_empty_tensor =
|
||||
std::any_of(shape1_orig.begin(), shape1_orig.end(), [](const auto &element) { return element == 0; });
|
||||
if (is_empty_tensor) {
|
||||
return {ib->Tensor(0, input->dtype())};
|
||||
}
|
||||
bool transpose_b = other_rank == 1;
|
||||
ShapeVector shape_backbone = ops::CheckMatMulShapes(shape1_orig, shape2_orig);
|
||||
ShapeVector shape_out = ops::InferShapeRem(shape_backbone, shape1_orig, shape2_orig, transpose_b);
|
||||
input = Expand(ib, input, 2);
|
||||
other = Expand(ib, other, 2);
|
||||
NodePtr ret;
|
||||
if (Rank(other) == 2) {
|
||||
if (Rank(input) > 2) {
|
||||
int64_t new_shape_dim0 = 1;
|
||||
for (size_t i = 0; i < shape1_orig.size() - 1; ++i) {
|
||||
new_shape_dim0 *= shape1_orig[i];
|
||||
}
|
||||
std::vector<int64_t> new_shape_vector = {new_shape_dim0, shape1_orig.back()};
|
||||
input = ib->Reshape(input, ib->Value(new_shape_vector));
|
||||
}
|
||||
ret = ib->MatMul(input, other, false, transpose_b);
|
||||
} else {
|
||||
size_t ndim_aligned = std::max(input_rank, other_rank);
|
||||
input = Expand(ib, input, ndim_aligned);
|
||||
other = Expand(ib, other, ndim_aligned);
|
||||
ShapeVector shape1_aligned = input->shape();
|
||||
ShapeVector shape2_aligned = other->shape();
|
||||
ShapeVector shape_cur1(shape1_aligned.begin(), shape1_aligned.end() - 2);
|
||||
ShapeVector shape_cur2(shape2_aligned.begin(), shape2_aligned.end() - 2);
|
||||
const ShapeVector &broadcast_shape1 = ops::GetMatMulExtBroadcastShape(shape_backbone, shape1_orig);
|
||||
const ShapeVector &broadcast_shape2 = ops::GetMatMulExtBroadcastShape(shape_backbone, shape2_orig);
|
||||
if (input->shape() != broadcast_shape1) {
|
||||
input = ib->Emit("BroadcastTo", {input, ib->Value(broadcast_shape1)});
|
||||
}
|
||||
if (other->shape() != broadcast_shape2) {
|
||||
other = ib->Emit("BroadcastTo", {other, ib->Value(broadcast_shape2)});
|
||||
}
|
||||
input = ib->Reshape(input, To3D(input->shape()));
|
||||
other = ib->Reshape(other, To3D(other->shape()));
|
||||
ret = ib->BatchMatMul(input, other, false, transpose_b);
|
||||
}
|
||||
ret = ib->Reshape(ret, ib->Value(shape_out));
|
||||
auto input_rank = input->shape().size();
|
||||
auto other_rank = other->shape().size();
|
||||
if (input_rank == 2 && other_rank == 2) {
|
||||
auto ret = ib->MatMul(input, other);
|
||||
return {ret};
|
||||
}
|
||||
const ShapeVector &shape1_orig = input->shape();
|
||||
const ShapeVector &shape2_orig = other->shape();
|
||||
bool is_empty_tensor =
|
||||
std::any_of(shape1_orig.begin(), shape1_orig.end(), [](const auto &element) { return element == 0; });
|
||||
if (is_empty_tensor) {
|
||||
return {ib->Tensor(0, input->dtype())};
|
||||
}
|
||||
bool transpose_b = other_rank == 1;
|
||||
ShapeVector shape_backbone = ops::CheckMatMulShapes(shape1_orig, shape2_orig);
|
||||
ShapeVector shape_out = ops::InferShapeRem(shape_backbone, shape1_orig, shape2_orig, transpose_b);
|
||||
input = Expand(ib, input, 2);
|
||||
other = Expand(ib, other, 2);
|
||||
NodePtr ret;
|
||||
if (Rank(other) == 2) {
|
||||
if (Rank(input) > 2) {
|
||||
int64_t new_shape_dim0 = 1;
|
||||
for (size_t i = 0; i < shape1_orig.size() - 1; ++i) {
|
||||
new_shape_dim0 *= shape1_orig[i];
|
||||
}
|
||||
std::vector<int64_t> new_shape_vector = {new_shape_dim0, shape1_orig.back()};
|
||||
input = ib->Reshape(input, ib->Value(new_shape_vector));
|
||||
}
|
||||
ret = ib->MatMul(input, other, false, transpose_b);
|
||||
} else {
|
||||
size_t ndim_aligned = std::max(input_rank, other_rank);
|
||||
input = Expand(ib, input, ndim_aligned);
|
||||
other = Expand(ib, other, ndim_aligned);
|
||||
ShapeVector shape1_aligned = input->shape();
|
||||
ShapeVector shape2_aligned = other->shape();
|
||||
ShapeVector shape_cur1(shape1_aligned.begin(), shape1_aligned.end() - 2);
|
||||
ShapeVector shape_cur2(shape2_aligned.begin(), shape2_aligned.end() - 2);
|
||||
const ShapeVector &broadcast_shape1 = ops::GetMatMulExtBroadcastShape(shape_backbone, shape1_orig);
|
||||
const ShapeVector &broadcast_shape2 = ops::GetMatMulExtBroadcastShape(shape_backbone, shape2_orig);
|
||||
if (input->shape() != broadcast_shape1) {
|
||||
input = ib->Emit("BroadcastTo", {input, ib->Value(broadcast_shape1)});
|
||||
}
|
||||
if (other->shape() != broadcast_shape2) {
|
||||
other = ib->Emit("BroadcastTo", {other, ib->Value(broadcast_shape2)});
|
||||
}
|
||||
input = ib->Reshape(input, To3D(input->shape()));
|
||||
other = ib->Reshape(other, To3D(other->shape()));
|
||||
ret = ib->BatchMatMul(input, other, false, transpose_b);
|
||||
}
|
||||
ret = ib->Reshape(ret, ib->Value(shape_out));
|
||||
return {ret};
|
||||
});
|
||||
|
||||
REG_FALLBACK_BUILDER("MeanExt").SetBody(BODYFUNC(ib) {
|
||||
|
|
|
@ -46,7 +46,22 @@ bool MMAclnnKernelMod::Launch(const std::vector<KernelTensor *> &inputs, const s
|
|||
RunOp(stream_ptr, workspace);
|
||||
return true;
|
||||
}
|
||||
|
||||
void MMExtAclnnKernelMod::GetWorkSpaceInfo(const std::vector<KernelTensor *> &inputs,
|
||||
const std::vector<KernelTensor *> &outputs) {
|
||||
GetWorkspaceForResize(inputs[kIndex0], inputs[kIndex1], outputs[kIndex0], OpApiUtil::GetCubeMathType());
|
||||
}
|
||||
|
||||
bool MMExtAclnnKernelMod::Launch(const std::vector<KernelTensor *> &inputs,
|
||||
const std::vector<KernelTensor *> &workspace,
|
||||
const std::vector<KernelTensor *> &outputs, void *stream_ptr) {
|
||||
ParseGenExecutor(GEN_EXECUTOR_BOOST(op_type_, hash_id_, inputs[kIndex0], inputs[kIndex1], outputs[kIndex0],
|
||||
OpApiUtil::GetCubeMathType()));
|
||||
RunOp(stream_ptr, workspace);
|
||||
return true;
|
||||
}
|
||||
MS_ACLNN_KERNEL_FACTORY_REG(MatMul, MMAclnnKernelMod);
|
||||
MS_ACLNN_KERNEL_FACTORY_REG(MatMulV2, MMAclnnKernelMod);
|
||||
MS_ACLNN_KERNEL_FACTORY_REG(MatMulExt, MMExtAclnnKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -40,6 +40,19 @@ class MMAclnnKernelMod : public AclnnKernelMod {
|
|||
std::pair<KernelTensor *, bool> input_a_;
|
||||
std::pair<KernelTensor *, bool> input_b_;
|
||||
};
|
||||
|
||||
class MMExtAclnnKernelMod : public AclnnKernelMod {
|
||||
public:
|
||||
MMExtAclnnKernelMod() : AclnnKernelMod("aclnnMatmul") {}
|
||||
~MMExtAclnnKernelMod() = default;
|
||||
|
||||
void GetWorkSpaceInfo(const std::vector<KernelTensor *> &inputs, const std::vector<KernelTensor *> &outputs) override;
|
||||
bool Launch(const std::vector<KernelTensor *> &inputs, const std::vector<KernelTensor *> &workspace,
|
||||
const std::vector<KernelTensor *> &outputs, void *stream_ptr) override;
|
||||
|
||||
private:
|
||||
DEFINE_GET_WORKSPACE_FOR_RESIZE()
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -1282,6 +1282,8 @@ class OnnxExporter {
|
|||
std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto);
|
||||
void ExportPrimNotEqual(const FuncGraphPtr &func_graph, const CNodePtr &node,
|
||||
std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto);
|
||||
void ExportPrimDense(const FuncGraphPtr &func_graph, const CNodePtr &node,
|
||||
std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto);
|
||||
void ExportPrimSqueeze(const FuncGraphPtr &func_graph, const CNodePtr &node,
|
||||
std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto);
|
||||
void ExportPrimDynamicRNN(const FuncGraphPtr &func_graph, const CNodePtr &node,
|
||||
|
@ -3220,6 +3222,19 @@ void OnnxExporter::ExportPrimNotEqual(const FuncGraphPtr &, const CNodePtr &node
|
|||
AddOp("Not", {equal_name}, {node_name}, graph_proto);
|
||||
}
|
||||
|
||||
void OnnxExporter::ExportPrimDense(const FuncGraphPtr &func_graph, const CNodePtr &node,
|
||||
std::map<AnfNodePtr, std::string> *node_map_ptr,
|
||||
onnx::GraphProto *const graph_proto) {
|
||||
auto matmul_node = dyn_cast<CNode>(node->input(kOneNum));
|
||||
auto input_x = matmul_node->input(kOneNum); // matmul input x
|
||||
auto input_y = matmul_node->input(kTwoNum); // matmul input y
|
||||
auto input_b = node->input(kTwoNum); // matmul bias
|
||||
|
||||
PrimitivePtr prim_matmul = dyn_cast<Primitive>((dyn_cast<ValueNode>(matmul_node->input(kZeroNum)))->value());
|
||||
std::vector<AnfNodePtr> inputs{input_x, input_y, input_b};
|
||||
(*node_map_ptr)[node] = ExportPrimitive(func_graph, node_map_ptr, prim_matmul, inputs, graph_proto);
|
||||
}
|
||||
|
||||
void OnnxExporter::ExportPrimSqueeze(const FuncGraphPtr &, const CNodePtr &node,
|
||||
std::map<AnfNodePtr, std::string> *node_map_ptr,
|
||||
onnx::GraphProto *const graph_proto) {
|
||||
|
@ -3926,6 +3941,7 @@ void OnnxExporter::ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &n
|
|||
{prim::kPrimGreaterEqual, &OnnxExporter::ExportPrimGreaterEqual},
|
||||
{prim::kPrimLessEqual, &OnnxExporter::ExportPrimLessEqual},
|
||||
{prim::kPrimNotEqual, &OnnxExporter::ExportPrimNotEqual},
|
||||
{prim::kPrimDense, &OnnxExporter::ExportPrimDense},
|
||||
{prim::kPrimSqueeze, &OnnxExporter::ExportPrimSqueeze},
|
||||
{prim::kPrimExpandDims, &OnnxExporter::ExportPrimExpandDims},
|
||||
{prim::kPrimGatherD, &OnnxExporter::ExportPrimGatherD},
|
||||
|
|
|
@ -19,6 +19,7 @@ from __future__ import absolute_import
|
|||
import mindspore.numpy as mnp
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.ops.auto_generate import MatMulExt
|
||||
from mindspore.ops.primitive import _primexpr
|
||||
from mindspore.common import Tensor
|
||||
from mindspore.ops.operations import math_ops
|
||||
|
@ -310,6 +311,27 @@ def get_matmul_vmap_rule(prim, axis_size):
|
|||
return vmap_rule
|
||||
|
||||
|
||||
@vmap_rules_getters.register(MatMulExt)
|
||||
def get_matmul_ext_vmap_rule(prim, axis_size):
|
||||
"""VmapRule for `*MatMulExt` operation."""
|
||||
if isinstance(prim, str):
|
||||
prim = Primitive(prim)
|
||||
|
||||
def vmap_rule(a_bdim, b_bdim):
|
||||
is_all_none, result = vmap_general_preprocess(prim, a_bdim, b_bdim)
|
||||
if is_all_none:
|
||||
return result
|
||||
|
||||
a, _ = a_bdim
|
||||
b, _ = b_bdim
|
||||
|
||||
matmul_ext = MatMulExt()
|
||||
out = matmul_ext(a, b)
|
||||
return out, 0
|
||||
|
||||
return vmap_rule
|
||||
|
||||
|
||||
@vmap_rules_getters.register(P.math_ops.MatrixSolve)
|
||||
def get_matrix_solve_vmap_rule(prim, axis_size):
|
||||
"""VmapRule for `*MatMul` operation."""
|
||||
|
|
|
@ -33,7 +33,7 @@ class UniformExtCell(Cell):
|
|||
return self.uniform(x, from_, to, generator)
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.parametrize("context_mode", [
|
||||
|
@ -59,11 +59,11 @@ def test_basic(context_mode):
|
|||
g1.manual_seed(41)
|
||||
|
||||
g2 = Generator()
|
||||
g2.manual_seed(43)
|
||||
g2.manual_seed(41)
|
||||
|
||||
output1 = uniform_cell(mindspore.tensor(x), from_, to, g1).numpy()
|
||||
expect1 = uniform_cell(mindspore.tensor(x), from_, to, g1).numpy()
|
||||
output2 = uniform_cell(mindspore.tensor(x), from_, to, g2).numpy()
|
||||
output2 = uniform_cell(mindspore.tensor(x), from_, to, g1).numpy()
|
||||
expect1 = uniform_cell(mindspore.tensor(x), from_, to, g2).numpy()
|
||||
expect2 = uniform_cell(mindspore.tensor(x), from_, to, g2).numpy()
|
||||
np.testing.assert_allclose(output1, expect1, rtol=rtol)
|
||||
|
||||
|
|
|
@ -58,7 +58,6 @@ class MatMulGradCell(Cell):
|
|||
[[4, 2, 5, 6], [3, 4, 2, 6, 5], (3, 4, 2, 5, 5), (4, 2, 5, 6)],
|
||||
[[4, 1, 5, 6], [4, 2, 6, 5], (4, 2, 5, 5), (4, 1, 5, 6)],
|
||||
[[4, 2, 5, 6], [4, 2, 6, 5], (4, 2, 5, 5), (4, 2, 5, 6)],
|
||||
[[2, 2, 2, 2, 2, 2, 2], [2, 2, 2, 2, 2, 2, 2], (2, 2, 2, 2, 2, 2, 2), (2, 2, 2, 2, 2, 2, 2)],
|
||||
])
|
||||
def test_ops(context_mode, shape1, shape2, output_shape1, output_shape2):
|
||||
"""
|
||||
|
|
Loading…
Reference in New Issue