!42873 Add Export/Update API for MapParameter

Merge pull request !42873 from hewei/map_param1
This commit is contained in:
i-robot 2022-09-27 01:23:21 +00:00 committed by Gitee
commit 9d3ffd2dfc
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
4 changed files with 85 additions and 9 deletions

View File

@ -58,7 +58,14 @@ void RegMapTensor(py::module *m) {
.def(py::init(&MapTensorPy::MakeMapTensor), py::arg("key_dtype"), py::arg("value_dtype"), py::arg("value_shape"))
.def_property_readonly("key_dtype", &MapTensor::KeyDtype)
.def_property_readonly("value_dtype", &MapTensor::ValueDtype)
.def_property_readonly("value_shape", &MapTensor::value_shape);
.def_property_readonly("value_shape", &MapTensor::value_shape)
.def("get", &MapTensor::Get)
.def("put", &MapTensor::Put)
.def("erase", &MapTensor::Erase)
.def("export", &MapTensorPy::ExportAsNumpy)
.def("update", &MapTensorPy::UpdateFromNumpy)
.def("__str__", &MapTensor::ToString)
.def("__repr__", &MapTensor::ToString);
}
} // namespace tensor
} // namespace mindspore

View File

@ -16,12 +16,20 @@
#include "ir/map_tensor.h"
#include "abstract/abstract_value.h"
#include "ir/tensor.h"
#include "utils/log_adapter.h"
#include "utils/ms_utils_secure.h"
namespace mindspore {
using tensor::Tensor;
using tensor::TensorPtr;
static ShapeVector ConcatShape(const ShapeVector &a, const ShapeVector &b) {
ShapeVector result_shape = a;
result_shape.insert(result_shape.end(), b.begin(), b.end());
return result_shape;
}
std::size_t MapTensor::hash() const { return static_cast<std::size_t>(tid()); }
bool MapTensor::operator==(const MapTensor &other) const { return this == &other; }
@ -40,7 +48,27 @@ abstract::AbstractBasePtr MapTensor::ToAbstract() {
TensorPtr MapTensor::Get(const TensorPtr &key_tensor, const TensorPtr &default_value) {
MS_EXCEPTION_IF_NULL(key_tensor);
MS_EXCEPTION_IF_NULL(default_value);
return nullptr;
// Check input.
if (key_tensor->shape().size() != 1) {
MS_LOG(EXCEPTION) << "Invalid key tensor shape: " << tensor::ShapeToString(key_tensor->shape());
}
// Result shape = key_tensor.shape + value_shape.
ShapeVector result_shape = ConcatShape(key_tensor->shape(), value_shape());
// Make the result tensor.
TensorPtr result_tensor = std::make_shared<Tensor>(value_dtype(), result_shape);
// Note: this is the fake implementation that fill result tensor by copy default values.
const size_t num_of_rows = static_cast<size_t>(result_shape[0]);
const size_t default_value_bytes = static_cast<size_t>(default_value->data().nbytes());
const uint8_t *default_value_data = static_cast<const uint8_t *>(default_value->data_c());
auto data_ptr = static_cast<uint8_t *>(result_tensor->data_c());
for (size_t i = 0; i < num_of_rows; ++i) {
auto ret = common::huge_memcpy(data_ptr, default_value_bytes, default_value_data, default_value_bytes);
if (ret != EOK) {
MS_LOG(EXCEPTION) << "Copy tensor data failed!";
}
data_ptr += default_value_bytes;
}
return result_tensor;
}
void MapTensor::Put(const TensorPtr &key_tensor, const TensorPtr &value_tensor) {
@ -57,6 +85,12 @@ void MapTensor::Update(const MapTensor::ExportData &data) {
MapTensor::ExportData MapTensor::Export(bool full) {
MS_LOG(DEBUG) << (full ? "Full" : "Incremental") << " export MapTensor";
return {nullptr, nullptr, nullptr};
// Note: this is fake implementation.
ShapeVector key_shape = {1};
ShapeVector values_shape = ConcatShape(ShapeVector{1}, value_shape());
auto key_tensor = std::make_shared<Tensor>(key_dtype(), key_shape);
auto value_tensor = std::make_shared<Tensor>(value_dtype(), values_shape);
auto status_tensor = std::make_shared<Tensor>(kNumberTypeUInt8, key_shape);
return {key_tensor, value_tensor, status_tensor};
}
} // namespace mindspore

View File

@ -19,7 +19,7 @@ __all__ = ['MapParameter']
import numbers
import mindspore as ms
from mindspore.common.parameter import Parameter
from mindspore.common.parameter import Tensor, Parameter
from mindspore.common.initializer import initializer
from mindspore._c_expression import Tensor as Tensor_
from mindspore._c_expression import MapTensor_
@ -82,7 +82,8 @@ class MapParameter(Parameter):
obj.key_dtype = key_dtype
obj.value_dtype = value_dtype
obj.value_shape = value_shape
obj.default_value = default_value
obj.default_value = default_value if isinstance(default_value, Tensor) else \
initializer(default_value, shape=value_shape, dtype=value_dtype).init_data()
return obj
def __init__(self, name=None, requires_grad=True, **kwargs):
@ -95,15 +96,15 @@ class MapParameter(Parameter):
Args:
key_tensor (Tensor): The key tensor.
default_value (Union[Tensor, str]): The default value or initializer. Default: None
default_value (Tensor): The default value tensor. Default: None
Returns:
Tensor, the value tensor for the key tensor.
"""
if default_value is None:
default_value = self.default_value
result = initializer(default_value, shape=(key_tensor.shape + self.value_shape), dtype=self.value_dtype)
return result.init_data()
result_tensor = self._map_tensor.get(key_tensor, default_value)
return Tensor(result_tensor, internal=True)
def put(self, key_tensor, value_tensor):
"""
@ -116,6 +117,7 @@ class MapParameter(Parameter):
Returns:
MapParameter, the MapParameter object itself.
"""
self._map_tensor.put(key_tensor, value_tensor)
return self
def erase(self, key_tensor):
@ -128,4 +130,26 @@ class MapParameter(Parameter):
Returns:
MapParameter, the MapParameter object itself.
"""
self._map_tensor.erase(key_tensor)
return self
def export(self, full=False):
"""
Export data from this map parameter.
Args:
full (bool): True for full export, otherwise for incremental export. Default: False.
Returns:
Tuple(key_array, value_array, status_array), The exported data as a tuple.
"""
return self._map_tensor.export(full)
def update(self, data):
"""
Update this map parameter from exported data.
Args:
data (Tuple): The data tuple with key_array, value_array and status_array.
"""
self._map_tensor.update(data)

View File

@ -35,7 +35,7 @@ def test_basic_operations():
assert t.shape == (3, 2)
assert np.allclose(t.asnumpy(), 0)
t = m.get(Tensor([1, 2, 3], dtype=ms.int32), 'ones')
t = m.get(Tensor([1, 2, 3], dtype=ms.int32), Tensor([1, 1], dtype=ms.float32))
assert t.dtype == ms.float32
assert t.shape == (3, 2)
assert np.allclose(t.asnumpy(), 1)
@ -71,3 +71,14 @@ def test_simple_graph_compile():
out = net(t)
print(out)
assert out.shape == (2, 3)
def test_export_update_api():
"""
Feature: MapParameter
Description: Test export update api for MapParameter.
Expectation: Export update api works as expected.
"""
m = MapParameter(key_dtype=ms.int32, value_dtype=ms.float32, value_shape=(3,))
data = m.export(full=True)
m.update(data)