forked from mindspore-Ecosystem/mindspore
overwrite strategies for star graph structure
This commit is contained in:
parent
639e7ca47e
commit
fba2bfeb54
|
@ -201,7 +201,7 @@ Status RecoverStrategy(std::vector<EliminationPtr> eliminations) {
|
|||
right_edge->set_selected_cost(decision->right_edge_cost_);
|
||||
// 'left_node' recovers the strategy.
|
||||
left_node->SetSelectedStrategyAndCost(decision->left_node_strategy_, decision->left_node_cost_);
|
||||
if (TRIANGLE_STRATEGY_OVERWRITE) {
|
||||
if (TRIANGLE_STAR_STRATEGY_OVERWRITE) {
|
||||
// 'right_node' recovers the strategy.
|
||||
MS_LOG(INFO) << "Overwrite the right-node: " << right_node->name() << " in recovering triangle elimination.";
|
||||
right_node->SetSelectedStrategyAndCost(decision->right_node_strategy_, decision->right_node_cost_);
|
||||
|
@ -225,10 +225,16 @@ Status RecoverStrategy(std::vector<EliminationPtr> eliminations) {
|
|||
MS_EXCEPTION_IF_NULL(succ_nodes[0]);
|
||||
MS_EXCEPTION_IF_NULL(decision->succ_ops_stra_list_[0]);
|
||||
MS_EXCEPTION_IF_NULL(decision->succ_ops_cost_list_[0]);
|
||||
// Since Star is eliminated into 'succ_nodes[0]', only 'succ_nodes[0]' is needed to recover the strategy.
|
||||
// Star is eliminated into 'succ_nodes[0]'
|
||||
succ_nodes[0]->SetSelectedStrategyAndCost(decision->succ_ops_stra_list_[0], decision->succ_ops_cost_list_[0]);
|
||||
for (size_t k = 1; k < succ_nodes.size(); ++k) {
|
||||
succ_nodes[k]->CheckSelectedStrategy(decision->succ_ops_stra_list_[k]);
|
||||
if (TRIANGLE_STAR_STRATEGY_OVERWRITE) {
|
||||
// 'succ_nodes[k]' is overwritten strategy and cost.
|
||||
succ_nodes[k]->SetSelectedStrategyAndCost(decision->succ_ops_stra_list_[k], decision->succ_ops_cost_list_[k]);
|
||||
} else {
|
||||
// In this case, 'succ_nodes[k]' is NOT overwritten strategy and cost, however, it checks the strategy.
|
||||
succ_nodes[k]->CheckSelectedStrategy(decision->succ_ops_stra_list_[k]);
|
||||
}
|
||||
}
|
||||
MS_LOG(INFO) << "Recover starElimination succeeded.";
|
||||
} else {
|
||||
|
|
|
@ -20,9 +20,9 @@
|
|||
#include <memory>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include "ir/value.h"
|
||||
#include "frontend/parallel/auto_parallel/edge_costmodel.h"
|
||||
#include "frontend/parallel/auto_parallel/graph_costmodel.h"
|
||||
#include "ir/value.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace parallel {
|
||||
|
|
|
@ -22,11 +22,11 @@
|
|||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include "utils/ms_utils.h"
|
||||
#include "frontend/parallel/auto_parallel/costmodel.h"
|
||||
#include "frontend/parallel/ops_info/operator_info.h"
|
||||
#include "frontend/parallel/tensor_layout/tensor_info.h"
|
||||
#include "frontend/parallel/tensor_layout/tensor_layout.h"
|
||||
#include "utils/ms_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace parallel {
|
||||
|
|
|
@ -40,7 +40,7 @@ bool FULLY_USE_DEVICES = DEFAULT_FULLY_USE_DEVICES;
|
|||
bool ELEMENTWISE_OP_STRA_FOLLOW = DEFAULT_ELEMENTWISE_OP_STRA_FOLLOW;
|
||||
bool MULTI_SUBGRAPHS = DEFAULT_IS_MULTI_SUBGRAPHS;
|
||||
int32_t RUN_PHASE = DEFAULT_RUN_PHASE;
|
||||
bool TRIANGLE_STRATEGY_OVERWRITE = DEFAULT_TRIANGLE_STRATEGY_OVERWRITE;
|
||||
bool TRIANGLE_STAR_STRATEGY_OVERWRITE = DEFAULT_TRIANGLE_STAR_STRATEGY_OVERWRITE;
|
||||
|
||||
void CostGraph::SetDeviceMemoryAndCostParameter() {
|
||||
MS_EXCEPTION_IF_NULL(CostModelContext::GetInstance());
|
||||
|
@ -155,12 +155,12 @@ void CostGraph::SetDeviceMemoryAndCostParameter() {
|
|||
MS_LOG(INFO) << "multi_subgraphs: false.";
|
||||
}
|
||||
|
||||
auto overwrite = CostModelContext::GetInstance()->triangle_strategy_overwrite();
|
||||
TRIANGLE_STRATEGY_OVERWRITE = overwrite;
|
||||
if (TRIANGLE_STRATEGY_OVERWRITE) {
|
||||
MS_LOG(INFO) << "triangle_strategy_overwrite: true.";
|
||||
auto overwrite = CostModelContext::GetInstance()->triangle_star_strategy_overwrite();
|
||||
TRIANGLE_STAR_STRATEGY_OVERWRITE = overwrite;
|
||||
if (TRIANGLE_STAR_STRATEGY_OVERWRITE) {
|
||||
MS_LOG(INFO) << "triangle_star_strategy_overwrite: true.";
|
||||
} else {
|
||||
MS_LOG(INFO) << "triangle_strategy_overwrite: false.";
|
||||
MS_LOG(INFO) << "triangle_star_strategy_overwrite: false.";
|
||||
}
|
||||
|
||||
// RUN_PHASE
|
||||
|
@ -1303,7 +1303,7 @@ void CostGraph::CreateTriangleEliminationSubCostList(StrategyPtr elimi_op_stra,
|
|||
elimi_op_cost->communication_without_parameter_ + left_edge_cost->communication_without_parameter_ +
|
||||
left_node_cost->communication_without_parameter_ + right_edge_cost->communication_without_parameter_;
|
||||
|
||||
if (TRIANGLE_STRATEGY_OVERWRITE) {
|
||||
if (TRIANGLE_STAR_STRATEGY_OVERWRITE) {
|
||||
new_computation += right_op_cost->computation_cost_;
|
||||
new_memory += right_op_cost->memory_with_reuse_;
|
||||
new_commu_cost += right_op_cost->communication_cost_;
|
||||
|
@ -1399,7 +1399,9 @@ OperatorInfoPtr CostGraph::EliminationTriangle(const OperatorInfoPtr &elimi_op,
|
|||
}
|
||||
|
||||
if (!valid) {
|
||||
MS_LOG(EXCEPTION) << "Eliminating triangle: " << elimi_op->name() << " failed.";
|
||||
MS_LOG(EXCEPTION) << "Eliminating triangle: " << elimi_op->name()
|
||||
<< " failed. It may be caused by "
|
||||
"configuring inconsistent strategies for operators.";
|
||||
}
|
||||
elimi_op->SetNotAlive();
|
||||
MS_LOG(INFO) << "Eliminating triangle: " << elimi_op->name() << " succeeded.";
|
||||
|
@ -1440,6 +1442,13 @@ void CostGraph::CreateStarEliminationSubCostList(const StrategyPtr &first_succ_n
|
|||
commu_cost += succ_edges_costs[i]->communication_cost_;
|
||||
commu_forward += succ_edges_costs[i]->communication_forward_;
|
||||
commu_without += succ_edges_costs[i]->communication_without_parameter_;
|
||||
if (TRIANGLE_STAR_STRATEGY_OVERWRITE) {
|
||||
computation_cost += succ_nodes_costs[i]->computation_cost_;
|
||||
memory_cost += succ_nodes_costs[i]->memory_with_reuse_;
|
||||
commu_cost += succ_nodes_costs[i]->communication_cost_;
|
||||
commu_forward += succ_nodes_costs[i]->communication_forward_;
|
||||
commu_without += succ_nodes_costs[i]->communication_without_parameter_;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1544,7 +1553,9 @@ std::vector<std::shared_ptr<Edge>> CostGraph::EliminationStar(const OperatorInfo
|
|||
}
|
||||
|
||||
if (!valid) {
|
||||
MS_LOG(EXCEPTION) << "Eliminating star centered at: " << merged_op->name() << " failed.";
|
||||
MS_LOG(EXCEPTION) << "Eliminating star centered at: " << merged_op->name()
|
||||
<< " failed. It may be caused by "
|
||||
"configuring inconsistent strategies for operators.";
|
||||
}
|
||||
|
||||
merged_op->SetNotAlive();
|
||||
|
|
|
@ -22,11 +22,11 @@
|
|||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include "utils/ms_utils.h"
|
||||
#include "frontend/parallel/auto_parallel/edge_costmodel.h"
|
||||
#include "frontend/parallel/costmodel_context.h"
|
||||
#include "frontend/parallel/ops_info/operator_info.h"
|
||||
#include "frontend/parallel/ops_info/tmp_identity_info.h"
|
||||
#include "utils/ms_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace parallel {
|
||||
|
@ -46,7 +46,7 @@ extern bool FULLY_USE_DEVICES;
|
|||
extern bool ELEMENTWISE_OP_STRA_FOLLOW;
|
||||
extern bool MULTI_SUBGRAPHS;
|
||||
extern int32_t RUN_PHASE;
|
||||
extern bool TRIANGLE_STRATEGY_OVERWRITE;
|
||||
extern bool TRIANGLE_STAR_STRATEGY_OVERWRITE;
|
||||
|
||||
class CostGraph {
|
||||
// 'CostGraph' consists of Operators and edges between them. An edge is created between two Operators if they have
|
||||
|
|
|
@ -16,8 +16,8 @@
|
|||
|
||||
#include "frontend/parallel/auto_parallel/operator_costmodel.h"
|
||||
|
||||
#include <random>
|
||||
#include <algorithm>
|
||||
#include <random>
|
||||
#include "frontend/parallel/device_matrix.h"
|
||||
#include "frontend/parallel/tensor_layout/tensor_redistribution.h"
|
||||
|
||||
|
|
|
@ -19,9 +19,9 @@
|
|||
#include <algorithm>
|
||||
#include <cstdint>
|
||||
#include <functional>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include <map>
|
||||
|
||||
#include "frontend/parallel/device_manager.h"
|
||||
|
||||
|
|
|
@ -18,18 +18,18 @@
|
|||
#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_CONTEXT_H_
|
||||
|
||||
#include <cstdint>
|
||||
#include <memory>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "frontend/parallel/ops_info/ops_utils.h"
|
||||
#include "frontend/parallel/status.h"
|
||||
#include "utils/convert_utils.h"
|
||||
#include "ir/anf.h"
|
||||
#include "ir/func_graph.h"
|
||||
#include "utils/convert_utils.h"
|
||||
#include "utils/info.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace parallel {
|
||||
|
|
|
@ -64,7 +64,7 @@ void CostModelContext::ResetAlgoParameters() {
|
|||
tensor_slice_alignment_size_ = DEFAULT_TENSOR_SLICE_ALIGNMENT_SIZE;
|
||||
fully_use_device_ = DEFAULT_FULLY_USE_DEVICES;
|
||||
elementwise_stra_follow_ = DEFAULT_ELEMENTWISE_OP_STRA_FOLLOW;
|
||||
triangle_strategy_overwrite_ = DEFAULT_TRIANGLE_STRATEGY_OVERWRITE;
|
||||
triangle_star_strategy_overwrite_ = DEFAULT_TRIANGLE_STAR_STRATEGY_OVERWRITE;
|
||||
}
|
||||
|
||||
void CostModelContext::set_costmodel_context_for_device(const std::string &device_target) {
|
||||
|
@ -134,7 +134,9 @@ void CostModelContext::set_elementwise_stra_follow(bool elementwise_follow) {
|
|||
elementwise_stra_follow_ = elementwise_follow;
|
||||
}
|
||||
|
||||
void CostModelContext::set_triangle_strategy_overwrite(bool overwrite) { triangle_strategy_overwrite_ = overwrite; }
|
||||
void CostModelContext::set_triangle_star_strategy_overwrite(bool overwrite) {
|
||||
triangle_star_strategy_overwrite_ = overwrite;
|
||||
}
|
||||
|
||||
void CostModelContext::set_run_phase(int32_t phase) { run_phase_ = phase; }
|
||||
|
||||
|
|
|
@ -44,7 +44,7 @@ namespace parallel {
|
|||
#define DEFAULT_RUN_PHASE 0
|
||||
#define TRAINING_PHASE 0
|
||||
#define INFERENCE_PHASE 1
|
||||
#define DEFAULT_TRIANGLE_STRATEGY_OVERWRITE true;
|
||||
#define DEFAULT_TRIANGLE_STAR_STRATEGY_OVERWRITE true;
|
||||
|
||||
class CostModelContext {
|
||||
public:
|
||||
|
@ -135,8 +135,8 @@ class CostModelContext {
|
|||
void set_elementwise_stra_follow(bool);
|
||||
bool elementwise_stra_follow() const { return elementwise_stra_follow_; }
|
||||
|
||||
void set_triangle_strategy_overwrite(bool);
|
||||
bool triangle_strategy_overwrite() const { return triangle_strategy_overwrite_; }
|
||||
void set_triangle_star_strategy_overwrite(bool);
|
||||
bool triangle_star_strategy_overwrite() const { return triangle_star_strategy_overwrite_; }
|
||||
|
||||
void set_run_phase(int32_t);
|
||||
int32_t run_phase() const { return run_phase_; }
|
||||
|
@ -172,9 +172,9 @@ class CostModelContext {
|
|||
// MULTI_SUBGRAPHS
|
||||
bool is_multi_subgraphs_;
|
||||
|
||||
// In the recovery phase of DP algorithm, when encountering triangle structure,
|
||||
// In the recovery phase of DP algorithm, when encountering triangle structure and star structure,
|
||||
// whether overwrite the right-node strategy
|
||||
bool triangle_strategy_overwrite_;
|
||||
bool triangle_star_strategy_overwrite_;
|
||||
|
||||
int32_t run_phase_; // 0: 'training', 1: 'inference'
|
||||
|
||||
|
|
|
@ -25,13 +25,13 @@
|
|||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "utils/ms_utils.h"
|
||||
#include "frontend/parallel/device.h"
|
||||
#include "frontend/parallel/device_matrix.h"
|
||||
#include "frontend/parallel/group_manager.h"
|
||||
#include "frontend/parallel/status.h"
|
||||
#include "frontend/parallel/strategy.h"
|
||||
#include "utils/convert_utils.h"
|
||||
#include "utils/ms_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace parallel {
|
||||
|
|
|
@ -17,8 +17,8 @@
|
|||
#include "frontend/parallel/group_manager.h"
|
||||
#include <algorithm>
|
||||
#include <vector>
|
||||
#include "frontend/parallel/device_manager.h"
|
||||
#include "backend/session/executor_manager.h"
|
||||
#include "frontend/parallel/device_manager.h"
|
||||
#include "utils/comm_manager.h"
|
||||
#include "utils/ms_context.h"
|
||||
|
||||
|
|
|
@ -26,10 +26,8 @@
|
|||
#include <unordered_map>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include <unordered_set>
|
||||
|
||||
#include "ir/anf.h"
|
||||
#include "ir/param_info.h"
|
||||
#include "ir/tensor.h"
|
||||
#include "frontend/optimizer/opt.h"
|
||||
#include "frontend/optimizer/optimizer.h"
|
||||
#include "frontend/parallel/auto_parallel/dp_algo_costmodel.h"
|
||||
|
@ -39,11 +37,14 @@
|
|||
#include "frontend/parallel/auto_parallel/rec_core/rec_parse_graph.h"
|
||||
#include "frontend/parallel/auto_parallel/rec_core/rec_partition.h"
|
||||
#include "frontend/parallel/context.h"
|
||||
#include "frontend/parallel/ops_info/tmp_identity_info.h"
|
||||
#include "frontend/parallel/ops_info/reshape_info.h"
|
||||
#include "frontend/parallel/graph_util/node_info.h"
|
||||
#include "frontend/parallel/ops_info/reshape_info.h"
|
||||
#include "frontend/parallel/ops_info/tmp_identity_info.h"
|
||||
#include "frontend/parallel/step_parallel.h"
|
||||
#include "frontend/parallel/strategy_checkpoint/parallel_strategy_checkpoint.h"
|
||||
#include "ir/anf.h"
|
||||
#include "ir/param_info.h"
|
||||
#include "ir/tensor.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace parallel {
|
||||
|
|
|
@ -21,9 +21,9 @@
|
|||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "ir/anf.h"
|
||||
#include "frontend/optimizer/opt.h"
|
||||
#include "frontend/parallel/status.h"
|
||||
#include "ir/anf.h"
|
||||
#include "pipeline/jit/pipeline.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
|
|
@ -27,8 +27,6 @@
|
|||
#include <unordered_map>
|
||||
#include <utility>
|
||||
|
||||
#include "ir/tensor.h"
|
||||
#include "ir/param_info.h"
|
||||
#include "frontend/operator/ops.h"
|
||||
#include "frontend/optimizer/optimizer.h"
|
||||
#include "frontend/parallel/auto_parallel/graph_costmodel.h"
|
||||
|
@ -41,9 +39,11 @@
|
|||
#include "frontend/parallel/node_check.h"
|
||||
#include "frontend/parallel/ops_info/matmul_info.h"
|
||||
#include "frontend/parallel/strategy_checkpoint/parallel_strategy_checkpoint.h"
|
||||
#include "ir/param_info.h"
|
||||
#include "ir/tensor.h"
|
||||
#include "utils/comm_manager.h"
|
||||
#include "utils/symbolic.h"
|
||||
#include "utils/ms_context.h"
|
||||
#include "utils/symbolic.h"
|
||||
|
||||
using mindspore::tensor::Tensor;
|
||||
|
||||
|
|
|
@ -21,10 +21,10 @@
|
|||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
#include <set>
|
||||
|
||||
#include "frontend/optimizer/opt.h"
|
||||
#include "frontend/parallel/strategy.h"
|
||||
|
|
|
@ -23,8 +23,8 @@
|
|||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "frontend/parallel/status.h"
|
||||
#include "frontend/parallel/device_matrix.h"
|
||||
#include "frontend/parallel/status.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace parallel {
|
||||
|
|
|
@ -0,0 +1,134 @@
|
|||
# Copyright 2019 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
from mindspore.common.api import _executor
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.parallel._utils import _reset_op_id as reset_op_id
|
||||
from mindspore import context, Tensor, Parameter
|
||||
from mindspore.parallel import set_algo_parameters
|
||||
from tests.ut.python.ops.test_math_ops import VirtualLoss
|
||||
|
||||
grad_all = C.GradOperation(get_all=True)
|
||||
|
||||
class NetWithLoss(nn.Cell):
|
||||
def __init__(self, network):
|
||||
super(NetWithLoss, self).__init__()
|
||||
self.loss = VirtualLoss()
|
||||
self.network = network
|
||||
|
||||
def construct(self, x):
|
||||
predict = self.network(x)
|
||||
return self.loss(predict)
|
||||
|
||||
class GradWarp(nn.Cell):
|
||||
def __init__(self, network):
|
||||
super(GradWarp, self).__init__()
|
||||
self.network = network
|
||||
|
||||
def construct(self, x):
|
||||
return grad_all(self.network)(x)
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, strategy_dict=None):
|
||||
super(Net, self).__init__()
|
||||
self.mul1 = P.Mul()
|
||||
self.mul2 = P.Mul()
|
||||
self.mul3 = P.Mul()
|
||||
self.mul4 = P.Mul()
|
||||
self.relu1 = P.ReLU()
|
||||
self.relu2 = P.ReLU()
|
||||
self.ba1 = P.BiasAdd()
|
||||
self.add = P.TensorAdd()
|
||||
self.weight = Parameter(Tensor(np.ones([128, 1000]), dtype=ms.float32), name="weight")
|
||||
self.bias = Parameter(Tensor(np.ones([1000]), dtype=ms.float32), name="bias")
|
||||
|
||||
if strategy_dict is not None:
|
||||
self.mul1.shard(strategy_dict["mul1"])
|
||||
self.mul2.shard(strategy_dict["mul2"])
|
||||
self.relu1.shard(strategy_dict["relu1"])
|
||||
self.relu2.shard(strategy_dict["relu2"])
|
||||
self.ba1.shard(strategy_dict["bias_add"])
|
||||
self.add.shard(strategy_dict["add"])
|
||||
|
||||
def construct(self, inputs):
|
||||
x = self.mul1(inputs, self.weight)
|
||||
y = self.relu1(x)
|
||||
y = self.mul2(y, self.weight)
|
||||
z = self.mul3(x, self.weight)
|
||||
z = self.ba1(z, self.bias)
|
||||
x = self.add(y, z)
|
||||
x = self.mul4(x, self.weight)
|
||||
x = self.relu2(x)
|
||||
return x
|
||||
|
||||
def test_star_strategy_consistency1():
|
||||
size = 8
|
||||
context.set_auto_parallel_context(device_num=size, global_rank=0)
|
||||
set_algo_parameters(fully_use_devices=False)
|
||||
x = Tensor(np.ones([128, 1000]), dtype=ms.float32)
|
||||
strategy_dict = {"mul1": ((2, 4), (2, 4)), "mul2": None, "relu1": ((4, 1),), "bias_add": ((8, 1), (1,)),
|
||||
"relu2": ((2, 2),), "add": ((1, 8), (1, 8))}
|
||||
net = NetWithLoss(Net(strategy_dict))
|
||||
context.set_auto_parallel_context(parallel_mode="auto_parallel")
|
||||
net.set_auto_parallel()
|
||||
reset_op_id()
|
||||
_executor.compile(net, x, phase='train')
|
||||
|
||||
|
||||
def test_star_strategy_consistency2():
|
||||
size = 8
|
||||
context.set_auto_parallel_context(device_num=size, global_rank=0)
|
||||
set_algo_parameters(fully_use_devices=False)
|
||||
x = Tensor(np.ones([128, 1000]), dtype=ms.float32)
|
||||
strategy_dict = {"mul1": None, "mul2": ((1, 4), (1, 4)), "relu1": ((2, 1),), "bias_add": ((4, 2), (2,)),
|
||||
"relu2": ((2, 2),), "add": ((8, 1), (8, 1))}
|
||||
net = NetWithLoss(Net(strategy_dict))
|
||||
context.set_auto_parallel_context(parallel_mode="auto_parallel")
|
||||
net.set_auto_parallel()
|
||||
reset_op_id()
|
||||
_executor.compile(net, x, phase='train')
|
||||
|
||||
|
||||
def test_star_strategy_consistency3():
|
||||
size = 8
|
||||
context.set_auto_parallel_context(device_num=size, global_rank=0)
|
||||
set_algo_parameters(fully_use_devices=False)
|
||||
x = Tensor(np.ones([128, 1000]), dtype=ms.float32)
|
||||
strategy_dict = {"mul1": None, "mul2": None, "relu1": ((8, 1),), "bias_add": ((1, 4), (4,)),
|
||||
"relu2": ((4, 1),), "add": ((2, 2), (2, 2))}
|
||||
net = NetWithLoss(Net(strategy_dict))
|
||||
context.set_auto_parallel_context(parallel_mode="auto_parallel")
|
||||
net.set_auto_parallel()
|
||||
reset_op_id()
|
||||
_executor.compile(net, x, phase='train')
|
||||
|
||||
|
||||
def test_star_strategy_consistency4():
|
||||
size = 8
|
||||
context.set_auto_parallel_context(device_num=size, global_rank=0)
|
||||
set_algo_parameters(fully_use_devices=False)
|
||||
x = Tensor(np.ones([128, 1000]), dtype=ms.float32)
|
||||
strategy_dict = {"mul1": ((1, 8), (1, 8)), "mul2": ((1, 4), (1, 4)), "relu1": None, "bias_add": None,
|
||||
"relu2": None, "add": None}
|
||||
net = NetWithLoss(Net(strategy_dict))
|
||||
context.set_auto_parallel_context(parallel_mode="auto_parallel")
|
||||
net.set_auto_parallel()
|
||||
reset_op_id()
|
||||
with pytest.raises(RuntimeError):
|
||||
_executor.compile(net, x, phase='train')
|
Loading…
Reference in New Issue