!32341 add parallel ops about Invert CheckValid PopulationCount

Merge pull request !32341 from yangzhenzhang/add-parallel-operators
This commit is contained in:
i-robot 2022-04-01 02:53:58 +00:00 committed by Gitee
commit acce047dfe
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
8 changed files with 289 additions and 20 deletions

View File

@ -240,6 +240,9 @@ REGISTER(XdivyInfo);
REGISTER(XlogyInfo);
REGISTER(InplaceAddInfo);
REGISTER(InplaceSubInfo);
REGISTER(CheckValidInfo); // has not bprop
REGISTER(InvertInfo); // has not bprop
REGISTER(PopulationCountInfo); // has not bprop
REGISTER(CdistInfo);
REGISTER(L2LossInfo);
REGISTER(LerpInfo);

View File

@ -391,6 +391,24 @@ class ErfinvInfo : public ActivationOther {
: ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<ErfinvCost>()) {}
~ErfinvInfo() override = default;
};
// the Invert has not backward
class InvertInfo : public ActivationOther {
public:
InvertInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<ReLUCost>()) {}
~InvertInfo() = default;
};
// the PopulationCount has not backward
class PopulationCountInfo : public ActivationOther {
public:
PopulationCountInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<ReLUCost>()) {}
~PopulationCountInfo() = default;
};
} // namespace parallel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_ACTIVATION_INFO_H_

View File

@ -74,19 +74,10 @@ Status BatchParallelInfo::InferTensorMap() {
return SUCCESS;
}
if (strategy[0].empty()) {
MS_LOG(INFO) << name_ << ": the first element of strategy is empty";
return FAILED;
}
if (strategy[0][0] != stage_device_size_) {
MS_LOG(ERROR) << name_ << ": It is not a valid data parallel strategy.";
return FAILED;
}
for (size_t i = 0; i < inputs_shape_.size(); i++) {
Shape tensor_map_index;
for (size_t j = 0; j < inputs_shape_[i].size(); ++j) {
if (strategy[i][j] == stage_device_size_ && j == 0) {
if (strategy[i][j] > 1 && j == 0) {
tensor_map_index.push_back(0);
} else {
tensor_map_index.push_back(MAP_NONE);
@ -133,6 +124,8 @@ Status BatchParallelInfo::SetCostUnderStrategy(const StrategyPtr &strategy) {
std::vector<StrategyPtr> BatchParallelInfo::GenerateOpStrategies(int64_t stage_id) {
StrategyPtr sp;
Strategys strategy;
ComputeBatchSplitFlagList();
for (size_t i = 0; i < inputs_shape_.size(); i++) {
Shape temp(inputs_shape_[i].size(), 1);
if (split_flag_list_[i]) {
@ -146,12 +139,6 @@ std::vector<StrategyPtr> BatchParallelInfo::GenerateOpStrategies(int64_t stage_i
return sp_vector;
}
void SparseSoftmaxCrossEntropyWithLogitsInfo::ReComputeBatchSplitFlagList() {
for (size_t i = 0; i < inputs_shape_.size(); i++) {
split_flag_list_[i] = true;
}
}
Status BatchParallelInfo::InferAsLossDivisor() {
as_loss_divisor_ = 1;
return SUCCESS;
@ -181,5 +168,38 @@ void BatchParallelInfo::ReplaceNodeInputOrAttrs() {
AnfNodePtr val = NewValueNode(replace_shape);
(void)manager->Replace(cnode->input(1), val);
}
void SparseSoftmaxCrossEntropyWithLogitsInfo::ReComputeBatchSplitFlagList() {
for (size_t i = 0; i < inputs_shape_.size(); i++) {
split_flag_list_[i] = true;
}
}
Status CheckValidInfo::CheckStrategy(const StrategyPtr &strategy) {
if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) {
MS_LOG(ERROR) << name_ << " : Invalid strategy.";
return FAILED;
}
Strategys stra = strategy->GetInputDim();
if (stra[0][1] != 1) {
MS_LOG(ERROR) << name_ << ": The second dimension of the first input can not be split, but got " << stra[0][1];
return FAILED;
}
if (stra[1][0] != 1) {
MS_LOG(ERROR) << name_ << ": The second input can not be split, but got " << stra[1][0];
return FAILED;
}
return SUCCESS;
}
Status CheckValidInfo::InferDevMatrixShape() {
Strategys stra = strategy_->GetInputDim();
dev_matrix_shape_.push_back(stra[0][0]);
return SUCCESS;
}
void CheckValidInfo::ReComputeBatchSplitFlagList() { split_flag_list_[0] = true; }
} // namespace parallel
} // namespace mindspore

View File

@ -31,10 +31,10 @@ class BatchParallelInfo : public OperatorInfo {
public:
BatchParallelInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs, OperatorCostPtr cost)
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, cost), dev_num_(1) {}
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, cost) {}
BatchParallelInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<BatchParallelCost>()), dev_num_(1) {}
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<BatchParallelCost>()) {}
~BatchParallelInfo() override = default;
std::vector<StrategyPtr> GenerateOpStrategies(int64_t stage_id) override;
@ -50,7 +50,6 @@ class BatchParallelInfo : public OperatorInfo {
Status InferAsLossDivisor() override;
private:
int64_t dev_num_ = 1;
bool need_replace_input_ = false;
Shape replace_shape_;
};
@ -64,6 +63,20 @@ class SparseSoftmaxCrossEntropyWithLogitsInfo : public BatchParallelInfo {
~SparseSoftmaxCrossEntropyWithLogitsInfo() override = default;
void ReComputeBatchSplitFlagList() override;
};
// For CheckValid operator, only the first dimension of first input can be split.
class CheckValidInfo : public BatchParallelInfo {
public:
CheckValidInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: BatchParallelInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<BatchParallelCost>()) {}
~CheckValidInfo() override = default;
void ReComputeBatchSplitFlagList() override;
protected:
Status CheckStrategy(const StrategyPtr &strategy) override;
Status InferDevMatrixShape() override;
};
} // namespace parallel
} // namespace mindspore

