add scatter nd op

This commit is contained in:
yao_yf 2022-11-04 11:03:41 +08:00
parent dfd2239118
commit bbaeb3203b
8 changed files with 937 additions and 7 deletions

View File

@ -1895,15 +1895,65 @@ double ScatterMathOpsCost::GetForwardComputationCost(const std::vector<TensorInf
MS_LOG(EXCEPTION) << "Invalid inputs type size " << inputs_type_lengths_.size() << " for gatherv2 cost";
}
// don't split axis
if (strategy_.at(0) == 1) {
if (!is_split_axis_) {
result += ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]) +
ListProduct(input1_slice_shape) * static_cast<double>(inputs_type_lengths_[1]) +
ListProduct(input2_slice_shape) * static_cast<double>(inputs_type_lengths_[2]);
} else {
// split axis
result += ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]) * input_coefficient_ +
ListProduct(input1_slice_shape) * static_cast<double>(inputs_type_lengths_[1]) * indices_coefficient_ +
ListProduct(input2_slice_shape) * static_cast<double>(inputs_type_lengths_[1]) * updates_coefficient_;
}
return result;
}
double TensorScatterOpsCost::GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
int64_t stage_id) const {
double result = 0.0;
CheckGlobalDeviceManager();
MS_EXCEPTION_IF_NULL(g_device_manager);
auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size();
for (size_t j = 0; j < inputs.size(); ++j) {
if (!is_parameter_[j]) {
continue;
}
TensorInfo input_a_tensor_info = inputs[j];
Shape input_a_shape = input_a_tensor_info.shape();
Shape input_a_slice_shape = input_a_tensor_info.slice_shape();
int64_t used_device_num = 1;
for (size_t i = 0; i < input_a_shape.size(); ++i) {
used_device_num *= input_a_shape[i] / input_a_slice_shape[i];
}
if (total_device_num != LongToSize(used_device_num)) {
result += ListProduct(input_a_slice_shape) * static_cast<double>(inputs_type_lengths_[j]);
}
}
return result;
}
double TensorScatterOpsCost::GetBackwardComputationCost(const std::vector<TensorInfo> &inputs,
const std::vector<TensorInfo> &outputs, int64_t) const {
double result = 0.0;
Shape input0_slice_shape = inputs[0].slice_shape();
Shape input1_slice_shape = inputs[1].slice_shape();
Shape input2_slice_shape = inputs[2].slice_shape(); // equal to output shape
// brop func using 3 times input/out in average.
// don't split axis
if (!is_split_axis_) {
result += ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]) +
ListProduct(input1_slice_shape) * static_cast<double>(inputs_type_lengths_[1]) * 3 +
ListProduct(input2_slice_shape) * static_cast<double>(inputs_type_lengths_[1]) * 2;
ListProduct(input1_slice_shape) * static_cast<double>(inputs_type_lengths_[1]) +
ListProduct(input2_slice_shape) * static_cast<double>(inputs_type_lengths_[2]);
result *= 3;
} else {
// split axis
result +=
ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]) * (input_coefficient_ + 3) +
ListProduct(input1_slice_shape) * static_cast<double>(inputs_type_lengths_[1]) * (indices_coefficient_ + 3) +
ListProduct(input2_slice_shape) * static_cast<double>(inputs_type_lengths_[1]) * (updates_coefficient_ + 3);
}
return result;

View File

