tile's input multiples

new mul node

new mul node

convert const to attr
This commit is contained in:
hwjiaorui 2021-01-06 15:43:24 +08:00
parent eb973172d2
commit 908e9a526b
1 changed files with 39 additions and 20 deletions

View File

@ -50,7 +50,8 @@ ValueNodePtr CreateValueNode(const ValuePtr &value_ptr, TypeId output_type) {
return new_node;
}
CNodePtr CreateOneHot(const FuncGraphPtr &graph, const CNodePtr &sparse_softmax_node, bool is_pynative = false) {
CNodePtr CreateOneHot(const FuncGraphPtr &graph, const CNodePtr &sparse_softmax_node,
bool is_convert_const_to_attr = false) {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(sparse_softmax_node);
@ -80,7 +81,7 @@ CNodePtr CreateOneHot(const FuncGraphPtr &graph, const CNodePtr &sparse_softmax_
one_hot_primitive->set_attr(kAttrOutputNames, MakeValue(output_names));
std::vector<AnfNodePtr> one_hot_inputs;
if (is_pynative) {
if (is_convert_const_to_attr) {
one_hot_inputs = {NewValueNode(one_hot_primitive), sparse_softmax_node->input(2), value_on_node, value_off_node};
} else {
auto depth_node = NewValueNode(depth);
@ -97,7 +98,7 @@ CNodePtr CreateOneHot(const FuncGraphPtr &graph, const CNodePtr &sparse_softmax_
std::vector<size_t> labels_shape = AnfAlgo ::GetPrevNodeOutputInferShape(sparse_softmax_node, 1);
labels_shape.emplace_back(depth);
AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat32}, {labels_shape}, one_hot_node.get());
if (is_pynative) {
if (is_convert_const_to_attr) {
AnfAlgo::SetNodeAttr(kAttrDepth, MakeValue(depth), one_hot_node);
}
return one_hot_node;
@ -252,7 +253,7 @@ CNodePtr CreateExpandDimsPynative(const FuncGraphPtr &graph, const CNodePtr &rea
}
CNodePtr CreateTile(const FuncGraphPtr &graph, const CNodePtr &sparse_softmax_node, const CNodePtr &mul_node,
bool is_pynative = false) {
bool is_convert_const_to_attr = false) {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(sparse_softmax_node);
MS_EXCEPTION_IF_NULL(mul_node);
@ -268,6 +269,9 @@ CNodePtr CreateTile(const FuncGraphPtr &graph, const CNodePtr &sparse_softmax_no
std::vector<int64_t> multiple_value;
std::transform(labels_shape.begin(), labels_shape.end(), std::back_inserter(multiple_value),
[](size_t label) { return static_cast<int64_t>(label); });
if (std::all_of(multiple_value.begin(), multiple_value.end(), [](int64_t value) { return value == 1; })) {
return nullptr;
}
auto multiples = MakeValue(multiple_value);
auto multiples_node = CreateValueNode(multiples, kNumberTypeInt64);
MS_EXCEPTION_IF_NULL(multiples_node);
@ -279,7 +283,7 @@ CNodePtr CreateTile(const FuncGraphPtr &graph, const CNodePtr &sparse_softmax_no
tile_primitive->set_attr(kAttrOutputNames, MakeValue(output_names));
std::vector<AnfNodePtr> tile_inputs;
if (is_pynative) {
if (is_convert_const_to_attr) {
tile_inputs = {NewValueNode(tile_primitive), mul_node->input(2)};
} else {
auto kernel_graph = graph->cast<KernelGraphPtr>();
@ -292,7 +296,7 @@ CNodePtr CreateTile(const FuncGraphPtr &graph, const CNodePtr &sparse_softmax_no
tile_node->set_scope(mul_node->scope());
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetPrevNodeOutputInferDataType(mul_node, 1)}, {labels_shape},
tile_node.get());
if (is_pynative) {
if (is_convert_const_to_attr) {
AnfAlgo::SetNodeAttr(kAttrMultiples, MakeValue(multiples), tile_node);
}
// feature map set
@ -302,7 +306,7 @@ CNodePtr CreateTile(const FuncGraphPtr &graph, const CNodePtr &sparse_softmax_no
return tile_node;
}
CNodePtr CreateRealDiv(const FuncGraphPtr &graph, const CNodePtr &sparse_softmax_node, const CNodePtr &tile_node) {
CNodePtr CreateRealDiv(const FuncGraphPtr &graph, const CNodePtr &sparse_softmax_node, const AnfNodePtr &tile_node) {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(sparse_softmax_node);
MS_EXCEPTION_IF_NULL(tile_node);
@ -464,16 +468,24 @@ const AnfNodePtr GradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR::Process(con
std::vector<AnfNodePtr> softmax_node_outputs;
CreateMultipleOutputsOfAnfNode(graph, softmax_node, kSoftmaxCrossEntropyWithLogitsOutputNum, &softmax_node_outputs);
auto tile_node = CreateTile(graph, sparse_softmax_node_grad, mul_node);
auto real_div_node = CreateRealDiv(graph, sparse_softmax_node_grad, tile_node);
CNodePtr real_div_node;
if (tile_node == nullptr) {
real_div_node = CreateRealDiv(graph, sparse_softmax_node_grad, mul_node->input(2));
} else {
real_div_node = CreateRealDiv(graph, sparse_softmax_node_grad, tile_node);
}
auto expand_dims_node = CreateExpandDims(graph, real_div_node);
mul_node->set_input(1, softmax_node_outputs[1]);
mul_node->set_input(2, expand_dims_node);
std::vector<AnfNodePtr> new_mul_inputs = {NewValueNode(std::make_shared<Primitive>(kMulOpName)),
softmax_node_outputs[1], expand_dims_node};
auto new_mul_node = graph->NewCNode(new_mul_inputs);
MS_EXCEPTION_IF_NULL(new_mul_node);
new_mul_node->set_scope(mul_node->scope());
new_mul_node->set_abstract(mul_node->abstract());
auto manager = graph->manager();
MS_EXCEPTION_IF_NULL(manager);
manager->Replace(sparse_softmax_node_grad, softmax_node_outputs[1]);
return mul_node;
return new_mul_node;
}
const BaseRef GradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIRV2::DefinePattern() const {
@ -563,19 +575,26 @@ const AnfNodePtr PynativeGradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR::Pro
}
CNodePtr softmax_node;
auto one_hot_node = CreateOneHot(graph, sparse_softmax_node_grad, true);
auto one_hot_node = CreateOneHot(graph, sparse_softmax_node_grad);
softmax_node = CreateSoftmaxCrossEntropyWithLogits(graph, sparse_softmax_node_grad, one_hot_node);
std::vector<AnfNodePtr> softmax_node_outputs;
CreateMultipleOutputsOfAnfNode(graph, softmax_node, kSoftmaxCrossEntropyWithLogitsOutputNum, &softmax_node_outputs);
auto tile_node = CreateTile(graph, sparse_softmax_node_grad, mul_node, true);
auto real_div_node = CreateRealDiv(graph, sparse_softmax_node_grad, tile_node);
auto tile_node = CreateTile(graph, sparse_softmax_node_grad, mul_node);
CNodePtr real_div_node;
if (tile_node == nullptr) {
real_div_node = CreateRealDiv(graph, sparse_softmax_node_grad, mul_node->input(2));
} else {
real_div_node = CreateRealDiv(graph, sparse_softmax_node_grad, tile_node);
}
auto expand_dims_node = CreateExpandDimsPynative(graph, real_div_node);
mul_node->set_input(1, softmax_node_outputs[1]);
mul_node->set_input(2, expand_dims_node);
return mul_node;
std::vector<AnfNodePtr> new_mul_inputs = {NewValueNode(std::make_shared<Primitive>(kMulOpName)),
softmax_node_outputs[1], expand_dims_node};
auto new_mul_node = graph->NewCNode(new_mul_inputs);
MS_EXCEPTION_IF_NULL(new_mul_node);
new_mul_node->set_scope(mul_node->scope());
new_mul_node->set_abstract(mul_node->abstract());
return new_mul_node;
}
} // namespace opt
} // namespace mindspore