Add UnsortedSegmentSum and UnosrtedSemenMin For Oparallel Implements

This commit is contained in:
huangxinjing 2020-10-15 21:24:52 +08:00
parent d4142d682d
commit bf5d21770a
13 changed files with 1046 additions and 2 deletions

View File

@ -947,5 +947,122 @@ double GatherV2PCost::GetBackwardComputationCost(const std::vector<TensorInfo> &
return result;
}
// The forward communication is determined by whether the slice is column split or row split
// The number of segments is actually the shape[0] of the output, which is the cost of the AllReduce
double UnsortedSegmentSumCost::GetForwardCommCost(const std::vector<TensorInfo> &inputs,
const std::vector<TensorInfo> &outputs, int32_t stage_id) const {
TensorInfo input0 = inputs[0];
TensorInfo input1 = inputs[1];
TensorInfo output0 = outputs[0];
Shape input0_shape = input0.shape();
Shape input0_slice_shape = inputs[0].slice_shape();
double result = 0.0;
if (inputs_type_lengths_.size() != inputs.size()) {
MS_LOG(EXCEPTION) << "Invalid inputs type size " << inputs_type_lengths_.size() << " for UnsortedSegmentSum cost";
}
// If the shape b is not the same as the shape a, we regard it as column slice
for (size_t i = 0; i < input1.shape().size(); ++i) {
if (input0_shape[i] != input0_slice_shape[i]) {
result = ListProduct(output0.slice_shape()) * static_cast<double>(outputs_type_lengths_[0]);
return result;
}
}
return result;
}
double UnsortedSegmentSumCost::GetBackwardCommCost(const std::vector<TensorInfo> &inputs,
const std::vector<TensorInfo> &outputs, int32_t stage_id) const {
TensorInfo input0 = inputs[0];
TensorInfo input1 = inputs[1];
TensorInfo output0 = outputs[0];
Shape input0_shape = input0.shape();
Shape input0_slice_shape = inputs[0].slice_shape();
double result = 0.0;
if (inputs_type_lengths_.size() != inputs.size()) {
MS_LOG(EXCEPTION) << "Invalid inputs type size " << inputs_type_lengths_.size() << " for UnsortedSegmentSum cost";
}
if (is_parameter_[0]) {
// If the forward process has a AllReduce, then the backward also needs one.
for (size_t i = 0; i < input1.shape().size(); ++i) {
if (input0_shape[i] != input0_slice_shape[i]) {
result = ListProduct(output0.slice_shape()) * static_cast<double>(outputs_type_lengths_[0]);
return result;
}
}
}
return result;
}
double UnsortedSegmentSumCost::GetForwardComputationCost(const std::vector<TensorInfo> &inputs,
const std::vector<TensorInfo> &outputs, int32_t) const {
// In forward phase, the computation cost = slice(A) + slice(B)
Shape input0_slice_shape = inputs[0].slice_shape();
Shape input1_slice_shape = inputs[1].slice_shape();
Shape output_slice_shape = outputs[0].slice_shape();
double result = ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]) +
ListProduct(input1_slice_shape) * static_cast<double>(inputs_type_lengths_[1]) +
ListProduct(output_slice_shape) * static_cast<double>(outputs_type_lengths_[0]);
return result;
}
double UnsortedSegmentMinCost::GetForwardCommCost(const std::vector<TensorInfo> &inputs,
const std::vector<TensorInfo> &outputs, int32_t stage_id) const {
TensorInfo input0 = inputs[0];
TensorInfo input1 = inputs[1];
TensorInfo output0 = outputs[0];
Shape input0_shape = input0.shape();
Shape input0_slice_shape = inputs[0].slice_shape();
double result = 0.0;
if (inputs_type_lengths_.size() != inputs.size()) {
MS_LOG(EXCEPTION) << "Invalid inputs type size " << inputs_type_lengths_.size()
<< " for UnsortedSegmentMinCost cost";
}
// If the shape b is not the same as the shape a, we regard it as column slice
// The cost is a AllGather operation, the shape is the same as the output of UnsortedSegmentMin.
for (size_t i = 0; i < input1.shape().size(); ++i) {
if (input0_shape[i] != input0_slice_shape[i]) {
result = ListProduct(output0.slice_shape()) * static_cast<double>(outputs_type_lengths_[0]);
return result;
}
}
return result;
}
double UnsortedSegmentMinCost::GetBackwardCommCost(const std::vector<TensorInfo> &inputs,
const std::vector<TensorInfo> &outputs, int32_t stage_id) const {
TensorInfo input0 = inputs[0];
TensorInfo input1 = inputs[1];
TensorInfo output0 = outputs[0];
Shape input0_shape = input0.shape();
Shape input0_slice_shape = inputs[0].slice_shape();
double result = 0.0;
if (inputs_type_lengths_.size() != inputs.size()) {
MS_LOG(EXCEPTION) << "Invalid inputs type size " << inputs_type_lengths_.size()
<< " for UnsortedSegmentMinCost cost";
}
if (is_parameter_[0]) {
// If the forward process has a AllGather, then the backward also needs one ReduceScatter.
for (size_t i = 0; i < input1.shape().size(); ++i) {
if (input0_shape[i] != input0_slice_shape[i]) {
result = ListProduct(output0.slice_shape()) * static_cast<double>(outputs_type_lengths_[0]);
return result;
}
}
}
return result;
}
double UnsortedSegmentMinCost::GetForwardComputationCost(const std::vector<TensorInfo> &inputs,
const std::vector<TensorInfo> &outputs, int32_t) const {
// In forward phase, the computation cost = slice(A) + slice(B)
Shape input0_slice_shape = inputs[0].slice_shape();
Shape input1_slice_shape = inputs[1].slice_shape();
Shape output_slice_shape = outputs[0].slice_shape();
// The forward operation is UnsortedSegmentMin + ReudceMin
double result = ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]) +
ListProduct(input1_slice_shape) * static_cast<double>(inputs_type_lengths_[1]) +
ListProduct(output_slice_shape) * static_cast<double>(outputs_type_lengths_[0]) +
ListProduct(output_slice_shape) * static_cast<double>(outputs_type_lengths_[0]); // ReduceMin
return result;
}
} // namespace parallel
} // namespace mindspore

