!12688 using cpp infer firstly

From: @lianliguang
Reviewed-by: 
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2021-03-09 20:29:32 +08:00 committed by Gitee
commit 04e23927ef
25 changed files with 302 additions and 143 deletions

View File

@ -366,7 +366,7 @@ if(CMAKE_SYSTEM_NAME MATCHES "Windows")
elseif(CMAKE_SYSTEM_NAME MATCHES "Darwin") elseif(CMAKE_SYSTEM_NAME MATCHES "Darwin")
target_link_libraries(mindspore mindspore::pybind11_module) target_link_libraries(mindspore mindspore::pybind11_module)
target_link_libraries(mindspore mindspore_gvar) target_link_libraries(mindspore mindspore_gvar)
target_link_libraries(_c_expression PRIVATE -Wl,-force_load mindspore -Wl,-noall_load) target_link_libraries(_c_expression PRIVATE -Wl,-force_load mindspore mindspore_core -Wl,-noall_load)
else() else()
if(ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU)) if(ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU))
target_link_libraries(mindspore proto_input mindspore::protobuf target_link_libraries(mindspore proto_input mindspore::protobuf
@ -376,7 +376,8 @@ else()
target_link_libraries(mindspore ibverbs rdmacm) target_link_libraries(mindspore ibverbs rdmacm)
endif() endif()
endif() endif()
target_link_libraries(_c_expression PRIVATE -Wl,--whole-archive mindspore proto_input -Wl,--no-whole-archive) target_link_libraries(_c_expression PRIVATE -Wl,--whole-archive mindspore mindspore_core
proto_input -Wl,--no-whole-archive)
target_link_libraries(_c_expression PRIVATE mindspore::pybind11_module) target_link_libraries(_c_expression PRIVATE mindspore::pybind11_module)
target_link_libraries(_c_expression PRIVATE mindspore_gvar) target_link_libraries(_c_expression PRIVATE mindspore_gvar)
if(ENABLE_D) if(ENABLE_D)

View File

@ -13,7 +13,7 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
#include "backend/optimizer/pass/const_input_to_attr_registry.h" #include "backend/optimizer/common/const_input_to_attr_registry.h"
#include <utility> #include <utility>

View File

@ -13,8 +13,8 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_CONST_INPUT_TO_ATTR_REGISTRY_H_ #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_COMMON_CONST_INPUT_TO_ATTR_REGISTRY_H_
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_CONST_INPUT_TO_ATTR_REGISTRY_H_ #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_COMMON_CONST_INPUT_TO_ATTR_REGISTRY_H_
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
@ -75,4 +75,4 @@ struct ConstInputToAttrInfoReceiver {
::mindspore::opt::ConstInputToAttrInfoRegister(op_name) ::mindspore::opt::ConstInputToAttrInfoRegister(op_name)
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_CONST_INPUT_TO_ATTR_REGISTRY_H_ #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_COMMON_CONST_INPUT_TO_ATTR_REGISTRY_H_

View File

@ -31,6 +31,8 @@
#include "utils/ms_utils.h" #include "utils/ms_utils.h"
#include "runtime/device/kernel_info.h" #include "runtime/device/kernel_info.h"
#include "utils/ms_context.h" #include "utils/ms_context.h"
#include "backend/optimizer/common/const_input_to_attr_registry.h"
#include "abstract/primitive_infer_map.h"
namespace mindspore { namespace mindspore {
namespace opt { namespace opt {
@ -700,6 +702,92 @@ AnfNodePtr HandleSexpVector(const BaseRef &sexp, const BaseRef &graph, Primitive
} }
return CreateCNodeWithGraph(input_nodes, graph); return CreateCNodeWithGraph(input_nodes, graph);
} }
// rectify absttract if the input has been converted to the attr
AbstractBasePtrList RectifyAbstractFromRegAttr(const PrimitivePtr &primitive,
const AbstractBasePtrList &input_abstract) {
MS_EXCEPTION_IF_NULL(primitive);
opt::ConstInputToAttrInfoRegister reg;
if (!opt::ConstInputToAttrInfoRegistry::Instance().GetRegisterByOpName(primitive->name(), &reg)) {
return input_abstract;
}
if (AnfAlgo::HasDynamicShapeFlag(primitive) ||
DynamicShapeConstInputToAttr.find(primitive->name()) != DynamicShapeConstInputToAttr.end()) {
return input_abstract;
}
auto convert_input_list = reg.GetConstInputAttrInfo();
auto input_names = primitive->GetAttr(kAttrInputNames);
if (input_names == nullptr) {
return input_abstract;
}
auto input_names_vec = GetValue<std::vector<std::string>>(input_names);
AbstractBasePtrList rectify_abs_list;
size_t ori_index = 0;
rectify_abs_list.resize(input_names_vec.size());
for (size_t index = 0; index < rectify_abs_list.size(); ++index) {
// if convert input list find the index it means the input has been converted to the attr
if (convert_input_list.find(index) != convert_input_list.end()) {
AbstractBasePtr rectify_abs = nullptr;
auto input_name = input_names_vec[index];
auto attr = primitive->GetAttr(input_name);
if (attr != nullptr) {
rectify_abs = attr->ToAbstract();
} else {
MS_LOG(DEBUG) << "the node prim name :" << primitive->name() << "input index :" << index
<< " input name :" << input_name << "has not been converted to the attr";
rectify_abs = input_abstract[ori_index++];
}
rectify_abs_list[index] = rectify_abs;
continue;
}
if (ori_index > input_abstract.size()) {
MS_LOG(EXCEPTION) << "index is out of range input abstract size " << input_abstract.size()
<< " get index :" << ori_index;
}
rectify_abs_list[index] = input_abstract[ori_index++];
}
return rectify_abs_list;
}
AbstractBasePtrList RectifyAbstractFromDynamicInput(const PrimitivePtr &primitive,
const AbstractBasePtrList &input_abstract) {
auto dynamic_inputs_list = primitive->GetAttr(kAttrDynInputSizes);
if (dynamic_inputs_list == nullptr) {
return input_abstract;
}
AbstractBasePtrList rectifyed_abs_list;
const int kNotDynamicFlag = -1;
auto dynamic_inputs_index = GetValue<std::vector<int64_t>>(dynamic_inputs_list);
size_t input_index = 0;
for (auto item : dynamic_inputs_index) {
if (item == kNotDynamicFlag) {
if (input_index >= input_abstract.size()) {
MS_LOG(EXCEPTION) << " index " << input_index << " is out of range in input abstract " << input_abstract.size();
}
rectifyed_abs_list.emplace_back(input_abstract[input_index++]);
} else {
if (item < 0) {
MS_LOG(EXCEPTION) << " the dynamic input size check error the index should be -1 or positive number but got "
<< item;
}
AbstractBasePtrList dynamic_inputs_abs;
for (auto index = item; index > 0; --index) {
if (input_index >= input_abstract.size()) {
MS_LOG(EXCEPTION) << " index " << input_index << " is out of range in input abstract "
<< input_abstract.size();
}
dynamic_inputs_abs.emplace_back(input_abstract[input_index++]);
}
rectifyed_abs_list.emplace_back(std::make_shared<abstract::AbstractTuple>(dynamic_inputs_abs));
}
}
return rectifyed_abs_list;
}
AbstractBasePtrList RectifyAbstract(const PrimitivePtr &primitive, const AbstractBasePtrList &input_abstract) {
auto rectify_abs_list = RectifyAbstractFromRegAttr(primitive, input_abstract);
return RectifyAbstractFromDynamicInput(primitive, rectify_abs_list);
}
} // namespace } // namespace
AnfNodePtr SexpToNode(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars, bool multigraph) { AnfNodePtr SexpToNode(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars, bool multigraph) {
@ -835,5 +923,24 @@ void TransferDepend(const CNodePtr &old_node, const FuncGraphPtr &graph, const C
} }
} }
} }
AbstractBasePtr CppInferShape(const PrimitivePtr &prim, const AbstractBasePtrList &args_spec_list) {
MS_EXCEPTION_IF_NULL(prim);
auto &prim_eval_implement_map = abstract::GetPrimitiveToEvalImplMap();
auto ret = prim_eval_implement_map.find(prim);
if (ret != prim_eval_implement_map.end()) {
// fing infer function in the front infer map and restore input abastract form dynamic inputs and reg attr
auto infer_spec_list = RectifyAbstract(prim, args_spec_list);
return ret->second.impl_(nullptr, prim, infer_spec_list);
} else {
// if the infer function has been not founded in the front infer map find it in the backend infer map instead
auto &prim_backend_eval_impl_map = abstract::GetPrimitiveToBackendEvalImplMap();
auto ret_backend = prim_backend_eval_impl_map.find(prim);
if (ret_backend != prim_backend_eval_impl_map.end()) {
return ret_backend->second.impl_(nullptr, prim, args_spec_list);
}
}
MS_LOG(EXCEPTION) << "Get infer shape function failed, primitive name:" << prim->name()
<< " primitive type:" << prim->type_name();
}
} // namespace opt } // namespace opt
} // namespace mindspore } // namespace mindspore

