forked from mindspore-Ecosystem/mindspore
!32341 add parallel ops about Invert CheckValid PopulationCount
Merge pull request !32341 from yangzhenzhang/add-parallel-operators
This commit is contained in:
commit
acce047dfe
|
@ -240,6 +240,9 @@ REGISTER(XdivyInfo);
|
|||
REGISTER(XlogyInfo);
|
||||
REGISTER(InplaceAddInfo);
|
||||
REGISTER(InplaceSubInfo);
|
||||
REGISTER(CheckValidInfo); // has not bprop
|
||||
REGISTER(InvertInfo); // has not bprop
|
||||
REGISTER(PopulationCountInfo); // has not bprop
|
||||
REGISTER(CdistInfo);
|
||||
REGISTER(L2LossInfo);
|
||||
REGISTER(LerpInfo);
|
||||
|
|
|
@ -391,6 +391,24 @@ class ErfinvInfo : public ActivationOther {
|
|||
: ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<ErfinvCost>()) {}
|
||||
~ErfinvInfo() override = default;
|
||||
};
|
||||
|
||||
// the Invert has not backward
|
||||
class InvertInfo : public ActivationOther {
|
||||
public:
|
||||
InvertInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<ReLUCost>()) {}
|
||||
~InvertInfo() = default;
|
||||
};
|
||||
|
||||
// the PopulationCount has not backward
|
||||
class PopulationCountInfo : public ActivationOther {
|
||||
public:
|
||||
PopulationCountInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<ReLUCost>()) {}
|
||||
~PopulationCountInfo() = default;
|
||||
};
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_ACTIVATION_INFO_H_
|
||||
|
|
|
@ -74,19 +74,10 @@ Status BatchParallelInfo::InferTensorMap() {
|
|||
return SUCCESS;
|
||||
}
|
||||
|
||||
if (strategy[0].empty()) {
|
||||
MS_LOG(INFO) << name_ << ": the first element of strategy is empty";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
if (strategy[0][0] != stage_device_size_) {
|
||||
MS_LOG(ERROR) << name_ << ": It is not a valid data parallel strategy.";
|
||||
return FAILED;
|
||||
}
|
||||
for (size_t i = 0; i < inputs_shape_.size(); i++) {
|
||||
Shape tensor_map_index;
|
||||
for (size_t j = 0; j < inputs_shape_[i].size(); ++j) {
|
||||
if (strategy[i][j] == stage_device_size_ && j == 0) {
|
||||
if (strategy[i][j] > 1 && j == 0) {
|
||||
tensor_map_index.push_back(0);
|
||||
} else {
|
||||
tensor_map_index.push_back(MAP_NONE);
|
||||
|
@ -133,6 +124,8 @@ Status BatchParallelInfo::SetCostUnderStrategy(const StrategyPtr &strategy) {
|
|||
std::vector<StrategyPtr> BatchParallelInfo::GenerateOpStrategies(int64_t stage_id) {
|
||||
StrategyPtr sp;
|
||||
Strategys strategy;
|
||||
ComputeBatchSplitFlagList();
|
||||
|
||||
for (size_t i = 0; i < inputs_shape_.size(); i++) {
|
||||
Shape temp(inputs_shape_[i].size(), 1);
|
||||
if (split_flag_list_[i]) {
|
||||
|
@ -146,12 +139,6 @@ std::vector<StrategyPtr> BatchParallelInfo::GenerateOpStrategies(int64_t stage_i
|
|||
return sp_vector;
|
||||
}
|
||||
|
||||
void SparseSoftmaxCrossEntropyWithLogitsInfo::ReComputeBatchSplitFlagList() {
|
||||
for (size_t i = 0; i < inputs_shape_.size(); i++) {
|
||||
split_flag_list_[i] = true;
|
||||
}
|
||||
}
|
||||
|
||||
Status BatchParallelInfo::InferAsLossDivisor() {
|
||||
as_loss_divisor_ = 1;
|
||||
return SUCCESS;
|
||||
|
@ -181,5 +168,38 @@ void BatchParallelInfo::ReplaceNodeInputOrAttrs() {
|
|||
AnfNodePtr val = NewValueNode(replace_shape);
|
||||
(void)manager->Replace(cnode->input(1), val);
|
||||
}
|
||||
|
||||
void SparseSoftmaxCrossEntropyWithLogitsInfo::ReComputeBatchSplitFlagList() {
|
||||
for (size_t i = 0; i < inputs_shape_.size(); i++) {
|
||||
split_flag_list_[i] = true;
|
||||
}
|
||||
}
|
||||
|
||||
Status CheckValidInfo::CheckStrategy(const StrategyPtr &strategy) {
|
||||
if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << " : Invalid strategy.";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
Strategys stra = strategy->GetInputDim();
|
||||
if (stra[0][1] != 1) {
|
||||
MS_LOG(ERROR) << name_ << ": The second dimension of the first input can not be split, but got " << stra[0][1];
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
if (stra[1][0] != 1) {
|
||||
MS_LOG(ERROR) << name_ << ": The second input can not be split, but got " << stra[1][0];
|
||||
return FAILED;
|
||||
}
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status CheckValidInfo::InferDevMatrixShape() {
|
||||
Strategys stra = strategy_->GetInputDim();
|
||||
dev_matrix_shape_.push_back(stra[0][0]);
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
void CheckValidInfo::ReComputeBatchSplitFlagList() { split_flag_list_[0] = true; }
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -31,10 +31,10 @@ class BatchParallelInfo : public OperatorInfo {
|
|||
public:
|
||||
BatchParallelInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs, OperatorCostPtr cost)
|
||||
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, cost), dev_num_(1) {}
|
||||
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, cost) {}
|
||||
BatchParallelInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<BatchParallelCost>()), dev_num_(1) {}
|
||||
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<BatchParallelCost>()) {}
|
||||
|
||||
~BatchParallelInfo() override = default;
|
||||
std::vector<StrategyPtr> GenerateOpStrategies(int64_t stage_id) override;
|
||||
|
@ -50,7 +50,6 @@ class BatchParallelInfo : public OperatorInfo {
|
|||
Status InferAsLossDivisor() override;
|
||||
|
||||
private:
|
||||
int64_t dev_num_ = 1;
|
||||
bool need_replace_input_ = false;
|
||||
Shape replace_shape_;
|
||||
};
|
||||
|
@ -64,6 +63,20 @@ class SparseSoftmaxCrossEntropyWithLogitsInfo : public BatchParallelInfo {
|
|||
~SparseSoftmaxCrossEntropyWithLogitsInfo() override = default;
|
||||
void ReComputeBatchSplitFlagList() override;
|
||||
};
|
||||
|
||||
// For CheckValid operator, only the first dimension of first input can be split.
|
||||
class CheckValidInfo : public BatchParallelInfo {
|
||||
public:
|
||||
CheckValidInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: BatchParallelInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<BatchParallelCost>()) {}
|
||||
~CheckValidInfo() override = default;
|
||||
void ReComputeBatchSplitFlagList() override;
|
||||
|
||||
protected:
|
||||
Status CheckStrategy(const StrategyPtr &strategy) override;
|
||||
Status InferDevMatrixShape() override;
|
||||
};
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -469,6 +469,9 @@ constexpr char SELU[] = "SeLU";
|
|||
constexpr char SOFT_SHRINK[] = "SoftShrink";
|
||||
constexpr char XLOGY[] = "Xlogy";
|
||||
constexpr char XDIVY[] = "Xdivy";
|
||||
constexpr char CHECK_VALID[] = "CheckValid";
|
||||
constexpr char INVERT[] = "Invert";
|
||||
constexpr char POPULATION_COUNT[] = "PopulationCount";
|
||||
constexpr char BITWISE_AND[] = "BitwiseAnd";
|
||||
constexpr char BITWISE_OR[] = "BitwiseOr";
|
||||
constexpr char BITWISE_XOR[] = "BitwiseXor";
|
||||
|
|
|
@ -179,7 +179,8 @@ bool IsSplittableOperator(const std::string &op_name) {
|
|||
RESIZE_NEAREST_NEIGHBOR, CUM_SUM, FAST_GELU, IOU, BOUNDING_BOX_ENCODE, RANDOM_CHOICE_WITH_MASK, CROP_AND_RESIZE,
|
||||
ROI_ALIGN, IS_FINITE, RINT, HSHRINK, HSIGMOID, MISH, SELU, SOFT_SHRINK, XLOGY, XDIVY, CUM_PROD, BITWISE_AND,
|
||||
BITWISE_OR, BITWISE_XOR, MUL_NO_NAN, TRUNCATE_DIV, TRUNCATE_MOD, INPLACE_ADD, INPLACE_SUB, L2_LOSS, LERP, ADDN,
|
||||
CDIST, SQUARED_DIFFERENCE, ERFINV, MASKED_FILL, SPLITV, GAMMA, KLDIV_LOSS, LIN_SPACE};
|
||||
CDIST, SQUARED_DIFFERENCE, ERFINV, MASKED_FILL, SPLITV, GAMMA, KLDIV_LOSS, LIN_SPACE, CHECK_VALID, INVERT,
|
||||
POPULATION_COUNT};
|
||||
// clang-format on
|
||||
|
||||
auto iter = splittable_op.find(op_name);
|
||||
|
|
|
@ -0,0 +1,139 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore as ms
|
||||
from mindspore import context, Tensor, Parameter
|
||||
from mindspore.nn import Cell
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.train import Model
|
||||
|
||||
|
||||
class Net(Cell):
|
||||
def __init__(self, weight, strategy):
|
||||
super().__init__()
|
||||
self.check_valid = P.CheckValid().shard(strategy)
|
||||
self.mul = P.Mul()
|
||||
cast_strategy = None
|
||||
if strategy:
|
||||
cast_strategy = (strategy[0],)
|
||||
self.cast = P.Cast().shard(cast_strategy)
|
||||
self.relu = P.ReLU()
|
||||
self.weight = Parameter(weight, "w1")
|
||||
|
||||
def construct(self, x, b):
|
||||
out = self.mul(x, self.weight)
|
||||
out = self.check_valid(out, b)
|
||||
out = self.cast(out, ms.float32)
|
||||
out = self.relu(out)
|
||||
return out
|
||||
|
||||
|
||||
_x = Tensor(np.ones([16, 4]), dtype=ms.float32)
|
||||
_w = Tensor(np.ones([16, 4]), dtype=ms.float32)
|
||||
_b = Tensor(np.ones([3]), dtype=ms.float32)
|
||||
|
||||
|
||||
def compile_net(net):
|
||||
model = Model(net)
|
||||
model.predict(_x, _b)
|
||||
context.reset_auto_parallel_context()
|
||||
|
||||
|
||||
def test_check_valid_data_parallel():
|
||||
"""
|
||||
Feature: test check valid data parallel
|
||||
Description:
|
||||
Expectation: compile success
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, full_batch=True)
|
||||
strategy = ((8, 1), (1,))
|
||||
net = Net(_w, strategy)
|
||||
compile_net(net)
|
||||
|
||||
|
||||
def test_check_valid_repeated_calc():
|
||||
"""
|
||||
Feature: test check valid repeated calculation
|
||||
Description:
|
||||
Expectation: compile success
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, full_batch=True)
|
||||
strategy = ((2, 1), (1,))
|
||||
net = Net(_w, strategy)
|
||||
compile_net(net)
|
||||
|
||||
|
||||
def test_check_valid_no_shard():
|
||||
"""
|
||||
Feature: test check valid no shard
|
||||
Description:
|
||||
Expectation: compile success
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, full_batch=True)
|
||||
strategy = ((1, 1), (1,))
|
||||
net = Net(_w, strategy)
|
||||
compile_net(net)
|
||||
|
||||
|
||||
def test_check_valid_strategy_none():
|
||||
"""
|
||||
Feature: test check valid strategy none
|
||||
Description: generator batch parallel strategy
|
||||
Expectation: compile success
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, full_batch=True)
|
||||
strategy = None
|
||||
net = Net(_w, strategy)
|
||||
compile_net(net)
|
||||
|
||||
|
||||
def test_check_valid_auto_parallel():
|
||||
"""
|
||||
Feature: test check valid auto parallel
|
||||
Description:
|
||||
Expectation: compile success
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, full_batch=True)
|
||||
strategy = None
|
||||
net = Net(_w, strategy)
|
||||
compile_net(net)
|
||||
|
||||
|
||||
def test_check_valid_shard_img():
|
||||
"""
|
||||
Feature: test check valid shard img
|
||||
Description:
|
||||
Expectation: compile failed
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, full_batch=True)
|
||||
strategy = ((2, 1), (4,))
|
||||
net = Net(_w, strategy)
|
||||
with pytest.raises(RuntimeError):
|
||||
compile_net(net)
|
||||
|
||||
|
||||
def test_check_valid_shard_bbox_second_dimension():
|
||||
"""
|
||||
Feature: test check valid shard bbox second dimension
|
||||
Description:
|
||||
Expectation: compile failed
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, full_batch=True)
|
||||
strategy = ((2, 2), (1,))
|
||||
net = Net(_w, strategy)
|
||||
with pytest.raises(RuntimeError):
|
||||
compile_net(net)
|
|
@ -0,0 +1,72 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
|
||||
import mindspore as ms
|
||||
from mindspore import context, Tensor
|
||||
from mindspore.nn import Cell
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.train import Model
|
||||
|
||||
|
||||
class Net(Cell):
|
||||
def __init__(self, strategy):
|
||||
super().__init__()
|
||||
self.invert = P.Invert().shard(strategy)
|
||||
self.pop = P.PopulationCount().shard(strategy)
|
||||
self.cast = P.Cast().shard(strategy)
|
||||
self.relu = P.ReLU().shard(strategy)
|
||||
|
||||
def construct(self, x, b):
|
||||
out = self.invert(x)
|
||||
out = self.pop(out)
|
||||
out = self.cast(out, ms.float32)
|
||||
out = self.relu(out)
|
||||
return out
|
||||
|
||||
|
||||
_x = Tensor(np.ones([16, 16]), dtype=ms.int16)
|
||||
_w = Tensor(np.ones([16, 16]), dtype=ms.int16)
|
||||
_b = Tensor(np.ones([16, 16]), dtype=ms.int16)
|
||||
|
||||
|
||||
def compile_net(net):
|
||||
model = Model(net)
|
||||
model.predict(_x, _b)
|
||||
context.reset_auto_parallel_context()
|
||||
|
||||
|
||||
def test_invert_population_count_semi():
|
||||
"""
|
||||
Feature: semi auto parallel
|
||||
Description:
|
||||
Expectation: compile success
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, full_batch=True)
|
||||
strategy = ((2, 4),)
|
||||
net = Net(strategy)
|
||||
compile_net(net)
|
||||
|
||||
|
||||
def test_invert_population_count_auto():
|
||||
"""
|
||||
Feature: auto parallel
|
||||
Description:
|
||||
Expectation: compile success
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, full_batch=True)
|
||||
strategy = None
|
||||
net = Net(strategy)
|
||||
compile_net(net)
|
Loading…
Reference in New Issue