forked from mindspore-Ecosystem/mindspore
!31227 don't handle reshape when using adafactor optimizer in auto parallel
Merge pull request !31227 from yangzhenzhang/adafactor-parallel-skip-handle-reshape
This commit is contained in:
commit
01078763df
|
@ -53,6 +53,14 @@ class ReduceOneEliminater : public AnfVisitor {
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
// if node has keep_alive attr, it would not be eliminated.
|
||||
if (IsPrimitiveCNode(node, prim::kPrimReduceMean)) {
|
||||
if (prim->HasAttr("keep_alive") && GetValue<bool>(prim->GetAttr("keep_alive"))) {
|
||||
MS_LOG(INFO) << "keep node " << node->fullname_with_scope() << " alive";
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
// consider keep_dims
|
||||
auto keep_dims = prim->GetAttr("keep_dims");
|
||||
auto is_keep_dims = GetValue<bool>(keep_dims);
|
||||
|
|
|
@ -2903,6 +2903,7 @@ void HandleRootReshapeAndSaveStrategy(const std::vector<AnfNodePtr> &all_nodes)
|
|||
if (prim->name() != RESHAPE) {
|
||||
continue;
|
||||
}
|
||||
|
||||
Shape origin_dst_shape = GetValue<std::vector<int64_t>>(cnode->input(2)->cast<ValueNodePtr>()->value());
|
||||
if (origin_dst_shape.size() == 1 && origin_dst_shape[0] == -1) {
|
||||
continue;
|
||||
|
|
|
@ -44,6 +44,7 @@ def _approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col):
|
|||
return P.Mul()(r_factor, c_factor)
|
||||
|
||||
|
||||
reduce_mean_keep_alive = P.ReduceMean().add_prim_attr("keep_alive", True)
|
||||
_adafactor_opt = C.MultitypeFuncGraph("adafactor_opt")
|
||||
|
||||
|
||||
|
@ -78,13 +79,13 @@ def _run_opt_with_one_number(eps, clip_threshold, beta1, beta2t, weight_decay, s
|
|||
if factored:
|
||||
exp_avg_sq_row_update = F.cast(exp_avg_sq_row, grad_dtype)
|
||||
exp_avg_sq_row_update = P.Mul()(exp_avg_sq_row_update, beta2t)
|
||||
update_mean = P.ReduceMean()(update, -1) * (1.0 - beta2t)
|
||||
update_mean = reduce_mean_keep_alive(update, -1) * (1.0 - beta2t)
|
||||
exp_avg_sq_row_update = P.Add()(exp_avg_sq_row_update, update_mean)
|
||||
exp_avg_sq_row_update = F.assign(exp_avg_sq_row, F.cast(exp_avg_sq_row_update, F.dtype(exp_avg_sq_row)))
|
||||
|
||||
exp_avg_sq_col_update = F.cast(exp_avg_sq_col, grad_dtype)
|
||||
exp_avg_sq_col_update = P.Mul()(exp_avg_sq_col_update, beta2t)
|
||||
update_mean = P.ReduceMean()(update, -2) * (1.0 - beta2t)
|
||||
update_mean = reduce_mean_keep_alive(update, -2) * (1.0 - beta2t)
|
||||
exp_avg_sq_col_update = P.Add()(exp_avg_sq_col_update, update_mean)
|
||||
exp_avg_sq_col_update = F.assign(exp_avg_sq_col, F.cast(exp_avg_sq_col_update, F.dtype(exp_avg_sq_col)))
|
||||
|
||||
|
|
|
@ -27,7 +27,7 @@ class Net(Cell):
|
|||
super().__init__()
|
||||
self.add = P.TensorAdd()
|
||||
self.matmul = P.MatMul().shard(strategy1)
|
||||
self.bias_add = P.BiasAdd().shard(strategy2)
|
||||
self.add2 = P.TensorAdd().shard(strategy2)
|
||||
self.add_weight = Parameter(add_weight, "w1")
|
||||
self.mul_weight = Parameter(matmul_weight, "w2")
|
||||
self.bias = Parameter(bias, "bias")
|
||||
|
@ -37,14 +37,14 @@ class Net(Cell):
|
|||
out = self.add(x, self.add_weight)
|
||||
out = self.reshape(out, (64, 32))
|
||||
out = self.matmul(out, self.mul_weight)
|
||||
out = self.add(out, self.bias)
|
||||
out = self.add2(out, self.bias)
|
||||
return out
|
||||
|
||||
|
||||
_x = Tensor(np.ones([64, 16, 2]), dtype=ms.float32)
|
||||
_w0 = Tensor(np.ones([64, 16, 2]), dtype=ms.float32)
|
||||
_w1 = Tensor(np.ones([32, 32]), dtype=ms.float32)
|
||||
_w2 = Tensor(np.ones([32]), dtype=ms.float32)
|
||||
_w2 = Tensor(np.ones([1, 32]), dtype=ms.float32)
|
||||
_b = Tensor(np.ones([64, 16, 2]), dtype=ms.float32)
|
||||
|
||||
|
||||
|
@ -71,7 +71,7 @@ def test_opt_data_parallel():
|
|||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
|
||||
strategy1 = ((16, 1), (1, 1))
|
||||
strategy2 = ((16, 1), (1,))
|
||||
strategy2 = ((16, 1), (1, 1))
|
||||
net = Net(_w0, _w1, _w2, strategy1, strategy2)
|
||||
compile_net(net)
|
||||
|
||||
|
@ -84,7 +84,7 @@ def test_opt_model_parallel():
|
|||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
|
||||
strategy1 = ((4, 2), (2, 2))
|
||||
strategy2 = ((4, 2), (2,))
|
||||
strategy2 = ((4, 2), (1, 2))
|
||||
net = Net(_w0, _w1, _w2, strategy1, strategy2)
|
||||
compile_net(net)
|
||||
|
||||
|
@ -98,6 +98,6 @@ def test_opt_shard():
|
|||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0,
|
||||
enable_parallel_optimizer=True)
|
||||
strategy1 = ((4, 2), (2, 2))
|
||||
strategy2 = ((4, 2), (2,))
|
||||
strategy2 = ((4, 2), (1, 2))
|
||||
net = Net(_w0, _w1, _w2, strategy1, strategy2)
|
||||
compile_net(net)
|
||||
|
|
Loading…
Reference in New Issue