forked from mindspore-Ecosystem/mindspore
optimize lamb
This commit is contained in:
parent
c29d6bb764
commit
b2dd894707
|
@ -0,0 +1,122 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/lamb_impl.cuh"
|
||||
#include "include/cuda_fp16.h"
|
||||
|
||||
const int32_t kSqareNum = 2;
|
||||
|
||||
template <typename T>
|
||||
__global__ void ApplyLambEralyKernel(const size_t size, T *variable, T *m, T *v, const float *beta1, const float *beta2,
|
||||
const float *epsilon, const T *decay, const int32_t *global_step,
|
||||
const T *gradient, const bool *decay_flag, float *update, float *var_float,
|
||||
float *grad_float, float *g_hat_var) {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) {
|
||||
float next_m = (beta1[0] * m[i] + (1 - beta1[0]) * gradient[i]);
|
||||
float next_v = (beta2[0] * v[i] + (1 - beta2[0]) * pow(gradient[i], kSqareNum));
|
||||
float next_mm = next_m / (1 - pow(beta1[0], global_step[0]));
|
||||
float next_vv = next_v / (1 - pow(beta2[0], global_step[0]));
|
||||
var_float[i] = variable[i];
|
||||
grad_float[i] = gradient[i];
|
||||
g_hat_var[i] = (next_mm / sqrt(next_vv + epsilon[0]) + decay[0] * variable[i]);
|
||||
update[i] = next_mm / (sqrt(next_vv) - epsilon[0]);
|
||||
if (decay_flag[0]) {
|
||||
update[i] += decay[0] * variable[i];
|
||||
}
|
||||
m[i] = next_m;
|
||||
v[i] = next_v;
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
__global__ void ApplyLambEralyKernel(const size_t size, half *variable, half *m, half *v, const float *beta1,
|
||||
const float *beta2, const float *epsilon, const half *decay,
|
||||
const int32_t *global_step, const half *gradient, const bool *decay_flag,
|
||||
float *update, float *var_float, float *grad_float, float *g_hat_var) {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) {
|
||||
float float_gradient = __half2float(gradient[i]);
|
||||
float float_var = __half2float(variable[i]);
|
||||
float float_decay = __half2float(decay[0]);
|
||||
|
||||
float next_m = (beta1[0] * __half2float(m[i]) + (1 - beta1[0]) * float_gradient);
|
||||
float next_v = (beta2[0] * __half2float(v[i]) + (1 - beta2[0]) * pow(float_gradient, kSqareNum));
|
||||
float next_mm = next_m / (1 - pow(beta1[0], global_step[0]));
|
||||
float next_vv = next_v / (1 - pow(beta2[0], global_step[0]));
|
||||
var_float[i] = float_var;
|
||||
grad_float[i] = float_gradient;
|
||||
g_hat_var[i] = next_mm / sqrt(next_vv + epsilon[0]) + float_decay * float_var;
|
||||
update[i] = next_mm / (sqrt(next_vv) - epsilon[0]);
|
||||
if (decay_flag[0]) {
|
||||
update[i] += float_decay * float_var;
|
||||
}
|
||||
m[i] = __float2half(next_m);
|
||||
v[i] = __float2half(next_v);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void ApplyLambAfterNormKernel(const size_t size, T *variable, const T *lr, const float *update,
|
||||
const float *trust_ratio) {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) {
|
||||
variable[i] = variable[i] - trust_ratio[0] * lr[0] * update[i];
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
__global__ void ApplyLambAfterNormKernel(const size_t size, half *variable, const half *lr, const float *update,
|
||||
const float *trust_ratio) {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) {
|
||||
variable[i] = __float2half(__half2float(variable[i]) - trust_ratio[0] * __half2float(lr[0]) * update[i]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void ApplyLambEraly(const size_t size, T *variable, T *m, T *v, const float *beta1, const float *beta2,
|
||||
const float *epsilon, const T *decay, const int32_t *global_step, const T *gradient,
|
||||
const bool *decay_flag, float *update, float *var_float, float *grad_float, float *g_hat_var,
|
||||
cudaStream_t cuda_stream) {
|
||||
ApplyLambEralyKernel<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, variable, m, v, beta1, beta2, epsilon,
|
||||
decay, global_step, gradient, decay_flag,
|
||||
update, var_float, grad_float, g_hat_var);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
CUDA_LIB_EXPORT void ApplyLambLater(const size_t size, T *variable, const T *lr, const float *update,
|
||||
const float *trust_ratio, cudaStream_t cuda_stream) {
|
||||
ApplyLambAfterNormKernel<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, variable, lr, update, trust_ratio);
|
||||
}
|
||||
|
||||
template CUDA_LIB_EXPORT void ApplyLambEraly<float>(const size_t size, float *variable, float *m, float *v,
|
||||
const float *beta1, const float *beta2, const float *epsilon,
|
||||
const float *decay, const int32_t *global_step,
|
||||
const float *gradient, const bool *decay_flag, float *update,
|
||||
float *w_square_ptr, float *g_square_ptr, float *g_hat_square_ptr,
|
||||
cudaStream_t cuda_stream);
|
||||
|
||||
template CUDA_LIB_EXPORT void ApplyLambEraly<half>(const size_t size, half *variable, half *m, half *v,
|
||||
const float *beta1, const float *beta2, const float *epsilon,
|
||||
const half *decay, const int32_t *global_step, const half *gradient,
|
||||
const bool *decay_flag, float *update, float *w_square_ptr,
|
||||
float *g_square_ptr, float *g_hat_square_ptr,
|
||||
cudaStream_t cuda_stream);
|
||||
|
||||
template CUDA_LIB_EXPORT void ApplyLambLater<float>(const size_t size, float *variable, const float *lr,
|
||||
const float *update, const float *trust_ratio,
|
||||
cudaStream_t cuda_stream);
|
||||
|
||||
template CUDA_LIB_EXPORT void ApplyLambLater<half>(const size_t size, half *variable, const half *lr,
|
||||
const float *update, const float *trust_ratio,
|
||||
cudaStream_t cuda_stream);
|
|
@ -0,0 +1,30 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_LAMB_IMPL_CUH_
|
||||
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_LAMB_IMPL_CUH_
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_common.h"
|
||||
template <typename T>
|
||||
CUDA_LIB_EXPORT void ApplyLambEraly(const size_t size, T *variable, T *m, T *v, const float *beta1, const float *beta2,
|
||||
const float *epsilon, const T *decay, const int32_t *global_step, const T *gradient,
|
||||
const bool *decay_flag, float *update, float *var_float, float *grad_float,
|
||||
float *g_hat_var, cudaStream_t cuda_stream);
|
||||
|
||||
template <typename T>
|
||||
CUDA_LIB_EXPORT void ApplyLambLater(const size_t size, T *variable, const T *lr, const float *update,
|
||||
const float *trust_ratio, cudaStream_t cuda_stream);
|
||||
|
||||
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_LAMB_IMPL_CUH_
|
|
@ -0,0 +1,56 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "plugin/device/gpu/kernel/nn/lamb_gpu_kernel.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
MS_REG_GPU_KERNEL_ONE(Lamb,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeBool)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
LambGpuKernelMod, float)
|
||||
MS_REG_GPU_KERNEL_ONE(Lamb,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeBool)
|
||||
.AddOutputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
LambGpuKernelMod, half)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,407 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_LAMB_GPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_LAMB_GPU_KERNEL_H_
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <algorithm>
|
||||
#include "plugin/device/gpu/kernel/gpu_kernel.h"
|
||||
#include "plugin/device/gpu/kernel/gpu_kernel_factory.h"
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/lamb_impl.cuh"
|
||||
#include "ops/lamb.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
constexpr size_t INPUT_NUM = 11;
|
||||
constexpr size_t kArgMaxDim = 7;
|
||||
constexpr float ten = 10;
|
||||
|
||||
// input param index
|
||||
constexpr size_t kVarIndex = 0;
|
||||
constexpr size_t kMIndex = 1;
|
||||
constexpr size_t kVIndex = 2;
|
||||
constexpr size_t kLearningRateIndex = 3;
|
||||
constexpr size_t kBeta1Index = 4;
|
||||
constexpr size_t kBeta2Index = 5;
|
||||
constexpr size_t kEpsilonIndex = 6;
|
||||
constexpr size_t kWeightDecayIndex = 7;
|
||||
constexpr size_t kGlobalStepIndex = 8;
|
||||
constexpr size_t kGradIndex = 9;
|
||||
constexpr size_t kDecayFlagIndex = 10;
|
||||
|
||||
// workspaces param index
|
||||
constexpr size_t kUpdateIndex = 0;
|
||||
constexpr size_t kVarFloatIndex = 1;
|
||||
constexpr size_t kGradFloatIndex = 2;
|
||||
constexpr size_t kGHatValIndex = 3;
|
||||
constexpr size_t kTrustRatioIndex = 4;
|
||||
constexpr size_t kReduceWorkspaceIndex = 5;
|
||||
constexpr size_t kWNormIndex = 6;
|
||||
constexpr size_t kGNormIndex = 7;
|
||||
constexpr size_t kGHatNormIndex = 8;
|
||||
|
||||
template <typename T>
|
||||
class LambGpuKernelMod : public NativeGpuKernelMod {
|
||||
public:
|
||||
LambGpuKernelMod() = default;
|
||||
|
||||
~LambGpuKernelMod() override { DestroyResource(); }
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspaces,
|
||||
const std::vector<AddressPtr> &, void *stream_ptr) override {
|
||||
if (is_null_input_) {
|
||||
return true;
|
||||
}
|
||||
|
||||
T *variable = GetDeviceAddress<T>(inputs, kVarIndex);
|
||||
T *m = GetDeviceAddress<T>(inputs, kMIndex);
|
||||
T *v = GetDeviceAddress<T>(inputs, kVIndex);
|
||||
T *learning_rate = GetDeviceAddress<T>(inputs, kLearningRateIndex);
|
||||
float *beta1 = GetDeviceAddress<float>(inputs, kBeta1Index);
|
||||
float *beta2 = GetDeviceAddress<float>(inputs, kBeta2Index);
|
||||
float *epsilon = GetDeviceAddress<float>(inputs, kEpsilonIndex);
|
||||
T *decay = GetDeviceAddress<T>(inputs, kWeightDecayIndex);
|
||||
int32_t *global_step = GetDeviceAddress<int32_t>(inputs, kGlobalStepIndex);
|
||||
T *gradient = GetDeviceAddress<T>(inputs, kGradIndex);
|
||||
bool *decay_flag = GetDeviceAddress<bool>(inputs, kDecayFlagIndex);
|
||||
float *update = GetDeviceAddress<float>(workspaces, kUpdateIndex);
|
||||
float *var_float = GetDeviceAddress<float>(workspaces, kVarFloatIndex);
|
||||
float *grad_float = GetDeviceAddress<float>(workspaces, kGradFloatIndex);
|
||||
float *g_hat_var = GetDeviceAddress<float>(workspaces, kGHatValIndex);
|
||||
|
||||
ApplyLambEraly(inputs[0]->size / sizeof(T), variable, m, v, beta1, beta2, epsilon, decay, global_step, gradient,
|
||||
decay_flag, update, var_float, grad_float, g_hat_var, reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
|
||||
float trust_ratio{0};
|
||||
CalcTrustRatio(workspaces, var_float, grad_float, g_hat_var, stream_ptr, &trust_ratio);
|
||||
|
||||
float *trust_ratio_ptr = GetDeviceAddress<float>(workspaces, kTrustRatioIndex);
|
||||
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
|
||||
cudaMemcpyAsync(trust_ratio_ptr, &trust_ratio, sizeof(float), cudaMemcpyHostToDevice,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"For " + kernel_name_ + " cudaMemcpyAsync trust_ratio failed.");
|
||||
|
||||
ApplyLambLater(inputs[0]->size / sizeof(T), variable, learning_rate, update, trust_ratio_ptr,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) override {
|
||||
if (inputs.size() != INPUT_NUM) {
|
||||
MS_LOG(EXCEPTION) << "For 'Lamb', the number of inputs should be " << INPUT_NUM << ", but got " << inputs.size();
|
||||
}
|
||||
|
||||
auto kernel_ptr = std::dynamic_pointer_cast<ops::Lamb>(base_operator);
|
||||
kernel_name_ = kernel_ptr->name();
|
||||
|
||||
InitResource();
|
||||
InitParamSizeByType();
|
||||
|
||||
auto covert_int64_shape_to_sizet_shape = [=](std::vector<int64_t> int64_shape) -> std::vector<size_t> {
|
||||
std::vector<size_t> size_t_shape;
|
||||
(void)std::transform(int64_shape.begin(), int64_shape.end(), std::back_inserter(size_t_shape), LongToSize);
|
||||
return size_t_shape;
|
||||
};
|
||||
|
||||
auto variable_int64_shape = inputs[kVarIndex]->GetShapeVector();
|
||||
auto m_int64_shape = inputs[kMIndex]->GetShapeVector();
|
||||
auto v_int64_shape = inputs[kVIndex]->GetShapeVector();
|
||||
auto gradient_int64_shape = inputs[kGradIndex]->GetShapeVector();
|
||||
|
||||
std::vector<size_t> variable_shape = covert_int64_shape_to_sizet_shape(variable_int64_shape);
|
||||
std::vector<size_t> m_shape = covert_int64_shape_to_sizet_shape(m_int64_shape);
|
||||
std::vector<size_t> v_shape = covert_int64_shape_to_sizet_shape(v_int64_shape);
|
||||
std::vector<size_t> gradient_shape = covert_int64_shape_to_sizet_shape(gradient_int64_shape);
|
||||
|
||||
is_null_input_ = CHECK_SHAPE_NULL(variable_shape, kernel_name_, "var") ||
|
||||
CHECK_SHAPE_NULL(m_shape, kernel_name_, "m") || CHECK_SHAPE_NULL(v_shape, kernel_name_, "v") ||
|
||||
CHECK_SHAPE_NULL(gradient_shape, kernel_name_, "gradient");
|
||||
if (is_null_input_) {
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
|
||||
InitParamSizeByShape(variable_shape, m_shape, v_shape, gradient_shape);
|
||||
|
||||
auto output_int64_shape = outputs[0]->GetShapeVector();
|
||||
std::vector<size_t> output_shape = covert_int64_shape_to_sizet_shape(output_int64_shape);
|
||||
|
||||
size_t input_dim = variable_shape.size();
|
||||
if (!CheckValidShape(variable_shape, output_shape, input_dim)) {
|
||||
return true;
|
||||
}
|
||||
|
||||
InitShapeInfo(variable_shape, output_shape);
|
||||
// Determine the reduce operation.
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT_NOTRACE(
|
||||
cudnnSetReduceTensorDescriptor(reduce_tensor_descriptor_, CUDNN_REDUCE_TENSOR_NORM2, CUDNN_DATA_FLOAT, nan_prop_,
|
||||
reduce_indices_, CUDNN_32BIT_INDICES),
|
||||
"For " + kernel_name_ + " cudnnSetReduceTensorDescriptor failed");
|
||||
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
|
||||
protected:
|
||||
void InitSizeLists() {
|
||||
input_size_list_.push_back(variable_size_);
|
||||
input_size_list_.push_back(m_size_);
|
||||
input_size_list_.push_back(v_size_);
|
||||
input_size_list_.push_back(learning_rate_size_);
|
||||
input_size_list_.push_back(beta1_size_);
|
||||
input_size_list_.push_back(beta2_size_);
|
||||
input_size_list_.push_back(epsilon_size_);
|
||||
input_size_list_.push_back(decay_size_);
|
||||
input_size_list_.push_back(global_step_size_);
|
||||
input_size_list_.push_back(gradient_size_);
|
||||
input_size_list_.push_back(decay_flag_size_);
|
||||
workspace_size_list_.push_back(update_size_);
|
||||
workspace_size_list_.push_back(var_float_size_);
|
||||
workspace_size_list_.push_back(grad_float_size_);
|
||||
workspace_size_list_.push_back(g_hat_val_size_);
|
||||
workspace_size_list_.push_back(trust_ratio_size_);
|
||||
size_t workspace_size{0};
|
||||
// Init workspace size for gradient tensor reduce sum calculate.
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT_NOTRACE(
|
||||
cudnnGetReductionWorkspaceSize(cudnn_handle_, reduce_tensor_descriptor_, input_descriptor_, output_descriptor_,
|
||||
&workspace_size),
|
||||
"For " + kernel_name_ + " cudnnGetReductionWorkspaceSize failed.");
|
||||
workspace_size_list_.emplace_back(workspace_size);
|
||||
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT_NOTRACE(cudnnGetTensorSizeInBytes(output_descriptor_, &reduce_output_size_),
|
||||
"For " + kernel_name_ + " cudnnGetTensorSizeInBytes failed.");
|
||||
workspace_size_list_.emplace_back(reduce_output_size_);
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT_NOTRACE(cudnnGetTensorSizeInBytes(output_descriptor_, &reduce_output_size_),
|
||||
"For " + kernel_name_ + " cudnnGetTensorSizeInBytes failed.");
|
||||
workspace_size_list_.emplace_back(reduce_output_size_);
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT_NOTRACE(cudnnGetTensorSizeInBytes(output_descriptor_, &reduce_output_size_),
|
||||
"For " + kernel_name_ + " cudnnGetTensorSizeInBytes failed.");
|
||||
workspace_size_list_.emplace_back(reduce_output_size_);
|
||||
|
||||
output_size_list_.push_back(0);
|
||||
output_size_list_.push_back(0);
|
||||
output_size_list_.push_back(0);
|
||||
}
|
||||
|
||||
void InitResource() override {
|
||||
cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle();
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT_NOTRACE(cudnnCreateReduceTensorDescriptor(&reduce_tensor_descriptor_),
|
||||
"For " + kernel_name_ + " cudnnCreateReduceTensorDescriptor failed.");
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT_NOTRACE(cudnnCreateTensorDescriptor(&input_descriptor_),
|
||||
"For " + kernel_name_ + " cudnnCreateTensorDescriptor failed.");
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT_NOTRACE(cudnnCreateTensorDescriptor(&output_descriptor_),
|
||||
"For " + kernel_name_ + " cudnnCreateTensorDescriptor failed.");
|
||||
}
|
||||
|
||||
private:
|
||||
void DestroyResource() noexcept override {
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT_NOTRACE(cudnnDestroyReduceTensorDescriptor(reduce_tensor_descriptor_),
|
||||
"For " + kernel_name_ + " cudnnDestroyReduceTensorDescriptor failed.");
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT_NOTRACE(cudnnDestroyTensorDescriptor(input_descriptor_),
|
||||
"For " + kernel_name_ + " cudnnDestroyTensorDescriptor failed.");
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT_NOTRACE(cudnnDestroyTensorDescriptor(output_descriptor_),
|
||||
"For " + kernel_name_ + " cudnnDestroyTensorDescriptor failed.");
|
||||
}
|
||||
|
||||
bool CheckValidShape(const std::vector<size_t> &input_shape, const std::vector<size_t> &output_shape,
|
||||
size_t input_dim) {
|
||||
is_null_input_ = CHECK_NULL_INPUT(input_shape) || CHECK_NULL_INPUT(output_shape);
|
||||
if (is_null_input_) {
|
||||
MS_LOG(WARNING) << "For 'LambGpuKernelMod', input or output is null.";
|
||||
InitSizeLists();
|
||||
return false;
|
||||
}
|
||||
if (input_shape.size() != output_shape.size()) {
|
||||
MS_LOG(EXCEPTION) << "The size of input shape: " << input_shape.size()
|
||||
<< " and the size of output shape: " << output_shape.size() << " are different.";
|
||||
}
|
||||
if (input_dim > kArgMaxDim) {
|
||||
MS_LOG(EXCEPTION) << "Broadcast operation is not supported when dim exceeds than " << kArgMaxDim;
|
||||
}
|
||||
CheckTensorSize({input_shape, output_shape});
|
||||
return true;
|
||||
}
|
||||
|
||||
void InitParamSizeByType() {
|
||||
variable_size_ = sizeof(T);
|
||||
m_size_ = sizeof(T);
|
||||
v_size_ = sizeof(T);
|
||||
learning_rate_size_ = sizeof(T);
|
||||
beta1_size_ = sizeof(T);
|
||||
beta2_size_ = sizeof(T);
|
||||
epsilon_size_ = sizeof(T);
|
||||
decay_size_ = sizeof(T);
|
||||
global_step_size_ = sizeof(int32_t);
|
||||
gradient_size_ = sizeof(T);
|
||||
decay_flag_size_ = sizeof(bool);
|
||||
update_size_ = sizeof(float);
|
||||
var_float_size_ = sizeof(float);
|
||||
grad_float_size_ = sizeof(float);
|
||||
g_hat_val_size_ = sizeof(float);
|
||||
trust_ratio_size_ = sizeof(float);
|
||||
}
|
||||
|
||||
void InitParamSizeByShape(const std::vector<size_t> &variable_shape, const std::vector<size_t> &m_shape,
|
||||
const std::vector<size_t> &v_shape, const std::vector<size_t> &gradient_shape) {
|
||||
for (size_t i = 0; i < variable_shape.size(); i++) {
|
||||
variable_size_ *= variable_shape[i];
|
||||
// save intermediate value
|
||||
update_size_ *= variable_shape[i];
|
||||
var_float_size_ *= variable_shape[i];
|
||||
grad_float_size_ *= variable_shape[i];
|
||||
g_hat_val_size_ *= variable_shape[i];
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < m_shape.size(); i++) {
|
||||
m_size_ *= m_shape[i];
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < v_shape.size(); i++) {
|
||||
v_size_ *= v_shape[i];
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < gradient_shape.size(); i++) {
|
||||
gradient_size_ *= gradient_shape[i];
|
||||
}
|
||||
}
|
||||
|
||||
void CalcTrustRatio(const std::vector<AddressPtr> &workspaces, float *var_float, float *grad_float, float *g_hat_var,
|
||||
void *stream_ptr, float *trust_ratio) {
|
||||
if (var_float == nullptr || grad_float == nullptr || g_hat_var == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "var_float or grad_float or g_hat_var is null";
|
||||
}
|
||||
|
||||
float *reduce_workspace_addr = GetPossiblyNullDeviceAddress<float>(workspaces, kReduceWorkspaceIndex);
|
||||
float *w_norm_ptr = GetDeviceAddress<float>(workspaces, kWNormIndex);
|
||||
float *g_norm_ptr = GetDeviceAddress<float>(workspaces, kGNormIndex);
|
||||
float *g_hat_norm_ptr = GetDeviceAddress<float>(workspaces, kGHatNormIndex);
|
||||
|
||||
// Calc sum of square
|
||||
constexpr float alpha = 1;
|
||||
constexpr float beta = 0;
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT_NOTRACE(
|
||||
cudnnReduceTensor(cudnn_handle_, reduce_tensor_descriptor_, nullptr, 0, reduce_workspace_addr,
|
||||
workspace_size_list_[kReduceWorkspaceIndex], &alpha, input_descriptor_, var_float, &beta,
|
||||
output_descriptor_, w_norm_ptr),
|
||||
"For " + kernel_name_ + " cudnnReduceTensor for 'var_float' failed");
|
||||
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT_NOTRACE(
|
||||
cudnnReduceTensor(cudnn_handle_, reduce_tensor_descriptor_, nullptr, 0, reduce_workspace_addr,
|
||||
workspace_size_list_[kReduceWorkspaceIndex], &alpha, input_descriptor_, grad_float, &beta,
|
||||
output_descriptor_, g_norm_ptr),
|
||||
"For " + kernel_name_ + " cudnnReduceTensor for 'grad_float' failed");
|
||||
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT_NOTRACE(
|
||||
cudnnReduceTensor(cudnn_handle_, reduce_tensor_descriptor_, nullptr, 0, reduce_workspace_addr,
|
||||
workspace_size_list_[kReduceWorkspaceIndex], &alpha, input_descriptor_, g_hat_var, &beta,
|
||||
output_descriptor_, g_hat_norm_ptr),
|
||||
"For " + kernel_name_ + " cudnnReduceTensor for 'g_hat_var' failed");
|
||||
|
||||
float w_norm = 0;
|
||||
float g_norm = 0;
|
||||
float g_norm_hat = 0;
|
||||
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaMemcpyAsync(&w_norm, w_norm_ptr, reduce_output_size_, cudaMemcpyDeviceToHost,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"For " + kernel_name_ + " cudaMemcpyAsync w_square_sum failed.");
|
||||
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaMemcpyAsync(&g_norm, g_norm_ptr, reduce_output_size_, cudaMemcpyDeviceToHost,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"For " + kernel_name_ + " cudaMemcpyAsync g_square_sum failed.");
|
||||
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
|
||||
cudaMemcpyAsync(&g_norm_hat, g_hat_norm_ptr, reduce_output_size_, cudaMemcpyDeviceToHost,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"For " + kernel_name_ + " cudaMemcpyAsync g_hat_square_sum failed.");
|
||||
|
||||
*trust_ratio = w_norm > 0 ? (g_norm_hat > 0 ? (w_norm / g_norm_hat) : 1) : 1;
|
||||
if (*trust_ratio < 0 || std::isnan(*trust_ratio)) {
|
||||
*trust_ratio = 0;
|
||||
} else if (*trust_ratio > ten) {
|
||||
*trust_ratio = ten;
|
||||
}
|
||||
}
|
||||
|
||||
void InitShapeInfo(const std::vector<size_t> &input_shape, const std::vector<size_t> &output_shape) {
|
||||
// Determine which dimension will be reduced.
|
||||
std::vector<size_t> reduce_output_shape = output_shape;
|
||||
std::fill(reduce_output_shape.begin(), reduce_output_shape.end(), 1);
|
||||
|
||||
// Infer input and output descriptor.
|
||||
InferInAndOutDesc(input_shape, reduce_output_shape);
|
||||
}
|
||||
|
||||
void InferInAndOutDesc(const std::vector<size_t> &input_shape, const std::vector<size_t> &reduce_output_shape) {
|
||||
constexpr size_t split_dim = 4;
|
||||
constexpr size_t dim_idx_two = 2;
|
||||
constexpr size_t dim_idx_three = 3;
|
||||
if (input_shape.size() <= split_dim) {
|
||||
std::vector<size_t> new_input_shape;
|
||||
ShapeNdTo4d(input_shape, &new_input_shape);
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT_NOTRACE(
|
||||
cudnnSetTensor4dDescriptor(input_descriptor_, CUDNN_TENSOR_NCHW, data_type_, new_input_shape[0],
|
||||
new_input_shape[1], new_input_shape[dim_idx_two], new_input_shape[dim_idx_three]),
|
||||
"For " + kernel_name_ + " cudnnSetTensor4dDescriptor failed");
|
||||
} else {
|
||||
CudnnSetTensorNdDescriptor(input_shape, input_descriptor_, data_type_, kernel_name_);
|
||||
}
|
||||
if (reduce_output_shape.size() <= split_dim) {
|
||||
std::vector<size_t> new_reduce_output_shape;
|
||||
ShapeNdTo4d(reduce_output_shape, &new_reduce_output_shape);
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT_NOTRACE(
|
||||
|
||||
cudnnSetTensor4dDescriptor(output_descriptor_, CUDNN_TENSOR_NCHW, data_type_, new_reduce_output_shape[0],
|
||||
new_reduce_output_shape[1], new_reduce_output_shape[dim_idx_two],
|
||||
new_reduce_output_shape[dim_idx_three]),
|
||||
"For " + kernel_name_ + " cudnnSetTensor4dDescriptor failed");
|
||||
} else {
|
||||
CudnnSetTensorNdDescriptor(reduce_output_shape, output_descriptor_, data_type_, kernel_name_);
|
||||
}
|
||||
}
|
||||
|
||||
size_t variable_size_{0};
|
||||
size_t m_size_{0};
|
||||
size_t v_size_{0};
|
||||
size_t learning_rate_size_{0};
|
||||
size_t beta1_size_{0};
|
||||
size_t beta2_size_{0};
|
||||
size_t epsilon_size_{0};
|
||||
size_t decay_size_{0};
|
||||
size_t global_step_size_{0};
|
||||
size_t gradient_size_{0};
|
||||
size_t decay_flag_size_{0};
|
||||
size_t update_size_{0};
|
||||
size_t var_float_size_{0};
|
||||
size_t grad_float_size_{0};
|
||||
size_t g_hat_val_size_{0};
|
||||
size_t trust_ratio_size_{0};
|
||||
size_t reduce_output_size_{0};
|
||||
bool is_null_input_{false};
|
||||
|
||||
cudnnHandle_t cudnn_handle_{nullptr};
|
||||
cudnnDataType_t data_type_{CUDNN_DATA_FLOAT};
|
||||
cudnnNanPropagation_t nan_prop_{CUDNN_NOT_PROPAGATE_NAN};
|
||||
cudnnReduceTensorIndices_t reduce_indices_{CUDNN_REDUCE_TENSOR_NO_INDICES};
|
||||
cudnnReduceTensorDescriptor_t reduce_tensor_descriptor_{nullptr};
|
||||
cudnnTensorDescriptor_t input_descriptor_{nullptr};
|
||||
cudnnTensorDescriptor_t output_descriptor_{nullptr};
|
||||
std::string kernel_name_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_LAMB_GPU_KERNEL_H_
|
|
@ -421,6 +421,7 @@ GVAR_DEF(PrimitivePtr, kPrimCropAndResizeGradBoxes, std::make_shared<Primitive>(
|
|||
GVAR_DEF(PrimitivePtr, kPrimCeLU, std::make_shared<Primitive>("CeLU"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimAdam, std::make_shared<Primitive>("Adam"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimAdamWeightDecay, std::make_shared<Primitive>("AdamWeightDecay"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimLamb, std::make_shared<Primitive>("Lamb"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimApplyAdaMax, std::make_shared<Primitive>("ApplyAdaMax"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimAudioSpectrogram, std::make_shared<Primitive>("AudioSpectrogram"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimFlatten, std::make_shared<Primitive>("Flatten"));
|
||||
|
|
|
@ -0,0 +1,105 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ops/lamb.h"
|
||||
#include <set>
|
||||
#include "ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
TuplePtr LambInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto prim_name = primitive->name();
|
||||
auto var_type = input_args[kInputIndex0]->BuildType();
|
||||
auto m_type = input_args[kInputIndex1]->BuildType();
|
||||
auto v_type = input_args[kInputIndex2]->BuildType();
|
||||
auto lr_type = input_args[kInputIndex3]->BuildType();
|
||||
auto beta1_type = input_args[kInputIndex4]->BuildType();
|
||||
auto beta2_type = input_args[kInputIndex5]->BuildType();
|
||||
auto epsilon_type = input_args[kInputIndex6]->BuildType();
|
||||
auto decay_type = input_args[kInputIndex7]->BuildType();
|
||||
auto global_step_type = input_args[kInputIndex8]->BuildType();
|
||||
auto grad_type = input_args[kInputIndex9]->BuildType();
|
||||
auto decay_flag_type = input_args[kInputIndex10]->BuildType();
|
||||
|
||||
std::map<std::string, TypePtr> type_dict;
|
||||
type_dict.emplace("var", var_type);
|
||||
type_dict.emplace("m", m_type);
|
||||
type_dict.emplace("v", v_type);
|
||||
type_dict.emplace("grad", grad_type);
|
||||
type_dict.emplace("lr", lr_type);
|
||||
type_dict.emplace("decay", decay_type);
|
||||
std::set<TypePtr> num_type = {kInt8, kInt16, kInt32, kInt64, kUInt8, kUInt16, kUInt32,
|
||||
kUInt64, kFloat16, kFloat32, kFloat64, kComplex64, kComplex128};
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeSame(type_dict, num_type, prim_name);
|
||||
std::map<std::string, TypePtr> type_dict1;
|
||||
type_dict1.emplace("beta1", beta1_type);
|
||||
type_dict1.emplace("beta2", beta2_type);
|
||||
type_dict1.emplace("epsilon", epsilon_type);
|
||||
std::set<TypePtr> float_set = {kFloat16, kFloat32};
|
||||
(void)CheckAndConvertUtils::CheckScalarOrTensorTypesSame(type_dict1, float_set, prim_name, true);
|
||||
|
||||
std::set<TypePtr> bool_set = {kBool};
|
||||
(void)CheckAndConvertUtils::CheckTypeValid("global_step", global_step_type, num_type, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckTypeValid("decay_flag", decay_flag_type, bool_set, prim_name);
|
||||
|
||||
return std::make_shared<Tuple>(std::vector<TypePtr>{var_type, m_type, v_type});
|
||||
}
|
||||
abstract::TupleShapePtr LambInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto prim_name = primitive->name();
|
||||
auto var_shape_ptr = input_args[kInputIndex0]->BuildShape();
|
||||
auto m_shape_ptr = input_args[kInputIndex1]->BuildShape();
|
||||
auto v_shape_ptr = input_args[kInputIndex2]->BuildShape();
|
||||
auto grad_shape_ptr = input_args[kInputIndex9]->BuildShape();
|
||||
if (var_shape_ptr->IsDynamic() || m_shape_ptr->IsDynamic() || v_shape_ptr->IsDynamic() ||
|
||||
grad_shape_ptr->IsDynamic()) {
|
||||
MS_LOG(WARNING) << "var is dynamic" << var_shape_ptr->IsDynamic() << "m is dynamic" << m_shape_ptr->IsDynamic()
|
||||
<< "v is dynamic" << v_shape_ptr->IsDynamic() << "grad is dynamic" << grad_shape_ptr->IsDynamic();
|
||||
return std::make_shared<abstract::TupleShape>(
|
||||
std::vector<abstract::BaseShapePtr>{var_shape_ptr, m_shape_ptr, v_shape_ptr});
|
||||
}
|
||||
auto var_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape];
|
||||
auto m_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape];
|
||||
auto v_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape())[kShape];
|
||||
auto grad_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex9]->BuildShape())[kShape];
|
||||
CheckAndConvertUtils::Check("var_shape", var_shape, kEqual, m_shape, prim_name);
|
||||
CheckAndConvertUtils::Check("var_shape", var_shape, kEqual, v_shape, prim_name);
|
||||
CheckAndConvertUtils::Check("var_shape", var_shape, kEqual, grad_shape, prim_name);
|
||||
return std::make_shared<abstract::TupleShape>(
|
||||
std::vector<abstract::BaseShapePtr>{var_shape_ptr, m_shape_ptr, v_shape_ptr});
|
||||
}
|
||||
} // namespace
|
||||
|
||||
MIND_API_OPERATOR_IMPL(Lamb, BaseOperator);
|
||||
|
||||
AbstractBasePtr LambInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
for (auto item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
const int64_t kInputNum = 11;
|
||||
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, kInputNum, prim_name);
|
||||
auto infer_type = LambInferType(primitive, input_args);
|
||||
auto infer_shape = LambInferShape(primitive, input_args);
|
||||
return abstract::MakeAbstract(infer_shape, infer_type);
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(Lamb, prim::kPrimLamb, LambInfer, nullptr, true);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,41 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_OPS_LAMB_H_
|
||||
#define MINDSPORE_CORE_OPS_LAMB_H_
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include "ops/base_operator.h"
|
||||
#include "mindapi/base/types.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameLamb = "Lamb";
|
||||
|
||||
class MIND_API Lamb : public BaseOperator {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(Lamb);
|
||||
Lamb() : BaseOperator(kNameLamb) {}
|
||||
};
|
||||
abstract::AbstractBasePtr LambInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||
using kPrimLambPtr = std::shared_ptr<Lamb>;
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_OPS_LAMB_H_
|
Loading…
Reference in New Issue