!42851 fix dynamic rank infer bug in vmap

Merge pull request !42851 from Erpim/master_0925
This commit is contained in:
i-robot 2022-09-26 01:06:55 +00:00 committed by Gitee
commit d9f25439aa
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
2 changed files with 12 additions and 7 deletions

View File

@ -669,6 +669,9 @@ AbstractBasePtr ExtendDim(int *axis, const AbstractBasePtr &orig_abs, int axis_s
if (shape != nullptr) {
orig_shape = shape->shape();
}
if (std::any_of(orig_shape.begin(), orig_shape.end(), [](ShapeValueDType s) { return s == UNKNOWN_RANK; })) {
return orig_abs;
}
}
int shape_len = SizeToInt(orig_shape.size() + 1);
if (*axis < -shape_len || *axis >= shape_len) {

View File

@ -16,6 +16,7 @@
import numpy as np
import pytest
import mindspore.nn as nn
import mindspore.ops as ops
from mindspore import Tensor, context
from mindspore import dtype as mstype
from mindspore.ops.operations.array_ops import MatrixDiagV3
@ -39,7 +40,8 @@ class VmapNet(nn.Cell):
self.vmap_op = vmap(self.op, in_axes=(0, None, None, None, None), out_axes=0)
def construct(self, x, k, num_rows, num_cols, padding_value):
output = self.vmap_op(x, k, num_rows, num_cols, padding_value)
out = self.vmap_op(x, k, num_rows, num_cols, padding_value)
output = ops.expand_dims(out, 0)
return output
@ -208,10 +210,10 @@ def test_matrix_diag_v3_vmap(data_type):
[6, 7, 9],
[0, 9, 1]]]).astype(data_type)
k = (-1, 1)
expect = np.array([[[1, 8, 0],
[4, 2, 9],
[0, 5, 3]],
[[6, 2, 0],
[9, 7, 3],
[0, 1, 9]]]).astype(data_type)
expect = np.array([[[[1, 8, 0],
[4, 2, 9],
[0, 5, 3]],
[[6, 2, 0],
[9, 7, 3],
[0, 1, 9]]]]).astype(data_type)
benchmark(diagonal, expect, k=k, align="LEFT_RIGHT", is_vmap=True)