!42380 Add dynamic shape support for nllloss and etc.

Merge pull request !42380 from panzhihui/nllloss_dyn_r1.9
This commit is contained in:
i-robot 2022-09-22 08:33:04 +00:00 committed by Gitee
commit 6414523ead
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
25 changed files with 1085 additions and 589 deletions

View File

@ -15,9 +15,11 @@
*/
#include "plugin/device/cpu/kernel/matrix_inverse_cpu_kernel.h"
#include <map>
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
#include "Eigen/Core"
#include "Eigen/LU"
#include "mindspore/core/ops/matrix_inverse.h"
namespace mindspore {
namespace kernel {
@ -29,19 +31,22 @@ static constexpr int kNumber2 = 2;
constexpr size_t kParallelDataNums = 1 * 1024;
} // namespace
void MatrixInverseCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
node_wpt_ = kernel_node;
dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
bool MatrixInverseCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) {
kernel_name_ = base_operator->name();
auto kernel_ptr = std::dynamic_pointer_cast<ops::MatrixInverse>(base_operator);
if (!kernel_ptr) {
MS_LOG(ERROR) << "cast " << kernel_name_ << " ops failed!";
return false;
}
dtype_ = inputs[kIndex0]->GetDtype();
adjoint_ = kernel_ptr->get_adjoint();
return true;
}
bool MatrixInverseCpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> & /* workspace */,
const std::vector<kernel::AddressPtr> &outputs) {
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kInputSize, kernel_name_);
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kOutputSize, kernel_name_);
if (dtype_ == kNumberTypeFloat32) {
LaunchMatrixInverse<float>(inputs, outputs);
} else if (dtype_ == kNumberTypeFloat64) {
@ -56,43 +61,50 @@ bool MatrixInverseCpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &in
return true;
}
int MatrixInverseCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &) {
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kInputSize, kernel_name_);
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kOutputSize, kernel_name_);
if (int ret = KernelMod::Resize(base_operator, inputs, outputs); ret != KRET_OK) {
return ret;
}
// Judge whether the input shape matches
input_shape_ = inputs[kIndex0]->GetShapeVector();
if (input_shape_.size() < kNumber2) {
MS_LOG(EXCEPTION) << "Input x must be at least rank 2.";
}
if (input_shape_[input_shape_.size() - kNumber1] != input_shape_[input_shape_.size() - kNumber2]) {
MS_LOG(EXCEPTION) << "The last two dimensions of Input x must be equal.";
}
return KRET_OK;
}
template <typename T>
void MatrixInverseCpuKernelMod::LaunchMatrixInverse(const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &outputs) {
auto node_ = node_wpt_.lock();
if (!node_) {
MS_LOG(EXCEPTION) << "node_wpt_ is expired.";
}
T *input_ptr = reinterpret_cast<T *>(inputs[0]->addr);
MS_EXCEPTION_IF_NULL(input_ptr);
T *output_ptr = reinterpret_cast<T *>(outputs[0]->addr);
MS_EXCEPTION_IF_NULL(output_ptr);
// Judge whether the input shape matches
auto shape = Convert2SizeT(common::AnfAlgo::GetPrevNodeOutputInferShape(node_, 0));
if (shape.size() < kNumber2) {
MS_LOG(EXCEPTION) << "Input x must be at least rank 2.";
}
if (shape[shape.size() - kNumber1] != shape[shape.size() - kNumber2]) {
MS_LOG(EXCEPTION) << "The last two dimensions of Input x must be equal.";
}
auto last_dimsize = shape[shape.size() - 1];
auto last_dimsize = LongToSize(input_shape_[input_shape_.size() - 1]);
// Output length
size_t input_num = 1;
for (size_t i = 0; i < shape.size(); i++) {
input_num *= shape[i];
for (size_t i = 0; i < input_shape_.size(); i++) {
input_num *= input_shape_[i];
}
size_t matrix_size = last_dimsize * last_dimsize;
auto matrix_size = last_dimsize * last_dimsize;
// Number of matrices
size_t matrix_num = input_num / matrix_size;
auto matrix_num = input_num / matrix_size;
// Store two-dimensional array of data for slicing
std::vector<std::vector<T>> temp(matrix_num, std::vector<T>(matrix_size));
for (size_t i = 0; i < matrix_num; i++) {
for (size_t j = 0; j < matrix_size; j++) {
temp[i][j] = *(input_ptr + i * matrix_size + j);
}
}
// Gets the value of the property adjoint
adjoint_ = common::AnfAlgo::GetNodeAttr<bool>(node_, "adjoint");
auto one_size = sizeof(*input_ptr);
if ((one_size * input_num) <= kParallelDataNums) {

View File

@ -17,17 +17,22 @@
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_MATRIX_INVERSE_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_MATRIX_INVERSE_CPU_KERNEL_H_
#include <vector>
#include <map>
#include "plugin/device/cpu/kernel/cpu_kernel.h"
#include "plugin/factory/ms_factory.h"
namespace mindspore {
namespace kernel {
class MatrixInverseCpuKernelMod : public DeprecatedNativeCpuKernelMod {
class MatrixInverseCpuKernelMod : public NativeCpuKernelMod {
public:
MatrixInverseCpuKernelMod() = default;
~MatrixInverseCpuKernelMod() override = default;
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) override;
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs, const std::map<uint32_t, tensor::TensorPtr> &) override;
void InitKernel(const CNodePtr &kernel_node) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
@ -46,6 +51,7 @@ class MatrixInverseCpuKernelMod : public DeprecatedNativeCpuKernelMod {
TypeId dtype_{kTypeUnknown};
template <typename T>
void LaunchMatrixInverse(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
ShapeVector input_shape_;
};
} // namespace kernel
} // namespace mindspore

View File

@ -15,10 +15,9 @@
*/
#include "plugin/device/cpu/kernel/nllloss_cpu_kernel.h"
#include <map>
#include <string>
#include <unordered_map>
#include "mindspore/core/ops/nllloss.h"
#include "nnacl/errorcode.h"
namespace mindspore {
@ -26,30 +25,50 @@ namespace kernel {
namespace {
constexpr size_t kNLLLossInputsNum = 3;
constexpr size_t kNLLLossOutputsNum = 2;
const std::unordered_map<std::string, ReductionType> kReductionMap = {
{MEAN, Reduction_Mean}, {SUM, Reduction_Sum}, {NONE, Reduction_None}};
const std::map<Reduction, ReductionType> kReductionMap = {
{Reduction::MEAN, Reduction_Mean}, {Reduction::REDUCTION_SUM, Reduction_Sum}, {Reduction::NONE, Reduction_None}};
} // namespace
void NLLLossCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
KernelAttr kernel_attr = GetKernelAttrFromNode(kernel_node);
bool is_match = MatchKernelAttr(kernel_attr, GetOpSupport()).first;
if (!is_match) {
MS_LOG(EXCEPTION) << kernel_name_ << " does not support this kernel data type: " << kernel_attr;
bool NLLLossCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) {
auto kernel_ptr = std::dynamic_pointer_cast<ops::NLLLoss>(base_operator);
if (!kernel_ptr) {
MS_LOG(ERROR) << "cast NLLLoss ops failed!";
return false;
}
auto logits_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
auto reduction = common::AnfAlgo::GetNodeAttr<std::string>(kernel_node, REDUCTION);
kernel_name_ = kernel_ptr->GetPrim()->name();
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
bool is_match = MatchKernelAttr(kernel_attr, GetOpSupport()).first;
if (!is_match) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "', it does not support this kernel data type: " << kernel_attr;
return false;
}
auto reduction = kernel_ptr->get_reduction();
auto pair = kReductionMap.find(reduction);
if (pair == kReductionMap.end()) {
MS_LOG(EXCEPTION) << "For " << kernel_name_
<< ", the attr 'reduction' only support 'mean', 'sum' and 'none', but got " << reduction;
}
nllloss_param_.batch_ = LongToInt(logits_shape[0]);
nllloss_param_.class_num_ = LongToInt(logits_shape[1]);
nllloss_param_.reduction_type_ = pair->second;
return true;
}
int NLLLossCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost) {
int ret = 0;
if ((ret = KernelMod::Resize(base_operator, inputs, outputs, inputsOnHost)) != 0) {
return ret;
}
auto logits_shape = inputs[kIndex0]->GetShapeVector();
nllloss_param_.batch_ = LongToInt(logits_shape[kIndex0]);
nllloss_param_.class_num_ = LongToInt(logits_shape[kIndex1]);
return KRET_OK;
}
bool NLLLossCpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &inputs,
@ -58,11 +77,11 @@ bool NLLLossCpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &inputs,
CHECK_KERNEL_INPUTS_NUM(kNLLLossInputsNum, inputs.size(), kernel_name_);
CHECK_KERNEL_OUTPUTS_NUM(kNLLLossOutputsNum, outputs.size(), kernel_name_);
const auto *logits = reinterpret_cast<float *>(inputs[0]->addr);
const auto *labels = reinterpret_cast<int *>(inputs[1]->addr);
const auto *weight = reinterpret_cast<float *>(inputs[2]->addr);
auto *loss = reinterpret_cast<float *>(outputs[0]->addr);
auto *total_weight = reinterpret_cast<float *>(outputs[1]->addr);
const auto *logits = reinterpret_cast<float *>(inputs[kIndex0]->addr);
const auto *labels = reinterpret_cast<int *>(inputs[kIndex1]->addr);
const auto *weight = reinterpret_cast<float *>(inputs[kIndex2]->addr);
auto *loss = reinterpret_cast<float *>(outputs[kIndex0]->addr);
auto *total_weight = reinterpret_cast<float *>(outputs[kIndex1]->addr);
int ret = NLLLoss(logits, labels, weight, loss, total_weight, &nllloss_param_);
if (ret != static_cast<int>(NNACL_OK)) {

View File

@ -18,23 +18,28 @@
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_NLLLOSS_CPU_KERNEL_H_
#include <vector>
#include <map>
#include "plugin/device/cpu/kernel/cpu_kernel.h"
#include "plugin/factory/ms_factory.h"
#include "nnacl/fp32/nllloss_fp32.h"
namespace mindspore {
namespace kernel {
class NLLLossCpuKernelMod : public DeprecatedNativeCpuKernelMod {
class NLLLossCpuKernelMod : public NativeCpuKernelMod {
public:
NLLLossCpuKernelMod() = default;
~NLLLossCpuKernelMod() override = default;
void InitKernel(const CNodePtr &kernel_node) override;
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) override;
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs, const std::map<uint32_t, tensor::TensorPtr> &) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
protected:
std::vector<KernelAttr> GetOpSupport() override {
static std::vector<KernelAttr> support_list = {KernelAttr()
.AddInputAttr(kNumberTypeFloat32)

View File

@ -15,10 +15,10 @@
*/
#include "plugin/device/cpu/kernel/nllloss_grad_cpu_kernel.h"
#include <map>
#include <string>
#include <unordered_map>
#include "mindspore/core/ops/grad/nllloss_grad.h"
#include "nnacl/errorcode.h"
namespace mindspore {
@ -26,30 +26,49 @@ namespace kernel {
namespace {
constexpr size_t kNLLLossGradInputsNum = 5;
constexpr size_t kNLLLossGradOutputsNum = 1;
const std::unordered_map<std::string, ReductionType> kReductionMap = {
{MEAN, Reduction_Mean}, {SUM, Reduction_Sum}, {NONE, Reduction_None}};
const std::unordered_map<Reduction, ReductionType> kReductionMap = {
{Reduction::MEAN, Reduction_Mean}, {Reduction::REDUCTION_SUM, Reduction_Sum}, {Reduction::NONE, Reduction_None}};
} // namespace
void NLLLossGradCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
KernelAttr kernel_attr = GetKernelAttrFromNode(kernel_node);
bool NLLLossGradCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) {
auto kernel_ptr = std::dynamic_pointer_cast<ops::NLLLossGrad>(base_operator);
if (!kernel_ptr) {
MS_LOG(ERROR) << "cast NLLLossGrad ops failed!";
return false;
}
auto kernel_name = kernel_ptr->GetPrim()->name();
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
bool is_match = MatchKernelAttr(kernel_attr, GetOpSupport()).first;
if (!is_match) {
MS_LOG(EXCEPTION) << kernel_name_ << " does not support this kernel data type: " << kernel_attr;
MS_LOG(EXCEPTION) << kernel_name << " does not support this kernel data type: " << kernel_attr;
}
auto logits_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
auto reduction = common::AnfAlgo::GetNodeAttr<std::string>(kernel_node, REDUCTION);
auto reduction = kernel_ptr->get_reduction();
auto pair = kReductionMap.find(reduction);
if (pair == kReductionMap.end()) {
MS_LOG(EXCEPTION) << "For " << kernel_name_
<< ", the attr 'reduction' only support 'mean', 'sum' and 'none', but got " << reduction;
}
nllloss_param_.reduction_type_ = pair->second;
return true;
}
int NLLLossGradCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost) {
int ret = 0;
if ((ret = KernelMod::Resize(base_operator, inputs, outputs, inputsOnHost)) != 0) {
return ret;
}
auto logits_shape = inputs[0]->GetShapeVector();
nllloss_param_.batch_ = LongToInt(logits_shape[0]);
nllloss_param_.class_num_ = LongToInt(logits_shape[1]);
nllloss_param_.reduction_type_ = pair->second;
return KRET_OK;
}
bool NLLLossGradCpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &inputs,

View File

@ -18,23 +18,28 @@
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_NLLLOSS_GRAD_CPU_KERNEL_H_
#include <vector>
#include <map>
#include "plugin/device/cpu/kernel/cpu_kernel.h"
#include "plugin/factory/ms_factory.h"
#include "nnacl/fp32_grad/nllloss_grad_fp32.h"
namespace mindspore {
namespace kernel {
class NLLLossGradCpuKernelMod : public DeprecatedNativeCpuKernelMod {
class NLLLossGradCpuKernelMod : public NativeCpuKernelMod {
public:
NLLLossGradCpuKernelMod() = default;
~NLLLossGradCpuKernelMod() override = default;
void InitKernel(const CNodePtr &kernel_node) override;
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) override;
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs, const std::map<uint32_t, tensor::TensorPtr> &) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
protected:
std::vector<KernelAttr> GetOpSupport() override {
static std::vector<KernelAttr> support_list = {KernelAttr()
.AddInputAttr(kNumberTypeFloat32)

View File

@ -15,12 +15,145 @@
*/
#include "plugin/device/gpu/kernel/math/matrix_inverse_gpu_kernel.h"
#include <map>
#include <utility>
#include <algorithm>
#include "mindspore/core/ops/matrix_inverse.h"
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_ONE(MatrixInverse, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
MatrixInverseGpuKernelMod, float)
MS_REG_GPU_KERNEL_ONE(MatrixInverse, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
MatrixInverseGpuKernelMod, double)
bool MatrixInverseGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) {
kernel_name_ = base_operator->name();
auto kernel_ptr = std::dynamic_pointer_cast<ops::MatrixInverse>(base_operator);
if (kernel_ptr == nullptr) {
MS_LOG(ERROR) << "Cast op from BaseOperator to MaxPoolingGradWithArgmax failed.";
return false;
}
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
auto pair = MatchKernelAttr(kernel_attr, GetOpSupport());
if (!pair.first) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "' does not support this kernel type: " << kernel_attr;
return false;
}
kernel_func_ = func_list_[pair.second].second;
handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCublasHandle();
adjoint_ = kernel_ptr->get_adjoint();
return true;
}
int MatrixInverseGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &) {
if (int ret = KernelMod::Resize(base_operator, inputs, outputs); ret != KRET_OK) {
return ret;
}
auto input_shape = inputs[kIndex0]->GetShapeVector();
size_t kMinDim = 2;
if (input_shape.size() < kMinDim) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the dimension of input cannot be less than 2, but got "
<< input_shape.size();
}
size_t last_index = input_shape.size() - 1;
if (input_shape[last_index] != input_shape[last_index - 1]) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the last two dimensions of the input matrix should be equal, "
<< "but got one: " << input_shape[last_index] << ", another: " << input_shape[last_index - 1];
}
size_ = input_shape[last_index];
batch_size_ = 1;
for (size_t i = 0; i < last_index - 1; i++) {
batch_size_ *= input_shape[i];
}
auto dtype = inputs[kIndex0]->GetDtype();
dtype_size_ = sizeof(TypeIdToType(dtype));
input_size_ = dtype_size_;
for (auto dim : input_shape) {
input_size_ *= dim;
}
InitSizeLists();
return KRET_OK;
}
template <typename T>
bool MatrixInverseGpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
CHECK_CUBLAS_RET_WITH_ERROR(cublasSetStream(handle_, reinterpret_cast<cudaStream_t>(stream_ptr)),
"cublasSetStream failed");
T *input_addr = GetDeviceAddress<T>(inputs, 0);
T *output_addr = GetDeviceAddress<T>(outputs, 0);
auto compute_input_addr = GetDeviceAddress<T>(workspace, 0);
auto lu_batch_addr = GetDeviceAddress<T *>(workspace, 1);
auto inv_batch_addr = GetDeviceAddress<T *>(workspace, 2);
auto pivo_addr = GetDeviceAddress<int>(workspace, 3);
auto info_addr = GetDeviceAddress<int>(workspace, 4);
std::vector<T *> lu_addr(batch_size_);
std::vector<T *> inv_addr(batch_size_);
int len = SizeToInt(size_);
int batchsize = SizeToInt(batch_size_);
for (size_t i = 0; i < batch_size_; i++) {
lu_addr[i] = compute_input_addr + i * len * len;
inv_addr[i] = output_addr + i * len * len;
}
CHECK_CUDA_RET_WITH_ERROR_NOTRACE(
cudaMemcpyAsync(compute_input_addr, input_addr, input_size_, cudaMemcpyDeviceToDevice,
reinterpret_cast<cudaStream_t>(stream_ptr)),
"cuda memcopy Fail");
CHECK_CUDA_RET_WITH_ERROR_NOTRACE(cudaMemcpyAsync(lu_batch_addr, lu_addr.data(), sizeof(T *) * batch_size_,
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
"cuda memcopy Fail");
CHECK_CUDA_RET_WITH_ERROR_NOTRACE(cudaMemcpyAsync(inv_batch_addr, inv_addr.data(), sizeof(T *) * batch_size_,
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
"cuda memcopy Fail");
if (std::is_same<T, float>::value) {
CHECK_CUBLAS_RET_WITH_EXCEPT_NOTRACE(cublasSgetrfBatched(handle_, len, reinterpret_cast<float **>(lu_batch_addr),
len, pivo_addr, info_addr, batchsize),
"cublas trsm batched Fail");
CHECK_CUBLAS_RET_WITH_EXCEPT_NOTRACE(
cublasSgetriBatched(handle_, len, reinterpret_cast<float **>(lu_batch_addr), len, pivo_addr,
reinterpret_cast<float **>(inv_batch_addr), len, info_addr, batchsize),
"cublas trsm batched Fail");
} else if (std::is_same<T, double>::value) {
CHECK_CUBLAS_RET_WITH_EXCEPT_NOTRACE(cublasDgetrfBatched(handle_, len, reinterpret_cast<double **>(lu_batch_addr),
len, pivo_addr, info_addr, batchsize),
"cublas trsm batched Fail");
CHECK_CUBLAS_RET_WITH_EXCEPT_NOTRACE(
cublasDgetriBatched(handle_, len, reinterpret_cast<double **>(lu_batch_addr), len, pivo_addr,
reinterpret_cast<double **>(inv_batch_addr), len, info_addr, batchsize),
"cublas trsm batched Fail");
} else {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the data type entered must be float or double.";
}
return true;
}
void MatrixInverseGpuKernelMod::InitSizeLists() {
workspace_size_list_.emplace_back(input_size_);
size_t lu_size = batch_size_ * dtype_size_;
workspace_size_list_.emplace_back(lu_size);
size_t inv_size = batch_size_ * dtype_size_;
workspace_size_list_.emplace_back(inv_size);
size_t pivo_size = batch_size_ * size_ * sizeof(int);
workspace_size_list_.emplace_back(pivo_size);
size_t info_size = batch_size_ * sizeof(int);
workspace_size_list_.emplace_back(info_size);
}
std::vector<std::pair<KernelAttr, MatrixInverseGpuKernelMod::MatrixInverseFunc>> MatrixInverseGpuKernelMod::func_list_ =
{{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
&MatrixInverseGpuKernelMod::LaunchKernel<float>},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
&MatrixInverseGpuKernelMod::LaunchKernel<double>}};
std::vector<KernelAttr> MatrixInverseGpuKernelMod::GetOpSupport() {
std::vector<KernelAttr> support_list;
(void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
[](const std::pair<KernelAttr, MatrixInverseFunc> &pair) { return pair.first; });
return support_list;
}
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, MatrixInverse, MatrixInverseGpuKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -21,142 +21,51 @@
#include <vector>
#include <string>
#include <type_traits>
#include <map>
#include <utility>
#include "plugin/device/gpu/kernel/gpu_kernel.h"
#include "plugin/device/gpu/kernel/gpu_kernel_factory.h"
#include "plugin/device/gpu/kernel/kernel_constants.h"
namespace mindspore {
namespace kernel {
template <typename T>
class MatrixInverseGpuKernelMod : public DeprecatedNativeGpuKernelMod {
class MatrixInverseGpuKernelMod : public NativeGpuKernelMod {
public:
MatrixInverseGpuKernelMod()
: input_size_(0), adjoint_(false), is_null_input_(false), handle_(nullptr), batch_size_(1), size_(1) {}
MatrixInverseGpuKernelMod() = default;
~MatrixInverseGpuKernelMod() override = default;
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) override;
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs, const std::map<uint32_t, tensor::TensorPtr> &) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
if (is_null_input_) {
return true;
}
CHECK_CUBLAS_RET_WITH_ERROR(cublasSetStream(handle_, reinterpret_cast<cudaStream_t>(stream_ptr)),
"cublasSetStream failed");
T *input_addr = GetDeviceAddress<T>(inputs, 0);
T *output_addr = GetDeviceAddress<T>(outputs, 0);
auto compute_input_addr = GetDeviceAddress<T>(workspace, 0);
auto lu_batch_addr = GetDeviceAddress<T *>(workspace, 1);
auto inv_batch_addr = GetDeviceAddress<T *>(workspace, 2);
auto pivo_addr = GetDeviceAddress<int>(workspace, 3);
auto info_addr = GetDeviceAddress<int>(workspace, 4);
int len = SizeToInt(size_);
int batchsize = SizeToInt(batch_size_);
for (size_t i = 0; i < batch_size_; i++) {
lu_addr_[i] = compute_input_addr + i * len * len;
inv_addr_[i] = output_addr + i * len * len;
}
CHECK_CUDA_RET_WITH_ERROR(kernel_node_,
cudaMemcpyAsync(compute_input_addr, input_addr, input_size_, cudaMemcpyDeviceToDevice,
reinterpret_cast<cudaStream_t>(stream_ptr)),
"cuda memcopy Fail");
CHECK_CUDA_RET_WITH_ERROR(kernel_node_,
cudaMemcpyAsync(lu_batch_addr, lu_addr_.data(), sizeof(T *) * batch_size_,
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
"cuda memcopy Fail");
CHECK_CUDA_RET_WITH_ERROR(kernel_node_,
cudaMemcpyAsync(inv_batch_addr, inv_addr_.data(), sizeof(T *) * batch_size_,
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
"cuda memcopy Fail");
if (std::is_same<T, float>::value) {
CHECK_CUBLAS_RET_WITH_EXCEPT(kernel_node_,
cublasSgetrfBatched(handle_, len, reinterpret_cast<float **>(lu_batch_addr), len,
pivo_addr, info_addr, batchsize),
"cublas trsm batched Fail");
CHECK_CUBLAS_RET_WITH_EXCEPT(
kernel_node_,
cublasSgetriBatched(handle_, len, reinterpret_cast<float **>(lu_batch_addr), len, pivo_addr,
reinterpret_cast<float **>(inv_batch_addr), len, info_addr, batchsize),
"cublas trsm batched Fail");
} else if (std::is_same<T, double>::value) {
CHECK_CUBLAS_RET_WITH_EXCEPT(kernel_node_,
cublasDgetrfBatched(handle_, len, reinterpret_cast<double **>(lu_batch_addr), len,
pivo_addr, info_addr, batchsize),
"cublas trsm batched Fail");
CHECK_CUBLAS_RET_WITH_EXCEPT(
kernel_node_,
cublasDgetriBatched(handle_, len, reinterpret_cast<double **>(lu_batch_addr), len, pivo_addr,
reinterpret_cast<double **>(inv_batch_addr), len, info_addr, batchsize),
"cublas trsm batched Fail");
} else {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the data type entered must be float or double.";
}
return true;
}
bool Init(const CNodePtr &kernel_node) override {
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
kernel_node_ = kernel_node;
handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCublasHandle();
auto shape_signed = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
if (IsDynamic(shape_signed)) {
return true;
}
auto input_shape = Convert2SizeTClipNeg(shape_signed);
is_null_input_ = CHECK_SHAPE_NULL(input_shape, kernel_name_, "input");
if (is_null_input_) {
InitSizeLists();
return true;
}
if (input_shape.size() < 2) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the dimension of input cannot be less than 2, but got "
<< input_shape.size();
}
size_t last_index = input_shape.size() - 1;
if (input_shape[last_index] != input_shape[last_index - 1]) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the last two dimensions of the input matrix should be equal, "
<< "but got one: " << input_shape[last_index] << ", another: " << input_shape[last_index - 1];
}
size_ = input_shape[last_index];
for (size_t i = 0; i < last_index - 1; i++) {
batch_size_ *= input_shape[i];
}
input_size_ = sizeof(T);
for (auto dim : input_shape) {
input_size_ *= dim;
}
adjoint_ = GetAttr<bool>(kernel_node, "adjoint");
lu_addr_.resize(batch_size_);
inv_addr_.resize(batch_size_);
InitSizeLists();
return true;
}
protected:
void InitSizeLists() override {
input_size_list_.emplace_back(input_size_);
output_size_list_.emplace_back(input_size_);
workspace_size_list_.emplace_back(input_size_);
size_t lu_size = batch_size_ * sizeof(T *);
workspace_size_list_.emplace_back(lu_size);
size_t inv_size = batch_size_ * sizeof(T *);
workspace_size_list_.emplace_back(inv_size);
size_t pivo_size = batch_size_ * size_ * sizeof(int);
workspace_size_list_.emplace_back(pivo_size);
size_t info_size = batch_size_ * sizeof(int);
workspace_size_list_.emplace_back(info_size);
return kernel_func_(this, inputs, workspace, outputs, stream_ptr);
}
std::vector<KernelAttr> GetOpSupport() override;
private:
size_t input_size_;
bool adjoint_;
bool is_null_input_;
template <typename T>
bool LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &workspace,
const std::vector<kernel::AddressPtr> &outputs, void *stream_ptr);
using MatrixInverseFunc =
std::function<bool(MatrixInverseGpuKernelMod *, const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &, const std::vector<kernel::AddressPtr> &, void *)>;
static std::vector<std::pair<KernelAttr, MatrixInverseFunc>> func_list_;
MatrixInverseFunc kernel_func_{nullptr};
void InitSizeLists();
size_t input_size_{1};
bool adjoint_{false};
cublasHandle_t handle_;
size_t batch_size_;
size_t size_;
std::vector<T *> lu_addr_;
std::vector<T *> inv_addr_;
size_t batch_size_{1};
size_t size_{1};
size_t dtype_size_{1};
};
} // namespace kernel
} // namespace mindspore

View File

@ -15,40 +15,133 @@
*/
#include "plugin/device/gpu/kernel/nn/nll_loss_gpu_kernel.h"
#include <map>
#include <utility>
#include "mindspore/core/ops/nllloss.h"
#include "kernel/common_utils.h"
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_TWO(NLLLoss,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
NLLLossGpuKernelMod, float, float)
MS_REG_GPU_KERNEL_TWO(NLLLoss,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat16),
NLLLossGpuKernelMod, float, half)
MS_REG_GPU_KERNEL_TWO(NLLLoss,
KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat32),
NLLLossGpuKernelMod, half, float)
MS_REG_GPU_KERNEL_TWO(NLLLoss,
KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16),
NLLLossGpuKernelMod, half, half)
namespace {
std::map<Reduction, ReductionMode> kReductionMap = {{Reduction::MEAN, ReductionMode::kMean},
{Reduction::REDUCTION_SUM, ReductionMode::kSum},
{Reduction::NONE, ReductionMode::kNone}};
}
bool NLLLossGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) {
auto kernel_ptr = std::dynamic_pointer_cast<ops::NLLLoss>(base_operator);
if (!kernel_ptr) {
MS_LOG(ERROR) << "cast NLLLoss ops failed!";
return false;
}
kernel_name_ = kernel_ptr->GetPrim()->name();
auto reduction = kernel_ptr->get_reduction();
reduction_ = kReductionMap[reduction];
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
if (!is_match) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "', it does not support this kernel data type: " << kernel_attr;
return false;
}
kernel_func_ = func_list_[index].second;
logits_data_size_ = abstract::TypeIdSize(kernel_attr.GetInputAttr(kIndex0).first);
weight_data_size_ = abstract::TypeIdSize(kernel_attr.GetInputAttr(kIndex2).first);
return true;
}
int NLLLossGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &) {
int ret = 0;
if ((ret = KernelMod::Resize(base_operator, inputs, outputs)) != 0) {
return ret;
}
auto logits_shape = inputs[kIndex0]->GetShapeVector();
size_t kMinShapeSize = 2;
if (logits_shape.size() < kMinShapeSize) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the dimension of logits cannot be less than 2, but "
<< "got the " << logits_shape.size();
}
n_ = LongToInt(logits_shape[0]);
c_ = LongToInt(logits_shape[1]);
input_size_ = SizeOf(logits_shape);
if ((reduction_ == ReductionMode::kSum) || (reduction_ == ReductionMode::kMean)) {
tmp_loss_size_ = logits_data_size_ * n_;
}
tmp_target_weight_size_ = n_ * weight_data_size_;
InitSizeLists();
return KRET_OK;
}
template <typename T, typename S>
bool NLLLossGpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
T *input_device = GetDeviceAddress<T>(inputs, 0);
int32_t *target_device = GetDeviceAddress<int32_t>(inputs, 1); // nll_loss only supports int32 target
S *weight_device = GetDeviceAddress<S>(inputs, 2);
T *loss_device = GetDeviceAddress<T>(outputs, 0);
S *total_weight_device = GetDeviceAddress<S>(outputs, 1);
T *tmp_loss_device = reduction_ != ReductionMode::kNone ? GetDeviceAddress<T>(workspace, 0)
: GetPossiblyNullDeviceAddress<T>(workspace, 0);
S *tmp_target_weight_device = GetDeviceAddress<S>(workspace, 1);
NLLLoss(n_, c_, reduction_, input_device, target_device, weight_device, loss_device, total_weight_device,
tmp_loss_device, tmp_target_weight_device, reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}
void NLLLossGpuKernelMod::InitSizeLists() {
workspace_size_list_.push_back(tmp_loss_size_);
workspace_size_list_.push_back(tmp_target_weight_size_);
}
std::vector<std::pair<KernelAttr, NLLLossGpuKernelMod::NLLLossLaunchFunc>> NLLLossGpuKernelMod::func_list_ = {
{KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
&NLLLossGpuKernelMod::LaunchKernel<float, float>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat16),
&NLLLossGpuKernelMod::LaunchKernel<float, half>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat32),
&NLLLossGpuKernelMod::LaunchKernel<half, float>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16),
&NLLLossGpuKernelMod::LaunchKernel<half, half>}};
std::vector<KernelAttr> NLLLossGpuKernelMod::GetOpSupport() {
std::vector<KernelAttr> support_list;
(void)std::transform(
func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
[](const std::pair<KernelAttr, NLLLossGpuKernelMod::NLLLossLaunchFunc> &pair) { return pair.first; });
return support_list;
}
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, NLLLoss, NLLLossGpuKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -19,17 +19,20 @@
#include <vector>
#include <string>
#include <map>
#include <unordered_map>
#include <utility>
#include "plugin/device/gpu/kernel/gpu_kernel.h"
#include "plugin/device/gpu/kernel/gpu_kernel_factory.h"
#include "plugin/factory/ms_factory.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/loss_with_reduction_impl.cuh"
#include "kernel/common_utils.h"
#include "mindapi/base/types.h"
namespace mindspore {
namespace kernel {
template <typename T, typename S>
class NLLLossGpuKernelMod : public DeprecatedNativeGpuKernelMod {
class NLLLossGpuKernelMod : public NativeGpuKernelMod {
public:
NLLLossGpuKernelMod() { ResetResource(); }
NLLLossGpuKernelMod() = default;
~NLLLossGpuKernelMod() override = default;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
@ -37,87 +40,43 @@ class NLLLossGpuKernelMod : public DeprecatedNativeGpuKernelMod {
if (is_null_input_) {
return true;
}
T *input_device = GetDeviceAddress<T>(inputs, 0);
int32_t *target_device = GetDeviceAddress<int32_t>(inputs, 1); // nll_loss only supports int32 target
S *weight_device = GetDeviceAddress<S>(inputs, 2);
T *loss_device = GetDeviceAddress<T>(outputs, 0);
S *total_weight_device = GetDeviceAddress<S>(outputs, 1);
T *tmp_loss_device = reduction_ != ReductionMode::kNone ? GetDeviceAddress<T>(workspace, 0)
: GetPossiblyNullDeviceAddress<T>(workspace, 0);
S *tmp_target_weight_device = GetDeviceAddress<S>(workspace, 1);
NLLLoss(n_, c_, reduction_, input_device, target_device, weight_device, loss_device, total_weight_device,
tmp_loss_device, tmp_target_weight_device, reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
MS_EXCEPTION_IF_NULL(kernel_func_);
return kernel_func_(this, inputs, workspace, outputs, stream_ptr);
}
bool Init(const CNodePtr &kernel_node) override {
auto kernel_name = common::AnfAlgo::GetCNodeName(kernel_node);
auto input_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
kernel_node_ = kernel_node;
is_null_input_ = CHECK_SHAPE_NULL(input_shape, kernel_name, "logits");
if (is_null_input_) {
InitSizeLists();
return true;
}
if (input_shape.size() < 2) {
MS_LOG(EXCEPTION) << "For '" << kernel_name << "', the dimension of logits cannot be less than 2, but "
<< "got the " << input_shape.size();
}
n_ = LongToInt(input_shape[0]);
c_ = LongToInt(input_shape[1]);
input_size_ *= SizeOf(input_shape);
string reduction = GetAttr<string>(kernel_node, "reduction");
reduction_ = kReductionModeMap[reduction];
if ((reduction_ == ReductionMode::kSum) || (reduction_ == ReductionMode::kMean)) {
tmp_loss_size_ = sizeof(T) * n_;
}
tmp_target_weight_size_ = n_ * sizeof(S);
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) override;
InitSizeLists();
return true;
}
void ResetResource() noexcept override {
input_size_ = 1;
n_ = 0;
c_ = 0;
is_null_input_ = false;
reduction_ = ReductionMode::kMean; // default value
tmp_loss_size_ = 0;
tmp_target_weight_size_ = 0; // tmp_target_weight (N,) array
input_size_list_.clear();
output_size_list_.clear();
workspace_size_list_.clear();
}
int Resize(
const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost = std::map<uint32_t, tensor::TensorPtr>()) override;
protected:
void InitSizeLists() override {
input_size_list_.push_back(input_size_ * sizeof(T)); // input tensor with shape (N, C)
input_size_list_.push_back(n_ * sizeof(int32_t)); // target tensor with shape (N)
input_size_list_.push_back(c_ * sizeof(S)); // weight tensor with shape (C)
if (reduction_ == ReductionMode::kNone) {
output_size_list_.push_back(n_ * sizeof(T)); // loss output of shape (N,)
} else {
output_size_list_.push_back(sizeof(T)); // scalar loss output
}
output_size_list_.push_back(sizeof(S)); // total weight
workspace_size_list_.push_back(tmp_loss_size_);
workspace_size_list_.push_back(tmp_target_weight_size_);
}
std::vector<KernelAttr> GetOpSupport() override;
private:
void InitSizeLists();
template <typename T, typename S>
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr);
using NLLLossLaunchFunc =
std::function<bool(NLLLossGpuKernelMod *, const std::vector<AddressPtr> &, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &, void *)>;
static std::vector<std::pair<KernelAttr, NLLLossLaunchFunc>> func_list_;
NLLLossLaunchFunc kernel_func_;
bool is_null_input_;
size_t input_size_;
ReductionMode reduction_;
size_t logits_data_size_;
size_t weight_data_size_;
size_t tmp_loss_size_;
size_t tmp_target_weight_size_;
int n_;
int c_;
bool is_null_input_;
string kernel_name_;
};
} // namespace kernel
} // namespace mindspore