@ -1226,10 +1226,34 @@ class ScatterMathOpsCost : public OperatorCost {
is_inputs_should_in_memory_[0] = true;
}
void set_strategy(Shape strategy) { strategy_ = strategy; }
void set_is_split_axis(bool is_split_axis) { is_split_axis_ = is_split_axis; }
void set_coefficient(int32_t input_coefficient, int32_t indices_coefficient, int32_t updates_coefficient) {
input_coefficient_ = input_coefficient;
indices_coefficient_ = indices_coefficient;
updates_coefficient_ = updates_coefficient;
}
protected:
Shape strategy_;
bool is_split_axis_ = false;
int32_t input_coefficient_ = 1;
int32_t indices_coefficient_ = 3;
int32_t updates_coefficient_ = 2;
};
class ScatterNdOpsCost : public ScatterMathOpsCost {
public:
ScatterNdOpsCost() : ScatterMathOpsCost() {}
~ScatterNdOpsCost() override = default;
};
class TensorScatterOpsCost : public ScatterMathOpsCost {
public:
TensorScatterOpsCost() : ScatterMathOpsCost() {}
~TensorScatterOpsCost() override = default;
double GetBackwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
int64_t stage_id) const override;
double GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
int64_t stage_id) const override;
};
class CropAndResizeCost : public OperatorCost {

View File

@ -95,6 +95,20 @@ constexpr size_t BATCH_NORM_INPUTS_SIZE = 5;
constexpr double EPS = 1e-6;
constexpr double INF = 1e20;
constexpr double COST_FACTOR = 2.0;
constexpr size_t SIZE_ZERO = 0;
constexpr size_t SIZE_ONE = 1;
constexpr size_t SIZE_TWO = 2;
constexpr size_t SIZE_THREE = 3;
constexpr size_t SIZE_FOUR = 4;
constexpr size_t SIZE_FIVE = 5;
constexpr size_t SIZE_SIX = 6;
constexpr size_t INDEX_ZERO = 0;
constexpr size_t INDEX_ONE = 1;
constexpr size_t INDEX_TWO = 2;
constexpr size_t INDEX_THREE = 3;
constexpr size_t INDEX_FOUR = 4;
constexpr size_t INDEX_FIVE = 5;
constexpr size_t INDEX_SIX = 6;
constexpr char AUTO_PARALLEL_RUN_ONCE_ONLY[] = "auto_parallel_run_once_only";
constexpr char SEMI_AUTO_PARALLEL_RUN_ONCE_ONLY[] = "semi_auto_parallel_run_once_only";
@ -467,6 +481,15 @@ constexpr char SCATTER_MUL[] = "ScatterMul";
constexpr char SCATTER_DIV[] = "ScatterDiv";
constexpr char SCATTER_MAX[] = "ScatterMax";
constexpr char SCATTER_MIN[] = "ScatterMin";
constexpr char SCATTER_ND_UPDATE[] = "ScatterNdUpdate";
constexpr char SCATTER_ND_ADD[] = "ScatterNdAdd";
constexpr char SCATTER_ND_SUB[] = "ScatterNdSub";
constexpr char TENSOR_SCATTER_ADD[] = "TensorScatterAdd";
constexpr char TENSOR_SCATTER_SUB[] = "TensorScatterSub";
constexpr char TENSOR_SCATTER_MAX[] = "TensorScatterMax";
constexpr char TENSOR_SCATTER_MIN[] = "TensorScatterMin";
constexpr char TENSOR_SCATTER_MUL[] = "TensorScatterMul";
constexpr char TENSOR_SCATTER_DIV[] = "TensorScatterDiv";
constexpr char GATHERD[] = "GatherD";
constexpr char DSD_MATMUL[] = "DSDMatmul";
constexpr char RESIZE_BILINEAR[] = "ResizeBilinear";

View File

@ -262,7 +262,7 @@ Status ScatterMathOpsInfo::InitForCostModel(const StrategyPtr &in_strategy, cons
auto param_strategy = strategy_->GetInputDim().at(0);
// cost model set axis and strategy
auto scatter_ops_cost = std::dynamic_pointer_cast<ScatterMathOpsCost>(operator_cost());
scatter_ops_cost->set_strategy(param_strategy);
scatter_ops_cost->set_is_split_axis((param_strategy[0] > 1));
MS_LOG(INFO) << name_ << ": Init for cost model success.";
return SUCCESS;
}

View File

@ -0,0 +1,407 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "frontend/parallel/ops_info/scatter_nd_ops_info.h"
#include <algorithm>
#include <functional>
#include <memory>
#include <utility>
#include <vector>
#include "frontend/parallel/device_matrix.h"
#include "frontend/parallel/dynamic_creator.h"
#include "frontend/parallel/strategy.h"
#include "frontend/parallel/tensor_layout/tensor_redistribution.h"
#include "pipeline/jit/resource.h"
#include "frontend/parallel/tensor_layout/shape_util.h"
namespace mindspore {
namespace parallel {
// The shape of input: [A, B, C, D], the strategy of input: (a, b, c, d)
// The shape of indices: [Q, W, 2], the strategy of indices: (1, 1, 1)
// the look dim of indices whose value (x, y) should satisfy '0 <= x < A, 0 <= y < B'
// here the 2 respect to the size of [A, B]
// The shape of updates: [Q, W, C, D], the strategy of updates: (1, 1, c, d)
// The shape of output: [A, B, C, D], the strategy of output: (a, b, c, d)
// The dev matrix: (a, b, 1, 1)
Status ScatterNdOpsInfo::CheckStrategy(const StrategyPtr &strategy) {
MS_EXCEPTION_IF_NULL(strategy);
if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Invalid strategy";
return FAILED;
}
std::vector<Dimensions> stra = strategy->GetInputDim();
if (stra.size() != 3) {
MS_LOG(ERROR) << name_ << ": The size of strategy must be 3";
return FAILED;
}
auto indices_shape = inputs_shape_[1];
gather_dims_size_ = indices_shape.back();
if (CheckInputStrategy(stra[0]) != SUCCESS) {
return FAILED;
}
if (!stra[1].empty() && std::accumulate(stra[1].begin(), stra[1].end(), 1, std::multiplies<int64_t>()) != 1) {
MS_LOG(ERROR) << name_ << ": The indices can not be split";
return FAILED;
}
if (stra[2].empty()) {
MS_LOG(ERROR) << name_ << ": The strategy[2] is empty";
return FAILED;
}
if (std::accumulate(stra[2].begin(), stra[2].begin() + static_cast<different_type>(stra[1].size() - 1), 1,
std::multiplies<int64_t>()) != 1) {
MS_LOG(ERROR) << name_ << ": The first " << (stra[1].size() - 1) << " dimensions of updates can not be split";
return FAILED;
}
for (size_t i = 0; i < stra[0].size() - gather_dims_size_; ++i) {
if (stra[0][i + gather_dims_size_] != stra[2][stra[1].size() - 1 + i]) {
MS_LOG(ERROR)
<< name_ << ": updates.strategy must be equal to indices.strategy[:-1] + input.strategy[indices.strategy[-1]:]";
return FAILED;
}
}
for (size_t i = 0; i < gather_dims_size_; ++i) {
if (stra[0][i] > 1) {
do_replace_graph_ = true;
break;
}
}
return SUCCESS;
}
// The shape of input: [A, B.., C, D], the strategy of input: (1, 1.., c, d)
Status ScatterNdUpdateInfo::CheckInputStrategy(const Dimensions &strategy_item) {
for (size_t i = 0; i < gather_dims_size_; ++i) {
if (strategy_item[i] != 1) {
MS_LOG(ERROR) << "For " << name_
<< ", "
"the input cannot be shard at gather dims input[:"
<< gather_dims_size_ << "].";
return FAILED;
}
}
return SUCCESS;
}
// The shape of input: [A, B.., C, D], the strategy of input: (1, 1.., c, d)
Status TensorScatterUpdateInfo::CheckInputStrategy(const Dimensions &strategy_item) {
for (size_t i = 0; i < gather_dims_size_; ++i) {
if (strategy_item[i] != 1) {
MS_LOG(ERROR) << "For ScatterNdUpdate/TensorScatterUpdate, "
"the input cannot be shard at gather dims input[:"
<< gather_dims_size_ << "].";
return FAILED;
}
}
return SUCCESS;
}
Status ScatterNdOpsInfo::InferDevMatrixShape() {
MS_EXCEPTION_IF_NULL(strategy_);
std::vector<Dimensions> stra = strategy_->GetInputDim();
if (stra.empty()) {
MS_LOG(ERROR) << name_ << "The strategy is empty";
return FAILED;
}
dev_matrix_shape_ = stra[0];
return SUCCESS;
}
// The shape of input: [A, B, C, D], the strategy of input: (a, b, c, d), tensor_map: (3, 2, 1, 0)
// The shape of indices: [Q, W, 2], the strategy of indices: (1, 1, 1), tensor_map: (-1, -1, -1)
// The shape of updates: [Q, W, C, D], the strategy of updates: (1, 1, c, d), tensor_map: (-1, -1, 1, 0)
// The shape of output: [A, B, C, D], the strategy of output: (a, b, c, d), tensor_map: (3, 2, 1, 0)
// The dev matrix: (a, b, c, d)
Status ScatterNdOpsInfo::InferTensorMap() {
if (inputs_shape_.size() != 3) {
MS_LOG(ERROR) << name_ << "The size of inputs shape must be 3";
return FAILED;
}
TensorMap input_tensor_map, updates_tensor_map;
TensorMap indices_tensor_map(inputs_shape_[1].size(), MAP_NONE);
// cannot use dev_matrix_shape_ replace inputs_shape_[0], because it may not be fully split in all devices.
int64_t size = SizeToLong(inputs_shape_[0].size());
for (int64_t i = 0; i < size; ++i) {
input_tensor_map.push_back(size - i - 1);
}
// updates_tensor_map = indices_tensor_map[:-1] + input_tensor_map[gather_dim_size_:]
updates_tensor_map = indices_tensor_map;
updates_tensor_map.pop_back();
for (size_t i = gather_dims_size_; i < input_tensor_map.size(); ++i) {
updates_tensor_map.push_back(input_tensor_map[i]);
}
inputs_tensor_map_.push_back(input_tensor_map); // input
inputs_tensor_map_.push_back(indices_tensor_map); // indices
inputs_tensor_map_.push_back(updates_tensor_map); // updates
outputs_tensor_map_.push_back(input_tensor_map);
return SUCCESS;
}
void ScatterNdOpsInfo::ReComputeBatchSplitFlagList() {
for (size_t i = 0; i < inputs_shape_.size(); i++) {
split_flag_list_[i] = false; // when set no strategy, going stand alone mode.
}
}
// in_strategy: ((A, B, C, D), (), ()), Shapes: ((a, b, c, d), (e, f), (e, f, c, d))
// return: ((A, B, C, D), (1, 1), (1, 1, C, D))
// return: throw exception
Shapes ScatterNdOpsInfo::InferStrategyIndividualMode(const Shapes &in_strategy) {
if (in_strategy.size() != SIZE_THREE) {
MS_LOG(EXCEPTION) << name_ << ": The size of in_strategy must be 3, but got " << in_strategy.size();
}
if (in_strategy[INDEX_ZERO].empty()) {
MS_LOG(EXCEPTION) << name_ << ": The in_strategy[0] must be not empty, you should set the first input strategy.";
}
Shape indices_strategy, updates_strategy;
indices_strategy = Shape(inputs_shape_[INDEX_ONE].size(), 1);
updates_strategy = Shape(inputs_shape_[INDEX_TWO].size(), 1);
for (size_t i = 0; i < in_strategy[0].size() - gather_dims_size_; ++i) {
updates_strategy[indices_strategy.size() - 1 + i] = in_strategy[INDEX_ZERO][i + gather_dims_size_];
}
return Shapes({in_strategy[INDEX_ZERO], indices_strategy, updates_strategy});
}
ReplaceGraphPtr ScatterNdOpsInfo::replace_graph(const CNodePtr &cnode) {
if (ComputeReplaceGraph(cnode) != SUCCESS) {
MS_LOG(EXCEPTION) << name_ << " replace graph failed";
}
return replace_graph_;
}
// The shape of input: [A, B, C, D], the strategy of input: (a, b, c, d)
// The shape of indices: [Q, W, 2], the strategy of indices: (1, 1, 1)
// the look dim of indices whose value (x, y) should satisfy '0 <= x < A, 0 <= y < B'
// The shape of updates: [Q, W, C, D], the strategy of updates: (1, 1, c, d)
// when splitting [A, B], doing replace graph
Status ScatterNdOpsInfo::ComputeReplaceGraph(const CNodePtr &cnode) {
if (!do_replace_graph_) {
return SUCCESS;
}
if (gen_g_.Init(cnode) != SUCCESS) {
MS_LOG(ERROR) << "GenerateGraph Init failed";
return FAILED;
}
auto anf_node_list = PrepareReplaceGraph(cnode);
// {indices_sub, mul, div, dtype};
auto indices_sub = anf_node_list[0];
auto mul = anf_node_list[1];
auto div = anf_node_list[2];
auto dtype = anf_node_list[3];
auto info_position = name_.find("Info");
if (info_position == std::string::npos) {
MS_LOG(EXCEPTION) << "The name " << name_ << " dose not contain 'Info'";
}
auto node_name = name_.substr(0, info_position);
auto scatter_ops = gen_g_.PushBack({gen_g_.NewOpInst(node_name), gen_g_.virtual_input_node(), indices_sub, mul});
std::vector<std::pair<AnfNodePtr, int64_t>> input_nodes = {std::make_pair(scatter_ops, 1), std::make_pair(div, 2),
std::make_pair(indices_sub, 2), std::make_pair(mul, 3),
std::make_pair(dtype, 3)};
replace_graph_ = std::make_shared<std::pair<std::vector<std::pair<AnfNodePtr, int64_t>>, AnfNodePtr>>(
std::make_pair(input_nodes, scatter_ops));
return SUCCESS;
}
// The shape of input: [A, B, C, D], the strategy of input: (a, b, c, d)
// The shape of indices: [Q, W, 2], the strategy of indices: (1, 1, 1)
// the look dim of indices whose value (x, y) should satisfy '0 <= x < A, 0 <= y < B'
// The shape of updates: [Q, W, C, D], the strategy of updates: (1, 1, c, d)
// when splitting [A, B], doing replace graph
Status ScatterNdMulDivBaseInfo::ComputeReplaceGraph(const CNodePtr &cnode) {
if (!do_replace_graph_) {
return SUCCESS;
}
if (gen_g_.Init(cnode) != SUCCESS) {
MS_LOG(ERROR) << "GenerateGraph Init failed";
return FAILED;
}
auto anf_node_list = PrepareReplaceGraph(cnode);
// {indices_sub, mul, div, dtype};
auto indices_sub = anf_node_list[0];
auto mul = anf_node_list[1];
auto div = anf_node_list[2];
auto dtype = anf_node_list[3];
auto reshape_updates_mask = anf_node_list[4];
auto reverse_sub = gen_g_.PushBack(
{gen_g_.NewOpInst(SUB), NewValueNode(std::make_shared<tensor::Tensor>(1.0, kFloat32)), reshape_updates_mask});
auto add_mask = gen_g_.PushBack({gen_g_.NewOpInst(ADD), mul, reverse_sub});
auto info_position = name_.find("Info");
if (info_position == std::string::npos) {
MS_LOG(EXCEPTION) << "The name " << name_ << " dose not contain 'Info'";
}
auto node_name = name_.substr(0, info_position);
auto scatter_ops = gen_g_.PushBack({gen_g_.NewOpInst(node_name), gen_g_.virtual_input_node(), indices_sub, add_mask});
std::vector<std::pair<AnfNodePtr, int64_t>> input_nodes = {std::make_pair(scatter_ops, 1), std::make_pair(div, 2),
std::make_pair(indices_sub, 2), std::make_pair(mul, 3),
std::make_pair(dtype, 3)};
replace_graph_ = std::make_shared<std::pair<std::vector<std::pair<AnfNodePtr, int64_t>>, AnfNodePtr>>(
std::make_pair(input_nodes, scatter_ops));
return SUCCESS;
}
// The shape of input: [A, B, C, D], the strategy of input: (a, b, c, d)
// The shape of indices: [Q, W, 2], the strategy of indices: (1, 1, 1)
// the look dim of indices whose value (x, y) should satisfy '0 <= x < A, 0 <= y < B'
// The shape of updates: [Q, W, C, D], the strategy of updates: (1, 1, c, d)
// when splitting [A, B], doing replace graph
std::vector<AnfNodePtr> ScatterNdOpsInfo::PrepareReplaceGraph(const CNodePtr &cnode) {
auto rank_in_stage = g_device_manager->rank_index_in_stage();
MS_LOG(INFO) << name_ << ": The rank is " << rank_in_stage;
auto input_slice_shape = inputs_tensor_info_[0].slice_shape();
Shape indices_slice_value;
std::copy(input_slice_shape.begin(), input_slice_shape.begin() + static_cast<different_type>(gather_dims_size_),
std::back_inserter(indices_slice_value));
auto indices_slice_value_tensor = std::make_shared<mindspore::tensor::Tensor>(indices_slice_value, kInt32);
Shape indices_shape_size(inputs_shape_[1].size(), 1);
indices_shape_size[indices_shape_size.size() - 1] = -1;
auto reshape_indices_slice =
gen_g_.PushBack({gen_g_.NewOpInst(RESHAPE), ValuePtrToAnfNodePtr(indices_slice_value_tensor),
NewValueNode(MakeValue(indices_shape_size))});
auto div = gen_g_.PushBack({gen_g_.NewOpInst(FLOORDIV), gen_g_.virtual_input_node(), reshape_indices_slice});
Shape dev_accum_shape;
ShapeToAccumulateProductReverse(dev_matrix_shape_, &dev_accum_shape);
MS_LOG(INFO) << "The dev_matrix is :" << dev_matrix_shape_ << ", dev_accum_shape is :" << dev_accum_shape;
size_t gather_begin_position = 0;
if (repeated_calc_num_ > 1 && !repeated_num_in_dev_matrix_right_) {
gather_begin_position = 1;
}
Shape accum_value;
std::copy(dev_accum_shape.begin() + static_cast<different_type>(gather_begin_position),
dev_accum_shape.begin() + static_cast<different_type>(gather_begin_position + gather_dims_size_),
std::back_inserter(accum_value));
auto delta_value =
std::accumulate(dev_accum_shape.begin() + static_cast<different_type>(gather_begin_position + gather_dims_size_),
dev_accum_shape.end(), 0, std::plus<int64_t>());
auto accum_value_tensor = std::make_shared<mindspore::tensor::Tensor>(accum_value, kInt32);
auto reshape_accum_value = gen_g_.PushBack(
{gen_g_.NewOpInst(RESHAPE), ValuePtrToAnfNodePtr(accum_value_tensor), NewValueNode(MakeValue(indices_shape_size))});
auto rank_mul = gen_g_.PushBack({gen_g_.NewOpInst(MUL), div, reshape_accum_value});
auto indices_mul = gen_g_.PushBack({gen_g_.NewOpInst(MUL), div, reshape_indices_slice});
auto indices_sub = gen_g_.PushBack({gen_g_.NewOpInst(SUB), gen_g_.virtual_input_node(), indices_mul});
auto reduce_sum = gen_g_.PushBack({gen_g_.NewOpInst(REDUCE_SUM), rank_mul, CreateInt32Tensor(-1)});
auto sub = gen_g_.PushBack({gen_g_.NewOpInst(SUB), CreateInt32Tensor(rank_in_stage), reduce_sum});
auto relu = gen_g_.PushBack({gen_g_.NewOpInst(RELU), sub});
auto minimum = gen_g_.PushBack({gen_g_.NewOpInst(MINIMUM), relu, CreateInt32Tensor(delta_value)});
auto equal = gen_g_.PushBack({gen_g_.NewOpInst(EQUAL), sub, minimum});
auto dtype = gen_g_.PushBack({gen_g_.NewOpInst(DTYPE), gen_g_.virtual_input_node()});
auto cast = gen_g_.PushBack({gen_g_.NewOpInst(CAST), equal, dtype});
Shape update_shapes(inputs_shape_[2]);
for (size_t i = inputs_shape_[1].size() - 1; i < update_shapes.size(); ++i) {
update_shapes[i] = 1;
}
auto reshape_updates_mask =
gen_g_.PushBack({gen_g_.NewOpInst(RESHAPE), cast, NewValueNode(MakeValue(update_shapes))});
auto mul = gen_g_.PushBack({gen_g_.NewOpInst(MUL), gen_g_.virtual_input_node(), reshape_updates_mask});
return {indices_sub, mul, div, dtype, reshape_updates_mask};
}
Status ScatterNdOpsInfo::SetCostUnderStrategy(const StrategyPtr &strategy) {
return SetCostUnderStrategyBase(strategy);
}
// The shape of input: [A, B, C, D], the strategy of input: (a, b, c, d)
// The shape of indices: [Q, W, 2], the strategy of indices: (1, 1, 1)
// The shape of updates: [Q, W, C, D], the strategy of updates: (1, 1, c, d)
// The shape of output: [A, B, C, D], the strategy of output: (a, b, c, d)
// The dev matrix: (a, b, 1, 1)
std::vector<StrategyPtr> ScatterNdOpsInfo::GenerateOpStrategies(int64_t stage_id) {
// to generate the first input's strategy
Shape input_split(inputs_shape_[0].size(), 1);
Shapes splittable_input = {input_split};
Shapes tmp_inputs_shape = {inputs_shape_[0]};
std::vector<StrategyPtr> sp_vector;
if (GenerateStrategiesForIndependentInputs(stage_id, tmp_inputs_shape, splittable_input, &sp_vector) != SUCCESS) {
MS_LOG(EXCEPTION) << name_ << ": Generate strategies failed";
}
// the others strategies are equal to the first input's strategy
for (auto &sp : sp_vector) {
if ((sp == nullptr) || sp->GetInputDim().empty()) {
MS_LOG(EXCEPTION) << name_ << ": The strategy is null or empty";
}
Strategies tmp_strategy;
Dimensions first_input_strategy = sp->GetInputDim()[0];
Dimensions indices_strategy(inputs_shape_[1].size(), 1);
// updates_strategy = indices_strategy[:-1] + input_strategy[gather_dims_size_:]
Dimensions updates_strategy = indices_strategy;
updates_strategy.pop_back();
for (size_t i = gather_dims_size_; i < first_input_strategy.size(); ++i) {
updates_strategy.push_back(first_input_strategy[i]);
}
tmp_strategy.push_back(first_input_strategy); // input
tmp_strategy.push_back(indices_strategy); // indices
tmp_strategy.push_back(updates_strategy); // updates
sp->ResetInputs(tmp_strategy);
}
return sp_vector;
}
Status ScatterNdOpsInfo::InitForCostModel(const StrategyPtr &in_strategy, const StrategyPtr &out_strategy) {
if (InitForCostModelWithAutoRepeatCalc(in_strategy, out_strategy) != SUCCESS) {
if (is_auto_parallel_) {
MS_LOG(DEBUG) << name_ << ": Init for cost model failed.";
} else {
MS_LOG(ERROR) << name_ << ": Init for cost model failed.";
}
return FAILED;
}
auto param_strategy = strategy_->GetInputDim().at(0);
// cost model set axis and strategy
auto scatter_nd_ops_cost = std::dynamic_pointer_cast<ScatterMathOpsCost>(operator_cost());
MS_EXCEPTION_IF_NULL(scatter_nd_ops_cost);
bool is_split_axis = std::find_if(param_strategy.begin(), param_strategy.begin() + gather_dims_size_,
[](auto dim) { return dim > 1; }) != param_strategy.end();
scatter_nd_ops_cost->set_is_split_axis(is_split_axis);
scatter_nd_ops_cost->set_coefficient(1, 5, 2);
MS_LOG(INFO) << name_ << ": Init for cost model success.";
return SUCCESS;
}
REGISTER(ScatterNdAddInfo);
REGISTER(ScatterNdSubInfo);
REGISTER(ScatterNdUpdateInfo);
REGISTER(TensorScatterUpdateInfo);
REGISTER(TensorScatterAddInfo);
REGISTER(TensorScatterSubInfo);
REGISTER(TensorScatterMulInfo);
REGISTER(TensorScatterDivInfo);
REGISTER(TensorScatterMaxInfo);
REGISTER(TensorScatterMinInfo);
} // namespace parallel
} // namespace mindspore

