forked from mindspore-Ecosystem/mindspore
!46931 remove cache for dtype in abstract
Merge pull request !46931 from chengbin/infer_shape_only
This commit is contained in:
commit
6ef1661f88
|
@ -38,7 +38,6 @@
|
||||||
#include "backend/common/optimizer/dynamic_shape/link_custom_op.h"
|
#include "backend/common/optimizer/dynamic_shape/link_custom_op.h"
|
||||||
#include "backend/common/pass/convert_unused_tuple_para_to_make_tuple.h"
|
#include "backend/common/pass/convert_unused_tuple_para_to_make_tuple.h"
|
||||||
#include "backend/common/pass/convert_dynamic_broadcast_to.h"
|
#include "backend/common/pass/convert_dynamic_broadcast_to.h"
|
||||||
#include "backend/common/optimizer/dynamic_shape/dynamic_shape_dtype_record.h"
|
|
||||||
#include "utils/ms_context.h"
|
#include "utils/ms_context.h"
|
||||||
#include "include/common/debug/anf_ir_dump.h"
|
#include "include/common/debug/anf_ir_dump.h"
|
||||||
|
|
||||||
|
@ -207,7 +206,6 @@ void DynamicShapeConvertPass(const std::shared_ptr<session::KernelGraph> &kernel
|
||||||
auto dynamic_shape_convert_pm = std::make_shared<opt::PassManager>("dynamic_shape_convert_pm");
|
auto dynamic_shape_convert_pm = std::make_shared<opt::PassManager>("dynamic_shape_convert_pm");
|
||||||
dynamic_shape_convert_pm->AddPass(std::make_shared<opt::dynamic_shape::ConvertCustomOp>());
|
dynamic_shape_convert_pm->AddPass(std::make_shared<opt::dynamic_shape::ConvertCustomOp>());
|
||||||
dynamic_shape_convert_pm->AddPass(std::make_shared<opt::dynamic_shape::LinkCustomOp>());
|
dynamic_shape_convert_pm->AddPass(std::make_shared<opt::dynamic_shape::LinkCustomOp>());
|
||||||
dynamic_shape_convert_pm->AddPass(std::make_shared<opt::dynamic_shape::DynamicShapeDtypeRecord>());
|
|
||||||
optimizer->AddPassManager(dynamic_shape_convert_pm);
|
optimizer->AddPassManager(dynamic_shape_convert_pm);
|
||||||
(void)optimizer->Optimize(kernel_graph);
|
(void)optimizer->Optimize(kernel_graph);
|
||||||
#ifdef ENABLE_DUMP_IR
|
#ifdef ENABLE_DUMP_IR
|
||||||
|
|
|
@ -1,98 +0,0 @@
|
||||||
/**
|
|
||||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
#include "backend/common/optimizer/dynamic_shape/dynamic_shape_dtype_record.h"
|
|
||||||
#include "backend/common/session/anf_runtime_algorithm.h"
|
|
||||||
#include "include/common/utils/anfalgo.h"
|
|
||||||
#include "backend/common/optimizer/helper.h"
|
|
||||||
#include "utils/anf_utils.h"
|
|
||||||
|
|
||||||
namespace mindspore {
|
|
||||||
namespace opt::dynamic_shape {
|
|
||||||
DynamicShapeDtypeManager &DynamicShapeDtypeManager::GetInstance() {
|
|
||||||
static DynamicShapeDtypeManager instance{};
|
|
||||||
return instance;
|
|
||||||
}
|
|
||||||
|
|
||||||
void DynamicShapeDtypeManager::Register(const AnfNodePtr &node, const TypePtrList &device_abs) {
|
|
||||||
if (device_type_recorder_.find(node) == device_type_recorder_.end()) {
|
|
||||||
(void)device_type_recorder_.emplace(node, device_abs);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
bool DynamicShapeDtypeManager::CheckDeviceType(const AnfNodePtr &node) const {
|
|
||||||
return (device_type_recorder_.find(node) != device_type_recorder_.end());
|
|
||||||
}
|
|
||||||
|
|
||||||
TypePtrList DynamicShapeDtypeManager::GetDeviceType(const AnfNodePtr &node) {
|
|
||||||
auto iter = device_type_recorder_.find(node);
|
|
||||||
if (iter != device_type_recorder_.end()) {
|
|
||||||
return iter->second;
|
|
||||||
}
|
|
||||||
return {};
|
|
||||||
}
|
|
||||||
|
|
||||||
bool DynamicShapeDtypeRecord::Run(const FuncGraphPtr &func_graph) {
|
|
||||||
MS_EXCEPTION_IF_NULL(func_graph);
|
|
||||||
auto nodes = TopoSort(func_graph->get_return());
|
|
||||||
for (const auto &node : nodes) {
|
|
||||||
CNodePtr cnode = node->cast<CNodePtr>();
|
|
||||||
if (cnode == nullptr || !AnfUtils::IsRealKernel(cnode)) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto kernel_info = node->kernel_info();
|
|
||||||
if (kernel_info == nullptr || !kernel_info->has_build_info()) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto out_num = AnfAlgo::GetOutputTensorNum(node);
|
|
||||||
auto node_abs = node->abstract();
|
|
||||||
if (node_abs->isa<abstract::AbstractTensor>()) {
|
|
||||||
if (out_num != 1) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
auto infer_type = common::AnfAlgo::GetOutputInferDataType(node, 0);
|
|
||||||
auto device_type = AnfAlgo::GetOutputDeviceDataType(node, 0);
|
|
||||||
if (infer_type != device_type) {
|
|
||||||
TypePtrList new_abstract = {TypeIdToType(device_type)};
|
|
||||||
DynamicShapeDtypeManager::GetInstance().Register(node, new_abstract);
|
|
||||||
}
|
|
||||||
} else if (node_abs->isa<abstract::AbstractTuple>()) {
|
|
||||||
auto abstract_tuple = node_abs->cast<abstract::AbstractTuplePtr>();
|
|
||||||
MS_EXCEPTION_IF_NULL(abstract_tuple);
|
|
||||||
TypePtrList abstract_list;
|
|
||||||
bool find_diff_element = false;
|
|
||||||
for (size_t output_index = 0; output_index < out_num; ++output_index) {
|
|
||||||
auto cur_element = abstract_tuple->elements()[output_index];
|
|
||||||
MS_EXCEPTION_IF_NULL(cur_element);
|
|
||||||
auto infer_type = common::AnfAlgo::GetOutputInferDataType(node, output_index);
|
|
||||||
auto device_type = AnfAlgo::GetOutputDeviceDataType(node, output_index);
|
|
||||||
if (infer_type != device_type) {
|
|
||||||
find_diff_element = true;
|
|
||||||
}
|
|
||||||
abstract_list.push_back(TypeIdToType(device_type));
|
|
||||||
}
|
|
||||||
if (find_diff_element) {
|
|
||||||
DynamicShapeDtypeManager::GetInstance().Register(node, abstract_list);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
} // namespace opt::dynamic_shape
|
|
||||||
} // namespace mindspore
|
|
|
@ -1,49 +0,0 @@
|
||||||
/**
|
|
||||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
#ifndef MINDSPORE_CCSRC_BACKEND_COMMON_OPTIMIZER_DYNAMIC_SHAPE_DTYPE_RECORD_H_
|
|
||||||
#define MINDSPORE_CCSRC_BACKEND_COMMON_OPTIMIZER_DYNAMIC_SHAPE_DTYPE_RECORD_H_
|
|
||||||
|
|
||||||
#include <map>
|
|
||||||
#include <vector>
|
|
||||||
#include "ir/anf.h"
|
|
||||||
#include "backend/common/optimizer/optimizer.h"
|
|
||||||
|
|
||||||
namespace mindspore::opt::dynamic_shape {
|
|
||||||
class DynamicShapeDtypeManager {
|
|
||||||
public:
|
|
||||||
static DynamicShapeDtypeManager &GetInstance();
|
|
||||||
void Register(const AnfNodePtr &node, const std::vector<TypePtr> &device_abs);
|
|
||||||
bool CheckDeviceType(const AnfNodePtr &node) const;
|
|
||||||
TypePtrList GetDeviceType(const AnfNodePtr &node);
|
|
||||||
|
|
||||||
private:
|
|
||||||
DynamicShapeDtypeManager() = default;
|
|
||||||
~DynamicShapeDtypeManager() = default;
|
|
||||||
DISABLE_COPY_AND_ASSIGN(DynamicShapeDtypeManager);
|
|
||||||
|
|
||||||
std::map<AnfNodePtr, TypePtrList> device_type_recorder_;
|
|
||||||
};
|
|
||||||
|
|
||||||
// If the data type of abstract is not same with the one of device, it will replace with device data type.
|
|
||||||
class DynamicShapeDtypeRecord : public Pass {
|
|
||||||
public:
|
|
||||||
DynamicShapeDtypeRecord() : Pass("dynamic_shape_dtype_record") {}
|
|
||||||
~DynamicShapeDtypeRecord() override = default;
|
|
||||||
bool Run(const FuncGraphPtr &func_graph) override;
|
|
||||||
};
|
|
||||||
} // namespace mindspore::opt::dynamic_shape
|
|
||||||
#endif // MINDSPORE_CCSRC_BACKEND_COMMON_OPTIMIZER_DYNAMIC_SHAPE_DTYPE_RECORD_H_
|
|
|
@ -24,7 +24,6 @@
|
||||||
#include <map>
|
#include <map>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
#include "backend/common/session/anf_runtime_algorithm.h"
|
#include "backend/common/session/anf_runtime_algorithm.h"
|
||||||
#include "backend/common/optimizer/dynamic_shape/dynamic_shape_dtype_record.h"
|
|
||||||
#include "runtime/device/ms_device_shape_transfer.h"
|
#include "runtime/device/ms_device_shape_transfer.h"
|
||||||
#include "include/common/utils/anfalgo.h"
|
#include "include/common/utils/anfalgo.h"
|
||||||
#include "include/common/utils/utils.h"
|
#include "include/common/utils/utils.h"
|
||||||
|
@ -98,18 +97,17 @@ bool InferShapeForDefiniteOutputNode(const CNodePtr &cnode) {
|
||||||
|
|
||||||
tensor::TensorPtr GetDependValueTensor(const AnfNodePtr &node, size_t i,
|
tensor::TensorPtr GetDependValueTensor(const AnfNodePtr &node, size_t i,
|
||||||
const std::pair<AnfNodePtr, size_t> &input_node_with_index, bool skip_nop_node,
|
const std::pair<AnfNodePtr, size_t> &input_node_with_index, bool skip_nop_node,
|
||||||
void *args, bool abstract_in_cache) {
|
void *args) {
|
||||||
auto real_input = input_node_with_index.first;
|
auto real_input = input_node_with_index.first;
|
||||||
MS_EXCEPTION_IF_NULL(real_input);
|
MS_EXCEPTION_IF_NULL(real_input);
|
||||||
auto real_input_index = input_node_with_index.second;
|
auto real_input_index = input_node_with_index.second;
|
||||||
auto shapes = trans::GetRuntimePaddingShape(real_input, real_input_index);
|
auto shapes = trans::GetRuntimePaddingShape(real_input, real_input_index);
|
||||||
TypeId host_type;
|
TypeId host_type;
|
||||||
if (abstract_in_cache) {
|
if (real_input->isa<ValueNode>()) {
|
||||||
// for cnode in the cache, we use device type as there is a mismatch
|
// the type of ValueNode in KernelInfo is kTypeUnknown
|
||||||
host_type = AnfAlgo::GetOutputDeviceDataType(real_input, real_input_index);
|
|
||||||
} else {
|
|
||||||
// for cnode not in the cache, valuenodes and other nodes, we use inferred type
|
|
||||||
host_type = common::AnfAlgo::GetOutputInferDataType(real_input, real_input_index);
|
host_type = common::AnfAlgo::GetOutputInferDataType(real_input, real_input_index);
|
||||||
|
} else {
|
||||||
|
host_type = AnfAlgo::GetOutputDeviceDataType(real_input, real_input_index);
|
||||||
}
|
}
|
||||||
auto out_tensor = std::make_shared<tensor::Tensor>(host_type, shapes);
|
auto out_tensor = std::make_shared<tensor::Tensor>(host_type, shapes);
|
||||||
|
|
||||||
|
@ -167,40 +165,14 @@ void InferShape(const CNodePtr &cnode, std::map<uint32_t, tensor::TensorPtr> *de
|
||||||
auto input_node_with_index = common::AnfAlgo::GetPrevNodeOutput(cnode, i, false);
|
auto input_node_with_index = common::AnfAlgo::GetPrevNodeOutput(cnode, i, false);
|
||||||
auto real_input = input_node_with_index.first;
|
auto real_input = input_node_with_index.first;
|
||||||
auto real_input_index = input_node_with_index.second;
|
auto real_input_index = input_node_with_index.second;
|
||||||
|
|
||||||
bool abstract_in_cache = DynamicShapeDtypeManager::GetInstance().CheckDeviceType(real_input);
|
|
||||||
AbstractBasePtr cached_abstract;
|
|
||||||
AbstractBasePtr real_input_abs = real_input->abstract();
|
AbstractBasePtr real_input_abs = real_input->abstract();
|
||||||
|
|
||||||
if (abstract_in_cache) {
|
|
||||||
auto cached_type_list = DynamicShapeDtypeManager::GetInstance().GetDeviceType(real_input);
|
|
||||||
if (real_input_abs->isa<abstract::AbstractTensor>()) {
|
|
||||||
auto shape_ptr = real_input_abs->BuildShape();
|
|
||||||
cached_abstract = std::make_shared<abstract::AbstractTensor>(cached_type_list[0], shape_ptr);
|
|
||||||
} else if (real_input_abs->isa<abstract::AbstractTuple>()) {
|
|
||||||
auto abstract_tuple = real_input_abs->cast<abstract::AbstractTuplePtr>();
|
|
||||||
MS_EXCEPTION_IF_NULL(abstract_tuple);
|
|
||||||
AbstractBasePtrList abstract_list;
|
|
||||||
|
|
||||||
for (size_t output_index = 0; output_index < cached_type_list.size(); ++output_index) {
|
|
||||||
auto cur_element = abstract_tuple->elements()[output_index];
|
|
||||||
MS_EXCEPTION_IF_NULL(cur_element);
|
|
||||||
auto shape_ptr = cur_element->BuildShape();
|
|
||||||
auto new_abstract = std::make_shared<abstract::AbstractTensor>(cached_type_list[output_index], shape_ptr);
|
|
||||||
abstract_list.push_back(new_abstract);
|
|
||||||
}
|
|
||||||
cached_abstract = std::make_shared<abstract::AbstractTuple>(abstract_list);
|
|
||||||
} else {
|
|
||||||
MS_LOG(EXCEPTION) << "Output of " << real_input->fullname_with_scope()
|
|
||||||
<< " is neither a Tensor nor a Tuple of Tensor, but " << real_input_abs->ToString();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
MS_EXCEPTION_IF_NULL(real_input);
|
MS_EXCEPTION_IF_NULL(real_input);
|
||||||
if (skip_nop_node) {
|
if (skip_nop_node) {
|
||||||
InferShapeForNopNode(real_input);
|
InferShapeForNopNode(real_input);
|
||||||
}
|
}
|
||||||
if (depend_list.find(i) != depend_list.end()) {
|
if (depend_list.find(i) != depend_list.end()) {
|
||||||
auto out_tensor = GetDependValueTensor(cnode, i, input_node_with_index, skip_nop_node, args, abstract_in_cache);
|
auto out_tensor = GetDependValueTensor(cnode, i, input_node_with_index, skip_nop_node, args);
|
||||||
auto ret2 = depend_tensor_map->try_emplace(i, out_tensor);
|
auto ret2 = depend_tensor_map->try_emplace(i, out_tensor);
|
||||||
if (!ret2.second) {
|
if (!ret2.second) {
|
||||||
MS_LOG(EXCEPTION) << "Insert map failed.";
|
MS_LOG(EXCEPTION) << "Insert map failed.";
|
||||||
|
@ -208,12 +180,7 @@ void InferShape(const CNodePtr &cnode, std::map<uint32_t, tensor::TensorPtr> *de
|
||||||
|
|
||||||
// cppcheck-suppress unreadVariable
|
// cppcheck-suppress unreadVariable
|
||||||
auto lock = AnfUtils::GetAbstractLock(real_input.get());
|
auto lock = AnfUtils::GetAbstractLock(real_input.get());
|
||||||
AbstractBasePtr real_abs;
|
auto real_abs = real_input->abstract();
|
||||||
if (abstract_in_cache) {
|
|
||||||
real_abs = cached_abstract;
|
|
||||||
} else {
|
|
||||||
real_abs = real_input->abstract();
|
|
||||||
}
|
|
||||||
if (real_abs->isa<abstract::AbstractTensor>()) {
|
if (real_abs->isa<abstract::AbstractTensor>()) {
|
||||||
real_abs->set_value(out_tensor);
|
real_abs->set_value(out_tensor);
|
||||||
} else if (real_abs->isa<abstract::AbstractTuple>()) {
|
} else if (real_abs->isa<abstract::AbstractTuple>()) {
|
||||||
|
@ -224,19 +191,8 @@ void InferShape(const CNodePtr &cnode, std::map<uint32_t, tensor::TensorPtr> *de
|
||||||
tuple_elements->set_value(out_tensor);
|
tuple_elements->set_value(out_tensor);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (abstract_in_cache) {
|
|
||||||
if (cached_abstract->isa<abstract::AbstractTuple>()) {
|
common::AnfAlgo::AddArgList(&args_spec_list, real_input, real_input_index);
|
||||||
auto abs_tuple = cached_abstract->Clone()->cast<abstract::AbstractTuplePtr>();
|
|
||||||
MS_EXCEPTION_IF_NULL(abs_tuple);
|
|
||||||
MS_EXCEPTION_IF_CHECK_FAIL((real_input_index < abs_tuple->elements().size()), "Index is out of range.");
|
|
||||||
auto abs_index = abs_tuple->elements()[real_input_index];
|
|
||||||
(void)args_spec_list.emplace_back(abs_index);
|
|
||||||
} else {
|
|
||||||
(void)args_spec_list.emplace_back(cached_abstract->Clone());
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
common::AnfAlgo::AddArgList(&args_spec_list, real_input, real_input_index);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Pynative mode is rely on the origin abstract of cnode, so cannot modify the abstract inplace, clone from old
|
// Pynative mode is rely on the origin abstract of cnode, so cannot modify the abstract inplace, clone from old
|
||||||
|
|
|
@ -796,58 +796,6 @@ AnfNodePtr HandleSexpVector(const BaseRef &sexp, const BaseRef &graph, Primitive
|
||||||
return CreateCNodeWithGraph(input_nodes, graph);
|
return CreateCNodeWithGraph(input_nodes, graph);
|
||||||
}
|
}
|
||||||
|
|
||||||
// rectify absttract if the input has been converted to the attr
|
|
||||||
AbstractBasePtrList RectifyAbstractFromRegAttr(const PrimitivePtr &primitive,
|
|
||||||
const AbstractBasePtrList &input_abstract) {
|
|
||||||
MS_EXCEPTION_IF_NULL(primitive);
|
|
||||||
opt::ConstInputToAttrInfoRegister reg;
|
|
||||||
if (!opt::ConstInputToAttrInfoRegistry::Instance().GetRegisterByOpName(primitive->name(), ®)) {
|
|
||||||
return input_abstract;
|
|
||||||
}
|
|
||||||
if (common::AnfAlgo::HasDynamicShapeFlag(primitive)) {
|
|
||||||
return input_abstract;
|
|
||||||
}
|
|
||||||
auto ms_context = MsContext::GetInstance();
|
|
||||||
MS_EXCEPTION_IF_NULL(ms_context);
|
|
||||||
auto device = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
|
|
||||||
if (device == kGPUDevice) {
|
|
||||||
if (IsOneOfDynamicShapeConstInputToAttrGPU(primitive->name())) {
|
|
||||||
return input_abstract;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
auto convert_input_list = reg.GetConstInputAttrInfo();
|
|
||||||
auto input_names = primitive->GetAttr(kAttrInputNames);
|
|
||||||
if (input_names == nullptr) {
|
|
||||||
return input_abstract;
|
|
||||||
}
|
|
||||||
auto input_names_vec = GetValue<std::vector<std::string>>(input_names);
|
|
||||||
AbstractBasePtrList rectify_abs_list;
|
|
||||||
size_t ori_index = 0;
|
|
||||||
rectify_abs_list.resize(input_names_vec.size());
|
|
||||||
for (size_t index = 0; index < rectify_abs_list.size(); ++index) {
|
|
||||||
// if convert input list find the index it means the input has been converted to the attr
|
|
||||||
if (convert_input_list.find(index) != convert_input_list.end()) {
|
|
||||||
AbstractBasePtr rectify_abs = nullptr;
|
|
||||||
auto input_name = input_names_vec[index];
|
|
||||||
auto attr = primitive->GetAttr(input_name);
|
|
||||||
if (attr != nullptr) {
|
|
||||||
rectify_abs = attr->ToAbstract();
|
|
||||||
} else {
|
|
||||||
MS_LOG(DEBUG) << "the node prim name :" << primitive->name() << "input index :" << index
|
|
||||||
<< " input name :" << input_name << "has not been converted to the attr";
|
|
||||||
rectify_abs = input_abstract[ori_index++];
|
|
||||||
}
|
|
||||||
rectify_abs_list[index] = rectify_abs;
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if (ori_index > input_abstract.size()) {
|
|
||||||
MS_LOG(EXCEPTION) << "Index " << ori_index << " is out of range in input abstract size " << input_abstract.size();
|
|
||||||
}
|
|
||||||
rectify_abs_list[index] = input_abstract[ori_index++];
|
|
||||||
}
|
|
||||||
return rectify_abs_list;
|
|
||||||
}
|
|
||||||
|
|
||||||
AbstractBasePtrList RectifyAbstractFromDynamicInput(const PrimitivePtr &prim,
|
AbstractBasePtrList RectifyAbstractFromDynamicInput(const PrimitivePtr &prim,
|
||||||
const AbstractBasePtrList &input_abstract) {
|
const AbstractBasePtrList &input_abstract) {
|
||||||
MS_EXCEPTION_IF_NULL(prim);
|
MS_EXCEPTION_IF_NULL(prim);
|
||||||
|
@ -889,8 +837,7 @@ AbstractBasePtrList RectifyAbstractFromDynamicInput(const PrimitivePtr &prim,
|
||||||
}
|
}
|
||||||
|
|
||||||
AbstractBasePtrList RectifyAbstract(const PrimitivePtr &primitive, const AbstractBasePtrList &input_abstract) {
|
AbstractBasePtrList RectifyAbstract(const PrimitivePtr &primitive, const AbstractBasePtrList &input_abstract) {
|
||||||
auto rectify_abs_list = RectifyAbstractFromRegAttr(primitive, input_abstract);
|
return RectifyAbstractFromDynamicInput(primitive, input_abstract);
|
||||||
return RectifyAbstractFromDynamicInput(primitive, rectify_abs_list);
|
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
|
|
@ -45,9 +45,10 @@ inline const PrimitivePtr kPrimAttrConvertTest = std::make_shared<Primitive>(kAt
|
||||||
inline const PrimitivePtr kPrimDynamicInputTest = std::make_shared<Primitive>("dynamic_input_test");
|
inline const PrimitivePtr kPrimDynamicInputTest = std::make_shared<Primitive>("dynamic_input_test");
|
||||||
AbstractBasePtr InferImplAttrTest(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
AbstractBasePtr InferImplAttrTest(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
const AbstractBasePtrList &args_spec_list) {
|
const AbstractBasePtrList &args_spec_list) {
|
||||||
EXPECT_EQ(args_spec_list.size(), 3);
|
// CppInferShapeAndType does not convert attr to input
|
||||||
|
EXPECT_EQ(args_spec_list.size(), 2);
|
||||||
EXPECT_NE(args_spec_list[1], nullptr);
|
EXPECT_NE(args_spec_list[1], nullptr);
|
||||||
EXPECT_EQ(args_spec_list[1]->isa<abstract::AbstractTuple>(), true);
|
EXPECT_EQ(args_spec_list[1]->isa<abstract::AbstractTensor>(), true);
|
||||||
return args_spec_list[0];
|
return args_spec_list[0];
|
||||||
}
|
}
|
||||||
REGISTER_PRIMITIVE_EVAL_IMPL(TestAttr, kPrimAttrConvertTest, InferImplAttrTest, nullptr, true);
|
REGISTER_PRIMITIVE_EVAL_IMPL(TestAttr, kPrimAttrConvertTest, InferImplAttrTest, nullptr, true);
|
||||||
|
|
Loading…
Reference in New Issue