From 58b468fdce2d8dceae83d293b85691bce166ff3d Mon Sep 17 00:00:00 2001 From: ZCX <769213934@qq.com> Date: Mon, 20 Dec 2021 21:45:41 +0800 Subject: [PATCH] [feat] [assistant] [I471DN] add new Ascend operator TruncatedNormal --- mindspore/ccsrc/include/common/utils/utils.h | 19 +-- .../cpu/kernel/truncated_normal_cpu_kernel.cc | 151 ++++++++++++++++++ .../cpu/kernel/truncated_normal_cpu_kernel.h | 57 +++++++ .../core/abstract/ops/primitive_infer_map.cc | 2 + mindspore/core/ops/core_ops.h | 1 + mindspore/core/ops/truncated_normal.cc | 133 +++++++++++++++ mindspore/core/ops/truncated_normal.h | 42 +++++ .../mindspore/ops/_op_impl/aicpu/__init__.py | 1 + .../ops/_op_impl/aicpu/truncated_normal.py | 37 +++++ .../mindspore/ops/operations/__init__.py | 3 +- .../mindspore/ops/operations/array_ops.py | 39 ----- .../mindspore/ops/operations/random_ops.py | 58 +++++++ tests/ut/python/ops/test_ops.py | 11 +- 13 files changed, 494 insertions(+), 60 deletions(-) create mode 100644 mindspore/ccsrc/plugin/device/cpu/kernel/truncated_normal_cpu_kernel.cc create mode 100644 mindspore/ccsrc/plugin/device/cpu/kernel/truncated_normal_cpu_kernel.h create mode 100644 mindspore/core/ops/truncated_normal.cc create mode 100644 mindspore/core/ops/truncated_normal.h create mode 100644 mindspore/python/mindspore/ops/_op_impl/aicpu/truncated_normal.py diff --git a/mindspore/ccsrc/include/common/utils/utils.h b/mindspore/ccsrc/include/common/utils/utils.h index e6d76e76a73..2c9af21fa3a 100644 --- a/mindspore/ccsrc/include/common/utils/utils.h +++ b/mindspore/ccsrc/include/common/utils/utils.h @@ -353,6 +353,7 @@ constexpr auto kEnvironGetOpName = "EnvironGet"; constexpr auto kEnvironDestroyAllOpName = "EnvironDestroyAll"; constexpr auto kNonDeterministicInts = "NonDeterministicInts"; constexpr auto kUpdateStateOpName = "UpdateState"; +constexpr auto kTruncatedNormal = "TruncatedNormal"; constexpr auto kPriorityReplayBufferCreate = "PriorityReplayBufferCreate"; constexpr auto kPriorityReplayBufferPush = "PriorityReplayBufferPush"; constexpr auto kPriorityReplayBufferSample = "PriorityReplayBufferSample"; @@ -779,19 +780,11 @@ const std::set kHWSpecialFormatSet = { const std::set kFloatDataTypeSet = {kNumberTypeFloat16, kNumberTypeFloat32}; -const std::set kComputeDepend = {kUniqueOpName, - kComputeAccidentalHitsOpName, - kSubAndFilterOpName, - kPadAndShiftOpName, - kCTCGreedyDecoderOpName, - kDropoutGenMaskOpName, - kMaskedSelectOpName, - kDynamicStitchOpName, - kGetNextOpName, - kNonMaxSuppressionV3OpName, - kCoalesceOpName, - kNonDeterministicInts, - kFractionalAvgPoolGradOpName}; +const std::set kComputeDepend = { + kUniqueOpName, kComputeAccidentalHitsOpName, kSubAndFilterOpName, kPadAndShiftOpName, + kCTCGreedyDecoderOpName, kDropoutGenMaskOpName, kMaskedSelectOpName, kDynamicStitchOpName, + kGetNextOpName, kNonMaxSuppressionV3OpName, kCoalesceOpName, kTruncatedNormal, + kNonDeterministicInts, kFractionalAvgPoolGradOpName}; const std::set k3DFormatSet = {kOpFormat_NCDHW, kOpFormat_NDC1HWC0, kOpFormat_FRACTAL_Z_3D, kOpFormat_NDHWC, kOpFormat_DHWCN, kOpFormat_DHWNC}; diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/truncated_normal_cpu_kernel.cc b/mindspore/ccsrc/plugin/device/cpu/kernel/truncated_normal_cpu_kernel.cc new file mode 100644 index 00000000000..3798e326a56 --- /dev/null +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/truncated_normal_cpu_kernel.cc @@ -0,0 +1,151 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "plugin/device/cpu/kernel/truncated_normal_cpu_kernel.h" +#include +#include +#include +#include +#include +#include +#include "Eigen/Core" +#include "unsupported/Eigen/CXX11/Tensor" +#include "plugin/device/cpu/hal/device/cpu_device_address.h" +#include "kernel/common_utils.h" + +namespace mindspore { +namespace kernel { +namespace { +const int32_t kMax = 2; +const uint32_t kInputNum = 1; +const uint32_t kInputDims = 1; +const uint32_t kOutputNum = 1; +const uint32_t kInputSizes = 2; +} // namespace + +void TruncatedNormalCPUKernelMod::InitKernel(const CNodePtr &kernel_node) { + MS_EXCEPTION_IF_NULL(kernel_node); + kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node); + auto input_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); + input_type_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0); + output_type_ = AnfAlgo::GetOutputDeviceDataType(kernel_node, 0); + seed_ = static_cast(common::AnfAlgo::GetNodeAttr(kernel_node, "seed")); + seed2_ = static_cast(common::AnfAlgo::GetNodeAttr(kernel_node, "seed2")); + if (input_shape[0] < kInputSizes) { + MS_EXCEPTION(ValueError) << "The input tensor shape must >= 2."; + } + if (input_shape.size() != kInputDims) { + MS_EXCEPTION(ValueError) << "The input tensor must be a 1-D tensor."; + } + + auto kernel_attr = GetKernelAttrFromNode(kernel_node); + auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport()); + if (!is_match) { + MS_LOG(EXCEPTION) << "TruncatedNormal does not support this kernel data type: " << kernel_attr; + } + + kernel_func_ = func_list_[index].second; +} + +bool TruncatedNormalCPUKernelMod::Launch(const std::vector &inputs, + const std::vector &workspace, + const std::vector &outputs) { + CHECK_KERNEL_INPUTS_NUM(inputs.size(), kInputNum, kernel_name_); + CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kOutputNum, kernel_name_); + if (input_type_ == kNumberTypeInt32 && output_type_ == kNumberTypeFloat16) { + LaunchKernel(inputs, outputs); + } else if (input_type_ == kNumberTypeInt32 && output_type_ == kNumberTypeFloat32) { + LaunchKernel(inputs, outputs); + } else if (input_type_ == kNumberTypeInt32 && output_type_ == kNumberTypeFloat64) { + LaunchKernel(inputs, outputs); + } else if (input_type_ == kNumberTypeInt64 && output_type_ == kNumberTypeFloat16) { + LaunchKernel(inputs, outputs); + } else if (input_type_ == kNumberTypeInt64 && output_type_ == kNumberTypeFloat32) { + LaunchKernel(inputs, outputs); + } else if (input_type_ == kNumberTypeInt64 && output_type_ == kNumberTypeFloat64) { + LaunchKernel(inputs, outputs); + } else { + MS_EXCEPTION(TypeError) << "The output data type must be one of float16, float32 and float64."; + } + return true; +} + +template +bool TruncatedNormalCPUKernelMod::LaunchKernel(const std::vector &inputs, + const std::vector &outputs) { + auto input = reinterpret_cast(inputs[0]->addr); + size_t input_elem_num = inputs[0]->size / sizeof(T1); + for (size_t i = 0; i < input_elem_num; i++) { + if (input[i] <= 0) { + MS_EXCEPTION(ValueError) << "Each dimension must be greater than zero."; + } + } + + auto output = reinterpret_cast(outputs[0]->addr); + size_t output_elem_num = outputs[0]->size / sizeof(T2); + std::random_device rd; + seedc_ = seed2_ != 0 ? seed2_ : (seed_ != 0 ? seed_ : rd()); + std::default_random_engine final_seed(seedc_); + if (seed_ != 0 || seed2_ != 0) { + flag_ = false; + } + + std::normal_distribution dis(0, 1); + auto task = [&](size_t start, size_t end) { + for (size_t j = start; j < end;) { + auto data = dis(final_seed); + if (data >= -kMax && data <= kMax) { + output[j++] = static_cast(data); + } + } + }; + if (flag_) { + CPUKernelUtils::ParallelFor(task, output_elem_num); + } else { + for (size_t i = 0; i < output_elem_num;) { + auto data = dis(final_seed); + if (data >= -kMax && data <= kMax) { + output[i++] = static_cast(data); + } + } + } + return true; +} + +std::vector> + TruncatedNormalCPUKernelMod::func_list_ = { + {KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16), + &TruncatedNormalCPUKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32), + &TruncatedNormalCPUKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat64), + &TruncatedNormalCPUKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat16), + &TruncatedNormalCPUKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32), + &TruncatedNormalCPUKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat64), + &TruncatedNormalCPUKernelMod::LaunchKernel}}; + +std::vector TruncatedNormalCPUKernelMod::GetOpSupport() { + std::vector support_list; + std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list), + [](const std::pair &pair) { return pair.first; }); + return support_list; +} +MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, TruncatedNormal, TruncatedNormalCPUKernelMod); +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/truncated_normal_cpu_kernel.h b/mindspore/ccsrc/plugin/device/cpu/kernel/truncated_normal_cpu_kernel.h new file mode 100644 index 00000000000..957de75ceae --- /dev/null +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/truncated_normal_cpu_kernel.h @@ -0,0 +1,57 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_TRUNCATEDNORMAL_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_TRUNCATEDNORMAL_CPU_KERNEL_H_ + +#include +#include +#include "plugin/device/cpu/kernel/cpu_kernel.h" +#include "plugin/factory/ms_factory.h" + +namespace mindspore { +namespace kernel { +class TruncatedNormalCPUKernelMod : public DeprecatedNativeCpuKernelMod { + public: + TruncatedNormalCPUKernelMod() = default; + ~TruncatedNormalCPUKernelMod() override = default; + + void InitKernel(const CNodePtr &kernel_node) override; + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; + + protected: + std::vector GetOpSupport() override; + + private: + template + bool LaunchKernel(const std::vector &inputs, const std::vector &outputs); + + using TruncatedNormalFunc = std::function &, + const std::vector &)>; + static std::vector> func_list_; + + TruncatedNormalFunc kernel_func_; + TypeId output_type_; + TypeId input_type_; + size_t seed_{0}; + size_t seed2_{0}; + size_t seedc_{0}; + bool flag_{true}; +}; +} // namespace kernel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_TRUNCATEDNORMAL_CPU_KERNEL_H_ diff --git a/mindspore/core/abstract/ops/primitive_infer_map.cc b/mindspore/core/abstract/ops/primitive_infer_map.cc index e52fdcd79a6..6061939870d 100644 --- a/mindspore/core/abstract/ops/primitive_infer_map.cc +++ b/mindspore/core/abstract/ops/primitive_infer_map.cc @@ -78,6 +78,7 @@ std::set GetDependsFormMap(const std::string &prim_name, size_t input_n static const auto &kNonDeterministicInts = prim::kPrimNonDeterministicInts->name(); static const auto &kSliceGrad = prim::kPrimSliceGrad->name(); static const auto &kReshape = prim::kPrimReshape->name(); + static const auto &kTruncatedNormal = prim::kPrimTruncatedNormal->name(); static const auto &kFillV2 = prim::kPrimFillV2->name(); static const auto &kFractionalAvgPoolGrad = prim::kPrimFractionalAvgPoolGrad->name(); // Common dynamic shape depends. @@ -106,6 +107,7 @@ std::set GetDependsFormMap(const std::string &prim_name, size_t input_n {kDynamicBroadcastTo, ShapeSet{1}}, {kNonDeterministicInts, ShapeSet{0}}, {kReduceSum, ShapeSet{1}}, + {kTruncatedNormal, ShapeSet{0}}, {kRaggedRange, ShapeSet{0, 1, 2}}}; auto ms_context = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(ms_context); diff --git a/mindspore/core/ops/core_ops.h b/mindspore/core/ops/core_ops.h index b28985ab240..611b2ac9dc6 100644 --- a/mindspore/core/ops/core_ops.h +++ b/mindspore/core/ops/core_ops.h @@ -990,6 +990,7 @@ GVAR_DEF(PrimitivePtr, kPrimDynamicBroadcastGradientArgs, std::make_shared("StandardNormal")); GVAR_DEF(PrimitivePtr, kPrimRandomNormal, std::make_shared("RandomNormal")); GVAR_DEF(PrimitivePtr, kPrimNonDeterministicInts, std::make_shared("NonDeterministicInts")); +GVAR_DEF(PrimitivePtr, kPrimTruncatedNormal, std::make_shared("TruncatedNormal")); // RL Ops GVAR_DEF(PrimitivePtr, kPrimTensorArrayStack, std::make_shared("TensorArrayStack")); diff --git a/mindspore/core/ops/truncated_normal.cc b/mindspore/core/ops/truncated_normal.cc new file mode 100644 index 00000000000..49c639976cd --- /dev/null +++ b/mindspore/core/ops/truncated_normal.cc @@ -0,0 +1,133 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "ops/truncated_normal.h" +#include +#include +#include +#include +#include +#include "ops/op_utils.h" +#include "utils/check_convert_utils.h" +#include "abstract/ops/primitive_infer_map.h" +#include "mindapi/src/helper.h" + +namespace mindspore { +namespace ops { +namespace { +abstract::ShapePtr TruncatedNormalInferShape(const PrimitivePtr &primitive, + const std::vector &input_args) { + if (!input_args[0]->isa()) { + MS_EXCEPTION(TypeError) << "Input[0] only support tensor!"; + } + MS_EXCEPTION_IF_NULL(primitive); + const uint32_t kInpuDims = 1; + const uint32_t kInpuSizes = 2; + auto max_length_ptr = primitive->GetAttr("max_length"); + MS_EXCEPTION_IF_NULL(max_length_ptr); + int64_t max_length = GetValue(max_length_ptr); + auto input_shape = input_args[0]->cast(); + MS_EXCEPTION_IF_NULL(input_shape); + auto input_shape_value_ptr = input_shape->BuildValue(); + MS_EXCEPTION_IF_NULL(input_shape_value_ptr); + auto input_shape_tensor = input_shape_value_ptr->cast(); + auto input_type = input_args[0]->BuildType(); + MS_EXCEPTION_IF_NULL(input_type); + auto input_type_id = input_type->cast(); + MS_EXCEPTION_IF_NULL(input_type_id); + auto input_type_element = input_type_id->element(); + MS_EXCEPTION_IF_NULL(input_type_element); + auto shape_ptr = std::make_shared( + CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]); + auto shape_v = shape_ptr->shape(); + if (shape_v.size() != kInpuDims) { + MS_EXCEPTION(ValueError) << "The input tensor must be a 1-D tensor."; + } + if (shape_v[0] < kInpuSizes) { + MS_EXCEPTION(ValueError) << "The input tensor elements must >= 2."; + } + if (!input_args[0]->BuildValue()->isa() && !input_args[0]->BuildValue()->isa()) { + std::vector out_shape; + auto shape_m = 1; + if (input_type_element->type_id() == kNumberTypeInt32) { + auto input_shape_ptr = reinterpret_cast(input_shape_tensor->data_c()); + for (auto i = 0; i < shape_v[0]; ++i) { + if (input_shape_ptr[i] > 0) { + out_shape.push_back(input_shape_ptr[i]); + shape_m *= input_shape_ptr[i]; + } else { + MS_EXCEPTION(ValueError) << "Each dimension must be greater than 0."; + } + } + } else if (input_type_element->type_id() == kNumberTypeInt64) { + auto input_shape_ptr = reinterpret_cast(input_shape_tensor->data_c()); + for (auto i = 0; i < shape_v[0]; ++i) { + if (input_shape_ptr[i] > 0) { + out_shape.push_back(input_shape_ptr[i]); + shape_m *= input_shape_ptr[i]; + } else { + MS_EXCEPTION(ValueError) << "Each dimension must be greater than 0."; + } + } + } + if (shape_m > max_length) { + MS_EXCEPTION(ValueError) << "The number of elements of output must be less than max length: " << max_length + << ", but got " << shape_m + << "! The shape of output should be reduced or max_length should be increased"; + } + return std::make_shared(out_shape); + } else { + const uint32_t input_shapes = static_cast(std::pow(max_length, 1.0 / shape_v[0])); + std::vector output_shape; + ShapeVector shape_min; + ShapeVector shape_max; + for (int i = 0; i < shape_v[0]; i++) { + output_shape.push_back(abstract::Shape::SHP_ANY); + shape_min.push_back(0); + shape_max.push_back(input_shapes); + } + return std::make_shared(output_shape, shape_min, shape_max); + } +} + +TypePtr TruncatedNormalInferType(const PrimitivePtr &prim, const std::vector &input_args) { + auto prim_name = prim->name(); + const uint32_t input_num = 1; + (void)CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, prim_name); + const std::set valid_input_types = {kInt32, kInt64}; + (void)CheckAndConvertUtils::CheckTensorTypeValid("shape", input_args[0]->BuildType(), valid_input_types, prim_name); + auto dtype_value = prim->GetAttr("dtype"); + if (!dtype_value->isa()) { + MS_EXCEPTION(TypeError) << "The dtype of " + prim_name + " is invalid!"; + } + auto output_type = dtype_value->cast(); + const std::set valid_output_types = {kFloat16, kFloat32, kFloat64}; + return CheckAndConvertUtils::CheckSubClass("dtype", output_type, valid_output_types, prim_name); +} +} // namespace + +MIND_API_BASE_IMPL(TruncatedNormal, PrimitiveC, BaseOperator); +AbstractBasePtr TruncatedNormalInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args) { + MS_EXCEPTION_IF_NULL(primitive); + const int64_t kInputNum = 1; + CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, kInputNum, primitive->name()); + auto infer_type = TruncatedNormalInferType(primitive, input_args); + auto infer_shape = TruncatedNormalInferShape(primitive, input_args); + return abstract::MakeAbstract(infer_shape, infer_type); +} +REGISTER_PRIMITIVE_EVAL_IMPL(TruncatedNormal, prim::kPrimTruncatedNormal, TruncatedNormalInfer, nullptr, true); +} // namespace ops +} // namespace mindspore diff --git a/mindspore/core/ops/truncated_normal.h b/mindspore/core/ops/truncated_normal.h new file mode 100644 index 00000000000..43fa6634650 --- /dev/null +++ b/mindspore/core/ops/truncated_normal.h @@ -0,0 +1,42 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CORE_OPS_TRUNCATEDNORMAL_H_ +#define MINDSPORE_CORE_OPS_TRUNCATEDNORMAL_H_ + +#include +#include +#include +#include + +#include "ops/base_operator.h" +#include "mindapi/base/types.h" + +namespace mindspore { +namespace ops { +constexpr auto kTruncatedNormal = "TruncatedNormal"; +class MIND_API TruncatedNormal : public BaseOperator { + public: + MIND_API_BASE_MEMBER(TruncatedNormal); + TruncatedNormal() : BaseOperator(kTruncatedNormal) { InitIOName({"shape"}, {"output"}); } +}; + +abstract::AbstractBasePtr TruncatedNormalInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args); +using PrimTruncatedNormalPtr = std::shared_ptr; +} // namespace ops +} // namespace mindspore +#endif diff --git a/mindspore/python/mindspore/ops/_op_impl/aicpu/__init__.py b/mindspore/python/mindspore/ops/_op_impl/aicpu/__init__.py index bea3f046711..a759c321799 100644 --- a/mindspore/python/mindspore/ops/_op_impl/aicpu/__init__.py +++ b/mindspore/python/mindspore/ops/_op_impl/aicpu/__init__.py @@ -141,6 +141,7 @@ from .environ_destroy_all import _environ_destroy_all_aicpu from .cross import _cross_aicpu from .cummax import _cummax_aicpu from .round import _round_aicpu +from .truncated_normal import _truncated_normal_aicpu from .floor_div import _floor_div_aicpu from .non_deterministic_ints import _non_deterministic_ints_aicpu from .one_hot import _one_hot_aicpu diff --git a/mindspore/python/mindspore/ops/_op_impl/aicpu/truncated_normal.py b/mindspore/python/mindspore/ops/_op_impl/aicpu/truncated_normal.py new file mode 100644 index 00000000000..7244b660ef5 --- /dev/null +++ b/mindspore/python/mindspore/ops/_op_impl/aicpu/truncated_normal.py @@ -0,0 +1,37 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""TruncatedNormal op""" +from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType + +truncated_normal_op_info = AiCPURegOp("TruncatedNormal")\ + .fusion_type("OPAQUE")\ + .input(0, "shape", "required")\ + .output(0, "output", "required")\ + .attr("seed", "int")\ + .attr("seed2", "int")\ + .dtype_format(DataType.I32_Default, DataType.F16_Default)\ + .dtype_format(DataType.I32_Default, DataType.F32_Default)\ + .dtype_format(DataType.I32_Default, DataType.F64_Default)\ + .dtype_format(DataType.I64_Default, DataType.F16_Default)\ + .dtype_format(DataType.I64_Default, DataType.F32_Default)\ + .dtype_format(DataType.I64_Default, DataType.F64_Default)\ + .get_op_info() + + +@op_info_register(truncated_normal_op_info) +def _truncated_normal_aicpu(): + """TruncatedNormal aicpu register""" + return diff --git a/mindspore/python/mindspore/ops/operations/__init__.py b/mindspore/python/mindspore/ops/operations/__init__.py index d7e274ad90b..5ebd1a8ba3f 100644 --- a/mindspore/python/mindspore/ops/operations/__init__.py +++ b/mindspore/python/mindspore/ops/operations/__init__.py @@ -40,7 +40,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Stack, Unpack, Unsta Shape, DynamicShape, TensorShape, Size, Slice, Split, SplitV, TransShape, ParallelConcat, Padding, UniqueWithPad, ScatterNdMax, ScatterNdMin, ScatterNdAdd, ScatterNdSub, ScatterNdMul, ScatterNdDiv, ScatterNonAliasingAdd, ReverseV2, Rint, - Squeeze, StridedSlice, Tile, EditDistance, Sort, Transpose, TruncatedNormal, TupleToArray, + Squeeze, StridedSlice, Tile, EditDistance, Sort, Transpose, TupleToArray, UnsortedSegmentMin, UnsortedSegmentMax, UnsortedSegmentProd, UnsortedSegmentSum, SpaceToDepth, DepthToSpace, SpaceToBatch, BatchToSpace, SpaceToBatchND, BatchToSpaceND, BroadcastTo, InplaceUpdate, ReverseSequence, @@ -246,7 +246,6 @@ __all__ = [ 'DivNoNan', 'Inv', 'Invert', - 'TruncatedNormal', 'Fill', 'Ones', 'Zeros', diff --git a/mindspore/python/mindspore/ops/operations/array_ops.py b/mindspore/python/mindspore/ops/operations/array_ops.py index 9b76a67d44d..912f49ed097 100755 --- a/mindspore/python/mindspore/ops/operations/array_ops.py +++ b/mindspore/python/mindspore/ops/operations/array_ops.py @@ -1133,45 +1133,6 @@ class Rank(PrimitiveWithInfer): return out -class TruncatedNormal(PrimitiveWithInfer): - """ - Returns a tensor of the specified shape filled with truncated normal values. - - The generated values follow a normal distribution. - - Args: - seed (int): A integer number used to create random seed. Default: 0. - dtype (:class:`mindspore.dtype`): Data type. Default: mindspore.float32. - - Inputs: - - **shape** (tuple[int]) - The shape of the output tensor, is a tuple of positive integer. - - Outputs: - Tensor, the data type of output tensor is the same as attribute `dtype`. - - Examples: - >>> shape = (1, 2, 3) - >>> truncated_normal = ops.TruncatedNormal() - >>> output = truncated_normal(shape) - """ - - @prim_attr_register - def __init__(self, seed=0, dtype=mstype.float32): - """Initialize TruncatedNormal""" - validator.check_value_type('seed', seed, [int], self.name) - validator.check_types_same_and_valid({'dtype': dtype}, mstype.number_type, self.name) - - def __infer__(self, shape): - shape_value = shape['value'] - validator.check_value_type("shape", shape_value, [tuple], self.name) - for i, value in enumerate(shape_value): - validator.check_positive_int(value, f'{i}th value of shape', self.name) - out = {'shape': shape_value, - 'dtype': mstype.tensor_type(self.dtype), - 'value': None} - return out - - class Size(PrimitiveWithInfer): r""" Returns a Scalar of type int that represents the size of the input Tensor and the total number of elements in the diff --git a/mindspore/python/mindspore/ops/operations/random_ops.py b/mindspore/python/mindspore/ops/operations/random_ops.py index f8f11e35cd5..df5eeb9f08e 100644 --- a/mindspore/python/mindspore/ops/operations/random_ops.py +++ b/mindspore/python/mindspore/ops/operations/random_ops.py @@ -71,6 +71,64 @@ class NonDeterministicInts(Primitive): Validator.check_type_name("dtype", dtype, valid_values, self.name) +class TruncatedNormal(Primitive): + """ + Returns a tensor of the specified shape filled with truncated normal values. + + The generated values follow a normal distribution. + + .. warning:: + The value of "shape" must be greater than zero. The output length must be less than 1000000. + + Args: + seed (int): An optional int. Defaults to 0. If either `seed` or `seed2` are set to be non-zero, + the seed is set by the given seed. Otherwise, it is seeded by a random seed. + seed2 (int): An optional int. Defaults to 0. A second seed to avoid seed collision. + dtype (mindspore.dtype): Must be one of the following types: mindspore.float16, mindspore.float32 and + mindspore.float64. Default: mindspore.float32. + + Inputs: + - **shape** (Tensor) - The shape of random tensor to be generated. Its type must be one of the following types: + mindspore.int32 and mindspore.int64. + + Outputs: + Tensor. Its shape is spcified by the input `shape`. Its type is spcified by `dtype`. + Its values are in [-2,2]. + + Raises: + TypeError: If `shape` is not a Tensor. + TypeError: If `dtype` and input tensor type are not allowed. + ValueError: If `shape` elements are not positive. + ValueError: If `shape` has less than 2 elements. + ValueError: If `shape` is not a 1-D tensor. + ValueError: If the number of elements of output is more than 1000000. + + Supported Platforms: + ``Ascend`` ``CPU`` + + Examples: + >>> shape = Tensor(np.array([2, 2]), mstype.int32) + >>> seed = 0 + >>> seed2 = 0 + >>> truncated_normal = ops.TruncatedNormal(seed=seed, seed2=seed2) + >>> output = truncated_normal(shape) + >>> print(output) + [[ -1.303105 0.641905 ] + [ -0.917926 0.650655 ]] + """ + + @prim_attr_register + def __init__(self, dtype=mstype.float32, seed=0, seed2=0): + """Initialize TruncatedNormal""" + self.dtype = dtype + self.add_prim_attr("max_length", 1000000) + self.init_prim_io_names(inputs=["shape"], outputs=["output"]) + Validator.check_value_type('seed', seed, [int], self.name) + Validator.check_value_type('seed2', seed2, [int], self.name) + valid_values = (mstype.float16, mstype.float32, mstype.float64) + Validator.check_type_name("dtype", dtype, valid_values, self.name) + + class StandardNormal(PrimitiveWithInfer): r""" Generates random numbers according to the standard Normal (or Gaussian) random number distribution. diff --git a/tests/ut/python/ops/test_ops.py b/tests/ut/python/ops/test_ops.py index a3cee81e44e..dfb1eebafd0 100755 --- a/tests/ut/python/ops/test_ops.py +++ b/tests/ut/python/ops/test_ops.py @@ -34,6 +34,7 @@ from mindspore.ops.operations.math_ops import BesselK0, BesselK1, BesselK0e, Bes from mindspore.ops.operations import nn_ops as nps from mindspore.ops.operations.array_ops import Tril from mindspore.ops.operations.random_ops import NonDeterministicInts +from mindspore.ops.operations.random_ops import TruncatedNormal from mindspore.ops.operations.array_ops import Triu from mindspore.ops.operations.array_ops import MatrixDiagV3 from mindspore.ops.operations.array_ops import MatrixDiagPartV3 @@ -1469,12 +1470,6 @@ test_case_math_ops = [ 'block': P.Sub(), 'desc_inputs': [[3], [3]], 'desc_bprop': [[3]]}), - ('TruncatedNormal', { - 'block': P.TruncatedNormal(), - 'desc_const': [(1, 2, 3)], - 'desc_inputs': [], - 'skip': ['backward'], - 'add_fake_input': True}), ('Select', { 'block': P.Select(), 'desc_inputs': [Tensor(np.array([[True, False, False], [False, True, True]])), @@ -3054,6 +3049,10 @@ test_case_other_ops = [ 'block': NonDeterministicInts(dtype=mstype.int32), 'desc_inputs': [Tensor(np.array([2, 2]), mstype.int32)], 'skip': ['backward']}), + ('TruncatedNormal', { + 'block': TruncatedNormal(dtype=mstype.float32, seed=1, seed2=1), + 'desc_inputs': [Tensor(np.array([2, 2]), mstype.int32)], + 'skip': ['backward']}), ('ScalarLog', { 'block': F.scalar_log, 'desc_const': [0.0],