View File

@ -15,44 +15,124 @@
*/
#include "plugin/device/gpu/kernel/nn/nll_loss_grad_gpu_kernel.h"
#include <map>
#include <utility>
#include "mindspore/core/ops/grad/nllloss_grad.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/loss_with_reduction_impl.cuh"
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_TWO(NLLLossGrad,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
NLLLossGradGpuKernelMod, float, float)
MS_REG_GPU_KERNEL_TWO(NLLLossGrad,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat32),
NLLLossGradGpuKernelMod, float, half)
MS_REG_GPU_KERNEL_TWO(NLLLossGrad,
KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat16),
NLLLossGradGpuKernelMod, half, float)
MS_REG_GPU_KERNEL_TWO(NLLLossGrad,
KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16),
NLLLossGradGpuKernelMod, half, half)
namespace {
std::map<Reduction, ReductionMode> kReductionMap = {{Reduction::MEAN, ReductionMode::kMean},
{Reduction::REDUCTION_SUM, ReductionMode::kSum},
{Reduction::NONE, ReductionMode::kNone}};
}
bool NLLLossGradGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) {
auto kernel_ptr = std::dynamic_pointer_cast<ops::NLLLossGrad>(base_operator);
if (!kernel_ptr) {
MS_LOG(ERROR) << "cast NLLLossGrad ops failed!";
return false;
}
auto reduction = kernel_ptr->get_reduction();
reduction_ = kReductionMap[reduction];
auto kernel_name_ = kernel_ptr->GetPrim()->name();
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
if (!is_match) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "', it does not support this kernel data type: " << kernel_attr;
return false;
}
kernel_func_ = func_list_[index].second;
return true;
}
int NLLLossGradGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &) {
int ret = 0;
if ((ret = KernelMod::Resize(base_operator, inputs, outputs)) != 0) {
return ret;
}
auto logits_shape = inputs[kIndex0]->GetShapeVector();
size_t kMinShapeSize = 2;
if (logits_shape.size() < kMinShapeSize) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the dimension of logits cannot be less than 2, but "
<< "got the " << logits_shape.size();
}
n_ = LongToInt(logits_shape[0]);
c_ = LongToInt(logits_shape[1]);
if (reduction_ == ReductionMode::kNone) {
num_dloss_ = n_;
}
return KRET_OK;
}
template <typename T, typename S>
bool NLLLossGradGpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
T *input_device = GetDeviceAddress<T>(inputs, 0);
T *dloss_device = GetDeviceAddress<T>(inputs, 1);
int32_t *target_device = GetDeviceAddress<int32_t>(inputs, 2); // nll_loss_grad only supports int32 target
S *weight_device = GetDeviceAddress<S>(inputs, 3);
S *total_weight_device = GetDeviceAddress<S>(inputs, 4);
T *dinput_device = GetDeviceAddress<T>(outputs, 0);
NLLLossGrad(n_, c_, reduction_, input_device, target_device, weight_device, total_weight_device, dloss_device,
dinput_device, reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}
std::vector<std::pair<KernelAttr, NLLLossGradGpuKernelMod::NLLLossGradLaunchFunc>> NLLLossGradGpuKernelMod::func_list_ =
{{KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
&NLLLossGradGpuKernelMod::LaunchKernel<float, float>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat32),
&NLLLossGradGpuKernelMod::LaunchKernel<float, half>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat16),
&NLLLossGradGpuKernelMod::LaunchKernel<half, float>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16),
&NLLLossGradGpuKernelMod::LaunchKernel<half, half>}};
std::vector<KernelAttr> NLLLossGradGpuKernelMod::GetOpSupport() {
std::vector<KernelAttr> support_list;
(void)std::transform(
func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
[](const std::pair<KernelAttr, NLLLossGradGpuKernelMod::NLLLossGradLaunchFunc> &pair) { return pair.first; });
return support_list;
}
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, NLLLossGrad, NLLLossGradGpuKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -19,17 +19,18 @@
#include <vector>
#include <string>
#include <map>
#include <utility>
#include "plugin/device/gpu/kernel/gpu_kernel.h"
#include "plugin/device/gpu/kernel/gpu_kernel_factory.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/loss_with_reduction_impl.cuh"
#include "plugin/factory/ms_factory.h"
#include "kernel/common_utils.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/loss_with_reduction_impl.cuh"
namespace mindspore {
namespace kernel {
template <typename T, typename S>
class NLLLossGradGpuKernelMod : public DeprecatedNativeGpuKernelMod {
class NLLLossGradGpuKernelMod : public NativeGpuKernelMod {
public:
NLLLossGradGpuKernelMod() { ResetResource(); }
NLLLossGradGpuKernelMod() = default;
~NLLLossGradGpuKernelMod() override = default;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
@ -37,71 +38,33 @@ class NLLLossGradGpuKernelMod : public DeprecatedNativeGpuKernelMod {
if (is_null_input_) {
return true;
}
T *input_device = GetDeviceAddress<T>(inputs, 0);
T *dloss_device = GetDeviceAddress<T>(inputs, 1);
int32_t *target_device = GetDeviceAddress<int32_t>(inputs, 2); // nll_loss_grad only supports int32 target
S *weight_device = GetDeviceAddress<S>(inputs, 3);
S *total_weight_device = GetDeviceAddress<S>(inputs, 4);
T *dinput_device = GetDeviceAddress<T>(outputs, 0);
NLLLossGrad(n_, c_, reduction_, input_device, target_device, weight_device, total_weight_device, dloss_device,
dinput_device, reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
MS_EXCEPTION_IF_NULL(kernel_func_);
return kernel_func_(this, inputs, workspace, outputs, stream_ptr);
}
bool Init(const CNodePtr &kernel_node) override {
auto kernel_name = common::AnfAlgo::GetCNodeName(kernel_node);
auto input_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
kernel_node_ = kernel_node;
is_null_input_ = CHECK_SHAPE_NULL(input_shape, kernel_name, "logits");
if (is_null_input_) {
InitSizeLists();
return true;
}
if (input_shape.size() < 2) {
MS_LOG(EXCEPTION) << "For '" << kernel_name << "', the dimension of logits cannot be less than 2, but "
<< "got the " << input_shape.size();
}
n_ = LongToInt(input_shape[0]);
c_ = LongToInt(input_shape[1]);
input_size_ *= SizeOf(input_shape);
string reduction = GetAttr<string>(kernel_node, "reduction");
reduction_ = kReductionModeMap[reduction];
if (reduction_ == ReductionMode::kNone) {
num_dloss_ = n_;
}
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) override;
InitSizeLists();
return true;
}
void ResetResource() noexcept override {
input_size_ = 1;
n_ = 0;
c_ = 0;
is_null_input_ = false;
reduction_ = ReductionMode::kMean; // default value
num_dloss_ = 1; // default size (scalar)
input_size_list_.clear();
output_size_list_.clear();
workspace_size_list_.clear();
}
int Resize(
const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost = std::map<uint32_t, tensor::TensorPtr>()) override;
protected:
void InitSizeLists() override {
input_size_list_.push_back(input_size_ * sizeof(T)); // input tensor with shape (N, C)
input_size_list_.push_back(num_dloss_ * sizeof(T)); // dloss tensor (either scalar or size N)
input_size_list_.push_back(n_ * sizeof(int32_t)); // target tensor with shape (N)
input_size_list_.push_back(c_ * sizeof(S)); // weight tensor with shape (C)
input_size_list_.push_back(sizeof(S)); // total_weight scalar
output_size_list_.push_back(input_size_ * sizeof(T)); // dinput
}
std::vector<KernelAttr> GetOpSupport() override;
private:
size_t input_size_;
template <typename T, typename S>
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr);
using NLLLossGradLaunchFunc =
std::function<bool(NLLLossGradGpuKernelMod *, const std::vector<AddressPtr> &, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &, void *)>;
static std::vector<std::pair<KernelAttr, NLLLossGradLaunchFunc>> func_list_;
NLLLossGradLaunchFunc kernel_func_;
string kernel_name_;
ReductionMode reduction_;
int n_;
int c_;

View File

@ -241,6 +241,8 @@ constexpr auto kMaxUnpool2DGrad = "MaxUnpool2DGrad";
constexpr auto kMaxUnpool3D = "MaxUnpool3D";
constexpr auto kMaxUnpool3DGrad = "MaxUnpool3DGrad";
constexpr auto kCTCLoss = "CTCLoss";
constexpr auto kNLLLoss = "NLLLoss";
constexpr auto kNLLLossGrad = "NLLLossGrad";
constexpr auto kMultiMarginLoss = "MultiMarginLoss";
constexpr auto kMultiMarginLossGrad = "MultiMarginLossGrad";
constexpr auto kLayerNorm = "LayerNorm";
@ -762,6 +764,8 @@ GVAR_DEF(PrimitivePtr, kPrimMultilabelMarginLossGrad, std::make_shared<Primitive
GVAR_DEF(PrimitivePtr, kPrimCTCLossV2, std::make_shared<Primitive>("CTCLossV2"));
GVAR_DEF(PrimitivePtr, kPrimCTCLossV2Grad, std::make_shared<Primitive>("CTCLossV2Grad"));
GVAR_DEF(PrimitivePtr, kPrimCTCLoss, std::make_shared<Primitive>(kCTCLoss));
GVAR_DEF(PrimitivePtr, kPrimNLLLoss, std::make_shared<Primitive>(kNLLLoss));
GVAR_DEF(PrimitivePtr, kPrimNLLLossGrad, std::make_shared<Primitive>(kNLLLossGrad));
GVAR_DEF(PrimitivePtr, kPrimFullConnection, std::make_shared<Primitive>("FullConnection"));
GVAR_DEF(PrimitivePtr, kPrimConv2DTranspose, std::make_shared<Primitive>(kConv2DTranspose));
GVAR_DEF(PrimitivePtr, kPrimConv3DTranspose, std::make_shared<Primitive>("Conv3DTranspose"));

View File

@ -44,61 +44,60 @@ int64_t EmbeddingLookup::get_offset() const {
return GetValue<int64_t>(value_ptr);
}
namespace {
abstract::ShapePtr EmbeddingLookupInferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
const std::string &op_name = primitive->name();
auto params_shape_ptr = CheckAndConvertUtils::GetTensorInputShape(op_name, input_args, 0);
MS_EXCEPTION_IF_NULL(params_shape_ptr);
auto params_shape = params_shape_ptr->shape();
constexpr int64_t kEmbeddingLookupInputParamsMaxDim = 2;
CheckAndConvertUtils::CheckInRange<int64_t>("dimension of params", params_shape.size(), kIncludeBoth,
{1, kEmbeddingLookupInputParamsMaxDim}, op_name);
auto indices_shape_ptr = CheckAndConvertUtils::GetTensorInputShape(op_name, input_args, 1);
MS_EXCEPTION_IF_NULL(indices_shape_ptr);
auto indices_shape = indices_shape_ptr->shape();
CheckAndConvertUtils::CheckValue<int64_t>("dimension of indices ", indices_shape.size(), kGreaterThan, 0, op_name);
ShapeVector out_shape = indices_shape;
if (params_shape.size() != 1) {
out_shape.push_back(params_shape.back());
}
return std::make_shared<abstract::Shape>(out_shape);
}
class EmbeddingLookupInfer : public abstract::OpInferBase {
public:
BaseShapePtr InferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) const override {
MS_EXCEPTION_IF_NULL(primitive);
const std::string &op_name = primitive->name();
auto params_shape_ptr = CheckAndConvertUtils::GetTensorInputShape(op_name, input_args, kInputIndex0);
MS_EXCEPTION_IF_NULL(params_shape_ptr);
auto params_shape = params_shape_ptr->shape();
constexpr int64_t kEmbeddingLookupInputParamsMaxDim = 2;
CheckAndConvertUtils::CheckInRange<int64_t>("dimension of params", params_shape.size(), kIncludeBoth,
{1, kEmbeddingLookupInputParamsMaxDim}, op_name);
auto indices_shape_ptr = CheckAndConvertUtils::GetTensorInputShape(op_name, input_args, kInputIndex1);
MS_EXCEPTION_IF_NULL(indices_shape_ptr);
auto indices_shape = indices_shape_ptr->shape();
CheckAndConvertUtils::CheckValue<int64_t>("dimension of indices ", indices_shape.size(), kGreaterThan, 0, op_name);
TypePtr EmbeddingLookupInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
const std::string &op_name = primitive->name();
constexpr int64_t input_num_dynamic = 3;
constexpr int64_t input_num = 2;
CheckAndConvertUtils::CheckInRange<int64_t>("input number", SizeToLong(input_args.size()), kIncludeBoth,
{input_num, input_num_dynamic}, op_name);
std::set<TypePtr> valid_params_types = {kTensorType};
(void)CheckAndConvertUtils::CheckSubClass("params", input_args[kInputIndex0]->BuildType(), valid_params_types,
op_name);
std::set<TypePtr> int_types = {kInt32, kInt64};
(void)CheckAndConvertUtils::CheckTensorTypeValid("indices", input_args[kInputIndex1]->BuildType(), int_types,
op_name);
if (SizeToLong(input_args.size()) == input_num_dynamic) {
std::set<TypePtr> int_type = {kInt64};
(void)CheckAndConvertUtils::CheckTypeValid("offset", input_args[kInputIndex2]->BuildType(), int_type, op_name);
ShapeVector out_shape;
if (!params_shape_ptr->IsDimUnknown() && !indices_shape_ptr->IsDimUnknown()) {
out_shape = indices_shape;
if (params_shape.size() != 1) {
out_shape.push_back(params_shape.back());
}
} else {
out_shape.push_back(UNKNOWN_RANK);
}
return std::make_shared<abstract::Shape>(out_shape);
}
CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, 0, op_name);
abstract::AbstractTensorPtr params =
CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(op_name, input_args, 0);
return params->BuildType();
}
} // namespace
TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) const override {
MS_EXCEPTION_IF_NULL(primitive);
const std::string &op_name = primitive->name();
constexpr int64_t input_num_dynamic = 3;
constexpr int64_t input_num = 2;
CheckAndConvertUtils::CheckInRange<int64_t>("input number", SizeToLong(input_args.size()), kIncludeBoth,
{input_num, input_num_dynamic}, op_name);
std::set<TypePtr> valid_params_types = {kTensorType};
(void)CheckAndConvertUtils::CheckSubClass("params", input_args[kInputIndex0]->BuildType(), valid_params_types,
op_name);
std::set<TypePtr> int_types = {kInt32, kInt64};
(void)CheckAndConvertUtils::CheckTensorTypeValid("indices", input_args[kInputIndex1]->BuildType(), int_types,
op_name);
if (SizeToLong(input_args.size()) == input_num_dynamic) {
std::set<TypePtr> int_type = {kInt64};
(void)CheckAndConvertUtils::CheckTypeValid("offset", input_args[kInputIndex2]->BuildType(), int_type, op_name);
}
AbstractBasePtr EmbeddingLookupInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto infer_type = EmbeddingLookupInferType(primitive, input_args);
auto infer_shape = EmbeddingLookupInferShape(primitive, input_args);
return abstract::MakeAbstract(infer_shape, infer_type);
}
CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, 0, op_name);
abstract::AbstractTensorPtr params =
CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(op_name, input_args, 0);
return params->BuildType();
}
};
REGISTER_PRIMITIVE_EVAL_IMPL(EmbeddingLookup, prim::kPrimEmbeddingLookup, EmbeddingLookupInfer, nullptr, true);
REGISTER_PRIMITIVE_OP_INFER_IMPL(EmbeddingLookup, prim::kPrimEmbeddingLookup, EmbeddingLookupInfer, false);
} // namespace ops
} // namespace mindspore

