adapt Cauchy and LogNormalReverse kernelmod.
This commit is contained in:
parent
af9bf8fe71
commit
389e6d93d9
|
@ -21,34 +21,31 @@
|
|||
#include <memory>
|
||||
#include <functional>
|
||||
#include <random>
|
||||
#include "mindspore/core/ops/cauchy.h"
|
||||
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel.h"
|
||||
#include "plugin/device/cpu/kernel/arithmetic_cpu_kernel.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
const size_t kCauchyOutputNum = 1;
|
||||
namespace {
|
||||
constexpr size_t kCauchyOutputNum = 1;
|
||||
constexpr auto kAttrSigma = "sigma";
|
||||
constexpr auto kAttrMedian = "median";
|
||||
} // namespace
|
||||
|
||||
// namespace
|
||||
|
||||
void CauchyCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
size_t output_num = common::AnfAlgo::GetOutputTensorNum(kernel_node);
|
||||
CHECK_KERNEL_OUTPUTS_NUM(output_num, kCauchyOutputNum, common::AnfAlgo::GetCNodeName(kernel_node));
|
||||
|
||||
std::vector<int64_t> size_ = common::AnfAlgo::GetNodeAttr<std::vector<int64_t>>(kernel_node, "size");
|
||||
sigma_ = common::AnfAlgo::GetNodeAttr<float>(kernel_node, "sigma");
|
||||
median_ = common::AnfAlgo::GetNodeAttr<float>(kernel_node, "median");
|
||||
auto y_shape = common::AnfAlgo::GetOutputInferShape(kernel_node, 0);
|
||||
for (size_t i = 0; i < size_.size(); i++) {
|
||||
if (size_[i] <= 0) {
|
||||
MS_EXCEPTION(ValueError) << "For Cauchy, each dimension of size must be greater than zero.";
|
||||
}
|
||||
if (size_[i] != y_shape[i]) {
|
||||
MS_EXCEPTION(ValueError) << "For Cauchy, output shape not equal with size in dimension " << i << " .";
|
||||
}
|
||||
}
|
||||
bool CauchyCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) {
|
||||
MS_ERROR_IF_NULL(base_operator);
|
||||
kernel_name_ = base_operator->name();
|
||||
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kCauchyOutputNum, kernel_name_);
|
||||
auto prim = std::dynamic_pointer_cast<ops::Cauchy>(base_operator);
|
||||
MS_ERROR_IF_NULL(prim);
|
||||
sigma_ = prim->get_sigma();
|
||||
median_ = prim->get_median();
|
||||
return true;
|
||||
}
|
||||
|
||||
bool CauchyCpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &, const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
LaunchKernel<float>(outputs);
|
||||
|
|
|
@ -27,11 +27,14 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
class CauchyCpuKernelMod : public DeprecatedNativeCpuKernelMod {
|
||||
class CauchyCpuKernelMod : public NativeCpuKernelMod {
|
||||
public:
|
||||
CauchyCpuKernelMod() = default;
|
||||
~CauchyCpuKernelMod() override = default;
|
||||
void InitKernel(const CNodePtr &kernel_node) override;
|
||||
|
||||
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) override;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
|
||||
|
@ -42,7 +45,8 @@ class CauchyCpuKernelMod : public DeprecatedNativeCpuKernelMod {
|
|||
template <typename T>
|
||||
bool LaunchKernel(const std::vector<AddressPtr> &outputs);
|
||||
|
||||
float sigma_ = 1.0, median_ = 0;
|
||||
float sigma_{1.0};
|
||||
float median_{0};
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -26,29 +26,38 @@
|
|||
namespace mindspore {
|
||||
namespace kernel {
|
||||
namespace {
|
||||
const uint32_t kNumInput = 1;
|
||||
const uint32_t kNumOutput = 1;
|
||||
constexpr uint32_t kNumInput = 1;
|
||||
constexpr uint32_t kNumOutput = 1;
|
||||
constexpr auto kAttrMean = "mean";
|
||||
constexpr auto kAttrStd = "std";
|
||||
} // namespace
|
||||
|
||||
void LogNormalReverseCpuKernel::InitKernel(const CNodePtr &kernel_node) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
|
||||
input_dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
|
||||
input_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
|
||||
output_dtype_ = AnfAlgo::GetOutputDeviceDataType(kernel_node, 0);
|
||||
output_shape_ = AnfAlgo::GetOutputDeviceShape(kernel_node, 0);
|
||||
if (input_dtype_ != kNumberTypeFloat32 && input_dtype_ != kNumberTypeFloat16) {
|
||||
if (input_dtype_ != kNumberTypeFloat64) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_
|
||||
<< ", the datatype of the input1 not support, support datatype: float16, float32, float64.";
|
||||
}
|
||||
bool LogNormalReverseCpuKernel::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) {
|
||||
MS_ERROR_IF_NULL(base_operator);
|
||||
kernel_name_ = base_operator->name();
|
||||
auto prim = base_operator->GetPrim();
|
||||
MS_ERROR_IF_NULL(prim);
|
||||
input_mean_ = GetValue<float>(prim->GetAttr(kAttrMean));
|
||||
input_std_ = GetValue<float>(prim->GetAttr(kAttrStd));
|
||||
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
|
||||
auto is_match = MatchKernelAttr(kernel_attr, GetOpSupport());
|
||||
if (!is_match.first) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "', it does not support this kernel data type: " << kernel_attr;
|
||||
return false;
|
||||
}
|
||||
if (input_dtype_ != output_dtype_) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_
|
||||
<< ", the data type of the input does not match the data type of the output.";
|
||||
return true;
|
||||
}
|
||||
|
||||
int LogNormalReverseCpuKernel::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs,
|
||||
const std::map<uint32_t, tensor::TensorPtr> &) {
|
||||
auto ret = KernelMod::Resize(base_operator, inputs, outputs);
|
||||
if (ret != KRET_OK) {
|
||||
return ret;
|
||||
}
|
||||
input_mean_ = common::AnfAlgo::GetNodeAttr<float>(kernel_node, "mean");
|
||||
input_std_ = common::AnfAlgo::GetNodeAttr<float>(kernel_node, "std");
|
||||
input_dtype_ = inputs[kIndex0]->GetDtype();
|
||||
return KRET_OK;
|
||||
}
|
||||
|
||||
bool LogNormalReverseCpuKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
||||
|
|
|
@ -17,17 +17,22 @@
|
|||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_LOGNORMALREVERSE_CPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_LOGNORMALREVERSE_CPU_KERNEL_H_
|
||||
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel.h"
|
||||
#include "plugin/factory/ms_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
class LogNormalReverseCpuKernel : public DeprecatedNativeCpuKernelMod {
|
||||
class LogNormalReverseCpuKernel : public NativeCpuKernelMod {
|
||||
public:
|
||||
LogNormalReverseCpuKernel() = default;
|
||||
~LogNormalReverseCpuKernel() override = default;
|
||||
void InitKernel(const CNodePtr &kernel_node) override;
|
||||
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) override;
|
||||
|
||||
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs, const std::map<uint32_t, tensor::TensorPtr> &) override;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
|
@ -43,9 +48,6 @@ class LogNormalReverseCpuKernel : public DeprecatedNativeCpuKernelMod {
|
|||
return support_list;
|
||||
}
|
||||
TypeId input_dtype_{kTypeUnknown};
|
||||
TypeId output_dtype_{kTypeUnknown};
|
||||
std::vector<int64_t> output_shape_;
|
||||
std::vector<int64_t> input_shape_;
|
||||
float input_mean_;
|
||||
float input_std_;
|
||||
};
|
||||
|
|
|
@ -25,16 +25,48 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
constexpr auto kSigma = "sigma";
|
||||
constexpr auto kMedian = "median";
|
||||
constexpr auto kAttrSize = "size";
|
||||
} // namespace
|
||||
|
||||
abstract::ShapePtr CauchyInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
(void)CheckAndConvertUtils::CheckInteger("input numbers", input_args.size(), kGreaterEqual, 0, prim_name);
|
||||
MS_EXCEPTION_IF_NULL(primitive->GetAttr("size"));
|
||||
auto size = GetValue<std::vector<int64_t>>(primitive->GetAttr("size"));
|
||||
MS_EXCEPTION_IF_NULL(primitive->GetAttr(kAttrSize));
|
||||
auto size = GetValue<std::vector<int64_t>>(primitive->GetAttr(kAttrSize));
|
||||
(void)CheckAndConvertUtils::CheckInteger("the length of 'size'", size.size(), kGreaterThan, 0, prim_name);
|
||||
for (size_t i = 0; i < size.size(); ++i) {
|
||||
if (size[i] <= 0) {
|
||||
MS_EXCEPTION(ValueError) << "For Cauchy, each dimension of size must be greater than zero.";
|
||||
}
|
||||
}
|
||||
return std::make_shared<abstract::Shape>(size);
|
||||
}
|
||||
|
||||
void Cauchy::set_sigma(float sigma) { (void)this->AddAttr(kSigma, api::MakeValue(sigma)); }
|
||||
|
||||
float Cauchy::get_sigma() {
|
||||
auto value_ptr = this->GetAttr(kSigma);
|
||||
return GetValue<float>(value_ptr);
|
||||
}
|
||||
|
||||
void Cauchy::set_median(float median) { (void)this->AddAttr(kMedian, api::MakeValue(median)); }
|
||||
|
||||
float Cauchy::get_median() {
|
||||
auto value_ptr = this->GetAttr(kMedian);
|
||||
return GetValue<float>(value_ptr);
|
||||
}
|
||||
|
||||
void Cauchy::set_size(std::vector<int64_t> size) { (void)this->AddAttr(kAttrSize, api::MakeValue(size)); }
|
||||
|
||||
std::vector<int64_t> Cauchy::get_size() {
|
||||
auto value_ptr = this->GetAttr(kAttrSize);
|
||||
return GetValue<std::vector<int64_t>>(value_ptr);
|
||||
}
|
||||
|
||||
MIND_API_OPERATOR_IMPL(Cauchy, BaseOperator);
|
||||
|
||||
abstract::AbstractBasePtr CauchyInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
|
|
|
@ -30,6 +30,13 @@ class MIND_API Cauchy : public BaseOperator {
|
|||
Cauchy() : BaseOperator(kNameCauchy) {}
|
||||
MIND_API_BASE_MEMBER(Cauchy);
|
||||
void Init() const {}
|
||||
|
||||
void set_sigma(float);
|
||||
float get_sigma();
|
||||
void set_median(float);
|
||||
float get_median();
|
||||
void set_size(std::vector<int64_t>);
|
||||
std::vector<int64_t> get_size();
|
||||
};
|
||||
abstract::AbstractBasePtr CauchyInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||
|
|
Loading…
Reference in New Issue