forked from mindspore-Ecosystem/mindspore
Implement MapParameter default value and clone()
This commit is contained in:
parent
c4baddd410
commit
9c03020659
|
@ -405,6 +405,10 @@ BuiltInTypeMap &GetAttrMap() {
|
|||
{"ndim", std::string("sparse_ndim_")}, // C.sparse_ndim_
|
||||
{"itemsize", std::string("itemsize_")}, // C.itemsize_
|
||||
}},
|
||||
{kObjectTypeMapTensorType,
|
||||
{
|
||||
{"default_value", prim::kPrimMapTensorGetDefaultValue}, // F.map_tensor_get_default_value
|
||||
}},
|
||||
};
|
||||
return attr_map;
|
||||
}
|
||||
|
|
|
@ -16,19 +16,41 @@
|
|||
|
||||
#include "pybind_api/ir/map_tensor_py.h"
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include "pybind11/pytypes.h"
|
||||
#include "pybind_api/ir/tensor_py.h"
|
||||
#include "include/common/pybind_api/api_register.h"
|
||||
#include "include/common/utils/python_adapter.h"
|
||||
#include "pipeline/jit/parse/parse_base.h"
|
||||
#include "utils/hash_set.h"
|
||||
#include "utils/log_adapter.h"
|
||||
|
||||
namespace mindspore {
|
||||
using tensor::TensorPy;
|
||||
|
||||
static ValuePtr ConvertMapTensorDefaultValue(const py::object &default_value_obj, const TypePtr &value_dtype) {
|
||||
static const mindspore::HashSet<std::string> support_init_names = {"zeros", "ones", "normal"};
|
||||
if (py::isinstance<py::str>(default_value_obj)) {
|
||||
std::string init_name = py::cast<std::string>(default_value_obj);
|
||||
if (support_init_names.find(init_name) == support_init_names.end()) {
|
||||
MS_EXCEPTION(ValueError) << "Unsupported init name for map parameter: " << init_name;
|
||||
}
|
||||
return std::make_shared<StringImm>(init_name);
|
||||
}
|
||||
ValuePtr default_value;
|
||||
bool convert_ok = parse::ConvertData(default_value_obj, &default_value, false, value_dtype, false);
|
||||
if (!convert_ok || default_value == nullptr) {
|
||||
MS_EXCEPTION(ValueError) << "Incorrect default value for map parameter: " << py::str(default_value_obj);
|
||||
}
|
||||
return default_value;
|
||||
}
|
||||
|
||||
MapTensorPtr MapTensorPy::MakeMapTensor(const TypePtr &key_dtype, const TypePtr &value_dtype,
|
||||
const ShapeVector &value_shape) {
|
||||
const ShapeVector &value_shape, const py::object &default_value_obj) {
|
||||
TypeId key_dtype_id = ((key_dtype != nullptr) ? key_dtype->type_id() : TypeId::kNumberTypeInt32);
|
||||
TypeId value_dtype_id = ((value_dtype != nullptr) ? value_dtype->type_id() : TypeId::kNumberTypeFloat32);
|
||||
return std::make_shared<MapTensor>(key_dtype_id, value_dtype_id, value_shape);
|
||||
ValuePtr default_value = ConvertMapTensorDefaultValue(default_value_obj, value_dtype);
|
||||
return std::make_shared<MapTensor>(key_dtype_id, value_dtype_id, value_shape, default_value);
|
||||
}
|
||||
|
||||
void MapTensorPy::UpdateFromNumpy(const MapTensorPtr &map_tensor,
|
||||
|
@ -51,15 +73,26 @@ std::tuple<py::array, py::array, py::array> MapTensorPy::ExportAsNumpy(const Map
|
|||
TensorPy::AsNumpy(*data.status_tensor));
|
||||
}
|
||||
|
||||
// Python wrapper for MapTensor::Get.
|
||||
static tensor::TensorPtr PyMapTensorGet(const MapTensorPtr &map_tensor, const tensor::TensorPtr &key_tensor,
|
||||
const py::object &default_value_obj) {
|
||||
MS_EXCEPTION_IF_NULL(map_tensor);
|
||||
ValuePtr default_value =
|
||||
(default_value_obj.is_none() ? map_tensor->default_value()
|
||||
: ConvertMapTensorDefaultValue(default_value_obj, map_tensor->ValueDtype()));
|
||||
return map_tensor->Get(key_tensor, default_value);
|
||||
}
|
||||
|
||||
namespace tensor {
|
||||
void RegMapTensor(py::module *m) {
|
||||
// Define python MapTensor class.
|
||||
(void)py::class_<MapTensor, MapTensorPtr>(*m, "MapTensor_")
|
||||
.def(py::init(&MapTensorPy::MakeMapTensor), py::arg("key_dtype"), py::arg("value_dtype"), py::arg("value_shape"))
|
||||
.def(py::init(&MapTensorPy::MakeMapTensor), py::arg("key_dtype"), py::arg("value_dtype"), py::arg("value_shape"),
|
||||
py::arg("default_value"))
|
||||
.def_property_readonly("key_dtype", &MapTensor::KeyDtype)
|
||||
.def_property_readonly("value_dtype", &MapTensor::ValueDtype)
|
||||
.def_property_readonly("value_shape", &MapTensor::value_shape)
|
||||
.def("get", &MapTensor::Get)
|
||||
.def("get", &PyMapTensorGet)
|
||||
.def("put", &MapTensor::Put)
|
||||
.def("erase", &MapTensor::Erase)
|
||||
.def("export", &MapTensorPy::ExportAsNumpy)
|
||||
|
|
|
@ -30,7 +30,7 @@ namespace mindspore {
|
|||
class MapTensorPy {
|
||||
public:
|
||||
static MapTensorPtr MakeMapTensor(const TypePtr &key_dtype, const TypePtr &value_dtype,
|
||||
const ShapeVector &value_shape);
|
||||
const ShapeVector &value_shape, const py::object &default_value_obj);
|
||||
|
||||
static void UpdateFromNumpy(const MapTensorPtr &map_tensor,
|
||||
const std::tuple<py::array, py::array, py::array> &numpy_data);
|
||||
|
|
|
@ -1668,20 +1668,23 @@ const AbstractTensorPtr AbstractCSRTensor::values() const {
|
|||
AbstractMapTensor::AbstractMapTensor(const MapTensorPtr &map_tensor)
|
||||
: AbstractBase(map_tensor, std::make_shared<MapTensorType>(map_tensor->KeyDtype(), map_tensor->ValueDtype()),
|
||||
std::make_shared<Shape>(map_tensor->value_shape())),
|
||||
ref_key_value_(kAnyValue) {}
|
||||
ref_key_value_(kAnyValue),
|
||||
default_value_(map_tensor->default_value()) {}
|
||||
|
||||
AbstractMapTensor::AbstractMapTensor(const MapTensorPtr &map_tensor, const ValuePtr &ref_key_value)
|
||||
: AbstractBase(kAnyValue, std::make_shared<MapTensorType>(map_tensor->KeyDtype(), map_tensor->ValueDtype()),
|
||||
std::make_shared<Shape>(map_tensor->value_shape())),
|
||||
ref_key_value_(ref_key_value) {}
|
||||
ref_key_value_(ref_key_value),
|
||||
default_value_(map_tensor->default_value()) {}
|
||||
|
||||
AbstractMapTensor::AbstractMapTensor(const AbstractMapTensor &other)
|
||||
: AbstractBase(other.GetValueTrack(), other.GetTypeTrack(), other.GetShapeTrack()),
|
||||
ref_key_value_(other.ref_key_value_) {}
|
||||
ref_key_value_(other.ref_key_value_),
|
||||
default_value_(other.default_value_) {}
|
||||
|
||||
AbstractMapTensor::AbstractMapTensor(const TypePtr &type, const ShapePtr &value_shape, const ValuePtr &value,
|
||||
const ValuePtr &ref_key_value)
|
||||
: AbstractBase(value, type, value_shape), ref_key_value_(ref_key_value) {}
|
||||
const ValuePtr &ref_key_value, const ValuePtr &default_value)
|
||||
: AbstractBase(value, type, value_shape), ref_key_value_(ref_key_value), default_value_(default_value) {}
|
||||
|
||||
AbstractBasePtr AbstractMapTensor::Clone() const { return std::make_shared<AbstractMapTensor>(*this); }
|
||||
|
||||
|
@ -1716,7 +1719,15 @@ AbstractBasePtr AbstractMapTensor::Join(const AbstractBasePtr &other) {
|
|||
// Join the ref_key_value.
|
||||
auto joined_ref_key = ValueJoin(ref_key_value_, other_abs->ref_key_value_);
|
||||
|
||||
return std::make_shared<AbstractMapTensor>(joined_type, joined_shape, joined_value, joined_ref_key);
|
||||
// Join the ref_key_value.
|
||||
auto joined_default_value = ValueJoin(default_value_, other_abs->default_value_);
|
||||
if (joined_default_value == kAnyValue) {
|
||||
MS_EXCEPTION(ValueError) << "Join default value failed for MapTensor. " << default_value_->ToString()
|
||||
<< " != " << other_abs->default_value_->ToString();
|
||||
}
|
||||
|
||||
return std::make_shared<AbstractMapTensor>(joined_type, joined_shape, joined_value, joined_ref_key,
|
||||
joined_default_value);
|
||||
}
|
||||
|
||||
bool AbstractMapTensor::operator==(const AbstractBase &other) const {
|
||||
|
@ -1743,16 +1754,19 @@ bool AbstractMapTensor::operator==(const AbstractMapTensor &other) const {
|
|||
return false;
|
||||
}
|
||||
return common::IsEqual(GetTypeTrack(), other.GetTypeTrack()) &&
|
||||
common::IsEqual(GetShapeTrack(), other.GetShapeTrack());
|
||||
common::IsEqual(GetShapeTrack(), other.GetShapeTrack()) &&
|
||||
common::IsEqual(default_value(), other.default_value());
|
||||
}
|
||||
|
||||
std::size_t AbstractMapTensor::hash() const {
|
||||
const auto &type = GetTypeTrack();
|
||||
MS_EXCEPTION_IF_NULL(type);
|
||||
std::size_t hash_value = hash_combine(tid(), type->hash());
|
||||
const auto &value_shape = GetShapeTrack();
|
||||
MS_EXCEPTION_IF_NULL(type);
|
||||
MS_EXCEPTION_IF_NULL(value_shape);
|
||||
return hash_combine(hash_value, value_shape->hash());
|
||||
MS_EXCEPTION_IF_NULL(default_value_);
|
||||
std::size_t hash_value = hash_combine(tid(), type->hash());
|
||||
hash_value = hash_combine(hash_value, value_shape->hash());
|
||||
return hash_combine(hash_value, default_value_->hash());
|
||||
}
|
||||
|
||||
std::string AbstractMapTensor::ToString() const {
|
||||
|
|
|
@ -1560,7 +1560,7 @@ class MS_CORE_API AbstractMapTensor final : public AbstractBase {
|
|||
explicit AbstractMapTensor(const MapTensorPtr &map_tensor);
|
||||
AbstractMapTensor(const MapTensorPtr &map_tensor, const ValuePtr &ref_key_value);
|
||||
AbstractMapTensor(const TypePtr &type, const ShapePtr &value_shape, const ValuePtr &value,
|
||||
const ValuePtr &ref_key_value);
|
||||
const ValuePtr &ref_key_value, const ValuePtr &default_value);
|
||||
AbstractMapTensor(const AbstractMapTensor &other);
|
||||
~AbstractMapTensor() override = default;
|
||||
|
||||
|
@ -1569,6 +1569,7 @@ class MS_CORE_API AbstractMapTensor final : public AbstractBase {
|
|||
MapTensorTypePtr map_tensor_type() const { return dyn_cast<MapTensorType>(GetTypeTrack()); }
|
||||
ShapePtr value_shape() const { return dyn_cast<Shape>(GetShapeTrack()); }
|
||||
const ValuePtr &ref_key_value() const { return ref_key_value_; }
|
||||
const ValuePtr &default_value() const { return default_value_; }
|
||||
TypePtr BuildType() const override { return GetTypeTrack(); }
|
||||
BaseShapePtr BuildShape() const override { return GetShapeTrack(); };
|
||||
|
||||
|
@ -1582,6 +1583,8 @@ class MS_CORE_API AbstractMapTensor final : public AbstractBase {
|
|||
private:
|
||||
// The reference key value, can be a string value or kAnyValue.
|
||||
ValuePtr ref_key_value_;
|
||||
// The default value, a scalar or string with initializer name.
|
||||
ValuePtr default_value_;
|
||||
};
|
||||
using AbstractMapTensorPtr = std::shared_ptr<AbstractMapTensor>;
|
||||
|
||||
|
|
|
@ -242,6 +242,8 @@ AbstractBasePtr InferImplAdamApplyOne(const AnalysisEnginePtr &, const Primitive
|
|||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplAdamApplyOneWithDecay(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplMapTensorGetDefaultValue(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
|
||||
template <typename T>
|
||||
AbstractBasePtr InferTupleOrListOrDictLen(const std::string &op_name, const AbstractBasePtrList &args_spec_list) {
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
#include <string>
|
||||
|
||||
#include "ir/dtype.h"
|
||||
#include "utils/log_adapter.h"
|
||||
#include "utils/ms_utils.h"
|
||||
#include "abstract/param_validator.h"
|
||||
#include "abstract/ops/infer_functions.h"
|
||||
|
@ -525,5 +526,17 @@ AbstractBasePtr InferImplAdamApplyOneWithDecay(const AnalysisEnginePtr &, const
|
|||
AbstractBasePtrList rets = {add1, add0, sub0};
|
||||
return std::make_shared<AbstractTuple>(rets);
|
||||
}
|
||||
// Infer for MapTensor.default_value.
|
||||
AbstractBasePtr InferImplMapTensorGetDefaultValue(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
CheckArgsSize(primitive->name(), args_spec_list, 1);
|
||||
const auto &arg = args_spec_list[0];
|
||||
MS_EXCEPTION_IF_NULL(arg);
|
||||
auto abs_map_tensor = arg->cast_ptr<abstract::AbstractMapTensor>();
|
||||
if (abs_map_tensor == nullptr) {
|
||||
MS_EXCEPTION(TypeError) << "Expect MapTensor, but got " << arg->ToString();
|
||||
}
|
||||
return std::make_shared<AbstractScalar>(abs_map_tensor->default_value());
|
||||
}
|
||||
} // namespace abstract
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -339,6 +339,8 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
|
|||
{prim::kPrimRowTensorGetIndices, R{InferImplRowTensorGetIndices, nullptr, true}},
|
||||
{prim::kPrimRowTensorGetDenseShape, R{InferImplRowTensorGetDenseShape, nullptr, true}},
|
||||
{prim::kPrimRowTensorAdd, R{InferImplRowTensorAdd, nullptr, false}},
|
||||
// MapTensor
|
||||
{prim::kPrimMapTensorGetDefaultValue, R{InferImplMapTensorGetDefaultValue, nullptr, true}},
|
||||
// Comm Ops
|
||||
{prim::kPrimAllSwap, R{InferImplAllSwap, nullptr, true}},
|
||||
{prim::kPrimMemCpyAsync, R{InferImplMemCpyAsync, nullptr, true}},
|
||||
|
|
|
@ -45,7 +45,16 @@ abstract::AbstractBasePtr MapTensor::ToAbstract() {
|
|||
}
|
||||
}
|
||||
|
||||
TensorPtr MapTensor::Get(const TensorPtr &key_tensor, const TensorPtr &default_value) {
|
||||
std::string MapTensor::ToString() const {
|
||||
auto key_dtype = KeyDtype();
|
||||
auto value_dtype = ValueDtype();
|
||||
return "MapTensor(key_dtype=" + (key_dtype == nullptr ? "<null>" : key_dtype->ToString()) +
|
||||
", value_dtype=" + (value_dtype == nullptr ? "<null>" : value_dtype->ToString()) +
|
||||
", value_shape=" + tensor::ShapeToString(value_shape_) +
|
||||
", deault_value=" + (default_value_ == nullptr ? "<null>" : default_value_->ToString()) + ")";
|
||||
}
|
||||
|
||||
TensorPtr MapTensor::Get(const TensorPtr &key_tensor, const ValuePtr &default_value) {
|
||||
MS_EXCEPTION_IF_NULL(key_tensor);
|
||||
MS_EXCEPTION_IF_NULL(default_value);
|
||||
// Check input.
|
||||
|
@ -56,18 +65,10 @@ TensorPtr MapTensor::Get(const TensorPtr &key_tensor, const TensorPtr &default_v
|
|||
ShapeVector result_shape = ConcatShape(key_tensor->shape(), value_shape());
|
||||
// Make the result tensor.
|
||||
TensorPtr result_tensor = std::make_shared<Tensor>(value_dtype(), result_shape);
|
||||
// Note: this is the fake implementation that fill result tensor by copy default values.
|
||||
const size_t num_of_rows = static_cast<size_t>(result_shape[0]);
|
||||
const size_t default_value_bytes = static_cast<size_t>(default_value->data().nbytes());
|
||||
const uint8_t *default_value_data = static_cast<const uint8_t *>(default_value->data_c());
|
||||
// Note: this is the fake implementation that fill result tensor with zeros.
|
||||
size_t nbytes = static_cast<size_t>(result_tensor->data().nbytes());
|
||||
auto data_ptr = static_cast<uint8_t *>(result_tensor->data_c());
|
||||
for (size_t i = 0; i < num_of_rows; ++i) {
|
||||
auto ret = common::huge_memcpy(data_ptr, default_value_bytes, default_value_data, default_value_bytes);
|
||||
if (ret != EOK) {
|
||||
MS_LOG(EXCEPTION) << "Copy tensor data failed!";
|
||||
}
|
||||
data_ptr += default_value_bytes;
|
||||
}
|
||||
(void)std::fill(data_ptr, data_ptr + nbytes, 0);
|
||||
return result_tensor;
|
||||
}
|
||||
|
||||
|
|
|
@ -53,8 +53,9 @@ class MS_CORE_API MapTensor final : public Value {
|
|||
/// \param[in] key_dtype [TypeId] The key data type id.
|
||||
/// \param[in] value_dtype [TypeId] The value data type id.
|
||||
/// \param[in] value_shape [TypeId] The value shape.
|
||||
MapTensor(TypeId key_dtype, TypeId value_dtype, const ShapeVector &value_shape)
|
||||
: key_dtype_(key_dtype), value_dtype_(value_dtype), value_shape_(value_shape) {}
|
||||
/// \param[in] default_value [ValuePtr] the default value.
|
||||
MapTensor(TypeId key_dtype, TypeId value_dtype, const ShapeVector &value_shape, const ValuePtr &default_value)
|
||||
: key_dtype_(key_dtype), value_dtype_(value_dtype), value_shape_(value_shape), default_value_(default_value) {}
|
||||
|
||||
~MapTensor() override = default;
|
||||
|
||||
|
@ -81,19 +82,15 @@ class MS_CORE_API MapTensor final : public Value {
|
|||
|
||||
const ShapeVector &value_shape() const { return value_shape_; }
|
||||
|
||||
const ValuePtr &default_value() const { return default_value_; }
|
||||
|
||||
TypePtr KeyDtype() const { return TypeIdToType(key_dtype_); }
|
||||
|
||||
TypePtr ValueDtype() const { return TypeIdToType(value_dtype_); }
|
||||
|
||||
abstract::AbstractBasePtr ToAbstract() override;
|
||||
|
||||
std::string ToString() const override {
|
||||
auto key_dtype = KeyDtype();
|
||||
auto value_dtype = ValueDtype();
|
||||
return "MapTensor(key_dtype=" + (key_dtype == nullptr ? "<null>" : key_dtype->ToString()) +
|
||||
", value_dtype=" + (value_dtype == nullptr ? "<null>" : value_dtype->ToString()) +
|
||||
", value_shape=" + tensor::ShapeToString(value_shape_) + ")";
|
||||
}
|
||||
std::string ToString() const override;
|
||||
|
||||
/// \brief Get tensor's param_info info.
|
||||
///
|
||||
|
@ -113,9 +110,9 @@ class MS_CORE_API MapTensor final : public Value {
|
|||
/// \brief Get or create values.
|
||||
///
|
||||
/// \param[in] key_tensor [Tensor] The key tensor.
|
||||
/// \param[in] default_value [Tensor] The default value tensor.
|
||||
/// \param[in] default_value [Value] The default value.
|
||||
/// \return The value tensor according the key tensor.
|
||||
TensorPtr Get(const TensorPtr &key_tensor, const TensorPtr &default_value);
|
||||
TensorPtr Get(const TensorPtr &key_tensor, const ValuePtr &default_value);
|
||||
|
||||
/// \brief Put or insert key value pairs.
|
||||
///
|
||||
|
@ -149,6 +146,9 @@ class MS_CORE_API MapTensor final : public Value {
|
|||
// Shape of the value.
|
||||
ShapeVector value_shape_;
|
||||
|
||||
// Default value. should be a scalar as the initial value or a string as the initializer name.
|
||||
ValuePtr default_value_;
|
||||
|
||||
// Parameter information.
|
||||
ParamInfoPtr param_info_;
|
||||
};
|
||||
|
|
|
@ -320,6 +320,7 @@ constexpr auto kMapTensorGet = "MapTensorGet";
|
|||
constexpr auto kMapTensorPut = "MapTensorPut";
|
||||
constexpr auto kMapTensorErase = "MapTensorErase";
|
||||
constexpr auto kMapTensorPutBprop = "MapTensorPutBprop";
|
||||
constexpr auto kMapTensorGetDefaultValue = "MapTensorGetDefaultValue";
|
||||
|
||||
// COOTensor
|
||||
constexpr auto kMakeCOOTensor = "MakeCOOTensor";
|
||||
|
@ -1036,6 +1037,7 @@ GVAR_DEF(PrimitivePtr, kPrimMapTensorGet, std::make_shared<Primitive>(kMapTensor
|
|||
GVAR_DEF(PrimitivePtr, kPrimMapTensorPut, std::make_shared<Primitive>(kMapTensorPut));
|
||||
GVAR_DEF(PrimitivePtr, kPrimMapTensorErase, std::make_shared<Primitive>(kMapTensorErase));
|
||||
GVAR_DEF(PrimitivePtr, kPrimMapTensorPutBprop, std::make_shared<Primitive>(kMapTensorPutBprop));
|
||||
GVAR_DEF(PrimitivePtr, kPrimMapTensorGetDefaultValue, std::make_shared<Primitive>(kMapTensorGetDefaultValue));
|
||||
|
||||
// Sparse ops
|
||||
GVAR_DEF(PrimitivePtr, kPrimSparseTensorDenseMatmul, std::make_shared<Primitive>(kSparseTensorDenseMatmul));
|
||||
|
|
|
@ -53,17 +53,10 @@ AbstractBasePtr MapTensorGetInfer(const abstract::AnalysisEnginePtr &, const Pri
|
|||
<< " but got " << key_tensor_shape->ToString() << ".";
|
||||
}
|
||||
|
||||
// Check 'default_value' dtype and shape.
|
||||
auto default_value_dtype = CheckAndConvertUtils::GetTensorInputType(kNameMapTensorGet, input_args, kInputIndex2);
|
||||
if (!common::IsEqual(value_dtype, default_value_dtype)) {
|
||||
MS_EXCEPTION(ValueError) << kNameMapTensorGet << " - required default_value dtype " << value_dtype->ToString()
|
||||
<< " but got " << default_value_dtype->ToString() << ".";
|
||||
}
|
||||
auto default_value_shape = CheckAndConvertUtils::GetTensorInputShape(kNameMapTensorGet, input_args, kInputIndex2);
|
||||
if (!common::IsEqual(value_shape, default_value_shape)) {
|
||||
MS_EXCEPTION(ValueError) << kNameMapTensorGet << " - required default_value shape " << value_shape->ToString()
|
||||
<< " but got " << default_value_shape->ToString() << ".";
|
||||
}
|
||||
// Check 'default_value'.
|
||||
auto default_value_scalar =
|
||||
CheckAndConvertUtils::CheckArgs<abstract::AbstractScalar>(kNameMapTensorGet, input_args, kInputIndex2);
|
||||
MS_EXCEPTION_IF_NULL(default_value_scalar);
|
||||
|
||||
// Concate key shape and value shape as the result shape.
|
||||
ShapeVector shape_vec = key_tensor_shape->shape();
|
||||
|
|
|
@ -3227,10 +3227,12 @@ def cholesky_inverse(input_x, upper=False):
|
|||
return F.cholesky_inverse(input_x, upper=upper)
|
||||
|
||||
|
||||
def map_tensor_get(map_tensor, key_tensor, default_value):
|
||||
def map_tensor_get(map_tensor, key_tensor, default_value=None):
|
||||
r"""
|
||||
Get or create value according the key tensor from a map tensor.
|
||||
"""
|
||||
if default_value is None:
|
||||
default_value = map_tensor.default_value
|
||||
return _map_tensor_ops.get(map_tensor, key_tensor, default_value)
|
||||
|
||||
|
||||
|
|
|
@ -17,10 +17,10 @@ from __future__ import absolute_import
|
|||
|
||||
__all__ = ['MapParameter']
|
||||
|
||||
from copy import copy
|
||||
import numbers
|
||||
import mindspore as ms
|
||||
from mindspore.common.parameter import Tensor, Parameter
|
||||
from mindspore.common.initializer import initializer
|
||||
from mindspore._c_expression import Tensor as Tensor_
|
||||
from mindspore._c_expression import MapTensor_
|
||||
|
||||
|
@ -39,7 +39,7 @@ class MapParameter(Parameter):
|
|||
be defined in `mindspore.dtype`. Default: float32.
|
||||
value_shape (Union[tuple, list, int]): Used to indicate the shape of the value Tensor. The argument should be
|
||||
a list of integers, a tuple of integers or an integer. Default: 1.
|
||||
default_value (Union[Tensor, str]): The default value Tensor or initializer name. Default: 'zeros'.
|
||||
default_value (Union[numbers.Number, str]): The default value number or initializer name. Default: 'normal'.
|
||||
name (str): Name of the map parameter. Default: None.
|
||||
requires_grad (bool): True if the parameter requires gradient. Default: True.
|
||||
|
||||
|
@ -60,15 +60,15 @@ class MapParameter(Parameter):
|
|||
[[1. 1. 1.]
|
||||
[2. 2. 2.]
|
||||
[0. 0. 0.]]
|
||||
>>> m.del(Tensor([2, 3], dtype=ms.int32))
|
||||
>>> t = m.get(Tensor([1, 2, 3], dtype=ms.int32))
|
||||
>>> m.erase(Tensor([2, 3], dtype=ms.int32))
|
||||
>>> t = m.get(Tensor([1, 2, 3], dtype=ms.int32), 3)
|
||||
>>> print(t)
|
||||
[[1. 1. 1.]
|
||||
[0. 0. 0.]
|
||||
[0. 0. 0.]]
|
||||
[3. 3. 3.]
|
||||
[3. 3. 3.]]
|
||||
"""
|
||||
|
||||
def __new__(cls, key_dtype=ms.int32, value_dtype=ms.float32, value_shape=1, default_value='zeros', **kwargs):
|
||||
def __new__(cls, key_dtype=ms.int32, value_dtype=ms.float32, value_shape=1, default_value='normal', **kwargs):
|
||||
if isinstance(value_shape, numbers.Number):
|
||||
value_shape = (value_shape,)
|
||||
data = Tensor_(value_dtype, value_shape)
|
||||
|
@ -82,13 +82,53 @@ class MapParameter(Parameter):
|
|||
obj.key_dtype = key_dtype
|
||||
obj.value_dtype = value_dtype
|
||||
obj.value_shape = value_shape
|
||||
obj.default_value = default_value if isinstance(default_value, Tensor) else \
|
||||
initializer(default_value, shape=value_shape, dtype=value_dtype).init_data()
|
||||
obj.default_value = default_value
|
||||
return obj
|
||||
|
||||
def __init__(self, name=None, requires_grad=True, **kwargs):
|
||||
Parameter.__init__(self, self, name=name, requires_grad=requires_grad)
|
||||
self._map_tensor = MapTensor_(self.key_dtype, self.value_dtype, self.value_shape)
|
||||
self._map_tensor = MapTensor_(self.key_dtype, self.value_dtype, self.value_shape, self.default_value)
|
||||
|
||||
def __getitem__(self, key_tensor):
|
||||
return self.get(key_tensor)
|
||||
|
||||
def __setitem__(self, key_tensor, value_tensor):
|
||||
return self.put(key_tensor, value_tensor)
|
||||
|
||||
def __str__(self):
|
||||
return 'MapParameter(' + str(self._map_tensor) + ')'
|
||||
|
||||
def __copy__(self):
|
||||
x = type(self)()
|
||||
x.__dict__.update(self.__dict__)
|
||||
return x
|
||||
|
||||
def clone(self, init='same'):
|
||||
"""
|
||||
Clone the MapParameter.
|
||||
|
||||
Args:
|
||||
init (Union[str, numbers.Number]): Initialize the default value of the new map parameter.
|
||||
If `init` is a `numbers.Number`, clone a new map parameter with the same key value shape
|
||||
and dtype, and the default value of the new map parameter will be set according to `init`.
|
||||
If `init` is a `str`, the `init` should be the alias of the class inheriting from `Initializer`.
|
||||
If `init` is 'same', clone a new map parameter with the same default value. Default: 'same'.
|
||||
|
||||
Returns:
|
||||
MapParameter, the new map parameter.
|
||||
"""
|
||||
x = copy(self)
|
||||
x.param_info = self.param_info.clone()
|
||||
info = self.param_info
|
||||
if hasattr(info, "cloned_obj"):
|
||||
info.cloned_obj.append(x)
|
||||
else:
|
||||
info.cloned_obj = [x]
|
||||
self.param_info = info
|
||||
if init != 'same':
|
||||
x.default_value = init # pylint: disable=W0201
|
||||
x._map_tensor = MapTensor_(x.key_dtype, x.value_dtype, x.value_shape, x.default_value) # pylint: disable=W0212
|
||||
return x
|
||||
|
||||
def get(self, key_tensor, default_value=None):
|
||||
"""
|
||||
|
@ -96,7 +136,7 @@ class MapParameter(Parameter):
|
|||
|
||||
Args:
|
||||
key_tensor (Tensor): The key tensor.
|
||||
default_value (Tensor): The default value tensor. Default: None
|
||||
default_value (Union[numbers.Number, str]): The default value number or initializer name. Default: None
|
||||
|
||||
Returns:
|
||||
Tensor, the value tensor for the key tensor.
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
# Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -15,6 +15,7 @@
|
|||
|
||||
"""Implementation for internal polymorphism `getitem` operations."""
|
||||
from mindspore.ops.operations import _inner_ops as inner
|
||||
from mindspore.ops.operations import _map_tensor_ops
|
||||
from mindspore.ops.composite.multitype_ops import _compile_utils as compile_utils
|
||||
from mindspore.ops.composite import base
|
||||
from mindspore.ops import functional as F
|
||||
|
@ -307,3 +308,19 @@ def _tensor_getitem_by_tuple(data, tuple_index):
|
|||
Tensor, element type is the same as the element type of data.
|
||||
"""
|
||||
return compile_utils.tensor_index_by_tuple(data, tuple_index)
|
||||
|
||||
|
||||
@getitem.register("MapTensor", "Tensor")
|
||||
def _map_tensor_getitem(map_tensor, key_tensor):
|
||||
"""
|
||||
Getting value tensor from map tensor by key tensor.
|
||||
|
||||
Inputs:
|
||||
map_tensor (MapTensor): A map tensor.
|
||||
key_tensor (Tensor): The key tensor.
|
||||
|
||||
Outputs:
|
||||
Tensor, value tensor according the key tensor.
|
||||
"""
|
||||
default_value = map_tensor.default_value
|
||||
return _map_tensor_ops.get(map_tensor, key_tensor, default_value)
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
# Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -18,6 +18,7 @@
|
|||
from mindspore.ops.composite.multitype_ops import _compile_utils as compile_utils
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.ops.operations._inner_ops import SliceGetItem
|
||||
from mindspore.ops.operations import _map_tensor_ops
|
||||
from mindspore.ops.composite import base
|
||||
from mindspore.common import Tensor
|
||||
|
||||
|
@ -46,6 +47,7 @@ class _ListSliceSetItem(base.ListSliceSetItem_):
|
|||
def __call__(self, *args):
|
||||
pass
|
||||
|
||||
|
||||
_list_slice_set_item = _ListSliceSetItem('list_slice_set_item')
|
||||
"""_list_slice_set_item is a MetaFuncGraph object which assign a list will slice."""
|
||||
|
||||
|
@ -876,3 +878,19 @@ def _tensor_setitem_by_list_with_list(data, index, value):
|
|||
if isinstance(index, Tensor):
|
||||
return compile_utils.tensor_setitem_by_tensor_with_sequence(data, index, value)
|
||||
return compile_utils.tensor_setitem_by_tuple_with_sequence(data, index, value)
|
||||
|
||||
|
||||
@setitem.register("MapTensor", "Tensor", "Tensor")
|
||||
def _map_tensor_setitem(map_tensor, key_tensor, value_tensor):
|
||||
"""
|
||||
Update or insert to map tensor by key tensor and value tensor.
|
||||
|
||||
Inputs:
|
||||
map_tensor (MapTensor): A map tensor.
|
||||
key_tensor (Tensor): The key tensor.
|
||||
value_tensor (Tensor): The value tensor.
|
||||
|
||||
Outputs:
|
||||
MapTensor, the map tensor be updated.
|
||||
"""
|
||||
return _map_tensor_ops.put(map_tensor, key_tensor, value_tensor)
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
*/
|
||||
#include <memory>
|
||||
#include "common/common_test.h"
|
||||
#include "ir/value.h"
|
||||
#include "ir/tensor.h"
|
||||
#include "ir/map_tensor.h"
|
||||
|
||||
|
@ -32,7 +33,12 @@ class TestMapTensor : public UT::Common {
|
|||
/// Description: test MapTensor API.
|
||||
/// Expectation: MapTensor API work as expected.
|
||||
TEST_F(TestMapTensor, TestApi) {
|
||||
auto m = std::make_shared<MapTensor>(kNumberTypeInt32, kNumberTypeFloat32, ShapeVector{4});
|
||||
auto default_value = std::make_shared<StringImm>("zeros");
|
||||
auto m = std::make_shared<MapTensor>(kNumberTypeInt32, kNumberTypeFloat32, ShapeVector{4}, default_value);
|
||||
ASSERT_TRUE(m != nullptr);
|
||||
ASSERT_EQ(m->key_dtype(), kNumberTypeInt32);
|
||||
ASSERT_EQ(m->value_dtype(), kNumberTypeFloat32);
|
||||
ASSERT_EQ(m->value_shape(), ShapeVector{4});
|
||||
ASSERT_EQ(m->default_value(), default_value);
|
||||
}
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
import numpy as np
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor, Parameter, context
|
||||
from mindspore import context, Tensor, Parameter, ParameterTuple
|
||||
from mindspore.experimental import MapParameter
|
||||
from mindspore.common.initializer import initializer
|
||||
|
||||
|
@ -35,14 +35,22 @@ def test_basic_operations():
|
|||
assert t.shape == (3, 2)
|
||||
assert np.allclose(t.asnumpy(), 0)
|
||||
|
||||
t = m.get(Tensor([1, 2, 3], dtype=ms.int32), Tensor([1, 1], dtype=ms.float32))
|
||||
t = m.get(Tensor([1, 2, 3], dtype=ms.int32), 0)
|
||||
assert t.dtype == ms.float32
|
||||
assert t.shape == (3, 2)
|
||||
assert np.allclose(t.asnumpy(), 1)
|
||||
assert np.allclose(t.asnumpy(), 0)
|
||||
|
||||
t = m[Tensor([1, 2, 3], dtype=ms.int32)]
|
||||
assert t.dtype == ms.float32
|
||||
assert t.shape == (3, 2)
|
||||
assert np.allclose(t.asnumpy(), 0)
|
||||
|
||||
m.put(Tensor([1, 2, 3], dtype=ms.int32), Tensor([[1, 1], [2, 2], [3, 3]], dtype=ms.float32))
|
||||
m[Tensor([1, 2, 3], dtype=ms.int32)] = Tensor([[11, 11], [22, 22], [33, 33]], dtype=ms.float32)
|
||||
m.erase(Tensor([1, 2, 3], dtype=ms.int32))
|
||||
|
||||
print(m)
|
||||
|
||||
|
||||
def test_simple_graph_compile():
|
||||
"""
|
||||
|
@ -56,13 +64,16 @@ def test_simple_graph_compile():
|
|||
self.p = Parameter(initializer('ones', (2, 3), ms.float32))
|
||||
self.m = MapParameter(key_dtype=ms.int32, value_dtype=ms.float32, value_shape=(3,))
|
||||
self.key = Tensor([1, 2], dtype=ms.int32)
|
||||
self.default_value = Tensor([3.0, 3.0, 3.0], dtype=ms.float32)
|
||||
|
||||
def construct(self, x):
|
||||
self.m.put(self.key, x)
|
||||
value = self.m.get(self.key, self.default_value)
|
||||
value1 = self.m.get(self.key, 0.1)
|
||||
value2 = self.m.get(self.key, 'zeros')
|
||||
value3 = self.m.get(self.key)
|
||||
value4 = self.m[self.key]
|
||||
self.m[self.key] = value4
|
||||
self.m.erase(self.key)
|
||||
return self.p + value
|
||||
return self.p + value1 + value2 + value3 + value4
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
net = MyNet()
|
||||
|
@ -82,3 +93,36 @@ def test_export_update_api():
|
|||
m = MapParameter(key_dtype=ms.int32, value_dtype=ms.float32, value_shape=(3,))
|
||||
data = m.export(full=True)
|
||||
m.update(data)
|
||||
|
||||
|
||||
def test_map_parameter_clone():
|
||||
"""
|
||||
Feature: MapParameter
|
||||
Description: Test MapParameter clone() method.
|
||||
Expectation: MapParameter cloned as expected.
|
||||
"""
|
||||
m = MapParameter(key_dtype=ms.int32, value_dtype=ms.float32, value_shape=(3,), name="map")
|
||||
p = Parameter(Tensor(1), name="param")
|
||||
params = ParameterTuple([m, p])
|
||||
cloned_params = params.clone(prefix="cloned", init='zeros')
|
||||
|
||||
cloned_map = cloned_params[0]
|
||||
assert isinstance(cloned_map, MapParameter)
|
||||
assert cloned_map.name == 'cloned.map'
|
||||
assert cloned_map.key_dtype == m.key_dtype
|
||||
assert cloned_map.value_dtype == m.value_dtype
|
||||
assert cloned_map.value_shape == m.value_shape
|
||||
assert cloned_map.default_value == 'zeros'
|
||||
|
||||
old_map_tensor = m._map_tensor # pylint: disable=W0212
|
||||
new_map_tensor = cloned_map._map_tensor # pylint: disable=W0212
|
||||
assert new_map_tensor != old_map_tensor
|
||||
assert new_map_tensor.key_dtype == old_map_tensor.key_dtype
|
||||
assert new_map_tensor.value_dtype == old_map_tensor.value_dtype
|
||||
assert new_map_tensor.value_shape == old_map_tensor.value_shape
|
||||
|
||||
clone_same = cloned_map.clone(init='same')
|
||||
assert clone_same.key_dtype == m.key_dtype
|
||||
assert clone_same.value_dtype == m.value_dtype
|
||||
assert clone_same.value_shape == m.value_shape
|
||||
assert clone_same.default_value == 'zeros'
|
||||
|
|
Loading…
Reference in New Issue