diff --git a/mindspore/nn/wrap/__init__.py b/mindspore/nn/wrap/__init__.py index 813c8bf7663..76840c5603c 100644 --- a/mindspore/nn/wrap/__init__.py +++ b/mindspore/nn/wrap/__init__.py @@ -17,7 +17,7 @@ Wrap cells for networks. Use the Wrapper to combine the loss or build the training steps. """ -from .cell_wrapper import TrainOneStepCell, WithLossCell, WithGradCell, WithEvalCell, DataWrapper, \ +from .cell_wrapper import TrainOneStepCell, WithLossCell, WithGradCell, WithEvalCell, \ ParameterUpdate, GetNextSingleOp, VirtualDatasetCellTriple from .loss_scale import TrainOneStepWithLossScaleCell, DynamicLossScaleUpdateCell, FixedLossScaleUpdateCell from .grad_reducer import DistributedGradReducer @@ -27,7 +27,6 @@ __all__ = [ "WithLossCell", "WithGradCell", "WithEvalCell", - "DataWrapper", "GetNextSingleOp", "TrainOneStepWithLossScaleCell", "DistributedGradReducer", diff --git a/mindspore/nn/wrap/cell_wrapper.py b/mindspore/nn/wrap/cell_wrapper.py index 86ab01375ef..6d25daf41e4 100644 --- a/mindspore/nn/wrap/cell_wrapper.py +++ b/mindspore/nn/wrap/cell_wrapper.py @@ -205,45 +205,6 @@ class TrainOneStepCell(Cell): return F.depend(loss, self.optimizer(grads)) -class DataWrapper(Cell): - """ - Network training package class for dataset. - - DataWrapper wraps the input network with a dataset which automatically fetches data with 'GetNext' - function from the dataset channel 'queue_name' and does forward computation in the construct function. - - Args: - network (Cell): The training network for dataset. - dataset_types (list): The type of dataset. The list contains the types of the inputs. - dataset_shapes (list): The shapes of dataset. The list contains multiple sublists that describe - the shape of the inputs. - queue_name (str): The identification of dataset channel which specifies the dataset channel to supply - data for the network. - - Outputs: - Tensor, network output whose shape depends on the network. - - Examples: - >>> # call create_dataset function to create a regular dataset, refer to mindspore.dataset - >>> train_dataset = create_dataset() - >>> dataset_helper = mindspore.DatasetHelper(train_dataset) - >>> net = Net() - >>> net = DataWrapper(net, *(dataset_helper.types_shapes()), train_dataset.queue_name) - """ - - def __init__(self, network, dataset_types, dataset_shapes, queue_name): - super(DataWrapper, self).__init__(auto_prefix=False, flags=network.get_flags()) - # Also copy the flag in `network` construct - flags = getattr(network.__class__.construct, "_mindspore_flags", {}) - self.add_flags(**flags) - self.get_next = P.GetNext(dataset_types, dataset_shapes, len(dataset_types), queue_name) - self.network = network - - def construct(self): - outputs = self.get_next() - return self.network(*outputs) - - class GetNextSingleOp(Cell): """ Cell to run for getting the next operation. diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index 5d9b0f57552..0c573968382 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -2538,7 +2538,7 @@ class GetNext(PrimitiveWithInfer): Note: The GetNext operation needs to be associated with network and it also depends on the init_dataset interface, it can't be used directly as a single operation. - For details, please refer to `nn.DataWrapper` source code. + For details, please refer to `connect_network_with_dataset` source code. Args: types (list[:class:`mindspore.dtype`]): The type of the outputs.