fix cpu pynative coredump

This commit is contained in:
chujinjin 2020-12-11 14:52:55 +08:00
parent 638fae9677
commit a0680113c4
2 changed files with 8 additions and 1 deletions

View File

@ -1718,6 +1718,11 @@ void SessionBasic::UpdateGraphDynamicShapeAttr(const NotNull<KernelGraphPtr> &ro
}
void SessionBasic::CleanUselessTensorsImpl(const std::shared_ptr<std::vector<tensor::TensorPtr>> &useless_tensors) {
auto ms_context = MsContext::GetInstance();
std::string device_target = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
if (device_target == "CPU") {
return;
}
for (const auto &tensor : *useless_tensors) {
tensor->set_device_address(nullptr);
}

View File

@ -131,9 +131,11 @@ class GradWrap(nn.Cell):
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_ascend_pynative_lenet():
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
context.set_context(mode=context.PYNATIVE_MODE)
epoch_size = 20
batch_size = 32