add_standalone_info

This commit is contained in:
lichen 2023-01-18 09:34:52 +08:00
parent d222e5a707
commit 0a054e8ca0
8 changed files with 342 additions and 18 deletions

View File

@ -108,6 +108,7 @@ class OperatorInfo {
virtual void ReComputeBatchSplitFlagList();
std::shared_ptr<Strategies> GenerateBatchStrategiesWithCheck();
void ComputeBatchSplitFlagList();
Shapes inputs_shape() const { return inputs_shape_; }
double GetForwardMemoryCostFromCNode();
// This is a common method for setting operator cost for a given strategy, in which the validity of this strategy

View File

@ -151,6 +151,7 @@ constexpr char DIVISOR[] = "divisor";
constexpr char NONE[] = "None";
constexpr char DEPEND[] = "Depend";
constexpr char BATCH_PARALLEL[] = "BatchParallel";
constexpr char STAND_ALONE[] = "StandAlone";
constexpr char ACTIVATION_TYPE[] = "activation_type";
constexpr char TARGET[] = "primitive_target";

View File

@ -0,0 +1,110 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "frontend/parallel/ops_info/stand_alone_info.h"
#include <memory>
#include <utility>
#include "ir/value.h"
#include "frontend/parallel/device_manager.h"
#include "frontend/parallel/device_matrix.h"
#include "frontend/parallel/dynamic_creator.h"
#include "frontend/parallel/step_parallel.h"
namespace mindspore {
namespace parallel {
Status StandAloneInfo::CheckStrategy(const StrategyPtr &strategy) { return SUCCESS; }
Status StandAloneInfo::InferDevMatrixShape() {
dev_matrix_shape_.push_back(stage_device_size_);
return SUCCESS;
}
Status StandAloneInfo::InferForwardCommunication() { return SUCCESS; }
Status StandAloneInfo::GetAttrs() { return SUCCESS; }
Status StandAloneInfo::InferTensorMap() {
// input tensor map, all -1
for (size_t i = 0; i < inputs_shape_.size(); i++) {
Shape tensor_map_index;
for (size_t j = 0; j < inputs_shape_[i].size(); ++j) {
tensor_map_index.push_back(MAP_NONE);
}
inputs_tensor_map_.push_back(tensor_map_index);
}
// output tensor map, all -1
for (size_t i = 0; i < outputs_shape_.size(); i++) {
Shape tensor_map_index;
for (size_t j = 0; j < outputs_shape_[i].size(); ++j) {
tensor_map_index.push_back(MAP_NONE);
}
outputs_tensor_map_.push_back(tensor_map_index);
}
return SUCCESS;
}
Status StandAloneInfo::InferTensorInfo() {
if (inputs_shape_.empty() || outputs_shape_.empty() || inputs_tensor_map_.empty() || outputs_tensor_map_.empty()) {
MS_LOG(ERROR) << name_ << ": Invalid args";
return FAILED;
}
// infer input TensorInfo
size_t temp = 0;
for (size_t i = 0; i < input_value_.size(); ++i) {
if (!input_value_[i]) {
TensorLayout input_layout;
if (input_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_[temp], inputs_shape_[temp]) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Infer input tensor layout failed, the index is " << i;
return FAILED;
}
temp += 1;
TensorInfo input_tensor_info(input_layout);
inputs_tensor_info_.push_back(input_tensor_info);
} else {
TensorInfo empty_tensor_info;
inputs_tensor_info_.push_back(empty_tensor_info);
}
}
// infer output TensorInfo
for (size_t j = 0; j < outputs_shape_.size(); ++j) {
TensorLayout output_layout;
if (output_layout.InitFromVector(dev_matrix_shape_, outputs_tensor_map_[j], outputs_shape_[j]) != SUCCESS) {
return FAILED;
}
TensorInfo out_tensor_info(output_layout);
outputs_tensor_info_.push_back(out_tensor_info);
}
return SUCCESS;
}
Status StandAloneInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { return SUCCESS; }
std::vector<StrategyPtr> StandAloneInfo::GenerateOpStrategies(int64_t stage_id) {
std::vector<StrategyPtr> sp_vector;
return sp_vector;
}
Status StandAloneInfo::InferAsLossDivisor() {
as_loss_divisor_ = stage_device_size_;
return SUCCESS;
}
REGISTER(StandAloneInfo);
} // namespace parallel
} // namespace mindspore

View File

