add switch_simplify pass to a2

This commit is contained in:
chenfei 2021-07-09 14:41:59 +08:00
parent 0768df5144
commit ff066a0df3
2 changed files with 21 additions and 21 deletions

View File

@ -285,6 +285,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
opt::OptPassConfig a_1 = GetOptPassA1(irpass);
opt::OptPassConfig a_2 = opt::OptPassConfig(
{
irpass.switch_simplify_,
irpass.cast_eliminate_,
irpass.specialize_transform_,
irpass.merge_addn_,

View File

@ -25,7 +25,6 @@ from mindspore.train.model import Model
from mindspore.ops.composite import GradOperation
from mindspore.common import ParameterTuple
context.set_context(mode=context.GRAPH_MODE)
@ -87,11 +86,11 @@ def _count_unequal_element(data_expected, data_me, rtol, atol):
assert data_expected.shape == data_me.shape
total_count = len(data_expected.flatten())
error = np.abs(data_expected - data_me)
greater = np.greater(error, atol + np.abs(data_me)*rtol)
greater = np.greater(error, atol + np.abs(data_me) * rtol)
loss_count = np.count_nonzero(greater)
assert (loss_count/total_count) < rtol, \
"\ndata_expected_std:{0}\ndata_me_error:{1}\nloss:{2}".\
format(data_expected[greater], data_me[greater], error[greater])
assert (loss_count / total_count) < rtol, \
"\ndata_expected_std:{0}\ndata_me_error:{1}\nloss:{2}". \
format(data_expected[greater], data_me[greater], error[greater])
def allclose_nparray(data_expected, data_me, rtol, atol, equal_nan=True):
@ -115,10 +114,10 @@ class ControlGraphSupportNotEqual(Cell):
else:
out2 = input_data / input_data
if x == z:
out3_f = (lambda a: a+a)
out3_f = (lambda a: a + a)
out3 = out3_f(input_data)
else:
out3_f = (lambda a: a+a+a)
out3_f = (lambda a: a + a + a)
out3 = out3_f(input_data)
return out, out2, out3
@ -175,15 +174,15 @@ class ControlBprop(Cell):
else:
out2 = input_data / input_data
if x == z:
out3_f = (lambda a: a+a)
out3_f = (lambda a: a + a)
out3 = out3_f(input_data)
else:
out3_f = (lambda a: a+a+a)
out3_f = (lambda a: a + a + a)
out3 = out3_f(input_data)
return out, out2, out3
def bprop(self, x, y, z, input_data, out, dout):
return x*2, y*3, z, input_data*5.1
return x * 2, y * 3, z, input_data * 5.1
@pytest.mark.level1
@ -199,10 +198,10 @@ def test_ctrl_if_while_bprop_true():
grad_net = GradOfAllInputs(net, sens_param=False)
grad_net.set_train()
grads = grad_net(Tensor(x), Tensor(y), Tensor(x), Tensor(input_data))
allclose_nparray(x*2, grads[0].asnumpy(), 0.0000, 0.0000)
allclose_nparray(y*3, grads[1].asnumpy(), 0.0000, 0.0000)
allclose_nparray(x * 2, grads[0].asnumpy(), 0.0000, 0.0000)
allclose_nparray(y * 3, grads[1].asnumpy(), 0.0000, 0.0000)
allclose_nparray(x, grads[2].asnumpy(), 0.0000, 0.0000)
allclose_nparray(input_data*5.1, grads[3].asnumpy(), 0.0000, 0.0000)
allclose_nparray(input_data * 5.1, grads[3].asnumpy(), 0.0000, 0.0000)
class TwoInput(Cell):
@ -234,7 +233,7 @@ class InlineBpropTwoInput1(Cell):
grads = self.grad(x, y)
else:
grads = self.grad(x, y)
return grads[0]*2, grads[1]*2
return grads[0] * 2, grads[1] * 2
@pytest.mark.level1
@ -248,8 +247,8 @@ def test_ctrl_if_while_bprop_inlinebprop_twoinput():
grad_net = GradOfAllInputs(net, sens_param=False)
grad_net.set_train()
grads = grad_net(input1, input2)
allclose_nparray(input1.asnumpy()*2, grads[1].asnumpy(), 0, 0)
allclose_nparray(input2.asnumpy()*2, grads[0].asnumpy(), 0, 0)
allclose_nparray(input1.asnumpy() * 2, grads[1].asnumpy(), 0, 0)
allclose_nparray(input2.asnumpy() * 2, grads[0].asnumpy(), 0, 0)
class ControlOneIfOneParaOneAddn(Cell):
@ -467,7 +466,7 @@ class SideEffectPrintInHighOrdeAddnNet(Cell):
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_side_effect_high_order_print_in_high_order_net():
print_file = os.getcwd()+"/test_side_effect_high_order_print_in_high_order_net.data"
print_file = os.getcwd() + "/test_side_effect_high_order_print_in_high_order_net.data"
context.set_context(print_file_path=print_file)
net = SideEffectPrintInHighOrdeAddnNet()
out1 = net(Tensor([9.0], ms.float32))
@ -617,7 +616,7 @@ class HighGrad(Cell):
def __init__(self, network, grad_list, sens_param=False, real_inputs_count=None):
super().__init__()
self.grads = [network]
for i in range(len(grad_list)-1):
for i in range(len(grad_list) - 1):
_grad = grad_list[i](self.grads[i], sens_param=False)
self.grads.append(_grad)
self.final_grad = grad_list[-1](self.grads[-1],
@ -676,10 +675,10 @@ class SideEffectControlFlowAssignDependWhileNet(Cell):
return grad_out
@pytest.mark.level1
# Now the case can't pass because the GPU RT problem, so only run on Ascend current time.
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_side_effect_grad_control_flow_assign_depend_while_net():
context.set_context(mode=context.GRAPH_MODE)