View File

@ -51,8 +51,6 @@ class MIND_API EmbeddingLookup : public BaseOperator {
/// \return offset.
int64_t get_offset() const;
};
abstract::AbstractBasePtr EmbeddingLookupInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &input_args);
} // namespace ops
} // namespace mindspore

View File

@ -15,10 +15,14 @@
*/
#include "ops/grad/nllloss_grad.h"
#include <map>
#include <vector>
#include <memory>
#include <set>
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "mindapi/src/helper.h"
#include "ops/nllloss.h"
namespace mindspore {
namespace ops {
@ -62,85 +66,88 @@ void CheckNLLLossGradShapeValid(const std::string &prim_name, const ShapeVector
}
} // namespace
void NLLLossGrad::Init(const Reduction &reduction) { set_reduction(reduction); }
class NLLLossGradInfer : public abstract::OpInferBase {
public:
BaseShapePtr InferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) const override {
auto prim_name = primitive->name();
void NLLLossGrad::set_reduction(const Reduction &reduction) {
int64_t reduce = reduction;
(void)AddAttr(kReduction, api::MakeValue(reduce));
}
// Check valid.
const size_t x_idx = 0;
const size_t t_idx = 2;
const size_t w_idx = 3;
(void)CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(prim_name, input_args, x_idx);
auto x = input_args[x_idx]->BuildShape();
MS_EXCEPTION_IF_NULL(x);
(void)CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(prim_name, input_args, t_idx);
auto t = input_args[t_idx]->BuildShape();
MS_EXCEPTION_IF_NULL(t);
(void)CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(prim_name, input_args, w_idx);
auto w = input_args[w_idx]->BuildShape();
MS_EXCEPTION_IF_NULL(w);
Reduction NLLLossGrad::get_reduction() const {
auto value_ptr = GetAttr(kReduction);
return Reduction(GetValue<int64_t>(value_ptr));
}
auto x_shape = x->cast<abstract::ShapePtr>();
MS_EXCEPTION_IF_NULL(x_shape);
abstract::ShapePtr NLLLossGradInferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
auto prim_name = primitive->name();
if (x->IsDynamic() || t->IsDynamic() || w->IsDynamic()) {
return x_shape;
}
// Check valid.
const size_t x_idx = 0;
const size_t t_idx = 2;
const size_t w_idx = 3;
(void)CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(prim_name, input_args, x_idx);
auto x = input_args[x_idx]->BuildShape();
MS_EXCEPTION_IF_NULL(x);
(void)CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(prim_name, input_args, t_idx);
auto t = input_args[t_idx]->BuildShape();
MS_EXCEPTION_IF_NULL(t);
(void)CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(prim_name, input_args, w_idx);
auto w = input_args[w_idx]->BuildShape();
MS_EXCEPTION_IF_NULL(w);
auto t_shape = t->cast<abstract::ShapePtr>();
MS_EXCEPTION_IF_NULL(t_shape);
auto w_shape = w->cast<abstract::ShapePtr>();
MS_EXCEPTION_IF_NULL(w_shape);
auto x_shape = x->cast<abstract::ShapePtr>();
MS_EXCEPTION_IF_NULL(x_shape);
CheckNLLLossGradShapeValid(prim_name, x_shape->shape(), t_shape->shape(), w_shape->shape());
if (x->IsDynamic() || t->IsDynamic() || w->IsDynamic()) {
return x_shape;
}
auto t_shape = t->cast<abstract::ShapePtr>();
MS_EXCEPTION_IF_NULL(t_shape);
auto w_shape = w->cast<abstract::ShapePtr>();
MS_EXCEPTION_IF_NULL(w_shape);
TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) const override {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
const int64_t input_num = 5;
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, prim_name);
// check
std::set<TypePtr> valid_types = {kFloat16, kFloat32};
auto x_dtype = input_args[kInputIndex0]->BuildType();
auto y_grad_dtype = input_args[kInputIndex1]->BuildType();
auto t_dtype = input_args[kInputIndex2]->BuildType();
auto w_dtype = input_args[kInputIndex3]->BuildType();
auto tw_dtype = input_args[kInputIndex4]->BuildType();
(void)CheckAndConvertUtils::CheckTensorTypeValid("logits dtype", x_dtype, valid_types, prim_name);
(void)CheckAndConvertUtils::CheckTensorTypeValid("loss's grad dtype", y_grad_dtype, valid_types, prim_name);
(void)CheckAndConvertUtils::CheckTensorTypeValid("labels dtype", t_dtype, {kInt32}, prim_name);
(void)CheckAndConvertUtils::CheckTensorTypeValid("weight dtype", w_dtype, valid_types, prim_name);
(void)CheckAndConvertUtils::CheckTensorTypeValid("total_weight dtype", tw_dtype, valid_types, prim_name);
CheckAndConvertUtils::Check("weight dtype", std::vector<TypeId>{tw_dtype->type_id()}, kEqual,
std::vector<TypeId>{w_dtype->type_id()}, prim_name);
return x_dtype;
}
};
CheckNLLLossGradShapeValid(prim_name, x_shape->shape(), t_shape->shape(), w_shape->shape());
void NLLLossGrad::Init(const Reduction &reduction) { set_reduction(reduction); }
return x_shape;
void NLLLossGrad::set_reduction(const Reduction &reduction) {
std::string reduce;
if (reduction == Reduction::REDUCTION_SUM) {
reduce = "sum";
} else if (reduction == Reduction::MEAN) {
reduce = "mean";
} else {
reduce = "none";
}
(void)this->AddAttr(kReduction, api::MakeValue(reduce));
}
TypePtr NLLLossGradInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
// check
std::set<TypePtr> valid_types = {kFloat16, kFloat32};
auto x_dtype = input_args[kInputIndex0]->BuildType();
auto y_grad_dtype = input_args[kInputIndex1]->BuildType();
auto t_dtype = input_args[kInputIndex2]->BuildType();
auto w_dtype = input_args[kInputIndex3]->BuildType();
auto tw_dtype = input_args[kInputIndex4]->BuildType();
(void)CheckAndConvertUtils::CheckTensorTypeValid("logits dtype", x_dtype, valid_types, prim_name);
(void)CheckAndConvertUtils::CheckTensorTypeValid("loss's grad dtype", y_grad_dtype, valid_types, prim_name);
(void)CheckAndConvertUtils::CheckTensorTypeValid("labels dtype", t_dtype, {kInt32}, prim_name);
(void)CheckAndConvertUtils::CheckTensorTypeValid("weight dtype", w_dtype, valid_types, prim_name);
(void)CheckAndConvertUtils::CheckTensorTypeValid("total_weight dtype", tw_dtype, valid_types, prim_name);
CheckAndConvertUtils::Check("weight dtype", std::vector<TypeId>{tw_dtype->type_id()}, kEqual,
std::vector<TypeId>{w_dtype->type_id()}, prim_name);
return x_dtype;
}
AbstractBasePtr NLLLossGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
const int64_t input_num = 5;
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, prim_name);
auto types = NLLLossGradInferType(primitive, input_args);
auto shapes = NLLLossGradInferShape(primitive, input_args);
return abstract::MakeAbstract(shapes, types);
Reduction NLLLossGrad::get_reduction() const {
auto value_ptr = MakeValue(GetValue<std::string>(GetAttr(kReduction)));
int64_t reduction = 0;
CheckAndConvertUtils::GetReductionEnumValue(value_ptr, &reduction);
return Reduction(reduction);
}
MIND_API_OPERATOR_IMPL(NLLLossGrad, BaseOperator);
REGISTER_PRIMITIVE_EVAL_IMPL(NLLLossGrad, std::make_shared<Primitive>("NLLLossGrad"), NLLLossGradInfer, nullptr, true);
REGISTER_PRIMITIVE_OP_INFER_IMPL(NLLLossGrad, std::make_shared<Primitive>("NLLLossGrad"), NLLLossGradInfer, false);
} // namespace ops
} // namespace mindspore

