support freeze param
This commit is contained in:
parent
9c3cf20bab
commit
f37306010e
|
@ -859,7 +859,9 @@ AnfNodePtr RefParameterToActualParameter(const AnfNodePtr &node) {
|
|||
auto new_cnode = GetInputNodeWithFilter(cnode_input, [&](const CNodePtr &cnode) {
|
||||
bool filter = IsPrimitiveCNode(cnode, prim::kPrimMicroStepAllGather) ||
|
||||
IsPrimitiveCNode(cnode, prim::kPrimLoad) || IsPrimitiveCNode(cnode, prim::kPrimDepend) ||
|
||||
IsPrimitiveCNode(cnode, prim::kPrimCast);
|
||||
IsPrimitiveCNode(cnode, prim::kPrimCast) ||
|
||||
(IsPrimitiveCNode(cnode, prim::kPrimAllGather) &&
|
||||
GetCNodePrimitive(cnode)->instance_name().find(PARALLEL_OPTIMIZER) != std::string::npos);
|
||||
return std::make_pair(filter, 1);
|
||||
});
|
||||
return RefParameterToActualParameter(new_cnode);
|
||||
|
|
|
@ -90,3 +90,27 @@ def test_pipeline_split_stage1():
|
|||
optimizer = nn.Lamb(params, learning_rate=0.01)
|
||||
model = Model(net, optimizer=optimizer)
|
||||
model.train(2, dataset, dataset_sink_mode=False)
|
||||
|
||||
def test_pipeline_split_stage1_save_stra():
|
||||
'''
|
||||
Feature: pipeline + grad_freeze + stage1 + opt_shard + save_strategy
|
||||
Description: In pipeline mode, stage1's param's requires_grad = False, expected success
|
||||
Expectation: success
|
||||
'''
|
||||
context.set_auto_parallel_context(device_num=32, global_rank=16, pipeline_stages=2, enable_parallel_optimizer=True,
|
||||
parallel_optimizer_config={"parallel_optimizer_threshold": 1},
|
||||
strategy_ckpt_config={"save_file": "./strategy_freeze_stage1.ckpt",
|
||||
"only_trainable_params": False})
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
|
||||
data = Tensor(np.ones([32, 64]), dtype=ms.float32)
|
||||
label = Tensor(np.ones([64, 64]), dtype=ms.float32)
|
||||
strategy1 = ((16, 1), (1, 1))
|
||||
strategy2 = ((8, 1), (1, 1))
|
||||
net = PipelineCell(PipelineSplit(strategy1, strategy2), 4)
|
||||
params = net.trainable_params()
|
||||
dataset = DatasetLenet(data, label, 3)
|
||||
optimizer = nn.Lamb(params, learning_rate=0.01)
|
||||
model = Model(net, optimizer=optimizer)
|
||||
model.train(2, dataset, dataset_sink_mode=False)
|
||||
stra = ms.build_searched_strategy("./strategy_freeze_stage1.ckpt")
|
||||
assert "cell.block.1.param" in stra
|
||||
|
|
Loading…
Reference in New Issue