forked from mindspore-Ecosystem/mindspore
!2871 fix MatrixSetDiag
Merge pull request !2871 from jiangjinsheng/issue_fix4
This commit is contained in:
commit
e9bccf18b1
|
@ -616,7 +616,7 @@ class MatrixDiagPart(PrimitiveWithInfer):
|
|||
Tensor, data type same as input `x`. The shape should be x.shape[:-2] + [min(x.shape[-2:])].
|
||||
|
||||
Examples:
|
||||
>>> x = Tensor([[[-1, 0], [0, 1]], [-1, 0], [0, 1]], [[-1, 0], [0, 1]]], mindspore.float32)
|
||||
>>> x = Tensor([[[-1, 0], [0, 1]], [[-1, 0], [0, 1]], [[-1, 0], [0, 1]]], mindspore.float32)
|
||||
>>> assist = Tensor(np.arange(-12, 0).reshape(3, 2, 2), mindspore.float32)
|
||||
>>> matrix_diag_part = P.MatrixDiagPart()
|
||||
>>> result = matrix_diag_part(x, assist)
|
||||
|
@ -658,11 +658,11 @@ class MatrixSetDiag(PrimitiveWithInfer):
|
|||
Tensor, data type same as input `x`. The shape same as `x`.
|
||||
|
||||
Examples:
|
||||
>>> x = Tensor([[[-1, 0], [0, 1]], [-1, 0], [0, 1]], [[-1, 0], [0, 1]]], mindspore.float32)
|
||||
>>> x = Tensor([[[-1, 0], [0, 1]], [[-1, 0], [0, 1]], [[-1, 0], [0, 1]]], mindspore.float32)
|
||||
>>> diagonal = Tensor([[-1., 2.], [-1., 1.], [-1., 1.]], mindspore.float32)
|
||||
>>> matrix_set_diag = P.MatrixSetDiag()
|
||||
>>> result = matrix_set_diag(x, diagonal)
|
||||
[[[-1, 0], [0, 2]], [-1, 0], [0, 1]], [[-1, 0], [0, 1]]]
|
||||
[[[-1, 0], [0, 2]], [[-1, 0], [0, 1]], [[-1, 0], [0, 1]]]
|
||||
|
||||
"""
|
||||
|
||||
|
@ -681,10 +681,10 @@ class MatrixSetDiag(PrimitiveWithInfer):
|
|||
validator.check("x shape", x_shape, "assist shape", assist_shape, Rel.EQ, self.name)
|
||||
|
||||
if x_shape[-2] < x_shape[-1]:
|
||||
validator.check("x shape excluding the last dimension", x_shape[:-1], "diagnoal shape",
|
||||
diagonal_shape, Rel.EQ, self.name)
|
||||
validator.check("diagnoal shape", diagonal_shape, "x shape excluding the last dimension",
|
||||
x_shape[:-1], Rel.EQ, self.name)
|
||||
else:
|
||||
validator.check("x shape excluding the second to last dimension", x_shape[:-2]+x_shape[-1:],
|
||||
"diagonal shape", diagonal_shape, Rel.EQ, self.name)
|
||||
validator.check("diagonal shape", diagonal_shape, "x shape excluding the second last dimension",
|
||||
x_shape[:-2] + x_shape[-1:], Rel.EQ, self.name)
|
||||
|
||||
return assist_shape
|
||||
|
|
|
@ -1851,7 +1851,7 @@ class ApplyRMSProp(PrimitiveWithInfer):
|
|||
>>> decay = 0.0
|
||||
>>> momentum = 1e-10
|
||||
>>> epsilon = 0.001
|
||||
>>> result = apply_rms(input_x, mean_square, moment, grad, learning_rate, decay, momentum, epsilon)
|
||||
>>> result = apply_rms(input_x, mean_square, moment, learning_rate, grad, decay, momentum, epsilon)
|
||||
(-2.9977674, 0.80999994, 1.9987665)
|
||||
"""
|
||||
|
||||
|
|
Loading…
Reference in New Issue