forked from mindspore-Ecosystem/mindspore
!42873 Add Export/Update API for MapParameter
Merge pull request !42873 from hewei/map_param1
This commit is contained in:
commit
9d3ffd2dfc
|
@ -58,7 +58,14 @@ void RegMapTensor(py::module *m) {
|
|||
.def(py::init(&MapTensorPy::MakeMapTensor), py::arg("key_dtype"), py::arg("value_dtype"), py::arg("value_shape"))
|
||||
.def_property_readonly("key_dtype", &MapTensor::KeyDtype)
|
||||
.def_property_readonly("value_dtype", &MapTensor::ValueDtype)
|
||||
.def_property_readonly("value_shape", &MapTensor::value_shape);
|
||||
.def_property_readonly("value_shape", &MapTensor::value_shape)
|
||||
.def("get", &MapTensor::Get)
|
||||
.def("put", &MapTensor::Put)
|
||||
.def("erase", &MapTensor::Erase)
|
||||
.def("export", &MapTensorPy::ExportAsNumpy)
|
||||
.def("update", &MapTensorPy::UpdateFromNumpy)
|
||||
.def("__str__", &MapTensor::ToString)
|
||||
.def("__repr__", &MapTensor::ToString);
|
||||
}
|
||||
} // namespace tensor
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -16,12 +16,20 @@
|
|||
|
||||
#include "ir/map_tensor.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "ir/tensor.h"
|
||||
#include "utils/log_adapter.h"
|
||||
#include "utils/ms_utils_secure.h"
|
||||
|
||||
namespace mindspore {
|
||||
using tensor::Tensor;
|
||||
using tensor::TensorPtr;
|
||||
|
||||
static ShapeVector ConcatShape(const ShapeVector &a, const ShapeVector &b) {
|
||||
ShapeVector result_shape = a;
|
||||
result_shape.insert(result_shape.end(), b.begin(), b.end());
|
||||
return result_shape;
|
||||
}
|
||||
|
||||
std::size_t MapTensor::hash() const { return static_cast<std::size_t>(tid()); }
|
||||
|
||||
bool MapTensor::operator==(const MapTensor &other) const { return this == &other; }
|
||||
|
@ -40,7 +48,27 @@ abstract::AbstractBasePtr MapTensor::ToAbstract() {
|
|||
TensorPtr MapTensor::Get(const TensorPtr &key_tensor, const TensorPtr &default_value) {
|
||||
MS_EXCEPTION_IF_NULL(key_tensor);
|
||||
MS_EXCEPTION_IF_NULL(default_value);
|
||||
return nullptr;
|
||||
// Check input.
|
||||
if (key_tensor->shape().size() != 1) {
|
||||
MS_LOG(EXCEPTION) << "Invalid key tensor shape: " << tensor::ShapeToString(key_tensor->shape());
|
||||
}
|
||||
// Result shape = key_tensor.shape + value_shape.
|
||||
ShapeVector result_shape = ConcatShape(key_tensor->shape(), value_shape());
|
||||
// Make the result tensor.
|
||||
TensorPtr result_tensor = std::make_shared<Tensor>(value_dtype(), result_shape);
|
||||
// Note: this is the fake implementation that fill result tensor by copy default values.
|
||||
const size_t num_of_rows = static_cast<size_t>(result_shape[0]);
|
||||
const size_t default_value_bytes = static_cast<size_t>(default_value->data().nbytes());
|
||||
const uint8_t *default_value_data = static_cast<const uint8_t *>(default_value->data_c());
|
||||
auto data_ptr = static_cast<uint8_t *>(result_tensor->data_c());
|
||||
for (size_t i = 0; i < num_of_rows; ++i) {
|
||||
auto ret = common::huge_memcpy(data_ptr, default_value_bytes, default_value_data, default_value_bytes);
|
||||
if (ret != EOK) {
|
||||
MS_LOG(EXCEPTION) << "Copy tensor data failed!";
|
||||
}
|
||||
data_ptr += default_value_bytes;
|
||||
}
|
||||
return result_tensor;
|
||||
}
|
||||
|
||||
void MapTensor::Put(const TensorPtr &key_tensor, const TensorPtr &value_tensor) {
|
||||
|
@ -57,6 +85,12 @@ void MapTensor::Update(const MapTensor::ExportData &data) {
|
|||
|
||||
MapTensor::ExportData MapTensor::Export(bool full) {
|
||||
MS_LOG(DEBUG) << (full ? "Full" : "Incremental") << " export MapTensor";
|
||||
return {nullptr, nullptr, nullptr};
|
||||
// Note: this is fake implementation.
|
||||
ShapeVector key_shape = {1};
|
||||
ShapeVector values_shape = ConcatShape(ShapeVector{1}, value_shape());
|
||||
auto key_tensor = std::make_shared<Tensor>(key_dtype(), key_shape);
|
||||
auto value_tensor = std::make_shared<Tensor>(value_dtype(), values_shape);
|
||||
auto status_tensor = std::make_shared<Tensor>(kNumberTypeUInt8, key_shape);
|
||||
return {key_tensor, value_tensor, status_tensor};
|
||||
}
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -19,7 +19,7 @@ __all__ = ['MapParameter']
|
|||
|
||||
import numbers
|
||||
import mindspore as ms
|
||||
from mindspore.common.parameter import Parameter
|
||||
from mindspore.common.parameter import Tensor, Parameter
|
||||
from mindspore.common.initializer import initializer
|
||||
from mindspore._c_expression import Tensor as Tensor_
|
||||
from mindspore._c_expression import MapTensor_
|
||||
|
@ -82,7 +82,8 @@ class MapParameter(Parameter):
|
|||
obj.key_dtype = key_dtype
|
||||
obj.value_dtype = value_dtype
|
||||
obj.value_shape = value_shape
|
||||
obj.default_value = default_value
|
||||
obj.default_value = default_value if isinstance(default_value, Tensor) else \
|
||||
initializer(default_value, shape=value_shape, dtype=value_dtype).init_data()
|
||||
return obj
|
||||
|
||||
def __init__(self, name=None, requires_grad=True, **kwargs):
|
||||
|
@ -95,15 +96,15 @@ class MapParameter(Parameter):
|
|||
|
||||
Args:
|
||||
key_tensor (Tensor): The key tensor.
|
||||
default_value (Union[Tensor, str]): The default value or initializer. Default: None
|
||||
default_value (Tensor): The default value tensor. Default: None
|
||||
|
||||
Returns:
|
||||
Tensor, the value tensor for the key tensor.
|
||||
"""
|
||||
if default_value is None:
|
||||
default_value = self.default_value
|
||||
result = initializer(default_value, shape=(key_tensor.shape + self.value_shape), dtype=self.value_dtype)
|
||||
return result.init_data()
|
||||
result_tensor = self._map_tensor.get(key_tensor, default_value)
|
||||
return Tensor(result_tensor, internal=True)
|
||||
|
||||
def put(self, key_tensor, value_tensor):
|
||||
"""
|
||||
|
@ -116,6 +117,7 @@ class MapParameter(Parameter):
|
|||
Returns:
|
||||
MapParameter, the MapParameter object itself.
|
||||
"""
|
||||
self._map_tensor.put(key_tensor, value_tensor)
|
||||
return self
|
||||
|
||||
def erase(self, key_tensor):
|
||||
|
@ -128,4 +130,26 @@ class MapParameter(Parameter):
|
|||
Returns:
|
||||
MapParameter, the MapParameter object itself.
|
||||
"""
|
||||
self._map_tensor.erase(key_tensor)
|
||||
return self
|
||||
|
||||
def export(self, full=False):
|
||||
"""
|
||||
Export data from this map parameter.
|
||||
|
||||
Args:
|
||||
full (bool): True for full export, otherwise for incremental export. Default: False.
|
||||
|
||||
Returns:
|
||||
Tuple(key_array, value_array, status_array), The exported data as a tuple.
|
||||
"""
|
||||
return self._map_tensor.export(full)
|
||||
|
||||
def update(self, data):
|
||||
"""
|
||||
Update this map parameter from exported data.
|
||||
|
||||
Args:
|
||||
data (Tuple): The data tuple with key_array, value_array and status_array.
|
||||
"""
|
||||
self._map_tensor.update(data)
|
||||
|
|
|
@ -35,7 +35,7 @@ def test_basic_operations():
|
|||
assert t.shape == (3, 2)
|
||||
assert np.allclose(t.asnumpy(), 0)
|
||||
|
||||
t = m.get(Tensor([1, 2, 3], dtype=ms.int32), 'ones')
|
||||
t = m.get(Tensor([1, 2, 3], dtype=ms.int32), Tensor([1, 1], dtype=ms.float32))
|
||||
assert t.dtype == ms.float32
|
||||
assert t.shape == (3, 2)
|
||||
assert np.allclose(t.asnumpy(), 1)
|
||||
|
@ -71,3 +71,14 @@ def test_simple_graph_compile():
|
|||
out = net(t)
|
||||
print(out)
|
||||
assert out.shape == (2, 3)
|
||||
|
||||
|
||||
def test_export_update_api():
|
||||
"""
|
||||
Feature: MapParameter
|
||||
Description: Test export update api for MapParameter.
|
||||
Expectation: Export update api works as expected.
|
||||
"""
|
||||
m = MapParameter(key_dtype=ms.int32, value_dtype=ms.float32, value_shape=(3,))
|
||||
data = m.export(full=True)
|
||||
m.update(data)
|
||||
|
|
Loading…
Reference in New Issue