forked from mindspore-Ecosystem/mindspore
!12688 using cpp infer firstly
From: @lianliguang Reviewed-by: Signed-off-by:
This commit is contained in:
commit
04e23927ef
|
@ -366,7 +366,7 @@ if(CMAKE_SYSTEM_NAME MATCHES "Windows")
|
|||
elseif(CMAKE_SYSTEM_NAME MATCHES "Darwin")
|
||||
target_link_libraries(mindspore mindspore::pybind11_module)
|
||||
target_link_libraries(mindspore mindspore_gvar)
|
||||
target_link_libraries(_c_expression PRIVATE -Wl,-force_load mindspore -Wl,-noall_load)
|
||||
target_link_libraries(_c_expression PRIVATE -Wl,-force_load mindspore mindspore_core -Wl,-noall_load)
|
||||
else()
|
||||
if(ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU))
|
||||
target_link_libraries(mindspore proto_input mindspore::protobuf
|
||||
|
@ -376,7 +376,8 @@ else()
|
|||
target_link_libraries(mindspore ibverbs rdmacm)
|
||||
endif()
|
||||
endif()
|
||||
target_link_libraries(_c_expression PRIVATE -Wl,--whole-archive mindspore proto_input -Wl,--no-whole-archive)
|
||||
target_link_libraries(_c_expression PRIVATE -Wl,--whole-archive mindspore mindspore_core
|
||||
proto_input -Wl,--no-whole-archive)
|
||||
target_link_libraries(_c_expression PRIVATE mindspore::pybind11_module)
|
||||
target_link_libraries(_c_expression PRIVATE mindspore_gvar)
|
||||
if(ENABLE_D)
|
||||
|
|
|
@ -13,7 +13,7 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "backend/optimizer/pass/const_input_to_attr_registry.h"
|
||||
#include "backend/optimizer/common/const_input_to_attr_registry.h"
|
||||
|
||||
#include <utility>
|
||||
|
|
@ -13,8 +13,8 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_CONST_INPUT_TO_ATTR_REGISTRY_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_CONST_INPUT_TO_ATTR_REGISTRY_H_
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_COMMON_CONST_INPUT_TO_ATTR_REGISTRY_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_COMMON_CONST_INPUT_TO_ATTR_REGISTRY_H_
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
|
@ -75,4 +75,4 @@ struct ConstInputToAttrInfoReceiver {
|
|||
::mindspore::opt::ConstInputToAttrInfoRegister(op_name)
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_CONST_INPUT_TO_ATTR_REGISTRY_H_
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_COMMON_CONST_INPUT_TO_ATTR_REGISTRY_H_
|
|
@ -31,6 +31,8 @@
|
|||
#include "utils/ms_utils.h"
|
||||
#include "runtime/device/kernel_info.h"
|
||||
#include "utils/ms_context.h"
|
||||
#include "backend/optimizer/common/const_input_to_attr_registry.h"
|
||||
#include "abstract/primitive_infer_map.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
|
@ -700,6 +702,92 @@ AnfNodePtr HandleSexpVector(const BaseRef &sexp, const BaseRef &graph, Primitive
|
|||
}
|
||||
return CreateCNodeWithGraph(input_nodes, graph);
|
||||
}
|
||||
|
||||
// rectify absttract if the input has been converted to the attr
|
||||
AbstractBasePtrList RectifyAbstractFromRegAttr(const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &input_abstract) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
opt::ConstInputToAttrInfoRegister reg;
|
||||
if (!opt::ConstInputToAttrInfoRegistry::Instance().GetRegisterByOpName(primitive->name(), ®)) {
|
||||
return input_abstract;
|
||||
}
|
||||
if (AnfAlgo::HasDynamicShapeFlag(primitive) ||
|
||||
DynamicShapeConstInputToAttr.find(primitive->name()) != DynamicShapeConstInputToAttr.end()) {
|
||||
return input_abstract;
|
||||
}
|
||||
auto convert_input_list = reg.GetConstInputAttrInfo();
|
||||
auto input_names = primitive->GetAttr(kAttrInputNames);
|
||||
if (input_names == nullptr) {
|
||||
return input_abstract;
|
||||
}
|
||||
auto input_names_vec = GetValue<std::vector<std::string>>(input_names);
|
||||
AbstractBasePtrList rectify_abs_list;
|
||||
size_t ori_index = 0;
|
||||
rectify_abs_list.resize(input_names_vec.size());
|
||||
for (size_t index = 0; index < rectify_abs_list.size(); ++index) {
|
||||
// if convert input list find the index it means the input has been converted to the attr
|
||||
if (convert_input_list.find(index) != convert_input_list.end()) {
|
||||
AbstractBasePtr rectify_abs = nullptr;
|
||||
auto input_name = input_names_vec[index];
|
||||
auto attr = primitive->GetAttr(input_name);
|
||||
if (attr != nullptr) {
|
||||
rectify_abs = attr->ToAbstract();
|
||||
} else {
|
||||
MS_LOG(DEBUG) << "the node prim name :" << primitive->name() << "input index :" << index
|
||||
<< " input name :" << input_name << "has not been converted to the attr";
|
||||
rectify_abs = input_abstract[ori_index++];
|
||||
}
|
||||
rectify_abs_list[index] = rectify_abs;
|
||||
continue;
|
||||
}
|
||||
if (ori_index > input_abstract.size()) {
|
||||
MS_LOG(EXCEPTION) << "index is out of range input abstract size " << input_abstract.size()
|
||||
<< " get index :" << ori_index;
|
||||
}
|
||||
rectify_abs_list[index] = input_abstract[ori_index++];
|
||||
}
|
||||
return rectify_abs_list;
|
||||
}
|
||||
|
||||
AbstractBasePtrList RectifyAbstractFromDynamicInput(const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &input_abstract) {
|
||||
auto dynamic_inputs_list = primitive->GetAttr(kAttrDynInputSizes);
|
||||
if (dynamic_inputs_list == nullptr) {
|
||||
return input_abstract;
|
||||
}
|
||||
AbstractBasePtrList rectifyed_abs_list;
|
||||
const int kNotDynamicFlag = -1;
|
||||
auto dynamic_inputs_index = GetValue<std::vector<int64_t>>(dynamic_inputs_list);
|
||||
size_t input_index = 0;
|
||||
for (auto item : dynamic_inputs_index) {
|
||||
if (item == kNotDynamicFlag) {
|
||||
if (input_index >= input_abstract.size()) {
|
||||
MS_LOG(EXCEPTION) << " index " << input_index << " is out of range in input abstract " << input_abstract.size();
|
||||
}
|
||||
rectifyed_abs_list.emplace_back(input_abstract[input_index++]);
|
||||
} else {
|
||||
if (item < 0) {
|
||||
MS_LOG(EXCEPTION) << " the dynamic input size check error the index should be -1 or positive number but got "
|
||||
<< item;
|
||||
}
|
||||
AbstractBasePtrList dynamic_inputs_abs;
|
||||
for (auto index = item; index > 0; --index) {
|
||||
if (input_index >= input_abstract.size()) {
|
||||
MS_LOG(EXCEPTION) << " index " << input_index << " is out of range in input abstract "
|
||||
<< input_abstract.size();
|
||||
}
|
||||
dynamic_inputs_abs.emplace_back(input_abstract[input_index++]);
|
||||
}
|
||||
rectifyed_abs_list.emplace_back(std::make_shared<abstract::AbstractTuple>(dynamic_inputs_abs));
|
||||
}
|
||||
}
|
||||
return rectifyed_abs_list;
|
||||
}
|
||||
|
||||
AbstractBasePtrList RectifyAbstract(const PrimitivePtr &primitive, const AbstractBasePtrList &input_abstract) {
|
||||
auto rectify_abs_list = RectifyAbstractFromRegAttr(primitive, input_abstract);
|
||||
return RectifyAbstractFromDynamicInput(primitive, rectify_abs_list);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
AnfNodePtr SexpToNode(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars, bool multigraph) {
|
||||
|
@ -835,5 +923,24 @@ void TransferDepend(const CNodePtr &old_node, const FuncGraphPtr &graph, const C
|
|||
}
|
||||
}
|
||||
}
|
||||
AbstractBasePtr CppInferShape(const PrimitivePtr &prim, const AbstractBasePtrList &args_spec_list) {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
auto &prim_eval_implement_map = abstract::GetPrimitiveToEvalImplMap();
|
||||
auto ret = prim_eval_implement_map.find(prim);
|
||||
if (ret != prim_eval_implement_map.end()) {
|
||||
// fing infer function in the front infer map and restore input abastract form dynamic inputs and reg attr
|
||||
auto infer_spec_list = RectifyAbstract(prim, args_spec_list);
|
||||
return ret->second.impl_(nullptr, prim, infer_spec_list);
|
||||
} else {
|
||||
// if the infer function has been not founded in the front infer map find it in the backend infer map instead
|
||||
auto &prim_backend_eval_impl_map = abstract::GetPrimitiveToBackendEvalImplMap();
|
||||
auto ret_backend = prim_backend_eval_impl_map.find(prim);
|
||||
if (ret_backend != prim_backend_eval_impl_map.end()) {
|
||||
return ret_backend->second.impl_(nullptr, prim, args_spec_list);
|
||||
}
|
||||
}
|
||||
MS_LOG(EXCEPTION) << "Get infer shape function failed, primitive name:" << prim->name()
|
||||
<< " primitive type:" << prim->type_name();
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -212,6 +212,8 @@ ValueNodePtr MakeValueNode(const ValueNodePtr &value_node);
|
|||
|
||||
// Transfer depend or control_depend to the new node
|
||||
void TransferDepend(const CNodePtr &old_node, const FuncGraphPtr &graph, const CNodePtr &new_node);
|
||||
|
||||
AbstractBasePtr CppInferShape(const PrimitivePtr &prim, const AbstractBasePtrList &args_spec_list);
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_COMMON_HELPER_H_
|
||||
|
|
|
@ -27,7 +27,7 @@
|
|||
#include "backend/kernel_compiler/akg/akg_kernel_json_decoder.h"
|
||||
#include "backend/kernel_compiler/kernel.h"
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
#include "backend/optimizer/pass/const_input_to_attr_registry.h"
|
||||
#include "backend/optimizer/common/const_input_to_attr_registry.h"
|
||||
#include "ir/func_graph_cloner.h"
|
||||
#include "ir/func_graph.h"
|
||||
#include "pipeline/jit/parse/python_adapter.h"
|
||||
|
|
|
@ -19,7 +19,7 @@
|
|||
#include <string>
|
||||
#include <memory>
|
||||
|
||||
#include "backend/optimizer/pass/const_input_to_attr_registry.h"
|
||||
#include "backend/optimizer/common/const_input_to_attr_registry.h"
|
||||
#include "backend/optimizer/common/helper.h"
|
||||
#include "utils/utils.h"
|
||||
#include "utils/ms_context.h"
|
||||
|
|
|
@ -1534,6 +1534,18 @@ bool AnfRuntimeAlgorithm::GetBooleanAttr(const AnfNodePtr &node, const std::stri
|
|||
return AnfAlgo::GetNodeAttr<bool>(node, attr);
|
||||
}
|
||||
|
||||
bool AnfRuntimeAlgorithm::HasDynamicShapeFlag(const PrimitivePtr &prim) {
|
||||
auto get_bool_attr = [](const PrimitivePtr &primitive, const std::string &attr_name) -> bool {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
if (!primitive->HasAttr(attr_name)) {
|
||||
return false;
|
||||
}
|
||||
return GetValue<bool>(primitive->GetAttr(attr_name));
|
||||
};
|
||||
return get_bool_attr(prim, kAttrInputIsDynamicShape) || get_bool_attr(prim, kAttrOutputIsDynamicShape) ||
|
||||
get_bool_attr(prim, kAttrIsDynamicShape);
|
||||
}
|
||||
|
||||
bool AnfRuntimeAlgorithm::IsDynamicShape(const AnfNodePtr &node) {
|
||||
return GetBooleanAttr(node, kAttrInputIsDynamicShape) || GetBooleanAttr(node, kAttrOutputIsDynamicShape) ||
|
||||
GetBooleanAttr(node, kAttrIsDynamicShape);
|
||||
|
@ -1805,7 +1817,7 @@ void AnfRuntimeAlgorithm::InferShape(const CNodePtr &node) {
|
|||
args_spec_list.emplace_back(real_input->abstract());
|
||||
}
|
||||
}
|
||||
auto eval_result = abstract::CppInferShape(primitive, args_spec_list);
|
||||
auto eval_result = opt::CppInferShape(primitive, args_spec_list);
|
||||
node->set_abstract(eval_result);
|
||||
}
|
||||
} // namespace session
|
||||
|
|
|
@ -230,6 +230,7 @@ class AnfRuntimeAlgorithm {
|
|||
// get fix output precision from prev node, input_idx is the input index of current node related to prev node.
|
||||
static TypeId GetPrevNodeOutputPrecision(const AnfNodePtr &node, size_t input_idx);
|
||||
static bool IsDynamicShape(const AnfNodePtr &node);
|
||||
static bool HasDynamicShapeFlag(const PrimitivePtr &prim);
|
||||
static bool IsCondControlKernel(const CNodePtr &node);
|
||||
static bool IsIndependentNode(const CNodePtr &node);
|
||||
static bool GetBooleanAttr(const AnfNodePtr &node, const std::string &attr);
|
||||
|
|
|
@ -1311,15 +1311,6 @@ bool IsInWhiteList(const PrimitivePtr &primitive) {
|
|||
return false;
|
||||
}
|
||||
|
||||
StandardPrimitiveEvalImpl GetPrimitiveInferImpl(const PrimitivePtr &primitive) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto iter = GetPrimitiveToEvalImplMap().find(primitive);
|
||||
if (iter == GetPrimitiveToEvalImplMap().end()) {
|
||||
return nullptr;
|
||||
}
|
||||
return iter->second.impl_;
|
||||
}
|
||||
|
||||
PrimEvaluatorMap &GetPrimEvaluatorConstructors() {
|
||||
PrimEvaluatorMap &constructor = PrimEvaluatorConstructors;
|
||||
if (!constructor.empty()) {
|
||||
|
|
|
@ -112,7 +112,6 @@ class MixedPrecisionCastEvaluator : public Evaluator {
|
|||
};
|
||||
|
||||
bool IsInWhiteList(const PrimitivePtr &primitive);
|
||||
StandardPrimitiveEvalImpl GetPrimitiveInferImpl(const PrimitivePtr &primitive);
|
||||
|
||||
using ValuePtrList = std::vector<ValuePtr>;
|
||||
using PrimitiveImpl = ValuePtr (*)(const ValuePtrList &);
|
||||
|
|
|
@ -357,6 +357,13 @@ EvaluatorPtr GetPrimEvaluator(const PrimitivePtr &prim, const AnalysisEnginePtr
|
|||
return std::make_shared<MixedPrecisionCastEvaluator>(prim);
|
||||
}
|
||||
|
||||
// find prim infer function in the prim function map return a standard evaluator
|
||||
StandardPrimitiveEvalImpl eval_impl = GetPrimitiveInferImpl(prim);
|
||||
if (eval_impl != nullptr) {
|
||||
return std::make_shared<StandardPrimEvaluator>(prim, eval_impl);
|
||||
}
|
||||
|
||||
// use python infer function if the infer function not founded in the map return a python evaluator
|
||||
EvaluatorPtr evaluator = nullptr;
|
||||
if (prim->HasPyEvaluator()) {
|
||||
auto prim_py = dyn_cast<PrimitivePy>(prim);
|
||||
|
@ -376,17 +383,7 @@ EvaluatorPtr GetPrimEvaluator(const PrimitivePtr &prim, const AnalysisEnginePtr
|
|||
MS_LOG(EXCEPTION) << "The primitive with python evaluator should be a python primitive.";
|
||||
}
|
||||
|
||||
if (prim->isa<PrimitivePy>() || prim->HasAttr()) {
|
||||
if (engine == nullptr) {
|
||||
(void)GetPrimEvaluatorConstructors();
|
||||
}
|
||||
// If a primitive may have attr, try to create a new evaluator.
|
||||
StandardPrimitiveEvalImpl eval_impl = GetPrimitiveInferImpl(prim);
|
||||
if (eval_impl != nullptr) {
|
||||
return std::make_shared<StandardPrimEvaluator>(prim, eval_impl);
|
||||
}
|
||||
}
|
||||
|
||||
// return a default evaluator
|
||||
if (engine == nullptr) {
|
||||
// If engine is nullptr, get constructor from default.
|
||||
const PrimEvaluatorMap &prim_evaluator_map = GetPrimEvaluatorConstructors();
|
||||
|
@ -778,16 +775,5 @@ EvalResultPtr EvalOnePrim(const PrimitivePtr &primitive, const AbstractBasePtrLi
|
|||
auto eval_result = trivial_evaluator->EvalPrim(nullptr, arg_specs);
|
||||
return eval_result;
|
||||
}
|
||||
|
||||
AbstractBasePtr CppInferShape(const PrimitivePtr &prim, const AbstractBasePtrList &args_spec_list) {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
auto &prim_eval_implement_map = GetPrimitiveToEvalImplMap();
|
||||
auto ret = prim_eval_implement_map.find(prim);
|
||||
if (ret == prim_eval_implement_map.end()) {
|
||||
MS_LOG(EXCEPTION) << "Get infer shape function failed, primitive name:" << prim->name()
|
||||
<< " primitive type:" << prim->type_name();
|
||||
}
|
||||
return ret->second.impl_(nullptr, prim, args_spec_list);
|
||||
}
|
||||
} // namespace abstract
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -331,8 +331,6 @@ AbstractBasePtr FromValue(const T &value, bool broaden = false) {
|
|||
}
|
||||
|
||||
EvalResultPtr EvalOnePrim(const PrimitivePtr &p, const AbstractBasePtrList &arg_specs);
|
||||
|
||||
AbstractBasePtr CppInferShape(const PrimitivePtr &prim, const AbstractBasePtrList &args_spec_list);
|
||||
} // namespace abstract
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -44,7 +44,7 @@
|
|||
#include "pipeline/jit/static_analysis/prim.h"
|
||||
#include "pipeline/jit/static_analysis/auto_monad.h"
|
||||
#include "backend/session/session_factory.h"
|
||||
#include "backend/optimizer/pass/const_input_to_attr_registry.h"
|
||||
#include "backend/optimizer/common/const_input_to_attr_registry.h"
|
||||
#include "backend/optimizer/common/helper.h"
|
||||
#include "pipeline/jit/action.h"
|
||||
|
||||
|
@ -807,21 +807,13 @@ void PynativeExecutor::GetOpOutputAbstract(const OpExecInfoPtr &op_exec_info,
|
|||
}
|
||||
}
|
||||
// get output dynamic shape info
|
||||
auto py_abstract = op_exec_info->abstract;
|
||||
MS_EXCEPTION_IF_NULL(py_abstract);
|
||||
auto py_shape = py_abstract->BuildShape();
|
||||
MS_EXCEPTION_IF_NULL(py_shape);
|
||||
auto py_shape_info = py_shape->ToString();
|
||||
if (py_shape_info.find("-1") != string::npos) {
|
||||
auto c_abstract = abstract::CppInferShape(prim, args_spec_list);
|
||||
MS_EXCEPTION_IF_NULL(c_abstract);
|
||||
auto c_shape = c_abstract->BuildShape();
|
||||
MS_EXCEPTION_IF_NULL(c_shape);
|
||||
auto c_shape_info = c_shape->ToString();
|
||||
MS_LOG(DEBUG) << "Final infer output shape: " << c_shape_info;
|
||||
if (c_shape_info.find("-1") != string::npos) {
|
||||
op_exec_info->is_dynamic_shape = true;
|
||||
}
|
||||
auto abstract = op_exec_info->abstract;
|
||||
MS_EXCEPTION_IF_NULL(abstract);
|
||||
auto shape = abstract->BuildShape();
|
||||
MS_EXCEPTION_IF_NULL(shape);
|
||||
auto shape_info = shape->ToString();
|
||||
if (shape_info.find("-1") != string::npos) {
|
||||
op_exec_info->is_dynamic_shape = true;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -123,7 +123,7 @@ void DynamicKernel::InferShape() {
|
|||
}
|
||||
}
|
||||
|
||||
auto eval_result = abstract::CppInferShape(primitive, args_spec_list);
|
||||
auto eval_result = opt::CppInferShape(primitive, args_spec_list);
|
||||
cnode_ptr_->set_abstract(eval_result);
|
||||
}
|
||||
|
||||
|
|
|
@ -1041,6 +1041,9 @@ AbstractBasePtr InferImplConcat(const AnalysisEnginePtr &, const PrimitivePtr &p
|
|||
AbstractBasePtr InferImplRange(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
const std::string &op_name = primitive->name();
|
||||
if (args_spec_list.size() == 1) {
|
||||
return args_spec_list[0]->Broaden();
|
||||
}
|
||||
CheckArgsSize(op_name, args_spec_list, 3);
|
||||
AbstractTensorPtr range_start = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
|
||||
AbstractTensorPtr range_end = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
|
||||
|
|
|
@ -292,24 +292,6 @@ AbstractBasePtr InferImplFusedBatchNormEx(const AnalysisEnginePtr &, const Primi
|
|||
return std::make_shared<AbstractTuple>(rets);
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplBatchNormGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: five tensors(y_backprop, x, scale, save_mean, save_inv_variance).
|
||||
MS_EXCEPTION_IF_NULL(args_spec_list[1]);
|
||||
MS_EXCEPTION_IF_NULL(args_spec_list[2]);
|
||||
MS_EXCEPTION_IF_NULL(args_spec_list[3]);
|
||||
|
||||
CheckArgsSize(primitive->name(), args_spec_list, 5);
|
||||
auto dx = args_spec_list[1]->Broaden();
|
||||
auto dscale = args_spec_list[2]->Broaden();
|
||||
auto dbias = args_spec_list[3]->Broaden();
|
||||
auto reserve_1 = args_spec_list[4]->Broaden();
|
||||
auto reserve_2 = args_spec_list[5]->Broaden();
|
||||
|
||||
AbstractBasePtrList rets = {dx, dscale, dbias, reserve_1, reserve_2};
|
||||
return std::make_shared<AbstractTuple>(rets);
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplReluGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: two tensors(y_backprop, x).
|
||||
|
@ -468,20 +450,6 @@ AbstractBasePtr InferImplConv2D(const AnalysisEnginePtr &, const PrimitivePtr &p
|
|||
return std::make_shared<AbstractTensor>(x_type, output_shape_ptr);
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplConv2DBackpropInput(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: three tensors(doutput, input, filters).
|
||||
CheckRequiredArgsSize(primitive->name(), args_spec_list, 3);
|
||||
return args_spec_list[1]->Broaden();
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplConv2DBackpropFilter(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: three tensors(inputs, filter, doutput).
|
||||
CheckArgsSize(primitive->name(), args_spec_list, 3);
|
||||
return args_spec_list[2]->Broaden();
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplBiasAdd(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
const std::string op_name = primitive->name();
|
||||
|
|
|
@ -17,6 +17,11 @@
|
|||
*/
|
||||
|
||||
#include "abstract/primitive_infer_map.h"
|
||||
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "abstract/abstract_function.h"
|
||||
#include "abstract/infer_functions.h"
|
||||
|
||||
|
@ -59,40 +64,21 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
|
|||
{prim::kPrimNotInDict, {InferImplNotInDict, true}},
|
||||
{prim::kPrimIsConsant, {InferImplIsConstant, true}},
|
||||
// Maths
|
||||
{prim::kPrimMaximumGrad, {InferImplMinOrMaxGrad, true}},
|
||||
{prim::kPrimMinimumGrad, {InferImplMinOrMaxGrad, true}},
|
||||
{prim::kPrimMul, {InferImplMul, true}},
|
||||
{prim::kPrimAdd, {InferImplAdd, true}},
|
||||
{prim::kPrimSquare, {InferImplSquare, true}},
|
||||
{prim::kPrimSqrt, {InferImplSqrt, true}},
|
||||
{prim::kPrimSqrtGrad, {InferImplSqrtGrad, true}},
|
||||
{prim::kPrimSub, {InferImplSub, true}},
|
||||
{prim::kPrimEqual, {InferImplEqual, true}},
|
||||
{prim::kPrimReduceSum, {InferImplReduceFunc, true}},
|
||||
{prim::kPrimReduceMean, {InferImplReduceFunc, true}},
|
||||
{prim::kPrimReduceAll, {InferImplReduceFunc, true}},
|
||||
{prim::kPrimReduceAny, {InferImplReduceFunc, true}},
|
||||
{prim::kPrimReduceMax, {InferImplReduceFunc, true}},
|
||||
{prim::kPrimReduceMin, {InferImplReduceFunc, true}},
|
||||
{prim::kPrimMinimum, {InferImplMinimum, true}},
|
||||
{prim::kPrimDivNoNan, {InferImplDivNoNan, true}},
|
||||
{prim::kPrimLinSpace, {InferImplLinSpace, true}},
|
||||
{prim::kPrimAddN, {InferImplAddN, true}},
|
||||
{prim::kPrimMatMul, {InferImplMatMul, true}},
|
||||
{prim::kPrimBatchMatMul, {InferImplBatchMatMul, true}},
|
||||
{prim::kPrimLess, {InferImplLess, true}},
|
||||
{prim::kPrimMaximumGrad, {InferImplMinOrMaxGrad, true}},
|
||||
{prim::kPrimMinimumGrad, {InferImplMinOrMaxGrad, true}},
|
||||
{prim::kPrimSqrt, {InferImplSqrt, true}},
|
||||
// Array
|
||||
{prim::kPrimRange, {InferImplRange, true}},
|
||||
{prim::kPrimScalarToArray, {InferImplScalarToArray, true}},
|
||||
{prim::kPrimArrayToScalar, {InferImplArrayToScalar, true}},
|
||||
{prim::kPrimBroadcastShape, {InferImplBroadCastShape, true}},
|
||||
{prim::kPrimStack, {InferImplStack, true}},
|
||||
{prim::kPrimPad, {InferImplPad, true}},
|
||||
{prim::kPrimUnique, {InferImplUnique, true}},
|
||||
{prim::kPrimUniqueGrad, {InferImplUniqueGrad, true}},
|
||||
{prim::kPrimGather, {InferImplGatherV2, true}},
|
||||
{prim::kPrimSparseGatherV2, {InferImplGatherV2, true}},
|
||||
{prim::kPrimEmbeddingLookup, {InferImplEmbeddingLookup, true}},
|
||||
{prim::kPrimUnsortedSegmentSum, {InferImplUnsortedSegmentSum, true}},
|
||||
{prim::kPrimUnsortedSegmentMax, {InferImplUnsortedSegmentMax, true}},
|
||||
{prim::kPrimUnsortedSegmentMin, {InferImplUnsortedSegmentMin, true}},
|
||||
{prim::kPrimScatterAdd, {InferImplScatterAdd, true}},
|
||||
|
@ -104,18 +90,10 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
|
|||
{prim::kPrimUpdateCache, {InferImplUpdateCache, true}},
|
||||
{prim::kPrimComputeAccidentalHits, {InferImplComputeAccidentalHits, true}},
|
||||
{prim::kPrimPadAndShift, {InferImplPadAndShift, true}},
|
||||
{prim::kPrimDiv, {InferImplDiv, true}},
|
||||
{prim::kPrimRealDiv, {InferImplRealDiv, true}},
|
||||
{prim::kPrimShape, {InferImplShape, false}},
|
||||
{prim::kPrimDynamicShape, {InferImplDynamicShape, true}},
|
||||
{prim::kPrimTranspose, {InferImplTranspose, true}},
|
||||
{prim::kPrimReshape, {InferImplReshape, true}},
|
||||
{prim::kPrimMapUniform, {InferImplMapUniform, true}},
|
||||
{prim::kPrimSplit, {InferImplSplit, true}},
|
||||
{prim::kPrimSequenceMask, {InferImplSequenceMask, true}},
|
||||
{prim::kPrimConcat, {InferImplConcat, true}},
|
||||
{prim::kPrimRange, {InferImplRange, true}},
|
||||
{prim::kPrimArgMaxWithValue, {InferImplArgMaxWithValue, true}},
|
||||
// Structure
|
||||
{prim::kPrimMakeTuple, {InferImplMakeTuple, true}},
|
||||
{prim::kPrimMakeList, {InferImplMakeList, true}},
|
||||
|
@ -139,14 +117,10 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
|
|||
{prim::kPrimPooling, {InferImplPooling, true}},
|
||||
{prim::kPrimPoolingGrad, {InferImplPoolingGrad, true}},
|
||||
{prim::kPrimFusedBatchNorm, {InferImplFusedBatchNorm, true}},
|
||||
{prim::kPrimFusedSparseAdam, {InferImplFusedSparseAdam, true}},
|
||||
{prim::kPrimFusedBatchNormGrad, {InferImplFusedBatchNormGrad, true}},
|
||||
{prim::kPrimFusedBatchNormEx, {InferImplFusedBatchNormEx, true}},
|
||||
{prim::kPrimBatchNormGrad, {InferImplBatchNormGrad, true}},
|
||||
{prim::kPrimReluGrad, {InferImplReluGrad, true}},
|
||||
{prim::kPrimConv2D, {InferImplConv2D, true}},
|
||||
{prim::kPrimConv2DBackpropInput, {InferImplConv2DBackpropInput, true}},
|
||||
{prim::kPrimConv2DBackpropFilter, {InferImplConv2DBackpropFilter, true}},
|
||||
{prim::kPrimBiasAdd, {InferImplBiasAdd, true}},
|
||||
{prim::kPrimBiasAddGrad, {InferImplBiasAddGrad, true}},
|
||||
{prim::kPrimRelu, {InferImplRelu, true}},
|
||||
|
@ -192,18 +166,60 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
|
|||
{prim::kPrimRowTensorGetDenseShape, {InferImplRowTensorGetDenseShape, true}},
|
||||
{prim::kPrimRowTensorAdd, {InferImplRowTensorAdd, false}},
|
||||
// Comm Ops
|
||||
{prim::kPrimAllSwap, {InferImplAllSwap, true}},
|
||||
{prim::kPrimMemCpyAsync, {InferImplMemCpyAsync, true}},
|
||||
};
|
||||
return prim_eval_implement_map;
|
||||
}
|
||||
|
||||
PrimitiveEvalImplMap &GetPrimitiveToBackendEvalImplMap() {
|
||||
static PrimitiveEvalImplMap prim_backend_eval_implement_map = {
|
||||
{prim::kPrimMul, {InferImplMul, true}},
|
||||
{prim::kPrimAdd, {InferImplAdd, true}},
|
||||
{prim::kPrimSqrtGrad, {InferImplSqrtGrad, true}},
|
||||
{prim::kPrimSub, {InferImplSub, true}},
|
||||
{prim::kPrimEqual, {InferImplEqual, true}},
|
||||
{prim::kPrimReduceSum, {InferImplReduceFunc, true}},
|
||||
{prim::kPrimReduceMean, {InferImplReduceFunc, true}},
|
||||
{prim::kPrimReduceAll, {InferImplReduceFunc, true}},
|
||||
{prim::kPrimReduceAny, {InferImplReduceFunc, true}},
|
||||
{prim::kPrimReduceMax, {InferImplReduceFunc, true}},
|
||||
{prim::kPrimReduceMin, {InferImplReduceFunc, true}},
|
||||
{prim::kPrimReduceScatter, {InferImplReduceScatter, true}},
|
||||
{prim::kPrimCast, {InferImplCast, true}},
|
||||
{prim::kPrimExpandDims, {InferImplExpandDims, true}},
|
||||
{prim::kPrimAllReduce, {InferImplAllReduce, true}},
|
||||
{prim::kPrimBroadcast, {InferImplBroadcast, true}},
|
||||
{prim::kPrimAllGather, {InferImplAllGather, true}},
|
||||
{prim::kPrimAllSwap, {InferImplAllSwap, true}},
|
||||
{prim::kPrimReduceScatter, {InferImplReduceScatter, true}},
|
||||
{prim::kPrimMemCpyAsync, {InferImplMemCpyAsync, true}},
|
||||
{prim::kPrimCast, {InferImplCast, true}},
|
||||
{prim::kPrimExpandDims, {InferImplExpandDims, true}},
|
||||
{prim::kPrimSparseSoftmaxCrossEntropyWithLogits, {InferImplSparseSoftmaxCrossEntropyWithLogits, true}},
|
||||
{prim::kPrimDType, {InferImplDType, true}},
|
||||
{prim::kPrimMinimum, {InferImplMinimum, true}},
|
||||
{prim::kPrimDivNoNan, {InferImplDivNoNan, true}},
|
||||
{prim::kPrimLinSpace, {InferImplLinSpace, true}},
|
||||
{prim::kPrimAddN, {InferImplAddN, true}},
|
||||
|
||||
{prim::kPrimLess, {InferImplLess, true}},
|
||||
{prim::kPrimStack, {InferImplStack, true}},
|
||||
{prim::kPrimPad, {InferImplPad, true}},
|
||||
{prim::kPrimEmbeddingLookup, {InferImplEmbeddingLookup, true}},
|
||||
{prim::kPrimUnsortedSegmentSum, {InferImplUnsortedSegmentSum, true}},
|
||||
{prim::kPrimDiv, {InferImplDiv, true}},
|
||||
{prim::kPrimRealDiv, {InferImplRealDiv, true}},
|
||||
{prim::kPrimShape, {InferImplShape, false}},
|
||||
{prim::kPrimTranspose, {InferImplTranspose, true}},
|
||||
{prim::kPrimReshape, {InferImplReshape, true}},
|
||||
{prim::kPrimConcat, {InferImplConcat, true}},
|
||||
{prim::kPrimArgMaxWithValue, {InferImplArgMaxWithValue, true}},
|
||||
{prim::kPrimFusedSparseAdam, {InferImplFusedSparseAdam, true}},
|
||||
};
|
||||
return prim_eval_implement_map;
|
||||
return prim_backend_eval_implement_map;
|
||||
}
|
||||
|
||||
StandardPrimitiveEvalImpl GetPrimitiveInferImpl(const PrimitivePtr &primitive) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto iter = GetPrimitiveToEvalImplMap().find(primitive);
|
||||
if (iter == GetPrimitiveToEvalImplMap().end()) {
|
||||
return nullptr;
|
||||
}
|
||||
return iter->second.impl_;
|
||||
}
|
||||
|
||||
void RegisterStandardPrimitiveImpl(const PrimitivePtr &primitive, const StandardPrimitiveImplReg &impl_reg) {
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
#ifndef MINDSPORE_CORE_ABSTRACT_PRIMITIVE_INFER_MAP_H_
|
||||
#define MINDSPORE_CORE_ABSTRACT_PRIMITIVE_INFER_MAP_H_
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
#include "ir/primitive.h"
|
||||
#include "base/core_ops.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
|
@ -37,6 +38,10 @@ using PrimitiveEvalImplMap =
|
|||
|
||||
PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap();
|
||||
|
||||
PrimitiveEvalImplMap &GetPrimitiveToBackendEvalImplMap();
|
||||
|
||||
StandardPrimitiveEvalImpl GetPrimitiveInferImpl(const PrimitivePtr &primitive);
|
||||
|
||||
std::vector<int64_t> GetDependsFormMap(const CNodePtr &cnode);
|
||||
|
||||
void RegisterStandardPrimitiveImpl(const PrimitivePtr &primitive, const StandardPrimitiveImplReg &impl_reg);
|
||||
|
|
|
@ -104,6 +104,5 @@ AbstractBasePtr MaxPoolFusionInfer(const abstract::AnalysisEnginePtr &, const Pr
|
|||
return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args),
|
||||
InferShape(primitive, input_args)->shape());
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(MaxPoolFusion, prim::kPrimMaxPool, MaxPoolFusionInfer);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -31,8 +31,6 @@ AbstractBasePtr AvgPoolGradInfer(const abstract::AnalysisEnginePtr &, const Prim
|
|||
auto element = tensor_type->element();
|
||||
return std::make_shared<abstract::AbstractTensor>(element, origin_input_shape);
|
||||
}
|
||||
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(AvgPoolGrad, prim::kPrimAvgPoolGrad, AvgPoolGradInfer);
|
||||
REGISTER_PRIMITIVE_C(kNameAvgPoolGrad, AvgPoolGrad);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -58,8 +58,6 @@ AbstractBasePtr BiasAddGradInfer(const abstract::AnalysisEnginePtr &, const Prim
|
|||
|
||||
return std::make_shared<abstract::AbstractTensor>(intype, inshape);
|
||||
}
|
||||
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(BiasAddGrad, prim::kPrimBiasAddGrad, BiasAddGradInfer);
|
||||
REGISTER_PRIMITIVE_C(kNameBiasAddGrad, BiasAddGrad);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -31,8 +31,6 @@ AbstractBasePtr MaxPoolGradInfer(const abstract::AnalysisEnginePtr &, const Prim
|
|||
auto element = tensor_type->element();
|
||||
return std::make_shared<abstract::AbstractTensor>(element, x1_shape);
|
||||
}
|
||||
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(MaxPoolGrad, prim::kPrimMaxPoolGrad, MaxPoolGradInfer);
|
||||
REGISTER_PRIMITIVE_C(kNameMaxPoolGrad, MaxPoolGrad);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -102,7 +102,6 @@ AbstractBasePtr LrnInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr
|
|||
return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args),
|
||||
InferShape(primitive, input_args)->shape());
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(LRN, prim::kPrimLrn, LrnInfer);
|
||||
REGISTER_PRIMITIVE_C(kNameLRN, LRN);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -0,0 +1,86 @@
|
|||
/**
|
||||
* This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
|
||||
*
|
||||
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include "ir/primitive.h"
|
||||
#include "utils/utils.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "abstract/primitive_infer_map.h"
|
||||
#include "backend/optimizer/common/const_input_to_attr_registry.h"
|
||||
#include "backend/optimizer/common/helper.h"
|
||||
#include "common/common_test.h"
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
constexpr auto kAttrConvertTestName = "attr_convert_test";
|
||||
constexpr auto kDynamicInputTestName = "dynamic_input_test";
|
||||
inline const PrimitivePtr kPrimAttrConvertTest = std::make_shared<Primitive>(kAttrConvertTestName);
|
||||
inline const PrimitivePtr kPrimDynamicInputTest = std::make_shared<Primitive>("dynamic_input_test");
|
||||
AbstractBasePtr InferImplAttrTest(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
EXPECT_EQ(args_spec_list.size(), 3);
|
||||
EXPECT_NE(args_spec_list[1], nullptr);
|
||||
EXPECT_EQ(args_spec_list[1]->isa<abstract::AbstractTuple>(), true);
|
||||
return args_spec_list[0];
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(TestAttr,kPrimAttrConvertTest,InferImplAttrTest);
|
||||
AbstractBasePtr InferImplDynamicInputTest(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
EXPECT_EQ(args_spec_list.size(), 3);
|
||||
EXPECT_NE(args_spec_list[1], nullptr);
|
||||
EXPECT_EQ(args_spec_list[1]->isa<abstract::AbstractTuple>(), true);
|
||||
auto item = args_spec_list[1]->cast<abstract::AbstractTuplePtr>();
|
||||
return args_spec_list[0];
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(TestDynamicInput,kPrimDynamicInputTest,InferImplDynamicInputTest);
|
||||
class TestAttrAndDynamicBackendInfer : public UT::Common {
|
||||
public:
|
||||
TestAttrAndDynamicBackendInfer() {}
|
||||
void SetUp() override {}
|
||||
void TearDown() override {}
|
||||
};
|
||||
|
||||
TEST_F(TestAttrAndDynamicBackendInfer, test_attr_and_dynamic_input_infer) {
|
||||
// Register Attr for ut
|
||||
ConstInputToAttrInfoRegistry ® = ConstInputToAttrInfoRegistry::Instance();
|
||||
reg.Register(kAttrConvertTestName, {1});
|
||||
// construct primitive
|
||||
PrimitivePtr prim_attr_test = std::make_shared<Primitive>(kAttrConvertTestName);
|
||||
PrimitivePtr prim_dynamic_input_test = std::make_shared<Primitive>(kDynamicInputTestName);
|
||||
// set primtive attr
|
||||
auto input_names = std::vector<std::string>{"a", "b", "c"};
|
||||
auto attr_name = "b";
|
||||
auto attr = MakeValue(std::vector<int>{1, 2, 3});
|
||||
prim_attr_test->AddAttr(kAttrInputNames, MakeValue(input_names));
|
||||
prim_attr_test->AddAttr(attr_name, attr);
|
||||
// set dynameic input list for primtive
|
||||
std::vector<int64_t> dynamic_input_list = {-1, 2, -1};
|
||||
prim_dynamic_input_test->AddAttr(kAttrDynInputSizes, MakeValue(dynamic_input_list));
|
||||
// construct Abstract list
|
||||
auto abs_a = std::make_shared<abstract::AbstractTensor>(kFloat32, std::vector<int64_t>{2, 2, 2, 2});
|
||||
auto abs_c = std::make_shared<abstract::AbstractTensor>(kFloat32, std::vector<int64_t>{2, 2, 2, 2});
|
||||
auto attr_infer_result = CppInferShape(prim_attr_test, {abs_a, abs_c});
|
||||
auto abs_dynamic_a = std::make_shared<abstract::AbstractTensor>(kFloat32, std::vector<int64_t>{2, 2, 2, 2});
|
||||
auto abs_dynamic_b = std::make_shared<abstract::AbstractTensor>(kFloat32, std::vector<int64_t>{2, 2, 2, 2});
|
||||
auto abs_dynamic_c = std::make_shared<abstract::AbstractTensor>(kFloat32, std::vector<int64_t>{2, 2, 2, 2});
|
||||
auto abs_dynamic_d = std::make_shared<abstract::AbstractTensor>(kFloat32, std::vector<int64_t>{2, 2, 2, 2});
|
||||
auto dynamic_infer_result =
|
||||
CppInferShape(prim_dynamic_input_test, {abs_dynamic_a, abs_dynamic_b, abs_dynamic_c, abs_dynamic_d});
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
Loading…
Reference in New Issue