forked from mindspore-Ecosystem/mindspore
commit
a38e2ffa8d
|
@ -22,6 +22,7 @@
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "ops/real_div.h"
|
||||||
#include "abstract/abstract_function.h"
|
#include "abstract/abstract_function.h"
|
||||||
#include "abstract/infer_functions.h"
|
#include "abstract/infer_functions.h"
|
||||||
|
|
||||||
|
@ -202,7 +203,7 @@ PrimitiveEvalImplMap &GetPrimitiveToBackendEvalImplMap() {
|
||||||
{prim::kPrimPad, {InferImplPad, nullptr, true}},
|
{prim::kPrimPad, {InferImplPad, nullptr, true}},
|
||||||
{prim::kPrimUnsortedSegmentSum, {InferImplUnsortedSegmentSum, nullptr, true}},
|
{prim::kPrimUnsortedSegmentSum, {InferImplUnsortedSegmentSum, nullptr, true}},
|
||||||
{prim::kPrimDiv, {InferImplDiv, nullptr, true}},
|
{prim::kPrimDiv, {InferImplDiv, nullptr, true}},
|
||||||
{prim::kPrimRealDiv, {InferImplRealDiv, nullptr, true}},
|
{prim::kPrimRealDiv, {ops::RealDivInfer, nullptr, false}},
|
||||||
{prim::kPrimShape, {InferImplShape, nullptr, false}},
|
{prim::kPrimShape, {InferImplShape, nullptr, false}},
|
||||||
{prim::kPrimTranspose, {InferImplTranspose, nullptr, true}},
|
{prim::kPrimTranspose, {InferImplTranspose, nullptr, true}},
|
||||||
{prim::kPrimReshape, {InferImplReshape, nullptr, true}},
|
{prim::kPrimReshape, {InferImplReshape, nullptr, true}},
|
||||||
|
|
|
@ -47,6 +47,7 @@ constexpr auto kScalarUadd = "ScalarUadd";
|
||||||
constexpr auto kScalarUsub = "ScalarUsub";
|
constexpr auto kScalarUsub = "ScalarUsub";
|
||||||
constexpr auto kSub = "Sub";
|
constexpr auto kSub = "Sub";
|
||||||
constexpr auto kMul = "Mul";
|
constexpr auto kMul = "Mul";
|
||||||
|
constexpr auto kRealDiv = "RealDiv";
|
||||||
|
|
||||||
// Arrays
|
// Arrays
|
||||||
constexpr auto kStack = "Stack";
|
constexpr auto kStack = "Stack";
|
||||||
|
@ -401,7 +402,7 @@ inline const PrimitivePtr kPrimInplaceAdd = std::make_shared<Primitive>("Inplace
|
||||||
inline const PrimitivePtr kPrimInplaceSub = std::make_shared<Primitive>("InplaceSub");
|
inline const PrimitivePtr kPrimInplaceSub = std::make_shared<Primitive>("InplaceSub");
|
||||||
inline const PrimitivePtr kPrimPow = std::make_shared<Primitive>("Pow");
|
inline const PrimitivePtr kPrimPow = std::make_shared<Primitive>("Pow");
|
||||||
inline const PrimitivePtr kPrimPower = std::make_shared<Primitive>("Power");
|
inline const PrimitivePtr kPrimPower = std::make_shared<Primitive>("Power");
|
||||||
inline const PrimitivePtr kPrimRealDiv = std::make_shared<Primitive>("RealDiv");
|
inline const PrimitivePtr kPrimRealDiv = std::make_shared<Primitive>(kRealDiv);
|
||||||
inline const PrimitivePtr kPrimFloorDiv = std::make_shared<Primitive>("FloorDiv");
|
inline const PrimitivePtr kPrimFloorDiv = std::make_shared<Primitive>("FloorDiv");
|
||||||
inline const PrimitivePtr kPrimSqrt = std::make_shared<Primitive>("Sqrt");
|
inline const PrimitivePtr kPrimSqrt = std::make_shared<Primitive>("Sqrt");
|
||||||
inline const PrimitivePtr kPrimSqrtGrad = std::make_shared<Primitive>("SqrtGrad");
|
inline const PrimitivePtr kPrimSqrtGrad = std::make_shared<Primitive>("SqrtGrad");
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -28,16 +28,21 @@ namespace ops {
|
||||||
namespace {
|
namespace {
|
||||||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||||
MS_EXCEPTION_IF_NULL(primitive);
|
MS_EXCEPTION_IF_NULL(primitive);
|
||||||
auto op_name = primitive->name();
|
auto prim_name = primitive->name();
|
||||||
return BroadCastInferShape(op_name, input_args);
|
CheckAndConvertUtils::CheckInteger("input numbers", input_args.size(), kGreaterEqual, 2, prim_name);
|
||||||
|
for (const auto &item : input_args) {
|
||||||
|
MS_EXCEPTION_IF_NULL(item);
|
||||||
|
}
|
||||||
|
return BroadCastInferShape(prim_name, input_args);
|
||||||
}
|
}
|
||||||
|
|
||||||
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||||
MS_EXCEPTION_IF_NULL(prim);
|
MS_EXCEPTION_IF_NULL(prim);
|
||||||
CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kEqual, 2, prim->name());
|
|
||||||
for (const auto &item : input_args) {
|
for (const auto &item : input_args) {
|
||||||
MS_EXCEPTION_IF_NULL(item);
|
MS_EXCEPTION_IF_NULL(item);
|
||||||
}
|
}
|
||||||
|
auto op_name = prim->name();
|
||||||
|
CheckAndConvertUtils::CheckInteger("RealDiv infer", input_args.size(), kGreaterEqual, 2, op_name);
|
||||||
std::map<std::string, TypePtr> types;
|
std::map<std::string, TypePtr> types;
|
||||||
types.emplace("x", input_args[0]->BuildType());
|
types.emplace("x", input_args[0]->BuildType());
|
||||||
types.emplace("y", input_args[1]->BuildType());
|
types.emplace("y", input_args[1]->BuildType());
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
#ifndef MINDSPORE_CORE_OPS_REAL_DIV_H_
|
#ifndef MINDSPORE_CORE_OPS_REAL_DIV_H_
|
||||||
#define MINDSPORE_CORE_OPS_REAL_DIV_H_
|
#define MINDSPORE_CORE_OPS_REAL_DIV_H_
|
||||||
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include "ops/primitive_c.h"
|
#include "ops/primitive_c.h"
|
||||||
|
@ -24,7 +25,7 @@
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
constexpr auto kNameRealDiv = "RealDiv";
|
constexpr auto kNameRealDiv = prim::kRealDiv;
|
||||||
class RealDiv : public PrimitiveC {
|
class RealDiv : public PrimitiveC {
|
||||||
public:
|
public:
|
||||||
RealDiv() : PrimitiveC(kNameRealDiv) { InitIOName({"x", "y"}, {"output"}); }
|
RealDiv() : PrimitiveC(kNameRealDiv) { InitIOName({"x", "y"}, {"output"}); }
|
||||||
|
|
Loading…
Reference in New Issue