mindspore/tests/ut/python/pynative_mode/test_hook.py

205 lines
6.0 KiB
Python
Raw Normal View History

2020-06-29 17:39:25 +08:00
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
2020-07-06 17:26:28 +08:00
import pytest
2020-06-04 19:53:40 +08:00
import mindspore.nn as nn
import mindspore.ops.operations as P
from mindspore import context, Tensor, ParameterTuple
from mindspore.common.initializer import TruncatedNormal
2020-05-29 20:23:05 +08:00
from mindspore.nn import WithLossCell, Momentum
2020-06-04 19:53:40 +08:00
from mindspore.ops import composite as C
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
2020-06-16 09:42:01 +08:00
cell_hook_done = False
var_hook_done = False
cell_bprop_done = False
2020-08-24 10:22:10 +08:00
grad_all = C.GradOperation('get_all', get_all=True)
def conv(in_channels, out_channels, kernel_size, stride=1, padding=0):
"""weight initial for conv layer"""
weight = weight_variable()
return nn.Conv2d(in_channels, out_channels,
kernel_size=kernel_size, stride=stride, padding=padding,
weight_init=weight, has_bias=False, pad_mode="valid")
2020-05-29 20:23:05 +08:00
def fc_with_initialize(input_channels, out_channels):
"""weight initial for fc layer"""
weight = weight_variable()
bias = weight_variable()
return nn.Dense(input_channels, out_channels, weight, bias)
2020-05-29 20:23:05 +08:00
def weight_variable():
"""weight initial"""
return TruncatedNormal(0.02)
2020-05-29 20:23:05 +08:00
def cell_hook_function(cell_id, grad_input, grad_output):
print(cell_id)
2020-06-16 09:42:01 +08:00
global cell_hook_done
cell_hook_done = True
2020-05-29 20:23:05 +08:00
assert (grad_output[0].asnumpy().shape == (32, 6, 14, 14))
assert (grad_input[0].asnumpy().shape == (32, 16, 10, 10))
def var_hook_function(grad_out):
print("grad:", grad_out)
2020-06-16 09:42:01 +08:00
global var_hook_done
var_hook_done = True
2020-05-29 20:23:05 +08:00
assert (grad_out[0].asnumpy().shape == (32, 120))
2020-06-16 09:42:01 +08:00
class Block(nn.Cell):
def __init__(self):
super(Block, self).__init__()
self.relu = nn.ReLU()
def construct(self, x):
x = self.relu(x)
return x
def bprop(self, x, out, dout):
global cell_bprop_done
cell_bprop_done = True
grad = out.asnumpy() * dout.asnumpy()
grad = Tensor(grad)
return (grad,)
class LeNet5(nn.Cell):
"""
Lenet network
Args:
num_class (int): Num classes. Default: 10.
Returns:
Tensor, output tensor
Examples:
>>> LeNet(num_class=10)
"""
def __init__(self, num_class=10):
super(LeNet5, self).__init__()
self.num_class = num_class
self.batch_size = 32
self.conv1 = conv(1, 6, 5)
self.conv2 = conv(6, 16, 5)
self.conv2.register_backward_hook(cell_hook_function)
2020-06-16 09:42:01 +08:00
self.block = Block()
self.fc1 = fc_with_initialize(16 * 5 * 5, 120)
self.fc2 = fc_with_initialize(120, 84)
self.fc3 = fc_with_initialize(84, self.num_class)
self.relu = nn.ReLU()
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
self.reshape = P.Reshape()
self.hook = P.HookBackward(var_hook_function)
def construct(self, x):
x = self.conv1(x)
x = self.relu(x)
x = self.max_pool2d(x)
x = self.conv2(x)
2020-06-16 09:42:01 +08:00
x = self.block(x)
x = self.max_pool2d(x)
x = self.reshape(x, (self.batch_size, -1))
x = self.fc1(x)
x = self.hook(x)
x = self.relu(x)
x = self.fc2(x)
x = self.relu(x)
x = self.fc3(x)
return x
2020-05-29 20:23:05 +08:00
class GradWrap(nn.Cell):
""" GradWrap definition """
def __init__(self, network):
super(GradWrap, self).__init__(auto_prefix=False)
self.network = network
self.weights = ParameterTuple(filter(lambda x: x.requires_grad, network.get_parameters()))
def construct(self, x, label):
weights = self.weights
return C.GradOperation('get_by_list', get_by_list=True)(self.network, weights)(x, label)
2020-05-29 20:23:05 +08:00
def test_hook():
net = LeNet5()
optimizer = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.1, 0.9)
criterion = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=False)
net_with_criterion = WithLossCell(net, criterion)
train_network = GradWrap(net_with_criterion)
train_network.set_train()
2020-05-29 20:23:05 +08:00
input_data = Tensor(np.ones([net.batch_size, 1, 32, 32]).astype(np.float32) * 0.01)
label = Tensor(np.ones([net.batch_size, net.num_class]).astype(np.float32))
output = net(Tensor(input_data))
loss_output = criterion(output, label)
grads = train_network(input_data, label)
success = optimizer(grads)
2020-06-16 09:42:01 +08:00
assert cell_hook_done
assert var_hook_done
assert cell_bprop_done
2020-08-05 14:32:34 +08:00
print(loss_output.asnumpy())
2020-07-06 17:26:28 +08:00
bprop_debug = False
class MulAdd(nn.Cell):
def __init__(self):
super(MulAdd, self).__init__()
def construct(self, x, y):
2020-07-06 17:26:28 +08:00
return 2 * x * x + y * y
def bprop(self, x, y, out, dout):
2020-07-06 17:26:28 +08:00
global bprop_debug
bprop_debug = True
return dout, 2 * y
2020-05-29 20:23:05 +08:00
def test_custom_bprop():
mul_add = MulAdd()
mul_add.bprop_debug = True
2020-07-06 17:26:28 +08:00
x = Tensor(np.array([1, 2, 3]).astype(np.int32))
y = Tensor(np.array([2, 3, 4]).astype(np.int32))
2020-08-24 10:22:10 +08:00
grad_all(mul_add)(x, y)
2020-07-06 17:26:28 +08:00
assert bprop_debug
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
def construct(self, x, y):
return 2 * x * x + y * y
def test_grad_all():
net = Net()
x = Tensor(np.array([1, 2, 3]).astype(np.int32))
y = Tensor(np.array([2, 3, 4]).astype(np.int32))
2020-08-24 10:22:10 +08:00
res = grad_all(net)(x, y)
2020-07-06 17:26:28 +08:00
print(res)
def test_check_input():
net = Net()
x = np.array([1, 2, 3])
y = np.array([2, 3, 4])
with pytest.raises(TypeError):
net(x, y)