forked from mindspore-Ecosystem/mindspore
!31040 Produce parallel operators for Argmin/max, SquareSumAll and UnsortedSegmentProd
Merge pull request !31040 from Bert0108/reduce_operators_arg
This commit is contained in:
commit
c2a5cc1486
|
@ -753,6 +753,7 @@ class ReduceSumCost : public OperatorCost {
|
|||
};
|
||||
using ReduceMethodCost = ReduceSumCost;
|
||||
using ReduceProdCost = ReduceSumCost;
|
||||
using SquareSumAllCost = ReduceSumCost;
|
||||
|
||||
class ReduceMeanCost : public ReduceSumCost {
|
||||
public:
|
||||
|
@ -784,6 +785,7 @@ class ArgMaxWithValueCost : public ReduceSumCost {
|
|||
void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
|
||||
};
|
||||
using ArgMinWithValueCost = ArgMaxWithValueCost;
|
||||
using ArgmaxCost = ArgMaxWithValueCost;
|
||||
|
||||
class GetNextCost : public OperatorCost {
|
||||
public:
|
||||
|
@ -919,6 +921,7 @@ class UnsortedSegmentSumCost : public OperatorCost {
|
|||
// Taking account of input
|
||||
void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
|
||||
};
|
||||
using UnsortedSegmentProdCost = UnsortedSegmentSumCost;
|
||||
|
||||
class UnsortedSegmentMinCost : public OperatorCost {
|
||||
public:
|
||||
|
|
|
@ -217,6 +217,10 @@ REGISTER(CropAndResizeInfo);
|
|||
REGISTER(ROIAlignInfo);
|
||||
REGISTER(ReduceProdInfo);
|
||||
REGISTER(ReduceAllInfo);
|
||||
REGISTER(ArgmaxInfo);
|
||||
REGISTER(ArgminInfo);
|
||||
REGISTER(UnsortedSegmentProdInfo);
|
||||
REGISTER(SquareSumAllInfo);
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -111,6 +111,7 @@ constexpr char FUNCTIONAL_OP_PATH[] = "mindspore.ops.functional";
|
|||
constexpr char GET_OP_FUNCTION_PATH[] = "mindspore.parallel._utils";
|
||||
constexpr char GET_OP_FUNCTION[] = "_get_python_op";
|
||||
constexpr char KEEP_DIMS[] = "keep_dims";
|
||||
constexpr char OUTPUT_TYPE[] = "output_type";
|
||||
constexpr char CROSS_BATCH[] = "cross_batch";
|
||||
constexpr char STEP_PARALLEL_BEGIN[] = "step_parallel_begin";
|
||||
constexpr char STEP_PARALLEL_END[] = "step_parallel_end";
|
||||
|
@ -335,6 +336,8 @@ constexpr char REDUCE_ALL[] = "ReduceAll";
|
|||
constexpr char REDUCE_ANY[] = "ReduceAny";
|
||||
constexpr char ARGMAXWITHVALUE[] = "ArgMaxWithValue";
|
||||
constexpr char ARGMINWITHVALUE[] = "ArgMinWithValue";
|
||||
constexpr char ARGMAX[] = "Argmax";
|
||||
constexpr char ARGMIN[] = "Argmin";
|
||||
constexpr char CONV2D[] = "Conv2D";
|
||||
constexpr char CONV2D_BACK_PROP_INPUT[] = "Conv2DBackpropInput";
|
||||
constexpr char CONV2D_TRANSPOSE[] = "Conv2DTranspose";
|
||||
|
@ -432,6 +435,7 @@ constexpr char UNSORTEF_SEGMENT_PRODD[] = "UnsortedSegmentProdD";
|
|||
constexpr char UNSORTED_SEGMENT_SUM[] = "UnsortedSegmentSum";
|
||||
constexpr char UNSORTED_SEGMENT_MIN[] = "UnsortedSegmentMin";
|
||||
constexpr char UNSORTED_SEGMENT_MAX[] = "UnsortedSegmentMax";
|
||||
constexpr char UNSORTED_SEGMENT_PROD[] = "UnsortedSegmentProd";
|
||||
constexpr char DEPTHWISE_CONV2D_NATIVE[] = "DepthwiseConv2dNative";
|
||||
constexpr char DEPTHWISE_CONV2D[] = "DepthwiseConv2D";
|
||||
constexpr char DROPOUT[] = "Dropout";
|
||||
|
@ -450,6 +454,7 @@ constexpr char RANDOM_CHOICE_WITH_MASK[] = "RandomChoiceWithMask";
|
|||
constexpr char CROP_AND_RESIZE[] = "CropAndResize";
|
||||
constexpr char MASKED_FILL[] = "MaskedFill";
|
||||
constexpr char ROI_ALIGN[] = "ROIAlign";
|
||||
constexpr char SQUARE_SUM_ALL[] = "SquareSumAll";
|
||||
|
||||
// pipeline
|
||||
constexpr size_t PIPELINE_FUSTION_OFFSET = 100;
|
||||
|
@ -483,6 +488,7 @@ constexpr char MAKE_RECORD[] = "make_record";
|
|||
constexpr char LIST_GETITEM[] = "list_getitem";
|
||||
constexpr char ARRAY_GETITEM[] = "array_getitem";
|
||||
constexpr char TUPLE_SETITEM[] = "tuple_setitem";
|
||||
constexpr char TUPLE_GETITEM[] = "tuple_getitem";
|
||||
constexpr char LIST_SETITEM[] = "list_setitem";
|
||||
constexpr char ARRAY_SETITEM[] = "array_setitem";
|
||||
constexpr char DICT_GETITEM[] = "dict_getitem";
|
||||
|
|
|
@ -26,6 +26,7 @@
|
|||
#include "frontend/parallel/device_matrix.h"
|
||||
#include "frontend/parallel/tensor_layout/tensor_redistribution.h"
|
||||
#include "utils/log_adapter.h"
|
||||
#include "frontend/parallel/graph_util/generate_graph.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace parallel {
|
||||
|
@ -578,15 +579,321 @@ Status ArgMaxWithValueInfo::InferAsLossDivisor() {
|
|||
return SUCCESS;
|
||||
}
|
||||
|
||||
std::vector<StrategyPtr> ArgMaxWithValueInfo::GenerateOpStrategies(int64_t stage_id) {
|
||||
Shape input0_split(inputs_shape_[0].size(), 1);
|
||||
Shapes splittable_inputs = {input0_split};
|
||||
std::vector<StrategyPtr> sp_vector;
|
||||
if (GenerateStrategiesForIndependentInputs(stage_id, inputs_shape_, splittable_inputs, &sp_vector) != SUCCESS) {
|
||||
MS_LOG(EXCEPTION) << name_ << ": GenerateStrategiesForIndependentInputs failed.";
|
||||
std::vector<int64_t> ArgmaxInfo::reduce_dim() {
|
||||
// get axis from attribution
|
||||
std::vector<int64_t> dim_list;
|
||||
auto iter = attrs_.find(AXIS);
|
||||
if (iter == attrs_.end()) {
|
||||
MS_LOG(EXCEPTION) << name_ << ": Don't have attribution axis.";
|
||||
}
|
||||
|
||||
return sp_vector;
|
||||
MS_ASSERT(inputs_shape_.size() == 1);
|
||||
auto input_dim = inputs_shape_.at(0).size();
|
||||
MS_EXCEPTION_IF_NULL(iter->second);
|
||||
if (iter->second->isa<ValueTuple>()) {
|
||||
auto attr_axis = GetValue<std::vector<int64_t>>(iter->second);
|
||||
if (attr_axis.empty()) {
|
||||
for (size_t i = 0; i < input_dim; ++i) {
|
||||
dim_list.push_back(SizeToLong(i));
|
||||
}
|
||||
} else {
|
||||
for (auto &axis : attr_axis) {
|
||||
axis < 0 ? dim_list.push_back(axis + SizeToLong(input_dim)) : dim_list.push_back(axis);
|
||||
}
|
||||
}
|
||||
} else if (iter->second->isa<Int64Imm>()) {
|
||||
int64_t axis = GetValue<int64_t>(iter->second);
|
||||
axis < 0 ? dim_list.push_back(axis + SizeToLong(input_dim)) : dim_list.push_back(axis);
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Axis type is invalid.";
|
||||
}
|
||||
|
||||
return dim_list;
|
||||
}
|
||||
|
||||
Status ArgmaxInfo::GetAttrs() {
|
||||
// set the keep_dims False as default
|
||||
keepdims_ = false;
|
||||
|
||||
// get attr output_type and cross_batch
|
||||
auto output_type_iter = attrs_.find(OUTPUT_TYPE);
|
||||
if (output_type_iter != attrs_.end()) {
|
||||
MS_EXCEPTION_IF_NULL(output_type_iter->second);
|
||||
}
|
||||
|
||||
auto cross_batch_iter = attrs_.find(CROSS_BATCH);
|
||||
if (cross_batch_iter != attrs_.end()) {
|
||||
MS_EXCEPTION_IF_NULL(cross_batch_iter->second);
|
||||
if (!cross_batch_iter->second->isa<BoolImm>()) {
|
||||
MS_LOG(ERROR) << name_ << ": cross_batch is not a bool.";
|
||||
return FAILED;
|
||||
}
|
||||
cross_batch_ = cross_batch_iter->second->cast<BoolImmPtr>()->value();
|
||||
}
|
||||
auto reducemethodcost = std::dynamic_pointer_cast<ReduceMethodCost>(operator_cost());
|
||||
if (reducemethodcost == nullptr) {
|
||||
MS_LOG(ERROR) << "Cost cast to ReduceMethodCostPtr failed!";
|
||||
return FAILED;
|
||||
}
|
||||
reducemethodcost->set_cross_batch(cross_batch_);
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status ArgmaxInfo::CheckStrategy(const StrategyPtr &strategy) {
|
||||
if (ReduceMethod::CheckStrategy(strategy) != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << ": CheckStrategy for parent class ReduceMethod failed";
|
||||
return FAILED;
|
||||
}
|
||||
std::vector<int64_t> dim_list = reduce_dim();
|
||||
MS_ASSERT(dim_list.size() == 1);
|
||||
|
||||
Strategys stra = strategy->GetInputDim();
|
||||
MS_ASSERT(stra.size() == 1);
|
||||
Shape input_strategy = stra.at(0);
|
||||
MS_ASSERT(dim_list.at(0) < input_strategy.size());
|
||||
if (input_strategy.at(LongToSize(dim_list.at(0))) != 1) {
|
||||
MS_LOG(WARNING) << name_
|
||||
<< " CheckStrategy for Argmax/Argmin, the strategy corresponding to axis is not one, real strategy "
|
||||
"is "
|
||||
<< input_strategy.at(LongToSize(dim_list.at(0)))
|
||||
<< ", the output index may be not compatible with the stand alone Primitive";
|
||||
}
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status ArgmaxInfo::InferMirrorOps() {
|
||||
if (OperatorInfo::InferMirrorOps() != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << ": InferMirrorOps for parent class OperatorInfo failed";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
std::vector<int64_t> SquareSumAllInfo::reduce_dim() {
|
||||
std::vector<int64_t> dim_list;
|
||||
auto input_dim = inputs_shape_.at(0).size();
|
||||
// reduce all dim
|
||||
for (size_t i = 0; i < input_dim; ++i) {
|
||||
dim_list.push_back(SizeToLong(i));
|
||||
}
|
||||
|
||||
return dim_list;
|
||||
}
|
||||
|
||||
Status SquareSumAllInfo::GetAttrs() {
|
||||
// set the keep_dims False as default
|
||||
keepdims_ = false;
|
||||
|
||||
// get attr cross_batch
|
||||
auto cross_batch_iter = attrs_.find(CROSS_BATCH);
|
||||
if (cross_batch_iter != attrs_.end()) {
|
||||
MS_EXCEPTION_IF_NULL(cross_batch_iter->second);
|
||||
if (!cross_batch_iter->second->isa<BoolImm>()) {
|
||||
MS_LOG(ERROR) << name_ << ": cross_batch is not a bool.";
|
||||
return FAILED;
|
||||
}
|
||||
cross_batch_ = cross_batch_iter->second->cast<BoolImmPtr>()->value();
|
||||
}
|
||||
auto reducemethodcost = std::dynamic_pointer_cast<ReduceMethodCost>(operator_cost());
|
||||
if (reducemethodcost == nullptr) {
|
||||
MS_LOG(ERROR) << "Cost cast to ReduceMethodCostPtr failed!";
|
||||
return FAILED;
|
||||
}
|
||||
reducemethodcost->set_cross_batch(cross_batch_);
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status SquareSumAllInfo::CheckStrategy(const StrategyPtr &strategy) {
|
||||
if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << " : Invalid strategy.";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
Strategys stra = strategy->GetInputDim();
|
||||
Dimensions sub_a_strategy = stra.at(0);
|
||||
Dimensions sub_b_strategy = stra.at(1);
|
||||
Shape input_a_shape = inputs_shape_.at(0);
|
||||
Shape input_b_shape = inputs_shape_.at(1);
|
||||
MS_ASSERT(input_a_shape.size() == input_b_shape.size());
|
||||
MS_ASSERT(sub_a_strategy.size() == sub_b_strategy.size());
|
||||
|
||||
for (size_t i = 0; i < input_a_shape.size(); ++i) {
|
||||
if ((sub_a_strategy[i] != sub_b_strategy[i]) && (input_a_shape[i] != input_b_shape[i])) {
|
||||
MS_LOG(ERROR) << name_ << " : Invalid strategy.";
|
||||
return FAILED;
|
||||
}
|
||||
}
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status SquareSumAllInfo::InferDevMatrixShape() {
|
||||
Strategys strategy = strategy_->GetInputDim();
|
||||
Dimensions sub_a_strategy = strategy.at(0);
|
||||
Shape dev_shape;
|
||||
for (size_t i = 0; i < sub_a_strategy.size(); ++i) {
|
||||
dev_shape.push_back(sub_a_strategy[i]);
|
||||
}
|
||||
dev_matrix_shape_ = dev_shape;
|
||||
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status SquareSumAllInfo::InferTensorMap() {
|
||||
if (ReduceMethod::InferTensorMap() != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << ": InferTensorMap for parent class ReduceMethod failed";
|
||||
return FAILED;
|
||||
}
|
||||
MS_ASSERT(outputs_tensor_map_.size() == 1);
|
||||
inputs_tensor_map_.push_back(inputs_tensor_map_[0]);
|
||||
outputs_tensor_map_.push_back(outputs_tensor_map_[0]);
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status SquareSumAllInfo::InferTensorInfo() {
|
||||
// infer tensor shape
|
||||
Shape input_shape = inputs_shape_.at(0);
|
||||
Shape output_shape = outputs_shape_.at(0);
|
||||
|
||||
// infer slice shape
|
||||
Shapes inputs_slice_shape, outputs_slice_shape;
|
||||
Strategys inputs_strategy = strategy_->GetInputDim();
|
||||
Dimensions output_strategy = InferOutputStrategy();
|
||||
|
||||
Strategys outputs_strategy = {output_strategy, output_strategy};
|
||||
if (InferSliceShape(inputs_strategy, outputs_strategy, &inputs_slice_shape, &outputs_slice_shape) != SUCCESS) {
|
||||
return FAILED;
|
||||
}
|
||||
Shape input_slice_shape = inputs_slice_shape.at(0);
|
||||
Shape output_slice_shape = outputs_slice_shape.at(0);
|
||||
|
||||
TensorLayout input_tensor_layout, output_tensor_layout;
|
||||
if ((input_tensor_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_[0], input_shape) != SUCCESS) ||
|
||||
(output_tensor_layout.InitFromVector(dev_matrix_shape_, outputs_tensor_map_[0], output_shape) != SUCCESS)) {
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
std::vector<int64_t> dim_list = reduce_dim();
|
||||
TensorInfo input_tensor_info(input_tensor_layout, input_shape, input_slice_shape);
|
||||
TensorInfo output_tensor_info(output_tensor_layout, output_shape, output_slice_shape);
|
||||
input_tensor_info.set_reduce_dim(dim_list);
|
||||
|
||||
inputs_tensor_info_.push_back(input_tensor_info);
|
||||
inputs_tensor_info_.push_back(input_tensor_info);
|
||||
outputs_tensor_info_.push_back(output_tensor_info);
|
||||
outputs_tensor_info_.push_back(output_tensor_info);
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status SquareSumAllInfo::InferGroup() {
|
||||
Dimensions stra = strategy_->GetInputDim().at(0);
|
||||
forward_op_.clear();
|
||||
std::vector<int64_t> dim_list = reduce_dim();
|
||||
size_t size = stra.size();
|
||||
// judge if the reduce dim is partitioned.
|
||||
Shape group_creat_map;
|
||||
|
||||
// if repeated calculation and the repeated_calc_num_ insert to the first dimension of dev matrix,
|
||||
// it need to handle the first dimension of map.
|
||||
if ((dev_matrix_shape_.size() > size) && !repeated_num_in_dev_matrix_right_) {
|
||||
group_creat_map.push_back(SizeToInt(dev_matrix_shape_.size() - size_t(1)));
|
||||
}
|
||||
|
||||
for (size_t index = 0; index < size; ++index) {
|
||||
auto pos =
|
||||
std::find_if(dim_list.begin(), dim_list.end(), [index](const int64_t &dim) { return SizeToLong(index) == dim; });
|
||||
if (pos != dim_list.end() && stra[index] != 1) {
|
||||
continue;
|
||||
}
|
||||
group_creat_map.push_back(SizeToLong(size) - SizeToLong(index) - 1);
|
||||
}
|
||||
|
||||
// if repeated calculation and the repeated_calc_num_ insert to the last dimension of dev matrix,
|
||||
// it need to handle the group_creat_map and insert the 0 to the last dimension of the group_creat_map.
|
||||
if (repeated_num_in_dev_matrix_right_ && (repeated_calc_num_ > 1)) {
|
||||
for (auto &ele : group_creat_map) {
|
||||
if (ele == MAP_NONE) {
|
||||
continue;
|
||||
}
|
||||
ele += 1;
|
||||
}
|
||||
group_creat_map.push_back(0);
|
||||
}
|
||||
|
||||
if (CreateGroupByTensorMap(group_creat_map, &group_) != SUCCESS) {
|
||||
ReportError(name_ + ": Create group failed.");
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status SquareSumAllInfo::ComputeReplaceGraph(const CNodePtr &cnode) {
|
||||
GenerateGraph gen_g = GenerateGraph(attrs_);
|
||||
if (gen_g.Init(cnode) != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << ": GenerateGraph Init failed";
|
||||
}
|
||||
if (InferGroup() != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << ": Infer Group failed";
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << name_ << ": The rank is " << g_device_manager->rank_index_in_stage();
|
||||
auto square_sum_all =
|
||||
gen_g.PushBack({gen_g.NewOpInst(SQUARE_SUM_ALL), gen_g.virtual_input_node(), gen_g.virtual_input_node()});
|
||||
auto get_item0 = gen_g.PushBack({gen_g.NewOpInst(TUPLE_GETITEM), square_sum_all, CreatInt64Imm(0)});
|
||||
auto get_item1 = gen_g.PushBack({gen_g.NewOpInst(TUPLE_GETITEM), square_sum_all, CreatInt64Imm(1)});
|
||||
|
||||
Attr attr_op = std::make_pair(OP, MakeValue(REDUCE_OP_SUM));
|
||||
Attr attr_group = std::make_pair(GROUP, MakeValue(group_[0].name()));
|
||||
OperatorAttrs attrs = {attr_op, attr_group};
|
||||
auto allreduce_op0 = gen_g.PushBack({gen_g.NewOpInst(ALL_REDUCE, attrs), get_item0});
|
||||
auto allreduce_op1 = gen_g.PushBack({gen_g.NewOpInst(ALL_REDUCE, attrs), get_item1});
|
||||
|
||||
auto make_list = gen_g.PushBack({gen_g.NewOpInst(MAKE_LIST), allreduce_op0, allreduce_op1});
|
||||
|
||||
std::vector<std::pair<AnfNodePtr, int64_t>> input_nodes = {std::make_pair(square_sum_all, 1),
|
||||
std::make_pair(square_sum_all, 2)};
|
||||
replace_graph_ = std::make_shared<std::pair<std::vector<std::pair<AnfNodePtr, int64_t>>, AnfNodePtr>>(
|
||||
std::make_pair(input_nodes, make_list));
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
ReplaceGraphPtr SquareSumAllInfo::replace_graph(const CNodePtr &cnode) {
|
||||
if (ComputeReplaceGraph(cnode) != SUCCESS) {
|
||||
MS_LOG(EXCEPTION) << name_ << ": ComputeReplaceGraph failed.";
|
||||
}
|
||||
return replace_graph_;
|
||||
}
|
||||
|
||||
Status SquareSumAllInfo::InferMirrorOps() {
|
||||
if (OperatorInfo::InferMirrorOps() != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << ": InferMirrorOps for parent class OperatorInfo failed";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status SquareSumAllInfo::InferAsLossDivisor() {
|
||||
if (outputs_tensor_map_.empty()) {
|
||||
MS_LOG(ERROR) << name_ << ": The outputs tensor map is empty.";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << name_ << " has two outputs, use output[0] to infer";
|
||||
if (outputs_tensor_map_[0].empty()) {
|
||||
as_loss_divisor_ = stage_device_size_;
|
||||
MS_LOG(INFO) << name_ << ": The output is a scalar, use the dev size" << as_loss_divisor_ << " as loss divisor.";
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
as_loss_divisor_ = ComputeRepeatDeviceNumByTensorMap(dev_matrix_shape_, outputs_tensor_map_[0]);
|
||||
|
||||
std::string dev_matrix_shape_str = ShapeToString(dev_matrix_shape_);
|
||||
std::string output_tensor_map_str = ShapeToString(outputs_tensor_map_[0]);
|
||||
MS_LOG(INFO) << name_ << ": the dev matrix shape, the output tensor map, and loss divisor is " << dev_matrix_shape_str
|
||||
<< ", " << output_tensor_map_str << ", " << as_loss_divisor_;
|
||||
return SUCCESS;
|
||||
}
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -76,8 +76,6 @@ class ArgMaxWithValueInfo : public ReduceMethod {
|
|||
|
||||
~ArgMaxWithValueInfo() override = default;
|
||||
|
||||
std::vector<StrategyPtr> GenerateOpStrategies(int64_t stage_id) override;
|
||||
|
||||
protected:
|
||||
std::vector<int64_t> reduce_dim() override;
|
||||
Status CheckStrategy(const StrategyPtr &strategy) override;
|
||||
|
@ -167,6 +165,63 @@ class ReduceAllInfo : public ReduceAnyInfo {
|
|||
|
||||
~ReduceAllInfo() override = default;
|
||||
};
|
||||
|
||||
class ArgmaxInfo : public ReduceMethod {
|
||||
public:
|
||||
ArgmaxInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: ReduceMethod(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArgmaxCost>()) {
|
||||
reduce_method_ = REDUCE_OP_MAX;
|
||||
}
|
||||
|
||||
~ArgmaxInfo() override = default;
|
||||
|
||||
protected:
|
||||
std::vector<int64_t> reduce_dim() override;
|
||||
Status GetAttrs() override;
|
||||
Status CheckStrategy(const StrategyPtr &strategy) override;
|
||||
Status InferMirrorOps() override;
|
||||
};
|
||||
|
||||
class ArgminInfo : public ArgmaxInfo {
|
||||
public:
|
||||
ArgminInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: ArgmaxInfo(name, inputs_shape, outputs_shape, attrs) {
|
||||
reduce_method_ = REDUCE_OP_MIN;
|
||||
}
|
||||
|
||||
~ArgminInfo() override = default;
|
||||
};
|
||||
|
||||
class SquareSumAllInfo : public ReduceMethod {
|
||||
public:
|
||||
SquareSumAllInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: ReduceMethod(name, inputs_shape, outputs_shape, attrs, std::make_shared<SquareSumAllCost>()) {
|
||||
reduce_method_ = REDUCE_OP_SUM;
|
||||
}
|
||||
~SquareSumAllInfo() override = default;
|
||||
|
||||
ReplaceGraphPtr replace_graph(const CNodePtr &cnode) override;
|
||||
|
||||
protected:
|
||||
std::vector<int64_t> reduce_dim() override;
|
||||
Status GetAttrs() override;
|
||||
Status CheckStrategy(const StrategyPtr &strategy) override;
|
||||
Status InferDevMatrixShape() override;
|
||||
Status InferTensorMap() override;
|
||||
Status InferTensorInfo() override;
|
||||
Status InferForwardCommunication() override { return SUCCESS; }
|
||||
Status InferMirrorOps() override;
|
||||
Status InferAsLossDivisor() override;
|
||||
|
||||
private:
|
||||
Status InferGroup();
|
||||
Status ComputeReplaceGraph(const CNodePtr &cnode);
|
||||
|
||||
std::vector<Group> group_;
|
||||
};
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_REDUCE_SUM_INFO_H_
|
||||
|
|
|
@ -205,7 +205,7 @@ Status UnsortedSegmentOpInfo::InferForwardCommunication() {
|
|||
return SUCCESS;
|
||||
}
|
||||
|
||||
Operator op = CreateAllReduceOp(REDUCE_OP_SUM, group_list[0].name());
|
||||
Operator op = CreateAllReduceOp(reduce_method_, group_list[0].name());
|
||||
forward_op_.push_back(op);
|
||||
MS_LOG(INFO) << name_ << " : The group name of forward communication is " << group_list[0].name();
|
||||
return SUCCESS;
|
||||
|
|
|
@ -50,6 +50,7 @@ class UnsortedSegmentOpInfo : public OperatorInfo {
|
|||
std::shared_ptr<Strategys> GenerateBatchStrategies() override;
|
||||
|
||||
protected:
|
||||
std::string reduce_method_;
|
||||
Status CheckStrategy(const StrategyPtr &strategy) override;
|
||||
Status InferForwardCommunication() override;
|
||||
Status InferMirrorOps() override;
|
||||
|
@ -65,15 +66,29 @@ class UnsortedSegmentSumInfo : public UnsortedSegmentOpInfo {
|
|||
public:
|
||||
UnsortedSegmentSumInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: UnsortedSegmentOpInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<UnsortedSegmentSumCost>()) {}
|
||||
: UnsortedSegmentOpInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<UnsortedSegmentSumCost>()) {
|
||||
reduce_method_ = REDUCE_OP_SUM;
|
||||
}
|
||||
~UnsortedSegmentSumInfo() override = default;
|
||||
};
|
||||
|
||||
class UnsortedSegmentProdInfo : public UnsortedSegmentOpInfo {
|
||||
public:
|
||||
UnsortedSegmentProdInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: UnsortedSegmentOpInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<UnsortedSegmentProdCost>()) {
|
||||
reduce_method_ = REDUCE_OP_PROD;
|
||||
}
|
||||
~UnsortedSegmentProdInfo() override = default;
|
||||
};
|
||||
|
||||
class UnsortedSegmentMinInfo : public UnsortedSegmentOpInfo {
|
||||
public:
|
||||
UnsortedSegmentMinInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: UnsortedSegmentOpInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<UnsortedSegmentMinCost>()) {}
|
||||
: UnsortedSegmentOpInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<UnsortedSegmentMinCost>()) {
|
||||
reduce_method_ = REDUCE_OP_MIN;
|
||||
}
|
||||
~UnsortedSegmentMinInfo() override = default;
|
||||
|
||||
ReplaceGraphPtr replace_graph(const CNodePtr &cnode) override;
|
||||
|
@ -86,7 +101,9 @@ class UnsortedSegmentMaxInfo : public UnsortedSegmentOpInfo {
|
|||
public:
|
||||
UnsortedSegmentMaxInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: UnsortedSegmentOpInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<UnsortedSegmentMaxCost>()) {}
|
||||
: UnsortedSegmentOpInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<UnsortedSegmentMaxCost>()) {
|
||||
reduce_method_ = REDUCE_OP_MAX;
|
||||
}
|
||||
~UnsortedSegmentMaxInfo() override = default;
|
||||
|
||||
ReplaceGraphPtr replace_graph(const CNodePtr &cnode) override;
|
||||
|
|
|
@ -174,7 +174,8 @@ bool IsSplittableOperator(const std::string &op_name) {
|
|||
UNSORTED_SEGMENT_MIN, REPEAT_ELEMENTS, TENSOR_DOT, RANGE, UNIFORM_CANDIDATE_SAMPLER, SLICE, SELECT, GATHERD,
|
||||
UNSORTED_SEGMENT_MAX, GATHER_ND, TOPK, SCATTER_UPDATE, VIRTUAL_OUTPUT, CONV2D_BACK_PROP_INPUT, CONV2D_TRANSPOSE,
|
||||
MATMUL_DDS, DSD_MATMUL, UNIFORMREAL, RESIZE_BILINEAR, RESIZE_NEAREST_NEIGHBOR, CUMSUM, FAST_GELU, IOU,
|
||||
BOUNDING_BOX_ENCODE, RANDOM_CHOICE_WITH_MASK, CROP_AND_RESIZE, ROI_ALIGN, REDUCE_PROD, REDUCE_ANY, REDUCE_ALL};
|
||||
BOUNDING_BOX_ENCODE, RANDOM_CHOICE_WITH_MASK, CROP_AND_RESIZE, ROI_ALIGN, REDUCE_PROD, REDUCE_ANY, REDUCE_ALL,
|
||||
ARGMAX, ARGMIN, UNSORTED_SEGMENT_PROD, SQUARE_SUM_ALL};
|
||||
// clang-format on
|
||||
|
||||
auto iter = splittable_op.find(op_name);
|
||||
|
|
|
@ -415,6 +415,29 @@ class ArgMinWithValueNet(nn.Cell):
|
|||
out = self.mul2(out, b)
|
||||
return out
|
||||
|
||||
class ArgMaxNet(nn.Cell):
|
||||
def __init__(self, strategy1, strategy2):
|
||||
super(ArgMaxNet, self).__init__()
|
||||
self.mul1 = P.Mul().shard(strategy1)
|
||||
self.arg_max = P.Argmax(axis=-1).shard(strategy2)
|
||||
|
||||
def construct(self, x, y):
|
||||
out = self.mul1(x, y)
|
||||
out = self.arg_max(out)
|
||||
return out
|
||||
|
||||
|
||||
class ArgMinNet(nn.Cell):
|
||||
def __init__(self, strategy1, strategy2):
|
||||
super(ArgMinNet, self).__init__()
|
||||
self.mul1 = P.Mul().shard(strategy1)
|
||||
self.arg_min = P.Argmin(axis=-1).shard(strategy2)
|
||||
|
||||
def construct(self, x, y):
|
||||
out = self.mul1(x, y)
|
||||
out = self.arg_min(out)
|
||||
return out
|
||||
|
||||
|
||||
def gen_inputs_and_compile_net(net):
|
||||
x = Tensor(np.ones([128, 64, 64]), dtype=ms.float32)
|
||||
|
@ -514,6 +537,90 @@ def test_arg_min_with_value_mul_auto():
|
|||
gen_inputs_and_compile_net(net)
|
||||
|
||||
|
||||
def test_arg_max_semi_axis_parallel():
|
||||
"""
|
||||
Feature: test Argmax semi parallel strategy
|
||||
Description: partition the reduced axes
|
||||
Expectation: compile success
|
||||
"""
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0)
|
||||
strategy1 = ((1, 4, 2), (1, 4, 2))
|
||||
strategy2 = ((4, 1, 2),)
|
||||
net = GradWrapNoBias(NetWithLossNoBias(ArgMaxNet(strategy1, strategy2)))
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
|
||||
gen_inputs_and_compile_net_no_bias(net)
|
||||
|
||||
|
||||
def test_arg_max_mul_semi():
|
||||
"""
|
||||
Feature: test Argmax model parallel strategy
|
||||
Description: partition the non-reduced axes
|
||||
Expectation: compile success
|
||||
"""
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0)
|
||||
strategy1 = ((1, 4, 2), (1, 4, 2))
|
||||
strategy2 = ((4, 2, 1),)
|
||||
net = GradWrapNoBias(NetWithLossNoBias(ArgMaxNet(strategy1, strategy2)))
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
|
||||
gen_inputs_and_compile_net_no_bias(net)
|
||||
|
||||
|
||||
def test_arg_max_mul_auto():
|
||||
"""
|
||||
Feature: test Argmax auto parallel strategy
|
||||
Description: don't set the strategy
|
||||
Expectation: compile success
|
||||
"""
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0)
|
||||
strategy1 = None
|
||||
strategy2 = None
|
||||
net = GradWrapNoBias(NetWithLossNoBias(ArgMaxNet(strategy1, strategy2)))
|
||||
context.set_auto_parallel_context(parallel_mode="auto_parallel")
|
||||
gen_inputs_and_compile_net_no_bias(net)
|
||||
|
||||
|
||||
def test_arg_min_semi_axis_parallel():
|
||||
"""
|
||||
Feature: test Argmin semi parallel strategy
|
||||
Description: partition the reduced axes
|
||||
Expectation: compile success
|
||||
"""
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0)
|
||||
strategy1 = ((1, 4, 2), (1, 4, 2))
|
||||
strategy2 = ((4, 1, 2),)
|
||||
net = GradWrapNoBias(NetWithLossNoBias(ArgMinNet(strategy1, strategy2)))
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
|
||||
gen_inputs_and_compile_net_no_bias(net)
|
||||
|
||||
|
||||
def test_arg_min_mul_semi():
|
||||
"""
|
||||
Feature: test Argmin model parallel strategy
|
||||
Description: partition the non-reduced axes
|
||||
Expectation: compile success
|
||||
"""
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0)
|
||||
strategy1 = ((1, 4, 2), (1, 4, 2))
|
||||
strategy2 = ((4, 2, 1),)
|
||||
net = GradWrapNoBias(NetWithLossNoBias(ArgMinNet(strategy1, strategy2)))
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
|
||||
gen_inputs_and_compile_net_no_bias(net)
|
||||
|
||||
|
||||
def test_arg_min_mul_auto():
|
||||
"""
|
||||
Feature: test Argmin auto parallel strategy
|
||||
Description: don't set the strategy
|
||||
Expectation: compile success
|
||||
"""
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0)
|
||||
strategy1 = None
|
||||
strategy2 = None
|
||||
net = GradWrapNoBias(NetWithLossNoBias(ArgMinNet(strategy1, strategy2)))
|
||||
context.set_auto_parallel_context(parallel_mode="auto_parallel")
|
||||
gen_inputs_and_compile_net_no_bias(net)
|
||||
|
||||
|
||||
class ArgMinWithValueNet2(nn.Cell):
|
||||
def __init__(self, strategy1, strategy2, strategy3):
|
||||
super(ArgMinWithValueNet2, self).__init__()
|
||||
|
@ -915,3 +1022,59 @@ def test_prod_mul_auto():
|
|||
net = GradWrapNoBias(NetWithLossNoBias(Net(strategy1, strategy2)))
|
||||
context.set_auto_parallel_context(parallel_mode="auto_parallel")
|
||||
gen_inputs_and_compile_net_no_bias(net)
|
||||
|
||||
|
||||
def test_square_sum_all_mul():
|
||||
"""
|
||||
Feature: test SquareSumAll model parallel strategy
|
||||
Description: partition the reduced axes
|
||||
Expectation: compile success
|
||||
"""
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, strategy1, strategy2):
|
||||
super(Net, self).__init__()
|
||||
self.mul1 = P.Mul().shard(strategy1)
|
||||
self.square_sum_all = P.SquareSumAll().shard(strategy2)
|
||||
|
||||
def construct(self, x, y):
|
||||
out = self.mul1(x, y)
|
||||
out = self.square_sum_all(out, out)
|
||||
return out
|
||||
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0)
|
||||
strategy1 = ((1, 1, 8), (1, 1, 8))
|
||||
strategy2 = ((2, 4, 1), (2, 4, 1))
|
||||
net = Net(strategy1, strategy2)
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
|
||||
|
||||
x = Tensor(np.ones([128, 32, 64]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([128, 32, 64]), dtype=ms.float32)
|
||||
compile_net_no_bias(net, x, y)
|
||||
|
||||
|
||||
def test_square_sum_all_mul2():
|
||||
"""
|
||||
Feature: test SquareSumAll model parallel strategy
|
||||
Description: partition the reduced axes
|
||||
Expectation: compile success
|
||||
"""
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, stra_mul, stra_prod):
|
||||
super(Net, self).__init__()
|
||||
self.mul = P.Mul().shard(stra_mul)
|
||||
self.square_sum_all = P.SquareSumAll().shard(stra_prod)
|
||||
|
||||
def construct(self, x, y):
|
||||
out = self.mul(x, y)
|
||||
out = self.square_sum_all(out, out)
|
||||
return out
|
||||
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0)
|
||||
strategy1 = ((1, 1, 8), (1, 1, 8))
|
||||
strategy2 = ((8, 1, 1), (8, 1, 1))
|
||||
net = Net(strategy1, strategy2)
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
|
||||
|
||||
x = Tensor(np.ones([128, 32, 64]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([128, 32, 64]), dtype=ms.float32)
|
||||
compile_net_no_bias(net, x, y)
|
||||
|
|
|
@ -0,0 +1,195 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore import context
|
||||
from mindspore.common.api import _cell_graph_executor
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.ops import operations as P
|
||||
from tests.ut.python.ops.test_math_ops import VirtualLoss
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
|
||||
grad_all = C.GradOperation(get_all=True)
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, strategy1, strategy2, num_segments):
|
||||
super(Net, self).__init__()
|
||||
self.merge_op = P.UnsortedSegmentProd().shard((strategy1, strategy2))
|
||||
self.num_segments = num_segments
|
||||
|
||||
def construct(self, vectors, segment_ids):
|
||||
predict = self.merge_op(vectors, segment_ids, self.num_segments)
|
||||
return predict
|
||||
|
||||
|
||||
class GradWrap(nn.Cell):
|
||||
def __init__(self, network):
|
||||
super(GradWrap, self).__init__()
|
||||
self.network = network
|
||||
|
||||
def construct(self, x, y):
|
||||
return grad_all(self.network)(x, y)
|
||||
|
||||
|
||||
class NetWithLoss(nn.Cell):
|
||||
def __init__(self, network):
|
||||
super(NetWithLoss, self).__init__()
|
||||
self.network = network
|
||||
self.loss = VirtualLoss()
|
||||
|
||||
def construct(self, x, y):
|
||||
predict = self.network(x, y)
|
||||
return self.loss(predict)
|
||||
|
||||
|
||||
def compile_graph(x, y, segments, strategy1, strategy2, auto=False):
|
||||
if auto:
|
||||
context.set_auto_parallel_context(parallel_mode="auto_parallel")
|
||||
else:
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
|
||||
net = GradWrap(NetWithLoss(Net(strategy1, strategy2, segments)))
|
||||
net.set_auto_parallel()
|
||||
net.set_train()
|
||||
_cell_graph_executor.compile(net, x, y)
|
||||
|
||||
|
||||
def test_unsortedsegmentprod_model_parallel_slice_1d():
|
||||
"""
|
||||
Feature: distribute operator unsorted_segment_prod in auto parallel.
|
||||
Description: unsorted_segment_prod net with model parallel strategy, slice 1d.
|
||||
Expectation: compile done without error.
|
||||
"""
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0)
|
||||
x = Tensor(np.ones(8), ms.float32)
|
||||
y = Tensor(np.ones(8), ms.int32)
|
||||
num_segments = 16
|
||||
strategy1 = (8,)
|
||||
strategy2 = (8,)
|
||||
compile_graph(x, y, num_segments, strategy1, strategy2)
|
||||
|
||||
|
||||
def test_unsortedsegmentprod_model_parallel_no_slice_1d():
|
||||
"""
|
||||
Feature: distribute operator unsorted_segment_prod in auto parallel.
|
||||
Description: unsorted_segment_prod net with no slice strategy in semi auto parallel, slice 1d.
|
||||
Expectation: compile done without error.
|
||||
"""
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0)
|
||||
x = Tensor(np.ones(8), ms.float32)
|
||||
y = Tensor(np.ones(8), ms.int32)
|
||||
num_segments = 16
|
||||
strategy1 = (1,)
|
||||
strategy2 = (1,)
|
||||
compile_graph(x, y, num_segments, strategy1, strategy2)
|
||||
|
||||
|
||||
def test_unsortedsegmentprod_model_parallel_index_slice_2d():
|
||||
"""
|
||||
Feature: distribute operator unsorted_segment_prod in auto parallel.
|
||||
Description: unsorted_segment_prod net with model parallel strategy, slice 2d.
|
||||
Expectation: compile done without error.
|
||||
"""
|
||||
context.set_auto_parallel_context(device_num=4, global_rank=0)
|
||||
x = Tensor(np.ones((4, 8)), ms.float32)
|
||||
y = Tensor(np.arange(4), ms.int32)
|
||||
num_segments = 4
|
||||
strategy1 = (4, 1)
|
||||
strategy2 = (4,)
|
||||
compile_graph(x, y, num_segments, strategy1, strategy2)
|
||||
|
||||
|
||||
def test_unsortedsegmentprod_model_parallel_index_slice_3d():
|
||||
"""
|
||||
Feature: distribute operator unsorted_segment_prod in auto parallel.
|
||||
Description: unsorted_segment_prod net with model parallel strategy, slice 3d.
|
||||
Expectation: compile done without error.
|
||||
"""
|
||||
context.set_auto_parallel_context(device_num=4, global_rank=0)
|
||||
x = Tensor(np.ones((4, 4, 8)), ms.float32)
|
||||
y = Tensor(np.ones((4, 4)), ms.int32)
|
||||
num_segments = 16
|
||||
strategy1 = (2, 2, 1)
|
||||
strategy2 = (2, 2)
|
||||
with pytest.raises(ValueError):
|
||||
compile_graph(x, y, num_segments, strategy1, strategy2)
|
||||
|
||||
|
||||
def test_unsortedsegmentprod_model_parallel_vector_slice_2d():
|
||||
"""
|
||||
Feature: distribute operator unsorted_segment_prod in auto parallel.
|
||||
Description: unsorted_segment_prod net with model parallel strategy, slice 2d.
|
||||
Expectation: compile done without error.
|
||||
"""
|
||||
context.set_auto_parallel_context(device_num=4, global_rank=0)
|
||||
x = Tensor(np.ones((4, 8)), ms.float32)
|
||||
y = Tensor(np.ones(4), ms.int32)
|
||||
num_segments = 4
|
||||
strategy1 = (1, 4)
|
||||
strategy2 = (1,)
|
||||
compile_graph(x, y, num_segments, strategy1, strategy2)
|
||||
|
||||
|
||||
def test_unsortedsegmentprod_model_parallel_vector_slice_3d():
|
||||
"""
|
||||
Feature: distribute operator unsorted_segment_prod in auto parallel.
|
||||
Description: unsorted_segment_prod net with model parallel, slice 3d.
|
||||
Expectation: compile done without error.
|
||||
"""
|
||||
context.set_auto_parallel_context(device_num=4, global_rank=0)
|
||||
x = Tensor(np.ones((4, 8, 8)), ms.float32)
|
||||
y = Tensor(np.ones(4), ms.int32)
|
||||
num_segments = 4
|
||||
strategy1 = (1, 2, 2)
|
||||
strategy2 = (1,)
|
||||
compile_graph(x, y, num_segments, strategy1, strategy2)
|
||||
|
||||
|
||||
def test_unsortedsegmentprod_model_parallel_index_vector_slice_2d():
|
||||
"""
|
||||
Feature: distribute operator unsorted_segment_prod in auto parallel.
|
||||
Description: unsorted_segment_prod net with strategy in semi auto parallel, slice 2d.
|
||||
Expectation: compile done without error.
|
||||
"""
|
||||
context.set_auto_parallel_context(device_num=4, global_rank=0)
|
||||
x = Tensor(np.ones((4, 8)), ms.float32)
|
||||
y = Tensor(np.ones(4), ms.int32)
|
||||
num_segments = 4
|
||||
strategy1 = (2, 2)
|
||||
strategy2 = (2,)
|
||||
compile_graph(x, y, num_segments, strategy1, strategy2)
|
||||
|
||||
|
||||
def test_unsortedsegmentprod_model_parallel_index_vector_slice_3d():
|
||||
"""
|
||||
Feature: distribute operator unsorted_segment_prod in auto parallel.
|
||||
Description: unsorted_segment_prod net with strategy in semi auto parallel, slice 3d.
|
||||
Expectation: compile done without error.
|
||||
"""
|
||||
context.set_auto_parallel_context(device_num=4, global_rank=0)
|
||||
x = Tensor(np.ones((4, 4, 8)), ms.float32)
|
||||
y = Tensor(np.ones((4, 4)), ms.int32)
|
||||
num_segments = 16
|
||||
strategy1 = (2, 1, 2)
|
||||
strategy2 = (2, 1)
|
||||
with pytest.raises(ValueError):
|
||||
compile_graph(x, y, num_segments, strategy1, strategy2)
|
Loading…
Reference in New Issue