forked from mindspore-Ecosystem/mindspore
!3564 [Auto parallel] Cost model for GPU
Merge pull request !3564 from Xiaoda/15-r0.6-add-new-gpu-costmodel
This commit is contained in:
commit
2a6884d97c
|
@ -34,7 +34,8 @@ namespace parallel {
|
|||
#define OPERATOR_TO_OPERATOR_CONNECTOR "-"
|
||||
#define DEFAULT_DEVICE_MEMORY_CAPACITY (1024.0 * 1024.0 * 1024.0 * 16.0)
|
||||
#define DEFAULT_COST_MODEL_ALPHA 1.0
|
||||
#define DEFAULT_COST_MODEL_BETA 400.0
|
||||
#define DEFAULT_COST_MODEL_BETA_ASCEND 400.0 // for 'device_target = Ascend'
|
||||
#define DEFAULT_COST_MODEL_BETA_GPU 50.0 // for 'device_target = GPU'
|
||||
#define DEFAULT_COST_MODEL_GAMMA 0.001
|
||||
#define DEFAULT_COST_MODEL_SIMPLIFY_CALCULATION true
|
||||
#define DEFAULT_COST_MODEL_COMMUNI_THRESHOLD 2048.0
|
||||
|
@ -73,7 +74,7 @@ class CostGraph {
|
|||
CostGraph() {
|
||||
dev_memory_ = DEFAULT_DEVICE_MEMORY_CAPACITY;
|
||||
costmodel_alpha_ = DEFAULT_COST_MODEL_ALPHA;
|
||||
costmodel_beta_ = DEFAULT_COST_MODEL_BETA;
|
||||
costmodel_beta_ = DEFAULT_COST_MODEL_BETA_ASCEND;
|
||||
}
|
||||
~CostGraph() = default;
|
||||
void AddOperator(const OperatorInfoPtr &op) { ops_.push_back(op); }
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
|
||||
#include "frontend/parallel/allreduce_fusion/allreduce_fusion.h"
|
||||
#include "frontend/parallel/auto_parallel/graph_costmodel.h"
|
||||
#include "utils/context/ms_context.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace parallel {
|
||||
|
@ -41,7 +42,7 @@ CostModelContext::CostModelContext() {
|
|||
void CostModelContext::ResetCostModel() {
|
||||
device_memory_capacity_ = DEFAULT_DEVICE_MEMORY_CAPACITY;
|
||||
costmodel_alpha_ = DEFAULT_COST_MODEL_ALPHA;
|
||||
costmodel_beta_ = DEFAULT_COST_MODEL_BETA;
|
||||
costmodel_beta_ = DEFAULT_COST_MODEL_BETA_ASCEND;
|
||||
costmodel_gamma_ = DEFAULT_COST_MODEL_GAMMA;
|
||||
costmodel_communi_threshold_ = DEFAULT_COST_MODEL_COMMUNI_THRESHOLD;
|
||||
costmodel_communi_const_ = DEFAULT_COST_MODEL_COMMUNI_CONST;
|
||||
|
@ -66,6 +67,12 @@ void CostModelContext::ResetAlgoParameters() {
|
|||
elementwise_stra_follow_ = DEFAULT_ELEMENTWISE_OP_STRA_FOLLOW;
|
||||
}
|
||||
|
||||
void CostModelContext::set_costmodel_context_for_device(const std::string &device_target) {
|
||||
if (device_target == kGPUDevice) {
|
||||
costmodel_beta_ = DEFAULT_COST_MODEL_BETA_GPU;
|
||||
}
|
||||
}
|
||||
|
||||
void CostModelContext::set_device_memory_capacity(double dm_capacity) { device_memory_capacity_ = dm_capacity; }
|
||||
|
||||
void CostModelContext::set_costmodel_alpha(double cm_alpha) { costmodel_alpha_ = cm_alpha; }
|
||||
|
|
|
@ -35,6 +35,7 @@ class CostModelContext {
|
|||
|
||||
static std::shared_ptr<CostModelContext> GetInstance();
|
||||
|
||||
void set_costmodel_context_for_device(const std::string &);
|
||||
// DEVICE_MEMORY_CAPACITY
|
||||
void set_device_memory_capacity(double);
|
||||
double device_memory_capacity() const { return device_memory_capacity_; }
|
||||
|
|
|
@ -21,6 +21,7 @@
|
|||
#include "./common.h"
|
||||
#include "utils/convert_utils.h"
|
||||
#include "utils/tensorprint_utils.h"
|
||||
#include "frontend/parallel/costmodel_context.h"
|
||||
#ifndef NO_DLIB
|
||||
#include "tdt/tsd_client.h"
|
||||
#include "tdt/tdt_host_interface.h"
|
||||
|
@ -146,6 +147,7 @@ bool MsContext::set_device_target(const std::string &target) {
|
|||
} else {
|
||||
device_target_ = target;
|
||||
}
|
||||
parallel::CostModelContext::GetInstance()->set_costmodel_context_for_device(device_target_);
|
||||
MS_LOG(INFO) << "ms set context device target:" << target;
|
||||
return true;
|
||||
}
|
||||
|
|
|
@ -678,3 +678,56 @@ def test_train_64k_8p(batch_size=32, num_classes=65536): # 1048576 #131072 #327
|
|||
assert v == [[1, 1], [dev_num, 1]]
|
||||
elif re.search('ReduceSum-op', k) is not None:
|
||||
assert v == [[1, dev_num]]
|
||||
|
||||
|
||||
def test_train_8k_8p_gpu(batch_size=32, num_classes=8192):
|
||||
dev_num = 8
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
context.set_auto_parallel_context(parallel_mode=ParallelMode.AUTO_PARALLEL, device_num=dev_num)
|
||||
set_algo_parameters(elementwise_op_strategy_follow=True)
|
||||
resset_op_id()
|
||||
np.random.seed(6)
|
||||
input_np = np.ones([batch_size, 3, 224, 224]).astype(np.float32)
|
||||
label_np = np.zeros([batch_size]).astype(np.int32)
|
||||
for i in range(0, batch_size):
|
||||
label_np[i] = i % num_classes
|
||||
dataset = DatasetLenet(Tensor(input_np), Tensor(label_np), 1)
|
||||
net = resnet50(num_classes)
|
||||
loss = SoftmaxCrossEntropyExpand(sparse=True)
|
||||
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01, 0.9)
|
||||
model = Model(net, loss_fn=loss, optimizer=opt)
|
||||
model.train(5, dataset, dataset_sink_mode=False)
|
||||
strategies = _executor._get_strategy(model._train_network)
|
||||
for (k, v) in strategies.items():
|
||||
if re.search('Conv2D-op', k) is not None:
|
||||
assert v[0][0] == dev_num
|
||||
elif re.search('MatMul-op', k) is not None:
|
||||
assert v == [[1, 1], [dev_num, 1]]
|
||||
elif re.search('ReduceSum-op', k) is not None:
|
||||
assert v == [[1, dev_num]]
|
||||
|
||||
def test_train_4k_8p_gpu(batch_size=32, num_classes=4096):
|
||||
dev_num = 8
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
context.set_auto_parallel_context(parallel_mode=ParallelMode.AUTO_PARALLEL, device_num=dev_num)
|
||||
set_algo_parameters(elementwise_op_strategy_follow=True)
|
||||
resset_op_id()
|
||||
np.random.seed(6)
|
||||
input_np = np.ones([batch_size, 3, 224, 224]).astype(np.float32)
|
||||
label_np = np.zeros([batch_size]).astype(np.int32)
|
||||
for i in range(0, batch_size):
|
||||
label_np[i] = i % num_classes
|
||||
dataset = DatasetLenet(Tensor(input_np), Tensor(label_np), 1)
|
||||
net = resnet50(num_classes)
|
||||
loss = SoftmaxCrossEntropyExpand(sparse=True)
|
||||
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01, 0.9)
|
||||
model = Model(net, loss_fn=loss, optimizer=opt)
|
||||
model.train(5, dataset, dataset_sink_mode=False)
|
||||
strategies = _executor._get_strategy(model._train_network)
|
||||
for (k, v) in strategies.items():
|
||||
if re.search('Conv2D-op', k) is not None:
|
||||
assert v[0][0] == dev_num
|
||||
elif re.search('MatMul-op', k) is not None:
|
||||
assert v == [[dev_num, 1], [1, 1]]
|
||||
elif re.search('ReduceSum-op', k) is not None:
|
||||
assert v == [[dev_num, 1]]
|
||||
|
|
Loading…
Reference in New Issue