fix summary isolation
This commit is contained in:
parent
78b6fd17d6
commit
c2bd061889
|
@ -362,6 +362,7 @@ void MemReuseUtil::SetReuseRefCount() {
|
|||
}
|
||||
}
|
||||
|
||||
#ifndef ENABLE_SECURITY
|
||||
void MemReuseUtil::SetSummaryNodesRefCount() {
|
||||
bool summary_exist = graph_->summary_node_exist();
|
||||
if (!summary_exist) {
|
||||
|
@ -393,6 +394,7 @@ void MemReuseUtil::SetSummaryNodesRefCount() {
|
|||
#endif
|
||||
MS_LOG(INFO) << "Special Tensor total size: SummaryNodes: " << total_summary_size;
|
||||
}
|
||||
#endif
|
||||
|
||||
void MemReuseUtil::SetRefNodesInputRefCount() {
|
||||
size_t total_size = 0;
|
||||
|
@ -457,7 +459,9 @@ void MemReuseUtil::SetAllInfo(const KernelGraph *graph) {
|
|||
}
|
||||
SetKernelDefMap();
|
||||
SetReuseRefCount();
|
||||
#ifndef ENABLE_SECURITY
|
||||
SetSummaryNodesRefCount();
|
||||
#endif
|
||||
SetRefNodesInputRefCount();
|
||||
SetWorkSpaceList();
|
||||
#ifdef MEM_REUSE_DEBUG
|
||||
|
|
|
@ -63,7 +63,9 @@ class MemReuseUtil {
|
|||
void SetWkMap(const CNodePtr &kernel, KernelDef *kernel_def_ptr);
|
||||
void SetKernelDefInputs();
|
||||
void SetReuseRefCount();
|
||||
#ifndef ENABLE_SECURITY
|
||||
void SetSummaryNodesRefCount();
|
||||
#endif
|
||||
void SetRefNodesInputRefCount();
|
||||
// Set the reference count of graph output specially.
|
||||
void SetGraphOutputRefCount();
|
||||
|
|
|
@ -340,7 +340,9 @@ bool Somas::InitSomasTensors(const session::KernelGraph *graph) {
|
|||
MS_EXCEPTION_IF_NULL(graph);
|
||||
InitBasicInfo(graph);
|
||||
IndependentNodeOutputProcess(graph);
|
||||
#ifndef ENABLE_SECURITY
|
||||
SummaryInputProcess(graph);
|
||||
#endif
|
||||
RefNodeProcess(graph);
|
||||
NonTaskSplitProcess(graph);
|
||||
UnReuseNodeProcess(graph);
|
||||
|
@ -743,6 +745,7 @@ void Somas::IndependentNodeOutputProcess(const session::KernelGraph *graph) {
|
|||
MS_LOG(INFO) << "Special Tensor total size: Independent Node output " << total_size;
|
||||
}
|
||||
|
||||
#ifndef ENABLE_SECURITY
|
||||
void Somas::SummaryInputProcess(const session::KernelGraph *graph) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
bool summary_exist = graph->summary_node_exist();
|
||||
|
@ -782,6 +785,7 @@ void Somas::SummaryInputProcess(const session::KernelGraph *graph) {
|
|||
|
||||
MS_LOG(INFO) << "Special Tensor total size: SummaryNodes: " << total_summary_size;
|
||||
}
|
||||
#endif
|
||||
|
||||
void Somas::RefNodeProcess(const session::KernelGraph *graph) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
|
|
|
@ -115,7 +115,9 @@ class Somas {
|
|||
void InitSomasInputTensors(const session::KernelGraph *graph);
|
||||
void GetNextOutputProcess(const session::KernelGraph *graph);
|
||||
void IndependentNodeOutputProcess(const session::KernelGraph *graph);
|
||||
#ifndef ENABLE_SECURITY
|
||||
void SummaryInputProcess(const session::KernelGraph *graph);
|
||||
#endif
|
||||
void RefNodeProcess(const session::KernelGraph *graph);
|
||||
void NonTaskSplitProcess(const session::KernelGraph *graph);
|
||||
void UnReuseNodeProcess(const session::KernelGraph *graph);
|
||||
|
|
|
@ -58,10 +58,16 @@ bool IsOneOfPrimitive(const AnfNodePtr &node, const PrimitiveSet &prim_set) {
|
|||
}
|
||||
|
||||
bool IsRealKernelCNode(const CNodePtr &cnode) {
|
||||
#ifndef ENABLE_SECURITY
|
||||
static const PrimitiveSet virtual_prims = {
|
||||
prim::kPrimImageSummary, prim::kPrimScalarSummary, prim::kPrimTensorSummary, prim::kPrimHistogramSummary,
|
||||
prim::kPrimMakeTuple, prim::kPrimStateSetItem, prim::kPrimTupleGetItem, prim::kPrimReturn,
|
||||
prim::kPrimPartial, prim::kPrimDepend, prim::kPrimUpdateState, prim::kPrimLoad};
|
||||
#else
|
||||
static const PrimitiveSet virtual_prims = {prim::kPrimMakeTuple, prim::kPrimStateSetItem, prim::kPrimTupleGetItem,
|
||||
prim::kPrimReturn, prim::kPrimPartial, prim::kPrimDepend,
|
||||
prim::kPrimUpdateState, prim::kPrimLoad};
|
||||
#endif
|
||||
if (cnode->inputs().empty()) {
|
||||
MS_LOG(EXCEPTION) << "Illegal null input of cnode(%s)" << cnode->DebugString();
|
||||
}
|
||||
|
|
|
@ -607,7 +607,9 @@ GraphId AscendSession::CompileGraphImpl(NotNull<FuncGraphPtr> func_graph) {
|
|||
device::KernelAdjust::GetInstance().InsertOverflowCheckOperations(NOT_NULL(root_graph));
|
||||
// build kernel
|
||||
BuildKernel(root_graph);
|
||||
#ifndef ENABLE_SECURITY
|
||||
SetSummaryNodes(root_graph.get());
|
||||
#endif
|
||||
// Alloc memory for child graph's inputs
|
||||
AssignStaticMemory(NOT_NULL(root_graph), NOT_NULL(&memo));
|
||||
memo.clear();
|
||||
|
@ -634,6 +636,7 @@ GraphId AscendSession::CompileGraphImpl(NotNull<FuncGraphPtr> func_graph) {
|
|||
return graph_id;
|
||||
}
|
||||
|
||||
#ifndef ENABLE_SECURITY
|
||||
void AscendSession::SetFinalGraphSummaryFlag(const std::shared_ptr<KernelGraph> &kernel_graph) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
auto graph_order = GetGraphOrder(kernel_graph->graph_id());
|
||||
|
@ -649,6 +652,7 @@ void AscendSession::SetFinalGraphSummaryFlag(const std::shared_ptr<KernelGraph>
|
|||
}
|
||||
kernel_graph->set_summary_node_exist(false);
|
||||
}
|
||||
#endif
|
||||
|
||||
void AscendSession::BuildGraphImpl(GraphId graph_id) {
|
||||
MS_LOG(INFO) << "Start";
|
||||
|
@ -760,7 +764,9 @@ void AscendSession::PreExecuteGraph(const std::shared_ptr<KernelGraph> &kernel_g
|
|||
void AscendSession::PostExecuteGraph(const std::shared_ptr<KernelGraph> &kernel_graph,
|
||||
const std::vector<tensor::TensorPtr> &, VectorRef *const) {
|
||||
// summary
|
||||
#ifndef ENABLE_SECURITY
|
||||
Summary(kernel_graph.get());
|
||||
#endif
|
||||
#ifdef ENABLE_DEBUGGER
|
||||
// load tensor from device for debugger
|
||||
if (debugger_ && debugger_->debugger_enabled()) {
|
||||
|
@ -1603,6 +1609,7 @@ void AscendSession::LoadTensor(const std::shared_ptr<KernelGraph> &kernel_graph)
|
|||
MS_LOG(INFO) << "Finish!";
|
||||
}
|
||||
|
||||
#ifndef ENABLE_SECURITY
|
||||
void AscendSession::RecurseSetSummaryNodes(KernelGraph *graph,
|
||||
std::map<std::string, std::pair<AnfNodePtr, int>> *summary) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
|
@ -1640,6 +1647,7 @@ void AscendSession::SetSummaryNodes(KernelGraph *graph) {
|
|||
graph->set_summary_nodes(summary);
|
||||
MS_LOG(DEBUG) << "Update summary end size: " << summary.size();
|
||||
}
|
||||
#endif
|
||||
|
||||
void AscendSession::MergeGraphExecOrder() {
|
||||
MS_LOG(INFO) << "Start!";
|
||||
|
|
|
@ -89,8 +89,10 @@ class AscendSession : public SessionBasic {
|
|||
private:
|
||||
// compile child graph when session have multiple child graphs
|
||||
void CompileChildGraph(const KernelGraphPtr &child_graph);
|
||||
#ifndef ENABLE_SECURITY
|
||||
void RecurseSetSummaryNodes(KernelGraph *graph, std::map<std::string, std::pair<AnfNodePtr, int>> *summary);
|
||||
void SetSummaryNodes(KernelGraph *graph) override;
|
||||
#endif
|
||||
void InitRuntimeResource();
|
||||
void SelectKernel(const KernelGraph &kernel_graph) const;
|
||||
void HardwareOptimize(const std::shared_ptr<KernelGraph> &kernel_graph) const;
|
||||
|
@ -128,7 +130,9 @@ class AscendSession : public SessionBasic {
|
|||
const std::vector<GraphType> &GetGraphOrderType(GraphId final_graph_id) const;
|
||||
// sync initial tensors' data to device
|
||||
void SyncInitialTenosrToDevice();
|
||||
#ifndef ENABLE_SECURITY
|
||||
void SetFinalGraphSummaryFlag(const std::shared_ptr<KernelGraph> &kernel_graph);
|
||||
#endif
|
||||
// create parameter to receive data from multiple branch output
|
||||
void CreateMultiBranchOutput(NotNull<KernelGraphPtr> graph, NotNull<std::set<KernelGraphPtr> *> memo);
|
||||
void SelectKernel(NotNull<KernelGraphPtr> root_graph);
|
||||
|
|
|
@ -146,7 +146,9 @@ GraphId CPUSession::CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtr
|
|||
MS_LOG(INFO) << "Assign kernel address";
|
||||
runtime_.AssignKernelAddress(graph.get());
|
||||
// set summary node
|
||||
#ifndef ENABLE_SECURITY
|
||||
SetSummaryNodes(graph.get());
|
||||
#endif
|
||||
runtime_.IncreaseSummaryRefCount(graph->summary_nodes());
|
||||
DumpGraph(graph);
|
||||
return graph_id;
|
||||
|
@ -205,7 +207,9 @@ void CPUSession::PreExecuteGraph(const std::shared_ptr<KernelGraph> &kernel_grap
|
|||
|
||||
void CPUSession::PostExecuteGraph(const std::shared_ptr<KernelGraph> &kernel_graph,
|
||||
const std::vector<tensor::TensorPtr> &, VectorRef *const) {
|
||||
#ifndef ENABLE_SECURITY
|
||||
Summary(kernel_graph.get());
|
||||
#endif
|
||||
}
|
||||
|
||||
void CPUSession::ExecuteGraph(const std::shared_ptr<KernelGraph> &kernel_graph) {
|
||||
|
|
|
@ -455,8 +455,10 @@ GraphId GPUSession::CompileGraphImpl(KernelGraphPtr graph) {
|
|||
std::string exec_order_name = "graph_exec_order." + std::to_string(graph->graph_id());
|
||||
(void)mindspore::RDR::RecordGraphExecOrder(SubModuleId::SM_SESSION, exec_order_name, kernels);
|
||||
#endif
|
||||
#ifndef ENABLE_SECURITY
|
||||
// Get summary nodes.
|
||||
SetSummaryNodes(graph.get());
|
||||
#endif
|
||||
// Dump .pb graph after graph optimization
|
||||
#ifdef ENABLE_DUMP_IR
|
||||
if (save_graphs) {
|
||||
|
@ -523,9 +525,11 @@ void GPUSession::PostExecuteGraph(const std::shared_ptr<KernelGraph> &kernel_gra
|
|||
// Summary
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
#ifndef ENABLE_SECURITY
|
||||
if (context_ptr->get_param<bool>(MS_CTX_ENABLE_GPU_SUMMARY)) {
|
||||
Summary(kernel_graph.get());
|
||||
}
|
||||
#endif
|
||||
#ifdef ENABLE_DEBUGGER
|
||||
if (debugger_ && debugger_->DebuggerBackendEnabled()) {
|
||||
debugger_->LoadParametersAndConst(kernel_graph);
|
||||
|
|
|
@ -176,8 +176,10 @@ class KernelGraph : public FuncGraph {
|
|||
bool executable() const { return executable_; }
|
||||
// set executable of graph
|
||||
void set_executable(bool executable) { executable_ = executable; }
|
||||
#ifndef ENABLE_SECURITY
|
||||
// set summary_node of graph
|
||||
void set_summary_node_exist(bool summary_node_exist) { summary_node_exist_ = summary_node_exist; }
|
||||
#endif
|
||||
// check whether exist summary node in graph
|
||||
bool summary_node_exist() const { return summary_node_exist_; }
|
||||
// set invalid inputs for control sink
|
||||
|
|
|
@ -355,6 +355,7 @@ void DumpGraphOutput(const Any &any, size_t recurse_level = 0) {
|
|||
MS_LOG(INFO) << tab_str;
|
||||
}
|
||||
|
||||
#ifndef ENABLE_SECURITY
|
||||
bool ExistSummaryNode(const KernelGraph *graph) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
auto ret = graph->get_return();
|
||||
|
@ -368,6 +369,7 @@ bool ExistSummaryNode(const KernelGraph *graph) {
|
|||
}
|
||||
return false;
|
||||
}
|
||||
#endif
|
||||
|
||||
BaseRef CreateNodeOutputPlaceholder(const session::KernelWithIndex &node_output_pair, const KernelGraphPtr &graph,
|
||||
const std::vector<tensor::TensorPtr> &input_tensors,
|
||||
|
@ -1153,9 +1155,12 @@ KernelGraphPtr SessionBasic::ConstructKernelGraph(const AnfNodePtrList &lst, con
|
|||
graph->set_manager(manager);
|
||||
}
|
||||
graph->SetExecOrderByDefault();
|
||||
|
||||
#ifndef ENABLE_SECURITY
|
||||
if (ExistSummaryNode(graph.get())) {
|
||||
graph->set_summary_node_exist(true);
|
||||
}
|
||||
#endif
|
||||
|
||||
UnifyMindIR(graph);
|
||||
// Update Graph Dynamic Shape Attr
|
||||
|
@ -1564,9 +1569,13 @@ std::shared_ptr<KernelGraph> SessionBasic::ConstructKernelGraph(const FuncGraphP
|
|||
graph->SetInputNodes();
|
||||
SetInputNodeUsage(graph, manager);
|
||||
graph->SetExecOrderByDefault();
|
||||
|
||||
#ifndef ENABLE_SECURITY
|
||||
if (ExistSummaryNode(graph.get())) {
|
||||
graph->set_summary_node_exist(true);
|
||||
}
|
||||
#endif
|
||||
|
||||
all_out_graph->push_back(graph);
|
||||
return graph;
|
||||
}
|
||||
|
@ -1795,6 +1804,7 @@ void SessionBasic::GetModelOutputsInfo(uint32_t graph_id, std::vector<tensor::Te
|
|||
}
|
||||
}
|
||||
|
||||
#ifndef ENABLE_SECURITY
|
||||
void SessionBasic::RegisterSummaryCallBackFunc(const CallBackFunc &callback) {
|
||||
MS_EXCEPTION_IF_NULL(callback);
|
||||
summary_callback_ = callback;
|
||||
|
@ -1874,6 +1884,7 @@ void SessionBasic::Summary(KernelGraph *graph) {
|
|||
// call callback function here
|
||||
summary_callback_(0, params_list);
|
||||
}
|
||||
#endif
|
||||
|
||||
namespace {
|
||||
bool CNodeFirstInputIsPrimitive(const AnfNodePtr &node) {
|
||||
|
|
|
@ -112,7 +112,9 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> {
|
|||
const std::vector<int64_t> &tensors_mask);
|
||||
void RunOpsInGraph(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs);
|
||||
|
||||
#ifndef ENABLE_SECURITY
|
||||
virtual void RegisterSummaryCallBackFunc(const CallBackFunc &callback);
|
||||
#endif
|
||||
|
||||
bool CreateCNodeOfKernelGraph(const AnfNodePtr &node, KernelGraph *graph);
|
||||
|
||||
|
@ -239,7 +241,9 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> {
|
|||
virtual void BuildOpsInGraph(const GraphId &graph_id, const std::map<AnfNodePtr, size_t> ¶meter_index,
|
||||
const std::vector<tensor::TensorPtr> &graph_inputs,
|
||||
const std::map<KernelWithIndex, size_t> &cnode_refcount) {}
|
||||
#ifndef ENABLE_SECURITY
|
||||
virtual void SetSummaryNodes(KernelGraph *graph);
|
||||
#endif
|
||||
|
||||
void LoadInputs(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs_const) {
|
||||
auto kernel_graph = GetGraph(graph_id);
|
||||
|
@ -259,7 +263,9 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> {
|
|||
const std::vector<tensor::TensorPtr> &input_tensors,
|
||||
std::map<tensor::TensorPtr, session::KernelWithIndex> *tensor_to_node) const;
|
||||
void UpdateOutputAbstract(const std::shared_ptr<KernelGraph> &kernel_graph, OpRunInfo *op_run_info) const;
|
||||
#ifndef ENABLE_SECURITY
|
||||
void Summary(KernelGraph *graph);
|
||||
#endif
|
||||
// create graph output for RunOp
|
||||
void CreateOutputNode(const CNodePtr &cnode, const std::shared_ptr<KernelGraph> &graph);
|
||||
CNodePtr ConstructOutput(const AnfNodePtrList &outputs, const std::shared_ptr<KernelGraph> &graph);
|
||||
|
|
|
@ -53,6 +53,7 @@ bool InConvertWhiteList(const AnfNodePtr &node, size_t index) {
|
|||
// node because it is attribute or ge specific reason.
|
||||
// Example : when convert CNode(kPrimReduceSum, x, axis), node of index 2 in CNode->inputs is axis which should not be
|
||||
// converted to switch guarded.
|
||||
#ifndef ENABLE_SECURITY
|
||||
std::vector<std::pair<PrimitivePtr, std::vector<size_t>>> white_list({{prim::kPrimApplyMomentum, {1, 2}},
|
||||
{prim::kPrimMomentum, {2, 3}},
|
||||
{prim::kPrimStateSetItem, {1}},
|
||||
|
@ -78,6 +79,20 @@ bool InConvertWhiteList(const AnfNodePtr &node, size_t index) {
|
|||
{prim::kPrimTile, {2}},
|
||||
{prim::kPrimExpandDims, {2}},
|
||||
{prim::kPrimHistogramSummary, {1}}});
|
||||
#else
|
||||
std::vector<std::pair<PrimitivePtr, std::vector<size_t>>> white_list(
|
||||
{{prim::kPrimApplyMomentum, {1, 2}}, {prim::kPrimMomentum, {2, 3}},
|
||||
{prim::kPrimStateSetItem, {1}}, {prim::kPrimTupleGetItem, {2}},
|
||||
{prim::kPrimEnvGetItem, {1}}, {prim::kPrimEnvSetItem, {1}},
|
||||
{prim::kPrimReduceSum, {2}}, {prim::kPrimReduceMean, {2}},
|
||||
{prim::kPrimReduceAll, {2}}, {prim::kPrimCast, {2}},
|
||||
{prim::kPrimTranspose, {2}}, {prim::kPrimOneHot, {2}},
|
||||
{prim::kPrimGather, {3}}, {prim::kPrimReshape, {2}},
|
||||
{prim::kPrimAssign, {1}}, {prim::kPrimAssignAdd, {1}},
|
||||
{prim::kPrimAssignSub, {1}}, {prim::kPrimApplyRMSProp, {6, 7, 8}},
|
||||
{prim::kPrimCumSum, {2}}, {prim::kPrimTile, {2}},
|
||||
{prim::kPrimExpandDims, {2}}});
|
||||
#endif
|
||||
for (auto &item : white_list) {
|
||||
auto matched = std::any_of(item.second.begin(), item.second.end(), [&item, &node, &index](size_t idx) {
|
||||
return IsPrimitiveCNode(node, item.first) && idx == index;
|
||||
|
|
|
@ -22,7 +22,9 @@
|
|||
#include "utils/symbolic.h"
|
||||
#include "pybind_api/api_register.h"
|
||||
#include "pipeline/jit/parse/python_adapter.h"
|
||||
#ifndef ENABLE_SECURITY
|
||||
#include "utils/summary/event_writer.h"
|
||||
#endif
|
||||
#include "utils/config_manager.h"
|
||||
#include "utils/mpi/mpi_config.h"
|
||||
#include "utils/ms_utils.h"
|
||||
|
@ -48,7 +50,9 @@ using GraphExecutorPy = mindspore::pipeline::GraphExecutorPy;
|
|||
using Pipeline = mindspore::pipeline::Pipeline;
|
||||
using PrimitivePy = mindspore::PrimitivePy;
|
||||
using MetaFuncGraph = mindspore::MetaFuncGraph;
|
||||
#ifndef ENABLE_SECURITY
|
||||
using EventWriter = mindspore::summary::EventWriter;
|
||||
#endif // ENABLE_SECURITY
|
||||
using OpLib = mindspore::kernel::OpLib;
|
||||
using ParallelContext = mindspore::parallel::ParallelContext;
|
||||
using CostModelContext = mindspore::parallel::CostModelContext;
|
||||
|
@ -311,6 +315,7 @@ PYBIND11_MODULE(_c_expression, m) {
|
|||
}
|
||||
}});
|
||||
|
||||
#ifndef ENABLE_SECURITY
|
||||
(void)py::class_<EventWriter, std::shared_ptr<EventWriter>>(m, "EventWriter_")
|
||||
.def(py::init<const std::string &>())
|
||||
.def("GetFileName", &EventWriter::GetFileName, "Get the file name.")
|
||||
|
@ -320,6 +325,7 @@ PYBIND11_MODULE(_c_expression, m) {
|
|||
.def("Flush", &EventWriter::Flush, "Flush the event.")
|
||||
.def("Close", &EventWriter::Close, "Close the write.")
|
||||
.def("Shut", &EventWriter::Shut, "Final close the write.");
|
||||
#endif // ENABLE_SECURITY
|
||||
|
||||
(void)py::class_<OpLib, std::shared_ptr<OpLib>>(m, "Oplib")
|
||||
.def(py::init())
|
||||
|
|
|
@ -592,8 +592,10 @@ void GPUKernelRuntime::InitKernelRefCount(const session::KernelGraph *graph) {
|
|||
mem_reuse_util_ptr->SetReuseRefCount();
|
||||
// Can't free the device address of graph output, so set the reference count of graph output specially.
|
||||
mem_reuse_util_ptr->SetGraphOutputRefCount();
|
||||
#ifndef ENABLE_SECURITY
|
||||
// Can't free the device address of summary nodes, so set the reference count of summary nodes specially.
|
||||
mem_reuse_util_ptr->SetSummaryNodesRefCount();
|
||||
#endif
|
||||
auto graph_id = graph->graph_id();
|
||||
mem_reuse_util_map_[graph_id] = mem_reuse_util_ptr;
|
||||
}
|
||||
|
|
|
@ -353,8 +353,9 @@ GraphId GraphCompiler::CompileGraphImpl(const KernelGraphPtr &graph, const Devic
|
|||
|
||||
MS_EXCEPTION_IF_NULL(session_);
|
||||
session_->InitAllBucket(graph, device_context);
|
||||
|
||||
#ifndef ENABLE_SECURITY
|
||||
session_->SetSummaryNodes(graph.get());
|
||||
#endif
|
||||
SetSummaryNodesRefCount(graph.get());
|
||||
#ifdef ENABLE_DUMP_IR
|
||||
// Dump .pb graph after graph optimization.
|
||||
|
@ -525,13 +526,17 @@ const std::vector<KernelWithIndex> &GraphCompiler::GetGraphOutputNodes(GraphId g
|
|||
|
||||
void GraphCompiler::RegisterSummaryCallBackFunc(const CallBackFunc &callback) const {
|
||||
MS_EXCEPTION_IF_NULL(session_);
|
||||
#ifndef ENABLE_SECURITY
|
||||
session_->RegisterSummaryCallBackFunc(callback);
|
||||
#endif
|
||||
}
|
||||
|
||||
void GraphCompiler::Summary(const std::vector<KernelGraphPtr> &graphs) const {
|
||||
MS_EXCEPTION_IF_NULL(session_);
|
||||
for (const auto &graph : graphs) {
|
||||
#ifndef ENABLE_SECURITY
|
||||
session_->Summary(graph.get());
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -32,10 +32,12 @@ REG_ADPT_DESC(Constant, kNameConst, ADPT_DESC(Constant, Const))
|
|||
// ScalarSummary
|
||||
INPUT_MAP(Summary) = {{2, INPUT_DESC(x)}};
|
||||
ATTR_MAP(Summary) = EMPTY_ATTR_MAP;
|
||||
#ifndef ENABLE_SECURITY
|
||||
REG_ADPT_DESC(ScalarSummary, prim::kPrimScalarSummary->name(), ADPT_DESC(Summary))
|
||||
REG_ADPT_DESC(ImageSummary, prim::kPrimImageSummary->name(), ADPT_DESC(Summary))
|
||||
REG_ADPT_DESC(TensorSummary, prim::kPrimTensorSummary->name(), ADPT_DESC(Summary))
|
||||
REG_ADPT_DESC(HistogramSummary, prim::kPrimHistogramSummary->name(), ADPT_DESC(Summary))
|
||||
#endif
|
||||
REG_ADPT_DESC(Debug, prim::kPrimDebug->name(), ADPT_DESC(Summary))
|
||||
|
||||
// Data
|
||||
|
|
|
@ -12,6 +12,10 @@ if(NOT ENABLE_D AND NOT ENABLE_TESTCASES)
|
|||
file(GLOB_RECURSE _UTILS_D_SRC_FILES ./runtime_error_codes.cc)
|
||||
list(REMOVE_ITEM _UTILS_SRC_LIST ${_UTILS_D_SRC_FILES})
|
||||
endif()
|
||||
if(ENABLE_SECURITY)
|
||||
file(GLOB_RECURSE _UTILS_SUMMARY_FILES ./summary/event_writer.cc)
|
||||
list(REMOVE_ITEM _UTILS_SRC_LIST ${_UTILS_SUMMARY_FILES})
|
||||
endif()
|
||||
|
||||
set_property(SOURCE ${_UTILS_SRC_LIST} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_UTILS)
|
||||
add_library(_mindspore_utils_obj OBJECT ${_UTILS_SRC_LIST})
|
||||
|
|
|
@ -301,7 +301,9 @@ MsBackend::MsBackend(const std::string &name, const std::string &target, uint32_
|
|||
MS_LOG(EXCEPTION) << "Session create failed!, please make sure target device:" << target << " is available.";
|
||||
}
|
||||
target_sess_->Init(device_id);
|
||||
#ifndef ENABLE_SECURITY
|
||||
target_sess_->RegisterSummaryCallBackFunc(callbacks::SummarySaveCallback);
|
||||
#endif
|
||||
target_device_ = target;
|
||||
}
|
||||
|
||||
|
@ -317,7 +319,9 @@ void MsBackend::CreateOtherSession(const std::string &target) {
|
|||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
uint32_t device_id = context_ptr->get_param<uint32_t>(MS_CTX_DEVICE_ID);
|
||||
other_sess_->Init(device_id);
|
||||
#ifndef ENABLE_SECURITY
|
||||
other_sess_->RegisterSummaryCallBackFunc(callbacks::SummarySaveCallback);
|
||||
#endif
|
||||
other_device_ = target;
|
||||
}
|
||||
|
||||
|
|
|
@ -25,6 +25,11 @@ file(GLOB_RECURSE CORE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
|||
"load_mindir/*.cc"
|
||||
)
|
||||
|
||||
if(ENABLE_SECURITY)
|
||||
file(GLOB_RECURSE _INFER_SUMMARY_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "ops/*_summary.cc")
|
||||
list(REMOVE_ITEM CORE_SRC_LIST ${_INFER_SUMMARY_FILES})
|
||||
endif()
|
||||
|
||||
file(GLOB_RECURSE PROTO_FILE RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "proto/*.proto")
|
||||
if(NOT(BUILD_LITE))
|
||||
ms_protobuf_generate_py(PROTO_SRCS PY_HDRS PY_PYS ${PROTO_FILE})
|
||||
|
|
|
@ -548,10 +548,12 @@ inline const PrimitivePtr kPrimGenerateInverseIndex = std::make_shared<Primitive
|
|||
|
||||
// Debug ops
|
||||
inline const PrimitivePtr kPrimAssert = std::make_shared<Primitive>("Assert");
|
||||
#ifndef ENABLE_SECURITY
|
||||
inline const PrimitivePtr kPrimScalarSummary = std::make_shared<Primitive>("ScalarSummary");
|
||||
inline const PrimitivePtr kPrimImageSummary = std::make_shared<Primitive>("ImageSummary");
|
||||
inline const PrimitivePtr kPrimTensorSummary = std::make_shared<Primitive>("TensorSummary");
|
||||
inline const PrimitivePtr kPrimHistogramSummary = std::make_shared<Primitive>("HistogramSummary");
|
||||
#endif
|
||||
inline const PrimitivePtr kPrimDebug = std::make_shared<Primitive>("Debug");
|
||||
|
||||
// Dynamic shape testing
|
||||
|
|
|
@ -392,12 +392,16 @@ std::string GetVirtualNodeTargetFromInputs(const AnfNodePtr &node) {
|
|||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto &inputs = cnode->inputs();
|
||||
#ifndef ENABLE_SECURITY
|
||||
if (IsPrimitiveCNode(node, prim::kPrimImageSummary) || IsPrimitiveCNode(node, prim::kPrimScalarSummary) ||
|
||||
IsPrimitiveCNode(node, prim::kPrimTensorSummary) || IsPrimitiveCNode(node, prim::kPrimHistogramSummary)) {
|
||||
if (inputs.size() > 1) {
|
||||
return GetOriginNodeTarget(inputs[1]);
|
||||
}
|
||||
} else if (IsPrimitiveCNode(node, prim::kPrimDepend) || IsPrimitiveCNode(node, prim::kPrimLoad)) {
|
||||
return kTargetUnDefined;
|
||||
}
|
||||
#endif
|
||||
if (IsPrimitiveCNode(node, prim::kPrimDepend) || IsPrimitiveCNode(node, prim::kPrimLoad)) {
|
||||
const size_t node_inputs_num = 3;
|
||||
if (inputs.size() >= node_inputs_num) {
|
||||
size_t use_index = 1;
|
||||
|
@ -521,6 +525,7 @@ std::string GetOriginNodeTarget(const AnfNodePtr &node) {
|
|||
if (target != kTargetUnDefined) {
|
||||
return target;
|
||||
}
|
||||
#ifndef ENABLE_SECURITY
|
||||
if (IsPrimitiveCNode(node, prim::kPrimImageSummary) || IsPrimitiveCNode(node, prim::kPrimScalarSummary) ||
|
||||
IsPrimitiveCNode(node, prim::kPrimTensorSummary) || IsPrimitiveCNode(node, prim::kPrimHistogramSummary) ||
|
||||
IsPrimitiveCNode(node, prim::kPrimDepend) || IsPrimitiveCNode(node, prim::kPrimLoad) ||
|
||||
|
@ -528,6 +533,13 @@ std::string GetOriginNodeTarget(const AnfNodePtr &node) {
|
|||
IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) {
|
||||
return GetVirtualNodeTarget(node);
|
||||
}
|
||||
#else
|
||||
if (IsPrimitiveCNode(node, prim::kPrimDepend) || IsPrimitiveCNode(node, prim::kPrimLoad) ||
|
||||
IsPrimitiveCNode(node, prim::kPrimUpdateState) || IsPrimitiveCNode(node, prim::kPrimMakeTuple) ||
|
||||
IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) {
|
||||
return GetVirtualNodeTarget(node);
|
||||
}
|
||||
#endif
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
return context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
|
||||
|
|
|
@ -41,6 +41,7 @@ std::string CNode::fullname_with_scope() {
|
|||
return fullname_with_scope_;
|
||||
}
|
||||
|
||||
#ifndef ENABLE_SECURITY
|
||||
if (IsApply(prim::kPrimScalarSummary) || IsApply(prim::kPrimTensorSummary) || IsApply(prim::kPrimImageSummary) ||
|
||||
IsApply(prim::kPrimHistogramSummary)) {
|
||||
std::string tag = GetValue<std::string>(GetValueNode(input(1)));
|
||||
|
@ -55,40 +56,40 @@ std::string CNode::fullname_with_scope() {
|
|||
name = tag + "[:Tensor]";
|
||||
}
|
||||
fullname_with_scope_ = name;
|
||||
} else {
|
||||
// cnode input 0 should be primitive ptr or funcgraph ptr
|
||||
auto value_ptr = input(0)->cast<ValueNodePtr>();
|
||||
if (value_ptr == nullptr) {
|
||||
MS_LOG(WARNING) << "Input 0 of cnode is not a value node, its type is " << input(0)->type_name() << ".";
|
||||
fullname_with_scope_ = id_generator::get_id(shared_from_base<CNode>());
|
||||
return fullname_with_scope_;
|
||||
}
|
||||
auto input_value = value_ptr->value();
|
||||
if (input_value == nullptr) {
|
||||
MS_LOG(WARNING) << "Value of input 0 of cnode is nullptr.";
|
||||
fullname_with_scope_ = id_generator::get_id(shared_from_base<CNode>());
|
||||
return fullname_with_scope_;
|
||||
}
|
||||
|
||||
auto prim = input_value->cast<PrimitivePtr>();
|
||||
MS_EXCEPTION_IF_NULL(scope());
|
||||
fullname_with_scope_ = scope()->name() + "/";
|
||||
if (prim != nullptr) {
|
||||
fullname_with_scope_ += prim->name();
|
||||
} else {
|
||||
auto func_graph = input_value->cast<FuncGraphPtr>();
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
auto fg_flag = func_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL);
|
||||
if (fg_flag != nullptr) {
|
||||
auto fg_name = GetValue<std::string>(fg_flag);
|
||||
fullname_with_scope_ += "GraphKernel_" + fg_name;
|
||||
} else {
|
||||
fullname_with_scope_ += func_graph->ToString();
|
||||
}
|
||||
}
|
||||
fullname_with_scope_ += "-op" + id_generator::get_id(shared_from_base<CNode>());
|
||||
return fullname_with_scope_;
|
||||
}
|
||||
#endif
|
||||
// cnode input 0 should be primitive ptr or funcgraph ptr
|
||||
auto value_ptr = input(0)->cast<ValueNodePtr>();
|
||||
if (value_ptr == nullptr) {
|
||||
MS_LOG(WARNING) << "Input 0 of cnode is not a value node, its type is " << input(0)->type_name() << ".";
|
||||
fullname_with_scope_ = id_generator::get_id(shared_from_base<CNode>());
|
||||
return fullname_with_scope_;
|
||||
}
|
||||
auto input_value = value_ptr->value();
|
||||
if (input_value == nullptr) {
|
||||
MS_LOG(WARNING) << "Value of input 0 of cnode is nullptr.";
|
||||
fullname_with_scope_ = id_generator::get_id(shared_from_base<CNode>());
|
||||
return fullname_with_scope_;
|
||||
}
|
||||
|
||||
auto prim = input_value->cast<PrimitivePtr>();
|
||||
MS_EXCEPTION_IF_NULL(scope());
|
||||
fullname_with_scope_ = scope()->name() + "/";
|
||||
if (prim != nullptr) {
|
||||
fullname_with_scope_ += prim->name();
|
||||
} else {
|
||||
auto func_graph = input_value->cast<FuncGraphPtr>();
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
auto fg_flag = func_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL);
|
||||
if (fg_flag != nullptr) {
|
||||
auto fg_name = GetValue<std::string>(fg_flag);
|
||||
fullname_with_scope_ += "GraphKernel_" + fg_name;
|
||||
} else {
|
||||
fullname_with_scope_ += func_graph->ToString();
|
||||
}
|
||||
}
|
||||
fullname_with_scope_ += "-op" + id_generator::get_id(shared_from_base<CNode>());
|
||||
return fullname_with_scope_;
|
||||
}
|
||||
|
||||
|
|
|
@ -22,6 +22,7 @@
|
|||
|
||||
namespace mindspore {
|
||||
// clang-format off
|
||||
#ifndef ENABLE_SECURITY
|
||||
static const std::set<std::string> PARALLEL_BLACK_LIST_ = {prim::kTupleGetItem, "J", "list_getitem",
|
||||
"array_getitem", "tuple_setitem", "Depend", "list_setitem", "array_setitem", "dict_getitem",
|
||||
"list_append", "list_map", "list_reduce", "tuple_reversed", "tile_shape", "tuple_div", "tuple_to_array",
|
||||
|
@ -31,6 +32,16 @@ static const std::set<std::string> PARALLEL_BLACK_LIST_ = {prim::kTupleGetItem,
|
|||
"ImageSummary", "TensorSummary", "Debug", "HistogramSummary", "col2im_v1", "resolve", "BroadcastGradientArgs",
|
||||
"InvertPermutation", "DropoutGenMask", "embed", "create_instance", "RefToEmbed",
|
||||
"stop_gradient", "UpdateState", "Load", "Switch", "Print"};
|
||||
#else
|
||||
static const std::set<std::string> PARALLEL_BLACK_LIST_ = {prim::kTupleGetItem, "J", "list_getitem",
|
||||
"array_getitem", "tuple_setitem", "Depend", "list_setitem", "array_setitem", "dict_getitem",
|
||||
"list_append", "list_map", "list_reduce", "tuple_reversed", "tile_shape", "tuple_div", "tuple_to_array",
|
||||
"make_dict", "make_slice", "make_record", "string_equal", "VirtualLoss", "Return", "env_getitem",
|
||||
"identity", "partial", "env_setitem", "env_getitem", "env_add", "MakeRefKey", "make_ref", "get_ref_key",
|
||||
"get_ref_value", "get_ref_origin", "dot", "im2col", "col2im", "im2col_v1", "state_setitem", "Debug", "col2im_v1",
|
||||
"resolve", "BroadcastGradientArgs", "InvertPermutation", "DropoutGenMask", "embed", "create_instance", "RefToEmbed",
|
||||
"stop_gradient", "UpdateState", "Load", "Switch", "Print"};
|
||||
#endif
|
||||
static const std::set<PrimitivePtr> ALLGATHER_NODE_LIST_ = {prim::kPrimAllGather, prim::kPrimMiniStepAllGather,
|
||||
prim::kPrimMicroStepAllGather};
|
||||
static const std::set<PrimitivePtr> TRIVIAL_NODE_LIST_ = {prim::kPrimCast, prim::kPrimDepend};
|
||||
|
|
|
@ -158,12 +158,18 @@ bool IsRealKernel(const AnfNodePtr &node) {
|
|||
return false;
|
||||
}
|
||||
auto input = cnode->inputs()[0];
|
||||
#ifndef ENABLE_SECURITY
|
||||
bool is_virtual_node = IsPrimitive(input, prim::kPrimImageSummary) || IsPrimitive(input, prim::kPrimScalarSummary) ||
|
||||
IsPrimitive(input, prim::kPrimTensorSummary) ||
|
||||
IsPrimitive(input, prim::kPrimHistogramSummary) || IsPrimitive(input, prim::kPrimMakeTuple) ||
|
||||
IsPrimitive(input, prim::kPrimStateSetItem) || IsPrimitive(input, prim::kPrimDepend) ||
|
||||
IsPrimitive(input, prim::kPrimTupleGetItem) || IsPrimitive(input, prim::kPrimReturn) ||
|
||||
IsPrimitive(input, prim::kPrimPartial);
|
||||
#else
|
||||
bool is_virtual_node = IsPrimitive(input, prim::kPrimMakeTuple) || IsPrimitive(input, prim::kPrimStateSetItem) ||
|
||||
IsPrimitive(input, prim::kPrimDepend) || IsPrimitive(input, prim::kPrimTupleGetItem) ||
|
||||
IsPrimitive(input, prim::kPrimReturn) || IsPrimitive(input, prim::kPrimPartial);
|
||||
#endif
|
||||
return !is_virtual_node;
|
||||
}
|
||||
|
||||
|
|
|
@ -12,7 +12,6 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""debug_ops"""
|
||||
from types import FunctionType, MethodType
|
||||
|
||||
|
@ -23,6 +22,7 @@ from ..._checkparam import Rel
|
|||
from ...common import dtype as mstype
|
||||
from ..primitive import prim_attr_register, Primitive, PrimitiveWithInfer
|
||||
|
||||
|
||||
def _check_mode(class_name):
|
||||
"""Check for PyNative mode."""
|
||||
mode = context.get_context('mode')
|
||||
|
@ -87,6 +87,10 @@ class ScalarSummary(Primitive):
|
|||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""Initialize ScalarSummary."""
|
||||
|
||||
if security.enable_security():
|
||||
raise ValueError('The Summary is not supported, please without `-s on` and recompile source.')
|
||||
|
||||
self.add_prim_attr("side_effect_io", True)
|
||||
|
||||
|
||||
|
@ -126,6 +130,10 @@ class ImageSummary(PrimitiveWithInfer):
|
|||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""Initialize ImageSummary."""
|
||||
|
||||
if security.enable_security():
|
||||
raise ValueError('The Summary is not supported, please without `-s on` and recompile source.')
|
||||
|
||||
self.add_prim_attr("side_effect_io", True)
|
||||
|
||||
def __infer__(self, name, value):
|
||||
|
@ -178,6 +186,10 @@ class TensorSummary(Primitive):
|
|||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""Initialize TensorSummary."""
|
||||
|
||||
if security.enable_security():
|
||||
raise ValueError('The Summary is not supported, please without `-s on` and recompile source.')
|
||||
|
||||
self.add_prim_attr("side_effect_io", True)
|
||||
|
||||
|
||||
|
@ -218,6 +230,10 @@ class HistogramSummary(PrimitiveWithInfer):
|
|||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""Initialize HistogramSummary."""
|
||||
|
||||
if security.enable_security():
|
||||
raise ValueError('The Summary is not supported, please without `-s on` and recompile source.')
|
||||
|
||||
self.add_prim_attr("side_effect_io", True)
|
||||
|
||||
def __infer__(self, name, value):
|
||||
|
|
|
@ -36,6 +36,7 @@ from mindspore.train.callback._dataset_graph import DatasetGraph
|
|||
from mindspore.nn.optim.optimizer import Optimizer
|
||||
from mindspore.nn.loss.loss import LossBase
|
||||
from mindspore.train._utils import check_value_type
|
||||
from ..._c_expression import security
|
||||
|
||||
HYPER_CONFIG_ENV_NAME = "MINDINSIGHT_HYPER_CONFIG"
|
||||
HYPER_CONFIG_LEN_LIMIT = 100000
|
||||
|
@ -194,6 +195,10 @@ class SummaryCollector(Callback):
|
|||
collect_tensor_freq=None,
|
||||
max_file_size=None,
|
||||
export_options=None):
|
||||
|
||||
if security.enable_security():
|
||||
raise ValueError('The Summary is not supported, please without `-s on` and recompile source.')
|
||||
|
||||
super(SummaryCollector, self).__init__()
|
||||
|
||||
self._summary_dir = self._process_summary_dir(summary_dir)
|
||||
|
|
|
@ -23,7 +23,7 @@ from collections import defaultdict
|
|||
from mindspore import log as logger
|
||||
from mindspore.nn import Cell
|
||||
|
||||
from ..._c_expression import Tensor
|
||||
from ..._c_expression import Tensor, security
|
||||
from ..._checkparam import Validator
|
||||
from .._utils import _check_lineage_value, _check_to_numpy, _make_directory, check_value_type
|
||||
from ._summary_adapter import get_event_file_name, package_graph_event
|
||||
|
@ -146,6 +146,9 @@ class SummaryRecord:
|
|||
def __init__(self, log_dir, file_prefix="events", file_suffix="_MS",
|
||||
network=None, max_file_size=None, raise_exception=False, export_options=None):
|
||||
|
||||
if security.enable_security():
|
||||
raise ValueError('The Summary is not supported, please without `-s on` and recompile source.')
|
||||
|
||||
self._event_writer = None
|
||||
self._mode, self._data_pool = 'train', defaultdict(list)
|
||||
self._status = {
|
||||
|
|
|
@ -23,8 +23,11 @@ import numpy as np
|
|||
from mindspore.train.summary.enums import PluginEnum, WriterPluginEnum
|
||||
|
||||
from .._utils import _make_directory
|
||||
from ..._c_expression import EventWriter_
|
||||
from ._summary_adapter import package_init_event
|
||||
from ..._c_expression import security
|
||||
if not security.enable_security():
|
||||
from ..._c_expression import EventWriter_
|
||||
|
||||
|
||||
FREE_DISK_SPACE_TIMES = 32
|
||||
FILE_MODE = 0o400
|
||||
|
@ -41,7 +44,7 @@ class BaseWriter:
|
|||
"""Write some metadata etc."""
|
||||
|
||||
@property
|
||||
def writer(self) -> EventWriter_:
|
||||
def writer(self):
|
||||
"""Get the writer."""
|
||||
if self._writer is not None:
|
||||
return self._writer
|
||||
|
|
|
@ -384,6 +384,7 @@ def train_summary_record(test_writer, steps):
|
|||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.security_off
|
||||
def test_summary():
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
steps = 2
|
||||
|
|
|
@ -62,6 +62,7 @@ def train_summary_record(test_writer, steps):
|
|||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.security_off
|
||||
def test_summary_step2_summary_record1():
|
||||
"""Test record 10 step summary."""
|
||||
if platform.system() == "Windows":
|
||||
|
|
|
@ -126,6 +126,7 @@ class TestSummary:
|
|||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.security_off
|
||||
def test_summary_with_sink_mode_false(self):
|
||||
"""Test summary with sink mode false, and num samples is 64."""
|
||||
summary_dir = self._run_network(num_samples=10)
|
||||
|
@ -148,6 +149,7 @@ class TestSummary:
|
|||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.security_off
|
||||
def test_summary_with_sink_mode_true(self):
|
||||
"""Test summary with sink mode true, and num samples is 64."""
|
||||
summary_dir = self._run_network(dataset_sink_mode=True, num_samples=10)
|
||||
|
@ -167,6 +169,7 @@ class TestSummary:
|
|||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.security_off
|
||||
def test_summarycollector_user_defind(self):
|
||||
"""Test SummaryCollector with user-defined."""
|
||||
summary_dir = self._run_network(dataset_sink_mode=True, num_samples=2,
|
||||
|
|
|
@ -93,6 +93,7 @@ class TestSummaryOps:
|
|||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.security_off
|
||||
def test_summary_ops(self):
|
||||
"""Test summary operators."""
|
||||
ds_train = create_mnist_dataset('train', num_samples=1, batch_size=1)
|
||||
|
|
|
@ -167,6 +167,12 @@ file(GLOB_RECURSE MINDSPORE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
|||
"../../../mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/adam_fp32.c"
|
||||
)
|
||||
|
||||
if(ENABLE_SECURITY)
|
||||
file(GLOB_RECURSE _INFER_SUMMARY_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
||||
"../../../mindspore/core/ops/*_summary.cc"
|
||||
)
|
||||
list(REMOVE_ITEM MINDSPORE_SRC_LIST ${_INFER_SUMMARY_FILES})
|
||||
endif()
|
||||
list(REMOVE_ITEM MINDSPORE_SRC_LIST
|
||||
"../../../mindspore/ccsrc/frontend/parallel/strategy_checkpoint/parallel_strategy_checkpoint.cc")
|
||||
list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/ps/util.cc")
|
||||
|
|
|
@ -530,6 +530,7 @@ TEST_F(TestConvert, TestSquareOps) {
|
|||
ASSERT_TRUE(ret);
|
||||
}
|
||||
|
||||
#ifndef ENABLE_SECURITY
|
||||
TEST_F(TestConvert, TestScalarSummaryOps) {
|
||||
auto prim = prim::kPrimScalarSummary;
|
||||
// should have only 1 input.
|
||||
|
@ -548,6 +549,7 @@ TEST_F(TestConvert, TestHistogramSummaryOps) {
|
|||
bool ret = MakeDfGraph(prim, 2);
|
||||
ASSERT_TRUE(ret);
|
||||
}
|
||||
#endif
|
||||
|
||||
TEST_F(TestConvert, TestGreaterOps) {
|
||||
auto prim = std::make_shared<Primitive>("Greater");
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
# ============================================================================
|
||||
""" test nn ops """
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore
|
||||
import mindspore.context as context
|
||||
|
@ -25,6 +26,8 @@ from mindspore.ops import operations as P
|
|||
from mindspore.ops import functional as F
|
||||
from mindspore.ops.operations import _grad_ops as G
|
||||
from mindspore.ops import prim_attr_register, PrimitiveWithInfer
|
||||
from mindspore._c_expression import security
|
||||
from tests.security_utils import security_off_wrap
|
||||
from ..ut_filter import non_graph_engine
|
||||
from ....mindspore_test_framework.mindspore_test import mindspore_test
|
||||
from ....mindspore_test_framework.pipeline.forward.compile_forward \
|
||||
|
@ -33,6 +36,7 @@ from ....mindspore_test_framework.pipeline.forward.verify_exception \
|
|||
import pipeline_for_verify_exception_for_case_by_case_config
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
|
||||
def conv3x3(in_channels, out_channels, stride=1, padding=1):
|
||||
"""3x3 convolution """
|
||||
return nn.Conv2d(in_channels, out_channels,
|
||||
|
@ -520,18 +524,10 @@ test_cases = [
|
|||
'desc_inputs': [[128, 32, 32, 64]],
|
||||
'desc_bprop': [[128, 32, 32, 64]],
|
||||
}),
|
||||
('ScalarSummary', {
|
||||
'block': ScalarSummaryNet(),
|
||||
'desc_inputs': [Tensor(2.2)],
|
||||
}),
|
||||
('L2Normalize', {
|
||||
'block': L2NormalizeNet(),
|
||||
'desc_inputs': [Tensor(np.array([[1.0, 2, 3], [4.0, 5, 6], [7.0, 8, 9]]), mindspore.float32)],
|
||||
}),
|
||||
('HistogramSummary', {
|
||||
'block': HistogramSummaryNet(),
|
||||
'desc_inputs': [[1, 2, 3]],
|
||||
}),
|
||||
('FusedBatchNormGrad', {
|
||||
'block': FusedBatchNormGrad(nn.BatchNorm2d(num_features=512, eps=1e-5, momentum=0.1)),
|
||||
'desc_inputs': [[64, 512, 7, 7], [64, 512, 7, 7]],
|
||||
|
@ -753,6 +749,33 @@ test_cases_for_verify_exception = [
|
|||
]
|
||||
|
||||
|
||||
@security_off_wrap
|
||||
@non_graph_engine
|
||||
@mindspore_test(pipeline_for_verify_exception_for_case_by_case_config)
|
||||
def test_summary_nn_ops():
|
||||
if security.enable_security():
|
||||
return []
|
||||
test_cases_for_summary_ops = [
|
||||
('ScalarSummary', {
|
||||
'block': ScalarSummaryNet(),
|
||||
'desc_inputs': [Tensor(2.2)],
|
||||
}),
|
||||
('HistogramSummary', {
|
||||
'block': HistogramSummaryNet(),
|
||||
'desc_inputs': [[1, 2, 3]],
|
||||
}),
|
||||
]
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
return test_cases_for_summary_ops
|
||||
|
||||
|
||||
def test_summary_nn_ops_security_on():
|
||||
if security.enable_security():
|
||||
with pytest.raises(ValueError) as exc:
|
||||
ScalarSummaryNet()
|
||||
assert str(exc.value) == 'The Summary is not supported, please without `-s on` and recompile source.'
|
||||
|
||||
|
||||
@non_graph_engine
|
||||
@mindspore_test(pipeline_for_compile_forward_ge_graph_for_case_by_case_config)
|
||||
def test_compile():
|
||||
|
|
|
@ -31,6 +31,8 @@ from mindspore.ops.operations import _inner_ops as inner
|
|||
from mindspore.ops.operations import _quant_ops as Q
|
||||
from mindspore.ops.operations import nn_ops as nps
|
||||
from mindspore.nn.layer import normalization
|
||||
from mindspore._c_expression import security
|
||||
from tests.security_utils import security_off_wrap
|
||||
from ..ut_filter import non_graph_engine
|
||||
from ....mindspore_test_framework.mindspore_test import mindspore_test
|
||||
from ....mindspore_test_framework.pipeline.forward.compile_forward \
|
||||
|
@ -2866,16 +2868,6 @@ test_case_other_ops = [
|
|||
'block': P.IOU(),
|
||||
'desc_inputs': [Tensor(np.ones((256, 4), np.float16)), Tensor(np.ones((128, 4), np.float16))],
|
||||
'desc_bprop': [convert([128, 256], np.float16)]}),
|
||||
('Summary', {
|
||||
'block': SummaryNet(),
|
||||
'desc_inputs': [Tensor(np.array([1.1]).astype(np.float32)),
|
||||
Tensor(np.array([1.2]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
('HistogramSummary', {
|
||||
'block': HistogramSummaryNet(),
|
||||
'desc_inputs': [Tensor(np.array([1.1]).astype(np.float32)),
|
||||
Tensor(np.array([1.2]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
('PopulationCount', {
|
||||
'block': P.PopulationCount(),
|
||||
'desc_inputs': [Tensor(np.array([1, 2, 3]).astype(np.int16))],
|
||||
|
@ -3046,6 +3038,38 @@ def test_backward_exec():
|
|||
return test_backward_exec_case
|
||||
|
||||
|
||||
@security_off_wrap
|
||||
@non_graph_engine
|
||||
@mindspore_test(pipeline_for_compile_forward_ge_graph_for_case_by_case_config)
|
||||
def test_summary_ops():
|
||||
if security.enable_security():
|
||||
return []
|
||||
test_cases_for_summary_ops = [
|
||||
('Summary', {
|
||||
'block': SummaryNet(),
|
||||
'desc_inputs': [Tensor(np.array([1.1]).astype(np.float32)),
|
||||
Tensor(np.array([1.2]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
('HistogramSummary', {
|
||||
'block': HistogramSummaryNet(),
|
||||
'desc_inputs': [Tensor(np.array([1.1]).astype(np.float32)),
|
||||
Tensor(np.array([1.2]).astype(np.float32))],
|
||||
'skip': ['backward']}),
|
||||
]
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
return test_cases_for_summary_ops
|
||||
|
||||
|
||||
def test_summary_ops_security_on():
|
||||
if security.enable_security():
|
||||
with pytest.raises(ValueError) as exc:
|
||||
SummaryNet()
|
||||
assert str(exc.value) == 'The Summary is not supported, please without `-s on` and recompile source.'
|
||||
with pytest.raises(ValueError) as exc:
|
||||
HistogramSummaryNet()
|
||||
assert str(exc.value) == 'The Summary is not supported, please without `-s on` and recompile source.'
|
||||
|
||||
|
||||
raise_set = [
|
||||
('Cast_Error', {
|
||||
'block': (P.Cast(), {'exception': TypeError}),
|
||||
|
|
|
@ -166,7 +166,6 @@ def test_ops():
|
|||
if [1, 2, 3] is not None:
|
||||
if self.str_a + self.str_b == "helloworld":
|
||||
if q == 86:
|
||||
print("hello world")
|
||||
return ret
|
||||
return x
|
||||
|
||||
|
|
|
@ -23,6 +23,7 @@ from mindspore.common.tensor import Tensor
|
|||
from mindspore.train.summary._summary_adapter import _calc_histogram_bins
|
||||
from mindspore.train.summary.summary_record import SummaryRecord, _cache_summary_tensor_data
|
||||
from tests.summary_utils import SummaryReader
|
||||
from tests.security_utils import security_off_wrap
|
||||
|
||||
CUR_DIR = os.getcwd()
|
||||
SUMMARY_DIR = os.path.join(CUR_DIR, "/test_temp_summary_event_file/")
|
||||
|
@ -48,6 +49,7 @@ def _wrap_test_data(input_data: Tensor):
|
|||
}]
|
||||
|
||||
|
||||
@security_off_wrap
|
||||
def test_histogram_summary():
|
||||
"""Test histogram summary."""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
|
@ -61,6 +63,7 @@ def test_histogram_summary():
|
|||
assert event.summary.value[0].histogram.count == 6
|
||||
|
||||
|
||||
@security_off_wrap
|
||||
def test_histogram_multi_summary():
|
||||
"""Test histogram multiple step."""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
|
@ -83,6 +86,8 @@ def test_histogram_multi_summary():
|
|||
event = reader.read_event()
|
||||
assert event.summary.value[0].histogram.count == size
|
||||
|
||||
|
||||
@security_off_wrap
|
||||
def test_histogram_summary_empty_tensor():
|
||||
"""Test histogram summary, input is an empty tensor."""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
|
@ -97,6 +102,7 @@ def test_histogram_summary_empty_tensor():
|
|||
assert event.summary.value[0].histogram.count == 0
|
||||
|
||||
|
||||
@security_off_wrap
|
||||
def test_histogram_summary_same_value():
|
||||
"""Test histogram summary, input is an ones tensor."""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
|
@ -116,6 +122,7 @@ def test_histogram_summary_same_value():
|
|||
assert len(event.summary.value[0].histogram.buckets) == _calc_histogram_bins(dim1 * dim2)
|
||||
|
||||
|
||||
@security_off_wrap
|
||||
def test_histogram_summary_high_dims():
|
||||
"""Test histogram summary, input is a 4-dimension tensor."""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
|
@ -136,6 +143,7 @@ def test_histogram_summary_high_dims():
|
|||
assert event.summary.value[0].histogram.count == tensor_data.size
|
||||
|
||||
|
||||
@security_off_wrap
|
||||
def test_histogram_summary_nan_inf():
|
||||
"""Test histogram summary, input tensor has nan."""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
|
@ -160,6 +168,7 @@ def test_histogram_summary_nan_inf():
|
|||
assert event.summary.value[0].histogram.nan_count == 1
|
||||
|
||||
|
||||
@security_off_wrap
|
||||
def test_histogram_summary_all_nan_inf():
|
||||
"""Test histogram summary, input tensor has no valid number."""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
|
|
|
@ -23,6 +23,7 @@ from mindspore import Tensor
|
|||
from mindspore.nn.optim import Momentum
|
||||
from mindspore.train.summary.summary_record import SummaryRecord, _cache_summary_tensor_data
|
||||
from mindspore.train.callback import Callback
|
||||
from tests.security_utils import security_off_wrap
|
||||
from .....dataset_mock import MindData
|
||||
|
||||
CUR_DIR = os.getcwd()
|
||||
|
@ -62,6 +63,7 @@ def get_test_data(step):
|
|||
|
||||
|
||||
# Test: call method on parse graph code
|
||||
@security_off_wrap
|
||||
def test_image_summary_sample():
|
||||
""" test_image_summary_sample """
|
||||
with SummaryRecord(SUMMARY_DIR, file_suffix="_MS_IMAGE") as test_writer:
|
||||
|
@ -157,6 +159,7 @@ class ImageSummaryCallback(Callback):
|
|||
self._summary_record.flush()
|
||||
|
||||
|
||||
@security_off_wrap
|
||||
def test_image_summary_train():
|
||||
""" test_image_summary_train """
|
||||
dataset = get_dataset()
|
||||
|
@ -166,6 +169,7 @@ def test_image_summary_train():
|
|||
model.train(2, dataset, callbacks=[callback])
|
||||
|
||||
|
||||
@security_off_wrap
|
||||
def test_image_summary_data():
|
||||
""" test_image_summary_data """
|
||||
dataset = get_dataset()
|
||||
|
|
|
@ -32,7 +32,8 @@ from mindspore.train.summary.summary_record import _DEFAULT_EXPORT_OPTIONS
|
|||
from mindspore.nn import Cell
|
||||
from mindspore.nn.optim.optimizer import Optimizer
|
||||
from mindspore.ops.operations import Add
|
||||
|
||||
from mindspore._c_expression import security
|
||||
from tests.security_utils import security_off_wrap
|
||||
|
||||
|
||||
_VALUE_CACHE = list()
|
||||
|
@ -84,6 +85,7 @@ class TestSummaryCollector:
|
|||
"""Run after each test function."""
|
||||
get_value()
|
||||
|
||||
@security_off_wrap
|
||||
@pytest.mark.parametrize("summary_dir", [1234, None, True, ''])
|
||||
def test_params_with_summary_dir_value_error(self, summary_dir):
|
||||
"""Test the exception scenario for summary dir."""
|
||||
|
@ -97,6 +99,7 @@ class TestSummaryCollector:
|
|||
SummaryCollector(summary_dir=summary_dir)
|
||||
assert 'For `summary_dir` the type should be a valid type' in str(exc.value)
|
||||
|
||||
@security_off_wrap
|
||||
def test_params_with_summary_dir_not_dir(self):
|
||||
"""Test the given summary dir parameter is not a directory."""
|
||||
summary_dir = tempfile.mkdtemp(dir=self.base_summary_dir)
|
||||
|
@ -107,6 +110,7 @@ class TestSummaryCollector:
|
|||
with pytest.raises(NotADirectoryError):
|
||||
SummaryCollector(summary_dir=summary_file)
|
||||
|
||||
@security_off_wrap
|
||||
@pytest.mark.parametrize("collect_freq", [None, 0, 0.01])
|
||||
def test_params_with_collect_freq_exception(self, collect_freq):
|
||||
"""Test the exception scenario for collect freq."""
|
||||
|
@ -123,6 +127,7 @@ class TestSummaryCollector:
|
|||
f'but got {type(collect_freq).__name__}.'
|
||||
assert expected_msg == str(exc.value)
|
||||
|
||||
@security_off_wrap
|
||||
@pytest.mark.parametrize("action", [None, 123, '', '123'])
|
||||
def test_params_with_action_exception(self, action):
|
||||
"""Test the exception scenario for action."""
|
||||
|
@ -133,6 +138,7 @@ class TestSummaryCollector:
|
|||
f"but got {type(action).__name__}."
|
||||
assert expected_msg == str(exc.value)
|
||||
|
||||
@security_off_wrap
|
||||
@pytest.mark.parametrize("collect_specified_data", [123])
|
||||
def test_params_with_collect_specified_data_type_error(self, collect_specified_data):
|
||||
"""Test type error scenario for collect specified data param."""
|
||||
|
@ -145,6 +151,7 @@ class TestSummaryCollector:
|
|||
|
||||
assert expected_msg == str(exc.value)
|
||||
|
||||
@security_off_wrap
|
||||
@pytest.mark.parametrize("export_options", [
|
||||
{
|
||||
"tensor_format": "npz"
|
||||
|
@ -163,6 +170,7 @@ class TestSummaryCollector:
|
|||
|
||||
assert expected_msg == str(exc.value)
|
||||
|
||||
@security_off_wrap
|
||||
@pytest.mark.parametrize("export_options", [123])
|
||||
def test_params_with_export_options_type_error(self, export_options):
|
||||
"""Test type error scenario for collect specified data param."""
|
||||
|
@ -175,6 +183,7 @@ class TestSummaryCollector:
|
|||
|
||||
assert expected_msg == str(exc.value)
|
||||
|
||||
@security_off_wrap
|
||||
@pytest.mark.parametrize("collect_specified_data", [
|
||||
{
|
||||
123: 123
|
||||
|
@ -194,6 +203,7 @@ class TestSummaryCollector:
|
|||
f"but got {type(param_name).__name__}."
|
||||
assert expected_msg == str(exc.value)
|
||||
|
||||
@security_off_wrap
|
||||
@pytest.mark.parametrize("collect_specified_data", [
|
||||
{
|
||||
'collect_metric': None
|
||||
|
@ -219,6 +229,7 @@ class TestSummaryCollector:
|
|||
|
||||
assert expected_msg == str(exc.value)
|
||||
|
||||
@security_off_wrap
|
||||
def test_params_with_histogram_regular_value_error(self):
|
||||
"""Test histogram regular."""
|
||||
summary_dir = tempfile.mkdtemp(dir=self.base_summary_dir)
|
||||
|
@ -227,6 +238,7 @@ class TestSummaryCollector:
|
|||
|
||||
assert 'For `collect_specified_data`, the value of `histogram_regular`' in str(exc.value)
|
||||
|
||||
@security_off_wrap
|
||||
def test_params_with_collect_specified_data_unexpected_key(self):
|
||||
"""Test the collect_specified_data parameter with unexpected key."""
|
||||
summary_dir = tempfile.mkdtemp(dir=self.base_summary_dir)
|
||||
|
@ -236,6 +248,7 @@ class TestSummaryCollector:
|
|||
expected_msg = f"For `collect_specified_data` the keys {set(data)} are unsupported"
|
||||
assert expected_msg in str(exc.value)
|
||||
|
||||
@security_off_wrap
|
||||
def test_params_with_export_options_unexpected_key(self):
|
||||
"""Test the export_options parameter with unexpected key."""
|
||||
summary_dir = tempfile.mkdtemp(dir=self.base_summary_dir)
|
||||
|
@ -245,6 +258,7 @@ class TestSummaryCollector:
|
|||
expected_msg = f"For `export_options` the keys {set(data)} are unsupported"
|
||||
assert expected_msg in str(exc.value)
|
||||
|
||||
@security_off_wrap
|
||||
@pytest.mark.parametrize("custom_lineage_data", [
|
||||
123,
|
||||
{
|
||||
|
@ -280,6 +294,7 @@ class TestSummaryCollector:
|
|||
|
||||
assert expected_msg == str(exc.value)
|
||||
|
||||
@security_off_wrap
|
||||
def test_check_callback_with_multi_instances(self):
|
||||
"""Use multi SummaryCollector instances to test check_callback function."""
|
||||
cb_params = _InternalCallbackParam()
|
||||
|
@ -291,6 +306,7 @@ class TestSummaryCollector:
|
|||
SummaryCollector((tempfile.mkdtemp(dir=self.base_summary_dir)))._check_callbacks(cb_params)
|
||||
assert f"more than one SummaryCollector instance in callback list" in str(exc.value)
|
||||
|
||||
@security_off_wrap
|
||||
def test_collect_input_data_with_train_dataset_element_invalid(self):
|
||||
"""Test the param 'train_dataset_element' in cb_params is invalid."""
|
||||
cb_params = _InternalCallbackParam()
|
||||
|
@ -300,6 +316,7 @@ class TestSummaryCollector:
|
|||
summary_collector._collect_input_data(cb_params)
|
||||
assert not summary_collector._collect_specified_data['collect_input_data']
|
||||
|
||||
@security_off_wrap
|
||||
@mock.patch.object(SummaryRecord, 'add_value')
|
||||
def test_collect_input_data_success(self, mock_add_value):
|
||||
"""Mock a image data, and collect image data success."""
|
||||
|
@ -311,6 +328,7 @@ class TestSummaryCollector:
|
|||
summary_collector._collect_input_data(cb_params)
|
||||
# Note Here need to assert the result and expected data
|
||||
|
||||
@security_off_wrap
|
||||
@mock.patch.object(SummaryRecord, 'add_value')
|
||||
def test_collect_dataset_graph_success(self, mock_add_value):
|
||||
"""Test collect dataset graph."""
|
||||
|
@ -325,6 +343,7 @@ class TestSummaryCollector:
|
|||
assert plugin == 'dataset_graph'
|
||||
assert name == 'train_dataset'
|
||||
|
||||
@security_off_wrap
|
||||
@pytest.mark.parametrize("net_output, expected_loss", [
|
||||
(None, None),
|
||||
(1, Tensor(1)),
|
||||
|
@ -349,6 +368,7 @@ class TestSummaryCollector:
|
|||
else:
|
||||
assert summary_collector._is_parse_loss_success
|
||||
|
||||
@security_off_wrap
|
||||
def test_get_optimizer_from_cb_params_success(self):
|
||||
"""Test get optimizer success from cb params."""
|
||||
cb_params = _InternalCallbackParam()
|
||||
|
@ -360,6 +380,7 @@ class TestSummaryCollector:
|
|||
# Test get optimizer again
|
||||
assert summary_collector._get_optimizer(cb_params) == cb_params.optimizer
|
||||
|
||||
@security_off_wrap
|
||||
@pytest.mark.parametrize('mode', [ModeEnum.TRAIN.value, ModeEnum.EVAL.value])
|
||||
def test_get_optimizer_from_network(self, mode):
|
||||
"""Get optimizer from train network"""
|
||||
|
@ -374,6 +395,7 @@ class TestSummaryCollector:
|
|||
optimizer = summary_collector._get_optimizer(cb_params)
|
||||
assert isinstance(optimizer, Optimizer)
|
||||
|
||||
@security_off_wrap
|
||||
def test_get_optimizer_failed(self):
|
||||
"""Test get optimizer failed."""
|
||||
class Net(Cell):
|
||||
|
@ -399,6 +421,7 @@ class TestSummaryCollector:
|
|||
assert optimizer is None
|
||||
assert summary_collector._temp_optimizer == 'Failed'
|
||||
|
||||
@security_off_wrap
|
||||
@pytest.mark.parametrize("histogram_regular, expected_names", [
|
||||
(
|
||||
'conv1|conv2',
|
||||
|
@ -430,6 +453,7 @@ class TestSummaryCollector:
|
|||
assert PluginEnum.HISTOGRAM.value == result[0][0]
|
||||
assert expected_names == [data[1] for data in result]
|
||||
|
||||
@security_off_wrap
|
||||
@pytest.mark.parametrize("specified_data, action, expected_result", [
|
||||
(None, True, SummaryCollector._DEFAULT_SPECIFIED_DATA),
|
||||
(None, False, {}),
|
||||
|
@ -446,3 +470,11 @@ class TestSummaryCollector:
|
|||
keep_default_action=action)
|
||||
|
||||
assert summary_collector._collect_specified_data == expected_result
|
||||
|
||||
@pytest.mark.parametrize("summary_dir", './')
|
||||
def test_summary_collector_security_on(self, summary_dir):
|
||||
"""Test the summary collector when set security on."""
|
||||
if security.enable_security():
|
||||
with pytest.raises(ValueError) as exc:
|
||||
SummaryCollector(summary_dir=summary_dir)
|
||||
assert str(exc.value) == 'The Summary is not supported, please without `-s on` and recompile source.'
|
||||
|
|
|
@ -22,6 +22,8 @@ import pytest
|
|||
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.train.summary.summary_record import SummaryRecord
|
||||
from mindspore._c_expression import security
|
||||
from tests.security_utils import security_off_wrap
|
||||
|
||||
|
||||
def get_test_data(step):
|
||||
|
@ -57,12 +59,14 @@ class TestSummaryRecord:
|
|||
if os.path.exists(self.base_summary_dir):
|
||||
shutil.rmtree(self.base_summary_dir)
|
||||
|
||||
@security_off_wrap
|
||||
@pytest.mark.parametrize("log_dir", ["", None, 32])
|
||||
def test_log_dir_with_type_error(self, log_dir):
|
||||
with pytest.raises(TypeError):
|
||||
with SummaryRecord(log_dir):
|
||||
pass
|
||||
|
||||
@security_off_wrap
|
||||
@pytest.mark.parametrize("raise_exception", ["", None, 32])
|
||||
def test_raise_exception_with_type_error(self, raise_exception):
|
||||
summary_dir = tempfile.mkdtemp(dir=self.base_summary_dir)
|
||||
|
@ -72,9 +76,18 @@ class TestSummaryRecord:
|
|||
|
||||
assert "raise_exception" in str(exc.value)
|
||||
|
||||
@security_off_wrap
|
||||
@pytest.mark.parametrize("step", ["str"])
|
||||
def test_step_of_record_with_type_error(self, step):
|
||||
summary_dir = tempfile.mkdtemp(dir=self.base_summary_dir)
|
||||
with pytest.raises(TypeError):
|
||||
with SummaryRecord(summary_dir) as sr:
|
||||
sr.record(step)
|
||||
|
||||
@pytest.mark.parametrize("log_dir", './')
|
||||
def test_summary_collector_security_on(self, log_dir):
|
||||
"""Test the summary collector when set security on."""
|
||||
if security.enable_security():
|
||||
with pytest.raises(ValueError) as exc:
|
||||
SummaryRecord(log_dir=log_dir)
|
||||
assert str(exc.value) == 'The Summary is not supported, please without `-s on` and recompile source.'
|
||||
|
|
Loading…
Reference in New Issue