View File

@ -212,6 +212,8 @@ ValueNodePtr MakeValueNode(const ValueNodePtr &value_node);
// Transfer depend or control_depend to the new node // Transfer depend or control_depend to the new node
void TransferDepend(const CNodePtr &old_node, const FuncGraphPtr &graph, const CNodePtr &new_node); void TransferDepend(const CNodePtr &old_node, const FuncGraphPtr &graph, const CNodePtr &new_node);
AbstractBasePtr CppInferShape(const PrimitivePtr &prim, const AbstractBasePtrList &args_spec_list);
} // namespace opt } // namespace opt
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_COMMON_HELPER_H_ #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_COMMON_HELPER_H_

View File

@ -27,7 +27,7 @@
#include "backend/kernel_compiler/akg/akg_kernel_json_decoder.h" #include "backend/kernel_compiler/akg/akg_kernel_json_decoder.h"
#include "backend/kernel_compiler/kernel.h" #include "backend/kernel_compiler/kernel.h"
#include "backend/session/anf_runtime_algorithm.h" #include "backend/session/anf_runtime_algorithm.h"
#include "backend/optimizer/pass/const_input_to_attr_registry.h" #include "backend/optimizer/common/const_input_to_attr_registry.h"
#include "ir/func_graph_cloner.h" #include "ir/func_graph_cloner.h"
#include "ir/func_graph.h" #include "ir/func_graph.h"
#include "pipeline/jit/parse/python_adapter.h" #include "pipeline/jit/parse/python_adapter.h"

View File

@ -19,7 +19,7 @@
#include <string> #include <string>
#include <memory> #include <memory>
#include "backend/optimizer/pass/const_input_to_attr_registry.h" #include "backend/optimizer/common/const_input_to_attr_registry.h"
#include "backend/optimizer/common/helper.h" #include "backend/optimizer/common/helper.h"
#include "utils/utils.h" #include "utils/utils.h"
#include "utils/ms_context.h" #include "utils/ms_context.h"

View File

