Refactor Ops MaxUnpool2D, MaxUnpool3D, MultiMarginLoss, MultilabelMarginLoss, TripletMarginLoss, BartlettWindow
This commit is contained in:
parent
0914ecf339
commit
4de4eb1a43
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -14,8 +14,11 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/kernel_compiler/cpu/bartlett_window_cpu_kernel.h"
|
||||
#include "runtime/device/cpu/cpu_device_address.h"
|
||||
#include <cmath>
|
||||
#include <string>
|
||||
#include <algorithm>
|
||||
#include "plugin/device/cpu/kernel/bartlett_window_cpu_kernel.h"
|
||||
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
|
@ -24,102 +27,94 @@ constexpr size_t kBartlettWindowInputsNum = 1;
|
|||
constexpr size_t kBartlettWindowOutputsNum = 1;
|
||||
} // namespace
|
||||
|
||||
void BartlettWindowCpuKernelMod<T, S>::InitKernel(const CNodePtr &kernel_node) {
|
||||
void BartlettWindowCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
kernel_name_ = AnfAlgo::GetCNodeName(kernel_node);
|
||||
input_dtype = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
|
||||
output_dtype = AnfAlgo::GetOutputDeviceDataType(kernel_node, 0);
|
||||
periodic_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, PERIODIC);
|
||||
if ((input_dtype != kNumberTypeInt32) && (input_dtype != kNumberTypeInt64)) {
|
||||
MS_LOG(EXCEPTION) << "Input tensor types must be int32 or int64";
|
||||
}
|
||||
input_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
|
||||
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
|
||||
periodic_ = common::AnfAlgo::GetNodeAttr<bool>(kernel_node, "periodic");
|
||||
auto input_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
|
||||
if (input_shape.size() > 0) {
|
||||
MS_EXCEPTION(ValueError) << "The dim of window_length must be 0.";
|
||||
MS_EXCEPTION(ValueError) << "For '" << kernel_name_ << "', the dim of window_length should be 0, but got "
|
||||
<< input_shape.size();
|
||||
}
|
||||
node_wpt_ = kernel_node;
|
||||
cnode_ptr_ = kernel_node;
|
||||
auto kernel_attr = GetKernelAttrFromNode(kernel_node);
|
||||
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
|
||||
if (!is_match) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', it does not support this kernel data type: " << kernel_attr;
|
||||
}
|
||||
kernel_func_ = func_list_[index].second;
|
||||
}
|
||||
|
||||
// template <typename T, typename S>
|
||||
// void BartlettWindowCpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs,
|
||||
// const std::vector<AddressPtr> &outputs) {
|
||||
// auto input = reinterpret_cast<T *>(inputs[0]->addr);
|
||||
// auto output = reinterpret_cast<S *>(outputs[0]->addr);
|
||||
// auto input_data = *input;
|
||||
// const size_t window_length = static_cast<size_t>(*input);
|
||||
// const S output_one = static_cast<S>(1.);
|
||||
// if (input_data < 0) {
|
||||
// MS_EXCEPTION(ValueError) << "Input window_length must ≥ 0!";
|
||||
// }
|
||||
// if (input_data == 1) {
|
||||
// *output = output_one;
|
||||
// } else {
|
||||
// if (periodic_) {
|
||||
// input_data += 1;
|
||||
// }
|
||||
// const size_t first_half_size = static_cast<size_t>((input_data - 1) / 2);
|
||||
// const double x = static_cast<double>(input_data);
|
||||
// for (size_t i = 0; i <= first_half_size; i++) {
|
||||
// auto value = static_cast<S>((2. * i) / (x - 1.));
|
||||
// *(output + i) = value;
|
||||
// }
|
||||
// for (size_t i = first_half_size + 1; i < window_length; i++) {
|
||||
// auto value = static_cast<S>(2. - (2. * i) / (x - 1.));
|
||||
// *(output + i) = value;
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
template <typename T1, typename T2>
|
||||
bool BartlettWindowCpuKernelMod::BartlettWindowKernelFunc(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
auto node_ = cnode_ptr_.lock();
|
||||
if (!node_) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', node_wpt_ is expired.";
|
||||
}
|
||||
|
||||
template <typename T, typename S>
|
||||
bool BartlettWindowCpuKernelMod<T, S>::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kBartlettWindowInputsNum, kernel_name_);
|
||||
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kBartlettWindowOutputsNum, kernel_name_);
|
||||
auto input = reinterpret_cast<T *>(inputs[0]->addr);
|
||||
auto output = reinterpret_cast<S *>(outputs[0]->addr);
|
||||
auto input_data = *input;
|
||||
const size_t window_length = static_cast<size_t>(*input);
|
||||
const S output_one = static_cast<S>(1.);
|
||||
if (input_data < 0) {
|
||||
MS_EXCEPTION(ValueError) << "Input window_length must ≥ 0!";
|
||||
auto input = reinterpret_cast<T1 *>(inputs[0]->addr);
|
||||
auto output = reinterpret_cast<T2 *>(outputs[0]->addr);
|
||||
|
||||
if (*input < 0) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << kernel_name_ << "', input window_length should be >= 0, but got " << *input;
|
||||
}
|
||||
if (input_data == 1) {
|
||||
*output = output_one;
|
||||
|
||||
auto window_length = static_cast<int64_t>(*input);
|
||||
double pre_window_length = static_cast<double>(window_length);
|
||||
const size_t OUTPUTISONE = 1.0;
|
||||
|
||||
ShapeVector out_shape = {window_length};
|
||||
std::vector<TypeId> dtypes = {AnfAlgo::GetOutputDeviceDataType(node_, 0)};
|
||||
|
||||
if (*input == 1) {
|
||||
*output = static_cast<T2>(OUTPUTISONE);
|
||||
} else {
|
||||
if (periodic_) {
|
||||
input_data += 1;
|
||||
window_length += 1;
|
||||
}
|
||||
const size_t first_half_size = static_cast<size_t>((input_data - 1) / 2);
|
||||
const double x = static_cast<double>(input_data);
|
||||
const size_t first_half_size = static_cast<size_t>((window_length - 1) / 2);
|
||||
const double x = static_cast<double>(window_length);
|
||||
for (size_t i = 0; i <= first_half_size; i++) {
|
||||
auto value = static_cast<S>((2. * i) / (x - 1.));
|
||||
auto value = static_cast<T2>((2. * i) / (x - 1.));
|
||||
*(output + i) = value;
|
||||
}
|
||||
for (size_t i = first_half_size + 1; i < window_length; i++) {
|
||||
auto value = static_cast<S>(2. - (2. * i) / (x - 1.));
|
||||
for (size_t i = first_half_size + 1; i < pre_window_length; i++) {
|
||||
auto value = static_cast<T2>(2. - (2. * i) / (x - 1.));
|
||||
*(output + i) = value;
|
||||
}
|
||||
}
|
||||
// if (output_dtype == kNumberTypeFloat16) {
|
||||
// if (input_dtype == kNumberTypeInt32) {
|
||||
// LaunchKernel<int32_t, float16>(inputs, outputs);
|
||||
// } else if (input_dtype == kNumberTypeInt64) {
|
||||
// LaunchKernel<int64_t, float16>(inputs, outputs);
|
||||
// }
|
||||
// } else if (output_dtype == kNumberTypeFloat32) {
|
||||
// if (input_dtype == kNumberTypeInt32) {
|
||||
// LaunchKernel<int32_t, float>(inputs, outputs);
|
||||
// } else if (input_dtype == kNumberTypeInt64) {
|
||||
// LaunchKernel<int64_t, float>(inputs, outputs);
|
||||
// }
|
||||
// } else if (output_dtype == kNumberTypeFloat64) {
|
||||
// if (input_dtype == kNumberTypeInt32) {
|
||||
// LaunchKernel<int32_t, double>(inputs, outputs);
|
||||
// } else if (input_dtype == kNumberTypeInt64) {
|
||||
// LaunchKernel<int64_t, double>(inputs, outputs);
|
||||
// }
|
||||
// }
|
||||
|
||||
common::AnfAlgo::SetOutputInferTypeAndShape(dtypes, {out_shape}, node_.get());
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<std::pair<KernelAttr, BartlettWindowCpuKernelMod::BartlettWindowFunc>>
|
||||
BartlettWindowCpuKernelMod::func_list_ = {
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
|
||||
&BartlettWindowCpuKernelMod::BartlettWindowKernelFunc<int32_t, float>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16),
|
||||
&BartlettWindowCpuKernelMod::BartlettWindowKernelFunc<int32_t, float16>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat64),
|
||||
&BartlettWindowCpuKernelMod::BartlettWindowKernelFunc<int32_t, double>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32),
|
||||
&BartlettWindowCpuKernelMod::BartlettWindowKernelFunc<int64_t, float>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat16),
|
||||
&BartlettWindowCpuKernelMod::BartlettWindowKernelFunc<int64_t, float16>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat64),
|
||||
&BartlettWindowCpuKernelMod::BartlettWindowKernelFunc<int64_t, double>}};
|
||||
|
||||
std::vector<KernelAttr> BartlettWindowCpuKernelMod::GetOpSupport() {
|
||||
std::vector<KernelAttr> support_list;
|
||||
(void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
|
||||
[](const std::pair<KernelAttr, BartlettWindowFunc> &pair) { return pair.first; });
|
||||
|
||||
return support_list;
|
||||
}
|
||||
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, BartlettWindow, BartlettWindowCpuKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -13,46 +13,45 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_BARTLETT_WINDOW_CPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_BARTLETT_WINDOW_CPU_KERNEL_H_
|
||||
#include <vector>
|
||||
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
|
||||
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
|
||||
#include <utility>
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel.h"
|
||||
#include "plugin/factory/ms_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
template <typename T, typename S>
|
||||
class BartlettWindowCpuKernelMod : public NativeCpuKernelMod {
|
||||
class BartlettWindowCpuKernelMod : public DeprecatedNativeCpuKernelMod {
|
||||
public:
|
||||
BartlettWindowCpuKernelMod() = default;
|
||||
~BartlettWindowCpuKernelMod() override = default;
|
||||
|
||||
void InitKernel(const CNodePtr &kernel_node) override;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
const std::vector<AddressPtr> &outputs) override {
|
||||
return kernel_func_(this, inputs, workspace, outputs);
|
||||
}
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
// template <typename T, typename S>
|
||||
// void LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
|
||||
template <typename T, typename T2>
|
||||
bool BartlettWindowKernelFunc(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &workspace,
|
||||
const std::vector<kernel::AddressPtr> &outputs);
|
||||
bool periodic_{true};
|
||||
TypeId output_dtype{kNumberTypeFloat32};
|
||||
TypeId input_dtype{kTypeUnknown};
|
||||
std::vector<size_t> input_shape;
|
||||
using BartlettWindowFunc =
|
||||
std::function<bool(BartlettWindowCpuKernelMod *, const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &, const std::vector<kernel::AddressPtr> &)>;
|
||||
static std::vector<std::pair<KernelAttr, BartlettWindowFunc>> func_list_;
|
||||
BartlettWindowFunc kernel_func_;
|
||||
ShapeVector input_shape;
|
||||
CNodePtr node_wpt_;
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL_T_S(BartlettWindow, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
|
||||
BartlettWindowCpuKernelMod, int32_t, float);
|
||||
MS_REG_CPU_KERNEL_T_S(BartlettWindow, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16),
|
||||
BartlettWindowCpuKernelMod, int32_t, float16);
|
||||
MS_REG_CPU_KERNEL_T_S(BartlettWindow, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat64),
|
||||
BartlettWindowCpuKernelMod, int32_t, double);
|
||||
MS_REG_CPU_KERNEL_T_S(BartlettWindow, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32),
|
||||
BartlettWindowCpuKernelMod, int64_t, float);
|
||||
MS_REG_CPU_KERNEL_T_S(BartlettWindow, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat16),
|
||||
BartlettWindowCpuKernelMod, int64_t, float16);
|
||||
MS_REG_CPU_KERNEL_T_S(BartlettWindow, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat64),
|
||||
BartlettWindowCpuKernelMod, int64_t, double);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_BARTLETT_WINDOW_CPU_KERNEL_H_
|
||||
|
|
|
@ -15,8 +15,10 @@
|
|||
*/
|
||||
#include <cmath>
|
||||
#include <map>
|
||||
#include "backend/kernel_compiler/cpu/max_unpool2d_cpu_kernel.h"
|
||||
#include "runtime/device/cpu/cpu_device_address.h"
|
||||
#include <utility>
|
||||
#include <algorithm>
|
||||
#include "plugin/device/cpu/kernel/max_unpool2d_cpu_kernel.h"
|
||||
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
|
@ -29,28 +31,40 @@ constexpr size_t kInputIndex2 = 2;
|
|||
constexpr size_t kInputIndex3 = 3;
|
||||
} // namespace
|
||||
|
||||
template <typename DATA_T, typename INDICES_T>
|
||||
void MaxUnpool2DCPUKernel<DATA_T, INDICES_T>::InitKernel(const CNodePtr &kernel_node) {
|
||||
void MaxUnpool2DCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
kernel_name_ = AnfAlgo::GetCNodeName(kernel_node);
|
||||
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
|
||||
node_wpt_ = kernel_node;
|
||||
input_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, kInputIndex0);
|
||||
indices_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, kInputIndex1);
|
||||
output_shape_ = AnfAlgo::GetOutputDeviceShape(kernel_node, kInputIndex0);
|
||||
data_format_ = AnfAlgo::GetNodeAttr<string>(kernel_node, FORMAT);
|
||||
data_format_ = common::AnfAlgo::GetNodeAttr<string>(kernel_node, FORMAT);
|
||||
|
||||
if (AnfAlgo::IsShapesDynamic({input_shape_, indices_shape_, output_shape_})) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto kernel_attr = GetKernelAttrFromNode(kernel_node);
|
||||
std::vector<KernelAttr> support_list;
|
||||
(void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
|
||||
[](const std::pair<KernelAttr, MaxUnpool2DFunc> &pair) { return pair.first; });
|
||||
auto [is_match, index] = MatchKernelAttr(kernel_attr, support_list);
|
||||
if (!is_match) {
|
||||
MS_LOG(EXCEPTION) << "MaxUnpool2D does not support this kernel data type: " << kernel_attr;
|
||||
}
|
||||
kernel_func_ = func_list_[index].second;
|
||||
}
|
||||
|
||||
template <typename DATA_T, typename INDICES_T>
|
||||
void MaxUnpool2DCPUKernel<DATA_T, INDICES_T>::OutPutInitKernel(DATA_T *raw_output, size_t length) {
|
||||
template <typename DATA_T>
|
||||
void MaxUnpool2DCpuKernelMod::OutPutInitKernel(DATA_T *raw_output, size_t length) {
|
||||
for (size_t s = 0; s < length; s++) {
|
||||
raw_output[s] = (DATA_T)0;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename DATA_T, typename INDICES_T>
|
||||
bool MaxUnpool2DCPUKernel<DATA_T, INDICES_T>::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
bool MaxUnpool2DCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
auto node = node_wpt_.lock();
|
||||
if (!node) {
|
||||
MS_LOG(EXCEPTION) << "node_wpt_ is expired.";
|
||||
|
@ -66,14 +80,14 @@ bool MaxUnpool2DCPUKernel<DATA_T, INDICES_T>::Launch(const std::vector<kernel::A
|
|||
auto *raw_indices = reinterpret_cast<INDICES_T *>(inputs[kInputIndex1]->addr);
|
||||
auto *raw_output = reinterpret_cast<DATA_T *>(outputs[kInputIndex0]->addr);
|
||||
if (data_format_ == "NHWC") {
|
||||
size_t num_batch = input_shape_[kInputIndex0];
|
||||
size_t input_height = input_shape_[kInputIndex1];
|
||||
size_t input_width = input_shape_[kInputIndex2];
|
||||
size_t num_channels = input_shape_[kInputIndex3];
|
||||
size_t oheight = output_shape_[kInputIndex1];
|
||||
size_t owidth = output_shape_[kInputIndex2];
|
||||
size_t num_batch = LongToSize(input_shape_[kInputIndex0]);
|
||||
size_t input_height = LongToSize(input_shape_[kInputIndex1]);
|
||||
size_t input_width = LongToSize(input_shape_[kInputIndex2]);
|
||||
size_t num_channels = LongToSize(input_shape_[kInputIndex3]);
|
||||
size_t oheight = LongToSize(output_shape_[kInputIndex1]);
|
||||
size_t owidth = LongToSize(output_shape_[kInputIndex2]);
|
||||
size_t length = num_batch * oheight * owidth * num_channels;
|
||||
OutPutInitKernel(raw_output, length);
|
||||
OutPutInitKernel<DATA_T>(raw_output, length);
|
||||
for (size_t n = 0; n < num_batch; n++) {
|
||||
size_t noutput_offset = n * num_channels * owidth * oheight;
|
||||
size_t n_input_offset = n * num_channels * input_width * input_height;
|
||||
|
@ -100,14 +114,14 @@ bool MaxUnpool2DCPUKernel<DATA_T, INDICES_T>::Launch(const std::vector<kernel::A
|
|||
}
|
||||
}
|
||||
} else {
|
||||
size_t num_batch = input_shape_[kInputIndex0];
|
||||
size_t input_height = input_shape_[kInputIndex2];
|
||||
size_t input_width = input_shape_[kInputIndex3];
|
||||
size_t num_channels = input_shape_[kInputIndex1];
|
||||
size_t oheight = output_shape_[kInputIndex2];
|
||||
size_t owidth = output_shape_[kInputIndex3];
|
||||
size_t num_batch = LongToSize(input_shape_[kInputIndex0]);
|
||||
size_t input_height = LongToSize(input_shape_[kInputIndex2]);
|
||||
size_t input_width = LongToSize(input_shape_[kInputIndex3]);
|
||||
size_t num_channels = LongToSize(input_shape_[kInputIndex1]);
|
||||
size_t oheight = LongToSize(output_shape_[kInputIndex2]);
|
||||
size_t owidth = LongToSize(output_shape_[kInputIndex3]);
|
||||
size_t length = num_batch * oheight * owidth * num_channels;
|
||||
OutPutInitKernel(raw_output, length);
|
||||
OutPutInitKernel<DATA_T>(raw_output, length);
|
||||
for (size_t n = 0; n < num_batch; n++) {
|
||||
size_t noutput_offset = n * num_channels * owidth * oheight;
|
||||
size_t n_input_offset = n * num_channels * input_width * input_height;
|
||||
|
@ -139,5 +153,60 @@ bool MaxUnpool2DCPUKernel<DATA_T, INDICES_T>::Launch(const std::vector<kernel::A
|
|||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<std::pair<KernelAttr, MaxUnpool2DCpuKernelMod::MaxUnpool2DFunc>> MaxUnpool2DCpuKernelMod::func_list_ = {
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt8),
|
||||
&MaxUnpool2DCpuKernelMod::LaunchKernel<uint8_t, int32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt8),
|
||||
&MaxUnpool2DCpuKernelMod::LaunchKernel<uint8_t, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt16),
|
||||
&MaxUnpool2DCpuKernelMod::LaunchKernel<uint16_t, int32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt16),
|
||||
&MaxUnpool2DCpuKernelMod::LaunchKernel<uint16_t, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt32),
|
||||
&MaxUnpool2DCpuKernelMod::LaunchKernel<uint32_t, int32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt32),
|
||||
&MaxUnpool2DCpuKernelMod::LaunchKernel<uint32_t, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt64),
|
||||
&MaxUnpool2DCpuKernelMod::LaunchKernel<uint64_t, int32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt64),
|
||||
&MaxUnpool2DCpuKernelMod::LaunchKernel<uint64_t, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt8),
|
||||
&MaxUnpool2DCpuKernelMod::LaunchKernel<int8_t, int32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt8),
|
||||
&MaxUnpool2DCpuKernelMod::LaunchKernel<int8_t, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt16),
|
||||
&MaxUnpool2DCpuKernelMod::LaunchKernel<int16_t, int32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt16),
|
||||
&MaxUnpool2DCpuKernelMod::LaunchKernel<int16_t, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
&MaxUnpool2DCpuKernelMod::LaunchKernel<int32_t, int32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32),
|
||||
&MaxUnpool2DCpuKernelMod::LaunchKernel<int32_t, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt64),
|
||||
&MaxUnpool2DCpuKernelMod::LaunchKernel<int64_t, int32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||
&MaxUnpool2DCpuKernelMod::LaunchKernel<int64_t, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16),
|
||||
&MaxUnpool2DCpuKernelMod::LaunchKernel<float16, int32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat16),
|
||||
&MaxUnpool2DCpuKernelMod::LaunchKernel<float16, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
|
||||
&MaxUnpool2DCpuKernelMod::LaunchKernel<float, int32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32),
|
||||
&MaxUnpool2DCpuKernelMod::LaunchKernel<float, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat64),
|
||||
&MaxUnpool2DCpuKernelMod::LaunchKernel<double, int32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat64),
|
||||
&MaxUnpool2DCpuKernelMod::LaunchKernel<double, int64_t>}};
|
||||
|
||||
std::vector<KernelAttr> MaxUnpool2DCpuKernelMod::GetOpSupport() {
|
||||
std::vector<KernelAttr> support_list;
|
||||
(void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
|
||||
[](const std::pair<KernelAttr, MaxUnpool2DFunc> &pair) { return pair.first; });
|
||||
return support_list;
|
||||
}
|
||||
|
||||
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, MaxUnpool2D, MaxUnpool2DCpuKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -19,117 +19,43 @@
|
|||
#include <memory>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
|
||||
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
|
||||
#include <utility>
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel.h"
|
||||
#include "plugin/factory/ms_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
template <typename DATA_T, typename INDICES_T>
|
||||
class MaxUnpool2DCPUKernel : public CPUKernel {
|
||||
class MaxUnpool2DCpuKernelMod : public DeprecatedNativeCpuKernelMod {
|
||||
public:
|
||||
MaxUnpool2DCPUKernel() = default;
|
||||
~MaxUnpool2DCPUKernel() override = default;
|
||||
MaxUnpool2DCpuKernelMod() = default;
|
||||
~MaxUnpool2DCpuKernelMod() override = default;
|
||||
|
||||
void InitKernel(const CNodePtr &kernel_node) override;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &outputs) override {
|
||||
return kernel_func_(this, inputs, outputs);
|
||||
};
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
template <typename DATA_T, typename INDICES_T>
|
||||
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
|
||||
using MaxUnpool2DFunc = std::function<bool(MaxUnpool2DCpuKernelMod *, const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &)>;
|
||||
static std::vector<std::pair<KernelAttr, MaxUnpool2DFunc>> func_list_;
|
||||
MaxUnpool2DFunc kernel_func_;
|
||||
|
||||
template <typename DATA_T>
|
||||
void OutPutInitKernel(DATA_T *rawOutput, size_t length);
|
||||
CNodeWeakPtr node_wpt_;
|
||||
std::vector<size_t> input_shape_;
|
||||
std::vector<size_t> indices_shape_;
|
||||
std::vector<size_t> output_shape_;
|
||||
ShapeVector input_shape_;
|
||||
ShapeVector indices_shape_;
|
||||
ShapeVector output_shape_;
|
||||
std::string data_format_;
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL_T_S(
|
||||
MaxUnpool2D,
|
||||
KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt8),
|
||||
MaxUnpool2DCPUKernel, uint8_t, int32_t);
|
||||
MS_REG_CPU_KERNEL_T_S(
|
||||
MaxUnpool2D,
|
||||
KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt8),
|
||||
MaxUnpool2DCPUKernel, uint8_t, int64_t);
|
||||
MS_REG_CPU_KERNEL_T_S(
|
||||
MaxUnpool2D,
|
||||
KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt16),
|
||||
MaxUnpool2DCPUKernel, uint16_t, int32_t);
|
||||
MS_REG_CPU_KERNEL_T_S(
|
||||
MaxUnpool2D,
|
||||
KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt16),
|
||||
MaxUnpool2DCPUKernel, uint16_t, int64_t);
|
||||
MS_REG_CPU_KERNEL_T_S(
|
||||
MaxUnpool2D,
|
||||
KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt32),
|
||||
MaxUnpool2DCPUKernel, uint32_t, int32_t);
|
||||
MS_REG_CPU_KERNEL_T_S(
|
||||
MaxUnpool2D,
|
||||
KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt32),
|
||||
MaxUnpool2DCPUKernel, uint32_t, int64_t);
|
||||
MS_REG_CPU_KERNEL_T_S(
|
||||
MaxUnpool2D,
|
||||
KernelAttr().AddInputAttr(kNumberTypeUInt64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt64),
|
||||
MaxUnpool2DCPUKernel, uint64_t, int32_t);
|
||||
MS_REG_CPU_KERNEL_T_S(
|
||||
MaxUnpool2D,
|
||||
KernelAttr().AddInputAttr(kNumberTypeUInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt64),
|
||||
MaxUnpool2DCPUKernel, uint64_t, int64_t);
|
||||
MS_REG_CPU_KERNEL_T_S(
|
||||
MaxUnpool2D, KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt8),
|
||||
MaxUnpool2DCPUKernel, int8_t, int32_t);
|
||||
MS_REG_CPU_KERNEL_T_S(
|
||||
MaxUnpool2D, KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt8),
|
||||
MaxUnpool2DCPUKernel, int8_t, int64_t);
|
||||
MS_REG_CPU_KERNEL_T_S(
|
||||
MaxUnpool2D,
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt16),
|
||||
MaxUnpool2DCPUKernel, int16_t, int32_t);
|
||||
MS_REG_CPU_KERNEL_T_S(
|
||||
MaxUnpool2D,
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt16),
|
||||
MaxUnpool2DCPUKernel, int16_t, int64_t);
|
||||
MS_REG_CPU_KERNEL_T_S(
|
||||
MaxUnpool2D,
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
MaxUnpool2DCPUKernel, int32_t, int32_t);
|
||||
MS_REG_CPU_KERNEL_T_S(
|
||||
MaxUnpool2D,
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32),
|
||||
MaxUnpool2DCPUKernel, int32_t, int64_t);
|
||||
MS_REG_CPU_KERNEL_T_S(
|
||||
MaxUnpool2D,
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt64),
|
||||
MaxUnpool2DCPUKernel, int64_t, int32_t);
|
||||
MS_REG_CPU_KERNEL_T_S(
|
||||
MaxUnpool2D,
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||
MaxUnpool2DCPUKernel, int64_t, int64_t);
|
||||
MS_REG_CPU_KERNEL_T_S(
|
||||
MaxUnpool2D,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16),
|
||||
MaxUnpool2DCPUKernel, float16, int32_t);
|
||||
MS_REG_CPU_KERNEL_T_S(
|
||||
MaxUnpool2D,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat16),
|
||||
MaxUnpool2DCPUKernel, float16, int64_t);
|
||||
MS_REG_CPU_KERNEL_T_S(
|
||||
MaxUnpool2D,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
|
||||
MaxUnpool2DCPUKernel, float, int32_t);
|
||||
MS_REG_CPU_KERNEL_T_S(
|
||||
MaxUnpool2D,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32),
|
||||
MaxUnpool2DCPUKernel, float, int64_t);
|
||||
MS_REG_CPU_KERNEL_T_S(
|
||||
MaxUnpool2D,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat64),
|
||||
MaxUnpool2DCPUKernel, double, int32_t);
|
||||
MS_REG_CPU_KERNEL_T_S(
|
||||
MaxUnpool2D,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat64),
|
||||
MaxUnpool2DCPUKernel, double, int64_t);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -15,8 +15,10 @@
|
|||
*/
|
||||
#include <cmath>
|
||||
#include <map>
|
||||
#include "backend/kernel_compiler/cpu/max_unpool2d_grad_cpu_kernel.h"
|
||||
#include "runtime/device/cpu/cpu_device_address.h"
|
||||
#include <utility>
|
||||
#include <algorithm>
|
||||
#include "plugin/device/cpu/kernel/max_unpool2d_grad_cpu_kernel.h"
|
||||
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
|
@ -28,29 +30,42 @@ constexpr size_t kInputIndex1 = 1;
|
|||
constexpr size_t kInputIndex2 = 2;
|
||||
constexpr size_t kInputIndex3 = 3;
|
||||
} // namespace
|
||||
template <typename DATA_T, typename INDICES_T>
|
||||
void MaxUnpool2DGradCPUKernel<DATA_T, INDICES_T>::InitKernel(const CNodePtr &kernel_node) {
|
||||
|
||||
void MaxUnpool2DGradCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
kernel_name_ = AnfAlgo::GetCNodeName(kernel_node);
|
||||
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
|
||||
node_wpt_ = kernel_node;
|
||||
input_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, kInputIndex0);
|
||||
grads_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, kInputIndex1);
|
||||
indices_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, kInputIndex2);
|
||||
output_shape_ = AnfAlgo::GetOutputDeviceShape(kernel_node, kInputIndex0);
|
||||
data_format_ = AnfAlgo::GetNodeAttr<string>(kernel_node, FORMAT);
|
||||
data_format_ = common::AnfAlgo::GetNodeAttr<string>(kernel_node, FORMAT);
|
||||
|
||||
if (AnfAlgo::IsShapesDynamic({input_shape_, grads_shape_, indices_shape_, output_shape_})) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto kernel_attr = GetKernelAttrFromNode(kernel_node);
|
||||
std::vector<KernelAttr> support_list;
|
||||
(void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
|
||||
[](const std::pair<KernelAttr, MaxUnpool2DGradFunc> &pair) { return pair.first; });
|
||||
auto [is_match, index] = MatchKernelAttr(kernel_attr, support_list);
|
||||
if (!is_match) {
|
||||
MS_LOG(EXCEPTION) << "MaxUnpool2DGrad does not support this kernel data type: " << kernel_attr;
|
||||
}
|
||||
kernel_func_ = func_list_[index].second;
|
||||
}
|
||||
|
||||
template <typename DATA_T, typename INDICES_T>
|
||||
void MaxUnpool2DGradCPUKernel<DATA_T, INDICES_T>::OutPutInitKernel(DATA_T *raw_output, size_t length) {
|
||||
template <typename DATA_T>
|
||||
void MaxUnpool2DGradCpuKernelMod::OutPutInitKernel(DATA_T *raw_output, size_t length) {
|
||||
for (size_t s = 0; s < length; s++) {
|
||||
raw_output[s] = (DATA_T)0;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename DATA_T, typename INDICES_T>
|
||||
bool MaxUnpool2DGradCPUKernel<DATA_T, INDICES_T>::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
bool MaxUnpool2DGradCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
auto node = node_wpt_.lock();
|
||||
if (!node) {
|
||||
MS_LOG(EXCEPTION) << "node_wpt_ is expired.";
|
||||
|
@ -66,14 +81,14 @@ bool MaxUnpool2DGradCPUKernel<DATA_T, INDICES_T>::Launch(const std::vector<kerne
|
|||
auto *raw_indices = reinterpret_cast<INDICES_T *>(inputs[kInputIndex2]->addr);
|
||||
auto *raw_output = reinterpret_cast<DATA_T *>(outputs[kInputIndex0]->addr);
|
||||
if (data_format_ == "NHWC") {
|
||||
size_t num_batch = grads_shape_[kInputIndex0];
|
||||
size_t oheight = grads_shape_[kInputIndex1];
|
||||
size_t owidth = grads_shape_[kInputIndex2];
|
||||
size_t num_channels = grads_shape_[kInputIndex3];
|
||||
size_t iheight = output_shape_[kInputIndex1];
|
||||
size_t iwidth = output_shape_[kInputIndex2];
|
||||
size_t num_batch = LongToSize(grads_shape_[kInputIndex0]);
|
||||
size_t oheight = LongToSize(grads_shape_[kInputIndex1]);
|
||||
size_t owidth = LongToSize(grads_shape_[kInputIndex2]);
|
||||
size_t num_channels = LongToSize(grads_shape_[kInputIndex3]);
|
||||
size_t iheight = LongToSize(output_shape_[kInputIndex1]);
|
||||
size_t iwidth = LongToSize(output_shape_[kInputIndex2]);
|
||||
size_t length = num_batch * iheight * iwidth * num_channels;
|
||||
OutPutInitKernel(raw_output, length);
|
||||
OutPutInitKernel<DATA_T>(raw_output, length);
|
||||
for (size_t n = 0; n < num_batch; n++) {
|
||||
size_t noutput_offset = n * num_channels * iwidth * iheight;
|
||||
size_t n_grads_offset = n * num_channels * owidth * oheight;
|
||||
|
@ -100,14 +115,14 @@ bool MaxUnpool2DGradCPUKernel<DATA_T, INDICES_T>::Launch(const std::vector<kerne
|
|||
}
|
||||
}
|
||||
} else {
|
||||
size_t num_batch = grads_shape_[kInputIndex0];
|
||||
size_t oheight = grads_shape_[kInputIndex2];
|
||||
size_t owidth = grads_shape_[kInputIndex3];
|
||||
size_t num_channels = grads_shape_[kInputIndex1];
|
||||
size_t iheight = output_shape_[kInputIndex2];
|
||||
size_t iwidth = output_shape_[kInputIndex3];
|
||||
size_t num_batch = LongToSize(grads_shape_[kInputIndex0]);
|
||||
size_t oheight = LongToSize(grads_shape_[kInputIndex2]);
|
||||
size_t owidth = LongToSize(grads_shape_[kInputIndex3]);
|
||||
size_t num_channels = LongToSize(grads_shape_[kInputIndex1]);
|
||||
size_t iheight = LongToSize(output_shape_[kInputIndex2]);
|
||||
size_t iwidth = LongToSize(output_shape_[kInputIndex3]);
|
||||
size_t length = num_batch * iheight * iwidth * num_channels;
|
||||
OutPutInitKernel(raw_output, length);
|
||||
OutPutInitKernel<DATA_T>(raw_output, length);
|
||||
for (size_t n = 0; n < num_batch; n++) {
|
||||
size_t noutput_offset = n * num_channels * iwidth * iheight;
|
||||
size_t n_grads_offset = n * num_channels * owidth * oheight;
|
||||
|
@ -139,5 +154,148 @@ bool MaxUnpool2DGradCPUKernel<DATA_T, INDICES_T>::Launch(const std::vector<kerne
|
|||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<std::pair<KernelAttr, MaxUnpool2DGradCpuKernelMod::MaxUnpool2DGradFunc>>
|
||||
MaxUnpool2DGradCpuKernelMod::func_list_ = {{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeUInt8)
|
||||
.AddInputAttr(kNumberTypeUInt8)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeUInt8),
|
||||
&MaxUnpool2DGradCpuKernelMod::LaunchKernel<uint8_t, int32_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeUInt8)
|
||||
.AddInputAttr(kNumberTypeUInt8)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeUInt8),
|
||||
&MaxUnpool2DGradCpuKernelMod::LaunchKernel<uint8_t, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeUInt16)
|
||||
.AddInputAttr(kNumberTypeUInt16)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeUInt16),
|
||||
&MaxUnpool2DGradCpuKernelMod::LaunchKernel<uint16_t, int32_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeUInt16)
|
||||
.AddInputAttr(kNumberTypeUInt16)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeUInt16),
|
||||
&MaxUnpool2DGradCpuKernelMod::LaunchKernel<uint16_t, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeUInt32)
|
||||
.AddInputAttr(kNumberTypeUInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeUInt32),
|
||||
&MaxUnpool2DGradCpuKernelMod::LaunchKernel<uint32_t, int32_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeUInt32)
|
||||
.AddInputAttr(kNumberTypeUInt32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeUInt32),
|
||||
&MaxUnpool2DGradCpuKernelMod::LaunchKernel<uint32_t, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeUInt64)
|
||||
.AddInputAttr(kNumberTypeUInt64)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeUInt64),
|
||||
&MaxUnpool2DGradCpuKernelMod::LaunchKernel<uint64_t, int32_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeUInt64)
|
||||
.AddInputAttr(kNumberTypeUInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeUInt64),
|
||||
&MaxUnpool2DGradCpuKernelMod::LaunchKernel<uint64_t, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt8)
|
||||
.AddInputAttr(kNumberTypeInt8)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeInt8),
|
||||
&MaxUnpool2DGradCpuKernelMod::LaunchKernel<int8_t, int32_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt8)
|
||||
.AddInputAttr(kNumberTypeInt8)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeInt8),
|
||||
&MaxUnpool2DGradCpuKernelMod::LaunchKernel<int8_t, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt16)
|
||||
.AddInputAttr(kNumberTypeInt16)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeInt16),
|
||||
&MaxUnpool2DGradCpuKernelMod::LaunchKernel<int16_t, int32_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt16)
|
||||
.AddInputAttr(kNumberTypeInt16)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeInt16),
|
||||
&MaxUnpool2DGradCpuKernelMod::LaunchKernel<int16_t, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeInt32),
|
||||
&MaxUnpool2DGradCpuKernelMod::LaunchKernel<int32_t, int32_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeInt32),
|
||||
&MaxUnpool2DGradCpuKernelMod::LaunchKernel<int32_t, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeInt64),
|
||||
&MaxUnpool2DGradCpuKernelMod::LaunchKernel<int64_t, int32_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeInt64),
|
||||
&MaxUnpool2DGradCpuKernelMod::LaunchKernel<int64_t, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
&MaxUnpool2DGradCpuKernelMod::LaunchKernel<float16, int32_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
&MaxUnpool2DGradCpuKernelMod::LaunchKernel<float16, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
&MaxUnpool2DGradCpuKernelMod::LaunchKernel<float, int32_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
&MaxUnpool2DGradCpuKernelMod::LaunchKernel<float, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeFloat64),
|
||||
&MaxUnpool2DGradCpuKernelMod::LaunchKernel<double, int32_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat64),
|
||||
&MaxUnpool2DGradCpuKernelMod::LaunchKernel<double, int64_t>}};
|
||||
|
||||
std::vector<KernelAttr> MaxUnpool2DGradCpuKernelMod::GetOpSupport() {
|
||||
std::vector<KernelAttr> support_list;
|
||||
(void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
|
||||
[](const std::pair<KernelAttr, MaxUnpool2DGradFunc> &pair) { return pair.first; });
|
||||
return support_list;
|
||||
}
|
||||
|
||||
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, MaxUnpool2DGrad, MaxUnpool2DGradCpuKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -19,187 +19,45 @@
|
|||
#include <memory>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
|
||||
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
|
||||
#include <utility>
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel.h"
|
||||
#include "plugin/factory/ms_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
template <typename DATA_T, typename INDICES_T>
|
||||
class MaxUnpool2DGradCPUKernel : public CPUKernel {
|
||||
class MaxUnpool2DGradCpuKernelMod : public DeprecatedNativeCpuKernelMod {
|
||||
public:
|
||||
MaxUnpool2DGradCPUKernel() = default;
|
||||
~MaxUnpool2DGradCPUKernel() override = default;
|
||||
MaxUnpool2DGradCpuKernelMod() = default;
|
||||
~MaxUnpool2DGradCpuKernelMod() override = default;
|
||||
|
||||
void InitKernel(const CNodePtr &kernel_node) override;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &outputs) override {
|
||||
return kernel_func_(this, inputs, outputs);
|
||||
};
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
template <typename DATA_T, typename INDICES_T>
|
||||
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
|
||||
using MaxUnpool2DGradFunc = std::function<bool(MaxUnpool2DGradCpuKernelMod *, const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &)>;
|
||||
static std::vector<std::pair<KernelAttr, MaxUnpool2DGradFunc>> func_list_;
|
||||
MaxUnpool2DGradFunc kernel_func_;
|
||||
|
||||
template <typename DATA_T>
|
||||
void OutPutInitKernel(DATA_T *rawOutput, size_t length);
|
||||
CNodeWeakPtr node_wpt_;
|
||||
std::vector<size_t> input_shape_;
|
||||
std::vector<size_t> grads_shape_;
|
||||
std::vector<size_t> indices_shape_;
|
||||
std::vector<size_t> output_shape_;
|
||||
ShapeVector input_shape_;
|
||||
ShapeVector grads_shape_;
|
||||
ShapeVector indices_shape_;
|
||||
ShapeVector output_shape_;
|
||||
std::string data_format_;
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL_T_S(MaxUnpool2D,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeUInt8)
|
||||
.AddInputAttr(kNumberTypeUInt8)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeUInt8),
|
||||
MaxUnpool2DGradCPUKernel, uint8_t, int32_t);
|
||||
MS_REG_CPU_KERNEL_T_S(MaxUnpool2DGrad,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeUInt8)
|
||||
.AddInputAttr(kNumberTypeUInt8)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeUInt8),
|
||||
MaxUnpool2DGradCPUKernel, uint8_t, int64_t);
|
||||
MS_REG_CPU_KERNEL_T_S(MaxUnpool2DGrad,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeUInt16)
|
||||
.AddInputAttr(kNumberTypeUInt16)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeUInt16),
|
||||
MaxUnpool2DGradCPUKernel, uint16_t, int32_t);
|
||||
MS_REG_CPU_KERNEL_T_S(MaxUnpool2DGrad,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeUInt16)
|
||||
.AddInputAttr(kNumberTypeUInt16)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeUInt16),
|
||||
MaxUnpool2DGradCPUKernel, uint16_t, int64_t);
|
||||
MS_REG_CPU_KERNEL_T_S(MaxUnpool2DGrad,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeUInt32)
|
||||
.AddInputAttr(kNumberTypeUInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeUInt32),
|
||||
MaxUnpool2DGradCPUKernel, uint32_t, int32_t);
|
||||
MS_REG_CPU_KERNEL_T_S(MaxUnpool2DGrad,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeUInt32)
|
||||
.AddInputAttr(kNumberTypeUInt32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeUInt32),
|
||||
MaxUnpool2DGradCPUKernel, uint32_t, int64_t);
|
||||
MS_REG_CPU_KERNEL_T_S(MaxUnpool2DGrad,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeUInt64)
|
||||
.AddInputAttr(kNumberTypeUInt64)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeUInt64),
|
||||
MaxUnpool2DGradCPUKernel, uint64_t, int32_t);
|
||||
MS_REG_CPU_KERNEL_T_S(MaxUnpool2DGrad,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeUInt64)
|
||||
.AddInputAttr(kNumberTypeUInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeUInt64),
|
||||
MaxUnpool2DGradCPUKernel, uint64_t, int64_t);
|
||||
MS_REG_CPU_KERNEL_T_S(MaxUnpool2DGrad,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt8)
|
||||
.AddInputAttr(kNumberTypeInt8)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeInt8),
|
||||
MaxUnpool2DGradCPUKernel, int8_t, int32_t);
|
||||
MS_REG_CPU_KERNEL_T_S(MaxUnpool2DGrad,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt8)
|
||||
.AddInputAttr(kNumberTypeInt8)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeInt8),
|
||||
MaxUnpool2DGradCPUKernel, int8_t, int64_t);
|
||||
MS_REG_CPU_KERNEL_T_S(MaxUnpool2DGrad,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt16)
|
||||
.AddInputAttr(kNumberTypeInt16)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeInt16),
|
||||
MaxUnpool2DGradCPUKernel, int16_t, int32_t);
|
||||
MS_REG_CPU_KERNEL_T_S(MaxUnpool2DGrad,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt16)
|
||||
.AddInputAttr(kNumberTypeInt16)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeInt16),
|
||||
MaxUnpool2DGradCPUKernel, int16_t, int64_t);
|
||||
MS_REG_CPU_KERNEL_T_S(MaxUnpool2DGrad,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeInt32),
|
||||
MaxUnpool2DGradCPUKernel, int32_t, int32_t);
|
||||
MS_REG_CPU_KERNEL_T_S(MaxUnpool2DGrad,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeInt32),
|
||||
MaxUnpool2DGradCPUKernel, int32_t, int64_t);
|
||||
MS_REG_CPU_KERNEL_T_S(MaxUnpool2DGrad,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeInt64),
|
||||
MaxUnpool2DGradCPUKernel, int64_t, int32_t);
|
||||
MS_REG_CPU_KERNEL_T_S(MaxUnpool2DGrad,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeInt64),
|
||||
MaxUnpool2DGradCPUKernel, int64_t, int64_t);
|
||||
MS_REG_CPU_KERNEL_T_S(MaxUnpool2DGrad,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
MaxUnpool2DGradCPUKernel, float16, int32_t);
|
||||
MS_REG_CPU_KERNEL_T_S(MaxUnpool2DGrad,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
MaxUnpool2DGradCPUKernel, float16, int64_t);
|
||||
MS_REG_CPU_KERNEL_T_S(MaxUnpool2DGrad,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
MaxUnpool2DGradCPUKernel, float, int32_t);
|
||||
MS_REG_CPU_KERNEL_T_S(MaxUnpool2DGrad,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
MaxUnpool2DGradCPUKernel, float, int64_t);
|
||||
MS_REG_CPU_KERNEL_T_S(MaxUnpool2DGrad,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeFloat64),
|
||||
MaxUnpool2DGradCPUKernel, double, int32_t);
|
||||
MS_REG_CPU_KERNEL_T_S(MaxUnpool2DGrad,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat64),
|
||||
MaxUnpool2DGradCPUKernel, double, int64_t);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_MaxUnpool2DGradGRAD_CPU_KERNEL_H_
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_MAXUNPOOL2DGRAD_CPU_KERNEL_H_
|
||||
|
|
|
@ -15,8 +15,10 @@
|
|||
*/
|
||||
#include <cmath>
|
||||
#include <map>
|
||||
#include "backend/kernel_compiler/cpu/max_unpool3d_cpu_kernel.h"
|
||||
#include "runtime/device/cpu/cpu_device_address.h"
|
||||
#include <utility>
|
||||
#include <algorithm>
|
||||
#include "plugin/device/cpu/kernel/max_unpool3d_cpu_kernel.h"
|
||||
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
|
@ -29,28 +31,41 @@ constexpr size_t kInputIndex2 = 2;
|
|||
constexpr size_t kInputIndex3 = 3;
|
||||
constexpr size_t kInputIndex4 = 4;
|
||||
} // namespace
|
||||
template <typename DATA_T, typename INDICES_T>
|
||||
void MaxUnpool3DCPUKernel<DATA_T, INDICES_T>::InitKernel(const CNodePtr &kernel_node) {
|
||||
|
||||
void MaxUnpool3DCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
kernel_name_ = AnfAlgo::GetCNodeName(kernel_node);
|
||||
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
|
||||
node_wpt_ = kernel_node;
|
||||
input_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, kInputIndex0);
|
||||
indices_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, kInputIndex1);
|
||||
output_shape_ = AnfAlgo::GetOutputDeviceShape(kernel_node, kInputIndex0);
|
||||
data_format_ = AnfAlgo::GetNodeAttr<string>(kernel_node, FORMAT);
|
||||
data_format_ = common::AnfAlgo::GetNodeAttr<string>(kernel_node, FORMAT);
|
||||
|
||||
if (AnfAlgo::IsShapesDynamic({input_shape_, indices_shape_, output_shape_})) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto kernel_attr = GetKernelAttrFromNode(kernel_node);
|
||||
std::vector<KernelAttr> support_list;
|
||||
(void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
|
||||
[](const std::pair<KernelAttr, MaxUnpool3DFunc> &pair) { return pair.first; });
|
||||
auto [is_match, index] = MatchKernelAttr(kernel_attr, support_list);
|
||||
if (!is_match) {
|
||||
MS_LOG(EXCEPTION) << "MaxUnpool3D does not support this kernel data type: " << kernel_attr;
|
||||
}
|
||||
kernel_func_ = func_list_[index].second;
|
||||
}
|
||||
|
||||
template <typename DATA_T, typename INDICES_T>
|
||||
void MaxUnpool3DCPUKernel<DATA_T, INDICES_T>::OutPutInitKernel(DATA_T *raw_output, size_t length) {
|
||||
template <typename DATA_T>
|
||||
void MaxUnpool3DCpuKernelMod::OutPutInitKernel(DATA_T *raw_output, size_t length) {
|
||||
for (size_t s = 0; s < length; s++) {
|
||||
raw_output[s] = (DATA_T)0;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename DATA_T, typename INDICES_T>
|
||||
bool MaxUnpool3DCPUKernel<DATA_T, INDICES_T>::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
bool MaxUnpool3DCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
auto node = node_wpt_.lock();
|
||||
if (!node) {
|
||||
MS_LOG(EXCEPTION) << "node_wpt_ is expired.";
|
||||
|
@ -66,15 +81,15 @@ bool MaxUnpool3DCPUKernel<DATA_T, INDICES_T>::Launch(const std::vector<kernel::A
|
|||
auto *raw_output = reinterpret_cast<DATA_T *>(outputs[kInputIndex0]->addr);
|
||||
size_t num_batch = input_shape_[kInputIndex0];
|
||||
if (data_format_ == "NDHWC") {
|
||||
size_t input_depth = input_shape_[kInputIndex1];
|
||||
size_t input_height = input_shape_[kInputIndex2];
|
||||
size_t input_width = input_shape_[kInputIndex3];
|
||||
size_t num_channels = input_shape_[kInputIndex4];
|
||||
size_t odepth = output_shape_[kInputIndex1];
|
||||
size_t oheight = output_shape_[kInputIndex2];
|
||||
size_t owidth = output_shape_[kInputIndex3];
|
||||
size_t input_depth = LongToSize(input_shape_[kInputIndex1]);
|
||||
size_t input_height = LongToSize(input_shape_[kInputIndex2]);
|
||||
size_t input_width = LongToSize(input_shape_[kInputIndex3]);
|
||||
size_t num_channels = LongToSize(input_shape_[kInputIndex4]);
|
||||
size_t odepth = LongToSize(output_shape_[kInputIndex1]);
|
||||
size_t oheight = LongToSize(output_shape_[kInputIndex2]);
|
||||
size_t owidth = LongToSize(output_shape_[kInputIndex3]);
|
||||
size_t length = num_batch * odepth * oheight * owidth * num_channels;
|
||||
OutPutInitKernel(raw_output, length);
|
||||
OutPutInitKernel<DATA_T>(raw_output, length);
|
||||
for (size_t n = 0; n < num_batch; n++) {
|
||||
size_t noutput_offset = n * num_channels * odepth * owidth * oheight;
|
||||
size_t n_input_offset = n * num_channels * input_depth * input_width * input_height;
|
||||
|
@ -105,15 +120,15 @@ bool MaxUnpool3DCPUKernel<DATA_T, INDICES_T>::Launch(const std::vector<kernel::A
|
|||
}
|
||||
}
|
||||
} else {
|
||||
size_t input_depth = input_shape_[kInputIndex2];
|
||||
size_t input_height = input_shape_[kInputIndex3];
|
||||
size_t input_width = input_shape_[kInputIndex4];
|
||||
size_t num_channels = input_shape_[kInputIndex1];
|
||||
size_t odepth = output_shape_[kInputIndex2];
|
||||
size_t oheight = output_shape_[kInputIndex3];
|
||||
size_t owidth = output_shape_[kInputIndex4];
|
||||
size_t input_depth = LongToSize(input_shape_[kInputIndex2]);
|
||||
size_t input_height = LongToSize(input_shape_[kInputIndex3]);
|
||||
size_t input_width = LongToSize(input_shape_[kInputIndex4]);
|
||||
size_t num_channels = LongToSize(input_shape_[kInputIndex1]);
|
||||
size_t odepth = LongToSize(output_shape_[kInputIndex2]);
|
||||
size_t oheight = LongToSize(output_shape_[kInputIndex3]);
|
||||
size_t owidth = LongToSize(output_shape_[kInputIndex4]);
|
||||
size_t length = num_batch * odepth * oheight * owidth * num_channels;
|
||||
OutPutInitKernel(raw_output, length);
|
||||
OutPutInitKernel<DATA_T>(raw_output, length);
|
||||
for (size_t n = 0; n < num_batch; n++) {
|
||||
size_t noutput_offset = n * num_channels * odepth * owidth * oheight;
|
||||
size_t n_input_offset = n * num_channels * input_depth * input_width * input_height;
|
||||
|
@ -148,5 +163,60 @@ bool MaxUnpool3DCPUKernel<DATA_T, INDICES_T>::Launch(const std::vector<kernel::A
|
|||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<std::pair<KernelAttr, MaxUnpool3DCpuKernelMod::MaxUnpool3DFunc>> MaxUnpool3DCpuKernelMod::func_list_ = {
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt8),
|
||||
&MaxUnpool3DCpuKernelMod::LaunchKernel<uint8_t, int32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt8),
|
||||
&MaxUnpool3DCpuKernelMod::LaunchKernel<uint8_t, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt16),
|
||||
&MaxUnpool3DCpuKernelMod::LaunchKernel<uint16_t, int32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt16),
|
||||
&MaxUnpool3DCpuKernelMod::LaunchKernel<uint16_t, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt32),
|
||||
&MaxUnpool3DCpuKernelMod::LaunchKernel<uint32_t, int32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt32),
|
||||
&MaxUnpool3DCpuKernelMod::LaunchKernel<uint32_t, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt64),
|
||||
&MaxUnpool3DCpuKernelMod::LaunchKernel<uint64_t, int32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt64),
|
||||
&MaxUnpool3DCpuKernelMod::LaunchKernel<uint64_t, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt8),
|
||||
&MaxUnpool3DCpuKernelMod::LaunchKernel<int8_t, int32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt8),
|
||||
&MaxUnpool3DCpuKernelMod::LaunchKernel<int8_t, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt16),
|
||||
&MaxUnpool3DCpuKernelMod::LaunchKernel<int16_t, int32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt16),
|
||||
&MaxUnpool3DCpuKernelMod::LaunchKernel<int16_t, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
&MaxUnpool3DCpuKernelMod::LaunchKernel<int32_t, int32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32),
|
||||
&MaxUnpool3DCpuKernelMod::LaunchKernel<int32_t, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt64),
|
||||
&MaxUnpool3DCpuKernelMod::LaunchKernel<int64_t, int32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||
&MaxUnpool3DCpuKernelMod::LaunchKernel<int64_t, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16),
|
||||
&MaxUnpool3DCpuKernelMod::LaunchKernel<float16, int32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat16),
|
||||
&MaxUnpool3DCpuKernelMod::LaunchKernel<float16, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
|
||||
&MaxUnpool3DCpuKernelMod::LaunchKernel<float, int32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32),
|
||||
&MaxUnpool3DCpuKernelMod::LaunchKernel<float, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat64),
|
||||
&MaxUnpool3DCpuKernelMod::LaunchKernel<double, int32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat64),
|
||||
&MaxUnpool3DCpuKernelMod::LaunchKernel<double, int64_t>}};
|
||||
|
||||
std::vector<KernelAttr> MaxUnpool3DCpuKernelMod::GetOpSupport() {
|
||||
std::vector<KernelAttr> support_list;
|
||||
(void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
|
||||
[](const std::pair<KernelAttr, MaxUnpool3DFunc> &pair) { return pair.first; });
|
||||
return support_list;
|
||||
}
|
||||
|
||||
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, MaxUnpool3D, MaxUnpool3DCpuKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -19,117 +19,43 @@
|
|||
#include <memory>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
|
||||
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
|
||||
#include <utility>
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel.h"
|
||||
#include "plugin/factory/ms_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
template <typename DATA_T, typename INDICES_T>
|
||||
class MaxUnpool3DCPUKernel : public CPUKernel {
|
||||
class MaxUnpool3DCpuKernelMod : public DeprecatedNativeCpuKernelMod {
|
||||
public:
|
||||
MaxUnpool3DCPUKernel() = default;
|
||||
~MaxUnpool3DCPUKernel() override = default;
|
||||
MaxUnpool3DCpuKernelMod() = default;
|
||||
~MaxUnpool3DCpuKernelMod() override = default;
|
||||
|
||||
void InitKernel(const CNodePtr &kernel_node) override;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &outputs) override {
|
||||
return kernel_func_(this, inputs, outputs);
|
||||
};
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
template <typename DATA_T, typename INDICES_T>
|
||||
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
|
||||
using MaxUnpool3DFunc = std::function<bool(MaxUnpool3DCpuKernelMod *, const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &)>;
|
||||
static std::vector<std::pair<KernelAttr, MaxUnpool3DFunc>> func_list_;
|
||||
MaxUnpool3DFunc kernel_func_;
|
||||
|
||||
template <typename DATA_T>
|
||||
void OutPutInitKernel(DATA_T *rawOutput, size_t length);
|
||||
CNodeWeakPtr node_wpt_;
|
||||
std::vector<size_t> input_shape_;
|
||||
std::vector<size_t> indices_shape_;
|
||||
std::vector<size_t> output_shape_;
|
||||
ShapeVector input_shape_;
|
||||
ShapeVector indices_shape_;
|
||||
ShapeVector output_shape_;
|
||||
std::string data_format_;
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL_T_S(
|
||||
MaxUnpool3D,
|
||||
KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt8),
|
||||
MaxUnpool3DCPUKernel, uint8_t, int32_t);
|
||||
MS_REG_CPU_KERNEL_T_S(
|
||||
MaxUnpool3D,
|
||||
KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt8),
|
||||
MaxUnpool3DCPUKernel, uint8_t, int64_t);
|
||||
MS_REG_CPU_KERNEL_T_S(
|
||||
MaxUnpool3D,
|
||||
KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt16),
|
||||
MaxUnpool3DCPUKernel, uint16_t, int32_t);
|
||||
MS_REG_CPU_KERNEL_T_S(
|
||||
MaxUnpool3D,
|
||||
KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt16),
|
||||
MaxUnpool3DCPUKernel, uint16_t, int64_t);
|
||||
MS_REG_CPU_KERNEL_T_S(
|
||||
MaxUnpool3D,
|
||||
KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt32),
|
||||
MaxUnpool3DCPUKernel, uint32_t, int32_t);
|
||||
MS_REG_CPU_KERNEL_T_S(
|
||||
MaxUnpool3D,
|
||||
KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt32),
|
||||
MaxUnpool3DCPUKernel, uint32_t, int64_t);
|
||||
MS_REG_CPU_KERNEL_T_S(
|
||||
MaxUnpool3D,
|
||||
KernelAttr().AddInputAttr(kNumberTypeUInt64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt64),
|
||||
MaxUnpool3DCPUKernel, uint64_t, int32_t);
|
||||
MS_REG_CPU_KERNEL_T_S(
|
||||
MaxUnpool3D,
|
||||
KernelAttr().AddInputAttr(kNumberTypeUInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt64),
|
||||
MaxUnpool3DCPUKernel, uint64_t, int64_t);
|
||||
MS_REG_CPU_KERNEL_T_S(
|
||||
MaxUnpool3D, KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt8),
|
||||
MaxUnpool3DCPUKernel, int8_t, int32_t);
|
||||
MS_REG_CPU_KERNEL_T_S(
|
||||
MaxUnpool3D, KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt8),
|
||||
MaxUnpool3DCPUKernel, int8_t, int64_t);
|
||||
MS_REG_CPU_KERNEL_T_S(
|
||||
MaxUnpool3D,
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt16),
|
||||
MaxUnpool3DCPUKernel, int16_t, int32_t);
|
||||
MS_REG_CPU_KERNEL_T_S(
|
||||
MaxUnpool3D,
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt16),
|
||||
MaxUnpool3DCPUKernel, int16_t, int64_t);
|
||||
MS_REG_CPU_KERNEL_T_S(
|
||||
MaxUnpool3D,
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
MaxUnpool3DCPUKernel, int32_t, int32_t);
|
||||
MS_REG_CPU_KERNEL_T_S(
|
||||
MaxUnpool3D,
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32),
|
||||
MaxUnpool3DCPUKernel, int32_t, int64_t);
|
||||
MS_REG_CPU_KERNEL_T_S(
|
||||
MaxUnpool3D,
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt64),
|
||||
MaxUnpool3DCPUKernel, int64_t, int32_t);
|
||||
MS_REG_CPU_KERNEL_T_S(
|
||||
MaxUnpool3D,
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||
MaxUnpool3DCPUKernel, int64_t, int64_t);
|
||||
MS_REG_CPU_KERNEL_T_S(
|
||||
MaxUnpool3D,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16),
|
||||
MaxUnpool3DCPUKernel, float16, int32_t);
|
||||
MS_REG_CPU_KERNEL_T_S(
|
||||
MaxUnpool3D,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat16),
|
||||
MaxUnpool3DCPUKernel, float16, int64_t);
|
||||
MS_REG_CPU_KERNEL_T_S(
|
||||
MaxUnpool3D,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
|
||||
MaxUnpool3DCPUKernel, float, int32_t);
|
||||
MS_REG_CPU_KERNEL_T_S(
|
||||
MaxUnpool3D,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32),
|
||||
MaxUnpool3DCPUKernel, float, int64_t);
|
||||
MS_REG_CPU_KERNEL_T_S(
|
||||
MaxUnpool3D,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat64),
|
||||
MaxUnpool3DCPUKernel, double, int32_t);
|
||||
MS_REG_CPU_KERNEL_T_S(
|
||||
MaxUnpool3D,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat64),
|
||||
MaxUnpool3DCPUKernel, double, int64_t);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -15,8 +15,10 @@
|
|||
*/
|
||||
#include <cmath>
|
||||
#include <map>
|
||||
#include "backend/kernel_compiler/cpu/max_unpool3d_grad_cpu_kernel.h"
|
||||
#include "runtime/device/cpu/cpu_device_address.h"
|
||||
#include <utility>
|
||||
#include <algorithm>
|
||||
#include "plugin/device/cpu/kernel/max_unpool3d_grad_cpu_kernel.h"
|
||||
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
|
@ -30,29 +32,41 @@ constexpr size_t kInputIndex3 = 3;
|
|||
constexpr size_t kInputIndex4 = 4;
|
||||
} // namespace
|
||||
|
||||
template <typename DATA_T, typename INDICES_T>
|
||||
void MaxUnpool3DGradCPUKernel<DATA_T, INDICES_T>::InitKernel(const CNodePtr &kernel_node) {
|
||||
void MaxUnpool3DGradCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
kernel_name_ = AnfAlgo::GetCNodeName(kernel_node);
|
||||
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
|
||||
node_wpt_ = kernel_node;
|
||||
input_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, kInputIndex0);
|
||||
grads_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, kInputIndex1);
|
||||
indices_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, kInputIndex2);
|
||||
output_shape_ = AnfAlgo::GetOutputDeviceShape(kernel_node, kInputIndex0);
|
||||
data_format_ = AnfAlgo::GetNodeAttr<string>(kernel_node, FORMAT);
|
||||
data_format_ = common::AnfAlgo::GetNodeAttr<string>(kernel_node, FORMAT);
|
||||
|
||||
if (AnfAlgo::IsShapesDynamic({input_shape_, grads_shape_, indices_shape_, output_shape_})) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto kernel_attr = GetKernelAttrFromNode(kernel_node);
|
||||
std::vector<KernelAttr> support_list;
|
||||
(void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
|
||||
[](const std::pair<KernelAttr, MaxUnpool3DGradFunc> &pair) { return pair.first; });
|
||||
auto [is_match, index] = MatchKernelAttr(kernel_attr, support_list);
|
||||
if (!is_match) {
|
||||
MS_LOG(EXCEPTION) << "MaxUnpool3DGrad does not support this kernel data type: " << kernel_attr;
|
||||
}
|
||||
kernel_func_ = func_list_[index].second;
|
||||
}
|
||||
|
||||
template <typename DATA_T, typename INDICES_T>
|
||||
void MaxUnpool3DGradCPUKernel<DATA_T, INDICES_T>::OutPutInitKernel(DATA_T *raw_output, size_t length) {
|
||||
template <typename DATA_T>
|
||||
void MaxUnpool3DGradCpuKernelMod::OutPutInitKernel(DATA_T *raw_output, size_t length) {
|
||||
for (size_t s = 0; s < length; s++) {
|
||||
raw_output[s] = (DATA_T)0;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename DATA_T, typename INDICES_T>
|
||||
bool MaxUnpool3DGradCPUKernel<DATA_T, INDICES_T>::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
bool MaxUnpool3DGradCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
auto node = node_wpt_.lock();
|
||||
if (!node) {
|
||||
MS_LOG(EXCEPTION) << "node_wpt_ is expired.";
|
||||
|
@ -66,17 +80,17 @@ bool MaxUnpool3DGradCPUKernel<DATA_T, INDICES_T>::Launch(const std::vector<kerne
|
|||
auto *raw_grads = reinterpret_cast<DATA_T *>(inputs[kInputIndex1]->addr);
|
||||
auto *raw_indices = reinterpret_cast<INDICES_T *>(inputs[kInputIndex2]->addr);
|
||||
auto *raw_output = reinterpret_cast<DATA_T *>(outputs[kInputIndex0]->addr);
|
||||
size_t num_batch = grads_shape_[kInputIndex0];
|
||||
auto num_batch = LongToSize(grads_shape_[kInputIndex0]);
|
||||
if (data_format_ == "NDHWC") {
|
||||
size_t odepth = grads_shape_[kInputIndex1];
|
||||
size_t oheight = grads_shape_[kInputIndex2];
|
||||
size_t owidth = grads_shape_[kInputIndex3];
|
||||
size_t num_channels = grads_shape_[kInputIndex4];
|
||||
size_t idepth = output_shape_[kInputIndex1];
|
||||
size_t iheight = output_shape_[kInputIndex2];
|
||||
size_t iwidth = output_shape_[kInputIndex3];
|
||||
size_t odepth = LongToSize(grads_shape_[kInputIndex1]);
|
||||
size_t oheight = LongToSize(grads_shape_[kInputIndex2]);
|
||||
size_t owidth = LongToSize(grads_shape_[kInputIndex3]);
|
||||
size_t num_channels = LongToSize(grads_shape_[kInputIndex4]);
|
||||
size_t idepth = LongToSize(output_shape_[kInputIndex1]);
|
||||
size_t iheight = LongToSize(output_shape_[kInputIndex2]);
|
||||
size_t iwidth = LongToSize(output_shape_[kInputIndex3]);
|
||||
size_t length = num_batch * iheight * iwidth * idepth * num_channels;
|
||||
OutPutInitKernel(raw_output, length);
|
||||
OutPutInitKernel<DATA_T>(raw_output, length);
|
||||
for (size_t n = 0; n < num_batch; n++) {
|
||||
size_t noutput_offset = n * num_channels * iwidth * iheight * idepth;
|
||||
size_t n_grads_offset = n * num_channels * owidth * oheight * odepth;
|
||||
|
@ -106,15 +120,15 @@ bool MaxUnpool3DGradCPUKernel<DATA_T, INDICES_T>::Launch(const std::vector<kerne
|
|||
}
|
||||
}
|
||||
} else {
|
||||
size_t odepth = grads_shape_[kInputIndex2];
|
||||
size_t oheight = grads_shape_[kInputIndex3];
|
||||
size_t owidth = grads_shape_[kInputIndex4];
|
||||
size_t num_channels = grads_shape_[kInputIndex1];
|
||||
size_t idepth = output_shape_[kInputIndex2];
|
||||
size_t iheight = output_shape_[kInputIndex3];
|
||||
size_t iwidth = output_shape_[kInputIndex4];
|
||||
size_t odepth = LongToSize(grads_shape_[kInputIndex2]);
|
||||
size_t oheight = LongToSize(grads_shape_[kInputIndex3]);
|
||||
size_t owidth = LongToSize(grads_shape_[kInputIndex4]);
|
||||
size_t num_channels = LongToSize(grads_shape_[kInputIndex1]);
|
||||
size_t idepth = LongToSize(output_shape_[kInputIndex2]);
|
||||
size_t iheight = LongToSize(output_shape_[kInputIndex3]);
|
||||
size_t iwidth = LongToSize(output_shape_[kInputIndex4]);
|
||||
size_t length = num_batch * iheight * iwidth * idepth * num_channels;
|
||||
OutPutInitKernel(raw_output, length);
|
||||
OutPutInitKernel<DATA_T>(raw_output, length);
|
||||
for (size_t n = 0; n < num_batch; n++) {
|
||||
size_t noutput_offset = n * num_channels * iwidth * iheight * idepth;
|
||||
size_t n_grads_offset = n * num_channels * owidth * oheight * odepth;
|
||||
|
@ -149,5 +163,148 @@ bool MaxUnpool3DGradCPUKernel<DATA_T, INDICES_T>::Launch(const std::vector<kerne
|
|||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<std::pair<KernelAttr, MaxUnpool3DGradCpuKernelMod::MaxUnpool3DGradFunc>>
|
||||
MaxUnpool3DGradCpuKernelMod::func_list_ = {{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeUInt8)
|
||||
.AddInputAttr(kNumberTypeUInt8)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeUInt8),
|
||||
&MaxUnpool3DGradCpuKernelMod::LaunchKernel<uint8_t, int32_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeUInt8)
|
||||
.AddInputAttr(kNumberTypeUInt8)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeUInt8),
|
||||
&MaxUnpool3DGradCpuKernelMod::LaunchKernel<uint8_t, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeUInt16)
|
||||
.AddInputAttr(kNumberTypeUInt16)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeUInt16),
|
||||
&MaxUnpool3DGradCpuKernelMod::LaunchKernel<uint16_t, int32_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeUInt16)
|
||||
.AddInputAttr(kNumberTypeUInt16)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeUInt16),
|
||||
&MaxUnpool3DGradCpuKernelMod::LaunchKernel<uint16_t, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeUInt32)
|
||||
.AddInputAttr(kNumberTypeUInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeUInt32),
|
||||
&MaxUnpool3DGradCpuKernelMod::LaunchKernel<uint32_t, int32_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeUInt32)
|
||||
.AddInputAttr(kNumberTypeUInt32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeUInt32),
|
||||
&MaxUnpool3DGradCpuKernelMod::LaunchKernel<uint32_t, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeUInt64)
|
||||
.AddInputAttr(kNumberTypeUInt64)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeUInt64),
|
||||
&MaxUnpool3DGradCpuKernelMod::LaunchKernel<uint64_t, int32_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeUInt64)
|
||||
.AddInputAttr(kNumberTypeUInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeUInt64),
|
||||
&MaxUnpool3DGradCpuKernelMod::LaunchKernel<uint64_t, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt8)
|
||||
.AddInputAttr(kNumberTypeInt8)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeInt8),
|
||||
&MaxUnpool3DGradCpuKernelMod::LaunchKernel<int8_t, int32_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt8)
|
||||
.AddInputAttr(kNumberTypeInt8)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeInt8),
|
||||
&MaxUnpool3DGradCpuKernelMod::LaunchKernel<int8_t, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt16)
|
||||
.AddInputAttr(kNumberTypeInt16)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeInt16),
|
||||
&MaxUnpool3DGradCpuKernelMod::LaunchKernel<int16_t, int32_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt16)
|
||||
.AddInputAttr(kNumberTypeInt16)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeInt16),
|
||||
&MaxUnpool3DGradCpuKernelMod::LaunchKernel<int16_t, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeInt32),
|
||||
&MaxUnpool3DGradCpuKernelMod::LaunchKernel<int32_t, int32_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeInt32),
|
||||
&MaxUnpool3DGradCpuKernelMod::LaunchKernel<int32_t, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeInt64),
|
||||
&MaxUnpool3DGradCpuKernelMod::LaunchKernel<int64_t, int32_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeInt64),
|
||||
&MaxUnpool3DGradCpuKernelMod::LaunchKernel<int64_t, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
&MaxUnpool3DGradCpuKernelMod::LaunchKernel<float16, int32_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
&MaxUnpool3DGradCpuKernelMod::LaunchKernel<float16, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
&MaxUnpool3DGradCpuKernelMod::LaunchKernel<float, int32_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
&MaxUnpool3DGradCpuKernelMod::LaunchKernel<float, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeFloat64),
|
||||
&MaxUnpool3DGradCpuKernelMod::LaunchKernel<double, int32_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat64),
|
||||
&MaxUnpool3DGradCpuKernelMod::LaunchKernel<double, int64_t>}};
|
||||
|
||||
std::vector<KernelAttr> MaxUnpool3DGradCpuKernelMod::GetOpSupport() {
|
||||
std::vector<KernelAttr> support_list;
|
||||
(void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
|
||||
[](const std::pair<KernelAttr, MaxUnpool3DGradFunc> &pair) { return pair.first; });
|
||||
return support_list;
|
||||
}
|
||||
|
||||
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, MaxUnpool3DGrad, MaxUnpool3DGradCpuKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -19,196 +19,44 @@
|
|||
#include <memory>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
|
||||
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
|
||||
#include <utility>
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel.h"
|
||||
#include "plugin/factory/ms_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
template <typename DATA_T, typename INDICES_T>
|
||||
class MaxUnpool3DGradCPUKernel : public CPUKernel {
|
||||
class MaxUnpool3DGradCpuKernelMod : public DeprecatedNativeCpuKernelMod {
|
||||
public:
|
||||
MaxUnpool3DGradCPUKernel() = default;
|
||||
~MaxUnpool3DGradCPUKernel() override = default;
|
||||
MaxUnpool3DGradCpuKernelMod() = default;
|
||||
~MaxUnpool3DGradCpuKernelMod() override = default;
|
||||
|
||||
void InitKernel(const CNodePtr &kernel_node) override;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &outputs) override {
|
||||
return kernel_func_(this, inputs, outputs);
|
||||
};
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
template <typename DATA_T, typename INDICES_T>
|
||||
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
|
||||
using MaxUnpool3DGradFunc = std::function<bool(MaxUnpool3DGradCpuKernelMod *, const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &)>;
|
||||
static std::vector<std::pair<KernelAttr, MaxUnpool3DGradFunc>> func_list_;
|
||||
MaxUnpool3DGradFunc kernel_func_;
|
||||
|
||||
template <typename DATA_T>
|
||||
void OutPutInitKernel(DATA_T *rawOutput, size_t length);
|
||||
CNodeWeakPtr node_wpt_;
|
||||
std::vector<size_t> input_shape_;
|
||||
std::vector<size_t> grads_shape_;
|
||||
std::vector<size_t> indices_shape_;
|
||||
std::vector<size_t> output_shape_;
|
||||
ShapeVector input_shape_;
|
||||
ShapeVector grads_shape_;
|
||||
ShapeVector indices_shape_;
|
||||
ShapeVector output_shape_;
|
||||
std::string data_format_;
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL_T_S(MaxUnpool3DGrad,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeUInt8)
|
||||
.AddInputAttr(kNumberTypeUInt8)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeUInt8),
|
||||
MaxUnpool3DGradCPUKernel, uint8_t, int32_t);
|
||||
MS_REG_CPU_KERNEL_T_S(MaxUnpool3DGrad,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeUInt8)
|
||||
.AddInputAttr(kNumberTypeUInt8)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeUInt8),
|
||||
MaxUnpool3DGradCPUKernel, uint8_t, int64_t);
|
||||
|
||||
MS_REG_CPU_KERNEL_T_S(MaxUnpool3DGrad,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeUInt16)
|
||||
.AddInputAttr(kNumberTypeUInt16)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeUInt16),
|
||||
MaxUnpool3DGradCPUKernel, uint16_t, int32_t);
|
||||
MS_REG_CPU_KERNEL_T_S(MaxUnpool3DGrad,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeUInt16)
|
||||
.AddInputAttr(kNumberTypeUInt16)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeUInt16),
|
||||
MaxUnpool3DGradCPUKernel, uint16_t, int64_t);
|
||||
|
||||
MS_REG_CPU_KERNEL_T_S(MaxUnpool3DGrad,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeUInt32)
|
||||
.AddInputAttr(kNumberTypeUInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeUInt32),
|
||||
MaxUnpool3DGradCPUKernel, uint32_t, int32_t);
|
||||
MS_REG_CPU_KERNEL_T_S(MaxUnpool3DGrad,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeUInt32)
|
||||
.AddInputAttr(kNumberTypeUInt32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeUInt32),
|
||||
MaxUnpool3DGradCPUKernel, uint32_t, int64_t);
|
||||
|
||||
MS_REG_CPU_KERNEL_T_S(MaxUnpool3DGrad,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeUInt64)
|
||||
.AddInputAttr(kNumberTypeUInt64)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeUInt64),
|
||||
MaxUnpool3DGradCPUKernel, uint64_t, int32_t);
|
||||
MS_REG_CPU_KERNEL_T_S(MaxUnpool3DGrad,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeUInt64)
|
||||
.AddInputAttr(kNumberTypeUInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeUInt64),
|
||||
MaxUnpool3DGradCPUKernel, uint64_t, int64_t);
|
||||
|
||||
MS_REG_CPU_KERNEL_T_S(MaxUnpool3DGrad,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt8)
|
||||
.AddInputAttr(kNumberTypeInt8)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeInt8),
|
||||
MaxUnpool3DGradCPUKernel, int8_t, int32_t);
|
||||
MS_REG_CPU_KERNEL_T_S(MaxUnpool3DGrad,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt8)
|
||||
.AddInputAttr(kNumberTypeInt8)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeInt8),
|
||||
MaxUnpool3DGradCPUKernel, int8_t, int64_t);
|
||||
|
||||
MS_REG_CPU_KERNEL_T_S(MaxUnpool3DGrad,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt16)
|
||||
.AddInputAttr(kNumberTypeInt16)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeInt16),
|
||||
MaxUnpool3DGradCPUKernel, int16_t, int32_t);
|
||||
MS_REG_CPU_KERNEL_T_S(MaxUnpool3DGrad,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt16)
|
||||
.AddInputAttr(kNumberTypeInt16)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeInt16),
|
||||
MaxUnpool3DGradCPUKernel, int16_t, int64_t);
|
||||
|
||||
MS_REG_CPU_KERNEL_T_S(MaxUnpool3DGrad,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeInt32),
|
||||
MaxUnpool3DGradCPUKernel, int32_t, int32_t);
|
||||
MS_REG_CPU_KERNEL_T_S(MaxUnpool3DGrad,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeInt32),
|
||||
MaxUnpool3DGradCPUKernel, int32_t, int64_t);
|
||||
|
||||
MS_REG_CPU_KERNEL_T_S(MaxUnpool3DGrad,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeInt64),
|
||||
MaxUnpool3DGradCPUKernel, int64_t, int32_t);
|
||||
MS_REG_CPU_KERNEL_T_S(MaxUnpool3DGrad,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeInt64),
|
||||
MaxUnpool3DGradCPUKernel, int64_t, int64_t);
|
||||
|
||||
MS_REG_CPU_KERNEL_T_S(MaxUnpool3DGrad,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
MaxUnpool3DGradCPUKernel, float16, int32_t);
|
||||
MS_REG_CPU_KERNEL_T_S(MaxUnpool3DGrad,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
MaxUnpool3DGradCPUKernel, float16, int64_t);
|
||||
|
||||
MS_REG_CPU_KERNEL_T_S(MaxUnpool3DGrad,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
MaxUnpool3DGradCPUKernel, float, int32_t);
|
||||
MS_REG_CPU_KERNEL_T_S(MaxUnpool3DGrad,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
MaxUnpool3DGradCPUKernel, float, int64_t);
|
||||
|
||||
MS_REG_CPU_KERNEL_T_S(MaxUnpool3DGrad,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeFloat64),
|
||||
MaxUnpool3DGradCPUKernel, double, int32_t);
|
||||
MS_REG_CPU_KERNEL_T_S(MaxUnpool3DGrad,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat64),
|
||||
MaxUnpool3DGradCPUKernel, double, int64_t);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -14,8 +14,8 @@
|
|||
* specific language governing permissions and limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/kernel_compiler/cpu/multi_margin_loss_cpu_kernel.h"
|
||||
#include "runtime/device/cpu/cpu_device_address.h"
|
||||
#include "plugin/device/cpu/kernel/multi_margin_loss_cpu_kernel.h"
|
||||
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
|
@ -23,24 +23,30 @@ namespace {
|
|||
constexpr size_t kMultiMarginLossInputNumWithWeight = 3;
|
||||
constexpr size_t kMultiMarginLossInputNumWithoutWeight = 2;
|
||||
constexpr size_t kMultiMarginLossOutputsNum = 1;
|
||||
const size_t kZero = 0;
|
||||
const size_t kOne = 1;
|
||||
const size_t kTwo = 2;
|
||||
constexpr char kKernelName[] = "MultiMarginLoss";
|
||||
} // namespace
|
||||
|
||||
void MultiMarginLossCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
||||
void MultiMarginLossCPUKernelMod::InitKernel(const CNodePtr &kernel_node) {
|
||||
CheckParam(kernel_node);
|
||||
std::vector<size_t> x_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||
batch_size = x_shape[0];
|
||||
dims = x_shape[1];
|
||||
reduction = AnfAlgo::GetNodeAttr<string>(kernel_node, REDUCTION);
|
||||
p = AnfAlgo::GetNodeAttr<int64_t>(kernel_node, "p");
|
||||
margin = AnfAlgo::GetNodeAttr<float>(kernel_node, "margin");
|
||||
dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
|
||||
input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
ShapeVector x_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, kZero);
|
||||
if (IsDynamic({x_shape})) {
|
||||
return;
|
||||
}
|
||||
batch_size = LongToSize(x_shape[kZero]);
|
||||
dims = LongToSize(x_shape[kOne]);
|
||||
reduction = common::AnfAlgo::GetNodeAttr<string>(kernel_node, REDUCTION);
|
||||
p = common::AnfAlgo::GetNodeAttr<int64_t>(kernel_node, "p");
|
||||
margin = common::AnfAlgo::GetNodeAttr<float>(kernel_node, "margin");
|
||||
dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, kZero);
|
||||
input_num = common::AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
}
|
||||
|
||||
bool MultiMarginLossCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
bool MultiMarginLossCPUKernelMod::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
if (dtype_ == kNumberTypeFloat16) {
|
||||
LaunchKernelFP16<float16>(inputs, outputs);
|
||||
} else if (dtype_ == kNumberTypeFloat32) {
|
||||
|
@ -54,10 +60,10 @@ bool MultiMarginLossCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inp
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
void MultiMarginLossCPUKernel::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
auto x_addr = reinterpret_cast<T *>(inputs[0]->addr);
|
||||
auto target_addr = reinterpret_cast<int64_t *>(inputs[1]->addr);
|
||||
void MultiMarginLossCPUKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
auto x_addr = reinterpret_cast<T *>(inputs[kZero]->addr);
|
||||
auto target_addr = reinterpret_cast<int64_t *>(inputs[kOne]->addr);
|
||||
for (size_t i = 0; i < batch_size; i++) {
|
||||
if (target_addr[i] < 0 || target_addr[i] >= SizeToLong(dims)) {
|
||||
MS_EXCEPTION(ValueError) << "Target out of range.";
|
||||
|
@ -66,9 +72,9 @@ void MultiMarginLossCPUKernel::LaunchKernel(const std::vector<kernel::AddressPtr
|
|||
T *weight_addr = nullptr;
|
||||
bool weight_defined_ = (input_num == 3);
|
||||
if (weight_defined_) {
|
||||
weight_addr = reinterpret_cast<T *>(inputs[2]->addr);
|
||||
weight_addr = reinterpret_cast<T *>(inputs[kTwo]->addr);
|
||||
}
|
||||
auto y_addr = reinterpret_cast<T *>(outputs[0]->addr);
|
||||
auto y_addr = reinterpret_cast<T *>(outputs[kZero]->addr);
|
||||
std::vector<T> tmp_loss(batch_size);
|
||||
auto task = [&](size_t start, size_t end) {
|
||||
start *= dims;
|
||||
|
@ -117,10 +123,10 @@ void MultiMarginLossCPUKernel::LaunchKernel(const std::vector<kernel::AddressPtr
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
void MultiMarginLossCPUKernel::LaunchKernelFP16(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
auto x_addr = reinterpret_cast<T *>(inputs[0]->addr);
|
||||
auto target_addr = reinterpret_cast<int64_t *>(inputs[1]->addr);
|
||||
void MultiMarginLossCPUKernelMod::LaunchKernelFP16(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
auto x_addr = reinterpret_cast<T *>(inputs[kZero]->addr);
|
||||
auto target_addr = reinterpret_cast<int64_t *>(inputs[kOne]->addr);
|
||||
for (size_t i = 0; i < batch_size; i++) {
|
||||
if (target_addr[i] < 0 || target_addr[i] >= SizeToLong(dims)) {
|
||||
MS_EXCEPTION(ValueError) << "Target out of range.";
|
||||
|
@ -129,9 +135,9 @@ void MultiMarginLossCPUKernel::LaunchKernelFP16(const std::vector<kernel::Addres
|
|||
T *weight_addr = nullptr;
|
||||
bool weight_defined_ = (input_num == 3);
|
||||
if (weight_defined_) {
|
||||
weight_addr = reinterpret_cast<T *>(inputs[2]->addr);
|
||||
weight_addr = reinterpret_cast<T *>(inputs[kTwo]->addr);
|
||||
}
|
||||
auto y_addr = reinterpret_cast<T *>(outputs[0]->addr);
|
||||
auto y_addr = reinterpret_cast<T *>(outputs[kZero]->addr);
|
||||
std::vector<float> tmp_loss(batch_size);
|
||||
auto task = [&](size_t start, size_t end) {
|
||||
start *= dims;
|
||||
|
@ -180,13 +186,15 @@ void MultiMarginLossCPUKernel::LaunchKernelFP16(const std::vector<kernel::Addres
|
|||
}
|
||||
}
|
||||
|
||||
void MultiMarginLossCPUKernel::CheckParam(const CNodePtr &kernel_node) {
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
void MultiMarginLossCPUKernelMod::CheckParam(const CNodePtr &kernel_node) {
|
||||
size_t input_num = common::AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
if (input_num != kMultiMarginLossInputNumWithoutWeight && input_num != kMultiMarginLossInputNumWithWeight) {
|
||||
MS_LOG(EXCEPTION) << "Invalid input numbers, expect input number 2 or 3, but actual input number " << input_num;
|
||||
}
|
||||
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
|
||||
size_t output_num = common::AnfAlgo::GetOutputTensorNum(kernel_node);
|
||||
CHECK_KERNEL_OUTPUTS_NUM(output_num, kMultiMarginLossOutputsNum, kKernelName);
|
||||
}
|
||||
|
||||
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, MultiMarginLoss, MultiMarginLossCPUKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -20,29 +20,52 @@
|
|||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
|
||||
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel.h"
|
||||
#include "plugin/factory/ms_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
class MultiMarginLossCPUKernel : public CPUKernel {
|
||||
class MultiMarginLossCPUKernelMod : public DeprecatedNativeCpuKernelMod {
|
||||
public:
|
||||
MultiMarginLossCPUKernel() = default;
|
||||
MultiMarginLossCPUKernelMod() = default;
|
||||
|
||||
~MultiMarginLossCPUKernel() override = default;
|
||||
~MultiMarginLossCPUKernelMod() override = default;
|
||||
|
||||
void InitKernel(const CNodePtr &kernel_node) override;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override {
|
||||
static std::vector<KernelAttr> support_list = {
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddOutputAttr(kNumberTypeFloat64),
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat16),
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32),
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat64)};
|
||||
return support_list;
|
||||
}
|
||||
|
||||
private:
|
||||
template <typename T>
|
||||
void LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
|
||||
|
||||
template <typename T>
|
||||
void LaunchKernelFP16(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
|
||||
|
||||
private:
|
||||
void CheckParam(const CNodePtr &kernel_node);
|
||||
size_t batch_size = 2;
|
||||
size_t dims = 1;
|
||||
|
@ -52,45 +75,6 @@ class MultiMarginLossCPUKernel : public CPUKernel {
|
|||
size_t input_num = 1;
|
||||
TypeId dtype_{kTypeUnknown};
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL(MultiMarginLoss,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
MultiMarginLossCPUKernel);
|
||||
|
||||
MS_REG_CPU_KERNEL(MultiMarginLoss,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
MultiMarginLossCPUKernel);
|
||||
|
||||
MS_REG_CPU_KERNEL(MultiMarginLoss,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddOutputAttr(kNumberTypeFloat64),
|
||||
MultiMarginLossCPUKernel);
|
||||
|
||||
MS_REG_CPU_KERNEL(
|
||||
MultiMarginLoss,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat16),
|
||||
MultiMarginLossCPUKernel);
|
||||
|
||||
MS_REG_CPU_KERNEL(
|
||||
MultiMarginLoss,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32),
|
||||
MultiMarginLossCPUKernel);
|
||||
|
||||
MS_REG_CPU_KERNEL(
|
||||
MultiMarginLoss,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat64),
|
||||
MultiMarginLossCPUKernel);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_MULTI_MARGIN_LOSS_CPU_KERNEL_H_
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/kernel_compiler/cpu/multi_margin_loss_grad_cpu_kernel.h"
|
||||
#include "runtime/device/cpu/cpu_device_address.h"
|
||||
#include "plugin/device/cpu/kernel/multi_margin_loss_grad_cpu_kernel.h"
|
||||
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
|
@ -23,25 +23,32 @@ namespace {
|
|||
constexpr size_t kMultiMarginLossGradInputNumWithWeight = 4;
|
||||
constexpr size_t kMultiMarginLossGradInputNumWithoutWeight = 3;
|
||||
constexpr size_t kMultiMarginLossGradOutputsNum = 1;
|
||||
const size_t kZero = 0;
|
||||
const size_t kOne = 1;
|
||||
const size_t kTwo = 2;
|
||||
const size_t kThree = 3;
|
||||
constexpr char kKernelName[] = "MultiMarginLossGrad";
|
||||
} // namespace
|
||||
|
||||
void MultiMarginLossGradCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
||||
void MultiMarginLossGradCPUKernelMod::InitKernel(const CNodePtr &kernel_node) {
|
||||
CheckParam(kernel_node);
|
||||
std::vector<size_t> x_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
|
||||
batch_size = x_shape[0];
|
||||
dims = x_shape[1];
|
||||
reduction = AnfAlgo::GetNodeAttr<string>(kernel_node, REDUCTION);
|
||||
p = AnfAlgo::GetNodeAttr<int64_t>(kernel_node, "p");
|
||||
margin = AnfAlgo::GetNodeAttr<float>(kernel_node, "margin");
|
||||
y_grad_dims = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0).size();
|
||||
dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
|
||||
input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
ShapeVector x_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, kOne);
|
||||
if (IsDynamic({x_shape})) {
|
||||
return;
|
||||
}
|
||||
batch_size = LongToSize(x_shape[kZero]);
|
||||
dims = LongToSize(x_shape[kOne]);
|
||||
reduction = common::AnfAlgo::GetNodeAttr<string>(kernel_node, REDUCTION);
|
||||
p = common::AnfAlgo::GetNodeAttr<int64_t>(kernel_node, "p");
|
||||
margin = common::AnfAlgo::GetNodeAttr<float>(kernel_node, "margin");
|
||||
y_grad_dims = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, kZero).size();
|
||||
dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, kZero);
|
||||
input_num = common::AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
}
|
||||
|
||||
bool MultiMarginLossGradCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
bool MultiMarginLossGradCPUKernelMod::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
if (dtype_ == kNumberTypeFloat16) {
|
||||
LaunchKernelFP16<float16>(inputs, outputs);
|
||||
} else if (dtype_ == kNumberTypeFloat32) {
|
||||
|
@ -55,11 +62,11 @@ bool MultiMarginLossGradCPUKernel::Launch(const std::vector<kernel::AddressPtr>
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
void MultiMarginLossGradCPUKernel::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
auto y_grad_addr = reinterpret_cast<T *>(inputs[0]->addr);
|
||||
auto x_addr = reinterpret_cast<T *>(inputs[1]->addr);
|
||||
auto target_addr = reinterpret_cast<int64_t *>(inputs[2]->addr);
|
||||
void MultiMarginLossGradCPUKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
auto y_grad_addr = reinterpret_cast<T *>(inputs[kZero]->addr);
|
||||
auto x_addr = reinterpret_cast<T *>(inputs[kOne]->addr);
|
||||
auto target_addr = reinterpret_cast<int64_t *>(inputs[kTwo]->addr);
|
||||
for (size_t i = 0; i < batch_size; i++) {
|
||||
if (target_addr[i] < 0 || target_addr[i] >= SizeToLong(dims)) {
|
||||
MS_EXCEPTION(ValueError) << "Target out of range.";
|
||||
|
@ -68,12 +75,9 @@ void MultiMarginLossGradCPUKernel::LaunchKernel(const std::vector<kernel::Addres
|
|||
T *weight_addr = nullptr;
|
||||
bool weight_defined_ = (input_num == 4);
|
||||
if (weight_defined_) {
|
||||
weight_addr = reinterpret_cast<T *>(inputs[3]->addr);
|
||||
weight_addr = reinterpret_cast<T *>(inputs[kThree]->addr);
|
||||
}
|
||||
auto x_grad_addr = reinterpret_cast<T *>(outputs[0]->addr);
|
||||
T weights;
|
||||
weights = reduction == MEAN ? (static_cast<T>(1) / (static_cast<T>(dims) * static_cast<T>(batch_size)))
|
||||
: (static_cast<T>(1) / static_cast<T>(dims));
|
||||
auto x_grad_addr = reinterpret_cast<T *>(outputs[kZero]->addr);
|
||||
auto task = [&](size_t start, size_t end) {
|
||||
start *= dims;
|
||||
end *= dims;
|
||||
|
@ -91,6 +95,8 @@ void MultiMarginLossGradCPUKernel::LaunchKernel(const std::vector<kernel::Addres
|
|||
continue;
|
||||
}
|
||||
if (calc_data[d] > static_cast<T>(0)) {
|
||||
auto weights = reduction == MEAN ? (static_cast<T>(1) / (static_cast<T>(dims) * static_cast<T>(batch_size)))
|
||||
: (static_cast<T>(1) / static_cast<T>(dims));
|
||||
calc_data[d] = (p == 1) ? weights : static_cast<T>(2) * weights * calc_data[d];
|
||||
if (weight_defined_) {
|
||||
calc_data[d] *= static_cast<T>(weight_addr[target_idx]);
|
||||
|
@ -122,11 +128,11 @@ void MultiMarginLossGradCPUKernel::LaunchKernel(const std::vector<kernel::Addres
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
void MultiMarginLossGradCPUKernel::LaunchKernelFP16(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
auto y_grad_addr = reinterpret_cast<T *>(inputs[0]->addr);
|
||||
auto x_addr = reinterpret_cast<T *>(inputs[1]->addr);
|
||||
auto target_addr = reinterpret_cast<int64_t *>(inputs[2]->addr);
|
||||
void MultiMarginLossGradCPUKernelMod::LaunchKernelFP16(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
auto y_grad_addr = reinterpret_cast<T *>(inputs[kZero]->addr);
|
||||
auto x_addr = reinterpret_cast<T *>(inputs[kOne]->addr);
|
||||
auto target_addr = reinterpret_cast<int64_t *>(inputs[kTwo]->addr);
|
||||
for (size_t i = 0; i < batch_size; i++) {
|
||||
if (target_addr[i] < 0 || target_addr[i] >= SizeToLong(dims)) {
|
||||
MS_EXCEPTION(ValueError) << "Target out of range.";
|
||||
|
@ -135,12 +141,9 @@ void MultiMarginLossGradCPUKernel::LaunchKernelFP16(const std::vector<kernel::Ad
|
|||
T *weight_addr = nullptr;
|
||||
bool weight_defined_ = (input_num == 4);
|
||||
if (weight_defined_) {
|
||||
weight_addr = reinterpret_cast<T *>(inputs[3]->addr);
|
||||
weight_addr = reinterpret_cast<T *>(inputs[kThree]->addr);
|
||||
}
|
||||
auto x_grad_addr = reinterpret_cast<T *>(outputs[0]->addr);
|
||||
float weights;
|
||||
weights = reduction == MEAN ? (static_cast<float>(1) / (static_cast<float>(dims) * static_cast<float>(batch_size)))
|
||||
: (static_cast<float>(1) / static_cast<float>(dims));
|
||||
auto x_grad_addr = reinterpret_cast<T *>(outputs[kZero]->addr);
|
||||
auto task = [&](size_t start, size_t end) {
|
||||
start *= dims;
|
||||
end *= dims;
|
||||
|
@ -158,6 +161,9 @@ void MultiMarginLossGradCPUKernel::LaunchKernelFP16(const std::vector<kernel::Ad
|
|||
continue;
|
||||
}
|
||||
if (calc_data[d] > static_cast<float>(0)) {
|
||||
auto weights = reduction == MEAN
|
||||
? (static_cast<float>(1) / (static_cast<float>(dims) * static_cast<float>(batch_size)))
|
||||
: (static_cast<float>(1) / static_cast<float>(dims));
|
||||
calc_data[d] = (p == 1) ? weights : static_cast<float>(2) * weights * calc_data[d];
|
||||
if (weight_defined_) {
|
||||
calc_data[d] *= static_cast<float>(weight_addr[target_idx]);
|
||||
|
@ -189,13 +195,15 @@ void MultiMarginLossGradCPUKernel::LaunchKernelFP16(const std::vector<kernel::Ad
|
|||
}
|
||||
}
|
||||
|
||||
void MultiMarginLossGradCPUKernel::CheckParam(const CNodePtr &kernel_node) {
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
void MultiMarginLossGradCPUKernelMod::CheckParam(const CNodePtr &kernel_node) {
|
||||
size_t input_num = common::AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
if (input_num != kMultiMarginLossGradInputNumWithoutWeight && input_num != kMultiMarginLossGradInputNumWithWeight) {
|
||||
MS_LOG(EXCEPTION) << "Invalid input numbers, expect input number 3 or 4, but actual input number " << input_num;
|
||||
}
|
||||
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
|
||||
size_t output_num = common::AnfAlgo::GetOutputTensorNum(kernel_node);
|
||||
CHECK_KERNEL_OUTPUTS_NUM(output_num, kMultiMarginLossGradOutputsNum, kKernelName);
|
||||
}
|
||||
|
||||
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, MultiMarginLossGrad, MultiMarginLossGradCPUKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -20,29 +20,66 @@
|
|||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
|
||||
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel.h"
|
||||
#include "plugin/factory/ms_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
class MultiMarginLossGradCPUKernel : public CPUKernel {
|
||||
class MultiMarginLossGradCPUKernelMod : public DeprecatedNativeCpuKernelMod {
|
||||
public:
|
||||
MultiMarginLossGradCPUKernel() = default;
|
||||
MultiMarginLossGradCPUKernelMod() = default;
|
||||
|
||||
~MultiMarginLossGradCPUKernel() override = default;
|
||||
~MultiMarginLossGradCPUKernelMod() override = default;
|
||||
|
||||
void InitKernel(const CNodePtr &kernel_node) override;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override {
|
||||
static std::vector<KernelAttr> support_list = {KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddOutputAttr(kNumberTypeFloat64),
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat64)};
|
||||
return support_list;
|
||||
}
|
||||
|
||||
private:
|
||||
template <typename T>
|
||||
void LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
|
||||
|
||||
template <typename T>
|
||||
void LaunchKernelFP16(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
|
||||
|
||||
private:
|
||||
void CheckParam(const CNodePtr &kernel_node);
|
||||
size_t batch_size = 2;
|
||||
size_t dims = 1;
|
||||
|
@ -53,57 +90,6 @@ class MultiMarginLossGradCPUKernel : public CPUKernel {
|
|||
size_t y_grad_dims = 1;
|
||||
TypeId dtype_{kTypeUnknown};
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL(MultiMarginLossGrad,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
MultiMarginLossGradCPUKernel);
|
||||
|
||||
MS_REG_CPU_KERNEL(MultiMarginLossGrad,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
MultiMarginLossGradCPUKernel);
|
||||
|
||||
MS_REG_CPU_KERNEL(MultiMarginLossGrad,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddOutputAttr(kNumberTypeFloat64),
|
||||
MultiMarginLossGradCPUKernel);
|
||||
|
||||
MS_REG_CPU_KERNEL(MultiMarginLossGrad,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
MultiMarginLossGradCPUKernel);
|
||||
|
||||
MS_REG_CPU_KERNEL(MultiMarginLossGrad,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
MultiMarginLossGradCPUKernel);
|
||||
|
||||
MS_REG_CPU_KERNEL(MultiMarginLossGrad,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat64),
|
||||
MultiMarginLossGradCPUKernel);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_MULTI_MARGIN_LOSS_GRAD_CPU_KERNEL_H_
|
||||
|
|
|
@ -91,4 +91,4 @@ std::vector<KernelAttr> MvlgammaCpuKernelMod::GetOpSupport() {
|
|||
}
|
||||
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, Mvlgamma, MvlgammaCpuKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -24,7 +24,6 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
|
||||
class MvlgammaCpuKernelMod : public DeprecatedNativeCpuKernelMod {
|
||||
public:
|
||||
MvlgammaCpuKernelMod() = default;
|
||||
|
@ -51,8 +50,6 @@ class MvlgammaCpuKernelMod : public DeprecatedNativeCpuKernelMod {
|
|||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
};
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_MVLGAMMA_CPU_KERNEL_H_
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_MVLGAMMA_CPU_KERNEL_H_
|
||||
|
|
|
@ -152,6 +152,5 @@ std::vector<KernelAttr> MvlgammaGradCpuKernelMod::GetOpSupport() {
|
|||
}
|
||||
|
||||
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, MvlgammaGrad, MvlgammaGradCpuKernelMod);
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -24,7 +24,6 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
|
||||
class MvlgammaGradCpuKernelMod : public DeprecatedNativeCpuKernelMod {
|
||||
public:
|
||||
MvlgammaGradCpuKernelMod() = default;
|
||||
|
@ -54,8 +53,6 @@ class MvlgammaGradCpuKernelMod : public DeprecatedNativeCpuKernelMod {
|
|||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
};
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_MVLGAMMA_GRAD_CPU_KERNEL_H_
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_MVLGAMMA_GRAD_CPU_KERNEL_H_
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -14,32 +14,35 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/kernel_compiler/cpu/triplet_margin_loss_cpu_kernel.h"
|
||||
#include "runtime/device/cpu/cpu_device_address.h"
|
||||
#include "plugin/device/cpu/kernel/triplet_margin_loss_cpu_kernel.h"
|
||||
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
void TripletMarginLossCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
||||
void TripletMarginLossCPUKernelMod::InitKernel(const CNodePtr &kernel_node) {
|
||||
CheckParam(kernel_node);
|
||||
constexpr int kzero = 0;
|
||||
constexpr int kone = 1;
|
||||
constexpr int ktwo = 2;
|
||||
constexpr int kthree = 3;
|
||||
constexpr int kZero = 0;
|
||||
constexpr int kOne = 1;
|
||||
constexpr int kTwo = 2;
|
||||
constexpr int kThree = 3;
|
||||
constexpr int kParallel = 28;
|
||||
constexpr int kParallelunit = 1024;
|
||||
p = AnfAlgo::GetNodeAttr<int64_t>(kernel_node, "p");
|
||||
swap = AnfAlgo::GetNodeAttr<bool>(kernel_node, "swap");
|
||||
eps = AnfAlgo::GetNodeAttr<float>(kernel_node, "eps");
|
||||
reduction = AnfAlgo::GetNodeAttr<string>(kernel_node, "reduction");
|
||||
dtype_0 = AnfAlgo::GetInputDeviceDataType(kernel_node, kzero);
|
||||
dtype_1 = AnfAlgo::GetInputDeviceDataType(kernel_node, kone);
|
||||
dtype_2 = AnfAlgo::GetInputDeviceDataType(kernel_node, ktwo);
|
||||
dtype_3 = AnfAlgo::GetInputDeviceDataType(kernel_node, kthree);
|
||||
x_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, kzero);
|
||||
positive_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, kone);
|
||||
negative_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, ktwo);
|
||||
p = common::AnfAlgo::GetNodeAttr<int64_t>(kernel_node, "p");
|
||||
swap = common::AnfAlgo::GetNodeAttr<bool>(kernel_node, "swap");
|
||||
eps = common::AnfAlgo::GetNodeAttr<float>(kernel_node, "eps");
|
||||
reduction = common::AnfAlgo::GetNodeAttr<string>(kernel_node, "reduction");
|
||||
dtype_0 = AnfAlgo::GetInputDeviceDataType(kernel_node, kZero);
|
||||
dtype_1 = AnfAlgo::GetInputDeviceDataType(kernel_node, kOne);
|
||||
dtype_2 = AnfAlgo::GetInputDeviceDataType(kernel_node, kTwo);
|
||||
dtype_3 = AnfAlgo::GetInputDeviceDataType(kernel_node, kThree);
|
||||
x_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, kZero);
|
||||
positive_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, kOne);
|
||||
negative_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, kTwo);
|
||||
if (AnfAlgo::IsShapesDynamic({x_shape, positive_shape, negative_shape})) {
|
||||
return;
|
||||
}
|
||||
kParallelDataNum = kParallel * kParallelunit;
|
||||
std::vector<size_t> broadcast_shape_x_and_positive = CPUKernelUtils::GetBroadcastShape(x_shape, positive_shape);
|
||||
auto broadcast_shape_x_and_positive = CPUKernelUtils::GetBroadcastShape(x_shape, positive_shape);
|
||||
broadcast_shape = CPUKernelUtils::GetBroadcastShape(broadcast_shape_x_and_positive, negative_shape);
|
||||
size_t dim_x = x_shape.size();
|
||||
size_t dim_positive = positive_shape.size();
|
||||
|
@ -51,30 +54,28 @@ void TripletMarginLossCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
|||
std::reverse(x_reshape_vector.begin(), x_reshape_vector.end());
|
||||
std::reverse(positive_reshape_vector.begin(), positive_reshape_vector.end());
|
||||
std::reverse(negative_reshape_vector.begin(), negative_reshape_vector.end());
|
||||
if (dim_x < max_size) x_reshape_vector.resize(max_size, kone);
|
||||
if (dim_positive < max_size) positive_reshape_vector.resize(max_size, kone);
|
||||
if (dim_negative < max_size) negative_reshape_vector.resize(max_size, kone);
|
||||
if (dim_x < max_size) x_reshape_vector.resize(max_size, kOne);
|
||||
if (dim_positive < max_size) positive_reshape_vector.resize(max_size, kOne);
|
||||
if (dim_negative < max_size) negative_reshape_vector.resize(max_size, kOne);
|
||||
std::reverse(x_reshape_vector.begin(), x_reshape_vector.end());
|
||||
std::reverse(positive_reshape_vector.begin(), positive_reshape_vector.end());
|
||||
std::reverse(negative_reshape_vector.begin(), negative_reshape_vector.end());
|
||||
numelements = 1;
|
||||
for (size_t i = 0; i < broadcast_shape.size(); i++) {
|
||||
numelements *= broadcast_shape[i];
|
||||
}
|
||||
data_num = (numelements) / (broadcast_shape[1]);
|
||||
data_num_each_batch = (numelements) / (broadcast_shape[0]);
|
||||
index = data_num / (broadcast_shape[0]);
|
||||
batch_size = broadcast_shape[0];
|
||||
once_compute_size = broadcast_shape[1];
|
||||
numelements = LongToSize(SizeOf(broadcast_shape));
|
||||
|
||||
data_num = (numelements) / LongToSize(broadcast_shape[1]);
|
||||
data_num_each_batch = (numelements) / LongToSize(broadcast_shape[0]);
|
||||
index = data_num / LongToSize(broadcast_shape[0]);
|
||||
batch_size = LongToSize(broadcast_shape[0]);
|
||||
once_compute_size = LongToSize(broadcast_shape[1]);
|
||||
broadcast = false;
|
||||
if (x_shape != positive_shape || x_shape != negative_shape || positive_shape != negative_shape) {
|
||||
broadcast = true;
|
||||
}
|
||||
}
|
||||
|
||||
bool TripletMarginLossCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
bool TripletMarginLossCPUKernelMod::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
switch (dtype_0) {
|
||||
case kNumberTypeFloat16:
|
||||
TripletMarginLossCompute_realtype<float>(inputs, outputs);
|
||||
|
@ -123,73 +124,23 @@ bool TripletMarginLossCPUKernel::Launch(const std::vector<kernel::AddressPtr> &i
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
void TripletMarginLossCPUKernel::TripletMarginLossCompute_realtype(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
auto out_data = reinterpret_cast<float *>(outputs[0]->addr);
|
||||
Eigen::Array<float, Eigen::Dynamic, 1> out(data_num, 1);
|
||||
float *output_reduction_none_data = reinterpret_cast<float *>(out.data());
|
||||
auto task_nobroadcast = [&](size_t start, size_t end) {
|
||||
TripletMarginLossCPUKernel::realtype_nobroadcast_task<T>(start, end, output_reduction_none_data, inputs, outputs);
|
||||
};
|
||||
auto task_broadcast = [&](size_t start, size_t end) {
|
||||
TripletMarginLossCPUKernel::realtype_broadcast_task<T>(start, end, output_reduction_none_data, inputs, outputs);
|
||||
};
|
||||
if (broadcast == true) {
|
||||
if (numelements * sizeof(T) > kParallelDataNum) {
|
||||
CPUKernelUtils::ParallelFor(task_broadcast, batch_size);
|
||||
} else {
|
||||
TripletMarginLossCPUKernel::realtype_broadcast_compute<T>(output_reduction_none_data, inputs, outputs);
|
||||
}
|
||||
if (reduction == NONE) {
|
||||
for (size_t i = 0; i < data_num; i++) {
|
||||
*(out_data + i) = *(output_reduction_none_data + i);
|
||||
}
|
||||
}
|
||||
if (reduction == MEAN) {
|
||||
*(out_data) = (out.mean());
|
||||
}
|
||||
if (reduction == SUM) {
|
||||
*(out_data) = (out.sum());
|
||||
}
|
||||
return;
|
||||
}
|
||||
if (numelements * sizeof(T) > kParallelDataNum) {
|
||||
CPUKernelUtils::ParallelFor(task_nobroadcast, batch_size);
|
||||
} else {
|
||||
TripletMarginLossCPUKernel::realtype_nobroadcast_compute<T>(output_reduction_none_data, inputs, outputs);
|
||||
}
|
||||
if (reduction == NONE) {
|
||||
for (size_t i = 0; i < data_num; i++) {
|
||||
*(out_data + i) = *(output_reduction_none_data + i);
|
||||
}
|
||||
}
|
||||
if (reduction == MEAN) {
|
||||
*(out_data) = (out.mean());
|
||||
}
|
||||
if (reduction == SUM) {
|
||||
*(out_data) = (out.sum());
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void TripletMarginLossCPUKernel::TripletMarginLossCompute_complextype(const std::vector<kernel::AddressPtr> &inputs,
|
||||
void TripletMarginLossCPUKernelMod::TripletMarginLossCompute_realtype(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
auto out_data = reinterpret_cast<float *>(outputs[0]->addr);
|
||||
Eigen::Array<float, Eigen::Dynamic, 1> out(data_num, 1);
|
||||
float *output_reduction_none_data = reinterpret_cast<float *>(out.data());
|
||||
auto task_nobroadcast = [&](size_t start, size_t end) {
|
||||
TripletMarginLossCPUKernel::complextype_nobroadcast_task<T>(start, end, output_reduction_none_data, inputs,
|
||||
TripletMarginLossCPUKernelMod::realtype_nobroadcast_task<T>(start, end, output_reduction_none_data, inputs,
|
||||
outputs);
|
||||
};
|
||||
auto task_broadcast = [&](size_t start, size_t end) {
|
||||
TripletMarginLossCPUKernel::complextype_broadcast_task<T>(start, end, output_reduction_none_data, inputs, outputs);
|
||||
TripletMarginLossCPUKernelMod::realtype_broadcast_task<T>(start, end, output_reduction_none_data, inputs, outputs);
|
||||
};
|
||||
if (broadcast == true) {
|
||||
if (numelements * sizeof(T) > kParallelDataNum) {
|
||||
CPUKernelUtils::ParallelFor(task_broadcast, batch_size);
|
||||
} else {
|
||||
TripletMarginLossCPUKernel::complextype_broadcast_compute<T>(output_reduction_none_data, inputs, outputs);
|
||||
TripletMarginLossCPUKernelMod::realtype_broadcast_compute<T>(output_reduction_none_data, inputs, outputs);
|
||||
}
|
||||
if (reduction == NONE) {
|
||||
for (size_t i = 0; i < data_num; i++) {
|
||||
|
@ -207,7 +158,7 @@ void TripletMarginLossCPUKernel::TripletMarginLossCompute_complextype(const std:
|
|||
if (numelements * sizeof(T) > kParallelDataNum) {
|
||||
CPUKernelUtils::ParallelFor(task_nobroadcast, batch_size);
|
||||
} else {
|
||||
TripletMarginLossCPUKernel::complextype_nobroadcast_compute<T>(output_reduction_none_data, inputs, outputs);
|
||||
TripletMarginLossCPUKernelMod::realtype_nobroadcast_compute<T>(output_reduction_none_data, inputs, outputs);
|
||||
}
|
||||
if (reduction == NONE) {
|
||||
for (size_t i = 0; i < data_num; i++) {
|
||||
|
@ -224,9 +175,62 @@ void TripletMarginLossCPUKernel::TripletMarginLossCompute_complextype(const std:
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
void TripletMarginLossCPUKernel::realtype_nobroadcast_task(size_t start, size_t end, float *output_reduction_none_data,
|
||||
const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
void TripletMarginLossCPUKernelMod::TripletMarginLossCompute_complextype(
|
||||
const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &outputs) {
|
||||
auto out_data = reinterpret_cast<float *>(outputs[0]->addr);
|
||||
Eigen::Array<float, Eigen::Dynamic, 1> out(data_num, 1);
|
||||
float *output_reduction_none_data = reinterpret_cast<float *>(out.data());
|
||||
auto task_nobroadcast = [&](size_t start, size_t end) {
|
||||
TripletMarginLossCPUKernelMod::complextype_nobroadcast_task<T>(start, end, output_reduction_none_data, inputs,
|
||||
outputs);
|
||||
};
|
||||
auto task_broadcast = [&](size_t start, size_t end) {
|
||||
TripletMarginLossCPUKernelMod::complextype_broadcast_task<T>(start, end, output_reduction_none_data, inputs,
|
||||
outputs);
|
||||
};
|
||||
if (broadcast == true) {
|
||||
if (numelements * sizeof(T) > kParallelDataNum) {
|
||||
CPUKernelUtils::ParallelFor(task_broadcast, batch_size);
|
||||
} else {
|
||||
TripletMarginLossCPUKernelMod::complextype_broadcast_compute<T>(output_reduction_none_data, inputs, outputs);
|
||||
}
|
||||
if (reduction == NONE) {
|
||||
for (size_t i = 0; i < data_num; i++) {
|
||||
*(out_data + i) = *(output_reduction_none_data + i);
|
||||
}
|
||||
}
|
||||
if (reduction == MEAN) {
|
||||
*(out_data) = (out.mean());
|
||||
}
|
||||
if (reduction == SUM) {
|
||||
*(out_data) = (out.sum());
|
||||
}
|
||||
return;
|
||||
}
|
||||
if (numelements * sizeof(T) > kParallelDataNum) {
|
||||
CPUKernelUtils::ParallelFor(task_nobroadcast, batch_size);
|
||||
} else {
|
||||
TripletMarginLossCPUKernelMod::complextype_nobroadcast_compute<T>(output_reduction_none_data, inputs, outputs);
|
||||
}
|
||||
if (reduction == NONE) {
|
||||
for (size_t i = 0; i < data_num; i++) {
|
||||
*(out_data + i) = *(output_reduction_none_data + i);
|
||||
}
|
||||
}
|
||||
if (reduction == MEAN) {
|
||||
*(out_data) = (out.mean());
|
||||
}
|
||||
if (reduction == SUM) {
|
||||
*(out_data) = (out.sum());
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void TripletMarginLossCPUKernelMod::realtype_nobroadcast_task(size_t start, size_t end,
|
||||
float *output_reduction_none_data,
|
||||
const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
auto x_addr = reinterpret_cast<T *>(inputs[0]->addr);
|
||||
auto positive_addr = reinterpret_cast<T *>(inputs[1]->addr);
|
||||
auto negative_addr = reinterpret_cast<T *>(inputs[2]->addr);
|
||||
|
@ -298,9 +302,9 @@ void TripletMarginLossCPUKernel::realtype_nobroadcast_task(size_t start, size_t
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
void TripletMarginLossCPUKernel::realtype_broadcast_task(size_t start, size_t end, float *output_reduction_none_data,
|
||||
const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
void TripletMarginLossCPUKernelMod::realtype_broadcast_task(size_t start, size_t end, float *output_reduction_none_data,
|
||||
const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
auto x_addr = reinterpret_cast<T *>(inputs[0]->addr);
|
||||
auto positive_addr = reinterpret_cast<T *>(inputs[1]->addr);
|
||||
auto negative_addr = reinterpret_cast<T *>(inputs[2]->addr);
|
||||
|
@ -349,8 +353,8 @@ void TripletMarginLossCPUKernel::realtype_broadcast_task(size_t start, size_t en
|
|||
}
|
||||
calc_1_sum += calculate_positive[k];
|
||||
calc_2_sum += calculate_negative[k];
|
||||
TripletMarginLossCPUKernel::realtype_swap<T>(start, positive_broadcast, negative_broadcast, calculate_swap, j,
|
||||
k, calc_swap_sum, inputs, outputs);
|
||||
TripletMarginLossCPUKernelMod::realtype_swap<T>(start, positive_broadcast, negative_broadcast, calculate_swap,
|
||||
j, k, calc_swap_sum, inputs, outputs);
|
||||
}
|
||||
positive_distance = std::pow(static_cast<double>(calc_1_sum), (1 / static_cast<float>(p)));
|
||||
if (x_reshape_vector[1] == 1 && positive_reshape_vector[1] == 1 && broadcast_shape[1] != 1) {
|
||||
|
@ -375,9 +379,9 @@ void TripletMarginLossCPUKernel::realtype_broadcast_task(size_t start, size_t en
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
void TripletMarginLossCPUKernel::realtype_broadcast_compute(float *output_reduction_none_data,
|
||||
const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
void TripletMarginLossCPUKernelMod::realtype_broadcast_compute(float *output_reduction_none_data,
|
||||
const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
auto x_addr = reinterpret_cast<T *>(inputs[0]->addr);
|
||||
auto positive_addr = reinterpret_cast<T *>(inputs[1]->addr);
|
||||
auto negative_addr = reinterpret_cast<T *>(inputs[2]->addr);
|
||||
|
@ -424,8 +428,8 @@ void TripletMarginLossCPUKernel::realtype_broadcast_compute(float *output_reduct
|
|||
}
|
||||
calc_1_sum += calculate_positive[k];
|
||||
calc_2_sum += calculate_negative[k];
|
||||
TripletMarginLossCPUKernel::realtype_swap<T>(i * data_num_each_batch, positive_broadcast, negative_broadcast,
|
||||
calculate_swap, j, k, calc_swap_sum, inputs, outputs);
|
||||
TripletMarginLossCPUKernelMod::realtype_swap<T>(i * data_num_each_batch, positive_broadcast, negative_broadcast,
|
||||
calculate_swap, j, k, calc_swap_sum, inputs, outputs);
|
||||
}
|
||||
positive_distance = std::pow(static_cast<double>(calc_1_sum), (1 / static_cast<float>(p)));
|
||||
if (x_reshape_vector[1] == 1 && positive_reshape_vector[1] == 1 && broadcast_shape[1] != 1) {
|
||||
|
@ -449,9 +453,9 @@ void TripletMarginLossCPUKernel::realtype_broadcast_compute(float *output_reduct
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
void TripletMarginLossCPUKernel::realtype_nobroadcast_compute(float *output_reduction_none_data,
|
||||
const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
void TripletMarginLossCPUKernelMod::realtype_nobroadcast_compute(float *output_reduction_none_data,
|
||||
const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
auto x_addr = reinterpret_cast<T *>(inputs[0]->addr);
|
||||
auto positive_addr = reinterpret_cast<T *>(inputs[1]->addr);
|
||||
auto negative_addr = reinterpret_cast<T *>(inputs[2]->addr);
|
||||
|
@ -518,10 +522,10 @@ void TripletMarginLossCPUKernel::realtype_nobroadcast_compute(float *output_redu
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
void TripletMarginLossCPUKernel::complextype_nobroadcast_task(size_t start, size_t end,
|
||||
float *output_reduction_none_data,
|
||||
const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
void TripletMarginLossCPUKernelMod::complextype_nobroadcast_task(size_t start, size_t end,
|
||||
float *output_reduction_none_data,
|
||||
const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
auto x_addr = reinterpret_cast<T *>(inputs[0]->addr);
|
||||
auto positive_addr = reinterpret_cast<T *>(inputs[1]->addr);
|
||||
auto negative_addr = reinterpret_cast<T *>(inputs[2]->addr);
|
||||
|
@ -575,9 +579,10 @@ void TripletMarginLossCPUKernel::complextype_nobroadcast_task(size_t start, size
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
void TripletMarginLossCPUKernel::complextype_broadcast_task(size_t start, size_t end, float *output_reduction_none_data,
|
||||
const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
void TripletMarginLossCPUKernelMod::complextype_broadcast_task(size_t start, size_t end,
|
||||
float *output_reduction_none_data,
|
||||
const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
auto x_addr = reinterpret_cast<T *>(inputs[0]->addr);
|
||||
auto positive_addr = reinterpret_cast<T *>(inputs[1]->addr);
|
||||
auto negative_addr = reinterpret_cast<T *>(inputs[2]->addr);
|
||||
|
@ -630,8 +635,8 @@ void TripletMarginLossCPUKernel::complextype_broadcast_task(size_t start, size_t
|
|||
}
|
||||
calc_1_sum += calculate_positive_float;
|
||||
calc_2_sum += calculate_negative_float;
|
||||
TripletMarginLossCPUKernel::complextype_swap<T>(start, positive_broadcast, negative_broadcast, calculate_swap,
|
||||
j, k, calc_swap_sum, inputs, outputs);
|
||||
TripletMarginLossCPUKernelMod::complextype_swap<T>(start, positive_broadcast, negative_broadcast,
|
||||
calculate_swap, j, k, calc_swap_sum, inputs, outputs);
|
||||
}
|
||||
positive_distance = std::pow(static_cast<double>(calc_1_sum), (1 / static_cast<float>(p)));
|
||||
if (x_reshape_vector[1] == 1 && positive_reshape_vector[1] == 1 && broadcast_shape[1] != 1) {
|
||||
|
@ -656,9 +661,9 @@ void TripletMarginLossCPUKernel::complextype_broadcast_task(size_t start, size_t
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
void TripletMarginLossCPUKernel::complextype_broadcast_compute(float *output_reduction_none_data,
|
||||
const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
void TripletMarginLossCPUKernelMod::complextype_broadcast_compute(float *output_reduction_none_data,
|
||||
const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
auto x_addr = reinterpret_cast<T *>(inputs[0]->addr);
|
||||
auto positive_addr = reinterpret_cast<T *>(inputs[1]->addr);
|
||||
auto negative_addr = reinterpret_cast<T *>(inputs[2]->addr);
|
||||
|
@ -707,8 +712,9 @@ void TripletMarginLossCPUKernel::complextype_broadcast_compute(float *output_red
|
|||
}
|
||||
calc_1_sum += calculate_positive_float;
|
||||
calc_2_sum += calculate_negative_float;
|
||||
TripletMarginLossCPUKernel::complextype_swap<T>(i * data_num_each_batch, positive_broadcast, negative_broadcast,
|
||||
calculate_swap, j, k, calc_swap_sum, inputs, outputs);
|
||||
TripletMarginLossCPUKernelMod::complextype_swap<T>(i * data_num_each_batch, positive_broadcast,
|
||||
negative_broadcast, calculate_swap, j, k, calc_swap_sum,
|
||||
inputs, outputs);
|
||||
}
|
||||
positive_distance = std::pow(static_cast<double>(calc_1_sum), (1 / static_cast<float>(p)));
|
||||
if (x_reshape_vector[1] == 1 && positive_reshape_vector[1] == 1 && broadcast_shape[1] != 1) {
|
||||
|
@ -732,9 +738,9 @@ void TripletMarginLossCPUKernel::complextype_broadcast_compute(float *output_red
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
void TripletMarginLossCPUKernel::complextype_nobroadcast_compute(float *output_reduction_none_data,
|
||||
const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
void TripletMarginLossCPUKernelMod::complextype_nobroadcast_compute(float *output_reduction_none_data,
|
||||
const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
auto x_addr = reinterpret_cast<T *>(inputs[0]->addr);
|
||||
auto positive_addr = reinterpret_cast<T *>(inputs[1]->addr);
|
||||
auto negative_addr = reinterpret_cast<T *>(inputs[2]->addr);
|
||||
|
@ -786,11 +792,11 @@ void TripletMarginLossCPUKernel::complextype_nobroadcast_compute(float *output_r
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
void TripletMarginLossCPUKernel::realtype_swap(size_t start, std::vector<T> &positive_broadcast,
|
||||
std::vector<T> &negative_broadcast, std::vector<float> &calculate_swap,
|
||||
size_t j, size_t k, float &calc_swap_sum,
|
||||
const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
void TripletMarginLossCPUKernelMod::realtype_swap(size_t start, std::vector<T> &positive_broadcast,
|
||||
std::vector<T> &negative_broadcast,
|
||||
std::vector<float> &calculate_swap, size_t j, size_t k,
|
||||
float &calc_swap_sum, const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
if (swap == true) {
|
||||
calculate_swap[k] = abs(static_cast<float>(positive_broadcast[start + j + k * index]) -
|
||||
static_cast<float>(negative_broadcast[start + j + k * index]) + eps);
|
||||
|
@ -803,11 +809,11 @@ void TripletMarginLossCPUKernel::realtype_swap(size_t start, std::vector<T> &pos
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
void TripletMarginLossCPUKernel::complextype_swap(size_t start, std::vector<T> &positive_broadcast,
|
||||
std::vector<T> &negative_broadcast, std::vector<T> &calculate_swap,
|
||||
size_t j, size_t k, float &calc_swap_sum,
|
||||
const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
void TripletMarginLossCPUKernelMod::complextype_swap(size_t start, std::vector<T> &positive_broadcast,
|
||||
std::vector<T> &negative_broadcast, std::vector<T> &calculate_swap,
|
||||
size_t j, size_t k, float &calc_swap_sum,
|
||||
const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
if (swap == true) {
|
||||
calculate_swap[k] =
|
||||
positive_broadcast[start + j + k * index] - negative_broadcast[start + j + k * index] + static_cast<T>(eps);
|
||||
|
@ -821,17 +827,19 @@ void TripletMarginLossCPUKernel::complextype_swap(size_t start, std::vector<T> &
|
|||
}
|
||||
}
|
||||
|
||||
void TripletMarginLossCPUKernel::CheckParam(const CNodePtr &kernel_node) {
|
||||
auto input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
constexpr int kone = 1;
|
||||
void TripletMarginLossCPUKernelMod::CheckParam(const CNodePtr &kernel_node) {
|
||||
auto input_num = common::AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
constexpr int kOne = 1;
|
||||
constexpr int kfour = 4;
|
||||
if (input_num != kfour) {
|
||||
MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but TripletMarginLossCPUKernel needs 4 inputs.";
|
||||
MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but TripletMarginLossCPUKernelMod needs 4 inputs.";
|
||||
}
|
||||
auto output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
|
||||
if (output_num != kone) {
|
||||
MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but TripletMarginLossCPUKernel needs 1 output.";
|
||||
auto output_num = common::AnfAlgo::GetOutputTensorNum(kernel_node);
|
||||
if (output_num != kOne) {
|
||||
MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but TripletMarginLossCPUKernelMod needs 1 output.";
|
||||
}
|
||||
}
|
||||
|
||||
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, TripletMarginLoss, TripletMarginLossCPUKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -25,21 +25,99 @@
|
|||
#include <map>
|
||||
#include <complex>
|
||||
#include <iostream>
|
||||
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
|
||||
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel.h"
|
||||
#include "plugin/factory/ms_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
class TripletMarginLossCPUKernel : public CPUKernel {
|
||||
class TripletMarginLossCPUKernelMod : public DeprecatedNativeCpuKernelMod {
|
||||
public:
|
||||
TripletMarginLossCPUKernel() = default;
|
||||
~TripletMarginLossCPUKernel() override = default;
|
||||
TripletMarginLossCPUKernelMod() = default;
|
||||
~TripletMarginLossCPUKernelMod() override = default;
|
||||
|
||||
void InitKernel(const CNodePtr &kernel_node) override;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override {
|
||||
static std::vector<KernelAttr> support_list = {KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt8)
|
||||
.AddInputAttr(kNumberTypeInt8)
|
||||
.AddInputAttr(kNumberTypeInt8)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt16)
|
||||
.AddInputAttr(kNumberTypeInt16)
|
||||
.AddInputAttr(kNumberTypeInt16)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeUInt8)
|
||||
.AddInputAttr(kNumberTypeUInt8)
|
||||
.AddInputAttr(kNumberTypeUInt8)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeUInt16)
|
||||
.AddInputAttr(kNumberTypeUInt16)
|
||||
.AddInputAttr(kNumberTypeUInt16)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeUInt32)
|
||||
.AddInputAttr(kNumberTypeUInt32)
|
||||
.AddInputAttr(kNumberTypeUInt32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeUInt64)
|
||||
.AddInputAttr(kNumberTypeUInt64)
|
||||
.AddInputAttr(kNumberTypeUInt64)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeComplex64)
|
||||
.AddInputAttr(kNumberTypeComplex64)
|
||||
.AddInputAttr(kNumberTypeComplex64)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeComplex128)
|
||||
.AddInputAttr(kNumberTypeComplex128)
|
||||
.AddInputAttr(kNumberTypeComplex128)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)};
|
||||
return support_list;
|
||||
}
|
||||
|
||||
private:
|
||||
template <typename T>
|
||||
void LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
|
||||
|
||||
|
@ -96,7 +174,6 @@ class TripletMarginLossCPUKernel : public CPUKernel {
|
|||
std::vector<T> &calculate_swap, size_t j, size_t k, float &calc_swap_sum,
|
||||
const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &outputs);
|
||||
|
||||
private:
|
||||
void CheckParam(const CNodePtr &kernel_node);
|
||||
int64_t p = 2;
|
||||
bool swap = false;
|
||||
|
@ -109,13 +186,13 @@ class TripletMarginLossCPUKernel : public CPUKernel {
|
|||
TypeId dtype_1{kTypeUnknown};
|
||||
TypeId dtype_2{kTypeUnknown};
|
||||
TypeId dtype_3{kTypeUnknown};
|
||||
std::vector<size_t> x_shape;
|
||||
std::vector<size_t> positive_shape;
|
||||
std::vector<size_t> negative_shape;
|
||||
std::vector<size_t> broadcast_shape;
|
||||
std::vector<size_t> x_reshape_vector;
|
||||
std::vector<size_t> positive_reshape_vector;
|
||||
std::vector<size_t> negative_reshape_vector;
|
||||
ShapeVector x_shape;
|
||||
ShapeVector positive_shape;
|
||||
ShapeVector negative_shape;
|
||||
ShapeVector broadcast_shape;
|
||||
ShapeVector x_reshape_vector;
|
||||
ShapeVector positive_reshape_vector;
|
||||
ShapeVector negative_reshape_vector;
|
||||
size_t numelements = 1;
|
||||
size_t data_num = 1;
|
||||
size_t data_num_each_batch = 1;
|
||||
|
@ -124,114 +201,6 @@ class TripletMarginLossCPUKernel : public CPUKernel {
|
|||
size_t once_compute_size = 1;
|
||||
bool broadcast = false;
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL(TripletMarginLoss,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt8)
|
||||
.AddInputAttr(kNumberTypeInt8)
|
||||
.AddInputAttr(kNumberTypeInt8)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
TripletMarginLossCPUKernel);
|
||||
|
||||
MS_REG_CPU_KERNEL(TripletMarginLoss,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt16)
|
||||
.AddInputAttr(kNumberTypeInt16)
|
||||
.AddInputAttr(kNumberTypeInt16)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
TripletMarginLossCPUKernel);
|
||||
|
||||
MS_REG_CPU_KERNEL(TripletMarginLoss,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
TripletMarginLossCPUKernel);
|
||||
|
||||
MS_REG_CPU_KERNEL(TripletMarginLoss,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
TripletMarginLossCPUKernel);
|
||||
|
||||
MS_REG_CPU_KERNEL(TripletMarginLoss,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeUInt8)
|
||||
.AddInputAttr(kNumberTypeUInt8)
|
||||
.AddInputAttr(kNumberTypeUInt8)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
TripletMarginLossCPUKernel);
|
||||
|
||||
MS_REG_CPU_KERNEL(TripletMarginLoss,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeUInt16)
|
||||
.AddInputAttr(kNumberTypeUInt16)
|
||||
.AddInputAttr(kNumberTypeUInt16)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
TripletMarginLossCPUKernel);
|
||||
|
||||
MS_REG_CPU_KERNEL(TripletMarginLoss,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeUInt32)
|
||||
.AddInputAttr(kNumberTypeUInt32)
|
||||
.AddInputAttr(kNumberTypeUInt32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
TripletMarginLossCPUKernel);
|
||||
|
||||
MS_REG_CPU_KERNEL(TripletMarginLoss,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeUInt64)
|
||||
.AddInputAttr(kNumberTypeUInt64)
|
||||
.AddInputAttr(kNumberTypeUInt64)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
TripletMarginLossCPUKernel);
|
||||
|
||||
MS_REG_CPU_KERNEL(TripletMarginLoss,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
TripletMarginLossCPUKernel);
|
||||
|
||||
MS_REG_CPU_KERNEL(TripletMarginLoss,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
TripletMarginLossCPUKernel);
|
||||
|
||||
MS_REG_CPU_KERNEL(TripletMarginLoss,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeComplex64)
|
||||
.AddInputAttr(kNumberTypeComplex64)
|
||||
.AddInputAttr(kNumberTypeComplex64)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
TripletMarginLossCPUKernel);
|
||||
|
||||
MS_REG_CPU_KERNEL(TripletMarginLoss,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeComplex128)
|
||||
.AddInputAttr(kNumberTypeComplex128)
|
||||
.AddInputAttr(kNumberTypeComplex128)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
TripletMarginLossCPUKernel);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_TRIPLET_MARGIN_LOSS_CPU_KERNEL_H_
|
||||
|
|
|
@ -592,7 +592,7 @@ GVAR_DEF(PrimitivePtr, kPrimMaxPoolGradWithArgmax, std::make_shared<Primitive>("
|
|||
GVAR_DEF(PrimitivePtr, kPrimMaxPoolGradGradWithArgmax, std::make_shared<Primitive>("MaxPoolGradGradWithArgmax"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimMaxPool3DWithArgmax, std::make_shared<Primitive>("MaxPool3DWithArgmax"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimMaxPool3DGradWithArgmax, std::make_shared<Primitive>("MaxPool3DGradWithArgmax"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimMaxUnpool2D, std::make_shared<Primitive>(kMaxUnpool2DGrad));
|
||||
GVAR_DEF(PrimitivePtr, kPrimMaxUnpool2D, std::make_shared<Primitive>(kMaxUnpool2D));
|
||||
GVAR_DEF(PrimitivePtr, kPrimMaxUnpool2DGrad, std::make_shared<Primitive>(kMaxUnpool2DGrad));
|
||||
GVAR_DEF(PrimitivePtr, kPrimMaxUnpool3D, std::make_shared<Primitive>(kMaxUnpool3D));
|
||||
GVAR_DEF(PrimitivePtr, kPrimMaxUnpool3DGrad, std::make_shared<Primitive>(kMaxUnpool3DGrad));
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -16,9 +16,13 @@
|
|||
|
||||
#include "ops/grad/max_unpool2d_grad.h"
|
||||
#include <set>
|
||||
#include "ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "abstract/ops/primitive_infer_map.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr int64_t k4DInputDims = 4;
|
||||
namespace {
|
||||
abstract::ShapePtr MaxUnpool2DGradInferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
|
@ -34,11 +38,10 @@ abstract::ShapePtr MaxUnpool2DGradInferShape(const PrimitivePtr &primitive,
|
|||
auto grads_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->GetShapeTrack())[kShape];
|
||||
auto argmax_shape =
|
||||
CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->GetShapeTrack())[kShape];
|
||||
(void)CheckAndConvertUtils::CheckInteger("x_rank", SizeToLong(in_shape.size()), kEqual, k4DInputDims, op_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("grads_rank", SizeToLong(grads_shape.size()), kEqual, k4DInputDims, op_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("argmax_rank", SizeToLong(argmax_shape.size()), kEqual, k4DInputDims,
|
||||
op_name);
|
||||
CheckAndConvertUtils::Check("x_shape", in_shape, kEqual, "argmax_shape", argmax_shape, op_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("x_rank", SizeToLong(in_shape.size()), kEqual, kDim4, op_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("grads_rank", SizeToLong(grads_shape.size()), kEqual, kDim4, op_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("argmax_rank", SizeToLong(argmax_shape.size()), kEqual, kDim4, op_name);
|
||||
CheckAndConvertUtils::Check("x_shape", in_shape, kEqual, argmax_shape, op_name, ValueError);
|
||||
auto x1_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape];
|
||||
return std::make_shared<abstract::Shape>(x1_shape);
|
||||
}
|
||||
|
@ -59,6 +62,7 @@ TypePtr MaxUnpool2DGradInferType(const PrimitivePtr &primitive, const std::vecto
|
|||
}
|
||||
} // namespace
|
||||
|
||||
MIND_API_OPERATOR_IMPL(MaxUnpool2DGrad, BaseOperator);
|
||||
AbstractBasePtr MaxUnpool2DGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -19,23 +19,20 @@
|
|||
#include <vector>
|
||||
#include <memory>
|
||||
|
||||
#include "ops/primitive_c.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "ops/base_operator.h"
|
||||
#include "mindapi/base/types.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameMaxUnpool2DGrad = "MaxUnpool2DGrad";
|
||||
class MS_CORE_API MaxUnpool2DGrad : public PrimitiveC {
|
||||
class MIND_API MaxUnpool2DGrad : public BaseOperator {
|
||||
public:
|
||||
MaxUnpool2DGrad() : PrimitiveC(kNameMaxUnpool2DGrad) { InitIOName({"x", "grads", "argmax"}, {"y"}); }
|
||||
~MaxUnpool2DGrad() = default;
|
||||
MS_DECLARE_PARENT(MaxUnpool2DGrad, PrimitiveC);
|
||||
MIND_API_BASE_MEMBER(MaxUnpool2DGrad);
|
||||
MaxUnpool2DGrad() : BaseOperator(kNameMaxUnpool2DGrad) { InitIOName({"x", "grads", "argmax"}, {"y"}); }
|
||||
};
|
||||
|
||||
AbstractBasePtr MaxUnpool2DGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args);
|
||||
abstract::AbstractBasePtr MaxUnpool2DGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||
using PrimMaxUnpool2DGradPtr = std::shared_ptr<MaxUnpool2DGrad>;
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -16,10 +16,13 @@
|
|||
|
||||
#include "ops/grad/max_unpool3d_grad.h"
|
||||
#include <set>
|
||||
#include "ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "abstract/ops/primitive_infer_map.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr int64_t k5DInputDims = 5;
|
||||
|
||||
namespace {
|
||||
abstract::ShapePtr MaxUnpool3DGradInferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
|
@ -35,11 +38,10 @@ abstract::ShapePtr MaxUnpool3DGradInferShape(const PrimitivePtr &primitive,
|
|||
auto grads_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->GetShapeTrack())[kShape];
|
||||
auto argmax_shape =
|
||||
CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->GetShapeTrack())[kShape];
|
||||
(void)CheckAndConvertUtils::CheckInteger("x_rank", SizeToLong(in_shape.size()), kEqual, k5DInputDims, op_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("grads_rank", SizeToLong(grads_shape.size()), kEqual, k5DInputDims, op_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("argmax_rank", SizeToLong(argmax_shape.size()), kEqual, k5DInputDims,
|
||||
op_name);
|
||||
CheckAndConvertUtils::Check("x_shape", in_shape, kEqual, "argmax_shape", argmax_shape, op_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("x_rank", SizeToLong(in_shape.size()), kEqual, kDim5, op_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("grads_rank", SizeToLong(grads_shape.size()), kEqual, kDim5, op_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("argmax_rank", SizeToLong(argmax_shape.size()), kEqual, kDim5, op_name);
|
||||
CheckAndConvertUtils::Check("x_shape", in_shape, kEqual, argmax_shape, op_name, ValueError);
|
||||
auto x1_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape];
|
||||
return std::make_shared<abstract::Shape>(x1_shape);
|
||||
}
|
||||
|
@ -60,6 +62,7 @@ TypePtr MaxUnpool3DGradInferType(const PrimitivePtr &primitive, const std::vecto
|
|||
}
|
||||
} // namespace
|
||||
|
||||
MIND_API_OPERATOR_IMPL(MaxUnpool3DGrad, BaseOperator);
|
||||
AbstractBasePtr MaxUnpool3DGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -19,23 +19,20 @@
|
|||
#include <vector>
|
||||
#include <memory>
|
||||
|
||||
#include "ops/primitive_c.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "ops/base_operator.h"
|
||||
#include "mindapi/base/types.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameMaxUnpool3DGrad = "MaxUnpool3DGrad";
|
||||
class MS_CORE_API MaxUnpool3DGrad : public PrimitiveC {
|
||||
class MIND_API MaxUnpool3DGrad : public BaseOperator {
|
||||
public:
|
||||
MaxUnpool3DGrad() : PrimitiveC(kNameMaxUnpool3DGrad) { InitIOName({"x", "grads", "argmax"}, {"y"}); }
|
||||
~MaxUnpool3DGrad() = default;
|
||||
MS_DECLARE_PARENT(MaxUnpool3DGrad, PrimitiveC);
|
||||
MIND_API_BASE_MEMBER(MaxUnpool3DGrad);
|
||||
MaxUnpool3DGrad() : BaseOperator(kNameMaxUnpool3DGrad) { InitIOName({"x", "grads", "argmax"}, {"y"}); }
|
||||
};
|
||||
|
||||
AbstractBasePtr MaxUnpool3DGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args);
|
||||
abstract::AbstractBasePtr MaxUnpool3DGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||
using PrimMaxUnpool3DGradPtr = std::shared_ptr<MaxUnpool3DGrad>;
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -15,16 +15,14 @@
|
|||
*/
|
||||
|
||||
#include "ops/grad/multi_margin_loss_grad.h"
|
||||
#include "abstract/primitive_infer_map.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "abstract/ops/primitive_infer_map.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
const size_t kone = 1;
|
||||
const size_t ktwo = 2;
|
||||
const size_t kfour = 4;
|
||||
|
||||
TypePtr MultiMarginLossGradInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("target", input_args[kInputIndex2]->BuildType(), {kInt64},
|
||||
prim->name());
|
||||
|
@ -32,7 +30,7 @@ TypePtr MultiMarginLossGradInferType(const PrimitivePtr &prim, const std::vector
|
|||
std::map<std::string, TypePtr> types;
|
||||
(void)types.emplace("y_grad", input_args[kInputIndex0]->BuildType());
|
||||
(void)types.emplace("x", input_args[kInputIndex1]->BuildType());
|
||||
if (input_args.size() == kfour && input_args[kInputIndex3]->BuildType()->isa<TensorType>()) {
|
||||
if (input_args.size() == kDim4 && input_args[kInputIndex3]->BuildType()->isa<TensorType>()) {
|
||||
auto tensor_type = input_args[kInputIndex3]->BuildType()->cast<TensorTypePtr>();
|
||||
MS_EXCEPTION_IF_NULL(tensor_type);
|
||||
auto element = tensor_type->element();
|
||||
|
@ -50,7 +48,7 @@ abstract::ShapePtr MultiMarginLossGradInferShape(const PrimitivePtr &primitive,
|
|||
auto prim_name = primitive->name();
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape];
|
||||
auto target_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape())[kShape];
|
||||
if (x_shape.size() != ktwo || target_shape.size() != kone) {
|
||||
if (x_shape.size() != kDim2 || target_shape.size() != kDim1) {
|
||||
MS_EXCEPTION(ValueError) << "For MultiMarginLossGrad, the rank of input x should be 2, and "
|
||||
"the rank of target should be 1,"
|
||||
<< " while rank of x is " << x_shape.size() << ", rank of target is "
|
||||
|
@ -61,7 +59,7 @@ abstract::ShapePtr MultiMarginLossGradInferShape(const PrimitivePtr &primitive,
|
|||
<< " while x_shape[0] is " << x_shape[kInputIndex0] << ", target_shape[0] is "
|
||||
<< target_shape[kInputIndex0];
|
||||
}
|
||||
if (input_args.size() == kfour && input_args[kInputIndex3]->BuildType()->isa<TensorType>()) {
|
||||
if (input_args.size() == kDim4 && input_args[kInputIndex3]->BuildType()->isa<TensorType>()) {
|
||||
auto tensor_type = input_args[kInputIndex3]->BuildType()->cast<TensorTypePtr>();
|
||||
MS_EXCEPTION_IF_NULL(tensor_type);
|
||||
auto element = tensor_type->element();
|
||||
|
@ -69,7 +67,7 @@ abstract::ShapePtr MultiMarginLossGradInferShape(const PrimitivePtr &primitive,
|
|||
if (element->type_id() != kMetaTypeNone) {
|
||||
auto weight_shape =
|
||||
CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex3]->BuildShape())[kShape];
|
||||
if (weight_shape.size() != kone) {
|
||||
if (weight_shape.size() != kDim1) {
|
||||
MS_EXCEPTION(ValueError) << "For " << prim_name << " the rank of weight should be 1,"
|
||||
<< " but get " << weight_shape.size();
|
||||
}
|
||||
|
@ -84,6 +82,7 @@ abstract::ShapePtr MultiMarginLossGradInferShape(const PrimitivePtr &primitive,
|
|||
}
|
||||
} // namespace
|
||||
|
||||
MIND_API_OPERATOR_IMPL(MultiMarginLossGrad, BaseOperator);
|
||||
AbstractBasePtr MultiMarginLossGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
constexpr size_t kInputNumWithWeight = 4;
|
||||
|
@ -96,7 +95,7 @@ AbstractBasePtr MultiMarginLossGradInfer(const abstract::AnalysisEnginePtr &, co
|
|||
MS_EXCEPTION_IF_NULL(input_args[kInputIndex0]);
|
||||
MS_EXCEPTION_IF_NULL(input_args[kInputIndex1]);
|
||||
MS_EXCEPTION_IF_NULL(input_args[kInputIndex2]);
|
||||
if (input_args.size() == kfour) {
|
||||
if (input_args.size() == kInputNumWithWeight) {
|
||||
MS_EXCEPTION_IF_NULL(input_args[kInputIndex3]);
|
||||
}
|
||||
auto types = MultiMarginLossGradInferType(primitive, input_args);
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -23,24 +23,23 @@
|
|||
#include <set>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "ops/primitive_c.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
|
||||
#include "ops/base_operator.h"
|
||||
#include "mindapi/base/types.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameMultiMarginLossGrad = "MultiMarginLossGrad";
|
||||
class MS_CORE_API MultiMarginLossGrad : public PrimitiveC {
|
||||
class MIND_API MultiMarginLossGrad : public BaseOperator {
|
||||
public:
|
||||
MultiMarginLossGrad() : PrimitiveC(kNameMultiMarginLossGrad) {
|
||||
MIND_API_BASE_MEMBER(MultiMarginLossGrad);
|
||||
MultiMarginLossGrad() : BaseOperator(kNameMultiMarginLossGrad) {
|
||||
InitIOName({"y_grad", "x", "target", "weight"}, {"x_grad"});
|
||||
}
|
||||
~MultiMarginLossGrad() = default;
|
||||
MS_DECLARE_PARENT(MultiMarginLossGrad, PrimitiveC);
|
||||
};
|
||||
|
||||
AbstractBasePtr MultiMarginLossGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args);
|
||||
abstract::AbstractBasePtr MultiMarginLossGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||
using PrimMultiMarginLossGradPtr = std::shared_ptr<MultiMarginLossGrad>;
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -16,8 +16,9 @@
|
|||
|
||||
#include "ops/grad/multilabel_margin_loss_grad.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "utils/tensor_construct_utils.h"
|
||||
#include "abstract/primitive_infer_map.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "abstract/ops/primitive_infer_map.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
|
@ -27,9 +28,7 @@ abstract::ShapePtr MultilabelMarginLossGradInferShape(const PrimitivePtr &primit
|
|||
auto op_name = primitive->name();
|
||||
auto x = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape];
|
||||
auto target = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape())[kShape];
|
||||
const size_t kone = 1;
|
||||
const size_t ktwo = 2;
|
||||
if ((x.size() != kone && x.size() != ktwo) || (target.size() != kone && target.size() != ktwo)) {
|
||||
if ((x.size() != kDim1 && x.size() != kDim2) || (target.size() != kDim1 && target.size() != kDim2)) {
|
||||
MS_EXCEPTION(ValueError) << "For " << op_name << ", the rank of input x and target should be 1 or 2, "
|
||||
<< "while rank of x is : " << x.size() << ", rank of target is : " << target.size() << ".";
|
||||
}
|
||||
|
@ -57,6 +56,7 @@ TypePtr MultilabelMarginLossGradInferType(const PrimitivePtr &primitive,
|
|||
}
|
||||
} // namespace
|
||||
|
||||
MIND_API_OPERATOR_IMPL(MultilabelMarginLossGrad, BaseOperator);
|
||||
AbstractBasePtr MultilabelMarginLossGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -23,24 +23,24 @@
|
|||
#include <set>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "ops/primitive_c.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
|
||||
#include "ops/base_operator.h"
|
||||
#include "mindapi/base/types.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameMultilabelMarginLossGrad = "MultilabelMarginLossGrad";
|
||||
class MS_CORE_API MultilabelMarginLossGrad : public PrimitiveC {
|
||||
class MIND_API MultilabelMarginLossGrad : public BaseOperator {
|
||||
public:
|
||||
MultilabelMarginLossGrad() : PrimitiveC(kNameMultilabelMarginLossGrad) {
|
||||
MIND_API_BASE_MEMBER(MultilabelMarginLossGrad);
|
||||
MultilabelMarginLossGrad() : BaseOperator(kNameMultilabelMarginLossGrad) {
|
||||
InitIOName({"y_grad", "x", "target", "is_target"}, {"x_grad"});
|
||||
}
|
||||
~MultilabelMarginLossGrad() = default;
|
||||
MS_DECLARE_PARENT(MultilabelMarginLossGrad, PrimitiveC);
|
||||
};
|
||||
|
||||
AbstractBasePtr MultilabelMarginLossGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args);
|
||||
abstract::AbstractBasePtr MultilabelMarginLossGradInfer(const abstract::AnalysisEnginePtr &,
|
||||
const PrimitivePtr &primitive,
|
||||
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||
using PrimMultilabelMarginLossGradPtr = std::shared_ptr<MultilabelMarginLossGrad>;
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -27,19 +27,18 @@
|
|||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
abstract::ShapePtr MvlgammaGradInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
abstract::ShapePtr MvlgammaGradInferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
auto y_grad_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
|
||||
auto y_grad_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape];
|
||||
return std::make_shared<abstract::Shape>(y_grad_shape);
|
||||
}
|
||||
|
||||
TypePtr MvlgammaGradInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
std::map<std::string, TypePtr> types;
|
||||
(void)types.emplace("y_grad", input_args[0]->BuildType());
|
||||
(void)types.emplace("x", input_args[1]->BuildType());
|
||||
(void)types.emplace("y_grad", input_args[kInputIndex0]->BuildType());
|
||||
(void)types.emplace("x", input_args[kInputIndex1]->BuildType());
|
||||
const std::set<TypePtr> valid_types = {kFloat32, kFloat64};
|
||||
return CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name());
|
||||
}
|
||||
|
|
|
@ -32,7 +32,7 @@ class MIND_API MvlgammaGrad : public BaseOperator {
|
|||
};
|
||||
|
||||
abstract::AbstractBasePtr MvlgammaGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||
using PrimMvlgammaGradPtr = std::shared_ptr<MvlgammaGrad>;
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -20,17 +20,17 @@
|
|||
#include <algorithm>
|
||||
#include "ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "abstract/primitive_infer_map.h"
|
||||
#include "abstract/ops/primitive_infer_map.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
constexpr int64_t k4DInputDims = 4;
|
||||
|
||||
abstract::ShapePtr InferShapeCompute(const std::string &data_format, const ShapeVector &in_shape,
|
||||
const std::vector<int64_t> &ksize, const std::vector<int64_t> &strides,
|
||||
const std::vector<int64_t> &pads, const std::vector<int64_t> &attr_output_shape,
|
||||
const std::string &op_name) {
|
||||
abstract::ShapePtr MaxUnpool2DInferShapeCompute(const std::string &data_format, const ShapeVector &in_shape,
|
||||
const std::vector<int64_t> &ksize, const std::vector<int64_t> &strides,
|
||||
const std::vector<int64_t> &pads,
|
||||
const std::vector<int64_t> &attr_output_shape,
|
||||
const std::string &op_name) {
|
||||
if (data_format == "NCHW") {
|
||||
int64_t out_h = static_cast<int64_t>((in_shape[kInputIndex2] - 1) * strides[kInputIndex2] - 2 * pads[kInputIndex2] +
|
||||
ksize[kInputIndex2]);
|
||||
|
@ -40,7 +40,7 @@ abstract::ShapePtr InferShapeCompute(const std::string &data_format, const Shape
|
|||
if (std::any_of(out_shape.begin(), out_shape.end(), [](int64_t a) { return a <= 0; })) {
|
||||
MS_LOG(EXCEPTION) << "MaxUnpool2D: Output size is not valid.";
|
||||
}
|
||||
if (attr_output_shape.size() == k4DInputDims) {
|
||||
if (attr_output_shape.size() == kDim4) {
|
||||
(void)CheckAndConvertUtils::CheckInteger("output_shape[0]", attr_output_shape[kInputIndex0], kEqual,
|
||||
in_shape[kInputIndex0], op_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("output_shape[1]", attr_output_shape[kInputIndex1], kEqual,
|
||||
|
@ -74,7 +74,7 @@ abstract::ShapePtr InferShapeCompute(const std::string &data_format, const Shape
|
|||
if (std::any_of(out_shape.begin(), out_shape.end(), [](int64_t a) { return a <= 0; })) {
|
||||
MS_LOG(EXCEPTION) << "MaxUnpool2D: Output size is not valid.";
|
||||
}
|
||||
if (attr_output_shape.size() == k4DInputDims) {
|
||||
if (attr_output_shape.size() == kDim4) {
|
||||
(void)CheckAndConvertUtils::CheckInteger("output_shape[0]", attr_output_shape[kInputIndex0], kEqual,
|
||||
in_shape[kInputIndex0], op_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("output_shape[3]", attr_output_shape[kInputIndex3], kEqual,
|
||||
|
@ -100,7 +100,8 @@ abstract::ShapePtr InferShapeCompute(const std::string &data_format, const Shape
|
|||
}
|
||||
}
|
||||
|
||||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
abstract::ShapePtr MaxUnpool2DInferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto op_name = primitive->name();
|
||||
MS_EXCEPTION_IF_NULL(input_args[kInputIndex0]);
|
||||
|
@ -109,28 +110,27 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
|||
auto argmax_shape =
|
||||
CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->GetShapeTrack())[kShape];
|
||||
auto data_format = GetValue<std::string>(primitive->GetAttr("format"));
|
||||
(void)CheckAndConvertUtils::CheckInteger("x_rank", SizeToLong(in_shape.size()), kEqual, k4DInputDims, op_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("argmax_rank", SizeToLong(argmax_shape.size()), kEqual, k4DInputDims,
|
||||
op_name);
|
||||
CheckAndConvertUtils::Check("x_shape", in_shape, kEqual, "argmax_shape", argmax_shape, op_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("x_rank", SizeToLong(in_shape.size()), kEqual, kDim4, op_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("argmax_rank", SizeToLong(argmax_shape.size()), kEqual, kDim4, op_name);
|
||||
CheckAndConvertUtils::Check("x_shape", in_shape, kEqual, argmax_shape, op_name, ValueError);
|
||||
|
||||
auto ksize = GetValue<std::vector<int64_t>>(primitive->GetAttr("ksize"));
|
||||
auto strides = GetValue<std::vector<int64_t>>(primitive->GetAttr("strides"));
|
||||
auto pads = GetValue<std::vector<int64_t>>(primitive->GetAttr("pads"));
|
||||
auto attr_output_shape = GetValue<std::vector<int64_t>>(primitive->GetAttr("output_shape"));
|
||||
|
||||
(void)CheckAndConvertUtils::CheckInteger("ksize_rank", SizeToLong(ksize.size()), kEqual, k4DInputDims, op_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("strides_rank", SizeToLong(strides.size()), kEqual, k4DInputDims, op_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("pads_rank", SizeToLong(pads.size()), kEqual, k4DInputDims, op_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("ksize_rank", SizeToLong(ksize.size()), kEqual, kDim4, op_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("strides_rank", SizeToLong(strides.size()), kEqual, kDim4, op_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("pads_rank", SizeToLong(pads.size()), kEqual, kDim4, op_name);
|
||||
|
||||
if (attr_output_shape.size() != k4DInputDims && attr_output_shape.size() != 0) {
|
||||
if (attr_output_shape.size() != kDim4 && attr_output_shape.size() != kDim0) {
|
||||
MS_EXCEPTION(ValueError) << "MaxUnpool2D: Output_shape size must be 0 or 4.";
|
||||
}
|
||||
|
||||
return InferShapeCompute(data_format, in_shape, ksize, strides, pads, attr_output_shape, op_name);
|
||||
return MaxUnpool2DInferShapeCompute(data_format, in_shape, ksize, strides, pads, attr_output_shape, op_name);
|
||||
}
|
||||
|
||||
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
TypePtr MaxUnpool2DInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
|
@ -143,13 +143,14 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
|
|||
}
|
||||
} // namespace
|
||||
|
||||
MIND_API_OPERATOR_IMPL(MaxUnpool2D, BaseOperator);
|
||||
AbstractBasePtr MaxUnpool2DInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
const int64_t input_num = 2;
|
||||
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name());
|
||||
auto infer_type = InferType(primitive, input_args);
|
||||
auto infer_shape = InferShape(primitive, input_args);
|
||||
auto infer_type = MaxUnpool2DInferType(primitive, input_args);
|
||||
auto infer_shape = MaxUnpool2DInferShape(primitive, input_args);
|
||||
return abstract::MakeAbstract(infer_shape, infer_type);
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(MaxUnpool2D, prim::kPrimMaxUnpool2D, MaxUnpool2DInfer, nullptr, true);
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -19,23 +19,20 @@
|
|||
#include <vector>
|
||||
#include <memory>
|
||||
|
||||
#include "ops/primitive_c.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "ops/base_operator.h"
|
||||
#include "mindapi/base/types.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameMaxUnpool2D = "MaxUnpool2D";
|
||||
class MS_CORE_API MaxUnpool2D : public PrimitiveC {
|
||||
class MIND_API MaxUnpool2D : public BaseOperator {
|
||||
public:
|
||||
MaxUnpool2D() : PrimitiveC(kNameMaxUnpool2D) { InitIOName({"x", "argmax"}, {"y"}); }
|
||||
~MaxUnpool2D() = default;
|
||||
MS_DECLARE_PARENT(MaxUnpool2D, PrimitiveC);
|
||||
MIND_API_BASE_MEMBER(MaxUnpool2D);
|
||||
MaxUnpool2D() : BaseOperator(kNameMaxUnpool2D) { InitIOName({"x", "argmax"}, {"y"}); }
|
||||
};
|
||||
|
||||
AbstractBasePtr MaxUnpool2DInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args);
|
||||
abstract::AbstractBasePtr MaxUnpool2DInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||
using PrimMaxUnpool2DPtr = std::shared_ptr<MaxUnpool2D>;
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -20,17 +20,17 @@
|
|||
#include <set>
|
||||
#include "ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "abstract/primitive_infer_map.h"
|
||||
#include "abstract/ops/primitive_infer_map.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
constexpr int64_t k5DInputDims = 5;
|
||||
|
||||
abstract::ShapePtr InferShapeCompute(const std::string &data_format, const ShapeVector &in_shape,
|
||||
const std::vector<int64_t> &ksize, const std::vector<int64_t> &strides,
|
||||
const std::vector<int64_t> &pads, const std::vector<int64_t> &attr_output_shape,
|
||||
const std::string &op_name) {
|
||||
abstract::ShapePtr MaxUnpool3DInferShapeCompute(const std::string &data_format, const ShapeVector &in_shape,
|
||||
const std::vector<int64_t> &ksize, const std::vector<int64_t> &strides,
|
||||
const std::vector<int64_t> &pads,
|
||||
const std::vector<int64_t> &attr_output_shape,
|
||||
const std::string &op_name) {
|
||||
if (data_format == "NCDHW") {
|
||||
int64_t out_d = static_cast<int64_t>((in_shape[kInputIndex2] - 1) * strides[kInputIndex2] - 2 * pads[kInputIndex2] +
|
||||
ksize[kInputIndex2]);
|
||||
|
@ -42,7 +42,7 @@ abstract::ShapePtr InferShapeCompute(const std::string &data_format, const Shape
|
|||
if (std::any_of(out_shape.begin(), out_shape.end(), [](int64_t a) { return a <= 0; })) {
|
||||
MS_LOG(EXCEPTION) << "MaxUnpool3D: Output size is not valid.";
|
||||
}
|
||||
if (attr_output_shape.size() == k5DInputDims) {
|
||||
if (attr_output_shape.size() == kDim5) {
|
||||
(void)CheckAndConvertUtils::CheckInteger("output_shape[0]", attr_output_shape[kInputIndex0], kEqual,
|
||||
in_shape[kInputIndex0], op_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("output_shape[1]", attr_output_shape[kInputIndex1], kEqual,
|
||||
|
@ -79,7 +79,7 @@ abstract::ShapePtr InferShapeCompute(const std::string &data_format, const Shape
|
|||
if (std::any_of(out_shape.begin(), out_shape.end(), [](int64_t a) { return a <= 0; })) {
|
||||
MS_LOG(EXCEPTION) << "MaxUnpool3D: Output size is not valid.";
|
||||
}
|
||||
if (attr_output_shape.size() == k5DInputDims) {
|
||||
if (attr_output_shape.size() == kDim5) {
|
||||
(void)CheckAndConvertUtils::CheckInteger("output_shape[0]", attr_output_shape[kInputIndex0], kEqual,
|
||||
in_shape[kInputIndex0], op_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("output_shape[4]", attr_output_shape[kInputIndex4], kEqual,
|
||||
|
@ -108,7 +108,8 @@ abstract::ShapePtr InferShapeCompute(const std::string &data_format, const Shape
|
|||
}
|
||||
}
|
||||
|
||||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
abstract::ShapePtr MaxUnpool3DInferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto op_name = primitive->name();
|
||||
MS_EXCEPTION_IF_NULL(input_args[kInputIndex0]);
|
||||
|
@ -117,26 +118,25 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
|||
auto argmax_shape =
|
||||
CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->GetShapeTrack())[kShape];
|
||||
auto data_format = GetValue<std::string>(primitive->GetAttr("format"));
|
||||
(void)CheckAndConvertUtils::CheckInteger("x_rank", SizeToLong(in_shape.size()), kEqual, k5DInputDims, op_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("argmax_rank", SizeToLong(argmax_shape.size()), kEqual, k5DInputDims,
|
||||
op_name);
|
||||
CheckAndConvertUtils::Check("x_shape", in_shape, kEqual, "argmax_shape", argmax_shape, op_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("x_rank", SizeToLong(in_shape.size()), kEqual, kDim5, op_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("argmax_rank", SizeToLong(argmax_shape.size()), kEqual, kDim5, op_name);
|
||||
CheckAndConvertUtils::Check("x_shape", in_shape, kEqual, argmax_shape, op_name, ValueError);
|
||||
auto ksize = GetValue<std::vector<int64_t>>(primitive->GetAttr("ksize"));
|
||||
auto strides = GetValue<std::vector<int64_t>>(primitive->GetAttr("strides"));
|
||||
auto pads = GetValue<std::vector<int64_t>>(primitive->GetAttr("pads"));
|
||||
auto attr_output_shape = GetValue<std::vector<int64_t>>(primitive->GetAttr("output_shape"));
|
||||
(void)CheckAndConvertUtils::CheckInteger("ksize_rank", SizeToLong(ksize.size()), kEqual, k5DInputDims, op_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("strides_rank", SizeToLong(strides.size()), kEqual, k5DInputDims, op_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("pads_rank", SizeToLong(pads.size()), kEqual, k5DInputDims, op_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("ksize_rank", SizeToLong(ksize.size()), kEqual, kDim5, op_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("strides_rank", SizeToLong(strides.size()), kEqual, kDim5, op_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("pads_rank", SizeToLong(pads.size()), kEqual, kDim5, op_name);
|
||||
|
||||
if (attr_output_shape.size() != k5DInputDims && attr_output_shape.size() != 0) {
|
||||
if (attr_output_shape.size() != kDim5 && attr_output_shape.size() != kDim0) {
|
||||
MS_EXCEPTION(ValueError) << "MaxUnpool3D: Output_shape size must be 0 or 5.";
|
||||
}
|
||||
|
||||
return InferShapeCompute(data_format, in_shape, ksize, strides, pads, attr_output_shape, op_name);
|
||||
return MaxUnpool3DInferShapeCompute(data_format, in_shape, ksize, strides, pads, attr_output_shape, op_name);
|
||||
}
|
||||
|
||||
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
TypePtr MaxUnpool3DInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
|
@ -149,13 +149,14 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
|
|||
}
|
||||
} // namespace
|
||||
|
||||
MIND_API_OPERATOR_IMPL(MaxUnpool3D, BaseOperator);
|
||||
AbstractBasePtr MaxUnpool3DInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
const int64_t input_num = 2;
|
||||
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name());
|
||||
auto infer_type = InferType(primitive, input_args);
|
||||
auto infer_shape = InferShape(primitive, input_args);
|
||||
auto infer_type = MaxUnpool3DInferType(primitive, input_args);
|
||||
auto infer_shape = MaxUnpool3DInferShape(primitive, input_args);
|
||||
return abstract::MakeAbstract(infer_shape, infer_type);
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(MaxUnpool3D, prim::kPrimMaxUnpool3D, MaxUnpool3DInfer, nullptr, true);
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -19,23 +19,20 @@
|
|||
#include <vector>
|
||||
#include <memory>
|
||||
|
||||
#include "ops/primitive_c.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "ops/base_operator.h"
|
||||
#include "mindapi/base/types.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameMaxUnpool3D = "MaxUnpool3D";
|
||||
class MS_CORE_API MaxUnpool3D : public PrimitiveC {
|
||||
class MIND_API MaxUnpool3D : public BaseOperator {
|
||||
public:
|
||||
MaxUnpool3D() : PrimitiveC(kNameMaxUnpool3D) { InitIOName({"x", "argmax"}, {"y"}); }
|
||||
~MaxUnpool3D() = default;
|
||||
MS_DECLARE_PARENT(MaxUnpool3D, PrimitiveC);
|
||||
MIND_API_BASE_MEMBER(MaxUnpool3D);
|
||||
MaxUnpool3D() : BaseOperator(kNameMaxUnpool3D) { InitIOName({"x", "argmax"}, {"y"}); }
|
||||
};
|
||||
|
||||
AbstractBasePtr MaxUnpool3DInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args);
|
||||
abstract::AbstractBasePtr MaxUnpool3DInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||
using PrimMaxUnpool3DPtr = std::shared_ptr<MaxUnpool3D>;
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -15,21 +15,20 @@
|
|||
*/
|
||||
|
||||
#include "ops/multi_margin_loss.h"
|
||||
#include "abstract/primitive_infer_map.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "abstract/ops/primitive_infer_map.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
const size_t kone = 1;
|
||||
const size_t ktwo = 2;
|
||||
const size_t kthree = 3;
|
||||
|
||||
TypePtr MultiMarginLossInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("target", input_args[1]->BuildType(), {kInt64}, prim->name());
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("target", input_args[kInputIndex1]->BuildType(), {kInt64},
|
||||
prim->name());
|
||||
const std::set<TypePtr> valid_types = {kFloat16, kFloat32, kFloat64};
|
||||
std::map<std::string, TypePtr> types;
|
||||
(void)types.emplace("x", input_args[0]->BuildType());
|
||||
(void)types.emplace("x", input_args[kInputIndex0]->BuildType());
|
||||
if (input_args.size() == kInputIndex3 && input_args[kInputIndex2]->BuildType()->isa<TensorType>()) {
|
||||
auto tensor_type = input_args[kInputIndex2]->BuildType()->cast<TensorTypePtr>();
|
||||
MS_EXCEPTION_IF_NULL(tensor_type);
|
||||
|
@ -51,7 +50,7 @@ abstract::ShapePtr MultiMarginLossInferShape(const PrimitivePtr &primitive,
|
|||
MS_EXCEPTION_IF_NULL(input_args[kInputIndex1]);
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape];
|
||||
auto target_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape];
|
||||
if (x_shape.size() != ktwo || target_shape.size() != kone) {
|
||||
if (x_shape.size() != kDim2 || target_shape.size() != kDim1) {
|
||||
MS_EXCEPTION(ValueError) << "For MultiMarginLoss, the rank of input "
|
||||
"x and target should be 2 and 1,"
|
||||
<< " while rank of x is " << x_shape.size() << ", rank of target is "
|
||||
|
@ -62,14 +61,15 @@ abstract::ShapePtr MultiMarginLossInferShape(const PrimitivePtr &primitive,
|
|||
<< " while x_shape[0] is " << x_shape[kInputIndex0] << ", target_shape[0] is "
|
||||
<< target_shape[kInputIndex0];
|
||||
}
|
||||
if (input_args.size() == kthree && input_args[kInputIndex2]->BuildType()->isa<TensorType>()) {
|
||||
if (input_args.size() == kDim3 && input_args[kInputIndex2]->BuildType()->isa<TensorType>()) {
|
||||
auto tensor_type = input_args[kInputIndex2]->BuildType()->cast<TensorTypePtr>();
|
||||
MS_EXCEPTION_IF_NULL(tensor_type);
|
||||
auto element = tensor_type->element();
|
||||
MS_EXCEPTION_IF_NULL(element);
|
||||
if (element->type_id() != kMetaTypeNone) {
|
||||
auto weight_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->BuildShape())[kShape];
|
||||
if (weight_shape.size() != kone) {
|
||||
auto weight_shape =
|
||||
CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape())[kShape];
|
||||
if (weight_shape.size() != kDim1) {
|
||||
MS_EXCEPTION(ValueError) << "For " << prim_name << " the rank of weight should be 1,"
|
||||
<< " but get " << weight_shape.size();
|
||||
}
|
||||
|
@ -90,16 +90,17 @@ abstract::ShapePtr MultiMarginLossInferShape(const PrimitivePtr &primitive,
|
|||
}
|
||||
} // namespace
|
||||
|
||||
MIND_API_OPERATOR_IMPL(MultiMarginLoss, BaseOperator);
|
||||
AbstractBasePtr MultiMarginLossInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
MS_EXCEPTION_IF_NULL(input_args[kInputIndex0]);
|
||||
MS_EXCEPTION_IF_NULL(input_args[kInputIndex1]);
|
||||
if (input_args.size() == kthree) {
|
||||
if (input_args.size() == kDim3) {
|
||||
MS_EXCEPTION_IF_NULL(input_args[kInputIndex2]);
|
||||
}
|
||||
(void)CheckAndConvertUtils::CheckInRange("multi_margin_loss_input_nums", input_args.size(), kIncludeBoth,
|
||||
{ktwo, kthree}, primitive->name());
|
||||
{kDim2, kDim3}, primitive->name());
|
||||
auto types = MultiMarginLossInferType(primitive, input_args);
|
||||
auto shapes = MultiMarginLossInferShape(primitive, input_args);
|
||||
return abstract::MakeAbstract(shapes, types);
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -22,22 +22,21 @@
|
|||
#include <set>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "ops/primitive_c.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
|
||||
#include "ops/base_operator.h"
|
||||
#include "mindapi/base/types.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameMultiMarginLoss = "MultiMarginLoss";
|
||||
class MS_CORE_API MultiMarginLoss : public PrimitiveC {
|
||||
class MIND_API MultiMarginLoss : public BaseOperator {
|
||||
public:
|
||||
MultiMarginLoss() : PrimitiveC(kNameMultiMarginLoss) { InitIOName({"x", "target", "weight"}, {"y"}); }
|
||||
~MultiMarginLoss() = default;
|
||||
MS_DECLARE_PARENT(MultiMarginLoss, PrimitiveC);
|
||||
MIND_API_BASE_MEMBER(MultiMarginLoss);
|
||||
MultiMarginLoss() : BaseOperator(kNameMultiMarginLoss) { InitIOName({"x", "target", "weight"}, {"y"}); }
|
||||
};
|
||||
|
||||
AbstractBasePtr MultiMarginLossInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args);
|
||||
abstract::AbstractBasePtr MultiMarginLossInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||
using PrimMultiMarginLossPtr = std::shared_ptr<MultiMarginLoss>;
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -16,8 +16,9 @@
|
|||
|
||||
#include "ops/multilabel_margin_loss.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "utils/tensor_construct_utils.h"
|
||||
#include "abstract/primitive_infer_map.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "abstract/ops/primitive_infer_map.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
|
@ -37,7 +38,7 @@ abstract::TupleShapePtr MultilabelMarginLossInferShape(const PrimitivePtr &primi
|
|||
MS_EXCEPTION(ValueError) << "For " << op_name << ", x_shape and target_shape should be the same, "
|
||||
<< "while x_shape is : " << x << ", target_shape is : " << target << ".";
|
||||
}
|
||||
int64_t batch = x[0];
|
||||
int64_t batch = x[kInputIndex0];
|
||||
ShapeVector out_shape0 = {batch};
|
||||
ShapeVector out_shape1 = target;
|
||||
int64_t reduction;
|
||||
|
@ -66,6 +67,7 @@ TuplePtr MultilabelMarginLossInferType(const PrimitivePtr &primitive, const std:
|
|||
}
|
||||
} // namespace
|
||||
|
||||
MIND_API_OPERATOR_IMPL(MultilabelMarginLoss, BaseOperator);
|
||||
AbstractBasePtr MultilabelMarginLossInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -21,27 +21,24 @@
|
|||
#include <vector>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include "ops/primitive_c.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
|
||||
#include "ops/base_operator.h"
|
||||
#include "mindapi/base/types.h"
|
||||
|
||||
namespace mindspore {
|
||||
constexpr auto kNameMultilabelMarginLoss = prim::kMultilabelMarginLoss;
|
||||
namespace ops {
|
||||
constexpr auto kNameMultilabelMarginLoss = "MultilabelMarginLoss";
|
||||
/// \brief Creates a criterion that optimizes a multi-class multi-classification hinge loss.
|
||||
/// Refer to Python API @ref mindspore.ops.MultilabelMarginLoss for more details.
|
||||
class MS_CORE_API MultilabelMarginLoss : public PrimitiveC {
|
||||
class MIND_API MultilabelMarginLoss : public BaseOperator {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(MultilabelMarginLoss);
|
||||
/// \brief Constructor.
|
||||
MultilabelMarginLoss() : PrimitiveC(kNameMultilabelMarginLoss) { InitIOName({"x", "target"}, {"y", "is_target"}); }
|
||||
/// \brief Destructor.
|
||||
~MultilabelMarginLoss() = default;
|
||||
MS_DECLARE_PARENT(MultilabelMarginLoss, PrimitiveC);
|
||||
/// \brief Init. Refer to the parameters of Python API @ref mindspore.ops.MultilabelMarginLoss for the inputs.
|
||||
void Init() const {}
|
||||
MultilabelMarginLoss() : BaseOperator(kNameMultilabelMarginLoss) { InitIOName({"x", "target"}, {"y", "is_target"}); }
|
||||
};
|
||||
|
||||
AbstractBasePtr MultilabelMarginLossInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args);
|
||||
abstract::AbstractBasePtr MultilabelMarginLossInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||
using PrimMultilabelMarginLossPtr = std::shared_ptr<MultilabelMarginLoss>;
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -28,14 +28,14 @@ namespace ops {
|
|||
namespace {
|
||||
abstract::ShapePtr MvlgammaInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
MS_EXCEPTION_IF_NULL(input_args[0]);
|
||||
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShapeTrack())[kShape];
|
||||
MS_EXCEPTION_IF_NULL(input_args[kInputIndex0]);
|
||||
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->GetShapeTrack())[kShape];
|
||||
return std::make_shared<abstract::Shape>(in_shape);
|
||||
}
|
||||
|
||||
TypePtr MvlgammaInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
auto input_type = input_args[0]->BuildType();
|
||||
auto input_type = input_args[kInputIndex0]->BuildType();
|
||||
const std::set<TypePtr> valid_types = {kFloat32, kFloat64};
|
||||
return CheckAndConvertUtils::CheckTensorTypeValid("x", input_type, valid_types, prim->name());
|
||||
}
|
||||
|
|
|
@ -30,12 +30,12 @@ constexpr auto kNameMvlgamma = "Mvlgamma";
|
|||
class MIND_API Mvlgamma : public BaseOperator {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(Mvlgamma);
|
||||
/// \brief Constructor.
|
||||
/// \brief Constructor.
|
||||
Mvlgamma() : BaseOperator(kNameMvlgamma) { InitIOName({"x"}, {"y"}); }
|
||||
};
|
||||
|
||||
abstract::AbstractBasePtr MvlgammaInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||
using PrimMvlgammaPtr = std::shared_ptr<Mvlgamma>;
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -16,13 +16,13 @@
|
|||
#include "ops/triplet_margin_loss.h"
|
||||
#include <algorithm>
|
||||
#include "ops/op_utils.h"
|
||||
#include "utils/tensor_construct_utils.h"
|
||||
#include "abstract/primitive_infer_map.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "abstract/ops/primitive_infer_map.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
constexpr size_t kInputSize = 4;
|
||||
abstract::ShapePtr TripletMarginLossInferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto op_name = primitive->name();
|
||||
|
@ -30,21 +30,19 @@ abstract::ShapePtr TripletMarginLossInferShape(const PrimitivePtr &primitive,
|
|||
auto positive = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape];
|
||||
auto negative = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape())[kShape];
|
||||
auto margin = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex3]->BuildShape())[kShape];
|
||||
const int64_t keight = 8;
|
||||
if (x.size() >= keight || positive.size() >= keight || negative.size() >= keight) {
|
||||
if (x.size() >= kDim8 || positive.size() >= kDim8 || negative.size() >= kDim8) {
|
||||
MS_EXCEPTION(ValueError) << "For " << op_name
|
||||
<< ", dimensions of input x positive and negative must be smaller than 8, x_dim: "
|
||||
<< x.size() << ", positive_dim: " << positive.size()
|
||||
<< ", negative_dim: " << negative.size() << ".";
|
||||
}
|
||||
const int64_t kone = 1;
|
||||
if (x.size() <= kone && positive.size() <= kone && negative.size() <= kone) {
|
||||
if (x.size() <= kDim1 && positive.size() <= kDim1 && negative.size() <= kDim1) {
|
||||
MS_EXCEPTION(ValueError)
|
||||
<< "For " << op_name
|
||||
<< ", dimensions of input x, positive and negative cannot be less than 1 at the same time, x_dim: " << x.size()
|
||||
<< ", positive_dim: " << positive.size() << ", negative_dim: " << negative.size() << ".";
|
||||
}
|
||||
if (margin.size() != 0) {
|
||||
if (margin.size() != kDim0) {
|
||||
MS_EXCEPTION(ValueError) << "For " << op_name
|
||||
<< ", the dimension of input margin must be 0, margin_dim: " << margin.size() << ".";
|
||||
}
|
||||
|
@ -61,8 +59,8 @@ abstract::ShapePtr TripletMarginLossInferShape(const PrimitivePtr &primitive,
|
|||
ShapeVector out_shape;
|
||||
for (size_t i = 0; i < dims; i++) {
|
||||
out_shape.push_back((int64_t)std::max(std::max(x[i], positive[i]), negative[i]));
|
||||
if ((x[i] != out_shape[i] && x[i] != kone) || (positive[i] != out_shape[i] && positive[i] != kone) ||
|
||||
(negative[i] != out_shape[i] && negative[i] != kone)) {
|
||||
if ((x[i] != out_shape[i] && x[i] != kDim1) || (positive[i] != out_shape[i] && positive[i] != kDim1) ||
|
||||
(negative[i] != out_shape[i] && negative[i] != kDim1)) {
|
||||
MS_EXCEPTION(ValueError) << "For " << op_name << ", inputs' shape can't broadcast.";
|
||||
}
|
||||
}
|
||||
|
@ -98,6 +96,7 @@ TypePtr TripletMarginLossInferType(const PrimitivePtr &primitive, const std::vec
|
|||
}
|
||||
} // namespace
|
||||
|
||||
MIND_API_OPERATOR_IMPL(TripletMarginLoss, BaseOperator);
|
||||
AbstractBasePtr TripletMarginLossInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -21,24 +21,23 @@
|
|||
#include <vector>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include "ops/primitive_c.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
|
||||
#include "ops/base_operator.h"
|
||||
#include "mindapi/base/types.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameTripletMarginLoss = "TripletMarginLoss";
|
||||
class TripletMarginLoss : public PrimitiveC {
|
||||
class MIND_API TripletMarginLoss : public BaseOperator {
|
||||
public:
|
||||
TripletMarginLoss() : PrimitiveC(kNameTripletMarginLoss) {
|
||||
MIND_API_BASE_MEMBER(TripletMarginLoss);
|
||||
TripletMarginLoss() : BaseOperator(kNameTripletMarginLoss) {
|
||||
InitIOName({"x", "positive", "negative", "margin"}, {"y"});
|
||||
}
|
||||
~TripletMarginLoss() = default;
|
||||
MS_DECLARE_PARENT(TripletMarginLoss, PrimitiveC);
|
||||
};
|
||||
|
||||
AbstractBasePtr TripletMarginLossInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args);
|
||||
abstract::AbstractBasePtr TripletMarginLossInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||
using PrimTripletMarginLossPtr = std::shared_ptr<TripletMarginLoss>;
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -19,6 +19,9 @@ from mindspore import log
|
|||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.common.parameter import Parameter
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops.operations.nn_ops import MultiMarginLoss as MultiMarginLossOp
|
||||
from mindspore.ops.operations.nn_ops import MultilabelMarginLoss as MultilabelMarginLossOp
|
||||
from mindspore.ops.operations.nn_ops import TripletMarginLoss as TripletMarginLossOp
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore import nn
|
||||
from mindspore.ops.primitive import constexpr
|
||||
|
@ -1127,7 +1130,7 @@ class MultiMarginLoss(LossBase):
|
|||
def __init__(self, p=1, margin=1.0, reduction='mean'):
|
||||
"""Initialize MultiMarginLoss."""
|
||||
super(MultiMarginLoss, self).__init__()
|
||||
self.multi_margin_loss = P.MultiMarginLoss(p=p, margin=margin, reduction=reduction)
|
||||
self.multi_margin_loss = MultiMarginLossOp(p=p, margin=margin, reduction=reduction)
|
||||
self.ones = P.Ones()
|
||||
|
||||
def construct(self, x, target, weight=None):
|
||||
|
@ -1369,7 +1372,7 @@ class MultilabelMarginLoss(LossBase):
|
|||
|
||||
def __init__(self, reduction='mean'):
|
||||
super(MultilabelMarginLoss, self).__init__()
|
||||
self.multilabel_margin_loss = P.MultilabelMarginLoss(reduction=reduction)
|
||||
self.multilabel_margin_loss = MultilabelMarginLossOp(reduction=reduction)
|
||||
|
||||
def construct(self, x, target):
|
||||
return self.multilabel_margin_loss(x, target)
|
||||
|
@ -1497,74 +1500,6 @@ def _check_input_dtype(labels_dtype, cls_name):
|
|||
[mstype.int32, mstype.int64, mstype.float16, mstype.float32], cls_name)
|
||||
|
||||
|
||||
class MultilabelMarginLoss(LossBase):
|
||||
r"""
|
||||
MultilabelMarginLoss operation.
|
||||
|
||||
Creates a criterion that optimizes a multi-class multi-classification
|
||||
hinge loss (margin-based loss) between input :math:`x` (a 2D mini-batch `Tensor`)
|
||||
and output :math:`y` (which is a 2D `Tensor` of target class indices).
|
||||
For each sample in the mini-batch:
|
||||
|
||||
.. math::
|
||||
\text{loss}(x, y) = \sum_{ij}\frac{\max(0, 1 - (x[y[j]] - x[i]))}{\text{x.size}(0)}
|
||||
|
||||
where :math:`x \in \left\{0, \; \cdots , \; \text{x.size}(0) - 1\right\}`, \
|
||||
:math:`y \in \left\{0, \; \cdots , \; \text{y.size}(0) - 1\right\}`, \
|
||||
:math:`0 \leq y[j] \leq \text{x.size}(0)-1`, \
|
||||
and :math:`i \neq y[j]` for all :math:`i` and :math:`j`.
|
||||
|
||||
:math:`y` and :math:`x` must have the same size.
|
||||
|
||||
The criterion only considers a contiguous block of non-negative targets that
|
||||
starts at the front.
|
||||
|
||||
This allows for different samples to have variable amounts of target classes.
|
||||
|
||||
Args:
|
||||
reduction (str): Apply specific reduction method to the output: 'none', 'mean', 'sum'. Default: "mean".
|
||||
|
||||
Inputs:
|
||||
- **x** (Tensor) - Predict data. Tensor of shape :math:`(C)` or :math:`(N, C)`, where :math:`N`
|
||||
is the batch size and :math:`C` is the number of classes. Data type must be float16 or float32.
|
||||
- **target** (Tensor) - Ground truth data, with the same shape as `x`, data type must be int32 and
|
||||
label targets padded by -1.
|
||||
|
||||
Outputs:
|
||||
- **y** (Union[Tensor, Scalar]) - The loss of MultilabelMarginLoss. If `reduction` is "none", its shape
|
||||
is :math:`(N)`. Otherwise, a scalar value will be returned.
|
||||
- **is_target** (Tensor) - Output tensor for backward input, with the same shape as `target`,
|
||||
data type must be int32.
|
||||
|
||||
Raises:
|
||||
TypeError: If `x` or `target` is not a Tensor.
|
||||
TypeError: If dtype of `x` is neither float16 nor float32.
|
||||
TypeError: If dtype of `target` is not int32.
|
||||
ValueError: If length of shape of `x` is neither 1 nor 2.
|
||||
ValueError: If shape of `x` is not the same as `target`.
|
||||
ValueError: If `reduction` is not one of 'none', 'mean', 'sum'.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend``
|
||||
|
||||
Examples:
|
||||
>>> loss = nn.MultilabelMarginLoss()
|
||||
>>> x = Tensor(np.array([[0.1, 0.2, 0.4, 0.8], [0.2, 0.3, 0.5, 0.7]]), mindspore.float32)
|
||||
>>> target = Tensor(np.array([[1, 2, 0, 3], [2, 3, -1, 1]]), mindspore.int32)
|
||||
>>> output = loss(x, target)
|
||||
>>> print(output)
|
||||
(Tensor(shape=[], dtype=Float32, value= 0.325), Tensor(shape=[2, 4], dtype=Int32, value=
|
||||
[[1, 1, 1, 1], [0, 0, 1, 1]]))
|
||||
"""
|
||||
|
||||
def __init__(self, reduction='mean'):
|
||||
super(MultilabelMarginLoss, self).__init__()
|
||||
self.multilabel_margin_loss = P.MultilabelMarginLoss(reduction=reduction)
|
||||
|
||||
def construct(self, x, target):
|
||||
return self.multilabel_margin_loss(x, target)
|
||||
|
||||
|
||||
class FocalLoss(LossBase):
|
||||
r"""
|
||||
The loss function proposed by Kaiming team in their paper ``Focal Loss for Dense Object Detection`` improves the
|
||||
|
@ -1867,7 +1802,7 @@ class TripletMarginLoss(LossBase):
|
|||
|
||||
def __init__(self, p=2, swap=False, eps=1e-6, reduction='mean'):
|
||||
super(TripletMarginLoss, self).__init__()
|
||||
self.triplet_margin_loss = P.TripletMarginLoss(p=p, swap=swap, eps=eps, reduction=reduction)
|
||||
self.triplet_margin_loss = TripletMarginLossOp(p=p, swap=swap, eps=eps, reduction=reduction)
|
||||
|
||||
def construct(self, x, positive, negative, margin):
|
||||
return self.triplet_margin_loss(x, positive, negative, margin)
|
||||
|
|
|
@ -40,7 +40,6 @@ from ..operations.array_ops import Expand
|
|||
from ..operations.array_ops import SegmentMean
|
||||
from .. import functional as F
|
||||
from .. import operations as P
|
||||
from ..operations import _grad_ops as G
|
||||
from .._utils.utils import is_shape_unknown
|
||||
from ..operations import _grad_ops as G
|
||||
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
# Copyright 2021-2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -24,12 +24,16 @@ from .._grad.grad_base import bprop_getters
|
|||
from .. import operations as P
|
||||
from ..composite.multitype_ops.zeros_like_impl import zeros_like
|
||||
from ..operations import _grad_ops as G
|
||||
from ..operations.nn_ops import MaxUnpool2D
|
||||
from ..operations.nn_ops import MaxUnpool3D
|
||||
from ..operations.nn_ops import FractionalMaxPool
|
||||
from ..operations._grad_ops import FractionalMaxPoolGrad
|
||||
from ..operations.nn_ops import FractionalMaxPool3DWithFixedKsize
|
||||
from ..operations._grad_ops import FractionalMaxPool3DGradWithFixedKsize
|
||||
from ..operations.nn_ops import FractionalAvgPool
|
||||
from ..operations._grad_ops import FractionalAvgPoolGrad
|
||||
from ..operations.nn_ops import MultiMarginLoss
|
||||
from ..operations.nn_ops import MultilabelMarginLoss
|
||||
from ..operations.nn_ops import NthElement
|
||||
from ..operations.nn_ops import PSROIPooling
|
||||
from ..operations._grad_ops import PSROIPoolingGrad
|
||||
|
@ -92,7 +96,7 @@ def get_bprop_hshrink(self):
|
|||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(P.MultilabelMarginLoss)
|
||||
@bprop_getters.register(MultilabelMarginLoss)
|
||||
def get_bprop_multilabel_margin_loss(self):
|
||||
"""Grad definition for `MultilabelMarginLoss` operation."""
|
||||
input_grad = G.MultilabelMarginLossGrad(reduction=self.reduction)
|
||||
|
@ -120,7 +124,7 @@ def get_bprop_celu(self):
|
|||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(P.MultiMarginLoss)
|
||||
@bprop_getters.register(MultiMarginLoss)
|
||||
def get_bprop_multi_margin_loss(self):
|
||||
"""Grad definition for `MultiMarginLoss` operation."""
|
||||
input_grad = G.MultiMarginLossGrad(p=self.p, margin=self.margin, reduction=self.reduction)
|
||||
|
@ -155,7 +159,7 @@ def get_bprop_relu(self):
|
|||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(P.MaxUnpool2D)
|
||||
@bprop_getters.register(MaxUnpool2D)
|
||||
def get_bprop_maxunpool2d(self):
|
||||
"""Grad definition for `MaxUnpool2D` operation."""
|
||||
maxunpool2d_grad = G.MaxUnpool2DGrad(
|
||||
|
@ -173,7 +177,7 @@ def get_bprop_maxunpool2d(self):
|
|||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(P.MaxUnpool3D)
|
||||
@bprop_getters.register(MaxUnpool3D)
|
||||
def get_bprop_maxunpool3d(self):
|
||||
"""Grad definition for `MaxUnpool3D` operation."""
|
||||
maxunpool3d_grad = G.MaxUnpool3DGrad(
|
||||
|
@ -188,6 +192,8 @@ def get_bprop_maxunpool3d(self):
|
|||
dargmax = zeros_like(argmax)
|
||||
return (dx, dargmax)
|
||||
|
||||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(NthElement)
|
||||
def get_bprop_nth_element(self):
|
||||
|
|
|
@ -29,6 +29,7 @@ bartlett_window_op_info = AiCPURegOp("BartlettWindow") \
|
|||
.dtype_format(DataType.I64_Default, DataType.F64_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(bartlett_window_op_info)
|
||||
def _bartlett_window_aicpu():
|
||||
"""BartlettWindow AiCPU register"""
|
||||
|
|
|
@ -50,6 +50,7 @@ max_unpool2d_op_info = AiCPURegOp("MaxUnpool2D") \
|
|||
.dtype_format(DataType.F64_Default, DataType.I64_Default, DataType.F64_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(max_unpool2d_op_info)
|
||||
def _max_unpool2d_aicpu():
|
||||
"""MaxUnpool2D aicpu register"""
|
||||
|
|
|
@ -51,6 +51,7 @@ max_unpool2d_grad_op_info = AiCPURegOp("MaxUnpool2DGrad") \
|
|||
.dtype_format(DataType.F64_Default, DataType.F64_Default, DataType.I64_Default, DataType.F64_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(max_unpool2d_grad_op_info)
|
||||
def _max_unpool2d_grad_aicpu():
|
||||
"""MaxUnpool2DGrad aicpu register"""
|
||||
|
|
|
@ -50,6 +50,7 @@ max_unpool3d_op_info = AiCPURegOp("MaxUnpool3D") \
|
|||
.dtype_format(DataType.F64_Default, DataType.I64_Default, DataType.F64_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(max_unpool3d_op_info)
|
||||
def _max_unpool3d_aicpu():
|
||||
"""MaxUnpool3D aicpu register"""
|
||||
|
|
|
@ -51,6 +51,7 @@ max_unpool3d_grad_op_info = AiCPURegOp("MaxUnpool3DGrad") \
|
|||
.dtype_format(DataType.F64_Default, DataType.F64_Default, DataType.I64_Default, DataType.F64_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(max_unpool3d_grad_op_info)
|
||||
def _max_unpool3d_grad_aicpu():
|
||||
"""MaxUnpool3DGrad aicpu register"""
|
||||
|
|
|
@ -30,6 +30,7 @@ multi_margin_loss_op_info = AiCPURegOp("MultiMarginLoss") \
|
|||
.dtype_format(DataType.F64_Default, DataType.I64_Default, DataType.F64_Default, DataType.F64_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(multi_margin_loss_op_info)
|
||||
def _multi_margin_loss_aicpu():
|
||||
"""MultiMarginLoss aicpu register"""
|
||||
|
|
|
@ -34,6 +34,7 @@ multi_margin_loss_grad_op_info = AiCPURegOp("MultiMarginLossGrad") \
|
|||
DataType.F64_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(multi_margin_loss_grad_op_info)
|
||||
def _multi_margin_loss_grad_aicpu():
|
||||
"""MultiMarginLossGrad aicpu register"""
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -30,6 +30,7 @@ multilabel_margin_loss_grad_op_info = AiCPURegOp("MultilabelMarginLossGrad") \
|
|||
DataType.F32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(multilabel_margin_loss_grad_op_info)
|
||||
def _multilabel_margin_loss_grad_aicpu():
|
||||
"""MultilabelMarginLossGrad aicpu register"""
|
||||
|
|
|
@ -25,6 +25,7 @@ mvlgamma_op_info = AiCPURegOp("Mvlgamma") \
|
|||
.dtype_format(DataType.F64_Default, DataType.F64_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(mvlgamma_op_info)
|
||||
def _mvlgamma_aicpu():
|
||||
"""Mvlgamma AiCPU register"""
|
||||
|
|
|
@ -26,6 +26,7 @@ mvlgamma_grad_op_info = AiCPURegOp("MvlgammaGrad") \
|
|||
.dtype_format(DataType.F64_Default, DataType.F64_Default, DataType.F64_Default,) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(mvlgamma_grad_op_info)
|
||||
def _mvlgamma_grad_aicpu():
|
||||
"""MvlgammaGrad AiCPU register"""
|
||||
|
|
|
@ -591,7 +591,6 @@ from .extract_volume_patches import _extract_volume_patches_tbe
|
|||
from .multilabel_margin_loss import _multilabel_margin_loss_tbe
|
||||
from .round_ds import _round_ds_tbe
|
||||
from .is_close import _is_close_tbe
|
||||
from .multilabel_margin_loss import _multilabel_margin_loss_tbe
|
||||
from .apply_adam_with_amsgrad import _apply_adam_with_amsgrad_tbe
|
||||
from .apply_adam_with_amsgrad_ds import _apply_adam_with_amsgrad_ds_tbe
|
||||
from .expm1_ds import _expm1_ds_tbe
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
|
|
@ -29,7 +29,8 @@ from ...common._decorator import deprecated
|
|||
from ..primitive import Primitive, PrimitiveWithInfer, PrimitiveWithCheck, prim_attr_register
|
||||
|
||||
|
||||
def _check_positive_int_or_tuple(arg_name, arg_value, prim_name, allow_four=False, ret_four=False):
|
||||
def _check_positive_int_or_tuple(arg_name, arg_value, prim_name, allow_four=False,
|
||||
ret_four=False, strict_positive=True):
|
||||
"""
|
||||
Checks whether an argument is a positive int or tuple with 2 or 4(when allow_four is True) positive int elements.
|
||||
"""
|
||||
|
@ -54,8 +55,11 @@ def _check_positive_int_or_tuple(arg_name, arg_value, prim_name, allow_four=Fals
|
|||
validator.check_value_type(arg_name, arg_value, (int, tuple), prim_name)
|
||||
ret_value = _get_return_value()
|
||||
for item in ret_value:
|
||||
if isinstance(item, int) and not isinstance(item, bool) and item > 0:
|
||||
continue
|
||||
if isinstance(item, int) and not isinstance(item, bool):
|
||||
if item > 0:
|
||||
continue
|
||||
if not strict_positive and item == 0:
|
||||
continue
|
||||
_raise_message()
|
||||
return ret_value
|
||||
|
||||
|
@ -2117,7 +2121,7 @@ class MaxUnpool2D(Primitive):
|
|||
if strides in (0, (0, 0)):
|
||||
strides = ksize
|
||||
self.strides = _check_positive_int_or_tuple('strides', strides, self.name, ret_four=True)
|
||||
self.pads = _check_positive_int_or_tuple('pads', pads, self.name, ret_four=True, greater_zero=False)
|
||||
self.pads = _check_positive_int_or_tuple('pads', pads, self.name, ret_four=True, strict_positive=False)
|
||||
self.data_format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'data_format', self.name)
|
||||
|
||||
if data_format == "NHWC":
|
||||
|
|
|
@ -84,6 +84,13 @@ from mindspore.ops.operations.nn_ops import FractionalAvgPool
|
|||
from mindspore.ops.operations._grad_ops import FractionalAvgPoolGrad
|
||||
from mindspore.ops.operations.nn_ops import GridSampler2D
|
||||
from mindspore.ops.operations.nn_ops import GridSampler3D
|
||||
from mindspore.ops.operations.nn_ops import MaxUnpool2D
|
||||
from mindspore.ops.operations.nn_ops import MaxUnpool3D
|
||||
from mindspore.nn.loss.loss import MultiMarginLoss
|
||||
from mindspore.nn.loss.loss import MultilabelMarginLoss
|
||||
from mindspore.nn.loss.loss import TripletMarginLoss
|
||||
from mindspore.ops.operations.array_ops import Mvlgamma
|
||||
from mindspore.ops.operations.other_ops import BartlettWindow
|
||||
from mindspore.ops.operations.nn_ops import NthElement
|
||||
from mindspore.ops.operations.nn_ops import SparseApplyAdagradDA
|
||||
from mindspore.ops.operations.nn_ops import PSROIPooling
|
||||
|
@ -2352,27 +2359,15 @@ test_case_nn_ops = [
|
|||
'desc_inputs': [[10, 3, 28, 31, 24]],
|
||||
'desc_bprop': [[10, 3, 14, 16, 12]]}),
|
||||
('MaxUnpool2D', {
|
||||
'block': P.MaxUnpool2D(ksize=(4, 4), strides=(2, 2), pads=(2, 2)),
|
||||
'block': MaxUnpool2D(ksize=(4, 4), strides=(2, 2), pads=(2, 2)),
|
||||
'desc_inputs': [([4, 3, 6, 6], {'dtype': np.float32}),
|
||||
([4, 3, 6, 6], {'dtype': np.int64})],
|
||||
'desc_bprop': [([4, 3, 10, 10], {'dtype': np.float32})]}),
|
||||
('MaxUnpool2DGrad', {
|
||||
'block': G.MaxUnpool2DGrad(ksize=(1, 1, 4, 4), strides=(1, 1, 2, 2), pads=(1, 1, 2, 2)),
|
||||
'desc_inputs': [([4, 3, 6, 6], {'dtype': np.float32}),
|
||||
([4, 3, 10, 10], {'dtype': np.float32}),
|
||||
([4, 3, 6, 6], {'dtype': np.int64})],
|
||||
'skip': ['backward']}),
|
||||
('MaxUnpool3D', {
|
||||
'block': P.MaxUnpool3D(ksize=(4, 4, 4), strides=(2, 2, 2), pads=(2, 2, 2)),
|
||||
'block': MaxUnpool3D(ksize=(4, 4, 4), strides=(2, 2, 2), pads=(2, 2, 2)),
|
||||
'desc_inputs': [([4, 3, 6, 6, 5], {'dtype': np.float32}),
|
||||
([4, 3, 6, 6, 5], {'dtype': np.int64})],
|
||||
'desc_bprop': [([4, 3, 10, 10, 8], {'dtype': np.float32})]}),
|
||||
('MaxUnpool3DGrad', {
|
||||
'block': G.MaxUnpool3DGrad(ksize=(1, 1, 4, 4, 4), strides=(1, 1, 2, 2, 2), pads=(1, 1, 2, 2, 2)),
|
||||
'desc_inputs': [([4, 3, 6, 6, 5], {'dtype': np.float32}),
|
||||
([4, 3, 10, 10, 8], {'dtype': np.float32}),
|
||||
([4, 3, 6, 6, 5], {'dtype': np.int64})],
|
||||
'skip': ['backward']}),
|
||||
('MaxPoolWithArgmax', {
|
||||
'block': P.MaxPoolWithArgmax(kernel_size=2, strides=2),
|
||||
'desc_inputs': [[128, 32, 32, 64]],
|
||||
|
@ -2752,18 +2747,10 @@ test_case_nn_ops = [
|
|||
'desc_inputs': [[2, 2, 3], Tensor(np.array([2, 2]).astype(np.int32))],
|
||||
'skip': ['backward']}),
|
||||
('MultiMarginLoss', {
|
||||
'block': nn.MultiMarginLoss(reduction="mean"),
|
||||
'block': MultiMarginLoss(reduction="mean"),
|
||||
'desc_inputs': [Tensor(np.array([[0.3, 0.7], [0.5, 0.5]]).astype(np.float32)),
|
||||
Tensor(np.array([0, 0]).astype(np.int64))],
|
||||
'desc_bprop': [[1]]}),
|
||||
('MultiMarginLossGrad', {
|
||||
'block': G.MultiMarginLossGrad(),
|
||||
'desc_inputs': [Tensor(np.array([1]).astype(np.float32)),
|
||||
Tensor(np.array([[0.3, 0.7], [0.5, 0.5]]).astype(np.float32)),
|
||||
Tensor(np.array([1, 1]).astype(np.int64)),
|
||||
Tensor(np.array([1, 1]).astype(np.float32))],
|
||||
'desc_bprop': [Tensor([1], mstype.float32)],
|
||||
'skip': ['backward']}),
|
||||
('L2Loss_1', {
|
||||
'block': P.L2Loss(),
|
||||
'desc_inputs': [Tensor(np.array([1, 2, 3, 4]), mstype.float32)],
|
||||
|
@ -2903,7 +2890,7 @@ test_case_nn_ops = [
|
|||
Tensor(np.array([[-4, -3, -2], [1, 2, 4]]), mstype.float16)],
|
||||
'skip': ['backward']}),
|
||||
('TripletMarginLoss', {
|
||||
'block': P.TripletMarginLoss(reduction="none"),
|
||||
'block': TripletMarginLoss(reduction="none"),
|
||||
'desc_inputs': [Tensor(np.array([[0.3, 0.7], [0.5, 0.5]]).astype(np.float32)),
|
||||
Tensor(np.array([[0.4, 0.6], [0.4, 0.6]]).astype(np.float32)),
|
||||
Tensor(np.array([[0.2, 0.9], [0.3, 0.7]]).astype(np.float32)),
|
||||
|
@ -2952,10 +2939,11 @@ test_case_nn_ops = [
|
|||
Tensor(0.99, mstype.float32)],
|
||||
'skip': ['backward']}),
|
||||
('MultilabelMarginLoss', {
|
||||
'block': P.MultilabelMarginLoss(reduction="none"),
|
||||
'block': MultilabelMarginLoss(reduction="none"),
|
||||
'desc_inputs': [Tensor(np.array([[0.1, 0.2, 0.4, 0.8], [0.1, 0.2, 0.3, 0.4]]).astype(np.float32)),
|
||||
Tensor(np.array([[2, 1, -1, 1], [1, -1, 2, 1]]).astype(np.int32))],
|
||||
'desc_bprop': [Tensor(np.array([1, 2]).astype(np.float32))]}),
|
||||
'desc_bprop': [Tensor(np.array([1, 2]).astype(np.float32)),
|
||||
Tensor(np.array([[1, 1, 2, 1], [1, 1, 2, 1]]).astype(np.int32))]}),
|
||||
('GridSampler3D', {
|
||||
'block': GridSampler3D(interpolation_mode='bilinear', padding_mode='zeros', align_corners=False),
|
||||
'desc_inputs': [Tensor(np.arange(32).reshape((2, 2, 2, 2, 2)).astype(np.float32)),
|
||||
|
@ -2983,7 +2971,6 @@ test_case_nn_ops = [
|
|||
Tensor(1, mstype.int64)],
|
||||
'skip': ['backward']}),
|
||||
]
|
||||
|
||||
test_case_array_ops = [
|
||||
('LeftShift', {
|
||||
'block': LeftShift(),
|
||||
|
@ -3158,15 +3145,10 @@ test_case_array_ops = [
|
|||
'desc_inputs': [(Tensor(np.array([-1.6, -0.1, 1.5, 2.0]).astype(np.float32)))],
|
||||
'skip': ['backward']}),
|
||||
('Mvlgamma', {
|
||||
'block': P.Mvlgamma(p=1),
|
||||
'block': Mvlgamma(p=1),
|
||||
'desc_inputs': [Tensor(np.array([[3, 4, 5], [4, 2, 6]]).astype(np.float32))],
|
||||
'desc_bprop': [Tensor(np.array([[3, 4, 5], [4, 2, 6]]).astype(np.float32))]
|
||||
}),
|
||||
('MvlgammaGrad', {
|
||||
'block': G.MvlgammaGrad(p=1),
|
||||
'desc_inputs': [Tensor(np.array([[3, 4, 5], [4, 2, 6]]).astype(np.float32)),
|
||||
Tensor(np.array([[3, 4, 5], [4, 2, 6]]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
('ConcatV2_0', {
|
||||
'block': NetForConcat1(),
|
||||
'desc_inputs': [
|
||||
|
@ -3506,6 +3488,7 @@ test_case_array_ops = [
|
|||
}),
|
||||
]
|
||||
|
||||
|
||||
test_case_image_ops = [
|
||||
('AdjustHue', {
|
||||
'block': AdjustHue(),
|
||||
|
@ -3597,10 +3580,10 @@ test_case_other_ops = [
|
|||
Tensor(np.array([[[0.38, 0.17, 0.95, 0.40]]], np.float32)),
|
||||
Tensor(np.array([0.8], np.float32))),
|
||||
'skip': ['backward']}),
|
||||
('BartlettWindow', {
|
||||
'block': P.BartlettWindow(periodic=True, dtype=mstype.float32),
|
||||
'desc_inputs': (Tensor(np.array([10], np.int32))),
|
||||
'skip': ['backward']}),
|
||||
('BartlettWindow', {
|
||||
'block': BartlettWindow(periodic=True, dtype=mstype.float32),
|
||||
'desc_inputs': (Tensor(np.array([10], np.int32))),
|
||||
'skip': ['backward']}),
|
||||
('GatherNd', {
|
||||
'block': P.GatherNd(),
|
||||
'desc_inputs': (Tensor(np.ones((1, 3, 6, 6), np.float32)),
|
||||
|
|
Loading…
Reference in New Issue