forked from mindspore-Ecosystem/mindspore
add_standalone_info
This commit is contained in:
parent
d222e5a707
commit
0a054e8ca0
|
@ -108,6 +108,7 @@ class OperatorInfo {
|
|||
virtual void ReComputeBatchSplitFlagList();
|
||||
std::shared_ptr<Strategies> GenerateBatchStrategiesWithCheck();
|
||||
void ComputeBatchSplitFlagList();
|
||||
Shapes inputs_shape() const { return inputs_shape_; }
|
||||
|
||||
double GetForwardMemoryCostFromCNode();
|
||||
// This is a common method for setting operator cost for a given strategy, in which the validity of this strategy
|
||||
|
|
|
@ -151,6 +151,7 @@ constexpr char DIVISOR[] = "divisor";
|
|||
constexpr char NONE[] = "None";
|
||||
constexpr char DEPEND[] = "Depend";
|
||||
constexpr char BATCH_PARALLEL[] = "BatchParallel";
|
||||
constexpr char STAND_ALONE[] = "StandAlone";
|
||||
|
||||
constexpr char ACTIVATION_TYPE[] = "activation_type";
|
||||
constexpr char TARGET[] = "primitive_target";
|
||||
|
|
|
@ -0,0 +1,110 @@
|
|||
/**
|
||||
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "frontend/parallel/ops_info/stand_alone_info.h"
|
||||
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
|
||||
#include "ir/value.h"
|
||||
#include "frontend/parallel/device_manager.h"
|
||||
#include "frontend/parallel/device_matrix.h"
|
||||
#include "frontend/parallel/dynamic_creator.h"
|
||||
#include "frontend/parallel/step_parallel.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace parallel {
|
||||
Status StandAloneInfo::CheckStrategy(const StrategyPtr &strategy) { return SUCCESS; }
|
||||
|
||||
Status StandAloneInfo::InferDevMatrixShape() {
|
||||
dev_matrix_shape_.push_back(stage_device_size_);
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status StandAloneInfo::InferForwardCommunication() { return SUCCESS; }
|
||||
|
||||
Status StandAloneInfo::GetAttrs() { return SUCCESS; }
|
||||
|
||||
Status StandAloneInfo::InferTensorMap() {
|
||||
// input tensor map, all -1
|
||||
for (size_t i = 0; i < inputs_shape_.size(); i++) {
|
||||
Shape tensor_map_index;
|
||||
for (size_t j = 0; j < inputs_shape_[i].size(); ++j) {
|
||||
tensor_map_index.push_back(MAP_NONE);
|
||||
}
|
||||
inputs_tensor_map_.push_back(tensor_map_index);
|
||||
}
|
||||
// output tensor map, all -1
|
||||
for (size_t i = 0; i < outputs_shape_.size(); i++) {
|
||||
Shape tensor_map_index;
|
||||
for (size_t j = 0; j < outputs_shape_[i].size(); ++j) {
|
||||
tensor_map_index.push_back(MAP_NONE);
|
||||
}
|
||||
outputs_tensor_map_.push_back(tensor_map_index);
|
||||
}
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status StandAloneInfo::InferTensorInfo() {
|
||||
if (inputs_shape_.empty() || outputs_shape_.empty() || inputs_tensor_map_.empty() || outputs_tensor_map_.empty()) {
|
||||
MS_LOG(ERROR) << name_ << ": Invalid args";
|
||||
return FAILED;
|
||||
}
|
||||
// infer input TensorInfo
|
||||
size_t temp = 0;
|
||||
for (size_t i = 0; i < input_value_.size(); ++i) {
|
||||
if (!input_value_[i]) {
|
||||
TensorLayout input_layout;
|
||||
if (input_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_[temp], inputs_shape_[temp]) != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << ": Infer input tensor layout failed, the index is " << i;
|
||||
return FAILED;
|
||||
}
|
||||
temp += 1;
|
||||
TensorInfo input_tensor_info(input_layout);
|
||||
inputs_tensor_info_.push_back(input_tensor_info);
|
||||
} else {
|
||||
TensorInfo empty_tensor_info;
|
||||
inputs_tensor_info_.push_back(empty_tensor_info);
|
||||
}
|
||||
}
|
||||
// infer output TensorInfo
|
||||
for (size_t j = 0; j < outputs_shape_.size(); ++j) {
|
||||
TensorLayout output_layout;
|
||||
if (output_layout.InitFromVector(dev_matrix_shape_, outputs_tensor_map_[j], outputs_shape_[j]) != SUCCESS) {
|
||||
return FAILED;
|
||||
}
|
||||
TensorInfo out_tensor_info(output_layout);
|
||||
outputs_tensor_info_.push_back(out_tensor_info);
|
||||
}
|
||||
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status StandAloneInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { return SUCCESS; }
|
||||
|
||||
std::vector<StrategyPtr> StandAloneInfo::GenerateOpStrategies(int64_t stage_id) {
|
||||
std::vector<StrategyPtr> sp_vector;
|
||||
return sp_vector;
|
||||
}
|
||||
|
||||
Status StandAloneInfo::InferAsLossDivisor() {
|
||||
as_loss_divisor_ = stage_device_size_;
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
REGISTER(StandAloneInfo);
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,50 @@
|
|||
/**
|
||||
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_STAND_ALONE_INFO_H_
|
||||
#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_STAND_ALONE_INFO_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "ir/value.h"
|
||||
#include "frontend/parallel/ops_info/operator_info.h"
|
||||
#include "frontend/parallel/strategy.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace parallel {
|
||||
class StandAloneInfo : public OperatorInfo {
|
||||
public:
|
||||
StandAloneInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, nullptr) {}
|
||||
|
||||
~StandAloneInfo() override = default;
|
||||
std::vector<StrategyPtr> GenerateOpStrategies(int64_t stage_id) override;
|
||||
Status SetCostUnderStrategy(const StrategyPtr &strategy) override;
|
||||
|
||||
protected:
|
||||
Status InferTensorInfo() override;
|
||||
Status CheckStrategy(const StrategyPtr &strategy) override;
|
||||
Status InferForwardCommunication() override;
|
||||
Status InferDevMatrixShape() override;
|
||||
Status InferTensorMap() override;
|
||||
Status GetAttrs() override;
|
||||
Status InferAsLossDivisor() override;
|
||||
};
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_STAND_ALONE_INFO_H_
|
|
@ -1550,18 +1550,23 @@ static void ExtractStrategyAndInit(const CNodePtr &cnode, const PrimitivePtr &pr
|
|||
}
|
||||
bool load_strategy_from_ckpt =
|
||||
StrategyCheckpoint::GetInstance().LoadCheckPointOn() && stra_map.find(strategy_key_name) != stra_map.end();
|
||||
if ((!StrategyFound(attrs) && !load_strategy_from_ckpt) && !cnode->HasPrimalAttr(IN_STRATEGY)) {
|
||||
MS_LOG(INFO) << "ExtractInformation: the strategy of node " << cnode->ToString() << " prim " << prim->name()
|
||||
<< " is empty, using batch parallel";
|
||||
in_strategy = GenerateBatchParallelStrategy(op_info, prim);
|
||||
} else if (cnode->HasPrimalAttr(IN_STRATEGY)) {
|
||||
in_strategy = ExtractStrategy(cnode->GetPrimalAttr(IN_STRATEGY));
|
||||
out_strategy = ExtractStrategy(cnode->GetPrimalAttr(OUT_STRATEGY));
|
||||
} else if (StrategyFound(attrs)) {
|
||||
in_strategy = ExtractStrategy(attrs[IN_STRATEGY]);
|
||||
out_strategy = ExtractStrategy(attrs[OUT_STRATEGY]);
|
||||
if (!prim->HasAttr(STAND_ALONE)) {
|
||||
if (((!StrategyFound(attrs) && !load_strategy_from_ckpt) && !cnode->HasPrimalAttr(IN_STRATEGY)) ||
|
||||
prim->HasAttr(BATCH_PARALLEL)) {
|
||||
MS_LOG(INFO) << "ExtractInformation: the strategy of node " << cnode->ToString() << " prim " << prim->name()
|
||||
<< " is empty, using batch parallel";
|
||||
in_strategy = GenerateBatchParallelStrategy(op_info, prim);
|
||||
} else if (cnode->HasPrimalAttr(IN_STRATEGY)) {
|
||||
in_strategy = ExtractStrategy(cnode->GetPrimalAttr(IN_STRATEGY));
|
||||
out_strategy = ExtractStrategy(cnode->GetPrimalAttr(OUT_STRATEGY));
|
||||
} else if (StrategyFound(attrs)) {
|
||||
in_strategy = ExtractStrategy(attrs[IN_STRATEGY]);
|
||||
out_strategy = ExtractStrategy(attrs[OUT_STRATEGY]);
|
||||
} else {
|
||||
in_strategy = stra_map[strategy_key_name];
|
||||
}
|
||||
} else {
|
||||
in_strategy = stra_map[strategy_key_name];
|
||||
in_strategy = GenerateStandAloneStrategy(op_info->inputs_shape());
|
||||
}
|
||||
|
||||
MS_EXCEPTION_IF_NULL(in_strategy);
|
||||
|
|
|
@ -859,14 +859,28 @@ OperatorInfoPtr OperatorInstance(const PrimitivePtr &prim, const PrimitiveAttrs
|
|||
const std::vector<Shapes> &shape_list) {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
OperatorInfoPtr operator_ = OperatorInstanceByName(prim->name(), attrs, shape_list);
|
||||
if (operator_ == nullptr) {
|
||||
if (IsInBatchParallelBlackList(prim)) {
|
||||
MS_LOG(EXCEPTION) << "Operator " << prim->name() << " is not supported yet in auto parallel mode.";
|
||||
}
|
||||
MS_LOG(INFO) << "Create " << prim->name() << " failed, use batch parallel";
|
||||
operator_ = OperatorInstanceByName(BATCH_PARALLEL, attrs, shape_list);
|
||||
MS_EXCEPTION_IF_NULL(operator_);
|
||||
if (operator_) {
|
||||
return operator_;
|
||||
}
|
||||
if (IsInBatchParallelBlackList(prim)) {
|
||||
operator_ = OperatorInstanceByName(STAND_ALONE, attrs, shape_list);
|
||||
prim->AddAttr(STAND_ALONE, MakeValue<bool>(true));
|
||||
MS_LOG(INFO) << "Operator " << prim->name() << " is not supported yet in auto parallel mode. Use Stand Alone";
|
||||
return operator_;
|
||||
}
|
||||
auto input_shape = shape_list.at(0);
|
||||
MS_EXCEPTION_IF_NULL(g_device_manager);
|
||||
auto device_num = g_device_manager->stage_device_num();
|
||||
MS_EXCEPTION_IF_ZERO("device_num", device_num);
|
||||
if (input_shape[0].empty() || input_shape[0][0] % device_num != 0) {
|
||||
MS_LOG(INFO) << "Operator " << prim->name() << " use Stand Alone";
|
||||
operator_ = OperatorInstanceByName(STAND_ALONE, attrs, shape_list);
|
||||
prim->AddAttr(STAND_ALONE, MakeValue<bool>(true));
|
||||
return operator_;
|
||||
}
|
||||
MS_LOG(INFO) << "Operator " << prim->name() << " use Batch Parallel";
|
||||
operator_ = OperatorInstanceByName(BATCH_PARALLEL, attrs, shape_list);
|
||||
prim->AddAttr(BATCH_PARALLEL, MakeValue<bool>(true));
|
||||
return operator_;
|
||||
}
|
||||
|
||||
|
@ -1426,6 +1440,23 @@ StrategyPtr GenerateBatchParallelStrategy(const OperatorInfoPtr operator_, const
|
|||
return strategyPtr;
|
||||
}
|
||||
|
||||
StrategyPtr GenerateStandAloneStrategy(const Shapes &inputs_shape) {
|
||||
Strategies strategy_v;
|
||||
for (size_t i = 0; i != inputs_shape.size(); i++) {
|
||||
if (inputs_shape[i].empty()) {
|
||||
MS_LOG(INFO) << "Elements of shapes is empty.";
|
||||
Dimensions empty_element;
|
||||
strategy_v.push_back(empty_element);
|
||||
} else {
|
||||
Dimensions element(inputs_shape[i].size(), 1);
|
||||
strategy_v.push_back(element);
|
||||
}
|
||||
}
|
||||
auto stage_id = g_device_manager->stage_id();
|
||||
auto stra_ptr = NewStrategy(stage_id, strategy_v);
|
||||
return stra_ptr;
|
||||
}
|
||||
|
||||
bool IsInsertVirtualOutput(const FuncGraphPtr &root) {
|
||||
MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance());
|
||||
auto comm_info = GetCommInfo();
|
||||
|
|
|
@ -103,6 +103,7 @@ Status ParallelInit();
|
|||
std::pair<bool, CNodePtr> FindCNode(const AnfNodePtr &anode, const std::string &name, const FuncGraphPtr &func_graph,
|
||||
size_t max_depth);
|
||||
void SetSharedParameterFlag(const FuncGraphPtr &root, const AnfNodePtr ¶meter);
|
||||
StrategyPtr GenerateStandAloneStrategy(const Shapes &inputs_shape);
|
||||
StrategyPtr GenerateBatchParallelStrategy(const OperatorInfoPtr operator_, const PrimitivePtr prim);
|
||||
bool IsInsertVirtualOutput(const FuncGraphPtr &root);
|
||||
TensorLayout GetInputLayoutFromCNode(const std::pair<AnfNodePtr, int64_t> &node_pair);
|
||||
|
|
|
@ -0,0 +1,125 @@
|
|||
# Copyright 2023 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import numpy as np
|
||||
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore import context
|
||||
from mindspore.common.api import _cell_graph_executor
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import prim_attr_register, PrimitiveWithInfer
|
||||
|
||||
|
||||
def setup_function():
|
||||
context.set_auto_parallel_context(dataset_strategy="full_batch")
|
||||
|
||||
|
||||
grad_all = C.GradOperation(get_all=True)
|
||||
|
||||
|
||||
class VirtualNodeGrad(PrimitiveWithInfer):
|
||||
""" VirtualLossGrad definition """
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""init VirtualLossGrad"""
|
||||
|
||||
def __call__(self, x, out, dout):
|
||||
raise NotImplementedError
|
||||
|
||||
def infer_shape(self, x_shape, out_shape, dout_shape):
|
||||
return x_shape
|
||||
|
||||
def infer_dtype(self, x_dtype, out_dtype, dout_dtype):
|
||||
return x_dtype
|
||||
|
||||
|
||||
class VirtualNode(PrimitiveWithInfer):
|
||||
""" VirtualLoss definition """
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""init VirtualLoss"""
|
||||
|
||||
def __call__(self, x):
|
||||
raise NotImplementedError
|
||||
|
||||
def get_bprop(self):
|
||||
loss_grad = VirtualNodeGrad()
|
||||
|
||||
def bprop(x, out, dout):
|
||||
dx = loss_grad(x, out, dout)
|
||||
return (dx,)
|
||||
|
||||
return bprop
|
||||
|
||||
def infer_shape(self, x_shape):
|
||||
return [1]
|
||||
|
||||
def infer_dtype(self, x_dtype):
|
||||
return x_dtype
|
||||
|
||||
|
||||
class NetWithLoss(nn.Cell):
|
||||
def __init__(self, network):
|
||||
super(NetWithLoss, self).__init__()
|
||||
self.loss = VirtualNode()
|
||||
self.network = network
|
||||
|
||||
def construct(self, x, y, b):
|
||||
predict = self.network(x, y, b)
|
||||
return self.loss(predict)
|
||||
|
||||
|
||||
class GradWrap(nn.Cell):
|
||||
def __init__(self, network):
|
||||
super(GradWrap, self).__init__()
|
||||
self.network = network
|
||||
|
||||
def construct(self, x, y, b):
|
||||
return grad_all(self.network)(x, y, b)
|
||||
|
||||
|
||||
def test_two_matmul():
|
||||
'''
|
||||
Feature: test StandAloneInfo
|
||||
Description: In SemiAuto mode, if there is no strategy and can't use BatchParallelInfo, use StandAloneInfo
|
||||
Expectation: success
|
||||
'''
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, strategy1, strategy2):
|
||||
super().__init__()
|
||||
self.matmul1 = P.MatMul().shard(strategy1)
|
||||
self.matmul2 = P.MatMul().shard(strategy2)
|
||||
|
||||
def construct(self, x, y, b):
|
||||
out = self.matmul1(x, y)
|
||||
out = self.matmul2(out, b)
|
||||
return out
|
||||
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0, full_batch=True)
|
||||
strategy1 = ((1, 2), (2, 2))
|
||||
strategy2 = ((1, 2), (2, 4))
|
||||
net = GradWrap(NetWithLoss(Net(strategy1, strategy2)))
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
|
||||
|
||||
x = Tensor(np.ones([1, 32]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([32, 128]), dtype=ms.float32)
|
||||
b = Tensor(np.ones([128, 128]), dtype=ms.float32)
|
||||
net.set_train()
|
||||
_cell_graph_executor.compile(net, x, y, b)
|
Loading…
Reference in New Issue