forked from mindspore-Ecosystem/mindspore
!4068 Add parallel operator for Concat
Merge pull request !4068 from yangzhenzhang/add-concat-op
This commit is contained in:
commit
fe0b8d6272
|
@ -199,6 +199,8 @@ class SoftmaxCost : public OperatorCost {
|
|||
using SoftmaxCostPtr = std::shared_ptr<SoftmaxCost>;
|
||||
using TileCost = SoftmaxCost;
|
||||
using TileCostPtr = std::shared_ptr<TileCost>;
|
||||
using ConcatCost = TileCost;
|
||||
using ConcatCostPtr = std::shared_ptr<ConcatCost>;
|
||||
|
||||
class TmpIdentityCost : public OperatorCost {
|
||||
public:
|
||||
|
|
|
@ -136,6 +136,7 @@ REGISTER(EmbeddingLookupInfo);
|
|||
REGISTER(TileInfo);
|
||||
REGISTER(StridedSliceInfo);
|
||||
REGISTER(DropoutInfo);
|
||||
REGISTER(ConcatInfo);
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -24,7 +24,6 @@
|
|||
namespace mindspore {
|
||||
namespace parallel {
|
||||
const std::set<std::string> BLACK_LIST = {TUPLE_GETITEM,
|
||||
MAKE_TUPLE,
|
||||
J,
|
||||
LIST_GETITEM,
|
||||
ARRAY_GETITEM,
|
||||
|
|
|
@ -0,0 +1,268 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "frontend/parallel/ops_info/concat_info.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "frontend/parallel/device_matrix.h"
|
||||
#include "frontend/parallel/strategy.h"
|
||||
#include "frontend/parallel/tensor_layout/tensor_redistribution.h"
|
||||
#include "pipeline/jit/resource.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace parallel {
|
||||
Status ConcatInfo::GetAttrs() {
|
||||
int axis = 0;
|
||||
auto axis_iter = attrs_.find(AXIS);
|
||||
if (axis_iter != attrs_.end()) {
|
||||
MS_EXCEPTION_IF_NULL(axis_iter->second);
|
||||
if (axis_iter->second->isa<Int32Imm>()) {
|
||||
axis = axis_iter->second->cast<Int32ImmPtr>()->value();
|
||||
} else {
|
||||
MS_LOG(ERROR) << name_ << ": The value of axis is not int";
|
||||
return FAILED;
|
||||
}
|
||||
} else {
|
||||
MS_LOG(ERROR) << name_ << ": Can not find the axis attr";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
if (inputs_shape_.empty()) {
|
||||
MS_LOG(ERROR) << name_ << ": The inputs shape is empty";
|
||||
return FAILED;
|
||||
}
|
||||
int dim = SizeToInt(inputs_shape_[0].size());
|
||||
|
||||
if (axis < 0) {
|
||||
axis = axis + dim;
|
||||
}
|
||||
|
||||
axis_ = SizeToInt(axis);
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status ConcatInfo::CheckStrategy(const StrategyPtr &strategy) {
|
||||
MS_EXCEPTION_IF_NULL(strategy);
|
||||
if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << ": Invalid strategy";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
std::vector<Dimensions> stra = strategy->GetInputDim();
|
||||
if (stra.empty()) {
|
||||
MS_LOG(ERROR) << name_ << ": The strategy is empty";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
if (stra.size() != inputs_shape_.size()) {
|
||||
MS_LOG(ERROR) << name_ << ": The size of strategy must be equal to the size of inputs shape";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < stra.size(); ++i) {
|
||||
auto strategy_ele = stra[i];
|
||||
auto input_shape_ele = inputs_shape_[i];
|
||||
if (strategy_ele.size() != input_shape_ele.size()) {
|
||||
MS_LOG(ERROR) << name_ << ": The size of strategy element must be equal to the size of input shape";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
if (axis_ >= strategy_ele.size()) {
|
||||
MS_LOG(ERROR) << name_ << ": The axis is out of range, the axis is " << axis_;
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
if (strategy_ele[axis_] != 1) {
|
||||
MS_LOG(ERROR) << name_ << ": The axis can not be split";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
for (size_t j = 0; j < strategy_ele.size(); ++j) {
|
||||
if (strategy_ele[j] != stra[0][j]) {
|
||||
MS_LOG(ERROR) << name_ << ": The strategy of each input tensor must be equal";
|
||||
return FAILED;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status ConcatInfo::InferDevMatrixShape() {
|
||||
MS_EXCEPTION_IF_NULL(strategy_);
|
||||
std::vector<Dimensions> stra = strategy_->GetInputDim();
|
||||
if (stra.empty()) {
|
||||
MS_LOG(ERROR) << name_ << "The strategy is empty";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
dev_matrix_shape_ = stra[0];
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status ConcatInfo::InferTensorMap() {
|
||||
TensorMap tensor_map;
|
||||
if (inputs_shape_.empty()) {
|
||||
MS_LOG(ERROR) << name_ << "The inputs shape is empty";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
// cannot use dev_matrix_shape_ replace inputs_shape_[0], because it may not be fully split in all devices.
|
||||
int32_t size = SizeToInt(inputs_shape_[0].size());
|
||||
for (int i = 0; i < size; ++i) {
|
||||
tensor_map.push_back(size - i - 1);
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < inputs_shape_.size(); ++i) {
|
||||
inputs_tensor_map_.push_back(tensor_map);
|
||||
}
|
||||
outputs_tensor_map_.push_back(tensor_map);
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status ConcatInfo::InferMirrorOps() {
|
||||
mirror_ops_.clear();
|
||||
if (inputs_tensor_map_.empty()) {
|
||||
MS_LOG(ERROR) << name_ << ": The inputs tensor map is empty";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
Shape input_tensor_map = inputs_tensor_map_[0];
|
||||
std::vector<Group> group;
|
||||
if (CreateGroupByTensorMap(input_tensor_map, &group) != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << ": Create group for input failed.";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
if (group.empty()) {
|
||||
MS_LOG(INFO) << name_ << ": The mirror group is empty.";
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
OperatorVector input_op;
|
||||
input_op = CreateMirrorOps(group[0].name(), group[0].GetDevNum());
|
||||
for (size_t i = 0; i < inputs_shape_.size(); ++i) {
|
||||
mirror_ops_.push_back(input_op);
|
||||
}
|
||||
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status ConcatInfo::InferTensorInfo() {
|
||||
if (inputs_shape_.empty() || outputs_shape_.empty() || inputs_tensor_map_.empty() || outputs_tensor_map_.empty()) {
|
||||
MS_LOG(ERROR) << name_ << ": Invalid args";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
TensorLayout input_layout, output_layout;
|
||||
for (size_t i = 0; i < inputs_shape_.size(); ++i) {
|
||||
// infer tensor layout
|
||||
if (input_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_[i], inputs_shape_[i]) != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << ": Infer input tensor layout failed.";
|
||||
return FAILED;
|
||||
}
|
||||
TensorInfo input_tensor_info(input_layout);
|
||||
inputs_tensor_info_.push_back(input_tensor_info);
|
||||
}
|
||||
|
||||
if (output_layout.InitFromVector(dev_matrix_shape_, outputs_tensor_map_[0], outputs_shape_[0]) != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << ": Infer output tensor layout failed.";
|
||||
return FAILED;
|
||||
}
|
||||
TensorInfo output_tensor_info(output_layout);
|
||||
outputs_tensor_info_.push_back(output_tensor_info);
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
void ConcatInfo::ReComputeBatchSplitFlagList() {
|
||||
for (size_t i = 0; i < inputs_shape_.size(); i++) {
|
||||
split_flag_list_[i] = true;
|
||||
}
|
||||
}
|
||||
|
||||
Status ConcatInfo::SetCostUnderStrategy(const StrategyPtr &strategy) {
|
||||
if (SetCostUnderStrategyBase(strategy) != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << ": Set cost under strategy failed.";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status ConcatInfo::GenerateStrategies(int32_t stage_id) {
|
||||
if (InferAttrs() != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << ": Infer attrs failed";
|
||||
return FAILED;
|
||||
}
|
||||
if (inputs_shape_.empty()) {
|
||||
MS_LOG(ERROR) << name_ << ": The inputs shape is empty";
|
||||
return FAILED;
|
||||
}
|
||||
Shape input_split;
|
||||
for (size_t i = 0; i < inputs_shape_[0].size(); ++i) {
|
||||
if (i == axis_) {
|
||||
input_split.push_back(0);
|
||||
} else {
|
||||
input_split.push_back(1);
|
||||
}
|
||||
}
|
||||
Shapes splittable_inputs;
|
||||
for (size_t i = 0; i < inputs_shape_.size(); ++i) {
|
||||
splittable_inputs.push_back(input_split);
|
||||
}
|
||||
|
||||
std::vector<StrategyPtr> sp_vector;
|
||||
is_auto_parallel_ = true;
|
||||
if (GenerateStrategiesWithBroadcast(stage_id, inputs_shape_, splittable_inputs, &sp_vector) != SUCCESS) {
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
size_t success = 0;
|
||||
for (auto &sp : sp_vector) {
|
||||
PrintStrategy(sp);
|
||||
if (SetCostUnderStrategy(sp) == SUCCESS) {
|
||||
success++;
|
||||
MS_LOG(INFO) << name_ << ": Successfully generated " << success << " strategy.";
|
||||
PrintStrategy(sp);
|
||||
}
|
||||
}
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status ConcatInfo::Init(const StrategyPtr &strategy) {
|
||||
if (InitWithAutoRepeatCalc(strategy) != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << ": Init failed.";
|
||||
return FAILED;
|
||||
}
|
||||
MS_LOG(INFO) << name_ << ": Init success.";
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status ConcatInfo::InitForCostModel(const StrategyPtr &strategy) {
|
||||
if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << ": Init for cost model failed.";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << name_ << ": Init for cost model success.";
|
||||
return SUCCESS;
|
||||
}
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,62 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_CONCAT_INFO_H_
|
||||
#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_CONCAT_INFO_H_
|
||||
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "ir/value.h"
|
||||
#include "frontend/parallel/auto_parallel/operator_costmodel.h"
|
||||
#include "frontend/parallel/ops_info/operator_info.h"
|
||||
#include "frontend/parallel/strategy.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace parallel {
|
||||
class ConcatInfo : public OperatorInfo {
|
||||
public:
|
||||
ConcatInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<ConcatCost>(false)) {}
|
||||
~ConcatInfo() override = default;
|
||||
|
||||
Status Init(const StrategyPtr &strategy) override;
|
||||
Status InitForCostModel(const StrategyPtr &strategy) override;
|
||||
Status GenerateStrategies(int32_t) override;
|
||||
Status SetCostUnderStrategy(const StrategyPtr &) override;
|
||||
void ReComputeBatchSplitFlagList() override;
|
||||
|
||||
protected:
|
||||
Status GetAttrs() override;
|
||||
Status CheckStrategy(const StrategyPtr &strategy) override;
|
||||
Status InferMirrorOps() override;
|
||||
Status InferForwardCommunication() override { return SUCCESS; }
|
||||
Status InferTensorInfo() override;
|
||||
Status InferDevMatrixShape() override;
|
||||
Status InferTensorMap() override;
|
||||
|
||||
private:
|
||||
size_t axis_ = 0;
|
||||
};
|
||||
|
||||
using ConcatInfoPtr = std::shared_ptr<ConcatInfo>;
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_CONCAT_INFO_H_
|
|
@ -39,5 +39,6 @@
|
|||
#include "frontend/parallel/ops_info/gather_v2_p_info.h"
|
||||
#include "frontend/parallel/ops_info/tile_info.h"
|
||||
#include "frontend/parallel/ops_info/strided_slice_info.h"
|
||||
#include "frontend/parallel/ops_info/concat_info.h"
|
||||
|
||||
#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_HEAD_FILES_H_
|
||||
|
|
|
@ -118,6 +118,9 @@ bool StepAutoParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &) {
|
|||
std::vector<bool> ExtractInputParameterByNode(const CNodePtr &node) {
|
||||
std::vector<bool> is_parameter;
|
||||
std::vector<AnfNodePtr> node_inputs{node->inputs()};
|
||||
if ((node_inputs.size() == 2) && AnfNodeIsPrimitive(node_inputs[1], MAKE_TUPLE)) {
|
||||
node_inputs = node_inputs[1]->cast<CNodePtr>()->inputs();
|
||||
}
|
||||
for (size_t i = 1; i < node_inputs.size(); ++i) {
|
||||
auto input = node_inputs[i];
|
||||
|
||||
|
@ -192,6 +195,10 @@ std::vector<size_t> ExtractInputTypeLengthByNode(const CNodePtr &node) {
|
|||
std::vector<size_t> inputs_type_len;
|
||||
std::vector<AnfNodePtr> node_inputs{node->inputs()};
|
||||
|
||||
if ((node_inputs.size() == 2) && AnfNodeIsPrimitive(node_inputs[1], MAKE_TUPLE)) {
|
||||
node_inputs = node_inputs[1]->cast<CNodePtr>()->inputs();
|
||||
}
|
||||
|
||||
// extract input element length
|
||||
for (auto &input : node_inputs) {
|
||||
if (IsValueNode<RefKey>(input)) {
|
||||
|
@ -255,7 +262,7 @@ bool IsSplittableOperator(const std::string &op_name) {
|
|||
FLOORDIV, L2_NORMALIZE, TENSOR_ADD, MAXPOOL, MAXPOOLV2, VIRTUAL_DATA_SET, RELU, ONEHOT, DROPOUT_DO_MASK,
|
||||
REDUCE_MAX, REDUCE_MIN, ARGMAXWITHVALUE, ARGMINWITHVALUE, REDUCE_SUM, CONV2D, FUSE_BATCH_NORM, POOLING,
|
||||
MAX_POOL_WITH_ARGMAX, SIMPLE_MEAN, FLATTEN, BATCH_NORM, LAYER_NORM, BIAS_ADD, ASSIGN_SUB, COS, ACOS, EXP,
|
||||
LOG, REDUCE_MEAN, REAL_DIV, SIGMOID, POW, MAXIMUM, MINIMUM, EQUAL, NOT_EQUAL, LOGICALNOT, GATHERV2, SQRT,
|
||||
LOG, REDUCE_MEAN, REAL_DIV, SIGMOID, POW, MAXIMUM, MINIMUM, EQUAL, NOT_EQUAL, LOGICALNOT, GATHERV2, SQRT, CONCAT,
|
||||
STRIDEDSLICE, GET_NEXT, CAST, NEG, SQUARE, BATCH_MATMUL, EXPAND_DIMS, SQUEEZE, SPARSE_GATHERV2, TILE, DROPOUT,
|
||||
SOFTMAX_CROSS_ENTROPY_WITH_LOGITS, SIGMOID_CROSS_ENTROPY_WITH_LOGITS, SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS};
|
||||
// clang-format on
|
||||
|
@ -275,7 +282,7 @@ bool IsAutoParallelCareNode(const CNodePtr &cnode) {
|
|||
return false;
|
||||
}
|
||||
bool bool_result = IsParallelCareNode(cnode) && !IsSplittableOperator(prim->name());
|
||||
if (bool_result) {
|
||||
if (bool_result && (prim->name() != MAKE_TUPLE)) {
|
||||
MS_LOG(EXCEPTION) << "Should implementing OperatorInfo for: " << prim->name();
|
||||
} else if (prim->name() == CAST) {
|
||||
if (cnode->fullname_with_scope().find(OPTIMIZER_SUB_STRING) != std::string::npos) {
|
||||
|
|
|
@ -267,6 +267,33 @@ TensorLayout GetTensorInLayout(const CNodePtr &middle_node, const PrimitivePtr &
|
|||
return tensorinfo_in.tensor_layout();
|
||||
}
|
||||
|
||||
bool AnfNodeIsPrimitive(const AnfNodePtr &anf_node, const std::string &prim_name) {
|
||||
MS_EXCEPTION_IF_NULL(anf_node);
|
||||
auto cnode = anf_node->cast<CNodePtr>();
|
||||
if ((cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0))) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto value_node = cnode->input(0)->cast<ValueNodePtr>();
|
||||
auto prim = GetValueNode<PrimitivePtr>(value_node);
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
if (prim->name() == prim_name) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
std::string GetPrimName(const CNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (!IsValueNode<Primitive>(node->input(0))) {
|
||||
MS_LOG(EXCEPTION) << "The node is not a primitive";
|
||||
}
|
||||
auto value_node = node->input(0)->cast<ValueNodePtr>();
|
||||
auto prim = GetValueNode<PrimitivePtr>(value_node);
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
return prim->name();
|
||||
}
|
||||
|
||||
OperatorInfoPtr GetDistributeOperator(const CNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (!IsParallelCareNode(node)) {
|
||||
|
@ -274,7 +301,7 @@ OperatorInfoPtr GetDistributeOperator(const CNodePtr &node) {
|
|||
}
|
||||
OperatorInfoPtr distribute_operator = node->user_data<OperatorInfo>();
|
||||
if (distribute_operator == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "GetDistributeOperator:distribute_operator is nullptr";
|
||||
MS_LOG(EXCEPTION) << "Distribute operator is nullptr, the prim is " << GetPrimName(node);
|
||||
}
|
||||
return distribute_operator;
|
||||
}
|
||||
|
@ -423,6 +450,11 @@ void StepRedistribution(const CNodePtr &node, const OperatorInfoPtr &distribute_
|
|||
MS_EXCEPTION_IF_NULL(manager);
|
||||
AnfNodeIndexSet node_set = manager->node_users()[node];
|
||||
CNodePtr insert_node_new;
|
||||
|
||||
if (AnfNodeIsPrimitive(node, MAKE_TUPLE)) {
|
||||
MS_LOG(INFO) << "No need to insert redistribution op betweend make_tuple node and the next node";
|
||||
return;
|
||||
}
|
||||
if (IsValueNode<Primitive>(node->input(0))) {
|
||||
auto current_value = node->input(0)->cast<ValueNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(current_value);
|
||||
|
@ -875,9 +907,15 @@ void InsertMirrorOps(const MirrorOps &mirror_ops, const CNodePtr &node) {
|
|||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
FuncGraphManagerPtr manager = func_graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
|
||||
if ((node->inputs().size() == 2) && AnfNodeIsPrimitive(node->input(1), MAKE_TUPLE)) {
|
||||
MS_LOG(INFO) << "The mirror for " << GetPrimName(node) << " has handle by make_tuple node";
|
||||
return;
|
||||
}
|
||||
|
||||
if (mirror_ops.size() != node_size - 1) {
|
||||
MS_LOG(EXCEPTION) << "Failure:Mirrorops's size is wrong! mirror_ops size is " << mirror_ops.size()
|
||||
<< ", node_size is " << node_size;
|
||||
MS_LOG(EXCEPTION) << "Mirrorops's size is wrong! mirror_ops size is " << mirror_ops.size() << ", node_size is "
|
||||
<< node_size - 1;
|
||||
}
|
||||
for (size_t index = 1; index < node_size; ++index) {
|
||||
OperatorVector backward_op = mirror_ops[index - 1];
|
||||
|
@ -993,7 +1031,7 @@ OperatorInfoPtr OperatorInstance(const PrimitivePtr &prim, const PrimitiveAttrs
|
|||
const std::vector<Shapes> &shape_list) {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
OperatorInfoPtr operator_ = OperatorInstanceByName(prim->name(), attrs, shape_list);
|
||||
if (operator_ == nullptr) {
|
||||
if ((operator_ == nullptr) && (prim->name() != MAKE_TUPLE)) {
|
||||
MS_LOG(INFO) << "Creat " << prim->name() << " failed, use batch parallel";
|
||||
operator_ = OperatorInstanceByName(BATCH_PARALLEL, attrs, shape_list);
|
||||
MS_EXCEPTION_IF_NULL(operator_);
|
||||
|
@ -1177,7 +1215,12 @@ std::vector<Shapes> ExtractShape(const CNodePtr &node) {
|
|||
continue;
|
||||
}
|
||||
if (input_shapes.size() != 1) {
|
||||
MS_LOG(EXCEPTION) << "ExtractShape:Get input shape failed";
|
||||
if (inputs_size == 2) { // like concat
|
||||
shape_inputs = input_shapes;
|
||||
break;
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "ExtractShape: Get input shape failed";
|
||||
}
|
||||
}
|
||||
shape_inputs.push_back(input_shapes[0]);
|
||||
}
|
||||
|
@ -1269,8 +1312,8 @@ void SetParallelShape(const AnfNodePtr ¶meter, const std::pair<AnfNodePtr, i
|
|||
}
|
||||
TensorInfo tensorinfo_in = distribute_operator->inputs_tensor_info()[IntToSize(res.second - 1)];
|
||||
Shape slice_shape = tensorinfo_in.slice_shape();
|
||||
MS_LOG(DEBUG) << "SetParallelShape slice_shape " << parameter->ToString() << " shape "
|
||||
<< MakeValue(slice_shape)->ToString();
|
||||
MS_LOG(INFO) << "SetParallelShape slice_shape " << parameter->ToString() << " shape "
|
||||
<< MakeValue(slice_shape)->ToString() << ", op name is " << distribute_operator->name();
|
||||
std::shared_ptr<abstract::BaseShape> parallel_shape = std::make_shared<abstract::Shape>(slice_shape);
|
||||
MS_EXCEPTION_IF_NULL(parallel_shape);
|
||||
// Don't modify it in-place as the pointer of this AbstractValue may used as cache key in StaticAnalysis.
|
||||
|
@ -1450,6 +1493,9 @@ void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes) {
|
|||
SetVirtualDatasetStrategy(cnode);
|
||||
ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>();
|
||||
PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node);
|
||||
if (prim->name() == MAKE_TUPLE) {
|
||||
continue;
|
||||
}
|
||||
auto attrs = prim->attrs();
|
||||
MS_LOG(INFO) << "extract information: node: " << node->ToString() << " prim " << prim->name();
|
||||
if (IsParallelCareNode(cnode)) {
|
||||
|
@ -2045,13 +2091,13 @@ void ParallelCommunication(const FuncGraphPtr &root, const std::vector<AnfNodePt
|
|||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (node->isa<CNode>()) {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
if (!IsValueNode<Primitive>(cnode->input(0))) {
|
||||
// the make_tuple is parallel care node, but it may have not operator info
|
||||
if (!IsParallelCareNode(cnode) || !cnode->has_user_data<OperatorInfo>()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
OperatorInfoPtr distribute_operator = GetDistributeOperator(cnode);
|
||||
if (distribute_operator == nullptr) {
|
||||
continue;
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(distribute_operator);
|
||||
|
||||
// insert forward ops
|
||||
InsertForwardOps(distribute_operator, cnode);
|
||||
|
@ -2074,13 +2120,12 @@ void ParallelCommunication(const FuncGraphPtr &root, const std::vector<AnfNodePt
|
|||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (node->isa<CNode>()) {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
if (!IsValueNode<Primitive>(cnode->input(0))) {
|
||||
if (!IsParallelCareNode(cnode) || !cnode->has_user_data<OperatorInfo>()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
OperatorInfoPtr distribute_operator = GetDistributeOperator(cnode);
|
||||
if (distribute_operator == nullptr) {
|
||||
continue;
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(distribute_operator);
|
||||
// StepReplace
|
||||
StepReplace(distribute_operator, cnode);
|
||||
}
|
||||
|
@ -2330,6 +2375,44 @@ Status ParallelInit() {
|
|||
return SUCCESS;
|
||||
}
|
||||
|
||||
void HandleForwardMakeTuple(const std::vector<AnfNodePtr> &all_nodes) {
|
||||
for (auto &node : all_nodes) {
|
||||
if (!AnfNodeIsPrimitive(node, MAKE_TUPLE)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (!cnode->in_forward_flag()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
FuncGraphManagerPtr manager = cnode->func_graph()->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
auto make_tuple_user = manager->node_users()[cnode];
|
||||
if (make_tuple_user.size() != 1) {
|
||||
MS_LOG(EXCEPTION) << "Now the make_tuple's user must be 1, but got " << make_tuple_user.size();
|
||||
}
|
||||
CNodePtr make_tuple_next_cnode = make_tuple_user.pop().first->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(make_tuple_next_cnode);
|
||||
|
||||
std::string make_tuple_user_prim_name = GetPrimName(make_tuple_next_cnode);
|
||||
if (!IsParallelCareNode(make_tuple_next_cnode)) {
|
||||
MS_LOG(INFO) << "The make_tuple's user is " << make_tuple_user_prim_name << ", no need to set operator info";
|
||||
continue;
|
||||
}
|
||||
if (make_tuple_next_cnode->inputs().size() != 2) {
|
||||
MS_LOG(EXCEPTION) << "Now the make_tuple's user only support 1 input, but got "
|
||||
<< make_tuple_next_cnode->inputs().size() - 1;
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "Set the make_tuple's operator info, and the op name is " << make_tuple_user_prim_name;
|
||||
OperatorInfoPtr op_info = GetDistributeOperator(make_tuple_next_cnode);
|
||||
MS_EXCEPTION_IF_NULL(op_info);
|
||||
cnode->set_user_data<OperatorInfo>(op_info);
|
||||
}
|
||||
}
|
||||
|
||||
bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer) {
|
||||
MS_EXCEPTION_IF_NULL(root);
|
||||
MS_EXCEPTION_IF_NULL(optimizer);
|
||||
|
@ -2383,6 +2466,9 @@ bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer)
|
|||
ExtractInformation(all_nodes);
|
||||
ReshapeInit(all_nodes);
|
||||
}
|
||||
|
||||
HandleForwardMakeTuple(all_nodes);
|
||||
|
||||
// save strategy as checkpoint for multi-train
|
||||
if (StrategyCheckpoint::GetInstance().SaveCheckPointOn()) {
|
||||
CheckpointStrategy(root);
|
||||
|
|
|
@ -149,6 +149,8 @@ Status ParallelInit();
|
|||
std::vector<std::string> ExtractInputsTensorName(const CNodePtr &node);
|
||||
|
||||
std::set<FuncGraphPtr> ForwardGraph(const FuncGraphPtr &root);
|
||||
|
||||
bool AnfNodeIsPrimitive(const AnfNodePtr &anf_node, const std::string &prim_name);
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -222,9 +222,17 @@ def get_bprop_virtual_div_operator(self):
|
|||
dtype = P.DType()
|
||||
|
||||
def bprop(x, out, dout):
|
||||
if F.issubclass_(F.dtype(dout), mstype.bool_):
|
||||
return (dout,)
|
||||
dx = op(dout, cast(F.scalar_to_array(divisor), dtype(dout)))
|
||||
if F.issubclass_(F.typeof(dout), mstype.tensor):
|
||||
if F.issubclass_(F.dtype(dout), mstype.bool_):
|
||||
return (dout,)
|
||||
dx = op(dout, cast(F.scalar_to_array(divisor), dtype(dout)))
|
||||
return (dx,)
|
||||
|
||||
dx = ()
|
||||
input_nums = F.tuple_len(dout)
|
||||
for i in range(input_nums):
|
||||
ele_grad = op(dout[i], cast(F.scalar_to_array(divisor), dtype(dout[i])))
|
||||
dx = dx + (ele_grad,)
|
||||
return (dx,)
|
||||
return bprop
|
||||
|
||||
|
|
|
@ -0,0 +1,128 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
|
||||
import mindspore as ms
|
||||
from mindspore import context, Tensor, Parameter
|
||||
from mindspore.common.api import _executor
|
||||
from mindspore.nn import Cell, TrainOneStepCell, Momentum
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
class Net(Cell):
|
||||
def __init__(self, weight, weight2, strategy1=None, strategy2=None, is_parameter=True):
|
||||
super().__init__()
|
||||
self.concat = P.Concat(axis=0).set_strategy(strategy1)
|
||||
if is_parameter:
|
||||
self.weight = Parameter(weight, "w1")
|
||||
else:
|
||||
self.weight = weight
|
||||
self.mul = P.Mul().set_strategy(strategy2)
|
||||
self.weight2 = Parameter(weight2, "w2")
|
||||
|
||||
def construct(self, x, b):
|
||||
out = self.concat((self.weight, self.weight2))
|
||||
out = self.mul(x, out)
|
||||
return out
|
||||
|
||||
|
||||
class Net2(Cell):
|
||||
def __init__(self, weight, strategy1=None, strategy2=None, axis=0):
|
||||
super().__init__()
|
||||
self.mul = P.Mul().set_strategy(strategy1)
|
||||
self.concat = P.Concat(axis=axis).set_strategy(strategy2)
|
||||
self.weight = Parameter(weight, "w")
|
||||
|
||||
def construct(self, x, b):
|
||||
out = self.mul(x, b)
|
||||
out = self.concat((out, self.weight))
|
||||
return out
|
||||
|
||||
|
||||
_x = Tensor(np.ones([128, 64, 32]), dtype=ms.float32)
|
||||
_w1 = Tensor(np.ones([96, 64, 32]), dtype=ms.float32)
|
||||
_w2 = Tensor(np.ones([32, 64, 32]), dtype=ms.float32)
|
||||
_w3 = Tensor(np.ones([128, 16, 32]), dtype=ms.float32)
|
||||
_b = Tensor(np.ones([128, 64, 32]), dtype=ms.float32)
|
||||
|
||||
|
||||
def compile_net(net):
|
||||
context.set_context(save_graphs=True)
|
||||
optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
|
||||
train_net = TrainOneStepCell(net, optimizer)
|
||||
train_net.set_auto_parallel()
|
||||
_executor.compile(train_net, _x, _b)
|
||||
context.reset_auto_parallel_context()
|
||||
|
||||
|
||||
def test_concat_parameter():
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
strategy1 = ((1, 4, 2), (1, 4, 2))
|
||||
strategy2 = ((1, 4, 2), (1, 4, 2))
|
||||
net = Net(_w1, _w2, strategy1, strategy2, is_parameter=True)
|
||||
compile_net(net)
|
||||
|
||||
|
||||
def test_concat_parameter_no_full_split():
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
strategy1 = ((1, 2, 2), (1, 2, 2))
|
||||
strategy2 = ((1, 4, 2), (1, 4, 2))
|
||||
net = Net(_w1, _w2, strategy1, strategy2, is_parameter=True)
|
||||
compile_net(net)
|
||||
|
||||
|
||||
def test_concat_tensor_and_parameter():
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
strategy1 = ((1, 2, 2), (1, 2, 2))
|
||||
strategy2 = ((1, 4, 2), (1, 4, 2))
|
||||
net = Net(_w1, _w2, strategy1, strategy2, is_parameter=False)
|
||||
compile_net(net)
|
||||
|
||||
|
||||
def test_concat_output():
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
strategy1 = ((2, 2, 2), (2, 2, 2))
|
||||
strategy2 = ((1, 4, 2), (1, 4, 2))
|
||||
net = Net2(_w1, strategy1, strategy2)
|
||||
compile_net(net)
|
||||
|
||||
|
||||
def test_concat_output_no_full_split():
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
strategy1 = ((2, 2, 2), (2, 2, 2))
|
||||
strategy2 = ((1, 2, 2), (1, 2, 2))
|
||||
net = Net2(_w1, strategy1, strategy2)
|
||||
compile_net(net)
|
||||
|
||||
|
||||
def test_concat_no_strategy():
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
strategy1 = ((2, 2, 2), (2, 2, 2))
|
||||
strategy2 = None
|
||||
net = Net2(_w3, strategy1, strategy2, axis=1)
|
||||
compile_net(net)
|
||||
|
||||
|
||||
def test_concat_auto_parallel():
|
||||
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0)
|
||||
net = Net2(_w2)
|
||||
compile_net(net)
|
||||
|
||||
|
||||
def test_concat_auto_parallel2():
|
||||
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0)
|
||||
strategy1 = None
|
||||
strategy2 = None
|
||||
net = Net2(_w3, strategy1, strategy2, axis=1)
|
||||
compile_net(net)
|
Loading…
Reference in New Issue