From 19a24b86acb5e132564924a6b201af83d96f1ce2 Mon Sep 17 00:00:00 2001 From: lichenever Date: Thu, 30 Apr 2020 11:27:11 +0800 Subject: [PATCH] add gatherv2 distributed op --- .../auto_parallel/operator_costmodel.cc | 85 +++++ .../auto_parallel/operator_costmodel.h | 36 ++ mindspore/ccsrc/parallel/dynamic_creator.h | 1 + .../parallel/graph_util/generate_graph.h | 1 + .../parallel/ops_info/gather_v2_p_info.cc | 339 ++++++++++++++++++ .../parallel/ops_info/gather_v2_p_info.h | 67 ++++ .../ccsrc/parallel/ops_info/onehot_info.cc | 6 +- .../ccsrc/parallel/ops_info/operator_info.h | 2 +- .../parallel/ops_info/ops_info_head_files.h | 1 + mindspore/ccsrc/parallel/ops_info/ops_utils.h | 3 +- mindspore/ccsrc/parallel/step_parallel.cc | 33 +- tests/ut/python/parallel/test_gather_v2.py | 173 +++++++++ .../parallel/test_gather_v2_primitive.py | 2 +- 13 files changed, 728 insertions(+), 21 deletions(-) create mode 100644 mindspore/ccsrc/parallel/ops_info/gather_v2_p_info.cc create mode 100644 mindspore/ccsrc/parallel/ops_info/gather_v2_p_info.h create mode 100644 tests/ut/python/parallel/test_gather_v2.py diff --git a/mindspore/ccsrc/parallel/auto_parallel/operator_costmodel.cc b/mindspore/ccsrc/parallel/auto_parallel/operator_costmodel.cc index 8ad8b46f32b..8f47d30273d 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/operator_costmodel.cc +++ b/mindspore/ccsrc/parallel/auto_parallel/operator_costmodel.cc @@ -787,5 +787,90 @@ double LayerNormCost::GetForwardComputationCost(const std::vector &i } return result; } + +double GatherV2PCost::GetForwardCommCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const { + double result = 0.0; + if (outputs_type_lengths_.size() != outputs.size()) { + MS_LOG(EXCEPTION) << "Invalid inputs type size " << inputs_type_lengths_.size() << " for gatherv2 cost"; + } + // don't split axis + if (strategy_.at(IntToSize(axis_)) == 1) { + return result; + } + + // split axis + auto param_shape = inputs[0].slice_shape(); + auto index_shape = inputs[1].slice_shape(); + Shape reducescatter_shape = index_shape; + if (param_shape.size() == 2) { + reducescatter_shape.push_back(param_shape.at(1 - axis_)); + } + result += ListProduct(reducescatter_shape) * static_cast(outputs_type_lengths_[0]); + return result; +} + +double GatherV2PCost::GetBackwardCommCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const { + double result = 0.0; + CheckGlobalDeviceManager(); + MS_EXCEPTION_IF_NULL(g_device_manager); + auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size(); + + for (size_t j = 0; j < inputs.size(); ++j) { + if (!is_parameter_[j]) { + continue; + } + TensorInfo input_a_tensor_info = inputs[j]; + Shape input_a_shape = input_a_tensor_info.shape(); + Shape input_a_slice_shape = input_a_tensor_info.slice_shape(); + int32_t used_device_num = 1; + for (size_t i = 0; i < input_a_shape.size(); ++i) { + used_device_num *= input_a_shape[i] / input_a_slice_shape[i]; + } + if (total_device_num != IntToSize(used_device_num)) { + result += ListProduct(input_a_slice_shape) * static_cast(inputs_type_lengths_[0]); + } + } + return result; +} + +double GatherV2PCost::GetForwardComputationCost(const std::vector &inputs, + const std::vector &outputs, int32_t stage_id) const { + double result = 0.0; + Shape input0_slice_shape = inputs[0].slice_shape(); + Shape input1_slice_shape = inputs[1].slice_shape(); + if (inputs_type_lengths_.size() != inputs.size()) { + MS_LOG(EXCEPTION) << "Invalid inputs type size " << inputs_type_lengths_.size() << " for gatherv2 cost"; + } + // don't split axis + if (strategy_.at(IntToSize(axis_)) == 1) { + result += ListProduct(input0_slice_shape) * static_cast(inputs_type_lengths_[0]) + + ListProduct(input1_slice_shape) * static_cast(inputs_type_lengths_[1]); + } else { + // split axis + result += ListProduct(input0_slice_shape) * static_cast(inputs_type_lengths_[0]) * GATHERV2_COST_WEIGHT0 + + ListProduct(input1_slice_shape) * static_cast(inputs_type_lengths_[1]) * GATHERV2_COST_WEIGHT1; + } + + return result; +} + +double GatherV2PCost::GetBackwardComputationCost(const std::vector &inputs, + const std::vector &outputs, int32_t) const { + double result = 0.0; + Shape input1_slice_shape = inputs[1].slice_shape(); + Shape output0_slice_shape = outputs[0].slice_shape(); + // don't split axis + if (strategy_.at(IntToSize(axis_)) == 1) { + result += ListProduct(output0_slice_shape) * static_cast(inputs_type_lengths_[0]); + } else { + // split axis + result += ListProduct(output0_slice_shape) * static_cast(inputs_type_lengths_[0]) * GATHERV2_COST_WEIGHT2 + + ListProduct(input1_slice_shape) * static_cast(inputs_type_lengths_[1]) * GATHERV2_COST_WEIGHT3; + } + + return result; +} } // namespace parallel } // namespace mindspore diff --git a/mindspore/ccsrc/parallel/auto_parallel/operator_costmodel.h b/mindspore/ccsrc/parallel/auto_parallel/operator_costmodel.h index a243f8adfae..7d966845d1a 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/operator_costmodel.h +++ b/mindspore/ccsrc/parallel/auto_parallel/operator_costmodel.h @@ -27,6 +27,10 @@ namespace parallel { #define MAXIMUM_INPUT_NUMBER 100 #define DEFAULT_DATA_TYPE_LENGTH 4 #define DROPOUT_COST_RATE 1.125 // the DropoutGenMask need 12.5% memory +#define GATHERV2_COST_WEIGHT0 3 +#define GATHERV2_COST_WEIGHT1 7 +#define GATHERV2_COST_WEIGHT2 2 +#define GATHERV2_COST_WEIGHT3 6 class OperatorCost; using OperatorCostPtr = std::shared_ptr; @@ -609,6 +613,38 @@ class GatherV2Cost : public OperatorCost { }; using GatherV2CostPtr = std::shared_ptr; + +class GatherV2PCost : public OperatorCost { + public: + explicit GatherV2PCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} + GatherV2PCost() : OperatorCost(true) {} + ~GatherV2PCost() override = default; + + double GetCommCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override { + return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); + } + double GetForwardCommCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override; + double GetBackwardCommCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override; + double GetComputationCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override { + return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); + } + double GetForwardComputationCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override; + double GetBackwardComputationCost(const std::vector &inputs, const std::vector &outputs, + int32_t) const override; + void set_axis(int32_t axis) { axis_ = axis; } + void set_strategy(const Shape &strategy) { strategy_ = strategy; } + + protected: + int32_t axis_; + Shape strategy_; +}; + +using GatherV2PCostPtr = std::shared_ptr; } // namespace parallel } // namespace mindspore #endif // PARALLEL_AUTO_PARALLEL_OPERATOR_COSTMODEL_H_ diff --git a/mindspore/ccsrc/parallel/dynamic_creator.h b/mindspore/ccsrc/parallel/dynamic_creator.h index 42ba42cf8a4..7b09193dbd2 100644 --- a/mindspore/ccsrc/parallel/dynamic_creator.h +++ b/mindspore/ccsrc/parallel/dynamic_creator.h @@ -129,6 +129,7 @@ REGISTER(ExpandDimsInfo); REGISTER(SqueezeInfo); REGISTER(SigmoidCrossEntropyWithLogitsInfo); REGISTER(SquareInfo); +REGISTER(GatherV2PInfo); } // namespace parallel } // namespace mindspore diff --git a/mindspore/ccsrc/parallel/graph_util/generate_graph.h b/mindspore/ccsrc/parallel/graph_util/generate_graph.h index d5535c7dc20..71227a6e7b6 100644 --- a/mindspore/ccsrc/parallel/graph_util/generate_graph.h +++ b/mindspore/ccsrc/parallel/graph_util/generate_graph.h @@ -41,6 +41,7 @@ ValuePtr CreatOpInstance(const OperatorAttrs &attrs, const OperatorName &op_name AnfNodePtr CreatTypeInt(int32_t value); AnfNodePtr CreatInt32Imm(int32_t value); AnfNodePtr CreateInt32Tensor(int32_t value); +AnfNodePtr ValuePtrToAnfNodePtr(const ValuePtr &value_ptr); std::string HashInstanceName(const std::string &name); class GenerateGraph { diff --git a/mindspore/ccsrc/parallel/ops_info/gather_v2_p_info.cc b/mindspore/ccsrc/parallel/ops_info/gather_v2_p_info.cc new file mode 100644 index 00000000000..15d34c6677c --- /dev/null +++ b/mindspore/ccsrc/parallel/ops_info/gather_v2_p_info.cc @@ -0,0 +1,339 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "parallel/ops_info/gather_v2_p_info.h" + +#include +#include +#include +#include + +#include "parallel/device_matrix.h" +#include "parallel/graph_util/generate_graph.h" + +namespace mindspore { +namespace parallel { +Status GatherV2PInfo::GetAttrs() { + // get axis, the third input is the axis, is a ValueNode + if (input_value_.at(2) == nullptr) { + MS_LOG(ERROR) << name_ << ": the third input value is nullptr, is not a ValueNode!"; + return FAILED; + } + auto axis = GetValue(input_value_.at(2)); + // if axis is negative then convert it to positive + auto params_shape = inputs_shape_.at(0); + if (params_shape.size() == 0) { + MS_LOG(ERROR) << name_ << ": params can not be a scalar!"; + return FAILED; + } + if (axis < 0) { + axis += SizeToInt(inputs_shape_[0].size()); + } + axis_ = axis; + + return SUCCESS; +} + +Status GatherV2PInfo::CheckStrategy(const StrategyPtr &strategy) { + if (CheckStrategyValue(strategy, {inputs_shape_.at(0)}, is_auto_parallel_) != SUCCESS) { + if (is_auto_parallel_) { + MS_LOG(DEBUG) << name_ << ": Invalid strategy."; + } else { + MS_LOG(ERROR) << name_ << ": Invalid strategy."; + } + return FAILED; + } + + // only support 1-dim and 2-dim param + if (inputs_shape_.at(0).size() != 1 && inputs_shape_.at(0).size() != 2) { + MS_LOG(ERROR) << name_ << ": Don't support param dim " << inputs_shape_.at(0).size(); + return FAILED; + } + + // don't support scalar index + if (inputs_shape_.at(1).size() == 0) { + MS_LOG(ERROR) << name_ << ": Don't support scalar index."; + return FAILED; + } + + // axis=0, index_shape(0)%param_strategy(0) must be 0 + Shape index_shape = inputs_shape_.at(1); + auto param_strategy = strategy->GetInputDim().at(0); + if ((axis_ == 0) && (index_shape.at(0) % param_strategy.at(0) != 0)) { + MS_LOG(ERROR) << name_ << ": index_shape(0) can't be divided by param_strategy(0)."; + return FAILED; + } + + // axis != 0, param_shape(0)%(param_strategy(0)*param_strategy(axis)) must be 0 + Shape param_shape = inputs_shape_.at(0); + if (axis_ != 0 && param_shape.at(0) % (param_strategy.at(0) * param_strategy.at(IntToSize(axis_))) != 0) { + MS_LOG(ERROR) << name_ << ": index_shape(0) can't be divided by (param_strategy(0)*param_strategy(axis))."; + return FAILED; + } + + // Don't support repeated calc + auto params_strategy = strategy->GetInputDim().at(0); + CheckGlobalDeviceManager(); + size_t dev_num = g_device_manager->GetDeviceListByStageId(0).size(); + auto product = std::accumulate(params_strategy.begin(), params_strategy.end(), 1, std::multiplies()); + if (dev_num != IntToSize(product)) { + MS_LOG(ERROR) << name_ << ": Invalid strategy. Don't support repeated calc."; + return FAILED; + } + + return SUCCESS; +} + +Status GatherV2PInfo::InferDevMatrixShape() { + dev_matrix_shape_.clear(); + out_dev_matrix_shape_.clear(); + // infer input dev_matrix_shape + auto params_strategy = strategy_->GetInputDim().at(0); + dev_matrix_shape_ = params_strategy; + + // infer out dev_matrix_shape + // axis!=0, split axis + if (axis_ != 0 && params_strategy.at(IntToSize(axis_)) != 1) { + out_dev_matrix_shape_.push_back(params_strategy.at(0) * params_strategy.at(IntToSize(axis_))); + for (size_t i = 1; i < params_strategy.size(); ++i) { + if (i == IntToSize(axis_)) { + out_dev_matrix_shape_.push_back(1); + } else { + out_dev_matrix_shape_.push_back(params_strategy.at(i)); + } + } + } else { + out_dev_matrix_shape_ = params_strategy; + } + + return SUCCESS; +} + +Status GatherV2PInfo::InferTensorMap() { + // infer input tensor map + size_t param_size = inputs_shape_.at(0).size(); + size_t index_size = inputs_shape_.at(1).size(); + std::vector tensor_map_index(index_size, -1); + std::vector tensor_map_params; + for (size_t i = 0; i < param_size; ++i) { + tensor_map_params.push_back(SizeToInt(param_size - i - 1)); + } + + // infer output tensor map + std::vector tensor_map_out; + if (axis_ == 0) { + tensor_map_out.push_back(SizeToInt(param_size - 1)); + tensor_map_out.insert(tensor_map_out.end(), index_size - 1, -1); + for (size_t i = 1; i < param_size; ++i) { + tensor_map_out.push_back(SizeToInt(param_size - i - 1)); + } + } else { + for (size_t i = 0; i < param_size; ++i) { + if (i == IntToSize(axis_)) { + tensor_map_out.insert(tensor_map_out.end(), index_size, -1); + } else { + tensor_map_out.push_back(SizeToInt(param_size - i - 1)); + } + } + } + + inputs_tensor_map_.emplace_back(std::move(tensor_map_params)); + inputs_tensor_map_.emplace_back(std::move(tensor_map_index)); + outputs_tensor_map_.emplace_back(std::move(tensor_map_out)); + return SUCCESS; +} + +Status GatherV2PInfo::InferTensorInfo() { + // infer tensor shape + Shape input_shape = inputs_shape_.at(0); + Shape input_index_shape = inputs_shape_.at(1); + Shape output_shape = outputs_shape_.at(0); + // infer tensor layout + TensorLayout input_tensor_layout, input_index_layout, output_tensor_layout; + if ((input_tensor_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_.at(0), input_shape) != SUCCESS) || + (input_index_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_.at(1), input_index_shape) != SUCCESS) || + (output_tensor_layout.InitFromVector(out_dev_matrix_shape_, outputs_tensor_map_.at(0), output_shape) != + SUCCESS)) { + return FAILED; + } + // infer tensor info + TensorInfo input_tensor_info(input_tensor_layout); + TensorInfo input_index_info(input_index_layout); + TensorInfo output_tensor_info(output_tensor_layout); + + inputs_tensor_info_.push_back(input_tensor_info); + inputs_tensor_info_.push_back(input_index_info); + outputs_tensor_info_.push_back(output_tensor_info); + return SUCCESS; +} + +Status GatherV2PInfo::InferBias() { + CheckGlobalDeviceManager(); + int32_t rank = g_device_manager->global_rank(); + auto input_shape = inputs_shape_.at(0); + auto params_strategy = strategy_->GetInputDim().at(0); + // params_size=1, axis=0 + if ((input_shape.size() == 1) && (axis_ == 0)) { + slice_size_ = input_shape.at(0) / params_strategy.at(0); + bias_ = rank * slice_size_; + return SUCCESS; + } + // params_size=2, axis=0 + if ((input_shape.size() == 2) && (axis_ == 0)) { + slice_size_ = input_shape.at(0) / params_strategy.at(0); + bias_ = rank / params_strategy.at(1) * slice_size_; + return SUCCESS; + } + // params_size=2, axis=1 + if ((input_shape.size() == 2) && (axis_ == 1)) { + slice_size_ = input_shape.at(1) / params_strategy.at(1); + bias_ = rank % params_strategy.at(1) * slice_size_; + return SUCCESS; + } + MS_LOG(ERROR) << name_ << ": Don't support params_size:" << input_shape.size() << " axis:" << axis_; + return FAILED; +} + +Status GatherV2PInfo::InferGroup() { + std::vector group_list; + if (CreateGroupByDim(IntToSize(axis_), &group_list) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Create group failed."; + return FAILED; + } + + group_ = group_list.at(0); + return SUCCESS; +} + +Status GatherV2PInfo::ComputeReplaceGraph(const CNodePtr &cnode) { + GenerateGraph gen_g = GenerateGraph(); + if (gen_g.Init(cnode) != SUCCESS) { + MS_LOG(ERROR) << "GenerateGraph Init failed"; + return FAILED; + } + if (InferBias() != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Infer Bias failed."; + return FAILED; + } + auto sub = gen_g.PushBack({gen_g.NewOpInst(SUB), gen_g.virtual_input_node(), CreateInt32Tensor(bias_)}); + auto relu = gen_g.PushBack({gen_g.NewOpInst(RELU), sub}); + auto minimum = gen_g.PushBack({gen_g.NewOpInst(MINIMUM), relu, CreateInt32Tensor(slice_size_ - 1)}); + auto equal = gen_g.PushBack({gen_g.NewOpInst(EQUAL), gen_g.virtual_input_node(), minimum}); + auto gather_v2 = + gen_g.PushBack({gen_g.NewOpInst(GATHERV2), gen_g.virtual_input_node(), minimum, CreatInt32Imm(axis_)}); + auto dtype = gen_g.PushBack({gen_g.NewOpInst(DTYPE), gather_v2}); + auto cast = gen_g.PushBack({gen_g.NewOpInst(CAST), equal, dtype}); + auto expand_dims = gen_g.PushBack({gen_g.NewOpInst(EXPAND_DIMS), cast, CreatInt32Imm(axis_ - 1)}); + auto mul = gen_g.PushBack({gen_g.NewOpInst(MUL), gather_v2, expand_dims}); + // don't need expandim,if param_size = 1, + if (inputs_shape_.at(0).size() == 1) { + mul = gen_g.PushBack({gen_g.NewOpInst(MUL), gather_v2, cast}); + } + if (InferGroup() != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Infer Group failed."; + return FAILED; + } + Attr attr_op = std::make_pair(OP, MakeValue(REDUCE_OP_SUM)); + Attr attr_group = std::make_pair(GROUP, MakeValue(group_.name())); + OperatorAttrs attrs = {attr_op, attr_group}; + auto reduce_scatter = gen_g.PushBack({gen_g.NewOpInst(REDUCE_SCATTER, attrs), mul}); + std::vector> input_nodes = {std::make_pair(sub, 2), std::make_pair(gather_v2, 1), + std::make_pair(equal, 2)}; + replace_graph_ = std::make_shared>, AnfNodePtr>>( + std::make_pair(input_nodes, reduce_scatter)); + + return SUCCESS; +} + +Status GatherV2PInfo::Init(const StrategyPtr &strategy) { + auto param_strategy = strategy->GetInputDim().at(0); + if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Init failed."; + return FAILED; + } + if (param_strategy.at(IntToSize(axis_)) != 1 && ComputeReplaceGraph(cnode_) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": ComputeReplaceGraph failed."; + return FAILED; + } + MS_LOG(INFO) << name_ << ": Init success."; + return SUCCESS; +} + +Status GatherV2PInfo::InitForCostModel(const StrategyPtr &strategy) { + if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { + if (is_auto_parallel_) { + MS_LOG(DEBUG) << name_ << ": Init for cost model failed."; + } else { + MS_LOG(ERROR) << name_ << ": Init for cost model failed."; + } + return FAILED; + } + auto param_strategy = strategy_->GetInputDim().at(0); + // cost model set axis and strategy + auto gatherv2_2cost = std::dynamic_pointer_cast(operator_cost()); + gatherv2_2cost->set_axis(axis_); + gatherv2_2cost->set_strategy(param_strategy); + MS_LOG(INFO) << name_ << ": Init for cost model success."; + return SUCCESS; +} + +Status GatherV2PInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { + if (SetCostUnderStrategyBase(strategy) != SUCCESS) { + if (is_auto_parallel_) { + MS_LOG(DEBUG) << name_ << ": Set cost under strategy failed."; + } else { + MS_LOG(ERROR) << name_ << ": Set cost under strategy failed."; + } + return FAILED; + } + return SUCCESS; +} + +Status GatherV2PInfo::GenerateStrategies(int32_t stage_id) { + is_auto_parallel_ = true; + Shape input0_split(inputs_shape_[0].size(), 1); + Shapes splittable_inputs = {input0_split}; + + std::vector sp_vector; + if (GenerateStrategiesForIndependentInputs(stage_id, {inputs_shape_.at(0)}, splittable_inputs, &sp_vector) != + SUCCESS) { + MS_LOG(ERROR) << name_ << " : Generate strategies for independent inputs() failed."; + return FAILED; + } + size_t success = 0; + for (auto &sp : sp_vector) { + if (SetCostUnderStrategy(sp) == SUCCESS) { + success++; + MS_LOG(INFO) << name_ << " : Successfully generated " << success << " strategy"; + PrintStrategy(sp); + } + } + return SUCCESS; +} + +std::shared_ptr>> GatherV2PInfo::GenerateBatchStrategies() { + CheckGlobalDeviceManager(); + size_t dev_num = g_device_manager->GetDeviceListByStageId(0).size(); + Dimensions strategy; + strategy.push_back(SizeToInt(dev_num)); + for (size_t i = 1; i < inputs_shape_[0].size(); i++) { + strategy.push_back(1); + } + std::vector strategy_v = {strategy}; + return std::make_shared>>(strategy_v); +} +} // namespace parallel +} // namespace mindspore diff --git a/mindspore/ccsrc/parallel/ops_info/gather_v2_p_info.h b/mindspore/ccsrc/parallel/ops_info/gather_v2_p_info.h new file mode 100644 index 00000000000..03cbd70e8d6 --- /dev/null +++ b/mindspore/ccsrc/parallel/ops_info/gather_v2_p_info.h @@ -0,0 +1,67 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PARALLEL_OPS_INFO_GATHER_V2_P_INFO_H_ +#define MINDSPORE_CCSRC_PARALLEL_OPS_INFO_GATHER_V2_P_INFO_H_ + +#include +#include +#include +#include + +#include "ir/value.h" +#include "parallel/auto_parallel/operator_costmodel.h" +#include "parallel/ops_info/operator_info.h" +#include "parallel/strategy.h" + +namespace mindspore { +namespace parallel { +class GatherV2PInfo : public OperatorInfo { + public: + GatherV2PInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) + : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared()) {} + ~GatherV2PInfo() override = default; + Status Init(const StrategyPtr &strategy) override; + Status InitForCostModel(const StrategyPtr &strategy) override; + + Status GenerateStrategies(int32_t stage_id) override; + Status SetCostUnderStrategy(const StrategyPtr &strategy) override; + std::shared_ptr>> GenerateBatchStrategies() override; + + protected: + Status CheckStrategy(const StrategyPtr &strategy) override; + Status InferMirrorOps() override { return SUCCESS; } + Status InferForwardCommunication() override { return SUCCESS; } + Status InferTensorInfo() override; + Status InferDevMatrixShape() override; + Status InferTensorMap() override; + Status GetAttrs() override; + + private: + Status ComputeReplaceGraph(const CNodePtr &cnode); + Status InferBias(); + Status InferGroup(); + + int32_t axis_; + int32_t bias_; + int32_t slice_size_; + Shape out_dev_matrix_shape_; + Group group_; +}; +} // namespace parallel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_GATHER_V2_P_INFO_H_ diff --git a/mindspore/ccsrc/parallel/ops_info/onehot_info.cc b/mindspore/ccsrc/parallel/ops_info/onehot_info.cc index 2c06a1ace94..ea2d0451040 100644 --- a/mindspore/ccsrc/parallel/ops_info/onehot_info.cc +++ b/mindspore/ccsrc/parallel/ops_info/onehot_info.cc @@ -215,9 +215,9 @@ Status OneHotInfo::ComputeReplaceGraph(const CNodePtr &cnode) { OperatorAttrs attrs_onehot = {attr_onehot_axis}; auto onehot = gen_g.PushBack({gen_g.NewOpInst(ONEHOT, attrs_onehot), sub2, CreatInt32Imm(classes_each_device_), cnode->input(3), cnode->input(4)}); - std::vector input_nodes = {floor_div, sub1}; - replace_graph_ = - std::make_shared, AnfNodePtr>>(std::make_pair(input_nodes, onehot)); + std::vector> input_nodes = {std::make_pair(floor_div, 1), std::make_pair(sub1, 1)}; + replace_graph_ = std::make_shared>, AnfNodePtr>>( + std::make_pair(input_nodes, onehot)); return SUCCESS; } diff --git a/mindspore/ccsrc/parallel/ops_info/operator_info.h b/mindspore/ccsrc/parallel/ops_info/operator_info.h index de95bd84ad1..6bb766026fa 100644 --- a/mindspore/ccsrc/parallel/ops_info/operator_info.h +++ b/mindspore/ccsrc/parallel/ops_info/operator_info.h @@ -48,7 +48,7 @@ using TensorLayouts = std::vector; using different_type = std::vector::difference_type; using PrimitiveAttrs = std::unordered_map; using Strategys = std::vector; -using ReplaceGraphPtr = std::shared_ptr, AnfNodePtr>>; +using ReplaceGraphPtr = std::shared_ptr>, AnfNodePtr>>; class Edge; diff --git a/mindspore/ccsrc/parallel/ops_info/ops_info_head_files.h b/mindspore/ccsrc/parallel/ops_info/ops_info_head_files.h index aec25f7f414..45b00aed30e 100644 --- a/mindspore/ccsrc/parallel/ops_info/ops_info_head_files.h +++ b/mindspore/ccsrc/parallel/ops_info/ops_info_head_files.h @@ -36,5 +36,6 @@ #include "parallel/ops_info/reshape_info.h" #include "parallel/ops_info/transpose_info.h" #include "parallel/ops_info/virtual_dataset_info.h" +#include "parallel/ops_info/gather_v2_p_info.h" #endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_HEAD_FILES_H_ diff --git a/mindspore/ccsrc/parallel/ops_info/ops_utils.h b/mindspore/ccsrc/parallel/ops_info/ops_utils.h index e0b62eb2337..9b7aceba862 100644 --- a/mindspore/ccsrc/parallel/ops_info/ops_utils.h +++ b/mindspore/ccsrc/parallel/ops_info/ops_utils.h @@ -114,7 +114,7 @@ constexpr char BE_CLONED_INDEX[] = "be_cloned_index"; constexpr char GROUP_RANKS[] = "group_ranks"; constexpr char IS_IN_FORWARD[] = "is_in_forward"; constexpr char DEFAULT_INPUT[] = "default_input"; -constexpr char DTYPE[] = "dtype"; +constexpr char DTYPE[] = "DType"; constexpr char DEV_NUM[] = "dev_num"; constexpr char MEAN_FLAG[] = "mean_flag"; constexpr char TYPES[] = "types"; @@ -124,6 +124,7 @@ constexpr char SHARED_NAME[] = "shared_name"; constexpr char MIRROR_OP[] = "mirror_op"; constexpr char FORWARD_OP[] = "forward_op"; constexpr char REDISTRIBUTION_OP[] = "redistribution_op"; +constexpr char DARA_PARALLEL[] = "data_parallel"; // Operator constexpr char VIRTUAL_DIV[] = "_VirtualDiv"; diff --git a/mindspore/ccsrc/parallel/step_parallel.cc b/mindspore/ccsrc/parallel/step_parallel.cc index d445231e8eb..a344607362f 100644 --- a/mindspore/ccsrc/parallel/step_parallel.cc +++ b/mindspore/ccsrc/parallel/step_parallel.cc @@ -605,8 +605,7 @@ bool IsSomePrimitive(const CNodePtr &cnode, const std::string &name) { return (prim->name() == name); } -void StepReplaceGraph(const std::shared_ptr, AnfNodePtr>> &replace_graph, - const CNodePtr &node) { +void StepReplaceGraph(const ReplaceGraphPtr &replace_graph, const CNodePtr &node) { MS_EXCEPTION_IF_NULL(replace_graph); MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(replace_graph->second); @@ -616,20 +615,10 @@ void StepReplaceGraph(const std::shared_ptr, A if (manager == nullptr) { MS_LOG(EXCEPTION) << "Failure:AddNode error since manager is nullptr"; } - if (!IsSomePrimitive(node, ONEHOT)) { - MS_LOG(EXCEPTION) << "Failure:Only OneHot Primitive will enter StepReplaceGraph!"; - } - if (node->inputs().size() != 5) { - MS_LOG(EXCEPTION) << "Failure:There is 5 inputs for the CNode corresponding to OneHot Primitive!"; - } - auto pre_node = node->input(1); - if (replace_graph->first.size() != 2) { - MS_LOG(EXCEPTION) << "Failure:replace_graph->first.size() must be 2 for OneHot Primitive!"; - } for (auto &replace_input : replace_graph->first) { - MS_EXCEPTION_IF_NULL(replace_input); - manager->SetEdge(replace_input, 1, pre_node); - CNodePtr replace_input_cnode = replace_input->cast(); + auto pre_node = node->input(IntToSize(replace_input.second)); + manager->SetEdge(replace_input.first, 1, pre_node); + auto replace_input_cnode = replace_input.first->cast(); MS_EXCEPTION_IF_NULL(replace_input_cnode); (void)replace_input_cnode->set_operator_info(node->operator_info()); replace_input_cnode->set_in_forward_flag(true); // mark this new cnode is forward node @@ -943,6 +932,20 @@ OperatorInfoPtr OperatorInstanceByName(const std::string &name, const PrimitiveA MS_LOG(EXCEPTION) << "Length of name is zero!"; } std::string distribute_opname = GetDisOpName(name); + if (name == GATHERV2) { + distribute_opname = name + "PInfo"; + auto data_parallel_iter = attrs.find(DATA_PARALLEL); + if (data_parallel_iter != attrs.end()) { + MS_EXCEPTION_IF_NULL(data_parallel_iter->second); + if (!data_parallel_iter->second->isa()) { + MS_LOG(EXCEPTION) << ": data_parallel flag's type is not a bool."; + } + bool data_parallel = data_parallel_iter->second->cast()->value(); + if (data_parallel) { + distribute_opname = name + "Info"; + } + } + } OperatorInfoPtr operator_ = (OperatorInfoPtr)DynCreator::Instance().Creat(distribute_opname, shape_list[0], shape_list[1], attrs, TOTAL_OPS); if (operator_ == nullptr) { diff --git a/tests/ut/python/parallel/test_gather_v2.py b/tests/ut/python/parallel/test_gather_v2.py new file mode 100644 index 00000000000..a9e01b3c24f --- /dev/null +++ b/tests/ut/python/parallel/test_gather_v2.py @@ -0,0 +1,173 @@ +# Copyright 2019 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np +import mindspore as ms +from mindspore import Tensor +from mindspore import context +import mindspore.nn as nn +from mindspore.ops import operations as P +from mindspore.ops import composite as C +from mindspore.common import dtype as mstype +from mindspore.common.api import _executor +from tests.ut.python.ops.test_math_ops import VirtualLoss + + +class NetWithLoss(nn.Cell): + def __init__(self, network): + super(NetWithLoss, self).__init__() + self.loss = VirtualLoss() + self.network = network + + def construct(self, x, y): + predict = self.network(x, y) + return self.loss(predict) + + +class GradWrap(nn.Cell): + def __init__(self, network): + super(GradWrap, self).__init__() + self.network = network + + def construct(self, x, y): + return C.grad_all(self.network)(x, y) + + +class Net(nn.Cell): + def __init__(self, axis=0, strategy1=None, strategy2=None, shape=[64, 64]): + super().__init__() + self.gatherv2 = P.GatherV2().set_strategy(strategy1) + self.mul = P.Mul().set_strategy(strategy2) + self.index = Tensor(np.ones(shape), dtype=ms.int32) + self.axis = axis + + def construct(self, x, y): + out = self.gatherv2(x, self.index, self.axis) + out = self.mul(out, y) + return out + + +def test_gatherv2_semi_auto0(): + context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") + strategy1 = ((1, 8), ) + strategy2 = ((4, 2, 1), (4, 2, 1)) + net = GradWrap(NetWithLoss(Net(0, strategy1, strategy2))) + net.set_auto_parallel() + + x = Tensor(np.ones([64, 32]), dtype=ms.float32) + y = Tensor(np.ones([64, 64, 32]), dtype=ms.float32) + _executor.compile(net, x, y) + +def test_gatherv2_semi_auto1(): + context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") + strategy1 = ((8, 1), ) + strategy2 = ((4, 2, 1), (4, 2, 1)) + net = GradWrap(NetWithLoss(Net(0, strategy1, strategy2))) + net.set_auto_parallel() + + x = Tensor(np.ones([64, 32]), dtype=ms.float32) + y = Tensor(np.ones([64, 64, 32]), dtype=ms.float32) + _executor.compile(net, x, y) + +def test_gatherv2_semi_auto2(): + context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") + strategy1 = ((2, 4), ) + strategy2 = ((4, 2, 1), (4, 2, 1)) + net = GradWrap(NetWithLoss(Net(0, strategy1, strategy2))) + net.set_auto_parallel() + + x = Tensor(np.ones([64, 32]), dtype=ms.float32) + y = Tensor(np.ones([64, 64, 32]), dtype=ms.float32) + _executor.compile(net, x, y) + +def test_gatherv2_semi_auto3(): + context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") + strategy1 = ((1, 8), ) + strategy2 = ((4, 2, 1), (4, 2, 1)) + net = GradWrap(NetWithLoss(Net(1, strategy1, strategy2))) + net.set_auto_parallel() + + x = Tensor(np.ones([64, 32]), dtype=ms.float32) + y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32) + _executor.compile(net, x, y) + +def test_gatherv2_semi_auto4(): + context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") + strategy1 = ((8, 1), ) + strategy2 = ((4, 2, 1), (4, 2, 1)) + net = GradWrap(NetWithLoss(Net(1, strategy1, strategy2))) + net.set_auto_parallel() + + x = Tensor(np.ones([64, 32]), dtype=ms.float32) + y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32) + _executor.compile(net, x, y) + +def test_gatherv2_semi_auto5(): + context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") + strategy1 = ((2, 4), ) + strategy2 = ((4, 2, 1), (4, 2, 1)) + net = GradWrap(NetWithLoss(Net(1, strategy1, strategy2))) + net.set_auto_parallel() + + x = Tensor(np.ones([64, 32]), dtype=ms.float32) + y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32) + _executor.compile(net, x, y) + +def test_gatherv2_semi_auto6(): + context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") + strategy2 = ((4, 2, 1), (4, 2, 1)) + net = GradWrap(NetWithLoss(Net(0, None, strategy2))) + net.set_auto_parallel() + + x = Tensor(np.ones([64, 32]), dtype=ms.float32) + y = Tensor(np.ones([64, 64, 32]), dtype=ms.float32) + _executor.compile(net, x, y) + +def test_gatherv2_semi_auto7(): + context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") + strategy2 = ((4, 2, 1), (4, 2, 1)) + net = GradWrap(NetWithLoss(Net(1, None, strategy2))) + net.set_auto_parallel() + + x = Tensor(np.ones([64, 32]), dtype=ms.float32) + y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32) + _executor.compile(net, x, y) + +def test_gatherv2_semi_auto8(): + context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") + strategy1 = ((8, ), ) + strategy2 = ((4, 2), (4, 2)) + net = GradWrap(NetWithLoss(Net(0, strategy1, strategy2))) + net.set_auto_parallel() + + x = Tensor(np.ones([64]), dtype=ms.float32) + y = Tensor(np.ones([64, 64]), dtype=ms.float32) + _executor.compile(net, x, y) + +def test_gatherv2_auto0(): + context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="auto_parallel") + net = GradWrap(NetWithLoss(Net(0))) + net.set_auto_parallel() + x = Tensor(np.ones([64, 32]), dtype=ms.float32) + y = Tensor(np.ones([64, 64, 32]), dtype=ms.float32) + _executor.compile(net, x, y) + +def test_gatherv2_auto1(): + context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="auto_parallel") + net = GradWrap(NetWithLoss(Net(1))) + net.set_auto_parallel() + x = Tensor(np.ones([64, 32]), dtype=ms.float32) + y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32) + _executor.compile(net, x, y) + diff --git a/tests/ut/python/parallel/test_gather_v2_primitive.py b/tests/ut/python/parallel/test_gather_v2_primitive.py index 3ea0795e9c4..3812ccd819e 100644 --- a/tests/ut/python/parallel/test_gather_v2_primitive.py +++ b/tests/ut/python/parallel/test_gather_v2_primitive.py @@ -74,7 +74,7 @@ class GatherV2(_Loss): emb2_list = np.reshape(emb_list[1::2], (int(index_size/2), 16)) self.emb1_param = Tensor(emb1_list, dtype=mstype.int32) self.emb2_param = Tensor(emb2_list, dtype=mstype.int32) - self.gatherv2 = P.GatherV2().set_strategy(strategy) + self.gatherv2 = P.GatherV2().set_strategy(strategy).add_prim_attr("data_parallel", True) def construct(self, nembeddings): emb1 = self.gatherv2(nembeddings, self.emb1_param, 0)