fix permission bug when remove read_only_file on windows

This commit is contained in:
buxue 2021-12-07 19:36:39 +08:00
parent 6a5584589d
commit ae256a362b
3 changed files with 38 additions and 14 deletions

View File

@ -39,12 +39,12 @@ py::dict UpdateFuncGraphHyperParams(const FuncGraphPtr &func_graph, const py::di
const auto &new_value = params_init[param_name].cast<tensor::TensorPtr>();
MS_EXCEPTION_IF_NULL(new_value);
if (new_value->shape() != old_value->shape() || new_value->data_type() != old_value->data_type()) {
MS_EXCEPTION(ValueError) << "Only support update parameter by Tensor with same shape and dtype as it. "
"The parameter '"
<< param_name.cast<std::string>() << "' has shape " << old_value->shape()
<< " and dtype " << TypeIdLabel(old_value->data_type())
<< ", but got the update Tensor with shape " << new_value->shape() << " and dtype "
<< TypeIdLabel(new_value->data_type()) << ".";
MS_EXCEPTION(ValueError)
<< "Only support update parameter by Tensor or Parameter with same shape and dtype as it. "
"The parameter '"
<< param_name.cast<std::string>() << "' has shape " << old_value->shape() << " and dtype "
<< TypeIdLabel(old_value->data_type()) << ", but got the update value with shape " << new_value->shape()
<< " and dtype " << TypeIdLabel(new_value->data_type()) << ".";
}
new_param = fn(*new_value);
} else {

View File

@ -14,7 +14,9 @@
# ============================================================================
"""test get and init GraphCell parameters"""
import os
import stat
import numpy as np
import pytest
@ -72,8 +74,10 @@ def get_and_init_graph_cell_parameters():
assert np.array_equal(load_net.trainable_params()[0].asnumpy(), np_param + 2.0)
if os.path.isfile(mindir_name):
os.chmod(mindir_name, stat.S_IWUSR)
os.remove(mindir_name)
if os.path.isfile(ckpt_name):
os.chmod(ckpt_name, stat.S_IWUSR)
os.remove(ckpt_name)

View File

@ -57,7 +57,27 @@ def remove_generated_file(file_name):
def test_init_graph_cell_parameters_with_wrong_type():
"""
Description: load mind ir and update parameters with wrong type.
Expectation: raise a ValueError indicating the params type error.
Expectation: raise a ValueError indicating the params_init type error.
"""
context.set_context(mode=context.GRAPH_MODE)
net = Net()
mindir_name = "net_0.mindir"
export(net, input_a, input_b, file_name=mindir_name[:-7], file_format='MINDIR')
new_params = np.arange(2 * 3 * 4).reshape((2, 3, 4)).astype(np.float32)
with pytest.raises(TypeError) as err:
graph = load(mindir_name)
load_net = nn.GraphCell(graph, params_init=new_params)
load_net(input_a, input_b)
assert "The 'params_init' must be a dict, but got" in str(err.value)
remove_generated_file(mindir_name)
def test_init_graph_cell_parameters_with_wrong_value_type():
"""
Description: load mind ir and update parameters with wrong value type.
Expectation: raise a ValueError indicating the params value type error.
"""
context.set_context(mode=context.GRAPH_MODE)
net = Net()
@ -74,10 +94,10 @@ def test_init_graph_cell_parameters_with_wrong_type():
remove_generated_file(mindir_name)
def test_init_graph_cell_parameters_with_wrong_shape():
def test_init_graph_cell_parameters_with_wrong_value_shape():
"""
Description: load mind ir and update parameters with wrong tensor shape.
Expectation: raise a ValueError indicating the tensor shape error.
Expectation: raise a ValueError indicating the update value shape error.
"""
context.set_context(mode=context.PYNATIVE_MODE)
net = Net()
@ -90,25 +110,25 @@ def test_init_graph_cell_parameters_with_wrong_shape():
load_net = nn.GraphCell(graph, params_init=new_params)
load_net(input_a, input_b)
assert "Only support update parameter by Tensor with same shape and dtype as it" in str(err.value)
assert "Only support update parameter by Tensor or Parameter with same shape and dtype as it" in str(err.value)
remove_generated_file(mindir_name)
def test_init_graph_cell_parameters_with_wrong_dtype():
def test_init_graph_cell_parameters_with_wrong_value_dtype():
"""
Description: load mind ir and update parameters with wrong tensor dtype.
Expectation: raise a ValueError indicating the tensor dtype error.
Expectation: raise a ValueError indicating the update value dtype error.
"""
context.set_context(mode=context.GRAPH_MODE)
net = Net()
mindir_name = "net_3.mindir"
export(net, input_a, input_b, file_name=mindir_name[:-7], file_format='MINDIR')
new_params = {"weight": Parameter(np.arange(2 * 3).reshape((2, 3)).astype(np.float64))}
new_params = {"weight": Tensor(np.arange(2 * 3).reshape((2, 3)).astype(np.float64))}
with pytest.raises(ValueError) as err:
graph = load(mindir_name)
load_net = nn.GraphCell(graph, params_init=new_params)
load_net(input_a, input_b)
assert "Only support update parameter by Tensor with same shape and dtype as it" in str(err.value)
assert "Only support update parameter by Tensor or Parameter with same shape and dtype as it" in str(err.value)
remove_generated_file(mindir_name)