@ -1534,6 +1534,18 @@ bool AnfRuntimeAlgorithm::GetBooleanAttr(const AnfNodePtr &node, const std::stri
return AnfAlgo::GetNodeAttr<bool>(node, attr); return AnfAlgo::GetNodeAttr<bool>(node, attr);
} }
bool AnfRuntimeAlgorithm::HasDynamicShapeFlag(const PrimitivePtr &prim) {
auto get_bool_attr = [](const PrimitivePtr &primitive, const std::string &attr_name) -> bool {
MS_EXCEPTION_IF_NULL(primitive);
if (!primitive->HasAttr(attr_name)) {
return false;
}
return GetValue<bool>(primitive->GetAttr(attr_name));
};
return get_bool_attr(prim, kAttrInputIsDynamicShape) || get_bool_attr(prim, kAttrOutputIsDynamicShape) ||
get_bool_attr(prim, kAttrIsDynamicShape);
}
bool AnfRuntimeAlgorithm::IsDynamicShape(const AnfNodePtr &node) { bool AnfRuntimeAlgorithm::IsDynamicShape(const AnfNodePtr &node) {
return GetBooleanAttr(node, kAttrInputIsDynamicShape) || GetBooleanAttr(node, kAttrOutputIsDynamicShape) || return GetBooleanAttr(node, kAttrInputIsDynamicShape) || GetBooleanAttr(node, kAttrOutputIsDynamicShape) ||
GetBooleanAttr(node, kAttrIsDynamicShape); GetBooleanAttr(node, kAttrIsDynamicShape);
@ -1805,7 +1817,7 @@ void AnfRuntimeAlgorithm::InferShape(const CNodePtr &node) {
args_spec_list.emplace_back(real_input->abstract()); args_spec_list.emplace_back(real_input->abstract());
} }
} }
auto eval_result = abstract::CppInferShape(primitive, args_spec_list); auto eval_result = opt::CppInferShape(primitive, args_spec_list);
node->set_abstract(eval_result); node->set_abstract(eval_result);
} }
} // namespace session } // namespace session

View File

@ -230,6 +230,7 @@ class AnfRuntimeAlgorithm {
// get fix output precision from prev node, input_idx is the input index of current node related to prev node. // get fix output precision from prev node, input_idx is the input index of current node related to prev node.
static TypeId GetPrevNodeOutputPrecision(const AnfNodePtr &node, size_t input_idx); static TypeId GetPrevNodeOutputPrecision(const AnfNodePtr &node, size_t input_idx);
static bool IsDynamicShape(const AnfNodePtr &node); static bool IsDynamicShape(const AnfNodePtr &node);
static bool HasDynamicShapeFlag(const PrimitivePtr &prim);
static bool IsCondControlKernel(const CNodePtr &node); static bool IsCondControlKernel(const CNodePtr &node);
static bool IsIndependentNode(const CNodePtr &node); static bool IsIndependentNode(const CNodePtr &node);
static bool GetBooleanAttr(const AnfNodePtr &node, const std::string &attr); static bool GetBooleanAttr(const AnfNodePtr &node, const std::string &attr);

View File

@ -1311,15 +1311,6 @@ bool IsInWhiteList(const PrimitivePtr &primitive) {
return false; return false;
} }
StandardPrimitiveEvalImpl GetPrimitiveInferImpl(const PrimitivePtr &primitive) {
MS_EXCEPTION_IF_NULL(primitive);
auto iter = GetPrimitiveToEvalImplMap().find(primitive);
if (iter == GetPrimitiveToEvalImplMap().end()) {
return nullptr;
}
return iter->second.impl_;
}
PrimEvaluatorMap &GetPrimEvaluatorConstructors() { PrimEvaluatorMap &GetPrimEvaluatorConstructors() {
PrimEvaluatorMap &constructor = PrimEvaluatorConstructors; PrimEvaluatorMap &constructor = PrimEvaluatorConstructors;
if (!constructor.empty()) { if (!constructor.empty()) {

View File

@ -112,7 +112,6 @@ class MixedPrecisionCastEvaluator : public Evaluator {
}; };
bool IsInWhiteList(const PrimitivePtr &primitive); bool IsInWhiteList(const PrimitivePtr &primitive);
StandardPrimitiveEvalImpl GetPrimitiveInferImpl(const PrimitivePtr &primitive);
using ValuePtrList = std::vector<ValuePtr>; using ValuePtrList = std::vector<ValuePtr>;
using PrimitiveImpl = ValuePtr (*)(const ValuePtrList &); using PrimitiveImpl = ValuePtr (*)(const ValuePtrList &);

View File

@ -357,6 +357,13 @@ EvaluatorPtr GetPrimEvaluator(const PrimitivePtr &prim, const AnalysisEnginePtr
return std::make_shared<MixedPrecisionCastEvaluator>(prim); return std::make_shared<MixedPrecisionCastEvaluator>(prim);
} }
// find prim infer function in the prim function map return a standard evaluator
StandardPrimitiveEvalImpl eval_impl = GetPrimitiveInferImpl(prim);
if (eval_impl != nullptr) {
return std::make_shared<StandardPrimEvaluator>(prim, eval_impl);
}
// use python infer function if the infer function not founded in the map return a python evaluator
EvaluatorPtr evaluator = nullptr; EvaluatorPtr evaluator = nullptr;
if (prim->HasPyEvaluator()) { if (prim->HasPyEvaluator()) {
auto prim_py = dyn_cast<PrimitivePy>(prim); auto prim_py = dyn_cast<PrimitivePy>(prim);
@ -376,17 +383,7 @@ EvaluatorPtr GetPrimEvaluator(const PrimitivePtr &prim, const AnalysisEnginePtr
MS_LOG(EXCEPTION) << "The primitive with python evaluator should be a python primitive."; MS_LOG(EXCEPTION) << "The primitive with python evaluator should be a python primitive.";
} }
if (prim->isa<PrimitivePy>() || prim->HasAttr()) { // return a default evaluator
if (engine == nullptr) {
(void)GetPrimEvaluatorConstructors();
}
// If a primitive may have attr, try to create a new evaluator.
StandardPrimitiveEvalImpl eval_impl = GetPrimitiveInferImpl(prim);
if (eval_impl != nullptr) {
return std::make_shared<StandardPrimEvaluator>(prim, eval_impl);
}
}
if (engine == nullptr) { if (engine == nullptr) {
// If engine is nullptr, get constructor from default. // If engine is nullptr, get constructor from default.
const PrimEvaluatorMap &prim_evaluator_map = GetPrimEvaluatorConstructors(); const PrimEvaluatorMap &prim_evaluator_map = GetPrimEvaluatorConstructors();
@ -778,16 +775,5 @@ EvalResultPtr EvalOnePrim(const PrimitivePtr &primitive, const AbstractBasePtrLi
auto eval_result = trivial_evaluator->EvalPrim(nullptr, arg_specs); auto eval_result = trivial_evaluator->EvalPrim(nullptr, arg_specs);
return eval_result; return eval_result;
} }
AbstractBasePtr CppInferShape(const PrimitivePtr &prim, const AbstractBasePtrList &args_spec_list) {
MS_EXCEPTION_IF_NULL(prim);
auto &prim_eval_implement_map = GetPrimitiveToEvalImplMap();
auto ret = prim_eval_implement_map.find(prim);
if (ret == prim_eval_implement_map.end()) {
MS_LOG(EXCEPTION) << "Get infer shape function failed, primitive name:" << prim->name()
<< " primitive type:" << prim->type_name();
}
return ret->second.impl_(nullptr, prim, args_spec_list);
}
} // namespace abstract } // namespace abstract
} // namespace mindspore } // namespace mindspore

View File

@ -331,8 +331,6 @@ AbstractBasePtr FromValue(const T &value, bool broaden = false) {
} }
EvalResultPtr EvalOnePrim(const PrimitivePtr &p, const AbstractBasePtrList &arg_specs); EvalResultPtr EvalOnePrim(const PrimitivePtr &p, const AbstractBasePtrList &arg_specs);
AbstractBasePtr CppInferShape(const PrimitivePtr &prim, const AbstractBasePtrList &args_spec_list);
} // namespace abstract } // namespace abstract
} // namespace mindspore } // namespace mindspore

