Refactor Ops MaxUnpool2D, MaxUnpool3D, MultiMarginLoss, MultilabelMarginLoss, TripletMarginLoss, BartlettWindow

This commit is contained in:
hedongdong 2022-06-22 21:50:53 +08:00
parent 0914ecf339
commit 4de4eb1a43
60 changed files with 1426 additions and 1537 deletions

View File

@ -1,5 +1,5 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -14,8 +14,11 @@
* limitations under the License.
*/
#include "backend/kernel_compiler/cpu/bartlett_window_cpu_kernel.h"
#include "runtime/device/cpu/cpu_device_address.h"
#include <cmath>
#include <string>
#include <algorithm>
#include "plugin/device/cpu/kernel/bartlett_window_cpu_kernel.h"
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
namespace mindspore {
namespace kernel {
@ -24,102 +27,94 @@ constexpr size_t kBartlettWindowInputsNum = 1;
constexpr size_t kBartlettWindowOutputsNum = 1;
} // namespace
void BartlettWindowCpuKernelMod<T, S>::InitKernel(const CNodePtr &kernel_node) {
void BartlettWindowCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
kernel_name_ = AnfAlgo::GetCNodeName(kernel_node);
input_dtype = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
output_dtype = AnfAlgo::GetOutputDeviceDataType(kernel_node, 0);
periodic_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, PERIODIC);
if ((input_dtype != kNumberTypeInt32) && (input_dtype != kNumberTypeInt64)) {
MS_LOG(EXCEPTION) << "Input tensor types must be int32 or int64";
}
input_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
periodic_ = common::AnfAlgo::GetNodeAttr<bool>(kernel_node, "periodic");
auto input_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
if (input_shape.size() > 0) {
MS_EXCEPTION(ValueError) << "The dim of window_length must be 0.";
MS_EXCEPTION(ValueError) << "For '" << kernel_name_ << "', the dim of window_length should be 0, but got "
<< input_shape.size();
}
node_wpt_ = kernel_node;
cnode_ptr_ = kernel_node;
auto kernel_attr = GetKernelAttrFromNode(kernel_node);
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
if (!is_match) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', it does not support this kernel data type: " << kernel_attr;
}
kernel_func_ = func_list_[index].second;
}
// template <typename T, typename S>
// void BartlettWindowCpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs,
// const std::vector<AddressPtr> &outputs) {
// auto input = reinterpret_cast<T *>(inputs[0]->addr);
// auto output = reinterpret_cast<S *>(outputs[0]->addr);
// auto input_data = *input;
// const size_t window_length = static_cast<size_t>(*input);
// const S output_one = static_cast<S>(1.);
// if (input_data < 0) {
// MS_EXCEPTION(ValueError) << "Input window_length must ≥ 0!";
// }
// if (input_data == 1) {
// *output = output_one;
// } else {
// if (periodic_) {
// input_data += 1;
// }
// const size_t first_half_size = static_cast<size_t>((input_data - 1) / 2);
// const double x = static_cast<double>(input_data);
// for (size_t i = 0; i <= first_half_size; i++) {
// auto value = static_cast<S>((2. * i) / (x - 1.));
// *(output + i) = value;
// }
// for (size_t i = first_half_size + 1; i < window_length; i++) {
// auto value = static_cast<S>(2. - (2. * i) / (x - 1.));
// *(output + i) = value;
// }
// }
// }
template <typename T1, typename T2>
bool BartlettWindowCpuKernelMod::BartlettWindowKernelFunc(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &outputs) {
auto node_ = cnode_ptr_.lock();
if (!node_) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', node_wpt_ is expired.";
}
template <typename T, typename S>
bool BartlettWindowCpuKernelMod<T, S>::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &outputs) {
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kBartlettWindowInputsNum, kernel_name_);
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kBartlettWindowOutputsNum, kernel_name_);
auto input = reinterpret_cast<T *>(inputs[0]->addr);
auto output = reinterpret_cast<S *>(outputs[0]->addr);
auto input_data = *input;
const size_t window_length = static_cast<size_t>(*input);
const S output_one = static_cast<S>(1.);
if (input_data < 0) {
MS_EXCEPTION(ValueError) << "Input window_length must ≥ 0!";
auto input = reinterpret_cast<T1 *>(inputs[0]->addr);
auto output = reinterpret_cast<T2 *>(outputs[0]->addr);
if (*input < 0) {
MS_EXCEPTION(ValueError) << "For '" << kernel_name_ << "', input window_length should be >= 0, but got " << *input;
}
if (input_data == 1) {
*output = output_one;
auto window_length = static_cast<int64_t>(*input);
double pre_window_length = static_cast<double>(window_length);
const size_t OUTPUTISONE = 1.0;
ShapeVector out_shape = {window_length};
std::vector<TypeId> dtypes = {AnfAlgo::GetOutputDeviceDataType(node_, 0)};
if (*input == 1) {
*output = static_cast<T2>(OUTPUTISONE);
} else {
if (periodic_) {
input_data += 1;
window_length += 1;
}
const size_t first_half_size = static_cast<size_t>((input_data - 1) / 2);
const double x = static_cast<double>(input_data);
const size_t first_half_size = static_cast<size_t>((window_length - 1) / 2);
const double x = static_cast<double>(window_length);
for (size_t i = 0; i <= first_half_size; i++) {
auto value = static_cast<S>((2. * i) / (x - 1.));
auto value = static_cast<T2>((2. * i) / (x - 1.));
*(output + i) = value;
}
for (size_t i = first_half_size + 1; i < window_length; i++) {
auto value = static_cast<S>(2. - (2. * i) / (x - 1.));
for (size_t i = first_half_size + 1; i < pre_window_length; i++) {
auto value = static_cast<T2>(2. - (2. * i) / (x - 1.));
*(output + i) = value;
}
}
// if (output_dtype == kNumberTypeFloat16) {
// if (input_dtype == kNumberTypeInt32) {
// LaunchKernel<int32_t, float16>(inputs, outputs);
// } else if (input_dtype == kNumberTypeInt64) {
// LaunchKernel<int64_t, float16>(inputs, outputs);
// }
// } else if (output_dtype == kNumberTypeFloat32) {
// if (input_dtype == kNumberTypeInt32) {
// LaunchKernel<int32_t, float>(inputs, outputs);
// } else if (input_dtype == kNumberTypeInt64) {
// LaunchKernel<int64_t, float>(inputs, outputs);
// }
// } else if (output_dtype == kNumberTypeFloat64) {
// if (input_dtype == kNumberTypeInt32) {
// LaunchKernel<int32_t, double>(inputs, outputs);
// } else if (input_dtype == kNumberTypeInt64) {
// LaunchKernel<int64_t, double>(inputs, outputs);
// }
// }
common::AnfAlgo::SetOutputInferTypeAndShape(dtypes, {out_shape}, node_.get());
return true;
}
std::vector<std::pair<KernelAttr, BartlettWindowCpuKernelMod::BartlettWindowFunc>>
BartlettWindowCpuKernelMod::func_list_ = {
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
&BartlettWindowCpuKernelMod::BartlettWindowKernelFunc<int32_t, float>},
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16),
&BartlettWindowCpuKernelMod::BartlettWindowKernelFunc<int32_t, float16>},
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat64),
&BartlettWindowCpuKernelMod::BartlettWindowKernelFunc<int32_t, double>},
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32),
&BartlettWindowCpuKernelMod::BartlettWindowKernelFunc<int64_t, float>},
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat16),
&BartlettWindowCpuKernelMod::BartlettWindowKernelFunc<int64_t, float16>},
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat64),
&BartlettWindowCpuKernelMod::BartlettWindowKernelFunc<int64_t, double>}};
std::vector<KernelAttr> BartlettWindowCpuKernelMod::GetOpSupport() {
std::vector<KernelAttr> support_list;
(void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
[](const std::pair<KernelAttr, BartlettWindowFunc> &pair) { return pair.first; });
return support_list;
}
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, BartlettWindow, BartlettWindowCpuKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -1,5 +1,5 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -13,46 +13,45 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_BARTLETT_WINDOW_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_BARTLETT_WINDOW_CPU_KERNEL_H_
#include <vector>
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
#include <utility>
#include "plugin/device/cpu/kernel/cpu_kernel.h"
#include "plugin/factory/ms_factory.h"
namespace mindspore {
namespace kernel {
template <typename T, typename S>
class BartlettWindowCpuKernelMod : public NativeCpuKernelMod {
class BartlettWindowCpuKernelMod : public DeprecatedNativeCpuKernelMod {
public:
BartlettWindowCpuKernelMod() = default;
~BartlettWindowCpuKernelMod() override = default;
void InitKernel(const CNodePtr &kernel_node) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
const std::vector<AddressPtr> &outputs) override {
return kernel_func_(this, inputs, workspace, outputs);
}
protected:
std::vector<KernelAttr> GetOpSupport() override;
private:
// template <typename T, typename S>
// void LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
template <typename T, typename T2>
bool BartlettWindowKernelFunc(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &workspace,
const std::vector<kernel::AddressPtr> &outputs);
bool periodic_{true};
TypeId output_dtype{kNumberTypeFloat32};
TypeId input_dtype{kTypeUnknown};
std::vector<size_t> input_shape;
using BartlettWindowFunc =
std::function<bool(BartlettWindowCpuKernelMod *, const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &, const std::vector<kernel::AddressPtr> &)>;
static std::vector<std::pair<KernelAttr, BartlettWindowFunc>> func_list_;
BartlettWindowFunc kernel_func_;
ShapeVector input_shape;
CNodePtr node_wpt_;
};
MS_REG_CPU_KERNEL_T_S(BartlettWindow, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
BartlettWindowCpuKernelMod, int32_t, float);
MS_REG_CPU_KERNEL_T_S(BartlettWindow, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16),
BartlettWindowCpuKernelMod, int32_t, float16);
MS_REG_CPU_KERNEL_T_S(BartlettWindow, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat64),
BartlettWindowCpuKernelMod, int32_t, double);
MS_REG_CPU_KERNEL_T_S(BartlettWindow, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32),
BartlettWindowCpuKernelMod, int64_t, float);
MS_REG_CPU_KERNEL_T_S(BartlettWindow, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat16),
BartlettWindowCpuKernelMod, int64_t, float16);
MS_REG_CPU_KERNEL_T_S(BartlettWindow, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat64),
BartlettWindowCpuKernelMod, int64_t, double);
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_BARTLETT_WINDOW_CPU_KERNEL_H_

View File

@ -15,8 +15,10 @@
*/
#include <cmath>
#include <map>
#include "backend/kernel_compiler/cpu/max_unpool2d_cpu_kernel.h"
#include "runtime/device/cpu/cpu_device_address.h"
#include <utility>
#include <algorithm>
#include "plugin/device/cpu/kernel/max_unpool2d_cpu_kernel.h"
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
namespace mindspore {
namespace kernel {
@ -29,28 +31,40 @@ constexpr size_t kInputIndex2 = 2;
constexpr size_t kInputIndex3 = 3;
} // namespace
template <typename DATA_T, typename INDICES_T>
void MaxUnpool2DCPUKernel<DATA_T, INDICES_T>::InitKernel(const CNodePtr &kernel_node) {
void MaxUnpool2DCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
kernel_name_ = AnfAlgo::GetCNodeName(kernel_node);
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
node_wpt_ = kernel_node;
input_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, kInputIndex0);
indices_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, kInputIndex1);
output_shape_ = AnfAlgo::GetOutputDeviceShape(kernel_node, kInputIndex0);
data_format_ = AnfAlgo::GetNodeAttr<string>(kernel_node, FORMAT);
data_format_ = common::AnfAlgo::GetNodeAttr<string>(kernel_node, FORMAT);
if (AnfAlgo::IsShapesDynamic({input_shape_, indices_shape_, output_shape_})) {
return;
}
auto kernel_attr = GetKernelAttrFromNode(kernel_node);
std::vector<KernelAttr> support_list;
(void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
[](const std::pair<KernelAttr, MaxUnpool2DFunc> &pair) { return pair.first; });
auto [is_match, index] = MatchKernelAttr(kernel_attr, support_list);
if (!is_match) {
MS_LOG(EXCEPTION) << "MaxUnpool2D does not support this kernel data type: " << kernel_attr;
}
kernel_func_ = func_list_[index].second;
}
template <typename DATA_T, typename INDICES_T>
void MaxUnpool2DCPUKernel<DATA_T, INDICES_T>::OutPutInitKernel(DATA_T *raw_output, size_t length) {
template <typename DATA_T>
void MaxUnpool2DCpuKernelMod::OutPutInitKernel(DATA_T *raw_output, size_t length) {
for (size_t s = 0; s < length; s++) {
raw_output[s] = (DATA_T)0;
}
}
template <typename DATA_T, typename INDICES_T>
bool MaxUnpool2DCPUKernel<DATA_T, INDICES_T>::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &outputs) {
bool MaxUnpool2DCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &outputs) {
auto node = node_wpt_.lock();
if (!node) {
MS_LOG(EXCEPTION) << "node_wpt_ is expired.";
@ -66,14 +80,14 @@ bool MaxUnpool2DCPUKernel<DATA_T, INDICES_T>::Launch(const std::vector<kernel::A
auto *raw_indices = reinterpret_cast<INDICES_T *>(inputs[kInputIndex1]->addr);
auto *raw_output = reinterpret_cast<DATA_T *>(outputs[kInputIndex0]->addr);
if (data_format_ == "NHWC") {
size_t num_batch = input_shape_[kInputIndex0];
size_t input_height = input_shape_[kInputIndex1];
size_t input_width = input_shape_[kInputIndex2];
size_t num_channels = input_shape_[kInputIndex3];
size_t oheight = output_shape_[kInputIndex1];
size_t owidth = output_shape_[kInputIndex2];
size_t num_batch = LongToSize(input_shape_[kInputIndex0]);
size_t input_height = LongToSize(input_shape_[kInputIndex1]);
size_t input_width = LongToSize(input_shape_[kInputIndex2]);
size_t num_channels = LongToSize(input_shape_[kInputIndex3]);
size_t oheight = LongToSize(output_shape_[kInputIndex1]);
size_t owidth = LongToSize(output_shape_[kInputIndex2]);
size_t length = num_batch * oheight * owidth * num_channels;
OutPutInitKernel(raw_output, length);
OutPutInitKernel<DATA_T>(raw_output, length);
for (size_t n = 0; n < num_batch; n++) {
size_t noutput_offset = n * num_channels * owidth * oheight;
size_t n_input_offset = n * num_channels * input_width * input_height;
@ -100,14 +114,14 @@ bool MaxUnpool2DCPUKernel<DATA_T, INDICES_T>::Launch(const std::vector<kernel::A
}
}
} else {
size_t num_batch = input_shape_[kInputIndex0];
size_t input_height = input_shape_[kInputIndex2];
size_t input_width = input_shape_[kInputIndex3];
size_t num_channels = input_shape_[kInputIndex1];
size_t oheight = output_shape_[kInputIndex2];
size_t owidth = output_shape_[kInputIndex3];
size_t num_batch = LongToSize(input_shape_[kInputIndex0]);
size_t input_height = LongToSize(input_shape_[kInputIndex2]);
size_t input_width = LongToSize(input_shape_[kInputIndex3]);
size_t num_channels = LongToSize(input_shape_[kInputIndex1]);
size_t oheight = LongToSize(output_shape_[kInputIndex2]);
size_t owidth = LongToSize(output_shape_[kInputIndex3]);
size_t length = num_batch * oheight * owidth * num_channels;
OutPutInitKernel(raw_output, length);
OutPutInitKernel<DATA_T>(raw_output, length);
for (size_t n = 0; n < num_batch; n++) {
size_t noutput_offset = n * num_channels * owidth * oheight;
size_t n_input_offset = n * num_channels * input_width * input_height;
@ -139,5 +153,60 @@ bool MaxUnpool2DCPUKernel<DATA_T, INDICES_T>::Launch(const std::vector<kernel::A
}
return true;
}
std::vector<std::pair<KernelAttr, MaxUnpool2DCpuKernelMod::MaxUnpool2DFunc>> MaxUnpool2DCpuKernelMod::func_list_ = {
{KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt8),
&MaxUnpool2DCpuKernelMod::LaunchKernel<uint8_t, int32_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt8),
&MaxUnpool2DCpuKernelMod::LaunchKernel<uint8_t, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt16),
&MaxUnpool2DCpuKernelMod::LaunchKernel<uint16_t, int32_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt16),
&MaxUnpool2DCpuKernelMod::LaunchKernel<uint16_t, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt32),
&MaxUnpool2DCpuKernelMod::LaunchKernel<uint32_t, int32_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt32),
&MaxUnpool2DCpuKernelMod::LaunchKernel<uint32_t, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt64),
&MaxUnpool2DCpuKernelMod::LaunchKernel<uint64_t, int32_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt64),
&MaxUnpool2DCpuKernelMod::LaunchKernel<uint64_t, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt8),
&MaxUnpool2DCpuKernelMod::LaunchKernel<int8_t, int32_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt8),
&MaxUnpool2DCpuKernelMod::LaunchKernel<int8_t, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt16),
&MaxUnpool2DCpuKernelMod::LaunchKernel<int16_t, int32_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt16),
&MaxUnpool2DCpuKernelMod::LaunchKernel<int16_t, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
&MaxUnpool2DCpuKernelMod::LaunchKernel<int32_t, int32_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32),
&MaxUnpool2DCpuKernelMod::LaunchKernel<int32_t, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt64),
&MaxUnpool2DCpuKernelMod::LaunchKernel<int64_t, int32_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
&MaxUnpool2DCpuKernelMod::LaunchKernel<int64_t, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16),
&MaxUnpool2DCpuKernelMod::LaunchKernel<float16, int32_t>},
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat16),
&MaxUnpool2DCpuKernelMod::LaunchKernel<float16, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
&MaxUnpool2DCpuKernelMod::LaunchKernel<float, int32_t>},
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32),
&MaxUnpool2DCpuKernelMod::LaunchKernel<float, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat64),
&MaxUnpool2DCpuKernelMod::LaunchKernel<double, int32_t>},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat64),
&MaxUnpool2DCpuKernelMod::LaunchKernel<double, int64_t>}};
std::vector<KernelAttr> MaxUnpool2DCpuKernelMod::GetOpSupport() {
std::vector<KernelAttr> support_list;
(void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
[](const std::pair<KernelAttr, MaxUnpool2DFunc> &pair) { return pair.first; });
return support_list;
}
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, MaxUnpool2D, MaxUnpool2DCpuKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -19,117 +19,43 @@
#include <memory>
#include <vector>
#include <string>
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
#include <utility>
#include "plugin/device/cpu/kernel/cpu_kernel.h"
#include "plugin/factory/ms_factory.h"
namespace mindspore {
namespace kernel {
template <typename DATA_T, typename INDICES_T>
class MaxUnpool2DCPUKernel : public CPUKernel {
class MaxUnpool2DCpuKernelMod : public DeprecatedNativeCpuKernelMod {
public:
MaxUnpool2DCPUKernel() = default;
~MaxUnpool2DCPUKernel() override = default;
MaxUnpool2DCpuKernelMod() = default;
~MaxUnpool2DCpuKernelMod() override = default;
void InitKernel(const CNodePtr &kernel_node) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs) override {
return kernel_func_(this, inputs, outputs);
};
protected:
std::vector<KernelAttr> GetOpSupport() override;
private:
template <typename DATA_T, typename INDICES_T>
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
using MaxUnpool2DFunc = std::function<bool(MaxUnpool2DCpuKernelMod *, const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &)>;
static std::vector<std::pair<KernelAttr, MaxUnpool2DFunc>> func_list_;
MaxUnpool2DFunc kernel_func_;
template <typename DATA_T>
void OutPutInitKernel(DATA_T *rawOutput, size_t length);
CNodeWeakPtr node_wpt_;
std::vector<size_t> input_shape_;
std::vector<size_t> indices_shape_;
std::vector<size_t> output_shape_;
ShapeVector input_shape_;
ShapeVector indices_shape_;
ShapeVector output_shape_;
std::string data_format_;
};
MS_REG_CPU_KERNEL_T_S(
MaxUnpool2D,
KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt8),
MaxUnpool2DCPUKernel, uint8_t, int32_t);
MS_REG_CPU_KERNEL_T_S(
MaxUnpool2D,
KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt8),
MaxUnpool2DCPUKernel, uint8_t, int64_t);
MS_REG_CPU_KERNEL_T_S(
MaxUnpool2D,
KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt16),
MaxUnpool2DCPUKernel, uint16_t, int32_t);
MS_REG_CPU_KERNEL_T_S(
MaxUnpool2D,
KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt16),
MaxUnpool2DCPUKernel, uint16_t, int64_t);
MS_REG_CPU_KERNEL_T_S(
MaxUnpool2D,
KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt32),
MaxUnpool2DCPUKernel, uint32_t, int32_t);
MS_REG_CPU_KERNEL_T_S(
MaxUnpool2D,
KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt32),
MaxUnpool2DCPUKernel, uint32_t, int64_t);
MS_REG_CPU_KERNEL_T_S(
MaxUnpool2D,
KernelAttr().AddInputAttr(kNumberTypeUInt64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt64),
MaxUnpool2DCPUKernel, uint64_t, int32_t);
MS_REG_CPU_KERNEL_T_S(
MaxUnpool2D,
KernelAttr().AddInputAttr(kNumberTypeUInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt64),
MaxUnpool2DCPUKernel, uint64_t, int64_t);
MS_REG_CPU_KERNEL_T_S(
MaxUnpool2D, KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt8),
MaxUnpool2DCPUKernel, int8_t, int32_t);
MS_REG_CPU_KERNEL_T_S(
MaxUnpool2D, KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt8),
MaxUnpool2DCPUKernel, int8_t, int64_t);
MS_REG_CPU_KERNEL_T_S(
MaxUnpool2D,
KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt16),
MaxUnpool2DCPUKernel, int16_t, int32_t);
MS_REG_CPU_KERNEL_T_S(
MaxUnpool2D,
KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt16),
MaxUnpool2DCPUKernel, int16_t, int64_t);
MS_REG_CPU_KERNEL_T_S(
MaxUnpool2D,
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
MaxUnpool2DCPUKernel, int32_t, int32_t);
MS_REG_CPU_KERNEL_T_S(
MaxUnpool2D,
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32),
MaxUnpool2DCPUKernel, int32_t, int64_t);
MS_REG_CPU_KERNEL_T_S(
MaxUnpool2D,
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt64),
MaxUnpool2DCPUKernel, int64_t, int32_t);
MS_REG_CPU_KERNEL_T_S(
MaxUnpool2D,
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
MaxUnpool2DCPUKernel, int64_t, int64_t);
MS_REG_CPU_KERNEL_T_S(
MaxUnpool2D,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16),
MaxUnpool2DCPUKernel, float16, int32_t);
MS_REG_CPU_KERNEL_T_S(
MaxUnpool2D,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat16),
MaxUnpool2DCPUKernel, float16, int64_t);
MS_REG_CPU_KERNEL_T_S(
MaxUnpool2D,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
MaxUnpool2DCPUKernel, float, int32_t);
MS_REG_CPU_KERNEL_T_S(
MaxUnpool2D,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32),
MaxUnpool2DCPUKernel, float, int64_t);
MS_REG_CPU_KERNEL_T_S(
MaxUnpool2D,
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat64),
MaxUnpool2DCPUKernel, double, int32_t);
MS_REG_CPU_KERNEL_T_S(
MaxUnpool2D,
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat64),
MaxUnpool2DCPUKernel, double, int64_t);
} // namespace kernel
} // namespace mindspore

View File

