From bbaeb3203bebace07b48882f96eedb06e17bae22 Mon Sep 17 00:00:00 2001 From: yao_yf Date: Fri, 4 Nov 2022 11:03:41 +0800 Subject: [PATCH] add scatter nd op --- .../auto_parallel/operator_costmodel.cc | 56 ++- .../auto_parallel/operator_costmodel.h | 28 +- .../frontend/parallel/ops_info/ops_utils.h | 23 + .../ops_info/scatter_math_ops_info.cc | 2 +- .../parallel/ops_info/scatter_nd_ops_info.cc | 407 ++++++++++++++++++ .../parallel/ops_info/scatter_nd_ops_info.h | 179 ++++++++ .../frontend/parallel/step_parallel_utils.cc | 4 +- .../ut/python/parallel/test_scatter_nd_ops.py | 245 +++++++++++ 8 files changed, 937 insertions(+), 7 deletions(-) create mode 100644 mindspore/ccsrc/frontend/parallel/ops_info/scatter_nd_ops_info.cc create mode 100644 mindspore/ccsrc/frontend/parallel/ops_info/scatter_nd_ops_info.h create mode 100644 tests/ut/python/parallel/test_scatter_nd_ops.py diff --git a/mindspore/ccsrc/frontend/parallel/auto_parallel/operator_costmodel.cc b/mindspore/ccsrc/frontend/parallel/auto_parallel/operator_costmodel.cc index bba7ac5712c..8e13a6f3f15 100644 --- a/mindspore/ccsrc/frontend/parallel/auto_parallel/operator_costmodel.cc +++ b/mindspore/ccsrc/frontend/parallel/auto_parallel/operator_costmodel.cc @@ -1895,15 +1895,65 @@ double ScatterMathOpsCost::GetForwardComputationCost(const std::vector(inputs_type_lengths_[0]) + ListProduct(input1_slice_shape) * static_cast(inputs_type_lengths_[1]) + ListProduct(input2_slice_shape) * static_cast(inputs_type_lengths_[2]); } else { // split axis + result += ListProduct(input0_slice_shape) * static_cast(inputs_type_lengths_[0]) * input_coefficient_ + + ListProduct(input1_slice_shape) * static_cast(inputs_type_lengths_[1]) * indices_coefficient_ + + ListProduct(input2_slice_shape) * static_cast(inputs_type_lengths_[1]) * updates_coefficient_; + } + + return result; +} + +double TensorScatterOpsCost::GetBackwardCommCost(const std::vector &inputs, const std::vector &, + int64_t stage_id) const { + double result = 0.0; + CheckGlobalDeviceManager(); + MS_EXCEPTION_IF_NULL(g_device_manager); + auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size(); + + for (size_t j = 0; j < inputs.size(); ++j) { + if (!is_parameter_[j]) { + continue; + } + TensorInfo input_a_tensor_info = inputs[j]; + Shape input_a_shape = input_a_tensor_info.shape(); + Shape input_a_slice_shape = input_a_tensor_info.slice_shape(); + int64_t used_device_num = 1; + for (size_t i = 0; i < input_a_shape.size(); ++i) { + used_device_num *= input_a_shape[i] / input_a_slice_shape[i]; + } + if (total_device_num != LongToSize(used_device_num)) { + result += ListProduct(input_a_slice_shape) * static_cast(inputs_type_lengths_[j]); + } + } + return result; +} + +double TensorScatterOpsCost::GetBackwardComputationCost(const std::vector &inputs, + const std::vector &outputs, int64_t) const { + double result = 0.0; + Shape input0_slice_shape = inputs[0].slice_shape(); + Shape input1_slice_shape = inputs[1].slice_shape(); + Shape input2_slice_shape = inputs[2].slice_shape(); // equal to output shape + // brop func using 3 times input/out in average. + // don't split axis + if (!is_split_axis_) { result += ListProduct(input0_slice_shape) * static_cast(inputs_type_lengths_[0]) + - ListProduct(input1_slice_shape) * static_cast(inputs_type_lengths_[1]) * 3 + - ListProduct(input2_slice_shape) * static_cast(inputs_type_lengths_[1]) * 2; + ListProduct(input1_slice_shape) * static_cast(inputs_type_lengths_[1]) + + ListProduct(input2_slice_shape) * static_cast(inputs_type_lengths_[2]); + result *= 3; + + } else { + // split axis + result += + ListProduct(input0_slice_shape) * static_cast(inputs_type_lengths_[0]) * (input_coefficient_ + 3) + + ListProduct(input1_slice_shape) * static_cast(inputs_type_lengths_[1]) * (indices_coefficient_ + 3) + + ListProduct(input2_slice_shape) * static_cast(inputs_type_lengths_[1]) * (updates_coefficient_ + 3); } return result; diff --git a/mindspore/ccsrc/frontend/parallel/auto_parallel/operator_costmodel.h b/mindspore/ccsrc/frontend/parallel/auto_parallel/operator_costmodel.h index f4beab06afd..9eb7bd798aa 100644 --- a/mindspore/ccsrc/frontend/parallel/auto_parallel/operator_costmodel.h +++ b/mindspore/ccsrc/frontend/parallel/auto_parallel/operator_costmodel.h @@ -1226,10 +1226,34 @@ class ScatterMathOpsCost : public OperatorCost { is_inputs_should_in_memory_[0] = true; } - void set_strategy(Shape strategy) { strategy_ = strategy; } + void set_is_split_axis(bool is_split_axis) { is_split_axis_ = is_split_axis; } + void set_coefficient(int32_t input_coefficient, int32_t indices_coefficient, int32_t updates_coefficient) { + input_coefficient_ = input_coefficient; + indices_coefficient_ = indices_coefficient; + updates_coefficient_ = updates_coefficient; + } protected: - Shape strategy_; + bool is_split_axis_ = false; + int32_t input_coefficient_ = 1; + int32_t indices_coefficient_ = 3; + int32_t updates_coefficient_ = 2; +}; + +class ScatterNdOpsCost : public ScatterMathOpsCost { + public: + ScatterNdOpsCost() : ScatterMathOpsCost() {} + ~ScatterNdOpsCost() override = default; +}; + +class TensorScatterOpsCost : public ScatterMathOpsCost { + public: + TensorScatterOpsCost() : ScatterMathOpsCost() {} + ~TensorScatterOpsCost() override = default; + double GetBackwardComputationCost(const std::vector &inputs, const std::vector &outputs, + int64_t stage_id) const override; + double GetBackwardCommCost(const std::vector &inputs, const std::vector &outputs, + int64_t stage_id) const override; }; class CropAndResizeCost : public OperatorCost { diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h b/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h index 4b43c05a8cc..2136b8a6a07 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h @@ -95,6 +95,20 @@ constexpr size_t BATCH_NORM_INPUTS_SIZE = 5; constexpr double EPS = 1e-6; constexpr double INF = 1e20; constexpr double COST_FACTOR = 2.0; +constexpr size_t SIZE_ZERO = 0; +constexpr size_t SIZE_ONE = 1; +constexpr size_t SIZE_TWO = 2; +constexpr size_t SIZE_THREE = 3; +constexpr size_t SIZE_FOUR = 4; +constexpr size_t SIZE_FIVE = 5; +constexpr size_t SIZE_SIX = 6; +constexpr size_t INDEX_ZERO = 0; +constexpr size_t INDEX_ONE = 1; +constexpr size_t INDEX_TWO = 2; +constexpr size_t INDEX_THREE = 3; +constexpr size_t INDEX_FOUR = 4; +constexpr size_t INDEX_FIVE = 5; +constexpr size_t INDEX_SIX = 6; constexpr char AUTO_PARALLEL_RUN_ONCE_ONLY[] = "auto_parallel_run_once_only"; constexpr char SEMI_AUTO_PARALLEL_RUN_ONCE_ONLY[] = "semi_auto_parallel_run_once_only"; @@ -467,6 +481,15 @@ constexpr char SCATTER_MUL[] = "ScatterMul"; constexpr char SCATTER_DIV[] = "ScatterDiv"; constexpr char SCATTER_MAX[] = "ScatterMax"; constexpr char SCATTER_MIN[] = "ScatterMin"; +constexpr char SCATTER_ND_UPDATE[] = "ScatterNdUpdate"; +constexpr char SCATTER_ND_ADD[] = "ScatterNdAdd"; +constexpr char SCATTER_ND_SUB[] = "ScatterNdSub"; +constexpr char TENSOR_SCATTER_ADD[] = "TensorScatterAdd"; +constexpr char TENSOR_SCATTER_SUB[] = "TensorScatterSub"; +constexpr char TENSOR_SCATTER_MAX[] = "TensorScatterMax"; +constexpr char TENSOR_SCATTER_MIN[] = "TensorScatterMin"; +constexpr char TENSOR_SCATTER_MUL[] = "TensorScatterMul"; +constexpr char TENSOR_SCATTER_DIV[] = "TensorScatterDiv"; constexpr char GATHERD[] = "GatherD"; constexpr char DSD_MATMUL[] = "DSDMatmul"; constexpr char RESIZE_BILINEAR[] = "ResizeBilinear"; diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/scatter_math_ops_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/scatter_math_ops_info.cc index 5e3156bbea1..d0db66f21c3 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/scatter_math_ops_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/scatter_math_ops_info.cc @@ -262,7 +262,7 @@ Status ScatterMathOpsInfo::InitForCostModel(const StrategyPtr &in_strategy, cons auto param_strategy = strategy_->GetInputDim().at(0); // cost model set axis and strategy auto scatter_ops_cost = std::dynamic_pointer_cast(operator_cost()); - scatter_ops_cost->set_strategy(param_strategy); + scatter_ops_cost->set_is_split_axis((param_strategy[0] > 1)); MS_LOG(INFO) << name_ << ": Init for cost model success."; return SUCCESS; } diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/scatter_nd_ops_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/scatter_nd_ops_info.cc new file mode 100644 index 00000000000..943b9757bec --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/ops_info/scatter_nd_ops_info.cc @@ -0,0 +1,407 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/parallel/ops_info/scatter_nd_ops_info.h" + +#include +#include +#include +#include +#include + +#include "frontend/parallel/device_matrix.h" +#include "frontend/parallel/dynamic_creator.h" +#include "frontend/parallel/strategy.h" +#include "frontend/parallel/tensor_layout/tensor_redistribution.h" +#include "pipeline/jit/resource.h" +#include "frontend/parallel/tensor_layout/shape_util.h" + +namespace mindspore { +namespace parallel { +// The shape of input: [A, B, C, D], the strategy of input: (a, b, c, d) +// The shape of indices: [Q, W, 2], the strategy of indices: (1, 1, 1) +// the look dim of indices whose value (x, y) should satisfy '0 <= x < A, 0 <= y < B' +// here the 2 respect to the size of [A, B] +// The shape of updates: [Q, W, C, D], the strategy of updates: (1, 1, c, d) +// The shape of output: [A, B, C, D], the strategy of output: (a, b, c, d) +// The dev matrix: (a, b, 1, 1) +Status ScatterNdOpsInfo::CheckStrategy(const StrategyPtr &strategy) { + MS_EXCEPTION_IF_NULL(strategy); + if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Invalid strategy"; + return FAILED; + } + + std::vector stra = strategy->GetInputDim(); + if (stra.size() != 3) { + MS_LOG(ERROR) << name_ << ": The size of strategy must be 3"; + return FAILED; + } + + auto indices_shape = inputs_shape_[1]; + gather_dims_size_ = indices_shape.back(); + if (CheckInputStrategy(stra[0]) != SUCCESS) { + return FAILED; + } + if (!stra[1].empty() && std::accumulate(stra[1].begin(), stra[1].end(), 1, std::multiplies()) != 1) { + MS_LOG(ERROR) << name_ << ": The indices can not be split"; + return FAILED; + } + + if (stra[2].empty()) { + MS_LOG(ERROR) << name_ << ": The strategy[2] is empty"; + return FAILED; + } + + if (std::accumulate(stra[2].begin(), stra[2].begin() + static_cast(stra[1].size() - 1), 1, + std::multiplies()) != 1) { + MS_LOG(ERROR) << name_ << ": The first " << (stra[1].size() - 1) << " dimensions of updates can not be split"; + return FAILED; + } + + for (size_t i = 0; i < stra[0].size() - gather_dims_size_; ++i) { + if (stra[0][i + gather_dims_size_] != stra[2][stra[1].size() - 1 + i]) { + MS_LOG(ERROR) + << name_ << ": updates.strategy must be equal to indices.strategy[:-1] + input.strategy[indices.strategy[-1]:]"; + return FAILED; + } + } + + for (size_t i = 0; i < gather_dims_size_; ++i) { + if (stra[0][i] > 1) { + do_replace_graph_ = true; + break; + } + } + + return SUCCESS; +} + +// The shape of input: [A, B.., C, D], the strategy of input: (1, 1.., c, d) +Status ScatterNdUpdateInfo::CheckInputStrategy(const Dimensions &strategy_item) { + for (size_t i = 0; i < gather_dims_size_; ++i) { + if (strategy_item[i] != 1) { + MS_LOG(ERROR) << "For " << name_ + << ", " + "the input cannot be shard at gather dims input[:" + << gather_dims_size_ << "]."; + return FAILED; + } + } + return SUCCESS; +} + +// The shape of input: [A, B.., C, D], the strategy of input: (1, 1.., c, d) +Status TensorScatterUpdateInfo::CheckInputStrategy(const Dimensions &strategy_item) { + for (size_t i = 0; i < gather_dims_size_; ++i) { + if (strategy_item[i] != 1) { + MS_LOG(ERROR) << "For ScatterNdUpdate/TensorScatterUpdate, " + "the input cannot be shard at gather dims input[:" + << gather_dims_size_ << "]."; + return FAILED; + } + } + return SUCCESS; +} + +Status ScatterNdOpsInfo::InferDevMatrixShape() { + MS_EXCEPTION_IF_NULL(strategy_); + std::vector stra = strategy_->GetInputDim(); + if (stra.empty()) { + MS_LOG(ERROR) << name_ << "The strategy is empty"; + return FAILED; + } + + dev_matrix_shape_ = stra[0]; + return SUCCESS; +} + +// The shape of input: [A, B, C, D], the strategy of input: (a, b, c, d), tensor_map: (3, 2, 1, 0) +// The shape of indices: [Q, W, 2], the strategy of indices: (1, 1, 1), tensor_map: (-1, -1, -1) +// The shape of updates: [Q, W, C, D], the strategy of updates: (1, 1, c, d), tensor_map: (-1, -1, 1, 0) +// The shape of output: [A, B, C, D], the strategy of output: (a, b, c, d), tensor_map: (3, 2, 1, 0) +// The dev matrix: (a, b, c, d) +Status ScatterNdOpsInfo::InferTensorMap() { + if (inputs_shape_.size() != 3) { + MS_LOG(ERROR) << name_ << "The size of inputs shape must be 3"; + return FAILED; + } + + TensorMap input_tensor_map, updates_tensor_map; + TensorMap indices_tensor_map(inputs_shape_[1].size(), MAP_NONE); + + // cannot use dev_matrix_shape_ replace inputs_shape_[0], because it may not be fully split in all devices. + int64_t size = SizeToLong(inputs_shape_[0].size()); + for (int64_t i = 0; i < size; ++i) { + input_tensor_map.push_back(size - i - 1); + } + + // updates_tensor_map = indices_tensor_map[:-1] + input_tensor_map[gather_dim_size_:] + updates_tensor_map = indices_tensor_map; + updates_tensor_map.pop_back(); + for (size_t i = gather_dims_size_; i < input_tensor_map.size(); ++i) { + updates_tensor_map.push_back(input_tensor_map[i]); + } + inputs_tensor_map_.push_back(input_tensor_map); // input + inputs_tensor_map_.push_back(indices_tensor_map); // indices + inputs_tensor_map_.push_back(updates_tensor_map); // updates + + outputs_tensor_map_.push_back(input_tensor_map); + return SUCCESS; +} + +void ScatterNdOpsInfo::ReComputeBatchSplitFlagList() { + for (size_t i = 0; i < inputs_shape_.size(); i++) { + split_flag_list_[i] = false; // when set no strategy, going stand alone mode. + } +} + +// in_strategy: ((A, B, C, D), (), ()), Shapes: ((a, b, c, d), (e, f), (e, f, c, d)) +// return: ((A, B, C, D), (1, 1), (1, 1, C, D)) +// return: throw exception +Shapes ScatterNdOpsInfo::InferStrategyIndividualMode(const Shapes &in_strategy) { + if (in_strategy.size() != SIZE_THREE) { + MS_LOG(EXCEPTION) << name_ << ": The size of in_strategy must be 3, but got " << in_strategy.size(); + } + + if (in_strategy[INDEX_ZERO].empty()) { + MS_LOG(EXCEPTION) << name_ << ": The in_strategy[0] must be not empty, you should set the first input strategy."; + } + + Shape indices_strategy, updates_strategy; + indices_strategy = Shape(inputs_shape_[INDEX_ONE].size(), 1); + updates_strategy = Shape(inputs_shape_[INDEX_TWO].size(), 1); + for (size_t i = 0; i < in_strategy[0].size() - gather_dims_size_; ++i) { + updates_strategy[indices_strategy.size() - 1 + i] = in_strategy[INDEX_ZERO][i + gather_dims_size_]; + } + return Shapes({in_strategy[INDEX_ZERO], indices_strategy, updates_strategy}); +} + +ReplaceGraphPtr ScatterNdOpsInfo::replace_graph(const CNodePtr &cnode) { + if (ComputeReplaceGraph(cnode) != SUCCESS) { + MS_LOG(EXCEPTION) << name_ << " replace graph failed"; + } + return replace_graph_; +} + +// The shape of input: [A, B, C, D], the strategy of input: (a, b, c, d) +// The shape of indices: [Q, W, 2], the strategy of indices: (1, 1, 1) +// the look dim of indices whose value (x, y) should satisfy '0 <= x < A, 0 <= y < B' +// The shape of updates: [Q, W, C, D], the strategy of updates: (1, 1, c, d) +// when splitting [A, B], doing replace graph +Status ScatterNdOpsInfo::ComputeReplaceGraph(const CNodePtr &cnode) { + if (!do_replace_graph_) { + return SUCCESS; + } + if (gen_g_.Init(cnode) != SUCCESS) { + MS_LOG(ERROR) << "GenerateGraph Init failed"; + return FAILED; + } + auto anf_node_list = PrepareReplaceGraph(cnode); + // {indices_sub, mul, div, dtype}; + auto indices_sub = anf_node_list[0]; + auto mul = anf_node_list[1]; + auto div = anf_node_list[2]; + auto dtype = anf_node_list[3]; + auto info_position = name_.find("Info"); + if (info_position == std::string::npos) { + MS_LOG(EXCEPTION) << "The name " << name_ << " dose not contain 'Info'"; + } + auto node_name = name_.substr(0, info_position); + auto scatter_ops = gen_g_.PushBack({gen_g_.NewOpInst(node_name), gen_g_.virtual_input_node(), indices_sub, mul}); + + std::vector> input_nodes = {std::make_pair(scatter_ops, 1), std::make_pair(div, 2), + std::make_pair(indices_sub, 2), std::make_pair(mul, 3), + std::make_pair(dtype, 3)}; + replace_graph_ = std::make_shared>, AnfNodePtr>>( + std::make_pair(input_nodes, scatter_ops)); + + return SUCCESS; +} + +// The shape of input: [A, B, C, D], the strategy of input: (a, b, c, d) +// The shape of indices: [Q, W, 2], the strategy of indices: (1, 1, 1) +// the look dim of indices whose value (x, y) should satisfy '0 <= x < A, 0 <= y < B' +// The shape of updates: [Q, W, C, D], the strategy of updates: (1, 1, c, d) +// when splitting [A, B], doing replace graph +Status ScatterNdMulDivBaseInfo::ComputeReplaceGraph(const CNodePtr &cnode) { + if (!do_replace_graph_) { + return SUCCESS; + } + if (gen_g_.Init(cnode) != SUCCESS) { + MS_LOG(ERROR) << "GenerateGraph Init failed"; + return FAILED; + } + auto anf_node_list = PrepareReplaceGraph(cnode); + // {indices_sub, mul, div, dtype}; + auto indices_sub = anf_node_list[0]; + auto mul = anf_node_list[1]; + auto div = anf_node_list[2]; + auto dtype = anf_node_list[3]; + auto reshape_updates_mask = anf_node_list[4]; + auto reverse_sub = gen_g_.PushBack( + {gen_g_.NewOpInst(SUB), NewValueNode(std::make_shared(1.0, kFloat32)), reshape_updates_mask}); + auto add_mask = gen_g_.PushBack({gen_g_.NewOpInst(ADD), mul, reverse_sub}); + auto info_position = name_.find("Info"); + if (info_position == std::string::npos) { + MS_LOG(EXCEPTION) << "The name " << name_ << " dose not contain 'Info'"; + } + auto node_name = name_.substr(0, info_position); + auto scatter_ops = gen_g_.PushBack({gen_g_.NewOpInst(node_name), gen_g_.virtual_input_node(), indices_sub, add_mask}); + + std::vector> input_nodes = {std::make_pair(scatter_ops, 1), std::make_pair(div, 2), + std::make_pair(indices_sub, 2), std::make_pair(mul, 3), + std::make_pair(dtype, 3)}; + replace_graph_ = std::make_shared>, AnfNodePtr>>( + std::make_pair(input_nodes, scatter_ops)); + + return SUCCESS; +} + +// The shape of input: [A, B, C, D], the strategy of input: (a, b, c, d) +// The shape of indices: [Q, W, 2], the strategy of indices: (1, 1, 1) +// the look dim of indices whose value (x, y) should satisfy '0 <= x < A, 0 <= y < B' +// The shape of updates: [Q, W, C, D], the strategy of updates: (1, 1, c, d) +// when splitting [A, B], doing replace graph +std::vector ScatterNdOpsInfo::PrepareReplaceGraph(const CNodePtr &cnode) { + auto rank_in_stage = g_device_manager->rank_index_in_stage(); + MS_LOG(INFO) << name_ << ": The rank is " << rank_in_stage; + auto input_slice_shape = inputs_tensor_info_[0].slice_shape(); + Shape indices_slice_value; + std::copy(input_slice_shape.begin(), input_slice_shape.begin() + static_cast(gather_dims_size_), + std::back_inserter(indices_slice_value)); + auto indices_slice_value_tensor = std::make_shared(indices_slice_value, kInt32); + Shape indices_shape_size(inputs_shape_[1].size(), 1); + indices_shape_size[indices_shape_size.size() - 1] = -1; + auto reshape_indices_slice = + gen_g_.PushBack({gen_g_.NewOpInst(RESHAPE), ValuePtrToAnfNodePtr(indices_slice_value_tensor), + NewValueNode(MakeValue(indices_shape_size))}); + auto div = gen_g_.PushBack({gen_g_.NewOpInst(FLOORDIV), gen_g_.virtual_input_node(), reshape_indices_slice}); + Shape dev_accum_shape; + ShapeToAccumulateProductReverse(dev_matrix_shape_, &dev_accum_shape); + MS_LOG(INFO) << "The dev_matrix is :" << dev_matrix_shape_ << ", dev_accum_shape is :" << dev_accum_shape; + size_t gather_begin_position = 0; + if (repeated_calc_num_ > 1 && !repeated_num_in_dev_matrix_right_) { + gather_begin_position = 1; + } + Shape accum_value; + std::copy(dev_accum_shape.begin() + static_cast(gather_begin_position), + dev_accum_shape.begin() + static_cast(gather_begin_position + gather_dims_size_), + std::back_inserter(accum_value)); + auto delta_value = + std::accumulate(dev_accum_shape.begin() + static_cast(gather_begin_position + gather_dims_size_), + dev_accum_shape.end(), 0, std::plus()); + auto accum_value_tensor = std::make_shared(accum_value, kInt32); + auto reshape_accum_value = gen_g_.PushBack( + {gen_g_.NewOpInst(RESHAPE), ValuePtrToAnfNodePtr(accum_value_tensor), NewValueNode(MakeValue(indices_shape_size))}); + auto rank_mul = gen_g_.PushBack({gen_g_.NewOpInst(MUL), div, reshape_accum_value}); + auto indices_mul = gen_g_.PushBack({gen_g_.NewOpInst(MUL), div, reshape_indices_slice}); + auto indices_sub = gen_g_.PushBack({gen_g_.NewOpInst(SUB), gen_g_.virtual_input_node(), indices_mul}); + auto reduce_sum = gen_g_.PushBack({gen_g_.NewOpInst(REDUCE_SUM), rank_mul, CreateInt32Tensor(-1)}); + auto sub = gen_g_.PushBack({gen_g_.NewOpInst(SUB), CreateInt32Tensor(rank_in_stage), reduce_sum}); + auto relu = gen_g_.PushBack({gen_g_.NewOpInst(RELU), sub}); + auto minimum = gen_g_.PushBack({gen_g_.NewOpInst(MINIMUM), relu, CreateInt32Tensor(delta_value)}); + auto equal = gen_g_.PushBack({gen_g_.NewOpInst(EQUAL), sub, minimum}); + auto dtype = gen_g_.PushBack({gen_g_.NewOpInst(DTYPE), gen_g_.virtual_input_node()}); + auto cast = gen_g_.PushBack({gen_g_.NewOpInst(CAST), equal, dtype}); + Shape update_shapes(inputs_shape_[2]); + for (size_t i = inputs_shape_[1].size() - 1; i < update_shapes.size(); ++i) { + update_shapes[i] = 1; + } + auto reshape_updates_mask = + gen_g_.PushBack({gen_g_.NewOpInst(RESHAPE), cast, NewValueNode(MakeValue(update_shapes))}); + auto mul = gen_g_.PushBack({gen_g_.NewOpInst(MUL), gen_g_.virtual_input_node(), reshape_updates_mask}); + return {indices_sub, mul, div, dtype, reshape_updates_mask}; +} + +Status ScatterNdOpsInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { + return SetCostUnderStrategyBase(strategy); +} + +// The shape of input: [A, B, C, D], the strategy of input: (a, b, c, d) +// The shape of indices: [Q, W, 2], the strategy of indices: (1, 1, 1) +// The shape of updates: [Q, W, C, D], the strategy of updates: (1, 1, c, d) +// The shape of output: [A, B, C, D], the strategy of output: (a, b, c, d) +// The dev matrix: (a, b, 1, 1) +std::vector ScatterNdOpsInfo::GenerateOpStrategies(int64_t stage_id) { + // to generate the first input's strategy + Shape input_split(inputs_shape_[0].size(), 1); + Shapes splittable_input = {input_split}; + Shapes tmp_inputs_shape = {inputs_shape_[0]}; + + std::vector sp_vector; + if (GenerateStrategiesForIndependentInputs(stage_id, tmp_inputs_shape, splittable_input, &sp_vector) != SUCCESS) { + MS_LOG(EXCEPTION) << name_ << ": Generate strategies failed"; + } + + // the others strategies are equal to the first input's strategy + for (auto &sp : sp_vector) { + if ((sp == nullptr) || sp->GetInputDim().empty()) { + MS_LOG(EXCEPTION) << name_ << ": The strategy is null or empty"; + } + Strategies tmp_strategy; + Dimensions first_input_strategy = sp->GetInputDim()[0]; + Dimensions indices_strategy(inputs_shape_[1].size(), 1); + // updates_strategy = indices_strategy[:-1] + input_strategy[gather_dims_size_:] + Dimensions updates_strategy = indices_strategy; + updates_strategy.pop_back(); + for (size_t i = gather_dims_size_; i < first_input_strategy.size(); ++i) { + updates_strategy.push_back(first_input_strategy[i]); + } + + tmp_strategy.push_back(first_input_strategy); // input + tmp_strategy.push_back(indices_strategy); // indices + tmp_strategy.push_back(updates_strategy); // updates + sp->ResetInputs(tmp_strategy); + } + + return sp_vector; +} + +Status ScatterNdOpsInfo::InitForCostModel(const StrategyPtr &in_strategy, const StrategyPtr &out_strategy) { + if (InitForCostModelWithAutoRepeatCalc(in_strategy, out_strategy) != SUCCESS) { + if (is_auto_parallel_) { + MS_LOG(DEBUG) << name_ << ": Init for cost model failed."; + } else { + MS_LOG(ERROR) << name_ << ": Init for cost model failed."; + } + return FAILED; + } + + auto param_strategy = strategy_->GetInputDim().at(0); + // cost model set axis and strategy + auto scatter_nd_ops_cost = std::dynamic_pointer_cast(operator_cost()); + MS_EXCEPTION_IF_NULL(scatter_nd_ops_cost); + bool is_split_axis = std::find_if(param_strategy.begin(), param_strategy.begin() + gather_dims_size_, + [](auto dim) { return dim > 1; }) != param_strategy.end(); + scatter_nd_ops_cost->set_is_split_axis(is_split_axis); + scatter_nd_ops_cost->set_coefficient(1, 5, 2); + MS_LOG(INFO) << name_ << ": Init for cost model success."; + return SUCCESS; +} + +REGISTER(ScatterNdAddInfo); +REGISTER(ScatterNdSubInfo); +REGISTER(ScatterNdUpdateInfo); +REGISTER(TensorScatterUpdateInfo); +REGISTER(TensorScatterAddInfo); +REGISTER(TensorScatterSubInfo); +REGISTER(TensorScatterMulInfo); +REGISTER(TensorScatterDivInfo); +REGISTER(TensorScatterMaxInfo); +REGISTER(TensorScatterMinInfo); +} // namespace parallel +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/scatter_nd_ops_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/scatter_nd_ops_info.h new file mode 100644 index 00000000000..123d94b629b --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/ops_info/scatter_nd_ops_info.h @@ -0,0 +1,179 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_SCATTER_ND_OPS_INFO_H_ +#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_SCATTER_ND_OPS_INFO_H_ + +#include +#include +#include + +#include "utils/hash_map.h" +#include "ir/value.h" +#include "frontend/parallel/auto_parallel/operator_costmodel.h" +#include "frontend/parallel/ops_info/operator_info.h" +#include "frontend/parallel/strategy.h" +#include "frontend/parallel/graph_util/generate_graph.h" + +namespace mindspore { +namespace parallel { +class ScatterNdOpsInfo : public OperatorInfo { + public: + ScatterNdOpsInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs, const OperatorCostPtr &cost) + : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, cost) {} + ~ScatterNdOpsInfo() override = default; + + std::vector GenerateOpStrategies(int64_t stage_id) override; + Status SetCostUnderStrategy(const StrategyPtr &) override; + void ReComputeBatchSplitFlagList() override; + ReplaceGraphPtr replace_graph(const CNodePtr &cnode) override; + Status InitForCostModel(const StrategyPtr &in_strategy, const StrategyPtr &out_strategy) override; + + protected: + Status GetAttrs() override { return SUCCESS; } + Status CheckStrategy(const StrategyPtr &strategy) override; + virtual Status CheckInputStrategy(const Dimensions &strategy_item) { return SUCCESS; } + Status InferForwardCommunication() override { return SUCCESS; } + Status InferDevMatrixShape() override; + Status InferTensorMap() override; + std::vector PrepareReplaceGraph(const CNodePtr &cnode); + virtual Status ComputeReplaceGraph(const CNodePtr &cnode); + Status InferBias(); + bool do_replace_graph_ = false; + int64_t bias_ = 0; + int64_t slice_size_ = 0; + size_t gather_dims_size_ = 1; + GenerateGraph gen_g_ = GenerateGraph(attrs_); + Shapes InferStrategyIndividualMode(const Shapes &in_strategy) override; +}; + +using ScatterNdOpsInfoPtr = std::shared_ptr; + +class ScatterNdAddInfo : public ScatterNdOpsInfo { + public: + ScatterNdAddInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) + : ScatterNdOpsInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared()) {} + Status InferMirrorOps() override { return SUCCESS; } // the scatter_nd_add only use in eval/predict + ~ScatterNdAddInfo() override = default; +}; + +class ScatterNdSubInfo : public ScatterNdOpsInfo { + public: + ScatterNdSubInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) + : ScatterNdOpsInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared()) {} + Status InferMirrorOps() override { return SUCCESS; } // the scatter_nd_sub only use in eval/predict + ~ScatterNdSubInfo() override = default; +}; + +class ScatterNdUpdateInfo : public ScatterNdOpsInfo { + public: + ScatterNdUpdateInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) + : ScatterNdOpsInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared()) {} + Status InferMirrorOps() override { return SUCCESS; } // the scatter_nd_update only use in eval/predict + ~ScatterNdUpdateInfo() override = default; + + protected: + Status CheckInputStrategy(const Dimensions &strategy_item) override; +}; + +class TensorScatterUpdateInfo : public ScatterNdOpsInfo { + public: + TensorScatterUpdateInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) + : ScatterNdOpsInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared()) {} + ~TensorScatterUpdateInfo() override = default; + + protected: + Status CheckInputStrategy(const Dimensions &strategy_item) override; +}; + +class TensorScatterAddInfo : public ScatterNdOpsInfo { + public: + TensorScatterAddInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) + : ScatterNdOpsInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared()) {} + ~TensorScatterAddInfo() override = default; +}; + +class TensorScatterSubInfo : public ScatterNdOpsInfo { + public: + TensorScatterSubInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) + : ScatterNdOpsInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared()) {} + ~TensorScatterSubInfo() override = default; +}; + +class ScatterNdMulDivBaseInfo : public ScatterNdOpsInfo { + public: + ScatterNdMulDivBaseInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) + : ScatterNdOpsInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared()) {} + ~ScatterNdMulDivBaseInfo() override = default; + + protected: + Status ComputeReplaceGraph(const CNodePtr &cnode) override; +}; + +class TensorScatterMulInfo : public ScatterNdMulDivBaseInfo { + public: + TensorScatterMulInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) + : ScatterNdMulDivBaseInfo(operator_name, inputs_shape, outputs_shape, attrs) {} + ~TensorScatterMulInfo() override = default; +}; + +class TensorScatterDivInfo : public ScatterNdMulDivBaseInfo { + public: + TensorScatterDivInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) + : ScatterNdMulDivBaseInfo(operator_name, inputs_shape, outputs_shape, attrs) {} + ~TensorScatterDivInfo() override = default; +}; + +class TensorScatterMaxInfo : public TensorScatterUpdateInfo { + public: + TensorScatterMaxInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) + : TensorScatterUpdateInfo(operator_name, inputs_shape, outputs_shape, attrs) {} + ~TensorScatterMaxInfo() override = default; +}; + +class TensorScatterMinInfo : public TensorScatterUpdateInfo { + public: + TensorScatterMinInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) + : TensorScatterUpdateInfo(operator_name, inputs_shape, outputs_shape, attrs) {} + ~TensorScatterMinInfo() override = default; +}; + +using ScatterAddInfoPtr = std::shared_ptr; +using ScatterSubInfoPtr = std::shared_ptr; +using ScatterUpdateInfoPtr = std::shared_ptr; +using TensorScatterUpdateInfoPtr = std::shared_ptr; +using TensorScatterAddInfoPtr = std::shared_ptr; +using TensorScatterSubInfoPtr = std::shared_ptr; +using TensorScatterMulInfoPtr = std::shared_ptr; +using TensorScatterDivInfoPtr = std::shared_ptr; +using TensorScatterMaxInfoPtr = std::shared_ptr; +using TensorScatterMinInfoPtr = std::shared_ptr; +} // namespace parallel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_SCATTER_ND_OPS_INFO_H_ diff --git a/mindspore/ccsrc/frontend/parallel/step_parallel_utils.cc b/mindspore/ccsrc/frontend/parallel/step_parallel_utils.cc index 54ff3342aef..d1fa7618adf 100644 --- a/mindspore/ccsrc/frontend/parallel/step_parallel_utils.cc +++ b/mindspore/ccsrc/frontend/parallel/step_parallel_utils.cc @@ -725,7 +725,9 @@ bool IsSplittableOperator(const std::string &op_name) { BESSELI0E, BESSELI1E, FLOORMOD, ASSIGN, ASSIGN_ADD, ATAN2, DIVNONAN, LOGICALAND, LOGICALOR, ELU, RELU6, RELUV2, SOFTPLUS, SOFTSIGN, GREATEREQUAL, LESSEQUAL, LESS, APPROXIMATEEQUAL, MOD, UNIQUE, UNSORTED_SEGMENT_SUM, UNSORTED_SEGMENT_MIN, REPEAT_ELEMENTS, TENSOR_DOT, RANGE, UNIFORM_CANDIDATE_SAMPLER, SLICE, SELECT, GATHERD, - UNSORTED_SEGMENT_MAX, GATHER_ND, TOPK, SCATTER_UPDATE, VIRTUAL_OUTPUT, CONV2D_BACK_PROP_INPUT, CONV2D_TRANSPOSE, + UNSORTED_SEGMENT_MAX, GATHER_ND, TOPK, SCATTER_UPDATE, SCATTER_ND_UPDATE, SCATTER_ND_ADD, SCATTER_ND_SUB, + TENSOR_SCATTER_UPDATE, TENSOR_SCATTER_ADD, TENSOR_SCATTER_SUB, TENSOR_SCATTER_MAX, TENSOR_SCATTER_MIN, + TENSOR_SCATTER_MUL, TENSOR_SCATTER_DIV, VIRTUAL_OUTPUT, CONV2D_BACK_PROP_INPUT, CONV2D_TRANSPOSE, MATMUL_DDS, DSD_MATMUL, UNIFORMREAL, RESIZE_BILINEAR, RESIZE_NEAREST_NEIGHBOR, FAST_GELU, IOU, BOUNDING_BOX_ENCODE, RANDOM_CHOICE_WITH_MASK, CROP_AND_RESIZE, ROI_ALIGN, REDUCE_PROD, REDUCE_ANY, REDUCE_ALL, ARGMAX, ARGMIN, ARGMINV2, UNSORTED_SEGMENT_PROD, SQUARE_SUM_ALL, MATMUL_DDS, DSD_MATMUL, UNIFORMREAL, RESIZE_BILINEAR, UNIQUE_CONSECUTIVE, diff --git a/tests/ut/python/parallel/test_scatter_nd_ops.py b/tests/ut/python/parallel/test_scatter_nd_ops.py new file mode 100644 index 00000000000..4b81d46ced6 --- /dev/null +++ b/tests/ut/python/parallel/test_scatter_nd_ops.py @@ -0,0 +1,245 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" test scatter update """ +import numpy as np +import pytest +import mindspore.nn as nn +from mindspore import Tensor, Model, Parameter +from mindspore.ops import operations as P +from mindspore import context +from mindspore.common.api import _cell_graph_executor +from parallel.utils.utils import ParallelValidator + +scatter_nd_ops_map = {"Add": P.ScatterNdAdd(), "Update": P.ScatterNdUpdate(), "Sub": P.ScatterNdSub()} +tensor_scatter_ops_map = {"Add": P.TensorScatterAdd(), "Update": P.TensorScatterUpdate(), + "Sub": P.TensorScatterSub(), "Mul": P.TensorScatterMul(), "Div": P.TensorScatterDiv()} + + +# The shape of input: [A, B, C, D], the strategy of input: (a, b, c, d) +# The shape of indices: [Q, W, 2], the strategy of indices: (1, 1, 1) +# here the 2 respect to the size of [A, B] +# The shape of updates: [Q, W, C, D], the strategy of updates: (1, 1, c, d) +# The shape of output: [A, B, C, D], the strategy of output: (a, b, c, d) +class Net(nn.Cell): + """Net definition""" + def __init__(self, strategy1=None, strategy2=None, ops_type="Add"): + super(Net, self).__init__() + self.inputs = Parameter(Tensor(np.ones([32, 64, 128]).astype(np.float32)), "input") + self.indices = Tensor(np.ones([4, 2]).astype(np.int32)) + self.updates = Tensor(np.ones([4, 128]).astype(np.float32)) + self.scatter_ops = scatter_nd_ops_map.get(ops_type).shard(strategy1) + self.add = P.TensorAdd().shard(strategy2) + self.relu = P.ReLU() + + def construct(self, x): + out = self.scatter_ops(self.inputs, self.indices, self.updates) + out = self.add(x, out) + out = self.relu(out) + return out + + +class Net1(nn.Cell): + """Net definition""" + def __init__(self, strategy1=None, strategy2=None, ops_type="Add"): + super(Net1, self).__init__() + self.inputs = Parameter(Tensor(np.ones([32, 64, 128]).astype(np.float32)), "input") + self.indices = Tensor(np.ones([4, 3]).astype(np.int32)) + self.updates = Tensor(np.ones([4]).astype(np.float32)) + self.scatter_ops = scatter_nd_ops_map.get(ops_type).shard(strategy1) + self.add = P.TensorAdd().shard(strategy2) + self.relu = P.ReLU() + + def construct(self, x): + out = self.scatter_ops(self.inputs, self.indices, self.updates) + out = self.add(x, out) + out = self.relu(out) + return out + + +class Net2(nn.Cell): + """Net definition""" + def __init__(self, strategy1=None, strategy2=None, ops_type="Add"): + super(Net2, self).__init__() + self.indices = Tensor(np.ones([4, 3]).astype(np.int32)) + self.updates = Tensor(np.ones([4]).astype(np.float32)) + self.scatter_ops = tensor_scatter_ops_map.get(ops_type).shard(strategy1) + self.add = P.TensorAdd().shard(strategy2) + self.relu = P.ReLU() + + def construct(self, inputs, x): + inputs = self.relu(inputs) + out = self.scatter_ops(inputs, self.indices, self.updates) + out = self.add(x, out) + out = self.relu(out) + return out + + +def compile_net(net, *inputs): + net.set_auto_parallel() + net.set_train(False) + phase, _ = _cell_graph_executor.compile(net, *inputs) + context.reset_auto_parallel_context() + return phase + + +def test_scatter_nd_add(): + """ + Feature: distribute operator scatter_nd_add in auto parallel. + Description: scatter_nd_add net with sharding updating strategy in semi auto parallel. + Expectation: assert ok. + """ + context.set_context(mode=context.GRAPH_MODE) + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, full_batch=True) + inputs = Tensor(np.ones([32, 64, 128]).astype(np.float32)) + strategy1 = ((1, 2, 4), (1, 1), (1, 4)) + strategy2 = ((1, 2, 4), (1, 2, 4)) + net = Net(strategy1, strategy2) + phase = compile_net(net, inputs) + validator = ParallelValidator(net, phase) + # check layout + inputs_expect_layout = ([2, 4], [-1, 1, 0], [32, 32, 32], 0, True, '') + assert validator.check_parameter_layout('input', inputs_expect_layout) + + +def test_scatter_nd_wrong_strategy(): + """ + Feature: distribute operator scatter_nd_add in auto parallel. + Description: scatter_nd_add net with wrong strategy in semi auto parallel. + Expectation: raise runtime error. + """ + context.set_context(mode=context.GRAPH_MODE) + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, full_batch=True) + inputs = Tensor(np.ones([32, 64, 128]).astype(np.float32)) + strategy1 = ((1, 2, 4), (1, 1), (1, 2)) + strategy2 = ((1, 2, 4), (1, 2, 4)) + net = Net(strategy1, strategy2) + model = Model(net) + with pytest.raises(RuntimeError): + model.predict(inputs) + context.reset_auto_parallel_context() + + +def test_scatter_nd_sub(): + """ + Feature: distribute operator scatter_nd_sub in auto parallel. + Description: scatter_nd_sub net with sharding input gather dims in semi auto parallel. + Expectation: assert ok. + """ + context.set_context(mode=context.GRAPH_MODE) + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, full_batch=True) + inputs = Tensor(np.ones([32, 64, 128]).astype(np.float32)) + strategy1 = ((2, 2, 2), (1, 1), (1,)) + strategy2 = ((2, 2, 2), (2, 2, 2)) + net = Net1(strategy1, strategy2, ops_type="Sub") + phase = compile_net(net, inputs) + validator = ParallelValidator(net, phase) + # check layout + inputs_expect_layout = ([2, 2, 2], [2, 1, 0], [16, 32, 64], 0, True, '') + assert validator.check_parameter_layout('input', inputs_expect_layout) + # check sub_graph + sub_graph = { + 'ScatterNdSub-0': ['input', 'Sub-0', 'Mul-2'], + 'FloorDiv-0': ['_GetTensorSlice-0', 'Reshape-0'], + } + assert validator.check_graph_structure(sub_graph) + + +def test_scatter_nd_update(): + """ + Feature: distribute operator scatter_nd_update in auto parallel. + Description: scatter_nd_update net with sharding input gather dims and tables in semi auto parallel. + Expectation: assert ok. + """ + context.set_context(mode=context.GRAPH_MODE) + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, full_batch=True) + inputs = Tensor(np.ones([32, 64, 128]).astype(np.float32)) + strategy1 = ((2, 2, 2), (1, 1), (1, 2)) + strategy2 = ((2, 2, 2), (2, 2, 2)) + net = Net(strategy1, strategy2, ops_type="Update") + model = Model(net) + with pytest.raises(RuntimeError): + model.predict(inputs) + context.reset_auto_parallel_context() + + +def test_tensor_scatter_add(): + """ + Feature: distribute operator tensor_scatter_add in auto parallel. + Description: tensor_scatter_update net with sharding input gather dims and tables in semi auto parallel. + Expectation: assert ok. + """ + context.set_context(mode=context.GRAPH_MODE) + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, full_batch=True) + input1 = Tensor(np.ones([32, 64, 128]).astype(np.float32)) + input2 = Tensor(np.ones([32, 64, 128]).astype(np.float32)) + strategy1 = ((2, 2, 2), (1, 1), (1,)) + strategy2 = ((1, 2, 2), (1, 2, 2)) + net = Net2(strategy1, strategy2, ops_type="Add") + phase = compile_net(net, input1, input2) + validator = ParallelValidator(net, phase) + # check sub_graph + sub_graph = { + 'TensorScatterAdd-0': ['Reshape-1', 'Sub-0', 'Mul-2'], + 'Equal-0': ['Sub-1', 'Minimum-0'], + 'AllGather-2': ['TensorScatterAdd-0'] + } + assert validator.check_graph_structure(sub_graph) + + +def test_tensor_scatter_mul(): + """ + Feature: distribute operator tensor_scatter_mul in auto parallel. + Description: tensor_scatter_update net with sharding input gather dims and tables in semi auto parallel. + Expectation: assert ok. + """ + context.set_context(mode=context.GRAPH_MODE) + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, full_batch=True) + input1 = Tensor(np.ones([32, 64, 128]).astype(np.float32)) + input2 = Tensor(np.ones([32, 64, 128]).astype(np.float32)) + strategy1 = ((2, 2, 2), (1, 1), (1,)) + strategy2 = ((1, 2, 2), (1, 2, 2)) + net = Net2(strategy1, strategy2, ops_type="Mul") + phase = compile_net(net, input1, input2) + validator = ParallelValidator(net, phase) + # check sub_graph + sub_graph = { + 'TensorScatterMul-0': ['Reshape-1', 'Sub-0', 'Add-0'], + 'Equal-0': ['Sub-1', 'Minimum-0'], + 'AllGather-2': ['TensorScatterMul-0'] + } + assert validator.check_graph_structure(sub_graph) + + +def test_tensor_scatter_mul_auto_parallel(): + """ + Feature: distribute operator tensor_scatter_mul in auto parallel. + Description: tensor_scatter_update net with sharding input gather dims and tables in auto parallel. + Expectation: assert ok. + """ + context.set_context(mode=context.GRAPH_MODE) + context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, full_batch=True) + input1 = Tensor(np.ones([32, 64, 128]).astype(np.float32)) + input2 = Tensor(np.ones([32, 64, 128]).astype(np.float32)) + strategy1 = ((2, 2, 2), (1, 1), (1,)) + strategy2 = None + net = Net2(strategy1, strategy2, ops_type="Mul") + phase = compile_net(net, input1, input2) + validator = ParallelValidator(net, phase) + # check sub_graph + sub_graph = { + 'TensorScatterMul-0': ['ReLU-0', 'Sub-0', 'Add-0'], + 'Equal-0': ['Sub-1', 'Minimum-0'] + } + assert validator.check_graph_structure(sub_graph)