!31040 Produce parallel operators for Argmin/max, SquareSumAll and UnsortedSegmentProd

Merge pull request !31040 from Bert0108/reduce_operators_arg
This commit is contained in:
i-robot 2022-03-11 06:15:56 +00:00 committed by Gitee
commit c2a5cc1486
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
10 changed files with 765 additions and 14 deletions

View File

@ -753,6 +753,7 @@ class ReduceSumCost : public OperatorCost {
};
using ReduceMethodCost = ReduceSumCost;
using ReduceProdCost = ReduceSumCost;
using SquareSumAllCost = ReduceSumCost;
class ReduceMeanCost : public ReduceSumCost {
public:
@ -784,6 +785,7 @@ class ArgMaxWithValueCost : public ReduceSumCost {
void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
};
using ArgMinWithValueCost = ArgMaxWithValueCost;
using ArgmaxCost = ArgMaxWithValueCost;
class GetNextCost : public OperatorCost {
public:
@ -919,6 +921,7 @@ class UnsortedSegmentSumCost : public OperatorCost {
// Taking account of input
void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override;
};
using UnsortedSegmentProdCost = UnsortedSegmentSumCost;
class UnsortedSegmentMinCost : public OperatorCost {
public:

View File

@ -217,6 +217,10 @@ REGISTER(CropAndResizeInfo);
REGISTER(ROIAlignInfo);
REGISTER(ReduceProdInfo);
REGISTER(ReduceAllInfo);
REGISTER(ArgmaxInfo);
REGISTER(ArgminInfo);
REGISTER(UnsortedSegmentProdInfo);
REGISTER(SquareSumAllInfo);
} // namespace parallel
} // namespace mindspore

View File

@ -111,6 +111,7 @@ constexpr char FUNCTIONAL_OP_PATH[] = "mindspore.ops.functional";
constexpr char GET_OP_FUNCTION_PATH[] = "mindspore.parallel._utils";
constexpr char GET_OP_FUNCTION[] = "_get_python_op";
constexpr char KEEP_DIMS[] = "keep_dims";
constexpr char OUTPUT_TYPE[] = "output_type";
constexpr char CROSS_BATCH[] = "cross_batch";
constexpr char STEP_PARALLEL_BEGIN[] = "step_parallel_begin";
constexpr char STEP_PARALLEL_END[] = "step_parallel_end";
@ -335,6 +336,8 @@ constexpr char REDUCE_ALL[] = "ReduceAll";
constexpr char REDUCE_ANY[] = "ReduceAny";
constexpr char ARGMAXWITHVALUE[] = "ArgMaxWithValue";
constexpr char ARGMINWITHVALUE[] = "ArgMinWithValue";
constexpr char ARGMAX[] = "Argmax";
constexpr char ARGMIN[] = "Argmin";
constexpr char CONV2D[] = "Conv2D";
constexpr char CONV2D_BACK_PROP_INPUT[] = "Conv2DBackpropInput";
constexpr char CONV2D_TRANSPOSE[] = "Conv2DTranspose";
@ -432,6 +435,7 @@ constexpr char UNSORTEF_SEGMENT_PRODD[] = "UnsortedSegmentProdD";
constexpr char UNSORTED_SEGMENT_SUM[] = "UnsortedSegmentSum";
constexpr char UNSORTED_SEGMENT_MIN[] = "UnsortedSegmentMin";
constexpr char UNSORTED_SEGMENT_MAX[] = "UnsortedSegmentMax";
constexpr char UNSORTED_SEGMENT_PROD[] = "UnsortedSegmentProd";
constexpr char DEPTHWISE_CONV2D_NATIVE[] = "DepthwiseConv2dNative";
constexpr char DEPTHWISE_CONV2D[] = "DepthwiseConv2D";
constexpr char DROPOUT[] = "Dropout";
@ -450,6 +454,7 @@ constexpr char RANDOM_CHOICE_WITH_MASK[] = "RandomChoiceWithMask";
constexpr char CROP_AND_RESIZE[] = "CropAndResize";
constexpr char MASKED_FILL[] = "MaskedFill";
constexpr char ROI_ALIGN[] = "ROIAlign";
constexpr char SQUARE_SUM_ALL[] = "SquareSumAll";
// pipeline
constexpr size_t PIPELINE_FUSTION_OFFSET = 100;
@ -483,6 +488,7 @@ constexpr char MAKE_RECORD[] = "make_record";
constexpr char LIST_GETITEM[] = "list_getitem";
constexpr char ARRAY_GETITEM[] = "array_getitem";
constexpr char TUPLE_SETITEM[] = "tuple_setitem";
constexpr char TUPLE_GETITEM[] = "tuple_getitem";
constexpr char LIST_SETITEM[] = "list_setitem";
constexpr char ARRAY_SETITEM[] = "array_setitem";
constexpr char DICT_GETITEM[] = "dict_getitem";

View File

@ -26,6 +26,7 @@
#include "frontend/parallel/device_matrix.h"
#include "frontend/parallel/tensor_layout/tensor_redistribution.h"
#include "utils/log_adapter.h"
#include "frontend/parallel/graph_util/generate_graph.h"
namespace mindspore {
namespace parallel {
@ -578,15 +579,321 @@ Status ArgMaxWithValueInfo::InferAsLossDivisor() {
return SUCCESS;
}
std::vector<StrategyPtr> ArgMaxWithValueInfo::GenerateOpStrategies(int64_t stage_id) {
Shape input0_split(inputs_shape_[0].size(), 1);
Shapes splittable_inputs = {input0_split};
std::vector<StrategyPtr> sp_vector;
if (GenerateStrategiesForIndependentInputs(stage_id, inputs_shape_, splittable_inputs, &sp_vector) != SUCCESS) {
MS_LOG(EXCEPTION) << name_ << ": GenerateStrategiesForIndependentInputs failed.";
std::vector<int64_t> ArgmaxInfo::reduce_dim() {
// get axis from attribution
std::vector<int64_t> dim_list;
auto iter = attrs_.find(AXIS);
if (iter == attrs_.end()) {
MS_LOG(EXCEPTION) << name_ << ": Don't have attribution axis.";
}
return sp_vector;
MS_ASSERT(inputs_shape_.size() == 1);
auto input_dim = inputs_shape_.at(0).size();
MS_EXCEPTION_IF_NULL(iter->second);
if (iter->second->isa<ValueTuple>()) {
auto attr_axis = GetValue<std::vector<int64_t>>(iter->second);
if (attr_axis.empty()) {
for (size_t i = 0; i < input_dim; ++i) {
dim_list.push_back(SizeToLong(i));
}
} else {
for (auto &axis : attr_axis) {
axis < 0 ? dim_list.push_back(axis + SizeToLong(input_dim)) : dim_list.push_back(axis);
}
}
} else if (iter->second->isa<Int64Imm>()) {
int64_t axis = GetValue<int64_t>(iter->second);
axis < 0 ? dim_list.push_back(axis + SizeToLong(input_dim)) : dim_list.push_back(axis);
} else {
MS_LOG(EXCEPTION) << "Axis type is invalid.";
}
return dim_list;
}
Status ArgmaxInfo::GetAttrs() {
// set the keep_dims False as default
keepdims_ = false;
// get attr output_type and cross_batch
auto output_type_iter = attrs_.find(OUTPUT_TYPE);
if (output_type_iter != attrs_.end()) {
MS_EXCEPTION_IF_NULL(output_type_iter->second);
}
auto cross_batch_iter = attrs_.find(CROSS_BATCH);
if (cross_batch_iter != attrs_.end()) {
MS_EXCEPTION_IF_NULL(cross_batch_iter->second);
if (!cross_batch_iter->second->isa<BoolImm>()) {
MS_LOG(ERROR) << name_ << ": cross_batch is not a bool.";
return FAILED;
}
cross_batch_ = cross_batch_iter->second->cast<BoolImmPtr>()->value();
}
auto reducemethodcost = std::dynamic_pointer_cast<ReduceMethodCost>(operator_cost());
if (reducemethodcost == nullptr) {
MS_LOG(ERROR) << "Cost cast to ReduceMethodCostPtr failed!";
return FAILED;
}
reducemethodcost->set_cross_batch(cross_batch_);
return SUCCESS;
}
Status ArgmaxInfo::CheckStrategy(const StrategyPtr &strategy) {
if (ReduceMethod::CheckStrategy(strategy) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": CheckStrategy for parent class ReduceMethod failed";
return FAILED;
}
std::vector<int64_t> dim_list = reduce_dim();
MS_ASSERT(dim_list.size() == 1);
Strategys stra = strategy->GetInputDim();
MS_ASSERT(stra.size() == 1);
Shape input_strategy = stra.at(0);
MS_ASSERT(dim_list.at(0) < input_strategy.size());
if (input_strategy.at(LongToSize(dim_list.at(0))) != 1) {
MS_LOG(WARNING) << name_
<< " CheckStrategy for Argmax/Argmin, the strategy corresponding to axis is not one, real strategy "
"is "
<< input_strategy.at(LongToSize(dim_list.at(0)))
<< ", the output index may be not compatible with the stand alone Primitive";
}
return SUCCESS;
}
Status ArgmaxInfo::InferMirrorOps() {
if (OperatorInfo::InferMirrorOps() != SUCCESS) {
MS_LOG(ERROR) << name_ << ": InferMirrorOps for parent class OperatorInfo failed";
return FAILED;
}
return SUCCESS;
}
std::vector<int64_t> SquareSumAllInfo::reduce_dim() {
std::vector<int64_t> dim_list;
auto input_dim = inputs_shape_.at(0).size();
// reduce all dim
for (size_t i = 0; i < input_dim; ++i) {
dim_list.push_back(SizeToLong(i));
}
return dim_list;
}
Status SquareSumAllInfo::GetAttrs() {
// set the keep_dims False as default
keepdims_ = false;
// get attr cross_batch
auto cross_batch_iter = attrs_.find(CROSS_BATCH);
if (cross_batch_iter != attrs_.end()) {
MS_EXCEPTION_IF_NULL(cross_batch_iter->second);
if (!cross_batch_iter->second->isa<BoolImm>()) {
MS_LOG(ERROR) << name_ << ": cross_batch is not a bool.";
return FAILED;
}
cross_batch_ = cross_batch_iter->second->cast<BoolImmPtr>()->value();
}
auto reducemethodcost = std::dynamic_pointer_cast<ReduceMethodCost>(operator_cost());
if (reducemethodcost == nullptr) {
MS_LOG(ERROR) << "Cost cast to ReduceMethodCostPtr failed!";
return FAILED;
}
reducemethodcost->set_cross_batch(cross_batch_);
return SUCCESS;
}
Status SquareSumAllInfo::CheckStrategy(const StrategyPtr &strategy) {
if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) {
MS_LOG(ERROR) << name_ << " : Invalid strategy.";
return FAILED;
}
Strategys stra = strategy->GetInputDim();
Dimensions sub_a_strategy = stra.at(0);
Dimensions sub_b_strategy = stra.at(1);
Shape input_a_shape = inputs_shape_.at(0);
Shape input_b_shape = inputs_shape_.at(1);
MS_ASSERT(input_a_shape.size() == input_b_shape.size());
MS_ASSERT(sub_a_strategy.size() == sub_b_strategy.size());
for (size_t i = 0; i < input_a_shape.size(); ++i) {
if ((sub_a_strategy[i] != sub_b_strategy[i]) && (input_a_shape[i] != input_b_shape[i])) {
MS_LOG(ERROR) << name_ << " : Invalid strategy.";
return FAILED;
}
}
return SUCCESS;
}
Status SquareSumAllInfo::InferDevMatrixShape() {
Strategys strategy = strategy_->GetInputDim();
Dimensions sub_a_strategy = strategy.at(0);
Shape dev_shape;
for (size_t i = 0; i < sub_a_strategy.size(); ++i) {
dev_shape.push_back(sub_a_strategy[i]);
}
dev_matrix_shape_ = dev_shape;
return SUCCESS;
}
Status SquareSumAllInfo::InferTensorMap() {
if (ReduceMethod::InferTensorMap() != SUCCESS) {
MS_LOG(ERROR) << name_ << ": InferTensorMap for parent class ReduceMethod failed";
return FAILED;
}
MS_ASSERT(outputs_tensor_map_.size() == 1);
inputs_tensor_map_.push_back(inputs_tensor_map_[0]);
outputs_tensor_map_.push_back(outputs_tensor_map_[0]);
return SUCCESS;
}
Status SquareSumAllInfo::InferTensorInfo() {
// infer tensor shape
Shape input_shape = inputs_shape_.at(0);
Shape output_shape = outputs_shape_.at(0);
// infer slice shape
Shapes inputs_slice_shape, outputs_slice_shape;
Strategys inputs_strategy = strategy_->GetInputDim();
Dimensions output_strategy = InferOutputStrategy();
Strategys outputs_strategy = {output_strategy, output_strategy};
if (InferSliceShape(inputs_strategy, outputs_strategy, &inputs_slice_shape, &outputs_slice_shape) != SUCCESS) {
return FAILED;
}
Shape input_slice_shape = inputs_slice_shape.at(0);
Shape output_slice_shape = outputs_slice_shape.at(0);
TensorLayout input_tensor_layout, output_tensor_layout;
if ((input_tensor_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_[0], input_shape) != SUCCESS) ||
(output_tensor_layout.InitFromVector(dev_matrix_shape_, outputs_tensor_map_[0], output_shape) != SUCCESS)) {
return FAILED;
}
std::vector<int64_t> dim_list = reduce_dim();
TensorInfo input_tensor_info(input_tensor_layout, input_shape, input_slice_shape);
TensorInfo output_tensor_info(output_tensor_layout, output_shape, output_slice_shape);
input_tensor_info.set_reduce_dim(dim_list);
inputs_tensor_info_.push_back(input_tensor_info);
inputs_tensor_info_.push_back(input_tensor_info);
outputs_tensor_info_.push_back(output_tensor_info);
outputs_tensor_info_.push_back(output_tensor_info);
return SUCCESS;
}
Status SquareSumAllInfo::InferGroup() {
Dimensions stra = strategy_->GetInputDim().at(0);
forward_op_.clear();
std::vector<int64_t> dim_list = reduce_dim();
size_t size = stra.size();
// judge if the reduce dim is partitioned.
Shape group_creat_map;
// if repeated calculation and the repeated_calc_num_ insert to the first dimension of dev matrix,
// it need to handle the first dimension of map.
if ((dev_matrix_shape_.size() > size) && !repeated_num_in_dev_matrix_right_) {
group_creat_map.push_back(SizeToInt(dev_matrix_shape_.size() - size_t(1)));
}
for (size_t index = 0; index < size; ++index) {
auto pos =
std::find_if(dim_list.begin(), dim_list.end(), [index](const int64_t &dim) { return SizeToLong(index) == dim; });
if (pos != dim_list.end() && stra[index] != 1) {
continue;
}
group_creat_map.push_back(SizeToLong(size) - SizeToLong(index) - 1);
}
// if repeated calculation and the repeated_calc_num_ insert to the last dimension of dev matrix,
// it need to handle the group_creat_map and insert the 0 to the last dimension of the group_creat_map.
if (repeated_num_in_dev_matrix_right_ && (repeated_calc_num_ > 1)) {
for (auto &ele : group_creat_map) {
if (ele == MAP_NONE) {
continue;
}
ele += 1;
}
group_creat_map.push_back(0);
}
if (CreateGroupByTensorMap(group_creat_map, &group_) != SUCCESS) {
ReportError(name_ + ": Create group failed.");
return FAILED;
}
return SUCCESS;
}
Status SquareSumAllInfo::ComputeReplaceGraph(const CNodePtr &cnode) {
GenerateGraph gen_g = GenerateGraph(attrs_);
if (gen_g.Init(cnode) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": GenerateGraph Init failed";
}
if (InferGroup() != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Infer Group failed";
}
MS_LOG(INFO) << name_ << ": The rank is " << g_device_manager->rank_index_in_stage();
auto square_sum_all =
gen_g.PushBack({gen_g.NewOpInst(SQUARE_SUM_ALL), gen_g.virtual_input_node(), gen_g.virtual_input_node()});
auto get_item0 = gen_g.PushBack({gen_g.NewOpInst(TUPLE_GETITEM), square_sum_all, CreatInt64Imm(0)});
auto get_item1 = gen_g.PushBack({gen_g.NewOpInst(TUPLE_GETITEM), square_sum_all, CreatInt64Imm(1)});
Attr attr_op = std::make_pair(OP, MakeValue(REDUCE_OP_SUM));
Attr attr_group = std::make_pair(GROUP, MakeValue(group_[0].name()));
OperatorAttrs attrs = {attr_op, attr_group};
auto allreduce_op0 = gen_g.PushBack({gen_g.NewOpInst(ALL_REDUCE, attrs), get_item0});
auto allreduce_op1 = gen_g.PushBack({gen_g.NewOpInst(ALL_REDUCE, attrs), get_item1});
auto make_list = gen_g.PushBack({gen_g.NewOpInst(MAKE_LIST), allreduce_op0, allreduce_op1});
std::vector<std::pair<AnfNodePtr, int64_t>> input_nodes = {std::make_pair(square_sum_all, 1),
std::make_pair(square_sum_all, 2)};
replace_graph_ = std::make_shared<std::pair<std::vector<std::pair<AnfNodePtr, int64_t>>, AnfNodePtr>>(
std::make_pair(input_nodes, make_list));
return SUCCESS;
}
ReplaceGraphPtr SquareSumAllInfo::replace_graph(const CNodePtr &cnode) {
if (ComputeReplaceGraph(cnode) != SUCCESS) {
MS_LOG(EXCEPTION) << name_ << ": ComputeReplaceGraph failed.";
}
return replace_graph_;
}
Status SquareSumAllInfo::InferMirrorOps() {
if (OperatorInfo::InferMirrorOps() != SUCCESS) {
MS_LOG(ERROR) << name_ << ": InferMirrorOps for parent class OperatorInfo failed";
return FAILED;
}
return SUCCESS;
}
Status SquareSumAllInfo::InferAsLossDivisor() {
if (outputs_tensor_map_.empty()) {
MS_LOG(ERROR) << name_ << ": The outputs tensor map is empty.";
return FAILED;
}
MS_LOG(INFO) << name_ << " has two outputs, use output[0] to infer";
if (outputs_tensor_map_[0].empty()) {
as_loss_divisor_ = stage_device_size_;
MS_LOG(INFO) << name_ << ": The output is a scalar, use the dev size" << as_loss_divisor_ << " as loss divisor.";
return SUCCESS;
}
as_loss_divisor_ = ComputeRepeatDeviceNumByTensorMap(dev_matrix_shape_, outputs_tensor_map_[0]);
std::string dev_matrix_shape_str = ShapeToString(dev_matrix_shape_);
std::string output_tensor_map_str = ShapeToString(outputs_tensor_map_[0]);
MS_LOG(INFO) << name_ << ": the dev matrix shape, the output tensor map, and loss divisor is " << dev_matrix_shape_str
<< ", " << output_tensor_map_str << ", " << as_loss_divisor_;
return SUCCESS;
}
} // namespace parallel
} // namespace mindspore

View File

@ -76,8 +76,6 @@ class ArgMaxWithValueInfo : public ReduceMethod {
~ArgMaxWithValueInfo() override = default;
std::vector<StrategyPtr> GenerateOpStrategies(int64_t stage_id) override;
protected:
std::vector<int64_t> reduce_dim() override;
Status CheckStrategy(const StrategyPtr &strategy) override;
@ -167,6 +165,63 @@ class ReduceAllInfo : public ReduceAnyInfo {
~ReduceAllInfo() override = default;
};
class ArgmaxInfo : public ReduceMethod {
public:
ArgmaxInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ReduceMethod(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArgmaxCost>()) {
reduce_method_ = REDUCE_OP_MAX;
}
~ArgmaxInfo() override = default;
protected:
std::vector<int64_t> reduce_dim() override;
Status GetAttrs() override;
Status CheckStrategy(const StrategyPtr &strategy) override;
Status InferMirrorOps() override;
};
class ArgminInfo : public ArgmaxInfo {
public:
ArgminInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ArgmaxInfo(name, inputs_shape, outputs_shape, attrs) {
reduce_method_ = REDUCE_OP_MIN;
}
~ArgminInfo() override = default;
};
class SquareSumAllInfo : public ReduceMethod {
public:
SquareSumAllInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ReduceMethod(name, inputs_shape, outputs_shape, attrs, std::make_shared<SquareSumAllCost>()) {
reduce_method_ = REDUCE_OP_SUM;
}
~SquareSumAllInfo() override = default;
ReplaceGraphPtr replace_graph(const CNodePtr &cnode) override;
protected:
std::vector<int64_t> reduce_dim() override;
Status GetAttrs() override;
Status CheckStrategy(const StrategyPtr &strategy) override;
Status InferDevMatrixShape() override;
Status InferTensorMap() override;
Status InferTensorInfo() override;
Status InferForwardCommunication() override { return SUCCESS; }
Status InferMirrorOps() override;
Status InferAsLossDivisor() override;
private:
Status InferGroup();
Status ComputeReplaceGraph(const CNodePtr &cnode);
std::vector<Group> group_;
};
} // namespace parallel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_REDUCE_SUM_INFO_H_

View File

@ -205,7 +205,7 @@ Status UnsortedSegmentOpInfo::InferForwardCommunication() {
return SUCCESS;
}
Operator op = CreateAllReduceOp(REDUCE_OP_SUM, group_list[0].name());
Operator op = CreateAllReduceOp(reduce_method_, group_list[0].name());
forward_op_.push_back(op);
MS_LOG(INFO) << name_ << " : The group name of forward communication is " << group_list[0].name();
return SUCCESS;

View File

@ -50,6 +50,7 @@ class UnsortedSegmentOpInfo : public OperatorInfo {
std::shared_ptr<Strategys> GenerateBatchStrategies() override;
protected:
std::string reduce_method_;
Status CheckStrategy(const StrategyPtr &strategy) override;
Status InferForwardCommunication() override;
Status InferMirrorOps() override;
@ -65,15 +66,29 @@ class UnsortedSegmentSumInfo : public UnsortedSegmentOpInfo {
public:
UnsortedSegmentSumInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: UnsortedSegmentOpInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<UnsortedSegmentSumCost>()) {}
: UnsortedSegmentOpInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<UnsortedSegmentSumCost>()) {
reduce_method_ = REDUCE_OP_SUM;
}
~UnsortedSegmentSumInfo() override = default;
};
class UnsortedSegmentProdInfo : public UnsortedSegmentOpInfo {
public:
UnsortedSegmentProdInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: UnsortedSegmentOpInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<UnsortedSegmentProdCost>()) {
reduce_method_ = REDUCE_OP_PROD;
}
~UnsortedSegmentProdInfo() override = default;
};
class UnsortedSegmentMinInfo : public UnsortedSegmentOpInfo {
public:
UnsortedSegmentMinInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: UnsortedSegmentOpInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<UnsortedSegmentMinCost>()) {}
: UnsortedSegmentOpInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<UnsortedSegmentMinCost>()) {
reduce_method_ = REDUCE_OP_MIN;
}
~UnsortedSegmentMinInfo() override = default;
ReplaceGraphPtr replace_graph(const CNodePtr &cnode) override;
@ -86,7 +101,9 @@ class UnsortedSegmentMaxInfo : public UnsortedSegmentOpInfo {
public:
UnsortedSegmentMaxInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: UnsortedSegmentOpInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<UnsortedSegmentMaxCost>()) {}
: UnsortedSegmentOpInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<UnsortedSegmentMaxCost>()) {
reduce_method_ = REDUCE_OP_MAX;
}
~UnsortedSegmentMaxInfo() override = default;
ReplaceGraphPtr replace_graph(const CNodePtr &cnode) override;

