clean code of several graph kernel source files

fix compiling
This commit is contained in:
looop5 2021-04-29 20:00:26 +08:00
parent dfae126152
commit 1c62823cfe
7 changed files with 30 additions and 29 deletions

View File

@ -80,7 +80,7 @@ class AbstractShapeCreator {
return {device_shape[0], device_shape[3], device_shape[1], device_shape[2]};
}
static ShapeVector FractalNzAbstractShape(const ShapeVector &device_shape) {
if (device_shape.size() == 1 && (device_shape[0] == 1 || device_shape[0] % kCubeSize == 0)) {
if (device_shape.size() == 1 && (device_shape[0] == 1 || static_cast<size_t>(device_shape[0]) % kCubeSize == 0)) {
return device_shape;
}
if (device_shape.size() < 4) {
@ -126,7 +126,7 @@ class CNodeDecoder {
}
private:
ValuePtr ParseValue(const nlohmann::json &attr_json, const std::string &type) {
ValuePtr ParseValue(const nlohmann::json &attr_json, const std::string &type) const {
if (type == "str") {
std::string value = attr_json[kJsonKeyValue];
return MakeValue(value);
@ -204,7 +204,6 @@ class CNodeDecoder {
bool DecodeOutputDesc(const nlohmann::json &cnode_json, const FuncGraphPtr &func_graph) {
std::vector<nlohmann::json> output_descs = cnode_json[kJsonKeyOutputDesc];
AbstractBasePtr abstract(nullptr);
if (output_descs.empty()) {
MS_LOG(ERROR) << "No outputs found.";
return false;
@ -288,7 +287,7 @@ class CNodeDecoder {
return primitive;
}
tensor::TensorPtr DecodeScalar(const nlohmann::json &scalar_json) {
tensor::TensorPtr DecodeScalar(const nlohmann::json &scalar_json) const {
auto type_id = DtypeToTypeId(scalar_json[kJsonKeyDataType]);
switch (type_id) {
case kNumberTypeFloat16:
@ -435,7 +434,7 @@ FuncGraphPtr AkgKernelJsonDecoder::DecodeFusedNodes(const std::string &kernel_js
return DecodeFusedNodes(kernel_json);
}
StitchInfo AkgKernelJsonDecoder::GetStitchInfo(const nlohmann::json &kernel_json) {
StitchInfo AkgKernelJsonDecoder::GetStitchInfo(const nlohmann::json &kernel_json) const {
StitchInfo info;
if (kernel_json.find(kJsonKeyBufferStitch) != kernel_json.end()) {
nlohmann::json buffer_stitch = kernel_json[kJsonKeyBufferStitch];
@ -451,7 +450,8 @@ StitchInfo AkgKernelJsonDecoder::GetStitchInfo(const nlohmann::json &kernel_json
return info;
}
void AkgKernelJsonDecoder::SetStitchAttr(const nlohmann::json &op_desc, const StitchInfo &info, const CNodePtr &node) {
void AkgKernelJsonDecoder::SetStitchAttr(const nlohmann::json &op_desc, const StitchInfo &info,
const CNodePtr &node) const {
std::vector<nlohmann::json> output_descs = op_desc[kJsonKeyOutputDesc];
if (output_descs.empty() || output_descs[0].find(kJsonKeyTensorName) == output_descs[0].end()) return;
std::string tensor_name = output_descs[0][kJsonKeyTensorName];

View File

@ -44,8 +44,8 @@ class AkgKernelJsonDecoder {
ParameterPtr DecodeParameter(const nlohmann::json &parameter_json, const FuncGraphPtr &func_graph);
CNodePtr DecodeCNode(const nlohmann::json &cnode_json, const FuncGraphPtr &func_graph, const std::string &processor);
AnfNodePtr DecodeOutput(const std::vector<nlohmann::json> &output_descs, const FuncGraphPtr &func_graph);
StitchInfo GetStitchInfo(const nlohmann::json &kernel_json);
void SetStitchAttr(const nlohmann::json &op_desc, const StitchInfo &info, const CNodePtr &node);
StitchInfo GetStitchInfo(const nlohmann::json &kernel_json) const;
void SetStitchAttr(const nlohmann::json &op_desc, const StitchInfo &info, const CNodePtr &node) const;
std::map<std::string, AnfNodePtr> nodes_map_;
};
} // namespace kernel

View File

@ -23,15 +23,15 @@
namespace mindspore {
namespace opt {
int64_t AxisNormalizer::NormAxis(int64_t x, size_t rank) { return x >= 0 ? x : x + static_cast<int64_t>(rank); }
int64_t AxisNormalizer::NormAxis(int64_t x, size_t rank) const { return x >= 0 ? x : x + static_cast<int64_t>(rank); }
bool AxisNormalizer::IsReduce(const AnfNodePtr &node) {
bool AxisNormalizer::IsReduce(const AnfNodePtr &node) const {
std::vector<PrimitivePtr> node_with_axis = {prim::kPrimReduceSum, prim::kPrimReduceMax, prim::kPrimReduceMin};
return std::any_of(node_with_axis.begin(), node_with_axis.end(),
[&node](PrimitivePtr &p) { return IsPrimitiveCNode(node, p); });
}
bool AxisNormalizer::Process(const FuncGraphPtr &func_graph) {
bool AxisNormalizer::Process(const FuncGraphPtr &func_graph) const {
bool changed = false;
auto todos = TopoSort(func_graph->get_return());
for (auto node : todos) {
@ -48,8 +48,8 @@ bool AxisNormalizer::Process(const FuncGraphPtr &func_graph) {
bool diff = false;
ShapeVector axis_vec;
if (axis->isa<Int32Imm>() || axis->isa<Int64Imm>()) {
int64_t v1 = GetValue<int64_t>(axis);
int64_t v2 = NormAxis(v1, rank);
auto v1 = GetValue<int64_t>(axis);
auto v2 = NormAxis(v1, rank);
axis_vec.push_back(v2);
diff = diff || (v1 != v2);
} else if (axis->isa<ValueList>() || axis->isa<ValueTuple>()) {
@ -61,8 +61,8 @@ bool AxisNormalizer::Process(const FuncGraphPtr &func_graph) {
}
} else if (vec[0]->isa<Int32Imm>() || vec[0]->isa<Int64Imm>()) {
for (auto v : vec) {
int64_t v1 = GetValue<int64_t>(v);
int64_t v2 = NormAxis(v1, rank);
auto v1 = GetValue<int64_t>(v);
auto v2 = NormAxis(v1, rank);
axis_vec.push_back(v2);
diff = diff || (v1 != v2);
}

View File

@ -29,9 +29,9 @@ class AxisNormalizer : public Pass {
bool Run(const FuncGraphPtr &func_graph) override;
private:
bool Process(const FuncGraphPtr &func_graph);
int64_t NormAxis(int64_t x, size_t rank);
bool IsReduce(const AnfNodePtr &node);
bool Process(const FuncGraphPtr &func_graph) const;
int64_t NormAxis(int64_t x, size_t rank) const;
bool IsReduce(const AnfNodePtr &node) const;
};
} // namespace opt
} // namespace mindspore

View File

@ -60,7 +60,7 @@ bool CheckCircle(const std::set<AnfNodePtr> &fused_op_set, const AnfNodePtr &che
}
circle_nodes->clear();
auto InputEdges = [&depend_prior](CNodePtr cnode) {
auto InputEdges = [&depend_prior](const CNodePtr &cnode) {
std::set<AnfNodePtr> edges;
auto range = depend_prior.equal_range(cnode);
for (auto iter = range.first; iter != range.second; ++iter) {

View File

@ -30,7 +30,7 @@ std::string CommonDimInfo::ToString() {
return buffer.str();
}
int ParallelCostModel::GetNodeCalAmount(const AnfNodePtr &node) {
int ParallelCostModel::GetNodeCalAmount(const AnfNodePtr &node) const {
nlohmann::json json_desc;
AnfNodePtrList nodes = {node};
DumpOption dump_option;
@ -47,7 +47,8 @@ int ParallelCostModel::GetNodeCalAmount(const AnfNodePtr &node) {
return py::cast<int>(ret);
}
std::tuple<std::vector<DimInfoPtr>, int, FusionInfoPtr> ParallelCostModel::CalFuseInfo(const AnfNodePtrList &nodes) {
std::tuple<std::vector<DimInfoPtr>, int, FusionInfoPtr> ParallelCostModel::CalFuseInfo(
const AnfNodePtrList &nodes) const {
nlohmann::json json_desc;
std::vector<AnfNodePtrList> graphs;
std::transform(nodes.begin(), nodes.end(), std::back_inserter(graphs),
@ -80,7 +81,7 @@ std::tuple<std::vector<DimInfoPtr>, int, FusionInfoPtr> ParallelCostModel::CalFu
return std::make_tuple(dim_infos, benefit, fusion_info);
}
FusionInfoPtr ParallelCostModel::ProcessFusionInfo(py::object fusion_type, py::object type_info) {
FusionInfoPtr ParallelCostModel::ProcessFusionInfo(const py::object &fusion_type, const py::object &type_info) const {
if (!py::isinstance<py::str>(fusion_type)) {
MS_LOG(EXCEPTION) << "Fusion type for parallel is invalid!";
}

View File

@ -37,7 +37,7 @@ namespace opt {
class DimInfo {
public:
DimInfo() = default;
~DimInfo() {}
virtual ~DimInfo() {}
virtual std::string ToString() = 0;
};
@ -60,7 +60,7 @@ class FusionInfo {
public:
FusionInfo() = default;
explicit FusionInfo(const std::string &type) : fusion_type_(type) {}
~FusionInfo() = default;
virtual ~FusionInfo() = default;
std::string FusionType() { return fusion_type_; }
virtual bool ExistTypeInfo() { return false; }
@ -72,7 +72,7 @@ class BlockFusionInfo : public FusionInfo {
public:
BlockFusionInfo() : FusionInfo("block_fusion") {}
~BlockFusionInfo() = default;
bool ExistTypeInfo() { return false; }
bool ExistTypeInfo() override { return false; }
};
class BlockPipelineFusionInfo : public FusionInfo {
@ -80,7 +80,7 @@ class BlockPipelineFusionInfo : public FusionInfo {
explicit BlockPipelineFusionInfo(const std::vector<std::vector<int>> &ids)
: FusionInfo("block_pipeline_fusion"), pipeline_ids_(ids) {}
~BlockPipelineFusionInfo() = default;
bool ExistTypeInfo() { return true; }
bool ExistTypeInfo() override { return true; }
std::vector<std::vector<int>> PipelineIds() { return pipeline_ids_; }
private:
@ -95,11 +95,11 @@ class ParallelCostModel {
public:
ParallelCostModel() {}
~ParallelCostModel() {}
int GetNodeCalAmount(const AnfNodePtr &node);
std::tuple<std::vector<DimInfoPtr>, int, FusionInfoPtr> CalFuseInfo(const AnfNodePtrList &nodes);
int GetNodeCalAmount(const AnfNodePtr &node) const;
std::tuple<std::vector<DimInfoPtr>, int, FusionInfoPtr> CalFuseInfo(const AnfNodePtrList &nodes) const;
private:
FusionInfoPtr ProcessFusionInfo(py::object fusion_type, py::object type_info);
FusionInfoPtr ProcessFusionInfo(const py::object &fusion_type, const py::object &type_info) const;
};
using ParallelCostModelPtr = std::shared_ptr<ParallelCostModel>;