forked from mindspore-Ecosystem/mindspore
!1861 fix ApplyCenteredRMSProp Python API params and TBE params order map BUG
Merge pull request !1861 from zhouneng/fix_applycenterrmsprop_params_map_bug
This commit is contained in:
commit
63bb429633
|
@ -177,6 +177,18 @@ void TbeAdapter::InputOrderPass(const std::string &op_name, std::vector<std::vec
|
|||
for (size_t i = 3; i < inputs_list.size(); ++i) {
|
||||
inputs_json->push_back(inputs_list[i]);
|
||||
}
|
||||
} else if (op_name == "ApplyCenteredRMSProp") {
|
||||
// Parameter order of ApplyCenteredRMSProp's TBE implementation is different from python API, so map
|
||||
// TBE parameter to correspond python API parameter by latter's index using hardcode
|
||||
inputs_json->push_back(inputs_list[0]);
|
||||
inputs_json->push_back(inputs_list[1]);
|
||||
inputs_json->push_back(inputs_list[2]);
|
||||
inputs_json->push_back(inputs_list[3]);
|
||||
inputs_json->push_back(inputs_list[5]);
|
||||
inputs_json->push_back(inputs_list[6]);
|
||||
inputs_json->push_back(inputs_list[7]);
|
||||
inputs_json->push_back(inputs_list[8]);
|
||||
inputs_json->push_back(inputs_list[4]);
|
||||
} else {
|
||||
inputs_json->push_back(inputs_list[1]);
|
||||
inputs_json->push_back(inputs_list[0]);
|
||||
|
|
|
@ -1807,18 +1807,23 @@ class ApplyCenteredRMSProp(PrimitiveWithInfer):
|
|||
|
||||
Examples:
|
||||
>>> centered_rms_prop = P.ApplyCenteredRMSProp()
|
||||
>>> input_x = Tensor(1., mindspore.float32)
|
||||
>>> mean_grad = Tensor(2., mindspore.float32)
|
||||
>>> mean_square = Tensor(1., mindspore.float32)
|
||||
>>> moment = Tensor(2., mindspore.float32)
|
||||
>>> grad = Tensor(1., mindspore.float32)
|
||||
>>> input_x = Tensor(np.arange(-6, 6).astype(np.float32).reshape(2, 3, 2), mindspore.float32)
|
||||
>>> mean_grad = Tensor(np.arange(12).astype(np.float32).reshape(2, 3, 2), mindspore.float32)
|
||||
>>> mean_square = Tensor(np.arange(-8, 4).astype(np.float32).reshape(2, 3, 2), mindspore.float32)
|
||||
>>> moment = Tensor(np.arange(12).astype(np.float32).reshape(2, 3, 2), mindspore.float32)
|
||||
>>> grad = Tensor(np.arange(12).astype(np.float32).rehspae(2, 3, 2), mindspore.float32)
|
||||
>>> learning_rate = Tensor(0.9, mindspore.float32)
|
||||
>>> decay = 0.0
|
||||
>>> momentum = 1e-10
|
||||
>>> epsilon = 0.001
|
||||
>>> epsilon = 0.05
|
||||
>>> result = centered_rms_prop(input_x, mean_grad, mean_square, moment, grad,
|
||||
>>> learning_rate, decay, momentum, epsilon)
|
||||
-27.460497
|
||||
[[[ -6. -9.024922]
|
||||
[-12.049845 -15.074766]
|
||||
[-18.09969 -21.124613]]
|
||||
[[-24.149532 -27.174456]
|
||||
[-30.199379 -33.2243 ]
|
||||
[-36.249226 -39.274143]]]
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
|
|
Loading…
Reference in New Issue