!6101 delete DataWrapper

Merge pull request !6101 from wangnan39/delete_datawrapper
This commit is contained in:
mindspore-ci-bot 2020-09-14 11:54:50 +08:00 committed by Gitee
commit b4581b2c54
3 changed files with 2 additions and 42 deletions

View File

@ -17,7 +17,7 @@ Wrap cells for networks.
Use the Wrapper to combine the loss or build the training steps. Use the Wrapper to combine the loss or build the training steps.
""" """
from .cell_wrapper import TrainOneStepCell, WithLossCell, WithGradCell, WithEvalCell, DataWrapper, \ from .cell_wrapper import TrainOneStepCell, WithLossCell, WithGradCell, WithEvalCell, \
ParameterUpdate, GetNextSingleOp, VirtualDatasetCellTriple ParameterUpdate, GetNextSingleOp, VirtualDatasetCellTriple
from .loss_scale import TrainOneStepWithLossScaleCell, DynamicLossScaleUpdateCell, FixedLossScaleUpdateCell from .loss_scale import TrainOneStepWithLossScaleCell, DynamicLossScaleUpdateCell, FixedLossScaleUpdateCell
from .grad_reducer import DistributedGradReducer from .grad_reducer import DistributedGradReducer
@ -27,7 +27,6 @@ __all__ = [
"WithLossCell", "WithLossCell",
"WithGradCell", "WithGradCell",
"WithEvalCell", "WithEvalCell",
"DataWrapper",
"GetNextSingleOp", "GetNextSingleOp",
"TrainOneStepWithLossScaleCell", "TrainOneStepWithLossScaleCell",
"DistributedGradReducer", "DistributedGradReducer",

View File

@ -205,45 +205,6 @@ class TrainOneStepCell(Cell):
return F.depend(loss, self.optimizer(grads)) return F.depend(loss, self.optimizer(grads))
class DataWrapper(Cell):
"""
Network training package class for dataset.
DataWrapper wraps the input network with a dataset which automatically fetches data with 'GetNext'
function from the dataset channel 'queue_name' and does forward computation in the construct function.
Args:
network (Cell): The training network for dataset.
dataset_types (list): The type of dataset. The list contains the types of the inputs.
dataset_shapes (list): The shapes of dataset. The list contains multiple sublists that describe
the shape of the inputs.
queue_name (str): The identification of dataset channel which specifies the dataset channel to supply
data for the network.
Outputs:
Tensor, network output whose shape depends on the network.
Examples:
>>> # call create_dataset function to create a regular dataset, refer to mindspore.dataset
>>> train_dataset = create_dataset()
>>> dataset_helper = mindspore.DatasetHelper(train_dataset)
>>> net = Net()
>>> net = DataWrapper(net, *(dataset_helper.types_shapes()), train_dataset.queue_name)
"""
def __init__(self, network, dataset_types, dataset_shapes, queue_name):
super(DataWrapper, self).__init__(auto_prefix=False, flags=network.get_flags())
# Also copy the flag in `network` construct
flags = getattr(network.__class__.construct, "_mindspore_flags", {})
self.add_flags(**flags)
self.get_next = P.GetNext(dataset_types, dataset_shapes, len(dataset_types), queue_name)
self.network = network
def construct(self):
outputs = self.get_next()
return self.network(*outputs)
class GetNextSingleOp(Cell): class GetNextSingleOp(Cell):
""" """
Cell to run for getting the next operation. Cell to run for getting the next operation.

View File

@ -2538,7 +2538,7 @@ class GetNext(PrimitiveWithInfer):
Note: Note:
The GetNext operation needs to be associated with network and it also depends on the init_dataset interface, The GetNext operation needs to be associated with network and it also depends on the init_dataset interface,
it can't be used directly as a single operation. it can't be used directly as a single operation.
For details, please refer to `nn.DataWrapper` source code. For details, please refer to `connect_network_with_dataset` source code.
Args: Args:
types (list[:class:`mindspore.dtype`]): The type of the outputs. types (list[:class:`mindspore.dtype`]): The type of the outputs.