forked from mindspore-Ecosystem/mindspore
commit
b554d27a2d
|
@ -56,7 +56,7 @@ def _check_dtype(dtype):
|
|||
# convert the string dtype to mstype.dtype
|
||||
if isinstance(dtype, str):
|
||||
dtype = dtype.lower()
|
||||
dtype = dtype_map[dtype]
|
||||
dtype = dtype_map.get(dtype, "")
|
||||
elif isinstance(dtype, type):
|
||||
if dtype is int:
|
||||
dtype = mstype.int32
|
||||
|
@ -264,13 +264,19 @@ def _promote(dtype1, dtype2):
|
|||
if dtype1 == dtype2:
|
||||
return dtype1
|
||||
if (dtype1, dtype2) in promotion_rule:
|
||||
return promotion_rule[dtype1, dtype2]
|
||||
return promotion_rule[dtype2, dtype1]
|
||||
return promotion_rule.get((dtype1, dtype2))
|
||||
res = promotion_rule.get((dtype2, dtype1))
|
||||
if res is None:
|
||||
raise TypeError(f"input dtype: {dtype1}, {dtype2} are not all mindspore datatypes.")
|
||||
return res
|
||||
|
||||
|
||||
@constexpr
|
||||
def _promote_for_trigonometric(dtype):
|
||||
return rule_for_trigonometric[dtype]
|
||||
res_dtype = rule_for_trigonometric.get(dtype)
|
||||
if res_dtype is None:
|
||||
raise TypeError(f"input dtype: {dtype} is not a mindspore datatype.")
|
||||
return res_dtype
|
||||
|
||||
|
||||
@constexpr
|
||||
|
|
|
@ -88,6 +88,7 @@ def get_bprop_sparse_tensor_dense_matmul(self):
|
|||
return zeros_like(indices), values_grad, zeros_like(dense_shape), dense_grad
|
||||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(_csr_ops.CSRReduceSum)
|
||||
def get_bprop_csr_reduce_sum(self):
|
||||
"Back-propagation for CSRReduceSum."
|
||||
|
@ -103,6 +104,7 @@ def get_bprop_csr_reduce_sum(self):
|
|||
return F.make_csr_tensor(indptr, indices, values_grad, shape), zeros_like(axis)
|
||||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(_csr_ops.CSRMV)
|
||||
def get_bprop_csr_mv(self):
|
||||
"Back-propagation for CSRMV."
|
||||
|
@ -129,6 +131,7 @@ def get_bprop_csr_mv(self):
|
|||
return F.make_csr_tensor(indptr, indices, values_grad, csr_tensor.shape), dense_grad
|
||||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(_csr_ops.CSRMul)
|
||||
def get_bprop_csr_mul(self):
|
||||
"Back-propagation for CSRMul."
|
||||
|
@ -153,12 +156,14 @@ def get_bprop_csr_mul(self):
|
|||
return csr_tensor_grad, dense_grad
|
||||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(_csr_ops.CSR2COO)
|
||||
def get_bprop_csr2coo(self):
|
||||
def bprop(indptr, nnz, out, dout):
|
||||
return zeros_like(dout)
|
||||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(_csr_ops.COO2CSR)
|
||||
def get_bprop_coo2csr(self):
|
||||
def bprop(row_indices, height, out, dout):
|
||||
|
|
|
@ -1006,6 +1006,7 @@ coo_tensor_get_dense_shape = Primitive('COOTensorGetDenseShape')
|
|||
|
||||
@constexpr
|
||||
def print_info(info):
|
||||
"""Print given error info"""
|
||||
print(info)
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue