forked from mindspore-Ecosystem/mindspore
!48602 修复softplus算子空输入的bug,修改了mish算子的代码示例
Merge pull request !48602 from wangtongyu6/fix_softplus_bug_and_mish_example
This commit is contained in:
commit
855acc3407
|
@ -42,6 +42,7 @@ class SoftplusInfer : public abstract::OpInferBase {
|
|||
MS_EXCEPTION_IF_NULL(prim);
|
||||
auto prim_name = prim->name();
|
||||
// check
|
||||
MS_EXCEPTION_IF_NULL(input_args[0]);
|
||||
std::set<TypePtr> valid_index_types = {kFloat16, kFloat32};
|
||||
auto x_type = input_args[0]->BuildType();
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("x", x_type, valid_index_types, prim_name);
|
||||
|
|
|
@ -649,9 +649,8 @@ class Mish(PrimitiveWithInfer):
|
|||
>>> x = Tensor(np.array([[-1.0, 4.0, -8.0], [2.0, -5.0, 9.0]]), mindspore.float32)
|
||||
>>> mish = ops.Mish()
|
||||
>>> output = mish(x)
|
||||
>>> print(output)
|
||||
[[-0.3034014 3.9974129 -0.0026832]
|
||||
[ 1.9439590 -0.0033576 9.0000000]]
|
||||
>>> print(output.shape)
|
||||
(2, 3)
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
|
|
Loading…
Reference in New Issue