!23544 remove deprecated gather op

Merge pull request !23544 from zhuyuxiao/master
This commit is contained in:
i-robot 2021-09-24 01:20:53 +00:00 committed by Gitee
commit d37fccc56f
8 changed files with 47 additions and 694 deletions

View File

@ -141,7 +141,6 @@ REGISTER(ReLU6Info);
REGISTER(ReLUV2Info); REGISTER(ReLUV2Info);
REGISTER(SoftplusInfo); REGISTER(SoftplusInfo);
REGISTER(SoftsignInfo); REGISTER(SoftsignInfo);
REGISTER(GatherInfo);
REGISTER(SparseGatherV2Info); REGISTER(SparseGatherV2Info);
REGISTER(SqrtInfo); REGISTER(SqrtInfo);
REGISTER(SigmoidInfo); REGISTER(SigmoidInfo);
@ -181,7 +180,7 @@ REGISTER(UniformCandidateSamplerInfo);
REGISTER(UnsortedSegmentSumInfo); REGISTER(UnsortedSegmentSumInfo);
REGISTER(UnsortedSegmentMinInfo); REGISTER(UnsortedSegmentMinInfo);
REGISTER(UnsortedSegmentMaxInfo); REGISTER(UnsortedSegmentMaxInfo);
REGISTER(GatherPInfo); REGISTER(GatherInfo);
REGISTER(EmbeddingLookupInfo); REGISTER(EmbeddingLookupInfo);
REGISTER(TileInfo); REGISTER(TileInfo);
REGISTER(BroadcastToInfo); REGISTER(BroadcastToInfo);

View File