View File

@ -469,6 +469,9 @@ constexpr char SELU[] = "SeLU";
constexpr char SOFT_SHRINK[] = "SoftShrink";
constexpr char XLOGY[] = "Xlogy";
constexpr char XDIVY[] = "Xdivy";
constexpr char CHECK_VALID[] = "CheckValid";
constexpr char INVERT[] = "Invert";
constexpr char POPULATION_COUNT[] = "PopulationCount";
constexpr char BITWISE_AND[] = "BitwiseAnd";
constexpr char BITWISE_OR[] = "BitwiseOr";
constexpr char BITWISE_XOR[] = "BitwiseXor";

View File

@ -179,7 +179,8 @@ bool IsSplittableOperator(const std::string &op_name) {
RESIZE_NEAREST_NEIGHBOR, CUM_SUM, FAST_GELU, IOU, BOUNDING_BOX_ENCODE, RANDOM_CHOICE_WITH_MASK, CROP_AND_RESIZE,
ROI_ALIGN, IS_FINITE, RINT, HSHRINK, HSIGMOID, MISH, SELU, SOFT_SHRINK, XLOGY, XDIVY, CUM_PROD, BITWISE_AND,
BITWISE_OR, BITWISE_XOR, MUL_NO_NAN, TRUNCATE_DIV, TRUNCATE_MOD, INPLACE_ADD, INPLACE_SUB, L2_LOSS, LERP, ADDN,
CDIST, SQUARED_DIFFERENCE, ERFINV, MASKED_FILL, SPLITV, GAMMA, KLDIV_LOSS, LIN_SPACE};
CDIST, SQUARED_DIFFERENCE, ERFINV, MASKED_FILL, SPLITV, GAMMA, KLDIV_LOSS, LIN_SPACE, CHECK_VALID, INVERT,
POPULATION_COUNT};
// clang-format on
auto iter = splittable_op.find(op_name);

View File

