Implement MapParameter default value and clone()

This commit is contained in:
He Wei 2022-09-26 16:25:38 +08:00
parent c4baddd410
commit 9c03020659
18 changed files with 265 additions and 71 deletions

View File

@ -405,6 +405,10 @@ BuiltInTypeMap &GetAttrMap() {
{"ndim", std::string("sparse_ndim_")}, // C.sparse_ndim_
{"itemsize", std::string("itemsize_")}, // C.itemsize_
}},
{kObjectTypeMapTensorType,
{
{"default_value", prim::kPrimMapTensorGetDefaultValue}, // F.map_tensor_get_default_value
}},
};
return attr_map;
}

View File

@ -16,19 +16,41 @@
#include "pybind_api/ir/map_tensor_py.h"
#include <memory>
#include <string>
#include "pybind11/pytypes.h"
#include "pybind_api/ir/tensor_py.h"
#include "include/common/pybind_api/api_register.h"
#include "include/common/utils/python_adapter.h"
#include "pipeline/jit/parse/parse_base.h"
#include "utils/hash_set.h"
#include "utils/log_adapter.h"
namespace mindspore {
using tensor::TensorPy;
static ValuePtr ConvertMapTensorDefaultValue(const py::object &default_value_obj, const TypePtr &value_dtype) {
static const mindspore::HashSet<std::string> support_init_names = {"zeros", "ones", "normal"};
if (py::isinstance<py::str>(default_value_obj)) {
std::string init_name = py::cast<std::string>(default_value_obj);
if (support_init_names.find(init_name) == support_init_names.end()) {
MS_EXCEPTION(ValueError) << "Unsupported init name for map parameter: " << init_name;
}
return std::make_shared<StringImm>(init_name);
}
ValuePtr default_value;
bool convert_ok = parse::ConvertData(default_value_obj, &default_value, false, value_dtype, false);
if (!convert_ok || default_value == nullptr) {
MS_EXCEPTION(ValueError) << "Incorrect default value for map parameter: " << py::str(default_value_obj);
}
return default_value;
}
MapTensorPtr MapTensorPy::MakeMapTensor(const TypePtr &key_dtype, const TypePtr &value_dtype,
const ShapeVector &value_shape) {
const ShapeVector &value_shape, const py::object &default_value_obj) {
TypeId key_dtype_id = ((key_dtype != nullptr) ? key_dtype->type_id() : TypeId::kNumberTypeInt32);
TypeId value_dtype_id = ((value_dtype != nullptr) ? value_dtype->type_id() : TypeId::kNumberTypeFloat32);
return std::make_shared<MapTensor>(key_dtype_id, value_dtype_id, value_shape);
ValuePtr default_value = ConvertMapTensorDefaultValue(default_value_obj, value_dtype);
return std::make_shared<MapTensor>(key_dtype_id, value_dtype_id, value_shape, default_value);
}
void MapTensorPy::UpdateFromNumpy(const MapTensorPtr &map_tensor,
@ -51,15 +73,26 @@ std::tuple<py::array, py::array, py::array> MapTensorPy::ExportAsNumpy(const Map
TensorPy::AsNumpy(*data.status_tensor));
}
// Python wrapper for MapTensor::Get.
static tensor::TensorPtr PyMapTensorGet(const MapTensorPtr &map_tensor, const tensor::TensorPtr &key_tensor,
const py::object &default_value_obj) {
MS_EXCEPTION_IF_NULL(map_tensor);
ValuePtr default_value =
(default_value_obj.is_none() ? map_tensor->default_value()
: ConvertMapTensorDefaultValue(default_value_obj, map_tensor->ValueDtype()));
return map_tensor->Get(key_tensor, default_value);
}
namespace tensor {
void RegMapTensor(py::module *m) {
// Define python MapTensor class.
(void)py::class_<MapTensor, MapTensorPtr>(*m, "MapTensor_")
.def(py::init(&MapTensorPy::MakeMapTensor), py::arg("key_dtype"), py::arg("value_dtype"), py::arg("value_shape"))
.def(py::init(&MapTensorPy::MakeMapTensor), py::arg("key_dtype"), py::arg("value_dtype"), py::arg("value_shape"),
py::arg("default_value"))
.def_property_readonly("key_dtype", &MapTensor::KeyDtype)
.def_property_readonly("value_dtype", &MapTensor::ValueDtype)
.def_property_readonly("value_shape", &MapTensor::value_shape)
.def("get", &MapTensor::Get)
.def("get", &PyMapTensorGet)
.def("put", &MapTensor::Put)
.def("erase", &MapTensor::Erase)
.def("export", &MapTensorPy::ExportAsNumpy)

View File

@ -30,7 +30,7 @@ namespace mindspore {
class MapTensorPy {
public:
static MapTensorPtr MakeMapTensor(const TypePtr &key_dtype, const TypePtr &value_dtype,
const ShapeVector &value_shape);
const ShapeVector &value_shape, const py::object &default_value_obj);
static void UpdateFromNumpy(const MapTensorPtr &map_tensor,
const std::tuple<py::array, py::array, py::array> &numpy_data);

View File

@ -1668,20 +1668,23 @@ const AbstractTensorPtr AbstractCSRTensor::values() const {
AbstractMapTensor::AbstractMapTensor(const MapTensorPtr &map_tensor)
: AbstractBase(map_tensor, std::make_shared<MapTensorType>(map_tensor->KeyDtype(), map_tensor->ValueDtype()),
std::make_shared<Shape>(map_tensor->value_shape())),
ref_key_value_(kAnyValue) {}
ref_key_value_(kAnyValue),
default_value_(map_tensor->default_value()) {}
AbstractMapTensor::AbstractMapTensor(const MapTensorPtr &map_tensor, const ValuePtr &ref_key_value)
: AbstractBase(kAnyValue, std::make_shared<MapTensorType>(map_tensor->KeyDtype(), map_tensor->ValueDtype()),
std::make_shared<Shape>(map_tensor->value_shape())),
ref_key_value_(ref_key_value) {}
ref_key_value_(ref_key_value),
default_value_(map_tensor->default_value()) {}
AbstractMapTensor::AbstractMapTensor(const AbstractMapTensor &other)
: AbstractBase(other.GetValueTrack(), other.GetTypeTrack(), other.GetShapeTrack()),
ref_key_value_(other.ref_key_value_) {}
ref_key_value_(other.ref_key_value_),
default_value_(other.default_value_) {}
AbstractMapTensor::AbstractMapTensor(const TypePtr &type, const ShapePtr &value_shape, const ValuePtr &value,
const ValuePtr &ref_key_value)
: AbstractBase(value, type, value_shape), ref_key_value_(ref_key_value) {}
const ValuePtr &ref_key_value, const ValuePtr &default_value)
: AbstractBase(value, type, value_shape), ref_key_value_(ref_key_value), default_value_(default_value) {}
AbstractBasePtr AbstractMapTensor::Clone() const { return std::make_shared<AbstractMapTensor>(*this); }
@ -1716,7 +1719,15 @@ AbstractBasePtr AbstractMapTensor::Join(const AbstractBasePtr &other) {
// Join the ref_key_value.
auto joined_ref_key = ValueJoin(ref_key_value_, other_abs->ref_key_value_);
return std::make_shared<AbstractMapTensor>(joined_type, joined_shape, joined_value, joined_ref_key);
// Join the ref_key_value.
auto joined_default_value = ValueJoin(default_value_, other_abs->default_value_);
if (joined_default_value == kAnyValue) {
MS_EXCEPTION(ValueError) << "Join default value failed for MapTensor. " << default_value_->ToString()
<< " != " << other_abs->default_value_->ToString();
}
return std::make_shared<AbstractMapTensor>(joined_type, joined_shape, joined_value, joined_ref_key,
joined_default_value);
}
bool AbstractMapTensor::operator==(const AbstractBase &other) const {
@ -1743,16 +1754,19 @@ bool AbstractMapTensor::operator==(const AbstractMapTensor &other) const {
return false;
}
return common::IsEqual(GetTypeTrack(), other.GetTypeTrack()) &&
common::IsEqual(GetShapeTrack(), other.GetShapeTrack());
common::IsEqual(GetShapeTrack(), other.GetShapeTrack()) &&
common::IsEqual(default_value(), other.default_value());
}
std::size_t AbstractMapTensor::hash() const {
const auto &type = GetTypeTrack();
MS_EXCEPTION_IF_NULL(type);
std::size_t hash_value = hash_combine(tid(), type->hash());
const auto &value_shape = GetShapeTrack();
MS_EXCEPTION_IF_NULL(type);
MS_EXCEPTION_IF_NULL(value_shape);
return hash_combine(hash_value, value_shape->hash());
MS_EXCEPTION_IF_NULL(default_value_);
std::size_t hash_value = hash_combine(tid(), type->hash());
hash_value = hash_combine(hash_value, value_shape->hash());
return hash_combine(hash_value, default_value_->hash());
}
std::string AbstractMapTensor::ToString() const {

View File

@ -1560,7 +1560,7 @@ class MS_CORE_API AbstractMapTensor final : public AbstractBase {
explicit AbstractMapTensor(const MapTensorPtr &map_tensor);
AbstractMapTensor(const MapTensorPtr &map_tensor, const ValuePtr &ref_key_value);
AbstractMapTensor(const TypePtr &type, const ShapePtr &value_shape, const ValuePtr &value,
const ValuePtr &ref_key_value);
const ValuePtr &ref_key_value, const ValuePtr &default_value);
AbstractMapTensor(const AbstractMapTensor &other);
~AbstractMapTensor() override = default;
@ -1569,6 +1569,7 @@ class MS_CORE_API AbstractMapTensor final : public AbstractBase {
MapTensorTypePtr map_tensor_type() const { return dyn_cast<MapTensorType>(GetTypeTrack()); }
ShapePtr value_shape() const { return dyn_cast<Shape>(GetShapeTrack()); }
const ValuePtr &ref_key_value() const { return ref_key_value_; }
const ValuePtr &default_value() const { return default_value_; }
TypePtr BuildType() const override { return GetTypeTrack(); }
BaseShapePtr BuildShape() const override { return GetShapeTrack(); };
@ -1582,6 +1583,8 @@ class MS_CORE_API AbstractMapTensor final : public AbstractBase {
private:
// The reference key value, can be a string value or kAnyValue.
ValuePtr ref_key_value_;
// The default value, a scalar or string with initializer name.
ValuePtr default_value_;
};
using AbstractMapTensorPtr = std::shared_ptr<AbstractMapTensor>;

View File

@ -242,6 +242,8 @@ AbstractBasePtr InferImplAdamApplyOne(const AnalysisEnginePtr &, const Primitive
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplAdamApplyOneWithDecay(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplMapTensorGetDefaultValue(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
template <typename T>
AbstractBasePtr InferTupleOrListOrDictLen(const std::string &op_name, const AbstractBasePtrList &args_spec_list) {

View File

@ -17,6 +17,7 @@
#include <string>
#include "ir/dtype.h"
#include "utils/log_adapter.h"
#include "utils/ms_utils.h"
#include "abstract/param_validator.h"
#include "abstract/ops/infer_functions.h"
@ -525,5 +526,17 @@ AbstractBasePtr InferImplAdamApplyOneWithDecay(const AnalysisEnginePtr &, const
AbstractBasePtrList rets = {add1, add0, sub0};
return std::make_shared<AbstractTuple>(rets);
}
// Infer for MapTensor.default_value.
AbstractBasePtr InferImplMapTensorGetDefaultValue(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
CheckArgsSize(primitive->name(), args_spec_list, 1);
const auto &arg = args_spec_list[0];
MS_EXCEPTION_IF_NULL(arg);
auto abs_map_tensor = arg->cast_ptr<abstract::AbstractMapTensor>();
if (abs_map_tensor == nullptr) {
MS_EXCEPTION(TypeError) << "Expect MapTensor, but got " << arg->ToString();
}
return std::make_shared<AbstractScalar>(abs_map_tensor->default_value());
}
} // namespace abstract
} // namespace mindspore

View File

@ -339,6 +339,8 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
{prim::kPrimRowTensorGetIndices, R{InferImplRowTensorGetIndices, nullptr, true}},
{prim::kPrimRowTensorGetDenseShape, R{InferImplRowTensorGetDenseShape, nullptr, true}},
{prim::kPrimRowTensorAdd, R{InferImplRowTensorAdd, nullptr, false}},
// MapTensor
{prim::kPrimMapTensorGetDefaultValue, R{InferImplMapTensorGetDefaultValue, nullptr, true}},
// Comm Ops
{prim::kPrimAllSwap, R{InferImplAllSwap, nullptr, true}},
{prim::kPrimMemCpyAsync, R{InferImplMemCpyAsync, nullptr, true}},

View File

@ -45,7 +45,16 @@ abstract::AbstractBasePtr MapTensor::ToAbstract() {
}
}
TensorPtr MapTensor::Get(const TensorPtr &key_tensor, const TensorPtr &default_value) {
std::string MapTensor::ToString() const {
auto key_dtype = KeyDtype();
auto value_dtype = ValueDtype();
return "MapTensor(key_dtype=" + (key_dtype == nullptr ? "<null>" : key_dtype->ToString()) +
", value_dtype=" + (value_dtype == nullptr ? "<null>" : value_dtype->ToString()) +
", value_shape=" + tensor::ShapeToString(value_shape_) +
", deault_value=" + (default_value_ == nullptr ? "<null>" : default_value_->ToString()) + ")";
}
TensorPtr MapTensor::Get(const TensorPtr &key_tensor, const ValuePtr &default_value) {
MS_EXCEPTION_IF_NULL(key_tensor);
MS_EXCEPTION_IF_NULL(default_value);
// Check input.
@ -56,18 +65,10 @@ TensorPtr MapTensor::Get(const TensorPtr &key_tensor, const TensorPtr &default_v
ShapeVector result_shape = ConcatShape(key_tensor->shape(), value_shape());
// Make the result tensor.
TensorPtr result_tensor = std::make_shared<Tensor>(value_dtype(), result_shape);
// Note: this is the fake implementation that fill result tensor by copy default values.
const size_t num_of_rows = static_cast<size_t>(result_shape[0]);
const size_t default_value_bytes = static_cast<size_t>(default_value->data().nbytes());
const uint8_t *default_value_data = static_cast<const uint8_t *>(default_value->data_c());
// Note: this is the fake implementation that fill result tensor with zeros.
size_t nbytes = static_cast<size_t>(result_tensor->data().nbytes());
auto data_ptr = static_cast<uint8_t *>(result_tensor->data_c());
for (size_t i = 0; i < num_of_rows; ++i) {
auto ret = common::huge_memcpy(data_ptr, default_value_bytes, default_value_data, default_value_bytes);
if (ret != EOK) {
MS_LOG(EXCEPTION) << "Copy tensor data failed!";
}
data_ptr += default_value_bytes;
}
(void)std::fill(data_ptr, data_ptr + nbytes, 0);
return result_tensor;
}

View File

@ -53,8 +53,9 @@ class MS_CORE_API MapTensor final : public Value {
/// \param[in] key_dtype [TypeId] The key data type id.
/// \param[in] value_dtype [TypeId] The value data type id.
/// \param[in] value_shape [TypeId] The value shape.
MapTensor(TypeId key_dtype, TypeId value_dtype, const ShapeVector &value_shape)
: key_dtype_(key_dtype), value_dtype_(value_dtype), value_shape_(value_shape) {}
/// \param[in] default_value [ValuePtr] the default value.
MapTensor(TypeId key_dtype, TypeId value_dtype, const ShapeVector &value_shape, const ValuePtr &default_value)
: key_dtype_(key_dtype), value_dtype_(value_dtype), value_shape_(value_shape), default_value_(default_value) {}
~MapTensor() override = default;
@ -81,19 +82,15 @@ class MS_CORE_API MapTensor final : public Value {
const ShapeVector &value_shape() const { return value_shape_; }
const ValuePtr &default_value() const { return default_value_; }
TypePtr KeyDtype() const { return TypeIdToType(key_dtype_); }
TypePtr ValueDtype() const { return TypeIdToType(value_dtype_); }
abstract::AbstractBasePtr ToAbstract() override;
std::string ToString() const override {
auto key_dtype = KeyDtype();
auto value_dtype = ValueDtype();
return "MapTensor(key_dtype=" + (key_dtype == nullptr ? "<null>" : key_dtype->ToString()) +
", value_dtype=" + (value_dtype == nullptr ? "<null>" : value_dtype->ToString()) +
", value_shape=" + tensor::ShapeToString(value_shape_) + ")";
}
std::string ToString() const override;
/// \brief Get tensor's param_info info.
///
@ -113,9 +110,9 @@ class MS_CORE_API MapTensor final : public Value {
/// \brief Get or create values.
///
/// \param[in] key_tensor [Tensor] The key tensor.
/// \param[in] default_value [Tensor] The default value tensor.
/// \param[in] default_value [Value] The default value.
/// \return The value tensor according the key tensor.
TensorPtr Get(const TensorPtr &key_tensor, const TensorPtr &default_value);
TensorPtr Get(const TensorPtr &key_tensor, const ValuePtr &default_value);
/// \brief Put or insert key value pairs.
///
@ -149,6 +146,9 @@ class MS_CORE_API MapTensor final : public Value {
// Shape of the value.
ShapeVector value_shape_;
// Default value. should be a scalar as the initial value or a string as the initializer name.
ValuePtr default_value_;
// Parameter information.
ParamInfoPtr param_info_;
};

View File

@ -320,6 +320,7 @@ constexpr auto kMapTensorGet = "MapTensorGet";
constexpr auto kMapTensorPut = "MapTensorPut";
constexpr auto kMapTensorErase = "MapTensorErase";
constexpr auto kMapTensorPutBprop = "MapTensorPutBprop";
constexpr auto kMapTensorGetDefaultValue = "MapTensorGetDefaultValue";
// COOTensor
constexpr auto kMakeCOOTensor = "MakeCOOTensor";
@ -1036,6 +1037,7 @@ GVAR_DEF(PrimitivePtr, kPrimMapTensorGet, std::make_shared<Primitive>(kMapTensor
GVAR_DEF(PrimitivePtr, kPrimMapTensorPut, std::make_shared<Primitive>(kMapTensorPut));
GVAR_DEF(PrimitivePtr, kPrimMapTensorErase, std::make_shared<Primitive>(kMapTensorErase));
GVAR_DEF(PrimitivePtr, kPrimMapTensorPutBprop, std::make_shared<Primitive>(kMapTensorPutBprop));
GVAR_DEF(PrimitivePtr, kPrimMapTensorGetDefaultValue, std::make_shared<Primitive>(kMapTensorGetDefaultValue));
// Sparse ops
GVAR_DEF(PrimitivePtr, kPrimSparseTensorDenseMatmul, std::make_shared<Primitive>(kSparseTensorDenseMatmul));

View File

@ -53,17 +53,10 @@ AbstractBasePtr MapTensorGetInfer(const abstract::AnalysisEnginePtr &, const Pri
<< " but got " << key_tensor_shape->ToString() << ".";
}
// Check 'default_value' dtype and shape.
auto default_value_dtype = CheckAndConvertUtils::GetTensorInputType(kNameMapTensorGet, input_args, kInputIndex2);
if (!common::IsEqual(value_dtype, default_value_dtype)) {
MS_EXCEPTION(ValueError) << kNameMapTensorGet << " - required default_value dtype " << value_dtype->ToString()
<< " but got " << default_value_dtype->ToString() << ".";
}
auto default_value_shape = CheckAndConvertUtils::GetTensorInputShape(kNameMapTensorGet, input_args, kInputIndex2);
if (!common::IsEqual(value_shape, default_value_shape)) {
MS_EXCEPTION(ValueError) << kNameMapTensorGet << " - required default_value shape " << value_shape->ToString()
<< " but got " << default_value_shape->ToString() << ".";
}
// Check 'default_value'.
auto default_value_scalar =
CheckAndConvertUtils::CheckArgs<abstract::AbstractScalar>(kNameMapTensorGet, input_args, kInputIndex2);
MS_EXCEPTION_IF_NULL(default_value_scalar);
// Concate key shape and value shape as the result shape.
ShapeVector shape_vec = key_tensor_shape->shape();

View File

@ -3227,10 +3227,12 @@ def cholesky_inverse(input_x, upper=False):
return F.cholesky_inverse(input_x, upper=upper)
def map_tensor_get(map_tensor, key_tensor, default_value):
def map_tensor_get(map_tensor, key_tensor, default_value=None):
r"""
Get or create value according the key tensor from a map tensor.
"""
if default_value is None:
default_value = map_tensor.default_value
return _map_tensor_ops.get(map_tensor, key_tensor, default_value)

View File

@ -17,10 +17,10 @@ from __future__ import absolute_import
__all__ = ['MapParameter']
from copy import copy
import numbers
import mindspore as ms
from mindspore.common.parameter import Tensor, Parameter
from mindspore.common.initializer import initializer
from mindspore._c_expression import Tensor as Tensor_
from mindspore._c_expression import MapTensor_
@ -39,7 +39,7 @@ class MapParameter(Parameter):
be defined in `mindspore.dtype`. Default: float32.
value_shape (Union[tuple, list, int]): Used to indicate the shape of the value Tensor. The argument should be
a list of integers, a tuple of integers or an integer. Default: 1.
default_value (Union[Tensor, str]): The default value Tensor or initializer name. Default: 'zeros'.
default_value (Union[numbers.Number, str]): The default value number or initializer name. Default: 'normal'.
name (str): Name of the map parameter. Default: None.
requires_grad (bool): True if the parameter requires gradient. Default: True.
@ -60,15 +60,15 @@ class MapParameter(Parameter):
[[1. 1. 1.]
[2. 2. 2.]
[0. 0. 0.]]
>>> m.del(Tensor([2, 3], dtype=ms.int32))
>>> t = m.get(Tensor([1, 2, 3], dtype=ms.int32))
>>> m.erase(Tensor([2, 3], dtype=ms.int32))
>>> t = m.get(Tensor([1, 2, 3], dtype=ms.int32), 3)
>>> print(t)
[[1. 1. 1.]
[0. 0. 0.]
[0. 0. 0.]]
[3. 3. 3.]
[3. 3. 3.]]
"""
def __new__(cls, key_dtype=ms.int32, value_dtype=ms.float32, value_shape=1, default_value='zeros', **kwargs):
def __new__(cls, key_dtype=ms.int32, value_dtype=ms.float32, value_shape=1, default_value='normal', **kwargs):
if isinstance(value_shape, numbers.Number):
value_shape = (value_shape,)
data = Tensor_(value_dtype, value_shape)
@ -82,13 +82,53 @@ class MapParameter(Parameter):
obj.key_dtype = key_dtype
obj.value_dtype = value_dtype
obj.value_shape = value_shape
obj.default_value = default_value if isinstance(default_value, Tensor) else \
initializer(default_value, shape=value_shape, dtype=value_dtype).init_data()
obj.default_value = default_value
return obj
def __init__(self, name=None, requires_grad=True, **kwargs):
Parameter.__init__(self, self, name=name, requires_grad=requires_grad)
self._map_tensor = MapTensor_(self.key_dtype, self.value_dtype, self.value_shape)
self._map_tensor = MapTensor_(self.key_dtype, self.value_dtype, self.value_shape, self.default_value)
def __getitem__(self, key_tensor):
return self.get(key_tensor)
def __setitem__(self, key_tensor, value_tensor):
return self.put(key_tensor, value_tensor)
def __str__(self):
return 'MapParameter(' + str(self._map_tensor) + ')'
def __copy__(self):
x = type(self)()
x.__dict__.update(self.__dict__)
return x
def clone(self, init='same'):
"""
Clone the MapParameter.
Args:
init (Union[str, numbers.Number]): Initialize the default value of the new map parameter.
If `init` is a `numbers.Number`, clone a new map parameter with the same key value shape
and dtype, and the default value of the new map parameter will be set according to `init`.
If `init` is a `str`, the `init` should be the alias of the class inheriting from `Initializer`.
If `init` is 'same', clone a new map parameter with the same default value. Default: 'same'.
Returns:
MapParameter, the new map parameter.
"""
x = copy(self)
x.param_info = self.param_info.clone()
info = self.param_info
if hasattr(info, "cloned_obj"):
info.cloned_obj.append(x)
else:
info.cloned_obj = [x]
self.param_info = info
if init != 'same':
x.default_value = init # pylint: disable=W0201
x._map_tensor = MapTensor_(x.key_dtype, x.value_dtype, x.value_shape, x.default_value) # pylint: disable=W0212
return x
def get(self, key_tensor, default_value=None):
"""
@ -96,7 +136,7 @@ class MapParameter(Parameter):
Args:
key_tensor (Tensor): The key tensor.
default_value (Tensor): The default value tensor. Default: None
default_value (Union[numbers.Number, str]): The default value number or initializer name. Default: None
Returns:
Tensor, the value tensor for the key tensor.

View File

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -15,6 +15,7 @@
"""Implementation for internal polymorphism `getitem` operations."""
from mindspore.ops.operations import _inner_ops as inner
from mindspore.ops.operations import _map_tensor_ops
from mindspore.ops.composite.multitype_ops import _compile_utils as compile_utils
from mindspore.ops.composite import base
from mindspore.ops import functional as F
@ -307,3 +308,19 @@ def _tensor_getitem_by_tuple(data, tuple_index):
Tensor, element type is the same as the element type of data.
"""
return compile_utils.tensor_index_by_tuple(data, tuple_index)
@getitem.register("MapTensor", "Tensor")
def _map_tensor_getitem(map_tensor, key_tensor):
"""
Getting value tensor from map tensor by key tensor.
Inputs:
map_tensor (MapTensor): A map tensor.
key_tensor (Tensor): The key tensor.
Outputs:
Tensor, value tensor according the key tensor.
"""
default_value = map_tensor.default_value
return _map_tensor_ops.get(map_tensor, key_tensor, default_value)

View File

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -18,6 +18,7 @@
from mindspore.ops.composite.multitype_ops import _compile_utils as compile_utils
from mindspore.ops import functional as F
from mindspore.ops.operations._inner_ops import SliceGetItem
from mindspore.ops.operations import _map_tensor_ops
from mindspore.ops.composite import base
from mindspore.common import Tensor
@ -46,6 +47,7 @@ class _ListSliceSetItem(base.ListSliceSetItem_):
def __call__(self, *args):
pass
_list_slice_set_item = _ListSliceSetItem('list_slice_set_item')
"""_list_slice_set_item is a MetaFuncGraph object which assign a list will slice."""
@ -876,3 +878,19 @@ def _tensor_setitem_by_list_with_list(data, index, value):
if isinstance(index, Tensor):
return compile_utils.tensor_setitem_by_tensor_with_sequence(data, index, value)
return compile_utils.tensor_setitem_by_tuple_with_sequence(data, index, value)
@setitem.register("MapTensor", "Tensor", "Tensor")
def _map_tensor_setitem(map_tensor, key_tensor, value_tensor):
"""
Update or insert to map tensor by key tensor and value tensor.
Inputs:
map_tensor (MapTensor): A map tensor.
key_tensor (Tensor): The key tensor.
value_tensor (Tensor): The value tensor.
Outputs:
MapTensor, the map tensor be updated.
"""
return _map_tensor_ops.put(map_tensor, key_tensor, value_tensor)

View File

@ -15,6 +15,7 @@
*/
#include <memory>
#include "common/common_test.h"
#include "ir/value.h"
#include "ir/tensor.h"
#include "ir/map_tensor.h"
@ -32,7 +33,12 @@ class TestMapTensor : public UT::Common {
/// Description: test MapTensor API.
/// Expectation: MapTensor API work as expected.
TEST_F(TestMapTensor, TestApi) {
auto m = std::make_shared<MapTensor>(kNumberTypeInt32, kNumberTypeFloat32, ShapeVector{4});
auto default_value = std::make_shared<StringImm>("zeros");
auto m = std::make_shared<MapTensor>(kNumberTypeInt32, kNumberTypeFloat32, ShapeVector{4}, default_value);
ASSERT_TRUE(m != nullptr);
ASSERT_EQ(m->key_dtype(), kNumberTypeInt32);
ASSERT_EQ(m->value_dtype(), kNumberTypeFloat32);
ASSERT_EQ(m->value_shape(), ShapeVector{4});
ASSERT_EQ(m->default_value(), default_value);
}
} // namespace mindspore

View File

@ -15,7 +15,7 @@
import numpy as np
import mindspore as ms
import mindspore.nn as nn
from mindspore import Tensor, Parameter, context
from mindspore import context, Tensor, Parameter, ParameterTuple
from mindspore.experimental import MapParameter
from mindspore.common.initializer import initializer
@ -35,14 +35,22 @@ def test_basic_operations():
assert t.shape == (3, 2)
assert np.allclose(t.asnumpy(), 0)
t = m.get(Tensor([1, 2, 3], dtype=ms.int32), Tensor([1, 1], dtype=ms.float32))
t = m.get(Tensor([1, 2, 3], dtype=ms.int32), 0)
assert t.dtype == ms.float32
assert t.shape == (3, 2)
assert np.allclose(t.asnumpy(), 1)
assert np.allclose(t.asnumpy(), 0)
t = m[Tensor([1, 2, 3], dtype=ms.int32)]
assert t.dtype == ms.float32
assert t.shape == (3, 2)
assert np.allclose(t.asnumpy(), 0)
m.put(Tensor([1, 2, 3], dtype=ms.int32), Tensor([[1, 1], [2, 2], [3, 3]], dtype=ms.float32))
m[Tensor([1, 2, 3], dtype=ms.int32)] = Tensor([[11, 11], [22, 22], [33, 33]], dtype=ms.float32)
m.erase(Tensor([1, 2, 3], dtype=ms.int32))
print(m)
def test_simple_graph_compile():
"""
@ -56,13 +64,16 @@ def test_simple_graph_compile():
self.p = Parameter(initializer('ones', (2, 3), ms.float32))
self.m = MapParameter(key_dtype=ms.int32, value_dtype=ms.float32, value_shape=(3,))
self.key = Tensor([1, 2], dtype=ms.int32)
self.default_value = Tensor([3.0, 3.0, 3.0], dtype=ms.float32)
def construct(self, x):
self.m.put(self.key, x)
value = self.m.get(self.key, self.default_value)
value1 = self.m.get(self.key, 0.1)
value2 = self.m.get(self.key, 'zeros')
value3 = self.m.get(self.key)
value4 = self.m[self.key]
self.m[self.key] = value4
self.m.erase(self.key)
return self.p + value
return self.p + value1 + value2 + value3 + value4
context.set_context(mode=context.GRAPH_MODE)
net = MyNet()
@ -82,3 +93,36 @@ def test_export_update_api():
m = MapParameter(key_dtype=ms.int32, value_dtype=ms.float32, value_shape=(3,))
data = m.export(full=True)
m.update(data)
def test_map_parameter_clone():
"""
Feature: MapParameter
Description: Test MapParameter clone() method.
Expectation: MapParameter cloned as expected.
"""
m = MapParameter(key_dtype=ms.int32, value_dtype=ms.float32, value_shape=(3,), name="map")
p = Parameter(Tensor(1), name="param")
params = ParameterTuple([m, p])
cloned_params = params.clone(prefix="cloned", init='zeros')
cloned_map = cloned_params[0]
assert isinstance(cloned_map, MapParameter)
assert cloned_map.name == 'cloned.map'
assert cloned_map.key_dtype == m.key_dtype
assert cloned_map.value_dtype == m.value_dtype
assert cloned_map.value_shape == m.value_shape
assert cloned_map.default_value == 'zeros'
old_map_tensor = m._map_tensor # pylint: disable=W0212
new_map_tensor = cloned_map._map_tensor # pylint: disable=W0212
assert new_map_tensor != old_map_tensor
assert new_map_tensor.key_dtype == old_map_tensor.key_dtype
assert new_map_tensor.value_dtype == old_map_tensor.value_dtype
assert new_map_tensor.value_shape == old_map_tensor.value_shape
clone_same = cloned_map.clone(init='same')
assert clone_same.key_dtype == m.key_dtype
assert clone_same.value_dtype == m.value_dtype
assert clone_same.value_shape == m.value_shape
assert clone_same.default_value == 'zeros'