fix summary isolation

This commit is contained in:
jiangshuqiang 2021-09-18 10:59:14 +08:00
parent 78b6fd17d6
commit c2bd061889
42 changed files with 335 additions and 58 deletions

View File

@ -362,6 +362,7 @@ void MemReuseUtil::SetReuseRefCount() {
}
}
#ifndef ENABLE_SECURITY
void MemReuseUtil::SetSummaryNodesRefCount() {
bool summary_exist = graph_->summary_node_exist();
if (!summary_exist) {
@ -393,6 +394,7 @@ void MemReuseUtil::SetSummaryNodesRefCount() {
#endif
MS_LOG(INFO) << "Special Tensor total size: SummaryNodes: " << total_summary_size;
}
#endif
void MemReuseUtil::SetRefNodesInputRefCount() {
size_t total_size = 0;
@ -457,7 +459,9 @@ void MemReuseUtil::SetAllInfo(const KernelGraph *graph) {
}
SetKernelDefMap();
SetReuseRefCount();
#ifndef ENABLE_SECURITY
SetSummaryNodesRefCount();
#endif
SetRefNodesInputRefCount();
SetWorkSpaceList();
#ifdef MEM_REUSE_DEBUG

View File

@ -63,7 +63,9 @@ class MemReuseUtil {
void SetWkMap(const CNodePtr &kernel, KernelDef *kernel_def_ptr);
void SetKernelDefInputs();
void SetReuseRefCount();
#ifndef ENABLE_SECURITY
void SetSummaryNodesRefCount();
#endif
void SetRefNodesInputRefCount();
// Set the reference count of graph output specially.
void SetGraphOutputRefCount();

View File

@ -340,7 +340,9 @@ bool Somas::InitSomasTensors(const session::KernelGraph *graph) {
MS_EXCEPTION_IF_NULL(graph);
InitBasicInfo(graph);
IndependentNodeOutputProcess(graph);
#ifndef ENABLE_SECURITY
SummaryInputProcess(graph);
#endif
RefNodeProcess(graph);
NonTaskSplitProcess(graph);
UnReuseNodeProcess(graph);
@ -743,6 +745,7 @@ void Somas::IndependentNodeOutputProcess(const session::KernelGraph *graph) {
MS_LOG(INFO) << "Special Tensor total size: Independent Node output " << total_size;
}
#ifndef ENABLE_SECURITY
void Somas::SummaryInputProcess(const session::KernelGraph *graph) {
MS_EXCEPTION_IF_NULL(graph);
bool summary_exist = graph->summary_node_exist();
@ -782,6 +785,7 @@ void Somas::SummaryInputProcess(const session::KernelGraph *graph) {
MS_LOG(INFO) << "Special Tensor total size: SummaryNodes: " << total_summary_size;
}
#endif
void Somas::RefNodeProcess(const session::KernelGraph *graph) {
MS_EXCEPTION_IF_NULL(graph);

View File

@ -115,7 +115,9 @@ class Somas {
void InitSomasInputTensors(const session::KernelGraph *graph);
void GetNextOutputProcess(const session::KernelGraph *graph);
void IndependentNodeOutputProcess(const session::KernelGraph *graph);
#ifndef ENABLE_SECURITY
void SummaryInputProcess(const session::KernelGraph *graph);
#endif
void RefNodeProcess(const session::KernelGraph *graph);
void NonTaskSplitProcess(const session::KernelGraph *graph);
void UnReuseNodeProcess(const session::KernelGraph *graph);

View File

@ -58,10 +58,16 @@ bool IsOneOfPrimitive(const AnfNodePtr &node, const PrimitiveSet &prim_set) {
}
bool IsRealKernelCNode(const CNodePtr &cnode) {
#ifndef ENABLE_SECURITY
static const PrimitiveSet virtual_prims = {
prim::kPrimImageSummary, prim::kPrimScalarSummary, prim::kPrimTensorSummary, prim::kPrimHistogramSummary,
prim::kPrimMakeTuple, prim::kPrimStateSetItem, prim::kPrimTupleGetItem, prim::kPrimReturn,
prim::kPrimPartial, prim::kPrimDepend, prim::kPrimUpdateState, prim::kPrimLoad};
#else
static const PrimitiveSet virtual_prims = {prim::kPrimMakeTuple, prim::kPrimStateSetItem, prim::kPrimTupleGetItem,
prim::kPrimReturn, prim::kPrimPartial, prim::kPrimDepend,
prim::kPrimUpdateState, prim::kPrimLoad};
#endif
if (cnode->inputs().empty()) {
MS_LOG(EXCEPTION) << "Illegal null input of cnode(%s)" << cnode->DebugString();
}

View File

@ -607,7 +607,9 @@ GraphId AscendSession::CompileGraphImpl(NotNull<FuncGraphPtr> func_graph) {
device::KernelAdjust::GetInstance().InsertOverflowCheckOperations(NOT_NULL(root_graph));
// build kernel
BuildKernel(root_graph);
#ifndef ENABLE_SECURITY
SetSummaryNodes(root_graph.get());
#endif
// Alloc memory for child graph's inputs
AssignStaticMemory(NOT_NULL(root_graph), NOT_NULL(&memo));
memo.clear();
@ -634,6 +636,7 @@ GraphId AscendSession::CompileGraphImpl(NotNull<FuncGraphPtr> func_graph) {
return graph_id;
}
#ifndef ENABLE_SECURITY
void AscendSession::SetFinalGraphSummaryFlag(const std::shared_ptr<KernelGraph> &kernel_graph) {
MS_EXCEPTION_IF_NULL(kernel_graph);
auto graph_order = GetGraphOrder(kernel_graph->graph_id());
@ -649,6 +652,7 @@ void AscendSession::SetFinalGraphSummaryFlag(const std::shared_ptr<KernelGraph>
}
kernel_graph->set_summary_node_exist(false);
}
#endif
void AscendSession::BuildGraphImpl(GraphId graph_id) {
MS_LOG(INFO) << "Start";
@ -760,7 +764,9 @@ void AscendSession::PreExecuteGraph(const std::shared_ptr<KernelGraph> &kernel_g
void AscendSession::PostExecuteGraph(const std::shared_ptr<KernelGraph> &kernel_graph,
const std::vector<tensor::TensorPtr> &, VectorRef *const) {
// summary
#ifndef ENABLE_SECURITY
Summary(kernel_graph.get());
#endif
#ifdef ENABLE_DEBUGGER
// load tensor from device for debugger
if (debugger_ && debugger_->debugger_enabled()) {
@ -1603,6 +1609,7 @@ void AscendSession::LoadTensor(const std::shared_ptr<KernelGraph> &kernel_graph)
MS_LOG(INFO) << "Finish!";
}
#ifndef ENABLE_SECURITY
void AscendSession::RecurseSetSummaryNodes(KernelGraph *graph,
std::map<std::string, std::pair<AnfNodePtr, int>> *summary) {
MS_EXCEPTION_IF_NULL(graph);
@ -1640,6 +1647,7 @@ void AscendSession::SetSummaryNodes(KernelGraph *graph) {
graph->set_summary_nodes(summary);
MS_LOG(DEBUG) << "Update summary end size: " << summary.size();
}
#endif
void AscendSession::MergeGraphExecOrder() {
MS_LOG(INFO) << "Start!";

View File

@ -89,8 +89,10 @@ class AscendSession : public SessionBasic {
private:
// compile child graph when session have multiple child graphs
void CompileChildGraph(const KernelGraphPtr &child_graph);
#ifndef ENABLE_SECURITY
void RecurseSetSummaryNodes(KernelGraph *graph, std::map<std::string, std::pair<AnfNodePtr, int>> *summary);
void SetSummaryNodes(KernelGraph *graph) override;
#endif
void InitRuntimeResource();
void SelectKernel(const KernelGraph &kernel_graph) const;
void HardwareOptimize(const std::shared_ptr<KernelGraph> &kernel_graph) const;
@ -128,7 +130,9 @@ class AscendSession : public SessionBasic {
const std::vector<GraphType> &GetGraphOrderType(GraphId final_graph_id) const;
// sync initial tensors' data to device
void SyncInitialTenosrToDevice();
#ifndef ENABLE_SECURITY
void SetFinalGraphSummaryFlag(const std::shared_ptr<KernelGraph> &kernel_graph);
#endif
// create parameter to receive data from multiple branch output
void CreateMultiBranchOutput(NotNull<KernelGraphPtr> graph, NotNull<std::set<KernelGraphPtr> *> memo);
void SelectKernel(NotNull<KernelGraphPtr> root_graph);

View File

@ -146,7 +146,9 @@ GraphId CPUSession::CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtr
MS_LOG(INFO) << "Assign kernel address";
runtime_.AssignKernelAddress(graph.get());
// set summary node
#ifndef ENABLE_SECURITY
SetSummaryNodes(graph.get());
#endif
runtime_.IncreaseSummaryRefCount(graph->summary_nodes());
DumpGraph(graph);
return graph_id;
@ -205,7 +207,9 @@ void CPUSession::PreExecuteGraph(const std::shared_ptr<KernelGraph> &kernel_grap
void CPUSession::PostExecuteGraph(const std::shared_ptr<KernelGraph> &kernel_graph,
const std::vector<tensor::TensorPtr> &, VectorRef *const) {
#ifndef ENABLE_SECURITY
Summary(kernel_graph.get());
#endif
}
void CPUSession::ExecuteGraph(const std::shared_ptr<KernelGraph> &kernel_graph) {

View File

@ -455,8 +455,10 @@ GraphId GPUSession::CompileGraphImpl(KernelGraphPtr graph) {
std::string exec_order_name = "graph_exec_order." + std::to_string(graph->graph_id());
(void)mindspore::RDR::RecordGraphExecOrder(SubModuleId::SM_SESSION, exec_order_name, kernels);
#endif
#ifndef ENABLE_SECURITY
// Get summary nodes.
SetSummaryNodes(graph.get());
#endif
// Dump .pb graph after graph optimization
#ifdef ENABLE_DUMP_IR
if (save_graphs) {
@ -523,9 +525,11 @@ void GPUSession::PostExecuteGraph(const std::shared_ptr<KernelGraph> &kernel_gra
// Summary
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
#ifndef ENABLE_SECURITY
if (context_ptr->get_param<bool>(MS_CTX_ENABLE_GPU_SUMMARY)) {
Summary(kernel_graph.get());
}
#endif
#ifdef ENABLE_DEBUGGER
if (debugger_ && debugger_->DebuggerBackendEnabled()) {
debugger_->LoadParametersAndConst(kernel_graph);

View File

@ -176,8 +176,10 @@ class KernelGraph : public FuncGraph {
bool executable() const { return executable_; }
// set executable of graph
void set_executable(bool executable) { executable_ = executable; }
#ifndef ENABLE_SECURITY
// set summary_node of graph
void set_summary_node_exist(bool summary_node_exist) { summary_node_exist_ = summary_node_exist; }
#endif
// check whether exist summary node in graph
bool summary_node_exist() const { return summary_node_exist_; }
// set invalid inputs for control sink

View File

@ -355,6 +355,7 @@ void DumpGraphOutput(const Any &any, size_t recurse_level = 0) {
MS_LOG(INFO) << tab_str;
}
#ifndef ENABLE_SECURITY
bool ExistSummaryNode(const KernelGraph *graph) {
MS_EXCEPTION_IF_NULL(graph);
auto ret = graph->get_return();
@ -368,6 +369,7 @@ bool ExistSummaryNode(const KernelGraph *graph) {
}
return false;
}
#endif
BaseRef CreateNodeOutputPlaceholder(const session::KernelWithIndex &node_output_pair, const KernelGraphPtr &graph,
const std::vector<tensor::TensorPtr> &input_tensors,
@ -1153,9 +1155,12 @@ KernelGraphPtr SessionBasic::ConstructKernelGraph(const AnfNodePtrList &lst, con
graph->set_manager(manager);
}
graph->SetExecOrderByDefault();
#ifndef ENABLE_SECURITY
if (ExistSummaryNode(graph.get())) {
graph->set_summary_node_exist(true);
}
#endif
UnifyMindIR(graph);
// Update Graph Dynamic Shape Attr
@ -1564,9 +1569,13 @@ std::shared_ptr<KernelGraph> SessionBasic::ConstructKernelGraph(const FuncGraphP
graph->SetInputNodes();
SetInputNodeUsage(graph, manager);
graph->SetExecOrderByDefault();
#ifndef ENABLE_SECURITY
if (ExistSummaryNode(graph.get())) {
graph->set_summary_node_exist(true);
}
#endif
all_out_graph->push_back(graph);
return graph;
}
@ -1795,6 +1804,7 @@ void SessionBasic::GetModelOutputsInfo(uint32_t graph_id, std::vector<tensor::Te
}
}
#ifndef ENABLE_SECURITY
void SessionBasic::RegisterSummaryCallBackFunc(const CallBackFunc &callback) {
MS_EXCEPTION_IF_NULL(callback);
summary_callback_ = callback;
@ -1874,6 +1884,7 @@ void SessionBasic::Summary(KernelGraph *graph) {
// call callback function here
summary_callback_(0, params_list);
}
#endif
namespace {
bool CNodeFirstInputIsPrimitive(const AnfNodePtr &node) {

View File

@ -112,7 +112,9 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> {
const std::vector<int64_t> &tensors_mask);
void RunOpsInGraph(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs);
#ifndef ENABLE_SECURITY
virtual void RegisterSummaryCallBackFunc(const CallBackFunc &callback);
#endif
bool CreateCNodeOfKernelGraph(const AnfNodePtr &node, KernelGraph *graph);
@ -239,7 +241,9 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> {
virtual void BuildOpsInGraph(const GraphId &graph_id, const std::map<AnfNodePtr, size_t> &parameter_index,
const std::vector<tensor::TensorPtr> &graph_inputs,
const std::map<KernelWithIndex, size_t> &cnode_refcount) {}
#ifndef ENABLE_SECURITY
virtual void SetSummaryNodes(KernelGraph *graph);
#endif
void LoadInputs(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs_const) {
auto kernel_graph = GetGraph(graph_id);
@ -259,7 +263,9 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> {
const std::vector<tensor::TensorPtr> &input_tensors,
std::map<tensor::TensorPtr, session::KernelWithIndex> *tensor_to_node) const;
void UpdateOutputAbstract(const std::shared_ptr<KernelGraph> &kernel_graph, OpRunInfo *op_run_info) const;
#ifndef ENABLE_SECURITY
void Summary(KernelGraph *graph);
#endif
// create graph output for RunOp
void CreateOutputNode(const CNodePtr &cnode, const std::shared_ptr<KernelGraph> &graph);
CNodePtr ConstructOutput(const AnfNodePtrList &outputs, const std::shared_ptr<KernelGraph> &graph);

View File

@ -53,6 +53,7 @@ bool InConvertWhiteList(const AnfNodePtr &node, size_t index) {
// node because it is attribute or ge specific reason.
// Example : when convert CNode(kPrimReduceSum, x, axis), node of index 2 in CNode->inputs is axis which should not be
// converted to switch guarded.
#ifndef ENABLE_SECURITY
std::vector<std::pair<PrimitivePtr, std::vector<size_t>>> white_list({{prim::kPrimApplyMomentum, {1, 2}},
{prim::kPrimMomentum, {2, 3}},
{prim::kPrimStateSetItem, {1}},
@ -78,6 +79,20 @@ bool InConvertWhiteList(const AnfNodePtr &node, size_t index) {
{prim::kPrimTile, {2}},
{prim::kPrimExpandDims, {2}},
{prim::kPrimHistogramSummary, {1}}});
#else
std::vector<std::pair<PrimitivePtr, std::vector<size_t>>> white_list(
{{prim::kPrimApplyMomentum, {1, 2}}, {prim::kPrimMomentum, {2, 3}},
{prim::kPrimStateSetItem, {1}}, {prim::kPrimTupleGetItem, {2}},
{prim::kPrimEnvGetItem, {1}}, {prim::kPrimEnvSetItem, {1}},
{prim::kPrimReduceSum, {2}}, {prim::kPrimReduceMean, {2}},
{prim::kPrimReduceAll, {2}}, {prim::kPrimCast, {2}},
{prim::kPrimTranspose, {2}}, {prim::kPrimOneHot, {2}},
{prim::kPrimGather, {3}}, {prim::kPrimReshape, {2}},
{prim::kPrimAssign, {1}}, {prim::kPrimAssignAdd, {1}},
{prim::kPrimAssignSub, {1}}, {prim::kPrimApplyRMSProp, {6, 7, 8}},
{prim::kPrimCumSum, {2}}, {prim::kPrimTile, {2}},
{prim::kPrimExpandDims, {2}}});
#endif
for (auto &item : white_list) {
auto matched = std::any_of(item.second.begin(), item.second.end(), [&item, &node, &index](size_t idx) {
return IsPrimitiveCNode(node, item.first) && idx == index;

View File

@ -22,7 +22,9 @@
#include "utils/symbolic.h"
#include "pybind_api/api_register.h"
#include "pipeline/jit/parse/python_adapter.h"
#ifndef ENABLE_SECURITY
#include "utils/summary/event_writer.h"
#endif
#include "utils/config_manager.h"
#include "utils/mpi/mpi_config.h"
#include "utils/ms_utils.h"
@ -48,7 +50,9 @@ using GraphExecutorPy = mindspore::pipeline::GraphExecutorPy;
using Pipeline = mindspore::pipeline::Pipeline;
using PrimitivePy = mindspore::PrimitivePy;
using MetaFuncGraph = mindspore::MetaFuncGraph;
#ifndef ENABLE_SECURITY
using EventWriter = mindspore::summary::EventWriter;
#endif // ENABLE_SECURITY
using OpLib = mindspore::kernel::OpLib;
using ParallelContext = mindspore::parallel::ParallelContext;
using CostModelContext = mindspore::parallel::CostModelContext;
@ -311,6 +315,7 @@ PYBIND11_MODULE(_c_expression, m) {
}
}});
#ifndef ENABLE_SECURITY
(void)py::class_<EventWriter, std::shared_ptr<EventWriter>>(m, "EventWriter_")
.def(py::init<const std::string &>())
.def("GetFileName", &EventWriter::GetFileName, "Get the file name.")
@ -320,6 +325,7 @@ PYBIND11_MODULE(_c_expression, m) {
.def("Flush", &EventWriter::Flush, "Flush the event.")
.def("Close", &EventWriter::Close, "Close the write.")
.def("Shut", &EventWriter::Shut, "Final close the write.");
#endif // ENABLE_SECURITY
(void)py::class_<OpLib, std::shared_ptr<OpLib>>(m, "Oplib")
.def(py::init())

View File

@ -592,8 +592,10 @@ void GPUKernelRuntime::InitKernelRefCount(const session::KernelGraph *graph) {
mem_reuse_util_ptr->SetReuseRefCount();
// Can't free the device address of graph output, so set the reference count of graph output specially.
mem_reuse_util_ptr->SetGraphOutputRefCount();
#ifndef ENABLE_SECURITY
// Can't free the device address of summary nodes, so set the reference count of summary nodes specially.
mem_reuse_util_ptr->SetSummaryNodesRefCount();
#endif
auto graph_id = graph->graph_id();
mem_reuse_util_map_[graph_id] = mem_reuse_util_ptr;
}

View File

@ -353,8 +353,9 @@ GraphId GraphCompiler::CompileGraphImpl(const KernelGraphPtr &graph, const Devic
MS_EXCEPTION_IF_NULL(session_);
session_->InitAllBucket(graph, device_context);
#ifndef ENABLE_SECURITY
session_->SetSummaryNodes(graph.get());
#endif
SetSummaryNodesRefCount(graph.get());
#ifdef ENABLE_DUMP_IR
// Dump .pb graph after graph optimization.
@ -525,13 +526,17 @@ const std::vector<KernelWithIndex> &GraphCompiler::GetGraphOutputNodes(GraphId g
void GraphCompiler::RegisterSummaryCallBackFunc(const CallBackFunc &callback) const {
MS_EXCEPTION_IF_NULL(session_);
#ifndef ENABLE_SECURITY
session_->RegisterSummaryCallBackFunc(callback);
#endif
}
void GraphCompiler::Summary(const std::vector<KernelGraphPtr> &graphs) const {
MS_EXCEPTION_IF_NULL(session_);
for (const auto &graph : graphs) {
#ifndef ENABLE_SECURITY
session_->Summary(graph.get());
#endif
}
}

View File

@ -32,10 +32,12 @@ REG_ADPT_DESC(Constant, kNameConst, ADPT_DESC(Constant, Const))
// ScalarSummary
INPUT_MAP(Summary) = {{2, INPUT_DESC(x)}};
ATTR_MAP(Summary) = EMPTY_ATTR_MAP;
#ifndef ENABLE_SECURITY
REG_ADPT_DESC(ScalarSummary, prim::kPrimScalarSummary->name(), ADPT_DESC(Summary))
REG_ADPT_DESC(ImageSummary, prim::kPrimImageSummary->name(), ADPT_DESC(Summary))
REG_ADPT_DESC(TensorSummary, prim::kPrimTensorSummary->name(), ADPT_DESC(Summary))
REG_ADPT_DESC(HistogramSummary, prim::kPrimHistogramSummary->name(), ADPT_DESC(Summary))
#endif
REG_ADPT_DESC(Debug, prim::kPrimDebug->name(), ADPT_DESC(Summary))
// Data

View File

@ -12,6 +12,10 @@ if(NOT ENABLE_D AND NOT ENABLE_TESTCASES)
file(GLOB_RECURSE _UTILS_D_SRC_FILES ./runtime_error_codes.cc)
list(REMOVE_ITEM _UTILS_SRC_LIST ${_UTILS_D_SRC_FILES})
endif()
if(ENABLE_SECURITY)
file(GLOB_RECURSE _UTILS_SUMMARY_FILES ./summary/event_writer.cc)
list(REMOVE_ITEM _UTILS_SRC_LIST ${_UTILS_SUMMARY_FILES})
endif()
set_property(SOURCE ${_UTILS_SRC_LIST} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_UTILS)
add_library(_mindspore_utils_obj OBJECT ${_UTILS_SRC_LIST})

View File

@ -301,7 +301,9 @@ MsBackend::MsBackend(const std::string &name, const std::string &target, uint32_
MS_LOG(EXCEPTION) << "Session create failed!, please make sure target device:" << target << " is available.";
}
target_sess_->Init(device_id);
#ifndef ENABLE_SECURITY
target_sess_->RegisterSummaryCallBackFunc(callbacks::SummarySaveCallback);
#endif
target_device_ = target;
}
@ -317,7 +319,9 @@ void MsBackend::CreateOtherSession(const std::string &target) {
MS_EXCEPTION_IF_NULL(context_ptr);
uint32_t device_id = context_ptr->get_param<uint32_t>(MS_CTX_DEVICE_ID);
other_sess_->Init(device_id);
#ifndef ENABLE_SECURITY
other_sess_->RegisterSummaryCallBackFunc(callbacks::SummarySaveCallback);
#endif
other_device_ = target;
}

View File

@ -25,6 +25,11 @@ file(GLOB_RECURSE CORE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
"load_mindir/*.cc"
)
if(ENABLE_SECURITY)
file(GLOB_RECURSE _INFER_SUMMARY_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "ops/*_summary.cc")
list(REMOVE_ITEM CORE_SRC_LIST ${_INFER_SUMMARY_FILES})
endif()
file(GLOB_RECURSE PROTO_FILE RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "proto/*.proto")
if(NOT(BUILD_LITE))
ms_protobuf_generate_py(PROTO_SRCS PY_HDRS PY_PYS ${PROTO_FILE})

View File

@ -548,10 +548,12 @@ inline const PrimitivePtr kPrimGenerateInverseIndex = std::make_shared<Primitive
// Debug ops
inline const PrimitivePtr kPrimAssert = std::make_shared<Primitive>("Assert");
#ifndef ENABLE_SECURITY
inline const PrimitivePtr kPrimScalarSummary = std::make_shared<Primitive>("ScalarSummary");
inline const PrimitivePtr kPrimImageSummary = std::make_shared<Primitive>("ImageSummary");
inline const PrimitivePtr kPrimTensorSummary = std::make_shared<Primitive>("TensorSummary");
inline const PrimitivePtr kPrimHistogramSummary = std::make_shared<Primitive>("HistogramSummary");
#endif
inline const PrimitivePtr kPrimDebug = std::make_shared<Primitive>("Debug");
// Dynamic shape testing

View File

@ -392,12 +392,16 @@ std::string GetVirtualNodeTargetFromInputs(const AnfNodePtr &node) {
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
auto &inputs = cnode->inputs();
#ifndef ENABLE_SECURITY
if (IsPrimitiveCNode(node, prim::kPrimImageSummary) || IsPrimitiveCNode(node, prim::kPrimScalarSummary) ||
IsPrimitiveCNode(node, prim::kPrimTensorSummary) || IsPrimitiveCNode(node, prim::kPrimHistogramSummary)) {
if (inputs.size() > 1) {
return GetOriginNodeTarget(inputs[1]);
}
} else if (IsPrimitiveCNode(node, prim::kPrimDepend) || IsPrimitiveCNode(node, prim::kPrimLoad)) {
return kTargetUnDefined;
}
#endif
if (IsPrimitiveCNode(node, prim::kPrimDepend) || IsPrimitiveCNode(node, prim::kPrimLoad)) {
const size_t node_inputs_num = 3;
if (inputs.size() >= node_inputs_num) {
size_t use_index = 1;
@ -521,6 +525,7 @@ std::string GetOriginNodeTarget(const AnfNodePtr &node) {
if (target != kTargetUnDefined) {
return target;
}
#ifndef ENABLE_SECURITY
if (IsPrimitiveCNode(node, prim::kPrimImageSummary) || IsPrimitiveCNode(node, prim::kPrimScalarSummary) ||
IsPrimitiveCNode(node, prim::kPrimTensorSummary) || IsPrimitiveCNode(node, prim::kPrimHistogramSummary) ||
IsPrimitiveCNode(node, prim::kPrimDepend) || IsPrimitiveCNode(node, prim::kPrimLoad) ||
@ -528,6 +533,13 @@ std::string GetOriginNodeTarget(const AnfNodePtr &node) {
IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) {
return GetVirtualNodeTarget(node);
}
#else
if (IsPrimitiveCNode(node, prim::kPrimDepend) || IsPrimitiveCNode(node, prim::kPrimLoad) ||
IsPrimitiveCNode(node, prim::kPrimUpdateState) || IsPrimitiveCNode(node, prim::kPrimMakeTuple) ||
IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) {
return GetVirtualNodeTarget(node);
}
#endif
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
return context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);

View File

@ -41,6 +41,7 @@ std::string CNode::fullname_with_scope() {
return fullname_with_scope_;
}
#ifndef ENABLE_SECURITY
if (IsApply(prim::kPrimScalarSummary) || IsApply(prim::kPrimTensorSummary) || IsApply(prim::kPrimImageSummary) ||
IsApply(prim::kPrimHistogramSummary)) {
std::string tag = GetValue<std::string>(GetValueNode(input(1)));
@ -55,40 +56,40 @@ std::string CNode::fullname_with_scope() {
name = tag + "[:Tensor]";
}
fullname_with_scope_ = name;
} else {
// cnode input 0 should be primitive ptr or funcgraph ptr
auto value_ptr = input(0)->cast<ValueNodePtr>();
if (value_ptr == nullptr) {
MS_LOG(WARNING) << "Input 0 of cnode is not a value node, its type is " << input(0)->type_name() << ".";
fullname_with_scope_ = id_generator::get_id(shared_from_base<CNode>());
return fullname_with_scope_;
}
auto input_value = value_ptr->value();
if (input_value == nullptr) {
MS_LOG(WARNING) << "Value of input 0 of cnode is nullptr.";
fullname_with_scope_ = id_generator::get_id(shared_from_base<CNode>());
return fullname_with_scope_;
}
auto prim = input_value->cast<PrimitivePtr>();
MS_EXCEPTION_IF_NULL(scope());
fullname_with_scope_ = scope()->name() + "/";
if (prim != nullptr) {
fullname_with_scope_ += prim->name();
} else {
auto func_graph = input_value->cast<FuncGraphPtr>();
MS_EXCEPTION_IF_NULL(func_graph);
auto fg_flag = func_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL);
if (fg_flag != nullptr) {
auto fg_name = GetValue<std::string>(fg_flag);
fullname_with_scope_ += "GraphKernel_" + fg_name;
} else {
fullname_with_scope_ += func_graph->ToString();
}
}
fullname_with_scope_ += "-op" + id_generator::get_id(shared_from_base<CNode>());
return fullname_with_scope_;
}
#endif
// cnode input 0 should be primitive ptr or funcgraph ptr
auto value_ptr = input(0)->cast<ValueNodePtr>();
if (value_ptr == nullptr) {
MS_LOG(WARNING) << "Input 0 of cnode is not a value node, its type is " << input(0)->type_name() << ".";
fullname_with_scope_ = id_generator::get_id(shared_from_base<CNode>());
return fullname_with_scope_;
}
auto input_value = value_ptr->value();
if (input_value == nullptr) {
MS_LOG(WARNING) << "Value of input 0 of cnode is nullptr.";
fullname_with_scope_ = id_generator::get_id(shared_from_base<CNode>());
return fullname_with_scope_;
}
auto prim = input_value->cast<PrimitivePtr>();
MS_EXCEPTION_IF_NULL(scope());
fullname_with_scope_ = scope()->name() + "/";
if (prim != nullptr) {
fullname_with_scope_ += prim->name();
} else {
auto func_graph = input_value->cast<FuncGraphPtr>();
MS_EXCEPTION_IF_NULL(func_graph);
auto fg_flag = func_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL);
if (fg_flag != nullptr) {
auto fg_name = GetValue<std::string>(fg_flag);
fullname_with_scope_ += "GraphKernel_" + fg_name;
} else {
fullname_with_scope_ += func_graph->ToString();
}
}
fullname_with_scope_ += "-op" + id_generator::get_id(shared_from_base<CNode>());
return fullname_with_scope_;
}

View File

@ -22,6 +22,7 @@
namespace mindspore {
// clang-format off
#ifndef ENABLE_SECURITY
static const std::set<std::string> PARALLEL_BLACK_LIST_ = {prim::kTupleGetItem, "J", "list_getitem",
"array_getitem", "tuple_setitem", "Depend", "list_setitem", "array_setitem", "dict_getitem",
"list_append", "list_map", "list_reduce", "tuple_reversed", "tile_shape", "tuple_div", "tuple_to_array",
@ -31,6 +32,16 @@ static const std::set<std::string> PARALLEL_BLACK_LIST_ = {prim::kTupleGetItem,
"ImageSummary", "TensorSummary", "Debug", "HistogramSummary", "col2im_v1", "resolve", "BroadcastGradientArgs",
"InvertPermutation", "DropoutGenMask", "embed", "create_instance", "RefToEmbed",
"stop_gradient", "UpdateState", "Load", "Switch", "Print"};
#else
static const std::set<std::string> PARALLEL_BLACK_LIST_ = {prim::kTupleGetItem, "J", "list_getitem",
"array_getitem", "tuple_setitem", "Depend", "list_setitem", "array_setitem", "dict_getitem",
"list_append", "list_map", "list_reduce", "tuple_reversed", "tile_shape", "tuple_div", "tuple_to_array",
"make_dict", "make_slice", "make_record", "string_equal", "VirtualLoss", "Return", "env_getitem",
"identity", "partial", "env_setitem", "env_getitem", "env_add", "MakeRefKey", "make_ref", "get_ref_key",
"get_ref_value", "get_ref_origin", "dot", "im2col", "col2im", "im2col_v1", "state_setitem", "Debug", "col2im_v1",
"resolve", "BroadcastGradientArgs", "InvertPermutation", "DropoutGenMask", "embed", "create_instance", "RefToEmbed",
"stop_gradient", "UpdateState", "Load", "Switch", "Print"};
#endif
static const std::set<PrimitivePtr> ALLGATHER_NODE_LIST_ = {prim::kPrimAllGather, prim::kPrimMiniStepAllGather,
prim::kPrimMicroStepAllGather};
static const std::set<PrimitivePtr> TRIVIAL_NODE_LIST_ = {prim::kPrimCast, prim::kPrimDepend};

View File

@ -158,12 +158,18 @@ bool IsRealKernel(const AnfNodePtr &node) {
return false;
}
auto input = cnode->inputs()[0];
#ifndef ENABLE_SECURITY
bool is_virtual_node = IsPrimitive(input, prim::kPrimImageSummary) || IsPrimitive(input, prim::kPrimScalarSummary) ||
IsPrimitive(input, prim::kPrimTensorSummary) ||
IsPrimitive(input, prim::kPrimHistogramSummary) || IsPrimitive(input, prim::kPrimMakeTuple) ||
IsPrimitive(input, prim::kPrimStateSetItem) || IsPrimitive(input, prim::kPrimDepend) ||
IsPrimitive(input, prim::kPrimTupleGetItem) || IsPrimitive(input, prim::kPrimReturn) ||
IsPrimitive(input, prim::kPrimPartial);
#else
bool is_virtual_node = IsPrimitive(input, prim::kPrimMakeTuple) || IsPrimitive(input, prim::kPrimStateSetItem) ||
IsPrimitive(input, prim::kPrimDepend) || IsPrimitive(input, prim::kPrimTupleGetItem) ||
IsPrimitive(input, prim::kPrimReturn) || IsPrimitive(input, prim::kPrimPartial);
#endif
return !is_virtual_node;
}

View File

@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""debug_ops"""
from types import FunctionType, MethodType
@ -23,6 +22,7 @@ from ..._checkparam import Rel
from ...common import dtype as mstype
from ..primitive import prim_attr_register, Primitive, PrimitiveWithInfer
def _check_mode(class_name):
"""Check for PyNative mode."""
mode = context.get_context('mode')
@ -87,6 +87,10 @@ class ScalarSummary(Primitive):
@prim_attr_register
def __init__(self):
"""Initialize ScalarSummary."""
if security.enable_security():
raise ValueError('The Summary is not supported, please without `-s on` and recompile source.')
self.add_prim_attr("side_effect_io", True)
@ -126,6 +130,10 @@ class ImageSummary(PrimitiveWithInfer):
@prim_attr_register
def __init__(self):
"""Initialize ImageSummary."""
if security.enable_security():
raise ValueError('The Summary is not supported, please without `-s on` and recompile source.')
self.add_prim_attr("side_effect_io", True)
def __infer__(self, name, value):
@ -178,6 +186,10 @@ class TensorSummary(Primitive):
@prim_attr_register
def __init__(self):
"""Initialize TensorSummary."""
if security.enable_security():
raise ValueError('The Summary is not supported, please without `-s on` and recompile source.')
self.add_prim_attr("side_effect_io", True)
@ -218,6 +230,10 @@ class HistogramSummary(PrimitiveWithInfer):
@prim_attr_register
def __init__(self):
"""Initialize HistogramSummary."""
if security.enable_security():
raise ValueError('The Summary is not supported, please without `-s on` and recompile source.')
self.add_prim_attr("side_effect_io", True)
def __infer__(self, name, value):

View File

@ -36,6 +36,7 @@ from mindspore.train.callback._dataset_graph import DatasetGraph
from mindspore.nn.optim.optimizer import Optimizer
from mindspore.nn.loss.loss import LossBase
from mindspore.train._utils import check_value_type
from ..._c_expression import security
HYPER_CONFIG_ENV_NAME = "MINDINSIGHT_HYPER_CONFIG"
HYPER_CONFIG_LEN_LIMIT = 100000
@ -194,6 +195,10 @@ class SummaryCollector(Callback):
collect_tensor_freq=None,
max_file_size=None,
export_options=None):
if security.enable_security():
raise ValueError('The Summary is not supported, please without `-s on` and recompile source.')
super(SummaryCollector, self).__init__()
self._summary_dir = self._process_summary_dir(summary_dir)

View File

@ -23,7 +23,7 @@ from collections import defaultdict
from mindspore import log as logger
from mindspore.nn import Cell
from ..._c_expression import Tensor
from ..._c_expression import Tensor, security
from ..._checkparam import Validator
from .._utils import _check_lineage_value, _check_to_numpy, _make_directory, check_value_type
from ._summary_adapter import get_event_file_name, package_graph_event
@ -146,6 +146,9 @@ class SummaryRecord:
def __init__(self, log_dir, file_prefix="events", file_suffix="_MS",
network=None, max_file_size=None, raise_exception=False, export_options=None):
if security.enable_security():
raise ValueError('The Summary is not supported, please without `-s on` and recompile source.')
self._event_writer = None
self._mode, self._data_pool = 'train', defaultdict(list)
self._status = {

View File

@ -23,8 +23,11 @@ import numpy as np
from mindspore.train.summary.enums import PluginEnum, WriterPluginEnum
from .._utils import _make_directory
from ..._c_expression import EventWriter_
from ._summary_adapter import package_init_event
from ..._c_expression import security
if not security.enable_security():
from ..._c_expression import EventWriter_
FREE_DISK_SPACE_TIMES = 32
FILE_MODE = 0o400
@ -41,7 +44,7 @@ class BaseWriter:
"""Write some metadata etc."""
@property
def writer(self) -> EventWriter_:
def writer(self):
"""Get the writer."""
if self._writer is not None:
return self._writer

View File

@ -384,6 +384,7 @@ def train_summary_record(test_writer, steps):
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
@pytest.mark.security_off
def test_summary():
with tempfile.TemporaryDirectory() as tmp_dir:
steps = 2

View File

@ -62,6 +62,7 @@ def train_summary_record(test_writer, steps):
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@pytest.mark.security_off
def test_summary_step2_summary_record1():
"""Test record 10 step summary."""
if platform.system() == "Windows":

View File

@ -126,6 +126,7 @@ class TestSummary:
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@pytest.mark.security_off
def test_summary_with_sink_mode_false(self):
"""Test summary with sink mode false, and num samples is 64."""
summary_dir = self._run_network(num_samples=10)
@ -148,6 +149,7 @@ class TestSummary:
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@pytest.mark.security_off
def test_summary_with_sink_mode_true(self):
"""Test summary with sink mode true, and num samples is 64."""
summary_dir = self._run_network(dataset_sink_mode=True, num_samples=10)
@ -167,6 +169,7 @@ class TestSummary:
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard
@pytest.mark.security_off
def test_summarycollector_user_defind(self):
"""Test SummaryCollector with user-defined."""
summary_dir = self._run_network(dataset_sink_mode=True, num_samples=2,

View File

@ -93,6 +93,7 @@ class TestSummaryOps:
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@pytest.mark.security_off
def test_summary_ops(self):
"""Test summary operators."""
ds_train = create_mnist_dataset('train', num_samples=1, batch_size=1)

View File

@ -167,6 +167,12 @@ file(GLOB_RECURSE MINDSPORE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
"../../../mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/adam_fp32.c"
)
if(ENABLE_SECURITY)
file(GLOB_RECURSE _INFER_SUMMARY_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
"../../../mindspore/core/ops/*_summary.cc"
)
list(REMOVE_ITEM MINDSPORE_SRC_LIST ${_INFER_SUMMARY_FILES})
endif()
list(REMOVE_ITEM MINDSPORE_SRC_LIST
"../../../mindspore/ccsrc/frontend/parallel/strategy_checkpoint/parallel_strategy_checkpoint.cc")
list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/ps/util.cc")

View File

@ -530,6 +530,7 @@ TEST_F(TestConvert, TestSquareOps) {
ASSERT_TRUE(ret);
}
#ifndef ENABLE_SECURITY
TEST_F(TestConvert, TestScalarSummaryOps) {
auto prim = prim::kPrimScalarSummary;
// should have only 1 input.
@ -548,6 +549,7 @@ TEST_F(TestConvert, TestHistogramSummaryOps) {
bool ret = MakeDfGraph(prim, 2);
ASSERT_TRUE(ret);
}
#endif
TEST_F(TestConvert, TestGreaterOps) {
auto prim = std::make_shared<Primitive>("Greater");

View File

@ -14,6 +14,7 @@
# ============================================================================
""" test nn ops """
import numpy as np
import pytest
import mindspore
import mindspore.context as context
@ -25,6 +26,8 @@ from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.ops.operations import _grad_ops as G
from mindspore.ops import prim_attr_register, PrimitiveWithInfer
from mindspore._c_expression import security
from tests.security_utils import security_off_wrap
from ..ut_filter import non_graph_engine
from ....mindspore_test_framework.mindspore_test import mindspore_test
from ....mindspore_test_framework.pipeline.forward.compile_forward \
@ -33,6 +36,7 @@ from ....mindspore_test_framework.pipeline.forward.verify_exception \
import pipeline_for_verify_exception_for_case_by_case_config
context.set_context(mode=context.GRAPH_MODE)
def conv3x3(in_channels, out_channels, stride=1, padding=1):
"""3x3 convolution """
return nn.Conv2d(in_channels, out_channels,
@ -520,18 +524,10 @@ test_cases = [
'desc_inputs': [[128, 32, 32, 64]],
'desc_bprop': [[128, 32, 32, 64]],
}),
('ScalarSummary', {
'block': ScalarSummaryNet(),
'desc_inputs': [Tensor(2.2)],
}),
('L2Normalize', {
'block': L2NormalizeNet(),
'desc_inputs': [Tensor(np.array([[1.0, 2, 3], [4.0, 5, 6], [7.0, 8, 9]]), mindspore.float32)],
}),
('HistogramSummary', {
'block': HistogramSummaryNet(),
'desc_inputs': [[1, 2, 3]],
}),
('FusedBatchNormGrad', {
'block': FusedBatchNormGrad(nn.BatchNorm2d(num_features=512, eps=1e-5, momentum=0.1)),
'desc_inputs': [[64, 512, 7, 7], [64, 512, 7, 7]],
@ -753,6 +749,33 @@ test_cases_for_verify_exception = [
]
@security_off_wrap
@non_graph_engine
@mindspore_test(pipeline_for_verify_exception_for_case_by_case_config)
def test_summary_nn_ops():
if security.enable_security():
return []
test_cases_for_summary_ops = [
('ScalarSummary', {
'block': ScalarSummaryNet(),
'desc_inputs': [Tensor(2.2)],
}),
('HistogramSummary', {
'block': HistogramSummaryNet(),
'desc_inputs': [[1, 2, 3]],
}),
]
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
return test_cases_for_summary_ops
def test_summary_nn_ops_security_on():
if security.enable_security():
with pytest.raises(ValueError) as exc:
ScalarSummaryNet()
assert str(exc.value) == 'The Summary is not supported, please without `-s on` and recompile source.'
@non_graph_engine
@mindspore_test(pipeline_for_compile_forward_ge_graph_for_case_by_case_config)
def test_compile():

View File

@ -31,6 +31,8 @@ from mindspore.ops.operations import _inner_ops as inner
from mindspore.ops.operations import _quant_ops as Q
from mindspore.ops.operations import nn_ops as nps
from mindspore.nn.layer import normalization
from mindspore._c_expression import security
from tests.security_utils import security_off_wrap
from ..ut_filter import non_graph_engine
from ....mindspore_test_framework.mindspore_test import mindspore_test
from ....mindspore_test_framework.pipeline.forward.compile_forward \
@ -2866,16 +2868,6 @@ test_case_other_ops = [
'block': P.IOU(),
'desc_inputs': [Tensor(np.ones((256, 4), np.float16)), Tensor(np.ones((128, 4), np.float16))],
'desc_bprop': [convert([128, 256], np.float16)]}),
('Summary', {
'block': SummaryNet(),
'desc_inputs': [Tensor(np.array([1.1]).astype(np.float32)),
Tensor(np.array([1.2]).astype(np.float32))],
'skip': ['backward']}),
('HistogramSummary', {
'block': HistogramSummaryNet(),
'desc_inputs': [Tensor(np.array([1.1]).astype(np.float32)),
Tensor(np.array([1.2]).astype(np.float32))],
'skip': ['backward']}),
('PopulationCount', {
'block': P.PopulationCount(),
'desc_inputs': [Tensor(np.array([1, 2, 3]).astype(np.int16))],
@ -3046,6 +3038,38 @@ def test_backward_exec():
return test_backward_exec_case
@security_off_wrap
@non_graph_engine
@mindspore_test(pipeline_for_compile_forward_ge_graph_for_case_by_case_config)
def test_summary_ops():
if security.enable_security():
return []
test_cases_for_summary_ops = [
('Summary', {
'block': SummaryNet(),
'desc_inputs': [Tensor(np.array([1.1]).astype(np.float32)),
Tensor(np.array([1.2]).astype(np.float32))],
'skip': ['backward']}),
('HistogramSummary', {
'block': HistogramSummaryNet(),
'desc_inputs': [Tensor(np.array([1.1]).astype(np.float32)),
Tensor(np.array([1.2]).astype(np.float32))],
'skip': ['backward']}),
]
context.set_context(mode=context.GRAPH_MODE)
return test_cases_for_summary_ops
def test_summary_ops_security_on():
if security.enable_security():
with pytest.raises(ValueError) as exc:
SummaryNet()
assert str(exc.value) == 'The Summary is not supported, please without `-s on` and recompile source.'
with pytest.raises(ValueError) as exc:
HistogramSummaryNet()
assert str(exc.value) == 'The Summary is not supported, please without `-s on` and recompile source.'
raise_set = [
('Cast_Error', {
'block': (P.Cast(), {'exception': TypeError}),

View File

@ -166,7 +166,6 @@ def test_ops():
if [1, 2, 3] is not None:
if self.str_a + self.str_b == "helloworld":
if q == 86:
print("hello world")
return ret
return x

View File

@ -23,6 +23,7 @@ from mindspore.common.tensor import Tensor
from mindspore.train.summary._summary_adapter import _calc_histogram_bins
from mindspore.train.summary.summary_record import SummaryRecord, _cache_summary_tensor_data
from tests.summary_utils import SummaryReader
from tests.security_utils import security_off_wrap
CUR_DIR = os.getcwd()
SUMMARY_DIR = os.path.join(CUR_DIR, "/test_temp_summary_event_file/")
@ -48,6 +49,7 @@ def _wrap_test_data(input_data: Tensor):
}]
@security_off_wrap
def test_histogram_summary():
"""Test histogram summary."""
with tempfile.TemporaryDirectory() as tmp_dir:
@ -61,6 +63,7 @@ def test_histogram_summary():
assert event.summary.value[0].histogram.count == 6
@security_off_wrap
def test_histogram_multi_summary():
"""Test histogram multiple step."""
with tempfile.TemporaryDirectory() as tmp_dir:
@ -83,6 +86,8 @@ def test_histogram_multi_summary():
event = reader.read_event()
assert event.summary.value[0].histogram.count == size
@security_off_wrap
def test_histogram_summary_empty_tensor():
"""Test histogram summary, input is an empty tensor."""
with tempfile.TemporaryDirectory() as tmp_dir:
@ -97,6 +102,7 @@ def test_histogram_summary_empty_tensor():
assert event.summary.value[0].histogram.count == 0
@security_off_wrap
def test_histogram_summary_same_value():
"""Test histogram summary, input is an ones tensor."""
with tempfile.TemporaryDirectory() as tmp_dir:
@ -116,6 +122,7 @@ def test_histogram_summary_same_value():
assert len(event.summary.value[0].histogram.buckets) == _calc_histogram_bins(dim1 * dim2)
@security_off_wrap
def test_histogram_summary_high_dims():
"""Test histogram summary, input is a 4-dimension tensor."""
with tempfile.TemporaryDirectory() as tmp_dir:
@ -136,6 +143,7 @@ def test_histogram_summary_high_dims():
assert event.summary.value[0].histogram.count == tensor_data.size
@security_off_wrap
def test_histogram_summary_nan_inf():
"""Test histogram summary, input tensor has nan."""
with tempfile.TemporaryDirectory() as tmp_dir:
@ -160,6 +168,7 @@ def test_histogram_summary_nan_inf():
assert event.summary.value[0].histogram.nan_count == 1
@security_off_wrap
def test_histogram_summary_all_nan_inf():
"""Test histogram summary, input tensor has no valid number."""
with tempfile.TemporaryDirectory() as tmp_dir:

View File

@ -23,6 +23,7 @@ from mindspore import Tensor
from mindspore.nn.optim import Momentum
from mindspore.train.summary.summary_record import SummaryRecord, _cache_summary_tensor_data
from mindspore.train.callback import Callback
from tests.security_utils import security_off_wrap
from .....dataset_mock import MindData
CUR_DIR = os.getcwd()
@ -62,6 +63,7 @@ def get_test_data(step):
# Test: call method on parse graph code
@security_off_wrap
def test_image_summary_sample():
""" test_image_summary_sample """
with SummaryRecord(SUMMARY_DIR, file_suffix="_MS_IMAGE") as test_writer:
@ -157,6 +159,7 @@ class ImageSummaryCallback(Callback):
self._summary_record.flush()
@security_off_wrap
def test_image_summary_train():
""" test_image_summary_train """
dataset = get_dataset()
@ -166,6 +169,7 @@ def test_image_summary_train():
model.train(2, dataset, callbacks=[callback])
@security_off_wrap
def test_image_summary_data():
""" test_image_summary_data """
dataset = get_dataset()

View File

@ -32,7 +32,8 @@ from mindspore.train.summary.summary_record import _DEFAULT_EXPORT_OPTIONS
from mindspore.nn import Cell
from mindspore.nn.optim.optimizer import Optimizer
from mindspore.ops.operations import Add
from mindspore._c_expression import security
from tests.security_utils import security_off_wrap
_VALUE_CACHE = list()
@ -84,6 +85,7 @@ class TestSummaryCollector:
"""Run after each test function."""
get_value()
@security_off_wrap
@pytest.mark.parametrize("summary_dir", [1234, None, True, ''])
def test_params_with_summary_dir_value_error(self, summary_dir):
"""Test the exception scenario for summary dir."""
@ -97,6 +99,7 @@ class TestSummaryCollector:
SummaryCollector(summary_dir=summary_dir)
assert 'For `summary_dir` the type should be a valid type' in str(exc.value)
@security_off_wrap
def test_params_with_summary_dir_not_dir(self):
"""Test the given summary dir parameter is not a directory."""
summary_dir = tempfile.mkdtemp(dir=self.base_summary_dir)
@ -107,6 +110,7 @@ class TestSummaryCollector:
with pytest.raises(NotADirectoryError):
SummaryCollector(summary_dir=summary_file)
@security_off_wrap
@pytest.mark.parametrize("collect_freq", [None, 0, 0.01])
def test_params_with_collect_freq_exception(self, collect_freq):
"""Test the exception scenario for collect freq."""
@ -123,6 +127,7 @@ class TestSummaryCollector:
f'but got {type(collect_freq).__name__}.'
assert expected_msg == str(exc.value)
@security_off_wrap
@pytest.mark.parametrize("action", [None, 123, '', '123'])
def test_params_with_action_exception(self, action):
"""Test the exception scenario for action."""
@ -133,6 +138,7 @@ class TestSummaryCollector:
f"but got {type(action).__name__}."
assert expected_msg == str(exc.value)
@security_off_wrap
@pytest.mark.parametrize("collect_specified_data", [123])
def test_params_with_collect_specified_data_type_error(self, collect_specified_data):
"""Test type error scenario for collect specified data param."""
@ -145,6 +151,7 @@ class TestSummaryCollector:
assert expected_msg == str(exc.value)
@security_off_wrap
@pytest.mark.parametrize("export_options", [
{
"tensor_format": "npz"
@ -163,6 +170,7 @@ class TestSummaryCollector:
assert expected_msg == str(exc.value)
@security_off_wrap
@pytest.mark.parametrize("export_options", [123])
def test_params_with_export_options_type_error(self, export_options):
"""Test type error scenario for collect specified data param."""
@ -175,6 +183,7 @@ class TestSummaryCollector:
assert expected_msg == str(exc.value)
@security_off_wrap
@pytest.mark.parametrize("collect_specified_data", [
{
123: 123
@ -194,6 +203,7 @@ class TestSummaryCollector:
f"but got {type(param_name).__name__}."
assert expected_msg == str(exc.value)
@security_off_wrap
@pytest.mark.parametrize("collect_specified_data", [
{
'collect_metric': None
@ -219,6 +229,7 @@ class TestSummaryCollector:
assert expected_msg == str(exc.value)
@security_off_wrap
def test_params_with_histogram_regular_value_error(self):
"""Test histogram regular."""
summary_dir = tempfile.mkdtemp(dir=self.base_summary_dir)
@ -227,6 +238,7 @@ class TestSummaryCollector:
assert 'For `collect_specified_data`, the value of `histogram_regular`' in str(exc.value)
@security_off_wrap
def test_params_with_collect_specified_data_unexpected_key(self):
"""Test the collect_specified_data parameter with unexpected key."""
summary_dir = tempfile.mkdtemp(dir=self.base_summary_dir)
@ -236,6 +248,7 @@ class TestSummaryCollector:
expected_msg = f"For `collect_specified_data` the keys {set(data)} are unsupported"
assert expected_msg in str(exc.value)
@security_off_wrap
def test_params_with_export_options_unexpected_key(self):
"""Test the export_options parameter with unexpected key."""
summary_dir = tempfile.mkdtemp(dir=self.base_summary_dir)
@ -245,6 +258,7 @@ class TestSummaryCollector:
expected_msg = f"For `export_options` the keys {set(data)} are unsupported"
assert expected_msg in str(exc.value)
@security_off_wrap
@pytest.mark.parametrize("custom_lineage_data", [
123,
{
@ -280,6 +294,7 @@ class TestSummaryCollector:
assert expected_msg == str(exc.value)
@security_off_wrap
def test_check_callback_with_multi_instances(self):
"""Use multi SummaryCollector instances to test check_callback function."""
cb_params = _InternalCallbackParam()
@ -291,6 +306,7 @@ class TestSummaryCollector:
SummaryCollector((tempfile.mkdtemp(dir=self.base_summary_dir)))._check_callbacks(cb_params)
assert f"more than one SummaryCollector instance in callback list" in str(exc.value)
@security_off_wrap
def test_collect_input_data_with_train_dataset_element_invalid(self):
"""Test the param 'train_dataset_element' in cb_params is invalid."""
cb_params = _InternalCallbackParam()
@ -300,6 +316,7 @@ class TestSummaryCollector:
summary_collector._collect_input_data(cb_params)
assert not summary_collector._collect_specified_data['collect_input_data']
@security_off_wrap
@mock.patch.object(SummaryRecord, 'add_value')
def test_collect_input_data_success(self, mock_add_value):
"""Mock a image data, and collect image data success."""
@ -311,6 +328,7 @@ class TestSummaryCollector:
summary_collector._collect_input_data(cb_params)
# Note Here need to assert the result and expected data
@security_off_wrap
@mock.patch.object(SummaryRecord, 'add_value')
def test_collect_dataset_graph_success(self, mock_add_value):
"""Test collect dataset graph."""
@ -325,6 +343,7 @@ class TestSummaryCollector:
assert plugin == 'dataset_graph'
assert name == 'train_dataset'
@security_off_wrap
@pytest.mark.parametrize("net_output, expected_loss", [
(None, None),
(1, Tensor(1)),
@ -349,6 +368,7 @@ class TestSummaryCollector:
else:
assert summary_collector._is_parse_loss_success
@security_off_wrap
def test_get_optimizer_from_cb_params_success(self):
"""Test get optimizer success from cb params."""
cb_params = _InternalCallbackParam()
@ -360,6 +380,7 @@ class TestSummaryCollector:
# Test get optimizer again
assert summary_collector._get_optimizer(cb_params) == cb_params.optimizer
@security_off_wrap
@pytest.mark.parametrize('mode', [ModeEnum.TRAIN.value, ModeEnum.EVAL.value])
def test_get_optimizer_from_network(self, mode):
"""Get optimizer from train network"""
@ -374,6 +395,7 @@ class TestSummaryCollector:
optimizer = summary_collector._get_optimizer(cb_params)
assert isinstance(optimizer, Optimizer)
@security_off_wrap
def test_get_optimizer_failed(self):
"""Test get optimizer failed."""
class Net(Cell):
@ -399,6 +421,7 @@ class TestSummaryCollector:
assert optimizer is None
assert summary_collector._temp_optimizer == 'Failed'
@security_off_wrap
@pytest.mark.parametrize("histogram_regular, expected_names", [
(
'conv1|conv2',
@ -430,6 +453,7 @@ class TestSummaryCollector:
assert PluginEnum.HISTOGRAM.value == result[0][0]
assert expected_names == [data[1] for data in result]
@security_off_wrap
@pytest.mark.parametrize("specified_data, action, expected_result", [
(None, True, SummaryCollector._DEFAULT_SPECIFIED_DATA),
(None, False, {}),
@ -446,3 +470,11 @@ class TestSummaryCollector:
keep_default_action=action)
assert summary_collector._collect_specified_data == expected_result
@pytest.mark.parametrize("summary_dir", './')
def test_summary_collector_security_on(self, summary_dir):
"""Test the summary collector when set security on."""
if security.enable_security():
with pytest.raises(ValueError) as exc:
SummaryCollector(summary_dir=summary_dir)
assert str(exc.value) == 'The Summary is not supported, please without `-s on` and recompile source.'

View File

@ -22,6 +22,8 @@ import pytest
from mindspore.common.tensor import Tensor
from mindspore.train.summary.summary_record import SummaryRecord
from mindspore._c_expression import security
from tests.security_utils import security_off_wrap
def get_test_data(step):
@ -57,12 +59,14 @@ class TestSummaryRecord:
if os.path.exists(self.base_summary_dir):
shutil.rmtree(self.base_summary_dir)
@security_off_wrap
@pytest.mark.parametrize("log_dir", ["", None, 32])
def test_log_dir_with_type_error(self, log_dir):
with pytest.raises(TypeError):
with SummaryRecord(log_dir):
pass
@security_off_wrap
@pytest.mark.parametrize("raise_exception", ["", None, 32])
def test_raise_exception_with_type_error(self, raise_exception):
summary_dir = tempfile.mkdtemp(dir=self.base_summary_dir)
@ -72,9 +76,18 @@ class TestSummaryRecord:
assert "raise_exception" in str(exc.value)
@security_off_wrap
@pytest.mark.parametrize("step", ["str"])
def test_step_of_record_with_type_error(self, step):
summary_dir = tempfile.mkdtemp(dir=self.base_summary_dir)
with pytest.raises(TypeError):
with SummaryRecord(summary_dir) as sr:
sr.record(step)
@pytest.mark.parametrize("log_dir", './')
def test_summary_collector_security_on(self, log_dir):
"""Test the summary collector when set security on."""
if security.enable_security():
with pytest.raises(ValueError) as exc:
SummaryRecord(log_dir=log_dir)
assert str(exc.value) == 'The Summary is not supported, please without `-s on` and recompile source.'