support opt parallel for adafactor
This commit is contained in:
parent
db7d28f5c8
commit
2a0b528084
|
@ -360,7 +360,7 @@ void HandleNoUsedParameter(const FuncGraphPtr &root) {
|
|||
continue;
|
||||
}
|
||||
Shape slice_shape = parameter_shape[0];
|
||||
if (slice_shape.empty()) {
|
||||
if (slice_shape.empty() || slice_shape[0] < dev_num) {
|
||||
continue;
|
||||
}
|
||||
slice_shape[0] = slice_shape[0] / dev_num;
|
||||
|
@ -554,29 +554,78 @@ void SetClonedTensorShapeForOptimizer(const FuncGraphPtr &root) {
|
|||
}
|
||||
}
|
||||
|
||||
// For adafactor optimizer, the relationship between parameter and state's shape as follows:
|
||||
// 1) parameter: [A, B, C, D] (shape_size > 2), exp_avg_sq_row: [A, B, C], exp_avg_sq_col: [A, B, D], exp_avg_sq: [1]
|
||||
// If the parameter is opt shard, the exp_avg_sq_row and exp_avg_sq_col need to be shard accordingly.
|
||||
//
|
||||
// 2) parameter: [A, B] (shape_size = 2), exp_avg_sq_row: [A], exp_avg_sq_col: [B], exp_avg_sq: [1]
|
||||
// If the parameter is opt shard, the exp_avg_sq_row needs to be shard accordingly.
|
||||
//
|
||||
// 3) parameter: [A] (shape_size = 1), exp_avg_sq_row: [1], exp_avg_sq_col: [1], exp_avg_sq: [A]
|
||||
// If the parameter is opt shard, the exp_avg_sq needs to be shard accordingly.
|
||||
static bool AdafactorStateIsOptShard(const std::string &opt_shard_group, size_t shape_size,
|
||||
const std::string ¶m_name, const std::string &state_name) {
|
||||
if (opt_shard_group.empty()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
std::string exp_row_name = EXP_AVG_SQ_ROW + param_name;
|
||||
std::string exp_col_name = EXP_AVG_SQ_COL + param_name;
|
||||
std::string exp_avg_name = EXP_AVG_SQ + param_name;
|
||||
|
||||
if (shape_size > 2 && state_name == exp_avg_name) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (shape_size == 2 && (state_name == exp_col_name || state_name == exp_avg_name)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (shape_size == 1 && (state_name == exp_row_name || state_name == exp_col_name)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "The parameter " << param_name << " is opt shard";
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool IsOriginWeight(const ParameterPtr ¶m) {
|
||||
std::string param_name = param->name();
|
||||
if (param_name.find(EXP_AVG) != std::string::npos) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto tensor_layout = param->user_data<TensorLayout>();
|
||||
if (tensor_layout == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void HandleAdaFactorOpt(const FuncGraphPtr &root) {
|
||||
MS_EXCEPTION_IF_NULL(root);
|
||||
for (auto ¶m_node : root->parameters()) {
|
||||
MS_EXCEPTION_IF_NULL(param_node);
|
||||
auto param = param_node->cast<ParameterPtr>();
|
||||
MS_EXCEPTION_IF_NULL(param);
|
||||
std::string param_name = param->name();
|
||||
if (param_name.find(EXP_AVG) != std::string::npos) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto tensor_layout = param->user_data<TensorLayout>();
|
||||
if (tensor_layout == nullptr) {
|
||||
if (!IsOriginWeight(param)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
int64_t row_col_count = 0;
|
||||
int64_t exp_avg_sq_count = 0;
|
||||
for (auto &row_col_node : root->parameters()) {
|
||||
if (row_col_count == 2 && exp_avg_sq_count == 1) {
|
||||
break;
|
||||
}
|
||||
|
||||
MS_EXCEPTION_IF_NULL(row_col_node);
|
||||
auto row_col_param = row_col_node->cast<ParameterPtr>();
|
||||
MS_EXCEPTION_IF_NULL(row_col_param);
|
||||
std::string row_col_param_name = row_col_param->name();
|
||||
std::string param_name = param->name();
|
||||
std::string exp_row_name = EXP_AVG_SQ_ROW + param_name;
|
||||
std::string exp_col_name = EXP_AVG_SQ_COL + param_name;
|
||||
std::string exp_avg_name = EXP_AVG_SQ + param_name;
|
||||
|
@ -586,14 +635,23 @@ void HandleAdaFactorOpt(const FuncGraphPtr &root) {
|
|||
continue;
|
||||
}
|
||||
|
||||
auto tensor_layout = param->user_data<TensorLayout>();
|
||||
MS_EXCEPTION_IF_NULL(tensor_layout);
|
||||
auto slice_shape = tensor_layout->slice_shape().array();
|
||||
Shape opt_shard_slice_shape = slice_shape;
|
||||
if (!tensor_layout->opt_shard_group().empty()) {
|
||||
opt_shard_slice_shape = tensor_layout->opt_shard_slice_shape();
|
||||
}
|
||||
|
||||
auto shape_size = slice_shape.size();
|
||||
bool is_row_or_col_param = (row_col_param_name == exp_row_name) || (row_col_param_name == exp_col_name);
|
||||
if (is_row_or_col_param && shape_size <= 1) {
|
||||
row_col_count++;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (row_col_param_name == exp_avg_name && shape_size != 1) {
|
||||
exp_avg_sq_count++;
|
||||
continue;
|
||||
}
|
||||
|
||||
|
@ -602,12 +660,13 @@ void HandleAdaFactorOpt(const FuncGraphPtr &root) {
|
|||
auto tensor_map = tensor_layout->tensor_map().array();
|
||||
|
||||
if (row_col_param_name == exp_row_name) {
|
||||
slice_shape.pop_back();
|
||||
opt_shard_slice_shape.pop_back();
|
||||
origin_shape.pop_back();
|
||||
tensor_map.pop_back();
|
||||
row_col_count++;
|
||||
} else if (row_col_param_name == exp_col_name) {
|
||||
(void)slice_shape.erase(slice_shape.begin() + static_cast<different_type>(SECOND_FROM_END(shape_size)));
|
||||
(void)opt_shard_slice_shape.erase(opt_shard_slice_shape.begin() +
|
||||
static_cast<different_type>(SECOND_FROM_END(shape_size)));
|
||||
(void)origin_shape.erase(origin_shape.begin() + static_cast<different_type>(SECOND_FROM_END(shape_size)));
|
||||
(void)tensor_map.erase(tensor_map.begin() + static_cast<different_type>(SECOND_FROM_END(shape_size)));
|
||||
row_col_count++;
|
||||
|
@ -620,19 +679,19 @@ void HandleAdaFactorOpt(const FuncGraphPtr &root) {
|
|||
MS_LOG(EXCEPTION) << "Init tensor layout failed";
|
||||
}
|
||||
|
||||
if (AdafactorStateIsOptShard(tensor_layout->opt_shard_group(), shape_size, param_name, row_col_param_name)) {
|
||||
new_tensor_layout.set_opt_shard_group(tensor_layout->opt_shard_group());
|
||||
}
|
||||
|
||||
auto cloned_abstract = row_col_node->abstract()->Clone();
|
||||
MS_EXCEPTION_IF_NULL(cloned_abstract);
|
||||
std::shared_ptr<abstract::BaseShape> parallel_shape = std::make_shared<abstract::Shape>(slice_shape);
|
||||
std::shared_ptr<abstract::BaseShape> parallel_shape = std::make_shared<abstract::Shape>(opt_shard_slice_shape);
|
||||
MS_EXCEPTION_IF_NULL(parallel_shape);
|
||||
cloned_abstract->set_shape(parallel_shape);
|
||||
row_col_param->set_user_data<TensorLayout>(std::make_shared<TensorLayout>(new_tensor_layout));
|
||||
row_col_node->set_abstract(cloned_abstract);
|
||||
MS_LOG(INFO) << "Set the slice shape for " << row_col_param_name << ", origin shape is " << origin_shape
|
||||
<< ", new slice shape is " << slice_shape;
|
||||
|
||||
if (row_col_count == 2 || exp_avg_sq_count == 1) {
|
||||
break;
|
||||
}
|
||||
<< ", new slice shape is " << opt_shard_slice_shape;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -26,7 +26,7 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace parallel {
|
||||
constexpr char EXP_AVG[] = "exp_avg_";
|
||||
constexpr char EXP_AVG[] = "exp_avg";
|
||||
constexpr char EXP_AVG_SQ_ROW[] = "exp_avg_sq_row_";
|
||||
constexpr char EXP_AVG_SQ_COL[] = "exp_avg_sq_col_";
|
||||
constexpr char EXP_AVG_SQ[] = "exp_avg_sq_";
|
||||
|
|
|
@ -23,23 +23,29 @@ from mindspore.ops import operations as P
|
|||
|
||||
|
||||
class Net(Cell):
|
||||
def __init__(self, matmul_weight, add_weight, strategy1=None, strategy2=None):
|
||||
def __init__(self, add_weight, matmul_weight, bias, strategy1=None, strategy2=None):
|
||||
super().__init__()
|
||||
self.add = P.TensorAdd()
|
||||
self.matmul = P.MatMul().shard(strategy1)
|
||||
self.add = P.BiasAdd().shard(strategy2)
|
||||
self.bias_add = P.BiasAdd().shard(strategy2)
|
||||
self.add_weight = Parameter(add_weight, "w1")
|
||||
self.mul_weight = Parameter(matmul_weight, "w1")
|
||||
self.bias = Parameter(add_weight, "bias")
|
||||
self.bias = Parameter(bias, "bias")
|
||||
self.reshape = P.Reshape()
|
||||
|
||||
def construct(self, x, b):
|
||||
out = self.matmul(x, self.mul_weight)
|
||||
out = self.add(x, self.add_weight)
|
||||
out = self.reshape(out, (64, 32))
|
||||
out = self.matmul(out, self.mul_weight)
|
||||
out = self.add(out, self.bias)
|
||||
return out
|
||||
|
||||
|
||||
_x = Tensor(np.ones([64, 32]), dtype=ms.float32)
|
||||
_x = Tensor(np.ones([64, 16, 2]), dtype=ms.float32)
|
||||
_w0 = Tensor(np.ones([64, 16, 2]), dtype=ms.float32)
|
||||
_w1 = Tensor(np.ones([32, 32]), dtype=ms.float32)
|
||||
_w2 = Tensor(np.ones([32]), dtype=ms.float32)
|
||||
_b = Tensor(np.ones([64, 32]), dtype=ms.float32)
|
||||
_b = Tensor(np.ones([64, 16, 2]), dtype=ms.float32)
|
||||
|
||||
|
||||
def compile_net(net):
|
||||
|
@ -58,16 +64,40 @@ def compile_net(net):
|
|||
|
||||
|
||||
def test_opt_data_parallel():
|
||||
"""
|
||||
Feature: test adafactor data parallel
|
||||
Description:
|
||||
Expectation: compile success
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
|
||||
strategy1 = ((16, 1), (1, 1))
|
||||
strategy2 = ((16, 1), (1,))
|
||||
net = Net(_w1, _w2, strategy1, strategy2)
|
||||
net = Net(_w0, _w1, _w2, strategy1, strategy2)
|
||||
compile_net(net)
|
||||
|
||||
|
||||
def test_opt_model_parallel():
|
||||
"""
|
||||
Feature: test adafactor model parallel
|
||||
Description:
|
||||
Expectation: compile success
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
|
||||
strategy1 = ((4, 2), (2, 2))
|
||||
strategy2 = ((4, 2), (2,))
|
||||
net = Net(_w1, _w2, strategy1, strategy2)
|
||||
net = Net(_w0, _w1, _w2, strategy1, strategy2)
|
||||
compile_net(net)
|
||||
|
||||
|
||||
def test_opt_shard():
|
||||
"""
|
||||
Feature: test adafactor optimizer parallel
|
||||
Description: only shard batch dimension
|
||||
Expectation: compile success
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0,
|
||||
enable_parallel_optimizer=True)
|
||||
strategy1 = ((4, 2), (2, 2))
|
||||
strategy2 = ((4, 2), (2,))
|
||||
net = Net(_w0, _w1, _w2, strategy1, strategy2)
|
||||
compile_net(net)
|
||||
|
|
Loading…
Reference in New Issue