forked from mindspore-Ecosystem/mindspore
fix permission bug when remove read_only_file on windows
This commit is contained in:
parent
6a5584589d
commit
ae256a362b
|
@ -39,12 +39,12 @@ py::dict UpdateFuncGraphHyperParams(const FuncGraphPtr &func_graph, const py::di
|
|||
const auto &new_value = params_init[param_name].cast<tensor::TensorPtr>();
|
||||
MS_EXCEPTION_IF_NULL(new_value);
|
||||
if (new_value->shape() != old_value->shape() || new_value->data_type() != old_value->data_type()) {
|
||||
MS_EXCEPTION(ValueError) << "Only support update parameter by Tensor with same shape and dtype as it. "
|
||||
"The parameter '"
|
||||
<< param_name.cast<std::string>() << "' has shape " << old_value->shape()
|
||||
<< " and dtype " << TypeIdLabel(old_value->data_type())
|
||||
<< ", but got the update Tensor with shape " << new_value->shape() << " and dtype "
|
||||
<< TypeIdLabel(new_value->data_type()) << ".";
|
||||
MS_EXCEPTION(ValueError)
|
||||
<< "Only support update parameter by Tensor or Parameter with same shape and dtype as it. "
|
||||
"The parameter '"
|
||||
<< param_name.cast<std::string>() << "' has shape " << old_value->shape() << " and dtype "
|
||||
<< TypeIdLabel(old_value->data_type()) << ", but got the update value with shape " << new_value->shape()
|
||||
<< " and dtype " << TypeIdLabel(new_value->data_type()) << ".";
|
||||
}
|
||||
new_param = fn(*new_value);
|
||||
} else {
|
||||
|
|
|
@ -14,7 +14,9 @@
|
|||
# ============================================================================
|
||||
|
||||
"""test get and init GraphCell parameters"""
|
||||
|
||||
import os
|
||||
import stat
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
@ -72,8 +74,10 @@ def get_and_init_graph_cell_parameters():
|
|||
assert np.array_equal(load_net.trainable_params()[0].asnumpy(), np_param + 2.0)
|
||||
|
||||
if os.path.isfile(mindir_name):
|
||||
os.chmod(mindir_name, stat.S_IWUSR)
|
||||
os.remove(mindir_name)
|
||||
if os.path.isfile(ckpt_name):
|
||||
os.chmod(ckpt_name, stat.S_IWUSR)
|
||||
os.remove(ckpt_name)
|
||||
|
||||
|
||||
|
|
|
@ -57,7 +57,27 @@ def remove_generated_file(file_name):
|
|||
def test_init_graph_cell_parameters_with_wrong_type():
|
||||
"""
|
||||
Description: load mind ir and update parameters with wrong type.
|
||||
Expectation: raise a ValueError indicating the params type error.
|
||||
Expectation: raise a ValueError indicating the params_init type error.
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
net = Net()
|
||||
mindir_name = "net_0.mindir"
|
||||
export(net, input_a, input_b, file_name=mindir_name[:-7], file_format='MINDIR')
|
||||
|
||||
new_params = np.arange(2 * 3 * 4).reshape((2, 3, 4)).astype(np.float32)
|
||||
with pytest.raises(TypeError) as err:
|
||||
graph = load(mindir_name)
|
||||
load_net = nn.GraphCell(graph, params_init=new_params)
|
||||
load_net(input_a, input_b)
|
||||
|
||||
assert "The 'params_init' must be a dict, but got" in str(err.value)
|
||||
remove_generated_file(mindir_name)
|
||||
|
||||
|
||||
def test_init_graph_cell_parameters_with_wrong_value_type():
|
||||
"""
|
||||
Description: load mind ir and update parameters with wrong value type.
|
||||
Expectation: raise a ValueError indicating the params value type error.
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
net = Net()
|
||||
|
@ -74,10 +94,10 @@ def test_init_graph_cell_parameters_with_wrong_type():
|
|||
remove_generated_file(mindir_name)
|
||||
|
||||
|
||||
def test_init_graph_cell_parameters_with_wrong_shape():
|
||||
def test_init_graph_cell_parameters_with_wrong_value_shape():
|
||||
"""
|
||||
Description: load mind ir and update parameters with wrong tensor shape.
|
||||
Expectation: raise a ValueError indicating the tensor shape error.
|
||||
Expectation: raise a ValueError indicating the update value shape error.
|
||||
"""
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
net = Net()
|
||||
|
@ -90,25 +110,25 @@ def test_init_graph_cell_parameters_with_wrong_shape():
|
|||
load_net = nn.GraphCell(graph, params_init=new_params)
|
||||
load_net(input_a, input_b)
|
||||
|
||||
assert "Only support update parameter by Tensor with same shape and dtype as it" in str(err.value)
|
||||
assert "Only support update parameter by Tensor or Parameter with same shape and dtype as it" in str(err.value)
|
||||
remove_generated_file(mindir_name)
|
||||
|
||||
|
||||
def test_init_graph_cell_parameters_with_wrong_dtype():
|
||||
def test_init_graph_cell_parameters_with_wrong_value_dtype():
|
||||
"""
|
||||
Description: load mind ir and update parameters with wrong tensor dtype.
|
||||
Expectation: raise a ValueError indicating the tensor dtype error.
|
||||
Expectation: raise a ValueError indicating the update value dtype error.
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
net = Net()
|
||||
mindir_name = "net_3.mindir"
|
||||
export(net, input_a, input_b, file_name=mindir_name[:-7], file_format='MINDIR')
|
||||
|
||||
new_params = {"weight": Parameter(np.arange(2 * 3).reshape((2, 3)).astype(np.float64))}
|
||||
new_params = {"weight": Tensor(np.arange(2 * 3).reshape((2, 3)).astype(np.float64))}
|
||||
with pytest.raises(ValueError) as err:
|
||||
graph = load(mindir_name)
|
||||
load_net = nn.GraphCell(graph, params_init=new_params)
|
||||
load_net(input_a, input_b)
|
||||
|
||||
assert "Only support update parameter by Tensor with same shape and dtype as it" in str(err.value)
|
||||
assert "Only support update parameter by Tensor or Parameter with same shape and dtype as it" in str(err.value)
|
||||
remove_generated_file(mindir_name)
|
||||
|
|
Loading…
Reference in New Issue