!46904 enable ascend dynamic shape

Merge pull request !46904 from caifubi/master-pynative-dynamic-shape-enable
This commit is contained in:
i-robot 2022-12-21 01:59:29 +00:00 committed by Gitee
commit a3770cf22d
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
11 changed files with 47 additions and 21 deletions

View File

@ -196,8 +196,7 @@ FrontendOpRunInfoPtr ForwardExecutor::GenerateOpRunInfo(const py::args &args) co
const auto &op_run_info = std::make_shared<FrontendOpRunInfo>(); const auto &op_run_info = std::make_shared<FrontendOpRunInfo>();
// Used for async run // Used for async run
op_run_info->grad_flag = grad()->grad_flag(); op_run_info->grad_flag = grad()->grad_flag();
op_run_info->base_op_run_info.use_dynamic_shape_process = op_run_info->base_op_run_info.use_dynamic_shape_process = grad()->use_dynamic_shape_process();
!(device_target() == kAscendDevice) && grad()->use_dynamic_shape_process();
op_run_info->base_op_run_info.op_name = args[static_cast<size_t>(RunOpArgsEnum::PY_NAME)].cast<std::string>(); op_run_info->base_op_run_info.op_name = args[static_cast<size_t>(RunOpArgsEnum::PY_NAME)].cast<std::string>();
op_run_info->base_op_run_info.lazy_build = lazy_build_; op_run_info->base_op_run_info.lazy_build = lazy_build_;
PyNativeAlgo::PyParser::SetPrim(op_run_info, args[static_cast<size_t>(RunOpArgsEnum::PY_PRIM)]); PyNativeAlgo::PyParser::SetPrim(op_run_info, args[static_cast<size_t>(RunOpArgsEnum::PY_PRIM)]);

View File

@ -1035,8 +1035,7 @@ FuncGraphPtr GradExecutor::GetBpropGraph(const ad::GradAttr &grad_attr, const ve
need_renormalize_ = false; need_renormalize_ = false;
bprop_graph->set_flag(FUNC_GRAPH_FLAG_CORE, true); bprop_graph->set_flag(FUNC_GRAPH_FLAG_CORE, true);
bprop_graph->set_flag(kFlagIsPynativeBpropGraph, true); bprop_graph->set_flag(kFlagIsPynativeBpropGraph, true);
bool use_dynamic_shape_process = !(forward()->device_target() == kAscendDevice) && use_dynamic_shape_process_; bprop_graph->set_flag(kFlagUseDynamicShapeProcess, use_dynamic_shape_process_);
bprop_graph->set_flag(kFlagUseDynamicShapeProcess, use_dynamic_shape_process);
bprop_graph->set_attr(kAttrFuncGraphCellId, MakeValue(top_input_args_info_->obj_id)); bprop_graph->set_attr(kAttrFuncGraphCellId, MakeValue(top_input_args_info_->obj_id));
return bprop_graph; return bprop_graph;
} }

View File

