forked from mindspore-Ecosystem/mindspore
add logsoftmax gpu brop xd support
This commit is contained in:
parent
9bec955a7d
commit
05ecf44ef4
|
@ -23,7 +23,6 @@ namespace mindspore {
|
|||
namespace kernel {
|
||||
constexpr size_t INPUT_NUM = 2;
|
||||
constexpr size_t OUTPUT_NUM = 1;
|
||||
constexpr size_t SUPPORT_SIZE = 2;
|
||||
|
||||
bool SoftmaxGradGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) {
|
||||
|
@ -55,10 +54,6 @@ int SoftmaxGradGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const
|
|||
ResetResource();
|
||||
auto input_shape = LongVecToSizeVec(inputs[kIndex0]->GetShapeVector());
|
||||
shape_size_ = input_shape.size();
|
||||
if (shape_size_ != SUPPORT_SIZE) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the dimension of input must be equal to 2, but got "
|
||||
<< shape_size_;
|
||||
}
|
||||
if (kernel_name_ == "LogSoftmaxGrad") {
|
||||
algo_ = CUDNN_SOFTMAX_LOG;
|
||||
auto log_soft_max_grad_ptr = std::dynamic_pointer_cast<ops::LogSoftmaxGrad>(base_operator);
|
||||
|
@ -106,7 +101,7 @@ bool SoftmaxGradGpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs
|
|||
const float alpha = 1;
|
||||
const float beta = 0;
|
||||
|
||||
if (axis_ == 1) {
|
||||
if (axis_ == static_cast<int>(input_shape_.size()) - 1) {
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT_NOTRACE(cudnnSoftmaxBackward(cudnn_handle_, algo_, mode_, &alpha, y_desc_, y_addr,
|
||||
y_desc_, dy_addr, &beta, y_desc_, dx_addr),
|
||||
kernel_name_ + "cudnnSoftmaxBackward failed");
|
||||
|
|
|
@ -22,6 +22,7 @@
|
|||
#include <algorithm>
|
||||
#include <map>
|
||||
#include <utility>
|
||||
#include <functional>
|
||||
#include "plugin/device/gpu/kernel/gpu_kernel.h"
|
||||
#include "plugin/device/gpu/kernel/gpu_kernel_factory.h"
|
||||
#include "plugin/device/gpu/kernel/kernel_constants.h"
|
||||
|
@ -82,23 +83,28 @@ class SoftmaxGradGpuKernelMod : public NativeGpuKernelMod {
|
|||
if (axis_ < 0) {
|
||||
axis_ += SizeToInt(shape_size_);
|
||||
}
|
||||
if (axis_ == 1) {
|
||||
batch_size_ = input_shape[0];
|
||||
channel_size_ = input_shape[1];
|
||||
} else if (axis_ == 0) {
|
||||
batch_size_ = input_shape[1];
|
||||
channel_size_ = input_shape[0];
|
||||
input_shape_.push_back(input_shape[0]);
|
||||
input_shape_.push_back(input_shape[1]);
|
||||
transpose_shape_.push_back(input_shape[1]);
|
||||
transpose_shape_.push_back(input_shape[0]);
|
||||
transpose_axis_.push_back(1);
|
||||
transpose_axis_.push_back(0);
|
||||
} else {
|
||||
if (axis_ >= SizeToInt(shape_size_) || axis_ < 0) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the value of 'axis' must be in range [-" << shape_size_
|
||||
<< ", " << shape_size_ << "), but got " << axis;
|
||||
}
|
||||
|
||||
input_shape_ = input_shape;
|
||||
transpose_shape_ = input_shape;
|
||||
for (size_t i = 0; i < input_shape.size(); ++i) {
|
||||
transpose_axis_.emplace_back(i);
|
||||
}
|
||||
std::swap(transpose_shape_[IntToSize(axis_)], transpose_shape_.back());
|
||||
std::swap(transpose_axis_[IntToSize(axis_)], transpose_axis_.back());
|
||||
|
||||
size_t size_ = std::accumulate(input_shape.begin(), input_shape.end(), 1UL, std::multiplies<size_t>());
|
||||
channel_size_ = transpose_shape_.back();
|
||||
if (channel_size_ != 0) {
|
||||
batch_size_ = size_ / channel_size_;
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_
|
||||
<< "', the value of the shape of the input along the 'axis' dimension should be greater than 0"
|
||||
<< ", but got " << channel_size_;
|
||||
}
|
||||
height_ = 1;
|
||||
width_ = 1;
|
||||
input_size_ = type_id_size_ * batch_size_ * channel_size_ * height_ * width_;
|
||||
|
@ -124,7 +130,6 @@ class SoftmaxGradGpuKernelMod : public NativeGpuKernelMod {
|
|||
std::vector<size_t> transpose_axis_;
|
||||
int axis_{0};
|
||||
size_t shape_size_{0};
|
||||
|
||||
size_t batch_size_{0};
|
||||
size_t channel_size_{0};
|
||||
size_t height_{0};
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
# Copyright 2023 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -29,6 +29,11 @@ from mindspore.ops.functional import vmap
|
|||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_logsoftmax():
|
||||
"""
|
||||
Feature: logsoftmax
|
||||
Description: Verify the result of logsoftmax
|
||||
Expectation: success
|
||||
"""
|
||||
x = np.array([[-0.08082921, -0.13706027, -0.4711177, -0.05606057],
|
||||
[-0.46082982, 1.1761844, -1.016654, -1.743829],
|
||||
[-1.5062045, 0.6910976, 0.4839723, 1.1502692]]).astype(np.float32)
|
||||
|
@ -62,10 +67,16 @@ class Grad(nn.Cell):
|
|||
return gout
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_logsoftmaxgrad():
|
||||
@pytest.mark.parametrize("mode", [context.GRAPH_MODE, context.PYNATIVE_MODE])
|
||||
def test_logsoftmaxgrad(mode):
|
||||
"""
|
||||
Feature: logsoftmaxgrad
|
||||
Description: Verify the result of logsoftmaxgrad with 2d input, dim=-1
|
||||
Expectation: success
|
||||
"""
|
||||
x = np.array([[-0.47705367, 0.48267725, -1.0453935, 1.574488, 0.20362134, 0.4435456, -0.23984082, -0.43684655,
|
||||
-0.7725506, 1.4481013],
|
||||
[1.1012247, 1.7069651, 0.55062026, 0.3361901, -1.1082426, -0.5001939, -0.3255393, -0.7972024,
|
||||
|
@ -97,16 +108,81 @@ def test_logsoftmaxgrad():
|
|||
[-1.2582518, 0.57718843, -1.0812542, 1.4944922, -0.8770549, 0.1476463, 0.40500447, 0.23499368,
|
||||
0.09027944, 0.26695627]]).astype(np.float32)
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
context.set_context(mode=mode, device_target="GPU")
|
||||
net = LogSoftmax()
|
||||
dx = Grad(net)(Tensor(x), Tensor(dy))
|
||||
assert np.allclose(dx[0].asnumpy(), expect)
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_logsoftmaxgrad1():
|
||||
@pytest.mark.parametrize("mode", [context.GRAPH_MODE, context.PYNATIVE_MODE])
|
||||
def test_logsoftmaxgrad_4d_lastdim(mode):
|
||||
"""
|
||||
Feature: logsoftmaxgrad
|
||||
Description: Verify the result of logsoftmaxgrad with 4d input, dim=-1
|
||||
Expectation: success
|
||||
"""
|
||||
x = np.array([[[[0.9342035, 0.41253936],
|
||||
[0.96119386, 0.45106655]],
|
||||
[[0.9795543, 0.70140046],
|
||||
[0.34018862, 0.31537667]]]], dtype=np.float32)
|
||||
dy = np.array([[[[0.28354234, 0.23482183],
|
||||
[0.06688348, 0.7837496]],
|
||||
[[0.14290118, 0.47044736],
|
||||
[0.46478033, 0.46465948]]]], dtype=np.float32)
|
||||
expect = np.array([[[[-0.04175026, 0.04175026],
|
||||
[-0.46462297, 0.46462294]],
|
||||
[[-0.2061515, 0.20615152],
|
||||
[-0.00570459, 0.00570457]]]], dtype=np.float32)
|
||||
|
||||
context.set_context(mode=mode, device_target="GPU")
|
||||
net = LogSoftmax()
|
||||
dx = Grad(net)(Tensor(x), Tensor(dy))
|
||||
assert np.allclose(dx[0].asnumpy(), expect, atol=1e-5, rtol=1e-5)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize("mode", [context.GRAPH_MODE, context.PYNATIVE_MODE])
|
||||
def test_logsoftmaxgrad_4d_dim1(mode):
|
||||
"""
|
||||
Feature: logsoftmaxgrad
|
||||
Description: Verify the result of logsoftmaxgrad with 4d input, dim=1
|
||||
Expectation: success
|
||||
"""
|
||||
dim = 1
|
||||
x = np.array([[[[0.9342035, 0.41253936],
|
||||
[0.96119386, 0.45106655]],
|
||||
[[0.9795543, 0.70140046],
|
||||
[0.34018862, 0.31537667]]]], dtype=np.float32)
|
||||
dy = np.array([[[[0.28354234, 0.23482183],
|
||||
[0.06688348, 0.7837496]],
|
||||
[[0.14290118, 0.47044736],
|
||||
[0.46478033, 0.46465948]]]], dtype=np.float32)
|
||||
expect = np.array([[[[0.07515464, -0.06723279],
|
||||
[-0.27893573, 0.11726081]],
|
||||
[[-0.07515462, 0.06723274],
|
||||
[0.2789357, -0.11726078]]]], dtype=np.float32)
|
||||
|
||||
context.set_context(mode=mode, device_target="GPU")
|
||||
net = LogSoftmax(dim)
|
||||
dx = Grad(net)(Tensor(x), Tensor(dy))
|
||||
assert np.allclose(dx[0].asnumpy(), expect, atol=1e-5, rtol=1e-5)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize("mode", [context.GRAPH_MODE, context.PYNATIVE_MODE])
|
||||
def test_logsoftmaxgrad1(mode):
|
||||
"""
|
||||
Feature: logsoftmaxgrad
|
||||
Description: Verify the result of logsoftmaxgrad with 2d input, dim=0
|
||||
Expectation: success
|
||||
"""
|
||||
x = np.array([[-0.47705367, 0.48267725, -1.0453935, 1.574488, 0.20362134, 0.4435456, -0.23984082, -0.43684655,
|
||||
-0.7725506, 1.4481013],
|
||||
[1.1012247, 1.7069651, 0.55062026, 0.3361901, -1.1082426, -0.5001939, -0.3255393, -0.7972024,
|
||||
|
@ -138,16 +214,17 @@ def test_logsoftmaxgrad1():
|
|||
[-0.5247209, 0.70192003, -1.0808672, 1.4858199, -1.1273282, 0.20728993, 0.38918605, 0.08162117,
|
||||
0.10445589, 0.3220427]],).astype(np.float32)
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
context.set_context(mode=mode, device_target="GPU")
|
||||
net = LogSoftmax(0)
|
||||
dx = Grad(net)(Tensor(x), Tensor(dy))
|
||||
assert np.allclose(dx[0].asnumpy(), expect)
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_logsoftmaxgrad1_dynamic_shape():
|
||||
@pytest.mark.parametrize("mode", [context.GRAPH_MODE, context.PYNATIVE_MODE])
|
||||
def test_logsoftmaxgrad1_dynamic_shape(mode):
|
||||
"""
|
||||
Feature: test logsoftmax in gpu.
|
||||
Description: test the ops in dynamic shape.
|
||||
|
@ -184,7 +261,7 @@ def test_logsoftmaxgrad1_dynamic_shape():
|
|||
[-0.5247209, 0.70192003, -1.0808672, 1.4858199, -1.1273282, 0.20728993, 0.38918605, 0.08162117,
|
||||
0.10445589, 0.3220427]],).astype(np.float32)
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
context.set_context(mode=mode, device_target="GPU")
|
||||
net = LogSoftmax(0)
|
||||
dx = Grad(net)
|
||||
x_dyn = Tensor(shape=[5, None], dtype=ms.float32)
|
||||
|
@ -193,7 +270,7 @@ def test_logsoftmaxgrad1_dynamic_shape():
|
|||
assert np.allclose(dx_out[0].asnumpy(), expect)
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_logsoftmaxgrad_vmap():
|
||||
|
|
Loading…
Reference in New Issue