diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/combined_non_max_suppression_cpu_kernel.cc b/mindspore/ccsrc/plugin/device/cpu/kernel/combined_non_max_suppression_cpu_kernel.cc index 676d7529e56..b78cc47c709 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/combined_non_max_suppression_cpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/combined_non_max_suppression_cpu_kernel.cc @@ -342,26 +342,49 @@ void CombinedNonMaxSuppressionCpuKernelMod::CheckOutput() { } } -void CombinedNonMaxSuppressionCpuKernelMod::InitKernel(const CNodePtr &kernel_node) { - size_t input_num = common::AnfAlgo::GetInputTensorNum(kernel_node); - size_t output_num = common::AnfAlgo::GetOutputTensorNum(kernel_node); - node_wpt_ = kernel_node; - input0_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0); - input1_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, KIndex1); - input2_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, KIndex2); - input3_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, KIndex3); - input4_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, KIndex4); - input5_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, KIndex5); +bool CombinedNonMaxSuppressionCpuKernelMod::Init(const BaseOperatorPtr &base_operator, + const std::vector &inputs, + const std::vector &outputs) { + MS_EXCEPTION_IF_NULL(base_operator); + kernel_name_ = base_operator->name(); + return true; +} + +int CombinedNonMaxSuppressionCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, + const std::vector &inputs, + const std::vector &outputs, + const std::map &) { + if (auto ret = KernelMod::Resize(base_operator, inputs, outputs); ret != KRET_OK) { + return ret; + } + + size_t input_num = inputs.size(); + size_t output_num = outputs.size(); + CHECK_KERNEL_INPUTS_NUM(input_num, kCombinedNonMaxSuppressionInputsNum, kernel_name_); + CHECK_KERNEL_OUTPUTS_NUM(output_num, kCombinedNonMaxSuppressionOutputsNum, kernel_name_); + + input0_shape_ = inputs.at(kIndex0)->GetDeviceShapeAdaptively(); + input1_shape_ = inputs.at(KIndex1)->GetDeviceShapeAdaptively(); + input2_shape_ = inputs.at(KIndex2)->GetDeviceShapeAdaptively(); + input3_shape_ = inputs.at(KIndex3)->GetDeviceShapeAdaptively(); + input4_shape_ = inputs.at(KIndex4)->GetDeviceShapeAdaptively(); + input5_shape_ = inputs.at(KIndex5)->GetDeviceShapeAdaptively(); + + output0_shape_ = outputs.at(kIndex0)->GetDeviceShapeAdaptively(); + output1_shape_ = outputs.at(kIndex1)->GetDeviceShapeAdaptively(); + output2_shape_ = outputs.at(kIndex2)->GetDeviceShapeAdaptively(); + output3_shape_ = outputs.at(kIndex3)->GetDeviceShapeAdaptively(); + soft_nms_sigma_ = 0.0; num_bath_ = static_cast(input0_shape_[0]); num_boxes_ = static_cast(input0_shape_[KIndex1]); q_ = static_cast(input0_shape_[KIndex2]); num_class_ = static_cast((input1_shape_[KIndex2])); - kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node); pad_per_class_ = false; clip_boxes_ = true; - auto prim = common::AnfAlgo::GetCNodePrimitive(kernel_node); + + PrimitivePtr prim = base_operator->GetPrim(); auto pad_per_class = prim->GetAttr("pad_per_class"); auto clip_boxes = prim->GetAttr("clip_boxes"); if (pad_per_class != nullptr) { @@ -370,8 +393,11 @@ void CombinedNonMaxSuppressionCpuKernelMod::InitKernel(const CNodePtr &kernel_no if (clip_boxes != nullptr) { clip_boxes_ = GetValue(clip_boxes); } - CHECK_KERNEL_INPUTS_NUM(input_num, kCombinedNonMaxSuppressionInputsNum, kernel_name_); - CHECK_KERNEL_OUTPUTS_NUM(output_num, kCombinedNonMaxSuppressionOutputsNum, kernel_name_); + + CheckInput(); + CheckOutput(); + + return KRET_OK; } bool CombinedNonMaxSuppressionCpuKernelMod::Launch(const std::vector &inputs, @@ -392,24 +418,9 @@ bool CombinedNonMaxSuppressionCpuKernelMod::Launch(const std::vector(num_detection_), DimSize4}; - ShapeVector shape1 = {input0_shape_[0], static_cast(num_detection_)}; - ShapeVector shape2 = {input0_shape_[0], static_cast(num_detection_)}; - ShapeVector shape3 = {input0_shape_[0]}; - common::AnfAlgo::SetOutputInferTypeAndShape( - {kNumberTypeFloat32, kNumberTypeFloat32, kNumberTypeFloat32, kNumberTypeInt32}, {shape0, shape1, shape2, shape3}, - node_.get()); - output0_shape_ = AnfAlgo::GetOutputDeviceShape(node_, KIndex0); - output1_shape_ = AnfAlgo::GetOutputDeviceShape(node_, KIndex1); - output2_shape_ = AnfAlgo::GetOutputDeviceShape(node_, KIndex2); - output3_shape_ = AnfAlgo::GetOutputDeviceShape(node_, KIndex3); + size_per_class_ = max_output_size_per_class_ < num_boxes_ ? max_output_size_per_class_ : num_boxes_; - CheckInput(); - CheckOutput(); + if (max_total_size_ <= 0) { MS_LOG(EXCEPTION) << "For " << kernel_name_ << " max_total_size must be > 0, but got " << max_total_size_ << "."; } @@ -435,6 +446,7 @@ bool CombinedNonMaxSuppressionCpuKernelMod::Launch(const std::vector CombinedNonMaxSuppressionCpuKernelMod::GetOpSupport() { static std::vector kernel_attr_list = { KernelAttr() diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/combined_non_max_suppression_cpu_kernel.h b/mindspore/ccsrc/plugin/device/cpu/kernel/combined_non_max_suppression_cpu_kernel.h index cada125392a..d89200c0d0b 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/combined_non_max_suppression_cpu_kernel.h +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/combined_non_max_suppression_cpu_kernel.h @@ -50,11 +50,17 @@ bool result_cmp(const result_para &a, const result_para &b) { return a.score > b namespace mindspore { namespace kernel { -class CombinedNonMaxSuppressionCpuKernelMod : public DeprecatedNativeCpuKernelMod { +class CombinedNonMaxSuppressionCpuKernelMod : public NativeCpuKernelMod { public: CombinedNonMaxSuppressionCpuKernelMod() = default; ~CombinedNonMaxSuppressionCpuKernelMod() override = default; - void InitKernel(const CNodePtr &kernel_node) override; + + bool Init(const BaseOperatorPtr &base_operator, const std::vector &inputs, + const std::vector &outputs) override; + + int Resize(const BaseOperatorPtr &base_operator, const std::vector &inputs, + const std::vector &outputs, const std::map &) override; + bool Launch(const std::vector &inputs, const std::vector &workspace, const std::vector &outputs) override; @@ -69,6 +75,7 @@ class CombinedNonMaxSuppressionCpuKernelMod : public DeprecatedNativeCpuKernelMo void nms_perclass(float *, float *, std::vector &, int &); void CheckInput(); void CheckOutput(); + int num_bath_ = 0; int num_boxes_ = 0; int q_ = 0; @@ -87,7 +94,7 @@ class CombinedNonMaxSuppressionCpuKernelMod : public DeprecatedNativeCpuKernelMo float soft_nms_sigma_ = 0.0; bool pad_per_class_ = 0; bool clip_boxes_ = 1; - CNodeWeakPtr node_wpt_; + std::vector input0_shape_; std::vector input1_shape_; std::vector input2_shape_; diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/conjugate_transpose_cpu_kernel.cc b/mindspore/ccsrc/plugin/device/cpu/kernel/conjugate_transpose_cpu_kernel.cc index de120a993cf..bec03dc7e37 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/conjugate_transpose_cpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/conjugate_transpose_cpu_kernel.cc @@ -44,14 +44,13 @@ using complex128 = std::complex; constexpr size_t kMaxTransposeSerialSize = 50331648; } // namespace -void ConjugateTransposeCpuKernelMod::InitKernel(const CNodePtr &kernel_node) { - MS_EXCEPTION_IF_NULL(kernel_node); - kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node); - input_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0); - output_shape_ = AnfAlgo::GetOutputDeviceShape(kernel_node, 0); - dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0); - perm_type_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 1); - +bool ConjugateTransposeCpuKernelMod::Init(const BaseOperatorPtr &base_operator, + const std::vector &inputs, + const std::vector &outputs) { + MS_EXCEPTION_IF_NULL(base_operator); + kernel_name_ = base_operator->name(); + dtype_ = inputs.at(kIndex0)->GetDtype(); + perm_type_ = inputs.at(kIndex1)->GetDtype(); launch_map_[kNumberTypeBool] = &ConjugateTransposeCpuKernelMod::LaunchKernel; launch_map_[kNumberTypeInt8] = &ConjugateTransposeCpuKernelMod::LaunchKernel; launch_map_[kNumberTypeInt16] = &ConjugateTransposeCpuKernelMod::LaunchKernel; @@ -72,6 +71,21 @@ void ConjugateTransposeCpuKernelMod::InitKernel(const CNodePtr &kernel_node) { } else { MS_LOG(EXCEPTION) << "For ConjugateTranspose: unsupported input data type: " << dtype_; } + return true; +} + +int ConjugateTransposeCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, + const std::vector &inputs, + const std::vector &outputs, + const std::map &) { + if (auto ret = KernelMod::Resize(base_operator, inputs, outputs); ret != KRET_OK) { + return ret; + } + + input_shape_ = inputs.at(kIndex0)->GetDeviceShapeAdaptively(); + output_shape_ = outputs.at(kIndex0)->GetDeviceShapeAdaptively(); + + return KRET_OK; } bool ConjugateTransposeCpuKernelMod::Launch(const std::vector &inputs, diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/conjugate_transpose_cpu_kernel.h b/mindspore/ccsrc/plugin/device/cpu/kernel/conjugate_transpose_cpu_kernel.h index 9f162a8ce09..a5785cb75c2 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/conjugate_transpose_cpu_kernel.h +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/conjugate_transpose_cpu_kernel.h @@ -17,6 +17,7 @@ #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_CONJUGATE_TRANSPOSE_CPU_KERNEL_H_ #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_CONJUGATE_TRANSPOSE_CPU_KERNEL_H_ +#include #include #include #include @@ -27,14 +28,20 @@ namespace mindspore { namespace kernel { -class ConjugateTransposeCpuKernelMod : public DeprecatedNativeCpuKernelMod { +class ConjugateTransposeCpuKernelMod : public NativeCpuKernelMod { public: ConjugateTransposeCpuKernelMod() = default; ~ConjugateTransposeCpuKernelMod() override = default; - void InitKernel(const CNodePtr &kernel_node) override; + bool Init(const BaseOperatorPtr &base_operator, const std::vector &inputs, + const std::vector &outputs) override; + + int Resize(const BaseOperatorPtr &base_operator, const std::vector &inputs, + const std::vector &outputs, const std::map &) override; + bool Launch(const std::vector &inputs, const std::vector &workspace, const std::vector &outputs) override; + template static void ConjComplexFunc(T *input, T *output, size_t start, size_t end); diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/nn/sparse_ftrl_gpu_kernel.cc b/mindspore/ccsrc/plugin/device/gpu/kernel/nn/sparse_apply_ftrl_gpu_kernel.cc similarity index 97% rename from mindspore/ccsrc/plugin/device/gpu/kernel/nn/sparse_ftrl_gpu_kernel.cc rename to mindspore/ccsrc/plugin/device/gpu/kernel/nn/sparse_apply_ftrl_gpu_kernel.cc index fb49dccf5a7..edf83f9a117 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/nn/sparse_ftrl_gpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/nn/sparse_apply_ftrl_gpu_kernel.cc @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "plugin/device/gpu/kernel/nn/sparse_ftrl_gpu_kernel.h" +#include "plugin/device/gpu/kernel/nn/sparse_apply_ftrl_gpu_kernel.h" namespace mindspore { namespace kernel { diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/nn/sparse_apply_ftrl_gpu_kernel.h b/mindspore/ccsrc/plugin/device/gpu/kernel/nn/sparse_apply_ftrl_gpu_kernel.h new file mode 100644 index 00000000000..d7023ac7f54 --- /dev/null +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/nn/sparse_apply_ftrl_gpu_kernel.h @@ -0,0 +1,133 @@ +/** + * Copyright 2020-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_SPARSE_FTRL_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_SPARSE_FTRL_GPU_KERNEL_H_ + +#include +#include +#include "ops/sparse_apply_ftrl.h" +#include "utils/check_convert_utils.h" +#include "plugin/device/gpu/kernel/gpu_kernel.h" +#include "plugin/device/gpu/kernel/gpu_kernel_factory.h" +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/sparse_ftrl_impl.cuh" + +namespace mindspore { +namespace kernel { +constexpr size_t INPUT_NUM = 5; +template +class SparseFtrlGpuKernelMod : public NativeGpuKernelMod { + public: + SparseFtrlGpuKernelMod() { ResetResource(); } + ~SparseFtrlGpuKernelMod() override = default; + + bool Launch(const std::vector &inputs, const std::vector &, + const std::vector &outputs, void *stream_ptr) override { + T *variable = GetDeviceAddress(inputs, 0); + T *accumulation = GetDeviceAddress(inputs, 1); + T *linear = GetDeviceAddress(inputs, 2); + T *gradient = GetDeviceAddress(inputs, 3); + S *indices = GetDeviceAddress(inputs, 4); + T *variable_out = GetDeviceAddress(outputs, 0); + T *accumulation_out = GetDeviceAddress(outputs, 1); + T *linear_out = GetDeviceAddress(outputs, 2); + CalSparseApplyFtrl(gradient, indices, num_index_, n_stride_, lr_, l1_, l2_, lr_power_, use_locking_, variable, + accumulation, linear, reinterpret_cast(stream_ptr)); + CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaMemcpyAsync(variable_out, variable, variable_size_, cudaMemcpyDeviceToDevice, + reinterpret_cast(stream_ptr)), + "cudaMemcpyAsync output failed"); + CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE( + cudaMemcpyAsync(accumulation_out, accumulation, accumulation_size_, cudaMemcpyDeviceToDevice, + reinterpret_cast(stream_ptr)), + "cudaMemcpyAsync output failed"); + CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaMemcpyAsync(linear_out, linear, linear_size_, cudaMemcpyDeviceToDevice, + reinterpret_cast(stream_ptr)), + "cudaMemcpyAsync output failed"); + return true; + } + + bool Init(const BaseOperatorPtr &base_operator, const std::vector &inputs, + const std::vector &outputs) { + MS_EXCEPTION_IF_NULL(base_operator); + kernel_name_ = base_operator->name(); + auto kernel_ptr = std::dynamic_pointer_cast(base_operator); + MS_EXCEPTION_IF_NULL(kernel_ptr); + lr_ = kernel_ptr->get_lr(); + l1_ = kernel_ptr->get_l1(); + l2_ = kernel_ptr->get_l2(); + lr_power_ = kernel_ptr->get_lr_power(); + use_locking_ = kernel_ptr->get_use_locking(); + return true; + } + + int Resize(const BaseOperatorPtr &base_operator, const std::vector &inputs, + const std::vector &outputs, const std::map &) { + if (auto ret = KernelMod::Resize(base_operator, inputs, outputs); ret != KRET_OK) { + return ret; + } + + (void)CheckAndConvertUtils::CheckInteger("input num", inputs.size(), kEqual, INPUT_NUM, kernel_name_); + + variable_size_ = sizeof(T); + accumulation_size_ = sizeof(T); + linear_size_ = sizeof(T); + n_stride_ = 1; + + auto variable_shape = inputs.at(kIndex0)->GetShapeVector(); + auto accumulation_shape = inputs.at(kIndex1)->GetShapeVector(); + auto linear_shape = inputs.at(kIndex2)->GetShapeVector(); + auto indices_shape = inputs.at(kIndex4)->GetShapeVector(); + + for (size_t i = 0; i < variable_shape.size(); i++) { + variable_size_ *= variable_shape[i]; + if (i > 0) { + n_stride_ *= variable_shape[i]; + } + } + accumulation_size_ *= SizeOf(accumulation_shape); + linear_size_ *= SizeOf(linear_shape); + num_index_ = indices_shape[0]; + + return KRET_OK; + } + + protected: + void ResetResource() noexcept { + lr_ = 0.0f; + l1_ = 0.0f; + l2_ = 0.0f; + lr_power_ = 0.0f; + use_locking_ = false; + num_index_ = 0; + n_stride_ = 1; + } + + private: + size_t variable_size_; + size_t accumulation_size_; + size_t linear_size_; + float lr_; + float l1_; + float l2_; + float lr_power_; + bool use_locking_; + int num_index_; + size_t n_stride_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_SPARSE_FTRL_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/nn/sparse_ftrl_gpu_kernel.h b/mindspore/ccsrc/plugin/device/gpu/kernel/nn/sparse_ftrl_gpu_kernel.h deleted file mode 100644 index 56b4984c691..00000000000 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/nn/sparse_ftrl_gpu_kernel.h +++ /dev/null @@ -1,168 +0,0 @@ -/** - * Copyright 2020-2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_SPARSE_FTRL_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_SPARSE_FTRL_GPU_KERNEL_H_ - -#include -#include "plugin/device/gpu/kernel/gpu_kernel.h" -#include "plugin/device/gpu/kernel/gpu_kernel_factory.h" -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/sparse_ftrl_impl.cuh" -namespace mindspore { -namespace kernel { -constexpr size_t INPUT_NUM = 5; -template -class SparseFtrlGpuKernelMod : public DeprecatedNativeGpuKernelMod { - public: - SparseFtrlGpuKernelMod() { ResetResource(); } - ~SparseFtrlGpuKernelMod() override = default; - - bool Launch(const std::vector &inputs, const std::vector &, - const std::vector &outputs, void *stream_ptr) override { - if (is_null_input_) { - return true; - } - T *variable = GetDeviceAddress(inputs, 0); - T *accumulation = GetDeviceAddress(inputs, 1); - T *linear = GetDeviceAddress(inputs, 2); - T *gradient = GetDeviceAddress(inputs, 3); - S *indices = GetDeviceAddress(inputs, 4); - T *variable_out = GetDeviceAddress(outputs, 0); - T *accumulation_out = GetDeviceAddress(outputs, 1); - T *linear_out = GetDeviceAddress(outputs, 2); - CalSparseApplyFtrl(gradient, indices, num_index_, n_stride_, lr_, l1_, l2_, lr_power_, use_locking_, variable, - accumulation, linear, reinterpret_cast(stream_ptr)); - CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, - cudaMemcpyAsync(variable_out, variable, variable_size_, cudaMemcpyDeviceToDevice, - reinterpret_cast(stream_ptr)), - "cudaMemcpyAsync output failed"); - CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, - cudaMemcpyAsync(accumulation_out, accumulation, accumulation_size_, - cudaMemcpyDeviceToDevice, reinterpret_cast(stream_ptr)), - "cudaMemcpyAsync output failed"); - CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, - cudaMemcpyAsync(linear_out, linear, linear_size_, cudaMemcpyDeviceToDevice, - reinterpret_cast(stream_ptr)), - "cudaMemcpyAsync output failed"); - return true; - } - - bool Init(const CNodePtr &kernel_node) override { - auto kernel_name = common::AnfAlgo::GetCNodeName(kernel_node); - kernel_node_ = kernel_node; - size_t input_num = common::AnfAlgo::GetInputTensorNum(kernel_node); - if (input_num != INPUT_NUM) { - MS_LOG(EXCEPTION) << "For '" << kernel_name << "', the number of inputs must be " << INPUT_NUM << ", but got " - << input_num; - } - - variable_size_ = sizeof(T); - accumulation_size_ = sizeof(T); - linear_size_ = sizeof(T); - gradient_size_ = sizeof(T); - indices_size_ = sizeof(S); - - auto shape_signed = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - if (IsDynamic(shape_signed)) { - return true; - } - auto variable_shape = Convert2SizeTClipNeg(shape_signed); - auto accumulation_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); - auto linear_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 2); - auto gradient_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 3); - auto indices_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 4); - is_null_input_ = CHECK_SHAPE_NULL(variable_shape, kernel_name, "var") || - CHECK_SHAPE_NULL(accumulation_shape, kernel_name, "accum") || - CHECK_SHAPE_NULL(linear_shape, kernel_name, "linear") || - CHECK_SHAPE_NULL(gradient_shape, kernel_name, "grad") || - CHECK_SHAPE_NULL(indices_shape, kernel_name, "indices"); - if (is_null_input_) { - InitSizeLists(); - return true; - } - for (size_t i = 0; i < variable_shape.size(); i++) { - variable_size_ *= variable_shape[i]; - if (i > 0) { - n_stride_ *= variable_shape[i]; - } - } - - accumulation_size_ *= SizeOf(accumulation_shape); - linear_size_ *= SizeOf(linear_shape); - gradient_size_ *= SizeOf(gradient_shape); - indices_size_ *= SizeOf(indices_shape); - - lr_ = GetAttr(kernel_node, "lr"); - l1_ = GetAttr(kernel_node, "l1"); - l2_ = GetAttr(kernel_node, "l2"); - lr_power_ = GetAttr(kernel_node, "lr_power"); - use_locking_ = GetAttr(kernel_node, "use_locking"); - num_index_ = LongToSizeClipNeg(indices_shape[0]); - - InitSizeLists(); - return true; - } - - protected: - void InitSizeLists() override { - input_size_list_.push_back(variable_size_); - input_size_list_.push_back(accumulation_size_); - input_size_list_.push_back(linear_size_); - input_size_list_.push_back(gradient_size_); - input_size_list_.push_back(indices_size_); - output_size_list_.push_back(variable_size_); - output_size_list_.push_back(accumulation_size_); - output_size_list_.push_back(linear_size_); - } - - void ResetResource() noexcept override { - variable_size_ = 0; - accumulation_size_ = 0; - linear_size_ = 0; - gradient_size_ = 0; - indices_size_ = 0; - lr_ = 0.0f; - l1_ = 0.0f; - l2_ = 0.0f; - lr_power_ = 0.0f; - use_locking_ = false; - is_null_input_ = false; - num_index_ = 0; - n_stride_ = 1; - input_size_list_.clear(); - output_size_list_.clear(); - workspace_size_list_.clear(); - } - - private: - size_t variable_size_; - size_t accumulation_size_; - size_t linear_size_; - size_t gradient_size_; - size_t indices_size_; - float lr_; - float l1_; - float l2_; - float lr_power_; - bool use_locking_; - bool is_null_input_; - int num_index_; - size_t n_stride_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_SPARSE_FTRL_GPU_KERNEL_H_ diff --git a/mindspore/core/ops/combined_non_max_suppression.cc b/mindspore/core/ops/combined_non_max_suppression.cc index 29dfc35617d..237f2a347fb 100644 --- a/mindspore/core/ops/combined_non_max_suppression.cc +++ b/mindspore/core/ops/combined_non_max_suppression.cc @@ -31,6 +31,7 @@ const int64_t kInputDimension1 = 3; const int64_t kDimsize = 4; const int64_t kInputs = 6; const size_t ksecond = 2; + tensor::TensorPtr Get_Value(const std::vector &input_args, size_t index) { auto input = input_args[index]->cast(); MS_EXCEPTION_IF_NULL(input); @@ -38,15 +39,14 @@ tensor::TensorPtr Get_Value(const std::vector &input_args, size MS_EXCEPTION_IF_NULL(input_shape_value_ptr); return input_shape_value_ptr->cast(); } -abstract::TupleShapePtr CombinedNonMaxSuppressionInferShape(const PrimitivePtr &primitive, - const std::vector &input_args) { - auto prim_name = primitive->name(); - auto input0_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape]; - auto input1_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape]; - auto input2_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape())[kShape]; - auto input3_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex3]->BuildShape())[kShape]; - auto input4_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex4]->BuildShape())[kShape]; - auto input5_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex5]->BuildShape())[kShape]; + +void CombinedNonMaxSuppressionCheckShapeSize(const ShapeVector &input0_shape, const ShapeVector &input1_shape, + const ShapeVector &input2_shape, const ShapeVector &input3_shape, + const ShapeVector &input4_shape, const ShapeVector &input5_shape, + const bool &is_dynamic_rank, const std::string &prim_name) { + if (is_dynamic_rank) { + return; + } (void)CheckAndConvertUtils::CheckInteger("boxes dim", SizeToLong(input0_shape.size()), kEqual, kInputDimension0, prim_name); (void)CheckAndConvertUtils::CheckInteger("scores dim", SizeToLong(input1_shape.size()), kEqual, kInputDimension1, @@ -55,6 +55,13 @@ abstract::TupleShapePtr CombinedNonMaxSuppressionInferShape(const PrimitivePtr & (void)CheckAndConvertUtils::CheckInteger("max_total_size dim", SizeToLong(input3_shape.size()), kEqual, 0, prim_name); (void)CheckAndConvertUtils::CheckInteger("iou_threshold", SizeToLong(input4_shape.size()), kEqual, 0, prim_name); (void)CheckAndConvertUtils::CheckInteger("score_threshold", SizeToLong(input5_shape.size()), kEqual, 0, prim_name); +} + +void CombinedNonMaxSuppressionCheckShapeValue(const ShapeVector &input0_shape, const ShapeVector &input1_shape, + const bool &is_dynamic, const std::string &prim_name) { + if (is_dynamic) { + return; + } if (input0_shape[0] != input1_shape[0]) { MS_EXCEPTION(ValueError) << "For " << prim_name << ", the boxes's 1st dim must be same with the scores's" << " 1st dim, but got" << input0_shape[0] << " and " << input1_shape[0] << "."; @@ -72,45 +79,35 @@ abstract::TupleShapePtr CombinedNonMaxSuppressionInferShape(const PrimitivePtr & MS_EXCEPTION(ValueError) << "For " << prim_name << ", the boxes's 4th dim must be equal to 4, but got" << input0_shape[kInputIndex3] << "."; } - for (int64_t i = 0; i < kInputs; i++) { - if (!input_args[i]->isa()) { - MS_EXCEPTION(TypeError) << "For " << prim_name << " input" << i << " only support tensor!"; - } - } +} + +abstract::TupleShapePtr CombinedNonMaxSuppressionGetOutputShape(const PrimitivePtr &primitive, + const std::vector &input_args, + const bool &is_dynamic) { + auto input0_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape]; + auto input1_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape]; auto pad_per_class_ptr = primitive->GetAttr("pad_per_class"); MS_EXCEPTION_IF_NULL(pad_per_class_ptr); bool pad_per_class = GetValue(pad_per_class_ptr); - auto input2_tensor = Get_Value(input_args, kInputIndex2); - auto input3_tensor = Get_Value(input_args, kInputIndex3); - auto input4_tensor = Get_Value(input_args, kInputIndex4); - auto input5_tensor = Get_Value(input_args, kInputIndex5); - if (IsValue(input_args[kInputIndex2]->BuildValue()) && IsValue(input_args[kInputIndex3]->BuildValue())) { - if (IsValue(input_args[kInputIndex4]->BuildValue()) && input_args[kInputIndex5]->BuildValue()) { - auto iou_threshold = *(static_cast(input4_tensor->data_c())); - auto score_threshold = *(static_cast(input5_tensor->data_c())); - if (iou_threshold < 0 || iou_threshold > 1) { - MS_EXCEPTION(ValueError) << "For " << prim_name << ", iou_threshold must be in [0,1], but got " << iou_threshold - << "."; - } - if (score_threshold < 0 && input0_shape[kInputIndex2] == input1_shape[kInputIndex2]) { - MS_EXCEPTION(ValueError) << "For " << prim_name << ", it is temporarily unsupported when boxes's 2'nd dim " - << "is not 1 and score_threshold is less than 1."; - } - } + + if (!is_dynamic && IsValue(input_args[kInputIndex2]->BuildValue()) && + IsValue(input_args[kInputIndex3]->BuildValue())) { + auto input2_tensor = Get_Value(input_args, kInputIndex2); + auto input3_tensor = Get_Value(input_args, kInputIndex3); auto max_output_size_per_class = *(static_cast(input2_tensor->data_c())); auto max_total_size = *(static_cast(input3_tensor->data_c())); - if (max_total_size <= 0) { - MS_EXCEPTION(ValueError) << "For " << prim_name << " max_total_size must be > 0, but got " << max_total_size - << "."; - } - if (max_output_size_per_class <= 0) { - MS_EXCEPTION(ValueError) << "For " << prim_name << " max_output_size_per_class must be > 0, but got " - << max_output_size_per_class << "."; - } + + const int32_t kNumZero = 0; + CheckAndConvertUtils::CheckInteger("max_total_size", max_total_size, kGreaterThan, kNumZero, primitive->name()); + + CheckAndConvertUtils::CheckInteger("max_output_size_per_clas", max_output_size_per_class, kGreaterThan, kNumZero, + primitive->name()); + auto num_detection = max_total_size; if (pad_per_class) { num_detection = std::min(max_total_size, max_output_size_per_class * static_cast(input1_shape[ksecond])); } + int64_t bs = input0_shape[0]; ShapeVector shape1 = {bs, num_detection, 4}; ShapeVector shape2 = {bs, num_detection}; @@ -122,14 +119,58 @@ abstract::TupleShapePtr CombinedNonMaxSuppressionInferShape(const PrimitivePtr & auto out4 = std::make_shared(shape4); return std::make_shared(std::vector{out1, out2, out3, out4}); } else { - auto shape1 = std::make_shared(ShapeVector{-2}); - auto shape2 = std::make_shared(ShapeVector{-2}); - auto shape3 = std::make_shared(ShapeVector{-2}); - auto shape4 = std::make_shared(ShapeVector{-2}); + auto shape1 = std::make_shared(ShapeVector{-1, -1, 4}); + auto shape2 = std::make_shared(ShapeVector{-1, -1}); + auto shape3 = std::make_shared(ShapeVector{-1, -1}); + auto shape4 = std::make_shared(ShapeVector{-1}); return std::make_shared(std::vector{shape1, shape2, shape3, shape4}); } } +abstract::TupleShapePtr CombinedNonMaxSuppressionInferShape(const PrimitivePtr &primitive, + const std::vector &input_args) { + auto prim_name = primitive->name(); + auto input0_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape]; + auto input1_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape]; + auto input2_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape())[kShape]; + auto input3_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex3]->BuildShape())[kShape]; + auto input4_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex4]->BuildShape())[kShape]; + auto input5_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex5]->BuildShape())[kShape]; + + std::vector all_shapes = {input0_shape, input1_shape, input2_shape, + input3_shape, input4_shape, input5_shape}; + auto is_dynamic = (IsDynamic(input0_shape) || IsDynamic(input1_shape)); + auto is_dynamic_rank = std::any_of(all_shapes.begin(), all_shapes.end(), IsDynamicRank); + + CombinedNonMaxSuppressionCheckShapeSize(input0_shape, input1_shape, input2_shape, input3_shape, input4_shape, + input5_shape, is_dynamic_rank, prim_name); + + CombinedNonMaxSuppressionCheckShapeValue(input0_shape, input1_shape, is_dynamic, prim_name); + + for (int64_t i = 0; i < kInputs; i++) { + if (!input_args[i]->isa()) { + MS_EXCEPTION(TypeError) << "For " << prim_name << " input" << i << " only support tensor!"; + } + } + + if (IsValue(input_args[kInputIndex4]->BuildValue()) && IsValue(input_args[kInputIndex5]->BuildValue())) { + auto input4_tensor = Get_Value(input_args, kInputIndex4); + auto input5_tensor = Get_Value(input_args, kInputIndex5); + auto iou_threshold = *(static_cast(input4_tensor->data_c())); + auto score_threshold = *(static_cast(input5_tensor->data_c())); + if (iou_threshold < 0 || iou_threshold > 1) { + MS_EXCEPTION(ValueError) << "For " << prim_name << ", iou_threshold must be in [0,1], but got " << iou_threshold + << "."; + } + if (score_threshold < 0 && !is_dynamic && input0_shape[kInputIndex2] == input1_shape[kInputIndex2]) { + MS_EXCEPTION(ValueError) << "For " << prim_name << ", it is temporarily unsupported when boxes's 2'nd dim " + << "is not 1 and score_threshold is less than 1."; + } + } + + return CombinedNonMaxSuppressionGetOutputShape(primitive, input_args, is_dynamic); +} + TuplePtr CombinedNonMaxSuppressionInferType(const PrimitivePtr &primitive, const std::vector &input_args) { auto prim_name = primitive->name(); diff --git a/mindspore/core/ops/conjugate_transpose.cc b/mindspore/core/ops/conjugate_transpose.cc index f7c59e15727..29710d516dd 100644 --- a/mindspore/core/ops/conjugate_transpose.cc +++ b/mindspore/core/ops/conjugate_transpose.cc @@ -32,33 +32,48 @@ abstract::ShapePtr ConjugateTransposeInferShape(const PrimitivePtr &primitive, MS_EXCEPTION_IF_NULL(primitive); auto op_name = primitive->name(); auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; - ShapeVector p_value; - ShapeVector p_value_raw; + auto is_dynamic_rank = IsDynamicRank(x_shape); + if (is_dynamic_rank) { + return std::make_shared(std::vector{-2}); + } + + constexpr int64_t dim_7 = 7; + (void)CheckAndConvertUtils::CheckInteger("[x] rank", static_cast(x_shape.size()), kLessEqual, dim_7, + op_name); + auto perm_value = input_args[1]->BuildValue(); MS_EXCEPTION_IF_NULL(perm_value); - if (perm_value->isa()) { + if (perm_value->isa()) { + std::vector output_shape(static_cast(x_shape.size()), -1); + return std::make_shared(output_shape); + } + + ShapeVector p_value; + ShapeVector p_value_raw; + if (perm_value->isa()) { + p_value_raw = CheckAndConvertUtils::CheckTensorIntValue("input[perm]", perm_value, op_name); + } else if (perm_value->isa()) { p_value_raw = CheckAndConvertUtils::CheckTupleInt("input[perm]", perm_value, op_name); } else { MS_EXCEPTION(TypeError) << "For '" << op_name << "', the type of perm must be Tuple, but got " << input_args[1]->BuildType()->ToString() << " ."; } + for (auto p : p_value_raw) { p = (p >= 0) ? p : (static_cast(p_value_raw.size()) + p); p_value.emplace_back(p); } + if (x_shape.size() != p_value.size()) { MS_EXCEPTION(ValueError) << "For '" << op_name << "', the dimension of x " << x_shape.size() << " and perm " << p_value.size() << " must be equal, but got the dimension of x " << x_shape.size() << " and perm " << p_value.size() << " ."; } - constexpr int64_t dim_7 = 7; - if (p_value.size() > dim_7) { - MS_EXCEPTION(ValueError) << "For '" << op_name << "', the dimension of perm must be less than 8, but get " - << p_value.size() << " ."; - } + for (auto i : p_value) { (void)CheckAndConvertUtils::CheckInteger("perm element", i, kLessThan, SizeToLong(p_value.size()), op_name); } + std::vector tmp(p_value); for (auto it = tmp.begin(); it != tmp.end();) { auto dim = *it; @@ -69,6 +84,7 @@ abstract::ShapePtr ConjugateTransposeInferShape(const PrimitivePtr &primitive, MS_EXCEPTION(ValueError) << "For '" << op_name << "', the value of perm must be different."; } } + std::vector in_shape(p_value); (void)std::transform(in_shape.begin(), in_shape.end(), in_shape.begin(), [x_shape](size_t i) { return x_shape[i]; }); return std::make_shared(in_shape); @@ -85,7 +101,6 @@ TypePtr ConjugateTransposeInferType(const PrimitivePtr &prim, const std::vector< } } // namespace -MIND_API_OPERATOR_IMPL(ConjugateTranspose, BaseOperator); AbstractBasePtr ConjugateTransposeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); @@ -95,6 +110,9 @@ AbstractBasePtr ConjugateTransposeInfer(const abstract::AnalysisEnginePtr &, con auto shape = ConjugateTransposeInferShape(primitive, input_args); return abstract::MakeAbstract(shape, type); } + +REGISTER_HOST_DEPENDS(kNameConjugateTranspose, {1}); +MIND_API_OPERATOR_IMPL(ConjugateTranspose, BaseOperator); REGISTER_PRIMITIVE_EVAL_IMPL(ConjugateTranspose, prim::kPrimConjugateTranspose, ConjugateTransposeInfer, nullptr, true); } // namespace ops } // namespace mindspore diff --git a/mindspore/core/ops/sparse_apply_ftrl.cc b/mindspore/core/ops/sparse_apply_ftrl.cc index cee73993385..5f904cadd1e 100644 --- a/mindspore/core/ops/sparse_apply_ftrl.cc +++ b/mindspore/core/ops/sparse_apply_ftrl.cc @@ -136,7 +136,6 @@ void SparseApplyFtrl::Init(float lr, float l1, float l2, float lr_power, bool us set_use_locking(use_locking); } -MIND_API_OPERATOR_IMPL(SparseApplyFtrl, BaseOperator); AbstractBasePtr SparseApplyFtrlInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); @@ -154,14 +153,14 @@ AbstractBasePtr SparseApplyFtrlInfer(const abstract::AnalysisEnginePtr &, const (void)CheckAndConvertUtils::CheckValue(kL1, l1, kGreaterEqual, 0.0f, op_name); (void)CheckAndConvertUtils::CheckValue(kL2, l2, kGreaterEqual, 0.0f, op_name); (void)CheckAndConvertUtils::CheckValue(kLrPower, lr_power, kLessEqual, 0.0f, op_name); - - (void)CheckAndConvertUtils::CheckInteger("input numbers", SizeToLong(input_args.size()), kGreaterEqual, - sparse_apply_ftrl::kSparseApplyFtrlInputNum, op_name); + (void)CheckAndConvertUtils::CheckInteger("input numbers", CheckAndConvertUtils::GetRemoveMonadAbsNum(input_args), + kEqual, sparse_apply_ftrl::kSparseApplyFtrlInputNum, op_name); auto types = sparse_apply_ftrl::SparseApplyFtrlInferType(primitive, input_args); auto shapes = sparse_apply_ftrl::SparseApplyFtrlInferShape(primitive, input_args); return abstract::MakeAbstract(shapes, types); } +MIND_API_OPERATOR_IMPL(SparseApplyFtrl, BaseOperator); REGISTER_PRIMITIVE_EVAL_IMPL(SparseApplyFtrl, prim::kPrimSparseApplyFtrl, SparseApplyFtrlInfer, nullptr, true); } // namespace ops } // namespace mindspore diff --git a/tests/st/ops/cpu/test_combined_non_max_suppression_op.py b/tests/st/ops/cpu/test_combined_non_max_suppression_op.py new file mode 100644 index 00000000000..0843ab1b60a --- /dev/null +++ b/tests/st/ops/cpu/test_combined_non_max_suppression_op.py @@ -0,0 +1,77 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import pytest +import numpy as np + +import mindspore.ops.operations.image_ops as P +from mindspore.common import dtype as mstype +from mindspore import Tensor, nn, context + + +class Net(nn.Cell): + + def __init__(self): + super(Net, self).__init__() + self.op = P.CombinedNonMaxSuppression() + + def construct(self, boxes, scores, max_output_size_per_class, + max_total_size, iou_threshold, score_threshold): + return self.op(boxes, scores, max_output_size_per_class, + max_total_size, iou_threshold, score_threshold) + + +def dyn_case(): + net = Net() + + boxes_dyn = Tensor(shape=[None, None, None, 4], dtype=mstype.float32) + scores_dyn = Tensor(shape=[None, None, None], dtype=mstype.float32) + max_output_size_per_class = Tensor(4, mstype.int32) + max_total_size = Tensor(1, mstype.int32) + iou_threshold = Tensor(0, mstype.float32) + score_threshold = Tensor(0, mstype.float32) + + net.set_inputs(boxes_dyn, scores_dyn, max_output_size_per_class, + max_total_size, iou_threshold, score_threshold) + + boxes = Tensor( + np.array([[[[200, 100, 150, 100]], [[220, 120, 150, 100]], + [[190, 110, 150, 100]], [[210, 112, 150, + 100]]]])).astype('float32') + scores = Tensor( + np.array([[[0.2000, 0.7000, 0.1000], [0.1000, 0.8000, 0.1000], + [0.3000, 0.6000, 0.1000], [0.0500, 0.9000, + 0.0500]]])).astype('float32') + + out = net(boxes, scores, max_output_size_per_class, max_total_size, + iou_threshold, score_threshold) + expect_shapes = [(1, 1, 4), (1, 1), (1, 1), (1,)] + for i in range(4): + assert out[i].asnumpy().shape == expect_shapes[i] + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_combined_non_max_suppression_dyn(): + """ + Feature: test CombinedNonMaxSuppression in PyNative and Graph modes. + Description: test dynamic shape case. + Expectation: expect correct shape result. + """ + context.set_context(mode=context.GRAPH_MODE, device_target='CPU') + dyn_case() + context.set_context(mode=context.PYNATIVE_MODE, device_target='CPU') + dyn_case() diff --git a/tests/st/ops/cpu/test_conjugate_transpose_op.py b/tests/st/ops/cpu/test_conjugate_transpose_op.py new file mode 100644 index 00000000000..be2316762fb --- /dev/null +++ b/tests/st/ops/cpu/test_conjugate_transpose_op.py @@ -0,0 +1,61 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import pytest +import numpy as np + +import mindspore.ops.operations.array_ops as P +from mindspore.common import dtype as mstype +from mindspore import Tensor, nn, context + + +class Net(nn.Cell): + + def __init__(self): + super(Net, self).__init__() + self.op = P.ConjugateTranspose() + + def construct(self, x, perm): + return self.op(x, perm) + + +def dyn_case(): + net = Net() + + x_dyn = Tensor(shape=[None, 4, None, 7], dtype=mstype.float32) + perm = (2, 1, 0, 3) + + net.set_inputs(x_dyn, perm) + + x = Tensor(np.random.random((8, 4, 5, 7)).astype(np.float32)) + out = net(x, perm) + + expect_shape = (5, 4, 8, 7) + assert out.asnumpy().shape == expect_shape + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_conjugate_transpose_dyn(): + """ + Feature: test ConjugateTranspose in PyNative and Graph modes. + Description: test dynamic shape case. + Expectation: expect correct shape result. + """ + context.set_context(mode=context.GRAPH_MODE, device_target='CPU') + dyn_case() + context.set_context(mode=context.PYNATIVE_MODE, device_target='CPU') + dyn_case() diff --git a/tests/st/ops/gpu/test_sparse_apply_ftrl_op.py b/tests/st/ops/gpu/test_sparse_apply_ftrl.py similarity index 69% rename from tests/st/ops/gpu/test_sparse_apply_ftrl_op.py rename to tests/st/ops/gpu/test_sparse_apply_ftrl.py index da83c41eb15..79cd95652cb 100644 --- a/tests/st/ops/gpu/test_sparse_apply_ftrl_op.py +++ b/tests/st/ops/gpu/test_sparse_apply_ftrl.py @@ -24,31 +24,82 @@ import mindspore.common.dtype as mstype class Net(nn.Cell): + def __init__(self): super(Net, self).__init__() - self.sparse_apply_ftrl = P.SparseApplyFtrl(lr=0.001, l1=0.0, l2=0.0, lr_power=-0.5, use_locking=False) - self.var = Parameter(Tensor(np.ones([3, 3, 3]).astype(np.float32)), name="var") - self.accum = Parameter(Tensor(np.ones([3, 3, 3]).astype(np.float32)), name="accum") - self.linear = Parameter(Tensor(np.ones([3, 3, 3]).astype(np.float32)), name="linear") + self.sparse_apply_ftrl = P.SparseApplyFtrl(lr=0.001, + l1=0.0, + l2=0.0, + lr_power=-0.5, + use_locking=False) + self.var = Parameter(Tensor(np.ones([3, 3, 3]).astype(np.float32)), + name="var") + self.accum = Parameter(Tensor(np.ones([3, 3, 3]).astype(np.float32)), + name="accum") + self.linear = Parameter(Tensor(np.ones([3, 3, 3]).astype(np.float32)), + name="linear") def construct(self, grad, indices): - out = self.sparse_apply_ftrl(self.var, self.accum, self.linear, grad, indices) + out = self.sparse_apply_ftrl(self.var, self.accum, self.linear, grad, + indices) return out -class Net_half(nn.Cell): +class NetHalf(nn.Cell): + def __init__(self): - super(Net_half, self).__init__() - self.sparse_apply_ftrl = P.SparseApplyFtrl(lr=0.001, l1=0.0, l2=0.0, lr_power=-0.5, use_locking=False) - self.var = Parameter(Tensor(np.ones([3, 3, 3]).astype(np.float16)), name="var") - self.accum = Parameter(Tensor(np.ones([3, 3, 3]).astype(np.float16)), name="accum") - self.linear = Parameter(Tensor(np.ones([3, 3, 3]).astype(np.float16)), name="linear") + super(NetHalf, self).__init__() + self.sparse_apply_ftrl = P.SparseApplyFtrl(lr=0.001, + l1=0.0, + l2=0.0, + lr_power=-0.5, + use_locking=False) + self.var = Parameter(Tensor(np.ones([3, 3, 3]).astype(np.float16)), + name="var") + self.accum = Parameter(Tensor(np.ones([3, 3, 3]).astype(np.float16)), + name="accum") + self.linear = Parameter(Tensor(np.ones([3, 3, 3]).astype(np.float16)), + name="linear") def construct(self, grad, indices): - out = self.sparse_apply_ftrl(self.var, self.accum, self.linear, grad, indices) + out = self.sparse_apply_ftrl(self.var, self.accum, self.linear, grad, + indices) return out +def dyn_case(): + net = Net() + + grad_dyn = Tensor(shape=[3, None, None], dtype=mstype.float32) + indices_dyn = Tensor(shape=[None], dtype=mstype.int32) + + net.set_inputs(grad_dyn, indices_dyn) + + grad = Tensor(np.ones([3, 3, 3]).astype(np.float32)) + indices = Tensor([0, 1, 2], mstype.int32) + + out = net(grad, indices) + + expect_shape = (3, 3, 3) + for i in range(3): + assert out[i].asnumpy().shape == expect_shape + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu +@pytest.mark.env_onecard +def test_sparse_apply_ftrl_dyn(): + """ + Feature: test SparseApplyFtrl in PyNative and Graph modes. + Description: test dynamic shape case. + Expectation: expect correct shape result. + """ + context.set_context(mode=context.GRAPH_MODE, device_target='GPU') + dyn_case() + context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') + dyn_case() + + @pytest.mark.level1 @pytest.mark.platform_x86_gpu_training @pytest.mark.env_onecard @@ -63,7 +114,8 @@ def test_ftrl(): [0.291479, 0.291479, 0.291479]], [[0.291479, 0.291479, 0.291479], [0.291479, 0.291479, 0.291479], - [0.291479, 0.291479, 0.291479]]]).astype(np.float32) + [0.291479, 0.291479, + 0.291479]]]).astype(np.float32) context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") sparse_apply_ftrl = Net() sparse_apply_ftrl(gradient, indices) @@ -83,12 +135,11 @@ def test_ftrl_sparse_int64_ind(): expect_var = np.array([[[0.291479, 0.291479, 0.291479], [0.291479, 0.291479, 0.291479], [0.291479, 0.291479, 0.291479]], - [[1, 1, 1], - [1, 1, 1], - [1, 1, 1]], + [[1, 1, 1], [1, 1, 1], [1, 1, 1]], [[0.291479, 0.291479, 0.291479], [0.291479, 0.291479, 0.291479], - [0.291479, 0.291479, 0.291479]]]).astype(np.float32) + [0.291479, 0.291479, + 0.291479]]]).astype(np.float32) context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") sparse_apply_ftrl = Net() sparse_apply_ftrl(gradient, indices) @@ -113,13 +164,14 @@ def test_ftrl_half(): [0.291479, 0.291479, 0.291479]], [[0.291479, 0.291479, 0.291479], [0.291479, 0.291479, 0.291479], - [0.291479, 0.291479, 0.291479]]]).astype(np.float16) + [0.291479, 0.291479, + 0.291479]]]).astype(np.float16) context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") - sparse_apply_ftrl = Net_half() + sparse_apply_ftrl = NetHalf() sparse_apply_ftrl(gradient, indices) assert np.all(sparse_apply_ftrl.var.data.asnumpy() == expect_var) context.set_context(mode=context.GRAPH_MODE, device_target="GPU") - sparse_apply_ftrl = Net_half() + sparse_apply_ftrl = NetHalf() sparse_apply_ftrl(gradient, indices) assert np.all(sparse_apply_ftrl.var.data.asnumpy() == expect_var) @@ -133,21 +185,21 @@ def test_ftrl_sparse_half_int64_ind(): expect_var = np.array([[[0.291479, 0.291479, 0.291479], [0.291479, 0.291479, 0.291479], [0.291479, 0.291479, 0.291479]], - [[1, 1, 1], - [1, 1, 1], - [1, 1, 1]], + [[1, 1, 1], [1, 1, 1], [1, 1, 1]], [[0.291479, 0.291479, 0.291479], [0.291479, 0.291479, 0.291479], - [0.291479, 0.291479, 0.291479]]]).astype(np.float16) + [0.291479, 0.291479, + 0.291479]]]).astype(np.float16) context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") - sparse_apply_ftrl = Net_half() + sparse_apply_ftrl = NetHalf() sparse_apply_ftrl(gradient, indices) assert np.all(sparse_apply_ftrl.var.data.asnumpy() == expect_var) context.set_context(mode=context.GRAPH_MODE, device_target="GPU") - sparse_apply_ftrl = Net_half() + sparse_apply_ftrl = NetHalf() sparse_apply_ftrl(gradient, indices) assert np.all(sparse_apply_ftrl.var.data.asnumpy() == expect_var) + @pytest.mark.level1 @pytest.mark.platform_x86_gpu_training @pytest.mark.env_onecard @@ -162,12 +214,13 @@ def test_ftrl_half_return_output(): [0.291479, 0.291479, 0.291479]], [[0.291479, 0.291479, 0.291479], [0.291479, 0.291479, 0.291479], - [0.291479, 0.291479, 0.291479]]]).astype(np.float16) + [0.291479, 0.291479, + 0.291479]]]).astype(np.float16) context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") - sparse_apply_ftrl = Net_half() + sparse_apply_ftrl = NetHalf() output = sparse_apply_ftrl(gradient, indices) assert np.all(output[0].asnumpy() == expect_var) context.set_context(mode=context.GRAPH_MODE, device_target="GPU") - sparse_apply_ftrl = Net_half() + sparse_apply_ftrl = NetHalf() sparse_apply_ftrl(gradient, indices) assert np.all(output[0].asnumpy() == expect_var)