support opt parallel for adafactor

This commit is contained in:
yangzhenzhang 2021-12-10 11:10:45 +08:00
parent db7d28f5c8
commit 2a0b528084
3 changed files with 113 additions and 24 deletions

View File

@ -360,7 +360,7 @@ void HandleNoUsedParameter(const FuncGraphPtr &root) {
continue;
}
Shape slice_shape = parameter_shape[0];
if (slice_shape.empty()) {
if (slice_shape.empty() || slice_shape[0] < dev_num) {
continue;
}
slice_shape[0] = slice_shape[0] / dev_num;
@ -554,29 +554,78 @@ void SetClonedTensorShapeForOptimizer(const FuncGraphPtr &root) {
}
}
// For adafactor optimizer, the relationship between parameter and state's shape as follows:
// 1) parameter: [A, B, C, D] (shape_size > 2), exp_avg_sq_row: [A, B, C], exp_avg_sq_col: [A, B, D], exp_avg_sq: [1]
// If the parameter is opt shard, the exp_avg_sq_row and exp_avg_sq_col need to be shard accordingly.
//
// 2) parameter: [A, B] (shape_size = 2), exp_avg_sq_row: [A], exp_avg_sq_col: [B], exp_avg_sq: [1]
// If the parameter is opt shard, the exp_avg_sq_row needs to be shard accordingly.
//
// 3) parameter: [A] (shape_size = 1), exp_avg_sq_row: [1], exp_avg_sq_col: [1], exp_avg_sq: [A]
// If the parameter is opt shard, the exp_avg_sq needs to be shard accordingly.
static bool AdafactorStateIsOptShard(const std::string &opt_shard_group, size_t shape_size,
const std::string &param_name, const std::string &state_name) {
if (opt_shard_group.empty()) {
return false;
}
std::string exp_row_name = EXP_AVG_SQ_ROW + param_name;
std::string exp_col_name = EXP_AVG_SQ_COL + param_name;
std::string exp_avg_name = EXP_AVG_SQ + param_name;
if (shape_size > 2 && state_name == exp_avg_name) {
return false;
}
if (shape_size == 2 && (state_name == exp_col_name || state_name == exp_avg_name)) {
return false;
}
if (shape_size == 1 && (state_name == exp_row_name || state_name == exp_col_name)) {
return false;
}
MS_LOG(INFO) << "The parameter " << param_name << " is opt shard";
return true;
}
static bool IsOriginWeight(const ParameterPtr &param) {
std::string param_name = param->name();
if (param_name.find(EXP_AVG) != std::string::npos) {
return false;
}
auto tensor_layout = param->user_data<TensorLayout>();
if (tensor_layout == nullptr) {
return false;
}
return true;
}
void HandleAdaFactorOpt(const FuncGraphPtr &root) {
MS_EXCEPTION_IF_NULL(root);
for (auto &param_node : root->parameters()) {
MS_EXCEPTION_IF_NULL(param_node);
auto param = param_node->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(param);
std::string param_name = param->name();
if (param_name.find(EXP_AVG) != std::string::npos) {
continue;
}
auto tensor_layout = param->user_data<TensorLayout>();
if (tensor_layout == nullptr) {
if (!IsOriginWeight(param)) {
continue;
}
int64_t row_col_count = 0;
int64_t exp_avg_sq_count = 0;
for (auto &row_col_node : root->parameters()) {
if (row_col_count == 2 && exp_avg_sq_count == 1) {
break;
}
MS_EXCEPTION_IF_NULL(row_col_node);
auto row_col_param = row_col_node->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(row_col_param);
std::string row_col_param_name = row_col_param->name();
std::string param_name = param->name();
std::string exp_row_name = EXP_AVG_SQ_ROW + param_name;
std::string exp_col_name = EXP_AVG_SQ_COL + param_name;
std::string exp_avg_name = EXP_AVG_SQ + param_name;
@ -586,14 +635,23 @@ void HandleAdaFactorOpt(const FuncGraphPtr &root) {
continue;
}
auto tensor_layout = param->user_data<TensorLayout>();
MS_EXCEPTION_IF_NULL(tensor_layout);
auto slice_shape = tensor_layout->slice_shape().array();
Shape opt_shard_slice_shape = slice_shape;
if (!tensor_layout->opt_shard_group().empty()) {
opt_shard_slice_shape = tensor_layout->opt_shard_slice_shape();
}
auto shape_size = slice_shape.size();
bool is_row_or_col_param = (row_col_param_name == exp_row_name) || (row_col_param_name == exp_col_name);
if (is_row_or_col_param && shape_size <= 1) {
row_col_count++;
continue;
}
if (row_col_param_name == exp_avg_name && shape_size != 1) {
exp_avg_sq_count++;
continue;
}
@ -602,12 +660,13 @@ void HandleAdaFactorOpt(const FuncGraphPtr &root) {
auto tensor_map = tensor_layout->tensor_map().array();
if (row_col_param_name == exp_row_name) {
slice_shape.pop_back();
opt_shard_slice_shape.pop_back();
origin_shape.pop_back();
tensor_map.pop_back();
row_col_count++;
} else if (row_col_param_name == exp_col_name) {
(void)slice_shape.erase(slice_shape.begin() + static_cast<different_type>(SECOND_FROM_END(shape_size)));
(void)opt_shard_slice_shape.erase(opt_shard_slice_shape.begin() +
static_cast<different_type>(SECOND_FROM_END(shape_size)));
(void)origin_shape.erase(origin_shape.begin() + static_cast<different_type>(SECOND_FROM_END(shape_size)));
(void)tensor_map.erase(tensor_map.begin() + static_cast<different_type>(SECOND_FROM_END(shape_size)));
row_col_count++;
@ -620,19 +679,19 @@ void HandleAdaFactorOpt(const FuncGraphPtr &root) {
MS_LOG(EXCEPTION) << "Init tensor layout failed";
}
if (AdafactorStateIsOptShard(tensor_layout->opt_shard_group(), shape_size, param_name, row_col_param_name)) {
new_tensor_layout.set_opt_shard_group(tensor_layout->opt_shard_group());
}
auto cloned_abstract = row_col_node->abstract()->Clone();
MS_EXCEPTION_IF_NULL(cloned_abstract);
std::shared_ptr<abstract::BaseShape> parallel_shape = std::make_shared<abstract::Shape>(slice_shape);
std::shared_ptr<abstract::BaseShape> parallel_shape = std::make_shared<abstract::Shape>(opt_shard_slice_shape);
MS_EXCEPTION_IF_NULL(parallel_shape);
cloned_abstract->set_shape(parallel_shape);
row_col_param->set_user_data<TensorLayout>(std::make_shared<TensorLayout>(new_tensor_layout));
row_col_node->set_abstract(cloned_abstract);
MS_LOG(INFO) << "Set the slice shape for " << row_col_param_name << ", origin shape is " << origin_shape
<< ", new slice shape is " << slice_shape;
if (row_col_count == 2 || exp_avg_sq_count == 1) {
break;
}
<< ", new slice shape is " << opt_shard_slice_shape;
}
}
}

View File

@ -26,7 +26,7 @@
namespace mindspore {
namespace parallel {
constexpr char EXP_AVG[] = "exp_avg_";
constexpr char EXP_AVG[] = "exp_avg";
constexpr char EXP_AVG_SQ_ROW[] = "exp_avg_sq_row_";
constexpr char EXP_AVG_SQ_COL[] = "exp_avg_sq_col_";
constexpr char EXP_AVG_SQ[] = "exp_avg_sq_";

View File

@ -23,23 +23,29 @@ from mindspore.ops import operations as P
class Net(Cell):
def __init__(self, matmul_weight, add_weight, strategy1=None, strategy2=None):
def __init__(self, add_weight, matmul_weight, bias, strategy1=None, strategy2=None):
super().__init__()
self.add = P.TensorAdd()
self.matmul = P.MatMul().shard(strategy1)
self.add = P.BiasAdd().shard(strategy2)
self.bias_add = P.BiasAdd().shard(strategy2)
self.add_weight = Parameter(add_weight, "w1")
self.mul_weight = Parameter(matmul_weight, "w1")
self.bias = Parameter(add_weight, "bias")
self.bias = Parameter(bias, "bias")
self.reshape = P.Reshape()
def construct(self, x, b):
out = self.matmul(x, self.mul_weight)
out = self.add(x, self.add_weight)
out = self.reshape(out, (64, 32))
out = self.matmul(out, self.mul_weight)
out = self.add(out, self.bias)
return out
_x = Tensor(np.ones([64, 32]), dtype=ms.float32)
_x = Tensor(np.ones([64, 16, 2]), dtype=ms.float32)
_w0 = Tensor(np.ones([64, 16, 2]), dtype=ms.float32)
_w1 = Tensor(np.ones([32, 32]), dtype=ms.float32)
_w2 = Tensor(np.ones([32]), dtype=ms.float32)
_b = Tensor(np.ones([64, 32]), dtype=ms.float32)
_b = Tensor(np.ones([64, 16, 2]), dtype=ms.float32)
def compile_net(net):
@ -58,16 +64,40 @@ def compile_net(net):
def test_opt_data_parallel():
"""
Feature: test adafactor data parallel
Description:
Expectation: compile success
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
strategy1 = ((16, 1), (1, 1))
strategy2 = ((16, 1), (1,))
net = Net(_w1, _w2, strategy1, strategy2)
net = Net(_w0, _w1, _w2, strategy1, strategy2)
compile_net(net)
def test_opt_model_parallel():
"""
Feature: test adafactor model parallel
Description:
Expectation: compile success
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
strategy1 = ((4, 2), (2, 2))
strategy2 = ((4, 2), (2,))
net = Net(_w1, _w2, strategy1, strategy2)
net = Net(_w0, _w1, _w2, strategy1, strategy2)
compile_net(net)
def test_opt_shard():
"""
Feature: test adafactor optimizer parallel
Description: only shard batch dimension
Expectation: compile success
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0,
enable_parallel_optimizer=True)
strategy1 = ((4, 2), (2, 2))
strategy2 = ((4, 2), (2,))
net = Net(_w0, _w1, _w2, strategy1, strategy2)
compile_net(net)