@ -14,7 +14,7 @@
* limitations under the License. * limitations under the License.
*/ */
#include "frontend/parallel/ops_info/gather_v2_p_info.h" #include "frontend/parallel/ops_info/gather_info.h"
#include <vector> #include <vector>
#include <numeric> #include <numeric>
@ -32,7 +32,7 @@
namespace mindspore { namespace mindspore {
namespace parallel { namespace parallel {
Status GatherPInfo::GetManualSplitWithoutOffsetAttr() { Status GatherInfo::GetManualSplitWithoutOffsetAttr() {
auto manual_split_without_offset_iter = attrs_.find("manual_split"); auto manual_split_without_offset_iter = attrs_.find("manual_split");
if (manual_split_without_offset_iter != attrs_.end()) { if (manual_split_without_offset_iter != attrs_.end()) {
manual_split_ = true; manual_split_ = true;
@ -68,7 +68,7 @@ Status GatherPInfo::GetManualSplitWithoutOffsetAttr() {
return SUCCESS; return SUCCESS;
} }
Status GatherPInfo::GetManualSplitAttr() { Status GatherInfo::GetManualSplitAttr() {
auto manual_split_with_offset_iter = attrs_.find("manual_split_with_offset"); auto manual_split_with_offset_iter = attrs_.find("manual_split_with_offset");
if (manual_split_with_offset_iter != attrs_.end()) { if (manual_split_with_offset_iter != attrs_.end()) {
manual_split_ = true; manual_split_ = true;
@ -118,7 +118,7 @@ Status GatherPInfo::GetManualSplitAttr() {
return SUCCESS; return SUCCESS;
} }
Status GatherPInfo::GetAttrs() { Status GatherInfo::GetAttrs() {
// get axis, the third input is the axis, is a ValueNode, embeddinglookup doesn't have axis. // get axis, the third input is the axis, is a ValueNode, embeddinglookup doesn't have axis.
if (target_ != CPU) { if (target_ != CPU) {
if (input_value_.at(2) == nullptr) { if (input_value_.at(2) == nullptr) {
@ -170,7 +170,7 @@ Status GatherPInfo::GetAttrs() {
return SUCCESS; return SUCCESS;
} }
Status GatherPInfo::CheckManualSplit(const Strategys &strategy) { Status GatherInfo::CheckManualSplit(const Strategys &strategy) {
if (strategy.size() != 2) { if (strategy.size() != 2) {
MS_LOG(ERROR) << name_ << ": The size of strategy must be 2, but got " << strategy.size(); MS_LOG(ERROR) << name_ << ": The size of strategy must be 2, but got " << strategy.size();
return FAILED; return FAILED;
@ -226,7 +226,7 @@ Status GatherPInfo::CheckManualSplit(const Strategys &strategy) {
return SUCCESS; return SUCCESS;
} }
Status GatherPInfo::CheckSplitAxisStrategy(const StrategyPtr &strategy) { Status GatherInfo::CheckSplitAxisStrategy(const StrategyPtr &strategy) {
auto param_strategy = strategy->GetInputDim().at(0); auto param_strategy = strategy->GetInputDim().at(0);
auto index_strategy = strategy->GetInputDim().at(1); auto index_strategy = strategy->GetInputDim().at(1);
// param_strategy(axis) != 1, index can't be split // param_strategy(axis) != 1, index can't be split
@ -255,7 +255,7 @@ Status GatherPInfo::CheckSplitAxisStrategy(const StrategyPtr &strategy) {
// return true: axis is 0, and split the first dimension of parameter and the first dimension of indices // return true: axis is 0, and split the first dimension of parameter and the first dimension of indices
// otherwise return false // otherwise return false
bool GatherPInfo::ShardBatchAndAxis(const Strategys &strategy) const { bool GatherInfo::ShardBatchAndAxis(const Strategys &strategy) const {
if (axis_ != 0) { if (axis_ != 0) {
return false; return false;
} }
@ -285,7 +285,7 @@ bool GatherPInfo::ShardBatchAndAxis(const Strategys &strategy) const {
return true; return true;
} }
void GatherPInfo::SetAttribute(const StrategyPtr &strategy) { void GatherInfo::SetAttribute(const StrategyPtr &strategy) {
auto param_strategy = strategy->GetInputDim().at(0); auto param_strategy = strategy->GetInputDim().at(0);
// axis=0, index_shape(0)%param_strategy(0) must be 0 // axis=0, index_shape(0)%param_strategy(0) must be 0
Shape index_shape = inputs_shape_.at(1); Shape index_shape = inputs_shape_.at(1);
@ -312,7 +312,7 @@ void GatherPInfo::SetAttribute(const StrategyPtr &strategy) {
MS_LOG(INFO) << "Set repeated_num_in_dev_matrix_right for gather to " << repeated_num_in_dev_matrix_right_; MS_LOG(INFO) << "Set repeated_num_in_dev_matrix_right for gather to " << repeated_num_in_dev_matrix_right_;
} }
Status GatherPInfo::CheckStrategy(const StrategyPtr &strategy) { Status GatherInfo::CheckStrategy(const StrategyPtr &strategy) {
if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) { if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) {
return FAILED; return FAILED;
} }
@ -373,7 +373,7 @@ Status GatherPInfo::CheckStrategy(const StrategyPtr &strategy) {
return SUCCESS; return SUCCESS;
} }
Status GatherPInfo::InferMirrorOps() { Status GatherInfo::InferMirrorOps() {
// There is no mirror operators for manual split // There is no mirror operators for manual split
if (manual_split_) { if (manual_split_) {
return SUCCESS; return SUCCESS;
@ -403,7 +403,7 @@ Status GatherPInfo::InferMirrorOps() {
return SUCCESS; return SUCCESS;
} }
Status GatherPInfo::InferDevMatrixShape() { Status GatherInfo::InferDevMatrixShape() {
dev_matrix_shape_.clear(); dev_matrix_shape_.clear();
out_dev_matrix_shape_.clear(); out_dev_matrix_shape_.clear();
// infer input dev_matrix_shape // infer input dev_matrix_shape
@ -460,7 +460,7 @@ Status GatherPInfo::InferDevMatrixShape() {
return SUCCESS; return SUCCESS;
} }
void GatherPInfo::InferInputsTensorMap() { void GatherInfo::InferInputsTensorMap() {
// infer input tensor map // infer input tensor map
// param_strategy(axis) is not 1 // param_strategy(axis) is not 1
size_t param_size = inputs_shape_.at(0).size(); size_t param_size = inputs_shape_.at(0).size();
@ -487,7 +487,7 @@ void GatherPInfo::InferInputsTensorMap() {
inputs_tensor_map_.emplace_back(std::move(tensor_map_index)); inputs_tensor_map_.emplace_back(std::move(tensor_map_index));
} }
void GatherPInfo::InferOutputsTensorMap() { void GatherInfo::InferOutputsTensorMap() {
// infer output tensor map // infer output tensor map
size_t param_size = inputs_shape_.at(0).size(); size_t param_size = inputs_shape_.at(0).size();
size_t index_size = inputs_shape_.at(1).size(); size_t index_size = inputs_shape_.at(1).size();
@ -534,7 +534,7 @@ void GatherPInfo::InferOutputsTensorMap() {
(void)outputs_tensor_map_.emplace_back(std::move(tensor_map_out)); (void)outputs_tensor_map_.emplace_back(std::move(tensor_map_out));
} }
Status GatherPInfo::InferTensorMap() { Status GatherInfo::InferTensorMap() {
if (manual_split_) { if (manual_split_) {
Shape param_map = {1, 0}; Shape param_map = {1, 0};
Shape indices_map = {-1, 1}; Shape indices_map = {-1, 1};
@ -560,7 +560,7 @@ Status GatherPInfo::InferTensorMap() {
return SUCCESS; return SUCCESS;
} }
Status GatherPInfo::InferTensorInfo() { Status GatherInfo::InferTensorInfo() {
// infer tensor shape // infer tensor shape
Shape input_shape = inputs_shape_.at(0); Shape input_shape = inputs_shape_.at(0);
Shape input_index_shape = inputs_shape_.at(1); Shape input_index_shape = inputs_shape_.at(1);
@ -593,7 +593,7 @@ Status GatherPInfo::InferTensorInfo() {
return SUCCESS; return SUCCESS;
} }
Status GatherPInfo::InferBias() { Status GatherInfo::InferBias() {
CheckGlobalDeviceManager(); CheckGlobalDeviceManager();
int64_t rank = g_device_manager->rank_index_in_stage(); int64_t rank = g_device_manager->rank_index_in_stage();
auto input_shape = inputs_shape_.at(0); auto input_shape = inputs_shape_.at(0);
@ -656,7 +656,7 @@ Status GatherPInfo::InferBias() {
return FAILED; return FAILED;
} }
Status GatherPInfo::InferOffset() { Status GatherInfo::InferOffset() {
CheckGlobalDeviceManager(); CheckGlobalDeviceManager();
size_t rank = LongToSize(g_device_manager->rank_index_in_stage()); size_t rank = LongToSize(g_device_manager->rank_index_in_stage());
@ -677,7 +677,7 @@ Status GatherPInfo::InferOffset() {
return FAILED; return FAILED;
} }
Status GatherPInfo::InferGroup() { Status GatherInfo::InferGroup() {
size_t dim = LongToSize(axis_); size_t dim = LongToSize(axis_);
int64_t rank = g_device_manager->global_rank(); int64_t rank = g_device_manager->global_rank();
@ -708,7 +708,7 @@ Status GatherPInfo::InferGroup() {
return SUCCESS; return SUCCESS;
} }
Status GatherPInfo::InferForwardCommunication() { Status GatherInfo::InferForwardCommunication() {
if (manual_split_) { if (manual_split_) {
return SUCCESS; return SUCCESS;
} }
@ -745,7 +745,7 @@ Status GatherPInfo::InferForwardCommunication() {
return SUCCESS; return SUCCESS;
} }
Status GatherPInfo::ComputeReplaceGraph(const CNodePtr &cnode) { Status GatherInfo::ComputeReplaceGraph(const CNodePtr &cnode) {
GenerateGraph gen_g = GenerateGraph(attrs_); GenerateGraph gen_g = GenerateGraph(attrs_);
if (gen_g.Init(cnode) != SUCCESS) { if (gen_g.Init(cnode) != SUCCESS) {
MS_LOG(ERROR) << "GenerateGraph Init failed"; MS_LOG(ERROR) << "GenerateGraph Init failed";
@ -805,7 +805,7 @@ Status GatherPInfo::ComputeReplaceGraph(const CNodePtr &cnode) {
return SUCCESS; return SUCCESS;
} }
ReplaceGraphPtr GatherPInfo::replace_graph(const CNodePtr &cnode) { ReplaceGraphPtr GatherInfo::replace_graph(const CNodePtr &cnode) {
if (manual_split_ && target_ != CPU) { if (manual_split_ && target_ != CPU) {
if (ComputeReplaceGraph(cnode) != SUCCESS) { if (ComputeReplaceGraph(cnode) != SUCCESS) {
MS_LOG(EXCEPTION) << name_ << ": ComputeReplaceGraph failed."; MS_LOG(EXCEPTION) << name_ << ": ComputeReplaceGraph failed.";
@ -824,7 +824,7 @@ ReplaceGraphPtr GatherPInfo::replace_graph(const CNodePtr &cnode) {
return replace_graph_; return replace_graph_;
} }
Status GatherPInfo::ComputeReplaceOp() { Status GatherInfo::ComputeReplaceOp() {
int64_t bias = 0; int64_t bias = 0;
if (manual_split_) { if (manual_split_) {
if (InferOffset() != SUCCESS) { if (InferOffset() != SUCCESS) {
@ -851,7 +851,7 @@ Status GatherPInfo::ComputeReplaceOp() {
return SUCCESS; return SUCCESS;
} }
Status GatherPInfo::Init(const StrategyPtr &strategy) { Status GatherInfo::Init(const StrategyPtr &strategy) {
if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { if (InitWithAutoRepeatCalc(strategy) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Init failed."; MS_LOG(ERROR) << name_ << ": Init failed.";
return FAILED; return FAILED;
@ -864,7 +864,7 @@ Status GatherPInfo::Init(const StrategyPtr &strategy) {
return SUCCESS; return SUCCESS;
} }
Status GatherPInfo::InitForCostModel(const StrategyPtr &strategy) { Status GatherInfo::InitForCostModel(const StrategyPtr &strategy) {
if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) {
if (is_auto_parallel_) { if (is_auto_parallel_) {
MS_LOG(DEBUG) << name_ << ": Init for cost model failed."; MS_LOG(DEBUG) << name_ << ": Init for cost model failed.";
@ -882,9 +882,9 @@ Status GatherPInfo::InitForCostModel(const StrategyPtr &strategy) {
return SUCCESS; return SUCCESS;
} }
Status GatherPInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); } Status GatherInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); }
std::vector<StrategyPtr> GatherPInfo::GenerateOpStrategies(int64_t stage_id) { std::vector<StrategyPtr> GatherInfo::GenerateOpStrategies(int64_t stage_id) {
if (manual_split_) { if (manual_split_) {
MS_LOG(EXCEPTION) << name_ << ": Manual split does not support to search strategy"; MS_LOG(EXCEPTION) << name_ << ": Manual split does not support to search strategy";
} }
@ -900,7 +900,7 @@ std::vector<StrategyPtr> GatherPInfo::GenerateOpStrategies(int64_t stage_id) {
return sp_vector; return sp_vector;
} }
std::shared_ptr<Strategys> GatherPInfo::GenerateBatchStrategies() { std::shared_ptr<Strategys> GatherInfo::GenerateBatchStrategies() {
if (GetAttrs() != SUCCESS) { if (GetAttrs() != SUCCESS) {
MS_LOG(EXCEPTION) << name_ << ": Get attr failed"; MS_LOG(EXCEPTION) << name_ << ": Get attr failed";
} }

View File

@ -14,8 +14,8 @@
* limitations under the License. * limitations under the License.
*/ */
#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_GATHER_V2_P_INFO_H_ #ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_GATHER_INFO_H_
#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_GATHER_V2_P_INFO_H_ #define MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_GATHER_INFO_H_
#include <memory> #include <memory>
#include <string> #include <string>
@ -29,17 +29,17 @@
namespace mindspore { namespace mindspore {
namespace parallel { namespace parallel {
class GatherPInfo : public OperatorInfo { class GatherInfo : public OperatorInfo {
public: public:
GatherPInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, GatherInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs, const std::string &replace_op_name = GATHERV2) const PrimitiveAttrs &attrs, const std::string &replace_op_name = GATHERV2)
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<GatherV2PCost>()), : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<GatherV2PCost>()),
axis_(0), axis_(0),
bias_(0), bias_(0),
index_offset_(0), index_offset_(0),
slice_size_(0), slice_size_(0),
replace_op_name_(replace_op_name) {} replace_op_name_(replace_op_name) {}
~GatherPInfo() override = default; ~GatherInfo() override = default;
Status Init(const StrategyPtr &strategy) override; Status Init(const StrategyPtr &strategy) override;
Status InitForCostModel(const StrategyPtr &strategy) override; Status InitForCostModel(const StrategyPtr &strategy) override;
@ -88,21 +88,21 @@ class GatherPInfo : public OperatorInfo {
std::vector<int64_t> index_offsets_; std::vector<int64_t> index_offsets_;
}; };
class SparseGatherV2Info : public GatherPInfo { class SparseGatherV2Info : public GatherInfo {
public: public:
SparseGatherV2Info(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, SparseGatherV2Info(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs, const std::string &replace_op_name = SPARSE_GATHERV2) const PrimitiveAttrs &attrs, const std::string &replace_op_name = SPARSE_GATHERV2)
: GatherPInfo(name, inputs_shape, outputs_shape, attrs, replace_op_name) {} : GatherInfo(name, inputs_shape, outputs_shape, attrs, replace_op_name) {}
~SparseGatherV2Info() override = default; ~SparseGatherV2Info() override = default;
}; };
class EmbeddingLookupInfo : public GatherPInfo { class EmbeddingLookupInfo : public GatherInfo {
public: public:
EmbeddingLookupInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, EmbeddingLookupInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs) const PrimitiveAttrs &attrs)
: GatherPInfo(name, inputs_shape, outputs_shape, attrs) {} : GatherInfo(name, inputs_shape, outputs_shape, attrs) {}
~EmbeddingLookupInfo() override = default; ~EmbeddingLookupInfo() override = default;
}; };
} // namespace parallel } // namespace parallel
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_GATHER_V2_P_INFO_H_ #endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_GATHER_INFO_H_

View File

@ -1,318 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "frontend/parallel/ops_info/gather_v2_info.h"
#include <memory>
#include <utility>
#include <vector>
#include "ir/tensor.h"
#include "ir/value.h"
#include "frontend/parallel/auto_parallel/costmodel.h"
#include "frontend/parallel/device_matrix.h"
#include "frontend/parallel/graph_util/generate_graph.h"
#include "frontend/parallel/strategy.h"
#include "utils/log_adapter.h"
namespace mindspore {
namespace parallel {
Status GatherInfo::GetAttrs() {
if (inputs_shape_.size() != GATHER_V2_INPUTS_SIZE) {
MS_LOG(ERROR) << name_ << ": inputs shape size must be 2, but is " << inputs_shape_.size();
return FAILED;
}
if (outputs_shape_.size() != GATHER_V2_OUTPUTS_SIZE) {
MS_LOG(ERROR) << name_ << ": outputs shape size must be 1, but is " << outputs_shape_.size();
return FAILED;
}
if (input_value_.size() != GATHER_V2_INPUTS_VALUE_SIZE) {
MS_LOG(ERROR) << name_ << ": input value size must be 3, but is " << input_value_.size();
return FAILED;
}
// the second input is the index tensor
// the third input is the axis, is a ValueNode
if (input_value_.at(2) == nullptr) {
MS_LOG(ERROR) << name_ << ": the third input value is nullptr, is not a ValueNode!";
return FAILED;
}
if (inputs_shape_.at(0).size() == 0) {
MS_LOG(ERROR) << name_ << ": input can not be a scalar!";
return FAILED;
}
int64_t axis = GetValue<int64_t>(input_value_.at(2));
if (axis >= SizeToLong(inputs_shape_.at(0).size()) || axis < -SizeToLong(inputs_shape_.at(0).size())) {
MS_LOG(ERROR) << "Axis is " << axis << ", not in [-" << inputs_shape_.at(0).size() << ", "
<< inputs_shape_.at(0).size() << ").";
}
if (axis < 0) {
axis += SizeToLong(inputs_shape_[0].size());
}
axis_ = axis;
index_size_ = inputs_shape_.at(1).size();
return SUCCESS;
}
Status GatherInfo::CheckStrategy(const StrategyPtr &strategy) {
if (inputs_shape_.size() != GATHER_V2_INPUTS_SIZE) {
MS_LOG(ERROR) << name_ << ": inputs shape size must be " << GATHER_V2_INPUTS_SIZE << ", but is "
<< inputs_shape_.size();
return FAILED;
}
if (outputs_shape_.size() != GATHER_V2_OUTPUTS_SIZE) {
MS_LOG(ERROR) << name_ << ": outputs shape size must be " << GATHER_V2_OUTPUTS_SIZE << ", but is "
<< outputs_shape_.size();
return FAILED;
}
// Only strategy of the first input should be set.
if (CheckStrategyValue(strategy, {inputs_shape_.at(0)}) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Invalid strategy.";
return FAILED;
}
axis_strategy_ = strategy->GetInputDim().at(0).at(LongToSize(axis_));
if (index_size_ != 1 && axis_strategy_ != 1) {
MS_LOG(ERROR) << name_
<< ": Invalid strategy. If the index is a scalar or a more than 1 dimension vector, the strategy "
"corresponding to axis must be 1, but is "
<< axis_strategy_;
return FAILED;
}
if (index_size_ == 1 && axis_strategy_ != 1 && inputs_shape_.at(1).at(0) % axis_strategy_ != 0) {
MS_LOG(ERROR) << name_
<< ": Invalid strategy. The first dimension of index can not be divided by strategy corresponding to "
"axis. The first dimension of index is "
<< inputs_shape_.at(1).at(0) << " strategy corresponding to axis is " << axis_strategy_;
return FAILED;
}
return SUCCESS;
}
Status GatherInfo::InferDevMatrixShape() {
Strategys stra = strategy_->GetInputDim();
dev_matrix_shape_ = stra.at(0);
return SUCCESS;
}
// If index is a scalar, output dimension is input dimension minus 1;
// If index is a n dimension tensor, output dimension is input dimension plus (n - 1).
// Tensor map dimension is equal to the corresponding input and output dimension.
// If index's dimension is more than 1, we insert -1 for the output tensor map.
Status GatherInfo::InferTensorMap() {
if (inputs_shape_.size() != GATHER_V2_INPUTS_SIZE) {
MS_LOG(ERROR) << name_ << ": inputs shape size must be " << GATHER_V2_INPUTS_SIZE << ", but is "
<< inputs_shape_.size();
return FAILED;
}
if (outputs_shape_.size() != GATHER_V2_OUTPUTS_SIZE) {
MS_LOG(ERROR) << name_ << ": outputs shape size must be " << GATHER_V2_OUTPUTS_SIZE << ", but is "
<< outputs_shape_.size();
return FAILED;
}
Shape tensor_map_in;
Shape tensor_map_out;
size_t size = inputs_shape_.at(0).size();
// such as 4: tensor_map_index [3,2,1,0]
for (size_t i = 0; i < size; ++i) {
tensor_map_in.push_back(SizeToLong(size - i - 1));
tensor_map_out.push_back(SizeToLong(size - i - 1));
}
if (index_size_ == 0) {
(void)tensor_map_out.erase(tensor_map_out.begin() + axis_);
} else if (index_size_ > 1) {
(void)tensor_map_out.insert(tensor_map_out.begin() + axis_, index_size_ - 1, -1);
}
if (tensor_map_out.size() != outputs_shape_.at(0).size()) {
MS_LOG(ERROR) << "Out tensor map size is not equal to output size! Out tensor map size is " << tensor_map_out.size()
<< " output size is " << outputs_shape_.at(0).size();
return FAILED;
}
Shape tensor_map_in_index;
if (index_size_ >= 1) {
tensor_map_in_index.push_back(SizeToLong(size) - axis_ - 1);
}
for (size_t i = 1; i < index_size_; ++i) {
tensor_map_in_index.push_back(-1);
}
inputs_tensor_map_.emplace_back(std::move(tensor_map_in));
inputs_tensor_map_.emplace_back(std::move(tensor_map_in_index));
outputs_tensor_map_.emplace_back(std::move(tensor_map_out));
return SUCCESS;
}
Status GatherInfo::InferTensorInfo() {
if (inputs_shape_.size() != GATHER_V2_INPUTS_SIZE) {
MS_LOG(ERROR) << name_ << ": inputs shape size must be " << GATHER_V2_INPUTS_SIZE << ", but is "
<< inputs_shape_.size();
return FAILED;
}
if (outputs_shape_.size() != GATHER_V2_OUTPUTS_SIZE) {
MS_LOG(ERROR) << name_ << ": outputs shape size must be " << GATHER_V2_OUTPUTS_SIZE << ", but is "
<< outputs_shape_.size();
return FAILED;
}
if (inputs_tensor_map_.size() != GATHER_V2_INPUTS_SIZE) {
MS_LOG(ERROR) << name_ << ": inputs tensor map size must be " << GATHER_V2_INPUTS_SIZE << ", but is "
<< inputs_tensor_map_.size();
return FAILED;
}
if (outputs_tensor_map_.size() != GATHER_V2_OUTPUTS_SIZE) {
MS_LOG(ERROR) << name_ << ": outputs tensor map size must be " << GATHER_V2_OUTPUTS_SIZE << ", but is "
<< outputs_tensor_map_.size();
return FAILED;
}
// infer tensor shape
Shape input_shape = inputs_shape_.at(0);
Shape input_index_shape = inputs_shape_.at(1);
Shape output_shape = outputs_shape_.at(0);
TensorLayout input_tensor_layout, input_index_layout, output_tensor_layout;
if ((input_tensor_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_.at(0), input_shape) != SUCCESS) ||
(input_index_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_.at(1), input_index_shape) != SUCCESS) ||
(output_tensor_layout.InitFromVector(dev_matrix_shape_, outputs_tensor_map_.at(0), output_shape) != SUCCESS)) {
return FAILED;
}
TensorInfo input_tensor_info(input_tensor_layout);
TensorInfo input_index_info(input_index_layout);
TensorInfo output_tensor_info(output_tensor_layout);
inputs_tensor_info_.push_back(input_tensor_info);
inputs_tensor_info_.push_back(input_index_info);
outputs_tensor_info_.push_back(output_tensor_info);
return SUCCESS;
}
OperatorVector CreateSubOp(int64_t sub_value) {
OperatorVector ops;
OperatorName operator_name = SUB;
OperatorAttrs operator_attrs;
std::vector<int64_t> tensor_data = {sub_value};
mindspore::tensor::TensorPtr tensor_ptr = std::make_shared<mindspore::tensor::Tensor>(tensor_data, kInt32);
ValuePtr op_param_value = MakeValue(tensor_ptr);
Attr op1_param = std::make_pair("", op_param_value);
OperatorParams operator_param = {std::make_pair(op1_param, 2)};
OperatorArgs operator_args = std::make_pair(operator_attrs, operator_param);
Operator op = std::make_pair(operator_name, operator_args);
ops.push_back(op);
return ops;
}
Status GatherInfo::InferTensorSubOps() {
sub_ops_.clear();
if ((index_size_ == 0) || (axis_strategy_ == 1)) {
return SUCCESS;
}
int64_t mod_n = 1;
for (size_t i = LongToSize(axis_) + 1; i < dev_matrix_shape_.size(); i++) {
mod_n *= dev_matrix_shape_.at(i);
}
if ((axis_ >= SizeToLong(dev_matrix_shape_.size())) || axis_ < 0) {
MS_LOG(ERROR) << "Axis is " << axis_ << ", not in [0, " << dev_matrix_shape_.size() << ").";
}
int64_t mod_p = mod_n * dev_matrix_shape_.at(LongToSize(axis_));
int64_t rank = g_device_manager->rank_index_in_stage();
int64_t mod_rank = rank % mod_p;
mod_rank = static_cast<int64_t>(mod_rank / mod_n);
if (inputs_shape_.size() != GATHER_V2_INPUTS_SIZE) {
MS_LOG(ERROR) << name_ << ": inputs shape size must be " << GATHER_V2_INPUTS_SIZE << ", but is "
<< inputs_shape_.size();
return FAILED;
}
if ((axis_ >= SizeToLong(inputs_shape_.at(0).size())) || axis_ < 0) {
MS_LOG(ERROR) << "Axis is " << axis_ << ", not in [0, " << inputs_shape_.at(0).size() << ").";
}
int64_t sub_value = inputs_shape_[0][LongToSize(axis_)] / dev_matrix_shape_[LongToSize(axis_)] * mod_rank;
OperatorVector sub_op;
sub_ops_.emplace_back(std::move(sub_op));
sub_op = CreateSubOp(sub_value);
sub_ops_.emplace_back(std::move(sub_op));
return SUCCESS;
}
Status GatherInfo::Init(const StrategyPtr &strategy) {
if (InitWithAutoRepeatCalc(strategy) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Init failed.";
return FAILED;
}
Status status = InferTensorSubOps();
if (status != SUCCESS) {
MS_LOG(ERROR) << name_ << ": InferTensorSubOps failed.";
return status;
}
MS_LOG(INFO) << name_ << ": Init success.";
return SUCCESS;
}
Status GatherInfo::InitForCostModel(const StrategyPtr &strategy) {
if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Init for cost model failed.";
return FAILED;
}
MS_LOG(INFO) << name_ << ": Init for cost model success.";
return SUCCESS;
}
std::vector<StrategyPtr> GatherInfo::GenerateOpStrategies(int64_t stage_id) {
if ((inputs_shape_.size() != GATHER_V2_INPUTS_SIZE) || (outputs_shape_.size() != GATHER_V2_OUTPUTS_SIZE)) {
MS_LOG(EXCEPTION) << name_ << " : Inputs shape size(" << inputs_shape_.size() << ") or outputs shape size("
<< outputs_shape_.size() << "is wrong.";
}
Shape input0_split(inputs_shape_[0].size(), 1);
Shapes splittable_inputs = {input0_split};
std::vector<StrategyPtr> sp_vector;
if (GenerateStrategiesForIndependentInputs(stage_id, {inputs_shape_.at(0)}, splittable_inputs, &sp_vector) !=
SUCCESS) {
MS_LOG(EXCEPTION) << name_ << " : Generate strategies for independent inputs() failed.";
}
return sp_vector;
}
Status GatherInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); }
std::shared_ptr<Strategys> GatherInfo::GenerateBatchStrategies() {
if (inputs_shape_.size() != GATHER_V2_INPUTS_SIZE) {
MS_LOG(EXCEPTION) << name_ << ": inputs shape size must be " << GATHER_V2_INPUTS_SIZE << ", but is "
<< inputs_shape_.size();
}
if (GetAttrs() != SUCCESS) {
MS_LOG(EXCEPTION) << "GetAttrs failed!";
}
Dimensions strategy;
if (index_size_ != 1) {
strategy.push_back(1);
} else {
strategy.push_back(stage_device_size_);
}
for (size_t i = 1; i < inputs_shape_[0].size(); i++) {
strategy.push_back(1);
}
Strategys strategy_v = {strategy};
return std::make_shared<Strategys>(strategy_v);
}
} // namespace parallel
} // namespace mindspore

View File

@ -1,73 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_GATHER_V2_INFO_H_
#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_GATHER_V2_INFO_H_
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "ir/value.h"
#include "frontend/parallel/auto_parallel/operator_costmodel.h"
#include "frontend/parallel/ops_info/operator_info.h"
#include "frontend/parallel/strategy.h"
namespace mindspore {
namespace parallel {
constexpr size_t GATHER_V2_INPUTS_SIZE = 2;
constexpr size_t GATHER_V2_OUTPUTS_SIZE = 1;
constexpr size_t GATHER_V2_INPUTS_VALUE_SIZE = 3;
// We now supported limited parallel strategies.
// If the strategy corresponding to axis is more than 1, index must be evenly distributed across the axis-dimension of
// the input.
// If Index is a scalar or n-dimension vector(n > 1), the strategy corresponding to axis must be 1.
class GatherInfo : public OperatorInfo {
public:
GatherInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<GatherV2Cost>()),
axis_(-1),
index_size_(0),
axis_strategy_(1) {}
~GatherInfo() override = default;
Status Init(const StrategyPtr &strategy) override;
Status InitForCostModel(const StrategyPtr &strategy) override;
std::vector<StrategyPtr> GenerateOpStrategies(int64_t stage_id) override;
Status SetCostUnderStrategy(const StrategyPtr &strategy) override;
std::shared_ptr<Strategys> GenerateBatchStrategies() override;
protected:
Status CheckStrategy(const StrategyPtr &strategy) override;
Status InferMirrorOps() override { return SUCCESS; }
Status InferForwardCommunication() override { return SUCCESS; }
Status InferTensorInfo() override;
Status InferDevMatrixShape() override;
Status InferTensorMap() override;
Status GetAttrs() override;
private:
Status InferTensorSubOps();
int64_t axis_;
size_t index_size_;
int64_t axis_strategy_;
};
} // namespace parallel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_GATHER_V2_INFO_H_

View File

@ -24,7 +24,6 @@
#include "frontend/parallel/ops_info/comparison_function_info.h" #include "frontend/parallel/ops_info/comparison_function_info.h"
#include "frontend/parallel/ops_info/dropout_do_mask_info.h" #include "frontend/parallel/ops_info/dropout_do_mask_info.h"
#include "frontend/parallel/ops_info/elementary_function_info.h" #include "frontend/parallel/ops_info/elementary_function_info.h"
#include "frontend/parallel/ops_info/gather_v2_info.h"
#include "frontend/parallel/ops_info/get_next_info.h" #include "frontend/parallel/ops_info/get_next_info.h"
#include "frontend/parallel/ops_info/l2_normalize_info.h" #include "frontend/parallel/ops_info/l2_normalize_info.h"
#include "frontend/parallel/ops_info/layer_norm_info.h" #include "frontend/parallel/ops_info/layer_norm_info.h"
@ -37,7 +36,7 @@
#include "frontend/parallel/ops_info/transpose_info.h" #include "frontend/parallel/ops_info/transpose_info.h"
#include "frontend/parallel/ops_info/unsorted_segment_op_info.h" #include "frontend/parallel/ops_info/unsorted_segment_op_info.h"
#include "frontend/parallel/ops_info/virtual_dataset_info.h" #include "frontend/parallel/ops_info/virtual_dataset_info.h"
#include "frontend/parallel/ops_info/gather_v2_p_info.h" #include "frontend/parallel/ops_info/gather_info.h"
#include "frontend/parallel/ops_info/tile_info.h" #include "frontend/parallel/ops_info/tile_info.h"
#include "frontend/parallel/ops_info/strided_slice_info.h" #include "frontend/parallel/ops_info/strided_slice_info.h"
#include "frontend/parallel/ops_info/slice_info.h" #include "frontend/parallel/ops_info/slice_info.h"

View File

@ -1252,20 +1252,6 @@ OperatorInfoPtr OperatorInstanceByName(const std::string &name, const PrimitiveA
MS_LOG(EXCEPTION) << "Length of name is zero!"; MS_LOG(EXCEPTION) << "Length of name is zero!";
} }
std::string distribute_opname = GetDisOpName(name); std::string distribute_opname = GetDisOpName(name);
if (name == GATHERV2) {
distribute_opname = name + "PInfo";
auto data_parallel_iter = attrs.find(DATA_PARALLEL);
if (data_parallel_iter != attrs.end()) {
MS_EXCEPTION_IF_NULL(data_parallel_iter->second);
if (!data_parallel_iter->second->isa<BoolImm>()) {
MS_LOG(EXCEPTION) << ": data_parallel flag's type is not a bool.";
}
bool data_parallel = data_parallel_iter->second->cast<BoolImmPtr>()->value();
if (data_parallel) {
distribute_opname = name + "Info";
}
}
}
OperatorInfoPtr operator_ = OperatorInfoPtr operator_ =
(OperatorInfoPtr)DynCreator::Instance().Create(distribute_opname, shape_list[0], shape_list[1], attrs, TOTAL_OPS); (OperatorInfoPtr)DynCreator::Instance().Create(distribute_opname, shape_list[0], shape_list[1], attrs, TOTAL_OPS);
if (operator_ == nullptr) { if (operator_ == nullptr) {
@ -2632,9 +2618,9 @@ ParameterMap NodeParameterName(const CNodePtr &node, int64_t index, size_t curr_
return param_names; return param_names;
} }
bool IsGatherPInfo(const std::string &name) { bool IsGatherInfo(const std::string &name) {
std::vector<std::string> gather_p_info_names = {"GatherPInfo", "SparseGatherV2Info", "EmbeddingLookupInfo"}; std::vector<std::string> gather_info_names = {"GatherInfo", "SparseGatherV2Info", "EmbeddingLookupInfo"};
for (std::string info_name : gather_p_info_names) { for (std::string info_name : gather_info_names) {
if (name.find(info_name) != std::string::npos) { if (name.find(info_name) != std::string::npos) {
return true; return true;
} }
@ -2669,10 +2655,10 @@ void CheckpointStrategy(const std::vector<AnfNodePtr> &all_nodes, const FuncGrap
for (auto param_name_pair : param_names) { for (auto param_name_pair : param_names) {
tensor_info_map[param_name_pair.first] = param_name_pair.second->user_data<TensorLayout>(); tensor_info_map[param_name_pair.first] = param_name_pair.second->user_data<TensorLayout>();
} }
if (IsGatherPInfo(operator_info->name())) { if (IsGatherInfo(operator_info->name())) {
auto gatherv2_info = std::dynamic_pointer_cast<GatherPInfo>(operator_info); auto gather_info = std::dynamic_pointer_cast<GatherInfo>(operator_info);
auto param_split_shapes = gatherv2_info->param_split_shapes(); auto param_split_shapes = gather_info->param_split_shapes();
auto index_offsets = gatherv2_info->index_offsets(); auto index_offsets = gather_info->index_offsets();
if (param_split_shapes.size() != index_offsets.size()) { if (param_split_shapes.size() != index_offsets.size()) {
MS_LOG(EXCEPTION) << "In manual split, the param_split_shapes and index_offsets length should be same."; MS_LOG(EXCEPTION) << "In manual split, the param_split_shapes and index_offsets length should be same.";
} }

View File

@ -1,240 +0,0 @@
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import mindspore as ms
import mindspore.nn as nn
from mindspore import Tensor
from mindspore import context
from mindspore.common import dtype as mstype
from mindspore.common.parameter import ParameterTuple
from mindspore.communication.management import init
from mindspore.nn import Dense, Cell
from mindspore.nn.loss.loss import LossBase
from mindspore.nn.optim import Momentum
from mindspore.ops import composite as C
from mindspore.ops import operations as P
from mindspore.train import Model
from mindspore.context import ParallelMode
from mindspore.communication._comm_helper import GlobalComm
context.set_context(mode=context.GRAPH_MODE)
device_number = 32
batch_size_per_device = 128
class Dataset():
def __init__(self, predict, length=3):
self.predict = predict
self.index = 0
self.length = length
def __iter__(self):
return self
def __next__(self):
if self.index >= self.length:
raise StopIteration
self.index += 1
return (self.predict,)
def reset(self):
self.index = 0
def get_dataset_size(self):
return 128
def get_repeat_count(self):
return 1
def create_tuple_iterator(self, num_epochs=-1, do_copy=True):
return self
class GatherV2(LossBase):
def __init__(self, index_dim, strategy, index_size=16):
super(GatherV2, self).__init__()
self.pow = P.Pow()
emb1_list = 21
emb2_list = 2
if index_dim == 1:
emb_list = list(range(index_size))
emb1_list = emb_list[0::2]
emb2_list = emb_list[1::2]
if index_dim == 2:
emb_list = np.arange(index_size * 16)
emb1_list = np.reshape(emb_list[0::2], (int(index_size / 2), 16))
emb2_list = np.reshape(emb_list[1::2], (int(index_size / 2), 16))
self.emb1_param = Tensor(emb1_list, dtype=mstype.int32)
self.emb2_param = Tensor(emb2_list, dtype=mstype.int32)
self.gatherv2 = P.Gather().shard(strategy).add_prim_attr("data_parallel", True)
def construct(self, nembeddings):
emb1 = self.gatherv2(nembeddings, self.emb1_param, 0)
emb2 = self.gatherv2(nembeddings, self.emb2_param, 0)
return self.pow((emb1 - emb2), 2.0)
def fc_with_initialize(input_channels, out_channels):
return Dense(input_channels, out_channels)
class BuildTrainNetwork(nn.Cell):
def __init__(self, network, criterion):
super(BuildTrainNetwork, self).__init__()
self.network = network
self.criterion = criterion
def construct(self, input_data):
embeddings = self.network(input_data)
loss = self.criterion(embeddings)
return loss
class TrainOneStepCell(Cell):
def __init__(self, network, optimizer, sens=1.0):
super(TrainOneStepCell, self).__init__(auto_prefix=False)
self.network = network
self.network.add_flags(defer_inline=True)
self.weights = ParameterTuple(network.trainable_params())
self.optimizer = optimizer
self.grad = C.GradOperation(get_by_list=True,
sens_param=True)
self.sens = sens
def construct(self, data):
weights = self.weights
loss = self.network(data)
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
grads = self.grad(self.network, weights)(data, sens)
self.optimizer(grads)
return loss
def net_trains(criterion, rank):
GlobalComm.CHECK_ENVS = False
init()
GlobalComm.CHECK_ENVS = True
lr = 0.1
momentum = 0.9
max_epoch = 20
input_channels = 256
out_channels = 512
context.set_context(mode=context.GRAPH_MODE)
context.reset_auto_parallel_context()
context.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, device_num=device_number,
global_rank=rank)
predict = Tensor(np.ones([batch_size_per_device, input_channels]), dtype=ms.float32)
dataset = Dataset(predict, 4)
network = fc_with_initialize(input_channels, out_channels)
network.set_train()
train_network = BuildTrainNetwork(network, criterion)
train_network.set_train()
opt = Momentum(train_network.trainable_params(), lr, momentum)
train_net = TrainOneStepCell(train_network, opt).set_train()
model = Model(train_net)
model.train(max_epoch, dataset, dataset_sink_mode=False)
context.reset_auto_parallel_context()
def test_auto_batch_parallel():
gather_v2_strategy = None
criterion = GatherV2(1, strategy=gather_v2_strategy, index_size=batch_size_per_device * device_number)
rank = 2
net_trains(criterion, rank)
def test_2d_index_auto_batch_parallel():
gather_v2_strategy = None
criterion = GatherV2(2, strategy=gather_v2_strategy, index_size=batch_size_per_device * device_number)
rank = 2
net_trains(criterion, rank)
def test_batch_parallel():
gather_v2_strategy = ((device_number, 1),)
criterion = GatherV2(1, strategy=gather_v2_strategy, index_size=batch_size_per_device * device_number)
rank = 2
net_trains(criterion, rank)
def test_strategy1():
gather_v2_strategy = ((16, 2),)
rank = 2
criterion = GatherV2(1, strategy=gather_v2_strategy, index_size=batch_size_per_device * device_number)
net_trains(criterion, rank)
def test_strategy2():
gather_v2_strategy = ((1, device_number),)
rank = 2
criterion = GatherV2(1, strategy=gather_v2_strategy, index_size=batch_size_per_device * device_number)
net_trains(criterion, rank)
def test_strategy3():
gather_v2_strategy = ((8, 1),)
rank = 2
criterion = GatherV2(1, strategy=gather_v2_strategy, index_size=batch_size_per_device * device_number)
net_trains(criterion, rank)
class GatherV2Axis1(LossBase):
def __init__(self, index_dim, strategy, index_size=16):
super(GatherV2Axis1, self).__init__()
self.pow = P.Pow()
emb1_list = 21
emb2_list = 2
if index_dim == 1:
emb_list = list(range(index_size))
emb1_list = emb_list[0::2]
emb2_list = emb_list[1::2]
if index_dim == 2:
emb_list = np.arange(index_size * index_size)
emb1_list = np.reshape(emb_list[0::2], (int(index_size / 2), index_size))
emb2_list = np.reshape(emb_list[1::2], (int(index_size / 2), index_size))
self.emb1_param = Tensor(emb1_list, dtype=mstype.int32)
self.emb2_param = Tensor(emb2_list, dtype=mstype.int32)
self.gatherv2 = P.Gather().shard(strategy)
def construct(self, nembeddings):
emb1 = self.gatherv2(nembeddings, self.emb1_param, 1)
emb2 = self.gatherv2(nembeddings, self.emb2_param, 1)
return self.pow((emb1 - emb2), 2.0)
def test_axis1_auto_batch_parallel():
gather_v2_strategy = None
criterion = GatherV2Axis1(1, strategy=gather_v2_strategy, index_size=512)
rank = 2
net_trains(criterion, rank)
def test_axis1_batch_parallel():
gather_v2_strategy = ((device_number, 1), (1,))
criterion = GatherV2Axis1(1, strategy=gather_v2_strategy, index_size=512)
rank = 2
net_trains(criterion, rank)
def test_axis1_strategy1():
gather_v2_strategy = ((16, 2), (1,))
rank = 17
criterion = GatherV2Axis1(1, strategy=gather_v2_strategy, index_size=512)
net_trains(criterion, rank)