!62768 [MS] fix dtype cast bugs
Merge pull request !62768 from XianglongZeng/dyn_cpu
This commit is contained in:
commit
6699df751c
|
@ -3,7 +3,7 @@ mindspore.nn.SiLU
|
|||
|
||||
.. py:class:: mindspore.nn.SiLU
|
||||
|
||||
逐元素计算SiLU激活函数。
|
||||
逐元素计算SiLU激活函数。有时也被称作Swish函数。
|
||||
|
||||
SiLU函数定义为:
|
||||
|
||||
|
|
|
@ -3,7 +3,7 @@ mindspore.ops.silu
|
|||
|
||||
.. py:function:: mindspore.ops.silu(x)
|
||||
|
||||
按输入逐元素计算激活函数SiLU(Sigmoid Linear Unit)。该激活函数定义为:
|
||||
按输入逐元素计算激活函数SiLU(Sigmoid Linear Unit)。有时也被称作Swish函数。该激活函数定义为:
|
||||
|
||||
.. math::
|
||||
\text{SiLU}(x) = x * \sigma(x),
|
||||
|
|
|
@ -400,6 +400,8 @@ inline aclScalar *ConvertType(const ScalarPtr &value) {
|
|||
converter.ConvertValue(value, AttrDeclType<bool>(), &acl_scalar);
|
||||
} else if (value->isa<Int64Imm>()) {
|
||||
converter.ConvertValue(value, AttrDeclType<int64_t>(), &acl_scalar);
|
||||
} else if (value->isa<FP64Imm>()) {
|
||||
converter.ConvertValue(value, AttrDeclType<double>(), &acl_scalar);
|
||||
} else if (value->isa<FP32Imm>()) {
|
||||
converter.ConvertValue(value, AttrDeclType<float>(), &acl_scalar);
|
||||
} else if (value->isa<Int32Imm>()) {
|
||||
|
|
|
@ -140,14 +140,17 @@ bool ValidateArgsType(const AbstractBasePtr &abs_arg, OP_DTYPE type_arg) {
|
|||
return abs_arg->isa<abstract::AbstractScalar>() && (abs_type->isa<Float>() || abs_type->isa<BFloat>());
|
||||
}
|
||||
case OP_DTYPE::DT_NUMBER: {
|
||||
return abs_arg->isa<abstract::AbstractScalar>() && (abs_type->isa<Number>());
|
||||
return abs_arg->isa<abstract::AbstractScalar>() && abs_type->isa<Number>();
|
||||
}
|
||||
case OP_DTYPE::DT_STR: {
|
||||
return abs_arg->isa<abstract::AbstractScalar>() && (abs_type->isa<String>());
|
||||
return abs_arg->isa<abstract::AbstractScalar>() && abs_type->isa<String>();
|
||||
}
|
||||
case OP_DTYPE::DT_TENSOR: {
|
||||
return abs_arg->isa<abstract::AbstractTensor>();
|
||||
}
|
||||
case OP_DTYPE::DT_TYPE: {
|
||||
return abs_arg->isa<abstract::AbstractType>() && abs_type->isa<Type>();
|
||||
}
|
||||
default: {
|
||||
return ValidateArgsSequenceType(abs_arg, type_arg);
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
silu:
|
||||
description: |
|
||||
Computes SiLU (Sigmoid Linear Unit activation function) of input tensors element-wise.
|
||||
Computes SiLU (Sigmoid Linear Unit activation function) of input tensors element-wise. Also known as the Swish
|
||||
function.
|
||||
|
||||
Refer to :func:`mindspore.ops.silu` for more details.
|
||||
|
||||
|
|
|
@ -29,25 +29,23 @@ def int_to_float(data):
|
|||
return float(data)
|
||||
|
||||
|
||||
def scalar_to_tuple(data, dst_type):
|
||||
if dst_type == PY_DT_TUPLE_INT:
|
||||
return (int(data),)
|
||||
def scalar_to_tuple(data):
|
||||
return (data,)
|
||||
|
||||
|
||||
def list_to_tuple(data, dst_type):
|
||||
def list_to_tuple(data):
|
||||
# tuple() currently does not support Any from JIT Fallback.
|
||||
res = ()
|
||||
for element in data:
|
||||
if dst_type == PY_DT_TUPLE_INT:
|
||||
res += (int(element),)
|
||||
else:
|
||||
res += (element,)
|
||||
res += (element,)
|
||||
return res
|
||||
|
||||
|
||||
def tensor_to_tuple(data, dst_type):
|
||||
if dst_type == PY_DT_TUPLE_INT:
|
||||
def tensor_to_tuple(data):
|
||||
# Since tuple is not supported for precision conversion during KernelSelect, the original int32 tensor input cases
|
||||
# would be failed. Thus, raise the tuple precision from int32 to int64 at frontend. But sequence data type cast
|
||||
# must be adapted in future version.
|
||||
if data.dtype == ms.int32:
|
||||
data = ops.cast(data, ms.int64)
|
||||
return tensor_to_tuple_(data)
|
||||
|
||||
|
@ -66,8 +64,8 @@ def tuple_to_tensor(data):
|
|||
return ops.tuple_to_array(data)
|
||||
|
||||
|
||||
def list_to_tensor(data, dst_type):
|
||||
return ops.tuple_to_array(list_to_tuple(data, dst_type))
|
||||
def list_to_tensor(data):
|
||||
return ops.tuple_to_array(list_to_tuple(data))
|
||||
|
||||
# type
|
||||
PY_DT_TYPE = OpDtype.PY_DT_TYPE.value
|
||||
|
@ -199,18 +197,18 @@ def do_type_cast(data, dst_type):
|
|||
return int_to_float(data)
|
||||
elif is_tuple(dst_type):
|
||||
if isinstance(data, (int, float, bool)):
|
||||
return scalar_to_tuple(data, dst_type)
|
||||
return scalar_to_tuple(data)
|
||||
if isinstance(data, list):
|
||||
return list_to_tuple(data, dst_type)
|
||||
return list_to_tuple(data)
|
||||
if isinstance(data, Tensor):
|
||||
return tensor_to_tuple(data, dst_type)
|
||||
return tensor_to_tuple(data)
|
||||
elif dst_type == PY_DT_TENSOR:
|
||||
if isinstance(data, (int, float, bool)):
|
||||
return scalar_to_tensor(data)
|
||||
if isinstance(data, tuple):
|
||||
return tuple_to_tensor(data)
|
||||
if isinstance(data, list):
|
||||
return list_to_tensor(data, dst_type)
|
||||
return list_to_tensor(data)
|
||||
elif is_number(dst_type):
|
||||
if isinstance(data, Tensor):
|
||||
if dst_type == PY_DT_INT:
|
||||
|
|
Loading…
Reference in New Issue