!44030 Implement backward graph for MapParameter

Merge pull request !44030 from hewei/map_param
This commit is contained in:
i-robot 2022-10-18 06:50:25 +00:00 committed by Gitee
commit 4253d35dc0
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
23 changed files with 498 additions and 28 deletions

View File

@ -362,9 +362,11 @@ BuiltInTypeMap &GetMethodMap() {
}}, }},
{kObjectTypeMapTensorType, {kObjectTypeMapTensorType,
{ {
{"get", std::string("map_tensor_get")}, // C.map_tensor_get {"get", std::string("map_tensor_get")}, // C.map_tensor_get
{"put", std::string("map_tensor_put")}, // C.map_tensor_put {"put", std::string("map_tensor_put")}, // C.map_tensor_put
{"erase", std::string("map_tensor_erase")}, // C.map_tensor_erase {"erase", std::string("map_tensor_erase")}, // C.map_tensor_erase
{"get_keys", std::string("map_tensor_get_keys")}, // C.map_tensor_get_keys
{"get_values", std::string("map_tensor_get_values")}, // C.map_tensor_get_values
}}, }},
{kObjectTypeJTagged, {}}, {kObjectTypeJTagged, {}},
{kObjectTypeSymbolicKeyType, {}}, {kObjectTypeSymbolicKeyType, {}},

View File

@ -1848,8 +1848,8 @@ class RefToEmbedEvaluator : public SymbolicPrimEvaluator {
MS_EXCEPTION_IF_NULL(eval_result); MS_EXCEPTION_IF_NULL(eval_result);
AbstractBasePtr abs = eval_result->abstract(); AbstractBasePtr abs = eval_result->abstract();
MS_EXCEPTION_IF_NULL(abs); MS_EXCEPTION_IF_NULL(abs);
auto ref_abs = abs->cast_ptr<AbstractRefTensor>(); auto ref_key_value = abstract::GetRefKeyValue(abs);
if (ref_abs == nullptr) { if (ref_key_value == nullptr) {
MS_LOG(ERROR) << "The first parameter of RefToEmbed should be Ref, but " << abs->ToString(); MS_LOG(ERROR) << "The first parameter of RefToEmbed should be Ref, but " << abs->ToString();
return nullptr; return nullptr;
} }
@ -1865,11 +1865,9 @@ class RefToEmbedEvaluator : public SymbolicPrimEvaluator {
MS_EXCEPTION_IF_NULL(param); MS_EXCEPTION_IF_NULL(param);
ifEmbedIsWeight = param->has_default(); ifEmbedIsWeight = param->has_default();
} }
auto refkey = ref_abs->ref_key_value()->cast_ptr<StringImm>(); auto refkey = ref_key_value->cast_ptr<StringImm>();
if (refkey == nullptr || !ifEmbedIsWeight) { if (refkey == nullptr || !ifEmbedIsWeight) {
auto ret = std::make_shared<AbstractScalar>(type); auto ret = std::make_shared<AbstractScalar>(type);
auto ref_value = ref_abs->ref();
MS_EXCEPTION_IF_NULL(ref_value);
return std::make_shared<EvalResult>(ret, std::make_shared<AttrValueMap>()); return std::make_shared<EvalResult>(ret, std::make_shared<AttrValueMap>());
} }
@ -1884,8 +1882,7 @@ class RefToEmbedEvaluator : public SymbolicPrimEvaluator {
MS_LOG(ERROR) << "RefToEmbed input can't find parameter \"" << name << "\" in graph."; MS_LOG(ERROR) << "RefToEmbed input can't find parameter \"" << name << "\" in graph.";
return nullptr; return nullptr;
} }
AbstractBasePtr x = ref_abs->ref(); AbstractBasePtr x = SensitivityTransform(abs);
x = SensitivityTransform(x);
std::shared_ptr<SymbolicKeyInstance> key = std::make_shared<SymbolicKeyInstance>(node, x); std::shared_ptr<SymbolicKeyInstance> key = std::make_shared<SymbolicKeyInstance>(node, x);
std::shared_ptr<AbstractScalar> abs_scalar = std::make_shared<AbstractScalar>(key, type); std::shared_ptr<AbstractScalar> abs_scalar = std::make_shared<AbstractScalar>(key, type);
return std::make_shared<EvalResult>(abs_scalar, std::make_shared<AttrValueMap>()); return std::make_shared<EvalResult>(abs_scalar, std::make_shared<AttrValueMap>());