View File

@ -174,7 +174,8 @@ bool IsSplittableOperator(const std::string &op_name) {
UNSORTED_SEGMENT_MIN, REPEAT_ELEMENTS, TENSOR_DOT, RANGE, UNIFORM_CANDIDATE_SAMPLER, SLICE, SELECT, GATHERD,
UNSORTED_SEGMENT_MAX, GATHER_ND, TOPK, SCATTER_UPDATE, VIRTUAL_OUTPUT, CONV2D_BACK_PROP_INPUT, CONV2D_TRANSPOSE,
MATMUL_DDS, DSD_MATMUL, UNIFORMREAL, RESIZE_BILINEAR, RESIZE_NEAREST_NEIGHBOR, CUMSUM, FAST_GELU, IOU,
BOUNDING_BOX_ENCODE, RANDOM_CHOICE_WITH_MASK, CROP_AND_RESIZE, ROI_ALIGN, REDUCE_PROD, REDUCE_ANY, REDUCE_ALL};
BOUNDING_BOX_ENCODE, RANDOM_CHOICE_WITH_MASK, CROP_AND_RESIZE, ROI_ALIGN, REDUCE_PROD, REDUCE_ANY, REDUCE_ALL,
ARGMAX, ARGMIN, UNSORTED_SEGMENT_PROD, SQUARE_SUM_ALL};
// clang-format on
auto iter = splittable_op.find(op_name);

