forked from mindspore-Ecosystem/mindspore
!6101 delete DataWrapper
Merge pull request !6101 from wangnan39/delete_datawrapper
This commit is contained in:
commit
b4581b2c54
|
@ -17,7 +17,7 @@ Wrap cells for networks.
|
||||||
|
|
||||||
Use the Wrapper to combine the loss or build the training steps.
|
Use the Wrapper to combine the loss or build the training steps.
|
||||||
"""
|
"""
|
||||||
from .cell_wrapper import TrainOneStepCell, WithLossCell, WithGradCell, WithEvalCell, DataWrapper, \
|
from .cell_wrapper import TrainOneStepCell, WithLossCell, WithGradCell, WithEvalCell, \
|
||||||
ParameterUpdate, GetNextSingleOp, VirtualDatasetCellTriple
|
ParameterUpdate, GetNextSingleOp, VirtualDatasetCellTriple
|
||||||
from .loss_scale import TrainOneStepWithLossScaleCell, DynamicLossScaleUpdateCell, FixedLossScaleUpdateCell
|
from .loss_scale import TrainOneStepWithLossScaleCell, DynamicLossScaleUpdateCell, FixedLossScaleUpdateCell
|
||||||
from .grad_reducer import DistributedGradReducer
|
from .grad_reducer import DistributedGradReducer
|
||||||
|
@ -27,7 +27,6 @@ __all__ = [
|
||||||
"WithLossCell",
|
"WithLossCell",
|
||||||
"WithGradCell",
|
"WithGradCell",
|
||||||
"WithEvalCell",
|
"WithEvalCell",
|
||||||
"DataWrapper",
|
|
||||||
"GetNextSingleOp",
|
"GetNextSingleOp",
|
||||||
"TrainOneStepWithLossScaleCell",
|
"TrainOneStepWithLossScaleCell",
|
||||||
"DistributedGradReducer",
|
"DistributedGradReducer",
|
||||||
|
|
|
@ -205,45 +205,6 @@ class TrainOneStepCell(Cell):
|
||||||
return F.depend(loss, self.optimizer(grads))
|
return F.depend(loss, self.optimizer(grads))
|
||||||
|
|
||||||
|
|
||||||
class DataWrapper(Cell):
|
|
||||||
"""
|
|
||||||
Network training package class for dataset.
|
|
||||||
|
|
||||||
DataWrapper wraps the input network with a dataset which automatically fetches data with 'GetNext'
|
|
||||||
function from the dataset channel 'queue_name' and does forward computation in the construct function.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
network (Cell): The training network for dataset.
|
|
||||||
dataset_types (list): The type of dataset. The list contains the types of the inputs.
|
|
||||||
dataset_shapes (list): The shapes of dataset. The list contains multiple sublists that describe
|
|
||||||
the shape of the inputs.
|
|
||||||
queue_name (str): The identification of dataset channel which specifies the dataset channel to supply
|
|
||||||
data for the network.
|
|
||||||
|
|
||||||
Outputs:
|
|
||||||
Tensor, network output whose shape depends on the network.
|
|
||||||
|
|
||||||
Examples:
|
|
||||||
>>> # call create_dataset function to create a regular dataset, refer to mindspore.dataset
|
|
||||||
>>> train_dataset = create_dataset()
|
|
||||||
>>> dataset_helper = mindspore.DatasetHelper(train_dataset)
|
|
||||||
>>> net = Net()
|
|
||||||
>>> net = DataWrapper(net, *(dataset_helper.types_shapes()), train_dataset.queue_name)
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, network, dataset_types, dataset_shapes, queue_name):
|
|
||||||
super(DataWrapper, self).__init__(auto_prefix=False, flags=network.get_flags())
|
|
||||||
# Also copy the flag in `network` construct
|
|
||||||
flags = getattr(network.__class__.construct, "_mindspore_flags", {})
|
|
||||||
self.add_flags(**flags)
|
|
||||||
self.get_next = P.GetNext(dataset_types, dataset_shapes, len(dataset_types), queue_name)
|
|
||||||
self.network = network
|
|
||||||
|
|
||||||
def construct(self):
|
|
||||||
outputs = self.get_next()
|
|
||||||
return self.network(*outputs)
|
|
||||||
|
|
||||||
|
|
||||||
class GetNextSingleOp(Cell):
|
class GetNextSingleOp(Cell):
|
||||||
"""
|
"""
|
||||||
Cell to run for getting the next operation.
|
Cell to run for getting the next operation.
|
||||||
|
|
|
@ -2538,7 +2538,7 @@ class GetNext(PrimitiveWithInfer):
|
||||||
Note:
|
Note:
|
||||||
The GetNext operation needs to be associated with network and it also depends on the init_dataset interface,
|
The GetNext operation needs to be associated with network and it also depends on the init_dataset interface,
|
||||||
it can't be used directly as a single operation.
|
it can't be used directly as a single operation.
|
||||||
For details, please refer to `nn.DataWrapper` source code.
|
For details, please refer to `connect_network_with_dataset` source code.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
types (list[:class:`mindspore.dtype`]): The type of the outputs.
|
types (list[:class:`mindspore.dtype`]): The type of the outputs.
|
||||||
|
|
Loading…
Reference in New Issue