@ -15,8 +15,10 @@
*/
#include <cmath>
#include <map>
#include "backend/kernel_compiler/cpu/max_unpool2d_grad_cpu_kernel.h"
#include "runtime/device/cpu/cpu_device_address.h"
#include <utility>
#include <algorithm>
#include "plugin/device/cpu/kernel/max_unpool2d_grad_cpu_kernel.h"
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
namespace mindspore {
namespace kernel {
@ -28,29 +30,42 @@ constexpr size_t kInputIndex1 = 1;
constexpr size_t kInputIndex2 = 2;
constexpr size_t kInputIndex3 = 3;
} // namespace
template <typename DATA_T, typename INDICES_T>
void MaxUnpool2DGradCPUKernel<DATA_T, INDICES_T>::InitKernel(const CNodePtr &kernel_node) {
void MaxUnpool2DGradCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
kernel_name_ = AnfAlgo::GetCNodeName(kernel_node);
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
node_wpt_ = kernel_node;
input_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, kInputIndex0);
grads_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, kInputIndex1);
indices_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, kInputIndex2);
output_shape_ = AnfAlgo::GetOutputDeviceShape(kernel_node, kInputIndex0);
data_format_ = AnfAlgo::GetNodeAttr<string>(kernel_node, FORMAT);
data_format_ = common::AnfAlgo::GetNodeAttr<string>(kernel_node, FORMAT);
if (AnfAlgo::IsShapesDynamic({input_shape_, grads_shape_, indices_shape_, output_shape_})) {
return;
}
auto kernel_attr = GetKernelAttrFromNode(kernel_node);
std::vector<KernelAttr> support_list;
(void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
[](const std::pair<KernelAttr, MaxUnpool2DGradFunc> &pair) { return pair.first; });
auto [is_match, index] = MatchKernelAttr(kernel_attr, support_list);
if (!is_match) {
MS_LOG(EXCEPTION) << "MaxUnpool2DGrad does not support this kernel data type: " << kernel_attr;
}
kernel_func_ = func_list_[index].second;
}
template <typename DATA_T, typename INDICES_T>
void MaxUnpool2DGradCPUKernel<DATA_T, INDICES_T>::OutPutInitKernel(DATA_T *raw_output, size_t length) {
template <typename DATA_T>
void MaxUnpool2DGradCpuKernelMod::OutPutInitKernel(DATA_T *raw_output, size_t length) {
for (size_t s = 0; s < length; s++) {
raw_output[s] = (DATA_T)0;
}
}
template <typename DATA_T, typename INDICES_T>
bool MaxUnpool2DGradCPUKernel<DATA_T, INDICES_T>::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &outputs) {
bool MaxUnpool2DGradCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &outputs) {
auto node = node_wpt_.lock();
if (!node) {
MS_LOG(EXCEPTION) << "node_wpt_ is expired.";
@ -66,14 +81,14 @@ bool MaxUnpool2DGradCPUKernel<DATA_T, INDICES_T>::Launch(const std::vector<kerne
auto *raw_indices = reinterpret_cast<INDICES_T *>(inputs[kInputIndex2]->addr);
auto *raw_output = reinterpret_cast<DATA_T *>(outputs[kInputIndex0]->addr);
if (data_format_ == "NHWC") {
size_t num_batch = grads_shape_[kInputIndex0];
size_t oheight = grads_shape_[kInputIndex1];
size_t owidth = grads_shape_[kInputIndex2];
size_t num_channels = grads_shape_[kInputIndex3];
size_t iheight = output_shape_[kInputIndex1];
size_t iwidth = output_shape_[kInputIndex2];
size_t num_batch = LongToSize(grads_shape_[kInputIndex0]);
size_t oheight = LongToSize(grads_shape_[kInputIndex1]);
size_t owidth = LongToSize(grads_shape_[kInputIndex2]);
size_t num_channels = LongToSize(grads_shape_[kInputIndex3]);
size_t iheight = LongToSize(output_shape_[kInputIndex1]);
size_t iwidth = LongToSize(output_shape_[kInputIndex2]);
size_t length = num_batch * iheight * iwidth * num_channels;
OutPutInitKernel(raw_output, length);
OutPutInitKernel<DATA_T>(raw_output, length);
for (size_t n = 0; n < num_batch; n++) {
size_t noutput_offset = n * num_channels * iwidth * iheight;
size_t n_grads_offset = n * num_channels * owidth * oheight;
@ -100,14 +115,14 @@ bool MaxUnpool2DGradCPUKernel<DATA_T, INDICES_T>::Launch(const std::vector<kerne
}
}
} else {
size_t num_batch = grads_shape_[kInputIndex0];
size_t oheight = grads_shape_[kInputIndex2];
size_t owidth = grads_shape_[kInputIndex3];
size_t num_channels = grads_shape_[kInputIndex1];
size_t iheight = output_shape_[kInputIndex2];
size_t iwidth = output_shape_[kInputIndex3];
size_t num_batch = LongToSize(grads_shape_[kInputIndex0]);
size_t oheight = LongToSize(grads_shape_[kInputIndex2]);
size_t owidth = LongToSize(grads_shape_[kInputIndex3]);
size_t num_channels = LongToSize(grads_shape_[kInputIndex1]);
size_t iheight = LongToSize(output_shape_[kInputIndex2]);
size_t iwidth = LongToSize(output_shape_[kInputIndex3]);
size_t length = num_batch * iheight * iwidth * num_channels;
OutPutInitKernel(raw_output, length);
OutPutInitKernel<DATA_T>(raw_output, length);
for (size_t n = 0; n < num_batch; n++) {
size_t noutput_offset = n * num_channels * iwidth * iheight;
size_t n_grads_offset = n * num_channels * owidth * oheight;
@ -139,5 +154,148 @@ bool MaxUnpool2DGradCPUKernel<DATA_T, INDICES_T>::Launch(const std::vector<kerne
}
return true;
}
std::vector<std::pair<KernelAttr, MaxUnpool2DGradCpuKernelMod::MaxUnpool2DGradFunc>>
MaxUnpool2DGradCpuKernelMod::func_list_ = {{KernelAttr()
.AddInputAttr(kNumberTypeUInt8)
.AddInputAttr(kNumberTypeUInt8)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeUInt8),
&MaxUnpool2DGradCpuKernelMod::LaunchKernel<uint8_t, int32_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeUInt8)
.AddInputAttr(kNumberTypeUInt8)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeUInt8),
&MaxUnpool2DGradCpuKernelMod::LaunchKernel<uint8_t, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeUInt16)
.AddInputAttr(kNumberTypeUInt16)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeUInt16),
&MaxUnpool2DGradCpuKernelMod::LaunchKernel<uint16_t, int32_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeUInt16)
.AddInputAttr(kNumberTypeUInt16)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeUInt16),
&MaxUnpool2DGradCpuKernelMod::LaunchKernel<uint16_t, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeUInt32)
.AddInputAttr(kNumberTypeUInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeUInt32),
&MaxUnpool2DGradCpuKernelMod::LaunchKernel<uint32_t, int32_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeUInt32)
.AddInputAttr(kNumberTypeUInt32)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeUInt32),
&MaxUnpool2DGradCpuKernelMod::LaunchKernel<uint32_t, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeUInt64)
.AddInputAttr(kNumberTypeUInt64)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeUInt64),
&MaxUnpool2DGradCpuKernelMod::LaunchKernel<uint64_t, int32_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeUInt64)
.AddInputAttr(kNumberTypeUInt64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeUInt64),
&MaxUnpool2DGradCpuKernelMod::LaunchKernel<uint64_t, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt8)
.AddInputAttr(kNumberTypeInt8)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt8),
&MaxUnpool2DGradCpuKernelMod::LaunchKernel<int8_t, int32_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt8)
.AddInputAttr(kNumberTypeInt8)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeInt8),
&MaxUnpool2DGradCpuKernelMod::LaunchKernel<int8_t, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt16)
.AddInputAttr(kNumberTypeInt16)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt16),
&MaxUnpool2DGradCpuKernelMod::LaunchKernel<int16_t, int32_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt16)
.AddInputAttr(kNumberTypeInt16)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeInt16),
&MaxUnpool2DGradCpuKernelMod::LaunchKernel<int16_t, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt32),
&MaxUnpool2DGradCpuKernelMod::LaunchKernel<int32_t, int32_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeInt32),
&MaxUnpool2DGradCpuKernelMod::LaunchKernel<int32_t, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt64),
&MaxUnpool2DGradCpuKernelMod::LaunchKernel<int64_t, int32_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeInt64),
&MaxUnpool2DGradCpuKernelMod::LaunchKernel<int64_t, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeFloat16),
&MaxUnpool2DGradCpuKernelMod::LaunchKernel<float16, int32_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat16),
&MaxUnpool2DGradCpuKernelMod::LaunchKernel<float16, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeFloat32),
&MaxUnpool2DGradCpuKernelMod::LaunchKernel<float, int32_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat32),
&MaxUnpool2DGradCpuKernelMod::LaunchKernel<float, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeFloat64),
&MaxUnpool2DGradCpuKernelMod::LaunchKernel<double, int32_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat64),
&MaxUnpool2DGradCpuKernelMod::LaunchKernel<double, int64_t>}};
std::vector<KernelAttr> MaxUnpool2DGradCpuKernelMod::GetOpSupport() {
std::vector<KernelAttr> support_list;
(void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
[](const std::pair<KernelAttr, MaxUnpool2DGradFunc> &pair) { return pair.first; });
return support_list;
}
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, MaxUnpool2DGrad, MaxUnpool2DGradCpuKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -19,187 +19,45 @@
#include <memory>
#include <vector>
#include <string>
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
#include <utility>
#include "plugin/device/cpu/kernel/cpu_kernel.h"
#include "plugin/factory/ms_factory.h"
namespace mindspore {
namespace kernel {
template <typename DATA_T, typename INDICES_T>
class MaxUnpool2DGradCPUKernel : public CPUKernel {
class MaxUnpool2DGradCpuKernelMod : public DeprecatedNativeCpuKernelMod {
public:
MaxUnpool2DGradCPUKernel() = default;
~MaxUnpool2DGradCPUKernel() override = default;
MaxUnpool2DGradCpuKernelMod() = default;
~MaxUnpool2DGradCpuKernelMod() override = default;
void InitKernel(const CNodePtr &kernel_node) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs) override {
return kernel_func_(this, inputs, outputs);
};
protected:
std::vector<KernelAttr> GetOpSupport() override;
private:
template <typename DATA_T, typename INDICES_T>
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
using MaxUnpool2DGradFunc = std::function<bool(MaxUnpool2DGradCpuKernelMod *, const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &)>;
static std::vector<std::pair<KernelAttr, MaxUnpool2DGradFunc>> func_list_;
MaxUnpool2DGradFunc kernel_func_;
template <typename DATA_T>
void OutPutInitKernel(DATA_T *rawOutput, size_t length);
CNodeWeakPtr node_wpt_;
std::vector<size_t> input_shape_;
std::vector<size_t> grads_shape_;
std::vector<size_t> indices_shape_;
std::vector<size_t> output_shape_;
ShapeVector input_shape_;
ShapeVector grads_shape_;
ShapeVector indices_shape_;
ShapeVector output_shape_;
std::string data_format_;
};
MS_REG_CPU_KERNEL_T_S(MaxUnpool2D,
KernelAttr()
.AddInputAttr(kNumberTypeUInt8)
.AddInputAttr(kNumberTypeUInt8)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeUInt8),
MaxUnpool2DGradCPUKernel, uint8_t, int32_t);
MS_REG_CPU_KERNEL_T_S(MaxUnpool2DGrad,
KernelAttr()
.AddInputAttr(kNumberTypeUInt8)
.AddInputAttr(kNumberTypeUInt8)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeUInt8),
MaxUnpool2DGradCPUKernel, uint8_t, int64_t);
MS_REG_CPU_KERNEL_T_S(MaxUnpool2DGrad,
KernelAttr()
.AddInputAttr(kNumberTypeUInt16)
.AddInputAttr(kNumberTypeUInt16)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeUInt16),
MaxUnpool2DGradCPUKernel, uint16_t, int32_t);
MS_REG_CPU_KERNEL_T_S(MaxUnpool2DGrad,
KernelAttr()
.AddInputAttr(kNumberTypeUInt16)
.AddInputAttr(kNumberTypeUInt16)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeUInt16),
MaxUnpool2DGradCPUKernel, uint16_t, int64_t);
MS_REG_CPU_KERNEL_T_S(MaxUnpool2DGrad,
KernelAttr()
.AddInputAttr(kNumberTypeUInt32)
.AddInputAttr(kNumberTypeUInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeUInt32),
MaxUnpool2DGradCPUKernel, uint32_t, int32_t);
MS_REG_CPU_KERNEL_T_S(MaxUnpool2DGrad,
KernelAttr()
.AddInputAttr(kNumberTypeUInt32)
.AddInputAttr(kNumberTypeUInt32)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeUInt32),
MaxUnpool2DGradCPUKernel, uint32_t, int64_t);
MS_REG_CPU_KERNEL_T_S(MaxUnpool2DGrad,
KernelAttr()
.AddInputAttr(kNumberTypeUInt64)
.AddInputAttr(kNumberTypeUInt64)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeUInt64),
MaxUnpool2DGradCPUKernel, uint64_t, int32_t);
MS_REG_CPU_KERNEL_T_S(MaxUnpool2DGrad,
KernelAttr()
.AddInputAttr(kNumberTypeUInt64)
.AddInputAttr(kNumberTypeUInt64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeUInt64),
MaxUnpool2DGradCPUKernel, uint64_t, int64_t);
MS_REG_CPU_KERNEL_T_S(MaxUnpool2DGrad,
KernelAttr()
.AddInputAttr(kNumberTypeInt8)
.AddInputAttr(kNumberTypeInt8)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt8),
MaxUnpool2DGradCPUKernel, int8_t, int32_t);
MS_REG_CPU_KERNEL_T_S(MaxUnpool2DGrad,
KernelAttr()
.AddInputAttr(kNumberTypeInt8)
.AddInputAttr(kNumberTypeInt8)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeInt8),
MaxUnpool2DGradCPUKernel, int8_t, int64_t);
MS_REG_CPU_KERNEL_T_S(MaxUnpool2DGrad,
KernelAttr()
.AddInputAttr(kNumberTypeInt16)
.AddInputAttr(kNumberTypeInt16)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt16),
MaxUnpool2DGradCPUKernel, int16_t, int32_t);
MS_REG_CPU_KERNEL_T_S(MaxUnpool2DGrad,
KernelAttr()
.AddInputAttr(kNumberTypeInt16)
.AddInputAttr(kNumberTypeInt16)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeInt16),
MaxUnpool2DGradCPUKernel, int16_t, int64_t);
MS_REG_CPU_KERNEL_T_S(MaxUnpool2DGrad,
KernelAttr()
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt32),
MaxUnpool2DGradCPUKernel, int32_t, int32_t);
MS_REG_CPU_KERNEL_T_S(MaxUnpool2DGrad,
KernelAttr()
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeInt32),
MaxUnpool2DGradCPUKernel, int32_t, int64_t);
MS_REG_CPU_KERNEL_T_S(MaxUnpool2DGrad,
KernelAttr()
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt64),
MaxUnpool2DGradCPUKernel, int64_t, int32_t);
MS_REG_CPU_KERNEL_T_S(MaxUnpool2DGrad,
KernelAttr()
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeInt64),
MaxUnpool2DGradCPUKernel, int64_t, int64_t);
MS_REG_CPU_KERNEL_T_S(MaxUnpool2DGrad,
KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeFloat16),
MaxUnpool2DGradCPUKernel, float16, int32_t);
MS_REG_CPU_KERNEL_T_S(MaxUnpool2DGrad,
KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat16),
MaxUnpool2DGradCPUKernel, float16, int64_t);
MS_REG_CPU_KERNEL_T_S(MaxUnpool2DGrad,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeFloat32),
MaxUnpool2DGradCPUKernel, float, int32_t);
MS_REG_CPU_KERNEL_T_S(MaxUnpool2DGrad,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat32),
MaxUnpool2DGradCPUKernel, float, int64_t);
MS_REG_CPU_KERNEL_T_S(MaxUnpool2DGrad,
KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeFloat64),
MaxUnpool2DGradCPUKernel, double, int32_t);
MS_REG_CPU_KERNEL_T_S(MaxUnpool2DGrad,
KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat64),
MaxUnpool2DGradCPUKernel, double, int64_t);
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_MaxUnpool2DGradGRAD_CPU_KERNEL_H_
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_MAXUNPOOL2DGRAD_CPU_KERNEL_H_

View File

@ -15,8 +15,10 @@
*/
#include <cmath>
#include <map>
#include "backend/kernel_compiler/cpu/max_unpool3d_cpu_kernel.h"
#include "runtime/device/cpu/cpu_device_address.h"
#include <utility>
#include <algorithm>
#include "plugin/device/cpu/kernel/max_unpool3d_cpu_kernel.h"
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
namespace mindspore {
namespace kernel {
@ -29,28 +31,41 @@ constexpr size_t kInputIndex2 = 2;
constexpr size_t kInputIndex3 = 3;
constexpr size_t kInputIndex4 = 4;
} // namespace
template <typename DATA_T, typename INDICES_T>
void MaxUnpool3DCPUKernel<DATA_T, INDICES_T>::InitKernel(const CNodePtr &kernel_node) {
void MaxUnpool3DCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
kernel_name_ = AnfAlgo::GetCNodeName(kernel_node);
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
node_wpt_ = kernel_node;
input_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, kInputIndex0);
indices_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, kInputIndex1);
output_shape_ = AnfAlgo::GetOutputDeviceShape(kernel_node, kInputIndex0);
data_format_ = AnfAlgo::GetNodeAttr<string>(kernel_node, FORMAT);
data_format_ = common::AnfAlgo::GetNodeAttr<string>(kernel_node, FORMAT);
if (AnfAlgo::IsShapesDynamic({input_shape_, indices_shape_, output_shape_})) {
return;
}
auto kernel_attr = GetKernelAttrFromNode(kernel_node);
std::vector<KernelAttr> support_list;
(void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
[](const std::pair<KernelAttr, MaxUnpool3DFunc> &pair) { return pair.first; });
auto [is_match, index] = MatchKernelAttr(kernel_attr, support_list);
if (!is_match) {
MS_LOG(EXCEPTION) << "MaxUnpool3D does not support this kernel data type: " << kernel_attr;
}
kernel_func_ = func_list_[index].second;
}
template <typename DATA_T, typename INDICES_T>
void MaxUnpool3DCPUKernel<DATA_T, INDICES_T>::OutPutInitKernel(DATA_T *raw_output, size_t length) {
template <typename DATA_T>
void MaxUnpool3DCpuKernelMod::OutPutInitKernel(DATA_T *raw_output, size_t length) {
for (size_t s = 0; s < length; s++) {
raw_output[s] = (DATA_T)0;
}
}
template <typename DATA_T, typename INDICES_T>
bool MaxUnpool3DCPUKernel<DATA_T, INDICES_T>::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &outputs) {
bool MaxUnpool3DCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &outputs) {
auto node = node_wpt_.lock();
if (!node) {
MS_LOG(EXCEPTION) << "node_wpt_ is expired.";
@ -66,15 +81,15 @@ bool MaxUnpool3DCPUKernel<DATA_T, INDICES_T>::Launch(const std::vector<kernel::A
auto *raw_output = reinterpret_cast<DATA_T *>(outputs[kInputIndex0]->addr);
size_t num_batch = input_shape_[kInputIndex0];
if (data_format_ == "NDHWC") {
size_t input_depth = input_shape_[kInputIndex1];
size_t input_height = input_shape_[kInputIndex2];
size_t input_width = input_shape_[kInputIndex3];
size_t num_channels = input_shape_[kInputIndex4];
size_t odepth = output_shape_[kInputIndex1];
size_t oheight = output_shape_[kInputIndex2];
size_t owidth = output_shape_[kInputIndex3];
size_t input_depth = LongToSize(input_shape_[kInputIndex1]);
size_t input_height = LongToSize(input_shape_[kInputIndex2]);
size_t input_width = LongToSize(input_shape_[kInputIndex3]);
size_t num_channels = LongToSize(input_shape_[kInputIndex4]);
size_t odepth = LongToSize(output_shape_[kInputIndex1]);
size_t oheight = LongToSize(output_shape_[kInputIndex2]);
size_t owidth = LongToSize(output_shape_[kInputIndex3]);
size_t length = num_batch * odepth * oheight * owidth * num_channels;
OutPutInitKernel(raw_output, length);
OutPutInitKernel<DATA_T>(raw_output, length);
for (size_t n = 0; n < num_batch; n++) {
size_t noutput_offset = n * num_channels * odepth * owidth * oheight;
size_t n_input_offset = n * num_channels * input_depth * input_width * input_height;
@ -105,15 +120,15 @@ bool MaxUnpool3DCPUKernel<DATA_T, INDICES_T>::Launch(const std::vector<kernel::A
}
}
} else {
size_t input_depth = input_shape_[kInputIndex2];
size_t input_height = input_shape_[kInputIndex3];
size_t input_width = input_shape_[kInputIndex4];
size_t num_channels = input_shape_[kInputIndex1];
size_t odepth = output_shape_[kInputIndex2];
size_t oheight = output_shape_[kInputIndex3];
size_t owidth = output_shape_[kInputIndex4];
size_t input_depth = LongToSize(input_shape_[kInputIndex2]);
size_t input_height = LongToSize(input_shape_[kInputIndex3]);
size_t input_width = LongToSize(input_shape_[kInputIndex4]);
size_t num_channels = LongToSize(input_shape_[kInputIndex1]);
size_t odepth = LongToSize(output_shape_[kInputIndex2]);
size_t oheight = LongToSize(output_shape_[kInputIndex3]);
size_t owidth = LongToSize(output_shape_[kInputIndex4]);
size_t length = num_batch * odepth * oheight * owidth * num_channels;
OutPutInitKernel(raw_output, length);
OutPutInitKernel<DATA_T>(raw_output, length);
for (size_t n = 0; n < num_batch; n++) {
size_t noutput_offset = n * num_channels * odepth * owidth * oheight;
size_t n_input_offset = n * num_channels * input_depth * input_width * input_height;
@ -148,5 +163,60 @@ bool MaxUnpool3DCPUKernel<DATA_T, INDICES_T>::Launch(const std::vector<kernel::A
}
return true;
}
std::vector<std::pair<KernelAttr, MaxUnpool3DCpuKernelMod::MaxUnpool3DFunc>> MaxUnpool3DCpuKernelMod::func_list_ = {
{KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt8),
&MaxUnpool3DCpuKernelMod::LaunchKernel<uint8_t, int32_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt8),
&MaxUnpool3DCpuKernelMod::LaunchKernel<uint8_t, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt16),
&MaxUnpool3DCpuKernelMod::LaunchKernel<uint16_t, int32_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt16),
&MaxUnpool3DCpuKernelMod::LaunchKernel<uint16_t, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt32),
&MaxUnpool3DCpuKernelMod::LaunchKernel<uint32_t, int32_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt32),
&MaxUnpool3DCpuKernelMod::LaunchKernel<uint32_t, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt64),
&MaxUnpool3DCpuKernelMod::LaunchKernel<uint64_t, int32_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt64),
&MaxUnpool3DCpuKernelMod::LaunchKernel<uint64_t, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt8),
&MaxUnpool3DCpuKernelMod::LaunchKernel<int8_t, int32_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt8),
&MaxUnpool3DCpuKernelMod::LaunchKernel<int8_t, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt16),
&MaxUnpool3DCpuKernelMod::LaunchKernel<int16_t, int32_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt16),
&MaxUnpool3DCpuKernelMod::LaunchKernel<int16_t, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
&MaxUnpool3DCpuKernelMod::LaunchKernel<int32_t, int32_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32),
&MaxUnpool3DCpuKernelMod::LaunchKernel<int32_t, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt64),
&MaxUnpool3DCpuKernelMod::LaunchKernel<int64_t, int32_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
&MaxUnpool3DCpuKernelMod::LaunchKernel<int64_t, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16),
&MaxUnpool3DCpuKernelMod::LaunchKernel<float16, int32_t>},
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat16),
&MaxUnpool3DCpuKernelMod::LaunchKernel<float16, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
&MaxUnpool3DCpuKernelMod::LaunchKernel<float, int32_t>},
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32),
&MaxUnpool3DCpuKernelMod::LaunchKernel<float, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat64),
&MaxUnpool3DCpuKernelMod::LaunchKernel<double, int32_t>},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat64),
&MaxUnpool3DCpuKernelMod::LaunchKernel<double, int64_t>}};
std::vector<KernelAttr> MaxUnpool3DCpuKernelMod::GetOpSupport() {
std::vector<KernelAttr> support_list;
(void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
[](const std::pair<KernelAttr, MaxUnpool3DFunc> &pair) { return pair.first; });
return support_list;
}
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, MaxUnpool3D, MaxUnpool3DCpuKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -19,117 +19,43 @@
#include <memory>
#include <vector>
#include <string>
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
#include <utility>
#include "plugin/device/cpu/kernel/cpu_kernel.h"
#include "plugin/factory/ms_factory.h"
namespace mindspore {
namespace kernel {
template <typename DATA_T, typename INDICES_T>
class MaxUnpool3DCPUKernel : public CPUKernel {
class MaxUnpool3DCpuKernelMod : public DeprecatedNativeCpuKernelMod {
public:
MaxUnpool3DCPUKernel() = default;
~MaxUnpool3DCPUKernel() override = default;
MaxUnpool3DCpuKernelMod() = default;
~MaxUnpool3DCpuKernelMod() override = default;
void InitKernel(const CNodePtr &kernel_node) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs) override {
return kernel_func_(this, inputs, outputs);
};
protected:
std::vector<KernelAttr> GetOpSupport() override;
private:
template <typename DATA_T, typename INDICES_T>
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
using MaxUnpool3DFunc = std::function<bool(MaxUnpool3DCpuKernelMod *, const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &)>;
static std::vector<std::pair<KernelAttr, MaxUnpool3DFunc>> func_list_;
MaxUnpool3DFunc kernel_func_;
template <typename DATA_T>
void OutPutInitKernel(DATA_T *rawOutput, size_t length);
CNodeWeakPtr node_wpt_;
std::vector<size_t> input_shape_;
std::vector<size_t> indices_shape_;
std::vector<size_t> output_shape_;
ShapeVector input_shape_;
ShapeVector indices_shape_;
ShapeVector output_shape_;
std::string data_format_;
};
MS_REG_CPU_KERNEL_T_S(
MaxUnpool3D,
KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt8),
MaxUnpool3DCPUKernel, uint8_t, int32_t);
MS_REG_CPU_KERNEL_T_S(
MaxUnpool3D,
KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt8),
MaxUnpool3DCPUKernel, uint8_t, int64_t);
MS_REG_CPU_KERNEL_T_S(
MaxUnpool3D,
KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt16),
MaxUnpool3DCPUKernel, uint16_t, int32_t);
MS_REG_CPU_KERNEL_T_S(
MaxUnpool3D,
KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt16),
MaxUnpool3DCPUKernel, uint16_t, int64_t);
MS_REG_CPU_KERNEL_T_S(
MaxUnpool3D,
KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt32),
MaxUnpool3DCPUKernel, uint32_t, int32_t);
MS_REG_CPU_KERNEL_T_S(
MaxUnpool3D,
KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt32),
MaxUnpool3DCPUKernel, uint32_t, int64_t);
MS_REG_CPU_KERNEL_T_S(
MaxUnpool3D,
KernelAttr().AddInputAttr(kNumberTypeUInt64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt64),
MaxUnpool3DCPUKernel, uint64_t, int32_t);
MS_REG_CPU_KERNEL_T_S(
MaxUnpool3D,
KernelAttr().AddInputAttr(kNumberTypeUInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt64),
MaxUnpool3DCPUKernel, uint64_t, int64_t);
MS_REG_CPU_KERNEL_T_S(
MaxUnpool3D, KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt8),
MaxUnpool3DCPUKernel, int8_t, int32_t);
MS_REG_CPU_KERNEL_T_S(
MaxUnpool3D, KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt8),
MaxUnpool3DCPUKernel, int8_t, int64_t);
MS_REG_CPU_KERNEL_T_S(
MaxUnpool3D,
KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt16),
MaxUnpool3DCPUKernel, int16_t, int32_t);
MS_REG_CPU_KERNEL_T_S(
MaxUnpool3D,
KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt16),
MaxUnpool3DCPUKernel, int16_t, int64_t);
MS_REG_CPU_KERNEL_T_S(
MaxUnpool3D,
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
MaxUnpool3DCPUKernel, int32_t, int32_t);
MS_REG_CPU_KERNEL_T_S(
MaxUnpool3D,
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32),
MaxUnpool3DCPUKernel, int32_t, int64_t);
MS_REG_CPU_KERNEL_T_S(
MaxUnpool3D,
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt64),
MaxUnpool3DCPUKernel, int64_t, int32_t);
MS_REG_CPU_KERNEL_T_S(
MaxUnpool3D,
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
MaxUnpool3DCPUKernel, int64_t, int64_t);
MS_REG_CPU_KERNEL_T_S(
MaxUnpool3D,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16),
MaxUnpool3DCPUKernel, float16, int32_t);
MS_REG_CPU_KERNEL_T_S(
MaxUnpool3D,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat16),
MaxUnpool3DCPUKernel, float16, int64_t);
MS_REG_CPU_KERNEL_T_S(
MaxUnpool3D,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
MaxUnpool3DCPUKernel, float, int32_t);
MS_REG_CPU_KERNEL_T_S(
MaxUnpool3D,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32),
MaxUnpool3DCPUKernel, float, int64_t);
MS_REG_CPU_KERNEL_T_S(
MaxUnpool3D,
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat64),
MaxUnpool3DCPUKernel, double, int32_t);
MS_REG_CPU_KERNEL_T_S(
MaxUnpool3D,
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat64),
MaxUnpool3DCPUKernel, double, int64_t);
} // namespace kernel
} // namespace mindspore

