forked from mindspore-Ecosystem/mindspore
clean code of several graph kernel source files
fix compiling
This commit is contained in:
parent
dfae126152
commit
1c62823cfe
|
@ -80,7 +80,7 @@ class AbstractShapeCreator {
|
|||
return {device_shape[0], device_shape[3], device_shape[1], device_shape[2]};
|
||||
}
|
||||
static ShapeVector FractalNzAbstractShape(const ShapeVector &device_shape) {
|
||||
if (device_shape.size() == 1 && (device_shape[0] == 1 || device_shape[0] % kCubeSize == 0)) {
|
||||
if (device_shape.size() == 1 && (device_shape[0] == 1 || static_cast<size_t>(device_shape[0]) % kCubeSize == 0)) {
|
||||
return device_shape;
|
||||
}
|
||||
if (device_shape.size() < 4) {
|
||||
|
@ -126,7 +126,7 @@ class CNodeDecoder {
|
|||
}
|
||||
|
||||
private:
|
||||
ValuePtr ParseValue(const nlohmann::json &attr_json, const std::string &type) {
|
||||
ValuePtr ParseValue(const nlohmann::json &attr_json, const std::string &type) const {
|
||||
if (type == "str") {
|
||||
std::string value = attr_json[kJsonKeyValue];
|
||||
return MakeValue(value);
|
||||
|
@ -204,7 +204,6 @@ class CNodeDecoder {
|
|||
|
||||
bool DecodeOutputDesc(const nlohmann::json &cnode_json, const FuncGraphPtr &func_graph) {
|
||||
std::vector<nlohmann::json> output_descs = cnode_json[kJsonKeyOutputDesc];
|
||||
AbstractBasePtr abstract(nullptr);
|
||||
if (output_descs.empty()) {
|
||||
MS_LOG(ERROR) << "No outputs found.";
|
||||
return false;
|
||||
|
@ -288,7 +287,7 @@ class CNodeDecoder {
|
|||
return primitive;
|
||||
}
|
||||
|
||||
tensor::TensorPtr DecodeScalar(const nlohmann::json &scalar_json) {
|
||||
tensor::TensorPtr DecodeScalar(const nlohmann::json &scalar_json) const {
|
||||
auto type_id = DtypeToTypeId(scalar_json[kJsonKeyDataType]);
|
||||
switch (type_id) {
|
||||
case kNumberTypeFloat16:
|
||||
|
@ -435,7 +434,7 @@ FuncGraphPtr AkgKernelJsonDecoder::DecodeFusedNodes(const std::string &kernel_js
|
|||
return DecodeFusedNodes(kernel_json);
|
||||
}
|
||||
|
||||
StitchInfo AkgKernelJsonDecoder::GetStitchInfo(const nlohmann::json &kernel_json) {
|
||||
StitchInfo AkgKernelJsonDecoder::GetStitchInfo(const nlohmann::json &kernel_json) const {
|
||||
StitchInfo info;
|
||||
if (kernel_json.find(kJsonKeyBufferStitch) != kernel_json.end()) {
|
||||
nlohmann::json buffer_stitch = kernel_json[kJsonKeyBufferStitch];
|
||||
|
@ -451,7 +450,8 @@ StitchInfo AkgKernelJsonDecoder::GetStitchInfo(const nlohmann::json &kernel_json
|
|||
return info;
|
||||
}
|
||||
|
||||
void AkgKernelJsonDecoder::SetStitchAttr(const nlohmann::json &op_desc, const StitchInfo &info, const CNodePtr &node) {
|
||||
void AkgKernelJsonDecoder::SetStitchAttr(const nlohmann::json &op_desc, const StitchInfo &info,
|
||||
const CNodePtr &node) const {
|
||||
std::vector<nlohmann::json> output_descs = op_desc[kJsonKeyOutputDesc];
|
||||
if (output_descs.empty() || output_descs[0].find(kJsonKeyTensorName) == output_descs[0].end()) return;
|
||||
std::string tensor_name = output_descs[0][kJsonKeyTensorName];
|
||||
|
|
|
@ -44,8 +44,8 @@ class AkgKernelJsonDecoder {
|
|||
ParameterPtr DecodeParameter(const nlohmann::json ¶meter_json, const FuncGraphPtr &func_graph);
|
||||
CNodePtr DecodeCNode(const nlohmann::json &cnode_json, const FuncGraphPtr &func_graph, const std::string &processor);
|
||||
AnfNodePtr DecodeOutput(const std::vector<nlohmann::json> &output_descs, const FuncGraphPtr &func_graph);
|
||||
StitchInfo GetStitchInfo(const nlohmann::json &kernel_json);
|
||||
void SetStitchAttr(const nlohmann::json &op_desc, const StitchInfo &info, const CNodePtr &node);
|
||||
StitchInfo GetStitchInfo(const nlohmann::json &kernel_json) const;
|
||||
void SetStitchAttr(const nlohmann::json &op_desc, const StitchInfo &info, const CNodePtr &node) const;
|
||||
std::map<std::string, AnfNodePtr> nodes_map_;
|
||||
};
|
||||
} // namespace kernel
|
||||
|
|
|
@ -23,15 +23,15 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
int64_t AxisNormalizer::NormAxis(int64_t x, size_t rank) { return x >= 0 ? x : x + static_cast<int64_t>(rank); }
|
||||
int64_t AxisNormalizer::NormAxis(int64_t x, size_t rank) const { return x >= 0 ? x : x + static_cast<int64_t>(rank); }
|
||||
|
||||
bool AxisNormalizer::IsReduce(const AnfNodePtr &node) {
|
||||
bool AxisNormalizer::IsReduce(const AnfNodePtr &node) const {
|
||||
std::vector<PrimitivePtr> node_with_axis = {prim::kPrimReduceSum, prim::kPrimReduceMax, prim::kPrimReduceMin};
|
||||
return std::any_of(node_with_axis.begin(), node_with_axis.end(),
|
||||
[&node](PrimitivePtr &p) { return IsPrimitiveCNode(node, p); });
|
||||
}
|
||||
|
||||
bool AxisNormalizer::Process(const FuncGraphPtr &func_graph) {
|
||||
bool AxisNormalizer::Process(const FuncGraphPtr &func_graph) const {
|
||||
bool changed = false;
|
||||
auto todos = TopoSort(func_graph->get_return());
|
||||
for (auto node : todos) {
|
||||
|
@ -48,8 +48,8 @@ bool AxisNormalizer::Process(const FuncGraphPtr &func_graph) {
|
|||
bool diff = false;
|
||||
ShapeVector axis_vec;
|
||||
if (axis->isa<Int32Imm>() || axis->isa<Int64Imm>()) {
|
||||
int64_t v1 = GetValue<int64_t>(axis);
|
||||
int64_t v2 = NormAxis(v1, rank);
|
||||
auto v1 = GetValue<int64_t>(axis);
|
||||
auto v2 = NormAxis(v1, rank);
|
||||
axis_vec.push_back(v2);
|
||||
diff = diff || (v1 != v2);
|
||||
} else if (axis->isa<ValueList>() || axis->isa<ValueTuple>()) {
|
||||
|
@ -61,8 +61,8 @@ bool AxisNormalizer::Process(const FuncGraphPtr &func_graph) {
|
|||
}
|
||||
} else if (vec[0]->isa<Int32Imm>() || vec[0]->isa<Int64Imm>()) {
|
||||
for (auto v : vec) {
|
||||
int64_t v1 = GetValue<int64_t>(v);
|
||||
int64_t v2 = NormAxis(v1, rank);
|
||||
auto v1 = GetValue<int64_t>(v);
|
||||
auto v2 = NormAxis(v1, rank);
|
||||
axis_vec.push_back(v2);
|
||||
diff = diff || (v1 != v2);
|
||||
}
|
||||
|
|
|
@ -29,9 +29,9 @@ class AxisNormalizer : public Pass {
|
|||
bool Run(const FuncGraphPtr &func_graph) override;
|
||||
|
||||
private:
|
||||
bool Process(const FuncGraphPtr &func_graph);
|
||||
int64_t NormAxis(int64_t x, size_t rank);
|
||||
bool IsReduce(const AnfNodePtr &node);
|
||||
bool Process(const FuncGraphPtr &func_graph) const;
|
||||
int64_t NormAxis(int64_t x, size_t rank) const;
|
||||
bool IsReduce(const AnfNodePtr &node) const;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -60,7 +60,7 @@ bool CheckCircle(const std::set<AnfNodePtr> &fused_op_set, const AnfNodePtr &che
|
|||
}
|
||||
circle_nodes->clear();
|
||||
|
||||
auto InputEdges = [&depend_prior](CNodePtr cnode) {
|
||||
auto InputEdges = [&depend_prior](const CNodePtr &cnode) {
|
||||
std::set<AnfNodePtr> edges;
|
||||
auto range = depend_prior.equal_range(cnode);
|
||||
for (auto iter = range.first; iter != range.second; ++iter) {
|
||||
|
|
|
@ -30,7 +30,7 @@ std::string CommonDimInfo::ToString() {
|
|||
return buffer.str();
|
||||
}
|
||||
|
||||
int ParallelCostModel::GetNodeCalAmount(const AnfNodePtr &node) {
|
||||
int ParallelCostModel::GetNodeCalAmount(const AnfNodePtr &node) const {
|
||||
nlohmann::json json_desc;
|
||||
AnfNodePtrList nodes = {node};
|
||||
DumpOption dump_option;
|
||||
|
@ -47,7 +47,8 @@ int ParallelCostModel::GetNodeCalAmount(const AnfNodePtr &node) {
|
|||
return py::cast<int>(ret);
|
||||
}
|
||||
|
||||
std::tuple<std::vector<DimInfoPtr>, int, FusionInfoPtr> ParallelCostModel::CalFuseInfo(const AnfNodePtrList &nodes) {
|
||||
std::tuple<std::vector<DimInfoPtr>, int, FusionInfoPtr> ParallelCostModel::CalFuseInfo(
|
||||
const AnfNodePtrList &nodes) const {
|
||||
nlohmann::json json_desc;
|
||||
std::vector<AnfNodePtrList> graphs;
|
||||
std::transform(nodes.begin(), nodes.end(), std::back_inserter(graphs),
|
||||
|
@ -80,7 +81,7 @@ std::tuple<std::vector<DimInfoPtr>, int, FusionInfoPtr> ParallelCostModel::CalFu
|
|||
return std::make_tuple(dim_infos, benefit, fusion_info);
|
||||
}
|
||||
|
||||
FusionInfoPtr ParallelCostModel::ProcessFusionInfo(py::object fusion_type, py::object type_info) {
|
||||
FusionInfoPtr ParallelCostModel::ProcessFusionInfo(const py::object &fusion_type, const py::object &type_info) const {
|
||||
if (!py::isinstance<py::str>(fusion_type)) {
|
||||
MS_LOG(EXCEPTION) << "Fusion type for parallel is invalid!";
|
||||
}
|
||||
|
|
|
@ -37,7 +37,7 @@ namespace opt {
|
|||
class DimInfo {
|
||||
public:
|
||||
DimInfo() = default;
|
||||
~DimInfo() {}
|
||||
virtual ~DimInfo() {}
|
||||
virtual std::string ToString() = 0;
|
||||
};
|
||||
|
||||
|
@ -60,7 +60,7 @@ class FusionInfo {
|
|||
public:
|
||||
FusionInfo() = default;
|
||||
explicit FusionInfo(const std::string &type) : fusion_type_(type) {}
|
||||
~FusionInfo() = default;
|
||||
virtual ~FusionInfo() = default;
|
||||
std::string FusionType() { return fusion_type_; }
|
||||
virtual bool ExistTypeInfo() { return false; }
|
||||
|
||||
|
@ -72,7 +72,7 @@ class BlockFusionInfo : public FusionInfo {
|
|||
public:
|
||||
BlockFusionInfo() : FusionInfo("block_fusion") {}
|
||||
~BlockFusionInfo() = default;
|
||||
bool ExistTypeInfo() { return false; }
|
||||
bool ExistTypeInfo() override { return false; }
|
||||
};
|
||||
|
||||
class BlockPipelineFusionInfo : public FusionInfo {
|
||||
|
@ -80,7 +80,7 @@ class BlockPipelineFusionInfo : public FusionInfo {
|
|||
explicit BlockPipelineFusionInfo(const std::vector<std::vector<int>> &ids)
|
||||
: FusionInfo("block_pipeline_fusion"), pipeline_ids_(ids) {}
|
||||
~BlockPipelineFusionInfo() = default;
|
||||
bool ExistTypeInfo() { return true; }
|
||||
bool ExistTypeInfo() override { return true; }
|
||||
std::vector<std::vector<int>> PipelineIds() { return pipeline_ids_; }
|
||||
|
||||
private:
|
||||
|
@ -95,11 +95,11 @@ class ParallelCostModel {
|
|||
public:
|
||||
ParallelCostModel() {}
|
||||
~ParallelCostModel() {}
|
||||
int GetNodeCalAmount(const AnfNodePtr &node);
|
||||
std::tuple<std::vector<DimInfoPtr>, int, FusionInfoPtr> CalFuseInfo(const AnfNodePtrList &nodes);
|
||||
int GetNodeCalAmount(const AnfNodePtr &node) const;
|
||||
std::tuple<std::vector<DimInfoPtr>, int, FusionInfoPtr> CalFuseInfo(const AnfNodePtrList &nodes) const;
|
||||
|
||||
private:
|
||||
FusionInfoPtr ProcessFusionInfo(py::object fusion_type, py::object type_info);
|
||||
FusionInfoPtr ProcessFusionInfo(const py::object &fusion_type, const py::object &type_info) const;
|
||||
};
|
||||
|
||||
using ParallelCostModelPtr = std::shared_ptr<ParallelCostModel>;
|
||||
|
|
Loading…
Reference in New Issue