View File

@ -415,6 +415,29 @@ class ArgMinWithValueNet(nn.Cell):
out = self.mul2(out, b)
return out
class ArgMaxNet(nn.Cell):
def __init__(self, strategy1, strategy2):
super(ArgMaxNet, self).__init__()
self.mul1 = P.Mul().shard(strategy1)
self.arg_max = P.Argmax(axis=-1).shard(strategy2)
def construct(self, x, y):
out = self.mul1(x, y)
out = self.arg_max(out)
return out
class ArgMinNet(nn.Cell):
def __init__(self, strategy1, strategy2):
super(ArgMinNet, self).__init__()
self.mul1 = P.Mul().shard(strategy1)
self.arg_min = P.Argmin(axis=-1).shard(strategy2)
def construct(self, x, y):
out = self.mul1(x, y)
out = self.arg_min(out)
return out
def gen_inputs_and_compile_net(net):
x = Tensor(np.ones([128, 64, 64]), dtype=ms.float32)
@ -514,6 +537,90 @@ def test_arg_min_with_value_mul_auto():
gen_inputs_and_compile_net(net)
def test_arg_max_semi_axis_parallel():
"""
Feature: test Argmax semi parallel strategy
Description: partition the reduced axes
Expectation: compile success
"""
context.set_auto_parallel_context(device_num=8, global_rank=0)
strategy1 = ((1, 4, 2), (1, 4, 2))
strategy2 = ((4, 1, 2),)
net = GradWrapNoBias(NetWithLossNoBias(ArgMaxNet(strategy1, strategy2)))
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
gen_inputs_and_compile_net_no_bias(net)
def test_arg_max_mul_semi():
"""
Feature: test Argmax model parallel strategy
Description: partition the non-reduced axes
Expectation: compile success
"""
context.set_auto_parallel_context(device_num=8, global_rank=0)
strategy1 = ((1, 4, 2), (1, 4, 2))
strategy2 = ((4, 2, 1),)
net = GradWrapNoBias(NetWithLossNoBias(ArgMaxNet(strategy1, strategy2)))
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
gen_inputs_and_compile_net_no_bias(net)
def test_arg_max_mul_auto():
"""
Feature: test Argmax auto parallel strategy
Description: don't set the strategy
Expectation: compile success
"""
context.set_auto_parallel_context(device_num=8, global_rank=0)
strategy1 = None
strategy2 = None
net = GradWrapNoBias(NetWithLossNoBias(ArgMaxNet(strategy1, strategy2)))
context.set_auto_parallel_context(parallel_mode="auto_parallel")
gen_inputs_and_compile_net_no_bias(net)
def test_arg_min_semi_axis_parallel():
"""
Feature: test Argmin semi parallel strategy
Description: partition the reduced axes
Expectation: compile success
"""
context.set_auto_parallel_context(device_num=8, global_rank=0)
strategy1 = ((1, 4, 2), (1, 4, 2))
strategy2 = ((4, 1, 2),)
net = GradWrapNoBias(NetWithLossNoBias(ArgMinNet(strategy1, strategy2)))
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
gen_inputs_and_compile_net_no_bias(net)
def test_arg_min_mul_semi():
"""
Feature: test Argmin model parallel strategy
Description: partition the non-reduced axes
Expectation: compile success
"""
context.set_auto_parallel_context(device_num=8, global_rank=0)
strategy1 = ((1, 4, 2), (1, 4, 2))
strategy2 = ((4, 2, 1),)
net = GradWrapNoBias(NetWithLossNoBias(ArgMinNet(strategy1, strategy2)))
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
gen_inputs_and_compile_net_no_bias(net)
def test_arg_min_mul_auto():
"""
Feature: test Argmin auto parallel strategy
Description: don't set the strategy
Expectation: compile success
"""
context.set_auto_parallel_context(device_num=8, global_rank=0)
strategy1 = None
strategy2 = None
net = GradWrapNoBias(NetWithLossNoBias(ArgMinNet(strategy1, strategy2)))
context.set_auto_parallel_context(parallel_mode="auto_parallel")
gen_inputs_and_compile_net_no_bias(net)
class ArgMinWithValueNet2(nn.Cell):
def __init__(self, strategy1, strategy2, strategy3):
super(ArgMinWithValueNet2, self).__init__()
@ -915,3 +1022,59 @@ def test_prod_mul_auto():
net = GradWrapNoBias(NetWithLossNoBias(Net(strategy1, strategy2)))
context.set_auto_parallel_context(parallel_mode="auto_parallel")
gen_inputs_and_compile_net_no_bias(net)
def test_square_sum_all_mul():
"""
Feature: test SquareSumAll model parallel strategy
Description: partition the reduced axes
Expectation: compile success
"""
class Net(nn.Cell):
def __init__(self, strategy1, strategy2):
super(Net, self).__init__()
self.mul1 = P.Mul().shard(strategy1)
self.square_sum_all = P.SquareSumAll().shard(strategy2)
def construct(self, x, y):
out = self.mul1(x, y)
out = self.square_sum_all(out, out)
return out
context.set_auto_parallel_context(device_num=8, global_rank=0)
strategy1 = ((1, 1, 8), (1, 1, 8))
strategy2 = ((2, 4, 1), (2, 4, 1))
net = Net(strategy1, strategy2)
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
x = Tensor(np.ones([128, 32, 64]), dtype=ms.float32)
y = Tensor(np.ones([128, 32, 64]), dtype=ms.float32)
compile_net_no_bias(net, x, y)
def test_square_sum_all_mul2():
"""
Feature: test SquareSumAll model parallel strategy
Description: partition the reduced axes
Expectation: compile success
"""
class Net(nn.Cell):
def __init__(self, stra_mul, stra_prod):
super(Net, self).__init__()
self.mul = P.Mul().shard(stra_mul)
self.square_sum_all = P.SquareSumAll().shard(stra_prod)
def construct(self, x, y):
out = self.mul(x, y)
out = self.square_sum_all(out, out)
return out
context.set_auto_parallel_context(device_num=8, global_rank=0)
strategy1 = ((1, 1, 8), (1, 1, 8))
strategy2 = ((8, 1, 1), (8, 1, 1))
net = Net(strategy1, strategy2)
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
x = Tensor(np.ones([128, 32, 64]), dtype=ms.float32)
y = Tensor(np.ones([128, 32, 64]), dtype=ms.float32)
compile_net_no_bias(net, x, y)