View File

@ -44,7 +44,7 @@
#include "pipeline/jit/static_analysis/prim.h" #include "pipeline/jit/static_analysis/prim.h"
#include "pipeline/jit/static_analysis/auto_monad.h" #include "pipeline/jit/static_analysis/auto_monad.h"
#include "backend/session/session_factory.h" #include "backend/session/session_factory.h"
#include "backend/optimizer/pass/const_input_to_attr_registry.h" #include "backend/optimizer/common/const_input_to_attr_registry.h"
#include "backend/optimizer/common/helper.h" #include "backend/optimizer/common/helper.h"
#include "pipeline/jit/action.h" #include "pipeline/jit/action.h"
@ -807,21 +807,13 @@ void PynativeExecutor::GetOpOutputAbstract(const OpExecInfoPtr &op_exec_info,
} }
} }
// get output dynamic shape info // get output dynamic shape info
auto py_abstract = op_exec_info->abstract; auto abstract = op_exec_info->abstract;
MS_EXCEPTION_IF_NULL(py_abstract); MS_EXCEPTION_IF_NULL(abstract);
auto py_shape = py_abstract->BuildShape(); auto shape = abstract->BuildShape();
MS_EXCEPTION_IF_NULL(py_shape); MS_EXCEPTION_IF_NULL(shape);
auto py_shape_info = py_shape->ToString(); auto shape_info = shape->ToString();
if (py_shape_info.find("-1") != string::npos) { if (shape_info.find("-1") != string::npos) {
auto c_abstract = abstract::CppInferShape(prim, args_spec_list); op_exec_info->is_dynamic_shape = true;
MS_EXCEPTION_IF_NULL(c_abstract);
auto c_shape = c_abstract->BuildShape();
MS_EXCEPTION_IF_NULL(c_shape);
auto c_shape_info = c_shape->ToString();
MS_LOG(DEBUG) << "Final infer output shape: " << c_shape_info;
if (c_shape_info.find("-1") != string::npos) {
op_exec_info->is_dynamic_shape = true;
}
} }
} }

View File

@ -123,7 +123,7 @@ void DynamicKernel::InferShape() {
} }
} }
auto eval_result = abstract::CppInferShape(primitive, args_spec_list); auto eval_result = opt::CppInferShape(primitive, args_spec_list);
cnode_ptr_->set_abstract(eval_result); cnode_ptr_->set_abstract(eval_result);
} }

View File

