forked from mindspore-Ecosystem/mindspore
!26325 [assistant][ops] add Addcdiv and Addcmul
Merge pull request !26325 from 孟权令/Addcdiv
This commit is contained in:
commit
372065c821
mindspore
core
python/mindspore/ops
_grad_experimental
_op_impl/tbe
operations
tests/ut/python/ops
|
@ -34,6 +34,8 @@ inline const mindspore::HashMap<std::string, ValuePtr> kSideEffectPropagate = {
|
|||
|
||||
constexpr auto kGetNext = "GetNext";
|
||||
constexpr auto kGather = "Gather";
|
||||
constexpr auto kAddcdiv = "Addcdiv";
|
||||
constexpr auto kAddcmul = "Addcmul";
|
||||
constexpr auto kCdist = "Cdist";
|
||||
constexpr auto kCdistGrad = "CdistGrad";
|
||||
// Arithmetic
|
||||
|
@ -525,6 +527,8 @@ inline const PrimitivePtr kPrimGer = std::make_shared<Primitive>("Ger");
|
|||
inline const PrimitivePtr kPrimCeil = std::make_shared<Primitive>("Ceil");
|
||||
inline const PrimitivePtr kPrimTensorAdd = std::make_shared<Primitive>("TensorAdd");
|
||||
inline const PrimitivePtr kPrimAdd = std::make_shared<Primitive>(kAdd);
|
||||
inline const PrimitivePtr kPrimAddcdiv = std::make_shared<Primitive>(kAddcdiv);
|
||||
inline const PrimitivePtr kPrimAddcmul = std::make_shared<Primitive>(kAddcmul);
|
||||
inline const PrimitivePtr kPrimMatMul = std::make_shared<Primitive>("MatMul");
|
||||
inline const PrimitivePtr kPrimMatMulV2 = std::make_shared<Primitive>("MatMulV2");
|
||||
inline const PrimitivePtr kPrimMatrixDiag = std::make_shared<Primitive>("MatrixDiag");
|
||||
|
|
|
@ -0,0 +1,87 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <algorithm>
|
||||
#include "ops/addcdiv.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "abstract/primitive_infer_map.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto op_name = primitive->name();
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
auto input_data_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape());
|
||||
auto input_data = input_data_map[kShape];
|
||||
auto x1_shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape());
|
||||
auto x1_shape = x1_shape_map[kShape];
|
||||
auto x2_shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape());
|
||||
auto x2_shape = x2_shape_map[kShape];
|
||||
auto value_shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex3]->BuildShape());
|
||||
auto value_shape = value_shape_map[kShape];
|
||||
auto broadcast_shape = CalBroadCastShape(x1_shape, x2_shape, op_name, "x1", "x2");
|
||||
if (input_args[kInputIndex3]->isa<abstract::AbstractTensor>()) {
|
||||
CalBroadCastShape(x1_shape, value_shape, op_name, "x1", "value");
|
||||
CalBroadCastShape(x2_shape, value_shape, op_name, "x2", "value");
|
||||
broadcast_shape = CalBroadCastShape(broadcast_shape, value_shape, op_name);
|
||||
}
|
||||
broadcast_shape = CalBroadCastShape(broadcast_shape, input_data, op_name);
|
||||
return std::make_shared<abstract::Shape>(broadcast_shape);
|
||||
}
|
||||
|
||||
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
auto op_name = prim->name();
|
||||
const std::set<TypePtr> valid_types = {kFloat16, kFloat32};
|
||||
auto input_data_type = input_args[kInputIndex0]->BuildType();
|
||||
auto x1_type = input_args[kInputIndex1]->BuildType();
|
||||
auto x2_type = input_args[kInputIndex2]->BuildType();
|
||||
auto value_type = input_args[kInputIndex3]->BuildType();
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("input_data", input_data_type, valid_types, op_name);
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("x1", x1_type, valid_types, op_name);
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("x2", x2_type, valid_types, op_name);
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("value", value_type, valid_types, op_name);
|
||||
std::map<std::string, TypePtr> types;
|
||||
(void)types.emplace("input_data", input_data_type);
|
||||
(void)types.emplace("x1", x1_type);
|
||||
(void)types.emplace("x2", x2_type);
|
||||
(void)types.emplace("value", value_type);
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, op_name);
|
||||
return input_data_type;
|
||||
}
|
||||
} // namespace
|
||||
AbstractBasePtr AddcdivInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
const int64_t kInputsNum = 4;
|
||||
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, kInputsNum, primitive->name());
|
||||
auto infer_shape = InferShape(primitive, input_args);
|
||||
auto infer_type = InferType(primitive, input_args);
|
||||
return abstract::MakeAbstract(infer_shape, infer_type);
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(Addcdiv, prim::kPrimAddcdiv, AddcdivInfer, nullptr, true);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,43 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_OPS_ADDCDIV_H_
|
||||
#define MINDSPORE_CORE_OPS_ADDCDIV_H_
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "ops/primitive_c.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameAddcdiv = "Addcdiv";
|
||||
class Addcdiv : public PrimitiveC {
|
||||
public:
|
||||
Addcdiv() : PrimitiveC(kNameAddcdiv) { InitIOName({"input_data", "x1", "x2", "value"}, {"output"}); }
|
||||
~Addcdiv() = default;
|
||||
MS_DECLARE_PARENT(Addcdiv, PrimitiveC);
|
||||
};
|
||||
|
||||
AbstractBasePtr AddcdivInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args);
|
||||
using PrimAddcdivPtr = std::shared_ptr<Addcdiv>;
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_OPS_ADDCDIV_H_
|
|
@ -0,0 +1,87 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <algorithm>
|
||||
#include "ops/addcmul.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "abstract/primitive_infer_map.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto op_name = primitive->name();
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
auto input_data_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape());
|
||||
auto input_data = input_data_map[kShape];
|
||||
auto x1_shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape());
|
||||
auto x1_shape = x1_shape_map[kShape];
|
||||
auto x2_shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape());
|
||||
auto x2_shape = x2_shape_map[kShape];
|
||||
auto value_shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex3]->BuildShape());
|
||||
auto value_shape = value_shape_map[kShape];
|
||||
auto broadcast_shape = CalBroadCastShape(x1_shape, x2_shape, op_name, "x1", "x2");
|
||||
if (input_args[kInputIndex3]->isa<abstract::AbstractTensor>()) {
|
||||
CalBroadCastShape(x1_shape, value_shape, op_name, "x1", "value");
|
||||
CalBroadCastShape(x2_shape, value_shape, op_name, "x2", "value");
|
||||
broadcast_shape = CalBroadCastShape(broadcast_shape, value_shape, op_name);
|
||||
}
|
||||
broadcast_shape = CalBroadCastShape(broadcast_shape, input_data, op_name);
|
||||
return std::make_shared<abstract::Shape>(broadcast_shape);
|
||||
}
|
||||
|
||||
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
auto op_name = prim->name();
|
||||
const std::set<TypePtr> valid_types = {kFloat16, kFloat32, kInt32};
|
||||
auto input_data_type = input_args[kInputIndex0]->BuildType();
|
||||
auto x1_type = input_args[kInputIndex1]->BuildType();
|
||||
auto x2_type = input_args[kInputIndex2]->BuildType();
|
||||
auto value_type = input_args[kInputIndex3]->BuildType();
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("input_data", input_data_type, valid_types, op_name);
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("x1", x1_type, valid_types, op_name);
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("x2", x2_type, valid_types, op_name);
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("value", value_type, valid_types, op_name);
|
||||
std::map<std::string, TypePtr> types;
|
||||
(void)types.emplace("input_data", input_data_type);
|
||||
(void)types.emplace("x1", x1_type);
|
||||
(void)types.emplace("x2", x2_type);
|
||||
(void)types.emplace("value", value_type);
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, op_name);
|
||||
return input_data_type;
|
||||
}
|
||||
} // namespace
|
||||
AbstractBasePtr AddcmulInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
const int64_t kInputsNum = 4;
|
||||
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, kInputsNum, primitive->name());
|
||||
auto infer_shape = InferShape(primitive, input_args);
|
||||
auto infer_type = InferType(primitive, input_args);
|
||||
return abstract::MakeAbstract(infer_shape, infer_type);
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(Addcmul, prim::kPrimAddcmul, AddcmulInfer, nullptr, true);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,43 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_OPS_ADDCMUL_H_
|
||||
#define MINDSPORE_CORE_OPS_ADDCMUL_H_
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "ops/primitive_c.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameAddcmul = "Addcmul";
|
||||
class Addcmul : public PrimitiveC {
|
||||
public:
|
||||
Addcmul() : PrimitiveC(kNameAddcmul) { InitIOName({"input_data", "x1", "x2", "value"}, {"output"}); }
|
||||
~Addcmul() = default;
|
||||
MS_DECLARE_PARENT(Addcmul, PrimitiveC);
|
||||
};
|
||||
|
||||
AbstractBasePtr AddcmulInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args);
|
||||
using PrimAddcmulPtr = std::shared_ptr<Addcmul>;
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_OPS_ADDCMUL_H_
|
|
@ -81,6 +81,37 @@ def get_bprop_index_lerp(self):
|
|||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(P.Addcdiv)
|
||||
def get_bprop_index_addcdiv(self):
|
||||
"""Generate bprop for Addcdiv"""
|
||||
mul_op = P.Mul()
|
||||
div_op = P.Div()
|
||||
pow_op = P.Pow()
|
||||
neg_op = P.Neg()
|
||||
|
||||
def bprop(input_data, x1, x2, value, out, dout):
|
||||
dx1 = mul_op(dout, div_op(value, x2))
|
||||
dx2 = neg_op(mul_op(mul_op(mul_op(x1, value), pow_op(x2, -2)), dout))
|
||||
dvalue = mul_op(dout, div_op(x1, x2))
|
||||
return dout, dx1, dx2, dvalue
|
||||
|
||||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(P.Addcmul)
|
||||
def get_bprop_index_addcmul(self):
|
||||
"""Generate bprop for Addcmul"""
|
||||
mul_op = P.Mul()
|
||||
|
||||
def bprop(input_data, x1, x2, value, out, dout):
|
||||
dx1 = mul_op(dout, mul_op(value, x2))
|
||||
dx2 = mul_op(dout, mul_op(value, x1))
|
||||
dvalue = mul_op(dout, mul_op(x1, x2))
|
||||
return dout, dx1, dx2, dvalue
|
||||
|
||||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(P.LpNorm)
|
||||
def get_bprop_lp_norm(self):
|
||||
"""Grad definition for `LpNorm` operation."""
|
||||
|
|
|
@ -35,6 +35,8 @@ from .apply_centered_rms_prop import _apply_centered_rms_prop_tbe
|
|||
from .apply_centered_rms_prop_ds import _apply_centered_rms_prop_ds_tbe
|
||||
from .add_n import _add_n_tbe
|
||||
from .add_n_ds import _add_n_ds_tbe
|
||||
from .addcdiv import _addcdiv_tbe
|
||||
from .addcmul import _addcmul_tbe
|
||||
from .accumulate_n_v2 import _accumulate_n_v2_tbe
|
||||
from .accumulate_n_v2_ds import _accumulate_n_v2_ds_tbe
|
||||
from .apply_ftrl import _apply_ftrl_tbe
|
||||
|
|
|
@ -0,0 +1,41 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Addcdiv op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
addcdiv_op_info = TBERegOp("Addcdiv") \
|
||||
.fusion_type("ELEMWISE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("addcdiv.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("addcdiv") \
|
||||
.partial_flag(True) \
|
||||
.input(0, "input_data", False, "required", "all") \
|
||||
.input(1, "x1", False, "required", "all") \
|
||||
.input(2, "x2", False, "required", "all") \
|
||||
.input(3, "value", False, "required", "all") \
|
||||
.output(0, "y", False, "required", "all") \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default,
|
||||
DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default,
|
||||
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(addcdiv_op_info)
|
||||
def _addcdiv_tbe():
|
||||
"""Addcdiv TBE register"""
|
||||
return
|
|
@ -0,0 +1,43 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Addcmul op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
addcmul_op_info = TBERegOp("Addcmul") \
|
||||
.fusion_type("ELEMWISE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("addcmul.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("addcmul") \
|
||||
.partial_flag(True) \
|
||||
.input(0, "input_data", False, "required", "all") \
|
||||
.input(1, "x1", False, "required", "all") \
|
||||
.input(2, "x2", False, "required", "all") \
|
||||
.input(3, "value", False, "required", "all") \
|
||||
.output(0, "y", False, "required", "all") \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default,
|
||||
DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default,
|
||||
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default,
|
||||
DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(addcmul_op_info)
|
||||
def _addcmul_tbe():
|
||||
"""Addcmul TBE register"""
|
||||
return
|
|
@ -57,7 +57,7 @@ from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AccumulateNV2, AssignAdd, A
|
|||
NPUAllocFloatStatus, NPUClearFloatStatus, LinSpace,
|
||||
NPUGetFloatStatus, Pow, RealDiv, IsNan, IsInf, IsFinite, FloatStatus,
|
||||
Reciprocal, CumSum, HistogramFixedWidth, SquaredDifference, Xdivy, Xlogy,
|
||||
Sin, Sqrt, Rsqrt, BesselI0e, BesselI1e, TruncateDiv, TruncateMod,
|
||||
Sin, Sqrt, Rsqrt, BesselI0e, BesselI1e, TruncateDiv, TruncateMod, Addcdiv, Addcmul,
|
||||
Square, Sub, TensorAdd, Add, Sign, Round, SquareSumAll, Atan, Atanh, Cosh, Sinh, Eps, Tan,
|
||||
MatrixInverse, IndexAdd, Erfinv, Conj, Real, Imag, Complex, Trunc, IsClose)
|
||||
|
||||
|
@ -294,6 +294,8 @@ __all__ = [
|
|||
'IsNan',
|
||||
'IsFinite',
|
||||
'IsInf',
|
||||
'Addcdiv',
|
||||
'Addcmul',
|
||||
'FloatStatus',
|
||||
'Reciprocal',
|
||||
'SmoothL1Loss',
|
||||
|
|
|
@ -240,6 +240,100 @@ class Add(_MathBinaryOp):
|
|||
return None
|
||||
|
||||
|
||||
class Addcdiv(Primitive):
|
||||
"""
|
||||
Performs the element-wise division of tensor x1 by tensor x2,
|
||||
multiply the result by the scalar value and add it to input_data.
|
||||
|
||||
.. math::
|
||||
y[i] = input_data[i] + value[i] * (x1[i] / x2[i])
|
||||
|
||||
Inputs:
|
||||
- **input_data**(Tensor) - The tensor to be added, with data type float16 and float32.
|
||||
- **x1** (Tensor) - The numerator tensor, with data type float16 and float32.
|
||||
- **x2** (Tensor) - The denominator tensor, with data type float16 and float32.
|
||||
- **value** (Tensor) - The multiplier for tensor x1/x2, with data type float16, float32.
|
||||
|
||||
Outputs:
|
||||
Tensor y, has the same shape and dtype as x1/x2.
|
||||
|
||||
Raises:
|
||||
TypeError: If dtype of `x1`, `x2`, `value`, `input_data`is not tensor.
|
||||
TypeError: If dtype of `input_data` is not one of: float32, float16.
|
||||
TypeError: If dtype of `x1` or 'x2' is not one of: float32, float16.
|
||||
TypeError: If dtype of `value` is not one of: float32, float16.
|
||||
ValueError: If `x1` could not be broadcast to a tensor with shape of `x2`.
|
||||
ValueError: If `value` could not be broadcast to tensors with shapes of `x1/x2`.
|
||||
ValueError: If `input_data` could not be broadcast to tensors with shapes of `value*(x1/x2)`.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend``
|
||||
|
||||
Examples:
|
||||
>>> input_data = Tensor(np.array([1, 1, 1, 1]), mindspore.float32)
|
||||
>>> x1 = Tensor(np.array([1, 2, 3, 4]), mindspore.float32)
|
||||
>>> x2 = Tensor(np.array([4, 3, 2, 1]), mindspore.float32)
|
||||
>>> value = Tensor([1], mindspore.float32)
|
||||
>>> addcdiv = ops.Addcdiv()
|
||||
>>> y = addcdiv(input_data, x1, x2, value)
|
||||
>>> print(y)
|
||||
[1.25 1.6666667 2.5 5. ]
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""Initialize Addcdiv """
|
||||
self.init_prim_io_names(inputs=['input_data', 'x1', 'x2', 'value'], outputs=['y'])
|
||||
|
||||
|
||||
class Addcmul(Primitive):
|
||||
"""
|
||||
Performs the element-wise product of tensor x1 and tensor x2,
|
||||
multiply the result by the scalar value and add it to input_data.
|
||||
|
||||
.. math::
|
||||
output[i] = input_data[i] + value[i] * (x1[i] * x2[i])
|
||||
|
||||
Inputs:
|
||||
- **input_data**(Tensor) - The tensor to be added, with data type float16, float32 and int32.
|
||||
- **x1** (Tensor) - The tensor to be multiplied, with data type float16, float32 and int32.
|
||||
- **x2** (Tensor) - The tensor to be multiplied, with data type float16, float32 and int32.
|
||||
- **value** (Tensor) - The multiplier for tensor x1*x2, with data type float16, float32 and int32.
|
||||
|
||||
Outputs:
|
||||
Tensor, has the same shape and dtype as x1*x2.
|
||||
|
||||
Raises:
|
||||
TypeError: If dtype of `x1`, `x2`, `value`, `input_data`is not tensor.
|
||||
TypeError: If dtype of `input_data` is not one of: float32, float16, int32.
|
||||
TypeError: If dtype of `x1` or 'x2' is not one of: float32, float16, int32.
|
||||
TypeError: If dtype of `value` is not one of: float32, float16, int32.
|
||||
ValueError: If `x1` could not be broadcast to a tensor with shape of `x2`.
|
||||
ValueError: If `value` could not be broadcast to tensors with shapes of `x1` * `x2`.
|
||||
ValueError: If `input_data` could not be broadcast to tensors with shapes of `value*(x1*x2)`.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend``
|
||||
|
||||
Examples:
|
||||
>>> input_data = Tensor(np.array([1, 1, 1]), mindspore.float32)
|
||||
>>> x1 = Tensor(np.array([[1], [2], [3]]), mindspore.float32)
|
||||
>>> x2 = Tensor(np.array([[1, 2, 3]]), mindspore.float32)
|
||||
>>> value = Tensor([1], mindspore.float32)
|
||||
>>> addcmul = ops.Addcmul()
|
||||
>>> y = addcmul(input_data, x1, x2, value)
|
||||
>>> print(y)
|
||||
[[ 2. 3. 4.]
|
||||
[ 3. 5. 7.]
|
||||
[ 4. 7. 10.]]
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""Initialize Addcmul """
|
||||
self.init_prim_io_names(inputs=['input_data', 'x1', 'x2', 'value'], outputs=['y'])
|
||||
|
||||
|
||||
class TensorAdd(_MathBinaryOp):
|
||||
"""
|
||||
Same as operator Add. TensorAdd will be deprecated in the future.
|
||||
|
|
|
@ -1784,6 +1784,20 @@ test_case_math_ops = [
|
|||
'desc_inputs': (Tensor(np.array([0, 1, 2]).astype(np.int32)),
|
||||
Tensor(np.array([[0.5, 1.0, 1.5], [1.0, 1.5, 2.0], [2.0, 2.5, 3.0]]).astype(np.float32))),
|
||||
'desc_bprop': [Tensor(np.array([[1, 1, 1], [1, 1, 1], [1, 1, 1]]).astype(np.float32))]}),
|
||||
('Addcdiv', {
|
||||
'block': P.Addcdiv(),
|
||||
'desc_inputs': [Tensor(np.array([1, 1, 1, 1]).astype(np.float32)),
|
||||
Tensor(np.array([1, 2, 3, 4]).astype(np.float32)),
|
||||
Tensor(np.array([4, 3, 2, 1]).astype(np.float32)),
|
||||
Tensor(np.array(1).astype(np.float32))],
|
||||
'desc_bprop': [Tensor(np.array([1, 1, 1, 1]).astype(np.float32))]}),
|
||||
('Addcmul', {
|
||||
'block': P.Addcmul(),
|
||||
'desc_inputs': [Tensor(np.array([1, 1, 1]).astype(np.float32)),
|
||||
Tensor(np.array([[1], [2], [3]]).astype(np.float32)),
|
||||
Tensor(np.array([[1, 2, 3]]).astype(np.float32)),
|
||||
Tensor(np.array(1).astype(np.float32))],
|
||||
'desc_bprop': [Tensor(np.array([[2, 3, 4], [3, 5, 7], [4, 7, 10]]).astype(np.float32))]}),
|
||||
('Real', {
|
||||
'block': P.Real(),
|
||||
'desc_inputs': [[2, 2]],
|
||||
|
|
Loading…
Reference in New Issue