!48602 修复softplus算子空输入的bug,修改了mish算子的代码示例

Merge pull request !48602 from wangtongyu6/fix_softplus_bug_and_mish_example
This commit is contained in:
i-robot 2023-02-09 08:38:14 +00:00 committed by Gitee
commit 855acc3407
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
2 changed files with 3 additions and 3 deletions

View File

@ -42,6 +42,7 @@ class SoftplusInfer : public abstract::OpInferBase {
MS_EXCEPTION_IF_NULL(prim);
auto prim_name = prim->name();
// check
MS_EXCEPTION_IF_NULL(input_args[0]);
std::set<TypePtr> valid_index_types = {kFloat16, kFloat32};
auto x_type = input_args[0]->BuildType();
(void)CheckAndConvertUtils::CheckTensorTypeValid("x", x_type, valid_index_types, prim_name);

View File

@ -649,9 +649,8 @@ class Mish(PrimitiveWithInfer):
>>> x = Tensor(np.array([[-1.0, 4.0, -8.0], [2.0, -5.0, 9.0]]), mindspore.float32)
>>> mish = ops.Mish()
>>> output = mish(x)
>>> print(output)
[[-0.3034014 3.9974129 -0.0026832]
[ 1.9439590 -0.0033576 9.0000000]]
>>> print(output.shape)
(2, 3)
"""
@prim_attr_register