forked from mindspore-Ecosystem/mindspore
add dynamic rnn grad pass for pynative
This commit is contained in:
parent
9d51fbfe26
commit
cae4abb5e0
|
@ -195,6 +195,7 @@ void RunOpAscendDataLayout(const std::shared_ptr<session::KernelGraph> &kernel_g
|
|||
auto data_layout_pm = std::make_shared<PassManager>("pynative_transop_pm");
|
||||
data_layout_pm->AddPass(std::make_shared<ChangeAxisOfReduceKernel>());
|
||||
data_layout_pm->AddPass(std::make_shared<RectifyDoMaskKernelInfo>());
|
||||
data_layout_pm->AddPass(std::make_shared<DynamicRNNGradReformat>());
|
||||
data_layout_pm->AddPass(std::make_shared<RunOpInsertTransData>());
|
||||
data_layout_pm->AddPass(std::make_shared<GetitemTuple>());
|
||||
data_layout_pm->AddPass(std::make_shared<CommonSubexpressionElimination>());
|
||||
|
@ -338,7 +339,9 @@ void RunOpAscendBackendIRFusionOptimization(const std::shared_ptr<session::Kerne
|
|||
ir_fusion_pm->AddPass(std::make_shared<InsertPadForNMSWithMask>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<TensorScatterUpdateFission>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<InsertPlaceholderForDynamicRNN>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<DynamicGRUV2GradFission>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<InsertPlaceholderForDynamicGRUV2>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<DynamicRnnGradFissionV2>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<EraseVisitAttr>());
|
||||
|
||||
optimizer->AddPassManager(ir_fusion_pm);
|
||||
|
|
|
@ -529,13 +529,6 @@ bool TransDataType(const TypeIdArgs &args, void *result) {
|
|||
}
|
||||
|
||||
bool TransFormat(const FormatArgs &args, void *result) {
|
||||
using FormatTransfer = std::function<bool(const FormatArgs &, void *)>;
|
||||
const std::map<std::string, FormatTransfer> format_trans_map{
|
||||
{kOpFormat_FRAC_Z, NchwToFracZ}, {kOpFormat_FRAC_NZ, NchwToFracNz},
|
||||
{kOpFormat_NC1HWC0, NchwToNc1hwc0}, {kOpFormat_C1HWNCoC0, NchwToC1hwncoc0},
|
||||
{kOpFormat_FRACTAL_Z_C04, NchwToFracZc04}, {kOpFormat_NC1HWC0_C04, NchwToNc1hwc04},
|
||||
{kOpFormat_NDC1HWC0, NcdhwToNdc1hwc0}};
|
||||
|
||||
MS_LOG(DEBUG) << "Start trans format.";
|
||||
if (abstract::TypeIdSize(args.src_data_type) < 1) {
|
||||
MS_LOG(ERROR) << "Invalid datatype..";
|
||||
|
@ -544,15 +537,14 @@ bool TransFormat(const FormatArgs &args, void *result) {
|
|||
if (args.device_format == kOpFormat_HWCN || args.device_format == kOpFormat_NHWC) {
|
||||
return NchwTo4D(args, result);
|
||||
}
|
||||
auto iter = format_trans_map.find(args.device_format);
|
||||
if (iter == format_trans_map.end()) {
|
||||
auto iter = kTransFormatMapOfHostToDevice.find(args.device_format);
|
||||
if (iter == kTransFormatMapOfHostToDevice.end()) {
|
||||
MS_LOG(EXCEPTION) << "Unexpected format[" << args.device_format << "]";
|
||||
}
|
||||
return iter->second(args, result);
|
||||
}
|
||||
|
||||
bool TransFormatFromDeviceToHost(const FormatArgs &args, void *result) {
|
||||
using FormatTransfer = std::function<bool(const FormatArgs &, void *)>;
|
||||
const std::map<std::string, FormatTransfer> format_trans_map{
|
||||
{kOpFormat_FRAC_Z, FracZToNchw}, {kOpFormat_FRAC_NZ, FracNzToNchw},
|
||||
{kOpFormat_NC1HWC0, Nc1hwc0ToNchw}, {kOpFormat_C1HWNCoC0, C1hwncoc0ToNchw},
|
||||
|
|
|
@ -76,6 +76,13 @@ bool Nc1hwc0ToNchw(const FormatArgs &args, void *result);
|
|||
bool Nc1hwc04ToNchw(const FormatArgs &args, void *result);
|
||||
bool C1hwncoc0ToNchw(const FormatArgs &args, void *result);
|
||||
bool Ndc1hwc0ToNcdhw(const FormatArgs &args, void *result);
|
||||
using FormatTransfer = std::function<bool(const FormatArgs &, void *)>;
|
||||
const std::map<std::string, FormatTransfer> kTransFormatMapOfHostToDevice{
|
||||
{kOpFormat_FRAC_Z, NchwToFracZ}, {kOpFormat_FRAC_NZ, NchwToFracNz},
|
||||
{kOpFormat_NC1HWC0, NchwToNc1hwc0}, {kOpFormat_C1HWNCoC0, NchwToC1hwncoc0},
|
||||
{kOpFormat_FRACTAL_Z_C04, NchwToFracZc04}, {kOpFormat_NC1HWC0_C04, NchwToNc1hwc04},
|
||||
{kOpFormat_NDC1HWC0, NcdhwToNdc1hwc0}};
|
||||
|
||||
} // namespace trans
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -29,6 +29,7 @@
|
|||
#include "backend/kernel_compiler/oplib/oplib.h"
|
||||
#include "backend/kernel_compiler/tbe/tbe_dynaminc_shape_util.h"
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
#include "common/trans.h"
|
||||
#include "debug/anf_ir_dump.h"
|
||||
#include "frontend/operator/ops.h"
|
||||
#include "utils/ms_context.h"
|
||||
|
@ -382,14 +383,15 @@ void SetTensorDeviceInfo(const CNodePtr &kernel_node) {
|
|||
continue;
|
||||
}
|
||||
auto builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
|
||||
std::vector<std::string> output_format = {AnfAlgo::GetOutputFormat(real_input_node, 0)};
|
||||
auto refresh_format = selected_kernel_info->GetInputFormat(input_index);
|
||||
std::vector<std::string> output_format = {refresh_format};
|
||||
// if not find in host convert format map means the host has not registered the convert function of this format
|
||||
if (trans::kTransFormatMapOfHostToDevice.find(refresh_format) == trans::kTransFormatMapOfHostToDevice.end() &&
|
||||
refresh_format != kOpFormat_DEFAULT) {
|
||||
output_format = {AnfAlgo::GetOutputFormat(real_input_node, 0)};
|
||||
}
|
||||
if (IsValueNode<tensor::Tensor>(input_kernel_node) &&
|
||||
AnfAlgo::GetOutputDeviceDataType(input_kernel_node, 0) == kTypeUnknown) {
|
||||
if (selected_kernel_info->GetInputFormat(input_index) != kOpFormat_FRACTAL_ZN_LSTM ||
|
||||
selected_kernel_info->GetInputFormat(input_index) != kOpFormat_FRACTAL_Z_3D ||
|
||||
selected_kernel_info->GetInputFormat(input_index) != kOpFormat_NDC1HWC0) {
|
||||
output_format = {selected_kernel_info->GetInputFormat(input_index)};
|
||||
}
|
||||
builder->SetOutputsFormat(output_format);
|
||||
std::vector<TypeId> output_type = {selected_kernel_info->GetInputDeviceType(input_index)};
|
||||
builder->SetOutputsDeviceType(output_type);
|
||||
|
@ -397,11 +399,6 @@ void SetTensorDeviceInfo(const CNodePtr &kernel_node) {
|
|||
continue;
|
||||
}
|
||||
if (AnfAlgo::GetOutputDeviceDataType(real_input_node, 0) == kTypeUnknown || is_ref) {
|
||||
if (selected_kernel_info->GetInputFormat(input_index) != kOpFormat_FRACTAL_ZN_LSTM ||
|
||||
selected_kernel_info->GetInputFormat(input_index) != kOpFormat_FRACTAL_Z_3D ||
|
||||
selected_kernel_info->GetInputFormat(input_index) != kOpFormat_NDC1HWC0) {
|
||||
output_format = {selected_kernel_info->GetInputFormat(input_index)};
|
||||
}
|
||||
builder->SetOutputsFormat(output_format);
|
||||
std::vector<TypeId> output_type = {selected_kernel_info->GetInputDeviceType(input_index)};
|
||||
builder->SetOutputsDeviceType(output_type);
|
||||
|
|
Loading…
Reference in New Issue