From: @shen_jingxing
Reviewed-by: 
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2021-05-17 16:50:20 +08:00 committed by Gitee
commit a38e2ffa8d
4 changed files with 16 additions and 8 deletions

View File

@ -22,6 +22,7 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include "ops/real_div.h"
#include "abstract/abstract_function.h" #include "abstract/abstract_function.h"
#include "abstract/infer_functions.h" #include "abstract/infer_functions.h"
@ -202,7 +203,7 @@ PrimitiveEvalImplMap &GetPrimitiveToBackendEvalImplMap() {
{prim::kPrimPad, {InferImplPad, nullptr, true}}, {prim::kPrimPad, {InferImplPad, nullptr, true}},
{prim::kPrimUnsortedSegmentSum, {InferImplUnsortedSegmentSum, nullptr, true}}, {prim::kPrimUnsortedSegmentSum, {InferImplUnsortedSegmentSum, nullptr, true}},
{prim::kPrimDiv, {InferImplDiv, nullptr, true}}, {prim::kPrimDiv, {InferImplDiv, nullptr, true}},
{prim::kPrimRealDiv, {InferImplRealDiv, nullptr, true}}, {prim::kPrimRealDiv, {ops::RealDivInfer, nullptr, false}},
{prim::kPrimShape, {InferImplShape, nullptr, false}}, {prim::kPrimShape, {InferImplShape, nullptr, false}},
{prim::kPrimTranspose, {InferImplTranspose, nullptr, true}}, {prim::kPrimTranspose, {InferImplTranspose, nullptr, true}},
{prim::kPrimReshape, {InferImplReshape, nullptr, true}}, {prim::kPrimReshape, {InferImplReshape, nullptr, true}},

View File

@ -47,6 +47,7 @@ constexpr auto kScalarUadd = "ScalarUadd";
constexpr auto kScalarUsub = "ScalarUsub"; constexpr auto kScalarUsub = "ScalarUsub";
constexpr auto kSub = "Sub"; constexpr auto kSub = "Sub";
constexpr auto kMul = "Mul"; constexpr auto kMul = "Mul";
constexpr auto kRealDiv = "RealDiv";
// Arrays // Arrays
constexpr auto kStack = "Stack"; constexpr auto kStack = "Stack";
@ -401,7 +402,7 @@ inline const PrimitivePtr kPrimInplaceAdd = std::make_shared<Primitive>("Inplace
inline const PrimitivePtr kPrimInplaceSub = std::make_shared<Primitive>("InplaceSub"); inline const PrimitivePtr kPrimInplaceSub = std::make_shared<Primitive>("InplaceSub");
inline const PrimitivePtr kPrimPow = std::make_shared<Primitive>("Pow"); inline const PrimitivePtr kPrimPow = std::make_shared<Primitive>("Pow");
inline const PrimitivePtr kPrimPower = std::make_shared<Primitive>("Power"); inline const PrimitivePtr kPrimPower = std::make_shared<Primitive>("Power");
inline const PrimitivePtr kPrimRealDiv = std::make_shared<Primitive>("RealDiv"); inline const PrimitivePtr kPrimRealDiv = std::make_shared<Primitive>(kRealDiv);
inline const PrimitivePtr kPrimFloorDiv = std::make_shared<Primitive>("FloorDiv"); inline const PrimitivePtr kPrimFloorDiv = std::make_shared<Primitive>("FloorDiv");
inline const PrimitivePtr kPrimSqrt = std::make_shared<Primitive>("Sqrt"); inline const PrimitivePtr kPrimSqrt = std::make_shared<Primitive>("Sqrt");
inline const PrimitivePtr kPrimSqrtGrad = std::make_shared<Primitive>("SqrtGrad"); inline const PrimitivePtr kPrimSqrtGrad = std::make_shared<Primitive>("SqrtGrad");

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2020-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -28,16 +28,21 @@ namespace ops {
namespace { namespace {
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive); MS_EXCEPTION_IF_NULL(primitive);
auto op_name = primitive->name(); auto prim_name = primitive->name();
return BroadCastInferShape(op_name, input_args); CheckAndConvertUtils::CheckInteger("input numbers", input_args.size(), kGreaterEqual, 2, prim_name);
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
return BroadCastInferShape(prim_name, input_args);
} }
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) { TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(prim); MS_EXCEPTION_IF_NULL(prim);
CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kEqual, 2, prim->name());
for (const auto &item : input_args) { for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item); MS_EXCEPTION_IF_NULL(item);
} }
auto op_name = prim->name();
CheckAndConvertUtils::CheckInteger("RealDiv infer", input_args.size(), kGreaterEqual, 2, op_name);
std::map<std::string, TypePtr> types; std::map<std::string, TypePtr> types;
types.emplace("x", input_args[0]->BuildType()); types.emplace("x", input_args[0]->BuildType());
types.emplace("y", input_args[1]->BuildType()); types.emplace("y", input_args[1]->BuildType());

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2020-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -16,6 +16,7 @@
#ifndef MINDSPORE_CORE_OPS_REAL_DIV_H_ #ifndef MINDSPORE_CORE_OPS_REAL_DIV_H_
#define MINDSPORE_CORE_OPS_REAL_DIV_H_ #define MINDSPORE_CORE_OPS_REAL_DIV_H_
#include <string>
#include <vector> #include <vector>
#include <memory> #include <memory>
#include "ops/primitive_c.h" #include "ops/primitive_c.h"
@ -24,7 +25,7 @@
namespace mindspore { namespace mindspore {
namespace ops { namespace ops {
constexpr auto kNameRealDiv = "RealDiv"; constexpr auto kNameRealDiv = prim::kRealDiv;
class RealDiv : public PrimitiveC { class RealDiv : public PrimitiveC {
public: public:
RealDiv() : PrimitiveC(kNameRealDiv) { InitIOName({"x", "y"}, {"output"}); } RealDiv() : PrimitiveC(kNameRealDiv) { InitIOName({"x", "y"}, {"output"}); }