!45083 Fix some bug in elemnt wise and math parallel operator.
Merge pull request !45083 from liuluobin/parallel_ops
This commit is contained in:
commit
179341710f
|
@ -667,6 +667,8 @@ using BitwiseOrCost = SubCost;
|
|||
using BitwiseXorCost = SubCost;
|
||||
using AddNCost = SubCost;
|
||||
using InplaceAddCost = SubCost;
|
||||
using InplaceSubCost = InplaceAddCost;
|
||||
using InplaceUpdateCost = InplaceAddCost;
|
||||
using MaskedFillCost = SubCost;
|
||||
|
||||
class MulCost : public SubCost {
|
||||
|
|
|
@ -252,7 +252,7 @@ Status CumOpBase::CheckStrategy(const StrategyPtr &strategy) {
|
|||
return FAILED;
|
||||
}
|
||||
auto axis_split = input_strategy[LongToSize(axis_)];
|
||||
if (axis_split > 1) {
|
||||
if (axis_split != NO_SPLIT_STRATEGY) {
|
||||
MS_LOG(ERROR) << "Currently, CumSum does not support the sharding strategies which splits axis.";
|
||||
return FAILED;
|
||||
}
|
||||
|
|
|
@ -31,6 +31,8 @@ Status AddNInfo::CheckStrategy(const StrategyPtr &strategy) {
|
|||
Strategies strategies = strategy->GetInputDim();
|
||||
for (size_t i = 1; i < strategies.size(); ++i) {
|
||||
if (strategies[i] != strategies[0]) {
|
||||
MS_LOG(ERROR) << name_ << ": The strategy for each input must be equal to strategies[0]: " << strategies[0]
|
||||
<< ", but got strategies[" << i << "]: " << strategies[i];
|
||||
return FAILED;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -25,6 +25,8 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace parallel {
|
||||
int64_t GammaInfo::SEED_NUM = 1;
|
||||
|
||||
Status GammaInfo::InferAttrs() {
|
||||
if (infer_attrs_completed_) {
|
||||
return SUCCESS;
|
||||
|
@ -136,12 +138,17 @@ void GammaInfo::ReplaceNodeInputOrAttrs() {
|
|||
|
||||
// Update seed according rank_id
|
||||
int64_t rank_id = g_device_manager->rank_index_in_stage();
|
||||
int64_t seed_bias;
|
||||
int64_t seed_bias = 0;
|
||||
// When seed and seed2 are both 0, ensure that the 0th card in each group has the same result
|
||||
if (seed_ == 0 && seed2_ == 0) {
|
||||
seed_bias += SEED_NUM;
|
||||
++SEED_NUM;
|
||||
}
|
||||
if (repeated_num_in_dev_matrix_right_) {
|
||||
seed_bias = rank_id / repeated_calc_num_;
|
||||
seed_bias += rank_id / repeated_calc_num_;
|
||||
} else {
|
||||
int64_t device_num = stage_device_size_;
|
||||
seed_bias = rank_id % (device_num / repeated_calc_num_);
|
||||
seed_bias += rank_id % (device_num / repeated_calc_num_);
|
||||
}
|
||||
|
||||
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
|
||||
|
@ -160,7 +167,7 @@ void GammaInfo::ResetInputsShape() {
|
|||
}
|
||||
|
||||
bool GammaInfo::IsNotSplittableStrategy(const Dimensions &strategy) const {
|
||||
return std::all_of(strategy.cbegin(), strategy.cend(), [](int64_t val) { return val == 1; });
|
||||
return std::all_of(strategy.cbegin(), strategy.cend(), [](int64_t val) { return val == NO_SPLIT_STRATEGY; });
|
||||
}
|
||||
|
||||
void GammaInfo::ReComputeBatchSplitFlagList() {
|
||||
|
|
|
@ -57,6 +57,7 @@ class GammaInfo : public OperatorInfo {
|
|||
|
||||
int64_t seed_ = 0;
|
||||
int64_t seed2_ = 0;
|
||||
static int64_t SEED_NUM;
|
||||
};
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -18,12 +18,12 @@
|
|||
#include <algorithm>
|
||||
#include <functional>
|
||||
|
||||
#include "frontend/parallel/ops_info/inplace_add_info.h"
|
||||
#include "frontend/parallel/ops_info/inplace_op_info.h"
|
||||
#include "frontend/parallel/dynamic_creator.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace parallel {
|
||||
Status InplaceAddInfo::CheckStrategy(const StrategyPtr &strategy) {
|
||||
Status InplaceOpBase::CheckStrategy(const StrategyPtr &strategy) {
|
||||
if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << ": Invalid strategy";
|
||||
return FAILED;
|
||||
|
@ -44,7 +44,7 @@ Status InplaceAddInfo::CheckStrategy(const StrategyPtr &strategy) {
|
|||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status InplaceAddInfo::InferDevMatrixShape() {
|
||||
Status InplaceOpBase::InferDevMatrixShape() {
|
||||
dev_matrix_shape_.clear();
|
||||
|
||||
auto strategies = strategy_->GetInputDim();
|
||||
|
@ -55,7 +55,7 @@ Status InplaceAddInfo::InferDevMatrixShape() {
|
|||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status InplaceAddInfo::InferTensorMap() {
|
||||
Status InplaceOpBase::InferTensorMap() {
|
||||
inputs_tensor_map_.clear();
|
||||
outputs_tensor_map_.clear();
|
||||
|
||||
|
@ -77,7 +77,7 @@ Status InplaceAddInfo::InferTensorMap() {
|
|||
return SUCCESS;
|
||||
}
|
||||
|
||||
std::vector<StrategyPtr> InplaceAddInfo::GenerateOpStrategies(int64_t stage_id) {
|
||||
std::vector<StrategyPtr> InplaceOpBase::GenerateOpStrategies(int64_t stage_id) {
|
||||
Shapes splittable_inputs = inputs_shape_;
|
||||
for (size_t i = 0; i < splittable_inputs.size(); ++i) {
|
||||
for (size_t j = 0; j < splittable_inputs[i].size(); ++j) {
|
||||
|
@ -91,12 +91,13 @@ std::vector<StrategyPtr> InplaceAddInfo::GenerateOpStrategies(int64_t stage_id)
|
|||
return sp_vector;
|
||||
}
|
||||
|
||||
void InplaceAddInfo::ReComputeBatchSplitFlagList() {
|
||||
void InplaceOpBase::ReComputeBatchSplitFlagList() {
|
||||
split_flag_list_[0] = false;
|
||||
split_flag_list_[1] = false;
|
||||
}
|
||||
|
||||
REGISTER(InplaceAddInfo);
|
||||
REGISTER(InplaceSubInfo);
|
||||
REGISTER(InplaceUpdateInfo);
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_INPLACE_ADD_INFO_H_
|
||||
#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_INPLACE_ADD_INFO_H_
|
||||
#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_INPLACE_OP_INFO_H_
|
||||
#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_INPLACE_OP_INFO_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
@ -27,12 +27,12 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace parallel {
|
||||
class InplaceAddInfo : public OperatorInfo {
|
||||
class InplaceOpBase : public OperatorInfo {
|
||||
public:
|
||||
InplaceAddInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<InplaceAddCost>()) {}
|
||||
~InplaceAddInfo() override = default;
|
||||
InplaceOpBase(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs, const OperatorCostPtr &cost)
|
||||
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, cost) {}
|
||||
~InplaceOpBase() override = default;
|
||||
std::vector<StrategyPtr> GenerateOpStrategies(int64_t stage_id) override;
|
||||
Status SetCostUnderStrategy(const StrategyPtr &strategy) override { return SetCostUnderStrategyBase(strategy); }
|
||||
void ReComputeBatchSplitFlagList() override;
|
||||
|
@ -48,14 +48,30 @@ class InplaceAddInfo : public OperatorInfo {
|
|||
Status InferForwardCommunication() override { return SUCCESS; }
|
||||
};
|
||||
|
||||
class InplaceSubInfo : public InplaceAddInfo {
|
||||
class InplaceAddInfo : public InplaceOpBase {
|
||||
public:
|
||||
InplaceAddInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: InplaceOpBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<InplaceAddCost>()) {}
|
||||
~InplaceAddInfo() override = default;
|
||||
};
|
||||
|
||||
class InplaceSubInfo : public InplaceOpBase {
|
||||
public:
|
||||
InplaceSubInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: InplaceAddInfo(name, inputs_shape, outputs_shape, attrs) {}
|
||||
~InplaceSubInfo() = default;
|
||||
: InplaceOpBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<InplaceSubCost>()) {}
|
||||
~InplaceSubInfo() override = default;
|
||||
};
|
||||
|
||||
class InplaceUpdateInfo : public InplaceOpBase {
|
||||
public:
|
||||
InplaceUpdateInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: InplaceOpBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<InplaceUpdateCost>()) {}
|
||||
~InplaceUpdateInfo() override = default;
|
||||
};
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_INPLACE_ADD_INFO_H_
|
||||
#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_INPLACE_OP_INFO_H_
|
|
@ -41,6 +41,7 @@ Status KLDivLossInfo::CheckStrategy(const StrategyPtr &strategy) {
|
|||
<< StrategyToString(strategies);
|
||||
return FAILED;
|
||||
}
|
||||
batch_split_num_ = strategies[0][0];
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
|
@ -128,6 +129,11 @@ Status KLDivLossInfo::InferForwardCommunication() {
|
|||
forward_op_ = CreateReduceMeanForwardOp(group_list_, element_type);
|
||||
} else {
|
||||
(void)forward_op_.emplace_back(CreateAllReduceOp(REDUCE_OP_SUM, group_list_[0].name()));
|
||||
if (reduction_ == BATCH_MEAN) {
|
||||
// Divided by the number of devices in the Batch dimension
|
||||
auto element_type = outputs_dtype_->cast<mindspore::TensorTypePtr>()->element();
|
||||
(void)forward_op_.emplace_back(CreateDivOpWithType(SizeToFloat(batch_split_num_), element_type));
|
||||
}
|
||||
}
|
||||
MS_LOG(INFO) << name_ << ": The group name of forward all reduce is " << group_list_[0].name();
|
||||
|
||||
|
|
|
@ -48,6 +48,7 @@ class KLDivLossInfo : public OperatorInfo {
|
|||
private:
|
||||
Status InferGroup();
|
||||
|
||||
std::size_t batch_split_num_;
|
||||
std::string reduction_;
|
||||
std::vector<Group> group_list_;
|
||||
};
|
||||
|
|
|
@ -38,16 +38,21 @@ Status LinSpaceInfo::CheckStrategy(const StrategyPtr &strategy) {
|
|||
return FAILED;
|
||||
}
|
||||
|
||||
auto split_num = strategies[0][0];
|
||||
if (output_size_ % split_num != 0) {
|
||||
MS_LOG(ERROR) << name_ << ": The strategy is " << StrategyToString(strategies) << ", output size is "
|
||||
<< output_size_ << " cannot be divisible by strategy value " << split_num;
|
||||
split_num_ = strategies[0][0];
|
||||
if (split_num_ <= 0) {
|
||||
MS_LOG(ERROR) << name_ << ": Each element in strategy must be a positive integer, but got "
|
||||
<< StrategyToString(strategies);
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
if (stage_device_size_ % split_num != 0) {
|
||||
if (output_size_ % split_num_ != 0) {
|
||||
MS_LOG(ERROR) << name_ << ": The strategy is " << StrategyToString(strategies) << ", output size is "
|
||||
<< output_size_ << " cannot be divisible by strategy value " << split_num_;
|
||||
return FAILED;
|
||||
}
|
||||
if (stage_device_size_ % split_num_ != 0) {
|
||||
MS_LOG(ERROR) << name_ << ": The strategy is " << StrategyToString(strategies)
|
||||
<< ", the device size in this stage is " << stage_device_size_
|
||||
<< " cannot be divisible by the strategy value " << split_num;
|
||||
<< " cannot be divisible by the strategy value " << split_num_;
|
||||
return FAILED;
|
||||
}
|
||||
return SUCCESS;
|
||||
|
@ -97,31 +102,15 @@ std::shared_ptr<Strategies> LinSpaceInfo::GenerateBatchStrategies() {
|
|||
return std::make_shared<Strategies>(strategies);
|
||||
}
|
||||
|
||||
int64_t LinSpaceInfo::GetSplitNum() {
|
||||
auto strategies = strategy_->GetInputDim();
|
||||
auto split_strategy = strategies.at(0);
|
||||
auto split_num = split_strategy.at(0);
|
||||
return split_num;
|
||||
}
|
||||
|
||||
ReplaceGraphPtr LinSpaceInfo::replace_graph(const CNodePtr &cnode) {
|
||||
auto split_num = GetSplitNum();
|
||||
if (split_num > 1 && ComputeReplaceGraph(cnode) != SUCCESS) {
|
||||
if (split_num_ > 1 && ComputeReplaceGraph(cnode) != SUCCESS) {
|
||||
MS_LOG(EXCEPTION) << name_ << ": ComputeReplaceGraph failed.";
|
||||
}
|
||||
return replace_graph_;
|
||||
}
|
||||
|
||||
Status LinSpaceInfo::ComputeReplaceGraph(const CNodePtr &cnode) {
|
||||
auto split_num = GetSplitNum();
|
||||
if (split_num == 0) {
|
||||
MS_LOG(ERROR) << name_ << ": split num is 1, no need to replace graph";
|
||||
return FAILED;
|
||||
}
|
||||
if (output_size_ % split_num != 0) {
|
||||
return FAILED;
|
||||
}
|
||||
if (split_num == 1) {
|
||||
if (split_num_ == 1) {
|
||||
MS_LOG(INFO) << name_ << ": split num is 1, no need to replace graph";
|
||||
return SUCCESS;
|
||||
}
|
||||
|
@ -132,7 +121,7 @@ Status LinSpaceInfo::ComputeReplaceGraph(const CNodePtr &cnode) {
|
|||
return FAILED;
|
||||
}
|
||||
|
||||
int64_t slice_output_size = output_size_ / split_num;
|
||||
int64_t slice_output_size = output_size_ / split_num_;
|
||||
auto sub = gen_g.PushBack({gen_g.NewOpInst(SUB), gen_g.virtual_input_node(), gen_g.virtual_input_node()});
|
||||
auto dtype = gen_g.PushBack({gen_g.NewOpInst(DTYPE), sub});
|
||||
AnfNodePtr interval = nullptr;
|
||||
|
@ -146,6 +135,7 @@ Status LinSpaceInfo::ComputeReplaceGraph(const CNodePtr &cnode) {
|
|||
// new_start = start + slice_id * slice_output_size * interval
|
||||
// new_end = new_start + (slice_output_size - 1) * interval
|
||||
// new_x = slice_output_size
|
||||
InferSliceId();
|
||||
auto offset_size = gen_g.PushBack({gen_g.NewOpInst(CAST), CreateInt32Tensor(slice_id_ * slice_output_size), dtype});
|
||||
auto start_offset = gen_g.PushBack({gen_g.NewOpInst(MUL), interval, offset_size});
|
||||
auto new_start = gen_g.PushBack({gen_g.NewOpInst(ADD), gen_g.virtual_input_node(), start_offset});
|
||||
|
@ -161,7 +151,7 @@ Status LinSpaceInfo::ComputeReplaceGraph(const CNodePtr &cnode) {
|
|||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status LinSpaceInfo::InferSliceId() {
|
||||
void LinSpaceInfo::InferSliceId() {
|
||||
CheckGlobalDeviceManager();
|
||||
int64_t rank = g_device_manager->rank_index_in_stage();
|
||||
slice_id_ = rank;
|
||||
|
@ -172,7 +162,6 @@ Status LinSpaceInfo::InferSliceId() {
|
|||
slice_id_ %= dev_matrix_shape_.front();
|
||||
}
|
||||
}
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
REGISTER(LinSpaceInfo);
|
||||
|
|
|
@ -49,11 +49,12 @@ class LinSpaceInfo : public OperatorInfo {
|
|||
Status InferForwardCommunication() override { return SUCCESS; }
|
||||
|
||||
private:
|
||||
Status InferSliceId();
|
||||
void InferSliceId();
|
||||
Status ComputeReplaceGraph(const CNodePtr &cnode);
|
||||
int64_t GetSplitNum();
|
||||
|
||||
int64_t slice_id_ = 0;
|
||||
int64_t split_num_ = 0;
|
||||
int64_t output_size_ = 0;
|
||||
};
|
||||
} // namespace parallel
|
||||
|
|
|
@ -362,10 +362,10 @@ Operator CreateDivOp(float scale) {
|
|||
OperatorName operator_name = REAL_DIV;
|
||||
OperatorAttrs operator_attrs;
|
||||
OperatorParams operator_param;
|
||||
size_t parameter_pos = 2;
|
||||
constexpr size_t parameter_pos = 2;
|
||||
mindspore::tensor::TensorPtr tensor_ptr = std::make_shared<mindspore::tensor::Tensor>(scale);
|
||||
ValuePtr scale_value = MakeValue(tensor_ptr);
|
||||
operator_param.push_back(std::make_pair(std::make_pair(Y, scale_value), parameter_pos));
|
||||
(void)operator_param.emplace_back(std::make_pair(std::make_pair(Y, scale_value), parameter_pos));
|
||||
OperatorArgs operator_arg = std::make_pair(operator_attrs, operator_param);
|
||||
|
||||
Operator op = std::make_pair(operator_name, operator_arg);
|
||||
|
@ -2322,6 +2322,18 @@ AnfNodePtr CreateTensorTupleAnfNodePtr(const tensor::TensorPtrList &tensor_tuple
|
|||
return tensor_node->cast<AnfNodePtr>();
|
||||
}
|
||||
|
||||
Operator CreateDivOpWithType(float divisor, const TypePtr &dtype) {
|
||||
OperatorName operator1_name = REAL_DIV;
|
||||
mindspore::tensor::TensorPtr tensor_ptr = std::make_shared<mindspore::tensor::Tensor>(divisor, dtype);
|
||||
ValuePtr op1_param_value = MakeValue(tensor_ptr);
|
||||
Attr op1_param = std::make_pair("divisor", op1_param_value);
|
||||
OperatorParams operator1_params = {std::make_pair(op1_param, 2)};
|
||||
OperatorAttrs operator1_attrs;
|
||||
OperatorArgs operator1_args = std::make_pair(operator1_attrs, operator1_params);
|
||||
Operator div_op = std::make_pair(operator1_name, operator1_args);
|
||||
return div_op;
|
||||
}
|
||||
|
||||
ForwardOp CreateReduceMeanForwardOp(const std::vector<Group> &forward_group, const TypePtr &dtype) {
|
||||
// Create AllReduceSum op
|
||||
Operator op0 = CreateAllReduceOp(REDUCE_OP_SUM, forward_group[0].name());
|
||||
|
@ -2329,21 +2341,13 @@ ForwardOp CreateReduceMeanForwardOp(const std::vector<Group> &forward_group, con
|
|||
MS_LOG(INFO) << "The group of forward all reduce is " << group_name;
|
||||
|
||||
// Create RealDiv op
|
||||
OperatorName operator1_name = REAL_DIV;
|
||||
std::vector<Device> device_list = forward_group[0].GetDevicesList();
|
||||
auto divisor = static_cast<float>(device_list.size());
|
||||
mindspore::tensor::TensorPtr tensor_ptr = std::make_shared<mindspore::tensor::Tensor>(divisor, dtype);
|
||||
ValuePtr op1_param_value = MakeValue(tensor_ptr);
|
||||
Attr op1_param = std::make_pair("divisor", op1_param_value);
|
||||
OperatorParams operator1_params = {std::make_pair(op1_param, 2)};
|
||||
OperatorAttrs operator1_attrs;
|
||||
OperatorArgs operator1_args = std::make_pair(operator1_attrs, operator1_params);
|
||||
Operator op1 = std::make_pair(operator1_name, operator1_args);
|
||||
ForwardOp forward_op = {op0, op1};
|
||||
|
||||
auto divisor = SizeToFloat(device_list.size());
|
||||
Operator op1 = CreateDivOpWithType(divisor, dtype);
|
||||
std::string dtype_name = dtype->ToString();
|
||||
MS_LOG(INFO) << "The divisor of Div op is " << device_list.size() << ", the dtype is " << dtype_name;
|
||||
return forward_op;
|
||||
|
||||
return {op0, op1};
|
||||
}
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -392,6 +392,7 @@ AnfNodePtr CreateValueTupleAnfNodePtr(const std::vector<int64_t> &value_tuple);
|
|||
AnfNodePtr CreateTensorTupleAnfNodePtr(const tensor::TensorPtrList &tensor_tuple);
|
||||
|
||||
ForwardOp CreateReduceMeanForwardOp(const std::vector<Group> &forward_group, const TypePtr &dtype);
|
||||
Operator CreateDivOpWithType(float divisor, const TypePtr &dtype);
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -270,6 +270,7 @@ constexpr char ROI_END_MODE[] = "roi_end_mode";
|
|||
constexpr char REDUCTION[] = "reduction";
|
||||
constexpr char MEAN[] = "mean";
|
||||
constexpr char ATTR_NONE[] = "none";
|
||||
constexpr char BATCH_MEAN[] = "batchmean";
|
||||
|
||||
// Operator
|
||||
constexpr char VIRTUAL_DIV[] = "_VirtualDiv";
|
||||
|
|
|
@ -29,6 +29,7 @@ Status RandomChoiceWithMaskInfo::GetAttrs() {
|
|||
if (attrs_.find(SEED2) != attrs_.end()) {
|
||||
seed2_ = GetValue<int64_t>(attrs_[SEED2]);
|
||||
}
|
||||
infer_strategy_mode_ = INDEPENDENT_MODE;
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
|
@ -85,39 +86,15 @@ Status RandomChoiceWithMaskInfo::InferAsLossDivisor() {
|
|||
}
|
||||
|
||||
void RandomChoiceWithMaskInfo::ReplaceNodeInputOrAttrs() {
|
||||
// seed and seed2 cannot both be 0.
|
||||
if (seed_ != 0 || seed2_ != 0) {
|
||||
return;
|
||||
}
|
||||
if (cnode_->HasAttr(SEED)) {
|
||||
cnode_->EraseAttr(SEED);
|
||||
}
|
||||
if (cnode_->HasAttr(SEED2)) {
|
||||
cnode_->EraseAttr(SEED2);
|
||||
}
|
||||
cnode_->AddAttr(SEED, MakeValue(SEED_NUM));
|
||||
cnode_->AddAttr(SEED2, MakeValue(SEED_NUM));
|
||||
++SEED_NUM;
|
||||
}
|
||||
|
||||
void RandomChoiceWithMaskInfo::CheckGPUBackend() {
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(ms_context);
|
||||
std::string backend = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
|
||||
if (backend != kGPUDevice) {
|
||||
MS_LOG(EXCEPTION) << name_ << ": The backend is " << backend << " , only support on GPU backend now.";
|
||||
}
|
||||
}
|
||||
|
||||
Status RandomChoiceWithMaskInfo::Init(const StrategyPtr &in_strategy, const StrategyPtr &out_strategy) {
|
||||
CheckGPUBackend();
|
||||
return OperatorInfo::Init(in_strategy, out_strategy);
|
||||
}
|
||||
|
||||
Status RandomChoiceWithMaskInfo::InitForCostModel(const StrategyPtr &strategy, const StrategyPtr &out_strategy) {
|
||||
CheckGPUBackend();
|
||||
return OperatorInfo::InitForCostModel(strategy, out_strategy);
|
||||
}
|
||||
|
||||
void RandomChoiceWithMaskInfo::ReComputeBatchSplitFlagList() { split_flag_list_[0] = false; }
|
||||
|
||||
REGISTER(RandomChoiceWithMaskInfo);
|
||||
|
|
|
@ -35,8 +35,6 @@ class RandomChoiceWithMaskInfo : public OperatorInfo {
|
|||
seed_(0),
|
||||
seed2_(0) {}
|
||||
~RandomChoiceWithMaskInfo() = default;
|
||||
Status Init(const StrategyPtr &in_strategy, const StrategyPtr &out_strategy) override;
|
||||
Status InitForCostModel(const StrategyPtr &strategy, const StrategyPtr &out_strategy) override;
|
||||
std::vector<StrategyPtr> GenerateOpStrategies(int64_t stage_id) override;
|
||||
Status SetCostUnderStrategy(const StrategyPtr &strategy) override { return SetCostUnderStrategyBase(strategy); }
|
||||
void ReplaceNodeInputOrAttrs() override;
|
||||
|
@ -51,8 +49,6 @@ class RandomChoiceWithMaskInfo : public OperatorInfo {
|
|||
Status InferAsLossDivisor() override;
|
||||
|
||||
private:
|
||||
void CheckGPUBackend();
|
||||
|
||||
int64_t seed_;
|
||||
int64_t seed2_;
|
||||
static int64_t SEED_NUM;
|
||||
|
|
|
@ -29,6 +29,8 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace parallel {
|
||||
int64_t UniformRealInfo::SEED_NUM = 1;
|
||||
|
||||
Status UniformRealInfo::InferAttrs() {
|
||||
if (infer_attrs_completed_) {
|
||||
return SUCCESS;
|
||||
|
@ -53,6 +55,7 @@ Status UniformRealInfo::GetAttrs() {
|
|||
MS_LOG(ERROR) << name_ << ": Seed2 must be greater or equal to zero, bug got " << seed2_;
|
||||
return FAILED;
|
||||
}
|
||||
infer_strategy_mode_ = INDEPENDENT_MODE;
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
|
@ -139,12 +142,17 @@ void UniformRealInfo::ReplaceNodeInputOrAttrs() {
|
|||
|
||||
// Update seed according rank_id
|
||||
int64_t rank_id = g_device_manager->rank_index_in_stage();
|
||||
int64_t seed_bias;
|
||||
int64_t seed_bias = 0;
|
||||
// When seed and seed2 are both 0, ensure that the 0th card in each group has the same result
|
||||
if (seed_ == 0 && seed2_ == 0) {
|
||||
seed_bias += SEED_NUM;
|
||||
++SEED_NUM;
|
||||
}
|
||||
if (repeated_num_in_dev_matrix_right_) {
|
||||
seed_bias = rank_id / repeated_calc_num_;
|
||||
seed_bias += rank_id / repeated_calc_num_;
|
||||
} else {
|
||||
int64_t device_num = stage_device_size_;
|
||||
seed_bias = rank_id % (device_num / repeated_calc_num_);
|
||||
seed_bias += rank_id % (device_num / repeated_calc_num_);
|
||||
}
|
||||
|
||||
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
|
||||
|
|
|
@ -54,6 +54,7 @@ class UniformRealInfo : public OperatorInfo {
|
|||
|
||||
int64_t seed_ = 0;
|
||||
int64_t seed2_ = 0;
|
||||
static int64_t SEED_NUM;
|
||||
};
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -731,11 +731,10 @@ bool IsSplittableOperator(const std::string &op_name) {
|
|||
UNSORTED_SEGMENT_PROD, SQUARE_SUM_ALL, MATMUL_DDS, DSD_MATMUL, UNIFORMREAL, RESIZE_BILINEAR, UNIQUE_CONSECUTIVE,
|
||||
RESIZE_NEAREST_NEIGHBOR, CUM_SUM, FAST_GELU, IOU, BOUNDING_BOX_ENCODE, RANDOM_CHOICE_WITH_MASK, CROP_AND_RESIZE,
|
||||
ROI_ALIGN, IS_FINITE, RINT, HSHRINK, HSIGMOID, MISH, SELU, SOFT_SHRINK, XLOGY, XDIVY, CUM_PROD, BITWISE_AND,
|
||||
BITWISE_OR, BITWISE_XOR, MUL_NO_NAN, TRUNCATE_DIV, TRUNCATE_MOD, INPLACE_ADD, INPLACE_SUB, L2_LOSS, LERP, ADDN,
|
||||
CDIST, SQUARED_DIFFERENCE, ERFINV, MASKED_FILL, SPLITV, GAMMA, KLDIV_LOSS, LIN_SPACE, CHECK_VALID, INVERT,
|
||||
SCATTER_ADD, SCATTER_DIV, SCATTER_MUL, SCATTER_MAX, SCATTER_MIN, SCATTER_SUB, UNIQUE_WITH_PAD, POPULATION_COUNT,
|
||||
IDENTITY};
|
||||
|
||||
BITWISE_OR, BITWISE_XOR, MUL_NO_NAN, TRUNCATE_DIV, TRUNCATE_MOD, INPLACE_ADD, INPLACE_SUB, INPLACE_UPDATE,
|
||||
L2_LOSS, LERP, ADDN, CDIST, SQUARED_DIFFERENCE, ERFINV, MASKED_FILL, SPLITV, GAMMA, KLDIV_LOSS, LIN_SPACE,
|
||||
CHECK_VALID, INVERT, SCATTER_ADD, SCATTER_DIV, SCATTER_MUL, SCATTER_MAX, SCATTER_MIN, SCATTER_SUB, UNIQUE_WITH_PAD,
|
||||
POPULATION_COUNT, IDENTITY};
|
||||
// clang-format on
|
||||
|
||||
auto iter = splittable_op.find(op_name);
|
||||
|
|
|
@ -0,0 +1,87 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
|
||||
import mindspore as ms
|
||||
from mindspore import context, Tensor
|
||||
from mindspore.nn import Cell
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
from parallel.utils.utils import ParallelValidator, compile_net
|
||||
|
||||
|
||||
def setup_function():
|
||||
context.set_auto_parallel_context(dataset_strategy="full_batch")
|
||||
|
||||
SEED_ = 1
|
||||
SEED2_ = 1
|
||||
|
||||
|
||||
class Net(Cell):
|
||||
def __init__(self, seed, seed2, strategy=None):
|
||||
super(Net, self).__init__()
|
||||
self.uniform_real = P.Gamma(seed, seed2).shard(strategy)
|
||||
|
||||
def construct(self, shape, alpha, beta):
|
||||
out = self.uniform_real(shape, alpha, beta)
|
||||
return out
|
||||
|
||||
|
||||
def test_uniform_real_auto_parallel():
|
||||
"""
|
||||
Features: test UniformReal auto parallel
|
||||
Description: auto parallel
|
||||
Expectation: compile success
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0)
|
||||
net = Net(SEED_, SEED2_)
|
||||
shape = (4, 4, 4)
|
||||
alpha = Tensor(np.array([1.0]), ms.float32)
|
||||
beta = Tensor(np.array([1.0]), ms.float32)
|
||||
compile_net(net, shape, alpha, beta)
|
||||
|
||||
|
||||
def test_uniform_real_data_parallel():
|
||||
"""
|
||||
Features: test UniformReal data parallel
|
||||
Description: data parallel
|
||||
Expectation: compile success
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=1)
|
||||
net = Net(SEED_, SEED2_)
|
||||
shape = (8, 8)
|
||||
alpha = Tensor(np.array([1.0]), ms.float32)
|
||||
beta = Tensor(np.array([1.0]), ms.float32)
|
||||
phase = compile_net(net, shape, alpha, beta)
|
||||
|
||||
validator = ParallelValidator(net, phase)
|
||||
assert validator.check_node_attrs("Gamma-0", {"seed": 2, "seed2": 2})
|
||||
|
||||
|
||||
def test_uniform_real_model_parallel():
|
||||
"""
|
||||
Features: test UniformReal model parallel
|
||||
Description: model parallel
|
||||
Expectation: compile success
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=5)
|
||||
shape = (8, 8)
|
||||
alpha = Tensor(np.array([1.0]), ms.float32)
|
||||
beta = Tensor(np.array([1.0]), ms.float32)
|
||||
strategy = ((2, 2), (1,), (1,))
|
||||
net = Net(SEED_, SEED2_, strategy)
|
||||
phase = compile_net(net, shape, alpha, beta)
|
||||
validator = ParallelValidator(net, phase)
|
||||
assert validator.check_node_attrs("Gamma-0", {"seed": 3, "seed2": 3})
|
|
@ -48,42 +48,55 @@ class InplaceSubNet(Cell):
|
|||
return self.inplace_sub(x, input_v)
|
||||
|
||||
|
||||
def test_inplace_add_auto_parallel():
|
||||
class InplaceUpdateNet(Cell):
|
||||
def __init__(self, indices, strategy=None):
|
||||
super(InplaceUpdateNet, self).__init__()
|
||||
self.inplace_update = P.InplaceUpdate(indices).shard(strategy)
|
||||
|
||||
def construct(self, x, input_v):
|
||||
return self.inplace_update(x, input_v)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("network", [InplaceAddNet, InplaceSubNet, InplaceUpdateNet])
|
||||
def test_inplace_add_auto_parallel(network):
|
||||
"""
|
||||
Feature: test InplaceAdd auto parallel
|
||||
Feature: test InplaceOp auto parallel
|
||||
Description: auto parallel
|
||||
Expectation: compile success
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0)
|
||||
net = InplaceAddNet(indices_)
|
||||
net = network(indices_)
|
||||
compile_net(net, x_, input_v_)
|
||||
|
||||
|
||||
def test_inplace_add_model_parallel():
|
||||
@pytest.mark.parametrize("network", [InplaceAddNet, InplaceSubNet, InplaceUpdateNet])
|
||||
def test_inplace_op_model_parallel(network):
|
||||
"""
|
||||
Feature: test InplaceAdd model parallel
|
||||
Feature: test InplaceOp model parallel
|
||||
Description: model parallel
|
||||
Expectation: compile success
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
strategy = ((1, 4, 2), (1, 4, 2))
|
||||
net = InplaceAddNet(indices_, strategy)
|
||||
net = network(indices_, strategy)
|
||||
compile_net(net, x_, input_v_)
|
||||
|
||||
|
||||
def test_inplace_add_model_parallel_with_repeated_cal():
|
||||
@pytest.mark.parametrize("network", [InplaceAddNet, InplaceSubNet, InplaceUpdateNet])
|
||||
def test_inplace_add_model_parallel_with_repeated_cal(network):
|
||||
"""
|
||||
Feature: test InplaceAdd model parallel with repeated calculation
|
||||
Feature: test InplaceOp model parallel with repeated calculation
|
||||
Description: model parallel
|
||||
Expectation: compile success
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
strategy = ((1, 2, 2), (1, 2, 2))
|
||||
net = InplaceAddNet(indices_, strategy)
|
||||
net = network(indices_, strategy)
|
||||
compile_net(net, x_, input_v_)
|
||||
|
||||
|
||||
def test_inplace_add_strategy_error():
|
||||
@pytest.mark.parametrize("network", [InplaceAddNet, InplaceSubNet, InplaceUpdateNet])
|
||||
def test_inplace_add_strategy_error(network):
|
||||
"""
|
||||
Feature: test invalid strategy for InplaceAdd
|
||||
Description: illegal strategy
|
||||
|
@ -91,54 +104,6 @@ def test_inplace_add_strategy_error():
|
|||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
strategy = ((1, 4, 2), (1, 2, 4))
|
||||
net = InplaceAddNet(indices_, strategy)
|
||||
with pytest.raises(RuntimeError):
|
||||
compile_net(net, x_, input_v_)
|
||||
|
||||
|
||||
def test_inplace_sub_auto_parallel():
|
||||
"""
|
||||
Feature: test InplaceSub auto parallel
|
||||
Description: auto parallel
|
||||
Expectation: compile success
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0)
|
||||
net = InplaceSubNet(indices_)
|
||||
compile_net(net, x_, input_v_)
|
||||
|
||||
|
||||
def test_inplace_sub_model_parallel():
|
||||
"""
|
||||
Feature: test InplaceSub model parallel
|
||||
Description: model parallel
|
||||
Expectation: compile success
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
strategy = ((1, 4, 2), (1, 4, 2))
|
||||
net = InplaceSubNet(indices_, strategy)
|
||||
compile_net(net, x_, input_v_)
|
||||
|
||||
|
||||
def test_inplace_sub_model_parallel_with_repeated_cal():
|
||||
"""
|
||||
Feature: test InplaceSub model parallel with repeated calculation
|
||||
Description: model parallel
|
||||
Expectation: compile success
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
strategy = ((1, 2, 2), (1, 2, 2))
|
||||
net = InplaceSubNet(indices_, strategy)
|
||||
compile_net(net, x_, input_v_)
|
||||
|
||||
|
||||
def test_inplace_sub_strategy_error():
|
||||
"""
|
||||
Feature: test invalid strategy for InplaceSub
|
||||
Description: illegal strategy
|
||||
Expectation: raise RuntimeError
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
strategy = ((1, 4, 2), (1, 2, 4))
|
||||
net = InplaceSubNet(indices_, strategy)
|
||||
net = network(indices_, strategy)
|
||||
with pytest.raises(RuntimeError):
|
||||
compile_net(net, x_, input_v_)
|
||||
|
|
|
@ -54,7 +54,6 @@ def test_auto_parallel_random_choice_with_mask():
|
|||
Description: auto parallel
|
||||
Expectation: compile success
|
||||
"""
|
||||
context.set_context(device_target="GPU")
|
||||
context.set_auto_parallel_context(dataset_strategy="full_batch")
|
||||
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0)
|
||||
net = Net()
|
||||
|
@ -67,24 +66,9 @@ def test_random_choice_with_mask_wrong_strategy():
|
|||
Description: illegal strategy
|
||||
Expectation: raise RuntimeError
|
||||
"""
|
||||
context.set_context(device_target="GPU")
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
strategy = ((8, 1),)
|
||||
net = Net(strategy)
|
||||
with pytest.raises(RuntimeError):
|
||||
compile_net(net, _input_x)
|
||||
context.reset_auto_parallel_context()
|
||||
|
||||
|
||||
def test_random_choice_with_mask_not_gpu():
|
||||
"""
|
||||
Feature: RandomChoiceWithMask
|
||||
Description: not compile with gpu backend
|
||||
Expectation: raise RuntimeError
|
||||
"""
|
||||
context.set_context(device_target="Ascend")
|
||||
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0)
|
||||
net = Net()
|
||||
with pytest.raises(RuntimeError):
|
||||
compile_net(net, _input_x)
|
||||
context.reset_auto_parallel_context()
|
||||
|
|
Loading…
Reference in New Issue