forked from mindspore-Ecosystem/mindspore
rename gather
This commit is contained in:
parent
343b28da34
commit
79d99323a2
|
@ -140,7 +140,6 @@ REGISTER(ReLU6Info);
|
|||
REGISTER(ReLUV2Info);
|
||||
REGISTER(SoftplusInfo);
|
||||
REGISTER(SoftsignInfo);
|
||||
REGISTER(GatherInfo);
|
||||
REGISTER(SparseGatherV2Info);
|
||||
REGISTER(SqrtInfo);
|
||||
REGISTER(SigmoidInfo);
|
||||
|
@ -180,7 +179,7 @@ REGISTER(UniformCandidateSamplerInfo);
|
|||
REGISTER(UnsortedSegmentSumInfo);
|
||||
REGISTER(UnsortedSegmentMinInfo);
|
||||
REGISTER(UnsortedSegmentMaxInfo);
|
||||
REGISTER(GatherPInfo);
|
||||
REGISTER(GatherInfo);
|
||||
REGISTER(EmbeddingLookupInfo);
|
||||
REGISTER(TileInfo);
|
||||
REGISTER(BroadcastToInfo);
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "frontend/parallel/ops_info/gather_v2_p_info.h"
|
||||
#include "frontend/parallel/ops_info/gather_info.h"
|
||||
|
||||
#include <vector>
|
||||
#include <numeric>
|
||||
|
@ -32,7 +32,7 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace parallel {
|
||||
Status GatherPInfo::GetManualSplitWithoutOffsetAttr() {
|
||||
Status GatherInfo::GetManualSplitWithoutOffsetAttr() {
|
||||
auto manual_split_without_offset_iter = attrs_.find("manual_split");
|
||||
if (manual_split_without_offset_iter != attrs_.end()) {
|
||||
manual_split_ = true;
|
||||
|
@ -68,7 +68,7 @@ Status GatherPInfo::GetManualSplitWithoutOffsetAttr() {
|
|||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status GatherPInfo::GetManualSplitAttr() {
|
||||
Status GatherInfo::GetManualSplitAttr() {
|
||||
auto manual_split_with_offset_iter = attrs_.find("manual_split_with_offset");
|
||||
if (manual_split_with_offset_iter != attrs_.end()) {
|
||||
manual_split_ = true;
|
||||
|
@ -118,7 +118,7 @@ Status GatherPInfo::GetManualSplitAttr() {
|
|||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status GatherPInfo::GetAttrs() {
|
||||
Status GatherInfo::GetAttrs() {
|
||||
// get axis, the third input is the axis, is a ValueNode, embeddinglookup doesn't have axis.
|
||||
if (target_ != CPU) {
|
||||
if (input_value_.at(2) == nullptr) {
|
||||
|
@ -170,7 +170,7 @@ Status GatherPInfo::GetAttrs() {
|
|||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status GatherPInfo::CheckManualSplit(const Strategys &strategy) {
|
||||
Status GatherInfo::CheckManualSplit(const Strategys &strategy) {
|
||||
if (strategy.size() != 2) {
|
||||
MS_LOG(ERROR) << name_ << ": The size of strategy must be 2, but got " << strategy.size();
|
||||
return FAILED;
|
||||
|
@ -226,7 +226,7 @@ Status GatherPInfo::CheckManualSplit(const Strategys &strategy) {
|
|||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status GatherPInfo::CheckSplitAxisStrategy(const StrategyPtr &strategy) {
|
||||
Status GatherInfo::CheckSplitAxisStrategy(const StrategyPtr &strategy) {
|
||||
auto param_strategy = strategy->GetInputDim().at(0);
|
||||
auto index_strategy = strategy->GetInputDim().at(1);
|
||||
// param_strategy(axis) != 1, index can't be split
|
||||
|
@ -255,7 +255,7 @@ Status GatherPInfo::CheckSplitAxisStrategy(const StrategyPtr &strategy) {
|
|||
|
||||
// return true: axis is 0, and split the first dimension of parameter and the first dimension of indices
|
||||
// otherwise return false
|
||||
bool GatherPInfo::ShardBatchAndAxis(const Strategys &strategy) const {
|
||||
bool GatherInfo::ShardBatchAndAxis(const Strategys &strategy) const {
|
||||
if (axis_ != 0) {
|
||||
return false;
|
||||
}
|
||||
|
@ -285,7 +285,7 @@ bool GatherPInfo::ShardBatchAndAxis(const Strategys &strategy) const {
|
|||
return true;
|
||||
}
|
||||
|
||||
void GatherPInfo::SetAttribute(const StrategyPtr &strategy) {
|
||||
void GatherInfo::SetAttribute(const StrategyPtr &strategy) {
|
||||
auto param_strategy = strategy->GetInputDim().at(0);
|
||||
// axis=0, index_shape(0)%param_strategy(0) must be 0
|
||||
Shape index_shape = inputs_shape_.at(1);
|
||||
|
@ -312,7 +312,7 @@ void GatherPInfo::SetAttribute(const StrategyPtr &strategy) {
|
|||
MS_LOG(INFO) << "Set repeated_num_in_dev_matrix_right for gather to " << repeated_num_in_dev_matrix_right_;
|
||||
}
|
||||
|
||||
Status GatherPInfo::CheckStrategy(const StrategyPtr &strategy) {
|
||||
Status GatherInfo::CheckStrategy(const StrategyPtr &strategy) {
|
||||
if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) {
|
||||
return FAILED;
|
||||
}
|
||||
|
@ -373,7 +373,7 @@ Status GatherPInfo::CheckStrategy(const StrategyPtr &strategy) {
|
|||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status GatherPInfo::InferMirrorOps() {
|
||||
Status GatherInfo::InferMirrorOps() {
|
||||
// There is no mirror operators for manual split
|
||||
if (manual_split_) {
|
||||
return SUCCESS;
|
||||
|
@ -403,7 +403,7 @@ Status GatherPInfo::InferMirrorOps() {
|
|||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status GatherPInfo::InferDevMatrixShape() {
|
||||
Status GatherInfo::InferDevMatrixShape() {
|
||||
dev_matrix_shape_.clear();
|
||||
out_dev_matrix_shape_.clear();
|
||||
// infer input dev_matrix_shape
|
||||
|
@ -460,7 +460,7 @@ Status GatherPInfo::InferDevMatrixShape() {
|
|||
return SUCCESS;
|
||||
}
|
||||
|
||||
void GatherPInfo::InferInputsTensorMap() {
|
||||
void GatherInfo::InferInputsTensorMap() {
|
||||
// infer input tensor map
|
||||
// param_strategy(axis) is not 1
|
||||
size_t param_size = inputs_shape_.at(0).size();
|
||||
|
@ -487,7 +487,7 @@ void GatherPInfo::InferInputsTensorMap() {
|
|||
inputs_tensor_map_.emplace_back(std::move(tensor_map_index));
|
||||
}
|
||||
|
||||
void GatherPInfo::InferOutputsTensorMap() {
|
||||
void GatherInfo::InferOutputsTensorMap() {
|
||||
// infer output tensor map
|
||||
size_t param_size = inputs_shape_.at(0).size();
|
||||
size_t index_size = inputs_shape_.at(1).size();
|
||||
|
@ -534,7 +534,7 @@ void GatherPInfo::InferOutputsTensorMap() {
|
|||
(void)outputs_tensor_map_.emplace_back(std::move(tensor_map_out));
|
||||
}
|
||||
|
||||
Status GatherPInfo::InferTensorMap() {
|
||||
Status GatherInfo::InferTensorMap() {
|
||||
if (manual_split_) {
|
||||
Shape param_map = {1, 0};
|
||||
Shape indices_map = {-1, 1};
|
||||
|
@ -560,7 +560,7 @@ Status GatherPInfo::InferTensorMap() {
|
|||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status GatherPInfo::InferTensorInfo() {
|
||||
Status GatherInfo::InferTensorInfo() {
|
||||
// infer tensor shape
|
||||
Shape input_shape = inputs_shape_.at(0);
|
||||
Shape input_index_shape = inputs_shape_.at(1);
|
||||
|
@ -593,7 +593,7 @@ Status GatherPInfo::InferTensorInfo() {
|
|||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status GatherPInfo::InferBias() {
|
||||
Status GatherInfo::InferBias() {
|
||||
CheckGlobalDeviceManager();
|
||||
int64_t rank = g_device_manager->rank_index_in_stage();
|
||||
auto input_shape = inputs_shape_.at(0);
|
||||
|
@ -656,7 +656,7 @@ Status GatherPInfo::InferBias() {
|
|||
return FAILED;
|
||||
}
|
||||
|
||||
Status GatherPInfo::InferOffset() {
|
||||
Status GatherInfo::InferOffset() {
|
||||
CheckGlobalDeviceManager();
|
||||
size_t rank = LongToSize(g_device_manager->rank_index_in_stage());
|
||||
|
||||
|
@ -677,7 +677,7 @@ Status GatherPInfo::InferOffset() {
|
|||
return FAILED;
|
||||
}
|
||||
|
||||
Status GatherPInfo::InferGroup() {
|
||||
Status GatherInfo::InferGroup() {
|
||||
size_t dim = LongToSize(axis_);
|
||||
|
||||
int64_t rank = g_device_manager->global_rank();
|
||||
|
@ -708,7 +708,7 @@ Status GatherPInfo::InferGroup() {
|
|||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status GatherPInfo::InferForwardCommunication() {
|
||||
Status GatherInfo::InferForwardCommunication() {
|
||||
if (manual_split_) {
|
||||
return SUCCESS;
|
||||
}
|
||||
|
@ -745,7 +745,7 @@ Status GatherPInfo::InferForwardCommunication() {
|
|||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status GatherPInfo::ComputeReplaceGraph(const CNodePtr &cnode) {
|
||||
Status GatherInfo::ComputeReplaceGraph(const CNodePtr &cnode) {
|
||||
GenerateGraph gen_g = GenerateGraph(attrs_);
|
||||
if (gen_g.Init(cnode) != SUCCESS) {
|
||||
MS_LOG(ERROR) << "GenerateGraph Init failed";
|
||||
|
@ -805,7 +805,7 @@ Status GatherPInfo::ComputeReplaceGraph(const CNodePtr &cnode) {
|
|||
return SUCCESS;
|
||||
}
|
||||
|
||||
ReplaceGraphPtr GatherPInfo::replace_graph(const CNodePtr &cnode) {
|
||||
ReplaceGraphPtr GatherInfo::replace_graph(const CNodePtr &cnode) {
|
||||
if (manual_split_ && target_ != CPU) {
|
||||
if (ComputeReplaceGraph(cnode) != SUCCESS) {
|
||||
MS_LOG(EXCEPTION) << name_ << ": ComputeReplaceGraph failed.";
|
||||
|
@ -824,7 +824,7 @@ ReplaceGraphPtr GatherPInfo::replace_graph(const CNodePtr &cnode) {
|
|||
return replace_graph_;
|
||||
}
|
||||
|
||||
Status GatherPInfo::ComputeReplaceOp() {
|
||||
Status GatherInfo::ComputeReplaceOp() {
|
||||
int64_t bias = 0;
|
||||
if (manual_split_) {
|
||||
if (InferOffset() != SUCCESS) {
|
||||
|
@ -851,7 +851,7 @@ Status GatherPInfo::ComputeReplaceOp() {
|
|||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status GatherPInfo::Init(const StrategyPtr &strategy) {
|
||||
Status GatherInfo::Init(const StrategyPtr &strategy) {
|
||||
if (InitWithAutoRepeatCalc(strategy) != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << ": Init failed.";
|
||||
return FAILED;
|
||||
|
@ -864,7 +864,7 @@ Status GatherPInfo::Init(const StrategyPtr &strategy) {
|
|||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status GatherPInfo::InitForCostModel(const StrategyPtr &strategy) {
|
||||
Status GatherInfo::InitForCostModel(const StrategyPtr &strategy) {
|
||||
if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) {
|
||||
if (is_auto_parallel_) {
|
||||
MS_LOG(DEBUG) << name_ << ": Init for cost model failed.";
|
||||
|
@ -882,9 +882,9 @@ Status GatherPInfo::InitForCostModel(const StrategyPtr &strategy) {
|
|||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status GatherPInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); }
|
||||
Status GatherInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); }
|
||||
|
||||
std::vector<StrategyPtr> GatherPInfo::GenerateOpStrategies(int64_t stage_id) {
|
||||
std::vector<StrategyPtr> GatherInfo::GenerateOpStrategies(int64_t stage_id) {
|
||||
if (manual_split_) {
|
||||
MS_LOG(EXCEPTION) << name_ << ": Manual split does not support to search strategy";
|
||||
}
|
||||
|
@ -900,7 +900,7 @@ std::vector<StrategyPtr> GatherPInfo::GenerateOpStrategies(int64_t stage_id) {
|
|||
return sp_vector;
|
||||
}
|
||||
|
||||
std::shared_ptr<Strategys> GatherPInfo::GenerateBatchStrategies() {
|
||||
std::shared_ptr<Strategys> GatherInfo::GenerateBatchStrategies() {
|
||||
if (GetAttrs() != SUCCESS) {
|
||||
MS_LOG(EXCEPTION) << name_ << ": Get attr failed";
|
||||
}
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_GATHER_V2_P_INFO_H_
|
||||
#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_GATHER_V2_P_INFO_H_
|
||||
#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_GATHER_INFO_H_
|
||||
#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_GATHER_INFO_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
@ -29,17 +29,17 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace parallel {
|
||||
class GatherPInfo : public OperatorInfo {
|
||||
class GatherInfo : public OperatorInfo {
|
||||
public:
|
||||
GatherPInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs, const std::string &replace_op_name = GATHERV2)
|
||||
GatherInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs, const std::string &replace_op_name = GATHERV2)
|
||||
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<GatherV2PCost>()),
|
||||
axis_(0),
|
||||
bias_(0),
|
||||
index_offset_(0),
|
||||
slice_size_(0),
|
||||
replace_op_name_(replace_op_name) {}
|
||||
~GatherPInfo() override = default;
|
||||
~GatherInfo() override = default;
|
||||
Status Init(const StrategyPtr &strategy) override;
|
||||
Status InitForCostModel(const StrategyPtr &strategy) override;
|
||||
|
||||
|
@ -88,21 +88,21 @@ class GatherPInfo : public OperatorInfo {
|
|||
std::vector<int64_t> index_offsets_;
|
||||
};
|
||||
|
||||
class SparseGatherV2Info : public GatherPInfo {
|
||||
class SparseGatherV2Info : public GatherInfo {
|
||||
public:
|
||||
SparseGatherV2Info(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs, const std::string &replace_op_name = SPARSE_GATHERV2)
|
||||
: GatherPInfo(name, inputs_shape, outputs_shape, attrs, replace_op_name) {}
|
||||
: GatherInfo(name, inputs_shape, outputs_shape, attrs, replace_op_name) {}
|
||||
~SparseGatherV2Info() override = default;
|
||||
};
|
||||
|
||||
class EmbeddingLookupInfo : public GatherPInfo {
|
||||
class EmbeddingLookupInfo : public GatherInfo {
|
||||
public:
|
||||
EmbeddingLookupInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: GatherPInfo(name, inputs_shape, outputs_shape, attrs) {}
|
||||
: GatherInfo(name, inputs_shape, outputs_shape, attrs) {}
|
||||
~EmbeddingLookupInfo() override = default;
|
||||
};
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_GATHER_V2_P_INFO_H_
|
||||
#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_GATHER_INFO_H_
|
|
@ -1,318 +0,0 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "frontend/parallel/ops_info/gather_v2_info.h"
|
||||
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "ir/tensor.h"
|
||||
#include "ir/value.h"
|
||||
#include "frontend/parallel/auto_parallel/costmodel.h"
|
||||
#include "frontend/parallel/device_matrix.h"
|
||||
#include "frontend/parallel/graph_util/generate_graph.h"
|
||||
#include "frontend/parallel/strategy.h"
|
||||
#include "utils/log_adapter.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace parallel {
|
||||
Status GatherInfo::GetAttrs() {
|
||||
if (inputs_shape_.size() != GATHER_V2_INPUTS_SIZE) {
|
||||
MS_LOG(ERROR) << name_ << ": inputs shape size must be 2, but is " << inputs_shape_.size();
|
||||
return FAILED;
|
||||
}
|
||||
if (outputs_shape_.size() != GATHER_V2_OUTPUTS_SIZE) {
|
||||
MS_LOG(ERROR) << name_ << ": outputs shape size must be 1, but is " << outputs_shape_.size();
|
||||
return FAILED;
|
||||
}
|
||||
if (input_value_.size() != GATHER_V2_INPUTS_VALUE_SIZE) {
|
||||
MS_LOG(ERROR) << name_ << ": input value size must be 3, but is " << input_value_.size();
|
||||
return FAILED;
|
||||
}
|
||||
// the second input is the index tensor
|
||||
|
||||
// the third input is the axis, is a ValueNode
|
||||
if (input_value_.at(2) == nullptr) {
|
||||
MS_LOG(ERROR) << name_ << ": the third input value is nullptr, is not a ValueNode!";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
if (inputs_shape_.at(0).size() == 0) {
|
||||
MS_LOG(ERROR) << name_ << ": input can not be a scalar!";
|
||||
return FAILED;
|
||||
}
|
||||
int64_t axis = GetValue<int64_t>(input_value_.at(2));
|
||||
if (axis >= SizeToLong(inputs_shape_.at(0).size()) || axis < -SizeToLong(inputs_shape_.at(0).size())) {
|
||||
MS_LOG(ERROR) << "Axis is " << axis << ", not in [-" << inputs_shape_.at(0).size() << ", "
|
||||
<< inputs_shape_.at(0).size() << ").";
|
||||
}
|
||||
if (axis < 0) {
|
||||
axis += SizeToLong(inputs_shape_[0].size());
|
||||
}
|
||||
axis_ = axis;
|
||||
|
||||
index_size_ = inputs_shape_.at(1).size();
|
||||
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status GatherInfo::CheckStrategy(const StrategyPtr &strategy) {
|
||||
if (inputs_shape_.size() != GATHER_V2_INPUTS_SIZE) {
|
||||
MS_LOG(ERROR) << name_ << ": inputs shape size must be " << GATHER_V2_INPUTS_SIZE << ", but is "
|
||||
<< inputs_shape_.size();
|
||||
return FAILED;
|
||||
}
|
||||
if (outputs_shape_.size() != GATHER_V2_OUTPUTS_SIZE) {
|
||||
MS_LOG(ERROR) << name_ << ": outputs shape size must be " << GATHER_V2_OUTPUTS_SIZE << ", but is "
|
||||
<< outputs_shape_.size();
|
||||
return FAILED;
|
||||
}
|
||||
// Only strategy of the first input should be set.
|
||||
if (CheckStrategyValue(strategy, {inputs_shape_.at(0)}) != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << ": Invalid strategy.";
|
||||
return FAILED;
|
||||
}
|
||||
axis_strategy_ = strategy->GetInputDim().at(0).at(LongToSize(axis_));
|
||||
if (index_size_ != 1 && axis_strategy_ != 1) {
|
||||
MS_LOG(ERROR) << name_
|
||||
<< ": Invalid strategy. If the index is a scalar or a more than 1 dimension vector, the strategy "
|
||||
"corresponding to axis must be 1, but is "
|
||||
<< axis_strategy_;
|
||||
return FAILED;
|
||||
}
|
||||
if (index_size_ == 1 && axis_strategy_ != 1 && inputs_shape_.at(1).at(0) % axis_strategy_ != 0) {
|
||||
MS_LOG(ERROR) << name_
|
||||
<< ": Invalid strategy. The first dimension of index can not be divided by strategy corresponding to "
|
||||
"axis. The first dimension of index is "
|
||||
<< inputs_shape_.at(1).at(0) << " strategy corresponding to axis is " << axis_strategy_;
|
||||
return FAILED;
|
||||
}
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status GatherInfo::InferDevMatrixShape() {
|
||||
Strategys stra = strategy_->GetInputDim();
|
||||
dev_matrix_shape_ = stra.at(0);
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
// If index is a scalar, output dimension is input dimension minus 1;
|
||||
// If index is a n dimension tensor, output dimension is input dimension plus (n - 1).
|
||||
// Tensor map dimension is equal to the corresponding input and output dimension.
|
||||
// If index's dimension is more than 1, we insert -1 for the output tensor map.
|
||||
Status GatherInfo::InferTensorMap() {
|
||||
if (inputs_shape_.size() != GATHER_V2_INPUTS_SIZE) {
|
||||
MS_LOG(ERROR) << name_ << ": inputs shape size must be " << GATHER_V2_INPUTS_SIZE << ", but is "
|
||||
<< inputs_shape_.size();
|
||||
return FAILED;
|
||||
}
|
||||
if (outputs_shape_.size() != GATHER_V2_OUTPUTS_SIZE) {
|
||||
MS_LOG(ERROR) << name_ << ": outputs shape size must be " << GATHER_V2_OUTPUTS_SIZE << ", but is "
|
||||
<< outputs_shape_.size();
|
||||
return FAILED;
|
||||
}
|
||||
Shape tensor_map_in;
|
||||
Shape tensor_map_out;
|
||||
size_t size = inputs_shape_.at(0).size();
|
||||
// such as 4: tensor_map_index [3,2,1,0]
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
tensor_map_in.push_back(SizeToLong(size - i - 1));
|
||||
tensor_map_out.push_back(SizeToLong(size - i - 1));
|
||||
}
|
||||
|
||||
if (index_size_ == 0) {
|
||||
(void)tensor_map_out.erase(tensor_map_out.begin() + axis_);
|
||||
} else if (index_size_ > 1) {
|
||||
(void)tensor_map_out.insert(tensor_map_out.begin() + axis_, index_size_ - 1, -1);
|
||||
}
|
||||
if (tensor_map_out.size() != outputs_shape_.at(0).size()) {
|
||||
MS_LOG(ERROR) << "Out tensor map size is not equal to output size! Out tensor map size is " << tensor_map_out.size()
|
||||
<< " output size is " << outputs_shape_.at(0).size();
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
Shape tensor_map_in_index;
|
||||
if (index_size_ >= 1) {
|
||||
tensor_map_in_index.push_back(SizeToLong(size) - axis_ - 1);
|
||||
}
|
||||
for (size_t i = 1; i < index_size_; ++i) {
|
||||
tensor_map_in_index.push_back(-1);
|
||||
}
|
||||
inputs_tensor_map_.emplace_back(std::move(tensor_map_in));
|
||||
inputs_tensor_map_.emplace_back(std::move(tensor_map_in_index));
|
||||
outputs_tensor_map_.emplace_back(std::move(tensor_map_out));
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status GatherInfo::InferTensorInfo() {
|
||||
if (inputs_shape_.size() != GATHER_V2_INPUTS_SIZE) {
|
||||
MS_LOG(ERROR) << name_ << ": inputs shape size must be " << GATHER_V2_INPUTS_SIZE << ", but is "
|
||||
<< inputs_shape_.size();
|
||||
return FAILED;
|
||||
}
|
||||
if (outputs_shape_.size() != GATHER_V2_OUTPUTS_SIZE) {
|
||||
MS_LOG(ERROR) << name_ << ": outputs shape size must be " << GATHER_V2_OUTPUTS_SIZE << ", but is "
|
||||
<< outputs_shape_.size();
|
||||
return FAILED;
|
||||
}
|
||||
if (inputs_tensor_map_.size() != GATHER_V2_INPUTS_SIZE) {
|
||||
MS_LOG(ERROR) << name_ << ": inputs tensor map size must be " << GATHER_V2_INPUTS_SIZE << ", but is "
|
||||
<< inputs_tensor_map_.size();
|
||||
return FAILED;
|
||||
}
|
||||
if (outputs_tensor_map_.size() != GATHER_V2_OUTPUTS_SIZE) {
|
||||
MS_LOG(ERROR) << name_ << ": outputs tensor map size must be " << GATHER_V2_OUTPUTS_SIZE << ", but is "
|
||||
<< outputs_tensor_map_.size();
|
||||
return FAILED;
|
||||
}
|
||||
// infer tensor shape
|
||||
Shape input_shape = inputs_shape_.at(0);
|
||||
Shape input_index_shape = inputs_shape_.at(1);
|
||||
Shape output_shape = outputs_shape_.at(0);
|
||||
|
||||
TensorLayout input_tensor_layout, input_index_layout, output_tensor_layout;
|
||||
if ((input_tensor_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_.at(0), input_shape) != SUCCESS) ||
|
||||
(input_index_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_.at(1), input_index_shape) != SUCCESS) ||
|
||||
(output_tensor_layout.InitFromVector(dev_matrix_shape_, outputs_tensor_map_.at(0), output_shape) != SUCCESS)) {
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
TensorInfo input_tensor_info(input_tensor_layout);
|
||||
TensorInfo input_index_info(input_index_layout);
|
||||
TensorInfo output_tensor_info(output_tensor_layout);
|
||||
|
||||
inputs_tensor_info_.push_back(input_tensor_info);
|
||||
inputs_tensor_info_.push_back(input_index_info);
|
||||
outputs_tensor_info_.push_back(output_tensor_info);
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
OperatorVector CreateSubOp(int64_t sub_value) {
|
||||
OperatorVector ops;
|
||||
OperatorName operator_name = SUB;
|
||||
OperatorAttrs operator_attrs;
|
||||
|
||||
std::vector<int64_t> tensor_data = {sub_value};
|
||||
mindspore::tensor::TensorPtr tensor_ptr = std::make_shared<mindspore::tensor::Tensor>(tensor_data, kInt32);
|
||||
ValuePtr op_param_value = MakeValue(tensor_ptr);
|
||||
|
||||
Attr op1_param = std::make_pair("", op_param_value);
|
||||
OperatorParams operator_param = {std::make_pair(op1_param, 2)};
|
||||
|
||||
OperatorArgs operator_args = std::make_pair(operator_attrs, operator_param);
|
||||
Operator op = std::make_pair(operator_name, operator_args);
|
||||
ops.push_back(op);
|
||||
return ops;
|
||||
}
|
||||
|
||||
Status GatherInfo::InferTensorSubOps() {
|
||||
sub_ops_.clear();
|
||||
if ((index_size_ == 0) || (axis_strategy_ == 1)) {
|
||||
return SUCCESS;
|
||||
}
|
||||
int64_t mod_n = 1;
|
||||
for (size_t i = LongToSize(axis_) + 1; i < dev_matrix_shape_.size(); i++) {
|
||||
mod_n *= dev_matrix_shape_.at(i);
|
||||
}
|
||||
if ((axis_ >= SizeToLong(dev_matrix_shape_.size())) || axis_ < 0) {
|
||||
MS_LOG(ERROR) << "Axis is " << axis_ << ", not in [0, " << dev_matrix_shape_.size() << ").";
|
||||
}
|
||||
int64_t mod_p = mod_n * dev_matrix_shape_.at(LongToSize(axis_));
|
||||
int64_t rank = g_device_manager->rank_index_in_stage();
|
||||
int64_t mod_rank = rank % mod_p;
|
||||
mod_rank = static_cast<int64_t>(mod_rank / mod_n);
|
||||
if (inputs_shape_.size() != GATHER_V2_INPUTS_SIZE) {
|
||||
MS_LOG(ERROR) << name_ << ": inputs shape size must be " << GATHER_V2_INPUTS_SIZE << ", but is "
|
||||
<< inputs_shape_.size();
|
||||
return FAILED;
|
||||
}
|
||||
if ((axis_ >= SizeToLong(inputs_shape_.at(0).size())) || axis_ < 0) {
|
||||
MS_LOG(ERROR) << "Axis is " << axis_ << ", not in [0, " << inputs_shape_.at(0).size() << ").";
|
||||
}
|
||||
int64_t sub_value = inputs_shape_[0][LongToSize(axis_)] / dev_matrix_shape_[LongToSize(axis_)] * mod_rank;
|
||||
|
||||
OperatorVector sub_op;
|
||||
sub_ops_.emplace_back(std::move(sub_op));
|
||||
sub_op = CreateSubOp(sub_value);
|
||||
sub_ops_.emplace_back(std::move(sub_op));
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status GatherInfo::Init(const StrategyPtr &strategy) {
|
||||
if (InitWithAutoRepeatCalc(strategy) != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << ": Init failed.";
|
||||
return FAILED;
|
||||
}
|
||||
Status status = InferTensorSubOps();
|
||||
if (status != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << ": InferTensorSubOps failed.";
|
||||
return status;
|
||||
}
|
||||
MS_LOG(INFO) << name_ << ": Init success.";
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status GatherInfo::InitForCostModel(const StrategyPtr &strategy) {
|
||||
if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << ": Init for cost model failed.";
|
||||
return FAILED;
|
||||
}
|
||||
MS_LOG(INFO) << name_ << ": Init for cost model success.";
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
std::vector<StrategyPtr> GatherInfo::GenerateOpStrategies(int64_t stage_id) {
|
||||
if ((inputs_shape_.size() != GATHER_V2_INPUTS_SIZE) || (outputs_shape_.size() != GATHER_V2_OUTPUTS_SIZE)) {
|
||||
MS_LOG(EXCEPTION) << name_ << " : Inputs shape size(" << inputs_shape_.size() << ") or outputs shape size("
|
||||
<< outputs_shape_.size() << "is wrong.";
|
||||
}
|
||||
Shape input0_split(inputs_shape_[0].size(), 1);
|
||||
Shapes splittable_inputs = {input0_split};
|
||||
|
||||
std::vector<StrategyPtr> sp_vector;
|
||||
if (GenerateStrategiesForIndependentInputs(stage_id, {inputs_shape_.at(0)}, splittable_inputs, &sp_vector) !=
|
||||
SUCCESS) {
|
||||
MS_LOG(EXCEPTION) << name_ << " : Generate strategies for independent inputs() failed.";
|
||||
}
|
||||
return sp_vector;
|
||||
}
|
||||
|
||||
Status GatherInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); }
|
||||
|
||||
std::shared_ptr<Strategys> GatherInfo::GenerateBatchStrategies() {
|
||||
if (inputs_shape_.size() != GATHER_V2_INPUTS_SIZE) {
|
||||
MS_LOG(EXCEPTION) << name_ << ": inputs shape size must be " << GATHER_V2_INPUTS_SIZE << ", but is "
|
||||
<< inputs_shape_.size();
|
||||
}
|
||||
if (GetAttrs() != SUCCESS) {
|
||||
MS_LOG(EXCEPTION) << "GetAttrs failed!";
|
||||
}
|
||||
|
||||
Dimensions strategy;
|
||||
if (index_size_ != 1) {
|
||||
strategy.push_back(1);
|
||||
} else {
|
||||
strategy.push_back(stage_device_size_);
|
||||
}
|
||||
for (size_t i = 1; i < inputs_shape_[0].size(); i++) {
|
||||
strategy.push_back(1);
|
||||
}
|
||||
Strategys strategy_v = {strategy};
|
||||
return std::make_shared<Strategys>(strategy_v);
|
||||
}
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
|
@ -1,73 +0,0 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_GATHER_V2_INFO_H_
|
||||
#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_GATHER_V2_INFO_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "ir/value.h"
|
||||
#include "frontend/parallel/auto_parallel/operator_costmodel.h"
|
||||
#include "frontend/parallel/ops_info/operator_info.h"
|
||||
#include "frontend/parallel/strategy.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace parallel {
|
||||
constexpr size_t GATHER_V2_INPUTS_SIZE = 2;
|
||||
constexpr size_t GATHER_V2_OUTPUTS_SIZE = 1;
|
||||
constexpr size_t GATHER_V2_INPUTS_VALUE_SIZE = 3;
|
||||
// We now supported limited parallel strategies.
|
||||
// If the strategy corresponding to axis is more than 1, index must be evenly distributed across the axis-dimension of
|
||||
// the input.
|
||||
// If Index is a scalar or n-dimension vector(n > 1), the strategy corresponding to axis must be 1.
|
||||
class GatherInfo : public OperatorInfo {
|
||||
public:
|
||||
GatherInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<GatherV2Cost>()),
|
||||
axis_(-1),
|
||||
index_size_(0),
|
||||
axis_strategy_(1) {}
|
||||
~GatherInfo() override = default;
|
||||
Status Init(const StrategyPtr &strategy) override;
|
||||
Status InitForCostModel(const StrategyPtr &strategy) override;
|
||||
|
||||
std::vector<StrategyPtr> GenerateOpStrategies(int64_t stage_id) override;
|
||||
Status SetCostUnderStrategy(const StrategyPtr &strategy) override;
|
||||
std::shared_ptr<Strategys> GenerateBatchStrategies() override;
|
||||
|
||||
protected:
|
||||
Status CheckStrategy(const StrategyPtr &strategy) override;
|
||||
Status InferMirrorOps() override { return SUCCESS; }
|
||||
Status InferForwardCommunication() override { return SUCCESS; }
|
||||
Status InferTensorInfo() override;
|
||||
Status InferDevMatrixShape() override;
|
||||
Status InferTensorMap() override;
|
||||
Status GetAttrs() override;
|
||||
|
||||
private:
|
||||
Status InferTensorSubOps();
|
||||
|
||||
int64_t axis_;
|
||||
size_t index_size_;
|
||||
int64_t axis_strategy_;
|
||||
};
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_GATHER_V2_INFO_H_
|
|
@ -24,7 +24,6 @@
|
|||
#include "frontend/parallel/ops_info/comparison_function_info.h"
|
||||
#include "frontend/parallel/ops_info/dropout_do_mask_info.h"
|
||||
#include "frontend/parallel/ops_info/elementary_function_info.h"
|
||||
#include "frontend/parallel/ops_info/gather_v2_info.h"
|
||||
#include "frontend/parallel/ops_info/get_next_info.h"
|
||||
#include "frontend/parallel/ops_info/l2_normalize_info.h"
|
||||
#include "frontend/parallel/ops_info/layer_norm_info.h"
|
||||
|
@ -37,7 +36,7 @@
|
|||
#include "frontend/parallel/ops_info/transpose_info.h"
|
||||
#include "frontend/parallel/ops_info/unsorted_segment_op_info.h"
|
||||
#include "frontend/parallel/ops_info/virtual_dataset_info.h"
|
||||
#include "frontend/parallel/ops_info/gather_v2_p_info.h"
|
||||
#include "frontend/parallel/ops_info/gather_info.h"
|
||||
#include "frontend/parallel/ops_info/tile_info.h"
|
||||
#include "frontend/parallel/ops_info/strided_slice_info.h"
|
||||
#include "frontend/parallel/ops_info/slice_info.h"
|
||||
|
|
|
@ -1248,20 +1248,6 @@ OperatorInfoPtr OperatorInstanceByName(const std::string &name, const PrimitiveA
|
|||
MS_LOG(EXCEPTION) << "Length of name is zero!";
|
||||
}
|
||||
std::string distribute_opname = GetDisOpName(name);
|
||||
if (name == GATHERV2) {
|
||||
distribute_opname = name + "PInfo";
|
||||
auto data_parallel_iter = attrs.find(DATA_PARALLEL);
|
||||
if (data_parallel_iter != attrs.end()) {
|
||||
MS_EXCEPTION_IF_NULL(data_parallel_iter->second);
|
||||
if (!data_parallel_iter->second->isa<BoolImm>()) {
|
||||
MS_LOG(EXCEPTION) << ": data_parallel flag's type is not a bool.";
|
||||
}
|
||||
bool data_parallel = data_parallel_iter->second->cast<BoolImmPtr>()->value();
|
||||
if (data_parallel) {
|
||||
distribute_opname = name + "Info";
|
||||
}
|
||||
}
|
||||
}
|
||||
OperatorInfoPtr operator_ =
|
||||
(OperatorInfoPtr)DynCreator::Instance().Create(distribute_opname, shape_list[0], shape_list[1], attrs, TOTAL_OPS);
|
||||
if (operator_ == nullptr) {
|
||||
|
@ -2623,9 +2609,9 @@ ParameterMap NodeParameterName(const CNodePtr &node, int64_t index, size_t curr_
|
|||
return param_names;
|
||||
}
|
||||
|
||||
bool IsGatherPInfo(const std::string &name) {
|
||||
std::vector<std::string> gather_p_info_names = {"GatherPInfo", "SparseGatherV2Info", "EmbeddingLookupInfo"};
|
||||
for (std::string info_name : gather_p_info_names) {
|
||||
bool IsGatherInfo(const std::string &name) {
|
||||
std::vector<std::string> gather_info_names = {"GatherInfo", "SparseGatherV2Info", "EmbeddingLookupInfo"};
|
||||
for (std::string info_name : gather_info_names) {
|
||||
if (name.find(info_name) != std::string::npos) {
|
||||
return true;
|
||||
}
|
||||
|
@ -2661,10 +2647,10 @@ void CheckpointStrategy(const std::vector<AnfNodePtr> &all_nodes, const FuncGrap
|
|||
for (auto param_name_pair : param_names) {
|
||||
tensor_info_map[param_name_pair.first] = param_name_pair.second->user_data<TensorLayout>();
|
||||
}
|
||||
if (IsGatherPInfo(operator_info->name())) {
|
||||
auto gatherv2_info = std::dynamic_pointer_cast<GatherPInfo>(operator_info);
|
||||
auto param_split_shapes = gatherv2_info->param_split_shapes();
|
||||
auto index_offsets = gatherv2_info->index_offsets();
|
||||
if (IsGatherInfo(operator_info->name())) {
|
||||
auto gather_info = std::dynamic_pointer_cast<GatherInfo>(operator_info);
|
||||
auto param_split_shapes = gather_info->param_split_shapes();
|
||||
auto index_offsets = gather_info->index_offsets();
|
||||
if (param_split_shapes.size() != index_offsets.size()) {
|
||||
MS_LOG(EXCEPTION) << "In manual split, the param_split_shapes and index_offsets length should be same.";
|
||||
}
|
||||
|
|
|
@ -1,240 +0,0 @@
|
|||
# Copyright 2019 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore import context
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.common.parameter import ParameterTuple
|
||||
from mindspore.communication.management import init
|
||||
from mindspore.nn import Dense, Cell
|
||||
from mindspore.nn.loss.loss import LossBase
|
||||
from mindspore.nn.optim import Momentum
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.train import Model
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.communication._comm_helper import GlobalComm
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
device_number = 32
|
||||
batch_size_per_device = 128
|
||||
|
||||
|
||||
class Dataset():
|
||||
def __init__(self, predict, length=3):
|
||||
self.predict = predict
|
||||
self.index = 0
|
||||
self.length = length
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
if self.index >= self.length:
|
||||
raise StopIteration
|
||||
self.index += 1
|
||||
return (self.predict,)
|
||||
|
||||
def reset(self):
|
||||
self.index = 0
|
||||
|
||||
def get_dataset_size(self):
|
||||
return 128
|
||||
|
||||
def get_repeat_count(self):
|
||||
return 1
|
||||
|
||||
def create_tuple_iterator(self, num_epochs=-1, do_copy=True):
|
||||
return self
|
||||
|
||||
|
||||
class GatherV2(LossBase):
|
||||
def __init__(self, index_dim, strategy, index_size=16):
|
||||
super(GatherV2, self).__init__()
|
||||
self.pow = P.Pow()
|
||||
emb1_list = 21
|
||||
emb2_list = 2
|
||||
if index_dim == 1:
|
||||
emb_list = list(range(index_size))
|
||||
emb1_list = emb_list[0::2]
|
||||
emb2_list = emb_list[1::2]
|
||||
if index_dim == 2:
|
||||
emb_list = np.arange(index_size * 16)
|
||||
emb1_list = np.reshape(emb_list[0::2], (int(index_size / 2), 16))
|
||||
emb2_list = np.reshape(emb_list[1::2], (int(index_size / 2), 16))
|
||||
self.emb1_param = Tensor(emb1_list, dtype=mstype.int32)
|
||||
self.emb2_param = Tensor(emb2_list, dtype=mstype.int32)
|
||||
self.gatherv2 = P.Gather().shard(strategy).add_prim_attr("data_parallel", True)
|
||||
|
||||
def construct(self, nembeddings):
|
||||
emb1 = self.gatherv2(nembeddings, self.emb1_param, 0)
|
||||
emb2 = self.gatherv2(nembeddings, self.emb2_param, 0)
|
||||
return self.pow((emb1 - emb2), 2.0)
|
||||
|
||||
|
||||
def fc_with_initialize(input_channels, out_channels):
|
||||
return Dense(input_channels, out_channels)
|
||||
|
||||
|
||||
class BuildTrainNetwork(nn.Cell):
|
||||
def __init__(self, network, criterion):
|
||||
super(BuildTrainNetwork, self).__init__()
|
||||
self.network = network
|
||||
self.criterion = criterion
|
||||
|
||||
def construct(self, input_data):
|
||||
embeddings = self.network(input_data)
|
||||
loss = self.criterion(embeddings)
|
||||
return loss
|
||||
|
||||
|
||||
class TrainOneStepCell(Cell):
|
||||
def __init__(self, network, optimizer, sens=1.0):
|
||||
super(TrainOneStepCell, self).__init__(auto_prefix=False)
|
||||
self.network = network
|
||||
self.network.add_flags(defer_inline=True)
|
||||
self.weights = ParameterTuple(network.trainable_params())
|
||||
self.optimizer = optimizer
|
||||
self.grad = C.GradOperation(get_by_list=True,
|
||||
sens_param=True)
|
||||
self.sens = sens
|
||||
|
||||
def construct(self, data):
|
||||
weights = self.weights
|
||||
loss = self.network(data)
|
||||
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
|
||||
grads = self.grad(self.network, weights)(data, sens)
|
||||
|
||||
self.optimizer(grads)
|
||||
return loss
|
||||
|
||||
|
||||
def net_trains(criterion, rank):
|
||||
GlobalComm.CHECK_ENVS = False
|
||||
init()
|
||||
GlobalComm.CHECK_ENVS = True
|
||||
lr = 0.1
|
||||
momentum = 0.9
|
||||
max_epoch = 20
|
||||
input_channels = 256
|
||||
out_channels = 512
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
context.reset_auto_parallel_context()
|
||||
context.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, device_num=device_number,
|
||||
global_rank=rank)
|
||||
predict = Tensor(np.ones([batch_size_per_device, input_channels]), dtype=ms.float32)
|
||||
dataset = Dataset(predict, 4)
|
||||
|
||||
network = fc_with_initialize(input_channels, out_channels)
|
||||
network.set_train()
|
||||
|
||||
train_network = BuildTrainNetwork(network, criterion)
|
||||
train_network.set_train()
|
||||
opt = Momentum(train_network.trainable_params(), lr, momentum)
|
||||
train_net = TrainOneStepCell(train_network, opt).set_train()
|
||||
|
||||
model = Model(train_net)
|
||||
model.train(max_epoch, dataset, dataset_sink_mode=False)
|
||||
context.reset_auto_parallel_context()
|
||||
|
||||
|
||||
def test_auto_batch_parallel():
|
||||
gather_v2_strategy = None
|
||||
criterion = GatherV2(1, strategy=gather_v2_strategy, index_size=batch_size_per_device * device_number)
|
||||
rank = 2
|
||||
net_trains(criterion, rank)
|
||||
|
||||
|
||||
def test_2d_index_auto_batch_parallel():
|
||||
gather_v2_strategy = None
|
||||
criterion = GatherV2(2, strategy=gather_v2_strategy, index_size=batch_size_per_device * device_number)
|
||||
rank = 2
|
||||
net_trains(criterion, rank)
|
||||
|
||||
|
||||
def test_batch_parallel():
|
||||
gather_v2_strategy = ((device_number, 1),)
|
||||
criterion = GatherV2(1, strategy=gather_v2_strategy, index_size=batch_size_per_device * device_number)
|
||||
rank = 2
|
||||
net_trains(criterion, rank)
|
||||
|
||||
|
||||
def test_strategy1():
|
||||
gather_v2_strategy = ((16, 2),)
|
||||
rank = 2
|
||||
criterion = GatherV2(1, strategy=gather_v2_strategy, index_size=batch_size_per_device * device_number)
|
||||
net_trains(criterion, rank)
|
||||
|
||||
|
||||
def test_strategy2():
|
||||
gather_v2_strategy = ((1, device_number),)
|
||||
rank = 2
|
||||
criterion = GatherV2(1, strategy=gather_v2_strategy, index_size=batch_size_per_device * device_number)
|
||||
net_trains(criterion, rank)
|
||||
|
||||
|
||||
def test_strategy3():
|
||||
gather_v2_strategy = ((8, 1),)
|
||||
rank = 2
|
||||
criterion = GatherV2(1, strategy=gather_v2_strategy, index_size=batch_size_per_device * device_number)
|
||||
net_trains(criterion, rank)
|
||||
|
||||
|
||||
class GatherV2Axis1(LossBase):
|
||||
def __init__(self, index_dim, strategy, index_size=16):
|
||||
super(GatherV2Axis1, self).__init__()
|
||||
self.pow = P.Pow()
|
||||
emb1_list = 21
|
||||
emb2_list = 2
|
||||
if index_dim == 1:
|
||||
emb_list = list(range(index_size))
|
||||
emb1_list = emb_list[0::2]
|
||||
emb2_list = emb_list[1::2]
|
||||
if index_dim == 2:
|
||||
emb_list = np.arange(index_size * index_size)
|
||||
emb1_list = np.reshape(emb_list[0::2], (int(index_size / 2), index_size))
|
||||
emb2_list = np.reshape(emb_list[1::2], (int(index_size / 2), index_size))
|
||||
self.emb1_param = Tensor(emb1_list, dtype=mstype.int32)
|
||||
self.emb2_param = Tensor(emb2_list, dtype=mstype.int32)
|
||||
self.gatherv2 = P.Gather().shard(strategy)
|
||||
|
||||
def construct(self, nembeddings):
|
||||
emb1 = self.gatherv2(nembeddings, self.emb1_param, 1)
|
||||
emb2 = self.gatherv2(nembeddings, self.emb2_param, 1)
|
||||
return self.pow((emb1 - emb2), 2.0)
|
||||
|
||||
|
||||
def test_axis1_auto_batch_parallel():
|
||||
gather_v2_strategy = None
|
||||
criterion = GatherV2Axis1(1, strategy=gather_v2_strategy, index_size=512)
|
||||
rank = 2
|
||||
net_trains(criterion, rank)
|
||||
|
||||
|
||||
def test_axis1_batch_parallel():
|
||||
gather_v2_strategy = ((device_number, 1), (1,))
|
||||
criterion = GatherV2Axis1(1, strategy=gather_v2_strategy, index_size=512)
|
||||
rank = 2
|
||||
net_trains(criterion, rank)
|
||||
|
||||
|
||||
def test_axis1_strategy1():
|
||||
gather_v2_strategy = ((16, 2), (1,))
|
||||
rank = 17
|
||||
criterion = GatherV2Axis1(1, strategy=gather_v2_strategy, index_size=512)
|
||||
net_trains(criterion, rank)
|
Loading…
Reference in New Issue