Change softmax grad to NativeGpuKernelMode and add st
This commit is contained in:
parent
d6d4235fa8
commit
26b5cf48c8
|
@ -15,16 +15,146 @@
|
|||
*/
|
||||
|
||||
#include "plugin/device/gpu/kernel/nn/softmax_grad_gpu_kernel.h"
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/transpose_impl.cuh"
|
||||
#include "mindspore/core/ops/grad/softmax_grad.h"
|
||||
#include "mindspore/core/ops/grad/log_softmax_grad.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
LogSoftmaxGrad,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
SoftmaxGradGpuKernelMod, float)
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
LogSoftmaxGrad,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
SoftmaxGradGpuKernelMod, half)
|
||||
constexpr size_t INPUT_NUM = 2;
|
||||
constexpr size_t OUTPUT_NUM = 1;
|
||||
constexpr size_t SUPPORT_SIZE = 3;
|
||||
|
||||
bool SoftmaxGradGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) {
|
||||
kernel_name_ = base_operator->GetPrim()->name();
|
||||
cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle();
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT_NOTRACE(cudnnCreateTensorDescriptor(&y_desc_),
|
||||
kernel_name_ + "create input_descriptor failed");
|
||||
CHECK_KERNEL_INPUTS_NUM(inputs.size(), INPUT_NUM, kernel_name_);
|
||||
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), OUTPUT_NUM, kernel_name_);
|
||||
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
|
||||
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
|
||||
if (!is_match) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "', it does not support this kernel data type: " << kernel_attr;
|
||||
return false;
|
||||
}
|
||||
kernel_func_ = func_list_[index].second;
|
||||
auto input_data_type = inputs.at(kIndex0)->GetDtype();
|
||||
type_id_size_ = abstract::TypeIdSize(input_data_type);
|
||||
cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(input_data_type));
|
||||
return true;
|
||||
}
|
||||
|
||||
int SoftmaxGradGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs,
|
||||
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost) {
|
||||
if (auto ret = KernelMod::Resize(base_operator, inputs, outputs, inputsOnHost); ret != KRET_OK) {
|
||||
return ret;
|
||||
}
|
||||
ResetResource();
|
||||
auto input_shape = LongVecToSizeVec(inputs[kIndex0]->GetShapeVector());
|
||||
shape_size_ = input_shape.size();
|
||||
if (shape_size_ > SUPPORT_SIZE) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_
|
||||
<< "', the dimension of input must be less than and equal to 3, but got " << shape_size_;
|
||||
}
|
||||
if (kernel_name_ == "LogSoftmaxGrad") {
|
||||
algo_ = CUDNN_SOFTMAX_LOG;
|
||||
auto log_soft_max_grad_ptr = std::dynamic_pointer_cast<ops::LogSoftmaxGrad>(base_operator);
|
||||
auto axis = LongToInt(log_soft_max_grad_ptr->get_axis());
|
||||
InitSizeByAxis(input_shape, axis);
|
||||
} else {
|
||||
algo_ = CUDNN_SOFTMAX_ACCURATE;
|
||||
std::vector<int> axis;
|
||||
auto soft_max_grad_ptr = std::dynamic_pointer_cast<ops::SoftmaxGrad>(base_operator);
|
||||
auto axis_me = soft_max_grad_ptr->get_axis();
|
||||
(void)std::transform(axis_me.begin(), axis_me.end(), std::back_inserter(axis),
|
||||
[](const int64_t &value) { return LongToInt(value); });
|
||||
if (axis.size() < 1) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the length of 'axis' cannot be equal to 0, but got "
|
||||
<< axis.size();
|
||||
}
|
||||
if (axis.size() > 1) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the length of 'axis' cannot be greater than 1, but got "
|
||||
<< axis.size();
|
||||
}
|
||||
InitSizeByAxis(input_shape, axis[0]);
|
||||
}
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT_NOTRACE(
|
||||
cudnnSetTensor4dDescriptor(y_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, SizeToInt(batch_size_),
|
||||
SizeToInt(channel_size_), SizeToInt(height_), SizeToInt(width_)),
|
||||
kernel_name_ + "set input_descriptor failed");
|
||||
InitSizeLists();
|
||||
return KRET_OK;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool SoftmaxGradGpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs,
|
||||
const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
|
||||
T *y_addr = GetDeviceAddress<T>(inputs, kIndex0);
|
||||
T *dy_addr = GetDeviceAddress<T>(inputs, kIndex1);
|
||||
T *dx_addr = GetDeviceAddress<T>(outputs, kIndex0);
|
||||
|
||||
T *transpose_y_addr = GetDeviceAddress<T>(workspace, kIndex0);
|
||||
T *transpose_dy_addr = GetDeviceAddress<T>(workspace, kIndex1);
|
||||
T *transpose_dx_addr = GetDeviceAddress<T>(workspace, kIndex2);
|
||||
size_t *input_shape = GetDeviceAddress<size_t>(workspace, kIndex3);
|
||||
size_t *transpose_shape = GetDeviceAddress<size_t>(workspace, kIndex4);
|
||||
size_t *transpose_axis = GetDeviceAddress<size_t>(workspace, kIndex5);
|
||||
const float alpha = 1;
|
||||
const float beta = 0;
|
||||
|
||||
if (axis_ == 1) {
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT_NOTRACE(cudnnSoftmaxBackward(cudnn_handle_, algo_, mode_, &alpha, y_desc_, y_addr,
|
||||
y_desc_, dy_addr, &beta, y_desc_, dx_addr),
|
||||
kernel_name_ + "cudnnSoftmaxBackward failed");
|
||||
} else {
|
||||
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
|
||||
cudaMemcpyAsync(input_shape, &input_shape_[0], workspace_size_, cudaMemcpyHostToDevice,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
kernel_name_ + "cudaMemcpyAsync input_shape failed");
|
||||
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
|
||||
cudaMemcpyAsync(transpose_shape, &transpose_shape_[0], workspace_size_, cudaMemcpyHostToDevice,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
kernel_name_ + "cudaMemcpyAsync input_shape failed");
|
||||
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
|
||||
cudaMemcpyAsync(transpose_axis, &transpose_axis_[0], workspace_size_, cudaMemcpyHostToDevice,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
kernel_name_ + "cudaMemcpyAsync input_axis failed");
|
||||
size_t size = input_size_ / sizeof(T);
|
||||
CalTranspose(size, y_addr, input_shape, transpose_axis, shape_size_, transpose_y_addr,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
CalTranspose(size, dy_addr, input_shape, transpose_axis, shape_size_, transpose_dy_addr,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT_NOTRACE(
|
||||
cudnnSoftmaxBackward(cudnn_handle_, algo_, mode_, &alpha, y_desc_, transpose_y_addr, y_desc_, transpose_dy_addr,
|
||||
&beta, y_desc_, transpose_dx_addr),
|
||||
kernel_name_ + "cudnnSoftmaxBackward failed");
|
||||
CalTranspose(size, transpose_dx_addr, transpose_shape, transpose_axis, shape_size_, dx_addr,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<std::pair<KernelAttr, SoftmaxGradGpuKernelMod::SoftmaxGradGpuLaunchFunc>>
|
||||
SoftmaxGradGpuKernelMod::func_list_ = {
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
&SoftmaxGradGpuKernelMod::LaunchKernel<float>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
&SoftmaxGradGpuKernelMod::LaunchKernel<half>},
|
||||
};
|
||||
|
||||
std::vector<KernelAttr> SoftmaxGradGpuKernelMod::GetOpSupport() {
|
||||
std::vector<KernelAttr> support_list;
|
||||
(void)std::transform(
|
||||
func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
|
||||
[](const std::pair<KernelAttr, SoftmaxGradGpuKernelMod::SoftmaxGradGpuLaunchFunc> &pair) { return pair.first; });
|
||||
return support_list;
|
||||
}
|
||||
|
||||
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, SoftmaxGrad, SoftmaxGradGpuKernelMod);
|
||||
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, LogSoftmaxGrad, SoftmaxGradGpuKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -14,195 +14,50 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_SOFTMAX_GRAD_GPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_SOFTMAX_GRAD_GPU_KERNEL_H_
|
||||
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_NN_SOFTMAX_GRAD_GPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_NN_SOFTMAX_GRAD_GPU_KERNEL_H_
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <algorithm>
|
||||
#include <map>
|
||||
#include <utility>
|
||||
#include "plugin/device/gpu/kernel/gpu_kernel.h"
|
||||
#include "plugin/device/gpu/kernel/gpu_kernel_factory.h"
|
||||
#include "plugin/device/gpu/kernel/kernel_constants.h"
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/transpose_impl.cuh"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
template <typename T>
|
||||
class SoftmaxGradGpuKernelMod : public DeprecatedNativeGpuKernelMod {
|
||||
class SoftmaxGradGpuKernelMod : public NativeGpuKernelMod {
|
||||
public:
|
||||
SoftmaxGradGpuKernelMod()
|
||||
: cudnn_handle_(nullptr),
|
||||
y_desc_(nullptr),
|
||||
algo_(CUDNN_SOFTMAX_ACCURATE),
|
||||
mode_(CUDNN_SOFTMAX_MODE_INSTANCE),
|
||||
cudnn_data_type_(CUDNN_DATA_FLOAT),
|
||||
is_null_input_(false),
|
||||
kernel_name_("SoftmaxGrad"),
|
||||
input_size_(0),
|
||||
output_size_(0),
|
||||
workspace_size_(0),
|
||||
axis_(0),
|
||||
shape_size_(0),
|
||||
batch_size_(0),
|
||||
channel_size_(0),
|
||||
height_(0),
|
||||
width_(0) {}
|
||||
SoftmaxGradGpuKernelMod() = default;
|
||||
~SoftmaxGradGpuKernelMod() override { DestroyResource(); }
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
||||
if (is_null_input_) {
|
||||
return true;
|
||||
}
|
||||
T *y_addr = GetDeviceAddress<T>(inputs, 0);
|
||||
T *dy_addr = GetDeviceAddress<T>(inputs, 1);
|
||||
T *dx_addr = GetDeviceAddress<T>(outputs, 0);
|
||||
|
||||
T *transpose_y_addr = GetDeviceAddress<T>(workspace, 0);
|
||||
T *transpose_dy_addr = GetDeviceAddress<T>(workspace, 1);
|
||||
T *transpose_dx_addr = GetDeviceAddress<T>(workspace, 2);
|
||||
size_t *input_shape = GetDeviceAddress<size_t>(workspace, 3);
|
||||
size_t *transpose_shape = GetDeviceAddress<size_t>(workspace, 4);
|
||||
size_t *transpose_axis = GetDeviceAddress<size_t>(workspace, 5);
|
||||
const float alpha = 1;
|
||||
const float beta = 0;
|
||||
|
||||
if (axis_ == 1) {
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_,
|
||||
cudnnSoftmaxBackward(cudnn_handle_, algo_, mode_, &alpha, y_desc_, y_addr, y_desc_,
|
||||
dy_addr, &beta, y_desc_, dx_addr),
|
||||
"cudnnSoftmaxBackward failed");
|
||||
} else {
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
|
||||
cudaMemcpyAsync(input_shape, &input_shape_[0], workspace_size_, cudaMemcpyHostToDevice,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"cudaMemcpyAsync input_shape failed");
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
|
||||
cudaMemcpyAsync(transpose_shape, &transpose_shape_[0], workspace_size_,
|
||||
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"cudaMemcpyAsync input_shape failed");
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
|
||||
cudaMemcpyAsync(transpose_axis, &transpose_axis_[0], workspace_size_,
|
||||
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"cudaMemcpyAsync input_axis failed");
|
||||
size_t size = input_size_ / sizeof(T);
|
||||
CalTranspose(size, y_addr, input_shape, transpose_axis, shape_size_, transpose_y_addr,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
CalTranspose(size, dy_addr, input_shape, transpose_axis, shape_size_, transpose_dy_addr,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_,
|
||||
cudnnSoftmaxBackward(cudnn_handle_, algo_, mode_, &alpha, y_desc_, transpose_y_addr,
|
||||
y_desc_, transpose_dy_addr, &beta, y_desc_, transpose_dx_addr),
|
||||
"cudnnSoftmaxBackward failed");
|
||||
CalTranspose(size, transpose_dx_addr, transpose_shape, transpose_axis, shape_size_, dx_addr,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
}
|
||||
return true;
|
||||
return kernel_func_(this, inputs, workspace, outputs, stream_ptr);
|
||||
}
|
||||
|
||||
bool Init(const CNodePtr &kernel_node) override {
|
||||
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
|
||||
kernel_node_ = kernel_node;
|
||||
InitResource();
|
||||
cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0)));
|
||||
size_t input_num = common::AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
if (input_num != 2) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the number of inputs must be 2, but got " << input_num;
|
||||
}
|
||||
size_t output_num = common::AnfAlgo::GetOutputTensorNum(kernel_node);
|
||||
if (output_num != 1) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the number of outputs must be 1, but got " << output_num;
|
||||
}
|
||||
auto temp_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||
if (AnfAlgo::IsShapesDynamic({temp_shape})) {
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
std::vector<size_t> input_shape(temp_shape.begin(), temp_shape.end());
|
||||
auto axis = static_cast<int>(GetAttr<int64_t>(kernel_node, "axis"));
|
||||
if (axis == -1 || axis == SizeToInt(input_shape.size())) {
|
||||
axis = 1;
|
||||
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) override;
|
||||
|
||||
std::vector<size_t> reshape;
|
||||
size_t dim0 = 1;
|
||||
for (size_t i = 0; i < input_shape.size() - 1; i++) {
|
||||
dim0 *= input_shape[i];
|
||||
}
|
||||
reshape.push_back(dim0);
|
||||
reshape.push_back(input_shape[input_shape.size() - 1]);
|
||||
input_shape = reshape;
|
||||
}
|
||||
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs, const std::map<uint32_t, tensor::TensorPtr> &) override;
|
||||
|
||||
is_null_input_ = CHECK_SHAPE_NULL(input_shape, kernel_name_, "input");
|
||||
if (is_null_input_) {
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
|
||||
shape_size_ = input_shape.size();
|
||||
if (shape_size_ > 3) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_
|
||||
<< "', the dimension of input must be less than and equal to 3, but got " << shape_size_;
|
||||
}
|
||||
auto kernel_name = common::AnfAlgo::GetCNodeName(kernel_node);
|
||||
if (kernel_name == "LogSoftmaxGrad") {
|
||||
algo_ = CUDNN_SOFTMAX_LOG;
|
||||
auto axis = static_cast<int>(GetAttr<int64_t>(kernel_node, "axis"));
|
||||
InitSizeByAxis(input_shape, axis);
|
||||
} else {
|
||||
algo_ = CUDNN_SOFTMAX_ACCURATE;
|
||||
std::vector<int> axis;
|
||||
std::vector<int64_t> axis_me = GetAttr<std::vector<int64_t>>(kernel_node, "axis");
|
||||
(void)std::transform(axis_me.begin(), axis_me.end(), std::back_inserter(axis),
|
||||
[](const int64_t &value) { return static_cast<int>(value); });
|
||||
if (axis.size() < 1) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the length of 'axis' cannot be equal to 0, but got "
|
||||
<< axis.size();
|
||||
}
|
||||
InitSizeByAxis(input_shape, axis[0]);
|
||||
}
|
||||
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(
|
||||
kernel_node_,
|
||||
cudnnSetTensor4dDescriptor(y_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, SizeToInt(batch_size_),
|
||||
SizeToInt(channel_size_), SizeToInt(height_), SizeToInt(width_)),
|
||||
"set input_descriptor failed");
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
void DestroyResource() noexcept override {
|
||||
CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyTensorDescriptor(y_desc_), "destroy output_descriptor failed");
|
||||
CHECK_CUDNN_RET_WITH_ERROR_NOTRACE(cudnnDestroyTensorDescriptor(y_desc_),
|
||||
kernel_name_ + " destroy output_descriptor failed");
|
||||
}
|
||||
|
||||
void ResetResource() noexcept override {
|
||||
cudnn_handle_ = nullptr;
|
||||
y_desc_ = nullptr;
|
||||
algo_ = CUDNN_SOFTMAX_ACCURATE;
|
||||
mode_ = CUDNN_SOFTMAX_MODE_INSTANCE;
|
||||
cudnn_data_type_ = CUDNN_DATA_FLOAT;
|
||||
is_null_input_ = false;
|
||||
input_size_ = 0;
|
||||
output_size_ = 0;
|
||||
workspace_size_ = 0;
|
||||
axis_ = 0;
|
||||
shape_size_ = 0;
|
||||
batch_size_ = 0;
|
||||
channel_size_ = 0;
|
||||
height_ = 0;
|
||||
width_ = 0;
|
||||
void ResetResource() {
|
||||
input_size_list_.clear();
|
||||
output_size_list_.clear();
|
||||
workspace_size_list_.clear();
|
||||
}
|
||||
|
||||
protected:
|
||||
void InitResource() override {
|
||||
cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle();
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateTensorDescriptor(&y_desc_), "create input_descriptor failed");
|
||||
}
|
||||
|
||||
void InitSizeLists() override {
|
||||
void InitSizeLists() {
|
||||
input_size_list_.push_back(input_size_);
|
||||
output_size_list_.push_back(output_size_);
|
||||
workspace_size_list_.push_back(input_size_);
|
||||
|
@ -214,6 +69,13 @@ class SoftmaxGradGpuKernelMod : public DeprecatedNativeGpuKernelMod {
|
|||
return;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr);
|
||||
using SoftmaxGradGpuLaunchFunc =
|
||||
std::function<bool(SoftmaxGradGpuKernelMod *, const std::vector<AddressPtr> &, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &, void *)>;
|
||||
|
||||
private:
|
||||
void InitSizeByAxis(const std::vector<size_t> input_shape, const int axis) {
|
||||
axis_ = axis;
|
||||
|
@ -239,34 +101,38 @@ class SoftmaxGradGpuKernelMod : public DeprecatedNativeGpuKernelMod {
|
|||
|
||||
height_ = 1;
|
||||
width_ = 1;
|
||||
input_size_ = sizeof(T) * batch_size_ * channel_size_ * height_ * width_;
|
||||
input_size_ = type_id_size_ * batch_size_ * channel_size_ * height_ * width_;
|
||||
output_size_ = input_size_;
|
||||
workspace_size_ = shape_size_ * sizeof(size_t);
|
||||
}
|
||||
|
||||
cudnnHandle_t cudnn_handle_;
|
||||
cudnnTensorDescriptor_t y_desc_;
|
||||
cudnnSoftmaxAlgorithm_t algo_;
|
||||
cudnnSoftmaxMode_t mode_;
|
||||
cudnnDataType_t cudnn_data_type_;
|
||||
bool is_null_input_;
|
||||
std::string kernel_name_;
|
||||
size_t input_size_;
|
||||
size_t output_size_;
|
||||
size_t workspace_size_;
|
||||
SoftmaxGradGpuLaunchFunc kernel_func_;
|
||||
static std::vector<std::pair<KernelAttr, SoftmaxGradGpuLaunchFunc>> func_list_;
|
||||
|
||||
cudnnHandle_t cudnn_handle_{nullptr};
|
||||
cudnnTensorDescriptor_t y_desc_{nullptr};
|
||||
cudnnSoftmaxAlgorithm_t algo_{CUDNN_SOFTMAX_ACCURATE};
|
||||
cudnnSoftmaxMode_t mode_{CUDNN_SOFTMAX_MODE_INSTANCE};
|
||||
cudnnDataType_t cudnn_data_type_{CUDNN_DATA_FLOAT};
|
||||
bool is_null_input_{false};
|
||||
std::string kernel_name_{"SoftmaxGrad"};
|
||||
size_t input_size_{0};
|
||||
size_t output_size_{0};
|
||||
size_t workspace_size_{0};
|
||||
|
||||
std::vector<size_t> input_shape_;
|
||||
std::vector<size_t> transpose_shape_;
|
||||
std::vector<size_t> transpose_axis_;
|
||||
int axis_;
|
||||
size_t shape_size_;
|
||||
int axis_{0};
|
||||
size_t shape_size_{0};
|
||||
|
||||
size_t batch_size_;
|
||||
size_t channel_size_;
|
||||
size_t height_;
|
||||
size_t width_;
|
||||
size_t batch_size_{0};
|
||||
size_t channel_size_{0};
|
||||
size_t height_{0};
|
||||
size_t width_{0};
|
||||
size_t type_id_size_{0};
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_SOFTMAX_GRAD_GPU_KERNEL_H_
|
||||
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_NN_SOFTMAX_GRAD_GPU_KERNEL_H_
|
||||
|
|
|
@ -705,6 +705,7 @@ GVAR_DEF(PrimitivePtr, kPrimFlatten, std::make_shared<Primitive>("Flatten"));
|
|||
GVAR_DEF(PrimitivePtr, kPrimCrop, std::make_shared<Primitive>("Crop"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimFlattenGrad, std::make_shared<Primitive>("FlattenGrad"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimSoftmax, std::make_shared<Primitive>("Softmax"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimSoftmaxGrad, std::make_shared<Primitive>("SoftmaxGrad"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimSoftsign, std::make_shared<Primitive>("Softsign"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimSparseSoftmaxCrossEntropy, std::make_shared<Primitive>("SparseSoftmaxCrossEntropy"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimSoftmaxV2WithDropoutDoMaskV3, std::make_shared<Primitive>("SoftmaxV2WithDropoutDoMaskV3"));
|
||||
|
|
|
@ -0,0 +1,73 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ops/grad/softmax_grad.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
abstract::ShapePtr SoftmaxGradInferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto axis = GetValue<int64_t>(primitive->GetAttr(kAxis));
|
||||
|
||||
auto shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape());
|
||||
auto in_shape = shape_map[kShape];
|
||||
if (!IsDynamicRank(in_shape)) {
|
||||
auto rank = SizeToLong(in_shape.size());
|
||||
(void)CheckAndConvertUtils::CheckInRange<int64_t>("axis", axis, kIncludeLeft, {-rank, rank}, primitive->name());
|
||||
}
|
||||
return std::make_shared<abstract::Shape>(in_shape);
|
||||
}
|
||||
|
||||
TypePtr SoftmaxGradInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto prim_name = prim->name();
|
||||
MS_EXCEPTION_IF_NULL(input_args[1]);
|
||||
auto x_type = input_args[1]->BuildType();
|
||||
MS_EXCEPTION_IF_NULL(x_type);
|
||||
if (!x_type->isa<TensorType>()) {
|
||||
MS_EXCEPTION(TypeError) << "For '" << prim_name << "', input must be a Tensor, but got: " << x_type->ToString()
|
||||
<< ".";
|
||||
}
|
||||
return x_type;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
MIND_API_OPERATOR_IMPL(SoftmaxGrad, BaseOperator);
|
||||
void SoftmaxGrad::Init(const int64_t axis) { this->set_axis(axis); }
|
||||
|
||||
void SoftmaxGrad::set_axis(const int64_t axis) { (void)this->AddAttr(kAxis, api::MakeValue(axis)); }
|
||||
|
||||
std::vector<int64_t> SoftmaxGrad::get_axis() const {
|
||||
auto value_ptr = GetAttr(kAxis);
|
||||
return GetValue<std::vector<int64_t>>(value_ptr);
|
||||
}
|
||||
|
||||
AbstractBasePtr SoftmaxGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
const int64_t input_num = 2;
|
||||
CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, input_num, primitive->name());
|
||||
auto type = SoftmaxGradInferType(primitive, input_args);
|
||||
auto shape = SoftmaxGradInferShape(primitive, input_args);
|
||||
return abstract::MakeAbstract(shape, type);
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(SoftmaxGrad, prim::kPrimSoftmaxGrad, SoftmaxGradInfer, nullptr, true);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,44 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_OPS_GRAD_SOFTMAX_GRAD_H_
|
||||
#define MINDSPORE_CORE_OPS_GRAD_SOFTMAX_GRAD_H_
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
|
||||
#include "ops/base_operator.h"
|
||||
#include "mindapi/base/types.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameSoftmaxGrad = "SoftmaxGrad";
|
||||
class MIND_API SoftmaxGrad : public BaseOperator {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(SoftmaxGrad);
|
||||
SoftmaxGrad() : BaseOperator(kNameSoftmaxGrad) { InitIOName({"x", "grad"}, {"y"}); }
|
||||
explicit SoftmaxGrad(const std::string k_name) : BaseOperator(k_name) { InitIOName({"x", "grad"}, {"y"}); }
|
||||
void Init(const int64_t axis = -1);
|
||||
void set_axis(const int64_t epsilon);
|
||||
std::vector<int64_t> get_axis() const;
|
||||
};
|
||||
|
||||
abstract::AbstractBasePtr SoftmaxGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_OPS_GRAD_SOFTMAX_GRAD_H_
|
|
@ -19,6 +19,7 @@ import pytest
|
|||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
import mindspore as ms
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops.functional import vmap
|
||||
|
@ -143,6 +144,55 @@ def test_logsoftmaxgrad1():
|
|||
assert np.allclose(dx[0].asnumpy(), expect)
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_logsoftmaxgrad1_dynamic_shape():
|
||||
"""
|
||||
Feature: test logsoftmax in gpu.
|
||||
Description: test the ops in dynamic shape.
|
||||
Expectation: expect correct result.
|
||||
"""
|
||||
x = np.array([[-0.47705367, 0.48267725, -1.0453935, 1.574488, 0.20362134, 0.4435456, -0.23984082, -0.43684655,
|
||||
-0.7725506, 1.4481013],
|
||||
[1.1012247, 1.7069651, 0.55062026, 0.3361901, -1.1082426, -0.5001939, -0.3255393, -0.7972024,
|
||||
-0.27965206, -0.702805],
|
||||
[0.19450496, 0.87596166, 0.6467245, -1.044987, 0.5248943, -2.6166635, 1.6719198, 0.06600758,
|
||||
-0.4099178, 1.1861311],
|
||||
[1.1305193, -1.97308, 2.1047623, -1.5105937, 0.93052036, 1.2467804, 0.5310002, 0.7084912, -1.3681422,
|
||||
-0.9686862],
|
||||
[1.871408, 0.14219497, -0.41050452, -0.749807, 1.4900619, -1.8172716, -0.73839617, 0.17565694,
|
||||
-0.4553867, -1.5423119]]).astype(np.float32)
|
||||
dy = np.array([[1.516363, -0.15196544, 0.598733, 0.64357865, 0.16265012, -1.3521105, 0.22621834, 0.7168259,
|
||||
-0.6709239, 0.79757756],
|
||||
[-0.32457778, 1.2831115, 1.1211495, -0.02665559, 1.9170904, -1.3397789, 1.4124829, -1.4298155,
|
||||
0.758519, -0.25322974],
|
||||
[-0.24226122, -1.2555921, 0.6492511, -0.34847677, 0.19916506, 0.628554, -0.19658111, 0.44939864,
|
||||
-0.11677749, -1.2131723],
|
||||
[0.24267715, 0.28106326, 1.1075432, -0.29006946, 0.31335673, 0.8833154, 0.13152207, 1.5482179,
|
||||
0.29770762, -0.16246222],
|
||||
[0.02145994, 0.80424, -0.95061, 1.5875458, -0.00308682, 0.17964548, 0.49912593, 0.46977136,
|
||||
0.2151897, 0.30908248]]).astype(np.float32)
|
||||
expect = np.array([[1.464194, -0.29578894, 0.5296974, -0.39600563, -0.1479242, -1.0869746, 0.04521982, 0.5064515,
|
||||
-0.7515615, 1.0554069],
|
||||
[-0.5774203, 0.793861, 0.7805745, -0.32800734, 1.8334473, -1.236596, 1.2463496, -1.5765365,
|
||||
0.6265108, -0.22322391],
|
||||
[-0.34437084, -1.4687154, 0.27432096, -0.42420125, -0.22908019, 0.640983, -1.4210342, 0.10155854,
|
||||
-0.23266247, -1.0147638],
|
||||
[-0.01768187, 0.26872346, -0.5037259, -0.3376058, -0.3291146, 1.4752979, -0.25972134, 0.8869053,
|
||||
0.25325722, -0.13946185],
|
||||
[-0.5247209, 0.70192003, -1.0808672, 1.4858199, -1.1273282, 0.20728993, 0.38918605, 0.08162117,
|
||||
0.10445589, 0.3220427]],).astype(np.float32)
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
net = LogSoftmax(0)
|
||||
dx = Grad(net)
|
||||
x_dyn = Tensor(shape=[5, None], dtype=ms.float32)
|
||||
dx.set_inputs(x_dyn, Tensor(dy))
|
||||
dx_out = dx(Tensor(x), Tensor(dy))
|
||||
assert np.allclose(dx_out[0].asnumpy(), expect)
|
||||
|
||||
|
||||
class LogSoftmaxForForward(nn.Cell):
|
||||
def __init__(self, axis=0):
|
||||
super().__init__()
|
||||
|
|
|
@ -19,6 +19,7 @@ import pytest
|
|||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
import mindspore as ms
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import functional as F
|
||||
|
@ -239,3 +240,97 @@ def test_softmax_functional():
|
|||
diff2 = np.abs(output_sum2 - expect2)
|
||||
assert np.all(diff1 < error1)
|
||||
assert np.all(diff2 < error2)
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_softmax_dynamic_shape():
|
||||
"""
|
||||
Feature: test softmax in gpu.
|
||||
Description: test the ops in dynamic shape.
|
||||
Expectation: expect correct result.
|
||||
"""
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
x = np.array([[[[2.7866030e-01, 8.5578346e-01, -2.7546784e-01, -8.5833269e-01, 1.5753637e-01],
|
||||
[-4.5145524e-01, 1.5590921e-01, -6.1947298e-01, -6.3499230e-01, -1.0625143e+00],
|
||||
[-6.8716180e-01, -3.5565588e-01, 9.9680430e-01, -3.5519487e-01, 5.2122700e-01],
|
||||
[-9.8125875e-01, 9.0505141e-01, 6.5961617e-01, 6.5950197e-01, 1.0319239e+00]],
|
||||
[[-7.6588345e-01, -1.6929083e-01, 9.4459933e-01, -8.3931917e-01, 1.4916732e+00],
|
||||
[8.1874236e-02, -1.9288104e-02, 7.3255712e-01, -1.4598954e-01, 1.1225560e+00],
|
||||
[2.7356184e-01, 1.2557162e-01, 1.3796539e+00, 1.0073920e-01, 7.9203087e-01],
|
||||
[-3.6947381e-01, 4.7919992e-01, 2.2421131e+00, -8.3911163e-01, 1.0814662e+00]],
|
||||
[[-2.5838584e-01, 2.0765430e-01, -1.9366746e-01, 6.7511219e-01, -3.7492469e-01],
|
||||
[4.4170797e-01, -9.9537361e-01, -3.5100895e-01, -7.8317386e-01, 1.1672008e-02],
|
||||
[1.6037937e+00, -1.7059358e+00, -9.3724984e-01, -1.5016698e+00, -2.7605603e-02],
|
||||
[1.6392696e-01, 1.0074581e+00, -2.7704465e+00, 8.1361882e-02, 7.9730105e-01]]],
|
||||
[[[2.9516423e-01, 4.6354745e-02, 1.7318316e-01, 1.5894413e+00, -1.2769363e+00],
|
||||
[2.8939021e-01, -3.8801813e-01, -1.3376296e+00, -4.9808905e-01, -3.2318991e-02],
|
||||
[-1.1740140e+00, -1.1140432e+00, -1.4198960e-01, 5.8953021e-02, -3.6763316e-01],
|
||||
[1.8660797e+00, -5.8705074e-01, 6.8757606e-01, -4.0573463e-01, -7.1130061e-01]],
|
||||
[[2.6170531e-01, 5.4814044e-02, 1.3891056e-01, 3.4492522e-02, -1.0920379e-01],
|
||||
[1.1420644e-01, 1.6939731e-01, -1.0413316e+00, -1.4040415e-01, -3.3280477e-01],
|
||||
[-3.0776244e-01, 1.0526397e+00, 2.9497927e-01, 1.1266683e+00, 8.4419928e-02],
|
||||
[-2.1593940e+00, -1.0187222e+00, 1.7475771e+00, -3.5802367e-01, -1.2900480e+00]],
|
||||
[[3.2892069e-01, -1.6604670e+00, -5.7856506e-01, 5.8143520e-01, 5.9596705e-01],
|
||||
[-1.5992336e-01, -5.9647644e-01, 1.2957820e+00, -1.0650631e-01, 7.0879894e-01],
|
||||
[4.1372257e-01, 3.6408889e-01, -6.3091749e-01, 1.0573713e+00, 1.0981073e+00],
|
||||
[-1.9162457e-01, 3.6392561e-05, -1.8338780e-01, 1.7549801e+00, -9.3534666e-01]]]]).astype(
|
||||
np.float32)
|
||||
|
||||
dy = np.array([[[[2.98213929e-01, 3.10518718e+00, -1.64306939e-01, -7.33681679e-01, 5.23136854e-02],
|
||||
[-3.47142726e-01, -1.52662742e+00, 5.26977003e-01, 5.29672280e-02, -4.34386432e-01],
|
||||
[1.34674394e+00, 1.69386661e+00, 3.17139983e-01, 5.77129781e-01, 1.25290680e+00],
|
||||
[-1.71099675e+00, -1.62872851e+00, -7.89083183e-01, 8.64615321e-01, -1.74364686e+00]],
|
||||
[[1.11915946e+00, -7.06878662e-01, -6.71557069e-01, -4.50884640e-01, 2.95763493e-01],
|
||||
[-7.64747679e-01, 1.62951392e-03, -2.84069944e-02, 7.55402744e-01, -1.02387452e+00],
|
||||
[-5.92088878e-01, 4.47980821e-01, 4.50127304e-01, -3.99038166e-01, -5.24561822e-01],
|
||||
[1.92535609e-01, 2.44671494e-01, -8.70469391e-01, -8.30129832e-02, -4.04477213e-03]],
|
||||
[[-1.94159836e-01, -8.50215256e-01, -1.01224804e+00, 2.64235616e-01, 5.34391068e-02],
|
||||
[-6.71353936e-01, 3.73690695e-01, 4.48037744e-01, -2.84973383e-01, -2.80129910e+00],
|
||||
[6.69475198e-01, 2.08404279e+00, 4.49459851e-01, 2.50908136e+00, 9.80683088e-01],
|
||||
[1.18290365e+00, -1.28790128e+00, -1.70202863e+00, -1.37078688e-01, 9.53227460e-01]]],
|
||||
[[[-6.44128084e-01, 1.37707603e+00, -8.60912442e-01, -3.83467346e-01, 6.68365955e-01],
|
||||
[-3.32795471e-01, 3.05202007e-01, 2.20850635e+00, 6.93960607e-01, -1.94968760e-01],
|
||||
[-3.35764170e-01, 1.10562348e+00, -1.13264215e+00, -1.08296621e+00, -6.53923571e-01],
|
||||
[-4.64974046e-01, 8.83257568e-01, -1.70353889e+00, -4.48120385e-01, -1.76938546e+00]],
|
||||
[[-3.80976290e-01, -1.49393475e+00, -8.51393223e-01, -1.49780405e+00, -1.24160886e-01],
|
||||
[-7.18508661e-02, 2.44543999e-01, 3.29225749e-01, 7.09274471e-01, -9.26648498e-01],
|
||||
[6.67312503e-01, -1.08737612e+00, -9.63039994e-01, -3.22715081e-02, -4.03802067e-01],
|
||||
[-5.97982287e-01, -1.40739769e-01, 2.80631828e+00, 5.72278857e-01, 2.05998325e+00]],
|
||||
[[3.46207246e-02, 7.34213948e-01, 1.45563519e+00, 1.02045703e+00, 1.40984225e+00],
|
||||
[4.14457440e-01, -8.74118507e-01, -4.21902031e-01, 7.87168801e-01, -1.48280108e+00],
|
||||
[1.42688036e+00, -2.02695489e+00, 9.26816165e-01, 9.37691629e-01, 7.85577714e-01],
|
||||
[-6.59893751e-01, 1.14681525e-02, -5.79456389e-01, -1.65206456e+00, 4.37116653e-01]]]]).astype(
|
||||
np.float32)
|
||||
|
||||
expect_dx = np.array([[[[-0.20103945, 0.737705, -0.17376284, -0.1370458, -0.22585672],
|
||||
[0.04461281, -0.34632078, 0.18386088, 0.10299816, 0.01484894],
|
||||
[0.04113413, 0.09592049, -0.22135337, -0.02833145, 0.11263024],
|
||||
[-0.0284293, -0.1661311, 0.04058228, 0.37645525, -0.22247711]],
|
||||
[[0.06355994, -0.06061868, -0.17428297, -0.01839012, 0.1897318],
|
||||
[-0.04652473, 0.05094835, 0.10032654, 0.12546772, -0.23021786],
|
||||
[-0.07882182, 0.05314343, 0.18712361, -0.04438123, -0.11706398],
|
||||
[0.03219109, 0.08079126, -0.22419631, 0.01224192, 0.09897206]],
|
||||
[[0.01057316, -0.1305348, -0.11175273, 0.19124077, 0.04047358],
|
||||
[0.07448982, 0.11195826, 0.2260284, 0.06497248, -0.47744888],
|
||||
[-0.09664576, 0.03458005, -0.02039931, 0.05646288, 0.02600216],
|
||||
[0.1973966, -0.47014874, -0.01431374, -0.01483214, 0.30189803]]],
|
||||
[[[-0.06132338, 0.19386888, -0.08370841, -0.07789247, 0.02905542],
|
||||
[-0.16714299, 0.0274538, 0.14029635, 0.08591694, -0.08652411],
|
||||
[0.03585254, 0.18327834, -0.11158065, -0.12024056, 0.01269035],
|
||||
[0.14654502, 0.0863447, -0.19723451, 0.01621746, -0.05187264]],
|
||||
[[0.11614501, -0.12182987, 0.00329342, -0.12011584, 0.12250728],
|
||||
[-0.03623635, 0.05001016, 0.02194443, 0.13183522, -0.16755345],
|
||||
[0.09322704, -0.18807998, -0.06984743, 0.15454148, 0.01015892],
|
||||
[-0.04743218, -0.12545264, 0.35787603, -0.1735842, -0.01140684]],
|
||||
[[-0.21854429, -0.00674347, 0.05053139, 0.02567403, 0.14908233],
|
||||
[0.09731252, -0.02596174, 0.03463032, 0.14460044, -0.2505815],
|
||||
[0.1478814, -0.3902862, 0.02360253, 0.13103928, 0.087763],
|
||||
[0.04834083, 0.13455458, 0.05632052, -0.3109298, 0.07171366]]]]).astype(np.float32)
|
||||
dx = Grad(Net())
|
||||
x_dyn = Tensor(shape=[2, 3, 4, None], dtype=ms.float32)
|
||||
dx.set_inputs(x_dyn, Tensor(dy))
|
||||
out_dx = dx(Tensor(x), Tensor(dy))
|
||||
assert np.allclose(out_dx[0].asnumpy(), expect_dx)
|
||||
|
|
Loading…
Reference in New Issue