Fix vmap problem

This commit is contained in:
liangzhibo 2022-06-28 20:19:53 +08:00
parent ce82191fbe
commit 03df4e4bf5
2 changed files with 40 additions and 2 deletions

View File

@ -707,9 +707,9 @@ AbstractBasePtr GetLogicalViewAbs(const AbstractBasePtr &physical_view_abs, cons
return GetLogicalViewAbs(sub_abs, sub_in_axes, axis_size);
});
if (physical_view_abs->isa<AbstractList>()) {
return std::make_shared<AbstractList>(logical_view_abs_list);
return std::make_shared<AbstractList>(logical_view_abs_list, physical_view_abs_sequence->sequence_nodes());
}
return std::make_shared<AbstractTuple>(logical_view_abs_list);
return std::make_shared<AbstractTuple>(logical_view_abs_list, physical_view_abs_sequence->sequence_nodes());
}
ValuePtr in_axis = in_axes;
if (in_axis->isa<Int64Imm>()) {

View File

@ -24,6 +24,7 @@ import mindspore.ops.functional as F
from mindspore import dtype as mstype
from mindspore.common import Tensor
from mindspore.ops.functional import vmap
from mindspore.common.api import ms_function
from mindspore.common.parameter import Parameter
context.set_context(mode=context.GRAPH_MODE)
@ -351,3 +352,40 @@ def test_vmap_nested_axes():
assert res3 == expect_res3
assert np.allclose(res4.asnumpy(), expect_res4.asnumpy())
assert np.allclose(res5.asnumpy(), expect_res5.asnumpy())
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_vmap_with_tuple_input():
"""
Feature: vmap
Description: When vmap use tuple inputs in graph, it must ensure the inputs is not eliminated.
Expectation: success
"""
def real_fn(x, y):
return x * y
def foo(fn):
@ms_function
def wrapped(*args):
def fn2(x, y):
return F.jvp(fn, x, y)
res = F.vmap(fn2)(args, args)
return res
return wrapped
shape = (2, 3)
a = F.ones(shape, mstype.int32)
b = F.ones(shape, mstype.int32) * 2
res = foo(real_fn)(a, b)
assert isinstance(res, tuple)
assert len(res) == 2
assert isinstance(res[0], Tensor)
assert isinstance(res[1], Tensor)
assert np.allclose(res[0].asnumpy(), np.array([[2, 2, 2], [2, 2, 2]]))
assert np.allclose(res[1].asnumpy(), np.array([[4, 4, 4], [4, 4, 4]]))