!30253 Support more ops for dump flag
Merge pull request !30253 from huanghui/enhance-dump-flag
This commit is contained in:
commit
12d8906e2e
|
@ -12,9 +12,8 @@ mindspore.set_dump
|
|||
|
||||
.. Note::
|
||||
- 此API只在Ascend后端的图模式有效。
|
||||
- 当target是一个Cell且enabled设置为True时,Cell实例及其子Cell实例的Primitive将递归启用Dump。如果算子不是Cell实例的成员,则不会为该算子启用Dump(例如,在construct方法中直接使用的 `functional 算子 <https://www.mindspore.cn/docs/api/zh-CN/master/api_python/mindspore.ops.html#functional>`_ )。要使此API生效,请在Cell的__init__方法中使用self.some_op = SomeOp()的写法。
|
||||
- 使用set_dump(Cell, True)后,Cell正向计算中的算子会被Dump,大多数反向计算(梯度运算产生的计算)不会被Dump。然而,由于图的优化,一些反向计算的数据仍然会被Dump。可以忽略文件名中包含“Gradients”的反向计算数据。
|
||||
- 此API只支持训练开始前调用。如果在训练过程中调用这个API,可能不会有效果。
|
||||
- 使用set_dump(Cell, True)后,Cell正向计算和反向计算(梯度运算产生的计算)中的算子会被Dump。
|
||||
- 对于 `nn.SoftMaxCrossEntropyWithLogits 层 <https://www.mindspore.cn/docs/api/zh-CN/master/api_python/nn/mindspore.nn.SoftmaxCrossEntropyWithLogits.html#mindspore.nn.SoftmaxCrossEntropyWithLogits>`_ ,正向计算和反向计算使用同一组算子。因此,只能看到反向计算中的Dump数据。请注意,当使用sparse=True和reduce=“mean”初始化时,nn.SoftmaxCrossEntropyWithLogits层也将在内部使用这些算子。
|
||||
|
||||
**参数:**
|
||||
|
|
|
@ -586,9 +586,9 @@ void DumpCNode(const CNodePtr &node, const FuncGraphPtr &sub_graph, OrderedMap<A
|
|||
DumpKernelInfo(node, gsub);
|
||||
|
||||
if (dump_full_name) {
|
||||
gsub->buffer << " : # fullname_with_scope: (" << node->fullname_with_scope() << ")" << std::endl;
|
||||
gsub->buffer << " # fullname_with_scope: (" << node->fullname_with_scope() << ")" << std::endl;
|
||||
} else {
|
||||
gsub->buffer << " : # scope: (" << node->scope()->name() << ")" << std::endl;
|
||||
gsub->buffer << " # scope: (" << node->scope()->name() << ")" << std::endl;
|
||||
}
|
||||
|
||||
// Print debug info
|
||||
|
|
|
@ -44,6 +44,7 @@
|
|||
#include "mindspore/core/load_mindir/load_model.h"
|
||||
#include "utils/system/sha256.h"
|
||||
#include "utils/file_utils.h"
|
||||
#include "utils/anf_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ad {
|
||||
|
@ -522,6 +523,16 @@ std::vector<NodeDebugInfoPtr> GeneratePrimalDebugInfo(const ValueNodePtr &value_
|
|||
return primal_debug_infos;
|
||||
}
|
||||
|
||||
void SetDumpFlag(const PrimitivePtr &prim, const FuncGraphPtr &bprop_fg) {
|
||||
if (prim == nullptr || bprop_fg == nullptr) {
|
||||
return;
|
||||
}
|
||||
auto attr = prim->GetAttr(kAttrDump);
|
||||
if (attr != nullptr && attr->isa<StringImm>() && attr->cast<StringImmPtr>()->value() == kValueTrue) {
|
||||
bprop_fg->set_flag(FUNC_GRAPH_FLAG_DUMP, true);
|
||||
}
|
||||
}
|
||||
|
||||
FuncGraphPtr KPrim::KPrimitive(const CNodePtr &cnode, const ValueNodePtr &value_node,
|
||||
const pipeline::ResourceBasePtr &resources) {
|
||||
if (!IsValueNode<Primitive>(value_node)) {
|
||||
|
@ -552,6 +563,7 @@ FuncGraphPtr KPrim::KPrimitive(const CNodePtr &cnode, const ValueNodePtr &value_
|
|||
bprop_fg = GetPrimBprop(prim, value_node, resources);
|
||||
}
|
||||
|
||||
SetDumpFlag(prim, bprop_fg);
|
||||
AdjustForAutoMonad(prim, bprop_fg);
|
||||
mindspore::HashMap<std::string, ValuePtr> primal_attrs;
|
||||
std::vector<NodeDebugInfoPtr> primal_debug_infos = GeneratePrimalDebugInfo(value_node, resources);
|
||||
|
|
|
@ -0,0 +1,79 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_EXPAND_DUMP_FLAG_H_
|
||||
#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_EXPAND_DUMP_FLAG_H_
|
||||
|
||||
#include <set>
|
||||
|
||||
#include "frontend/optimizer/optimizer.h"
|
||||
#include "frontend/optimizer/anf_visitor.h"
|
||||
|
||||
namespace mindspore::opt::irpass {
|
||||
const PrimitiveSet dump_skipped_prim_set = {prim::kPrimReturn, prim::kPrimDepend, prim::kPrimMakeTuple,
|
||||
prim::kPrimTupleGetItem, prim::kPrimUpdateState, prim::kPrimLoad,
|
||||
prim::kPrimPrint, prim::kPrimPartial};
|
||||
|
||||
// Expand dump flag to all of cnodes if parent graph has dump flag.
|
||||
class ExpandDumpFlag {
|
||||
public:
|
||||
bool operator()(const FuncGraphPtr &, const OptimizerPtr &optimizer) {
|
||||
MS_EXCEPTION_IF_NULL(optimizer);
|
||||
auto manager = optimizer->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
std::set<FuncGraphPtr> seen;
|
||||
auto graph_filter = [&seen](const FuncGraphPtr &graph) {
|
||||
if (seen.find(graph) != seen.end() ||
|
||||
(graph->has_attr(FUNC_GRAPH_FLAG_DUMP) && !graph->has_flag(FUNC_GRAPH_FLAG_DUMP))) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
};
|
||||
for (auto &func_graph : manager->func_graphs()) {
|
||||
if (!func_graph->has_flag(FUNC_GRAPH_FLAG_DUMP) || seen.find(func_graph) != seen.end()) {
|
||||
continue;
|
||||
}
|
||||
std::set<FuncGraphPtr> traverse_graphs;
|
||||
|
||||
SuccFunc succ_func = std::bind(SuccWithFilter, graph_filter, std::placeholders::_1);
|
||||
auto nodes = TopoSort(func_graph->get_return(), succ_func);
|
||||
for (const auto &node : nodes) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto node_graph = node->func_graph();
|
||||
if (seen.find(node_graph) != seen.end()) {
|
||||
continue;
|
||||
}
|
||||
traverse_graphs.insert(node_graph);
|
||||
// If the node need be ignored or the dump flag is set by false, do not set true.
|
||||
if (!node->isa<CNode>() || IsOneOfPrimitiveCNode(node, dump_skipped_prim_set) ||
|
||||
(AnfUtils::HasDumpFlag(node) && !AnfUtils::GetDumpFlag(node))) {
|
||||
continue;
|
||||
}
|
||||
AnfUtils::SetDumpFlag(node);
|
||||
}
|
||||
for (auto graph : traverse_graphs) {
|
||||
if (graph != nullptr && graph->has_attr(FUNC_GRAPH_FLAG_DUMP)) {
|
||||
graph->erase_flag(FUNC_GRAPH_FLAG_DUMP);
|
||||
}
|
||||
}
|
||||
seen.insert(traverse_graphs.begin(), traverse_graphs.end());
|
||||
}
|
||||
return false;
|
||||
}
|
||||
};
|
||||
} // namespace mindspore::opt::irpass
|
||||
|
||||
#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_EXPAND_DUMP_FLAG_H_
|
|
@ -2582,6 +2582,9 @@ FuncGraphPtr MakeTopGraph(const py::object &cell, const ValuePtr &cell_ptr) {
|
|||
if (py::hasattr(cell, "construct")) {
|
||||
parse::UpdateFuncGraphFlags(cell.attr("construct"), func_graph);
|
||||
}
|
||||
if (current_graph->has_flag(FUNC_GRAPH_FLAG_DUMP) && func_graph->has_flag(FUNC_GRAPH_FLAG_DUMP)) {
|
||||
current_graph->erase_flag(FUNC_GRAPH_FLAG_DUMP);
|
||||
}
|
||||
|
||||
auto unpacking = func_graph->has_vararg() || func_graph->has_kwarg();
|
||||
MS_EXCEPTION_IF_NULL(current_graph->get_return());
|
||||
|
|
|
@ -55,6 +55,7 @@
|
|||
#include "frontend/optimizer/irpass/taylor_eliminate.h"
|
||||
#include "frontend/optimizer/irpass/parameter_eliminate.h"
|
||||
#include "frontend/optimizer/irpass/updatestate_eliminate.h"
|
||||
#include "frontend/optimizer/irpass/expand_dump_flag.h"
|
||||
#if ((defined ENABLE_CPU) && (!defined _WIN32))
|
||||
#include "ps/util.h"
|
||||
#include "ps/ps_context.h"
|
||||
|
@ -355,7 +356,8 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
|
|||
opt::OptPassConfig recompute_prepare = opt::OptPassConfig({irpass.set_cell_output_no_recompute_});
|
||||
|
||||
// Before adjusting map_a, check GetA1A2() and GetOptPynativeGradEpiloguePhases().
|
||||
OptPassGroupMap map_a({{"switch_simplify", opt::OptPassConfig({irpass.switch_simplify_})},
|
||||
OptPassGroupMap map_a({{"expand_dump_flag", opt::OptPassConfig(opt::irpass::ExpandDumpFlag())},
|
||||
{"switch_simplify", opt::OptPassConfig({irpass.switch_simplify_})},
|
||||
{"a_1", a_1},
|
||||
{"recompute_prepare", recompute_prepare},
|
||||
{"updatestate_depend_eliminate", updatestate_depend_eliminate},
|
||||
|
@ -385,7 +387,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
|
|||
|
||||
OptPassGroupMap GetA1A2(const opt::irpass::OptimizeIRPassLib &irpass) {
|
||||
auto opt_a = GetOptPassesA(irpass);
|
||||
constexpr auto a1_a2_len = 7;
|
||||
constexpr auto a1_a2_len = 9;
|
||||
OptPassGroupMap a1_a2(opt_a.begin(), opt_a.begin() + a1_a2_len);
|
||||
return a1_a2;
|
||||
}
|
||||
|
|
|
@ -388,6 +388,10 @@ bool IrExportBuilder::BuildFuncGraphAttrs(const FuncGraphPtr &func_graph, mind_i
|
|||
MS_EXCEPTION_IF_NULL(graph_proto);
|
||||
for (auto attr : func_graph->attrs()) {
|
||||
MS_LOG(DEBUG) << "attr: " << attr.first << " " << attr.second->DumpText() << " " << attr.second->type_name();
|
||||
auto iter = g_export_attr_blacklist.find(attr.first);
|
||||
if (iter != g_export_attr_blacklist.end()) {
|
||||
continue;
|
||||
}
|
||||
mind_ir::AttributeProto *attr_proto = graph_proto->add_attribute();
|
||||
attr_proto->set_name(attr.first);
|
||||
if (!SetValueToAttributeProto(attr.second, attr_proto)) {
|
||||
|
|
|
@ -90,6 +90,7 @@ const char FUNC_GRAPH_ATTR_GRAPH_KERNEL[] = "graph_kernel";
|
|||
const char FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER[] = "spec_param";
|
||||
const char FUNC_GRAPH_OUTPUT_NO_RECOMPUTE[] = "output_no_recompute";
|
||||
const char FUNC_GRAPH_FLAG_FORCE_INLINE[] = "force_inline";
|
||||
const char FUNC_GRAPH_FLAG_DUMP[] = "dump";
|
||||
|
||||
const char kFuncGraphFlagUndetermined[] = "Undeterminate";
|
||||
const char kFuncGraphFlagBackPropEntry[] = "BackPropEntry";
|
||||
|
|
|
@ -256,6 +256,31 @@ std::vector<AnfNodePtr> SuccIncludeFV(const FuncGraphPtr &fg, const AnfNodePtr &
|
|||
return vecs;
|
||||
}
|
||||
|
||||
std::vector<AnfNodePtr> SuccWithFilter(const GraphFilterFunc &graph_filter, const AnfNodePtr &node) {
|
||||
std::vector<AnfNodePtr> vecs;
|
||||
if (node == nullptr) {
|
||||
return vecs;
|
||||
}
|
||||
|
||||
if (IsValueNode<FuncGraph>(node)) {
|
||||
auto graph = GetValueNode<FuncGraphPtr>(node);
|
||||
if (graph_filter != nullptr && graph_filter(graph)) {
|
||||
return vecs;
|
||||
}
|
||||
|
||||
auto &ret = graph->return_node();
|
||||
if (ret != nullptr) {
|
||||
vecs.push_back(ret);
|
||||
}
|
||||
return vecs;
|
||||
} else {
|
||||
if (node->isa<CNode>()) {
|
||||
PushSuccessors(node->cast<CNodePtr>(), &vecs);
|
||||
}
|
||||
return vecs;
|
||||
}
|
||||
}
|
||||
|
||||
const std::vector<AnfNodePtr> &GetInputs(const AnfNodePtr &node) {
|
||||
static std::vector<AnfNodePtr> empty_inputs;
|
||||
auto cnode = dyn_cast<CNode>(node);
|
||||
|
|
|
@ -26,6 +26,7 @@
|
|||
#include <set>
|
||||
#include <string>
|
||||
#include <functional>
|
||||
#include <deque>
|
||||
|
||||
#include "utils/hash_map.h"
|
||||
#include "utils/hash_set.h"
|
||||
|
@ -41,6 +42,7 @@ enum IncludeType { FOLLOW, NOFOLLOW, EXCLUDE };
|
|||
|
||||
using IncludeFunc = std::function<IncludeType(const AnfNodePtr &)>;
|
||||
using FilterFunc = std::function<bool(const AnfNodePtr &)>;
|
||||
using GraphFilterFunc = std::function<bool(const FuncGraphPtr &)>;
|
||||
using SuccFunc = std::function<std::vector<AnfNodePtr>(AnfNodePtr)>;
|
||||
using SearchFunc = std::function<std::vector<AnfNodePtr>(const AnfNodePtr &, const IncludeFunc &)>;
|
||||
using MatchFunc = std::function<bool(const CNodePtr &)>;
|
||||
|
@ -50,6 +52,7 @@ std::vector<AnfNodePtr> SuccDeeper(const AnfNodePtr &node);
|
|||
MS_CORE_API std::vector<AnfNodePtr> SuccDeeperSimple(const AnfNodePtr &node);
|
||||
MS_CORE_API std::vector<AnfNodePtr> SuccIncoming(const AnfNodePtr &node);
|
||||
std::vector<AnfNodePtr> SuccIncludeFV(const FuncGraphPtr &fg, const AnfNodePtr &node);
|
||||
MS_CORE_API std::vector<AnfNodePtr> SuccWithFilter(const GraphFilterFunc &graph_filter, const AnfNodePtr &node);
|
||||
|
||||
MS_CORE_API const std::vector<AnfNodePtr> &GetInputs(const AnfNodePtr &node);
|
||||
|
||||
|
|
|
@ -424,6 +424,17 @@ bool AnfUtils::GetDumpFlag(const AnfNodePtr &node) {
|
|||
return false;
|
||||
}
|
||||
|
||||
bool AnfUtils::HasDumpFlag(const AnfNodePtr &node) {
|
||||
if (node == nullptr || !node->isa<CNode>()) {
|
||||
return false;
|
||||
}
|
||||
auto prim = GetCNodePrimitive(node);
|
||||
if (prim != nullptr) {
|
||||
return prim->HasAttr(kAttrDump);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool AnfUtils::IsCustomActorNode(const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
return node->has_user_data<CustomActorInfo>();
|
||||
|
|
|
@ -73,6 +73,8 @@ class MS_CORE_API AnfUtils {
|
|||
static void SetDumpFlag(const AnfNodePtr &node);
|
||||
// Get dump flag from CNode's primitive.
|
||||
static bool GetDumpFlag(const AnfNodePtr &node);
|
||||
// Check whether the node has dump flag or not.
|
||||
static bool HasDumpFlag(const AnfNodePtr &node);
|
||||
static AbstractScope GetAbstractLock(const AnfNode *node);
|
||||
static void OpenAbstractLock();
|
||||
static void CloseAbstractLock();
|
||||
|
|
|
@ -35,23 +35,12 @@ def set_dump(target, enabled=True):
|
|||
|
||||
Note:
|
||||
1. This API is only effective for GRAPH_MODE with Ascend backend.
|
||||
2. When target is a cell and enabled is True, this API will enable
|
||||
dump for the primitive operator members of the cell instance and
|
||||
its child cell instances recursively. If an operator is not a
|
||||
member of the cell instance, the dump flag will not be set for
|
||||
this operator (e.g. `functional operators
|
||||
<https://www.mindspore.cn/docs/api/en/master/api_python/mindspore.ops.html#functional>`_ used directly in
|
||||
construct method). To make this API effective, please use
|
||||
self.some_op = SomeOp() in your cell's __init__ method.
|
||||
3. After using set_dump(cell, True), operators in forward computation
|
||||
of the cell will be dumped. Most backward computation (computation
|
||||
generated by the grad operations) will not be dumped by design.
|
||||
However, due to the graph optimization, a few backward computation
|
||||
data will still be dumped. You can ignore the backward computation
|
||||
data which contains "Gradients" in their filenames.
|
||||
4. This API only supports being called before training starts.
|
||||
2. This API only supports being called before training starts.
|
||||
If you call this API during training, it may not be effective.
|
||||
5. For `nn.SparseSoftmaxCrossEntropyWithLogits
|
||||
3. After using set_dump(cell, True), operators in forward and backward
|
||||
computation (computation generated by the grad operations) of the
|
||||
cell will be dumped.
|
||||
4. For `nn.SparseSoftmaxCrossEntropyWithLogits
|
||||
<https://www.mindspore.cn/docs/api/en/master/api_python/nn/
|
||||
mindspore.nn.SoftmaxCrossEntropyWithLogits.html#mindspore.nn
|
||||
.SoftmaxCrossEntropyWithLogits>`_ layer, the forward
|
||||
|
@ -132,16 +121,15 @@ def set_dump(target, enabled=True):
|
|||
"before calling set_dump.")
|
||||
|
||||
# The actual set dump logic.
|
||||
mode = "true" if enabled else "false"
|
||||
if isinstance(target, nn.Cell):
|
||||
primitives = getattr(target, "_primitives", {})
|
||||
for value in primitives.values():
|
||||
if value:
|
||||
value.add_prim_attr("dump", mode)
|
||||
target.add_flags(dump=enabled)
|
||||
for cell in target.cells():
|
||||
set_dump(cell, enabled)
|
||||
return
|
||||
|
||||
primitives = getattr(target, "_primitives", {})
|
||||
for value in primitives.values():
|
||||
if value and "dump" in value.attrs:
|
||||
set_dump(value, enabled)
|
||||
|
||||
if isinstance(target, Primitive):
|
||||
target.add_prim_attr("dump", mode)
|
||||
return
|
||||
target.add_prim_attr("dump", "true" if enabled else "false")
|
||||
|
|
|
@ -22,7 +22,7 @@ import glob
|
|||
from enum import Enum
|
||||
import numpy as np
|
||||
import pytest
|
||||
from mindspore import Tensor, set_dump
|
||||
from mindspore import Tensor, set_dump, ops
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.nn import Cell
|
||||
from mindspore.nn import Dense
|
||||
|
@ -33,11 +33,13 @@ from mindspore.nn import WithLossCell
|
|||
from dump_test_utils import generate_cell_dump_json, check_dump_structure
|
||||
from tests.security_utils import security_off_wrap
|
||||
|
||||
|
||||
class IsDump(Enum):
|
||||
SET_DUMP_TRUE = 1
|
||||
SET_DUMP_FALSE = 2
|
||||
SET_NONE = 3
|
||||
|
||||
|
||||
class ReluReduceMeanDenseRelu(Cell):
|
||||
def __init__(self, kernel, bias, in_channel, num_class):
|
||||
super().__init__()
|
||||
|
@ -101,12 +103,18 @@ def test_ascend_cell_dump():
|
|||
check_dump_structure(dump_path, dump_config_path, 1, 1, 1)
|
||||
|
||||
# make sure 2 relu dump files are generated with correct name prefix
|
||||
assert len(os.listdir(dump_file_path)) == 2
|
||||
assert len(os.listdir(dump_file_path)) == 3
|
||||
relu_file_name = "ReLU.Default_network-WithLossCell__backbone-ReluReduceMeanDenseRelu_ReLU-op*.*.*.*"
|
||||
relu_file1 = glob.glob(os.path.join(dump_file_path, relu_file_name))[0]
|
||||
relu_file2 = glob.glob(os.path.join(dump_file_path, relu_file_name))[1]
|
||||
assert relu_file1
|
||||
assert relu_file2
|
||||
|
||||
# make sure 1 ReluGrad dump files are generated with correct name prefix
|
||||
relu_grad_file_name = "ReluGrad.Gradients_Default_network-WithLossCell__backbone" \
|
||||
"-ReluReduceMeanDenseRelu_gradReLU_ReluGrad-op*.*.*.*"
|
||||
relu_grad_file1 = glob.glob(os.path.join(dump_file_path, relu_grad_file_name))[0]
|
||||
assert relu_grad_file1
|
||||
del os.environ['MINDSPORE_DUMP_CONFIG']
|
||||
|
||||
|
||||
|
@ -170,6 +178,7 @@ def test_ascend_cell_empty_dump():
|
|||
assert not os.path.exists(dump_file_path)
|
||||
del os.environ['MINDSPORE_DUMP_CONFIG']
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
|
@ -203,3 +212,48 @@ def test_ascend_cell_dump_set_enable_false():
|
|||
mean_file = glob.glob(os.path.join(dump_file_path, mean_file_name))[0]
|
||||
assert mean_file
|
||||
del os.environ['MINDSPORE_DUMP_CONFIG']
|
||||
|
||||
|
||||
class OperateSymbolNet(Cell):
|
||||
def construct(self, x):
|
||||
x = ops.Add()(x, 1)
|
||||
x = x - 1
|
||||
x = x / 1
|
||||
return x
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
@security_off_wrap
|
||||
def test_ascend_cell_dump_with_operate_symbol():
|
||||
"""
|
||||
Feature: Cell Dump
|
||||
Description: Test cell dump
|
||||
Expectation: Operators which is expressed by symbol will be dumped
|
||||
"""
|
||||
if sys.platform != 'linux':
|
||||
return
|
||||
with tempfile.TemporaryDirectory(dir='/tmp') as tmp_dir:
|
||||
dump_path = os.path.join(tmp_dir, 'cell_dump')
|
||||
dump_config_path = os.path.join(tmp_dir, 'cell_dump.json')
|
||||
generate_cell_dump_json(dump_path, dump_config_path, 'test_async_dump', 2)
|
||||
os.environ['MINDSPORE_DUMP_CONFIG'] = dump_config_path
|
||||
if os.path.isdir(dump_path):
|
||||
shutil.rmtree(dump_path)
|
||||
|
||||
net = OperateSymbolNet()
|
||||
x = Tensor(np.ones((1000,)).astype(np.float32))
|
||||
set_dump(net)
|
||||
net(x)
|
||||
|
||||
dump_file_path = os.path.join(dump_path, 'rank_0', 'Net', '0', '0')
|
||||
for _ in range(5):
|
||||
if not os.path.exists(dump_file_path):
|
||||
time.sleep(1)
|
||||
check_dump_structure(dump_path, dump_config_path, 1, 1, 1)
|
||||
|
||||
# make sure directory has dumped files with enabled=True
|
||||
assert len(os.listdir(dump_file_path)) == 3
|
||||
del os.environ['MINDSPORE_DUMP_CONFIG']
|
||||
|
|
|
@ -44,7 +44,7 @@ def test_set_dump_on_cell():
|
|||
net = MyNet()
|
||||
set_dump(net.relu1)
|
||||
|
||||
assert net.relu1.relu.attrs["dump"] == "true"
|
||||
assert net.relu1.get_flags()["dump"] is True
|
||||
|
||||
|
||||
def test_set_dump_on_primitive():
|
||||
|
@ -84,3 +84,51 @@ def test_set_dump_warning():
|
|||
set_dump(op)
|
||||
assert "Only Ascend device target is supported" in str(w[-2].message)
|
||||
assert "Only GRAPH_MODE is supported" in str(w[-1].message)
|
||||
|
||||
|
||||
def test_set_dump_on_cell_with_false():
|
||||
"""
|
||||
Feature: Python API set_dump on cell with False.
|
||||
Description: Use set_dump API on Cell instance.
|
||||
Expectation: Success.
|
||||
"""
|
||||
|
||||
class MyNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super(MyNet, self).__init__()
|
||||
self.relu1 = nn.ReLU()
|
||||
|
||||
def construct(self, x):
|
||||
x = self.relu1(x)
|
||||
return x
|
||||
|
||||
net = MyNet()
|
||||
set_dump(net.relu1)
|
||||
assert net.relu1.get_flags()["dump"] is True
|
||||
|
||||
set_dump(net, False)
|
||||
assert net.relu1.get_flags()["dump"] is False
|
||||
|
||||
|
||||
def test_set_dump_on_primitive_with_false():
|
||||
"""
|
||||
Feature: Python API set_dump on primitive with False.
|
||||
Description: Use set_dump API on Cell instance.
|
||||
Expectation: Success.
|
||||
"""
|
||||
|
||||
class MyNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super(MyNet, self).__init__()
|
||||
self.relu1 = ops.ReLU()
|
||||
|
||||
def construct(self, x):
|
||||
x = self.relu1(x)
|
||||
return x
|
||||
|
||||
net = MyNet()
|
||||
set_dump(net.relu1)
|
||||
assert net.relu1.attrs.get("dump") == "true"
|
||||
|
||||
set_dump(net, False)
|
||||
assert net.relu1.attrs.get("dump") == "false"
|
||||
|
|
Loading…
Reference in New Issue