View File

@ -578,6 +578,58 @@ class DropOutCost : public OperatorCost {
using DropOutCostPtr = std::shared_ptr<DropOutCost>;
class UnsortedSegmentSumCost : public OperatorCost {
public:
explicit UnsortedSegmentSumCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {}
UnsortedSegmentSumCost() : OperatorCost(true) {}
~UnsortedSegmentSumCost() override = default;
double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
int32_t stage_id) const override {
return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id);
}
double GetForwardCommCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &, int32_t) const override;
double GetBackwardCommCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &, int32_t) const override;
double GetComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
int32_t stage_id) const override {
return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id);
}
double GetForwardComputationCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &,
int32_t) const override;
double GetBackwardComputationCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &,
int32_t) const override {
return 0.0;
}
};
using UnsortedSegmentSumCostPtr = std::shared_ptr<UnsortedSegmentSumCost>;
class UnsortedSegmentMinCost : public OperatorCost {
public:
explicit UnsortedSegmentMinCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {}
UnsortedSegmentMinCost() : OperatorCost(true) {}
~UnsortedSegmentMinCost() override = default;
double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
int32_t stage_id) const override {
return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id);
}
double GetForwardCommCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &, int32_t) const override;
double GetBackwardCommCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &, int32_t) const override;
double GetComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
int32_t stage_id) const override {
return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id);
}
double GetForwardComputationCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &,
int32_t) const override;
double GetBackwardComputationCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &,
int32_t) const override {
return 0.0;
}
};
using UnsortedSegmentMinCostPtr = std::shared_ptr<UnsortedSegmentMinCost>;
class LayerNormCost : public OperatorCost {
public:
explicit LayerNormCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {}

View File

@ -173,6 +173,8 @@ REGISTER(ExpandDimsInfo);
REGISTER(SqueezeInfo);
REGISTER(SigmoidCrossEntropyWithLogitsInfo);
REGISTER(SquareInfo);
REGISTER(UnsortedSegmentSumInfo);
REGISTER(UnsortedSegmentMinInfo);
REGISTER(GatherV2PInfo);
REGISTER(EmbeddingLookupInfo);
REGISTER(TileInfo);

View File

@ -35,6 +35,7 @@
#include "frontend/parallel/ops_info/reduce_method_info.h"
#include "frontend/parallel/ops_info/reshape_info.h"
#include "frontend/parallel/ops_info/transpose_info.h"
#include "frontend/parallel/ops_info/unsorted_segment_op_info.h"
#include "frontend/parallel/ops_info/virtual_dataset_info.h"
#include "frontend/parallel/ops_info/gather_v2_p_info.h"
#include "frontend/parallel/ops_info/tile_info.h"

View File

@ -283,6 +283,8 @@ constexpr char IN_TOPK[] = "InTopK";
constexpr char GATHER_ND[] = "GatherNd";
constexpr char UNSORTEF_SEGMENT_MIND[] = "UnsortedSegmentMinD";
constexpr char UNSORTEF_SEGMENT_PRODD[] = "UnsortedSegmentProdD";
constexpr char UNSORTED_SEGMENT_SUM[] = "UnsortedSegmentSum";
constexpr char UNSORTED_SEGMENT_MIN[] = "UnsortedSegmentMin";
constexpr char DEPTHWISE_CONV2D_NATIVE[] = "DepthwiseConv2dNative";
constexpr char DEPTHWISE_CONV2D[] = "DepthwiseConv2D";
constexpr char ADD[] = "Add";

View File

@ -0,0 +1,313 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "frontend/parallel/ops_info/unsorted_segment_op_info.h"
#include <memory>
#include <utility>
#include <vector>
#include "frontend/parallel/auto_parallel/costmodel.h"
#include "frontend/parallel/device_matrix.h"
#include "frontend/parallel/graph_util/generate_graph.h"
#include "frontend/parallel/strategy.h"
#include "ir/tensor.h"
#include "ir/value.h"
#include "utils/log_adapter.h"
namespace mindspore {
namespace parallel {
// The operator UnsortedSegment accepts three inputs:
// input0 : vector, the shape is x1,x2,x3,...,xr
// input1 : segment id, the shape is x1,x2,..,xn
// input2 : value, the number of the segments
// For Sum: r >= n
// For Min: r >=n, n=1
Status UnsortedSegmentOpInfo::GetAttrs() {
if (inputs_shape_.size() != UNSORTEDSEGMENTOP_INPUTS_SIZE) {
MS_LOG(ERROR) << name_ << ": inputs shape size must be 2, but is " << inputs_shape_.size();
return FAILED;
}
if (outputs_shape_.size() != UNSORTEDSEGMENTOP_OUTPUTS_SIZE) {
MS_LOG(ERROR) << name_ << ": outputs shape size must be 1, but is " << outputs_shape_.size();
return FAILED;
}
if (input_value_.at(2) == nullptr) {
MS_LOG(ERROR) << name_ << ": the third input value is nullptr, is not a ValueNode!";
return FAILED;
}
if (inputs_shape_.at(0).empty()) {
MS_LOG(ERROR) << name_ << ": input can not be a scalar!";
return FAILED;
}
int num_segments = GetValue<int>(input_value_.at(2));
if (num_segments < 0) {
MS_LOG(ERROR) << name_ << ": the number of segments should be non negative value.";
return FAILED;
}
return SUCCESS;
}
Status UnsortedSegmentOpInfo::CheckStrategy(const StrategyPtr &strategy) {
// Check size
if (inputs_shape_.size() != UNSORTEDSEGMENTOP_INPUTS_SIZE) {
MS_LOG(ERROR) << name_ << ": inputs shape size must be " << UNSORTEDSEGMENTOP_INPUTS_SIZE << ", but is "
<< inputs_shape_.size();
return FAILED;
}
if (outputs_shape_.size() != UNSORTEDSEGMENTOP_OUTPUTS_SIZE) {
MS_LOG(ERROR) << name_ << ": outputs shape size must be " << UNSORTEDSEGMENTOP_OUTPUTS_SIZE << ", but is "
<< outputs_shape_.size();
return FAILED;
}
// The strategy of the first and the second input should be set.
if (CheckStrategyValue(strategy, {inputs_shape_.at(0), inputs_shape_.at(1)}) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Invalid strategy.";
return FAILED;
}
Strategys stra = strategy->GetInputDim();
Dimensions sub_a_strategy = stra.at(0);
Dimensions sub_b_strategy = stra.at(1);
Shape input_a_shape = inputs_shape_.at(0);
Shape input_b_shape = inputs_shape_.at(1);
// The size of the input b must be equal or smaller than input a
for (size_t i = 0; i < input_b_shape.size(); ++i) {
if ((sub_a_strategy[i] != sub_b_strategy[i]) && (input_a_shape[i] != input_b_shape[i])) {
MS_LOG(ERROR) << name_
<< " : Invalid strategy. The shape and the strategy of the input0 and input1 "
"should be same before the front size of the input[1]";
return FAILED;
}
}
return SUCCESS;
}
Status UnsortedSegmentOpInfo::InferDevMatrixShape() {
Strategys stra = strategy_->GetInputDim();
dev_matrix_shape_ = stra.at(0);
return SUCCESS;
}
// As the op converts the vector x1,x2,x3...,xr -> number of segments, xn,..,xr
// the dimension x1,x2,x3,..,xn is eliminated
// suppose the strategy of the inputs is (a,b,c,d), (a,b)
// the tensor map of the input vector is (3,2,1,0), id:(1, 0)
// the output vector is (-1, 1, 0)
Status UnsortedSegmentOpInfo::InferTensorMap() {
Shape tensor_map_in;
Shape tensor_map_in_index;
Shape tensor_map_out;
size_t input0_size = inputs_shape_.at(0).size();
// such as 4: tensor_map_index [3,2,1,0]
for (size_t i = 0; i < input0_size; ++i) {
tensor_map_in.push_back(SizeToInt(input0_size - i - 1));
tensor_map_in_index.push_back(SizeToInt(input0_size - i - 1));
tensor_map_out.push_back(SizeToInt(input0_size - i - 1));
}
(void)tensor_map_out.erase(tensor_map_out.begin(), tensor_map_out.begin() + inputs_shape_.at(1).size() - 1);
// A special case: the input vector (a,) id (a,) or input vector (a,b,c), id(a,b,c)
// The output vector will be a 1-dim vector,
// These two kinds of situations as row slice.
tensor_map_out[0] = -1;
(void)tensor_map_in_index.erase(tensor_map_in_index.begin() + inputs_shape_.at(1).size(), tensor_map_in_index.end());
if (tensor_map_out.size() != outputs_shape_.at(0).size()) {
MS_LOG(ERROR) << "Out tensor map size is not equal to output size! Out tensor map size is " << tensor_map_out.size()
<< " output size is " << outputs_shape_.at(0).size();
return FAILED;
}
inputs_tensor_map_.emplace_back(std::move(tensor_map_in));
inputs_tensor_map_.emplace_back(std::move(tensor_map_in_index));
outputs_tensor_map_.emplace_back(std::move(tensor_map_out));
return SUCCESS;
}
Status UnsortedSegmentOpInfo::InferTensorInfo() {
// infer tensor shape
Shape input_shape = inputs_shape_.at(0);
Shape input_index_shape = inputs_shape_.at(1);
Shape output_shape = outputs_shape_.at(0);
TensorLayout input_tensor_layout, input_index_layout, output_tensor_layout;
if ((input_tensor_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_.at(0), input_shape) != SUCCESS) ||
(input_index_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_.at(1), input_index_shape) != SUCCESS) ||
(output_tensor_layout.InitFromVector(dev_matrix_shape_, outputs_tensor_map_.at(0), output_shape) != SUCCESS)) {
return FAILED;
}
TensorInfo input_tensor_info(input_tensor_layout);
TensorInfo input_index_info(input_index_layout);
TensorInfo output_tensor_info(output_tensor_layout);
inputs_tensor_info_.push_back(input_tensor_info);
inputs_tensor_info_.push_back(input_index_info);
outputs_tensor_info_.push_back(output_tensor_info);
return SUCCESS;
}
Status UnsortedSegmentOpInfo::Init(const StrategyPtr &strategy) {
if (InitWithAutoRepeatCalc(strategy) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Init failed.";
return FAILED;
}
MS_LOG(INFO) << name_ << ": Init success.";
return SUCCESS;
}
Status UnsortedSegmentOpInfo::InitForCostModel(const StrategyPtr &strategy) {
if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Init for cost model failed.";
return FAILED;
}
MS_LOG(INFO) << name_ << ": Init for cost model success.";
return SUCCESS;
}
// Set the default strategy
Status UnsortedSegmentOpInfo::GenerateStrategies(int32_t stage_id) {
Shape input0_split(inputs_shape_[0].size(), 1);
Shapes splittable_inputs = {input0_split};
std::vector<StrategyPtr> sp_vector;
if (GenerateStrategiesForIndependentInputs(stage_id, {inputs_shape_.at(0)}, splittable_inputs, &sp_vector) !=
SUCCESS) {
MS_LOG(ERROR) << name_ << " : Generate strategies for independent inputs() failed.";
return FAILED;
}
for (auto &sp : sp_vector) {
Strategys tmp_strategy;
Dimensions first_input_strategy = sp->GetInputDim()[0];
Dimensions second_input_strategy;
for (size_t i = 0; i < inputs_shape_[1].size(); ++i) {
second_input_strategy.push_back(first_input_strategy[i]);
}
tmp_strategy.push_back(first_input_strategy);
tmp_strategy.push_back(second_input_strategy);
sp->ResetInputs(tmp_strategy);
}
size_t success = 0;
for (auto &sp : sp_vector) {
PrintStrategy(sp);
if (SetCostUnderStrategy(sp) == SUCCESS) {
success++;
MS_LOG(INFO) << name_ << " : Successfully generated " << success << " strategy";
PrintStrategy(sp);
}
}
return SUCCESS;
}
// if the dimension of the input b is split, we regarded it as the row slice, thus requires a AllReduce
// otherwise it is column slice,
Status UnsortedSegmentOpInfo::InferForwardCommunication() {
forward_op_.clear();
std::vector<Group> group_list;
Shape tmp_group_tensor_map = outputs_tensor_map_.at(0);
if (repeated_calc_num_ > 1) {
for (size_t i = 1; i < tmp_group_tensor_map.size(); ++i) {
tmp_group_tensor_map[i] += 1;
}
tmp_group_tensor_map.push_back(0);
}
if (CreateGroupByTensorMap(tmp_group_tensor_map, &group_list) != SUCCESS) {
MS_LOG(ERROR) << name_ << " : Infer forward communication, create group failed.";
return FAILED;
} else if (group_list.empty()) {
MS_LOG(INFO) << name_ << " : Forward all reduce is not required.";
return SUCCESS;
}
Operator op;
op = CreateAllReduceOp(REDUCE_OP_SUM, group_list[0].name());
forward_op_.push_back(op);
MS_LOG(INFO) << name_ << " : The group name of forward communication is " << group_list[0].name();
return SUCCESS;
}
Status UnsortedSegmentOpInfo::SetCostUnderStrategy(const StrategyPtr &strategy) {
return SetCostUnderStrategyBase(strategy);
}
std::shared_ptr<Strategys> UnsortedSegmentOpInfo::GenerateBatchStrategies() {
if (inputs_shape_.size() != UNSORTEDSEGMENTOP_INPUTS_SIZE) {
MS_LOG(EXCEPTION) << name_ << ": inputs shape size must be " << UNSORTEDSEGMENTOP_INPUTS_SIZE << ", but is "
<< inputs_shape_.size();
}
CheckGlobalDeviceManager();
size_t dev_num = g_device_manager->GetDeviceListByStageId(0).size();
if (GetAttrs() != SUCCESS) {
MS_LOG(EXCEPTION) << "GetAttrs failed!";
}
Dimensions strategy_a;
Dimensions strategy_b;
strategy_a.push_back(SizeToInt(dev_num));
for (size_t i = 1; i < inputs_shape_[0].size(); i++) {
strategy_a.push_back(1);
}
strategy_b.push_back(SizeToInt(dev_num));
for (size_t i = 1; i < inputs_shape_[1].size(); i++) {
strategy_b.push_back(1);
}
Strategys strategy_v = {strategy_a, strategy_b};
return std::make_shared<Strategys>(strategy_v);
}
// When the index is splited, the graph should be replaced
// a special case is when the shape input equals the shape of ids, we regard it as column slice,
// thus there is no need for repalce graphs
ReplaceGraphPtr UnsortedSegmentMinInfo::replace_graph(const CNodePtr &cnode) {
auto input_id_strategy = strategy_->GetInputDim().at(1);
// 1. the two input shapes are same, and the strategy is not all ones
if (std::any_of(input_id_strategy.begin(), input_id_strategy.end(), [](const int32_t &shard) { return shard > 1; })) {
if (ComputeReplaceGraph(cnode) != SUCCESS) {
MS_LOG(EXCEPTION) << name_ << ": ComputeReplaceGraph failed.";
}
}
return replace_graph_;
}
Status UnsortedSegmentMinInfo::ComputeReplaceGraph(const CNodePtr &cnode) {
GenerateGraph gen_g = GenerateGraph();
if (gen_g.Init(cnode) != SUCCESS) {
MS_LOG(ERROR) << "GenerateGraph Init failed";
return FAILED;
}
// Get the attributes of the UnsortedSegmentMin
auto num_segments = GetValue<int>(input_value_.at(2));
// Step1: Output branch
auto segment_min = gen_g.PushBack({gen_g.NewOpInst(UNSORTED_SEGMENT_MIN), gen_g.virtual_input_node(),
gen_g.virtual_input_node(), CreatInt32Imm(num_segments)});
auto expandim_output = gen_g.PushBack({gen_g.NewOpInst(EXPAND_DIMS), segment_min, CreatInt32Imm(0)});
auto all_gather_output = gen_g.PushBack({gen_g.NewOpInst(ALL_GATHER), expandim_output});
auto final_output = gen_g.PushBack({gen_g.NewOpInst(REDUCE_MIN), all_gather_output, CreatInt32Imm(0)});
std::vector<std::pair<AnfNodePtr, int>> input_nodes = {std::make_pair(segment_min, 1),
std::make_pair(segment_min, 2)};
replace_graph_ = std::make_shared<std::pair<std::vector<std::pair<AnfNodePtr, int>>, AnfNodePtr>>(
std::make_pair(input_nodes, final_output));
return SUCCESS;
}
} // namespace parallel
} // namespace mindspore

