forked from mindspore-Ecosystem/mindspore
!48988 Add GRUV2Grad OPS IMPL
Merge pull request !48988 from yangluhang/gruv2_grad
This commit is contained in:
commit
026c8d9084
|
@ -0,0 +1,273 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "plugin/device/cpu/kernel/mkldnn/gru_grad_cpu_kernel.h"
|
||||||
|
#include <cstddef>
|
||||||
|
#include <cstring>
|
||||||
|
#include <string>
|
||||||
|
#include "utils/ms_utils.h"
|
||||||
|
#include "mindspore/core/ops/grad/gru_v2_grad.h"
|
||||||
|
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace kernel {
|
||||||
|
namespace {
|
||||||
|
constexpr size_t kGruGradInputsNum = 9;
|
||||||
|
constexpr size_t kGruGradOutputsNum = 3;
|
||||||
|
constexpr size_t kNumberOne = 1;
|
||||||
|
constexpr size_t kNumberTwo = 2;
|
||||||
|
constexpr size_t kGateNum = 3;
|
||||||
|
constexpr size_t kDims = 3;
|
||||||
|
constexpr int kMaxGRULayer = 100;
|
||||||
|
|
||||||
|
constexpr int kSrcLayerIdx = 0;
|
||||||
|
constexpr int kSrcIterIdx = 1;
|
||||||
|
constexpr int kDstLayerIdx = 4;
|
||||||
|
constexpr int kDstIterIdx = 5;
|
||||||
|
constexpr int kWorkSpaceIdx = 8;
|
||||||
|
constexpr int kDiffSrcLayerIdx = 0;
|
||||||
|
constexpr int kDiffSrcIterIdx = 1;
|
||||||
|
constexpr int kDiffDstLayerIdx = 6;
|
||||||
|
constexpr int kDiffDstIterIdx = 7;
|
||||||
|
|
||||||
|
using tag = dnnl::memory::format_tag;
|
||||||
|
using dim = dnnl::memory::dims;
|
||||||
|
using dt = dnnl::memory::data_type;
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
bool GRUGradCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||||
|
const std::vector<KernelTensorPtr> &outputs) {
|
||||||
|
MS_EXCEPTION_IF_NULL(base_operator);
|
||||||
|
kernel_name_ = base_operator->name();
|
||||||
|
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kGruGradInputsNum, kernel_name_);
|
||||||
|
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kGruGradOutputsNum, kernel_name_);
|
||||||
|
auto op_prim = std::dynamic_pointer_cast<ops::GRUV2Grad>(base_operator);
|
||||||
|
MS_EXCEPTION_IF_NULL(op_prim);
|
||||||
|
bidirectional_ = op_prim->get_bidirectional();
|
||||||
|
input_size_ = op_prim->get_input_size();
|
||||||
|
hidden_size_ = op_prim->get_hidden_size();
|
||||||
|
num_layers_ = op_prim->get_num_layers();
|
||||||
|
has_bias_ = op_prim->get_has_bias();
|
||||||
|
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
|
||||||
|
auto match = MatchKernelAttr(kernel_attr, GetOpSupport());
|
||||||
|
if (!match.first) {
|
||||||
|
MS_LOG(ERROR) << "For '" << kernel_name_ << "', it does not support this kernel data type: " << kernel_attr;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
int GRUGradCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||||
|
const std::vector<KernelTensorPtr> &outputs,
|
||||||
|
const std::map<uint32_t, tensor::TensorPtr> &) {
|
||||||
|
auto ret = KernelMod::Resize(base_operator, inputs, outputs);
|
||||||
|
if (ret != KRET_OK) {
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
input_size_list_[kIndex8] = reserve_size_;
|
||||||
|
auto src_shape = inputs[kIndex0]->GetDeviceShapeAdaptively();
|
||||||
|
auto src_h_shape = inputs[kIndex1]->GetDeviceShapeAdaptively();
|
||||||
|
if (src_shape.size() != kDims || src_h_shape.size() != kDims) {
|
||||||
|
MS_LOG(ERROR) << "GRU only support 3-D input!,but the src_shape dim is" << src_shape.size()
|
||||||
|
<< ", the src_shape dim is" << src_h_shape.size();
|
||||||
|
return KRET_RESIZE_FAILED;
|
||||||
|
}
|
||||||
|
batch_size_ = src_shape[1];
|
||||||
|
seq_len_ = src_shape[0];
|
||||||
|
num_directions_ = kNumberOne;
|
||||||
|
if (bidirectional_) {
|
||||||
|
num_directions_ = kNumberTwo;
|
||||||
|
}
|
||||||
|
const int64_t gate_size = kGateNum * hidden_size_;
|
||||||
|
if (num_layers_ <= 0) {
|
||||||
|
MS_LOG(ERROR) << "Layers must be greater than zero! but the num_layers is " << num_layers_;
|
||||||
|
return KRET_RESIZE_FAILED;
|
||||||
|
}
|
||||||
|
if (num_layers_ > kMaxGRULayer) {
|
||||||
|
MS_LOG(ERROR) << "Layers must be less than or equal to 100! but the num_layers_ is " << num_layers_;
|
||||||
|
return KRET_RESIZE_FAILED;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < num_layers_; ++i) {
|
||||||
|
weight_size_ += gate_size * (i == 0 ? input_size_ : hidden_size_ * num_directions_);
|
||||||
|
weight_h_size_ += gate_size * hidden_size_;
|
||||||
|
}
|
||||||
|
weight_size_ = weight_size_ * num_directions_;
|
||||||
|
weight_h_size_ = weight_h_size_ * num_directions_;
|
||||||
|
|
||||||
|
weights_dims_ = {num_layers_, num_directions_, input_size_, kGateNum, hidden_size_};
|
||||||
|
weights_h_dims_ = {num_layers_, num_directions_, hidden_size_, kGateNum, hidden_size_};
|
||||||
|
bias_dims_ = {num_layers_, num_directions_, kGateNum, hidden_size_};
|
||||||
|
|
||||||
|
if (num_directions_ * num_layers_ != src_h_shape[0]) {
|
||||||
|
MS_LOG(ERROR) << "Error iteration shape!, iteration shape[0] is required to be " << num_directions_ * num_layers_
|
||||||
|
<< " but " << src_h_shape[0];
|
||||||
|
return KRET_RESIZE_FAILED;
|
||||||
|
}
|
||||||
|
InitDnnl();
|
||||||
|
return KRET_OK;
|
||||||
|
}
|
||||||
|
|
||||||
|
void GRUGradCpuKernelMod::InitDnnl() {
|
||||||
|
auto eng = engine_;
|
||||||
|
dnnl::rnn_direction direction = dnnl::rnn_direction::unidirectional;
|
||||||
|
if (bidirectional_) {
|
||||||
|
direction = dnnl::rnn_direction::bidirectional_concat;
|
||||||
|
}
|
||||||
|
dim src_dims = {seq_len_, batch_size_, input_size_};
|
||||||
|
dim src_h_dims = {num_layers_, num_directions_, batch_size_, hidden_size_};
|
||||||
|
weights_dims_ = {num_layers_, num_directions_, input_size_, kGateNum, hidden_size_};
|
||||||
|
weights_h_dims_ = {num_layers_, num_directions_, hidden_size_, kGateNum, hidden_size_};
|
||||||
|
bias_dims_ = {num_layers_, num_directions_, kGateNum, hidden_size_};
|
||||||
|
dim dst_dims = {seq_len_, batch_size_, static_cast<int64_t>(hidden_size_) * num_directions_};
|
||||||
|
dim dst_h_dims = {num_layers_, num_directions_, batch_size_, hidden_size_};
|
||||||
|
|
||||||
|
dnnl::memory::desc src_desc = formatted_md(src_dims, tag::tnc);
|
||||||
|
dnnl::memory::desc src_h_desc = formatted_md(src_h_dims, tag::ldnc);
|
||||||
|
dnnl::memory::desc bias_desc = formatted_md(bias_dims_, tag::ldgo);
|
||||||
|
dnnl::memory::desc dst_desc = formatted_md(dst_dims, tag::tnc);
|
||||||
|
dnnl::memory::desc dst_h_desc = formatted_md(dst_h_dims, tag::ldnc);
|
||||||
|
|
||||||
|
auto weights_desc = formatted_md(weights_dims_, tag::any);
|
||||||
|
auto weights_h_desc = formatted_md(weights_h_dims_, tag::any);
|
||||||
|
|
||||||
|
auto forward_desc =
|
||||||
|
CreatePrimitive<dnnl::gru_forward::desc>(dnnl::prop_kind::forward_training, direction, src_desc, src_h_desc,
|
||||||
|
weights_desc, weights_h_desc, bias_desc, dst_desc, dst_h_desc);
|
||||||
|
|
||||||
|
auto prim_forward_desc = CreateDesc<dnnl::gru_forward::primitive_desc>(*forward_desc, eng);
|
||||||
|
auto backward_desc = CreatePrimitive<dnnl::gru_backward::desc>(
|
||||||
|
dnnl::prop_kind::backward, direction, src_desc, src_h_desc, weights_desc, weights_h_desc, bias_desc, dst_desc,
|
||||||
|
dst_h_desc, src_desc, src_h_desc, weights_desc, weights_h_desc, bias_desc, dst_desc, dst_h_desc);
|
||||||
|
prim_backward_desc_ = CreateDesc<dnnl::gru_backward::primitive_desc>(*backward_desc, eng, prim_forward_desc);
|
||||||
|
primitive_ = CreatePrimitive<dnnl::gru_backward>(prim_backward_desc_);
|
||||||
|
auto wksp_desc = GetWorkspaceDesc(prim_forward_desc);
|
||||||
|
reserve_size_ = GetSize(wksp_desc);
|
||||||
|
AddArgumentOp(src_desc, src_h_desc, bias_desc, dst_desc, dst_h_desc, wksp_desc);
|
||||||
|
|
||||||
|
// construct fw memory
|
||||||
|
weights_layer_desc_ = GetWeightsLayerDesc(prim_backward_desc_);
|
||||||
|
weights_iter_desc_ = GetWeightsIterDesc(prim_backward_desc_);
|
||||||
|
bias_desc_ = GetBiasDesc(prim_backward_desc_);
|
||||||
|
auto weights_mem_desc = CreateDesc<dnnl::memory::desc>(weights_dims_, dt::f32, tag::ldgoi);
|
||||||
|
auto weights_h_mem_desc = CreateDesc<dnnl::memory::desc>(weights_h_dims_, dt::f32, tag::ldgoi);
|
||||||
|
user_weights_memory_ = CreateDesc<dnnl::memory>(weights_mem_desc, eng);
|
||||||
|
user_weights_h_memory_ = CreateDesc<dnnl::memory>(weights_h_mem_desc, eng);
|
||||||
|
weights_memory_ = CreateDesc<dnnl::memory>(weights_layer_desc_, eng);
|
||||||
|
weights_h_memory_ = CreateDesc<dnnl::memory>(weights_iter_desc_, eng);
|
||||||
|
bias_memory_ = CreateDesc<dnnl::memory>(bias_desc_, eng);
|
||||||
|
|
||||||
|
// construct bw memory
|
||||||
|
diff_weights_layer_desc_ = GetDiffWeightsLayerDesc(prim_backward_desc_);
|
||||||
|
diff_weights_iter_desc_ = GetDiffWeightsIterDesc(prim_backward_desc_);
|
||||||
|
diff_bias_desc_ = GetDiffBiasDesc(prim_backward_desc_);
|
||||||
|
diff_weights_memory_ = CreateDesc<dnnl::memory>(diff_weights_layer_desc_, eng);
|
||||||
|
diff_weights_h_memory_ = CreateDesc<dnnl::memory>(diff_weights_iter_desc_, eng);
|
||||||
|
diff_bias_memory_ = CreateDesc<dnnl::memory>(diff_bias_desc_, eng);
|
||||||
|
user_diff_weights_memory_ = CreateDesc<dnnl::memory>(weights_mem_desc, eng);
|
||||||
|
user_diff_weights_h_memory_ = CreateDesc<dnnl::memory>(weights_h_mem_desc, eng);
|
||||||
|
}
|
||||||
|
|
||||||
|
void GRUGradCpuKernelMod::AddArgumentOp(const dnnl::memory::desc &src_desc, const dnnl::memory::desc &src_h_desc,
|
||||||
|
const dnnl::memory::desc &bias_desc, const dnnl::memory::desc &dst_desc,
|
||||||
|
const dnnl::memory::desc &dst_h_desc, const dnnl::memory::desc &wksp_desc) {
|
||||||
|
AddArgument(DNNL_ARG_SRC_LAYER, src_desc);
|
||||||
|
AddArgument(DNNL_ARG_SRC_ITER, src_h_desc);
|
||||||
|
AddArgument(DNNL_ARG_WEIGHTS_LAYER, weights_layer_desc_);
|
||||||
|
AddArgument(DNNL_ARG_WEIGHTS_ITER, weights_iter_desc_);
|
||||||
|
AddArgument(DNNL_ARG_BIAS, bias_desc);
|
||||||
|
AddArgument(DNNL_ARG_DST_LAYER, dst_desc);
|
||||||
|
AddArgument(DNNL_ARG_DST_ITER, dst_h_desc);
|
||||||
|
AddArgument(DNNL_ARG_DIFF_SRC_LAYER, src_desc);
|
||||||
|
AddArgument(DNNL_ARG_DIFF_SRC_ITER, src_h_desc);
|
||||||
|
AddArgument(DNNL_ARG_DIFF_WEIGHTS_LAYER, diff_weights_layer_desc_);
|
||||||
|
AddArgument(DNNL_ARG_DIFF_WEIGHTS_ITER, diff_weights_iter_desc_);
|
||||||
|
AddArgument(DNNL_ARG_DIFF_BIAS, bias_desc);
|
||||||
|
AddArgument(DNNL_ARG_DIFF_DST_LAYER, dst_desc);
|
||||||
|
AddArgument(DNNL_ARG_DIFF_DST_ITER, dst_h_desc);
|
||||||
|
AddArgument(DNNL_ARG_WORKSPACE, wksp_desc);
|
||||||
|
}
|
||||||
|
|
||||||
|
void GRUGradCpuKernelMod::SetArgumentHandleOp(const std::vector<kernel::AddressPtr> &inputs,
|
||||||
|
const std::vector<kernel::AddressPtr> &outputs) {
|
||||||
|
SetArgumentHandle(DNNL_ARG_SRC_LAYER, inputs[kSrcLayerIdx]->addr);
|
||||||
|
SetArgumentHandle(DNNL_ARG_SRC_ITER, inputs[kSrcIterIdx]->addr);
|
||||||
|
SetArgumentHandle(DNNL_ARG_WEIGHTS_LAYER, GetDataHandle(weights_memory_));
|
||||||
|
SetArgumentHandle(DNNL_ARG_WEIGHTS_ITER, GetDataHandle(weights_h_memory_));
|
||||||
|
SetArgumentHandle(DNNL_ARG_BIAS, GetDataHandle(bias_memory_));
|
||||||
|
SetArgumentHandle(DNNL_ARG_DST_LAYER, inputs[kDstLayerIdx]->addr);
|
||||||
|
SetArgumentHandle(DNNL_ARG_DST_ITER, inputs[kDstIterIdx]->addr);
|
||||||
|
SetArgumentHandle(DNNL_ARG_WORKSPACE, inputs[kWorkSpaceIdx]->addr);
|
||||||
|
SetArgumentHandle(DNNL_ARG_DIFF_SRC_LAYER, outputs[kDiffSrcLayerIdx]->addr);
|
||||||
|
SetArgumentHandle(DNNL_ARG_DIFF_SRC_ITER, outputs[kDiffSrcIterIdx]->addr);
|
||||||
|
SetArgumentHandle(DNNL_ARG_DIFF_WEIGHTS_LAYER, GetDataHandle(diff_weights_memory_));
|
||||||
|
SetArgumentHandle(DNNL_ARG_DIFF_WEIGHTS_ITER, GetDataHandle(diff_weights_h_memory_));
|
||||||
|
SetArgumentHandle(DNNL_ARG_DIFF_BIAS, GetDataHandle(diff_bias_memory_));
|
||||||
|
SetArgumentHandle(DNNL_ARG_DIFF_DST_LAYER, inputs[kDiffDstLayerIdx]->addr);
|
||||||
|
SetArgumentHandle(DNNL_ARG_DIFF_DST_ITER, inputs[kDiffDstIterIdx]->addr);
|
||||||
|
}
|
||||||
|
|
||||||
|
void GRUGradCpuKernelMod::ResetMemory(const dnnl::memory &mem, const string name) const {
|
||||||
|
auto dst_ptr = GetDataHandle(mem);
|
||||||
|
auto mem_desc = GetMemDesc(mem);
|
||||||
|
auto size = GetSize(mem_desc);
|
||||||
|
if (memset_s(dst_ptr, size, 0, size) != EOK) {
|
||||||
|
MS_LOG(EXCEPTION) << name << " memset error";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bool GRUGradCpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &,
|
||||||
|
const std::vector<kernel::AddressPtr> &outputs) {
|
||||||
|
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kGruGradInputsNum, kernel_name_);
|
||||||
|
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kGruGradOutputsNum, kernel_name_);
|
||||||
|
SetDataHandle(user_weights_memory_, inputs[kIndex2]->addr);
|
||||||
|
SetDataHandle(user_weights_h_memory_, reinterpret_cast<float *>(inputs[kIndex2]->addr) + weight_size_);
|
||||||
|
Reorder(&user_weights_memory_, &weights_memory_);
|
||||||
|
Reorder(&user_weights_h_memory_, &weights_h_memory_);
|
||||||
|
if (has_bias_) {
|
||||||
|
SetDataHandle(bias_memory_, reinterpret_cast<float *>(inputs[kIndex2]->addr) + weight_size_ + weight_h_size_);
|
||||||
|
} else {
|
||||||
|
auto dst_ptr = GetDataHandle(bias_memory_);
|
||||||
|
auto size = GetSize(bias_desc_);
|
||||||
|
if (memset_s(dst_ptr, size, 0, size) != EOK) {
|
||||||
|
MS_LOG(EXCEPTION) << "Bias memset error";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
SetDataHandle(user_diff_weights_memory_, outputs[kIndex2]->addr);
|
||||||
|
SetDataHandle(user_diff_weights_h_memory_, reinterpret_cast<float *>(outputs[kIndex2]->addr) + weight_size_);
|
||||||
|
ResetMemory(user_diff_weights_memory_, "user weights grad");
|
||||||
|
ResetMemory(user_diff_weights_h_memory_, "user weights iter grad");
|
||||||
|
ResetMemory(diff_weights_memory_, "weights grad");
|
||||||
|
ResetMemory(diff_weights_h_memory_, "weights iter grad");
|
||||||
|
if (has_bias_) {
|
||||||
|
SetDataHandle(diff_bias_memory_, reinterpret_cast<float *>(outputs[kIndex2]->addr) + weight_size_ + weight_h_size_);
|
||||||
|
}
|
||||||
|
auto dst_ptr = GetDataHandle(diff_bias_memory_);
|
||||||
|
auto size = GetSize(diff_bias_desc_);
|
||||||
|
if (memset_s(dst_ptr, size, 0, size) != EOK) {
|
||||||
|
MS_LOG(EXCEPTION) << "Bias grad memset error";
|
||||||
|
}
|
||||||
|
SetArgumentHandleOp(inputs, outputs);
|
||||||
|
ExecutePrimitive();
|
||||||
|
Reorder(&diff_weights_memory_, &user_diff_weights_memory_);
|
||||||
|
Reorder(&diff_weights_h_memory_, &user_diff_weights_h_memory_);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, GRUV2Grad, GRUGradCpuKernelMod);
|
||||||
|
} // namespace kernel
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,107 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_GRU_GRAD_CPU_KERNEL_H_
|
||||||
|
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_GRU_GRAD_CPU_KERNEL_H_
|
||||||
|
|
||||||
|
#include <map>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
#include <memory>
|
||||||
|
#include "plugin/device/cpu/kernel/mkldnn/mkl_cpu_kernel.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace kernel {
|
||||||
|
class GRUGradCpuKernelMod : public MKLCpuKernelMod {
|
||||||
|
public:
|
||||||
|
GRUGradCpuKernelMod() = default;
|
||||||
|
~GRUGradCpuKernelMod() override = default;
|
||||||
|
|
||||||
|
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||||
|
const std::vector<KernelTensorPtr> &outputs) override;
|
||||||
|
int Resize(
|
||||||
|
const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||||
|
const std::vector<KernelTensorPtr> &outputs,
|
||||||
|
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost = std::map<uint32_t, tensor::TensorPtr>()) override;
|
||||||
|
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||||
|
const std::vector<AddressPtr> &outputs) override;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
std::vector<KernelAttr> GetOpSupport() override {
|
||||||
|
static std::vector<KernelAttr> support_list = {KernelAttr()
|
||||||
|
.AddInputAttr(kNumberTypeFloat32)
|
||||||
|
.AddInputAttr(kNumberTypeFloat32)
|
||||||
|
.AddInputAttr(kNumberTypeFloat32)
|
||||||
|
.AddInputAttr(kNumberTypeInt32)
|
||||||
|
.AddInputAttr(kNumberTypeFloat32)
|
||||||
|
.AddInputAttr(kNumberTypeFloat32)
|
||||||
|
.AddInputAttr(kNumberTypeFloat32)
|
||||||
|
.AddInputAttr(kNumberTypeFloat32)
|
||||||
|
.AddInputAttr(kNumberTypeFloat32)
|
||||||
|
.AddOutputAttr(kNumberTypeFloat32)
|
||||||
|
.AddOutputAttr(kNumberTypeFloat32)
|
||||||
|
.AddOutputAttr(kNumberTypeFloat32)};
|
||||||
|
return support_list;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
void AddArgumentOp(const dnnl::memory::desc &src_desc, const dnnl::memory::desc &src_c_desc,
|
||||||
|
const dnnl::memory::desc &bias_desc, const dnnl::memory::desc &dst_desc,
|
||||||
|
const dnnl::memory::desc &dst_h_desc, const dnnl::memory::desc &wksp_desc);
|
||||||
|
void SetArgumentHandleOp(const std::vector<kernel::AddressPtr> &inputs,
|
||||||
|
const std::vector<kernel::AddressPtr> &outputs);
|
||||||
|
void ResetMemory(const dnnl::memory &mem, const string name) const;
|
||||||
|
void InitDnnl();
|
||||||
|
|
||||||
|
int weight_size_{0};
|
||||||
|
int weight_h_size_{0};
|
||||||
|
int input_size_{0};
|
||||||
|
int hidden_size_{0};
|
||||||
|
int num_layers_{0};
|
||||||
|
int batch_size_{0};
|
||||||
|
int seq_len_{0};
|
||||||
|
int num_directions_{0};
|
||||||
|
bool bidirectional_{false};
|
||||||
|
bool has_bias_{false};
|
||||||
|
size_t reserve_size_{1};
|
||||||
|
|
||||||
|
dnnl::memory::dims weights_dims_;
|
||||||
|
dnnl::memory::dims weights_h_dims_;
|
||||||
|
dnnl::memory::dims bias_dims_;
|
||||||
|
|
||||||
|
dnnl::gru_backward::primitive_desc prim_backward_desc_;
|
||||||
|
dnnl::memory::desc weights_layer_desc_;
|
||||||
|
dnnl::memory::desc weights_iter_desc_;
|
||||||
|
dnnl::memory::desc bias_desc_;
|
||||||
|
|
||||||
|
dnnl::memory::desc diff_weights_layer_desc_;
|
||||||
|
dnnl::memory::desc diff_weights_iter_desc_;
|
||||||
|
dnnl::memory::desc diff_bias_desc_;
|
||||||
|
|
||||||
|
dnnl::memory user_weights_memory_;
|
||||||
|
dnnl::memory user_weights_h_memory_;
|
||||||
|
dnnl::memory weights_memory_;
|
||||||
|
dnnl::memory weights_h_memory_;
|
||||||
|
dnnl::memory bias_memory_;
|
||||||
|
dnnl::memory diff_weights_memory_;
|
||||||
|
dnnl::memory diff_weights_h_memory_;
|
||||||
|
dnnl::memory diff_bias_memory_;
|
||||||
|
dnnl::memory user_diff_weights_memory_;
|
||||||
|
dnnl::memory user_diff_weights_h_memory_;
|
||||||
|
};
|
||||||
|
} // namespace kernel
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_GRU_GRAD_CPU_KERNEL_H_
|
|
@ -673,6 +673,7 @@ GVAR_DEF(PrimitivePtr, kPrimExtractImagePatches, std::make_shared<Primitive>("Ex
|
||||||
GVAR_DEF(PrimitivePtr, kPrimDynamicRNN, std::make_shared<Primitive>("DynamicRNN"));
|
GVAR_DEF(PrimitivePtr, kPrimDynamicRNN, std::make_shared<Primitive>("DynamicRNN"));
|
||||||
GVAR_DEF(PrimitivePtr, kPrimCudnnGRU, std::make_shared<Primitive>("CudnnGRU"));
|
GVAR_DEF(PrimitivePtr, kPrimCudnnGRU, std::make_shared<Primitive>("CudnnGRU"));
|
||||||
GVAR_DEF(PrimitivePtr, kPrimGRUV2, std::make_shared<Primitive>("GRUV2"));
|
GVAR_DEF(PrimitivePtr, kPrimGRUV2, std::make_shared<Primitive>("GRUV2"));
|
||||||
|
GVAR_DEF(PrimitivePtr, kPrimGRUV2Grad, std::make_shared<Primitive>("GRUV2Grad"));
|
||||||
GVAR_DEF(PrimitivePtr, kPrimLSTMV2, std::make_shared<Primitive>("LSTMV2"));
|
GVAR_DEF(PrimitivePtr, kPrimLSTMV2, std::make_shared<Primitive>("LSTMV2"));
|
||||||
GVAR_DEF(PrimitivePtr, kPrimDynamicRNNGrad, std::make_shared<Primitive>("DynamicRNNGrad"));
|
GVAR_DEF(PrimitivePtr, kPrimDynamicRNNGrad, std::make_shared<Primitive>("DynamicRNNGrad"));
|
||||||
GVAR_DEF(PrimitivePtr, kPrimDynamicGRUV2, std::make_shared<Primitive>("DynamicGRUV2"));
|
GVAR_DEF(PrimitivePtr, kPrimDynamicGRUV2, std::make_shared<Primitive>("DynamicGRUV2"));
|
||||||
|
|
|
@ -0,0 +1,157 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "ops/grad/gru_v2_grad.h"
|
||||||
|
#include <algorithm>
|
||||||
|
#include <cstdint>
|
||||||
|
#include "ops/op_utils.h"
|
||||||
|
#include "utils/check_convert_utils.h"
|
||||||
|
#include "mindapi/src/helper.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace ops {
|
||||||
|
void GRUV2Grad::set_input_size(const int64_t input_size) {
|
||||||
|
(void)CheckAndConvertUtils::CheckInteger(kInput_size, input_size, kGreaterThan, 0, this->name());
|
||||||
|
(void)AddAttr(kInput_size, api::MakeValue(input_size));
|
||||||
|
}
|
||||||
|
int64_t GRUV2Grad::get_input_size() const { return GetValue<int64_t>(GetAttr(kInput_size)); }
|
||||||
|
void GRUV2Grad::set_hidden_size(const int64_t hidden_size) {
|
||||||
|
(void)CheckAndConvertUtils::CheckInteger(kHidden_size, hidden_size, kGreaterThan, 0, this->name());
|
||||||
|
(void)AddAttr(kHidden_size, api::MakeValue(hidden_size));
|
||||||
|
}
|
||||||
|
int64_t GRUV2Grad::get_hidden_size() const { return GetValue<int64_t>(GetAttr(kHidden_size)); }
|
||||||
|
void GRUV2Grad::set_num_layers(const int64_t num_layers) {
|
||||||
|
(void)CheckAndConvertUtils::CheckInteger(kNumLayers, num_layers, kGreaterThan, 0, this->name());
|
||||||
|
(void)AddAttr(kNumLayers, api::MakeValue(num_layers));
|
||||||
|
}
|
||||||
|
int64_t GRUV2Grad::get_num_layers() const { return GetValue<int64_t>(GetAttr(kNumLayers)); }
|
||||||
|
void GRUV2Grad::set_has_bias(const bool has_bias) { (void)AddAttr(kHasBias, api::MakeValue(has_bias)); }
|
||||||
|
bool GRUV2Grad::get_has_bias() const {
|
||||||
|
auto value_ptr = this->GetAttr(kHasBias);
|
||||||
|
return GetValue<bool>(value_ptr);
|
||||||
|
}
|
||||||
|
void GRUV2Grad::set_dropout(const float dropout) {
|
||||||
|
CheckAndConvertUtils::CheckInRange<float>(kDropout, dropout, kIncludeBoth, {0.0, 1.0}, this->name());
|
||||||
|
(void)AddAttr(kDropout, api::MakeValue(dropout));
|
||||||
|
}
|
||||||
|
float GRUV2Grad::get_dropout() const {
|
||||||
|
auto value_ptr = this->GetAttr(kDropout);
|
||||||
|
return GetValue<float>(value_ptr);
|
||||||
|
}
|
||||||
|
void GRUV2Grad::set_bidirectional(const bool bidirectional) {
|
||||||
|
(void)AddAttr(kBidirectional, api::MakeValue(bidirectional));
|
||||||
|
}
|
||||||
|
bool GRUV2Grad::get_bidirectional() const {
|
||||||
|
auto value_ptr = this->GetAttr(kBidirectional);
|
||||||
|
return GetValue<bool>(value_ptr);
|
||||||
|
}
|
||||||
|
void GRUV2Grad::set_num_directions(const int64_t num_directions) {
|
||||||
|
(void)AddAttr(kNumDirections, api::MakeValue(num_directions));
|
||||||
|
}
|
||||||
|
int64_t GRUV2Grad::get_num_directions() const { return GetValue<int64_t>(GetAttr(kNumDirections)); }
|
||||||
|
|
||||||
|
void GRUV2Grad::Init(const int64_t input_size, const int64_t hidden_size, const int64_t num_layers, const bool has_bias,
|
||||||
|
const float dropout, const bool bidirectional) {
|
||||||
|
this->set_input_size(input_size);
|
||||||
|
this->set_hidden_size(hidden_size);
|
||||||
|
this->set_num_layers(num_layers);
|
||||||
|
this->set_has_bias(has_bias);
|
||||||
|
this->set_dropout(dropout);
|
||||||
|
this->set_bidirectional(bidirectional);
|
||||||
|
if (bidirectional) {
|
||||||
|
constexpr int k2Directions = 2;
|
||||||
|
this->set_num_directions(k2Directions);
|
||||||
|
} else {
|
||||||
|
this->set_num_directions(1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
class GruGradInfer : public abstract::OpInferBase {
|
||||||
|
const int kInputNum = 9;
|
||||||
|
const int64_t kNumber1 = 1;
|
||||||
|
const int64_t kNumber2 = 2;
|
||||||
|
const int64_t kNumber3 = 3;
|
||||||
|
const size_t kShapeSize = 3;
|
||||||
|
const int kIndex0 = 0;
|
||||||
|
const int kIndex2 = 2;
|
||||||
|
const int kHxIdx = 1;
|
||||||
|
const int kYIdx = 4;
|
||||||
|
const int kDyIdx = 6;
|
||||||
|
const int kDhyIdx = 7;
|
||||||
|
|
||||||
|
public:
|
||||||
|
GruGradInfer() = default;
|
||||||
|
|
||||||
|
BaseShapePtr InferShape(const PrimitivePtr &primitive,
|
||||||
|
const std::vector<AbstractBasePtr> &input_args) const override {
|
||||||
|
MS_EXCEPTION_IF_NULL(primitive);
|
||||||
|
auto prim_name = primitive->name();
|
||||||
|
(void)CheckAndConvertUtils::CheckInteger("input numbers", SizeToLong(input_args.size()), kEqual, kInputNum,
|
||||||
|
prim_name);
|
||||||
|
auto y_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kYIdx]->BuildShape())[kShape];
|
||||||
|
auto dy_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kDyIdx]->BuildShape())[kShape];
|
||||||
|
auto dhy_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kDhyIdx]->BuildShape())[kShape];
|
||||||
|
(void)CheckAndConvertUtils::CheckInteger("dhy_shape size", SizeToLong(dhy_shape.size()), kEqual, kShapeSize,
|
||||||
|
prim_name);
|
||||||
|
(void)CheckAndConvertUtils::CheckInteger("dy_shape size", SizeToLong(dy_shape.size()), kEqual, kShapeSize,
|
||||||
|
prim_name);
|
||||||
|
|
||||||
|
int64_t num_layers = GetValue<int64_t>(primitive->GetAttr(kNumLayers));
|
||||||
|
bool bidirectional = GetValue<bool>(primitive->GetAttr(kBidirectional));
|
||||||
|
int64_t num_directions = kNumber1;
|
||||||
|
if (bidirectional) {
|
||||||
|
num_directions = kNumber2;
|
||||||
|
}
|
||||||
|
int64_t input_size = GetValue<int64_t>(primitive->GetAttr(kInput_size));
|
||||||
|
auto weight_size = GetWeightSize(primitive, num_layers, num_directions);
|
||||||
|
ShapeVector dx_shape = {y_shape[kIndex0], y_shape[kIndex2], input_size};
|
||||||
|
ShapeVector weight_shape = {weight_size, kNumber1, kNumber1};
|
||||||
|
std::vector<abstract::BaseShapePtr> output_shapes;
|
||||||
|
output_shapes.push_back(std::make_shared<abstract::Shape>(dx_shape));
|
||||||
|
output_shapes.push_back(std::make_shared<abstract::Shape>(dhy_shape));
|
||||||
|
output_shapes.push_back(std::make_shared<abstract::Shape>(weight_shape));
|
||||||
|
return std::make_shared<abstract::TupleShape>(output_shapes);
|
||||||
|
}
|
||||||
|
|
||||||
|
TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) const override {
|
||||||
|
auto hx_type_ptr = input_args[kHxIdx]->BuildType();
|
||||||
|
auto dy_type_ptr = input_args[kDyIdx]->BuildType();
|
||||||
|
std::vector<TypePtr> types = {dy_type_ptr, dy_type_ptr, hx_type_ptr};
|
||||||
|
return std::make_shared<Tuple>(types);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
int64_t GetWeightSize(const PrimitivePtr &primitive, int64_t num_layers, int64_t num_directions) const {
|
||||||
|
int64_t weight_size = 0;
|
||||||
|
bool has_bias = GetValue<bool>(primitive->GetAttr(kHasBias));
|
||||||
|
int64_t input_size = GetValue<int64_t>(primitive->GetAttr(kInput_size));
|
||||||
|
int64_t hidden_size = GetValue<int64_t>(primitive->GetAttr(kHidden_size));
|
||||||
|
int64_t gate_size = hidden_size * kNumber3;
|
||||||
|
weight_size += input_size * gate_size * num_directions +
|
||||||
|
(num_layers - 1) * (hidden_size * num_directions) * gate_size * num_directions;
|
||||||
|
int64_t temp = num_directions * num_layers;
|
||||||
|
weight_size += gate_size * hidden_size * temp;
|
||||||
|
if (has_bias) {
|
||||||
|
weight_size += gate_size * temp;
|
||||||
|
}
|
||||||
|
return weight_size;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
MIND_API_OPERATOR_IMPL(GRUV2Grad, BaseOperator);
|
||||||
|
REGISTER_PRIMITIVE_OP_INFER_IMPL(GRUV2Grad, prim::kPrimGRUV2Grad, GruGradInfer, false);
|
||||||
|
} // namespace ops
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,52 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CORE_OPS_GRAD_GRU_V2_GRAD_H_
|
||||||
|
#define MINDSPORE_CORE_OPS_GRAD_GRU_V2_GRAD_H_
|
||||||
|
#include <map>
|
||||||
|
#include <vector>
|
||||||
|
#include <string>
|
||||||
|
#include <memory>
|
||||||
|
#include "ops/base_operator.h"
|
||||||
|
#include "mindapi/base/types.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace ops {
|
||||||
|
constexpr auto kNameGRUV2Grad = "GRUV2Grad";
|
||||||
|
class MIND_API GRUV2Grad : public BaseOperator {
|
||||||
|
public:
|
||||||
|
MIND_API_BASE_MEMBER(GRUV2Grad);
|
||||||
|
GRUV2Grad() : BaseOperator(kNameGRUV2Grad) {}
|
||||||
|
void Init(const int64_t input_size, const int64_t hidden_size, const int64_t num_layers, const bool has_bias,
|
||||||
|
const float dropout, const bool bidirectional = false);
|
||||||
|
void set_input_size(const int64_t input_size);
|
||||||
|
int64_t get_input_size() const;
|
||||||
|
void set_hidden_size(const int64_t hidden_size);
|
||||||
|
int64_t get_hidden_size() const;
|
||||||
|
void set_num_layers(const int64_t num_layers);
|
||||||
|
int64_t get_num_layers() const;
|
||||||
|
void set_has_bias(const bool has_bias);
|
||||||
|
bool get_has_bias() const;
|
||||||
|
void set_dropout(const float dropout);
|
||||||
|
float get_dropout() const;
|
||||||
|
void set_bidirectional(const bool bidirectional);
|
||||||
|
bool get_bidirectional() const;
|
||||||
|
void set_num_directions(const int64_t num_directions);
|
||||||
|
int64_t get_num_directions() const;
|
||||||
|
};
|
||||||
|
} // namespace ops
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif // MINDSPORE_CORE_OPS_GRAD_GRU_V2_GRAD_H_
|
|
@ -1102,6 +1102,25 @@ def get_bprop_lstm(self):
|
||||||
return bprop
|
return bprop
|
||||||
|
|
||||||
|
|
||||||
|
@bprop_getters.register(rl_ops.GRUV2)
|
||||||
|
def get_bppro_gru_v2(self):
|
||||||
|
"""Grad definition for `GRUV2` operation."""
|
||||||
|
gru_grad_v2 = G.GRUV2Grad(
|
||||||
|
self.input_size,
|
||||||
|
self.hidden_size,
|
||||||
|
self.num_layers,
|
||||||
|
self.has_bias,
|
||||||
|
self.bidirectional,
|
||||||
|
self.dropout
|
||||||
|
)
|
||||||
|
def bpro(x, hx, w, seq_length, out, dout):
|
||||||
|
y, hy, reverse, _ = out
|
||||||
|
dy, dhy, _, _ = dout
|
||||||
|
dx, dhx, dw = gru_grad_v2(x, hx, w, seq_length, y, hy, dy, dhy, reverse)
|
||||||
|
return dx, dhx, dw, (0)
|
||||||
|
return bpro
|
||||||
|
|
||||||
|
|
||||||
@bprop_getters.register(rl_ops.CudnnGRU)
|
@bprop_getters.register(rl_ops.CudnnGRU)
|
||||||
def get_bprop_gru(self):
|
def get_bprop_gru(self):
|
||||||
"""Grad definition for `GRU` operation."""
|
"""Grad definition for `GRU` operation."""
|
||||||
|
|
|
@ -1616,6 +1616,25 @@ class GruGradWeight(PrimitiveWithInfer):
|
||||||
return hx_dtype
|
return hx_dtype
|
||||||
|
|
||||||
|
|
||||||
|
class GRUV2Grad(Primitive):
|
||||||
|
"""Computes the grad gradients of GRU."""
|
||||||
|
|
||||||
|
@prim_attr_register
|
||||||
|
def __init__(self, input_size, hidden_size, num_layers, has_bias, bidirectional, dropout):
|
||||||
|
self.input_size = validator.check_positive_int(input_size, 'input_size', self.name)
|
||||||
|
self.hidden_size = validator.check_positive_int(hidden_size, 'hidden_size', self.name)
|
||||||
|
self.num_layers = validator.check_positive_int(num_layers, 'num_layers', self.name)
|
||||||
|
self.has_bias = validator.check_value_type('has_bias', has_bias, (bool,), self.name)
|
||||||
|
self.bidirectional = validator.check_value_type('bidirectional', bidirectional, (bool,), self.name)
|
||||||
|
self.dropout = validator.check_value_type("dropout", dropout, [float], self.name)
|
||||||
|
self.dropout = validator.check_float_range(dropout, 0, 1, Rel.INC_BOTH, 'dropout', self.name)
|
||||||
|
|
||||||
|
if bidirectional:
|
||||||
|
self.num_directions = 2
|
||||||
|
else:
|
||||||
|
self.num_directions = 1
|
||||||
|
|
||||||
|
|
||||||
class DynamicGRUV2Grad(Primitive):
|
class DynamicGRUV2Grad(Primitive):
|
||||||
r"""
|
r"""
|
||||||
Computes the input gradients of DynamicGRUV2.
|
Computes the input gradients of DynamicGRUV2.
|
||||||
|
|
|
@ -0,0 +1,134 @@
|
||||||
|
# Copyright 2023 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
import mindspore.ops.operations._rl_inner_ops as rl_ops
|
||||||
|
import mindspore.ops.operations._grad_ops as grad_ops
|
||||||
|
from mindspore import context, Tensor
|
||||||
|
from mindspore.common.parameter import ParameterTuple
|
||||||
|
import mindspore as ms
|
||||||
|
import mindspore.nn as nn
|
||||||
|
from mindspore.ops import composite as c
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.platform_arm_cpu
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
@pytest.mark.parametrize("mode", [context.GRAPH_MODE, context.PYNATIVE_MODE])
|
||||||
|
def test_gru_grad(mode):
|
||||||
|
"""
|
||||||
|
Feature: test gru_grad cpu operation.
|
||||||
|
Description: test gru_grad cpu operation.
|
||||||
|
Expectation: no exception.
|
||||||
|
"""
|
||||||
|
input_size = 10
|
||||||
|
hidden_size = 2
|
||||||
|
num_layers = 1
|
||||||
|
max_seq_len = 5
|
||||||
|
batch_size = 2
|
||||||
|
|
||||||
|
context.set_context(mode=mode)
|
||||||
|
net = rl_ops.GRUV2(input_size, hidden_size, num_layers, True, False, 0.0)
|
||||||
|
input_tensor = Tensor(
|
||||||
|
np.ones([max_seq_len, batch_size, input_size]).astype(np.float32))
|
||||||
|
h0 = Tensor(
|
||||||
|
np.ones([num_layers, batch_size, hidden_size]).astype(np.float32))
|
||||||
|
w = Tensor(np.ones([84, 1, 1]).astype(np.float32))
|
||||||
|
seq_lengths = Tensor(np.array([4, 3]).astype(np.int32))
|
||||||
|
output, hn, out1, _ = net(input_tensor, h0, w, seq_lengths)
|
||||||
|
grad_net = grad_ops.GRUV2Grad(
|
||||||
|
input_size, hidden_size, num_layers, True, False, 0.0)
|
||||||
|
dx, dh, dw = grad_net(input_tensor, h0, w, seq_lengths,
|
||||||
|
output, hn, output, hn, out1)
|
||||||
|
print("dx:", dx)
|
||||||
|
print("dh:", dh)
|
||||||
|
print("dw:", dw)
|
||||||
|
|
||||||
|
|
||||||
|
class GradOfAllInputsAndParams(nn.Cell):
|
||||||
|
def __init__(self, network, sens_param):
|
||||||
|
super().__init__()
|
||||||
|
self.grad = c.GradOperation(
|
||||||
|
get_all=True, get_by_list=True, sens_param=sens_param)
|
||||||
|
self.network = network
|
||||||
|
self.params = ParameterTuple(self.network.trainable_params())
|
||||||
|
|
||||||
|
def construct(self, *inputs):
|
||||||
|
gout = self.grad(self.network, self.params)(*inputs)
|
||||||
|
return gout
|
||||||
|
|
||||||
|
|
||||||
|
class NetGruV2(nn.Cell):
|
||||||
|
def __init__(self, input_size, hidden_size, num_layers, has_bias, weights, is_train):
|
||||||
|
super(NetGruV2, self).__init__()
|
||||||
|
self.gruv2 = rl_ops.GRUV2(
|
||||||
|
input_size, hidden_size, num_layers, has_bias, False, 0.0, is_train)
|
||||||
|
self.weights = weights
|
||||||
|
|
||||||
|
def construct(self, x, h_0, seq_len):
|
||||||
|
return self.gruv2(
|
||||||
|
x, h_0, self.weights.astype(x.dtype), seq_len)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.platform_arm_cpu
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
@pytest.mark.parametrize("has_bias", [True, False])
|
||||||
|
@pytest.mark.parametrize("is_train", [True, False])
|
||||||
|
def test_gru_backward(has_bias, is_train):
|
||||||
|
"""
|
||||||
|
Feature: test GRUV2 backward.
|
||||||
|
Description: test gru_grad cpu operation.
|
||||||
|
Expectation: no exception.
|
||||||
|
"""
|
||||||
|
batch_size = 3
|
||||||
|
max_seq_length = 5
|
||||||
|
input_size = 10
|
||||||
|
hidden_size = 3
|
||||||
|
num_layers = 1
|
||||||
|
num_directions = 1
|
||||||
|
seq_lengths = Tensor([5, 3, 2], ms.int32)
|
||||||
|
dtype = ms.float32
|
||||||
|
|
||||||
|
x = Tensor(np.random.normal(
|
||||||
|
0.0, 1.0, (max_seq_length, batch_size, input_size)), dtype)
|
||||||
|
h0 = Tensor(np.random.normal(
|
||||||
|
0.0, 1.0, (num_layers * num_directions, batch_size, hidden_size)), dtype)
|
||||||
|
weight_size = 135 if has_bias else 117
|
||||||
|
weights = Tensor(np.ones([weight_size, 1, 1]).astype(np.float32))
|
||||||
|
|
||||||
|
# graph mode
|
||||||
|
context.set_context(mode=context.GRAPH_MODE)
|
||||||
|
gru_v2_net = NetGruV2(input_size, hidden_size,
|
||||||
|
num_layers, has_bias, weights, is_train)
|
||||||
|
grad_net_inp = GradOfAllInputsAndParams(gru_v2_net, sens_param=False)
|
||||||
|
grad_net_inp.set_train()
|
||||||
|
out_grad, _ = grad_net_inp(x, h0, seq_lengths)
|
||||||
|
# pynative mode
|
||||||
|
context.set_context(mode=context.PYNATIVE_MODE)
|
||||||
|
pynative_gru_v2_net = NetGruV2(input_size, hidden_size,
|
||||||
|
num_layers, has_bias, weights, is_train)
|
||||||
|
pynative_grad_net_inp = GradOfAllInputsAndParams(
|
||||||
|
pynative_gru_v2_net, sens_param=False)
|
||||||
|
pynative_grad_net_inp.set_train()
|
||||||
|
py_native_out_grad, _ = pynative_grad_net_inp(x, h0, seq_lengths)
|
||||||
|
|
||||||
|
assert np.allclose(out_grad[0].asnumpy(),
|
||||||
|
py_native_out_grad[0].asnumpy(), 0.001, 0.001)
|
||||||
|
assert np.allclose(out_grad[1].asnumpy(),
|
||||||
|
py_native_out_grad[1].asnumpy(), 0.001, 0.001)
|
Loading…
Reference in New Issue