parallel func mode in sink mode
This commit is contained in:
parent
85e85ae2bf
commit
bc5a7d8d1e
|
@ -89,6 +89,17 @@ struct OutStrategyValueRegister {
|
|||
});
|
||||
}
|
||||
} out_regist;
|
||||
|
||||
int64_t MaxCommonDivisor(int64_t n1, int64_t n2) {
|
||||
while (n1 != n2) {
|
||||
if (n1 > n2) {
|
||||
n1 -= n2;
|
||||
} else {
|
||||
n2 -= n1;
|
||||
}
|
||||
}
|
||||
return n1;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
std::string StrategyToString(const Strategies &strategy) {
|
||||
|
@ -1064,6 +1075,41 @@ void OperatorInfo::ReplaceSuccEdges(const std::shared_ptr<OperatorInfo> &op, con
|
|||
succ_edges_ = update_pre_edges;
|
||||
}
|
||||
|
||||
std::shared_ptr<Strategies> OperatorInfo::GenerateBatchStrategiesWithCheck() {
|
||||
std::shared_ptr<Strategies> batch_strategy = GenerateBatchStrategies();
|
||||
if (batch_strategy->size() != inputs_shape_.size()) {
|
||||
MS_LOG(WARNING) << "The inputs size:" << inputs_shape_.size()
|
||||
<< " is not equal to the generated batch parallel strategies size:" << batch_strategy->size();
|
||||
return batch_strategy;
|
||||
}
|
||||
int64_t shard_size = g_device_manager->stage_device_num();
|
||||
std::vector<std::pair<size_t, size_t>> changed_pos;
|
||||
for (size_t i = 0; i < inputs_shape_.size(); ++i) {
|
||||
auto stra = batch_strategy->at(i);
|
||||
auto input_shape = inputs_shape_.at(i);
|
||||
if (stra.size() != input_shape.size()) {
|
||||
MS_LOG(WARNING) << "The " << i << " input size:" << input_shape.size() << " is not equal to the " << i
|
||||
<< " generated batch parallel strategy size:" << stra.size();
|
||||
return batch_strategy;
|
||||
}
|
||||
for (size_t j = 0; j < input_shape.size(); ++j) {
|
||||
if (stra[j] == 1) {
|
||||
continue;
|
||||
}
|
||||
if (stra[j] != g_device_manager->stage_device_num()) {
|
||||
MS_LOG(WARNING) << "The batch parallel value is not equal to device num, skip adjust it.";
|
||||
return batch_strategy;
|
||||
}
|
||||
shard_size = MaxCommonDivisor(input_shape[j], shard_size);
|
||||
changed_pos.push_back({i, j});
|
||||
}
|
||||
}
|
||||
for (auto &pair : changed_pos) {
|
||||
batch_strategy->at(pair.first).at(pair.second) = shard_size;
|
||||
}
|
||||
return batch_strategy;
|
||||
}
|
||||
|
||||
std::shared_ptr<Strategies> GenerateBatchStrategiesBySplitFlag(const Shapes &shapes,
|
||||
const std::vector<bool> &split_flag_list) {
|
||||
if (shapes.size() != split_flag_list.size()) {
|
||||
|
|
|
@ -107,6 +107,7 @@ class OperatorInfo {
|
|||
|
||||
virtual std::shared_ptr<Strategies> GenerateBatchStrategies();
|
||||
virtual void ReComputeBatchSplitFlagList();
|
||||
std::shared_ptr<Strategies> GenerateBatchStrategiesWithCheck();
|
||||
void ComputeBatchSplitFlagList();
|
||||
|
||||
double GetForwardMemoryCostFromCNode();
|
||||
|
|
|
@ -1723,7 +1723,7 @@ StrategyPtr GenerateBatchParallelStrategy(const OperatorInfoPtr operator_, const
|
|||
MS_EXCEPTION_IF_NULL(operator_);
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
StrategyPtr strategyPtr;
|
||||
std::shared_ptr<Strategies> strategy_v_ptr = operator_->GenerateBatchStrategies();
|
||||
std::shared_ptr<Strategies> strategy_v_ptr = operator_->GenerateBatchStrategiesWithCheck();
|
||||
MS_EXCEPTION_IF_NULL(strategy_v_ptr);
|
||||
strategyPtr = NewStrategy(0, *strategy_v_ptr);
|
||||
std::vector<ValuePtr> elements;
|
||||
|
|
|
@ -52,6 +52,9 @@ def _init_sink_dataset(dataset, steps, sink_size, input_signature):
|
|||
_check_inputs(input_signature, dataset_shapes, dataset_types)
|
||||
|
||||
queue_name = transfer_dataset.queue_name
|
||||
if _need_to_full():
|
||||
device_num = _get_device_num() // _get_pipeline_stages()
|
||||
dataset_shapes = _to_full_shapes(dataset_shapes, device_num)
|
||||
next_op = ops.GetNext(dataset_types, dataset_shapes, len(dataset_types), queue_name)
|
||||
|
||||
_set_dataset_mode_config('sink')
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2019 Huawei Technologies Co., Ltd
|
||||
# Copyright 2019-2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -83,23 +83,28 @@ class NetConv(nn.Cell):
|
|||
return self.conv(input_x)
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, strategy1, strategy2, strategy3):
|
||||
super().__init__()
|
||||
self.conv1 = NetConv(16, 8, (3, 3), bias_init='zeros', strategy=strategy1)
|
||||
self.mul1 = P.Mul().shard(strategy2)
|
||||
self.conv2 = NetConv(8, 64, (9, 9), bias_init='zeros', strategy=strategy1)
|
||||
self.mul2 = P.Mul().shard(strategy3)
|
||||
|
||||
def construct(self, x, w1, w2):
|
||||
out1 = self.conv1(x)
|
||||
out2 = self.mul1(out1, w1)
|
||||
out3 = self.conv2(out2)
|
||||
out4 = self.mul2(out3, w2)
|
||||
return out4
|
||||
|
||||
|
||||
def test_batch():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, strategy1, strategy2, strategy3):
|
||||
super().__init__()
|
||||
self.conv1 = NetConv(16, 8, (3, 3), bias_init='zeros', strategy=strategy1)
|
||||
self.mul1 = P.Mul().shard(strategy2)
|
||||
self.conv2 = NetConv(8, 64, (9, 9), bias_init='zeros', strategy=strategy1)
|
||||
self.mul2 = P.Mul().shard(strategy3)
|
||||
|
||||
def construct(self, x, w1, w2):
|
||||
out1 = self.conv1(x)
|
||||
out2 = self.mul1(out1, w1)
|
||||
out3 = self.conv2(out2)
|
||||
out4 = self.mul2(out3, w2)
|
||||
|
||||
return out4
|
||||
|
||||
"""
|
||||
Feature: Batch parallel
|
||||
Description: test batch parallel
|
||||
Expectation: compile ok
|
||||
"""
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0)
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
|
||||
strategy1 = ((8, 1, 1, 1), (1, 1, 1, 1))
|
||||
|
@ -116,5 +121,23 @@ def test_batch():
|
|||
_cell_graph_executor.compile(net, x, w1, w2)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_batch()
|
||||
def test_batch_shape_less_than_devices():
|
||||
"""
|
||||
Feature: Batch parallel
|
||||
Description: test batch parallel, shapes less than device nums.
|
||||
Expectation: compile ok
|
||||
"""
|
||||
context.set_auto_parallel_context(device_num=512, global_rank=0)
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
|
||||
strategy1 = None
|
||||
strategy2 = None
|
||||
strategy3 = None
|
||||
|
||||
net = GradWrap(NetWithLoss(Net(strategy1, strategy2, strategy3)))
|
||||
net.set_auto_parallel()
|
||||
|
||||
x = Tensor(np.ones([128, 16, 34, 34]), dtype=ms.float32)
|
||||
w1 = Tensor(np.ones([128, 8, 32, 32]), dtype=ms.float32)
|
||||
w2 = Tensor(np.ones([128, 64, 24, 24]), dtype=ms.float32)
|
||||
net.set_train()
|
||||
_cell_graph_executor.compile(net, x, w1, w2)
|
||||
|
|
|
@ -0,0 +1,126 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor, context, ms_function
|
||||
from mindspore import dataset as ds
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.common.jit_config import JitConfig
|
||||
|
||||
|
||||
context.set_context(mode=ms.GRAPH_MODE)
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", dataset_strategy="full_batch")
|
||||
|
||||
|
||||
def setup_function():
|
||||
context.set_auto_parallel_context(dataset_strategy="full_batch")
|
||||
|
||||
|
||||
class Attention(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Attention, self).__init__()
|
||||
self.fc_a = nn.Dense(128, 768, activation='relu')
|
||||
self.fc_b = nn.Dense(128, 768, activation='relu')
|
||||
self.fc_c = nn.Dense(128, 768, activation='relu')
|
||||
self.fc_a.matmul.shard(((1, 1), (8, 1)))
|
||||
self.fc_b.matmul.shard(((1, 1), (8, 1)))
|
||||
self.fc_c.matmul.shard(((1, 1), (8, 1)))
|
||||
|
||||
def construct(self, x):
|
||||
q = self.fc_a(x)
|
||||
k = self.fc_b(x)
|
||||
v = self.fc_c(x)
|
||||
return q, k, v
|
||||
|
||||
attention = Attention()
|
||||
relu = nn.ReLU()
|
||||
|
||||
|
||||
@ms_function
|
||||
def dense_func(x, label):
|
||||
q, k, v = attention(x)
|
||||
k = P.Transpose()(k, (1, 0)) # (728, 32)
|
||||
c = relu(P.MatMul()(q, k)) # (32, 32)
|
||||
s = relu(P.MatMul()(c, v)) # (32, 768)
|
||||
s = s - label
|
||||
return P.ReduceMean()(s * s)
|
||||
|
||||
optimizer_adam = nn.Adam(attention.trainable_params(), learning_rate=0.001)
|
||||
attention.set_train()
|
||||
attention.update_parameters_name("attn")
|
||||
optimizer_adam.update_parameters_name("opt")
|
||||
grad_dens_func = ms.ops.value_and_grad(dense_func, None, optimizer_adam.parameters)
|
||||
|
||||
|
||||
@ms_function
|
||||
def train_step(input_, label_):
|
||||
loss, grad = grad_dens_func(input_, label_)
|
||||
optimizer_adam(grad)
|
||||
return loss
|
||||
|
||||
|
||||
def test_sink():
|
||||
"""
|
||||
Feature: Function mode in auto parallel
|
||||
Description: sink mode
|
||||
Expectation: compile ok
|
||||
"""
|
||||
context.reset_auto_parallel_context()
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel",
|
||||
dataset_strategy="data_parallel", device_num=8)
|
||||
data = {"input": np.ones([16, 32, 128]).astype(np.float32), "label": np.zeros([16, 32, 768]).astype(np.float32)}
|
||||
dataset = ds.NumpySlicesDataset(data=data)
|
||||
jitconfig = JitConfig(jit_level="O1", task_sink=True)
|
||||
sink_process = ms.train.data_sink(dense_func, dataset, steps=2, sink_size=4, jit_config=jitconfig)
|
||||
_ = sink_process()
|
||||
|
||||
|
||||
def test_no_sink():
|
||||
"""
|
||||
Feature: Function mode in auto parallel
|
||||
Description: no sink mode
|
||||
Expectation: compile ok
|
||||
"""
|
||||
context.reset_auto_parallel_context()
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", dataset_strategy="full_batch", device_num=8)
|
||||
_ = dense_func(Tensor(np.ones([32, 128]).astype(np.float32)), Tensor(np.zeros([32, 768]).astype(np.float32)))
|
||||
|
||||
|
||||
def test_sink_with_grad():
|
||||
"""
|
||||
Feature: Function mode in auto parallel
|
||||
Description: sink mode with grad
|
||||
Expectation: compile ok
|
||||
"""
|
||||
context.reset_auto_parallel_context()
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel",
|
||||
dataset_strategy="data_parallel", device_num=8)
|
||||
data = {"input": np.ones([16, 32, 128]).astype(np.float32), "label": np.zeros([16, 32, 768]).astype(np.float32)}
|
||||
dataset = ds.NumpySlicesDataset(data=data)
|
||||
jitconfig = JitConfig(jit_level="O1", task_sink=True)
|
||||
sink_process = ms.train.data_sink(train_step, dataset, steps=2, sink_size=4, jit_config=jitconfig)
|
||||
_ = sink_process()
|
||||
|
||||
|
||||
def test_no_sink_with_grad():
|
||||
"""
|
||||
Feature: Function mode in auto parallel
|
||||
Description: no sink mode with grad
|
||||
Expectation: compile ok
|
||||
"""
|
||||
context.reset_auto_parallel_context()
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", dataset_strategy="full_batch", device_num=8)
|
||||
_ = train_step(Tensor(np.ones([32, 128]).astype(np.float32)), Tensor(np.zeros([32, 768]).astype(np.float32)))
|
Loading…
Reference in New Issue