CombinedNMS, SparseApplyFtrl, ConjugateTranspose
update upadte
This commit is contained in:
parent
73a7bc9b28
commit
1e96180574
|
@ -342,26 +342,49 @@ void CombinedNonMaxSuppressionCpuKernelMod::CheckOutput() {
|
|||
}
|
||||
}
|
||||
|
||||
void CombinedNonMaxSuppressionCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
|
||||
size_t input_num = common::AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
size_t output_num = common::AnfAlgo::GetOutputTensorNum(kernel_node);
|
||||
node_wpt_ = kernel_node;
|
||||
input0_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
|
||||
input1_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, KIndex1);
|
||||
input2_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, KIndex2);
|
||||
input3_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, KIndex3);
|
||||
input4_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, KIndex4);
|
||||
input5_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, KIndex5);
|
||||
bool CombinedNonMaxSuppressionCpuKernelMod::Init(const BaseOperatorPtr &base_operator,
|
||||
const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) {
|
||||
MS_EXCEPTION_IF_NULL(base_operator);
|
||||
kernel_name_ = base_operator->name();
|
||||
return true;
|
||||
}
|
||||
|
||||
int CombinedNonMaxSuppressionCpuKernelMod::Resize(const BaseOperatorPtr &base_operator,
|
||||
const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs,
|
||||
const std::map<uint32_t, tensor::TensorPtr> &) {
|
||||
if (auto ret = KernelMod::Resize(base_operator, inputs, outputs); ret != KRET_OK) {
|
||||
return ret;
|
||||
}
|
||||
|
||||
size_t input_num = inputs.size();
|
||||
size_t output_num = outputs.size();
|
||||
CHECK_KERNEL_INPUTS_NUM(input_num, kCombinedNonMaxSuppressionInputsNum, kernel_name_);
|
||||
CHECK_KERNEL_OUTPUTS_NUM(output_num, kCombinedNonMaxSuppressionOutputsNum, kernel_name_);
|
||||
|
||||
input0_shape_ = inputs.at(kIndex0)->GetDeviceShapeAdaptively();
|
||||
input1_shape_ = inputs.at(KIndex1)->GetDeviceShapeAdaptively();
|
||||
input2_shape_ = inputs.at(KIndex2)->GetDeviceShapeAdaptively();
|
||||
input3_shape_ = inputs.at(KIndex3)->GetDeviceShapeAdaptively();
|
||||
input4_shape_ = inputs.at(KIndex4)->GetDeviceShapeAdaptively();
|
||||
input5_shape_ = inputs.at(KIndex5)->GetDeviceShapeAdaptively();
|
||||
|
||||
output0_shape_ = outputs.at(kIndex0)->GetDeviceShapeAdaptively();
|
||||
output1_shape_ = outputs.at(kIndex1)->GetDeviceShapeAdaptively();
|
||||
output2_shape_ = outputs.at(kIndex2)->GetDeviceShapeAdaptively();
|
||||
output3_shape_ = outputs.at(kIndex3)->GetDeviceShapeAdaptively();
|
||||
|
||||
soft_nms_sigma_ = 0.0;
|
||||
num_bath_ = static_cast<int>(input0_shape_[0]);
|
||||
num_boxes_ = static_cast<int>(input0_shape_[KIndex1]);
|
||||
q_ = static_cast<int>(input0_shape_[KIndex2]);
|
||||
num_class_ = static_cast<int>((input1_shape_[KIndex2]));
|
||||
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
|
||||
|
||||
pad_per_class_ = false;
|
||||
clip_boxes_ = true;
|
||||
auto prim = common::AnfAlgo::GetCNodePrimitive(kernel_node);
|
||||
|
||||
PrimitivePtr prim = base_operator->GetPrim();
|
||||
auto pad_per_class = prim->GetAttr("pad_per_class");
|
||||
auto clip_boxes = prim->GetAttr("clip_boxes");
|
||||
if (pad_per_class != nullptr) {
|
||||
|
@ -370,8 +393,11 @@ void CombinedNonMaxSuppressionCpuKernelMod::InitKernel(const CNodePtr &kernel_no
|
|||
if (clip_boxes != nullptr) {
|
||||
clip_boxes_ = GetValue<bool>(clip_boxes);
|
||||
}
|
||||
CHECK_KERNEL_INPUTS_NUM(input_num, kCombinedNonMaxSuppressionInputsNum, kernel_name_);
|
||||
CHECK_KERNEL_OUTPUTS_NUM(output_num, kCombinedNonMaxSuppressionOutputsNum, kernel_name_);
|
||||
|
||||
CheckInput();
|
||||
CheckOutput();
|
||||
|
||||
return KRET_OK;
|
||||
}
|
||||
|
||||
bool CombinedNonMaxSuppressionCpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
||||
|
@ -392,24 +418,9 @@ bool CombinedNonMaxSuppressionCpuKernelMod::Launch(const std::vector<kernel::Add
|
|||
} else {
|
||||
num_detection_ = max_total_size_;
|
||||
}
|
||||
auto node_ = node_wpt_.lock();
|
||||
if (!node_) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', node_wpt_(kernel_node) is expired. Error no: " << node_ << ".";
|
||||
}
|
||||
ShapeVector shape0 = {input0_shape_[0], static_cast<int64_t>(num_detection_), DimSize4};
|
||||
ShapeVector shape1 = {input0_shape_[0], static_cast<int64_t>(num_detection_)};
|
||||
ShapeVector shape2 = {input0_shape_[0], static_cast<int64_t>(num_detection_)};
|
||||
ShapeVector shape3 = {input0_shape_[0]};
|
||||
common::AnfAlgo::SetOutputInferTypeAndShape(
|
||||
{kNumberTypeFloat32, kNumberTypeFloat32, kNumberTypeFloat32, kNumberTypeInt32}, {shape0, shape1, shape2, shape3},
|
||||
node_.get());
|
||||
output0_shape_ = AnfAlgo::GetOutputDeviceShape(node_, KIndex0);
|
||||
output1_shape_ = AnfAlgo::GetOutputDeviceShape(node_, KIndex1);
|
||||
output2_shape_ = AnfAlgo::GetOutputDeviceShape(node_, KIndex2);
|
||||
output3_shape_ = AnfAlgo::GetOutputDeviceShape(node_, KIndex3);
|
||||
|
||||
size_per_class_ = max_output_size_per_class_ < num_boxes_ ? max_output_size_per_class_ : num_boxes_;
|
||||
CheckInput();
|
||||
CheckOutput();
|
||||
|
||||
if (max_total_size_ <= 0) {
|
||||
MS_LOG(EXCEPTION) << "For " << kernel_name_ << " max_total_size must be > 0, but got " << max_total_size_ << ".";
|
||||
}
|
||||
|
@ -435,6 +446,7 @@ bool CombinedNonMaxSuppressionCpuKernelMod::Launch(const std::vector<kernel::Add
|
|||
(void)nms_perbath(boxes, scores, nmsed_boxes, nmsed_scores, nmsed_class, valid_detection);
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<KernelAttr> CombinedNonMaxSuppressionCpuKernelMod::GetOpSupport() {
|
||||
static std::vector<KernelAttr> kernel_attr_list = {
|
||||
KernelAttr()
|
||||
|
|
|
@ -50,11 +50,17 @@ bool result_cmp(const result_para &a, const result_para &b) { return a.score > b
|
|||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
class CombinedNonMaxSuppressionCpuKernelMod : public DeprecatedNativeCpuKernelMod {
|
||||
class CombinedNonMaxSuppressionCpuKernelMod : public NativeCpuKernelMod {
|
||||
public:
|
||||
CombinedNonMaxSuppressionCpuKernelMod() = default;
|
||||
~CombinedNonMaxSuppressionCpuKernelMod() override = default;
|
||||
void InitKernel(const CNodePtr &kernel_node) override;
|
||||
|
||||
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) override;
|
||||
|
||||
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs, const std::map<uint32_t, tensor::TensorPtr> &) override;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
|
||||
|
@ -69,6 +75,7 @@ class CombinedNonMaxSuppressionCpuKernelMod : public DeprecatedNativeCpuKernelMo
|
|||
void nms_perclass(float *, float *, std::vector<non_max_suppression_local::result_para> &, int &);
|
||||
void CheckInput();
|
||||
void CheckOutput();
|
||||
|
||||
int num_bath_ = 0;
|
||||
int num_boxes_ = 0;
|
||||
int q_ = 0;
|
||||
|
@ -87,7 +94,7 @@ class CombinedNonMaxSuppressionCpuKernelMod : public DeprecatedNativeCpuKernelMo
|
|||
float soft_nms_sigma_ = 0.0;
|
||||
bool pad_per_class_ = 0;
|
||||
bool clip_boxes_ = 1;
|
||||
CNodeWeakPtr node_wpt_;
|
||||
|
||||
std::vector<int64_t> input0_shape_;
|
||||
std::vector<int64_t> input1_shape_;
|
||||
std::vector<int64_t> input2_shape_;
|
||||
|
|
|
@ -44,14 +44,13 @@ using complex128 = std::complex<double>;
|
|||
constexpr size_t kMaxTransposeSerialSize = 50331648;
|
||||
} // namespace
|
||||
|
||||
void ConjugateTransposeCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
|
||||
input_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
|
||||
output_shape_ = AnfAlgo::GetOutputDeviceShape(kernel_node, 0);
|
||||
dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
|
||||
perm_type_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 1);
|
||||
|
||||
bool ConjugateTransposeCpuKernelMod::Init(const BaseOperatorPtr &base_operator,
|
||||
const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) {
|
||||
MS_EXCEPTION_IF_NULL(base_operator);
|
||||
kernel_name_ = base_operator->name();
|
||||
dtype_ = inputs.at(kIndex0)->GetDtype();
|
||||
perm_type_ = inputs.at(kIndex1)->GetDtype();
|
||||
launch_map_[kNumberTypeBool] = &ConjugateTransposeCpuKernelMod::LaunchKernel<bool>;
|
||||
launch_map_[kNumberTypeInt8] = &ConjugateTransposeCpuKernelMod::LaunchKernel<int8_t>;
|
||||
launch_map_[kNumberTypeInt16] = &ConjugateTransposeCpuKernelMod::LaunchKernel<int16_t>;
|
||||
|
@ -72,6 +71,21 @@ void ConjugateTransposeCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
|
|||
} else {
|
||||
MS_LOG(EXCEPTION) << "For ConjugateTranspose: unsupported input data type: " << dtype_;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
int ConjugateTransposeCpuKernelMod::Resize(const BaseOperatorPtr &base_operator,
|
||||
const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs,
|
||||
const std::map<uint32_t, tensor::TensorPtr> &) {
|
||||
if (auto ret = KernelMod::Resize(base_operator, inputs, outputs); ret != KRET_OK) {
|
||||
return ret;
|
||||
}
|
||||
|
||||
input_shape_ = inputs.at(kIndex0)->GetDeviceShapeAdaptively();
|
||||
output_shape_ = outputs.at(kIndex0)->GetDeviceShapeAdaptively();
|
||||
|
||||
return KRET_OK;
|
||||
}
|
||||
|
||||
bool ConjugateTransposeCpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_CONJUGATE_TRANSPOSE_CPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_CONJUGATE_TRANSPOSE_CPU_KERNEL_H_
|
||||
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include <unordered_map>
|
||||
#include <memory>
|
||||
|
@ -27,14 +28,20 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
class ConjugateTransposeCpuKernelMod : public DeprecatedNativeCpuKernelMod {
|
||||
class ConjugateTransposeCpuKernelMod : public NativeCpuKernelMod {
|
||||
public:
|
||||
ConjugateTransposeCpuKernelMod() = default;
|
||||
~ConjugateTransposeCpuKernelMod() override = default;
|
||||
|
||||
void InitKernel(const CNodePtr &kernel_node) override;
|
||||
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) override;
|
||||
|
||||
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs, const std::map<uint32_t, tensor::TensorPtr> &) override;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
|
||||
template <typename T>
|
||||
static void ConjComplexFunc(T *input, T *output, size_t start, size_t end);
|
||||
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "plugin/device/gpu/kernel/nn/sparse_ftrl_gpu_kernel.h"
|
||||
#include "plugin/device/gpu/kernel/nn/sparse_apply_ftrl_gpu_kernel.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
|
@ -0,0 +1,133 @@
|
|||
/**
|
||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_SPARSE_FTRL_GPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_SPARSE_FTRL_GPU_KERNEL_H_
|
||||
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include "ops/sparse_apply_ftrl.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "plugin/device/gpu/kernel/gpu_kernel.h"
|
||||
#include "plugin/device/gpu/kernel/gpu_kernel_factory.h"
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/sparse_ftrl_impl.cuh"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
constexpr size_t INPUT_NUM = 5;
|
||||
template <typename T, typename S>
|
||||
class SparseFtrlGpuKernelMod : public NativeGpuKernelMod {
|
||||
public:
|
||||
SparseFtrlGpuKernelMod() { ResetResource(); }
|
||||
~SparseFtrlGpuKernelMod() override = default;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
||||
T *variable = GetDeviceAddress<T>(inputs, 0);
|
||||
T *accumulation = GetDeviceAddress<T>(inputs, 1);
|
||||
T *linear = GetDeviceAddress<T>(inputs, 2);
|
||||
T *gradient = GetDeviceAddress<T>(inputs, 3);
|
||||
S *indices = GetDeviceAddress<S>(inputs, 4);
|
||||
T *variable_out = GetDeviceAddress<T>(outputs, 0);
|
||||
T *accumulation_out = GetDeviceAddress<T>(outputs, 1);
|
||||
T *linear_out = GetDeviceAddress<T>(outputs, 2);
|
||||
CalSparseApplyFtrl(gradient, indices, num_index_, n_stride_, lr_, l1_, l2_, lr_power_, use_locking_, variable,
|
||||
accumulation, linear, reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaMemcpyAsync(variable_out, variable, variable_size_, cudaMemcpyDeviceToDevice,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"cudaMemcpyAsync output failed");
|
||||
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
|
||||
cudaMemcpyAsync(accumulation_out, accumulation, accumulation_size_, cudaMemcpyDeviceToDevice,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"cudaMemcpyAsync output failed");
|
||||
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaMemcpyAsync(linear_out, linear, linear_size_, cudaMemcpyDeviceToDevice,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"cudaMemcpyAsync output failed");
|
||||
return true;
|
||||
}
|
||||
|
||||
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) {
|
||||
MS_EXCEPTION_IF_NULL(base_operator);
|
||||
kernel_name_ = base_operator->name();
|
||||
auto kernel_ptr = std::dynamic_pointer_cast<ops::SparseApplyFtrl>(base_operator);
|
||||
MS_EXCEPTION_IF_NULL(kernel_ptr);
|
||||
lr_ = kernel_ptr->get_lr();
|
||||
l1_ = kernel_ptr->get_l1();
|
||||
l2_ = kernel_ptr->get_l2();
|
||||
lr_power_ = kernel_ptr->get_lr_power();
|
||||
use_locking_ = kernel_ptr->get_use_locking();
|
||||
return true;
|
||||
}
|
||||
|
||||
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs, const std::map<uint32_t, tensor::TensorPtr> &) {
|
||||
if (auto ret = KernelMod::Resize(base_operator, inputs, outputs); ret != KRET_OK) {
|
||||
return ret;
|
||||
}
|
||||
|
||||
(void)CheckAndConvertUtils::CheckInteger("input num", inputs.size(), kEqual, INPUT_NUM, kernel_name_);
|
||||
|
||||
variable_size_ = sizeof(T);
|
||||
accumulation_size_ = sizeof(T);
|
||||
linear_size_ = sizeof(T);
|
||||
n_stride_ = 1;
|
||||
|
||||
auto variable_shape = inputs.at(kIndex0)->GetShapeVector();
|
||||
auto accumulation_shape = inputs.at(kIndex1)->GetShapeVector();
|
||||
auto linear_shape = inputs.at(kIndex2)->GetShapeVector();
|
||||
auto indices_shape = inputs.at(kIndex4)->GetShapeVector();
|
||||
|
||||
for (size_t i = 0; i < variable_shape.size(); i++) {
|
||||
variable_size_ *= variable_shape[i];
|
||||
if (i > 0) {
|
||||
n_stride_ *= variable_shape[i];
|
||||
}
|
||||
}
|
||||
accumulation_size_ *= SizeOf(accumulation_shape);
|
||||
linear_size_ *= SizeOf(linear_shape);
|
||||
num_index_ = indices_shape[0];
|
||||
|
||||
return KRET_OK;
|
||||
}
|
||||
|
||||
protected:
|
||||
void ResetResource() noexcept {
|
||||
lr_ = 0.0f;
|
||||
l1_ = 0.0f;
|
||||
l2_ = 0.0f;
|
||||
lr_power_ = 0.0f;
|
||||
use_locking_ = false;
|
||||
num_index_ = 0;
|
||||
n_stride_ = 1;
|
||||
}
|
||||
|
||||
private:
|
||||
size_t variable_size_;
|
||||
size_t accumulation_size_;
|
||||
size_t linear_size_;
|
||||
float lr_;
|
||||
float l1_;
|
||||
float l2_;
|
||||
float lr_power_;
|
||||
bool use_locking_;
|
||||
int num_index_;
|
||||
size_t n_stride_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_SPARSE_FTRL_GPU_KERNEL_H_
|
|
@ -1,168 +0,0 @@
|
|||
/**
|
||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_SPARSE_FTRL_GPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_SPARSE_FTRL_GPU_KERNEL_H_
|
||||
|
||||
#include <vector>
|
||||
#include "plugin/device/gpu/kernel/gpu_kernel.h"
|
||||
#include "plugin/device/gpu/kernel/gpu_kernel_factory.h"
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/sparse_ftrl_impl.cuh"
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
constexpr size_t INPUT_NUM = 5;
|
||||
template <typename T, typename S>
|
||||
class SparseFtrlGpuKernelMod : public DeprecatedNativeGpuKernelMod {
|
||||
public:
|
||||
SparseFtrlGpuKernelMod() { ResetResource(); }
|
||||
~SparseFtrlGpuKernelMod() override = default;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
||||
if (is_null_input_) {
|
||||
return true;
|
||||
}
|
||||
T *variable = GetDeviceAddress<T>(inputs, 0);
|
||||
T *accumulation = GetDeviceAddress<T>(inputs, 1);
|
||||
T *linear = GetDeviceAddress<T>(inputs, 2);
|
||||
T *gradient = GetDeviceAddress<T>(inputs, 3);
|
||||
S *indices = GetDeviceAddress<S>(inputs, 4);
|
||||
T *variable_out = GetDeviceAddress<T>(outputs, 0);
|
||||
T *accumulation_out = GetDeviceAddress<T>(outputs, 1);
|
||||
T *linear_out = GetDeviceAddress<T>(outputs, 2);
|
||||
CalSparseApplyFtrl(gradient, indices, num_index_, n_stride_, lr_, l1_, l2_, lr_power_, use_locking_, variable,
|
||||
accumulation, linear, reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
|
||||
cudaMemcpyAsync(variable_out, variable, variable_size_, cudaMemcpyDeviceToDevice,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"cudaMemcpyAsync output failed");
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
|
||||
cudaMemcpyAsync(accumulation_out, accumulation, accumulation_size_,
|
||||
cudaMemcpyDeviceToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"cudaMemcpyAsync output failed");
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
|
||||
cudaMemcpyAsync(linear_out, linear, linear_size_, cudaMemcpyDeviceToDevice,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"cudaMemcpyAsync output failed");
|
||||
return true;
|
||||
}
|
||||
|
||||
bool Init(const CNodePtr &kernel_node) override {
|
||||
auto kernel_name = common::AnfAlgo::GetCNodeName(kernel_node);
|
||||
kernel_node_ = kernel_node;
|
||||
size_t input_num = common::AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
if (input_num != INPUT_NUM) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name << "', the number of inputs must be " << INPUT_NUM << ", but got "
|
||||
<< input_num;
|
||||
}
|
||||
|
||||
variable_size_ = sizeof(T);
|
||||
accumulation_size_ = sizeof(T);
|
||||
linear_size_ = sizeof(T);
|
||||
gradient_size_ = sizeof(T);
|
||||
indices_size_ = sizeof(S);
|
||||
|
||||
auto shape_signed = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||
if (IsDynamic(shape_signed)) {
|
||||
return true;
|
||||
}
|
||||
auto variable_shape = Convert2SizeTClipNeg(shape_signed);
|
||||
auto accumulation_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
|
||||
auto linear_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 2);
|
||||
auto gradient_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 3);
|
||||
auto indices_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 4);
|
||||
is_null_input_ = CHECK_SHAPE_NULL(variable_shape, kernel_name, "var") ||
|
||||
CHECK_SHAPE_NULL(accumulation_shape, kernel_name, "accum") ||
|
||||
CHECK_SHAPE_NULL(linear_shape, kernel_name, "linear") ||
|
||||
CHECK_SHAPE_NULL(gradient_shape, kernel_name, "grad") ||
|
||||
CHECK_SHAPE_NULL(indices_shape, kernel_name, "indices");
|
||||
if (is_null_input_) {
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
for (size_t i = 0; i < variable_shape.size(); i++) {
|
||||
variable_size_ *= variable_shape[i];
|
||||
if (i > 0) {
|
||||
n_stride_ *= variable_shape[i];
|
||||
}
|
||||
}
|
||||
|
||||
accumulation_size_ *= SizeOf(accumulation_shape);
|
||||
linear_size_ *= SizeOf(linear_shape);
|
||||
gradient_size_ *= SizeOf(gradient_shape);
|
||||
indices_size_ *= SizeOf(indices_shape);
|
||||
|
||||
lr_ = GetAttr<float>(kernel_node, "lr");
|
||||
l1_ = GetAttr<float>(kernel_node, "l1");
|
||||
l2_ = GetAttr<float>(kernel_node, "l2");
|
||||
lr_power_ = GetAttr<float>(kernel_node, "lr_power");
|
||||
use_locking_ = GetAttr<bool>(kernel_node, "use_locking");
|
||||
num_index_ = LongToSizeClipNeg(indices_shape[0]);
|
||||
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
|
||||
protected:
|
||||
void InitSizeLists() override {
|
||||
input_size_list_.push_back(variable_size_);
|
||||
input_size_list_.push_back(accumulation_size_);
|
||||
input_size_list_.push_back(linear_size_);
|
||||
input_size_list_.push_back(gradient_size_);
|
||||
input_size_list_.push_back(indices_size_);
|
||||
output_size_list_.push_back(variable_size_);
|
||||
output_size_list_.push_back(accumulation_size_);
|
||||
output_size_list_.push_back(linear_size_);
|
||||
}
|
||||
|
||||
void ResetResource() noexcept override {
|
||||
variable_size_ = 0;
|
||||
accumulation_size_ = 0;
|
||||
linear_size_ = 0;
|
||||
gradient_size_ = 0;
|
||||
indices_size_ = 0;
|
||||
lr_ = 0.0f;
|
||||
l1_ = 0.0f;
|
||||
l2_ = 0.0f;
|
||||
lr_power_ = 0.0f;
|
||||
use_locking_ = false;
|
||||
is_null_input_ = false;
|
||||
num_index_ = 0;
|
||||
n_stride_ = 1;
|
||||
input_size_list_.clear();
|
||||
output_size_list_.clear();
|
||||
workspace_size_list_.clear();
|
||||
}
|
||||
|
||||
private:
|
||||
size_t variable_size_;
|
||||
size_t accumulation_size_;
|
||||
size_t linear_size_;
|
||||
size_t gradient_size_;
|
||||
size_t indices_size_;
|
||||
float lr_;
|
||||
float l1_;
|
||||
float l2_;
|
||||
float lr_power_;
|
||||
bool use_locking_;
|
||||
bool is_null_input_;
|
||||
int num_index_;
|
||||
size_t n_stride_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_SPARSE_FTRL_GPU_KERNEL_H_
|
|
@ -31,6 +31,7 @@ const int64_t kInputDimension1 = 3;
|
|||
const int64_t kDimsize = 4;
|
||||
const int64_t kInputs = 6;
|
||||
const size_t ksecond = 2;
|
||||
|
||||
tensor::TensorPtr Get_Value(const std::vector<AbstractBasePtr> &input_args, size_t index) {
|
||||
auto input = input_args[index]->cast<abstract::AbstractTensorPtr>();
|
||||
MS_EXCEPTION_IF_NULL(input);
|
||||
|
@ -38,15 +39,14 @@ tensor::TensorPtr Get_Value(const std::vector<AbstractBasePtr> &input_args, size
|
|||
MS_EXCEPTION_IF_NULL(input_shape_value_ptr);
|
||||
return input_shape_value_ptr->cast<tensor::TensorPtr>();
|
||||
}
|
||||
abstract::TupleShapePtr CombinedNonMaxSuppressionInferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto prim_name = primitive->name();
|
||||
auto input0_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape];
|
||||
auto input1_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape];
|
||||
auto input2_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape())[kShape];
|
||||
auto input3_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex3]->BuildShape())[kShape];
|
||||
auto input4_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex4]->BuildShape())[kShape];
|
||||
auto input5_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex5]->BuildShape())[kShape];
|
||||
|
||||
void CombinedNonMaxSuppressionCheckShapeSize(const ShapeVector &input0_shape, const ShapeVector &input1_shape,
|
||||
const ShapeVector &input2_shape, const ShapeVector &input3_shape,
|
||||
const ShapeVector &input4_shape, const ShapeVector &input5_shape,
|
||||
const bool &is_dynamic_rank, const std::string &prim_name) {
|
||||
if (is_dynamic_rank) {
|
||||
return;
|
||||
}
|
||||
(void)CheckAndConvertUtils::CheckInteger("boxes dim", SizeToLong(input0_shape.size()), kEqual, kInputDimension0,
|
||||
prim_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("scores dim", SizeToLong(input1_shape.size()), kEqual, kInputDimension1,
|
||||
|
@ -55,6 +55,13 @@ abstract::TupleShapePtr CombinedNonMaxSuppressionInferShape(const PrimitivePtr &
|
|||
(void)CheckAndConvertUtils::CheckInteger("max_total_size dim", SizeToLong(input3_shape.size()), kEqual, 0, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("iou_threshold", SizeToLong(input4_shape.size()), kEqual, 0, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("score_threshold", SizeToLong(input5_shape.size()), kEqual, 0, prim_name);
|
||||
}
|
||||
|
||||
void CombinedNonMaxSuppressionCheckShapeValue(const ShapeVector &input0_shape, const ShapeVector &input1_shape,
|
||||
const bool &is_dynamic, const std::string &prim_name) {
|
||||
if (is_dynamic) {
|
||||
return;
|
||||
}
|
||||
if (input0_shape[0] != input1_shape[0]) {
|
||||
MS_EXCEPTION(ValueError) << "For " << prim_name << ", the boxes's 1st dim must be same with the scores's"
|
||||
<< " 1st dim, but got" << input0_shape[0] << " and " << input1_shape[0] << ".";
|
||||
|
@ -72,45 +79,35 @@ abstract::TupleShapePtr CombinedNonMaxSuppressionInferShape(const PrimitivePtr &
|
|||
MS_EXCEPTION(ValueError) << "For " << prim_name << ", the boxes's 4th dim must be equal to 4, but got"
|
||||
<< input0_shape[kInputIndex3] << ".";
|
||||
}
|
||||
for (int64_t i = 0; i < kInputs; i++) {
|
||||
if (!input_args[i]->isa<abstract::AbstractTensor>()) {
|
||||
MS_EXCEPTION(TypeError) << "For " << prim_name << " input" << i << " only support tensor!";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
abstract::TupleShapePtr CombinedNonMaxSuppressionGetOutputShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args,
|
||||
const bool &is_dynamic) {
|
||||
auto input0_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape];
|
||||
auto input1_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape];
|
||||
auto pad_per_class_ptr = primitive->GetAttr("pad_per_class");
|
||||
MS_EXCEPTION_IF_NULL(pad_per_class_ptr);
|
||||
bool pad_per_class = GetValue<bool>(pad_per_class_ptr);
|
||||
auto input2_tensor = Get_Value(input_args, kInputIndex2);
|
||||
auto input3_tensor = Get_Value(input_args, kInputIndex3);
|
||||
auto input4_tensor = Get_Value(input_args, kInputIndex4);
|
||||
auto input5_tensor = Get_Value(input_args, kInputIndex5);
|
||||
if (IsValue(input_args[kInputIndex2]->BuildValue()) && IsValue(input_args[kInputIndex3]->BuildValue())) {
|
||||
if (IsValue(input_args[kInputIndex4]->BuildValue()) && input_args[kInputIndex5]->BuildValue()) {
|
||||
auto iou_threshold = *(static_cast<float *>(input4_tensor->data_c()));
|
||||
auto score_threshold = *(static_cast<float *>(input5_tensor->data_c()));
|
||||
if (iou_threshold < 0 || iou_threshold > 1) {
|
||||
MS_EXCEPTION(ValueError) << "For " << prim_name << ", iou_threshold must be in [0,1], but got " << iou_threshold
|
||||
<< ".";
|
||||
}
|
||||
if (score_threshold < 0 && input0_shape[kInputIndex2] == input1_shape[kInputIndex2]) {
|
||||
MS_EXCEPTION(ValueError) << "For " << prim_name << ", it is temporarily unsupported when boxes's 2'nd dim "
|
||||
<< "is not 1 and score_threshold is less than 1.";
|
||||
}
|
||||
}
|
||||
|
||||
if (!is_dynamic && IsValue(input_args[kInputIndex2]->BuildValue()) &&
|
||||
IsValue(input_args[kInputIndex3]->BuildValue())) {
|
||||
auto input2_tensor = Get_Value(input_args, kInputIndex2);
|
||||
auto input3_tensor = Get_Value(input_args, kInputIndex3);
|
||||
auto max_output_size_per_class = *(static_cast<int32_t *>(input2_tensor->data_c()));
|
||||
auto max_total_size = *(static_cast<int32_t *>(input3_tensor->data_c()));
|
||||
if (max_total_size <= 0) {
|
||||
MS_EXCEPTION(ValueError) << "For " << prim_name << " max_total_size must be > 0, but got " << max_total_size
|
||||
<< ".";
|
||||
}
|
||||
if (max_output_size_per_class <= 0) {
|
||||
MS_EXCEPTION(ValueError) << "For " << prim_name << " max_output_size_per_class must be > 0, but got "
|
||||
<< max_output_size_per_class << ".";
|
||||
}
|
||||
|
||||
const int32_t kNumZero = 0;
|
||||
CheckAndConvertUtils::CheckInteger("max_total_size", max_total_size, kGreaterThan, kNumZero, primitive->name());
|
||||
|
||||
CheckAndConvertUtils::CheckInteger("max_output_size_per_clas", max_output_size_per_class, kGreaterThan, kNumZero,
|
||||
primitive->name());
|
||||
|
||||
auto num_detection = max_total_size;
|
||||
if (pad_per_class) {
|
||||
num_detection = std::min(max_total_size, max_output_size_per_class * static_cast<int32_t>(input1_shape[ksecond]));
|
||||
}
|
||||
|
||||
int64_t bs = input0_shape[0];
|
||||
ShapeVector shape1 = {bs, num_detection, 4};
|
||||
ShapeVector shape2 = {bs, num_detection};
|
||||
|
@ -122,14 +119,58 @@ abstract::TupleShapePtr CombinedNonMaxSuppressionInferShape(const PrimitivePtr &
|
|||
auto out4 = std::make_shared<abstract::Shape>(shape4);
|
||||
return std::make_shared<abstract::TupleShape>(std::vector<abstract::BaseShapePtr>{out1, out2, out3, out4});
|
||||
} else {
|
||||
auto shape1 = std::make_shared<abstract::Shape>(ShapeVector{-2});
|
||||
auto shape2 = std::make_shared<abstract::Shape>(ShapeVector{-2});
|
||||
auto shape3 = std::make_shared<abstract::Shape>(ShapeVector{-2});
|
||||
auto shape4 = std::make_shared<abstract::Shape>(ShapeVector{-2});
|
||||
auto shape1 = std::make_shared<abstract::Shape>(ShapeVector{-1, -1, 4});
|
||||
auto shape2 = std::make_shared<abstract::Shape>(ShapeVector{-1, -1});
|
||||
auto shape3 = std::make_shared<abstract::Shape>(ShapeVector{-1, -1});
|
||||
auto shape4 = std::make_shared<abstract::Shape>(ShapeVector{-1});
|
||||
return std::make_shared<abstract::TupleShape>(std::vector<abstract::BaseShapePtr>{shape1, shape2, shape3, shape4});
|
||||
}
|
||||
}
|
||||
|
||||
abstract::TupleShapePtr CombinedNonMaxSuppressionInferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto prim_name = primitive->name();
|
||||
auto input0_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape];
|
||||
auto input1_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape];
|
||||
auto input2_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape())[kShape];
|
||||
auto input3_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex3]->BuildShape())[kShape];
|
||||
auto input4_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex4]->BuildShape())[kShape];
|
||||
auto input5_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex5]->BuildShape())[kShape];
|
||||
|
||||
std::vector<ShapeVector> all_shapes = {input0_shape, input1_shape, input2_shape,
|
||||
input3_shape, input4_shape, input5_shape};
|
||||
auto is_dynamic = (IsDynamic(input0_shape) || IsDynamic(input1_shape));
|
||||
auto is_dynamic_rank = std::any_of(all_shapes.begin(), all_shapes.end(), IsDynamicRank);
|
||||
|
||||
CombinedNonMaxSuppressionCheckShapeSize(input0_shape, input1_shape, input2_shape, input3_shape, input4_shape,
|
||||
input5_shape, is_dynamic_rank, prim_name);
|
||||
|
||||
CombinedNonMaxSuppressionCheckShapeValue(input0_shape, input1_shape, is_dynamic, prim_name);
|
||||
|
||||
for (int64_t i = 0; i < kInputs; i++) {
|
||||
if (!input_args[i]->isa<abstract::AbstractTensor>()) {
|
||||
MS_EXCEPTION(TypeError) << "For " << prim_name << " input" << i << " only support tensor!";
|
||||
}
|
||||
}
|
||||
|
||||
if (IsValue(input_args[kInputIndex4]->BuildValue()) && IsValue(input_args[kInputIndex5]->BuildValue())) {
|
||||
auto input4_tensor = Get_Value(input_args, kInputIndex4);
|
||||
auto input5_tensor = Get_Value(input_args, kInputIndex5);
|
||||
auto iou_threshold = *(static_cast<float *>(input4_tensor->data_c()));
|
||||
auto score_threshold = *(static_cast<float *>(input5_tensor->data_c()));
|
||||
if (iou_threshold < 0 || iou_threshold > 1) {
|
||||
MS_EXCEPTION(ValueError) << "For " << prim_name << ", iou_threshold must be in [0,1], but got " << iou_threshold
|
||||
<< ".";
|
||||
}
|
||||
if (score_threshold < 0 && !is_dynamic && input0_shape[kInputIndex2] == input1_shape[kInputIndex2]) {
|
||||
MS_EXCEPTION(ValueError) << "For " << prim_name << ", it is temporarily unsupported when boxes's 2'nd dim "
|
||||
<< "is not 1 and score_threshold is less than 1.";
|
||||
}
|
||||
}
|
||||
|
||||
return CombinedNonMaxSuppressionGetOutputShape(primitive, input_args, is_dynamic);
|
||||
}
|
||||
|
||||
TuplePtr CombinedNonMaxSuppressionInferType(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto prim_name = primitive->name();
|
||||
|
|
|
@ -32,33 +32,48 @@ abstract::ShapePtr ConjugateTransposeInferShape(const PrimitivePtr &primitive,
|
|||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto op_name = primitive->name();
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
ShapeVector p_value;
|
||||
ShapeVector p_value_raw;
|
||||
auto is_dynamic_rank = IsDynamicRank(x_shape);
|
||||
if (is_dynamic_rank) {
|
||||
return std::make_shared<abstract::Shape>(std::vector<int64_t>{-2});
|
||||
}
|
||||
|
||||
constexpr int64_t dim_7 = 7;
|
||||
(void)CheckAndConvertUtils::CheckInteger("[x] rank", static_cast<int64_t>(x_shape.size()), kLessEqual, dim_7,
|
||||
op_name);
|
||||
|
||||
auto perm_value = input_args[1]->BuildValue();
|
||||
MS_EXCEPTION_IF_NULL(perm_value);
|
||||
if (perm_value->isa<ValueTuple>()) {
|
||||
if (perm_value->isa<AnyValue>()) {
|
||||
std::vector<int64_t> output_shape(static_cast<int>(x_shape.size()), -1);
|
||||
return std::make_shared<abstract::Shape>(output_shape);
|
||||
}
|
||||
|
||||
ShapeVector p_value;
|
||||
ShapeVector p_value_raw;
|
||||
if (perm_value->isa<tensor::Tensor>()) {
|
||||
p_value_raw = CheckAndConvertUtils::CheckTensorIntValue("input[perm]", perm_value, op_name);
|
||||
} else if (perm_value->isa<ValueTuple>()) {
|
||||
p_value_raw = CheckAndConvertUtils::CheckTupleInt("input[perm]", perm_value, op_name);
|
||||
} else {
|
||||
MS_EXCEPTION(TypeError) << "For '" << op_name << "', the type of perm must be Tuple, but got "
|
||||
<< input_args[1]->BuildType()->ToString() << " .";
|
||||
}
|
||||
|
||||
for (auto p : p_value_raw) {
|
||||
p = (p >= 0) ? p : (static_cast<int64_t>(p_value_raw.size()) + p);
|
||||
p_value.emplace_back(p);
|
||||
}
|
||||
|
||||
if (x_shape.size() != p_value.size()) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << op_name << "', the dimension of x " << x_shape.size() << " and perm "
|
||||
<< p_value.size() << " must be equal, but got the dimension of x " << x_shape.size()
|
||||
<< " and perm " << p_value.size() << " .";
|
||||
}
|
||||
constexpr int64_t dim_7 = 7;
|
||||
if (p_value.size() > dim_7) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << op_name << "', the dimension of perm must be less than 8, but get "
|
||||
<< p_value.size() << " .";
|
||||
}
|
||||
|
||||
for (auto i : p_value) {
|
||||
(void)CheckAndConvertUtils::CheckInteger("perm element", i, kLessThan, SizeToLong(p_value.size()), op_name);
|
||||
}
|
||||
|
||||
std::vector<int64_t> tmp(p_value);
|
||||
for (auto it = tmp.begin(); it != tmp.end();) {
|
||||
auto dim = *it;
|
||||
|
@ -69,6 +84,7 @@ abstract::ShapePtr ConjugateTransposeInferShape(const PrimitivePtr &primitive,
|
|||
MS_EXCEPTION(ValueError) << "For '" << op_name << "', the value of perm must be different.";
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<int64_t> in_shape(p_value);
|
||||
(void)std::transform(in_shape.begin(), in_shape.end(), in_shape.begin(), [x_shape](size_t i) { return x_shape[i]; });
|
||||
return std::make_shared<abstract::Shape>(in_shape);
|
||||
|
@ -85,7 +101,6 @@ TypePtr ConjugateTransposeInferType(const PrimitivePtr &prim, const std::vector<
|
|||
}
|
||||
} // namespace
|
||||
|
||||
MIND_API_OPERATOR_IMPL(ConjugateTranspose, BaseOperator);
|
||||
AbstractBasePtr ConjugateTransposeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
|
@ -95,6 +110,9 @@ AbstractBasePtr ConjugateTransposeInfer(const abstract::AnalysisEnginePtr &, con
|
|||
auto shape = ConjugateTransposeInferShape(primitive, input_args);
|
||||
return abstract::MakeAbstract(shape, type);
|
||||
}
|
||||
|
||||
REGISTER_HOST_DEPENDS(kNameConjugateTranspose, {1});
|
||||
MIND_API_OPERATOR_IMPL(ConjugateTranspose, BaseOperator);
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(ConjugateTranspose, prim::kPrimConjugateTranspose, ConjugateTransposeInfer, nullptr, true);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -136,7 +136,6 @@ void SparseApplyFtrl::Init(float lr, float l1, float l2, float lr_power, bool us
|
|||
set_use_locking(use_locking);
|
||||
}
|
||||
|
||||
MIND_API_OPERATOR_IMPL(SparseApplyFtrl, BaseOperator);
|
||||
AbstractBasePtr SparseApplyFtrlInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
|
@ -154,14 +153,14 @@ AbstractBasePtr SparseApplyFtrlInfer(const abstract::AnalysisEnginePtr &, const
|
|||
(void)CheckAndConvertUtils::CheckValue(kL1, l1, kGreaterEqual, 0.0f, op_name);
|
||||
(void)CheckAndConvertUtils::CheckValue(kL2, l2, kGreaterEqual, 0.0f, op_name);
|
||||
(void)CheckAndConvertUtils::CheckValue(kLrPower, lr_power, kLessEqual, 0.0f, op_name);
|
||||
|
||||
(void)CheckAndConvertUtils::CheckInteger("input numbers", SizeToLong(input_args.size()), kGreaterEqual,
|
||||
sparse_apply_ftrl::kSparseApplyFtrlInputNum, op_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("input numbers", CheckAndConvertUtils::GetRemoveMonadAbsNum(input_args),
|
||||
kEqual, sparse_apply_ftrl::kSparseApplyFtrlInputNum, op_name);
|
||||
auto types = sparse_apply_ftrl::SparseApplyFtrlInferType(primitive, input_args);
|
||||
auto shapes = sparse_apply_ftrl::SparseApplyFtrlInferShape(primitive, input_args);
|
||||
return abstract::MakeAbstract(shapes, types);
|
||||
}
|
||||
|
||||
MIND_API_OPERATOR_IMPL(SparseApplyFtrl, BaseOperator);
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(SparseApplyFtrl, prim::kPrimSparseApplyFtrl, SparseApplyFtrlInfer, nullptr, true);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -0,0 +1,77 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import pytest
|
||||
import numpy as np
|
||||
|
||||
import mindspore.ops.operations.image_ops as P
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore import Tensor, nn, context
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.op = P.CombinedNonMaxSuppression()
|
||||
|
||||
def construct(self, boxes, scores, max_output_size_per_class,
|
||||
max_total_size, iou_threshold, score_threshold):
|
||||
return self.op(boxes, scores, max_output_size_per_class,
|
||||
max_total_size, iou_threshold, score_threshold)
|
||||
|
||||
|
||||
def dyn_case():
|
||||
net = Net()
|
||||
|
||||
boxes_dyn = Tensor(shape=[None, None, None, 4], dtype=mstype.float32)
|
||||
scores_dyn = Tensor(shape=[None, None, None], dtype=mstype.float32)
|
||||
max_output_size_per_class = Tensor(4, mstype.int32)
|
||||
max_total_size = Tensor(1, mstype.int32)
|
||||
iou_threshold = Tensor(0, mstype.float32)
|
||||
score_threshold = Tensor(0, mstype.float32)
|
||||
|
||||
net.set_inputs(boxes_dyn, scores_dyn, max_output_size_per_class,
|
||||
max_total_size, iou_threshold, score_threshold)
|
||||
|
||||
boxes = Tensor(
|
||||
np.array([[[[200, 100, 150, 100]], [[220, 120, 150, 100]],
|
||||
[[190, 110, 150, 100]], [[210, 112, 150,
|
||||
100]]]])).astype('float32')
|
||||
scores = Tensor(
|
||||
np.array([[[0.2000, 0.7000, 0.1000], [0.1000, 0.8000, 0.1000],
|
||||
[0.3000, 0.6000, 0.1000], [0.0500, 0.9000,
|
||||
0.0500]]])).astype('float32')
|
||||
|
||||
out = net(boxes, scores, max_output_size_per_class, max_total_size,
|
||||
iou_threshold, score_threshold)
|
||||
expect_shapes = [(1, 1, 4), (1, 1), (1, 1), (1,)]
|
||||
for i in range(4):
|
||||
assert out[i].asnumpy().shape == expect_shapes[i]
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_combined_non_max_suppression_dyn():
|
||||
"""
|
||||
Feature: test CombinedNonMaxSuppression in PyNative and Graph modes.
|
||||
Description: test dynamic shape case.
|
||||
Expectation: expect correct shape result.
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
|
||||
dyn_case()
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target='CPU')
|
||||
dyn_case()
|
|
@ -0,0 +1,61 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import pytest
|
||||
import numpy as np
|
||||
|
||||
import mindspore.ops.operations.array_ops as P
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore import Tensor, nn, context
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.op = P.ConjugateTranspose()
|
||||
|
||||
def construct(self, x, perm):
|
||||
return self.op(x, perm)
|
||||
|
||||
|
||||
def dyn_case():
|
||||
net = Net()
|
||||
|
||||
x_dyn = Tensor(shape=[None, 4, None, 7], dtype=mstype.float32)
|
||||
perm = (2, 1, 0, 3)
|
||||
|
||||
net.set_inputs(x_dyn, perm)
|
||||
|
||||
x = Tensor(np.random.random((8, 4, 5, 7)).astype(np.float32))
|
||||
out = net(x, perm)
|
||||
|
||||
expect_shape = (5, 4, 8, 7)
|
||||
assert out.asnumpy().shape == expect_shape
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_conjugate_transpose_dyn():
|
||||
"""
|
||||
Feature: test ConjugateTranspose in PyNative and Graph modes.
|
||||
Description: test dynamic shape case.
|
||||
Expectation: expect correct shape result.
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
|
||||
dyn_case()
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target='CPU')
|
||||
dyn_case()
|
|
@ -24,31 +24,82 @@ import mindspore.common.dtype as mstype
|
|||
|
||||
|
||||
class Net(nn.Cell):
|
||||
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.sparse_apply_ftrl = P.SparseApplyFtrl(lr=0.001, l1=0.0, l2=0.0, lr_power=-0.5, use_locking=False)
|
||||
self.var = Parameter(Tensor(np.ones([3, 3, 3]).astype(np.float32)), name="var")
|
||||
self.accum = Parameter(Tensor(np.ones([3, 3, 3]).astype(np.float32)), name="accum")
|
||||
self.linear = Parameter(Tensor(np.ones([3, 3, 3]).astype(np.float32)), name="linear")
|
||||
self.sparse_apply_ftrl = P.SparseApplyFtrl(lr=0.001,
|
||||
l1=0.0,
|
||||
l2=0.0,
|
||||
lr_power=-0.5,
|
||||
use_locking=False)
|
||||
self.var = Parameter(Tensor(np.ones([3, 3, 3]).astype(np.float32)),
|
||||
name="var")
|
||||
self.accum = Parameter(Tensor(np.ones([3, 3, 3]).astype(np.float32)),
|
||||
name="accum")
|
||||
self.linear = Parameter(Tensor(np.ones([3, 3, 3]).astype(np.float32)),
|
||||
name="linear")
|
||||
|
||||
def construct(self, grad, indices):
|
||||
out = self.sparse_apply_ftrl(self.var, self.accum, self.linear, grad, indices)
|
||||
out = self.sparse_apply_ftrl(self.var, self.accum, self.linear, grad,
|
||||
indices)
|
||||
return out
|
||||
|
||||
|
||||
class Net_half(nn.Cell):
|
||||
class NetHalf(nn.Cell):
|
||||
|
||||
def __init__(self):
|
||||
super(Net_half, self).__init__()
|
||||
self.sparse_apply_ftrl = P.SparseApplyFtrl(lr=0.001, l1=0.0, l2=0.0, lr_power=-0.5, use_locking=False)
|
||||
self.var = Parameter(Tensor(np.ones([3, 3, 3]).astype(np.float16)), name="var")
|
||||
self.accum = Parameter(Tensor(np.ones([3, 3, 3]).astype(np.float16)), name="accum")
|
||||
self.linear = Parameter(Tensor(np.ones([3, 3, 3]).astype(np.float16)), name="linear")
|
||||
super(NetHalf, self).__init__()
|
||||
self.sparse_apply_ftrl = P.SparseApplyFtrl(lr=0.001,
|
||||
l1=0.0,
|
||||
l2=0.0,
|
||||
lr_power=-0.5,
|
||||
use_locking=False)
|
||||
self.var = Parameter(Tensor(np.ones([3, 3, 3]).astype(np.float16)),
|
||||
name="var")
|
||||
self.accum = Parameter(Tensor(np.ones([3, 3, 3]).astype(np.float16)),
|
||||
name="accum")
|
||||
self.linear = Parameter(Tensor(np.ones([3, 3, 3]).astype(np.float16)),
|
||||
name="linear")
|
||||
|
||||
def construct(self, grad, indices):
|
||||
out = self.sparse_apply_ftrl(self.var, self.accum, self.linear, grad, indices)
|
||||
out = self.sparse_apply_ftrl(self.var, self.accum, self.linear, grad,
|
||||
indices)
|
||||
return out
|
||||
|
||||
|
||||
def dyn_case():
|
||||
net = Net()
|
||||
|
||||
grad_dyn = Tensor(shape=[3, None, None], dtype=mstype.float32)
|
||||
indices_dyn = Tensor(shape=[None], dtype=mstype.int32)
|
||||
|
||||
net.set_inputs(grad_dyn, indices_dyn)
|
||||
|
||||
grad = Tensor(np.ones([3, 3, 3]).astype(np.float32))
|
||||
indices = Tensor([0, 1, 2], mstype.int32)
|
||||
|
||||
out = net(grad, indices)
|
||||
|
||||
expect_shape = (3, 3, 3)
|
||||
for i in range(3):
|
||||
assert out[i].asnumpy().shape == expect_shape
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_sparse_apply_ftrl_dyn():
|
||||
"""
|
||||
Feature: test SparseApplyFtrl in PyNative and Graph modes.
|
||||
Description: test dynamic shape case.
|
||||
Expectation: expect correct shape result.
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
|
||||
dyn_case()
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
|
||||
dyn_case()
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
|
@ -63,7 +114,8 @@ def test_ftrl():
|
|||
[0.291479, 0.291479, 0.291479]],
|
||||
[[0.291479, 0.291479, 0.291479],
|
||||
[0.291479, 0.291479, 0.291479],
|
||||
[0.291479, 0.291479, 0.291479]]]).astype(np.float32)
|
||||
[0.291479, 0.291479,
|
||||
0.291479]]]).astype(np.float32)
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
|
||||
sparse_apply_ftrl = Net()
|
||||
sparse_apply_ftrl(gradient, indices)
|
||||
|
@ -83,12 +135,11 @@ def test_ftrl_sparse_int64_ind():
|
|||
expect_var = np.array([[[0.291479, 0.291479, 0.291479],
|
||||
[0.291479, 0.291479, 0.291479],
|
||||
[0.291479, 0.291479, 0.291479]],
|
||||
[[1, 1, 1],
|
||||
[1, 1, 1],
|
||||
[1, 1, 1]],
|
||||
[[1, 1, 1], [1, 1, 1], [1, 1, 1]],
|
||||
[[0.291479, 0.291479, 0.291479],
|
||||
[0.291479, 0.291479, 0.291479],
|
||||
[0.291479, 0.291479, 0.291479]]]).astype(np.float32)
|
||||
[0.291479, 0.291479,
|
||||
0.291479]]]).astype(np.float32)
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
|
||||
sparse_apply_ftrl = Net()
|
||||
sparse_apply_ftrl(gradient, indices)
|
||||
|
@ -113,13 +164,14 @@ def test_ftrl_half():
|
|||
[0.291479, 0.291479, 0.291479]],
|
||||
[[0.291479, 0.291479, 0.291479],
|
||||
[0.291479, 0.291479, 0.291479],
|
||||
[0.291479, 0.291479, 0.291479]]]).astype(np.float16)
|
||||
[0.291479, 0.291479,
|
||||
0.291479]]]).astype(np.float16)
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
|
||||
sparse_apply_ftrl = Net_half()
|
||||
sparse_apply_ftrl = NetHalf()
|
||||
sparse_apply_ftrl(gradient, indices)
|
||||
assert np.all(sparse_apply_ftrl.var.data.asnumpy() == expect_var)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
sparse_apply_ftrl = Net_half()
|
||||
sparse_apply_ftrl = NetHalf()
|
||||
sparse_apply_ftrl(gradient, indices)
|
||||
assert np.all(sparse_apply_ftrl.var.data.asnumpy() == expect_var)
|
||||
|
||||
|
@ -133,21 +185,21 @@ def test_ftrl_sparse_half_int64_ind():
|
|||
expect_var = np.array([[[0.291479, 0.291479, 0.291479],
|
||||
[0.291479, 0.291479, 0.291479],
|
||||
[0.291479, 0.291479, 0.291479]],
|
||||
[[1, 1, 1],
|
||||
[1, 1, 1],
|
||||
[1, 1, 1]],
|
||||
[[1, 1, 1], [1, 1, 1], [1, 1, 1]],
|
||||
[[0.291479, 0.291479, 0.291479],
|
||||
[0.291479, 0.291479, 0.291479],
|
||||
[0.291479, 0.291479, 0.291479]]]).astype(np.float16)
|
||||
[0.291479, 0.291479,
|
||||
0.291479]]]).astype(np.float16)
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
|
||||
sparse_apply_ftrl = Net_half()
|
||||
sparse_apply_ftrl = NetHalf()
|
||||
sparse_apply_ftrl(gradient, indices)
|
||||
assert np.all(sparse_apply_ftrl.var.data.asnumpy() == expect_var)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
sparse_apply_ftrl = Net_half()
|
||||
sparse_apply_ftrl = NetHalf()
|
||||
sparse_apply_ftrl(gradient, indices)
|
||||
assert np.all(sparse_apply_ftrl.var.data.asnumpy() == expect_var)
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
|
@ -162,12 +214,13 @@ def test_ftrl_half_return_output():
|
|||
[0.291479, 0.291479, 0.291479]],
|
||||
[[0.291479, 0.291479, 0.291479],
|
||||
[0.291479, 0.291479, 0.291479],
|
||||
[0.291479, 0.291479, 0.291479]]]).astype(np.float16)
|
||||
[0.291479, 0.291479,
|
||||
0.291479]]]).astype(np.float16)
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
|
||||
sparse_apply_ftrl = Net_half()
|
||||
sparse_apply_ftrl = NetHalf()
|
||||
output = sparse_apply_ftrl(gradient, indices)
|
||||
assert np.all(output[0].asnumpy() == expect_var)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
sparse_apply_ftrl = Net_half()
|
||||
sparse_apply_ftrl = NetHalf()
|
||||
sparse_apply_ftrl(gradient, indices)
|
||||
assert np.all(output[0].asnumpy() == expect_var)
|
Loading…
Reference in New Issue