@ -71,6 +71,9 @@ static const std::unordered_set<std::string> kAclKernelSet = {kConv2DOpName,
const std::map<std::string, std::vector<std::string>> kNextOpFormatList = { const std::map<std::string, std::vector<std::string>> kNextOpFormatList = {
{prim::kPrimConv2D->name(), {kOpFormat_NC1HWC0, kOpFormat_FRAC_Z}}}; {prim::kPrimConv2D->name(), {kOpFormat_NC1HWC0, kOpFormat_FRAC_Z}}};
mindspore::HashSet<std::string> kHighPrecisionOp = {kConv2DOpName, kMatMulOpName, kBatchMatMulOpName,
kConv2DBackpropInputOpName, kConv2DBackpropFilterOpName};
bool MatchInferOutputDataType(const CNodePtr &cnode, const kernel::KernelBuildInfo &kernel_build_info) { bool MatchInferOutputDataType(const CNodePtr &cnode, const kernel::KernelBuildInfo &kernel_build_info) {
MS_EXCEPTION_IF_NULL(cnode); MS_EXCEPTION_IF_NULL(cnode);
// Check input data type // Check input data type
@ -789,6 +792,27 @@ void SetRaiseOrReduceFlag(const CNodePtr &kernel_node, KernelSelectStatus status
} }
} }
void UpdateInputForHighPrecisionOp(const CNodePtr &kernel_node,
const std::shared_ptr<kernel::KernelBuildInfo::KernelBuildInfoBuilder> &builder) {
auto input_dtypes = AnfAlgo::GetAllInputDeviceTypes(kernel_node);
auto output_dtypes = AnfAlgo::GetAllOutputDeviceTypes(kernel_node);
auto has_fp32 = std::any_of(output_dtypes.begin(), output_dtypes.end(),
[](TypeId type) { return type == TypeId::kNumberTypeFloat32; });
if (has_fp32) {
std::vector<TypeId> new_input_types;
for (auto type : input_dtypes) {
if (type == TypeId::kNumberTypeFloat16) {
new_input_types.push_back(TypeId::kNumberTypeFloat32);
} else {
new_input_types.push_back(type);
}
}
builder->SetInputsDeviceType(new_input_types);
MS_LOG(INFO) << "Update data type for " << kernel_node->fullname_with_scope() << " from " << input_dtypes << " to "
<< new_input_types;
}
}
void SetAclKernelInfo(const CNodePtr &kernel_node) { void SetAclKernelInfo(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node); MS_EXCEPTION_IF_NULL(kernel_node);
if (!common::AnfAlgo::HasNodeAttr(kAttrMutableKernel, kernel_node)) { if (!common::AnfAlgo::HasNodeAttr(kAttrMutableKernel, kernel_node)) {
@ -821,6 +845,13 @@ void SetAclKernelInfo(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(new_builder); MS_EXCEPTION_IF_NULL(new_builder);
new_builder->SetKernelType(ACL_KERNEL); new_builder->SetKernelType(ACL_KERNEL);
MS_LOG(INFO) << "SUCCESS SET ACL KERNEL FOR" << kernel_node->DebugString(); MS_LOG(INFO) << "SUCCESS SET ACL KERNEL FOR" << kernel_node->DebugString();
// For high precision op
auto op_name = common::AnfAlgo::GetCNodeName(kernel_node);
if (kHighPrecisionOp.count(op_name) != 0) {
UpdateInputForHighPrecisionOp(kernel_node, new_builder);
}
AnfAlgo::SetSelectKernelBuildInfo(new_builder->Build(), kernel_node.get()); AnfAlgo::SetSelectKernelBuildInfo(new_builder->Build(), kernel_node.get());
} }

View File

@ -326,6 +326,11 @@ bool AscendKernelExecutor::MemoryCopyAsync(const CNodePtr &node, const vector<Ad
bool AscendKernelExecutor::GetKernelRealInputs(const CNodePtr &kernel, const vector<AddressPtr> &inputs, bool AscendKernelExecutor::GetKernelRealInputs(const CNodePtr &kernel, const vector<AddressPtr> &inputs,
std::vector<AddressPtr> *real_inputs) const { std::vector<AddressPtr> *real_inputs) const {
if (AnfAlgo::GetKernelType(kernel) == KernelType::ACL_KERNEL) {
*real_inputs = inputs;
return true;
}
auto input_num = common::AnfAlgo::GetInputTensorNum(kernel); auto input_num = common::AnfAlgo::GetInputTensorNum(kernel);
if (input_num != inputs.size()) { if (input_num != inputs.size()) {
MS_LOG(ERROR) << "Input num is " << input_num << " but input address num is " << inputs.size(); MS_LOG(ERROR) << "Input num is " << input_num << " but input address num is " << inputs.size();

View File

@ -122,7 +122,7 @@ bool AclKernelMod::Launch(const std::vector<AddressPtr> &inputs, const std::vect
return false; return false;
} }
MS_LOG(INFO) << "Start aclopCompileAndExecute of node: " << node->fullname_with_scope(); MS_LOG(INFO) << "Start aclopCompileAndExecute of node: " << node->fullname_with_scope() << " op_type_:" << op_type_;
bool ret = aclopCompileAndExecute(const_cast<char *>(op_type_.c_str()), op_desc_ptr->input_tensor_desc().size(), bool ret = aclopCompileAndExecute(const_cast<char *>(op_type_.c_str()), op_desc_ptr->input_tensor_desc().size(),
op_desc_ptr->input_tensor_desc().data(), op_desc_ptr->input_tensor_data().data(), op_desc_ptr->input_tensor_desc().data(), op_desc_ptr->input_tensor_data().data(),
op_desc_ptr->output_tensor_desc().size(), op_desc_ptr->output_tensor_desc().data(), op_desc_ptr->output_tensor_desc().size(), op_desc_ptr->output_tensor_desc().data(),

View File

@ -57,8 +57,7 @@ std::shared_ptr<OpInfo> TbeDynamicShapeUtil::FindOp(const std::string &op_name,
auto op_info = mindspore::kernel::OpLib::FindOp(op_name, OpImplyType::kImplyTBE, is_dynamic_shape); auto op_info = mindspore::kernel::OpLib::FindOp(op_name, OpImplyType::kImplyTBE, is_dynamic_shape);
// If have no dynamic shape op, get static shape op // If have no dynamic shape op, get static shape op
if (op_info != nullptr && !op_info->dynamic_shape_support() && is_dynamic_shape) { if (op_info != nullptr && !op_info->dynamic_shape_support() && is_dynamic_shape) {
MS_LOG(ERROR) << "Node(" << cnode->fullname_with_scope() << ") not support dynamic shape:" << cnode->DebugString(); MS_LOG(INFO) << "Node(" << cnode->fullname_with_scope() << ") not support dynamic shape:" << cnode->DebugString();
return nullptr;
} }
return op_info; return op_info;
} }

View File

@ -70,7 +70,7 @@ const AnfNodePtr BatchNormGradUnifyMindIR::Process(const FuncGraphPtr &func_grap
auto cnode = node->cast<CNodePtr>(); auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode); MS_EXCEPTION_IF_NULL(cnode);
if (common::AnfAlgo::HasNodeAttr(kAttrUnifyIRPassed, cnode)) { if (common::AnfAlgo::HasNodeAttr(kAttrUnifyIRPassed, cnode) || func_graph->has_flag(kAttrMutableKernel)) {
return nullptr; return nullptr;
} }
return CreateNewBatchNormGrad(func_graph, cnode); return CreateNewBatchNormGrad(func_graph, cnode);

View File

@ -27,12 +27,6 @@ OUTPUT_MAP(Pad) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(Pad, kPadOpName, ADPT_DESC(Pad)) REG_ADPT_DESC(Pad, kPadOpName, ADPT_DESC(Pad))
REG_ADPT_DESC(PadD, kPadDOpName, ADPT_DESC(Pad)) REG_ADPT_DESC(PadD, kPadDOpName, ADPT_DESC(Pad))
// BroadcastToD
INPUT_MAP(BroadcastToD) = {{1, INPUT_DESC(x)}};
ATTR_MAP(BroadcastToD) = {{"shape", ATTR_DESC(shape, AnyTraits<int64_t>(), AnyTraits<std::vector<int64_t>>())}};
OUTPUT_MAP(BroadcastToD) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(BroadcastTo, kNameBroadcastTo, ADPT_DESC(BroadcastToD))
// BroadcastTo // BroadcastTo
INPUT_MAP(BroadcastTo) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(shape)}}; INPUT_MAP(BroadcastTo) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(shape)}};
ATTR_INPUT_MAP(BroadcastTo) = {{"shape", "shape"}}; ATTR_INPUT_MAP(BroadcastTo) = {{"shape", "shape"}};
@ -40,6 +34,7 @@ OUTPUT_MAP(BroadcastTo) = {{0, OUTPUT_DESC(y)}};
ATTR_MAP(BroadcastTo) = EMPTY_ATTR_MAP; ATTR_MAP(BroadcastTo) = EMPTY_ATTR_MAP;
REG_ADPT_DESC(BroadcastToD, kNameBroadcastToD, ADPT_DESC(BroadcastTo)) REG_ADPT_DESC(BroadcastToD, kNameBroadcastToD, ADPT_DESC(BroadcastTo))
REG_ADPT_DESC(DynamicBroadcastTo, kDynamicBroadcastToOpName, ADPT_DESC(BroadcastTo)) REG_ADPT_DESC(DynamicBroadcastTo, kDynamicBroadcastToOpName, ADPT_DESC(BroadcastTo))
REG_ADPT_DESC(BroadcastTo, kNameBroadcastTo, ADPT_DESC(BroadcastTo))
// Diag // Diag
INPUT_MAP(Diag) = {{1, INPUT_DESC(x)}}; INPUT_MAP(Diag) = {{1, INPUT_DESC(x)}};

View File

@ -28,9 +28,6 @@ DECLARE_OP_USE_OUTPUT(PadD)
DECLARE_OP_ADAPTER(Pad) DECLARE_OP_ADAPTER(Pad)
DECLARE_OP_USE_OUTPUT(Pad) DECLARE_OP_USE_OUTPUT(Pad)
DECLARE_OP_ADAPTER(BroadcastToD)
DECLARE_OP_USE_OUTPUT(BroadcastToD)
DECLARE_OP_ADAPTER(BroadcastTo) DECLARE_OP_ADAPTER(BroadcastTo)
DECLARE_OP_USE_OUTPUT(BroadcastTo) DECLARE_OP_USE_OUTPUT(BroadcastTo)

View File

@ -23,14 +23,14 @@ INPUT_MAP(SplitD) = {{1, INPUT_DESC(x)}};
ATTR_MAP(SplitD) = {{"axis", ATTR_DESC(split_dim, AnyTraits<int64_t>())}, ATTR_MAP(SplitD) = {{"axis", ATTR_DESC(split_dim, AnyTraits<int64_t>())},
{"output_num", ATTR_DESC(num_split, AnyTraits<int64_t>())}}; {"output_num", ATTR_DESC(num_split, AnyTraits<int64_t>())}};
DYN_OUTPUT_MAP(SplitD) = {{0, DYN_OUTPUT_DESC(y)}}; DYN_OUTPUT_MAP(SplitD) = {{0, DYN_OUTPUT_DESC(y)}};
REG_ADPT_DESC(Split, kNameSplit, ADPT_DESC(SplitD)) REG_ADPT_DESC(SplitD, kSplitDOpName, ADPT_DESC(SplitD))
// Split // Split
INPUT_MAP(Split) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(split_dim)}}; INPUT_MAP(Split) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(split_dim)}};
ATTR_INPUT_MAP(Split) = {{"axis", "split_dim"}}; ATTR_INPUT_MAP(Split) = {{"axis", "split_dim"}};
ATTR_MAP(Split) = {{"num_split", ATTR_DESC(num_split, AnyTraits<int64_t>())}}; ATTR_MAP(Split) = {{"num_split", ATTR_DESC(num_split, AnyTraits<int64_t>())}};
DYN_OUTPUT_MAP(Split) = {{0, DYN_OUTPUT_DESC(y)}}; DYN_OUTPUT_MAP(Split) = {{0, DYN_OUTPUT_DESC(y)}};
REG_ADPT_DESC(SplitD, kSplitDOpName, ADPT_DESC(Split)) REG_ADPT_DESC(Split, kNameSplit, ADPT_DESC(Split))
// Pack // Pack
INPUT_MAP(Pack) = EMPTY_INPUT_MAP; INPUT_MAP(Pack) = EMPTY_INPUT_MAP;
@ -66,7 +66,8 @@ DYN_INPUT_MAP(Concat) = {{1, DYN_INPUT_DESC(x)}};
ATTR_INPUT_MAP(Concat) = {{"axis", "concat_dim"}}; ATTR_INPUT_MAP(Concat) = {{"axis", "concat_dim"}};
ATTR_MAP(Concat) = {{"inputNums", ATTR_DESC(N, AnyTraits<int64_t>())}}; ATTR_MAP(Concat) = {{"inputNums", ATTR_DESC(N, AnyTraits<int64_t>())}};
OUTPUT_MAP(Concat) = {{0, OUTPUT_DESC(y)}}; OUTPUT_MAP(Concat) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(Concat, prim::kPrimConcatD->name(), ADPT_DESC(Concat)) // Rollback to ConcatD for the support of dynamic input scene is incomplete.
REG_ADPT_DESC(Concat, prim::kPrimConcatD->name(), ADPT_DESC(ConcatD))
// ConcatV2 Inference for tf // ConcatV2 Inference for tf
DYN_INPUT_MAP(ConcatV2) = {{1, DYN_INPUT_DESC(x)}}; DYN_INPUT_MAP(ConcatV2) = {{1, DYN_INPUT_DESC(x)}};

View File

@ -353,7 +353,7 @@ def test_pynative_forward_hook():
assert np.allclose(grad[1][0].asnumpy(), expect_grad[1][0].asnumpy(), 0.000001, 0.000001) assert np.allclose(grad[1][0].asnumpy(), expect_grad[1][0].asnumpy(), 0.000001, 0.000001)
@pytest.mark.level0 @pytest.mark.level2
@pytest.mark.platform_x86_cpu @pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_ascend_training @pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training @pytest.mark.platform_x86_ascend_training