!31227 don't handle reshape when using adafactor optimizer in auto parallel

Merge pull request !31227 from yangzhenzhang/adafactor-parallel-skip-handle-reshape
This commit is contained in:
i-robot 2022-03-15 02:47:31 +00:00 committed by Gitee
commit 01078763df
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
4 changed files with 18 additions and 8 deletions

View File

@ -53,6 +53,14 @@ class ReduceOneEliminater : public AnfVisitor {
return nullptr;
}
// if node has keep_alive attr, it would not be eliminated.
if (IsPrimitiveCNode(node, prim::kPrimReduceMean)) {
if (prim->HasAttr("keep_alive") && GetValue<bool>(prim->GetAttr("keep_alive"))) {
MS_LOG(INFO) << "keep node " << node->fullname_with_scope() << " alive";
return nullptr;
}
}
// consider keep_dims
auto keep_dims = prim->GetAttr("keep_dims");
auto is_keep_dims = GetValue<bool>(keep_dims);

View File

@ -2903,6 +2903,7 @@ void HandleRootReshapeAndSaveStrategy(const std::vector<AnfNodePtr> &all_nodes)
if (prim->name() != RESHAPE) {
continue;
}
Shape origin_dst_shape = GetValue<std::vector<int64_t>>(cnode->input(2)->cast<ValueNodePtr>()->value());
if (origin_dst_shape.size() == 1 && origin_dst_shape[0] == -1) {
continue;

View File

@ -44,6 +44,7 @@ def _approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col):
return P.Mul()(r_factor, c_factor)
reduce_mean_keep_alive = P.ReduceMean().add_prim_attr("keep_alive", True)
_adafactor_opt = C.MultitypeFuncGraph("adafactor_opt")
@ -78,13 +79,13 @@ def _run_opt_with_one_number(eps, clip_threshold, beta1, beta2t, weight_decay, s
if factored:
exp_avg_sq_row_update = F.cast(exp_avg_sq_row, grad_dtype)
exp_avg_sq_row_update = P.Mul()(exp_avg_sq_row_update, beta2t)
update_mean = P.ReduceMean()(update, -1) * (1.0 - beta2t)
update_mean = reduce_mean_keep_alive(update, -1) * (1.0 - beta2t)
exp_avg_sq_row_update = P.Add()(exp_avg_sq_row_update, update_mean)
exp_avg_sq_row_update = F.assign(exp_avg_sq_row, F.cast(exp_avg_sq_row_update, F.dtype(exp_avg_sq_row)))
exp_avg_sq_col_update = F.cast(exp_avg_sq_col, grad_dtype)
exp_avg_sq_col_update = P.Mul()(exp_avg_sq_col_update, beta2t)
update_mean = P.ReduceMean()(update, -2) * (1.0 - beta2t)
update_mean = reduce_mean_keep_alive(update, -2) * (1.0 - beta2t)
exp_avg_sq_col_update = P.Add()(exp_avg_sq_col_update, update_mean)
exp_avg_sq_col_update = F.assign(exp_avg_sq_col, F.cast(exp_avg_sq_col_update, F.dtype(exp_avg_sq_col)))

View File

@ -27,7 +27,7 @@ class Net(Cell):
super().__init__()
self.add = P.TensorAdd()
self.matmul = P.MatMul().shard(strategy1)
self.bias_add = P.BiasAdd().shard(strategy2)
self.add2 = P.TensorAdd().shard(strategy2)
self.add_weight = Parameter(add_weight, "w1")
self.mul_weight = Parameter(matmul_weight, "w2")
self.bias = Parameter(bias, "bias")
@ -37,14 +37,14 @@ class Net(Cell):
out = self.add(x, self.add_weight)
out = self.reshape(out, (64, 32))
out = self.matmul(out, self.mul_weight)
out = self.add(out, self.bias)
out = self.add2(out, self.bias)
return out
_x = Tensor(np.ones([64, 16, 2]), dtype=ms.float32)
_w0 = Tensor(np.ones([64, 16, 2]), dtype=ms.float32)
_w1 = Tensor(np.ones([32, 32]), dtype=ms.float32)
_w2 = Tensor(np.ones([32]), dtype=ms.float32)
_w2 = Tensor(np.ones([1, 32]), dtype=ms.float32)
_b = Tensor(np.ones([64, 16, 2]), dtype=ms.float32)
@ -71,7 +71,7 @@ def test_opt_data_parallel():
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
strategy1 = ((16, 1), (1, 1))
strategy2 = ((16, 1), (1,))
strategy2 = ((16, 1), (1, 1))
net = Net(_w0, _w1, _w2, strategy1, strategy2)
compile_net(net)
@ -84,7 +84,7 @@ def test_opt_model_parallel():
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
strategy1 = ((4, 2), (2, 2))
strategy2 = ((4, 2), (2,))
strategy2 = ((4, 2), (1, 2))
net = Net(_w0, _w1, _w2, strategy1, strategy2)
compile_net(net)
@ -98,6 +98,6 @@ def test_opt_shard():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0,
enable_parallel_optimizer=True)
strategy1 = ((4, 2), (2, 2))
strategy2 = ((4, 2), (2,))
strategy2 = ((4, 2), (1, 2))
net = Net(_w0, _w1, _w2, strategy1, strategy2)
compile_net(net)