forked from mindspore-Ecosystem/mindspore
Fix vmap problem
This commit is contained in:
parent
ce82191fbe
commit
03df4e4bf5
|
@ -707,9 +707,9 @@ AbstractBasePtr GetLogicalViewAbs(const AbstractBasePtr &physical_view_abs, cons
|
|||
return GetLogicalViewAbs(sub_abs, sub_in_axes, axis_size);
|
||||
});
|
||||
if (physical_view_abs->isa<AbstractList>()) {
|
||||
return std::make_shared<AbstractList>(logical_view_abs_list);
|
||||
return std::make_shared<AbstractList>(logical_view_abs_list, physical_view_abs_sequence->sequence_nodes());
|
||||
}
|
||||
return std::make_shared<AbstractTuple>(logical_view_abs_list);
|
||||
return std::make_shared<AbstractTuple>(logical_view_abs_list, physical_view_abs_sequence->sequence_nodes());
|
||||
}
|
||||
ValuePtr in_axis = in_axes;
|
||||
if (in_axis->isa<Int64Imm>()) {
|
||||
|
|
|
@ -24,6 +24,7 @@ import mindspore.ops.functional as F
|
|||
from mindspore import dtype as mstype
|
||||
from mindspore.common import Tensor
|
||||
from mindspore.ops.functional import vmap
|
||||
from mindspore.common.api import ms_function
|
||||
from mindspore.common.parameter import Parameter
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
@ -351,3 +352,40 @@ def test_vmap_nested_axes():
|
|||
assert res3 == expect_res3
|
||||
assert np.allclose(res4.asnumpy(), expect_res4.asnumpy())
|
||||
assert np.allclose(res5.asnumpy(), expect_res5.asnumpy())
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_vmap_with_tuple_input():
|
||||
"""
|
||||
Feature: vmap
|
||||
Description: When vmap use tuple inputs in graph, it must ensure the inputs is not eliminated.
|
||||
Expectation: success
|
||||
"""
|
||||
def real_fn(x, y):
|
||||
return x * y
|
||||
|
||||
def foo(fn):
|
||||
@ms_function
|
||||
def wrapped(*args):
|
||||
def fn2(x, y):
|
||||
return F.jvp(fn, x, y)
|
||||
res = F.vmap(fn2)(args, args)
|
||||
return res
|
||||
return wrapped
|
||||
|
||||
shape = (2, 3)
|
||||
a = F.ones(shape, mstype.int32)
|
||||
b = F.ones(shape, mstype.int32) * 2
|
||||
res = foo(real_fn)(a, b)
|
||||
|
||||
assert isinstance(res, tuple)
|
||||
assert len(res) == 2
|
||||
assert isinstance(res[0], Tensor)
|
||||
assert isinstance(res[1], Tensor)
|
||||
assert np.allclose(res[0].asnumpy(), np.array([[2, 2, 2], [2, 2, 2]]))
|
||||
assert np.allclose(res[1].asnumpy(), np.array([[4, 4, 4], [4, 4, 4]]))
|
||||
|
|
Loading…
Reference in New Issue