View File

@ -35,7 +35,7 @@ class MIND_API NLLLossGrad : public BaseOperator {
}
/// \brief Init. Refer to the parameters of Python API @ref mindspore.ops.NLLLossGrad for the inputs.
void Init(const Reduction &reduction = NONE);
void Init(const Reduction &reduction = Reduction::NONE);
/// \brief Set reduction.
void set_reduction(const Reduction &reduction);

View File

@ -25,41 +25,48 @@
namespace mindspore {
namespace ops {
namespace {
abstract::ShapePtr MatrixInverseInferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
auto prim_name = primitive->name();
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
auto x_rank = SizeToLong(x_shape.size());
const constexpr int64_t kNumber1 = 1;
const constexpr int64_t kNumber2 = 2;
(void)CheckAndConvertUtils::CheckInteger("x rank", x_rank, kGreaterEqual, kNumber2, prim_name);
CheckAndConvertUtils::Check("row size", x_shape[x_rank - kNumber1], kEqual, x_shape[x_rank - kNumber2], prim_name);
(void)CheckAndConvertUtils::CheckInteger("row size", x_shape[LongToSize(x_rank - kNumber1)], kGreaterEqual, kNumber2,
prim_name);
(void)CheckAndConvertUtils::CheckInteger("column size", x_shape[LongToSize(x_rank - kNumber2)], kGreaterEqual,
kNumber2, prim_name);
return std::make_shared<abstract::Shape>(x_shape);
}
TypePtr MatrixInverseInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
const std::set<TypePtr> valid_types = {kFloat32, kFloat64};
auto infer_type = input_args[0]->BuildType();
(void)CheckAndConvertUtils::CheckTensorTypeValid("x", infer_type, valid_types, prim->name());
return infer_type;
}
} // namespace
MIND_API_OPERATOR_IMPL(MatrixInverse, BaseOperator);
AbstractBasePtr MatrixInverseInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
const int64_t input_num = 1;
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name());
auto infertype = MatrixInverseInferType(primitive, input_args);
auto infershape = MatrixInverseInferShape(primitive, input_args);
return abstract::MakeAbstract(infershape, infertype);
void MatrixInverse::Init(const bool adjoint) { this->set_adjoint(adjoint); }
void MatrixInverse::set_adjoint(const bool adjoint) { (void)this->AddAttr(kAdjoint, api::MakeValue(adjoint)); }
bool MatrixInverse::get_adjoint() const {
auto value_ptr = GetAttr(kAlign);
return GetValue<bool>(value_ptr);
}
REGISTER_PRIMITIVE_EVAL_IMPL(MatrixInverse, prim::kPrimMatrixInverse, MatrixInverseInfer, nullptr, true);
class MatrixInverseInfer : public abstract::OpInferBase {
public:
BaseShapePtr InferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) const override {
auto prim_name = primitive->name();
auto x_shape_ptr = input_args[kInputIndex0]->BuildShape()->cast<abstract::ShapePtr>();
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(x_shape_ptr)[kShape];
auto x_rank = SizeToLong(x_shape.size());
const constexpr int64_t kNumber1 = 1;
const constexpr int64_t kNumber2 = 2;
if (!x_shape_ptr->IsDynamic()) {
(void)CheckAndConvertUtils::CheckInteger("x rank", x_rank, kGreaterEqual, kNumber2, prim_name);
CheckAndConvertUtils::Check("row size", x_shape[x_rank - kNumber1], kEqual, x_shape[x_rank - kNumber2],
prim_name);
(void)CheckAndConvertUtils::CheckInteger("row size", x_shape[LongToSize(x_rank - kNumber1)], kGreaterEqual,
kNumber2, prim_name);
(void)CheckAndConvertUtils::CheckInteger("column size", x_shape[LongToSize(x_rank - kNumber2)], kGreaterEqual,
kNumber2, prim_name);
}
return std::make_shared<abstract::Shape>(x_shape);
}
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) const override {
const int64_t input_num = 1;
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, prim->name());
const std::set<TypePtr> valid_types = {kFloat32, kFloat64};
auto infer_type = input_args[kInputIndex0]->BuildType();
(void)CheckAndConvertUtils::CheckTensorTypeValid("x", infer_type, valid_types, prim->name());
return infer_type;
}
};
REGISTER_PRIMITIVE_OP_INFER_IMPL(MatrixInverse, prim::kPrimMatrixInverse, MatrixInverseInfer, false);
} // namespace ops
} // namespace mindspore

