forked from mindspore-Ecosystem/mindspore
!17169 convert python to C++ in Exp operator.
From: @shen_jingxing Reviewed-by: @ginfung,@zh_qh Signed-off-by: @zh_qh
This commit is contained in:
commit
9c4d063198
|
@ -21,6 +21,7 @@
|
|||
#include <map>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "ops/exp.h"
|
||||
#include "ops/real_div.h"
|
||||
#include "ops/add.h"
|
||||
#include "abstract/abstract_function.h"
|
||||
|
@ -185,6 +186,7 @@ PrimitiveEvalImplMap &GetPrimitiveToBackendEvalImplMap() {
|
|||
{prim::kPrimBiasAddGrad, {InferImplBiasAddGrad, nullptr, true}},
|
||||
{prim::kPrimReduceScatter, {InferImplReduceScatter, nullptr, true}},
|
||||
{prim::kPrimCast, {InferImplCast, nullptr, true}},
|
||||
{prim::kPrimExp, {ops::ExpInfer, nullptr, true}},
|
||||
{prim::kPrimExpandDims, {InferImplExpandDims, nullptr, true}},
|
||||
{prim::kPrimAllReduce, {InferImplAllReduce, nullptr, true}},
|
||||
{prim::kPrimBroadcast, {InferImplBroadcast, nullptr, true}},
|
||||
|
|
|
@ -45,6 +45,7 @@ constexpr auto kScalarTrunc = "ScalarTrunc";
|
|||
constexpr auto kScalarFloor = "ScalarFloor";
|
||||
constexpr auto kScalarUadd = "ScalarUadd";
|
||||
constexpr auto kScalarUsub = "ScalarUsub";
|
||||
constexpr auto kExp = "Exp";
|
||||
constexpr auto kSub = "Sub";
|
||||
constexpr auto kMul = "Mul";
|
||||
constexpr auto kRealDiv = "RealDiv";
|
||||
|
@ -430,7 +431,7 @@ inline const PrimitivePtr kPrimExpandDims = std::make_shared<Primitive>("ExpandD
|
|||
inline const PrimitivePtr kPrimAbs = std::make_shared<Primitive>("Abs");
|
||||
inline const PrimitivePtr kPrimRint = std::make_shared<Primitive>("Rint");
|
||||
inline const PrimitivePtr kPrimRound = std::make_shared<Primitive>("Round");
|
||||
inline const PrimitivePtr kPrimExp = std::make_shared<Primitive>("Exp");
|
||||
inline const PrimitivePtr kPrimExp = std::make_shared<Primitive>(kExp);
|
||||
inline const PrimitivePtr kPrimLog = std::make_shared<Primitive>("Log");
|
||||
inline const PrimitivePtr kPrimRsqrt = std::make_shared<Primitive>("Rsqrt");
|
||||
inline const PrimitivePtr kPrimRsqrtGrad = std::make_shared<Primitive>("RsqrtGrad");
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -13,32 +13,48 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ops/exp.h"
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <set>
|
||||
#include <memory>
|
||||
#include "ops/exp.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "abstract/primitive_infer_map.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto x = input_args[0]->BuildShape();
|
||||
auto shape_element = x->cast<abstract::ShapePtr>();
|
||||
MS_EXCEPTION_IF_NULL(shape_element);
|
||||
return shape_element;
|
||||
auto prim_name = primitive->name();
|
||||
CheckAndConvertUtils::CheckInteger("input numbers", input_args.size(), kEqual, 1, prim_name);
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
auto shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape());
|
||||
auto in_shape = shape_map[kShape];
|
||||
auto min_shape = shape_map[kMinShape];
|
||||
auto max_shape = shape_map[kMaxShape];
|
||||
if (min_shape.size() != 0 && max_shape.size() != 0) {
|
||||
return std::make_shared<abstract::Shape>(in_shape, min_shape, max_shape);
|
||||
}
|
||||
return std::make_shared<abstract::Shape>(in_shape);
|
||||
}
|
||||
|
||||
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kEqual, 1, prim->name());
|
||||
auto op_name = prim->name();
|
||||
CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kEqual, 1, op_name);
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
std::map<std::string, TypePtr> types;
|
||||
types.emplace("x", input_args[0]->BuildType());
|
||||
std::set<TypePtr> valid_params_types = {kTensorType};
|
||||
CheckAndConvertUtils::CheckSubClass("x_type", input_args[0]->BuildType(), valid_params_types, op_name);
|
||||
return CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name());
|
||||
}
|
||||
} // namespace
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -26,10 +26,10 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameExp = "Exp";
|
||||
constexpr auto kNameExp = prim::kExp;
|
||||
class Exp : public PrimitiveC {
|
||||
public:
|
||||
Exp() : PrimitiveC(kNameExp) { InitIOName({"x"}, {"y"}); }
|
||||
Exp() : PrimitiveC(prim::kPrimExp->name()) { InitIOName({"x"}, {"y"}); }
|
||||
explicit Exp(const std::string k_name) : PrimitiveC(k_name) { InitIOName({"x"}, {"y"}); }
|
||||
~Exp() = default;
|
||||
MS_DECLARE_PARENT(Exp, PrimitiveC);
|
||||
|
|
Loading…
Reference in New Issue