View File

@ -1822,5 +1822,17 @@ bool AbstractIOMonad::operator==(const AbstractBase &other) const {
} }
return other.isa<AbstractIOMonad>(); return other.isa<AbstractIOMonad>();
} }
ValuePtr GetRefKeyValue(const AbstractBasePtr &abs) {
auto abs_ref = abs->cast_ptr<AbstractRefTensor>();
if (abs_ref != nullptr) {
return abs_ref->ref_key_value();
}
auto abs_map_tensor = abs->cast_ptr<AbstractMapTensor>();
if (abs_map_tensor != nullptr) {
return abs_map_tensor->ref_key_value();
}
return nullptr;
}
} // namespace abstract } // namespace abstract
} // namespace mindspore } // namespace mindspore

View File

@ -1592,6 +1592,8 @@ MS_CORE_API std::string ExtractLoggingInfo(const std::string &info);
MS_CORE_API void SynchronizeSequenceElementsUseFlagsRecursively(const AbstractSequencePtr &lhs_sequence, MS_CORE_API void SynchronizeSequenceElementsUseFlagsRecursively(const AbstractSequencePtr &lhs_sequence,
const AbstractSequencePtr &rhs_sequence); const AbstractSequencePtr &rhs_sequence);
MS_CORE_API ValuePtr GetRefKeyValue(const AbstractBasePtr &abs);
} // namespace abstract } // namespace abstract
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CORE_ABSTRACT_ABSTRACT_VALUE_H_ #endif // MINDSPORE_CORE_ABSTRACT_ABSTRACT_VALUE_H_

View File

@ -17,6 +17,7 @@
#include "ir/dtype.h" #include "ir/dtype.h"
#include <cstdlib> #include <cstdlib>
#include <algorithm> #include <algorithm>
#include "mindapi/base/type_id.h"
#include "utils/log_adapter.h" #include "utils/log_adapter.h"
#include "abstract/abstract_value.h" #include "abstract/abstract_value.h"
@ -39,7 +40,8 @@ std::string GetExcptionTypeString(TypeId id) {
{kNumberTypeInt4, "Int4"}, {kNumberTypeInt4, "Int4"},
{kNumberTypeGLUInt, "GLUInt"}, {kNumberTypeGLUInt, "GLUInt"},
{kObjectTypeMonad, "Monad"}, {kObjectTypeMonad, "Monad"},
{kObjectTypeCSRTensorType, "CSRTensor"}}; {kObjectTypeCSRTensorType, "CSRTensor"},
{kObjectTypeMapTensorType, "MapTensor"}};
auto it = type_id_to_string.find(id); auto it = type_id_to_string.find(id);
std::string type = ""; std::string type = "";

View File