View File

@ -0,0 +1,84 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_UNSORTEDSEGMENTOP_INFO_H_
#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_UNSORTEDSEGMENTOP_INFO_H_
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "frontend/parallel/auto_parallel/operator_costmodel.h"
#include "frontend/parallel/ops_info/operator_info.h"
#include "frontend/parallel/strategy.h"
#include "ir/value.h"
namespace mindspore {
namespace parallel {
constexpr size_t UNSORTEDSEGMENTOP_INPUTS_SIZE = 2;
constexpr size_t UNSORTEDSEGMENTOP_OUTPUTS_SIZE = 1;
class UnsortedSegmentOpInfo : public OperatorInfo {
public:
UnsortedSegmentOpInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs, OperatorCostPtr cost)
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, cost) {}
~UnsortedSegmentOpInfo() override = default;
Status Init(const StrategyPtr &strategy) override;
Status InitForCostModel(const StrategyPtr &strategy) override;
Status GenerateStrategies(int32_t stage_id) override;
Status SetCostUnderStrategy(const StrategyPtr &strategy) override;
std::shared_ptr<Strategys> GenerateBatchStrategies() override;
protected:
Status CheckStrategy(const StrategyPtr &strategy) override;
Status InferForwardCommunication() override;
Status InferMirrorOps() override { return SUCCESS; }
Status InferTensorInfo() override;
Status InferDevMatrixShape() override;
Status InferTensorMap() override;
Status GetAttrs() override;
private:
Status ComputeReplaceGraph(const CNodePtr &cnode);
};
class UnsortedSegmentSumInfo : public UnsortedSegmentOpInfo {
public:
UnsortedSegmentSumInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: UnsortedSegmentOpInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<UnsortedSegmentSumCost>()) {}
~UnsortedSegmentSumInfo() override = default;
};
class UnsortedSegmentMinInfo : public UnsortedSegmentOpInfo {
public:
UnsortedSegmentMinInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: UnsortedSegmentOpInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<UnsortedSegmentMinCost>()) {}
~UnsortedSegmentMinInfo() override = default;
ReplaceGraphPtr replace_graph(const CNodePtr &cnode) override;
Status InferForwardCommunication() override { return SUCCESS; }
protected:
Status ComputeReplaceGraph(const CNodePtr &cnode);
};
} // namespace parallel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_UNSORTEDSEGMENTOP_INFO_H_

