forked from mindspore-Ecosystem/mindspore
!221 add distributed GatherV2 operator for auto parallel
Merge pull request !221 from chentingting/support_distributd_GatherV2_operator
This commit is contained in:
commit
b2eb656f91
|
@ -623,5 +623,34 @@ double DropOutCost::GetForwardComputationCost(const std::vector<TensorInfo>& inp
|
||||||
Shape input0_slice_shape = input0.slice_shape();
|
Shape input0_slice_shape = input0.slice_shape();
|
||||||
return ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]) * DROPOUT_COST_RATE;
|
return ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]) * DROPOUT_COST_RATE;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// return the per device communication cost in the forward phase.
|
||||||
|
double GatherV2Cost::GetForwardCommCost(const std::vector<TensorInfo>&, const std::vector<TensorInfo>&,
|
||||||
|
const int32_t&) const {
|
||||||
|
// GatherV2Cost does not need communication in the forward phase
|
||||||
|
return 0.0;
|
||||||
|
}
|
||||||
|
|
||||||
|
// return the per device communication cost in the backward phase.
|
||||||
|
double GatherV2Cost::GetBackwardCommCost(const std::vector<TensorInfo>&, const std::vector<TensorInfo>&,
|
||||||
|
const int32_t&) const {
|
||||||
|
// GatherV2Cost does not need communication in the backward phase
|
||||||
|
return 0.0;
|
||||||
|
}
|
||||||
|
|
||||||
|
double GatherV2Cost::GetForwardComputationCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>&,
|
||||||
|
const int32_t&) const {
|
||||||
|
// In forward phase, the computation cost = slice(A) + slice(B)
|
||||||
|
Shape input0_slice_shape = inputs[0].slice_shape();
|
||||||
|
Shape input1_slice_shape = inputs[1].slice_shape();
|
||||||
|
double result = ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]) +
|
||||||
|
ListProduct(input1_slice_shape) * static_cast<double>(inputs_type_lengths_[1]);
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
double GatherV2Cost::GetBackwardComputationCost(const std::vector<TensorInfo>&, const std::vector<TensorInfo>&,
|
||||||
|
const int32_t&) const {
|
||||||
|
return 0.0;
|
||||||
|
}
|
||||||
} // namespace parallel
|
} // namespace parallel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -81,6 +81,8 @@ class OperatorCost {
|
||||||
std::vector<size_t> outputs_type_lengths_;
|
std::vector<size_t> outputs_type_lengths_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
using OperatorCostPtr = std::shared_ptr<OperatorCost>;
|
||||||
|
|
||||||
class MatMulCost : public OperatorCost {
|
class MatMulCost : public OperatorCost {
|
||||||
public:
|
public:
|
||||||
MatMulCost() = default;
|
MatMulCost() = default;
|
||||||
|
@ -525,6 +527,31 @@ class DropOutCost : public OperatorCost {
|
||||||
};
|
};
|
||||||
|
|
||||||
using DropOutCostPtr = std::shared_ptr<DropOutCost>;
|
using DropOutCostPtr = std::shared_ptr<DropOutCost>;
|
||||||
|
|
||||||
|
class GatherV2Cost : public OperatorCost {
|
||||||
|
public:
|
||||||
|
GatherV2Cost() = default;
|
||||||
|
~GatherV2Cost() override = default;
|
||||||
|
|
||||||
|
double GetCommCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs,
|
||||||
|
const int32_t& stage_id) const override {
|
||||||
|
return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id);
|
||||||
|
}
|
||||||
|
double GetForwardCommCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs,
|
||||||
|
const int32_t& stage_id) const override;
|
||||||
|
double GetBackwardCommCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs,
|
||||||
|
const int32_t& stage_id) const override;
|
||||||
|
double GetComputationCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs,
|
||||||
|
const int32_t& stage_id) const override {
|
||||||
|
return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id);
|
||||||
|
}
|
||||||
|
double GetForwardComputationCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs,
|
||||||
|
const int32_t& stage_id) const override;
|
||||||
|
double GetBackwardComputationCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs,
|
||||||
|
const int32_t&) const override;
|
||||||
|
};
|
||||||
|
|
||||||
|
using GatherV2CostPtr = std::shared_ptr<GatherV2Cost>;
|
||||||
} // namespace parallel
|
} // namespace parallel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
#endif // PARALLEL_AUTO_PARALLEL_OPERATOR_COSTMODEL_H_
|
#endif // PARALLEL_AUTO_PARALLEL_OPERATOR_COSTMODEL_H_
|
||||||
|
|
|
@ -228,26 +228,6 @@ void SparseSoftmaxCrossEntropyWithLogitsInfo::ReComputeBatchSplitFlagList() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void GatherV2Info::ReComputeBatchSplitFlagList() {
|
|
||||||
MS_ASSERT(inputs_shape_.size() == 2);
|
|
||||||
MS_ASSERT(input_value_.size() == 3);
|
|
||||||
MS_ASSERT(input_value_[0] == nullptr);
|
|
||||||
// the second input is the index tensor
|
|
||||||
MS_ASSERT(input_value_[1] != nullptr);
|
|
||||||
// the third input is the axis
|
|
||||||
MS_ASSERT(input_value_[2] != nullptr);
|
|
||||||
int axis = GetValue<int>(input_value_[2]);
|
|
||||||
MS_ASSERT(axis < inputs_shape_[0].size() && axis >= 0 - inputs_shape_[0].size());
|
|
||||||
if (axis < 0) {
|
|
||||||
axis += SizeToInt(inputs_shape_[0].size());
|
|
||||||
}
|
|
||||||
split_flag_list_[0] = true;
|
|
||||||
// if gather axis is 0, the index's strategy is equal to device number
|
|
||||||
if (axis == 0) {
|
|
||||||
split_flag_list_[1] = true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Status BatchParallelInfo::InferAsLossDivisor() {
|
Status BatchParallelInfo::InferAsLossDivisor() {
|
||||||
as_loss_divisor_ = 1;
|
as_loss_divisor_ = 1;
|
||||||
return SUCCESS;
|
return SUCCESS;
|
||||||
|
|
|
@ -62,15 +62,6 @@ class SparseSoftmaxCrossEntropyWithLogitsInfo : public BatchParallelInfo {
|
||||||
~SparseSoftmaxCrossEntropyWithLogitsInfo() override = default;
|
~SparseSoftmaxCrossEntropyWithLogitsInfo() override = default;
|
||||||
void ReComputeBatchSplitFlagList() override;
|
void ReComputeBatchSplitFlagList() override;
|
||||||
};
|
};
|
||||||
|
|
||||||
class GatherV2Info : public BatchParallelInfo {
|
|
||||||
public:
|
|
||||||
GatherV2Info(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape,
|
|
||||||
const PrimitiveAttrs& attrs)
|
|
||||||
: BatchParallelInfo(name, inputs_shape, outputs_shape, attrs) {}
|
|
||||||
~GatherV2Info() override = default;
|
|
||||||
void ReComputeBatchSplitFlagList() override;
|
|
||||||
};
|
|
||||||
} // namespace parallel
|
} // namespace parallel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,350 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "parallel/ops_info/gather_v2_info.h"
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "ir/meta_tensor.h"
|
||||||
|
#include "ir/value.h"
|
||||||
|
#include "parallel/auto_parallel/costmodel.h"
|
||||||
|
#include "parallel/device_matrix.h"
|
||||||
|
#include "parallel/graph_util/generate_graph.h"
|
||||||
|
#include "parallel/strategy.h"
|
||||||
|
#include "utils/log_adapter.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace parallel {
|
||||||
|
Status GatherV2Info::GetAttrs() {
|
||||||
|
if (inputs_shape_.size() != GATHER_V2_INPUTS_SIZE) {
|
||||||
|
MS_LOG(ERROR) << name_ << ": inputs shape size must be 2, but is " << inputs_shape_.size();
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
if (outputs_shape_.size() != GATHER_V2_OUTPUTS_SIZE) {
|
||||||
|
MS_LOG(ERROR) << name_ << ": outputs shape size must be 1, but is " << outputs_shape_.size();
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
if (input_value_.size() != GATHER_V2_INPUTS_VALUE_SIZE) {
|
||||||
|
MS_LOG(ERROR) << name_ << ": input value size must be 3, but is " << input_value_.size();
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
// the second input is the index tensor
|
||||||
|
|
||||||
|
// the third input is the axis, is a ValueNode
|
||||||
|
if (input_value_.at(2) == nullptr) {
|
||||||
|
MS_LOG(ERROR) << name_ << ": the third input value is nullptr, is not a ValueNode!";
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (inputs_shape_.at(0).size() == 0) {
|
||||||
|
MS_LOG(ERROR) << name_ << ": input can not be a scalar!";
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
int axis = GetValue<int>(input_value_.at(2));
|
||||||
|
if (axis >= SizeToInt(inputs_shape_.at(0).size()) || axis < 0 - SizeToInt(inputs_shape_.at(0).size())) {
|
||||||
|
MS_LOG(ERROR) << "Axis is " << axis << ", not in [-" << inputs_shape_.at(0).size() << ", "
|
||||||
|
<< inputs_shape_.at(0).size() << ").";
|
||||||
|
}
|
||||||
|
if (axis < 0) {
|
||||||
|
axis += SizeToInt(inputs_shape_[0].size());
|
||||||
|
}
|
||||||
|
axis_ = axis;
|
||||||
|
|
||||||
|
index_size_ = inputs_shape_.at(1).size();
|
||||||
|
|
||||||
|
return SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status GatherV2Info::CheckStrategy(const StrategyPtr& strategy) {
|
||||||
|
if (inputs_shape_.size() != GATHER_V2_INPUTS_SIZE) {
|
||||||
|
MS_LOG(ERROR) << name_ << ": inputs shape size must be " << GATHER_V2_INPUTS_SIZE << ", but is "
|
||||||
|
<< inputs_shape_.size();
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
if (outputs_shape_.size() != GATHER_V2_OUTPUTS_SIZE) {
|
||||||
|
MS_LOG(ERROR) << name_ << ": outputs shape size must be " << GATHER_V2_OUTPUTS_SIZE << ", but is "
|
||||||
|
<< outputs_shape_.size();
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
// Only strategy of the first input should be set.
|
||||||
|
if (CheckStrategyValue(strategy, {inputs_shape_.at(0)}, is_auto_parallel_) != SUCCESS) {
|
||||||
|
if (is_auto_parallel_) {
|
||||||
|
MS_LOG(DEBUG) << name_ << ": Invalid strategy.";
|
||||||
|
} else {
|
||||||
|
MS_LOG(ERROR) << name_ << ": Invalid strategy.";
|
||||||
|
}
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
axis_strategy_ = strategy->GetInputDim().at(0).at(axis_);
|
||||||
|
if (index_size_ != 1 && axis_strategy_ != 1) {
|
||||||
|
MS_LOG(ERROR) << name_
|
||||||
|
<< ": Invalid strategy. If the index is a scalar or a more than 1 dimension vector, the strategy "
|
||||||
|
"corresponding to axis must be 1, but is "
|
||||||
|
<< axis_strategy_;
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
if (index_size_ == 1 && axis_strategy_ != 1 && inputs_shape_.at(1).at(0) % axis_strategy_ != 0) {
|
||||||
|
MS_LOG(ERROR) << name_
|
||||||
|
<< ": Invalid strategy. The first dimension of index can not be divided by strategy corresponding to "
|
||||||
|
"axis. The first dimension of index is "
|
||||||
|
<< inputs_shape_.at(1).at(0) << " strategy corresponding to axis is " << axis_strategy_;
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
return SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status GatherV2Info::InferDevMatrixShape() {
|
||||||
|
std::vector<Dimensions> stra = strategy_->GetInputDim();
|
||||||
|
dev_matrix_shape_ = stra.at(0);
|
||||||
|
return SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
// If index is a scalar, output dimension is input dimension minus 1;
|
||||||
|
// If index is a n dimension tensor, output dimension is input dimension plus (n - 1).
|
||||||
|
// Tensor map dimension is equal to the corresponding input and output dimension.
|
||||||
|
// If index's dimension is more than 1, we insert -1 for the output tensor map.
|
||||||
|
Status GatherV2Info::InferTensorMap() {
|
||||||
|
if (inputs_shape_.size() != GATHER_V2_INPUTS_SIZE) {
|
||||||
|
MS_LOG(ERROR) << name_ << ": inputs shape size must be " << GATHER_V2_INPUTS_SIZE << ", but is "
|
||||||
|
<< inputs_shape_.size();
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
if (outputs_shape_.size() != GATHER_V2_OUTPUTS_SIZE) {
|
||||||
|
MS_LOG(ERROR) << name_ << ": outputs shape size must be " << GATHER_V2_OUTPUTS_SIZE << ", but is "
|
||||||
|
<< outputs_shape_.size();
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
std::vector<int32_t> tensor_map_in;
|
||||||
|
std::vector<int32_t> tensor_map_out;
|
||||||
|
size_t size = inputs_shape_.at(0).size();
|
||||||
|
// such as 4: tensor_map_index [3,2,1,0]
|
||||||
|
for (size_t i = 0; i < size; ++i) {
|
||||||
|
tensor_map_in.push_back(SizeToInt(size - i - 1));
|
||||||
|
tensor_map_out.push_back(SizeToInt(size - i - 1));
|
||||||
|
}
|
||||||
|
|
||||||
|
if (index_size_ == 0) {
|
||||||
|
(void)tensor_map_out.erase(tensor_map_out.begin() + axis_);
|
||||||
|
} else if (index_size_ > 1) {
|
||||||
|
(void)tensor_map_out.insert(tensor_map_out.begin() + axis_, index_size_ - 1, -1);
|
||||||
|
}
|
||||||
|
if (tensor_map_out.size() != outputs_shape_.at(0).size()) {
|
||||||
|
MS_LOG(ERROR) << "Out tensor map size is not equal to output size! Out tensor map size is " << tensor_map_out.size()
|
||||||
|
<< " output size is " << outputs_shape_.at(0).size();
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<int32_t> tensor_map_in_index;
|
||||||
|
if (index_size_ >= 1) {
|
||||||
|
tensor_map_in_index.push_back(SizeToInt(size - axis_ - 1));
|
||||||
|
}
|
||||||
|
for (size_t i = 1; i < index_size_; ++i) {
|
||||||
|
tensor_map_in_index.push_back(-1);
|
||||||
|
}
|
||||||
|
inputs_tensor_map_.emplace_back(std::move(tensor_map_in));
|
||||||
|
inputs_tensor_map_.emplace_back(std::move(tensor_map_in_index));
|
||||||
|
outputs_tensor_map_.emplace_back(std::move(tensor_map_out));
|
||||||
|
return SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status GatherV2Info::InferTensorInfo() {
|
||||||
|
if (inputs_shape_.size() != GATHER_V2_INPUTS_SIZE) {
|
||||||
|
MS_LOG(ERROR) << name_ << ": inputs shape size must be " << GATHER_V2_INPUTS_SIZE << ", but is "
|
||||||
|
<< inputs_shape_.size();
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
if (outputs_shape_.size() != GATHER_V2_OUTPUTS_SIZE) {
|
||||||
|
MS_LOG(ERROR) << name_ << ": outputs shape size must be " << GATHER_V2_OUTPUTS_SIZE << ", but is "
|
||||||
|
<< outputs_shape_.size();
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
if (inputs_tensor_map_.size() != GATHER_V2_INPUTS_SIZE) {
|
||||||
|
MS_LOG(ERROR) << name_ << ": inputs tensor map size must be " << GATHER_V2_INPUTS_SIZE << ", but is "
|
||||||
|
<< inputs_tensor_map_.size();
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
if (outputs_tensor_map_.size() != GATHER_V2_OUTPUTS_SIZE) {
|
||||||
|
MS_LOG(ERROR) << name_ << ": outputs tensor map size must be " << GATHER_V2_OUTPUTS_SIZE << ", but is "
|
||||||
|
<< outputs_tensor_map_.size();
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
// infer tensor shape
|
||||||
|
Shape input_shape = inputs_shape_.at(0);
|
||||||
|
Shape input_index_shape = inputs_shape_.at(1);
|
||||||
|
Shape output_shape = outputs_shape_.at(0);
|
||||||
|
|
||||||
|
TensorLayout input_tensor_layout, input_index_layout, output_tensor_layout;
|
||||||
|
if ((input_tensor_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_.at(0), input_shape) != SUCCESS) ||
|
||||||
|
(input_index_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_.at(1), input_index_shape) != SUCCESS) ||
|
||||||
|
(output_tensor_layout.InitFromVector(dev_matrix_shape_, outputs_tensor_map_.at(0), output_shape) != SUCCESS)) {
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
|
||||||
|
TensorInfo input_tensor_info(input_tensor_layout);
|
||||||
|
TensorInfo input_index_info(input_index_layout);
|
||||||
|
TensorInfo output_tensor_info(output_tensor_layout);
|
||||||
|
|
||||||
|
inputs_tensor_info_.push_back(input_tensor_info);
|
||||||
|
inputs_tensor_info_.push_back(input_index_info);
|
||||||
|
outputs_tensor_info_.push_back(output_tensor_info);
|
||||||
|
return SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
OperatorVector CreateSubOp(int32_t sub_value) {
|
||||||
|
OperatorVector ops;
|
||||||
|
OperatorName operator_name = SUB;
|
||||||
|
OperatorAttrs operator_attrs;
|
||||||
|
|
||||||
|
py::tuple tuple = py::make_tuple(sub_value);
|
||||||
|
mindspore::tensor::TensorPtr tensor_ptr = std::make_shared<mindspore::tensor::Tensor>(tuple, kInt32);
|
||||||
|
ValuePtr op_param_value = MakeValue(tensor_ptr);
|
||||||
|
|
||||||
|
Attr op1_param = std::make_pair("", op_param_value);
|
||||||
|
OperatorParams operator_param = {std::make_pair(op1_param, 2)};
|
||||||
|
|
||||||
|
OperatorArgs operator_args = std::make_pair(operator_attrs, operator_param);
|
||||||
|
Operator op = std::make_pair(operator_name, operator_args);
|
||||||
|
ops.push_back(op);
|
||||||
|
return ops;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status GatherV2Info::InferTensorSubOps() {
|
||||||
|
sub_ops_.clear();
|
||||||
|
if ((index_size_ == 0) || (axis_strategy_ == 1)) {
|
||||||
|
return SUCCESS;
|
||||||
|
}
|
||||||
|
int32_t mod_n = 1;
|
||||||
|
for (size_t i = IntToSize(axis_) + 1; i < dev_matrix_shape_.size(); i++) {
|
||||||
|
mod_n *= dev_matrix_shape_.at(i);
|
||||||
|
}
|
||||||
|
if ((axis_ >= SizeToInt(dev_matrix_shape_.size())) || axis_ < 0) {
|
||||||
|
MS_LOG(ERROR) << "Axis is " << axis_ << ", not in [0, " << dev_matrix_shape_.size() << ").";
|
||||||
|
}
|
||||||
|
int32_t mod_p = mod_n * dev_matrix_shape_.at(axis_);
|
||||||
|
int32_t rank = g_device_manager->global_rank();
|
||||||
|
int32_t mod_rank = rank % mod_p;
|
||||||
|
mod_rank = static_cast<int32_t>(mod_rank / mod_n);
|
||||||
|
if (inputs_shape_.size() != GATHER_V2_INPUTS_SIZE) {
|
||||||
|
MS_LOG(ERROR) << name_ << ": inputs shape size must be " << GATHER_V2_INPUTS_SIZE << ", but is "
|
||||||
|
<< inputs_shape_.size();
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
if ((axis_ >= SizeToInt(inputs_shape_.at(0).size())) || axis_ < 0) {
|
||||||
|
MS_LOG(ERROR) << "Axis is " << axis_ << ", not in [0, " << inputs_shape_.at(0).size() << ").";
|
||||||
|
}
|
||||||
|
int32_t sub_value = static_cast<int32_t>(inputs_shape_.at(0).at(axis_) / dev_matrix_shape_.at(axis_)) * mod_rank;
|
||||||
|
|
||||||
|
OperatorVector sub_op;
|
||||||
|
sub_ops_.emplace_back(std::move(sub_op));
|
||||||
|
sub_op = CreateSubOp(sub_value);
|
||||||
|
sub_ops_.emplace_back(std::move(sub_op));
|
||||||
|
return SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status GatherV2Info::Init(const StrategyPtr& strategy) {
|
||||||
|
if (InitWithAutoRepeatCalc(strategy) != SUCCESS) {
|
||||||
|
MS_LOG(ERROR) << name_ << ": Init failed.";
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
Status status = InferTensorSubOps();
|
||||||
|
if (status != SUCCESS) {
|
||||||
|
MS_LOG(ERROR) << name_ << ": InferTensorSubOps failed.";
|
||||||
|
return status;
|
||||||
|
}
|
||||||
|
MS_LOG(INFO) << name_ << ": Init success.";
|
||||||
|
return SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status GatherV2Info::InitForCostModel(const StrategyPtr& strategy) {
|
||||||
|
if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) {
|
||||||
|
if (is_auto_parallel_) {
|
||||||
|
MS_LOG(DEBUG) << name_ << ": Init for cost model failed.";
|
||||||
|
} else {
|
||||||
|
MS_LOG(ERROR) << name_ << ": Init for cost model failed.";
|
||||||
|
}
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
MS_LOG(INFO) << name_ << ": Init for cost model success.";
|
||||||
|
return SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status GatherV2Info::GenerateStrategies(int32_t stage_id) {
|
||||||
|
if ((inputs_shape_.size() != GATHER_V2_INPUTS_SIZE) || (outputs_shape_.size() != GATHER_V2_OUTPUTS_SIZE)) {
|
||||||
|
MS_LOG(ERROR) << name_ << " : Inputs shape size(" << inputs_shape_.size() << ") or outputs shape size("
|
||||||
|
<< outputs_shape_.size() << "is wrong.";
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
|
||||||
|
is_auto_parallel_ = true;
|
||||||
|
Shape input0_split(inputs_shape_[0].size());
|
||||||
|
Shapes splittable_inputs = {input0_split};
|
||||||
|
|
||||||
|
std::vector<StrategyPtr> sp_vector;
|
||||||
|
if (GenerateStrategiesForIndependentInputs(stage_id, {inputs_shape_.at(0)}, splittable_inputs, &sp_vector) !=
|
||||||
|
SUCCESS) {
|
||||||
|
MS_LOG(ERROR) << name_ << " : Generate strategies for independent inputs() failed.";
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
size_t success = 0;
|
||||||
|
for (auto& sp : sp_vector) {
|
||||||
|
if (SetCostUnderStrategy(sp) == SUCCESS) {
|
||||||
|
success++;
|
||||||
|
MS_LOG(INFO) << name_ << " : Successfully generated " << success << " strategy";
|
||||||
|
PrintStrategy(sp);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status GatherV2Info::SetCostUnderStrategy(const StrategyPtr& strategy) {
|
||||||
|
if (SetCostUnderStrategyBase(strategy) != SUCCESS) {
|
||||||
|
if (is_auto_parallel_) {
|
||||||
|
MS_LOG(DEBUG) << name_ << ": Set cost under strategy failed.";
|
||||||
|
} else {
|
||||||
|
MS_LOG(ERROR) << name_ << ": Set cost under strategy failed.";
|
||||||
|
}
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
return SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<std::vector<std::vector<int32_t>>> GatherV2Info::GenerateBatchStrategies() {
|
||||||
|
if (inputs_shape_.size() != GATHER_V2_INPUTS_SIZE) {
|
||||||
|
MS_LOG(EXCEPTION) << name_ << ": inputs shape size must be " << GATHER_V2_INPUTS_SIZE << ", but is "
|
||||||
|
<< inputs_shape_.size();
|
||||||
|
}
|
||||||
|
CheckGlobalDeviceManager();
|
||||||
|
size_t dev_num = g_device_manager->GetDeviceListByStageId(0).size();
|
||||||
|
if (GetAttrs() != SUCCESS) {
|
||||||
|
MS_LOG(EXCEPTION) << "GetAttrs failed!";
|
||||||
|
}
|
||||||
|
|
||||||
|
Dimensions strategy;
|
||||||
|
if (index_size_ != 1) {
|
||||||
|
strategy.push_back(1);
|
||||||
|
} else {
|
||||||
|
strategy.push_back(SizeToInt(dev_num));
|
||||||
|
}
|
||||||
|
for (size_t i = 1; i < inputs_shape_[0].size(); i++) {
|
||||||
|
strategy.push_back(1);
|
||||||
|
}
|
||||||
|
std::vector<Dimensions> strategy_v = {strategy};
|
||||||
|
return std::make_shared<std::vector<std::vector<int32_t>>>(strategy_v);
|
||||||
|
}
|
||||||
|
} // namespace parallel
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,73 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CCSRC_PARALLEL_OPS_INFO_GATHER_V2_INFO_H_
|
||||||
|
#define MINDSPORE_CCSRC_PARALLEL_OPS_INFO_GATHER_V2_INFO_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <unordered_map>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "ir/value.h"
|
||||||
|
#include "parallel/auto_parallel/operator_costmodel.h"
|
||||||
|
#include "parallel/ops_info/operator_info.h"
|
||||||
|
#include "parallel/strategy.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace parallel {
|
||||||
|
constexpr size_t GATHER_V2_INPUTS_SIZE = 2;
|
||||||
|
constexpr size_t GATHER_V2_OUTPUTS_SIZE = 1;
|
||||||
|
constexpr size_t GATHER_V2_INPUTS_VALUE_SIZE = 3;
|
||||||
|
// We now supported limited parallel strategies.
|
||||||
|
// If the strategy corresponding to axis is more than 1, index must be evenly distributed across the axis-dimension of
|
||||||
|
// the input.
|
||||||
|
// If Index is a scalar or n-dimension vector(n > 1), the strategy corresponding to axis must be 1.
|
||||||
|
class GatherV2Info : public OperatorInfo {
|
||||||
|
public:
|
||||||
|
GatherV2Info(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape,
|
||||||
|
const PrimitiveAttrs& attrs)
|
||||||
|
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<GatherV2Cost>()),
|
||||||
|
axis_(-1),
|
||||||
|
index_size_(0),
|
||||||
|
axis_strategy_(1) {}
|
||||||
|
~GatherV2Info() override = default;
|
||||||
|
Status Init(const StrategyPtr& strategy) override;
|
||||||
|
Status InitForCostModel(const StrategyPtr& strategy) override;
|
||||||
|
|
||||||
|
Status GenerateStrategies(int32_t stage_id) override;
|
||||||
|
Status SetCostUnderStrategy(const StrategyPtr& strategy) override;
|
||||||
|
std::shared_ptr<std::vector<std::vector<int32_t>>> GenerateBatchStrategies() override;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
Status CheckStrategy(const StrategyPtr& strategy) override;
|
||||||
|
Status InferMirrorOps() override { return SUCCESS; }
|
||||||
|
Status InferForwardCommunication() override { return SUCCESS; }
|
||||||
|
Status InferTensorInfo() override;
|
||||||
|
Status InferDevMatrixShape() override;
|
||||||
|
Status InferTensorMap() override;
|
||||||
|
Status GetAttrs() override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
Status InferTensorSubOps();
|
||||||
|
|
||||||
|
int32_t axis_;
|
||||||
|
size_t index_size_;
|
||||||
|
int32_t axis_strategy_;
|
||||||
|
};
|
||||||
|
} // namespace parallel
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_GATHER_V2_INFO_H_
|
|
@ -112,6 +112,7 @@ void OperatorInfo::ResetQueueMember() {
|
||||||
dev_matrix_shape_.clear();
|
dev_matrix_shape_.clear();
|
||||||
forward_op_.clear();
|
forward_op_.clear();
|
||||||
mirror_ops_.clear();
|
mirror_ops_.clear();
|
||||||
|
sub_ops_.clear();
|
||||||
replace_op_.clear();
|
replace_op_.clear();
|
||||||
replace_op_info_.clear();
|
replace_op_info_.clear();
|
||||||
virtual_div_op_.clear();
|
virtual_div_op_.clear();
|
||||||
|
|
|
@ -41,6 +41,7 @@ namespace mindspore {
|
||||||
namespace parallel {
|
namespace parallel {
|
||||||
using ForwardOp = OperatorVector;
|
using ForwardOp = OperatorVector;
|
||||||
using MirrorOps = std::vector<OperatorVector>;
|
using MirrorOps = std::vector<OperatorVector>;
|
||||||
|
using Ops = std::vector<OperatorVector>;
|
||||||
using VirtualDivOp = OperatorVector;
|
using VirtualDivOp = OperatorVector;
|
||||||
using TensorMaps = std::vector<std::vector<int32_t>>;
|
using TensorMaps = std::vector<std::vector<int32_t>>;
|
||||||
using TensorLayouts = std::vector<TensorLayout>;
|
using TensorLayouts = std::vector<TensorLayout>;
|
||||||
|
@ -99,6 +100,7 @@ class OperatorInfo {
|
||||||
OutPutInfoVector replace_op_info() const { return replace_op_info_; }
|
OutPutInfoVector replace_op_info() const { return replace_op_info_; }
|
||||||
virtual ReplaceGraphPtr replace_graph(const CNodePtr&) { return replace_graph_; }
|
virtual ReplaceGraphPtr replace_graph(const CNodePtr&) { return replace_graph_; }
|
||||||
MirrorOps mirror_ops() const { return mirror_ops_; }
|
MirrorOps mirror_ops() const { return mirror_ops_; }
|
||||||
|
Ops sub_ops() const { return sub_ops_; }
|
||||||
VirtualDivOp virtual_div_op() const { return virtual_div_op_; }
|
VirtualDivOp virtual_div_op() const { return virtual_div_op_; }
|
||||||
Shape dev_matrix_shape() const { return dev_matrix_shape_; }
|
Shape dev_matrix_shape() const { return dev_matrix_shape_; }
|
||||||
std::vector<TensorInfo> inputs_tensor_info() const { return inputs_tensor_info_; }
|
std::vector<TensorInfo> inputs_tensor_info() const { return inputs_tensor_info_; }
|
||||||
|
@ -190,6 +192,7 @@ class OperatorInfo {
|
||||||
TensorMaps inputs_tensor_map_;
|
TensorMaps inputs_tensor_map_;
|
||||||
TensorMaps outputs_tensor_map_;
|
TensorMaps outputs_tensor_map_;
|
||||||
ForwardOp forward_op_;
|
ForwardOp forward_op_;
|
||||||
|
Ops sub_ops_;
|
||||||
ForwardOp replace_op_;
|
ForwardOp replace_op_;
|
||||||
OutPutInfoVector replace_op_info_;
|
OutPutInfoVector replace_op_info_;
|
||||||
ReplaceGraphPtr replace_graph_;
|
ReplaceGraphPtr replace_graph_;
|
||||||
|
|
|
@ -24,6 +24,7 @@
|
||||||
#include "parallel/ops_info/comparison_function_info.h"
|
#include "parallel/ops_info/comparison_function_info.h"
|
||||||
#include "parallel/ops_info/dropout_do_mask_info.h"
|
#include "parallel/ops_info/dropout_do_mask_info.h"
|
||||||
#include "parallel/ops_info/elementary_function_info.h"
|
#include "parallel/ops_info/elementary_function_info.h"
|
||||||
|
#include "parallel/ops_info/gather_v2_info.h"
|
||||||
#include "parallel/ops_info/get_next_info.h"
|
#include "parallel/ops_info/get_next_info.h"
|
||||||
#include "parallel/ops_info/l2_normalize_info.h"
|
#include "parallel/ops_info/l2_normalize_info.h"
|
||||||
#include "parallel/ops_info/loss_info.h"
|
#include "parallel/ops_info/loss_info.h"
|
||||||
|
|
|
@ -464,6 +464,14 @@ void SplitTensor(const AnfNodePtr& node, const CNodePtr& next_node, int index) {
|
||||||
MS_EXCEPTION_IF_NULL(func_graph);
|
MS_EXCEPTION_IF_NULL(func_graph);
|
||||||
Operator op = CreateGetTensorSliceOp(tensor_layout);
|
Operator op = CreateGetTensorSliceOp(tensor_layout);
|
||||||
InsertGetTensorSliceOp(op, next_node, func_graph, index, SPLIT_TENSOR);
|
InsertGetTensorSliceOp(op, next_node, func_graph, index, SPLIT_TENSOR);
|
||||||
|
if (!op_info->sub_ops().empty()) {
|
||||||
|
auto sub_ops = op_info->sub_ops();
|
||||||
|
for (size_t i = 0; i < sub_ops.size(); i++) {
|
||||||
|
if (!sub_ops.at(i).empty()) {
|
||||||
|
InsertGetTensorSliceOp(sub_ops.at(i).at(0), next_node, func_graph, index, SUB);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void StepSplitTensor(const AnfNodePtr& node, const FuncGraphManagerPtr& manager) {
|
void StepSplitTensor(const AnfNodePtr& node, const FuncGraphManagerPtr& manager) {
|
||||||
|
|
|
@ -29,6 +29,8 @@ from mindspore.nn import Dense, Cell
|
||||||
from mindspore import context
|
from mindspore import context
|
||||||
|
|
||||||
context.set_context(mode=context.GRAPH_MODE)
|
context.set_context(mode=context.GRAPH_MODE)
|
||||||
|
device_number = 32
|
||||||
|
batch_size_per_device = 128
|
||||||
|
|
||||||
|
|
||||||
class Dataset():
|
class Dataset():
|
||||||
|
@ -57,15 +59,22 @@ class Dataset():
|
||||||
|
|
||||||
|
|
||||||
class GatherV2(_Loss):
|
class GatherV2(_Loss):
|
||||||
def __init__(self, batchsize):
|
def __init__(self, index_dim, strategy, index_size=16):
|
||||||
super(GatherV2, self).__init__()
|
super(GatherV2, self).__init__()
|
||||||
self.pow = P.Pow()
|
self.pow = P.Pow()
|
||||||
emb_list = list(range(batchsize))
|
emb1_list = 21
|
||||||
emb1_list = emb_list[0::2]
|
emb2_list = 2
|
||||||
emb2_list = emb_list[1::2]
|
if index_dim == 1:
|
||||||
|
emb_list = list(range(index_size))
|
||||||
|
emb1_list = emb_list[0::2]
|
||||||
|
emb2_list = emb_list[1::2]
|
||||||
|
if index_dim == 2:
|
||||||
|
emb_list = np.arange(index_size*16)
|
||||||
|
emb1_list = np.reshape(emb_list[0::2], (int(index_size/2), 16))
|
||||||
|
emb2_list = np.reshape(emb_list[1::2], (int(index_size/2), 16))
|
||||||
self.emb1_param = Tensor(emb1_list, dtype=mstype.int32)
|
self.emb1_param = Tensor(emb1_list, dtype=mstype.int32)
|
||||||
self.emb2_param = Tensor(emb2_list, dtype=mstype.int32)
|
self.emb2_param = Tensor(emb2_list, dtype=mstype.int32)
|
||||||
self.gatherv2 = P.GatherV2()
|
self.gatherv2 = P.GatherV2().set_strategy(strategy)
|
||||||
|
|
||||||
def construct(self, nembeddings):
|
def construct(self, nembeddings):
|
||||||
emb1 = self.gatherv2(nembeddings, self.emb1_param, 0)
|
emb1 = self.gatherv2(nembeddings, self.emb1_param, 0)
|
||||||
|
@ -73,10 +82,6 @@ class GatherV2(_Loss):
|
||||||
return self.pow((emb1 - emb2), 2.0)
|
return self.pow((emb1 - emb2), 2.0)
|
||||||
|
|
||||||
|
|
||||||
def get_loss(batchsize):
|
|
||||||
return GatherV2(batchsize)
|
|
||||||
|
|
||||||
|
|
||||||
def fc_with_initialize(input_channels, out_channels):
|
def fc_with_initialize(input_channels, out_channels):
|
||||||
return Dense(input_channels, out_channels)
|
return Dense(input_channels, out_channels)
|
||||||
|
|
||||||
|
@ -114,26 +119,23 @@ class TrainOneStepCell(Cell):
|
||||||
return F.depend(loss, self.optimizer(grads))
|
return F.depend(loss, self.optimizer(grads))
|
||||||
|
|
||||||
|
|
||||||
def test_trains():
|
def net_trains(gather_v2_strategy, criterion, rank):
|
||||||
init()
|
init()
|
||||||
lr = 0.1
|
lr = 0.1
|
||||||
momentum = 0.9
|
momentum = 0.9
|
||||||
max_epoch = 20
|
max_epoch = 20
|
||||||
device_number = 32
|
|
||||||
batch_size_per_device = 128
|
|
||||||
input_channels = 256
|
input_channels = 256
|
||||||
out_channels = 512
|
out_channels = 512
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, save_graphs=False)
|
||||||
context.reset_auto_parallel_context()
|
context.reset_auto_parallel_context()
|
||||||
context.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, device_num=device_number)
|
context.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, device_num=device_number,
|
||||||
|
global_rank=rank)
|
||||||
predict = Tensor(np.ones([batch_size_per_device, input_channels]), dtype=ms.float32)
|
predict = Tensor(np.ones([batch_size_per_device, input_channels]), dtype=ms.float32)
|
||||||
dataset = Dataset(predict, 4)
|
dataset = Dataset(predict, 4)
|
||||||
|
|
||||||
network = fc_with_initialize(input_channels, out_channels)
|
network = fc_with_initialize(input_channels, out_channels)
|
||||||
network.set_train()
|
network.set_train()
|
||||||
|
|
||||||
criterion = get_loss(batch_size_per_device * device_number)
|
|
||||||
|
|
||||||
train_network = BuildTrainNetwork(network, criterion)
|
train_network = BuildTrainNetwork(network, criterion)
|
||||||
train_network.set_train()
|
train_network.set_train()
|
||||||
opt = Momentum(train_network.trainable_params(), lr, momentum)
|
opt = Momentum(train_network.trainable_params(), lr, momentum)
|
||||||
|
@ -143,5 +145,90 @@ def test_trains():
|
||||||
model.train(max_epoch, dataset, dataset_sink_mode=False)
|
model.train(max_epoch, dataset, dataset_sink_mode=False)
|
||||||
context.reset_auto_parallel_context()
|
context.reset_auto_parallel_context()
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
test_trains()
|
def test_auto_batch_parallel():
|
||||||
|
gather_v2_strategy = None
|
||||||
|
criterion = GatherV2(1, strategy=gather_v2_strategy, index_size=batch_size_per_device * device_number)
|
||||||
|
rank = 2
|
||||||
|
net_trains(gather_v2_strategy, criterion, rank)
|
||||||
|
|
||||||
|
|
||||||
|
def test_2d_index_auto_batch_parallel():
|
||||||
|
gather_v2_strategy = None
|
||||||
|
criterion = GatherV2(2, strategy=gather_v2_strategy, index_size=batch_size_per_device * device_number)
|
||||||
|
rank = 2
|
||||||
|
net_trains(gather_v2_strategy, criterion, rank)
|
||||||
|
|
||||||
|
|
||||||
|
def test_batch_parallel():
|
||||||
|
gather_v2_strategy = ((device_number, 1),)
|
||||||
|
criterion = GatherV2(1, strategy=gather_v2_strategy, index_size=batch_size_per_device * device_number)
|
||||||
|
rank = 2
|
||||||
|
net_trains(gather_v2_strategy, criterion, rank)
|
||||||
|
|
||||||
|
|
||||||
|
def test_strategy1():
|
||||||
|
gather_v2_strategy = ((16, 2),)
|
||||||
|
rank = 2
|
||||||
|
criterion = GatherV2(1, strategy=gather_v2_strategy, index_size=batch_size_per_device * device_number)
|
||||||
|
net_trains(gather_v2_strategy, criterion, rank)
|
||||||
|
|
||||||
|
|
||||||
|
def test_strategy2():
|
||||||
|
gather_v2_strategy = ((1, device_number),)
|
||||||
|
rank = 2
|
||||||
|
criterion = GatherV2(1, strategy=gather_v2_strategy, index_size=batch_size_per_device * device_number)
|
||||||
|
net_trains(gather_v2_strategy, criterion, rank)
|
||||||
|
|
||||||
|
|
||||||
|
def test_strategy3():
|
||||||
|
gather_v2_strategy = ((8, 1),)
|
||||||
|
rank = 2
|
||||||
|
criterion = GatherV2(1, strategy=gather_v2_strategy, index_size=batch_size_per_device * device_number)
|
||||||
|
net_trains(gather_v2_strategy, criterion, rank)
|
||||||
|
|
||||||
|
|
||||||
|
class GatherV2Axis1(_Loss):
|
||||||
|
def __init__(self, index_dim, strategy, index_size=16):
|
||||||
|
super(GatherV2Axis1, self).__init__()
|
||||||
|
self.pow = P.Pow()
|
||||||
|
emb1_list = 21
|
||||||
|
emb2_list = 2
|
||||||
|
if index_dim == 1:
|
||||||
|
emb_list = list(range(index_size))
|
||||||
|
emb1_list = emb_list[0::2]
|
||||||
|
emb2_list = emb_list[1::2]
|
||||||
|
if index_dim == 2:
|
||||||
|
emb_list = np.arange(index_size*index_size)
|
||||||
|
emb1_list = np.reshape(emb_list[0::2], (int(index_size/2), index_size))
|
||||||
|
emb2_list = np.reshape(emb_list[1::2], (int(index_size/2), index_size))
|
||||||
|
self.emb1_param = Tensor(emb1_list, dtype=mstype.int32)
|
||||||
|
self.emb2_param = Tensor(emb2_list, dtype=mstype.int32)
|
||||||
|
self.gatherv2 = P.GatherV2().set_strategy(strategy)
|
||||||
|
|
||||||
|
def construct(self, nembeddings):
|
||||||
|
emb1 = self.gatherv2(nembeddings, self.emb1_param, 1)
|
||||||
|
emb2 = self.gatherv2(nembeddings, self.emb2_param, 1)
|
||||||
|
return self.pow((emb1 - emb2), 2.0)
|
||||||
|
|
||||||
|
|
||||||
|
def test_axis1_auto_batch_parallel():
|
||||||
|
gather_v2_strategy = None
|
||||||
|
criterion = GatherV2Axis1(1, strategy=gather_v2_strategy, index_size=512)
|
||||||
|
rank = 2
|
||||||
|
net_trains(gather_v2_strategy, criterion, rank)
|
||||||
|
|
||||||
|
|
||||||
|
def test_axis1_batch_parallel():
|
||||||
|
gather_v2_strategy = ((device_number, 1),)
|
||||||
|
criterion = GatherV2Axis1(1, strategy=gather_v2_strategy, index_size=512)
|
||||||
|
rank = 2
|
||||||
|
net_trains(gather_v2_strategy, criterion, rank)
|
||||||
|
|
||||||
|
|
||||||
|
def test_axis1_strategy1():
|
||||||
|
gather_v2_strategy = ((16, 2),)
|
||||||
|
rank = 17
|
||||||
|
criterion = GatherV2Axis1(1, strategy=gather_v2_strategy, index_size=512)
|
||||||
|
net_trains(gather_v2_strategy, criterion, rank)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue