!46489 add functioal shard

Merge pull request !46489 from suteng/functional_shard
This commit is contained in:
i-robot 2022-12-08 01:38:51 +00:00 committed by Gitee
commit df23fbbb8d
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
6 changed files with 217 additions and 35 deletions

View File

@ -181,8 +181,6 @@ class COMMON_EXPORT ParallelContext {
void ParallelParameterContextInitShape(const FuncGraphPtr &func_graph);
void ParallelParameterContextRestoreShape(const FuncGraphPtr &func_graph, const ParameterPtr &param_node,
const AbstractBasePtr &ptr) const;
void ParallelParameterContextCkptShape(const FuncGraphPtr &func_graph, const ParameterPtr &param_node,
const AbstractBasePtr &ptr) const;
void set_sharding_propagation(const bool stra_pto);
bool sharding_propagation() const { return sharding_propagation_; }
@ -194,7 +192,7 @@ class COMMON_EXPORT ParallelContext {
private:
ParallelContext();
bool IsAutoParallelCareGraph(const FuncGraphPtr &func_graph) const;
bool ParallelContextCareGraph(const FuncGraphPtr &func_graph) const;
bool gradients_mean_;
bool full_batch_;

View File

@ -741,7 +741,6 @@ abstract::AbstractBasePtrList GetArgsAbs(const ResourcePtr &resource) {
auto param_abs = GetDefaultValueAbstract(param_node);
context->ParallelParameterContextRestoreShape(func_graph, param_node, param_abs);
(void)args_abs.emplace_back(param_abs);
context->ParallelParameterContextCkptShape(func_graph, param_node, param_abs);
}
}
return args_abs;

View File

@ -24,8 +24,6 @@
namespace mindspore::parallel {
namespace {
std::map<std::string, std::vector<int64_t>> param_shapes;
std::vector<std::string> kParallelModeList = {kStandalone, kDataParallel, kHybridParallel, kSemiAutoParallel,
kAutoParallel};
std::vector<std::string> kStrategySearchModeList = {kDynamicProgramming, kRecursiveProgramming, kShardingPropagation};
@ -235,14 +233,12 @@ bool ParallelContext::set_communi_parallel_mode(const std::string &communi_paral
return true;
}
// Clear param_shapes before training in auto-parallel or semi-auto-parallel mode
void ParallelContext::ParallelParameterContextInitShape(const FuncGraphPtr &func_graph) {
MS_EXCEPTION_IF_NULL(func_graph);
if (!IsAutoParallelCareGraph(func_graph)) {
if (!ParallelContextCareGraph(func_graph)) {
return;
}
if (func_graph->has_flag(kIsFirstIteration)) {
param_shapes.clear();
init_param_shape_ = true;
MS_LOG(INFO) << "Init the parameter shape dict in increment predict with two graph";
return;
@ -257,7 +253,6 @@ void ParallelContext::ParallelParameterContextInitShape(const FuncGraphPtr &func
init_param_shape_ = false;
MS_LOG(INFO) << "In parallel grad accumulation second graph, need to restore the parameter shape";
} else {
param_shapes.clear();
init_param_shape_ = true;
MS_LOG(INFO) << "Init the parameter shape dict";
}
@ -270,7 +265,7 @@ void ParallelContext::ParallelParameterContextRestoreShape(const FuncGraphPtr &f
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(param_node);
MS_EXCEPTION_IF_NULL(ptr);
if (!IsAutoParallelCareGraph(func_graph)) {
if (!ParallelContextCareGraph(func_graph)) {
return;
}
@ -291,30 +286,7 @@ void ParallelContext::ParallelParameterContextRestoreShape(const FuncGraphPtr &f
MS_LOG(INFO) << "The parameter name is " << param_node->name() << ", the shape is " << shape;
}
// Clear param_shapes before training in auto-parallel or semi-auto-parallel mode
// Checkpoint the parameters' shape for training in auto-parallel or semi-auto-parallel mode
void ParallelContext::ParallelParameterContextCkptShape(const FuncGraphPtr &func_graph, const ParameterPtr &param_node,
const AbstractBasePtr &ptr) const {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(param_node);
MS_EXCEPTION_IF_NULL(ptr);
if (!IsAutoParallelCareGraph(func_graph)) {
return;
}
if (!init_param_shape_) {
return;
}
std::vector<int64_t> shape = dyn_cast<abstract::Shape>(ptr->GetShapeTrack())->shape();
auto ret = param_shapes.try_emplace(param_node->name(), shape);
if (!ret.second) {
MS_LOG(EXCEPTION) << "The shape for parameter name " << param_node->name() << " is existed";
}
MS_LOG(DEBUG) << "The parameter name is " << param_node->name() << ", the shape is " << shape;
}
bool ParallelContext::IsAutoParallelCareGraph(const FuncGraphPtr &func_graph) const {
bool ParallelContext::ParallelContextCareGraph(const FuncGraphPtr &func_graph) const {
MS_EXCEPTION_IF_NULL(func_graph);
if (func_graph->has_flag(kSkipAutoParallelCompile)) {
return false;

View File

@ -181,6 +181,7 @@ class Optimizer(Cell):
self._init_group_params(parameters, learning_rate, weight_decay, self.grad_centralization)
self._init_opt_attrs(learning_rate, parameters, weight_decay)
self.add_flags(skip_auto_parallel_compile=True)
def _init_opt_attrs(self, learning_rate, parameters, weight_decay):
"""initialize optimizer attributions"""

View File

@ -0,0 +1,169 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
from mindspore import Parameter, jit
from mindspore.nn import Cell, Momentum
from mindspore.nn import MSELoss
import mindspore.dataset as ds
import mindspore.ops as ops
import mindspore as ms
from mindspore.common.initializer import initializer
from mindspore.communication import init
def get_dataset(batch_size, step_per_epoch, in_dim, out_dim):
input_data = np.ones((batch_size, in_dim), dtype=np.float32) * 0.1
label_data = np.ones((batch_size, out_dim), dtype=np.float32) * 0.1
def generate():
for _ in range(step_per_epoch):
yield (input_data, label_data)
return generate
class Net(Cell):
"""define net"""
def __init__(self, in_dim, hidden_dim, out_dim):
super().__init__()
self.in_dim = in_dim
self.hidden_dim = hidden_dim
self.out_dim = out_dim
self.weight = Parameter(initializer(0.03, [self.in_dim, self.hidden_dim]), "w")
self.weight2 = Parameter(initializer(0.04, [self.hidden_dim, self.out_dim]), "w2")
self.matmul = ops.MatMul()
self.relu = ops.ReLU()
self.matmul2 = ops.MatMul()
def construct(self, x):
out = self.matmul(x, self.weight)
out = self.relu(out)
out = self.matmul2(out, self.weight2)
return out
def test_pynative_func():
'''
Feature: Object Oriented and Functional Mixed Programming
Description: pynative mode, run one step
Expectation: Run success
'''
var_step_per_epoch = 1
var_single_batch_size = 16
var_in_dim = 32
var_hidden_dim = 8
var_out_dim = 16
ms.set_context(mode=ms.PYNATIVE_MODE, device_target="GPU")
ms.set_auto_parallel_context(parallel_mode=ms.ParallelMode.AUTO_PARALLEL, search_mode="sharding_propagation",
device_num=8)
init("nccl")
# dataset
fake_dataset = get_dataset(var_single_batch_size, var_step_per_epoch, var_in_dim, var_out_dim)
dataset = ds.GeneratorDataset(fake_dataset, ["input", "label"])
# define net
net = Net(var_in_dim, var_hidden_dim, var_out_dim)
# define shard
net.shard(in_strategy=((2, 4),), parameter_plan={"weight": (4, 1)})
# define loss
loss_fn = MSELoss()
# define opt
learning_rate = 0.3
momentum = 0.1
opt = Momentum(net.trainable_params(), learning_rate, momentum)
# define forward function
def net_forward(x, y):
out = net(x)
loss = loss_fn(out, y)
return loss
grad_net = ops.value_and_grad(net_forward, grad_position=None, weights=net.trainable_params())
def train_one_step(x, y):
loss, grads = grad_net(x, y)
opt(grads)
return loss
loss = 0.0
for _ in range(1):
for input_x, label in dataset:
loss = train_one_step(input_x, label)
assert np.allclose(np.array([loss.asnumpy()]), np.array([0.004799718]), 0.0001, 0.0001)
ms.reset_auto_parallel_context()
def test_graph_func():
'''
Feature: Object Oriented and Functional Mixed Programming
Description: graph mode, run two step
Expectation: Run success
'''
var_step_per_epoch = 2
var_single_batch_size = 16
var_in_dim = 32
var_hidden_dim = 8
var_out_dim = 16
ms.set_context(mode=ms.GRAPH_MODE, device_target="GPU")
ms.set_auto_parallel_context(parallel_mode=ms.ParallelMode.AUTO_PARALLEL, search_mode="sharding_propagation")
init("nccl")
# dataset
fake_dataset = get_dataset(var_single_batch_size, var_step_per_epoch, var_in_dim, var_out_dim)
dataset = ds.GeneratorDataset(fake_dataset, ["input", "label"])
# define net
net = Net(var_in_dim, var_hidden_dim, var_out_dim)
# define shard
net.matmul.shard(((2, 4), (4, 1)))
# define loss
loss_fn = MSELoss()
# define opt
learning_rate = 0.3
momentum = 0.1
opt = Momentum(net.trainable_params(), learning_rate, momentum)
# define forward function
def net_forward(x, y):
out = net(x)
loss = loss_fn(out, y)
return loss
grad_net = ops.value_and_grad(net_forward, grad_position=None, weights=net.trainable_params())
@jit
def train_one_step(x, y):
loss, grads = grad_net(x, y)
opt(grads)
return loss
loss = 0.0
for _ in range(1):
for input_x, label in dataset:
loss = train_one_step(input_x, label)
assert np.allclose(np.array([loss.asnumpy()]), np.array([0.0047495714]), 0.0001, 0.0001)
ms.reset_auto_parallel_context()

View File

@ -0,0 +1,43 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import os
import pytest
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_single
def test_pynative_functional_train_gpu():
'''
Feature: Object Oriented and Functional Mixed Programming
Description: pynative mode
Expectation: Run success
'''
ret = os.system("mpirun -n 8 --allow-run-as-root pytest -s -v functional_train.py::test_pynative_func")
assert ret == 0
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_single
def test_graph_functional_train_gpu():
'''
Feature: Object Oriented and Functional Mixed Programming
Description: graph mode
Expectation: Run success
'''
ret = os.system("mpirun -n 8 --allow-run-as-root pytest -s -v functional_train.py::test_graph_func")
assert ret == 0