View File

@ -15,8 +15,10 @@
*/
#include <cmath>
#include <map>
#include "backend/kernel_compiler/cpu/max_unpool3d_grad_cpu_kernel.h"
#include "runtime/device/cpu/cpu_device_address.h"
#include <utility>
#include <algorithm>
#include "plugin/device/cpu/kernel/max_unpool3d_grad_cpu_kernel.h"
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
namespace mindspore {
namespace kernel {
@ -30,29 +32,41 @@ constexpr size_t kInputIndex3 = 3;
constexpr size_t kInputIndex4 = 4;
} // namespace
template <typename DATA_T, typename INDICES_T>
void MaxUnpool3DGradCPUKernel<DATA_T, INDICES_T>::InitKernel(const CNodePtr &kernel_node) {
void MaxUnpool3DGradCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
kernel_name_ = AnfAlgo::GetCNodeName(kernel_node);
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
node_wpt_ = kernel_node;
input_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, kInputIndex0);
grads_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, kInputIndex1);
indices_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, kInputIndex2);
output_shape_ = AnfAlgo::GetOutputDeviceShape(kernel_node, kInputIndex0);
data_format_ = AnfAlgo::GetNodeAttr<string>(kernel_node, FORMAT);
data_format_ = common::AnfAlgo::GetNodeAttr<string>(kernel_node, FORMAT);
if (AnfAlgo::IsShapesDynamic({input_shape_, grads_shape_, indices_shape_, output_shape_})) {
return;
}
auto kernel_attr = GetKernelAttrFromNode(kernel_node);
std::vector<KernelAttr> support_list;
(void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
[](const std::pair<KernelAttr, MaxUnpool3DGradFunc> &pair) { return pair.first; });
auto [is_match, index] = MatchKernelAttr(kernel_attr, support_list);
if (!is_match) {
MS_LOG(EXCEPTION) << "MaxUnpool3DGrad does not support this kernel data type: " << kernel_attr;
}
kernel_func_ = func_list_[index].second;
}
template <typename DATA_T, typename INDICES_T>
void MaxUnpool3DGradCPUKernel<DATA_T, INDICES_T>::OutPutInitKernel(DATA_T *raw_output, size_t length) {
template <typename DATA_T>
void MaxUnpool3DGradCpuKernelMod::OutPutInitKernel(DATA_T *raw_output, size_t length) {
for (size_t s = 0; s < length; s++) {
raw_output[s] = (DATA_T)0;
}
}
template <typename DATA_T, typename INDICES_T>
bool MaxUnpool3DGradCPUKernel<DATA_T, INDICES_T>::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &outputs) {
bool MaxUnpool3DGradCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &outputs) {
auto node = node_wpt_.lock();
if (!node) {
MS_LOG(EXCEPTION) << "node_wpt_ is expired.";
@ -66,17 +80,17 @@ bool MaxUnpool3DGradCPUKernel<DATA_T, INDICES_T>::Launch(const std::vector<kerne
auto *raw_grads = reinterpret_cast<DATA_T *>(inputs[kInputIndex1]->addr);
auto *raw_indices = reinterpret_cast<INDICES_T *>(inputs[kInputIndex2]->addr);
auto *raw_output = reinterpret_cast<DATA_T *>(outputs[kInputIndex0]->addr);
size_t num_batch = grads_shape_[kInputIndex0];
auto num_batch = LongToSize(grads_shape_[kInputIndex0]);
if (data_format_ == "NDHWC") {
size_t odepth = grads_shape_[kInputIndex1];
size_t oheight = grads_shape_[kInputIndex2];
size_t owidth = grads_shape_[kInputIndex3];
size_t num_channels = grads_shape_[kInputIndex4];
size_t idepth = output_shape_[kInputIndex1];
size_t iheight = output_shape_[kInputIndex2];
size_t iwidth = output_shape_[kInputIndex3];
size_t odepth = LongToSize(grads_shape_[kInputIndex1]);
size_t oheight = LongToSize(grads_shape_[kInputIndex2]);
size_t owidth = LongToSize(grads_shape_[kInputIndex3]);
size_t num_channels = LongToSize(grads_shape_[kInputIndex4]);
size_t idepth = LongToSize(output_shape_[kInputIndex1]);
size_t iheight = LongToSize(output_shape_[kInputIndex2]);
size_t iwidth = LongToSize(output_shape_[kInputIndex3]);
size_t length = num_batch * iheight * iwidth * idepth * num_channels;
OutPutInitKernel(raw_output, length);
OutPutInitKernel<DATA_T>(raw_output, length);
for (size_t n = 0; n < num_batch; n++) {
size_t noutput_offset = n * num_channels * iwidth * iheight * idepth;
size_t n_grads_offset = n * num_channels * owidth * oheight * odepth;
@ -106,15 +120,15 @@ bool MaxUnpool3DGradCPUKernel<DATA_T, INDICES_T>::Launch(const std::vector<kerne
}
}
} else {
size_t odepth = grads_shape_[kInputIndex2];
size_t oheight = grads_shape_[kInputIndex3];
size_t owidth = grads_shape_[kInputIndex4];
size_t num_channels = grads_shape_[kInputIndex1];
size_t idepth = output_shape_[kInputIndex2];
size_t iheight = output_shape_[kInputIndex3];
size_t iwidth = output_shape_[kInputIndex4];
size_t odepth = LongToSize(grads_shape_[kInputIndex2]);
size_t oheight = LongToSize(grads_shape_[kInputIndex3]);
size_t owidth = LongToSize(grads_shape_[kInputIndex4]);
size_t num_channels = LongToSize(grads_shape_[kInputIndex1]);
size_t idepth = LongToSize(output_shape_[kInputIndex2]);
size_t iheight = LongToSize(output_shape_[kInputIndex3]);
size_t iwidth = LongToSize(output_shape_[kInputIndex4]);
size_t length = num_batch * iheight * iwidth * idepth * num_channels;
OutPutInitKernel(raw_output, length);
OutPutInitKernel<DATA_T>(raw_output, length);
for (size_t n = 0; n < num_batch; n++) {
size_t noutput_offset = n * num_channels * iwidth * iheight * idepth;
size_t n_grads_offset = n * num_channels * owidth * oheight * odepth;
@ -149,5 +163,148 @@ bool MaxUnpool3DGradCPUKernel<DATA_T, INDICES_T>::Launch(const std::vector<kerne
}
return true;
}
std::vector<std::pair<KernelAttr, MaxUnpool3DGradCpuKernelMod::MaxUnpool3DGradFunc>>
MaxUnpool3DGradCpuKernelMod::func_list_ = {{KernelAttr()
.AddInputAttr(kNumberTypeUInt8)
.AddInputAttr(kNumberTypeUInt8)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeUInt8),
&MaxUnpool3DGradCpuKernelMod::LaunchKernel<uint8_t, int32_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeUInt8)
.AddInputAttr(kNumberTypeUInt8)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeUInt8),
&MaxUnpool3DGradCpuKernelMod::LaunchKernel<uint8_t, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeUInt16)
.AddInputAttr(kNumberTypeUInt16)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeUInt16),
&MaxUnpool3DGradCpuKernelMod::LaunchKernel<uint16_t, int32_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeUInt16)
.AddInputAttr(kNumberTypeUInt16)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeUInt16),
&MaxUnpool3DGradCpuKernelMod::LaunchKernel<uint16_t, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeUInt32)
.AddInputAttr(kNumberTypeUInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeUInt32),
&MaxUnpool3DGradCpuKernelMod::LaunchKernel<uint32_t, int32_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeUInt32)
.AddInputAttr(kNumberTypeUInt32)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeUInt32),
&MaxUnpool3DGradCpuKernelMod::LaunchKernel<uint32_t, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeUInt64)
.AddInputAttr(kNumberTypeUInt64)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeUInt64),
&MaxUnpool3DGradCpuKernelMod::LaunchKernel<uint64_t, int32_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeUInt64)
.AddInputAttr(kNumberTypeUInt64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeUInt64),
&MaxUnpool3DGradCpuKernelMod::LaunchKernel<uint64_t, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt8)
.AddInputAttr(kNumberTypeInt8)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt8),
&MaxUnpool3DGradCpuKernelMod::LaunchKernel<int8_t, int32_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt8)
.AddInputAttr(kNumberTypeInt8)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeInt8),
&MaxUnpool3DGradCpuKernelMod::LaunchKernel<int8_t, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt16)
.AddInputAttr(kNumberTypeInt16)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt16),
&MaxUnpool3DGradCpuKernelMod::LaunchKernel<int16_t, int32_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt16)
.AddInputAttr(kNumberTypeInt16)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeInt16),
&MaxUnpool3DGradCpuKernelMod::LaunchKernel<int16_t, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt32),
&MaxUnpool3DGradCpuKernelMod::LaunchKernel<int32_t, int32_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeInt32),
&MaxUnpool3DGradCpuKernelMod::LaunchKernel<int32_t, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt64),
&MaxUnpool3DGradCpuKernelMod::LaunchKernel<int64_t, int32_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeInt64),
&MaxUnpool3DGradCpuKernelMod::LaunchKernel<int64_t, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeFloat16),
&MaxUnpool3DGradCpuKernelMod::LaunchKernel<float16, int32_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat16),
&MaxUnpool3DGradCpuKernelMod::LaunchKernel<float16, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeFloat32),
&MaxUnpool3DGradCpuKernelMod::LaunchKernel<float, int32_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat32),
&MaxUnpool3DGradCpuKernelMod::LaunchKernel<float, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeFloat64),
&MaxUnpool3DGradCpuKernelMod::LaunchKernel<double, int32_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat64),
&MaxUnpool3DGradCpuKernelMod::LaunchKernel<double, int64_t>}};
std::vector<KernelAttr> MaxUnpool3DGradCpuKernelMod::GetOpSupport() {
std::vector<KernelAttr> support_list;
(void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
[](const std::pair<KernelAttr, MaxUnpool3DGradFunc> &pair) { return pair.first; });
return support_list;
}
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, MaxUnpool3DGrad, MaxUnpool3DGradCpuKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -19,196 +19,44 @@
#include <memory>
#include <vector>
#include <string>
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
#include <utility>
#include "plugin/device/cpu/kernel/cpu_kernel.h"
#include "plugin/factory/ms_factory.h"
namespace mindspore {
namespace kernel {
template <typename DATA_T, typename INDICES_T>
class MaxUnpool3DGradCPUKernel : public CPUKernel {
class MaxUnpool3DGradCpuKernelMod : public DeprecatedNativeCpuKernelMod {
public:
MaxUnpool3DGradCPUKernel() = default;
~MaxUnpool3DGradCPUKernel() override = default;
MaxUnpool3DGradCpuKernelMod() = default;
~MaxUnpool3DGradCpuKernelMod() override = default;
void InitKernel(const CNodePtr &kernel_node) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs) override {
return kernel_func_(this, inputs, outputs);
};
protected:
std::vector<KernelAttr> GetOpSupport() override;
private:
template <typename DATA_T, typename INDICES_T>
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
using MaxUnpool3DGradFunc = std::function<bool(MaxUnpool3DGradCpuKernelMod *, const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &)>;
static std::vector<std::pair<KernelAttr, MaxUnpool3DGradFunc>> func_list_;
MaxUnpool3DGradFunc kernel_func_;
template <typename DATA_T>
void OutPutInitKernel(DATA_T *rawOutput, size_t length);
CNodeWeakPtr node_wpt_;
std::vector<size_t> input_shape_;
std::vector<size_t> grads_shape_;
std::vector<size_t> indices_shape_;
std::vector<size_t> output_shape_;
ShapeVector input_shape_;
ShapeVector grads_shape_;
ShapeVector indices_shape_;
ShapeVector output_shape_;
std::string data_format_;
};
MS_REG_CPU_KERNEL_T_S(MaxUnpool3DGrad,
KernelAttr()
.AddInputAttr(kNumberTypeUInt8)
.AddInputAttr(kNumberTypeUInt8)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeUInt8),
MaxUnpool3DGradCPUKernel, uint8_t, int32_t);
MS_REG_CPU_KERNEL_T_S(MaxUnpool3DGrad,
KernelAttr()
.AddInputAttr(kNumberTypeUInt8)
.AddInputAttr(kNumberTypeUInt8)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeUInt8),
MaxUnpool3DGradCPUKernel, uint8_t, int64_t);
MS_REG_CPU_KERNEL_T_S(MaxUnpool3DGrad,
KernelAttr()
.AddInputAttr(kNumberTypeUInt16)
.AddInputAttr(kNumberTypeUInt16)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeUInt16),
MaxUnpool3DGradCPUKernel, uint16_t, int32_t);
MS_REG_CPU_KERNEL_T_S(MaxUnpool3DGrad,
KernelAttr()
.AddInputAttr(kNumberTypeUInt16)
.AddInputAttr(kNumberTypeUInt16)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeUInt16),
MaxUnpool3DGradCPUKernel, uint16_t, int64_t);
MS_REG_CPU_KERNEL_T_S(MaxUnpool3DGrad,
KernelAttr()
.AddInputAttr(kNumberTypeUInt32)
.AddInputAttr(kNumberTypeUInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeUInt32),
MaxUnpool3DGradCPUKernel, uint32_t, int32_t);
MS_REG_CPU_KERNEL_T_S(MaxUnpool3DGrad,
KernelAttr()
.AddInputAttr(kNumberTypeUInt32)
.AddInputAttr(kNumberTypeUInt32)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeUInt32),
MaxUnpool3DGradCPUKernel, uint32_t, int64_t);
MS_REG_CPU_KERNEL_T_S(MaxUnpool3DGrad,
KernelAttr()
.AddInputAttr(kNumberTypeUInt64)
.AddInputAttr(kNumberTypeUInt64)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeUInt64),
MaxUnpool3DGradCPUKernel, uint64_t, int32_t);
MS_REG_CPU_KERNEL_T_S(MaxUnpool3DGrad,
KernelAttr()
.AddInputAttr(kNumberTypeUInt64)
.AddInputAttr(kNumberTypeUInt64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeUInt64),
MaxUnpool3DGradCPUKernel, uint64_t, int64_t);
MS_REG_CPU_KERNEL_T_S(MaxUnpool3DGrad,
KernelAttr()
.AddInputAttr(kNumberTypeInt8)
.AddInputAttr(kNumberTypeInt8)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt8),
MaxUnpool3DGradCPUKernel, int8_t, int32_t);
MS_REG_CPU_KERNEL_T_S(MaxUnpool3DGrad,
KernelAttr()
.AddInputAttr(kNumberTypeInt8)
.AddInputAttr(kNumberTypeInt8)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeInt8),
MaxUnpool3DGradCPUKernel, int8_t, int64_t);
MS_REG_CPU_KERNEL_T_S(MaxUnpool3DGrad,
KernelAttr()
.AddInputAttr(kNumberTypeInt16)
.AddInputAttr(kNumberTypeInt16)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt16),
MaxUnpool3DGradCPUKernel, int16_t, int32_t);
MS_REG_CPU_KERNEL_T_S(MaxUnpool3DGrad,
KernelAttr()
.AddInputAttr(kNumberTypeInt16)
.AddInputAttr(kNumberTypeInt16)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeInt16),
MaxUnpool3DGradCPUKernel, int16_t, int64_t);
MS_REG_CPU_KERNEL_T_S(MaxUnpool3DGrad,
KernelAttr()
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt32),
MaxUnpool3DGradCPUKernel, int32_t, int32_t);
MS_REG_CPU_KERNEL_T_S(MaxUnpool3DGrad,
KernelAttr()
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeInt32),
MaxUnpool3DGradCPUKernel, int32_t, int64_t);
MS_REG_CPU_KERNEL_T_S(MaxUnpool3DGrad,
KernelAttr()
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt64),
MaxUnpool3DGradCPUKernel, int64_t, int32_t);
MS_REG_CPU_KERNEL_T_S(MaxUnpool3DGrad,
KernelAttr()
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeInt64),
MaxUnpool3DGradCPUKernel, int64_t, int64_t);
MS_REG_CPU_KERNEL_T_S(MaxUnpool3DGrad,
KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeFloat16),
MaxUnpool3DGradCPUKernel, float16, int32_t);
MS_REG_CPU_KERNEL_T_S(MaxUnpool3DGrad,
KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat16),
MaxUnpool3DGradCPUKernel, float16, int64_t);
MS_REG_CPU_KERNEL_T_S(MaxUnpool3DGrad,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeFloat32),
MaxUnpool3DGradCPUKernel, float, int32_t);
MS_REG_CPU_KERNEL_T_S(MaxUnpool3DGrad,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat32),
MaxUnpool3DGradCPUKernel, float, int64_t);
MS_REG_CPU_KERNEL_T_S(MaxUnpool3DGrad,
KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeFloat64),
MaxUnpool3DGradCPUKernel, double, int32_t);
MS_REG_CPU_KERNEL_T_S(MaxUnpool3DGrad,
KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat64),
MaxUnpool3DGradCPUKernel, double, int64_t);
} // namespace kernel
} // namespace mindspore

View File

