forked from mindspore-Ecosystem/mindspore
!46489 add functioal shard
Merge pull request !46489 from suteng/functional_shard
This commit is contained in:
commit
df23fbbb8d
|
@ -181,8 +181,6 @@ class COMMON_EXPORT ParallelContext {
|
|||
void ParallelParameterContextInitShape(const FuncGraphPtr &func_graph);
|
||||
void ParallelParameterContextRestoreShape(const FuncGraphPtr &func_graph, const ParameterPtr ¶m_node,
|
||||
const AbstractBasePtr &ptr) const;
|
||||
void ParallelParameterContextCkptShape(const FuncGraphPtr &func_graph, const ParameterPtr ¶m_node,
|
||||
const AbstractBasePtr &ptr) const;
|
||||
void set_sharding_propagation(const bool stra_pto);
|
||||
bool sharding_propagation() const { return sharding_propagation_; }
|
||||
|
||||
|
@ -194,7 +192,7 @@ class COMMON_EXPORT ParallelContext {
|
|||
|
||||
private:
|
||||
ParallelContext();
|
||||
bool IsAutoParallelCareGraph(const FuncGraphPtr &func_graph) const;
|
||||
bool ParallelContextCareGraph(const FuncGraphPtr &func_graph) const;
|
||||
|
||||
bool gradients_mean_;
|
||||
bool full_batch_;
|
||||
|
|
|
@ -741,7 +741,6 @@ abstract::AbstractBasePtrList GetArgsAbs(const ResourcePtr &resource) {
|
|||
auto param_abs = GetDefaultValueAbstract(param_node);
|
||||
context->ParallelParameterContextRestoreShape(func_graph, param_node, param_abs);
|
||||
(void)args_abs.emplace_back(param_abs);
|
||||
context->ParallelParameterContextCkptShape(func_graph, param_node, param_abs);
|
||||
}
|
||||
}
|
||||
return args_abs;
|
||||
|
|
|
@ -24,8 +24,6 @@
|
|||
|
||||
namespace mindspore::parallel {
|
||||
namespace {
|
||||
std::map<std::string, std::vector<int64_t>> param_shapes;
|
||||
|
||||
std::vector<std::string> kParallelModeList = {kStandalone, kDataParallel, kHybridParallel, kSemiAutoParallel,
|
||||
kAutoParallel};
|
||||
std::vector<std::string> kStrategySearchModeList = {kDynamicProgramming, kRecursiveProgramming, kShardingPropagation};
|
||||
|
@ -235,14 +233,12 @@ bool ParallelContext::set_communi_parallel_mode(const std::string &communi_paral
|
|||
return true;
|
||||
}
|
||||
|
||||
// Clear param_shapes before training in auto-parallel or semi-auto-parallel mode
|
||||
void ParallelContext::ParallelParameterContextInitShape(const FuncGraphPtr &func_graph) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
if (!IsAutoParallelCareGraph(func_graph)) {
|
||||
if (!ParallelContextCareGraph(func_graph)) {
|
||||
return;
|
||||
}
|
||||
if (func_graph->has_flag(kIsFirstIteration)) {
|
||||
param_shapes.clear();
|
||||
init_param_shape_ = true;
|
||||
MS_LOG(INFO) << "Init the parameter shape dict in increment predict with two graph";
|
||||
return;
|
||||
|
@ -257,7 +253,6 @@ void ParallelContext::ParallelParameterContextInitShape(const FuncGraphPtr &func
|
|||
init_param_shape_ = false;
|
||||
MS_LOG(INFO) << "In parallel grad accumulation second graph, need to restore the parameter shape";
|
||||
} else {
|
||||
param_shapes.clear();
|
||||
init_param_shape_ = true;
|
||||
MS_LOG(INFO) << "Init the parameter shape dict";
|
||||
}
|
||||
|
@ -270,7 +265,7 @@ void ParallelContext::ParallelParameterContextRestoreShape(const FuncGraphPtr &f
|
|||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(param_node);
|
||||
MS_EXCEPTION_IF_NULL(ptr);
|
||||
if (!IsAutoParallelCareGraph(func_graph)) {
|
||||
if (!ParallelContextCareGraph(func_graph)) {
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -291,30 +286,7 @@ void ParallelContext::ParallelParameterContextRestoreShape(const FuncGraphPtr &f
|
|||
MS_LOG(INFO) << "The parameter name is " << param_node->name() << ", the shape is " << shape;
|
||||
}
|
||||
|
||||
// Clear param_shapes before training in auto-parallel or semi-auto-parallel mode
|
||||
// Checkpoint the parameters' shape for training in auto-parallel or semi-auto-parallel mode
|
||||
void ParallelContext::ParallelParameterContextCkptShape(const FuncGraphPtr &func_graph, const ParameterPtr ¶m_node,
|
||||
const AbstractBasePtr &ptr) const {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(param_node);
|
||||
MS_EXCEPTION_IF_NULL(ptr);
|
||||
if (!IsAutoParallelCareGraph(func_graph)) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (!init_param_shape_) {
|
||||
return;
|
||||
}
|
||||
std::vector<int64_t> shape = dyn_cast<abstract::Shape>(ptr->GetShapeTrack())->shape();
|
||||
auto ret = param_shapes.try_emplace(param_node->name(), shape);
|
||||
if (!ret.second) {
|
||||
MS_LOG(EXCEPTION) << "The shape for parameter name " << param_node->name() << " is existed";
|
||||
}
|
||||
|
||||
MS_LOG(DEBUG) << "The parameter name is " << param_node->name() << ", the shape is " << shape;
|
||||
}
|
||||
|
||||
bool ParallelContext::IsAutoParallelCareGraph(const FuncGraphPtr &func_graph) const {
|
||||
bool ParallelContext::ParallelContextCareGraph(const FuncGraphPtr &func_graph) const {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
if (func_graph->has_flag(kSkipAutoParallelCompile)) {
|
||||
return false;
|
||||
|
|
|
@ -181,6 +181,7 @@ class Optimizer(Cell):
|
|||
self._init_group_params(parameters, learning_rate, weight_decay, self.grad_centralization)
|
||||
|
||||
self._init_opt_attrs(learning_rate, parameters, weight_decay)
|
||||
self.add_flags(skip_auto_parallel_compile=True)
|
||||
|
||||
def _init_opt_attrs(self, learning_rate, parameters, weight_decay):
|
||||
"""initialize optimizer attributions"""
|
||||
|
|
|
@ -0,0 +1,169 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
|
||||
from mindspore import Parameter, jit
|
||||
from mindspore.nn import Cell, Momentum
|
||||
from mindspore.nn import MSELoss
|
||||
import mindspore.dataset as ds
|
||||
import mindspore.ops as ops
|
||||
import mindspore as ms
|
||||
from mindspore.common.initializer import initializer
|
||||
from mindspore.communication import init
|
||||
|
||||
|
||||
def get_dataset(batch_size, step_per_epoch, in_dim, out_dim):
|
||||
input_data = np.ones((batch_size, in_dim), dtype=np.float32) * 0.1
|
||||
label_data = np.ones((batch_size, out_dim), dtype=np.float32) * 0.1
|
||||
def generate():
|
||||
for _ in range(step_per_epoch):
|
||||
yield (input_data, label_data)
|
||||
return generate
|
||||
|
||||
|
||||
class Net(Cell):
|
||||
"""define net"""
|
||||
def __init__(self, in_dim, hidden_dim, out_dim):
|
||||
super().__init__()
|
||||
self.in_dim = in_dim
|
||||
self.hidden_dim = hidden_dim
|
||||
self.out_dim = out_dim
|
||||
self.weight = Parameter(initializer(0.03, [self.in_dim, self.hidden_dim]), "w")
|
||||
self.weight2 = Parameter(initializer(0.04, [self.hidden_dim, self.out_dim]), "w2")
|
||||
self.matmul = ops.MatMul()
|
||||
|
||||
self.relu = ops.ReLU()
|
||||
self.matmul2 = ops.MatMul()
|
||||
|
||||
def construct(self, x):
|
||||
out = self.matmul(x, self.weight)
|
||||
out = self.relu(out)
|
||||
out = self.matmul2(out, self.weight2)
|
||||
return out
|
||||
|
||||
|
||||
def test_pynative_func():
|
||||
'''
|
||||
Feature: Object Oriented and Functional Mixed Programming
|
||||
Description: pynative mode, run one step
|
||||
Expectation: Run success
|
||||
'''
|
||||
var_step_per_epoch = 1
|
||||
var_single_batch_size = 16
|
||||
var_in_dim = 32
|
||||
var_hidden_dim = 8
|
||||
var_out_dim = 16
|
||||
|
||||
ms.set_context(mode=ms.PYNATIVE_MODE, device_target="GPU")
|
||||
ms.set_auto_parallel_context(parallel_mode=ms.ParallelMode.AUTO_PARALLEL, search_mode="sharding_propagation",
|
||||
device_num=8)
|
||||
|
||||
init("nccl")
|
||||
|
||||
# dataset
|
||||
fake_dataset = get_dataset(var_single_batch_size, var_step_per_epoch, var_in_dim, var_out_dim)
|
||||
dataset = ds.GeneratorDataset(fake_dataset, ["input", "label"])
|
||||
|
||||
# define net
|
||||
net = Net(var_in_dim, var_hidden_dim, var_out_dim)
|
||||
|
||||
# define shard
|
||||
net.shard(in_strategy=((2, 4),), parameter_plan={"weight": (4, 1)})
|
||||
|
||||
# define loss
|
||||
loss_fn = MSELoss()
|
||||
|
||||
# define opt
|
||||
learning_rate = 0.3
|
||||
momentum = 0.1
|
||||
opt = Momentum(net.trainable_params(), learning_rate, momentum)
|
||||
|
||||
# define forward function
|
||||
def net_forward(x, y):
|
||||
out = net(x)
|
||||
loss = loss_fn(out, y)
|
||||
return loss
|
||||
|
||||
grad_net = ops.value_and_grad(net_forward, grad_position=None, weights=net.trainable_params())
|
||||
|
||||
def train_one_step(x, y):
|
||||
loss, grads = grad_net(x, y)
|
||||
opt(grads)
|
||||
return loss
|
||||
|
||||
loss = 0.0
|
||||
for _ in range(1):
|
||||
for input_x, label in dataset:
|
||||
loss = train_one_step(input_x, label)
|
||||
assert np.allclose(np.array([loss.asnumpy()]), np.array([0.004799718]), 0.0001, 0.0001)
|
||||
ms.reset_auto_parallel_context()
|
||||
|
||||
|
||||
def test_graph_func():
|
||||
'''
|
||||
Feature: Object Oriented and Functional Mixed Programming
|
||||
Description: graph mode, run two step
|
||||
Expectation: Run success
|
||||
'''
|
||||
var_step_per_epoch = 2
|
||||
var_single_batch_size = 16
|
||||
var_in_dim = 32
|
||||
var_hidden_dim = 8
|
||||
var_out_dim = 16
|
||||
|
||||
ms.set_context(mode=ms.GRAPH_MODE, device_target="GPU")
|
||||
ms.set_auto_parallel_context(parallel_mode=ms.ParallelMode.AUTO_PARALLEL, search_mode="sharding_propagation")
|
||||
|
||||
init("nccl")
|
||||
|
||||
# dataset
|
||||
fake_dataset = get_dataset(var_single_batch_size, var_step_per_epoch, var_in_dim, var_out_dim)
|
||||
dataset = ds.GeneratorDataset(fake_dataset, ["input", "label"])
|
||||
|
||||
# define net
|
||||
net = Net(var_in_dim, var_hidden_dim, var_out_dim)
|
||||
|
||||
# define shard
|
||||
net.matmul.shard(((2, 4), (4, 1)))
|
||||
|
||||
# define loss
|
||||
loss_fn = MSELoss()
|
||||
|
||||
# define opt
|
||||
learning_rate = 0.3
|
||||
momentum = 0.1
|
||||
opt = Momentum(net.trainable_params(), learning_rate, momentum)
|
||||
|
||||
# define forward function
|
||||
def net_forward(x, y):
|
||||
out = net(x)
|
||||
loss = loss_fn(out, y)
|
||||
return loss
|
||||
|
||||
grad_net = ops.value_and_grad(net_forward, grad_position=None, weights=net.trainable_params())
|
||||
|
||||
@jit
|
||||
def train_one_step(x, y):
|
||||
loss, grads = grad_net(x, y)
|
||||
opt(grads)
|
||||
return loss
|
||||
|
||||
loss = 0.0
|
||||
for _ in range(1):
|
||||
for input_x, label in dataset:
|
||||
loss = train_one_step(input_x, label)
|
||||
assert np.allclose(np.array([loss.asnumpy()]), np.array([0.0047495714]), 0.0001, 0.0001)
|
||||
ms.reset_auto_parallel_context()
|
|
@ -0,0 +1,43 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import os
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_single
|
||||
def test_pynative_functional_train_gpu():
|
||||
'''
|
||||
Feature: Object Oriented and Functional Mixed Programming
|
||||
Description: pynative mode
|
||||
Expectation: Run success
|
||||
'''
|
||||
ret = os.system("mpirun -n 8 --allow-run-as-root pytest -s -v functional_train.py::test_pynative_func")
|
||||
assert ret == 0
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_single
|
||||
def test_graph_functional_train_gpu():
|
||||
'''
|
||||
Feature: Object Oriented and Functional Mixed Programming
|
||||
Description: graph mode
|
||||
Expectation: Run success
|
||||
'''
|
||||
ret = os.system("mpirun -n 8 --allow-run-as-root pytest -s -v functional_train.py::test_graph_func")
|
||||
assert ret == 0
|
Loading…
Reference in New Issue