forked from mindspore-Ecosystem/mindspore
fix get nullptr when use graph manager
This commit is contained in:
parent
92b54922cb
commit
9df3a8613c
|
@ -65,6 +65,7 @@ void MemSwapManager::Init(const mindspore::session::KernelGraph *kernel_graph) {
|
||||||
void MemSwapManager::SaveUserKernelTopoOrder(const mindspore::session::KernelGraph *kernel_graph) {
|
void MemSwapManager::SaveUserKernelTopoOrder(const mindspore::session::KernelGraph *kernel_graph) {
|
||||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||||
FuncGraphManagerPtr manager = kernel_graph->manager();
|
FuncGraphManagerPtr manager = kernel_graph->manager();
|
||||||
|
MS_EXCEPTION_IF_NULL(manager);
|
||||||
NodeUsersMap user_map = manager->node_users();
|
NodeUsersMap user_map = manager->node_users();
|
||||||
for (const auto &kernel : execution_order_) {
|
for (const auto &kernel : execution_order_) {
|
||||||
auto iter = user_map.find(kernel);
|
auto iter = user_map.find(kernel);
|
||||||
|
|
|
@ -532,6 +532,7 @@ KernelGraphPtr SessionBasic::ConstructKernelGraph(const AnfNodePtrList &lst, con
|
||||||
graph->set_output(ConstructOutput(outputs, graph));
|
graph->set_output(ConstructOutput(outputs, graph));
|
||||||
MS_EXCEPTION_IF_NULL(context_);
|
MS_EXCEPTION_IF_NULL(context_);
|
||||||
FuncGraphManagerPtr manager = MakeManager({graph});
|
FuncGraphManagerPtr manager = MakeManager({graph});
|
||||||
|
context_->AddManager(manager);
|
||||||
if (manager) {
|
if (manager) {
|
||||||
manager->AddFuncGraph(graph);
|
manager->AddFuncGraph(graph);
|
||||||
graph->set_manager(manager);
|
graph->set_manager(manager);
|
||||||
|
|
|
@ -37,8 +37,10 @@ class Context : public pipeline::ResourceBase {
|
||||||
|
|
||||||
uint32_t device_id() const { return device_id_; }
|
uint32_t device_id() const { return device_id_; }
|
||||||
static std::shared_ptr<Context> GetInstance();
|
static std::shared_ptr<Context> GetInstance();
|
||||||
|
void AddManager(const FuncGraphManagerPtr &m) { manager_list_.push_back(m); }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
std::vector<FuncGraphManagerPtr> manager_list_;
|
||||||
std::string target_;
|
std::string target_;
|
||||||
uint32_t device_id_;
|
uint32_t device_id_;
|
||||||
};
|
};
|
||||||
|
|
|
@ -179,7 +179,8 @@ class ResidualBlockWithDown(Cell):
|
||||||
self.relu = P.ReLU()
|
self.relu = P.ReLU()
|
||||||
self.downSample = down_sample
|
self.downSample = down_sample
|
||||||
|
|
||||||
self.conv_down_sample = conv1x1(in_channels, out_channels, stride=stride, padding=0)
|
self.conv_down_sample = conv1x1(
|
||||||
|
in_channels, out_channels, stride=stride, padding=0)
|
||||||
self.bn_down_sample = bn_with_initialize(out_channels)
|
self.bn_down_sample = bn_with_initialize(out_channels)
|
||||||
self.add = TensorAdd()
|
self.add = TensorAdd()
|
||||||
|
|
||||||
|
@ -210,7 +211,8 @@ class MakeLayer0(Cell):
|
||||||
|
|
||||||
def __init__(self, block, layer_num, in_channels, out_channels, stride):
|
def __init__(self, block, layer_num, in_channels, out_channels, stride):
|
||||||
super(MakeLayer0, self).__init__()
|
super(MakeLayer0, self).__init__()
|
||||||
self.a = ResidualBlockWithDown(in_channels, out_channels, stride=1, down_sample=True)
|
self.a = ResidualBlockWithDown(
|
||||||
|
in_channels, out_channels, stride=1, down_sample=True)
|
||||||
self.b = block(out_channels, out_channels, stride=stride)
|
self.b = block(out_channels, out_channels, stride=stride)
|
||||||
self.c = block(out_channels, out_channels, stride=1)
|
self.c = block(out_channels, out_channels, stride=1)
|
||||||
|
|
||||||
|
@ -226,7 +228,8 @@ class MakeLayer1(Cell):
|
||||||
|
|
||||||
def __init__(self, block, layer_num, in_channels, out_channels, stride):
|
def __init__(self, block, layer_num, in_channels, out_channels, stride):
|
||||||
super(MakeLayer1, self).__init__()
|
super(MakeLayer1, self).__init__()
|
||||||
self.a = ResidualBlockWithDown(in_channels, out_channels, stride=stride, down_sample=True)
|
self.a = ResidualBlockWithDown(
|
||||||
|
in_channels, out_channels, stride=stride, down_sample=True)
|
||||||
self.b = block(out_channels, out_channels, stride=1)
|
self.b = block(out_channels, out_channels, stride=1)
|
||||||
self.c = block(out_channels, out_channels, stride=1)
|
self.c = block(out_channels, out_channels, stride=1)
|
||||||
self.d = block(out_channels, out_channels, stride=1)
|
self.d = block(out_channels, out_channels, stride=1)
|
||||||
|
@ -244,7 +247,8 @@ class MakeLayer2(Cell):
|
||||||
|
|
||||||
def __init__(self, block, layer_num, in_channels, out_channels, stride):
|
def __init__(self, block, layer_num, in_channels, out_channels, stride):
|
||||||
super(MakeLayer2, self).__init__()
|
super(MakeLayer2, self).__init__()
|
||||||
self.a = ResidualBlockWithDown(in_channels, out_channels, stride=stride, down_sample=True)
|
self.a = ResidualBlockWithDown(
|
||||||
|
in_channels, out_channels, stride=stride, down_sample=True)
|
||||||
self.b = block(out_channels, out_channels, stride=1)
|
self.b = block(out_channels, out_channels, stride=1)
|
||||||
self.c = block(out_channels, out_channels, stride=1)
|
self.c = block(out_channels, out_channels, stride=1)
|
||||||
self.d = block(out_channels, out_channels, stride=1)
|
self.d = block(out_channels, out_channels, stride=1)
|
||||||
|
@ -266,7 +270,8 @@ class MakeLayer3(Cell):
|
||||||
|
|
||||||
def __init__(self, block, layer_num, in_channels, out_channels, stride):
|
def __init__(self, block, layer_num, in_channels, out_channels, stride):
|
||||||
super(MakeLayer3, self).__init__()
|
super(MakeLayer3, self).__init__()
|
||||||
self.a = ResidualBlockWithDown(in_channels, out_channels, stride=stride, down_sample=True)
|
self.a = ResidualBlockWithDown(
|
||||||
|
in_channels, out_channels, stride=stride, down_sample=True)
|
||||||
self.b = block(out_channels, out_channels, stride=1)
|
self.b = block(out_channels, out_channels, stride=1)
|
||||||
self.c = block(out_channels, out_channels, stride=1)
|
self.c = block(out_channels, out_channels, stride=1)
|
||||||
|
|
||||||
|
@ -330,14 +335,41 @@ def test_trainTensor(num_classes=10, epoch=8, batch_size=1):
|
||||||
net = resnet50(num_classes)
|
net = resnet50(num_classes)
|
||||||
lr = 0.1
|
lr = 0.1
|
||||||
momentum = 0.9
|
momentum = 0.9
|
||||||
optimizer = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, momentum)
|
optimizer = Momentum(filter(lambda x: x.requires_grad,
|
||||||
|
net.get_parameters()), lr, momentum)
|
||||||
criterion = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
|
criterion = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
|
||||||
net_with_criterion = WithLossCell(net, criterion)
|
net_with_criterion = WithLossCell(net, criterion)
|
||||||
train_network = TrainOneStepCell(net_with_criterion, optimizer) # optimizer
|
train_network = TrainOneStepCell(
|
||||||
|
net_with_criterion, optimizer) # optimizer
|
||||||
train_network.set_train()
|
train_network.set_train()
|
||||||
losses = []
|
losses = []
|
||||||
for i in range(0, epoch):
|
for i in range(0, epoch):
|
||||||
data = Tensor(np.ones([batch_size, 3, 224, 224]).astype(np.float32) * 0.01)
|
data = Tensor(np.ones([batch_size, 3, 224, 224]
|
||||||
|
).astype(np.float32) * 0.01)
|
||||||
|
label = Tensor(np.ones([batch_size]).astype(np.int32))
|
||||||
|
loss = train_network(data, label)
|
||||||
|
losses.append(loss)
|
||||||
|
assert (losses[-1].asnumpy() < 1)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_trainTensor_big_batchSize(num_classes=10, epoch=8, batch_size=170):
|
||||||
|
net = resnet50(num_classes)
|
||||||
|
lr = 0.1
|
||||||
|
momentum = 0.9
|
||||||
|
optimizer = Momentum(filter(lambda x: x.requires_grad,
|
||||||
|
net.get_parameters()), lr, momentum)
|
||||||
|
criterion = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
|
||||||
|
net_with_criterion = WithLossCell(net, criterion)
|
||||||
|
train_network = TrainOneStepCell(
|
||||||
|
net_with_criterion, optimizer) # optimizer
|
||||||
|
train_network.set_train()
|
||||||
|
losses = []
|
||||||
|
for i in range(0, epoch):
|
||||||
|
data = Tensor(np.ones([batch_size, 3, 224, 224]
|
||||||
|
).astype(np.float32) * 0.01)
|
||||||
label = Tensor(np.ones([batch_size]).astype(np.int32))
|
label = Tensor(np.ones([batch_size]).astype(np.int32))
|
||||||
loss = train_network(data, label)
|
loss = train_network(data, label)
|
||||||
losses.append(loss)
|
losses.append(loss)
|
||||||
|
@ -351,13 +383,16 @@ def test_trainTensor_amp(num_classes=10, epoch=18, batch_size=16):
|
||||||
net = resnet50(num_classes)
|
net = resnet50(num_classes)
|
||||||
lr = 0.1
|
lr = 0.1
|
||||||
momentum = 0.9
|
momentum = 0.9
|
||||||
optimizer = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, momentum)
|
optimizer = Momentum(filter(lambda x: x.requires_grad,
|
||||||
|
net.get_parameters()), lr, momentum)
|
||||||
criterion = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
|
criterion = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
|
||||||
train_network = amp.build_train_network(net, optimizer, criterion, level="O2")
|
train_network = amp.build_train_network(
|
||||||
|
net, optimizer, criterion, level="O2")
|
||||||
train_network.set_train()
|
train_network.set_train()
|
||||||
losses = []
|
losses = []
|
||||||
for i in range(0, epoch):
|
for i in range(0, epoch):
|
||||||
data = Tensor(np.ones([batch_size, 3, 224, 224]).astype(np.float32) * 0.01)
|
data = Tensor(np.ones([batch_size, 3, 224, 224]
|
||||||
|
).astype(np.float32) * 0.01)
|
||||||
label = Tensor(np.ones([batch_size]).astype(np.int32))
|
label = Tensor(np.ones([batch_size]).astype(np.int32))
|
||||||
loss = train_network(data, label)
|
loss = train_network(data, label)
|
||||||
losses.append(loss)
|
losses.append(loss)
|
||||||
|
|
Loading…
Reference in New Issue