forked from mindspore-Ecosystem/mindspore
!6101 delete DataWrapper
Merge pull request !6101 from wangnan39/delete_datawrapper
This commit is contained in:
commit
b4581b2c54
|
@ -17,7 +17,7 @@ Wrap cells for networks.
|
|||
|
||||
Use the Wrapper to combine the loss or build the training steps.
|
||||
"""
|
||||
from .cell_wrapper import TrainOneStepCell, WithLossCell, WithGradCell, WithEvalCell, DataWrapper, \
|
||||
from .cell_wrapper import TrainOneStepCell, WithLossCell, WithGradCell, WithEvalCell, \
|
||||
ParameterUpdate, GetNextSingleOp, VirtualDatasetCellTriple
|
||||
from .loss_scale import TrainOneStepWithLossScaleCell, DynamicLossScaleUpdateCell, FixedLossScaleUpdateCell
|
||||
from .grad_reducer import DistributedGradReducer
|
||||
|
@ -27,7 +27,6 @@ __all__ = [
|
|||
"WithLossCell",
|
||||
"WithGradCell",
|
||||
"WithEvalCell",
|
||||
"DataWrapper",
|
||||
"GetNextSingleOp",
|
||||
"TrainOneStepWithLossScaleCell",
|
||||
"DistributedGradReducer",
|
||||
|
|
|
@ -205,45 +205,6 @@ class TrainOneStepCell(Cell):
|
|||
return F.depend(loss, self.optimizer(grads))
|
||||
|
||||
|
||||
class DataWrapper(Cell):
|
||||
"""
|
||||
Network training package class for dataset.
|
||||
|
||||
DataWrapper wraps the input network with a dataset which automatically fetches data with 'GetNext'
|
||||
function from the dataset channel 'queue_name' and does forward computation in the construct function.
|
||||
|
||||
Args:
|
||||
network (Cell): The training network for dataset.
|
||||
dataset_types (list): The type of dataset. The list contains the types of the inputs.
|
||||
dataset_shapes (list): The shapes of dataset. The list contains multiple sublists that describe
|
||||
the shape of the inputs.
|
||||
queue_name (str): The identification of dataset channel which specifies the dataset channel to supply
|
||||
data for the network.
|
||||
|
||||
Outputs:
|
||||
Tensor, network output whose shape depends on the network.
|
||||
|
||||
Examples:
|
||||
>>> # call create_dataset function to create a regular dataset, refer to mindspore.dataset
|
||||
>>> train_dataset = create_dataset()
|
||||
>>> dataset_helper = mindspore.DatasetHelper(train_dataset)
|
||||
>>> net = Net()
|
||||
>>> net = DataWrapper(net, *(dataset_helper.types_shapes()), train_dataset.queue_name)
|
||||
"""
|
||||
|
||||
def __init__(self, network, dataset_types, dataset_shapes, queue_name):
|
||||
super(DataWrapper, self).__init__(auto_prefix=False, flags=network.get_flags())
|
||||
# Also copy the flag in `network` construct
|
||||
flags = getattr(network.__class__.construct, "_mindspore_flags", {})
|
||||
self.add_flags(**flags)
|
||||
self.get_next = P.GetNext(dataset_types, dataset_shapes, len(dataset_types), queue_name)
|
||||
self.network = network
|
||||
|
||||
def construct(self):
|
||||
outputs = self.get_next()
|
||||
return self.network(*outputs)
|
||||
|
||||
|
||||
class GetNextSingleOp(Cell):
|
||||
"""
|
||||
Cell to run for getting the next operation.
|
||||
|
|
|
@ -2538,7 +2538,7 @@ class GetNext(PrimitiveWithInfer):
|
|||
Note:
|
||||
The GetNext operation needs to be associated with network and it also depends on the init_dataset interface,
|
||||
it can't be used directly as a single operation.
|
||||
For details, please refer to `nn.DataWrapper` source code.
|
||||
For details, please refer to `connect_network_with_dataset` source code.
|
||||
|
||||
Args:
|
||||
types (list[:class:`mindspore.dtype`]): The type of the outputs.
|
||||
|
|
Loading…
Reference in New Issue