support freeze param

This commit is contained in:
jiangzhenguang 2024-05-14 09:59:19 +08:00
parent 9c3cf20bab
commit f37306010e
2 changed files with 27 additions and 1 deletions

View File

@ -859,7 +859,9 @@ AnfNodePtr RefParameterToActualParameter(const AnfNodePtr &node) {
auto new_cnode = GetInputNodeWithFilter(cnode_input, [&](const CNodePtr &cnode) {
bool filter = IsPrimitiveCNode(cnode, prim::kPrimMicroStepAllGather) ||
IsPrimitiveCNode(cnode, prim::kPrimLoad) || IsPrimitiveCNode(cnode, prim::kPrimDepend) ||
IsPrimitiveCNode(cnode, prim::kPrimCast);
IsPrimitiveCNode(cnode, prim::kPrimCast) ||
(IsPrimitiveCNode(cnode, prim::kPrimAllGather) &&
GetCNodePrimitive(cnode)->instance_name().find(PARALLEL_OPTIMIZER) != std::string::npos);
return std::make_pair(filter, 1);
});
return RefParameterToActualParameter(new_cnode);

View File

@ -90,3 +90,27 @@ def test_pipeline_split_stage1():
optimizer = nn.Lamb(params, learning_rate=0.01)
model = Model(net, optimizer=optimizer)
model.train(2, dataset, dataset_sink_mode=False)
def test_pipeline_split_stage1_save_stra():
'''
Feature: pipeline + grad_freeze + stage1 + opt_shard + save_strategy
Description: In pipeline mode, stage1's param's requires_grad = False, expected success
Expectation: success
'''
context.set_auto_parallel_context(device_num=32, global_rank=16, pipeline_stages=2, enable_parallel_optimizer=True,
parallel_optimizer_config={"parallel_optimizer_threshold": 1},
strategy_ckpt_config={"save_file": "./strategy_freeze_stage1.ckpt",
"only_trainable_params": False})
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
data = Tensor(np.ones([32, 64]), dtype=ms.float32)
label = Tensor(np.ones([64, 64]), dtype=ms.float32)
strategy1 = ((16, 1), (1, 1))
strategy2 = ((8, 1), (1, 1))
net = PipelineCell(PipelineSplit(strategy1, strategy2), 4)
params = net.trainable_params()
dataset = DatasetLenet(data, label, 3)
optimizer = nn.Lamb(params, learning_rate=0.01)
model = Model(net, optimizer=optimizer)
model.train(2, dataset, dataset_sink_mode=False)
stra = ms.build_searched_strategy("./strategy_freeze_stage1.ckpt")
assert "cell.block.1.param" in stra