View File

@ -29,11 +29,12 @@ class MIND_API MatrixInverse : public BaseOperator {
public:
MIND_API_BASE_MEMBER(MatrixInverse);
MatrixInverse() : BaseOperator(kNameMatrixInverse) { InitIOName({"x"}, {"y"}); }
};
abstract::AbstractBasePtr MatrixInverseInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &input_args);
using PrimMatrixInversePtr = std::shared_ptr<MatrixInverse>;
void Init(const bool adjoint = false);
void set_adjoint(const bool adjoint);
bool get_adjoint() const;
};
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_MATRIX_INVERSE_H_

View File

@ -15,25 +15,104 @@
*/
#include "ops/nllloss.h"
#include <memory>
#include <string>
#include <vector>
#include <map>
#include "abstract/ops/primitive_infer_map.h"
#include "ops/op_utils.h"
#include "mindapi/src/helper.h"
#include "utils/check_convert_utils.h"
#include "include/common/utils/utils.h"
namespace mindspore {
namespace ops {
void NLLLoss::Init(const Reduction &reduction) { set_reduction(reduction); }
MIND_API_OPERATOR_IMPL(NLLLoss, BaseOperator);
class NLLLossInfer : public abstract::OpInferBase {
public:
BaseShapePtr InferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) const override {
const auto prim_name = primitive->name();
auto logits_shape_ptr = input_args[kInputIndex0]->BuildShape();
auto logits_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(logits_shape_ptr)[kShape];
auto target_shape_ptr = input_args[kInputIndex1]->BuildShape();
auto target_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(target_shape_ptr)[kShape];
auto weight_shape_ptr = input_args[kInputIndex2]->BuildShape();
auto weight_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(weight_shape_ptr)[kShape];
(void)CheckAndConvertUtils::CheckInteger("rank of target", SizeToLong(target_shape.size()), kEqual, 1, prim_name);
(void)CheckAndConvertUtils::CheckInteger("rank of weight", SizeToLong(weight_shape.size()), kEqual, 1, prim_name);
(void)CheckAndConvertUtils::CheckInRange("rank of logits", SizeToLong(logits_shape.size()), kIncludeBoth, {1, 2},
prim_name);
if (!logits_shape_ptr->IsDynamic()) {
if (!target_shape_ptr->IsDynamic() && logits_shape[kInputIndex0] != target_shape[kInputIndex0]) {
MS_EXCEPTION(ValueError) << "For '" << prim_name
<< "', the 'logits_dim0' and the shape of 'target' should be equal, but got "
<< logits_shape[kInputIndex0] << " and " << target_shape[kInputIndex0] << ".";
}
int64_t weight_dim = logits_shape.size() - 1;
if (!weight_shape_ptr->IsDynamic() && logits_shape[weight_dim] != weight_shape[kInputIndex0]) {
MS_EXCEPTION(ValueError) << "For '" << prim_name
<< "', the last dim of 'logits' and the shape of 'weight' should be equal, but got "
<< logits_shape[weight_dim] << " and " << weight_shape[kInputIndex0] << ".";
}
}
ShapeVector loss_shape;
ShapeVector total_weight_shape;
auto reduction_ptr = primitive->GetAttr(kReduction);
bool reduction_is_none;
if (reduction_ptr->isa<StringImm>()) {
auto reduction = GetValue<std::string>(reduction_ptr);
reduction_is_none = reduction == kNone;
} else {
auto reduction = Reduction(GetValue<int64_t>(reduction_ptr));
reduction_is_none = reduction == Reduction::NONE;
}
if (reduction_is_none) {
loss_shape.push_back(logits_shape[kInputIndex0]);
}
abstract::ShapePtr loss_shape_ptr = std::make_shared<abstract::Shape>(loss_shape);
abstract::ShapePtr total_weight_shape_ptr = std::make_shared<abstract::Shape>(total_weight_shape);
return std::make_shared<abstract::TupleShape>(
std::vector<abstract::BaseShapePtr>{loss_shape_ptr, total_weight_shape_ptr});
}
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) const override {
const std::set valid_types = {kFloat16, kFloat32};
auto target_type = input_args[kIndex1]->BuildType();
(void)CheckAndConvertUtils::CheckTensorTypeValid("target", target_type, {kInt32}, prim->name());
std::map<std::string, TypePtr> types;
auto logits_data_type = input_args[kIndex0]->BuildType();
(void)types.emplace("logits", logits_data_type);
(void)types.emplace("weight", input_args[kIndex2]->BuildType());
(void)CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name());
return std::make_shared<Tuple>(std::vector<TypePtr>{logits_data_type, logits_data_type});
}
};
void NLLLoss::Init(const Reduction &reduction) { this->set_reduction(reduction); }
void NLLLoss::set_reduction(const Reduction &reduction) {
int64_t reduce = reduction;
(void)AddAttr(kReduction, api::MakeValue(reduce));
std::string reduce;
if (reduction == Reduction::REDUCTION_SUM) {
reduce = "sum";
} else if (reduction == Reduction::MEAN) {
reduce = "mean";
} else {
reduce = "none";
}
(void)this->AddAttr(kReduction, api::MakeValue(reduce));
}
Reduction NLLLoss::get_reduction() const {
auto value_ptr = GetAttr(kReduction);
return Reduction(GetValue<int64_t>(value_ptr));
auto value_ptr = MakeValue(GetValue<std::string>(GetAttr(kReduction)));
int64_t reduction = 0;
CheckAndConvertUtils::GetReductionEnumValue(value_ptr, &reduction);
return Reduction(reduction);
}
MIND_API_OPERATOR_IMPL(NLLLoss, BaseOperator);
REGISTER_PRIMITIVE_C(kNameNLLLoss, NLLLoss);
REGISTER_PRIMITIVE_OP_INFER_IMPL(NLLLoss, prim::kPrimNLLLoss, NLLLossInfer, false);
} // namespace ops
} // namespace mindspore

