forked from mindspore-Ecosystem/mindspore
fixed validator for CumProd, ReduceProd, ApplyRMSProp
This commit is contained in:
parent
63bb429633
commit
51affc2f1b
|
@ -19,17 +19,17 @@
|
|||
#include "device/gpu/cuda_common.h"
|
||||
|
||||
template <typename T>
|
||||
__global__ void RmsPropKernel(const T* learning_rate, const T* decay, const T* momentum, const T* epsilon, T* variable,
|
||||
__global__ void RmsPropKernel(const T* learning_rate, const T decay, const T momentum, const T epsilon, T* variable,
|
||||
T* mean_square, T*moment, T* gradients, const size_t size) {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (size); i += blockDim.x * gridDim.x) {
|
||||
mean_square[i] = decay[0] * mean_square[i] + (1.0 - decay[0]) * gradients[i] * gradients[i];
|
||||
moment[i] = momentum[0] * moment[i] + learning_rate[0] * rsqrt(mean_square[i] + epsilon[0]) * gradients[i];
|
||||
mean_square[i] = decay * mean_square[i] + (1.0 - decay) * gradients[i] * gradients[i];
|
||||
moment[i] = momentum * moment[i] + learning_rate[0] * rsqrt(mean_square[i] + epsilon) * gradients[i];
|
||||
variable[i] -= moment[i];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void RmsProp(const T* learning_rate, const T* decay, const T* momentum, const T* epsilon,
|
||||
void RmsProp(const T* learning_rate, const T decay, const T momentum, const T epsilon,
|
||||
T* variable, T* mean_square, T* moment, T* gradients, const size_t size, cudaStream_t cuda_stream) {
|
||||
RmsPropKernel<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(learning_rate, decay, momentum, epsilon,
|
||||
variable, mean_square, moment, gradients, size);
|
||||
|
@ -58,7 +58,7 @@ void RmsPropCenter(const T* learning_rate, const T* decay, const T* momentum, co
|
|||
}
|
||||
|
||||
template
|
||||
void RmsProp(const float* learning_rate, const float* decay, const float* momentum, const float* epsilon,
|
||||
void RmsProp(const float* learning_rate, const float decay, const float momentum, const float epsilon,
|
||||
float* variable, float* mean_square, float* moment, float* gradients, const size_t size,
|
||||
cudaStream_t cuda_stream);
|
||||
|
||||
|
|
|
@ -19,7 +19,7 @@
|
|||
#include "device/gpu/cuda_common.h"
|
||||
|
||||
template <typename T>
|
||||
void RmsProp(const T* learning_rate, const T* decay, const T* momentum, const T* epsilon, T* variable, T* mean_square,
|
||||
void RmsProp(const T* learning_rate, const T decay, const T momentum, const T epsilon, T* variable, T* mean_square,
|
||||
T* moment, T* gradients, const size_t size, cudaStream_t cuda_stream);
|
||||
|
||||
template <typename T>
|
||||
|
|
|
@ -25,9 +25,6 @@ MS_REG_GPU_KERNEL_ONE(ApplyRMSProp,
|
|||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
RMSPropGpuKernel, float)
|
||||
|
||||
|
|
|
@ -27,7 +27,7 @@ namespace kernel {
|
|||
template <typename T>
|
||||
class RMSPropGpuKernel : public GpuKernel {
|
||||
public:
|
||||
RMSPropGpuKernel() : size_(1), use_center_(false) {}
|
||||
RMSPropGpuKernel() : size_(1), use_center_(false), decay_(0.0), momentum_(0.9), epsilon_(1e-12) {}
|
||||
~RMSPropGpuKernel() override = default;
|
||||
|
||||
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
|
||||
|
@ -40,13 +40,10 @@ class RMSPropGpuKernel : public GpuKernel {
|
|||
T *variable = GetDeviceAddress<T>(inputs, 0);
|
||||
T *mean_square = GetDeviceAddress<T>(inputs, 1);
|
||||
T *moment = GetDeviceAddress<T>(inputs, 2);
|
||||
T *gradients = GetDeviceAddress<T>(inputs, 3);
|
||||
T *learning_rate = GetDeviceAddress<T>(inputs, 4);
|
||||
T *decay = GetDeviceAddress<T>(inputs, 5);
|
||||
T *momentum = GetDeviceAddress<T>(inputs, 6);
|
||||
T *epsilon = GetDeviceAddress<T>(inputs, 7);
|
||||
T *learning_rate = GetDeviceAddress<T>(inputs, 3);
|
||||
T *gradients = GetDeviceAddress<T>(inputs, 4);
|
||||
|
||||
RmsProp(learning_rate, decay, momentum, epsilon, variable, mean_square, moment, gradients, size_,
|
||||
RmsProp(learning_rate, decay_, momentum_, epsilon_, variable, mean_square, moment, gradients, size_,
|
||||
reinterpret_cast<cudaStream_t>(stream));
|
||||
} else {
|
||||
T *variable = GetDeviceAddress<T>(inputs, 0);
|
||||
|
@ -70,6 +67,11 @@ class RMSPropGpuKernel : public GpuKernel {
|
|||
use_center_ = true;
|
||||
}
|
||||
|
||||
if (node_name == "ApplyRMSProp") {
|
||||
decay_ = GetAttr<float>(kernel_node, "rho");
|
||||
momentum_ = GetAttr<float>(kernel_node, "momentum");
|
||||
epsilon_ = GetAttr<float>(kernel_node, "epsilon");
|
||||
}
|
||||
auto input_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0);
|
||||
for (auto &dim : input_shape) {
|
||||
size_ *= dim;
|
||||
|
@ -81,24 +83,33 @@ class RMSPropGpuKernel : public GpuKernel {
|
|||
protected:
|
||||
void InitSizeLists() override {
|
||||
size_t input_size = size_ * sizeof(T);
|
||||
input_size_list_.push_back(input_size);
|
||||
if (use_center_) {
|
||||
if (!use_center_) {
|
||||
input_size_list_.push_back(input_size);
|
||||
input_size_list_.push_back(input_size);
|
||||
input_size_list_.push_back(input_size);
|
||||
input_size_list_.push_back(sizeof(T));
|
||||
input_size_list_.push_back(input_size);
|
||||
output_size_list_.push_back(input_size);
|
||||
} else {
|
||||
input_size_list_.push_back(input_size);
|
||||
input_size_list_.push_back(input_size);
|
||||
input_size_list_.push_back(input_size);
|
||||
input_size_list_.push_back(input_size);
|
||||
input_size_list_.push_back(input_size);
|
||||
input_size_list_.push_back(sizeof(T));
|
||||
input_size_list_.push_back(sizeof(T));
|
||||
input_size_list_.push_back(sizeof(T));
|
||||
input_size_list_.push_back(sizeof(T));
|
||||
output_size_list_.push_back(input_size);
|
||||
}
|
||||
|
||||
input_size_list_.push_back(input_size);
|
||||
input_size_list_.push_back(input_size);
|
||||
input_size_list_.push_back(input_size);
|
||||
input_size_list_.push_back(sizeof(T));
|
||||
input_size_list_.push_back(sizeof(T));
|
||||
input_size_list_.push_back(sizeof(T));
|
||||
input_size_list_.push_back(sizeof(T));
|
||||
output_size_list_.push_back(0);
|
||||
}
|
||||
|
||||
private:
|
||||
size_t size_;
|
||||
bool use_center_;
|
||||
float decay_;
|
||||
float momentum_;
|
||||
float epsilon_;
|
||||
|
||||
std::vector<size_t> input_size_list_;
|
||||
std::vector<size_t> output_size_list_;
|
||||
|
|
|
@ -175,7 +175,7 @@ class FakeQuantWithMinMaxAscend(Cell):
|
|||
else:
|
||||
quant_fun = P.FakeQuantPerLayer
|
||||
ema_fun = P.FakeQuantMinMaxPerLayerUpdate
|
||||
|
||||
|
||||
self.fake_quant = quant_fun(num_bits=self.num_bits,
|
||||
ema=self.ema,
|
||||
ema_decay=self.ema_decay,
|
||||
|
@ -272,7 +272,7 @@ class FakeQuantWithMinMaxGPU(Cell):
|
|||
0, self.out_channels)]).astype(np.float32)
|
||||
self.minq = Parameter(Tensor(min_array), name='quant_min', requires_grad=False)
|
||||
self.maxq = Parameter(Tensor(max_array), name='quant_max', requires_grad=False)
|
||||
|
||||
|
||||
if per_channel:
|
||||
quant_fun = partial(P.FakeQuantPerChannel, channel_axis=self.channel_axis)
|
||||
else:
|
||||
|
|
|
@ -175,8 +175,7 @@ class Adam(Optimizer):
|
|||
If True, updates the gradients using NAG.
|
||||
If False, updates the gradients without using NAG. Default: False.
|
||||
weight_decay (float): Weight decay (L2 penalty). Default: 0.0.
|
||||
loss_scale (float): A floating point value for the loss scale. Should be equal to or greater than 1. Default:
|
||||
1.0.
|
||||
loss_scale (float): A floating point value for the loss scale. Should be greater than 0. Default: 1.0.
|
||||
|
||||
Inputs:
|
||||
- **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.
|
||||
|
@ -210,7 +209,7 @@ class Adam(Optimizer):
|
|||
validator.check_value_type("use_locking", use_locking, [bool], self.cls_name)
|
||||
validator.check_value_type("use_nesterov", use_nesterov, [bool], self.cls_name)
|
||||
validator.check_value_type("loss_scale", loss_scale, [float], self.cls_name)
|
||||
validator.check_number_range("loss_scale", loss_scale, 1.0, float("inf"), Rel.INC_LEFT, self.cls_name)
|
||||
validator.check_number_range("loss_scale", loss_scale, 0.0, float("inf"), Rel.INC_LEFT, self.cls_name)
|
||||
|
||||
self.beta1 = Tensor(beta1, mstype.float32)
|
||||
self.beta2 = Tensor(beta2, mstype.float32)
|
||||
|
|
|
@ -122,7 +122,8 @@ class SameTypeShape(PrimitiveWithInfer):
|
|||
Checks whether data type and shape of two tensors are the same.
|
||||
|
||||
Raises:
|
||||
TypeError or ValueError: If not the same.
|
||||
TypeError - If data type not the same.
|
||||
ValueError - If shape of two tensors not the same.
|
||||
|
||||
Inputs:
|
||||
- **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
|
||||
|
@ -1031,7 +1032,7 @@ class InvertPermutation(PrimitiveWithInfer):
|
|||
- **input_x** (Union(tuple[int], Tensor[int])) - The input tuple is constructed by multiple
|
||||
integers, i.e., :math:`(y_1, y_2, ..., y_S)` representing the indices.
|
||||
The values must include 0. There can be no duplicate values or negative values.
|
||||
If the input is Tensor, it must be 1-d and the dtype is int.
|
||||
If the input is Tensor, it must be 1-d and the dtype is int. Only constant value is allowed.
|
||||
|
||||
|
||||
Outputs:
|
||||
|
@ -1061,7 +1062,9 @@ class InvertPermutation(PrimitiveWithInfer):
|
|||
z = [x_value[i] for i in range(len(x_value))]
|
||||
z.sort()
|
||||
|
||||
validator.check(f'value length', len(x_value), f'unique value length', len(set(x_value)), Rel.EQ, self.name)
|
||||
for i in range(1, len(z)):
|
||||
if z[i-1] == z[i]:
|
||||
raise ValueError(f"For {self.name}, {z[i]} is duplicated in the input.")
|
||||
validator.check(f'value min', min(x_value), '', 0, Rel.EQ, self.name)
|
||||
validator.check(f'value max', max(x_value), '', len(x_value)-1, Rel.EQ, self.name)
|
||||
|
||||
|
|
|
@ -258,6 +258,8 @@ class _Reduce(PrimitiveWithInfer):
|
|||
args = {'input_x': input_x['dtype']}
|
||||
validator.check_tensor_type_same(args, valid_dtype, self.name)
|
||||
|
||||
if axis_v is None:
|
||||
raise ValueError(f"For {self.name}, axis must be const.")
|
||||
input_shp = _infer_shape_reduce(input_shp, axis_v, self.keep_dims, self.name)
|
||||
return {'shape': input_shp,
|
||||
'dtype': input_x['dtype'],
|
||||
|
@ -445,8 +447,9 @@ class ReduceProd(_Reduce):
|
|||
Default : False, don't keep these reduced dimensions.
|
||||
|
||||
Inputs:
|
||||
- **input_x** (Tensor[Number]) - The input tensor.
|
||||
- **axis** (Union[int, tuple(int), list(int)]) - The dimensions to reduce. Default: (), reduce all dimensions.
|
||||
- **input_x** (Tensor[Number]) - The input tensor.
|
||||
- **axis** (Union[int, tuple(int), list(int)]) - The dimensions to reduce. Default: (), reduce all dimensions.
|
||||
Only constant value is allowed.
|
||||
|
||||
Outputs:
|
||||
Tensor, has the same dtype as the 'input_x'.
|
||||
|
@ -474,8 +477,9 @@ class CumProd(PrimitiveWithInfer):
|
|||
reverse (bool): If True, reverse the result along axis. Default: False
|
||||
|
||||
Inputs:
|
||||
- **input_x** (Tensor[Number]) - The input tensor.
|
||||
- **axis** (int) - The dimensions to compute the cumulative product.
|
||||
- **input_x** (Tensor[Number]) - The input tensor.
|
||||
- **axis** (int) - The dimensions to compute the cumulative product.
|
||||
Only constant value is allowed.
|
||||
|
||||
Outputs:
|
||||
Tensor, has the same shape and dtype as the 'input_x'.
|
||||
|
@ -507,6 +511,10 @@ class CumProd(PrimitiveWithInfer):
|
|||
validator.check_subclass("axis", axis_type, mstype.int_, cls_name)
|
||||
return x_type
|
||||
|
||||
def infer_value(self, x, axis):
|
||||
if axis is None:
|
||||
raise ValueError(f"For {self.name}, axis must be const.")
|
||||
|
||||
|
||||
class MatMul(PrimitiveWithInfer):
|
||||
"""
|
||||
|
@ -669,6 +677,10 @@ class CumSum(PrimitiveWithInfer):
|
|||
'dtype': x['dtype'],
|
||||
'value': None}
|
||||
|
||||
def infer_value(self, x, axis):
|
||||
if axis is None:
|
||||
raise ValueError(f"For {self.name}, axis must be const.")
|
||||
|
||||
|
||||
class AddN(PrimitiveWithInfer):
|
||||
"""
|
||||
|
|
|
@ -1707,9 +1707,9 @@ class ApplyRMSProp(PrimitiveWithInfer):
|
|||
- **moment** (Tensor) - Delta of `var`, must have the same type as `var`.
|
||||
- **learning_rate** (Union[Number, Tensor]) - Learning rate.
|
||||
- **grad** (Tensor) - Gradients, must have the same type as `var`.
|
||||
- **decay** (float) - Decay rate.
|
||||
- **momentum** (float) - Momentum.
|
||||
- **epsilon** (float) - Ridge term.
|
||||
- **decay** (float) - Decay rate. Only constant value is allowed.
|
||||
- **momentum** (float) - Momentum. Only constant value is allowed.
|
||||
- **epsilon** (float) - Ridge term. Only constant value is allowed.
|
||||
|
||||
Outputs:
|
||||
Tensor, parameters to be update.
|
||||
|
@ -1759,6 +1759,13 @@ class ApplyRMSProp(PrimitiveWithInfer):
|
|||
return var_dtype, var_dtype, var_dtype
|
||||
return var_dtype
|
||||
|
||||
def infer_value(self, var, mean_square, moment, learning_rate, grad, decay, momentum, epsilon):
|
||||
if decay is None or momentum is None or epsilon is None:
|
||||
raise ValueError(f"For {self.name}, decay, momentum, epsilon must be const.")
|
||||
if not self.is_ge and self.is_d:
|
||||
return None, None, None
|
||||
return None
|
||||
|
||||
|
||||
class ApplyCenteredRMSProp(PrimitiveWithInfer):
|
||||
"""
|
||||
|
|
|
@ -379,7 +379,7 @@ class ConfusionMatrix(PrimitiveWithInfer):
|
|||
Inputs:
|
||||
- **labels** (Tensor) - real labels, tensor of 1-D. the dtype must be non-negative Integer.
|
||||
- **predictions** (Tensor) - the labels from prediction, tensor of 1-D.
|
||||
the shape same as `labels` and the dtype must be non-negative Integer.
|
||||
the shape same as `labels` and the dtype must be non-negative Integer.
|
||||
- **weights** (Tensor) - tensor of 1-D. the shape same as `predictions`.
|
||||
|
||||
Outputs:
|
||||
|
|
|
@ -24,19 +24,25 @@ from mindspore.ops import operations as P
|
|||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
|
||||
|
||||
class NetRMSProp(nn.Cell):
|
||||
def __init__(self, use_centered):
|
||||
super(NetRMSProp, self).__init__()
|
||||
self.use_centered = use_centered
|
||||
if use_centered:
|
||||
self.rms_opt = P.ApplyCenteredRMSProp()
|
||||
else:
|
||||
self.rms_opt = P.ApplyRMSProp()
|
||||
class NetCenteredRMSProp(nn.Cell):
|
||||
def __init__(self):
|
||||
super(NetCenteredRMSProp, self).__init__()
|
||||
self.rms_opt = P.ApplyCenteredRMSProp()
|
||||
|
||||
def construct(self, var, g, mg, rms, mom, lr, decay, momentum, epsilon):
|
||||
if self.use_centered:
|
||||
return self.rms_opt(var, mg, rms, mom, g, lr, decay, momentum, epsilon)
|
||||
return self.rms_opt(var, rms, mom, lr, g, decay, momentum, epsilon)
|
||||
return self.rms_opt(var, mg, rms, mom, g, lr, decay, momentum, epsilon)
|
||||
|
||||
|
||||
class NetRMSProp(nn.Cell):
|
||||
def __init__(self, decay, momentum, epsilon):
|
||||
super(NetRMSProp, self).__init__()
|
||||
self.decay = decay
|
||||
self.momentum = momentum
|
||||
self.epsilon = epsilon
|
||||
self.rms_opt = P.ApplyRMSProp()
|
||||
|
||||
def construct(self, var, g, mg, rms, mom, lr):
|
||||
return self.rms_opt(var, rms, mom, lr, g, self.decay, self.momentum, self.epsilon)
|
||||
|
||||
|
||||
def rmsprop_numpy(variable, gradients, mean_square, moment,
|
||||
|
@ -76,13 +82,16 @@ def test_rmsprop():
|
|||
if centered:
|
||||
rmspropcented_numpy(variable_np, gradients_np, mean_gradients_np, mean_square_np, moment_np,
|
||||
learning_rate, decay, momentum, epsilon)
|
||||
net = NetCenteredRMSProp()
|
||||
_ = net(variable_ms, gradients_ms, mean_gradients_ms, mean_square_ms,
|
||||
moment_ms, learning_rate, decay, momentum, epsilon)
|
||||
|
||||
else:
|
||||
rmsprop_numpy(variable_np, gradients_np, mean_square_np, moment_np,
|
||||
learning_rate, decay, momentum, epsilon)
|
||||
|
||||
net = NetRMSProp(centered)
|
||||
_ = net(variable_ms, gradients_ms, mean_gradients_ms, mean_square_ms,
|
||||
moment_ms, learning_rate, decay, momentum, epsilon)
|
||||
net = NetRMSProp(decay, momentum, epsilon)
|
||||
_ = net(variable_ms, gradients_ms, mean_gradients_ms, mean_square_ms,
|
||||
moment_ms, learning_rate)
|
||||
|
||||
error = np.ones(shape=variable_np.shape) * 10e-6
|
||||
diff = variable_ms.asnumpy() - variable_np
|
||||
|
@ -126,13 +135,15 @@ def test_rmspropcenter():
|
|||
if centered:
|
||||
rmspropcented_numpy(variable_np, gradients_np, mean_gradients_np, mean_square_np, moment_np,
|
||||
learning_rate, decay, momentum, epsilon)
|
||||
net = NetCenteredRMSProp()
|
||||
_ = net(variable_ms, gradients_ms, mean_gradients_ms, mean_square_ms, moment_ms,
|
||||
learning_rate, decay, momentum, epsilon)
|
||||
else:
|
||||
rmsprop_numpy(variable_np, gradients_np, mean_square_np, moment_np,
|
||||
learning_rate, decay, momentum, epsilon)
|
||||
|
||||
net = NetRMSProp(centered)
|
||||
_ = net(variable_ms, gradients_ms, mean_gradients_ms, mean_square_ms, moment_ms,
|
||||
learning_rate, decay, momentum, epsilon)
|
||||
net = NetRMSProp(decay, momentum, epsilon)
|
||||
_ = net(variable_ms, gradients_ms, mean_gradients_ms, mean_square_ms, moment_ms,
|
||||
learning_rate)
|
||||
|
||||
error = np.ones(shape=variable_np.shape) * 10e-6
|
||||
diff = variable_ms.asnumpy() - variable_np
|
Loading…
Reference in New Issue