@ -325,8 +325,10 @@ constexpr auto kIsCSRFunc = "IsCSRFunc";
constexpr auto kMapTensorGet = "MapTensorGet"; constexpr auto kMapTensorGet = "MapTensorGet";
constexpr auto kMapTensorPut = "MapTensorPut"; constexpr auto kMapTensorPut = "MapTensorPut";
constexpr auto kMapTensorErase = "MapTensorErase"; constexpr auto kMapTensorErase = "MapTensorErase";
constexpr auto kMapTensorPutBprop = "MapTensorPutBprop";
constexpr auto kMapTensorGetDefaultValue = "MapTensorGetDefaultValue"; constexpr auto kMapTensorGetDefaultValue = "MapTensorGetDefaultValue";
constexpr auto kMapTensorGetKeys = "MapTensorGetKeys";
constexpr auto kMapTensorGetValues = "MapTensorGetValues";
constexpr auto kMapTensorGetGrad = "MapTensorGetGrad";
// COOTensor // COOTensor
constexpr auto kMakeCOOTensor = "MakeCOOTensor"; constexpr auto kMakeCOOTensor = "MakeCOOTensor";
@ -1060,8 +1062,10 @@ GVAR_DEF(PrimitivePtr, kPrimCSRTensorGetDenseShape, std::make_shared<Primitive>(
GVAR_DEF(PrimitivePtr, kPrimMapTensorGet, std::make_shared<Primitive>(kMapTensorGet)); GVAR_DEF(PrimitivePtr, kPrimMapTensorGet, std::make_shared<Primitive>(kMapTensorGet));
GVAR_DEF(PrimitivePtr, kPrimMapTensorPut, std::make_shared<Primitive>(kMapTensorPut)); GVAR_DEF(PrimitivePtr, kPrimMapTensorPut, std::make_shared<Primitive>(kMapTensorPut));
GVAR_DEF(PrimitivePtr, kPrimMapTensorErase, std::make_shared<Primitive>(kMapTensorErase)); GVAR_DEF(PrimitivePtr, kPrimMapTensorErase, std::make_shared<Primitive>(kMapTensorErase));
GVAR_DEF(PrimitivePtr, kPrimMapTensorPutBprop, std::make_shared<Primitive>(kMapTensorPutBprop));
GVAR_DEF(PrimitivePtr, kPrimMapTensorGetDefaultValue, std::make_shared<Primitive>(kMapTensorGetDefaultValue)); GVAR_DEF(PrimitivePtr, kPrimMapTensorGetDefaultValue, std::make_shared<Primitive>(kMapTensorGetDefaultValue));
GVAR_DEF(PrimitivePtr, kPrimMapTensorGetKeys, std::make_shared<Primitive>(kMapTensorGetKeys));
GVAR_DEF(PrimitivePtr, kPrimMapTensorGetValues, std::make_shared<Primitive>(kMapTensorGetValues));
GVAR_DEF(PrimitivePtr, kPrimMapTensorGetGrad, std::make_shared<Primitive>(kMapTensorGetGrad));
// Sparse ops // Sparse ops
GVAR_DEF(PrimitivePtr, kPrimSparseTensorDenseMatmul, std::make_shared<Primitive>(kSparseTensorDenseMatmul)); GVAR_DEF(PrimitivePtr, kPrimSparseTensorDenseMatmul, std::make_shared<Primitive>(kSparseTensorDenseMatmul));

View File

@ -0,0 +1,45 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ops/map_tensor_get_grad.h"
#include <vector>
#include <memory>
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "mindapi/src/helper.h"
#include "utils/ms_utils.h"
namespace mindspore {
namespace ops {
MIND_API_OPERATOR_IMPL(MapTensorGetGrad, BaseOperator);
AbstractBasePtr MapTensorGetGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
// Check number of arguments.
constexpr int64_t input_num = 4;
CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, input_num, kNameMapTensorGetGrad);
// Check argument abstracts.
auto abs_map_tensor =
CheckAndConvertUtils::CheckArgs<abstract::AbstractMapTensor>(kNameMapTensorGetGrad, input_args, kInputIndex0);
// We skip check other arguments, because grad operations are generated by compiler
// so we can assume that their are always correct.
// Grad map tensor has same abstract with the input map tensor.
return abs_map_tensor->Broaden();
}
REGISTER_PRIMITIVE_EVAL_IMPL(MapTensorGetGrad, prim::kPrimMapTensorGetGrad, MapTensorGetGradInfer, nullptr, true);
} // namespace ops
} // namespace mindspore

View File

@ -0,0 +1,44 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_MAP_TENSOR_GET_GRAD_H_
#define MINDSPORE_CORE_OPS_MAP_TENSOR_GET_GRAD_H_
#include <vector>
#include "ops/base_operator.h"
#include "mindapi/base/types.h"
namespace mindspore {
namespace ops {
constexpr auto kNameMapTensorGetGrad = "MapTensorGetGrad";
/// \brief Grad operator MapTensorGet.
/// Refer to Python API @ref mindspore.ops.MapTensorGetGrad for more details.
class MIND_API MapTensorGetGrad : public BaseOperator {
public:
MIND_API_BASE_MEMBER(MapTensorGetGrad);
/// \brief Constructor.
MapTensorGetGrad() : BaseOperator(kNameMapTensorGetGrad) {
InitIOName({"map_tensor", "key_tensor", "default_value", "grad"}, {"output"});
}
/// \brief Init.
void Init() const {}
};
abstract::AbstractBasePtr MapTensorGetGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &input_args);
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_MAP_TENSOR_GET_GRAD_H_

View File

@ -0,0 +1,45 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ops/map_tensor_get_keys.h"
#include <vector>
#include <memory>
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "mindapi/src/helper.h"
#include "utils/ms_utils.h"
namespace mindspore {
namespace ops {
MIND_API_OPERATOR_IMPL(MapTensorGetKeys, BaseOperator);
AbstractBasePtr MapTensorGetKeysInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
// Check number of arguments.
constexpr int64_t input_num = 1;
CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, input_num, kNameMapTensorGetKeys);
// Check argument abstracts.
auto abs_map_tensor =
CheckAndConvertUtils::CheckArgs<abstract::AbstractMapTensor>(kNameMapTensorGetKeys, input_args, kInputIndex0);
auto map_tensor_type = abs_map_tensor->map_tensor_type();
MS_EXCEPTION_IF_NULL(map_tensor_type);
const auto &key_dtype = map_tensor_type->key_dtype();
// We don't know the map size in compile time.
ShapeVector shape_vec = {abstract::Shape::kShapeDimAny};
return std::make_shared<abstract::AbstractTensor>(key_dtype, shape_vec);
}
REGISTER_PRIMITIVE_EVAL_IMPL(MapTensorGetKeys, prim::kPrimMapTensorGetKeys, MapTensorGetKeysInfer, nullptr, true);
} // namespace ops
} // namespace mindspore

