!6101 delete DataWrapper

Merge pull request !6101 from wangnan39/delete_datawrapper
This commit is contained in:
mindspore-ci-bot 2020-09-14 11:54:50 +08:00 committed by Gitee
commit b4581b2c54
3 changed files with 2 additions and 42 deletions

View File

@ -17,7 +17,7 @@ Wrap cells for networks.
Use the Wrapper to combine the loss or build the training steps.
"""
from .cell_wrapper import TrainOneStepCell, WithLossCell, WithGradCell, WithEvalCell, DataWrapper, \
from .cell_wrapper import TrainOneStepCell, WithLossCell, WithGradCell, WithEvalCell, \
ParameterUpdate, GetNextSingleOp, VirtualDatasetCellTriple
from .loss_scale import TrainOneStepWithLossScaleCell, DynamicLossScaleUpdateCell, FixedLossScaleUpdateCell
from .grad_reducer import DistributedGradReducer
@ -27,7 +27,6 @@ __all__ = [
"WithLossCell",
"WithGradCell",
"WithEvalCell",
"DataWrapper",
"GetNextSingleOp",
"TrainOneStepWithLossScaleCell",
"DistributedGradReducer",

View File

@ -205,45 +205,6 @@ class TrainOneStepCell(Cell):
return F.depend(loss, self.optimizer(grads))
class DataWrapper(Cell):
"""
Network training package class for dataset.
DataWrapper wraps the input network with a dataset which automatically fetches data with 'GetNext'
function from the dataset channel 'queue_name' and does forward computation in the construct function.
Args:
network (Cell): The training network for dataset.
dataset_types (list): The type of dataset. The list contains the types of the inputs.
dataset_shapes (list): The shapes of dataset. The list contains multiple sublists that describe
the shape of the inputs.
queue_name (str): The identification of dataset channel which specifies the dataset channel to supply
data for the network.
Outputs:
Tensor, network output whose shape depends on the network.
Examples:
>>> # call create_dataset function to create a regular dataset, refer to mindspore.dataset
>>> train_dataset = create_dataset()
>>> dataset_helper = mindspore.DatasetHelper(train_dataset)
>>> net = Net()
>>> net = DataWrapper(net, *(dataset_helper.types_shapes()), train_dataset.queue_name)
"""
def __init__(self, network, dataset_types, dataset_shapes, queue_name):
super(DataWrapper, self).__init__(auto_prefix=False, flags=network.get_flags())
# Also copy the flag in `network` construct
flags = getattr(network.__class__.construct, "_mindspore_flags", {})
self.add_flags(**flags)
self.get_next = P.GetNext(dataset_types, dataset_shapes, len(dataset_types), queue_name)
self.network = network
def construct(self):
outputs = self.get_next()
return self.network(*outputs)
class GetNextSingleOp(Cell):
"""
Cell to run for getting the next operation.

View File

@ -2538,7 +2538,7 @@ class GetNext(PrimitiveWithInfer):
Note:
The GetNext operation needs to be associated with network and it also depends on the init_dataset interface,
it can't be used directly as a single operation.
For details, please refer to `nn.DataWrapper` source code.
For details, please refer to `connect_network_with_dataset` source code.
Args:
types (list[:class:`mindspore.dtype`]): The type of the outputs.