forked from mindspore-Ecosystem/mindspore
!16378 Export bprop func_graph to mindir
Merge pull request !16378 from YuJianfeng/bprop_mindir
This commit is contained in:
commit
eb0bad4ad7
|
@ -311,4 +311,11 @@ bool Common::SaveStringToFile(const std::string filename, const std::string stri
|
||||||
ChangeFileMode(real_path.value(), S_IRUSR);
|
ChangeFileMode(real_path.value(), S_IRUSR);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool Common::FileExists(const std::string &filepath) {
|
||||||
|
std::ifstream f(filepath);
|
||||||
|
bool cache_file_existed = f.good();
|
||||||
|
f.close();
|
||||||
|
return cache_file_existed;
|
||||||
|
}
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -39,6 +39,7 @@ class Common {
|
||||||
|
|
||||||
static std::string AddId(const std::string &filename, const std::string &suffix);
|
static std::string AddId(const std::string &filename, const std::string &suffix);
|
||||||
static bool SaveStringToFile(const std::string filename, const std::string string_info);
|
static bool SaveStringToFile(const std::string filename, const std::string string_info);
|
||||||
|
static bool FileExists(const std::string &filepath);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
static bool IsEveryFilenameValid(const std::string &path, size_t length_limit, const std::string &error_message);
|
static bool IsEveryFilenameValid(const std::string &path, size_t length_limit, const std::string &error_message);
|
||||||
|
|
|
@ -141,7 +141,7 @@ class KPrim {
|
||||||
FuncGraphPtr GetPossibleBprop(const PrimitivePtr &prim);
|
FuncGraphPtr GetPossibleBprop(const PrimitivePtr &prim);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
FuncGraphPtr GetBprop(const PrimitivePtr &prim);
|
FuncGraphPtr GetBprop(const PrimitivePtr &prim, const pipeline::ResourceBasePtr &resources = nullptr);
|
||||||
FuncGraphPtr GetFprop(const PrimitivePtr &prim);
|
FuncGraphPtr GetFprop(const PrimitivePtr &prim);
|
||||||
FuncGraphPtr FakeBprop(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources);
|
FuncGraphPtr FakeBprop(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources);
|
||||||
FuncGraphPtr BpropCut(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources);
|
FuncGraphPtr BpropCut(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources);
|
||||||
|
|
|
@ -26,6 +26,7 @@
|
||||||
#include "ir/manager.h"
|
#include "ir/manager.h"
|
||||||
#include "pipeline/jit/resource.h"
|
#include "pipeline/jit/resource.h"
|
||||||
#include "pipeline/jit/parse/parse.h"
|
#include "pipeline/jit/parse/parse.h"
|
||||||
|
#include "pipeline/jit/parse/resolve.h"
|
||||||
#include "frontend/optimizer/ad/dfunctor.h"
|
#include "frontend/optimizer/ad/dfunctor.h"
|
||||||
#include "frontend/operator/ops.h"
|
#include "frontend/operator/ops.h"
|
||||||
#include "frontend/operator/composite/composite.h"
|
#include "frontend/operator/composite/composite.h"
|
||||||
|
@ -35,12 +36,102 @@
|
||||||
#include "utils/ms_context.h"
|
#include "utils/ms_context.h"
|
||||||
#include "utils/info.h"
|
#include "utils/info.h"
|
||||||
#include "debug/trace.h"
|
#include "debug/trace.h"
|
||||||
|
#include "debug/common.h"
|
||||||
|
#include "debug/dump_proto.h"
|
||||||
|
#include "mindspore/core/load_mindir/load_model.h"
|
||||||
|
#include "utils/system/sha256.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ad {
|
namespace ad {
|
||||||
KPrim g_k_prims;
|
KPrim g_k_prims;
|
||||||
|
|
||||||
FuncGraphPtr KPrim::GetBprop(const PrimitivePtr &prim) {
|
namespace {
|
||||||
|
constexpr char kBpropMindIRSuffix[] = "_bprop.mindir";
|
||||||
|
constexpr char kBpropMindIRDir[] = "/../bprop_mindir/";
|
||||||
|
constexpr char kGenerateMindirEnv[] = "GENERATE_MINDIR";
|
||||||
|
|
||||||
|
#ifndef _WIN32
|
||||||
|
bool IsSerializableBprop(const PrimitivePtr &prim) {
|
||||||
|
static std::unordered_set<PrimitivePtr> serializable_bprop_list{prim::kPrimRelu, prim::kPrimIdentity};
|
||||||
|
|
||||||
|
return std::any_of(serializable_bprop_list.begin(), serializable_bprop_list.end(),
|
||||||
|
[&prim](const PrimitivePtr &serializable_bprop_prim) {
|
||||||
|
auto str1 = prim->name();
|
||||||
|
auto str2 = serializable_bprop_prim->name();
|
||||||
|
transform(str1.begin(), str1.end(), str1.begin(), ::tolower);
|
||||||
|
transform(str2.begin(), str2.end(), str2.begin(), ::tolower);
|
||||||
|
return str1 == str2;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string GetBpropDir() {
|
||||||
|
static std::string bprop_dir;
|
||||||
|
if (bprop_dir.empty()) {
|
||||||
|
py::module mod = py::module::import("mindspore.ops._grad");
|
||||||
|
auto grad_file_path = mod.attr("__file__").cast<std::string>();
|
||||||
|
bprop_dir = grad_file_path.substr(0, grad_file_path.find_last_of('/'));
|
||||||
|
}
|
||||||
|
return bprop_dir;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string GetBpropHash() {
|
||||||
|
static std::string bprop_hash;
|
||||||
|
if (bprop_hash.empty()) {
|
||||||
|
auto bprop_dir = GetBpropDir();
|
||||||
|
auto realpath = Common::GetRealPath(bprop_dir);
|
||||||
|
if (!realpath.has_value()) {
|
||||||
|
MS_LOG(EXCEPTION) << "Get real path of bprop dir failed. path=" << bprop_dir;
|
||||||
|
}
|
||||||
|
bprop_hash = system::sha256::GetHashFromDir(realpath.value());
|
||||||
|
}
|
||||||
|
return bprop_hash;
|
||||||
|
}
|
||||||
|
|
||||||
|
FuncGraphPtr ImportBpropFromMindIR(const PrimitivePtr &prim) {
|
||||||
|
MS_EXCEPTION_IF_NULL(prim);
|
||||||
|
std::string bprop_dir = GetBpropDir();
|
||||||
|
auto bprop_mindir_path = bprop_dir + kBpropMindIRDir;
|
||||||
|
std::optional<std::string> bprop_mindir_realpath =
|
||||||
|
Common::GetRealPath(bprop_mindir_path + prim->name() + kBpropMindIRSuffix);
|
||||||
|
bool bprop_cache_file_exists = bprop_mindir_realpath.has_value() && Common::FileExists(bprop_mindir_realpath.value());
|
||||||
|
if (!bprop_cache_file_exists) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
auto bprop_fg = LoadMindIR(bprop_mindir_realpath.value());
|
||||||
|
if (bprop_fg != nullptr && bprop_fg->bprop_hash() != GetBpropHash()) {
|
||||||
|
MS_LOG(EXCEPTION) << "The bprop mindir files are not up to date. Please run the " << bprop_mindir_path
|
||||||
|
<< "generate_mindir.py to generate new mindir files.\n"
|
||||||
|
<< "bprop_fg hash: " << bprop_fg->bprop_hash() << "\n"
|
||||||
|
<< "bprop hash: " << GetBpropHash();
|
||||||
|
}
|
||||||
|
return bprop_fg;
|
||||||
|
}
|
||||||
|
|
||||||
|
void ExportBpropToMindIR(const PrimitivePtr &prim, const FuncGraphPtr &func_graph) {
|
||||||
|
MS_EXCEPTION_IF_NULL(prim);
|
||||||
|
std::string bprop_dir = GetBpropDir();
|
||||||
|
func_graph->set_bprop_hash(GetBpropHash());
|
||||||
|
auto bprop_mindir_path = bprop_dir + kBpropMindIRDir;
|
||||||
|
std::optional<std::string> bprop_mindir_realpath =
|
||||||
|
Common::GetRealPath(bprop_mindir_path + prim->name() + kBpropMindIRSuffix);
|
||||||
|
if (!bprop_mindir_realpath.has_value()) {
|
||||||
|
MS_LOG(ERROR) << "Failed to get the realpath of bprop mindir: " << bprop_mindir_path << prim->name()
|
||||||
|
<< kBpropMindIRSuffix;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
std::ofstream fout(bprop_mindir_realpath.value());
|
||||||
|
mind_ir::ModelProto fg_model = GetBinaryProto(func_graph);
|
||||||
|
if (!fg_model.SerializeToOstream(&fout)) {
|
||||||
|
MS_LOG(WARNING) << "Failed to cache the bprop of op \"" << prim->name() << "\" to file \""
|
||||||
|
<< bprop_mindir_realpath.value() << "\".";
|
||||||
|
}
|
||||||
|
fout.close();
|
||||||
|
ChangeFileMode(bprop_mindir_realpath.value(), S_IRUSR | S_IWUSR);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
FuncGraphPtr KPrim::GetBprop(const PrimitivePtr &prim, const pipeline::ResourceBasePtr &resources) {
|
||||||
// Set a child scope named "grad'PrimitiveName'" for the bprop function,
|
// Set a child scope named "grad'PrimitiveName'" for the bprop function,
|
||||||
// and add "Gradients" to the front.
|
// and add "Gradients" to the front.
|
||||||
static const std::string gradients_scope = "Gradients/";
|
static const std::string gradients_scope = "Gradients/";
|
||||||
|
@ -50,6 +141,17 @@ FuncGraphPtr KPrim::GetBprop(const PrimitivePtr &prim) {
|
||||||
grad_op_child_scope_prefix + prim->name());
|
grad_op_child_scope_prefix + prim->name());
|
||||||
ScopeGuard scope_guard(scope);
|
ScopeGuard scope_guard(scope);
|
||||||
|
|
||||||
|
// Firstly we get bprop from mindir. If failed, parse the python function registered.
|
||||||
|
FuncGraphPtr func_graph = nullptr;
|
||||||
|
#ifndef _WIN32
|
||||||
|
bool serializable = IsSerializableBprop(prim);
|
||||||
|
if (serializable && common::GetEnv(kGenerateMindirEnv) != "1") {
|
||||||
|
func_graph = ImportBpropFromMindIR(prim);
|
||||||
|
if (func_graph != nullptr) {
|
||||||
|
return func_graph;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif
|
||||||
py::function fn;
|
py::function fn;
|
||||||
if (prim->is_base()) {
|
if (prim->is_base()) {
|
||||||
fn = GetBpropFunction(prim->name());
|
fn = GetBpropFunction(prim->name());
|
||||||
|
@ -63,7 +165,7 @@ FuncGraphPtr KPrim::GetBprop(const PrimitivePtr &prim) {
|
||||||
MS_LOG(DEBUG) << "Fail to find bprop function for " << prim->name() << ".";
|
MS_LOG(DEBUG) << "Fail to find bprop function for " << prim->name() << ".";
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
FuncGraphPtr func_graph = parse::ParsePythonCode(fn);
|
func_graph = parse::ParsePythonCode(fn);
|
||||||
if (func_graph == nullptr) {
|
if (func_graph == nullptr) {
|
||||||
MS_LOG(ERROR) << "Fail to parse bprop function for " << prim->name() << ".";
|
MS_LOG(ERROR) << "Fail to parse bprop function for " << prim->name() << ".";
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
@ -72,7 +174,14 @@ FuncGraphPtr KPrim::GetBprop(const PrimitivePtr &prim) {
|
||||||
if (bprop_flag) {
|
if (bprop_flag) {
|
||||||
func_graph->set_flag(mindspore::kFuncGraphFlagReAutoMonad, true);
|
func_graph->set_flag(mindspore::kFuncGraphFlagReAutoMonad, true);
|
||||||
}
|
}
|
||||||
|
pipeline::ResourceBasePtr res = (resources != nullptr) ? resources : std::make_shared<pipeline::Resource>();
|
||||||
|
parse::ResolveFuncGraph(func_graph, res);
|
||||||
|
#ifndef _WIN32
|
||||||
|
// Check whether the bprop needs to be exported.
|
||||||
|
if (serializable) {
|
||||||
|
ExportBpropToMindIR(prim, func_graph);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
return func_graph;
|
return func_graph;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -187,7 +296,7 @@ FuncGraphPtr KPrim::KPrimitive(const CNodePtr &cnode, const ValueNodePtr &value_
|
||||||
}
|
}
|
||||||
|
|
||||||
if (bprop_fg == nullptr) {
|
if (bprop_fg == nullptr) {
|
||||||
bprop_fg = GetBprop(prim);
|
bprop_fg = GetBprop(prim, resources);
|
||||||
if (bprop_fg != nullptr) {
|
if (bprop_fg != nullptr) {
|
||||||
// Set bprop_g graph cache
|
// Set bprop_g graph cache
|
||||||
bprop_registry_[prim] = bprop_fg;
|
bprop_registry_[prim] = bprop_fg;
|
||||||
|
|
|
@ -114,11 +114,6 @@ FuncGraphPtr PrimBpOptPassStep1(const opt::irpass::OptimizeIRPassLib &irpass, co
|
||||||
irpass.pynative_eliminate_,
|
irpass.pynative_eliminate_,
|
||||||
});
|
});
|
||||||
opt::irpass::ResolveIRPassLib resolve_irpass;
|
opt::irpass::ResolveIRPassLib resolve_irpass;
|
||||||
opt::OptPassConfig resolver_prim = opt::OptPassConfig({
|
|
||||||
resolve_irpass.resolver_resolve_and_getattr_,
|
|
||||||
resolve_irpass.resolver_resolve_,
|
|
||||||
resolve_irpass.resolver_getattr_,
|
|
||||||
});
|
|
||||||
|
|
||||||
opt::OptPassConfig switch_simplify = opt::OptPassConfig({
|
opt::OptPassConfig switch_simplify = opt::OptPassConfig({
|
||||||
irpass.switch_simplify_,
|
irpass.switch_simplify_,
|
||||||
|
@ -133,7 +128,6 @@ FuncGraphPtr PrimBpOptPassStep1(const opt::irpass::OptimizeIRPassLib &irpass, co
|
||||||
});
|
});
|
||||||
|
|
||||||
OptPassGroupMap map({{"ad_eliminate", pynative_eliminate},
|
OptPassGroupMap map({{"ad_eliminate", pynative_eliminate},
|
||||||
{"ad_resolver_prim", resolver_prim},
|
|
||||||
{"ad_inline", inline_opt},
|
{"ad_inline", inline_opt},
|
||||||
{"bool_scalar_eliminate", bool_scalar_eliminate},
|
{"bool_scalar_eliminate", bool_scalar_eliminate},
|
||||||
{"ad_switch_simplify", switch_simplify}});
|
{"ad_switch_simplify", switch_simplify}});
|
||||||
|
@ -321,9 +315,8 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
|
||||||
opt::OptPassConfig virtual_dataset = opt::OptPassConfig({irpass.virtual_dataset_eliminate_});
|
opt::OptPassConfig virtual_dataset = opt::OptPassConfig({irpass.virtual_dataset_eliminate_});
|
||||||
opt::irpass::ResolveIRPassLib resolve_irpass;
|
opt::irpass::ResolveIRPassLib resolve_irpass;
|
||||||
|
|
||||||
opt::OptPassConfig resolve_pass =
|
opt::OptPassConfig after_resolve_pass =
|
||||||
opt::OptPassConfig({resolve_irpass.resolver_resolve_, resolve_irpass.resolver_getattr_,
|
opt::OptPassConfig({irpass.get_make_ref_eliminate_, irpass.replace_old_param_});
|
||||||
irpass.get_make_ref_eliminate_, irpass.replace_old_param_});
|
|
||||||
|
|
||||||
// Before adjusting map_a, check GetA1A2() and GetOptPynativeGradEpiloguePhases().
|
// Before adjusting map_a, check GetA1A2() and GetOptPynativeGradEpiloguePhases().
|
||||||
OptPassGroupMap map_a({{"a_1", a_1},
|
OptPassGroupMap map_a({{"a_1", a_1},
|
||||||
|
@ -336,7 +329,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
|
||||||
{"virtual_dataset", virtual_dataset},
|
{"virtual_dataset", virtual_dataset},
|
||||||
{"virtual_output", opt::OptPassConfig({irpass.virtual_output_eliminate_})},
|
{"virtual_output", opt::OptPassConfig({irpass.virtual_output_eliminate_})},
|
||||||
{"grad", opt::OptPassConfig(opt::irpass::ExpandJPrim())},
|
{"grad", opt::OptPassConfig(opt::irpass::ExpandJPrim())},
|
||||||
{"resolve", resolve_pass},
|
{"after_resolve", after_resolve_pass},
|
||||||
{"a_after_grad", a_after_grad},
|
{"a_after_grad", a_after_grad},
|
||||||
{"renormalize", opt::OptPassConfig::Renormalize()},
|
{"renormalize", opt::OptPassConfig::Renormalize()},
|
||||||
{"auto_monad_grad", opt::OptPassConfig(ReAutoMonadWrapper)},
|
{"auto_monad_grad", opt::OptPassConfig(ReAutoMonadWrapper)},
|
||||||
|
|
|
@ -176,6 +176,7 @@ void IrExportBuilder::BuildModelInfo() {
|
||||||
void IrExportBuilder::BuildModel(const FuncGraphPtr &func_graph) {
|
void IrExportBuilder::BuildModel(const FuncGraphPtr &func_graph) {
|
||||||
mind_ir::GraphProto *graph_proto = model_.mutable_graph();
|
mind_ir::GraphProto *graph_proto = model_.mutable_graph();
|
||||||
graph_proto->set_name(func_graph->ToString());
|
graph_proto->set_name(func_graph->ToString());
|
||||||
|
graph_proto->set_bprop_hash(func_graph->bprop_hash());
|
||||||
ResetNodeIndex();
|
ResetNodeIndex();
|
||||||
todo_.clear();
|
todo_.clear();
|
||||||
todo_.push_back(func_graph);
|
todo_.push_back(func_graph);
|
||||||
|
@ -247,6 +248,9 @@ void IrExportBuilder::SetValueInfoProto(const AnfNodePtr &node, mind_ir::ValueIn
|
||||||
MS_LOG(DEBUG) << "SetValueInfoProto: " << node->DebugString();
|
MS_LOG(DEBUG) << "SetValueInfoProto: " << node->DebugString();
|
||||||
const TypePtr &type = node->Type();
|
const TypePtr &type = node->Type();
|
||||||
const BaseShapePtr &shape = node->Shape();
|
const BaseShapePtr &shape = node->Shape();
|
||||||
|
if (type == nullptr || shape == nullptr) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
if (type->isa<TensorType>() && shape->isa<abstract::Shape>()) {
|
if (type->isa<TensorType>() && shape->isa<abstract::Shape>()) {
|
||||||
auto tensor = type->cast<TensorTypePtr>();
|
auto tensor = type->cast<TensorTypePtr>();
|
||||||
auto elem_type = tensor->element();
|
auto elem_type = tensor->element();
|
||||||
|
@ -404,9 +408,10 @@ void IrExportBuilder::SetShapeToNodeProto(const CNodePtr &node, mind_ir::NodePro
|
||||||
// 3. save tuple string in ref_attr_name
|
// 3. save tuple string in ref_attr_name
|
||||||
MS_EXCEPTION_IF_NULL(node);
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
auto type = node->Type();
|
auto type = node->Type();
|
||||||
MS_EXCEPTION_IF_NULL(type);
|
|
||||||
auto shape = node->Shape();
|
auto shape = node->Shape();
|
||||||
MS_EXCEPTION_IF_NULL(shape);
|
if (type == nullptr || shape == nullptr) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
ResetTupleIndex();
|
ResetTupleIndex();
|
||||||
std::string seq_string = "shape:";
|
std::string seq_string = "shape:";
|
||||||
mind_ir::AttributeProto *attr_proto = node_proto->add_attribute();
|
mind_ir::AttributeProto *attr_proto = node_proto->add_attribute();
|
||||||
|
|
|
@ -60,6 +60,29 @@ py::tuple ConvertDatatoPyTuple(const VectorRef &args) {
|
||||||
return py_args;
|
return py_args;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
py::function GetComputeFunctionWithoutPyObj(const std::string &name) {
|
||||||
|
static const std::string module = "tests.vm_impl.vm_impl_function";
|
||||||
|
py::module mod = py::module::import(common::SafeCStr(module));
|
||||||
|
if (!py::hasattr(mod, common::SafeCStr(name))) {
|
||||||
|
return py::none();
|
||||||
|
}
|
||||||
|
py::object fn = mod.attr(common::SafeCStr(name));
|
||||||
|
return fn;
|
||||||
|
}
|
||||||
|
|
||||||
|
BaseRef RunComputeFunctionWithoutPyObj(const PrimitivePtr &prim, const VectorRef &args) {
|
||||||
|
auto func = GetComputeFunctionWithoutPyObj(prim->name());
|
||||||
|
if (py::isinstance<py::none>(func)) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
auto py_args = ConvertDatatoPyTuple(args);
|
||||||
|
py::object obj = func(*py_args);
|
||||||
|
if (py::isinstance<py::none>(obj)) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
return std::make_shared<PyObjectRef>(obj);
|
||||||
|
}
|
||||||
|
|
||||||
BaseRef RunComputeFunction(const PrimitivePtr &prim, const VectorRef &args) {
|
BaseRef RunComputeFunction(const PrimitivePtr &prim, const VectorRef &args) {
|
||||||
auto func = GetComputeFunction(prim->name());
|
auto func = GetComputeFunction(prim->name());
|
||||||
if (py::isinstance<py::none>(func)) {
|
if (py::isinstance<py::none>(func)) {
|
||||||
|
|
|
@ -33,6 +33,10 @@ py::function GetComputeFunction(std::string name);
|
||||||
|
|
||||||
BaseRef RunComputeFunction(const PrimitivePtr &prim, const VectorRef &args);
|
BaseRef RunComputeFunction(const PrimitivePtr &prim, const VectorRef &args);
|
||||||
|
|
||||||
|
py::function GetComputeFunctionWithoutPyObj(const std::string &name);
|
||||||
|
|
||||||
|
BaseRef RunComputeFunctionWithoutPyObj(const PrimitivePtr &prim, const VectorRef &args);
|
||||||
|
|
||||||
py::tuple ConvertDatatoPyTuple(const VectorRef &args);
|
py::tuple ConvertDatatoPyTuple(const VectorRef &args);
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,201 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "utils/system/sha256.h"
|
||||||
|
#include <dirent.h>
|
||||||
|
#include <sys/stat.h>
|
||||||
|
#include <iomanip>
|
||||||
|
#include <fstream>
|
||||||
|
#include <vector>
|
||||||
|
#include <algorithm>
|
||||||
|
#include <numeric>
|
||||||
|
#include "securec/include/securec.h"
|
||||||
|
#include "utils/log_adapter.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace system {
|
||||||
|
namespace sha256 {
|
||||||
|
constexpr int kBitNumber = 8;
|
||||||
|
constexpr int kDigestSize = 8;
|
||||||
|
constexpr int kIterationNumber = 64;
|
||||||
|
constexpr int kMessageBlockLength = 64;
|
||||||
|
const uint32_t constant[64] = {
|
||||||
|
0x428a2f98, 0x71374491, 0xb5c0fbcf, 0xe9b5dba5, 0x3956c25b, 0x59f111f1, 0x923f82a4, 0xab1c5ed5,
|
||||||
|
0xd807aa98, 0x12835b01, 0x243185be, 0x550c7dc3, 0x72be5d74, 0x80deb1fe, 0x9bdc06a7, 0xc19bf174,
|
||||||
|
0xe49b69c1, 0xefbe4786, 0x0fc19dc6, 0x240ca1cc, 0x2de92c6f, 0x4a7484aa, 0x5cb0a9dc, 0x76f988da,
|
||||||
|
0x983e5152, 0xa831c66d, 0xb00327c8, 0xbf597fc7, 0xc6e00bf3, 0xd5a79147, 0x06ca6351, 0x14292967,
|
||||||
|
0x27b70a85, 0x2e1b2138, 0x4d2c6dfc, 0x53380d13, 0x650a7354, 0x766a0abb, 0x81c2c92e, 0x92722c85,
|
||||||
|
0xa2bfe8a1, 0xa81a664b, 0xc24b8b70, 0xc76c51a3, 0xd192e819, 0xd6990624, 0xf40e3585, 0x106aa070,
|
||||||
|
0x19a4c116, 0x1e376c08, 0x2748774c, 0x34b0bcb5, 0x391c0cb3, 0x4ed8aa4a, 0x5b9cca4f, 0x682e6ff3,
|
||||||
|
0x748f82ee, 0x78a5636f, 0x84c87814, 0x8cc70208, 0x90befffa, 0xa4506ceb, 0xbef9a3f7, 0xc67178f2};
|
||||||
|
|
||||||
|
std::string LoadFilePath(const std::string &path) {
|
||||||
|
char real_path[PATH_MAX] = {0};
|
||||||
|
#if defined(_WIN32) || defined(_WIN64)
|
||||||
|
if (path.size() >= PATH_MAX || _fullpath(real_path, path.c_str(), PATH_MAX) == nullptr) {
|
||||||
|
return "";
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
if (path.size() >= PATH_MAX || realpath(path.c_str(), real_path) == nullptr) {
|
||||||
|
return "";
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
std::ifstream bin_stream(real_path, std::ios::binary);
|
||||||
|
if (!bin_stream.is_open()) {
|
||||||
|
return "";
|
||||||
|
}
|
||||||
|
std::string message((std::istreambuf_iterator<char>(bin_stream)), std::istreambuf_iterator<char>());
|
||||||
|
return message;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool Padding(std::string *message) {
|
||||||
|
uint64_t bits_message = message->size() * kBitNumber;
|
||||||
|
const int remains = message->size() % kMessageBlockLength;
|
||||||
|
// The length of the message needs to be stored in 8 bytes, supplemented at the end of the message.
|
||||||
|
const int size_append = 8;
|
||||||
|
const int size_required = kMessageBlockLength - size_append;
|
||||||
|
const int size_pad = size_required - remains + (size_required > remains ? 0 : kMessageBlockLength);
|
||||||
|
if (size_pad < 1 || size_pad > kMessageBlockLength) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
message->push_back(0x80);
|
||||||
|
for (int i = 1; i < size_pad; ++i) {
|
||||||
|
message->push_back(0x00);
|
||||||
|
}
|
||||||
|
for (int i = size_append - 1; i >= 0; --i) {
|
||||||
|
message->push_back(static_cast<uint8_t>((bits_message >> static_cast<uint32_t>(i * kBitNumber)) & 0xff));
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool ProcessInner(const std::string &message, const int &bias, uint32_t *digest, const int &digest_size) {
|
||||||
|
if (digest_size != 8) { // The number of digests is fixed at 8
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
uint32_t w[kIterationNumber] = {0};
|
||||||
|
for (int i = 0; i < 16; ++i) {
|
||||||
|
w[i] = (static_cast<uint32_t>(static_cast<uint8_t>(message[bias + i * 4]) & 0xff) << 24) |
|
||||||
|
(static_cast<uint32_t>(static_cast<uint8_t>(message[bias + i * 4 + 1]) & 0xff) << 16) |
|
||||||
|
(static_cast<uint32_t>(static_cast<uint8_t>(message[bias + i * 4 + 2]) & 0xff) << 8) |
|
||||||
|
(static_cast<uint32_t>(static_cast<uint8_t>(message[bias + i * 4 + 3]) & 0xff));
|
||||||
|
}
|
||||||
|
for (int i = 16; i < kIterationNumber; ++i) {
|
||||||
|
w[i] = sigma3(w[i - 2]) + w[i - 7] + sigma2(w[i - 15]) + w[i - 16];
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<uint32_t> hash(digest_size);
|
||||||
|
size_t mem_size = digest_size * sizeof(uint32_t);
|
||||||
|
auto ret = memcpy_s(hash.data(), mem_size, digest, mem_size);
|
||||||
|
if (ret != EOK) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
for (int i = 0; i < kIterationNumber; ++i) {
|
||||||
|
uint32_t t1 = w[i] + constant[i] + hash[7] + sigma1(hash[4]) + ch(hash[4], hash[5], hash[6]);
|
||||||
|
uint32_t t2 = sigma0(hash[0]) + ma(hash[0], hash[1], hash[2]);
|
||||||
|
for (int j = digest_size - 1; j >= 0; --j) {
|
||||||
|
if (j == 4) {
|
||||||
|
hash[j] = hash[j - 1] + t1;
|
||||||
|
} else if (j == 0) {
|
||||||
|
hash[j] = t1 + t2;
|
||||||
|
} else {
|
||||||
|
hash[j] = hash[j - 1];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (int i = 0; i < digest_size; ++i) {
|
||||||
|
digest[i] += hash[i];
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string ConvertToString(uint32_t *input, const int &size) {
|
||||||
|
std::ostringstream oss;
|
||||||
|
oss << std::hex;
|
||||||
|
for (int i = 0; i < size; ++i) {
|
||||||
|
for (int j = static_cast<int>(sizeof(uint32_t) / sizeof(uint8_t)) - 1; j >= 0; --j) {
|
||||||
|
auto val = static_cast<uint8_t>((input[i] >> static_cast<uint32_t>(j * kBitNumber)) & 0xff);
|
||||||
|
oss << std::setw(2) << std::setfill('0') << static_cast<unsigned int>(val);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return oss.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string Encrypt(const std::string &message) {
|
||||||
|
uint32_t digest[kDigestSize] = {0x6a09e667, 0xbb67ae85, 0x3c6ef372, 0xa54ff53a,
|
||||||
|
0x510e527f, 0x9b05688c, 0x1f83d9ab, 0x5be0cd19};
|
||||||
|
for (int i = 0; i < static_cast<int>(message.size()); i += kMessageBlockLength) {
|
||||||
|
if (!ProcessInner(message, i, digest, kDigestSize)) {
|
||||||
|
return "";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ConvertToString(digest, kDigestSize);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string GetHashFromString(const std::string &data) {
|
||||||
|
std::string message = data;
|
||||||
|
if (message.empty() || !Padding(&message)) {
|
||||||
|
return "";
|
||||||
|
}
|
||||||
|
return Encrypt(message);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string GetHashFromFile(const std::string &path) {
|
||||||
|
std::string message = LoadFilePath(path);
|
||||||
|
if (message.empty() || !Padding(&message)) {
|
||||||
|
return "";
|
||||||
|
}
|
||||||
|
return Encrypt(message);
|
||||||
|
}
|
||||||
|
|
||||||
|
#ifndef _WIN32
|
||||||
|
std::string GetHashFromDir(const std::string &dir) {
|
||||||
|
if (dir.empty()) {
|
||||||
|
MS_LOG(ERROR) << "The directory path is empty.";
|
||||||
|
return "";
|
||||||
|
}
|
||||||
|
struct stat s {};
|
||||||
|
int ret = stat(dir.c_str(), &s);
|
||||||
|
if (ret != 0) {
|
||||||
|
MS_LOG(ERROR) << "stat dir \"" << dir << "\" failed, ret is : " << ret;
|
||||||
|
return "";
|
||||||
|
}
|
||||||
|
if (!S_ISDIR(s.st_mode)) {
|
||||||
|
MS_LOG(ERROR) << "The path \"" << dir << "\" is not a directory.";
|
||||||
|
return "";
|
||||||
|
}
|
||||||
|
DIR *open_dir = opendir(dir.c_str());
|
||||||
|
if (open_dir == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "open dir " << dir.c_str() << " failed";
|
||||||
|
return "";
|
||||||
|
}
|
||||||
|
struct dirent *filename;
|
||||||
|
std::vector<std::string> file_hashes;
|
||||||
|
while ((filename = readdir(open_dir)) != nullptr) {
|
||||||
|
std::string d_name = std::string(filename->d_name);
|
||||||
|
if (d_name == "." || d_name == ".." || filename->d_type != DT_REG) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
file_hashes.emplace_back(GetHashFromFile(std::string(dir) + "/" + filename->d_name));
|
||||||
|
}
|
||||||
|
closedir(open_dir);
|
||||||
|
std::sort(file_hashes.begin(), file_hashes.end());
|
||||||
|
auto dir_hash = std::accumulate(file_hashes.begin(), file_hashes.end(), std::string{});
|
||||||
|
return dir_hash;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
} // namespace sha256
|
||||||
|
} // namespace system
|
||||||
|
} // namespace mindspore
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -18,28 +18,10 @@
|
||||||
#define MINDSPORE_CCSRC_UTILS_SYSTEM_SHA256_H_
|
#define MINDSPORE_CCSRC_UTILS_SYSTEM_SHA256_H_
|
||||||
|
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <iomanip>
|
|
||||||
#include <fstream>
|
|
||||||
#include <memory>
|
|
||||||
#include <vector>
|
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace system {
|
namespace system {
|
||||||
namespace sha256 {
|
namespace sha256 {
|
||||||
constexpr int kBitNumber = 8;
|
|
||||||
constexpr int kDigestSize = 8;
|
|
||||||
constexpr int kIterationNumber = 64;
|
|
||||||
constexpr int kMessageBlockLength = 64;
|
|
||||||
const uint32_t constant[64] = {
|
|
||||||
0x428a2f98, 0x71374491, 0xb5c0fbcf, 0xe9b5dba5, 0x3956c25b, 0x59f111f1, 0x923f82a4, 0xab1c5ed5,
|
|
||||||
0xd807aa98, 0x12835b01, 0x243185be, 0x550c7dc3, 0x72be5d74, 0x80deb1fe, 0x9bdc06a7, 0xc19bf174,
|
|
||||||
0xe49b69c1, 0xefbe4786, 0x0fc19dc6, 0x240ca1cc, 0x2de92c6f, 0x4a7484aa, 0x5cb0a9dc, 0x76f988da,
|
|
||||||
0x983e5152, 0xa831c66d, 0xb00327c8, 0xbf597fc7, 0xc6e00bf3, 0xd5a79147, 0x06ca6351, 0x14292967,
|
|
||||||
0x27b70a85, 0x2e1b2138, 0x4d2c6dfc, 0x53380d13, 0x650a7354, 0x766a0abb, 0x81c2c92e, 0x92722c85,
|
|
||||||
0xa2bfe8a1, 0xa81a664b, 0xc24b8b70, 0xc76c51a3, 0xd192e819, 0xd6990624, 0xf40e3585, 0x106aa070,
|
|
||||||
0x19a4c116, 0x1e376c08, 0x2748774c, 0x34b0bcb5, 0x391c0cb3, 0x4ed8aa4a, 0x5b9cca4f, 0x682e6ff3,
|
|
||||||
0x748f82ee, 0x78a5636f, 0x84c87814, 0x8cc70208, 0x90befffa, 0xa4506ceb, 0xbef9a3f7, 0xc67178f2};
|
|
||||||
|
|
||||||
inline uint32_t ch(uint32_t x, uint32_t y, uint32_t z) { return (x & y) ^ ((~x) & z); }
|
inline uint32_t ch(uint32_t x, uint32_t y, uint32_t z) { return (x & y) ^ ((~x) & z); }
|
||||||
inline uint32_t ma(uint32_t x, uint32_t y, uint32_t z) { return (x & y) ^ (x & z) ^ (y & z); }
|
inline uint32_t ma(uint32_t x, uint32_t y, uint32_t z) { return (x & y) ^ (x & z) ^ (y & z); }
|
||||||
inline uint32_t sigma0(uint32_t x) { return (x >> 2 | x << 30) ^ (x >> 13 | x << 19) ^ (x >> 22 | x << 10); }
|
inline uint32_t sigma0(uint32_t x) { return (x >> 2 | x << 30) ^ (x >> 13 | x << 19) ^ (x >> 22 | x << 10); }
|
||||||
|
@ -47,123 +29,23 @@ inline uint32_t sigma1(uint32_t x) { return (x >> 6 | x << 26) ^ (x >> 11 | x <<
|
||||||
inline uint32_t sigma2(uint32_t x) { return (x >> 7 | x << 25) ^ (x >> 18 | x << 14) ^ (x >> 3); }
|
inline uint32_t sigma2(uint32_t x) { return (x >> 7 | x << 25) ^ (x >> 18 | x << 14) ^ (x >> 3); }
|
||||||
inline uint32_t sigma3(uint32_t x) { return (x >> 17 | x << 15) ^ (x >> 19 | x << 13) ^ (x >> 10); }
|
inline uint32_t sigma3(uint32_t x) { return (x >> 17 | x << 15) ^ (x >> 19 | x << 13) ^ (x >> 10); }
|
||||||
|
|
||||||
std::string LoadFilePath(const std::string &path) {
|
std::string LoadFilePath(const std::string &path);
|
||||||
char real_path[PATH_MAX] = {0};
|
|
||||||
#if defined(_WIN32) || defined(_WIN64)
|
bool Padding(std::string *message);
|
||||||
if (path.size() >= PATH_MAX || _fullpath(real_path, path.c_str(), PATH_MAX) == nullptr) {
|
|
||||||
return "";
|
bool ProcessInner(const std::string &message, const int &bias, uint32_t *digest, const int &digest_size);
|
||||||
}
|
|
||||||
#else
|
std::string ConvertToString(uint32_t *input, const int &size);
|
||||||
if (path.size() >= PATH_MAX || realpath(path.c_str(), real_path) == nullptr) {
|
|
||||||
return "";
|
std::string Encrypt(const std::string &message);
|
||||||
}
|
|
||||||
|
std::string GetHashFromString(const std::string &data);
|
||||||
|
|
||||||
|
std::string GetHashFromFile(const std::string &path);
|
||||||
|
|
||||||
|
#ifndef _WIN32
|
||||||
|
std::string GetHashFromDir(const std::string &dir);
|
||||||
#endif
|
#endif
|
||||||
std::ifstream bin_stream(real_path, std::ios::binary);
|
|
||||||
if (!bin_stream.is_open()) {
|
|
||||||
return "";
|
|
||||||
}
|
|
||||||
std::string message((std::istreambuf_iterator<char>(bin_stream)), std::istreambuf_iterator<char>());
|
|
||||||
return message;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool Padding(std::string *message) {
|
|
||||||
uint64_t bits_message = message->size() * kBitNumber;
|
|
||||||
const int remains = message->size() % kMessageBlockLength;
|
|
||||||
// The length of the message needs to be stored in 8 bytes, supplemented at the end of the message.
|
|
||||||
const int size_append = 8;
|
|
||||||
const int size_required = kMessageBlockLength - size_append;
|
|
||||||
const int size_pad = size_required - remains + (size_required > remains ? 0 : kMessageBlockLength);
|
|
||||||
if (size_pad < 1 || size_pad > kMessageBlockLength) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
message->push_back(0x80);
|
|
||||||
for (int i = 1; i < size_pad; ++i) {
|
|
||||||
message->push_back(0x00);
|
|
||||||
}
|
|
||||||
for (int i = size_append - 1; i >= 0; --i) {
|
|
||||||
message->push_back(static_cast<uint8_t>((bits_message >> static_cast<uint32_t>(i * kBitNumber)) & 0xff));
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool ProcessInner(const std::string &message, const int &bias, uint32_t *digest, const int &digest_size) {
|
|
||||||
if (digest_size != 8) { // The number of digests is fixed at 8
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
uint32_t w[kIterationNumber] = {0};
|
|
||||||
for (int i = 0; i < 16; ++i) {
|
|
||||||
w[i] = (static_cast<uint32_t>(static_cast<uint8_t>(message[bias + i * 4]) & 0xff) << 24) |
|
|
||||||
(static_cast<uint32_t>(static_cast<uint8_t>(message[bias + i * 4 + 1]) & 0xff) << 16) |
|
|
||||||
(static_cast<uint32_t>(static_cast<uint8_t>(message[bias + i * 4 + 2]) & 0xff) << 8) |
|
|
||||||
(static_cast<uint32_t>(static_cast<uint8_t>(message[bias + i * 4 + 3]) & 0xff));
|
|
||||||
}
|
|
||||||
for (int i = 16; i < kIterationNumber; ++i) {
|
|
||||||
w[i] = sigma3(w[i - 2]) + w[i - 7] + sigma2(w[i - 15]) + w[i - 16];
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<uint32_t> hash(digest_size);
|
|
||||||
size_t mem_size = digest_size * sizeof(uint32_t);
|
|
||||||
auto ret = memcpy_s(hash.data(), mem_size, digest, mem_size);
|
|
||||||
if (ret != EOK) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
for (int i = 0; i < kIterationNumber; ++i) {
|
|
||||||
uint32_t t1 = w[i] + constant[i] + hash[7] + sigma1(hash[4]) + ch(hash[4], hash[5], hash[6]);
|
|
||||||
uint32_t t2 = sigma0(hash[0]) + ma(hash[0], hash[1], hash[2]);
|
|
||||||
for (int j = digest_size - 1; j >= 0; --j) {
|
|
||||||
if (j == 4) {
|
|
||||||
hash[j] = hash[j - 1] + t1;
|
|
||||||
} else if (j == 0) {
|
|
||||||
hash[j] = t1 + t2;
|
|
||||||
} else {
|
|
||||||
hash[j] = hash[j - 1];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for (int i = 0; i < digest_size; ++i) {
|
|
||||||
digest[i] += hash[i];
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string ConvertToString(uint32_t *input, const int &size) {
|
|
||||||
std::ostringstream oss;
|
|
||||||
oss << std::hex;
|
|
||||||
for (int i = 0; i < size; ++i) {
|
|
||||||
for (int j = static_cast<int>(sizeof(uint32_t) / sizeof(uint8_t)) - 1; j >= 0; --j) {
|
|
||||||
auto val = static_cast<uint8_t>((input[i] >> static_cast<uint32_t>(j * kBitNumber)) & 0xff);
|
|
||||||
oss << std::setw(2) << std::setfill('0') << static_cast<unsigned int>(val);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return oss.str();
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string Encrypt(const std::string &message) {
|
|
||||||
uint32_t digest[kDigestSize] = {0x6a09e667, 0xbb67ae85, 0x3c6ef372, 0xa54ff53a,
|
|
||||||
0x510e527f, 0x9b05688c, 0x1f83d9ab, 0x5be0cd19};
|
|
||||||
for (int i = 0; i < static_cast<int>(message.size()); i += kMessageBlockLength) {
|
|
||||||
if (!ProcessInner(message, i, digest, kDigestSize)) {
|
|
||||||
return "";
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return ConvertToString(digest, kDigestSize);
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string GetHashFromString(const std::string &data) {
|
|
||||||
std::string message = data;
|
|
||||||
if (message.empty() || !Padding(&message)) {
|
|
||||||
return "";
|
|
||||||
}
|
|
||||||
return Encrypt(message);
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string GetHashFromFile(const std::string &path) {
|
|
||||||
std::string message = LoadFilePath(path);
|
|
||||||
if (message.empty() || !Padding(&message)) {
|
|
||||||
return "";
|
|
||||||
}
|
|
||||||
return Encrypt(message);
|
|
||||||
}
|
|
||||||
} // namespace sha256
|
} // namespace sha256
|
||||||
} // namespace system
|
} // namespace system
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -437,6 +437,9 @@ BaseRef RunOperation(const PrimitivePtr &prim, const VectorRef &args) {
|
||||||
MS_LOG(DEBUG) << "operation start " << prim->name();
|
MS_LOG(DEBUG) << "operation start " << prim->name();
|
||||||
MS_EXCEPTION_IF_NULL(prim);
|
MS_EXCEPTION_IF_NULL(prim);
|
||||||
auto result = prim->RunComputeFunction(args);
|
auto result = prim->RunComputeFunction(args);
|
||||||
|
if (result.is_null()) {
|
||||||
|
result = RunComputeFunctionWithoutPyObj(prim, args);
|
||||||
|
}
|
||||||
if (result.is_null()) {
|
if (result.is_null()) {
|
||||||
return RunComputeFunction(prim, args);
|
return RunComputeFunction(prim, args);
|
||||||
}
|
}
|
||||||
|
|
|
@ -394,6 +394,9 @@ class FuncGraph : public FuncGraphBase, public EffectInfoHolder {
|
||||||
bool dropped() const { return dropped_; }
|
bool dropped() const { return dropped_; }
|
||||||
void set_dropped(bool dropped) { dropped_ = dropped; }
|
void set_dropped(bool dropped) { dropped_ = dropped; }
|
||||||
|
|
||||||
|
std::string bprop_hash() const { return bprop_hash_; }
|
||||||
|
void set_bprop_hash(const std::string &bprop_hash) { bprop_hash_ = bprop_hash; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// Only used for func_graph manager to control resource free.
|
// Only used for func_graph manager to control resource free.
|
||||||
int attached_mng_cnt() const { return attached_mng_cnt_; }
|
int attached_mng_cnt() const { return attached_mng_cnt_; }
|
||||||
|
@ -478,6 +481,8 @@ class FuncGraph : public FuncGraphBase, public EffectInfoHolder {
|
||||||
// If the graph was changed, it should be dropped in cache data_converter::object_map_
|
// If the graph was changed, it should be dropped in cache data_converter::object_map_
|
||||||
// which used by ConvertToFuncGraph.
|
// which used by ConvertToFuncGraph.
|
||||||
bool dropped_ = false;
|
bool dropped_ = false;
|
||||||
|
// If the graph is a bprop graph, it should has a hash of the bprop directory.
|
||||||
|
std::string bprop_hash_;
|
||||||
};
|
};
|
||||||
|
|
||||||
inline CNodePtr NewCNode(const std::vector<AnfNodePtr> &inputs, const FuncGraphPtr &fg) {
|
inline CNodePtr NewCNode(const std::vector<AnfNodePtr> &inputs, const FuncGraphPtr &fg) {
|
||||||
|
|
|
@ -40,6 +40,8 @@ static constexpr char kConstantValueNode[] = "Constant";
|
||||||
static constexpr char kCNodeShapeAttr[] = "shape";
|
static constexpr char kCNodeShapeAttr[] = "shape";
|
||||||
static constexpr char kCNodeShape1Attr[] = "shape1";
|
static constexpr char kCNodeShape1Attr[] = "shape1";
|
||||||
static constexpr char kCNodeShape2Attr[] = "shape2";
|
static constexpr char kCNodeShape2Attr[] = "shape2";
|
||||||
|
static constexpr char kDoSignaturePrimitivePrefix[] = "S-Prim-";
|
||||||
|
|
||||||
enum ParseForm : int {
|
enum ParseForm : int {
|
||||||
FORM_PARSE_TYPE = 0,
|
FORM_PARSE_TYPE = 0,
|
||||||
FORM_PARSE_SCALAR = 1,
|
FORM_PARSE_SCALAR = 1,
|
||||||
|
@ -305,12 +307,13 @@ bool MSANFModelParser::BuildInputForFuncGraph(const ParameterPtr &node, const mi
|
||||||
node->set_debug_info(debug_info_ptr);
|
node->set_debug_info(debug_info_ptr);
|
||||||
node->set_name(debug_info_name);
|
node->set_name(debug_info_name);
|
||||||
|
|
||||||
const mind_ir::TensorProto &tensor_proto = value_proto.tensor(0);
|
if (value_proto.tensor_size() > 0) {
|
||||||
|
const mind_ir::TensorProto &tensor_proto = value_proto.tensor(0);
|
||||||
tensor::TensorPtr tensor_info = BuildTensorInfoForFuncGraph(tensor_proto);
|
tensor::TensorPtr tensor_info = BuildTensorInfoForFuncGraph(tensor_proto);
|
||||||
MS_EXCEPTION_IF_NULL(tensor_info);
|
MS_EXCEPTION_IF_NULL(tensor_info);
|
||||||
auto tensor_abstract = tensor_info->ToAbstract();
|
auto tensor_abstract = tensor_info->ToAbstract();
|
||||||
node->set_abstract(tensor_abstract);
|
node->set_abstract(tensor_abstract);
|
||||||
|
}
|
||||||
|
|
||||||
anfnode_build_map_[value_proto.name()] = node;
|
anfnode_build_map_[value_proto.name()] = node;
|
||||||
return true;
|
return true;
|
||||||
|
@ -768,8 +771,14 @@ CNodePtr MSANFModelParser::BuildCNodeForFuncGraph(const FuncGraphPtr &outputFunc
|
||||||
if (op_primc_fns.find(node_type) != op_primc_fns.end()) {
|
if (op_primc_fns.find(node_type) != op_primc_fns.end()) {
|
||||||
prim = op_primc_fns[node_type]();
|
prim = op_primc_fns[node_type]();
|
||||||
} else {
|
} else {
|
||||||
prim = std::make_shared<Primitive>(node_type);
|
if (node_type.compare(0, strlen(kDoSignaturePrimitivePrefix), kDoSignaturePrimitivePrefix) == 0) {
|
||||||
prim->set_instance_name(node_type);
|
auto op_name = node_type.substr(strlen(kDoSignaturePrimitivePrefix));
|
||||||
|
prim = std::make_shared<prim::DoSignaturePrimitive>(op_name, std::make_shared<Primitive>(op_name));
|
||||||
|
prim->set_instance_name(op_name);
|
||||||
|
} else {
|
||||||
|
prim = std::make_shared<Primitive>(node_type);
|
||||||
|
prim->set_instance_name(node_type);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
MS_EXCEPTION_IF_NULL(prim);
|
MS_EXCEPTION_IF_NULL(prim);
|
||||||
|
|
||||||
|
@ -812,9 +821,14 @@ CNodePtr MSANFModelParser::BuildCNodeForFuncGraph(const FuncGraphPtr &outputFunc
|
||||||
} else {
|
} else {
|
||||||
AbstractBasePtrList elem;
|
AbstractBasePtrList elem;
|
||||||
for (size_t index = 1; index < cnode_ptr->inputs().size(); ++index) {
|
for (size_t index = 1; index < cnode_ptr->inputs().size(); ++index) {
|
||||||
elem.push_back(cnode_ptr->input(index)->abstract());
|
auto abs = cnode_ptr->input(index)->abstract();
|
||||||
|
if (abs != nullptr) {
|
||||||
|
elem.push_back(abs);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (!elem.empty()) {
|
||||||
|
cnode_ptr->set_abstract(std::make_shared<abstract::AbstractTuple>(elem));
|
||||||
}
|
}
|
||||||
cnode_ptr->set_abstract(std::make_shared<abstract::AbstractTuple>(elem));
|
|
||||||
}
|
}
|
||||||
} else if (kv.size() == 1) {
|
} else if (kv.size() == 1) {
|
||||||
std::unordered_map<std::string, abstract::AbstractBasePtr>::iterator iter = kv.begin();
|
std::unordered_map<std::string, abstract::AbstractBasePtr>::iterator iter = kv.begin();
|
||||||
|
@ -919,6 +933,9 @@ bool MSANFModelParser::BuildFuncGraph(const FuncGraphPtr &outputFuncGraph, const
|
||||||
} else {
|
} else {
|
||||||
MS_LOG(ERROR) << "FuncGraph under converting has not name!";
|
MS_LOG(ERROR) << "FuncGraph under converting has not name!";
|
||||||
}
|
}
|
||||||
|
if (importProto.has_bprop_hash()) {
|
||||||
|
outputFuncGraph->set_bprop_hash(importProto.bprop_hash());
|
||||||
|
}
|
||||||
|
|
||||||
if (!ImportParametersForGraph(outputFuncGraph, importProto)) {
|
if (!ImportParametersForGraph(outputFuncGraph, importProto)) {
|
||||||
MS_LOG(ERROR) << "import parameters for graph fail!";
|
MS_LOG(ERROR) << "import parameters for graph fail!";
|
||||||
|
|
|
@ -80,6 +80,7 @@ message GraphProto {
|
||||||
optional string doc_string = 4;
|
optional string doc_string = 4;
|
||||||
repeated ValueInfoProto input = 5;
|
repeated ValueInfoProto input = 5;
|
||||||
repeated ValueInfoProto output = 6;
|
repeated ValueInfoProto output = 6;
|
||||||
|
optional string bprop_hash = 7;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,9 @@
|
||||||
|
|
||||||
|
0.1.0 MindSpore*1.1.0:î
|
||||||
|
—
|
||||||
|
bprop.10:doutbprop.10:[CNode]12:2bprop.10:[CNode]11:1"S-Prim-MakeTuple:HGradients/Default/network-NetIdentity/gradIdentity/S-Prim-MakeTuple-op15bprop.10*
|
||||||
|
|
||||||
|
bprop.10:x*
|
||||||
|
bprop.10:out*
|
||||||
|
bprop.10:dout2
|
||||||
|
bprop.10:[CNode]12:2:€14cac93a068aa39edcd5220275a7f3df23c79f939b5f52bbe3321d22bc4706d91fc4e05e490c2d17243f1a3b10584567ec050735053d6e4da31b0c0120e747522366f7bd59ea5ec135e982de03b4f7cab6b61d833d046a6e13f78bdaf2fb2b22602dc0172bb790a967e7e5ba7e35d54f6df1ae3014fea781a726693f4c945c0d65c0e00bc893ef15ec6199798d6c8c46997153587d375b3240c1195ff2c7278c695827de181149f40d1a6397a5a742b3216423289a39974991e6c56c2a17a1dd6c22f5fa950412d9825b54837052ad6b1201185ed5a1a4de1d4fb4a60945bfb77e635a08323207b4cb3f73fd8437b4d7ee28a7676a68f005a7749bd19e5ed4eca40d19b1aba47164a5f706a591d6282290a76b4abfa99f3b64d0b95ce409fac3c414b8c313aac4f85c6217fbbb7009dd079b2d5548f8b695a470a11cb8cc83e6cdf493aa69bd2a1bfbbab074bc8076571a7a726e82612665807efc62736694a4f5e78f5b3c67f2e7bf339b250c3638aee952e1a073002e2834011401f3827260
|
|
@ -0,0 +1,11 @@
|
||||||
|
|
||||||
|
0.1.0 MindSpore*1.1.0:å
|
||||||
|
ˆ
|
||||||
|
bprop.2:dout
|
||||||
|
bprop.2:outbprop.2:dx:1bprop.2:dx:1"S-Prim-ReluGrad:>Gradients/Default/network-NetRelu/gradReLU/S-Prim-ReluGrad-op5
|
||||||
|
‰
|
||||||
|
bprop.2:dx:1bprop.2:[CNode]4:3bprop.2:[CNode]3:2"S-Prim-MakeTuple:?Gradients/Default/network-NetRelu/gradReLU/S-Prim-MakeTuple-op6bprop.2*
|
||||||
|
bprop.2:x*
|
||||||
|
bprop.2:out*
|
||||||
|
bprop.2:dout2
|
||||||
|
bprop.2:[CNode]4:3:€14cac93a068aa39edcd5220275a7f3df23c79f939b5f52bbe3321d22bc4706d91fc4e05e490c2d17243f1a3b10584567ec050735053d6e4da31b0c0120e747522366f7bd59ea5ec135e982de03b4f7cab6b61d833d046a6e13f78bdaf2fb2b22602dc0172bb790a967e7e5ba7e35d54f6df1ae3014fea781a726693f4c945c0d65c0e00bc893ef15ec6199798d6c8c46997153587d375b3240c1195ff2c7278c695827de181149f40d1a6397a5a742b3216423289a39974991e6c56c2a17a1dd6c22f5fa950412d9825b54837052ad6b1201185ed5a1a4de1d4fb4a60945bfb77e635a08323207b4cb3f73fd8437b4d7ee28a7676a68f005a7749bd19e5ed4eca40d19b1aba47164a5f706a591d6282290a76b4abfa99f3b64d0b95ce409fac3c414b8c313aac4f85c6217fbbb7009dd079b2d5548f8b695a470a11cb8cc83e6cdf493aa69bd2a1bfbbab074bc8076571a7a726e82612665807efc62736694a4f5e78f5b3c67f2e7bf339b250c3638aee952e1a073002e2834011401f3827260
|
|
@ -0,0 +1,16 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""bprop mindir."""
|
|
@ -0,0 +1,81 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""Generate the mindir for bprop"""
|
||||||
|
import os
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import mindspore.context as context
|
||||||
|
import mindspore.nn as nn
|
||||||
|
from mindspore import Tensor
|
||||||
|
from mindspore.ops import operations as P
|
||||||
|
import mindspore.ops as ops
|
||||||
|
import mindspore.ops._grad as g
|
||||||
|
|
||||||
|
context.set_context(mode=context.GRAPH_MODE)
|
||||||
|
os.environ['GENERATE_MINDIR'] = '1'
|
||||||
|
|
||||||
|
|
||||||
|
class NetRelu(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(NetRelu, self).__init__()
|
||||||
|
self.relu = P.ReLU()
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
return self.relu(x)
|
||||||
|
|
||||||
|
|
||||||
|
class NetIdentity(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(NetIdentity, self).__init__()
|
||||||
|
self.identity = P.Identity()
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
return self.identity(x)
|
||||||
|
|
||||||
|
|
||||||
|
class GradNet(nn.Cell):
|
||||||
|
def __init__(self, network):
|
||||||
|
super(GradNet, self).__init__()
|
||||||
|
self.grad = ops.GradOperation()
|
||||||
|
self.network = network
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
gout = self.grad(self.network)(x)
|
||||||
|
return gout
|
||||||
|
|
||||||
|
|
||||||
|
def test_relu():
|
||||||
|
x = Tensor(np.array([[[[-1, 1, 10],
|
||||||
|
[1, -1, 1],
|
||||||
|
[10, 1, -1]]]]).astype(np.float32))
|
||||||
|
relu = NetRelu()
|
||||||
|
grad = GradNet(relu)
|
||||||
|
grad(x)
|
||||||
|
|
||||||
|
|
||||||
|
def test_identity():
|
||||||
|
x = Tensor(np.array([1, 2, 3, 4]).astype(np.int64))
|
||||||
|
identity = NetIdentity()
|
||||||
|
grad = GradNet(identity)
|
||||||
|
grad(x)
|
||||||
|
|
||||||
|
|
||||||
|
test_relu()
|
||||||
|
test_identity()
|
||||||
|
# mindspore/ops/_grad/__init__.py
|
||||||
|
bprop_path = g.__file__
|
||||||
|
bprop_mindir_path = bprop_path[: bprop_path.rindex('/')] + "/../bprop_mindir/"
|
||||||
|
print("The new bprop mindir files has been generated in the path \"" + bprop_mindir_path +
|
||||||
|
"\", copy the *.mindir to your PYTHONPATH if necessary.")
|
1
setup.py
1
setup.py
|
@ -133,6 +133,7 @@ package_data = {
|
||||||
'lib/*.dylib*',
|
'lib/*.dylib*',
|
||||||
'.commit_id',
|
'.commit_id',
|
||||||
'config/*',
|
'config/*',
|
||||||
|
'ops/bprop_mindir/*',
|
||||||
'include/*',
|
'include/*',
|
||||||
'include/*/*',
|
'include/*/*',
|
||||||
'include/*/*/*',
|
'include/*/*/*',
|
||||||
|
|
|
@ -0,0 +1,23 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""Generate vm_impl function for nn ops without python object"""
|
||||||
|
from mindspore.common.tensor import Tensor
|
||||||
|
from .vm_interface import vm
|
||||||
|
|
||||||
|
def ReluGrad(y_backprop, x):
|
||||||
|
x = x.asnumpy()
|
||||||
|
y_backprop = y_backprop.asnumpy()
|
||||||
|
y_backprop = vm.relu_grad(x.copy()) * y_backprop
|
||||||
|
return Tensor(y_backprop)
|
Loading…
Reference in New Issue