forked from mindspore-Ecosystem/mindspore
Modify the export_data interface default description.
This commit is contained in:
parent
293e3fb6dc
commit
e95be8f057
|
@ -68,9 +68,10 @@ void MapTensorPy::UpdateFromNumpy(const MapTensorPtr &map_tensor,
|
|||
map_tensor->Update(data);
|
||||
}
|
||||
|
||||
std::tuple<py::array, py::array, py::array> MapTensorPy::ExportAsNumpy(const MapTensorPtr &map_tensor, bool full) {
|
||||
std::tuple<py::array, py::array, py::array> MapTensorPy::ExportAsNumpy(const MapTensorPtr &map_tensor,
|
||||
bool incremental) {
|
||||
MS_EXCEPTION_IF_NULL(map_tensor);
|
||||
auto data = map_tensor->Export(full);
|
||||
auto data = map_tensor->Export(incremental);
|
||||
return std::make_tuple(TensorPy::AsNumpy(*data.key_tensor), TensorPy::AsNumpy(*data.value_tensor),
|
||||
TensorPy::AsNumpy(*data.status_tensor));
|
||||
}
|
||||
|
@ -120,14 +121,12 @@ void RegMapTensor(py::module *m) {
|
|||
const py::object &permit_filter_obj, const py::object &evict_filter_obj) {
|
||||
auto key_tensor_ptr = std::make_shared<tensor::Tensor>(key_tensor);
|
||||
auto value_tensor_ptr = std::make_shared<tensor::Tensor>(value_tensor);
|
||||
auto key_dtype_id = key_tensor_ptr->Dtype()->type_id();
|
||||
auto status_tensor_ptr = std::make_shared<Tensor>(kNumberTypeInt, key_tensor.shape());
|
||||
auto value_dtype = key_tensor_ptr->Dtype();
|
||||
auto value_dtype_id = value_tensor_ptr->Dtype()->type_id();
|
||||
auto value_shape = value_tensor_ptr->shape();
|
||||
ValuePtr default_value = ConvertMapTensorDefaultValue(default_value_obj, value_dtype);
|
||||
ValuePtr permit_filter_value = ConvertMapTensorFilterValue(permit_filter_obj);
|
||||
ValuePtr evict_filter_value = ConvertMapTensorFilterValue(evict_filter_obj);
|
||||
return std::make_shared<MapTensor>(key_dtype_id, value_dtype_id, value_shape, default_value,
|
||||
return std::make_shared<MapTensor>(key_tensor_ptr, value_tensor_ptr, status_tensor_ptr, default_value,
|
||||
permit_filter_value, evict_filter_value);
|
||||
}),
|
||||
py::arg("key_tensor"), py::arg("value_tensor"), py::arg("default_value"), py::arg("permit_filter_value"),
|
||||
|
|
|
@ -729,7 +729,7 @@ bool IrExportBuilder::ConvertMapParameterToMapTensorProto(const ParameterPtr &ma
|
|||
MS_LOG(ERROR) << "Export default value of MapTensor failed, default_value: " << default_value->ToString();
|
||||
return false;
|
||||
}
|
||||
tensor::MapTensor::ExportData export_data = map_tensor->Export(!this->incremental_);
|
||||
tensor::MapTensor::ExportData export_data = map_tensor->Export(this->incremental_);
|
||||
// key_tensor
|
||||
auto *key_tensor_proto = map_tensor_proto->mutable_key_tensor();
|
||||
MS_EXCEPTION_IF_NULL(key_tensor_proto);
|
||||
|
|
|
@ -142,14 +142,39 @@ MapTensor::ExportData MapTensor::ExportDataFromDevice(const DeviceSyncPtr &devic
|
|||
return {key_tensor(), value_tensor(), status_tensor()};
|
||||
}
|
||||
|
||||
MapTensor::ExportData MapTensor::Export(bool full) {
|
||||
MS_LOG(DEBUG) << (full ? "Full" : "Incremental") << " export MapTensor";
|
||||
// If the data on the host side is valid, the data on the host side will be exported.
|
||||
bool MapTensor::CheckData() {
|
||||
// check key
|
||||
if (key_tensor()->shape().size() != 1 || key_tensor()->shape()[0] < 1) {
|
||||
MS_LOG(WARNING) << "Invalid key tensor shape: " << tensor::ShapeToString(key_tensor()->shape());
|
||||
return false;
|
||||
}
|
||||
// check value
|
||||
bool check_value =
|
||||
std::any_of(value_shape().cbegin(), value_shape().cend(), [](const ShapeValueDType &shape) { return shape < 1; });
|
||||
if (check_value) {
|
||||
MS_LOG(WARNING) << "Invalid value tensor shape: " << tensor::ShapeToString(value_shape());
|
||||
return false;
|
||||
}
|
||||
// check status
|
||||
if (status_tensor()->shape().size() != 1 || status_tensor()->shape()[0] < 1) {
|
||||
MS_LOG(WARNING) << "Invalid status tensor shape: " << tensor::ShapeToString(status_tensor()->shape());
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
MapTensor::ExportData MapTensor::Export(bool incremental) {
|
||||
MS_LOG(DEBUG) << (incremental ? "Incremental" : "Full") << " export MapTensor";
|
||||
|
||||
// Check device
|
||||
DeviceSyncPtr device_sync = device_address();
|
||||
if (device_sync != nullptr) {
|
||||
return ExportDataFromDevice(device_sync);
|
||||
}
|
||||
if (CheckData()) {
|
||||
return {key_tensor(), value_tensor(), status_tensor()};
|
||||
}
|
||||
// Note: this is fake implementation.
|
||||
ShapeVector key_shape = {1};
|
||||
ShapeVector values_shape = ConcatShape(ShapeVector{1}, value_shape());
|
||||
|
|
|
@ -169,9 +169,9 @@ class MS_CORE_API MapTensor final : public Tensor {
|
|||
|
||||
/// \brief Exported MapTensor data.
|
||||
///
|
||||
/// \param[in] full [bool] True for full export, false for incremental export.
|
||||
/// \param[in] incremental [bool] False for incremental export, true for full export.
|
||||
/// \return The exported data.
|
||||
ExportData Export(bool full = false);
|
||||
ExportData Export(bool incremental = false);
|
||||
|
||||
/// \brief Exported MapTensor data from device.
|
||||
///
|
||||
|
@ -206,6 +206,8 @@ class MS_CORE_API MapTensor final : public Tensor {
|
|||
|
||||
void set_status_tensor(const TensorPtr status_tensor) { status_tensor_ = status_tensor; }
|
||||
|
||||
bool CheckData();
|
||||
|
||||
private:
|
||||
// Data type of the keys.
|
||||
TypeId key_dtype_;
|
||||
|
|
|
@ -241,17 +241,20 @@ class MapParameter(Parameter):
|
|||
self._map_tensor.erase(key_tensor)
|
||||
return self
|
||||
|
||||
def export_data(self, full=False):
|
||||
def export_data(self, incremental=False):
|
||||
"""
|
||||
Export data from this map parameter.
|
||||
|
||||
Args:
|
||||
full (bool): True for full export, otherwise for incremental export. Default: False.
|
||||
incremental (bool): False for full export, otherwise for incremental export. Default: False.
|
||||
When exporting data incrementally, the value_array does not contain erased data, so the length of the
|
||||
key_array and the length of the value_array may be inconsistent.The length of the key_array and the length
|
||||
of the status_array are consistent.
|
||||
|
||||
Returns:
|
||||
Tuple(key_array, value_array, status_array), The exported data as a tuple.
|
||||
"""
|
||||
return self._map_tensor.export_data(full)
|
||||
return self._map_tensor.export_data(incremental)
|
||||
|
||||
def import_data(self, data):
|
||||
"""
|
||||
|
|
|
@ -849,4 +849,5 @@ def _map_tensor_setitem(map_tensor, key_tensor, value_tensor):
|
|||
Outputs:
|
||||
MapTensor, the map tensor be updated.
|
||||
"""
|
||||
return _map_tensor_ops.put(map_tensor, key_tensor, value_tensor)
|
||||
_map_tensor_ops.put(map_tensor, key_tensor, value_tensor)
|
||||
return map_tensor
|
||||
|
|
|
@ -1426,7 +1426,7 @@ def _save_mindir_together(net_dict, model, file_name, is_encrypt, **kwargs):
|
|||
map_param_name = map_param_proto.name[map_param_proto.name.find(":") + 1:]
|
||||
if map_param_name in net_dict.keys():
|
||||
map_parameter = net_dict[map_param_name]
|
||||
key_nparr, value_nparr, status_nparr = map_parameter.export_data(not incremental)
|
||||
key_nparr, value_nparr, status_nparr = map_parameter.export_data(incremental)
|
||||
map_param_proto.key_tensor.raw_data = key_nparr.tobytes()
|
||||
map_param_proto.value_tensor.raw_data = value_nparr.tobytes()
|
||||
map_param_proto.status_tensor.raw_data = status_nparr.tobytes()
|
||||
|
|
|
@ -68,7 +68,7 @@ def test_maptensor_put_get_export(ms_type):
|
|||
self.values = Tensor([[11, 11, 11], [22, 22, 22]], dtype=ms.float32)
|
||||
|
||||
def construct(self, ms_type):
|
||||
self.m.put(self.keys, self.values)
|
||||
self.m[self.keys] = self.values
|
||||
key1 = Tensor([3], dtype=ms_type)
|
||||
value1 = self.m.get(key1, True)
|
||||
key2 = Tensor([4], dtype=ms_type)
|
||||
|
@ -84,3 +84,5 @@ def test_maptensor_put_get_export(ms_type):
|
|||
print("out3:", out3)
|
||||
data = net.m.export_data()
|
||||
print("data:", data)
|
||||
assert len(data) == 3
|
||||
assert len(data[0]) == 4
|
||||
|
|
|
@ -100,9 +100,19 @@ def test_export_update_api():
|
|||
Description: Test export update api for MapParameter.
|
||||
Expectation: Export update api works as expected.
|
||||
"""
|
||||
m = MapParameter(key_dtype=ms.int32, value_dtype=ms.float32, value_shape=(3,))
|
||||
data = m.export_data(full=True)
|
||||
m.import_data(data)
|
||||
m1 = MapParameter(key_dtype=ms.int32, value_dtype=ms.float32, value_shape=(3,))
|
||||
data1 = m1.export_data(incremental=False)
|
||||
print("data1:", data1)
|
||||
m1.import_data(data1)
|
||||
|
||||
keys = Tensor([1, 2], dtype=ms.int32)
|
||||
values = Tensor([[1, 2], [1, 2]], dtype=ms.float32)
|
||||
m2 = MapParameter(key_tensor=keys, value_tensor=values, default_value='zeros')
|
||||
data2 = m2.export_data(incremental=False)
|
||||
print("data2:", data2)
|
||||
m1.import_data(data2)
|
||||
new_data1 = m1.export_data(incremental=False)
|
||||
print("new_data1:", new_data1)
|
||||
|
||||
|
||||
def test_map_parameter_clone():
|
||||
|
|
Loading…
Reference in New Issue