forked from mindspore-Ecosystem/mindspore
!35639 modify ckpt file chmod
Merge pull request !35639 from changzherui/mod_chmod_file
This commit is contained in:
commit
8a90d74860
|
@ -2352,7 +2352,6 @@ void DfGraphConvertor::ConvertTile(const FuncGraphPtr anf_graph) {
|
|||
|
||||
std::vector<int64_t> DfGraphConvertor::CastToInt(const ValuePtr &value) {
|
||||
if (value == nullptr) {
|
||||
MS_LOG(WARNING) << "Value ptr is nullptr.";
|
||||
return {};
|
||||
}
|
||||
std::vector<int64_t> cur_value = {};
|
||||
|
|
|
@ -189,6 +189,7 @@ def _save_weight(checkpoint_dir, model_name, iteration, params):
|
|||
else:
|
||||
logger.warning(f"Checkpoint dir: '{checkpoint_dir}' is not existed.")
|
||||
|
||||
|
||||
def _exec_save(ckpt_file_name, data_list, enc_key=None, enc_mode="AES-GCM"):
|
||||
"""Execute the process of saving checkpoint into file."""
|
||||
try:
|
||||
|
@ -234,7 +235,7 @@ def _exec_save(ckpt_file_name, data_list, enc_key=None, enc_mode="AES-GCM"):
|
|||
cipher_data += _encrypt(plain_data, len(plain_data), enc_key, len(enc_key), enc_mode)
|
||||
f.write(cipher_data)
|
||||
|
||||
os.chmod(ckpt_file_name, stat.S_IRUSR)
|
||||
os.chmod(ckpt_file_name, stat.S_IRUSR)
|
||||
|
||||
except BaseException as e:
|
||||
logger.critical("Failed to save the checkpoint file %s. Maybe don't have the permission to write files, "
|
||||
|
|
|
@ -82,7 +82,8 @@ def setup_module():
|
|||
|
||||
def test_save_graph():
|
||||
""" test_exec_save_graph """
|
||||
context.set_context(mode=context.GRAPH_MODE, print_file_path="print/print.pb")
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
class Net1(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net1, self).__init__()
|
||||
|
@ -108,7 +109,7 @@ def test_save_graph():
|
|||
|
||||
def test_save_checkpoint_for_list():
|
||||
""" test save_checkpoint for list"""
|
||||
context.set_context(mode=context.GRAPH_MODE, print_file_path="print/print.pb")
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
parameter_list = []
|
||||
one_param = {}
|
||||
param1 = {}
|
||||
|
@ -137,7 +138,7 @@ def test_load_checkpoint_error_filename():
|
|||
Description: Load checkpoint with error filename.
|
||||
Expectation: Raise value error for error filename.
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, print_file_path="print/print.pb")
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
ckpt_file_name = 1
|
||||
with pytest.raises(TypeError):
|
||||
load_checkpoint(ckpt_file_name)
|
||||
|
@ -218,7 +219,7 @@ def test_save_checkpoint_for_list_append_info_and_load_checkpoint():
|
|||
Description: Save checkpoint for list append info and load checkpoint with list append info.
|
||||
Expectation: Checkpoint for list append info can be saved and reloaded.
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, print_file_path="print/print.pb")
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
parameter_list = []
|
||||
one_param = {}
|
||||
param1 = {}
|
||||
|
@ -243,9 +244,9 @@ def test_save_checkpoint_for_list_append_info_and_load_checkpoint():
|
|||
par_dict = load_checkpoint(ckpt_file_name)
|
||||
|
||||
assert len(par_dict) == 9
|
||||
assert par_dict['param_test'].name == 'param_test'
|
||||
assert par_dict['param_test'].data.dtype == mstype.float32
|
||||
assert par_dict['param_test'].data.shape == (1, 3, 224, 224)
|
||||
assert par_dict.get('param_test').name == 'param_test'
|
||||
assert par_dict.get('param_test').data.dtype == mstype.float32
|
||||
assert par_dict.get('param_test').data.shape == (1, 3, 224, 224)
|
||||
|
||||
assert par_dict.get('par_string') == "string_test"
|
||||
assert par_dict.get('par_param').name == 'par_param'
|
||||
|
@ -258,7 +259,7 @@ def test_save_checkpoint_for_list_append_info_and_load_checkpoint():
|
|||
|
||||
def test_checkpoint_manager():
|
||||
""" test_checkpoint_manager """
|
||||
context.set_context(mode=context.GRAPH_MODE, print_file_path="print/print.pb")
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
ckp_mgr = _CheckpointManager()
|
||||
|
||||
ckpt_file_name = os.path.join(_cur_dir, './test-1_1.ckpt')
|
||||
|
@ -307,7 +308,7 @@ def test_checkpoint_manager():
|
|||
|
||||
|
||||
def test_load_param_into_net_error_net():
|
||||
context.set_context(mode=context.GRAPH_MODE, print_file_path="print/print.pb")
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
parameter_dict = {}
|
||||
one_param = Parameter(Tensor(np.ones(shape=(64, 3, 7, 7)), dtype=mstype.float32),
|
||||
name="conv1.weight")
|
||||
|
@ -317,14 +318,14 @@ def test_load_param_into_net_error_net():
|
|||
|
||||
|
||||
def test_load_param_into_net_error_dict():
|
||||
context.set_context(mode=context.GRAPH_MODE, print_file_path="print/print.pb")
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
net = Net(10)
|
||||
with pytest.raises(TypeError):
|
||||
load_param_into_net(net, '')
|
||||
|
||||
|
||||
def test_load_param_into_net_erro_dict_param():
|
||||
context.set_context(mode=context.GRAPH_MODE, print_file_path="print/print.pb")
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
net = Net(10)
|
||||
net.init_parameters_data()
|
||||
assert net.conv1.weight.data.asnumpy()[0][0][0][0] == 0
|
||||
|
@ -338,7 +339,7 @@ def test_load_param_into_net_erro_dict_param():
|
|||
|
||||
def test_load_param_into_net_has_more_param():
|
||||
""" test_load_param_into_net_has_more_param """
|
||||
context.set_context(mode=context.GRAPH_MODE, print_file_path="print/print.pb")
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
net = Net(10)
|
||||
net.init_parameters_data()
|
||||
assert net.conv1.weight.data.asnumpy()[0][0][0][0] == 0
|
||||
|
@ -355,7 +356,7 @@ def test_load_param_into_net_has_more_param():
|
|||
|
||||
|
||||
def test_load_param_into_net_param_type_and_shape_error():
|
||||
context.set_context(mode=context.GRAPH_MODE, print_file_path="print/print.pb")
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
net = Net(10)
|
||||
net.init_parameters_data()
|
||||
assert net.conv1.weight.data.asnumpy()[0][0][0][0] == 0
|
||||
|
@ -368,7 +369,7 @@ def test_load_param_into_net_param_type_and_shape_error():
|
|||
|
||||
|
||||
def test_load_param_into_net_param_type_error():
|
||||
context.set_context(mode=context.GRAPH_MODE, print_file_path="print/print.pb")
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
net = Net(10)
|
||||
net.init_parameters_data()
|
||||
assert net.conv1.weight.data.asnumpy()[0][0][0][0] == 0
|
||||
|
@ -382,7 +383,7 @@ def test_load_param_into_net_param_type_error():
|
|||
|
||||
|
||||
def test_load_param_into_net_param_shape_error():
|
||||
context.set_context(mode=context.GRAPH_MODE, print_file_path="print/print.pb")
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
net = Net(10)
|
||||
net.init_parameters_data()
|
||||
assert net.conv1.weight.data.asnumpy()[0][0][0][0] == 0
|
||||
|
@ -396,7 +397,7 @@ def test_load_param_into_net_param_shape_error():
|
|||
|
||||
|
||||
def test_load_param_into_net():
|
||||
context.set_context(mode=context.GRAPH_MODE, print_file_path="print/print.pb")
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
net = Net(10)
|
||||
net.init_parameters_data()
|
||||
assert net.conv1.weight.data.asnumpy()[0][0][0][0] == 0
|
||||
|
@ -411,7 +412,7 @@ def test_load_param_into_net():
|
|||
|
||||
def test_save_checkpoint_for_network():
|
||||
""" test save_checkpoint for network"""
|
||||
context.set_context(mode=context.GRAPH_MODE, print_file_path="print/print.pb")
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
net = Net()
|
||||
loss = SoftmaxCrossEntropyWithLogits(sparse=True)
|
||||
opt = Momentum(net.trainable_params(), 0.0, 0.9, 0.0001, 1024)
|
||||
|
@ -424,7 +425,7 @@ def test_save_checkpoint_for_network():
|
|||
|
||||
|
||||
def test_load_checkpoint_empty_file():
|
||||
context.set_context(mode=context.GRAPH_MODE, print_file_path="print/print.pb")
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
os.mknod("empty.ckpt")
|
||||
with pytest.raises(ValueError):
|
||||
load_checkpoint("empty.ckpt")
|
||||
|
@ -524,7 +525,7 @@ def test_load_checkpoint_specify_filter_prefix():
|
|||
|
||||
def test_save_and_load_checkpoint_for_network_with_encryption():
|
||||
""" test save and checkpoint for network with encryption"""
|
||||
context.set_context(mode=context.GRAPH_MODE, print_file_path="print/print.pb")
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
net = Net()
|
||||
loss = SoftmaxCrossEntropyWithLogits(sparse=True)
|
||||
opt = Momentum(net.trainable_params(), 0.0, 0.9, 0.0001, 1024)
|
||||
|
@ -569,7 +570,7 @@ class MYNET(nn.Cell):
|
|||
|
||||
@non_graph_engine
|
||||
def test_export():
|
||||
context.set_context(mode=context.GRAPH_MODE, print_file_path="print/print.pb")
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
net = MYNET()
|
||||
input_data = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]).astype(np.float32))
|
||||
with pytest.raises(ValueError):
|
||||
|
@ -578,7 +579,7 @@ def test_export():
|
|||
|
||||
@non_graph_engine
|
||||
def test_mindir_export():
|
||||
context.set_context(mode=context.GRAPH_MODE, print_file_path="print/print.pb")
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
net = MYNET()
|
||||
input_data = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]).astype(np.float32))
|
||||
export(net, input_data, file_name="./me_binary_export", file_format="MINDIR")
|
||||
|
@ -586,7 +587,7 @@ def test_mindir_export():
|
|||
|
||||
@non_graph_engine
|
||||
def test_mindir_export_and_load_with_encryption():
|
||||
context.set_context(mode=context.GRAPH_MODE, print_file_path="print/print.pb")
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
net = MYNET()
|
||||
input_data = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]).astype(np.float32))
|
||||
key = secrets.token_bytes(16)
|
||||
|
@ -594,7 +595,6 @@ def test_mindir_export_and_load_with_encryption():
|
|||
load("./me_cipher_binary_export.mindir", dec_key=key)
|
||||
|
||||
|
||||
|
||||
class PrintNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super(PrintNet, self).__init__()
|
||||
|
|
Loading…
Reference in New Issue