diff --git a/mindspore/ccsrc/parallel/auto_parallel/operator_costmodel.cc b/mindspore/ccsrc/parallel/auto_parallel/operator_costmodel.cc index 1f420e8797f..960e13281cf 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/operator_costmodel.cc +++ b/mindspore/ccsrc/parallel/auto_parallel/operator_costmodel.cc @@ -623,5 +623,34 @@ double DropOutCost::GetForwardComputationCost(const std::vector& inp Shape input0_slice_shape = input0.slice_shape(); return ListProduct(input0_slice_shape) * static_cast(inputs_type_lengths_[0]) * DROPOUT_COST_RATE; } + +// return the per device communication cost in the forward phase. +double GatherV2Cost::GetForwardCommCost(const std::vector&, const std::vector&, + const int32_t&) const { + // GatherV2Cost does not need communication in the forward phase + return 0.0; +} + +// return the per device communication cost in the backward phase. +double GatherV2Cost::GetBackwardCommCost(const std::vector&, const std::vector&, + const int32_t&) const { + // GatherV2Cost does not need communication in the backward phase + return 0.0; +} + +double GatherV2Cost::GetForwardComputationCost(const std::vector& inputs, const std::vector&, + const int32_t&) const { + // In forward phase, the computation cost = slice(A) + slice(B) + Shape input0_slice_shape = inputs[0].slice_shape(); + Shape input1_slice_shape = inputs[1].slice_shape(); + double result = ListProduct(input0_slice_shape) * static_cast(inputs_type_lengths_[0]) + + ListProduct(input1_slice_shape) * static_cast(inputs_type_lengths_[1]); + return result; +} + +double GatherV2Cost::GetBackwardComputationCost(const std::vector&, const std::vector&, + const int32_t&) const { + return 0.0; +} } // namespace parallel } // namespace mindspore diff --git a/mindspore/ccsrc/parallel/auto_parallel/operator_costmodel.h b/mindspore/ccsrc/parallel/auto_parallel/operator_costmodel.h index b642ada0d95..685cb259c31 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/operator_costmodel.h +++ b/mindspore/ccsrc/parallel/auto_parallel/operator_costmodel.h @@ -81,6 +81,8 @@ class OperatorCost { std::vector outputs_type_lengths_; }; +using OperatorCostPtr = std::shared_ptr; + class MatMulCost : public OperatorCost { public: MatMulCost() = default; @@ -525,6 +527,31 @@ class DropOutCost : public OperatorCost { }; using DropOutCostPtr = std::shared_ptr; + +class GatherV2Cost : public OperatorCost { + public: + GatherV2Cost() = default; + ~GatherV2Cost() override = default; + + double GetCommCost(const std::vector& inputs, const std::vector& outputs, + const int32_t& stage_id) const override { + return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); + } + double GetForwardCommCost(const std::vector& inputs, const std::vector& outputs, + const int32_t& stage_id) const override; + double GetBackwardCommCost(const std::vector& inputs, const std::vector& outputs, + const int32_t& stage_id) const override; + double GetComputationCost(const std::vector& inputs, const std::vector& outputs, + const int32_t& stage_id) const override { + return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); + } + double GetForwardComputationCost(const std::vector& inputs, const std::vector& outputs, + const int32_t& stage_id) const override; + double GetBackwardComputationCost(const std::vector& inputs, const std::vector& outputs, + const int32_t&) const override; +}; + +using GatherV2CostPtr = std::shared_ptr; } // namespace parallel } // namespace mindspore #endif // PARALLEL_AUTO_PARALLEL_OPERATOR_COSTMODEL_H_ diff --git a/mindspore/ccsrc/parallel/ops_info/batch_parallel_info.cc b/mindspore/ccsrc/parallel/ops_info/batch_parallel_info.cc index 793452b8ad2..b1d9b8b60e8 100644 --- a/mindspore/ccsrc/parallel/ops_info/batch_parallel_info.cc +++ b/mindspore/ccsrc/parallel/ops_info/batch_parallel_info.cc @@ -228,26 +228,6 @@ void SparseSoftmaxCrossEntropyWithLogitsInfo::ReComputeBatchSplitFlagList() { } } -void GatherV2Info::ReComputeBatchSplitFlagList() { - MS_ASSERT(inputs_shape_.size() == 2); - MS_ASSERT(input_value_.size() == 3); - MS_ASSERT(input_value_[0] == nullptr); - // the second input is the index tensor - MS_ASSERT(input_value_[1] != nullptr); - // the third input is the axis - MS_ASSERT(input_value_[2] != nullptr); - int axis = GetValue(input_value_[2]); - MS_ASSERT(axis < inputs_shape_[0].size() && axis >= 0 - inputs_shape_[0].size()); - if (axis < 0) { - axis += SizeToInt(inputs_shape_[0].size()); - } - split_flag_list_[0] = true; - // if gather axis is 0, the index's strategy is equal to device number - if (axis == 0) { - split_flag_list_[1] = true; - } -} - Status BatchParallelInfo::InferAsLossDivisor() { as_loss_divisor_ = 1; return SUCCESS; diff --git a/mindspore/ccsrc/parallel/ops_info/batch_parallel_info.h b/mindspore/ccsrc/parallel/ops_info/batch_parallel_info.h index fae96dcab5b..093bfb8fad7 100644 --- a/mindspore/ccsrc/parallel/ops_info/batch_parallel_info.h +++ b/mindspore/ccsrc/parallel/ops_info/batch_parallel_info.h @@ -62,15 +62,6 @@ class SparseSoftmaxCrossEntropyWithLogitsInfo : public BatchParallelInfo { ~SparseSoftmaxCrossEntropyWithLogitsInfo() override = default; void ReComputeBatchSplitFlagList() override; }; - -class GatherV2Info : public BatchParallelInfo { - public: - GatherV2Info(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, - const PrimitiveAttrs& attrs) - : BatchParallelInfo(name, inputs_shape, outputs_shape, attrs) {} - ~GatherV2Info() override = default; - void ReComputeBatchSplitFlagList() override; -}; } // namespace parallel } // namespace mindspore diff --git a/mindspore/ccsrc/parallel/ops_info/gather_v2_info.cc b/mindspore/ccsrc/parallel/ops_info/gather_v2_info.cc new file mode 100644 index 00000000000..2010d1ed46a --- /dev/null +++ b/mindspore/ccsrc/parallel/ops_info/gather_v2_info.cc @@ -0,0 +1,350 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "parallel/ops_info/gather_v2_info.h" + +#include +#include +#include + +#include "ir/meta_tensor.h" +#include "ir/value.h" +#include "parallel/auto_parallel/costmodel.h" +#include "parallel/device_matrix.h" +#include "parallel/graph_util/generate_graph.h" +#include "parallel/strategy.h" +#include "utils/log_adapter.h" + +namespace mindspore { +namespace parallel { +Status GatherV2Info::GetAttrs() { + if (inputs_shape_.size() != GATHER_V2_INPUTS_SIZE) { + MS_LOG(ERROR) << name_ << ": inputs shape size must be 2, but is " << inputs_shape_.size(); + return FAILED; + } + if (outputs_shape_.size() != GATHER_V2_OUTPUTS_SIZE) { + MS_LOG(ERROR) << name_ << ": outputs shape size must be 1, but is " << outputs_shape_.size(); + return FAILED; + } + if (input_value_.size() != GATHER_V2_INPUTS_VALUE_SIZE) { + MS_LOG(ERROR) << name_ << ": input value size must be 3, but is " << input_value_.size(); + return FAILED; + } + // the second input is the index tensor + + // the third input is the axis, is a ValueNode + if (input_value_.at(2) == nullptr) { + MS_LOG(ERROR) << name_ << ": the third input value is nullptr, is not a ValueNode!"; + return FAILED; + } + + if (inputs_shape_.at(0).size() == 0) { + MS_LOG(ERROR) << name_ << ": input can not be a scalar!"; + return FAILED; + } + int axis = GetValue(input_value_.at(2)); + if (axis >= SizeToInt(inputs_shape_.at(0).size()) || axis < 0 - SizeToInt(inputs_shape_.at(0).size())) { + MS_LOG(ERROR) << "Axis is " << axis << ", not in [-" << inputs_shape_.at(0).size() << ", " + << inputs_shape_.at(0).size() << ")."; + } + if (axis < 0) { + axis += SizeToInt(inputs_shape_[0].size()); + } + axis_ = axis; + + index_size_ = inputs_shape_.at(1).size(); + + return SUCCESS; +} + +Status GatherV2Info::CheckStrategy(const StrategyPtr& strategy) { + if (inputs_shape_.size() != GATHER_V2_INPUTS_SIZE) { + MS_LOG(ERROR) << name_ << ": inputs shape size must be " << GATHER_V2_INPUTS_SIZE << ", but is " + << inputs_shape_.size(); + return FAILED; + } + if (outputs_shape_.size() != GATHER_V2_OUTPUTS_SIZE) { + MS_LOG(ERROR) << name_ << ": outputs shape size must be " << GATHER_V2_OUTPUTS_SIZE << ", but is " + << outputs_shape_.size(); + return FAILED; + } + // Only strategy of the first input should be set. + if (CheckStrategyValue(strategy, {inputs_shape_.at(0)}, is_auto_parallel_) != SUCCESS) { + if (is_auto_parallel_) { + MS_LOG(DEBUG) << name_ << ": Invalid strategy."; + } else { + MS_LOG(ERROR) << name_ << ": Invalid strategy."; + } + return FAILED; + } + axis_strategy_ = strategy->GetInputDim().at(0).at(axis_); + if (index_size_ != 1 && axis_strategy_ != 1) { + MS_LOG(ERROR) << name_ + << ": Invalid strategy. If the index is a scalar or a more than 1 dimension vector, the strategy " + "corresponding to axis must be 1, but is " + << axis_strategy_; + return FAILED; + } + if (index_size_ == 1 && axis_strategy_ != 1 && inputs_shape_.at(1).at(0) % axis_strategy_ != 0) { + MS_LOG(ERROR) << name_ + << ": Invalid strategy. The first dimension of index can not be divided by strategy corresponding to " + "axis. The first dimension of index is " + << inputs_shape_.at(1).at(0) << " strategy corresponding to axis is " << axis_strategy_; + return FAILED; + } + return SUCCESS; +} + +Status GatherV2Info::InferDevMatrixShape() { + std::vector stra = strategy_->GetInputDim(); + dev_matrix_shape_ = stra.at(0); + return SUCCESS; +} + +// If index is a scalar, output dimension is input dimension minus 1; +// If index is a n dimension tensor, output dimension is input dimension plus (n - 1). +// Tensor map dimension is equal to the corresponding input and output dimension. +// If index's dimension is more than 1, we insert -1 for the output tensor map. +Status GatherV2Info::InferTensorMap() { + if (inputs_shape_.size() != GATHER_V2_INPUTS_SIZE) { + MS_LOG(ERROR) << name_ << ": inputs shape size must be " << GATHER_V2_INPUTS_SIZE << ", but is " + << inputs_shape_.size(); + return FAILED; + } + if (outputs_shape_.size() != GATHER_V2_OUTPUTS_SIZE) { + MS_LOG(ERROR) << name_ << ": outputs shape size must be " << GATHER_V2_OUTPUTS_SIZE << ", but is " + << outputs_shape_.size(); + return FAILED; + } + std::vector tensor_map_in; + std::vector tensor_map_out; + size_t size = inputs_shape_.at(0).size(); + // such as 4: tensor_map_index [3,2,1,0] + for (size_t i = 0; i < size; ++i) { + tensor_map_in.push_back(SizeToInt(size - i - 1)); + tensor_map_out.push_back(SizeToInt(size - i - 1)); + } + + if (index_size_ == 0) { + (void)tensor_map_out.erase(tensor_map_out.begin() + axis_); + } else if (index_size_ > 1) { + (void)tensor_map_out.insert(tensor_map_out.begin() + axis_, index_size_ - 1, -1); + } + if (tensor_map_out.size() != outputs_shape_.at(0).size()) { + MS_LOG(ERROR) << "Out tensor map size is not equal to output size! Out tensor map size is " << tensor_map_out.size() + << " output size is " << outputs_shape_.at(0).size(); + return FAILED; + } + + std::vector tensor_map_in_index; + if (index_size_ >= 1) { + tensor_map_in_index.push_back(SizeToInt(size - axis_ - 1)); + } + for (size_t i = 1; i < index_size_; ++i) { + tensor_map_in_index.push_back(-1); + } + inputs_tensor_map_.emplace_back(std::move(tensor_map_in)); + inputs_tensor_map_.emplace_back(std::move(tensor_map_in_index)); + outputs_tensor_map_.emplace_back(std::move(tensor_map_out)); + return SUCCESS; +} + +Status GatherV2Info::InferTensorInfo() { + if (inputs_shape_.size() != GATHER_V2_INPUTS_SIZE) { + MS_LOG(ERROR) << name_ << ": inputs shape size must be " << GATHER_V2_INPUTS_SIZE << ", but is " + << inputs_shape_.size(); + return FAILED; + } + if (outputs_shape_.size() != GATHER_V2_OUTPUTS_SIZE) { + MS_LOG(ERROR) << name_ << ": outputs shape size must be " << GATHER_V2_OUTPUTS_SIZE << ", but is " + << outputs_shape_.size(); + return FAILED; + } + if (inputs_tensor_map_.size() != GATHER_V2_INPUTS_SIZE) { + MS_LOG(ERROR) << name_ << ": inputs tensor map size must be " << GATHER_V2_INPUTS_SIZE << ", but is " + << inputs_tensor_map_.size(); + return FAILED; + } + if (outputs_tensor_map_.size() != GATHER_V2_OUTPUTS_SIZE) { + MS_LOG(ERROR) << name_ << ": outputs tensor map size must be " << GATHER_V2_OUTPUTS_SIZE << ", but is " + << outputs_tensor_map_.size(); + return FAILED; + } + // infer tensor shape + Shape input_shape = inputs_shape_.at(0); + Shape input_index_shape = inputs_shape_.at(1); + Shape output_shape = outputs_shape_.at(0); + + TensorLayout input_tensor_layout, input_index_layout, output_tensor_layout; + if ((input_tensor_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_.at(0), input_shape) != SUCCESS) || + (input_index_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_.at(1), input_index_shape) != SUCCESS) || + (output_tensor_layout.InitFromVector(dev_matrix_shape_, outputs_tensor_map_.at(0), output_shape) != SUCCESS)) { + return FAILED; + } + + TensorInfo input_tensor_info(input_tensor_layout); + TensorInfo input_index_info(input_index_layout); + TensorInfo output_tensor_info(output_tensor_layout); + + inputs_tensor_info_.push_back(input_tensor_info); + inputs_tensor_info_.push_back(input_index_info); + outputs_tensor_info_.push_back(output_tensor_info); + return SUCCESS; +} + +OperatorVector CreateSubOp(int32_t sub_value) { + OperatorVector ops; + OperatorName operator_name = SUB; + OperatorAttrs operator_attrs; + + py::tuple tuple = py::make_tuple(sub_value); + mindspore::tensor::TensorPtr tensor_ptr = std::make_shared(tuple, kInt32); + ValuePtr op_param_value = MakeValue(tensor_ptr); + + Attr op1_param = std::make_pair("", op_param_value); + OperatorParams operator_param = {std::make_pair(op1_param, 2)}; + + OperatorArgs operator_args = std::make_pair(operator_attrs, operator_param); + Operator op = std::make_pair(operator_name, operator_args); + ops.push_back(op); + return ops; +} + +Status GatherV2Info::InferTensorSubOps() { + sub_ops_.clear(); + if ((index_size_ == 0) || (axis_strategy_ == 1)) { + return SUCCESS; + } + int32_t mod_n = 1; + for (size_t i = IntToSize(axis_) + 1; i < dev_matrix_shape_.size(); i++) { + mod_n *= dev_matrix_shape_.at(i); + } + if ((axis_ >= SizeToInt(dev_matrix_shape_.size())) || axis_ < 0) { + MS_LOG(ERROR) << "Axis is " << axis_ << ", not in [0, " << dev_matrix_shape_.size() << ")."; + } + int32_t mod_p = mod_n * dev_matrix_shape_.at(axis_); + int32_t rank = g_device_manager->global_rank(); + int32_t mod_rank = rank % mod_p; + mod_rank = static_cast(mod_rank / mod_n); + if (inputs_shape_.size() != GATHER_V2_INPUTS_SIZE) { + MS_LOG(ERROR) << name_ << ": inputs shape size must be " << GATHER_V2_INPUTS_SIZE << ", but is " + << inputs_shape_.size(); + return FAILED; + } + if ((axis_ >= SizeToInt(inputs_shape_.at(0).size())) || axis_ < 0) { + MS_LOG(ERROR) << "Axis is " << axis_ << ", not in [0, " << inputs_shape_.at(0).size() << ")."; + } + int32_t sub_value = static_cast(inputs_shape_.at(0).at(axis_) / dev_matrix_shape_.at(axis_)) * mod_rank; + + OperatorVector sub_op; + sub_ops_.emplace_back(std::move(sub_op)); + sub_op = CreateSubOp(sub_value); + sub_ops_.emplace_back(std::move(sub_op)); + return SUCCESS; +} + +Status GatherV2Info::Init(const StrategyPtr& strategy) { + if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Init failed."; + return FAILED; + } + Status status = InferTensorSubOps(); + if (status != SUCCESS) { + MS_LOG(ERROR) << name_ << ": InferTensorSubOps failed."; + return status; + } + MS_LOG(INFO) << name_ << ": Init success."; + return SUCCESS; +} + +Status GatherV2Info::InitForCostModel(const StrategyPtr& strategy) { + if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { + if (is_auto_parallel_) { + MS_LOG(DEBUG) << name_ << ": Init for cost model failed."; + } else { + MS_LOG(ERROR) << name_ << ": Init for cost model failed."; + } + return FAILED; + } + MS_LOG(INFO) << name_ << ": Init for cost model success."; + return SUCCESS; +} + +Status GatherV2Info::GenerateStrategies(int32_t stage_id) { + if ((inputs_shape_.size() != GATHER_V2_INPUTS_SIZE) || (outputs_shape_.size() != GATHER_V2_OUTPUTS_SIZE)) { + MS_LOG(ERROR) << name_ << " : Inputs shape size(" << inputs_shape_.size() << ") or outputs shape size(" + << outputs_shape_.size() << "is wrong."; + return FAILED; + } + + is_auto_parallel_ = true; + Shape input0_split(inputs_shape_[0].size()); + Shapes splittable_inputs = {input0_split}; + + std::vector sp_vector; + if (GenerateStrategiesForIndependentInputs(stage_id, {inputs_shape_.at(0)}, splittable_inputs, &sp_vector) != + SUCCESS) { + MS_LOG(ERROR) << name_ << " : Generate strategies for independent inputs() failed."; + return FAILED; + } + size_t success = 0; + for (auto& sp : sp_vector) { + if (SetCostUnderStrategy(sp) == SUCCESS) { + success++; + MS_LOG(INFO) << name_ << " : Successfully generated " << success << " strategy"; + PrintStrategy(sp); + } + } + return SUCCESS; +} + +Status GatherV2Info::SetCostUnderStrategy(const StrategyPtr& strategy) { + if (SetCostUnderStrategyBase(strategy) != SUCCESS) { + if (is_auto_parallel_) { + MS_LOG(DEBUG) << name_ << ": Set cost under strategy failed."; + } else { + MS_LOG(ERROR) << name_ << ": Set cost under strategy failed."; + } + return FAILED; + } + return SUCCESS; +} + +std::shared_ptr>> GatherV2Info::GenerateBatchStrategies() { + if (inputs_shape_.size() != GATHER_V2_INPUTS_SIZE) { + MS_LOG(EXCEPTION) << name_ << ": inputs shape size must be " << GATHER_V2_INPUTS_SIZE << ", but is " + << inputs_shape_.size(); + } + CheckGlobalDeviceManager(); + size_t dev_num = g_device_manager->GetDeviceListByStageId(0).size(); + if (GetAttrs() != SUCCESS) { + MS_LOG(EXCEPTION) << "GetAttrs failed!"; + } + + Dimensions strategy; + if (index_size_ != 1) { + strategy.push_back(1); + } else { + strategy.push_back(SizeToInt(dev_num)); + } + for (size_t i = 1; i < inputs_shape_[0].size(); i++) { + strategy.push_back(1); + } + std::vector strategy_v = {strategy}; + return std::make_shared>>(strategy_v); +} +} // namespace parallel +} // namespace mindspore diff --git a/mindspore/ccsrc/parallel/ops_info/gather_v2_info.h b/mindspore/ccsrc/parallel/ops_info/gather_v2_info.h new file mode 100644 index 00000000000..773d46f4294 --- /dev/null +++ b/mindspore/ccsrc/parallel/ops_info/gather_v2_info.h @@ -0,0 +1,73 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PARALLEL_OPS_INFO_GATHER_V2_INFO_H_ +#define MINDSPORE_CCSRC_PARALLEL_OPS_INFO_GATHER_V2_INFO_H_ + +#include +#include +#include +#include + +#include "ir/value.h" +#include "parallel/auto_parallel/operator_costmodel.h" +#include "parallel/ops_info/operator_info.h" +#include "parallel/strategy.h" + +namespace mindspore { +namespace parallel { +constexpr size_t GATHER_V2_INPUTS_SIZE = 2; +constexpr size_t GATHER_V2_OUTPUTS_SIZE = 1; +constexpr size_t GATHER_V2_INPUTS_VALUE_SIZE = 3; +// We now supported limited parallel strategies. +// If the strategy corresponding to axis is more than 1, index must be evenly distributed across the axis-dimension of +// the input. +// If Index is a scalar or n-dimension vector(n > 1), the strategy corresponding to axis must be 1. +class GatherV2Info : public OperatorInfo { + public: + GatherV2Info(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, + const PrimitiveAttrs& attrs) + : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared()), + axis_(-1), + index_size_(0), + axis_strategy_(1) {} + ~GatherV2Info() override = default; + Status Init(const StrategyPtr& strategy) override; + Status InitForCostModel(const StrategyPtr& strategy) override; + + Status GenerateStrategies(int32_t stage_id) override; + Status SetCostUnderStrategy(const StrategyPtr& strategy) override; + std::shared_ptr>> GenerateBatchStrategies() override; + + protected: + Status CheckStrategy(const StrategyPtr& strategy) override; + Status InferMirrorOps() override { return SUCCESS; } + Status InferForwardCommunication() override { return SUCCESS; } + Status InferTensorInfo() override; + Status InferDevMatrixShape() override; + Status InferTensorMap() override; + Status GetAttrs() override; + + private: + Status InferTensorSubOps(); + + int32_t axis_; + size_t index_size_; + int32_t axis_strategy_; +}; +} // namespace parallel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_GATHER_V2_INFO_H_ diff --git a/mindspore/ccsrc/parallel/ops_info/operator_info.cc b/mindspore/ccsrc/parallel/ops_info/operator_info.cc index 68c73bc548e..42755b3ec36 100644 --- a/mindspore/ccsrc/parallel/ops_info/operator_info.cc +++ b/mindspore/ccsrc/parallel/ops_info/operator_info.cc @@ -112,6 +112,7 @@ void OperatorInfo::ResetQueueMember() { dev_matrix_shape_.clear(); forward_op_.clear(); mirror_ops_.clear(); + sub_ops_.clear(); replace_op_.clear(); replace_op_info_.clear(); virtual_div_op_.clear(); diff --git a/mindspore/ccsrc/parallel/ops_info/operator_info.h b/mindspore/ccsrc/parallel/ops_info/operator_info.h index 8fcae8ad330..248172fa4c0 100644 --- a/mindspore/ccsrc/parallel/ops_info/operator_info.h +++ b/mindspore/ccsrc/parallel/ops_info/operator_info.h @@ -41,6 +41,7 @@ namespace mindspore { namespace parallel { using ForwardOp = OperatorVector; using MirrorOps = std::vector; +using Ops = std::vector; using VirtualDivOp = OperatorVector; using TensorMaps = std::vector>; using TensorLayouts = std::vector; @@ -99,6 +100,7 @@ class OperatorInfo { OutPutInfoVector replace_op_info() const { return replace_op_info_; } virtual ReplaceGraphPtr replace_graph(const CNodePtr&) { return replace_graph_; } MirrorOps mirror_ops() const { return mirror_ops_; } + Ops sub_ops() const { return sub_ops_; } VirtualDivOp virtual_div_op() const { return virtual_div_op_; } Shape dev_matrix_shape() const { return dev_matrix_shape_; } std::vector inputs_tensor_info() const { return inputs_tensor_info_; } @@ -190,6 +192,7 @@ class OperatorInfo { TensorMaps inputs_tensor_map_; TensorMaps outputs_tensor_map_; ForwardOp forward_op_; + Ops sub_ops_; ForwardOp replace_op_; OutPutInfoVector replace_op_info_; ReplaceGraphPtr replace_graph_; diff --git a/mindspore/ccsrc/parallel/ops_info/ops_info_head_files.h b/mindspore/ccsrc/parallel/ops_info/ops_info_head_files.h index 1681c8f7962..27b434ecca8 100644 --- a/mindspore/ccsrc/parallel/ops_info/ops_info_head_files.h +++ b/mindspore/ccsrc/parallel/ops_info/ops_info_head_files.h @@ -24,6 +24,7 @@ #include "parallel/ops_info/comparison_function_info.h" #include "parallel/ops_info/dropout_do_mask_info.h" #include "parallel/ops_info/elementary_function_info.h" +#include "parallel/ops_info/gather_v2_info.h" #include "parallel/ops_info/get_next_info.h" #include "parallel/ops_info/l2_normalize_info.h" #include "parallel/ops_info/loss_info.h" diff --git a/mindspore/ccsrc/parallel/step_parallel.cc b/mindspore/ccsrc/parallel/step_parallel.cc index 0a6d0b0bef8..eab5443481c 100644 --- a/mindspore/ccsrc/parallel/step_parallel.cc +++ b/mindspore/ccsrc/parallel/step_parallel.cc @@ -464,6 +464,14 @@ void SplitTensor(const AnfNodePtr& node, const CNodePtr& next_node, int index) { MS_EXCEPTION_IF_NULL(func_graph); Operator op = CreateGetTensorSliceOp(tensor_layout); InsertGetTensorSliceOp(op, next_node, func_graph, index, SPLIT_TENSOR); + if (!op_info->sub_ops().empty()) { + auto sub_ops = op_info->sub_ops(); + for (size_t i = 0; i < sub_ops.size(); i++) { + if (!sub_ops.at(i).empty()) { + InsertGetTensorSliceOp(sub_ops.at(i).at(0), next_node, func_graph, index, SUB); + } + } + } } void StepSplitTensor(const AnfNodePtr& node, const FuncGraphManagerPtr& manager) { diff --git a/tests/ut/python/parallel/test_gather_v2_primitive.py b/tests/ut/python/parallel/test_gather_v2_primitive.py index c623595b53c..3ea0795e9c4 100644 --- a/tests/ut/python/parallel/test_gather_v2_primitive.py +++ b/tests/ut/python/parallel/test_gather_v2_primitive.py @@ -29,6 +29,8 @@ from mindspore.nn import Dense, Cell from mindspore import context context.set_context(mode=context.GRAPH_MODE) +device_number = 32 +batch_size_per_device = 128 class Dataset(): @@ -57,15 +59,22 @@ class Dataset(): class GatherV2(_Loss): - def __init__(self, batchsize): + def __init__(self, index_dim, strategy, index_size=16): super(GatherV2, self).__init__() self.pow = P.Pow() - emb_list = list(range(batchsize)) - emb1_list = emb_list[0::2] - emb2_list = emb_list[1::2] + emb1_list = 21 + emb2_list = 2 + if index_dim == 1: + emb_list = list(range(index_size)) + emb1_list = emb_list[0::2] + emb2_list = emb_list[1::2] + if index_dim == 2: + emb_list = np.arange(index_size*16) + emb1_list = np.reshape(emb_list[0::2], (int(index_size/2), 16)) + emb2_list = np.reshape(emb_list[1::2], (int(index_size/2), 16)) self.emb1_param = Tensor(emb1_list, dtype=mstype.int32) self.emb2_param = Tensor(emb2_list, dtype=mstype.int32) - self.gatherv2 = P.GatherV2() + self.gatherv2 = P.GatherV2().set_strategy(strategy) def construct(self, nembeddings): emb1 = self.gatherv2(nembeddings, self.emb1_param, 0) @@ -73,10 +82,6 @@ class GatherV2(_Loss): return self.pow((emb1 - emb2), 2.0) -def get_loss(batchsize): - return GatherV2(batchsize) - - def fc_with_initialize(input_channels, out_channels): return Dense(input_channels, out_channels) @@ -114,26 +119,23 @@ class TrainOneStepCell(Cell): return F.depend(loss, self.optimizer(grads)) -def test_trains(): +def net_trains(gather_v2_strategy, criterion, rank): init() lr = 0.1 momentum = 0.9 max_epoch = 20 - device_number = 32 - batch_size_per_device = 128 input_channels = 256 out_channels = 512 - + context.set_context(mode=context.GRAPH_MODE, save_graphs=False) context.reset_auto_parallel_context() - context.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, device_num=device_number) + context.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, device_num=device_number, + global_rank=rank) predict = Tensor(np.ones([batch_size_per_device, input_channels]), dtype=ms.float32) dataset = Dataset(predict, 4) network = fc_with_initialize(input_channels, out_channels) network.set_train() - criterion = get_loss(batch_size_per_device * device_number) - train_network = BuildTrainNetwork(network, criterion) train_network.set_train() opt = Momentum(train_network.trainable_params(), lr, momentum) @@ -143,5 +145,90 @@ def test_trains(): model.train(max_epoch, dataset, dataset_sink_mode=False) context.reset_auto_parallel_context() -if __name__ == "__main__": - test_trains() + +def test_auto_batch_parallel(): + gather_v2_strategy = None + criterion = GatherV2(1, strategy=gather_v2_strategy, index_size=batch_size_per_device * device_number) + rank = 2 + net_trains(gather_v2_strategy, criterion, rank) + + +def test_2d_index_auto_batch_parallel(): + gather_v2_strategy = None + criterion = GatherV2(2, strategy=gather_v2_strategy, index_size=batch_size_per_device * device_number) + rank = 2 + net_trains(gather_v2_strategy, criterion, rank) + + +def test_batch_parallel(): + gather_v2_strategy = ((device_number, 1),) + criterion = GatherV2(1, strategy=gather_v2_strategy, index_size=batch_size_per_device * device_number) + rank = 2 + net_trains(gather_v2_strategy, criterion, rank) + + +def test_strategy1(): + gather_v2_strategy = ((16, 2),) + rank = 2 + criterion = GatherV2(1, strategy=gather_v2_strategy, index_size=batch_size_per_device * device_number) + net_trains(gather_v2_strategy, criterion, rank) + + +def test_strategy2(): + gather_v2_strategy = ((1, device_number),) + rank = 2 + criterion = GatherV2(1, strategy=gather_v2_strategy, index_size=batch_size_per_device * device_number) + net_trains(gather_v2_strategy, criterion, rank) + + +def test_strategy3(): + gather_v2_strategy = ((8, 1),) + rank = 2 + criterion = GatherV2(1, strategy=gather_v2_strategy, index_size=batch_size_per_device * device_number) + net_trains(gather_v2_strategy, criterion, rank) + + +class GatherV2Axis1(_Loss): + def __init__(self, index_dim, strategy, index_size=16): + super(GatherV2Axis1, self).__init__() + self.pow = P.Pow() + emb1_list = 21 + emb2_list = 2 + if index_dim == 1: + emb_list = list(range(index_size)) + emb1_list = emb_list[0::2] + emb2_list = emb_list[1::2] + if index_dim == 2: + emb_list = np.arange(index_size*index_size) + emb1_list = np.reshape(emb_list[0::2], (int(index_size/2), index_size)) + emb2_list = np.reshape(emb_list[1::2], (int(index_size/2), index_size)) + self.emb1_param = Tensor(emb1_list, dtype=mstype.int32) + self.emb2_param = Tensor(emb2_list, dtype=mstype.int32) + self.gatherv2 = P.GatherV2().set_strategy(strategy) + + def construct(self, nembeddings): + emb1 = self.gatherv2(nembeddings, self.emb1_param, 1) + emb2 = self.gatherv2(nembeddings, self.emb2_param, 1) + return self.pow((emb1 - emb2), 2.0) + + +def test_axis1_auto_batch_parallel(): + gather_v2_strategy = None + criterion = GatherV2Axis1(1, strategy=gather_v2_strategy, index_size=512) + rank = 2 + net_trains(gather_v2_strategy, criterion, rank) + + +def test_axis1_batch_parallel(): + gather_v2_strategy = ((device_number, 1),) + criterion = GatherV2Axis1(1, strategy=gather_v2_strategy, index_size=512) + rank = 2 + net_trains(gather_v2_strategy, criterion, rank) + + +def test_axis1_strategy1(): + gather_v2_strategy = ((16, 2),) + rank = 17 + criterion = GatherV2Axis1(1, strategy=gather_v2_strategy, index_size=512) + net_trains(gather_v2_strategy, criterion, rank) +