Modify the export_data interface default description.

This commit is contained in:
Margaret_wangrui 2022-11-23 17:48:04 +08:00
parent 293e3fb6dc
commit e95be8f057
9 changed files with 62 additions and 20 deletions

View File

@ -68,9 +68,10 @@ void MapTensorPy::UpdateFromNumpy(const MapTensorPtr &map_tensor,
map_tensor->Update(data); map_tensor->Update(data);
} }
std::tuple<py::array, py::array, py::array> MapTensorPy::ExportAsNumpy(const MapTensorPtr &map_tensor, bool full) { std::tuple<py::array, py::array, py::array> MapTensorPy::ExportAsNumpy(const MapTensorPtr &map_tensor,
bool incremental) {
MS_EXCEPTION_IF_NULL(map_tensor); MS_EXCEPTION_IF_NULL(map_tensor);
auto data = map_tensor->Export(full); auto data = map_tensor->Export(incremental);
return std::make_tuple(TensorPy::AsNumpy(*data.key_tensor), TensorPy::AsNumpy(*data.value_tensor), return std::make_tuple(TensorPy::AsNumpy(*data.key_tensor), TensorPy::AsNumpy(*data.value_tensor),
TensorPy::AsNumpy(*data.status_tensor)); TensorPy::AsNumpy(*data.status_tensor));
} }
@ -120,14 +121,12 @@ void RegMapTensor(py::module *m) {
const py::object &permit_filter_obj, const py::object &evict_filter_obj) { const py::object &permit_filter_obj, const py::object &evict_filter_obj) {
auto key_tensor_ptr = std::make_shared<tensor::Tensor>(key_tensor); auto key_tensor_ptr = std::make_shared<tensor::Tensor>(key_tensor);
auto value_tensor_ptr = std::make_shared<tensor::Tensor>(value_tensor); auto value_tensor_ptr = std::make_shared<tensor::Tensor>(value_tensor);
auto key_dtype_id = key_tensor_ptr->Dtype()->type_id(); auto status_tensor_ptr = std::make_shared<Tensor>(kNumberTypeInt, key_tensor.shape());
auto value_dtype = key_tensor_ptr->Dtype(); auto value_dtype = key_tensor_ptr->Dtype();
auto value_dtype_id = value_tensor_ptr->Dtype()->type_id();
auto value_shape = value_tensor_ptr->shape();
ValuePtr default_value = ConvertMapTensorDefaultValue(default_value_obj, value_dtype); ValuePtr default_value = ConvertMapTensorDefaultValue(default_value_obj, value_dtype);
ValuePtr permit_filter_value = ConvertMapTensorFilterValue(permit_filter_obj); ValuePtr permit_filter_value = ConvertMapTensorFilterValue(permit_filter_obj);
ValuePtr evict_filter_value = ConvertMapTensorFilterValue(evict_filter_obj); ValuePtr evict_filter_value = ConvertMapTensorFilterValue(evict_filter_obj);
return std::make_shared<MapTensor>(key_dtype_id, value_dtype_id, value_shape, default_value, return std::make_shared<MapTensor>(key_tensor_ptr, value_tensor_ptr, status_tensor_ptr, default_value,
permit_filter_value, evict_filter_value); permit_filter_value, evict_filter_value);
}), }),
py::arg("key_tensor"), py::arg("value_tensor"), py::arg("default_value"), py::arg("permit_filter_value"), py::arg("key_tensor"), py::arg("value_tensor"), py::arg("default_value"), py::arg("permit_filter_value"),

View File

@ -729,7 +729,7 @@ bool IrExportBuilder::ConvertMapParameterToMapTensorProto(const ParameterPtr &ma
MS_LOG(ERROR) << "Export default value of MapTensor failed, default_value: " << default_value->ToString(); MS_LOG(ERROR) << "Export default value of MapTensor failed, default_value: " << default_value->ToString();
return false; return false;
} }
tensor::MapTensor::ExportData export_data = map_tensor->Export(!this->incremental_); tensor::MapTensor::ExportData export_data = map_tensor->Export(this->incremental_);
// key_tensor // key_tensor
auto *key_tensor_proto = map_tensor_proto->mutable_key_tensor(); auto *key_tensor_proto = map_tensor_proto->mutable_key_tensor();
MS_EXCEPTION_IF_NULL(key_tensor_proto); MS_EXCEPTION_IF_NULL(key_tensor_proto);