View File

@ -0,0 +1,179 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_SCATTER_ND_OPS_INFO_H_
#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_SCATTER_ND_OPS_INFO_H_
#include <string>
#include <memory>
#include <vector>
#include "utils/hash_map.h"
#include "ir/value.h"
#include "frontend/parallel/auto_parallel/operator_costmodel.h"
#include "frontend/parallel/ops_info/operator_info.h"
#include "frontend/parallel/strategy.h"
#include "frontend/parallel/graph_util/generate_graph.h"
namespace mindspore {
namespace parallel {
class ScatterNdOpsInfo : public OperatorInfo {
public:
ScatterNdOpsInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs, const OperatorCostPtr &cost)
: OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, cost) {}
~ScatterNdOpsInfo() override = default;
std::vector<StrategyPtr> GenerateOpStrategies(int64_t stage_id) override;
Status SetCostUnderStrategy(const StrategyPtr &) override;
void ReComputeBatchSplitFlagList() override;
ReplaceGraphPtr replace_graph(const CNodePtr &cnode) override;
Status InitForCostModel(const StrategyPtr &in_strategy, const StrategyPtr &out_strategy) override;
protected:
Status GetAttrs() override { return SUCCESS; }
Status CheckStrategy(const StrategyPtr &strategy) override;
virtual Status CheckInputStrategy(const Dimensions &strategy_item) { return SUCCESS; }
Status InferForwardCommunication() override { return SUCCESS; }
Status InferDevMatrixShape() override;
Status InferTensorMap() override;
std::vector<AnfNodePtr> PrepareReplaceGraph(const CNodePtr &cnode);
virtual Status ComputeReplaceGraph(const CNodePtr &cnode);
Status InferBias();
bool do_replace_graph_ = false;
int64_t bias_ = 0;
int64_t slice_size_ = 0;
size_t gather_dims_size_ = 1;
GenerateGraph gen_g_ = GenerateGraph(attrs_);
Shapes InferStrategyIndividualMode(const Shapes &in_strategy) override;
};
using ScatterNdOpsInfoPtr = std::shared_ptr<ScatterNdOpsInfo>;
class ScatterNdAddInfo : public ScatterNdOpsInfo {
public:
ScatterNdAddInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ScatterNdOpsInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<ScatterNdOpsCost>()) {}
Status InferMirrorOps() override { return SUCCESS; } // the scatter_nd_add only use in eval/predict
~ScatterNdAddInfo() override = default;
};
class ScatterNdSubInfo : public ScatterNdOpsInfo {
public:
ScatterNdSubInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ScatterNdOpsInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<ScatterNdOpsCost>()) {}
Status InferMirrorOps() override { return SUCCESS; } // the scatter_nd_sub only use in eval/predict
~ScatterNdSubInfo() override = default;
};
class ScatterNdUpdateInfo : public ScatterNdOpsInfo {
public:
ScatterNdUpdateInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ScatterNdOpsInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<ScatterNdOpsCost>()) {}
Status InferMirrorOps() override { return SUCCESS; } // the scatter_nd_update only use in eval/predict
~ScatterNdUpdateInfo() override = default;
protected:
Status CheckInputStrategy(const Dimensions &strategy_item) override;
};
class TensorScatterUpdateInfo : public ScatterNdOpsInfo {
public:
TensorScatterUpdateInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ScatterNdOpsInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<TensorScatterOpsCost>()) {}
~TensorScatterUpdateInfo() override = default;
protected:
Status CheckInputStrategy(const Dimensions &strategy_item) override;
};
class TensorScatterAddInfo : public ScatterNdOpsInfo {
public:
TensorScatterAddInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ScatterNdOpsInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<TensorScatterOpsCost>()) {}
~TensorScatterAddInfo() override = default;
};
class TensorScatterSubInfo : public ScatterNdOpsInfo {
public:
TensorScatterSubInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ScatterNdOpsInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<TensorScatterOpsCost>()) {}
~TensorScatterSubInfo() override = default;
};
class ScatterNdMulDivBaseInfo : public ScatterNdOpsInfo {
public:
ScatterNdMulDivBaseInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ScatterNdOpsInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<TensorScatterOpsCost>()) {}
~ScatterNdMulDivBaseInfo() override = default;
protected:
Status ComputeReplaceGraph(const CNodePtr &cnode) override;
};
class TensorScatterMulInfo : public ScatterNdMulDivBaseInfo {
public:
TensorScatterMulInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ScatterNdMulDivBaseInfo(operator_name, inputs_shape, outputs_shape, attrs) {}
~TensorScatterMulInfo() override = default;
};
class TensorScatterDivInfo : public ScatterNdMulDivBaseInfo {
public:
TensorScatterDivInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ScatterNdMulDivBaseInfo(operator_name, inputs_shape, outputs_shape, attrs) {}
~TensorScatterDivInfo() override = default;
};
class TensorScatterMaxInfo : public TensorScatterUpdateInfo {
public:
TensorScatterMaxInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: TensorScatterUpdateInfo(operator_name, inputs_shape, outputs_shape, attrs) {}
~TensorScatterMaxInfo() override = default;
};
class TensorScatterMinInfo : public TensorScatterUpdateInfo {
public:
TensorScatterMinInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: TensorScatterUpdateInfo(operator_name, inputs_shape, outputs_shape, attrs) {}
~TensorScatterMinInfo() override = default;
};
using ScatterAddInfoPtr = std::shared_ptr<ScatterNdAddInfo>;
using ScatterSubInfoPtr = std::shared_ptr<ScatterNdSubInfo>;
using ScatterUpdateInfoPtr = std::shared_ptr<ScatterNdUpdateInfo>;
using TensorScatterUpdateInfoPtr = std::shared_ptr<TensorScatterUpdateInfo>;
using TensorScatterAddInfoPtr = std::shared_ptr<TensorScatterAddInfo>;
using TensorScatterSubInfoPtr = std::shared_ptr<TensorScatterSubInfo>;
using TensorScatterMulInfoPtr = std::shared_ptr<TensorScatterMulInfo>;
using TensorScatterDivInfoPtr = std::shared_ptr<TensorScatterDivInfo>;
using TensorScatterMaxInfoPtr = std::shared_ptr<TensorScatterMaxInfo>;
using TensorScatterMinInfoPtr = std::shared_ptr<TensorScatterMinInfo>;
} // namespace parallel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_SCATTER_ND_OPS_INFO_H_

