fix codecheck

This commit is contained in:
xulei 2022-02-24 20:28:55 +08:00
parent f8ab74ef29
commit 4cf320cdbd
43 changed files with 95 additions and 77 deletions

View File

@ -18,7 +18,6 @@
#include <memory>
#include <string>
#include <vector>
#include <algorithm>
#include <utility>
#include "backend/common/optimizer/pass_manager.h"

View File

@ -33,7 +33,7 @@ class BackendCSE : public CSE {
BackendCSE() = default;
~BackendCSE() override = default;
virtual bool CheckEqualCnodeInputs(const AnfNodePtr &main, const AnfNodePtr &node) const;
bool CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node, bool check_side_effect = true) const override;
bool CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node, bool check_side_effect) const override;
virtual bool CheckEqualKernelBuildInfo(const AnfNodePtr &main, const AnfNodePtr &node) const;
bool Cse(const FuncGraphPtr graph, const FuncGraphManagerPtr manager) const override;

View File

@ -15,7 +15,6 @@
*/
#include "backend/common/pass/erase_visit_attr.h"
#include <vector>
#include <memory>
#include "kernel/common_utils.h"
#include "backend/common/session/anf_runtime_algorithm.h"

View File

@ -135,8 +135,8 @@ GraphId CPUSession::CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtr
if (!runtime_.Init()) {
MS_LOG(EXCEPTION) << "Kernel runtime init error.";
}
MS_LOG(INFO) << "Assign kernel address";
runtime_.AssignKernelAddress(graph.get());
MS_LOG(INFO) << "Assign kernel graph address";
runtime_.AssignKernelGraphAddress(graph.get());
// set summary node
#ifndef ENABLE_SECURITY
SetSummaryNodes(graph.get());
@ -285,7 +285,7 @@ void CPUSession::RunOpImpl(const GraphInfo &graph_info, OpRunInfo *op_run_info,
if (!runtime_.Init()) {
MS_LOG(EXCEPTION) << "Kernel runtime init error.";
}
runtime_.AssignKernelAddress(kernel_graph.get());
runtime_.AssignKernelGraphAddress(kernel_graph.get());
std::map<tensor::TensorPtr, session::KernelWithIndex> tensor_to_node;
runtime_.CreateOutputTensors(kernel_graph.get(), *input_tensors, outputs, &tensor_to_node);
runtime_.BindInputOutput(kernel_graph.get(), *input_tensors, outputs);

View File

@ -205,7 +205,7 @@ class KernelGraph : public FuncGraph {
// set stream label of graph
void set_stream_distinction_label(uint32_t stream_label) { stream_distinction_label_ = stream_label; }
// get stream label of graph
uint32_t stream_distinction_label() { return stream_distinction_label_; }
uint32_t stream_distinction_label() const { return stream_distinction_label_; }
// refresh execute kernel stream label
void UpdateExecuteKernelStreamLabel();
// calculate the leaf graph order of root graph
@ -406,7 +406,7 @@ class KernelGraph : public FuncGraph {
// The interface to set/get the graph GIL flag.
void set_is_need_gil(bool flag) { is_need_gil_ = flag; }
bool is_need_gil() { return is_need_gil_; }
bool is_need_gil() const { return is_need_gil_; }
bool IsDatasetGraph() const;

View File

@ -251,9 +251,9 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> {
virtual void ExecuteGraph(const std::shared_ptr<KernelGraph> &kernel_graph) { MS_EXCEPTION_IF_NULL(kernel_graph); }
void RunGraphImpl(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs);
virtual KernelGraphPtr BuildOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info,
const std::vector<tensor::TensorPtr> &input_tensors,
const std::vector<int64_t> &tensors_mask) {
virtual KernelGraphPtr BuildOpImpl(const OpRunInfo & /* op_run_info */, const GraphInfo & /* graph_info */,
const std::vector<tensor::TensorPtr> & /* input_tensors */,
const std::vector<int64_t> & /* tensors_mask */) {
return nullptr;
}
virtual void RunOpImpl(const GraphInfo &graph_info, OpRunInfo *op_run_info,

View File

@ -1838,7 +1838,7 @@ uint8_t *Somas::GetNodeWorkSpacePtr(const AnfNodePtr &node, size_t index) const
return ptr;
}
#ifndef ENABLE_SECURITY
void Somas::ConvertToProfilingNode(uint32_t graph_id) {
void Somas::ConvertToProfilingNode(uint32_t graph_id) const {
#ifdef ENABLE_D
auto graph_node = MemoryProfiling::GetInstance().GetGraphMemoryNode(graph_id);
if (graph_node == nullptr) {

View File

@ -37,6 +37,9 @@ namespace mindspore {
namespace somas {
class Somas {
public:
using SomasStreamPtr = std::shared_ptr<SomasStream>;
using SomasTensorPtr = std::shared_ptr<SomasTensor>;
using SomasNodePtr = std::shared_ptr<SomasNode>;
// Constructors/Destructors
Somas() = default;
Somas(const Somas &) = delete;
@ -56,7 +59,7 @@ class Somas {
static bool NodeSort(const SomasNodePtr &node1, const SomasNodePtr &node2);
#ifndef ENABLE_SECURITY
void ConvertToProfilingNode(uint32_t graph_id);
void ConvertToProfilingNode(uint32_t graph_id) const;
#endif
private:

View File

@ -35,11 +35,10 @@ class SomasTensor;
enum NodeType { kCommonNode, kCommunicationNode };
using SomasStreamPtr = std::shared_ptr<SomasStream>;
using SomasTensorPtr = std::shared_ptr<SomasTensor>;
class SomasNode {
public:
using SomasStreamPtr = std::shared_ptr<SomasStream>;
using SomasTensorPtr = std::shared_ptr<SomasTensor>;
using SomasNodePtr = std::shared_ptr<SomasNode>;
// Public attributes (mutated in code)
std::string scope_full_name_;
@ -56,7 +55,7 @@ class SomasNode {
mindspore::HashMap<int64_t, size_t> anc_stream_max_order_;
// Constructors/Destructors
SomasNode(size_t id, NodeType type, SomasStreamPtr stream) : id_(id), stream_(stream), type_(type) {}
SomasNode(size_t id, NodeType type, const SomasStreamPtr &stream) : id_(id), stream_(stream), type_(type) {}
SomasNode(const SomasNode &) = delete;
SomasNode &operator=(const SomasNode &) = delete;
~SomasNode() = default;

View File

@ -137,7 +137,7 @@ bool FootPrint::findOffset(const std::vector<DynamicBitSet> *constraints, const
// transform constrained tensors in non eligible intervals
if (block.Alone()) {
if (m_algorithm_ == kManyObjects && m_starts_.size() > 0 && m_starts_[0]->Alone() &&
if (m_algorithm_ == static_cast<uint32_t>(kManyObjects) && m_starts_.size() > 0 && m_starts_[0]->Alone() &&
(*constraints)[block.m_start_tensor_->index_].IsBitTrue(m_starts_[0]->m_start_tensor_->index_) == false) {
return false;
}

View File

@ -71,6 +71,9 @@ class Interval {
bool contains(size_t width) { return (m_b_ - m_a_) >= width; }
bool contains(const Interval &a) { return ((a.m_a_ >= m_a_) && (a.m_b_ <= m_b_)); }
Interval &operator=(const Interval &in) {
if (this == &in) {
return *this;
}
m_a_ = in.m_a_;
m_b_ = in.m_b_;
return *this;
@ -101,6 +104,9 @@ class BlockTensor {
~BlockTensor() = default;
BlockTensor &operator=(const BlockTensor &bt) {
if (this == &bt) {
return *this;
}
m_bre_allocate_ = bt.m_bre_allocate_;
m_current_sol_ = 0;
m_start_tensor_ = bt.m_start_tensor_;

View File

@ -102,7 +102,10 @@ Status SomasSolverPre::Solving(const session::KernelGraph *graph, TensorsDescMap
Status ret = SUCCESS;
try {
TensorsDescMap &tensors = *ptensors;
size_t total_sol = kNumSortingTypes * kNumFittingTypes * kNumAlgorithmTypes;
constexpr size_t numSortingTypes = static_cast<size_t>(kNumSortingTypes);
constexpr size_t numFittingTypes = static_cast<size_t>(kNumFittingTypes);
constexpr size_t numAlgorithmTypes = static_cast<size_t>(kNumAlgorithmTypes);
constexpr size_t total_sol = numSortingTypes * numFittingTypes * numAlgorithmTypes;
size_t process_num = common::ThreadPool::GetInstance().GetSyncRunThreadNum();
bool isMultiThreadPermit = ball && process_num >= total_sol && total_sol > 1;
bool isMultiThreadValid = isMultiThreadPermit && (total_sol > kSolNumThresholdMultiThread ||
@ -116,9 +119,9 @@ Status SomasSolverPre::Solving(const session::KernelGraph *graph, TensorsDescMap
return FAILED;
}
auto start = std::chrono::system_clock::now();
for (size_t algorithm_strategy = 0, sol = 0; algorithm_strategy < kNumAlgorithmTypes; algorithm_strategy++) {
for (size_t sort_strategy = 0; sort_strategy < kNumSortingTypes; sort_strategy++) {
for (size_t branching_strategy = 0; branching_strategy < kNumFittingTypes; branching_strategy++) {
for (size_t algorithm_strategy = 0, sol = 0; algorithm_strategy < numAlgorithmTypes; algorithm_strategy++) {
for (size_t sort_strategy = 0; sort_strategy < numSortingTypes; sort_strategy++) {
for (size_t branching_strategy = 0; branching_strategy < numFittingTypes; branching_strategy++) {
std::shared_ptr<SomasSolverCore> pSolver =
std::make_shared<SomasSolverCore>(vecTensorsMap[sol], pConstraints, sol);
pSolver->SetAlgorithmStrategy(AlgorithmType(algorithm));

View File

@ -169,7 +169,7 @@ struct SomasSolverTensorDesc {
out << n->index_ << " " << n->size_ << " " << n->offset_ << "\n";
return out;
}
friend std::istream &operator>>(std::istream &in, SomasSolverTensorDescPtr n) {
friend std::istream &operator>>(std::istream &in, const SomasSolverTensorDescPtr &n) {
in >> n->index_ >> n->size_ >> n->offset_;
return in;
}

View File

@ -29,12 +29,11 @@ namespace somas {
class SomasNode;
class SomasTensor;
using SomasTensorPtr = std::shared_ptr<SomasTensor>;
class SomasStream {
public:
using SomasStreamPtr = std::shared_ptr<SomasStream>;
using SomasNodePtr = std::shared_ptr<SomasNode>;
using SomasTensorPtr = std::shared_ptr<SomasTensor>;
using SomasStreamPtr = std::shared_ptr<SomasStream>;
// Attributes mutated in code
std::vector<SomasTensorPtr> tensors_; // vector needed for same-stream loop in ConflictComputing()

View File

@ -60,11 +60,10 @@ enum LifeLongType {
kLifeLongGraphEnd // life time is from tensor start to graph end
};
using SomasNodePtr = std::shared_ptr<SomasNode>;
using SomasStreamPtr = std::shared_ptr<SomasStream>;
class SomasTensor {
public:
using SomasNodePtr = std::shared_ptr<SomasNode>;
using SomasStreamPtr = std::shared_ptr<SomasStream>;
using SomasTensorPtr = std::shared_ptr<SomasTensor>;
size_t aligned_size_{0};
@ -90,20 +89,20 @@ class SomasTensor {
~SomasTensor() = default;
// Accessors
const size_t &GetId() { return id_; }
const size_t &GetId() const { return id_; }
SomasNodePtr GetSourceNode() const { return source_node_; }
SomasStreamPtr GetSourceStream() const { return source_stream_; }
const size_t &GetOriginalSize() { return original_size_; }
const size_t &GetAlignedSize() { return aligned_size_; }
const size_t &GetNumConstraints() { return num_constraints_; }
bool IsLifelong() { return lifelong_value_ == kLifeLongGraphAll; }
bool IsWorkspace() { return type_ == kWorkspace; }
bool IsOutputOnly() { return type_ == kOutputOnly; }
size_t GetOffset() { return offset_; }
bool IsBetweenStreams() { return between_streams_; }
bool IsSemiLifelongStart() { return lifelong_value_ == kLifeLongGraphStart; }
bool IsSemiLifelongEnd() { return lifelong_value_ == kLifeLongGraphEnd; }
bool IsRefOverlap() { return ref_overlap_; }
const size_t &GetOriginalSize() const { return original_size_; }
const size_t &GetAlignedSize() const { return aligned_size_; }
const size_t &GetNumConstraints() const { return num_constraints_; }
bool IsLifelong() const { return lifelong_value_ == kLifeLongGraphAll; }
bool IsWorkspace() const { return type_ == kWorkspace; }
bool IsOutputOnly() const { return type_ == kOutputOnly; }
size_t GetOffset() const { return offset_; }
bool IsBetweenStreams() const { return between_streams_; }
bool IsSemiLifelongStart() const { return lifelong_value_ == kLifeLongGraphStart; }
bool IsSemiLifelongEnd() const { return lifelong_value_ == kLifeLongGraphEnd; }
bool IsRefOverlap() const { return ref_overlap_; }
// Computing functions
void SetOffset() {

View File

@ -85,7 +85,7 @@ class KernelDef {
void set_scope_full_name(const std::string &scop_name) { scop_full_name_ = scop_name; }
std::string scope_full_name() const { return scop_full_name_; }
void InsertInputKernel(const std::shared_ptr<KernelDef> &input_kernel) { input_kernels_.insert(input_kernel); }
const std::set<std::shared_ptr<KernelDef>> &input_kernels() { return input_kernels_; }
const std::set<std::shared_ptr<KernelDef>> &input_kernels() const { return input_kernels_; }
private:
std::string scop_full_name_;

View File

@ -1003,9 +1003,9 @@ size_t UnitSizeInBytes(const mindspore::TypeId &t) {
case kNumberTypeFloat64:
bytes = sizeof(int64_t);
break;
case kNumberTypeInt4:
default:
MS_LOG(EXCEPTION) << "Invalid types " << t;
break;
}
return bytes;

View File

@ -212,9 +212,9 @@ class KernelMod {
void set_inputs_addr(const std::vector<AddressPtr> &addr) { inputs_addr_ = addr; }
void set_workspaces_addr(const std::vector<AddressPtr> &addr) { workspaces_addr_ = addr; }
void set_outputs_addr(const std::vector<AddressPtr> &addr) { outputs_addr_ = addr; }
const std::vector<AddressPtr> &GetInputsAddr() { return inputs_addr_; }
const std::vector<AddressPtr> &GetWorkSpacesAddr() { return workspaces_addr_; }
const std::vector<AddressPtr> &GetOutputsAddr() { return outputs_addr_; }
const std::vector<AddressPtr> &GetInputsAddr() const { return inputs_addr_; }
const std::vector<AddressPtr> &GetWorkSpacesAddr() const { return workspaces_addr_; }
const std::vector<AddressPtr> &GetOutputsAddr() const { return outputs_addr_; }
void set_stream(StreamType stream) { stream_ = stream; }
StreamType stream() const { return stream_; }
void SetAtomicCleanNodes(const std::vector<CNodePtr> &atomic_clean_node);

View File

@ -49,7 +49,7 @@ class AscendDeviceAddress : public DeviceAddress {
bool SyncHostToDevice(size_t size, const void *host_ptr) const override;
bool SyncDeviceToHost(const ShapeVector &shape, size_t size, TypeId type, void *host_ptr) const override;
bool SyncHostToDevice(const ShapeVector &shape, size_t size, TypeId type, const void *host_ptr,
const std::string &format = "DefaultFormat") const override;
const std::string &format) const override;
bool AsyncDeviceToDevice(const ShapeVector &shape, size_t size, TypeId type, const void *src_ptr,
const std::string &format) const override;
bool SyncDeviceToDevice(const DeviceSync *src_device_addr) const override;

View File

@ -50,7 +50,7 @@ class AscendMemoryManager : public MemoryManager {
uint64_t GetMsUsedHbmSize();
protected:
uint8_t *MallocStaticMem(size_t size, bool communication_mem, uint32_t graph_id = kInvalidGraphId) override;
uint8_t *MallocStaticMem(size_t size, bool communication_mem, uint32_t graph_id) override;
uint8_t *MallocDynamicMem(size_t size, bool communication_mem) override;
};
} // namespace ascend

View File

@ -39,6 +39,10 @@
#include "debug/debugger/debugger.h"
#endif
namespace mindspore {
namespace device {
namespace ascend {
namespace {
static constexpr uint32_t kAicpuLoadFlag = 1;
static constexpr uint32_t kAicpuUnloadFlag = 0;
static constexpr uint32_t kTupleTaskId = 0;
@ -61,10 +65,8 @@ constexpr const char *kOpTypeOpDebug = "Opdebug";
static constexpr auto kCurLoopCountName = "current_loop_count";
static constexpr auto kCurEpochCountName = "current_epoch_count";
static constexpr auto kConstLoopNumInEpochName = "const_loop_num_in_epoch";
} // namespace
namespace mindspore {
namespace device {
namespace ascend {
DataDumper::~DataDumper() {
kernel_graph_ = nullptr;
ReleaseDevMem(&dev_load_mem_);

View File

@ -28,6 +28,8 @@
#include "transform/graph_ir/util.h"
#include "plugin/device/ascend/hal/hccl_adapter/all_to_all_v_calc_param.h"
namespace mindspore::hccl {
namespace {
static constexpr char kGeOpNameHcclSend[] = "HcomSend";
static constexpr char kGeOpNameHcclReceive[] = "HcomReceive";
static constexpr char kGeOpNameHcclAllRudece[] = "HcomAllReduce";
@ -81,8 +83,8 @@ struct IsVector<std::vector<int64_t>> {
// cppcheck-suppress unusedStructMember
static constexpr bool value = true;
};
} // namespace
namespace mindspore::hccl {
template <class T>
static T ConvertAttr(const CNodePtr &cnode, const ge::OpDescPtr &ge_op, const std::string &anf_attr_name,
const std::string &ge_attr_name) {

View File

@ -35,7 +35,7 @@ static constexpr const char *kHcclAlgoEnv = "HCCL_ALGO";
static constexpr const char *kHcclAlgoOption = "HCCL_algorithm";
#define CHECK_SYMBOL_NULL(symbol) \
if (symbol == nullptr) { \
if ((symbol) == nullptr) { \
MS_LOG(WARNING) << #symbol << " is null, hccl has not been inited, do nothing."; \
return HcclResult::HCCL_E_RESERVED; \
}

View File

@ -37,7 +37,7 @@ MemCpyAsyncKernel::MemCpyAsyncKernel() {}
MemCpyAsyncKernel::~MemCpyAsyncKernel() {}
bool MemCpyAsyncKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> & /*workspace*/,
bool MemCpyAsyncKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> & /* workspace */,
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
if (inputs.size() != 1) {
MS_LOG(ERROR) << "inputs size is not one";

View File

@ -40,7 +40,7 @@ RtKernel::RtKernel() {}
RtKernel::~RtKernel() {}
bool RtKernel::Init(const mindspore::AnfNodePtr & /*anf_node*/) { return true; }
bool RtKernel::Init(const mindspore::AnfNodePtr & /* anf_node */) { return true; }
void RtKernel::SetInputSizeList(const std::vector<size_t> &size_list) { mutable_input_size_list_ = size_list; }
void RtKernel::SetOutputSizeList(const std::vector<size_t> &size_list) { mutable_output_size_list_ = size_list; }

View File

@ -16,7 +16,6 @@
#include "plugin/device/ascend/optimizer/format_type/insert_trans_op.h"
#include <memory>
#include <vector>
#include "utils/utils.h"
#include "backend/common/optimizer/helper.h"
#include "plugin/device/ascend/optimizer/ascend_helper.h"

View File

@ -15,7 +15,6 @@
*/
#include "plugin/device/ascend/optimizer/format_type/insert_transdata_for_runop.h"
#include <memory>
#include "utils/utils.h"
#include "plugin/device/ascend/optimizer/ascend_helper.h"
#include "backend/common/session/anf_runtime_algorithm.h"

View File

@ -142,7 +142,7 @@ AnfNodePtr ConstructFilter(const FuncGraphPtr &func_graph, const std::vector<int
std::vector<int64_t> assist_shape = {c1 * kd * kh * kw, 1, kC0, kC0}; // frac_z_3d
std::vector<size_t> infer_shape = {IntToSize(1), LongToSize(fc), LongToSize(kd), LongToSize(kh), LongToSize(kw)};
float val = 1.0 / (kd * kh * kw);
if (divisor_override) {
if (divisor_override != 0) {
val = 1.0 / divisor_override;
} else if (!IsZeroPads(pad_list) || ceil_mode) {
val = 1.0;

View File

@ -107,7 +107,7 @@ AnfNodePtr ConstructFilter(const FuncGraphPtr &func_graph, const std::vector<int
std::vector<int64_t> assist_shape = {c1 * kd * kh * kw, 1, kC0, kC0}; // frac_z_3d
std::vector<size_t> infer_shape = {IntToSize(1), LongToSize(fc), LongToSize(kd), LongToSize(kh), LongToSize(kw)};
float val = 1.0;
if (divisor_override) {
if (divisor_override != 0) {
val = 1.0 / divisor_override;
} else if (IsZeroPads(pad_list) && !ceil_mode) {
val = 1.0 / (kd * kh * kw);

View File

@ -42,7 +42,7 @@ class CPUDeviceAddress : public DeviceAddress {
bool SyncDeviceToHost(const ShapeVector &shape, size_t size, TypeId type, void *host_ptr) const override;
bool SyncHostToDevice(const ShapeVector &shape, size_t size, TypeId type, const void *host_ptr,
const std::string &format = "DefaultFormat") const override;
const std::string &format) const override;
bool SyncDeviceToDevice(const DeviceSync *src_device_addr) const override;
bool DumpMemToFile(const std::string &filepath, const std::string &host_fmt, const ShapeVector &host_shape,

View File

@ -59,7 +59,7 @@ bool CPUKernelRuntime::Init() {
}
const size_t INIT_NODE_REF = 1;
void CPUKernelRuntime::AssignKernelAddress(session::KernelGraph *kernel_graph) {
void CPUKernelRuntime::AssignKernelGraphAddress(session::KernelGraph *kernel_graph) {
AssignValueNodeAddress(kernel_graph);
AssignInputNodeAddress(kernel_graph);
auto context_ptr = MsContext::GetInstance();

View File

@ -37,7 +37,7 @@ class CPUKernelRuntime : public KernelRuntime {
bool Init();
bool Run(const session::KernelGraph &graph, bool is_task_sink) override;
void AssignKernelAddress(session::KernelGraph *kernel_graph);
void AssignKernelGraphAddress(session::KernelGraph *kernel_graph);
void CreateOutputTensors(session::KernelGraph *kernel_graph, const std::vector<tensor::TensorPtr> &inputs,
VectorRef *outputs, std::map<tensor::TensorPtr, session::KernelWithIndex> *tensor_to_node);
void BindInputOutput(session::KernelGraph *kernel_graph, const std::vector<tensor::TensorPtr> &inputs,

View File

@ -55,7 +55,7 @@ class CPUMemoryManager : public MemoryManager {
}
protected:
uint8_t *MallocStaticMem(size_t size, bool communication_mem, uint32_t graph_id = kInvalidGraphId) override;
uint8_t *MallocStaticMem(size_t size, bool communication_mem, uint32_t graph_id) override;
uint8_t *MallocDynamicMem(size_t size, bool communication_mem) override;
private:

View File

@ -48,8 +48,7 @@ bool IsInputNotCNode(const CNodePtr &kernel_node, size_t input_index) {
void GetOutputDtypes(const CNodePtr &kernel_node, std::vector<TypeId> *output_types) {
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
for (size_t output_index = 0; output_index < output_num; ++output_index) {
TypeId dtype = kTypeUnknown;
dtype = AnfAlgo::GetOutputInferDataType(kernel_node, output_index);
TypeId dtype = AnfAlgo::GetOutputInferDataType(kernel_node, output_index);
output_types->emplace_back(dtype);
}
}

View File

@ -179,7 +179,7 @@ class MKLCpuKernelMod : public NativeCpuKernelMod {
dnnl::memory::format_tag GetDefaultFormatTag(const dnnl::memory::dims &dims) const;
dnnl::memory::desc GetDefaultMemDesc(const std::vector<size_t> &shape) const;
void ExecutePrimitive();
inline dnnl::memory::desc formatted_md(const dnnl::memory::dims &dimensions, dnnl::memory::format_tag layout) {
inline dnnl::memory::desc formatted_md(const dnnl::memory::dims &dimensions, dnnl::memory::format_tag layout) const {
MS_LOG(DEBUG) << "begin to invoke constructor of dnnl::memory::desc";
auto desc = dnnl::memory::desc{{dimensions}, dnnl::memory::data_type::f32, layout};
MS_LOG(DEBUG) << "end to invoke constructor of dnnl::memory::desc";

View File

@ -13,6 +13,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "plugin/device/cpu/kernel/pyfunc/py_func_cpu_kernel.h"
#include <memory>
@ -135,8 +136,9 @@ void ScalarToRawMemory(const py::object &obj, const TypeId &type, const AddressP
void ArrayToRawMemory(const py::array &array, const AddressPtr &address) {
if (static_cast<unsigned int>(array.flags()) & pybind11::detail::npy_api::NPY_ARRAY_C_CONTIGUOUS_) {
const py::buffer_info &buf_info = array.request();
CHECK_RET_WITH_EXCEPT(memcpy_s(address->addr, address->size, buf_info.ptr, buf_info.size * buf_info.itemsize), EOK,
"memcpy failed.");
CHECK_RET_WITH_EXCEPT(
memcpy_s(address->addr, address->size, buf_info.ptr, LongToSize(buf_info.size * buf_info.itemsize)), EOK,
"memcpy failed.");
} else {
// Transform numpy array to row major buffer.
Py_buffer pybuf;

View File

@ -25,7 +25,7 @@ namespace mindspore {
namespace opt {
class InsertCastCPU : public Pass {
public:
explicit InsertCastCPU(const std::string &name) : Pass("insert_cast_cpu") {}
explicit InsertCastCPU(const std::string & /* name */) : Pass("insert_cast_cpu") {}
~InsertCastCPU() override = default;
bool Run(const FuncGraphPtr &graph) override;
};

View File

@ -45,7 +45,7 @@ class GPUDeviceAddress : public DeviceAddress {
bool SyncHostToDevice(size_t size, const void *host_ptr) const override;
bool SyncDeviceToHost(const ShapeVector &shape, size_t size, TypeId type, void *host_ptr) const override;
bool SyncHostToDevice(const ShapeVector &shape, size_t size, TypeId type, const void *host_ptr,
const std::string &format = "DefaultFormat") const override;
const std::string &format) const override;
bool SyncDeviceToDevice(const DeviceSync *src_device_addr) const override;
void ClearDeviceMemory() override;

View File

@ -36,7 +36,7 @@ class GPUMemoryManager : public MemoryManager {
std::vector<size_t> size_list) override;
protected:
uint8_t *MallocStaticMem(size_t size, bool communication_mem, uint32_t graph_id = kInvalidGraphId) override;
uint8_t *MallocStaticMem(size_t size, bool communication_mem, uint32_t graph_id) override;
};
} // namespace gpu
} // namespace device

View File

@ -52,7 +52,7 @@ void DynamicKernel::Initialize() {
MS_LOG(INFO) << "Init End";
}
int DynamicKernel::GetKernelType() const { return AnfAlgo::GetKernelType(cnode_ptr_.lock()); }
int DynamicKernel::GetKernelType() const { return static_cast<int>(AnfAlgo::GetKernelType(cnode_ptr_.lock())); }
void DynamicKernel::InferShape() {
auto cnode = cnode_ptr_.lock();

View File

@ -1035,7 +1035,7 @@ DeviceAddressPtr KernelRuntime::CreateDeviceAddressForStringValue(const ValuePtr
}
}
ShapeVector shape = {1, SizeToLong(tensor_size)};
if (!address->SyncHostToDevice(shape, tensor_size, kNumberTypeUInt8, value_string.data())) {
if (!address->SyncHostToDevice(shape, tensor_size, kNumberTypeUInt8, value_string.data(), "DefaultFormat")) {
MS_LOG(EXCEPTION) << "kValueNode SyncHostToDevice fail!";
}
return address;

View File

@ -46,8 +46,10 @@ class MemoryManager : public MemHandler {
uint8_t *MallocOutputMem(const AnfNodePtr &node, size_t index, MemType type, size_t size,
const DeviceAddressPtr &address, bool comm_mem);
uint8_t *MallocWorkSpaceMem(const AnfNodePtr &node, size_t index, MemType type, size_t size);
virtual uint8_t *MallocMem(MemType type, size_t size, const DeviceAddressPtr &address,
uint32_t graph_id = kInvalidGraphId);
virtual uint8_t *MallocMem(MemType type, size_t size, const DeviceAddressPtr &address, uint32_t graph_id);
virtual uint8_t *MallocMem(MemType type, size_t size, const DeviceAddressPtr &address) {
return MallocMem(type, size, address, kInvalidGraphId);
}
// param address is the address type of each device
// param from_persistent_mem shows whether the tensor is a parameter in Pynative mode
@ -109,7 +111,10 @@ class MemoryManager : public MemHandler {
}
protected:
virtual uint8_t *MallocStaticMem(size_t size, bool communication_mem, uint32_t graph_id = kInvalidGraphId) = 0;
virtual uint8_t *MallocStaticMem(size_t size, bool communication_mem, uint32_t graph_id) = 0;
virtual uint8_t *MallocStaticMem(size_t size, bool communication_mem) {
return MallocStaticMem(size, communication_mem, kInvalidGraphId);
}
virtual uint8_t *MallocDynamicMem(size_t size, bool communication_mem);
SomasPtr somas_reuse_util_ptr_{nullptr};
std::map<size_t, std::queue<void *>> cached_host_mem_;

View File

@ -38,7 +38,10 @@ class DeviceSync {
// Used to sync data between host tensor and device address, additional need the data shape and data type.
virtual bool SyncDeviceToHost(const ShapeVector &shape, size_t size, TypeId type, void *host_ptr) const = 0;
virtual bool SyncHostToDevice(const ShapeVector &shape, size_t size, TypeId type, const void *host_ptr,
const std::string &format = "DefaultFormat") const = 0;
const std::string &format) const = 0;
virtual bool SyncHostToDevice(const ShapeVector &shape, size_t size, TypeId type, const void *host_ptr) const {
return SyncHostToDevice(shape, size, type, host_ptr, "DefaultFormat");
}
virtual bool SyncDeviceToDevice(const DeviceSync *) const { return true; }
virtual bool AsyncDeviceToDevice(const ShapeVector &, size_t, TypeId type, const void *, const std::string &) const {
return true;