diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/lp_norm_impl.cu b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/lp_norm_impl.cu index bca840e6de9..e5f19aaecbb 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/lp_norm_impl.cu +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/lp_norm_impl.cu @@ -72,21 +72,22 @@ template <> void CalLpNorm(const float *input, const size_t *input_shape, size_t input_shape_length, size_t input_elements, const size_t *output_axis, const size_t *output_stride, size_t output_shape_length, size_t output_elements, float p, float eps, float *middle_output, float *output, - cudaStream_t cuda_stream) { - LpCalKernel<<>>(input, input_shape, input_shape_length, - input_elements, output_axis, output_stride, - output_shape_length, p, eps, output); - NormCalKernel<<>>(output, output_elements, p, eps); + const uint32_t &device_id, cudaStream_t cuda_stream) { + LpCalKernel<<>>( + input, input_shape, input_shape_length, input_elements, output_axis, output_stride, output_shape_length, p, eps, + output); + NormCalKernel<<>>( + output, output_elements, p, eps); } template <> void CalLpNorm(const half *input, const size_t *input_shape, size_t input_shape_length, size_t input_elements, const size_t *output_axis, const size_t *output_stride, size_t output_shape_length, size_t output_elements, float p, float eps, float *middle_output, half *output, - cudaStream_t cuda_stream) { - LpCalKernel<<>>(input, input_shape, input_shape_length, - input_elements, output_axis, output_stride, - output_shape_length, p, eps, middle_output); - NormCalHighPrecisionKernel<<>>(middle_output, output, - output_elements, p, eps); + const uint32_t &device_id, cudaStream_t cuda_stream) { + LpCalKernel<<>>( + input, input_shape, input_shape_length, input_elements, output_axis, output_stride, output_shape_length, p, eps, + middle_output); + NormCalHighPrecisionKernel<<>>( + middle_output, output, output_elements, p, eps); } diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/lp_norm_impl.cuh b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/lp_norm_impl.cuh index af8237026be..15abe956c91 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/lp_norm_impl.cuh +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/lp_norm_impl.cuh @@ -16,11 +16,12 @@ #ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_LPNORM_IMPL_CUH_ #define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_LPNORM_IMPL_CUH_ -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_common.h" + +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_device_info.h" template CUDA_LIB_EXPORT void CalLpNorm(const T *input, const size_t *input_shape, size_t input_shape_length, size_t input_elements, const size_t *output_axis, const size_t *output_stride, size_t output_shape_length, size_t output_elements, float p, float eps, - float *middle_output, T *output, cudaStream_t cuda_stream_); + float *middle_output, T *output, const uint32_t &device_id, cudaStream_t cuda_stream_); #endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_LPNORM_IMPL_CUH_ diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/math/lp_norm_gpu_kernel.cc b/mindspore/ccsrc/plugin/device/gpu/kernel/math/lp_norm_gpu_kernel.cc index 64b9210eeb7..a30abeb1547 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/math/lp_norm_gpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/math/lp_norm_gpu_kernel.cc @@ -27,117 +27,84 @@ namespace mindspore { namespace kernel { -void LpNormGpuKernelMod::GetLpNormAttr() { - const std::string axis = "axis"; - if (!kernel_ptr_->HasAttr(axis)) { - MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "' has no kernel attribute: " << axis; +bool LpNormGpuKernelMod::GetLpNormAttr(const BaseOperatorPtr &base_operator) { + if (kernel_name_ != prim::kPrimLpNorm->name()) { + MS_LOG(ERROR) << "For '" << prim::kPrimLpNorm->name() << "' , it's kernel name must be equal to LpNorm, but got " + << kernel_name_; + return false; } - axis_ = GetValue>(kernel_ptr_->GetAttr(axis)); - const std::string p = "p"; - if (!kernel_ptr_->HasAttr(p)) { - MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "' has no kernel attribute: " << p; - } - p_ = static_cast(GetValue(kernel_ptr_->GetAttr(p))); + auto kernel_ptr = std::make_shared(base_operator->GetPrim()); + + axis_ = kernel_ptr->get_axis(); + p_ = static_cast(kernel_ptr->get_p()); if (p_ == 0.0f) { - MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "''s op attribute " << p << " equals to zero is invalid."; + MS_LOG(ERROR) << "For '" << kernel_name_ << "', it's op attribute 'p' equals to zero is invalid."; + return false; } - const std::string epsilon = "epsilon"; - if (!kernel_ptr_->HasAttr(epsilon)) { - MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "' has no kernel attribute: " << epsilon; - } - epsilon_ = GetValue(kernel_ptr_->GetAttr(epsilon)); + epsilon_ = kernel_ptr->get_epsilon(); + return true; } bool LpNormGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector &inputs, const std::vector &outputs) { + kernel_name_ = base_operator->name(); if (inputs.empty() || outputs.empty()) { - MS_LOG(ERROR) << "For '" << kernel_name_ << "' got empty inputs or outputs, which is invalid."; + MS_LOG(ERROR) << "For '" << kernel_name_ << "', it got empty inputs or outputs, which is invalid."; return false; } - - // A Code Block For getting launch_kernel function. - { - kernel_ptr_ = std::make_shared(base_operator->GetPrim()); - kernel_name_ = kernel_ptr_->name(); - auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs); - auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport()); - if (!is_match) { - MS_LOG(ERROR) << "For '" << kernel_name_ << "' does not support this kernel type: " << kernel_attr; - return false; - } - kernel_func_ = func_list_[index].second; - unit_size_ = abstract::TypeIdSize(kernel_attr.GetInputAttr(kIndex0).first); + auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs); + auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport()); + if (!is_match) { + MS_LOG(ERROR) << "For '" << kernel_name_ << "', it does not support this kernel type: " << kernel_attr; + return false; } - - GetLpNormAttr(); - - // A Code Block For setting input and output shape. - { - input_shape_ = std::vector(inputs.at(kIndex0)->GetDeviceShapeAdaptively().begin(), - inputs.at(kIndex0)->GetDeviceShapeAdaptively().end()); - input_elements_ = std::accumulate(input_shape_.begin(), input_shape_.end(), 1, std::multiplies()); - is_null_input_ = (input_elements_ == 0); - if (is_null_input_) { - InitSizeLists(); - return true; - } - - outputs_ = outputs; - output_shape_ = std::vector(outputs.at(kIndex0)->GetDeviceShapeAdaptively().begin(), - outputs.at(kIndex0)->GetDeviceShapeAdaptively().end()); - - std::vector output_shape; - // Ignore dim equal to one. - std::copy_if(output_shape_.begin(), output_shape_.end(), std::back_inserter(output_shape), - [](size_t dim) { return dim != 1; }); - output_shape_ = output_shape; - std::set axis_set(axis_.begin(), axis_.end()); - for (size_t i = 0; i < input_shape_.size(); ++i) { - if (!axis_set.count(i)) { - output_axis_.emplace_back(i); - } - } - output_stride_.resize(output_shape_.size()); - output_stride_[output_stride_.size() - 1] = 1; - for (int i = static_cast(output_stride_.size() - 2); i >= 0; --i) { - output_stride_[i] = output_stride_[i + 1] * output_shape[i + 1]; - } - output_elements_ = std::accumulate(output_shape_.begin(), output_shape_.end(), 1, std::multiplies()); - InitSizeLists(); - } - - // A Code Block For dealing with input_dynamic_shape. - { - if (!is_input_dynamic_shape_.has_value()) { - bool is_input_dynamic_shape = false; - for (const auto &input : inputs) { - auto input_shape = input->GetShapeVector(); - if (std::any_of(input_shape.begin(), input_shape.end(), [](int64_t dim) { return dim < 0; })) { - is_input_dynamic_shape = true; - break; - } - } - is_input_dynamic_shape_ = is_input_dynamic_shape; - } - } - return true; + kernel_func_ = func_list_[index].second; + return GetLpNormAttr(base_operator); } int LpNormGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector &inputs, const std::vector &outputs, const std::map &inputsOnHost) { - if (is_input_dynamic_shape_.has_value() && is_input_dynamic_shape_.value()) { - DestroyResource(); - ResetResource(); - if (!Init(base_operator, inputs, outputs)) { - return KRET_RESIZE_FAILED; - } - return 0; - } else { - kernel_ptr_ = base_operator; - outputs_ = outputs; - return 0; + if (int ret = KernelMod::Resize(base_operator, inputs, outputs); ret != KRET_OK) { + return ret; } + unit_size_ = abstract::TypeIdSize(inputs.at(kIndex0)->GetDtype()); + + input_shape_.clear(); + auto input_shape = inputs.at(kIndex0)->GetShapeVector(); + (void)std::transform(input_shape.begin(), input_shape.end(), std::back_inserter(input_shape_), LongToSize); + input_elements_ = std::accumulate(input_shape_.begin(), input_shape_.end(), 1, std::multiplies()); + is_null_input_ = (input_elements_ == 0); + if (is_null_input_) { + return KRET_OK; + } + + output_shape_.clear(); + auto output_shape = outputs.at(kIndex0)->GetShapeVector(); + // Ignore dim equal to one. + for (const auto &dim : output_shape) { + if (dim != 1) { + output_shape_.emplace_back(LongToSize(dim)); + } + } + + output_axis_.clear(); + std::set axis_set(axis_.begin(), axis_.end()); + for (size_t i = 0; i < input_shape_.size(); ++i) { + if (!axis_set.count(i)) { + output_axis_.emplace_back(i); + } + } + + output_stride_.clear(); + output_stride_.resize(output_shape_.size()); + output_stride_[output_stride_.size() - 1] = 1; + for (int i = static_cast(output_stride_.size() - 2); i >= 0; --i) { + output_stride_[i] = output_stride_[i + 1] * output_shape[i + 1]; + } + output_elements_ = std::accumulate(output_shape_.begin(), output_shape_.end(), 1, std::multiplies()); + InitWorkSpaceSizeList(); + return KRET_OK; } template @@ -153,17 +120,17 @@ bool LpNormGpuKernelMod::LaunchKernel(const std::vector &inputs, con CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE( cudaMemcpyAsync(device_input_shape, &input_shape_[0], input_shape_.size() * sizeof(size_t), cudaMemcpyHostToDevice, reinterpret_cast(cuda_stream_)), - "cudaMemcpyAsync input_shape_ failed"); + "LpNormGpuKernelMod cudaMemcpyAsync input_shape_ failed"); CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE( cudaMemcpyAsync(device_axis_output, &output_axis_[0], output_axis_.size() * sizeof(size_t), cudaMemcpyHostToDevice, reinterpret_cast(cuda_stream_)), - "cudaMemcpyAsync output_axis_ failed"); + "LpNormGpuKernelMod cudaMemcpyAsync output_axis_ failed"); CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE( cudaMemcpyAsync(device_output_stride, &output_stride_[0], output_stride_.size() * sizeof(size_t), cudaMemcpyHostToDevice, reinterpret_cast(cuda_stream_)), - "cudaMemcpyAsync output_shape_ failed"); + "LpNormGpuKernelMod cudaMemcpyAsync output_shape_ failed"); // The workspace for device output high precision. if constexpr (std::is_same_v) { @@ -171,18 +138,18 @@ bool LpNormGpuKernelMod::LaunchKernel(const std::vector &inputs, con "cudaStremSynchronize failed"); constexpr auto high_precision_unit = 2; size_t device_output_stride_size = output_elements_ * unit_size_ * high_precision_unit; - float *middle_output = nullptr; - CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaMalloc(&middle_output, device_output_stride_size), - "cudaMalloc output_shape_ failed"); + auto middle_output = reinterpret_cast( + device::gpu::GPUMemoryAllocator::GetInstance().AllocTensorMem(device_output_stride_size)); + CHECK_CUDA_RET_WITH_ERROR_NOTRACE(cudaMemset(middle_output, 0, device_output_stride_size), + "LpNormGpuKernelMod failed to set cuda memory to zeros."); CalLpNorm(input, device_input_shape, input_shape_.size(), input_elements_, device_axis_output, device_output_stride, - output_axis_.size(), output_elements_, p_, epsilon_, middle_output, output, + output_axis_.size(), output_elements_, p_, epsilon_, middle_output, output, device_id_, reinterpret_cast(cuda_stream_)); } else { CalLpNorm(input, device_input_shape, input_shape_.size(), input_elements_, device_axis_output, device_output_stride, - output_axis_.size(), output_elements_, p_, epsilon_, nullptr, output, + output_axis_.size(), output_elements_, p_, epsilon_, nullptr, output, device_id_, reinterpret_cast(cuda_stream_)); } - return true; } diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/math/lp_norm_gpu_kernel.h b/mindspore/ccsrc/plugin/device/gpu/kernel/math/lp_norm_gpu_kernel.h index 3dd5ea2dc1e..5530677eb6d 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/math/lp_norm_gpu_kernel.h +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/math/lp_norm_gpu_kernel.h @@ -16,6 +16,7 @@ #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_LPNORM_GPU_KERNEL_H_ #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_LPNORM_GPU_KERNEL_H_ + #include #include #include @@ -27,7 +28,7 @@ namespace mindspore { namespace kernel { class LpNormGpuKernelMod : public NativeGpuKernelMod { public: - LpNormGpuKernelMod() { ResetResource(); } + LpNormGpuKernelMod() = default; ~LpNormGpuKernelMod() override = default; bool Launch(const std::vector &inputs, const std::vector &workspace, @@ -47,42 +48,24 @@ class LpNormGpuKernelMod : public NativeGpuKernelMod { const std::vector &outputs, const std::map &inputsOnHost = std::map()) override; - std::vector GetOutputs() override { return outputs_; } - - void ResetResource() noexcept { - is_null_input_ = false; - cuda_stream_ = nullptr; - input_size_list_.clear(); - output_size_list_.clear(); - workspace_size_list_.clear(); - } - protected: - void InitSizeLists() { - input_size_list_.clear(); - output_size_list_.clear(); - workspace_size_list_.clear(); - - input_size_list_.emplace_back(input_elements_ * unit_size_); - // The workspace for device input shape. - size_t device_input_shape_size = input_shape_.size() * sizeof(size_t); - // The workspace for device output shape. - size_t device_output_shape_size = output_shape_.size() * sizeof(size_t); - // The workspace for device output axis. - size_t device_axis_shape_size = output_axis_.size() * sizeof(size_t); - // The workspace for device output stride. - size_t device_output_stride_size = output_stride_.size() * sizeof(size_t); - - workspace_size_list_.emplace_back(device_input_shape_size); - workspace_size_list_.emplace_back(device_output_shape_size); - workspace_size_list_.emplace_back(device_axis_shape_size); - workspace_size_list_.emplace_back(device_output_stride_size); - output_size_list_.emplace_back(output_elements_ * unit_size_); - } - std::vector GetOpSupport() override; private: + void InitWorkSpaceSizeList() { + // The workspace for device input shape. + const size_t device_input_shape_size = input_shape_.size() * sizeof(size_t); + // The workspace for device output shape. + const size_t device_output_shape_size = output_shape_.size() * sizeof(size_t); + // The workspace for device output axis. + const size_t device_axis_shape_size = output_axis_.size() * sizeof(size_t); + // The workspace for device output stride. + const size_t device_output_stride_size = output_stride_.size() * sizeof(size_t); + workspace_size_list_.clear(); + workspace_size_list_ = {device_input_shape_size, device_output_shape_size, device_axis_shape_size, + device_output_stride_size}; + } + template bool LaunchKernel(const std::vector &inputs, const std::vector &workspace, const std::vector &outputs); @@ -90,25 +73,20 @@ class LpNormGpuKernelMod : public NativeGpuKernelMod { std::function &, const std::vector &, const std::vector &)>; - void GetLpNormAttr(); + bool GetLpNormAttr(const BaseOperatorPtr &base_operator); - private: size_t unit_size_{1}; float p_{2.0}; float epsilon_{1e-12}; std::vector axis_; void *cuda_stream_{nullptr}; bool is_null_input_{false}; - - std::optional is_input_dynamic_shape_{}; - BaseOperatorPtr kernel_ptr_{nullptr}; std::vector input_shape_; std::vector output_shape_; std::vector output_axis_; std::vector output_stride_; size_t input_elements_{}; size_t output_elements_{}; - std::vector outputs_ = {}; LpNormFunc kernel_func_; static std::vector> func_list_; }; diff --git a/mindspore/core/ops/lp_norm.cc b/mindspore/core/ops/lp_norm.cc index 87cd4ee125f..0dbd1221eee 100644 --- a/mindspore/core/ops/lp_norm.cc +++ b/mindspore/core/ops/lp_norm.cc @@ -110,11 +110,20 @@ AbstractBasePtr LpNormInfer(const abstract::AnalysisEnginePtr &, const Primitive return abstract::MakeAbstract(infer_shape, infer_type); } -void LpNorm::Init(const int64_t p, const float epsilon) { +void LpNorm::Init(const std::vector &axis, const int64_t p, const bool keep_dims, const float epsilon) { + this->set_axis(axis); this->set_p(p); + this->set_keep_dims(keep_dims); this->set_epsilon(epsilon); } +void LpNorm::set_axis(const std::vector &axis) { (void)this->AddAttr(kAxis, api::MakeValue(axis)); } + +std::vector LpNorm::get_axis() const { + auto value_ptr = this->GetAttr(kAxis); + return GetValue>(value_ptr); +} + void LpNorm::set_p(const int64_t p) { (void)this->AddAttr(kP, api::MakeValue(p)); } int64_t LpNorm::get_p() const { @@ -122,6 +131,13 @@ int64_t LpNorm::get_p() const { return GetValue(value_ptr); } +void LpNorm::set_keep_dims(const bool keep_dims) { (void)this->AddAttr(kKeepDims, api::MakeValue(keep_dims)); } + +bool LpNorm::get_keep_dims() const { + auto value_ptr = this->GetAttr(kKeepDims); + return GetValue(value_ptr); +} + void LpNorm::set_epsilon(const float epsilon) { (void)this->AddAttr(kEpsilon, api::MakeValue(epsilon)); } float LpNorm::get_epsilon() const { diff --git a/mindspore/core/ops/lp_norm.h b/mindspore/core/ops/lp_norm.h index 81263c736c1..72b28533088 100644 --- a/mindspore/core/ops/lp_norm.h +++ b/mindspore/core/ops/lp_norm.h @@ -30,14 +30,23 @@ class MIND_API LpNorm : public BaseOperator { MIND_API_BASE_MEMBER(LpNorm); LpNorm() : BaseOperator(kNameLpNorm) { InitIOName({"input"}, {"output"}); } - void Init(const int64_t p = 2, const float epsilon = 1e-12); + void Init(const std::vector &axis, const int64_t p = 2, const bool keep_dims = false, + const float epsilon = 1e-12); + + void set_axis(const std::vector &axis); + + void set_keep_dims(const bool keep_dims); void set_p(const int64_t p); void set_epsilon(const float epsilon); + std::vector get_axis() const; + int64_t get_p() const; + bool get_keep_dims() const; + float get_epsilon() const; };