@ -1041,6 +1041,9 @@ AbstractBasePtr InferImplConcat(const AnalysisEnginePtr &, const PrimitivePtr &p
AbstractBasePtr InferImplRange(const AnalysisEnginePtr &, const PrimitivePtr &primitive, AbstractBasePtr InferImplRange(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) { const AbstractBasePtrList &args_spec_list) {
const std::string &op_name = primitive->name(); const std::string &op_name = primitive->name();
if (args_spec_list.size() == 1) {
return args_spec_list[0]->Broaden();
}
CheckArgsSize(op_name, args_spec_list, 3); CheckArgsSize(op_name, args_spec_list, 3);
AbstractTensorPtr range_start = CheckArg<AbstractTensor>(op_name, args_spec_list, 0); AbstractTensorPtr range_start = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
AbstractTensorPtr range_end = CheckArg<AbstractTensor>(op_name, args_spec_list, 1); AbstractTensorPtr range_end = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);

View File

@ -292,24 +292,6 @@ AbstractBasePtr InferImplFusedBatchNormEx(const AnalysisEnginePtr &, const Primi
return std::make_shared<AbstractTuple>(rets); return std::make_shared<AbstractTuple>(rets);
} }
AbstractBasePtr InferImplBatchNormGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// Inputs: five tensors(y_backprop, x, scale, save_mean, save_inv_variance).
MS_EXCEPTION_IF_NULL(args_spec_list[1]);
MS_EXCEPTION_IF_NULL(args_spec_list[2]);
MS_EXCEPTION_IF_NULL(args_spec_list[3]);
CheckArgsSize(primitive->name(), args_spec_list, 5);
auto dx = args_spec_list[1]->Broaden();
auto dscale = args_spec_list[2]->Broaden();
auto dbias = args_spec_list[3]->Broaden();
auto reserve_1 = args_spec_list[4]->Broaden();
auto reserve_2 = args_spec_list[5]->Broaden();
AbstractBasePtrList rets = {dx, dscale, dbias, reserve_1, reserve_2};
return std::make_shared<AbstractTuple>(rets);
}
AbstractBasePtr InferImplReluGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, AbstractBasePtr InferImplReluGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) { const AbstractBasePtrList &args_spec_list) {
// Inputs: two tensors(y_backprop, x). // Inputs: two tensors(y_backprop, x).
@ -468,20 +450,6 @@ AbstractBasePtr InferImplConv2D(const AnalysisEnginePtr &, const PrimitivePtr &p
return std::make_shared<AbstractTensor>(x_type, output_shape_ptr); return std::make_shared<AbstractTensor>(x_type, output_shape_ptr);
} }
AbstractBasePtr InferImplConv2DBackpropInput(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// Inputs: three tensors(doutput, input, filters).
CheckRequiredArgsSize(primitive->name(), args_spec_list, 3);
return args_spec_list[1]->Broaden();
}
AbstractBasePtr InferImplConv2DBackpropFilter(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// Inputs: three tensors(inputs, filter, doutput).
CheckArgsSize(primitive->name(), args_spec_list, 3);
return args_spec_list[2]->Broaden();
}
AbstractBasePtr InferImplBiasAdd(const AnalysisEnginePtr &, const PrimitivePtr &primitive, AbstractBasePtr InferImplBiasAdd(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) { const AbstractBasePtrList &args_spec_list) {
const std::string op_name = primitive->name(); const std::string op_name = primitive->name();

View File

@ -17,6 +17,11 @@
*/ */
#include "abstract/primitive_infer_map.h" #include "abstract/primitive_infer_map.h"
#include <map>
#include <string>
#include <vector>
#include "abstract/abstract_function.h" #include "abstract/abstract_function.h"
#include "abstract/infer_functions.h" #include "abstract/infer_functions.h"
@ -59,40 +64,21 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
{prim::kPrimNotInDict, {InferImplNotInDict, true}}, {prim::kPrimNotInDict, {InferImplNotInDict, true}},
{prim::kPrimIsConsant, {InferImplIsConstant, true}}, {prim::kPrimIsConsant, {InferImplIsConstant, true}},
// Maths // Maths
{prim::kPrimMaximumGrad, {InferImplMinOrMaxGrad, true}},
{prim::kPrimMinimumGrad, {InferImplMinOrMaxGrad, true}},
{prim::kPrimMul, {InferImplMul, true}},
{prim::kPrimAdd, {InferImplAdd, true}},
{prim::kPrimSquare, {InferImplSquare, true}}, {prim::kPrimSquare, {InferImplSquare, true}},
{prim::kPrimSqrt, {InferImplSqrt, true}},
{prim::kPrimSqrtGrad, {InferImplSqrtGrad, true}},
{prim::kPrimSub, {InferImplSub, true}},
{prim::kPrimEqual, {InferImplEqual, true}},
{prim::kPrimReduceSum, {InferImplReduceFunc, true}},
{prim::kPrimReduceMean, {InferImplReduceFunc, true}},
{prim::kPrimReduceAll, {InferImplReduceFunc, true}},
{prim::kPrimReduceAny, {InferImplReduceFunc, true}},
{prim::kPrimReduceMax, {InferImplReduceFunc, true}},
{prim::kPrimReduceMin, {InferImplReduceFunc, true}},
{prim::kPrimMinimum, {InferImplMinimum, true}},
{prim::kPrimDivNoNan, {InferImplDivNoNan, true}},
{prim::kPrimLinSpace, {InferImplLinSpace, true}},
{prim::kPrimAddN, {InferImplAddN, true}},
{prim::kPrimMatMul, {InferImplMatMul, true}}, {prim::kPrimMatMul, {InferImplMatMul, true}},
{prim::kPrimBatchMatMul, {InferImplBatchMatMul, true}}, {prim::kPrimBatchMatMul, {InferImplBatchMatMul, true}},
{prim::kPrimLess, {InferImplLess, true}}, {prim::kPrimMaximumGrad, {InferImplMinOrMaxGrad, true}},
{prim::kPrimMinimumGrad, {InferImplMinOrMaxGrad, true}},
{prim::kPrimSqrt, {InferImplSqrt, true}},
// Array // Array
{prim::kPrimRange, {InferImplRange, true}},
{prim::kPrimScalarToArray, {InferImplScalarToArray, true}}, {prim::kPrimScalarToArray, {InferImplScalarToArray, true}},
{prim::kPrimArrayToScalar, {InferImplArrayToScalar, true}}, {prim::kPrimArrayToScalar, {InferImplArrayToScalar, true}},
{prim::kPrimBroadcastShape, {InferImplBroadCastShape, true}}, {prim::kPrimBroadcastShape, {InferImplBroadCastShape, true}},
{prim::kPrimStack, {InferImplStack, true}},
{prim::kPrimPad, {InferImplPad, true}},
{prim::kPrimUnique, {InferImplUnique, true}}, {prim::kPrimUnique, {InferImplUnique, true}},
{prim::kPrimUniqueGrad, {InferImplUniqueGrad, true}}, {prim::kPrimUniqueGrad, {InferImplUniqueGrad, true}},
{prim::kPrimGather, {InferImplGatherV2, true}}, {prim::kPrimGather, {InferImplGatherV2, true}},
{prim::kPrimSparseGatherV2, {InferImplGatherV2, true}}, {prim::kPrimSparseGatherV2, {InferImplGatherV2, true}},
{prim::kPrimEmbeddingLookup, {InferImplEmbeddingLookup, true}},
{prim::kPrimUnsortedSegmentSum, {InferImplUnsortedSegmentSum, true}},
{prim::kPrimUnsortedSegmentMax, {InferImplUnsortedSegmentMax, true}}, {prim::kPrimUnsortedSegmentMax, {InferImplUnsortedSegmentMax, true}},
{prim::kPrimUnsortedSegmentMin, {InferImplUnsortedSegmentMin, true}}, {prim::kPrimUnsortedSegmentMin, {InferImplUnsortedSegmentMin, true}},
{prim::kPrimScatterAdd, {InferImplScatterAdd, true}}, {prim::kPrimScatterAdd, {InferImplScatterAdd, true}},
@ -104,18 +90,10 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
{prim::kPrimUpdateCache, {InferImplUpdateCache, true}}, {prim::kPrimUpdateCache, {InferImplUpdateCache, true}},
{prim::kPrimComputeAccidentalHits, {InferImplComputeAccidentalHits, true}}, {prim::kPrimComputeAccidentalHits, {InferImplComputeAccidentalHits, true}},
{prim::kPrimPadAndShift, {InferImplPadAndShift, true}}, {prim::kPrimPadAndShift, {InferImplPadAndShift, true}},
{prim::kPrimDiv, {InferImplDiv, true}},
{prim::kPrimRealDiv, {InferImplRealDiv, true}},
{prim::kPrimShape, {InferImplShape, false}},
{prim::kPrimDynamicShape, {InferImplDynamicShape, true}}, {prim::kPrimDynamicShape, {InferImplDynamicShape, true}},
{prim::kPrimTranspose, {InferImplTranspose, true}},
{prim::kPrimReshape, {InferImplReshape, true}},
{prim::kPrimMapUniform, {InferImplMapUniform, true}}, {prim::kPrimMapUniform, {InferImplMapUniform, true}},
{prim::kPrimSplit, {InferImplSplit, true}}, {prim::kPrimSplit, {InferImplSplit, true}},
{prim::kPrimSequenceMask, {InferImplSequenceMask, true}}, {prim::kPrimSequenceMask, {InferImplSequenceMask, true}},
{prim::kPrimConcat, {InferImplConcat, true}},
{prim::kPrimRange, {InferImplRange, true}},
{prim::kPrimArgMaxWithValue, {InferImplArgMaxWithValue, true}},
// Structure // Structure
{prim::kPrimMakeTuple, {InferImplMakeTuple, true}}, {prim::kPrimMakeTuple, {InferImplMakeTuple, true}},
{prim::kPrimMakeList, {InferImplMakeList, true}}, {prim::kPrimMakeList, {InferImplMakeList, true}},
@ -139,14 +117,10 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
{prim::kPrimPooling, {InferImplPooling, true}}, {prim::kPrimPooling, {InferImplPooling, true}},
{prim::kPrimPoolingGrad, {InferImplPoolingGrad, true}}, {prim::kPrimPoolingGrad, {InferImplPoolingGrad, true}},
{prim::kPrimFusedBatchNorm, {InferImplFusedBatchNorm, true}}, {prim::kPrimFusedBatchNorm, {InferImplFusedBatchNorm, true}},
{prim::kPrimFusedSparseAdam, {InferImplFusedSparseAdam, true}},
{prim::kPrimFusedBatchNormGrad, {InferImplFusedBatchNormGrad, true}}, {prim::kPrimFusedBatchNormGrad, {InferImplFusedBatchNormGrad, true}},
{prim::kPrimFusedBatchNormEx, {InferImplFusedBatchNormEx, true}}, {prim::kPrimFusedBatchNormEx, {InferImplFusedBatchNormEx, true}},
{prim::kPrimBatchNormGrad, {InferImplBatchNormGrad, true}},
{prim::kPrimReluGrad, {InferImplReluGrad, true}}, {prim::kPrimReluGrad, {InferImplReluGrad, true}},
{prim::kPrimConv2D, {InferImplConv2D, true}}, {prim::kPrimConv2D, {InferImplConv2D, true}},
{prim::kPrimConv2DBackpropInput, {InferImplConv2DBackpropInput, true}},
{prim::kPrimConv2DBackpropFilter, {InferImplConv2DBackpropFilter, true}},
{prim::kPrimBiasAdd, {InferImplBiasAdd, true}}, {prim::kPrimBiasAdd, {InferImplBiasAdd, true}},
{prim::kPrimBiasAddGrad, {InferImplBiasAddGrad, true}}, {prim::kPrimBiasAddGrad, {InferImplBiasAddGrad, true}},
{prim::kPrimRelu, {InferImplRelu, true}}, {prim::kPrimRelu, {InferImplRelu, true}},
@ -192,18 +166,60 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
{prim::kPrimRowTensorGetDenseShape, {InferImplRowTensorGetDenseShape, true}}, {prim::kPrimRowTensorGetDenseShape, {InferImplRowTensorGetDenseShape, true}},
{prim::kPrimRowTensorAdd, {InferImplRowTensorAdd, false}}, {prim::kPrimRowTensorAdd, {InferImplRowTensorAdd, false}},
// Comm Ops // Comm Ops
{prim::kPrimAllSwap, {InferImplAllSwap, true}},
{prim::kPrimMemCpyAsync, {InferImplMemCpyAsync, true}},
};
return prim_eval_implement_map;
}
PrimitiveEvalImplMap &GetPrimitiveToBackendEvalImplMap() {
static PrimitiveEvalImplMap prim_backend_eval_implement_map = {
{prim::kPrimMul, {InferImplMul, true}},
{prim::kPrimAdd, {InferImplAdd, true}},
{prim::kPrimSqrtGrad, {InferImplSqrtGrad, true}},
{prim::kPrimSub, {InferImplSub, true}},
{prim::kPrimEqual, {InferImplEqual, true}},
{prim::kPrimReduceSum, {InferImplReduceFunc, true}},
{prim::kPrimReduceMean, {InferImplReduceFunc, true}},
{prim::kPrimReduceAll, {InferImplReduceFunc, true}},
{prim::kPrimReduceAny, {InferImplReduceFunc, true}},
{prim::kPrimReduceMax, {InferImplReduceFunc, true}},
{prim::kPrimReduceMin, {InferImplReduceFunc, true}},
{prim::kPrimReduceScatter, {InferImplReduceScatter, true}},
{prim::kPrimCast, {InferImplCast, true}},
{prim::kPrimExpandDims, {InferImplExpandDims, true}},
{prim::kPrimAllReduce, {InferImplAllReduce, true}}, {prim::kPrimAllReduce, {InferImplAllReduce, true}},
{prim::kPrimBroadcast, {InferImplBroadcast, true}}, {prim::kPrimBroadcast, {InferImplBroadcast, true}},
{prim::kPrimAllGather, {InferImplAllGather, true}}, {prim::kPrimAllGather, {InferImplAllGather, true}},
{prim::kPrimAllSwap, {InferImplAllSwap, true}}, {prim::kPrimMinimum, {InferImplMinimum, true}},
{prim::kPrimReduceScatter, {InferImplReduceScatter, true}}, {prim::kPrimDivNoNan, {InferImplDivNoNan, true}},
{prim::kPrimMemCpyAsync, {InferImplMemCpyAsync, true}}, {prim::kPrimLinSpace, {InferImplLinSpace, true}},
{prim::kPrimCast, {InferImplCast, true}}, {prim::kPrimAddN, {InferImplAddN, true}},
{prim::kPrimExpandDims, {InferImplExpandDims, true}},
{prim::kPrimSparseSoftmaxCrossEntropyWithLogits, {InferImplSparseSoftmaxCrossEntropyWithLogits, true}}, {prim::kPrimLess, {InferImplLess, true}},
{prim::kPrimDType, {InferImplDType, true}}, {prim::kPrimStack, {InferImplStack, true}},
{prim::kPrimPad, {InferImplPad, true}},
{prim::kPrimEmbeddingLookup, {InferImplEmbeddingLookup, true}},
{prim::kPrimUnsortedSegmentSum, {InferImplUnsortedSegmentSum, true}},
{prim::kPrimDiv, {InferImplDiv, true}},
{prim::kPrimRealDiv, {InferImplRealDiv, true}},
{prim::kPrimShape, {InferImplShape, false}},
{prim::kPrimTranspose, {InferImplTranspose, true}},
{prim::kPrimReshape, {InferImplReshape, true}},
{prim::kPrimConcat, {InferImplConcat, true}},
{prim::kPrimArgMaxWithValue, {InferImplArgMaxWithValue, true}},
{prim::kPrimFusedSparseAdam, {InferImplFusedSparseAdam, true}},
}; };
return prim_eval_implement_map; return prim_backend_eval_implement_map;
}
StandardPrimitiveEvalImpl GetPrimitiveInferImpl(const PrimitivePtr &primitive) {
MS_EXCEPTION_IF_NULL(primitive);
auto iter = GetPrimitiveToEvalImplMap().find(primitive);
if (iter == GetPrimitiveToEvalImplMap().end()) {
return nullptr;
}
return iter->second.impl_;
} }
void RegisterStandardPrimitiveImpl(const PrimitivePtr &primitive, const StandardPrimitiveImplReg &impl_reg) { void RegisterStandardPrimitiveImpl(const PrimitivePtr &primitive, const StandardPrimitiveImplReg &impl_reg) {

View File

@ -18,6 +18,7 @@
#ifndef MINDSPORE_CORE_ABSTRACT_PRIMITIVE_INFER_MAP_H_ #ifndef MINDSPORE_CORE_ABSTRACT_PRIMITIVE_INFER_MAP_H_
#define MINDSPORE_CORE_ABSTRACT_PRIMITIVE_INFER_MAP_H_ #define MINDSPORE_CORE_ABSTRACT_PRIMITIVE_INFER_MAP_H_
#include <unordered_map> #include <unordered_map>
#include <vector>
#include "ir/primitive.h" #include "ir/primitive.h"
#include "base/core_ops.h" #include "base/core_ops.h"
#include "abstract/abstract_value.h" #include "abstract/abstract_value.h"
@ -37,6 +38,10 @@ using PrimitiveEvalImplMap =
PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap(); PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap();
PrimitiveEvalImplMap &GetPrimitiveToBackendEvalImplMap();
StandardPrimitiveEvalImpl GetPrimitiveInferImpl(const PrimitivePtr &primitive);
std::vector<int64_t> GetDependsFormMap(const CNodePtr &cnode); std::vector<int64_t> GetDependsFormMap(const CNodePtr &cnode);
void RegisterStandardPrimitiveImpl(const PrimitivePtr &primitive, const StandardPrimitiveImplReg &impl_reg); void RegisterStandardPrimitiveImpl(const PrimitivePtr &primitive, const StandardPrimitiveImplReg &impl_reg);

View File

@ -104,6 +104,5 @@ AbstractBasePtr MaxPoolFusionInfer(const abstract::AnalysisEnginePtr &, const Pr
return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args), return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args),
InferShape(primitive, input_args)->shape()); InferShape(primitive, input_args)->shape());
} }
REGISTER_PRIMITIVE_EVAL_IMPL(MaxPoolFusion, prim::kPrimMaxPool, MaxPoolFusionInfer);
} // namespace ops } // namespace ops
} // namespace mindspore } // namespace mindspore

