!44030 Implement backward graph for MapParameter

Merge pull request !44030 from hewei/map_param
This commit is contained in:
i-robot 2022-10-18 06:50:25 +00:00 committed by Gitee
commit 4253d35dc0
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
23 changed files with 498 additions and 28 deletions

View File

@ -362,9 +362,11 @@ BuiltInTypeMap &GetMethodMap() {
}},
{kObjectTypeMapTensorType,
{
{"get", std::string("map_tensor_get")}, // C.map_tensor_get
{"put", std::string("map_tensor_put")}, // C.map_tensor_put
{"erase", std::string("map_tensor_erase")}, // C.map_tensor_erase
{"get", std::string("map_tensor_get")}, // C.map_tensor_get
{"put", std::string("map_tensor_put")}, // C.map_tensor_put
{"erase", std::string("map_tensor_erase")}, // C.map_tensor_erase
{"get_keys", std::string("map_tensor_get_keys")}, // C.map_tensor_get_keys
{"get_values", std::string("map_tensor_get_values")}, // C.map_tensor_get_values
}},
{kObjectTypeJTagged, {}},
{kObjectTypeSymbolicKeyType, {}},

View File

@ -1848,8 +1848,8 @@ class RefToEmbedEvaluator : public SymbolicPrimEvaluator {
MS_EXCEPTION_IF_NULL(eval_result);
AbstractBasePtr abs = eval_result->abstract();
MS_EXCEPTION_IF_NULL(abs);
auto ref_abs = abs->cast_ptr<AbstractRefTensor>();
if (ref_abs == nullptr) {
auto ref_key_value = abstract::GetRefKeyValue(abs);
if (ref_key_value == nullptr) {
MS_LOG(ERROR) << "The first parameter of RefToEmbed should be Ref, but " << abs->ToString();
return nullptr;
}
@ -1865,11 +1865,9 @@ class RefToEmbedEvaluator : public SymbolicPrimEvaluator {
MS_EXCEPTION_IF_NULL(param);
ifEmbedIsWeight = param->has_default();
}
auto refkey = ref_abs->ref_key_value()->cast_ptr<StringImm>();
auto refkey = ref_key_value->cast_ptr<StringImm>();
if (refkey == nullptr || !ifEmbedIsWeight) {
auto ret = std::make_shared<AbstractScalar>(type);
auto ref_value = ref_abs->ref();
MS_EXCEPTION_IF_NULL(ref_value);
return std::make_shared<EvalResult>(ret, std::make_shared<AttrValueMap>());
}
@ -1884,8 +1882,7 @@ class RefToEmbedEvaluator : public SymbolicPrimEvaluator {
MS_LOG(ERROR) << "RefToEmbed input can't find parameter \"" << name << "\" in graph.";
return nullptr;
}
AbstractBasePtr x = ref_abs->ref();
x = SensitivityTransform(x);
AbstractBasePtr x = SensitivityTransform(abs);
std::shared_ptr<SymbolicKeyInstance> key = std::make_shared<SymbolicKeyInstance>(node, x);
std::shared_ptr<AbstractScalar> abs_scalar = std::make_shared<AbstractScalar>(key, type);
return std::make_shared<EvalResult>(abs_scalar, std::make_shared<AttrValueMap>());

View File

@ -1822,5 +1822,17 @@ bool AbstractIOMonad::operator==(const AbstractBase &other) const {
}
return other.isa<AbstractIOMonad>();
}
ValuePtr GetRefKeyValue(const AbstractBasePtr &abs) {
auto abs_ref = abs->cast_ptr<AbstractRefTensor>();
if (abs_ref != nullptr) {
return abs_ref->ref_key_value();
}
auto abs_map_tensor = abs->cast_ptr<AbstractMapTensor>();
if (abs_map_tensor != nullptr) {
return abs_map_tensor->ref_key_value();
}
return nullptr;
}
} // namespace abstract
} // namespace mindspore

View File

@ -1592,6 +1592,8 @@ MS_CORE_API std::string ExtractLoggingInfo(const std::string &info);
MS_CORE_API void SynchronizeSequenceElementsUseFlagsRecursively(const AbstractSequencePtr &lhs_sequence,
const AbstractSequencePtr &rhs_sequence);
MS_CORE_API ValuePtr GetRefKeyValue(const AbstractBasePtr &abs);
} // namespace abstract
} // namespace mindspore
#endif // MINDSPORE_CORE_ABSTRACT_ABSTRACT_VALUE_H_

