[GraphKernel] clean code.

This commit is contained in:
chenlei_autodiff 2022-06-24 10:18:28 +08:00
parent b1d99b0ed3
commit 872e26f119
4 changed files with 16 additions and 15 deletions

View File

@ -82,7 +82,7 @@ std::vector<AreaWithRelation> Area::users_with_relation() const {
int64_t Area::compute_size() const {
auto op = dom();
MS_EXCEPTION_IF_NULL(op);
return op->tensor_size();
return SizeToLong(op->tensor_size());
}
std::string Area::ToString() const {

View File

@ -157,7 +157,7 @@ bool TensorInplace::Run(const FuncGraphPtr &func_graph) {
auto mng = func_graph->manager();
MS_EXCEPTION_IF_NULL(mng);
auto todos = TopoSort(func_graph->get_return());
bool changed = false;
bool tensor_inplace_changed = false;
for (auto &node : todos) {
if (common::AnfAlgo::IsGraphKernel(node)) {
auto sub_func_graph = common::AnfAlgo::GetCNodeFuncGraphPtr(node);
@ -176,7 +176,7 @@ bool TensorInplace::Run(const FuncGraphPtr &func_graph) {
return CheckShapeType(cnode->input(i), node.first);
});
if (candidate != outs.end()) {
changed = true;
tensor_inplace_changed = true;
InplaceAssignerInfo new_op_info; // output info
new_op_info.op_node = candidate->first->cast<CNodePtr>();
new_op_info.real_output_num = common::AnfAlgo::GetOutputTensorNum(cnode);
@ -192,10 +192,10 @@ bool TensorInplace::Run(const FuncGraphPtr &func_graph) {
}
}
}
if (changed) {
if (tensor_inplace_changed) {
mng->RemoveRoots();
mng->KeepRoots({func_graph});
}
return changed;
return tensor_inplace_changed;
}
} // namespace mindspore::graphkernel

View File

@ -202,7 +202,7 @@ void DumpOperateAttrs(const AnfNodePtr &op, const std::shared_ptr<SubGraphIRInfo
}
if (IsValueNode<Primitive>(op)) {
PrimitivePtr primitive = GetValueNode<PrimitivePtr>(op);
auto primitive = GetValueNode<PrimitivePtr>(op);
if (!primitive->instance_name().empty()) {
gsub->dumpbuf << " {";
gsub->dumpbuf << "instance name"
@ -223,11 +223,11 @@ void DumpCNodeAttrs(const CNodePtr &op, const std::shared_ptr<SubGraphIRInfo> &g
if (op == nullptr || gsub == nullptr) {
return;
}
if (op->attrs().empty()) {
auto &attrs = op->attrs();
if (attrs.empty()) {
return;
}
auto attrs = op->attrs();
gsub->dumpbuf << " cnode_attrs: {";
DumpAttrs(attrs, gsub);
gsub->dumpbuf << "}";
@ -336,11 +336,11 @@ void DumpIRInSubgraph(const std::vector<AnfNodePtr> &nodes, OrderedMap<AnfNodePt
gsub->local_var = 0;
(*sub_graphs)[sub_graph] = gsub;
}
std::vector<AnfNodePtr> parameters = sub_graph->parameters();
for (size_t idx = 0; idx < parameters.size(); idx++) {
MS_EXCEPTION_IF_NULL(parameters[idx]);
if ((*para_map).count(parameters[idx]) == 0) {
(*para_map)[parameters[idx]] = total_para++;
auto &param = sub_graph->parameters();
for (size_t idx = 0; idx < param.size(); idx++) {
MS_EXCEPTION_IF_NULL(param[idx]);
if ((*para_map).count(param[idx]) == 0) {
(*para_map)[param[idx]] = total_para++;
}
}
if (!nd->isa<Parameter>()) {
@ -419,7 +419,7 @@ std::optional<std::string> CreatePrefixPath(const std::string &input_path) {
} else {
auto pwd_path = FileUtils::GetRealPath("./");
if (!pwd_path.has_value()) {
MS_LOG(ERROR) << "Cannot get pwd path";
MS_LOG(ERROR) << "Can not get pwd path";
return std::nullopt;
}
prefix_path_str = pwd_path.value();

View File

@ -74,6 +74,7 @@ parse_device()
if [[ "X$ENABLE_AKG" == "Xon" && "X$ENABLE_D" != "Xon" ]]; then
# check llvm version for akg
export USE_LLVM=`bash ${BASEPATH}/scripts/build/akg_find_llvm.sh`
HAS_LLVM=`bash ${BASEPATH}/scripts/build/akg_find_llvm.sh`
export USE_LLVM=$HAS_LLVM
fi
}