forked from mindspore-Ecosystem/mindspore
fix_parallel_op_info_bug
This commit is contained in:
parent
7cddb2c437
commit
4a7a9e4b5f
|
@ -432,6 +432,7 @@ void SetUserAttrs(const mindspore::HashMap<std::string, ValuePtr> &origin_prim_a
|
|||
// Convert ValueTuple/ValueList to vector
|
||||
Status TransValueSequeueToVector(const ValuePtr &input_value, std::vector<int64_t> *input) {
|
||||
MS_EXCEPTION_IF_NULL(input_value);
|
||||
input->clear();
|
||||
if (!input_value->isa<ValueSequeue>()) {
|
||||
MS_LOG(ERROR) << "Input value must be ValueTuplePtr.";
|
||||
return FAILED;
|
||||
|
|
|
@ -141,6 +141,9 @@ Status ConcatInfo::InferTensorMap() {
|
|||
void ConcatInfo::ReComputeBatchSplitFlagList() {
|
||||
for (size_t i = 0; i < inputs_shape_.size(); i++) {
|
||||
split_flag_list_[i] = true;
|
||||
if (axis_ == 0) {
|
||||
split_flag_list_[i] = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -335,14 +335,13 @@ Status GatherInfo::CheckStrategy(const StrategyPtr &strategy) {
|
|||
return FAILED;
|
||||
}
|
||||
|
||||
// param slice shape need 32Byte aligned
|
||||
// param slice shape preferably 32Byte aligned
|
||||
auto param_shape = inputs_shape_.at(0);
|
||||
auto input_dim = strategy->GetInputDim();
|
||||
auto param_strategy = input_dim.at(0);
|
||||
auto slice_shape = param_shape.at(param_shape.size() - 1) / param_strategy.at(param_strategy.size() - 1);
|
||||
if ((target_ != CPU) && (slice_shape % 8 != 0) && (slice_shape != 1)) {
|
||||
ReportError(name_ + ": Last dim of param slice shape need 32Byte aligned.");
|
||||
return FAILED;
|
||||
MS_LOG(WARNING) << "Gather: Last dim of param slice shape is not 32Byte aligned.";
|
||||
}
|
||||
|
||||
if (manual_split_) {
|
||||
|
@ -353,6 +352,11 @@ Status GatherInfo::CheckStrategy(const StrategyPtr &strategy) {
|
|||
return SUCCESS;
|
||||
}
|
||||
|
||||
// parameter not split axis
|
||||
if (param_strategy.at(LongToSize(axis_)) == 1) {
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
// only support 1-dim and 2-dim param
|
||||
if (inputs_shape_.at(0).size() != 1 && inputs_shape_.at(0).size() != 2) {
|
||||
MS_LOG(ERROR) << name_ << ": Don't support param dim " << inputs_shape_.at(0).size();
|
||||
|
|
|
@ -505,9 +505,24 @@ std::vector<StrategyPtr> MatMulBase::GenerateOpStrategies(int64_t stage_id) {
|
|||
}
|
||||
|
||||
std::shared_ptr<Strategies> BatchMatMulInfo::GenerateBatchStrategies() {
|
||||
Dimensions batch_strategy(inputs_shape_[1].size() - 1, 1);
|
||||
(void)batch_strategy.insert(batch_strategy.cbegin(), stage_device_size_);
|
||||
Strategies strategy_v = {batch_strategy, batch_strategy};
|
||||
Dimensions batch_strategy_a(inputs_shape_[0].size(), 1);
|
||||
Dimensions batch_strategy_b(inputs_shape_[1].size(), 1);
|
||||
MS_EXCEPTION_IF_ZERO("device_num", stage_device_size_);
|
||||
Strategies strategy_v;
|
||||
// input's shape equals to weight's shape
|
||||
if (inputs_shape_[0].size() == inputs_shape_[1].size()) {
|
||||
batch_strategy_a[0] = stage_device_size_;
|
||||
if (inputs_shape_[0].size() > MATMUL_INPUTS_SIZE) {
|
||||
batch_strategy_b[0] = stage_device_size_;
|
||||
}
|
||||
}
|
||||
if (inputs_shape_[0].size() > inputs_shape_[1].size()) {
|
||||
batch_strategy_a[0] = stage_device_size_;
|
||||
}
|
||||
if (inputs_shape_[0].size() < inputs_shape_[1].size()) {
|
||||
batch_strategy_b[0] = stage_device_size_;
|
||||
}
|
||||
strategy_v = {batch_strategy_a, batch_strategy_b};
|
||||
return std::make_shared<Strategies>(strategy_v);
|
||||
}
|
||||
|
||||
|
|
|
@ -241,7 +241,10 @@ std::vector<StrategyPtr> OneHotInfo::GenerateOpStrategies(int64_t stage_id) {
|
|||
Status OneHotInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); }
|
||||
|
||||
std::shared_ptr<Strategies> OneHotInfo::GenerateBatchStrategies() {
|
||||
Dimensions strategy = {stage_device_size_, 1};
|
||||
Dimensions strategy(inputs_shape_[0].size() + 1, 1);
|
||||
if (inputs_shape_[0].front() % stage_device_size_ == 0) {
|
||||
strategy[0] = {stage_device_size_};
|
||||
}
|
||||
Dimensions empty_strategy;
|
||||
Strategies strategy_v = {strategy, empty_strategy, empty_strategy};
|
||||
return std::make_shared<Strategies>(strategy_v);
|
||||
|
|
|
@ -31,6 +31,7 @@
|
|||
namespace mindspore {
|
||||
namespace parallel {
|
||||
Status SliceInfo::GetInput(const ValuePtr &input_value, std::vector<int64_t> *input) {
|
||||
input->clear();
|
||||
MS_EXCEPTION_IF_NULL(input_value);
|
||||
ValueTuplePtr value_tuple = input_value->cast<ValueTuplePtr>();
|
||||
if (value_tuple == nullptr) {
|
||||
|
@ -151,7 +152,14 @@ Status SliceInfo::InferMirrorOps() {
|
|||
|
||||
// Note: if the batch dimension is not fully fetched, the batch strategy may not work.
|
||||
std::shared_ptr<Strategies> SliceInfo::GenerateBatchStrategies() {
|
||||
if (GetAttrs() != SUCCESS) {
|
||||
MS_LOG(EXCEPTION) << name_ << "generate batch parallel strategies failed.";
|
||||
}
|
||||
split_flag_list_ = {true};
|
||||
bool no_fully_fetch = ((begin_[0] != 0) || (size_[0] < inputs_shape_[0][0]));
|
||||
if (no_fully_fetch) {
|
||||
split_flag_list_ = {false};
|
||||
}
|
||||
return GenerateBatchStrategiesBySplitFlag(inputs_shape_, split_flag_list_);
|
||||
}
|
||||
|
||||
|
|
|
@ -152,9 +152,7 @@ std::shared_ptr<Strategies> SplitInfo::GenerateBatchStrategies() {
|
|||
Dimensions input_strategy(inputs_shape_[0].size(), 1);
|
||||
// axis can't split
|
||||
if (inputs_shape_[0].size() > 1) {
|
||||
if (axis_ == 0) {
|
||||
input_strategy[1] = stage_device_size_;
|
||||
} else {
|
||||
if (axis_ != 0) {
|
||||
input_strategy[0] = stage_device_size_;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -405,7 +405,14 @@ Status StridedSliceInfo::InferMirrorOps() {
|
|||
|
||||
// Note: if the batch dimension is not fully fetched, the batch strategy may not work.
|
||||
std::shared_ptr<Strategies> StridedSliceInfo::GenerateBatchStrategies() {
|
||||
if (GetAttrs() != SUCCESS) {
|
||||
MS_LOG(EXCEPTION) << name_ << "generate batch parallel strategies failed.";
|
||||
}
|
||||
split_flag_list_ = {true};
|
||||
bool no_fully_fetch = ((begin_[0] != 0) || (end_[0] < input_shape_in_process_[0]));
|
||||
if (no_fully_fetch) {
|
||||
split_flag_list_ = {false};
|
||||
}
|
||||
return GenerateBatchStrategiesBySplitFlag(inputs_shape_, split_flag_list_);
|
||||
}
|
||||
|
||||
|
|
|
@ -57,7 +57,7 @@ Status VirtualDatasetInfo::CheckStrategy(const StrategyPtr &strategy) {
|
|||
current_stra_shard_num = dim;
|
||||
}
|
||||
}
|
||||
if (i == 0) {
|
||||
if (shard_num_ == 1) {
|
||||
shard_num_ = current_stra_shard_num;
|
||||
} else if (current_stra_shard_num != 1 && current_stra_shard_num != shard_num_) {
|
||||
MS_LOG(ERROR) << name_
|
||||
|
|
|
@ -40,12 +40,6 @@ Status VirtualOutputInfo::CheckStrategy(const StrategyPtr &strategy) {
|
|||
return FAILED;
|
||||
}
|
||||
Dimensions strategy_first = stra.at(0);
|
||||
for (auto dim = strategy_first.begin() + 1; dim != strategy_first.end(); ++dim) {
|
||||
if (*dim != 1) {
|
||||
MS_LOG(ERROR) << name_ << ": All dimension except the first dimension of the strategy must be 1.";
|
||||
return FAILED;
|
||||
}
|
||||
}
|
||||
if (!strategy_first.empty()) {
|
||||
shard_num_ = strategy_first[0];
|
||||
}
|
||||
|
|
|
@ -911,7 +911,7 @@ void InsertVirtualOutput(const FuncGraphPtr &root, const std::vector<AnfNodePtr>
|
|||
}
|
||||
} else {
|
||||
Shapes shape_outputs = GetNodeShape(out_node);
|
||||
if (shape_outputs[0].empty()) {
|
||||
if (shape_outputs[0].empty() || out_node->isa<Parameter>()) {
|
||||
return;
|
||||
}
|
||||
auto node_input = CreateInput(op, out_node, VIRTUAL_OUTPUT);
|
||||
|
|
|
@ -0,0 +1,248 @@
|
|||
# Copyright 2023 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
|
||||
import numpy as np
|
||||
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor, Parameter
|
||||
from mindspore import context
|
||||
from mindspore.common.api import _cell_graph_executor
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.ops import operations as P
|
||||
from tests.ut.python.ops.test_math_ops import VirtualLoss
|
||||
from parallel.utils.utils import ParallelValidator
|
||||
|
||||
|
||||
def setup_function():
|
||||
context.set_auto_parallel_context(dataset_strategy="full_batch")
|
||||
|
||||
|
||||
grad_all = C.GradOperation(get_all=True)
|
||||
|
||||
|
||||
class NetWithLoss(nn.Cell):
|
||||
def __init__(self, network):
|
||||
super(NetWithLoss, self).__init__()
|
||||
self.loss = VirtualLoss()
|
||||
self.network = network
|
||||
|
||||
def construct(self, x, y):
|
||||
predict = self.network(x, y)
|
||||
return self.loss(predict)
|
||||
|
||||
|
||||
class GradWrap(nn.Cell):
|
||||
def __init__(self, network):
|
||||
super(GradWrap, self).__init__()
|
||||
self.network = network
|
||||
|
||||
def construct(self, x, y):
|
||||
return grad_all(self.network)(x, y)
|
||||
|
||||
|
||||
def test_concat():
|
||||
"""
|
||||
Feature: Test Concat with axis=0 and generate batch parallel strategy.
|
||||
Description: axis=0, batch parallel strategy must be full one.
|
||||
Expectation: Successful graph compilation.
|
||||
"""
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.concat = P.Concat()
|
||||
|
||||
def construct(self, x, y):
|
||||
out = self.concat((x, y))
|
||||
return out
|
||||
context.reset_auto_parallel_context()
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0)
|
||||
net = GradWrap(NetWithLoss(Net()))
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
|
||||
|
||||
x = Tensor(np.ones([128, 128]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([32, 128]), dtype=ms.float32)
|
||||
net.set_train()
|
||||
phase, _ = _cell_graph_executor.compile(net, x, y)
|
||||
validator = ParallelValidator(net, phase)
|
||||
expect_strategy = {"gen_strategy": '((1, 1), (1, 1))'}
|
||||
assert validator.check_node_attrs("Concat-0", expect_strategy)
|
||||
|
||||
|
||||
def test_batch_matmul():
|
||||
"""
|
||||
Feature: Test BatchMatMul with 2-dim weight and generate batch parallel strategy.
|
||||
Description: batch parallel strategy of weight must be full one.
|
||||
Expectation: Successful graph compilation.
|
||||
"""
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.batch_matmul = P.BatchMatMul()
|
||||
|
||||
def construct(self, x, y):
|
||||
out = self.batch_matmul(x, y)
|
||||
return out
|
||||
context.reset_auto_parallel_context()
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0)
|
||||
net = GradWrap(NetWithLoss(Net()))
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", full_batch=True)
|
||||
|
||||
x = Tensor(np.ones([128, 128, 32]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([32, 128]), dtype=ms.float32)
|
||||
net.set_train()
|
||||
phase, _ = _cell_graph_executor.compile(net, x, y)
|
||||
validator = ParallelValidator(net, phase)
|
||||
expect_strategy = {"gen_strategy": '((8, 1, 1), (1, 1))'}
|
||||
assert validator.check_node_attrs("BatchMatMul-0", expect_strategy)
|
||||
|
||||
|
||||
def test_onehot():
|
||||
"""
|
||||
Feature: Test OneHot with 2-dim input and generate batch parallel strategy.
|
||||
Description: batch parallel strategy must be full one.
|
||||
Expectation: Successful graph compilation.
|
||||
"""
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.ont_hot = P.OneHot()
|
||||
self.on = Tensor(1.0, ms.float32)
|
||||
self.off = Tensor(0.0, ms.float32)
|
||||
|
||||
def construct(self, x, y):
|
||||
out = self.ont_hot(x, 1, self.on, self.off)
|
||||
return out
|
||||
context.reset_auto_parallel_context()
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0)
|
||||
net = GradWrap(NetWithLoss(Net()))
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", full_batch=True)
|
||||
|
||||
x = Tensor(np.ones([2, 128]), dtype=ms.int32)
|
||||
y = Tensor(np.ones([32, 128]), dtype=ms.float32)
|
||||
net.set_train()
|
||||
phase, _ = _cell_graph_executor.compile(net, x, y)
|
||||
validator = ParallelValidator(net, phase)
|
||||
expect_strategy = {"gen_strategy": '((1, 1, 1), (), ())'}
|
||||
assert validator.check_node_attrs("OneHot-0", expect_strategy)
|
||||
|
||||
|
||||
def test_slice():
|
||||
"""
|
||||
Feature: Test Slice with input no fully fetched and generate batch parallel strategy.
|
||||
Description: batch parallel strategy must be full one.
|
||||
Expectation: Successful graph compilation.
|
||||
"""
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.slice = P.Slice()
|
||||
|
||||
def construct(self, x, y):
|
||||
out = self.slice(x, (0, 0), (64, 128))
|
||||
return out
|
||||
context.reset_auto_parallel_context()
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0, full_batch=True)
|
||||
net = GradWrap(NetWithLoss(Net()))
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
|
||||
|
||||
x = Tensor(np.ones([128, 128]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([32, 128]), dtype=ms.float32)
|
||||
net.set_train()
|
||||
phase, _ = _cell_graph_executor.compile(net, x, y)
|
||||
validator = ParallelValidator(net, phase)
|
||||
expect_strategy = {"gen_strategy": '((1, 1))'}
|
||||
assert validator.check_node_attrs("Slice-0", expect_strategy)
|
||||
|
||||
|
||||
def test_strided_slice():
|
||||
"""
|
||||
Feature: Test StridedSlice with input no fully fetched and generate batch parallel strategy.
|
||||
Description: batch parallel strategy must be full one.
|
||||
Expectation: Successful graph compilation.
|
||||
"""
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.slice = P.StridedSlice()
|
||||
|
||||
def construct(self, x, y):
|
||||
out = self.slice(x, (0, 0), (64, 128), (1, 1))
|
||||
return out
|
||||
context.reset_auto_parallel_context()
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0, full_batch=True)
|
||||
net = GradWrap(NetWithLoss(Net()))
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
|
||||
|
||||
x = Tensor(np.ones([128, 128]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([32, 128]), dtype=ms.float32)
|
||||
net.set_train()
|
||||
phase, _ = _cell_graph_executor.compile(net, x, y)
|
||||
validator = ParallelValidator(net, phase)
|
||||
expect_strategy = {"gen_strategy": '((1, 1))'}
|
||||
assert validator.check_node_attrs("StridedSlice-0", expect_strategy)
|
||||
|
||||
|
||||
def test_split():
|
||||
"""
|
||||
Feature: Test Split with axis=0 and generate batch parallel strategy.
|
||||
Description: batch parallel strategy must be full one.
|
||||
Expectation: Successful graph compilation.
|
||||
"""
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.split = P.Split(0, 2)
|
||||
|
||||
def construct(self, x, y):
|
||||
out, _ = self.split(x)
|
||||
return out
|
||||
context.reset_auto_parallel_context()
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0, full_batch=True)
|
||||
net = GradWrap(NetWithLoss(Net()))
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
|
||||
|
||||
x = Tensor(np.ones([128, 128]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([32, 128]), dtype=ms.float32)
|
||||
net.set_train()
|
||||
phase, _ = _cell_graph_executor.compile(net, x, y)
|
||||
validator = ParallelValidator(net, phase)
|
||||
expect_strategy = {"gen_strategy": '((1, 1))'}
|
||||
assert validator.check_node_attrs("Split-0", expect_strategy)
|
||||
|
||||
|
||||
def test_virtual_output():
|
||||
"""
|
||||
Feature: Test virtualoutput with return parameter in predict mode.
|
||||
Description: No need to insert virtualoutput.
|
||||
Expectation: Successful graph compilation.
|
||||
"""
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.param = Parameter(Tensor(np.ones([32, 32]), ms.float32))
|
||||
|
||||
def construct(self):
|
||||
return self.param
|
||||
context.reset_auto_parallel_context()
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0, full_batch=True)
|
||||
net = Net()
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
|
||||
|
||||
phase, _ = _cell_graph_executor.compile(net)
|
||||
validator = ParallelValidator(net, phase)
|
||||
sub_graph = {'Return-0': ['param']}
|
||||
assert validator.check_graph_structure(sub_graph)
|
Loading…
Reference in New Issue