forked from mindspore-Ecosystem/mindspore
Modify the export_data interface default description.
This commit is contained in:
parent
293e3fb6dc
commit
e95be8f057
|
@ -68,9 +68,10 @@ void MapTensorPy::UpdateFromNumpy(const MapTensorPtr &map_tensor,
|
||||||
map_tensor->Update(data);
|
map_tensor->Update(data);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::tuple<py::array, py::array, py::array> MapTensorPy::ExportAsNumpy(const MapTensorPtr &map_tensor, bool full) {
|
std::tuple<py::array, py::array, py::array> MapTensorPy::ExportAsNumpy(const MapTensorPtr &map_tensor,
|
||||||
|
bool incremental) {
|
||||||
MS_EXCEPTION_IF_NULL(map_tensor);
|
MS_EXCEPTION_IF_NULL(map_tensor);
|
||||||
auto data = map_tensor->Export(full);
|
auto data = map_tensor->Export(incremental);
|
||||||
return std::make_tuple(TensorPy::AsNumpy(*data.key_tensor), TensorPy::AsNumpy(*data.value_tensor),
|
return std::make_tuple(TensorPy::AsNumpy(*data.key_tensor), TensorPy::AsNumpy(*data.value_tensor),
|
||||||
TensorPy::AsNumpy(*data.status_tensor));
|
TensorPy::AsNumpy(*data.status_tensor));
|
||||||
}
|
}
|
||||||
|
@ -120,14 +121,12 @@ void RegMapTensor(py::module *m) {
|
||||||
const py::object &permit_filter_obj, const py::object &evict_filter_obj) {
|
const py::object &permit_filter_obj, const py::object &evict_filter_obj) {
|
||||||
auto key_tensor_ptr = std::make_shared<tensor::Tensor>(key_tensor);
|
auto key_tensor_ptr = std::make_shared<tensor::Tensor>(key_tensor);
|
||||||
auto value_tensor_ptr = std::make_shared<tensor::Tensor>(value_tensor);
|
auto value_tensor_ptr = std::make_shared<tensor::Tensor>(value_tensor);
|
||||||
auto key_dtype_id = key_tensor_ptr->Dtype()->type_id();
|
auto status_tensor_ptr = std::make_shared<Tensor>(kNumberTypeInt, key_tensor.shape());
|
||||||
auto value_dtype = key_tensor_ptr->Dtype();
|
auto value_dtype = key_tensor_ptr->Dtype();
|
||||||
auto value_dtype_id = value_tensor_ptr->Dtype()->type_id();
|
|
||||||
auto value_shape = value_tensor_ptr->shape();
|
|
||||||
ValuePtr default_value = ConvertMapTensorDefaultValue(default_value_obj, value_dtype);
|
ValuePtr default_value = ConvertMapTensorDefaultValue(default_value_obj, value_dtype);
|
||||||
ValuePtr permit_filter_value = ConvertMapTensorFilterValue(permit_filter_obj);
|
ValuePtr permit_filter_value = ConvertMapTensorFilterValue(permit_filter_obj);
|
||||||
ValuePtr evict_filter_value = ConvertMapTensorFilterValue(evict_filter_obj);
|
ValuePtr evict_filter_value = ConvertMapTensorFilterValue(evict_filter_obj);
|
||||||
return std::make_shared<MapTensor>(key_dtype_id, value_dtype_id, value_shape, default_value,
|
return std::make_shared<MapTensor>(key_tensor_ptr, value_tensor_ptr, status_tensor_ptr, default_value,
|
||||||
permit_filter_value, evict_filter_value);
|
permit_filter_value, evict_filter_value);
|
||||||
}),
|
}),
|
||||||
py::arg("key_tensor"), py::arg("value_tensor"), py::arg("default_value"), py::arg("permit_filter_value"),
|
py::arg("key_tensor"), py::arg("value_tensor"), py::arg("default_value"), py::arg("permit_filter_value"),
|
||||||
|
|
|
@ -729,7 +729,7 @@ bool IrExportBuilder::ConvertMapParameterToMapTensorProto(const ParameterPtr &ma
|
||||||
MS_LOG(ERROR) << "Export default value of MapTensor failed, default_value: " << default_value->ToString();
|
MS_LOG(ERROR) << "Export default value of MapTensor failed, default_value: " << default_value->ToString();
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
tensor::MapTensor::ExportData export_data = map_tensor->Export(!this->incremental_);
|
tensor::MapTensor::ExportData export_data = map_tensor->Export(this->incremental_);
|
||||||
// key_tensor
|
// key_tensor
|
||||||
auto *key_tensor_proto = map_tensor_proto->mutable_key_tensor();
|
auto *key_tensor_proto = map_tensor_proto->mutable_key_tensor();
|
||||||
MS_EXCEPTION_IF_NULL(key_tensor_proto);
|
MS_EXCEPTION_IF_NULL(key_tensor_proto);
|
||||||
|
|
|
@ -142,14 +142,39 @@ MapTensor::ExportData MapTensor::ExportDataFromDevice(const DeviceSyncPtr &devic
|
||||||
return {key_tensor(), value_tensor(), status_tensor()};
|
return {key_tensor(), value_tensor(), status_tensor()};
|
||||||
}
|
}
|
||||||
|
|
||||||
MapTensor::ExportData MapTensor::Export(bool full) {
|
// If the data on the host side is valid, the data on the host side will be exported.
|
||||||
MS_LOG(DEBUG) << (full ? "Full" : "Incremental") << " export MapTensor";
|
bool MapTensor::CheckData() {
|
||||||
|
// check key
|
||||||
|
if (key_tensor()->shape().size() != 1 || key_tensor()->shape()[0] < 1) {
|
||||||
|
MS_LOG(WARNING) << "Invalid key tensor shape: " << tensor::ShapeToString(key_tensor()->shape());
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
// check value
|
||||||
|
bool check_value =
|
||||||
|
std::any_of(value_shape().cbegin(), value_shape().cend(), [](const ShapeValueDType &shape) { return shape < 1; });
|
||||||
|
if (check_value) {
|
||||||
|
MS_LOG(WARNING) << "Invalid value tensor shape: " << tensor::ShapeToString(value_shape());
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
// check status
|
||||||
|
if (status_tensor()->shape().size() != 1 || status_tensor()->shape()[0] < 1) {
|
||||||
|
MS_LOG(WARNING) << "Invalid status tensor shape: " << tensor::ShapeToString(status_tensor()->shape());
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
MapTensor::ExportData MapTensor::Export(bool incremental) {
|
||||||
|
MS_LOG(DEBUG) << (incremental ? "Incremental" : "Full") << " export MapTensor";
|
||||||
|
|
||||||
// Check device
|
// Check device
|
||||||
DeviceSyncPtr device_sync = device_address();
|
DeviceSyncPtr device_sync = device_address();
|
||||||
if (device_sync != nullptr) {
|
if (device_sync != nullptr) {
|
||||||
return ExportDataFromDevice(device_sync);
|
return ExportDataFromDevice(device_sync);
|
||||||
}
|
}
|
||||||
|
if (CheckData()) {
|
||||||
|
return {key_tensor(), value_tensor(), status_tensor()};
|
||||||
|
}
|
||||||
// Note: this is fake implementation.
|
// Note: this is fake implementation.
|
||||||
ShapeVector key_shape = {1};
|
ShapeVector key_shape = {1};
|
||||||
ShapeVector values_shape = ConcatShape(ShapeVector{1}, value_shape());
|
ShapeVector values_shape = ConcatShape(ShapeVector{1}, value_shape());
|
||||||
|
|
|
@ -169,9 +169,9 @@ class MS_CORE_API MapTensor final : public Tensor {
|
||||||
|
|
||||||
/// \brief Exported MapTensor data.
|
/// \brief Exported MapTensor data.
|
||||||
///
|
///
|
||||||
/// \param[in] full [bool] True for full export, false for incremental export.
|
/// \param[in] incremental [bool] False for incremental export, true for full export.
|
||||||
/// \return The exported data.
|
/// \return The exported data.
|
||||||
ExportData Export(bool full = false);
|
ExportData Export(bool incremental = false);
|
||||||
|
|
||||||
/// \brief Exported MapTensor data from device.
|
/// \brief Exported MapTensor data from device.
|
||||||
///
|
///
|
||||||
|
@ -206,6 +206,8 @@ class MS_CORE_API MapTensor final : public Tensor {
|
||||||
|
|
||||||
void set_status_tensor(const TensorPtr status_tensor) { status_tensor_ = status_tensor; }
|
void set_status_tensor(const TensorPtr status_tensor) { status_tensor_ = status_tensor; }
|
||||||
|
|
||||||
|
bool CheckData();
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// Data type of the keys.
|
// Data type of the keys.
|
||||||
TypeId key_dtype_;
|
TypeId key_dtype_;
|
||||||
|
|
|
@ -241,17 +241,20 @@ class MapParameter(Parameter):
|
||||||
self._map_tensor.erase(key_tensor)
|
self._map_tensor.erase(key_tensor)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def export_data(self, full=False):
|
def export_data(self, incremental=False):
|
||||||
"""
|
"""
|
||||||
Export data from this map parameter.
|
Export data from this map parameter.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
full (bool): True for full export, otherwise for incremental export. Default: False.
|
incremental (bool): False for full export, otherwise for incremental export. Default: False.
|
||||||
|
When exporting data incrementally, the value_array does not contain erased data, so the length of the
|
||||||
|
key_array and the length of the value_array may be inconsistent.The length of the key_array and the length
|
||||||
|
of the status_array are consistent.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple(key_array, value_array, status_array), The exported data as a tuple.
|
Tuple(key_array, value_array, status_array), The exported data as a tuple.
|
||||||
"""
|
"""
|
||||||
return self._map_tensor.export_data(full)
|
return self._map_tensor.export_data(incremental)
|
||||||
|
|
||||||
def import_data(self, data):
|
def import_data(self, data):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -849,4 +849,5 @@ def _map_tensor_setitem(map_tensor, key_tensor, value_tensor):
|
||||||
Outputs:
|
Outputs:
|
||||||
MapTensor, the map tensor be updated.
|
MapTensor, the map tensor be updated.
|
||||||
"""
|
"""
|
||||||
return _map_tensor_ops.put(map_tensor, key_tensor, value_tensor)
|
_map_tensor_ops.put(map_tensor, key_tensor, value_tensor)
|
||||||
|
return map_tensor
|
||||||
|
|
|
@ -1426,7 +1426,7 @@ def _save_mindir_together(net_dict, model, file_name, is_encrypt, **kwargs):
|
||||||
map_param_name = map_param_proto.name[map_param_proto.name.find(":") + 1:]
|
map_param_name = map_param_proto.name[map_param_proto.name.find(":") + 1:]
|
||||||
if map_param_name in net_dict.keys():
|
if map_param_name in net_dict.keys():
|
||||||
map_parameter = net_dict[map_param_name]
|
map_parameter = net_dict[map_param_name]
|
||||||
key_nparr, value_nparr, status_nparr = map_parameter.export_data(not incremental)
|
key_nparr, value_nparr, status_nparr = map_parameter.export_data(incremental)
|
||||||
map_param_proto.key_tensor.raw_data = key_nparr.tobytes()
|
map_param_proto.key_tensor.raw_data = key_nparr.tobytes()
|
||||||
map_param_proto.value_tensor.raw_data = value_nparr.tobytes()
|
map_param_proto.value_tensor.raw_data = value_nparr.tobytes()
|
||||||
map_param_proto.status_tensor.raw_data = status_nparr.tobytes()
|
map_param_proto.status_tensor.raw_data = status_nparr.tobytes()
|
||||||
|
|
|
@ -68,7 +68,7 @@ def test_maptensor_put_get_export(ms_type):
|
||||||
self.values = Tensor([[11, 11, 11], [22, 22, 22]], dtype=ms.float32)
|
self.values = Tensor([[11, 11, 11], [22, 22, 22]], dtype=ms.float32)
|
||||||
|
|
||||||
def construct(self, ms_type):
|
def construct(self, ms_type):
|
||||||
self.m.put(self.keys, self.values)
|
self.m[self.keys] = self.values
|
||||||
key1 = Tensor([3], dtype=ms_type)
|
key1 = Tensor([3], dtype=ms_type)
|
||||||
value1 = self.m.get(key1, True)
|
value1 = self.m.get(key1, True)
|
||||||
key2 = Tensor([4], dtype=ms_type)
|
key2 = Tensor([4], dtype=ms_type)
|
||||||
|
@ -84,3 +84,5 @@ def test_maptensor_put_get_export(ms_type):
|
||||||
print("out3:", out3)
|
print("out3:", out3)
|
||||||
data = net.m.export_data()
|
data = net.m.export_data()
|
||||||
print("data:", data)
|
print("data:", data)
|
||||||
|
assert len(data) == 3
|
||||||
|
assert len(data[0]) == 4
|
||||||
|
|
|
@ -100,9 +100,19 @@ def test_export_update_api():
|
||||||
Description: Test export update api for MapParameter.
|
Description: Test export update api for MapParameter.
|
||||||
Expectation: Export update api works as expected.
|
Expectation: Export update api works as expected.
|
||||||
"""
|
"""
|
||||||
m = MapParameter(key_dtype=ms.int32, value_dtype=ms.float32, value_shape=(3,))
|
m1 = MapParameter(key_dtype=ms.int32, value_dtype=ms.float32, value_shape=(3,))
|
||||||
data = m.export_data(full=True)
|
data1 = m1.export_data(incremental=False)
|
||||||
m.import_data(data)
|
print("data1:", data1)
|
||||||
|
m1.import_data(data1)
|
||||||
|
|
||||||
|
keys = Tensor([1, 2], dtype=ms.int32)
|
||||||
|
values = Tensor([[1, 2], [1, 2]], dtype=ms.float32)
|
||||||
|
m2 = MapParameter(key_tensor=keys, value_tensor=values, default_value='zeros')
|
||||||
|
data2 = m2.export_data(incremental=False)
|
||||||
|
print("data2:", data2)
|
||||||
|
m1.import_data(data2)
|
||||||
|
new_data1 = m1.export_data(incremental=False)
|
||||||
|
print("new_data1:", new_data1)
|
||||||
|
|
||||||
|
|
||||||
def test_map_parameter_clone():
|
def test_map_parameter_clone():
|
||||||
|
|
Loading…
Reference in New Issue