View File

@ -142,14 +142,39 @@ MapTensor::ExportData MapTensor::ExportDataFromDevice(const DeviceSyncPtr &devic
return {key_tensor(), value_tensor(), status_tensor()}; return {key_tensor(), value_tensor(), status_tensor()};
} }
MapTensor::ExportData MapTensor::Export(bool full) { // If the data on the host side is valid, the data on the host side will be exported.
MS_LOG(DEBUG) << (full ? "Full" : "Incremental") << " export MapTensor"; bool MapTensor::CheckData() {
// check key
if (key_tensor()->shape().size() != 1 || key_tensor()->shape()[0] < 1) {
MS_LOG(WARNING) << "Invalid key tensor shape: " << tensor::ShapeToString(key_tensor()->shape());
return false;
}
// check value
bool check_value =
std::any_of(value_shape().cbegin(), value_shape().cend(), [](const ShapeValueDType &shape) { return shape < 1; });
if (check_value) {
MS_LOG(WARNING) << "Invalid value tensor shape: " << tensor::ShapeToString(value_shape());
return false;
}
// check status
if (status_tensor()->shape().size() != 1 || status_tensor()->shape()[0] < 1) {
MS_LOG(WARNING) << "Invalid status tensor shape: " << tensor::ShapeToString(status_tensor()->shape());
return false;
}
return true;
}
MapTensor::ExportData MapTensor::Export(bool incremental) {
MS_LOG(DEBUG) << (incremental ? "Incremental" : "Full") << " export MapTensor";
// Check device // Check device
DeviceSyncPtr device_sync = device_address(); DeviceSyncPtr device_sync = device_address();
if (device_sync != nullptr) { if (device_sync != nullptr) {
return ExportDataFromDevice(device_sync); return ExportDataFromDevice(device_sync);
} }
if (CheckData()) {
return {key_tensor(), value_tensor(), status_tensor()};
}
// Note: this is fake implementation. // Note: this is fake implementation.
ShapeVector key_shape = {1}; ShapeVector key_shape = {1};
ShapeVector values_shape = ConcatShape(ShapeVector{1}, value_shape()); ShapeVector values_shape = ConcatShape(ShapeVector{1}, value_shape());

View File

@ -169,9 +169,9 @@ class MS_CORE_API MapTensor final : public Tensor {
/// \brief Exported MapTensor data. /// \brief Exported MapTensor data.
/// ///
/// \param[in] full [bool] True for full export, false for incremental export. /// \param[in] incremental [bool] False for incremental export, true for full export.
/// \return The exported data. /// \return The exported data.
ExportData Export(bool full = false); ExportData Export(bool incremental = false);
/// \brief Exported MapTensor data from device. /// \brief Exported MapTensor data from device.
/// ///
@ -206,6 +206,8 @@ class MS_CORE_API MapTensor final : public Tensor {
void set_status_tensor(const TensorPtr status_tensor) { status_tensor_ = status_tensor; } void set_status_tensor(const TensorPtr status_tensor) { status_tensor_ = status_tensor; }
bool CheckData();
private: private:
// Data type of the keys. // Data type of the keys.
TypeId key_dtype_; TypeId key_dtype_;

View File

@ -241,17 +241,20 @@ class MapParameter(Parameter):
self._map_tensor.erase(key_tensor) self._map_tensor.erase(key_tensor)
return self return self
def export_data(self, full=False): def export_data(self, incremental=False):
""" """
Export data from this map parameter. Export data from this map parameter.
Args: Args:
full (bool): True for full export, otherwise for incremental export. Default: False. incremental (bool): False for full export, otherwise for incremental export. Default: False.
When exporting data incrementally, the value_array does not contain erased data, so the length of the
key_array and the length of the value_array may be inconsistent.The length of the key_array and the length
of the status_array are consistent.
Returns: Returns:
Tuple(key_array, value_array, status_array), The exported data as a tuple. Tuple(key_array, value_array, status_array), The exported data as a tuple.
""" """
return self._map_tensor.export_data(full) return self._map_tensor.export_data(incremental)
def import_data(self, data): def import_data(self, data):
""" """

View File

@ -849,4 +849,5 @@ def _map_tensor_setitem(map_tensor, key_tensor, value_tensor):
Outputs: Outputs:
MapTensor, the map tensor be updated. MapTensor, the map tensor be updated.
""" """
return _map_tensor_ops.put(map_tensor, key_tensor, value_tensor) _map_tensor_ops.put(map_tensor, key_tensor, value_tensor)
return map_tensor

View File

@ -1426,7 +1426,7 @@ def _save_mindir_together(net_dict, model, file_name, is_encrypt, **kwargs):
map_param_name = map_param_proto.name[map_param_proto.name.find(":") + 1:] map_param_name = map_param_proto.name[map_param_proto.name.find(":") + 1:]
if map_param_name in net_dict.keys(): if map_param_name in net_dict.keys():
map_parameter = net_dict[map_param_name] map_parameter = net_dict[map_param_name]
key_nparr, value_nparr, status_nparr = map_parameter.export_data(not incremental) key_nparr, value_nparr, status_nparr = map_parameter.export_data(incremental)
map_param_proto.key_tensor.raw_data = key_nparr.tobytes() map_param_proto.key_tensor.raw_data = key_nparr.tobytes()
map_param_proto.value_tensor.raw_data = value_nparr.tobytes() map_param_proto.value_tensor.raw_data = value_nparr.tobytes()
map_param_proto.status_tensor.raw_data = status_nparr.tobytes() map_param_proto.status_tensor.raw_data = status_nparr.tobytes()

View File

@ -68,7 +68,7 @@ def test_maptensor_put_get_export(ms_type):
self.values = Tensor([[11, 11, 11], [22, 22, 22]], dtype=ms.float32) self.values = Tensor([[11, 11, 11], [22, 22, 22]], dtype=ms.float32)
def construct(self, ms_type): def construct(self, ms_type):
self.m.put(self.keys, self.values) self.m[self.keys] = self.values
key1 = Tensor([3], dtype=ms_type) key1 = Tensor([3], dtype=ms_type)
value1 = self.m.get(key1, True) value1 = self.m.get(key1, True)
key2 = Tensor([4], dtype=ms_type) key2 = Tensor([4], dtype=ms_type)
@ -84,3 +84,5 @@ def test_maptensor_put_get_export(ms_type):
print("out3:", out3) print("out3:", out3)
data = net.m.export_data() data = net.m.export_data()
print("data:", data) print("data:", data)
assert len(data) == 3
assert len(data[0]) == 4

View File

@ -100,9 +100,19 @@ def test_export_update_api():
Description: Test export update api for MapParameter. Description: Test export update api for MapParameter.
Expectation: Export update api works as expected. Expectation: Export update api works as expected.
""" """
m = MapParameter(key_dtype=ms.int32, value_dtype=ms.float32, value_shape=(3,)) m1 = MapParameter(key_dtype=ms.int32, value_dtype=ms.float32, value_shape=(3,))
data = m.export_data(full=True) data1 = m1.export_data(incremental=False)
m.import_data(data) print("data1:", data1)
m1.import_data(data1)
keys = Tensor([1, 2], dtype=ms.int32)
values = Tensor([[1, 2], [1, 2]], dtype=ms.float32)
m2 = MapParameter(key_tensor=keys, value_tensor=values, default_value='zeros')
data2 = m2.export_data(incremental=False)
print("data2:", data2)
m1.import_data(data2)
new_data1 = m1.export_data(incremental=False)
print("new_data1:", new_data1)
def test_map_parameter_clone(): def test_map_parameter_clone():