forked from mindspore-Ecosystem/mindspore
!37125 [feat][assistant][I4XJGX,I4XJGY]add NLLLoss, NLLLossGrad
Merge pull request !37125 from 李定维/NLLLoss,NLLLossGrad
This commit is contained in:
commit
01623bc3dd
|
@ -27,7 +27,7 @@ mindspore.ops.NLLLoss
|
|||
|
||||
输入:
|
||||
- **logits** (Tensor) - 输入预测值,shape为 :math:`(N, C)` 。数据类型仅支持float32或float16。
|
||||
- **labels** (Tensor) - 输入目标值,shape为 :math:`(N,)` 。数据类型仅支持int32。
|
||||
- **labels** (Tensor) - 输入目标值,shape为 :math:`(N,)` ,取值范围为 :math:`[0, C-1]` 。数据类型仅支持int32或int64。
|
||||
- **weight** (Tensor) - 指定各类别的权重,shape为 :math:`(C,)` ,数据类型仅支持float32或float16。
|
||||
|
||||
输出:
|
||||
|
@ -37,6 +37,7 @@ mindspore.ops.NLLLoss
|
|||
- **total_weight** (Tensor) - `total_weight` 是scalar,数据类型与 `weight` 相同。
|
||||
|
||||
异常:
|
||||
- **TypeError** - `logits` 或 `weight` 的数据类型既不是float16也不是float32, `labels` 不是int32。
|
||||
- **TypeError** - `logits` 或 `weight` 的数据类型既不是float16也不是float32。
|
||||
- **TypeError** - `labels` 的数据类型既不是int32也不是int64。
|
||||
- **ValueError** - `logits` 不是二维Tensor, `labels` 和 `weight` 不是一维Tensor。 `logits` 的第一个维度不等于 `labels` , `logits` 的第二个维度不等于 `weight` 。
|
||||
- **ValueError** - `labels` 的取值超出 :math:`[0, C-1]` ,其中 :math:`C` 表示类的数量。
|
|
@ -17,6 +17,7 @@
|
|||
#include "plugin/device/cpu/kernel/nllloss_cpu_kernel.h"
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include "mindspore/core/ops/nllloss.h"
|
||||
#include "nnacl/errorcode.h"
|
||||
|
||||
|
@ -41,11 +42,12 @@ bool NLLLossCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::
|
|||
kernel_name_ = kernel_ptr->GetPrim()->name();
|
||||
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
|
||||
|
||||
bool is_match = MatchKernelAttr(kernel_attr, GetOpSupport()).first;
|
||||
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
|
||||
if (!is_match) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "', it does not support this kernel data type: " << kernel_attr;
|
||||
return false;
|
||||
}
|
||||
kernel_func_ = func_list_[index].second;
|
||||
|
||||
auto reduction = kernel_ptr->get_reduction();
|
||||
auto pair = kReductionMap.find(reduction);
|
||||
|
@ -72,14 +74,15 @@ int NLLLossCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std:
|
|||
return KRET_OK;
|
||||
}
|
||||
|
||||
bool NLLLossCpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &workspace,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
template <typename T>
|
||||
bool NLLLossCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &workspace,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
CHECK_KERNEL_INPUTS_NUM(kNLLLossInputsNum, inputs.size(), kernel_name_);
|
||||
CHECK_KERNEL_OUTPUTS_NUM(kNLLLossOutputsNum, outputs.size(), kernel_name_);
|
||||
|
||||
const auto *logits = reinterpret_cast<float *>(inputs[kIndex0]->addr);
|
||||
const auto *labels = reinterpret_cast<int *>(inputs[kIndex1]->addr);
|
||||
const auto *labels = reinterpret_cast<T *>(inputs[kIndex1]->addr);
|
||||
const auto *weight = reinterpret_cast<float *>(inputs[kIndex2]->addr);
|
||||
auto *loss = reinterpret_cast<float *>(outputs[kIndex0]->addr);
|
||||
auto *total_weight = reinterpret_cast<float *>(outputs[kIndex1]->addr);
|
||||
|
@ -93,13 +96,53 @@ bool NLLLossCpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
|||
}
|
||||
}
|
||||
|
||||
int ret = NLLLoss(logits, labels, weight, loss, total_weight, &nllloss_param_);
|
||||
if (ret != static_cast<int>(NNACL_OK)) {
|
||||
MS_LOG(EXCEPTION) << "Launch " << kernel_name_ << " failed, the nnacl error code " << ret;
|
||||
if (logits == NULL || labels == NULL || weight == NULL) {
|
||||
MS_LOG(ERROR) << "For NLLLoss, it does not support NULL input";
|
||||
}
|
||||
|
||||
float total_loss = 0.0;
|
||||
float tmp_total_weight = 0.0;
|
||||
ReductionType reduction_type = nllloss_param_.reduction_type_;
|
||||
for (int i = 0; i < nllloss_param_.batch_; i++) {
|
||||
if (!(labels[i] < nllloss_param_.class_num_)) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << kernel_name_
|
||||
<< "', the labels should be smaller than the number of classes, but got " << labels[i];
|
||||
}
|
||||
int index = i * nllloss_param_.class_num_ + labels[i];
|
||||
float n_weight = weight[labels[i]];
|
||||
float n_loss = -logits[index] * n_weight;
|
||||
tmp_total_weight += n_weight;
|
||||
total_loss += n_loss;
|
||||
if (reduction_type == Reduction_None) {
|
||||
loss[i] = n_loss;
|
||||
}
|
||||
}
|
||||
|
||||
*total_weight = tmp_total_weight;
|
||||
if (reduction_type == Reduction_Sum) {
|
||||
*loss = total_loss;
|
||||
} else if (reduction_type == Reduction_Mean) {
|
||||
*loss = total_loss / tmp_total_weight;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<std::pair<KernelAttr, NLLLossCpuKernelMod::NLLLossFunc>> NLLLossCpuKernelMod::func_list_ = {
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
&NLLLossCpuKernelMod::LaunchKernel<int32_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
&NLLLossCpuKernelMod::LaunchKernel<int64_t>}};
|
||||
|
||||
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, NLLLoss, NLLLossCpuKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include <utility>
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel.h"
|
||||
#include "plugin/factory/ms_factory.h"
|
||||
#include "nnacl/fp32/nllloss_fp32.h"
|
||||
|
@ -37,7 +38,9 @@ class NLLLossCpuKernelMod : public NativeCpuKernelMod {
|
|||
const std::vector<KernelTensorPtr> &outputs, const std::map<uint32_t, tensor::TensorPtr> &) override;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
const std::vector<AddressPtr> &outputs) override {
|
||||
return kernel_func_(this, inputs, workspace, outputs);
|
||||
}
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override {
|
||||
|
@ -46,11 +49,25 @@ class NLLLossCpuKernelMod : public NativeCpuKernelMod {
|
|||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)};
|
||||
return support_list;
|
||||
}
|
||||
|
||||
private:
|
||||
template <typename T>
|
||||
bool LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &workspace,
|
||||
const std::vector<kernel::AddressPtr> &outputs);
|
||||
using NLLLossFunc =
|
||||
std::function<bool(NLLLossCpuKernelMod *, const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &, const std::vector<kernel::AddressPtr> &)>;
|
||||
NLLLossFunc kernel_func_;
|
||||
static std::vector<std::pair<KernelAttr, NLLLossFunc>> func_list_;
|
||||
NLLLossParameter nllloss_param_{};
|
||||
};
|
||||
} // namespace kernel
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
#include "plugin/device/cpu/kernel/nllloss_grad_cpu_kernel.h"
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <unordered_map>
|
||||
#include "mindspore/core/ops/grad/nllloss_grad.h"
|
||||
#include "nnacl/errorcode.h"
|
||||
|
@ -40,10 +41,11 @@ bool NLLLossGradCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const s
|
|||
auto kernel_name = kernel_ptr->GetPrim()->name();
|
||||
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
|
||||
|
||||
bool is_match = MatchKernelAttr(kernel_attr, GetOpSupport()).first;
|
||||
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
|
||||
if (!is_match) {
|
||||
MS_LOG(EXCEPTION) << kernel_name << " does not support this kernel data type: " << kernel_attr;
|
||||
}
|
||||
kernel_func_ = func_list_[index].second;
|
||||
|
||||
auto reduction = kernel_ptr->get_reduction();
|
||||
|
||||
|
@ -71,26 +73,60 @@ int NLLLossGradCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const
|
|||
return KRET_OK;
|
||||
}
|
||||
|
||||
bool NLLLossGradCpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &workspace,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
template <typename T>
|
||||
bool NLLLossGradCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &workspace,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
CHECK_KERNEL_INPUTS_NUM(kNLLLossGradInputsNum, inputs.size(), kernel_name_);
|
||||
CHECK_KERNEL_OUTPUTS_NUM(kNLLLossGradOutputsNum, outputs.size(), kernel_name_);
|
||||
|
||||
const auto *logits = reinterpret_cast<float *>(inputs[0]->addr);
|
||||
const auto *loss_grad = reinterpret_cast<float *>(inputs[1]->addr);
|
||||
const auto *labels = reinterpret_cast<int *>(inputs[2]->addr);
|
||||
const auto *labels = reinterpret_cast<T *>(inputs[2]->addr);
|
||||
const auto *weight = reinterpret_cast<float *>(inputs[3]->addr);
|
||||
const auto *total_weight = reinterpret_cast<float *>(inputs[4]->addr);
|
||||
auto *logits_grad = reinterpret_cast<float *>(outputs[0]->addr);
|
||||
|
||||
int ret = NLLLossGrad(logits, loss_grad, labels, weight, total_weight, logits_grad, &nllloss_param_);
|
||||
if (ret != static_cast<int>(NNACL_OK)) {
|
||||
MS_LOG(EXCEPTION) << "Launch " << kernel_name_ << " failed, the nnacl error code " << ret;
|
||||
if (logits == NULL || loss_grad == NULL || labels == NULL || weight == NULL || total_weight == NULL) {
|
||||
MS_LOG(ERROR) << "For NLLLossGrad, it does not support NULL input";
|
||||
}
|
||||
memset(logits_grad, 0, nllloss_param_.batch_ * nllloss_param_.class_num_ * sizeof(float));
|
||||
for (int i = 0; i < nllloss_param_.batch_; i++) {
|
||||
if (!(labels[i] < nllloss_param_.class_num_)) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << kernel_name_
|
||||
<< "', the labels should be smaller than the number of classes, but got " << labels[i];
|
||||
}
|
||||
int index = i * nllloss_param_.class_num_ + labels[i];
|
||||
float n_weight = weight[labels[i]];
|
||||
if (nllloss_param_.reduction_type_ == Reduction_Sum) {
|
||||
logits_grad[index] = -loss_grad[0] * n_weight;
|
||||
} else if (nllloss_param_.reduction_type_ == Reduction_Mean) {
|
||||
logits_grad[index] = -loss_grad[0] * n_weight / *total_weight;
|
||||
} else {
|
||||
logits_grad[index] = -loss_grad[i] * n_weight;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<std::pair<KernelAttr, NLLLossGradCpuKernelMod::NLLLossGradFunc>> NLLLossGradCpuKernelMod::func_list_ = {
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
&NLLLossGradCpuKernelMod::LaunchKernel<int32_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
&NLLLossGradCpuKernelMod::LaunchKernel<int64_t>}};
|
||||
|
||||
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, NLLLossGrad, NLLLossGradCpuKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include <utility>
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel.h"
|
||||
#include "plugin/factory/ms_factory.h"
|
||||
#include "nnacl/fp32_grad/nllloss_grad_fp32.h"
|
||||
|
@ -37,7 +38,9 @@ class NLLLossGradCpuKernelMod : public NativeCpuKernelMod {
|
|||
const std::vector<KernelTensorPtr> &outputs, const std::map<uint32_t, tensor::TensorPtr> &) override;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
const std::vector<AddressPtr> &outputs) override {
|
||||
return kernel_func_(this, inputs, workspace, outputs);
|
||||
}
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override {
|
||||
|
@ -47,11 +50,26 @@ class NLLLossGradCpuKernelMod : public NativeCpuKernelMod {
|
|||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)};
|
||||
return support_list;
|
||||
}
|
||||
|
||||
private:
|
||||
template <typename T>
|
||||
bool LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &workspace,
|
||||
const std::vector<kernel::AddressPtr> &outputs);
|
||||
using NLLLossGradFunc =
|
||||
std::function<bool(NLLLossGradCpuKernelMod *, const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &, const std::vector<kernel::AddressPtr> &)>;
|
||||
NLLLossGradFunc kernel_func_;
|
||||
static std::vector<std::pair<KernelAttr, NLLLossGradFunc>> func_list_;
|
||||
NLLLossParameter nllloss_param_{};
|
||||
};
|
||||
} // namespace kernel
|
||||
|
|
|
@ -117,11 +117,11 @@ class NLLLossGradInfer : public abstract::OpInferBase {
|
|||
auto t_dtype = input_args[kInputIndex2]->BuildType();
|
||||
auto w_dtype = input_args[kInputIndex3]->BuildType();
|
||||
auto tw_dtype = input_args[kInputIndex4]->BuildType();
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("logits dtype", x_dtype, valid_types, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("loss's grad dtype", y_grad_dtype, valid_types, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("labels dtype", t_dtype, {kInt32}, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("weight dtype", w_dtype, valid_types, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("total_weight dtype", tw_dtype, valid_types, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("logits", x_dtype, valid_types, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("loss's grad", y_grad_dtype, valid_types, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("labels", t_dtype, {kInt32, kInt64}, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("weight", w_dtype, valid_types, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("total_weight", tw_dtype, valid_types, prim_name);
|
||||
CheckAndConvertUtils::Check("weight dtype", std::vector<TypeId>{tw_dtype->type_id()}, kEqual,
|
||||
std::vector<TypeId>{w_dtype->type_id()}, prim_name);
|
||||
return x_dtype;
|
||||
|
@ -150,6 +150,6 @@ Reduction NLLLossGrad::get_reduction() const {
|
|||
}
|
||||
|
||||
MIND_API_OPERATOR_IMPL(NLLLossGrad, BaseOperator);
|
||||
REGISTER_PRIMITIVE_OP_INFER_IMPL(NLLLossGrad, std::make_shared<Primitive>("NLLLossGrad"), NLLLossGradInfer, false);
|
||||
REGISTER_PRIMITIVE_OP_INFER_IMPL(NLLLossGrad, prim::kPrimNLLLossGrad, NLLLossGradInfer, false);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -86,7 +86,7 @@ class NLLLossInfer : public abstract::OpInferBase {
|
|||
auto logits_data_type = input_args[kIndex0]->BuildType();
|
||||
auto target_type = input_args[kIndex1]->BuildType();
|
||||
auto weight_data_type = input_args[kIndex2]->BuildType();
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("target", target_type, {kInt32}, prim->name());
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("target", target_type, {kInt32, kInt64}, prim->name());
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("logits", logits_data_type, valid_types, prim->name());
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("weight", weight_data_type, valid_types, prim->name());
|
||||
return std::make_shared<Tuple>(std::vector<TypePtr>{logits_data_type, weight_data_type});
|
||||
|
|
|
@ -0,0 +1,38 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""NLLLoss op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
|
||||
|
||||
nll_loss_op_info = AiCPURegOp("NLLLoss") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.attr("reduction", "str") \
|
||||
.attr("ignore_index", "int") \
|
||||
.input(0, "x", "required") \
|
||||
.input(1, "target", "required") \
|
||||
.input(2, "weight", "optional") \
|
||||
.output(0, "y", "required") \
|
||||
.output(1, "total_weight", "required") \
|
||||
.dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default, DataType.F32_Default,
|
||||
DataType.F32_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.I64_Default, DataType.F32_Default, DataType.F32_Default,
|
||||
DataType.F32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(nll_loss_op_info)
|
||||
def _nll_loss_aicpu():
|
||||
"""NLLLoss aicpu register"""
|
||||
return
|
|
@ -0,0 +1,39 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""NLLLossGrad op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
|
||||
|
||||
nll_loss_grad_op_info = AiCPURegOp("NLLLossGrad") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.attr("reduction", "str") \
|
||||
.attr("ignore_index", "int") \
|
||||
.input(0, "x", "required") \
|
||||
.input(1, "y_grad", "required") \
|
||||
.input(2, "target", "required") \
|
||||
.input(3, "weight", "require") \
|
||||
.input(4, "total_weight", "require") \
|
||||
.output(0, "x_grad", "required") \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.I32_Default, DataType.F32_Default, \
|
||||
DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.I64_Default, DataType.F32_Default, \
|
||||
DataType.F32_Default, DataType.F32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(nll_loss_grad_op_info)
|
||||
def _nll_loss_grad_aicpu():
|
||||
"""NLLLossGrad aicpu register"""
|
||||
return
|
|
@ -2667,7 +2667,8 @@ class NLLLoss(Primitive):
|
|||
|
||||
Inputs:
|
||||
- **logits** (Tensor) - Input logits, with shape :math:`(N, C)`. Data type only supports float32 or float16.
|
||||
- **labels** (Tensor) - Ground truth labels, with shape :math:`(N,)`. Data type only supports int32.
|
||||
- **labels** (Tensor) - Ground truth labels, with shape :math:`(N,)`, where each value belong to
|
||||
:math:`[0, C-1]`. Data type only supports int32 or int64.
|
||||
- **weight** (Tensor) - The rescaling weight to each class, with shape :math:`(C,)` and data type only
|
||||
supports float32 or float16.
|
||||
|
||||
|
@ -2679,13 +2680,15 @@ class NLLLoss(Primitive):
|
|||
- **total_weight** (Tensor) - The `total_weight` is a scalar. The data type is the same with `weight's`.
|
||||
|
||||
Raises:
|
||||
TypeError: If dtype of `logits` or `weight` is neither float16 nor float32, `labels` is not int32.
|
||||
TypeError: If dtype of `logits` or `weight` is neither float16 nor float32.
|
||||
TypeError: If dtype of `labels` is neither int32 nor int64.
|
||||
ValueError: If `logits` is not a one or two dimension tensor, `labels` and `weight` are not
|
||||
one dimension tensors.
|
||||
When `logits` is a two dimension tensor, the first dimension of `logits` is not equal to `labels`,
|
||||
and second dimension of `logits` is not equal to `weight`.
|
||||
When `logits` is a one dimension tensor, the dimensions of `logits`, `labels`
|
||||
and `weight` should be equal to each other.
|
||||
ValueError: If the value of `labels` exceed :math:`[0, C-1]`, where :math:`C` is the number of classes.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
|
|
@ -2089,6 +2089,14 @@ test_case_math_ops = [
|
|||
Tensor(np.random.rand(3), mstype.int32),
|
||||
Tensor(np.random.rand(16), mstype.float32)],
|
||||
'desc_bprop': [(Tensor(np.random.rand(1), mstype.float32), Tensor(np.random.rand(1), mstype.float32))]}),
|
||||
('NLLLossGrad', {
|
||||
'block': G.NLLLossGrad(reduction="mean"),
|
||||
'desc_inputs': [Tensor(np.random.rand(3, 16), mstype.float32),
|
||||
Tensor(np.random.rand(1), mstype.float32),
|
||||
Tensor(np.random.rand(3), mstype.int32),
|
||||
Tensor(np.random.rand(16), mstype.float32),
|
||||
Tensor(np.random.rand(1), mstype.float32)],
|
||||
'skip': ['backward']}),
|
||||
('BatchNorm3d', {
|
||||
'block': BatchNorm3d(num_features=3),
|
||||
'desc_inputs': [Tensor(np.random.rand(3, 3, 3, 5, 4).astype(np.float32))],
|
||||
|
|
Loading…
Reference in New Issue