!1501 add vm support for operators in mindspore.ops.operaitons package. Inlcude ACos, ACosGrad, Acosh, AcoshGrad, ArgMinD and ApplyCenteredRMSProp
Merge pull request !1501 from zhouneng/add_vm_support
This commit is contained in:
commit
6f92c4a124
|
@ -80,6 +80,7 @@ static std::map<string, string> tbe_func_adapter_map = {
|
|||
{"resize_nearest_neighbor_grad", "resize_nearest_neighbor_v2_grad_d"},
|
||||
{"pad", "pad_d"},
|
||||
{"argmax", "arg_max_d"},
|
||||
{"argmin", "arg_min_d"},
|
||||
{"space_to_batch", "space_to_batch_d"},
|
||||
{"batch_to_space", "batch_to_space_d"},
|
||||
{"space_to_batch_nd", "space_to_batch_nd_d"},
|
||||
|
@ -100,7 +101,9 @@ static std::map<string, string> tbe_func_adapter_map = {
|
|||
{"reduce_all", "reduce_all_d"},
|
||||
{"sparse_apply_adagrad", "sparse_apply_adagrad_d"},
|
||||
{"unsorted_segment_min", "unsorted_segment_min_d"},
|
||||
{"reduce_prod", "reduce_prod_d"}};
|
||||
{"reduce_prod", "reduce_prod_d"},
|
||||
{"a_cos", "acos"},
|
||||
{"a_cos_grad", "acos_grad"}};
|
||||
|
||||
void TbeAdapter::NormalizeFuncName(std::string *func_name) {
|
||||
if (func_name == nullptr) {
|
||||
|
@ -156,8 +159,8 @@ void TbeAdapter::SetTbeAttrsForTransDataOp(const mindspore::AnfNodePtr &anf_node
|
|||
}
|
||||
|
||||
std::unordered_set<std::string> input_order_adjusted_ops = {
|
||||
"Conv2DBackpropInput", "Conv2DBackpropFilter", "LogSoftmaxGrad", "LayerNormGrad",
|
||||
"LayerNormXBackprop", "LayerNormBetaGammaBackprop", "MinimumGrad", "MaximumGrad"};
|
||||
"Conv2DBackpropInput", "Conv2DBackpropFilter", "LogSoftmaxGrad", "LayerNormGrad", "LayerNormXBackprop",
|
||||
"LayerNormBetaGammaBackprop", "MinimumGrad", "MaximumGrad", "ApplyCenteredRMSProp"};
|
||||
|
||||
void TbeAdapter::InputOrderPass(const std::string &op_name, std::vector<std::vector<nlohmann::json>> const &inputs_list,
|
||||
nlohmann::json *inputs_json) {
|
||||
|
|
|
@ -179,6 +179,7 @@ const PrimitivePtr kPrimPooling = std::make_shared<Primitive>("Pooling");
|
|||
const PrimitivePtr kPrimPoolingGrad = std::make_shared<Primitive>("PoolingGrad");
|
||||
const PrimitivePtr kPrimMaxPool = std::make_shared<Primitive>("MaxPool");
|
||||
const PrimitivePtr kPrimMaxPoolGrad = std::make_shared<Primitive>("MaxPoolGrad");
|
||||
const PrimitivePtr kPrimApplyCenteredRMSProp = std::make_shared<Primitive>("ApplyCenteredRMSProp");
|
||||
const PrimitivePtr kPrimAvgPoolGrad = std::make_shared<Primitive>("AvgPoolGrad");
|
||||
const PrimitivePtr kPrimFusedBatchNorm = std::make_shared<Primitive>("FusedBatchNorm");
|
||||
const PrimitivePtr kPrimConv2D = std::make_shared<Primitive>("Conv2D");
|
||||
|
|
|
@ -181,6 +181,7 @@ extern const PrimitivePtr kPrimCumProd;
|
|||
extern const PrimitivePtr kPrimFlatten;
|
||||
extern const PrimitivePtr kPrimLogSoftmax;
|
||||
extern const PrimitivePtr kPrimLogSoftmaxGrad;
|
||||
extern const PrimitivePtr kPrimApplyCenteredRMSProp;
|
||||
extern const PrimitivePtr kPrimTanh;
|
||||
extern const PrimitivePtr kPrimTanhGrad;
|
||||
extern const PrimitivePtr kPrimPooling;
|
||||
|
|
|
@ -875,7 +875,9 @@ size_t AnfRuntimeAlgorithm::GetRealInputIndex(const mindspore::AnfNodePtr &anf_n
|
|||
{prim::kPrimLayerNormBetaGammaBackprop->name(), {{0, 1}, {1, 0}, {2, 2}, {3, 3}}},
|
||||
{prim::kPrimLayerNormXBackprop->name(), {{0, 1}, {1, 0}, {2, 2}, {3, 3}, {4, 4}}},
|
||||
{prim::kPrimMinimumGrad->name(), {{0, 2}, {1, 0}, {2, 1}}},
|
||||
{prim::kPrimMaximumGrad->name(), {{0, 2}, {1, 0}, {2, 1}}}};
|
||||
{prim::kPrimMaximumGrad->name(), {{0, 2}, {1, 0}, {2, 1}}},
|
||||
{prim::kPrimApplyCenteredRMSProp->name(),
|
||||
{{0, 0}, {1, 1}, {2, 2}, {3, 3}, {4, 5}, {5, 6}, {6, 7}, {7, 8}, {8, 4}}}};
|
||||
size_t ret = cur_index;
|
||||
auto node_name = AnfAlgo::GetCNodeName(anf_node);
|
||||
if (AnfAlgo::GetKernelType(anf_node) == TBE_KERNEL) {
|
||||
|
|
|
@ -79,6 +79,7 @@ constexpr auto kApplyAdamOpName = "Adam";
|
|||
constexpr auto kApplyAdaMaxOpName = "ApplyAdaMax";
|
||||
constexpr auto kApplyAddSignOpName = "ApplyAddSign";
|
||||
constexpr auto kApplyCenteredRMSPOpName = "ApplyCenteredRMSP";
|
||||
constexpr auto kApplyCenteredRMSPropOpName = "ApplyCenteredRMSProp";
|
||||
constexpr auto kApplyFtrlOpName = "ApplyFtrl";
|
||||
constexpr auto kApplyFtrlV2OpName = "ApplyFtrlV2";
|
||||
constexpr auto kApplyGradientDescentOpName = "ApplyGradientDescent";
|
||||
|
|
|
@ -16,8 +16,13 @@
|
|||
"""tbe ops"""
|
||||
from .abs import _abs_tbe
|
||||
from .abs_grad import _abs_grad_tbe
|
||||
from .acos import _acos_tbe
|
||||
from .acos_grad import _acos_grad_tbe
|
||||
from .acosh import _acosh_tbe
|
||||
from .acosh_grad import _acosh_grad_tbe
|
||||
from .adam_apply_one_with_decay import _adam_apply_one_with_decay_tbe
|
||||
from .add import _add_tbe
|
||||
from .apply_centered_rms_prop import _apply_centered_rms_prop_tbe
|
||||
from .add_n import _add_n_tbe
|
||||
from .apply_ftrl import _apply_ftrl_tbe
|
||||
from .apply_momentum import _apply_momentum_tbe
|
||||
|
@ -183,6 +188,7 @@ from .arg_max import _arg_max_tbe
|
|||
from .nms_with_mask import _nms_with_mask_tbe
|
||||
from .sgd import _sgd_tbe
|
||||
from .lars_update import _lars_update_tbe
|
||||
from .arg_min import _arg_min_tbe
|
||||
from .bn_training_update_v2 import _bn_training_update_v2_tbe
|
||||
from .square_sum_all import _square_sum_all_tbe
|
||||
from .pack import _pack_tbe
|
||||
|
|
|
@ -0,0 +1,37 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""ACos op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
acos_op_info = TBERegOp("ACos") \
|
||||
.fusion_type("ELEMWISE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("acos.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("acos") \
|
||||
.partial_flag(True) \
|
||||
.op_pattern("formatAgnostic") \
|
||||
.input(0, "x", False, "required", "all") \
|
||||
.output(0, "y", False, "required", "all") \
|
||||
.dtype_format(DataType.F16_5HD, DataType.F16_5HD) \
|
||||
.dtype_format(DataType.F32_5HD, DataType.F32_5HD) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(acos_op_info)
|
||||
def _acos_tbe():
|
||||
"""ACos TBE register"""
|
||||
return
|
|
@ -0,0 +1,43 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""ACosGrad op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
acos_grad_op_info = TBERegOp("ACosGrad") \
|
||||
.fusion_type("ELEMWISE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("acos_grad.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("acos_grad") \
|
||||
.partial_flag(True) \
|
||||
.input(0, "y", False, "required", "all") \
|
||||
.input(1, "dy", False, "required", "all") \
|
||||
.output(0, "z", False, "required", "all") \
|
||||
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \
|
||||
.dtype_format(DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_FracZ) \
|
||||
.dtype_format(DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0) \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \
|
||||
.dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ) \
|
||||
.dtype_format(DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(acos_grad_op_info)
|
||||
def _acos_grad_tbe():
|
||||
"""ACosGrad TBE register"""
|
||||
return
|
|
@ -0,0 +1,37 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Acosh op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
acosh_op_info = TBERegOp("Acosh") \
|
||||
.fusion_type("ELEMWISE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("acosh.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("acosh") \
|
||||
.partial_flag(True) \
|
||||
.op_pattern("formatAgnostic") \
|
||||
.input(0, "x", False, "required", "all") \
|
||||
.output(0, "y", False, "required", "all") \
|
||||
.dtype_format(DataType.F16_5HD, DataType.F16_5HD) \
|
||||
.dtype_format(DataType.F32_5HD, DataType.F32_5HD) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(acosh_op_info)
|
||||
def _acosh_tbe():
|
||||
"""Acosh TBE register"""
|
||||
return
|
|
@ -0,0 +1,43 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""AcoshGrad op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
acosh_grad_op_info = TBERegOp("AcoshGrad") \
|
||||
.fusion_type("ELEMWISE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("acosh_grad.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("acosh_grad") \
|
||||
.partial_flag(True) \
|
||||
.input(0, "y", False, "required", "all") \
|
||||
.input(1, "dy", False, "required", "all") \
|
||||
.output(0, "z", False, "required", "all") \
|
||||
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \
|
||||
.dtype_format(DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_FracZ) \
|
||||
.dtype_format(DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0) \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \
|
||||
.dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ) \
|
||||
.dtype_format(DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(acosh_grad_op_info)
|
||||
def _acosh_grad_tbe():
|
||||
"""AcoshGrad TBE register"""
|
||||
return
|
|
@ -0,0 +1,66 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""ApplyCenteredRMSProp op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
apply_centered_rms_prop_op_info = TBERegOp("ApplyCenteredRMSProp") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("apply_centered_rms_prop.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("apply_centered_rms_prop") \
|
||||
.partial_flag(True) \
|
||||
.input(0, "var", False, "required", "all") \
|
||||
.input(1, "mg", False, "required", "all") \
|
||||
.input(2, "ms", False, "required", "all") \
|
||||
.input(3, "mom", False, "required", "all") \
|
||||
.input(4, "lr", False, "required", "all") \
|
||||
.input(5, "rho", False, "required", "all") \
|
||||
.input(6, "momentum", False, "required", "all") \
|
||||
.input(7, "epsilon", False, "required", "all") \
|
||||
.input(8, "grad", False, "required", "all") \
|
||||
.output(0, "var", False, "required", "all") \
|
||||
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD,
|
||||
DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
|
||||
DataType.F16_5HD, DataType.F16_5HD) \
|
||||
.dtype_format(DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_FracZ,
|
||||
DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
|
||||
DataType.F16_FracZ, DataType.F16_FracZ) \
|
||||
.dtype_format(DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0,
|
||||
DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
|
||||
DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0) \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
|
||||
DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
|
||||
DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD,
|
||||
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
|
||||
DataType.F32_5HD, DataType.F32_5HD) \
|
||||
.dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ,
|
||||
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
|
||||
DataType.F32_FracZ, DataType.F32_FracZ) \
|
||||
.dtype_format(DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0,
|
||||
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
|
||||
DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
|
||||
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
|
||||
DataType.F32_Default, DataType.F32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(apply_centered_rms_prop_op_info)
|
||||
def _apply_centered_rms_prop_tbe():
|
||||
"""ApplyCenteredRMSProp TBE register"""
|
||||
return
|
|
@ -0,0 +1,38 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Argmin op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
arg_min_op_info = TBERegOp("Argmin") \
|
||||
.fusion_type("ELEMWISE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("arg_min_d.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("arg_min_d") \
|
||||
.partial_flag(True) \
|
||||
.attr("axis", "required", "int", "all") \
|
||||
.attr("output_dtype", "optional", "type", "all") \
|
||||
.input(0, "x", False, "required", "all") \
|
||||
.output(0, "y", False, "required", "all") \
|
||||
.dtype_format(DataType.F16_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.I32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(arg_min_op_info)
|
||||
def _arg_min_tbe():
|
||||
"""Argmin TBE register"""
|
||||
return
|
|
@ -1114,8 +1114,8 @@ class Argmin(PrimitiveWithInfer):
|
|||
|
||||
Args:
|
||||
axis (int): Axis on which Argmin operation applies. Default: -1.
|
||||
output_type (:class:`mindspore.dtype`): An optional data type from: `mindspore.dtype.int32`,
|
||||
`mindspore.dtype.int64`. Default: `mindspore.dtype.int64`.
|
||||
output_type (:class:`mindspore.dtype`): An optional data type of `mindspore.dtype.int32`.
|
||||
Default: `mindspore.dtype.int32`.
|
||||
|
||||
Inputs:
|
||||
- **input_x** (Tensor) - Input tensor.
|
||||
|
@ -1124,13 +1124,13 @@ class Argmin(PrimitiveWithInfer):
|
|||
Tensor, indices of the min value of input tensor across the axis.
|
||||
|
||||
Examples:
|
||||
>>> input_x = Tensor(np.array([2.0, 3.1, 1.2]))
|
||||
>>> input_x = Tensor(np.array([2.0, 3.1, 1.2]), mindspore.float32)
|
||||
>>> index = P.Argmin()(input_x)
|
||||
>>> assert index == Tensor(2, mindspore.int64)
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, axis=-1, output_type=mstype.int64):
|
||||
def __init__(self, axis=-1, output_type=mstype.int32):
|
||||
"""init Argmin"""
|
||||
self.init_prim_io_names(inputs=['x'], outputs=['output'])
|
||||
validator.check_value_type("axis", axis, [int], self.name)
|
||||
|
|
|
@ -1173,10 +1173,8 @@ class AvgPool(_Pool):
|
|||
>>> result = net(input_x)
|
||||
[[[[ 2.5 3.5 4.5]
|
||||
[ 6.5 7.5 8.5]]
|
||||
|
||||
[[ 14.5 15.5 16.5]
|
||||
[ 18.5 19.5 20.5]]
|
||||
|
||||
[[ 26.5 27.5 28.5]
|
||||
[ 30.5 31.5 32.5]]]]
|
||||
"""
|
||||
|
@ -1718,15 +1716,16 @@ class ApplyRMSProp(PrimitiveWithInfer):
|
|||
|
||||
Examples:
|
||||
>>> apply_rms = P.ApplyRMSProp()
|
||||
>>> input_x = Tensor(np.random.randint(0, 256, (3, 3)),mindspore.float32)
|
||||
>>> mean_square = Tensor(np.random.randint(0, 256, (3, 3)), mindspore.float32)
|
||||
>>> moment = Tensor(np.random.randn(3, 3), mindspore.float32)
|
||||
>>> grad = Tensor(np.random.randint(-32, 16, (3, 3)), mindspore.float32 )
|
||||
>>> learning_rate = 0.9
|
||||
>>> input_x = Tensor(1., mindspore.float32)
|
||||
>>> mean_square = Tensor(2., mindspore.float32)
|
||||
>>> moment = Tensor(1., mindspore.float32)
|
||||
>>> grad = Tensor(2., mindspore.float32 )
|
||||
>>> learning_rate = Tensor(0.9, mindspore.float32)
|
||||
>>> decay = 0.0
|
||||
>>> momentum = 1e-10
|
||||
>>> epsilon = 0.001
|
||||
>>> result = apply_rms(input_x, mean_square, moment, grad, learning_rate, decay, momentum, epsilon)
|
||||
(-2.9977674, 0.80999994, 1.9987665)
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
|
@ -1808,17 +1807,18 @@ class ApplyCenteredRMSProp(PrimitiveWithInfer):
|
|||
|
||||
Examples:
|
||||
>>> centered_rms_prop = P.ApplyCenteredRMSProp()
|
||||
>>> input_x = Tensor(np.random.randint(0, 256, (3, 3)),mindspore.float32)
|
||||
>>> mean_grad = Tensor(np.random.randint(-8, 8, (3, 3)), mindspore.float32)
|
||||
>>> mean_square = Tensor(np.random.randint(0, 256, (3, 3)), mindspore.float32)
|
||||
>>> moment = Tensor(np.random.randn(3, 3), mindspore.float32)
|
||||
>>> grad = Tensor(np.random.randint(-32, 16, (3, 3)), mindspore.float32 )
|
||||
>>> learning_rate = 0.9
|
||||
>>> input_x = Tensor(1., mindspore.float32)
|
||||
>>> mean_grad = Tensor(2., mindspore.float32)
|
||||
>>> mean_square = Tensor(1., mindspore.float32)
|
||||
>>> moment = Tensor(2., mindspore.float32)
|
||||
>>> grad = Tensor(1., mindspore.float32)
|
||||
>>> learning_rate = Tensor(0.9, mindspore.float32)
|
||||
>>> decay = 0.0
|
||||
>>> momentum = 1e-10
|
||||
>>> epsilon = 0.001
|
||||
>>> result = centered_rms_prop(input_x, mean_grad, mean_square, moment, grad,
|
||||
>>> learning_rate, decay, momentum, epsilon)
|
||||
>>> learning_rate, decay, momentum, epsilon)
|
||||
-27.460497
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
|
@ -1927,7 +1927,6 @@ class L2Normalize(PrimitiveWithInfer):
|
|||
[[[-0.47247353 -0.30934513 -0.4991462 0.8185567 ]
|
||||
[-0.08070751 -0.9961299 -0.5741758 0.09262337]
|
||||
[-0.9916556 -0.3049123 0.5730487 -0.40579924]
|
||||
|
||||
[[-0.88134485 0.9509498 -0.86651784 0.57442576]
|
||||
[ 0.99673784 0.08789381 -0.8187321 0.9957012 ]
|
||||
[ 0.12891524 -0.9523804 -0.81952125 0.91396334]]]
|
||||
|
|
|
@ -359,12 +359,20 @@ test_case_math_ops = [
|
|||
'skip': ['backward']}),
|
||||
('ACos', {
|
||||
'block': P.ACos(),
|
||||
'desc_inputs': [[2, 3]],
|
||||
'desc_bprop': [[2, 3]]}),
|
||||
'desc_inputs': [Tensor(np.array([2., 3.]).astype(np.float32))],
|
||||
'desc_bprop': [Tensor(np.array([2., 3.]).astype(np.float32))]}),
|
||||
('ACosGrad', {
|
||||
'block': G.ACosGrad(),
|
||||
'desc_inputs': [[2, 3], [2, 3]],
|
||||
'skip': ['backward']}),
|
||||
('Acosh', {
|
||||
'block': P.Acosh(),
|
||||
'desc_inputs': [[3, 4, 5]],
|
||||
'desc_bprop': [[3, 4, 5]]}),
|
||||
'desc_inputs': [Tensor(np.array([2., 3.]).astype(np.float32))],
|
||||
'desc_bprop': [Tensor(np.array([2., 3.]).astype(np.float32))]}),
|
||||
('AcoshGrad', {
|
||||
'block': G.AcoshGrad(),
|
||||
'desc_inputs': [[2, 3], [2, 3]],
|
||||
'skip': ['backward']}),
|
||||
('Sin', {
|
||||
'block': P.Sin(),
|
||||
'desc_inputs': [[2, 3]],
|
||||
|
@ -1012,8 +1020,9 @@ test_case_nn_ops = [
|
|||
('ApplyCenteredRMSProp', {
|
||||
'block': P.ApplyCenteredRMSProp(),
|
||||
'desc_const': [0.9, 0.0, 1e-10, 0.001],
|
||||
'desc_inputs': [[3, 3], [3, 3], [3, 3], [3, 3], [3, 3]],
|
||||
'desc_bprop': [3, 3],
|
||||
'desc_inputs': [Tensor(1., mstype.float32), Tensor(2., mstype.float32), Tensor(1., mstype.float32),
|
||||
Tensor(2., mstype.float32), Tensor(1., mstype.float32)],
|
||||
'desc_bprop': [1],
|
||||
'skip': ['backward']}),
|
||||
('CTCLoss', {
|
||||
'block': P.CTCLoss(),
|
||||
|
|
Loading…
Reference in New Issue