forked from mindspore-Ecosystem/mindspore
!1861 fix ApplyCenteredRMSProp Python API params and TBE params order map BUG
Merge pull request !1861 from zhouneng/fix_applycenterrmsprop_params_map_bug
This commit is contained in:
commit
63bb429633
|
@ -177,6 +177,18 @@ void TbeAdapter::InputOrderPass(const std::string &op_name, std::vector<std::vec
|
||||||
for (size_t i = 3; i < inputs_list.size(); ++i) {
|
for (size_t i = 3; i < inputs_list.size(); ++i) {
|
||||||
inputs_json->push_back(inputs_list[i]);
|
inputs_json->push_back(inputs_list[i]);
|
||||||
}
|
}
|
||||||
|
} else if (op_name == "ApplyCenteredRMSProp") {
|
||||||
|
// Parameter order of ApplyCenteredRMSProp's TBE implementation is different from python API, so map
|
||||||
|
// TBE parameter to correspond python API parameter by latter's index using hardcode
|
||||||
|
inputs_json->push_back(inputs_list[0]);
|
||||||
|
inputs_json->push_back(inputs_list[1]);
|
||||||
|
inputs_json->push_back(inputs_list[2]);
|
||||||
|
inputs_json->push_back(inputs_list[3]);
|
||||||
|
inputs_json->push_back(inputs_list[5]);
|
||||||
|
inputs_json->push_back(inputs_list[6]);
|
||||||
|
inputs_json->push_back(inputs_list[7]);
|
||||||
|
inputs_json->push_back(inputs_list[8]);
|
||||||
|
inputs_json->push_back(inputs_list[4]);
|
||||||
} else {
|
} else {
|
||||||
inputs_json->push_back(inputs_list[1]);
|
inputs_json->push_back(inputs_list[1]);
|
||||||
inputs_json->push_back(inputs_list[0]);
|
inputs_json->push_back(inputs_list[0]);
|
||||||
|
|
|
@ -1807,18 +1807,23 @@ class ApplyCenteredRMSProp(PrimitiveWithInfer):
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> centered_rms_prop = P.ApplyCenteredRMSProp()
|
>>> centered_rms_prop = P.ApplyCenteredRMSProp()
|
||||||
>>> input_x = Tensor(1., mindspore.float32)
|
>>> input_x = Tensor(np.arange(-6, 6).astype(np.float32).reshape(2, 3, 2), mindspore.float32)
|
||||||
>>> mean_grad = Tensor(2., mindspore.float32)
|
>>> mean_grad = Tensor(np.arange(12).astype(np.float32).reshape(2, 3, 2), mindspore.float32)
|
||||||
>>> mean_square = Tensor(1., mindspore.float32)
|
>>> mean_square = Tensor(np.arange(-8, 4).astype(np.float32).reshape(2, 3, 2), mindspore.float32)
|
||||||
>>> moment = Tensor(2., mindspore.float32)
|
>>> moment = Tensor(np.arange(12).astype(np.float32).reshape(2, 3, 2), mindspore.float32)
|
||||||
>>> grad = Tensor(1., mindspore.float32)
|
>>> grad = Tensor(np.arange(12).astype(np.float32).rehspae(2, 3, 2), mindspore.float32)
|
||||||
>>> learning_rate = Tensor(0.9, mindspore.float32)
|
>>> learning_rate = Tensor(0.9, mindspore.float32)
|
||||||
>>> decay = 0.0
|
>>> decay = 0.0
|
||||||
>>> momentum = 1e-10
|
>>> momentum = 1e-10
|
||||||
>>> epsilon = 0.001
|
>>> epsilon = 0.05
|
||||||
>>> result = centered_rms_prop(input_x, mean_grad, mean_square, moment, grad,
|
>>> result = centered_rms_prop(input_x, mean_grad, mean_square, moment, grad,
|
||||||
>>> learning_rate, decay, momentum, epsilon)
|
>>> learning_rate, decay, momentum, epsilon)
|
||||||
-27.460497
|
[[[ -6. -9.024922]
|
||||||
|
[-12.049845 -15.074766]
|
||||||
|
[-18.09969 -21.124613]]
|
||||||
|
[[-24.149532 -27.174456]
|
||||||
|
[-30.199379 -33.2243 ]
|
||||||
|
[-36.249226 -39.274143]]]
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@prim_attr_register
|
@prim_attr_register
|
||||||
|
|
Loading…
Reference in New Issue