View File

@ -17,6 +17,7 @@
#include "ir/dtype.h"
#include <cstdlib>
#include <algorithm>
#include "mindapi/base/type_id.h"
#include "utils/log_adapter.h"
#include "abstract/abstract_value.h"
@ -39,7 +40,8 @@ std::string GetExcptionTypeString(TypeId id) {
{kNumberTypeInt4, "Int4"},
{kNumberTypeGLUInt, "GLUInt"},
{kObjectTypeMonad, "Monad"},
{kObjectTypeCSRTensorType, "CSRTensor"}};
{kObjectTypeCSRTensorType, "CSRTensor"},
{kObjectTypeMapTensorType, "MapTensor"}};
auto it = type_id_to_string.find(id);
std::string type = "";

View File

@ -325,8 +325,10 @@ constexpr auto kIsCSRFunc = "IsCSRFunc";
constexpr auto kMapTensorGet = "MapTensorGet";
constexpr auto kMapTensorPut = "MapTensorPut";
constexpr auto kMapTensorErase = "MapTensorErase";
constexpr auto kMapTensorPutBprop = "MapTensorPutBprop";
constexpr auto kMapTensorGetDefaultValue = "MapTensorGetDefaultValue";
constexpr auto kMapTensorGetKeys = "MapTensorGetKeys";
constexpr auto kMapTensorGetValues = "MapTensorGetValues";
constexpr auto kMapTensorGetGrad = "MapTensorGetGrad";
// COOTensor
constexpr auto kMakeCOOTensor = "MakeCOOTensor";
@ -1060,8 +1062,10 @@ GVAR_DEF(PrimitivePtr, kPrimCSRTensorGetDenseShape, std::make_shared<Primitive>(
GVAR_DEF(PrimitivePtr, kPrimMapTensorGet, std::make_shared<Primitive>(kMapTensorGet));
GVAR_DEF(PrimitivePtr, kPrimMapTensorPut, std::make_shared<Primitive>(kMapTensorPut));
GVAR_DEF(PrimitivePtr, kPrimMapTensorErase, std::make_shared<Primitive>(kMapTensorErase));
GVAR_DEF(PrimitivePtr, kPrimMapTensorPutBprop, std::make_shared<Primitive>(kMapTensorPutBprop));
GVAR_DEF(PrimitivePtr, kPrimMapTensorGetDefaultValue, std::make_shared<Primitive>(kMapTensorGetDefaultValue));
GVAR_DEF(PrimitivePtr, kPrimMapTensorGetKeys, std::make_shared<Primitive>(kMapTensorGetKeys));
GVAR_DEF(PrimitivePtr, kPrimMapTensorGetValues, std::make_shared<Primitive>(kMapTensorGetValues));
GVAR_DEF(PrimitivePtr, kPrimMapTensorGetGrad, std::make_shared<Primitive>(kMapTensorGetGrad));
// Sparse ops
GVAR_DEF(PrimitivePtr, kPrimSparseTensorDenseMatmul, std::make_shared<Primitive>(kSparseTensorDenseMatmul));

View File

@ -0,0 +1,45 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ops/map_tensor_get_grad.h"
#include <vector>
#include <memory>
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "mindapi/src/helper.h"
#include "utils/ms_utils.h"
namespace mindspore {
namespace ops {
MIND_API_OPERATOR_IMPL(MapTensorGetGrad, BaseOperator);
AbstractBasePtr MapTensorGetGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
// Check number of arguments.
constexpr int64_t input_num = 4;
CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, input_num, kNameMapTensorGetGrad);
// Check argument abstracts.
auto abs_map_tensor =
CheckAndConvertUtils::CheckArgs<abstract::AbstractMapTensor>(kNameMapTensorGetGrad, input_args, kInputIndex0);
// We skip check other arguments, because grad operations are generated by compiler
// so we can assume that their are always correct.
// Grad map tensor has same abstract with the input map tensor.
return abs_map_tensor->Broaden();
}
REGISTER_PRIMITIVE_EVAL_IMPL(MapTensorGetGrad, prim::kPrimMapTensorGetGrad, MapTensorGetGradInfer, nullptr, true);
} // namespace ops
} // namespace mindspore

View File

@ -0,0 +1,44 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_MAP_TENSOR_GET_GRAD_H_
#define MINDSPORE_CORE_OPS_MAP_TENSOR_GET_GRAD_H_
#include <vector>
#include "ops/base_operator.h"
#include "mindapi/base/types.h"
namespace mindspore {
namespace ops {
constexpr auto kNameMapTensorGetGrad = "MapTensorGetGrad";
/// \brief Grad operator MapTensorGet.
/// Refer to Python API @ref mindspore.ops.MapTensorGetGrad for more details.
class MIND_API MapTensorGetGrad : public BaseOperator {
public:
MIND_API_BASE_MEMBER(MapTensorGetGrad);
/// \brief Constructor.
MapTensorGetGrad() : BaseOperator(kNameMapTensorGetGrad) {
InitIOName({"map_tensor", "key_tensor", "default_value", "grad"}, {"output"});
}
/// \brief Init.
void Init() const {}
};
abstract::AbstractBasePtr MapTensorGetGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &input_args);
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_MAP_TENSOR_GET_GRAD_H_

View File

@ -0,0 +1,45 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ops/map_tensor_get_keys.h"
#include <vector>
#include <memory>
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "mindapi/src/helper.h"
#include "utils/ms_utils.h"
namespace mindspore {
namespace ops {
MIND_API_OPERATOR_IMPL(MapTensorGetKeys, BaseOperator);
AbstractBasePtr MapTensorGetKeysInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
// Check number of arguments.
constexpr int64_t input_num = 1;
CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, input_num, kNameMapTensorGetKeys);
// Check argument abstracts.
auto abs_map_tensor =
CheckAndConvertUtils::CheckArgs<abstract::AbstractMapTensor>(kNameMapTensorGetKeys, input_args, kInputIndex0);
auto map_tensor_type = abs_map_tensor->map_tensor_type();
MS_EXCEPTION_IF_NULL(map_tensor_type);
const auto &key_dtype = map_tensor_type->key_dtype();
// We don't know the map size in compile time.
ShapeVector shape_vec = {abstract::Shape::kShapeDimAny};
return std::make_shared<abstract::AbstractTensor>(key_dtype, shape_vec);
}
REGISTER_PRIMITIVE_EVAL_IMPL(MapTensorGetKeys, prim::kPrimMapTensorGetKeys, MapTensorGetKeysInfer, nullptr, true);
} // namespace ops
} // namespace mindspore

View File

@ -0,0 +1,42 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_MAP_TENSOR_GET_KEYS_H_
#define MINDSPORE_CORE_OPS_MAP_TENSOR_GET_KEYS_H_
#include <vector>
#include "ops/base_operator.h"
#include "mindapi/base/types.h"
namespace mindspore {
namespace ops {
constexpr auto kNameMapTensorGetKeys = "MapTensorGetKeys";
/// \brief Get all keys as a tensor from a MapTensor.
/// Refer to Python API @ref mindspore.ops.MapTensorGetKeys for more details.
class MIND_API MapTensorGetKeys : public BaseOperator {
public:
MIND_API_BASE_MEMBER(MapTensorGetKeys);
/// \brief Constructor.
MapTensorGetKeys() : BaseOperator(kNameMapTensorGetKeys) { InitIOName({"map_tensor"}, {"output"}); }
/// \brief Init.
void Init() const {}
};
abstract::AbstractBasePtr MapTensorGetKeysInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &input_args);
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_MAP_TENSOR_GET_KEYS_H_

View File

@ -0,0 +1,49 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ops/map_tensor_get_values.h"
#include <vector>
#include <memory>
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "mindapi/src/helper.h"
#include "utils/ms_utils.h"
namespace mindspore {
namespace ops {
MIND_API_OPERATOR_IMPL(MapTensorGetValues, BaseOperator);
AbstractBasePtr MapTensorGetValuesInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
// Check number of arguments.
constexpr int64_t input_num = 1;
CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, input_num, kNameMapTensorGetValues);
// Check argument abstracts.
auto abs_map_tensor =
CheckAndConvertUtils::CheckArgs<abstract::AbstractMapTensor>(kNameMapTensorGetValues, input_args, kInputIndex0);
auto map_tensor_type = abs_map_tensor->map_tensor_type();
MS_EXCEPTION_IF_NULL(map_tensor_type);
const auto &value_dtype = map_tensor_type->value_dtype();
auto value_shape_ptr = abs_map_tensor->value_shape();
MS_EXCEPTION_IF_NULL(value_shape_ptr);
const auto &value_shape = value_shape_ptr->shape();
// We don't know the map size in compile time.
ShapeVector shape_vec = {abstract::Shape::kShapeDimAny};
(void)shape_vec.insert(shape_vec.end(), value_shape.begin(), value_shape.end());
return std::make_shared<abstract::AbstractTensor>(value_dtype, shape_vec);
}
REGISTER_PRIMITIVE_EVAL_IMPL(MapTensorGetValues, prim::kPrimMapTensorGetValues, MapTensorGetValuesInfer, nullptr, true);
} // namespace ops
} // namespace mindspore

