!17169 convert python to C++ in Exp operator.

From: @shen_jingxing
Reviewed-by: @ginfung,@zh_qh
Signed-off-by: @zh_qh
This commit is contained in:
mindspore-ci-bot 2021-05-31 17:27:36 +08:00 committed by Gitee
commit 9c4d063198
4 changed files with 30 additions and 11 deletions

View File

@ -21,6 +21,7 @@
#include <map>
#include <string>
#include <vector>
#include "ops/exp.h"
#include "ops/real_div.h"
#include "ops/add.h"
#include "abstract/abstract_function.h"
@ -185,6 +186,7 @@ PrimitiveEvalImplMap &GetPrimitiveToBackendEvalImplMap() {
{prim::kPrimBiasAddGrad, {InferImplBiasAddGrad, nullptr, true}},
{prim::kPrimReduceScatter, {InferImplReduceScatter, nullptr, true}},
{prim::kPrimCast, {InferImplCast, nullptr, true}},
{prim::kPrimExp, {ops::ExpInfer, nullptr, true}},
{prim::kPrimExpandDims, {InferImplExpandDims, nullptr, true}},
{prim::kPrimAllReduce, {InferImplAllReduce, nullptr, true}},
{prim::kPrimBroadcast, {InferImplBroadcast, nullptr, true}},

View File

@ -45,6 +45,7 @@ constexpr auto kScalarTrunc = "ScalarTrunc";
constexpr auto kScalarFloor = "ScalarFloor";
constexpr auto kScalarUadd = "ScalarUadd";
constexpr auto kScalarUsub = "ScalarUsub";
constexpr auto kExp = "Exp";
constexpr auto kSub = "Sub";
constexpr auto kMul = "Mul";
constexpr auto kRealDiv = "RealDiv";
@ -430,7 +431,7 @@ inline const PrimitivePtr kPrimExpandDims = std::make_shared<Primitive>("ExpandD
inline const PrimitivePtr kPrimAbs = std::make_shared<Primitive>("Abs");
inline const PrimitivePtr kPrimRint = std::make_shared<Primitive>("Rint");
inline const PrimitivePtr kPrimRound = std::make_shared<Primitive>("Round");
inline const PrimitivePtr kPrimExp = std::make_shared<Primitive>("Exp");
inline const PrimitivePtr kPrimExp = std::make_shared<Primitive>(kExp);
inline const PrimitivePtr kPrimLog = std::make_shared<Primitive>("Log");
inline const PrimitivePtr kPrimRsqrt = std::make_shared<Primitive>("Rsqrt");
inline const PrimitivePtr kPrimRsqrtGrad = std::make_shared<Primitive>("RsqrtGrad");

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -13,32 +13,48 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ops/exp.h"
#include <map>
#include <string>
#include <vector>
#include <set>
#include <memory>
#include "ops/exp.h"
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "abstract/primitive_infer_map.h"
namespace mindspore {
namespace ops {
namespace {
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto x = input_args[0]->BuildShape();
auto shape_element = x->cast<abstract::ShapePtr>();
MS_EXCEPTION_IF_NULL(shape_element);
return shape_element;
auto prim_name = primitive->name();
CheckAndConvertUtils::CheckInteger("input numbers", input_args.size(), kEqual, 1, prim_name);
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
auto shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape());
auto in_shape = shape_map[kShape];
auto min_shape = shape_map[kMinShape];
auto max_shape = shape_map[kMaxShape];
if (min_shape.size() != 0 && max_shape.size() != 0) {
return std::make_shared<abstract::Shape>(in_shape, min_shape, max_shape);
}
return std::make_shared<abstract::Shape>(in_shape);
}
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(prim);
CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kEqual, 1, prim->name());
auto op_name = prim->name();
CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kEqual, 1, op_name);
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
std::map<std::string, TypePtr> types;
types.emplace("x", input_args[0]->BuildType());
std::set<TypePtr> valid_params_types = {kTensorType};
CheckAndConvertUtils::CheckSubClass("x_type", input_args[0]->BuildType(), valid_params_types, op_name);
return CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name());
}
} // namespace

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -26,10 +26,10 @@
namespace mindspore {
namespace ops {
constexpr auto kNameExp = "Exp";
constexpr auto kNameExp = prim::kExp;
class Exp : public PrimitiveC {
public:
Exp() : PrimitiveC(kNameExp) { InitIOName({"x"}, {"y"}); }
Exp() : PrimitiveC(prim::kPrimExp->name()) { InitIOName({"x"}, {"y"}); }
explicit Exp(const std::string k_name) : PrimitiveC(k_name) { InitIOName({"x"}, {"y"}); }
~Exp() = default;
MS_DECLARE_PARENT(Exp, PrimitiveC);