forked from mindspore-Ecosystem/mindspore
!42851 fix dynamic rank infer bug in vmap
Merge pull request !42851 from Erpim/master_0925
This commit is contained in:
commit
d9f25439aa
|
@ -669,6 +669,9 @@ AbstractBasePtr ExtendDim(int *axis, const AbstractBasePtr &orig_abs, int axis_s
|
|||
if (shape != nullptr) {
|
||||
orig_shape = shape->shape();
|
||||
}
|
||||
if (std::any_of(orig_shape.begin(), orig_shape.end(), [](ShapeValueDType s) { return s == UNKNOWN_RANK; })) {
|
||||
return orig_abs;
|
||||
}
|
||||
}
|
||||
int shape_len = SizeToInt(orig_shape.size() + 1);
|
||||
if (*axis < -shape_len || *axis >= shape_len) {
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
import numpy as np
|
||||
import pytest
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops as ops
|
||||
from mindspore import Tensor, context
|
||||
from mindspore import dtype as mstype
|
||||
from mindspore.ops.operations.array_ops import MatrixDiagV3
|
||||
|
@ -39,7 +40,8 @@ class VmapNet(nn.Cell):
|
|||
self.vmap_op = vmap(self.op, in_axes=(0, None, None, None, None), out_axes=0)
|
||||
|
||||
def construct(self, x, k, num_rows, num_cols, padding_value):
|
||||
output = self.vmap_op(x, k, num_rows, num_cols, padding_value)
|
||||
out = self.vmap_op(x, k, num_rows, num_cols, padding_value)
|
||||
output = ops.expand_dims(out, 0)
|
||||
return output
|
||||
|
||||
|
||||
|
@ -208,10 +210,10 @@ def test_matrix_diag_v3_vmap(data_type):
|
|||
[6, 7, 9],
|
||||
[0, 9, 1]]]).astype(data_type)
|
||||
k = (-1, 1)
|
||||
expect = np.array([[[1, 8, 0],
|
||||
[4, 2, 9],
|
||||
[0, 5, 3]],
|
||||
[[6, 2, 0],
|
||||
[9, 7, 3],
|
||||
[0, 1, 9]]]).astype(data_type)
|
||||
expect = np.array([[[[1, 8, 0],
|
||||
[4, 2, 9],
|
||||
[0, 5, 3]],
|
||||
[[6, 2, 0],
|
||||
[9, 7, 3],
|
||||
[0, 1, 9]]]]).astype(data_type)
|
||||
benchmark(diagonal, expect, k=k, align="LEFT_RIGHT", is_vmap=True)
|
||||
|
|
Loading…
Reference in New Issue