forked from mindspore-Ecosystem/mindspore
!46904 enable ascend dynamic shape
Merge pull request !46904 from caifubi/master-pynative-dynamic-shape-enable
This commit is contained in:
commit
a3770cf22d
|
@ -196,8 +196,7 @@ FrontendOpRunInfoPtr ForwardExecutor::GenerateOpRunInfo(const py::args &args) co
|
|||
const auto &op_run_info = std::make_shared<FrontendOpRunInfo>();
|
||||
// Used for async run
|
||||
op_run_info->grad_flag = grad()->grad_flag();
|
||||
op_run_info->base_op_run_info.use_dynamic_shape_process =
|
||||
!(device_target() == kAscendDevice) && grad()->use_dynamic_shape_process();
|
||||
op_run_info->base_op_run_info.use_dynamic_shape_process = grad()->use_dynamic_shape_process();
|
||||
op_run_info->base_op_run_info.op_name = args[static_cast<size_t>(RunOpArgsEnum::PY_NAME)].cast<std::string>();
|
||||
op_run_info->base_op_run_info.lazy_build = lazy_build_;
|
||||
PyNativeAlgo::PyParser::SetPrim(op_run_info, args[static_cast<size_t>(RunOpArgsEnum::PY_PRIM)]);
|
||||
|
|
|
@ -1035,8 +1035,7 @@ FuncGraphPtr GradExecutor::GetBpropGraph(const ad::GradAttr &grad_attr, const ve
|
|||
need_renormalize_ = false;
|
||||
bprop_graph->set_flag(FUNC_GRAPH_FLAG_CORE, true);
|
||||
bprop_graph->set_flag(kFlagIsPynativeBpropGraph, true);
|
||||
bool use_dynamic_shape_process = !(forward()->device_target() == kAscendDevice) && use_dynamic_shape_process_;
|
||||
bprop_graph->set_flag(kFlagUseDynamicShapeProcess, use_dynamic_shape_process);
|
||||
bprop_graph->set_flag(kFlagUseDynamicShapeProcess, use_dynamic_shape_process_);
|
||||
bprop_graph->set_attr(kAttrFuncGraphCellId, MakeValue(top_input_args_info_->obj_id));
|
||||
return bprop_graph;
|
||||
}
|
||||
|
|
|
@ -71,6 +71,9 @@ static const std::unordered_set<std::string> kAclKernelSet = {kConv2DOpName,
|
|||
const std::map<std::string, std::vector<std::string>> kNextOpFormatList = {
|
||||
{prim::kPrimConv2D->name(), {kOpFormat_NC1HWC0, kOpFormat_FRAC_Z}}};
|
||||
|
||||
mindspore::HashSet<std::string> kHighPrecisionOp = {kConv2DOpName, kMatMulOpName, kBatchMatMulOpName,
|
||||
kConv2DBackpropInputOpName, kConv2DBackpropFilterOpName};
|
||||
|
||||
bool MatchInferOutputDataType(const CNodePtr &cnode, const kernel::KernelBuildInfo &kernel_build_info) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
// Check input data type
|
||||
|
@ -789,6 +792,27 @@ void SetRaiseOrReduceFlag(const CNodePtr &kernel_node, KernelSelectStatus status
|
|||
}
|
||||
}
|
||||
|
||||
void UpdateInputForHighPrecisionOp(const CNodePtr &kernel_node,
|
||||
const std::shared_ptr<kernel::KernelBuildInfo::KernelBuildInfoBuilder> &builder) {
|
||||
auto input_dtypes = AnfAlgo::GetAllInputDeviceTypes(kernel_node);
|
||||
auto output_dtypes = AnfAlgo::GetAllOutputDeviceTypes(kernel_node);
|
||||
auto has_fp32 = std::any_of(output_dtypes.begin(), output_dtypes.end(),
|
||||
[](TypeId type) { return type == TypeId::kNumberTypeFloat32; });
|
||||
if (has_fp32) {
|
||||
std::vector<TypeId> new_input_types;
|
||||
for (auto type : input_dtypes) {
|
||||
if (type == TypeId::kNumberTypeFloat16) {
|
||||
new_input_types.push_back(TypeId::kNumberTypeFloat32);
|
||||
} else {
|
||||
new_input_types.push_back(type);
|
||||
}
|
||||
}
|
||||
builder->SetInputsDeviceType(new_input_types);
|
||||
MS_LOG(INFO) << "Update data type for " << kernel_node->fullname_with_scope() << " from " << input_dtypes << " to "
|
||||
<< new_input_types;
|
||||
}
|
||||
}
|
||||
|
||||
void SetAclKernelInfo(const CNodePtr &kernel_node) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
if (!common::AnfAlgo::HasNodeAttr(kAttrMutableKernel, kernel_node)) {
|
||||
|
@ -821,6 +845,13 @@ void SetAclKernelInfo(const CNodePtr &kernel_node) {
|
|||
MS_EXCEPTION_IF_NULL(new_builder);
|
||||
new_builder->SetKernelType(ACL_KERNEL);
|
||||
MS_LOG(INFO) << "SUCCESS SET ACL KERNEL FOR" << kernel_node->DebugString();
|
||||
|
||||
// For high precision op
|
||||
auto op_name = common::AnfAlgo::GetCNodeName(kernel_node);
|
||||
if (kHighPrecisionOp.count(op_name) != 0) {
|
||||
UpdateInputForHighPrecisionOp(kernel_node, new_builder);
|
||||
}
|
||||
|
||||
AnfAlgo::SetSelectKernelBuildInfo(new_builder->Build(), kernel_node.get());
|
||||
}
|
||||
|
||||
|
|
|
@ -326,6 +326,11 @@ bool AscendKernelExecutor::MemoryCopyAsync(const CNodePtr &node, const vector<Ad
|
|||
|
||||
bool AscendKernelExecutor::GetKernelRealInputs(const CNodePtr &kernel, const vector<AddressPtr> &inputs,
|
||||
std::vector<AddressPtr> *real_inputs) const {
|
||||
if (AnfAlgo::GetKernelType(kernel) == KernelType::ACL_KERNEL) {
|
||||
*real_inputs = inputs;
|
||||
return true;
|
||||
}
|
||||
|
||||
auto input_num = common::AnfAlgo::GetInputTensorNum(kernel);
|
||||
if (input_num != inputs.size()) {
|
||||
MS_LOG(ERROR) << "Input num is " << input_num << " but input address num is " << inputs.size();
|
||||
|
|
|
@ -122,7 +122,7 @@ bool AclKernelMod::Launch(const std::vector<AddressPtr> &inputs, const std::vect
|
|||
return false;
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "Start aclopCompileAndExecute of node: " << node->fullname_with_scope();
|
||||
MS_LOG(INFO) << "Start aclopCompileAndExecute of node: " << node->fullname_with_scope() << " op_type_:" << op_type_;
|
||||
bool ret = aclopCompileAndExecute(const_cast<char *>(op_type_.c_str()), op_desc_ptr->input_tensor_desc().size(),
|
||||
op_desc_ptr->input_tensor_desc().data(), op_desc_ptr->input_tensor_data().data(),
|
||||
op_desc_ptr->output_tensor_desc().size(), op_desc_ptr->output_tensor_desc().data(),
|
||||
|
|
|
@ -57,8 +57,7 @@ std::shared_ptr<OpInfo> TbeDynamicShapeUtil::FindOp(const std::string &op_name,
|
|||
auto op_info = mindspore::kernel::OpLib::FindOp(op_name, OpImplyType::kImplyTBE, is_dynamic_shape);
|
||||
// If have no dynamic shape op, get static shape op
|
||||
if (op_info != nullptr && !op_info->dynamic_shape_support() && is_dynamic_shape) {
|
||||
MS_LOG(ERROR) << "Node(" << cnode->fullname_with_scope() << ") not support dynamic shape:" << cnode->DebugString();
|
||||
return nullptr;
|
||||
MS_LOG(INFO) << "Node(" << cnode->fullname_with_scope() << ") not support dynamic shape:" << cnode->DebugString();
|
||||
}
|
||||
return op_info;
|
||||
}
|
||||
|
|
|
@ -70,7 +70,7 @@ const AnfNodePtr BatchNormGradUnifyMindIR::Process(const FuncGraphPtr &func_grap
|
|||
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (common::AnfAlgo::HasNodeAttr(kAttrUnifyIRPassed, cnode)) {
|
||||
if (common::AnfAlgo::HasNodeAttr(kAttrUnifyIRPassed, cnode) || func_graph->has_flag(kAttrMutableKernel)) {
|
||||
return nullptr;
|
||||
}
|
||||
return CreateNewBatchNormGrad(func_graph, cnode);
|
||||
|
|
|
@ -27,12 +27,6 @@ OUTPUT_MAP(Pad) = {{0, OUTPUT_DESC(y)}};
|
|||
REG_ADPT_DESC(Pad, kPadOpName, ADPT_DESC(Pad))
|
||||
REG_ADPT_DESC(PadD, kPadDOpName, ADPT_DESC(Pad))
|
||||
|
||||
// BroadcastToD
|
||||
INPUT_MAP(BroadcastToD) = {{1, INPUT_DESC(x)}};
|
||||
ATTR_MAP(BroadcastToD) = {{"shape", ATTR_DESC(shape, AnyTraits<int64_t>(), AnyTraits<std::vector<int64_t>>())}};
|
||||
OUTPUT_MAP(BroadcastToD) = {{0, OUTPUT_DESC(y)}};
|
||||
REG_ADPT_DESC(BroadcastTo, kNameBroadcastTo, ADPT_DESC(BroadcastToD))
|
||||
|
||||
// BroadcastTo
|
||||
INPUT_MAP(BroadcastTo) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(shape)}};
|
||||
ATTR_INPUT_MAP(BroadcastTo) = {{"shape", "shape"}};
|
||||
|
@ -40,6 +34,7 @@ OUTPUT_MAP(BroadcastTo) = {{0, OUTPUT_DESC(y)}};
|
|||
ATTR_MAP(BroadcastTo) = EMPTY_ATTR_MAP;
|
||||
REG_ADPT_DESC(BroadcastToD, kNameBroadcastToD, ADPT_DESC(BroadcastTo))
|
||||
REG_ADPT_DESC(DynamicBroadcastTo, kDynamicBroadcastToOpName, ADPT_DESC(BroadcastTo))
|
||||
REG_ADPT_DESC(BroadcastTo, kNameBroadcastTo, ADPT_DESC(BroadcastTo))
|
||||
|
||||
// Diag
|
||||
INPUT_MAP(Diag) = {{1, INPUT_DESC(x)}};
|
||||
|
|
|
@ -28,9 +28,6 @@ DECLARE_OP_USE_OUTPUT(PadD)
|
|||
DECLARE_OP_ADAPTER(Pad)
|
||||
DECLARE_OP_USE_OUTPUT(Pad)
|
||||
|
||||
DECLARE_OP_ADAPTER(BroadcastToD)
|
||||
DECLARE_OP_USE_OUTPUT(BroadcastToD)
|
||||
|
||||
DECLARE_OP_ADAPTER(BroadcastTo)
|
||||
DECLARE_OP_USE_OUTPUT(BroadcastTo)
|
||||
|
||||
|
|
|
@ -23,14 +23,14 @@ INPUT_MAP(SplitD) = {{1, INPUT_DESC(x)}};
|
|||
ATTR_MAP(SplitD) = {{"axis", ATTR_DESC(split_dim, AnyTraits<int64_t>())},
|
||||
{"output_num", ATTR_DESC(num_split, AnyTraits<int64_t>())}};
|
||||
DYN_OUTPUT_MAP(SplitD) = {{0, DYN_OUTPUT_DESC(y)}};
|
||||
REG_ADPT_DESC(Split, kNameSplit, ADPT_DESC(SplitD))
|
||||
REG_ADPT_DESC(SplitD, kSplitDOpName, ADPT_DESC(SplitD))
|
||||
|
||||
// Split
|
||||
INPUT_MAP(Split) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(split_dim)}};
|
||||
ATTR_INPUT_MAP(Split) = {{"axis", "split_dim"}};
|
||||
ATTR_MAP(Split) = {{"num_split", ATTR_DESC(num_split, AnyTraits<int64_t>())}};
|
||||
DYN_OUTPUT_MAP(Split) = {{0, DYN_OUTPUT_DESC(y)}};
|
||||
REG_ADPT_DESC(SplitD, kSplitDOpName, ADPT_DESC(Split))
|
||||
REG_ADPT_DESC(Split, kNameSplit, ADPT_DESC(Split))
|
||||
|
||||
// Pack
|
||||
INPUT_MAP(Pack) = EMPTY_INPUT_MAP;
|
||||
|
@ -66,7 +66,8 @@ DYN_INPUT_MAP(Concat) = {{1, DYN_INPUT_DESC(x)}};
|
|||
ATTR_INPUT_MAP(Concat) = {{"axis", "concat_dim"}};
|
||||
ATTR_MAP(Concat) = {{"inputNums", ATTR_DESC(N, AnyTraits<int64_t>())}};
|
||||
OUTPUT_MAP(Concat) = {{0, OUTPUT_DESC(y)}};
|
||||
REG_ADPT_DESC(Concat, prim::kPrimConcatD->name(), ADPT_DESC(Concat))
|
||||
// Rollback to ConcatD for the support of dynamic input scene is incomplete.
|
||||
REG_ADPT_DESC(Concat, prim::kPrimConcatD->name(), ADPT_DESC(ConcatD))
|
||||
|
||||
// ConcatV2 Inference for tf
|
||||
DYN_INPUT_MAP(ConcatV2) = {{1, DYN_INPUT_DESC(x)}};
|
||||
|
|
|
@ -353,7 +353,7 @@ def test_pynative_forward_hook():
|
|||
assert np.allclose(grad[1][0].asnumpy(), expect_grad[1][0].asnumpy(), 0.000001, 0.000001)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.level2
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
|
|
Loading…
Reference in New Issue