From f9f5df368e99fd05b3898a481ee0d1cbe26f24cb Mon Sep 17 00:00:00 2001 From: yangzhenzhang Date: Fri, 16 Apr 2021 14:23:45 +0800 Subject: [PATCH] add gathernd op --- .../auto_parallel/operator_costmodel.h | 1 + .../ccsrc/frontend/parallel/dynamic_creator.h | 1 + .../parallel/ops_info/gathernd_info.cc | 214 ++++++++++++++++++ .../parallel/ops_info/gathernd_info.h | 58 +++++ .../parallel/ops_info/ops_info_head_files.h | 1 + .../frontend/parallel/ops_info/ops_utils.h | 1 + .../frontend/parallel/step_auto_parallel.cc | 22 +- tests/ut/python/parallel/test_gathernd.py | 110 +++++++++ 8 files changed, 395 insertions(+), 13 deletions(-) create mode 100644 mindspore/ccsrc/frontend/parallel/ops_info/gathernd_info.cc create mode 100644 mindspore/ccsrc/frontend/parallel/ops_info/gathernd_info.h create mode 100644 tests/ut/python/parallel/test_gathernd.py diff --git a/mindspore/ccsrc/frontend/parallel/auto_parallel/operator_costmodel.h b/mindspore/ccsrc/frontend/parallel/auto_parallel/operator_costmodel.h index bf68b8dc5bc..da6787dabc6 100644 --- a/mindspore/ccsrc/frontend/parallel/auto_parallel/operator_costmodel.h +++ b/mindspore/ccsrc/frontend/parallel/auto_parallel/operator_costmodel.h @@ -608,6 +608,7 @@ using GreaterCost = SubCost; using GreaterEqualCost = SubCost; using LessCost = SubCost; using LessEqualCost = SubCost; +using GatherNdCost = SubCost; class MulCost : public SubCost { public: diff --git a/mindspore/ccsrc/frontend/parallel/dynamic_creator.h b/mindspore/ccsrc/frontend/parallel/dynamic_creator.h index f6ec4927ed9..3fb7d20f620 100644 --- a/mindspore/ccsrc/frontend/parallel/dynamic_creator.h +++ b/mindspore/ccsrc/frontend/parallel/dynamic_creator.h @@ -191,6 +191,7 @@ REGISTER(StackInfo); REGISTER(ConcatInfo); REGISTER(SplitInfo); REGISTER(UniqueInfo); +REGISTER(GatherNdInfo); } // namespace parallel } // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/gathernd_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/gathernd_info.cc new file mode 100644 index 00000000000..3cf1e400e89 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/ops_info/gathernd_info.cc @@ -0,0 +1,214 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/parallel/ops_info/gathernd_info.h" + +#include +#include +#include +#include +#include +#include + +#include "frontend/parallel/device_matrix.h" +#include "frontend/parallel/strategy.h" +#include "frontend/parallel/tensor_layout/tensor_redistribution.h" +#include "pipeline/jit/resource.h" + +namespace mindspore { +namespace parallel { +// the input can not be split, and the last dimension of indices can not be split +Status GatherNdInfo::CheckStrategy(const StrategyPtr &strategy) { + MS_EXCEPTION_IF_NULL(strategy); + if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Invalid strategy"; + return FAILED; + } + + std::vector stra = strategy->GetInputDim(); + if (stra.size() != 2) { + MS_LOG(ERROR) << name_ << ": The size of strategies must be 2"; + return FAILED; + } + + int64_t input_split_size = std::accumulate(stra[0].begin(), stra[0].end(), 1, std::multiplies()); + if (input_split_size != 1) { + MS_LOG(ERROR) << name_ << ": The input can not be split"; + return FAILED; + } + + if (stra[1].empty()) { + MS_LOG(ERROR) << name_ << ": The strategy of indices can not be empty"; + return FAILED; + } + + if (stra[1].back() != 1) { + MS_LOG(ERROR) << name_ << ": The last dimension of indices can not be split"; + return FAILED; + } + + return SUCCESS; +} + +// the dev matrix is indices_strategy +Status GatherNdInfo::InferDevMatrixShape() { + MS_EXCEPTION_IF_NULL(strategy_); + std::vector stra = strategy_->GetInputDim(); + if (stra.size() != 2) { + MS_LOG(ERROR) << name_ << "The size of strategies must be 2"; + return FAILED; + } + + dev_matrix_shape_ = stra[1]; + return SUCCESS; +} + +// input shape: [x, y, z], indices shape: [a, b, c, 2], output shape: [a, b, c, z] +// strategy: ((1, 1, 1), (m, n, o, 1)) +// dev-matrix: [m, n, o, 1] +// input map: [-1, -1, -1], indices map: [3, 2, 1, 0], output map: [3, 2, 1, -1] +Status GatherNdInfo::InferTensorMap() { + if (inputs_shape_.size() != 2) { + MS_LOG(ERROR) << name_ << "The size of input shapes must be 2"; + return FAILED; + } + + if (outputs_shape_.empty() || outputs_shape_[0].size() < (inputs_shape_[1].size() - 1)) { + MS_LOG(ERROR) << name_ << "invalid shapes"; + return FAILED; + } + + TensorMap input_tensor_map(inputs_shape_[0].size(), MAP_NONE); // the input can not split + + // cannot use dev_matrix_shape_ replace inputs_shape_[0], because it may not be fully split in all devices. + TensorMap indices_tensor_map; + int64_t size = SizeToLong(inputs_shape_[0].size()); + for (int64_t i = 0; i < size; ++i) { + indices_tensor_map.push_back(size - i - 1); + } + + TensorMap output_tensor_map(outputs_shape_[0].size(), MAP_NONE); + for (size_t i = 0; i < (inputs_shape_[1].size() - 1); ++i) { + output_tensor_map[i] = indices_tensor_map[i]; + } + + inputs_tensor_map_.push_back(input_tensor_map); + inputs_tensor_map_.push_back(indices_tensor_map); + outputs_tensor_map_.push_back(output_tensor_map); + return SUCCESS; +} + +Status GatherNdInfo::InferTensorInfo() { + if (inputs_shape_.empty() || outputs_shape_.empty() || inputs_tensor_map_.empty() || outputs_tensor_map_.empty()) { + MS_LOG(ERROR) << name_ << ": Invalid args"; + return FAILED; + } + + TensorLayout input_layout, output_layout; + for (size_t i = 0; i < inputs_shape_.size(); ++i) { + // infer tensor layout + if (input_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_[i], inputs_shape_[i]) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Infer input tensor layout failed."; + return FAILED; + } + TensorInfo input_tensor_info(input_layout); + inputs_tensor_info_.push_back(input_tensor_info); + } + + if (output_layout.InitFromVector(dev_matrix_shape_, outputs_tensor_map_[0], outputs_shape_[0]) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Infer output tensor layout failed."; + return FAILED; + } + TensorInfo output_tensor_info(output_layout); + outputs_tensor_info_.push_back(output_tensor_info); + return SUCCESS; +} + +void GatherNdInfo::ReComputeBatchSplitFlagList() { + split_flag_list_[0] = false; + split_flag_list_[1] = true; +} + +Status GatherNdInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); } + +Status GatherNdInfo::GenerateStrategies(int64_t stage_id) { + if (InferAttrs() != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Infer attrs failed"; + return FAILED; + } + if (inputs_shape_.empty()) { + MS_LOG(ERROR) << name_ << ": The inputs shape is empty"; + return FAILED; + } + + // to generate the indices' strategy + Shape input_split(inputs_shape_[1].size(), 1); + input_split.back() = 0; + Shapes splittable_input = {input_split}; + Shapes tmp_inputs_shape = {inputs_shape_[1]}; + + std::vector sp_vector; + if (GenerateStrategiesForIndependentInputs(stage_id, tmp_inputs_shape, splittable_input, &sp_vector) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Generate strategies failed"; + return FAILED; + } + + // the others strategies are equal to the first input's strategy + for (auto &sp : sp_vector) { + if ((sp == nullptr) || sp->GetInputDim().empty()) { + MS_LOG(ERROR) << name_ << ": The strategy is null or empty"; + return FAILED; + } + Strategys tmp_strategy; + Dimensions indices_strategy = sp->GetInputDim()[0]; + Dimensions input_strategy(inputs_shape_[0].size(), 1); + tmp_strategy.push_back(input_strategy); + tmp_strategy.push_back(indices_strategy); + sp->ResetInputs(tmp_strategy); + } + + size_t success = 0; + for (auto &sp : sp_vector) { + PrintStrategy(sp); + if (SetCostUnderStrategy(sp) == SUCCESS) { + success++; + MS_LOG(INFO) << name_ << ": Successfully generated " << success << " strategy."; + PrintStrategy(sp); + } + } + return SUCCESS; +} + +Status GatherNdInfo::Init(const StrategyPtr &strategy) { + if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Init failed."; + return FAILED; + } + MS_LOG(INFO) << name_ << ": Init success."; + return SUCCESS; +} + +Status GatherNdInfo::InitForCostModel(const StrategyPtr &strategy) { + if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Init for cost model failed."; + return FAILED; + } + + MS_LOG(INFO) << name_ << ": Init for cost model success."; + return SUCCESS; +} +} // namespace parallel +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/gathernd_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/gathernd_info.h new file mode 100644 index 00000000000..bd6e3b9b3e7 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/ops_info/gathernd_info.h @@ -0,0 +1,58 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_GATHERND_INFO_H_ +#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_GATHERND_INFO_H_ + +#include +#include +#include +#include + +#include "ir/value.h" +#include "frontend/parallel/auto_parallel/operator_costmodel.h" +#include "frontend/parallel/ops_info/operator_info.h" +#include "frontend/parallel/strategy.h" + +namespace mindspore { +namespace parallel { +class GatherNdInfo : public OperatorInfo { + public: + GatherNdInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) + : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared()) {} + ~GatherNdInfo() override = default; + + Status Init(const StrategyPtr &strategy) override; + Status InitForCostModel(const StrategyPtr &strategy) override; + Status GenerateStrategies(int64_t) override; + Status SetCostUnderStrategy(const StrategyPtr &) override; + void ReComputeBatchSplitFlagList() override; + + protected: + Status GetAttrs() override { return SUCCESS; } + Status CheckStrategy(const StrategyPtr &strategy) override; + Status InferForwardCommunication() override { return SUCCESS; } + Status InferTensorInfo() override; + Status InferDevMatrixShape() override; + Status InferTensorMap() override; +}; + +using GatherNdInfoPtr = std::shared_ptr; +} // namespace parallel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_GATHERND_INFO_H_ diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/ops_info_head_files.h b/mindspore/ccsrc/frontend/parallel/ops_info/ops_info_head_files.h index 7591145e1c7..8e16a1b2ad6 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/ops_info_head_files.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/ops_info_head_files.h @@ -50,5 +50,6 @@ #include "frontend/parallel/ops_info/unique_info.h" #include "frontend/parallel/ops_info/uniform_candidate_sampler_info.h" #include "frontend/parallel/ops_info/reluv2_info.h" +#include "frontend/parallel/ops_info/gathernd_info.h" #endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_HEAD_FILES_H_ diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h b/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h index 3c4978442d8..0abf529b379 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h @@ -325,6 +325,7 @@ constexpr char DEPTHWISE_CONV2D[] = "DepthwiseConv2D"; constexpr char DROPOUT[] = "Dropout"; constexpr char KStridedSlice[] = "StridedSlice"; constexpr char UNIQUE[] = "Unique"; +constexpr char GATHERND[] = "GatherNd"; // Parallel don't care constexpr char STRING_EQUAL[] = "string_equal"; diff --git a/mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc b/mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc index f97229ca563..6be0447d3b4 100644 --- a/mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc +++ b/mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc @@ -163,7 +163,8 @@ bool IsSplittableOperator(const std::string &op_name) { BESSELI0E, BESSELI1E, FLOORMOD, ASSIGN, ASSIGN_ADD, ATAN2, DIVNONAN, LOGICALAND, LOGICALOR, ELU, RELU6, RELUV2, SOFTPLUS, SOFTSIGN, GREATEREQUAL, LESSEQUAL, LESS, APPROXIMATEEQUAL, MOD, UNIQUE, UNSORTED_SEGMENT_SUM, UNSORTED_SEGMENT_MIN, REPEAT_ELEMENTS, TENSOR_DOT, RANGE, UNIFORM_CANDIDATE_SAMPLER, SLICE, - UNSORTED_SEGMENT_MAX}; + UNSORTED_SEGMENT_MAX, GATHER_ND}; + // clang-format on auto iter = splittable_op.find(op_name); @@ -492,10 +493,9 @@ Status ConstructCostGraphNodesByUniqueIdTC(const std::vector &all_no std::map loop_to_ops; // extract strategy from checkpoint for multi-train StrategyMap stra_map; - if (StrategyCheckpoint::GetInstance().LoadCheckPointOn()) { - if (StrategyCheckpoint::GetInstance().Load(&stra_map) != SUCCESS) { - MS_LOG(EXCEPTION) << "Load strategy checkpoint failed"; - } + if (StrategyCheckpoint::GetInstance().LoadCheckPointOn() && + StrategyCheckpoint::GetInstance().Load(&stra_map) != SUCCESS) { + MS_LOG(EXCEPTION) << "Load strategy checkpoint failed"; } std::vector last_forward_node_ids; if (!root->has_flag(TRAINING)) { @@ -505,8 +505,7 @@ Status ConstructCostGraphNodesByUniqueIdTC(const std::vector &all_no for (auto &node : all_nodes) { // NOTE: we only care about splittable Primitive operators auto cnode = node->cast(); - bool bool_result = (cnode == nullptr) || (!IsValueNode(cnode->input(0))); - if (bool_result) { + if ((cnode == nullptr) || (!IsValueNode(cnode->input(0)))) { continue; } ValueNodePtr prim_anf_node = cnode->input(0)->cast(); @@ -551,9 +550,8 @@ Status ConstructCostGraphNodesByUniqueIdTC(const std::vector &all_no bool is_last_nodes = std::find(last_forward_node_ids.begin(), last_forward_node_ids.end(), cnode->UniqueId()) != last_forward_node_ids.end(); auto operator_info = CreateTheOperatorInfo(prim, cnode, is_last_nodes, &stra_map); - if (operator_info == nullptr) { - return FAILED; - } + MS_EXCEPTION_IF_NULL(operator_info); + // Needed by rec_parser operator_info->set_type(prim->name()); operator_info->set_last_node_flag(is_last_nodes); @@ -627,8 +625,7 @@ void ConstructCostGraphEdges(const std::vector &all_nodes) { MS_LOG(INFO) << "Constructing edges for cost graph begins."; for (auto &node : all_nodes) { auto cnode = node->cast(); - bool bool_result_cnode = (cnode == nullptr) || !IsValueNode(cnode->input(0)); - if (bool_result_cnode) { + if ((cnode == nullptr) || !IsValueNode(cnode->input(0))) { continue; } auto &inputs = cnode->inputs(); @@ -638,7 +635,6 @@ void ConstructCostGraphEdges(const std::vector &all_nodes) { } PrimitivePtr prim = GetValueNode(prim_anf_node); size_t edge_count = 0; - auto node_op_info = cnode->user_data(); for (size_t i = 1; i < inputs.size(); ++i) { diff --git a/tests/ut/python/parallel/test_gathernd.py b/tests/ut/python/parallel/test_gathernd.py new file mode 100644 index 00000000000..2dd16c9ba2f --- /dev/null +++ b/tests/ut/python/parallel/test_gathernd.py @@ -0,0 +1,110 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np +import pytest + +import mindspore as ms +from mindspore import context, Tensor, Parameter +from mindspore.nn import Cell, Momentum +from mindspore.ops import operations as P +from mindspore.train import Model +from tests.dataset_mock import MindData + + +class Dataset(MindData): + def __init__(self, predict, label, length=3): + super(Dataset, self).__init__(size=length) + self.predict = predict + self.label = label + self.index = 0 + self.length = length + + def __iter__(self): + return self + + def __next__(self): + if self.index >= self.length: + raise StopIteration + self.index += 1 + return self.predict, self.label + + def reset(self): + self.index = 0 + + +class Net(Cell): + def __init__(self, w1, strategy1=None, strategy2=None): + super().__init__() + self.mul = P.Mul().shard(strategy1) + self.w1 = Parameter(w1, "w1") + self.indices = Tensor(np.ones([16, 2]), dtype=ms.int32) + self.gathernd = P.GatherNd().shard(strategy2) + + def construct(self, x, b): + out = self.mul(x, self.w1) + out = self.gathernd(out, self.indices) + return out + + +_x = Tensor(np.ones([16, 64]), dtype=ms.float32) +_b = Tensor(np.ones([16, 64]), dtype=ms.float32) +_w1 = Tensor(np.ones([128, 64]), dtype=ms.float32) + + +def compile_net(net): + context.set_context(save_graphs=True) + learning_rate = 0.1 + momentum = 0.9 + epoch_size = 2 + dataset = Dataset(_x, _b) + opt = Momentum(net.trainable_params(), learning_rate, momentum) + model = Model(net, optimizer=opt) + model.train(epoch_size, dataset, dataset_sink_mode=False) + context.reset_auto_parallel_context() + + +def test_gathernd_data_parallel(): + context.set_auto_parallel_context( + parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) + strategy1 = ((8, 1), (8, 1)) + strategy2 = ((1, 1), (8, 1)) + net = Net(_w1, strategy1, strategy2) + compile_net(net) + + +def test_gathernd_model_parallel(): + context.set_auto_parallel_context( + parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) + strategy1 = ((2, 4), (2, 4)) + strategy2 = ((1, 1), (4, 1)) + net = Net(_w1, strategy1, strategy2) + compile_net(net) + + +def test_gathernd_auto_parallel(): + context.set_auto_parallel_context( + parallel_mode="auto_parallel", device_num=8, global_rank=0) + net = Net(_w1) + compile_net(net) + + +def test_gathernd_strategy_error(): + context.set_auto_parallel_context( + parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) + strategy1 = ((8, 1), (8, 1)) + strategy2 = ((1, 1), (2, 4)) + net = Net(_w1, strategy1, strategy2) + with pytest.raises(RuntimeError): + compile_net(net)