View File

@ -0,0 +1,42 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_MAP_TENSOR_GET_KEYS_H_
#define MINDSPORE_CORE_OPS_MAP_TENSOR_GET_KEYS_H_
#include <vector>
#include "ops/base_operator.h"
#include "mindapi/base/types.h"
namespace mindspore {
namespace ops {
constexpr auto kNameMapTensorGetKeys = "MapTensorGetKeys";
/// \brief Get all keys as a tensor from a MapTensor.
/// Refer to Python API @ref mindspore.ops.MapTensorGetKeys for more details.
class MIND_API MapTensorGetKeys : public BaseOperator {
public:
MIND_API_BASE_MEMBER(MapTensorGetKeys);
/// \brief Constructor.
MapTensorGetKeys() : BaseOperator(kNameMapTensorGetKeys) { InitIOName({"map_tensor"}, {"output"}); }
/// \brief Init.
void Init() const {}
};
abstract::AbstractBasePtr MapTensorGetKeysInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &input_args);
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_MAP_TENSOR_GET_KEYS_H_

View File

@ -0,0 +1,49 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ops/map_tensor_get_values.h"
#include <vector>
#include <memory>
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "mindapi/src/helper.h"
#include "utils/ms_utils.h"
namespace mindspore {
namespace ops {
MIND_API_OPERATOR_IMPL(MapTensorGetValues, BaseOperator);
AbstractBasePtr MapTensorGetValuesInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
// Check number of arguments.
constexpr int64_t input_num = 1;
CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, input_num, kNameMapTensorGetValues);
// Check argument abstracts.
auto abs_map_tensor =
CheckAndConvertUtils::CheckArgs<abstract::AbstractMapTensor>(kNameMapTensorGetValues, input_args, kInputIndex0);
auto map_tensor_type = abs_map_tensor->map_tensor_type();
MS_EXCEPTION_IF_NULL(map_tensor_type);
const auto &value_dtype = map_tensor_type->value_dtype();
auto value_shape_ptr = abs_map_tensor->value_shape();
MS_EXCEPTION_IF_NULL(value_shape_ptr);
const auto &value_shape = value_shape_ptr->shape();
// We don't know the map size in compile time.
ShapeVector shape_vec = {abstract::Shape::kShapeDimAny};
(void)shape_vec.insert(shape_vec.end(), value_shape.begin(), value_shape.end());
return std::make_shared<abstract::AbstractTensor>(value_dtype, shape_vec);
}
REGISTER_PRIMITIVE_EVAL_IMPL(MapTensorGetValues, prim::kPrimMapTensorGetValues, MapTensorGetValuesInfer, nullptr, true);
} // namespace ops
} // namespace mindspore

View File

