forked from mindspore-Ecosystem/mindspore
commit
b554d27a2d
|
@ -56,7 +56,7 @@ def _check_dtype(dtype):
|
||||||
# convert the string dtype to mstype.dtype
|
# convert the string dtype to mstype.dtype
|
||||||
if isinstance(dtype, str):
|
if isinstance(dtype, str):
|
||||||
dtype = dtype.lower()
|
dtype = dtype.lower()
|
||||||
dtype = dtype_map[dtype]
|
dtype = dtype_map.get(dtype, "")
|
||||||
elif isinstance(dtype, type):
|
elif isinstance(dtype, type):
|
||||||
if dtype is int:
|
if dtype is int:
|
||||||
dtype = mstype.int32
|
dtype = mstype.int32
|
||||||
|
@ -264,13 +264,19 @@ def _promote(dtype1, dtype2):
|
||||||
if dtype1 == dtype2:
|
if dtype1 == dtype2:
|
||||||
return dtype1
|
return dtype1
|
||||||
if (dtype1, dtype2) in promotion_rule:
|
if (dtype1, dtype2) in promotion_rule:
|
||||||
return promotion_rule[dtype1, dtype2]
|
return promotion_rule.get((dtype1, dtype2))
|
||||||
return promotion_rule[dtype2, dtype1]
|
res = promotion_rule.get((dtype2, dtype1))
|
||||||
|
if res is None:
|
||||||
|
raise TypeError(f"input dtype: {dtype1}, {dtype2} are not all mindspore datatypes.")
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
@constexpr
|
@constexpr
|
||||||
def _promote_for_trigonometric(dtype):
|
def _promote_for_trigonometric(dtype):
|
||||||
return rule_for_trigonometric[dtype]
|
res_dtype = rule_for_trigonometric.get(dtype)
|
||||||
|
if res_dtype is None:
|
||||||
|
raise TypeError(f"input dtype: {dtype} is not a mindspore datatype.")
|
||||||
|
return res_dtype
|
||||||
|
|
||||||
|
|
||||||
@constexpr
|
@constexpr
|
||||||
|
|
|
@ -88,6 +88,7 @@ def get_bprop_sparse_tensor_dense_matmul(self):
|
||||||
return zeros_like(indices), values_grad, zeros_like(dense_shape), dense_grad
|
return zeros_like(indices), values_grad, zeros_like(dense_shape), dense_grad
|
||||||
return bprop
|
return bprop
|
||||||
|
|
||||||
|
|
||||||
@bprop_getters.register(_csr_ops.CSRReduceSum)
|
@bprop_getters.register(_csr_ops.CSRReduceSum)
|
||||||
def get_bprop_csr_reduce_sum(self):
|
def get_bprop_csr_reduce_sum(self):
|
||||||
"Back-propagation for CSRReduceSum."
|
"Back-propagation for CSRReduceSum."
|
||||||
|
@ -103,6 +104,7 @@ def get_bprop_csr_reduce_sum(self):
|
||||||
return F.make_csr_tensor(indptr, indices, values_grad, shape), zeros_like(axis)
|
return F.make_csr_tensor(indptr, indices, values_grad, shape), zeros_like(axis)
|
||||||
return bprop
|
return bprop
|
||||||
|
|
||||||
|
|
||||||
@bprop_getters.register(_csr_ops.CSRMV)
|
@bprop_getters.register(_csr_ops.CSRMV)
|
||||||
def get_bprop_csr_mv(self):
|
def get_bprop_csr_mv(self):
|
||||||
"Back-propagation for CSRMV."
|
"Back-propagation for CSRMV."
|
||||||
|
@ -129,6 +131,7 @@ def get_bprop_csr_mv(self):
|
||||||
return F.make_csr_tensor(indptr, indices, values_grad, csr_tensor.shape), dense_grad
|
return F.make_csr_tensor(indptr, indices, values_grad, csr_tensor.shape), dense_grad
|
||||||
return bprop
|
return bprop
|
||||||
|
|
||||||
|
|
||||||
@bprop_getters.register(_csr_ops.CSRMul)
|
@bprop_getters.register(_csr_ops.CSRMul)
|
||||||
def get_bprop_csr_mul(self):
|
def get_bprop_csr_mul(self):
|
||||||
"Back-propagation for CSRMul."
|
"Back-propagation for CSRMul."
|
||||||
|
@ -153,12 +156,14 @@ def get_bprop_csr_mul(self):
|
||||||
return csr_tensor_grad, dense_grad
|
return csr_tensor_grad, dense_grad
|
||||||
return bprop
|
return bprop
|
||||||
|
|
||||||
|
|
||||||
@bprop_getters.register(_csr_ops.CSR2COO)
|
@bprop_getters.register(_csr_ops.CSR2COO)
|
||||||
def get_bprop_csr2coo(self):
|
def get_bprop_csr2coo(self):
|
||||||
def bprop(indptr, nnz, out, dout):
|
def bprop(indptr, nnz, out, dout):
|
||||||
return zeros_like(dout)
|
return zeros_like(dout)
|
||||||
return bprop
|
return bprop
|
||||||
|
|
||||||
|
|
||||||
@bprop_getters.register(_csr_ops.COO2CSR)
|
@bprop_getters.register(_csr_ops.COO2CSR)
|
||||||
def get_bprop_coo2csr(self):
|
def get_bprop_coo2csr(self):
|
||||||
def bprop(row_indices, height, out, dout):
|
def bprop(row_indices, height, out, dout):
|
||||||
|
|
|
@ -1006,6 +1006,7 @@ coo_tensor_get_dense_shape = Primitive('COOTensorGetDenseShape')
|
||||||
|
|
||||||
@constexpr
|
@constexpr
|
||||||
def print_info(info):
|
def print_info(info):
|
||||||
|
"""Print given error info"""
|
||||||
print(info)
|
print(info)
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue