forked from mindspore-Ecosystem/mindspore
!46904 enable ascend dynamic shape
Merge pull request !46904 from caifubi/master-pynative-dynamic-shape-enable
This commit is contained in:
commit
a3770cf22d
|
@ -196,8 +196,7 @@ FrontendOpRunInfoPtr ForwardExecutor::GenerateOpRunInfo(const py::args &args) co
|
||||||
const auto &op_run_info = std::make_shared<FrontendOpRunInfo>();
|
const auto &op_run_info = std::make_shared<FrontendOpRunInfo>();
|
||||||
// Used for async run
|
// Used for async run
|
||||||
op_run_info->grad_flag = grad()->grad_flag();
|
op_run_info->grad_flag = grad()->grad_flag();
|
||||||
op_run_info->base_op_run_info.use_dynamic_shape_process =
|
op_run_info->base_op_run_info.use_dynamic_shape_process = grad()->use_dynamic_shape_process();
|
||||||
!(device_target() == kAscendDevice) && grad()->use_dynamic_shape_process();
|
|
||||||
op_run_info->base_op_run_info.op_name = args[static_cast<size_t>(RunOpArgsEnum::PY_NAME)].cast<std::string>();
|
op_run_info->base_op_run_info.op_name = args[static_cast<size_t>(RunOpArgsEnum::PY_NAME)].cast<std::string>();
|
||||||
op_run_info->base_op_run_info.lazy_build = lazy_build_;
|
op_run_info->base_op_run_info.lazy_build = lazy_build_;
|
||||||
PyNativeAlgo::PyParser::SetPrim(op_run_info, args[static_cast<size_t>(RunOpArgsEnum::PY_PRIM)]);
|
PyNativeAlgo::PyParser::SetPrim(op_run_info, args[static_cast<size_t>(RunOpArgsEnum::PY_PRIM)]);
|
||||||
|
|
|
@ -1035,8 +1035,7 @@ FuncGraphPtr GradExecutor::GetBpropGraph(const ad::GradAttr &grad_attr, const ve
|
||||||
need_renormalize_ = false;
|
need_renormalize_ = false;
|
||||||
bprop_graph->set_flag(FUNC_GRAPH_FLAG_CORE, true);
|
bprop_graph->set_flag(FUNC_GRAPH_FLAG_CORE, true);
|
||||||
bprop_graph->set_flag(kFlagIsPynativeBpropGraph, true);
|
bprop_graph->set_flag(kFlagIsPynativeBpropGraph, true);
|
||||||
bool use_dynamic_shape_process = !(forward()->device_target() == kAscendDevice) && use_dynamic_shape_process_;
|
bprop_graph->set_flag(kFlagUseDynamicShapeProcess, use_dynamic_shape_process_);
|
||||||
bprop_graph->set_flag(kFlagUseDynamicShapeProcess, use_dynamic_shape_process);
|
|
||||||
bprop_graph->set_attr(kAttrFuncGraphCellId, MakeValue(top_input_args_info_->obj_id));
|
bprop_graph->set_attr(kAttrFuncGraphCellId, MakeValue(top_input_args_info_->obj_id));
|
||||||
return bprop_graph;
|
return bprop_graph;
|
||||||
}
|
}
|
||||||
|
|
|
@ -71,6 +71,9 @@ static const std::unordered_set<std::string> kAclKernelSet = {kConv2DOpName,
|
||||||
const std::map<std::string, std::vector<std::string>> kNextOpFormatList = {
|
const std::map<std::string, std::vector<std::string>> kNextOpFormatList = {
|
||||||
{prim::kPrimConv2D->name(), {kOpFormat_NC1HWC0, kOpFormat_FRAC_Z}}};
|
{prim::kPrimConv2D->name(), {kOpFormat_NC1HWC0, kOpFormat_FRAC_Z}}};
|
||||||
|
|
||||||
|
mindspore::HashSet<std::string> kHighPrecisionOp = {kConv2DOpName, kMatMulOpName, kBatchMatMulOpName,
|
||||||
|
kConv2DBackpropInputOpName, kConv2DBackpropFilterOpName};
|
||||||
|
|
||||||
bool MatchInferOutputDataType(const CNodePtr &cnode, const kernel::KernelBuildInfo &kernel_build_info) {
|
bool MatchInferOutputDataType(const CNodePtr &cnode, const kernel::KernelBuildInfo &kernel_build_info) {
|
||||||
MS_EXCEPTION_IF_NULL(cnode);
|
MS_EXCEPTION_IF_NULL(cnode);
|
||||||
// Check input data type
|
// Check input data type
|
||||||
|
@ -789,6 +792,27 @@ void SetRaiseOrReduceFlag(const CNodePtr &kernel_node, KernelSelectStatus status
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void UpdateInputForHighPrecisionOp(const CNodePtr &kernel_node,
|
||||||
|
const std::shared_ptr<kernel::KernelBuildInfo::KernelBuildInfoBuilder> &builder) {
|
||||||
|
auto input_dtypes = AnfAlgo::GetAllInputDeviceTypes(kernel_node);
|
||||||
|
auto output_dtypes = AnfAlgo::GetAllOutputDeviceTypes(kernel_node);
|
||||||
|
auto has_fp32 = std::any_of(output_dtypes.begin(), output_dtypes.end(),
|
||||||
|
[](TypeId type) { return type == TypeId::kNumberTypeFloat32; });
|
||||||
|
if (has_fp32) {
|
||||||
|
std::vector<TypeId> new_input_types;
|
||||||
|
for (auto type : input_dtypes) {
|
||||||
|
if (type == TypeId::kNumberTypeFloat16) {
|
||||||
|
new_input_types.push_back(TypeId::kNumberTypeFloat32);
|
||||||
|
} else {
|
||||||
|
new_input_types.push_back(type);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
builder->SetInputsDeviceType(new_input_types);
|
||||||
|
MS_LOG(INFO) << "Update data type for " << kernel_node->fullname_with_scope() << " from " << input_dtypes << " to "
|
||||||
|
<< new_input_types;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void SetAclKernelInfo(const CNodePtr &kernel_node) {
|
void SetAclKernelInfo(const CNodePtr &kernel_node) {
|
||||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||||
if (!common::AnfAlgo::HasNodeAttr(kAttrMutableKernel, kernel_node)) {
|
if (!common::AnfAlgo::HasNodeAttr(kAttrMutableKernel, kernel_node)) {
|
||||||
|
@ -821,6 +845,13 @@ void SetAclKernelInfo(const CNodePtr &kernel_node) {
|
||||||
MS_EXCEPTION_IF_NULL(new_builder);
|
MS_EXCEPTION_IF_NULL(new_builder);
|
||||||
new_builder->SetKernelType(ACL_KERNEL);
|
new_builder->SetKernelType(ACL_KERNEL);
|
||||||
MS_LOG(INFO) << "SUCCESS SET ACL KERNEL FOR" << kernel_node->DebugString();
|
MS_LOG(INFO) << "SUCCESS SET ACL KERNEL FOR" << kernel_node->DebugString();
|
||||||
|
|
||||||
|
// For high precision op
|
||||||
|
auto op_name = common::AnfAlgo::GetCNodeName(kernel_node);
|
||||||
|
if (kHighPrecisionOp.count(op_name) != 0) {
|
||||||
|
UpdateInputForHighPrecisionOp(kernel_node, new_builder);
|
||||||
|
}
|
||||||
|
|
||||||
AnfAlgo::SetSelectKernelBuildInfo(new_builder->Build(), kernel_node.get());
|
AnfAlgo::SetSelectKernelBuildInfo(new_builder->Build(), kernel_node.get());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -326,6 +326,11 @@ bool AscendKernelExecutor::MemoryCopyAsync(const CNodePtr &node, const vector<Ad
|
||||||
|
|
||||||
bool AscendKernelExecutor::GetKernelRealInputs(const CNodePtr &kernel, const vector<AddressPtr> &inputs,
|
bool AscendKernelExecutor::GetKernelRealInputs(const CNodePtr &kernel, const vector<AddressPtr> &inputs,
|
||||||
std::vector<AddressPtr> *real_inputs) const {
|
std::vector<AddressPtr> *real_inputs) const {
|
||||||
|
if (AnfAlgo::GetKernelType(kernel) == KernelType::ACL_KERNEL) {
|
||||||
|
*real_inputs = inputs;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
auto input_num = common::AnfAlgo::GetInputTensorNum(kernel);
|
auto input_num = common::AnfAlgo::GetInputTensorNum(kernel);
|
||||||
if (input_num != inputs.size()) {
|
if (input_num != inputs.size()) {
|
||||||
MS_LOG(ERROR) << "Input num is " << input_num << " but input address num is " << inputs.size();
|
MS_LOG(ERROR) << "Input num is " << input_num << " but input address num is " << inputs.size();
|
||||||
|
|
|
@ -122,7 +122,7 @@ bool AclKernelMod::Launch(const std::vector<AddressPtr> &inputs, const std::vect
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
MS_LOG(INFO) << "Start aclopCompileAndExecute of node: " << node->fullname_with_scope();
|
MS_LOG(INFO) << "Start aclopCompileAndExecute of node: " << node->fullname_with_scope() << " op_type_:" << op_type_;
|
||||||
bool ret = aclopCompileAndExecute(const_cast<char *>(op_type_.c_str()), op_desc_ptr->input_tensor_desc().size(),
|
bool ret = aclopCompileAndExecute(const_cast<char *>(op_type_.c_str()), op_desc_ptr->input_tensor_desc().size(),
|
||||||
op_desc_ptr->input_tensor_desc().data(), op_desc_ptr->input_tensor_data().data(),
|
op_desc_ptr->input_tensor_desc().data(), op_desc_ptr->input_tensor_data().data(),
|
||||||
op_desc_ptr->output_tensor_desc().size(), op_desc_ptr->output_tensor_desc().data(),
|
op_desc_ptr->output_tensor_desc().size(), op_desc_ptr->output_tensor_desc().data(),
|
||||||
|
|
|
@ -57,8 +57,7 @@ std::shared_ptr<OpInfo> TbeDynamicShapeUtil::FindOp(const std::string &op_name,
|
||||||
auto op_info = mindspore::kernel::OpLib::FindOp(op_name, OpImplyType::kImplyTBE, is_dynamic_shape);
|
auto op_info = mindspore::kernel::OpLib::FindOp(op_name, OpImplyType::kImplyTBE, is_dynamic_shape);
|
||||||
// If have no dynamic shape op, get static shape op
|
// If have no dynamic shape op, get static shape op
|
||||||
if (op_info != nullptr && !op_info->dynamic_shape_support() && is_dynamic_shape) {
|
if (op_info != nullptr && !op_info->dynamic_shape_support() && is_dynamic_shape) {
|
||||||
MS_LOG(ERROR) << "Node(" << cnode->fullname_with_scope() << ") not support dynamic shape:" << cnode->DebugString();
|
MS_LOG(INFO) << "Node(" << cnode->fullname_with_scope() << ") not support dynamic shape:" << cnode->DebugString();
|
||||||
return nullptr;
|
|
||||||
}
|
}
|
||||||
return op_info;
|
return op_info;
|
||||||
}
|
}
|
||||||
|
|
|
@ -70,7 +70,7 @@ const AnfNodePtr BatchNormGradUnifyMindIR::Process(const FuncGraphPtr &func_grap
|
||||||
|
|
||||||
auto cnode = node->cast<CNodePtr>();
|
auto cnode = node->cast<CNodePtr>();
|
||||||
MS_EXCEPTION_IF_NULL(cnode);
|
MS_EXCEPTION_IF_NULL(cnode);
|
||||||
if (common::AnfAlgo::HasNodeAttr(kAttrUnifyIRPassed, cnode)) {
|
if (common::AnfAlgo::HasNodeAttr(kAttrUnifyIRPassed, cnode) || func_graph->has_flag(kAttrMutableKernel)) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
return CreateNewBatchNormGrad(func_graph, cnode);
|
return CreateNewBatchNormGrad(func_graph, cnode);
|
||||||
|
|
|
@ -27,12 +27,6 @@ OUTPUT_MAP(Pad) = {{0, OUTPUT_DESC(y)}};
|
||||||
REG_ADPT_DESC(Pad, kPadOpName, ADPT_DESC(Pad))
|
REG_ADPT_DESC(Pad, kPadOpName, ADPT_DESC(Pad))
|
||||||
REG_ADPT_DESC(PadD, kPadDOpName, ADPT_DESC(Pad))
|
REG_ADPT_DESC(PadD, kPadDOpName, ADPT_DESC(Pad))
|
||||||
|
|
||||||
// BroadcastToD
|
|
||||||
INPUT_MAP(BroadcastToD) = {{1, INPUT_DESC(x)}};
|
|
||||||
ATTR_MAP(BroadcastToD) = {{"shape", ATTR_DESC(shape, AnyTraits<int64_t>(), AnyTraits<std::vector<int64_t>>())}};
|
|
||||||
OUTPUT_MAP(BroadcastToD) = {{0, OUTPUT_DESC(y)}};
|
|
||||||
REG_ADPT_DESC(BroadcastTo, kNameBroadcastTo, ADPT_DESC(BroadcastToD))
|
|
||||||
|
|
||||||
// BroadcastTo
|
// BroadcastTo
|
||||||
INPUT_MAP(BroadcastTo) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(shape)}};
|
INPUT_MAP(BroadcastTo) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(shape)}};
|
||||||
ATTR_INPUT_MAP(BroadcastTo) = {{"shape", "shape"}};
|
ATTR_INPUT_MAP(BroadcastTo) = {{"shape", "shape"}};
|
||||||
|
@ -40,6 +34,7 @@ OUTPUT_MAP(BroadcastTo) = {{0, OUTPUT_DESC(y)}};
|
||||||
ATTR_MAP(BroadcastTo) = EMPTY_ATTR_MAP;
|
ATTR_MAP(BroadcastTo) = EMPTY_ATTR_MAP;
|
||||||
REG_ADPT_DESC(BroadcastToD, kNameBroadcastToD, ADPT_DESC(BroadcastTo))
|
REG_ADPT_DESC(BroadcastToD, kNameBroadcastToD, ADPT_DESC(BroadcastTo))
|
||||||
REG_ADPT_DESC(DynamicBroadcastTo, kDynamicBroadcastToOpName, ADPT_DESC(BroadcastTo))
|
REG_ADPT_DESC(DynamicBroadcastTo, kDynamicBroadcastToOpName, ADPT_DESC(BroadcastTo))
|
||||||
|
REG_ADPT_DESC(BroadcastTo, kNameBroadcastTo, ADPT_DESC(BroadcastTo))
|
||||||
|
|
||||||
// Diag
|
// Diag
|
||||||
INPUT_MAP(Diag) = {{1, INPUT_DESC(x)}};
|
INPUT_MAP(Diag) = {{1, INPUT_DESC(x)}};
|
||||||
|
|
|
@ -28,9 +28,6 @@ DECLARE_OP_USE_OUTPUT(PadD)
|
||||||
DECLARE_OP_ADAPTER(Pad)
|
DECLARE_OP_ADAPTER(Pad)
|
||||||
DECLARE_OP_USE_OUTPUT(Pad)
|
DECLARE_OP_USE_OUTPUT(Pad)
|
||||||
|
|
||||||
DECLARE_OP_ADAPTER(BroadcastToD)
|
|
||||||
DECLARE_OP_USE_OUTPUT(BroadcastToD)
|
|
||||||
|
|
||||||
DECLARE_OP_ADAPTER(BroadcastTo)
|
DECLARE_OP_ADAPTER(BroadcastTo)
|
||||||
DECLARE_OP_USE_OUTPUT(BroadcastTo)
|
DECLARE_OP_USE_OUTPUT(BroadcastTo)
|
||||||
|
|
||||||
|
|
|
@ -23,14 +23,14 @@ INPUT_MAP(SplitD) = {{1, INPUT_DESC(x)}};
|
||||||
ATTR_MAP(SplitD) = {{"axis", ATTR_DESC(split_dim, AnyTraits<int64_t>())},
|
ATTR_MAP(SplitD) = {{"axis", ATTR_DESC(split_dim, AnyTraits<int64_t>())},
|
||||||
{"output_num", ATTR_DESC(num_split, AnyTraits<int64_t>())}};
|
{"output_num", ATTR_DESC(num_split, AnyTraits<int64_t>())}};
|
||||||
DYN_OUTPUT_MAP(SplitD) = {{0, DYN_OUTPUT_DESC(y)}};
|
DYN_OUTPUT_MAP(SplitD) = {{0, DYN_OUTPUT_DESC(y)}};
|
||||||
REG_ADPT_DESC(Split, kNameSplit, ADPT_DESC(SplitD))
|
REG_ADPT_DESC(SplitD, kSplitDOpName, ADPT_DESC(SplitD))
|
||||||
|
|
||||||
// Split
|
// Split
|
||||||
INPUT_MAP(Split) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(split_dim)}};
|
INPUT_MAP(Split) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(split_dim)}};
|
||||||
ATTR_INPUT_MAP(Split) = {{"axis", "split_dim"}};
|
ATTR_INPUT_MAP(Split) = {{"axis", "split_dim"}};
|
||||||
ATTR_MAP(Split) = {{"num_split", ATTR_DESC(num_split, AnyTraits<int64_t>())}};
|
ATTR_MAP(Split) = {{"num_split", ATTR_DESC(num_split, AnyTraits<int64_t>())}};
|
||||||
DYN_OUTPUT_MAP(Split) = {{0, DYN_OUTPUT_DESC(y)}};
|
DYN_OUTPUT_MAP(Split) = {{0, DYN_OUTPUT_DESC(y)}};
|
||||||
REG_ADPT_DESC(SplitD, kSplitDOpName, ADPT_DESC(Split))
|
REG_ADPT_DESC(Split, kNameSplit, ADPT_DESC(Split))
|
||||||
|
|
||||||
// Pack
|
// Pack
|
||||||
INPUT_MAP(Pack) = EMPTY_INPUT_MAP;
|
INPUT_MAP(Pack) = EMPTY_INPUT_MAP;
|
||||||
|
@ -66,7 +66,8 @@ DYN_INPUT_MAP(Concat) = {{1, DYN_INPUT_DESC(x)}};
|
||||||
ATTR_INPUT_MAP(Concat) = {{"axis", "concat_dim"}};
|
ATTR_INPUT_MAP(Concat) = {{"axis", "concat_dim"}};
|
||||||
ATTR_MAP(Concat) = {{"inputNums", ATTR_DESC(N, AnyTraits<int64_t>())}};
|
ATTR_MAP(Concat) = {{"inputNums", ATTR_DESC(N, AnyTraits<int64_t>())}};
|
||||||
OUTPUT_MAP(Concat) = {{0, OUTPUT_DESC(y)}};
|
OUTPUT_MAP(Concat) = {{0, OUTPUT_DESC(y)}};
|
||||||
REG_ADPT_DESC(Concat, prim::kPrimConcatD->name(), ADPT_DESC(Concat))
|
// Rollback to ConcatD for the support of dynamic input scene is incomplete.
|
||||||
|
REG_ADPT_DESC(Concat, prim::kPrimConcatD->name(), ADPT_DESC(ConcatD))
|
||||||
|
|
||||||
// ConcatV2 Inference for tf
|
// ConcatV2 Inference for tf
|
||||||
DYN_INPUT_MAP(ConcatV2) = {{1, DYN_INPUT_DESC(x)}};
|
DYN_INPUT_MAP(ConcatV2) = {{1, DYN_INPUT_DESC(x)}};
|
||||||
|
|
|
@ -353,7 +353,7 @@ def test_pynative_forward_hook():
|
||||||
assert np.allclose(grad[1][0].asnumpy(), expect_grad[1][0].asnumpy(), 0.000001, 0.000001)
|
assert np.allclose(grad[1][0].asnumpy(), expect_grad[1][0].asnumpy(), 0.000001, 0.000001)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.level0
|
@pytest.mark.level2
|
||||||
@pytest.mark.platform_x86_cpu
|
@pytest.mark.platform_x86_cpu
|
||||||
@pytest.mark.platform_arm_ascend_training
|
@pytest.mark.platform_arm_ascend_training
|
||||||
@pytest.mark.platform_x86_ascend_training
|
@pytest.mark.platform_x86_ascend_training
|
||||||
|
|
Loading…
Reference in New Issue