add scatter nd op
This commit is contained in:
parent
dfd2239118
commit
bbaeb3203b
|
@ -1895,15 +1895,65 @@ double ScatterMathOpsCost::GetForwardComputationCost(const std::vector<TensorInf
|
|||
MS_LOG(EXCEPTION) << "Invalid inputs type size " << inputs_type_lengths_.size() << " for gatherv2 cost";
|
||||
}
|
||||
// don't split axis
|
||||
if (strategy_.at(0) == 1) {
|
||||
if (!is_split_axis_) {
|
||||
result += ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]) +
|
||||
ListProduct(input1_slice_shape) * static_cast<double>(inputs_type_lengths_[1]) +
|
||||
ListProduct(input2_slice_shape) * static_cast<double>(inputs_type_lengths_[2]);
|
||||
} else {
|
||||
// split axis
|
||||
result += ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]) * input_coefficient_ +
|
||||
ListProduct(input1_slice_shape) * static_cast<double>(inputs_type_lengths_[1]) * indices_coefficient_ +
|
||||
ListProduct(input2_slice_shape) * static_cast<double>(inputs_type_lengths_[1]) * updates_coefficient_;
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
double TensorScatterOpsCost::GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &,
|
||||
int64_t stage_id) const {
|
||||
double result = 0.0;
|
||||
CheckGlobalDeviceManager();
|
||||
MS_EXCEPTION_IF_NULL(g_device_manager);
|
||||
auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size();
|
||||
|
||||
for (size_t j = 0; j < inputs.size(); ++j) {
|
||||
if (!is_parameter_[j]) {
|
||||
continue;
|
||||
}
|
||||
TensorInfo input_a_tensor_info = inputs[j];
|
||||
Shape input_a_shape = input_a_tensor_info.shape();
|
||||
Shape input_a_slice_shape = input_a_tensor_info.slice_shape();
|
||||
int64_t used_device_num = 1;
|
||||
for (size_t i = 0; i < input_a_shape.size(); ++i) {
|
||||
used_device_num *= input_a_shape[i] / input_a_slice_shape[i];
|
||||
}
|
||||
if (total_device_num != LongToSize(used_device_num)) {
|
||||
result += ListProduct(input_a_slice_shape) * static_cast<double>(inputs_type_lengths_[j]);
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
double TensorScatterOpsCost::GetBackwardComputationCost(const std::vector<TensorInfo> &inputs,
|
||||
const std::vector<TensorInfo> &outputs, int64_t) const {
|
||||
double result = 0.0;
|
||||
Shape input0_slice_shape = inputs[0].slice_shape();
|
||||
Shape input1_slice_shape = inputs[1].slice_shape();
|
||||
Shape input2_slice_shape = inputs[2].slice_shape(); // equal to output shape
|
||||
// brop func using 3 times input/out in average.
|
||||
// don't split axis
|
||||
if (!is_split_axis_) {
|
||||
result += ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]) +
|
||||
ListProduct(input1_slice_shape) * static_cast<double>(inputs_type_lengths_[1]) * 3 +
|
||||
ListProduct(input2_slice_shape) * static_cast<double>(inputs_type_lengths_[1]) * 2;
|
||||
ListProduct(input1_slice_shape) * static_cast<double>(inputs_type_lengths_[1]) +
|
||||
ListProduct(input2_slice_shape) * static_cast<double>(inputs_type_lengths_[2]);
|
||||
result *= 3;
|
||||
|
||||
} else {
|
||||
// split axis
|
||||
result +=
|
||||
ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]) * (input_coefficient_ + 3) +
|
||||
ListProduct(input1_slice_shape) * static_cast<double>(inputs_type_lengths_[1]) * (indices_coefficient_ + 3) +
|
||||
ListProduct(input2_slice_shape) * static_cast<double>(inputs_type_lengths_[1]) * (updates_coefficient_ + 3);
|
||||
}
|
||||
|
||||
return result;
|
||||
|
|
|
@ -1226,10 +1226,34 @@ class ScatterMathOpsCost : public OperatorCost {
|
|||
is_inputs_should_in_memory_[0] = true;
|
||||
}
|
||||
|
||||
void set_strategy(Shape strategy) { strategy_ = strategy; }
|
||||
void set_is_split_axis(bool is_split_axis) { is_split_axis_ = is_split_axis; }
|
||||
void set_coefficient(int32_t input_coefficient, int32_t indices_coefficient, int32_t updates_coefficient) {
|
||||
input_coefficient_ = input_coefficient;
|
||||
indices_coefficient_ = indices_coefficient;
|
||||
updates_coefficient_ = updates_coefficient;
|
||||
}
|
||||
|
||||
protected:
|
||||
Shape strategy_;
|
||||
bool is_split_axis_ = false;
|
||||
int32_t input_coefficient_ = 1;
|
||||
int32_t indices_coefficient_ = 3;
|
||||
int32_t updates_coefficient_ = 2;
|
||||
};
|
||||
|
||||
class ScatterNdOpsCost : public ScatterMathOpsCost {
|
||||
public:
|
||||
ScatterNdOpsCost() : ScatterMathOpsCost() {}
|
||||
~ScatterNdOpsCost() override = default;
|
||||
};
|
||||
|
||||
class TensorScatterOpsCost : public ScatterMathOpsCost {
|
||||
public:
|
||||
TensorScatterOpsCost() : ScatterMathOpsCost() {}
|
||||
~TensorScatterOpsCost() override = default;
|
||||
double GetBackwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
|
||||
int64_t stage_id) const override;
|
||||
double GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
|
||||
int64_t stage_id) const override;
|
||||
};
|
||||
|
||||
class CropAndResizeCost : public OperatorCost {
|
||||
|
|
|
@ -95,6 +95,20 @@ constexpr size_t BATCH_NORM_INPUTS_SIZE = 5;
|
|||
constexpr double EPS = 1e-6;
|
||||
constexpr double INF = 1e20;
|
||||
constexpr double COST_FACTOR = 2.0;
|
||||
constexpr size_t SIZE_ZERO = 0;
|
||||
constexpr size_t SIZE_ONE = 1;
|
||||
constexpr size_t SIZE_TWO = 2;
|
||||
constexpr size_t SIZE_THREE = 3;
|
||||
constexpr size_t SIZE_FOUR = 4;
|
||||
constexpr size_t SIZE_FIVE = 5;
|
||||
constexpr size_t SIZE_SIX = 6;
|
||||
constexpr size_t INDEX_ZERO = 0;
|
||||
constexpr size_t INDEX_ONE = 1;
|
||||
constexpr size_t INDEX_TWO = 2;
|
||||
constexpr size_t INDEX_THREE = 3;
|
||||
constexpr size_t INDEX_FOUR = 4;
|
||||
constexpr size_t INDEX_FIVE = 5;
|
||||
constexpr size_t INDEX_SIX = 6;
|
||||
|
||||
constexpr char AUTO_PARALLEL_RUN_ONCE_ONLY[] = "auto_parallel_run_once_only";
|
||||
constexpr char SEMI_AUTO_PARALLEL_RUN_ONCE_ONLY[] = "semi_auto_parallel_run_once_only";
|
||||
|
@ -467,6 +481,15 @@ constexpr char SCATTER_MUL[] = "ScatterMul";
|
|||
constexpr char SCATTER_DIV[] = "ScatterDiv";
|
||||
constexpr char SCATTER_MAX[] = "ScatterMax";
|
||||
constexpr char SCATTER_MIN[] = "ScatterMin";
|
||||
constexpr char SCATTER_ND_UPDATE[] = "ScatterNdUpdate";
|
||||
constexpr char SCATTER_ND_ADD[] = "ScatterNdAdd";
|
||||
constexpr char SCATTER_ND_SUB[] = "ScatterNdSub";
|
||||
constexpr char TENSOR_SCATTER_ADD[] = "TensorScatterAdd";
|
||||
constexpr char TENSOR_SCATTER_SUB[] = "TensorScatterSub";
|
||||
constexpr char TENSOR_SCATTER_MAX[] = "TensorScatterMax";
|
||||
constexpr char TENSOR_SCATTER_MIN[] = "TensorScatterMin";
|
||||
constexpr char TENSOR_SCATTER_MUL[] = "TensorScatterMul";
|
||||
constexpr char TENSOR_SCATTER_DIV[] = "TensorScatterDiv";
|
||||
constexpr char GATHERD[] = "GatherD";
|
||||
constexpr char DSD_MATMUL[] = "DSDMatmul";
|
||||
constexpr char RESIZE_BILINEAR[] = "ResizeBilinear";
|
||||
|
|
|
@ -262,7 +262,7 @@ Status ScatterMathOpsInfo::InitForCostModel(const StrategyPtr &in_strategy, cons
|
|||
auto param_strategy = strategy_->GetInputDim().at(0);
|
||||
// cost model set axis and strategy
|
||||
auto scatter_ops_cost = std::dynamic_pointer_cast<ScatterMathOpsCost>(operator_cost());
|
||||
scatter_ops_cost->set_strategy(param_strategy);
|
||||
scatter_ops_cost->set_is_split_axis((param_strategy[0] > 1));
|
||||
MS_LOG(INFO) << name_ << ": Init for cost model success.";
|
||||
return SUCCESS;
|
||||
}
|
||||
|
|
|
@ -0,0 +1,407 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "frontend/parallel/ops_info/scatter_nd_ops_info.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "frontend/parallel/device_matrix.h"
|
||||
#include "frontend/parallel/dynamic_creator.h"
|
||||
#include "frontend/parallel/strategy.h"
|
||||
#include "frontend/parallel/tensor_layout/tensor_redistribution.h"
|
||||
#include "pipeline/jit/resource.h"
|
||||
#include "frontend/parallel/tensor_layout/shape_util.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace parallel {
|
||||
// The shape of input: [A, B, C, D], the strategy of input: (a, b, c, d)
|
||||
// The shape of indices: [Q, W, 2], the strategy of indices: (1, 1, 1)
|
||||
// the look dim of indices whose value (x, y) should satisfy '0 <= x < A, 0 <= y < B'
|
||||
// here the 2 respect to the size of [A, B]
|
||||
// The shape of updates: [Q, W, C, D], the strategy of updates: (1, 1, c, d)
|
||||
// The shape of output: [A, B, C, D], the strategy of output: (a, b, c, d)
|
||||
// The dev matrix: (a, b, 1, 1)
|
||||
Status ScatterNdOpsInfo::CheckStrategy(const StrategyPtr &strategy) {
|
||||
MS_EXCEPTION_IF_NULL(strategy);
|
||||
if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << ": Invalid strategy";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
std::vector<Dimensions> stra = strategy->GetInputDim();
|
||||
if (stra.size() != 3) {
|
||||
MS_LOG(ERROR) << name_ << ": The size of strategy must be 3";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
auto indices_shape = inputs_shape_[1];
|
||||
gather_dims_size_ = indices_shape.back();
|
||||
if (CheckInputStrategy(stra[0]) != SUCCESS) {
|
||||
return FAILED;
|
||||
}
|
||||
if (!stra[1].empty() && std::accumulate(stra[1].begin(), stra[1].end(), 1, std::multiplies<int64_t>()) != 1) {
|
||||
MS_LOG(ERROR) << name_ << ": The indices can not be split";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
if (stra[2].empty()) {
|
||||
MS_LOG(ERROR) << name_ << ": The strategy[2] is empty";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
if (std::accumulate(stra[2].begin(), stra[2].begin() + static_cast<different_type>(stra[1].size() - 1), 1,
|
||||
std::multiplies<int64_t>()) != 1) {
|
||||
MS_LOG(ERROR) << name_ << ": The first " << (stra[1].size() - 1) << " dimensions of updates can not be split";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < stra[0].size() - gather_dims_size_; ++i) {
|
||||
if (stra[0][i + gather_dims_size_] != stra[2][stra[1].size() - 1 + i]) {
|
||||
MS_LOG(ERROR)
|
||||
<< name_ << ": updates.strategy must be equal to indices.strategy[:-1] + input.strategy[indices.strategy[-1]:]";
|
||||
return FAILED;
|
||||
}
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < gather_dims_size_; ++i) {
|
||||
if (stra[0][i] > 1) {
|
||||
do_replace_graph_ = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
// The shape of input: [A, B.., C, D], the strategy of input: (1, 1.., c, d)
|
||||
Status ScatterNdUpdateInfo::CheckInputStrategy(const Dimensions &strategy_item) {
|
||||
for (size_t i = 0; i < gather_dims_size_; ++i) {
|
||||
if (strategy_item[i] != 1) {
|
||||
MS_LOG(ERROR) << "For " << name_
|
||||
<< ", "
|
||||
"the input cannot be shard at gather dims input[:"
|
||||
<< gather_dims_size_ << "].";
|
||||
return FAILED;
|
||||
}
|
||||
}
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
// The shape of input: [A, B.., C, D], the strategy of input: (1, 1.., c, d)
|
||||
Status TensorScatterUpdateInfo::CheckInputStrategy(const Dimensions &strategy_item) {
|
||||
for (size_t i = 0; i < gather_dims_size_; ++i) {
|
||||
if (strategy_item[i] != 1) {
|
||||
MS_LOG(ERROR) << "For ScatterNdUpdate/TensorScatterUpdate, "
|
||||
"the input cannot be shard at gather dims input[:"
|
||||
<< gather_dims_size_ << "].";
|
||||
return FAILED;
|
||||
}
|
||||
}
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status ScatterNdOpsInfo::InferDevMatrixShape() {
|
||||
MS_EXCEPTION_IF_NULL(strategy_);
|
||||
std::vector<Dimensions> stra = strategy_->GetInputDim();
|
||||
if (stra.empty()) {
|
||||
MS_LOG(ERROR) << name_ << "The strategy is empty";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
dev_matrix_shape_ = stra[0];
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
// The shape of input: [A, B, C, D], the strategy of input: (a, b, c, d), tensor_map: (3, 2, 1, 0)
|
||||
// The shape of indices: [Q, W, 2], the strategy of indices: (1, 1, 1), tensor_map: (-1, -1, -1)
|
||||
// The shape of updates: [Q, W, C, D], the strategy of updates: (1, 1, c, d), tensor_map: (-1, -1, 1, 0)
|
||||
// The shape of output: [A, B, C, D], the strategy of output: (a, b, c, d), tensor_map: (3, 2, 1, 0)
|
||||
// The dev matrix: (a, b, c, d)
|
||||
Status ScatterNdOpsInfo::InferTensorMap() {
|
||||
if (inputs_shape_.size() != 3) {
|
||||
MS_LOG(ERROR) << name_ << "The size of inputs shape must be 3";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
TensorMap input_tensor_map, updates_tensor_map;
|
||||
TensorMap indices_tensor_map(inputs_shape_[1].size(), MAP_NONE);
|
||||
|
||||
// cannot use dev_matrix_shape_ replace inputs_shape_[0], because it may not be fully split in all devices.
|
||||
int64_t size = SizeToLong(inputs_shape_[0].size());
|
||||
for (int64_t i = 0; i < size; ++i) {
|
||||
input_tensor_map.push_back(size - i - 1);
|
||||
}
|
||||
|
||||
// updates_tensor_map = indices_tensor_map[:-1] + input_tensor_map[gather_dim_size_:]
|
||||
updates_tensor_map = indices_tensor_map;
|
||||
updates_tensor_map.pop_back();
|
||||
for (size_t i = gather_dims_size_; i < input_tensor_map.size(); ++i) {
|
||||
updates_tensor_map.push_back(input_tensor_map[i]);
|
||||
}
|
||||
inputs_tensor_map_.push_back(input_tensor_map); // input
|
||||
inputs_tensor_map_.push_back(indices_tensor_map); // indices
|
||||
inputs_tensor_map_.push_back(updates_tensor_map); // updates
|
||||
|
||||
outputs_tensor_map_.push_back(input_tensor_map);
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
void ScatterNdOpsInfo::ReComputeBatchSplitFlagList() {
|
||||
for (size_t i = 0; i < inputs_shape_.size(); i++) {
|
||||
split_flag_list_[i] = false; // when set no strategy, going stand alone mode.
|
||||
}
|
||||
}
|
||||
|
||||
// in_strategy: ((A, B, C, D), (), ()), Shapes: ((a, b, c, d), (e, f), (e, f, c, d))
|
||||
// return: ((A, B, C, D), (1, 1), (1, 1, C, D))
|
||||
// return: throw exception
|
||||
Shapes ScatterNdOpsInfo::InferStrategyIndividualMode(const Shapes &in_strategy) {
|
||||
if (in_strategy.size() != SIZE_THREE) {
|
||||
MS_LOG(EXCEPTION) << name_ << ": The size of in_strategy must be 3, but got " << in_strategy.size();
|
||||
}
|
||||
|
||||
if (in_strategy[INDEX_ZERO].empty()) {
|
||||
MS_LOG(EXCEPTION) << name_ << ": The in_strategy[0] must be not empty, you should set the first input strategy.";
|
||||
}
|
||||
|
||||
Shape indices_strategy, updates_strategy;
|
||||
indices_strategy = Shape(inputs_shape_[INDEX_ONE].size(), 1);
|
||||
updates_strategy = Shape(inputs_shape_[INDEX_TWO].size(), 1);
|
||||
for (size_t i = 0; i < in_strategy[0].size() - gather_dims_size_; ++i) {
|
||||
updates_strategy[indices_strategy.size() - 1 + i] = in_strategy[INDEX_ZERO][i + gather_dims_size_];
|
||||
}
|
||||
return Shapes({in_strategy[INDEX_ZERO], indices_strategy, updates_strategy});
|
||||
}
|
||||
|
||||
ReplaceGraphPtr ScatterNdOpsInfo::replace_graph(const CNodePtr &cnode) {
|
||||
if (ComputeReplaceGraph(cnode) != SUCCESS) {
|
||||
MS_LOG(EXCEPTION) << name_ << " replace graph failed";
|
||||
}
|
||||
return replace_graph_;
|
||||
}
|
||||
|
||||
// The shape of input: [A, B, C, D], the strategy of input: (a, b, c, d)
|
||||
// The shape of indices: [Q, W, 2], the strategy of indices: (1, 1, 1)
|
||||
// the look dim of indices whose value (x, y) should satisfy '0 <= x < A, 0 <= y < B'
|
||||
// The shape of updates: [Q, W, C, D], the strategy of updates: (1, 1, c, d)
|
||||
// when splitting [A, B], doing replace graph
|
||||
Status ScatterNdOpsInfo::ComputeReplaceGraph(const CNodePtr &cnode) {
|
||||
if (!do_replace_graph_) {
|
||||
return SUCCESS;
|
||||
}
|
||||
if (gen_g_.Init(cnode) != SUCCESS) {
|
||||
MS_LOG(ERROR) << "GenerateGraph Init failed";
|
||||
return FAILED;
|
||||
}
|
||||
auto anf_node_list = PrepareReplaceGraph(cnode);
|
||||
// {indices_sub, mul, div, dtype};
|
||||
auto indices_sub = anf_node_list[0];
|
||||
auto mul = anf_node_list[1];
|
||||
auto div = anf_node_list[2];
|
||||
auto dtype = anf_node_list[3];
|
||||
auto info_position = name_.find("Info");
|
||||
if (info_position == std::string::npos) {
|
||||
MS_LOG(EXCEPTION) << "The name " << name_ << " dose not contain 'Info'";
|
||||
}
|
||||
auto node_name = name_.substr(0, info_position);
|
||||
auto scatter_ops = gen_g_.PushBack({gen_g_.NewOpInst(node_name), gen_g_.virtual_input_node(), indices_sub, mul});
|
||||
|
||||
std::vector<std::pair<AnfNodePtr, int64_t>> input_nodes = {std::make_pair(scatter_ops, 1), std::make_pair(div, 2),
|
||||
std::make_pair(indices_sub, 2), std::make_pair(mul, 3),
|
||||
std::make_pair(dtype, 3)};
|
||||
replace_graph_ = std::make_shared<std::pair<std::vector<std::pair<AnfNodePtr, int64_t>>, AnfNodePtr>>(
|
||||
std::make_pair(input_nodes, scatter_ops));
|
||||
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
// The shape of input: [A, B, C, D], the strategy of input: (a, b, c, d)
|
||||
// The shape of indices: [Q, W, 2], the strategy of indices: (1, 1, 1)
|
||||
// the look dim of indices whose value (x, y) should satisfy '0 <= x < A, 0 <= y < B'
|
||||
// The shape of updates: [Q, W, C, D], the strategy of updates: (1, 1, c, d)
|
||||
// when splitting [A, B], doing replace graph
|
||||
Status ScatterNdMulDivBaseInfo::ComputeReplaceGraph(const CNodePtr &cnode) {
|
||||
if (!do_replace_graph_) {
|
||||
return SUCCESS;
|
||||
}
|
||||
if (gen_g_.Init(cnode) != SUCCESS) {
|
||||
MS_LOG(ERROR) << "GenerateGraph Init failed";
|
||||
return FAILED;
|
||||
}
|
||||
auto anf_node_list = PrepareReplaceGraph(cnode);
|
||||
// {indices_sub, mul, div, dtype};
|
||||
auto indices_sub = anf_node_list[0];
|
||||
auto mul = anf_node_list[1];
|
||||
auto div = anf_node_list[2];
|
||||
auto dtype = anf_node_list[3];
|
||||
auto reshape_updates_mask = anf_node_list[4];
|
||||
auto reverse_sub = gen_g_.PushBack(
|
||||
{gen_g_.NewOpInst(SUB), NewValueNode(std::make_shared<tensor::Tensor>(1.0, kFloat32)), reshape_updates_mask});
|
||||
auto add_mask = gen_g_.PushBack({gen_g_.NewOpInst(ADD), mul, reverse_sub});
|
||||
auto info_position = name_.find("Info");
|
||||
if (info_position == std::string::npos) {
|
||||
MS_LOG(EXCEPTION) << "The name " << name_ << " dose not contain 'Info'";
|
||||
}
|
||||
auto node_name = name_.substr(0, info_position);
|
||||
auto scatter_ops = gen_g_.PushBack({gen_g_.NewOpInst(node_name), gen_g_.virtual_input_node(), indices_sub, add_mask});
|
||||
|
||||
std::vector<std::pair<AnfNodePtr, int64_t>> input_nodes = {std::make_pair(scatter_ops, 1), std::make_pair(div, 2),
|
||||
std::make_pair(indices_sub, 2), std::make_pair(mul, 3),
|
||||
std::make_pair(dtype, 3)};
|
||||
replace_graph_ = std::make_shared<std::pair<std::vector<std::pair<AnfNodePtr, int64_t>>, AnfNodePtr>>(
|
||||
std::make_pair(input_nodes, scatter_ops));
|
||||
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
// The shape of input: [A, B, C, D], the strategy of input: (a, b, c, d)
|
||||
// The shape of indices: [Q, W, 2], the strategy of indices: (1, 1, 1)
|
||||
// the look dim of indices whose value (x, y) should satisfy '0 <= x < A, 0 <= y < B'
|
||||
// The shape of updates: [Q, W, C, D], the strategy of updates: (1, 1, c, d)
|
||||
// when splitting [A, B], doing replace graph
|
||||
std::vector<AnfNodePtr> ScatterNdOpsInfo::PrepareReplaceGraph(const CNodePtr &cnode) {
|
||||
auto rank_in_stage = g_device_manager->rank_index_in_stage();
|
||||
MS_LOG(INFO) << name_ << ": The rank is " << rank_in_stage;
|
||||
auto input_slice_shape = inputs_tensor_info_[0].slice_shape();
|
||||
Shape indices_slice_value;
|
||||
std::copy(input_slice_shape.begin(), input_slice_shape.begin() + static_cast<different_type>(gather_dims_size_),
|
||||
std::back_inserter(indices_slice_value));
|
||||
auto indices_slice_value_tensor = std::make_shared<mindspore::tensor::Tensor>(indices_slice_value, kInt32);
|
||||
Shape indices_shape_size(inputs_shape_[1].size(), 1);
|
||||
indices_shape_size[indices_shape_size.size() - 1] = -1;
|
||||
auto reshape_indices_slice =
|
||||
gen_g_.PushBack({gen_g_.NewOpInst(RESHAPE), ValuePtrToAnfNodePtr(indices_slice_value_tensor),
|
||||
NewValueNode(MakeValue(indices_shape_size))});
|
||||
auto div = gen_g_.PushBack({gen_g_.NewOpInst(FLOORDIV), gen_g_.virtual_input_node(), reshape_indices_slice});
|
||||
Shape dev_accum_shape;
|
||||
ShapeToAccumulateProductReverse(dev_matrix_shape_, &dev_accum_shape);
|
||||
MS_LOG(INFO) << "The dev_matrix is :" << dev_matrix_shape_ << ", dev_accum_shape is :" << dev_accum_shape;
|
||||
size_t gather_begin_position = 0;
|
||||
if (repeated_calc_num_ > 1 && !repeated_num_in_dev_matrix_right_) {
|
||||
gather_begin_position = 1;
|
||||
}
|
||||
Shape accum_value;
|
||||
std::copy(dev_accum_shape.begin() + static_cast<different_type>(gather_begin_position),
|
||||
dev_accum_shape.begin() + static_cast<different_type>(gather_begin_position + gather_dims_size_),
|
||||
std::back_inserter(accum_value));
|
||||
auto delta_value =
|
||||
std::accumulate(dev_accum_shape.begin() + static_cast<different_type>(gather_begin_position + gather_dims_size_),
|
||||
dev_accum_shape.end(), 0, std::plus<int64_t>());
|
||||
auto accum_value_tensor = std::make_shared<mindspore::tensor::Tensor>(accum_value, kInt32);
|
||||
auto reshape_accum_value = gen_g_.PushBack(
|
||||
{gen_g_.NewOpInst(RESHAPE), ValuePtrToAnfNodePtr(accum_value_tensor), NewValueNode(MakeValue(indices_shape_size))});
|
||||
auto rank_mul = gen_g_.PushBack({gen_g_.NewOpInst(MUL), div, reshape_accum_value});
|
||||
auto indices_mul = gen_g_.PushBack({gen_g_.NewOpInst(MUL), div, reshape_indices_slice});
|
||||
auto indices_sub = gen_g_.PushBack({gen_g_.NewOpInst(SUB), gen_g_.virtual_input_node(), indices_mul});
|
||||
auto reduce_sum = gen_g_.PushBack({gen_g_.NewOpInst(REDUCE_SUM), rank_mul, CreateInt32Tensor(-1)});
|
||||
auto sub = gen_g_.PushBack({gen_g_.NewOpInst(SUB), CreateInt32Tensor(rank_in_stage), reduce_sum});
|
||||
auto relu = gen_g_.PushBack({gen_g_.NewOpInst(RELU), sub});
|
||||
auto minimum = gen_g_.PushBack({gen_g_.NewOpInst(MINIMUM), relu, CreateInt32Tensor(delta_value)});
|
||||
auto equal = gen_g_.PushBack({gen_g_.NewOpInst(EQUAL), sub, minimum});
|
||||
auto dtype = gen_g_.PushBack({gen_g_.NewOpInst(DTYPE), gen_g_.virtual_input_node()});
|
||||
auto cast = gen_g_.PushBack({gen_g_.NewOpInst(CAST), equal, dtype});
|
||||
Shape update_shapes(inputs_shape_[2]);
|
||||
for (size_t i = inputs_shape_[1].size() - 1; i < update_shapes.size(); ++i) {
|
||||
update_shapes[i] = 1;
|
||||
}
|
||||
auto reshape_updates_mask =
|
||||
gen_g_.PushBack({gen_g_.NewOpInst(RESHAPE), cast, NewValueNode(MakeValue(update_shapes))});
|
||||
auto mul = gen_g_.PushBack({gen_g_.NewOpInst(MUL), gen_g_.virtual_input_node(), reshape_updates_mask});
|
||||
return {indices_sub, mul, div, dtype, reshape_updates_mask};
|
||||
}
|
||||
|
||||
Status ScatterNdOpsInfo::SetCostUnderStrategy(const StrategyPtr &strategy) {
|
||||
return SetCostUnderStrategyBase(strategy);
|
||||
}
|
||||
|
||||
// The shape of input: [A, B, C, D], the strategy of input: (a, b, c, d)
|
||||
// The shape of indices: [Q, W, 2], the strategy of indices: (1, 1, 1)
|
||||
// The shape of updates: [Q, W, C, D], the strategy of updates: (1, 1, c, d)
|
||||
// The shape of output: [A, B, C, D], the strategy of output: (a, b, c, d)
|
||||
// The dev matrix: (a, b, 1, 1)
|
||||
std::vector<StrategyPtr> ScatterNdOpsInfo::GenerateOpStrategies(int64_t stage_id) {
|
||||
// to generate the first input's strategy
|
||||
Shape input_split(inputs_shape_[0].size(), 1);
|
||||
Shapes splittable_input = {input_split};
|
||||
Shapes tmp_inputs_shape = {inputs_shape_[0]};
|
||||
|
||||
std::vector<StrategyPtr> sp_vector;
|
||||
if (GenerateStrategiesForIndependentInputs(stage_id, tmp_inputs_shape, splittable_input, &sp_vector) != SUCCESS) {
|
||||
MS_LOG(EXCEPTION) << name_ << ": Generate strategies failed";
|
||||
}
|
||||
|
||||
// the others strategies are equal to the first input's strategy
|
||||
for (auto &sp : sp_vector) {
|
||||
if ((sp == nullptr) || sp->GetInputDim().empty()) {
|
||||
MS_LOG(EXCEPTION) << name_ << ": The strategy is null or empty";
|
||||
}
|
||||
Strategies tmp_strategy;
|
||||
Dimensions first_input_strategy = sp->GetInputDim()[0];
|
||||
Dimensions indices_strategy(inputs_shape_[1].size(), 1);
|
||||
// updates_strategy = indices_strategy[:-1] + input_strategy[gather_dims_size_:]
|
||||
Dimensions updates_strategy = indices_strategy;
|
||||
updates_strategy.pop_back();
|
||||
for (size_t i = gather_dims_size_; i < first_input_strategy.size(); ++i) {
|
||||
updates_strategy.push_back(first_input_strategy[i]);
|
||||
}
|
||||
|
||||
tmp_strategy.push_back(first_input_strategy); // input
|
||||
tmp_strategy.push_back(indices_strategy); // indices
|
||||
tmp_strategy.push_back(updates_strategy); // updates
|
||||
sp->ResetInputs(tmp_strategy);
|
||||
}
|
||||
|
||||
return sp_vector;
|
||||
}
|
||||
|
||||
Status ScatterNdOpsInfo::InitForCostModel(const StrategyPtr &in_strategy, const StrategyPtr &out_strategy) {
|
||||
if (InitForCostModelWithAutoRepeatCalc(in_strategy, out_strategy) != SUCCESS) {
|
||||
if (is_auto_parallel_) {
|
||||
MS_LOG(DEBUG) << name_ << ": Init for cost model failed.";
|
||||
} else {
|
||||
MS_LOG(ERROR) << name_ << ": Init for cost model failed.";
|
||||
}
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
auto param_strategy = strategy_->GetInputDim().at(0);
|
||||
// cost model set axis and strategy
|
||||
auto scatter_nd_ops_cost = std::dynamic_pointer_cast<ScatterMathOpsCost>(operator_cost());
|
||||
MS_EXCEPTION_IF_NULL(scatter_nd_ops_cost);
|
||||
bool is_split_axis = std::find_if(param_strategy.begin(), param_strategy.begin() + gather_dims_size_,
|
||||
[](auto dim) { return dim > 1; }) != param_strategy.end();
|
||||
scatter_nd_ops_cost->set_is_split_axis(is_split_axis);
|
||||
scatter_nd_ops_cost->set_coefficient(1, 5, 2);
|
||||
MS_LOG(INFO) << name_ << ": Init for cost model success.";
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
REGISTER(ScatterNdAddInfo);
|
||||
REGISTER(ScatterNdSubInfo);
|
||||
REGISTER(ScatterNdUpdateInfo);
|
||||
REGISTER(TensorScatterUpdateInfo);
|
||||
REGISTER(TensorScatterAddInfo);
|
||||
REGISTER(TensorScatterSubInfo);
|
||||
REGISTER(TensorScatterMulInfo);
|
||||
REGISTER(TensorScatterDivInfo);
|
||||
REGISTER(TensorScatterMaxInfo);
|
||||
REGISTER(TensorScatterMinInfo);
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,179 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_SCATTER_ND_OPS_INFO_H_
|
||||
#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_SCATTER_ND_OPS_INFO_H_
|
||||
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "utils/hash_map.h"
|
||||
#include "ir/value.h"
|
||||
#include "frontend/parallel/auto_parallel/operator_costmodel.h"
|
||||
#include "frontend/parallel/ops_info/operator_info.h"
|
||||
#include "frontend/parallel/strategy.h"
|
||||
#include "frontend/parallel/graph_util/generate_graph.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace parallel {
|
||||
class ScatterNdOpsInfo : public OperatorInfo {
|
||||
public:
|
||||
ScatterNdOpsInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs, const OperatorCostPtr &cost)
|
||||
: OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, cost) {}
|
||||
~ScatterNdOpsInfo() override = default;
|
||||
|
||||
std::vector<StrategyPtr> GenerateOpStrategies(int64_t stage_id) override;
|
||||
Status SetCostUnderStrategy(const StrategyPtr &) override;
|
||||
void ReComputeBatchSplitFlagList() override;
|
||||
ReplaceGraphPtr replace_graph(const CNodePtr &cnode) override;
|
||||
Status InitForCostModel(const StrategyPtr &in_strategy, const StrategyPtr &out_strategy) override;
|
||||
|
||||
protected:
|
||||
Status GetAttrs() override { return SUCCESS; }
|
||||
Status CheckStrategy(const StrategyPtr &strategy) override;
|
||||
virtual Status CheckInputStrategy(const Dimensions &strategy_item) { return SUCCESS; }
|
||||
Status InferForwardCommunication() override { return SUCCESS; }
|
||||
Status InferDevMatrixShape() override;
|
||||
Status InferTensorMap() override;
|
||||
std::vector<AnfNodePtr> PrepareReplaceGraph(const CNodePtr &cnode);
|
||||
virtual Status ComputeReplaceGraph(const CNodePtr &cnode);
|
||||
Status InferBias();
|
||||
bool do_replace_graph_ = false;
|
||||
int64_t bias_ = 0;
|
||||
int64_t slice_size_ = 0;
|
||||
size_t gather_dims_size_ = 1;
|
||||
GenerateGraph gen_g_ = GenerateGraph(attrs_);
|
||||
Shapes InferStrategyIndividualMode(const Shapes &in_strategy) override;
|
||||
};
|
||||
|
||||
using ScatterNdOpsInfoPtr = std::shared_ptr<ScatterNdOpsInfo>;
|
||||
|
||||
class ScatterNdAddInfo : public ScatterNdOpsInfo {
|
||||
public:
|
||||
ScatterNdAddInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: ScatterNdOpsInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<ScatterNdOpsCost>()) {}
|
||||
Status InferMirrorOps() override { return SUCCESS; } // the scatter_nd_add only use in eval/predict
|
||||
~ScatterNdAddInfo() override = default;
|
||||
};
|
||||
|
||||
class ScatterNdSubInfo : public ScatterNdOpsInfo {
|
||||
public:
|
||||
ScatterNdSubInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: ScatterNdOpsInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<ScatterNdOpsCost>()) {}
|
||||
Status InferMirrorOps() override { return SUCCESS; } // the scatter_nd_sub only use in eval/predict
|
||||
~ScatterNdSubInfo() override = default;
|
||||
};
|
||||
|
||||
class ScatterNdUpdateInfo : public ScatterNdOpsInfo {
|
||||
public:
|
||||
ScatterNdUpdateInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: ScatterNdOpsInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<ScatterNdOpsCost>()) {}
|
||||
Status InferMirrorOps() override { return SUCCESS; } // the scatter_nd_update only use in eval/predict
|
||||
~ScatterNdUpdateInfo() override = default;
|
||||
|
||||
protected:
|
||||
Status CheckInputStrategy(const Dimensions &strategy_item) override;
|
||||
};
|
||||
|
||||
class TensorScatterUpdateInfo : public ScatterNdOpsInfo {
|
||||
public:
|
||||
TensorScatterUpdateInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: ScatterNdOpsInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<TensorScatterOpsCost>()) {}
|
||||
~TensorScatterUpdateInfo() override = default;
|
||||
|
||||
protected:
|
||||
Status CheckInputStrategy(const Dimensions &strategy_item) override;
|
||||
};
|
||||
|
||||
class TensorScatterAddInfo : public ScatterNdOpsInfo {
|
||||
public:
|
||||
TensorScatterAddInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: ScatterNdOpsInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<TensorScatterOpsCost>()) {}
|
||||
~TensorScatterAddInfo() override = default;
|
||||
};
|
||||
|
||||
class TensorScatterSubInfo : public ScatterNdOpsInfo {
|
||||
public:
|
||||
TensorScatterSubInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: ScatterNdOpsInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<TensorScatterOpsCost>()) {}
|
||||
~TensorScatterSubInfo() override = default;
|
||||
};
|
||||
|
||||
class ScatterNdMulDivBaseInfo : public ScatterNdOpsInfo {
|
||||
public:
|
||||
ScatterNdMulDivBaseInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: ScatterNdOpsInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<TensorScatterOpsCost>()) {}
|
||||
~ScatterNdMulDivBaseInfo() override = default;
|
||||
|
||||
protected:
|
||||
Status ComputeReplaceGraph(const CNodePtr &cnode) override;
|
||||
};
|
||||
|
||||
class TensorScatterMulInfo : public ScatterNdMulDivBaseInfo {
|
||||
public:
|
||||
TensorScatterMulInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: ScatterNdMulDivBaseInfo(operator_name, inputs_shape, outputs_shape, attrs) {}
|
||||
~TensorScatterMulInfo() override = default;
|
||||
};
|
||||
|
||||
class TensorScatterDivInfo : public ScatterNdMulDivBaseInfo {
|
||||
public:
|
||||
TensorScatterDivInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: ScatterNdMulDivBaseInfo(operator_name, inputs_shape, outputs_shape, attrs) {}
|
||||
~TensorScatterDivInfo() override = default;
|
||||
};
|
||||
|
||||
class TensorScatterMaxInfo : public TensorScatterUpdateInfo {
|
||||
public:
|
||||
TensorScatterMaxInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: TensorScatterUpdateInfo(operator_name, inputs_shape, outputs_shape, attrs) {}
|
||||
~TensorScatterMaxInfo() override = default;
|
||||
};
|
||||
|
||||
class TensorScatterMinInfo : public TensorScatterUpdateInfo {
|
||||
public:
|
||||
TensorScatterMinInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: TensorScatterUpdateInfo(operator_name, inputs_shape, outputs_shape, attrs) {}
|
||||
~TensorScatterMinInfo() override = default;
|
||||
};
|
||||
|
||||
using ScatterAddInfoPtr = std::shared_ptr<ScatterNdAddInfo>;
|
||||
using ScatterSubInfoPtr = std::shared_ptr<ScatterNdSubInfo>;
|
||||
using ScatterUpdateInfoPtr = std::shared_ptr<ScatterNdUpdateInfo>;
|
||||
using TensorScatterUpdateInfoPtr = std::shared_ptr<TensorScatterUpdateInfo>;
|
||||
using TensorScatterAddInfoPtr = std::shared_ptr<TensorScatterAddInfo>;
|
||||
using TensorScatterSubInfoPtr = std::shared_ptr<TensorScatterSubInfo>;
|
||||
using TensorScatterMulInfoPtr = std::shared_ptr<TensorScatterMulInfo>;
|
||||
using TensorScatterDivInfoPtr = std::shared_ptr<TensorScatterDivInfo>;
|
||||
using TensorScatterMaxInfoPtr = std::shared_ptr<TensorScatterMaxInfo>;
|
||||
using TensorScatterMinInfoPtr = std::shared_ptr<TensorScatterMinInfo>;
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_SCATTER_ND_OPS_INFO_H_
|
|
@ -725,7 +725,9 @@ bool IsSplittableOperator(const std::string &op_name) {
|
|||
BESSELI0E, BESSELI1E, FLOORMOD, ASSIGN, ASSIGN_ADD, ATAN2, DIVNONAN, LOGICALAND, LOGICALOR, ELU, RELU6, RELUV2,
|
||||
SOFTPLUS, SOFTSIGN, GREATEREQUAL, LESSEQUAL, LESS, APPROXIMATEEQUAL, MOD, UNIQUE, UNSORTED_SEGMENT_SUM,
|
||||
UNSORTED_SEGMENT_MIN, REPEAT_ELEMENTS, TENSOR_DOT, RANGE, UNIFORM_CANDIDATE_SAMPLER, SLICE, SELECT, GATHERD,
|
||||
UNSORTED_SEGMENT_MAX, GATHER_ND, TOPK, SCATTER_UPDATE, VIRTUAL_OUTPUT, CONV2D_BACK_PROP_INPUT, CONV2D_TRANSPOSE,
|
||||
UNSORTED_SEGMENT_MAX, GATHER_ND, TOPK, SCATTER_UPDATE, SCATTER_ND_UPDATE, SCATTER_ND_ADD, SCATTER_ND_SUB,
|
||||
TENSOR_SCATTER_UPDATE, TENSOR_SCATTER_ADD, TENSOR_SCATTER_SUB, TENSOR_SCATTER_MAX, TENSOR_SCATTER_MIN,
|
||||
TENSOR_SCATTER_MUL, TENSOR_SCATTER_DIV, VIRTUAL_OUTPUT, CONV2D_BACK_PROP_INPUT, CONV2D_TRANSPOSE,
|
||||
MATMUL_DDS, DSD_MATMUL, UNIFORMREAL, RESIZE_BILINEAR, RESIZE_NEAREST_NEIGHBOR, FAST_GELU, IOU, BOUNDING_BOX_ENCODE,
|
||||
RANDOM_CHOICE_WITH_MASK, CROP_AND_RESIZE, ROI_ALIGN, REDUCE_PROD, REDUCE_ANY, REDUCE_ALL, ARGMAX, ARGMIN, ARGMINV2,
|
||||
UNSORTED_SEGMENT_PROD, SQUARE_SUM_ALL, MATMUL_DDS, DSD_MATMUL, UNIFORMREAL, RESIZE_BILINEAR, UNIQUE_CONSECUTIVE,
|
||||
|
|
|
@ -0,0 +1,245 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""" test scatter update """
|
||||
import numpy as np
|
||||
import pytest
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor, Model, Parameter
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore import context
|
||||
from mindspore.common.api import _cell_graph_executor
|
||||
from parallel.utils.utils import ParallelValidator
|
||||
|
||||
scatter_nd_ops_map = {"Add": P.ScatterNdAdd(), "Update": P.ScatterNdUpdate(), "Sub": P.ScatterNdSub()}
|
||||
tensor_scatter_ops_map = {"Add": P.TensorScatterAdd(), "Update": P.TensorScatterUpdate(),
|
||||
"Sub": P.TensorScatterSub(), "Mul": P.TensorScatterMul(), "Div": P.TensorScatterDiv()}
|
||||
|
||||
|
||||
# The shape of input: [A, B, C, D], the strategy of input: (a, b, c, d)
|
||||
# The shape of indices: [Q, W, 2], the strategy of indices: (1, 1, 1)
|
||||
# here the 2 respect to the size of [A, B]
|
||||
# The shape of updates: [Q, W, C, D], the strategy of updates: (1, 1, c, d)
|
||||
# The shape of output: [A, B, C, D], the strategy of output: (a, b, c, d)
|
||||
class Net(nn.Cell):
|
||||
"""Net definition"""
|
||||
def __init__(self, strategy1=None, strategy2=None, ops_type="Add"):
|
||||
super(Net, self).__init__()
|
||||
self.inputs = Parameter(Tensor(np.ones([32, 64, 128]).astype(np.float32)), "input")
|
||||
self.indices = Tensor(np.ones([4, 2]).astype(np.int32))
|
||||
self.updates = Tensor(np.ones([4, 128]).astype(np.float32))
|
||||
self.scatter_ops = scatter_nd_ops_map.get(ops_type).shard(strategy1)
|
||||
self.add = P.TensorAdd().shard(strategy2)
|
||||
self.relu = P.ReLU()
|
||||
|
||||
def construct(self, x):
|
||||
out = self.scatter_ops(self.inputs, self.indices, self.updates)
|
||||
out = self.add(x, out)
|
||||
out = self.relu(out)
|
||||
return out
|
||||
|
||||
|
||||
class Net1(nn.Cell):
|
||||
"""Net definition"""
|
||||
def __init__(self, strategy1=None, strategy2=None, ops_type="Add"):
|
||||
super(Net1, self).__init__()
|
||||
self.inputs = Parameter(Tensor(np.ones([32, 64, 128]).astype(np.float32)), "input")
|
||||
self.indices = Tensor(np.ones([4, 3]).astype(np.int32))
|
||||
self.updates = Tensor(np.ones([4]).astype(np.float32))
|
||||
self.scatter_ops = scatter_nd_ops_map.get(ops_type).shard(strategy1)
|
||||
self.add = P.TensorAdd().shard(strategy2)
|
||||
self.relu = P.ReLU()
|
||||
|
||||
def construct(self, x):
|
||||
out = self.scatter_ops(self.inputs, self.indices, self.updates)
|
||||
out = self.add(x, out)
|
||||
out = self.relu(out)
|
||||
return out
|
||||
|
||||
|
||||
class Net2(nn.Cell):
|
||||
"""Net definition"""
|
||||
def __init__(self, strategy1=None, strategy2=None, ops_type="Add"):
|
||||
super(Net2, self).__init__()
|
||||
self.indices = Tensor(np.ones([4, 3]).astype(np.int32))
|
||||
self.updates = Tensor(np.ones([4]).astype(np.float32))
|
||||
self.scatter_ops = tensor_scatter_ops_map.get(ops_type).shard(strategy1)
|
||||
self.add = P.TensorAdd().shard(strategy2)
|
||||
self.relu = P.ReLU()
|
||||
|
||||
def construct(self, inputs, x):
|
||||
inputs = self.relu(inputs)
|
||||
out = self.scatter_ops(inputs, self.indices, self.updates)
|
||||
out = self.add(x, out)
|
||||
out = self.relu(out)
|
||||
return out
|
||||
|
||||
|
||||
def compile_net(net, *inputs):
|
||||
net.set_auto_parallel()
|
||||
net.set_train(False)
|
||||
phase, _ = _cell_graph_executor.compile(net, *inputs)
|
||||
context.reset_auto_parallel_context()
|
||||
return phase
|
||||
|
||||
|
||||
def test_scatter_nd_add():
|
||||
"""
|
||||
Feature: distribute operator scatter_nd_add in auto parallel.
|
||||
Description: scatter_nd_add net with sharding updating strategy in semi auto parallel.
|
||||
Expectation: assert ok.
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, full_batch=True)
|
||||
inputs = Tensor(np.ones([32, 64, 128]).astype(np.float32))
|
||||
strategy1 = ((1, 2, 4), (1, 1), (1, 4))
|
||||
strategy2 = ((1, 2, 4), (1, 2, 4))
|
||||
net = Net(strategy1, strategy2)
|
||||
phase = compile_net(net, inputs)
|
||||
validator = ParallelValidator(net, phase)
|
||||
# check layout
|
||||
inputs_expect_layout = ([2, 4], [-1, 1, 0], [32, 32, 32], 0, True, '')
|
||||
assert validator.check_parameter_layout('input', inputs_expect_layout)
|
||||
|
||||
|
||||
def test_scatter_nd_wrong_strategy():
|
||||
"""
|
||||
Feature: distribute operator scatter_nd_add in auto parallel.
|
||||
Description: scatter_nd_add net with wrong strategy in semi auto parallel.
|
||||
Expectation: raise runtime error.
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, full_batch=True)
|
||||
inputs = Tensor(np.ones([32, 64, 128]).astype(np.float32))
|
||||
strategy1 = ((1, 2, 4), (1, 1), (1, 2))
|
||||
strategy2 = ((1, 2, 4), (1, 2, 4))
|
||||
net = Net(strategy1, strategy2)
|
||||
model = Model(net)
|
||||
with pytest.raises(RuntimeError):
|
||||
model.predict(inputs)
|
||||
context.reset_auto_parallel_context()
|
||||
|
||||
|
||||
def test_scatter_nd_sub():
|
||||
"""
|
||||
Feature: distribute operator scatter_nd_sub in auto parallel.
|
||||
Description: scatter_nd_sub net with sharding input gather dims in semi auto parallel.
|
||||
Expectation: assert ok.
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, full_batch=True)
|
||||
inputs = Tensor(np.ones([32, 64, 128]).astype(np.float32))
|
||||
strategy1 = ((2, 2, 2), (1, 1), (1,))
|
||||
strategy2 = ((2, 2, 2), (2, 2, 2))
|
||||
net = Net1(strategy1, strategy2, ops_type="Sub")
|
||||
phase = compile_net(net, inputs)
|
||||
validator = ParallelValidator(net, phase)
|
||||
# check layout
|
||||
inputs_expect_layout = ([2, 2, 2], [2, 1, 0], [16, 32, 64], 0, True, '')
|
||||
assert validator.check_parameter_layout('input', inputs_expect_layout)
|
||||
# check sub_graph
|
||||
sub_graph = {
|
||||
'ScatterNdSub-0': ['input', 'Sub-0', 'Mul-2'],
|
||||
'FloorDiv-0': ['_GetTensorSlice-0', 'Reshape-0'],
|
||||
}
|
||||
assert validator.check_graph_structure(sub_graph)
|
||||
|
||||
|
||||
def test_scatter_nd_update():
|
||||
"""
|
||||
Feature: distribute operator scatter_nd_update in auto parallel.
|
||||
Description: scatter_nd_update net with sharding input gather dims and tables in semi auto parallel.
|
||||
Expectation: assert ok.
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, full_batch=True)
|
||||
inputs = Tensor(np.ones([32, 64, 128]).astype(np.float32))
|
||||
strategy1 = ((2, 2, 2), (1, 1), (1, 2))
|
||||
strategy2 = ((2, 2, 2), (2, 2, 2))
|
||||
net = Net(strategy1, strategy2, ops_type="Update")
|
||||
model = Model(net)
|
||||
with pytest.raises(RuntimeError):
|
||||
model.predict(inputs)
|
||||
context.reset_auto_parallel_context()
|
||||
|
||||
|
||||
def test_tensor_scatter_add():
|
||||
"""
|
||||
Feature: distribute operator tensor_scatter_add in auto parallel.
|
||||
Description: tensor_scatter_update net with sharding input gather dims and tables in semi auto parallel.
|
||||
Expectation: assert ok.
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, full_batch=True)
|
||||
input1 = Tensor(np.ones([32, 64, 128]).astype(np.float32))
|
||||
input2 = Tensor(np.ones([32, 64, 128]).astype(np.float32))
|
||||
strategy1 = ((2, 2, 2), (1, 1), (1,))
|
||||
strategy2 = ((1, 2, 2), (1, 2, 2))
|
||||
net = Net2(strategy1, strategy2, ops_type="Add")
|
||||
phase = compile_net(net, input1, input2)
|
||||
validator = ParallelValidator(net, phase)
|
||||
# check sub_graph
|
||||
sub_graph = {
|
||||
'TensorScatterAdd-0': ['Reshape-1', 'Sub-0', 'Mul-2'],
|
||||
'Equal-0': ['Sub-1', 'Minimum-0'],
|
||||
'AllGather-2': ['TensorScatterAdd-0']
|
||||
}
|
||||
assert validator.check_graph_structure(sub_graph)
|
||||
|
||||
|
||||
def test_tensor_scatter_mul():
|
||||
"""
|
||||
Feature: distribute operator tensor_scatter_mul in auto parallel.
|
||||
Description: tensor_scatter_update net with sharding input gather dims and tables in semi auto parallel.
|
||||
Expectation: assert ok.
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, full_batch=True)
|
||||
input1 = Tensor(np.ones([32, 64, 128]).astype(np.float32))
|
||||
input2 = Tensor(np.ones([32, 64, 128]).astype(np.float32))
|
||||
strategy1 = ((2, 2, 2), (1, 1), (1,))
|
||||
strategy2 = ((1, 2, 2), (1, 2, 2))
|
||||
net = Net2(strategy1, strategy2, ops_type="Mul")
|
||||
phase = compile_net(net, input1, input2)
|
||||
validator = ParallelValidator(net, phase)
|
||||
# check sub_graph
|
||||
sub_graph = {
|
||||
'TensorScatterMul-0': ['Reshape-1', 'Sub-0', 'Add-0'],
|
||||
'Equal-0': ['Sub-1', 'Minimum-0'],
|
||||
'AllGather-2': ['TensorScatterMul-0']
|
||||
}
|
||||
assert validator.check_graph_structure(sub_graph)
|
||||
|
||||
|
||||
def test_tensor_scatter_mul_auto_parallel():
|
||||
"""
|
||||
Feature: distribute operator tensor_scatter_mul in auto parallel.
|
||||
Description: tensor_scatter_update net with sharding input gather dims and tables in auto parallel.
|
||||
Expectation: assert ok.
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, full_batch=True)
|
||||
input1 = Tensor(np.ones([32, 64, 128]).astype(np.float32))
|
||||
input2 = Tensor(np.ones([32, 64, 128]).astype(np.float32))
|
||||
strategy1 = ((2, 2, 2), (1, 1), (1,))
|
||||
strategy2 = None
|
||||
net = Net2(strategy1, strategy2, ops_type="Mul")
|
||||
phase = compile_net(net, input1, input2)
|
||||
validator = ParallelValidator(net, phase)
|
||||
# check sub_graph
|
||||
sub_graph = {
|
||||
'TensorScatterMul-0': ['ReLU-0', 'Sub-0', 'Add-0'],
|
||||
'Equal-0': ['Sub-1', 'Minimum-0']
|
||||
}
|
||||
assert validator.check_graph_structure(sub_graph)
|
Loading…
Reference in New Issue