!1862 fixed validator for ApplyRMSProp,CumProd, CumSum,ReduceProd etc

Merge pull request !1862 from jiangjinsheng/issue_doc
This commit is contained in:
mindspore-ci-bot 2020-06-09 09:57:22 +08:00 committed by Gitee
commit 5499161531
10 changed files with 100 additions and 60 deletions

View File

@ -19,17 +19,17 @@
#include "device/gpu/cuda_common.h" #include "device/gpu/cuda_common.h"
template <typename T> template <typename T>
__global__ void RmsPropKernel(const T* learning_rate, const T* decay, const T* momentum, const T* epsilon, T* variable, __global__ void RmsPropKernel(const T* learning_rate, const T decay, const T momentum, const T epsilon, T* variable,
T* mean_square, T*moment, T* gradients, const size_t size) { T* mean_square, T*moment, T* gradients, const size_t size) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (size); i += blockDim.x * gridDim.x) { for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (size); i += blockDim.x * gridDim.x) {
mean_square[i] = decay[0] * mean_square[i] + (1.0 - decay[0]) * gradients[i] * gradients[i]; mean_square[i] = decay * mean_square[i] + (1.0 - decay) * gradients[i] * gradients[i];
moment[i] = momentum[0] * moment[i] + learning_rate[0] * rsqrt(mean_square[i] + epsilon[0]) * gradients[i]; moment[i] = momentum * moment[i] + learning_rate[0] * rsqrt(mean_square[i] + epsilon) * gradients[i];
variable[i] -= moment[i]; variable[i] -= moment[i];
} }
} }
template <typename T> template <typename T>
void RmsProp(const T* learning_rate, const T* decay, const T* momentum, const T* epsilon, void RmsProp(const T* learning_rate, const T decay, const T momentum, const T epsilon,
T* variable, T* mean_square, T* moment, T* gradients, const size_t size, cudaStream_t cuda_stream) { T* variable, T* mean_square, T* moment, T* gradients, const size_t size, cudaStream_t cuda_stream) {
RmsPropKernel<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(learning_rate, decay, momentum, epsilon, RmsPropKernel<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(learning_rate, decay, momentum, epsilon,
variable, mean_square, moment, gradients, size); variable, mean_square, moment, gradients, size);
@ -58,7 +58,7 @@ void RmsPropCenter(const T* learning_rate, const T* decay, const T* momentum, co
} }
template template
void RmsProp(const float* learning_rate, const float* decay, const float* momentum, const float* epsilon, void RmsProp(const float* learning_rate, const float decay, const float momentum, const float epsilon,
float* variable, float* mean_square, float* moment, float* gradients, const size_t size, float* variable, float* mean_square, float* moment, float* gradients, const size_t size,
cudaStream_t cuda_stream); cudaStream_t cuda_stream);

View File

@ -19,7 +19,7 @@
#include "device/gpu/cuda_common.h" #include "device/gpu/cuda_common.h"
template <typename T> template <typename T>
void RmsProp(const T* learning_rate, const T* decay, const T* momentum, const T* epsilon, T* variable, T* mean_square, void RmsProp(const T* learning_rate, const T decay, const T momentum, const T epsilon, T* variable, T* mean_square,
T* moment, T* gradients, const size_t size, cudaStream_t cuda_stream); T* moment, T* gradients, const size_t size, cudaStream_t cuda_stream);
template <typename T> template <typename T>

View File

@ -25,9 +25,6 @@ MS_REG_GPU_KERNEL_ONE(ApplyRMSProp,
.AddInputAttr(kNumberTypeFloat32) .AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32) .AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32) .AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32), .AddOutputAttr(kNumberTypeFloat32),
RMSPropGpuKernel, float) RMSPropGpuKernel, float)

View File

