CombinedNMS, SparseApplyFtrl, ConjugateTranspose

update

upadte
This commit is contained in:
hw_hz 2022-10-29 16:32:54 +08:00
parent 73a7bc9b28
commit 1e96180574
13 changed files with 552 additions and 298 deletions

View File

@ -342,26 +342,49 @@ void CombinedNonMaxSuppressionCpuKernelMod::CheckOutput() {
}
}
void CombinedNonMaxSuppressionCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
size_t input_num = common::AnfAlgo::GetInputTensorNum(kernel_node);
size_t output_num = common::AnfAlgo::GetOutputTensorNum(kernel_node);
node_wpt_ = kernel_node;
input0_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
input1_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, KIndex1);
input2_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, KIndex2);
input3_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, KIndex3);
input4_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, KIndex4);
input5_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, KIndex5);
bool CombinedNonMaxSuppressionCpuKernelMod::Init(const BaseOperatorPtr &base_operator,
const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) {
MS_EXCEPTION_IF_NULL(base_operator);
kernel_name_ = base_operator->name();
return true;
}
int CombinedNonMaxSuppressionCpuKernelMod::Resize(const BaseOperatorPtr &base_operator,
const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &) {
if (auto ret = KernelMod::Resize(base_operator, inputs, outputs); ret != KRET_OK) {
return ret;
}
size_t input_num = inputs.size();
size_t output_num = outputs.size();
CHECK_KERNEL_INPUTS_NUM(input_num, kCombinedNonMaxSuppressionInputsNum, kernel_name_);
CHECK_KERNEL_OUTPUTS_NUM(output_num, kCombinedNonMaxSuppressionOutputsNum, kernel_name_);
input0_shape_ = inputs.at(kIndex0)->GetDeviceShapeAdaptively();
input1_shape_ = inputs.at(KIndex1)->GetDeviceShapeAdaptively();
input2_shape_ = inputs.at(KIndex2)->GetDeviceShapeAdaptively();
input3_shape_ = inputs.at(KIndex3)->GetDeviceShapeAdaptively();
input4_shape_ = inputs.at(KIndex4)->GetDeviceShapeAdaptively();
input5_shape_ = inputs.at(KIndex5)->GetDeviceShapeAdaptively();
output0_shape_ = outputs.at(kIndex0)->GetDeviceShapeAdaptively();
output1_shape_ = outputs.at(kIndex1)->GetDeviceShapeAdaptively();
output2_shape_ = outputs.at(kIndex2)->GetDeviceShapeAdaptively();
output3_shape_ = outputs.at(kIndex3)->GetDeviceShapeAdaptively();
soft_nms_sigma_ = 0.0;
num_bath_ = static_cast<int>(input0_shape_[0]);
num_boxes_ = static_cast<int>(input0_shape_[KIndex1]);
q_ = static_cast<int>(input0_shape_[KIndex2]);
num_class_ = static_cast<int>((input1_shape_[KIndex2]));
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
pad_per_class_ = false;
clip_boxes_ = true;
auto prim = common::AnfAlgo::GetCNodePrimitive(kernel_node);
PrimitivePtr prim = base_operator->GetPrim();
auto pad_per_class = prim->GetAttr("pad_per_class");
auto clip_boxes = prim->GetAttr("clip_boxes");
if (pad_per_class != nullptr) {
@ -370,8 +393,11 @@ void CombinedNonMaxSuppressionCpuKernelMod::InitKernel(const CNodePtr &kernel_no
if (clip_boxes != nullptr) {
clip_boxes_ = GetValue<bool>(clip_boxes);
}
CHECK_KERNEL_INPUTS_NUM(input_num, kCombinedNonMaxSuppressionInputsNum, kernel_name_);
CHECK_KERNEL_OUTPUTS_NUM(output_num, kCombinedNonMaxSuppressionOutputsNum, kernel_name_);
CheckInput();
CheckOutput();
return KRET_OK;
}
bool CombinedNonMaxSuppressionCpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &inputs,
@ -392,24 +418,9 @@ bool CombinedNonMaxSuppressionCpuKernelMod::Launch(const std::vector<kernel::Add
} else {
num_detection_ = max_total_size_;
}
auto node_ = node_wpt_.lock();
if (!node_) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', node_wpt_(kernel_node) is expired. Error no: " << node_ << ".";
}
ShapeVector shape0 = {input0_shape_[0], static_cast<int64_t>(num_detection_), DimSize4};
ShapeVector shape1 = {input0_shape_[0], static_cast<int64_t>(num_detection_)};
ShapeVector shape2 = {input0_shape_[0], static_cast<int64_t>(num_detection_)};
ShapeVector shape3 = {input0_shape_[0]};
common::AnfAlgo::SetOutputInferTypeAndShape(
{kNumberTypeFloat32, kNumberTypeFloat32, kNumberTypeFloat32, kNumberTypeInt32}, {shape0, shape1, shape2, shape3},
node_.get());
output0_shape_ = AnfAlgo::GetOutputDeviceShape(node_, KIndex0);
output1_shape_ = AnfAlgo::GetOutputDeviceShape(node_, KIndex1);
output2_shape_ = AnfAlgo::GetOutputDeviceShape(node_, KIndex2);
output3_shape_ = AnfAlgo::GetOutputDeviceShape(node_, KIndex3);
size_per_class_ = max_output_size_per_class_ < num_boxes_ ? max_output_size_per_class_ : num_boxes_;
CheckInput();
CheckOutput();
if (max_total_size_ <= 0) {
MS_LOG(EXCEPTION) << "For " << kernel_name_ << " max_total_size must be > 0, but got " << max_total_size_ << ".";
}
@ -435,6 +446,7 @@ bool CombinedNonMaxSuppressionCpuKernelMod::Launch(const std::vector<kernel::Add
(void)nms_perbath(boxes, scores, nmsed_boxes, nmsed_scores, nmsed_class, valid_detection);
return true;
}
std::vector<KernelAttr> CombinedNonMaxSuppressionCpuKernelMod::GetOpSupport() {
static std::vector<KernelAttr> kernel_attr_list = {
KernelAttr()

View File

@ -50,11 +50,17 @@ bool result_cmp(const result_para &a, const result_para &b) { return a.score > b
namespace mindspore {
namespace kernel {
class CombinedNonMaxSuppressionCpuKernelMod : public DeprecatedNativeCpuKernelMod {
class CombinedNonMaxSuppressionCpuKernelMod : public NativeCpuKernelMod {
public:
CombinedNonMaxSuppressionCpuKernelMod() = default;
~CombinedNonMaxSuppressionCpuKernelMod() override = default;
void InitKernel(const CNodePtr &kernel_node) override;
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) override;
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs, const std::map<uint32_t, tensor::TensorPtr> &) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
@ -69,6 +75,7 @@ class CombinedNonMaxSuppressionCpuKernelMod : public DeprecatedNativeCpuKernelMo
void nms_perclass(float *, float *, std::vector<non_max_suppression_local::result_para> &, int &);
void CheckInput();
void CheckOutput();
int num_bath_ = 0;
int num_boxes_ = 0;
int q_ = 0;
@ -87,7 +94,7 @@ class CombinedNonMaxSuppressionCpuKernelMod : public DeprecatedNativeCpuKernelMo
float soft_nms_sigma_ = 0.0;
bool pad_per_class_ = 0;
bool clip_boxes_ = 1;
CNodeWeakPtr node_wpt_;
std::vector<int64_t> input0_shape_;
std::vector<int64_t> input1_shape_;
std::vector<int64_t> input2_shape_;

View File

@ -44,14 +44,13 @@ using complex128 = std::complex<double>;
constexpr size_t kMaxTransposeSerialSize = 50331648;
} // namespace
void ConjugateTransposeCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
input_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
output_shape_ = AnfAlgo::GetOutputDeviceShape(kernel_node, 0);
dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
perm_type_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 1);
bool ConjugateTransposeCpuKernelMod::Init(const BaseOperatorPtr &base_operator,
const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) {
MS_EXCEPTION_IF_NULL(base_operator);
kernel_name_ = base_operator->name();
dtype_ = inputs.at(kIndex0)->GetDtype();
perm_type_ = inputs.at(kIndex1)->GetDtype();
launch_map_[kNumberTypeBool] = &ConjugateTransposeCpuKernelMod::LaunchKernel<bool>;
launch_map_[kNumberTypeInt8] = &ConjugateTransposeCpuKernelMod::LaunchKernel<int8_t>;
launch_map_[kNumberTypeInt16] = &ConjugateTransposeCpuKernelMod::LaunchKernel<int16_t>;
@ -72,6 +71,21 @@ void ConjugateTransposeCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
} else {
MS_LOG(EXCEPTION) << "For ConjugateTranspose: unsupported input data type: " << dtype_;
}
return true;
}
int ConjugateTransposeCpuKernelMod::Resize(const BaseOperatorPtr &base_operator,
const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &) {
if (auto ret = KernelMod::Resize(base_operator, inputs, outputs); ret != KRET_OK) {
return ret;
}
input_shape_ = inputs.at(kIndex0)->GetDeviceShapeAdaptively();
output_shape_ = outputs.at(kIndex0)->GetDeviceShapeAdaptively();
return KRET_OK;
}
bool ConjugateTransposeCpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &inputs,

View File

@ -17,6 +17,7 @@
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_CONJUGATE_TRANSPOSE_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_CONJUGATE_TRANSPOSE_CPU_KERNEL_H_
#include <map>
#include <vector>
#include <unordered_map>
#include <memory>
@ -27,14 +28,20 @@
namespace mindspore {
namespace kernel {
class ConjugateTransposeCpuKernelMod : public DeprecatedNativeCpuKernelMod {
class ConjugateTransposeCpuKernelMod : public NativeCpuKernelMod {
public:
ConjugateTransposeCpuKernelMod() = default;
~ConjugateTransposeCpuKernelMod() override = default;
void InitKernel(const CNodePtr &kernel_node) override;
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) override;
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs, const std::map<uint32_t, tensor::TensorPtr> &) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
template <typename T>
static void ConjComplexFunc(T *input, T *output, size_t start, size_t end);

View File

@ -14,7 +14,7 @@
* limitations under the License.
*/
#include "plugin/device/gpu/kernel/nn/sparse_ftrl_gpu_kernel.h"
#include "plugin/device/gpu/kernel/nn/sparse_apply_ftrl_gpu_kernel.h"
namespace mindspore {
namespace kernel {

View File

@ -0,0 +1,133 @@
/**
* Copyright 2020-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_SPARSE_FTRL_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_SPARSE_FTRL_GPU_KERNEL_H_
#include <vector>
#include <map>
#include "ops/sparse_apply_ftrl.h"
#include "utils/check_convert_utils.h"
#include "plugin/device/gpu/kernel/gpu_kernel.h"
#include "plugin/device/gpu/kernel/gpu_kernel_factory.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/sparse_ftrl_impl.cuh"
namespace mindspore {
namespace kernel {
constexpr size_t INPUT_NUM = 5;
template <typename T, typename S>
class SparseFtrlGpuKernelMod : public NativeGpuKernelMod {
public:
SparseFtrlGpuKernelMod() { ResetResource(); }
~SparseFtrlGpuKernelMod() override = default;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
T *variable = GetDeviceAddress<T>(inputs, 0);
T *accumulation = GetDeviceAddress<T>(inputs, 1);
T *linear = GetDeviceAddress<T>(inputs, 2);
T *gradient = GetDeviceAddress<T>(inputs, 3);
S *indices = GetDeviceAddress<S>(inputs, 4);
T *variable_out = GetDeviceAddress<T>(outputs, 0);
T *accumulation_out = GetDeviceAddress<T>(outputs, 1);
T *linear_out = GetDeviceAddress<T>(outputs, 2);
CalSparseApplyFtrl(gradient, indices, num_index_, n_stride_, lr_, l1_, l2_, lr_power_, use_locking_, variable,
accumulation, linear, reinterpret_cast<cudaStream_t>(stream_ptr));
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaMemcpyAsync(variable_out, variable, variable_size_, cudaMemcpyDeviceToDevice,
reinterpret_cast<cudaStream_t>(stream_ptr)),
"cudaMemcpyAsync output failed");
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
cudaMemcpyAsync(accumulation_out, accumulation, accumulation_size_, cudaMemcpyDeviceToDevice,
reinterpret_cast<cudaStream_t>(stream_ptr)),
"cudaMemcpyAsync output failed");
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaMemcpyAsync(linear_out, linear, linear_size_, cudaMemcpyDeviceToDevice,
reinterpret_cast<cudaStream_t>(stream_ptr)),
"cudaMemcpyAsync output failed");
return true;
}
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) {
MS_EXCEPTION_IF_NULL(base_operator);
kernel_name_ = base_operator->name();
auto kernel_ptr = std::dynamic_pointer_cast<ops::SparseApplyFtrl>(base_operator);
MS_EXCEPTION_IF_NULL(kernel_ptr);
lr_ = kernel_ptr->get_lr();
l1_ = kernel_ptr->get_l1();
l2_ = kernel_ptr->get_l2();
lr_power_ = kernel_ptr->get_lr_power();
use_locking_ = kernel_ptr->get_use_locking();
return true;
}
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs, const std::map<uint32_t, tensor::TensorPtr> &) {
if (auto ret = KernelMod::Resize(base_operator, inputs, outputs); ret != KRET_OK) {
return ret;
}
(void)CheckAndConvertUtils::CheckInteger("input num", inputs.size(), kEqual, INPUT_NUM, kernel_name_);
variable_size_ = sizeof(T);
accumulation_size_ = sizeof(T);
linear_size_ = sizeof(T);
n_stride_ = 1;
auto variable_shape = inputs.at(kIndex0)->GetShapeVector();
auto accumulation_shape = inputs.at(kIndex1)->GetShapeVector();
auto linear_shape = inputs.at(kIndex2)->GetShapeVector();
auto indices_shape = inputs.at(kIndex4)->GetShapeVector();
for (size_t i = 0; i < variable_shape.size(); i++) {
variable_size_ *= variable_shape[i];
if (i > 0) {
n_stride_ *= variable_shape[i];
}
}
accumulation_size_ *= SizeOf(accumulation_shape);
linear_size_ *= SizeOf(linear_shape);
num_index_ = indices_shape[0];
return KRET_OK;
}
protected:
void ResetResource() noexcept {
lr_ = 0.0f;
l1_ = 0.0f;
l2_ = 0.0f;
lr_power_ = 0.0f;
use_locking_ = false;
num_index_ = 0;
n_stride_ = 1;
}
private:
size_t variable_size_;
size_t accumulation_size_;
size_t linear_size_;
float lr_;
float l1_;
float l2_;
float lr_power_;
bool use_locking_;
int num_index_;
size_t n_stride_;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_SPARSE_FTRL_GPU_KERNEL_H_

View File

@ -1,168 +0,0 @@
/**
* Copyright 2020-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_SPARSE_FTRL_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_SPARSE_FTRL_GPU_KERNEL_H_
#include <vector>
#include "plugin/device/gpu/kernel/gpu_kernel.h"
#include "plugin/device/gpu/kernel/gpu_kernel_factory.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/sparse_ftrl_impl.cuh"
namespace mindspore {
namespace kernel {
constexpr size_t INPUT_NUM = 5;
template <typename T, typename S>
class SparseFtrlGpuKernelMod : public DeprecatedNativeGpuKernelMod {
public:
SparseFtrlGpuKernelMod() { ResetResource(); }
~SparseFtrlGpuKernelMod() override = default;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
if (is_null_input_) {
return true;
}
T *variable = GetDeviceAddress<T>(inputs, 0);
T *accumulation = GetDeviceAddress<T>(inputs, 1);
T *linear = GetDeviceAddress<T>(inputs, 2);
T *gradient = GetDeviceAddress<T>(inputs, 3);
S *indices = GetDeviceAddress<S>(inputs, 4);
T *variable_out = GetDeviceAddress<T>(outputs, 0);
T *accumulation_out = GetDeviceAddress<T>(outputs, 1);
T *linear_out = GetDeviceAddress<T>(outputs, 2);
CalSparseApplyFtrl(gradient, indices, num_index_, n_stride_, lr_, l1_, l2_, lr_power_, use_locking_, variable,
accumulation, linear, reinterpret_cast<cudaStream_t>(stream_ptr));
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
cudaMemcpyAsync(variable_out, variable, variable_size_, cudaMemcpyDeviceToDevice,
reinterpret_cast<cudaStream_t>(stream_ptr)),
"cudaMemcpyAsync output failed");
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
cudaMemcpyAsync(accumulation_out, accumulation, accumulation_size_,
cudaMemcpyDeviceToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
"cudaMemcpyAsync output failed");
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
cudaMemcpyAsync(linear_out, linear, linear_size_, cudaMemcpyDeviceToDevice,
reinterpret_cast<cudaStream_t>(stream_ptr)),
"cudaMemcpyAsync output failed");
return true;
}
bool Init(const CNodePtr &kernel_node) override {
auto kernel_name = common::AnfAlgo::GetCNodeName(kernel_node);
kernel_node_ = kernel_node;
size_t input_num = common::AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != INPUT_NUM) {
MS_LOG(EXCEPTION) << "For '" << kernel_name << "', the number of inputs must be " << INPUT_NUM << ", but got "
<< input_num;
}
variable_size_ = sizeof(T);
accumulation_size_ = sizeof(T);
linear_size_ = sizeof(T);
gradient_size_ = sizeof(T);
indices_size_ = sizeof(S);
auto shape_signed = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
if (IsDynamic(shape_signed)) {
return true;
}
auto variable_shape = Convert2SizeTClipNeg(shape_signed);
auto accumulation_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
auto linear_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 2);
auto gradient_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 3);
auto indices_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 4);
is_null_input_ = CHECK_SHAPE_NULL(variable_shape, kernel_name, "var") ||
CHECK_SHAPE_NULL(accumulation_shape, kernel_name, "accum") ||
CHECK_SHAPE_NULL(linear_shape, kernel_name, "linear") ||
CHECK_SHAPE_NULL(gradient_shape, kernel_name, "grad") ||
CHECK_SHAPE_NULL(indices_shape, kernel_name, "indices");
if (is_null_input_) {
InitSizeLists();
return true;
}
for (size_t i = 0; i < variable_shape.size(); i++) {
variable_size_ *= variable_shape[i];
if (i > 0) {
n_stride_ *= variable_shape[i];
}
}
accumulation_size_ *= SizeOf(accumulation_shape);
linear_size_ *= SizeOf(linear_shape);
gradient_size_ *= SizeOf(gradient_shape);
indices_size_ *= SizeOf(indices_shape);
lr_ = GetAttr<float>(kernel_node, "lr");
l1_ = GetAttr<float>(kernel_node, "l1");
l2_ = GetAttr<float>(kernel_node, "l2");
lr_power_ = GetAttr<float>(kernel_node, "lr_power");
use_locking_ = GetAttr<bool>(kernel_node, "use_locking");
num_index_ = LongToSizeClipNeg(indices_shape[0]);
InitSizeLists();
return true;
}
protected:
void InitSizeLists() override {
input_size_list_.push_back(variable_size_);
input_size_list_.push_back(accumulation_size_);
input_size_list_.push_back(linear_size_);
input_size_list_.push_back(gradient_size_);
input_size_list_.push_back(indices_size_);
output_size_list_.push_back(variable_size_);
output_size_list_.push_back(accumulation_size_);
output_size_list_.push_back(linear_size_);
}
void ResetResource() noexcept override {
variable_size_ = 0;
accumulation_size_ = 0;
linear_size_ = 0;
gradient_size_ = 0;
indices_size_ = 0;
lr_ = 0.0f;
l1_ = 0.0f;
l2_ = 0.0f;
lr_power_ = 0.0f;
use_locking_ = false;
is_null_input_ = false;
num_index_ = 0;
n_stride_ = 1;
input_size_list_.clear();
output_size_list_.clear();
workspace_size_list_.clear();
}
private:
size_t variable_size_;
size_t accumulation_size_;
size_t linear_size_;
size_t gradient_size_;
size_t indices_size_;
float lr_;
float l1_;
float l2_;
float lr_power_;
bool use_locking_;
bool is_null_input_;
int num_index_;
size_t n_stride_;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_SPARSE_FTRL_GPU_KERNEL_H_

View File

@ -31,6 +31,7 @@ const int64_t kInputDimension1 = 3;
const int64_t kDimsize = 4;
const int64_t kInputs = 6;
const size_t ksecond = 2;
tensor::TensorPtr Get_Value(const std::vector<AbstractBasePtr> &input_args, size_t index) {
auto input = input_args[index]->cast<abstract::AbstractTensorPtr>();
MS_EXCEPTION_IF_NULL(input);
@ -38,15 +39,14 @@ tensor::TensorPtr Get_Value(const std::vector<AbstractBasePtr> &input_args, size
MS_EXCEPTION_IF_NULL(input_shape_value_ptr);
return input_shape_value_ptr->cast<tensor::TensorPtr>();
}
abstract::TupleShapePtr CombinedNonMaxSuppressionInferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
auto prim_name = primitive->name();
auto input0_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape];
auto input1_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape];
auto input2_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape())[kShape];
auto input3_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex3]->BuildShape())[kShape];
auto input4_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex4]->BuildShape())[kShape];
auto input5_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex5]->BuildShape())[kShape];
void CombinedNonMaxSuppressionCheckShapeSize(const ShapeVector &input0_shape, const ShapeVector &input1_shape,
const ShapeVector &input2_shape, const ShapeVector &input3_shape,
const ShapeVector &input4_shape, const ShapeVector &input5_shape,
const bool &is_dynamic_rank, const std::string &prim_name) {
if (is_dynamic_rank) {
return;
}
(void)CheckAndConvertUtils::CheckInteger("boxes dim", SizeToLong(input0_shape.size()), kEqual, kInputDimension0,
prim_name);
(void)CheckAndConvertUtils::CheckInteger("scores dim", SizeToLong(input1_shape.size()), kEqual, kInputDimension1,
@ -55,6 +55,13 @@ abstract::TupleShapePtr CombinedNonMaxSuppressionInferShape(const PrimitivePtr &
(void)CheckAndConvertUtils::CheckInteger("max_total_size dim", SizeToLong(input3_shape.size()), kEqual, 0, prim_name);
(void)CheckAndConvertUtils::CheckInteger("iou_threshold", SizeToLong(input4_shape.size()), kEqual, 0, prim_name);
(void)CheckAndConvertUtils::CheckInteger("score_threshold", SizeToLong(input5_shape.size()), kEqual, 0, prim_name);
}
void CombinedNonMaxSuppressionCheckShapeValue(const ShapeVector &input0_shape, const ShapeVector &input1_shape,
const bool &is_dynamic, const std::string &prim_name) {
if (is_dynamic) {
return;
}
if (input0_shape[0] != input1_shape[0]) {
MS_EXCEPTION(ValueError) << "For " << prim_name << ", the boxes's 1st dim must be same with the scores's"
<< " 1st dim, but got" << input0_shape[0] << " and " << input1_shape[0] << ".";
@ -72,45 +79,35 @@ abstract::TupleShapePtr CombinedNonMaxSuppressionInferShape(const PrimitivePtr &
MS_EXCEPTION(ValueError) << "For " << prim_name << ", the boxes's 4th dim must be equal to 4, but got"
<< input0_shape[kInputIndex3] << ".";
}
for (int64_t i = 0; i < kInputs; i++) {
if (!input_args[i]->isa<abstract::AbstractTensor>()) {
MS_EXCEPTION(TypeError) << "For " << prim_name << " input" << i << " only support tensor!";
}
}
}
abstract::TupleShapePtr CombinedNonMaxSuppressionGetOutputShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args,
const bool &is_dynamic) {
auto input0_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape];
auto input1_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape];
auto pad_per_class_ptr = primitive->GetAttr("pad_per_class");
MS_EXCEPTION_IF_NULL(pad_per_class_ptr);
bool pad_per_class = GetValue<bool>(pad_per_class_ptr);
auto input2_tensor = Get_Value(input_args, kInputIndex2);
auto input3_tensor = Get_Value(input_args, kInputIndex3);
auto input4_tensor = Get_Value(input_args, kInputIndex4);
auto input5_tensor = Get_Value(input_args, kInputIndex5);
if (IsValue(input_args[kInputIndex2]->BuildValue()) && IsValue(input_args[kInputIndex3]->BuildValue())) {
if (IsValue(input_args[kInputIndex4]->BuildValue()) && input_args[kInputIndex5]->BuildValue()) {
auto iou_threshold = *(static_cast<float *>(input4_tensor->data_c()));
auto score_threshold = *(static_cast<float *>(input5_tensor->data_c()));
if (iou_threshold < 0 || iou_threshold > 1) {
MS_EXCEPTION(ValueError) << "For " << prim_name << ", iou_threshold must be in [0,1], but got " << iou_threshold
<< ".";
}
if (score_threshold < 0 && input0_shape[kInputIndex2] == input1_shape[kInputIndex2]) {
MS_EXCEPTION(ValueError) << "For " << prim_name << ", it is temporarily unsupported when boxes's 2'nd dim "
<< "is not 1 and score_threshold is less than 1.";
}
}
if (!is_dynamic && IsValue(input_args[kInputIndex2]->BuildValue()) &&
IsValue(input_args[kInputIndex3]->BuildValue())) {
auto input2_tensor = Get_Value(input_args, kInputIndex2);
auto input3_tensor = Get_Value(input_args, kInputIndex3);
auto max_output_size_per_class = *(static_cast<int32_t *>(input2_tensor->data_c()));
auto max_total_size = *(static_cast<int32_t *>(input3_tensor->data_c()));
if (max_total_size <= 0) {
MS_EXCEPTION(ValueError) << "For " << prim_name << " max_total_size must be > 0, but got " << max_total_size
<< ".";
}
if (max_output_size_per_class <= 0) {
MS_EXCEPTION(ValueError) << "For " << prim_name << " max_output_size_per_class must be > 0, but got "
<< max_output_size_per_class << ".";
}
const int32_t kNumZero = 0;
CheckAndConvertUtils::CheckInteger("max_total_size", max_total_size, kGreaterThan, kNumZero, primitive->name());
CheckAndConvertUtils::CheckInteger("max_output_size_per_clas", max_output_size_per_class, kGreaterThan, kNumZero,
primitive->name());
auto num_detection = max_total_size;
if (pad_per_class) {
num_detection = std::min(max_total_size, max_output_size_per_class * static_cast<int32_t>(input1_shape[ksecond]));
}
int64_t bs = input0_shape[0];
ShapeVector shape1 = {bs, num_detection, 4};
ShapeVector shape2 = {bs, num_detection};
@ -122,14 +119,58 @@ abstract::TupleShapePtr CombinedNonMaxSuppressionInferShape(const PrimitivePtr &
auto out4 = std::make_shared<abstract::Shape>(shape4);
return std::make_shared<abstract::TupleShape>(std::vector<abstract::BaseShapePtr>{out1, out2, out3, out4});
} else {
auto shape1 = std::make_shared<abstract::Shape>(ShapeVector{-2});
auto shape2 = std::make_shared<abstract::Shape>(ShapeVector{-2});
auto shape3 = std::make_shared<abstract::Shape>(ShapeVector{-2});
auto shape4 = std::make_shared<abstract::Shape>(ShapeVector{-2});
auto shape1 = std::make_shared<abstract::Shape>(ShapeVector{-1, -1, 4});
auto shape2 = std::make_shared<abstract::Shape>(ShapeVector{-1, -1});
auto shape3 = std::make_shared<abstract::Shape>(ShapeVector{-1, -1});
auto shape4 = std::make_shared<abstract::Shape>(ShapeVector{-1});
return std::make_shared<abstract::TupleShape>(std::vector<abstract::BaseShapePtr>{shape1, shape2, shape3, shape4});
}
}
abstract::TupleShapePtr CombinedNonMaxSuppressionInferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
auto prim_name = primitive->name();
auto input0_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape];
auto input1_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape];
auto input2_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape())[kShape];
auto input3_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex3]->BuildShape())[kShape];
auto input4_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex4]->BuildShape())[kShape];
auto input5_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex5]->BuildShape())[kShape];
std::vector<ShapeVector> all_shapes = {input0_shape, input1_shape, input2_shape,
input3_shape, input4_shape, input5_shape};
auto is_dynamic = (IsDynamic(input0_shape) || IsDynamic(input1_shape));
auto is_dynamic_rank = std::any_of(all_shapes.begin(), all_shapes.end(), IsDynamicRank);
CombinedNonMaxSuppressionCheckShapeSize(input0_shape, input1_shape, input2_shape, input3_shape, input4_shape,
input5_shape, is_dynamic_rank, prim_name);
CombinedNonMaxSuppressionCheckShapeValue(input0_shape, input1_shape, is_dynamic, prim_name);
for (int64_t i = 0; i < kInputs; i++) {
if (!input_args[i]->isa<abstract::AbstractTensor>()) {
MS_EXCEPTION(TypeError) << "For " << prim_name << " input" << i << " only support tensor!";
}
}
if (IsValue(input_args[kInputIndex4]->BuildValue()) && IsValue(input_args[kInputIndex5]->BuildValue())) {
auto input4_tensor = Get_Value(input_args, kInputIndex4);
auto input5_tensor = Get_Value(input_args, kInputIndex5);
auto iou_threshold = *(static_cast<float *>(input4_tensor->data_c()));
auto score_threshold = *(static_cast<float *>(input5_tensor->data_c()));
if (iou_threshold < 0 || iou_threshold > 1) {
MS_EXCEPTION(ValueError) << "For " << prim_name << ", iou_threshold must be in [0,1], but got " << iou_threshold
<< ".";
}
if (score_threshold < 0 && !is_dynamic && input0_shape[kInputIndex2] == input1_shape[kInputIndex2]) {
MS_EXCEPTION(ValueError) << "For " << prim_name << ", it is temporarily unsupported when boxes's 2'nd dim "
<< "is not 1 and score_threshold is less than 1.";
}
}
return CombinedNonMaxSuppressionGetOutputShape(primitive, input_args, is_dynamic);
}
TuplePtr CombinedNonMaxSuppressionInferType(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
auto prim_name = primitive->name();

View File

@ -32,33 +32,48 @@ abstract::ShapePtr ConjugateTransposeInferShape(const PrimitivePtr &primitive,
MS_EXCEPTION_IF_NULL(primitive);
auto op_name = primitive->name();
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
ShapeVector p_value;
ShapeVector p_value_raw;
auto is_dynamic_rank = IsDynamicRank(x_shape);
if (is_dynamic_rank) {
return std::make_shared<abstract::Shape>(std::vector<int64_t>{-2});
}
constexpr int64_t dim_7 = 7;
(void)CheckAndConvertUtils::CheckInteger("[x] rank", static_cast<int64_t>(x_shape.size()), kLessEqual, dim_7,
op_name);
auto perm_value = input_args[1]->BuildValue();
MS_EXCEPTION_IF_NULL(perm_value);
if (perm_value->isa<ValueTuple>()) {
if (perm_value->isa<AnyValue>()) {
std::vector<int64_t> output_shape(static_cast<int>(x_shape.size()), -1);
return std::make_shared<abstract::Shape>(output_shape);
}
ShapeVector p_value;
ShapeVector p_value_raw;
if (perm_value->isa<tensor::Tensor>()) {
p_value_raw = CheckAndConvertUtils::CheckTensorIntValue("input[perm]", perm_value, op_name);
} else if (perm_value->isa<ValueTuple>()) {
p_value_raw = CheckAndConvertUtils::CheckTupleInt("input[perm]", perm_value, op_name);
} else {
MS_EXCEPTION(TypeError) << "For '" << op_name << "', the type of perm must be Tuple, but got "
<< input_args[1]->BuildType()->ToString() << " .";
}
for (auto p : p_value_raw) {
p = (p >= 0) ? p : (static_cast<int64_t>(p_value_raw.size()) + p);
p_value.emplace_back(p);
}
if (x_shape.size() != p_value.size()) {
MS_EXCEPTION(ValueError) << "For '" << op_name << "', the dimension of x " << x_shape.size() << " and perm "
<< p_value.size() << " must be equal, but got the dimension of x " << x_shape.size()
<< " and perm " << p_value.size() << " .";
}
constexpr int64_t dim_7 = 7;
if (p_value.size() > dim_7) {
MS_EXCEPTION(ValueError) << "For '" << op_name << "', the dimension of perm must be less than 8, but get "
<< p_value.size() << " .";
}
for (auto i : p_value) {
(void)CheckAndConvertUtils::CheckInteger("perm element", i, kLessThan, SizeToLong(p_value.size()), op_name);
}
std::vector<int64_t> tmp(p_value);
for (auto it = tmp.begin(); it != tmp.end();) {
auto dim = *it;
@ -69,6 +84,7 @@ abstract::ShapePtr ConjugateTransposeInferShape(const PrimitivePtr &primitive,
MS_EXCEPTION(ValueError) << "For '" << op_name << "', the value of perm must be different.";
}
}
std::vector<int64_t> in_shape(p_value);
(void)std::transform(in_shape.begin(), in_shape.end(), in_shape.begin(), [x_shape](size_t i) { return x_shape[i]; });
return std::make_shared<abstract::Shape>(in_shape);
@ -85,7 +101,6 @@ TypePtr ConjugateTransposeInferType(const PrimitivePtr &prim, const std::vector<
}
} // namespace
MIND_API_OPERATOR_IMPL(ConjugateTranspose, BaseOperator);
AbstractBasePtr ConjugateTransposeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
@ -95,6 +110,9 @@ AbstractBasePtr ConjugateTransposeInfer(const abstract::AnalysisEnginePtr &, con
auto shape = ConjugateTransposeInferShape(primitive, input_args);
return abstract::MakeAbstract(shape, type);
}
REGISTER_HOST_DEPENDS(kNameConjugateTranspose, {1});
MIND_API_OPERATOR_IMPL(ConjugateTranspose, BaseOperator);
REGISTER_PRIMITIVE_EVAL_IMPL(ConjugateTranspose, prim::kPrimConjugateTranspose, ConjugateTransposeInfer, nullptr, true);
} // namespace ops
} // namespace mindspore

View File

@ -136,7 +136,6 @@ void SparseApplyFtrl::Init(float lr, float l1, float l2, float lr_power, bool us
set_use_locking(use_locking);
}
MIND_API_OPERATOR_IMPL(SparseApplyFtrl, BaseOperator);
AbstractBasePtr SparseApplyFtrlInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
@ -154,14 +153,14 @@ AbstractBasePtr SparseApplyFtrlInfer(const abstract::AnalysisEnginePtr &, const
(void)CheckAndConvertUtils::CheckValue(kL1, l1, kGreaterEqual, 0.0f, op_name);
(void)CheckAndConvertUtils::CheckValue(kL2, l2, kGreaterEqual, 0.0f, op_name);
(void)CheckAndConvertUtils::CheckValue(kLrPower, lr_power, kLessEqual, 0.0f, op_name);
(void)CheckAndConvertUtils::CheckInteger("input numbers", SizeToLong(input_args.size()), kGreaterEqual,
sparse_apply_ftrl::kSparseApplyFtrlInputNum, op_name);
(void)CheckAndConvertUtils::CheckInteger("input numbers", CheckAndConvertUtils::GetRemoveMonadAbsNum(input_args),
kEqual, sparse_apply_ftrl::kSparseApplyFtrlInputNum, op_name);
auto types = sparse_apply_ftrl::SparseApplyFtrlInferType(primitive, input_args);
auto shapes = sparse_apply_ftrl::SparseApplyFtrlInferShape(primitive, input_args);
return abstract::MakeAbstract(shapes, types);
}
MIND_API_OPERATOR_IMPL(SparseApplyFtrl, BaseOperator);
REGISTER_PRIMITIVE_EVAL_IMPL(SparseApplyFtrl, prim::kPrimSparseApplyFtrl, SparseApplyFtrlInfer, nullptr, true);
} // namespace ops
} // namespace mindspore

View File

@ -0,0 +1,77 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import pytest
import numpy as np
import mindspore.ops.operations.image_ops as P
from mindspore.common import dtype as mstype
from mindspore import Tensor, nn, context
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.op = P.CombinedNonMaxSuppression()
def construct(self, boxes, scores, max_output_size_per_class,
max_total_size, iou_threshold, score_threshold):
return self.op(boxes, scores, max_output_size_per_class,
max_total_size, iou_threshold, score_threshold)
def dyn_case():
net = Net()
boxes_dyn = Tensor(shape=[None, None, None, 4], dtype=mstype.float32)
scores_dyn = Tensor(shape=[None, None, None], dtype=mstype.float32)
max_output_size_per_class = Tensor(4, mstype.int32)
max_total_size = Tensor(1, mstype.int32)
iou_threshold = Tensor(0, mstype.float32)
score_threshold = Tensor(0, mstype.float32)
net.set_inputs(boxes_dyn, scores_dyn, max_output_size_per_class,
max_total_size, iou_threshold, score_threshold)
boxes = Tensor(
np.array([[[[200, 100, 150, 100]], [[220, 120, 150, 100]],
[[190, 110, 150, 100]], [[210, 112, 150,
100]]]])).astype('float32')
scores = Tensor(
np.array([[[0.2000, 0.7000, 0.1000], [0.1000, 0.8000, 0.1000],
[0.3000, 0.6000, 0.1000], [0.0500, 0.9000,
0.0500]]])).astype('float32')
out = net(boxes, scores, max_output_size_per_class, max_total_size,
iou_threshold, score_threshold)
expect_shapes = [(1, 1, 4), (1, 1), (1, 1), (1,)]
for i in range(4):
assert out[i].asnumpy().shape == expect_shapes[i]
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_combined_non_max_suppression_dyn():
"""
Feature: test CombinedNonMaxSuppression in PyNative and Graph modes.
Description: test dynamic shape case.
Expectation: expect correct shape result.
"""
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
dyn_case()
context.set_context(mode=context.PYNATIVE_MODE, device_target='CPU')
dyn_case()

View File

@ -0,0 +1,61 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import pytest
import numpy as np
import mindspore.ops.operations.array_ops as P
from mindspore.common import dtype as mstype
from mindspore import Tensor, nn, context
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.op = P.ConjugateTranspose()
def construct(self, x, perm):
return self.op(x, perm)
def dyn_case():
net = Net()
x_dyn = Tensor(shape=[None, 4, None, 7], dtype=mstype.float32)
perm = (2, 1, 0, 3)
net.set_inputs(x_dyn, perm)
x = Tensor(np.random.random((8, 4, 5, 7)).astype(np.float32))
out = net(x, perm)
expect_shape = (5, 4, 8, 7)
assert out.asnumpy().shape == expect_shape
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_conjugate_transpose_dyn():
"""
Feature: test ConjugateTranspose in PyNative and Graph modes.
Description: test dynamic shape case.
Expectation: expect correct shape result.
"""
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
dyn_case()
context.set_context(mode=context.PYNATIVE_MODE, device_target='CPU')
dyn_case()

View File

@ -24,31 +24,82 @@ import mindspore.common.dtype as mstype
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.sparse_apply_ftrl = P.SparseApplyFtrl(lr=0.001, l1=0.0, l2=0.0, lr_power=-0.5, use_locking=False)
self.var = Parameter(Tensor(np.ones([3, 3, 3]).astype(np.float32)), name="var")
self.accum = Parameter(Tensor(np.ones([3, 3, 3]).astype(np.float32)), name="accum")
self.linear = Parameter(Tensor(np.ones([3, 3, 3]).astype(np.float32)), name="linear")
self.sparse_apply_ftrl = P.SparseApplyFtrl(lr=0.001,
l1=0.0,
l2=0.0,
lr_power=-0.5,
use_locking=False)
self.var = Parameter(Tensor(np.ones([3, 3, 3]).astype(np.float32)),
name="var")
self.accum = Parameter(Tensor(np.ones([3, 3, 3]).astype(np.float32)),
name="accum")
self.linear = Parameter(Tensor(np.ones([3, 3, 3]).astype(np.float32)),
name="linear")
def construct(self, grad, indices):
out = self.sparse_apply_ftrl(self.var, self.accum, self.linear, grad, indices)
out = self.sparse_apply_ftrl(self.var, self.accum, self.linear, grad,
indices)
return out
class Net_half(nn.Cell):
class NetHalf(nn.Cell):
def __init__(self):
super(Net_half, self).__init__()
self.sparse_apply_ftrl = P.SparseApplyFtrl(lr=0.001, l1=0.0, l2=0.0, lr_power=-0.5, use_locking=False)
self.var = Parameter(Tensor(np.ones([3, 3, 3]).astype(np.float16)), name="var")
self.accum = Parameter(Tensor(np.ones([3, 3, 3]).astype(np.float16)), name="accum")
self.linear = Parameter(Tensor(np.ones([3, 3, 3]).astype(np.float16)), name="linear")
super(NetHalf, self).__init__()
self.sparse_apply_ftrl = P.SparseApplyFtrl(lr=0.001,
l1=0.0,
l2=0.0,
lr_power=-0.5,
use_locking=False)
self.var = Parameter(Tensor(np.ones([3, 3, 3]).astype(np.float16)),
name="var")
self.accum = Parameter(Tensor(np.ones([3, 3, 3]).astype(np.float16)),
name="accum")
self.linear = Parameter(Tensor(np.ones([3, 3, 3]).astype(np.float16)),
name="linear")
def construct(self, grad, indices):
out = self.sparse_apply_ftrl(self.var, self.accum, self.linear, grad, indices)
out = self.sparse_apply_ftrl(self.var, self.accum, self.linear, grad,
indices)
return out
def dyn_case():
net = Net()
grad_dyn = Tensor(shape=[3, None, None], dtype=mstype.float32)
indices_dyn = Tensor(shape=[None], dtype=mstype.int32)
net.set_inputs(grad_dyn, indices_dyn)
grad = Tensor(np.ones([3, 3, 3]).astype(np.float32))
indices = Tensor([0, 1, 2], mstype.int32)
out = net(grad, indices)
expect_shape = (3, 3, 3)
for i in range(3):
assert out[i].asnumpy().shape == expect_shape
@pytest.mark.level0
@pytest.mark.platform_x86_gpu
@pytest.mark.env_onecard
def test_sparse_apply_ftrl_dyn():
"""
Feature: test SparseApplyFtrl in PyNative and Graph modes.
Description: test dynamic shape case.
Expectation: expect correct shape result.
"""
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
dyn_case()
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
dyn_case()
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@ -63,7 +114,8 @@ def test_ftrl():
[0.291479, 0.291479, 0.291479]],
[[0.291479, 0.291479, 0.291479],
[0.291479, 0.291479, 0.291479],
[0.291479, 0.291479, 0.291479]]]).astype(np.float32)
[0.291479, 0.291479,
0.291479]]]).astype(np.float32)
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
sparse_apply_ftrl = Net()
sparse_apply_ftrl(gradient, indices)
@ -83,12 +135,11 @@ def test_ftrl_sparse_int64_ind():
expect_var = np.array([[[0.291479, 0.291479, 0.291479],
[0.291479, 0.291479, 0.291479],
[0.291479, 0.291479, 0.291479]],
[[1, 1, 1],
[1, 1, 1],
[1, 1, 1]],
[[1, 1, 1], [1, 1, 1], [1, 1, 1]],
[[0.291479, 0.291479, 0.291479],
[0.291479, 0.291479, 0.291479],
[0.291479, 0.291479, 0.291479]]]).astype(np.float32)
[0.291479, 0.291479,
0.291479]]]).astype(np.float32)
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
sparse_apply_ftrl = Net()
sparse_apply_ftrl(gradient, indices)
@ -113,13 +164,14 @@ def test_ftrl_half():
[0.291479, 0.291479, 0.291479]],
[[0.291479, 0.291479, 0.291479],
[0.291479, 0.291479, 0.291479],
[0.291479, 0.291479, 0.291479]]]).astype(np.float16)
[0.291479, 0.291479,
0.291479]]]).astype(np.float16)
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
sparse_apply_ftrl = Net_half()
sparse_apply_ftrl = NetHalf()
sparse_apply_ftrl(gradient, indices)
assert np.all(sparse_apply_ftrl.var.data.asnumpy() == expect_var)
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
sparse_apply_ftrl = Net_half()
sparse_apply_ftrl = NetHalf()
sparse_apply_ftrl(gradient, indices)
assert np.all(sparse_apply_ftrl.var.data.asnumpy() == expect_var)
@ -133,21 +185,21 @@ def test_ftrl_sparse_half_int64_ind():
expect_var = np.array([[[0.291479, 0.291479, 0.291479],
[0.291479, 0.291479, 0.291479],
[0.291479, 0.291479, 0.291479]],
[[1, 1, 1],
[1, 1, 1],
[1, 1, 1]],
[[1, 1, 1], [1, 1, 1], [1, 1, 1]],
[[0.291479, 0.291479, 0.291479],
[0.291479, 0.291479, 0.291479],
[0.291479, 0.291479, 0.291479]]]).astype(np.float16)
[0.291479, 0.291479,
0.291479]]]).astype(np.float16)
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
sparse_apply_ftrl = Net_half()
sparse_apply_ftrl = NetHalf()
sparse_apply_ftrl(gradient, indices)
assert np.all(sparse_apply_ftrl.var.data.asnumpy() == expect_var)
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
sparse_apply_ftrl = Net_half()
sparse_apply_ftrl = NetHalf()
sparse_apply_ftrl(gradient, indices)
assert np.all(sparse_apply_ftrl.var.data.asnumpy() == expect_var)
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@ -162,12 +214,13 @@ def test_ftrl_half_return_output():
[0.291479, 0.291479, 0.291479]],
[[0.291479, 0.291479, 0.291479],
[0.291479, 0.291479, 0.291479],
[0.291479, 0.291479, 0.291479]]]).astype(np.float16)
[0.291479, 0.291479,
0.291479]]]).astype(np.float16)
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
sparse_apply_ftrl = Net_half()
sparse_apply_ftrl = NetHalf()
output = sparse_apply_ftrl(gradient, indices)
assert np.all(output[0].asnumpy() == expect_var)
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
sparse_apply_ftrl = Net_half()
sparse_apply_ftrl = NetHalf()
sparse_apply_ftrl(gradient, indices)
assert np.all(output[0].asnumpy() == expect_var)