View File

@ -31,8 +31,6 @@ AbstractBasePtr AvgPoolGradInfer(const abstract::AnalysisEnginePtr &, const Prim
auto element = tensor_type->element(); auto element = tensor_type->element();
return std::make_shared<abstract::AbstractTensor>(element, origin_input_shape); return std::make_shared<abstract::AbstractTensor>(element, origin_input_shape);
} }
REGISTER_PRIMITIVE_EVAL_IMPL(AvgPoolGrad, prim::kPrimAvgPoolGrad, AvgPoolGradInfer);
REGISTER_PRIMITIVE_C(kNameAvgPoolGrad, AvgPoolGrad); REGISTER_PRIMITIVE_C(kNameAvgPoolGrad, AvgPoolGrad);
} // namespace ops } // namespace ops
} // namespace mindspore } // namespace mindspore

View File

@ -58,8 +58,6 @@ AbstractBasePtr BiasAddGradInfer(const abstract::AnalysisEnginePtr &, const Prim
return std::make_shared<abstract::AbstractTensor>(intype, inshape); return std::make_shared<abstract::AbstractTensor>(intype, inshape);
} }
REGISTER_PRIMITIVE_EVAL_IMPL(BiasAddGrad, prim::kPrimBiasAddGrad, BiasAddGradInfer);
REGISTER_PRIMITIVE_C(kNameBiasAddGrad, BiasAddGrad); REGISTER_PRIMITIVE_C(kNameBiasAddGrad, BiasAddGrad);
} // namespace ops } // namespace ops
} // namespace mindspore } // namespace mindspore

