fix get nullptr when use graph manager

This commit is contained in:
lizhenyu 2020-05-29 20:53:50 +08:00
parent 92b54922cb
commit 9df3a8613c
4 changed files with 50 additions and 11 deletions

View File

@ -65,6 +65,7 @@ void MemSwapManager::Init(const mindspore::session::KernelGraph *kernel_graph) {
void MemSwapManager::SaveUserKernelTopoOrder(const mindspore::session::KernelGraph *kernel_graph) {
MS_EXCEPTION_IF_NULL(kernel_graph);
FuncGraphManagerPtr manager = kernel_graph->manager();
MS_EXCEPTION_IF_NULL(manager);
NodeUsersMap user_map = manager->node_users();
for (const auto &kernel : execution_order_) {
auto iter = user_map.find(kernel);

View File

@ -532,6 +532,7 @@ KernelGraphPtr SessionBasic::ConstructKernelGraph(const AnfNodePtrList &lst, con
graph->set_output(ConstructOutput(outputs, graph));
MS_EXCEPTION_IF_NULL(context_);
FuncGraphManagerPtr manager = MakeManager({graph});
context_->AddManager(manager);
if (manager) {
manager->AddFuncGraph(graph);
graph->set_manager(manager);

View File

@ -37,8 +37,10 @@ class Context : public pipeline::ResourceBase {
uint32_t device_id() const { return device_id_; }
static std::shared_ptr<Context> GetInstance();
void AddManager(const FuncGraphManagerPtr &m) { manager_list_.push_back(m); }
private:
std::vector<FuncGraphManagerPtr> manager_list_;
std::string target_;
uint32_t device_id_;
};

View File

@ -179,7 +179,8 @@ class ResidualBlockWithDown(Cell):
self.relu = P.ReLU()
self.downSample = down_sample
self.conv_down_sample = conv1x1(in_channels, out_channels, stride=stride, padding=0)
self.conv_down_sample = conv1x1(
in_channels, out_channels, stride=stride, padding=0)
self.bn_down_sample = bn_with_initialize(out_channels)
self.add = TensorAdd()
@ -210,7 +211,8 @@ class MakeLayer0(Cell):
def __init__(self, block, layer_num, in_channels, out_channels, stride):
super(MakeLayer0, self).__init__()
self.a = ResidualBlockWithDown(in_channels, out_channels, stride=1, down_sample=True)
self.a = ResidualBlockWithDown(
in_channels, out_channels, stride=1, down_sample=True)
self.b = block(out_channels, out_channels, stride=stride)
self.c = block(out_channels, out_channels, stride=1)
@ -226,7 +228,8 @@ class MakeLayer1(Cell):
def __init__(self, block, layer_num, in_channels, out_channels, stride):
super(MakeLayer1, self).__init__()
self.a = ResidualBlockWithDown(in_channels, out_channels, stride=stride, down_sample=True)
self.a = ResidualBlockWithDown(
in_channels, out_channels, stride=stride, down_sample=True)
self.b = block(out_channels, out_channels, stride=1)
self.c = block(out_channels, out_channels, stride=1)
self.d = block(out_channels, out_channels, stride=1)
@ -244,7 +247,8 @@ class MakeLayer2(Cell):
def __init__(self, block, layer_num, in_channels, out_channels, stride):
super(MakeLayer2, self).__init__()
self.a = ResidualBlockWithDown(in_channels, out_channels, stride=stride, down_sample=True)
self.a = ResidualBlockWithDown(
in_channels, out_channels, stride=stride, down_sample=True)
self.b = block(out_channels, out_channels, stride=1)
self.c = block(out_channels, out_channels, stride=1)
self.d = block(out_channels, out_channels, stride=1)
@ -266,7 +270,8 @@ class MakeLayer3(Cell):
def __init__(self, block, layer_num, in_channels, out_channels, stride):
super(MakeLayer3, self).__init__()
self.a = ResidualBlockWithDown(in_channels, out_channels, stride=stride, down_sample=True)
self.a = ResidualBlockWithDown(
in_channels, out_channels, stride=stride, down_sample=True)
self.b = block(out_channels, out_channels, stride=1)
self.c = block(out_channels, out_channels, stride=1)
@ -330,14 +335,41 @@ def test_trainTensor(num_classes=10, epoch=8, batch_size=1):
net = resnet50(num_classes)
lr = 0.1
momentum = 0.9
optimizer = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, momentum)
optimizer = Momentum(filter(lambda x: x.requires_grad,
net.get_parameters()), lr, momentum)
criterion = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
net_with_criterion = WithLossCell(net, criterion)
train_network = TrainOneStepCell(net_with_criterion, optimizer) # optimizer
train_network = TrainOneStepCell(
net_with_criterion, optimizer) # optimizer
train_network.set_train()
losses = []
for i in range(0, epoch):
data = Tensor(np.ones([batch_size, 3, 224, 224]).astype(np.float32) * 0.01)
data = Tensor(np.ones([batch_size, 3, 224, 224]
).astype(np.float32) * 0.01)
label = Tensor(np.ones([batch_size]).astype(np.int32))
loss = train_network(data, label)
losses.append(loss)
assert (losses[-1].asnumpy() < 1)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_trainTensor_big_batchSize(num_classes=10, epoch=8, batch_size=170):
net = resnet50(num_classes)
lr = 0.1
momentum = 0.9
optimizer = Momentum(filter(lambda x: x.requires_grad,
net.get_parameters()), lr, momentum)
criterion = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
net_with_criterion = WithLossCell(net, criterion)
train_network = TrainOneStepCell(
net_with_criterion, optimizer) # optimizer
train_network.set_train()
losses = []
for i in range(0, epoch):
data = Tensor(np.ones([batch_size, 3, 224, 224]
).astype(np.float32) * 0.01)
label = Tensor(np.ones([batch_size]).astype(np.int32))
loss = train_network(data, label)
losses.append(loss)
@ -351,13 +383,16 @@ def test_trainTensor_amp(num_classes=10, epoch=18, batch_size=16):
net = resnet50(num_classes)
lr = 0.1
momentum = 0.9
optimizer = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, momentum)
optimizer = Momentum(filter(lambda x: x.requires_grad,
net.get_parameters()), lr, momentum)
criterion = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
train_network = amp.build_train_network(net, optimizer, criterion, level="O2")
train_network = amp.build_train_network(
net, optimizer, criterion, level="O2")
train_network.set_train()
losses = []
for i in range(0, epoch):
data = Tensor(np.ones([batch_size, 3, 224, 224]).astype(np.float32) * 0.01)
data = Tensor(np.ones([batch_size, 3, 224, 224]
).astype(np.float32) * 0.01)
label = Tensor(np.ones([batch_size]).astype(np.int32))
loss = train_network(data, label)
losses.append(loss)