forked from mindspore-Ecosystem/mindspore
fix dropout unify_mindir pass
This commit is contained in:
parent
b41d83a7df
commit
389da54525
|
@ -15,7 +15,9 @@
|
|||
*/
|
||||
|
||||
#include "backend/optimizer/ascend/mindir/dropout_unify_mindir.h"
|
||||
#include <ops/all_ops.h>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <numeric>
|
||||
#include <algorithm>
|
||||
|
@ -23,45 +25,69 @@
|
|||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
#include "utils/log_adapter.h"
|
||||
|
||||
/*
|
||||
DropoutGenMask:
|
||||
attr: seed0 seed1:
|
||||
input: 1.shape <>;
|
||||
2. keep_prob: type base on inputx type, if x in float/float16, then use this type, else use float16;
|
||||
output: shape: (count + 127) % 128 * 16
|
||||
*/
|
||||
namespace mindspore::opt {
|
||||
namespace {
|
||||
constexpr auto kKeepProb = "keep_prob";
|
||||
constexpr auto kSeed0 = "Seed0";
|
||||
constexpr auto kSeed1 = "Seed1";
|
||||
constexpr auto kUint8BitSize = 8;
|
||||
|
||||
namespace mindspore::opt {
|
||||
constexpr int64_t kMaskAlignNum = 128;
|
||||
constexpr int64_t kMaskMultiNum = 16;
|
||||
constexpr size_t kFloat16Len = 2; // size of float16
|
||||
namespace {
|
||||
AnfNodePtr GetDropoutKeepProb(const AnfNodePtr &node, float *keep_prob) {
|
||||
MS_LOG(INFO) << "GetDropoutNodeInfo start.";
|
||||
constexpr size_t kInt64Len = 8; // size of int64
|
||||
|
||||
TypeId GetInputXDataType(const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_EXCEPTION_IF_NULL(keep_prob);
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (!AnfAlgo::HasNodeAttr(kKeepProb, cnode) || !AnfAlgo::HasNodeAttr(kSeed0, cnode) ||
|
||||
!AnfAlgo::HasNodeAttr(kSeed1, cnode)) {
|
||||
MS_LOG(EXCEPTION) << "Dropout node does nothave attr: keep_prob or seed0 or seed1.";
|
||||
auto dropout_input_type = AnfAlgo::GetPrevNodeOutputInferDataType(node, 0);
|
||||
if (dropout_input_type != kNumberTypeFloat32 && dropout_input_type != kNumberTypeFloat &&
|
||||
dropout_input_type != kNumberTypeFloat16) {
|
||||
dropout_input_type = kNumberTypeFloat16;
|
||||
}
|
||||
*keep_prob = AnfAlgo::GetNodeAttr<float>(node, kKeepProb);
|
||||
MS_LOG(INFO) << "keep_prob: " << *keep_prob;
|
||||
// return dropout input. maybe tensor or pre cnode output
|
||||
return cnode->input(1);
|
||||
MS_LOG(INFO) << "Dropout input data type: " << TypeIdLabel(dropout_input_type);
|
||||
return dropout_input_type;
|
||||
}
|
||||
|
||||
ValueNodePtr CreateKeepPorbValueNode(const FuncGraphPtr &func_graph, const float &keep_prob, const TypePtr &dtype) {
|
||||
MS_LOG(INFO) << "CreateKeepPorbValueNode start.";
|
||||
std::vector<int64_t> GetInputXShape(const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
std::vector<int64_t> shapes;
|
||||
auto shape_size_t = AnfAlgo::GetPrevNodeOutputInferShape(node, 0);
|
||||
std::transform(shape_size_t.begin(), shape_size_t.end(), std::back_inserter(shapes), SizeToLong);
|
||||
return shapes;
|
||||
}
|
||||
|
||||
ValueNodePtr CreateKeepPorbValueNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node, TypeId type_id) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
auto kernel_graph = func_graph->cast<KernelGraphPtr>();
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
// Step1: get keep_prob
|
||||
if (!AnfAlgo::HasNodeAttr(kKeepProb, cnode)) {
|
||||
MS_LOG(EXCEPTION) << "Dropout node does not have attr: keep_prob.";
|
||||
}
|
||||
if (AnfAlgo::GetCNodePrimitive(cnode)->ToString() == kDropoutOpName) {
|
||||
if (!AnfAlgo::HasNodeAttr(kSeed0, cnode) || !AnfAlgo::HasNodeAttr(kSeed1, cnode)) {
|
||||
MS_LOG(EXCEPTION) << "Dropout node does not have attr: seed0 or seed1.";
|
||||
}
|
||||
}
|
||||
auto keep_prob = AnfAlgo::GetNodeAttr<float>(node, kKeepProb);
|
||||
MS_LOG(INFO) << "Keep_prob value: " << keep_prob;
|
||||
|
||||
std::vector<int64_t> keep_prob_shape = {};
|
||||
ShapeVector shape = {};
|
||||
auto keep_prob_tensor = std::make_shared<tensor::Tensor>(dtype->type_id(), keep_prob_shape);
|
||||
auto keep_prob_tensor = std::make_shared<tensor::Tensor>(type_id, keep_prob_shape);
|
||||
MS_EXCEPTION_IF_NULL(keep_prob_tensor);
|
||||
auto data_ptr = keep_prob_tensor->data_c();
|
||||
MS_EXCEPTION_IF_NULL(data_ptr);
|
||||
// keep_prob's datatype is same with input data
|
||||
if (dtype->type_id() == kNumberTypeFloat16) {
|
||||
float16 half_data = float16(keep_prob);
|
||||
auto ret_code = memcpy_s(data_ptr, kFloat16Len, &half_data, kFloat16Len);
|
||||
if (type_id == kNumberTypeFloat16) {
|
||||
auto half_data = float16(keep_prob);
|
||||
auto ret_code = memcpy_s(data_ptr, static_cast<size_t>(keep_prob_tensor->data().nbytes()), &half_data, kFloat16Len);
|
||||
if (ret_code != 0) {
|
||||
MS_LOG(EXCEPTION) << "Failed to copy data into Tensor.";
|
||||
}
|
||||
|
@ -69,59 +95,65 @@ ValueNodePtr CreateKeepPorbValueNode(const FuncGraphPtr &func_graph, const float
|
|||
auto *val = reinterpret_cast<float *>(data_ptr);
|
||||
*val = keep_prob;
|
||||
}
|
||||
auto abstract = std::make_shared<abstract::AbstractTensor>(dtype, shape);
|
||||
auto kernel_graph = func_graph->cast<KernelGraphPtr>();
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
auto abstract = std::make_shared<abstract::AbstractTensor>(TypeIdToType(type_id), keep_prob_shape);
|
||||
auto keep_prob_value = kernel_graph->NewValueNode(abstract, keep_prob_tensor);
|
||||
MS_EXCEPTION_IF_NULL(keep_prob_value);
|
||||
kernel_graph->AddValueNodeToGraph(keep_prob_value);
|
||||
return keep_prob_value;
|
||||
}
|
||||
|
||||
std::vector<int64_t> GetInputShape(const AnfNodePtr &node, const AnfNodePtr &dropout_input) {
|
||||
MS_LOG(INFO) << "GetInputShape start.";
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_EXCEPTION_IF_NULL(dropout_input);
|
||||
std::vector<int64_t> shapes;
|
||||
if (dropout_input->isa<Parameter>()) {
|
||||
MS_LOG(INFO) << "Dropout input from parameter node.";
|
||||
// single test case
|
||||
auto dropout_input_value = dropout_input->cast<ParameterPtr>();
|
||||
MS_EXCEPTION_IF_NULL(dropout_input_value);
|
||||
MS_EXCEPTION_IF_NULL(dropout_input_value->Shape());
|
||||
auto shape = dropout_input_value->Shape()->cast<abstract::ShapePtr>();
|
||||
MS_EXCEPTION_IF_NULL(shape);
|
||||
return shape->shape();
|
||||
} else if (dropout_input->isa<CNode>()) {
|
||||
MS_LOG(INFO) << "Dropout input from cnode.";
|
||||
auto dropout_input_node = dropout_input->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(dropout_input_node);
|
||||
auto shape_size_t = AnfAlgo::GetPrevNodeOutputInferShape(node, 0);
|
||||
std::transform(shape_size_t.begin(), shape_size_t.end(), std::back_inserter(shapes), SizeToLong);
|
||||
return shapes;
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Dropout input is not parameter or cnode.";
|
||||
return {};
|
||||
}
|
||||
}
|
||||
|
||||
ValueNodePtr CreateShapeValueNode(const FuncGraphPtr &func_graph, const std::vector<int64_t> &shape) {
|
||||
ValueNodePtr CreateShapeValueNode(const FuncGraphPtr &func_graph, const std::vector<int64_t> &shape,
|
||||
bool is_pynative = false) {
|
||||
MS_LOG(INFO) << "CreateShapeValueNode start.";
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
auto kernel_graph = func_graph->cast<KernelGraphPtr>();
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
std::vector<ValuePtr> dim_values{};
|
||||
abstract::AbstractBasePtrList abs{};
|
||||
for (const auto &dim : shape) {
|
||||
dim_values.push_back(MakeValue(dim));
|
||||
abs.push_back(std::make_shared<abstract::AbstractScalar>(dim));
|
||||
ValuePtr shape_value = nullptr;
|
||||
AbstractBasePtr abstract = nullptr;
|
||||
if (is_pynative) {
|
||||
// pynative mode need to create tensor
|
||||
int64_t shape_dim = SizeToLong(shape.size());
|
||||
std::vector<int64_t> shape_vec_shape = {shape_dim};
|
||||
auto shape_tensor = std::make_shared<tensor::Tensor>(kNumberTypeInt64, shape_vec_shape);
|
||||
MS_EXCEPTION_IF_NULL(shape_tensor);
|
||||
auto data_ptr = shape_tensor->data_c();
|
||||
MS_EXCEPTION_IF_NULL(data_ptr);
|
||||
auto elem_num = shape.size() * kInt64Len;
|
||||
auto ret_code = memcpy_s(data_ptr, static_cast<size_t>(shape_tensor->data().nbytes()), &shape[0], elem_num);
|
||||
if (ret_code != 0) {
|
||||
MS_LOG(EXCEPTION) << "Failed to copy data into Tensor.";
|
||||
}
|
||||
shape_value = shape_tensor;
|
||||
abstract = std::make_shared<abstract::AbstractTensor>(kInt64, shape_vec_shape);
|
||||
} else {
|
||||
std::vector<ValuePtr> dim_values{};
|
||||
abstract::AbstractBasePtrList abs{};
|
||||
for (const auto &dim : shape) {
|
||||
dim_values.push_back(MakeValue(dim));
|
||||
abs.push_back(std::make_shared<abstract::AbstractScalar>(dim));
|
||||
}
|
||||
shape_value = std::make_shared<ValueTuple>(dim_values);
|
||||
abstract = std::make_shared<abstract::AbstractTuple>(abs);
|
||||
}
|
||||
auto shape_value_tuple = std::make_shared<ValueTuple>(dim_values);
|
||||
MS_EXCEPTION_IF_NULL(shape_value_tuple);
|
||||
auto abstract = std::make_shared<abstract::AbstractTuple>(abs);
|
||||
MS_EXCEPTION_IF_NULL(abstract);
|
||||
auto shape_value = kernel_graph->NewValueNode(abstract, shape_value_tuple);
|
||||
MS_EXCEPTION_IF_NULL(shape_value);
|
||||
kernel_graph->AddValueNodeToGraph(shape_value);
|
||||
return shape_value;
|
||||
MS_EXCEPTION_IF_NULL(abstract);
|
||||
auto shape_value_node = kernel_graph->NewValueNode(abstract, shape_value);
|
||||
MS_EXCEPTION_IF_NULL(shape_value_node);
|
||||
kernel_graph->AddValueNodeToGraph(shape_value_node);
|
||||
return shape_value_node;
|
||||
}
|
||||
|
||||
std::vector<int64_t> CalDropoutGenMaskOutput(const std::vector<int64_t> &shape) {
|
||||
auto output_size = std::accumulate(shape.begin(), shape.end(), static_cast<int64_t>(1), std::multiplies<int64_t>());
|
||||
auto output_count = output_size / kMaskAlignNum;
|
||||
if (output_size % kMaskAlignNum != 0) {
|
||||
output_count++;
|
||||
}
|
||||
auto ret = output_count * kMaskMultiNum;
|
||||
MS_LOG(INFO) << "Output_size: " << ret;
|
||||
return {ret};
|
||||
}
|
||||
} // namespace
|
||||
|
||||
|
@ -141,34 +173,34 @@ const AnfNodePtr DropoutUnifyMindIR::Process(const FuncGraphPtr &func_graph, con
|
|||
MS_EXCEPTION_IF_NULL(tuple_cnode);
|
||||
auto dropout_node = tuple_cnode->input(1);
|
||||
MS_EXCEPTION_IF_NULL(dropout_node);
|
||||
float keep_prob = 0;
|
||||
auto dropout_input = GetDropoutKeepProb(dropout_node, &keep_prob);
|
||||
auto dropout_dtype = AnfAlgo::GetOutputInferDataType(dropout_node, 0) == kNumberTypeFloat16 ? kFloat16 : kFloat32;
|
||||
auto keep_prob_value = CreateKeepPorbValueNode(func_graph, keep_prob, dropout_dtype);
|
||||
auto shape = GetInputShape(dropout_node, dropout_input);
|
||||
auto shape_value = CreateShapeValueNode(func_graph, shape);
|
||||
|
||||
auto inputx_type_id = GetInputXDataType(dropout_node);
|
||||
auto inputx_shape = GetInputXShape(dropout_node);
|
||||
auto shape_value = CreateShapeValueNode(func_graph, inputx_shape);
|
||||
auto keep_prob_value = CreateKeepPorbValueNode(func_graph, dropout_node, inputx_type_id);
|
||||
|
||||
// CreateDropoutGenMask
|
||||
auto output_size = std::accumulate(shape.begin(), shape.end(), static_cast<int64_t>(1), std::multiplies<int64_t>());
|
||||
output_size = output_size / kUint8BitSize;
|
||||
MS_LOG(INFO) << "Output_size: " << output_size;
|
||||
std::vector<AnfNodePtr> dropout_gen_mask_inputs{NewValueNode(std::make_shared<Primitive>(kDropoutGenMaskOpName)),
|
||||
shape_value, keep_prob_value};
|
||||
CNodePtr dropout_gen_mask = func_graph->NewCNode(dropout_gen_mask_inputs);
|
||||
MS_EXCEPTION_IF_NULL(dropout_gen_mask);
|
||||
AnfAlgo::CopyNodeAttrs(node, dropout_gen_mask);
|
||||
ShapeVector dropout_gen_mask_output = {output_size};
|
||||
auto gen_mask_abstract = std::make_shared<abstract::AbstractTensor>(kUInt8, dropout_gen_mask_output);
|
||||
auto output_shape = CalDropoutGenMaskOutput(inputx_shape);
|
||||
auto gen_mask_abstract = std::make_shared<abstract::AbstractTensor>(kUInt8, output_shape);
|
||||
MS_EXCEPTION_IF_NULL(gen_mask_abstract);
|
||||
dropout_gen_mask->set_abstract(gen_mask_abstract);
|
||||
dropout_gen_mask->set_scope(node->scope());
|
||||
|
||||
// CreateDropoutDoMask
|
||||
MS_EXCEPTION_IF_NULL(dropout_node);
|
||||
auto dropout_cnode = dropout_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(dropout_cnode);
|
||||
auto dropout_input = dropout_cnode->input(1);
|
||||
std::vector<AnfNodePtr> dropout_do_mask_inputs{NewValueNode(std::make_shared<Primitive>(kDropoutDoMaskOpName)),
|
||||
dropout_input, dropout_gen_mask, keep_prob_value};
|
||||
auto dropout_do_mask = func_graph->NewCNode(dropout_do_mask_inputs);
|
||||
MS_EXCEPTION_IF_NULL(dropout_do_mask);
|
||||
ShapeVector dropout_do_mask_output = shape;
|
||||
auto do_mask_abstract = std::make_shared<abstract::AbstractTensor>(dropout_dtype, dropout_do_mask_output);
|
||||
auto do_mask_abstract = std::make_shared<abstract::AbstractTensor>(TypeIdToType(inputx_type_id), inputx_shape);
|
||||
dropout_do_mask->set_abstract(do_mask_abstract);
|
||||
dropout_do_mask->set_scope(node->scope());
|
||||
|
||||
|
@ -178,8 +210,6 @@ const AnfNodePtr DropoutUnifyMindIR::Process(const FuncGraphPtr &func_graph, con
|
|||
const BaseRef DropoutGradUnifyMindIR::DefinePattern() const {
|
||||
VarPtr X = std::make_shared<Var>();
|
||||
VarPtr Y = std::make_shared<Var>();
|
||||
MS_EXCEPTION_IF_NULL(X);
|
||||
MS_EXCEPTION_IF_NULL(Y);
|
||||
auto dropout_prim = std::make_shared<Primitive>(kDropoutOpName);
|
||||
auto tuple_getitem_prim = prim::kPrimTupleGetItem;
|
||||
auto dropout_grad_prim = std::make_shared<Primitive>(kDropoutGradOpName);
|
||||
|
@ -194,58 +224,74 @@ const AnfNodePtr DropoutGradUnifyMindIR::Process(const FuncGraphPtr &func_graph,
|
|||
const EquivPtr &equiv) const {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto dropout_grad = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(dropout_grad);
|
||||
auto tuple_getitem = dropout_grad->input(2);
|
||||
MS_EXCEPTION_IF_NULL(tuple_getitem);
|
||||
auto tuple_getitem_cnode = tuple_getitem->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(tuple_getitem_cnode);
|
||||
auto dropout_node = tuple_getitem_cnode->input(1);
|
||||
auto dropout_grad_cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(dropout_grad_cnode);
|
||||
auto getitem1_node = dropout_grad_cnode->input(2);
|
||||
MS_EXCEPTION_IF_NULL(getitem1_node);
|
||||
auto getitem1_cnode = getitem1_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(getitem1_cnode);
|
||||
auto dropout_node = getitem1_cnode->input(1);
|
||||
MS_EXCEPTION_IF_NULL(dropout_node);
|
||||
float keep_prob = 0;
|
||||
auto dropout_input = GetDropoutKeepProb(dropout_node, &keep_prob);
|
||||
auto dropout_dtype = AnfAlgo::GetOutputInferDataType(dropout_node, 0) == kNumberTypeFloat16 ? kFloat16 : kFloat32;
|
||||
auto keep_prob_value = CreateKeepPorbValueNode(func_graph, keep_prob, dropout_dtype);
|
||||
auto shape = GetInputShape(dropout_node, dropout_input);
|
||||
auto shape_value = CreateShapeValueNode(func_graph, shape);
|
||||
auto dropout_cnode = dropout_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(dropout_cnode);
|
||||
|
||||
auto inputx_type_id = GetInputXDataType(dropout_node);
|
||||
auto inputx_shape = GetInputXShape(dropout_node);
|
||||
auto shape_value = CreateShapeValueNode(func_graph, inputx_shape);
|
||||
auto keep_prob_value = CreateKeepPorbValueNode(func_graph, dropout_node, inputx_type_id);
|
||||
|
||||
// CreateDropoutGenMask
|
||||
auto output_size = std::accumulate(shape.begin(), shape.end(), static_cast<int64_t>(1), std::multiplies<int64_t>());
|
||||
output_size = output_size / kUint8BitSize;
|
||||
MS_LOG(INFO) << "Output_size: " << output_size;
|
||||
std::vector<AnfNodePtr> dropout_gen_mask_inputs{NewValueNode(std::make_shared<Primitive>(kDropoutGenMaskOpName)),
|
||||
shape_value, keep_prob_value};
|
||||
CNodePtr dropout_gen_mask = func_graph->NewCNode(dropout_gen_mask_inputs);
|
||||
MS_EXCEPTION_IF_NULL(dropout_gen_mask);
|
||||
AnfAlgo::CopyNodeAttrs(node, dropout_gen_mask);
|
||||
ShapeVector dropout_gen_mask_output = {output_size};
|
||||
auto gen_mask_abstract = std::make_shared<abstract::AbstractTensor>(kUInt8, dropout_gen_mask_output);
|
||||
auto output_shape = CalDropoutGenMaskOutput(inputx_shape);
|
||||
auto gen_mask_abstract = std::make_shared<abstract::AbstractTensor>(kUInt8, output_shape);
|
||||
MS_EXCEPTION_IF_NULL(gen_mask_abstract);
|
||||
dropout_gen_mask->set_abstract(gen_mask_abstract);
|
||||
dropout_gen_mask->set_scope(dropout_node->scope());
|
||||
// AnfAlgo::CopyNodeAttrs(node, dropout_gen_mask);
|
||||
dropout_gen_mask->set_scope(node->scope());
|
||||
|
||||
// CreateDropoutDoMask-forward
|
||||
auto manager = func_graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
auto &node_users = manager->node_users();
|
||||
auto iter = node_users.find(dropout_node);
|
||||
CNodePtr dropout_do_mask1 = nullptr;
|
||||
if (iter != node_users.end()) {
|
||||
for (auto &node_index : iter->second) {
|
||||
// Dropout has two outputs, so output node is tuple_getitem
|
||||
auto tuple_getitem_cnode2 = node_index.first->cast<CNodePtr>();
|
||||
// check if Dropout's first output, which is used by forward, is used.
|
||||
auto getitem_index = GetValue<int64_t>(tuple_getitem_cnode2->input(2)->cast<ValueNodePtr>()->value());
|
||||
if (getitem_index == 0) {
|
||||
std::vector<AnfNodePtr> dropout_do_mask1_inputs{NewValueNode(std::make_shared<Primitive>(kDropoutDoMaskOpName)),
|
||||
dropout_input, dropout_gen_mask, keep_prob_value};
|
||||
auto dropout_do_mask1 = func_graph->NewCNode(dropout_do_mask1_inputs);
|
||||
MS_EXCEPTION_IF_NULL(dropout_do_mask1);
|
||||
ShapeVector dropout_do_mask1_output = shape;
|
||||
auto do_mask_abstract1 = std::make_shared<abstract::AbstractTensor>(dropout_dtype, dropout_do_mask1_output);
|
||||
dropout_do_mask1->set_abstract(do_mask_abstract1);
|
||||
dropout_do_mask1->set_scope(dropout_node->scope());
|
||||
(void)manager->Replace(tuple_getitem_cnode2, dropout_do_mask1);
|
||||
break;
|
||||
auto used_node = node_index.first;
|
||||
if (AnfAlgo::CheckPrimitiveType(used_node, prim::kPrimTupleGetItem)) {
|
||||
// check if Dropout's first output, which is used by forward, is used
|
||||
if (AnfAlgo::GetTupleGetItemOutIndex(used_node->cast<CNodePtr>()) == 0) {
|
||||
// if Dropout's first output is used, create forward DropoutDoMask
|
||||
auto dropout_input = dropout_cnode->input(1);
|
||||
std::vector<AnfNodePtr> dropout_do_mask1_inputs{
|
||||
NewValueNode(std::make_shared<Primitive>(kDropoutDoMaskOpName)), dropout_input, dropout_gen_mask,
|
||||
keep_prob_value};
|
||||
dropout_do_mask1 = func_graph->NewCNode(dropout_do_mask1_inputs);
|
||||
MS_EXCEPTION_IF_NULL(dropout_do_mask1);
|
||||
auto do_mask_abstract1 =
|
||||
std::make_shared<abstract::AbstractTensor>(TypeIdToType(inputx_type_id), inputx_shape);
|
||||
dropout_do_mask1->set_abstract(do_mask_abstract1);
|
||||
dropout_do_mask1->set_scope(dropout_node->scope());
|
||||
(void)manager->Replace(used_node, dropout_do_mask1);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if (dropout_do_mask1 != nullptr) {
|
||||
// Dropout is used by ControlDepend in some situation, need to replace ControlDepend.
|
||||
auto &users = manager->node_users();
|
||||
iter = users.find(dropout_node);
|
||||
if (iter != users.end()) {
|
||||
for (auto &node_index : iter->second) {
|
||||
auto used_node = node_index.first;
|
||||
if (AnfAlgo::CheckPrimitiveType(used_node, prim::kPrimControlDepend)) {
|
||||
(void)manager->Replace(used_node, dropout_do_mask1);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -254,16 +300,112 @@ const AnfNodePtr DropoutGradUnifyMindIR::Process(const FuncGraphPtr &func_graph,
|
|||
if (equiv->find(grad_input_) == equiv->end()) {
|
||||
MS_LOG(EXCEPTION) << "Can not find grad_input in this pattern.";
|
||||
}
|
||||
auto grad_input = utils::cast<AnfNodePtr>((*equiv)[grad_input_]);
|
||||
std::vector<AnfNodePtr> dropout_do_mask2_inputs{NewValueNode(std::make_shared<Primitive>(kDropoutDoMaskOpName)),
|
||||
grad_input, dropout_gen_mask, keep_prob_value};
|
||||
auto dropout_do_mask2 = func_graph->NewCNode(dropout_do_mask2_inputs);
|
||||
MS_EXCEPTION_IF_NULL(dropout_do_mask2);
|
||||
ShapeVector dropout_do_mask2_output = shape;
|
||||
auto do_mask_abstract2 = std::make_shared<abstract::AbstractTensor>(dropout_dtype, dropout_do_mask2_output);
|
||||
dropout_do_mask2->set_abstract(do_mask_abstract2);
|
||||
dropout_do_mask2->set_scope(node->scope());
|
||||
auto dropout_grad_input = utils::cast<AnfNodePtr>((*equiv)[grad_input_]);
|
||||
std::vector<AnfNodePtr> dropout_do_mask_inputs{NewValueNode(std::make_shared<Primitive>(kDropoutDoMaskOpName)),
|
||||
dropout_grad_input, dropout_gen_mask, keep_prob_value};
|
||||
auto dropout_do_mask = func_graph->NewCNode(dropout_do_mask_inputs);
|
||||
MS_EXCEPTION_IF_NULL(dropout_do_mask);
|
||||
auto do_mask_abstract = std::make_shared<abstract::AbstractTensor>(TypeIdToType(inputx_type_id), inputx_shape);
|
||||
dropout_do_mask->set_abstract(do_mask_abstract);
|
||||
dropout_do_mask->set_scope(node->scope());
|
||||
|
||||
return dropout_do_mask2;
|
||||
return dropout_do_mask;
|
||||
}
|
||||
|
||||
const BaseRef DropoutUnifyMindIRPynative::DefinePattern() const {
|
||||
VarPtr X = std::make_shared<Var>();
|
||||
VarPtr Y = std::make_shared<Var>();
|
||||
VarPtr Z = std::make_shared<Var>();
|
||||
auto dropout = VectorRef({prim::kPrimDropout, X});
|
||||
auto getitem0 = VectorRef({prim::kPrimTupleGetItem, dropout, Y});
|
||||
auto getitem1 = VectorRef({prim::kPrimTupleGetItem, dropout, Z});
|
||||
auto maketuple = VectorRef({prim::kPrimMakeTuple, getitem0, getitem1});
|
||||
return maketuple;
|
||||
}
|
||||
|
||||
const AnfNodePtr DropoutUnifyMindIRPynative::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||
const EquivPtr &) const {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto maketuple_cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(maketuple_cnode);
|
||||
auto getitem0_node = maketuple_cnode->input(1);
|
||||
MS_EXCEPTION_IF_NULL(getitem0_node);
|
||||
auto getitem1_node = maketuple_cnode->input(2);
|
||||
MS_EXCEPTION_IF_NULL(getitem1_node);
|
||||
auto getitem1_cnode = getitem1_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(getitem1_cnode);
|
||||
auto dropout_node = getitem1_cnode->input(1);
|
||||
MS_EXCEPTION_IF_NULL(dropout_node);
|
||||
|
||||
auto inputx_type_id = GetInputXDataType(dropout_node);
|
||||
auto inputx_shape = GetInputXShape(dropout_node);
|
||||
auto shape_value = CreateShapeValueNode(func_graph, inputx_shape, true);
|
||||
auto keep_prob_value = CreateKeepPorbValueNode(func_graph, dropout_node, inputx_type_id);
|
||||
|
||||
// CreateDropoutGenMask
|
||||
std::vector<AnfNodePtr> dropout_gen_mask_inputs{NewValueNode(std::make_shared<Primitive>(kDropoutGenMaskOpName)),
|
||||
shape_value, keep_prob_value};
|
||||
CNodePtr dropout_gen_mask = func_graph->NewCNode(dropout_gen_mask_inputs);
|
||||
MS_EXCEPTION_IF_NULL(dropout_gen_mask);
|
||||
AnfAlgo::CopyNodeAttrs(node, dropout_gen_mask);
|
||||
auto output_shape = CalDropoutGenMaskOutput(inputx_shape);
|
||||
auto gen_mask_abstract = std::make_shared<abstract::AbstractTensor>(kUInt8, output_shape);
|
||||
MS_EXCEPTION_IF_NULL(gen_mask_abstract);
|
||||
dropout_gen_mask->set_abstract(gen_mask_abstract);
|
||||
dropout_gen_mask->set_scope(node->scope());
|
||||
|
||||
// CreateDropoutDoMask
|
||||
MS_EXCEPTION_IF_NULL(dropout_node);
|
||||
auto dropout_cnode = dropout_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(dropout_cnode);
|
||||
auto dropout_input = dropout_cnode->input(1);
|
||||
std::vector<AnfNodePtr> dropout_do_mask_inputs{NewValueNode(std::make_shared<Primitive>(kDropoutDoMaskOpName)),
|
||||
dropout_input, dropout_gen_mask, keep_prob_value};
|
||||
auto dropout_do_mask = func_graph->NewCNode(dropout_do_mask_inputs);
|
||||
MS_EXCEPTION_IF_NULL(dropout_do_mask);
|
||||
auto do_mask_abstract = std::make_shared<abstract::AbstractTensor>(TypeIdToType(inputx_type_id), inputx_shape);
|
||||
dropout_do_mask->set_abstract(do_mask_abstract);
|
||||
dropout_do_mask->set_scope(node->scope());
|
||||
|
||||
// replace genmask and domask
|
||||
auto manager = func_graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
(void)manager->Replace(getitem0_node, dropout_do_mask);
|
||||
(void)manager->Replace(getitem1_node, dropout_gen_mask);
|
||||
|
||||
return node;
|
||||
}
|
||||
|
||||
const BaseRef DropoutGradUnifyMindIRPynative::DefinePattern() const {
|
||||
VarPtr X = std::make_shared<Var>();
|
||||
VarPtr Y = std::make_shared<Var>();
|
||||
auto dropout_grad_prim = std::make_shared<Primitive>(kDropoutGradOpName);
|
||||
return VectorRef({dropout_grad_prim, X, Y});
|
||||
}
|
||||
|
||||
const AnfNodePtr DropoutGradUnifyMindIRPynative::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||
const EquivPtr &) const {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto dropout_grad_cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(dropout_grad_cnode);
|
||||
|
||||
auto grad_input_type_id = GetInputXDataType(dropout_grad_cnode);
|
||||
auto grad_input_shape = GetInputXShape(dropout_grad_cnode);
|
||||
auto keep_prob_value = CreateKeepPorbValueNode(func_graph, dropout_grad_cnode, grad_input_type_id);
|
||||
|
||||
// CreateDropoutDoMask
|
||||
auto grad_input = dropout_grad_cnode->input(1);
|
||||
auto mask_input = dropout_grad_cnode->input(2);
|
||||
std::vector<AnfNodePtr> dropout_do_mask_inputs{NewValueNode(std::make_shared<Primitive>(kDropoutDoMaskOpName)),
|
||||
grad_input, mask_input, keep_prob_value};
|
||||
auto dropout_do_mask = func_graph->NewCNode(dropout_do_mask_inputs);
|
||||
MS_EXCEPTION_IF_NULL(dropout_do_mask);
|
||||
auto do_mask_abstract =
|
||||
std::make_shared<abstract::AbstractTensor>(TypeIdToType(grad_input_type_id), grad_input_shape);
|
||||
dropout_do_mask->set_abstract(do_mask_abstract);
|
||||
dropout_do_mask->set_scope(node->scope());
|
||||
return dropout_do_mask;
|
||||
}
|
||||
} // namespace mindspore::opt
|
||||
|
|
|
@ -42,6 +42,24 @@ class DropoutGradUnifyMindIR : public PatternProcessPass {
|
|||
private:
|
||||
VarPtr grad_input_;
|
||||
};
|
||||
|
||||
class DropoutUnifyMindIRPynative : public PatternProcessPass {
|
||||
public:
|
||||
explicit DropoutUnifyMindIRPynative(bool multigraph = true)
|
||||
: PatternProcessPass("dropout_unify_mindir_pynative", multigraph) {}
|
||||
~DropoutUnifyMindIRPynative() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
};
|
||||
|
||||
class DropoutGradUnifyMindIRPynative : public PatternProcessPass {
|
||||
public:
|
||||
explicit DropoutGradUnifyMindIRPynative(bool multigraph = true)
|
||||
: PatternProcessPass("dropout_grad_unify_mindir_pynative", multigraph) {}
|
||||
~DropoutGradUnifyMindIRPynative() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_MINDIR_DROPOUT_UNIFY_MINDIR_H_
|
||||
|
|
|
@ -444,6 +444,15 @@ void AscendSession::UnifyMindIR(const KernelGraphPtr &graph) {
|
|||
unify_mindir_pm->AddPass(std::make_shared<opt::Conv2DUnifyMindIR>());
|
||||
unify_mindir_pm->AddPass(std::make_shared<opt::Conv2DBackpropInputUnifyMindIR>());
|
||||
unify_mindir_pm->AddPass(std::make_shared<opt::Conv2DBackpropFilterUnifyMindIR>());
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(ms_context);
|
||||
if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode) {
|
||||
unify_mindir_pm->AddPass(std::make_shared<opt::DropoutGradUnifyMindIR>());
|
||||
unify_mindir_pm->AddPass(std::make_shared<opt::DropoutUnifyMindIR>());
|
||||
} else {
|
||||
unify_mindir_pm->AddPass(std::make_shared<opt::DropoutUnifyMindIRPynative>());
|
||||
unify_mindir_pm->AddPass(std::make_shared<opt::DropoutGradUnifyMindIRPynative>());
|
||||
}
|
||||
|
||||
optimizer->AddPassManager(unify_mindir_pm);
|
||||
(void)optimizer->Optimize(graph);
|
||||
|
|
|
@ -1633,7 +1633,11 @@ std::shared_ptr<KernelGraph> SessionBasic::ConstructSingleOpGraph(const OpRunInf
|
|||
manager->AddFuncGraph(graph);
|
||||
graph->set_manager(manager);
|
||||
}
|
||||
UnifyMindIR(graph);
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(ms_context);
|
||||
if (ms_context->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER)) {
|
||||
UnifyMindIR(graph);
|
||||
}
|
||||
return graph;
|
||||
}
|
||||
|
||||
|
|
|
@ -29,7 +29,6 @@ from mindspore.ops.primitive import constexpr, Primitive
|
|||
from mindspore.common.parameter import Parameter
|
||||
from mindspore._extends import cell_attr_register
|
||||
from mindspore._checkparam import Rel, Validator
|
||||
from mindspore import context
|
||||
from ..cell import Cell
|
||||
from .activation import get_activation
|
||||
|
||||
|
@ -146,33 +145,17 @@ class Dropout(Cell):
|
|||
seed0, seed1 = _get_graph_seed(0, "dropout")
|
||||
self.seed0 = seed0
|
||||
self.seed1 = seed1
|
||||
self.dtype = dtype
|
||||
self.get_shape = P.Shape()
|
||||
self.dropout_gen_mask = P.DropoutGenMask(Seed0=self.seed0, Seed1=self.seed1)
|
||||
self.dropout_do_mask = P.DropoutDoMask()
|
||||
self.cast = P.Cast()
|
||||
self.is_ascend = context.get_context('device_target') in ["Ascend"]
|
||||
self.dropout = P.Dropout(keep_prob)
|
||||
self.dropout = P.Dropout(keep_prob, seed0, seed1)
|
||||
|
||||
def construct(self, x):
|
||||
if not self.training:
|
||||
return x
|
||||
|
||||
if not self.is_ascend:
|
||||
out, _ = self.dropout(x)
|
||||
return out
|
||||
|
||||
if self.keep_prob == 1:
|
||||
return x
|
||||
|
||||
shape = self.get_shape(x)
|
||||
dtype = P.DType()(x)
|
||||
if _is_float_dtype(dtype):
|
||||
keep_prob = self.cast(self.keep_prob, dtype)
|
||||
else:
|
||||
keep_prob = self.cast(self.keep_prob, mstype.float16)
|
||||
output = self.dropout_gen_mask(shape, keep_prob)
|
||||
return self.dropout_do_mask(x, output, keep_prob)
|
||||
out, _ = self.dropout(x)
|
||||
return out
|
||||
|
||||
def extend_repr(self):
|
||||
return 'keep_prob={}, dtype={}'.format(self.keep_prob, self.dtype)
|
||||
|
|
Loading…
Reference in New Issue