From f26890becca6508682ef1d693554d48aeae45f6f Mon Sep 17 00:00:00 2001 From: WangXiaoyin <506413424@qq.com> Date: Wed, 7 Dec 2022 15:05:58 +0800 Subject: [PATCH] [feat] [assistant] [I4XJIC] Add Uniform --- .../device/cpu/kernel/uniform_cpu_kernel.cc | 146 ++++++++++++++++++ .../device/cpu/kernel/uniform_cpu_kernel.h | 83 ++++++++++ mindspore/core/ops/uniform.cc | 18 ++- mindspore/core/ops/uniform.h | 14 +- .../mindspore/ops/_op_impl/aicpu/uniform.py | 34 ++++ .../mindspore/ops/function/math_func.py | 4 +- .../mindspore/ops/operations/random_ops.py | 24 +-- tests/ut/python/ops/test_ops.py | 5 + 8 files changed, 314 insertions(+), 14 deletions(-) create mode 100644 mindspore/ccsrc/plugin/device/cpu/kernel/uniform_cpu_kernel.cc create mode 100644 mindspore/ccsrc/plugin/device/cpu/kernel/uniform_cpu_kernel.h create mode 100644 mindspore/python/mindspore/ops/_op_impl/aicpu/uniform.py diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/uniform_cpu_kernel.cc b/mindspore/ccsrc/plugin/device/cpu/kernel/uniform_cpu_kernel.cc new file mode 100644 index 00000000000..592d340a2b5 --- /dev/null +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/uniform_cpu_kernel.cc @@ -0,0 +1,146 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "plugin/device/cpu/kernel/uniform_cpu_kernel.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "mindspore/core/ops/uniform.h" +#include "kernel/common_utils.h" +#include "utils/ms_utils.h" +#include "plugin/device/cpu/hal/device/cpu_device_address.h" +#include "plugin/device/cpu/kernel/cpu_kernel.h" + +namespace mindspore { +namespace kernel { +namespace { +const size_t kUniformInputsNum = 1; +const size_t kUniformOutputsNum = 1; +} // namespace + +uint64_t UniformCpuKernelMod::New64() { + std::random_device device("/dev/urandom"); + static std::mt19937_64 rng = std::mt19937_64(device()); + return (rng)(); +} + +void UniformCpuKernelMod::InitMSPhiloxRandom(int64_t seed_, int64_t offset_) { + if (seed_ == 0 && offset_ == 0) { + seed_ = New64(); + offset_ = New64(); + } + generator_ = random::MSPhiloxRandom(seed_, offset_); +} + +float UniformCpuKernelMod::RandFloat() { + uint32_t x = GenerateSingle(); + const uint32_t man = x & 0x7fffffu; // 23 bit mantissa + const uint32_t exp = static_cast(127); + const uint32_t val = (exp << 23) | man; + + float result; + memcpy_s(&result, sizeof(result), &val, sizeof(val)); + return result - 1.0f; +} + +uint32_t UniformCpuKernelMod::GenerateSingle() { + if (used_result_index_ == random::MSPhiloxRandom::kResultElementCount) { + unused_results_ = generator_(); + used_result_index_ = 0; + } + return unused_results_[used_result_index_++]; +} + +bool UniformCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector &inputs, + const std::vector &outputs) { + MS_EXCEPTION_IF_NULL(base_operator); + auto op = std::dynamic_pointer_cast(base_operator); + kernel_name_ = op->name(); + auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs); + auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport()); + kernel_ptr_ = std::make_shared(base_operator->GetPrim()); + if (!is_match) { + MS_LOG(EXCEPTION) << "Uniform does not support this kernel data type: " << kernel_attr; + } + from_ = op->get_from(); + to_ = op->get_to(); + seed_ = op->get_seed(); + offset_ = op->get_offset(); + if (from_ > to_) { + MS_LOG(ERROR) << "For Uniform, 'minval' must <= 'maxval', but got 'minval'=" << from_ << " ,'maxval'=" << to_; + } + kernel_func_ = func_list_[index].second; + return true; +} + +int UniformCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector &inputs, + const std::vector &outputs, + const std::map &) { + int ret = KRET_OK; + if ((ret = NativeCpuKernelMod::Resize(base_operator, inputs, outputs)) != 0) { + return ret; + } + std::vector input_shape = inputs.at(kIndex0)->GetShapeVector(); + std::transform(input_shape.begin(), input_shape.end(), std::back_inserter(input_shape_), LongToSize); + return ret; +} + +template +bool UniformCpuKernelMod::LaunchKernel(const std::vector &inputs, + const std::vector &outputs) { + CHECK_KERNEL_INPUTS_NUM(inputs.size(), kUniformInputsNum, kernel_name_); + CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kUniformOutputsNum, kernel_name_); + + InitMSPhiloxRandom(seed_, offset_); + + auto y = reinterpret_cast(outputs[0]->addr); + input_elements_ = std::accumulate(input_shape_.begin(), input_shape_.end(), int64_t(1), std::multiplies()); + for (int64_t i = 0; i < input_elements_; i++) { + y[i] = static_cast(RandFloat() * (to_ - from_) + from_); + } + + return true; +} + +std::vector> UniformCpuKernelMod::func_list_ = { + {KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + &UniformCpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + &UniformCpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), + &UniformCpuKernelMod::LaunchKernel}}; + +std::vector UniformCpuKernelMod::GetOpSupport() { + std::vector support_list; + (void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list), + [](const std::pair &pair) { return pair.first; }); + return support_list; +} + +MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, Uniform, UniformCpuKernelMod); +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/uniform_cpu_kernel.h b/mindspore/ccsrc/plugin/device/cpu/kernel/uniform_cpu_kernel.h new file mode 100644 index 00000000000..f4912dfb4ed --- /dev/null +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/uniform_cpu_kernel.h @@ -0,0 +1,83 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_UNIFORM_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_UNIFORM_CPU_KERNEL_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include "plugin/device/cpu/kernel/cpu_kernel.h" +#include "plugin/factory/ms_factory.h" +#include "plugin/device/cpu/kernel/random_util.h" + +namespace mindspore { +namespace kernel { +class UniformCpuKernelMod : public NativeCpuKernelMod { + public: + UniformCpuKernelMod() = default; + ~UniformCpuKernelMod() override = default; + bool Init(const BaseOperatorPtr &base_operator, const std::vector &inputs, + const std::vector &outputs) override; + int Resize(const BaseOperatorPtr &base_operator, const std::vector &inputs, + const std::vector &outputs, const std::map &) override; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override { + return kernel_func_(this, inputs, outputs); + } + + std::vector GetOpSupport() override; + + private: + bool CheckUniformShape(); + + template + bool LaunchKernel(const std::vector &inputs, const std::vector &outputs); + using UniformFunc = std::function &, + const std::vector &)>; + + private: + random::MSPhiloxRandom generator_; + using ResType = random::Array; + ResType unused_results_; + size_t used_result_index_ = random::MSPhiloxRandom::kResultElementCount; + + float RandFloat(); + uint64_t New64(); + void InitMSPhiloxRandom(int64_t seed, int64_t offset); + uint32_t GenerateSingle(); + + static std::vector> func_list_; + UniformFunc kernel_func_; + std::vector input_shape_; + std::vector output_shape_; + int64_t input_elements_; + float from_{0.0}; + float to_{1.0}; + int64_t seed_{0}; + int64_t offset_{0}; + BaseOperatorPtr kernel_ptr_{nullptr}; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_UNIFORM_CPU_KERNEL_H_ diff --git a/mindspore/core/ops/uniform.cc b/mindspore/core/ops/uniform.cc index a0604bace5e..95b1c3acb73 100644 --- a/mindspore/core/ops/uniform.cc +++ b/mindspore/core/ops/uniform.cc @@ -21,14 +21,20 @@ namespace mindspore { namespace ops { -void Uniform::Init(float from, float to) { +void Uniform::Init(float from, float to, int64_t seed, int64_t offset) { this->set_from(from); this->set_to(to); + this->set_seed(seed); + this->set_offset(offset); } void Uniform::set_from(float from) { (void)this->AddAttr(kFrom, api::MakeValue(from)); } void Uniform::set_to(float to) { (void)this->AddAttr(kTo, api::MakeValue(to)); } +void Uniform::set_seed(int64_t seed) { (void)this->AddAttr(kSeed, api::MakeValue(seed)); } + +void Uniform::set_offset(int64_t offset) { (void)this->AddAttr(kOffset, api::MakeValue(offset)); } + float Uniform::get_from() const { auto value_ptr = GetAttr(kFrom); return GetValue(value_ptr); @@ -39,6 +45,16 @@ float Uniform::get_to() const { return GetValue(value_ptr); } +int64_t Uniform::get_seed() const { + auto value_ptr = GetAttr(kSeed); + return GetValue(value_ptr); +} + +int64_t Uniform::get_offset() const { + auto value_ptr = GetAttr(kOffset); + return GetValue(value_ptr); +} + namespace { abstract::ShapePtr UniformInferShape(const PrimitivePtr &primitive, const std::vector &input_args) { auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; diff --git a/mindspore/core/ops/uniform.h b/mindspore/core/ops/uniform.h index 68e5af7ed49..3f8b611e22c 100644 --- a/mindspore/core/ops/uniform.h +++ b/mindspore/core/ops/uniform.h @@ -34,11 +34,15 @@ class MIND_API Uniform : public BaseOperator { public: Uniform() : BaseOperator(kNameUniform) { InitIOName({"x"}, {"y"}); } /// \brief Method to init the ops attributes. - void Init(const float from, const float to); + void Init(const float from, const float to, const int64_t seed, const int64_t offset); /// \brief Set from. void set_from(const float from); /// \brief Set to. void set_to(const float to); + /// \brief Set seed. + void set_seed(const int64_t seed); + /// \brief Set offset. + void set_offset(const int64_t offset); /// \brief Get from. /// /// \return from. @@ -47,6 +51,14 @@ class MIND_API Uniform : public BaseOperator { /// /// \return to. float get_to() const; + /// \brief Get seed. + /// + /// \return seed. + int64_t get_seed() const; + /// \brief Get offset. + /// + /// \return offset. + int64_t get_offset() const; MIND_API_BASE_MEMBER(Uniform); }; diff --git a/mindspore/python/mindspore/ops/_op_impl/aicpu/uniform.py b/mindspore/python/mindspore/ops/_op_impl/aicpu/uniform.py new file mode 100644 index 00000000000..a9d9ac0ef24 --- /dev/null +++ b/mindspore/python/mindspore/ops/_op_impl/aicpu/uniform.py @@ -0,0 +1,34 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Uniform op""" +from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType + +uniform_op_info = AiCPURegOp("Uniform") \ + .fusion_type("OPAQUE") \ + .input(0, "x", "required") \ + .output(0, "y", "required") \ + .attr("from", "float") \ + .attr("to", "float") \ + .attr("seed", "int") \ + .attr("offset", "int") \ + .dtype_format(DataType.F32_Default, DataType.F32_Default) \ + .dtype_format(DataType.F64_Default, DataType.F64_Default) \ + .get_op_info() + + +@op_info_register(uniform_op_info) +def _uniform_aicpu(): + """Uniform aicpu register""" + return diff --git a/mindspore/python/mindspore/ops/function/math_func.py b/mindspore/python/mindspore/ops/function/math_func.py index ee2713a485f..9cbf8c529c3 100644 --- a/mindspore/python/mindspore/ops/function/math_func.py +++ b/mindspore/python/mindspore/ops/function/math_func.py @@ -5525,7 +5525,7 @@ def tril_indices(row, col, offset=0, dtype=mstype.int64): ``GPU`` ``CPU`` Examples: - >>> net = ops.TrilIndices(4, 3, -1, mindspore.int64) + >>> net = ops.tril_indices(4, 3, -1, mindspore.int64) >>> output = net() >>> print(output) [[1 2 2 3 3 3] @@ -5570,7 +5570,7 @@ def triu_indices(row, col, offset=0, dtype=mstype.int64): ``GPU`` ``CPU`` Examples: - >>> net = ops.TriuIndices(5, 4, 2, mindspore.int64) + >>> net = ops.triu_indices(5, 4, 2, mindspore.int64) >>> output = net() >>> print(output) [[0 0 1] diff --git a/mindspore/python/mindspore/ops/operations/random_ops.py b/mindspore/python/mindspore/ops/operations/random_ops.py index 5ac1f3a75e9..bf2642c8ce4 100755 --- a/mindspore/python/mindspore/ops/operations/random_ops.py +++ b/mindspore/python/mindspore/ops/operations/random_ops.py @@ -1005,39 +1005,43 @@ class Uniform(Primitive): Generates random numbers according to the Uniform random number distribution. Args: - min_val(float):must be non-negative. Default: 0.0. - max_val(float):must be non-negative. Default: 1.0. + minval(float):must be non-negative. Default: 0.0. + maxval(float):must be non-negative. Default: 1.0. Inputs: - **x** (Tensor) - The x of random tensor to be generated. Only constant value is allowed, and the date type is float16, float32, float64. Raises: - TypeError: If `min_val` or `max_val` is not a float. + TypeError: If `minval` or `maxval` is not a float. TypeError: If `x`is not a Tensor. + ValueError: If `minval` is larger than `maxval`. Outputs: - **output** (Tensor) - With the same type and shape as the 'x'. Supported Platforms: - ``GPU`` + ``GPU`` ``CPU`` Examples: >>> x = Tensor(np.random.randn(3,4), mstype.float64) - >>> uniform = Uniform(min_val=1.0, max_val=2.0) + >>> uniform = Uniform(minval=1.0, maxval=2.0) >>> y = uniform(x) >>> print(y.shape) (3, 4) """ @prim_attr_register - def __init__(self, min_val=0, max_val=1): + def __init__(self, minval=0., maxval=1., seed=0, offset=0): """Initialize Uniform""" self.init_prim_io_names(inputs=['x'], outputs=['y']) - self.add_prim_attr("from", 0.0) - self.add_prim_attr("to", 1.0) - Validator.check_non_negative_float(min_val, "from", self.name) - Validator.check_non_negative_float(max_val, "to", self.name) + self.add_prim_attr("from", minval) + self.add_prim_attr("to", maxval) + Validator.check_value_type('seed', seed, [int], self.name) + Validator.check_value_type('offset', offset, [int], self.name) + Validator.check('minval', minval, 'maxval', maxval, Rel.LE, self.name) + Validator.check_non_negative_float(minval, "minval", self.name) + Validator.check_non_negative_float(maxval, "maxval", self.name) class RandpermV2(Primitive): diff --git a/tests/ut/python/ops/test_ops.py b/tests/ut/python/ops/test_ops.py index 840186a0e01..e00b802bb56 100644 --- a/tests/ut/python/ops/test_ops.py +++ b/tests/ut/python/ops/test_ops.py @@ -82,6 +82,7 @@ from mindspore.ops.operations.array_ops import SegmentProd from mindspore.ops.operations.array_ops import ScatterAddWithAxis from mindspore.ops.operations.array_ops import ConcatOffsetV1 from mindspore.ops.operations.random_ops import NonDeterministicInts +from mindspore.ops.operations.random_ops import Uniform from mindspore.ops.operations.random_ops import TruncatedNormal from mindspore.ops.operations.random_ops import MultinomialWithReplacement from mindspore.ops.operations.random_ops import ParameterizedTruncatedNormal @@ -4310,6 +4311,10 @@ test_case_other_ops = [ 'block': NonDeterministicInts(dtype=mstype.int32), 'desc_inputs': [Tensor(np.array([2, 2]), mstype.int32)], 'skip': ['backward']}), + ('UniformOps', { + 'block': Uniform(minval=0., maxval=1., seed=1, offset=1), + 'desc_inputs': [Tensor(np.array([2, 2]), mstype.float32)], + 'skip': ['backward']}), ('TruncatedNormal', { 'block': TruncatedNormal(dtype=mstype.float32, seed=1, seed2=1), 'desc_inputs': [Tensor(np.array([2, 2]), mstype.int32)],