forked from mindspore-Ecosystem/mindspore
!46496 adapter jit connfig for cellwrap
Merge pull request !46496 from changzherui/mod_jit_config
This commit is contained in:
commit
e5556217c7
|
@ -16,5 +16,10 @@ mindspore.JitConfig
|
|||
- "O2": 手动优化与图算优化结合。
|
||||
- "O3": 性能优化,无法保证泛化性。
|
||||
|
||||
- **task_sink** (bool) - 数据是否直接下沉至处理器进行处理。默认值:True。
|
||||
- **task_sink** (str) - 设置执行模式,支持["auto", "sink", "no_sink"]。默认值:"auto"。
|
||||
|
||||
- "auto": 自动策略。
|
||||
- "sink": 计算图下沉策略。
|
||||
- "no_sink": 非计算图下沉策略。
|
||||
|
||||
- **kwargs** (dict) - 关键字参数字典。
|
||||
|
|
|
@ -31,7 +31,12 @@ class JitConfig:
|
|||
- "O2": Manual optimization and graph computation fusion.
|
||||
- "O3": Performance optimization, no generalization guaranteed.
|
||||
|
||||
task_sink (bool): Determines whether to pass the data through dataset channel. Default: True.
|
||||
exc_mode (str): Mode for execute the network. Supports ["auto", "sink", "no_sink"]. Default: "auto".
|
||||
|
||||
- "auto": Automatic Policies.
|
||||
- "sink": Build computational graphs with the sink mode.
|
||||
- "no_sink": Build computational graphs with no sink mode.
|
||||
|
||||
**kwargs (dict): A dictionary of keyword arguments that the class needs.
|
||||
|
||||
Examples:
|
||||
|
@ -42,13 +47,13 @@ class JitConfig:
|
|||
>>>
|
||||
>>> net.set_jit_config(jitconfig)
|
||||
"""
|
||||
def __init__(self, jit_level="O1", task_sink=True, **kwargs):
|
||||
def __init__(self, jit_level="O1", exc_mode="auto", **kwargs):
|
||||
if jit_level not in ["O0", "O1", "O2", "O3"]:
|
||||
raise ValueError("For 'jit_level' must be one of ['O0', 'O1', 'O2', 'O3'].")
|
||||
if not isinstance(task_sink, bool):
|
||||
raise TypeError("For 'task_sink' must be bool.")
|
||||
if exc_mode not in ['auto', 'sink', 'no_sink']:
|
||||
raise ValueError("For 'exc_mode' must be one of '['auto', 'sink', 'no_sink']'.")
|
||||
self.jit_config_dict = dict()
|
||||
self.jit_config_dict["jit_level"] = jit_level
|
||||
self.jit_config_dict["task_sink"] = str(int(task_sink))
|
||||
self.jit_config_dict["exc_mode"] = exc_mode
|
||||
for key, value in kwargs.items():
|
||||
self.jit_config_dict[key] = value
|
||||
|
|
|
@ -110,7 +110,7 @@ class WithLossCell(Cell):
|
|||
super(WithLossCell, self).__init__(auto_prefix=False)
|
||||
self._backbone = backbone
|
||||
self._loss_fn = loss_fn
|
||||
if backbone.jit_config_dict:
|
||||
if isinstance(backbone, Cell) and backbone.jit_config_dict:
|
||||
self._jit_config_dict = backbone.jit_config_dict
|
||||
|
||||
def construct(self, data, label):
|
||||
|
@ -182,6 +182,8 @@ class WithGradCell(Cell):
|
|||
else:
|
||||
self.network_with_loss = WithLossCell(self.network, self.loss_fn)
|
||||
self.network_with_loss.set_train()
|
||||
if isinstance(network, Cell) and network.jit_config_dict:
|
||||
self._jit_config_dict = network.jit_config_dict
|
||||
|
||||
def construct(self, *inputs):
|
||||
weights = self.weights
|
||||
|
@ -282,6 +284,8 @@ class ForwardValueAndGrad(Cell):
|
|||
self.get_by_list = get_by_list
|
||||
self.sens_param = sens_param
|
||||
self.grad = C.GradOperation(get_all=self.get_all, get_by_list=self.get_by_list, sens_param=self.sens_param)
|
||||
if isinstance(network, Cell) and network.jit_config_dict:
|
||||
self._jit_config_dict = network.jit_config_dict
|
||||
|
||||
def construct(self, *inputs):
|
||||
grad_inputs = inputs
|
||||
|
@ -375,7 +379,7 @@ class TrainOneStepCell(Cell):
|
|||
group=server_group_name)
|
||||
else:
|
||||
self.grad_reducer = DistributedGradReducer(self.weights, self.mean, self.degree)
|
||||
if network.jit_config_dict:
|
||||
if isinstance(network, Cell) and network.jit_config_dict:
|
||||
self._jit_config_dict = network.jit_config_dict
|
||||
|
||||
def construct(self, *inputs):
|
||||
|
@ -455,6 +459,8 @@ class _VirtualDatasetCell(Cell):
|
|||
super(_VirtualDatasetCell, self).__init__(auto_prefix=False)
|
||||
self._backbone = backbone
|
||||
self._virtual_dataset = _VirtualDataset()
|
||||
if isinstance(backbone, Cell) and backbone.jit_config_dict:
|
||||
self._jit_config_dict = backbone.jit_config_dict
|
||||
|
||||
def construct(self, *inputs):
|
||||
output = self._virtual_dataset(*inputs)
|
||||
|
@ -547,6 +553,8 @@ class MicroBatchInterleaved(Cell):
|
|||
interleave_data.strided_slice.add_prim_attr("strided_slice_flag", True)
|
||||
interleave_data.strided_slice.add_prim_attr("interleave_num", interleave_num)
|
||||
self.interleave_inputs.append(interleave_data)
|
||||
if isinstance(network, Cell) and network.jit_config_dict:
|
||||
self._jit_config_dict = network.jit_config_dict
|
||||
|
||||
def construct(self, *inputs):
|
||||
output = 0.0
|
||||
|
@ -585,6 +593,8 @@ class PipelineCell(Cell):
|
|||
self.micro_inputs.append(micro_input)
|
||||
self.add = P.Add().add_prim_attr("pipeline_end", i)
|
||||
self.add_list.append(self.add)
|
||||
if isinstance(network, Cell) and network.jit_config_dict:
|
||||
self._jit_config_dict = network.jit_config_dict
|
||||
|
||||
def construct(self, *inputs):
|
||||
ret = None
|
||||
|
@ -613,6 +623,8 @@ class _TrainPipelineAccuStepCell(TrainOneStepCell):
|
|||
self.accu_grads = self.weights.clone(prefix="accu_grads", init="zeros")
|
||||
self.hyper_map = ops.HyperMap()
|
||||
self.opt_shard = _get_enable_parallel_optimizer()
|
||||
if isinstance(network, Cell) and network.jit_config_dict:
|
||||
self._jit_config_dict = network.jit_config_dict
|
||||
|
||||
def construct(self, *inputs):
|
||||
weights = self.weights
|
||||
|
@ -654,6 +666,8 @@ class VirtualDatasetCellTriple(Cell):
|
|||
super(VirtualDatasetCellTriple, self).__init__(auto_prefix=False)
|
||||
logger.warning("WARN_DEPRECATED: The usage of VirtualDatasetCellTriple is deprecated.")
|
||||
self._backbone = backbone
|
||||
if isinstance(backbone, Cell) and backbone.jit_config_dict:
|
||||
self._jit_config_dict = backbone.jit_config_dict
|
||||
|
||||
def construct(self, a, b, c):
|
||||
return self._backbone(a, b, c)
|
||||
|
@ -696,6 +710,8 @@ class WithEvalCell(Cell):
|
|||
self._network = network
|
||||
self._loss_fn = loss_fn
|
||||
self.add_cast_fp32 = validator.check_value_type("add_cast_fp32", add_cast_fp32, [bool], self.cls_name)
|
||||
if isinstance(network, Cell) and network.jit_config_dict:
|
||||
self._jit_config_dict = network.jit_config_dict
|
||||
|
||||
def construct(self, data, label):
|
||||
outputs = self._network(data)
|
||||
|
|
|
@ -43,6 +43,8 @@ class SymbolTreeBuilder:
|
|||
network_str = inspect.getsource(type(network))
|
||||
self._ast_root: ast.Module = ast.parse(network_str)
|
||||
self._root_tree: Optional[SymbolTree] = None
|
||||
if isinstance(network, Cell) and network.jit_config_dict:
|
||||
self._jit_config_dict = network.jit_config_dict
|
||||
|
||||
@staticmethod
|
||||
def merge_module_of_subtree(main_tree: SymbolTree, sub_stree: SymbolTree):
|
||||
|
|
|
@ -70,22 +70,23 @@ AMP_BLACK_LIST = (
|
|||
|
||||
class _OutputTo16(nn.Cell):
|
||||
"""Wrap cell for amp. Cast network output back to float16."""
|
||||
|
||||
def __init__(self, op):
|
||||
def __init__(self, backbone):
|
||||
super(_OutputTo16, self).__init__(auto_prefix=False)
|
||||
self._op = op
|
||||
self._backbone = backbone
|
||||
if isinstance(backbone, nn.Cell) and backbone.jit_config_dict:
|
||||
self._jit_config_dict = backbone.jit_config_dict
|
||||
|
||||
def construct(self, x):
|
||||
return F.cast(self._op(x), mstype.float16)
|
||||
return F.cast(self._backbone(x), mstype.float16)
|
||||
|
||||
|
||||
class _OutputTo32(nn.Cell):
|
||||
"Wrap loss for amp. Cast network output back to float32"
|
||||
|
||||
"""Wrap loss for amp. Cast network output back to float32."""
|
||||
def __init__(self, backbone):
|
||||
super(_OutputTo32, self).__init__(auto_prefix=False)
|
||||
self._backbone = backbone
|
||||
self._jit_config_dict = backbone._jit_config_dict
|
||||
if isinstance(backbone, nn.Cell) and backbone.jit_config_dict:
|
||||
self._jit_config_dict = backbone.jit_config_dict
|
||||
|
||||
def construct(self, *inputs):
|
||||
out = self._backbone(*inputs)
|
||||
|
@ -331,13 +332,12 @@ def _add_loss_network(network, loss_fn, cast_model_type):
|
|||
"""Add loss network."""
|
||||
|
||||
class WithLossCell(nn.Cell):
|
||||
"Wrap loss for amp. Cast network output back to float32"
|
||||
|
||||
"""Wrap loss for amp. Cast network output back to float32."""
|
||||
def __init__(self, backbone, loss_fn):
|
||||
super(WithLossCell, self).__init__(auto_prefix=False)
|
||||
self._backbone = backbone
|
||||
self._loss_fn = loss_fn
|
||||
if backbone.jit_config_dict:
|
||||
if isinstance(backbone, nn.Cell) and backbone.jit_config_dict:
|
||||
self._jit_config_dict = backbone.jit_config_dict
|
||||
|
||||
def construct(self, data, label):
|
||||
|
|
|
@ -95,6 +95,8 @@ class _DataWrapper(nn.Cell):
|
|||
self.get_next = P.GetNext(
|
||||
dataset_types, dataset_shapes, len(dataset_types), queue_name)
|
||||
self.network = network
|
||||
if isinstance(network, nn.Cell) and network.jit_config_dict:
|
||||
self._jit_config_dict = network.jit_config_dict
|
||||
|
||||
def construct(self):
|
||||
outputs = self.get_next()
|
||||
|
|
|
@ -83,7 +83,7 @@ def test_sink():
|
|||
dataset_strategy="data_parallel", device_num=8)
|
||||
data = {"input": np.ones([16, 32, 128]).astype(np.float32), "label": np.zeros([16, 32, 768]).astype(np.float32)}
|
||||
dataset = ds.NumpySlicesDataset(data=data)
|
||||
jitconfig = JitConfig(jit_level="O1", task_sink=True)
|
||||
jitconfig = JitConfig(jit_level="O1", exc_mode='auto')
|
||||
sink_process = ms.train.data_sink(dense_func, dataset, sink_size=4, jit_config=jitconfig)
|
||||
_ = sink_process()
|
||||
|
||||
|
@ -110,7 +110,7 @@ def test_sink_with_grad():
|
|||
dataset_strategy="data_parallel", device_num=8)
|
||||
data = {"input": np.ones([16, 32, 128]).astype(np.float32), "label": np.zeros([16, 32, 768]).astype(np.float32)}
|
||||
dataset = ds.NumpySlicesDataset(data=data)
|
||||
jitconfig = JitConfig(jit_level="O1", task_sink=True)
|
||||
jitconfig = JitConfig(jit_level="O1", exc_mode='no_sink')
|
||||
sink_process = ms.train.data_sink(train_step, dataset, sink_size=4, jit_config=jitconfig)
|
||||
_ = sink_process()
|
||||
|
||||
|
|
Loading…
Reference in New Issue