forked from mindspore-Ecosystem/mindspore
!34045 refactor lp_norm gpu kernel for gpu backend.
Merge pull request !34045 from zhuzhongrui/pub_master1
This commit is contained in:
commit
7c7a8c3d9b
|
@ -72,21 +72,22 @@ template <>
|
|||
void CalLpNorm<float>(const float *input, const size_t *input_shape, size_t input_shape_length, size_t input_elements,
|
||||
const size_t *output_axis, const size_t *output_stride, size_t output_shape_length,
|
||||
size_t output_elements, float p, float eps, float *middle_output, float *output,
|
||||
cudaStream_t cuda_stream) {
|
||||
LpCalKernel<<<GET_BLOCKS(input_elements), GET_THREADS, 0, cuda_stream>>>(input, input_shape, input_shape_length,
|
||||
input_elements, output_axis, output_stride,
|
||||
output_shape_length, p, eps, output);
|
||||
NormCalKernel<<<GET_BLOCKS(output_elements), GET_THREADS, 0, cuda_stream>>>(output, output_elements, p, eps);
|
||||
const uint32_t &device_id, cudaStream_t cuda_stream) {
|
||||
LpCalKernel<<<CUDA_BLOCKS(device_id, input_elements), CUDA_THREADS(device_id), 0, cuda_stream>>>(
|
||||
input, input_shape, input_shape_length, input_elements, output_axis, output_stride, output_shape_length, p, eps,
|
||||
output);
|
||||
NormCalKernel<<<CUDA_BLOCKS(device_id, output_elements), CUDA_THREADS(device_id), 0, cuda_stream>>>(
|
||||
output, output_elements, p, eps);
|
||||
}
|
||||
|
||||
template <>
|
||||
void CalLpNorm<half>(const half *input, const size_t *input_shape, size_t input_shape_length, size_t input_elements,
|
||||
const size_t *output_axis, const size_t *output_stride, size_t output_shape_length,
|
||||
size_t output_elements, float p, float eps, float *middle_output, half *output,
|
||||
cudaStream_t cuda_stream) {
|
||||
LpCalKernel<<<GET_BLOCKS(input_elements), GET_THREADS, 0, cuda_stream>>>(input, input_shape, input_shape_length,
|
||||
input_elements, output_axis, output_stride,
|
||||
output_shape_length, p, eps, middle_output);
|
||||
NormCalHighPrecisionKernel<<<GET_BLOCKS(output_elements), GET_THREADS, 0, cuda_stream>>>(middle_output, output,
|
||||
output_elements, p, eps);
|
||||
const uint32_t &device_id, cudaStream_t cuda_stream) {
|
||||
LpCalKernel<<<CUDA_BLOCKS(device_id, input_elements), CUDA_THREADS(device_id), 0, cuda_stream>>>(
|
||||
input, input_shape, input_shape_length, input_elements, output_axis, output_stride, output_shape_length, p, eps,
|
||||
middle_output);
|
||||
NormCalHighPrecisionKernel<<<CUDA_BLOCKS(device_id, output_elements), CUDA_THREADS(device_id), 0, cuda_stream>>>(
|
||||
middle_output, output, output_elements, p, eps);
|
||||
}
|
||||
|
|
|
@ -16,11 +16,12 @@
|
|||
|
||||
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_LPNORM_IMPL_CUH_
|
||||
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_LPNORM_IMPL_CUH_
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_common.h"
|
||||
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_device_info.h"
|
||||
|
||||
template <typename T>
|
||||
CUDA_LIB_EXPORT void CalLpNorm(const T *input, const size_t *input_shape, size_t input_shape_length,
|
||||
size_t input_elements, const size_t *output_axis, const size_t *output_stride,
|
||||
size_t output_shape_length, size_t output_elements, float p, float eps,
|
||||
float *middle_output, T *output, cudaStream_t cuda_stream_);
|
||||
float *middle_output, T *output, const uint32_t &device_id, cudaStream_t cuda_stream_);
|
||||
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_LPNORM_IMPL_CUH_
|
||||
|
|
|
@ -27,117 +27,84 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
void LpNormGpuKernelMod::GetLpNormAttr() {
|
||||
const std::string axis = "axis";
|
||||
if (!kernel_ptr_->HasAttr(axis)) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "' has no kernel attribute: " << axis;
|
||||
bool LpNormGpuKernelMod::GetLpNormAttr(const BaseOperatorPtr &base_operator) {
|
||||
if (kernel_name_ != prim::kPrimLpNorm->name()) {
|
||||
MS_LOG(ERROR) << "For '" << prim::kPrimLpNorm->name() << "' , it's kernel name must be equal to LpNorm, but got "
|
||||
<< kernel_name_;
|
||||
return false;
|
||||
}
|
||||
axis_ = GetValue<std::vector<int64_t>>(kernel_ptr_->GetAttr(axis));
|
||||
const std::string p = "p";
|
||||
if (!kernel_ptr_->HasAttr(p)) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "' has no kernel attribute: " << p;
|
||||
}
|
||||
p_ = static_cast<float>(GetValue<int64_t>(kernel_ptr_->GetAttr(p)));
|
||||
auto kernel_ptr = std::make_shared<ops::LpNorm>(base_operator->GetPrim());
|
||||
|
||||
axis_ = kernel_ptr->get_axis();
|
||||
p_ = static_cast<float>(kernel_ptr->get_p());
|
||||
if (p_ == 0.0f) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "''s op attribute " << p << " equals to zero is invalid.";
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "', it's op attribute 'p' equals to zero is invalid.";
|
||||
return false;
|
||||
}
|
||||
const std::string epsilon = "epsilon";
|
||||
if (!kernel_ptr_->HasAttr(epsilon)) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "' has no kernel attribute: " << epsilon;
|
||||
}
|
||||
epsilon_ = GetValue<float>(kernel_ptr_->GetAttr(epsilon));
|
||||
epsilon_ = kernel_ptr->get_epsilon();
|
||||
return true;
|
||||
}
|
||||
|
||||
bool LpNormGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) {
|
||||
kernel_name_ = base_operator->name();
|
||||
if (inputs.empty() || outputs.empty()) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "' got empty inputs or outputs, which is invalid.";
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "', it got empty inputs or outputs, which is invalid.";
|
||||
return false;
|
||||
}
|
||||
|
||||
// A Code Block For getting launch_kernel function.
|
||||
{
|
||||
kernel_ptr_ = std::make_shared<ops::LpNorm>(base_operator->GetPrim());
|
||||
kernel_name_ = kernel_ptr_->name();
|
||||
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
|
||||
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
|
||||
if (!is_match) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "' does not support this kernel type: " << kernel_attr;
|
||||
return false;
|
||||
}
|
||||
kernel_func_ = func_list_[index].second;
|
||||
unit_size_ = abstract::TypeIdSize(kernel_attr.GetInputAttr(kIndex0).first);
|
||||
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
|
||||
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
|
||||
if (!is_match) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "', it does not support this kernel type: " << kernel_attr;
|
||||
return false;
|
||||
}
|
||||
|
||||
GetLpNormAttr();
|
||||
|
||||
// A Code Block For setting input and output shape.
|
||||
{
|
||||
input_shape_ = std::vector<size_t>(inputs.at(kIndex0)->GetDeviceShapeAdaptively().begin(),
|
||||
inputs.at(kIndex0)->GetDeviceShapeAdaptively().end());
|
||||
input_elements_ = std::accumulate(input_shape_.begin(), input_shape_.end(), 1, std::multiplies<size_t>());
|
||||
is_null_input_ = (input_elements_ == 0);
|
||||
if (is_null_input_) {
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
|
||||
outputs_ = outputs;
|
||||
output_shape_ = std::vector<size_t>(outputs.at(kIndex0)->GetDeviceShapeAdaptively().begin(),
|
||||
outputs.at(kIndex0)->GetDeviceShapeAdaptively().end());
|
||||
|
||||
std::vector<size_t> output_shape;
|
||||
// Ignore dim equal to one.
|
||||
std::copy_if(output_shape_.begin(), output_shape_.end(), std::back_inserter(output_shape),
|
||||
[](size_t dim) { return dim != 1; });
|
||||
output_shape_ = output_shape;
|
||||
std::set<size_t> axis_set(axis_.begin(), axis_.end());
|
||||
for (size_t i = 0; i < input_shape_.size(); ++i) {
|
||||
if (!axis_set.count(i)) {
|
||||
output_axis_.emplace_back(i);
|
||||
}
|
||||
}
|
||||
output_stride_.resize(output_shape_.size());
|
||||
output_stride_[output_stride_.size() - 1] = 1;
|
||||
for (int i = static_cast<int>(output_stride_.size() - 2); i >= 0; --i) {
|
||||
output_stride_[i] = output_stride_[i + 1] * output_shape[i + 1];
|
||||
}
|
||||
output_elements_ = std::accumulate(output_shape_.begin(), output_shape_.end(), 1, std::multiplies<size_t>());
|
||||
InitSizeLists();
|
||||
}
|
||||
|
||||
// A Code Block For dealing with input_dynamic_shape.
|
||||
{
|
||||
if (!is_input_dynamic_shape_.has_value()) {
|
||||
bool is_input_dynamic_shape = false;
|
||||
for (const auto &input : inputs) {
|
||||
auto input_shape = input->GetShapeVector();
|
||||
if (std::any_of(input_shape.begin(), input_shape.end(), [](int64_t dim) { return dim < 0; })) {
|
||||
is_input_dynamic_shape = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
is_input_dynamic_shape_ = is_input_dynamic_shape;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
kernel_func_ = func_list_[index].second;
|
||||
return GetLpNormAttr(base_operator);
|
||||
}
|
||||
|
||||
int LpNormGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs,
|
||||
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost) {
|
||||
if (is_input_dynamic_shape_.has_value() && is_input_dynamic_shape_.value()) {
|
||||
DestroyResource();
|
||||
ResetResource();
|
||||
if (!Init(base_operator, inputs, outputs)) {
|
||||
return KRET_RESIZE_FAILED;
|
||||
}
|
||||
return 0;
|
||||
} else {
|
||||
kernel_ptr_ = base_operator;
|
||||
outputs_ = outputs;
|
||||
return 0;
|
||||
if (int ret = KernelMod::Resize(base_operator, inputs, outputs); ret != KRET_OK) {
|
||||
return ret;
|
||||
}
|
||||
unit_size_ = abstract::TypeIdSize(inputs.at(kIndex0)->GetDtype());
|
||||
|
||||
input_shape_.clear();
|
||||
auto input_shape = inputs.at(kIndex0)->GetShapeVector();
|
||||
(void)std::transform(input_shape.begin(), input_shape.end(), std::back_inserter(input_shape_), LongToSize);
|
||||
input_elements_ = std::accumulate(input_shape_.begin(), input_shape_.end(), 1, std::multiplies<size_t>());
|
||||
is_null_input_ = (input_elements_ == 0);
|
||||
if (is_null_input_) {
|
||||
return KRET_OK;
|
||||
}
|
||||
|
||||
output_shape_.clear();
|
||||
auto output_shape = outputs.at(kIndex0)->GetShapeVector();
|
||||
// Ignore dim equal to one.
|
||||
for (const auto &dim : output_shape) {
|
||||
if (dim != 1) {
|
||||
output_shape_.emplace_back(LongToSize(dim));
|
||||
}
|
||||
}
|
||||
|
||||
output_axis_.clear();
|
||||
std::set<size_t> axis_set(axis_.begin(), axis_.end());
|
||||
for (size_t i = 0; i < input_shape_.size(); ++i) {
|
||||
if (!axis_set.count(i)) {
|
||||
output_axis_.emplace_back(i);
|
||||
}
|
||||
}
|
||||
|
||||
output_stride_.clear();
|
||||
output_stride_.resize(output_shape_.size());
|
||||
output_stride_[output_stride_.size() - 1] = 1;
|
||||
for (int i = static_cast<int>(output_stride_.size() - 2); i >= 0; --i) {
|
||||
output_stride_[i] = output_stride_[i + 1] * output_shape[i + 1];
|
||||
}
|
||||
output_elements_ = std::accumulate(output_shape_.begin(), output_shape_.end(), 1, std::multiplies<size_t>());
|
||||
InitWorkSpaceSizeList();
|
||||
return KRET_OK;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
|
@ -153,17 +120,17 @@ bool LpNormGpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs, con
|
|||
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
|
||||
cudaMemcpyAsync(device_input_shape, &input_shape_[0], input_shape_.size() * sizeof(size_t), cudaMemcpyHostToDevice,
|
||||
reinterpret_cast<cudaStream_t>(cuda_stream_)),
|
||||
"cudaMemcpyAsync input_shape_ failed");
|
||||
"LpNormGpuKernelMod cudaMemcpyAsync input_shape_ failed");
|
||||
|
||||
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
|
||||
cudaMemcpyAsync(device_axis_output, &output_axis_[0], output_axis_.size() * sizeof(size_t), cudaMemcpyHostToDevice,
|
||||
reinterpret_cast<cudaStream_t>(cuda_stream_)),
|
||||
"cudaMemcpyAsync output_axis_ failed");
|
||||
"LpNormGpuKernelMod cudaMemcpyAsync output_axis_ failed");
|
||||
|
||||
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
|
||||
cudaMemcpyAsync(device_output_stride, &output_stride_[0], output_stride_.size() * sizeof(size_t),
|
||||
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(cuda_stream_)),
|
||||
"cudaMemcpyAsync output_shape_ failed");
|
||||
"LpNormGpuKernelMod cudaMemcpyAsync output_shape_ failed");
|
||||
|
||||
// The workspace for device output high precision.
|
||||
if constexpr (std::is_same_v<T, half>) {
|
||||
|
@ -171,18 +138,18 @@ bool LpNormGpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs, con
|
|||
"cudaStremSynchronize failed");
|
||||
constexpr auto high_precision_unit = 2;
|
||||
size_t device_output_stride_size = output_elements_ * unit_size_ * high_precision_unit;
|
||||
float *middle_output = nullptr;
|
||||
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaMalloc(&middle_output, device_output_stride_size),
|
||||
"cudaMalloc output_shape_ failed");
|
||||
auto middle_output = reinterpret_cast<float *>(
|
||||
device::gpu::GPUMemoryAllocator::GetInstance().AllocTensorMem(device_output_stride_size));
|
||||
CHECK_CUDA_RET_WITH_ERROR_NOTRACE(cudaMemset(middle_output, 0, device_output_stride_size),
|
||||
"LpNormGpuKernelMod failed to set cuda memory to zeros.");
|
||||
CalLpNorm(input, device_input_shape, input_shape_.size(), input_elements_, device_axis_output, device_output_stride,
|
||||
output_axis_.size(), output_elements_, p_, epsilon_, middle_output, output,
|
||||
output_axis_.size(), output_elements_, p_, epsilon_, middle_output, output, device_id_,
|
||||
reinterpret_cast<cudaStream_t>(cuda_stream_));
|
||||
} else {
|
||||
CalLpNorm(input, device_input_shape, input_shape_.size(), input_elements_, device_axis_output, device_output_stride,
|
||||
output_axis_.size(), output_elements_, p_, epsilon_, nullptr, output,
|
||||
output_axis_.size(), output_elements_, p_, epsilon_, nullptr, output, device_id_,
|
||||
reinterpret_cast<cudaStream_t>(cuda_stream_));
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_LPNORM_GPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_LPNORM_GPU_KERNEL_H_
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
|
@ -27,7 +28,7 @@ namespace mindspore {
|
|||
namespace kernel {
|
||||
class LpNormGpuKernelMod : public NativeGpuKernelMod {
|
||||
public:
|
||||
LpNormGpuKernelMod() { ResetResource(); }
|
||||
LpNormGpuKernelMod() = default;
|
||||
~LpNormGpuKernelMod() override = default;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
|
@ -47,42 +48,24 @@ class LpNormGpuKernelMod : public NativeGpuKernelMod {
|
|||
const std::vector<KernelTensorPtr> &outputs,
|
||||
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost = std::map<uint32_t, tensor::TensorPtr>()) override;
|
||||
|
||||
std::vector<KernelTensorPtr> GetOutputs() override { return outputs_; }
|
||||
|
||||
void ResetResource() noexcept {
|
||||
is_null_input_ = false;
|
||||
cuda_stream_ = nullptr;
|
||||
input_size_list_.clear();
|
||||
output_size_list_.clear();
|
||||
workspace_size_list_.clear();
|
||||
}
|
||||
|
||||
protected:
|
||||
void InitSizeLists() {
|
||||
input_size_list_.clear();
|
||||
output_size_list_.clear();
|
||||
workspace_size_list_.clear();
|
||||
|
||||
input_size_list_.emplace_back(input_elements_ * unit_size_);
|
||||
// The workspace for device input shape.
|
||||
size_t device_input_shape_size = input_shape_.size() * sizeof(size_t);
|
||||
// The workspace for device output shape.
|
||||
size_t device_output_shape_size = output_shape_.size() * sizeof(size_t);
|
||||
// The workspace for device output axis.
|
||||
size_t device_axis_shape_size = output_axis_.size() * sizeof(size_t);
|
||||
// The workspace for device output stride.
|
||||
size_t device_output_stride_size = output_stride_.size() * sizeof(size_t);
|
||||
|
||||
workspace_size_list_.emplace_back(device_input_shape_size);
|
||||
workspace_size_list_.emplace_back(device_output_shape_size);
|
||||
workspace_size_list_.emplace_back(device_axis_shape_size);
|
||||
workspace_size_list_.emplace_back(device_output_stride_size);
|
||||
output_size_list_.emplace_back(output_elements_ * unit_size_);
|
||||
}
|
||||
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
void InitWorkSpaceSizeList() {
|
||||
// The workspace for device input shape.
|
||||
const size_t device_input_shape_size = input_shape_.size() * sizeof(size_t);
|
||||
// The workspace for device output shape.
|
||||
const size_t device_output_shape_size = output_shape_.size() * sizeof(size_t);
|
||||
// The workspace for device output axis.
|
||||
const size_t device_axis_shape_size = output_axis_.size() * sizeof(size_t);
|
||||
// The workspace for device output stride.
|
||||
const size_t device_output_stride_size = output_stride_.size() * sizeof(size_t);
|
||||
workspace_size_list_.clear();
|
||||
workspace_size_list_ = {device_input_shape_size, device_output_shape_size, device_axis_shape_size,
|
||||
device_output_stride_size};
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs);
|
||||
|
@ -90,25 +73,20 @@ class LpNormGpuKernelMod : public NativeGpuKernelMod {
|
|||
std::function<bool(LpNormGpuKernelMod *, const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &, const std::vector<kernel::AddressPtr> &)>;
|
||||
|
||||
void GetLpNormAttr();
|
||||
bool GetLpNormAttr(const BaseOperatorPtr &base_operator);
|
||||
|
||||
private:
|
||||
size_t unit_size_{1};
|
||||
float p_{2.0};
|
||||
float epsilon_{1e-12};
|
||||
std::vector<int64_t> axis_;
|
||||
void *cuda_stream_{nullptr};
|
||||
bool is_null_input_{false};
|
||||
|
||||
std::optional<bool> is_input_dynamic_shape_{};
|
||||
BaseOperatorPtr kernel_ptr_{nullptr};
|
||||
std::vector<size_t> input_shape_;
|
||||
std::vector<size_t> output_shape_;
|
||||
std::vector<size_t> output_axis_;
|
||||
std::vector<size_t> output_stride_;
|
||||
size_t input_elements_{};
|
||||
size_t output_elements_{};
|
||||
std::vector<KernelTensorPtr> outputs_ = {};
|
||||
LpNormFunc kernel_func_;
|
||||
static std::vector<std::pair<KernelAttr, LpNormFunc>> func_list_;
|
||||
};
|
||||
|
|
|
@ -110,11 +110,20 @@ AbstractBasePtr LpNormInfer(const abstract::AnalysisEnginePtr &, const Primitive
|
|||
return abstract::MakeAbstract(infer_shape, infer_type);
|
||||
}
|
||||
|
||||
void LpNorm::Init(const int64_t p, const float epsilon) {
|
||||
void LpNorm::Init(const std::vector<int64_t> &axis, const int64_t p, const bool keep_dims, const float epsilon) {
|
||||
this->set_axis(axis);
|
||||
this->set_p(p);
|
||||
this->set_keep_dims(keep_dims);
|
||||
this->set_epsilon(epsilon);
|
||||
}
|
||||
|
||||
void LpNorm::set_axis(const std::vector<int64_t> &axis) { (void)this->AddAttr(kAxis, api::MakeValue(axis)); }
|
||||
|
||||
std::vector<int64_t> LpNorm::get_axis() const {
|
||||
auto value_ptr = this->GetAttr(kAxis);
|
||||
return GetValue<std::vector<int64_t>>(value_ptr);
|
||||
}
|
||||
|
||||
void LpNorm::set_p(const int64_t p) { (void)this->AddAttr(kP, api::MakeValue(p)); }
|
||||
|
||||
int64_t LpNorm::get_p() const {
|
||||
|
@ -122,6 +131,13 @@ int64_t LpNorm::get_p() const {
|
|||
return GetValue<int64_t>(value_ptr);
|
||||
}
|
||||
|
||||
void LpNorm::set_keep_dims(const bool keep_dims) { (void)this->AddAttr(kKeepDims, api::MakeValue(keep_dims)); }
|
||||
|
||||
bool LpNorm::get_keep_dims() const {
|
||||
auto value_ptr = this->GetAttr(kKeepDims);
|
||||
return GetValue<bool>(value_ptr);
|
||||
}
|
||||
|
||||
void LpNorm::set_epsilon(const float epsilon) { (void)this->AddAttr(kEpsilon, api::MakeValue(epsilon)); }
|
||||
|
||||
float LpNorm::get_epsilon() const {
|
||||
|
|
|
@ -30,14 +30,23 @@ class MIND_API LpNorm : public BaseOperator {
|
|||
MIND_API_BASE_MEMBER(LpNorm);
|
||||
LpNorm() : BaseOperator(kNameLpNorm) { InitIOName({"input"}, {"output"}); }
|
||||
|
||||
void Init(const int64_t p = 2, const float epsilon = 1e-12);
|
||||
void Init(const std::vector<int64_t> &axis, const int64_t p = 2, const bool keep_dims = false,
|
||||
const float epsilon = 1e-12);
|
||||
|
||||
void set_axis(const std::vector<int64_t> &axis);
|
||||
|
||||
void set_keep_dims(const bool keep_dims);
|
||||
|
||||
void set_p(const int64_t p);
|
||||
|
||||
void set_epsilon(const float epsilon);
|
||||
|
||||
std::vector<int64_t> get_axis() const;
|
||||
|
||||
int64_t get_p() const;
|
||||
|
||||
bool get_keep_dims() const;
|
||||
|
||||
float get_epsilon() const;
|
||||
};
|
||||
|
||||
|
|
Loading…
Reference in New Issue