add recompute nodes

This commit is contained in:
yujianfeng 2020-12-02 09:27:04 +08:00
parent 237faca57e
commit 7b412d7cb2
6 changed files with 580 additions and 6 deletions

View File

@ -38,9 +38,6 @@ AnfNodePtr CreateTensorInput(const AnfNodePtr &node, const KernelGraphPtr &kerne
if (value->isa<Scalar>()) {
tensor_ptr = ScalarToTensor(value->cast<ScalarPtr>());
} else if (value->isa<ValueTuple>()) {
if (!AnfAlgo::IsRealCNodeKernel(node)) {
return nullptr;
}
tensor_ptr = CreateTupleTensor(value->cast<ValueTuplePtr>());
} else {
MS_LOG(EXCEPTION) << "The value should be a scalar or value tuple";
@ -89,7 +86,11 @@ AnfNodePtr ConstInputToTensorInput(const FuncGraphPtr &func_graph, const CNodePt
MS_EXCEPTION_IF_NULL(func_graph);
auto new_cnode = func_graph->NewCNode(new_inputs);
MS_EXCEPTION_IF_NULL(new_cnode);
new_cnode->set_abstract(cnode->abstract());
if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimDepend)) {
new_cnode->set_abstract(new_inputs[1]->abstract());
} else {
new_cnode->set_abstract(cnode->abstract());
}
new_cnode->set_scope(cnode->scope());
AnfAlgo::CopyNodeAttrs(cnode, new_cnode);
if (kernel_graph != nullptr) {
@ -123,7 +124,8 @@ AnfNodePtr ProcessGraphKernelOp(const AnfNodePtr &node) {
const AnfNodePtr ConvertConstInputToTensorInput::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const EquivPtr &) const {
if (node == nullptr || func_graph == nullptr || !AnfAlgo::IsRealCNodeKernel(node)) {
if (node == nullptr || func_graph == nullptr || AnfAlgo::CheckPrimitiveType(node, prim::kPrimTupleGetItem) ||
AnfAlgo::CheckPrimitiveType(node, prim::kPrimMakeTuple)) {
return nullptr;
}
if (!node->isa<CNode>()) {

View File

@ -0,0 +1,442 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "frontend/optimizer/recompute.h"
#include <memory>
#include <queue>
#include <utility>
#include <list>
#include <vector>
#include <unordered_map>
#include <unordered_set>
#include <algorithm>
#include "ir/func_graph.h"
#include "mindspore/core/base/core_ops.h"
namespace mindspore {
namespace opt {
namespace {
constexpr auto kGradientsFlag = "Gradients";
constexpr auto kAttrRecomputed = "recomputed";
constexpr auto kAttrNoRecomputed = "no_recomputed";
bool IsTargetNode(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
if (!node->isa<CNode>()) {
return false;
}
return node->fullname_with_scope().find(kGradientsFlag) == 0;
}
bool HasNoRecomputedAttr(const AnfNodePtr &node) {
auto prim = GetCNodePrimitive(node);
if (prim != nullptr) {
auto no_recompute_val = prim->GetAttr(kAttrNoRecomputed);
if (no_recompute_val != nullptr && no_recompute_val->isa<BoolImm>()) {
return GetValue<bool>(no_recompute_val);
}
}
return false;
}
bool WithRecomputedScope(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
if (!node->isa<CNode>()) {
return false;
}
return node->fullname_with_scope().find(kAttrRecomputed) == 0;
}
bool IsSetRecomputed(const AnfNodePtr &node) {
auto prim = GetCNodePrimitive(node);
if (prim != nullptr) {
auto recompute_val = prim->GetAttr(kAttrRecomputed);
if (recompute_val != nullptr && recompute_val->isa<BoolImm>()) {
return GetValue<bool>(recompute_val);
}
}
return false;
}
bool IsCandidateRecomputedNode(const CNodePtr &node) { return !IsTargetNode(node) && IsSetRecomputed(node); }
std::vector<CNodePtr> FindCandidateRecomputedNodes(const FuncGraphManagerPtr &mng,
const std::vector<CNodePtr> &cnodes) {
MS_EXCEPTION_IF_NULL(mng);
std::vector<CNodePtr> candidate_recomputed_nodes;
for (const auto &cnode : cnodes) {
MS_EXCEPTION_IF_NULL(cnode);
if (!IsCandidateRecomputedNode(cnode)) {
continue;
}
// Check outputs.
const auto &node_users = mng->node_users();
auto output_set_iter = node_users.find(cnode);
if (output_set_iter == node_users.end()) {
continue;
}
const auto &node_index_set = output_set_iter->second;
if (!std::any_of(node_index_set.begin(), node_index_set.end(),
[](const std::pair<AnfNodePtr, int> &node_index) { return IsTargetNode(node_index.first); })) {
continue;
}
// Check inputs.
const auto &inputs = cnode->inputs();
if (std::any_of(inputs.begin(), inputs.end(), [](const AnfNodePtr &node) { return IsTargetNode(node); })) {
continue;
}
candidate_recomputed_nodes.emplace_back(cnode);
}
return candidate_recomputed_nodes;
}
void GetMaxSubGraph(const FuncGraphManagerPtr &mng, std::unordered_set<CNodePtr> *recomputed_nodes, bool get_inputs,
bool get_outputs) {
MS_EXCEPTION_IF_NULL(mng);
MS_EXCEPTION_IF_NULL(recomputed_nodes);
std::queue<CNodePtr> nodes_to_visit;
for (const auto &node : *recomputed_nodes) {
nodes_to_visit.push(node);
}
recomputed_nodes->clear();
while (!nodes_to_visit.empty()) {
auto current_node = nodes_to_visit.front();
nodes_to_visit.pop();
recomputed_nodes->insert(current_node);
if (get_inputs) {
for (const auto &input : current_node->inputs()) {
MS_EXCEPTION_IF_NULL(input);
if (input->isa<CNode>()) {
auto input_cnode = input->cast<CNodePtr>();
if (recomputed_nodes->find(input_cnode) == recomputed_nodes->end() &&
IsCandidateRecomputedNode(input_cnode)) {
nodes_to_visit.push(input_cnode);
}
}
}
}
if (get_outputs) {
const auto &node_users = mng->node_users();
auto output_set_iter = node_users.find(current_node);
if (output_set_iter == node_users.end()) {
continue;
}
for (const auto &node_index_set : output_set_iter->second) {
auto output_node = node_index_set.first;
MS_EXCEPTION_IF_NULL(output_node);
if (output_node->isa<CNode>()) {
auto output_cnode = output_node->cast<CNodePtr>();
if (recomputed_nodes->find(output_cnode) == recomputed_nodes->end() &&
IsCandidateRecomputedNode(output_cnode)) {
nodes_to_visit.push(output_cnode);
}
}
}
}
}
}
void GetOriginRecomputeAndTargetNodes(const FuncGraphManagerPtr &mng,
const std::unordered_set<CNodePtr> &max_recomputed_sub_graph,
std::unordered_set<CNodePtr> *recompute_nodes,
std::unordered_set<CNodePtr> *target_nodes) {
MS_EXCEPTION_IF_NULL(mng);
MS_EXCEPTION_IF_NULL(recompute_nodes);
MS_EXCEPTION_IF_NULL(target_nodes);
const auto &node_users = mng->node_users();
for (const auto &node : max_recomputed_sub_graph) {
bool inserted = false;
auto output_set_iter = node_users.find(node);
if (output_set_iter == node_users.end()) {
continue;
}
for (const auto &node_index_set : output_set_iter->second) {
auto output_node = node_index_set.first;
MS_EXCEPTION_IF_NULL(output_node);
if (!IsTargetNode(output_node)) {
continue;
}
target_nodes->insert(output_node->cast<CNodePtr>());
if (!inserted) {
recompute_nodes->insert(node);
inserted = true;
}
}
}
}
std::vector<AnfNodePtr> GetFirstTargetInputs(const std::vector<CNodePtr> &origin_nodes_topological,
const std::unordered_set<CNodePtr> &recomputed_origin_nodes,
const std::unordered_set<CNodePtr> &target_nodes) {
std::vector<AnfNodePtr> first_target_inputs;
for (const auto &node : origin_nodes_topological) {
MS_EXCEPTION_IF_NULL(node);
if (target_nodes.find(node) != target_nodes.end()) {
for (size_t i = 1; i < node->size(); ++i) {
auto input = node->input(i);
if (!input->isa<CNode>()) {
continue;
}
MS_EXCEPTION_IF_NULL(input);
if (recomputed_origin_nodes.find(input->cast<CNodePtr>()) != recomputed_origin_nodes.end()) {
continue;
}
first_target_inputs.emplace_back(input);
}
break;
}
}
return first_target_inputs;
}
bool HasGradInputs(const AnfNodePtr &node, std::unordered_map<AnfNodePtr, bool> *has_grad_inputs_map) {
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(has_grad_inputs_map);
if (has_grad_inputs_map->find(node) != has_grad_inputs_map->end()) {
return has_grad_inputs_map->find(node)->second;
}
auto cnode = node->cast<CNodePtr>();
if (cnode == nullptr) {
has_grad_inputs_map->insert(std::make_pair(node, false));
return false;
}
const auto &inputs = cnode->inputs();
if (std::any_of(inputs.begin(), inputs.end(), [&has_grad_inputs_map](const AnfNodePtr &input) {
return IsTargetNode(input) || HasGradInputs(input, has_grad_inputs_map);
})) {
has_grad_inputs_map->insert(std::make_pair(node, true));
return true;
}
has_grad_inputs_map->insert(std::make_pair(node, false));
return false;
}
bool HasForwardOutput(const FuncGraphManagerPtr &mng, const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(mng);
const auto &node_users = mng->node_users();
auto output_set_iter = node_users.find(node);
if (output_set_iter == node_users.end()) {
return false;
}
for (const auto &node_index_set : output_set_iter->second) {
if (!IsTargetNode(node_index_set.first) && !IsPrimitiveCNode(node_index_set.first, prim::kPrimControlDepend)) {
return true;
}
}
return false;
}
void GetTupleGetItemOutputNodes(const FuncGraphManagerPtr &mng, const AnfNodePtr &node,
std::vector<AnfNodePtr> *tuple_getitem_output_nodes) {
MS_EXCEPTION_IF_NULL(mng);
MS_EXCEPTION_IF_NULL(tuple_getitem_output_nodes);
const auto &node_users = mng->node_users();
auto output_set_iter = node_users.find(node);
if (output_set_iter == node_users.end()) {
return;
}
for (const auto &node_index_set : output_set_iter->second) {
if (IsPrimitiveCNode(node_index_set.first, prim::kPrimTupleGetItem)) {
tuple_getitem_output_nodes->emplace_back(node_index_set.first);
}
}
}
// Set 'recomputed' attr for the nodes according to its scope.
// A node set 'recomputed' attr can be the candidate recomputed node.
void SetRecomputedAttr(const FuncGraphPtr &graph, const std::vector<CNodePtr> &origin_nodes_topological) {
MS_EXCEPTION_IF_NULL(graph);
auto mng = graph->manager();
MS_EXCEPTION_IF_NULL(mng);
std::unordered_map<AnfNodePtr, bool> has_grad_inputs_map;
for (const auto &node : origin_nodes_topological) {
MS_EXCEPTION_IF_NULL(node);
if (!WithRecomputedScope(node) || HasNoRecomputedAttr(node)) {
continue;
}
auto prim = GetCNodePrimitive(node);
if (prim == nullptr || prim->name() == prim::kPrimTupleGetItem->name() ||
prim->name() == prim::kPrimAllGather->name()) {
continue;
}
if (!HasForwardOutput(mng, node) || HasGradInputs(node, &has_grad_inputs_map)) {
continue;
}
// Make a new primitive to set attr because some nodes share the same primitive probably.
auto new_prim = std::make_shared<Primitive>(prim->name());
new_prim->SetAttrs(prim->attrs());
new_prim->set_prim_type(prim->prim_type());
new_prim->set_attr(kAttrRecomputed, MakeValue(true));
std::vector<AnfNodePtr> new_inputs{NewValueNode(new_prim)};
const auto &origin_inputs = node->inputs();
std::copy(origin_inputs.begin() + 1, origin_inputs.end(), std::back_inserter(new_inputs));
auto new_node = graph->NewCNode(new_inputs);
new_node->set_abstract(node->abstract());
new_node->set_scope(node->scope());
mng->Replace(node, new_node);
// Set attr for the tuple_getitem outputs.
std::vector<AnfNodePtr> tuple_getitem_output_nodes;
GetTupleGetItemOutputNodes(mng, new_node, &tuple_getitem_output_nodes);
for (const auto &output_node : tuple_getitem_output_nodes) {
auto new_output_prim = std::make_shared<Primitive>(prim::kPrimTupleGetItem->name());
new_output_prim->set_attr(kAttrRecomputed, MakeValue(true));
std::vector<AnfNodePtr> new_tuple_getitem_inputs{NewValueNode(new_output_prim)};
auto origin_tuple_getitem_inputs = output_node->cast<CNodePtr>()->inputs();
std::copy(origin_tuple_getitem_inputs.begin() + 1, origin_tuple_getitem_inputs.end(),
std::back_inserter(new_tuple_getitem_inputs));
auto new_tuple_getitem = graph->NewCNode(new_tuple_getitem_inputs);
new_tuple_getitem->set_abstract(output_node->abstract());
mng->Replace(output_node, new_tuple_getitem);
}
}
}
CNodePtr NewRecomputedNode(const FuncGraphPtr &graph, const CNodePtr &origin_node,
const std::vector<AnfNodePtr> &first_target_inputs,
const std::unordered_set<CNodePtr> &recomputed_origin_nodes,
std::unordered_map<CNodePtr, CNodePtr> *origin_to_recomputed_nodes) {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(origin_node);
MS_EXCEPTION_IF_NULL(origin_to_recomputed_nodes);
auto iter = origin_to_recomputed_nodes->find(origin_node);
if (iter != origin_to_recomputed_nodes->end()) {
return iter->second;
}
MS_LOG(DEBUG) << "Begin to Duplicating origin recomputed node: " << origin_node->DebugString();
auto prim = GetCNodePrimitive(origin_node);
MS_EXCEPTION_IF_NULL(prim);
auto new_prim = std::make_shared<Primitive>(prim->name());
new_prim->SetAttrs(prim->attrs());
new_prim->set_attr("duplicated", MakeValue(true));
new_prim->set_prim_type(prim->prim_type());
std::vector<AnfNodePtr> new_inputs{NewValueNode(new_prim)};
bool has_recomputed_inputs = false;
for (size_t i = 1; i < origin_node->size(); ++i) {
auto input = origin_node->input(i);
MS_EXCEPTION_IF_NULL(input);
if (!input->isa<CNode>()) {
new_inputs.emplace_back(input);
continue;
}
auto input_cnode = input->cast<CNodePtr>();
if (recomputed_origin_nodes.find(input_cnode) == recomputed_origin_nodes.end()) {
new_inputs.emplace_back(input);
} else {
has_recomputed_inputs = true;
new_inputs.emplace_back(NewRecomputedNode(graph, input_cnode, first_target_inputs, recomputed_origin_nodes,
origin_to_recomputed_nodes));
}
}
// Add the execution dependency.
if (!has_recomputed_inputs && new_inputs.size() > 1) {
std::vector<AnfNodePtr> make_tuple_inputs{NewValueNode(prim::kPrimMakeTuple)};
std::copy(first_target_inputs.begin(), first_target_inputs.end(), std::back_inserter(make_tuple_inputs));
auto first_input = new_inputs[1];
MS_EXCEPTION_IF_NULL(first_input);
std::vector<AnfNodePtr> depend_inputs{NewValueNode(prim::kPrimDepend), first_input,
graph->NewCNode(make_tuple_inputs)};
auto depend_node = graph->NewCNode(depend_inputs);
MS_EXCEPTION_IF_NULL(depend_node);
depend_node->set_abstract(first_input->abstract());
new_inputs[1] = depend_node;
}
auto recomputed_node = graph->NewCNode(new_inputs);
recomputed_node->set_abstract(origin_node->abstract());
recomputed_node->set_scope(origin_node->scope());
origin_to_recomputed_nodes->insert(std::make_pair(origin_node, recomputed_node));
return recomputed_node;
}
void DuplicateRecomputedNodes(const FuncGraphPtr &graph, const std::unordered_set<CNodePtr> &target_nodes,
const std::unordered_set<CNodePtr> &origin_recomputed_nodes,
const std::vector<AnfNodePtr> &first_target_inputs,
std::unordered_map<CNodePtr, CNodePtr> *origin_to_recomputed_nodes) {
MS_EXCEPTION_IF_NULL(graph);
auto mng = graph->manager();
MS_EXCEPTION_IF_NULL(mng);
for (const auto &target_node : target_nodes) {
MS_EXCEPTION_IF_NULL(target_node);
MS_LOG(DEBUG) << "Rebuild a new target_node " << target_node->DebugString() << " with the new recomputed input";
auto target_cnode = target_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(target_cnode);
auto prim = GetCNodePrimitive(target_cnode);
if (prim != nullptr) {
prim->set_attr("target_grad", MakeValue(true));
}
std::vector<AnfNodePtr> new_target_inputs;
for (const auto &input : target_cnode->inputs()) {
MS_EXCEPTION_IF_NULL(input);
if (!input->isa<CNode>()) {
new_target_inputs.emplace_back(input);
} else {
auto input_cnode = input->cast<CNodePtr>();
if (origin_recomputed_nodes.find(input_cnode) != origin_recomputed_nodes.end()) {
new_target_inputs.emplace_back(NewRecomputedNode(graph, input_cnode, first_target_inputs,
origin_recomputed_nodes, origin_to_recomputed_nodes));
} else {
new_target_inputs.emplace_back(input_cnode);
}
}
}
auto new_target_node = graph->NewCNode(new_target_inputs);
new_target_node->set_abstract(target_node->abstract());
new_target_node->set_scope(target_node->scope());
mng->Replace(target_node, new_target_node);
}
}
} // namespace
void InsertRecomputedNodes(const FuncGraphPtr &graph) {
MS_EXCEPTION_IF_NULL(graph);
auto mng = graph->manager();
MS_EXCEPTION_IF_NULL(mng);
std::list<CNodePtr> old_orders = graph->GetOrderedCnodes();
std::vector<CNodePtr> old_nodes_topological(old_orders.begin(), old_orders.end());
SetRecomputedAttr(graph, old_nodes_topological);
std::list<CNodePtr> new_orders = graph->GetOrderedCnodes();
std::vector<CNodePtr> origin_nodes_topological(new_orders.begin(), new_orders.end());
// Get candidate origin recomputed nodes which have no grad inputs and output to at least one grad node directly.
std::vector<CNodePtr> candidate_recomputed_nodes = FindCandidateRecomputedNodes(mng, origin_nodes_topological);
std::unordered_set<CNodePtr> visited_nodes;
for (const auto &candidate_recomputed_node : candidate_recomputed_nodes) {
if (visited_nodes.find(candidate_recomputed_node) != visited_nodes.end()) {
continue;
}
std::unordered_set<CNodePtr> max_recomputed_sub_graph = {candidate_recomputed_node};
// Get max continuous recomputed sub-graph.
GetMaxSubGraph(mng, &max_recomputed_sub_graph, true, true);
visited_nodes.insert(max_recomputed_sub_graph.begin(), max_recomputed_sub_graph.end());
// Get the origin recomputed nodes which directly output to the grad nodes.
std::unordered_set<CNodePtr> origin_recomputed_nodes;
std::unordered_set<CNodePtr> target_nodes;
GetOriginRecomputeAndTargetNodes(mng, max_recomputed_sub_graph, &origin_recomputed_nodes, &target_nodes);
// Also get the inputs of origin recomputed nodes which eventually output to the grad nodes.
GetMaxSubGraph(mng, &origin_recomputed_nodes, true, false);
// Get the inputs of the first target node in the topological sequence. The duplicated recomputed nodes should
// not be executed until these inputs are ready.
std::vector<AnfNodePtr> first_target_inputs =
GetFirstTargetInputs(origin_nodes_topological, origin_recomputed_nodes, target_nodes);
std::unordered_map<CNodePtr, CNodePtr> origin_to_recomputed_nodes;
// Begin duplicate origin recomputed nodes with each target node.
DuplicateRecomputedNodes(graph, target_nodes, origin_recomputed_nodes, first_target_inputs,
&origin_to_recomputed_nodes);
}
}
} // namespace opt
} // namespace mindspore

View File

@ -0,0 +1,28 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_RECOMPUTE_H_
#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_RECOMPUTE_H_
#include "ir/anf.h"
namespace mindspore {
namespace opt {
// Automatically insert duplicated recomputed nodes.
void InsertRecomputedNodes(const FuncGraphPtr &graph);
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_RECOMPUTE_H_

View File

@ -37,6 +37,7 @@
#include "frontend/parallel/step_parallel.h"
#include "frontend/parallel/step_auto_parallel.h"
#include "frontend/parallel/allreduce_fusion/step_allreduce_fusion.h"
#include "frontend/optimizer/recompute.h"
#include "utils/log_adapter.h"
#include "pipeline/jit/pipeline_split.h"
@ -383,6 +384,12 @@ bool AddControlDependPass(const ResourcePtr &res) {
return true;
}
bool AddRecomputationPass(const ResourcePtr &res) {
MS_EXCEPTION_IF_NULL(res);
opt::InsertRecomputedNodes(res->func_graph());
return true;
}
bool MergeDupGraphPass(const ResourcePtr &res) {
FuncGraphPtr func_graph = res->func_graph();
MS_EXCEPTION_IF_NULL(func_graph);
@ -474,7 +481,8 @@ std::vector<PassItem> kVmPasses = {{"simplify_data_structures", SimplifyDataStru
{"tuple_transform", OptPassTransformGraphGroup},
{"opt_graph_kernel_a", OptPassGraphKernelGroupA},
{"opt_graph_kernel_b", OptPassGraphKernelGroupB},
{"add_control_depend", AddControlDependPass}};
{"add_control_depend", AddControlDependPass},
{"add_recomputation", AddRecomputationPass}};
std::vector<PassItem> kGePasses = {{"simplify_data_structures", SimplifyDataStructuresPass},
{"opt_a", OptPassAGroup},

View File

@ -913,6 +913,8 @@ class Cell(Cell_):
"""Sets the name on the first time."""
if self._scope is None:
self._scope = name
elif self._scope == 'recomputed':
self._scope = self._scope + "_" + name
def _children_scope_recursive(self, parent_prefix='Default'):
"""Generates the scope of each layer of the network recursively."""
@ -1093,6 +1095,15 @@ class Cell(Cell_):
param.comm_fusion = fusion_type
return self
def recompute(self):
"""
Set the cell recomputed. All the primitive in the cell will be set recomputed. If a primitive feeds into a grad
node and is set recomputed, we will compute it again for the grad node after the forward computation.
"""
self._set_scope('recomputed')
for cell in self.cells():
cell.recompute()
class GraphKernel(Cell):
"""

View File

@ -0,0 +1,83 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.nn import TrainOneStepCell, WithLossCell
from mindspore.nn.optim import Momentum
from mindspore.ops import operations as P
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
class LeNet(nn.Cell):
def __init__(self):
super(LeNet, self).__init__()
self.relu = P.ReLU()
self.batch_size = 32
self.conv1 = nn.Conv2d(1, 6, kernel_size=5, stride=1, padding=0, has_bias=False, pad_mode='valid')
self.conv2 = nn.Conv2d(6, 16, kernel_size=5, stride=1, padding=0, has_bias=False, pad_mode='valid')
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
self.pool.recompute()
self.reshape = P.Reshape()
self.fc1 = nn.Dense(400, 120)
self.fc2 = nn.Dense(120, 84)
self.fc3 = nn.Dense(84, 10)
def construct(self, input_x):
output = self.conv1(input_x)
output = self.relu(output)
output = self.pool(output)
output = self.conv2(output)
output = self.relu(output)
output = self.pool(output)
output = self.reshape(output, (self.batch_size, -1))
output = self.fc1(output)
output = self.relu(output)
output = self.fc2(output)
output = self.relu(output)
output = self.fc3(output)
return output
def train(net, data, label):
learning_rate = 0.01
momentum = 0.9
optimizer = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), learning_rate, momentum)
criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
net_with_criterion = WithLossCell(net, criterion)
train_network = TrainOneStepCell(net_with_criterion, optimizer) # optimizer
train_network.set_train()
res = train_network(data, label)
print("+++++++++Loss+++++++++++++")
print(res)
print("+++++++++++++++++++++++++++")
diff = res.asnumpy() - 2.302585
assert np.all(diff < 1.e-6)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_lenet():
data = Tensor(np.ones([32, 1, 32, 32]).astype(np.float32) * 0.01)
label = Tensor(np.ones([32]).astype(np.int32))
net = LeNet()
train(net, data, label)