@ -0,0 +1,50 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_STAND_ALONE_INFO_H_
#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_STAND_ALONE_INFO_H_
#include <memory>
#include <string>
#include <vector>
#include "ir/value.h"
#include "frontend/parallel/ops_info/operator_info.h"
#include "frontend/parallel/strategy.h"
namespace mindspore {
namespace parallel {
class StandAloneInfo : public OperatorInfo {
public:
StandAloneInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, nullptr) {}
~StandAloneInfo() override = default;
std::vector<StrategyPtr> GenerateOpStrategies(int64_t stage_id) override;
Status SetCostUnderStrategy(const StrategyPtr &strategy) override;
protected:
Status InferTensorInfo() override;
Status CheckStrategy(const StrategyPtr &strategy) override;
Status InferForwardCommunication() override;
Status InferDevMatrixShape() override;
Status InferTensorMap() override;
Status GetAttrs() override;
Status InferAsLossDivisor() override;
};
} // namespace parallel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_STAND_ALONE_INFO_H_

View File

@ -1550,18 +1550,23 @@ static void ExtractStrategyAndInit(const CNodePtr &cnode, const PrimitivePtr &pr
}
bool load_strategy_from_ckpt =
StrategyCheckpoint::GetInstance().LoadCheckPointOn() && stra_map.find(strategy_key_name) != stra_map.end();
if ((!StrategyFound(attrs) && !load_strategy_from_ckpt) && !cnode->HasPrimalAttr(IN_STRATEGY)) {
MS_LOG(INFO) << "ExtractInformation: the strategy of node " << cnode->ToString() << " prim " << prim->name()
<< " is empty, using batch parallel";
in_strategy = GenerateBatchParallelStrategy(op_info, prim);
} else if (cnode->HasPrimalAttr(IN_STRATEGY)) {
in_strategy = ExtractStrategy(cnode->GetPrimalAttr(IN_STRATEGY));
out_strategy = ExtractStrategy(cnode->GetPrimalAttr(OUT_STRATEGY));
} else if (StrategyFound(attrs)) {
in_strategy = ExtractStrategy(attrs[IN_STRATEGY]);
out_strategy = ExtractStrategy(attrs[OUT_STRATEGY]);
if (!prim->HasAttr(STAND_ALONE)) {
if (((!StrategyFound(attrs) && !load_strategy_from_ckpt) && !cnode->HasPrimalAttr(IN_STRATEGY)) ||
prim->HasAttr(BATCH_PARALLEL)) {
MS_LOG(INFO) << "ExtractInformation: the strategy of node " << cnode->ToString() << " prim " << prim->name()
<< " is empty, using batch parallel";
in_strategy = GenerateBatchParallelStrategy(op_info, prim);
} else if (cnode->HasPrimalAttr(IN_STRATEGY)) {
in_strategy = ExtractStrategy(cnode->GetPrimalAttr(IN_STRATEGY));
out_strategy = ExtractStrategy(cnode->GetPrimalAttr(OUT_STRATEGY));
} else if (StrategyFound(attrs)) {
in_strategy = ExtractStrategy(attrs[IN_STRATEGY]);
out_strategy = ExtractStrategy(attrs[OUT_STRATEGY]);
} else {
in_strategy = stra_map[strategy_key_name];
}
} else {
in_strategy = stra_map[strategy_key_name];
in_strategy = GenerateStandAloneStrategy(op_info->inputs_shape());
}
MS_EXCEPTION_IF_NULL(in_strategy);

View File

@ -859,14 +859,28 @@ OperatorInfoPtr OperatorInstance(const PrimitivePtr &prim, const PrimitiveAttrs
const std::vector<Shapes> &shape_list) {
MS_EXCEPTION_IF_NULL(prim);
OperatorInfoPtr operator_ = OperatorInstanceByName(prim->name(), attrs, shape_list);
if (operator_ == nullptr) {
if (IsInBatchParallelBlackList(prim)) {
MS_LOG(EXCEPTION) << "Operator " << prim->name() << " is not supported yet in auto parallel mode.";
}
MS_LOG(INFO) << "Create " << prim->name() << " failed, use batch parallel";
operator_ = OperatorInstanceByName(BATCH_PARALLEL, attrs, shape_list);
MS_EXCEPTION_IF_NULL(operator_);
if (operator_) {
return operator_;
}
if (IsInBatchParallelBlackList(prim)) {
operator_ = OperatorInstanceByName(STAND_ALONE, attrs, shape_list);
prim->AddAttr(STAND_ALONE, MakeValue<bool>(true));
MS_LOG(INFO) << "Operator " << prim->name() << " is not supported yet in auto parallel mode. Use Stand Alone";
return operator_;
}
auto input_shape = shape_list.at(0);
MS_EXCEPTION_IF_NULL(g_device_manager);
auto device_num = g_device_manager->stage_device_num();
MS_EXCEPTION_IF_ZERO("device_num", device_num);
if (input_shape[0].empty() || input_shape[0][0] % device_num != 0) {
MS_LOG(INFO) << "Operator " << prim->name() << " use Stand Alone";
operator_ = OperatorInstanceByName(STAND_ALONE, attrs, shape_list);
prim->AddAttr(STAND_ALONE, MakeValue<bool>(true));
return operator_;
}
MS_LOG(INFO) << "Operator " << prim->name() << " use Batch Parallel";
operator_ = OperatorInstanceByName(BATCH_PARALLEL, attrs, shape_list);
prim->AddAttr(BATCH_PARALLEL, MakeValue<bool>(true));
return operator_;
}
@ -1426,6 +1440,23 @@ StrategyPtr GenerateBatchParallelStrategy(const OperatorInfoPtr operator_, const
return strategyPtr;
}
StrategyPtr GenerateStandAloneStrategy(const Shapes &inputs_shape) {
Strategies strategy_v;
for (size_t i = 0; i != inputs_shape.size(); i++) {
if (inputs_shape[i].empty()) {
MS_LOG(INFO) << "Elements of shapes is empty.";
Dimensions empty_element;
strategy_v.push_back(empty_element);
} else {
Dimensions element(inputs_shape[i].size(), 1);
strategy_v.push_back(element);
}
}
auto stage_id = g_device_manager->stage_id();
auto stra_ptr = NewStrategy(stage_id, strategy_v);
return stra_ptr;
}
bool IsInsertVirtualOutput(const FuncGraphPtr &root) {
MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance());
auto comm_info = GetCommInfo();

View File

@ -103,6 +103,7 @@ Status ParallelInit();
std::pair<bool, CNodePtr> FindCNode(const AnfNodePtr &anode, const std::string &name, const FuncGraphPtr &func_graph,
size_t max_depth);
void SetSharedParameterFlag(const FuncGraphPtr &root, const AnfNodePtr &parameter);
StrategyPtr GenerateStandAloneStrategy(const Shapes &inputs_shape);
StrategyPtr GenerateBatchParallelStrategy(const OperatorInfoPtr operator_, const PrimitivePtr prim);
bool IsInsertVirtualOutput(const FuncGraphPtr &root);
TensorLayout GetInputLayoutFromCNode(const std::pair<AnfNodePtr, int64_t> &node_pair);

View File

@ -0,0 +1,125 @@
# Copyright 2023 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import mindspore as ms
import mindspore.nn as nn
from mindspore import Tensor
from mindspore import context
from mindspore.common.api import _cell_graph_executor
from mindspore.ops import composite as C
from mindspore.ops import operations as P
from mindspore.ops import prim_attr_register, PrimitiveWithInfer
def setup_function():
context.set_auto_parallel_context(dataset_strategy="full_batch")
grad_all = C.GradOperation(get_all=True)
class VirtualNodeGrad(PrimitiveWithInfer):
""" VirtualLossGrad definition """
@prim_attr_register
def __init__(self):
"""init VirtualLossGrad"""
def __call__(self, x, out, dout):
raise NotImplementedError
def infer_shape(self, x_shape, out_shape, dout_shape):
return x_shape
def infer_dtype(self, x_dtype, out_dtype, dout_dtype):
return x_dtype
class VirtualNode(PrimitiveWithInfer):
""" VirtualLoss definition """
@prim_attr_register
def __init__(self):
"""init VirtualLoss"""
def __call__(self, x):
raise NotImplementedError
def get_bprop(self):
loss_grad = VirtualNodeGrad()
def bprop(x, out, dout):
dx = loss_grad(x, out, dout)
return (dx,)
return bprop
def infer_shape(self, x_shape):
return [1]
def infer_dtype(self, x_dtype):
return x_dtype
class NetWithLoss(nn.Cell):
def __init__(self, network):
super(NetWithLoss, self).__init__()
self.loss = VirtualNode()
self.network = network
def construct(self, x, y, b):
predict = self.network(x, y, b)
return self.loss(predict)
class GradWrap(nn.Cell):
def __init__(self, network):
super(GradWrap, self).__init__()
self.network = network
def construct(self, x, y, b):
return grad_all(self.network)(x, y, b)
def test_two_matmul():
'''
Feature: test StandAloneInfo
Description: In SemiAuto mode, if there is no strategy and can't use BatchParallelInfo, use StandAloneInfo
Expectation: success
'''
class Net(nn.Cell):
def __init__(self, strategy1, strategy2):
super().__init__()
self.matmul1 = P.MatMul().shard(strategy1)
self.matmul2 = P.MatMul().shard(strategy2)
def construct(self, x, y, b):
out = self.matmul1(x, y)
out = self.matmul2(out, b)
return out
context.set_auto_parallel_context(device_num=8, global_rank=0, full_batch=True)
strategy1 = ((1, 2), (2, 2))
strategy2 = ((1, 2), (2, 4))
net = GradWrap(NetWithLoss(Net(strategy1, strategy2)))
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
x = Tensor(np.ones([1, 32]), dtype=ms.float32)
y = Tensor(np.ones([32, 128]), dtype=ms.float32)
b = Tensor(np.ones([128, 128]), dtype=ms.float32)
net.set_train()
_cell_graph_executor.compile(net, x, y, b)