View File

@ -0,0 +1,195 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore as ms
import mindspore.nn as nn
from mindspore import Tensor
from mindspore import context
from mindspore.common.api import _cell_graph_executor
from mindspore.ops import composite as C
from mindspore.ops import operations as P
from tests.ut.python.ops.test_math_ops import VirtualLoss
context.set_context(mode=context.GRAPH_MODE)
grad_all = C.GradOperation(get_all=True)
class Net(nn.Cell):
def __init__(self, strategy1, strategy2, num_segments):
super(Net, self).__init__()
self.merge_op = P.UnsortedSegmentProd().shard((strategy1, strategy2))
self.num_segments = num_segments
def construct(self, vectors, segment_ids):
predict = self.merge_op(vectors, segment_ids, self.num_segments)
return predict
class GradWrap(nn.Cell):
def __init__(self, network):
super(GradWrap, self).__init__()
self.network = network
def construct(self, x, y):
return grad_all(self.network)(x, y)
class NetWithLoss(nn.Cell):
def __init__(self, network):
super(NetWithLoss, self).__init__()
self.network = network
self.loss = VirtualLoss()
def construct(self, x, y):
predict = self.network(x, y)
return self.loss(predict)
def compile_graph(x, y, segments, strategy1, strategy2, auto=False):
if auto:
context.set_auto_parallel_context(parallel_mode="auto_parallel")
else:
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
net = GradWrap(NetWithLoss(Net(strategy1, strategy2, segments)))
net.set_auto_parallel()
net.set_train()
_cell_graph_executor.compile(net, x, y)
def test_unsortedsegmentprod_model_parallel_slice_1d():
"""
Feature: distribute operator unsorted_segment_prod in auto parallel.
Description: unsorted_segment_prod net with model parallel strategy, slice 1d.
Expectation: compile done without error.
"""
context.set_auto_parallel_context(device_num=8, global_rank=0)
x = Tensor(np.ones(8), ms.float32)
y = Tensor(np.ones(8), ms.int32)
num_segments = 16
strategy1 = (8,)
strategy2 = (8,)
compile_graph(x, y, num_segments, strategy1, strategy2)
def test_unsortedsegmentprod_model_parallel_no_slice_1d():
"""
Feature: distribute operator unsorted_segment_prod in auto parallel.
Description: unsorted_segment_prod net with no slice strategy in semi auto parallel, slice 1d.
Expectation: compile done without error.
"""
context.set_auto_parallel_context(device_num=8, global_rank=0)
x = Tensor(np.ones(8), ms.float32)
y = Tensor(np.ones(8), ms.int32)
num_segments = 16
strategy1 = (1,)
strategy2 = (1,)
compile_graph(x, y, num_segments, strategy1, strategy2)
def test_unsortedsegmentprod_model_parallel_index_slice_2d():
"""
Feature: distribute operator unsorted_segment_prod in auto parallel.
Description: unsorted_segment_prod net with model parallel strategy, slice 2d.
Expectation: compile done without error.
"""
context.set_auto_parallel_context(device_num=4, global_rank=0)
x = Tensor(np.ones((4, 8)), ms.float32)
y = Tensor(np.arange(4), ms.int32)
num_segments = 4
strategy1 = (4, 1)
strategy2 = (4,)
compile_graph(x, y, num_segments, strategy1, strategy2)
def test_unsortedsegmentprod_model_parallel_index_slice_3d():
"""
Feature: distribute operator unsorted_segment_prod in auto parallel.
Description: unsorted_segment_prod net with model parallel strategy, slice 3d.
Expectation: compile done without error.
"""
context.set_auto_parallel_context(device_num=4, global_rank=0)
x = Tensor(np.ones((4, 4, 8)), ms.float32)
y = Tensor(np.ones((4, 4)), ms.int32)
num_segments = 16
strategy1 = (2, 2, 1)
strategy2 = (2, 2)
with pytest.raises(ValueError):
compile_graph(x, y, num_segments, strategy1, strategy2)
def test_unsortedsegmentprod_model_parallel_vector_slice_2d():
"""
Feature: distribute operator unsorted_segment_prod in auto parallel.
Description: unsorted_segment_prod net with model parallel strategy, slice 2d.
Expectation: compile done without error.
"""
context.set_auto_parallel_context(device_num=4, global_rank=0)
x = Tensor(np.ones((4, 8)), ms.float32)
y = Tensor(np.ones(4), ms.int32)
num_segments = 4
strategy1 = (1, 4)
strategy2 = (1,)
compile_graph(x, y, num_segments, strategy1, strategy2)
def test_unsortedsegmentprod_model_parallel_vector_slice_3d():
"""
Feature: distribute operator unsorted_segment_prod in auto parallel.
Description: unsorted_segment_prod net with model parallel, slice 3d.
Expectation: compile done without error.
"""
context.set_auto_parallel_context(device_num=4, global_rank=0)
x = Tensor(np.ones((4, 8, 8)), ms.float32)
y = Tensor(np.ones(4), ms.int32)
num_segments = 4
strategy1 = (1, 2, 2)
strategy2 = (1,)
compile_graph(x, y, num_segments, strategy1, strategy2)
def test_unsortedsegmentprod_model_parallel_index_vector_slice_2d():
"""
Feature: distribute operator unsorted_segment_prod in auto parallel.
Description: unsorted_segment_prod net with strategy in semi auto parallel, slice 2d.
Expectation: compile done without error.
"""
context.set_auto_parallel_context(device_num=4, global_rank=0)
x = Tensor(np.ones((4, 8)), ms.float32)
y = Tensor(np.ones(4), ms.int32)
num_segments = 4
strategy1 = (2, 2)
strategy2 = (2,)
compile_graph(x, y, num_segments, strategy1, strategy2)
def test_unsortedsegmentprod_model_parallel_index_vector_slice_3d():
"""
Feature: distribute operator unsorted_segment_prod in auto parallel.
Description: unsorted_segment_prod net with strategy in semi auto parallel, slice 3d.
Expectation: compile done without error.
"""
context.set_auto_parallel_context(device_num=4, global_rank=0)
x = Tensor(np.ones((4, 4, 8)), ms.float32)
y = Tensor(np.ones((4, 4)), ms.int32)
num_segments = 16
strategy1 = (2, 1, 2)
strategy2 = (2, 1)
with pytest.raises(ValueError):
compile_graph(x, y, num_segments, strategy1, strategy2)