View File

@ -31,8 +31,6 @@ AbstractBasePtr MaxPoolGradInfer(const abstract::AnalysisEnginePtr &, const Prim
auto element = tensor_type->element(); auto element = tensor_type->element();
return std::make_shared<abstract::AbstractTensor>(element, x1_shape); return std::make_shared<abstract::AbstractTensor>(element, x1_shape);
} }
REGISTER_PRIMITIVE_EVAL_IMPL(MaxPoolGrad, prim::kPrimMaxPoolGrad, MaxPoolGradInfer);
REGISTER_PRIMITIVE_C(kNameMaxPoolGrad, MaxPoolGrad); REGISTER_PRIMITIVE_C(kNameMaxPoolGrad, MaxPoolGrad);
} // namespace ops } // namespace ops
} // namespace mindspore } // namespace mindspore

View File

@ -102,7 +102,6 @@ AbstractBasePtr LrnInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr
return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args), return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args),
InferShape(primitive, input_args)->shape()); InferShape(primitive, input_args)->shape());
} }
REGISTER_PRIMITIVE_EVAL_IMPL(LRN, prim::kPrimLrn, LrnInfer);
REGISTER_PRIMITIVE_C(kNameLRN, LRN); REGISTER_PRIMITIVE_C(kNameLRN, LRN);
} // namespace ops } // namespace ops
} // namespace mindspore } // namespace mindspore