@ -1,5 +1,5 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -14,8 +14,8 @@
* specific language governing permissions and limitations under the License.
*/
#include "backend/kernel_compiler/cpu/multi_margin_loss_cpu_kernel.h"
#include "runtime/device/cpu/cpu_device_address.h"
#include "plugin/device/cpu/kernel/multi_margin_loss_cpu_kernel.h"
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
namespace mindspore {
namespace kernel {
@ -23,24 +23,30 @@ namespace {
constexpr size_t kMultiMarginLossInputNumWithWeight = 3;
constexpr size_t kMultiMarginLossInputNumWithoutWeight = 2;
constexpr size_t kMultiMarginLossOutputsNum = 1;
const size_t kZero = 0;
const size_t kOne = 1;
const size_t kTwo = 2;
constexpr char kKernelName[] = "MultiMarginLoss";
} // namespace
void MultiMarginLossCPUKernel::InitKernel(const CNodePtr &kernel_node) {
void MultiMarginLossCPUKernelMod::InitKernel(const CNodePtr &kernel_node) {
CheckParam(kernel_node);
std::vector<size_t> x_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
batch_size = x_shape[0];
dims = x_shape[1];
reduction = AnfAlgo::GetNodeAttr<string>(kernel_node, REDUCTION);
p = AnfAlgo::GetNodeAttr<int64_t>(kernel_node, "p");
margin = AnfAlgo::GetNodeAttr<float>(kernel_node, "margin");
dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
input_num = AnfAlgo::GetInputTensorNum(kernel_node);
ShapeVector x_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, kZero);
if (IsDynamic({x_shape})) {
return;
}
batch_size = LongToSize(x_shape[kZero]);
dims = LongToSize(x_shape[kOne]);
reduction = common::AnfAlgo::GetNodeAttr<string>(kernel_node, REDUCTION);
p = common::AnfAlgo::GetNodeAttr<int64_t>(kernel_node, "p");
margin = common::AnfAlgo::GetNodeAttr<float>(kernel_node, "margin");
dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, kZero);
input_num = common::AnfAlgo::GetInputTensorNum(kernel_node);
}
bool MultiMarginLossCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &outputs) {
bool MultiMarginLossCPUKernelMod::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &outputs) {
if (dtype_ == kNumberTypeFloat16) {
LaunchKernelFP16<float16>(inputs, outputs);
} else if (dtype_ == kNumberTypeFloat32) {
@ -54,10 +60,10 @@ bool MultiMarginLossCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inp
}
template <typename T>
void MultiMarginLossCPUKernel::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &outputs) {
auto x_addr = reinterpret_cast<T *>(inputs[0]->addr);
auto target_addr = reinterpret_cast<int64_t *>(inputs[1]->addr);
void MultiMarginLossCPUKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &outputs) {
auto x_addr = reinterpret_cast<T *>(inputs[kZero]->addr);
auto target_addr = reinterpret_cast<int64_t *>(inputs[kOne]->addr);
for (size_t i = 0; i < batch_size; i++) {
if (target_addr[i] < 0 || target_addr[i] >= SizeToLong(dims)) {
MS_EXCEPTION(ValueError) << "Target out of range.";
@ -66,9 +72,9 @@ void MultiMarginLossCPUKernel::LaunchKernel(const std::vector<kernel::AddressPtr
T *weight_addr = nullptr;
bool weight_defined_ = (input_num == 3);
if (weight_defined_) {
weight_addr = reinterpret_cast<T *>(inputs[2]->addr);
weight_addr = reinterpret_cast<T *>(inputs[kTwo]->addr);
}
auto y_addr = reinterpret_cast<T *>(outputs[0]->addr);
auto y_addr = reinterpret_cast<T *>(outputs[kZero]->addr);
std::vector<T> tmp_loss(batch_size);
auto task = [&](size_t start, size_t end) {
start *= dims;
@ -117,10 +123,10 @@ void MultiMarginLossCPUKernel::LaunchKernel(const std::vector<kernel::AddressPtr
}
template <typename T>
void MultiMarginLossCPUKernel::LaunchKernelFP16(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &outputs) {
auto x_addr = reinterpret_cast<T *>(inputs[0]->addr);
auto target_addr = reinterpret_cast<int64_t *>(inputs[1]->addr);
void MultiMarginLossCPUKernelMod::LaunchKernelFP16(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &outputs) {
auto x_addr = reinterpret_cast<T *>(inputs[kZero]->addr);
auto target_addr = reinterpret_cast<int64_t *>(inputs[kOne]->addr);
for (size_t i = 0; i < batch_size; i++) {
if (target_addr[i] < 0 || target_addr[i] >= SizeToLong(dims)) {
MS_EXCEPTION(ValueError) << "Target out of range.";
@ -129,9 +135,9 @@ void MultiMarginLossCPUKernel::LaunchKernelFP16(const std::vector<kernel::Addres
T *weight_addr = nullptr;
bool weight_defined_ = (input_num == 3);
if (weight_defined_) {
weight_addr = reinterpret_cast<T *>(inputs[2]->addr);
weight_addr = reinterpret_cast<T *>(inputs[kTwo]->addr);
}
auto y_addr = reinterpret_cast<T *>(outputs[0]->addr);
auto y_addr = reinterpret_cast<T *>(outputs[kZero]->addr);
std::vector<float> tmp_loss(batch_size);
auto task = [&](size_t start, size_t end) {
start *= dims;
@ -180,13 +186,15 @@ void MultiMarginLossCPUKernel::LaunchKernelFP16(const std::vector<kernel::Addres
}
}
void MultiMarginLossCPUKernel::CheckParam(const CNodePtr &kernel_node) {
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
void MultiMarginLossCPUKernelMod::CheckParam(const CNodePtr &kernel_node) {
size_t input_num = common::AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != kMultiMarginLossInputNumWithoutWeight && input_num != kMultiMarginLossInputNumWithWeight) {
MS_LOG(EXCEPTION) << "Invalid input numbers, expect input number 2 or 3, but actual input number " << input_num;
}
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
size_t output_num = common::AnfAlgo::GetOutputTensorNum(kernel_node);
CHECK_KERNEL_OUTPUTS_NUM(output_num, kMultiMarginLossOutputsNum, kKernelName);
}
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, MultiMarginLoss, MultiMarginLossCPUKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -1,5 +1,5 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -20,29 +20,52 @@
#include <string>
#include <unordered_map>
#include <vector>
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
#include "plugin/device/cpu/kernel/cpu_kernel.h"
#include "plugin/factory/ms_factory.h"
namespace mindspore {
namespace kernel {
class MultiMarginLossCPUKernel : public CPUKernel {
class MultiMarginLossCPUKernelMod : public DeprecatedNativeCpuKernelMod {
public:
MultiMarginLossCPUKernel() = default;
MultiMarginLossCPUKernelMod() = default;
~MultiMarginLossCPUKernel() override = default;
~MultiMarginLossCPUKernelMod() override = default;
void InitKernel(const CNodePtr &kernel_node) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
protected:
std::vector<KernelAttr> GetOpSupport() override {
static std::vector<KernelAttr> support_list = {
KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16),
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeFloat64),
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat16),
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32),
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat64)};
return support_list;
}
private:
template <typename T>
void LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
template <typename T>
void LaunchKernelFP16(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
private:
void CheckParam(const CNodePtr &kernel_node);
size_t batch_size = 2;
size_t dims = 1;
@ -52,45 +75,6 @@ class MultiMarginLossCPUKernel : public CPUKernel {
size_t input_num = 1;
TypeId dtype_{kTypeUnknown};
};
MS_REG_CPU_KERNEL(MultiMarginLoss,
KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16),
MultiMarginLossCPUKernel);
MS_REG_CPU_KERNEL(MultiMarginLoss,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
MultiMarginLossCPUKernel);
MS_REG_CPU_KERNEL(MultiMarginLoss,
KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeFloat64),
MultiMarginLossCPUKernel);
MS_REG_CPU_KERNEL(
MultiMarginLoss,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat16),
MultiMarginLossCPUKernel);
MS_REG_CPU_KERNEL(
MultiMarginLoss,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32),
MultiMarginLossCPUKernel);
MS_REG_CPU_KERNEL(
MultiMarginLoss,
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat64),
MultiMarginLossCPUKernel);
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_MULTI_MARGIN_LOSS_CPU_KERNEL_H_

View File

@ -1,5 +1,5 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -14,8 +14,8 @@
* limitations under the License.
*/
#include "backend/kernel_compiler/cpu/multi_margin_loss_grad_cpu_kernel.h"
#include "runtime/device/cpu/cpu_device_address.h"
#include "plugin/device/cpu/kernel/multi_margin_loss_grad_cpu_kernel.h"
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
namespace mindspore {
namespace kernel {
@ -23,25 +23,32 @@ namespace {
constexpr size_t kMultiMarginLossGradInputNumWithWeight = 4;
constexpr size_t kMultiMarginLossGradInputNumWithoutWeight = 3;
constexpr size_t kMultiMarginLossGradOutputsNum = 1;
const size_t kZero = 0;
const size_t kOne = 1;
const size_t kTwo = 2;
const size_t kThree = 3;
constexpr char kKernelName[] = "MultiMarginLossGrad";
} // namespace
void MultiMarginLossGradCPUKernel::InitKernel(const CNodePtr &kernel_node) {
void MultiMarginLossGradCPUKernelMod::InitKernel(const CNodePtr &kernel_node) {
CheckParam(kernel_node);
std::vector<size_t> x_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
batch_size = x_shape[0];
dims = x_shape[1];
reduction = AnfAlgo::GetNodeAttr<string>(kernel_node, REDUCTION);
p = AnfAlgo::GetNodeAttr<int64_t>(kernel_node, "p");
margin = AnfAlgo::GetNodeAttr<float>(kernel_node, "margin");
y_grad_dims = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0).size();
dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
input_num = AnfAlgo::GetInputTensorNum(kernel_node);
ShapeVector x_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, kOne);
if (IsDynamic({x_shape})) {
return;
}
batch_size = LongToSize(x_shape[kZero]);
dims = LongToSize(x_shape[kOne]);
reduction = common::AnfAlgo::GetNodeAttr<string>(kernel_node, REDUCTION);
p = common::AnfAlgo::GetNodeAttr<int64_t>(kernel_node, "p");
margin = common::AnfAlgo::GetNodeAttr<float>(kernel_node, "margin");
y_grad_dims = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, kZero).size();
dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, kZero);
input_num = common::AnfAlgo::GetInputTensorNum(kernel_node);
}
bool MultiMarginLossGradCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &outputs) {
bool MultiMarginLossGradCPUKernelMod::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &outputs) {
if (dtype_ == kNumberTypeFloat16) {
LaunchKernelFP16<float16>(inputs, outputs);
} else if (dtype_ == kNumberTypeFloat32) {
@ -55,11 +62,11 @@ bool MultiMarginLossGradCPUKernel::Launch(const std::vector<kernel::AddressPtr>
}
template <typename T>
void MultiMarginLossGradCPUKernel::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &outputs) {
auto y_grad_addr = reinterpret_cast<T *>(inputs[0]->addr);
auto x_addr = reinterpret_cast<T *>(inputs[1]->addr);
auto target_addr = reinterpret_cast<int64_t *>(inputs[2]->addr);
void MultiMarginLossGradCPUKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &outputs) {
auto y_grad_addr = reinterpret_cast<T *>(inputs[kZero]->addr);
auto x_addr = reinterpret_cast<T *>(inputs[kOne]->addr);
auto target_addr = reinterpret_cast<int64_t *>(inputs[kTwo]->addr);
for (size_t i = 0; i < batch_size; i++) {
if (target_addr[i] < 0 || target_addr[i] >= SizeToLong(dims)) {
MS_EXCEPTION(ValueError) << "Target out of range.";
@ -68,12 +75,9 @@ void MultiMarginLossGradCPUKernel::LaunchKernel(const std::vector<kernel::Addres
T *weight_addr = nullptr;
bool weight_defined_ = (input_num == 4);
if (weight_defined_) {
weight_addr = reinterpret_cast<T *>(inputs[3]->addr);
weight_addr = reinterpret_cast<T *>(inputs[kThree]->addr);
}
auto x_grad_addr = reinterpret_cast<T *>(outputs[0]->addr);
T weights;
weights = reduction == MEAN ? (static_cast<T>(1) / (static_cast<T>(dims) * static_cast<T>(batch_size)))
: (static_cast<T>(1) / static_cast<T>(dims));
auto x_grad_addr = reinterpret_cast<T *>(outputs[kZero]->addr);
auto task = [&](size_t start, size_t end) {
start *= dims;
end *= dims;
@ -91,6 +95,8 @@ void MultiMarginLossGradCPUKernel::LaunchKernel(const std::vector<kernel::Addres
continue;
}
if (calc_data[d] > static_cast<T>(0)) {
auto weights = reduction == MEAN ? (static_cast<T>(1) / (static_cast<T>(dims) * static_cast<T>(batch_size)))
: (static_cast<T>(1) / static_cast<T>(dims));
calc_data[d] = (p == 1) ? weights : static_cast<T>(2) * weights * calc_data[d];
if (weight_defined_) {
calc_data[d] *= static_cast<T>(weight_addr[target_idx]);
@ -122,11 +128,11 @@ void MultiMarginLossGradCPUKernel::LaunchKernel(const std::vector<kernel::Addres
}
template <typename T>
void MultiMarginLossGradCPUKernel::LaunchKernelFP16(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &outputs) {
auto y_grad_addr = reinterpret_cast<T *>(inputs[0]->addr);
auto x_addr = reinterpret_cast<T *>(inputs[1]->addr);
auto target_addr = reinterpret_cast<int64_t *>(inputs[2]->addr);
void MultiMarginLossGradCPUKernelMod::LaunchKernelFP16(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &outputs) {
auto y_grad_addr = reinterpret_cast<T *>(inputs[kZero]->addr);
auto x_addr = reinterpret_cast<T *>(inputs[kOne]->addr);
auto target_addr = reinterpret_cast<int64_t *>(inputs[kTwo]->addr);
for (size_t i = 0; i < batch_size; i++) {
if (target_addr[i] < 0 || target_addr[i] >= SizeToLong(dims)) {
MS_EXCEPTION(ValueError) << "Target out of range.";
@ -135,12 +141,9 @@ void MultiMarginLossGradCPUKernel::LaunchKernelFP16(const std::vector<kernel::Ad
T *weight_addr = nullptr;
bool weight_defined_ = (input_num == 4);
if (weight_defined_) {
weight_addr = reinterpret_cast<T *>(inputs[3]->addr);
weight_addr = reinterpret_cast<T *>(inputs[kThree]->addr);
}
auto x_grad_addr = reinterpret_cast<T *>(outputs[0]->addr);
float weights;
weights = reduction == MEAN ? (static_cast<float>(1) / (static_cast<float>(dims) * static_cast<float>(batch_size)))
: (static_cast<float>(1) / static_cast<float>(dims));
auto x_grad_addr = reinterpret_cast<T *>(outputs[kZero]->addr);
auto task = [&](size_t start, size_t end) {
start *= dims;
end *= dims;
@ -158,6 +161,9 @@ void MultiMarginLossGradCPUKernel::LaunchKernelFP16(const std::vector<kernel::Ad
continue;
}
if (calc_data[d] > static_cast<float>(0)) {
auto weights = reduction == MEAN
? (static_cast<float>(1) / (static_cast<float>(dims) * static_cast<float>(batch_size)))
: (static_cast<float>(1) / static_cast<float>(dims));
calc_data[d] = (p == 1) ? weights : static_cast<float>(2) * weights * calc_data[d];
if (weight_defined_) {
calc_data[d] *= static_cast<float>(weight_addr[target_idx]);
@ -189,13 +195,15 @@ void MultiMarginLossGradCPUKernel::LaunchKernelFP16(const std::vector<kernel::Ad
}
}
void MultiMarginLossGradCPUKernel::CheckParam(const CNodePtr &kernel_node) {
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
void MultiMarginLossGradCPUKernelMod::CheckParam(const CNodePtr &kernel_node) {
size_t input_num = common::AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != kMultiMarginLossGradInputNumWithoutWeight && input_num != kMultiMarginLossGradInputNumWithWeight) {
MS_LOG(EXCEPTION) << "Invalid input numbers, expect input number 3 or 4, but actual input number " << input_num;
}
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
size_t output_num = common::AnfAlgo::GetOutputTensorNum(kernel_node);
CHECK_KERNEL_OUTPUTS_NUM(output_num, kMultiMarginLossGradOutputsNum, kKernelName);
}
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, MultiMarginLossGrad, MultiMarginLossGradCPUKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -1,5 +1,5 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -20,29 +20,66 @@
#include <string>
#include <unordered_map>
#include <vector>
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
#include "plugin/device/cpu/kernel/cpu_kernel.h"
#include "plugin/factory/ms_factory.h"
namespace mindspore {
namespace kernel {
class MultiMarginLossGradCPUKernel : public CPUKernel {
class MultiMarginLossGradCPUKernelMod : public DeprecatedNativeCpuKernelMod {
public:
MultiMarginLossGradCPUKernel() = default;
MultiMarginLossGradCPUKernelMod() = default;
~MultiMarginLossGradCPUKernel() override = default;
~MultiMarginLossGradCPUKernelMod() override = default;
void InitKernel(const CNodePtr &kernel_node) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
protected:
std::vector<KernelAttr> GetOpSupport() override {
static std::vector<KernelAttr> support_list = {KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16),
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeFloat64),
KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat16),
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat32),
KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat64)};
return support_list;
}
private:
template <typename T>
void LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
template <typename T>
void LaunchKernelFP16(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
private:
void CheckParam(const CNodePtr &kernel_node);
size_t batch_size = 2;
size_t dims = 1;
@ -53,57 +90,6 @@ class MultiMarginLossGradCPUKernel : public CPUKernel {
size_t y_grad_dims = 1;
TypeId dtype_{kTypeUnknown};
};
MS_REG_CPU_KERNEL(MultiMarginLossGrad,
KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16),
MultiMarginLossGradCPUKernel);
MS_REG_CPU_KERNEL(MultiMarginLossGrad,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
MultiMarginLossGradCPUKernel);
MS_REG_CPU_KERNEL(MultiMarginLossGrad,
KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeFloat64),
MultiMarginLossGradCPUKernel);
MS_REG_CPU_KERNEL(MultiMarginLossGrad,
KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat16),
MultiMarginLossGradCPUKernel);
MS_REG_CPU_KERNEL(MultiMarginLossGrad,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat32),
MultiMarginLossGradCPUKernel);
MS_REG_CPU_KERNEL(MultiMarginLossGrad,
KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat64),
MultiMarginLossGradCPUKernel);
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_MULTI_MARGIN_LOSS_GRAD_CPU_KERNEL_H_

View File

@ -91,4 +91,4 @@ std::vector<KernelAttr> MvlgammaCpuKernelMod::GetOpSupport() {
}
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, Mvlgamma, MvlgammaCpuKernelMod);
} // namespace kernel
} // namespace mindspore
} // namespace mindspore

View File

@ -24,7 +24,6 @@
namespace mindspore {
namespace kernel {
class MvlgammaCpuKernelMod : public DeprecatedNativeCpuKernelMod {
public:
MvlgammaCpuKernelMod() = default;
@ -51,8 +50,6 @@ class MvlgammaCpuKernelMod : public DeprecatedNativeCpuKernelMod {
protected:
std::vector<KernelAttr> GetOpSupport() override;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_MVLGAMMA_CPU_KERNEL_H_
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_MVLGAMMA_CPU_KERNEL_H_

View File

@ -152,6 +152,5 @@ std::vector<KernelAttr> MvlgammaGradCpuKernelMod::GetOpSupport() {
}
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, MvlgammaGrad, MvlgammaGradCpuKernelMod);
} // namespace kernel
} // namespace mindspore
} // namespace mindspore

View File

@ -24,7 +24,6 @@
namespace mindspore {
namespace kernel {
class MvlgammaGradCpuKernelMod : public DeprecatedNativeCpuKernelMod {
public:
MvlgammaGradCpuKernelMod() = default;
@ -54,8 +53,6 @@ class MvlgammaGradCpuKernelMod : public DeprecatedNativeCpuKernelMod {
protected:
std::vector<KernelAttr> GetOpSupport() override;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_MVLGAMMA_GRAD_CPU_KERNEL_H_
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_MVLGAMMA_GRAD_CPU_KERNEL_H_

View File

@ -1,5 +1,5 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -14,32 +14,35 @@
* limitations under the License.
*/
#include "backend/kernel_compiler/cpu/triplet_margin_loss_cpu_kernel.h"
#include "runtime/device/cpu/cpu_device_address.h"
#include "plugin/device/cpu/kernel/triplet_margin_loss_cpu_kernel.h"
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
namespace mindspore {
namespace kernel {
void TripletMarginLossCPUKernel::InitKernel(const CNodePtr &kernel_node) {
void TripletMarginLossCPUKernelMod::InitKernel(const CNodePtr &kernel_node) {
CheckParam(kernel_node);
constexpr int kzero = 0;
constexpr int kone = 1;
constexpr int ktwo = 2;
constexpr int kthree = 3;
constexpr int kZero = 0;
constexpr int kOne = 1;
constexpr int kTwo = 2;
constexpr int kThree = 3;
constexpr int kParallel = 28;
constexpr int kParallelunit = 1024;
p = AnfAlgo::GetNodeAttr<int64_t>(kernel_node, "p");
swap = AnfAlgo::GetNodeAttr<bool>(kernel_node, "swap");
eps = AnfAlgo::GetNodeAttr<float>(kernel_node, "eps");
reduction = AnfAlgo::GetNodeAttr<string>(kernel_node, "reduction");
dtype_0 = AnfAlgo::GetInputDeviceDataType(kernel_node, kzero);
dtype_1 = AnfAlgo::GetInputDeviceDataType(kernel_node, kone);
dtype_2 = AnfAlgo::GetInputDeviceDataType(kernel_node, ktwo);
dtype_3 = AnfAlgo::GetInputDeviceDataType(kernel_node, kthree);
x_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, kzero);
positive_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, kone);
negative_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, ktwo);
p = common::AnfAlgo::GetNodeAttr<int64_t>(kernel_node, "p");
swap = common::AnfAlgo::GetNodeAttr<bool>(kernel_node, "swap");
eps = common::AnfAlgo::GetNodeAttr<float>(kernel_node, "eps");
reduction = common::AnfAlgo::GetNodeAttr<string>(kernel_node, "reduction");
dtype_0 = AnfAlgo::GetInputDeviceDataType(kernel_node, kZero);
dtype_1 = AnfAlgo::GetInputDeviceDataType(kernel_node, kOne);
dtype_2 = AnfAlgo::GetInputDeviceDataType(kernel_node, kTwo);
dtype_3 = AnfAlgo::GetInputDeviceDataType(kernel_node, kThree);
x_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, kZero);
positive_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, kOne);
negative_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, kTwo);
if (AnfAlgo::IsShapesDynamic({x_shape, positive_shape, negative_shape})) {
return;
}
kParallelDataNum = kParallel * kParallelunit;
std::vector<size_t> broadcast_shape_x_and_positive = CPUKernelUtils::GetBroadcastShape(x_shape, positive_shape);
auto broadcast_shape_x_and_positive = CPUKernelUtils::GetBroadcastShape(x_shape, positive_shape);
broadcast_shape = CPUKernelUtils::GetBroadcastShape(broadcast_shape_x_and_positive, negative_shape);
size_t dim_x = x_shape.size();
size_t dim_positive = positive_shape.size();
@ -51,30 +54,28 @@ void TripletMarginLossCPUKernel::InitKernel(const CNodePtr &kernel_node) {
std::reverse(x_reshape_vector.begin(), x_reshape_vector.end());
std::reverse(positive_reshape_vector.begin(), positive_reshape_vector.end());
std::reverse(negative_reshape_vector.begin(), negative_reshape_vector.end());
if (dim_x < max_size) x_reshape_vector.resize(max_size, kone);
if (dim_positive < max_size) positive_reshape_vector.resize(max_size, kone);
if (dim_negative < max_size) negative_reshape_vector.resize(max_size, kone);
if (dim_x < max_size) x_reshape_vector.resize(max_size, kOne);
if (dim_positive < max_size) positive_reshape_vector.resize(max_size, kOne);
if (dim_negative < max_size) negative_reshape_vector.resize(max_size, kOne);
std::reverse(x_reshape_vector.begin(), x_reshape_vector.end());
std::reverse(positive_reshape_vector.begin(), positive_reshape_vector.end());
std::reverse(negative_reshape_vector.begin(), negative_reshape_vector.end());
numelements = 1;
for (size_t i = 0; i < broadcast_shape.size(); i++) {
numelements *= broadcast_shape[i];
}
data_num = (numelements) / (broadcast_shape[1]);
data_num_each_batch = (numelements) / (broadcast_shape[0]);
index = data_num / (broadcast_shape[0]);
batch_size = broadcast_shape[0];
once_compute_size = broadcast_shape[1];
numelements = LongToSize(SizeOf(broadcast_shape));
data_num = (numelements) / LongToSize(broadcast_shape[1]);
data_num_each_batch = (numelements) / LongToSize(broadcast_shape[0]);
index = data_num / LongToSize(broadcast_shape[0]);
batch_size = LongToSize(broadcast_shape[0]);
once_compute_size = LongToSize(broadcast_shape[1]);
broadcast = false;
if (x_shape != positive_shape || x_shape != negative_shape || positive_shape != negative_shape) {
broadcast = true;
}
}
bool TripletMarginLossCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &outputs) {
bool TripletMarginLossCPUKernelMod::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &outputs) {
switch (dtype_0) {
case kNumberTypeFloat16:
TripletMarginLossCompute_realtype<float>(inputs, outputs);
@ -123,73 +124,23 @@ bool TripletMarginLossCPUKernel::Launch(const std::vector<kernel::AddressPtr> &i
}
template <typename T>
void TripletMarginLossCPUKernel::TripletMarginLossCompute_realtype(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &outputs) {
auto out_data = reinterpret_cast<float *>(outputs[0]->addr);
Eigen::Array<float, Eigen::Dynamic, 1> out(data_num, 1);
float *output_reduction_none_data = reinterpret_cast<float *>(out.data());
auto task_nobroadcast = [&](size_t start, size_t end) {
TripletMarginLossCPUKernel::realtype_nobroadcast_task<T>(start, end, output_reduction_none_data, inputs, outputs);
};
auto task_broadcast = [&](size_t start, size_t end) {
TripletMarginLossCPUKernel::realtype_broadcast_task<T>(start, end, output_reduction_none_data, inputs, outputs);
};
if (broadcast == true) {
if (numelements * sizeof(T) > kParallelDataNum) {
CPUKernelUtils::ParallelFor(task_broadcast, batch_size);
} else {
TripletMarginLossCPUKernel::realtype_broadcast_compute<T>(output_reduction_none_data, inputs, outputs);
}
if (reduction == NONE) {
for (size_t i = 0; i < data_num; i++) {
*(out_data + i) = *(output_reduction_none_data + i);
}
}
if (reduction == MEAN) {
*(out_data) = (out.mean());
}
if (reduction == SUM) {
*(out_data) = (out.sum());
}
return;
}
if (numelements * sizeof(T) > kParallelDataNum) {
CPUKernelUtils::ParallelFor(task_nobroadcast, batch_size);
} else {
TripletMarginLossCPUKernel::realtype_nobroadcast_compute<T>(output_reduction_none_data, inputs, outputs);
}
if (reduction == NONE) {
for (size_t i = 0; i < data_num; i++) {
*(out_data + i) = *(output_reduction_none_data + i);
}
}
if (reduction == MEAN) {
*(out_data) = (out.mean());
}
if (reduction == SUM) {
*(out_data) = (out.sum());
}
return;
}
template <typename T>
void TripletMarginLossCPUKernel::TripletMarginLossCompute_complextype(const std::vector<kernel::AddressPtr> &inputs,
void TripletMarginLossCPUKernelMod::TripletMarginLossCompute_realtype(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &outputs) {
auto out_data = reinterpret_cast<float *>(outputs[0]->addr);
Eigen::Array<float, Eigen::Dynamic, 1> out(data_num, 1);
float *output_reduction_none_data = reinterpret_cast<float *>(out.data());
auto task_nobroadcast = [&](size_t start, size_t end) {
TripletMarginLossCPUKernel::complextype_nobroadcast_task<T>(start, end, output_reduction_none_data, inputs,
TripletMarginLossCPUKernelMod::realtype_nobroadcast_task<T>(start, end, output_reduction_none_data, inputs,
outputs);
};
auto task_broadcast = [&](size_t start, size_t end) {
TripletMarginLossCPUKernel::complextype_broadcast_task<T>(start, end, output_reduction_none_data, inputs, outputs);
TripletMarginLossCPUKernelMod::realtype_broadcast_task<T>(start, end, output_reduction_none_data, inputs, outputs);
};
if (broadcast == true) {
if (numelements * sizeof(T) > kParallelDataNum) {
CPUKernelUtils::ParallelFor(task_broadcast, batch_size);
} else {
TripletMarginLossCPUKernel::complextype_broadcast_compute<T>(output_reduction_none_data, inputs, outputs);
TripletMarginLossCPUKernelMod::realtype_broadcast_compute<T>(output_reduction_none_data, inputs, outputs);
}
if (reduction == NONE) {
for (size_t i = 0; i < data_num; i++) {
@ -207,7 +158,7 @@ void TripletMarginLossCPUKernel::TripletMarginLossCompute_complextype(const std:
if (numelements * sizeof(T) > kParallelDataNum) {
CPUKernelUtils::ParallelFor(task_nobroadcast, batch_size);
} else {
TripletMarginLossCPUKernel::complextype_nobroadcast_compute<T>(output_reduction_none_data, inputs, outputs);
TripletMarginLossCPUKernelMod::realtype_nobroadcast_compute<T>(output_reduction_none_data, inputs, outputs);
}
if (reduction == NONE) {
for (size_t i = 0; i < data_num; i++) {
@ -224,9 +175,62 @@ void TripletMarginLossCPUKernel::TripletMarginLossCompute_complextype(const std:
}
template <typename T>
void TripletMarginLossCPUKernel::realtype_nobroadcast_task(size_t start, size_t end, float *output_reduction_none_data,
const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &outputs) {
void TripletMarginLossCPUKernelMod::TripletMarginLossCompute_complextype(
const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &outputs) {
auto out_data = reinterpret_cast<float *>(outputs[0]->addr);
Eigen::Array<float, Eigen::Dynamic, 1> out(data_num, 1);
float *output_reduction_none_data = reinterpret_cast<float *>(out.data());
auto task_nobroadcast = [&](size_t start, size_t end) {
TripletMarginLossCPUKernelMod::complextype_nobroadcast_task<T>(start, end, output_reduction_none_data, inputs,
outputs);
};
auto task_broadcast = [&](size_t start, size_t end) {
TripletMarginLossCPUKernelMod::complextype_broadcast_task<T>(start, end, output_reduction_none_data, inputs,
outputs);
};
if (broadcast == true) {
if (numelements * sizeof(T) > kParallelDataNum) {
CPUKernelUtils::ParallelFor(task_broadcast, batch_size);
} else {
TripletMarginLossCPUKernelMod::complextype_broadcast_compute<T>(output_reduction_none_data, inputs, outputs);
}
if (reduction == NONE) {
for (size_t i = 0; i < data_num; i++) {
*(out_data + i) = *(output_reduction_none_data + i);
}
}
if (reduction == MEAN) {
*(out_data) = (out.mean());
}
if (reduction == SUM) {
*(out_data) = (out.sum());
}
return;
}
if (numelements * sizeof(T) > kParallelDataNum) {
CPUKernelUtils::ParallelFor(task_nobroadcast, batch_size);
} else {
TripletMarginLossCPUKernelMod::complextype_nobroadcast_compute<T>(output_reduction_none_data, inputs, outputs);
}
if (reduction == NONE) {
for (size_t i = 0; i < data_num; i++) {
*(out_data + i) = *(output_reduction_none_data + i);
}
}
if (reduction == MEAN) {
*(out_data) = (out.mean());
}
if (reduction == SUM) {
*(out_data) = (out.sum());
}
return;
}
template <typename T>
void TripletMarginLossCPUKernelMod::realtype_nobroadcast_task(size_t start, size_t end,
float *output_reduction_none_data,
const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &outputs) {
auto x_addr = reinterpret_cast<T *>(inputs[0]->addr);
auto positive_addr = reinterpret_cast<T *>(inputs[1]->addr);
auto negative_addr = reinterpret_cast<T *>(inputs[2]->addr);
@ -298,9 +302,9 @@ void TripletMarginLossCPUKernel::realtype_nobroadcast_task(size_t start, size_t
}
template <typename T>
void TripletMarginLossCPUKernel::realtype_broadcast_task(size_t start, size_t end, float *output_reduction_none_data,
const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &outputs) {
void TripletMarginLossCPUKernelMod::realtype_broadcast_task(size_t start, size_t end, float *output_reduction_none_data,
const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &outputs) {
auto x_addr = reinterpret_cast<T *>(inputs[0]->addr);
auto positive_addr = reinterpret_cast<T *>(inputs[1]->addr);
auto negative_addr = reinterpret_cast<T *>(inputs[2]->addr);
@ -349,8 +353,8 @@ void TripletMarginLossCPUKernel::realtype_broadcast_task(size_t start, size_t en
}
calc_1_sum += calculate_positive[k];
calc_2_sum += calculate_negative[k];
TripletMarginLossCPUKernel::realtype_swap<T>(start, positive_broadcast, negative_broadcast, calculate_swap, j,
k, calc_swap_sum, inputs, outputs);
TripletMarginLossCPUKernelMod::realtype_swap<T>(start, positive_broadcast, negative_broadcast, calculate_swap,
j, k, calc_swap_sum, inputs, outputs);
}
positive_distance = std::pow(static_cast<double>(calc_1_sum), (1 / static_cast<float>(p)));
if (x_reshape_vector[1] == 1 && positive_reshape_vector[1] == 1 && broadcast_shape[1] != 1) {
@ -375,9 +379,9 @@ void TripletMarginLossCPUKernel::realtype_broadcast_task(size_t start, size_t en
}
template <typename T>
void TripletMarginLossCPUKernel::realtype_broadcast_compute(float *output_reduction_none_data,
const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &outputs) {
void TripletMarginLossCPUKernelMod::realtype_broadcast_compute(float *output_reduction_none_data,
const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &outputs) {
auto x_addr = reinterpret_cast<T *>(inputs[0]->addr);
auto positive_addr = reinterpret_cast<T *>(inputs[1]->addr);
auto negative_addr = reinterpret_cast<T *>(inputs[2]->addr);
@ -424,8 +428,8 @@ void TripletMarginLossCPUKernel::realtype_broadcast_compute(float *output_reduct
}
calc_1_sum += calculate_positive[k];
calc_2_sum += calculate_negative[k];
TripletMarginLossCPUKernel::realtype_swap<T>(i * data_num_each_batch, positive_broadcast, negative_broadcast,
calculate_swap, j, k, calc_swap_sum, inputs, outputs);
TripletMarginLossCPUKernelMod::realtype_swap<T>(i * data_num_each_batch, positive_broadcast, negative_broadcast,
calculate_swap, j, k, calc_swap_sum, inputs, outputs);
}
positive_distance = std::pow(static_cast<double>(calc_1_sum), (1 / static_cast<float>(p)));
if (x_reshape_vector[1] == 1 && positive_reshape_vector[1] == 1 && broadcast_shape[1] != 1) {
@ -449,9 +453,9 @@ void TripletMarginLossCPUKernel::realtype_broadcast_compute(float *output_reduct
}
template <typename T>
void TripletMarginLossCPUKernel::realtype_nobroadcast_compute(float *output_reduction_none_data,
const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &outputs) {
void TripletMarginLossCPUKernelMod::realtype_nobroadcast_compute(float *output_reduction_none_data,
const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &outputs) {
auto x_addr = reinterpret_cast<T *>(inputs[0]->addr);
auto positive_addr = reinterpret_cast<T *>(inputs[1]->addr);
auto negative_addr = reinterpret_cast<T *>(inputs[2]->addr);
@ -518,10 +522,10 @@ void TripletMarginLossCPUKernel::realtype_nobroadcast_compute(float *output_redu
}
template <typename T>
void TripletMarginLossCPUKernel::complextype_nobroadcast_task(size_t start, size_t end,
float *output_reduction_none_data,
const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &outputs) {
void TripletMarginLossCPUKernelMod::complextype_nobroadcast_task(size_t start, size_t end,
float *output_reduction_none_data,
const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &outputs) {
auto x_addr = reinterpret_cast<T *>(inputs[0]->addr);
auto positive_addr = reinterpret_cast<T *>(inputs[1]->addr);
auto negative_addr = reinterpret_cast<T *>(inputs[2]->addr);
@ -575,9 +579,10 @@ void TripletMarginLossCPUKernel::complextype_nobroadcast_task(size_t start, size
}
template <typename T>
void TripletMarginLossCPUKernel::complextype_broadcast_task(size_t start, size_t end, float *output_reduction_none_data,
const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &outputs) {
void TripletMarginLossCPUKernelMod::complextype_broadcast_task(size_t start, size_t end,
float *output_reduction_none_data,
const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &outputs) {
auto x_addr = reinterpret_cast<T *>(inputs[0]->addr);
auto positive_addr = reinterpret_cast<T *>(inputs[1]->addr);
auto negative_addr = reinterpret_cast<T *>(inputs[2]->addr);
@ -630,8 +635,8 @@ void TripletMarginLossCPUKernel::complextype_broadcast_task(size_t start, size_t
}
calc_1_sum += calculate_positive_float;
calc_2_sum += calculate_negative_float;
TripletMarginLossCPUKernel::complextype_swap<T>(start, positive_broadcast, negative_broadcast, calculate_swap,
j, k, calc_swap_sum, inputs, outputs);
TripletMarginLossCPUKernelMod::complextype_swap<T>(start, positive_broadcast, negative_broadcast,
calculate_swap, j, k, calc_swap_sum, inputs, outputs);
}
positive_distance = std::pow(static_cast<double>(calc_1_sum), (1 / static_cast<float>(p)));
if (x_reshape_vector[1] == 1 && positive_reshape_vector[1] == 1 && broadcast_shape[1] != 1) {
@ -656,9 +661,9 @@ void TripletMarginLossCPUKernel::complextype_broadcast_task(size_t start, size_t
}
template <typename T>
void TripletMarginLossCPUKernel::complextype_broadcast_compute(float *output_reduction_none_data,
const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &outputs) {
void TripletMarginLossCPUKernelMod::complextype_broadcast_compute(float *output_reduction_none_data,
const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &outputs) {
auto x_addr = reinterpret_cast<T *>(inputs[0]->addr);
auto positive_addr = reinterpret_cast<T *>(inputs[1]->addr);
auto negative_addr = reinterpret_cast<T *>(inputs[2]->addr);
@ -707,8 +712,9 @@ void TripletMarginLossCPUKernel::complextype_broadcast_compute(float *output_red
}
calc_1_sum += calculate_positive_float;
calc_2_sum += calculate_negative_float;
TripletMarginLossCPUKernel::complextype_swap<T>(i * data_num_each_batch, positive_broadcast, negative_broadcast,
calculate_swap, j, k, calc_swap_sum, inputs, outputs);
TripletMarginLossCPUKernelMod::complextype_swap<T>(i * data_num_each_batch, positive_broadcast,
negative_broadcast, calculate_swap, j, k, calc_swap_sum,
inputs, outputs);
}
positive_distance = std::pow(static_cast<double>(calc_1_sum), (1 / static_cast<float>(p)));
if (x_reshape_vector[1] == 1 && positive_reshape_vector[1] == 1 && broadcast_shape[1] != 1) {
@ -732,9 +738,9 @@ void TripletMarginLossCPUKernel::complextype_broadcast_compute(float *output_red
}
template <typename T>
void TripletMarginLossCPUKernel::complextype_nobroadcast_compute(float *output_reduction_none_data,
const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &outputs) {
void TripletMarginLossCPUKernelMod::complextype_nobroadcast_compute(float *output_reduction_none_data,
const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &outputs) {
auto x_addr = reinterpret_cast<T *>(inputs[0]->addr);
auto positive_addr = reinterpret_cast<T *>(inputs[1]->addr);
auto negative_addr = reinterpret_cast<T *>(inputs[2]->addr);
@ -786,11 +792,11 @@ void TripletMarginLossCPUKernel::complextype_nobroadcast_compute(float *output_r
}
template <typename T>
void TripletMarginLossCPUKernel::realtype_swap(size_t start, std::vector<T> &positive_broadcast,
std::vector<T> &negative_broadcast, std::vector<float> &calculate_swap,
size_t j, size_t k, float &calc_swap_sum,
const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &outputs) {
void TripletMarginLossCPUKernelMod::realtype_swap(size_t start, std::vector<T> &positive_broadcast,
std::vector<T> &negative_broadcast,
std::vector<float> &calculate_swap, size_t j, size_t k,
float &calc_swap_sum, const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &outputs) {
if (swap == true) {
calculate_swap[k] = abs(static_cast<float>(positive_broadcast[start + j + k * index]) -
static_cast<float>(negative_broadcast[start + j + k * index]) + eps);
@ -803,11 +809,11 @@ void TripletMarginLossCPUKernel::realtype_swap(size_t start, std::vector<T> &pos
}
template <typename T>
void TripletMarginLossCPUKernel::complextype_swap(size_t start, std::vector<T> &positive_broadcast,
std::vector<T> &negative_broadcast, std::vector<T> &calculate_swap,
size_t j, size_t k, float &calc_swap_sum,
const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &outputs) {
void TripletMarginLossCPUKernelMod::complextype_swap(size_t start, std::vector<T> &positive_broadcast,
std::vector<T> &negative_broadcast, std::vector<T> &calculate_swap,
size_t j, size_t k, float &calc_swap_sum,
const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &outputs) {
if (swap == true) {
calculate_swap[k] =
positive_broadcast[start + j + k * index] - negative_broadcast[start + j + k * index] + static_cast<T>(eps);
@ -821,17 +827,19 @@ void TripletMarginLossCPUKernel::complextype_swap(size_t start, std::vector<T> &
}
}
void TripletMarginLossCPUKernel::CheckParam(const CNodePtr &kernel_node) {
auto input_num = AnfAlgo::GetInputTensorNum(kernel_node);
constexpr int kone = 1;
void TripletMarginLossCPUKernelMod::CheckParam(const CNodePtr &kernel_node) {
auto input_num = common::AnfAlgo::GetInputTensorNum(kernel_node);
constexpr int kOne = 1;
constexpr int kfour = 4;
if (input_num != kfour) {
MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but TripletMarginLossCPUKernel needs 4 inputs.";
MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but TripletMarginLossCPUKernelMod needs 4 inputs.";
}
auto output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
if (output_num != kone) {
MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but TripletMarginLossCPUKernel needs 1 output.";
auto output_num = common::AnfAlgo::GetOutputTensorNum(kernel_node);
if (output_num != kOne) {
MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but TripletMarginLossCPUKernelMod needs 1 output.";
}
}
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, TripletMarginLoss, TripletMarginLossCPUKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -1,5 +1,5 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -25,21 +25,99 @@
#include <map>
#include <complex>
#include <iostream>
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
#include "plugin/device/cpu/kernel/cpu_kernel.h"
#include "plugin/factory/ms_factory.h"
namespace mindspore {
namespace kernel {
class TripletMarginLossCPUKernel : public CPUKernel {
class TripletMarginLossCPUKernelMod : public DeprecatedNativeCpuKernelMod {
public:
TripletMarginLossCPUKernel() = default;
~TripletMarginLossCPUKernel() override = default;
TripletMarginLossCPUKernelMod() = default;
~TripletMarginLossCPUKernelMod() override = default;
void InitKernel(const CNodePtr &kernel_node) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
protected:
std::vector<KernelAttr> GetOpSupport() override {
static std::vector<KernelAttr> support_list = {KernelAttr()
.AddInputAttr(kNumberTypeInt8)
.AddInputAttr(kNumberTypeInt8)
.AddInputAttr(kNumberTypeInt8)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
KernelAttr()
.AddInputAttr(kNumberTypeInt16)
.AddInputAttr(kNumberTypeInt16)
.AddInputAttr(kNumberTypeInt16)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
KernelAttr()
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
KernelAttr()
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
KernelAttr()
.AddInputAttr(kNumberTypeUInt8)
.AddInputAttr(kNumberTypeUInt8)
.AddInputAttr(kNumberTypeUInt8)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
KernelAttr()
.AddInputAttr(kNumberTypeUInt16)
.AddInputAttr(kNumberTypeUInt16)
.AddInputAttr(kNumberTypeUInt16)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
KernelAttr()
.AddInputAttr(kNumberTypeUInt32)
.AddInputAttr(kNumberTypeUInt32)
.AddInputAttr(kNumberTypeUInt32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
KernelAttr()
.AddInputAttr(kNumberTypeUInt64)
.AddInputAttr(kNumberTypeUInt64)
.AddInputAttr(kNumberTypeUInt64)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
KernelAttr()
.AddInputAttr(kNumberTypeComplex64)
.AddInputAttr(kNumberTypeComplex64)
.AddInputAttr(kNumberTypeComplex64)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
KernelAttr()
.AddInputAttr(kNumberTypeComplex128)
.AddInputAttr(kNumberTypeComplex128)
.AddInputAttr(kNumberTypeComplex128)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)};
return support_list;
}
private:
template <typename T>
void LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
@ -96,7 +174,6 @@ class TripletMarginLossCPUKernel : public CPUKernel {
std::vector<T> &calculate_swap, size_t j, size_t k, float &calc_swap_sum,
const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &outputs);
private:
void CheckParam(const CNodePtr &kernel_node);
int64_t p = 2;
bool swap = false;
@ -109,13 +186,13 @@ class TripletMarginLossCPUKernel : public CPUKernel {
TypeId dtype_1{kTypeUnknown};
TypeId dtype_2{kTypeUnknown};
TypeId dtype_3{kTypeUnknown};
std::vector<size_t> x_shape;
std::vector<size_t> positive_shape;
std::vector<size_t> negative_shape;
std::vector<size_t> broadcast_shape;
std::vector<size_t> x_reshape_vector;
std::vector<size_t> positive_reshape_vector;
std::vector<size_t> negative_reshape_vector;
ShapeVector x_shape;
ShapeVector positive_shape;
ShapeVector negative_shape;
ShapeVector broadcast_shape;
ShapeVector x_reshape_vector;
ShapeVector positive_reshape_vector;
ShapeVector negative_reshape_vector;
size_t numelements = 1;
size_t data_num = 1;
size_t data_num_each_batch = 1;
@ -124,114 +201,6 @@ class TripletMarginLossCPUKernel : public CPUKernel {
size_t once_compute_size = 1;
bool broadcast = false;
};
MS_REG_CPU_KERNEL(TripletMarginLoss,
KernelAttr()
.AddInputAttr(kNumberTypeInt8)
.AddInputAttr(kNumberTypeInt8)
.AddInputAttr(kNumberTypeInt8)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
TripletMarginLossCPUKernel);
MS_REG_CPU_KERNEL(TripletMarginLoss,
KernelAttr()
.AddInputAttr(kNumberTypeInt16)
.AddInputAttr(kNumberTypeInt16)
.AddInputAttr(kNumberTypeInt16)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
TripletMarginLossCPUKernel);
MS_REG_CPU_KERNEL(TripletMarginLoss,
KernelAttr()
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
TripletMarginLossCPUKernel);
MS_REG_CPU_KERNEL(TripletMarginLoss,
KernelAttr()
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
TripletMarginLossCPUKernel);
MS_REG_CPU_KERNEL(TripletMarginLoss,
KernelAttr()
.AddInputAttr(kNumberTypeUInt8)
.AddInputAttr(kNumberTypeUInt8)
.AddInputAttr(kNumberTypeUInt8)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
TripletMarginLossCPUKernel);
MS_REG_CPU_KERNEL(TripletMarginLoss,
KernelAttr()
.AddInputAttr(kNumberTypeUInt16)
.AddInputAttr(kNumberTypeUInt16)
.AddInputAttr(kNumberTypeUInt16)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
TripletMarginLossCPUKernel);
MS_REG_CPU_KERNEL(TripletMarginLoss,
KernelAttr()
.AddInputAttr(kNumberTypeUInt32)
.AddInputAttr(kNumberTypeUInt32)
.AddInputAttr(kNumberTypeUInt32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
TripletMarginLossCPUKernel);
MS_REG_CPU_KERNEL(TripletMarginLoss,
KernelAttr()
.AddInputAttr(kNumberTypeUInt64)
.AddInputAttr(kNumberTypeUInt64)
.AddInputAttr(kNumberTypeUInt64)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
TripletMarginLossCPUKernel);
MS_REG_CPU_KERNEL(TripletMarginLoss,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
TripletMarginLossCPUKernel);
MS_REG_CPU_KERNEL(TripletMarginLoss,
KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
TripletMarginLossCPUKernel);
MS_REG_CPU_KERNEL(TripletMarginLoss,
KernelAttr()
.AddInputAttr(kNumberTypeComplex64)
.AddInputAttr(kNumberTypeComplex64)
.AddInputAttr(kNumberTypeComplex64)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
TripletMarginLossCPUKernel);
MS_REG_CPU_KERNEL(TripletMarginLoss,
KernelAttr()
.AddInputAttr(kNumberTypeComplex128)
.AddInputAttr(kNumberTypeComplex128)
.AddInputAttr(kNumberTypeComplex128)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
TripletMarginLossCPUKernel);
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_TRIPLET_MARGIN_LOSS_CPU_KERNEL_H_

View File

@ -592,7 +592,7 @@ GVAR_DEF(PrimitivePtr, kPrimMaxPoolGradWithArgmax, std::make_shared<Primitive>("
GVAR_DEF(PrimitivePtr, kPrimMaxPoolGradGradWithArgmax, std::make_shared<Primitive>("MaxPoolGradGradWithArgmax"));
GVAR_DEF(PrimitivePtr, kPrimMaxPool3DWithArgmax, std::make_shared<Primitive>("MaxPool3DWithArgmax"));
GVAR_DEF(PrimitivePtr, kPrimMaxPool3DGradWithArgmax, std::make_shared<Primitive>("MaxPool3DGradWithArgmax"));
GVAR_DEF(PrimitivePtr, kPrimMaxUnpool2D, std::make_shared<Primitive>(kMaxUnpool2DGrad));
GVAR_DEF(PrimitivePtr, kPrimMaxUnpool2D, std::make_shared<Primitive>(kMaxUnpool2D));
GVAR_DEF(PrimitivePtr, kPrimMaxUnpool2DGrad, std::make_shared<Primitive>(kMaxUnpool2DGrad));
GVAR_DEF(PrimitivePtr, kPrimMaxUnpool3D, std::make_shared<Primitive>(kMaxUnpool3D));
GVAR_DEF(PrimitivePtr, kPrimMaxUnpool3DGrad, std::make_shared<Primitive>(kMaxUnpool3DGrad));

View File

@ -1,5 +1,5 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -16,9 +16,13 @@
#include "ops/grad/max_unpool2d_grad.h"
#include <set>
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "abstract/ops/primitive_infer_map.h"
#include "mindapi/src/helper.h"
namespace mindspore {
namespace ops {
constexpr int64_t k4DInputDims = 4;
namespace {
abstract::ShapePtr MaxUnpool2DGradInferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
@ -34,11 +38,10 @@ abstract::ShapePtr MaxUnpool2DGradInferShape(const PrimitivePtr &primitive,
auto grads_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->GetShapeTrack())[kShape];
auto argmax_shape =
CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->GetShapeTrack())[kShape];
(void)CheckAndConvertUtils::CheckInteger("x_rank", SizeToLong(in_shape.size()), kEqual, k4DInputDims, op_name);
(void)CheckAndConvertUtils::CheckInteger("grads_rank", SizeToLong(grads_shape.size()), kEqual, k4DInputDims, op_name);
(void)CheckAndConvertUtils::CheckInteger("argmax_rank", SizeToLong(argmax_shape.size()), kEqual, k4DInputDims,
op_name);
CheckAndConvertUtils::Check("x_shape", in_shape, kEqual, "argmax_shape", argmax_shape, op_name);
(void)CheckAndConvertUtils::CheckInteger("x_rank", SizeToLong(in_shape.size()), kEqual, kDim4, op_name);
(void)CheckAndConvertUtils::CheckInteger("grads_rank", SizeToLong(grads_shape.size()), kEqual, kDim4, op_name);
(void)CheckAndConvertUtils::CheckInteger("argmax_rank", SizeToLong(argmax_shape.size()), kEqual, kDim4, op_name);
CheckAndConvertUtils::Check("x_shape", in_shape, kEqual, argmax_shape, op_name, ValueError);
auto x1_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape];
return std::make_shared<abstract::Shape>(x1_shape);
}
@ -59,6 +62,7 @@ TypePtr MaxUnpool2DGradInferType(const PrimitivePtr &primitive, const std::vecto
}
} // namespace
MIND_API_OPERATOR_IMPL(MaxUnpool2DGrad, BaseOperator);
AbstractBasePtr MaxUnpool2DGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);

View File

@ -1,5 +1,5 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -19,23 +19,20 @@
#include <vector>
#include <memory>
#include "ops/primitive_c.h"
#include "ops/op_utils.h"
#include "abstract/abstract_value.h"
#include "utils/check_convert_utils.h"
#include "ops/base_operator.h"
#include "mindapi/base/types.h"
namespace mindspore {
namespace ops {
constexpr auto kNameMaxUnpool2DGrad = "MaxUnpool2DGrad";
class MS_CORE_API MaxUnpool2DGrad : public PrimitiveC {
class MIND_API MaxUnpool2DGrad : public BaseOperator {
public:
MaxUnpool2DGrad() : PrimitiveC(kNameMaxUnpool2DGrad) { InitIOName({"x", "grads", "argmax"}, {"y"}); }
~MaxUnpool2DGrad() = default;
MS_DECLARE_PARENT(MaxUnpool2DGrad, PrimitiveC);
MIND_API_BASE_MEMBER(MaxUnpool2DGrad);
MaxUnpool2DGrad() : BaseOperator(kNameMaxUnpool2DGrad) { InitIOName({"x", "grads", "argmax"}, {"y"}); }
};
AbstractBasePtr MaxUnpool2DGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
abstract::AbstractBasePtr MaxUnpool2DGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &input_args);
using PrimMaxUnpool2DGradPtr = std::shared_ptr<MaxUnpool2DGrad>;
} // namespace ops
} // namespace mindspore

View File

@ -1,5 +1,5 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -16,10 +16,13 @@
#include "ops/grad/max_unpool3d_grad.h"
#include <set>
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "abstract/ops/primitive_infer_map.h"
#include "mindapi/src/helper.h"
namespace mindspore {
namespace ops {
constexpr int64_t k5DInputDims = 5;
namespace {
abstract::ShapePtr MaxUnpool3DGradInferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
@ -35,11 +38,10 @@ abstract::ShapePtr MaxUnpool3DGradInferShape(const PrimitivePtr &primitive,
auto grads_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->GetShapeTrack())[kShape];
auto argmax_shape =
CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->GetShapeTrack())[kShape];
(void)CheckAndConvertUtils::CheckInteger("x_rank", SizeToLong(in_shape.size()), kEqual, k5DInputDims, op_name);
(void)CheckAndConvertUtils::CheckInteger("grads_rank", SizeToLong(grads_shape.size()), kEqual, k5DInputDims, op_name);
(void)CheckAndConvertUtils::CheckInteger("argmax_rank", SizeToLong(argmax_shape.size()), kEqual, k5DInputDims,
op_name);
CheckAndConvertUtils::Check("x_shape", in_shape, kEqual, "argmax_shape", argmax_shape, op_name);
(void)CheckAndConvertUtils::CheckInteger("x_rank", SizeToLong(in_shape.size()), kEqual, kDim5, op_name);
(void)CheckAndConvertUtils::CheckInteger("grads_rank", SizeToLong(grads_shape.size()), kEqual, kDim5, op_name);
(void)CheckAndConvertUtils::CheckInteger("argmax_rank", SizeToLong(argmax_shape.size()), kEqual, kDim5, op_name);
CheckAndConvertUtils::Check("x_shape", in_shape, kEqual, argmax_shape, op_name, ValueError);
auto x1_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape];
return std::make_shared<abstract::Shape>(x1_shape);
}
@ -60,6 +62,7 @@ TypePtr MaxUnpool3DGradInferType(const PrimitivePtr &primitive, const std::vecto
}
} // namespace
MIND_API_OPERATOR_IMPL(MaxUnpool3DGrad, BaseOperator);
AbstractBasePtr MaxUnpool3DGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);

View File

@ -1,5 +1,5 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -19,23 +19,20 @@
#include <vector>
#include <memory>
#include "ops/primitive_c.h"
#include "ops/op_utils.h"
#include "abstract/abstract_value.h"
#include "utils/check_convert_utils.h"
#include "ops/base_operator.h"
#include "mindapi/base/types.h"
namespace mindspore {
namespace ops {
constexpr auto kNameMaxUnpool3DGrad = "MaxUnpool3DGrad";
class MS_CORE_API MaxUnpool3DGrad : public PrimitiveC {
class MIND_API MaxUnpool3DGrad : public BaseOperator {
public:
MaxUnpool3DGrad() : PrimitiveC(kNameMaxUnpool3DGrad) { InitIOName({"x", "grads", "argmax"}, {"y"}); }
~MaxUnpool3DGrad() = default;
MS_DECLARE_PARENT(MaxUnpool3DGrad, PrimitiveC);
MIND_API_BASE_MEMBER(MaxUnpool3DGrad);
MaxUnpool3DGrad() : BaseOperator(kNameMaxUnpool3DGrad) { InitIOName({"x", "grads", "argmax"}, {"y"}); }
};
AbstractBasePtr MaxUnpool3DGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
abstract::AbstractBasePtr MaxUnpool3DGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &input_args);
using PrimMaxUnpool3DGradPtr = std::shared_ptr<MaxUnpool3DGrad>;
} // namespace ops
} // namespace mindspore

View File

@ -1,5 +1,5 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -15,16 +15,14 @@
*/
#include "ops/grad/multi_margin_loss_grad.h"
#include "abstract/primitive_infer_map.h"
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "abstract/ops/primitive_infer_map.h"
#include "mindapi/src/helper.h"
namespace mindspore {
namespace ops {
namespace {
const size_t kone = 1;
const size_t ktwo = 2;
const size_t kfour = 4;
TypePtr MultiMarginLossGradInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
(void)CheckAndConvertUtils::CheckTensorTypeValid("target", input_args[kInputIndex2]->BuildType(), {kInt64},
prim->name());
@ -32,7 +30,7 @@ TypePtr MultiMarginLossGradInferType(const PrimitivePtr &prim, const std::vector
std::map<std::string, TypePtr> types;
(void)types.emplace("y_grad", input_args[kInputIndex0]->BuildType());
(void)types.emplace("x", input_args[kInputIndex1]->BuildType());
if (input_args.size() == kfour && input_args[kInputIndex3]->BuildType()->isa<TensorType>()) {
if (input_args.size() == kDim4 && input_args[kInputIndex3]->BuildType()->isa<TensorType>()) {
auto tensor_type = input_args[kInputIndex3]->BuildType()->cast<TensorTypePtr>();
MS_EXCEPTION_IF_NULL(tensor_type);
auto element = tensor_type->element();
@ -50,7 +48,7 @@ abstract::ShapePtr MultiMarginLossGradInferShape(const PrimitivePtr &primitive,
auto prim_name = primitive->name();
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape];
auto target_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape())[kShape];
if (x_shape.size() != ktwo || target_shape.size() != kone) {
if (x_shape.size() != kDim2 || target_shape.size() != kDim1) {
MS_EXCEPTION(ValueError) << "For MultiMarginLossGrad, the rank of input x should be 2, and "
"the rank of target should be 1,"
<< " while rank of x is " << x_shape.size() << ", rank of target is "
@ -61,7 +59,7 @@ abstract::ShapePtr MultiMarginLossGradInferShape(const PrimitivePtr &primitive,
<< " while x_shape[0] is " << x_shape[kInputIndex0] << ", target_shape[0] is "
<< target_shape[kInputIndex0];
}
if (input_args.size() == kfour && input_args[kInputIndex3]->BuildType()->isa<TensorType>()) {
if (input_args.size() == kDim4 && input_args[kInputIndex3]->BuildType()->isa<TensorType>()) {
auto tensor_type = input_args[kInputIndex3]->BuildType()->cast<TensorTypePtr>();
MS_EXCEPTION_IF_NULL(tensor_type);
auto element = tensor_type->element();
@ -69,7 +67,7 @@ abstract::ShapePtr MultiMarginLossGradInferShape(const PrimitivePtr &primitive,
if (element->type_id() != kMetaTypeNone) {
auto weight_shape =
CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex3]->BuildShape())[kShape];
if (weight_shape.size() != kone) {
if (weight_shape.size() != kDim1) {
MS_EXCEPTION(ValueError) << "For " << prim_name << " the rank of weight should be 1,"
<< " but get " << weight_shape.size();
}
@ -84,6 +82,7 @@ abstract::ShapePtr MultiMarginLossGradInferShape(const PrimitivePtr &primitive,
}
} // namespace
MIND_API_OPERATOR_IMPL(MultiMarginLossGrad, BaseOperator);
AbstractBasePtr MultiMarginLossGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
constexpr size_t kInputNumWithWeight = 4;
@ -96,7 +95,7 @@ AbstractBasePtr MultiMarginLossGradInfer(const abstract::AnalysisEnginePtr &, co
MS_EXCEPTION_IF_NULL(input_args[kInputIndex0]);
MS_EXCEPTION_IF_NULL(input_args[kInputIndex1]);
MS_EXCEPTION_IF_NULL(input_args[kInputIndex2]);
if (input_args.size() == kfour) {
if (input_args.size() == kInputNumWithWeight) {
MS_EXCEPTION_IF_NULL(input_args[kInputIndex3]);
}
auto types = MultiMarginLossGradInferType(primitive, input_args);

View File

@ -1,5 +1,5 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -23,24 +23,23 @@
#include <set>
#include <string>
#include <vector>
#include "abstract/abstract_value.h"
#include "ops/primitive_c.h"
#include "utils/check_convert_utils.h"
#include "ops/base_operator.h"
#include "mindapi/base/types.h"
namespace mindspore {
namespace ops {
constexpr auto kNameMultiMarginLossGrad = "MultiMarginLossGrad";
class MS_CORE_API MultiMarginLossGrad : public PrimitiveC {
class MIND_API MultiMarginLossGrad : public BaseOperator {
public:
MultiMarginLossGrad() : PrimitiveC(kNameMultiMarginLossGrad) {
MIND_API_BASE_MEMBER(MultiMarginLossGrad);
MultiMarginLossGrad() : BaseOperator(kNameMultiMarginLossGrad) {
InitIOName({"y_grad", "x", "target", "weight"}, {"x_grad"});
}
~MultiMarginLossGrad() = default;
MS_DECLARE_PARENT(MultiMarginLossGrad, PrimitiveC);
};
AbstractBasePtr MultiMarginLossGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
abstract::AbstractBasePtr MultiMarginLossGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &input_args);
using PrimMultiMarginLossGradPtr = std::shared_ptr<MultiMarginLossGrad>;
} // namespace ops
} // namespace mindspore

View File

@ -1,5 +1,5 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -16,8 +16,9 @@
#include "ops/grad/multilabel_margin_loss_grad.h"
#include "ops/op_utils.h"
#include "utils/tensor_construct_utils.h"
#include "abstract/primitive_infer_map.h"
#include "utils/check_convert_utils.h"
#include "abstract/ops/primitive_infer_map.h"
#include "mindapi/src/helper.h"
namespace mindspore {
namespace ops {
@ -27,9 +28,7 @@ abstract::ShapePtr MultilabelMarginLossGradInferShape(const PrimitivePtr &primit
auto op_name = primitive->name();
auto x = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape];
auto target = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape())[kShape];
const size_t kone = 1;
const size_t ktwo = 2;
if ((x.size() != kone && x.size() != ktwo) || (target.size() != kone && target.size() != ktwo)) {
if ((x.size() != kDim1 && x.size() != kDim2) || (target.size() != kDim1 && target.size() != kDim2)) {
MS_EXCEPTION(ValueError) << "For " << op_name << ", the rank of input x and target should be 1 or 2, "
<< "while rank of x is : " << x.size() << ", rank of target is : " << target.size() << ".";
}
@ -57,6 +56,7 @@ TypePtr MultilabelMarginLossGradInferType(const PrimitivePtr &primitive,
}
} // namespace
MIND_API_OPERATOR_IMPL(MultilabelMarginLossGrad, BaseOperator);
AbstractBasePtr MultilabelMarginLossGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);

View File

@ -1,5 +1,5 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -23,24 +23,24 @@
#include <set>
#include <string>
#include <vector>
#include "abstract/abstract_value.h"
#include "ops/primitive_c.h"
#include "utils/check_convert_utils.h"
#include "ops/base_operator.h"
#include "mindapi/base/types.h"
namespace mindspore {
namespace ops {
constexpr auto kNameMultilabelMarginLossGrad = "MultilabelMarginLossGrad";
class MS_CORE_API MultilabelMarginLossGrad : public PrimitiveC {
class MIND_API MultilabelMarginLossGrad : public BaseOperator {
public:
MultilabelMarginLossGrad() : PrimitiveC(kNameMultilabelMarginLossGrad) {
MIND_API_BASE_MEMBER(MultilabelMarginLossGrad);
MultilabelMarginLossGrad() : BaseOperator(kNameMultilabelMarginLossGrad) {
InitIOName({"y_grad", "x", "target", "is_target"}, {"x_grad"});
}
~MultilabelMarginLossGrad() = default;
MS_DECLARE_PARENT(MultilabelMarginLossGrad, PrimitiveC);
};
AbstractBasePtr MultilabelMarginLossGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
abstract::AbstractBasePtr MultilabelMarginLossGradInfer(const abstract::AnalysisEnginePtr &,
const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &input_args);
using PrimMultilabelMarginLossGradPtr = std::shared_ptr<MultilabelMarginLossGrad>;
} // namespace ops
} // namespace mindspore

View File

@ -27,19 +27,18 @@
namespace mindspore {
namespace ops {
namespace {
abstract::ShapePtr MvlgammaGradInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
abstract::ShapePtr MvlgammaGradInferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
auto y_grad_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
auto y_grad_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape];
return std::make_shared<abstract::Shape>(y_grad_shape);
}
TypePtr MvlgammaGradInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(prim);
std::map<std::string, TypePtr> types;
(void)types.emplace("y_grad", input_args[0]->BuildType());
(void)types.emplace("x", input_args[1]->BuildType());
(void)types.emplace("y_grad", input_args[kInputIndex0]->BuildType());
(void)types.emplace("x", input_args[kInputIndex1]->BuildType());
const std::set<TypePtr> valid_types = {kFloat32, kFloat64};
return CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name());
}

View File

@ -32,7 +32,7 @@ class MIND_API MvlgammaGrad : public BaseOperator {
};
abstract::AbstractBasePtr MvlgammaGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &input_args);
const std::vector<abstract::AbstractBasePtr> &input_args);
using PrimMvlgammaGradPtr = std::shared_ptr<MvlgammaGrad>;
} // namespace ops
} // namespace mindspore

View File

@ -1,5 +1,5 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -20,17 +20,17 @@
#include <algorithm>
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "abstract/primitive_infer_map.h"
#include "abstract/ops/primitive_infer_map.h"
#include "mindapi/src/helper.h"
namespace mindspore {
namespace ops {
namespace {
constexpr int64_t k4DInputDims = 4;
abstract::ShapePtr InferShapeCompute(const std::string &data_format, const ShapeVector &in_shape,
const std::vector<int64_t> &ksize, const std::vector<int64_t> &strides,
const std::vector<int64_t> &pads, const std::vector<int64_t> &attr_output_shape,
const std::string &op_name) {
abstract::ShapePtr MaxUnpool2DInferShapeCompute(const std::string &data_format, const ShapeVector &in_shape,
const std::vector<int64_t> &ksize, const std::vector<int64_t> &strides,
const std::vector<int64_t> &pads,
const std::vector<int64_t> &attr_output_shape,
const std::string &op_name) {
if (data_format == "NCHW") {
int64_t out_h = static_cast<int64_t>((in_shape[kInputIndex2] - 1) * strides[kInputIndex2] - 2 * pads[kInputIndex2] +
ksize[kInputIndex2]);
@ -40,7 +40,7 @@ abstract::ShapePtr InferShapeCompute(const std::string &data_format, const Shape
if (std::any_of(out_shape.begin(), out_shape.end(), [](int64_t a) { return a <= 0; })) {
MS_LOG(EXCEPTION) << "MaxUnpool2D: Output size is not valid.";
}
if (attr_output_shape.size() == k4DInputDims) {
if (attr_output_shape.size() == kDim4) {
(void)CheckAndConvertUtils::CheckInteger("output_shape[0]", attr_output_shape[kInputIndex0], kEqual,
in_shape[kInputIndex0], op_name);
(void)CheckAndConvertUtils::CheckInteger("output_shape[1]", attr_output_shape[kInputIndex1], kEqual,
@ -74,7 +74,7 @@ abstract::ShapePtr InferShapeCompute(const std::string &data_format, const Shape
if (std::any_of(out_shape.begin(), out_shape.end(), [](int64_t a) { return a <= 0; })) {
MS_LOG(EXCEPTION) << "MaxUnpool2D: Output size is not valid.";
}
if (attr_output_shape.size() == k4DInputDims) {
if (attr_output_shape.size() == kDim4) {
(void)CheckAndConvertUtils::CheckInteger("output_shape[0]", attr_output_shape[kInputIndex0], kEqual,
in_shape[kInputIndex0], op_name);
(void)CheckAndConvertUtils::CheckInteger("output_shape[3]", attr_output_shape[kInputIndex3], kEqual,
@ -100,7 +100,8 @@ abstract::ShapePtr InferShapeCompute(const std::string &data_format, const Shape
}
}
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
abstract::ShapePtr MaxUnpool2DInferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto op_name = primitive->name();
MS_EXCEPTION_IF_NULL(input_args[kInputIndex0]);
@ -109,28 +110,27 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
auto argmax_shape =
CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->GetShapeTrack())[kShape];
auto data_format = GetValue<std::string>(primitive->GetAttr("format"));
(void)CheckAndConvertUtils::CheckInteger("x_rank", SizeToLong(in_shape.size()), kEqual, k4DInputDims, op_name);
(void)CheckAndConvertUtils::CheckInteger("argmax_rank", SizeToLong(argmax_shape.size()), kEqual, k4DInputDims,
op_name);
CheckAndConvertUtils::Check("x_shape", in_shape, kEqual, "argmax_shape", argmax_shape, op_name);
(void)CheckAndConvertUtils::CheckInteger("x_rank", SizeToLong(in_shape.size()), kEqual, kDim4, op_name);
(void)CheckAndConvertUtils::CheckInteger("argmax_rank", SizeToLong(argmax_shape.size()), kEqual, kDim4, op_name);
CheckAndConvertUtils::Check("x_shape", in_shape, kEqual, argmax_shape, op_name, ValueError);
auto ksize = GetValue<std::vector<int64_t>>(primitive->GetAttr("ksize"));
auto strides = GetValue<std::vector<int64_t>>(primitive->GetAttr("strides"));
auto pads = GetValue<std::vector<int64_t>>(primitive->GetAttr("pads"));
auto attr_output_shape = GetValue<std::vector<int64_t>>(primitive->GetAttr("output_shape"));
(void)CheckAndConvertUtils::CheckInteger("ksize_rank", SizeToLong(ksize.size()), kEqual, k4DInputDims, op_name);
(void)CheckAndConvertUtils::CheckInteger("strides_rank", SizeToLong(strides.size()), kEqual, k4DInputDims, op_name);
(void)CheckAndConvertUtils::CheckInteger("pads_rank", SizeToLong(pads.size()), kEqual, k4DInputDims, op_name);
(void)CheckAndConvertUtils::CheckInteger("ksize_rank", SizeToLong(ksize.size()), kEqual, kDim4, op_name);
(void)CheckAndConvertUtils::CheckInteger("strides_rank", SizeToLong(strides.size()), kEqual, kDim4, op_name);
(void)CheckAndConvertUtils::CheckInteger("pads_rank", SizeToLong(pads.size()), kEqual, kDim4, op_name);
if (attr_output_shape.size() != k4DInputDims && attr_output_shape.size() != 0) {
if (attr_output_shape.size() != kDim4 && attr_output_shape.size() != kDim0) {
MS_EXCEPTION(ValueError) << "MaxUnpool2D: Output_shape size must be 0 or 4.";
}
return InferShapeCompute(data_format, in_shape, ksize, strides, pads, attr_output_shape, op_name);
return MaxUnpool2DInferShapeCompute(data_format, in_shape, ksize, strides, pads, attr_output_shape, op_name);
}
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
TypePtr MaxUnpool2DInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
@ -143,13 +143,14 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
}
} // namespace
MIND_API_OPERATOR_IMPL(MaxUnpool2D, BaseOperator);
AbstractBasePtr MaxUnpool2DInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
const int64_t input_num = 2;
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name());
auto infer_type = InferType(primitive, input_args);
auto infer_shape = InferShape(primitive, input_args);
auto infer_type = MaxUnpool2DInferType(primitive, input_args);
auto infer_shape = MaxUnpool2DInferShape(primitive, input_args);
return abstract::MakeAbstract(infer_shape, infer_type);
}
REGISTER_PRIMITIVE_EVAL_IMPL(MaxUnpool2D, prim::kPrimMaxUnpool2D, MaxUnpool2DInfer, nullptr, true);

View File

@ -1,5 +1,5 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -19,23 +19,20 @@
#include <vector>
#include <memory>
#include "ops/primitive_c.h"
#include "ops/op_utils.h"
#include "abstract/abstract_value.h"
#include "utils/check_convert_utils.h"
#include "ops/base_operator.h"
#include "mindapi/base/types.h"
namespace mindspore {
namespace ops {
constexpr auto kNameMaxUnpool2D = "MaxUnpool2D";
class MS_CORE_API MaxUnpool2D : public PrimitiveC {
class MIND_API MaxUnpool2D : public BaseOperator {
public:
MaxUnpool2D() : PrimitiveC(kNameMaxUnpool2D) { InitIOName({"x", "argmax"}, {"y"}); }
~MaxUnpool2D() = default;
MS_DECLARE_PARENT(MaxUnpool2D, PrimitiveC);
MIND_API_BASE_MEMBER(MaxUnpool2D);
MaxUnpool2D() : BaseOperator(kNameMaxUnpool2D) { InitIOName({"x", "argmax"}, {"y"}); }
};
AbstractBasePtr MaxUnpool2DInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
abstract::AbstractBasePtr MaxUnpool2DInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &input_args);
using PrimMaxUnpool2DPtr = std::shared_ptr<MaxUnpool2D>;
} // namespace ops
} // namespace mindspore

View File

@ -1,5 +1,5 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -20,17 +20,17 @@
#include <set>
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "abstract/primitive_infer_map.h"
#include "abstract/ops/primitive_infer_map.h"
#include "mindapi/src/helper.h"
namespace mindspore {
namespace ops {
namespace {
constexpr int64_t k5DInputDims = 5;
abstract::ShapePtr InferShapeCompute(const std::string &data_format, const ShapeVector &in_shape,
const std::vector<int64_t> &ksize, const std::vector<int64_t> &strides,
const std::vector<int64_t> &pads, const std::vector<int64_t> &attr_output_shape,
const std::string &op_name) {
abstract::ShapePtr MaxUnpool3DInferShapeCompute(const std::string &data_format, const ShapeVector &in_shape,
const std::vector<int64_t> &ksize, const std::vector<int64_t> &strides,
const std::vector<int64_t> &pads,
const std::vector<int64_t> &attr_output_shape,
const std::string &op_name) {
if (data_format == "NCDHW") {
int64_t out_d = static_cast<int64_t>((in_shape[kInputIndex2] - 1) * strides[kInputIndex2] - 2 * pads[kInputIndex2] +
ksize[kInputIndex2]);
@ -42,7 +42,7 @@ abstract::ShapePtr InferShapeCompute(const std::string &data_format, const Shape
if (std::any_of(out_shape.begin(), out_shape.end(), [](int64_t a) { return a <= 0; })) {
MS_LOG(EXCEPTION) << "MaxUnpool3D: Output size is not valid.";
}
if (attr_output_shape.size() == k5DInputDims) {
if (attr_output_shape.size() == kDim5) {
(void)CheckAndConvertUtils::CheckInteger("output_shape[0]", attr_output_shape[kInputIndex0], kEqual,
in_shape[kInputIndex0], op_name);
(void)CheckAndConvertUtils::CheckInteger("output_shape[1]", attr_output_shape[kInputIndex1], kEqual,
@ -79,7 +79,7 @@ abstract::ShapePtr InferShapeCompute(const std::string &data_format, const Shape
if (std::any_of(out_shape.begin(), out_shape.end(), [](int64_t a) { return a <= 0; })) {
MS_LOG(EXCEPTION) << "MaxUnpool3D: Output size is not valid.";
}
if (attr_output_shape.size() == k5DInputDims) {
if (attr_output_shape.size() == kDim5) {
(void)CheckAndConvertUtils::CheckInteger("output_shape[0]", attr_output_shape[kInputIndex0], kEqual,
in_shape[kInputIndex0], op_name);
(void)CheckAndConvertUtils::CheckInteger("output_shape[4]", attr_output_shape[kInputIndex4], kEqual,
@ -108,7 +108,8 @@ abstract::ShapePtr InferShapeCompute(const std::string &data_format, const Shape
}
}
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
abstract::ShapePtr MaxUnpool3DInferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto op_name = primitive->name();
MS_EXCEPTION_IF_NULL(input_args[kInputIndex0]);
@ -117,26 +118,25 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
auto argmax_shape =
CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->GetShapeTrack())[kShape];
auto data_format = GetValue<std::string>(primitive->GetAttr("format"));
(void)CheckAndConvertUtils::CheckInteger("x_rank", SizeToLong(in_shape.size()), kEqual, k5DInputDims, op_name);
(void)CheckAndConvertUtils::CheckInteger("argmax_rank", SizeToLong(argmax_shape.size()), kEqual, k5DInputDims,
op_name);
CheckAndConvertUtils::Check("x_shape", in_shape, kEqual, "argmax_shape", argmax_shape, op_name);
(void)CheckAndConvertUtils::CheckInteger("x_rank", SizeToLong(in_shape.size()), kEqual, kDim5, op_name);
(void)CheckAndConvertUtils::CheckInteger("argmax_rank", SizeToLong(argmax_shape.size()), kEqual, kDim5, op_name);
CheckAndConvertUtils::Check("x_shape", in_shape, kEqual, argmax_shape, op_name, ValueError);
auto ksize = GetValue<std::vector<int64_t>>(primitive->GetAttr("ksize"));
auto strides = GetValue<std::vector<int64_t>>(primitive->GetAttr("strides"));
auto pads = GetValue<std::vector<int64_t>>(primitive->GetAttr("pads"));
auto attr_output_shape = GetValue<std::vector<int64_t>>(primitive->GetAttr("output_shape"));
(void)CheckAndConvertUtils::CheckInteger("ksize_rank", SizeToLong(ksize.size()), kEqual, k5DInputDims, op_name);
(void)CheckAndConvertUtils::CheckInteger("strides_rank", SizeToLong(strides.size()), kEqual, k5DInputDims, op_name);
(void)CheckAndConvertUtils::CheckInteger("pads_rank", SizeToLong(pads.size()), kEqual, k5DInputDims, op_name);
(void)CheckAndConvertUtils::CheckInteger("ksize_rank", SizeToLong(ksize.size()), kEqual, kDim5, op_name);
(void)CheckAndConvertUtils::CheckInteger("strides_rank", SizeToLong(strides.size()), kEqual, kDim5, op_name);
(void)CheckAndConvertUtils::CheckInteger("pads_rank", SizeToLong(pads.size()), kEqual, kDim5, op_name);
if (attr_output_shape.size() != k5DInputDims && attr_output_shape.size() != 0) {
if (attr_output_shape.size() != kDim5 && attr_output_shape.size() != kDim0) {
MS_EXCEPTION(ValueError) << "MaxUnpool3D: Output_shape size must be 0 or 5.";
}
return InferShapeCompute(data_format, in_shape, ksize, strides, pads, attr_output_shape, op_name);
return MaxUnpool3DInferShapeCompute(data_format, in_shape, ksize, strides, pads, attr_output_shape, op_name);
}
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
TypePtr MaxUnpool3DInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
@ -149,13 +149,14 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
}
} // namespace
MIND_API_OPERATOR_IMPL(MaxUnpool3D, BaseOperator);
AbstractBasePtr MaxUnpool3DInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
const int64_t input_num = 2;
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name());
auto infer_type = InferType(primitive, input_args);
auto infer_shape = InferShape(primitive, input_args);
auto infer_type = MaxUnpool3DInferType(primitive, input_args);
auto infer_shape = MaxUnpool3DInferShape(primitive, input_args);
return abstract::MakeAbstract(infer_shape, infer_type);
}
REGISTER_PRIMITIVE_EVAL_IMPL(MaxUnpool3D, prim::kPrimMaxUnpool3D, MaxUnpool3DInfer, nullptr, true);

View File

@ -1,5 +1,5 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -19,23 +19,20 @@
#include <vector>
#include <memory>
#include "ops/primitive_c.h"
#include "ops/op_utils.h"
#include "abstract/abstract_value.h"
#include "utils/check_convert_utils.h"
#include "ops/base_operator.h"
#include "mindapi/base/types.h"
namespace mindspore {
namespace ops {
constexpr auto kNameMaxUnpool3D = "MaxUnpool3D";
class MS_CORE_API MaxUnpool3D : public PrimitiveC {
class MIND_API MaxUnpool3D : public BaseOperator {
public:
MaxUnpool3D() : PrimitiveC(kNameMaxUnpool3D) { InitIOName({"x", "argmax"}, {"y"}); }
~MaxUnpool3D() = default;
MS_DECLARE_PARENT(MaxUnpool3D, PrimitiveC);
MIND_API_BASE_MEMBER(MaxUnpool3D);
MaxUnpool3D() : BaseOperator(kNameMaxUnpool3D) { InitIOName({"x", "argmax"}, {"y"}); }
};
AbstractBasePtr MaxUnpool3DInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
abstract::AbstractBasePtr MaxUnpool3DInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &input_args);
using PrimMaxUnpool3DPtr = std::shared_ptr<MaxUnpool3D>;
} // namespace ops
} // namespace mindspore

View File

@ -1,5 +1,5 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -15,21 +15,20 @@
*/
#include "ops/multi_margin_loss.h"
#include "abstract/primitive_infer_map.h"
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "abstract/ops/primitive_infer_map.h"
#include "mindapi/src/helper.h"
namespace mindspore {
namespace ops {
namespace {
const size_t kone = 1;
const size_t ktwo = 2;
const size_t kthree = 3;
TypePtr MultiMarginLossInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
(void)CheckAndConvertUtils::CheckTensorTypeValid("target", input_args[1]->BuildType(), {kInt64}, prim->name());
(void)CheckAndConvertUtils::CheckTensorTypeValid("target", input_args[kInputIndex1]->BuildType(), {kInt64},
prim->name());
const std::set<TypePtr> valid_types = {kFloat16, kFloat32, kFloat64};
std::map<std::string, TypePtr> types;
(void)types.emplace("x", input_args[0]->BuildType());
(void)types.emplace("x", input_args[kInputIndex0]->BuildType());
if (input_args.size() == kInputIndex3 && input_args[kInputIndex2]->BuildType()->isa<TensorType>()) {
auto tensor_type = input_args[kInputIndex2]->BuildType()->cast<TensorTypePtr>();
MS_EXCEPTION_IF_NULL(tensor_type);
@ -51,7 +50,7 @@ abstract::ShapePtr MultiMarginLossInferShape(const PrimitivePtr &primitive,
MS_EXCEPTION_IF_NULL(input_args[kInputIndex1]);
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape];
auto target_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape];
if (x_shape.size() != ktwo || target_shape.size() != kone) {
if (x_shape.size() != kDim2 || target_shape.size() != kDim1) {
MS_EXCEPTION(ValueError) << "For MultiMarginLoss, the rank of input "
"x and target should be 2 and 1,"
<< " while rank of x is " << x_shape.size() << ", rank of target is "
@ -62,14 +61,15 @@ abstract::ShapePtr MultiMarginLossInferShape(const PrimitivePtr &primitive,
<< " while x_shape[0] is " << x_shape[kInputIndex0] << ", target_shape[0] is "
<< target_shape[kInputIndex0];
}
if (input_args.size() == kthree && input_args[kInputIndex2]->BuildType()->isa<TensorType>()) {
if (input_args.size() == kDim3 && input_args[kInputIndex2]->BuildType()->isa<TensorType>()) {
auto tensor_type = input_args[kInputIndex2]->BuildType()->cast<TensorTypePtr>();
MS_EXCEPTION_IF_NULL(tensor_type);
auto element = tensor_type->element();
MS_EXCEPTION_IF_NULL(element);
if (element->type_id() != kMetaTypeNone) {
auto weight_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->BuildShape())[kShape];
if (weight_shape.size() != kone) {
auto weight_shape =
CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape())[kShape];
if (weight_shape.size() != kDim1) {
MS_EXCEPTION(ValueError) << "For " << prim_name << " the rank of weight should be 1,"
<< " but get " << weight_shape.size();
}
@ -90,16 +90,17 @@ abstract::ShapePtr MultiMarginLossInferShape(const PrimitivePtr &primitive,
}
} // namespace
MIND_API_OPERATOR_IMPL(MultiMarginLoss, BaseOperator);
AbstractBasePtr MultiMarginLossInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
MS_EXCEPTION_IF_NULL(input_args[kInputIndex0]);
MS_EXCEPTION_IF_NULL(input_args[kInputIndex1]);
if (input_args.size() == kthree) {
if (input_args.size() == kDim3) {
MS_EXCEPTION_IF_NULL(input_args[kInputIndex2]);
}
(void)CheckAndConvertUtils::CheckInRange("multi_margin_loss_input_nums", input_args.size(), kIncludeBoth,
{ktwo, kthree}, primitive->name());
{kDim2, kDim3}, primitive->name());
auto types = MultiMarginLossInferType(primitive, input_args);
auto shapes = MultiMarginLossInferShape(primitive, input_args);
return abstract::MakeAbstract(shapes, types);

View File

@ -1,5 +1,5 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -22,22 +22,21 @@
#include <set>
#include <string>
#include <vector>
#include "abstract/abstract_value.h"
#include "ops/primitive_c.h"
#include "utils/check_convert_utils.h"
#include "ops/base_operator.h"
#include "mindapi/base/types.h"
namespace mindspore {
namespace ops {
constexpr auto kNameMultiMarginLoss = "MultiMarginLoss";
class MS_CORE_API MultiMarginLoss : public PrimitiveC {
class MIND_API MultiMarginLoss : public BaseOperator {
public:
MultiMarginLoss() : PrimitiveC(kNameMultiMarginLoss) { InitIOName({"x", "target", "weight"}, {"y"}); }
~MultiMarginLoss() = default;
MS_DECLARE_PARENT(MultiMarginLoss, PrimitiveC);
MIND_API_BASE_MEMBER(MultiMarginLoss);
MultiMarginLoss() : BaseOperator(kNameMultiMarginLoss) { InitIOName({"x", "target", "weight"}, {"y"}); }
};
AbstractBasePtr MultiMarginLossInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
abstract::AbstractBasePtr MultiMarginLossInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &input_args);
using PrimMultiMarginLossPtr = std::shared_ptr<MultiMarginLoss>;
} // namespace ops
} // namespace mindspore

View File

@ -1,5 +1,5 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -16,8 +16,9 @@
#include "ops/multilabel_margin_loss.h"
#include "ops/op_utils.h"
#include "utils/tensor_construct_utils.h"
#include "abstract/primitive_infer_map.h"
#include "utils/check_convert_utils.h"
#include "abstract/ops/primitive_infer_map.h"
#include "mindapi/src/helper.h"
namespace mindspore {
namespace ops {
@ -37,7 +38,7 @@ abstract::TupleShapePtr MultilabelMarginLossInferShape(const PrimitivePtr &primi
MS_EXCEPTION(ValueError) << "For " << op_name << ", x_shape and target_shape should be the same, "
<< "while x_shape is : " << x << ", target_shape is : " << target << ".";
}
int64_t batch = x[0];
int64_t batch = x[kInputIndex0];
ShapeVector out_shape0 = {batch};
ShapeVector out_shape1 = target;
int64_t reduction;
@ -66,6 +67,7 @@ TuplePtr MultilabelMarginLossInferType(const PrimitivePtr &primitive, const std:
}
} // namespace
MIND_API_OPERATOR_IMPL(MultilabelMarginLoss, BaseOperator);
AbstractBasePtr MultilabelMarginLossInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);

View File

@ -1,5 +1,5 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -21,27 +21,24 @@
#include <vector>
#include <set>
#include <string>
#include "ops/primitive_c.h"
#include "abstract/abstract_value.h"
#include "utils/check_convert_utils.h"
#include "ops/base_operator.h"
#include "mindapi/base/types.h"
namespace mindspore {
constexpr auto kNameMultilabelMarginLoss = prim::kMultilabelMarginLoss;
namespace ops {
constexpr auto kNameMultilabelMarginLoss = "MultilabelMarginLoss";
/// \brief Creates a criterion that optimizes a multi-class multi-classification hinge loss.
/// Refer to Python API @ref mindspore.ops.MultilabelMarginLoss for more details.
class MS_CORE_API MultilabelMarginLoss : public PrimitiveC {
class MIND_API MultilabelMarginLoss : public BaseOperator {
public:
MIND_API_BASE_MEMBER(MultilabelMarginLoss);
/// \brief Constructor.
MultilabelMarginLoss() : PrimitiveC(kNameMultilabelMarginLoss) { InitIOName({"x", "target"}, {"y", "is_target"}); }
/// \brief Destructor.
~MultilabelMarginLoss() = default;
MS_DECLARE_PARENT(MultilabelMarginLoss, PrimitiveC);
/// \brief Init. Refer to the parameters of Python API @ref mindspore.ops.MultilabelMarginLoss for the inputs.
void Init() const {}
MultilabelMarginLoss() : BaseOperator(kNameMultilabelMarginLoss) { InitIOName({"x", "target"}, {"y", "is_target"}); }
};
AbstractBasePtr MultilabelMarginLossInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
abstract::AbstractBasePtr MultilabelMarginLossInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &input_args);
using PrimMultilabelMarginLossPtr = std::shared_ptr<MultilabelMarginLoss>;
} // namespace ops
} // namespace mindspore

View File

@ -28,14 +28,14 @@ namespace ops {
namespace {
abstract::ShapePtr MvlgammaInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
MS_EXCEPTION_IF_NULL(input_args[0]);
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShapeTrack())[kShape];
MS_EXCEPTION_IF_NULL(input_args[kInputIndex0]);
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->GetShapeTrack())[kShape];
return std::make_shared<abstract::Shape>(in_shape);
}
TypePtr MvlgammaInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(prim);
auto input_type = input_args[0]->BuildType();
auto input_type = input_args[kInputIndex0]->BuildType();
const std::set<TypePtr> valid_types = {kFloat32, kFloat64};
return CheckAndConvertUtils::CheckTensorTypeValid("x", input_type, valid_types, prim->name());
}

View File

@ -30,12 +30,12 @@ constexpr auto kNameMvlgamma = "Mvlgamma";
class MIND_API Mvlgamma : public BaseOperator {
public:
MIND_API_BASE_MEMBER(Mvlgamma);
/// \brief Constructor.
/// \brief Constructor.
Mvlgamma() : BaseOperator(kNameMvlgamma) { InitIOName({"x"}, {"y"}); }
};
abstract::AbstractBasePtr MvlgammaInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &input_args);
const std::vector<abstract::AbstractBasePtr> &input_args);
using PrimMvlgammaPtr = std::shared_ptr<Mvlgamma>;
} // namespace ops
} // namespace mindspore

View File

@ -1,5 +1,5 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -16,13 +16,13 @@
#include "ops/triplet_margin_loss.h"
#include <algorithm>
#include "ops/op_utils.h"
#include "utils/tensor_construct_utils.h"
#include "abstract/primitive_infer_map.h"
#include "utils/check_convert_utils.h"
#include "abstract/ops/primitive_infer_map.h"
#include "mindapi/src/helper.h"
namespace mindspore {
namespace ops {
namespace {
constexpr size_t kInputSize = 4;
abstract::ShapePtr TripletMarginLossInferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
auto op_name = primitive->name();
@ -30,21 +30,19 @@ abstract::ShapePtr TripletMarginLossInferShape(const PrimitivePtr &primitive,
auto positive = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape];
auto negative = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape())[kShape];
auto margin = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex3]->BuildShape())[kShape];
const int64_t keight = 8;
if (x.size() >= keight || positive.size() >= keight || negative.size() >= keight) {
if (x.size() >= kDim8 || positive.size() >= kDim8 || negative.size() >= kDim8) {
MS_EXCEPTION(ValueError) << "For " << op_name
<< ", dimensions of input x positive and negative must be smaller than 8, x_dim: "
<< x.size() << ", positive_dim: " << positive.size()
<< ", negative_dim: " << negative.size() << ".";
}
const int64_t kone = 1;
if (x.size() <= kone && positive.size() <= kone && negative.size() <= kone) {
if (x.size() <= kDim1 && positive.size() <= kDim1 && negative.size() <= kDim1) {
MS_EXCEPTION(ValueError)
<< "For " << op_name
<< ", dimensions of input x, positive and negative cannot be less than 1 at the same time, x_dim: " << x.size()
<< ", positive_dim: " << positive.size() << ", negative_dim: " << negative.size() << ".";
}
if (margin.size() != 0) {
if (margin.size() != kDim0) {
MS_EXCEPTION(ValueError) << "For " << op_name
<< ", the dimension of input margin must be 0, margin_dim: " << margin.size() << ".";
}
@ -61,8 +59,8 @@ abstract::ShapePtr TripletMarginLossInferShape(const PrimitivePtr &primitive,
ShapeVector out_shape;
for (size_t i = 0; i < dims; i++) {
out_shape.push_back((int64_t)std::max(std::max(x[i], positive[i]), negative[i]));
if ((x[i] != out_shape[i] && x[i] != kone) || (positive[i] != out_shape[i] && positive[i] != kone) ||
(negative[i] != out_shape[i] && negative[i] != kone)) {
if ((x[i] != out_shape[i] && x[i] != kDim1) || (positive[i] != out_shape[i] && positive[i] != kDim1) ||
(negative[i] != out_shape[i] && negative[i] != kDim1)) {
MS_EXCEPTION(ValueError) << "For " << op_name << ", inputs' shape can't broadcast.";
}
}
@ -98,6 +96,7 @@ TypePtr TripletMarginLossInferType(const PrimitivePtr &primitive, const std::vec
}
} // namespace
MIND_API_OPERATOR_IMPL(TripletMarginLoss, BaseOperator);
AbstractBasePtr TripletMarginLossInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);

View File

@ -1,5 +1,5 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -21,24 +21,23 @@
#include <vector>
#include <set>
#include <string>
#include "ops/primitive_c.h"
#include "abstract/abstract_value.h"
#include "utils/check_convert_utils.h"
#include "ops/base_operator.h"
#include "mindapi/base/types.h"
namespace mindspore {
namespace ops {
constexpr auto kNameTripletMarginLoss = "TripletMarginLoss";
class TripletMarginLoss : public PrimitiveC {
class MIND_API TripletMarginLoss : public BaseOperator {
public:
TripletMarginLoss() : PrimitiveC(kNameTripletMarginLoss) {
MIND_API_BASE_MEMBER(TripletMarginLoss);
TripletMarginLoss() : BaseOperator(kNameTripletMarginLoss) {
InitIOName({"x", "positive", "negative", "margin"}, {"y"});
}
~TripletMarginLoss() = default;
MS_DECLARE_PARENT(TripletMarginLoss, PrimitiveC);
};
AbstractBasePtr TripletMarginLossInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
abstract::AbstractBasePtr TripletMarginLossInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &input_args);
using PrimTripletMarginLossPtr = std::shared_ptr<TripletMarginLoss>;
} // namespace ops
} // namespace mindspore

View File

@ -19,6 +19,9 @@ from mindspore import log
from mindspore.common.tensor import Tensor
from mindspore.common.parameter import Parameter
from mindspore.ops import operations as P
from mindspore.ops.operations.nn_ops import MultiMarginLoss as MultiMarginLossOp
from mindspore.ops.operations.nn_ops import MultilabelMarginLoss as MultilabelMarginLossOp
from mindspore.ops.operations.nn_ops import TripletMarginLoss as TripletMarginLossOp
from mindspore.ops import functional as F
from mindspore import nn
from mindspore.ops.primitive import constexpr
@ -1127,7 +1130,7 @@ class MultiMarginLoss(LossBase):
def __init__(self, p=1, margin=1.0, reduction='mean'):
"""Initialize MultiMarginLoss."""
super(MultiMarginLoss, self).__init__()
self.multi_margin_loss = P.MultiMarginLoss(p=p, margin=margin, reduction=reduction)
self.multi_margin_loss = MultiMarginLossOp(p=p, margin=margin, reduction=reduction)
self.ones = P.Ones()
def construct(self, x, target, weight=None):
@ -1369,7 +1372,7 @@ class MultilabelMarginLoss(LossBase):
def __init__(self, reduction='mean'):
super(MultilabelMarginLoss, self).__init__()
self.multilabel_margin_loss = P.MultilabelMarginLoss(reduction=reduction)
self.multilabel_margin_loss = MultilabelMarginLossOp(reduction=reduction)
def construct(self, x, target):
return self.multilabel_margin_loss(x, target)
@ -1497,74 +1500,6 @@ def _check_input_dtype(labels_dtype, cls_name):
[mstype.int32, mstype.int64, mstype.float16, mstype.float32], cls_name)
class MultilabelMarginLoss(LossBase):
r"""
MultilabelMarginLoss operation.
Creates a criterion that optimizes a multi-class multi-classification
hinge loss (margin-based loss) between input :math:`x` (a 2D mini-batch `Tensor`)
and output :math:`y` (which is a 2D `Tensor` of target class indices).
For each sample in the mini-batch:
.. math::
\text{loss}(x, y) = \sum_{ij}\frac{\max(0, 1 - (x[y[j]] - x[i]))}{\text{x.size}(0)}
where :math:`x \in \left\{0, \; \cdots , \; \text{x.size}(0) - 1\right\}`, \
:math:`y \in \left\{0, \; \cdots , \; \text{y.size}(0) - 1\right\}`, \
:math:`0 \leq y[j] \leq \text{x.size}(0)-1`, \
and :math:`i \neq y[j]` for all :math:`i` and :math:`j`.
:math:`y` and :math:`x` must have the same size.
The criterion only considers a contiguous block of non-negative targets that
starts at the front.
This allows for different samples to have variable amounts of target classes.
Args:
reduction (str): Apply specific reduction method to the output: 'none', 'mean', 'sum'. Default: "mean".
Inputs:
- **x** (Tensor) - Predict data. Tensor of shape :math:`(C)` or :math:`(N, C)`, where :math:`N`
is the batch size and :math:`C` is the number of classes. Data type must be float16 or float32.
- **target** (Tensor) - Ground truth data, with the same shape as `x`, data type must be int32 and
label targets padded by -1.
Outputs:
- **y** (Union[Tensor, Scalar]) - The loss of MultilabelMarginLoss. If `reduction` is "none", its shape
is :math:`(N)`. Otherwise, a scalar value will be returned.
- **is_target** (Tensor) - Output tensor for backward input, with the same shape as `target`,
data type must be int32.
Raises:
TypeError: If `x` or `target` is not a Tensor.
TypeError: If dtype of `x` is neither float16 nor float32.
TypeError: If dtype of `target` is not int32.
ValueError: If length of shape of `x` is neither 1 nor 2.
ValueError: If shape of `x` is not the same as `target`.
ValueError: If `reduction` is not one of 'none', 'mean', 'sum'.
Supported Platforms:
``Ascend``
Examples:
>>> loss = nn.MultilabelMarginLoss()
>>> x = Tensor(np.array([[0.1, 0.2, 0.4, 0.8], [0.2, 0.3, 0.5, 0.7]]), mindspore.float32)
>>> target = Tensor(np.array([[1, 2, 0, 3], [2, 3, -1, 1]]), mindspore.int32)
>>> output = loss(x, target)
>>> print(output)
(Tensor(shape=[], dtype=Float32, value= 0.325), Tensor(shape=[2, 4], dtype=Int32, value=
[[1, 1, 1, 1], [0, 0, 1, 1]]))
"""
def __init__(self, reduction='mean'):
super(MultilabelMarginLoss, self).__init__()
self.multilabel_margin_loss = P.MultilabelMarginLoss(reduction=reduction)
def construct(self, x, target):
return self.multilabel_margin_loss(x, target)
class FocalLoss(LossBase):
r"""
The loss function proposed by Kaiming team in their paper ``Focal Loss for Dense Object Detection`` improves the
@ -1867,7 +1802,7 @@ class TripletMarginLoss(LossBase):
def __init__(self, p=2, swap=False, eps=1e-6, reduction='mean'):
super(TripletMarginLoss, self).__init__()
self.triplet_margin_loss = P.TripletMarginLoss(p=p, swap=swap, eps=eps, reduction=reduction)
self.triplet_margin_loss = TripletMarginLossOp(p=p, swap=swap, eps=eps, reduction=reduction)
def construct(self, x, positive, negative, margin):
return self.triplet_margin_loss(x, positive, negative, margin)

View File

@ -40,7 +40,6 @@ from ..operations.array_ops import Expand
from ..operations.array_ops import SegmentMean
from .. import functional as F
from .. import operations as P
from ..operations import _grad_ops as G
from .._utils.utils import is_shape_unknown
from ..operations import _grad_ops as G

View File

@ -1,4 +1,4 @@
# Copyright 2021 Huawei Technologies Co., Ltd
# Copyright 2021-2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -24,12 +24,16 @@ from .._grad.grad_base import bprop_getters
from .. import operations as P
from ..composite.multitype_ops.zeros_like_impl import zeros_like
from ..operations import _grad_ops as G
from ..operations.nn_ops import MaxUnpool2D
from ..operations.nn_ops import MaxUnpool3D
from ..operations.nn_ops import FractionalMaxPool
from ..operations._grad_ops import FractionalMaxPoolGrad
from ..operations.nn_ops import FractionalMaxPool3DWithFixedKsize
from ..operations._grad_ops import FractionalMaxPool3DGradWithFixedKsize
from ..operations.nn_ops import FractionalAvgPool
from ..operations._grad_ops import FractionalAvgPoolGrad
from ..operations.nn_ops import MultiMarginLoss
from ..operations.nn_ops import MultilabelMarginLoss
from ..operations.nn_ops import NthElement
from ..operations.nn_ops import PSROIPooling
from ..operations._grad_ops import PSROIPoolingGrad
@ -92,7 +96,7 @@ def get_bprop_hshrink(self):
return bprop
@bprop_getters.register(P.MultilabelMarginLoss)
@bprop_getters.register(MultilabelMarginLoss)
def get_bprop_multilabel_margin_loss(self):
"""Grad definition for `MultilabelMarginLoss` operation."""
input_grad = G.MultilabelMarginLossGrad(reduction=self.reduction)
@ -120,7 +124,7 @@ def get_bprop_celu(self):
return bprop
@bprop_getters.register(P.MultiMarginLoss)
@bprop_getters.register(MultiMarginLoss)
def get_bprop_multi_margin_loss(self):
"""Grad definition for `MultiMarginLoss` operation."""
input_grad = G.MultiMarginLossGrad(p=self.p, margin=self.margin, reduction=self.reduction)
@ -155,7 +159,7 @@ def get_bprop_relu(self):
return bprop
@bprop_getters.register(P.MaxUnpool2D)
@bprop_getters.register(MaxUnpool2D)
def get_bprop_maxunpool2d(self):
"""Grad definition for `MaxUnpool2D` operation."""
maxunpool2d_grad = G.MaxUnpool2DGrad(
@ -173,7 +177,7 @@ def get_bprop_maxunpool2d(self):
return bprop
@bprop_getters.register(P.MaxUnpool3D)
@bprop_getters.register(MaxUnpool3D)
def get_bprop_maxunpool3d(self):
"""Grad definition for `MaxUnpool3D` operation."""
maxunpool3d_grad = G.MaxUnpool3DGrad(
@ -188,6 +192,8 @@ def get_bprop_maxunpool3d(self):
dargmax = zeros_like(argmax)
return (dx, dargmax)
return bprop
@bprop_getters.register(NthElement)
def get_bprop_nth_element(self):

View File

@ -29,6 +29,7 @@ bartlett_window_op_info = AiCPURegOp("BartlettWindow") \
.dtype_format(DataType.I64_Default, DataType.F64_Default) \
.get_op_info()
@op_info_register(bartlett_window_op_info)
def _bartlett_window_aicpu():
"""BartlettWindow AiCPU register"""

View File

@ -50,6 +50,7 @@ max_unpool2d_op_info = AiCPURegOp("MaxUnpool2D") \
.dtype_format(DataType.F64_Default, DataType.I64_Default, DataType.F64_Default) \
.get_op_info()
@op_info_register(max_unpool2d_op_info)
def _max_unpool2d_aicpu():
"""MaxUnpool2D aicpu register"""

View File

@ -51,6 +51,7 @@ max_unpool2d_grad_op_info = AiCPURegOp("MaxUnpool2DGrad") \
.dtype_format(DataType.F64_Default, DataType.F64_Default, DataType.I64_Default, DataType.F64_Default) \
.get_op_info()
@op_info_register(max_unpool2d_grad_op_info)
def _max_unpool2d_grad_aicpu():
"""MaxUnpool2DGrad aicpu register"""

View File

@ -50,6 +50,7 @@ max_unpool3d_op_info = AiCPURegOp("MaxUnpool3D") \
.dtype_format(DataType.F64_Default, DataType.I64_Default, DataType.F64_Default) \
.get_op_info()
@op_info_register(max_unpool3d_op_info)
def _max_unpool3d_aicpu():
"""MaxUnpool3D aicpu register"""

View File

@ -51,6 +51,7 @@ max_unpool3d_grad_op_info = AiCPURegOp("MaxUnpool3DGrad") \
.dtype_format(DataType.F64_Default, DataType.F64_Default, DataType.I64_Default, DataType.F64_Default) \
.get_op_info()
@op_info_register(max_unpool3d_grad_op_info)
def _max_unpool3d_grad_aicpu():
"""MaxUnpool3DGrad aicpu register"""

View File

@ -30,6 +30,7 @@ multi_margin_loss_op_info = AiCPURegOp("MultiMarginLoss") \
.dtype_format(DataType.F64_Default, DataType.I64_Default, DataType.F64_Default, DataType.F64_Default) \
.get_op_info()
@op_info_register(multi_margin_loss_op_info)
def _multi_margin_loss_aicpu():
"""MultiMarginLoss aicpu register"""

View File

@ -34,6 +34,7 @@ multi_margin_loss_grad_op_info = AiCPURegOp("MultiMarginLossGrad") \
DataType.F64_Default) \
.get_op_info()
@op_info_register(multi_margin_loss_grad_op_info)
def _multi_margin_loss_grad_aicpu():
"""MultiMarginLossGrad aicpu register"""

View File

@ -1,4 +1,4 @@
# Copyright 2021 Huawei Technologies Co., Ltd
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -30,6 +30,7 @@ multilabel_margin_loss_grad_op_info = AiCPURegOp("MultilabelMarginLossGrad") \
DataType.F32_Default) \
.get_op_info()
@op_info_register(multilabel_margin_loss_grad_op_info)
def _multilabel_margin_loss_grad_aicpu():
"""MultilabelMarginLossGrad aicpu register"""

View File

@ -25,6 +25,7 @@ mvlgamma_op_info = AiCPURegOp("Mvlgamma") \
.dtype_format(DataType.F64_Default, DataType.F64_Default) \
.get_op_info()
@op_info_register(mvlgamma_op_info)
def _mvlgamma_aicpu():
"""Mvlgamma AiCPU register"""

View File

@ -26,6 +26,7 @@ mvlgamma_grad_op_info = AiCPURegOp("MvlgammaGrad") \
.dtype_format(DataType.F64_Default, DataType.F64_Default, DataType.F64_Default,) \
.get_op_info()
@op_info_register(mvlgamma_grad_op_info)
def _mvlgamma_grad_aicpu():
"""MvlgammaGrad AiCPU register"""

View File

@ -591,7 +591,6 @@ from .extract_volume_patches import _extract_volume_patches_tbe
from .multilabel_margin_loss import _multilabel_margin_loss_tbe
from .round_ds import _round_ds_tbe
from .is_close import _is_close_tbe
from .multilabel_margin_loss import _multilabel_margin_loss_tbe
from .apply_adam_with_amsgrad import _apply_adam_with_amsgrad_tbe
from .apply_adam_with_amsgrad_ds import _apply_adam_with_amsgrad_ds_tbe
from .expm1_ds import _expm1_ds_tbe

View File

@ -1,4 +1,4 @@
# Copyright 2021 Huawei Technologies Co., Ltd
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@ -29,7 +29,8 @@ from ...common._decorator import deprecated
from ..primitive import Primitive, PrimitiveWithInfer, PrimitiveWithCheck, prim_attr_register
def _check_positive_int_or_tuple(arg_name, arg_value, prim_name, allow_four=False, ret_four=False):
def _check_positive_int_or_tuple(arg_name, arg_value, prim_name, allow_four=False,
ret_four=False, strict_positive=True):
"""
Checks whether an argument is a positive int or tuple with 2 or 4(when allow_four is True) positive int elements.
"""
@ -54,8 +55,11 @@ def _check_positive_int_or_tuple(arg_name, arg_value, prim_name, allow_four=Fals
validator.check_value_type(arg_name, arg_value, (int, tuple), prim_name)
ret_value = _get_return_value()
for item in ret_value:
if isinstance(item, int) and not isinstance(item, bool) and item > 0:
continue
if isinstance(item, int) and not isinstance(item, bool):
if item > 0:
continue
if not strict_positive and item == 0:
continue
_raise_message()
return ret_value
@ -2117,7 +2121,7 @@ class MaxUnpool2D(Primitive):
if strides in (0, (0, 0)):
strides = ksize
self.strides = _check_positive_int_or_tuple('strides', strides, self.name, ret_four=True)
self.pads = _check_positive_int_or_tuple('pads', pads, self.name, ret_four=True, greater_zero=False)
self.pads = _check_positive_int_or_tuple('pads', pads, self.name, ret_four=True, strict_positive=False)
self.data_format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'data_format', self.name)
if data_format == "NHWC":

View File

@ -84,6 +84,13 @@ from mindspore.ops.operations.nn_ops import FractionalAvgPool
from mindspore.ops.operations._grad_ops import FractionalAvgPoolGrad
from mindspore.ops.operations.nn_ops import GridSampler2D
from mindspore.ops.operations.nn_ops import GridSampler3D
from mindspore.ops.operations.nn_ops import MaxUnpool2D
from mindspore.ops.operations.nn_ops import MaxUnpool3D
from mindspore.nn.loss.loss import MultiMarginLoss
from mindspore.nn.loss.loss import MultilabelMarginLoss
from mindspore.nn.loss.loss import TripletMarginLoss
from mindspore.ops.operations.array_ops import Mvlgamma
from mindspore.ops.operations.other_ops import BartlettWindow
from mindspore.ops.operations.nn_ops import NthElement
from mindspore.ops.operations.nn_ops import SparseApplyAdagradDA
from mindspore.ops.operations.nn_ops import PSROIPooling
@ -2352,27 +2359,15 @@ test_case_nn_ops = [
'desc_inputs': [[10, 3, 28, 31, 24]],
'desc_bprop': [[10, 3, 14, 16, 12]]}),
('MaxUnpool2D', {
'block': P.MaxUnpool2D(ksize=(4, 4), strides=(2, 2), pads=(2, 2)),
'block': MaxUnpool2D(ksize=(4, 4), strides=(2, 2), pads=(2, 2)),
'desc_inputs': [([4, 3, 6, 6], {'dtype': np.float32}),
([4, 3, 6, 6], {'dtype': np.int64})],
'desc_bprop': [([4, 3, 10, 10], {'dtype': np.float32})]}),
('MaxUnpool2DGrad', {
'block': G.MaxUnpool2DGrad(ksize=(1, 1, 4, 4), strides=(1, 1, 2, 2), pads=(1, 1, 2, 2)),
'desc_inputs': [([4, 3, 6, 6], {'dtype': np.float32}),
([4, 3, 10, 10], {'dtype': np.float32}),
([4, 3, 6, 6], {'dtype': np.int64})],
'skip': ['backward']}),
('MaxUnpool3D', {
'block': P.MaxUnpool3D(ksize=(4, 4, 4), strides=(2, 2, 2), pads=(2, 2, 2)),
'block': MaxUnpool3D(ksize=(4, 4, 4), strides=(2, 2, 2), pads=(2, 2, 2)),
'desc_inputs': [([4, 3, 6, 6, 5], {'dtype': np.float32}),
([4, 3, 6, 6, 5], {'dtype': np.int64})],
'desc_bprop': [([4, 3, 10, 10, 8], {'dtype': np.float32})]}),
('MaxUnpool3DGrad', {
'block': G.MaxUnpool3DGrad(ksize=(1, 1, 4, 4, 4), strides=(1, 1, 2, 2, 2), pads=(1, 1, 2, 2, 2)),
'desc_inputs': [([4, 3, 6, 6, 5], {'dtype': np.float32}),
([4, 3, 10, 10, 8], {'dtype': np.float32}),
([4, 3, 6, 6, 5], {'dtype': np.int64})],
'skip': ['backward']}),
('MaxPoolWithArgmax', {
'block': P.MaxPoolWithArgmax(kernel_size=2, strides=2),
'desc_inputs': [[128, 32, 32, 64]],
@ -2752,18 +2747,10 @@ test_case_nn_ops = [
'desc_inputs': [[2, 2, 3], Tensor(np.array([2, 2]).astype(np.int32))],
'skip': ['backward']}),
('MultiMarginLoss', {
'block': nn.MultiMarginLoss(reduction="mean"),
'block': MultiMarginLoss(reduction="mean"),
'desc_inputs': [Tensor(np.array([[0.3, 0.7], [0.5, 0.5]]).astype(np.float32)),
Tensor(np.array([0, 0]).astype(np.int64))],
'desc_bprop': [[1]]}),
('MultiMarginLossGrad', {
'block': G.MultiMarginLossGrad(),
'desc_inputs': [Tensor(np.array([1]).astype(np.float32)),
Tensor(np.array([[0.3, 0.7], [0.5, 0.5]]).astype(np.float32)),
Tensor(np.array([1, 1]).astype(np.int64)),
Tensor(np.array([1, 1]).astype(np.float32))],
'desc_bprop': [Tensor([1], mstype.float32)],
'skip': ['backward']}),
('L2Loss_1', {
'block': P.L2Loss(),
'desc_inputs': [Tensor(np.array([1, 2, 3, 4]), mstype.float32)],
@ -2903,7 +2890,7 @@ test_case_nn_ops = [
Tensor(np.array([[-4, -3, -2], [1, 2, 4]]), mstype.float16)],
'skip': ['backward']}),
('TripletMarginLoss', {
'block': P.TripletMarginLoss(reduction="none"),
'block': TripletMarginLoss(reduction="none"),
'desc_inputs': [Tensor(np.array([[0.3, 0.7], [0.5, 0.5]]).astype(np.float32)),
Tensor(np.array([[0.4, 0.6], [0.4, 0.6]]).astype(np.float32)),
Tensor(np.array([[0.2, 0.9], [0.3, 0.7]]).astype(np.float32)),
@ -2952,10 +2939,11 @@ test_case_nn_ops = [
Tensor(0.99, mstype.float32)],
'skip': ['backward']}),
('MultilabelMarginLoss', {
'block': P.MultilabelMarginLoss(reduction="none"),
'block': MultilabelMarginLoss(reduction="none"),
'desc_inputs': [Tensor(np.array([[0.1, 0.2, 0.4, 0.8], [0.1, 0.2, 0.3, 0.4]]).astype(np.float32)),
Tensor(np.array([[2, 1, -1, 1], [1, -1, 2, 1]]).astype(np.int32))],
'desc_bprop': [Tensor(np.array([1, 2]).astype(np.float32))]}),
'desc_bprop': [Tensor(np.array([1, 2]).astype(np.float32)),
Tensor(np.array([[1, 1, 2, 1], [1, 1, 2, 1]]).astype(np.int32))]}),
('GridSampler3D', {
'block': GridSampler3D(interpolation_mode='bilinear', padding_mode='zeros', align_corners=False),
'desc_inputs': [Tensor(np.arange(32).reshape((2, 2, 2, 2, 2)).astype(np.float32)),
@ -2983,7 +2971,6 @@ test_case_nn_ops = [
Tensor(1, mstype.int64)],
'skip': ['backward']}),
]
test_case_array_ops = [
('LeftShift', {
'block': LeftShift(),
@ -3158,15 +3145,10 @@ test_case_array_ops = [
'desc_inputs': [(Tensor(np.array([-1.6, -0.1, 1.5, 2.0]).astype(np.float32)))],
'skip': ['backward']}),
('Mvlgamma', {
'block': P.Mvlgamma(p=1),
'block': Mvlgamma(p=1),
'desc_inputs': [Tensor(np.array([[3, 4, 5], [4, 2, 6]]).astype(np.float32))],
'desc_bprop': [Tensor(np.array([[3, 4, 5], [4, 2, 6]]).astype(np.float32))]
}),
('MvlgammaGrad', {
'block': G.MvlgammaGrad(p=1),
'desc_inputs': [Tensor(np.array([[3, 4, 5], [4, 2, 6]]).astype(np.float32)),
Tensor(np.array([[3, 4, 5], [4, 2, 6]]).astype(np.float32))],
'skip': ['backward']}),
('ConcatV2_0', {
'block': NetForConcat1(),
'desc_inputs': [
@ -3506,6 +3488,7 @@ test_case_array_ops = [
}),
]
test_case_image_ops = [
('AdjustHue', {
'block': AdjustHue(),
@ -3597,10 +3580,10 @@ test_case_other_ops = [
Tensor(np.array([[[0.38, 0.17, 0.95, 0.40]]], np.float32)),
Tensor(np.array([0.8], np.float32))),
'skip': ['backward']}),
('BartlettWindow', {
'block': P.BartlettWindow(periodic=True, dtype=mstype.float32),
'desc_inputs': (Tensor(np.array([10], np.int32))),
'skip': ['backward']}),
('BartlettWindow', {
'block': BartlettWindow(periodic=True, dtype=mstype.float32),
'desc_inputs': (Tensor(np.array([10], np.int32))),
'skip': ['backward']}),
('GatherNd', {
'block': P.GatherNd(),
'desc_inputs': (Tensor(np.ones((1, 3, 6, 6), np.float32)),