add logsoftmax gpu brop xd support

This commit is contained in:
fandawei 2023-03-07 11:44:33 +08:00
parent 9bec955a7d
commit 05ecf44ef4
3 changed files with 108 additions and 31 deletions

View File

@ -23,7 +23,6 @@ namespace mindspore {
namespace kernel {
constexpr size_t INPUT_NUM = 2;
constexpr size_t OUTPUT_NUM = 1;
constexpr size_t SUPPORT_SIZE = 2;
bool SoftmaxGradGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) {
@ -55,10 +54,6 @@ int SoftmaxGradGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const
ResetResource();
auto input_shape = LongVecToSizeVec(inputs[kIndex0]->GetShapeVector());
shape_size_ = input_shape.size();
if (shape_size_ != SUPPORT_SIZE) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the dimension of input must be equal to 2, but got "
<< shape_size_;
}
if (kernel_name_ == "LogSoftmaxGrad") {
algo_ = CUDNN_SOFTMAX_LOG;
auto log_soft_max_grad_ptr = std::dynamic_pointer_cast<ops::LogSoftmaxGrad>(base_operator);
@ -106,7 +101,7 @@ bool SoftmaxGradGpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs
const float alpha = 1;
const float beta = 0;
if (axis_ == 1) {
if (axis_ == static_cast<int>(input_shape_.size()) - 1) {
CHECK_CUDNN_RET_WITH_EXCEPT_NOTRACE(cudnnSoftmaxBackward(cudnn_handle_, algo_, mode_, &alpha, y_desc_, y_addr,
y_desc_, dy_addr, &beta, y_desc_, dx_addr),
kernel_name_ + "cudnnSoftmaxBackward failed");

View File

@ -22,6 +22,7 @@
#include <algorithm>
#include <map>
#include <utility>
#include <functional>
#include "plugin/device/gpu/kernel/gpu_kernel.h"
#include "plugin/device/gpu/kernel/gpu_kernel_factory.h"
#include "plugin/device/gpu/kernel/kernel_constants.h"
@ -82,23 +83,28 @@ class SoftmaxGradGpuKernelMod : public NativeGpuKernelMod {
if (axis_ < 0) {
axis_ += SizeToInt(shape_size_);
}
if (axis_ == 1) {
batch_size_ = input_shape[0];
channel_size_ = input_shape[1];
} else if (axis_ == 0) {
batch_size_ = input_shape[1];
channel_size_ = input_shape[0];
input_shape_.push_back(input_shape[0]);
input_shape_.push_back(input_shape[1]);
transpose_shape_.push_back(input_shape[1]);
transpose_shape_.push_back(input_shape[0]);
transpose_axis_.push_back(1);
transpose_axis_.push_back(0);
} else {
if (axis_ >= SizeToInt(shape_size_) || axis_ < 0) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the value of 'axis' must be in range [-" << shape_size_
<< ", " << shape_size_ << "), but got " << axis;
}
input_shape_ = input_shape;
transpose_shape_ = input_shape;
for (size_t i = 0; i < input_shape.size(); ++i) {
transpose_axis_.emplace_back(i);
}
std::swap(transpose_shape_[IntToSize(axis_)], transpose_shape_.back());
std::swap(transpose_axis_[IntToSize(axis_)], transpose_axis_.back());
size_t size_ = std::accumulate(input_shape.begin(), input_shape.end(), 1UL, std::multiplies<size_t>());
channel_size_ = transpose_shape_.back();
if (channel_size_ != 0) {
batch_size_ = size_ / channel_size_;
} else {
MS_LOG(EXCEPTION) << "For '" << kernel_name_
<< "', the value of the shape of the input along the 'axis' dimension should be greater than 0"
<< ", but got " << channel_size_;
}
height_ = 1;
width_ = 1;
input_size_ = type_id_size_ * batch_size_ * channel_size_ * height_ * width_;
@ -124,7 +130,6 @@ class SoftmaxGradGpuKernelMod : public NativeGpuKernelMod {
std::vector<size_t> transpose_axis_;
int axis_{0};
size_t shape_size_{0};
size_t batch_size_{0};
size_t channel_size_{0};
size_t height_{0};

View File

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2023 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -29,6 +29,11 @@ from mindspore.ops.functional import vmap
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_logsoftmax():
"""
Feature: logsoftmax
Description: Verify the result of logsoftmax
Expectation: success
"""
x = np.array([[-0.08082921, -0.13706027, -0.4711177, -0.05606057],
[-0.46082982, 1.1761844, -1.016654, -1.743829],
[-1.5062045, 0.6910976, 0.4839723, 1.1502692]]).astype(np.float32)
@ -62,10 +67,16 @@ class Grad(nn.Cell):
return gout
@pytest.mark.level1
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_logsoftmaxgrad():
@pytest.mark.parametrize("mode", [context.GRAPH_MODE, context.PYNATIVE_MODE])
def test_logsoftmaxgrad(mode):
"""
Feature: logsoftmaxgrad
Description: Verify the result of logsoftmaxgrad with 2d input, dim=-1
Expectation: success
"""
x = np.array([[-0.47705367, 0.48267725, -1.0453935, 1.574488, 0.20362134, 0.4435456, -0.23984082, -0.43684655,
-0.7725506, 1.4481013],
[1.1012247, 1.7069651, 0.55062026, 0.3361901, -1.1082426, -0.5001939, -0.3255393, -0.7972024,
@ -97,16 +108,81 @@ def test_logsoftmaxgrad():
[-1.2582518, 0.57718843, -1.0812542, 1.4944922, -0.8770549, 0.1476463, 0.40500447, 0.23499368,
0.09027944, 0.26695627]]).astype(np.float32)
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
context.set_context(mode=mode, device_target="GPU")
net = LogSoftmax()
dx = Grad(net)(Tensor(x), Tensor(dy))
assert np.allclose(dx[0].asnumpy(), expect)
@pytest.mark.level1
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_logsoftmaxgrad1():
@pytest.mark.parametrize("mode", [context.GRAPH_MODE, context.PYNATIVE_MODE])
def test_logsoftmaxgrad_4d_lastdim(mode):
"""
Feature: logsoftmaxgrad
Description: Verify the result of logsoftmaxgrad with 4d input, dim=-1
Expectation: success
"""
x = np.array([[[[0.9342035, 0.41253936],
[0.96119386, 0.45106655]],
[[0.9795543, 0.70140046],
[0.34018862, 0.31537667]]]], dtype=np.float32)
dy = np.array([[[[0.28354234, 0.23482183],
[0.06688348, 0.7837496]],
[[0.14290118, 0.47044736],
[0.46478033, 0.46465948]]]], dtype=np.float32)
expect = np.array([[[[-0.04175026, 0.04175026],
[-0.46462297, 0.46462294]],
[[-0.2061515, 0.20615152],
[-0.00570459, 0.00570457]]]], dtype=np.float32)
context.set_context(mode=mode, device_target="GPU")
net = LogSoftmax()
dx = Grad(net)(Tensor(x), Tensor(dy))
assert np.allclose(dx[0].asnumpy(), expect, atol=1e-5, rtol=1e-5)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@pytest.mark.parametrize("mode", [context.GRAPH_MODE, context.PYNATIVE_MODE])
def test_logsoftmaxgrad_4d_dim1(mode):
"""
Feature: logsoftmaxgrad
Description: Verify the result of logsoftmaxgrad with 4d input, dim=1
Expectation: success
"""
dim = 1
x = np.array([[[[0.9342035, 0.41253936],
[0.96119386, 0.45106655]],
[[0.9795543, 0.70140046],
[0.34018862, 0.31537667]]]], dtype=np.float32)
dy = np.array([[[[0.28354234, 0.23482183],
[0.06688348, 0.7837496]],
[[0.14290118, 0.47044736],
[0.46478033, 0.46465948]]]], dtype=np.float32)
expect = np.array([[[[0.07515464, -0.06723279],
[-0.27893573, 0.11726081]],
[[-0.07515462, 0.06723274],
[0.2789357, -0.11726078]]]], dtype=np.float32)
context.set_context(mode=mode, device_target="GPU")
net = LogSoftmax(dim)
dx = Grad(net)(Tensor(x), Tensor(dy))
assert np.allclose(dx[0].asnumpy(), expect, atol=1e-5, rtol=1e-5)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@pytest.mark.parametrize("mode", [context.GRAPH_MODE, context.PYNATIVE_MODE])
def test_logsoftmaxgrad1(mode):
"""
Feature: logsoftmaxgrad
Description: Verify the result of logsoftmaxgrad with 2d input, dim=0
Expectation: success
"""
x = np.array([[-0.47705367, 0.48267725, -1.0453935, 1.574488, 0.20362134, 0.4435456, -0.23984082, -0.43684655,
-0.7725506, 1.4481013],
[1.1012247, 1.7069651, 0.55062026, 0.3361901, -1.1082426, -0.5001939, -0.3255393, -0.7972024,
@ -138,16 +214,17 @@ def test_logsoftmaxgrad1():
[-0.5247209, 0.70192003, -1.0808672, 1.4858199, -1.1273282, 0.20728993, 0.38918605, 0.08162117,
0.10445589, 0.3220427]],).astype(np.float32)
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
context.set_context(mode=mode, device_target="GPU")
net = LogSoftmax(0)
dx = Grad(net)(Tensor(x), Tensor(dy))
assert np.allclose(dx[0].asnumpy(), expect)
@pytest.mark.level1
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_logsoftmaxgrad1_dynamic_shape():
@pytest.mark.parametrize("mode", [context.GRAPH_MODE, context.PYNATIVE_MODE])
def test_logsoftmaxgrad1_dynamic_shape(mode):
"""
Feature: test logsoftmax in gpu.
Description: test the ops in dynamic shape.
@ -184,7 +261,7 @@ def test_logsoftmaxgrad1_dynamic_shape():
[-0.5247209, 0.70192003, -1.0808672, 1.4858199, -1.1273282, 0.20728993, 0.38918605, 0.08162117,
0.10445589, 0.3220427]],).astype(np.float32)
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
context.set_context(mode=mode, device_target="GPU")
net = LogSoftmax(0)
dx = Grad(net)
x_dyn = Tensor(shape=[5, None], dtype=ms.float32)
@ -193,7 +270,7 @@ def test_logsoftmaxgrad1_dynamic_shape():
assert np.allclose(dx_out[0].asnumpy(), expect)
@pytest.mark.level1
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_logsoftmaxgrad_vmap():