forked from mindspore-Ecosystem/mindspore
gpu support logsoftmax & logsoftmaxgrad kernel
This commit is contained in:
parent
afe048474d
commit
1eb60df5d4
|
@ -22,5 +22,9 @@ MS_REG_GPU_KERNEL_ONE(Softmax, KernelAttr().AddInputAttr(kNumberTypeFloat32).Add
|
|||
SoftmaxGpuKernel, float)
|
||||
MS_REG_GPU_KERNEL_ONE(Softmax, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
SoftmaxGpuKernel, half)
|
||||
MS_REG_GPU_KERNEL_ONE(LogSoftmax, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
SoftmaxGpuKernel, float)
|
||||
MS_REG_GPU_KERNEL_ONE(LogSoftmax, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
SoftmaxGpuKernel, half)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -117,9 +117,16 @@ class SoftmaxGpuKernel : public GpuKernel {
|
|||
if (shape_size_ != 2) {
|
||||
MS_LOG(EXCEPTION) << "Input is " << shape_size_ << "-D, but softmax only supports 2-D inputs.";
|
||||
}
|
||||
|
||||
auto axis = GetAttr<std::vector<int>>(kernel_node, "axis");
|
||||
InitSizeByAxis(input_shape, axis[0]);
|
||||
auto kernel_name = AnfAlgo::GetCNodeName(kernel_node);
|
||||
if (kernel_name == "LogSoftmax") {
|
||||
algo_ = CUDNN_SOFTMAX_LOG;
|
||||
auto axis = GetAttr<int>(kernel_node, "axis");
|
||||
InitSizeByAxis(input_shape, axis);
|
||||
} else {
|
||||
algo_ = CUDNN_SOFTMAX_ACCURATE;
|
||||
auto axis = GetAttr<std::vector<int>>(kernel_node, "axis");
|
||||
InitSizeByAxis(input_shape, axis[0]);
|
||||
}
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(
|
||||
cudnnSetTensor4dDescriptor(input_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, SizeToInt(batch_size_),
|
||||
SizeToInt(channel_size_), SizeToInt(height_), SizeToInt(width_)),
|
||||
|
|
|
@ -0,0 +1,30 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "kernel/gpu/nn/softmax_grad_gpu_kernel.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
LogSoftmaxGrad,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
SoftmaxGradGpuKernel, float)
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
LogSoftmaxGrad,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
SoftmaxGradGpuKernel, half)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,219 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_SOFTMAX_GRAD_GPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_KERNEL_GPU_NN_SOFTMAX_GRAD_GPU_KERNEL_H_
|
||||
|
||||
#include <vector>
|
||||
#include "kernel/gpu/gpu_kernel.h"
|
||||
#include "kernel/gpu/gpu_kernel_factory.h"
|
||||
#include "kernel/gpu/kernel_constants.h"
|
||||
#include "kernel/gpu/cuda_impl/transpose_impl.cuh"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
template <typename T>
|
||||
class SoftmaxGradGpuKernel : public GpuKernel {
|
||||
public:
|
||||
SoftmaxGradGpuKernel()
|
||||
: cudnn_handle_(nullptr),
|
||||
y_desc_(nullptr),
|
||||
algo_(CUDNN_SOFTMAX_ACCURATE),
|
||||
mode_(CUDNN_SOFTMAX_MODE_INSTANCE),
|
||||
cudnn_data_type_(CUDNN_DATA_FLOAT),
|
||||
is_null_input_(false),
|
||||
input_size_(0),
|
||||
output_size_(0),
|
||||
workspace_size_(0),
|
||||
axis_(0),
|
||||
shape_size_(0),
|
||||
batch_size_(0),
|
||||
channel_size_(0),
|
||||
height_(0),
|
||||
width_(0) {}
|
||||
~SoftmaxGradGpuKernel() override { DestroyResource(); }
|
||||
|
||||
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
|
||||
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
|
||||
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, uintptr_t stream_ptr) override {
|
||||
if (is_null_input_) {
|
||||
return true;
|
||||
}
|
||||
T *y_addr = GetDeviceAddress<T>(inputs, 0);
|
||||
T *dy_addr = GetDeviceAddress<T>(inputs, 1);
|
||||
T *dx_addr = GetDeviceAddress<T>(outputs, 0);
|
||||
|
||||
T *transpose_y_addr = GetDeviceAddress<T>(workspace, 0);
|
||||
T *transpose_dy_addr = GetDeviceAddress<T>(workspace, 1);
|
||||
T *transpose_dx_addr = GetDeviceAddress<T>(workspace, 2);
|
||||
int *input_shape = GetDeviceAddress<int>(workspace, 3);
|
||||
int *transpose_shape = GetDeviceAddress<int>(workspace, 4);
|
||||
int *transpose_axis = GetDeviceAddress<int>(workspace, 5);
|
||||
const float alpha = 1;
|
||||
const float beta = 0;
|
||||
|
||||
if (axis_ == 1) {
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSoftmaxBackward(cudnn_handle_, algo_, mode_, &alpha, y_desc_, y_addr, y_desc_,
|
||||
dy_addr, &beta, y_desc_, dx_addr),
|
||||
"cudnnSoftmaxBackward failed");
|
||||
} else {
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(input_shape, &input_shape_[0], workspace_size_, cudaMemcpyHostToDevice,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"cudaMemcpyAsync input_shape failed");
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(transpose_shape, &transpose_shape_[0], workspace_size_,
|
||||
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"cudaMemcpyAsync input_shape failed");
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(transpose_axis, &transpose_axis_[0], workspace_size_,
|
||||
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"cudaMemcpyAsync input_axis failed");
|
||||
int size = SizeToInt(input_size_ / sizeof(T));
|
||||
CalTranspose(size, y_addr, input_shape, transpose_axis, shape_size_, transpose_y_addr,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
CalTranspose(size, dy_addr, input_shape, transpose_axis, shape_size_, transpose_dy_addr,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSoftmaxBackward(cudnn_handle_, algo_, mode_, &alpha, y_desc_, transpose_y_addr,
|
||||
y_desc_, transpose_dy_addr, &beta, y_desc_, transpose_dx_addr),
|
||||
"cudnnSoftmaxBackward failed");
|
||||
CalTranspose(size, transpose_dx_addr, transpose_shape, transpose_axis, shape_size_, dx_addr,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool Init(const CNodePtr &kernel_node) override {
|
||||
InitResource();
|
||||
cudnn_data_type_ = kCudnnDtypeMap[TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))];
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
if (input_num != 2) {
|
||||
MS_LOG(ERROR) << "Input number is " << input_num << ", but softmax grad needs 2 input.";
|
||||
return false;
|
||||
}
|
||||
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
|
||||
if (output_num != 1) {
|
||||
MS_LOG(ERROR) << "Output number is " << output_num << ", but softmax grad needs 1 output.";
|
||||
return false;
|
||||
}
|
||||
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||
is_null_input_ = CHECK_NULL_INPUT(input_shape);
|
||||
if (is_null_input_) {
|
||||
MS_LOG(WARNING) << "SoftmaxGradGpuKernel input is null";
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
shape_size_ = SizeToInt(input_shape.size());
|
||||
if (shape_size_ != 2) {
|
||||
MS_LOG(EXCEPTION) << "Input is " << shape_size_ << "-D, but softmax grad only supports 2-D inputs.";
|
||||
}
|
||||
auto kernel_name = AnfAlgo::GetCNodeName(kernel_node);
|
||||
if (kernel_name == "LogSoftmaxGrad") {
|
||||
algo_ = CUDNN_SOFTMAX_LOG;
|
||||
auto axis = GetAttr<int>(kernel_node, "axis");
|
||||
InitSizeByAxis(input_shape, axis);
|
||||
} else {
|
||||
algo_ = CUDNN_SOFTMAX_ACCURATE;
|
||||
auto axis = GetAttr<std::vector<int>>(kernel_node, "axis");
|
||||
InitSizeByAxis(input_shape, axis[0]);
|
||||
}
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(
|
||||
cudnnSetTensor4dDescriptor(y_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, SizeToInt(batch_size_),
|
||||
SizeToInt(channel_size_), SizeToInt(height_), SizeToInt(width_)),
|
||||
"set input_descriptor failed");
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
|
||||
protected:
|
||||
void InitResource() override {
|
||||
cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle();
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&y_desc_), "create input_descriptor failed");
|
||||
}
|
||||
|
||||
void InitSizeLists() override {
|
||||
input_size_list_.push_back(input_size_);
|
||||
output_size_list_.push_back(output_size_);
|
||||
workspace_size_list_.push_back(input_size_);
|
||||
workspace_size_list_.push_back(input_size_);
|
||||
workspace_size_list_.push_back(output_size_);
|
||||
workspace_size_list_.push_back(workspace_size_);
|
||||
workspace_size_list_.push_back(workspace_size_);
|
||||
workspace_size_list_.push_back(workspace_size_);
|
||||
return;
|
||||
}
|
||||
|
||||
private:
|
||||
void DestroyResource() noexcept {
|
||||
CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(y_desc_), "destroy output_descriptor failed");
|
||||
}
|
||||
|
||||
void InitSizeByAxis(const std::vector<size_t> input_shape, const int axis) {
|
||||
axis_ = axis;
|
||||
if (axis_ < 0) {
|
||||
axis_ += shape_size_;
|
||||
}
|
||||
if (axis_ == 1) {
|
||||
batch_size_ = input_shape[0];
|
||||
channel_size_ = input_shape[1];
|
||||
} else if (axis_ == 0) {
|
||||
batch_size_ = input_shape[1];
|
||||
channel_size_ = input_shape[0];
|
||||
input_shape_.push_back(input_shape[0]);
|
||||
input_shape_.push_back(input_shape[1]);
|
||||
transpose_shape_.push_back(input_shape[1]);
|
||||
transpose_shape_.push_back(input_shape[0]);
|
||||
transpose_axis_.push_back(1);
|
||||
transpose_axis_.push_back(0);
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Input is " << shape_size_ << "-D, but axis(" << axis << ") is invalid.";
|
||||
}
|
||||
|
||||
height_ = 1;
|
||||
width_ = 1;
|
||||
input_size_ = sizeof(T) * batch_size_ * channel_size_ * height_ * width_;
|
||||
output_size_ = input_size_;
|
||||
workspace_size_ = IntToSize(shape_size_) * sizeof(int);
|
||||
}
|
||||
|
||||
cudnnHandle_t cudnn_handle_;
|
||||
cudnnTensorDescriptor_t y_desc_;
|
||||
cudnnSoftmaxAlgorithm_t algo_;
|
||||
cudnnSoftmaxMode_t mode_;
|
||||
cudnnDataType_t cudnn_data_type_;
|
||||
bool is_null_input_;
|
||||
size_t input_size_;
|
||||
size_t output_size_;
|
||||
size_t workspace_size_;
|
||||
std::vector<size_t> input_size_list_;
|
||||
std::vector<size_t> output_size_list_;
|
||||
std::vector<size_t> workspace_size_list_;
|
||||
|
||||
std::vector<int> input_shape_;
|
||||
std::vector<int> transpose_shape_;
|
||||
std::vector<int> transpose_axis_;
|
||||
int axis_;
|
||||
int shape_size_;
|
||||
|
||||
size_t batch_size_;
|
||||
size_t channel_size_;
|
||||
size_t height_;
|
||||
size_t width_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_SOFTMAX_GRAD_GPU_KERNEL_H_
|
|
@ -0,0 +1,109 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import pytest
|
||||
import numpy as np
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import composite as C
|
||||
import mindspore.nn as nn
|
||||
import mindspore.context as context
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_logsoftmax():
|
||||
x = np.array([[-0.08082921, -0.13706027, -0.4711177, -0.05606057],
|
||||
[-0.46082982, 1.1761844, -1.016654, -1.743829 ],
|
||||
[-1.5062045, 0.6910976, 0.4839723, 1.1502692 ]]).astype(np.float32)
|
||||
expect = np.array([[-1.2939762, -1.3502073, -1.6842647, -1.2692076 ],
|
||||
[-1.9445671, -0.3075528, -2.5003912, -3.2275662 ],
|
||||
[-3.452001, -1.2546989, -1.4618242, -0.79552734]]).astype(np.float32)
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
LogSoftmax = P.LogSoftmax()
|
||||
output = LogSoftmax(Tensor(x))
|
||||
assert np.allclose(output.asnumpy(), expect)
|
||||
|
||||
|
||||
class LogSoftmax(nn.Cell):
|
||||
def __init__(self, axis=-1):
|
||||
super(LogSoftmax, self).__init__()
|
||||
self.logsoftmax = P.LogSoftmax(axis)
|
||||
|
||||
def construct(self, x):
|
||||
return self.logsoftmax(x)
|
||||
|
||||
class Grad(nn.Cell):
|
||||
def __init__(self, network):
|
||||
super(Grad, self).__init__()
|
||||
self.grad = C.GradOperation(name="get_all", get_all=True, sens_param=True)
|
||||
self.network = network
|
||||
|
||||
def construct(self, input_data, sens):
|
||||
gout = self.grad(self.network)(input_data, sens)
|
||||
return gout
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_logsoftmaxgrad():
|
||||
x = np.array([[-0.47705367, 0.48267725, -1.0453935, 1.574488, 0.20362134, 0.4435456, -0.23984082, -0.43684655, -0.7725506, 1.4481013 ],
|
||||
[ 1.1012247, 1.7069651, 0.55062026, 0.3361901, -1.1082426, -0.5001939, -0.3255393, -0.7972024, -0.27965206, -0.702805 ],
|
||||
[ 0.19450496, 0.87596166, 0.6467245, -1.044987, 0.5248943, -2.6166635, 1.6719198, 0.06600758, -0.4099178, 1.1861311 ],
|
||||
[ 1.1305193, -1.97308, 2.1047623, -1.5105937, 0.93052036, 1.2467804, 0.5310002, 0.7084912, -1.3681422, -0.9686862 ],
|
||||
[ 1.871408, 0.14219497, -0.41050452, -0.749807, 1.4900619, -1.8172716, -0.73839617, 0.17565694, -0.4553867, -1.5423119 ]]).astype(np.float32)
|
||||
dy = np.array([[ 1.516363, -0.15196544, 0.598733, 0.64357865, 0.16265012, -1.3521105, 0.22621834, 0.7168259, -0.6709239, 0.79757756],
|
||||
[-0.32457778, 1.2831115, 1.1211495, -0.02665559, 1.9170904, -1.3397789, 1.4124829, -1.4298155, 0.758519, -0.25322974],
|
||||
[-0.24226122, -1.2555921, 0.6492511, -0.34847677, 0.19916506, 0.628554, -0.19658111, 0.44939864, -0.11677749, -1.2131723 ],
|
||||
[ 0.24267715, 0.28106326, 1.1075432, -0.29006946, 0.31335673, 0.8833154, 0.13152207, 1.5482179, 0.29770762, -0.16246222],
|
||||
[ 0.02145994, 0.80424, -0.95061, 1.5875458, -0.00308682, 0.17964548, 0.49912593, 0.46977136, 0.2151897, 0.30908248]]).astype(np.float32)
|
||||
expect = np.array([[ 1.4219905 , -0.39837134, 0.5452743 , -0.09062839, -0.02375537, -1.5890603 , 0.10658137, 0.6185817 , -0.7411523 , 0.15054005],
|
||||
[-0.94926417, 0.13830578, 0.7609547 , -0.31733334, 1.8485254 , -1.4657221 , 1.2625053 , -1.523396 , 0.601499 , -0.35607445],
|
||||
[-0.14447737, -1.0622973 , 0.80294746, -0.32016528, 0.33523226, 0.63443416, 0.23186903, 0.53539133, -0.0633494 , -0.9495847 ],
|
||||
[-0.36894822, 0.253609 , -0.5127511 , -0.33366728, -0.18740037, 0.19628316, -0.20430653, 1.1471655 , 0.24743511, -0.23741922],
|
||||
[-1.2582518 , 0.57718843, -1.0812542 , 1.4944922 , -0.8770549 , 0.1476463 , 0.40500447, 0.23499368, 0.09027944, 0.26695627]]).astype(np.float32)
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
net = LogSoftmax()
|
||||
dx = Grad(net)(Tensor(x), Tensor(dy))
|
||||
assert np.allclose(dx[0].asnumpy(), expect)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_logsoftmaxgrad1():
|
||||
x = np.array([[-0.47705367, 0.48267725, -1.0453935, 1.574488, 0.20362134, 0.4435456, -0.23984082, -0.43684655, -0.7725506, 1.4481013 ],
|
||||
[ 1.1012247, 1.7069651, 0.55062026, 0.3361901, -1.1082426, -0.5001939, -0.3255393, -0.7972024, -0.27965206, -0.702805 ],
|
||||
[ 0.19450496, 0.87596166, 0.6467245, -1.044987, 0.5248943, -2.6166635, 1.6719198, 0.06600758, -0.4099178, 1.1861311 ],
|
||||
[ 1.1305193, -1.97308, 2.1047623, -1.5105937, 0.93052036, 1.2467804, 0.5310002, 0.7084912, -1.3681422, -0.9686862 ],
|
||||
[ 1.871408, 0.14219497, -0.41050452, -0.749807, 1.4900619, -1.8172716, -0.73839617, 0.17565694, -0.4553867, -1.5423119 ]]).astype(np.float32)
|
||||
dy = np.array([[ 1.516363, -0.15196544, 0.598733, 0.64357865, 0.16265012, -1.3521105, 0.22621834, 0.7168259, -0.6709239, 0.79757756],
|
||||
[-0.32457778, 1.2831115, 1.1211495, -0.02665559, 1.9170904, -1.3397789, 1.4124829, -1.4298155, 0.758519, -0.25322974],
|
||||
[-0.24226122, -1.2555921, 0.6492511, -0.34847677, 0.19916506, 0.628554, -0.19658111, 0.44939864, -0.11677749, -1.2131723 ],
|
||||
[ 0.24267715, 0.28106326, 1.1075432, -0.29006946, 0.31335673, 0.8833154, 0.13152207, 1.5482179, 0.29770762, -0.16246222],
|
||||
[ 0.02145994, 0.80424, -0.95061, 1.5875458, -0.00308682, 0.17964548, 0.49912593, 0.46977136, 0.2151897, 0.30908248]]).astype(np.float32)
|
||||
expect = np.array([[ 1.464194 , -0.29578894, 0.5296974 , -0.39600563, -0.1479242 , -1.0869746 , 0.04521982, 0.5064515 , -0.7515615 , 1.0554069 ],
|
||||
[-0.5774203 , 0.793861 , 0.7805745 , -0.32800734, 1.8334473 , -1.236596 , 1.2463496 , -1.5765365 , 0.6265108 , -0.22322391],
|
||||
[-0.34437084, -1.4687154 , 0.27432096, -0.42420125, -0.22908019, 0.640983 , -1.4210342 , 0.10155854, -0.23266247, -1.0147638 ],
|
||||
[-0.01768187, 0.26872346, -0.5037259 , -0.3376058 , -0.3291146 , 1.4752979 , -0.25972134, 0.8869053 , 0.25325722, -0.13946185],
|
||||
[-0.5247209 , 0.70192003, -1.0808672 , 1.4858199 , -1.1273282 , 0.20728993, 0.38918605, 0.08162117, 0.10445589, 0.3220427 ]],).astype(np.float32)
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
net = LogSoftmax(0)
|
||||
dx = Grad(net)(Tensor(x), Tensor(dy))
|
||||
assert np.allclose(dx[0].asnumpy(), expect)
|
Loading…
Reference in New Issue