fix_parallel_op_info_bug

This commit is contained in:
lichen 2023-02-06 21:09:05 +08:00
parent 7cddb2c437
commit 4a7a9e4b5f
12 changed files with 299 additions and 18 deletions

View File

@ -432,6 +432,7 @@ void SetUserAttrs(const mindspore::HashMap<std::string, ValuePtr> &origin_prim_a
// Convert ValueTuple/ValueList to vector
Status TransValueSequeueToVector(const ValuePtr &input_value, std::vector<int64_t> *input) {
MS_EXCEPTION_IF_NULL(input_value);
input->clear();
if (!input_value->isa<ValueSequeue>()) {
MS_LOG(ERROR) << "Input value must be ValueTuplePtr.";
return FAILED;

View File

@ -141,6 +141,9 @@ Status ConcatInfo::InferTensorMap() {
void ConcatInfo::ReComputeBatchSplitFlagList() {
for (size_t i = 0; i < inputs_shape_.size(); i++) {
split_flag_list_[i] = true;
if (axis_ == 0) {
split_flag_list_[i] = false;
}
}
}

View File

@ -335,14 +335,13 @@ Status GatherInfo::CheckStrategy(const StrategyPtr &strategy) {
return FAILED;
}
// param slice shape need 32Byte aligned
// param slice shape preferably 32Byte aligned
auto param_shape = inputs_shape_.at(0);
auto input_dim = strategy->GetInputDim();
auto param_strategy = input_dim.at(0);
auto slice_shape = param_shape.at(param_shape.size() - 1) / param_strategy.at(param_strategy.size() - 1);
if ((target_ != CPU) && (slice_shape % 8 != 0) && (slice_shape != 1)) {
ReportError(name_ + ": Last dim of param slice shape need 32Byte aligned.");
return FAILED;
MS_LOG(WARNING) << "Gather: Last dim of param slice shape is not 32Byte aligned.";
}
if (manual_split_) {
@ -353,6 +352,11 @@ Status GatherInfo::CheckStrategy(const StrategyPtr &strategy) {
return SUCCESS;
}
// parameter not split axis
if (param_strategy.at(LongToSize(axis_)) == 1) {
return SUCCESS;
}
// only support 1-dim and 2-dim param
if (inputs_shape_.at(0).size() != 1 && inputs_shape_.at(0).size() != 2) {
MS_LOG(ERROR) << name_ << ": Don't support param dim " << inputs_shape_.at(0).size();

View File

@ -505,9 +505,24 @@ std::vector<StrategyPtr> MatMulBase::GenerateOpStrategies(int64_t stage_id) {
}
std::shared_ptr<Strategies> BatchMatMulInfo::GenerateBatchStrategies() {
Dimensions batch_strategy(inputs_shape_[1].size() - 1, 1);
(void)batch_strategy.insert(batch_strategy.cbegin(), stage_device_size_);
Strategies strategy_v = {batch_strategy, batch_strategy};
Dimensions batch_strategy_a(inputs_shape_[0].size(), 1);
Dimensions batch_strategy_b(inputs_shape_[1].size(), 1);
MS_EXCEPTION_IF_ZERO("device_num", stage_device_size_);
Strategies strategy_v;
// input's shape equals to weight's shape
if (inputs_shape_[0].size() == inputs_shape_[1].size()) {
batch_strategy_a[0] = stage_device_size_;
if (inputs_shape_[0].size() > MATMUL_INPUTS_SIZE) {
batch_strategy_b[0] = stage_device_size_;
}
}
if (inputs_shape_[0].size() > inputs_shape_[1].size()) {
batch_strategy_a[0] = stage_device_size_;
}
if (inputs_shape_[0].size() < inputs_shape_[1].size()) {
batch_strategy_b[0] = stage_device_size_;
}
strategy_v = {batch_strategy_a, batch_strategy_b};
return std::make_shared<Strategies>(strategy_v);
}

View File

@ -241,7 +241,10 @@ std::vector<StrategyPtr> OneHotInfo::GenerateOpStrategies(int64_t stage_id) {
Status OneHotInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); }
std::shared_ptr<Strategies> OneHotInfo::GenerateBatchStrategies() {
Dimensions strategy = {stage_device_size_, 1};
Dimensions strategy(inputs_shape_[0].size() + 1, 1);
if (inputs_shape_[0].front() % stage_device_size_ == 0) {
strategy[0] = {stage_device_size_};
}
Dimensions empty_strategy;
Strategies strategy_v = {strategy, empty_strategy, empty_strategy};
return std::make_shared<Strategies>(strategy_v);

View File

@ -31,6 +31,7 @@
namespace mindspore {
namespace parallel {
Status SliceInfo::GetInput(const ValuePtr &input_value, std::vector<int64_t> *input) {
input->clear();
MS_EXCEPTION_IF_NULL(input_value);
ValueTuplePtr value_tuple = input_value->cast<ValueTuplePtr>();
if (value_tuple == nullptr) {
@ -151,7 +152,14 @@ Status SliceInfo::InferMirrorOps() {
// Note: if the batch dimension is not fully fetched, the batch strategy may not work.
std::shared_ptr<Strategies> SliceInfo::GenerateBatchStrategies() {
if (GetAttrs() != SUCCESS) {
MS_LOG(EXCEPTION) << name_ << "generate batch parallel strategies failed.";
}
split_flag_list_ = {true};
bool no_fully_fetch = ((begin_[0] != 0) || (size_[0] < inputs_shape_[0][0]));
if (no_fully_fetch) {
split_flag_list_ = {false};
}
return GenerateBatchStrategiesBySplitFlag(inputs_shape_, split_flag_list_);
}

View File

@ -152,9 +152,7 @@ std::shared_ptr<Strategies> SplitInfo::GenerateBatchStrategies() {
Dimensions input_strategy(inputs_shape_[0].size(), 1);
// axis can't split
if (inputs_shape_[0].size() > 1) {
if (axis_ == 0) {
input_strategy[1] = stage_device_size_;
} else {
if (axis_ != 0) {
input_strategy[0] = stage_device_size_;
}
}

View File

@ -405,7 +405,14 @@ Status StridedSliceInfo::InferMirrorOps() {
// Note: if the batch dimension is not fully fetched, the batch strategy may not work.
std::shared_ptr<Strategies> StridedSliceInfo::GenerateBatchStrategies() {
if (GetAttrs() != SUCCESS) {
MS_LOG(EXCEPTION) << name_ << "generate batch parallel strategies failed.";
}
split_flag_list_ = {true};
bool no_fully_fetch = ((begin_[0] != 0) || (end_[0] < input_shape_in_process_[0]));
if (no_fully_fetch) {
split_flag_list_ = {false};
}
return GenerateBatchStrategiesBySplitFlag(inputs_shape_, split_flag_list_);
}

View File

@ -57,7 +57,7 @@ Status VirtualDatasetInfo::CheckStrategy(const StrategyPtr &strategy) {
current_stra_shard_num = dim;
}
}
if (i == 0) {
if (shard_num_ == 1) {
shard_num_ = current_stra_shard_num;
} else if (current_stra_shard_num != 1 && current_stra_shard_num != shard_num_) {
MS_LOG(ERROR) << name_

View File

@ -40,12 +40,6 @@ Status VirtualOutputInfo::CheckStrategy(const StrategyPtr &strategy) {
return FAILED;
}
Dimensions strategy_first = stra.at(0);
for (auto dim = strategy_first.begin() + 1; dim != strategy_first.end(); ++dim) {
if (*dim != 1) {
MS_LOG(ERROR) << name_ << ": All dimension except the first dimension of the strategy must be 1.";
return FAILED;
}
}
if (!strategy_first.empty()) {
shard_num_ = strategy_first[0];
}

View File

@ -911,7 +911,7 @@ void InsertVirtualOutput(const FuncGraphPtr &root, const std::vector<AnfNodePtr>
}
} else {
Shapes shape_outputs = GetNodeShape(out_node);
if (shape_outputs[0].empty()) {
if (shape_outputs[0].empty() || out_node->isa<Parameter>()) {
return;
}
auto node_input = CreateInput(op, out_node, VIRTUAL_OUTPUT);

View File

@ -0,0 +1,248 @@
# Copyright 2023 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import mindspore as ms
import mindspore.nn as nn
from mindspore import Tensor, Parameter
from mindspore import context
from mindspore.common.api import _cell_graph_executor
from mindspore.ops import composite as C
from mindspore.ops import operations as P
from tests.ut.python.ops.test_math_ops import VirtualLoss
from parallel.utils.utils import ParallelValidator
def setup_function():
context.set_auto_parallel_context(dataset_strategy="full_batch")
grad_all = C.GradOperation(get_all=True)
class NetWithLoss(nn.Cell):
def __init__(self, network):
super(NetWithLoss, self).__init__()
self.loss = VirtualLoss()
self.network = network
def construct(self, x, y):
predict = self.network(x, y)
return self.loss(predict)
class GradWrap(nn.Cell):
def __init__(self, network):
super(GradWrap, self).__init__()
self.network = network
def construct(self, x, y):
return grad_all(self.network)(x, y)
def test_concat():
"""
Feature: Test Concat with axis=0 and generate batch parallel strategy.
Description: axis=0, batch parallel strategy must be full one.
Expectation: Successful graph compilation.
"""
class Net(nn.Cell):
def __init__(self):
super().__init__()
self.concat = P.Concat()
def construct(self, x, y):
out = self.concat((x, y))
return out
context.reset_auto_parallel_context()
context.set_auto_parallel_context(device_num=8, global_rank=0)
net = GradWrap(NetWithLoss(Net()))
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
x = Tensor(np.ones([128, 128]), dtype=ms.float32)
y = Tensor(np.ones([32, 128]), dtype=ms.float32)
net.set_train()
phase, _ = _cell_graph_executor.compile(net, x, y)
validator = ParallelValidator(net, phase)
expect_strategy = {"gen_strategy": '((1, 1), (1, 1))'}
assert validator.check_node_attrs("Concat-0", expect_strategy)
def test_batch_matmul():
"""
Feature: Test BatchMatMul with 2-dim weight and generate batch parallel strategy.
Description: batch parallel strategy of weight must be full one.
Expectation: Successful graph compilation.
"""
class Net(nn.Cell):
def __init__(self):
super().__init__()
self.batch_matmul = P.BatchMatMul()
def construct(self, x, y):
out = self.batch_matmul(x, y)
return out
context.reset_auto_parallel_context()
context.set_auto_parallel_context(device_num=8, global_rank=0)
net = GradWrap(NetWithLoss(Net()))
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", full_batch=True)
x = Tensor(np.ones([128, 128, 32]), dtype=ms.float32)
y = Tensor(np.ones([32, 128]), dtype=ms.float32)
net.set_train()
phase, _ = _cell_graph_executor.compile(net, x, y)
validator = ParallelValidator(net, phase)
expect_strategy = {"gen_strategy": '((8, 1, 1), (1, 1))'}
assert validator.check_node_attrs("BatchMatMul-0", expect_strategy)
def test_onehot():
"""
Feature: Test OneHot with 2-dim input and generate batch parallel strategy.
Description: batch parallel strategy must be full one.
Expectation: Successful graph compilation.
"""
class Net(nn.Cell):
def __init__(self):
super().__init__()
self.ont_hot = P.OneHot()
self.on = Tensor(1.0, ms.float32)
self.off = Tensor(0.0, ms.float32)
def construct(self, x, y):
out = self.ont_hot(x, 1, self.on, self.off)
return out
context.reset_auto_parallel_context()
context.set_auto_parallel_context(device_num=8, global_rank=0)
net = GradWrap(NetWithLoss(Net()))
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", full_batch=True)
x = Tensor(np.ones([2, 128]), dtype=ms.int32)
y = Tensor(np.ones([32, 128]), dtype=ms.float32)
net.set_train()
phase, _ = _cell_graph_executor.compile(net, x, y)
validator = ParallelValidator(net, phase)
expect_strategy = {"gen_strategy": '((1, 1, 1), (), ())'}
assert validator.check_node_attrs("OneHot-0", expect_strategy)
def test_slice():
"""
Feature: Test Slice with input no fully fetched and generate batch parallel strategy.
Description: batch parallel strategy must be full one.
Expectation: Successful graph compilation.
"""
class Net(nn.Cell):
def __init__(self):
super().__init__()
self.slice = P.Slice()
def construct(self, x, y):
out = self.slice(x, (0, 0), (64, 128))
return out
context.reset_auto_parallel_context()
context.set_auto_parallel_context(device_num=8, global_rank=0, full_batch=True)
net = GradWrap(NetWithLoss(Net()))
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
x = Tensor(np.ones([128, 128]), dtype=ms.float32)
y = Tensor(np.ones([32, 128]), dtype=ms.float32)
net.set_train()
phase, _ = _cell_graph_executor.compile(net, x, y)
validator = ParallelValidator(net, phase)
expect_strategy = {"gen_strategy": '((1, 1))'}
assert validator.check_node_attrs("Slice-0", expect_strategy)
def test_strided_slice():
"""
Feature: Test StridedSlice with input no fully fetched and generate batch parallel strategy.
Description: batch parallel strategy must be full one.
Expectation: Successful graph compilation.
"""
class Net(nn.Cell):
def __init__(self):
super().__init__()
self.slice = P.StridedSlice()
def construct(self, x, y):
out = self.slice(x, (0, 0), (64, 128), (1, 1))
return out
context.reset_auto_parallel_context()
context.set_auto_parallel_context(device_num=8, global_rank=0, full_batch=True)
net = GradWrap(NetWithLoss(Net()))
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
x = Tensor(np.ones([128, 128]), dtype=ms.float32)
y = Tensor(np.ones([32, 128]), dtype=ms.float32)
net.set_train()
phase, _ = _cell_graph_executor.compile(net, x, y)
validator = ParallelValidator(net, phase)
expect_strategy = {"gen_strategy": '((1, 1))'}
assert validator.check_node_attrs("StridedSlice-0", expect_strategy)
def test_split():
"""
Feature: Test Split with axis=0 and generate batch parallel strategy.
Description: batch parallel strategy must be full one.
Expectation: Successful graph compilation.
"""
class Net(nn.Cell):
def __init__(self):
super().__init__()
self.split = P.Split(0, 2)
def construct(self, x, y):
out, _ = self.split(x)
return out
context.reset_auto_parallel_context()
context.set_auto_parallel_context(device_num=8, global_rank=0, full_batch=True)
net = GradWrap(NetWithLoss(Net()))
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
x = Tensor(np.ones([128, 128]), dtype=ms.float32)
y = Tensor(np.ones([32, 128]), dtype=ms.float32)
net.set_train()
phase, _ = _cell_graph_executor.compile(net, x, y)
validator = ParallelValidator(net, phase)
expect_strategy = {"gen_strategy": '((1, 1))'}
assert validator.check_node_attrs("Split-0", expect_strategy)
def test_virtual_output():
"""
Feature: Test virtualoutput with return parameter in predict mode.
Description: No need to insert virtualoutput.
Expectation: Successful graph compilation.
"""
class Net(nn.Cell):
def __init__(self):
super().__init__()
self.param = Parameter(Tensor(np.ones([32, 32]), ms.float32))
def construct(self):
return self.param
context.reset_auto_parallel_context()
context.set_auto_parallel_context(device_num=8, global_rank=0, full_batch=True)
net = Net()
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
phase, _ = _cell_graph_executor.compile(net)
validator = ParallelValidator(net, phase)
sub_graph = {'Return-0': ['param']}
assert validator.check_graph_structure(sub_graph)