forked from mindspore-Ecosystem/mindspore
!31346 add vmap ut and st
Merge pull request !31346 from Erpim/vmap_v12
This commit is contained in:
commit
91d9585df7
|
@ -848,23 +848,23 @@ ValuePtr CheckAxes(const AbstractBasePtr &axes_abs, const bool &is_in_axes = fal
|
|||
ValueSequencePtr in_axes_seq = dyn_cast<ValueSequence>(axes_value);
|
||||
int in_axes_size = SizeToInt(in_axes_seq->size());
|
||||
if (nparam != in_axes_size) {
|
||||
MS_LOG(EXCEPTION) << "When vmap`s `" << axes_name
|
||||
<< "` is a tuple or list, and its size must be equal to the number of arguments of `fn`: "
|
||||
MS_LOG(EXCEPTION) << "When vmap`s '" << axes_name
|
||||
<< "' is a tuple or list, and its size must be equal to the number of arguments of 'fn': "
|
||||
<< nparam << ", but got size: " << in_axes_size << ".";
|
||||
}
|
||||
}
|
||||
bool elem_all_none = IsAxesAllNone(axes_value);
|
||||
if (elem_all_none) {
|
||||
MS_LOG(EXCEPTION) << "The `" << axes_name << "` of `vmap` cannot be all None, but got " << axes_value->ToString()
|
||||
MS_LOG(EXCEPTION) << "The '" << axes_name << "' of 'vmap' cannot be all None, but got " << axes_value->ToString()
|
||||
<< ".";
|
||||
}
|
||||
} else {
|
||||
axes_value = axes_abs->BuildValue();
|
||||
MS_EXCEPTION_IF_NULL(axes_value);
|
||||
if (axes_value->isa<None>()) {
|
||||
MS_LOG(EXCEPTION) << "The `" << axes_name << "` of `vmap` cannot be a single None.";
|
||||
MS_LOG(EXCEPTION) << "The '" << axes_name << "' of 'vmap' cannot be a single None.";
|
||||
} else if (!axes_value->isa<Int64Imm>()) {
|
||||
MS_LOG(EXCEPTION) << "The axis in vmap`s `" << axes_name << "` can only be of type Int or None, but got "
|
||||
MS_LOG(EXCEPTION) << "The axis in vmap`s '" << axes_name << "' can only be of type Int or None, but got "
|
||||
<< axes_abs->ToString() << ".";
|
||||
}
|
||||
}
|
||||
|
@ -892,7 +892,7 @@ FuncGraphPtr VmapOperation::GenerateFuncGraph(const AbstractBasePtrList &args_sp
|
|||
|
||||
auto real_fn = dyn_cast<FuncGraphAbstractClosure>(fn);
|
||||
if (real_fn == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "'VmapOperation' arg0 " << fn->ToString() << " cast to `FuncGraphAbstractClosure` failed.";
|
||||
MS_LOG(EXCEPTION) << "'VmapOperation' arg0 " << fn->ToString() << " cast to 'FuncGraphAbstractClosure' failed.";
|
||||
}
|
||||
|
||||
FuncGraphPtr orig_graph = real_fn->func_graph();
|
||||
|
|
|
@ -623,7 +623,7 @@ AbstractBasePtr ReduceDim(int *axis, const AbstractBasePtr &orig_abs, int *axis_
|
|||
ShapeVector orig_shape = dyn_cast<abstract::Shape>(orig_abs->BuildShape())->shape();
|
||||
int shape_len = SizeToInt(orig_shape.size());
|
||||
if (*axis < -shape_len || *axis >= shape_len) {
|
||||
MS_LOG(EXCEPTION) << "ValueError: The axis: " << *axis << " in `in_axes` is out of bounds for array of dimension ["
|
||||
MS_LOG(EXCEPTION) << "ValueError: The axis: " << *axis << " in 'in_axes' is out of bounds for array of dimension ["
|
||||
<< -shape_len << "," << shape_len << ").";
|
||||
}
|
||||
*axis = *axis < 0 ? shape_len + *axis : *axis;
|
||||
|
@ -631,7 +631,7 @@ AbstractBasePtr ReduceDim(int *axis, const AbstractBasePtr &orig_abs, int *axis_
|
|||
if (*axis_size == -1) {
|
||||
*axis_size = LongToInt(temp_axes_size);
|
||||
} else if (*axis_size != temp_axes_size) {
|
||||
MS_LOG(EXCEPTION) << "The `axes_size` of each argument in the scope of `vmap` should be equal, but got "
|
||||
MS_LOG(EXCEPTION) << "The 'axis_size' of each argument in the scope of 'vmap' should be equal, but got "
|
||||
<< *axis_size << " and " << temp_axes_size << ".";
|
||||
}
|
||||
(void)orig_shape.erase(orig_shape.begin() + *axis);
|
||||
|
@ -666,15 +666,15 @@ AbstractBasePtr GetLogicalViewAbs(const AbstractBasePtr &physical_view_abs, cons
|
|||
return std::make_shared<AbstractTuple>(logical_view_abs_list);
|
||||
}
|
||||
ValuePtr in_axis = in_axes;
|
||||
if (!in_axis->isa<Int64Imm>() && !in_axis->isa<None>()) {
|
||||
MS_LOG(EXCEPTION) << "The axis in vmap's `in_axes` should be a None or a scalar of type Int64Imm, but got a "
|
||||
<< in_axis->ToString() << ".";
|
||||
}
|
||||
if (in_axis->isa<Int64Imm>()) {
|
||||
int axis = dyn_cast<Int64Imm>(in_axis)->value();
|
||||
auto logical_view_abs = ReduceDim(&axis, physical_view_abs, axis_size);
|
||||
return logical_view_abs;
|
||||
}
|
||||
if (!in_axis->isa<None>()) {
|
||||
MS_LOG(EXCEPTION) << "The axis in vmap's 'in_axes' should be a None or a scalar of type Int64Imm, but got a "
|
||||
<< in_axis->ToString() << ".";
|
||||
}
|
||||
// in_axis is None.
|
||||
return physical_view_abs;
|
||||
}
|
||||
|
@ -688,7 +688,7 @@ AbstractBasePtr ExtendDim(int *axis, const AbstractBasePtr &orig_abs, int axis_s
|
|||
}
|
||||
int shape_len = SizeToInt(orig_shape.size() + 1);
|
||||
if (*axis < -shape_len || *axis >= shape_len) {
|
||||
MS_LOG(EXCEPTION) << "ValueError: The axis: " << *axis << " in `out_axes` is out of bounds for array of dimension ["
|
||||
MS_LOG(EXCEPTION) << "ValueError: The axis: " << *axis << " in 'out_axes' is out of bounds for array of dimension ["
|
||||
<< -shape_len << "," << shape_len << ").";
|
||||
}
|
||||
*axis = *axis < 0 ? shape_len + *axis : *axis;
|
||||
|
@ -700,7 +700,7 @@ AbstractBasePtr ExtendDim(int *axis, const AbstractBasePtr &orig_abs, int axis_s
|
|||
} else if (orig_abs->isa<AbstractScalar>()) {
|
||||
out_abs = std::make_shared<abstract::AbstractTensor>(orig_abs, new_shape);
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "The outputs of vmap's `fn` should be consisting of tensors or constants, but got "
|
||||
MS_LOG(EXCEPTION) << "The outputs of vmap's 'fn' should be consisting of tensors or constants, but got "
|
||||
<< orig_abs->ToString() << ".";
|
||||
}
|
||||
return out_abs;
|
||||
|
@ -715,7 +715,7 @@ AbstractBasePtr GetPhysicalViewAbs(const AbstractBasePtr &logical_view_abs, cons
|
|||
auto out_axes_seq = dyn_cast<ValueSequeue>(out_axes);
|
||||
if (out_axes_seq != nullptr) {
|
||||
if (logical_view_abs_list.size() != out_axes_seq->size()) {
|
||||
MS_LOG(EXCEPTION) << "The size of vmap's `out_axes` should be equal to the number of results of `fn`: "
|
||||
MS_LOG(EXCEPTION) << "The size of vmap's 'out_axes' should be equal to the number of results of 'fn': "
|
||||
<< logical_view_abs_list.size() << ", but got size: " << out_axes_seq->size() << ".";
|
||||
}
|
||||
}
|
||||
|
@ -737,7 +737,7 @@ AbstractBasePtr GetPhysicalViewAbs(const AbstractBasePtr &logical_view_abs, cons
|
|||
} else if (sub_out_axes->isa<None>()) {
|
||||
return arg_spec;
|
||||
}
|
||||
MS_LOG(EXCEPTION) << "The axis in vmap's `out_axes` should be a None or a scalar of type Int64Imm, but got a "
|
||||
MS_LOG(EXCEPTION) << "The axis in vmap's 'out_axes' should be a None or a scalar of type Int64Imm, but got a "
|
||||
<< sub_out_axes->ToString() << ".";
|
||||
});
|
||||
if (logical_view_abs->isa<AbstractList>()) {
|
||||
|
@ -746,18 +746,24 @@ AbstractBasePtr GetPhysicalViewAbs(const AbstractBasePtr &logical_view_abs, cons
|
|||
return std::make_shared<AbstractTuple>(physical_view_abs_list);
|
||||
}
|
||||
|
||||
int axis = 0;
|
||||
if (out_axes->isa<None>()) {
|
||||
return logical_view_abs;
|
||||
} else if (out_axes->isa<ValueSequeue>()) {
|
||||
ValueSequeuePtr out_axes_seq = dyn_cast<ValueSequeue>(out_axes);
|
||||
// for the single output case, outputs: A, and out_axes: 1 or (1,).
|
||||
ValuePtr sub_out_axes = out_axes;
|
||||
ValueSequeuePtr out_axes_seq = dyn_cast<ValueSequeue>(out_axes);
|
||||
if (out_axes_seq != nullptr) {
|
||||
if (out_axes_seq->size() != 1) {
|
||||
MS_LOG(EXCEPTION) << "The size of vmap's `out_axes` should be equal to the result size: 1, but got size: "
|
||||
MS_LOG(EXCEPTION) << "The size of vmap's 'out_axes' should be equal to the result size: 1, but got size: "
|
||||
<< out_axes_seq->size() << ".";
|
||||
}
|
||||
axis = dyn_cast<Int64Imm>((*out_axes_seq)[0])->value();
|
||||
} else if (out_axes->isa<Int64Imm>()) {
|
||||
axis = dyn_cast<Int64Imm>(out_axes)->value();
|
||||
sub_out_axes = (*out_axes_seq)[0];
|
||||
}
|
||||
|
||||
int axis = 0;
|
||||
auto axis_int_ptr = dyn_cast<Int64Imm>(sub_out_axes);
|
||||
if (axis_int_ptr != nullptr) {
|
||||
axis = LongToInt(axis_int_ptr->value());
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "The axis in vmap's 'out_axes' should be a None or a scalar of type Int64Imm, but got a "
|
||||
<< sub_out_axes->ToString() << ".";
|
||||
}
|
||||
return ExtendDim(&axis, logical_view_abs, axis_size);
|
||||
}
|
||||
|
|
|
@ -139,7 +139,7 @@ void PrimitivePy::set_signatures(const std::vector<Signature> &signatures) {
|
|||
}
|
||||
|
||||
py::function PrimitivePy::GetVmapRuleFunction(const bool is_side_effect, int axis_size) {
|
||||
static const char *const get_vmap_rule_func_name = "get_vmap_rule";
|
||||
constexpr char get_vmap_rule_func_name[] = "get_vmap_rule";
|
||||
if (py::hasattr(python_obj_, get_vmap_rule_func_name)) {
|
||||
py::function fn = python_obj_.attr(get_vmap_rule_func_name)().cast<py::function>();
|
||||
return fn;
|
||||
|
|
|
@ -122,20 +122,19 @@ def vmap_general_rule(prim, axis_size):
|
|||
vals_in_tuple = ()
|
||||
for val_in in args:
|
||||
val, dim = val_in
|
||||
if isinstance(val, Tensor):
|
||||
# Handle case such as args:(..., (A, 0), (B, 1), ...)
|
||||
if dim is None:
|
||||
val = _broadcast_by_axis(val, 0, axis_size)
|
||||
dim = 0
|
||||
out = P.Unstack(dim)(val)
|
||||
else:
|
||||
# Handle scalar case such as args:(..., (1, None), ...)
|
||||
if dim is not None:
|
||||
_raise_value_error("A variable of type other than `Tensor` is accepted, "
|
||||
"but the source axis is not `None`")
|
||||
out = ()
|
||||
out = ()
|
||||
if dim is None:
|
||||
# Handle case such as args:(..., (A, None), (1, None), ...)
|
||||
for _ in range(axis_size):
|
||||
out = out + (val,)
|
||||
else:
|
||||
if isinstance(val, Tensor):
|
||||
# Handle case such as args:(..., (A, 0), (B, 1), ...)
|
||||
out = P.Unstack(dim)(val)
|
||||
else:
|
||||
_raise_value_error("A variable of type other than `Tensor` is accepted, "
|
||||
"but the source axis is not `None`")
|
||||
|
||||
vals_in_tuple = vals_in_tuple + (out,)
|
||||
|
||||
if wrapped_tuple:
|
||||
|
|
|
@ -0,0 +1,354 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""test vmap in graph mode"""
|
||||
|
||||
import pytest
|
||||
import numpy as np
|
||||
import mindspore.nn as nn
|
||||
import mindspore.numpy as mnp
|
||||
import mindspore.context as context
|
||||
import mindspore.ops.operations as P
|
||||
import mindspore.ops.functional as F
|
||||
from mindspore import dtype as mstype
|
||||
from mindspore.common import Tensor
|
||||
from mindspore.ops.functional import vmap
|
||||
from mindspore.common.parameter import Parameter
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_vmap_cond():
|
||||
"""
|
||||
Feature: vmap
|
||||
Description: This case mainly tests the following `vmap` application scenarios in graph mode:
|
||||
1. The `fn` is a `Cell`, which contains control flow operators, such as `if` and `while`.
|
||||
2. The specific VmapRule of `Switch` and `Add` operation.
|
||||
3. The `in_axes` is a single integer, which automatically match to multiple arguments.
|
||||
Expectation: success
|
||||
"""
|
||||
class CondNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super(CondNet, self).__init__()
|
||||
self.inner_tensor_a = Tensor(2, mstype.int32)
|
||||
self.inner_tensor_b = Tensor(5, mstype.int32)
|
||||
|
||||
def construct(self, x, y):
|
||||
a = self.inner_tensor_a + 1
|
||||
b = self.inner_tensor_b
|
||||
if a < b:
|
||||
b += a
|
||||
else:
|
||||
b -= a
|
||||
b += 5
|
||||
i = 0
|
||||
while i < 4:
|
||||
x += 1
|
||||
i += 1
|
||||
out = b + x + y
|
||||
return out
|
||||
|
||||
x_hat = Tensor([2, 3, 1], mstype.int32)
|
||||
y_hat = Tensor([5, 4, 3], mstype.int32)
|
||||
result = vmap(CondNet(), 0, 0)(x_hat, y_hat)
|
||||
expect_result = Tensor([24, 24, 21], mstype.int32)
|
||||
assert np.allclose(result.asnumpy(), expect_result.asnumpy())
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_vmap_gradient():
|
||||
"""
|
||||
Feature: vmap
|
||||
Description: This case mainly tests the following `vmap` application scenarios in graph mode:
|
||||
1. `vmap` and `grad` are used in combination.
|
||||
2. `vmap` and `jvp` are used in combination.
|
||||
Expectation: success
|
||||
"""
|
||||
def forward_fn(x, y):
|
||||
out = x + 2 * y
|
||||
out = F.sin(out)
|
||||
return F.reduce_sum(out)
|
||||
|
||||
class GradNet(nn.Cell):
|
||||
def __init__(self, fn):
|
||||
super(GradNet, self).__init__()
|
||||
self.fn = fn
|
||||
|
||||
def construct(self, x, y):
|
||||
out = F.grad(self.fn, grad_position=(0, 1))(x, y)
|
||||
return out
|
||||
|
||||
def vmap_fn(x, y):
|
||||
output = vmap(forward_fn, 1, 0)(x, y)
|
||||
return F.reduce_sum(output)
|
||||
|
||||
def jvp_fn(x, y, v):
|
||||
out = F.jvp(forward_fn, (x, y), (v, v))
|
||||
return out
|
||||
|
||||
x_hat = Tensor([[1., 2., 3.], [2., 3., 4.]], mstype.float32)
|
||||
y_hat = Tensor([[2., 3., 4.], [3., 4., 5.]], mstype.float32)
|
||||
expect_x_grad = Tensor([[0.28366217, -0.14550003, 0.0044257],
|
||||
[-0.14550003, 0.0044257, 0.13673723]], mstype.float32)
|
||||
expect_y_grad = Tensor([[0.56732434, -0.29100007, 0.0088514],
|
||||
[-0.29100007, 0.0088514, 0.27347445]], mstype.float32)
|
||||
|
||||
vmap_grad_x, vmap_grad_y = vmap(GradNet(forward_fn), 1, 1)(x_hat, y_hat)
|
||||
assert np.allclose(vmap_grad_x.asnumpy(), expect_x_grad.asnumpy(), 0.0001, 0.0001)
|
||||
assert np.allclose(vmap_grad_y.asnumpy(), expect_y_grad.asnumpy(), 0.0001, 0.0001)
|
||||
|
||||
grad_vmap_x, grad_vmap_y = GradNet(vmap_fn)(x_hat, y_hat)
|
||||
assert np.allclose(grad_vmap_x.asnumpy(), expect_x_grad.asnumpy(), 0.0001, 0.0001)
|
||||
assert np.allclose(grad_vmap_y.asnumpy(), expect_y_grad.asnumpy(), 0.0001, 0.0001)
|
||||
|
||||
x_hat = Tensor(np.array([[1.], [2.], [3.]]), mstype.float32)
|
||||
y_hat = Tensor(np.array([[1.], [2.], [3.]]), mstype.float32)
|
||||
v_hat = Tensor(np.array([[1.], [2.], [3.]]), mstype.float32)
|
||||
|
||||
vmap_jvp_x, vmap_jvp_y = vmap(jvp_fn, 0, 0)(x_hat, y_hat, v_hat)
|
||||
expect_x_jvp = Tensor([0.141120002, -0.279415488, 0.412118465], mstype.float32)
|
||||
expect_y_jvp = Tensor([-2.96997738, 5.76102161, -8.20017242], mstype.float32)
|
||||
assert np.allclose(vmap_jvp_x.asnumpy(), expect_x_jvp.asnumpy(), 0.0001, 0.0001)
|
||||
assert np.allclose(vmap_jvp_y.asnumpy(), expect_y_jvp.asnumpy(), 0.0001, 0.0001)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_vmap_monad():
|
||||
"""
|
||||
Feature: vmap
|
||||
Description: This case mainly tests the following `vmap` application scenarios in graph mode:
|
||||
1. The `fn` is a `Cell`, which contains side effect operators, such as `AssignAdd`, `Assign`,
|
||||
`Print`, `ScatterAdd`.
|
||||
2. Parameter as argument.
|
||||
Expectation: success
|
||||
"""
|
||||
class AssignNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super(AssignNet, self).__init__()
|
||||
self.assign = P.Assign()
|
||||
self.assign_add = P.AssignAdd()
|
||||
self.scatter_add = P.ScatterAdd()
|
||||
self.assign_ref = Parameter(Tensor([[0, 0, 0], [1, 1, 1]], mstype.float32), name='assign_ref')
|
||||
self.replace_tensor = Tensor([[1, 1, 1], [2, 2, 2]], mstype.float32)
|
||||
|
||||
def construct(self, assign_add_val, assign_add_var, scatter_ref, indices, updates):
|
||||
self.assign(self.assign_ref, self.replace_tensor)
|
||||
F.print(self.assign_ref)
|
||||
out = self.assign_add(assign_add_var, assign_add_val) + self.scatter_add(scatter_ref, indices, updates)
|
||||
return out
|
||||
|
||||
class VmapMonadNet(nn.Cell):
|
||||
def __init__(self, net):
|
||||
super(VmapMonadNet, self).__init__()
|
||||
self.net = net
|
||||
self.assign_add_var = Parameter(
|
||||
Tensor([[[1, 1, 1], [1, 1, 1], [1, 1, 1]], [[2, 2, 2], [2, 2, 2], [2, 2, 2]]], mstype.float32),
|
||||
name='assign_add_var')
|
||||
self.scatter_ref = Parameter(
|
||||
Tensor([[[0, 0, 0], [0, 0, 0]], [[1, 1, 1], [1, 1, 1]], [[2, 2, 2], [2, 2, 2]]], mstype.float32),
|
||||
name='scatter_ref')
|
||||
|
||||
def construct(self, assign_add_val, scatter_indices, scatter_updates):
|
||||
output = vmap(self.net, (0, 1, 0, 0, None), 1)(assign_add_val, self.assign_add_var,
|
||||
self.scatter_ref, scatter_indices, scatter_updates)
|
||||
return output, self.assign_add_var
|
||||
|
||||
assign_add_val = Tensor([[[1, 1, 1], [2, 2, 2]], [[1, 1, 1], [2, 2, 2]], [[1, 1, 1], [2, 2, 2]]], mstype.float32)
|
||||
scatter_indices = Tensor([[[0, 1], [1, 1]], [[0, 1], [0, 1]], [[1, 1], [1, 0]]], mstype.int32)
|
||||
scatter_updates = Tensor([[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]]], mstype.int32)
|
||||
output, assign_add_var = VmapMonadNet(AssignNet())(assign_add_val, scatter_indices, scatter_updates)
|
||||
|
||||
expect_output = Tensor([[[3, 3, 3], [7, 7, 7], [8, 8, 8]], [[13, 13, 13], [11, 11, 11], [12, 12, 12]]],
|
||||
mstype.float32)
|
||||
expect_assign_add_var = Tensor([[[2, 2, 2], [2, 2, 2], [2, 2, 2]], [[4, 4, 4], [4, 4, 4], [4, 4, 4]]],
|
||||
mstype.float32)
|
||||
assert np.allclose(output.asnumpy(), expect_output.asnumpy())
|
||||
assert np.allclose(assign_add_var.asnumpy(), expect_assign_add_var.asnumpy())
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_vmap_reduce():
|
||||
"""
|
||||
Feature: vmap
|
||||
Description: This case mainly tests the following `vmap` application scenarios in graph mode:
|
||||
1. The specific VmapRule of `ReduceSum` operation.
|
||||
2. The `out_axes` is a single integer, which automatically match to multiple outputs.
|
||||
Expectation: success
|
||||
"""
|
||||
class ReduceNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super(ReduceNet, self).__init__()
|
||||
self.reduce_sum = P.ReduceSum(keep_dims=False)
|
||||
self.reduce_sum_keep_dims = P.ReduceSum(keep_dims=True)
|
||||
|
||||
def construct(self, x):
|
||||
out1 = self.reduce_sum(x)
|
||||
out2 = self.reduce_sum_keep_dims(x)
|
||||
out3 = self.reduce_sum(x, 1)
|
||||
out4 = self.reduce_sum_keep_dims(x, 1)
|
||||
out5 = self.reduce_sum(x, (0, 1))
|
||||
out6 = self.reduce_sum_keep_dims(x, (0, 1))
|
||||
output = (out1, out2, out3, out4, out5, out6)
|
||||
return output
|
||||
|
||||
class VmapNet(nn.Cell):
|
||||
def __init__(self, net):
|
||||
super(VmapNet, self).__init__()
|
||||
self.net = net
|
||||
|
||||
def construct(self, x):
|
||||
vmap_function = F.vmap(self.net, 1, 0)
|
||||
output = vmap_function(x)
|
||||
return output
|
||||
|
||||
x_hat = Tensor(np.array([[[[1, 1, 1, 1, 1, 1], [2, 2, 2, 2, 2, 2], [3, 3, 3, 3, 3, 3]],
|
||||
[[4, 4, 4, 4, 4, 4], [5, 5, 5, 5, 5, 5], [6, 6, 6, 6, 6, 6]],
|
||||
[[7, 7, 7, 7, 7, 7], [8, 8, 8, 8, 8, 8], [9, 9, 9, 9, 9, 9]]],
|
||||
[[[1, 1, 1, 1, 1, 1], [2, 2, 2, 2, 2, 2], [3, 3, 3, 3, 3, 3]],
|
||||
[[4, 4, 4, 4, 4, 4], [5, 5, 5, 5, 5, 5], [6, 6, 6, 6, 6, 6]],
|
||||
[[7, 7, 7, 7, 7, 7], [8, 8, 8, 8, 8, 8], [9, 9, 9, 9, 9, 9]]],
|
||||
[[[1, 1, 1, 1, 1, 1], [2, 2, 2, 2, 2, 2], [3, 3, 3, 3, 3, 3]],
|
||||
[[4, 4, 4, 4, 4, 4], [5, 5, 5, 5, 5, 5], [6, 6, 6, 6, 6, 6]],
|
||||
[[7, 7, 7, 7, 7, 7], [8, 8, 8, 8, 8, 8], [9, 9, 9, 9, 9, 9]]]]), mstype.float32)
|
||||
|
||||
result1, result2, result3, result4, result5, result6 = VmapNet(ReduceNet())(x_hat)
|
||||
expect_result1 = Tensor([108, 270, 432], mstype.float32)
|
||||
assert np.allclose(result1.asnumpy(), expect_result1.asnumpy())
|
||||
expect_result2 = Tensor([[[[108]]], [[[270]]], [[[432]]]], mstype.float32)
|
||||
assert np.allclose(result2.asnumpy(), expect_result2.asnumpy())
|
||||
expect_result3 = Tensor([[[6, 6, 6, 6, 6, 6], [6, 6, 6, 6, 6, 6], [6, 6, 6, 6, 6, 6]],
|
||||
[[15, 15, 15, 15, 15, 15], [15, 15, 15, 15, 15, 15], [15, 15, 15, 15, 15, 15]],
|
||||
[[24, 24, 24, 24, 24, 24], [24, 24, 24, 24, 24, 24], [24, 24, 24, 24, 24, 24]]],
|
||||
mstype.float32)
|
||||
assert np.allclose(result3.asnumpy(), expect_result3.asnumpy())
|
||||
expect_result4 = Tensor([[[[6, 6, 6, 6, 6, 6]], [[6, 6, 6, 6, 6, 6]], [[6, 6, 6, 6, 6, 6]]],
|
||||
[[[15, 15, 15, 15, 15, 15]], [[15, 15, 15, 15, 15, 15]], [[15, 15, 15, 15, 15, 15]]],
|
||||
[[[24, 24, 24, 24, 24, 24]], [[24, 24, 24, 24, 24, 24]], [[24, 24, 24, 24, 24, 24]]]],
|
||||
mstype.float32)
|
||||
assert np.allclose(result4.asnumpy(), expect_result4.asnumpy())
|
||||
expect_result5 = Tensor([[18, 18, 18, 18, 18, 18], [45, 45, 45, 45, 45, 45], [72, 72, 72, 72, 72, 72]],
|
||||
mstype.float32)
|
||||
assert np.allclose(result5.asnumpy(), expect_result5.asnumpy())
|
||||
expect_result6 = Tensor([[[[18, 18, 18, 18, 18, 18]]], [[[45, 45, 45, 45, 45, 45]]], [[[72, 72, 72, 72, 72, 72]]]],
|
||||
mstype.float32)
|
||||
assert np.allclose(result6.asnumpy(), expect_result6.asnumpy())
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_vmap_general_rule():
|
||||
"""
|
||||
Feature: vmap
|
||||
Description: This case mainly tests the following `vmap` application scenarios in graph mode:
|
||||
1. The general VmapRule.
|
||||
2. The specific VmapRule of `Reshape` operation.
|
||||
3. The same `vmap` object is called multiple times.
|
||||
4. The `mindspore.numpy` objects as the arguments.
|
||||
Expectation: success
|
||||
"""
|
||||
def convolve(x, w):
|
||||
output = []
|
||||
for i in range(1, len(x) - 1):
|
||||
output.append(mnp.dot(x[i - 1 : i + 2], w))
|
||||
return mnp.stack(output)
|
||||
|
||||
x = mnp.arange(5).astype('float32')
|
||||
w = mnp.array([1., 2., 3.])
|
||||
vmap_function = vmap(convolve)
|
||||
|
||||
x1 = mnp.stack([x, x, x])
|
||||
w1 = mnp.stack([w, w, w])
|
||||
result1 = vmap_function(x1, w1)
|
||||
expect_result1 = Tensor([[8, 14, 20], [8, 14, 20], [8, 14, 20]], mstype.float32)
|
||||
assert np.allclose(result1.asnumpy(), expect_result1.asnumpy())
|
||||
|
||||
x2 = mnp.stack([x, x + 1, x + 2])
|
||||
w2 = mnp.stack([w, w * 2, w * 3])
|
||||
result2 = vmap_function(x2, w2)
|
||||
expect_result2 = Tensor([[8, 14, 20], [28, 40, 52], [60, 78, 96]], mstype.float32)
|
||||
assert np.allclose(result2.asnumpy(), expect_result2.asnumpy())
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_vmap_nested_axes():
|
||||
"""
|
||||
Feature: vmap
|
||||
Description: This case mainly tests the following `vmap` application scenarios in graph mode:
|
||||
1. The nested inputs as the vmap's arguments.
|
||||
2. One element of the `in_axes` is a minus integer.
|
||||
3. Some outputs of the function is scalars with destination axis non-None.
|
||||
4. The `in_axes` is nested Tuple and List.
|
||||
5. VmapRule for that operators with indefinite length as input, such as `Stack`.
|
||||
Expectation: success
|
||||
"""
|
||||
class AddNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super(AddNet, self).__init__()
|
||||
self.inner_tensor = Tensor([5, 6], mstype.float32)
|
||||
self.inner_para = Parameter(Tensor([5, 6], mstype.float32), name='inner_para')
|
||||
|
||||
def construct(self, x, y):
|
||||
a = 1
|
||||
b = 2
|
||||
c = 3
|
||||
d = self.inner_tensor + a
|
||||
e = F.stack((self.inner_para, self.inner_para))
|
||||
return ((a, b), c), d, e
|
||||
|
||||
x_hat = Tensor([[1, 2, 3], [4, 5, 6]], mstype.float32)
|
||||
y_hat = Tensor([[1, 2, 3], [4, 5, 6]], mstype.float32)
|
||||
z_hat = 1
|
||||
|
||||
((res1, res2), res3), res4, res5 = \
|
||||
vmap(AddNet(), in_axes=(1, [-1, None]), out_axes=((0, None), 0, None))(x_hat, (y_hat, z_hat))
|
||||
expect_res1 = Tensor([1, 1, 1], mstype.float32)
|
||||
expect_res2 = Tensor([2, 2, 2], mstype.float32)
|
||||
expect_res3 = 3
|
||||
expect_res4 = Tensor([[6, 7], [6, 7], [6, 7]], mstype.float32)
|
||||
expect_res5 = Tensor([[5, 6], [5, 6]], mstype.float32)
|
||||
|
||||
assert np.allclose(res1.asnumpy(), expect_res1.asnumpy())
|
||||
assert np.allclose(res2.asnumpy(), expect_res2.asnumpy())
|
||||
assert res3 == expect_res3
|
||||
assert np.allclose(res4.asnumpy(), expect_res4.asnumpy())
|
||||
assert np.allclose(res5.asnumpy(), expect_res5.asnumpy())
|
|
@ -0,0 +1,69 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""test vmap in pynative mode"""
|
||||
|
||||
import pytest
|
||||
import numpy as np
|
||||
import mindspore.context as context
|
||||
import mindspore.ops.functional as F
|
||||
from mindspore import dtype as mstype
|
||||
from mindspore.common import Tensor
|
||||
from mindspore.ops.functional import vmap
|
||||
from mindspore.common.api import ms_function
|
||||
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_vmap_nested():
|
||||
"""
|
||||
Feature: vmap
|
||||
Description: This case mainly tests the following `vmap` application scenarios in PyNative mode:
|
||||
1.Calling nested `vmap` functions.
|
||||
2.`fn` is a function wrapped `ms_function`.
|
||||
3.Function contains free variables.
|
||||
Expectation: success
|
||||
"""
|
||||
outter_tensor = Tensor([1], mstype.float32)
|
||||
|
||||
def add_fn(x):
|
||||
return F.add(x, outter_tensor)
|
||||
|
||||
@ms_function
|
||||
def inner_vmap_fn(x, outter_tensor):
|
||||
vmap_funtion = vmap(add_fn, 1)
|
||||
out = vmap_funtion(x)
|
||||
output = out + outter_tensor
|
||||
return output
|
||||
|
||||
def outter_vmap_fn(x):
|
||||
output = vmap(inner_vmap_fn, (0, None), 1)(x, outter_tensor)
|
||||
return output
|
||||
|
||||
x_hat = Tensor([[[1., 2., 3.], [4., 5., 6.]],
|
||||
[[2., 3., 4.], [5., 6., 7.]],
|
||||
[[3., 4., 5.], [6., 7., 8.]],
|
||||
[[4., 5., 6.], [7., 8., 9.]]], mstype.float32)
|
||||
|
||||
result = outter_vmap_fn(x_hat)
|
||||
expect_result = Tensor([[[3., 6.], [4., 7.], [5., 8.], [6., 9.]],
|
||||
[[4., 7.], [5., 8.], [6., 9.], [7., 10.]],
|
||||
[[5., 8.], [6., 9.], [7., 10.], [8., 11.]]], mstype.float32)
|
||||
assert np.allclose(result.asnumpy(), expect_result.asnumpy())
|
|
@ -0,0 +1,220 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""test vmap in graph mode"""
|
||||
import pytest
|
||||
import mindspore.nn as nn
|
||||
import mindspore.context as context
|
||||
import mindspore.ops.operations as P
|
||||
from mindspore import Tensor
|
||||
from mindspore import dtype as mstype
|
||||
from mindspore.ops.functional import vmap
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
|
||||
class ThreeInputsTwoOutputsNet(nn.Cell):
|
||||
def construct(self, x, y, z):
|
||||
return x + y, z
|
||||
|
||||
|
||||
def test_lambda_fn():
|
||||
"""
|
||||
Feature: vmap
|
||||
Description: The first argument of `vmap` is a lambda function.
|
||||
Expectation: throw TypeError:"Parse Lambda Function Fail. Node type must be Lambda, but got Call."
|
||||
"""
|
||||
x_hat = Tensor([[1, 2, 3], [4, 5, 6]], mstype.float32)
|
||||
y_hat = Tensor([[1, 2, 3], [4, 5, 6]], mstype.float32)
|
||||
z_hat = 1
|
||||
with pytest.raises(TypeError) as ex:
|
||||
vmap(lambda x, y, z: x + y + z, in_axes=(1, 1, None), out_axes=0)(x_hat, y_hat, z_hat)
|
||||
assert "Parse Lambda Function Fail. Node type must be Lambda, but got Call." in str(ex.value)
|
||||
|
||||
|
||||
def test_single_op():
|
||||
"""
|
||||
Feature: vmap
|
||||
Description: The first argument of `vmap` is a single primitive.
|
||||
Expectation: throw RuntimeError:"'VmapOperation' arg0 Prim: S-Prim-Add cast to 'FuncGraphAbstractClosure' failed."
|
||||
"""
|
||||
x_hat = Tensor([[1, 2, 3], [4, 5, 6]], mstype.float32)
|
||||
y_hat = Tensor([[1, 2, 3], [4, 5, 6]], mstype.float32)
|
||||
with pytest.raises(RuntimeError) as ex:
|
||||
vmap(P.Add(), in_axes=(1, 1), out_axes=0)(x_hat, y_hat)
|
||||
assert "'VmapOperation' arg0 Prim: S-Prim-Add cast to 'FuncGraphAbstractClosure' failed." in str(ex.value)
|
||||
|
||||
|
||||
def test_none_in_axes():
|
||||
"""
|
||||
Feature: vmap
|
||||
Description: The `in_axis` argument of `vmap` is a single None, and it's invalid when apply `vmap`.
|
||||
Expectation: throw RuntimeError:"The 'in_axes' of 'vmap' cannot be a single None."
|
||||
"""
|
||||
x_hat = Tensor([[1, 2, 3], [4, 5, 6]], mstype.float32)
|
||||
y_hat = Tensor([[1, 2, 3], [4, 5, 6]], mstype.float32)
|
||||
z_hat = 1
|
||||
with pytest.raises(RuntimeError) as ex:
|
||||
vmap(ThreeInputsTwoOutputsNet(), in_axes=None, out_axes=0)(x_hat, y_hat, z_hat)
|
||||
assert "The 'in_axes' of 'vmap' cannot be a single None." in str(ex.value)
|
||||
|
||||
|
||||
def test_none_out_axes():
|
||||
"""
|
||||
Feature: vmap
|
||||
Description: The `out_axes` argument of `vmap` is a nested None, and it's invalid when apply `vmap`.
|
||||
Expectation: throw RuntimeError:"The 'out_axes' of 'vmap' cannot be all None, but got
|
||||
(None, None, None, (None, None))."
|
||||
"""
|
||||
x_hat = Tensor([[1, 2, 3], [4, 5, 6]], mstype.float32)
|
||||
y_hat = Tensor([[1, 2, 3], [4, 5, 6]], mstype.float32)
|
||||
z_hat = 1
|
||||
with pytest.raises(RuntimeError) as ex:
|
||||
vmap(ThreeInputsTwoOutputsNet(), in_axes=(1, 1, None),
|
||||
out_axes=(None, None, None, (None, None)))(x_hat, y_hat, z_hat)
|
||||
assert "The 'out_axes' of 'vmap' cannot be all None, but got (None, None, None, (None, None))." in str(ex.value)
|
||||
|
||||
|
||||
def test_mismatch_out_axes():
|
||||
"""
|
||||
Feature: vmap
|
||||
Description: The `out_axes` of `vmap` sets to (0, 0, 0), but the outputs of `fn` is x + y, z.
|
||||
Expectation: throw RuntimeError:"The size of vmap's 'out_axes' should be equal to the number of results of 'fn': 2,
|
||||
but got size: 3."
|
||||
"""
|
||||
x_hat = Tensor([[1, 2, 3], [4, 5, 6]], mstype.float32)
|
||||
y_hat = Tensor([[1, 2, 3], [4, 5, 6]], mstype.float32)
|
||||
z_hat = 1
|
||||
with pytest.raises(RuntimeError) as ex:
|
||||
vmap(ThreeInputsTwoOutputsNet(), in_axes=(1, 1, None), out_axes=(0, 0, 0))(x_hat, y_hat, z_hat)
|
||||
assert "The size of vmap's 'out_axes' should be equal to the number of results of 'fn': 2, but got size: 3." \
|
||||
in str(ex.value)
|
||||
|
||||
|
||||
def test_axis_type():
|
||||
"""
|
||||
Feature: vmap
|
||||
Description: The `in_axes` of `vmap` contains elements of Float type.
|
||||
Expectation: throw RuntimeError:"The axis in vmap's 'in_axes' should be a None or a scalar of type Int64Imm,
|
||||
but got a 1."
|
||||
"""
|
||||
x_hat = Tensor([[1, 2, 3], [4, 5, 6]], mstype.float32)
|
||||
y_hat = Tensor([[1, 2, 3], [4, 5, 6]], mstype.float32)
|
||||
z_hat = 1
|
||||
with pytest.raises(RuntimeError) as ex:
|
||||
vmap(ThreeInputsTwoOutputsNet(), in_axes=(1., 1., None), out_axes=0)(x_hat, y_hat, z_hat)
|
||||
assert "The axis in vmap's 'in_axes' should be a None or a scalar of type Int64Imm, but got a 1." in str(ex.value)
|
||||
|
||||
|
||||
def test_axis_out_of_bounds():
|
||||
"""
|
||||
Feature: vmap
|
||||
Description: The dimension of X is 2, but the corresponding axis -3 is set.
|
||||
Expectation: throw RuntimeError:"The axis: -3 in 'in_axes' is out of bounds for array of dimension [-2,2)."
|
||||
"""
|
||||
x_hat = Tensor([[1, 2, 3], [4, 5, 6]], mstype.float32)
|
||||
y_hat = Tensor([[1, 2, 3], [4, 5, 6]], mstype.float32)
|
||||
z_hat = 1
|
||||
with pytest.raises(RuntimeError) as ex:
|
||||
vmap(ThreeInputsTwoOutputsNet(), in_axes=(-3, 2, None), out_axes=0)(x_hat, y_hat, z_hat)
|
||||
assert "The axis: -3 in 'in_axes' is out of bounds for array of dimension [-2,2)." in str(ex.value)
|
||||
|
||||
|
||||
def test_mismatch_none_axis():
|
||||
"""
|
||||
Feature: vmap
|
||||
Description: The source axis of the first output of `fn` is non-None, but the `out_axes` for that is None,
|
||||
it's invalid when apply `vmap`.
|
||||
Expectation: throw RuntimeError:"It is invalid that source is not None and dst is None."
|
||||
"""
|
||||
x_hat = Tensor([[1, 2, 3], [4, 5, 6]], mstype.float32)
|
||||
y_hat = Tensor([[1, 2, 3], [4, 5, 6]], mstype.float32)
|
||||
z_hat = 1
|
||||
with pytest.raises(RuntimeError) as ex:
|
||||
vmap(ThreeInputsTwoOutputsNet(), in_axes=(1, 1, None), out_axes=(None, 0))(x_hat, y_hat, z_hat)
|
||||
assert "It is invalid that source is not None and dst is None." in str(ex.value)
|
||||
|
||||
|
||||
def test_mismatch_parameters_number():
|
||||
"""
|
||||
Feature: vmap
|
||||
Description: The arguments of the cell is (x, y, z), but the arguments of vmap-ed function is (x_hat, y_hat).
|
||||
Expectation: throw TypeError:"The parameters number of the function is 3, but the number of provided arguments
|
||||
is 2."
|
||||
"""
|
||||
x_hat = Tensor([[1, 2, 3], [4, 5, 6]], mstype.float32)
|
||||
y_hat = Tensor([[1, 2, 3], [4, 5, 6]], mstype.float32)
|
||||
with pytest.raises(TypeError) as ex:
|
||||
vmap(ThreeInputsTwoOutputsNet(), in_axes=(1, 1, None), out_axes=0)(x_hat, y_hat)
|
||||
assert "The parameters number of the function is 3, but the number of provided arguments is 2." in str(ex.value)
|
||||
|
||||
|
||||
def test_mismatch_axis_size():
|
||||
"""
|
||||
Feature: vmap
|
||||
Description: The `axis_size` of X is 3, and the `axis_size` of Y is 2, they are not equal, vmap needs to ensure
|
||||
that the `axis_size` of all parameters are uniform.
|
||||
Expectation: throw RuntimeError:"The 'axis_size' of each argument in the scope of 'vmap' should be equal,
|
||||
but got 3 and 2."
|
||||
"""
|
||||
x_hat = Tensor([[1, 2, 3], [4, 5, 6]], mstype.float32)
|
||||
y_hat = Tensor([[1, 2, 3], [4, 5, 6]], mstype.float32)
|
||||
z_hat = 1
|
||||
with pytest.raises(RuntimeError) as ex:
|
||||
vmap(ThreeInputsTwoOutputsNet(), in_axes=(1, 0, None), out_axes=0)(x_hat, y_hat, z_hat)
|
||||
assert "The 'axis_size' of each argument in the scope of 'vmap' should be equal, but got 3 and 2." in str(ex.value)
|
||||
|
||||
|
||||
def test_vmap_non_input():
|
||||
"""
|
||||
Feature: vmap
|
||||
Description: The arguments of the cell is empty, it's invalid when apply `vmap`.
|
||||
Expectation: throw RuntimeError:"Failed to get 'axis_size' within the scope of vmap."
|
||||
"""
|
||||
class NonInputSingleOutputNet(nn.Cell):
|
||||
def construct(self):
|
||||
return 1
|
||||
|
||||
with pytest.raises(RuntimeError) as ex:
|
||||
vmap(NonInputSingleOutputNet())()
|
||||
assert "Failed to get 'axis_size' within the scope of vmap." in str(ex.value)
|
||||
|
||||
|
||||
def test_non_fn():
|
||||
"""
|
||||
Feature: vmap
|
||||
Description: The first argument of `vmap` not provided, which is required positional argument.
|
||||
Expectation: throw TypeError:"vmap() missing 1 required positional argument: 'fn'"
|
||||
"""
|
||||
x_hat = Tensor([[1, 2, 3], [4, 5, 6]], mstype.float32)
|
||||
y_hat = Tensor([[1, 2, 3], [4, 5, 6]], mstype.float32)
|
||||
z_hat = 1
|
||||
with pytest.raises(TypeError) as ex:
|
||||
vmap(in_axes=(1, 1, None), out_axes=0)(x_hat, y_hat, z_hat)
|
||||
assert "vmap() missing 1 required positional argument: 'fn'" in str(ex.value)
|
||||
|
||||
|
||||
def test_scalar_with_non_zero_axis():
|
||||
"""
|
||||
Feature: vmap
|
||||
Description: The second output of `fn` is a scalar with source axis None, but get a destination axis 1, and it's
|
||||
invalid when apply `vmap`.
|
||||
Expectation: throw RuntimeError:"The axis: 1 in 'out_axes' is out of bounds for array of dimension [-1,1)."
|
||||
"""
|
||||
x_hat = Tensor([[1, 2, 3], [4, 5, 6]], mstype.float32)
|
||||
y_hat = Tensor([[1, 2, 3], [4, 5, 6]], mstype.float32)
|
||||
z_hat = 1
|
||||
with pytest.raises(RuntimeError) as ex:
|
||||
vmap(ThreeInputsTwoOutputsNet(), in_axes=(1, 1, None), out_axes=(0, 1))(x_hat, y_hat, z_hat)
|
||||
assert "The axis: 1 in 'out_axes' is out of bounds for array of dimension [-1,1)." in str(ex.value)
|
Loading…
Reference in New Issue