View File

@ -725,7 +725,9 @@ bool IsSplittableOperator(const std::string &op_name) {
BESSELI0E, BESSELI1E, FLOORMOD, ASSIGN, ASSIGN_ADD, ATAN2, DIVNONAN, LOGICALAND, LOGICALOR, ELU, RELU6, RELUV2,
SOFTPLUS, SOFTSIGN, GREATEREQUAL, LESSEQUAL, LESS, APPROXIMATEEQUAL, MOD, UNIQUE, UNSORTED_SEGMENT_SUM,
UNSORTED_SEGMENT_MIN, REPEAT_ELEMENTS, TENSOR_DOT, RANGE, UNIFORM_CANDIDATE_SAMPLER, SLICE, SELECT, GATHERD,
UNSORTED_SEGMENT_MAX, GATHER_ND, TOPK, SCATTER_UPDATE, VIRTUAL_OUTPUT, CONV2D_BACK_PROP_INPUT, CONV2D_TRANSPOSE,
UNSORTED_SEGMENT_MAX, GATHER_ND, TOPK, SCATTER_UPDATE, SCATTER_ND_UPDATE, SCATTER_ND_ADD, SCATTER_ND_SUB,
TENSOR_SCATTER_UPDATE, TENSOR_SCATTER_ADD, TENSOR_SCATTER_SUB, TENSOR_SCATTER_MAX, TENSOR_SCATTER_MIN,
TENSOR_SCATTER_MUL, TENSOR_SCATTER_DIV, VIRTUAL_OUTPUT, CONV2D_BACK_PROP_INPUT, CONV2D_TRANSPOSE,
MATMUL_DDS, DSD_MATMUL, UNIFORMREAL, RESIZE_BILINEAR, RESIZE_NEAREST_NEIGHBOR, FAST_GELU, IOU, BOUNDING_BOX_ENCODE,
RANDOM_CHOICE_WITH_MASK, CROP_AND_RESIZE, ROI_ALIGN, REDUCE_PROD, REDUCE_ANY, REDUCE_ALL, ARGMAX, ARGMIN, ARGMINV2,
UNSORTED_SEGMENT_PROD, SQUARE_SUM_ALL, MATMUL_DDS, DSD_MATMUL, UNIFORMREAL, RESIZE_BILINEAR, UNIQUE_CONSECUTIVE,

View File

@ -0,0 +1,245 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" test scatter update """
import numpy as np
import pytest
import mindspore.nn as nn
from mindspore import Tensor, Model, Parameter
from mindspore.ops import operations as P
from mindspore import context
from mindspore.common.api import _cell_graph_executor
from parallel.utils.utils import ParallelValidator
scatter_nd_ops_map = {"Add": P.ScatterNdAdd(), "Update": P.ScatterNdUpdate(), "Sub": P.ScatterNdSub()}
tensor_scatter_ops_map = {"Add": P.TensorScatterAdd(), "Update": P.TensorScatterUpdate(),
"Sub": P.TensorScatterSub(), "Mul": P.TensorScatterMul(), "Div": P.TensorScatterDiv()}
# The shape of input: [A, B, C, D], the strategy of input: (a, b, c, d)
# The shape of indices: [Q, W, 2], the strategy of indices: (1, 1, 1)
# here the 2 respect to the size of [A, B]
# The shape of updates: [Q, W, C, D], the strategy of updates: (1, 1, c, d)
# The shape of output: [A, B, C, D], the strategy of output: (a, b, c, d)
class Net(nn.Cell):
"""Net definition"""
def __init__(self, strategy1=None, strategy2=None, ops_type="Add"):
super(Net, self).__init__()
self.inputs = Parameter(Tensor(np.ones([32, 64, 128]).astype(np.float32)), "input")
self.indices = Tensor(np.ones([4, 2]).astype(np.int32))
self.updates = Tensor(np.ones([4, 128]).astype(np.float32))
self.scatter_ops = scatter_nd_ops_map.get(ops_type).shard(strategy1)
self.add = P.TensorAdd().shard(strategy2)
self.relu = P.ReLU()
def construct(self, x):
out = self.scatter_ops(self.inputs, self.indices, self.updates)
out = self.add(x, out)
out = self.relu(out)
return out
class Net1(nn.Cell):
"""Net definition"""
def __init__(self, strategy1=None, strategy2=None, ops_type="Add"):
super(Net1, self).__init__()
self.inputs = Parameter(Tensor(np.ones([32, 64, 128]).astype(np.float32)), "input")
self.indices = Tensor(np.ones([4, 3]).astype(np.int32))
self.updates = Tensor(np.ones([4]).astype(np.float32))
self.scatter_ops = scatter_nd_ops_map.get(ops_type).shard(strategy1)
self.add = P.TensorAdd().shard(strategy2)
self.relu = P.ReLU()
def construct(self, x):
out = self.scatter_ops(self.inputs, self.indices, self.updates)
out = self.add(x, out)
out = self.relu(out)
return out
class Net2(nn.Cell):
"""Net definition"""
def __init__(self, strategy1=None, strategy2=None, ops_type="Add"):
super(Net2, self).__init__()
self.indices = Tensor(np.ones([4, 3]).astype(np.int32))
self.updates = Tensor(np.ones([4]).astype(np.float32))
self.scatter_ops = tensor_scatter_ops_map.get(ops_type).shard(strategy1)
self.add = P.TensorAdd().shard(strategy2)
self.relu = P.ReLU()
def construct(self, inputs, x):
inputs = self.relu(inputs)
out = self.scatter_ops(inputs, self.indices, self.updates)
out = self.add(x, out)
out = self.relu(out)
return out
def compile_net(net, *inputs):
net.set_auto_parallel()
net.set_train(False)
phase, _ = _cell_graph_executor.compile(net, *inputs)
context.reset_auto_parallel_context()
return phase
def test_scatter_nd_add():
"""
Feature: distribute operator scatter_nd_add in auto parallel.
Description: scatter_nd_add net with sharding updating strategy in semi auto parallel.
Expectation: assert ok.
"""
context.set_context(mode=context.GRAPH_MODE)
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, full_batch=True)
inputs = Tensor(np.ones([32, 64, 128]).astype(np.float32))
strategy1 = ((1, 2, 4), (1, 1), (1, 4))
strategy2 = ((1, 2, 4), (1, 2, 4))
net = Net(strategy1, strategy2)
phase = compile_net(net, inputs)
validator = ParallelValidator(net, phase)
# check layout
inputs_expect_layout = ([2, 4], [-1, 1, 0], [32, 32, 32], 0, True, '')
assert validator.check_parameter_layout('input', inputs_expect_layout)
def test_scatter_nd_wrong_strategy():
"""
Feature: distribute operator scatter_nd_add in auto parallel.
Description: scatter_nd_add net with wrong strategy in semi auto parallel.
Expectation: raise runtime error.
"""
context.set_context(mode=context.GRAPH_MODE)
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, full_batch=True)
inputs = Tensor(np.ones([32, 64, 128]).astype(np.float32))
strategy1 = ((1, 2, 4), (1, 1), (1, 2))
strategy2 = ((1, 2, 4), (1, 2, 4))
net = Net(strategy1, strategy2)
model = Model(net)
with pytest.raises(RuntimeError):
model.predict(inputs)
context.reset_auto_parallel_context()
def test_scatter_nd_sub():
"""
Feature: distribute operator scatter_nd_sub in auto parallel.
Description: scatter_nd_sub net with sharding input gather dims in semi auto parallel.
Expectation: assert ok.
"""
context.set_context(mode=context.GRAPH_MODE)
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, full_batch=True)
inputs = Tensor(np.ones([32, 64, 128]).astype(np.float32))
strategy1 = ((2, 2, 2), (1, 1), (1,))
strategy2 = ((2, 2, 2), (2, 2, 2))
net = Net1(strategy1, strategy2, ops_type="Sub")
phase = compile_net(net, inputs)
validator = ParallelValidator(net, phase)
# check layout
inputs_expect_layout = ([2, 2, 2], [2, 1, 0], [16, 32, 64], 0, True, '')
assert validator.check_parameter_layout('input', inputs_expect_layout)
# check sub_graph
sub_graph = {
'ScatterNdSub-0': ['input', 'Sub-0', 'Mul-2'],
'FloorDiv-0': ['_GetTensorSlice-0', 'Reshape-0'],
}
assert validator.check_graph_structure(sub_graph)
def test_scatter_nd_update():
"""
Feature: distribute operator scatter_nd_update in auto parallel.
Description: scatter_nd_update net with sharding input gather dims and tables in semi auto parallel.
Expectation: assert ok.
"""
context.set_context(mode=context.GRAPH_MODE)
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, full_batch=True)
inputs = Tensor(np.ones([32, 64, 128]).astype(np.float32))
strategy1 = ((2, 2, 2), (1, 1), (1, 2))
strategy2 = ((2, 2, 2), (2, 2, 2))
net = Net(strategy1, strategy2, ops_type="Update")
model = Model(net)
with pytest.raises(RuntimeError):
model.predict(inputs)
context.reset_auto_parallel_context()
def test_tensor_scatter_add():
"""
Feature: distribute operator tensor_scatter_add in auto parallel.
Description: tensor_scatter_update net with sharding input gather dims and tables in semi auto parallel.
Expectation: assert ok.
"""
context.set_context(mode=context.GRAPH_MODE)
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, full_batch=True)
input1 = Tensor(np.ones([32, 64, 128]).astype(np.float32))
input2 = Tensor(np.ones([32, 64, 128]).astype(np.float32))
strategy1 = ((2, 2, 2), (1, 1), (1,))
strategy2 = ((1, 2, 2), (1, 2, 2))
net = Net2(strategy1, strategy2, ops_type="Add")
phase = compile_net(net, input1, input2)
validator = ParallelValidator(net, phase)
# check sub_graph
sub_graph = {
'TensorScatterAdd-0': ['Reshape-1', 'Sub-0', 'Mul-2'],
'Equal-0': ['Sub-1', 'Minimum-0'],
'AllGather-2': ['TensorScatterAdd-0']
}
assert validator.check_graph_structure(sub_graph)
def test_tensor_scatter_mul():
"""
Feature: distribute operator tensor_scatter_mul in auto parallel.
Description: tensor_scatter_update net with sharding input gather dims and tables in semi auto parallel.
Expectation: assert ok.
"""
context.set_context(mode=context.GRAPH_MODE)
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, full_batch=True)
input1 = Tensor(np.ones([32, 64, 128]).astype(np.float32))
input2 = Tensor(np.ones([32, 64, 128]).astype(np.float32))
strategy1 = ((2, 2, 2), (1, 1), (1,))
strategy2 = ((1, 2, 2), (1, 2, 2))
net = Net2(strategy1, strategy2, ops_type="Mul")
phase = compile_net(net, input1, input2)
validator = ParallelValidator(net, phase)
# check sub_graph
sub_graph = {
'TensorScatterMul-0': ['Reshape-1', 'Sub-0', 'Add-0'],
'Equal-0': ['Sub-1', 'Minimum-0'],
'AllGather-2': ['TensorScatterMul-0']
}
assert validator.check_graph_structure(sub_graph)
def test_tensor_scatter_mul_auto_parallel():
"""
Feature: distribute operator tensor_scatter_mul in auto parallel.
Description: tensor_scatter_update net with sharding input gather dims and tables in auto parallel.
Expectation: assert ok.
"""
context.set_context(mode=context.GRAPH_MODE)
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, full_batch=True)
input1 = Tensor(np.ones([32, 64, 128]).astype(np.float32))
input2 = Tensor(np.ones([32, 64, 128]).astype(np.float32))
strategy1 = ((2, 2, 2), (1, 1), (1,))
strategy2 = None
net = Net2(strategy1, strategy2, ops_type="Mul")
phase = compile_net(net, input1, input2)
validator = ParallelValidator(net, phase)
# check sub_graph
sub_graph = {
'TensorScatterMul-0': ['ReLU-0', 'Sub-0', 'Add-0'],
'Equal-0': ['Sub-1', 'Minimum-0']
}
assert validator.check_graph_structure(sub_graph)