Modify the export_data interface default description.

This commit is contained in:
Margaret_wangrui 2022-11-23 17:48:04 +08:00
parent 293e3fb6dc
commit e95be8f057
9 changed files with 62 additions and 20 deletions

View File

@ -68,9 +68,10 @@ void MapTensorPy::UpdateFromNumpy(const MapTensorPtr &map_tensor,
map_tensor->Update(data);
}
std::tuple<py::array, py::array, py::array> MapTensorPy::ExportAsNumpy(const MapTensorPtr &map_tensor, bool full) {
std::tuple<py::array, py::array, py::array> MapTensorPy::ExportAsNumpy(const MapTensorPtr &map_tensor,
bool incremental) {
MS_EXCEPTION_IF_NULL(map_tensor);
auto data = map_tensor->Export(full);
auto data = map_tensor->Export(incremental);
return std::make_tuple(TensorPy::AsNumpy(*data.key_tensor), TensorPy::AsNumpy(*data.value_tensor),
TensorPy::AsNumpy(*data.status_tensor));
}
@ -120,14 +121,12 @@ void RegMapTensor(py::module *m) {
const py::object &permit_filter_obj, const py::object &evict_filter_obj) {
auto key_tensor_ptr = std::make_shared<tensor::Tensor>(key_tensor);
auto value_tensor_ptr = std::make_shared<tensor::Tensor>(value_tensor);
auto key_dtype_id = key_tensor_ptr->Dtype()->type_id();
auto status_tensor_ptr = std::make_shared<Tensor>(kNumberTypeInt, key_tensor.shape());
auto value_dtype = key_tensor_ptr->Dtype();
auto value_dtype_id = value_tensor_ptr->Dtype()->type_id();
auto value_shape = value_tensor_ptr->shape();
ValuePtr default_value = ConvertMapTensorDefaultValue(default_value_obj, value_dtype);
ValuePtr permit_filter_value = ConvertMapTensorFilterValue(permit_filter_obj);
ValuePtr evict_filter_value = ConvertMapTensorFilterValue(evict_filter_obj);
return std::make_shared<MapTensor>(key_dtype_id, value_dtype_id, value_shape, default_value,
return std::make_shared<MapTensor>(key_tensor_ptr, value_tensor_ptr, status_tensor_ptr, default_value,
permit_filter_value, evict_filter_value);
}),
py::arg("key_tensor"), py::arg("value_tensor"), py::arg("default_value"), py::arg("permit_filter_value"),

View File

@ -729,7 +729,7 @@ bool IrExportBuilder::ConvertMapParameterToMapTensorProto(const ParameterPtr &ma
MS_LOG(ERROR) << "Export default value of MapTensor failed, default_value: " << default_value->ToString();
return false;
}
tensor::MapTensor::ExportData export_data = map_tensor->Export(!this->incremental_);
tensor::MapTensor::ExportData export_data = map_tensor->Export(this->incremental_);
// key_tensor
auto *key_tensor_proto = map_tensor_proto->mutable_key_tensor();
MS_EXCEPTION_IF_NULL(key_tensor_proto);

View File

@ -142,14 +142,39 @@ MapTensor::ExportData MapTensor::ExportDataFromDevice(const DeviceSyncPtr &devic
return {key_tensor(), value_tensor(), status_tensor()};
}
MapTensor::ExportData MapTensor::Export(bool full) {
MS_LOG(DEBUG) << (full ? "Full" : "Incremental") << " export MapTensor";
// If the data on the host side is valid, the data on the host side will be exported.
bool MapTensor::CheckData() {
// check key
if (key_tensor()->shape().size() != 1 || key_tensor()->shape()[0] < 1) {
MS_LOG(WARNING) << "Invalid key tensor shape: " << tensor::ShapeToString(key_tensor()->shape());
return false;
}
// check value
bool check_value =
std::any_of(value_shape().cbegin(), value_shape().cend(), [](const ShapeValueDType &shape) { return shape < 1; });
if (check_value) {
MS_LOG(WARNING) << "Invalid value tensor shape: " << tensor::ShapeToString(value_shape());
return false;
}
// check status
if (status_tensor()->shape().size() != 1 || status_tensor()->shape()[0] < 1) {
MS_LOG(WARNING) << "Invalid status tensor shape: " << tensor::ShapeToString(status_tensor()->shape());
return false;
}
return true;
}
MapTensor::ExportData MapTensor::Export(bool incremental) {
MS_LOG(DEBUG) << (incremental ? "Incremental" : "Full") << " export MapTensor";
// Check device
DeviceSyncPtr device_sync = device_address();
if (device_sync != nullptr) {
return ExportDataFromDevice(device_sync);
}
if (CheckData()) {
return {key_tensor(), value_tensor(), status_tensor()};
}
// Note: this is fake implementation.
ShapeVector key_shape = {1};
ShapeVector values_shape = ConcatShape(ShapeVector{1}, value_shape());

View File

@ -169,9 +169,9 @@ class MS_CORE_API MapTensor final : public Tensor {
/// \brief Exported MapTensor data.
///
/// \param[in] full [bool] True for full export, false for incremental export.
/// \param[in] incremental [bool] False for incremental export, true for full export.
/// \return The exported data.
ExportData Export(bool full = false);
ExportData Export(bool incremental = false);
/// \brief Exported MapTensor data from device.
///
@ -206,6 +206,8 @@ class MS_CORE_API MapTensor final : public Tensor {
void set_status_tensor(const TensorPtr status_tensor) { status_tensor_ = status_tensor; }
bool CheckData();
private:
// Data type of the keys.
TypeId key_dtype_;

View File

@ -241,17 +241,20 @@ class MapParameter(Parameter):
self._map_tensor.erase(key_tensor)
return self
def export_data(self, full=False):
def export_data(self, incremental=False):
"""
Export data from this map parameter.
Args:
full (bool): True for full export, otherwise for incremental export. Default: False.
incremental (bool): False for full export, otherwise for incremental export. Default: False.
When exporting data incrementally, the value_array does not contain erased data, so the length of the
key_array and the length of the value_array may be inconsistent.The length of the key_array and the length
of the status_array are consistent.
Returns:
Tuple(key_array, value_array, status_array), The exported data as a tuple.
"""
return self._map_tensor.export_data(full)
return self._map_tensor.export_data(incremental)
def import_data(self, data):
"""

View File

@ -849,4 +849,5 @@ def _map_tensor_setitem(map_tensor, key_tensor, value_tensor):
Outputs:
MapTensor, the map tensor be updated.
"""
return _map_tensor_ops.put(map_tensor, key_tensor, value_tensor)
_map_tensor_ops.put(map_tensor, key_tensor, value_tensor)
return map_tensor

View File

@ -1426,7 +1426,7 @@ def _save_mindir_together(net_dict, model, file_name, is_encrypt, **kwargs):
map_param_name = map_param_proto.name[map_param_proto.name.find(":") + 1:]
if map_param_name in net_dict.keys():
map_parameter = net_dict[map_param_name]
key_nparr, value_nparr, status_nparr = map_parameter.export_data(not incremental)
key_nparr, value_nparr, status_nparr = map_parameter.export_data(incremental)
map_param_proto.key_tensor.raw_data = key_nparr.tobytes()
map_param_proto.value_tensor.raw_data = value_nparr.tobytes()
map_param_proto.status_tensor.raw_data = status_nparr.tobytes()

View File

@ -68,7 +68,7 @@ def test_maptensor_put_get_export(ms_type):
self.values = Tensor([[11, 11, 11], [22, 22, 22]], dtype=ms.float32)
def construct(self, ms_type):
self.m.put(self.keys, self.values)
self.m[self.keys] = self.values
key1 = Tensor([3], dtype=ms_type)
value1 = self.m.get(key1, True)
key2 = Tensor([4], dtype=ms_type)
@ -84,3 +84,5 @@ def test_maptensor_put_get_export(ms_type):
print("out3:", out3)
data = net.m.export_data()
print("data:", data)
assert len(data) == 3
assert len(data[0]) == 4

View File

@ -100,9 +100,19 @@ def test_export_update_api():
Description: Test export update api for MapParameter.
Expectation: Export update api works as expected.
"""
m = MapParameter(key_dtype=ms.int32, value_dtype=ms.float32, value_shape=(3,))
data = m.export_data(full=True)
m.import_data(data)
m1 = MapParameter(key_dtype=ms.int32, value_dtype=ms.float32, value_shape=(3,))
data1 = m1.export_data(incremental=False)
print("data1:", data1)
m1.import_data(data1)
keys = Tensor([1, 2], dtype=ms.int32)
values = Tensor([[1, 2], [1, 2]], dtype=ms.float32)
m2 = MapParameter(key_tensor=keys, value_tensor=values, default_value='zeros')
data2 = m2.export_data(incremental=False)
print("data2:", data2)
m1.import_data(data2)
new_data1 = m1.export_data(incremental=False)
print("new_data1:", new_data1)
def test_map_parameter_clone():