View File

@ -0,0 +1,86 @@
/**
* This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
*
* Copyright 2019-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <memory>
#include <vector>
#include <string>
#include "ir/primitive.h"
#include "utils/utils.h"
#include "abstract/abstract_value.h"
#include "abstract/primitive_infer_map.h"
#include "backend/optimizer/common/const_input_to_attr_registry.h"
#include "backend/optimizer/common/helper.h"
#include "common/common_test.h"
namespace mindspore {
namespace opt {
constexpr auto kAttrConvertTestName = "attr_convert_test";
constexpr auto kDynamicInputTestName = "dynamic_input_test";
inline const PrimitivePtr kPrimAttrConvertTest = std::make_shared<Primitive>(kAttrConvertTestName);
inline const PrimitivePtr kPrimDynamicInputTest = std::make_shared<Primitive>("dynamic_input_test");
AbstractBasePtr InferImplAttrTest(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
EXPECT_EQ(args_spec_list.size(), 3);
EXPECT_NE(args_spec_list[1], nullptr);
EXPECT_EQ(args_spec_list[1]->isa<abstract::AbstractTuple>(), true);
return args_spec_list[0];
}
REGISTER_PRIMITIVE_EVAL_IMPL(TestAttr,kPrimAttrConvertTest,InferImplAttrTest);
AbstractBasePtr InferImplDynamicInputTest(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
EXPECT_EQ(args_spec_list.size(), 3);
EXPECT_NE(args_spec_list[1], nullptr);
EXPECT_EQ(args_spec_list[1]->isa<abstract::AbstractTuple>(), true);
auto item = args_spec_list[1]->cast<abstract::AbstractTuplePtr>();
return args_spec_list[0];
}
REGISTER_PRIMITIVE_EVAL_IMPL(TestDynamicInput,kPrimDynamicInputTest,InferImplDynamicInputTest);
class TestAttrAndDynamicBackendInfer : public UT::Common {
public:
TestAttrAndDynamicBackendInfer() {}
void SetUp() override {}
void TearDown() override {}
};
TEST_F(TestAttrAndDynamicBackendInfer, test_attr_and_dynamic_input_infer) {
// Register Attr for ut
ConstInputToAttrInfoRegistry &reg = ConstInputToAttrInfoRegistry::Instance();
reg.Register(kAttrConvertTestName, {1});
// construct primitive
PrimitivePtr prim_attr_test = std::make_shared<Primitive>(kAttrConvertTestName);
PrimitivePtr prim_dynamic_input_test = std::make_shared<Primitive>(kDynamicInputTestName);
// set primtive attr
auto input_names = std::vector<std::string>{"a", "b", "c"};
auto attr_name = "b";
auto attr = MakeValue(std::vector<int>{1, 2, 3});
prim_attr_test->AddAttr(kAttrInputNames, MakeValue(input_names));
prim_attr_test->AddAttr(attr_name, attr);
// set dynameic input list for primtive
std::vector<int64_t> dynamic_input_list = {-1, 2, -1};
prim_dynamic_input_test->AddAttr(kAttrDynInputSizes, MakeValue(dynamic_input_list));
// construct Abstract list
auto abs_a = std::make_shared<abstract::AbstractTensor>(kFloat32, std::vector<int64_t>{2, 2, 2, 2});
auto abs_c = std::make_shared<abstract::AbstractTensor>(kFloat32, std::vector<int64_t>{2, 2, 2, 2});
auto attr_infer_result = CppInferShape(prim_attr_test, {abs_a, abs_c});
auto abs_dynamic_a = std::make_shared<abstract::AbstractTensor>(kFloat32, std::vector<int64_t>{2, 2, 2, 2});
auto abs_dynamic_b = std::make_shared<abstract::AbstractTensor>(kFloat32, std::vector<int64_t>{2, 2, 2, 2});
auto abs_dynamic_c = std::make_shared<abstract::AbstractTensor>(kFloat32, std::vector<int64_t>{2, 2, 2, 2});
auto abs_dynamic_d = std::make_shared<abstract::AbstractTensor>(kFloat32, std::vector<int64_t>{2, 2, 2, 2});
auto dynamic_infer_result =
CppInferShape(prim_dynamic_input_test, {abs_dynamic_a, abs_dynamic_b, abs_dynamic_c, abs_dynamic_d});
}
} // namespace opt
} // namespace mindspore