@ -0,0 +1,42 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_MAP_TENSOR_GET_VALUES_H_
#define MINDSPORE_CORE_OPS_MAP_TENSOR_GET_VALUES_H_
#include <vector>
#include "ops/base_operator.h"
#include "mindapi/base/types.h"
namespace mindspore {
namespace ops {
constexpr auto kNameMapTensorGetValues = "MapTensorGetValues";
/// \brief Get all values as a tensor from a MapTensor.
/// Refer to Python API @ref mindspore.ops.MapTensorGetValues for more details.
class MIND_API MapTensorGetValues : public BaseOperator {
public:
MIND_API_BASE_MEMBER(MapTensorGetValues);
/// \brief Constructor.
MapTensorGetValues() : BaseOperator(kNameMapTensorGetValues) { InitIOName({"map_tensor"}, {"output"}); }
/// \brief Init.
void Init() const {}
};
abstract::AbstractBasePtr MapTensorGetValuesInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &input_args);
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_MAP_TENSOR_GET_VALUES_H_

View File

@ -3289,6 +3289,20 @@ def map_tensor_erase(map_tensor, key_tensor):
return _map_tensor_ops.erase(map_tensor, key_tensor) return _map_tensor_ops.erase(map_tensor, key_tensor)
def map_tensor_get_keys(map_tensor):
r"""
Get all keys as a tensor.
"""
return _map_tensor_ops.get_keys(map_tensor)
def map_tensor_get_values(map_tensor):
r"""
Get all values as a tensor.
"""
return _map_tensor_ops.get_values(map_tensor)
def conj(input): def conj(input):
r""" r"""
Computes complex conjugate of the input element-wise. Computes complex conjugate of the input element-wise.

View File

@ -146,6 +146,24 @@ class MapParameter(Parameter):
result_tensor = self._map_tensor.get(key_tensor, default_value) result_tensor = self._map_tensor.get(key_tensor, default_value)
return Tensor(result_tensor, internal=True) return Tensor(result_tensor, internal=True)
def get_keys(self):
"""
Get all keys as a tensor.
Returns:
Tensor, the tensor contains all keys.
"""
return None
def get_values(self):
"""
Get all values as a tensor.
Returns:
Tensor, the tensor contains all keys.
"""
return None
def put(self, key_tensor, value_tensor): def put(self, key_tensor, value_tensor):
""" """
Insert or update records according the given key tensor and value tensor. Insert or update records according the given key tensor and value tensor.

View File

@ -28,6 +28,7 @@ from mindspore.ops.operations.sparse_ops import SparseSegmentMeanWithNumSegments
from mindspore.ops.operations.sparse_ops import SparseDenseCwiseMul from mindspore.ops.operations.sparse_ops import SparseDenseCwiseMul
from mindspore.ops.operations.sparse_ops import SparseDenseCwiseDiv from mindspore.ops.operations.sparse_ops import SparseDenseCwiseDiv
from mindspore.ops.operations.sparse_ops import SparseTensorDenseAdd from mindspore.ops.operations.sparse_ops import SparseTensorDenseAdd
from mindspore.ops.operations import _map_tensor_ops
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from mindspore import Tensor from mindspore import Tensor
from mindspore.ops.primitive import constexpr from mindspore.ops.primitive import constexpr
@ -293,3 +294,14 @@ def get_bprop_sparse_reorder(self):
return None, gather_op(dout[1], inverted_permutation, axis), None return None, gather_op(dout[1], inverted_permutation, axis), None
return bprop return bprop
@bprop_getters.register(_map_tensor_ops.MapTensorGet)
def get_bprop_map_tensor_get(self):
"""Grad definition for `MapTensorGet` operation."""
grad_op = G.MapTensorGetGrad()
def bprop(map_tensor, key_tensor, default_value, out, dout):
grad_map_tensor = grad_op(map_tensor, key_tensor, default_value, dout)
return grad_map_tensor, zeros_like(key_tensor), zeros_like(default_value)
return bprop

View File

@ -21,7 +21,8 @@ Pre-defined combination of operators.
from mindspore.ops.composite.base import GradOperation, _Grad, HyperMap, Map, MultitypeFuncGraph, add_flags, \ from mindspore.ops.composite.base import GradOperation, _Grad, HyperMap, Map, MultitypeFuncGraph, add_flags, \
core, env_get, tail, zip_operation, _Vmap, _TaylorOperation core, tail, zip_operation, _Vmap, _TaylorOperation
from mindspore.ops.composite.env_ops import env_get
from mindspore.ops.composite.clip_ops import clip_by_value, clip_by_global_norm from mindspore.ops.composite.clip_ops import clip_by_value, clip_by_global_norm
from mindspore.ops.composite.multitype_ops.add_impl import hyper_add from mindspore.ops.composite.multitype_ops.add_impl import hyper_add
from mindspore.ops.composite.multitype_ops.ones_like_impl import ones_like from mindspore.ops.composite.multitype_ops.ones_like_impl import ones_like

View File

@ -32,8 +32,6 @@ from mindspore._c_expression import GradOperation_, HyperMap_, Map_, MultitypeFu
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from mindspore.common.api import ms_function, _pynative_executor, _wrap_func from mindspore.common.api import ms_function, _pynative_executor, _wrap_func
from mindspore.ops.primitive import Primitive from mindspore.ops.primitive import Primitive
from mindspore.ops.operations import _grad_ops
from mindspore.ops import operations as P
from mindspore.ops import signature as sig from mindspore.ops import signature as sig
__all__ = [TupleAdd_, UnpackCall_, TupleGetItemTensor_, SequenceSliceGetItem_, ListSliceSetItem_] __all__ = [TupleAdd_, UnpackCall_, TupleGetItemTensor_, SequenceSliceGetItem_, ListSliceSetItem_]
@ -1108,15 +1106,3 @@ class _ZipOperation(ZipOperation_):
zip_operation = _ZipOperation('zip_operation') zip_operation = _ZipOperation('zip_operation')
"""`zip_operation` will generate a tuple of zip iterations of inputs.""" """`zip_operation` will generate a tuple of zip iterations of inputs."""
env_get = MultitypeFuncGraph("env_get")
environ_get = Primitive('EnvironGet')
ref_to_embed = _grad_ops.RefToEmbed()
zeros_like = P.ZerosLike()
@env_get.register("EnvType", "Tensor")
def _tensor_env_get(env, parameter):
"""Used to get env."""
return environ_get(env, ref_to_embed(parameter), zeros_like(parameter))

View File

@ -0,0 +1,41 @@
# This is the Python adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
#
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Env related operations."""
from __future__ import absolute_import
from mindspore.ops.composite.base import MultitypeFuncGraph
from mindspore.ops.composite.multitype_ops.zeros_like_impl import zeros_like
from mindspore.ops.primitive import Primitive
from mindspore.ops.operations import _grad_ops
from mindspore.ops import operations as P
env_get = MultitypeFuncGraph("env_get")
environ_get = Primitive('EnvironGet')
ref_to_embed = _grad_ops.RefToEmbed()
tensor_zeros_like = P.ZerosLike()
@env_get.register("EnvType", "Tensor")
def _tensor_env_get(env, parameter):
"""Used to get env."""
return environ_get(env, ref_to_embed(parameter), tensor_zeros_like(parameter))
@env_get.register("EnvType", "MapTensor")
def _map_tensor_env_get(env, map_parameter):
"""Used to get env for map parameter."""
return environ_get(env, ref_to_embed(map_parameter), zeros_like(map_parameter))

View File

@ -513,4 +513,20 @@ def _add_tuple_cootensor(x, y):
x = COOTensor(x[0], x[1], y.shape) x = COOTensor(x[0], x[1], y.shape)
return _add_cootensor_cootensor(x, y) return _add_cootensor_cootensor(x, y)
@_add_backward.register("MapTensor", "MapTensor")
def _map_tensor_add_backward(x, y):
"""
Adds MapTensors for backward.
Args:
x (MapTensor): x
y (MapTensor): y
Returns:
MapTensor.
"""
return x
hyper_add = base.HyperMap(_add_backward) hyper_add = base.HyperMap(_add_backward)

View File

@ -73,6 +73,12 @@ def _zeros_like_csr_tensor(x):
return F.make_csr_tensor(x.indptr, x.indices, values, x.shape) return F.make_csr_tensor(x.indptr, x.indices, values, x.shape)
@zeros_like_leaf.register("MapTensor")
def _zeros_like_map_tensor(x):
"""Returns a map tensor with the same shape and dtype as x and all elements are 0."""
return x
@zeros_like_leaf.register("TypeType") @zeros_like_leaf.register("TypeType")
def _zeros_like_type_type(x): def _zeros_like_type_type(x):
"""Returns x because x is a type. This is usually used in backprop progress.""" """Returns x because x is a type. This is usually used in backprop progress."""

View File

@ -3838,3 +3838,23 @@ class CholeskyGrad(Primitive):
def __init__(self): def __init__(self):
"""Initialize CholeskyGrad""" """Initialize CholeskyGrad"""
self.init_prim_io_names(inputs=['x', 'grad'], outputs=['y']) self.init_prim_io_names(inputs=['x', 'grad'], outputs=['y'])
class MapTensorGetGrad(Primitive):
"""
Computes gradients for MapTensorGet operation.
Inputs:
- **map_tensor** (MapTensor) - The input `map_tensor` of the forward operator MapTensorGet.
- **key_tensor** (Tensor) - The input `key_tensor` of the forward operator MapTensorGet.
- **default_value** (Scalar) - The input `default_value` of the forward operator MapTensorGet.
- **grad** (Tensor) - The grad value according the forward operator MapTensorGet.
Outputs:
- **output** (MapTensor) - MapTensor with grad values.
"""
@prim_attr_register
def __init__(self):
"""Initialize MapTensorGetGrad"""
self.init_prim_io_names(inputs=['map_tensor', 'key_tensor', 'default_value', 'grad'], outputs=['output'])
self.add_prim_attr('side_effect_mem', True)

View File

@ -65,6 +65,35 @@ class MapTensorErase(Primitive):
self.init_prim_io_names(inputs=['map_tensor', 'key_tensor'], outputs=['output']) self.init_prim_io_names(inputs=['map_tensor', 'key_tensor'], outputs=['output'])
self.add_prim_attr('side_effect_mem', True) self.add_prim_attr('side_effect_mem', True)
class MapTensorGetKeys(Primitive):
"""
Get all keys as a tensor.
"""
__mindspore_signature__ = (sig.make_sig('map_tensor'),)
@prim_attr_register
def __init__(self):
"""Initialize MapTensorGetKeys"""
self.init_prim_io_names(inputs=['map_tensor'], outputs=['output'])
self.add_prim_attr('side_effect_mem', True)
class MapTensorGetValues(Primitive):
"""
Get all keys as a tensor.
"""
__mindspore_signature__ = (sig.make_sig('map_tensor'),)
@prim_attr_register
def __init__(self):
"""Initialize MapTensorGetValues"""
self.init_prim_io_names(inputs=['map_tensor'], outputs=['output'])
self.add_prim_attr('side_effect_mem', True)
get = MapTensorGet() get = MapTensorGet()
put = MapTensorPut() put = MapTensorPut()
erase = MapTensorErase() erase = MapTensorErase()
get_keys = MapTensorGetKeys()
get_values = MapTensorGetValues()

View File

@ -18,6 +18,7 @@ import mindspore.nn as nn
from mindspore import context, Tensor, Parameter, ParameterTuple from mindspore import context, Tensor, Parameter, ParameterTuple
from mindspore.experimental import MapParameter from mindspore.experimental import MapParameter
from mindspore.common.initializer import initializer from mindspore.common.initializer import initializer
from mindspore.ops import composite as C
def test_basic_operations(): def test_basic_operations():
@ -73,6 +74,9 @@ def test_simple_graph_compile():
value4 = self.m[self.key] value4 = self.m[self.key]
self.m[self.key] = value4 self.m[self.key] = value4
self.m.erase(self.key) self.m.erase(self.key)
keys = self.m.get_keys()
values = self.m.get_values()
self.m.put(keys, values)
return self.p + value1 + value2 + value3 + value4 return self.p + value1 + value2 + value3 + value4
context.set_context(mode=context.GRAPH_MODE) context.set_context(mode=context.GRAPH_MODE)
@ -126,3 +130,40 @@ def test_map_parameter_clone():
assert clone_same.value_dtype == m.value_dtype assert clone_same.value_dtype == m.value_dtype
assert clone_same.value_shape == m.value_shape assert clone_same.value_shape == m.value_shape
assert clone_same.default_value == 'zeros' assert clone_same.default_value == 'zeros'
def test_grad_net():
"""
Feature: MapParameter
Description: Test grad graph compiled with MapParameter.
Expectation: Grad graph for MapParameter created without exceptions.
"""
class MyNet(nn.Cell):
def __init__(self):
nn.Cell.__init__(self)
self.p = Parameter(initializer('ones', (2, 3), ms.float32))
self.m = MapParameter(key_dtype=ms.int32, value_dtype=ms.float32, value_shape=(3,))
self.key = Tensor([1, 2], dtype=ms.int32)
def construct(self, x):
a = self.m.get(self.key, 0.1)
self.m.erase(self.key)
return x * a
class GradNet(nn.Cell):
def __init__(self, network):
super(GradNet, self).__init__()
self.grad_by_list = C.GradOperation(get_by_list=True)
self.network = network
self.weights = ParameterTuple(network.trainable_params())
def construct(self, *inputs):
gout = self.grad_by_list(self.network, self.weights)(*inputs)
return gout
context.set_context(mode=context.GRAPH_MODE)
net = MyNet()
grad = GradNet(net)
t = initializer('ones', (2, 3), ms.float32)
t = t.init_data()
grad(t)