!48988 Add GRUV2Grad OPS IMPL

Merge pull request !48988 from yangluhang/gruv2_grad
This commit is contained in:
i-robot 2023-02-24 07:26:53 +00:00 committed by Gitee
commit 026c8d9084
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
8 changed files with 762 additions and 0 deletions

View File

@ -0,0 +1,273 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "plugin/device/cpu/kernel/mkldnn/gru_grad_cpu_kernel.h"
#include <cstddef>
#include <cstring>
#include <string>
#include "utils/ms_utils.h"
#include "mindspore/core/ops/grad/gru_v2_grad.h"
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
namespace mindspore {
namespace kernel {
namespace {
constexpr size_t kGruGradInputsNum = 9;
constexpr size_t kGruGradOutputsNum = 3;
constexpr size_t kNumberOne = 1;
constexpr size_t kNumberTwo = 2;
constexpr size_t kGateNum = 3;
constexpr size_t kDims = 3;
constexpr int kMaxGRULayer = 100;
constexpr int kSrcLayerIdx = 0;
constexpr int kSrcIterIdx = 1;
constexpr int kDstLayerIdx = 4;
constexpr int kDstIterIdx = 5;
constexpr int kWorkSpaceIdx = 8;
constexpr int kDiffSrcLayerIdx = 0;
constexpr int kDiffSrcIterIdx = 1;
constexpr int kDiffDstLayerIdx = 6;
constexpr int kDiffDstIterIdx = 7;
using tag = dnnl::memory::format_tag;
using dim = dnnl::memory::dims;
using dt = dnnl::memory::data_type;
} // namespace
bool GRUGradCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) {
MS_EXCEPTION_IF_NULL(base_operator);
kernel_name_ = base_operator->name();
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kGruGradInputsNum, kernel_name_);
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kGruGradOutputsNum, kernel_name_);
auto op_prim = std::dynamic_pointer_cast<ops::GRUV2Grad>(base_operator);
MS_EXCEPTION_IF_NULL(op_prim);
bidirectional_ = op_prim->get_bidirectional();
input_size_ = op_prim->get_input_size();
hidden_size_ = op_prim->get_hidden_size();
num_layers_ = op_prim->get_num_layers();
has_bias_ = op_prim->get_has_bias();
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
auto match = MatchKernelAttr(kernel_attr, GetOpSupport());
if (!match.first) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "', it does not support this kernel data type: " << kernel_attr;
return false;
}
return true;
}
int GRUGradCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &) {
auto ret = KernelMod::Resize(base_operator, inputs, outputs);
if (ret != KRET_OK) {
return ret;
}
input_size_list_[kIndex8] = reserve_size_;
auto src_shape = inputs[kIndex0]->GetDeviceShapeAdaptively();
auto src_h_shape = inputs[kIndex1]->GetDeviceShapeAdaptively();
if (src_shape.size() != kDims || src_h_shape.size() != kDims) {
MS_LOG(ERROR) << "GRU only support 3-D input!,but the src_shape dim is" << src_shape.size()
<< ", the src_shape dim is" << src_h_shape.size();
return KRET_RESIZE_FAILED;
}
batch_size_ = src_shape[1];
seq_len_ = src_shape[0];
num_directions_ = kNumberOne;
if (bidirectional_) {
num_directions_ = kNumberTwo;
}
const int64_t gate_size = kGateNum * hidden_size_;
if (num_layers_ <= 0) {
MS_LOG(ERROR) << "Layers must be greater than zero! but the num_layers is " << num_layers_;
return KRET_RESIZE_FAILED;
}
if (num_layers_ > kMaxGRULayer) {
MS_LOG(ERROR) << "Layers must be less than or equal to 100! but the num_layers_ is " << num_layers_;
return KRET_RESIZE_FAILED;
}
for (int i = 0; i < num_layers_; ++i) {
weight_size_ += gate_size * (i == 0 ? input_size_ : hidden_size_ * num_directions_);
weight_h_size_ += gate_size * hidden_size_;
}
weight_size_ = weight_size_ * num_directions_;
weight_h_size_ = weight_h_size_ * num_directions_;
weights_dims_ = {num_layers_, num_directions_, input_size_, kGateNum, hidden_size_};
weights_h_dims_ = {num_layers_, num_directions_, hidden_size_, kGateNum, hidden_size_};
bias_dims_ = {num_layers_, num_directions_, kGateNum, hidden_size_};
if (num_directions_ * num_layers_ != src_h_shape[0]) {
MS_LOG(ERROR) << "Error iteration shape!, iteration shape[0] is required to be " << num_directions_ * num_layers_
<< " but " << src_h_shape[0];
return KRET_RESIZE_FAILED;
}
InitDnnl();
return KRET_OK;
}
void GRUGradCpuKernelMod::InitDnnl() {
auto eng = engine_;
dnnl::rnn_direction direction = dnnl::rnn_direction::unidirectional;
if (bidirectional_) {
direction = dnnl::rnn_direction::bidirectional_concat;
}
dim src_dims = {seq_len_, batch_size_, input_size_};
dim src_h_dims = {num_layers_, num_directions_, batch_size_, hidden_size_};
weights_dims_ = {num_layers_, num_directions_, input_size_, kGateNum, hidden_size_};
weights_h_dims_ = {num_layers_, num_directions_, hidden_size_, kGateNum, hidden_size_};
bias_dims_ = {num_layers_, num_directions_, kGateNum, hidden_size_};
dim dst_dims = {seq_len_, batch_size_, static_cast<int64_t>(hidden_size_) * num_directions_};
dim dst_h_dims = {num_layers_, num_directions_, batch_size_, hidden_size_};
dnnl::memory::desc src_desc = formatted_md(src_dims, tag::tnc);
dnnl::memory::desc src_h_desc = formatted_md(src_h_dims, tag::ldnc);
dnnl::memory::desc bias_desc = formatted_md(bias_dims_, tag::ldgo);
dnnl::memory::desc dst_desc = formatted_md(dst_dims, tag::tnc);
dnnl::memory::desc dst_h_desc = formatted_md(dst_h_dims, tag::ldnc);
auto weights_desc = formatted_md(weights_dims_, tag::any);
auto weights_h_desc = formatted_md(weights_h_dims_, tag::any);
auto forward_desc =
CreatePrimitive<dnnl::gru_forward::desc>(dnnl::prop_kind::forward_training, direction, src_desc, src_h_desc,
weights_desc, weights_h_desc, bias_desc, dst_desc, dst_h_desc);
auto prim_forward_desc = CreateDesc<dnnl::gru_forward::primitive_desc>(*forward_desc, eng);
auto backward_desc = CreatePrimitive<dnnl::gru_backward::desc>(
dnnl::prop_kind::backward, direction, src_desc, src_h_desc, weights_desc, weights_h_desc, bias_desc, dst_desc,
dst_h_desc, src_desc, src_h_desc, weights_desc, weights_h_desc, bias_desc, dst_desc, dst_h_desc);
prim_backward_desc_ = CreateDesc<dnnl::gru_backward::primitive_desc>(*backward_desc, eng, prim_forward_desc);
primitive_ = CreatePrimitive<dnnl::gru_backward>(prim_backward_desc_);
auto wksp_desc = GetWorkspaceDesc(prim_forward_desc);
reserve_size_ = GetSize(wksp_desc);
AddArgumentOp(src_desc, src_h_desc, bias_desc, dst_desc, dst_h_desc, wksp_desc);
// construct fw memory
weights_layer_desc_ = GetWeightsLayerDesc(prim_backward_desc_);
weights_iter_desc_ = GetWeightsIterDesc(prim_backward_desc_);
bias_desc_ = GetBiasDesc(prim_backward_desc_);
auto weights_mem_desc = CreateDesc<dnnl::memory::desc>(weights_dims_, dt::f32, tag::ldgoi);
auto weights_h_mem_desc = CreateDesc<dnnl::memory::desc>(weights_h_dims_, dt::f32, tag::ldgoi);
user_weights_memory_ = CreateDesc<dnnl::memory>(weights_mem_desc, eng);
user_weights_h_memory_ = CreateDesc<dnnl::memory>(weights_h_mem_desc, eng);
weights_memory_ = CreateDesc<dnnl::memory>(weights_layer_desc_, eng);
weights_h_memory_ = CreateDesc<dnnl::memory>(weights_iter_desc_, eng);
bias_memory_ = CreateDesc<dnnl::memory>(bias_desc_, eng);
// construct bw memory
diff_weights_layer_desc_ = GetDiffWeightsLayerDesc(prim_backward_desc_);
diff_weights_iter_desc_ = GetDiffWeightsIterDesc(prim_backward_desc_);
diff_bias_desc_ = GetDiffBiasDesc(prim_backward_desc_);
diff_weights_memory_ = CreateDesc<dnnl::memory>(diff_weights_layer_desc_, eng);
diff_weights_h_memory_ = CreateDesc<dnnl::memory>(diff_weights_iter_desc_, eng);
diff_bias_memory_ = CreateDesc<dnnl::memory>(diff_bias_desc_, eng);
user_diff_weights_memory_ = CreateDesc<dnnl::memory>(weights_mem_desc, eng);
user_diff_weights_h_memory_ = CreateDesc<dnnl::memory>(weights_h_mem_desc, eng);
}
void GRUGradCpuKernelMod::AddArgumentOp(const dnnl::memory::desc &src_desc, const dnnl::memory::desc &src_h_desc,
const dnnl::memory::desc &bias_desc, const dnnl::memory::desc &dst_desc,
const dnnl::memory::desc &dst_h_desc, const dnnl::memory::desc &wksp_desc) {
AddArgument(DNNL_ARG_SRC_LAYER, src_desc);
AddArgument(DNNL_ARG_SRC_ITER, src_h_desc);
AddArgument(DNNL_ARG_WEIGHTS_LAYER, weights_layer_desc_);
AddArgument(DNNL_ARG_WEIGHTS_ITER, weights_iter_desc_);
AddArgument(DNNL_ARG_BIAS, bias_desc);
AddArgument(DNNL_ARG_DST_LAYER, dst_desc);
AddArgument(DNNL_ARG_DST_ITER, dst_h_desc);
AddArgument(DNNL_ARG_DIFF_SRC_LAYER, src_desc);
AddArgument(DNNL_ARG_DIFF_SRC_ITER, src_h_desc);
AddArgument(DNNL_ARG_DIFF_WEIGHTS_LAYER, diff_weights_layer_desc_);
AddArgument(DNNL_ARG_DIFF_WEIGHTS_ITER, diff_weights_iter_desc_);
AddArgument(DNNL_ARG_DIFF_BIAS, bias_desc);
AddArgument(DNNL_ARG_DIFF_DST_LAYER, dst_desc);
AddArgument(DNNL_ARG_DIFF_DST_ITER, dst_h_desc);
AddArgument(DNNL_ARG_WORKSPACE, wksp_desc);
}
void GRUGradCpuKernelMod::SetArgumentHandleOp(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &outputs) {
SetArgumentHandle(DNNL_ARG_SRC_LAYER, inputs[kSrcLayerIdx]->addr);
SetArgumentHandle(DNNL_ARG_SRC_ITER, inputs[kSrcIterIdx]->addr);
SetArgumentHandle(DNNL_ARG_WEIGHTS_LAYER, GetDataHandle(weights_memory_));
SetArgumentHandle(DNNL_ARG_WEIGHTS_ITER, GetDataHandle(weights_h_memory_));
SetArgumentHandle(DNNL_ARG_BIAS, GetDataHandle(bias_memory_));
SetArgumentHandle(DNNL_ARG_DST_LAYER, inputs[kDstLayerIdx]->addr);
SetArgumentHandle(DNNL_ARG_DST_ITER, inputs[kDstIterIdx]->addr);
SetArgumentHandle(DNNL_ARG_WORKSPACE, inputs[kWorkSpaceIdx]->addr);
SetArgumentHandle(DNNL_ARG_DIFF_SRC_LAYER, outputs[kDiffSrcLayerIdx]->addr);
SetArgumentHandle(DNNL_ARG_DIFF_SRC_ITER, outputs[kDiffSrcIterIdx]->addr);
SetArgumentHandle(DNNL_ARG_DIFF_WEIGHTS_LAYER, GetDataHandle(diff_weights_memory_));
SetArgumentHandle(DNNL_ARG_DIFF_WEIGHTS_ITER, GetDataHandle(diff_weights_h_memory_));
SetArgumentHandle(DNNL_ARG_DIFF_BIAS, GetDataHandle(diff_bias_memory_));
SetArgumentHandle(DNNL_ARG_DIFF_DST_LAYER, inputs[kDiffDstLayerIdx]->addr);
SetArgumentHandle(DNNL_ARG_DIFF_DST_ITER, inputs[kDiffDstIterIdx]->addr);
}
void GRUGradCpuKernelMod::ResetMemory(const dnnl::memory &mem, const string name) const {
auto dst_ptr = GetDataHandle(mem);
auto mem_desc = GetMemDesc(mem);
auto size = GetSize(mem_desc);
if (memset_s(dst_ptr, size, 0, size) != EOK) {
MS_LOG(EXCEPTION) << name << " memset error";
}
}
bool GRUGradCpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &outputs) {
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kGruGradInputsNum, kernel_name_);
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kGruGradOutputsNum, kernel_name_);
SetDataHandle(user_weights_memory_, inputs[kIndex2]->addr);
SetDataHandle(user_weights_h_memory_, reinterpret_cast<float *>(inputs[kIndex2]->addr) + weight_size_);
Reorder(&user_weights_memory_, &weights_memory_);
Reorder(&user_weights_h_memory_, &weights_h_memory_);
if (has_bias_) {
SetDataHandle(bias_memory_, reinterpret_cast<float *>(inputs[kIndex2]->addr) + weight_size_ + weight_h_size_);
} else {
auto dst_ptr = GetDataHandle(bias_memory_);
auto size = GetSize(bias_desc_);
if (memset_s(dst_ptr, size, 0, size) != EOK) {
MS_LOG(EXCEPTION) << "Bias memset error";
}
}
SetDataHandle(user_diff_weights_memory_, outputs[kIndex2]->addr);
SetDataHandle(user_diff_weights_h_memory_, reinterpret_cast<float *>(outputs[kIndex2]->addr) + weight_size_);
ResetMemory(user_diff_weights_memory_, "user weights grad");
ResetMemory(user_diff_weights_h_memory_, "user weights iter grad");
ResetMemory(diff_weights_memory_, "weights grad");
ResetMemory(diff_weights_h_memory_, "weights iter grad");
if (has_bias_) {
SetDataHandle(diff_bias_memory_, reinterpret_cast<float *>(outputs[kIndex2]->addr) + weight_size_ + weight_h_size_);
}
auto dst_ptr = GetDataHandle(diff_bias_memory_);
auto size = GetSize(diff_bias_desc_);
if (memset_s(dst_ptr, size, 0, size) != EOK) {
MS_LOG(EXCEPTION) << "Bias grad memset error";
}
SetArgumentHandleOp(inputs, outputs);
ExecutePrimitive();
Reorder(&diff_weights_memory_, &user_diff_weights_memory_);
Reorder(&diff_weights_h_memory_, &user_diff_weights_h_memory_);
return true;
}
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, GRUV2Grad, GRUGradCpuKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,107 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_GRU_GRAD_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_GRU_GRAD_CPU_KERNEL_H_
#include <map>
#include <string>
#include <vector>
#include <memory>
#include "plugin/device/cpu/kernel/mkldnn/mkl_cpu_kernel.h"
namespace mindspore {
namespace kernel {
class GRUGradCpuKernelMod : public MKLCpuKernelMod {
public:
GRUGradCpuKernelMod() = default;
~GRUGradCpuKernelMod() override = default;
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) override;
int Resize(
const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost = std::map<uint32_t, tensor::TensorPtr>()) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
protected:
std::vector<KernelAttr> GetOpSupport() override {
static std::vector<KernelAttr> support_list = {KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)};
return support_list;
}
private:
void AddArgumentOp(const dnnl::memory::desc &src_desc, const dnnl::memory::desc &src_c_desc,
const dnnl::memory::desc &bias_desc, const dnnl::memory::desc &dst_desc,
const dnnl::memory::desc &dst_h_desc, const dnnl::memory::desc &wksp_desc);
void SetArgumentHandleOp(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &outputs);
void ResetMemory(const dnnl::memory &mem, const string name) const;
void InitDnnl();
int weight_size_{0};
int weight_h_size_{0};
int input_size_{0};
int hidden_size_{0};
int num_layers_{0};
int batch_size_{0};
int seq_len_{0};
int num_directions_{0};
bool bidirectional_{false};
bool has_bias_{false};
size_t reserve_size_{1};
dnnl::memory::dims weights_dims_;
dnnl::memory::dims weights_h_dims_;
dnnl::memory::dims bias_dims_;
dnnl::gru_backward::primitive_desc prim_backward_desc_;
dnnl::memory::desc weights_layer_desc_;
dnnl::memory::desc weights_iter_desc_;
dnnl::memory::desc bias_desc_;
dnnl::memory::desc diff_weights_layer_desc_;
dnnl::memory::desc diff_weights_iter_desc_;
dnnl::memory::desc diff_bias_desc_;
dnnl::memory user_weights_memory_;
dnnl::memory user_weights_h_memory_;
dnnl::memory weights_memory_;
dnnl::memory weights_h_memory_;
dnnl::memory bias_memory_;
dnnl::memory diff_weights_memory_;
dnnl::memory diff_weights_h_memory_;
dnnl::memory diff_bias_memory_;
dnnl::memory user_diff_weights_memory_;
dnnl::memory user_diff_weights_h_memory_;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_GRU_GRAD_CPU_KERNEL_H_

View File

@ -673,6 +673,7 @@ GVAR_DEF(PrimitivePtr, kPrimExtractImagePatches, std::make_shared<Primitive>("Ex
GVAR_DEF(PrimitivePtr, kPrimDynamicRNN, std::make_shared<Primitive>("DynamicRNN")); GVAR_DEF(PrimitivePtr, kPrimDynamicRNN, std::make_shared<Primitive>("DynamicRNN"));
GVAR_DEF(PrimitivePtr, kPrimCudnnGRU, std::make_shared<Primitive>("CudnnGRU")); GVAR_DEF(PrimitivePtr, kPrimCudnnGRU, std::make_shared<Primitive>("CudnnGRU"));
GVAR_DEF(PrimitivePtr, kPrimGRUV2, std::make_shared<Primitive>("GRUV2")); GVAR_DEF(PrimitivePtr, kPrimGRUV2, std::make_shared<Primitive>("GRUV2"));
GVAR_DEF(PrimitivePtr, kPrimGRUV2Grad, std::make_shared<Primitive>("GRUV2Grad"));
GVAR_DEF(PrimitivePtr, kPrimLSTMV2, std::make_shared<Primitive>("LSTMV2")); GVAR_DEF(PrimitivePtr, kPrimLSTMV2, std::make_shared<Primitive>("LSTMV2"));
GVAR_DEF(PrimitivePtr, kPrimDynamicRNNGrad, std::make_shared<Primitive>("DynamicRNNGrad")); GVAR_DEF(PrimitivePtr, kPrimDynamicRNNGrad, std::make_shared<Primitive>("DynamicRNNGrad"));
GVAR_DEF(PrimitivePtr, kPrimDynamicGRUV2, std::make_shared<Primitive>("DynamicGRUV2")); GVAR_DEF(PrimitivePtr, kPrimDynamicGRUV2, std::make_shared<Primitive>("DynamicGRUV2"));

View File

@ -0,0 +1,157 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ops/grad/gru_v2_grad.h"
#include <algorithm>
#include <cstdint>
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "mindapi/src/helper.h"
namespace mindspore {
namespace ops {
void GRUV2Grad::set_input_size(const int64_t input_size) {
(void)CheckAndConvertUtils::CheckInteger(kInput_size, input_size, kGreaterThan, 0, this->name());
(void)AddAttr(kInput_size, api::MakeValue(input_size));
}
int64_t GRUV2Grad::get_input_size() const { return GetValue<int64_t>(GetAttr(kInput_size)); }
void GRUV2Grad::set_hidden_size(const int64_t hidden_size) {
(void)CheckAndConvertUtils::CheckInteger(kHidden_size, hidden_size, kGreaterThan, 0, this->name());
(void)AddAttr(kHidden_size, api::MakeValue(hidden_size));
}
int64_t GRUV2Grad::get_hidden_size() const { return GetValue<int64_t>(GetAttr(kHidden_size)); }
void GRUV2Grad::set_num_layers(const int64_t num_layers) {
(void)CheckAndConvertUtils::CheckInteger(kNumLayers, num_layers, kGreaterThan, 0, this->name());
(void)AddAttr(kNumLayers, api::MakeValue(num_layers));
}
int64_t GRUV2Grad::get_num_layers() const { return GetValue<int64_t>(GetAttr(kNumLayers)); }
void GRUV2Grad::set_has_bias(const bool has_bias) { (void)AddAttr(kHasBias, api::MakeValue(has_bias)); }
bool GRUV2Grad::get_has_bias() const {
auto value_ptr = this->GetAttr(kHasBias);
return GetValue<bool>(value_ptr);
}
void GRUV2Grad::set_dropout(const float dropout) {
CheckAndConvertUtils::CheckInRange<float>(kDropout, dropout, kIncludeBoth, {0.0, 1.0}, this->name());
(void)AddAttr(kDropout, api::MakeValue(dropout));
}
float GRUV2Grad::get_dropout() const {
auto value_ptr = this->GetAttr(kDropout);
return GetValue<float>(value_ptr);
}
void GRUV2Grad::set_bidirectional(const bool bidirectional) {
(void)AddAttr(kBidirectional, api::MakeValue(bidirectional));
}
bool GRUV2Grad::get_bidirectional() const {
auto value_ptr = this->GetAttr(kBidirectional);
return GetValue<bool>(value_ptr);
}
void GRUV2Grad::set_num_directions(const int64_t num_directions) {
(void)AddAttr(kNumDirections, api::MakeValue(num_directions));
}
int64_t GRUV2Grad::get_num_directions() const { return GetValue<int64_t>(GetAttr(kNumDirections)); }
void GRUV2Grad::Init(const int64_t input_size, const int64_t hidden_size, const int64_t num_layers, const bool has_bias,
const float dropout, const bool bidirectional) {
this->set_input_size(input_size);
this->set_hidden_size(hidden_size);
this->set_num_layers(num_layers);
this->set_has_bias(has_bias);
this->set_dropout(dropout);
this->set_bidirectional(bidirectional);
if (bidirectional) {
constexpr int k2Directions = 2;
this->set_num_directions(k2Directions);
} else {
this->set_num_directions(1);
}
}
class GruGradInfer : public abstract::OpInferBase {
const int kInputNum = 9;
const int64_t kNumber1 = 1;
const int64_t kNumber2 = 2;
const int64_t kNumber3 = 3;
const size_t kShapeSize = 3;
const int kIndex0 = 0;
const int kIndex2 = 2;
const int kHxIdx = 1;
const int kYIdx = 4;
const int kDyIdx = 6;
const int kDhyIdx = 7;
public:
GruGradInfer() = default;
BaseShapePtr InferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) const override {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
(void)CheckAndConvertUtils::CheckInteger("input numbers", SizeToLong(input_args.size()), kEqual, kInputNum,
prim_name);
auto y_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kYIdx]->BuildShape())[kShape];
auto dy_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kDyIdx]->BuildShape())[kShape];
auto dhy_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kDhyIdx]->BuildShape())[kShape];
(void)CheckAndConvertUtils::CheckInteger("dhy_shape size", SizeToLong(dhy_shape.size()), kEqual, kShapeSize,
prim_name);
(void)CheckAndConvertUtils::CheckInteger("dy_shape size", SizeToLong(dy_shape.size()), kEqual, kShapeSize,
prim_name);
int64_t num_layers = GetValue<int64_t>(primitive->GetAttr(kNumLayers));
bool bidirectional = GetValue<bool>(primitive->GetAttr(kBidirectional));
int64_t num_directions = kNumber1;
if (bidirectional) {
num_directions = kNumber2;
}
int64_t input_size = GetValue<int64_t>(primitive->GetAttr(kInput_size));
auto weight_size = GetWeightSize(primitive, num_layers, num_directions);
ShapeVector dx_shape = {y_shape[kIndex0], y_shape[kIndex2], input_size};
ShapeVector weight_shape = {weight_size, kNumber1, kNumber1};
std::vector<abstract::BaseShapePtr> output_shapes;
output_shapes.push_back(std::make_shared<abstract::Shape>(dx_shape));
output_shapes.push_back(std::make_shared<abstract::Shape>(dhy_shape));
output_shapes.push_back(std::make_shared<abstract::Shape>(weight_shape));
return std::make_shared<abstract::TupleShape>(output_shapes);
}
TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) const override {
auto hx_type_ptr = input_args[kHxIdx]->BuildType();
auto dy_type_ptr = input_args[kDyIdx]->BuildType();
std::vector<TypePtr> types = {dy_type_ptr, dy_type_ptr, hx_type_ptr};
return std::make_shared<Tuple>(types);
}
private:
int64_t GetWeightSize(const PrimitivePtr &primitive, int64_t num_layers, int64_t num_directions) const {
int64_t weight_size = 0;
bool has_bias = GetValue<bool>(primitive->GetAttr(kHasBias));
int64_t input_size = GetValue<int64_t>(primitive->GetAttr(kInput_size));
int64_t hidden_size = GetValue<int64_t>(primitive->GetAttr(kHidden_size));
int64_t gate_size = hidden_size * kNumber3;
weight_size += input_size * gate_size * num_directions +
(num_layers - 1) * (hidden_size * num_directions) * gate_size * num_directions;
int64_t temp = num_directions * num_layers;
weight_size += gate_size * hidden_size * temp;
if (has_bias) {
weight_size += gate_size * temp;
}
return weight_size;
}
};
MIND_API_OPERATOR_IMPL(GRUV2Grad, BaseOperator);
REGISTER_PRIMITIVE_OP_INFER_IMPL(GRUV2Grad, prim::kPrimGRUV2Grad, GruGradInfer, false);
} // namespace ops
} // namespace mindspore

