Add bprop cache

This commit is contained in:
yujianfeng 2021-04-25 17:25:33 +08:00
parent a0c5b56f5f
commit e88d058787
20 changed files with 554 additions and 162 deletions

View File

@ -311,4 +311,11 @@ bool Common::SaveStringToFile(const std::string filename, const std::string stri
ChangeFileMode(real_path.value(), S_IRUSR);
return true;
}
bool Common::FileExists(const std::string &filepath) {
std::ifstream f(filepath);
bool cache_file_existed = f.good();
f.close();
return cache_file_existed;
}
} // namespace mindspore

View File

@ -39,6 +39,7 @@ class Common {
static std::string AddId(const std::string &filename, const std::string &suffix);
static bool SaveStringToFile(const std::string filename, const std::string string_info);
static bool FileExists(const std::string &filepath);
private:
static bool IsEveryFilenameValid(const std::string &path, size_t length_limit, const std::string &error_message);

View File

@ -141,7 +141,7 @@ class KPrim {
FuncGraphPtr GetPossibleBprop(const PrimitivePtr &prim);
private:
FuncGraphPtr GetBprop(const PrimitivePtr &prim);
FuncGraphPtr GetBprop(const PrimitivePtr &prim, const pipeline::ResourceBasePtr &resources = nullptr);
FuncGraphPtr GetFprop(const PrimitivePtr &prim);
FuncGraphPtr FakeBprop(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources);
FuncGraphPtr BpropCut(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources);

View File

@ -26,6 +26,7 @@
#include "ir/manager.h"
#include "pipeline/jit/resource.h"
#include "pipeline/jit/parse/parse.h"
#include "pipeline/jit/parse/resolve.h"
#include "frontend/optimizer/ad/dfunctor.h"
#include "frontend/operator/ops.h"
#include "frontend/operator/composite/composite.h"
@ -35,12 +36,102 @@
#include "utils/ms_context.h"
#include "utils/info.h"
#include "debug/trace.h"
#include "debug/common.h"
#include "debug/dump_proto.h"
#include "mindspore/core/load_mindir/load_model.h"
#include "utils/system/sha256.h"
namespace mindspore {
namespace ad {
KPrim g_k_prims;
FuncGraphPtr KPrim::GetBprop(const PrimitivePtr &prim) {
namespace {
constexpr char kBpropMindIRSuffix[] = "_bprop.mindir";
constexpr char kBpropMindIRDir[] = "/../bprop_mindir/";
constexpr char kGenerateMindirEnv[] = "GENERATE_MINDIR";
#ifndef _WIN32
bool IsSerializableBprop(const PrimitivePtr &prim) {
static std::unordered_set<PrimitivePtr> serializable_bprop_list{prim::kPrimRelu, prim::kPrimIdentity};
return std::any_of(serializable_bprop_list.begin(), serializable_bprop_list.end(),
[&prim](const PrimitivePtr &serializable_bprop_prim) {
auto str1 = prim->name();
auto str2 = serializable_bprop_prim->name();
transform(str1.begin(), str1.end(), str1.begin(), ::tolower);
transform(str2.begin(), str2.end(), str2.begin(), ::tolower);
return str1 == str2;
});
}
std::string GetBpropDir() {
static std::string bprop_dir;
if (bprop_dir.empty()) {
py::module mod = py::module::import("mindspore.ops._grad");
auto grad_file_path = mod.attr("__file__").cast<std::string>();
bprop_dir = grad_file_path.substr(0, grad_file_path.find_last_of('/'));
}
return bprop_dir;
}
std::string GetBpropHash() {
static std::string bprop_hash;
if (bprop_hash.empty()) {
auto bprop_dir = GetBpropDir();
auto realpath = Common::GetRealPath(bprop_dir);
if (!realpath.has_value()) {
MS_LOG(EXCEPTION) << "Get real path of bprop dir failed. path=" << bprop_dir;
}
bprop_hash = system::sha256::GetHashFromDir(realpath.value());
}
return bprop_hash;
}
FuncGraphPtr ImportBpropFromMindIR(const PrimitivePtr &prim) {
MS_EXCEPTION_IF_NULL(prim);
std::string bprop_dir = GetBpropDir();
auto bprop_mindir_path = bprop_dir + kBpropMindIRDir;
std::optional<std::string> bprop_mindir_realpath =
Common::GetRealPath(bprop_mindir_path + prim->name() + kBpropMindIRSuffix);
bool bprop_cache_file_exists = bprop_mindir_realpath.has_value() && Common::FileExists(bprop_mindir_realpath.value());
if (!bprop_cache_file_exists) {
return nullptr;
}
auto bprop_fg = LoadMindIR(bprop_mindir_realpath.value());
if (bprop_fg != nullptr && bprop_fg->bprop_hash() != GetBpropHash()) {
MS_LOG(EXCEPTION) << "The bprop mindir files are not up to date. Please run the " << bprop_mindir_path
<< "generate_mindir.py to generate new mindir files.\n"
<< "bprop_fg hash: " << bprop_fg->bprop_hash() << "\n"
<< "bprop hash: " << GetBpropHash();
}
return bprop_fg;
}
void ExportBpropToMindIR(const PrimitivePtr &prim, const FuncGraphPtr &func_graph) {
MS_EXCEPTION_IF_NULL(prim);
std::string bprop_dir = GetBpropDir();
func_graph->set_bprop_hash(GetBpropHash());
auto bprop_mindir_path = bprop_dir + kBpropMindIRDir;
std::optional<std::string> bprop_mindir_realpath =
Common::GetRealPath(bprop_mindir_path + prim->name() + kBpropMindIRSuffix);
if (!bprop_mindir_realpath.has_value()) {
MS_LOG(ERROR) << "Failed to get the realpath of bprop mindir: " << bprop_mindir_path << prim->name()
<< kBpropMindIRSuffix;
return;
}
std::ofstream fout(bprop_mindir_realpath.value());
mind_ir::ModelProto fg_model = GetBinaryProto(func_graph);
if (!fg_model.SerializeToOstream(&fout)) {
MS_LOG(WARNING) << "Failed to cache the bprop of op \"" << prim->name() << "\" to file \""
<< bprop_mindir_realpath.value() << "\".";
}
fout.close();
ChangeFileMode(bprop_mindir_realpath.value(), S_IRUSR | S_IWUSR);
}
#endif
} // namespace
FuncGraphPtr KPrim::GetBprop(const PrimitivePtr &prim, const pipeline::ResourceBasePtr &resources) {
// Set a child scope named "grad'PrimitiveName'" for the bprop function,
// and add "Gradients" to the front.
static const std::string gradients_scope = "Gradients/";
@ -50,6 +141,17 @@ FuncGraphPtr KPrim::GetBprop(const PrimitivePtr &prim) {
grad_op_child_scope_prefix + prim->name());
ScopeGuard scope_guard(scope);
// Firstly we get bprop from mindir. If failed, parse the python function registered.
FuncGraphPtr func_graph = nullptr;
#ifndef _WIN32
bool serializable = IsSerializableBprop(prim);
if (serializable && common::GetEnv(kGenerateMindirEnv) != "1") {
func_graph = ImportBpropFromMindIR(prim);
if (func_graph != nullptr) {
return func_graph;
}
}
#endif
py::function fn;
if (prim->is_base()) {
fn = GetBpropFunction(prim->name());
@ -63,7 +165,7 @@ FuncGraphPtr KPrim::GetBprop(const PrimitivePtr &prim) {
MS_LOG(DEBUG) << "Fail to find bprop function for " << prim->name() << ".";
return nullptr;
}
FuncGraphPtr func_graph = parse::ParsePythonCode(fn);
func_graph = parse::ParsePythonCode(fn);
if (func_graph == nullptr) {
MS_LOG(ERROR) << "Fail to parse bprop function for " << prim->name() << ".";
return nullptr;
@ -72,7 +174,14 @@ FuncGraphPtr KPrim::GetBprop(const PrimitivePtr &prim) {
if (bprop_flag) {
func_graph->set_flag(mindspore::kFuncGraphFlagReAutoMonad, true);
}
pipeline::ResourceBasePtr res = (resources != nullptr) ? resources : std::make_shared<pipeline::Resource>();
parse::ResolveFuncGraph(func_graph, res);
#ifndef _WIN32
// Check whether the bprop needs to be exported.
if (serializable) {
ExportBpropToMindIR(prim, func_graph);
}
#endif
return func_graph;
}
@ -187,7 +296,7 @@ FuncGraphPtr KPrim::KPrimitive(const CNodePtr &cnode, const ValueNodePtr &value_
}
if (bprop_fg == nullptr) {
bprop_fg = GetBprop(prim);
bprop_fg = GetBprop(prim, resources);
if (bprop_fg != nullptr) {
// Set bprop_g graph cache
bprop_registry_[prim] = bprop_fg;

View File

@ -114,11 +114,6 @@ FuncGraphPtr PrimBpOptPassStep1(const opt::irpass::OptimizeIRPassLib &irpass, co
irpass.pynative_eliminate_,
});
opt::irpass::ResolveIRPassLib resolve_irpass;
opt::OptPassConfig resolver_prim = opt::OptPassConfig({
resolve_irpass.resolver_resolve_and_getattr_,
resolve_irpass.resolver_resolve_,
resolve_irpass.resolver_getattr_,
});
opt::OptPassConfig switch_simplify = opt::OptPassConfig({
irpass.switch_simplify_,
@ -133,7 +128,6 @@ FuncGraphPtr PrimBpOptPassStep1(const opt::irpass::OptimizeIRPassLib &irpass, co
});
OptPassGroupMap map({{"ad_eliminate", pynative_eliminate},
{"ad_resolver_prim", resolver_prim},
{"ad_inline", inline_opt},
{"bool_scalar_eliminate", bool_scalar_eliminate},
{"ad_switch_simplify", switch_simplify}});
@ -324,9 +318,8 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
opt::OptPassConfig virtual_dataset = opt::OptPassConfig({irpass.virtual_dataset_eliminate_});
opt::irpass::ResolveIRPassLib resolve_irpass;
opt::OptPassConfig resolve_pass =
opt::OptPassConfig({resolve_irpass.resolver_resolve_, resolve_irpass.resolver_getattr_,
irpass.get_make_ref_eliminate_, irpass.replace_old_param_});
opt::OptPassConfig after_resolve_pass =
opt::OptPassConfig({irpass.get_make_ref_eliminate_, irpass.replace_old_param_});
// Before adjusting map_a, check GetA1A2() and GetOptPynativeGradEpiloguePhases().
OptPassGroupMap map_a({{"a_1", a_1},
@ -339,7 +332,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
{"virtual_dataset", virtual_dataset},
{"virtual_output", opt::OptPassConfig({irpass.virtual_output_eliminate_})},
{"grad", opt::OptPassConfig(opt::irpass::ExpandJPrim())},
{"resolve", resolve_pass},
{"after_resolve", after_resolve_pass},
{"a_after_grad", a_after_grad},
{"renormalize", opt::OptPassConfig::Renormalize()},
{"auto_monad_grad", opt::OptPassConfig(ReAutoMonadWrapper)},

View File

@ -176,6 +176,7 @@ void IrExportBuilder::BuildModelInfo() {
void IrExportBuilder::BuildModel(const FuncGraphPtr &func_graph) {
mind_ir::GraphProto *graph_proto = model_.mutable_graph();
graph_proto->set_name(func_graph->ToString());
graph_proto->set_bprop_hash(func_graph->bprop_hash());
ResetNodeIndex();
todo_.clear();
todo_.push_back(func_graph);
@ -247,6 +248,9 @@ void IrExportBuilder::SetValueInfoProto(const AnfNodePtr &node, mind_ir::ValueIn
MS_LOG(DEBUG) << "SetValueInfoProto: " << node->DebugString();
const TypePtr &type = node->Type();
const BaseShapePtr &shape = node->Shape();
if (type == nullptr || shape == nullptr) {
return;
}
if (type->isa<TensorType>() && shape->isa<abstract::Shape>()) {
auto tensor = type->cast<TensorTypePtr>();
auto elem_type = tensor->element();
@ -404,9 +408,10 @@ void IrExportBuilder::SetShapeToNodeProto(const CNodePtr &node, mind_ir::NodePro
// 3. save tuple string in ref_attr_name
MS_EXCEPTION_IF_NULL(node);
auto type = node->Type();
MS_EXCEPTION_IF_NULL(type);
auto shape = node->Shape();
MS_EXCEPTION_IF_NULL(shape);
if (type == nullptr || shape == nullptr) {
return;
}
ResetTupleIndex();
std::string seq_string = "shape:";
mind_ir::AttributeProto *attr_proto = node_proto->add_attribute();

View File

@ -60,6 +60,29 @@ py::tuple ConvertDatatoPyTuple(const VectorRef &args) {
return py_args;
}
py::function GetComputeFunctionWithoutPyObj(const std::string &name) {
static const std::string module = "tests.vm_impl.vm_impl_function";
py::module mod = py::module::import(common::SafeCStr(module));
if (!py::hasattr(mod, common::SafeCStr(name))) {
return py::none();
}
py::object fn = mod.attr(common::SafeCStr(name));
return fn;
}
BaseRef RunComputeFunctionWithoutPyObj(const PrimitivePtr &prim, const VectorRef &args) {
auto func = GetComputeFunctionWithoutPyObj(prim->name());
if (py::isinstance<py::none>(func)) {
return nullptr;
}
auto py_args = ConvertDatatoPyTuple(args);
py::object obj = func(*py_args);
if (py::isinstance<py::none>(obj)) {
return nullptr;
}
return std::make_shared<PyObjectRef>(obj);
}
BaseRef RunComputeFunction(const PrimitivePtr &prim, const VectorRef &args) {
auto func = GetComputeFunction(prim->name());
if (py::isinstance<py::none>(func)) {

View File

@ -33,6 +33,10 @@ py::function GetComputeFunction(std::string name);
BaseRef RunComputeFunction(const PrimitivePtr &prim, const VectorRef &args);
py::function GetComputeFunctionWithoutPyObj(const std::string &name);
BaseRef RunComputeFunctionWithoutPyObj(const PrimitivePtr &prim, const VectorRef &args);
py::tuple ConvertDatatoPyTuple(const VectorRef &args);
} // namespace mindspore

View File

@ -0,0 +1,201 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "utils/system/sha256.h"
#include <dirent.h>
#include <sys/stat.h>
#include <iomanip>
#include <fstream>
#include <vector>
#include <algorithm>
#include <numeric>
#include "securec/include/securec.h"
#include "utils/log_adapter.h"
namespace mindspore {
namespace system {
namespace sha256 {
constexpr int kBitNumber = 8;
constexpr int kDigestSize = 8;
constexpr int kIterationNumber = 64;
constexpr int kMessageBlockLength = 64;
const uint32_t constant[64] = {
0x428a2f98, 0x71374491, 0xb5c0fbcf, 0xe9b5dba5, 0x3956c25b, 0x59f111f1, 0x923f82a4, 0xab1c5ed5,
0xd807aa98, 0x12835b01, 0x243185be, 0x550c7dc3, 0x72be5d74, 0x80deb1fe, 0x9bdc06a7, 0xc19bf174,
0xe49b69c1, 0xefbe4786, 0x0fc19dc6, 0x240ca1cc, 0x2de92c6f, 0x4a7484aa, 0x5cb0a9dc, 0x76f988da,
0x983e5152, 0xa831c66d, 0xb00327c8, 0xbf597fc7, 0xc6e00bf3, 0xd5a79147, 0x06ca6351, 0x14292967,
0x27b70a85, 0x2e1b2138, 0x4d2c6dfc, 0x53380d13, 0x650a7354, 0x766a0abb, 0x81c2c92e, 0x92722c85,
0xa2bfe8a1, 0xa81a664b, 0xc24b8b70, 0xc76c51a3, 0xd192e819, 0xd6990624, 0xf40e3585, 0x106aa070,
0x19a4c116, 0x1e376c08, 0x2748774c, 0x34b0bcb5, 0x391c0cb3, 0x4ed8aa4a, 0x5b9cca4f, 0x682e6ff3,
0x748f82ee, 0x78a5636f, 0x84c87814, 0x8cc70208, 0x90befffa, 0xa4506ceb, 0xbef9a3f7, 0xc67178f2};
std::string LoadFilePath(const std::string &path) {
char real_path[PATH_MAX] = {0};
#if defined(_WIN32) || defined(_WIN64)
if (path.size() >= PATH_MAX || _fullpath(real_path, path.c_str(), PATH_MAX) == nullptr) {
return "";
}
#else
if (path.size() >= PATH_MAX || realpath(path.c_str(), real_path) == nullptr) {
return "";
}
#endif
std::ifstream bin_stream(real_path, std::ios::binary);
if (!bin_stream.is_open()) {
return "";
}
std::string message((std::istreambuf_iterator<char>(bin_stream)), std::istreambuf_iterator<char>());
return message;
}
bool Padding(std::string *message) {
uint64_t bits_message = message->size() * kBitNumber;
const int remains = message->size() % kMessageBlockLength;
// The length of the message needs to be stored in 8 bytes, supplemented at the end of the message.
const int size_append = 8;
const int size_required = kMessageBlockLength - size_append;
const int size_pad = size_required - remains + (size_required > remains ? 0 : kMessageBlockLength);
if (size_pad < 1 || size_pad > kMessageBlockLength) {
return false;
}
message->push_back(0x80);
for (int i = 1; i < size_pad; ++i) {
message->push_back(0x00);
}
for (int i = size_append - 1; i >= 0; --i) {
message->push_back(static_cast<uint8_t>((bits_message >> static_cast<uint32_t>(i * kBitNumber)) & 0xff));
}
return true;
}
bool ProcessInner(const std::string &message, const int &bias, uint32_t *digest, const int &digest_size) {
if (digest_size != 8) { // The number of digests is fixed at 8
return false;
}
uint32_t w[kIterationNumber] = {0};
for (int i = 0; i < 16; ++i) {
w[i] = (static_cast<uint32_t>(static_cast<uint8_t>(message[bias + i * 4]) & 0xff) << 24) |
(static_cast<uint32_t>(static_cast<uint8_t>(message[bias + i * 4 + 1]) & 0xff) << 16) |
(static_cast<uint32_t>(static_cast<uint8_t>(message[bias + i * 4 + 2]) & 0xff) << 8) |
(static_cast<uint32_t>(static_cast<uint8_t>(message[bias + i * 4 + 3]) & 0xff));
}
for (int i = 16; i < kIterationNumber; ++i) {
w[i] = sigma3(w[i - 2]) + w[i - 7] + sigma2(w[i - 15]) + w[i - 16];
}
std::vector<uint32_t> hash(digest_size);
size_t mem_size = digest_size * sizeof(uint32_t);
auto ret = memcpy_s(hash.data(), mem_size, digest, mem_size);
if (ret != EOK) {
return false;
}
for (int i = 0; i < kIterationNumber; ++i) {
uint32_t t1 = w[i] + constant[i] + hash[7] + sigma1(hash[4]) + ch(hash[4], hash[5], hash[6]);
uint32_t t2 = sigma0(hash[0]) + ma(hash[0], hash[1], hash[2]);
for (int j = digest_size - 1; j >= 0; --j) {
if (j == 4) {
hash[j] = hash[j - 1] + t1;
} else if (j == 0) {
hash[j] = t1 + t2;
} else {
hash[j] = hash[j - 1];
}
}
}
for (int i = 0; i < digest_size; ++i) {
digest[i] += hash[i];
}
return true;
}
std::string ConvertToString(uint32_t *input, const int &size) {
std::ostringstream oss;
oss << std::hex;
for (int i = 0; i < size; ++i) {
for (int j = static_cast<int>(sizeof(uint32_t) / sizeof(uint8_t)) - 1; j >= 0; --j) {
auto val = static_cast<uint8_t>((input[i] >> static_cast<uint32_t>(j * kBitNumber)) & 0xff);
oss << std::setw(2) << std::setfill('0') << static_cast<unsigned int>(val);
}
}
return oss.str();
}
std::string Encrypt(const std::string &message) {
uint32_t digest[kDigestSize] = {0x6a09e667, 0xbb67ae85, 0x3c6ef372, 0xa54ff53a,
0x510e527f, 0x9b05688c, 0x1f83d9ab, 0x5be0cd19};
for (int i = 0; i < static_cast<int>(message.size()); i += kMessageBlockLength) {
if (!ProcessInner(message, i, digest, kDigestSize)) {
return "";
}
}
return ConvertToString(digest, kDigestSize);
}
std::string GetHashFromString(const std::string &data) {
std::string message = data;
if (message.empty() || !Padding(&message)) {
return "";
}
return Encrypt(message);
}
std::string GetHashFromFile(const std::string &path) {
std::string message = LoadFilePath(path);
if (message.empty() || !Padding(&message)) {
return "";
}
return Encrypt(message);
}
#ifndef _WIN32
std::string GetHashFromDir(const std::string &dir) {
if (dir.empty()) {
MS_LOG(ERROR) << "The directory path is empty.";
return "";
}
struct stat s {};
int ret = stat(dir.c_str(), &s);
if (ret != 0) {
MS_LOG(ERROR) << "stat dir \"" << dir << "\" failed, ret is : " << ret;
return "";
}
if (!S_ISDIR(s.st_mode)) {
MS_LOG(ERROR) << "The path \"" << dir << "\" is not a directory.";
return "";
}
DIR *open_dir = opendir(dir.c_str());
if (open_dir == nullptr) {
MS_LOG(ERROR) << "open dir " << dir.c_str() << " failed";
return "";
}
struct dirent *filename;
std::vector<std::string> file_hashes;
while ((filename = readdir(open_dir)) != nullptr) {
std::string d_name = std::string(filename->d_name);
if (d_name == "." || d_name == ".." || filename->d_type != DT_REG) {
continue;
}
file_hashes.emplace_back(GetHashFromFile(std::string(dir) + "/" + filename->d_name));
}
closedir(open_dir);
std::sort(file_hashes.begin(), file_hashes.end());
auto dir_hash = std::accumulate(file_hashes.begin(), file_hashes.end(), std::string{});
return dir_hash;
}
#endif
} // namespace sha256
} // namespace system
} // namespace mindspore

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -18,28 +18,10 @@
#define MINDSPORE_CCSRC_UTILS_SYSTEM_SHA256_H_
#include <string>
#include <iomanip>
#include <fstream>
#include <memory>
#include <vector>
namespace mindspore {
namespace system {
namespace sha256 {
constexpr int kBitNumber = 8;
constexpr int kDigestSize = 8;
constexpr int kIterationNumber = 64;
constexpr int kMessageBlockLength = 64;
const uint32_t constant[64] = {
0x428a2f98, 0x71374491, 0xb5c0fbcf, 0xe9b5dba5, 0x3956c25b, 0x59f111f1, 0x923f82a4, 0xab1c5ed5,
0xd807aa98, 0x12835b01, 0x243185be, 0x550c7dc3, 0x72be5d74, 0x80deb1fe, 0x9bdc06a7, 0xc19bf174,
0xe49b69c1, 0xefbe4786, 0x0fc19dc6, 0x240ca1cc, 0x2de92c6f, 0x4a7484aa, 0x5cb0a9dc, 0x76f988da,
0x983e5152, 0xa831c66d, 0xb00327c8, 0xbf597fc7, 0xc6e00bf3, 0xd5a79147, 0x06ca6351, 0x14292967,
0x27b70a85, 0x2e1b2138, 0x4d2c6dfc, 0x53380d13, 0x650a7354, 0x766a0abb, 0x81c2c92e, 0x92722c85,
0xa2bfe8a1, 0xa81a664b, 0xc24b8b70, 0xc76c51a3, 0xd192e819, 0xd6990624, 0xf40e3585, 0x106aa070,
0x19a4c116, 0x1e376c08, 0x2748774c, 0x34b0bcb5, 0x391c0cb3, 0x4ed8aa4a, 0x5b9cca4f, 0x682e6ff3,
0x748f82ee, 0x78a5636f, 0x84c87814, 0x8cc70208, 0x90befffa, 0xa4506ceb, 0xbef9a3f7, 0xc67178f2};
inline uint32_t ch(uint32_t x, uint32_t y, uint32_t z) { return (x & y) ^ ((~x) & z); }
inline uint32_t ma(uint32_t x, uint32_t y, uint32_t z) { return (x & y) ^ (x & z) ^ (y & z); }
inline uint32_t sigma0(uint32_t x) { return (x >> 2 | x << 30) ^ (x >> 13 | x << 19) ^ (x >> 22 | x << 10); }
@ -47,123 +29,23 @@ inline uint32_t sigma1(uint32_t x) { return (x >> 6 | x << 26) ^ (x >> 11 | x <<
inline uint32_t sigma2(uint32_t x) { return (x >> 7 | x << 25) ^ (x >> 18 | x << 14) ^ (x >> 3); }
inline uint32_t sigma3(uint32_t x) { return (x >> 17 | x << 15) ^ (x >> 19 | x << 13) ^ (x >> 10); }
std::string LoadFilePath(const std::string &path) {
char real_path[PATH_MAX] = {0};
#if defined(_WIN32) || defined(_WIN64)
if (path.size() >= PATH_MAX || _fullpath(real_path, path.c_str(), PATH_MAX) == nullptr) {
return "";
}
#else
if (path.size() >= PATH_MAX || realpath(path.c_str(), real_path) == nullptr) {
return "";
}
std::string LoadFilePath(const std::string &path);
bool Padding(std::string *message);
bool ProcessInner(const std::string &message, const int &bias, uint32_t *digest, const int &digest_size);
std::string ConvertToString(uint32_t *input, const int &size);
std::string Encrypt(const std::string &message);
std::string GetHashFromString(const std::string &data);
std::string GetHashFromFile(const std::string &path);
#ifndef _WIN32
std::string GetHashFromDir(const std::string &dir);
#endif
std::ifstream bin_stream(real_path, std::ios::binary);
if (!bin_stream.is_open()) {
return "";
}
std::string message((std::istreambuf_iterator<char>(bin_stream)), std::istreambuf_iterator<char>());
return message;
}
bool Padding(std::string *message) {
uint64_t bits_message = message->size() * kBitNumber;
const int remains = message->size() % kMessageBlockLength;
// The length of the message needs to be stored in 8 bytes, supplemented at the end of the message.
const int size_append = 8;
const int size_required = kMessageBlockLength - size_append;
const int size_pad = size_required - remains + (size_required > remains ? 0 : kMessageBlockLength);
if (size_pad < 1 || size_pad > kMessageBlockLength) {
return false;
}
message->push_back(0x80);
for (int i = 1; i < size_pad; ++i) {
message->push_back(0x00);
}
for (int i = size_append - 1; i >= 0; --i) {
message->push_back(static_cast<uint8_t>((bits_message >> static_cast<uint32_t>(i * kBitNumber)) & 0xff));
}
return true;
}
bool ProcessInner(const std::string &message, const int &bias, uint32_t *digest, const int &digest_size) {
if (digest_size != 8) { // The number of digests is fixed at 8
return false;
}
uint32_t w[kIterationNumber] = {0};
for (int i = 0; i < 16; ++i) {
w[i] = (static_cast<uint32_t>(static_cast<uint8_t>(message[bias + i * 4]) & 0xff) << 24) |
(static_cast<uint32_t>(static_cast<uint8_t>(message[bias + i * 4 + 1]) & 0xff) << 16) |
(static_cast<uint32_t>(static_cast<uint8_t>(message[bias + i * 4 + 2]) & 0xff) << 8) |
(static_cast<uint32_t>(static_cast<uint8_t>(message[bias + i * 4 + 3]) & 0xff));
}
for (int i = 16; i < kIterationNumber; ++i) {
w[i] = sigma3(w[i - 2]) + w[i - 7] + sigma2(w[i - 15]) + w[i - 16];
}
std::vector<uint32_t> hash(digest_size);
size_t mem_size = digest_size * sizeof(uint32_t);
auto ret = memcpy_s(hash.data(), mem_size, digest, mem_size);
if (ret != EOK) {
return false;
}
for (int i = 0; i < kIterationNumber; ++i) {
uint32_t t1 = w[i] + constant[i] + hash[7] + sigma1(hash[4]) + ch(hash[4], hash[5], hash[6]);
uint32_t t2 = sigma0(hash[0]) + ma(hash[0], hash[1], hash[2]);
for (int j = digest_size - 1; j >= 0; --j) {
if (j == 4) {
hash[j] = hash[j - 1] + t1;
} else if (j == 0) {
hash[j] = t1 + t2;
} else {
hash[j] = hash[j - 1];
}
}
}
for (int i = 0; i < digest_size; ++i) {
digest[i] += hash[i];
}
return true;
}
std::string ConvertToString(uint32_t *input, const int &size) {
std::ostringstream oss;
oss << std::hex;
for (int i = 0; i < size; ++i) {
for (int j = static_cast<int>(sizeof(uint32_t) / sizeof(uint8_t)) - 1; j >= 0; --j) {
auto val = static_cast<uint8_t>((input[i] >> static_cast<uint32_t>(j * kBitNumber)) & 0xff);
oss << std::setw(2) << std::setfill('0') << static_cast<unsigned int>(val);
}
}
return oss.str();
}
std::string Encrypt(const std::string &message) {
uint32_t digest[kDigestSize] = {0x6a09e667, 0xbb67ae85, 0x3c6ef372, 0xa54ff53a,
0x510e527f, 0x9b05688c, 0x1f83d9ab, 0x5be0cd19};
for (int i = 0; i < static_cast<int>(message.size()); i += kMessageBlockLength) {
if (!ProcessInner(message, i, digest, kDigestSize)) {
return "";
}
}
return ConvertToString(digest, kDigestSize);
}
std::string GetHashFromString(const std::string &data) {
std::string message = data;
if (message.empty() || !Padding(&message)) {
return "";
}
return Encrypt(message);
}
std::string GetHashFromFile(const std::string &path) {
std::string message = LoadFilePath(path);
if (message.empty() || !Padding(&message)) {
return "";
}
return Encrypt(message);
}
} // namespace sha256
} // namespace system
} // namespace mindspore

View File

@ -437,6 +437,9 @@ BaseRef RunOperation(const PrimitivePtr &prim, const VectorRef &args) {
MS_LOG(DEBUG) << "operation start " << prim->name();
MS_EXCEPTION_IF_NULL(prim);
auto result = prim->RunComputeFunction(args);
if (result.is_null()) {
result = RunComputeFunctionWithoutPyObj(prim, args);
}
if (result.is_null()) {
return RunComputeFunction(prim, args);
}

View File

@ -393,6 +393,9 @@ class FuncGraph : public FuncGraphBase, public EffectInfoHolder {
bool dropped() const { return dropped_; }
void set_dropped(bool dropped) { dropped_ = dropped; }
std::string bprop_hash() const { return bprop_hash_; }
void set_bprop_hash(const std::string &bprop_hash) { bprop_hash_ = bprop_hash; }
private:
// Only used for func_graph manager to control resource free.
int attached_mng_cnt() const { return attached_mng_cnt_; }
@ -477,6 +480,8 @@ class FuncGraph : public FuncGraphBase, public EffectInfoHolder {
// If the graph was changed, it should be dropped in cache data_converter::object_map_
// which used by ConvertToFuncGraph.
bool dropped_ = false;
// If the graph is a bprop graph, it should has a hash of the bprop directory.
std::string bprop_hash_;
};
inline CNodePtr NewCNode(const std::vector<AnfNodePtr> &inputs, const FuncGraphPtr &fg) {

View File

@ -40,6 +40,8 @@ static constexpr char kConstantValueNode[] = "Constant";
static constexpr char kCNodeShapeAttr[] = "shape";
static constexpr char kCNodeShape1Attr[] = "shape1";
static constexpr char kCNodeShape2Attr[] = "shape2";
static constexpr char kDoSignaturePrimitivePrefix[] = "S-Prim-";
enum ParseForm : int {
FORM_PARSE_TYPE = 0,
FORM_PARSE_SCALAR = 1,
@ -305,12 +307,13 @@ bool MSANFModelParser::BuildInputForFuncGraph(const ParameterPtr &node, const mi
node->set_debug_info(debug_info_ptr);
node->set_name(debug_info_name);
const mind_ir::TensorProto &tensor_proto = value_proto.tensor(0);
tensor::TensorPtr tensor_info = BuildTensorInfoForFuncGraph(tensor_proto);
MS_EXCEPTION_IF_NULL(tensor_info);
auto tensor_abstract = tensor_info->ToAbstract();
node->set_abstract(tensor_abstract);
if (value_proto.tensor_size() > 0) {
const mind_ir::TensorProto &tensor_proto = value_proto.tensor(0);
tensor::TensorPtr tensor_info = BuildTensorInfoForFuncGraph(tensor_proto);
MS_EXCEPTION_IF_NULL(tensor_info);
auto tensor_abstract = tensor_info->ToAbstract();
node->set_abstract(tensor_abstract);
}
anfnode_build_map_[value_proto.name()] = node;
return true;
@ -768,8 +771,14 @@ CNodePtr MSANFModelParser::BuildCNodeForFuncGraph(const FuncGraphPtr &outputFunc
if (op_primc_fns.find(node_type) != op_primc_fns.end()) {
prim = op_primc_fns[node_type]();
} else {
prim = std::make_shared<Primitive>(node_type);
prim->set_instance_name(node_type);
if (node_type.compare(0, strlen(kDoSignaturePrimitivePrefix), kDoSignaturePrimitivePrefix) == 0) {
auto op_name = node_type.substr(strlen(kDoSignaturePrimitivePrefix));
prim = std::make_shared<prim::DoSignaturePrimitive>(op_name, std::make_shared<Primitive>(op_name));
prim->set_instance_name(op_name);
} else {
prim = std::make_shared<Primitive>(node_type);
prim->set_instance_name(node_type);
}
}
MS_EXCEPTION_IF_NULL(prim);
@ -812,9 +821,14 @@ CNodePtr MSANFModelParser::BuildCNodeForFuncGraph(const FuncGraphPtr &outputFunc
} else {
AbstractBasePtrList elem;
for (size_t index = 1; index < cnode_ptr->inputs().size(); ++index) {
elem.push_back(cnode_ptr->input(index)->abstract());
auto abs = cnode_ptr->input(index)->abstract();
if (abs != nullptr) {
elem.push_back(abs);
}
}
if (!elem.empty()) {
cnode_ptr->set_abstract(std::make_shared<abstract::AbstractTuple>(elem));
}
cnode_ptr->set_abstract(std::make_shared<abstract::AbstractTuple>(elem));
}
} else if (kv.size() == 1) {
std::unordered_map<std::string, abstract::AbstractBasePtr>::iterator iter = kv.begin();
@ -919,6 +933,9 @@ bool MSANFModelParser::BuildFuncGraph(const FuncGraphPtr &outputFuncGraph, const
} else {
MS_LOG(ERROR) << "FuncGraph under converting has not name!";
}
if (importProto.has_bprop_hash()) {
outputFuncGraph->set_bprop_hash(importProto.bprop_hash());
}
if (!ImportParametersForGraph(outputFuncGraph, importProto)) {
MS_LOG(ERROR) << "import parameters for graph fail!";

View File

@ -80,6 +80,7 @@ message GraphProto {
optional string doc_string = 4;
repeated ValueInfoProto input = 5;
repeated ValueInfoProto output = 6;
optional string bprop_hash = 7;
}

View File

@ -0,0 +1,9 @@
0.1.0 MindSpore*1.1.0:î

bprop.10:doutbprop.10:[CNode]12:2bprop.10:[CNode]11:1"S-Prim-MakeTuple:HGradients/Default/network-NetIdentity/gradIdentity/S-Prim-MakeTuple-op15bprop.10*
bprop.10:x*
bprop.10:out*
bprop.10:dout2
bprop.10:[CNode]12:2:€14cac93a068aa39edcd5220275a7f3df23c79f939b5f52bbe3321d22bc4706d91fc4e05e490c2d17243f1a3b10584567ec050735053d6e4da31b0c0120e747522366f7bd59ea5ec135e982de03b4f7cab6b61d833d046a6e13f78bdaf2fb2b22602dc0172bb790a967e7e5ba7e35d54f6df1ae3014fea781a726693f4c945c0d65c0e00bc893ef15ec6199798d6c8c46997153587d375b3240c1195ff2c7278c695827de181149f40d1a6397a5a742b3216423289a39974991e6c56c2a17a1dd6c22f5fa950412d9825b54837052ad6b1201185ed5a1a4de1d4fb4a60945bfb77e635a08323207b4cb3f73fd8437b4d7ee28a7676a68f005a7749bd19e5ed4eca40d19b1aba47164a5f706a591d6282290a76b4abfa99f3b64d0b95ce409fac3c414b8c313aac4f85c6217fbbb7009dd079b2d5548f8b695a470a11cb8cc83e6cdf493aa69bd2a1bfbbab074bc8076571a7a726e82612665807efc62736694a4f5e78f5b3c67f2e7bf339b250c3638aee952e1a073002e2834011401f3827260

View File

@ -0,0 +1,11 @@
0.1.0 MindSpore*1.1.0:å
ˆ
bprop.2:dout
bprop.2:out bprop.2:dx:1 bprop.2:dx:1"S-Prim-ReluGrad:>Gradients/Default/network-NetRelu/gradReLU/S-Prim-ReluGrad-op5

bprop.2:dx:1bprop.2:[CNode]4:3bprop.2:[CNode]3:2"S-Prim-MakeTuple:?Gradients/Default/network-NetRelu/gradReLU/S-Prim-MakeTuple-op6bprop.2*
bprop.2:x*
bprop.2:out*
bprop.2:dout2
bprop.2:[CNode]4:3:€14cac93a068aa39edcd5220275a7f3df23c79f939b5f52bbe3321d22bc4706d91fc4e05e490c2d17243f1a3b10584567ec050735053d6e4da31b0c0120e747522366f7bd59ea5ec135e982de03b4f7cab6b61d833d046a6e13f78bdaf2fb2b22602dc0172bb790a967e7e5ba7e35d54f6df1ae3014fea781a726693f4c945c0d65c0e00bc893ef15ec6199798d6c8c46997153587d375b3240c1195ff2c7278c695827de181149f40d1a6397a5a742b3216423289a39974991e6c56c2a17a1dd6c22f5fa950412d9825b54837052ad6b1201185ed5a1a4de1d4fb4a60945bfb77e635a08323207b4cb3f73fd8437b4d7ee28a7676a68f005a7749bd19e5ed4eca40d19b1aba47164a5f706a591d6282290a76b4abfa99f3b64d0b95ce409fac3c414b8c313aac4f85c6217fbbb7009dd079b2d5548f8b695a470a11cb8cc83e6cdf493aa69bd2a1bfbbab074bc8076571a7a726e82612665807efc62736694a4f5e78f5b3c67f2e7bf339b250c3638aee952e1a073002e2834011401f3827260

View File

@ -0,0 +1,16 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""bprop mindir."""

View File

@ -0,0 +1,81 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Generate the mindir for bprop"""
import os
import numpy as np
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.ops import operations as P
import mindspore.ops as ops
import mindspore.ops._grad as g
context.set_context(mode=context.GRAPH_MODE)
os.environ['GENERATE_MINDIR'] = '1'
class NetRelu(nn.Cell):
def __init__(self):
super(NetRelu, self).__init__()
self.relu = P.ReLU()
def construct(self, x):
return self.relu(x)
class NetIdentity(nn.Cell):
def __init__(self):
super(NetIdentity, self).__init__()
self.identity = P.Identity()
def construct(self, x):
return self.identity(x)
class GradNet(nn.Cell):
def __init__(self, network):
super(GradNet, self).__init__()
self.grad = ops.GradOperation()
self.network = network
def construct(self, x):
gout = self.grad(self.network)(x)
return gout
def test_relu():
x = Tensor(np.array([[[[-1, 1, 10],
[1, -1, 1],
[10, 1, -1]]]]).astype(np.float32))
relu = NetRelu()
grad = GradNet(relu)
grad(x)
def test_identity():
x = Tensor(np.array([1, 2, 3, 4]).astype(np.int64))
identity = NetIdentity()
grad = GradNet(identity)
grad(x)
test_relu()
test_identity()
# mindspore/ops/_grad/__init__.py
bprop_path = g.__file__
bprop_mindir_path = bprop_path[: bprop_path.rindex('/')] + "/../bprop_mindir/"
print("The new bprop mindir files has been generated in the path \"" + bprop_mindir_path +
"\", copy the *.mindir to your PYTHONPATH if necessary.")

View File

@ -133,6 +133,7 @@ package_data = {
'lib/*.dylib*',
'.commit_id',
'config/*',
'ops/bprop_mindir/*',
'include/*',
'include/*/*',
'include/*/*/*',

View File

@ -0,0 +1,23 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Generate vm_impl function for nn ops without python object"""
from mindspore.common.tensor import Tensor
from .vm_interface import vm
def ReluGrad(y_backprop, x):
x = x.asnumpy()
y_backprop = y_backprop.asnumpy()
y_backprop = vm.relu_grad(x.copy()) * y_backprop
return Tensor(y_backprop)