fix chinese docs problem
This commit is contained in:
parent
c48385abfb
commit
b21baa366d
|
@ -3,25 +3,84 @@ mindspore.dataset.WaitedDSCallback
|
||||||
|
|
||||||
.. py:class:: mindspore.dataset.WaitedDSCallback(step_size=1)
|
.. py:class:: mindspore.dataset.WaitedDSCallback(step_size=1)
|
||||||
|
|
||||||
用于自定义与训练回调同步的数据集回调类的抽象基类。
|
数据集自定义回调类的抽象基类,用于与训练回调类(`mindspore.callback <https://mindspore.cn/docs/api/zh-CN/master/api_python/mindspore.train.html#mindspore.train.callback.Callback>`_)的同步。
|
||||||
|
|
||||||
此类可用于自定义在step或epoch结束后执行的回调方法。
|
可用于在每个step或epoch开始前执行自定义的回调方法,注意,第二个step或epoch开始时才会触发该调用。
|
||||||
例如在自动数据增强中根据上一个epoch的loss值来更新增强算子参数配置。
|
例如在自动数据增强中根据上一个epoch的loss值来更新增强算子参数配置。
|
||||||
|
|
||||||
|
用户可通过 `train_run_context` 获取模型相关信息。如 `network` 、 `train_network` 、 `epoch_num` 、 `batch_num` 、 `loss_fn` 、 `optimizer` 、 `parallel_mode` 、 `device_number` 、 `list_callback` 、 `cur_epoch_num` 、 `cur_step_num` 、 `dataset_sink_mode` 、 `net_outputs` 等,详见 `mindspore.callback <https://mindspore.cn/docs/api/zh-CN/master/api_python/mindspore.train.html#mindspore.train.callback.Callback>`_ 。
|
||||||
|
|
||||||
|
用户可通过 `ds_run_context` 获取数据处理管道相关信息。包括 `cur_epoch_num` (当前epoch数)、 `cur_step_num_in_epoch` (当前epoch的step数)、 `cur_step_num` (当前step数)。
|
||||||
|
|
||||||
**参数:**
|
**参数:**
|
||||||
|
|
||||||
- **step_size** (int, optional) - 每个step包含的数据行数。step大小通常与batch大小相等(默认值为1)。
|
- **step_size** (int, optional) - 每个step包含的数据行数。通常step_size与batch_size一致,默认值:1。
|
||||||
|
|
||||||
**样例:**
|
**样例:**
|
||||||
|
|
||||||
|
>>> import mindspore.nn as nn
|
||||||
>>> from mindspore.dataset import WaitedDSCallback
|
>>> from mindspore.dataset import WaitedDSCallback
|
||||||
|
>>> from mindspore import context
|
||||||
|
>>> from mindspore.train import Model
|
||||||
|
>>> from mindspore.train.callback import Callback
|
||||||
>>>
|
>>>
|
||||||
>>> my_cb = WaitedDSCallback(32)
|
>>> context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
|
||||||
>>> # dataset为任意数据集实例
|
>>>
|
||||||
>>> data = data.map(operations=AugOp(), callbacks=my_cb)
|
>>> # 自定义用于数据处理管道同步数据的回调类
|
||||||
>>> data = data.batch(32)
|
>>> class MyWaitedCallback(WaitedDSCallback):
|
||||||
>>> # 定义网络
|
... def __init__(self, events, step_size=1):
|
||||||
>>> model.train(epochs, data, callbacks=[my_cb])
|
... super().__init__(step_size)
|
||||||
|
... self.events = events
|
||||||
|
...
|
||||||
|
... # epoch开始前数据处理管道要执行的回调函数
|
||||||
|
... def sync_epoch_begin(self, train_run_context, ds_run_context):
|
||||||
|
... event = f"ds_epoch_begin_{ds_run_context.cur_epoch_num}_{ds_run_context.cur_step_num}"
|
||||||
|
... self.events.append(event)
|
||||||
|
...
|
||||||
|
... # step开始前数据处理管道要执行的回调函数
|
||||||
|
... def sync_step_begin(self, train_run_context, ds_run_context):
|
||||||
|
... event = f"ds_step_begin_{ds_run_context.cur_epoch_num}_{ds_run_context.cur_step_num}"
|
||||||
|
... self.events.append(event)
|
||||||
|
>>>
|
||||||
|
>>> # 自定义用于网络训练时同步数据的回调类
|
||||||
|
>>> class MyMSCallback(Callback):
|
||||||
|
... def __init__(self, events):
|
||||||
|
... self.events = events
|
||||||
|
...
|
||||||
|
... # epoch结束网络训练要执行的回调函数
|
||||||
|
... def epoch_end(self, run_context):
|
||||||
|
... cb_params = run_context.original_args()
|
||||||
|
... event = f"ms_epoch_end_{cb_params.cur_epoch_num}_{cb_params.cur_step_num}"
|
||||||
|
... self.events.append(event)
|
||||||
|
...
|
||||||
|
... # step结束网络训练要执行的回调函数
|
||||||
|
... def step_end(self, run_context):
|
||||||
|
... cb_params = run_context.original_args()
|
||||||
|
... event = f"ms_step_end_{cb_params.cur_epoch_num}_{cb_params.cur_step_num}"
|
||||||
|
... self.events.append(event)
|
||||||
|
>>>
|
||||||
|
>>> # 自定义网络
|
||||||
|
>>> class Net(nn.Cell):
|
||||||
|
... def construct(self, x, y):
|
||||||
|
... return x
|
||||||
|
>>>
|
||||||
|
>>> # 声明一个网络训练与数据处理同步的数据
|
||||||
|
>>> events = []
|
||||||
|
>>>
|
||||||
|
>>> # 声明数据处理管道和网络训练的回调类
|
||||||
|
>>> my_cb1 = MyWaitedCallback(events, 1)
|
||||||
|
>>> my_cb2 = MyMSCallback(events)
|
||||||
|
>>> arr = [1, 2, 3, 4]
|
||||||
|
>>> # 构建数据处理管道
|
||||||
|
>>> data = ds.NumpySlicesDataset((arr, arr), column_names=["c1", "c2"], shuffle=False)
|
||||||
|
>>> # 将数据处理管道的回调类加入到map中
|
||||||
|
>>> data = data.map(operations=(lambda x: x), callbacks=my_cb1)
|
||||||
|
>>>
|
||||||
|
>>> net = Net()
|
||||||
|
>>> model = Model(net)
|
||||||
|
>>>
|
||||||
|
>>> # 将数据处理管道和网络训练的回调类加入到模型训练的回调列表中
|
||||||
|
>>> model.train(2, data, dataset_sink_mode=False, callbacks=[my_cb2, my_cb1])
|
||||||
|
|
||||||
.. py:method:: begin(run_context)
|
.. py:method:: begin(run_context)
|
||||||
|
|
||||||
|
|
|
@ -10,16 +10,16 @@ mindspore.dataset.deserialize
|
||||||
|
|
||||||
**参数:**
|
**参数:**
|
||||||
|
|
||||||
- **input_dict** (dict) - 包含序列化数据集图的Python字典。
|
- **input_dict** (dict) - 以Python字典存储的数据处理管道。默认值:None。
|
||||||
- **json_filepath** (str) - JSON文件的路径,用户可通过 `mindspore.dataset.serialize()` 接口生成。
|
- **json_filepath** (str) - 数据处理管道JSON文件的路径,该文件以通用JSON格式存储了数据处理管道信息,用户可通过 `mindspore.dataset.serialize()` 接口生成。默认值:None。
|
||||||
|
|
||||||
**返回:**
|
**返回:**
|
||||||
|
|
||||||
成功时,返回Dataset对象;失败时,则返回None。
|
当反序列化成功时,将返回Dataset对象;当无法被反序列化时,deserialize将会失败,且返回None。
|
||||||
|
|
||||||
**异常:**
|
**异常:**
|
||||||
|
|
||||||
**OSError:** 无法打开JSON文件。
|
- **OSError:** - `json_filepath` 不为None且JSON文件解析失败时。
|
||||||
|
|
||||||
**样例:**
|
**样例:**
|
||||||
|
|
||||||
|
@ -28,9 +28,8 @@ mindspore.dataset.deserialize
|
||||||
>>> dataset = dataset.map(operation=one_hot_encode, input_column_names="label")
|
>>> dataset = dataset.map(operation=one_hot_encode, input_column_names="label")
|
||||||
>>> dataset = dataset.batch(batch_size=10, drop_remainder=True)
|
>>> dataset = dataset.batch(batch_size=10, drop_remainder=True)
|
||||||
>>> # 用例1:序列化/反序列化 JSON文件
|
>>> # 用例1:序列化/反序列化 JSON文件
|
||||||
>>> ds.engine.serialize(dataset, json_filepath="/path/to/mnist_dataset_pipeline.json")
|
>>> ds.serialize(dataset, json_filepath="/path/to/mnist_dataset_pipeline.json")
|
||||||
>>> dataset = ds.engine.deserialize(json_filepath="/path/to/mnist_dataset_pipeline.json")
|
>>> dataset = ds.deserialize(json_filepath="/path/to/mnist_dataset_pipeline.json")
|
||||||
>>> # 用例2:序列化/反序列化 Python字典
|
>>> # 用例2:序列化/反序列化 Python字典
|
||||||
>>> serialized_data = ds.engine.serialize(dataset)
|
>>> serialized_data = ds.serialize(dataset)
|
||||||
>>> dataset = ds.engine.deserialize(input_dict=serialized_data)
|
>>> dataset = ds.deserialize(input_dict=serialized_data)
|
||||||
|
|
||||||
|
|
|
@ -7,22 +7,51 @@
|
||||||
|
|
||||||
**参数:**
|
**参数:**
|
||||||
|
|
||||||
- **image** (ndarray): 待绘制的图像,shape为(C, H, W)或(H, W, C),通道顺序为RGB。
|
- **image** (numpy.ndarray) - 待绘制的图像,shape为(C, H, W)或(H, W, C),通道顺序为RGB。
|
||||||
- **bboxes** (ndarray): 边界框(包含类别置信度),shape为(N, 4)或(N, 5),格式为(N,X,Y,W,H)。
|
- **bboxes** (numpy.ndarray) - 边界框(包含类别置信度),shape为(N, 4)或(N, 5),格式为(N,X,Y,W,H)。
|
||||||
- **labels** (ndarray): 边界框的类别,shape为(N, 1)。
|
- **labels** (numpy.ndarray) - 边界框的类别,shape为(N, 1)。
|
||||||
- **segm** (ndarray): 图像分割掩码,shape为(M, H, W),M表示类别总数(默认值None,不绘制掩码)。
|
- **segm** (numpy.ndarray) - 图像分割掩码,shape为(M, H, W),M表示类别总数,默认值:None,不绘制掩码。
|
||||||
- **class_names** (list[str], dict): 类别索引到类别名的映射表(默认值None,仅显示类别索引)。
|
- **class_names** (list[str], dict) - 类别索引到类别名的映射表,默认值:None,仅显示类别索引。
|
||||||
- **score_threshold** (float): 绘制边界框的类别置信度阈值(默认值0,绘制所有边界框)。
|
- **score_threshold** (float) - 绘制边界框的类别置信度阈值,默认值:0,绘制所有边界框。
|
||||||
- **bbox_color** (tuple(int)): 指定绘制边界框时线条的颜色,顺序为BGR(默认值(0,255,0),表示'green')。
|
- **bbox_color** (tuple(int)) - 指定绘制边界框时线条的颜色,顺序为BGR,默认值:(0,255,0),表示绿色。
|
||||||
- **text_color** (tuple(int)):指定类别文本的显示颜色,顺序为BGR(默认值(203, 192, 255),表示'pink')。
|
- **text_color** (tuple(int)) - 指定类别文本的显示颜色,顺序为BGR,默认值:(203, 192, 255),表示粉色。
|
||||||
- **mask_color** (tuple(int)):指定掩码的显示颜色,顺序为BGR(默认值(128, 0, 128),表示'purple')。
|
- **mask_color** (tuple(int)) - 指定掩码的显示颜色,顺序为BGR,默认值:(128, 0, 128),表示紫色。
|
||||||
- **thickness** (int): 指定边界框和类别文本的线条粗细(默认值2)。
|
- **thickness** (int) - 指定边界框和类别文本的线条粗细,默认值:2。
|
||||||
- **font_size** (int, float): 指定类别文本字体大小(默认值0.8)。
|
- **font_size** (int, float) - 指定类别文本字体大小,默认值:0.8。
|
||||||
- **show** (bool): 是否显示图像(默认值为True)。
|
- **show** (bool) - 是否显示图像,默认值:True。
|
||||||
- **win_name** (str): 指定窗口名称(默认值"win")。
|
- **win_name** (str) - 指定窗口名称,默认值:"win"。
|
||||||
- **wait_time** (int): 指定cv2.waitKey的时延,单位为ms,即图像显示的自动切换间隔(默认值2000,表示间隔为2000ms)。
|
- **wait_time** (int) - 指定cv2.waitKey的时延,单位为ms,即图像显示的自动切换间隔,默认值:2000,表示间隔为2000ms。
|
||||||
- **out_file** (str, optional): 输出图像的文件名,用于在绘制后将结果存储到本地(默认值None,不保存)。
|
- **out_file** (str, optional) - 输出图像的文件路径,用于在绘制后将结果存储到本地,默认值:None,不保存。
|
||||||
|
|
||||||
**返回:**
|
**返回:**
|
||||||
|
|
||||||
ndarray,带边界框和类别置信度的图像。
|
numpy.ndarray,带边界框和类别置信度的图像。
|
||||||
|
|
||||||
|
**样例:**
|
||||||
|
|
||||||
|
>>> import numpy as np
|
||||||
|
>>> from mindspore.dataset.utils.browse_dataset import imshow_det_bbox
|
||||||
|
>>>
|
||||||
|
>>> # 读取VOC数据集.
|
||||||
|
>>> voc_dataset_dir = "/path/to/voc_dataset_directory"
|
||||||
|
>>> dataset = ds.VOCDataset(voc_dataset_dir, task="Detection", shuffle=False, decode=True, num_samples=5)
|
||||||
|
>>> dataset_iter = dataset.create_dict_iterator(output_numpy=True, num_epochs=1)
|
||||||
|
>>>
|
||||||
|
>>> # 调用imshow_det_bbox自动标注图像
|
||||||
|
>>> for index, data in enumerate(dataset_iter):
|
||||||
|
... image = data["image"]
|
||||||
|
... bbox = data["bbox"]
|
||||||
|
... label = data["label"]
|
||||||
|
... # draw image with bboxes
|
||||||
|
... imshow_det_bbox(image, bbox, label,
|
||||||
|
... class_names=['aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat',
|
||||||
|
... 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', 'person',
|
||||||
|
... 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor'],
|
||||||
|
... win_name="my_window",
|
||||||
|
... wait_time=5000,
|
||||||
|
... show=True,
|
||||||
|
... out_file="voc_dataset_{}.jpg".format(str(index)))
|
||||||
|
|
||||||
|
**`imshow_det_bbox` 在VOC2012数据集的使用图示:**
|
||||||
|
|
||||||
|
.. image:: api_img/browse_dataset.png
|
||||||
|
|
Loading…
Reference in New Issue