forked from mindspore-Ecosystem/mindspore
add dynamic shape support
This commit is contained in:
parent
68cb63d7f6
commit
48e688c166
|
@ -76,6 +76,12 @@ PYBIND_REGISTER(
|
||||||
THROW_IF_ERROR(de.GetOutputTypes(&out));
|
THROW_IF_ERROR(de.GetOutputTypes(&out));
|
||||||
return out;
|
return out;
|
||||||
})
|
})
|
||||||
|
.def("GetDataInfo",
|
||||||
|
[](DEPipeline &de) {
|
||||||
|
py::list types, shapes;
|
||||||
|
THROW_IF_ERROR(de.GetDataInfo(&types, &shapes));
|
||||||
|
return py::make_tuple(types, shapes);
|
||||||
|
})
|
||||||
.def("GetDatasetSize", &DEPipeline::GetDatasetSize)
|
.def("GetDatasetSize", &DEPipeline::GetDatasetSize)
|
||||||
.def("GetBatchSize", &DEPipeline::GetBatchSize)
|
.def("GetBatchSize", &DEPipeline::GetBatchSize)
|
||||||
.def("GetNumClasses", &DEPipeline::GetNumClasses)
|
.def("GetNumClasses", &DEPipeline::GetNumClasses)
|
||||||
|
|
|
@ -241,6 +241,30 @@ Status DEPipeline::GetNextAsList(py::list *output) {
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status DEPipeline::GetDataInfo(py::list *types, py::list *shapes) {
|
||||||
|
Status s;
|
||||||
|
DATA_INFO data_info;
|
||||||
|
// tree_.root() must be DeviceQueueOp
|
||||||
|
DeviceQueueOp *op = dynamic_cast<DeviceQueueOp *>(tree_->root().get());
|
||||||
|
if (op == nullptr) {
|
||||||
|
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "GetDataInfo only supported by DeviceQueueOp");
|
||||||
|
}
|
||||||
|
{
|
||||||
|
py::gil_scoped_release gil_release;
|
||||||
|
s = op->GetDataInfo(&data_info);
|
||||||
|
}
|
||||||
|
RETURN_IF_NOT_OK(s);
|
||||||
|
for (auto el : data_info) {
|
||||||
|
types->append(el.first.AsNumpyType());
|
||||||
|
py::list shape;
|
||||||
|
for (auto dim : el.second.AsVector()) {
|
||||||
|
shape.append(dim);
|
||||||
|
}
|
||||||
|
shapes->append(shape);
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
Status DEPipeline::GetOutputShapes(py::list *output) {
|
Status DEPipeline::GetOutputShapes(py::list *output) {
|
||||||
std::vector<TensorShape> shapes;
|
std::vector<TensorShape> shapes;
|
||||||
Status s;
|
Status s;
|
||||||
|
@ -1070,6 +1094,8 @@ Status DEPipeline::ParseDeviceQueueOp(const py::dict &args, std::shared_ptr<Data
|
||||||
(void)builder->SetSendEpochEnd(ToBool(value));
|
(void)builder->SetSendEpochEnd(ToBool(value));
|
||||||
} else if (key == "total_batch") {
|
} else if (key == "total_batch") {
|
||||||
(void)builder->SetTotalBatch(ToInt(value));
|
(void)builder->SetTotalBatch(ToInt(value));
|
||||||
|
} else if (key == "create_data_info_queue") {
|
||||||
|
(void)builder->SetCreateDataInfoQueue(ToBool(value));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -111,6 +111,8 @@ class DEPipeline {
|
||||||
|
|
||||||
Status GetOutputTypes(py::list *output);
|
Status GetOutputTypes(py::list *output);
|
||||||
|
|
||||||
|
Status GetDataInfo(py::list *types, py::list *shapes);
|
||||||
|
|
||||||
Status SaveDataset(const std::vector<std::string> &file_names, const std::string &file_type);
|
Status SaveDataset(const std::vector<std::string> &file_names, const std::string &file_type);
|
||||||
|
|
||||||
int GetDatasetSize() const;
|
int GetDatasetSize() const;
|
||||||
|
|
|
@ -33,7 +33,7 @@
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace dataset {
|
namespace dataset {
|
||||||
DeviceQueueOp::DeviceQueueOp(std::string channel_name, DeviceType device_type, int32_t device_id, int32_t prefetch_size,
|
DeviceQueueOp::DeviceQueueOp(std::string channel_name, DeviceType device_type, int32_t device_id, int32_t prefetch_size,
|
||||||
bool send_epoch_end, int32_t total_batch)
|
bool send_epoch_end, int32_t total_batch, bool create_data_info_queue)
|
||||||
: PipelineOp(1),
|
: PipelineOp(1),
|
||||||
channel_name_(channel_name),
|
channel_name_(channel_name),
|
||||||
device_type_(device_type),
|
device_type_(device_type),
|
||||||
|
@ -41,7 +41,8 @@ DeviceQueueOp::DeviceQueueOp(std::string channel_name, DeviceType device_type, i
|
||||||
prefetch_size_(prefetch_size),
|
prefetch_size_(prefetch_size),
|
||||||
send_epoch_end_(send_epoch_end),
|
send_epoch_end_(send_epoch_end),
|
||||||
stop_send_(false),
|
stop_send_(false),
|
||||||
total_batch_(total_batch) {
|
total_batch_(total_batch),
|
||||||
|
create_data_info_queue_(create_data_info_queue) {
|
||||||
#ifdef ENABLE_TDTQUE
|
#ifdef ENABLE_TDTQUE
|
||||||
ascend_keep_waiting_ = true;
|
ascend_keep_waiting_ = true;
|
||||||
#endif
|
#endif
|
||||||
|
@ -87,6 +88,10 @@ Status DeviceQueueOp::operator()() {
|
||||||
|
|
||||||
if (device_type_ == DeviceType::Ascend) {
|
if (device_type_ == DeviceType::Ascend) {
|
||||||
#ifdef ENABLE_TDTQUE
|
#ifdef ENABLE_TDTQUE
|
||||||
|
if (create_data_info_queue_) {
|
||||||
|
data_info_queue_ptr_ = std::make_unique<DATA_INFO_QUEUE>(kDataInfoQueueCapacity);
|
||||||
|
RETURN_IF_NOT_OK(data_info_queue_ptr_->Register(tree_->AllTasks()));
|
||||||
|
}
|
||||||
RETURN_IF_NOT_OK(SendDataToAscend());
|
RETURN_IF_NOT_OK(SendDataToAscend());
|
||||||
#endif
|
#endif
|
||||||
} else if (device_type_ == DeviceType::GPU) {
|
} else if (device_type_ == DeviceType::GPU) {
|
||||||
|
@ -142,6 +147,13 @@ Status DeviceQueueOp::SendDataToAscend() {
|
||||||
return Status(StatusCode::kTDTPushFailure, "TDT Push Failed");
|
return Status(StatusCode::kTDTPushFailure, "TDT Push Failed");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if (create_data_info_queue_) {
|
||||||
|
DATA_INFO data_info;
|
||||||
|
(void)std::transform(
|
||||||
|
currRow.begin(), currRow.end(), std::back_inserter(data_info),
|
||||||
|
[](const std::shared_ptr<Tensor> &ts) { return std::make_pair(ts->type(), ts->shape()); });
|
||||||
|
RETURN_IF_NOT_OK(data_info_queue_ptr_->Add(data_info));
|
||||||
|
}
|
||||||
|
|
||||||
if (isProfilingEnable) {
|
if (isProfilingEnable) {
|
||||||
end_time = ProfilingTime::GetCurMilliSecond();
|
end_time = ProfilingTime::GetCurMilliSecond();
|
||||||
|
@ -157,6 +169,7 @@ Status DeviceQueueOp::SendDataToAscend() {
|
||||||
profiling_node->Record(CONNECTOR_DEPTH, connector_capacity, send_batch + 1, connector_size);
|
profiling_node->Record(CONNECTOR_DEPTH, connector_capacity, send_batch + 1, connector_size);
|
||||||
}
|
}
|
||||||
send_batch++;
|
send_batch++;
|
||||||
|
|
||||||
if (total_batch_ > 0 && send_batch >= total_batch_) {
|
if (total_batch_ > 0 && send_batch >= total_batch_) {
|
||||||
is_break_loop = true;
|
is_break_loop = true;
|
||||||
break;
|
break;
|
||||||
|
@ -196,6 +209,21 @@ Status DeviceQueueOp::SendDataToAscend() {
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#ifdef ENABLE_TDTQUE
|
||||||
|
Status DeviceQueueOp::GetDataInfo(DATA_INFO *data_info) {
|
||||||
|
if (!create_data_info_queue_) {
|
||||||
|
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "DataInfo queue is not created.");
|
||||||
|
}
|
||||||
|
RETURN_IF_NOT_OK(data_info_queue_ptr_->PopFront(data_info));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
Status DeviceQueueOp::GetDataInfo(DATA_INFO *data_info) {
|
||||||
|
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "GetDataInfo is not supported yet.");
|
||||||
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#ifdef ENABLE_GPUQUE
|
#ifdef ENABLE_GPUQUE
|
||||||
|
|
|
@ -18,6 +18,7 @@
|
||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "minddata/dataset/engine/datasetops/pipeline_op.h"
|
#include "minddata/dataset/engine/datasetops/pipeline_op.h"
|
||||||
|
@ -25,6 +26,7 @@
|
||||||
#include "minddata/dataset/util/status.h"
|
#include "minddata/dataset/util/status.h"
|
||||||
|
|
||||||
#ifdef ENABLE_TDTQUE
|
#ifdef ENABLE_TDTQUE
|
||||||
|
#include "minddata/dataset/util/queue.h"
|
||||||
#include "minddata/dataset/engine/tdt/tdt_plugin.h"
|
#include "minddata/dataset/engine/tdt/tdt_plugin.h"
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
@ -37,6 +39,10 @@ using mindspore::device::GpuBufferMgr;
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace dataset {
|
namespace dataset {
|
||||||
|
|
||||||
|
using DATA_INFO = std::vector<std::pair<DataType, TensorShape>>;
|
||||||
|
using DATA_INFO_QUEUE = Queue<DATA_INFO>;
|
||||||
|
const int kDataInfoQueueCapacity = 128;
|
||||||
class DeviceQueueOp : public PipelineOp {
|
class DeviceQueueOp : public PipelineOp {
|
||||||
public:
|
public:
|
||||||
static const uint32_t INVALID_HANDLE = 0xffffffffUL;
|
static const uint32_t INVALID_HANDLE = 0xffffffffUL;
|
||||||
|
@ -91,13 +97,18 @@ class DeviceQueueOp : public PipelineOp {
|
||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Builder &SetCreateDataInfoQueue(bool create_data_info_queue) {
|
||||||
|
builder_create_data_info_queue_ = create_data_info_queue;
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
// Name: Build()
|
// Name: Build()
|
||||||
// Description: The final step for building a DeviceQueueOp via the Builder is
|
// Description: The final step for building a DeviceQueueOp via the Builder is
|
||||||
// to call this Build() method. It will instantiate the DeviceQueueOp
|
// to call this Build() method. It will instantiate the DeviceQueueOp
|
||||||
// and return it to caller as a shared pointer.
|
// and return it to caller as a shared pointer.
|
||||||
Status Build(std::shared_ptr<DeviceQueueOp> *ptr) {
|
Status Build(std::shared_ptr<DeviceQueueOp> *ptr) {
|
||||||
*ptr = std::make_shared<DeviceQueueOp>(builder_channel_name_, builder_device_type_, builder_device_id_,
|
*ptr = std::make_shared<DeviceQueueOp>(builder_channel_name_, builder_device_type_, builder_device_id_,
|
||||||
builder_prefetch_size_, builder_send_epoch_end_, builder_total_batch_);
|
builder_prefetch_size_, builder_send_epoch_end_, builder_total_batch_,
|
||||||
|
builder_create_data_info_queue_);
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -108,12 +119,13 @@ class DeviceQueueOp : public PipelineOp {
|
||||||
std::string builder_channel_name_;
|
std::string builder_channel_name_;
|
||||||
bool builder_send_epoch_end_;
|
bool builder_send_epoch_end_;
|
||||||
int32_t builder_total_batch_;
|
int32_t builder_total_batch_;
|
||||||
|
bool builder_create_data_info_queue_;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Name: constructor
|
// Name: constructor
|
||||||
// Description
|
// Description
|
||||||
DeviceQueueOp(std::string channel_name, DeviceType device_type, int32_t device_id, int32_t prefetch_size,
|
DeviceQueueOp(std::string channel_name, DeviceType device_type, int32_t device_id, int32_t prefetch_size,
|
||||||
bool send_epoch_end, int32_t total_batch);
|
bool send_epoch_end, int32_t total_batch, bool create_data_info_queue);
|
||||||
|
|
||||||
// Name: destructor
|
// Name: destructor
|
||||||
// Description
|
// Description
|
||||||
|
@ -138,6 +150,8 @@ class DeviceQueueOp : public PipelineOp {
|
||||||
void StopWaiting() { ascend_keep_waiting_ = false; }
|
void StopWaiting() { ascend_keep_waiting_ = false; }
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
Status GetDataInfo(DATA_INFO *data_info);
|
||||||
|
|
||||||
// Name: Print()
|
// Name: Print()
|
||||||
// Description: A function that prints info about the node
|
// Description: A function that prints info about the node
|
||||||
void Print(std::ostream &out, // In: The output stream to print to
|
void Print(std::ostream &out, // In: The output stream to print to
|
||||||
|
@ -170,6 +184,7 @@ class DeviceQueueOp : public PipelineOp {
|
||||||
#ifdef ENABLE_TDTQUE
|
#ifdef ENABLE_TDTQUE
|
||||||
Status SendDataToAscend();
|
Status SendDataToAscend();
|
||||||
bool ascend_keep_waiting_;
|
bool ascend_keep_waiting_;
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#ifdef ENABLE_GPUQUE
|
#ifdef ENABLE_GPUQUE
|
||||||
|
@ -190,6 +205,8 @@ class DeviceQueueOp : public PipelineOp {
|
||||||
const bool send_epoch_end_;
|
const bool send_epoch_end_;
|
||||||
bool stop_send_;
|
bool stop_send_;
|
||||||
int32_t total_batch_;
|
int32_t total_batch_;
|
||||||
|
bool create_data_info_queue_;
|
||||||
|
std::unique_ptr<DATA_INFO_QUEUE> data_info_queue_ptr_;
|
||||||
|
|
||||||
#ifdef ENABLE_TDTQUE
|
#ifdef ENABLE_TDTQUE
|
||||||
std::shared_ptr<TdtPlugin> tdtInstancePtr;
|
std::shared_ptr<TdtPlugin> tdtInstancePtr;
|
||||||
|
|
|
@ -62,9 +62,8 @@ std::vector<std::shared_ptr<DatasetOp>> TransferNode::Build() {
|
||||||
} else if (device_type_ == "Ascend") {
|
} else if (device_type_ == "Ascend") {
|
||||||
type = DeviceQueueOp::DeviceType::Ascend;
|
type = DeviceQueueOp::DeviceType::Ascend;
|
||||||
}
|
}
|
||||||
|
node_ops.push_back(std::make_shared<DeviceQueueOp>(queue_name_, type, device_id_, prefetch_size_, send_epoch_end_,
|
||||||
node_ops.push_back(
|
total_batch_, false));
|
||||||
std::make_shared<DeviceQueueOp>(queue_name_, type, device_id_, prefetch_size_, send_epoch_end_, total_batch_));
|
|
||||||
return node_ops;
|
return node_ops;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1005,7 +1005,7 @@ class Dataset:
|
||||||
return dataset
|
return dataset
|
||||||
|
|
||||||
@check_device_send
|
@check_device_send
|
||||||
def device_que(self, prefetch_size=None, send_epoch_end=True):
|
def device_que(self, prefetch_size=None, send_epoch_end=True, create_data_info_queue=False):
|
||||||
"""
|
"""
|
||||||
Return a transferred Dataset that transfers data through a device.
|
Return a transferred Dataset that transfers data through a device.
|
||||||
|
|
||||||
|
@ -1013,6 +1013,8 @@ class Dataset:
|
||||||
prefetch_size (int, optional): Prefetch number of records ahead of the
|
prefetch_size (int, optional): Prefetch number of records ahead of the
|
||||||
user's request (default=None).
|
user's request (default=None).
|
||||||
send_epoch_end (bool, optional): Whether to send end of sequence to device or not (default=True).
|
send_epoch_end (bool, optional): Whether to send end of sequence to device or not (default=True).
|
||||||
|
create_data_info_queue (bool, optional): Whether to create queue which stores
|
||||||
|
types and shapes of data or not(default=False).
|
||||||
|
|
||||||
Note:
|
Note:
|
||||||
If device is Ascend, features of data will be transferred one by one. The limitation
|
If device is Ascend, features of data will be transferred one by one. The limitation
|
||||||
|
@ -1021,15 +1023,17 @@ class Dataset:
|
||||||
Return:
|
Return:
|
||||||
TransferDataset, dataset for transferring.
|
TransferDataset, dataset for transferring.
|
||||||
"""
|
"""
|
||||||
return self.to_device(send_epoch_end=send_epoch_end)
|
return self.to_device(send_epoch_end=send_epoch_end, create_data_info_queue=create_data_info_queue)
|
||||||
|
|
||||||
@check_device_send
|
@check_device_send
|
||||||
def to_device(self, send_epoch_end=True):
|
def to_device(self, send_epoch_end=True, create_data_info_queue=False):
|
||||||
"""
|
"""
|
||||||
Transfer data through CPU, GPU or Ascend devices.
|
Transfer data through CPU, GPU or Ascend devices.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
send_epoch_end (bool, optional): Whether to send end of sequence to device or not (default=True).
|
send_epoch_end (bool, optional): Whether to send end of sequence to device or not (default=True).
|
||||||
|
create_data_info_queue (bool, optional): Whether to create queue which stores
|
||||||
|
types and shapes of data or not(default=False).
|
||||||
|
|
||||||
Note:
|
Note:
|
||||||
If device is Ascend, features of data will be transferred one by one. The limitation
|
If device is Ascend, features of data will be transferred one by one. The limitation
|
||||||
|
@ -1078,7 +1082,7 @@ class Dataset:
|
||||||
|
|
||||||
distribution_path, device_id = get_distribution(self)
|
distribution_path, device_id = get_distribution(self)
|
||||||
if distribution_path == "":
|
if distribution_path == "":
|
||||||
return TransferDataset(self, queue_name, device_id, device_type, send_epoch_end)
|
return TransferDataset(self, queue_name, device_id, device_type, send_epoch_end, create_data_info_queue)
|
||||||
try:
|
try:
|
||||||
with open(distribution_path, 'r') as distribution_f:
|
with open(distribution_path, 'r') as distribution_f:
|
||||||
dist = json.load(distribution_f)
|
dist = json.load(distribution_f)
|
||||||
|
@ -1088,7 +1092,7 @@ class Dataset:
|
||||||
except Exception:
|
except Exception:
|
||||||
raise RuntimeError("Failed to read Distribution file.")
|
raise RuntimeError("Failed to read Distribution file.")
|
||||||
|
|
||||||
return TransferDataset(self, queue_name, device_id, device_type, send_epoch_end)
|
return TransferDataset(self, queue_name, device_id, device_type, send_epoch_end, create_data_info_queue)
|
||||||
|
|
||||||
@check_save
|
@check_save
|
||||||
def save(self, file_name, num_files=1, file_type='mindrecord'):
|
def save(self, file_name, num_files=1, file_type='mindrecord'):
|
||||||
|
@ -2640,9 +2644,12 @@ class TransferDataset(DatasetOp):
|
||||||
device_id (int): ID of device.
|
device_id (int): ID of device.
|
||||||
device_type (str): Type of device, including "CPU", "GPU", and "Ascend".
|
device_type (str): Type of device, including "CPU", "GPU", and "Ascend".
|
||||||
send_epoch_end (bool, optional): Whether to send end of sequence to device or not (default=True).
|
send_epoch_end (bool, optional): Whether to send end of sequence to device or not (default=True).
|
||||||
|
create_data_info_queue (bool, optional): Whether to create queue which stores
|
||||||
|
types and shapes of data or not(default=False).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, input_dataset, queue_name, device_id, device_type, send_epoch_end=True):
|
def __init__(self, input_dataset, queue_name, device_id, device_type, send_epoch_end=True,
|
||||||
|
create_data_info_queue=False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.children.append(input_dataset)
|
self.children.append(input_dataset)
|
||||||
input_dataset.parent.append(self)
|
input_dataset.parent.append(self)
|
||||||
|
@ -2652,6 +2659,7 @@ class TransferDataset(DatasetOp):
|
||||||
self._device_id = device_id
|
self._device_id = device_id
|
||||||
self._send_epoch_end = send_epoch_end
|
self._send_epoch_end = send_epoch_end
|
||||||
self.iterator = None
|
self.iterator = None
|
||||||
|
self._create_data_info_queue = create_data_info_queue
|
||||||
|
|
||||||
def get_args(self):
|
def get_args(self):
|
||||||
args = super().get_args()
|
args = super().get_args()
|
||||||
|
@ -2661,6 +2669,7 @@ class TransferDataset(DatasetOp):
|
||||||
args["send_epoch_end"] = self._send_epoch_end
|
args["send_epoch_end"] = self._send_epoch_end
|
||||||
if hasattr(self.children[0], "__total_batch__"):
|
if hasattr(self.children[0], "__total_batch__"):
|
||||||
args["total_batch"] = self.children[0].__total_batch__
|
args["total_batch"] = self.children[0].__total_batch__
|
||||||
|
args["create_data_info_queue"] = self._create_data_info_queue
|
||||||
return args
|
return args
|
||||||
|
|
||||||
def create_dict_iterator(self, num_epochs=-1, output_numpy=False):
|
def create_dict_iterator(self, num_epochs=-1, output_numpy=False):
|
||||||
|
@ -2692,6 +2701,9 @@ class TransferDataset(DatasetOp):
|
||||||
def continue_send(self):
|
def continue_send(self):
|
||||||
self.iterator.depipeline.ContinueSend()
|
self.iterator.depipeline.ContinueSend()
|
||||||
|
|
||||||
|
def get_data_info(self):
|
||||||
|
return self.iterator.depipeline.GetDataInfo()
|
||||||
|
|
||||||
|
|
||||||
class RangeDataset(MappableDataset):
|
class RangeDataset(MappableDataset):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -50,7 +50,7 @@ def _get_types_and_shapes(dataset):
|
||||||
return dataset_types, dataset_shapes
|
return dataset_types, dataset_shapes
|
||||||
|
|
||||||
|
|
||||||
def _exec_datagraph(exec_dataset, dataset_size, phase='dataset'):
|
def _exec_datagraph(exec_dataset, dataset_size, phase='dataset', create_data_info_queue=False):
|
||||||
"""Initialize and execute the dataset graph."""
|
"""Initialize and execute the dataset graph."""
|
||||||
batch_size = exec_dataset.get_batch_size()
|
batch_size = exec_dataset.get_batch_size()
|
||||||
input_indexs = exec_dataset.input_indexs
|
input_indexs = exec_dataset.input_indexs
|
||||||
|
@ -58,7 +58,7 @@ def _exec_datagraph(exec_dataset, dataset_size, phase='dataset'):
|
||||||
# transform data format
|
# transform data format
|
||||||
dataset_types, dataset_shapes = _get_types_and_shapes(exec_dataset)
|
dataset_types, dataset_shapes = _get_types_and_shapes(exec_dataset)
|
||||||
send_epoch_end = bool(dataset_size == -1)
|
send_epoch_end = bool(dataset_size == -1)
|
||||||
exec_dataset = exec_dataset.device_que(send_epoch_end=send_epoch_end)
|
exec_dataset = exec_dataset.device_que(send_epoch_end=send_epoch_end, create_data_info_queue=create_data_info_queue)
|
||||||
|
|
||||||
_executor.init_dataset(exec_dataset.queue_name,
|
_executor.init_dataset(exec_dataset.queue_name,
|
||||||
dataset_size,
|
dataset_size,
|
||||||
|
|
|
@ -17,6 +17,7 @@ import math
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from mindspore._checkparam import Validator
|
from mindspore._checkparam import Validator
|
||||||
|
from mindspore.common.dtype import pytype_to_dtype
|
||||||
from .. import context, nn
|
from .. import context, nn
|
||||||
from ._utils import _exec_datagraph, _get_types_and_shapes, _construct_tensor_list
|
from ._utils import _exec_datagraph, _get_types_and_shapes, _construct_tensor_list
|
||||||
from ..nn.wrap import GetNextSingleOp
|
from ..nn.wrap import GetNextSingleOp
|
||||||
|
@ -31,6 +32,7 @@ def _send_data(dataset, epoch_num):
|
||||||
exec_dataset.send(epoch_num)
|
exec_dataset.send(epoch_num)
|
||||||
dataset.__has_sent__ = True
|
dataset.__has_sent__ = True
|
||||||
|
|
||||||
|
|
||||||
def _send_data_no_flag(dataset, epoch_num):
|
def _send_data_no_flag(dataset, epoch_num):
|
||||||
"""Engine dataset to write data to tdt queue directly."""
|
"""Engine dataset to write data to tdt queue directly."""
|
||||||
exec_dataset = dataset.__transfer_dataset__
|
exec_dataset = dataset.__transfer_dataset__
|
||||||
|
@ -70,6 +72,7 @@ def connect_network_with_dataset(network, dataset_helper):
|
||||||
Wraps the input network with a dataset which automatically fetches data with 'GetNext' function from the
|
Wraps the input network with a dataset which automatically fetches data with 'GetNext' function from the
|
||||||
dataset channel 'queue_name' and performs the forward computation.
|
dataset channel 'queue_name' and performs the forward computation.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, network, dataset_types, dataset_shapes, queue_name):
|
def __init__(self, network, dataset_types, dataset_shapes, queue_name):
|
||||||
super(_DataWrapper, self).__init__(auto_prefix=False, flags=network.get_flags())
|
super(_DataWrapper, self).__init__(auto_prefix=False, flags=network.get_flags())
|
||||||
# Also copy the flag in `network` construct
|
# Also copy the flag in `network` construct
|
||||||
|
@ -88,9 +91,30 @@ def connect_network_with_dataset(network, dataset_helper):
|
||||||
if isinstance(dataset_iter, _DatasetIterNormal):
|
if isinstance(dataset_iter, _DatasetIterNormal):
|
||||||
raise RuntimeError("Dataset should be connected with network only in sink mode.")
|
raise RuntimeError("Dataset should be connected with network only in sink mode.")
|
||||||
|
|
||||||
if not hasattr(dataset, '__me_inited__') and (context.get_context("device_target") == "Ascend"
|
if (hasattr(dataset_iter, "sink_size") and dataset_iter.sink_size == 1) \
|
||||||
or context.get_context("device_target") == "GPU") and not \
|
and (hasattr(dataset_iter, "sink_count") and dataset_iter.sink_count == 1) \
|
||||||
context.get_context("enable_ge"):
|
and context.get_context("device_target") == "Ascend":
|
||||||
|
|
||||||
|
if not hasattr(dataset, '__network__'):
|
||||||
|
dataset.__network__ = network
|
||||||
|
network = dataset.__network__
|
||||||
|
|
||||||
|
dataset_types, dataset_shapes = dataset_helper.get_data_info()
|
||||||
|
dataset_types = [pytype_to_dtype(x) for x in dataset_types]
|
||||||
|
|
||||||
|
key = str(dataset_types) + str(dataset_shapes)
|
||||||
|
if hasattr(dataset, '__network_manage__') and key in dataset.__network_manage__:
|
||||||
|
network = dataset.__network_manage__[key]
|
||||||
|
else:
|
||||||
|
network = _DataWrapper(network, dataset_types, dataset_shapes, dataset.__transfer_dataset__.queue_name)
|
||||||
|
dataset.__network_manage__ = dataset.__network_manage__ if hasattr(
|
||||||
|
dataset, '__network_manage__') else dict()
|
||||||
|
dataset.__network_manage__[key] = network
|
||||||
|
|
||||||
|
return network
|
||||||
|
|
||||||
|
if not hasattr(dataset, '__me_inited__') and (context.get_context("device_target") == "Ascend" or \
|
||||||
|
context.get_context("device_target") == "GPU") and not context.get_context("enable_ge"):
|
||||||
dataset.__me_inited__ = True
|
dataset.__me_inited__ = True
|
||||||
|
|
||||||
dataset_types, dataset_shapes = dataset_helper.types_shapes()
|
dataset_types, dataset_shapes = dataset_helper.types_shapes()
|
||||||
|
@ -99,7 +123,6 @@ def connect_network_with_dataset(network, dataset_helper):
|
||||||
network = _DataWrapper(network, dataset_types, dataset_shapes, queue_name)
|
network = _DataWrapper(network, dataset_types, dataset_shapes, queue_name)
|
||||||
return network
|
return network
|
||||||
|
|
||||||
|
|
||||||
class DatasetHelper:
|
class DatasetHelper:
|
||||||
"""
|
"""
|
||||||
DatasetHelper is a class to process the MindData dataset and it provides the information of dataset.
|
DatasetHelper is a class to process the MindData dataset and it provides the information of dataset.
|
||||||
|
@ -171,18 +194,25 @@ class DatasetHelper:
|
||||||
"""continue send data to device at the beginning of epoch."""
|
"""continue send data to device at the beginning of epoch."""
|
||||||
self.iter.continue_send()
|
self.iter.continue_send()
|
||||||
|
|
||||||
|
def get_data_info(self):
|
||||||
|
return self.iter.get_data_info()
|
||||||
|
|
||||||
|
|
||||||
class _DatasetIter:
|
class _DatasetIter:
|
||||||
"""Base iter for dataset helper"""
|
"""Base iter for dataset helper"""
|
||||||
|
|
||||||
def __init__(self, dataset, sink_size, epoch_num):
|
def __init__(self, dataset, sink_size, epoch_num):
|
||||||
self.dataset = dataset
|
self.dataset = dataset
|
||||||
self.sink_size = sink_size
|
self.sink_size = sink_size
|
||||||
self.sink_count = 1
|
self.sink_count = self.get_sink_count(dataset)
|
||||||
|
|
||||||
if not hasattr(dataset, '__transfer_dataset__'):
|
if not hasattr(dataset, '__transfer_dataset__'):
|
||||||
if hasattr(dataset, '__loop_size__'):
|
if hasattr(dataset, '__loop_size__'):
|
||||||
self.sink_size = dataset.__loop_size__
|
self.sink_size = dataset.__loop_size__
|
||||||
dataset.__transfer_dataset__ = _exec_datagraph(dataset, self.sink_size)
|
create_data_info_queue = (sink_size == 1 and self.sink_count == 1 and context.get_context(
|
||||||
|
"device_target") == "Ascend")
|
||||||
|
dataset.__transfer_dataset__ = _exec_datagraph(dataset, self.sink_size,
|
||||||
|
create_data_info_queue=create_data_info_queue)
|
||||||
|
|
||||||
if not hasattr(dataset, '__no_send__'):
|
if not hasattr(dataset, '__no_send__'):
|
||||||
_send_data(dataset, epoch_num)
|
_send_data(dataset, epoch_num)
|
||||||
|
@ -191,6 +221,7 @@ class _DatasetIter:
|
||||||
|
|
||||||
self.stop_send = dataset.__transfer_dataset__.stop_send
|
self.stop_send = dataset.__transfer_dataset__.stop_send
|
||||||
self.continue_send = dataset.__transfer_dataset__.continue_send
|
self.continue_send = dataset.__transfer_dataset__.continue_send
|
||||||
|
self.get_data_info = dataset.__transfer_dataset__.get_data_info
|
||||||
self.dataset_types, self.dataset_shapes = _get_types_and_shapes(dataset)
|
self.dataset_types, self.dataset_shapes = _get_types_and_shapes(dataset)
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
|
@ -223,7 +254,7 @@ class _DatasetIter:
|
||||||
sink_size = self.dataset.__loop_size__
|
sink_size = self.dataset.__loop_size__
|
||||||
else:
|
else:
|
||||||
if context.get_context("enable_ge") or context.get_context("device_target") == "Ascend" \
|
if context.get_context("enable_ge") or context.get_context("device_target") == "Ascend" \
|
||||||
or context.get_context("device_target") == "GPU":
|
or context.get_context("device_target") == "GPU":
|
||||||
if self.sink_size > 0:
|
if self.sink_size > 0:
|
||||||
sink_size = self.sink_size
|
sink_size = self.sink_size
|
||||||
else:
|
else:
|
||||||
|
@ -233,6 +264,7 @@ class _DatasetIter:
|
||||||
|
|
||||||
class _DatasetIterGE(_DatasetIter):
|
class _DatasetIterGE(_DatasetIter):
|
||||||
"""Iter for GE."""
|
"""Iter for GE."""
|
||||||
|
|
||||||
def __init__(self, dataset, sink_size, epoch_num):
|
def __init__(self, dataset, sink_size, epoch_num):
|
||||||
super().__init__(dataset, sink_size, epoch_num)
|
super().__init__(dataset, sink_size, epoch_num)
|
||||||
self.sink_count = self.get_sink_count(dataset)
|
self.sink_count = self.get_sink_count(dataset)
|
||||||
|
@ -249,6 +281,7 @@ class _DatasetIterGE(_DatasetIter):
|
||||||
|
|
||||||
class _DatasetIterMSLoopSink(_DatasetIter):
|
class _DatasetIterMSLoopSink(_DatasetIter):
|
||||||
"""Iter for context (device_target=Ascend)"""
|
"""Iter for context (device_target=Ascend)"""
|
||||||
|
|
||||||
def __init__(self, dataset, sink_size, epoch_num):
|
def __init__(self, dataset, sink_size, epoch_num):
|
||||||
super().__init__(dataset, sink_size, epoch_num)
|
super().__init__(dataset, sink_size, epoch_num)
|
||||||
self.sink_count = self.get_sink_count(dataset)
|
self.sink_count = self.get_sink_count(dataset)
|
||||||
|
@ -270,6 +303,7 @@ class _DatasetIterMSLoopSink(_DatasetIter):
|
||||||
|
|
||||||
class _DatasetIterMS(_DatasetIter):
|
class _DatasetIterMS(_DatasetIter):
|
||||||
"""Iter for MS(enable_loop_sink=False)."""
|
"""Iter for MS(enable_loop_sink=False)."""
|
||||||
|
|
||||||
def __init__(self, dataset, sink_size, epoch_num):
|
def __init__(self, dataset, sink_size, epoch_num):
|
||||||
super().__init__(dataset, sink_size, epoch_num)
|
super().__init__(dataset, sink_size, epoch_num)
|
||||||
if sink_size > 0:
|
if sink_size > 0:
|
||||||
|
@ -283,11 +317,13 @@ class _DatasetIterMS(_DatasetIter):
|
||||||
|
|
||||||
class _DatasetIterPSLite(_DatasetIter):
|
class _DatasetIterPSLite(_DatasetIter):
|
||||||
"""Iter for context (device_target=GPU) on MS_PSERVER or MS_SCHED"""
|
"""Iter for context (device_target=GPU) on MS_PSERVER or MS_SCHED"""
|
||||||
|
|
||||||
def __init__(self, dataset, sink_size, epoch_num):
|
def __init__(self, dataset, sink_size, epoch_num):
|
||||||
super().__init__(dataset, sink_size, epoch_num)
|
super().__init__(dataset, sink_size, epoch_num)
|
||||||
self.sink_count = 1
|
self.sink_count = 1
|
||||||
self.sink_size = 1
|
self.sink_size = 1
|
||||||
self.op = None
|
self.op = None
|
||||||
|
|
||||||
def op():
|
def op():
|
||||||
return _construct_tensor_list(self.dataset_types, self.dataset_shapes, batch_expand_num=1)
|
return _construct_tensor_list(self.dataset_types, self.dataset_shapes, batch_expand_num=1)
|
||||||
self.op = op
|
self.op = op
|
||||||
|
|
|
@ -250,11 +250,14 @@ class Model:
|
||||||
scaling_sens /= self._device_number
|
scaling_sens /= self._device_number
|
||||||
return scaling_sens
|
return scaling_sens
|
||||||
|
|
||||||
def _exec_preprocess(self, network, is_train, phase, dataset, dataset_sink_mode, sink_size=-1, epoch_num=1):
|
def _exec_preprocess(self, network, is_train, phase, dataset,
|
||||||
|
dataset_sink_mode, sink_size=-1, epoch_num=1, dataset_helper=None):
|
||||||
"""Initializes dataset."""
|
"""Initializes dataset."""
|
||||||
if dataset_sink_mode and not is_train:
|
if dataset_sink_mode and not is_train:
|
||||||
dataset.__loop_size__ = 1
|
dataset.__loop_size__ = 1
|
||||||
dataset_helper = DatasetHelper(dataset, dataset_sink_mode, sink_size, epoch_num)
|
|
||||||
|
if dataset_helper is None:
|
||||||
|
dataset_helper = DatasetHelper(dataset, dataset_sink_mode, sink_size, epoch_num)
|
||||||
|
|
||||||
if dataset_sink_mode:
|
if dataset_sink_mode:
|
||||||
network = connect_network_with_dataset(network, dataset_helper)
|
network = connect_network_with_dataset(network, dataset_helper)
|
||||||
|
@ -405,15 +408,6 @@ class Model:
|
||||||
epoch_num = math.ceil(epoch * sink_size / train_dataset.get_dataset_size())
|
epoch_num = math.ceil(epoch * sink_size / train_dataset.get_dataset_size())
|
||||||
train_dataset.__total_batch__ = epoch * sink_size
|
train_dataset.__total_batch__ = epoch * sink_size
|
||||||
|
|
||||||
dataset_helper, train_network = self._exec_preprocess(self._train_network,
|
|
||||||
is_train=True,
|
|
||||||
phase='train',
|
|
||||||
dataset=train_dataset,
|
|
||||||
dataset_sink_mode=True,
|
|
||||||
sink_size=sink_size,
|
|
||||||
epoch_num=epoch_num)
|
|
||||||
self._train_network = train_network
|
|
||||||
cb_params.train_network = self._train_network
|
|
||||||
cb_params.cur_step_num = 0
|
cb_params.cur_step_num = 0
|
||||||
|
|
||||||
run_context = RunContext(cb_params)
|
run_context = RunContext(cb_params)
|
||||||
|
@ -421,9 +415,21 @@ class Model:
|
||||||
|
|
||||||
# used to stop training for early stop, such as stopAtTIme or stopATStep
|
# used to stop training for early stop, such as stopAtTIme or stopATStep
|
||||||
should_stop = False
|
should_stop = False
|
||||||
|
dataset_helper = None
|
||||||
for i in range(epoch):
|
for i in range(epoch):
|
||||||
cb_params.cur_epoch_num = i + 1
|
cb_params.cur_epoch_num = i + 1
|
||||||
list_callback.epoch_begin(run_context)
|
list_callback.epoch_begin(run_context)
|
||||||
|
dataset_helper, train_network = self._exec_preprocess(self._train_network,
|
||||||
|
is_train=True,
|
||||||
|
phase='train',
|
||||||
|
dataset=train_dataset,
|
||||||
|
dataset_sink_mode=True,
|
||||||
|
sink_size=sink_size,
|
||||||
|
epoch_num=epoch_num,
|
||||||
|
dataset_helper=dataset_helper)
|
||||||
|
|
||||||
|
self._train_network = train_network
|
||||||
|
cb_params.train_network = self._train_network
|
||||||
|
|
||||||
# for data sink dataset_helper only iter once, other wise iter epoch_size times.
|
# for data sink dataset_helper only iter once, other wise iter epoch_size times.
|
||||||
for inputs in dataset_helper:
|
for inputs in dataset_helper:
|
||||||
|
|
|
@ -133,7 +133,7 @@ def tokenize_lambada(file_path):
|
||||||
with open(file_path, 'r', encoding='utf-8') as f:
|
with open(file_path, 'r', encoding='utf-8') as f:
|
||||||
for line in f.readlines():
|
for line in f.readlines():
|
||||||
para = json.loads(line)['text'].replace(
|
para = json.loads(line)['text'].replace(
|
||||||
"“", '""').replace("”", '"').strip().strip(".")
|
"“", '"').replace("”", '"').strip().strip(".")
|
||||||
tokenized_text = tokenizer.tokenize(para)
|
tokenized_text = tokenizer.tokenize(para)
|
||||||
content += tokenizer.convert_tokens_to_ids(tokenized_text) + [EOT]
|
content += tokenizer.convert_tokens_to_ids(tokenized_text) + [EOT]
|
||||||
for chunk in chunks(content, SEQ_LEN):
|
for chunk in chunks(content, SEQ_LEN):
|
||||||
|
|
|
@ -50,7 +50,7 @@ class MindData:
|
||||||
def input_indexs(self):
|
def input_indexs(self):
|
||||||
return self._input_indexs
|
return self._input_indexs
|
||||||
|
|
||||||
def device_que(self, send_epoch_end=True):
|
def device_que(self, send_epoch_end=True, create_data_info_queue=False):
|
||||||
self.queue_name = '6ba41974-209e-11ea-88b0-a24efeb2c736'
|
self.queue_name = '6ba41974-209e-11ea-88b0-a24efeb2c736'
|
||||||
self.send_epoch_end = send_epoch_end
|
self.send_epoch_end = send_epoch_end
|
||||||
return self
|
return self
|
||||||
|
@ -61,6 +61,9 @@ class MindData:
|
||||||
def send(self, num_epochs=-1):
|
def send(self, num_epochs=-1):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def get_data_info(self):
|
||||||
|
pass
|
||||||
|
|
||||||
def stop_send(self):
|
def stop_send(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue