[feat] [assistant] [I40GHB] add Cauchy
[feat] [assistant] [I40GHB] add Cauchy
This commit is contained in:
parent
3942f0a0d8
commit
22d82e2a0a
|
@ -0,0 +1,82 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "plugin/device/cpu/kernel/cauchy_cpu_kernel.h"
|
||||
#include <vector>
|
||||
#include <cmath>
|
||||
#include <type_traits>
|
||||
#include <memory>
|
||||
#include <functional>
|
||||
#include <random>
|
||||
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel.h"
|
||||
#include "plugin/device/cpu/kernel/arithmetic_cpu_kernel.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
const size_t kCauchyOutputNum = 1;
|
||||
|
||||
// namespace
|
||||
|
||||
void CauchyCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
size_t output_num = common::AnfAlgo::GetOutputTensorNum(kernel_node);
|
||||
CHECK_KERNEL_OUTPUTS_NUM(output_num, kCauchyOutputNum, common::AnfAlgo::GetCNodeName(kernel_node));
|
||||
|
||||
std::vector<int64_t> size_ = common::AnfAlgo::GetNodeAttr<std::vector<int64_t>>(kernel_node, "size");
|
||||
sigma_ = common::AnfAlgo::GetNodeAttr<float>(kernel_node, "sigma");
|
||||
median_ = common::AnfAlgo::GetNodeAttr<float>(kernel_node, "median");
|
||||
auto y_shape = common::AnfAlgo::GetOutputInferShape(kernel_node, 0);
|
||||
for (size_t i = 0; i < size_.size(); i++) {
|
||||
if (size_[i] <= 0) {
|
||||
MS_EXCEPTION(ValueError) << "For Cauchy, each dimension of size must be greater than zero.";
|
||||
}
|
||||
if (size_[i] != y_shape[i]) {
|
||||
MS_EXCEPTION(ValueError) << "For Cauchy, output shape not equal with size in dimension " << i << " .";
|
||||
}
|
||||
}
|
||||
}
|
||||
bool CauchyCpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &, const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
LaunchKernel<float>(outputs);
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool CauchyCpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &outputs) {
|
||||
T *y_data = reinterpret_cast<T *>(outputs[0]->addr);
|
||||
std::random_device rd;
|
||||
std::default_random_engine generator(rd());
|
||||
std::cauchy_distribution<float> cauchy_d(median_, sigma_);
|
||||
auto end = outputs[0]->size / sizeof(T);
|
||||
|
||||
for (size_t i = 0; i < end; ++i) {
|
||||
float data = cauchy_d(generator);
|
||||
y_data[i] = static_cast<T>(data);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<KernelAttr> CauchyCpuKernelMod::GetOpSupport() {
|
||||
static std::vector<KernelAttr> support_list = {KernelAttr().AddOutputAttr(kNumberTypeFloat16),
|
||||
KernelAttr().AddOutputAttr(kNumberTypeFloat32)};
|
||||
return support_list;
|
||||
}
|
||||
|
||||
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, Cauchy, CauchyCpuKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,50 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_CAUCHY_CPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_CAUCHY_CPU_KERNEL_H_
|
||||
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel.h"
|
||||
#include "plugin/factory/ms_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
class CauchyCpuKernelMod : public DeprecatedNativeCpuKernelMod {
|
||||
public:
|
||||
CauchyCpuKernelMod() = default;
|
||||
~CauchyCpuKernelMod() override = default;
|
||||
void InitKernel(const CNodePtr &kernel_node) override;
|
||||
bool Launch(const std::vector<AddressPtr> &, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
template <typename T>
|
||||
bool LaunchKernel(const std::vector<AddressPtr> &outputs);
|
||||
|
||||
float sigma_ = 1.0, median_ = 0;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_CAUCHY_CPU_KERNEL_H_
|
|
@ -0,0 +1,50 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "ops/cauchy.h"
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "abstract/ops/primitive_infer_map.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
abstract::ShapePtr CauchyInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
(void)CheckAndConvertUtils::CheckInteger("input numbers", input_args.size(), kGreaterEqual, 0, prim_name);
|
||||
MS_EXCEPTION_IF_NULL(primitive->GetAttr("size"));
|
||||
auto size = GetValue<std::vector<int64_t>>(primitive->GetAttr("size"));
|
||||
(void)CheckAndConvertUtils::CheckInteger("the length of 'size'", size.size(), kGreaterThan, 0, prim_name);
|
||||
return std::make_shared<abstract::Shape>(size);
|
||||
}
|
||||
|
||||
MIND_API_OPERATOR_IMPL(Cauchy, BaseOperator);
|
||||
|
||||
abstract::AbstractBasePtr CauchyInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
|
||||
auto infer_shape = CauchyInferShape(primitive, input_args);
|
||||
auto infer_type = std::make_shared<TensorType>(kFloat32);
|
||||
return abstract::MakeAbstract(infer_shape, infer_type);
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(Cauchy, prim::kPrimCauchy, CauchyInfer, nullptr, true);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,39 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_OPS_CAUCHY_H_
|
||||
#define MINDSPORE_CORE_OPS_CAUCHY_H_
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
|
||||
#include "ops/base_operator.h"
|
||||
#include "mindapi/base/types.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameCauchy = "Cauchy";
|
||||
class MIND_API Cauchy : public BaseOperator {
|
||||
public:
|
||||
Cauchy() : BaseOperator(kNameCauchy) {}
|
||||
MIND_API_BASE_MEMBER(Cauchy);
|
||||
void Init() const {}
|
||||
};
|
||||
abstract::AbstractBasePtr CauchyInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||
using PrimCauchy = std::shared_ptr<Cauchy>;
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CORE_OPS_Cauchy_H_
|
|
@ -133,6 +133,7 @@ constexpr auto kBernoulli = "Bernoulli";
|
|||
constexpr auto kLinearSumAssignment = "LinearSumAssignment";
|
||||
|
||||
// Math
|
||||
constexpr auto kCauchy = "Cauchy";
|
||||
constexpr auto kCross = "Cross";
|
||||
constexpr auto kEditDistance = "EditDistance";
|
||||
constexpr auto kNextAfter = "NextAfter";
|
||||
|
@ -1128,6 +1129,7 @@ GVAR_DEF(PrimitivePtr, kPrimTensorListStack, std::make_shared<Primitive>("Tensor
|
|||
GVAR_DEF(PrimitivePtr, kPrimTensorListSetItem, std::make_shared<Primitive>("TensorListSetItem"));
|
||||
|
||||
// Maths
|
||||
GVAR_DEF(PrimitivePtr, kPrimCauchy, std::make_shared<Primitive>(kCauchy));
|
||||
GVAR_DEF(PrimitivePtr, kPrimNextAfter, std::make_shared<Primitive>(kNextAfter));
|
||||
GVAR_DEF(PrimitivePtr, kPrimCross, std::make_shared<Primitive>(kCross));
|
||||
GVAR_DEF(PrimitivePtr, kPrimEditDistance, std::make_shared<Primitive>(kEditDistance));
|
||||
|
|
|
@ -0,0 +1,33 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Cauchy op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
|
||||
|
||||
cauchy_op_info = AiCPURegOp("Cauchy") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.attr("size", "listInt") \
|
||||
.attr("sigma", "float") \
|
||||
.attr("median", "float") \
|
||||
.output(0, "y", "required") \
|
||||
.dtype_format(DataType.F16_Default)\
|
||||
.dtype_format(DataType.F32_Default)\
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(cauchy_op_info)
|
||||
def _cauchy_aicpu():
|
||||
"""Cauchy AiCPU register"""
|
||||
return
|
|
@ -7162,3 +7162,47 @@ class Qr(Primitive):
|
|||
def __init__(self, full_matrices=False):
|
||||
"""Initialize Qr"""
|
||||
validator.check_value_type('full_matrices', full_matrices, [bool], self.name)
|
||||
|
||||
|
||||
class Cauchy(Primitive):
|
||||
r"""
|
||||
Create a tensor of shape `size` with random numbers drawn from Cauchy distribution
|
||||
|
||||
.. math::
|
||||
\f(x)= \frac{1}{\pi} \frac{\sigma}{(x-median)^2 +\sigma^2}
|
||||
|
||||
Args:
|
||||
size (list(int)): The size of tensor.
|
||||
sigma (float): the location parameter, specifying the location
|
||||
of the peak of the distribution. Default: 1.0.
|
||||
median (float): the scale parameter which specifies the half-width
|
||||
at half-maximum. Default: 0.0.
|
||||
|
||||
Outputs:
|
||||
- **y** (Tensor) - Tensor with cauchy distribution data. Tensor shape is size, and data type is float32.
|
||||
|
||||
Raises:
|
||||
TypeError: If `sigma` is not a float.
|
||||
TypeError: If `median` is not a float.
|
||||
TypeError: If `size` is not a list.
|
||||
ValueError: If `size` list is empty.
|
||||
ValueError: If data of `size` is not a positive integer.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> size = [1]
|
||||
>>> net = ops.Cauchy(size)
|
||||
>>> y = net()
|
||||
>>> print(y)
|
||||
[0.03128606]
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, size, median=0.0, sigma=1.0):
|
||||
validator.check_value_type('median', median, [float], self.name)
|
||||
validator.check_value_type('sigma', sigma, [float], self.name)
|
||||
validator.check_value_type('size', size, (list), self.name)
|
||||
for index, size_ in enumerate(size):
|
||||
validator.check_positive_int(size_, 'size[%d]' % index, self.name)
|
||||
|
|
|
@ -56,6 +56,7 @@ from mindspore.ops.operations.math_ops import CompareAndBitpack
|
|||
from mindspore.ops.operations.math_ops import Real, Imag, Complex, Angle
|
||||
from mindspore.ops.operations.math_ops import STFT
|
||||
from mindspore.ops.operations.math_ops import Qr
|
||||
from mindspore.ops.operations.math_ops import Cauchy
|
||||
from mindspore.ops.operations import nn_ops as nps
|
||||
from mindspore.ops.operations.array_ops import FillDiagonal
|
||||
from mindspore.ops.operations.array_ops import Im2Col
|
||||
|
@ -1432,6 +1433,10 @@ class BincountNet(nn.Cell):
|
|||
|
||||
|
||||
test_case_math_ops = [
|
||||
('Cauchy', {
|
||||
'block': Cauchy(size=[2, 3]),
|
||||
'desc_inputs': [],
|
||||
'skip': ['backward']}),
|
||||
('Betainc', {
|
||||
'block': Betainc(),
|
||||
'desc_inputs': [Tensor([1, 1, 1, 1], mstype.float32),
|
||||
|
|
Loading…
Reference in New Issue