forked from mindspore-Ecosystem/mindspore
Add experimental API for MapParameter
This commit is contained in:
parent
3689955521
commit
352571d7be
|
@ -309,6 +309,7 @@ install(
|
|||
${CMAKE_SOURCE_DIR}/mindspore/python/mindspore/compression
|
||||
${CMAKE_SOURCE_DIR}/mindspore/python/mindspore/rewrite
|
||||
${CMAKE_SOURCE_DIR}/mindspore/python/mindspore/run_check
|
||||
${CMAKE_SOURCE_DIR}/mindspore/python/mindspore/experimental
|
||||
DESTINATION ${INSTALL_PY_DIR}
|
||||
COMPONENT mindspore
|
||||
)
|
||||
|
|
|
@ -189,6 +189,7 @@ install(
|
|||
${CMAKE_SOURCE_DIR}/mindspore/python/mindspore/compression
|
||||
${CMAKE_SOURCE_DIR}/mindspore/python/mindspore/rewrite
|
||||
${CMAKE_SOURCE_DIR}/mindspore/python/mindspore/run_check
|
||||
${CMAKE_SOURCE_DIR}/mindspore/python/mindspore/experimental
|
||||
DESTINATION ${INSTALL_PY_DIR}
|
||||
COMPONENT mindspore
|
||||
)
|
||||
|
|
|
@ -243,6 +243,7 @@ install(
|
|||
${CMAKE_SOURCE_DIR}/mindspore/python/mindspore/compression
|
||||
${CMAKE_SOURCE_DIR}/mindspore/python/mindspore/rewrite
|
||||
${CMAKE_SOURCE_DIR}/mindspore/python/mindspore/run_check
|
||||
${CMAKE_SOURCE_DIR}/mindspore/python/mindspore/experimental
|
||||
DESTINATION ${INSTALL_PY_DIR}
|
||||
COMPONENT mindspore
|
||||
)
|
||||
|
|
|
@ -53,6 +53,7 @@ void RegMetaTensor(py::module *m);
|
|||
void RegCSRTensor(py::module *m);
|
||||
void RegCOOTensor(py::module *m);
|
||||
void RegRowTensor(py::module *m);
|
||||
void RegMapTensor(py::module *m);
|
||||
} // namespace tensor
|
||||
|
||||
namespace opt {
|
||||
|
|
|
@ -104,6 +104,7 @@ void RegModule(py::module *m) {
|
|||
mindspore::tensor::RegCSRTensor(m);
|
||||
mindspore::tensor::RegCOOTensor(m);
|
||||
mindspore::tensor::RegRowTensor(m);
|
||||
mindspore::tensor::RegMapTensor(m);
|
||||
RegValues(m);
|
||||
mindspore::initializer::RegRandomNormal(m);
|
||||
RegMsContext(m);
|
||||
|
|
|
@ -0,0 +1,64 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "pybind_api/ir/map_tensor_py.h"
|
||||
#include <memory>
|
||||
#include "pybind_api/ir/tensor_py.h"
|
||||
#include "include/common/pybind_api/api_register.h"
|
||||
#include "include/common/utils/python_adapter.h"
|
||||
#include "utils/log_adapter.h"
|
||||
|
||||
namespace mindspore {
|
||||
using tensor::TensorPy;
|
||||
|
||||
MapTensorPtr MapTensorPy::MakeMapTensor(const TypePtr &key_dtype, const TypePtr &value_dtype,
|
||||
const ShapeVector &value_shape) {
|
||||
TypeId key_dtype_id = ((key_dtype != nullptr) ? key_dtype->type_id() : TypeId::kNumberTypeInt32);
|
||||
TypeId value_dtype_id = ((value_dtype != nullptr) ? value_dtype->type_id() : TypeId::kNumberTypeFloat32);
|
||||
return std::make_shared<MapTensor>(key_dtype_id, value_dtype_id, value_shape);
|
||||
}
|
||||
|
||||
void MapTensorPy::UpdateFromNumpy(const MapTensorPtr &map_tensor,
|
||||
const std::tuple<py::array, py::array, py::array> &numpy_data) {
|
||||
MS_EXCEPTION_IF_NULL(map_tensor);
|
||||
MapTensor::ExportData data;
|
||||
constexpr size_t key_index = 0;
|
||||
constexpr size_t value_index = 1;
|
||||
constexpr size_t status_index = 2;
|
||||
data.key_tensor = TensorPy::MakeTensorOfNumpy(std::get<key_index>(numpy_data));
|
||||
data.value_tensor = TensorPy::MakeTensorOfNumpy(std::get<value_index>(numpy_data));
|
||||
data.status_tensor = TensorPy::MakeTensorOfNumpy(std::get<status_index>(numpy_data));
|
||||
map_tensor->Update(data);
|
||||
}
|
||||
|
||||
std::tuple<py::array, py::array, py::array> MapTensorPy::ExportAsNumpy(const MapTensorPtr &map_tensor, bool full) {
|
||||
MS_EXCEPTION_IF_NULL(map_tensor);
|
||||
auto data = map_tensor->Export(full);
|
||||
return std::make_tuple(TensorPy::AsNumpy(*data.key_tensor), TensorPy::AsNumpy(*data.value_tensor),
|
||||
TensorPy::AsNumpy(*data.status_tensor));
|
||||
}
|
||||
|
||||
namespace tensor {
|
||||
void RegMapTensor(py::module *m) {
|
||||
// Define python MapTensor class.
|
||||
(void)py::class_<MapTensor, MapTensorPtr>(*m, "MapTensor_")
|
||||
.def(py::init(&MapTensorPy::MakeMapTensor), py::arg("key_dtype"), py::arg("value_dtype"), py::arg("value_shape"))
|
||||
.def_property_readonly("key_dtype", &MapTensor::KeyDtype)
|
||||
.def_property_readonly("value_dtype", &MapTensor::ValueDtype)
|
||||
.def_property_readonly("value_shape", &MapTensor::value_shape);
|
||||
}
|
||||
} // namespace tensor
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,41 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_UTILS_MAP_TENSOR_PY_H_
|
||||
#define MINDSPORE_CCSRC_UTILS_MAP_TENSOR_PY_H_
|
||||
|
||||
#include <tuple>
|
||||
#include "pybind11/numpy.h"
|
||||
#include "ir/map_tensor.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
namespace mindspore {
|
||||
//
|
||||
// MapTensor python adapter class.
|
||||
//
|
||||
class MapTensorPy {
|
||||
public:
|
||||
static MapTensorPtr MakeMapTensor(const TypePtr &key_dtype, const TypePtr &value_dtype,
|
||||
const ShapeVector &value_shape);
|
||||
|
||||
static void UpdateFromNumpy(const MapTensorPtr &map_tensor,
|
||||
const std::tuple<py::array, py::array, py::array> &numpy_data);
|
||||
|
||||
static std::tuple<py::array, py::array, py::array> ExportAsNumpy(const MapTensorPtr &map_tensor, bool full = false);
|
||||
};
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_UTILS_MAP_TENSOR_PY_H_
|
|
@ -0,0 +1,50 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ir/map_tensor.h"
|
||||
#include "utils/log_adapter.h"
|
||||
|
||||
namespace mindspore {
|
||||
using tensor::Tensor;
|
||||
using tensor::TensorPtr;
|
||||
|
||||
std::size_t MapTensor::hash() const { return static_cast<std::size_t>(tid()); }
|
||||
|
||||
bool MapTensor::operator==(const MapTensor &other) const { return this == &other; }
|
||||
|
||||
TensorPtr MapTensor::Get(const TensorPtr &key_tensor, const TensorPtr &default_value) {
|
||||
MS_EXCEPTION_IF_NULL(key_tensor);
|
||||
MS_EXCEPTION_IF_NULL(default_value);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void MapTensor::Put(const TensorPtr &key_tensor, const TensorPtr &value_tensor) {
|
||||
MS_EXCEPTION_IF_NULL(key_tensor);
|
||||
MS_EXCEPTION_IF_NULL(value_tensor);
|
||||
}
|
||||
|
||||
void MapTensor::Erase(const TensorPtr &key_tensor) { MS_EXCEPTION_IF_NULL(key_tensor); }
|
||||
|
||||
void MapTensor::Update(const MapTensor::ExportData &data) {
|
||||
MS_EXCEPTION_IF_NULL(data.key_tensor);
|
||||
MS_EXCEPTION_IF_NULL(data.value_tensor);
|
||||
}
|
||||
|
||||
MapTensor::ExportData MapTensor::Export(bool full) {
|
||||
MS_LOG(DEBUG) << (full ? "Full" : "Incremental") << " export MapTensor";
|
||||
return {nullptr, nullptr, nullptr};
|
||||
}
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,129 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_IR_MAP_TENSOR_H_
|
||||
#define MINDSPORE_CORE_IR_MAP_TENSOR_H_
|
||||
|
||||
#include <tuple>
|
||||
#include <memory>
|
||||
#include "ir/anf.h"
|
||||
#include "ir/dtype.h"
|
||||
#include "ir/tensor.h"
|
||||
#include "utils/macros.h"
|
||||
#include "utils/shape_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
///
|
||||
/// \brief MapTensor is a dynamic tensor with map like index functions.
|
||||
///
|
||||
class MS_CORE_API MapTensor final : public Value {
|
||||
public:
|
||||
using Tensor = tensor::Tensor;
|
||||
using TensorPtr = tensor::TensorPtr;
|
||||
|
||||
struct ExportData {
|
||||
TensorPtr key_tensor;
|
||||
TensorPtr value_tensor;
|
||||
TensorPtr status_tensor;
|
||||
};
|
||||
|
||||
enum class Status {
|
||||
kUnchanged = 0,
|
||||
kModified = 1,
|
||||
kErased = 2,
|
||||
};
|
||||
|
||||
/// \brief Create a empty MapTensor.
|
||||
///
|
||||
/// \param[in] key_dtype [TypeId] The key data type id.
|
||||
/// \param[in] value_dtype [TypeId] The value data type id.
|
||||
/// \param[in] value_shape [TypeId] The value shape.
|
||||
MapTensor(TypeId key_dtype, TypeId value_dtype, const ShapeVector &value_shape)
|
||||
: key_dtype_(key_dtype), value_dtype_(value_dtype), value_shape_(value_shape) {}
|
||||
|
||||
~MapTensor() override = default;
|
||||
|
||||
MS_DECLARE_PARENT(MapTensor, Value)
|
||||
|
||||
std::size_t hash() const override;
|
||||
|
||||
bool operator==(const Value &other) const override {
|
||||
if (this == &other) {
|
||||
return true;
|
||||
}
|
||||
if (!other.isa<MapTensor>()) {
|
||||
return false;
|
||||
}
|
||||
auto other_ = static_cast<const MapTensor &>(other);
|
||||
return *this == other_;
|
||||
}
|
||||
|
||||
bool operator==(const MapTensor &other) const;
|
||||
|
||||
TypeId key_dtype() const { return key_dtype_; }
|
||||
|
||||
TypeId value_dtype() const { return value_dtype_; }
|
||||
|
||||
const ShapeVector &value_shape() const { return value_shape_; }
|
||||
|
||||
TypePtr KeyDtype() const { return TypeIdToType(key_dtype_); }
|
||||
|
||||
TypePtr ValueDtype() const { return TypeIdToType(value_dtype_); }
|
||||
|
||||
/// \brief Get or create values.
|
||||
///
|
||||
/// \param[in] key_tensor [Tensor] The key tensor.
|
||||
/// \param[in] default_value [Tensor] The default value tensor.
|
||||
/// \return The value tensor according the key tensor.
|
||||
TensorPtr Get(const TensorPtr &key_tensor, const TensorPtr &default_value);
|
||||
|
||||
/// \brief Put or insert key value pairs.
|
||||
///
|
||||
/// \param[in] key_tensor [Tensor] The key tensor.
|
||||
/// \param[in] value_tensor [Tensor] The value tensor.
|
||||
void Put(const TensorPtr &key_tensor, const TensorPtr &value_tensor);
|
||||
|
||||
/// \brief Remove items with the given keys.
|
||||
///
|
||||
/// \param[in] key_tensor [Tensor] The key tensor.
|
||||
void Erase(const TensorPtr &key_tensor);
|
||||
|
||||
/// \brief Update MapTensor from exported data.
|
||||
///
|
||||
/// \param[in] data [ExportData] The data.
|
||||
void Update(const ExportData &data);
|
||||
|
||||
/// \brief Update MapTensor from exported data.
|
||||
///
|
||||
/// \param[in] full [bool] True for full export, false for incremental export.
|
||||
/// \return The exported data.
|
||||
ExportData Export(bool full = false);
|
||||
|
||||
private:
|
||||
// Data type of the key.
|
||||
TypeId key_dtype_;
|
||||
|
||||
// Data type of the value.
|
||||
TypeId value_dtype_;
|
||||
|
||||
// Shape of the value.
|
||||
ShapeVector value_shape_;
|
||||
};
|
||||
|
||||
// Smart pointer for MapTensor.
|
||||
using MapTensorPtr = std::shared_ptr<MapTensor>;
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CORE_IR_MAP_TENSOR_H_
|
|
@ -1,2 +1,2 @@
|
|||
Note: This is the mindspore Lite inference framework size threshold. Offline review is required before modify this value!!!
|
||||
1083704
|
||||
1085704
|
||||
|
|
|
@ -0,0 +1,20 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Experimental module."""
|
||||
from __future__ import absolute_import
|
||||
from mindspore.experimental.map_parameter import MapParameter
|
||||
|
||||
# Symbols from experimental module.
|
||||
__all__ = ["MapParameter"]
|
|
@ -0,0 +1,126 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""MapParameter implementation."""
|
||||
from __future__ import absolute_import
|
||||
|
||||
__all__ = ['MapParameter']
|
||||
|
||||
import numbers
|
||||
import mindspore as ms
|
||||
from mindspore.common.parameter import Parameter
|
||||
from mindspore.common.initializer import initializer
|
||||
from mindspore._c_expression import Tensor as Tensor_
|
||||
from mindspore._c_expression import MapTensor_
|
||||
|
||||
|
||||
class MapParameter(Parameter):
|
||||
"""
|
||||
MapParameter is a parameter that stores a map like data structure.
|
||||
|
||||
.. warning::
|
||||
This is an experimental prototype that is subject to change and/or deletion.
|
||||
|
||||
Args:
|
||||
key_dtype (:class:`mindspore.dtype`): The data type of the key. The argument should be defined in
|
||||
`mindspore.dtype`, currently only integer types are supported. Default: int32.
|
||||
value_dtype (:class:`mindspore.dtype`): The data type of the value Tensor. The argument should
|
||||
be defined in `mindspore.dtype`. Default: float32.
|
||||
value_shape (Union[tuple, list, int]): Used to indicate the shape of the value Tensor. The argument should be
|
||||
a list of integers, a tuple of integers or an integer. Default: 1.
|
||||
default_value (Union[Tensor, str]): The default value Tensor or initializer name. Default: 'zeros'.
|
||||
name (str): Name of the map parameter. Default: None.
|
||||
requires_grad (bool): True if the parameter requires gradient. Default: True.
|
||||
|
||||
|
||||
Examples:
|
||||
>>> import mindspore as ms
|
||||
>>> from mindspore import Tensor
|
||||
>>> from mindspore.experimental import MapParameter
|
||||
>>>
|
||||
>>> m = MapParameter(key_dtype=ms.int32, value_dtype=ms.float32, value_shape=(3), default_value='zeros')
|
||||
>>> t = m.get(Tensor([1, 2, 3], dtype=ms.int32))
|
||||
[[0. 0. 0.]
|
||||
[0. 0. 0.]
|
||||
[0. 0. 0.]]
|
||||
>>> m.put(Tensor([1, 2], dtype=ms.int32), Tensor([[1, 1, 1], [2, 2, 2]], dtype=np.float32))
|
||||
>>> t = m.get(Tensor([1, 2, 3], dtype=ms.int32))
|
||||
>>> print(t)
|
||||
[[1. 1. 1.]
|
||||
[2. 2. 2.]
|
||||
[0. 0. 0.]]
|
||||
>>> m.del(Tensor([2, 3], dtype=ms.int32))
|
||||
>>> t = m.get(Tensor([1, 2, 3], dtype=ms.int32))
|
||||
>>> print(t)
|
||||
[[1. 1. 1.]
|
||||
[0. 0. 0.]
|
||||
[0. 0. 0.]]
|
||||
"""
|
||||
|
||||
def __new__(cls, key_dtype=ms.int32, value_dtype=ms.float32, value_shape=1, default_value='zeros', **kwargs):
|
||||
if isinstance(value_shape, numbers.Number):
|
||||
value_shape = (value_shape,)
|
||||
data = Tensor_(value_dtype, value_shape)
|
||||
obj = Tensor_.__new__(cls)
|
||||
Tensor_.__init__(obj, data)
|
||||
obj.key_dtype = key_dtype
|
||||
obj.value_dtype = value_dtype
|
||||
obj.value_shape = value_shape
|
||||
obj.default_value = default_value
|
||||
return obj
|
||||
|
||||
def __init__(self, name=None, requires_grad=True, **kwargs):
|
||||
Parameter.__init__(self, self, name=name, requires_grad=requires_grad)
|
||||
self.map_tensor_ = MapTensor_(self.key_dtype, self.value_dtype, self.value_shape)
|
||||
|
||||
def get(self, key_tensor, default_value=None):
|
||||
"""
|
||||
Get value tensor according the key tensor, fill and return the default value if key is not existed.
|
||||
|
||||
Args:
|
||||
key_tensor (Tensor): The key tensor.
|
||||
default_value (Union[Tensor, str]): The default value or initializer. Default: None
|
||||
|
||||
Returns:
|
||||
Tensor, the value tensor for the key tensor.
|
||||
"""
|
||||
if default_value is None:
|
||||
default_value = self.default_value
|
||||
result = initializer(default_value, shape=(key_tensor.shape + self.value_shape), dtype=self.value_dtype)
|
||||
return result.init_data()
|
||||
|
||||
def put(self, key_tensor, value_tensor):
|
||||
"""
|
||||
Insert or update records according the given key tensor and value tensor.
|
||||
|
||||
Args:
|
||||
key_tensor (Tensor): The key tensor.
|
||||
value_tensor (Tensor): The value tensor.
|
||||
|
||||
Returns:
|
||||
MapParameter, the MapParameter object itself.
|
||||
"""
|
||||
return self
|
||||
|
||||
def erase(self, key_tensor):
|
||||
"""
|
||||
Remove records according the given key tensor.
|
||||
|
||||
Args:
|
||||
key_tensor (Tensor): The key tensor.
|
||||
|
||||
Returns:
|
||||
MapParameter, the MapParameter object itself.
|
||||
"""
|
||||
return self
|
|
@ -0,0 +1,38 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <memory>
|
||||
#include "common/common_test.h"
|
||||
#include "ir/tensor.h"
|
||||
#include "ir/map_tensor.h"
|
||||
|
||||
namespace mindspore {
|
||||
using tensor::Tensor;
|
||||
using tensor::TensorPtr;
|
||||
|
||||
class TestMapTensor : public UT::Common {
|
||||
public:
|
||||
TestMapTensor() = default;
|
||||
~TestMapTensor() = default;
|
||||
};
|
||||
|
||||
/// Feature: MapTensor
|
||||
/// Description: test MapTensor API.
|
||||
/// Expectation: MapTensor API work as expected.
|
||||
TEST_F(TestMapTensor, TestApi) {
|
||||
auto m = std::make_shared<MapTensor>(kNumberTypeInt32, kNumberTypeFloat32, ShapeVector{4});
|
||||
ASSERT_TRUE(m != nullptr);
|
||||
}
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,42 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
import mindspore as ms
|
||||
from mindspore import Tensor
|
||||
from mindspore.experimental import MapParameter
|
||||
|
||||
|
||||
def test_basic_operations():
|
||||
"""
|
||||
Feature: MapParameter
|
||||
Description: Test MapParameter basic operations.
|
||||
Expectation: MapParameter works as expected.
|
||||
"""
|
||||
m = MapParameter(key_dtype=ms.int32, value_dtype=ms.float32, value_shape=(2), default_value='zeros', name='my_map')
|
||||
assert m.name == 'my_map'
|
||||
assert m.requires_grad
|
||||
|
||||
t = m.get(Tensor([1, 2, 3], dtype=ms.int32))
|
||||
assert t.dtype == ms.float32
|
||||
assert t.shape == (3, 2)
|
||||
assert np.allclose(t.asnumpy(), 0)
|
||||
|
||||
t = m.get(Tensor([1, 2, 3], dtype=ms.int32), 'ones')
|
||||
assert t.dtype == ms.float32
|
||||
assert t.shape == (3, 2)
|
||||
assert np.allclose(t.asnumpy(), 1)
|
||||
|
||||
m.put(Tensor([1, 2, 3], dtype=ms.int32), Tensor([[1, 1], [2, 2], [3, 3]], dtype=ms.float32))
|
||||
m.erase(Tensor([1, 2, 3], dtype=ms.int32))
|
Loading…
Reference in New Issue