forked from mindspore-Ecosystem/mindspore
!11978 Add grad impl for op MatrixInverse
From: @yuan_shen_zhou Reviewed-by: @liangchenghui,@wuxuejian Signed-off-by: @liangchenghui
This commit is contained in:
commit
c2d120e714
|
@ -165,6 +165,22 @@ def get_bprop_tensor_add(self):
|
|||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(P.MatrixInverse)
|
||||
def get_bprop_matrix_inverse(self):
|
||||
"""Grad definition for `MatrixInverse` operation."""
|
||||
batchmatmul_a = P.math_ops.BatchMatMul(transpose_a=True)
|
||||
batchmatmul_b = P.math_ops.BatchMatMul(transpose_b=True)
|
||||
neg = P.Neg()
|
||||
|
||||
def bprop(x, out, dout):
|
||||
dx = batchmatmul_b(dout, out)
|
||||
dx = batchmatmul_a(out, dx)
|
||||
dx = neg(dx)
|
||||
return dx
|
||||
|
||||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(P.Neg)
|
||||
def get_bprop_neg(self):
|
||||
"""Grad definition for `Neg` operation."""
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
# Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -4136,6 +4136,9 @@ class MatrixInverse(PrimitiveWithInfer):
|
|||
Returns the inverse of the input matrix. If the matrix is irreversible, an error may be reported or an unknown
|
||||
result may be returned
|
||||
|
||||
Note:
|
||||
The parameter 'adjoint' is only supporting False right now. Because complex number is not supported at present.
|
||||
|
||||
Args:
|
||||
adjoint (bool) : An optional bool. Default: False.
|
||||
|
||||
|
@ -4146,6 +4149,9 @@ class MatrixInverse(PrimitiveWithInfer):
|
|||
Outputs:
|
||||
Tensor, has the same type and shape as input `x`.
|
||||
|
||||
Supported Platforms:
|
||||
``GPU``
|
||||
|
||||
Examples:
|
||||
>>> mindspore.set_seed(1)
|
||||
>>> x = Tensor(np.random.uniform(-2, 2, (2, 2, 2)), mindspore.float32)
|
||||
|
@ -4161,7 +4167,7 @@ class MatrixInverse(PrimitiveWithInfer):
|
|||
@prim_attr_register
|
||||
def __init__(self, adjoint=False):
|
||||
"""Initialize MatrixInverse"""
|
||||
validator.check_value_type("adjoint", adjoint, [bool], self.name)
|
||||
validator.check_type_name("adjoint", adjoint, False, self.name)
|
||||
self.adjoint = adjoint
|
||||
|
||||
def infer_dtype(self, x_dtype):
|
||||
|
|
Loading…
Reference in New Issue