parallel func mode in sink mode

This commit is contained in:
yao_yf 2022-10-15 16:47:04 +08:00
parent 85e85ae2bf
commit bc5a7d8d1e
6 changed files with 219 additions and 20 deletions

View File

@ -89,6 +89,17 @@ struct OutStrategyValueRegister {
});
}
} out_regist;
int64_t MaxCommonDivisor(int64_t n1, int64_t n2) {
while (n1 != n2) {
if (n1 > n2) {
n1 -= n2;
} else {
n2 -= n1;
}
}
return n1;
}
} // namespace
std::string StrategyToString(const Strategies &strategy) {
@ -1064,6 +1075,41 @@ void OperatorInfo::ReplaceSuccEdges(const std::shared_ptr<OperatorInfo> &op, con
succ_edges_ = update_pre_edges;
}
std::shared_ptr<Strategies> OperatorInfo::GenerateBatchStrategiesWithCheck() {
std::shared_ptr<Strategies> batch_strategy = GenerateBatchStrategies();
if (batch_strategy->size() != inputs_shape_.size()) {
MS_LOG(WARNING) << "The inputs size:" << inputs_shape_.size()
<< " is not equal to the generated batch parallel strategies size:" << batch_strategy->size();
return batch_strategy;
}
int64_t shard_size = g_device_manager->stage_device_num();
std::vector<std::pair<size_t, size_t>> changed_pos;
for (size_t i = 0; i < inputs_shape_.size(); ++i) {
auto stra = batch_strategy->at(i);
auto input_shape = inputs_shape_.at(i);
if (stra.size() != input_shape.size()) {
MS_LOG(WARNING) << "The " << i << " input size:" << input_shape.size() << " is not equal to the " << i
<< " generated batch parallel strategy size:" << stra.size();
return batch_strategy;
}
for (size_t j = 0; j < input_shape.size(); ++j) {
if (stra[j] == 1) {
continue;
}
if (stra[j] != g_device_manager->stage_device_num()) {
MS_LOG(WARNING) << "The batch parallel value is not equal to device num, skip adjust it.";
return batch_strategy;
}
shard_size = MaxCommonDivisor(input_shape[j], shard_size);
changed_pos.push_back({i, j});
}
}
for (auto &pair : changed_pos) {
batch_strategy->at(pair.first).at(pair.second) = shard_size;
}
return batch_strategy;
}
std::shared_ptr<Strategies> GenerateBatchStrategiesBySplitFlag(const Shapes &shapes,
const std::vector<bool> &split_flag_list) {
if (shapes.size() != split_flag_list.size()) {

View File

@ -107,6 +107,7 @@ class OperatorInfo {
virtual std::shared_ptr<Strategies> GenerateBatchStrategies();
virtual void ReComputeBatchSplitFlagList();
std::shared_ptr<Strategies> GenerateBatchStrategiesWithCheck();
void ComputeBatchSplitFlagList();
double GetForwardMemoryCostFromCNode();

View File

@ -1723,7 +1723,7 @@ StrategyPtr GenerateBatchParallelStrategy(const OperatorInfoPtr operator_, const
MS_EXCEPTION_IF_NULL(operator_);
MS_EXCEPTION_IF_NULL(prim);
StrategyPtr strategyPtr;
std::shared_ptr<Strategies> strategy_v_ptr = operator_->GenerateBatchStrategies();
std::shared_ptr<Strategies> strategy_v_ptr = operator_->GenerateBatchStrategiesWithCheck();
MS_EXCEPTION_IF_NULL(strategy_v_ptr);
strategyPtr = NewStrategy(0, *strategy_v_ptr);
std::vector<ValuePtr> elements;

View File

@ -52,6 +52,9 @@ def _init_sink_dataset(dataset, steps, sink_size, input_signature):
_check_inputs(input_signature, dataset_shapes, dataset_types)
queue_name = transfer_dataset.queue_name
if _need_to_full():
device_num = _get_device_num() // _get_pipeline_stages()
dataset_shapes = _to_full_shapes(dataset_shapes, device_num)
next_op = ops.GetNext(dataset_types, dataset_shapes, len(dataset_types), queue_name)
_set_dataset_mode_config('sink')

View File

@ -1,4 +1,4 @@
# Copyright 2019 Huawei Technologies Co., Ltd
# Copyright 2019-2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -83,23 +83,28 @@ class NetConv(nn.Cell):
return self.conv(input_x)
class Net(nn.Cell):
def __init__(self, strategy1, strategy2, strategy3):
super().__init__()
self.conv1 = NetConv(16, 8, (3, 3), bias_init='zeros', strategy=strategy1)
self.mul1 = P.Mul().shard(strategy2)
self.conv2 = NetConv(8, 64, (9, 9), bias_init='zeros', strategy=strategy1)
self.mul2 = P.Mul().shard(strategy3)
def construct(self, x, w1, w2):
out1 = self.conv1(x)
out2 = self.mul1(out1, w1)
out3 = self.conv2(out2)
out4 = self.mul2(out3, w2)
return out4
def test_batch():
class Net(nn.Cell):
def __init__(self, strategy1, strategy2, strategy3):
super().__init__()
self.conv1 = NetConv(16, 8, (3, 3), bias_init='zeros', strategy=strategy1)
self.mul1 = P.Mul().shard(strategy2)
self.conv2 = NetConv(8, 64, (9, 9), bias_init='zeros', strategy=strategy1)
self.mul2 = P.Mul().shard(strategy3)
def construct(self, x, w1, w2):
out1 = self.conv1(x)
out2 = self.mul1(out1, w1)
out3 = self.conv2(out2)
out4 = self.mul2(out3, w2)
return out4
"""
Feature: Batch parallel
Description: test batch parallel
Expectation: compile ok
"""
context.set_auto_parallel_context(device_num=8, global_rank=0)
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
strategy1 = ((8, 1, 1, 1), (1, 1, 1, 1))
@ -116,5 +121,23 @@ def test_batch():
_cell_graph_executor.compile(net, x, w1, w2)
if __name__ == '__main__':
test_batch()
def test_batch_shape_less_than_devices():
"""
Feature: Batch parallel
Description: test batch parallel, shapes less than device nums.
Expectation: compile ok
"""
context.set_auto_parallel_context(device_num=512, global_rank=0)
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
strategy1 = None
strategy2 = None
strategy3 = None
net = GradWrap(NetWithLoss(Net(strategy1, strategy2, strategy3)))
net.set_auto_parallel()
x = Tensor(np.ones([128, 16, 34, 34]), dtype=ms.float32)
w1 = Tensor(np.ones([128, 8, 32, 32]), dtype=ms.float32)
w2 = Tensor(np.ones([128, 64, 24, 24]), dtype=ms.float32)
net.set_train()
_cell_graph_executor.compile(net, x, w1, w2)

View File

@ -0,0 +1,126 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import mindspore as ms
import mindspore.nn as nn
from mindspore import Tensor, context, ms_function
from mindspore import dataset as ds
from mindspore.ops import operations as P
from mindspore.common.jit_config import JitConfig
context.set_context(mode=ms.GRAPH_MODE)
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", dataset_strategy="full_batch")
def setup_function():
context.set_auto_parallel_context(dataset_strategy="full_batch")
class Attention(nn.Cell):
def __init__(self):
super(Attention, self).__init__()
self.fc_a = nn.Dense(128, 768, activation='relu')
self.fc_b = nn.Dense(128, 768, activation='relu')
self.fc_c = nn.Dense(128, 768, activation='relu')
self.fc_a.matmul.shard(((1, 1), (8, 1)))
self.fc_b.matmul.shard(((1, 1), (8, 1)))
self.fc_c.matmul.shard(((1, 1), (8, 1)))
def construct(self, x):
q = self.fc_a(x)
k = self.fc_b(x)
v = self.fc_c(x)
return q, k, v
attention = Attention()
relu = nn.ReLU()
@ms_function
def dense_func(x, label):
q, k, v = attention(x)
k = P.Transpose()(k, (1, 0)) # (728, 32)
c = relu(P.MatMul()(q, k)) # (32, 32)
s = relu(P.MatMul()(c, v)) # (32, 768)
s = s - label
return P.ReduceMean()(s * s)
optimizer_adam = nn.Adam(attention.trainable_params(), learning_rate=0.001)
attention.set_train()
attention.update_parameters_name("attn")
optimizer_adam.update_parameters_name("opt")
grad_dens_func = ms.ops.value_and_grad(dense_func, None, optimizer_adam.parameters)
@ms_function
def train_step(input_, label_):
loss, grad = grad_dens_func(input_, label_)
optimizer_adam(grad)
return loss
def test_sink():
"""
Feature: Function mode in auto parallel
Description: sink mode
Expectation: compile ok
"""
context.reset_auto_parallel_context()
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel",
dataset_strategy="data_parallel", device_num=8)
data = {"input": np.ones([16, 32, 128]).astype(np.float32), "label": np.zeros([16, 32, 768]).astype(np.float32)}
dataset = ds.NumpySlicesDataset(data=data)
jitconfig = JitConfig(jit_level="O1", task_sink=True)
sink_process = ms.train.data_sink(dense_func, dataset, steps=2, sink_size=4, jit_config=jitconfig)
_ = sink_process()
def test_no_sink():
"""
Feature: Function mode in auto parallel
Description: no sink mode
Expectation: compile ok
"""
context.reset_auto_parallel_context()
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", dataset_strategy="full_batch", device_num=8)
_ = dense_func(Tensor(np.ones([32, 128]).astype(np.float32)), Tensor(np.zeros([32, 768]).astype(np.float32)))
def test_sink_with_grad():
"""
Feature: Function mode in auto parallel
Description: sink mode with grad
Expectation: compile ok
"""
context.reset_auto_parallel_context()
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel",
dataset_strategy="data_parallel", device_num=8)
data = {"input": np.ones([16, 32, 128]).astype(np.float32), "label": np.zeros([16, 32, 768]).astype(np.float32)}
dataset = ds.NumpySlicesDataset(data=data)
jitconfig = JitConfig(jit_level="O1", task_sink=True)
sink_process = ms.train.data_sink(train_step, dataset, steps=2, sink_size=4, jit_config=jitconfig)
_ = sink_process()
def test_no_sink_with_grad():
"""
Feature: Function mode in auto parallel
Description: no sink mode with grad
Expectation: compile ok
"""
context.reset_auto_parallel_context()
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", dataset_strategy="full_batch", device_num=8)
_ = train_step(Tensor(np.ones([32, 128]).astype(np.float32)), Tensor(np.zeros([32, 768]).astype(np.float32)))