View File

@ -18,7 +18,8 @@
#define MINDSPORE_CORE_OPS_NLLLOSS_H_
#include <string>
#include <vector>
#include <memory>
#include "ops/base_operator.h"
#include "mindapi/base/types.h"
@ -33,7 +34,7 @@ class MIND_API NLLLoss : public BaseOperator {
NLLLoss() : BaseOperator(kNameNLLLoss) { InitIOName({"logits", "labels", "weight"}, {"loss", "total_weight"}); }
/// \brief Init. Refer to the parameters of Python API @ref mindspore.ops.NLLLoss for the inputs.
void Init(const Reduction &reduction = NONE);
void Init(const Reduction &reduction = Reduction::NONE);
/// \brief Set reduction.
void set_reduction(const Reduction &reduction);
@ -45,5 +46,4 @@ class MIND_API NLLLoss : public BaseOperator {
};
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_NLLLOSS_H_

View File

@ -2553,7 +2553,7 @@ class BiasAdd(Primitive):
self.add_prim_attr('data_format', self.format)
class NLLLoss(PrimitiveWithInfer):
class NLLLoss(Primitive):
r"""
Gets the negative log likelihood loss between logits and labels.
@ -2622,29 +2622,8 @@ class NLLLoss(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, reduction="mean"):
"""Initialize NLLLoss"""
self.init_prim_io_names(inputs=['x', 'target', "weight"], outputs=['loss'])
self.init_prim_io_names(inputs=['x', 'target', "weight"], outputs=['loss', 'total_weight'])
self.reduction = validator.check_string(reduction, ['none', 'sum', 'mean'], 'reduction', self.name)
self.add_prim_attr('reduction', self.reduction)
def infer_shape(self, x_shape, t_shape, w_shape):
validator.check_int(len(x_shape), [1, 2], Rel.IN, "x rank", self.name)
validator.check_int(len(t_shape), 1, Rel.EQ, "target rank", self.name)
validator.check_int(len(w_shape), 1, Rel.EQ, "weight rank", self.name)
validator.check(f"input_shape[0]", x_shape[0], "target_shape", t_shape[0], Rel.EQ, self.name)
if len(x_shape) == 1:
validator.check(f"input_shape[0]", x_shape[0], "weight_shape", w_shape[0], Rel.EQ, self.name)
else:
validator.check(f"input_shape[1]", x_shape[1], "weight_shape", w_shape[0], Rel.EQ, self.name)
if self.reduction == "none":
return t_shape, ()
return (), ()
def infer_dtype(self, x_dtype, t_dtype, w_dtype):
valid_dtypes = (mstype.float16, mstype.float32)
validator.check_tensor_dtype_valid("x_dtype", x_dtype, valid_dtypes, self.name)
validator.check_tensor_dtype_valid("t_dtype", t_dtype, mstype.int32, self.name)
validator.check_tensor_dtype_valid("w_dtype", w_dtype, valid_dtypes, self.name)
return x_dtype, w_dtype
class SoftmaxCrossEntropyWithLogits(Primitive):

View File

@ -0,0 +1,57 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""test embedding_lookup dynamic shape"""
import numpy as np
import pytest
from mindspore import context
from mindspore import nn
import mindspore.common.dtype as mstype
from mindspore import Tensor
from mindspore.ops import operations as P
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
class Net(nn.Cell):
def __init__(self):
super().__init__()
self.embedding = P.EmbeddingLookup().add_prim_attr("primitive_target", "CPU")
self.offset = 4
def construct(self, param, index):
return self.embedding(param, index, self.offset)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_embedding_look_up0():
"""
Feature: test embedding_lookup op
Description: test the ops in dynamic shape
Expectation: expect correct shape result.
"""
params = Tensor(
np.array([[8, 9], [10, 11], [12, 13], [14, 15]]), mstype.float32)
indices = Tensor(np.array([5, 2, 8, 5]), mstype.int32)
params_dyn = Tensor(shape=[None, None], dtype=params.dtype)
indices_dyn = Tensor(shape=[None], dtype=indices.dtype)
embedding = Net()
embedding.set_inputs(params_dyn, indices_dyn)
out = embedding(params, indices)
expect_shape = (4, 2)
assert out.asnumpy().shape == expect_shape

View File

@ -0,0 +1,55 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either matrix_inverseress or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""test matrix_inverse dynamic shape"""
import numpy as np
import pytest
from mindspore import context
from mindspore import nn
from mindspore import Tensor
from mindspore.ops import operations as P
np.random.seed(1)
class NetMatrixInverse(nn.Cell):
def __init__(self):
super().__init__()
self.matrix_inverse = P.MatrixInverse()
def construct(self, x):
return self.matrix_inverse(x)
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_matrix_inverse():
"""
Feature: test matrix_inverse op
Description: test the ops in dynamic shape
Expectation: expect correct shape result.
"""
x_np = np.random.uniform(-2, 2, (3, 4, 4)).astype(np.float32)
x = Tensor(x_np)
x_dyn = Tensor(shape=[None, None, None], dtype=x.dtype)
context.set_context(device_target="GPU")
matrix_inverse = NetMatrixInverse()
matrix_inverse.set_inputs(x_dyn)
output0 = matrix_inverse(x)
assert output0.asnumpy().shape == (3, 4, 4)

View File

@ -0,0 +1,107 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""test NLLLoss forward and backward dynamic shape"""
import pytest
import numpy as np
from mindspore import context
from mindspore import Tensor
from mindspore import nn
from mindspore.ops import operations as P
from mindspore.ops import composite as C
context.set_context(mode=context.GRAPH_MODE, device_target="CPU", save_graphs=True)
class NLLLoss(nn.Cell):
def __init__(self, reduction="none"):
super().__init__()
self.nllloss = P.NLLLoss(reduction=reduction)
def construct(self, x, t, w):
return self.nllloss(x, t, w)
class NLLLossGrad(nn.Cell):
def __init__(self, forward, sens):
super().__init__()
self.forward = forward
self.grad = C.GradOperation(get_all=True, sens_param=True)
self.sens = sens
def construct(self, x, t, w):
return self.grad(self.forward)(x, t, w, self.sens)
np_type = np.float32
logits = Tensor(np.array([[-1.3739, -2.2700, -3.2333, -2.4589, -0.6566],
[-1.2156, -2.6026, -1.2200, -1.8731, -1.7119],
[-0.7130, -3.3672, -1.5368, -1.8289, -2.3058]]).astype(np_type))
target = Tensor(np.array([1, 0, 4]).astype(np.int32))
weight = Tensor(np.array([0.2, 0.3, 0.1, 0.15, 0.25]).astype(np_type))
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_nllloss_cpu_none_dynamic_shape():
"""
Feature: test nllloss op with reduction none.
Description: test the ops in dynamic shape.
Expectation: expect correct output shape.
"""
nllloss = NLLLoss("none")
logits_dyn = Tensor(shape=[None]*len(logits.shape), dtype=logits.dtype)
target_dyn = Tensor(shape=[None]*len(target.shape), dtype=target.dtype)
weight_dyn = Tensor(shape=[None]*len(weight.shape), dtype=weight.dtype)
nllloss.set_inputs(logits_dyn, target_dyn, weight_dyn)
loss, total_weight = nllloss(logits, target, weight)
assert loss.asnumpy().shape == (logits.shape[0],)
assert total_weight.asnumpy().shape == tuple()
nllloss_grad = NLLLossGrad(nllloss, sens=(loss + 0.5, total_weight + 0.5))
nllloss_grad.set_inputs(logits_dyn, target_dyn, weight_dyn)
expect_grad = nllloss_grad(logits, target, weight)
assert expect_grad[0].asnumpy().shape == logits.asnumpy().shape
assert expect_grad[1].asnumpy().shape == target.asnumpy().shape
assert expect_grad[2].asnumpy().shape == weight.asnumpy().shape
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_nllloss_cpu_mean_dynamic_shape():
"""
Feature: test nllloss op with reduction mean.
Description: test the ops in dynamic shape
Expectation: expect correct shape result.
"""
nllloss = NLLLoss("mean")
logits_dyn = Tensor(shape=[None]*len(logits.shape), dtype=logits.dtype)
target_dyn = Tensor(shape=[None]*len(target.shape), dtype=target.dtype)
weight_dyn = Tensor(shape=[None]*len(weight.shape), dtype=weight.dtype)
nllloss.set_inputs(logits_dyn, target_dyn, weight_dyn)
loss, total_weight = nllloss(logits, target, weight)
assert loss.asnumpy().shape == tuple()
assert total_weight.asnumpy().shape == tuple()
nllloss_grad = NLLLossGrad(nllloss, sens=(loss + 0.5, total_weight + 0.5))
nllloss_grad.set_inputs(logits_dyn, target_dyn, weight_dyn)
expect_grad = nllloss_grad(logits, target, weight)
assert expect_grad[0].asnumpy().shape == logits.asnumpy().shape
assert expect_grad[1].asnumpy().shape == target.asnumpy().shape
assert expect_grad[2].asnumpy().shape == weight.asnumpy().shape