forked from mindspore-Ecosystem/mindspore
!1023 add_gatherv2_distributed_op
Merge pull request !1023 from lichen/add_gatherv2_distributed_op
This commit is contained in:
commit
dd2062bf8d
|
@ -787,5 +787,90 @@ double LayerNormCost::GetForwardComputationCost(const std::vector<TensorInfo> &i
|
||||||
}
|
}
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
double GatherV2PCost::GetForwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
|
||||||
|
int32_t stage_id) const {
|
||||||
|
double result = 0.0;
|
||||||
|
if (outputs_type_lengths_.size() != outputs.size()) {
|
||||||
|
MS_LOG(EXCEPTION) << "Invalid inputs type size " << inputs_type_lengths_.size() << " for gatherv2 cost";
|
||||||
|
}
|
||||||
|
// don't split axis
|
||||||
|
if (strategy_.at(IntToSize(axis_)) == 1) {
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
// split axis
|
||||||
|
auto param_shape = inputs[0].slice_shape();
|
||||||
|
auto index_shape = inputs[1].slice_shape();
|
||||||
|
Shape reducescatter_shape = index_shape;
|
||||||
|
if (param_shape.size() == 2) {
|
||||||
|
reducescatter_shape.push_back(param_shape.at(1 - axis_));
|
||||||
|
}
|
||||||
|
result += ListProduct(reducescatter_shape) * static_cast<double>(outputs_type_lengths_[0]);
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
double GatherV2PCost::GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
|
||||||
|
int32_t stage_id) const {
|
||||||
|
double result = 0.0;
|
||||||
|
CheckGlobalDeviceManager();
|
||||||
|
MS_EXCEPTION_IF_NULL(g_device_manager);
|
||||||
|
auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size();
|
||||||
|
|
||||||
|
for (size_t j = 0; j < inputs.size(); ++j) {
|
||||||
|
if (!is_parameter_[j]) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
TensorInfo input_a_tensor_info = inputs[j];
|
||||||
|
Shape input_a_shape = input_a_tensor_info.shape();
|
||||||
|
Shape input_a_slice_shape = input_a_tensor_info.slice_shape();
|
||||||
|
int32_t used_device_num = 1;
|
||||||
|
for (size_t i = 0; i < input_a_shape.size(); ++i) {
|
||||||
|
used_device_num *= input_a_shape[i] / input_a_slice_shape[i];
|
||||||
|
}
|
||||||
|
if (total_device_num != IntToSize(used_device_num)) {
|
||||||
|
result += ListProduct(input_a_slice_shape) * static_cast<double>(inputs_type_lengths_[0]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
double GatherV2PCost::GetForwardComputationCost(const std::vector<TensorInfo> &inputs,
|
||||||
|
const std::vector<TensorInfo> &outputs, int32_t stage_id) const {
|
||||||
|
double result = 0.0;
|
||||||
|
Shape input0_slice_shape = inputs[0].slice_shape();
|
||||||
|
Shape input1_slice_shape = inputs[1].slice_shape();
|
||||||
|
if (inputs_type_lengths_.size() != inputs.size()) {
|
||||||
|
MS_LOG(EXCEPTION) << "Invalid inputs type size " << inputs_type_lengths_.size() << " for gatherv2 cost";
|
||||||
|
}
|
||||||
|
// don't split axis
|
||||||
|
if (strategy_.at(IntToSize(axis_)) == 1) {
|
||||||
|
result += ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]) +
|
||||||
|
ListProduct(input1_slice_shape) * static_cast<double>(inputs_type_lengths_[1]);
|
||||||
|
} else {
|
||||||
|
// split axis
|
||||||
|
result += ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]) * GATHERV2_COST_WEIGHT0 +
|
||||||
|
ListProduct(input1_slice_shape) * static_cast<double>(inputs_type_lengths_[1]) * GATHERV2_COST_WEIGHT1;
|
||||||
|
}
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
double GatherV2PCost::GetBackwardComputationCost(const std::vector<TensorInfo> &inputs,
|
||||||
|
const std::vector<TensorInfo> &outputs, int32_t) const {
|
||||||
|
double result = 0.0;
|
||||||
|
Shape input1_slice_shape = inputs[1].slice_shape();
|
||||||
|
Shape output0_slice_shape = outputs[0].slice_shape();
|
||||||
|
// don't split axis
|
||||||
|
if (strategy_.at(IntToSize(axis_)) == 1) {
|
||||||
|
result += ListProduct(output0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]);
|
||||||
|
} else {
|
||||||
|
// split axis
|
||||||
|
result += ListProduct(output0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]) * GATHERV2_COST_WEIGHT2 +
|
||||||
|
ListProduct(input1_slice_shape) * static_cast<double>(inputs_type_lengths_[1]) * GATHERV2_COST_WEIGHT3;
|
||||||
|
}
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
} // namespace parallel
|
} // namespace parallel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -27,6 +27,10 @@ namespace parallel {
|
||||||
#define MAXIMUM_INPUT_NUMBER 100
|
#define MAXIMUM_INPUT_NUMBER 100
|
||||||
#define DEFAULT_DATA_TYPE_LENGTH 4
|
#define DEFAULT_DATA_TYPE_LENGTH 4
|
||||||
#define DROPOUT_COST_RATE 1.125 // the DropoutGenMask need 12.5% memory
|
#define DROPOUT_COST_RATE 1.125 // the DropoutGenMask need 12.5% memory
|
||||||
|
#define GATHERV2_COST_WEIGHT0 3
|
||||||
|
#define GATHERV2_COST_WEIGHT1 7
|
||||||
|
#define GATHERV2_COST_WEIGHT2 2
|
||||||
|
#define GATHERV2_COST_WEIGHT3 6
|
||||||
|
|
||||||
class OperatorCost;
|
class OperatorCost;
|
||||||
using OperatorCostPtr = std::shared_ptr<OperatorCost>;
|
using OperatorCostPtr = std::shared_ptr<OperatorCost>;
|
||||||
|
@ -609,6 +613,38 @@ class GatherV2Cost : public OperatorCost {
|
||||||
};
|
};
|
||||||
|
|
||||||
using GatherV2CostPtr = std::shared_ptr<GatherV2Cost>;
|
using GatherV2CostPtr = std::shared_ptr<GatherV2Cost>;
|
||||||
|
|
||||||
|
class GatherV2PCost : public OperatorCost {
|
||||||
|
public:
|
||||||
|
explicit GatherV2PCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {}
|
||||||
|
GatherV2PCost() : OperatorCost(true) {}
|
||||||
|
~GatherV2PCost() override = default;
|
||||||
|
|
||||||
|
double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
|
||||||
|
int32_t stage_id) const override {
|
||||||
|
return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id);
|
||||||
|
}
|
||||||
|
double GetForwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
|
||||||
|
int32_t stage_id) const override;
|
||||||
|
double GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
|
||||||
|
int32_t stage_id) const override;
|
||||||
|
double GetComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
|
||||||
|
int32_t stage_id) const override {
|
||||||
|
return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id);
|
||||||
|
}
|
||||||
|
double GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
|
||||||
|
int32_t stage_id) const override;
|
||||||
|
double GetBackwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
|
||||||
|
int32_t) const override;
|
||||||
|
void set_axis(int32_t axis) { axis_ = axis; }
|
||||||
|
void set_strategy(const Shape &strategy) { strategy_ = strategy; }
|
||||||
|
|
||||||
|
protected:
|
||||||
|
int32_t axis_;
|
||||||
|
Shape strategy_;
|
||||||
|
};
|
||||||
|
|
||||||
|
using GatherV2PCostPtr = std::shared_ptr<GatherV2PCost>;
|
||||||
} // namespace parallel
|
} // namespace parallel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
#endif // PARALLEL_AUTO_PARALLEL_OPERATOR_COSTMODEL_H_
|
#endif // PARALLEL_AUTO_PARALLEL_OPERATOR_COSTMODEL_H_
|
||||||
|
|
|
@ -129,6 +129,7 @@ REGISTER(ExpandDimsInfo);
|
||||||
REGISTER(SqueezeInfo);
|
REGISTER(SqueezeInfo);
|
||||||
REGISTER(SigmoidCrossEntropyWithLogitsInfo);
|
REGISTER(SigmoidCrossEntropyWithLogitsInfo);
|
||||||
REGISTER(SquareInfo);
|
REGISTER(SquareInfo);
|
||||||
|
REGISTER(GatherV2PInfo);
|
||||||
} // namespace parallel
|
} // namespace parallel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
||||||
|
|
|
@ -41,6 +41,7 @@ ValuePtr CreatOpInstance(const OperatorAttrs &attrs, const OperatorName &op_name
|
||||||
AnfNodePtr CreatTypeInt(int32_t value);
|
AnfNodePtr CreatTypeInt(int32_t value);
|
||||||
AnfNodePtr CreatInt32Imm(int32_t value);
|
AnfNodePtr CreatInt32Imm(int32_t value);
|
||||||
AnfNodePtr CreateInt32Tensor(int32_t value);
|
AnfNodePtr CreateInt32Tensor(int32_t value);
|
||||||
|
AnfNodePtr ValuePtrToAnfNodePtr(const ValuePtr &value_ptr);
|
||||||
std::string HashInstanceName(const std::string &name);
|
std::string HashInstanceName(const std::string &name);
|
||||||
|
|
||||||
class GenerateGraph {
|
class GenerateGraph {
|
||||||
|
|
|
@ -0,0 +1,339 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "parallel/ops_info/gather_v2_p_info.h"
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
#include <numeric>
|
||||||
|
#include <functional>
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
|
#include "parallel/device_matrix.h"
|
||||||
|
#include "parallel/graph_util/generate_graph.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace parallel {
|
||||||
|
Status GatherV2PInfo::GetAttrs() {
|
||||||
|
// get axis, the third input is the axis, is a ValueNode
|
||||||
|
if (input_value_.at(2) == nullptr) {
|
||||||
|
MS_LOG(ERROR) << name_ << ": the third input value is nullptr, is not a ValueNode!";
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
auto axis = GetValue<int>(input_value_.at(2));
|
||||||
|
// if axis is negative then convert it to positive
|
||||||
|
auto params_shape = inputs_shape_.at(0);
|
||||||
|
if (params_shape.size() == 0) {
|
||||||
|
MS_LOG(ERROR) << name_ << ": params can not be a scalar!";
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
if (axis < 0) {
|
||||||
|
axis += SizeToInt(inputs_shape_[0].size());
|
||||||
|
}
|
||||||
|
axis_ = axis;
|
||||||
|
|
||||||
|
return SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status GatherV2PInfo::CheckStrategy(const StrategyPtr &strategy) {
|
||||||
|
if (CheckStrategyValue(strategy, {inputs_shape_.at(0)}, is_auto_parallel_) != SUCCESS) {
|
||||||
|
if (is_auto_parallel_) {
|
||||||
|
MS_LOG(DEBUG) << name_ << ": Invalid strategy.";
|
||||||
|
} else {
|
||||||
|
MS_LOG(ERROR) << name_ << ": Invalid strategy.";
|
||||||
|
}
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
|
||||||
|
// only support 1-dim and 2-dim param
|
||||||
|
if (inputs_shape_.at(0).size() != 1 && inputs_shape_.at(0).size() != 2) {
|
||||||
|
MS_LOG(ERROR) << name_ << ": Don't support param dim " << inputs_shape_.at(0).size();
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
|
||||||
|
// don't support scalar index
|
||||||
|
if (inputs_shape_.at(1).size() == 0) {
|
||||||
|
MS_LOG(ERROR) << name_ << ": Don't support scalar index.";
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
|
||||||
|
// axis=0, index_shape(0)%param_strategy(0) must be 0
|
||||||
|
Shape index_shape = inputs_shape_.at(1);
|
||||||
|
auto param_strategy = strategy->GetInputDim().at(0);
|
||||||
|
if ((axis_ == 0) && (index_shape.at(0) % param_strategy.at(0) != 0)) {
|
||||||
|
MS_LOG(ERROR) << name_ << ": index_shape(0) can't be divided by param_strategy(0).";
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
|
||||||
|
// axis != 0, param_shape(0)%(param_strategy(0)*param_strategy(axis)) must be 0
|
||||||
|
Shape param_shape = inputs_shape_.at(0);
|
||||||
|
if (axis_ != 0 && param_shape.at(0) % (param_strategy.at(0) * param_strategy.at(IntToSize(axis_))) != 0) {
|
||||||
|
MS_LOG(ERROR) << name_ << ": index_shape(0) can't be divided by (param_strategy(0)*param_strategy(axis)).";
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Don't support repeated calc
|
||||||
|
auto params_strategy = strategy->GetInputDim().at(0);
|
||||||
|
CheckGlobalDeviceManager();
|
||||||
|
size_t dev_num = g_device_manager->GetDeviceListByStageId(0).size();
|
||||||
|
auto product = std::accumulate(params_strategy.begin(), params_strategy.end(), 1, std::multiplies<int>());
|
||||||
|
if (dev_num != IntToSize(product)) {
|
||||||
|
MS_LOG(ERROR) << name_ << ": Invalid strategy. Don't support repeated calc.";
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
|
||||||
|
return SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status GatherV2PInfo::InferDevMatrixShape() {
|
||||||
|
dev_matrix_shape_.clear();
|
||||||
|
out_dev_matrix_shape_.clear();
|
||||||
|
// infer input dev_matrix_shape
|
||||||
|
auto params_strategy = strategy_->GetInputDim().at(0);
|
||||||
|
dev_matrix_shape_ = params_strategy;
|
||||||
|
|
||||||
|
// infer out dev_matrix_shape
|
||||||
|
// axis!=0, split axis
|
||||||
|
if (axis_ != 0 && params_strategy.at(IntToSize(axis_)) != 1) {
|
||||||
|
out_dev_matrix_shape_.push_back(params_strategy.at(0) * params_strategy.at(IntToSize(axis_)));
|
||||||
|
for (size_t i = 1; i < params_strategy.size(); ++i) {
|
||||||
|
if (i == IntToSize(axis_)) {
|
||||||
|
out_dev_matrix_shape_.push_back(1);
|
||||||
|
} else {
|
||||||
|
out_dev_matrix_shape_.push_back(params_strategy.at(i));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
out_dev_matrix_shape_ = params_strategy;
|
||||||
|
}
|
||||||
|
|
||||||
|
return SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status GatherV2PInfo::InferTensorMap() {
|
||||||
|
// infer input tensor map
|
||||||
|
size_t param_size = inputs_shape_.at(0).size();
|
||||||
|
size_t index_size = inputs_shape_.at(1).size();
|
||||||
|
std::vector<int32_t> tensor_map_index(index_size, -1);
|
||||||
|
std::vector<int32_t> tensor_map_params;
|
||||||
|
for (size_t i = 0; i < param_size; ++i) {
|
||||||
|
tensor_map_params.push_back(SizeToInt(param_size - i - 1));
|
||||||
|
}
|
||||||
|
|
||||||
|
// infer output tensor map
|
||||||
|
std::vector<int32_t> tensor_map_out;
|
||||||
|
if (axis_ == 0) {
|
||||||
|
tensor_map_out.push_back(SizeToInt(param_size - 1));
|
||||||
|
tensor_map_out.insert(tensor_map_out.end(), index_size - 1, -1);
|
||||||
|
for (size_t i = 1; i < param_size; ++i) {
|
||||||
|
tensor_map_out.push_back(SizeToInt(param_size - i - 1));
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (size_t i = 0; i < param_size; ++i) {
|
||||||
|
if (i == IntToSize(axis_)) {
|
||||||
|
tensor_map_out.insert(tensor_map_out.end(), index_size, -1);
|
||||||
|
} else {
|
||||||
|
tensor_map_out.push_back(SizeToInt(param_size - i - 1));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
inputs_tensor_map_.emplace_back(std::move(tensor_map_params));
|
||||||
|
inputs_tensor_map_.emplace_back(std::move(tensor_map_index));
|
||||||
|
outputs_tensor_map_.emplace_back(std::move(tensor_map_out));
|
||||||
|
return SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status GatherV2PInfo::InferTensorInfo() {
|
||||||
|
// infer tensor shape
|
||||||
|
Shape input_shape = inputs_shape_.at(0);
|
||||||
|
Shape input_index_shape = inputs_shape_.at(1);
|
||||||
|
Shape output_shape = outputs_shape_.at(0);
|
||||||
|
// infer tensor layout
|
||||||
|
TensorLayout input_tensor_layout, input_index_layout, output_tensor_layout;
|
||||||
|
if ((input_tensor_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_.at(0), input_shape) != SUCCESS) ||
|
||||||
|
(input_index_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_.at(1), input_index_shape) != SUCCESS) ||
|
||||||
|
(output_tensor_layout.InitFromVector(out_dev_matrix_shape_, outputs_tensor_map_.at(0), output_shape) !=
|
||||||
|
SUCCESS)) {
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
// infer tensor info
|
||||||
|
TensorInfo input_tensor_info(input_tensor_layout);
|
||||||
|
TensorInfo input_index_info(input_index_layout);
|
||||||
|
TensorInfo output_tensor_info(output_tensor_layout);
|
||||||
|
|
||||||
|
inputs_tensor_info_.push_back(input_tensor_info);
|
||||||
|
inputs_tensor_info_.push_back(input_index_info);
|
||||||
|
outputs_tensor_info_.push_back(output_tensor_info);
|
||||||
|
return SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status GatherV2PInfo::InferBias() {
|
||||||
|
CheckGlobalDeviceManager();
|
||||||
|
int32_t rank = g_device_manager->global_rank();
|
||||||
|
auto input_shape = inputs_shape_.at(0);
|
||||||
|
auto params_strategy = strategy_->GetInputDim().at(0);
|
||||||
|
// params_size=1, axis=0
|
||||||
|
if ((input_shape.size() == 1) && (axis_ == 0)) {
|
||||||
|
slice_size_ = input_shape.at(0) / params_strategy.at(0);
|
||||||
|
bias_ = rank * slice_size_;
|
||||||
|
return SUCCESS;
|
||||||
|
}
|
||||||
|
// params_size=2, axis=0
|
||||||
|
if ((input_shape.size() == 2) && (axis_ == 0)) {
|
||||||
|
slice_size_ = input_shape.at(0) / params_strategy.at(0);
|
||||||
|
bias_ = rank / params_strategy.at(1) * slice_size_;
|
||||||
|
return SUCCESS;
|
||||||
|
}
|
||||||
|
// params_size=2, axis=1
|
||||||
|
if ((input_shape.size() == 2) && (axis_ == 1)) {
|
||||||
|
slice_size_ = input_shape.at(1) / params_strategy.at(1);
|
||||||
|
bias_ = rank % params_strategy.at(1) * slice_size_;
|
||||||
|
return SUCCESS;
|
||||||
|
}
|
||||||
|
MS_LOG(ERROR) << name_ << ": Don't support params_size:" << input_shape.size() << " axis:" << axis_;
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status GatherV2PInfo::InferGroup() {
|
||||||
|
std::vector<Group> group_list;
|
||||||
|
if (CreateGroupByDim(IntToSize(axis_), &group_list) != SUCCESS) {
|
||||||
|
MS_LOG(ERROR) << name_ << ": Create group failed.";
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
|
||||||
|
group_ = group_list.at(0);
|
||||||
|
return SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status GatherV2PInfo::ComputeReplaceGraph(const CNodePtr &cnode) {
|
||||||
|
GenerateGraph gen_g = GenerateGraph();
|
||||||
|
if (gen_g.Init(cnode) != SUCCESS) {
|
||||||
|
MS_LOG(ERROR) << "GenerateGraph Init failed";
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
if (InferBias() != SUCCESS) {
|
||||||
|
MS_LOG(ERROR) << name_ << ": Infer Bias failed.";
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
auto sub = gen_g.PushBack({gen_g.NewOpInst(SUB), gen_g.virtual_input_node(), CreateInt32Tensor(bias_)});
|
||||||
|
auto relu = gen_g.PushBack({gen_g.NewOpInst(RELU), sub});
|
||||||
|
auto minimum = gen_g.PushBack({gen_g.NewOpInst(MINIMUM), relu, CreateInt32Tensor(slice_size_ - 1)});
|
||||||
|
auto equal = gen_g.PushBack({gen_g.NewOpInst(EQUAL), gen_g.virtual_input_node(), minimum});
|
||||||
|
auto gather_v2 =
|
||||||
|
gen_g.PushBack({gen_g.NewOpInst(GATHERV2), gen_g.virtual_input_node(), minimum, CreatInt32Imm(axis_)});
|
||||||
|
auto dtype = gen_g.PushBack({gen_g.NewOpInst(DTYPE), gather_v2});
|
||||||
|
auto cast = gen_g.PushBack({gen_g.NewOpInst(CAST), equal, dtype});
|
||||||
|
auto expand_dims = gen_g.PushBack({gen_g.NewOpInst(EXPAND_DIMS), cast, CreatInt32Imm(axis_ - 1)});
|
||||||
|
auto mul = gen_g.PushBack({gen_g.NewOpInst(MUL), gather_v2, expand_dims});
|
||||||
|
// don't need expandim,if param_size = 1,
|
||||||
|
if (inputs_shape_.at(0).size() == 1) {
|
||||||
|
mul = gen_g.PushBack({gen_g.NewOpInst(MUL), gather_v2, cast});
|
||||||
|
}
|
||||||
|
if (InferGroup() != SUCCESS) {
|
||||||
|
MS_LOG(ERROR) << name_ << ": Infer Group failed.";
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
Attr attr_op = std::make_pair(OP, MakeValue(REDUCE_OP_SUM));
|
||||||
|
Attr attr_group = std::make_pair(GROUP, MakeValue(group_.name()));
|
||||||
|
OperatorAttrs attrs = {attr_op, attr_group};
|
||||||
|
auto reduce_scatter = gen_g.PushBack({gen_g.NewOpInst(REDUCE_SCATTER, attrs), mul});
|
||||||
|
std::vector<std::pair<AnfNodePtr, int>> input_nodes = {std::make_pair(sub, 2), std::make_pair(gather_v2, 1),
|
||||||
|
std::make_pair(equal, 2)};
|
||||||
|
replace_graph_ = std::make_shared<std::pair<std::vector<std::pair<AnfNodePtr, int>>, AnfNodePtr>>(
|
||||||
|
std::make_pair(input_nodes, reduce_scatter));
|
||||||
|
|
||||||
|
return SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status GatherV2PInfo::Init(const StrategyPtr &strategy) {
|
||||||
|
auto param_strategy = strategy->GetInputDim().at(0);
|
||||||
|
if (InitWithAutoRepeatCalc(strategy) != SUCCESS) {
|
||||||
|
MS_LOG(ERROR) << name_ << ": Init failed.";
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
if (param_strategy.at(IntToSize(axis_)) != 1 && ComputeReplaceGraph(cnode_) != SUCCESS) {
|
||||||
|
MS_LOG(ERROR) << name_ << ": ComputeReplaceGraph failed.";
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
MS_LOG(INFO) << name_ << ": Init success.";
|
||||||
|
return SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status GatherV2PInfo::InitForCostModel(const StrategyPtr &strategy) {
|
||||||
|
if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) {
|
||||||
|
if (is_auto_parallel_) {
|
||||||
|
MS_LOG(DEBUG) << name_ << ": Init for cost model failed.";
|
||||||
|
} else {
|
||||||
|
MS_LOG(ERROR) << name_ << ": Init for cost model failed.";
|
||||||
|
}
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
auto param_strategy = strategy_->GetInputDim().at(0);
|
||||||
|
// cost model set axis and strategy
|
||||||
|
auto gatherv2_2cost = std::dynamic_pointer_cast<GatherV2PCost>(operator_cost());
|
||||||
|
gatherv2_2cost->set_axis(axis_);
|
||||||
|
gatherv2_2cost->set_strategy(param_strategy);
|
||||||
|
MS_LOG(INFO) << name_ << ": Init for cost model success.";
|
||||||
|
return SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status GatherV2PInfo::SetCostUnderStrategy(const StrategyPtr &strategy) {
|
||||||
|
if (SetCostUnderStrategyBase(strategy) != SUCCESS) {
|
||||||
|
if (is_auto_parallel_) {
|
||||||
|
MS_LOG(DEBUG) << name_ << ": Set cost under strategy failed.";
|
||||||
|
} else {
|
||||||
|
MS_LOG(ERROR) << name_ << ": Set cost under strategy failed.";
|
||||||
|
}
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
return SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status GatherV2PInfo::GenerateStrategies(int32_t stage_id) {
|
||||||
|
is_auto_parallel_ = true;
|
||||||
|
Shape input0_split(inputs_shape_[0].size(), 1);
|
||||||
|
Shapes splittable_inputs = {input0_split};
|
||||||
|
|
||||||
|
std::vector<StrategyPtr> sp_vector;
|
||||||
|
if (GenerateStrategiesForIndependentInputs(stage_id, {inputs_shape_.at(0)}, splittable_inputs, &sp_vector) !=
|
||||||
|
SUCCESS) {
|
||||||
|
MS_LOG(ERROR) << name_ << " : Generate strategies for independent inputs() failed.";
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
size_t success = 0;
|
||||||
|
for (auto &sp : sp_vector) {
|
||||||
|
if (SetCostUnderStrategy(sp) == SUCCESS) {
|
||||||
|
success++;
|
||||||
|
MS_LOG(INFO) << name_ << " : Successfully generated " << success << " strategy";
|
||||||
|
PrintStrategy(sp);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<std::vector<std::vector<int32_t>>> GatherV2PInfo::GenerateBatchStrategies() {
|
||||||
|
CheckGlobalDeviceManager();
|
||||||
|
size_t dev_num = g_device_manager->GetDeviceListByStageId(0).size();
|
||||||
|
Dimensions strategy;
|
||||||
|
strategy.push_back(SizeToInt(dev_num));
|
||||||
|
for (size_t i = 1; i < inputs_shape_[0].size(); i++) {
|
||||||
|
strategy.push_back(1);
|
||||||
|
}
|
||||||
|
std::vector<Dimensions> strategy_v = {strategy};
|
||||||
|
return std::make_shared<std::vector<std::vector<int32_t>>>(strategy_v);
|
||||||
|
}
|
||||||
|
} // namespace parallel
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,67 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CCSRC_PARALLEL_OPS_INFO_GATHER_V2_P_INFO_H_
|
||||||
|
#define MINDSPORE_CCSRC_PARALLEL_OPS_INFO_GATHER_V2_P_INFO_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <unordered_map>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "ir/value.h"
|
||||||
|
#include "parallel/auto_parallel/operator_costmodel.h"
|
||||||
|
#include "parallel/ops_info/operator_info.h"
|
||||||
|
#include "parallel/strategy.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace parallel {
|
||||||
|
class GatherV2PInfo : public OperatorInfo {
|
||||||
|
public:
|
||||||
|
GatherV2PInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||||
|
const PrimitiveAttrs &attrs)
|
||||||
|
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<GatherV2PCost>()) {}
|
||||||
|
~GatherV2PInfo() override = default;
|
||||||
|
Status Init(const StrategyPtr &strategy) override;
|
||||||
|
Status InitForCostModel(const StrategyPtr &strategy) override;
|
||||||
|
|
||||||
|
Status GenerateStrategies(int32_t stage_id) override;
|
||||||
|
Status SetCostUnderStrategy(const StrategyPtr &strategy) override;
|
||||||
|
std::shared_ptr<std::vector<std::vector<int32_t>>> GenerateBatchStrategies() override;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
Status CheckStrategy(const StrategyPtr &strategy) override;
|
||||||
|
Status InferMirrorOps() override { return SUCCESS; }
|
||||||
|
Status InferForwardCommunication() override { return SUCCESS; }
|
||||||
|
Status InferTensorInfo() override;
|
||||||
|
Status InferDevMatrixShape() override;
|
||||||
|
Status InferTensorMap() override;
|
||||||
|
Status GetAttrs() override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
Status ComputeReplaceGraph(const CNodePtr &cnode);
|
||||||
|
Status InferBias();
|
||||||
|
Status InferGroup();
|
||||||
|
|
||||||
|
int32_t axis_;
|
||||||
|
int32_t bias_;
|
||||||
|
int32_t slice_size_;
|
||||||
|
Shape out_dev_matrix_shape_;
|
||||||
|
Group group_;
|
||||||
|
};
|
||||||
|
} // namespace parallel
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_GATHER_V2_P_INFO_H_
|
|
@ -215,9 +215,9 @@ Status OneHotInfo::ComputeReplaceGraph(const CNodePtr &cnode) {
|
||||||
OperatorAttrs attrs_onehot = {attr_onehot_axis};
|
OperatorAttrs attrs_onehot = {attr_onehot_axis};
|
||||||
auto onehot = gen_g.PushBack({gen_g.NewOpInst(ONEHOT, attrs_onehot), sub2, CreatInt32Imm(classes_each_device_),
|
auto onehot = gen_g.PushBack({gen_g.NewOpInst(ONEHOT, attrs_onehot), sub2, CreatInt32Imm(classes_each_device_),
|
||||||
cnode->input(3), cnode->input(4)});
|
cnode->input(3), cnode->input(4)});
|
||||||
std::vector<AnfNodePtr> input_nodes = {floor_div, sub1};
|
std::vector<std::pair<AnfNodePtr, int>> input_nodes = {std::make_pair(floor_div, 1), std::make_pair(sub1, 1)};
|
||||||
replace_graph_ =
|
replace_graph_ = std::make_shared<std::pair<std::vector<std::pair<AnfNodePtr, int>>, AnfNodePtr>>(
|
||||||
std::make_shared<std::pair<std::vector<AnfNodePtr>, AnfNodePtr>>(std::make_pair(input_nodes, onehot));
|
std::make_pair(input_nodes, onehot));
|
||||||
|
|
||||||
return SUCCESS;
|
return SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
|
@ -48,7 +48,7 @@ using TensorLayouts = std::vector<TensorLayout>;
|
||||||
using different_type = std::vector<int32_t>::difference_type;
|
using different_type = std::vector<int32_t>::difference_type;
|
||||||
using PrimitiveAttrs = std::unordered_map<std::string, ValuePtr>;
|
using PrimitiveAttrs = std::unordered_map<std::string, ValuePtr>;
|
||||||
using Strategys = std::vector<Dimensions>;
|
using Strategys = std::vector<Dimensions>;
|
||||||
using ReplaceGraphPtr = std::shared_ptr<std::pair<std::vector<AnfNodePtr>, AnfNodePtr>>;
|
using ReplaceGraphPtr = std::shared_ptr<std::pair<std::vector<std::pair<AnfNodePtr, int>>, AnfNodePtr>>;
|
||||||
|
|
||||||
class Edge;
|
class Edge;
|
||||||
|
|
||||||
|
|
|
@ -36,5 +36,6 @@
|
||||||
#include "parallel/ops_info/reshape_info.h"
|
#include "parallel/ops_info/reshape_info.h"
|
||||||
#include "parallel/ops_info/transpose_info.h"
|
#include "parallel/ops_info/transpose_info.h"
|
||||||
#include "parallel/ops_info/virtual_dataset_info.h"
|
#include "parallel/ops_info/virtual_dataset_info.h"
|
||||||
|
#include "parallel/ops_info/gather_v2_p_info.h"
|
||||||
|
|
||||||
#endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_HEAD_FILES_H_
|
#endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_HEAD_FILES_H_
|
||||||
|
|
|
@ -114,7 +114,7 @@ constexpr char BE_CLONED_INDEX[] = "be_cloned_index";
|
||||||
constexpr char GROUP_RANKS[] = "group_ranks";
|
constexpr char GROUP_RANKS[] = "group_ranks";
|
||||||
constexpr char IS_IN_FORWARD[] = "is_in_forward";
|
constexpr char IS_IN_FORWARD[] = "is_in_forward";
|
||||||
constexpr char DEFAULT_INPUT[] = "default_input";
|
constexpr char DEFAULT_INPUT[] = "default_input";
|
||||||
constexpr char DTYPE[] = "dtype";
|
constexpr char DTYPE[] = "DType";
|
||||||
constexpr char DEV_NUM[] = "dev_num";
|
constexpr char DEV_NUM[] = "dev_num";
|
||||||
constexpr char MEAN_FLAG[] = "mean_flag";
|
constexpr char MEAN_FLAG[] = "mean_flag";
|
||||||
constexpr char TYPES[] = "types";
|
constexpr char TYPES[] = "types";
|
||||||
|
@ -124,6 +124,7 @@ constexpr char SHARED_NAME[] = "shared_name";
|
||||||
constexpr char MIRROR_OP[] = "mirror_op";
|
constexpr char MIRROR_OP[] = "mirror_op";
|
||||||
constexpr char FORWARD_OP[] = "forward_op";
|
constexpr char FORWARD_OP[] = "forward_op";
|
||||||
constexpr char REDISTRIBUTION_OP[] = "redistribution_op";
|
constexpr char REDISTRIBUTION_OP[] = "redistribution_op";
|
||||||
|
constexpr char DARA_PARALLEL[] = "data_parallel";
|
||||||
|
|
||||||
// Operator
|
// Operator
|
||||||
constexpr char VIRTUAL_DIV[] = "_VirtualDiv";
|
constexpr char VIRTUAL_DIV[] = "_VirtualDiv";
|
||||||
|
|
|
@ -605,8 +605,7 @@ bool IsSomePrimitive(const CNodePtr &cnode, const std::string &name) {
|
||||||
return (prim->name() == name);
|
return (prim->name() == name);
|
||||||
}
|
}
|
||||||
|
|
||||||
void StepReplaceGraph(const std::shared_ptr<std::pair<std::vector<AnfNodePtr>, AnfNodePtr>> &replace_graph,
|
void StepReplaceGraph(const ReplaceGraphPtr &replace_graph, const CNodePtr &node) {
|
||||||
const CNodePtr &node) {
|
|
||||||
MS_EXCEPTION_IF_NULL(replace_graph);
|
MS_EXCEPTION_IF_NULL(replace_graph);
|
||||||
MS_EXCEPTION_IF_NULL(node);
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
MS_EXCEPTION_IF_NULL(replace_graph->second);
|
MS_EXCEPTION_IF_NULL(replace_graph->second);
|
||||||
|
@ -616,20 +615,10 @@ void StepReplaceGraph(const std::shared_ptr<std::pair<std::vector<AnfNodePtr>, A
|
||||||
if (manager == nullptr) {
|
if (manager == nullptr) {
|
||||||
MS_LOG(EXCEPTION) << "Failure:AddNode error since manager is nullptr";
|
MS_LOG(EXCEPTION) << "Failure:AddNode error since manager is nullptr";
|
||||||
}
|
}
|
||||||
if (!IsSomePrimitive(node, ONEHOT)) {
|
|
||||||
MS_LOG(EXCEPTION) << "Failure:Only OneHot Primitive will enter StepReplaceGraph!";
|
|
||||||
}
|
|
||||||
if (node->inputs().size() != 5) {
|
|
||||||
MS_LOG(EXCEPTION) << "Failure:There is 5 inputs for the CNode corresponding to OneHot Primitive!";
|
|
||||||
}
|
|
||||||
auto pre_node = node->input(1);
|
|
||||||
if (replace_graph->first.size() != 2) {
|
|
||||||
MS_LOG(EXCEPTION) << "Failure:replace_graph->first.size() must be 2 for OneHot Primitive!";
|
|
||||||
}
|
|
||||||
for (auto &replace_input : replace_graph->first) {
|
for (auto &replace_input : replace_graph->first) {
|
||||||
MS_EXCEPTION_IF_NULL(replace_input);
|
auto pre_node = node->input(IntToSize(replace_input.second));
|
||||||
manager->SetEdge(replace_input, 1, pre_node);
|
manager->SetEdge(replace_input.first, 1, pre_node);
|
||||||
CNodePtr replace_input_cnode = replace_input->cast<CNodePtr>();
|
auto replace_input_cnode = replace_input.first->cast<CNodePtr>();
|
||||||
MS_EXCEPTION_IF_NULL(replace_input_cnode);
|
MS_EXCEPTION_IF_NULL(replace_input_cnode);
|
||||||
(void)replace_input_cnode->set_operator_info(node->operator_info());
|
(void)replace_input_cnode->set_operator_info(node->operator_info());
|
||||||
replace_input_cnode->set_in_forward_flag(true); // mark this new cnode is forward node
|
replace_input_cnode->set_in_forward_flag(true); // mark this new cnode is forward node
|
||||||
|
@ -943,6 +932,20 @@ OperatorInfoPtr OperatorInstanceByName(const std::string &name, const PrimitiveA
|
||||||
MS_LOG(EXCEPTION) << "Length of name is zero!";
|
MS_LOG(EXCEPTION) << "Length of name is zero!";
|
||||||
}
|
}
|
||||||
std::string distribute_opname = GetDisOpName(name);
|
std::string distribute_opname = GetDisOpName(name);
|
||||||
|
if (name == GATHERV2) {
|
||||||
|
distribute_opname = name + "PInfo";
|
||||||
|
auto data_parallel_iter = attrs.find(DATA_PARALLEL);
|
||||||
|
if (data_parallel_iter != attrs.end()) {
|
||||||
|
MS_EXCEPTION_IF_NULL(data_parallel_iter->second);
|
||||||
|
if (!data_parallel_iter->second->isa<BoolImm>()) {
|
||||||
|
MS_LOG(EXCEPTION) << ": data_parallel flag's type is not a bool.";
|
||||||
|
}
|
||||||
|
bool data_parallel = data_parallel_iter->second->cast<BoolImmPtr>()->value();
|
||||||
|
if (data_parallel) {
|
||||||
|
distribute_opname = name + "Info";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
OperatorInfoPtr operator_ =
|
OperatorInfoPtr operator_ =
|
||||||
(OperatorInfoPtr)DynCreator::Instance().Creat(distribute_opname, shape_list[0], shape_list[1], attrs, TOTAL_OPS);
|
(OperatorInfoPtr)DynCreator::Instance().Creat(distribute_opname, shape_list[0], shape_list[1], attrs, TOTAL_OPS);
|
||||||
if (operator_ == nullptr) {
|
if (operator_ == nullptr) {
|
||||||
|
|
|
@ -0,0 +1,173 @@
|
||||||
|
# Copyright 2019 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
import numpy as np
|
||||||
|
import mindspore as ms
|
||||||
|
from mindspore import Tensor
|
||||||
|
from mindspore import context
|
||||||
|
import mindspore.nn as nn
|
||||||
|
from mindspore.ops import operations as P
|
||||||
|
from mindspore.ops import composite as C
|
||||||
|
from mindspore.common import dtype as mstype
|
||||||
|
from mindspore.common.api import _executor
|
||||||
|
from tests.ut.python.ops.test_math_ops import VirtualLoss
|
||||||
|
|
||||||
|
|
||||||
|
class NetWithLoss(nn.Cell):
|
||||||
|
def __init__(self, network):
|
||||||
|
super(NetWithLoss, self).__init__()
|
||||||
|
self.loss = VirtualLoss()
|
||||||
|
self.network = network
|
||||||
|
|
||||||
|
def construct(self, x, y):
|
||||||
|
predict = self.network(x, y)
|
||||||
|
return self.loss(predict)
|
||||||
|
|
||||||
|
|
||||||
|
class GradWrap(nn.Cell):
|
||||||
|
def __init__(self, network):
|
||||||
|
super(GradWrap, self).__init__()
|
||||||
|
self.network = network
|
||||||
|
|
||||||
|
def construct(self, x, y):
|
||||||
|
return C.grad_all(self.network)(x, y)
|
||||||
|
|
||||||
|
|
||||||
|
class Net(nn.Cell):
|
||||||
|
def __init__(self, axis=0, strategy1=None, strategy2=None, shape=[64, 64]):
|
||||||
|
super().__init__()
|
||||||
|
self.gatherv2 = P.GatherV2().set_strategy(strategy1)
|
||||||
|
self.mul = P.Mul().set_strategy(strategy2)
|
||||||
|
self.index = Tensor(np.ones(shape), dtype=ms.int32)
|
||||||
|
self.axis = axis
|
||||||
|
|
||||||
|
def construct(self, x, y):
|
||||||
|
out = self.gatherv2(x, self.index, self.axis)
|
||||||
|
out = self.mul(out, y)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def test_gatherv2_semi_auto0():
|
||||||
|
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
|
||||||
|
strategy1 = ((1, 8), )
|
||||||
|
strategy2 = ((4, 2, 1), (4, 2, 1))
|
||||||
|
net = GradWrap(NetWithLoss(Net(0, strategy1, strategy2)))
|
||||||
|
net.set_auto_parallel()
|
||||||
|
|
||||||
|
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
|
||||||
|
y = Tensor(np.ones([64, 64, 32]), dtype=ms.float32)
|
||||||
|
_executor.compile(net, x, y)
|
||||||
|
|
||||||
|
def test_gatherv2_semi_auto1():
|
||||||
|
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
|
||||||
|
strategy1 = ((8, 1), )
|
||||||
|
strategy2 = ((4, 2, 1), (4, 2, 1))
|
||||||
|
net = GradWrap(NetWithLoss(Net(0, strategy1, strategy2)))
|
||||||
|
net.set_auto_parallel()
|
||||||
|
|
||||||
|
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
|
||||||
|
y = Tensor(np.ones([64, 64, 32]), dtype=ms.float32)
|
||||||
|
_executor.compile(net, x, y)
|
||||||
|
|
||||||
|
def test_gatherv2_semi_auto2():
|
||||||
|
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
|
||||||
|
strategy1 = ((2, 4), )
|
||||||
|
strategy2 = ((4, 2, 1), (4, 2, 1))
|
||||||
|
net = GradWrap(NetWithLoss(Net(0, strategy1, strategy2)))
|
||||||
|
net.set_auto_parallel()
|
||||||
|
|
||||||
|
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
|
||||||
|
y = Tensor(np.ones([64, 64, 32]), dtype=ms.float32)
|
||||||
|
_executor.compile(net, x, y)
|
||||||
|
|
||||||
|
def test_gatherv2_semi_auto3():
|
||||||
|
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
|
||||||
|
strategy1 = ((1, 8), )
|
||||||
|
strategy2 = ((4, 2, 1), (4, 2, 1))
|
||||||
|
net = GradWrap(NetWithLoss(Net(1, strategy1, strategy2)))
|
||||||
|
net.set_auto_parallel()
|
||||||
|
|
||||||
|
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
|
||||||
|
y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32)
|
||||||
|
_executor.compile(net, x, y)
|
||||||
|
|
||||||
|
def test_gatherv2_semi_auto4():
|
||||||
|
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
|
||||||
|
strategy1 = ((8, 1), )
|
||||||
|
strategy2 = ((4, 2, 1), (4, 2, 1))
|
||||||
|
net = GradWrap(NetWithLoss(Net(1, strategy1, strategy2)))
|
||||||
|
net.set_auto_parallel()
|
||||||
|
|
||||||
|
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
|
||||||
|
y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32)
|
||||||
|
_executor.compile(net, x, y)
|
||||||
|
|
||||||
|
def test_gatherv2_semi_auto5():
|
||||||
|
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
|
||||||
|
strategy1 = ((2, 4), )
|
||||||
|
strategy2 = ((4, 2, 1), (4, 2, 1))
|
||||||
|
net = GradWrap(NetWithLoss(Net(1, strategy1, strategy2)))
|
||||||
|
net.set_auto_parallel()
|
||||||
|
|
||||||
|
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
|
||||||
|
y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32)
|
||||||
|
_executor.compile(net, x, y)
|
||||||
|
|
||||||
|
def test_gatherv2_semi_auto6():
|
||||||
|
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
|
||||||
|
strategy2 = ((4, 2, 1), (4, 2, 1))
|
||||||
|
net = GradWrap(NetWithLoss(Net(0, None, strategy2)))
|
||||||
|
net.set_auto_parallel()
|
||||||
|
|
||||||
|
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
|
||||||
|
y = Tensor(np.ones([64, 64, 32]), dtype=ms.float32)
|
||||||
|
_executor.compile(net, x, y)
|
||||||
|
|
||||||
|
def test_gatherv2_semi_auto7():
|
||||||
|
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
|
||||||
|
strategy2 = ((4, 2, 1), (4, 2, 1))
|
||||||
|
net = GradWrap(NetWithLoss(Net(1, None, strategy2)))
|
||||||
|
net.set_auto_parallel()
|
||||||
|
|
||||||
|
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
|
||||||
|
y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32)
|
||||||
|
_executor.compile(net, x, y)
|
||||||
|
|
||||||
|
def test_gatherv2_semi_auto8():
|
||||||
|
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
|
||||||
|
strategy1 = ((8, ), )
|
||||||
|
strategy2 = ((4, 2), (4, 2))
|
||||||
|
net = GradWrap(NetWithLoss(Net(0, strategy1, strategy2)))
|
||||||
|
net.set_auto_parallel()
|
||||||
|
|
||||||
|
x = Tensor(np.ones([64]), dtype=ms.float32)
|
||||||
|
y = Tensor(np.ones([64, 64]), dtype=ms.float32)
|
||||||
|
_executor.compile(net, x, y)
|
||||||
|
|
||||||
|
def test_gatherv2_auto0():
|
||||||
|
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="auto_parallel")
|
||||||
|
net = GradWrap(NetWithLoss(Net(0)))
|
||||||
|
net.set_auto_parallel()
|
||||||
|
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
|
||||||
|
y = Tensor(np.ones([64, 64, 32]), dtype=ms.float32)
|
||||||
|
_executor.compile(net, x, y)
|
||||||
|
|
||||||
|
def test_gatherv2_auto1():
|
||||||
|
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="auto_parallel")
|
||||||
|
net = GradWrap(NetWithLoss(Net(1)))
|
||||||
|
net.set_auto_parallel()
|
||||||
|
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
|
||||||
|
y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32)
|
||||||
|
_executor.compile(net, x, y)
|
||||||
|
|
|
@ -74,7 +74,7 @@ class GatherV2(_Loss):
|
||||||
emb2_list = np.reshape(emb_list[1::2], (int(index_size/2), 16))
|
emb2_list = np.reshape(emb_list[1::2], (int(index_size/2), 16))
|
||||||
self.emb1_param = Tensor(emb1_list, dtype=mstype.int32)
|
self.emb1_param = Tensor(emb1_list, dtype=mstype.int32)
|
||||||
self.emb2_param = Tensor(emb2_list, dtype=mstype.int32)
|
self.emb2_param = Tensor(emb2_list, dtype=mstype.int32)
|
||||||
self.gatherv2 = P.GatherV2().set_strategy(strategy)
|
self.gatherv2 = P.GatherV2().set_strategy(strategy).add_prim_attr("data_parallel", True)
|
||||||
|
|
||||||
def construct(self, nembeddings):
|
def construct(self, nembeddings):
|
||||||
emb1 = self.gatherv2(nembeddings, self.emb1_param, 0)
|
emb1 = self.gatherv2(nembeddings, self.emb1_param, 0)
|
||||||
|
|
Loading…
Reference in New Issue