!44030 Implement backward graph for MapParameter
Merge pull request !44030 from hewei/map_param
This commit is contained in:
commit
4253d35dc0
|
@ -362,9 +362,11 @@ BuiltInTypeMap &GetMethodMap() {
|
|||
}},
|
||||
{kObjectTypeMapTensorType,
|
||||
{
|
||||
{"get", std::string("map_tensor_get")}, // C.map_tensor_get
|
||||
{"put", std::string("map_tensor_put")}, // C.map_tensor_put
|
||||
{"erase", std::string("map_tensor_erase")}, // C.map_tensor_erase
|
||||
{"get", std::string("map_tensor_get")}, // C.map_tensor_get
|
||||
{"put", std::string("map_tensor_put")}, // C.map_tensor_put
|
||||
{"erase", std::string("map_tensor_erase")}, // C.map_tensor_erase
|
||||
{"get_keys", std::string("map_tensor_get_keys")}, // C.map_tensor_get_keys
|
||||
{"get_values", std::string("map_tensor_get_values")}, // C.map_tensor_get_values
|
||||
}},
|
||||
{kObjectTypeJTagged, {}},
|
||||
{kObjectTypeSymbolicKeyType, {}},
|
||||
|
|
|
@ -1848,8 +1848,8 @@ class RefToEmbedEvaluator : public SymbolicPrimEvaluator {
|
|||
MS_EXCEPTION_IF_NULL(eval_result);
|
||||
AbstractBasePtr abs = eval_result->abstract();
|
||||
MS_EXCEPTION_IF_NULL(abs);
|
||||
auto ref_abs = abs->cast_ptr<AbstractRefTensor>();
|
||||
if (ref_abs == nullptr) {
|
||||
auto ref_key_value = abstract::GetRefKeyValue(abs);
|
||||
if (ref_key_value == nullptr) {
|
||||
MS_LOG(ERROR) << "The first parameter of RefToEmbed should be Ref, but " << abs->ToString();
|
||||
return nullptr;
|
||||
}
|
||||
|
@ -1865,11 +1865,9 @@ class RefToEmbedEvaluator : public SymbolicPrimEvaluator {
|
|||
MS_EXCEPTION_IF_NULL(param);
|
||||
ifEmbedIsWeight = param->has_default();
|
||||
}
|
||||
auto refkey = ref_abs->ref_key_value()->cast_ptr<StringImm>();
|
||||
auto refkey = ref_key_value->cast_ptr<StringImm>();
|
||||
if (refkey == nullptr || !ifEmbedIsWeight) {
|
||||
auto ret = std::make_shared<AbstractScalar>(type);
|
||||
auto ref_value = ref_abs->ref();
|
||||
MS_EXCEPTION_IF_NULL(ref_value);
|
||||
return std::make_shared<EvalResult>(ret, std::make_shared<AttrValueMap>());
|
||||
}
|
||||
|
||||
|
@ -1884,8 +1882,7 @@ class RefToEmbedEvaluator : public SymbolicPrimEvaluator {
|
|||
MS_LOG(ERROR) << "RefToEmbed input can't find parameter \"" << name << "\" in graph.";
|
||||
return nullptr;
|
||||
}
|
||||
AbstractBasePtr x = ref_abs->ref();
|
||||
x = SensitivityTransform(x);
|
||||
AbstractBasePtr x = SensitivityTransform(abs);
|
||||
std::shared_ptr<SymbolicKeyInstance> key = std::make_shared<SymbolicKeyInstance>(node, x);
|
||||
std::shared_ptr<AbstractScalar> abs_scalar = std::make_shared<AbstractScalar>(key, type);
|
||||
return std::make_shared<EvalResult>(abs_scalar, std::make_shared<AttrValueMap>());
|
||||
|
|
|
@ -1822,5 +1822,17 @@ bool AbstractIOMonad::operator==(const AbstractBase &other) const {
|
|||
}
|
||||
return other.isa<AbstractIOMonad>();
|
||||
}
|
||||
|
||||
ValuePtr GetRefKeyValue(const AbstractBasePtr &abs) {
|
||||
auto abs_ref = abs->cast_ptr<AbstractRefTensor>();
|
||||
if (abs_ref != nullptr) {
|
||||
return abs_ref->ref_key_value();
|
||||
}
|
||||
auto abs_map_tensor = abs->cast_ptr<AbstractMapTensor>();
|
||||
if (abs_map_tensor != nullptr) {
|
||||
return abs_map_tensor->ref_key_value();
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
} // namespace abstract
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -1592,6 +1592,8 @@ MS_CORE_API std::string ExtractLoggingInfo(const std::string &info);
|
|||
|
||||
MS_CORE_API void SynchronizeSequenceElementsUseFlagsRecursively(const AbstractSequencePtr &lhs_sequence,
|
||||
const AbstractSequencePtr &rhs_sequence);
|
||||
|
||||
MS_CORE_API ValuePtr GetRefKeyValue(const AbstractBasePtr &abs);
|
||||
} // namespace abstract
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CORE_ABSTRACT_ABSTRACT_VALUE_H_
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
#include "ir/dtype.h"
|
||||
#include <cstdlib>
|
||||
#include <algorithm>
|
||||
#include "mindapi/base/type_id.h"
|
||||
#include "utils/log_adapter.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
|
||||
|
@ -39,7 +40,8 @@ std::string GetExcptionTypeString(TypeId id) {
|
|||
{kNumberTypeInt4, "Int4"},
|
||||
{kNumberTypeGLUInt, "GLUInt"},
|
||||
{kObjectTypeMonad, "Monad"},
|
||||
{kObjectTypeCSRTensorType, "CSRTensor"}};
|
||||
{kObjectTypeCSRTensorType, "CSRTensor"},
|
||||
{kObjectTypeMapTensorType, "MapTensor"}};
|
||||
|
||||
auto it = type_id_to_string.find(id);
|
||||
std::string type = "";
|
||||
|
|
|
@ -325,8 +325,10 @@ constexpr auto kIsCSRFunc = "IsCSRFunc";
|
|||
constexpr auto kMapTensorGet = "MapTensorGet";
|
||||
constexpr auto kMapTensorPut = "MapTensorPut";
|
||||
constexpr auto kMapTensorErase = "MapTensorErase";
|
||||
constexpr auto kMapTensorPutBprop = "MapTensorPutBprop";
|
||||
constexpr auto kMapTensorGetDefaultValue = "MapTensorGetDefaultValue";
|
||||
constexpr auto kMapTensorGetKeys = "MapTensorGetKeys";
|
||||
constexpr auto kMapTensorGetValues = "MapTensorGetValues";
|
||||
constexpr auto kMapTensorGetGrad = "MapTensorGetGrad";
|
||||
|
||||
// COOTensor
|
||||
constexpr auto kMakeCOOTensor = "MakeCOOTensor";
|
||||
|
@ -1060,8 +1062,10 @@ GVAR_DEF(PrimitivePtr, kPrimCSRTensorGetDenseShape, std::make_shared<Primitive>(
|
|||
GVAR_DEF(PrimitivePtr, kPrimMapTensorGet, std::make_shared<Primitive>(kMapTensorGet));
|
||||
GVAR_DEF(PrimitivePtr, kPrimMapTensorPut, std::make_shared<Primitive>(kMapTensorPut));
|
||||
GVAR_DEF(PrimitivePtr, kPrimMapTensorErase, std::make_shared<Primitive>(kMapTensorErase));
|
||||
GVAR_DEF(PrimitivePtr, kPrimMapTensorPutBprop, std::make_shared<Primitive>(kMapTensorPutBprop));
|
||||
GVAR_DEF(PrimitivePtr, kPrimMapTensorGetDefaultValue, std::make_shared<Primitive>(kMapTensorGetDefaultValue));
|
||||
GVAR_DEF(PrimitivePtr, kPrimMapTensorGetKeys, std::make_shared<Primitive>(kMapTensorGetKeys));
|
||||
GVAR_DEF(PrimitivePtr, kPrimMapTensorGetValues, std::make_shared<Primitive>(kMapTensorGetValues));
|
||||
GVAR_DEF(PrimitivePtr, kPrimMapTensorGetGrad, std::make_shared<Primitive>(kMapTensorGetGrad));
|
||||
|
||||
// Sparse ops
|
||||
GVAR_DEF(PrimitivePtr, kPrimSparseTensorDenseMatmul, std::make_shared<Primitive>(kSparseTensorDenseMatmul));
|
||||
|
|
|
@ -0,0 +1,45 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "ops/map_tensor_get_grad.h"
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
#include "utils/ms_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
MIND_API_OPERATOR_IMPL(MapTensorGetGrad, BaseOperator);
|
||||
AbstractBasePtr MapTensorGetGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
// Check number of arguments.
|
||||
constexpr int64_t input_num = 4;
|
||||
CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, input_num, kNameMapTensorGetGrad);
|
||||
// Check argument abstracts.
|
||||
auto abs_map_tensor =
|
||||
CheckAndConvertUtils::CheckArgs<abstract::AbstractMapTensor>(kNameMapTensorGetGrad, input_args, kInputIndex0);
|
||||
|
||||
// We skip check other arguments, because grad operations are generated by compiler
|
||||
// so we can assume that their are always correct.
|
||||
|
||||
// Grad map tensor has same abstract with the input map tensor.
|
||||
return abs_map_tensor->Broaden();
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(MapTensorGetGrad, prim::kPrimMapTensorGetGrad, MapTensorGetGradInfer, nullptr, true);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,44 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_OPS_MAP_TENSOR_GET_GRAD_H_
|
||||
#define MINDSPORE_CORE_OPS_MAP_TENSOR_GET_GRAD_H_
|
||||
|
||||
#include <vector>
|
||||
#include "ops/base_operator.h"
|
||||
#include "mindapi/base/types.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameMapTensorGetGrad = "MapTensorGetGrad";
|
||||
/// \brief Grad operator MapTensorGet.
|
||||
/// Refer to Python API @ref mindspore.ops.MapTensorGetGrad for more details.
|
||||
class MIND_API MapTensorGetGrad : public BaseOperator {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(MapTensorGetGrad);
|
||||
/// \brief Constructor.
|
||||
MapTensorGetGrad() : BaseOperator(kNameMapTensorGetGrad) {
|
||||
InitIOName({"map_tensor", "key_tensor", "default_value", "grad"}, {"output"});
|
||||
}
|
||||
/// \brief Init.
|
||||
void Init() const {}
|
||||
};
|
||||
abstract::AbstractBasePtr MapTensorGetGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_OPS_MAP_TENSOR_GET_GRAD_H_
|
|
@ -0,0 +1,45 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "ops/map_tensor_get_keys.h"
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
#include "utils/ms_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
MIND_API_OPERATOR_IMPL(MapTensorGetKeys, BaseOperator);
|
||||
AbstractBasePtr MapTensorGetKeysInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
// Check number of arguments.
|
||||
constexpr int64_t input_num = 1;
|
||||
CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, input_num, kNameMapTensorGetKeys);
|
||||
// Check argument abstracts.
|
||||
auto abs_map_tensor =
|
||||
CheckAndConvertUtils::CheckArgs<abstract::AbstractMapTensor>(kNameMapTensorGetKeys, input_args, kInputIndex0);
|
||||
auto map_tensor_type = abs_map_tensor->map_tensor_type();
|
||||
MS_EXCEPTION_IF_NULL(map_tensor_type);
|
||||
const auto &key_dtype = map_tensor_type->key_dtype();
|
||||
// We don't know the map size in compile time.
|
||||
ShapeVector shape_vec = {abstract::Shape::kShapeDimAny};
|
||||
return std::make_shared<abstract::AbstractTensor>(key_dtype, shape_vec);
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(MapTensorGetKeys, prim::kPrimMapTensorGetKeys, MapTensorGetKeysInfer, nullptr, true);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,42 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_OPS_MAP_TENSOR_GET_KEYS_H_
|
||||
#define MINDSPORE_CORE_OPS_MAP_TENSOR_GET_KEYS_H_
|
||||
|
||||
#include <vector>
|
||||
#include "ops/base_operator.h"
|
||||
#include "mindapi/base/types.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameMapTensorGetKeys = "MapTensorGetKeys";
|
||||
/// \brief Get all keys as a tensor from a MapTensor.
|
||||
/// Refer to Python API @ref mindspore.ops.MapTensorGetKeys for more details.
|
||||
class MIND_API MapTensorGetKeys : public BaseOperator {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(MapTensorGetKeys);
|
||||
/// \brief Constructor.
|
||||
MapTensorGetKeys() : BaseOperator(kNameMapTensorGetKeys) { InitIOName({"map_tensor"}, {"output"}); }
|
||||
/// \brief Init.
|
||||
void Init() const {}
|
||||
};
|
||||
abstract::AbstractBasePtr MapTensorGetKeysInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_OPS_MAP_TENSOR_GET_KEYS_H_
|
|
@ -0,0 +1,49 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "ops/map_tensor_get_values.h"
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
#include "utils/ms_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
MIND_API_OPERATOR_IMPL(MapTensorGetValues, BaseOperator);
|
||||
AbstractBasePtr MapTensorGetValuesInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
// Check number of arguments.
|
||||
constexpr int64_t input_num = 1;
|
||||
CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, input_num, kNameMapTensorGetValues);
|
||||
// Check argument abstracts.
|
||||
auto abs_map_tensor =
|
||||
CheckAndConvertUtils::CheckArgs<abstract::AbstractMapTensor>(kNameMapTensorGetValues, input_args, kInputIndex0);
|
||||
auto map_tensor_type = abs_map_tensor->map_tensor_type();
|
||||
MS_EXCEPTION_IF_NULL(map_tensor_type);
|
||||
const auto &value_dtype = map_tensor_type->value_dtype();
|
||||
auto value_shape_ptr = abs_map_tensor->value_shape();
|
||||
MS_EXCEPTION_IF_NULL(value_shape_ptr);
|
||||
const auto &value_shape = value_shape_ptr->shape();
|
||||
// We don't know the map size in compile time.
|
||||
ShapeVector shape_vec = {abstract::Shape::kShapeDimAny};
|
||||
(void)shape_vec.insert(shape_vec.end(), value_shape.begin(), value_shape.end());
|
||||
return std::make_shared<abstract::AbstractTensor>(value_dtype, shape_vec);
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(MapTensorGetValues, prim::kPrimMapTensorGetValues, MapTensorGetValuesInfer, nullptr, true);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,42 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_OPS_MAP_TENSOR_GET_VALUES_H_
|
||||
#define MINDSPORE_CORE_OPS_MAP_TENSOR_GET_VALUES_H_
|
||||
|
||||
#include <vector>
|
||||
#include "ops/base_operator.h"
|
||||
#include "mindapi/base/types.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameMapTensorGetValues = "MapTensorGetValues";
|
||||
/// \brief Get all values as a tensor from a MapTensor.
|
||||
/// Refer to Python API @ref mindspore.ops.MapTensorGetValues for more details.
|
||||
class MIND_API MapTensorGetValues : public BaseOperator {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(MapTensorGetValues);
|
||||
/// \brief Constructor.
|
||||
MapTensorGetValues() : BaseOperator(kNameMapTensorGetValues) { InitIOName({"map_tensor"}, {"output"}); }
|
||||
/// \brief Init.
|
||||
void Init() const {}
|
||||
};
|
||||
abstract::AbstractBasePtr MapTensorGetValuesInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_OPS_MAP_TENSOR_GET_VALUES_H_
|
|
@ -3289,6 +3289,20 @@ def map_tensor_erase(map_tensor, key_tensor):
|
|||
return _map_tensor_ops.erase(map_tensor, key_tensor)
|
||||
|
||||
|
||||
def map_tensor_get_keys(map_tensor):
|
||||
r"""
|
||||
Get all keys as a tensor.
|
||||
"""
|
||||
return _map_tensor_ops.get_keys(map_tensor)
|
||||
|
||||
|
||||
def map_tensor_get_values(map_tensor):
|
||||
r"""
|
||||
Get all values as a tensor.
|
||||
"""
|
||||
return _map_tensor_ops.get_values(map_tensor)
|
||||
|
||||
|
||||
def conj(input):
|
||||
r"""
|
||||
Computes complex conjugate of the input element-wise.
|
||||
|
|
|
@ -146,6 +146,24 @@ class MapParameter(Parameter):
|
|||
result_tensor = self._map_tensor.get(key_tensor, default_value)
|
||||
return Tensor(result_tensor, internal=True)
|
||||
|
||||
def get_keys(self):
|
||||
"""
|
||||
Get all keys as a tensor.
|
||||
|
||||
Returns:
|
||||
Tensor, the tensor contains all keys.
|
||||
"""
|
||||
return None
|
||||
|
||||
def get_values(self):
|
||||
"""
|
||||
Get all values as a tensor.
|
||||
|
||||
Returns:
|
||||
Tensor, the tensor contains all keys.
|
||||
"""
|
||||
return None
|
||||
|
||||
def put(self, key_tensor, value_tensor):
|
||||
"""
|
||||
Insert or update records according the given key tensor and value tensor.
|
||||
|
|
|
@ -28,6 +28,7 @@ from mindspore.ops.operations.sparse_ops import SparseSegmentMeanWithNumSegments
|
|||
from mindspore.ops.operations.sparse_ops import SparseDenseCwiseMul
|
||||
from mindspore.ops.operations.sparse_ops import SparseDenseCwiseDiv
|
||||
from mindspore.ops.operations.sparse_ops import SparseTensorDenseAdd
|
||||
from mindspore.ops.operations import _map_tensor_ops
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops.primitive import constexpr
|
||||
|
@ -293,3 +294,14 @@ def get_bprop_sparse_reorder(self):
|
|||
return None, gather_op(dout[1], inverted_permutation, axis), None
|
||||
|
||||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(_map_tensor_ops.MapTensorGet)
|
||||
def get_bprop_map_tensor_get(self):
|
||||
"""Grad definition for `MapTensorGet` operation."""
|
||||
grad_op = G.MapTensorGetGrad()
|
||||
|
||||
def bprop(map_tensor, key_tensor, default_value, out, dout):
|
||||
grad_map_tensor = grad_op(map_tensor, key_tensor, default_value, dout)
|
||||
return grad_map_tensor, zeros_like(key_tensor), zeros_like(default_value)
|
||||
return bprop
|
||||
|
|
|
@ -21,7 +21,8 @@ Pre-defined combination of operators.
|
|||
|
||||
|
||||
from mindspore.ops.composite.base import GradOperation, _Grad, HyperMap, Map, MultitypeFuncGraph, add_flags, \
|
||||
core, env_get, tail, zip_operation, _Vmap, _TaylorOperation
|
||||
core, tail, zip_operation, _Vmap, _TaylorOperation
|
||||
from mindspore.ops.composite.env_ops import env_get
|
||||
from mindspore.ops.composite.clip_ops import clip_by_value, clip_by_global_norm
|
||||
from mindspore.ops.composite.multitype_ops.add_impl import hyper_add
|
||||
from mindspore.ops.composite.multitype_ops.ones_like_impl import ones_like
|
||||
|
|
|
@ -32,8 +32,6 @@ from mindspore._c_expression import GradOperation_, HyperMap_, Map_, MultitypeFu
|
|||
from mindspore.common import dtype as mstype
|
||||
from mindspore.common.api import ms_function, _pynative_executor, _wrap_func
|
||||
from mindspore.ops.primitive import Primitive
|
||||
from mindspore.ops.operations import _grad_ops
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import signature as sig
|
||||
|
||||
__all__ = [TupleAdd_, UnpackCall_, TupleGetItemTensor_, SequenceSliceGetItem_, ListSliceSetItem_]
|
||||
|
@ -1108,15 +1106,3 @@ class _ZipOperation(ZipOperation_):
|
|||
|
||||
zip_operation = _ZipOperation('zip_operation')
|
||||
"""`zip_operation` will generate a tuple of zip iterations of inputs."""
|
||||
|
||||
env_get = MultitypeFuncGraph("env_get")
|
||||
|
||||
environ_get = Primitive('EnvironGet')
|
||||
ref_to_embed = _grad_ops.RefToEmbed()
|
||||
zeros_like = P.ZerosLike()
|
||||
|
||||
|
||||
@env_get.register("EnvType", "Tensor")
|
||||
def _tensor_env_get(env, parameter):
|
||||
"""Used to get env."""
|
||||
return environ_get(env, ref_to_embed(parameter), zeros_like(parameter))
|
||||
|
|
|
@ -0,0 +1,41 @@
|
|||
# This is the Python adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
|
||||
#
|
||||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Env related operations."""
|
||||
from __future__ import absolute_import
|
||||
from mindspore.ops.composite.base import MultitypeFuncGraph
|
||||
from mindspore.ops.composite.multitype_ops.zeros_like_impl import zeros_like
|
||||
from mindspore.ops.primitive import Primitive
|
||||
from mindspore.ops.operations import _grad_ops
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
env_get = MultitypeFuncGraph("env_get")
|
||||
environ_get = Primitive('EnvironGet')
|
||||
ref_to_embed = _grad_ops.RefToEmbed()
|
||||
tensor_zeros_like = P.ZerosLike()
|
||||
|
||||
|
||||
@env_get.register("EnvType", "Tensor")
|
||||
def _tensor_env_get(env, parameter):
|
||||
"""Used to get env."""
|
||||
return environ_get(env, ref_to_embed(parameter), tensor_zeros_like(parameter))
|
||||
|
||||
|
||||
@env_get.register("EnvType", "MapTensor")
|
||||
def _map_tensor_env_get(env, map_parameter):
|
||||
"""Used to get env for map parameter."""
|
||||
return environ_get(env, ref_to_embed(map_parameter), zeros_like(map_parameter))
|
|
@ -513,4 +513,20 @@ def _add_tuple_cootensor(x, y):
|
|||
x = COOTensor(x[0], x[1], y.shape)
|
||||
return _add_cootensor_cootensor(x, y)
|
||||
|
||||
|
||||
@_add_backward.register("MapTensor", "MapTensor")
|
||||
def _map_tensor_add_backward(x, y):
|
||||
"""
|
||||
Adds MapTensors for backward.
|
||||
|
||||
Args:
|
||||
x (MapTensor): x
|
||||
y (MapTensor): y
|
||||
|
||||
Returns:
|
||||
MapTensor.
|
||||
"""
|
||||
return x
|
||||
|
||||
|
||||
hyper_add = base.HyperMap(_add_backward)
|
||||
|
|
|
@ -73,6 +73,12 @@ def _zeros_like_csr_tensor(x):
|
|||
return F.make_csr_tensor(x.indptr, x.indices, values, x.shape)
|
||||
|
||||
|
||||
@zeros_like_leaf.register("MapTensor")
|
||||
def _zeros_like_map_tensor(x):
|
||||
"""Returns a map tensor with the same shape and dtype as x and all elements are 0."""
|
||||
return x
|
||||
|
||||
|
||||
@zeros_like_leaf.register("TypeType")
|
||||
def _zeros_like_type_type(x):
|
||||
"""Returns x because x is a type. This is usually used in backprop progress."""
|
||||
|
|
|
@ -3838,3 +3838,23 @@ class CholeskyGrad(Primitive):
|
|||
def __init__(self):
|
||||
"""Initialize CholeskyGrad"""
|
||||
self.init_prim_io_names(inputs=['x', 'grad'], outputs=['y'])
|
||||
|
||||
|
||||
class MapTensorGetGrad(Primitive):
|
||||
"""
|
||||
Computes gradients for MapTensorGet operation.
|
||||
|
||||
Inputs:
|
||||
- **map_tensor** (MapTensor) - The input `map_tensor` of the forward operator MapTensorGet.
|
||||
- **key_tensor** (Tensor) - The input `key_tensor` of the forward operator MapTensorGet.
|
||||
- **default_value** (Scalar) - The input `default_value` of the forward operator MapTensorGet.
|
||||
- **grad** (Tensor) - The grad value according the forward operator MapTensorGet.
|
||||
|
||||
Outputs:
|
||||
- **output** (MapTensor) - MapTensor with grad values.
|
||||
"""
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""Initialize MapTensorGetGrad"""
|
||||
self.init_prim_io_names(inputs=['map_tensor', 'key_tensor', 'default_value', 'grad'], outputs=['output'])
|
||||
self.add_prim_attr('side_effect_mem', True)
|
||||
|
|
|
@ -65,6 +65,35 @@ class MapTensorErase(Primitive):
|
|||
self.init_prim_io_names(inputs=['map_tensor', 'key_tensor'], outputs=['output'])
|
||||
self.add_prim_attr('side_effect_mem', True)
|
||||
|
||||
|
||||
class MapTensorGetKeys(Primitive):
|
||||
"""
|
||||
Get all keys as a tensor.
|
||||
"""
|
||||
__mindspore_signature__ = (sig.make_sig('map_tensor'),)
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""Initialize MapTensorGetKeys"""
|
||||
self.init_prim_io_names(inputs=['map_tensor'], outputs=['output'])
|
||||
self.add_prim_attr('side_effect_mem', True)
|
||||
|
||||
|
||||
class MapTensorGetValues(Primitive):
|
||||
"""
|
||||
Get all keys as a tensor.
|
||||
"""
|
||||
__mindspore_signature__ = (sig.make_sig('map_tensor'),)
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""Initialize MapTensorGetValues"""
|
||||
self.init_prim_io_names(inputs=['map_tensor'], outputs=['output'])
|
||||
self.add_prim_attr('side_effect_mem', True)
|
||||
|
||||
|
||||
get = MapTensorGet()
|
||||
put = MapTensorPut()
|
||||
erase = MapTensorErase()
|
||||
get_keys = MapTensorGetKeys()
|
||||
get_values = MapTensorGetValues()
|
||||
|
|
|
@ -18,6 +18,7 @@ import mindspore.nn as nn
|
|||
from mindspore import context, Tensor, Parameter, ParameterTuple
|
||||
from mindspore.experimental import MapParameter
|
||||
from mindspore.common.initializer import initializer
|
||||
from mindspore.ops import composite as C
|
||||
|
||||
|
||||
def test_basic_operations():
|
||||
|
@ -73,6 +74,9 @@ def test_simple_graph_compile():
|
|||
value4 = self.m[self.key]
|
||||
self.m[self.key] = value4
|
||||
self.m.erase(self.key)
|
||||
keys = self.m.get_keys()
|
||||
values = self.m.get_values()
|
||||
self.m.put(keys, values)
|
||||
return self.p + value1 + value2 + value3 + value4
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
@ -126,3 +130,40 @@ def test_map_parameter_clone():
|
|||
assert clone_same.value_dtype == m.value_dtype
|
||||
assert clone_same.value_shape == m.value_shape
|
||||
assert clone_same.default_value == 'zeros'
|
||||
|
||||
|
||||
def test_grad_net():
|
||||
"""
|
||||
Feature: MapParameter
|
||||
Description: Test grad graph compiled with MapParameter.
|
||||
Expectation: Grad graph for MapParameter created without exceptions.
|
||||
"""
|
||||
class MyNet(nn.Cell):
|
||||
def __init__(self):
|
||||
nn.Cell.__init__(self)
|
||||
self.p = Parameter(initializer('ones', (2, 3), ms.float32))
|
||||
self.m = MapParameter(key_dtype=ms.int32, value_dtype=ms.float32, value_shape=(3,))
|
||||
self.key = Tensor([1, 2], dtype=ms.int32)
|
||||
|
||||
def construct(self, x):
|
||||
a = self.m.get(self.key, 0.1)
|
||||
self.m.erase(self.key)
|
||||
return x * a
|
||||
|
||||
class GradNet(nn.Cell):
|
||||
def __init__(self, network):
|
||||
super(GradNet, self).__init__()
|
||||
self.grad_by_list = C.GradOperation(get_by_list=True)
|
||||
self.network = network
|
||||
self.weights = ParameterTuple(network.trainable_params())
|
||||
|
||||
def construct(self, *inputs):
|
||||
gout = self.grad_by_list(self.network, self.weights)(*inputs)
|
||||
return gout
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
net = MyNet()
|
||||
grad = GradNet(net)
|
||||
t = initializer('ones', (2, 3), ms.float32)
|
||||
t = t.init_data()
|
||||
grad(t)
|
||||
|
|
Loading…
Reference in New Issue