adapt Cauchy and LogNormalReverse kernelmod.

This commit is contained in:
y00451588 2022-11-03 11:48:34 +08:00
parent af9bf8fe71
commit 389e6d93d9
6 changed files with 100 additions and 49 deletions

View File

@ -21,34 +21,31 @@
#include <memory>
#include <functional>
#include <random>
#include "mindspore/core/ops/cauchy.h"
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
#include "plugin/device/cpu/kernel/cpu_kernel.h"
#include "plugin/device/cpu/kernel/arithmetic_cpu_kernel.h"
namespace mindspore {
namespace kernel {
const size_t kCauchyOutputNum = 1;
namespace {
constexpr size_t kCauchyOutputNum = 1;
constexpr auto kAttrSigma = "sigma";
constexpr auto kAttrMedian = "median";
} // namespace
// namespace
void CauchyCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
size_t output_num = common::AnfAlgo::GetOutputTensorNum(kernel_node);
CHECK_KERNEL_OUTPUTS_NUM(output_num, kCauchyOutputNum, common::AnfAlgo::GetCNodeName(kernel_node));
std::vector<int64_t> size_ = common::AnfAlgo::GetNodeAttr<std::vector<int64_t>>(kernel_node, "size");
sigma_ = common::AnfAlgo::GetNodeAttr<float>(kernel_node, "sigma");
median_ = common::AnfAlgo::GetNodeAttr<float>(kernel_node, "median");
auto y_shape = common::AnfAlgo::GetOutputInferShape(kernel_node, 0);
for (size_t i = 0; i < size_.size(); i++) {
if (size_[i] <= 0) {
MS_EXCEPTION(ValueError) << "For Cauchy, each dimension of size must be greater than zero.";
}
if (size_[i] != y_shape[i]) {
MS_EXCEPTION(ValueError) << "For Cauchy, output shape not equal with size in dimension " << i << " .";
}
}
bool CauchyCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) {
MS_ERROR_IF_NULL(base_operator);
kernel_name_ = base_operator->name();
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kCauchyOutputNum, kernel_name_);
auto prim = std::dynamic_pointer_cast<ops::Cauchy>(base_operator);
MS_ERROR_IF_NULL(prim);
sigma_ = prim->get_sigma();
median_ = prim->get_median();
return true;
}
bool CauchyCpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &, const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &outputs) {
LaunchKernel<float>(outputs);

View File

@ -27,11 +27,14 @@
namespace mindspore {
namespace kernel {
class CauchyCpuKernelMod : public DeprecatedNativeCpuKernelMod {
class CauchyCpuKernelMod : public NativeCpuKernelMod {
public:
CauchyCpuKernelMod() = default;
~CauchyCpuKernelMod() override = default;
void InitKernel(const CNodePtr &kernel_node) override;
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) override;
bool Launch(const std::vector<AddressPtr> &, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs) override;
@ -42,7 +45,8 @@ class CauchyCpuKernelMod : public DeprecatedNativeCpuKernelMod {
template <typename T>
bool LaunchKernel(const std::vector<AddressPtr> &outputs);
float sigma_ = 1.0, median_ = 0;
float sigma_{1.0};
float median_{0};
};
} // namespace kernel
} // namespace mindspore

View File

@ -26,29 +26,38 @@
namespace mindspore {
namespace kernel {
namespace {
const uint32_t kNumInput = 1;
const uint32_t kNumOutput = 1;
constexpr uint32_t kNumInput = 1;
constexpr uint32_t kNumOutput = 1;
constexpr auto kAttrMean = "mean";
constexpr auto kAttrStd = "std";
} // namespace
void LogNormalReverseCpuKernel::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
input_dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
input_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
output_dtype_ = AnfAlgo::GetOutputDeviceDataType(kernel_node, 0);
output_shape_ = AnfAlgo::GetOutputDeviceShape(kernel_node, 0);
if (input_dtype_ != kNumberTypeFloat32 && input_dtype_ != kNumberTypeFloat16) {
if (input_dtype_ != kNumberTypeFloat64) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_
<< ", the datatype of the input1 not support, support datatype: float16, float32, float64.";
}
bool LogNormalReverseCpuKernel::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) {
MS_ERROR_IF_NULL(base_operator);
kernel_name_ = base_operator->name();
auto prim = base_operator->GetPrim();
MS_ERROR_IF_NULL(prim);
input_mean_ = GetValue<float>(prim->GetAttr(kAttrMean));
input_std_ = GetValue<float>(prim->GetAttr(kAttrStd));
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
auto is_match = MatchKernelAttr(kernel_attr, GetOpSupport());
if (!is_match.first) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "', it does not support this kernel data type: " << kernel_attr;
return false;
}
if (input_dtype_ != output_dtype_) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_
<< ", the data type of the input does not match the data type of the output.";
return true;
}
int LogNormalReverseCpuKernel::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &) {
auto ret = KernelMod::Resize(base_operator, inputs, outputs);
if (ret != KRET_OK) {
return ret;
}
input_mean_ = common::AnfAlgo::GetNodeAttr<float>(kernel_node, "mean");
input_std_ = common::AnfAlgo::GetNodeAttr<float>(kernel_node, "std");
input_dtype_ = inputs[kIndex0]->GetDtype();
return KRET_OK;
}
bool LogNormalReverseCpuKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,

View File

@ -17,17 +17,22 @@
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_LOGNORMALREVERSE_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_LOGNORMALREVERSE_CPU_KERNEL_H_
#include <map>
#include <vector>
#include "plugin/device/cpu/kernel/cpu_kernel.h"
#include "plugin/factory/ms_factory.h"
namespace mindspore {
namespace kernel {
class LogNormalReverseCpuKernel : public DeprecatedNativeCpuKernelMod {
class LogNormalReverseCpuKernel : public NativeCpuKernelMod {
public:
LogNormalReverseCpuKernel() = default;
~LogNormalReverseCpuKernel() override = default;
void InitKernel(const CNodePtr &kernel_node) override;
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) override;
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs, const std::map<uint32_t, tensor::TensorPtr> &) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
@ -43,9 +48,6 @@ class LogNormalReverseCpuKernel : public DeprecatedNativeCpuKernelMod {
return support_list;
}
TypeId input_dtype_{kTypeUnknown};
TypeId output_dtype_{kTypeUnknown};
std::vector<int64_t> output_shape_;
std::vector<int64_t> input_shape_;
float input_mean_;
float input_std_;
};

View File

@ -25,16 +25,48 @@
namespace mindspore {
namespace ops {
namespace {
constexpr auto kSigma = "sigma";
constexpr auto kMedian = "median";
constexpr auto kAttrSize = "size";
} // namespace
abstract::ShapePtr CauchyInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
(void)CheckAndConvertUtils::CheckInteger("input numbers", input_args.size(), kGreaterEqual, 0, prim_name);
MS_EXCEPTION_IF_NULL(primitive->GetAttr("size"));
auto size = GetValue<std::vector<int64_t>>(primitive->GetAttr("size"));
MS_EXCEPTION_IF_NULL(primitive->GetAttr(kAttrSize));
auto size = GetValue<std::vector<int64_t>>(primitive->GetAttr(kAttrSize));
(void)CheckAndConvertUtils::CheckInteger("the length of 'size'", size.size(), kGreaterThan, 0, prim_name);
for (size_t i = 0; i < size.size(); ++i) {
if (size[i] <= 0) {
MS_EXCEPTION(ValueError) << "For Cauchy, each dimension of size must be greater than zero.";
}
}
return std::make_shared<abstract::Shape>(size);
}
void Cauchy::set_sigma(float sigma) { (void)this->AddAttr(kSigma, api::MakeValue(sigma)); }
float Cauchy::get_sigma() {
auto value_ptr = this->GetAttr(kSigma);
return GetValue<float>(value_ptr);
}
void Cauchy::set_median(float median) { (void)this->AddAttr(kMedian, api::MakeValue(median)); }
float Cauchy::get_median() {
auto value_ptr = this->GetAttr(kMedian);
return GetValue<float>(value_ptr);
}
void Cauchy::set_size(std::vector<int64_t> size) { (void)this->AddAttr(kAttrSize, api::MakeValue(size)); }
std::vector<int64_t> Cauchy::get_size() {
auto value_ptr = this->GetAttr(kAttrSize);
return GetValue<std::vector<int64_t>>(value_ptr);
}
MIND_API_OPERATOR_IMPL(Cauchy, BaseOperator);
abstract::AbstractBasePtr CauchyInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,

View File

@ -30,6 +30,13 @@ class MIND_API Cauchy : public BaseOperator {
Cauchy() : BaseOperator(kNameCauchy) {}
MIND_API_BASE_MEMBER(Cauchy);
void Init() const {}
void set_sigma(float);
float get_sigma();
void set_median(float);
float get_median();
void set_size(std::vector<int64_t>);
std::vector<int64_t> get_size();
};
abstract::AbstractBasePtr CauchyInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &input_args);