!45083 Fix some bug in elemnt wise and math parallel operator.

Merge pull request !45083 from liuluobin/parallel_ops
This commit is contained in:
i-robot 2022-11-07 02:00:44 +00:00 committed by Gitee
commit 179341710f
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
22 changed files with 226 additions and 177 deletions

View File

@ -667,6 +667,8 @@ using BitwiseOrCost = SubCost;
using BitwiseXorCost = SubCost;
using AddNCost = SubCost;
using InplaceAddCost = SubCost;
using InplaceSubCost = InplaceAddCost;
using InplaceUpdateCost = InplaceAddCost;
using MaskedFillCost = SubCost;
class MulCost : public SubCost {

View File

@ -252,7 +252,7 @@ Status CumOpBase::CheckStrategy(const StrategyPtr &strategy) {
return FAILED;
}
auto axis_split = input_strategy[LongToSize(axis_)];
if (axis_split > 1) {
if (axis_split != NO_SPLIT_STRATEGY) {
MS_LOG(ERROR) << "Currently, CumSum does not support the sharding strategies which splits axis.";
return FAILED;
}

View File

@ -31,6 +31,8 @@ Status AddNInfo::CheckStrategy(const StrategyPtr &strategy) {
Strategies strategies = strategy->GetInputDim();
for (size_t i = 1; i < strategies.size(); ++i) {
if (strategies[i] != strategies[0]) {
MS_LOG(ERROR) << name_ << ": The strategy for each input must be equal to strategies[0]: " << strategies[0]
<< ", but got strategies[" << i << "]: " << strategies[i];
return FAILED;
}
}

View File

@ -25,6 +25,8 @@
namespace mindspore {
namespace parallel {
int64_t GammaInfo::SEED_NUM = 1;
Status GammaInfo::InferAttrs() {
if (infer_attrs_completed_) {
return SUCCESS;
@ -136,12 +138,17 @@ void GammaInfo::ReplaceNodeInputOrAttrs() {
// Update seed according rank_id
int64_t rank_id = g_device_manager->rank_index_in_stage();
int64_t seed_bias;
int64_t seed_bias = 0;
// When seed and seed2 are both 0, ensure that the 0th card in each group has the same result
if (seed_ == 0 && seed2_ == 0) {
seed_bias += SEED_NUM;
++SEED_NUM;
}
if (repeated_num_in_dev_matrix_right_) {
seed_bias = rank_id / repeated_calc_num_;
seed_bias += rank_id / repeated_calc_num_;
} else {
int64_t device_num = stage_device_size_;
seed_bias = rank_id % (device_num / repeated_calc_num_);
seed_bias += rank_id % (device_num / repeated_calc_num_);
}
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
@ -160,7 +167,7 @@ void GammaInfo::ResetInputsShape() {
}
bool GammaInfo::IsNotSplittableStrategy(const Dimensions &strategy) const {
return std::all_of(strategy.cbegin(), strategy.cend(), [](int64_t val) { return val == 1; });
return std::all_of(strategy.cbegin(), strategy.cend(), [](int64_t val) { return val == NO_SPLIT_STRATEGY; });
}
void GammaInfo::ReComputeBatchSplitFlagList() {

View File

@ -57,6 +57,7 @@ class GammaInfo : public OperatorInfo {
int64_t seed_ = 0;
int64_t seed2_ = 0;
static int64_t SEED_NUM;
};
} // namespace parallel
} // namespace mindspore

View File

@ -18,12 +18,12 @@
#include <algorithm>
#include <functional>
#include "frontend/parallel/ops_info/inplace_add_info.h"
#include "frontend/parallel/ops_info/inplace_op_info.h"
#include "frontend/parallel/dynamic_creator.h"
namespace mindspore {
namespace parallel {
Status InplaceAddInfo::CheckStrategy(const StrategyPtr &strategy) {
Status InplaceOpBase::CheckStrategy(const StrategyPtr &strategy) {
if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Invalid strategy";
return FAILED;
@ -44,7 +44,7 @@ Status InplaceAddInfo::CheckStrategy(const StrategyPtr &strategy) {
return SUCCESS;
}
Status InplaceAddInfo::InferDevMatrixShape() {
Status InplaceOpBase::InferDevMatrixShape() {
dev_matrix_shape_.clear();
auto strategies = strategy_->GetInputDim();
@ -55,7 +55,7 @@ Status InplaceAddInfo::InferDevMatrixShape() {
return SUCCESS;
}
Status InplaceAddInfo::InferTensorMap() {
Status InplaceOpBase::InferTensorMap() {
inputs_tensor_map_.clear();
outputs_tensor_map_.clear();
@ -77,7 +77,7 @@ Status InplaceAddInfo::InferTensorMap() {
return SUCCESS;
}
std::vector<StrategyPtr> InplaceAddInfo::GenerateOpStrategies(int64_t stage_id) {
std::vector<StrategyPtr> InplaceOpBase::GenerateOpStrategies(int64_t stage_id) {
Shapes splittable_inputs = inputs_shape_;
for (size_t i = 0; i < splittable_inputs.size(); ++i) {
for (size_t j = 0; j < splittable_inputs[i].size(); ++j) {
@ -91,12 +91,13 @@ std::vector<StrategyPtr> InplaceAddInfo::GenerateOpStrategies(int64_t stage_id)
return sp_vector;
}
void InplaceAddInfo::ReComputeBatchSplitFlagList() {
void InplaceOpBase::ReComputeBatchSplitFlagList() {
split_flag_list_[0] = false;
split_flag_list_[1] = false;
}
REGISTER(InplaceAddInfo);
REGISTER(InplaceSubInfo);
REGISTER(InplaceUpdateInfo);
} // namespace parallel
} // namespace mindspore

View File

@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_INPLACE_ADD_INFO_H_
#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_INPLACE_ADD_INFO_H_
#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_INPLACE_OP_INFO_H_
#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_INPLACE_OP_INFO_H_
#include <memory>
#include <string>
@ -27,12 +27,12 @@
namespace mindspore {
namespace parallel {
class InplaceAddInfo : public OperatorInfo {
class InplaceOpBase : public OperatorInfo {
public:
InplaceAddInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<InplaceAddCost>()) {}
~InplaceAddInfo() override = default;
InplaceOpBase(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs, const OperatorCostPtr &cost)
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, cost) {}
~InplaceOpBase() override = default;
std::vector<StrategyPtr> GenerateOpStrategies(int64_t stage_id) override;
Status SetCostUnderStrategy(const StrategyPtr &strategy) override { return SetCostUnderStrategyBase(strategy); }
void ReComputeBatchSplitFlagList() override;
@ -48,14 +48,30 @@ class InplaceAddInfo : public OperatorInfo {
Status InferForwardCommunication() override { return SUCCESS; }
};
class InplaceSubInfo : public InplaceAddInfo {
class InplaceAddInfo : public InplaceOpBase {
public:
InplaceAddInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: InplaceOpBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<InplaceAddCost>()) {}
~InplaceAddInfo() override = default;
};
class InplaceSubInfo : public InplaceOpBase {
public:
InplaceSubInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: InplaceAddInfo(name, inputs_shape, outputs_shape, attrs) {}
~InplaceSubInfo() = default;
: InplaceOpBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<InplaceSubCost>()) {}
~InplaceSubInfo() override = default;
};
class InplaceUpdateInfo : public InplaceOpBase {
public:
InplaceUpdateInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: InplaceOpBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<InplaceUpdateCost>()) {}
~InplaceUpdateInfo() override = default;
};
} // namespace parallel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_INPLACE_ADD_INFO_H_
#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_INPLACE_OP_INFO_H_

View File

@ -41,6 +41,7 @@ Status KLDivLossInfo::CheckStrategy(const StrategyPtr &strategy) {
<< StrategyToString(strategies);
return FAILED;
}
batch_split_num_ = strategies[0][0];
return SUCCESS;
}
@ -128,6 +129,11 @@ Status KLDivLossInfo::InferForwardCommunication() {
forward_op_ = CreateReduceMeanForwardOp(group_list_, element_type);
} else {
(void)forward_op_.emplace_back(CreateAllReduceOp(REDUCE_OP_SUM, group_list_[0].name()));
if (reduction_ == BATCH_MEAN) {
// Divided by the number of devices in the Batch dimension
auto element_type = outputs_dtype_->cast<mindspore::TensorTypePtr>()->element();
(void)forward_op_.emplace_back(CreateDivOpWithType(SizeToFloat(batch_split_num_), element_type));
}
}
MS_LOG(INFO) << name_ << ": The group name of forward all reduce is " << group_list_[0].name();

View File

@ -48,6 +48,7 @@ class KLDivLossInfo : public OperatorInfo {
private:
Status InferGroup();
std::size_t batch_split_num_;
std::string reduction_;
std::vector<Group> group_list_;
};

View File

@ -38,16 +38,21 @@ Status LinSpaceInfo::CheckStrategy(const StrategyPtr &strategy) {
return FAILED;
}
auto split_num = strategies[0][0];
if (output_size_ % split_num != 0) {
MS_LOG(ERROR) << name_ << ": The strategy is " << StrategyToString(strategies) << ", output size is "
<< output_size_ << " cannot be divisible by strategy value " << split_num;
split_num_ = strategies[0][0];
if (split_num_ <= 0) {
MS_LOG(ERROR) << name_ << ": Each element in strategy must be a positive integer, but got "
<< StrategyToString(strategies);
return FAILED;
}
if (stage_device_size_ % split_num != 0) {
if (output_size_ % split_num_ != 0) {
MS_LOG(ERROR) << name_ << ": The strategy is " << StrategyToString(strategies) << ", output size is "
<< output_size_ << " cannot be divisible by strategy value " << split_num_;
return FAILED;
}
if (stage_device_size_ % split_num_ != 0) {
MS_LOG(ERROR) << name_ << ": The strategy is " << StrategyToString(strategies)
<< ", the device size in this stage is " << stage_device_size_
<< " cannot be divisible by the strategy value " << split_num;
<< " cannot be divisible by the strategy value " << split_num_;
return FAILED;
}
return SUCCESS;
@ -97,31 +102,15 @@ std::shared_ptr<Strategies> LinSpaceInfo::GenerateBatchStrategies() {
return std::make_shared<Strategies>(strategies);
}
int64_t LinSpaceInfo::GetSplitNum() {
auto strategies = strategy_->GetInputDim();
auto split_strategy = strategies.at(0);
auto split_num = split_strategy.at(0);
return split_num;
}
ReplaceGraphPtr LinSpaceInfo::replace_graph(const CNodePtr &cnode) {
auto split_num = GetSplitNum();
if (split_num > 1 && ComputeReplaceGraph(cnode) != SUCCESS) {
if (split_num_ > 1 && ComputeReplaceGraph(cnode) != SUCCESS) {
MS_LOG(EXCEPTION) << name_ << ": ComputeReplaceGraph failed.";
}
return replace_graph_;
}
Status LinSpaceInfo::ComputeReplaceGraph(const CNodePtr &cnode) {
auto split_num = GetSplitNum();
if (split_num == 0) {
MS_LOG(ERROR) << name_ << ": split num is 1, no need to replace graph";
return FAILED;
}
if (output_size_ % split_num != 0) {
return FAILED;
}
if (split_num == 1) {
if (split_num_ == 1) {
MS_LOG(INFO) << name_ << ": split num is 1, no need to replace graph";
return SUCCESS;
}
@ -132,7 +121,7 @@ Status LinSpaceInfo::ComputeReplaceGraph(const CNodePtr &cnode) {
return FAILED;
}
int64_t slice_output_size = output_size_ / split_num;
int64_t slice_output_size = output_size_ / split_num_;
auto sub = gen_g.PushBack({gen_g.NewOpInst(SUB), gen_g.virtual_input_node(), gen_g.virtual_input_node()});
auto dtype = gen_g.PushBack({gen_g.NewOpInst(DTYPE), sub});
AnfNodePtr interval = nullptr;
@ -146,6 +135,7 @@ Status LinSpaceInfo::ComputeReplaceGraph(const CNodePtr &cnode) {
// new_start = start + slice_id * slice_output_size * interval
// new_end = new_start + (slice_output_size - 1) * interval
// new_x = slice_output_size
InferSliceId();
auto offset_size = gen_g.PushBack({gen_g.NewOpInst(CAST), CreateInt32Tensor(slice_id_ * slice_output_size), dtype});
auto start_offset = gen_g.PushBack({gen_g.NewOpInst(MUL), interval, offset_size});
auto new_start = gen_g.PushBack({gen_g.NewOpInst(ADD), gen_g.virtual_input_node(), start_offset});
@ -161,7 +151,7 @@ Status LinSpaceInfo::ComputeReplaceGraph(const CNodePtr &cnode) {
return SUCCESS;
}
Status LinSpaceInfo::InferSliceId() {
void LinSpaceInfo::InferSliceId() {
CheckGlobalDeviceManager();
int64_t rank = g_device_manager->rank_index_in_stage();
slice_id_ = rank;
@ -172,7 +162,6 @@ Status LinSpaceInfo::InferSliceId() {
slice_id_ %= dev_matrix_shape_.front();
}
}
return SUCCESS;
}
REGISTER(LinSpaceInfo);

View File

@ -49,11 +49,12 @@ class LinSpaceInfo : public OperatorInfo {
Status InferForwardCommunication() override { return SUCCESS; }
private:
Status InferSliceId();
void InferSliceId();
Status ComputeReplaceGraph(const CNodePtr &cnode);
int64_t GetSplitNum();
int64_t slice_id_ = 0;
int64_t split_num_ = 0;
int64_t output_size_ = 0;
};
} // namespace parallel

View File

@ -362,10 +362,10 @@ Operator CreateDivOp(float scale) {
OperatorName operator_name = REAL_DIV;
OperatorAttrs operator_attrs;
OperatorParams operator_param;
size_t parameter_pos = 2;
constexpr size_t parameter_pos = 2;
mindspore::tensor::TensorPtr tensor_ptr = std::make_shared<mindspore::tensor::Tensor>(scale);
ValuePtr scale_value = MakeValue(tensor_ptr);
operator_param.push_back(std::make_pair(std::make_pair(Y, scale_value), parameter_pos));
(void)operator_param.emplace_back(std::make_pair(std::make_pair(Y, scale_value), parameter_pos));
OperatorArgs operator_arg = std::make_pair(operator_attrs, operator_param);
Operator op = std::make_pair(operator_name, operator_arg);
@ -2322,6 +2322,18 @@ AnfNodePtr CreateTensorTupleAnfNodePtr(const tensor::TensorPtrList &tensor_tuple
return tensor_node->cast<AnfNodePtr>();
}
Operator CreateDivOpWithType(float divisor, const TypePtr &dtype) {
OperatorName operator1_name = REAL_DIV;
mindspore::tensor::TensorPtr tensor_ptr = std::make_shared<mindspore::tensor::Tensor>(divisor, dtype);
ValuePtr op1_param_value = MakeValue(tensor_ptr);
Attr op1_param = std::make_pair("divisor", op1_param_value);
OperatorParams operator1_params = {std::make_pair(op1_param, 2)};
OperatorAttrs operator1_attrs;
OperatorArgs operator1_args = std::make_pair(operator1_attrs, operator1_params);
Operator div_op = std::make_pair(operator1_name, operator1_args);
return div_op;
}
ForwardOp CreateReduceMeanForwardOp(const std::vector<Group> &forward_group, const TypePtr &dtype) {
// Create AllReduceSum op
Operator op0 = CreateAllReduceOp(REDUCE_OP_SUM, forward_group[0].name());
@ -2329,21 +2341,13 @@ ForwardOp CreateReduceMeanForwardOp(const std::vector<Group> &forward_group, con
MS_LOG(INFO) << "The group of forward all reduce is " << group_name;
// Create RealDiv op
OperatorName operator1_name = REAL_DIV;
std::vector<Device> device_list = forward_group[0].GetDevicesList();
auto divisor = static_cast<float>(device_list.size());
mindspore::tensor::TensorPtr tensor_ptr = std::make_shared<mindspore::tensor::Tensor>(divisor, dtype);
ValuePtr op1_param_value = MakeValue(tensor_ptr);
Attr op1_param = std::make_pair("divisor", op1_param_value);
OperatorParams operator1_params = {std::make_pair(op1_param, 2)};
OperatorAttrs operator1_attrs;
OperatorArgs operator1_args = std::make_pair(operator1_attrs, operator1_params);
Operator op1 = std::make_pair(operator1_name, operator1_args);
ForwardOp forward_op = {op0, op1};
auto divisor = SizeToFloat(device_list.size());
Operator op1 = CreateDivOpWithType(divisor, dtype);
std::string dtype_name = dtype->ToString();
MS_LOG(INFO) << "The divisor of Div op is " << device_list.size() << ", the dtype is " << dtype_name;
return forward_op;
return {op0, op1};
}
} // namespace parallel
} // namespace mindspore

View File

@ -392,6 +392,7 @@ AnfNodePtr CreateValueTupleAnfNodePtr(const std::vector<int64_t> &value_tuple);
AnfNodePtr CreateTensorTupleAnfNodePtr(const tensor::TensorPtrList &tensor_tuple);
ForwardOp CreateReduceMeanForwardOp(const std::vector<Group> &forward_group, const TypePtr &dtype);
Operator CreateDivOpWithType(float divisor, const TypePtr &dtype);
} // namespace parallel
} // namespace mindspore

View File

@ -270,6 +270,7 @@ constexpr char ROI_END_MODE[] = "roi_end_mode";
constexpr char REDUCTION[] = "reduction";
constexpr char MEAN[] = "mean";
constexpr char ATTR_NONE[] = "none";
constexpr char BATCH_MEAN[] = "batchmean";
// Operator
constexpr char VIRTUAL_DIV[] = "_VirtualDiv";

View File

@ -29,6 +29,7 @@ Status RandomChoiceWithMaskInfo::GetAttrs() {
if (attrs_.find(SEED2) != attrs_.end()) {
seed2_ = GetValue<int64_t>(attrs_[SEED2]);
}
infer_strategy_mode_ = INDEPENDENT_MODE;
return SUCCESS;
}
@ -85,39 +86,15 @@ Status RandomChoiceWithMaskInfo::InferAsLossDivisor() {
}
void RandomChoiceWithMaskInfo::ReplaceNodeInputOrAttrs() {
// seed and seed2 cannot both be 0.
if (seed_ != 0 || seed2_ != 0) {
return;
}
if (cnode_->HasAttr(SEED)) {
cnode_->EraseAttr(SEED);
}
if (cnode_->HasAttr(SEED2)) {
cnode_->EraseAttr(SEED2);
}
cnode_->AddAttr(SEED, MakeValue(SEED_NUM));
cnode_->AddAttr(SEED2, MakeValue(SEED_NUM));
++SEED_NUM;
}
void RandomChoiceWithMaskInfo::CheckGPUBackend() {
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
std::string backend = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
if (backend != kGPUDevice) {
MS_LOG(EXCEPTION) << name_ << ": The backend is " << backend << " , only support on GPU backend now.";
}
}
Status RandomChoiceWithMaskInfo::Init(const StrategyPtr &in_strategy, const StrategyPtr &out_strategy) {
CheckGPUBackend();
return OperatorInfo::Init(in_strategy, out_strategy);
}
Status RandomChoiceWithMaskInfo::InitForCostModel(const StrategyPtr &strategy, const StrategyPtr &out_strategy) {
CheckGPUBackend();
return OperatorInfo::InitForCostModel(strategy, out_strategy);
}
void RandomChoiceWithMaskInfo::ReComputeBatchSplitFlagList() { split_flag_list_[0] = false; }
REGISTER(RandomChoiceWithMaskInfo);

View File

@ -35,8 +35,6 @@ class RandomChoiceWithMaskInfo : public OperatorInfo {
seed_(0),
seed2_(0) {}
~RandomChoiceWithMaskInfo() = default;
Status Init(const StrategyPtr &in_strategy, const StrategyPtr &out_strategy) override;
Status InitForCostModel(const StrategyPtr &strategy, const StrategyPtr &out_strategy) override;
std::vector<StrategyPtr> GenerateOpStrategies(int64_t stage_id) override;
Status SetCostUnderStrategy(const StrategyPtr &strategy) override { return SetCostUnderStrategyBase(strategy); }
void ReplaceNodeInputOrAttrs() override;
@ -51,8 +49,6 @@ class RandomChoiceWithMaskInfo : public OperatorInfo {
Status InferAsLossDivisor() override;
private:
void CheckGPUBackend();
int64_t seed_;
int64_t seed2_;
static int64_t SEED_NUM;

View File

@ -29,6 +29,8 @@
namespace mindspore {
namespace parallel {
int64_t UniformRealInfo::SEED_NUM = 1;
Status UniformRealInfo::InferAttrs() {
if (infer_attrs_completed_) {
return SUCCESS;
@ -53,6 +55,7 @@ Status UniformRealInfo::GetAttrs() {
MS_LOG(ERROR) << name_ << ": Seed2 must be greater or equal to zero, bug got " << seed2_;
return FAILED;
}
infer_strategy_mode_ = INDEPENDENT_MODE;
return SUCCESS;
}
@ -139,12 +142,17 @@ void UniformRealInfo::ReplaceNodeInputOrAttrs() {
// Update seed according rank_id
int64_t rank_id = g_device_manager->rank_index_in_stage();
int64_t seed_bias;
int64_t seed_bias = 0;
// When seed and seed2 are both 0, ensure that the 0th card in each group has the same result
if (seed_ == 0 && seed2_ == 0) {
seed_bias += SEED_NUM;
++SEED_NUM;
}
if (repeated_num_in_dev_matrix_right_) {
seed_bias = rank_id / repeated_calc_num_;
seed_bias += rank_id / repeated_calc_num_;
} else {
int64_t device_num = stage_device_size_;
seed_bias = rank_id % (device_num / repeated_calc_num_);
seed_bias += rank_id % (device_num / repeated_calc_num_);
}
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));

View File

@ -54,6 +54,7 @@ class UniformRealInfo : public OperatorInfo {
int64_t seed_ = 0;
int64_t seed2_ = 0;
static int64_t SEED_NUM;
};
} // namespace parallel
} // namespace mindspore

View File

@ -731,11 +731,10 @@ bool IsSplittableOperator(const std::string &op_name) {
UNSORTED_SEGMENT_PROD, SQUARE_SUM_ALL, MATMUL_DDS, DSD_MATMUL, UNIFORMREAL, RESIZE_BILINEAR, UNIQUE_CONSECUTIVE,
RESIZE_NEAREST_NEIGHBOR, CUM_SUM, FAST_GELU, IOU, BOUNDING_BOX_ENCODE, RANDOM_CHOICE_WITH_MASK, CROP_AND_RESIZE,
ROI_ALIGN, IS_FINITE, RINT, HSHRINK, HSIGMOID, MISH, SELU, SOFT_SHRINK, XLOGY, XDIVY, CUM_PROD, BITWISE_AND,
BITWISE_OR, BITWISE_XOR, MUL_NO_NAN, TRUNCATE_DIV, TRUNCATE_MOD, INPLACE_ADD, INPLACE_SUB, L2_LOSS, LERP, ADDN,
CDIST, SQUARED_DIFFERENCE, ERFINV, MASKED_FILL, SPLITV, GAMMA, KLDIV_LOSS, LIN_SPACE, CHECK_VALID, INVERT,
SCATTER_ADD, SCATTER_DIV, SCATTER_MUL, SCATTER_MAX, SCATTER_MIN, SCATTER_SUB, UNIQUE_WITH_PAD, POPULATION_COUNT,
IDENTITY};
BITWISE_OR, BITWISE_XOR, MUL_NO_NAN, TRUNCATE_DIV, TRUNCATE_MOD, INPLACE_ADD, INPLACE_SUB, INPLACE_UPDATE,
L2_LOSS, LERP, ADDN, CDIST, SQUARED_DIFFERENCE, ERFINV, MASKED_FILL, SPLITV, GAMMA, KLDIV_LOSS, LIN_SPACE,
CHECK_VALID, INVERT, SCATTER_ADD, SCATTER_DIV, SCATTER_MUL, SCATTER_MAX, SCATTER_MIN, SCATTER_SUB, UNIQUE_WITH_PAD,
POPULATION_COUNT, IDENTITY};
// clang-format on
auto iter = splittable_op.find(op_name);

View File

@ -0,0 +1,87 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import mindspore as ms
from mindspore import context, Tensor
from mindspore.nn import Cell
from mindspore.ops import operations as P
from parallel.utils.utils import ParallelValidator, compile_net
def setup_function():
context.set_auto_parallel_context(dataset_strategy="full_batch")
SEED_ = 1
SEED2_ = 1
class Net(Cell):
def __init__(self, seed, seed2, strategy=None):
super(Net, self).__init__()
self.uniform_real = P.Gamma(seed, seed2).shard(strategy)
def construct(self, shape, alpha, beta):
out = self.uniform_real(shape, alpha, beta)
return out
def test_uniform_real_auto_parallel():
"""
Features: test UniformReal auto parallel
Description: auto parallel
Expectation: compile success
"""
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0)
net = Net(SEED_, SEED2_)
shape = (4, 4, 4)
alpha = Tensor(np.array([1.0]), ms.float32)
beta = Tensor(np.array([1.0]), ms.float32)
compile_net(net, shape, alpha, beta)
def test_uniform_real_data_parallel():
"""
Features: test UniformReal data parallel
Description: data parallel
Expectation: compile success
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=1)
net = Net(SEED_, SEED2_)
shape = (8, 8)
alpha = Tensor(np.array([1.0]), ms.float32)
beta = Tensor(np.array([1.0]), ms.float32)
phase = compile_net(net, shape, alpha, beta)
validator = ParallelValidator(net, phase)
assert validator.check_node_attrs("Gamma-0", {"seed": 2, "seed2": 2})
def test_uniform_real_model_parallel():
"""
Features: test UniformReal model parallel
Description: model parallel
Expectation: compile success
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=5)
shape = (8, 8)
alpha = Tensor(np.array([1.0]), ms.float32)
beta = Tensor(np.array([1.0]), ms.float32)
strategy = ((2, 2), (1,), (1,))
net = Net(SEED_, SEED2_, strategy)
phase = compile_net(net, shape, alpha, beta)
validator = ParallelValidator(net, phase)
assert validator.check_node_attrs("Gamma-0", {"seed": 3, "seed2": 3})

View File

@ -48,42 +48,55 @@ class InplaceSubNet(Cell):
return self.inplace_sub(x, input_v)
def test_inplace_add_auto_parallel():
class InplaceUpdateNet(Cell):
def __init__(self, indices, strategy=None):
super(InplaceUpdateNet, self).__init__()
self.inplace_update = P.InplaceUpdate(indices).shard(strategy)
def construct(self, x, input_v):
return self.inplace_update(x, input_v)
@pytest.mark.parametrize("network", [InplaceAddNet, InplaceSubNet, InplaceUpdateNet])
def test_inplace_add_auto_parallel(network):
"""
Feature: test InplaceAdd auto parallel
Feature: test InplaceOp auto parallel
Description: auto parallel
Expectation: compile success
"""
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0)
net = InplaceAddNet(indices_)
net = network(indices_)
compile_net(net, x_, input_v_)
def test_inplace_add_model_parallel():
@pytest.mark.parametrize("network", [InplaceAddNet, InplaceSubNet, InplaceUpdateNet])
def test_inplace_op_model_parallel(network):
"""
Feature: test InplaceAdd model parallel
Feature: test InplaceOp model parallel
Description: model parallel
Expectation: compile success
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy = ((1, 4, 2), (1, 4, 2))
net = InplaceAddNet(indices_, strategy)
net = network(indices_, strategy)
compile_net(net, x_, input_v_)
def test_inplace_add_model_parallel_with_repeated_cal():
@pytest.mark.parametrize("network", [InplaceAddNet, InplaceSubNet, InplaceUpdateNet])
def test_inplace_add_model_parallel_with_repeated_cal(network):
"""
Feature: test InplaceAdd model parallel with repeated calculation
Feature: test InplaceOp model parallel with repeated calculation
Description: model parallel
Expectation: compile success
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy = ((1, 2, 2), (1, 2, 2))
net = InplaceAddNet(indices_, strategy)
net = network(indices_, strategy)
compile_net(net, x_, input_v_)
def test_inplace_add_strategy_error():
@pytest.mark.parametrize("network", [InplaceAddNet, InplaceSubNet, InplaceUpdateNet])
def test_inplace_add_strategy_error(network):
"""
Feature: test invalid strategy for InplaceAdd
Description: illegal strategy
@ -91,54 +104,6 @@ def test_inplace_add_strategy_error():
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy = ((1, 4, 2), (1, 2, 4))
net = InplaceAddNet(indices_, strategy)
with pytest.raises(RuntimeError):
compile_net(net, x_, input_v_)
def test_inplace_sub_auto_parallel():
"""
Feature: test InplaceSub auto parallel
Description: auto parallel
Expectation: compile success
"""
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0)
net = InplaceSubNet(indices_)
compile_net(net, x_, input_v_)
def test_inplace_sub_model_parallel():
"""
Feature: test InplaceSub model parallel
Description: model parallel
Expectation: compile success
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy = ((1, 4, 2), (1, 4, 2))
net = InplaceSubNet(indices_, strategy)
compile_net(net, x_, input_v_)
def test_inplace_sub_model_parallel_with_repeated_cal():
"""
Feature: test InplaceSub model parallel with repeated calculation
Description: model parallel
Expectation: compile success
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy = ((1, 2, 2), (1, 2, 2))
net = InplaceSubNet(indices_, strategy)
compile_net(net, x_, input_v_)
def test_inplace_sub_strategy_error():
"""
Feature: test invalid strategy for InplaceSub
Description: illegal strategy
Expectation: raise RuntimeError
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy = ((1, 4, 2), (1, 2, 4))
net = InplaceSubNet(indices_, strategy)
net = network(indices_, strategy)
with pytest.raises(RuntimeError):
compile_net(net, x_, input_v_)

View File

@ -54,7 +54,6 @@ def test_auto_parallel_random_choice_with_mask():
Description: auto parallel
Expectation: compile success
"""
context.set_context(device_target="GPU")
context.set_auto_parallel_context(dataset_strategy="full_batch")
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0)
net = Net()
@ -67,24 +66,9 @@ def test_random_choice_with_mask_wrong_strategy():
Description: illegal strategy
Expectation: raise RuntimeError
"""
context.set_context(device_target="GPU")
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy = ((8, 1),)
net = Net(strategy)
with pytest.raises(RuntimeError):
compile_net(net, _input_x)
context.reset_auto_parallel_context()
def test_random_choice_with_mask_not_gpu():
"""
Feature: RandomChoiceWithMask
Description: not compile with gpu backend
Expectation: raise RuntimeError
"""
context.set_context(device_target="Ascend")
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0)
net = Net()
with pytest.raises(RuntimeError):
compile_net(net, _input_x)
context.reset_auto_parallel_context()