View File

@ -0,0 +1,52 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_GRAD_GRU_V2_GRAD_H_
#define MINDSPORE_CORE_OPS_GRAD_GRU_V2_GRAD_H_
#include <map>
#include <vector>
#include <string>
#include <memory>
#include "ops/base_operator.h"
#include "mindapi/base/types.h"
namespace mindspore {
namespace ops {
constexpr auto kNameGRUV2Grad = "GRUV2Grad";
class MIND_API GRUV2Grad : public BaseOperator {
public:
MIND_API_BASE_MEMBER(GRUV2Grad);
GRUV2Grad() : BaseOperator(kNameGRUV2Grad) {}
void Init(const int64_t input_size, const int64_t hidden_size, const int64_t num_layers, const bool has_bias,
const float dropout, const bool bidirectional = false);
void set_input_size(const int64_t input_size);
int64_t get_input_size() const;
void set_hidden_size(const int64_t hidden_size);
int64_t get_hidden_size() const;
void set_num_layers(const int64_t num_layers);
int64_t get_num_layers() const;
void set_has_bias(const bool has_bias);
bool get_has_bias() const;
void set_dropout(const float dropout);
float get_dropout() const;
void set_bidirectional(const bool bidirectional);
bool get_bidirectional() const;
void set_num_directions(const int64_t num_directions);
int64_t get_num_directions() const;
};
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_GRAD_GRU_V2_GRAD_H_

View File

@ -1102,6 +1102,25 @@ def get_bprop_lstm(self):
return bprop return bprop
@bprop_getters.register(rl_ops.GRUV2)
def get_bppro_gru_v2(self):
"""Grad definition for `GRUV2` operation."""
gru_grad_v2 = G.GRUV2Grad(
self.input_size,
self.hidden_size,
self.num_layers,
self.has_bias,
self.bidirectional,
self.dropout
)
def bpro(x, hx, w, seq_length, out, dout):
y, hy, reverse, _ = out
dy, dhy, _, _ = dout
dx, dhx, dw = gru_grad_v2(x, hx, w, seq_length, y, hy, dy, dhy, reverse)
return dx, dhx, dw, (0)
return bpro
@bprop_getters.register(rl_ops.CudnnGRU) @bprop_getters.register(rl_ops.CudnnGRU)
def get_bprop_gru(self): def get_bprop_gru(self):
"""Grad definition for `GRU` operation.""" """Grad definition for `GRU` operation."""

View File

@ -1616,6 +1616,25 @@ class GruGradWeight(PrimitiveWithInfer):
return hx_dtype return hx_dtype
class GRUV2Grad(Primitive):
"""Computes the grad gradients of GRU."""
@prim_attr_register
def __init__(self, input_size, hidden_size, num_layers, has_bias, bidirectional, dropout):
self.input_size = validator.check_positive_int(input_size, 'input_size', self.name)
self.hidden_size = validator.check_positive_int(hidden_size, 'hidden_size', self.name)
self.num_layers = validator.check_positive_int(num_layers, 'num_layers', self.name)
self.has_bias = validator.check_value_type('has_bias', has_bias, (bool,), self.name)
self.bidirectional = validator.check_value_type('bidirectional', bidirectional, (bool,), self.name)
self.dropout = validator.check_value_type("dropout", dropout, [float], self.name)
self.dropout = validator.check_float_range(dropout, 0, 1, Rel.INC_BOTH, 'dropout', self.name)
if bidirectional:
self.num_directions = 2
else:
self.num_directions = 1
class DynamicGRUV2Grad(Primitive): class DynamicGRUV2Grad(Primitive):
r""" r"""
Computes the input gradients of DynamicGRUV2. Computes the input gradients of DynamicGRUV2.

134
tests/st/ops/test_gru.py Normal file
View File

@ -0,0 +1,134 @@
# Copyright 2023 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.ops.operations._rl_inner_ops as rl_ops
import mindspore.ops.operations._grad_ops as grad_ops
from mindspore import context, Tensor
from mindspore.common.parameter import ParameterTuple
import mindspore as ms
import mindspore.nn as nn
from mindspore.ops import composite as c
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.env_onecard
@pytest.mark.parametrize("mode", [context.GRAPH_MODE, context.PYNATIVE_MODE])
def test_gru_grad(mode):
"""
Feature: test gru_grad cpu operation.
Description: test gru_grad cpu operation.
Expectation: no exception.
"""
input_size = 10
hidden_size = 2
num_layers = 1
max_seq_len = 5
batch_size = 2
context.set_context(mode=mode)
net = rl_ops.GRUV2(input_size, hidden_size, num_layers, True, False, 0.0)
input_tensor = Tensor(
np.ones([max_seq_len, batch_size, input_size]).astype(np.float32))
h0 = Tensor(
np.ones([num_layers, batch_size, hidden_size]).astype(np.float32))
w = Tensor(np.ones([84, 1, 1]).astype(np.float32))
seq_lengths = Tensor(np.array([4, 3]).astype(np.int32))
output, hn, out1, _ = net(input_tensor, h0, w, seq_lengths)
grad_net = grad_ops.GRUV2Grad(
input_size, hidden_size, num_layers, True, False, 0.0)
dx, dh, dw = grad_net(input_tensor, h0, w, seq_lengths,
output, hn, output, hn, out1)
print("dx:", dx)
print("dh:", dh)
print("dw:", dw)
class GradOfAllInputsAndParams(nn.Cell):
def __init__(self, network, sens_param):
super().__init__()
self.grad = c.GradOperation(
get_all=True, get_by_list=True, sens_param=sens_param)
self.network = network
self.params = ParameterTuple(self.network.trainable_params())
def construct(self, *inputs):
gout = self.grad(self.network, self.params)(*inputs)
return gout
class NetGruV2(nn.Cell):
def __init__(self, input_size, hidden_size, num_layers, has_bias, weights, is_train):
super(NetGruV2, self).__init__()
self.gruv2 = rl_ops.GRUV2(
input_size, hidden_size, num_layers, has_bias, False, 0.0, is_train)
self.weights = weights
def construct(self, x, h_0, seq_len):
return self.gruv2(
x, h_0, self.weights.astype(x.dtype), seq_len)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.env_onecard
@pytest.mark.parametrize("has_bias", [True, False])
@pytest.mark.parametrize("is_train", [True, False])
def test_gru_backward(has_bias, is_train):
"""
Feature: test GRUV2 backward.
Description: test gru_grad cpu operation.
Expectation: no exception.
"""
batch_size = 3
max_seq_length = 5
input_size = 10
hidden_size = 3
num_layers = 1
num_directions = 1
seq_lengths = Tensor([5, 3, 2], ms.int32)
dtype = ms.float32
x = Tensor(np.random.normal(
0.0, 1.0, (max_seq_length, batch_size, input_size)), dtype)
h0 = Tensor(np.random.normal(
0.0, 1.0, (num_layers * num_directions, batch_size, hidden_size)), dtype)
weight_size = 135 if has_bias else 117
weights = Tensor(np.ones([weight_size, 1, 1]).astype(np.float32))
# graph mode
context.set_context(mode=context.GRAPH_MODE)
gru_v2_net = NetGruV2(input_size, hidden_size,
num_layers, has_bias, weights, is_train)
grad_net_inp = GradOfAllInputsAndParams(gru_v2_net, sens_param=False)
grad_net_inp.set_train()
out_grad, _ = grad_net_inp(x, h0, seq_lengths)
# pynative mode
context.set_context(mode=context.PYNATIVE_MODE)
pynative_gru_v2_net = NetGruV2(input_size, hidden_size,
num_layers, has_bias, weights, is_train)
pynative_grad_net_inp = GradOfAllInputsAndParams(
pynative_gru_v2_net, sens_param=False)
pynative_grad_net_inp.set_train()
py_native_out_grad, _ = pynative_grad_net_inp(x, h0, seq_lengths)
assert np.allclose(out_grad[0].asnumpy(),
py_native_out_grad[0].asnumpy(), 0.001, 0.001)
assert np.allclose(out_grad[1].asnumpy(),
py_native_out_grad[1].asnumpy(), 0.001, 0.001)