!62768 [MS] fix dtype cast bugs

Merge pull request !62768 from XianglongZeng/dyn_cpu
This commit is contained in:
i-robot 2023-12-11 06:25:33 +00:00 committed by Gitee
commit 6699df751c
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
6 changed files with 25 additions and 21 deletions

View File

@ -3,7 +3,7 @@ mindspore.nn.SiLU
.. py:class:: mindspore.nn.SiLU
逐元素计算SiLU激活函数。
逐元素计算SiLU激活函数。有时也被称作Swish函数。
SiLU函数定义为

View File

@ -3,7 +3,7 @@ mindspore.ops.silu
.. py:function:: mindspore.ops.silu(x)
按输入逐元素计算激活函数SiLUSigmoid Linear Unit。该激活函数定义为
按输入逐元素计算激活函数SiLUSigmoid Linear Unit有时也被称作Swish函数。该激活函数定义为:
.. math::
\text{SiLU}(x) = x * \sigma(x),

View File

@ -400,6 +400,8 @@ inline aclScalar *ConvertType(const ScalarPtr &value) {
converter.ConvertValue(value, AttrDeclType<bool>(), &acl_scalar);
} else if (value->isa<Int64Imm>()) {
converter.ConvertValue(value, AttrDeclType<int64_t>(), &acl_scalar);
} else if (value->isa<FP64Imm>()) {
converter.ConvertValue(value, AttrDeclType<double>(), &acl_scalar);
} else if (value->isa<FP32Imm>()) {
converter.ConvertValue(value, AttrDeclType<float>(), &acl_scalar);
} else if (value->isa<Int32Imm>()) {

View File

@ -140,14 +140,17 @@ bool ValidateArgsType(const AbstractBasePtr &abs_arg, OP_DTYPE type_arg) {
return abs_arg->isa<abstract::AbstractScalar>() && (abs_type->isa<Float>() || abs_type->isa<BFloat>());
}
case OP_DTYPE::DT_NUMBER: {
return abs_arg->isa<abstract::AbstractScalar>() && (abs_type->isa<Number>());
return abs_arg->isa<abstract::AbstractScalar>() && abs_type->isa<Number>();
}
case OP_DTYPE::DT_STR: {
return abs_arg->isa<abstract::AbstractScalar>() && (abs_type->isa<String>());
return abs_arg->isa<abstract::AbstractScalar>() && abs_type->isa<String>();
}
case OP_DTYPE::DT_TENSOR: {
return abs_arg->isa<abstract::AbstractTensor>();
}
case OP_DTYPE::DT_TYPE: {
return abs_arg->isa<abstract::AbstractType>() && abs_type->isa<Type>();
}
default: {
return ValidateArgsSequenceType(abs_arg, type_arg);
}

View File

@ -1,6 +1,7 @@
silu:
description: |
Computes SiLU (Sigmoid Linear Unit activation function) of input tensors element-wise.
Computes SiLU (Sigmoid Linear Unit activation function) of input tensors element-wise. Also known as the Swish
function.
Refer to :func:`mindspore.ops.silu` for more details.

View File

@ -29,25 +29,23 @@ def int_to_float(data):
return float(data)
def scalar_to_tuple(data, dst_type):
if dst_type == PY_DT_TUPLE_INT:
return (int(data),)
def scalar_to_tuple(data):
return (data,)
def list_to_tuple(data, dst_type):
def list_to_tuple(data):
# tuple() currently does not support Any from JIT Fallback.
res = ()
for element in data:
if dst_type == PY_DT_TUPLE_INT:
res += (int(element),)
else:
res += (element,)
res += (element,)
return res
def tensor_to_tuple(data, dst_type):
if dst_type == PY_DT_TUPLE_INT:
def tensor_to_tuple(data):
# Since tuple is not supported for precision conversion during KernelSelect, the original int32 tensor input cases
# would be failed. Thus, raise the tuple precision from int32 to int64 at frontend. But sequence data type cast
# must be adapted in future version.
if data.dtype == ms.int32:
data = ops.cast(data, ms.int64)
return tensor_to_tuple_(data)
@ -66,8 +64,8 @@ def tuple_to_tensor(data):
return ops.tuple_to_array(data)
def list_to_tensor(data, dst_type):
return ops.tuple_to_array(list_to_tuple(data, dst_type))
def list_to_tensor(data):
return ops.tuple_to_array(list_to_tuple(data))
# type
PY_DT_TYPE = OpDtype.PY_DT_TYPE.value
@ -199,18 +197,18 @@ def do_type_cast(data, dst_type):
return int_to_float(data)
elif is_tuple(dst_type):
if isinstance(data, (int, float, bool)):
return scalar_to_tuple(data, dst_type)
return scalar_to_tuple(data)
if isinstance(data, list):
return list_to_tuple(data, dst_type)
return list_to_tuple(data)
if isinstance(data, Tensor):
return tensor_to_tuple(data, dst_type)
return tensor_to_tuple(data)
elif dst_type == PY_DT_TENSOR:
if isinstance(data, (int, float, bool)):
return scalar_to_tensor(data)
if isinstance(data, tuple):
return tuple_to_tensor(data)
if isinstance(data, list):
return list_to_tensor(data, dst_type)
return list_to_tensor(data)
elif is_number(dst_type):
if isinstance(data, Tensor):
if dst_type == PY_DT_INT: