!4922 Transform tuple parameter to multiple parameters
Merge pull request !4922 from amongo/TupleTransform
This commit is contained in:
commit
c2fddb56c8
|
@ -0,0 +1,144 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "frontend/optimizer/graph_transform.h"
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include <string>
|
||||
#include "ir/graph_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
/* namespace to support opt */
|
||||
namespace opt {
|
||||
// check cnode input values, whether it is tuple input
|
||||
bool CNodeHasTupleInput(const CNodePtr &cnode) {
|
||||
auto &inputs = cnode->inputs();
|
||||
for (size_t i = 1; i < inputs.size(); i++) {
|
||||
if (IsValueNode<FuncGraph>(inputs[i])) {
|
||||
continue;
|
||||
}
|
||||
if (IsValueNode<Primitive>(inputs[i])) {
|
||||
// unexpected high order primitvie as cnode input when transform graph
|
||||
MS_LOG(WARNING) << "CheckTupleInput, got unexpected primitve as input" << cnode->DebugString();
|
||||
return false;
|
||||
}
|
||||
auto abs = inputs[i]->abstract();
|
||||
if (abs == nullptr) {
|
||||
MS_LOG(WARNING) << "CheckTupleInput, got abstract nullptr for node:" << cnode->DebugString();
|
||||
return false;
|
||||
}
|
||||
if (abs->isa<abstract::AbstractTuple>()) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool FuncGraphHasTupleInput(const FuncGraphPtr &fg) {
|
||||
auto ¶ms = fg->parameters();
|
||||
for (auto ¶m : params) {
|
||||
if (param->abstract()->isa<abstract::AbstractTuple>()) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
std::vector<AnfNodePtr> TransformTupleArgument(const FuncGraphPtr &fg, const AnfNodePtr &node,
|
||||
const abstract::AbstractTuplePtr &abs) {
|
||||
auto &elements = abs->elements();
|
||||
std::vector<AnfNodePtr> tuple_node_expanded;
|
||||
for (size_t i = 0; i < elements.size(); i++) {
|
||||
auto elem_node = fg->NewCNode({NewValueNode(prim::kPrimTupleGetItem), node, NewValueNode(SizeToInt(i))});
|
||||
elem_node->set_abstract(elements[i]);
|
||||
if (elements[i]->isa<abstract::AbstractTuple>()) {
|
||||
auto nodes = TransformTupleArgument(fg, elem_node, elements[i]->cast<abstract::AbstractTuplePtr>());
|
||||
tuple_node_expanded.insert(tuple_node_expanded.end(), nodes.begin(), nodes.end());
|
||||
} else {
|
||||
tuple_node_expanded.push_back(elem_node);
|
||||
}
|
||||
}
|
||||
return tuple_node_expanded;
|
||||
}
|
||||
|
||||
AnfNodePtr TransformCallGraph(const FuncGraphPtr &trans_fg, const CNodePtr &cnode) {
|
||||
auto &cinputs = cnode->inputs();
|
||||
auto fg = cnode->func_graph();
|
||||
std::vector<AnfNodePtr> inputs;
|
||||
inputs.push_back(NewValueNode(trans_fg));
|
||||
for (size_t i = 1; i < cinputs.size(); i++) {
|
||||
auto abs = cinputs[i]->abstract();
|
||||
if (abs == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "TransformCallGraph:Node abstract should not be nullptr" << cinputs[i]->DebugString();
|
||||
}
|
||||
if (abs->isa<abstract::AbstractTuple>()) {
|
||||
auto nodes = TransformTupleArgument(fg, cinputs[i], abs->cast<abstract::AbstractTuplePtr>());
|
||||
inputs.insert(inputs.end(), nodes.begin(), nodes.end());
|
||||
} else {
|
||||
inputs.push_back(cinputs[i]);
|
||||
}
|
||||
}
|
||||
auto new_node = fg->NewCNode(inputs);
|
||||
new_node->set_abstract(cnode->abstract());
|
||||
return new_node;
|
||||
}
|
||||
|
||||
AnfNodePtr TransformPartial(const FuncGraphPtr &trans_fg, const CNodePtr &cnode) {
|
||||
auto &cinputs = cnode->inputs();
|
||||
auto fg = cnode->func_graph();
|
||||
std::vector<AnfNodePtr> inputs;
|
||||
inputs.push_back(NewValueNode(prim::kPrimPartial));
|
||||
inputs.push_back(NewValueNode(trans_fg));
|
||||
for (size_t i = 2; i < cinputs.size(); i++) {
|
||||
auto abs = cinputs[i]->abstract();
|
||||
if (abs == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "TransformPartial:Node abstract should not be nullptr" << cinputs[i]->DebugString();
|
||||
}
|
||||
if (abs->isa<abstract::AbstractTuple>()) {
|
||||
auto nodes = TransformTupleArgument(fg, cinputs[i], abs->cast<abstract::AbstractTuplePtr>());
|
||||
inputs.insert(inputs.end(), nodes.begin(), nodes.end());
|
||||
} else {
|
||||
inputs.push_back(cinputs[i]);
|
||||
}
|
||||
}
|
||||
auto new_node = fg->NewCNode(inputs);
|
||||
new_node->set_abstract(cnode->abstract());
|
||||
return new_node;
|
||||
}
|
||||
|
||||
AnfNodePtr TransformSwitchCall(const AnfNodePtr &swtich_node, const CNodePtr &cnode) {
|
||||
auto &cinputs = cnode->inputs();
|
||||
auto fg = cnode->func_graph();
|
||||
std::vector<AnfNodePtr> inputs;
|
||||
inputs.push_back(swtich_node);
|
||||
for (size_t i = 1; i < cinputs.size(); i++) {
|
||||
auto abs = cinputs[i]->abstract();
|
||||
if (abs == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "TransformSwitchCall:Node abstract should not be nullptr" << cinputs[i]->DebugString();
|
||||
}
|
||||
if (abs->isa<abstract::AbstractTuple>()) {
|
||||
auto nodes = TransformTupleArgument(fg, cinputs[i], abs->cast<abstract::AbstractTuplePtr>());
|
||||
inputs.insert(inputs.end(), nodes.begin(), nodes.end());
|
||||
} else {
|
||||
inputs.push_back(cinputs[i]);
|
||||
}
|
||||
}
|
||||
auto new_node = fg->NewCNode(inputs);
|
||||
new_node->set_abstract(cnode->abstract());
|
||||
return new_node;
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,108 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_GRAPH_TRANSFORM_H
|
||||
#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_GRAPH_TRANSFORM_H
|
||||
|
||||
#include <unordered_map>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
|
||||
#include "frontend/optimizer/optimizer.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
|
||||
bool CNodeHasTupleInput(const CNodePtr &cnode);
|
||||
bool FuncGraphHasTupleInput(const FuncGraphPtr &fg);
|
||||
std::vector<AnfNodePtr> TransformTupleArgument(const FuncGraphPtr &fg, const AnfNodePtr &node,
|
||||
const abstract::AbstractTuplePtr &abs);
|
||||
AnfNodePtr TransformCallGraph(const FuncGraphPtr &trans_fg, const CNodePtr &cnode);
|
||||
AnfNodePtr TransformPartial(const FuncGraphPtr &trans_fg, const CNodePtr &cnode);
|
||||
AnfNodePtr TransformSwitchCall(const AnfNodePtr &swtich_node, const CNodePtr &cnode);
|
||||
|
||||
class GraphTupleParamTransform {
|
||||
public:
|
||||
GraphTupleParamTransform() : cache_() {}
|
||||
~GraphTupleParamTransform() { cache_.clear(); }
|
||||
FuncGraphPtr operator()(const FuncGraphPtr &fg, const FuncGraphManagerPtr &mng) {
|
||||
if (cache_.find(fg) != cache_.end()) {
|
||||
return cache_[fg];
|
||||
}
|
||||
auto new_fg = TransformGraphParam(fg, mng);
|
||||
cache_[fg] = new_fg;
|
||||
return new_fg;
|
||||
}
|
||||
|
||||
AnfNodePtr GenerateTupleParams(const abstract::AbstractTuplePtr &tuple_abs, const FuncGraphPtr &fg,
|
||||
std::vector<AnfNodePtr> *params) {
|
||||
std::vector<AnfNodePtr> inputs;
|
||||
inputs.push_back(NewValueNode(prim::kPrimMakeTuple));
|
||||
auto &elements = tuple_abs->elements();
|
||||
for (auto &item : elements) {
|
||||
if (item->isa<abstract::AbstractTuple>()) {
|
||||
inputs.push_back(GenerateTupleParams(item->cast<abstract::AbstractTuplePtr>(), fg, params));
|
||||
} else {
|
||||
auto p = std::make_shared<Parameter>(fg);
|
||||
p->set_abstract(item);
|
||||
params->push_back(p);
|
||||
inputs.push_back(params->back());
|
||||
}
|
||||
}
|
||||
auto node = fg->NewCNode(inputs);
|
||||
node->set_abstract(tuple_abs);
|
||||
return node;
|
||||
}
|
||||
|
||||
FuncGraphPtr TransformGraphParam(const FuncGraphPtr &fg, const FuncGraphManagerPtr &mng) {
|
||||
Cloner cloner({fg}, false, false, false, std::make_shared<TraceCopy>(), std::make_shared<TraceCopy>());
|
||||
auto new_fg = cloner[fg];
|
||||
auto ¶ms = new_fg->parameters();
|
||||
std::vector<AnfNodePtr> new_params;
|
||||
std::unordered_map<AnfNodePtr, AnfNodePtr> repl;
|
||||
for (auto ¶m : params) {
|
||||
auto abs = param->abstract();
|
||||
if (abs != nullptr && abs->isa<abstract::AbstractTuple>()) {
|
||||
auto tuple_abs = abs->cast<abstract::AbstractTuplePtr>();
|
||||
std::vector<AnfNodePtr> tuple_params;
|
||||
repl.emplace(param, GenerateTupleParams(tuple_abs, new_fg, &tuple_params));
|
||||
std::transform(tuple_params.begin(), tuple_params.end(), std::back_inserter(new_params),
|
||||
[](AnfNodePtr p) { return p; });
|
||||
} else {
|
||||
new_params.push_back(param);
|
||||
}
|
||||
}
|
||||
auto tmp_mng = mindspore::Manage(new_fg, false);
|
||||
auto tr = tmp_mng->Transact();
|
||||
for (auto &item : repl) {
|
||||
bool ret = tr.Replace(item.first, item.second);
|
||||
if (ret == false) {
|
||||
MS_LOG(ERROR) << "replace failed" << item.first->DebugString() << " with__" << item.second->DebugString(2);
|
||||
}
|
||||
}
|
||||
tr.SetParameters(new_fg, new_params);
|
||||
tr.Commit();
|
||||
mng->AddFuncGraph(new_fg);
|
||||
return new_fg;
|
||||
}
|
||||
std::unordered_map<FuncGraphPtr, FuncGraphPtr> cache_;
|
||||
};
|
||||
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_GRAPH_TRANSFORM_H
|
|
@ -44,6 +44,7 @@
|
|||
#include "frontend/optimizer/irpass/row_tensor_eliminate.h"
|
||||
#include "frontend/optimizer/irpass/sparse_tensor_eliminate.h"
|
||||
#include "frontend/optimizer/irpass/switch_layer_defer_inline.h"
|
||||
#include "frontend/optimizer/irpass/call_graph_tuple_transform.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
|
@ -158,6 +159,10 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
|
|||
unused_output_eliminate_ =
|
||||
MakeSubstitution(std::make_shared<UnusedOutputEliminater>(), "unused_output_eliminate", IsCNodeGraphKernel);
|
||||
|
||||
// tuple parameter graph transform
|
||||
call_graph_tuple_transform_ =
|
||||
MakeSubstitution(std::make_shared<CallGraphTupleTransform>(), "graph_param_transorm", IsCNode);
|
||||
|
||||
// AddN eliminate
|
||||
addn_eliminate_ = MakeSubstitution(std::make_shared<AddNEliminater>(), "addn_eliminate", IsCNodeGraphKernel);
|
||||
|
||||
|
|
|
@ -103,6 +103,9 @@ class OptimizeIRPassLib {
|
|||
SubstitutionPtr unused_parameter_eliminate_;
|
||||
SubstitutionPtr unused_output_eliminate_;
|
||||
|
||||
// tuple parameter graph transform
|
||||
SubstitutionPtr call_graph_tuple_transform_;
|
||||
|
||||
// AddN eliminate
|
||||
SubstitutionPtr addn_eliminate_;
|
||||
|
||||
|
|
|
@ -0,0 +1,246 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_CALL_GRAPH_TRANSFORM_H_
|
||||
#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_CALL_GRAPH_TRANSFORM_H_
|
||||
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
||||
#include "ir/func_graph.h"
|
||||
#include "ir/func_graph_cloner.h"
|
||||
#include "frontend/optimizer/optimizer_caller.h"
|
||||
#include "frontend/optimizer/anf_visitor.h"
|
||||
#include "frontend/operator/ops.h"
|
||||
#include "frontend/optimizer/irpass.h"
|
||||
#include "frontend/optimizer/optimizer.h"
|
||||
#include "frontend/optimizer/graph_transform.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace irpass {
|
||||
// {G, Xs}-->transform graph call tuple inputs to flat inputs.
|
||||
class GraphCallTupleTransform : public AnfVisitor {
|
||||
public:
|
||||
explicit GraphCallTupleTransform(GraphTupleParamTransform &transformer) : graph_transform_(transformer) {}
|
||||
~GraphCallTupleTransform() override = default;
|
||||
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override {
|
||||
if (!node->isa<CNode>() || node->func_graph() == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto &inputs = cnode->inputs();
|
||||
auto fg = GetValueNode<FuncGraphPtr>(inputs[0]);
|
||||
if (fg == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
if (!CNodeHasTupleInput(node->cast<CNodePtr>())) {
|
||||
return nullptr;
|
||||
}
|
||||
FuncGraphPtr transformed_fg = graph_transform_(fg, optimizer->manager());
|
||||
auto new_node = TransformCallGraph(transformed_fg, node->cast<CNodePtr>());
|
||||
return new_node;
|
||||
}
|
||||
|
||||
private:
|
||||
GraphTupleParamTransform &graph_transform_;
|
||||
};
|
||||
|
||||
// {{switch, cond, true_branch, false_branch}, Xs} -->transform switch graph call tuple inputs to flat inputs.
|
||||
class SwitchCallTupleTransform : public AnfVisitor {
|
||||
public:
|
||||
explicit SwitchCallTupleTransform(GraphTupleParamTransform &transformer) : graph_transform_(transformer) {}
|
||||
~SwitchCallTupleTransform() override = default;
|
||||
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override {
|
||||
if (!node->isa<CNode>() || node->func_graph() == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
auto switch_call_cnode = node->cast<CNodePtr>();
|
||||
auto call_inputs = switch_call_cnode->inputs();
|
||||
if (call_inputs.size() < 1) {
|
||||
return nullptr;
|
||||
}
|
||||
if (!IsPrimitiveCNode(call_inputs[0], prim::kPrimSwitch)) {
|
||||
return nullptr;
|
||||
}
|
||||
auto swich_cnode = call_inputs[0]->cast<CNodePtr>();
|
||||
auto switch_inputs = swich_cnode->inputs();
|
||||
if (switch_inputs.size() != 4) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
AnfNodePtr transformed = nullptr;
|
||||
bool true_br_changed = TransformBranchNode(switch_inputs[2], optimizer->manager(), &transformed);
|
||||
if (true_br_changed) {
|
||||
switch_inputs[2] = transformed;
|
||||
}
|
||||
bool false_br_changed = TransformBranchNode(switch_inputs[3], optimizer->manager(), &transformed);
|
||||
if (false_br_changed) {
|
||||
switch_inputs[3] = transformed;
|
||||
}
|
||||
if (true_br_changed || false_br_changed) {
|
||||
call_inputs[0] = swich_cnode->func_graph()->NewCNode(switch_inputs);
|
||||
}
|
||||
if (CNodeHasTupleInput(switch_call_cnode)) {
|
||||
return TransformSwitchCall(call_inputs[0], switch_call_cnode);
|
||||
}
|
||||
if (true_br_changed || false_br_changed) {
|
||||
return switch_call_cnode->func_graph()->NewCNode(call_inputs);
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
bool TransformBranchNode(AnfNodePtr node, FuncGraphManagerPtr mng, AnfNodePtr *trans_node) {
|
||||
if (IsValueNode<FuncGraph>(node)) {
|
||||
FuncGraphPtr fg = GetValueNode<FuncGraphPtr>(node);
|
||||
if (FuncGraphHasTupleInput(fg)) {
|
||||
FuncGraphPtr transformed_fg = graph_transform_(fg, mng);
|
||||
*trans_node = NewValueNode(transformed_fg);
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
if (IsPrimitiveCNode(node, prim::kPrimPartial)) {
|
||||
auto partial_inputs = node->cast<CNodePtr>()->inputs();
|
||||
if (IsValueNode<FuncGraph>(partial_inputs[1])) {
|
||||
FuncGraphPtr fg = GetValueNode<FuncGraphPtr>(partial_inputs[1]);
|
||||
if (FuncGraphHasTupleInput(fg)) {
|
||||
fg = graph_transform_(fg, mng);
|
||||
}
|
||||
if (CNodeHasTupleInput(node->cast<CNodePtr>())) {
|
||||
*trans_node = TransformPartial(fg, node->cast<CNodePtr>());
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
MS_LOG(WARNING) << "Got unexpected switch branch node " << node->DebugString();
|
||||
return false;
|
||||
}
|
||||
|
||||
private:
|
||||
GraphTupleParamTransform &graph_transform_;
|
||||
};
|
||||
|
||||
// {{switch_layer, index, {make_tuple, br1, br2,...,}}, Xs} ->
|
||||
// transform switch layer graph call tuple inputs to flat inputs.
|
||||
class SwitchLayerCallTupleTransform : public AnfVisitor {
|
||||
public:
|
||||
explicit SwitchLayerCallTupleTransform(GraphTupleParamTransform &transformer) : graph_transform_(transformer) {}
|
||||
~SwitchLayerCallTupleTransform() override = default;
|
||||
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override {
|
||||
if (!node->isa<CNode>() || node->func_graph() == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
auto switch_layer_call_cnode = node->cast<CNodePtr>();
|
||||
auto call_inputs = switch_layer_call_cnode->inputs();
|
||||
if (call_inputs.size() < 1) {
|
||||
return nullptr;
|
||||
}
|
||||
if (!IsPrimitiveCNode(call_inputs[0], prim::kPrimSwitchLayer)) {
|
||||
return nullptr;
|
||||
}
|
||||
auto swich_layer_cnode = call_inputs[0]->cast<CNodePtr>();
|
||||
auto switch_layer_inputs = swich_layer_cnode->inputs();
|
||||
if (switch_layer_inputs.size() != 3) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
AnfNodePtr transformed = nullptr;
|
||||
bool layer_changed = TransformLayerNode(switch_layer_inputs[2], optimizer->manager(), &transformed);
|
||||
if (layer_changed) {
|
||||
switch_layer_inputs[2] = transformed;
|
||||
call_inputs[0] = switch_layer_call_cnode->func_graph()->NewCNode(switch_layer_inputs);
|
||||
}
|
||||
if (CNodeHasTupleInput(switch_layer_call_cnode)) {
|
||||
return TransformSwitchCall(call_inputs[0], switch_layer_call_cnode);
|
||||
}
|
||||
if (layer_changed) {
|
||||
return switch_layer_call_cnode->func_graph()->NewCNode(call_inputs);
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
bool TransformLayerNode(AnfNodePtr node, FuncGraphManagerPtr mng, AnfNodePtr *trans_node) {
|
||||
if (!IsPrimitiveCNode(node, prim::kPrimMakeTuple)) {
|
||||
MS_LOG(WARNING) << "SwitchLayer input is not MakeTuple";
|
||||
return false;
|
||||
}
|
||||
auto tuple_inputs = node->cast<CNodePtr>()->inputs();
|
||||
bool changed = false;
|
||||
for (size_t i = 1; i < tuple_inputs.size(); i++) {
|
||||
if (!IsValueNode<FuncGraph>(tuple_inputs[i])) {
|
||||
MS_LOG(WARNING) << "SwitchLayer input is not FuncGraph";
|
||||
return false;
|
||||
}
|
||||
FuncGraphPtr fg = GetValueNode<FuncGraphPtr>(tuple_inputs[i]);
|
||||
if (FuncGraphHasTupleInput(fg)) {
|
||||
FuncGraphPtr transformed_fg = graph_transform_(fg, mng);
|
||||
tuple_inputs[i] = NewValueNode(transformed_fg);
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
if (changed) {
|
||||
*trans_node = node->func_graph()->NewCNode(tuple_inputs);
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
|
||||
private:
|
||||
GraphTupleParamTransform &graph_transform_;
|
||||
};
|
||||
|
||||
class CallGraphTupleTransform : public OptimizerCaller {
|
||||
public:
|
||||
CallGraphTupleTransform()
|
||||
: graph_transformer_(),
|
||||
graph_call_transform_(std::make_shared<GraphCallTupleTransform>(graph_transformer_)),
|
||||
switch_call_transform_(std::make_shared<SwitchCallTupleTransform>(graph_transformer_)),
|
||||
switch_layer_call_transform_(std::make_shared<SwitchLayerCallTupleTransform>(graph_transformer_)) {
|
||||
transformers_.emplace_back(graph_call_transform_);
|
||||
transformers_.emplace_back(switch_call_transform_);
|
||||
transformers_.emplace_back(switch_layer_call_transform_);
|
||||
}
|
||||
~CallGraphTupleTransform() = default;
|
||||
|
||||
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override {
|
||||
AnfNodePtr new_node;
|
||||
for (auto &transform : transformers_) {
|
||||
new_node = (*transform)(optimizer, node);
|
||||
if (new_node != nullptr) {
|
||||
return new_node;
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
private:
|
||||
GraphTupleParamTransform graph_transformer_;
|
||||
OptimizerCallerPtr graph_call_transform_;
|
||||
OptimizerCallerPtr switch_call_transform_;
|
||||
OptimizerCallerPtr switch_layer_call_transform_;
|
||||
std::vector<OptimizerCallerPtr> transformers_{};
|
||||
};
|
||||
} // namespace irpass
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_CALL_GRAPH_TRANSFORM_H_
|
|
@ -277,6 +277,7 @@ bool OptimizeAction(const ResourcePtr &res, const std::vector<PassItem> &passes)
|
|||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
func_graph->DumpFuncGraph(fg_name);
|
||||
DumpIR(fg_name + ".ir", func_graph);
|
||||
ExportIR(fg_name + ".dat", "", func_graph);
|
||||
MS_LOG(DEBUG) << "Dump " << fg_name << " func graph.";
|
||||
}
|
||||
counter++;
|
||||
|
|
|
@ -33,6 +33,7 @@
|
|||
#include "frontend/optimizer/clean.h"
|
||||
#include "frontend/optimizer/irpass.h"
|
||||
#include "frontend/optimizer/control_depend.h"
|
||||
#include "frontend/optimizer/graph_transform.h"
|
||||
#include "frontend/parallel/step_parallel.h"
|
||||
#include "frontend/parallel/step_auto_parallel.h"
|
||||
#include "frontend/parallel/allreduce_fusion/step_allreduce_fusion.h"
|
||||
|
@ -166,12 +167,23 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
|
|||
|
||||
OptPassGroupMap GetOptPassesAfterCconv(const opt::irpass::OptimizeIRPassLib &irpass) {
|
||||
opt::OptPassConfig c_1 = opt::OptPassConfig({
|
||||
// Safe inlining
|
||||
// Safe inlining,
|
||||
irpass.inline_,
|
||||
irpass.partial_eliminate_,
|
||||
});
|
||||
|
||||
OptPassGroupMap map_a({{"c_1", c_1}, {"renormalize", opt::OptPassConfig::Renormalize()}});
|
||||
OptPassGroupMap map_a({{"c_1", c_1},
|
||||
{"cse", opt::OptPassConfig(opt::CSEPass(false))},
|
||||
{"renormalize", opt::OptPassConfig::Renormalize()}});
|
||||
|
||||
return map_a;
|
||||
}
|
||||
|
||||
OptPassGroupMap GetOptPassesTransformGraph(const opt::irpass::OptimizeIRPassLib &irpass) {
|
||||
opt::OptPassConfig d_1 = opt::OptPassConfig({// Safe inlining
|
||||
irpass.call_graph_tuple_transform_, irpass.item_tuple_eliminate_});
|
||||
|
||||
OptPassGroupMap map_a({{"d_1", d_1}, {"renormalize", opt::OptPassConfig::Renormalize()}});
|
||||
|
||||
return map_a;
|
||||
}
|
||||
|
@ -262,6 +274,8 @@ void InitOpt(const ResourcePtr &res) {
|
|||
g_pass_opts["opt_b"] = Optimizer::MakeOptimizer("opt_b", res, GetOptPassesB(irpass), false, true);
|
||||
g_pass_opts["opt_after_cconv"] =
|
||||
Optimizer::MakeOptimizer("opt_after_cconv", res, GetOptPassesAfterCconv(irpass), false, true);
|
||||
g_pass_opts["opt_trans_graph"] =
|
||||
Optimizer::MakeOptimizer("opt_trans_graph", res, GetOptPassesTransformGraph(irpass), true, true);
|
||||
g_pass_opts["opt_graph_kernel_a"] =
|
||||
Optimizer::MakeOptimizer("opt_graph_kernel_a", res, GetOptPassesGraphKernelA(irpass), true);
|
||||
g_pass_opts["opt_graph_kernel_b"] =
|
||||
|
@ -307,6 +321,7 @@ bool OptPassGroup(const ResourcePtr &res, const std::string &name) {
|
|||
bool OptPassAGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_a"); }
|
||||
bool OptPassBGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_b"); }
|
||||
bool OptPassAfterCconvGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_after_cconv"); }
|
||||
bool OptPassTransformGraphGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_trans_graph"); }
|
||||
bool OptPassGraphKernelGroupA(const ResourcePtr &res) { return OptPassGroup(res, "opt_graph_kernel_a"); }
|
||||
bool OptPassGraphKernelGroupB(const ResourcePtr &res) { return OptPassGroup(res, "opt_graph_kernel_b"); }
|
||||
bool ControlGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_control"); }
|
||||
|
@ -365,6 +380,24 @@ bool CconvPass(const ResourcePtr &res) {
|
|||
return true;
|
||||
}
|
||||
|
||||
bool TransformTopGraphPass(const ResourcePtr &res) {
|
||||
if (res->func_graph() == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Transform top graph error.";
|
||||
}
|
||||
FuncGraphPtr func_graph = res->func_graph();
|
||||
if (opt::FuncGraphHasTupleInput(func_graph)) {
|
||||
opt::GraphTupleParamTransform graph_trans;
|
||||
func_graph = graph_trans(func_graph, res->manager());
|
||||
res->set_func_graph(func_graph);
|
||||
AbstractBasePtrList abs_spec_list;
|
||||
auto ¶ms = func_graph->parameters();
|
||||
std::transform(params.begin(), params.end(), std::back_inserter(abs_spec_list),
|
||||
[](AnfNodePtr node) { return node->abstract(); });
|
||||
res->set_args_spec(abs_spec_list);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ValidatePass(const ResourcePtr &res) {
|
||||
MS_EXCEPTION_IF_NULL(res->func_graph());
|
||||
FuncGraphPtr func_graph = res->func_graph();
|
||||
|
@ -388,6 +421,7 @@ std::vector<PassItem> kVmPasses = {{"simplify_data_structures", SimplifyDataStru
|
|||
{"cconv", CconvPass},
|
||||
{"opt_after_cconv", OptPassAfterCconvGroup},
|
||||
{"remove_dup_value", RemoveValueNodeDuplicationsPass},
|
||||
{"tuple_transform", OptPassTransformGraphGroup},
|
||||
{"opt_graph_kernel_a", OptPassGraphKernelGroupA},
|
||||
{"opt_graph_kernel_b", OptPassGraphKernelGroupB},
|
||||
{"add_control_depend", AddControlDependPass}};
|
||||
|
@ -401,6 +435,10 @@ std::vector<PassItem> kGePasses = {{"simplify_data_structures", SimplifyDataStru
|
|||
{"opt_prepare", PrepareGroup},
|
||||
{"cconv", CconvPass}};
|
||||
|
||||
std::vector<PassItem> kPynativePasses = {{"opt_a", OptPassAGroup}, {"opt_b", OptPassBGroup}, {"cconv", CconvPass}};
|
||||
std::vector<PassItem> kPynativePasses = {{"opt_a", OptPassAGroup},
|
||||
{"opt_b", OptPassBGroup},
|
||||
{"cconv", CconvPass},
|
||||
{"transform_top", TransformTopGraphPass},
|
||||
{"transform_graph", OptPassTransformGraphGroup}};
|
||||
} // namespace pipeline
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -1387,9 +1387,46 @@ void PynativeExecutor::ClearRes() {
|
|||
resource_.reset();
|
||||
}
|
||||
|
||||
size_t GetTupleSize(const py::tuple &args) {
|
||||
size_t count = 0;
|
||||
for (size_t i = 0; i < args.size(); i++) {
|
||||
if (py::isinstance<py::tuple>(args[i])) {
|
||||
count += GetTupleSize(args[i]);
|
||||
} else {
|
||||
count += 1;
|
||||
}
|
||||
}
|
||||
return count;
|
||||
}
|
||||
|
||||
void ConvertTupleArg(py::tuple *res, size_t *index, const py::tuple &arg) {
|
||||
for (size_t i = 0; i < arg.size(); i++) {
|
||||
if (py::isinstance<py::tuple>(arg[i])) {
|
||||
ConvertTupleArg(res, index, arg[i]);
|
||||
} else {
|
||||
(*res)[(*index)++] = arg[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
py::tuple ConvertArgs(const py::tuple &args) {
|
||||
size_t tuple_size = GetTupleSize(args);
|
||||
py::tuple res(tuple_size);
|
||||
size_t index = 0;
|
||||
for (size_t i = 0; i < args.size(); i++) {
|
||||
if (py::isinstance<py::tuple>(args[i])) {
|
||||
ConvertTupleArg(&res, &index, args[i]);
|
||||
} else {
|
||||
res[index++] = args[i];
|
||||
}
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
py::object PynativeExecutor::Run(const py::tuple &args, const py::object &phase) {
|
||||
VectorRef arg_list;
|
||||
pipeline::ProcessVmArgInner(args, resource_, &arg_list);
|
||||
py::tuple converted_args = ConvertArgs(args);
|
||||
pipeline::ProcessVmArgInner(converted_args, resource_, &arg_list);
|
||||
if (resource_->results().find(pipeline::kOutput) == resource_->results().end() ||
|
||||
!resource_->results()[pipeline::kOutput].is<compile::VmEvalFuncPtr>()) {
|
||||
MS_LOG(EXCEPTION) << "Can't find run graph func for ";
|
||||
|
|
|
@ -0,0 +1,201 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import pytest
|
||||
import numpy as np
|
||||
from mindspore import RowTensor
|
||||
from mindspore import context, nn, Tensor, ParameterTuple
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.common import ms_function
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import composite as C
|
||||
|
||||
|
||||
def setup_module():
|
||||
context.set_context(mode=context.PYNATIVE_MODE, enable_sparse=False)
|
||||
|
||||
|
||||
class _Grad(nn.Cell):
|
||||
def __init__(self, grad, network, wrt_params=False, real_inputs_count=None):
|
||||
super().__init__()
|
||||
self.network = network
|
||||
self.grad = grad
|
||||
self.sens_param = self.grad.sens_param
|
||||
self.wrt_params = wrt_params
|
||||
self.real_inputs_count = real_inputs_count
|
||||
if self.wrt_params:
|
||||
self.params = ParameterTuple(self.network.trainable_params())
|
||||
|
||||
def construct(self, *inputs):
|
||||
if self.wrt_params:
|
||||
if self.real_inputs_count is None or self.sens_param is False:
|
||||
return self.grad(self.network, self.params)(*inputs)
|
||||
real_inputs = inputs[:self.real_inputs_count]
|
||||
sense_param_inputs = inputs[self.real_inputs_count:]
|
||||
return self.grad(self.network, self.params)(*real_inputs, sense_param_inputs)
|
||||
|
||||
if self.real_inputs_count is None or self.sens_param is False:
|
||||
return self.grad(self.network)(*inputs)
|
||||
real_inputs = inputs[:self.real_inputs_count]
|
||||
sense_param_inputs = inputs[self.real_inputs_count:]
|
||||
return self.grad(self.network)(*real_inputs, sense_param_inputs)
|
||||
|
||||
|
||||
class GradOfFirstInput(_Grad):
|
||||
"""
|
||||
get grad of first input
|
||||
"""
|
||||
|
||||
def __init__(self, network, sens_param=True, real_inputs_count=None):
|
||||
super().__init__(grad=C.GradOperation(sens_param=sens_param),
|
||||
network=network, real_inputs_count=real_inputs_count)
|
||||
|
||||
|
||||
class GradOfAllInputs(_Grad):
|
||||
"""
|
||||
get grad of first input
|
||||
"""
|
||||
|
||||
def __init__(self, network, sens_param=True, real_inputs_count=None):
|
||||
super().__init__(grad=C.GradOperation(get_all=True, sens_param=sens_param),
|
||||
network=network, real_inputs_count=real_inputs_count)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_row_tensor_in_while():
|
||||
class RowTensorValuesDouble(nn.Cell):
|
||||
|
||||
def construct(self, x):
|
||||
indices = x.indices
|
||||
values = x.values * 2
|
||||
dense_shape = x.dense_shape
|
||||
return RowTensor(indices, values, dense_shape)
|
||||
|
||||
class RowTensorValuesAdd2(nn.Cell):
|
||||
|
||||
def construct(self, x):
|
||||
indices = x.indices
|
||||
values = x.values + 2
|
||||
dense_shape = x.dense_shape
|
||||
return RowTensor(indices, values, dense_shape)
|
||||
|
||||
class RowTensorWithControlWhile(nn.Cell):
|
||||
def __init__(self, dense_shape):
|
||||
super().__init__()
|
||||
self.op1 = RowTensorValuesDouble()
|
||||
self.op2 = RowTensorValuesAdd2()
|
||||
self.dense_shape = dense_shape
|
||||
|
||||
@ms_function
|
||||
def construct(self, a, b, indices, values):
|
||||
x = RowTensor(indices, values, self.dense_shape)
|
||||
x = self.op2(x)
|
||||
while a > b:
|
||||
x = self.op1(x)
|
||||
b = b + 1
|
||||
return x.indices, x.values, x.dense_shape
|
||||
a = Tensor(np.array(3).astype(np.int32))
|
||||
b = Tensor(np.array(0).astype(np.int32))
|
||||
indices = Tensor(np.array([0, 2]).astype(np.int32))
|
||||
values = Tensor(np.ones([2, 2]).astype(np.float32))
|
||||
dense_shape = (5, 2)
|
||||
net = RowTensorWithControlWhile(dense_shape)
|
||||
out = net(a, b, indices, values)
|
||||
assert np.allclose(indices.asnumpy(), out[0].asnumpy(), .0, .0)
|
||||
assert np.allclose(values.asnumpy()*24, out[1].asnumpy(), .0, .0)
|
||||
assert dense_shape == out[2]
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_parser_switch_layer_inputs_tuple():
|
||||
class Add(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.op = P.TensorAdd()
|
||||
|
||||
def construct(self, x):
|
||||
y = self.op(x[0], x[1])
|
||||
return self.op(x[0], y)
|
||||
|
||||
class Mul(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.op = P.Mul()
|
||||
|
||||
def construct(self, x):
|
||||
y = self.op(x[0], x[1])
|
||||
return self.op(x[0], y)
|
||||
|
||||
class MulTwoInput(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.op = P.Mul()
|
||||
|
||||
@ms_function
|
||||
def construct(self, x, y):
|
||||
y = self.op(x, y)
|
||||
return self.op(x, y)
|
||||
|
||||
class TwoInputTupleFinalNet(nn.Cell):
|
||||
def __init__(self, funcs):
|
||||
super().__init__()
|
||||
self.funcs = funcs
|
||||
|
||||
@ms_function
|
||||
def construct(self, i, inputa, inputb):
|
||||
inputs = (inputa, inputb)
|
||||
x = self.funcs[i](inputs)
|
||||
return x
|
||||
|
||||
func1 = Add()
|
||||
func2 = Mul()
|
||||
|
||||
funcs = (func1, func2)
|
||||
net = TwoInputTupleFinalNet(funcs)
|
||||
|
||||
input_data = Tensor(np.random.randn(2, 3, 4, 5).astype(np.float32))
|
||||
input2 = Tensor(np.random.randn(2, 3, 4, 5).astype(np.float32))
|
||||
i = Tensor(1, mstype.int32)
|
||||
netout = net(i, input_data, input2)
|
||||
net_good = MulTwoInput()
|
||||
goodout = net_good(input_data, input2)
|
||||
assert np.allclose(goodout.asnumpy(), netout.asnumpy(), 0, 0)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_imagenet():
|
||||
class ImageGradients(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.imagegradients = nn.ImageGradients()
|
||||
|
||||
def construct(self, inputs):
|
||||
return self.imagegradients(inputs)
|
||||
|
||||
net = ImageGradients()
|
||||
net_me = GradOfFirstInput(net, real_inputs_count=1)
|
||||
net_me.set_train()
|
||||
input_data = Tensor(np.ones([32, 16, 8, 8]), dtype=mstype.float32)
|
||||
output_grad = (Tensor(np.ones([32, 16, 8, 8]), dtype=mstype.float32),
|
||||
Tensor(np.ones([32, 16, 8, 8]), dtype=mstype.float32))
|
||||
net_me(input_data, *output_grad)
|
|
@ -0,0 +1,136 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
from mindspore import RowTensor
|
||||
from mindspore import context, nn, Tensor, ParameterTuple
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.common import ms_function
|
||||
from mindspore.ops import composite as C
|
||||
|
||||
|
||||
def setup_module():
|
||||
context.set_context(mode=context.PYNATIVE_MODE, enable_sparse=False)
|
||||
|
||||
|
||||
class _Grad(nn.Cell):
|
||||
def __init__(self, grad, network, wrt_params=False, real_inputs_count=None):
|
||||
super().__init__()
|
||||
self.network = network
|
||||
self.grad = grad
|
||||
self.sens_param = self.grad.sens_param
|
||||
self.wrt_params = wrt_params
|
||||
self.real_inputs_count = real_inputs_count
|
||||
if self.wrt_params:
|
||||
self.params = ParameterTuple(self.network.trainable_params())
|
||||
|
||||
def construct(self, *inputs):
|
||||
if self.wrt_params:
|
||||
if self.real_inputs_count is None or self.sens_param is False:
|
||||
return self.grad(self.network, self.params)(*inputs)
|
||||
real_inputs = inputs[:self.real_inputs_count]
|
||||
sense_param_inputs = inputs[self.real_inputs_count:]
|
||||
return self.grad(self.network, self.params)(*real_inputs, sense_param_inputs)
|
||||
|
||||
if self.real_inputs_count is None or self.sens_param is False:
|
||||
return self.grad(self.network)(*inputs)
|
||||
real_inputs = inputs[:self.real_inputs_count]
|
||||
sense_param_inputs = inputs[self.real_inputs_count:]
|
||||
return self.grad(self.network)(*real_inputs, sense_param_inputs)
|
||||
|
||||
|
||||
class GradOfFirstInput(_Grad):
|
||||
"""
|
||||
get grad of first input
|
||||
"""
|
||||
|
||||
def __init__(self, network, sens_param=True, real_inputs_count=None):
|
||||
super().__init__(grad=C.GradOperation(sens_param=sens_param),
|
||||
network=network, real_inputs_count=real_inputs_count)
|
||||
|
||||
|
||||
class GradOfAllInputs(_Grad):
|
||||
"""
|
||||
get grad of first input
|
||||
"""
|
||||
|
||||
def __init__(self, network, sens_param=True, real_inputs_count=None):
|
||||
super().__init__(grad=C.GradOperation(get_all=True, sens_param=sens_param),
|
||||
network=network, real_inputs_count=real_inputs_count)
|
||||
|
||||
|
||||
def test_row_tensor_in_while():
|
||||
class RowTensorValuesDouble(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def construct(self, x):
|
||||
indices = x.indices
|
||||
values = x.values * 2
|
||||
dense_shape = x.dense_shape
|
||||
return RowTensor(indices, values, dense_shape)
|
||||
|
||||
class RowTensorValuesAdd2(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def construct(self, x):
|
||||
indices = x.indices
|
||||
values = x.values + 2
|
||||
dense_shape = x.dense_shape
|
||||
return RowTensor(indices, values, dense_shape)
|
||||
|
||||
class RowTensorWithControlWhile(nn.Cell):
|
||||
def __init__(self, dense_shape):
|
||||
super().__init__()
|
||||
self.op1 = RowTensorValuesDouble()
|
||||
self.op2 = RowTensorValuesAdd2()
|
||||
self.dense_shape = dense_shape
|
||||
|
||||
@ms_function
|
||||
def construct(self, a, b, indices, values):
|
||||
x = RowTensor(indices, values, self.dense_shape)
|
||||
x = self.op2(x)
|
||||
while (a > b):
|
||||
x = self.op1(x)
|
||||
b = b + 1
|
||||
return x.indices, x.values, x.dense_shape
|
||||
a = Tensor(np.array(3).astype(np.int32))
|
||||
b = Tensor(np.array(0).astype(np.int32))
|
||||
indices = Tensor(np.array([0, 2]).astype(np.int32))
|
||||
values = Tensor(np.ones([2, 2]).astype(np.float32))
|
||||
dense_shape = (5, 2)
|
||||
|
||||
net = RowTensorWithControlWhile(dense_shape)
|
||||
net(a, b, indices, values)
|
||||
|
||||
|
||||
def test_multi_out_sens():
|
||||
class ImageGradients(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def construct(self, x, y, z):
|
||||
resa = x * y
|
||||
resb = y * z
|
||||
resc = x * z
|
||||
return resa, (resb, resc)
|
||||
|
||||
net = ImageGradients()
|
||||
net_me = GradOfAllInputs(net, real_inputs_count=3)
|
||||
net_me.set_train()
|
||||
input_data = Tensor(np.ones([32]), dtype=mstype.float32)
|
||||
output_grad = (Tensor(np.ones([32]), dtype=mstype.float32),
|
||||
(Tensor(np.ones([32]), dtype=mstype.float32), Tensor(np.ones([32]), dtype=mstype.float32)))
|
||||
net_me(input_data, input_data, input_data, *output_grad)
|
Loading…
Reference in New Issue