fix bug for allreduce fusion and add resnet unit test
This commit is contained in:
parent
0830ad932b
commit
ab917a734d
|
@ -359,7 +359,7 @@ Status AllreduceFusion::SetFusionByBackwardCompAndAllreduceTime() {
|
||||||
return FAILED;
|
return FAILED;
|
||||||
}
|
}
|
||||||
double para_size = (tail_time_ - allreduce_inherent_time_) / allreduce_bandwidth_;
|
double para_size = (tail_time_ - allreduce_inherent_time_) / allreduce_bandwidth_;
|
||||||
double to_cost = allreduce_graph_.max() + FUSION_COST_EPS;
|
double to_cost = allreduce_graph_.max();
|
||||||
int32_t fusion = 1;
|
int32_t fusion = 1;
|
||||||
while (to_cost != 0) {
|
while (to_cost != 0) {
|
||||||
MS_LOG(INFO) << "to_cost: " << to_cost << " para_size: " << para_size;
|
MS_LOG(INFO) << "to_cost: " << to_cost << " para_size: " << para_size;
|
||||||
|
|
|
@ -38,7 +38,6 @@ constexpr double DEFAULT_COST_MODEL_ALLREDUCE_FUSION_COMPUTATION_TIME_PARAMETER
|
||||||
constexpr char FUSION[] = "fusion";
|
constexpr char FUSION[] = "fusion";
|
||||||
constexpr char PARAMETER[] = "parameter";
|
constexpr char PARAMETER[] = "parameter";
|
||||||
const uint32_t MAX_RECURSIVE_CALL_TIMES = 100;
|
const uint32_t MAX_RECURSIVE_CALL_TIMES = 100;
|
||||||
const double FUSION_COST_EPS = 1e-7;
|
|
||||||
class AllreduceFusion {
|
class AllreduceFusion {
|
||||||
public:
|
public:
|
||||||
AllreduceFusion()
|
AllreduceFusion()
|
||||||
|
|
|
@ -24,7 +24,19 @@
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace parallel {
|
namespace parallel {
|
||||||
Status AllreduceGraph::AddNode(const CNodePtr& node, const AnfNodePtr& para) {
|
Status AllreduceGraph::AddNode(const CNodePtr& node, const AnfNodePtr& para) {
|
||||||
auto arnode = std::make_shared<AllreduceNode>(AllreduceNode());
|
AllreduceNodePtr arnode;
|
||||||
|
auto cnode_emplace_return = cnode_set_.emplace(node);
|
||||||
|
if (!cnode_emplace_return.second) {
|
||||||
|
MS_LOG(INFO) << "node: " << node->DebugString() << " has already been added!";
|
||||||
|
auto cnode_arnode_pair = cnode_arnode_map_.find(node);
|
||||||
|
if (cnode_arnode_pair == cnode_arnode_map_.end()) {
|
||||||
|
MS_LOG(EXCEPTION) << "node is not in cnode_arnode_map_!";
|
||||||
|
}
|
||||||
|
arnode = cnode_arnode_pair->second;
|
||||||
|
} else {
|
||||||
|
arnode = std::make_shared<AllreduceNode>(AllreduceNode());
|
||||||
|
}
|
||||||
|
|
||||||
if (arnode->Init(node) != SUCCESS) {
|
if (arnode->Init(node) != SUCCESS) {
|
||||||
MS_LOG(ERROR) << "AllreduceNode Init failed";
|
MS_LOG(ERROR) << "AllreduceNode Init failed";
|
||||||
return FAILED;
|
return FAILED;
|
||||||
|
@ -39,10 +51,6 @@ Status AllreduceGraph::AddNode(const CNodePtr& node, const AnfNodePtr& para) {
|
||||||
if (!arnode_emplace_return.second) {
|
if (!arnode_emplace_return.second) {
|
||||||
MS_LOG(INFO) << "node: " << node->DebugString() << "'s arnode has already been added!";
|
MS_LOG(INFO) << "node: " << node->DebugString() << "'s arnode has already been added!";
|
||||||
}
|
}
|
||||||
auto cnode_emplace_return = cnode_set_.emplace(node);
|
|
||||||
if (!cnode_emplace_return.second) {
|
|
||||||
MS_LOG(INFO) << "node: " << node->DebugString() << " has already been added!";
|
|
||||||
}
|
|
||||||
cnode_emplace_return = para_cnodeset_map_[para].emplace(node);
|
cnode_emplace_return = para_cnodeset_map_[para].emplace(node);
|
||||||
if (!cnode_emplace_return.second) {
|
if (!cnode_emplace_return.second) {
|
||||||
MS_LOG(INFO) << "node: " << node->DebugString() << " already in para: " << para->fullname_with_scope()
|
MS_LOG(INFO) << "node: " << node->DebugString() << " already in para: " << para->fullname_with_scope()
|
||||||
|
@ -75,7 +83,7 @@ Status AllreduceGraph::AddEdge(const CNodePtr& from, const CNodePtr& to, double
|
||||||
MS_LOG(ERROR) << "from_arnode AddNext failed";
|
MS_LOG(ERROR) << "from_arnode AddNext failed";
|
||||||
return FAILED;
|
return FAILED;
|
||||||
}
|
}
|
||||||
if (to_arnode->AddPrev(from_arnode, dist) != SUCCESS) {
|
if (to_arnode->AddPrev(from_arnode, dist, &max_) != SUCCESS) {
|
||||||
MS_LOG(ERROR) << "to_arnode AddPrev failed";
|
MS_LOG(ERROR) << "to_arnode AddPrev failed";
|
||||||
return FAILED;
|
return FAILED;
|
||||||
}
|
}
|
||||||
|
@ -110,7 +118,7 @@ std::pair<std::vector<AnfNodePtr>, double> AllreduceGraph::GetParaByParaSize(dou
|
||||||
double cur_para_size = 0;
|
double cur_para_size = 0;
|
||||||
double from = to;
|
double from = to;
|
||||||
for (auto& arnode : arnode_vec_) {
|
for (auto& arnode : arnode_vec_) {
|
||||||
if (arnode.depend_feat_size() >= to) {
|
if (arnode.depend_feat_size() != max_ && arnode.depend_feat_size() >= to) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
if (para_size > 0 && cur_para_size >= para_size && arnode.depend_feat_size() < from) {
|
if (para_size > 0 && cur_para_size >= para_size && arnode.depend_feat_size() < from) {
|
||||||
|
|
|
@ -15,6 +15,7 @@
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include "parallel/allreduce_fusion/allreduce_node.h"
|
#include "parallel/allreduce_fusion/allreduce_node.h"
|
||||||
|
#include <queue>
|
||||||
#include "parallel/tensor_layout/tensor_layout.h"
|
#include "parallel/tensor_layout/tensor_layout.h"
|
||||||
#include "utils/log_adapter.h"
|
#include "utils/log_adapter.h"
|
||||||
|
|
||||||
|
@ -29,7 +30,7 @@ Status AllreduceNode::AddNext(const AllreduceNodePtr& next_node) {
|
||||||
return SUCCESS;
|
return SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
Status AllreduceNode::AddPrev(const AllreduceNodePtr& prev_node, double dist) {
|
Status AllreduceNode::AddPrev(const AllreduceNodePtr& prev_node, double dist, double* max) {
|
||||||
if (prev_node == nullptr) {
|
if (prev_node == nullptr) {
|
||||||
MS_LOG(ERROR) << "next_node is nullptr!";
|
MS_LOG(ERROR) << "next_node is nullptr!";
|
||||||
return FAILED;
|
return FAILED;
|
||||||
|
@ -39,7 +40,26 @@ Status AllreduceNode::AddPrev(const AllreduceNodePtr& prev_node, double dist) {
|
||||||
return FAILED;
|
return FAILED;
|
||||||
}
|
}
|
||||||
prev_.emplace_back(prev_node);
|
prev_.emplace_back(prev_node);
|
||||||
depend_feat_size_ += prev_node->depend_feat_size() + dist;
|
double add_dist = prev_node->depend_feat_size() + dist;
|
||||||
|
depend_feat_size_ += add_dist;
|
||||||
|
if (depend_feat_size_ > *max) {
|
||||||
|
*max = depend_feat_size_;
|
||||||
|
}
|
||||||
|
std::queue<AllreduceNodePtr> next_queue;
|
||||||
|
for (auto& next : next_) {
|
||||||
|
next_queue.push(next);
|
||||||
|
}
|
||||||
|
while (!next_queue.empty()) {
|
||||||
|
auto ele = next_queue.front();
|
||||||
|
ele->AddDependFeatSize(add_dist);
|
||||||
|
if (ele->depend_feat_size() > *max) {
|
||||||
|
*max = ele->depend_feat_size();
|
||||||
|
}
|
||||||
|
for (auto& next : ele->next()) {
|
||||||
|
next_queue.push(next);
|
||||||
|
}
|
||||||
|
next_queue.pop();
|
||||||
|
}
|
||||||
return SUCCESS;
|
return SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -39,9 +39,14 @@ class AllreduceNode {
|
||||||
const std::unordered_set<AnfNodePtr>& paras() const { return paras_; }
|
const std::unordered_set<AnfNodePtr>& paras() const { return paras_; }
|
||||||
double curr_para_size() const { return curr_para_size_; }
|
double curr_para_size() const { return curr_para_size_; }
|
||||||
virtual ~AllreduceNode() = default;
|
virtual ~AllreduceNode() = default;
|
||||||
Status AddPrev(const AllreduceNodePtr& prev_node, double dist);
|
// Add previous node
|
||||||
|
// prev_node is the previous to be added
|
||||||
|
// max is the current max depend_feat_size of the AllreduceGraph
|
||||||
|
Status AddPrev(const AllreduceNodePtr& prev_node, double dist, double* max);
|
||||||
Status AddNext(const AllreduceNodePtr& next_node);
|
Status AddNext(const AllreduceNodePtr& next_node);
|
||||||
double depend_feat_size() const { return depend_feat_size_; }
|
double depend_feat_size() const { return depend_feat_size_; }
|
||||||
|
void AddDependFeatSize(double add_dist) { depend_feat_size_ += add_dist; }
|
||||||
|
const std::vector<AllreduceNodePtr>& next() const { return next_; }
|
||||||
void ToString() const;
|
void ToString() const;
|
||||||
bool operator<(const AllreduceNode& node) const { return depend_feat_size_ < node.depend_feat_size(); }
|
bool operator<(const AllreduceNode& node) const { return depend_feat_size_ < node.depend_feat_size(); }
|
||||||
bool operator>(const AllreduceNode& node) const { return depend_feat_size_ > node.depend_feat_size(); }
|
bool operator>(const AllreduceNode& node) const { return depend_feat_size_ > node.depend_feat_size(); }
|
||||||
|
|
|
@ -275,7 +275,7 @@ def test_allreduce_fusion5():
|
||||||
expect_dict = {'backbone2.fc8.weight': 3,
|
expect_dict = {'backbone2.fc8.weight': 3,
|
||||||
'backbone2.fc7.weight': 3,
|
'backbone2.fc7.weight': 3,
|
||||||
'backbone2.fc6.weight': 3,
|
'backbone2.fc6.weight': 3,
|
||||||
'backbone2.fc5.weight': 2,
|
'backbone2.fc5.weight': 3,
|
||||||
'backbone2.fc4.weight': 2,
|
'backbone2.fc4.weight': 2,
|
||||||
'backbone2.fc3.weight': 2,
|
'backbone2.fc3.weight': 2,
|
||||||
'backbone2.fc2.weight': 1,
|
'backbone2.fc2.weight': 1,
|
||||||
|
@ -283,7 +283,7 @@ def test_allreduce_fusion5():
|
||||||
'backbone1.fc8.weight': 3,
|
'backbone1.fc8.weight': 3,
|
||||||
'backbone1.fc7.weight': 3,
|
'backbone1.fc7.weight': 3,
|
||||||
'backbone1.fc6.weight': 3,
|
'backbone1.fc6.weight': 3,
|
||||||
'backbone1.fc5.weight': 2,
|
'backbone1.fc5.weight': 3,
|
||||||
'backbone1.fc4.weight': 2,
|
'backbone1.fc4.weight': 2,
|
||||||
'backbone1.fc3.weight': 2,
|
'backbone1.fc3.weight': 2,
|
||||||
'backbone1.fc2.weight': 1,
|
'backbone1.fc2.weight': 1,
|
||||||
|
|
|
@ -273,13 +273,9 @@ class DatasetLenet():
|
||||||
return 1
|
return 1
|
||||||
|
|
||||||
|
|
||||||
def test_train_32k_8p(epoch_size=3, batch_size=32, num_classes=32768): #1048576 #131072 #32768 #8192
|
def train_32k_8p(epoch_size=3, batch_size=32, num_classes=32768):
|
||||||
dev_num = 8
|
dev_num = 8
|
||||||
context.set_auto_parallel_context(parallel_mode=ParallelMode.AUTO_PARALLEL, device_num=dev_num)
|
context.set_auto_parallel_context(parallel_mode=ParallelMode.AUTO_PARALLEL, device_num=dev_num)
|
||||||
cost_model_context.set_cost_model_context(costmodel_gamma=0.001, costmodel_beta=260.0)
|
|
||||||
cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_algorithm=1)
|
|
||||||
cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_times=2)
|
|
||||||
cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_tail_percent=0.5)
|
|
||||||
set_algo_parameters(elementwise_op_strategy_follow=True)
|
set_algo_parameters(elementwise_op_strategy_follow=True)
|
||||||
resset_op_id()
|
resset_op_id()
|
||||||
np.random.seed(6)
|
np.random.seed(6)
|
||||||
|
@ -303,8 +299,16 @@ def test_train_32k_8p(epoch_size=3, batch_size=32, num_classes=32768): #1048576
|
||||||
assert v == [[dev_num, 1]]
|
assert v == [[dev_num, 1]]
|
||||||
|
|
||||||
allreduce_fusion_dict = _executor._get_allreduce_fusion(model._train_network)
|
allreduce_fusion_dict = _executor._get_allreduce_fusion(model._train_network)
|
||||||
|
|
||||||
print(allreduce_fusion_dict)
|
print(allreduce_fusion_dict)
|
||||||
|
return allreduce_fusion_dict
|
||||||
|
|
||||||
|
|
||||||
|
def test_train_32k_8p_fusion1(epoch_size=3, batch_size=32, num_classes=32768): #1048576 #131072 #32768 #8192
|
||||||
|
cost_model_context.set_cost_model_context(costmodel_gamma=0.001, costmodel_beta=260.0)
|
||||||
|
cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_algorithm=1)
|
||||||
|
cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_times=2)
|
||||||
|
cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_tail_percent=0.5)
|
||||||
|
allreduce_fusion_dict = train_32k_8p(epoch_size, batch_size, num_classes)
|
||||||
expect_dict = {'end_point.bias': 2,
|
expect_dict = {'end_point.bias': 2,
|
||||||
'end_point.weight': 2,
|
'end_point.weight': 2,
|
||||||
'layer4.2.bn3.beta': 2,
|
'layer4.2.bn3.beta': 2,
|
||||||
|
@ -382,11 +386,11 @@ def test_train_32k_8p(epoch_size=3, batch_size=32, num_classes=32768): #1048576
|
||||||
'layer3.1.bn1.beta': 2,
|
'layer3.1.bn1.beta': 2,
|
||||||
'layer3.1.bn1.gamma': 2,
|
'layer3.1.bn1.gamma': 2,
|
||||||
'layer3.1.conv1.weight': 2,
|
'layer3.1.conv1.weight': 2,
|
||||||
'layer3.0.bn_down_sample.beta': 1,
|
'layer3.0.bn_down_sample.beta': 2,
|
||||||
'layer3.0.bn_down_sample.gamma': 1,
|
'layer3.0.bn_down_sample.gamma': 2,
|
||||||
'layer3.0.conv_down_sample.weight': 2,
|
'layer3.0.conv_down_sample.weight': 2,
|
||||||
'layer3.0.bn3.beta': 1,
|
'layer3.0.bn3.beta': 2,
|
||||||
'layer3.0.bn3.gamma': 1,
|
'layer3.0.bn3.gamma': 2,
|
||||||
'layer3.0.conv3.weight': 2,
|
'layer3.0.conv3.weight': 2,
|
||||||
'layer3.0.bn2.beta': 2,
|
'layer3.0.bn2.beta': 2,
|
||||||
'layer3.0.bn2.gamma': 2,
|
'layer3.0.bn2.gamma': 2,
|
||||||
|
@ -412,8 +416,8 @@ def test_train_32k_8p(epoch_size=3, batch_size=32, num_classes=32768): #1048576
|
||||||
'layer2.2.bn1.beta': 2,
|
'layer2.2.bn1.beta': 2,
|
||||||
'layer2.2.bn1.gamma': 2,
|
'layer2.2.bn1.gamma': 2,
|
||||||
'layer2.2.conv1.weight': 2,
|
'layer2.2.conv1.weight': 2,
|
||||||
'layer2.1.bn3.beta': 1,
|
'layer2.1.bn3.beta': 2,
|
||||||
'layer2.1.bn3.gamma': 1,
|
'layer2.1.bn3.gamma': 2,
|
||||||
'layer2.1.conv3.weight': 2,
|
'layer2.1.conv3.weight': 2,
|
||||||
'layer2.1.bn2.beta': 2,
|
'layer2.1.bn2.beta': 2,
|
||||||
'layer2.1.bn2.gamma': 2,
|
'layer2.1.bn2.gamma': 2,
|
||||||
|
@ -421,11 +425,11 @@ def test_train_32k_8p(epoch_size=3, batch_size=32, num_classes=32768): #1048576
|
||||||
'layer2.1.bn1.beta': 2,
|
'layer2.1.bn1.beta': 2,
|
||||||
'layer2.1.bn1.gamma': 2,
|
'layer2.1.bn1.gamma': 2,
|
||||||
'layer2.1.conv1.weight': 2,
|
'layer2.1.conv1.weight': 2,
|
||||||
'layer2.0.bn_down_sample.beta': 1,
|
'layer2.0.bn_down_sample.beta': 2,
|
||||||
'layer2.0.bn_down_sample.gamma': 1,
|
'layer2.0.bn_down_sample.gamma': 2,
|
||||||
'layer2.0.conv_down_sample.weight': 2,
|
'layer2.0.conv_down_sample.weight': 2,
|
||||||
'layer2.0.bn3.beta': 1,
|
'layer2.0.bn3.beta': 2,
|
||||||
'layer2.0.bn3.gamma': 1,
|
'layer2.0.bn3.gamma': 2,
|
||||||
'layer2.0.conv3.weight': 2,
|
'layer2.0.conv3.weight': 2,
|
||||||
'layer2.0.bn2.beta': 2,
|
'layer2.0.bn2.beta': 2,
|
||||||
'layer2.0.bn2.gamma': 2,
|
'layer2.0.bn2.gamma': 2,
|
||||||
|
@ -442,8 +446,8 @@ def test_train_32k_8p(epoch_size=3, batch_size=32, num_classes=32768): #1048576
|
||||||
'layer1.2.bn1.beta': 2,
|
'layer1.2.bn1.beta': 2,
|
||||||
'layer1.2.bn1.gamma': 2,
|
'layer1.2.bn1.gamma': 2,
|
||||||
'layer1.2.conv1.weight': 2,
|
'layer1.2.conv1.weight': 2,
|
||||||
'layer1.1.bn3.beta': 1,
|
'layer1.1.bn3.beta': 2,
|
||||||
'layer1.1.bn3.gamma': 1,
|
'layer1.1.bn3.gamma': 2,
|
||||||
'layer1.1.conv3.weight': 2,
|
'layer1.1.conv3.weight': 2,
|
||||||
'layer1.1.bn2.beta': 2,
|
'layer1.1.bn2.beta': 2,
|
||||||
'layer1.1.bn2.gamma': 2,
|
'layer1.1.bn2.gamma': 2,
|
||||||
|
@ -451,11 +455,11 @@ def test_train_32k_8p(epoch_size=3, batch_size=32, num_classes=32768): #1048576
|
||||||
'layer1.1.bn1.beta': 2,
|
'layer1.1.bn1.beta': 2,
|
||||||
'layer1.1.bn1.gamma': 2,
|
'layer1.1.bn1.gamma': 2,
|
||||||
'layer1.1.conv1.weight': 2,
|
'layer1.1.conv1.weight': 2,
|
||||||
'layer1.0.bn_down_sample.beta': 1,
|
'layer1.0.bn_down_sample.beta': 2,
|
||||||
'layer1.0.bn_down_sample.gamma': 1,
|
'layer1.0.bn_down_sample.gamma': 2,
|
||||||
'layer1.0.conv_down_sample.weight': 2,
|
'layer1.0.conv_down_sample.weight': 2,
|
||||||
'layer1.0.bn3.beta': 1,
|
'layer1.0.bn3.beta': 2,
|
||||||
'layer1.0.bn3.gamma': 1,
|
'layer1.0.bn3.gamma': 2,
|
||||||
'layer1.0.conv3.weight': 2,
|
'layer1.0.conv3.weight': 2,
|
||||||
'layer1.0.bn2.beta': 2,
|
'layer1.0.bn2.beta': 2,
|
||||||
'layer1.0.bn2.gamma': 2,
|
'layer1.0.bn2.gamma': 2,
|
||||||
|
@ -465,7 +469,180 @@ def test_train_32k_8p(epoch_size=3, batch_size=32, num_classes=32768): #1048576
|
||||||
'layer1.0.conv1.weight': 2,
|
'layer1.0.conv1.weight': 2,
|
||||||
'bn1.beta': 1,
|
'bn1.beta': 1,
|
||||||
'bn1.gamma': 1,
|
'bn1.gamma': 1,
|
||||||
'conv1.weight': 2}
|
'conv1.weight': 1}
|
||||||
|
|
||||||
|
assert (allreduce_fusion_dict == expect_dict)
|
||||||
|
cost_model_context.reset_cost_model_context()
|
||||||
|
|
||||||
|
|
||||||
|
def test_train_32k_8p_fusion2(epoch_size=3, batch_size=32, num_classes=32768): #1048576 #131072 #32768 #8192
|
||||||
|
cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_algorithm=2)
|
||||||
|
cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_tail_time=0.1)
|
||||||
|
cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_allreduce_inherent_time=0.05)
|
||||||
|
cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_allreduce_bandwidth=0.000001)
|
||||||
|
cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_computation_time_parameter=0.0000015)
|
||||||
|
allreduce_fusion_dict = train_32k_8p(epoch_size, batch_size, num_classes)
|
||||||
|
expect_dict = {'end_point.bias': 2,
|
||||||
|
'end_point.weight': 2,
|
||||||
|
'layer4.2.bn3.beta': 2,
|
||||||
|
'layer4.2.bn3.gamma': 2,
|
||||||
|
'layer4.2.conv3.weight': 2,
|
||||||
|
'layer4.2.bn2.beta': 2,
|
||||||
|
'layer4.2.bn2.gamma': 2,
|
||||||
|
'layer4.2.conv2.weight': 2,
|
||||||
|
'layer4.2.bn1.beta': 2,
|
||||||
|
'layer4.2.bn1.gamma': 2,
|
||||||
|
'layer4.2.conv1.weight': 2,
|
||||||
|
'layer4.1.bn3.beta': 2,
|
||||||
|
'layer4.1.bn3.gamma': 2,
|
||||||
|
'layer4.1.conv3.weight': 2,
|
||||||
|
'layer4.1.bn2.beta': 2,
|
||||||
|
'layer4.1.bn2.gamma': 2,
|
||||||
|
'layer4.1.conv2.weight': 2,
|
||||||
|
'layer4.1.bn1.beta': 2,
|
||||||
|
'layer4.1.bn1.gamma': 2,
|
||||||
|
'layer4.1.conv1.weight': 2,
|
||||||
|
'layer4.0.bn_down_sample.beta': 2,
|
||||||
|
'layer4.0.bn_down_sample.gamma': 2,
|
||||||
|
'layer4.0.conv_down_sample.weight': 2,
|
||||||
|
'layer4.0.bn3.beta': 2,
|
||||||
|
'layer4.0.bn3.gamma': 2,
|
||||||
|
'layer4.0.conv3.weight': 2,
|
||||||
|
'layer4.0.bn2.beta': 2,
|
||||||
|
'layer4.0.bn2.gamma': 2,
|
||||||
|
'layer4.0.conv2.weight': 2,
|
||||||
|
'layer4.0.bn1.beta': 2,
|
||||||
|
'layer4.0.bn1.gamma': 2,
|
||||||
|
'layer4.0.conv1.weight': 2,
|
||||||
|
'layer3.5.bn3.beta': 2,
|
||||||
|
'layer3.5.bn3.gamma': 2,
|
||||||
|
'layer3.5.conv3.weight': 2,
|
||||||
|
'layer3.5.bn2.beta': 2,
|
||||||
|
'layer3.5.bn2.gamma': 2,
|
||||||
|
'layer3.5.conv2.weight': 2,
|
||||||
|
'layer3.5.bn1.beta': 2,
|
||||||
|
'layer3.5.bn1.gamma': 2,
|
||||||
|
'layer3.5.conv1.weight': 2,
|
||||||
|
'layer3.4.bn3.beta': 2,
|
||||||
|
'layer3.4.bn3.gamma': 2,
|
||||||
|
'layer3.4.conv3.weight': 2,
|
||||||
|
'layer3.4.bn2.beta': 2,
|
||||||
|
'layer3.4.bn2.gamma': 2,
|
||||||
|
'layer3.4.conv2.weight': 2,
|
||||||
|
'layer3.4.bn1.beta': 2,
|
||||||
|
'layer3.4.bn1.gamma': 2,
|
||||||
|
'layer3.4.conv1.weight': 2,
|
||||||
|
'layer3.3.bn3.beta': 2,
|
||||||
|
'layer3.3.bn3.gamma': 2,
|
||||||
|
'layer3.3.conv3.weight': 2,
|
||||||
|
'layer3.3.bn2.beta': 2,
|
||||||
|
'layer3.3.bn2.gamma': 2,
|
||||||
|
'layer3.3.conv2.weight': 2,
|
||||||
|
'layer3.3.bn1.beta': 2,
|
||||||
|
'layer3.3.bn1.gamma': 2,
|
||||||
|
'layer3.3.conv1.weight': 2,
|
||||||
|
'layer3.2.bn3.beta': 2,
|
||||||
|
'layer3.2.bn3.gamma': 2,
|
||||||
|
'layer3.2.conv3.weight': 2,
|
||||||
|
'layer3.2.bn2.beta': 2,
|
||||||
|
'layer3.2.bn2.gamma': 2,
|
||||||
|
'layer3.2.conv2.weight': 2,
|
||||||
|
'layer3.2.bn1.beta': 2,
|
||||||
|
'layer3.2.bn1.gamma': 2,
|
||||||
|
'layer3.2.conv1.weight': 2,
|
||||||
|
'layer3.1.bn3.beta': 2,
|
||||||
|
'layer3.1.bn3.gamma': 2,
|
||||||
|
'layer3.1.conv3.weight': 2,
|
||||||
|
'layer3.1.bn2.beta': 2,
|
||||||
|
'layer3.1.bn2.gamma': 2,
|
||||||
|
'layer3.1.conv2.weight': 2,
|
||||||
|
'layer3.1.bn1.beta': 2,
|
||||||
|
'layer3.1.bn1.gamma': 2,
|
||||||
|
'layer3.1.conv1.weight': 2,
|
||||||
|
'layer3.0.bn_down_sample.beta': 2,
|
||||||
|
'layer3.0.bn_down_sample.gamma': 2,
|
||||||
|
'layer3.0.conv_down_sample.weight': 2,
|
||||||
|
'layer3.0.bn3.beta': 2,
|
||||||
|
'layer3.0.bn3.gamma': 2,
|
||||||
|
'layer3.0.conv3.weight': 2,
|
||||||
|
'layer3.0.bn2.beta': 2,
|
||||||
|
'layer3.0.bn2.gamma': 2,
|
||||||
|
'layer3.0.conv2.weight': 2,
|
||||||
|
'layer3.0.bn1.beta': 2,
|
||||||
|
'layer3.0.bn1.gamma': 2,
|
||||||
|
'layer3.0.conv1.weight': 2,
|
||||||
|
'layer2.3.bn3.beta': 2,
|
||||||
|
'layer2.3.bn3.gamma': 2,
|
||||||
|
'layer2.3.conv3.weight': 2,
|
||||||
|
'layer2.3.bn2.beta': 2,
|
||||||
|
'layer2.3.bn2.gamma': 2,
|
||||||
|
'layer2.3.conv2.weight': 2,
|
||||||
|
'layer2.3.bn1.beta': 2,
|
||||||
|
'layer2.3.bn1.gamma': 2,
|
||||||
|
'layer2.3.conv1.weight': 2,
|
||||||
|
'layer2.2.bn3.beta': 2,
|
||||||
|
'layer2.2.bn3.gamma': 2,
|
||||||
|
'layer2.2.conv3.weight': 2,
|
||||||
|
'layer2.2.bn2.beta': 2,
|
||||||
|
'layer2.2.bn2.gamma': 2,
|
||||||
|
'layer2.2.conv2.weight': 2,
|
||||||
|
'layer2.2.bn1.beta': 2,
|
||||||
|
'layer2.2.bn1.gamma': 2,
|
||||||
|
'layer2.2.conv1.weight': 2,
|
||||||
|
'layer2.1.bn3.beta': 2,
|
||||||
|
'layer2.1.bn3.gamma': 2,
|
||||||
|
'layer2.1.conv3.weight': 2,
|
||||||
|
'layer2.1.bn2.beta': 2,
|
||||||
|
'layer2.1.bn2.gamma': 2,
|
||||||
|
'layer2.1.conv2.weight': 2,
|
||||||
|
'layer2.1.bn1.beta': 2,
|
||||||
|
'layer2.1.bn1.gamma': 2,
|
||||||
|
'layer2.1.conv1.weight': 2,
|
||||||
|
'layer2.0.bn_down_sample.beta': 2,
|
||||||
|
'layer2.0.bn_down_sample.gamma': 2,
|
||||||
|
'layer2.0.conv_down_sample.weight': 2,
|
||||||
|
'layer2.0.bn3.beta': 2,
|
||||||
|
'layer2.0.bn3.gamma': 2,
|
||||||
|
'layer2.0.conv3.weight': 2,
|
||||||
|
'layer2.0.bn2.beta': 2,
|
||||||
|
'layer2.0.bn2.gamma': 2,
|
||||||
|
'layer2.0.conv2.weight': 2,
|
||||||
|
'layer2.0.bn1.beta': 2,
|
||||||
|
'layer2.0.bn1.gamma': 2,
|
||||||
|
'layer2.0.conv1.weight': 2,
|
||||||
|
'layer1.2.bn3.beta': 2,
|
||||||
|
'layer1.2.bn3.gamma': 2,
|
||||||
|
'layer1.2.conv3.weight': 2,
|
||||||
|
'layer1.2.bn2.beta': 2,
|
||||||
|
'layer1.2.bn2.gamma': 2,
|
||||||
|
'layer1.2.conv2.weight': 2,
|
||||||
|
'layer1.2.bn1.beta': 2,
|
||||||
|
'layer1.2.bn1.gamma': 2,
|
||||||
|
'layer1.2.conv1.weight': 2,
|
||||||
|
'layer1.1.bn3.beta': 2,
|
||||||
|
'layer1.1.bn3.gamma': 2,
|
||||||
|
'layer1.1.conv3.weight': 2,
|
||||||
|
'layer1.1.bn2.beta': 2,
|
||||||
|
'layer1.1.bn2.gamma': 2,
|
||||||
|
'layer1.1.conv2.weight': 2,
|
||||||
|
'layer1.1.bn1.beta': 2,
|
||||||
|
'layer1.1.bn1.gamma': 2,
|
||||||
|
'layer1.1.conv1.weight': 2,
|
||||||
|
'layer1.0.bn_down_sample.beta': 2,
|
||||||
|
'layer1.0.bn_down_sample.gamma': 2,
|
||||||
|
'layer1.0.conv_down_sample.weight': 2,
|
||||||
|
'layer1.0.bn3.beta': 2,
|
||||||
|
'layer1.0.bn3.gamma': 2,
|
||||||
|
'layer1.0.conv3.weight': 2,
|
||||||
|
'layer1.0.bn2.beta': 2,
|
||||||
|
'layer1.0.bn2.gamma': 2,
|
||||||
|
'layer1.0.conv2.weight': 1,
|
||||||
|
'layer1.0.bn1.beta': 1,
|
||||||
|
'layer1.0.bn1.gamma': 1,
|
||||||
|
'layer1.0.conv1.weight': 1,
|
||||||
|
'bn1.beta': 1,
|
||||||
|
'bn1.gamma': 1,
|
||||||
|
'conv1.weight': 1}
|
||||||
|
|
||||||
assert (allreduce_fusion_dict == expect_dict)
|
assert (allreduce_fusion_dict == expect_dict)
|
||||||
cost_model_context.reset_cost_model_context()
|
cost_model_context.reset_cost_model_context()
|
||||||
|
|
Loading…
Reference in New Issue