View File

@ -0,0 +1,42 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_MAP_TENSOR_GET_VALUES_H_
#define MINDSPORE_CORE_OPS_MAP_TENSOR_GET_VALUES_H_
#include <vector>
#include "ops/base_operator.h"
#include "mindapi/base/types.h"
namespace mindspore {
namespace ops {
constexpr auto kNameMapTensorGetValues = "MapTensorGetValues";
/// \brief Get all values as a tensor from a MapTensor.
/// Refer to Python API @ref mindspore.ops.MapTensorGetValues for more details.
class MIND_API MapTensorGetValues : public BaseOperator {
public:
MIND_API_BASE_MEMBER(MapTensorGetValues);
/// \brief Constructor.
MapTensorGetValues() : BaseOperator(kNameMapTensorGetValues) { InitIOName({"map_tensor"}, {"output"}); }
/// \brief Init.
void Init() const {}
};
abstract::AbstractBasePtr MapTensorGetValuesInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &input_args);
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_MAP_TENSOR_GET_VALUES_H_

View File

@ -3289,6 +3289,20 @@ def map_tensor_erase(map_tensor, key_tensor):
return _map_tensor_ops.erase(map_tensor, key_tensor)
def map_tensor_get_keys(map_tensor):
r"""
Get all keys as a tensor.
"""
return _map_tensor_ops.get_keys(map_tensor)
def map_tensor_get_values(map_tensor):
r"""
Get all values as a tensor.
"""
return _map_tensor_ops.get_values(map_tensor)
def conj(input):
r"""
Computes complex conjugate of the input element-wise.

View File

@ -146,6 +146,24 @@ class MapParameter(Parameter):
result_tensor = self._map_tensor.get(key_tensor, default_value)
return Tensor(result_tensor, internal=True)
def get_keys(self):
"""
Get all keys as a tensor.
Returns:
Tensor, the tensor contains all keys.
"""
return None
def get_values(self):
"""
Get all values as a tensor.
Returns:
Tensor, the tensor contains all keys.
"""
return None
def put(self, key_tensor, value_tensor):
"""
Insert or update records according the given key tensor and value tensor.

View File

@ -28,6 +28,7 @@ from mindspore.ops.operations.sparse_ops import SparseSegmentMeanWithNumSegments
from mindspore.ops.operations.sparse_ops import SparseDenseCwiseMul
from mindspore.ops.operations.sparse_ops import SparseDenseCwiseDiv
from mindspore.ops.operations.sparse_ops import SparseTensorDenseAdd
from mindspore.ops.operations import _map_tensor_ops
from mindspore.common import dtype as mstype
from mindspore import Tensor
from mindspore.ops.primitive import constexpr
@ -293,3 +294,14 @@ def get_bprop_sparse_reorder(self):
return None, gather_op(dout[1], inverted_permutation, axis), None
return bprop
@bprop_getters.register(_map_tensor_ops.MapTensorGet)
def get_bprop_map_tensor_get(self):
"""Grad definition for `MapTensorGet` operation."""
grad_op = G.MapTensorGetGrad()
def bprop(map_tensor, key_tensor, default_value, out, dout):
grad_map_tensor = grad_op(map_tensor, key_tensor, default_value, dout)
return grad_map_tensor, zeros_like(key_tensor), zeros_like(default_value)
return bprop

View File

@ -21,7 +21,8 @@ Pre-defined combination of operators.
from mindspore.ops.composite.base import GradOperation, _Grad, HyperMap, Map, MultitypeFuncGraph, add_flags, \
core, env_get, tail, zip_operation, _Vmap, _TaylorOperation
core, tail, zip_operation, _Vmap, _TaylorOperation
from mindspore.ops.composite.env_ops import env_get
from mindspore.ops.composite.clip_ops import clip_by_value, clip_by_global_norm
from mindspore.ops.composite.multitype_ops.add_impl import hyper_add
from mindspore.ops.composite.multitype_ops.ones_like_impl import ones_like

View File

@ -32,8 +32,6 @@ from mindspore._c_expression import GradOperation_, HyperMap_, Map_, MultitypeFu
from mindspore.common import dtype as mstype
from mindspore.common.api import ms_function, _pynative_executor, _wrap_func
from mindspore.ops.primitive import Primitive
from mindspore.ops.operations import _grad_ops
from mindspore.ops import operations as P
from mindspore.ops import signature as sig
__all__ = [TupleAdd_, UnpackCall_, TupleGetItemTensor_, SequenceSliceGetItem_, ListSliceSetItem_]
@ -1108,15 +1106,3 @@ class _ZipOperation(ZipOperation_):
zip_operation = _ZipOperation('zip_operation')
"""`zip_operation` will generate a tuple of zip iterations of inputs."""
env_get = MultitypeFuncGraph("env_get")
environ_get = Primitive('EnvironGet')
ref_to_embed = _grad_ops.RefToEmbed()
zeros_like = P.ZerosLike()
@env_get.register("EnvType", "Tensor")
def _tensor_env_get(env, parameter):
"""Used to get env."""
return environ_get(env, ref_to_embed(parameter), zeros_like(parameter))

View File

@ -0,0 +1,41 @@
# This is the Python adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
#
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Env related operations."""
from __future__ import absolute_import
from mindspore.ops.composite.base import MultitypeFuncGraph
from mindspore.ops.composite.multitype_ops.zeros_like_impl import zeros_like
from mindspore.ops.primitive import Primitive
from mindspore.ops.operations import _grad_ops
from mindspore.ops import operations as P
env_get = MultitypeFuncGraph("env_get")
environ_get = Primitive('EnvironGet')
ref_to_embed = _grad_ops.RefToEmbed()
tensor_zeros_like = P.ZerosLike()
@env_get.register("EnvType", "Tensor")
def _tensor_env_get(env, parameter):
"""Used to get env."""
return environ_get(env, ref_to_embed(parameter), tensor_zeros_like(parameter))
@env_get.register("EnvType", "MapTensor")
def _map_tensor_env_get(env, map_parameter):
"""Used to get env for map parameter."""
return environ_get(env, ref_to_embed(map_parameter), zeros_like(map_parameter))

View File

@ -513,4 +513,20 @@ def _add_tuple_cootensor(x, y):
x = COOTensor(x[0], x[1], y.shape)
return _add_cootensor_cootensor(x, y)
@_add_backward.register("MapTensor", "MapTensor")
def _map_tensor_add_backward(x, y):
"""
Adds MapTensors for backward.
Args:
x (MapTensor): x
y (MapTensor): y
Returns:
MapTensor.
"""
return x
hyper_add = base.HyperMap(_add_backward)

View File

@ -73,6 +73,12 @@ def _zeros_like_csr_tensor(x):
return F.make_csr_tensor(x.indptr, x.indices, values, x.shape)
@zeros_like_leaf.register("MapTensor")
def _zeros_like_map_tensor(x):
"""Returns a map tensor with the same shape and dtype as x and all elements are 0."""
return x
@zeros_like_leaf.register("TypeType")
def _zeros_like_type_type(x):
"""Returns x because x is a type. This is usually used in backprop progress."""

View File

@ -3838,3 +3838,23 @@ class CholeskyGrad(Primitive):
def __init__(self):
"""Initialize CholeskyGrad"""
self.init_prim_io_names(inputs=['x', 'grad'], outputs=['y'])
class MapTensorGetGrad(Primitive):
"""
Computes gradients for MapTensorGet operation.
Inputs:
- **map_tensor** (MapTensor) - The input `map_tensor` of the forward operator MapTensorGet.
- **key_tensor** (Tensor) - The input `key_tensor` of the forward operator MapTensorGet.
- **default_value** (Scalar) - The input `default_value` of the forward operator MapTensorGet.
- **grad** (Tensor) - The grad value according the forward operator MapTensorGet.
Outputs:
- **output** (MapTensor) - MapTensor with grad values.
"""
@prim_attr_register
def __init__(self):
"""Initialize MapTensorGetGrad"""
self.init_prim_io_names(inputs=['map_tensor', 'key_tensor', 'default_value', 'grad'], outputs=['output'])
self.add_prim_attr('side_effect_mem', True)

View File

@ -65,6 +65,35 @@ class MapTensorErase(Primitive):
self.init_prim_io_names(inputs=['map_tensor', 'key_tensor'], outputs=['output'])
self.add_prim_attr('side_effect_mem', True)
class MapTensorGetKeys(Primitive):
"""
Get all keys as a tensor.
"""
__mindspore_signature__ = (sig.make_sig('map_tensor'),)
@prim_attr_register
def __init__(self):
"""Initialize MapTensorGetKeys"""
self.init_prim_io_names(inputs=['map_tensor'], outputs=['output'])
self.add_prim_attr('side_effect_mem', True)
class MapTensorGetValues(Primitive):
"""
Get all keys as a tensor.
"""
__mindspore_signature__ = (sig.make_sig('map_tensor'),)
@prim_attr_register
def __init__(self):
"""Initialize MapTensorGetValues"""
self.init_prim_io_names(inputs=['map_tensor'], outputs=['output'])
self.add_prim_attr('side_effect_mem', True)
get = MapTensorGet()
put = MapTensorPut()
erase = MapTensorErase()
get_keys = MapTensorGetKeys()
get_values = MapTensorGetValues()

View File

@ -18,6 +18,7 @@ import mindspore.nn as nn
from mindspore import context, Tensor, Parameter, ParameterTuple
from mindspore.experimental import MapParameter
from mindspore.common.initializer import initializer
from mindspore.ops import composite as C
def test_basic_operations():
@ -73,6 +74,9 @@ def test_simple_graph_compile():
value4 = self.m[self.key]
self.m[self.key] = value4
self.m.erase(self.key)
keys = self.m.get_keys()
values = self.m.get_values()
self.m.put(keys, values)
return self.p + value1 + value2 + value3 + value4
context.set_context(mode=context.GRAPH_MODE)
@ -126,3 +130,40 @@ def test_map_parameter_clone():
assert clone_same.value_dtype == m.value_dtype
assert clone_same.value_shape == m.value_shape
assert clone_same.default_value == 'zeros'
def test_grad_net():
"""
Feature: MapParameter
Description: Test grad graph compiled with MapParameter.
Expectation: Grad graph for MapParameter created without exceptions.
"""
class MyNet(nn.Cell):
def __init__(self):
nn.Cell.__init__(self)
self.p = Parameter(initializer('ones', (2, 3), ms.float32))
self.m = MapParameter(key_dtype=ms.int32, value_dtype=ms.float32, value_shape=(3,))
self.key = Tensor([1, 2], dtype=ms.int32)
def construct(self, x):
a = self.m.get(self.key, 0.1)
self.m.erase(self.key)
return x * a
class GradNet(nn.Cell):
def __init__(self, network):
super(GradNet, self).__init__()
self.grad_by_list = C.GradOperation(get_by_list=True)
self.network = network
self.weights = ParameterTuple(network.trainable_params())
def construct(self, *inputs):
gout = self.grad_by_list(self.network, self.weights)(*inputs)
return gout
context.set_context(mode=context.GRAPH_MODE)
net = MyNet()
grad = GradNet(net)
t = initializer('ones', (2, 3), ms.float32)
t = t.init_data()
grad(t)