View File

@ -312,7 +312,8 @@ bool IsSplittableOperator(const std::string &op_name) {
EMBEDDING_LOOKUP, FUSE_BATCH_NORM_EX, SPLIT, BROADCAST_TO, ABS, ACOSH, ASIN, ASINH, ATAN, ATANH, CEIL, COSH,
EXPM1, LOG1P, SIN, SINH, TAN, RSQRT, INV, RECIPROCAL, ROUND, FLOOR, SIGN, ERF, ERFC, ZEROSLIKE, ONESLIKE,
BESSELI0E, BESSELI1E, FLOORMOD, ASSIGN, ASSIGN_ADD, ATAN2, DIVNONAN, LOGICALAND, LOGICALOR, ELU, RELU6, RELUV2,
SOFTPLUS, SOFTSIGN, GREATEREQUAL, LESSEQUAL, LESS, APPROXIMATEEQUAL, MOD, UNIQUE};
SOFTPLUS, SOFTSIGN, GREATEREQUAL, LESSEQUAL, LESS, APPROXIMATEEQUAL, MOD, UNIQUE, UNSORTED_SEGMENT_SUM,
UNSORTED_SEGMENT_MIN};
// clang-format on
auto iter = splittable_op.find(op_name);

View File

@ -740,9 +740,25 @@ void StepReplaceGraph(const ReplaceGraphPtr &replace_graph, const CNodePtr &node
if (manager == nullptr) {
MS_LOG(EXCEPTION) << "Failure:AddNode error since manager is nullptr";
}
// Sovle the input order
// For example input_node:{segment_sum:1, segment_sum:2, gahter:2}
// The Original code here will bind the all operations to the first inputs of theses operatos
// However, the segment_sum operation needs two inputs, To sovle this
// We maintain a dict to count the times of the same operations,
// and bind the inputs according to the times of the op appears.
static std::unordered_map<AnfNodePtr, int> input_map = {};
static int appear_count = 0;
for (auto &replace_input : replace_graph->first) {
auto pre_node = node->input(IntToSize(replace_input.second));
manager->SetEdge(replace_input.first, 1, pre_node);
auto it = input_map.find(replace_input.first);
if (it != input_map.end()) {
appear_count = 1 + it->second;
} else {
appear_count = 1;
}
input_map[replace_input.first] = appear_count;
manager->SetEdge(replace_input.first, appear_count, pre_node);
}
// "(void)manager->Replace(replace_graph->first, pre_node);" can not be called
auto replace_output = replace_graph->second;

View File

@ -0,0 +1,71 @@
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import mindspore as ms
import mindspore.nn as nn
from mindspore import Tensor
from mindspore import context
from mindspore.common.api import _executor
from mindspore.ops import composite as C
import mindspore.ops as P
from tests.ut.python.ops.test_math_ops import VirtualLoss
grad_all = C.GradOperation(get_all=True)
class NetWithLoss(nn.Cell):
def __init__(self, network):
super(NetWithLoss, self).__init__()
self.loss = VirtualLoss()
self.network = network
def construct(self, vectors, index):
predict = self.network(vectors, index)
return self.loss(predict)
class GradWrap(nn.Cell):
def __init__(self, network):
super(GradWrap, self).__init__()
self.network = network
def construct(self, vectors, index):
return grad_all(self.network)(vectors, index)
def test_auto_parallel_unsortedsegmentmin():
class Net(nn.Cell):
def __init__(self, num_segments):
super().__init__()
self.merge_op = P.UnsortedSegmentMin()
self.num_segments = num_segments
def construct(self, vectors, index):
out = self.merge_op(vectors, index, self.num_segments)
return out
size = 8
context.set_auto_parallel_context(device_num=size, global_rank=0)
context.set_auto_parallel_context(parallel_mode="auto_parallel")
x = Tensor(np.random.rand(16, 16, 32, 64), dtype=ms.float32)
indices = Tensor(np.random.randint(16, size=(16,)), ms.int32)
net = GradWrap(NetWithLoss(Net(16)))
net.set_auto_parallel()
net.set_train()
_executor.compile(net, x, indices)

View File

@ -0,0 +1,71 @@
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import mindspore as ms
import mindspore.nn as nn
from mindspore import Tensor
from mindspore import context
from mindspore.common.api import _executor
from mindspore.ops import composite as C
import mindspore.ops as P
from tests.ut.python.ops.test_math_ops import VirtualLoss
grad_all = C.GradOperation(get_all=True)
class NetWithLoss(nn.Cell):
def __init__(self, network):
super(NetWithLoss, self).__init__()
self.loss = VirtualLoss()
self.network = network
def construct(self, vectors, index):
predict = self.network(vectors, index)
return self.loss(predict)
class GradWrap(nn.Cell):
def __init__(self, network):
super(GradWrap, self).__init__()
self.network = network
def construct(self, vectors, index):
return grad_all(self.network)(vectors, index)
def test_auto_parallel_unsortedsegmentsum():
class Net(nn.Cell):
def __init__(self, num_segments):
super().__init__()
self.merge_op = P.UnsortedSegmentSum()
self.num_segments = num_segments
def construct(self, vectors, index):
out = self.merge_op(vectors, index, self.num_segments)
return out
size = 8
context.set_auto_parallel_context(device_num=size, global_rank=0)
context.set_auto_parallel_context(parallel_mode="auto_parallel")
x = Tensor(np.random.rand(16, 16, 32, 64), dtype=ms.float32)
indices = Tensor(np.random.randint(16, size=(16, 16)))
net = GradWrap(NetWithLoss(Net(16)))
net.set_auto_parallel()
net.set_train()
_executor.compile(net, x, indices)

View File

@ -0,0 +1,161 @@
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import mindspore as ms
import mindspore.nn as nn
from mindspore import Tensor
from mindspore import context
from mindspore.common.api import _executor
from mindspore.ops import composite as C
from mindspore.ops import operations as P
from mindspore.ops.operations.comm_ops import _VirtualDataset
from tests.ut.python.ops.test_math_ops import VirtualLoss
context.set_context(mode=context.GRAPH_MODE)
grad_all = C.GradOperation(get_all=True)
class Net(nn.Cell):
def __init__(self, strategy1, strategy2, num_segments):
super(Net, self).__init__()
self.virtual_dataset = _VirtualDataset()
self.merge_op = P.UnsortedSegmentMin().shard((strategy1, strategy2))
self.num_segments = num_segments
def construct(self, vectors, segment_ids):
predict = self.merge_op(vectors, segment_ids, self.num_segments)
return predict
class GradWrap(nn.Cell):
def __init__(self, network):
super(GradWrap, self).__init__()
self.network = network
def construct(self, x, y):
return grad_all(self.network)(x, y)
class NetWithLoss(nn.Cell):
def __init__(self, network):
super(NetWithLoss, self).__init__()
self.loss = VirtualLoss()
self.network = network
def construct(self, x, y):
predict = self.network(x, y)
return self.loss(predict)
def compile_graph(x, y, segments, strategy1, strategy2, auto=False):
net = GradWrap(NetWithLoss(Net(strategy1, strategy2, segments)))
net.set_auto_parallel()
net.set_train()
if auto:
context.set_auto_parallel_context(parallel_mode="auto_parallel")
else:
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
_executor.compile(net, x, y)
def test_unsortedsegmentmin_model_parallel_slice_1d():
context.set_auto_parallel_context(device_num=8, global_rank=0)
x = Tensor(np.ones(8), ms.float32)
y = Tensor(np.ones(8), ms.int32)
num_segments = 16
strategy1 = (8,)
strategy2 = (8,)
compile_graph(x, y, num_segments, strategy1, strategy2)
def test_unsortedsegmentmin_model_parallel_no_slice_1d():
context.set_auto_parallel_context(device_num=8, global_rank=0)
x = Tensor(np.ones(8), ms.float32)
y = Tensor(np.ones(8), ms.int32)
num_segments = 16
strategy1 = (1,)
strategy2 = (1,)
compile_graph(x, y, num_segments, strategy1, strategy2)
def test_unsortedsegmentmin_model_parallel_index_slice_2d():
context.set_auto_parallel_context(device_num=4, global_rank=0)
x = Tensor(np.ones((4, 8)), ms.float32)
y = Tensor(np.arange(4), ms.int32)
num_segments = 4
strategy1 = (4, 1)
strategy2 = (4,)
compile_graph(x, y, num_segments, strategy1, strategy2)
def test_unsortedsegmentmin_model_parallel_vector_slice_2d():
context.set_auto_parallel_context(device_num=4, global_rank=0)
x = Tensor(np.ones((4, 8)), ms.float32)
y = Tensor(np.ones(4), ms.int32)
num_segments = 4
strategy1 = (1, 4)
strategy2 = (1,)
compile_graph(x, y, num_segments, strategy1, strategy2)
def test_unsortedsegmentmin_model_parallel_vector_slice_3d():
context.set_auto_parallel_context(device_num=4, global_rank=0)
x = Tensor(np.ones((4, 8, 8)), ms.float32)
y = Tensor(np.ones(4), ms.int32)
num_segments = 4
strategy1 = (1, 2, 2)
strategy2 = (1,)
compile_graph(x, y, num_segments, strategy1, strategy2)
def test_unsortedsegmentmin_model_parallel_index_vector_slice_2d():
context.set_auto_parallel_context(device_num=4, global_rank=0)
x = Tensor(np.ones((4, 8)), ms.float32)
y = Tensor(np.ones(4), ms.int32)
num_segments = 4
strategy1 = (2, 2)
strategy2 = (2,)
compile_graph(x, y, num_segments, strategy1, strategy2)
def test_unsortedsegmentmin_model_parallel_index_vector_slice_3d():
context.set_auto_parallel_context(device_num=4, global_rank=0)
x = Tensor(np.ones((4, 4, 8)), ms.float32)
y = Tensor(np.ones((4)), ms.int32)
num_segments = 16
strategy1 = (2, 1, 2)
strategy2 = (2,)
compile_graph(x, y, num_segments, strategy1, strategy2)
def test_unsortedsegmentmin_model_parallel_float16():
context.set_auto_parallel_context(device_num=4, global_rank=0)
x = Tensor(np.ones((4, 4, 8)), ms.float16)
y = Tensor(np.ones((4)), ms.int32)
num_segments = 16
strategy1 = (2, 1, 2)
strategy2 = (2,)
compile_graph(x, y, num_segments, strategy1, strategy2)
def test_unsortedsegmentmin_model_parallel_int32():
context.set_auto_parallel_context(device_num=4, global_rank=0)
x = Tensor(np.ones((4, 4, 8)), ms.int32)
y = Tensor(np.ones((4)), ms.int32)
num_segments = 16
strategy1 = (2, 1, 2)
strategy2 = (2,)
compile_graph(x, y, num_segments, strategy1, strategy2)

View File

@ -0,0 +1,153 @@
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import mindspore as ms
import mindspore.nn as nn
from mindspore import Tensor
from mindspore import context
from mindspore.common.api import _executor
from mindspore.ops import composite as C
from mindspore.ops import operations as P
from mindspore.ops.operations.comm_ops import _VirtualDataset
from tests.ut.python.ops.test_math_ops import VirtualLoss
context.set_context(mode=context.GRAPH_MODE)
grad_all = C.GradOperation(get_all=True)
class Net(nn.Cell):
def __init__(self, strategy1, strategy2, num_segments):
super(Net, self).__init__()
self.virtual_dataset = _VirtualDataset()
self.merge_op = P.UnsortedSegmentSum().shard((strategy1, strategy2))
self.num_segments = num_segments
def construct(self, vectors, segment_ids):
predict = self.merge_op(vectors, segment_ids, self.num_segments)
return predict
class GradWrap(nn.Cell):
def __init__(self, network):
super(GradWrap, self).__init__()
self.network = network
def construct(self, x, y):
return grad_all(self.network)(x, y)
class NetWithLoss(nn.Cell):
def __init__(self, network):
super(NetWithLoss, self).__init__()
self.loss = VirtualLoss()
self.network = network
def construct(self, x, y):
predict = self.network(x, y)
return self.loss(predict)
def compile_graph(x, y, segments, strategy1, strategy2, auto=False):
net = GradWrap(NetWithLoss(Net(strategy1, strategy2, segments)))
net.set_auto_parallel()
net.set_train()
if auto:
context.set_auto_parallel_context(parallel_mode="auto_parallel")
else:
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
_executor.compile(net, x, y)
def test_unsortedsegmentsum_model_parallel_slice_1d():
context.set_auto_parallel_context(device_num=8, global_rank=0)
x = Tensor(np.ones(8), ms.float32)
y = Tensor(np.ones(8), ms.int32)
num_segments = 16
strategy1 = (8,)
strategy2 = (8,)
compile_graph(x, y, num_segments, strategy1, strategy2)
def test_unsortedsegmentsum_model_parallel_no_slice_1d():
context.set_auto_parallel_context(device_num=8, global_rank=0)
x = Tensor(np.ones(8), ms.float32)
y = Tensor(np.ones(8), ms.int32)
num_segments = 16
strategy1 = (1,)
strategy2 = (1,)
compile_graph(x, y, num_segments, strategy1, strategy2)
def test_unsortedsegmentsum_model_parallel_index_slice_2d():
context.set_auto_parallel_context(device_num=4, global_rank=0)
x = Tensor(np.ones((4, 8)), ms.float32)
y = Tensor(np.arange(4), ms.int32)
num_segments = 4
strategy1 = (4, 1)
strategy2 = (4,)
compile_graph(x, y, num_segments, strategy1, strategy2)
def test_unsortedsegmentsum_model_parallel_index_slice_3d():
context.set_auto_parallel_context(device_num=4, global_rank=0)
x = Tensor(np.ones((4, 4, 8)), ms.float32)
y = Tensor(np.ones((4, 4)), ms.int32)
num_segments = 16
strategy1 = (2, 2, 1)
strategy2 = (2, 2)
compile_graph(x, y, num_segments, strategy1, strategy2)
def test_unsortedsegmentsum_model_parallel_vector_slice_2d():
context.set_auto_parallel_context(device_num=4, global_rank=0)
x = Tensor(np.ones((4, 8)), ms.float32)
y = Tensor(np.ones(4), ms.int32)
num_segments = 4
strategy1 = (1, 4)
strategy2 = (1,)
compile_graph(x, y, num_segments, strategy1, strategy2)
def test_unsortedsegmentsum_model_parallel_vector_slice_3d():
context.set_auto_parallel_context(device_num=4, global_rank=0)
x = Tensor(np.ones((4, 8, 8)), ms.float32)
y = Tensor(np.ones(4), ms.int32)
num_segments = 4
strategy1 = (1, 2, 2)
strategy2 = (1,)
compile_graph(x, y, num_segments, strategy1, strategy2)
def test_unsortedsegmentsum_model_parallel_index_vector_slice_2d():
context.set_auto_parallel_context(device_num=4, global_rank=0)
x = Tensor(np.ones((4, 8)), ms.float32)
y = Tensor(np.ones(4), ms.int32)
num_segments = 4
strategy1 = (2, 2)
strategy2 = (2,)
compile_graph(x, y, num_segments, strategy1, strategy2)
def test_unsortedsegmentsum_model_parallel_index_vector_slice_3d():
context.set_auto_parallel_context(device_num=4, global_rank=0)
x = Tensor(np.ones((4, 4, 8)), ms.float32)
y = Tensor(np.ones((4, 4)), ms.int32)
num_segments = 16
strategy1 = (2, 1, 2)
strategy2 = (2, 1)
compile_graph(x, y, num_segments, strategy1, strategy2)