@ -27,7 +27,7 @@ namespace kernel {
template <typename T> template <typename T>
class RMSPropGpuKernel : public GpuKernel { class RMSPropGpuKernel : public GpuKernel {
public: public:
RMSPropGpuKernel() : size_(1), use_center_(false) {} RMSPropGpuKernel() : size_(1), use_center_(false), decay_(0.0), momentum_(0.9), epsilon_(1e-12) {}
~RMSPropGpuKernel() override = default; ~RMSPropGpuKernel() override = default;
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
@ -40,13 +40,10 @@ class RMSPropGpuKernel : public GpuKernel {
T *variable = GetDeviceAddress<T>(inputs, 0); T *variable = GetDeviceAddress<T>(inputs, 0);
T *mean_square = GetDeviceAddress<T>(inputs, 1); T *mean_square = GetDeviceAddress<T>(inputs, 1);
T *moment = GetDeviceAddress<T>(inputs, 2); T *moment = GetDeviceAddress<T>(inputs, 2);
T *gradients = GetDeviceAddress<T>(inputs, 3); T *learning_rate = GetDeviceAddress<T>(inputs, 3);
T *learning_rate = GetDeviceAddress<T>(inputs, 4); T *gradients = GetDeviceAddress<T>(inputs, 4);
T *decay = GetDeviceAddress<T>(inputs, 5);
T *momentum = GetDeviceAddress<T>(inputs, 6);
T *epsilon = GetDeviceAddress<T>(inputs, 7);
RmsProp(learning_rate, decay, momentum, epsilon, variable, mean_square, moment, gradients, size_, RmsProp(learning_rate, decay_, momentum_, epsilon_, variable, mean_square, moment, gradients, size_,
reinterpret_cast<cudaStream_t>(stream)); reinterpret_cast<cudaStream_t>(stream));
} else { } else {
T *variable = GetDeviceAddress<T>(inputs, 0); T *variable = GetDeviceAddress<T>(inputs, 0);
@ -70,6 +67,11 @@ class RMSPropGpuKernel : public GpuKernel {
use_center_ = true; use_center_ = true;
} }
if (node_name == "ApplyRMSProp") {
decay_ = GetAttr<float>(kernel_node, "rho");
momentum_ = GetAttr<float>(kernel_node, "momentum");
epsilon_ = GetAttr<float>(kernel_node, "epsilon");
}
auto input_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); auto input_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0);
for (auto &dim : input_shape) { for (auto &dim : input_shape) {
size_ *= dim; size_ *= dim;
@ -81,24 +83,33 @@ class RMSPropGpuKernel : public GpuKernel {
protected: protected:
void InitSizeLists() override { void InitSizeLists() override {
size_t input_size = size_ * sizeof(T); size_t input_size = size_ * sizeof(T);
input_size_list_.push_back(input_size); if (!use_center_) {
if (use_center_) {
input_size_list_.push_back(input_size); input_size_list_.push_back(input_size);
input_size_list_.push_back(input_size);
input_size_list_.push_back(input_size);
input_size_list_.push_back(sizeof(T));
input_size_list_.push_back(input_size);
output_size_list_.push_back(input_size);
} else {
input_size_list_.push_back(input_size);
input_size_list_.push_back(input_size);
input_size_list_.push_back(input_size);
input_size_list_.push_back(input_size);
input_size_list_.push_back(input_size);
input_size_list_.push_back(sizeof(T));
input_size_list_.push_back(sizeof(T));
input_size_list_.push_back(sizeof(T));
input_size_list_.push_back(sizeof(T));
output_size_list_.push_back(input_size);
} }
input_size_list_.push_back(input_size);
input_size_list_.push_back(input_size);
input_size_list_.push_back(input_size);
input_size_list_.push_back(sizeof(T));
input_size_list_.push_back(sizeof(T));
input_size_list_.push_back(sizeof(T));
input_size_list_.push_back(sizeof(T));
output_size_list_.push_back(0);
} }
private: private:
size_t size_; size_t size_;
bool use_center_; bool use_center_;
float decay_;
float momentum_;
float epsilon_;
std::vector<size_t> input_size_list_; std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_; std::vector<size_t> output_size_list_;

View File

@ -182,8 +182,7 @@ class Adam(Optimizer):
If True, updates the gradients using NAG. If True, updates the gradients using NAG.
If False, updates the gradients without using NAG. Default: False. If False, updates the gradients without using NAG. Default: False.
weight_decay (float): Weight decay (L2 penalty). Default: 0.0. weight_decay (float): Weight decay (L2 penalty). Default: 0.0.
loss_scale (float): A floating point value for the loss scale. Should be equal to or greater than 1. Default: loss_scale (float): A floating point value for the loss scale. Should be greater than 0. Default: 1.0.
1.0.
Inputs: Inputs:
- **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`. - **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.
@ -220,7 +219,7 @@ class Adam(Optimizer):
validator.check_value_type("use_locking", use_locking, [bool], self.cls_name) validator.check_value_type("use_locking", use_locking, [bool], self.cls_name)
validator.check_value_type("use_nesterov", use_nesterov, [bool], self.cls_name) validator.check_value_type("use_nesterov", use_nesterov, [bool], self.cls_name)
validator.check_value_type("loss_scale", loss_scale, [float], self.cls_name) validator.check_value_type("loss_scale", loss_scale, [float], self.cls_name)
validator.check_number_range("loss_scale", loss_scale, 1.0, float("inf"), Rel.INC_LEFT, self.cls_name) validator.check_number_range("loss_scale", loss_scale, 0.0, float("inf"), Rel.INC_LEFT, self.cls_name)
self.beta1 = Tensor(beta1, mstype.float32) self.beta1 = Tensor(beta1, mstype.float32)
self.beta2 = Tensor(beta2, mstype.float32) self.beta2 = Tensor(beta2, mstype.float32)

View File

@ -122,7 +122,8 @@ class SameTypeShape(PrimitiveWithInfer):
Checks whether data type and shape of two tensors are the same. Checks whether data type and shape of two tensors are the same.
Raises: Raises:
TypeError or ValueError: If not the same. TypeError - If data type not the same.
ValueError - If shape of two tensors not the same.
Inputs: Inputs:
- **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`. - **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
@ -1040,7 +1041,7 @@ class InvertPermutation(PrimitiveWithInfer):
- **input_x** (Union(tuple[int], Tensor[int])) - The input tuple is constructed by multiple - **input_x** (Union(tuple[int], Tensor[int])) - The input tuple is constructed by multiple
integers, i.e., :math:`(y_1, y_2, ..., y_S)` representing the indices. integers, i.e., :math:`(y_1, y_2, ..., y_S)` representing the indices.
The values must include 0. There can be no duplicate values or negative values. The values must include 0. There can be no duplicate values or negative values.
If the input is Tensor, it must be 1-d and the dtype is int. If the input is Tensor, it must be 1-d and the dtype is int. Only constant value is allowed.
Outputs: Outputs:
@ -1070,7 +1071,9 @@ class InvertPermutation(PrimitiveWithInfer):
z = [x_value[i] for i in range(len(x_value))] z = [x_value[i] for i in range(len(x_value))]
z.sort() z.sort()
validator.check(f'value length', len(x_value), f'unique value length', len(set(x_value)), Rel.EQ, self.name) for i in range(1, len(z)):
if z[i-1] == z[i]:
raise ValueError(f"For {self.name}, {z[i]} is duplicated in the input.")
validator.check(f'value min', min(x_value), '', 0, Rel.EQ, self.name) validator.check(f'value min', min(x_value), '', 0, Rel.EQ, self.name)
validator.check(f'value max', max(x_value), '', len(x_value)-1, Rel.EQ, self.name) validator.check(f'value max', max(x_value), '', len(x_value)-1, Rel.EQ, self.name)

View File

@ -260,6 +260,8 @@ class _Reduce(PrimitiveWithInfer):
args = {'input_x': input_x['dtype']} args = {'input_x': input_x['dtype']}
validator.check_tensor_type_same(args, valid_dtype, self.name) validator.check_tensor_type_same(args, valid_dtype, self.name)
if axis_v is None:
raise ValueError(f"For {self.name}, axis must be const.")
input_shp = _infer_shape_reduce(input_shp, axis_v, self.keep_dims, self.name) input_shp = _infer_shape_reduce(input_shp, axis_v, self.keep_dims, self.name)
return {'shape': input_shp, return {'shape': input_shp,
'dtype': input_x['dtype'], 'dtype': input_x['dtype'],
@ -447,8 +449,9 @@ class ReduceProd(_Reduce):
Default : False, don't keep these reduced dimensions. Default : False, don't keep these reduced dimensions.
Inputs: Inputs:
- **input_x** (Tensor[Number]) - The input tensor. - **input_x** (Tensor[Number]) - The input tensor.
- **axis** (Union[int, tuple(int), list(int)]) - The dimensions to reduce. Default: (), reduce all dimensions. - **axis** (Union[int, tuple(int), list(int)]) - The dimensions to reduce. Default: (), reduce all dimensions.
Only constant value is allowed.
Outputs: Outputs:
Tensor, has the same dtype as the 'input_x'. Tensor, has the same dtype as the 'input_x'.
@ -476,8 +479,9 @@ class CumProd(PrimitiveWithInfer):
reverse (bool): If True, reverse the result along axis. Default: False reverse (bool): If True, reverse the result along axis. Default: False
Inputs: Inputs:
- **input_x** (Tensor[Number]) - The input tensor. - **input_x** (Tensor[Number]) - The input tensor.
- **axis** (int) - The dimensions to compute the cumulative product. - **axis** (int) - The dimensions to compute the cumulative product.
Only constant value is allowed.
Outputs: Outputs:
Tensor, has the same shape and dtype as the 'input_x'. Tensor, has the same shape and dtype as the 'input_x'.
@ -509,6 +513,10 @@ class CumProd(PrimitiveWithInfer):
validator.check_subclass("axis", axis_type, mstype.int_, cls_name) validator.check_subclass("axis", axis_type, mstype.int_, cls_name)
return x_type return x_type
def infer_value(self, x, axis):
if axis is None:
raise ValueError(f"For {self.name}, axis must be const.")
class MatMul(PrimitiveWithInfer): class MatMul(PrimitiveWithInfer):
""" """
@ -671,6 +679,10 @@ class CumSum(PrimitiveWithInfer):
'dtype': x['dtype'], 'dtype': x['dtype'],
'value': None} 'value': None}
def infer_value(self, x, axis):
if axis is None:
raise ValueError(f"For {self.name}, axis must be const.")
class AddN(PrimitiveWithInfer): class AddN(PrimitiveWithInfer):
""" """

View File

@ -1707,9 +1707,9 @@ class ApplyRMSProp(PrimitiveWithInfer):
- **moment** (Tensor) - Delta of `var`, must have the same type as `var`. - **moment** (Tensor) - Delta of `var`, must have the same type as `var`.
- **learning_rate** (Union[Number, Tensor]) - Learning rate. - **learning_rate** (Union[Number, Tensor]) - Learning rate.
- **grad** (Tensor) - Gradients, must have the same type as `var`. - **grad** (Tensor) - Gradients, must have the same type as `var`.
- **decay** (float) - Decay rate. - **decay** (float) - Decay rate. Only constant value is allowed.
- **momentum** (float) - Momentum. - **momentum** (float) - Momentum. Only constant value is allowed.
- **epsilon** (float) - Ridge term. - **epsilon** (float) - Ridge term. Only constant value is allowed.
Outputs: Outputs:
Tensor, parameters to be update. Tensor, parameters to be update.
@ -1759,6 +1759,13 @@ class ApplyRMSProp(PrimitiveWithInfer):
return var_dtype, var_dtype, var_dtype return var_dtype, var_dtype, var_dtype
return var_dtype return var_dtype
def infer_value(self, var, mean_square, moment, learning_rate, grad, decay, momentum, epsilon):
if decay is None or momentum is None or epsilon is None:
raise ValueError(f"For {self.name}, decay, momentum, epsilon must be const.")
if not self.is_ge and self.is_d:
return None, None, None
return None
class ApplyCenteredRMSProp(PrimitiveWithInfer): class ApplyCenteredRMSProp(PrimitiveWithInfer):
""" """

View File

@ -379,7 +379,7 @@ class ConfusionMatrix(PrimitiveWithInfer):
Inputs: Inputs:
- **labels** (Tensor) - real labels, tensor of 1-D. the dtype must be non-negative Integer. - **labels** (Tensor) - real labels, tensor of 1-D. the dtype must be non-negative Integer.
- **predictions** (Tensor) - the labels from prediction, tensor of 1-D. - **predictions** (Tensor) - the labels from prediction, tensor of 1-D.
the shape same as `labels` and the dtype must be non-negative Integer. the shape same as `labels` and the dtype must be non-negative Integer.
- **weights** (Tensor) - tensor of 1-D. the shape same as `predictions`. - **weights** (Tensor) - tensor of 1-D. the shape same as `predictions`.
Outputs: Outputs:

View File

@ -24,19 +24,25 @@ from mindspore.ops import operations as P
context.set_context(mode=context.GRAPH_MODE, device_target="GPU") context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
class NetRMSProp(nn.Cell): class NetCenteredRMSProp(nn.Cell):
def __init__(self, use_centered): def __init__(self):
super(NetRMSProp, self).__init__() super(NetCenteredRMSProp, self).__init__()
self.use_centered = use_centered self.rms_opt = P.ApplyCenteredRMSProp()
if use_centered:
self.rms_opt = P.ApplyCenteredRMSProp()
else:
self.rms_opt = P.ApplyRMSProp()
def construct(self, var, g, mg, rms, mom, lr, decay, momentum, epsilon): def construct(self, var, g, mg, rms, mom, lr, decay, momentum, epsilon):
if self.use_centered: return self.rms_opt(var, mg, rms, mom, g, lr, decay, momentum, epsilon)
return self.rms_opt(var, mg, rms, mom, g, lr, decay, momentum, epsilon)
return self.rms_opt(var, rms, mom, lr, g, decay, momentum, epsilon)
class NetRMSProp(nn.Cell):
def __init__(self, decay, momentum, epsilon):
super(NetRMSProp, self).__init__()
self.decay = decay
self.momentum = momentum
self.epsilon = epsilon
self.rms_opt = P.ApplyRMSProp()
def construct(self, var, g, mg, rms, mom, lr):
return self.rms_opt(var, rms, mom, lr, g, self.decay, self.momentum, self.epsilon)
def rmsprop_numpy(variable, gradients, mean_square, moment, def rmsprop_numpy(variable, gradients, mean_square, moment,
@ -76,13 +82,16 @@ def test_rmsprop():
if centered: if centered:
rmspropcented_numpy(variable_np, gradients_np, mean_gradients_np, mean_square_np, moment_np, rmspropcented_numpy(variable_np, gradients_np, mean_gradients_np, mean_square_np, moment_np,
learning_rate, decay, momentum, epsilon) learning_rate, decay, momentum, epsilon)
net = NetCenteredRMSProp()
_ = net(variable_ms, gradients_ms, mean_gradients_ms, mean_square_ms,
moment_ms, learning_rate, decay, momentum, epsilon)
else: else:
rmsprop_numpy(variable_np, gradients_np, mean_square_np, moment_np, rmsprop_numpy(variable_np, gradients_np, mean_square_np, moment_np,
learning_rate, decay, momentum, epsilon) learning_rate, decay, momentum, epsilon)
net = NetRMSProp(decay, momentum, epsilon)
net = NetRMSProp(centered) _ = net(variable_ms, gradients_ms, mean_gradients_ms, mean_square_ms,
_ = net(variable_ms, gradients_ms, mean_gradients_ms, mean_square_ms, moment_ms, learning_rate)
moment_ms, learning_rate, decay, momentum, epsilon)
error = np.ones(shape=variable_np.shape) * 10e-6 error = np.ones(shape=variable_np.shape) * 10e-6
diff = variable_ms.asnumpy() - variable_np diff = variable_ms.asnumpy() - variable_np
@ -126,13 +135,15 @@ def test_rmspropcenter():
if centered: if centered:
rmspropcented_numpy(variable_np, gradients_np, mean_gradients_np, mean_square_np, moment_np, rmspropcented_numpy(variable_np, gradients_np, mean_gradients_np, mean_square_np, moment_np,
learning_rate, decay, momentum, epsilon) learning_rate, decay, momentum, epsilon)
net = NetCenteredRMSProp()
_ = net(variable_ms, gradients_ms, mean_gradients_ms, mean_square_ms, moment_ms,
learning_rate, decay, momentum, epsilon)
else: else:
rmsprop_numpy(variable_np, gradients_np, mean_square_np, moment_np, rmsprop_numpy(variable_np, gradients_np, mean_square_np, moment_np,
learning_rate, decay, momentum, epsilon) learning_rate, decay, momentum, epsilon)
net = NetRMSProp(decay, momentum, epsilon)
net = NetRMSProp(centered) _ = net(variable_ms, gradients_ms, mean_gradients_ms, mean_square_ms, moment_ms,
_ = net(variable_ms, gradients_ms, mean_gradients_ms, mean_square_ms, moment_ms, learning_rate)
learning_rate, decay, momentum, epsilon)
error = np.ones(shape=variable_np.shape) * 10e-6 error = np.ones(shape=variable_np.shape) * 10e-6
diff = variable_ms.asnumpy() - variable_np diff = variable_ms.asnumpy() - variable_np