fix dropout unify_mindir pass

This commit is contained in:
jjfeing 2020-12-15 16:21:23 +08:00 committed by yuchaojie
parent b41d83a7df
commit 389da54525
5 changed files with 302 additions and 146 deletions

View File

@ -15,7 +15,9 @@
*/
#include "backend/optimizer/ascend/mindir/dropout_unify_mindir.h"
#include <ops/all_ops.h>
#include <vector>
#include <string>
#include <memory>
#include <numeric>
#include <algorithm>
@ -23,45 +25,69 @@
#include "backend/session/anf_runtime_algorithm.h"
#include "utils/log_adapter.h"
/*
DropoutGenMask
attr: seed0 seed1:
input: 1.shape <>;
2. keep_prob: type base on inputx type, if x in float/float16, then use this type, else use float16;
output: shape: (count + 127) % 128 * 16
*/
namespace mindspore::opt {
namespace {
constexpr auto kKeepProb = "keep_prob";
constexpr auto kSeed0 = "Seed0";
constexpr auto kSeed1 = "Seed1";
constexpr auto kUint8BitSize = 8;
namespace mindspore::opt {
constexpr int64_t kMaskAlignNum = 128;
constexpr int64_t kMaskMultiNum = 16;
constexpr size_t kFloat16Len = 2; // size of float16
namespace {
AnfNodePtr GetDropoutKeepProb(const AnfNodePtr &node, float *keep_prob) {
MS_LOG(INFO) << "GetDropoutNodeInfo start.";
constexpr size_t kInt64Len = 8; // size of int64
TypeId GetInputXDataType(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(keep_prob);
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (!AnfAlgo::HasNodeAttr(kKeepProb, cnode) || !AnfAlgo::HasNodeAttr(kSeed0, cnode) ||
!AnfAlgo::HasNodeAttr(kSeed1, cnode)) {
MS_LOG(EXCEPTION) << "Dropout node does nothave attr: keep_prob or seed0 or seed1.";
auto dropout_input_type = AnfAlgo::GetPrevNodeOutputInferDataType(node, 0);
if (dropout_input_type != kNumberTypeFloat32 && dropout_input_type != kNumberTypeFloat &&
dropout_input_type != kNumberTypeFloat16) {
dropout_input_type = kNumberTypeFloat16;
}
*keep_prob = AnfAlgo::GetNodeAttr<float>(node, kKeepProb);
MS_LOG(INFO) << "keep_prob: " << *keep_prob;
// return dropout input. maybe tensor or pre cnode output
return cnode->input(1);
MS_LOG(INFO) << "Dropout input data type: " << TypeIdLabel(dropout_input_type);
return dropout_input_type;
}
ValueNodePtr CreateKeepPorbValueNode(const FuncGraphPtr &func_graph, const float &keep_prob, const TypePtr &dtype) {
MS_LOG(INFO) << "CreateKeepPorbValueNode start.";
std::vector<int64_t> GetInputXShape(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
std::vector<int64_t> shapes;
auto shape_size_t = AnfAlgo::GetPrevNodeOutputInferShape(node, 0);
std::transform(shape_size_t.begin(), shape_size_t.end(), std::back_inserter(shapes), SizeToLong);
return shapes;
}
ValueNodePtr CreateKeepPorbValueNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node, TypeId type_id) {
MS_EXCEPTION_IF_NULL(func_graph);
auto kernel_graph = func_graph->cast<KernelGraphPtr>();
MS_EXCEPTION_IF_NULL(kernel_graph);
MS_EXCEPTION_IF_NULL(node);
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
// Step1: get keep_prob
if (!AnfAlgo::HasNodeAttr(kKeepProb, cnode)) {
MS_LOG(EXCEPTION) << "Dropout node does not have attr: keep_prob.";
}
if (AnfAlgo::GetCNodePrimitive(cnode)->ToString() == kDropoutOpName) {
if (!AnfAlgo::HasNodeAttr(kSeed0, cnode) || !AnfAlgo::HasNodeAttr(kSeed1, cnode)) {
MS_LOG(EXCEPTION) << "Dropout node does not have attr: seed0 or seed1.";
}
}
auto keep_prob = AnfAlgo::GetNodeAttr<float>(node, kKeepProb);
MS_LOG(INFO) << "Keep_prob value: " << keep_prob;
std::vector<int64_t> keep_prob_shape = {};
ShapeVector shape = {};
auto keep_prob_tensor = std::make_shared<tensor::Tensor>(dtype->type_id(), keep_prob_shape);
auto keep_prob_tensor = std::make_shared<tensor::Tensor>(type_id, keep_prob_shape);
MS_EXCEPTION_IF_NULL(keep_prob_tensor);
auto data_ptr = keep_prob_tensor->data_c();
MS_EXCEPTION_IF_NULL(data_ptr);
// keep_prob's datatype is same with input data
if (dtype->type_id() == kNumberTypeFloat16) {
float16 half_data = float16(keep_prob);
auto ret_code = memcpy_s(data_ptr, kFloat16Len, &half_data, kFloat16Len);
if (type_id == kNumberTypeFloat16) {
auto half_data = float16(keep_prob);
auto ret_code = memcpy_s(data_ptr, static_cast<size_t>(keep_prob_tensor->data().nbytes()), &half_data, kFloat16Len);
if (ret_code != 0) {
MS_LOG(EXCEPTION) << "Failed to copy data into Tensor.";
}
@ -69,59 +95,65 @@ ValueNodePtr CreateKeepPorbValueNode(const FuncGraphPtr &func_graph, const float
auto *val = reinterpret_cast<float *>(data_ptr);
*val = keep_prob;
}
auto abstract = std::make_shared<abstract::AbstractTensor>(dtype, shape);
auto kernel_graph = func_graph->cast<KernelGraphPtr>();
MS_EXCEPTION_IF_NULL(kernel_graph);
auto abstract = std::make_shared<abstract::AbstractTensor>(TypeIdToType(type_id), keep_prob_shape);
auto keep_prob_value = kernel_graph->NewValueNode(abstract, keep_prob_tensor);
MS_EXCEPTION_IF_NULL(keep_prob_value);
kernel_graph->AddValueNodeToGraph(keep_prob_value);
return keep_prob_value;
}
std::vector<int64_t> GetInputShape(const AnfNodePtr &node, const AnfNodePtr &dropout_input) {
MS_LOG(INFO) << "GetInputShape start.";
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(dropout_input);
std::vector<int64_t> shapes;
if (dropout_input->isa<Parameter>()) {
MS_LOG(INFO) << "Dropout input from parameter node.";
// single test case
auto dropout_input_value = dropout_input->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(dropout_input_value);
MS_EXCEPTION_IF_NULL(dropout_input_value->Shape());
auto shape = dropout_input_value->Shape()->cast<abstract::ShapePtr>();
MS_EXCEPTION_IF_NULL(shape);
return shape->shape();
} else if (dropout_input->isa<CNode>()) {
MS_LOG(INFO) << "Dropout input from cnode.";
auto dropout_input_node = dropout_input->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(dropout_input_node);
auto shape_size_t = AnfAlgo::GetPrevNodeOutputInferShape(node, 0);
std::transform(shape_size_t.begin(), shape_size_t.end(), std::back_inserter(shapes), SizeToLong);
return shapes;
} else {
MS_LOG(ERROR) << "Dropout input is not parameter or cnode.";
return {};
}
}
ValueNodePtr CreateShapeValueNode(const FuncGraphPtr &func_graph, const std::vector<int64_t> &shape) {
ValueNodePtr CreateShapeValueNode(const FuncGraphPtr &func_graph, const std::vector<int64_t> &shape,
bool is_pynative = false) {
MS_LOG(INFO) << "CreateShapeValueNode start.";
MS_EXCEPTION_IF_NULL(func_graph);
auto kernel_graph = func_graph->cast<KernelGraphPtr>();
MS_EXCEPTION_IF_NULL(kernel_graph);
std::vector<ValuePtr> dim_values{};
abstract::AbstractBasePtrList abs{};
for (const auto &dim : shape) {
dim_values.push_back(MakeValue(dim));
abs.push_back(std::make_shared<abstract::AbstractScalar>(dim));
ValuePtr shape_value = nullptr;
AbstractBasePtr abstract = nullptr;
if (is_pynative) {
// pynative mode need to create tensor
int64_t shape_dim = SizeToLong(shape.size());
std::vector<int64_t> shape_vec_shape = {shape_dim};
auto shape_tensor = std::make_shared<tensor::Tensor>(kNumberTypeInt64, shape_vec_shape);
MS_EXCEPTION_IF_NULL(shape_tensor);
auto data_ptr = shape_tensor->data_c();
MS_EXCEPTION_IF_NULL(data_ptr);
auto elem_num = shape.size() * kInt64Len;
auto ret_code = memcpy_s(data_ptr, static_cast<size_t>(shape_tensor->data().nbytes()), &shape[0], elem_num);
if (ret_code != 0) {
MS_LOG(EXCEPTION) << "Failed to copy data into Tensor.";
}
shape_value = shape_tensor;
abstract = std::make_shared<abstract::AbstractTensor>(kInt64, shape_vec_shape);
} else {
std::vector<ValuePtr> dim_values{};
abstract::AbstractBasePtrList abs{};
for (const auto &dim : shape) {
dim_values.push_back(MakeValue(dim));
abs.push_back(std::make_shared<abstract::AbstractScalar>(dim));
}
shape_value = std::make_shared<ValueTuple>(dim_values);
abstract = std::make_shared<abstract::AbstractTuple>(abs);
}
auto shape_value_tuple = std::make_shared<ValueTuple>(dim_values);
MS_EXCEPTION_IF_NULL(shape_value_tuple);
auto abstract = std::make_shared<abstract::AbstractTuple>(abs);
MS_EXCEPTION_IF_NULL(abstract);
auto shape_value = kernel_graph->NewValueNode(abstract, shape_value_tuple);
MS_EXCEPTION_IF_NULL(shape_value);
kernel_graph->AddValueNodeToGraph(shape_value);
return shape_value;
MS_EXCEPTION_IF_NULL(abstract);
auto shape_value_node = kernel_graph->NewValueNode(abstract, shape_value);
MS_EXCEPTION_IF_NULL(shape_value_node);
kernel_graph->AddValueNodeToGraph(shape_value_node);
return shape_value_node;
}
std::vector<int64_t> CalDropoutGenMaskOutput(const std::vector<int64_t> &shape) {
auto output_size = std::accumulate(shape.begin(), shape.end(), static_cast<int64_t>(1), std::multiplies<int64_t>());
auto output_count = output_size / kMaskAlignNum;
if (output_size % kMaskAlignNum != 0) {
output_count++;
}
auto ret = output_count * kMaskMultiNum;
MS_LOG(INFO) << "Output_size: " << ret;
return {ret};
}
} // namespace
@ -141,34 +173,34 @@ const AnfNodePtr DropoutUnifyMindIR::Process(const FuncGraphPtr &func_graph, con
MS_EXCEPTION_IF_NULL(tuple_cnode);
auto dropout_node = tuple_cnode->input(1);
MS_EXCEPTION_IF_NULL(dropout_node);
float keep_prob = 0;
auto dropout_input = GetDropoutKeepProb(dropout_node, &keep_prob);
auto dropout_dtype = AnfAlgo::GetOutputInferDataType(dropout_node, 0) == kNumberTypeFloat16 ? kFloat16 : kFloat32;
auto keep_prob_value = CreateKeepPorbValueNode(func_graph, keep_prob, dropout_dtype);
auto shape = GetInputShape(dropout_node, dropout_input);
auto shape_value = CreateShapeValueNode(func_graph, shape);
auto inputx_type_id = GetInputXDataType(dropout_node);
auto inputx_shape = GetInputXShape(dropout_node);
auto shape_value = CreateShapeValueNode(func_graph, inputx_shape);
auto keep_prob_value = CreateKeepPorbValueNode(func_graph, dropout_node, inputx_type_id);
// CreateDropoutGenMask
auto output_size = std::accumulate(shape.begin(), shape.end(), static_cast<int64_t>(1), std::multiplies<int64_t>());
output_size = output_size / kUint8BitSize;
MS_LOG(INFO) << "Output_size: " << output_size;
std::vector<AnfNodePtr> dropout_gen_mask_inputs{NewValueNode(std::make_shared<Primitive>(kDropoutGenMaskOpName)),
shape_value, keep_prob_value};
CNodePtr dropout_gen_mask = func_graph->NewCNode(dropout_gen_mask_inputs);
MS_EXCEPTION_IF_NULL(dropout_gen_mask);
AnfAlgo::CopyNodeAttrs(node, dropout_gen_mask);
ShapeVector dropout_gen_mask_output = {output_size};
auto gen_mask_abstract = std::make_shared<abstract::AbstractTensor>(kUInt8, dropout_gen_mask_output);
auto output_shape = CalDropoutGenMaskOutput(inputx_shape);
auto gen_mask_abstract = std::make_shared<abstract::AbstractTensor>(kUInt8, output_shape);
MS_EXCEPTION_IF_NULL(gen_mask_abstract);
dropout_gen_mask->set_abstract(gen_mask_abstract);
dropout_gen_mask->set_scope(node->scope());
// CreateDropoutDoMask
MS_EXCEPTION_IF_NULL(dropout_node);
auto dropout_cnode = dropout_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(dropout_cnode);
auto dropout_input = dropout_cnode->input(1);
std::vector<AnfNodePtr> dropout_do_mask_inputs{NewValueNode(std::make_shared<Primitive>(kDropoutDoMaskOpName)),
dropout_input, dropout_gen_mask, keep_prob_value};
auto dropout_do_mask = func_graph->NewCNode(dropout_do_mask_inputs);
MS_EXCEPTION_IF_NULL(dropout_do_mask);
ShapeVector dropout_do_mask_output = shape;
auto do_mask_abstract = std::make_shared<abstract::AbstractTensor>(dropout_dtype, dropout_do_mask_output);
auto do_mask_abstract = std::make_shared<abstract::AbstractTensor>(TypeIdToType(inputx_type_id), inputx_shape);
dropout_do_mask->set_abstract(do_mask_abstract);
dropout_do_mask->set_scope(node->scope());
@ -178,8 +210,6 @@ const AnfNodePtr DropoutUnifyMindIR::Process(const FuncGraphPtr &func_graph, con
const BaseRef DropoutGradUnifyMindIR::DefinePattern() const {
VarPtr X = std::make_shared<Var>();
VarPtr Y = std::make_shared<Var>();
MS_EXCEPTION_IF_NULL(X);
MS_EXCEPTION_IF_NULL(Y);
auto dropout_prim = std::make_shared<Primitive>(kDropoutOpName);
auto tuple_getitem_prim = prim::kPrimTupleGetItem;
auto dropout_grad_prim = std::make_shared<Primitive>(kDropoutGradOpName);
@ -194,58 +224,74 @@ const AnfNodePtr DropoutGradUnifyMindIR::Process(const FuncGraphPtr &func_graph,
const EquivPtr &equiv) const {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(node);
auto dropout_grad = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(dropout_grad);
auto tuple_getitem = dropout_grad->input(2);
MS_EXCEPTION_IF_NULL(tuple_getitem);
auto tuple_getitem_cnode = tuple_getitem->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(tuple_getitem_cnode);
auto dropout_node = tuple_getitem_cnode->input(1);
auto dropout_grad_cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(dropout_grad_cnode);
auto getitem1_node = dropout_grad_cnode->input(2);
MS_EXCEPTION_IF_NULL(getitem1_node);
auto getitem1_cnode = getitem1_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(getitem1_cnode);
auto dropout_node = getitem1_cnode->input(1);
MS_EXCEPTION_IF_NULL(dropout_node);
float keep_prob = 0;
auto dropout_input = GetDropoutKeepProb(dropout_node, &keep_prob);
auto dropout_dtype = AnfAlgo::GetOutputInferDataType(dropout_node, 0) == kNumberTypeFloat16 ? kFloat16 : kFloat32;
auto keep_prob_value = CreateKeepPorbValueNode(func_graph, keep_prob, dropout_dtype);
auto shape = GetInputShape(dropout_node, dropout_input);
auto shape_value = CreateShapeValueNode(func_graph, shape);
auto dropout_cnode = dropout_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(dropout_cnode);
auto inputx_type_id = GetInputXDataType(dropout_node);
auto inputx_shape = GetInputXShape(dropout_node);
auto shape_value = CreateShapeValueNode(func_graph, inputx_shape);
auto keep_prob_value = CreateKeepPorbValueNode(func_graph, dropout_node, inputx_type_id);
// CreateDropoutGenMask
auto output_size = std::accumulate(shape.begin(), shape.end(), static_cast<int64_t>(1), std::multiplies<int64_t>());
output_size = output_size / kUint8BitSize;
MS_LOG(INFO) << "Output_size: " << output_size;
std::vector<AnfNodePtr> dropout_gen_mask_inputs{NewValueNode(std::make_shared<Primitive>(kDropoutGenMaskOpName)),
shape_value, keep_prob_value};
CNodePtr dropout_gen_mask = func_graph->NewCNode(dropout_gen_mask_inputs);
MS_EXCEPTION_IF_NULL(dropout_gen_mask);
AnfAlgo::CopyNodeAttrs(node, dropout_gen_mask);
ShapeVector dropout_gen_mask_output = {output_size};
auto gen_mask_abstract = std::make_shared<abstract::AbstractTensor>(kUInt8, dropout_gen_mask_output);
auto output_shape = CalDropoutGenMaskOutput(inputx_shape);
auto gen_mask_abstract = std::make_shared<abstract::AbstractTensor>(kUInt8, output_shape);
MS_EXCEPTION_IF_NULL(gen_mask_abstract);
dropout_gen_mask->set_abstract(gen_mask_abstract);
dropout_gen_mask->set_scope(dropout_node->scope());
// AnfAlgo::CopyNodeAttrs(node, dropout_gen_mask);
dropout_gen_mask->set_scope(node->scope());
// CreateDropoutDoMask-forward
auto manager = func_graph->manager();
MS_EXCEPTION_IF_NULL(manager);
auto &node_users = manager->node_users();
auto iter = node_users.find(dropout_node);
CNodePtr dropout_do_mask1 = nullptr;
if (iter != node_users.end()) {
for (auto &node_index : iter->second) {
// Dropout has two outputs, so output node is tuple_getitem
auto tuple_getitem_cnode2 = node_index.first->cast<CNodePtr>();
// check if Dropout's first output, which is used by forward, is used.
auto getitem_index = GetValue<int64_t>(tuple_getitem_cnode2->input(2)->cast<ValueNodePtr>()->value());
if (getitem_index == 0) {
std::vector<AnfNodePtr> dropout_do_mask1_inputs{NewValueNode(std::make_shared<Primitive>(kDropoutDoMaskOpName)),
dropout_input, dropout_gen_mask, keep_prob_value};
auto dropout_do_mask1 = func_graph->NewCNode(dropout_do_mask1_inputs);
MS_EXCEPTION_IF_NULL(dropout_do_mask1);
ShapeVector dropout_do_mask1_output = shape;
auto do_mask_abstract1 = std::make_shared<abstract::AbstractTensor>(dropout_dtype, dropout_do_mask1_output);
dropout_do_mask1->set_abstract(do_mask_abstract1);
dropout_do_mask1->set_scope(dropout_node->scope());
(void)manager->Replace(tuple_getitem_cnode2, dropout_do_mask1);
break;
auto used_node = node_index.first;
if (AnfAlgo::CheckPrimitiveType(used_node, prim::kPrimTupleGetItem)) {
// check if Dropout's first output, which is used by forward, is used
if (AnfAlgo::GetTupleGetItemOutIndex(used_node->cast<CNodePtr>()) == 0) {
// if Dropout's first output is used, create forward DropoutDoMask
auto dropout_input = dropout_cnode->input(1);
std::vector<AnfNodePtr> dropout_do_mask1_inputs{
NewValueNode(std::make_shared<Primitive>(kDropoutDoMaskOpName)), dropout_input, dropout_gen_mask,
keep_prob_value};
dropout_do_mask1 = func_graph->NewCNode(dropout_do_mask1_inputs);
MS_EXCEPTION_IF_NULL(dropout_do_mask1);
auto do_mask_abstract1 =
std::make_shared<abstract::AbstractTensor>(TypeIdToType(inputx_type_id), inputx_shape);
dropout_do_mask1->set_abstract(do_mask_abstract1);
dropout_do_mask1->set_scope(dropout_node->scope());
(void)manager->Replace(used_node, dropout_do_mask1);
break;
}
}
}
}
if (dropout_do_mask1 != nullptr) {
// Dropout is used by ControlDepend in some situation, need to replace ControlDepend.
auto &users = manager->node_users();
iter = users.find(dropout_node);
if (iter != users.end()) {
for (auto &node_index : iter->second) {
auto used_node = node_index.first;
if (AnfAlgo::CheckPrimitiveType(used_node, prim::kPrimControlDepend)) {
(void)manager->Replace(used_node, dropout_do_mask1);
break;
}
}
}
}
@ -254,16 +300,112 @@ const AnfNodePtr DropoutGradUnifyMindIR::Process(const FuncGraphPtr &func_graph,
if (equiv->find(grad_input_) == equiv->end()) {
MS_LOG(EXCEPTION) << "Can not find grad_input in this pattern.";
}
auto grad_input = utils::cast<AnfNodePtr>((*equiv)[grad_input_]);
std::vector<AnfNodePtr> dropout_do_mask2_inputs{NewValueNode(std::make_shared<Primitive>(kDropoutDoMaskOpName)),
grad_input, dropout_gen_mask, keep_prob_value};
auto dropout_do_mask2 = func_graph->NewCNode(dropout_do_mask2_inputs);
MS_EXCEPTION_IF_NULL(dropout_do_mask2);
ShapeVector dropout_do_mask2_output = shape;
auto do_mask_abstract2 = std::make_shared<abstract::AbstractTensor>(dropout_dtype, dropout_do_mask2_output);
dropout_do_mask2->set_abstract(do_mask_abstract2);
dropout_do_mask2->set_scope(node->scope());
auto dropout_grad_input = utils::cast<AnfNodePtr>((*equiv)[grad_input_]);
std::vector<AnfNodePtr> dropout_do_mask_inputs{NewValueNode(std::make_shared<Primitive>(kDropoutDoMaskOpName)),
dropout_grad_input, dropout_gen_mask, keep_prob_value};
auto dropout_do_mask = func_graph->NewCNode(dropout_do_mask_inputs);
MS_EXCEPTION_IF_NULL(dropout_do_mask);
auto do_mask_abstract = std::make_shared<abstract::AbstractTensor>(TypeIdToType(inputx_type_id), inputx_shape);
dropout_do_mask->set_abstract(do_mask_abstract);
dropout_do_mask->set_scope(node->scope());
return dropout_do_mask2;
return dropout_do_mask;
}
const BaseRef DropoutUnifyMindIRPynative::DefinePattern() const {
VarPtr X = std::make_shared<Var>();
VarPtr Y = std::make_shared<Var>();
VarPtr Z = std::make_shared<Var>();
auto dropout = VectorRef({prim::kPrimDropout, X});
auto getitem0 = VectorRef({prim::kPrimTupleGetItem, dropout, Y});
auto getitem1 = VectorRef({prim::kPrimTupleGetItem, dropout, Z});
auto maketuple = VectorRef({prim::kPrimMakeTuple, getitem0, getitem1});
return maketuple;
}
const AnfNodePtr DropoutUnifyMindIRPynative::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const EquivPtr &) const {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(node);
auto maketuple_cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(maketuple_cnode);
auto getitem0_node = maketuple_cnode->input(1);
MS_EXCEPTION_IF_NULL(getitem0_node);
auto getitem1_node = maketuple_cnode->input(2);
MS_EXCEPTION_IF_NULL(getitem1_node);
auto getitem1_cnode = getitem1_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(getitem1_cnode);
auto dropout_node = getitem1_cnode->input(1);
MS_EXCEPTION_IF_NULL(dropout_node);
auto inputx_type_id = GetInputXDataType(dropout_node);
auto inputx_shape = GetInputXShape(dropout_node);
auto shape_value = CreateShapeValueNode(func_graph, inputx_shape, true);
auto keep_prob_value = CreateKeepPorbValueNode(func_graph, dropout_node, inputx_type_id);
// CreateDropoutGenMask
std::vector<AnfNodePtr> dropout_gen_mask_inputs{NewValueNode(std::make_shared<Primitive>(kDropoutGenMaskOpName)),
shape_value, keep_prob_value};
CNodePtr dropout_gen_mask = func_graph->NewCNode(dropout_gen_mask_inputs);
MS_EXCEPTION_IF_NULL(dropout_gen_mask);
AnfAlgo::CopyNodeAttrs(node, dropout_gen_mask);
auto output_shape = CalDropoutGenMaskOutput(inputx_shape);
auto gen_mask_abstract = std::make_shared<abstract::AbstractTensor>(kUInt8, output_shape);
MS_EXCEPTION_IF_NULL(gen_mask_abstract);
dropout_gen_mask->set_abstract(gen_mask_abstract);
dropout_gen_mask->set_scope(node->scope());
// CreateDropoutDoMask
MS_EXCEPTION_IF_NULL(dropout_node);
auto dropout_cnode = dropout_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(dropout_cnode);
auto dropout_input = dropout_cnode->input(1);
std::vector<AnfNodePtr> dropout_do_mask_inputs{NewValueNode(std::make_shared<Primitive>(kDropoutDoMaskOpName)),
dropout_input, dropout_gen_mask, keep_prob_value};
auto dropout_do_mask = func_graph->NewCNode(dropout_do_mask_inputs);
MS_EXCEPTION_IF_NULL(dropout_do_mask);
auto do_mask_abstract = std::make_shared<abstract::AbstractTensor>(TypeIdToType(inputx_type_id), inputx_shape);
dropout_do_mask->set_abstract(do_mask_abstract);
dropout_do_mask->set_scope(node->scope());
// replace genmask and domask
auto manager = func_graph->manager();
MS_EXCEPTION_IF_NULL(manager);
(void)manager->Replace(getitem0_node, dropout_do_mask);
(void)manager->Replace(getitem1_node, dropout_gen_mask);
return node;
}
const BaseRef DropoutGradUnifyMindIRPynative::DefinePattern() const {
VarPtr X = std::make_shared<Var>();
VarPtr Y = std::make_shared<Var>();
auto dropout_grad_prim = std::make_shared<Primitive>(kDropoutGradOpName);
return VectorRef({dropout_grad_prim, X, Y});
}
const AnfNodePtr DropoutGradUnifyMindIRPynative::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const EquivPtr &) const {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(node);
auto dropout_grad_cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(dropout_grad_cnode);
auto grad_input_type_id = GetInputXDataType(dropout_grad_cnode);
auto grad_input_shape = GetInputXShape(dropout_grad_cnode);
auto keep_prob_value = CreateKeepPorbValueNode(func_graph, dropout_grad_cnode, grad_input_type_id);
// CreateDropoutDoMask
auto grad_input = dropout_grad_cnode->input(1);
auto mask_input = dropout_grad_cnode->input(2);
std::vector<AnfNodePtr> dropout_do_mask_inputs{NewValueNode(std::make_shared<Primitive>(kDropoutDoMaskOpName)),
grad_input, mask_input, keep_prob_value};
auto dropout_do_mask = func_graph->NewCNode(dropout_do_mask_inputs);
MS_EXCEPTION_IF_NULL(dropout_do_mask);
auto do_mask_abstract =
std::make_shared<abstract::AbstractTensor>(TypeIdToType(grad_input_type_id), grad_input_shape);
dropout_do_mask->set_abstract(do_mask_abstract);
dropout_do_mask->set_scope(node->scope());
return dropout_do_mask;
}
} // namespace mindspore::opt

View File

@ -42,6 +42,24 @@ class DropoutGradUnifyMindIR : public PatternProcessPass {
private:
VarPtr grad_input_;
};
class DropoutUnifyMindIRPynative : public PatternProcessPass {
public:
explicit DropoutUnifyMindIRPynative(bool multigraph = true)
: PatternProcessPass("dropout_unify_mindir_pynative", multigraph) {}
~DropoutUnifyMindIRPynative() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
};
class DropoutGradUnifyMindIRPynative : public PatternProcessPass {
public:
explicit DropoutGradUnifyMindIRPynative(bool multigraph = true)
: PatternProcessPass("dropout_grad_unify_mindir_pynative", multigraph) {}
~DropoutGradUnifyMindIRPynative() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_MINDIR_DROPOUT_UNIFY_MINDIR_H_

View File

@ -444,6 +444,15 @@ void AscendSession::UnifyMindIR(const KernelGraphPtr &graph) {
unify_mindir_pm->AddPass(std::make_shared<opt::Conv2DUnifyMindIR>());
unify_mindir_pm->AddPass(std::make_shared<opt::Conv2DBackpropInputUnifyMindIR>());
unify_mindir_pm->AddPass(std::make_shared<opt::Conv2DBackpropFilterUnifyMindIR>());
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode) {
unify_mindir_pm->AddPass(std::make_shared<opt::DropoutGradUnifyMindIR>());
unify_mindir_pm->AddPass(std::make_shared<opt::DropoutUnifyMindIR>());
} else {
unify_mindir_pm->AddPass(std::make_shared<opt::DropoutUnifyMindIRPynative>());
unify_mindir_pm->AddPass(std::make_shared<opt::DropoutGradUnifyMindIRPynative>());
}
optimizer->AddPassManager(unify_mindir_pm);
(void)optimizer->Optimize(graph);

View File

@ -1633,7 +1633,11 @@ std::shared_ptr<KernelGraph> SessionBasic::ConstructSingleOpGraph(const OpRunInf
manager->AddFuncGraph(graph);
graph->set_manager(manager);
}
UnifyMindIR(graph);
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
if (ms_context->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER)) {
UnifyMindIR(graph);
}
return graph;
}

View File

@ -29,7 +29,6 @@ from mindspore.ops.primitive import constexpr, Primitive
from mindspore.common.parameter import Parameter
from mindspore._extends import cell_attr_register
from mindspore._checkparam import Rel, Validator
from mindspore import context
from ..cell import Cell
from .activation import get_activation
@ -146,33 +145,17 @@ class Dropout(Cell):
seed0, seed1 = _get_graph_seed(0, "dropout")
self.seed0 = seed0
self.seed1 = seed1
self.dtype = dtype
self.get_shape = P.Shape()
self.dropout_gen_mask = P.DropoutGenMask(Seed0=self.seed0, Seed1=self.seed1)
self.dropout_do_mask = P.DropoutDoMask()
self.cast = P.Cast()
self.is_ascend = context.get_context('device_target') in ["Ascend"]
self.dropout = P.Dropout(keep_prob)
self.dropout = P.Dropout(keep_prob, seed0, seed1)
def construct(self, x):
if not self.training:
return x
if not self.is_ascend:
out, _ = self.dropout(x)
return out
if self.keep_prob == 1:
return x
shape = self.get_shape(x)
dtype = P.DType()(x)
if _is_float_dtype(dtype):
keep_prob = self.cast(self.keep_prob, dtype)
else:
keep_prob = self.cast(self.keep_prob, mstype.float16)
output = self.dropout_gen_mask(shape, keep_prob)
return self.dropout_do_mask(x, output, keep_prob)
out, _ = self.dropout(x)
return out
def extend_repr(self):
return 'keep_prob={}, dtype={}'.format(self.keep_prob, self.dtype)