@ -0,0 +1,139 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore as ms
from mindspore import context, Tensor, Parameter
from mindspore.nn import Cell
from mindspore.ops import operations as P
from mindspore.train import Model
class Net(Cell):
def __init__(self, weight, strategy):
super().__init__()
self.check_valid = P.CheckValid().shard(strategy)
self.mul = P.Mul()
cast_strategy = None
if strategy:
cast_strategy = (strategy[0],)
self.cast = P.Cast().shard(cast_strategy)
self.relu = P.ReLU()
self.weight = Parameter(weight, "w1")
def construct(self, x, b):
out = self.mul(x, self.weight)
out = self.check_valid(out, b)
out = self.cast(out, ms.float32)
out = self.relu(out)
return out
_x = Tensor(np.ones([16, 4]), dtype=ms.float32)
_w = Tensor(np.ones([16, 4]), dtype=ms.float32)
_b = Tensor(np.ones([3]), dtype=ms.float32)
def compile_net(net):
model = Model(net)
model.predict(_x, _b)
context.reset_auto_parallel_context()
def test_check_valid_data_parallel():
"""
Feature: test check valid data parallel
Description:
Expectation: compile success
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, full_batch=True)
strategy = ((8, 1), (1,))
net = Net(_w, strategy)
compile_net(net)
def test_check_valid_repeated_calc():
"""
Feature: test check valid repeated calculation
Description:
Expectation: compile success
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, full_batch=True)
strategy = ((2, 1), (1,))
net = Net(_w, strategy)
compile_net(net)
def test_check_valid_no_shard():
"""
Feature: test check valid no shard
Description:
Expectation: compile success
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, full_batch=True)
strategy = ((1, 1), (1,))
net = Net(_w, strategy)
compile_net(net)
def test_check_valid_strategy_none():
"""
Feature: test check valid strategy none
Description: generator batch parallel strategy
Expectation: compile success
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, full_batch=True)
strategy = None
net = Net(_w, strategy)
compile_net(net)
def test_check_valid_auto_parallel():
"""
Feature: test check valid auto parallel
Description:
Expectation: compile success
"""
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, full_batch=True)
strategy = None
net = Net(_w, strategy)
compile_net(net)
def test_check_valid_shard_img():
"""
Feature: test check valid shard img
Description:
Expectation: compile failed
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, full_batch=True)
strategy = ((2, 1), (4,))
net = Net(_w, strategy)
with pytest.raises(RuntimeError):
compile_net(net)
def test_check_valid_shard_bbox_second_dimension():
"""
Feature: test check valid shard bbox second dimension
Description:
Expectation: compile failed
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, full_batch=True)
strategy = ((2, 2), (1,))
net = Net(_w, strategy)
with pytest.raises(RuntimeError):
compile_net(net)

View File

@ -0,0 +1,72 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import mindspore as ms
from mindspore import context, Tensor
from mindspore.nn import Cell
from mindspore.ops import operations as P
from mindspore.train import Model
class Net(Cell):
def __init__(self, strategy):
super().__init__()
self.invert = P.Invert().shard(strategy)
self.pop = P.PopulationCount().shard(strategy)
self.cast = P.Cast().shard(strategy)
self.relu = P.ReLU().shard(strategy)
def construct(self, x, b):
out = self.invert(x)
out = self.pop(out)
out = self.cast(out, ms.float32)
out = self.relu(out)
return out
_x = Tensor(np.ones([16, 16]), dtype=ms.int16)
_w = Tensor(np.ones([16, 16]), dtype=ms.int16)
_b = Tensor(np.ones([16, 16]), dtype=ms.int16)
def compile_net(net):
model = Model(net)
model.predict(_x, _b)
context.reset_auto_parallel_context()
def test_invert_population_count_semi():
"""
Feature: semi auto parallel
Description:
Expectation: compile success
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, full_batch=True)
strategy = ((2, 4),)
net = Net(strategy)
compile_net(net)
def test_invert_population_count_auto():
"""
Feature: auto parallel
Description:
Expectation: compile success
"""
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, full_batch=True)
strategy = None
net = Net(strategy)
compile_net(net)