forked from mindspore-Ecosystem/mindspore
!42007 adagrad_gpu_kernel
Merge pull request !42007 from KylinMoriarty/r1.9
This commit is contained in:
commit
edebb3425e
|
@ -14,6 +14,7 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
#include <algorithm>
|
||||
#include "mindspore/core/ops/adam.h"
|
||||
#include "plugin/device/cpu/kernel/adam_cpu_kernel.h"
|
||||
#include "plugin/device/cpu/kernel/nnacl/errorcode.h"
|
||||
#include "plugin/device/cpu/kernel/nnacl/fp32/adam_fp32.h"
|
||||
|
@ -110,6 +111,8 @@ bool AdamCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vec
|
|||
kernel_name_ = base_operator->GetPrim()->name();
|
||||
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kAdamInputsNum, kernel_name_);
|
||||
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kAdamOutputsNum, kernel_name_);
|
||||
auto kernel_ptr_ = std::dynamic_pointer_cast<ops::Adam>(base_operator);
|
||||
use_nesterov_ = kernel_ptr_->get_use_nesterov();
|
||||
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
|
||||
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
|
||||
if (!is_match) {
|
||||
|
|
|
@ -215,6 +215,8 @@ std::vector<std::pair<KernelAttr, ScatterNdCpuKernelMod::ScatterNdFunc>> Scatter
|
|||
DTYPE_REGISTER_ATTR(kNumberTypeUInt32, kNumberTypeUInt32, uint32_t),
|
||||
DTYPE_REGISTER_ATTR(kNumberTypeUInt16, kNumberTypeUInt16, uint16_t),
|
||||
DTYPE_REGISTER_ATTR(kNumberTypeUInt8, kNumberTypeUInt8, uint8_t),
|
||||
DTYPE_REGISTER_ATTR(kNumberTypeComplex128, kNumberTypeComplex128, complex128),
|
||||
DTYPE_REGISTER_ATTR(kNumberTypeComplex64, kNumberTypeComplex64, complex64),
|
||||
DTYPE_REGISTER(kNumberTypeFloat64, kNumberTypeInt64, kNumberTypeFloat64, double),
|
||||
DTYPE_REGISTER(kNumberTypeFloat32, kNumberTypeInt64, kNumberTypeFloat32, float),
|
||||
DTYPE_REGISTER(kNumberTypeInt64, kNumberTypeInt64, kNumberTypeInt64, int64_t),
|
||||
|
@ -225,6 +227,8 @@ std::vector<std::pair<KernelAttr, ScatterNdCpuKernelMod::ScatterNdFunc>> Scatter
|
|||
DTYPE_REGISTER(kNumberTypeUInt32, kNumberTypeInt64, kNumberTypeUInt32, uint32_t),
|
||||
DTYPE_REGISTER(kNumberTypeUInt16, kNumberTypeInt64, kNumberTypeUInt16, uint16_t),
|
||||
DTYPE_REGISTER(kNumberTypeUInt8, kNumberTypeInt64, kNumberTypeUInt8, uint8_t),
|
||||
DTYPE_REGISTER(kNumberTypeComplex128, kNumberTypeInt64, kNumberTypeComplex128, complex128),
|
||||
DTYPE_REGISTER(kNumberTypeComplex64, kNumberTypeInt64, kNumberTypeComplex64, complex64),
|
||||
};
|
||||
|
||||
std::vector<KernelAttr> ScatterNdCpuKernelMod::GetOpSupport() {
|
||||
|
|
|
@ -15,62 +15,131 @@
|
|||
*/
|
||||
|
||||
#include "plugin/device/gpu/kernel/nn/adagrad_gpu_kernel.h"
|
||||
#include "mindspore/core/ops/apply_adagrad.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
MS_REG_GPU_KERNEL_THREE(ApplyAdagrad,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
AdagradGpuKernelMod, float, float, float)
|
||||
MS_REG_GPU_KERNEL_THREE(ApplyAdagrad,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
AdagradGpuKernelMod, half, half, half)
|
||||
MS_REG_GPU_KERNEL_THREE(ApplyAdagrad,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
AdagradGpuKernelMod, half, float, half)
|
||||
MS_REG_GPU_KERNEL_THREE(ApplyAdagrad,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
AdagradGpuKernelMod, float, float, half)
|
||||
MS_REG_GPU_KERNEL_THREE(ApplyAdagrad,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
AdagradGpuKernelMod, float, half, float)
|
||||
MS_REG_GPU_KERNEL_THREE(ApplyAdagrad,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
AdagradGpuKernelMod, half, float, float)
|
||||
bool AdagradGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) {
|
||||
auto kernel_ptr_ = std::dynamic_pointer_cast<ops::ApplyAdagrad>(base_operator);
|
||||
MS_EXCEPTION_IF_NULL(kernel_ptr_);
|
||||
update_slots = kernel_ptr_->get_update_slots();
|
||||
kernel_name_ = kernel_ptr_->name();
|
||||
constexpr size_t input_num = 4;
|
||||
constexpr size_t output_num = 2;
|
||||
CHECK_KERNEL_INPUTS_NUM(inputs.size(), input_num, kernel_name_);
|
||||
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), output_num, kernel_name_);
|
||||
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
|
||||
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
|
||||
if (!is_match) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "', it does not support this kernel data type: " << kernel_attr;
|
||||
return false;
|
||||
}
|
||||
kernel_func_ = func_list_[index].second;
|
||||
return true;
|
||||
}
|
||||
|
||||
int AdagradGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs,
|
||||
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost) {
|
||||
if (auto ret = KernelMod::Resize(base_operator, inputs, outputs, inputsOnHost); ret != KRET_OK) {
|
||||
return ret;
|
||||
}
|
||||
|
||||
variable_size_ = sizeof(inputs.at(kIndex0)->GetDtype());
|
||||
accumulation_size_ = sizeof(inputs.at(kIndex1)->GetDtype());
|
||||
learning_rate_size_ = sizeof(inputs.at(kIndex2)->GetDtype());
|
||||
gradient_size_ = sizeof(inputs.at(kIndex3)->GetDtype());
|
||||
|
||||
auto variable_shape = inputs[kIndex0]->GetShapeVector();
|
||||
auto accumulation_shape = inputs[kIndex1]->GetShapeVector();
|
||||
auto gradient_shape = inputs[kIndex3]->GetShapeVector();
|
||||
|
||||
variable_size_ *= SizeOf(variable_shape);
|
||||
accumulation_size_ *= SizeOf(accumulation_shape);
|
||||
gradient_size_ *= SizeOf(gradient_shape);
|
||||
|
||||
return KRET_OK;
|
||||
}
|
||||
template <typename T, typename S, typename G>
|
||||
bool AdagradGpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
|
||||
T *variable = GetDeviceAddress<T>(inputs, kIndex0);
|
||||
T *accumulation = GetDeviceAddress<T>(inputs, kIndex1);
|
||||
S *learning_rate = GetDeviceAddress<S>(inputs, kIndex2);
|
||||
G *gradient = GetDeviceAddress<G>(inputs, kIndex3);
|
||||
T *variable_out = GetDeviceAddress<T>(outputs, kIndex0);
|
||||
T *accumulation_out = GetDeviceAddress<T>(outputs, kIndex1);
|
||||
|
||||
ApplyAdagrad(inputs[0]->size / sizeof(T), update_slots, learning_rate, gradient, variable, accumulation,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaMemcpyAsync(variable_out, variable, variable_size_, cudaMemcpyDeviceToDevice,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"cudaMemcpyAsync output failed");
|
||||
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
|
||||
cudaMemcpyAsync(accumulation_out, accumulation, accumulation_size_, cudaMemcpyDeviceToDevice,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"cudaMemcpyAsync output failed");
|
||||
return true;
|
||||
}
|
||||
std::vector<std::pair<KernelAttr, AdagradGpuKernelMod::AdagradLaunchFunc>> AdagradGpuKernelMod::func_list_ = {
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
&AdagradGpuKernelMod::LaunchKernel<float, float, float>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
&AdagradGpuKernelMod::LaunchKernel<half, half, half>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
&AdagradGpuKernelMod::LaunchKernel<half, float, half>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
&AdagradGpuKernelMod::LaunchKernel<float, float, half>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
&AdagradGpuKernelMod::LaunchKernel<float, half, float>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
&AdagradGpuKernelMod::LaunchKernel<half, float, float>},
|
||||
};
|
||||
|
||||
std::vector<KernelAttr> AdagradGpuKernelMod::GetOpSupport() {
|
||||
std::vector<KernelAttr> support_list;
|
||||
(void)std::transform(
|
||||
func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
|
||||
[](const std::pair<KernelAttr, AdagradGpuKernelMod::AdagradLaunchFunc> &pair) { return pair.first; });
|
||||
return support_list;
|
||||
}
|
||||
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, ApplyAdagrad, AdagradGpuKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -14,19 +14,20 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_ADAGRAD_GPU_KERNEL_H
|
||||
#define MINDSPORE_ADAGRAD_GPU_KERNEL_H
|
||||
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_NN_ADAM_ADAGRAD_GPU_KERNEL_H
|
||||
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_NN_ADAM_ADAGRAD_GPU_KERNEL_H
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <map>
|
||||
#include <utility>
|
||||
#include "plugin/device/gpu/kernel/gpu_kernel.h"
|
||||
#include "plugin/device/gpu/kernel/gpu_kernel_factory.h"
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/adagrad_impl.cuh"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
template <typename T, typename S, typename G>
|
||||
class AdagradGpuKernelMod : public DeprecatedNativeGpuKernelMod {
|
||||
class AdagradGpuKernelMod : public NativeGpuKernelMod {
|
||||
public:
|
||||
AdagradGpuKernelMod()
|
||||
: variable_size_(0),
|
||||
|
@ -34,89 +35,43 @@ class AdagradGpuKernelMod : public DeprecatedNativeGpuKernelMod {
|
|||
learning_rate_size_(0),
|
||||
gradient_size_(0),
|
||||
update_slots(true),
|
||||
is_null_input_(false),
|
||||
kernel_name_("ApplyAdagrad") {}
|
||||
|
||||
~AdagradGpuKernelMod() override = default;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
||||
if (is_null_input_) {
|
||||
return true;
|
||||
}
|
||||
T *variable = GetDeviceAddress<T>(inputs, 0);
|
||||
T *accumulation = GetDeviceAddress<T>(inputs, 1);
|
||||
S *learning_rate = GetDeviceAddress<S>(inputs, 2);
|
||||
G *gradient = GetDeviceAddress<G>(inputs, 3);
|
||||
T *variable_out = GetDeviceAddress<T>(outputs, 0);
|
||||
T *accumulation_out = GetDeviceAddress<T>(outputs, 1);
|
||||
ApplyAdagrad(inputs[0]->size / sizeof(T), update_slots, learning_rate, gradient, variable, accumulation,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
|
||||
cudaMemcpyAsync(variable_out, variable, variable_size_, cudaMemcpyDeviceToDevice,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"cudaMemcpyAsync output failed");
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
|
||||
cudaMemcpyAsync(accumulation_out, accumulation, accumulation_size_,
|
||||
cudaMemcpyDeviceToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"cudaMemcpyAsync output failed");
|
||||
|
||||
return true;
|
||||
return kernel_func_(this, inputs, workspace, outputs, stream_ptr);
|
||||
}
|
||||
|
||||
bool Init(const CNodePtr &kernel_node) override {
|
||||
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
|
||||
size_t input_num = common::AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
update_slots = common::AnfAlgo::GetNodeAttr<bool>(kernel_node, "update_slots");
|
||||
kernel_node_ = kernel_node;
|
||||
if (input_num != 4) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the number of inputs must be 4, but got " << input_num;
|
||||
}
|
||||
variable_size_ = sizeof(T);
|
||||
accumulation_size_ = sizeof(T);
|
||||
learning_rate_size_ = sizeof(S);
|
||||
gradient_size_ = sizeof(G);
|
||||
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) override;
|
||||
|
||||
auto variable_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||
auto accumulation_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
|
||||
auto gradient_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 3);
|
||||
is_null_input_ = CHECK_SHAPE_NULL(variable_shape, kernel_name_, "var") ||
|
||||
CHECK_SHAPE_NULL(accumulation_shape, kernel_name_, "accum") ||
|
||||
CHECK_SHAPE_NULL(gradient_shape, kernel_name_, "grad");
|
||||
if (is_null_input_ || AnfAlgo::IsShapesDynamic({variable_shape, accumulation_shape, gradient_shape})) {
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs, const std::map<uint32_t, tensor::TensorPtr> &) override;
|
||||
|
||||
variable_size_ *= SizeOf(variable_shape);
|
||||
accumulation_size_ *= SizeOf(accumulation_shape);
|
||||
gradient_size_ *= SizeOf(gradient_shape);
|
||||
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
protected:
|
||||
void InitSizeLists() override {
|
||||
input_size_list_.push_back(variable_size_);
|
||||
input_size_list_.push_back(accumulation_size_);
|
||||
input_size_list_.push_back(learning_rate_size_);
|
||||
input_size_list_.push_back(gradient_size_);
|
||||
output_size_list_.push_back(variable_size_);
|
||||
output_size_list_.push_back(accumulation_size_);
|
||||
}
|
||||
template <typename T, typename S, typename G>
|
||||
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr);
|
||||
|
||||
using AdagradLaunchFunc =
|
||||
std::function<bool(AdagradGpuKernelMod *, const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &, const std::vector<kernel::AddressPtr> &, void *)>;
|
||||
|
||||
private:
|
||||
AdagradLaunchFunc kernel_func_;
|
||||
static std::vector<std::pair<KernelAttr, AdagradLaunchFunc>> func_list_;
|
||||
size_t variable_size_;
|
||||
size_t accumulation_size_;
|
||||
size_t learning_rate_size_;
|
||||
size_t gradient_size_;
|
||||
bool update_slots;
|
||||
bool is_null_input_;
|
||||
std::string kernel_name_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_ADAGRAD_GPU_KERNEL_H
|
||||
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_NN_ADAM_ADAGRAD_GPU_KERNEL_H
|
||||
|
|
|
@ -101,6 +101,13 @@ class ApplyAdagradInfer : public abstract::OpInferBase {
|
|||
return std::make_shared<Tuple>(std::vector<TypePtr>{var_type, accum_type});
|
||||
}
|
||||
};
|
||||
bool ApplyAdagrad::get_update_slots() const {
|
||||
auto value_ptr = this->GetAttr(kUpdateSlots);
|
||||
return GetValue<bool>(value_ptr);
|
||||
}
|
||||
void ApplyAdagrad::set_update_slots(const bool update_slots) {
|
||||
(void)this->AddAttr(kUpdateSlots, api::MakeValue(update_slots));
|
||||
}
|
||||
REGISTER_PRIMITIVE_OP_INFER_IMPL(ApplyAdagrad, prim::kPrimApplyAdagrad, ApplyAdagradInfer, false);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -32,6 +32,12 @@ class MIND_API ApplyAdagrad : public BaseOperator {
|
|||
public:
|
||||
MIND_API_BASE_MEMBER(ApplyAdagrad);
|
||||
ApplyAdagrad() : BaseOperator(kNameApplyAdagrad) { InitIOName({"var", "accum", "lr", "grad"}, {"var", "accum"}); }
|
||||
/// \brief Set update_slots, A bool where if True, accum will be updated. Default: True.
|
||||
void set_update_slots(const bool update_slots);
|
||||
/// \brief Get update_slots.
|
||||
///
|
||||
/// \return update_slots.
|
||||
bool get_update_slots() const;
|
||||
};
|
||||
|
||||
using kPrimApplyAdagradPtr = std::shared_ptr<ApplyAdagrad>;
|
||||
|
|
Loading…
Reference in New Issue