forked from mindspore-Ecosystem/mindspore
optimize build
This commit is contained in:
parent
c4843c4085
commit
e0d339534e
|
@ -13,21 +13,23 @@ if(NOT(BUILD_LITE))
|
|||
add_subdirectory(mindrt)
|
||||
endif()
|
||||
|
||||
include(${TOP_DIR}/mindspore/lite/cmake/merge.cmake)
|
||||
if(ENABLE_SECURITY)
|
||||
merge_files(${CMAKE_CURRENT_SOURCE_DIR}/ops/ ${CMAKE_BINARY_DIR}/merge/mindspore/core/ops_merge.cc "_summary.cc$")
|
||||
else()
|
||||
merge_files(${CMAKE_CURRENT_SOURCE_DIR}/ops/ ${CMAKE_BINARY_DIR}/merge/mindspore/core/ops_merge.cc "")
|
||||
endif()
|
||||
|
||||
file(GLOB_RECURSE CORE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
||||
"abstract/*.cc"
|
||||
"base/*.cc"
|
||||
"ops/*.cc"
|
||||
"${CMAKE_BINARY_DIR}/merge/mindspore/core/ops_merge.cc"
|
||||
"ir/*.cc"
|
||||
"utils/*.cc"
|
||||
"load_mindir/*.cc"
|
||||
"mindapi/src/*.cc"
|
||||
)
|
||||
|
||||
if(ENABLE_SECURITY)
|
||||
file(GLOB_RECURSE _INFER_SUMMARY_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "ops/*_summary.cc")
|
||||
list(REMOVE_ITEM CORE_SRC_LIST ${_INFER_SUMMARY_FILES})
|
||||
endif()
|
||||
|
||||
file(GLOB_RECURSE PROTO_FILE RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "proto/*.proto")
|
||||
if(NOT(BUILD_LITE))
|
||||
ms_protobuf_generate_py(PROTO_SRCS PY_HDRS PY_PYS ${PROTO_FILE})
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "LayerNormBetaGammaBackprop.h"
|
||||
#include "ops/LayerNormBetaGammaBackprop.h"
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include "ops/op_utils.h"
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "LayerNormXBackprop.h"
|
||||
#include "ops/LayerNormXBackprop.h"
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include "ops/op_utils.h"
|
||||
|
|
|
@ -19,8 +19,6 @@
|
|||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
const size_t input_num = 1;
|
||||
|
||||
abstract::ShapePtr ACosInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto prim_name = primitive->name();
|
||||
(void)CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(prim_name, input_args, 0);
|
||||
|
@ -44,6 +42,7 @@ AbstractBasePtr ACosInfer(const abstract::AnalysisEnginePtr &, const PrimitivePt
|
|||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
const size_t input_num = 1;
|
||||
(void)CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, prim_name);
|
||||
MS_EXCEPTION_IF_NULL(input_args[kInputIndex0]);
|
||||
auto types = ACosInferType(primitive, input_args);
|
||||
|
|
|
@ -19,15 +19,13 @@
|
|||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
const size_t input_num = 1;
|
||||
const int64_t max_dim = 8;
|
||||
|
||||
abstract::ShapePtr AcoshInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto prim_name = primitive->name();
|
||||
(void)CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(prim_name, input_args, 0);
|
||||
auto x = input_args[kInputIndex0]->BuildShape();
|
||||
MS_EXCEPTION_IF_NULL(x);
|
||||
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
const int64_t max_dim = 8;
|
||||
(void)CheckAndConvertUtils::CheckInteger("The dimension of Acosh input", SizeToLong(in_shape.size()), kLessThan,
|
||||
max_dim, prim_name);
|
||||
auto shape_element = x->cast<abstract::ShapePtr>();
|
||||
|
@ -48,6 +46,7 @@ AbstractBasePtr AcoshInfer(const abstract::AnalysisEnginePtr &, const PrimitiveP
|
|||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
const size_t input_num = 1;
|
||||
(void)CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, prim_name);
|
||||
auto types = AcoshInferType(primitive, input_args);
|
||||
auto shapes = AcoshInferShape(primitive, input_args);
|
||||
|
|
|
@ -26,7 +26,7 @@
|
|||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
abstract::ShapePtr AddcdivInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto op_name = primitive->name();
|
||||
for (const auto &item : input_args) {
|
||||
|
@ -50,7 +50,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
|||
return std::make_shared<abstract::Shape>(broadcast_shape);
|
||||
}
|
||||
|
||||
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
TypePtr AddcdivInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
|
@ -78,8 +78,8 @@ AbstractBasePtr AddcdivInfer(const abstract::AnalysisEnginePtr &, const Primitiv
|
|||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
const int64_t kInputsNum = 4;
|
||||
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, kInputsNum, primitive->name());
|
||||
auto infer_shape = InferShape(primitive, input_args);
|
||||
auto infer_type = InferType(primitive, input_args);
|
||||
auto infer_shape = AddcdivInferShape(primitive, input_args);
|
||||
auto infer_type = AddcdivInferType(primitive, input_args);
|
||||
return abstract::MakeAbstract(infer_shape, infer_type);
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(Addcdiv, prim::kPrimAddcdiv, AddcdivInfer, nullptr, true);
|
||||
|
|
|
@ -26,7 +26,7 @@
|
|||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
abstract::ShapePtr AddcmulInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto op_name = primitive->name();
|
||||
for (const auto &item : input_args) {
|
||||
|
@ -50,7 +50,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
|||
return std::make_shared<abstract::Shape>(broadcast_shape);
|
||||
}
|
||||
|
||||
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
TypePtr AddcmulInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
|
@ -78,8 +78,8 @@ AbstractBasePtr AddcmulInfer(const abstract::AnalysisEnginePtr &, const Primitiv
|
|||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
const int64_t kInputsNum = 4;
|
||||
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, kInputsNum, primitive->name());
|
||||
auto infer_shape = InferShape(primitive, input_args);
|
||||
auto infer_type = InferType(primitive, input_args);
|
||||
auto infer_shape = AddcmulInferShape(primitive, input_args);
|
||||
auto infer_type = AddcmulInferType(primitive, input_args);
|
||||
return abstract::MakeAbstract(infer_shape, infer_type);
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(Addcmul, prim::kPrimAddcmul, AddcmulInfer, nullptr, true);
|
||||
|
|
|
@ -28,7 +28,8 @@ namespace mindspore {
|
|||
namespace ops {
|
||||
|
||||
namespace {
|
||||
abstract::TupleShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
abstract::TupleShapePtr ApplyAdagradDAInferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
const int64_t input_num = 8;
|
||||
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, input_num,
|
||||
|
@ -54,7 +55,7 @@ abstract::TupleShapePtr InferShape(const PrimitivePtr &primitive, const std::vec
|
|||
std::vector<abstract::BaseShapePtr>{var_shape, gradient_accumulator_shape, gradient_squared_accumulator_shape});
|
||||
}
|
||||
|
||||
TuplePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
TuplePtr ApplyAdagradDAInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
auto prim_name = prim->name();
|
||||
const int64_t input_num = 8;
|
||||
|
@ -100,7 +101,8 @@ TuplePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr>
|
|||
AbstractBasePtr ApplyAdagradDAInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
return abstract::MakeAbstract(InferShape(primitive, input_args), InferType(primitive, input_args));
|
||||
return abstract::MakeAbstract(ApplyAdagradDAInferShape(primitive, input_args),
|
||||
ApplyAdagradDAInferType(primitive, input_args));
|
||||
}
|
||||
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(ApplyAdagradDA, prim::kPrimApplyAdagradDA, ApplyAdagradDAInfer, nullptr, true);
|
||||
|
|
|
@ -27,7 +27,8 @@
|
|||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
abstract::TupleShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
abstract::TupleShapePtr ApplyAdamWithAmsgradInferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
for (const auto &item : input_args) {
|
||||
|
@ -61,7 +62,7 @@ abstract::TupleShapePtr InferShape(const PrimitivePtr &primitive, const std::vec
|
|||
std::vector<abstract::BaseShapePtr>{var_shape, m_shape, v_shape, vhat_shape});
|
||||
}
|
||||
|
||||
TuplePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
TuplePtr ApplyAdamWithAmsgradInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
auto prim_name = prim->name();
|
||||
// get all input_args' shape
|
||||
|
@ -95,8 +96,8 @@ AbstractBasePtr ApplyAdamWithAmsgradInfer(const abstract::AnalysisEnginePtr &, c
|
|||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
const int64_t input_num = 8;
|
||||
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name());
|
||||
auto infer_type = InferType(primitive, input_args);
|
||||
auto infer_shape = InferShape(primitive, input_args);
|
||||
auto infer_type = ApplyAdamWithAmsgradInferType(primitive, input_args);
|
||||
auto infer_shape = ApplyAdamWithAmsgradInferShape(primitive, input_args);
|
||||
return abstract::MakeAbstract(infer_shape, infer_type);
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(ApplyAdamWithAmsgrad, prim::kPrimApplyAdamWithAmsgrad, ApplyAdamWithAmsgradInfer, nullptr,
|
||||
|
|
|
@ -26,7 +26,8 @@
|
|||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
abstract::TupleShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
abstract::TupleShapePtr ApplyKerasMomentumInferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, 5, prim_name);
|
||||
|
@ -54,7 +55,7 @@ abstract::TupleShapePtr InferShape(const PrimitivePtr &primitive, const std::vec
|
|||
return std::make_shared<abstract::TupleShape>(std::vector<abstract::BaseShapePtr>{var_shape, accum_shape});
|
||||
}
|
||||
|
||||
TuplePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
TuplePtr ApplyKerasMomentumInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
auto prim_name = prim->name();
|
||||
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, 5, prim_name);
|
||||
|
@ -83,7 +84,8 @@ TuplePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr>
|
|||
AbstractBasePtr ApplyKerasMomentumInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
return abstract::MakeAbstract(InferShape(primitive, input_args), InferType(primitive, input_args));
|
||||
return abstract::MakeAbstract(ApplyKerasMomentumInferShape(primitive, input_args),
|
||||
ApplyKerasMomentumInferType(primitive, input_args));
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(ApplyKerasMomentum, prim::kPrimApplyKerasMomentum, ApplyKerasMomentumInfer, nullptr, true);
|
||||
} // namespace ops
|
||||
|
|
|
@ -28,13 +28,13 @@
|
|||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
const int64_t kInputNum = 7;
|
||||
const int64_t kApplyPowerSignDInputNum = 7;
|
||||
abstract::TupleShapePtr ApplyPowerSignDInferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kGreaterEqual, kInputNum,
|
||||
prim_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kGreaterEqual,
|
||||
kApplyPowerSignDInputNum, prim_name);
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
|
@ -82,8 +82,8 @@ abstract::TupleShapePtr ApplyPowerSignDInferShape(const PrimitivePtr &primitive,
|
|||
TuplePtr ApplyPowerSignDInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
auto prim_name = prim->name();
|
||||
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kGreaterEqual, kInputNum,
|
||||
prim_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kGreaterEqual,
|
||||
kApplyPowerSignDInputNum, prim_name);
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
|
|
|
@ -19,7 +19,7 @@
|
|||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
abstract::ShapePtr ArgMaxInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
auto axis = GetValue<int64_t>(primitive->GetAttr(kAxis));
|
||||
|
@ -36,7 +36,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
|||
return std::make_shared<abstract::Shape>(out_shape);
|
||||
}
|
||||
|
||||
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
TypePtr ArgMaxInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, 1, prim->name());
|
||||
for (const auto &item : input_args) {
|
||||
|
@ -62,8 +62,8 @@ TypeId ArgMax::get_output_type() const {
|
|||
|
||||
AbstractBasePtr ArgMaxInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args),
|
||||
InferShape(primitive, input_args)->shape());
|
||||
return std::make_shared<abstract::AbstractTensor>(ArgMaxInferType(primitive, input_args),
|
||||
ArgMaxInferShape(primitive, input_args)->shape());
|
||||
}
|
||||
REGISTER_PRIMITIVE_C(kNameArgMax, ArgMax);
|
||||
} // namespace ops
|
||||
|
|
|
@ -23,8 +23,6 @@
|
|||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
const size_t InputNum = 1;
|
||||
|
||||
abstract::ShapePtr AsinInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto prim_name = primitive->name();
|
||||
(void)CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(prim_name, input_args, 0);
|
||||
|
@ -48,6 +46,7 @@ AbstractBasePtr AsinInfer(const abstract::AnalysisEnginePtr &, const PrimitivePt
|
|||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
const size_t InputNum = 1;
|
||||
(void)CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, InputNum, prim_name);
|
||||
MS_EXCEPTION_IF_NULL(input_args[kInputIndex0]);
|
||||
auto types = AsinInferType(primitive, input_args);
|
||||
|
|
|
@ -19,15 +19,13 @@
|
|||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
const size_t InputNum = 1;
|
||||
const int64_t MaxDim = 8;
|
||||
|
||||
abstract::ShapePtr AsinhInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto prim_name = primitive->name();
|
||||
(void)CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(prim_name, input_args, 0);
|
||||
auto x = input_args[kInputIndex0]->BuildShape();
|
||||
MS_EXCEPTION_IF_NULL(x);
|
||||
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
const int64_t MaxDim = 8;
|
||||
(void)CheckAndConvertUtils::CheckInteger("The dimension of Asinh input", SizeToLong(in_shape.size()), kLessThan,
|
||||
MaxDim, prim_name);
|
||||
auto shape_element = x->cast<abstract::ShapePtr>();
|
||||
|
@ -48,6 +46,7 @@ AbstractBasePtr AsinhInfer(const abstract::AnalysisEnginePtr &, const PrimitiveP
|
|||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
const size_t InputNum = 1;
|
||||
(void)CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, InputNum, prim_name);
|
||||
auto types = AsinhInferType(primitive, input_args);
|
||||
auto shapes = AsinhInferShape(primitive, input_args);
|
||||
|
|
|
@ -79,7 +79,7 @@ void AvgPool::Init(const std::vector<int64_t> &kernel_size, const std::vector<in
|
|||
}
|
||||
|
||||
namespace {
|
||||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
abstract::ShapePtr AvgPoolInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto op_name = primitive->name();
|
||||
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShapeTrack())[kShape];
|
||||
auto format = Format(GetValue<int64_t>(primitive->GetAttr(kFormat)));
|
||||
|
@ -124,7 +124,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
|||
return std::make_shared<abstract::Shape>(out_shape);
|
||||
}
|
||||
|
||||
TypePtr InferType(const std::vector<AbstractBasePtr> &input_args) { return input_args[0]->BuildType(); }
|
||||
TypePtr AvgPoolInferType(const std::vector<AbstractBasePtr> &input_args) { return input_args[0]->BuildType(); }
|
||||
} // namespace
|
||||
|
||||
AbstractBasePtr AvgPoolInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
|
@ -132,7 +132,8 @@ AbstractBasePtr AvgPoolInfer(const abstract::AnalysisEnginePtr &, const Primitiv
|
|||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
const int64_t input_num = 1;
|
||||
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name());
|
||||
return std::make_shared<abstract::AbstractTensor>(InferType(input_args), InferShape(primitive, input_args)->shape());
|
||||
return std::make_shared<abstract::AbstractTensor>(AvgPoolInferType(input_args),
|
||||
AvgPoolInferShape(primitive, input_args)->shape());
|
||||
}
|
||||
REGISTER_PRIMITIVE_C(kNameAvgPool, AvgPool);
|
||||
} // namespace ops
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2021-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -26,13 +26,12 @@
|
|||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
constexpr int64_t k5DInputDims = 5;
|
||||
constexpr size_t kKernelDims = 3;
|
||||
constexpr size_t kStridesDims = 3;
|
||||
constexpr size_t kPadDims = 6;
|
||||
constexpr size_t kAvgPool3DPadDims = 6;
|
||||
|
||||
void GetAttrs(const PrimitivePtr &primitive, std::vector<int64_t> *kernel_size, std::vector<int64_t> *strides,
|
||||
int64_t *pad_mode, std::vector<int64_t> *pad_list, bool *ceil_mode, bool *count_include_pad) {
|
||||
constexpr size_t kKernelDims = 3;
|
||||
constexpr size_t kStridesDims = 3;
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
// attr kernel size
|
||||
*kernel_size = GetValue<std::vector<int64_t>>(primitive->GetAttr(kKernelSize));
|
||||
|
@ -66,11 +65,14 @@ std::vector<int64_t> GetOutputShape(const std::vector<int64_t> &in_shape, int64_
|
|||
int64_t out_d = 0;
|
||||
int64_t out_h = 0;
|
||||
int64_t out_w = 0;
|
||||
if (stride_d == 0 || stride_h == 0 || stride_w == 0) {
|
||||
MS_LOG(EXCEPTION) << "stride_d or stride_h or stride_w must be non-zero";
|
||||
}
|
||||
if (ceil_mode) {
|
||||
out_d =
|
||||
static_cast<int64_t>(std::floor((in_d + pad_list[0] + pad_list[1] - kernel_d + stride_d - 1) / stride_d + 1));
|
||||
out_h =
|
||||
static_cast<int64_t>(std::floor((in_h + pad_list[2] + pad_list[3] - kernel_h + stride_h - 1) / stride_h + 1));
|
||||
out_h = static_cast<int64_t>(
|
||||
std::floor((in_h + pad_list[kInputIndex2] + pad_list[kInputIndex3] - kernel_h + stride_h - 1) / stride_h + 1));
|
||||
out_w =
|
||||
static_cast<int64_t>(std::floor((in_w + pad_list[4] + pad_list[5] - kernel_w + stride_w - 1) / stride_w + 1));
|
||||
if ((out_d - 1) * stride_d >= in_d + pad_list[0]) {
|
||||
|
@ -85,7 +87,8 @@ std::vector<int64_t> GetOutputShape(const std::vector<int64_t> &in_shape, int64_
|
|||
} else {
|
||||
out_d = static_cast<int64_t>(std::floor((in_d + pad_list[0] + pad_list[1] - kernel_d) / stride_d + 1));
|
||||
out_h = static_cast<int64_t>(std::floor((in_h + pad_list[2] + pad_list[3] - kernel_h) / stride_h + 1));
|
||||
out_w = static_cast<int64_t>(std::floor((in_w + pad_list[4] + pad_list[5] - kernel_w) / stride_w + 1));
|
||||
out_w = static_cast<int64_t>(
|
||||
std::floor((in_w + pad_list[kInputIndex4] + pad_list[kInputIndex5] - kernel_w) / stride_w + 1));
|
||||
}
|
||||
std::vector<int64_t> output_shape = {in_shape[0], in_shape[1], out_d, out_h, out_w};
|
||||
return output_shape;
|
||||
|
@ -95,7 +98,7 @@ void GetPadsByPadding(int64_t in_d, int64_t in_h, int64_t in_w, int64_t kernel_d
|
|||
int64_t stride_d, int64_t stride_h, int64_t stride_w, const int64_t &pad_mode,
|
||||
const std::vector<int64_t> &padding, std::vector<int64_t> *pad_list) {
|
||||
if (pad_mode == PadMode::VALID) {
|
||||
(void)pad_list->insert(pad_list->begin(), kPadDims, 0);
|
||||
(void)pad_list->insert(pad_list->begin(), kAvgPool3DPadDims, 0);
|
||||
} else if (pad_mode == PadMode::SAME) {
|
||||
if (stride_d == 0 || stride_h == 0 || stride_w == 0) {
|
||||
MS_LOG(EXCEPTION) << "stride_d or stride_h or stride_w must be non-zero";
|
||||
|
@ -106,18 +109,20 @@ void GetPadsByPadding(int64_t in_d, int64_t in_h, int64_t in_w, int64_t kernel_d
|
|||
int64_t pad_d = std::max((tail_d > 0 ? kernel_d - tail_d : kernel_d - stride_d), (int64_t)0);
|
||||
int64_t pad_h = std::max((tail_h > 0 ? kernel_h - tail_h : kernel_h - stride_h), (int64_t)0);
|
||||
int64_t pad_w = std::max((tail_w > 0 ? kernel_w - tail_w : kernel_w - stride_w), (int64_t)0);
|
||||
pad_list->push_back(static_cast<int64_t>(std::floor(pad_d / 2)));
|
||||
constexpr int twice = 2;
|
||||
pad_list->push_back(static_cast<int64_t>(std::floor(pad_d / twice)));
|
||||
pad_list->push_back(pad_d - pad_list->at(0));
|
||||
pad_list->push_back(static_cast<int64_t>(std::floor(pad_h / 2)));
|
||||
pad_list->push_back(pad_h - pad_list->at(2));
|
||||
pad_list->push_back(static_cast<int64_t>(std::floor(pad_w / 2)));
|
||||
pad_list->push_back(pad_w - pad_list->at(4));
|
||||
pad_list->push_back(static_cast<int64_t>(std::floor(pad_h / twice)));
|
||||
pad_list->push_back(pad_h - pad_list->at(kInputIndex2));
|
||||
pad_list->push_back(static_cast<int64_t>(std::floor(pad_w / twice)));
|
||||
pad_list->push_back(pad_w - pad_list->at(kInputIndex4));
|
||||
} else if (pad_mode == PadMode::PAD) {
|
||||
pad_list->assign(padding.begin(), padding.end());
|
||||
}
|
||||
}
|
||||
|
||||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
abstract::ShapePtr AvgPool3DInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
constexpr int64_t k5DInputDims = 5;
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto op_name = primitive->name();
|
||||
(void)CheckAndConvertUtils::CheckInteger("input size", int64_t(input_args.size()), kEqual, 1, op_name);
|
||||
|
@ -146,7 +151,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
|||
std::vector<int64_t> new_pad_list;
|
||||
GetPadsByPadding(in_d, in_h, in_w, kernel_d, kernel_h, kernel_w, stride_d, stride_h, stride_w, pad_mode, pad_list,
|
||||
&new_pad_list);
|
||||
if (new_pad_list.size() != kPadDims) {
|
||||
if (new_pad_list.size() != kAvgPool3DPadDims) {
|
||||
MS_LOG(EXCEPTION) << "pad_list size must be 6.";
|
||||
}
|
||||
primitive->set_attr(kPadList, MakeValue(new_pad_list));
|
||||
|
@ -159,7 +164,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
|||
return std::make_shared<abstract::Shape>(out_shape);
|
||||
}
|
||||
|
||||
TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
TypePtr AvgPool3DInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto op_name = primitive->name();
|
||||
(void)CheckAndConvertUtils::CheckInteger("input size", int64_t(input_args.size()), kEqual, 1, op_name);
|
||||
|
@ -174,7 +179,7 @@ TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBaseP
|
|||
|
||||
AbstractBasePtr AvgPool3DInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
return abstract::MakeAbstract(InferShape(primitive, input_args), InferType(primitive, input_args));
|
||||
return abstract::MakeAbstract(AvgPool3DInferShape(primitive, input_args), AvgPool3DInferType(primitive, input_args));
|
||||
}
|
||||
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(AvgPool3D, prim::kPrimAvgPool3D, AvgPool3DInfer, nullptr, true);
|
||||
|
|
|
@ -28,7 +28,7 @@
|
|||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
abstract::ShapePtr BiasAddInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
auto x = CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(prim_name, input_args, 0);
|
||||
|
@ -73,7 +73,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
|||
}
|
||||
return std::make_shared<abstract::Shape>(input_shape);
|
||||
}
|
||||
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
TypePtr BiasAddInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
auto prim_name = prim->name();
|
||||
const int64_t input_num = 2;
|
||||
|
@ -105,8 +105,8 @@ Format BiasAdd::get_format() const {
|
|||
void BiasAdd::Init(const Format &format) { this->set_format(format); }
|
||||
AbstractBasePtr BiasAddInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto infertype = InferType(primitive, input_args);
|
||||
auto infershape = InferShape(primitive, input_args);
|
||||
auto infertype = BiasAddInferType(primitive, input_args);
|
||||
auto infershape = BiasAddInferShape(primitive, input_args);
|
||||
return abstract::MakeAbstract(infershape, infertype);
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(BiasAdd, prim::kPrimBiasAdd, BiasAddInfer, nullptr, true);
|
||||
|
|
|
@ -28,12 +28,7 @@
|
|||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
constexpr auto kIndex1 = 1;
|
||||
constexpr auto kIndex3 = 3;
|
||||
constexpr auto kInputDim = 4;
|
||||
constexpr auto kInputNum = 1;
|
||||
|
||||
int64_t GetAndCheckFormat(const ValuePtr &value) {
|
||||
int64_t BNTrainingReduceGetAndCheckFormat(const ValuePtr &value) {
|
||||
int64_t data_format;
|
||||
bool result = CheckAndConvertUtils::GetDataFormatEnumValue(value, &data_format);
|
||||
if (!result || (data_format != Format::NHWC && data_format != Format::NCHW && data_format != Format::NCDHW)) {
|
||||
|
@ -48,13 +43,14 @@ abstract::TupleShapePtr BNTrainingReduceInferShape(const PrimitivePtr &primitive
|
|||
auto min_shape = input_shape[kMinShape];
|
||||
auto max_shape = input_shape[kMaxShape];
|
||||
|
||||
constexpr auto kInputDim = 4;
|
||||
(void)CheckAndConvertUtils::CheckInteger("x_dim", SizeToLong(shape.size()), kEqual, kInputDim, primitive->name());
|
||||
auto data_format_ptr = primitive->GetAttr("format");
|
||||
MS_EXCEPTION_IF_NULL(data_format_ptr);
|
||||
int64_t data_format = GetAndCheckFormat(data_format_ptr);
|
||||
size_t c_axis = kIndex1;
|
||||
int64_t data_format = BNTrainingReduceGetAndCheckFormat(data_format_ptr);
|
||||
size_t c_axis = kInputIndex1;
|
||||
if (data_format == Format::NHWC) {
|
||||
c_axis = kIndex3;
|
||||
c_axis = kInputIndex3;
|
||||
}
|
||||
ShapeVector batch = {shape[c_axis]};
|
||||
abstract::ShapePtr sum_shape;
|
||||
|
@ -83,6 +79,7 @@ TypePtr BNTrainingReduceInferType(const PrimitivePtr &primitive, const std::vect
|
|||
AbstractBasePtr BNTrainingReduceInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
constexpr int64_t kInputNum = 1;
|
||||
(void)CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, kInputNum, primitive->name());
|
||||
auto infer_type = BNTrainingReduceInferType(primitive, input_args);
|
||||
auto infer_shape = BNTrainingReduceInferShape(primitive, input_args);
|
||||
|
|
|
@ -26,11 +26,9 @@
|
|||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
constexpr auto kIndex1 = 1;
|
||||
constexpr auto kIndex3 = 3;
|
||||
constexpr auto kInputNum = 7;
|
||||
constexpr auto kBNTrainingUpdateInputNum = 7;
|
||||
|
||||
int64_t GetAndCheckFormat(const ValuePtr &value) {
|
||||
int64_t BNTrainingUpdateGetAndCheckFormat(const ValuePtr &value) {
|
||||
int64_t data_format;
|
||||
bool result = CheckAndConvertUtils::GetDataFormatEnumValue(value, &data_format);
|
||||
if (!result || (data_format != Format::NHWC && data_format != Format::NCHW && data_format != Format::NCDHW)) {
|
||||
|
@ -43,7 +41,7 @@ abstract::TupleShapePtr BNTrainingUpdateInferShape(const PrimitivePtr &primitive
|
|||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
|
||||
(void)CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, kInputNum, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, kBNTrainingUpdateInputNum, prim_name);
|
||||
auto input_x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape];
|
||||
auto sum_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape];
|
||||
auto square_sum_shape =
|
||||
|
@ -54,10 +52,10 @@ abstract::TupleShapePtr BNTrainingUpdateInferShape(const PrimitivePtr &primitive
|
|||
auto variance_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex6]->BuildShape())[kShape];
|
||||
auto data_format_ptr = primitive->GetAttr("format");
|
||||
MS_EXCEPTION_IF_NULL(data_format_ptr);
|
||||
int64_t data_format = GetAndCheckFormat(data_format_ptr);
|
||||
size_t c_axis = kIndex1;
|
||||
int64_t data_format = BNTrainingUpdateGetAndCheckFormat(data_format_ptr);
|
||||
size_t c_axis = kInputIndex1;
|
||||
if (data_format == Format::NHWC) {
|
||||
c_axis = kIndex3;
|
||||
c_axis = kInputIndex3;
|
||||
}
|
||||
// input_x rank should be equal with 4
|
||||
(void)CheckAndConvertUtils::CheckInteger("input_x rank", input_x_shape.size(), kEqual, 4, prim_name);
|
||||
|
@ -95,7 +93,7 @@ abstract::TupleShapePtr BNTrainingUpdateInferShape(const PrimitivePtr &primitive
|
|||
TuplePtr BNTrainingUpdateInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
(void)CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, kInputNum, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, kBNTrainingUpdateInputNum, prim_name);
|
||||
auto input_x_type = input_args[kInputIndex0]->BuildType();
|
||||
auto sum_type = input_args[kInputIndex1]->BuildType();
|
||||
auto square_sum_type = input_args[kInputIndex2]->BuildType();
|
||||
|
|
|
@ -23,7 +23,7 @@
|
|||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
abstract::ShapePtr CdistInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
|
@ -43,7 +43,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
|||
return std::make_shared<abstract::Shape>(out_shape);
|
||||
}
|
||||
|
||||
TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
TypePtr CdistInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
|
@ -60,8 +60,8 @@ AbstractBasePtr CdistInfer(const abstract::AnalysisEnginePtr &, const PrimitiveP
|
|||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
const int64_t input_num = 2;
|
||||
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name());
|
||||
auto infer_type = InferType(primitive, input_args);
|
||||
auto infer_shape = InferShape(primitive, input_args);
|
||||
auto infer_type = CdistInferType(primitive, input_args);
|
||||
auto infer_shape = CdistInferShape(primitive, input_args);
|
||||
return abstract::MakeAbstract(infer_shape, infer_type);
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(Cdist, prim::kPrimCdist, CdistInfer, nullptr, true);
|
||||
|
|
|
@ -28,7 +28,7 @@
|
|||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
abstract::ShapePtr CeLUInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
(void)CheckAndConvertUtils::CheckInteger("input numbers", input_args.size(), kGreaterEqual, 1, prim_name);
|
||||
|
@ -36,7 +36,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
|||
return shape_element;
|
||||
}
|
||||
|
||||
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
TypePtr CeLUInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
auto prim_name = prim->name();
|
||||
(void)CheckAndConvertUtils::CheckInteger("CeLU input numbers", input_args.size(), kEqual, 1, prim_name);
|
||||
|
@ -51,8 +51,8 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
|
|||
AbstractBasePtr CeLUInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto type = InferType(primitive, input_args);
|
||||
auto shape = InferShape(primitive, input_args);
|
||||
auto type = CeLUInferType(primitive, input_args);
|
||||
auto shape = CeLUInferShape(primitive, input_args);
|
||||
return abstract::MakeAbstract(shape, type);
|
||||
}
|
||||
|
||||
|
|
|
@ -22,7 +22,8 @@
|
|||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
abstract::ShapePtr ConstantOfShapeInferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
(void)CheckAndConvertUtils::CheckInteger("input args size", SizeToLong(input_args.size()), kEqual, 1,
|
||||
"ConstantOfShape");
|
||||
|
@ -30,7 +31,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
|||
return std::make_shared<abstract::Shape>(input_shape);
|
||||
}
|
||||
|
||||
TypePtr InferType(const PrimitivePtr &primitive) {
|
||||
TypePtr ConstantOfShapeInferType(const PrimitivePtr &primitive) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto data_type = TypeId(GetValue<int64_t>(primitive->GetAttr(kDataType)));
|
||||
return TypeIdToType(data_type);
|
||||
|
@ -57,7 +58,8 @@ std::vector<float> ConstantOfShape::get_value() const {
|
|||
}
|
||||
AbstractBasePtr ConstantOfShapeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
return std::make_shared<abstract::AbstractTensor>(InferType(primitive), InferShape(primitive, input_args)->shape());
|
||||
return std::make_shared<abstract::AbstractTensor>(ConstantOfShapeInferType(primitive),
|
||||
ConstantOfShapeInferShape(primitive, input_args)->shape());
|
||||
}
|
||||
REGISTER_PRIMITIVE_C(kNameConstantOfShape, ConstantOfShape);
|
||||
} // namespace ops
|
||||
|
|
|
@ -27,12 +27,12 @@
|
|||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
constexpr size_t kLenLogProbs = 3;
|
||||
constexpr size_t kLenTarget = 2;
|
||||
constexpr int64_t kMulti = 2;
|
||||
constexpr int64_t kInputSize = 4;
|
||||
abstract::TupleShapePtr CTCLossV2InferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
constexpr size_t kLenLogProbs = 3;
|
||||
constexpr size_t kLenTarget = 2;
|
||||
constexpr int64_t kMulti = 2;
|
||||
constexpr int64_t kInputSize = 4;
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
(void)CheckAndConvertUtils::CheckInteger("input numbers", SizeToLong(input_args.size()), kEqual, kInputSize,
|
||||
|
|
|
@ -26,11 +26,11 @@
|
|||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
constexpr size_t kLenLogProbs = 3;
|
||||
constexpr int64_t kInputSize = 7;
|
||||
constexpr size_t kIdx2 = 2;
|
||||
abstract::ShapePtr CTCLossV2GradInferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
constexpr size_t kLenLogProbs = 3;
|
||||
constexpr int64_t kInputSize = 7;
|
||||
constexpr size_t kIdx2 = 2;
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
(void)CheckAndConvertUtils::CheckInteger("input numbers", SizeToLong(input_args.size()), kEqual, kInputSize,
|
||||
|
|
|
@ -66,7 +66,8 @@ void CheckCTCLossInputs(const std::vector<AbstractBasePtr> &input_args, const st
|
|||
}
|
||||
}
|
||||
|
||||
abstract::TupleShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
abstract::TupleShapePtr CTCLossInferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto op_name = primitive->name();
|
||||
CheckCTCLossInputs(input_args, op_name);
|
||||
auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape());
|
||||
|
@ -90,7 +91,7 @@ abstract::TupleShapePtr InferShape(const PrimitivePtr &primitive, const std::vec
|
|||
return std::make_shared<abstract::TupleShape>(std::vector<abstract::BaseShapePtr>{loss_shape, gradient_shape});
|
||||
}
|
||||
|
||||
TuplePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
TuplePtr CTCLossInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto op_name = primitive->name();
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("labels_indices", input_args[kInputIndex1]->BuildType(), {kInt64},
|
||||
op_name);
|
||||
|
@ -108,8 +109,8 @@ TuplePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBase
|
|||
AbstractBasePtr CTCLossInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto types = InferType(primitive, input_args);
|
||||
auto shapes = InferShape(primitive, input_args);
|
||||
auto types = CTCLossInferType(primitive, input_args);
|
||||
auto shapes = CTCLossInferShape(primitive, input_args);
|
||||
return abstract::MakeAbstract(shapes, types);
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(CTCLoss, prim::kPrimCTCLoss, CTCLossInfer, nullptr, true);
|
||||
|
|
|
@ -25,7 +25,8 @@
|
|||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
abstract::TupleShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
abstract::TupleShapePtr CummaxInferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto x_shape = input_args[0]->BuildShape();
|
||||
auto x_shape_value = CheckAndConvertUtils::ConvertShapePtrToShapeMap(x_shape)[kShape];
|
||||
auto dim = GetValue<int64_t>(primitive->GetAttr("dim"));
|
||||
|
@ -40,7 +41,7 @@ abstract::TupleShapePtr InferShape(const PrimitivePtr &primitive, const std::vec
|
|||
return std::make_shared<abstract::TupleShape>(std::vector<abstract::BaseShapePtr>{x_shape, x_shape});
|
||||
}
|
||||
|
||||
TuplePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
TuplePtr CummaxInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto op_name = primitive->name();
|
||||
const std::set<TypePtr> valid_types = {kInt8, kInt32, kInt64, kUInt8, kUInt32, kFloat16, kFloat32};
|
||||
auto y_type = CheckAndConvertUtils::CheckTensorTypeValid("x", input_args[0]->BuildType(), valid_types, op_name);
|
||||
|
@ -53,8 +54,8 @@ AbstractBasePtr CummaxInfer(const abstract::AnalysisEnginePtr &, const Primitive
|
|||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
const int64_t input_num = 1;
|
||||
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name());
|
||||
auto types = InferType(primitive, input_args);
|
||||
auto shapes = InferShape(primitive, input_args);
|
||||
auto types = CummaxInferType(primitive, input_args);
|
||||
auto shapes = CummaxInferShape(primitive, input_args);
|
||||
return abstract::MakeAbstract(shapes, types);
|
||||
}
|
||||
|
||||
|
|
|
@ -23,7 +23,8 @@
|
|||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
abstract::TupleShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
abstract::TupleShapePtr CumminInferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
auto y_shape = CheckAndConvertUtils::GetTensorInputShape(prim_name, input_args, 0);
|
||||
|
@ -37,7 +38,7 @@ abstract::TupleShapePtr InferShape(const PrimitivePtr &primitive, const std::vec
|
|||
return std::make_shared<abstract::TupleShape>(std::vector<abstract::BaseShapePtr>{y_shape, y_shape});
|
||||
}
|
||||
|
||||
TuplePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
TuplePtr CumminInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
auto prim_name = prim->name();
|
||||
MS_EXCEPTION_IF_NULL(input_args[0]);
|
||||
|
@ -54,8 +55,8 @@ AbstractBasePtr CumminInfer(const abstract::AnalysisEnginePtr &, const Primitive
|
|||
auto prim_name = primitive->name();
|
||||
(void)CheckAndConvertUtils::CheckInteger("input numbers", input_args.size(), kEqual, 1, prim_name);
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto type = InferType(primitive, input_args);
|
||||
auto shape = InferShape(primitive, input_args);
|
||||
auto type = CumminInferType(primitive, input_args);
|
||||
auto shape = CumminInferShape(primitive, input_args);
|
||||
return abstract::MakeAbstract(shape, type);
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(Cummin, prim::kPrimCummin, CumminInfer, nullptr, true);
|
||||
|
|
|
@ -44,7 +44,8 @@ void DepthToSpace::Init(const int64_t block_size, const Format &format) {
|
|||
}
|
||||
|
||||
namespace {
|
||||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
abstract::ShapePtr DepthToSpaceInferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
(void)CheckAndConvertUtils::CheckInteger("input number", int64_t(input_args.size()), kEqual, 1, prim_name);
|
||||
|
@ -100,7 +101,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
|||
return std::make_shared<abstract::Shape>(out_shape, out_min_shape, out_max_shape);
|
||||
}
|
||||
|
||||
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
TypePtr DepthToSpaceInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
|
@ -113,8 +114,8 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
|
|||
|
||||
AbstractBasePtr DepthToSpaceInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto infer_type = InferType(primitive, input_args);
|
||||
auto infer_shape = InferShape(primitive, input_args);
|
||||
auto infer_type = DepthToSpaceInferType(primitive, input_args);
|
||||
auto infer_shape = DepthToSpaceInferShape(primitive, input_args);
|
||||
return abstract::MakeAbstract(infer_shape, infer_type);
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(DepthToSpace, prim::kPrimDepthToSpace, DepthToSpaceInfer, nullptr, true);
|
||||
|
|
|
@ -24,9 +24,8 @@
|
|||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
constexpr size_t kScaleNum = 2;
|
||||
|
||||
abstract::ShapePtr DiagPartInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
constexpr size_t kScaleNum = 2;
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto op_name = primitive->name();
|
||||
auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShapeTrack())[kShape];
|
||||
|
|
|
@ -26,13 +26,13 @@
|
|||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
abstract::ShapePtr DivInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
return BroadCastInferShape(prim_name, input_args);
|
||||
}
|
||||
|
||||
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
TypePtr DivInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
|
@ -45,8 +45,8 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
|
|||
|
||||
AbstractBasePtr DivInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args),
|
||||
InferShape(primitive, input_args)->shape());
|
||||
return std::make_shared<abstract::AbstractTensor>(DivInferType(primitive, input_args),
|
||||
DivInferShape(primitive, input_args)->shape());
|
||||
}
|
||||
REGISTER_PRIMITIVE_C(kNameDiv, Div);
|
||||
} // namespace ops
|
||||
|
|
|
@ -27,9 +27,6 @@
|
|||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
constexpr auto kX1Index = 0; // x1
|
||||
constexpr auto kX2Index = 1; // x2
|
||||
|
||||
template <typename T>
|
||||
void DivNoNanImpl(void *x1, void *x2, void *result, size_t size) {
|
||||
MS_EXCEPTION_IF_NULL(x1);
|
||||
|
@ -50,7 +47,7 @@ void DivNoNanImpl(void *x1, void *x2, void *result, size_t size) {
|
|||
}
|
||||
}
|
||||
|
||||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
abstract::ShapePtr DivNoNanInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
const int64_t valid_size = 2;
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
|
@ -63,7 +60,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
|||
return BroadCastInferShape(prim_name, input_args);
|
||||
}
|
||||
|
||||
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
TypePtr DivNoNanInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
|
@ -74,9 +71,11 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
|
|||
return input_args[0]->BuildType();
|
||||
}
|
||||
|
||||
ValuePtr InferValue(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto result_type = InferType(prim, input_args);
|
||||
auto result_shape = InferShape(prim, input_args)->cast<abstract::ShapePtr>();
|
||||
ValuePtr DivNoNanInferValue(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
constexpr auto kX1Index = 0;
|
||||
constexpr auto kX2Index = 1;
|
||||
auto result_type = DivNoNanInferType(prim, input_args);
|
||||
auto result_shape = DivNoNanInferShape(prim, input_args)->cast<abstract::ShapePtr>();
|
||||
auto x1 = input_args[kX1Index]->BuildValue();
|
||||
auto x2 = input_args[kX2Index]->BuildValue();
|
||||
if (x1 == nullptr || x2 == nullptr) {
|
||||
|
@ -160,10 +159,10 @@ ValuePtr InferValue(const PrimitivePtr &prim, const std::vector<AbstractBasePtr>
|
|||
|
||||
AbstractBasePtr DivNoNanInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto infer_shape = InferShape(primitive, input_args);
|
||||
auto infer_type = InferType(primitive, input_args);
|
||||
auto infer_shape = DivNoNanInferShape(primitive, input_args);
|
||||
auto infer_type = DivNoNanInferType(primitive, input_args);
|
||||
return abstract::MakeAbstract(infer_shape, infer_type);
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(DivNoNan, prim::kPrimDivNoNan, DivNoNanInfer, InferValue, true);
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(DivNoNan, prim::kPrimDivNoNan, DivNoNanInfer, DivNoNanInferValue, true);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -40,7 +40,8 @@ T GetAndCheckKeepProp(const tensor::TensorPtr &keep_prop) {
|
|||
return *value;
|
||||
}
|
||||
|
||||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
abstract::ShapePtr DropoutDoMaskInferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto op_name = primitive->name();
|
||||
auto x_shape = CheckAndConvertUtils::GetTensorInputShape(op_name, input_args, 0);
|
||||
auto mask_shape = CheckAndConvertUtils::GetTensorInputShape(op_name, input_args, 1);
|
||||
|
@ -72,7 +73,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
|||
return x_shape;
|
||||
}
|
||||
|
||||
TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
TypePtr DropoutDoMaskInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto op_name = primitive->name();
|
||||
auto keep_prop = input_args[kInputIndex2];
|
||||
MS_EXCEPTION_IF_NULL(keep_prop);
|
||||
|
@ -121,7 +122,8 @@ AbstractBasePtr DropoutDoMaskInfer(const abstract::AnalysisEnginePtr &, const Pr
|
|||
const int64_t input_num = 3;
|
||||
(void)CheckAndConvertUtils::CheckInteger("infer shape", SizeToLong(input_args.size()), kGreaterEqual, input_num,
|
||||
primitive->name());
|
||||
return abstract::MakeAbstract(InferShape(primitive, input_args), InferType(primitive, input_args));
|
||||
return abstract::MakeAbstract(DropoutDoMaskInferShape(primitive, input_args),
|
||||
DropoutDoMaskInferType(primitive, input_args));
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(DropoutDoMask, prim::kPrimDropoutDoMask, DropoutDoMaskInfer, nullptr, true);
|
||||
} // namespace ops
|
||||
|
|
|
@ -29,7 +29,7 @@
|
|||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
const int64_t mask_convert_len = 128;
|
||||
const int64_t kDropoutGenMaskMaskConvertLen = 128;
|
||||
ShapeVector CalDynamicOutputShape(const ValuePtrList value_list) {
|
||||
int64_t count = 1;
|
||||
size_t x_rank = value_list.size();
|
||||
|
@ -53,8 +53,8 @@ ShapeVector CalDynamicOutputShape(const ValuePtrList value_list) {
|
|||
}
|
||||
|
||||
// convert to bytes(8 bits) mask, using round up
|
||||
int64_t n128s = count / mask_convert_len;
|
||||
if ((count % mask_convert_len) != 0) {
|
||||
int64_t n128s = count / kDropoutGenMaskMaskConvertLen;
|
||||
if ((count % kDropoutGenMaskMaskConvertLen) != 0) {
|
||||
n128s++;
|
||||
}
|
||||
int64_t bytes_count = n128s * 16;
|
||||
|
@ -87,8 +87,8 @@ ShapeVector CalOutputShape(const AbstractBasePtrList shape_list) {
|
|||
count = count * value;
|
||||
}
|
||||
// convert to bytes(8 bits) mask, using round up
|
||||
int64_t n128s = count / mask_convert_len;
|
||||
if ((count % mask_convert_len) != 0) {
|
||||
int64_t n128s = count / kDropoutGenMaskMaskConvertLen;
|
||||
if ((count % kDropoutGenMaskMaskConvertLen) != 0) {
|
||||
n128s++;
|
||||
}
|
||||
int64_t bytes_count = n128s * 16;
|
||||
|
@ -97,7 +97,8 @@ ShapeVector CalOutputShape(const AbstractBasePtrList shape_list) {
|
|||
return shape;
|
||||
}
|
||||
|
||||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
abstract::ShapePtr DropoutGenMaskInferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto op_name = primitive->name();
|
||||
const int64_t input_num = 2;
|
||||
(void)CheckAndConvertUtils::CheckInteger("infer shape", SizeToLong(input_args.size()), kGreaterEqual, input_num,
|
||||
|
@ -152,7 +153,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
|||
out_shape = CalOutputShape(x_shape_data);
|
||||
return std::make_shared<abstract::Shape>(out_shape);
|
||||
}
|
||||
TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
TypePtr DropoutGenMaskInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto op_name = primitive->name();
|
||||
const std::set<TypePtr> valid_types = {kFloat32, kFloat16};
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("inputs", input_args[1]->BuildType(), valid_types, op_name);
|
||||
|
@ -163,7 +164,8 @@ TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBaseP
|
|||
AbstractBasePtr DropoutGenMaskInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
return abstract::MakeAbstract(InferShape(primitive, input_args), InferType(primitive, input_args));
|
||||
return abstract::MakeAbstract(DropoutGenMaskInferShape(primitive, input_args),
|
||||
DropoutGenMaskInferType(primitive, input_args));
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(DropoutGenMask, prim::kPrimDropoutGenMask, DropoutGenMaskInfer, nullptr, true);
|
||||
} // namespace ops
|
||||
|
|
|
@ -22,7 +22,8 @@
|
|||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
abstract::ShapePtr DynamicBroadcastToInferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
CheckAndConvertUtils::CheckInteger("input numbers", SizeToLong(input_args.size()), kEqual, 2, prim_name);
|
||||
|
@ -71,7 +72,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
|||
MS_EXCEPTION(TypeError) << "For BroadcastTo, input args must be tensor or tuple.";
|
||||
}
|
||||
|
||||
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
TypePtr DynamicBroadcastToInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
|
@ -84,7 +85,8 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
|
|||
|
||||
AbstractBasePtr DynamicBroadcastToInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
return abstract::MakeAbstract(InferShape(primitive, input_args), InferType(primitive, input_args));
|
||||
return abstract::MakeAbstract(DynamicBroadcastToInferShape(primitive, input_args),
|
||||
DynamicBroadcastToInferType(primitive, input_args));
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(DynamicBroadcastTo, prim::kPrimDynamicBroadcastTo, DynamicBroadcastToInfer, nullptr, true);
|
||||
} // namespace ops
|
||||
|
|
|
@ -27,7 +27,8 @@
|
|||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
abstract::ShapePtr DynamicResizeNearestNeighborInferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
auto x_shape_ptr = CheckAndConvertUtils::GetTensorInputShape(prim_name, input_args, 0);
|
||||
|
@ -109,7 +110,8 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
|||
return std::make_shared<abstract::Shape>(output_shape, min_shape, max_shape);
|
||||
}
|
||||
|
||||
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
TypePtr DynamicResizeNearestNeighborInferType(const PrimitivePtr &prim,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto valid_types = common_valid_types;
|
||||
valid_types.insert(kComplex128);
|
||||
valid_types.insert(kComplex64);
|
||||
|
@ -122,7 +124,8 @@ AbstractBasePtr DynamicResizeNearestNeighborInfer(const abstract::AnalysisEngine
|
|||
const int64_t input_num = 2;
|
||||
(void)CheckAndConvertUtils::CheckInteger("infer", SizeToLong(CheckAndConvertUtils::GetRemoveMonadAbsNum(input_args)),
|
||||
kEqual, input_num, prim_name);
|
||||
auto res = abstract::MakeAbstract(InferShape(primitive, input_args), InferType(primitive, input_args));
|
||||
auto res = abstract::MakeAbstract(DynamicResizeNearestNeighborInferShape(primitive, input_args),
|
||||
DynamicResizeNearestNeighborInferType(primitive, input_args));
|
||||
return res;
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(DynamicResizeNearestNeighbor, prim::kPrimDynamicResizeNearestNeighbor,
|
||||
|
|
|
@ -27,15 +27,15 @@
|
|||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
constexpr int ELL_VAL = 52;
|
||||
constexpr int LABEL_NUM = 52;
|
||||
constexpr int ELL_LEN = 3;
|
||||
constexpr int BIG_C_BEGIN = 26;
|
||||
constexpr int kEinsumEllVal = 52;
|
||||
constexpr int kEinsumLableNum = 52;
|
||||
constexpr int kEinsumEllLen = 3;
|
||||
static int64_t char_to_index(char cur_char) {
|
||||
if (cur_char <= 'z' && cur_char >= 'a') {
|
||||
return static_cast<int64_t>(cur_char - 'a');
|
||||
}
|
||||
return static_cast<int64_t>(cur_char - 'A' + BIG_C_BEGIN);
|
||||
constexpr int kBigCBegin = 26;
|
||||
return static_cast<int64_t>(cur_char - 'A' + kBigCBegin);
|
||||
}
|
||||
|
||||
static void seg_left_equation(const std::string &left_equation, const std::string &prim_name,
|
||||
|
@ -54,14 +54,14 @@ static void seg_left_equation(const std::string &left_equation, const std::strin
|
|||
<< "For " << prim_name
|
||||
<< ", each operand can contain contain only one ellipsis, but it has been found again.";
|
||||
}
|
||||
if (idx + ELL_LEN - 1 >= left_equation.length() || left_equation[idx + 1] != label ||
|
||||
left_equation[idx + ELL_LEN - 1] != label) {
|
||||
if (idx + kEinsumEllLen - 1 >= left_equation.length() || left_equation[idx + 1] != label ||
|
||||
left_equation[idx + kEinsumEllLen - 1] != label) {
|
||||
MS_EXCEPTION(ValueError) << "For " << prim_name
|
||||
<< ", An ellipsis in the equation should consist three \'.\', but got less than 3.";
|
||||
}
|
||||
idx += (ELL_LEN - 1);
|
||||
idx += (kEinsumEllLen - 1);
|
||||
found_ell = true;
|
||||
(*left_elements)[cur_element].emplace_back(ELL_VAL);
|
||||
(*left_elements)[cur_element].emplace_back(kEinsumEllVal);
|
||||
} else if (label == ',') {
|
||||
if ((found_ell && (*left_elements)[cur_element].size() > input_shapes[cur_element].size() + 1) ||
|
||||
(!found_ell && (*left_elements)[cur_element].size() != input_shapes[cur_element].size())) {
|
||||
|
@ -87,7 +87,7 @@ static void seg_left_equation(const std::string &left_equation, const std::strin
|
|||
<< ", the number of inputs should be equal to the number of inputs and equation's operand, but it does not.";
|
||||
}
|
||||
for (size_t i = 0; i < (*left_elements).size(); ++i) {
|
||||
auto it = std::find((*left_elements)[i].begin(), (*left_elements)[i].end(), ELL_VAL);
|
||||
auto it = std::find((*left_elements)[i].begin(), (*left_elements)[i].end(), kEinsumEllVal);
|
||||
if ((*left_elements)[i].size() != input_shapes[i].size() && it == (*left_elements)[i].end()) {
|
||||
MS_EXCEPTION(ValueError) << "For " << prim_name << ", The number of subscript in " << i
|
||||
<< " operand in the eqaution should match inputs[" << i << "].dim(), but it does not.";
|
||||
|
@ -104,7 +104,7 @@ static void seg_right_equation_with_arrow(const std::string &left_equation, cons
|
|||
out_shape->emplace_back(1);
|
||||
return;
|
||||
}
|
||||
std::vector<bool> exit_flag(LABEL_NUM, false);
|
||||
std::vector<bool> exit_flag(kEinsumLableNum, false);
|
||||
for (size_t idx = 0; idx < right_equation.length(); ++idx) {
|
||||
if (left_equation.find(right_equation[idx]) == std::string::npos) {
|
||||
MS_EXCEPTION(ValueError)
|
||||
|
@ -119,14 +119,15 @@ static void seg_right_equation_with_arrow(const std::string &left_equation, cons
|
|||
<< "For " << prim_name
|
||||
<< ", each operand can contain contain only one ellipsis, but it has been found again.";
|
||||
}
|
||||
if ((idx + ELL_LEN - 1 >= right_equation.length()) ||
|
||||
(right_equation[idx + 1] != '.' || right_equation[idx + ELL_LEN - 1] != '.')) {
|
||||
if ((idx + kEinsumEllLen - 1 >= right_equation.length()) ||
|
||||
(right_equation[idx + 1] != '.' || right_equation[idx + kEinsumEllLen - 1] != '.')) {
|
||||
MS_EXCEPTION(ValueError) << "For " << prim_name
|
||||
<< ", An ellipsis in the equation should consist three \'.\', but got less than 3.";
|
||||
}
|
||||
idx += (ELL_LEN - 1);
|
||||
idx += (kEinsumEllLen - 1);
|
||||
found_ell = true;
|
||||
out_shape->insert(out_shape->end(), (*element_shape_map)[ELL_VAL].begin(), (*element_shape_map)[ELL_VAL].end());
|
||||
out_shape->insert(out_shape->end(), (*element_shape_map)[kEinsumEllVal].begin(),
|
||||
(*element_shape_map)[kEinsumEllVal].end());
|
||||
} else if (isalpha(right_equation[idx])) {
|
||||
auto val = char_to_index(right_equation[idx]);
|
||||
if (exit_flag[val]) {
|
||||
|
@ -149,7 +150,8 @@ static void seg_right_equation_without_arrow(const std::string &left_equation,
|
|||
const std::vector<int64_t> &element_count,
|
||||
std::vector<int64_t> *out_shape) {
|
||||
if (left_equation.find('.') != std::string::npos) {
|
||||
out_shape->insert(out_shape->begin(), (*element_shape_map)[ELL_VAL].begin(), (*element_shape_map)[ELL_VAL].end());
|
||||
out_shape->insert(out_shape->begin(), (*element_shape_map)[kEinsumEllVal].begin(),
|
||||
(*element_shape_map)[kEinsumEllVal].end());
|
||||
}
|
||||
for (size_t idx = 0; idx < element_count.size(); ++idx) {
|
||||
if (element_count[idx] == 1) {
|
||||
|
@ -167,7 +169,7 @@ static void element_map_shape(const std::string &prim_name, const std::vector<st
|
|||
for (size_t idx_input = 0; idx_input < input_shapes.size(); ++idx_input) {
|
||||
auto cur_shape = input_shapes[idx_input];
|
||||
size_t idx_left = 0;
|
||||
while (idx_left < left_elements[idx_input].size() && left_elements[idx_input][idx_left] != ELL_VAL) {
|
||||
while (idx_left < left_elements[idx_input].size() && left_elements[idx_input][idx_left] != kEinsumEllVal) {
|
||||
auto cur_element = left_elements[idx_input][idx_left];
|
||||
if (element_shape_map->find(cur_element) != element_shape_map->end()) {
|
||||
if ((*element_shape_map)[cur_element][0] != input_shapes[idx_input][idx_left]) {
|
||||
|
@ -185,7 +187,7 @@ static void element_map_shape(const std::string &prim_name, const std::vector<st
|
|||
if (idx_left != left_elements[idx_input].size()) {
|
||||
auto idx_element_right = left_elements[idx_input].size() - 1;
|
||||
auto idx_shape_right = input_shapes[idx_input].size() - 1;
|
||||
while (idx_element_right > idx_left && left_elements[idx_input][idx_element_right] != ELL_VAL) {
|
||||
while (idx_element_right > idx_left && left_elements[idx_input][idx_element_right] != kEinsumEllVal) {
|
||||
auto cur_element = left_elements[idx_input][idx_element_right];
|
||||
if (element_shape_map->find(cur_element) != element_shape_map->end()) {
|
||||
if ((*element_shape_map)[cur_element][0] != input_shapes[idx_input][idx_shape_right]) {
|
||||
|
@ -202,14 +204,14 @@ static void element_map_shape(const std::string &prim_name, const std::vector<st
|
|||
}
|
||||
std::vector<int64_t> temp_vec(input_shapes[idx_input].begin() + idx_left,
|
||||
input_shapes[idx_input].begin() + idx_shape_right + 1);
|
||||
if (element_shape_map->find(ELL_VAL) != element_shape_map->end()) {
|
||||
if ((*element_shape_map)[ELL_VAL] != temp_vec) {
|
||||
if (element_shape_map->find(kEinsumEllVal) != element_shape_map->end()) {
|
||||
if ((*element_shape_map)[kEinsumEllVal] != temp_vec) {
|
||||
MS_EXCEPTION(ValueError)
|
||||
<< "For " << prim_name
|
||||
<< ", the same ellipsis in equation can only represent the same dimension in inputs, but it does not.";
|
||||
}
|
||||
} else {
|
||||
(*element_shape_map)[ELL_VAL] = temp_vec;
|
||||
(*element_shape_map)[kEinsumEllVal] = temp_vec;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -225,7 +227,7 @@ std::string Einsum::get_equation() const {
|
|||
return GetValue<std::string>(value_ptr);
|
||||
}
|
||||
|
||||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
abstract::ShapePtr EinsumInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto prim_name = primitive->name();
|
||||
auto equation = GetValue<std::string>(primitive->GetAttr(kEquation));
|
||||
equation.erase(std::remove(equation.begin(), equation.end(), ' '), equation.end());
|
||||
|
@ -270,7 +272,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
|||
|
||||
const auto left_equation = equation.substr(0, seg_pos);
|
||||
std::vector<std::vector<int64_t>> left_elements(input_shapes.size());
|
||||
std::vector<int64_t> element_count(LABEL_NUM, 0);
|
||||
std::vector<int64_t> element_count(kEinsumLableNum, 0);
|
||||
std::unordered_map<int64_t, std::vector<int64_t>> element_shape_map;
|
||||
std::vector<int64_t> out_shape;
|
||||
seg_left_equation(left_equation, prim_name, input_shapes, &left_elements, &element_count);
|
||||
|
@ -284,7 +286,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
|||
}
|
||||
return std::make_shared<abstract::Shape>(out_shape);
|
||||
}
|
||||
TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
TypePtr EinsumInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto elements = input_args[0]->isa<abstract::AbstractTuple>()
|
||||
? input_args[0]->cast<abstract::AbstractTuplePtr>()->elements()
|
||||
: input_args[0]->cast<abstract::AbstractListPtr>()->elements();
|
||||
|
@ -298,8 +300,8 @@ AbstractBasePtr EinsumInfer(const abstract::AnalysisEnginePtr &, const Primitive
|
|||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
const int64_t input_num = 1;
|
||||
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name());
|
||||
auto res =
|
||||
std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args), InferShape(primitive, input_args));
|
||||
auto res = std::make_shared<abstract::AbstractTensor>(EinsumInferType(primitive, input_args),
|
||||
EinsumInferShape(primitive, input_args));
|
||||
return res;
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(Einsum, prim::kPrimEinsum, EinsumInfer, nullptr, true);
|
||||
|
|
|
@ -27,7 +27,7 @@
|
|||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
abstract::ShapePtr ExpInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
(void)CheckAndConvertUtils::CheckInteger("input numbers", int64_t(input_args.size()), kEqual, 1, prim_name);
|
||||
|
@ -41,7 +41,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
|||
return shape_ptr;
|
||||
}
|
||||
|
||||
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
TypePtr ExpInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, 1, prim->name());
|
||||
for (const auto &item : input_args) {
|
||||
|
@ -56,7 +56,7 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
|
|||
} // namespace
|
||||
AbstractBasePtr ExpInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
return abstract::MakeAbstract(InferShape(primitive, input_args), InferType(primitive, input_args));
|
||||
return abstract::MakeAbstract(ExpInferShape(primitive, input_args), ExpInferType(primitive, input_args));
|
||||
}
|
||||
REGISTER_PRIMITIVE_C(kNameExp, Exp);
|
||||
} // namespace ops
|
||||
|
|
|
@ -26,7 +26,8 @@
|
|||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
abstract::ShapePtr FakeQuantWithMinMaxVarsInferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape];
|
||||
|
@ -45,7 +46,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
|||
return std::make_shared<abstract::Shape>(in_shape);
|
||||
}
|
||||
|
||||
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
TypePtr FakeQuantWithMinMaxVarsInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
const std::set<TypePtr> valid_types = {kFloat16, kFloat32};
|
||||
if (std::any_of(input_args.begin(), input_args.end(), [](const AbstractBasePtr arg) { return arg == nullptr; })) {
|
||||
MS_LOG(EXCEPTION) << "nullptr";
|
||||
|
@ -81,8 +82,8 @@ int64_t FakeQuantWithMinMaxVars::get_num_bits() const {
|
|||
}
|
||||
AbstractBasePtr FakeQuantWithMinMaxVarsInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args),
|
||||
InferShape(primitive, input_args)->shape());
|
||||
return std::make_shared<abstract::AbstractTensor>(FakeQuantWithMinMaxVarsInferType(primitive, input_args),
|
||||
FakeQuantWithMinMaxVarsInferShape(primitive, input_args)->shape());
|
||||
}
|
||||
REGISTER_PRIMITIVE_C(kNameFakeQuantWithMinMaxVars, FakeQuantWithMinMaxVars);
|
||||
} // namespace ops
|
||||
|
|
|
@ -22,7 +22,7 @@
|
|||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
abstract::ShapePtr FftImagInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto name = primitive->name();
|
||||
MS_LOG(DEBUG) << "Infer shape for " << name;
|
||||
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
|
@ -36,7 +36,7 @@ AbstractBasePtr FftImagInfer(const abstract::AnalysisEnginePtr &, const Primitiv
|
|||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
const int64_t input_num = 1;
|
||||
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name());
|
||||
return std::make_shared<abstract::AbstractTensor>(kFloat32, InferShape(primitive, input_args));
|
||||
return std::make_shared<abstract::AbstractTensor>(kFloat32, FftImagInferShape(primitive, input_args));
|
||||
}
|
||||
REGISTER_PRIMITIVE_C(kNameFftImag, FftImag);
|
||||
} // namespace ops
|
||||
|
|
|
@ -21,7 +21,7 @@
|
|||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
abstract::ShapePtr FlattenInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
(void)CheckAndConvertUtils::CheckInteger("input args size", SizeToLong(input_args.size()), kGreaterEqual, 1,
|
||||
|
@ -58,7 +58,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
|||
return std::make_shared<abstract::Shape>(out_shape, out_min_shape, out_max_shape);
|
||||
}
|
||||
|
||||
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
TypePtr FlattenInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
|
@ -71,8 +71,8 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
|
|||
|
||||
AbstractBasePtr FlattenInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto infer_type = InferShape(primitive, input_args);
|
||||
auto infer_shape = InferType(primitive, input_args);
|
||||
auto infer_type = FlattenInferShape(primitive, input_args);
|
||||
auto infer_shape = FlattenInferType(primitive, input_args);
|
||||
return abstract::MakeAbstract(infer_type, infer_shape);
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(Flatten, prim::kPrimFlatten, FlattenInfer, nullptr, true);
|
||||
|
|
|
@ -36,13 +36,13 @@ ActivationType AddFusion::get_activation_type() const {
|
|||
void AddFusion::Init(const ActivationType activation_type) { this->set_activation_type(activation_type); }
|
||||
|
||||
namespace {
|
||||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
abstract::ShapePtr AddFusionInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto op_name = primitive->name();
|
||||
return BroadCastInferShape(op_name, input_args);
|
||||
}
|
||||
|
||||
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
TypePtr AddFusionInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
if (std::any_of(input_args.begin(), input_args.end(), [](const AbstractBasePtr &a) { return a == nullptr; })) {
|
||||
MS_LOG(EXCEPTION) << "nullptr";
|
||||
}
|
||||
|
@ -55,8 +55,8 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
|
|||
|
||||
AbstractBasePtr AddFusionInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args),
|
||||
InferShape(primitive, input_args));
|
||||
return std::make_shared<abstract::AbstractTensor>(AddFusionInferType(primitive, input_args),
|
||||
AddFusionInferShape(primitive, input_args));
|
||||
}
|
||||
REGISTER_PRIMITIVE_C(kNameAddFusion, AddFusion);
|
||||
} // namespace ops
|
||||
|
|
|
@ -51,7 +51,8 @@ ActivationType AvgPoolFusion::get_activation_type() const {
|
|||
}
|
||||
|
||||
namespace {
|
||||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
abstract::ShapePtr AvgPoolFusionInferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
for (auto item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
|
@ -96,7 +97,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
|||
return std::make_shared<abstract::Shape>(out_shape);
|
||||
}
|
||||
|
||||
TypePtr InferType(const std::vector<AbstractBasePtr> &input_args) {
|
||||
TypePtr AvgPoolFusionInferType(const std::vector<AbstractBasePtr> &input_args) {
|
||||
for (auto item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
|
@ -106,7 +107,8 @@ TypePtr InferType(const std::vector<AbstractBasePtr> &input_args) {
|
|||
|
||||
AbstractBasePtr AvgPoolFusionInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
return std::make_shared<abstract::AbstractTensor>(InferType(input_args), InferShape(primitive, input_args));
|
||||
return std::make_shared<abstract::AbstractTensor>(AvgPoolFusionInferType(input_args),
|
||||
AvgPoolFusionInferShape(primitive, input_args));
|
||||
}
|
||||
REGISTER_PRIMITIVE_C(kNameAvgPoolFusion, AvgPoolFusion);
|
||||
} // namespace ops
|
||||
|
|
|
@ -51,7 +51,8 @@ ActivationType MaxPoolFusion::get_activation_type() const {
|
|||
}
|
||||
|
||||
namespace {
|
||||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
abstract::ShapePtr MaxPoolFusionInferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
MS_EXCEPTION_IF_NULL(input_args[0]);
|
||||
auto op_name = primitive->name();
|
||||
|
@ -93,7 +94,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
|||
return std::make_shared<abstract::Shape>(out_shape);
|
||||
}
|
||||
|
||||
TypePtr InferType(const std::vector<AbstractBasePtr> &input_args) { return input_args[0]->BuildType(); }
|
||||
TypePtr MaxPoolFusionInferType(const std::vector<AbstractBasePtr> &input_args) { return input_args[0]->BuildType(); }
|
||||
} // namespace
|
||||
|
||||
AbstractBasePtr MaxPoolFusionInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
|
@ -102,7 +103,8 @@ AbstractBasePtr MaxPoolFusionInfer(const abstract::AnalysisEnginePtr &, const Pr
|
|||
for (auto item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
return std::make_shared<abstract::AbstractTensor>(InferType(input_args), InferShape(primitive, input_args));
|
||||
return std::make_shared<abstract::AbstractTensor>(MaxPoolFusionInferType(input_args),
|
||||
MaxPoolFusionInferShape(primitive, input_args));
|
||||
}
|
||||
REGISTER_PRIMITIVE_C(kNameMaxPoolFusion, MaxPoolFusion);
|
||||
} // namespace ops
|
||||
|
|
|
@ -35,13 +35,13 @@ float PowFusion::get_scale() const { return GetValue<float>(GetAttr(kScale)); }
|
|||
float PowFusion::get_shift() const { return GetValue<float>(GetAttr(kShift)); }
|
||||
|
||||
namespace {
|
||||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
abstract::ShapePtr PowFusionInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto op_name = primitive->name();
|
||||
return BroadCastInferShape(op_name, input_args);
|
||||
}
|
||||
|
||||
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
TypePtr PowFusionInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
|
@ -59,8 +59,8 @@ AbstractBasePtr PowFusionInfer(const abstract::AnalysisEnginePtr &, const Primit
|
|||
const int64_t input_num = 2;
|
||||
(void)CheckAndConvertUtils::CheckInteger("PowFusion infer", SizeToLong(input_args.size()), kGreaterEqual, input_num,
|
||||
primitive->name());
|
||||
return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args),
|
||||
InferShape(primitive, input_args));
|
||||
return std::make_shared<abstract::AbstractTensor>(PowFusionInferType(primitive, input_args),
|
||||
PowFusionInferShape(primitive, input_args));
|
||||
}
|
||||
REGISTER_PRIMITIVE_C(kNamePowFusion, PowFusion);
|
||||
} // namespace ops
|
||||
|
|
|
@ -27,7 +27,7 @@
|
|||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
abstract::ShapePtr GeLUInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(input_args[0]);
|
||||
auto x = input_args[0]->BuildShape();
|
||||
MS_EXCEPTION_IF_NULL(x);
|
||||
|
@ -35,7 +35,8 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
|||
MS_EXCEPTION_IF_NULL(shape_ptr);
|
||||
return shape_ptr;
|
||||
}
|
||||
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
|
||||
TypePtr GeLUInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
std::map<std::string, TypePtr> types;
|
||||
const std::set<TypePtr> valid_types = {kFloat16, kFloat32};
|
||||
MS_EXCEPTION_IF_NULL(input_args[0]);
|
||||
|
@ -48,8 +49,8 @@ AbstractBasePtr GeLUInfer(const abstract::AnalysisEnginePtr &, const PrimitivePt
|
|||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
const int64_t input_num = 1;
|
||||
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name());
|
||||
auto infer_type = InferType(primitive, input_args);
|
||||
auto infer_shape = InferShape(primitive, input_args);
|
||||
auto infer_type = GeLUInferType(primitive, input_args);
|
||||
auto infer_shape = GeLUInferShape(primitive, input_args);
|
||||
return abstract::MakeAbstract(infer_shape, infer_type);
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(GeLU, prim::kPrimGeLU, GeLUInfer, nullptr, true);
|
||||
|
|
|
@ -24,7 +24,7 @@
|
|||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
abstract::ShapePtr GerInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
for (const auto &item : input_args) {
|
||||
|
@ -40,7 +40,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
|||
return std::make_shared<abstract::Shape>(out_shape);
|
||||
}
|
||||
|
||||
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
TypePtr GerInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
auto prim_name = prim->name();
|
||||
const std::set<TypePtr> valid_types = {kFloat16, kFloat32};
|
||||
|
@ -58,8 +58,8 @@ AbstractBasePtr GerInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr
|
|||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
const int64_t input_num = 2;
|
||||
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name());
|
||||
auto type = InferType(primitive, input_args);
|
||||
auto shape = InferShape(primitive, input_args);
|
||||
auto type = GerInferType(primitive, input_args);
|
||||
auto shape = GerInferShape(primitive, input_args);
|
||||
return abstract::MakeAbstract(shape, type);
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(Ger, prim::kPrimGer, GerInfer, nullptr, true);
|
||||
|
|
|
@ -24,7 +24,7 @@
|
|||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
abstract::ShapePtr AbsGradInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
const int64_t input_num = 2;
|
||||
|
@ -39,7 +39,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
|||
return shape_element;
|
||||
}
|
||||
|
||||
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
TypePtr AbsGradInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
auto prim_name = prim->name();
|
||||
const int64_t input_num = 2;
|
||||
|
@ -57,8 +57,8 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
|
|||
|
||||
AbstractBasePtr AbsGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto type = InferType(primitive, input_args);
|
||||
auto shape = InferShape(primitive, input_args);
|
||||
auto type = AbsGradInferType(primitive, input_args);
|
||||
auto shape = AbsGradInferShape(primitive, input_args);
|
||||
return abstract::MakeAbstract(shape, type);
|
||||
}
|
||||
|
||||
|
|
|
@ -19,8 +19,6 @@
|
|||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
const size_t input_num = 2;
|
||||
|
||||
abstract::ShapePtr ACosGradInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto prim_name = primitive->name();
|
||||
(void)CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(prim_name, input_args, 0);
|
||||
|
@ -46,6 +44,7 @@ AbstractBasePtr ACosGradInfer(const abstract::AnalysisEnginePtr &, const Primiti
|
|||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
const size_t input_num = 2;
|
||||
(void)CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, prim_name);
|
||||
MS_EXCEPTION_IF_NULL(input_args[kInputIndex0]);
|
||||
MS_EXCEPTION_IF_NULL(input_args[kInputIndex1]);
|
||||
|
|
|
@ -19,8 +19,6 @@
|
|||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
const size_t input_num = 2;
|
||||
|
||||
abstract::ShapePtr AcoshGradInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto prim_name = primitive->name();
|
||||
(void)CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(prim_name, input_args, 0);
|
||||
|
@ -46,6 +44,7 @@ AbstractBasePtr AcoshGradInfer(const abstract::AnalysisEnginePtr &, const Primit
|
|||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
const size_t input_num = 2;
|
||||
(void)CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, prim_name);
|
||||
auto types = AcoshGradInferType(primitive, input_args);
|
||||
auto shapes = AcoshGradInferShape(primitive, input_args);
|
||||
|
|
|
@ -19,8 +19,6 @@
|
|||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
const size_t InputNum = 2;
|
||||
|
||||
abstract::ShapePtr AsinGradInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto prim_name = primitive->name();
|
||||
(void)CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(prim_name, input_args, 0);
|
||||
|
@ -46,6 +44,7 @@ AbstractBasePtr AsinGradInfer(const abstract::AnalysisEnginePtr &, const Primiti
|
|||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
const size_t InputNum = 2;
|
||||
(void)CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, InputNum, prim_name);
|
||||
MS_EXCEPTION_IF_NULL(input_args[kInputIndex0]);
|
||||
MS_EXCEPTION_IF_NULL(input_args[kInputIndex1]);
|
||||
|
|
|
@ -19,8 +19,6 @@
|
|||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
const size_t InputNum = 2;
|
||||
|
||||
abstract::ShapePtr AsinhGradInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto prim_name = primitive->name();
|
||||
(void)CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(prim_name, input_args, 0);
|
||||
|
@ -46,6 +44,7 @@ AbstractBasePtr AsinhGradInfer(const abstract::AnalysisEnginePtr &, const Primit
|
|||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
const size_t InputNum = 2;
|
||||
(void)CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, InputNum, prim_name);
|
||||
auto types = AsinhGradInferType(primitive, input_args);
|
||||
auto shapes = AsinhGradInferShape(primitive, input_args);
|
||||
|
|
|
@ -23,9 +23,8 @@
|
|||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
constexpr int64_t k5DInputDims = 5;
|
||||
|
||||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
abstract::ShapePtr AvgPool3DGradInferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto op_name = primitive->name();
|
||||
const int64_t input_num = 2;
|
||||
|
@ -34,6 +33,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
|||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
auto grad_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->GetShapeTrack())[kShape];
|
||||
constexpr int64_t k5DInputDims = 5;
|
||||
(void)CheckAndConvertUtils::CheckInteger("grad_rank", SizeToLong(grad_shape.size()), kEqual, k5DInputDims, op_name);
|
||||
std::vector<int64_t> origin_input_size;
|
||||
if (input_args[0]->isa<abstract::AbstractTuple>()) { // origin_size is tuple
|
||||
|
@ -44,7 +44,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
|||
return std::make_shared<abstract::Shape>(origin_input_size);
|
||||
}
|
||||
|
||||
TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
TypePtr AvgPool3DGradInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto op_name = primitive->name();
|
||||
const int64_t input_num = 2;
|
||||
|
@ -60,8 +60,8 @@ TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBaseP
|
|||
|
||||
AbstractBasePtr AvgPool3DGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto res = std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args),
|
||||
InferShape(primitive, input_args)->shape());
|
||||
auto res = std::make_shared<abstract::AbstractTensor>(AvgPool3DGradInferType(primitive, input_args),
|
||||
AvgPool3DGradInferShape(primitive, input_args)->shape());
|
||||
return res;
|
||||
}
|
||||
|
||||
|
|
|
@ -37,7 +37,8 @@ std::vector<int64_t> GetFormatShape(const int64_t &format, const std::vector<int
|
|||
}
|
||||
return output_shape;
|
||||
}
|
||||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
abstract::ShapePtr BiasAddGradInferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
(void)CheckAndConvertUtils::CheckInteger("input numbers", SizeToLong(input_args.size()), kEqual, 1, prim_name);
|
||||
|
@ -57,7 +58,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
|||
}
|
||||
return std::make_shared<abstract::Shape>(input_shape_);
|
||||
}
|
||||
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
TypePtr BiasAddGradInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
auto prim_name = prim->name();
|
||||
(void)CheckAndConvertUtils::CheckInteger("BiasAddGrad infer", SizeToLong(input_args.size()), kEqual, 1, prim_name);
|
||||
|
@ -72,7 +73,8 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
|
|||
} // namespace
|
||||
AbstractBasePtr BiasAddGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
return abstract::MakeAbstract(InferShape(primitive, input_args), InferType(primitive, input_args));
|
||||
return abstract::MakeAbstract(BiasAddGradInferShape(primitive, input_args),
|
||||
BiasAddGradInferType(primitive, input_args));
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(BiasAddGrad, prim::kPrimBiasAddGrad, BiasAddGradInfer, nullptr, true);
|
||||
} // namespace ops
|
||||
|
|
|
@ -26,13 +26,13 @@
|
|||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
constexpr auto kInputNum = 4;
|
||||
constexpr auto kBNTrainingUpdateGradInputNum = 4;
|
||||
|
||||
abstract::TupleShapePtr BNTrainingUpdateGradInferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
(void)CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, kInputNum, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, kBNTrainingUpdateGradInputNum, prim_name);
|
||||
auto batch_mean_shape_ptr = input_args[kInputIndex2]->BuildShape();
|
||||
auto batch_variance_shape_ptr = input_args[kInputIndex3]->BuildShape();
|
||||
return std::make_shared<abstract::TupleShape>(
|
||||
|
@ -42,7 +42,7 @@ abstract::TupleShapePtr BNTrainingUpdateGradInferShape(const PrimitivePtr &primi
|
|||
TuplePtr BNTrainingUpdateGradInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
(void)CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, kInputNum, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, kBNTrainingUpdateGradInputNum, prim_name);
|
||||
auto batch_mean_type_ptr = input_args[kInputIndex2]->BuildType();
|
||||
auto batch_variance_type_ptr = input_args[kInputIndex3]->BuildType();
|
||||
return std::make_shared<Tuple>(std::vector<TypePtr>{batch_mean_type_ptr, batch_variance_type_ptr});
|
||||
|
|
|
@ -23,7 +23,7 @@
|
|||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
abstract::ShapePtr CdistGradInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
for (const auto &item : input_args) {
|
||||
|
@ -44,7 +44,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
|||
return std::make_shared<abstract::Shape>(out_shape);
|
||||
}
|
||||
|
||||
TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
TypePtr CdistGradInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
|
@ -63,8 +63,8 @@ AbstractBasePtr CdistGradInfer(const abstract::AnalysisEnginePtr &, const Primit
|
|||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
const int64_t input_num = 4;
|
||||
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name());
|
||||
auto infer_type = InferType(primitive, input_args);
|
||||
auto infer_shape = InferShape(primitive, input_args);
|
||||
auto infer_type = CdistGradInferType(primitive, input_args);
|
||||
auto infer_shape = CdistGradInferShape(primitive, input_args);
|
||||
return abstract::MakeAbstract(infer_shape, infer_type);
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(CdistGrad, prim::kPrimCdistGrad, CdistGradInfer, nullptr, true);
|
||||
|
|
|
@ -24,16 +24,13 @@
|
|||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
constexpr size_t kDoutIndex = 0;
|
||||
constexpr size_t kInputIndex = 1;
|
||||
constexpr size_t kFilterSizeIndex = 2;
|
||||
constexpr size_t kStride2dSize = 2;
|
||||
constexpr size_t kStride4dSize = 4;
|
||||
constexpr size_t kConv2DBackpropFilterDoutIndex = 0;
|
||||
constexpr size_t kConv2DBackpropFilterInputIndex = 1;
|
||||
|
||||
void TransStrideTo4D(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto prim_name = primitive->name();
|
||||
auto x_shape = CheckAndConvertUtils::GetTensorInputShape(prim_name, input_args, kInputIndex);
|
||||
auto dout_shape = CheckAndConvertUtils::GetTensorInputShape(prim_name, input_args, kDoutIndex);
|
||||
auto x_shape = CheckAndConvertUtils::GetTensorInputShape(prim_name, input_args, kConv2DBackpropFilterInputIndex);
|
||||
auto dout_shape = CheckAndConvertUtils::GetTensorInputShape(prim_name, input_args, kConv2DBackpropFilterDoutIndex);
|
||||
if (!x_shape->IsDynamic() && !dout_shape->IsDynamic()) {
|
||||
return;
|
||||
}
|
||||
|
@ -41,6 +38,7 @@ void TransStrideTo4D(const PrimitivePtr &primitive, const std::vector<AbstractBa
|
|||
auto stride = primitive->GetAttr(kStride);
|
||||
MS_EXCEPTION_IF_NULL(stride);
|
||||
auto stride_value = GetValue<std::vector<int64_t>>(stride);
|
||||
constexpr size_t kStride2dSize = 2;
|
||||
if (stride_value.size() == kStride2dSize) {
|
||||
std::vector<int64_t> stride_value_4d(stride_value);
|
||||
(void)stride_value_4d.insert(stride_value_4d.begin(), 1);
|
||||
|
@ -56,6 +54,7 @@ abstract::ShapePtr Conv2DBackpropFilterInferShape(const PrimitivePtr &primitive,
|
|||
std::vector<int64_t> out_shape;
|
||||
abstract::ShapePtr ret_shape;
|
||||
TransStrideTo4D(primitive, input_args);
|
||||
constexpr size_t kFilterSizeIndex = 2;
|
||||
auto filter_size = input_args[kFilterSizeIndex];
|
||||
auto filter_size_v = filter_size->BuildValue();
|
||||
MS_EXCEPTION_IF_NULL(filter_size_v);
|
||||
|
@ -115,8 +114,8 @@ TypePtr Conv2DBackpropFilterInferType(const PrimitivePtr &prim, const std::vecto
|
|||
auto prim_name = prim->name();
|
||||
// check
|
||||
std::map<std::string, TypePtr> types;
|
||||
(void)types.emplace("x", input_args[kInputIndex]->BuildType());
|
||||
(void)types.emplace("doutput", input_args[kDoutIndex]->BuildType());
|
||||
(void)types.emplace("x", input_args[kConv2DBackpropFilterInputIndex]->BuildType());
|
||||
(void)types.emplace("doutput", input_args[kConv2DBackpropFilterDoutIndex]->BuildType());
|
||||
std::set<TypePtr> valid_x_type = {kInt8, kInt32, kFloat16, kFloat32};
|
||||
return CheckAndConvertUtils::CheckTensorTypeSame(types, valid_x_type, prim_name);
|
||||
}
|
||||
|
@ -131,6 +130,7 @@ void Conv2DBackpropFilter::Init(const int64_t out_channel, const std::vector<int
|
|||
set_pad_mode(pad_mode);
|
||||
set_pad_list(pad_list);
|
||||
set_mode(mode);
|
||||
constexpr size_t kStride4dSize = 4;
|
||||
if (stride.size() == kStride4dSize) {
|
||||
set_stride({stride[2], stride[3]});
|
||||
} else {
|
||||
|
|
|
@ -25,9 +25,9 @@
|
|||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
constexpr size_t kDoutIndex = 0;
|
||||
constexpr size_t kInputIndex = 1;
|
||||
constexpr size_t kSizeIndex = 2;
|
||||
constexpr size_t kConv2DBackpropInputDoutIndex = 0;
|
||||
constexpr size_t kConv2DBackpropInputInputIndex = 1;
|
||||
constexpr size_t kConv2DBackpropInputSizeIndex = 2;
|
||||
|
||||
void SetPadList(const PrimitivePtr &primitive, const std::vector<int64_t> &dout_shape_norm,
|
||||
const std::vector<int64_t> &x_size_v) {
|
||||
|
@ -86,7 +86,7 @@ abstract::ShapePtr Conv2DBackpropInputInferShape(const PrimitivePtr &primitive,
|
|||
auto prim_name = primitive->name();
|
||||
std::vector<int64_t> out_shape;
|
||||
abstract::ShapePtr ret_shape;
|
||||
auto input_size = input_args[kSizeIndex];
|
||||
auto input_size = input_args[kConv2DBackpropInputSizeIndex];
|
||||
auto input_size_v = input_size->BuildValue();
|
||||
MS_EXCEPTION_IF_NULL(input_size_v);
|
||||
|
||||
|
@ -95,7 +95,7 @@ abstract::ShapePtr Conv2DBackpropInputInferShape(const PrimitivePtr &primitive,
|
|||
out_shape = CheckAndConvertUtils::CheckTensorIntValue("input x size", input_size_v, prim_name);
|
||||
ret_shape = std::make_shared<abstract::Shape>(out_shape);
|
||||
} else {
|
||||
auto shape_ptr = CheckAndConvertUtils::GetTensorInputShape(prim_name, input_args, kSizeIndex);
|
||||
auto shape_ptr = CheckAndConvertUtils::GetTensorInputShape(prim_name, input_args, kConv2DBackpropInputSizeIndex);
|
||||
MS_EXCEPTION_IF_NULL(shape_ptr);
|
||||
auto shape_shape = shape_ptr->shape();
|
||||
if (shape_shape.size() != 1) {
|
||||
|
@ -137,7 +137,8 @@ abstract::ShapePtr Conv2DBackpropInputInferShape(const PrimitivePtr &primitive,
|
|||
MS_EXCEPTION(TypeError) << "The primitive[" << prim_name << "]'s input[x size] must be a tuple or Tensor, "
|
||||
<< "but got " << size_type->ToString();
|
||||
}
|
||||
auto dout_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kDoutIndex]->BuildShape())[kShape];
|
||||
auto dout_shape =
|
||||
CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kConv2DBackpropInputDoutIndex]->BuildShape())[kShape];
|
||||
|
||||
auto format = CheckAndConvertUtils::GetAndCheckFormat(primitive->GetAttr(kFormat));
|
||||
ShapeVector tmp_shape = {dout_shape[0], dout_shape[2], dout_shape[3], dout_shape[1]};
|
||||
|
@ -152,8 +153,8 @@ TypePtr Conv2DBackpropInputInferType(const PrimitivePtr &prim, const std::vector
|
|||
// check
|
||||
std::map<std::string, TypePtr> types;
|
||||
// todo: check input_sizes
|
||||
(void)types.emplace("x", input_args[kInputIndex]->BuildType());
|
||||
(void)types.emplace("doutput", input_args[kDoutIndex]->BuildType());
|
||||
(void)types.emplace("x", input_args[kConv2DBackpropInputInputIndex]->BuildType());
|
||||
(void)types.emplace("doutput", input_args[kConv2DBackpropInputDoutIndex]->BuildType());
|
||||
std::set<TypePtr> valid_x_type = {kInt8, kInt32, kFloat16, kFloat32};
|
||||
return CheckAndConvertUtils::CheckTensorTypeSame(types, valid_x_type, prim_name);
|
||||
}
|
||||
|
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_OPS_GRID_SAMPLER_3D_H_
|
||||
#define MINDSPORE_CORE_OPS_GRID_SAMPLER_3D_H_
|
||||
#ifndef MINDSPORE_CORE_OPS_GRAD_GRID_SAMPLER_3D_GRAD_H_
|
||||
#define MINDSPORE_CORE_OPS_GRAD_GRID_SAMPLER_3D_GRAD_H_
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
@ -41,4 +41,4 @@ using PrimGridSampler3DGrad = std::shared_ptr<GridSampler3DGrad>;
|
|||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_OPS_GRID_SAMPLER_3D_H_
|
||||
#endif // MINDSPORE_CORE_OPS_GRAD_GRID_SAMPLER_3D_GRAD_H_
|
||||
|
|
|
@ -30,7 +30,8 @@
|
|||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
abstract::ShapePtr HSigmoidGradInferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
const int64_t input_num = 2;
|
||||
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, input_num,
|
||||
|
@ -45,7 +46,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
|||
return std::make_shared<abstract::Shape>(grads_shape);
|
||||
}
|
||||
|
||||
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
TypePtr HSigmoidGradInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
const int64_t input_num = 2;
|
||||
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, prim->name());
|
||||
|
@ -59,8 +60,8 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
|
|||
|
||||
AbstractBasePtr HSigmoidGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args),
|
||||
InferShape(primitive, input_args)->shape());
|
||||
return std::make_shared<abstract::AbstractTensor>(HSigmoidGradInferType(primitive, input_args),
|
||||
HSigmoidGradInferShape(primitive, input_args)->shape());
|
||||
}
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -23,7 +23,7 @@
|
|||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
abstract::ShapePtr InvGradInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto x = input_args[0]->BuildShape();
|
||||
MS_EXCEPTION_IF_NULL(x);
|
||||
|
@ -32,7 +32,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
|||
return shape_element;
|
||||
}
|
||||
|
||||
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
TypePtr InvGradInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
auto prim_name = prim->name();
|
||||
const int64_t input_num = 2;
|
||||
|
@ -57,8 +57,8 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
|
|||
} // namespace
|
||||
AbstractBasePtr InvGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto type = InferType(primitive, input_args);
|
||||
auto shape = InferShape(primitive, input_args);
|
||||
auto type = InvGradInferType(primitive, input_args);
|
||||
auto shape = InvGradInferShape(primitive, input_args);
|
||||
return abstract::MakeAbstract(shape, type);
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(InvGrad, prim::kPrimInvGrad, InvGradInfer, nullptr, true);
|
||||
|
|
|
@ -21,7 +21,6 @@
|
|||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
constexpr int k2Directions = 2;
|
||||
AbstractBasePtr LstmGradInfer(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
// infer shape
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
|
@ -87,6 +86,7 @@ void LSTMGrad::Init(const int64_t input_size, const int64_t hidden_size, const i
|
|||
this->set_dropout(dropout);
|
||||
this->set_bidirectional(bidirectional);
|
||||
if (bidirectional) {
|
||||
constexpr int k2Directions = 2;
|
||||
this->set_num_directions(k2Directions);
|
||||
} else {
|
||||
this->set_num_directions(1);
|
||||
|
|
|
@ -21,7 +21,6 @@
|
|||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
constexpr int k2Directions = 2;
|
||||
AbstractBasePtr LstmGradDataInfer(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
// infer shape
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
|
@ -91,6 +90,7 @@ void LSTMGradData::Init(const int64_t input_size, const int64_t hidden_size, con
|
|||
this->set_dropout(dropout);
|
||||
this->set_bidirectional(bidirectional);
|
||||
if (bidirectional) {
|
||||
constexpr int k2Directions = 2;
|
||||
this->set_num_directions(k2Directions);
|
||||
} else {
|
||||
this->set_num_directions(1);
|
||||
|
|
|
@ -21,7 +21,6 @@
|
|||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
constexpr int k2Directions = 2;
|
||||
AbstractBasePtr LstmGradWeightInfer(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
// infer shape
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
|
@ -91,6 +90,7 @@ void LSTMGradWeight::Init(const int64_t input_size, const int64_t hidden_size, c
|
|||
this->set_dropout(dropout);
|
||||
this->set_bidirectional(bidirectional);
|
||||
if (bidirectional) {
|
||||
constexpr int k2Directions = 2;
|
||||
this->set_num_directions(k2Directions);
|
||||
} else {
|
||||
this->set_num_directions(1);
|
||||
|
|
|
@ -27,7 +27,7 @@
|
|||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
abstract::ShapePtr ReLUGradInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
const int64_t input_num = 2;
|
||||
|
@ -45,7 +45,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
|||
return shape_element;
|
||||
}
|
||||
|
||||
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
TypePtr ReLUGradInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
auto prim_name = prim->name();
|
||||
const int64_t input_num = 2;
|
||||
|
@ -65,8 +65,8 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
|
|||
} // namespace
|
||||
AbstractBasePtr ReLUGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto type = InferType(primitive, input_args);
|
||||
auto shape = InferShape(primitive, input_args);
|
||||
auto type = ReLUGradInferType(primitive, input_args);
|
||||
auto shape = ReLUGradInferShape(primitive, input_args);
|
||||
return abstract::MakeAbstract(shape, type);
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(ReLUGrad, prim::kPrimReluGrad, ReLUGradInfer, nullptr, true);
|
||||
|
|
|
@ -27,14 +27,14 @@
|
|||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
abstract::ShapePtr InferShape(const std::vector<AbstractBasePtr> &input_args) {
|
||||
abstract::ShapePtr ReLUGradV2InferShape(const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto x = input_args[0]->BuildShape();
|
||||
MS_EXCEPTION_IF_NULL(x);
|
||||
auto shape_element = x->cast<abstract::ShapePtr>();
|
||||
MS_EXCEPTION_IF_NULL(shape_element);
|
||||
return shape_element;
|
||||
}
|
||||
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
TypePtr ReLUGradV2InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto prim_name = prim->name();
|
||||
MS_EXCEPTION_IF_NULL(input_args[0]);
|
||||
auto x_type_map = input_args[0]->BuildType();
|
||||
|
@ -55,7 +55,7 @@ AbstractBasePtr ReLUGradV2Infer(const abstract::AnalysisEnginePtr &, const Primi
|
|||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
return abstract::MakeAbstract(InferShape(input_args), InferType(primitive, input_args));
|
||||
return abstract::MakeAbstract(ReLUGradV2InferShape(input_args), ReLUGradV2InferType(primitive, input_args));
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(ReLUGradV2, prim::kPrimReluGradV2, ReLUGradV2Infer, nullptr, true);
|
||||
} // namespace ops
|
||||
|
|
|
@ -27,7 +27,8 @@
|
|||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
abstract::ShapePtr ResizeNearestNeighborGradInferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
auto grad_shape_ptr = CheckAndConvertUtils::GetTensorInputShape(prim_name, input_args, 0);
|
||||
|
@ -55,7 +56,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
|||
return std::make_shared<abstract::Shape>(ret_shape);
|
||||
}
|
||||
|
||||
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
TypePtr ResizeNearestNeighborGradInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
return input_args[0]->BuildType();
|
||||
}
|
||||
} // namespace
|
||||
|
@ -65,7 +66,8 @@ AbstractBasePtr ResizeNearestNeighborGradInfer(const abstract::AnalysisEnginePtr
|
|||
const int64_t input_num = 2;
|
||||
(void)CheckAndConvertUtils::CheckInteger("infer", SizeToLong(CheckAndConvertUtils::GetRemoveMonadAbsNum(input_args)),
|
||||
kEqual, input_num, prim_name);
|
||||
return abstract::MakeAbstract(InferShape(primitive, input_args), InferType(primitive, input_args));
|
||||
return abstract::MakeAbstract(ResizeNearestNeighborGradInferShape(primitive, input_args),
|
||||
ResizeNearestNeighborGradInferType(primitive, input_args));
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(ResizeNearestNeighborGrad, prim::kPrimResizeNearestNeighborGrad,
|
||||
ResizeNearestNeighborGradInfer, nullptr, true);
|
||||
|
|
|
@ -23,12 +23,13 @@
|
|||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
constexpr int64_t kInputSize = 3;
|
||||
constexpr int64_t kSoftMarginLossGradInputSize = 3;
|
||||
abstract::ShapePtr SoftMarginLossGradInferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto op_name = primitive->name();
|
||||
(void)CheckAndConvertUtils::CheckInteger("input numbers", SizeToLong(input_args.size()), kEqual, kInputSize, op_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("input numbers", SizeToLong(input_args.size()), kEqual,
|
||||
kSoftMarginLossGradInputSize, op_name);
|
||||
auto predict = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape];
|
||||
auto label = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape];
|
||||
auto dout = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape())[kShape];
|
||||
|
@ -41,7 +42,8 @@ abstract::ShapePtr SoftMarginLossGradInferShape(const PrimitivePtr &primitive,
|
|||
|
||||
TypePtr SoftMarginLossGradInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto op_name = primitive->name();
|
||||
(void)CheckAndConvertUtils::CheckInteger("input numbers", SizeToLong(input_args.size()), kEqual, kInputSize, op_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("input numbers", SizeToLong(input_args.size()), kEqual,
|
||||
kSoftMarginLossGradInputSize, op_name);
|
||||
const std::set<TypePtr> valid_types = {kFloat16, kFloat32};
|
||||
std::map<std::string, TypePtr> types;
|
||||
(void)types.emplace("logits", input_args[kInputIndex0]->BuildType());
|
||||
|
|
|
@ -20,17 +20,11 @@
|
|||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
const size_t kZero = 0;
|
||||
const size_t kOne = 1;
|
||||
const size_t kTwo = 2;
|
||||
const size_t kThree = 3;
|
||||
const size_t kFour = 4;
|
||||
const size_t kFive = 5;
|
||||
|
||||
abstract::ShapePtr GridSampler3DInferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto input_x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kZero]->BuildShape())[kShape];
|
||||
auto grid_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kOne]->BuildShape())[kShape];
|
||||
auto input_x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape];
|
||||
auto grid_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape];
|
||||
const size_t kFive = 5;
|
||||
if (input_x_shape.size() != kFive) {
|
||||
MS_EXCEPTION(ValueError) << "Input_x must be a 5-dimensional tensor, but got "
|
||||
<< std::to_string(input_x_shape.size()) << "-dimensional tensor.";
|
||||
|
@ -39,24 +33,25 @@ abstract::ShapePtr GridSampler3DInferShape(const PrimitivePtr &primitive,
|
|||
MS_EXCEPTION(ValueError) << "Grid must be a 5-dimensional tensor, but got " << std::to_string(grid_shape.size())
|
||||
<< "-dimensional tensor.";
|
||||
}
|
||||
if (input_x_shape[kZero] != grid_shape[kZero]) {
|
||||
MS_EXCEPTION(ValueError) << "The shape of grid is " << input_args[kOne]->BuildShape()->ToString()
|
||||
<< " , but the shape of input_x is " << input_args[kZero]->BuildShape()->ToString()
|
||||
if (input_x_shape[kInputIndex0] != grid_shape[kInputIndex0]) {
|
||||
MS_EXCEPTION(ValueError) << "The shape of grid is " << input_args[kInputIndex1]->BuildShape()->ToString()
|
||||
<< " , but the shape of input_x is " << input_args[kInputIndex0]->BuildShape()->ToString()
|
||||
<< " . The first dimension of grid and input_x must be equal.";
|
||||
}
|
||||
if (grid_shape[kFour] != kThree) {
|
||||
MS_EXCEPTION(ValueError) << "The last dimension of grid must be 3, but got " << std::to_string(grid_shape[kFour]);
|
||||
if (grid_shape[kInputIndex4] != kInputIndex3) {
|
||||
MS_EXCEPTION(ValueError) << "The last dimension of grid must be 3, but got "
|
||||
<< std::to_string(grid_shape[kInputIndex4]);
|
||||
}
|
||||
std::vector<int64_t> output_shape = {input_x_shape[kZero], input_x_shape[kOne], grid_shape[kOne], grid_shape[kTwo],
|
||||
grid_shape[kThree]};
|
||||
std::vector<int64_t> output_shape = {input_x_shape[kInputIndex0], input_x_shape[kInputIndex1],
|
||||
grid_shape[kInputIndex1], grid_shape[kInputIndex2], grid_shape[kInputIndex3]};
|
||||
return std::make_shared<abstract::Shape>(output_shape);
|
||||
}
|
||||
|
||||
TypePtr GridSampler3DInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
std::map<std::string, TypePtr> types;
|
||||
std::set<TypePtr> valid_types = {kFloat32, kFloat64};
|
||||
TypePtr input_x_type = input_args[kZero]->BuildType();
|
||||
TypePtr grid_type = input_args[kOne]->BuildType();
|
||||
TypePtr input_x_type = input_args[kInputIndex0]->BuildType();
|
||||
TypePtr grid_type = input_args[kInputIndex1]->BuildType();
|
||||
(void)types.emplace("input_x", input_x_type);
|
||||
(void)types.emplace("grid", grid_type);
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, primitive->name());
|
||||
|
@ -67,7 +62,7 @@ TypePtr GridSampler3DInferType(const PrimitivePtr &primitive, const std::vector<
|
|||
AbstractBasePtr GridSampler3DInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
const int64_t input_num = kTwo;
|
||||
const int64_t input_num = kInputIndex2;
|
||||
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name());
|
||||
auto infer_type = GridSampler3DInferType(primitive, input_args);
|
||||
auto infer_shape = GridSampler3DInferShape(primitive, input_args);
|
||||
|
|
|
@ -26,13 +26,13 @@
|
|||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
abstract::ShapePtr HShrinkInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, 1L, primitive->name());
|
||||
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShapeTrack())[kShape];
|
||||
return std::make_shared<abstract::Shape>(in_shape);
|
||||
}
|
||||
TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
TypePtr HShrinkInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, 1, primitive->name());
|
||||
const std::set<TypePtr> valid_types = {kFloat16, kFloat32};
|
||||
|
@ -43,8 +43,8 @@ TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBaseP
|
|||
|
||||
AbstractBasePtr HShrinkInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args),
|
||||
InferShape(primitive, input_args)->shape());
|
||||
return std::make_shared<abstract::AbstractTensor>(HShrinkInferType(primitive, input_args),
|
||||
HShrinkInferShape(primitive, input_args)->shape());
|
||||
}
|
||||
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(HShrink, prim::kPrimHShrink, HShrinkInfer, nullptr, true);
|
||||
|
|
|
@ -21,7 +21,7 @@
|
|||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
abstract::ShapePtr HSigmoidInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
|
@ -30,7 +30,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
|||
return std::make_shared<abstract::Shape>(in_shape);
|
||||
}
|
||||
|
||||
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
TypePtr HSigmoidInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
if (std::any_of(input_args.begin(), input_args.end(), [](const AbstractBasePtr &a) { return a == nullptr; })) {
|
||||
MS_LOG(EXCEPTION) << "nullptr";
|
||||
}
|
||||
|
@ -42,8 +42,8 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
|
|||
} // namespace
|
||||
AbstractBasePtr HSigmoidInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args),
|
||||
InferShape(primitive, input_args)->shape());
|
||||
return std::make_shared<abstract::AbstractTensor>(HSigmoidInferType(primitive, input_args),
|
||||
HSigmoidInferShape(primitive, input_args)->shape());
|
||||
}
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -24,7 +24,7 @@
|
|||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
abstract::ShapePtr HSVToRGBInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
const int64_t kNumDims = 4;
|
||||
const int64_t kLastDim = 3;
|
||||
|
@ -37,7 +37,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
|||
return std::make_shared<abstract::Shape>(input_shape);
|
||||
}
|
||||
|
||||
TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
TypePtr HSVToRGBInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto input_dtype = input_args[0]->BuildType();
|
||||
const std::set<TypePtr> input_valid_types = {kFloat16, kFloat32, kFloat64};
|
||||
CheckAndConvertUtils::CheckTensorTypeValid("x", input_dtype, input_valid_types, kNameHSVToRGB);
|
||||
|
@ -50,8 +50,8 @@ AbstractBasePtr HSVToRGBInfer(const abstract::AnalysisEnginePtr &, const Primiti
|
|||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
constexpr int64_t input_num = 1;
|
||||
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name());
|
||||
auto types = InferType(primitive, input_args);
|
||||
auto shapes = InferShape(primitive, input_args);
|
||||
auto types = HSVToRGBInferType(primitive, input_args);
|
||||
auto shapes = HSVToRGBInferShape(primitive, input_args);
|
||||
return abstract::MakeAbstract(shapes, types);
|
||||
}
|
||||
|
||||
|
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_OPS_REAL_H_
|
||||
#define MINDSPORE_CORE_OPS_REAL_H_
|
||||
#ifndef MINDSPORE_CORE_OPS_IMAG_H_
|
||||
#define MINDSPORE_CORE_OPS_IMAG_H_
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
|
@ -40,4 +40,4 @@ class MS_CORE_API Imag : public PrimitiveC {
|
|||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_OPS_REAL_H_
|
||||
#endif // MINDSPORE_CORE_OPS_IMAG_H_
|
||||
|
|
|
@ -21,7 +21,7 @@
|
|||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
abstract::ShapePtr IOUInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
(void)CheckAndConvertUtils::CheckInteger("input numbers", SizeToLong(input_args.size()), kGreaterEqual, 2, prim_name);
|
||||
|
@ -67,7 +67,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
|||
return std::make_shared<abstract::Shape>(ret_shape, ret_min_shape, ret_max_shape);
|
||||
}
|
||||
|
||||
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
TypePtr IOUInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
const std::set<TypePtr> valid_types = {kFloat16, kFloat32};
|
||||
std::map<std::string, TypePtr> types;
|
||||
(void)types.emplace("x", input_args[0]->BuildType());
|
||||
|
@ -77,8 +77,8 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
|
|||
} // namespace
|
||||
AbstractBasePtr IOUInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto type = InferType(primitive, input_args);
|
||||
auto shape = InferShape(primitive, input_args);
|
||||
auto type = IOUInferType(primitive, input_args);
|
||||
auto shape = IOUInferShape(primitive, input_args);
|
||||
return abstract::MakeAbstract(shape, type);
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(IOU, prim::kPrimIOU, IOUInfer, nullptr, true);
|
||||
|
|
|
@ -25,7 +25,7 @@
|
|||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
abstract::ShapePtr IsInfInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
const int64_t input_num = 1;
|
||||
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name());
|
||||
|
@ -36,7 +36,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
|||
return std::make_shared<abstract::Shape>(x_shape);
|
||||
}
|
||||
|
||||
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
TypePtr IsInfInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
|
@ -50,8 +50,8 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
|
|||
|
||||
AbstractBasePtr IsInfInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto infertype = InferType(primitive, input_args);
|
||||
auto infershape = InferShape(primitive, input_args);
|
||||
auto infertype = IsInfInferType(primitive, input_args);
|
||||
auto infershape = IsInfInferShape(primitive, input_args);
|
||||
return abstract::MakeAbstract(infershape, infertype);
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(IsInf, prim::kPrimIsInf, IsInfInfer, nullptr, true);
|
||||
|
|
|
@ -25,7 +25,7 @@
|
|||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
abstract::ShapePtr IsNanInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
|
@ -34,7 +34,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
|||
return std::make_shared<abstract::Shape>(x_shape);
|
||||
}
|
||||
|
||||
TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
TypePtr IsNanInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
|
@ -51,8 +51,8 @@ AbstractBasePtr IsNanInfer(const abstract::AnalysisEnginePtr &, const PrimitiveP
|
|||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
const int64_t input_num = 1;
|
||||
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name());
|
||||
auto infertype = InferType(primitive, input_args);
|
||||
auto infershape = InferShape(primitive, input_args);
|
||||
auto infertype = IsNanInferType(primitive, input_args);
|
||||
auto infershape = IsNanInferShape(primitive, input_args);
|
||||
return abstract::MakeAbstract(infershape, infertype);
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(IsNan, prim::kPrimIsNan, IsNanInfer, nullptr, true);
|
||||
|
|
|
@ -19,7 +19,7 @@
|
|||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
abstract::ShapePtr LeakyReluInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto x = input_args[0]->BuildShape();
|
||||
auto shape_element = x->cast<abstract::ShapePtr>();
|
||||
|
@ -27,7 +27,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
|||
return shape_element;
|
||||
}
|
||||
|
||||
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
TypePtr LeakyReluInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, 1, prim->name());
|
||||
for (const auto &item : input_args) {
|
||||
|
@ -47,8 +47,8 @@ float LeakyRelu::get_negative_slope() const { return GetValue<float>(GetAttr(kNe
|
|||
|
||||
AbstractBasePtr LeakyReluInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args),
|
||||
InferShape(primitive, input_args));
|
||||
return std::make_shared<abstract::AbstractTensor>(LeakyReluInferType(primitive, input_args),
|
||||
LeakyReluInferShape(primitive, input_args));
|
||||
}
|
||||
REGISTER_PRIMITIVE_C(kNameLeakyRelu, LeakyRelu);
|
||||
} // namespace ops
|
||||
|
|
|
@ -25,7 +25,7 @@
|
|||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
abstract::ShapePtr LerpInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto op_name = primitive->name();
|
||||
const int64_t input_num = 3;
|
||||
|
@ -48,7 +48,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
|||
return std::make_shared<abstract::Shape>(broadcast_shape);
|
||||
}
|
||||
|
||||
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
TypePtr LerpInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
|
@ -69,8 +69,8 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
|
|||
|
||||
AbstractBasePtr LerpInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args),
|
||||
InferShape(primitive, input_args)->shape());
|
||||
return std::make_shared<abstract::AbstractTensor>(LerpInferType(primitive, input_args),
|
||||
LerpInferShape(primitive, input_args)->shape());
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(Lerp, prim::kPrimLerp, LerpInfer, nullptr, true);
|
||||
} // namespace ops
|
||||
|
|
|
@ -25,13 +25,13 @@
|
|||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
abstract::ShapePtr LessInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto op_name = primitive->name();
|
||||
return BroadCastInferShape(op_name, input_args);
|
||||
}
|
||||
|
||||
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
TypePtr LessInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
|
@ -45,8 +45,8 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
|
|||
|
||||
AbstractBasePtr LessInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto shape = InferShape(primitive, input_args);
|
||||
auto type = InferType(primitive, input_args);
|
||||
auto shape = LessInferShape(primitive, input_args);
|
||||
auto type = LessInferType(primitive, input_args);
|
||||
return abstract::MakeAbstract(shape, type);
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(Less, prim::kPrimLess, LessInfer, nullptr, true);
|
||||
|
|
|
@ -24,13 +24,13 @@
|
|||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
abstract::ShapePtr LessEqualInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto op_name = primitive->name();
|
||||
return BroadCastInferShape(op_name, input_args);
|
||||
}
|
||||
|
||||
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
TypePtr LessEqualInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
const int64_t input_num = 2;
|
||||
(void)CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, prim->name());
|
||||
|
@ -44,8 +44,8 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
|
|||
|
||||
AbstractBasePtr LessEqualInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto infer_shape = InferShape(primitive, input_args);
|
||||
auto infer_type = InferType(primitive, input_args);
|
||||
auto infer_shape = LessEqualInferShape(primitive, input_args);
|
||||
auto infer_type = LessEqualInferType(primitive, input_args);
|
||||
return abstract::MakeAbstract(infer_shape, infer_type);
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(LessEqual, prim::kPrimLessEqual, LessEqualInfer, nullptr, true);
|
||||
|
|
|
@ -27,7 +27,7 @@
|
|||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
abstract::ShapePtr LogInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
(void)CheckAndConvertUtils::CheckInteger("input numbers", int64_t(input_args.size()), kEqual, 1, prim_name);
|
||||
|
@ -44,7 +44,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
|||
return std::make_shared<abstract::Shape>(in_shape);
|
||||
}
|
||||
|
||||
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
TypePtr LogInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
auto op_name = prim->name();
|
||||
(void)CheckAndConvertUtils::CheckInteger("input number", int64_t(input_args.size()), kEqual, 1, op_name);
|
||||
|
@ -61,7 +61,7 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
|
|||
|
||||
AbstractBasePtr LogInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
return abstract::MakeAbstract(InferShape(primitive, input_args), InferType(primitive, input_args));
|
||||
return abstract::MakeAbstract(LogInferShape(primitive, input_args), LogInferType(primitive, input_args));
|
||||
}
|
||||
REGISTER_PRIMITIVE_C(kNameLog, Log);
|
||||
} // namespace ops
|
||||
|
|
|
@ -24,7 +24,8 @@
|
|||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
abstract::TupleShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
abstract::TupleShapePtr LogMatrixDeterminantInferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto prim_name = primitive->name();
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
auto x_rank = SizeToLong(x_shape.size());
|
||||
|
@ -39,7 +40,7 @@ abstract::TupleShapePtr InferShape(const PrimitivePtr &primitive, const std::vec
|
|||
return std::make_shared<abstract::TupleShape>(std::vector<abstract::BaseShapePtr>{out_shape, out_shape});
|
||||
}
|
||||
|
||||
TuplePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
TuplePtr LogMatrixDeterminantInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
const std::set<TypePtr> valid_types = {kFloat32};
|
||||
auto x_type = input_args[0]->BuildType();
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("x", x_type, valid_types, prim->name());
|
||||
|
@ -52,8 +53,8 @@ AbstractBasePtr LogMatrixDeterminantInfer(const abstract::AnalysisEnginePtr &, c
|
|||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
const int64_t input_num = 1;
|
||||
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name());
|
||||
auto infertype = InferType(primitive, input_args);
|
||||
auto infershape = InferShape(primitive, input_args);
|
||||
auto infertype = LogMatrixDeterminantInferType(primitive, input_args);
|
||||
auto infershape = LogMatrixDeterminantInferShape(primitive, input_args);
|
||||
return abstract::MakeAbstract(infershape, infertype);
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(LogMatrixDeterminant, prim::kPrimLogMatrixDeterminant, LogMatrixDeterminantInfer, nullptr,
|
||||
|
|
|
@ -23,13 +23,13 @@
|
|||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
abstract::ShapePtr LogicalXorInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto op_name = primitive->name();
|
||||
return BroadCastInferShape(op_name, input_args);
|
||||
}
|
||||
|
||||
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
TypePtr LogicalXorInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
std::map<std::string, TypePtr> types;
|
||||
const std::set<TypePtr> valid_types = {kBool};
|
||||
types.emplace("x", input_args[0]->BuildType());
|
||||
|
@ -43,8 +43,8 @@ AbstractBasePtr LogicalXorInfer(const abstract::AnalysisEnginePtr &, const Primi
|
|||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
const int64_t input_num = 2;
|
||||
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name());
|
||||
auto infer_type = InferType(primitive, input_args);
|
||||
auto infer_shape = InferShape(primitive, input_args);
|
||||
auto infer_type = LogicalXorInferType(primitive, input_args);
|
||||
auto infer_shape = LogicalXorInferShape(primitive, input_args);
|
||||
return abstract::MakeAbstract(infer_shape, infer_type);
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(LogicalXor, prim::kPrimLogicalXor, LogicalXorInfer, nullptr, true);
|
||||
|
|
|
@ -27,7 +27,7 @@
|
|||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
abstract::ShapePtr LpNormInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
for (const auto &item : input_args) {
|
||||
|
@ -86,7 +86,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
|||
return std::make_shared<abstract::Shape>(output_shape);
|
||||
}
|
||||
|
||||
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
TypePtr LpNormInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
|
@ -103,8 +103,8 @@ AbstractBasePtr LpNormInfer(const abstract::AnalysisEnginePtr &, const Primitive
|
|||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
const int64_t kInputsNum = 1;
|
||||
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, kInputsNum, primitive->name());
|
||||
auto infer_type = InferType(primitive, input_args);
|
||||
auto infer_shape = InferShape(primitive, input_args);
|
||||
auto infer_type = LpNormInferType(primitive, input_args);
|
||||
auto infer_shape = LpNormInferShape(primitive, input_args);
|
||||
return abstract::MakeAbstract(infer_shape, infer_type);
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(LpNorm, prim::kPrimLpNorm, LpNormInfer, nullptr, true);
|
||||
|
|
|
@ -75,7 +75,7 @@ void LRN::Init(const int64_t depth_radius, const float bias, const float alpha,
|
|||
}
|
||||
|
||||
namespace {
|
||||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
abstract::ShapePtr LRNInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
const int64_t x_size = 4;
|
||||
|
@ -85,7 +85,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
|||
return std::make_shared<abstract::Shape>(in_shape);
|
||||
}
|
||||
|
||||
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
TypePtr LRNInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
const std::set<TypePtr> valid_types = {kFloat16, kFloat32};
|
||||
if (std::any_of(input_args.begin(), input_args.end(), [](const AbstractBasePtr arg) { return arg == nullptr; })) {
|
||||
MS_LOG(EXCEPTION) << "nullptr";
|
||||
|
@ -99,8 +99,8 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
|
|||
|
||||
AbstractBasePtr LrnInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args),
|
||||
InferShape(primitive, input_args)->shape());
|
||||
return std::make_shared<abstract::AbstractTensor>(LRNInferType(primitive, input_args),
|
||||
LRNInferShape(primitive, input_args)->shape());
|
||||
}
|
||||
REGISTER_PRIMITIVE_C(kNameLRN, LRN);
|
||||
} // namespace ops
|
||||
|
|
|
@ -19,7 +19,6 @@
|
|||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
constexpr int64_t type_size = 4;
|
||||
AbstractBasePtr LstmInfer(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
// infer shape
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
|
|
|
@ -47,7 +47,7 @@
|
|||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
abstract::ShapePtr LuSolveInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto op_name = primitive->name();
|
||||
const int64_t kDimNum = 2;
|
||||
|
@ -128,7 +128,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
|||
}
|
||||
}
|
||||
|
||||
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
TypePtr LuSolveInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
const int64_t kDimNum = 2;
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
|
@ -152,8 +152,8 @@ AbstractBasePtr LuSolveInfer(const abstract::AnalysisEnginePtr &, const Primitiv
|
|||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
const int64_t input_num = 3;
|
||||
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name());
|
||||
auto infer_type = InferType(primitive, input_args);
|
||||
auto infer_shape = InferShape(primitive, input_args);
|
||||
auto infer_type = LuSolveInferType(primitive, input_args);
|
||||
auto infer_shape = LuSolveInferShape(primitive, input_args);
|
||||
return abstract::MakeAbstract(infer_shape, infer_type);
|
||||
}
|
||||
|
||||
|
|
|
@ -24,7 +24,7 @@
|
|||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
abstract::ShapePtr MaskedFillInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto op_name = primitive->name();
|
||||
const int64_t input_num = 3;
|
||||
|
@ -48,7 +48,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
|||
return std::make_shared<abstract::Shape>(broadcast_shape);
|
||||
}
|
||||
|
||||
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
TypePtr MaskedFillInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
|
@ -71,8 +71,8 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
|
|||
|
||||
AbstractBasePtr MaskedFillInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args),
|
||||
InferShape(primitive, input_args)->shape());
|
||||
return std::make_shared<abstract::AbstractTensor>(MaskedFillInferType(primitive, input_args),
|
||||
MaskedFillInferShape(primitive, input_args)->shape());
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(MaskedFill, prim::kPrimMaskedFill, MaskedFillInfer, nullptr, true);
|
||||
} // namespace ops
|
||||
|
|
|
@ -24,7 +24,8 @@
|
|||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
abstract::ShapePtr MatrixDeterminantInferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto prim_name = primitive->name();
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
auto x_rank = SizeToLong(x_shape.size());
|
||||
|
@ -38,7 +39,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
|||
return std::make_shared<abstract::Shape>(out_shape);
|
||||
}
|
||||
|
||||
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
TypePtr MatrixDeterminantInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
const std::set<TypePtr> valid_types = {kFloat32};
|
||||
auto infer_type = input_args[0]->BuildType();
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("x", infer_type, valid_types, prim->name());
|
||||
|
@ -51,8 +52,8 @@ AbstractBasePtr MatrixDeterminantInfer(const abstract::AnalysisEnginePtr &, cons
|
|||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
const int64_t input_num = 1;
|
||||
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name());
|
||||
auto infertype = InferType(primitive, input_args);
|
||||
auto infershape = InferShape(primitive, input_args);
|
||||
auto infertype = MatrixDeterminantInferType(primitive, input_args);
|
||||
auto infershape = MatrixDeterminantInferShape(primitive, input_args);
|
||||
return abstract::MakeAbstract(infershape, infertype);
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(MatrixDeterminant, prim::kPrimMatrixDeterminant, MatrixDeterminantInfer, nullptr, true);
|
||||
|
|
|
@ -23,21 +23,22 @@
|
|||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
const constexpr int64_t kShape2 = 2;
|
||||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
abstract::ShapePtr MatrixDiagPartInferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto input_shape = input_args[0]->BuildShape();
|
||||
auto shape_element = input_shape->cast<abstract::ShapePtr>();
|
||||
ShapeVector shape = shape_element->shape();
|
||||
ShapeVector min_shape = shape_element->shape();
|
||||
ShapeVector max_shape = shape_element->shape();
|
||||
const constexpr int64_t kShape2 = 2;
|
||||
max_shape[shape.size() - 1] = kShape2 * shape[shape.size() - 1] - 1;
|
||||
min_shape[shape.size() - 1] = 1;
|
||||
shape[shape.size() - 1] = abstract::Shape::SHP_ANY;
|
||||
return std::make_shared<abstract::Shape>(shape, min_shape, max_shape);
|
||||
}
|
||||
|
||||
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
TypePtr MatrixDiagPartInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
|
@ -51,7 +52,8 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
|
|||
|
||||
AbstractBasePtr MatrixDiagPartInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
return abstract::MakeAbstract(InferShape(primitive, input_args), InferType(primitive, input_args));
|
||||
return abstract::MakeAbstract(MatrixDiagPartInferShape(primitive, input_args),
|
||||
MatrixDiagPartInferType(primitive, input_args));
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(MatrixDiagPartV3, prim::kPrimMatrixDiagPart, MatrixDiagPartInfer, nullptr, true);
|
||||
} // namespace ops
|
||||
|
|
|
@ -24,7 +24,8 @@
|
|||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
abstract::ShapePtr MatrixInverseInferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto prim_name = primitive->name();
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
auto x_rank = SizeToLong(x_shape.size());
|
||||
|
@ -37,7 +38,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
|||
return std::make_shared<abstract::Shape>(x_shape);
|
||||
}
|
||||
|
||||
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
TypePtr MatrixInverseInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
const std::set<TypePtr> valid_types = {kFloat32, kFloat64};
|
||||
auto infer_type = input_args[0]->BuildType();
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("x", infer_type, valid_types, prim->name());
|
||||
|
@ -50,8 +51,8 @@ AbstractBasePtr MatrixInverseInfer(const abstract::AnalysisEnginePtr &, const Pr
|
|||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
const int64_t input_num = 1;
|
||||
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name());
|
||||
auto infertype = InferType(primitive, input_args);
|
||||
auto infershape = InferShape(primitive, input_args);
|
||||
auto infertype = MatrixInverseInferType(primitive, input_args);
|
||||
auto infershape = MatrixInverseInferShape(primitive, input_args);
|
||||
return abstract::MakeAbstract(infershape, infertype);
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(MatrixInverse, prim::kPrimMatrixInverse, MatrixInverseInfer, nullptr, true);
|
||||
|
|
|
@ -79,7 +79,7 @@ void MaxPool::Init(const std::vector<int64_t> &kernel_size, const std::vector<in
|
|||
}
|
||||
|
||||
namespace {
|
||||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
abstract::ShapePtr MaxPoolInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto op_name = primitive->name();
|
||||
MS_EXCEPTION_IF_NULL(input_args[0]);
|
||||
|
@ -124,7 +124,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
|||
return std::make_shared<abstract::Shape>(out_shape);
|
||||
}
|
||||
|
||||
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
TypePtr MaxPoolInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
if (std::any_of(input_args.begin(), input_args.end(), [](const AbstractBasePtr arg) { return arg == nullptr; })) {
|
||||
MS_LOG(EXCEPTION) << "nullptr";
|
||||
}
|
||||
|
@ -142,8 +142,8 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
|
|||
|
||||
AbstractBasePtr MaxPoolInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args),
|
||||
InferShape(primitive, input_args)->shape());
|
||||
return std::make_shared<abstract::AbstractTensor>(MaxPoolInferType(primitive, input_args),
|
||||
MaxPoolInferShape(primitive, input_args)->shape());
|
||||
}
|
||||
REGISTER_PRIMITIVE_C(kNameMaxPool, MaxPool);
|
||||
} // namespace ops
|
||||
|
|
|
@ -22,7 +22,7 @@
|
|||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
abstract::ShapePtr MfccInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
const int64_t input0_size = 3;
|
||||
|
@ -38,7 +38,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
|||
return std::make_shared<abstract::Shape>(out_shape);
|
||||
}
|
||||
|
||||
TypePtr InferType(const std::vector<AbstractBasePtr> &input_args) {
|
||||
TypePtr MfccInferType(const std::vector<AbstractBasePtr> &input_args) {
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
|
@ -90,7 +90,8 @@ int64_t Mfcc::get_dct_coeff_num() const { return GetValue<int64_t>(GetAttr(kDctC
|
|||
|
||||
AbstractBasePtr MfccInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
return std::make_shared<abstract::AbstractTensor>(InferType(input_args), InferShape(primitive, input_args)->shape());
|
||||
return std::make_shared<abstract::AbstractTensor>(MfccInferType(input_args),
|
||||
MfccInferShape(primitive, input_args)->shape());
|
||||
}
|
||||
REGISTER_PRIMITIVE_C(kNameMfcc, Mfcc);
|
||||
} // namespace ops
|
||||
|
|
|
@ -27,13 +27,13 @@
|
|||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
abstract::ShapePtr MinimumInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto op_name = primitive->name();
|
||||
return BroadCastInferShape(op_name, input_args);
|
||||
}
|
||||
|
||||
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
TypePtr MinimumInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
auto op_name = prim->name();
|
||||
const int64_t input_num = 2;
|
||||
|
@ -50,8 +50,8 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
|
|||
|
||||
AbstractBasePtr MinimumInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args),
|
||||
InferShape(primitive, input_args));
|
||||
return std::make_shared<abstract::AbstractTensor>(MinimumInferType(primitive, input_args),
|
||||
MinimumInferShape(primitive, input_args));
|
||||
}
|
||||
REGISTER_PRIMITIVE_C(kNameMinimum, Minimum);
|
||||
} // namespace ops
|
||||
|
|
|
@ -26,7 +26,7 @@
|
|||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
abstract::ShapePtr MulInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
const int64_t input_num = 2;
|
||||
|
@ -37,7 +37,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
|||
return BroadCastInferShape(prim_name, input_args);
|
||||
}
|
||||
|
||||
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
TypePtr MulInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
|
@ -54,7 +54,7 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
|
|||
|
||||
AbstractBasePtr MulInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
return abstract::MakeAbstract(InferShape(primitive, input_args), InferType(primitive, input_args));
|
||||
return abstract::MakeAbstract(MulInferShape(primitive, input_args), MulInferType(primitive, input_args));
|
||||
}
|
||||
REGISTER_PRIMITIVE_C(kNameMul, Mul);
|
||||
} // namespace ops
|
||||
|
|
|
@ -38,7 +38,7 @@ void ImpleNeg(void *origin, void *target, size_t size) {
|
|||
target_data[i] = -origin_data[i];
|
||||
}
|
||||
}
|
||||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
abstract::ShapePtr NegInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto x = input_args[0]->BuildShape();
|
||||
MS_EXCEPTION_IF_NULL(x);
|
||||
auto shape_ptr = x->cast<abstract::ShapePtr>();
|
||||
|
@ -46,7 +46,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
|||
return shape_ptr;
|
||||
}
|
||||
|
||||
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
TypePtr NegInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
const std::set<TypePtr> valid_types = {kUInt8, kInt8, kInt16, kInt32, kInt64,
|
||||
kFloat16, kFloat32, kFloat64, kComplex64, kComplex128};
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("x", input_args[0]->BuildType(), valid_types, prim->name());
|
||||
|
@ -69,7 +69,7 @@ ValuePtr NegInferValue(const PrimitivePtr &prim, const std::vector<AbstractBaseP
|
|||
|
||||
auto data_size = x_tensor->DataSize();
|
||||
auto dtype = x_tensor->data_type();
|
||||
auto shape = InferShape(prim, input_args)->shape();
|
||||
auto shape = NegInferShape(prim, input_args)->shape();
|
||||
auto result_tensor = std::make_shared<tensor::Tensor>(dtype, shape); // same shape and dtype
|
||||
auto x_datac = x_tensor->data_c();
|
||||
auto result_datac = result_tensor->data_c();
|
||||
|
@ -139,8 +139,8 @@ AbstractBasePtr NegInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr
|
|||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
const int64_t kInputsNum = 1;
|
||||
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, kInputsNum, primitive->name());
|
||||
auto infer_type = InferType(primitive, input_args);
|
||||
auto infer_shape = InferShape(primitive, input_args);
|
||||
auto infer_type = NegInferType(primitive, input_args);
|
||||
auto infer_shape = NegInferShape(primitive, input_args);
|
||||
return abstract::MakeAbstract(infer_shape, infer_type);
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(Neg, prim::kPrimNeg, NegInfer, NegInferValue, true);
|
||||
|
|
|
@ -22,12 +22,8 @@
|
|||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
constexpr auto kRecvShapes = "recv_shapes";
|
||||
constexpr auto kRecvRankIds = "recv_rank_ids";
|
||||
constexpr auto kRecvType = "recv_type";
|
||||
constexpr auto kSendShapes = "send_shapes";
|
||||
constexpr auto kSendRankIds = "send_rank_ids";
|
||||
constexpr auto kGroup = "group";
|
||||
constexpr auto kNeighborExchangeRecvShapes = "recv_shapes";
|
||||
constexpr auto kNeighborExchangeRecvType = "recv_type";
|
||||
|
||||
inline std::string GetShapeStr(const std::vector<int64_t> &shape) {
|
||||
std::string shape_str = "[";
|
||||
|
@ -78,16 +74,20 @@ void CheckAttr(const PrimitivePtr &primitive, const std::string &shape_attr_name
|
|||
}
|
||||
}
|
||||
|
||||
void Check(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
void NeighborExchangeCheck(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
CheckAttr(primitive, kRecvShapes, kRecvRankIds);
|
||||
constexpr auto kSendShapes = "send_shapes";
|
||||
constexpr auto kRecvRankIds = "recv_rank_ids";
|
||||
constexpr auto kSendRankIds = "send_rank_ids";
|
||||
constexpr auto kGroup = "group";
|
||||
CheckAttr(primitive, kNeighborExchangeRecvShapes, kRecvRankIds);
|
||||
CheckAttr(primitive, kSendShapes, kSendRankIds);
|
||||
// check recv type
|
||||
auto recv_type_attr = primitive->GetAttr(kRecvType);
|
||||
auto recv_type_attr = primitive->GetAttr(kNeighborExchangeRecvType);
|
||||
MS_EXCEPTION_IF_NULL(recv_type_attr);
|
||||
if (!recv_type_attr->isa<Type>()) {
|
||||
MS_EXCEPTION(TypeError) << "Attr " << kRecvType << " should be a mindspore data type.";
|
||||
MS_EXCEPTION(TypeError) << "Attr " << kNeighborExchangeRecvType << " should be a mindspore data type.";
|
||||
}
|
||||
// check group
|
||||
auto group_attr = primitive->GetAttr(kGroup);
|
||||
|
@ -141,9 +141,9 @@ void Check(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &in
|
|||
}
|
||||
}
|
||||
|
||||
abstract::BaseShapePtr InferShape(const PrimitivePtr &primitive) {
|
||||
abstract::BaseShapePtr NeighborExchangeInferShape(const PrimitivePtr &primitive) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto recv_shapes = primitive->GetAttr(kRecvShapes);
|
||||
auto recv_shapes = primitive->GetAttr(kNeighborExchangeRecvShapes);
|
||||
MS_EXCEPTION_IF_NULL(recv_shapes);
|
||||
auto shapes_seq = recv_shapes->cast<ValueSequencePtr>();
|
||||
MS_EXCEPTION_IF_NULL(shapes_seq);
|
||||
|
@ -163,15 +163,15 @@ abstract::BaseShapePtr InferShape(const PrimitivePtr &primitive) {
|
|||
return std::make_shared<abstract::TupleShape>(base_shape_list);
|
||||
}
|
||||
|
||||
TypePtr InferType(const PrimitivePtr &primitive) {
|
||||
TypePtr NeighborExchangeInferType(const PrimitivePtr &primitive) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto recv_shapes = primitive->GetAttr(kRecvShapes);
|
||||
auto recv_shapes = primitive->GetAttr(kNeighborExchangeRecvShapes);
|
||||
MS_EXCEPTION_IF_NULL(recv_shapes);
|
||||
auto shapes_seq = recv_shapes->cast<ValueSequencePtr>();
|
||||
MS_EXCEPTION_IF_NULL(shapes_seq);
|
||||
auto shapes_value = shapes_seq->value();
|
||||
auto out_num = shapes_value.size();
|
||||
auto recv_type = primitive->GetAttr(kRecvType)->cast<TypePtr>();
|
||||
auto recv_type = primitive->GetAttr(kNeighborExchangeRecvType)->cast<TypePtr>();
|
||||
MS_EXCEPTION_IF_NULL(recv_type);
|
||||
std::vector<TypePtr> type_vec(out_num, recv_type);
|
||||
if (type_vec.empty()) {
|
||||
|
@ -182,9 +182,9 @@ TypePtr InferType(const PrimitivePtr &primitive) {
|
|||
} // namespace
|
||||
AbstractBasePtr NeighborExchangeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
Check(primitive, input_args);
|
||||
auto type = InferType(primitive);
|
||||
auto shape = InferShape(primitive);
|
||||
NeighborExchangeCheck(primitive, input_args);
|
||||
auto type = NeighborExchangeInferType(primitive);
|
||||
auto shape = NeighborExchangeInferShape(primitive);
|
||||
return abstract::MakeAbstract(shape, type);
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(NeighborExchange, prim::kPrimNeighborExchange, NeighborExchangeInfer, nullptr, true);
|
||||
|
|
|
@ -23,26 +23,17 @@
|
|||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
constexpr auto kSendRankIds = "send_rank_ids";
|
||||
constexpr auto kSendLens = "send_lens";
|
||||
constexpr auto kRecvRankIds = "recv_rank_ids";
|
||||
constexpr auto kRecvLens = "recv_lens";
|
||||
constexpr auto kDataFormat = "format";
|
||||
constexpr auto kGroup = "group";
|
||||
constexpr size_t kRankIdsSize = 8;
|
||||
constexpr size_t kLensSize = 4;
|
||||
constexpr size_t kInputSize = 4;
|
||||
constexpr size_t kHDim = 2;
|
||||
constexpr size_t kWDim = 3;
|
||||
constexpr int64_t kInvalidIds = -1;
|
||||
constexpr size_t kIdx0 = 0;
|
||||
constexpr size_t kIdx1 = 1;
|
||||
constexpr size_t kIdx2 = 2;
|
||||
constexpr size_t kIdx3 = 3;
|
||||
constexpr size_t kIdx4 = 4;
|
||||
constexpr size_t kIdx5 = 5;
|
||||
constexpr size_t kIdx6 = 6;
|
||||
constexpr size_t kIdx7 = 7;
|
||||
constexpr auto kNeighborExchangeV2RecvRankIds = "recv_rank_ids";
|
||||
constexpr auto kNeighborExchangeV2RecvLens = "recv_lens";
|
||||
constexpr int64_t kNeighborExchangeV2InvalidIds = -1;
|
||||
constexpr size_t kNeighborExchangeV2Idx0 = 0;
|
||||
constexpr size_t kNeighborExchangeV2Idx1 = 1;
|
||||
constexpr size_t kNeighborExchangeV2Idx2 = 2;
|
||||
constexpr size_t kNeighborExchangeV2Idx3 = 3;
|
||||
constexpr size_t kNeighborExchangeV2Idx4 = 4;
|
||||
constexpr size_t kNeighborExchangeV2Idx5 = 5;
|
||||
constexpr size_t kNeighborExchangeV2Idx6 = 6;
|
||||
constexpr size_t kNeighborExchangeV2Idx7 = 7;
|
||||
|
||||
std::vector<int64_t> CheckAttrSize(const PrimitivePtr &primitive, const std::string &attr_name,
|
||||
const size_t attr_size) {
|
||||
|
@ -68,14 +59,14 @@ std::vector<int64_t> CheckAttrSize(const PrimitivePtr &primitive, const std::str
|
|||
}
|
||||
|
||||
void CheckRecvCorner(std::vector<int64_t> recv_rank_ids, int64_t idx1, int64_t idx2, int64_t idx_corner) {
|
||||
if (recv_rank_ids[idx1] != kInvalidIds && recv_rank_ids[idx2] != kInvalidIds &&
|
||||
recv_rank_ids[idx_corner] == kInvalidIds) {
|
||||
if (recv_rank_ids[idx1] != kNeighborExchangeV2InvalidIds && recv_rank_ids[idx2] != kNeighborExchangeV2InvalidIds &&
|
||||
recv_rank_ids[idx_corner] == kNeighborExchangeV2InvalidIds) {
|
||||
MS_EXCEPTION(ValueError) << "Invalid recv_rank_ids, as recv_rank_ids[" << idx1 << "] = " << recv_rank_ids[idx1]
|
||||
<< ", recv_rank_ids[" << idx2 << "] = " << recv_rank_ids[idx2] << ", and recv_rank_ids["
|
||||
<< idx_corner << "] = " << recv_rank_ids[idx_corner] << ".";
|
||||
}
|
||||
if ((recv_rank_ids[idx1] == kInvalidIds || recv_rank_ids[idx2] == kInvalidIds) &&
|
||||
recv_rank_ids[idx_corner] != kInvalidIds) {
|
||||
if ((recv_rank_ids[idx1] == kNeighborExchangeV2InvalidIds || recv_rank_ids[idx2] == kNeighborExchangeV2InvalidIds) &&
|
||||
recv_rank_ids[idx_corner] != kNeighborExchangeV2InvalidIds) {
|
||||
MS_EXCEPTION(ValueError) << "Invalid recv_rank_ids, as recv_rank_ids[" << idx1 << "] = " << recv_rank_ids[idx1]
|
||||
<< ", recv_rank_ids[" << idx2 << "] = " << recv_rank_ids[idx2] << ", and recv_rank_ids["
|
||||
<< idx_corner << "] = " << recv_rank_ids[idx_corner] << ".";
|
||||
|
@ -86,7 +77,7 @@ void CheckIdsValue(std::vector<int64_t> rank_ids) {
|
|||
// check repeat & invalid value
|
||||
std::set<int64_t> ids_count;
|
||||
for (auto id : rank_ids) {
|
||||
if (id < 0 && id != kInvalidIds) {
|
||||
if (id < 0 && id != kNeighborExchangeV2InvalidIds) {
|
||||
MS_EXCEPTION(ValueError) << "Invalid send_rank_ids or recv_rank_ids: " << id
|
||||
<< ", all the rank id should be >= 0 or -1.";
|
||||
}
|
||||
|
@ -106,17 +97,21 @@ void CheckLensValue(std::vector<int64_t> lens) {
|
|||
}
|
||||
}
|
||||
|
||||
void Check(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
void NeighborExchangeV2Check(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
const int64_t input_num = 1;
|
||||
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, input_num, prim_name);
|
||||
|
||||
// check size of send_rank_ids, recv_rank_ids, send_lens, recv_lens
|
||||
constexpr size_t kRankIdsSize = 8;
|
||||
constexpr size_t kLensSize = 4;
|
||||
constexpr auto kSendRankIds = "send_rank_ids";
|
||||
constexpr auto kSendLens = "send_lens";
|
||||
auto send_rank_ids = CheckAttrSize(primitive, kSendRankIds, kRankIdsSize);
|
||||
auto recv_rank_ids = CheckAttrSize(primitive, kRecvRankIds, kRankIdsSize);
|
||||
auto recv_rank_ids = CheckAttrSize(primitive, kNeighborExchangeV2RecvRankIds, kRankIdsSize);
|
||||
auto send_lens = CheckAttrSize(primitive, kSendLens, kLensSize);
|
||||
auto recv_lens = CheckAttrSize(primitive, kRecvLens, kLensSize);
|
||||
auto recv_lens = CheckAttrSize(primitive, kNeighborExchangeV2RecvLens, kLensSize);
|
||||
|
||||
// check rank_ids value
|
||||
CheckIdsValue(send_rank_ids);
|
||||
|
@ -126,12 +121,13 @@ void Check(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &in
|
|||
CheckLensValue(recv_lens);
|
||||
|
||||
// check recv rankids invalid cond
|
||||
CheckRecvCorner(recv_rank_ids, kIdx0, kIdx2, kIdx1);
|
||||
CheckRecvCorner(recv_rank_ids, kIdx2, kIdx4, kIdx3);
|
||||
CheckRecvCorner(recv_rank_ids, kIdx4, kIdx6, kIdx5);
|
||||
CheckRecvCorner(recv_rank_ids, kIdx6, kIdx0, kIdx7);
|
||||
CheckRecvCorner(recv_rank_ids, kNeighborExchangeV2Idx0, kNeighborExchangeV2Idx2, kNeighborExchangeV2Idx1);
|
||||
CheckRecvCorner(recv_rank_ids, kNeighborExchangeV2Idx2, kNeighborExchangeV2Idx4, kNeighborExchangeV2Idx3);
|
||||
CheckRecvCorner(recv_rank_ids, kNeighborExchangeV2Idx4, kNeighborExchangeV2Idx6, kNeighborExchangeV2Idx5);
|
||||
CheckRecvCorner(recv_rank_ids, kNeighborExchangeV2Idx6, kNeighborExchangeV2Idx0, kNeighborExchangeV2Idx7);
|
||||
|
||||
// check data_format is NCHW
|
||||
constexpr auto kDataFormat = "format";
|
||||
auto format_attr = primitive->GetAttr(kDataFormat);
|
||||
string format = "";
|
||||
try {
|
||||
|
@ -147,27 +143,31 @@ void Check(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &in
|
|||
// check if send_lens > input_lens
|
||||
std::vector<int64_t> input_shape =
|
||||
CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
constexpr size_t kInputSize = 4;
|
||||
if (input_shape.size() != kInputSize) {
|
||||
MS_EXCEPTION(ValueError) << "Input size is not 4, only support NCHW now.";
|
||||
}
|
||||
if (send_lens[kIdx0] > input_shape[kHDim]) {
|
||||
MS_EXCEPTION(ValueError) << "Attr send_lens[0]: " << send_lens[kIdx0]
|
||||
constexpr size_t kHDim = 2;
|
||||
constexpr size_t kWDim = 3;
|
||||
if (send_lens[kNeighborExchangeV2Idx0] > input_shape[kHDim]) {
|
||||
MS_EXCEPTION(ValueError) << "Attr send_lens[0]: " << send_lens[kNeighborExchangeV2Idx0]
|
||||
<< " is larger than input size in H dim: " << input_shape[kHDim] << ".";
|
||||
}
|
||||
if (send_lens[kIdx1] > input_shape[kHDim]) {
|
||||
MS_EXCEPTION(ValueError) << "Attr send_lens[1]: " << send_lens[kIdx1]
|
||||
if (send_lens[kNeighborExchangeV2Idx1] > input_shape[kHDim]) {
|
||||
MS_EXCEPTION(ValueError) << "Attr send_lens[1]: " << send_lens[kNeighborExchangeV2Idx1]
|
||||
<< " is larger than input size in H dim: " << input_shape[kHDim] << ".";
|
||||
}
|
||||
if (send_lens[kIdx2] > input_shape[kWDim]) {
|
||||
MS_EXCEPTION(ValueError) << "Attr send_lens[2]: " << send_lens[kIdx2]
|
||||
if (send_lens[kNeighborExchangeV2Idx2] > input_shape[kWDim]) {
|
||||
MS_EXCEPTION(ValueError) << "Attr send_lens[2]: " << send_lens[kNeighborExchangeV2Idx2]
|
||||
<< " is larger than input size in W dim: " << input_shape[kWDim] << ".";
|
||||
}
|
||||
if (send_lens[kIdx3] > input_shape[kWDim]) {
|
||||
MS_EXCEPTION(ValueError) << "Attr send_lens[3]: " << send_lens[kIdx3]
|
||||
if (send_lens[kNeighborExchangeV2Idx3] > input_shape[kWDim]) {
|
||||
MS_EXCEPTION(ValueError) << "Attr send_lens[3]: " << send_lens[kNeighborExchangeV2Idx3]
|
||||
<< " is larger than input size in W dim: " << input_shape[kWDim] << ".";
|
||||
}
|
||||
|
||||
// check group
|
||||
constexpr auto kGroup = "group";
|
||||
auto group_attr = primitive->GetAttr(kGroup);
|
||||
try {
|
||||
MS_EXCEPTION_IF_NULL(group_attr);
|
||||
|
@ -177,14 +177,15 @@ void Check(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &in
|
|||
}
|
||||
}
|
||||
|
||||
abstract::BaseShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
abstract::BaseShapePtr NeighborExchangeV2InferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto recv_rank_ids = primitive->GetAttr(kRecvRankIds);
|
||||
auto recv_rank_ids = primitive->GetAttr(kNeighborExchangeV2RecvRankIds);
|
||||
MS_EXCEPTION_IF_NULL(recv_rank_ids);
|
||||
auto recv_rank_ids_value = recv_rank_ids->cast<ValueSequencePtr>();
|
||||
MS_EXCEPTION_IF_NULL(recv_rank_ids_value);
|
||||
std::vector<int64_t> recv_rank_ids_v = GetValue<std::vector<int64_t>>(recv_rank_ids_value);
|
||||
auto recv_lens = primitive->GetAttr(kRecvLens);
|
||||
auto recv_lens = primitive->GetAttr(kNeighborExchangeV2RecvLens);
|
||||
MS_EXCEPTION_IF_NULL(recv_lens);
|
||||
auto recv_lens_value = recv_lens->cast<ValueSequencePtr>();
|
||||
MS_EXCEPTION_IF_NULL(recv_lens_value);
|
||||
|
@ -192,17 +193,17 @@ abstract::BaseShapePtr InferShape(const PrimitivePtr &primitive, const std::vect
|
|||
|
||||
std::vector<int64_t> input_shape =
|
||||
CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
if (recv_rank_ids_v[kIdx0] != kInvalidIds) {
|
||||
input_shape[kIdx2] += recv_lens_v[kIdx0];
|
||||
if (recv_rank_ids_v[kNeighborExchangeV2Idx0] != kNeighborExchangeV2InvalidIds) {
|
||||
input_shape[kNeighborExchangeV2Idx2] += recv_lens_v[kNeighborExchangeV2Idx0];
|
||||
}
|
||||
if (recv_rank_ids_v[kIdx4] != kInvalidIds) {
|
||||
input_shape[kIdx2] += recv_lens_v[kIdx1];
|
||||
if (recv_rank_ids_v[kNeighborExchangeV2Idx4] != kNeighborExchangeV2InvalidIds) {
|
||||
input_shape[kNeighborExchangeV2Idx2] += recv_lens_v[kNeighborExchangeV2Idx1];
|
||||
}
|
||||
if (recv_rank_ids_v[kIdx6] != kInvalidIds) {
|
||||
input_shape[kIdx3] += recv_lens_v[kIdx2];
|
||||
if (recv_rank_ids_v[kNeighborExchangeV2Idx6] != kNeighborExchangeV2InvalidIds) {
|
||||
input_shape[kNeighborExchangeV2Idx3] += recv_lens_v[kNeighborExchangeV2Idx2];
|
||||
}
|
||||
if (recv_rank_ids_v[kIdx2] != kInvalidIds) {
|
||||
input_shape[kIdx3] += recv_lens_v[kIdx3];
|
||||
if (recv_rank_ids_v[kNeighborExchangeV2Idx2] != kNeighborExchangeV2InvalidIds) {
|
||||
input_shape[kNeighborExchangeV2Idx3] += recv_lens_v[kNeighborExchangeV2Idx3];
|
||||
}
|
||||
BaseShapePtr output_shape = std::make_shared<abstract::Shape>(input_shape);
|
||||
if (input_shape.empty()) {
|
||||
|
@ -211,7 +212,7 @@ abstract::BaseShapePtr InferShape(const PrimitivePtr &primitive, const std::vect
|
|||
return output_shape;
|
||||
}
|
||||
|
||||
TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
TypePtr NeighborExchangeV2InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
// recv type
|
||||
TypePtr recv_type = input_args[0]->BuildType();
|
||||
|
@ -223,9 +224,9 @@ TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBaseP
|
|||
} // namespace
|
||||
AbstractBasePtr NeighborExchangeV2Infer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
Check(primitive, input_args);
|
||||
auto type = InferType(primitive, input_args);
|
||||
auto shape = InferShape(primitive, input_args);
|
||||
NeighborExchangeV2Check(primitive, input_args);
|
||||
auto type = NeighborExchangeV2InferType(primitive, input_args);
|
||||
auto shape = NeighborExchangeV2InferShape(primitive, input_args);
|
||||
return abstract::MakeAbstract(shape, type);
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(NeighborExchangeV2, prim::kPrimNeighborExchangeV2, NeighborExchangeV2Infer, nullptr, true);
|
||||
|
|
|
@ -21,7 +21,8 @@
|
|||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
abstract::ShapePtr NonMaxSuppressionV3InferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto prim_name = primitive->name();
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
const int input_num = 5;
|
||||
|
@ -72,7 +73,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
|||
selected_indices_max_shape);
|
||||
}
|
||||
|
||||
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
TypePtr NonMaxSuppressionV3InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto prim_name = prim->name();
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
const int input_num = 5;
|
||||
|
@ -107,7 +108,8 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
|
|||
AbstractBasePtr NonMaxSuppressionV3Infer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
return abstract::MakeAbstract(InferShape(primitive, input_args), InferType(primitive, input_args));
|
||||
return abstract::MakeAbstract(NonMaxSuppressionV3InferShape(primitive, input_args),
|
||||
NonMaxSuppressionV3InferType(primitive, input_args));
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(NonMaxSuppressionV3, prim::kPrimNonMaxSuppressionV3, NonMaxSuppressionV3Infer, nullptr,
|
||||
true);
|
||||
|
|
|
@ -21,7 +21,7 @@
|
|||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
abstract::ShapePtr PadInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
auto paddings_attr = GetValue<std::vector<std::vector<int64_t>>>(primitive->GetAttr(kPaddings));
|
||||
|
@ -44,7 +44,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
|||
return std::make_shared<abstract::Shape>(out_shape);
|
||||
}
|
||||
|
||||
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
TypePtr PadInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
|
@ -62,8 +62,8 @@ std::vector<std::vector<int64_t>> Pad::get_paddings() const {
|
|||
}
|
||||
AbstractBasePtr PadInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args),
|
||||
InferShape(primitive, input_args)->shape());
|
||||
return std::make_shared<abstract::AbstractTensor>(PadInferType(primitive, input_args),
|
||||
PadInferShape(primitive, input_args)->shape());
|
||||
}
|
||||
REGISTER_PRIMITIVE_C(kNamePad, Pad);
|
||||
} // namespace ops
|
||||
|
|
|
@ -20,7 +20,7 @@
|
|||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
abstract::ShapePtr PReLUInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto prim_name = primitive->name();
|
||||
auto x = input_args[0]->BuildShape();
|
||||
auto w = input_args[1]->BuildShape();
|
||||
|
@ -40,7 +40,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
|||
return shape_element;
|
||||
}
|
||||
|
||||
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
TypePtr PReLUInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
const std::set<TypePtr> valid_types = {kFloat16, kFloat32};
|
||||
std::map<string, TypePtr> check_map = {{"input_x", input_args[0]->BuildType()},
|
||||
{"weight", input_args[1]->BuildType()}};
|
||||
|
@ -52,8 +52,8 @@ AbstractBasePtr PReLUInfer(const abstract::AnalysisEnginePtr &, const PrimitiveP
|
|||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
const int64_t input_num = 2;
|
||||
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name());
|
||||
return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args),
|
||||
InferShape(primitive, input_args));
|
||||
return std::make_shared<abstract::AbstractTensor>(PReLUInferType(primitive, input_args),
|
||||
PReLUInferShape(primitive, input_args));
|
||||
}
|
||||
REGISTER_PRIMITIVE_C(kNamePReLU, PReLU);
|
||||
} // namespace ops
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue