From e88d0587879a79926f67d38130860126b6c74125 Mon Sep 17 00:00:00 2001 From: yujianfeng Date: Sun, 25 Apr 2021 17:25:33 +0800 Subject: [PATCH] Add bprop cache --- mindspore/ccsrc/debug/common.cc | 7 + mindspore/ccsrc/debug/common.h | 1 + .../ccsrc/frontend/optimizer/ad/dfunctor.h | 2 +- .../ccsrc/frontend/optimizer/ad/kprim.cc | 117 +++++++++- mindspore/ccsrc/pipeline/jit/pass.cc | 13 +- .../transform/express_ir/mindir_exporter.cc | 9 +- mindspore/ccsrc/utils/primitive_utils.cc | 23 ++ mindspore/ccsrc/utils/primitive_utils.h | 4 + mindspore/ccsrc/utils/system/sha256.cc | 201 ++++++++++++++++++ mindspore/ccsrc/utils/system/sha256.h | 152 ++----------- mindspore/ccsrc/vm/vmimpl.cc | 3 + mindspore/core/ir/func_graph.h | 5 + .../core/load_mindir/anf_model_parser.cc | 37 +++- mindspore/core/proto/mind_ir.proto | 1 + .../ops/bprop_mindir/Identity_bprop.mindir | 9 + mindspore/ops/bprop_mindir/ReLU_bprop.mindir | 11 + mindspore/ops/bprop_mindir/__init__.py | 16 ++ mindspore/ops/bprop_mindir/generate_mindir.py | 81 +++++++ setup.py | 1 + tests/vm_impl/vm_impl_function.py | 23 ++ 20 files changed, 554 insertions(+), 162 deletions(-) create mode 100644 mindspore/ccsrc/utils/system/sha256.cc create mode 100644 mindspore/ops/bprop_mindir/Identity_bprop.mindir create mode 100644 mindspore/ops/bprop_mindir/ReLU_bprop.mindir create mode 100644 mindspore/ops/bprop_mindir/__init__.py create mode 100644 mindspore/ops/bprop_mindir/generate_mindir.py create mode 100644 tests/vm_impl/vm_impl_function.py diff --git a/mindspore/ccsrc/debug/common.cc b/mindspore/ccsrc/debug/common.cc index dfe54e22682..7d9a82912c7 100644 --- a/mindspore/ccsrc/debug/common.cc +++ b/mindspore/ccsrc/debug/common.cc @@ -311,4 +311,11 @@ bool Common::SaveStringToFile(const std::string filename, const std::string stri ChangeFileMode(real_path.value(), S_IRUSR); return true; } + +bool Common::FileExists(const std::string &filepath) { + std::ifstream f(filepath); + bool cache_file_existed = f.good(); + f.close(); + return cache_file_existed; +} } // namespace mindspore diff --git a/mindspore/ccsrc/debug/common.h b/mindspore/ccsrc/debug/common.h index 7737ceb5915..dd8f1d5c8c4 100644 --- a/mindspore/ccsrc/debug/common.h +++ b/mindspore/ccsrc/debug/common.h @@ -39,6 +39,7 @@ class Common { static std::string AddId(const std::string &filename, const std::string &suffix); static bool SaveStringToFile(const std::string filename, const std::string string_info); + static bool FileExists(const std::string &filepath); private: static bool IsEveryFilenameValid(const std::string &path, size_t length_limit, const std::string &error_message); diff --git a/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.h b/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.h index fbda48adb62..5965027c7be 100644 --- a/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.h +++ b/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.h @@ -141,7 +141,7 @@ class KPrim { FuncGraphPtr GetPossibleBprop(const PrimitivePtr &prim); private: - FuncGraphPtr GetBprop(const PrimitivePtr &prim); + FuncGraphPtr GetBprop(const PrimitivePtr &prim, const pipeline::ResourceBasePtr &resources = nullptr); FuncGraphPtr GetFprop(const PrimitivePtr &prim); FuncGraphPtr FakeBprop(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources); FuncGraphPtr BpropCut(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources); diff --git a/mindspore/ccsrc/frontend/optimizer/ad/kprim.cc b/mindspore/ccsrc/frontend/optimizer/ad/kprim.cc index 6d90da0321c..438bc1c3416 100644 --- a/mindspore/ccsrc/frontend/optimizer/ad/kprim.cc +++ b/mindspore/ccsrc/frontend/optimizer/ad/kprim.cc @@ -26,6 +26,7 @@ #include "ir/manager.h" #include "pipeline/jit/resource.h" #include "pipeline/jit/parse/parse.h" +#include "pipeline/jit/parse/resolve.h" #include "frontend/optimizer/ad/dfunctor.h" #include "frontend/operator/ops.h" #include "frontend/operator/composite/composite.h" @@ -35,12 +36,102 @@ #include "utils/ms_context.h" #include "utils/info.h" #include "debug/trace.h" +#include "debug/common.h" +#include "debug/dump_proto.h" +#include "mindspore/core/load_mindir/load_model.h" +#include "utils/system/sha256.h" namespace mindspore { namespace ad { KPrim g_k_prims; -FuncGraphPtr KPrim::GetBprop(const PrimitivePtr &prim) { +namespace { +constexpr char kBpropMindIRSuffix[] = "_bprop.mindir"; +constexpr char kBpropMindIRDir[] = "/../bprop_mindir/"; +constexpr char kGenerateMindirEnv[] = "GENERATE_MINDIR"; + +#ifndef _WIN32 +bool IsSerializableBprop(const PrimitivePtr &prim) { + static std::unordered_set serializable_bprop_list{prim::kPrimRelu, prim::kPrimIdentity}; + + return std::any_of(serializable_bprop_list.begin(), serializable_bprop_list.end(), + [&prim](const PrimitivePtr &serializable_bprop_prim) { + auto str1 = prim->name(); + auto str2 = serializable_bprop_prim->name(); + transform(str1.begin(), str1.end(), str1.begin(), ::tolower); + transform(str2.begin(), str2.end(), str2.begin(), ::tolower); + return str1 == str2; + }); +} + +std::string GetBpropDir() { + static std::string bprop_dir; + if (bprop_dir.empty()) { + py::module mod = py::module::import("mindspore.ops._grad"); + auto grad_file_path = mod.attr("__file__").cast(); + bprop_dir = grad_file_path.substr(0, grad_file_path.find_last_of('/')); + } + return bprop_dir; +} + +std::string GetBpropHash() { + static std::string bprop_hash; + if (bprop_hash.empty()) { + auto bprop_dir = GetBpropDir(); + auto realpath = Common::GetRealPath(bprop_dir); + if (!realpath.has_value()) { + MS_LOG(EXCEPTION) << "Get real path of bprop dir failed. path=" << bprop_dir; + } + bprop_hash = system::sha256::GetHashFromDir(realpath.value()); + } + return bprop_hash; +} + +FuncGraphPtr ImportBpropFromMindIR(const PrimitivePtr &prim) { + MS_EXCEPTION_IF_NULL(prim); + std::string bprop_dir = GetBpropDir(); + auto bprop_mindir_path = bprop_dir + kBpropMindIRDir; + std::optional bprop_mindir_realpath = + Common::GetRealPath(bprop_mindir_path + prim->name() + kBpropMindIRSuffix); + bool bprop_cache_file_exists = bprop_mindir_realpath.has_value() && Common::FileExists(bprop_mindir_realpath.value()); + if (!bprop_cache_file_exists) { + return nullptr; + } + auto bprop_fg = LoadMindIR(bprop_mindir_realpath.value()); + if (bprop_fg != nullptr && bprop_fg->bprop_hash() != GetBpropHash()) { + MS_LOG(EXCEPTION) << "The bprop mindir files are not up to date. Please run the " << bprop_mindir_path + << "generate_mindir.py to generate new mindir files.\n" + << "bprop_fg hash: " << bprop_fg->bprop_hash() << "\n" + << "bprop hash: " << GetBpropHash(); + } + return bprop_fg; +} + +void ExportBpropToMindIR(const PrimitivePtr &prim, const FuncGraphPtr &func_graph) { + MS_EXCEPTION_IF_NULL(prim); + std::string bprop_dir = GetBpropDir(); + func_graph->set_bprop_hash(GetBpropHash()); + auto bprop_mindir_path = bprop_dir + kBpropMindIRDir; + std::optional bprop_mindir_realpath = + Common::GetRealPath(bprop_mindir_path + prim->name() + kBpropMindIRSuffix); + if (!bprop_mindir_realpath.has_value()) { + MS_LOG(ERROR) << "Failed to get the realpath of bprop mindir: " << bprop_mindir_path << prim->name() + << kBpropMindIRSuffix; + return; + } + std::ofstream fout(bprop_mindir_realpath.value()); + mind_ir::ModelProto fg_model = GetBinaryProto(func_graph); + if (!fg_model.SerializeToOstream(&fout)) { + MS_LOG(WARNING) << "Failed to cache the bprop of op \"" << prim->name() << "\" to file \"" + << bprop_mindir_realpath.value() << "\"."; + } + fout.close(); + ChangeFileMode(bprop_mindir_realpath.value(), S_IRUSR | S_IWUSR); +} +#endif +} // namespace + +FuncGraphPtr KPrim::GetBprop(const PrimitivePtr &prim, const pipeline::ResourceBasePtr &resources) { // Set a child scope named "grad'PrimitiveName'" for the bprop function, // and add "Gradients" to the front. static const std::string gradients_scope = "Gradients/"; @@ -50,6 +141,17 @@ FuncGraphPtr KPrim::GetBprop(const PrimitivePtr &prim) { grad_op_child_scope_prefix + prim->name()); ScopeGuard scope_guard(scope); + // Firstly we get bprop from mindir. If failed, parse the python function registered. + FuncGraphPtr func_graph = nullptr; +#ifndef _WIN32 + bool serializable = IsSerializableBprop(prim); + if (serializable && common::GetEnv(kGenerateMindirEnv) != "1") { + func_graph = ImportBpropFromMindIR(prim); + if (func_graph != nullptr) { + return func_graph; + } + } +#endif py::function fn; if (prim->is_base()) { fn = GetBpropFunction(prim->name()); @@ -63,7 +165,7 @@ FuncGraphPtr KPrim::GetBprop(const PrimitivePtr &prim) { MS_LOG(DEBUG) << "Fail to find bprop function for " << prim->name() << "."; return nullptr; } - FuncGraphPtr func_graph = parse::ParsePythonCode(fn); + func_graph = parse::ParsePythonCode(fn); if (func_graph == nullptr) { MS_LOG(ERROR) << "Fail to parse bprop function for " << prim->name() << "."; return nullptr; @@ -72,7 +174,14 @@ FuncGraphPtr KPrim::GetBprop(const PrimitivePtr &prim) { if (bprop_flag) { func_graph->set_flag(mindspore::kFuncGraphFlagReAutoMonad, true); } - + pipeline::ResourceBasePtr res = (resources != nullptr) ? resources : std::make_shared(); + parse::ResolveFuncGraph(func_graph, res); +#ifndef _WIN32 + // Check whether the bprop needs to be exported. + if (serializable) { + ExportBpropToMindIR(prim, func_graph); + } +#endif return func_graph; } @@ -187,7 +296,7 @@ FuncGraphPtr KPrim::KPrimitive(const CNodePtr &cnode, const ValueNodePtr &value_ } if (bprop_fg == nullptr) { - bprop_fg = GetBprop(prim); + bprop_fg = GetBprop(prim, resources); if (bprop_fg != nullptr) { // Set bprop_g graph cache bprop_registry_[prim] = bprop_fg; diff --git a/mindspore/ccsrc/pipeline/jit/pass.cc b/mindspore/ccsrc/pipeline/jit/pass.cc index 817089d4e23..3c051689218 100644 --- a/mindspore/ccsrc/pipeline/jit/pass.cc +++ b/mindspore/ccsrc/pipeline/jit/pass.cc @@ -114,11 +114,6 @@ FuncGraphPtr PrimBpOptPassStep1(const opt::irpass::OptimizeIRPassLib &irpass, co irpass.pynative_eliminate_, }); opt::irpass::ResolveIRPassLib resolve_irpass; - opt::OptPassConfig resolver_prim = opt::OptPassConfig({ - resolve_irpass.resolver_resolve_and_getattr_, - resolve_irpass.resolver_resolve_, - resolve_irpass.resolver_getattr_, - }); opt::OptPassConfig switch_simplify = opt::OptPassConfig({ irpass.switch_simplify_, @@ -133,7 +128,6 @@ FuncGraphPtr PrimBpOptPassStep1(const opt::irpass::OptimizeIRPassLib &irpass, co }); OptPassGroupMap map({{"ad_eliminate", pynative_eliminate}, - {"ad_resolver_prim", resolver_prim}, {"ad_inline", inline_opt}, {"bool_scalar_eliminate", bool_scalar_eliminate}, {"ad_switch_simplify", switch_simplify}}); @@ -324,9 +318,8 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) { opt::OptPassConfig virtual_dataset = opt::OptPassConfig({irpass.virtual_dataset_eliminate_}); opt::irpass::ResolveIRPassLib resolve_irpass; - opt::OptPassConfig resolve_pass = - opt::OptPassConfig({resolve_irpass.resolver_resolve_, resolve_irpass.resolver_getattr_, - irpass.get_make_ref_eliminate_, irpass.replace_old_param_}); + opt::OptPassConfig after_resolve_pass = + opt::OptPassConfig({irpass.get_make_ref_eliminate_, irpass.replace_old_param_}); // Before adjusting map_a, check GetA1A2() and GetOptPynativeGradEpiloguePhases(). OptPassGroupMap map_a({{"a_1", a_1}, @@ -339,7 +332,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) { {"virtual_dataset", virtual_dataset}, {"virtual_output", opt::OptPassConfig({irpass.virtual_output_eliminate_})}, {"grad", opt::OptPassConfig(opt::irpass::ExpandJPrim())}, - {"resolve", resolve_pass}, + {"after_resolve", after_resolve_pass}, {"a_after_grad", a_after_grad}, {"renormalize", opt::OptPassConfig::Renormalize()}, {"auto_monad_grad", opt::OptPassConfig(ReAutoMonadWrapper)}, diff --git a/mindspore/ccsrc/transform/express_ir/mindir_exporter.cc b/mindspore/ccsrc/transform/express_ir/mindir_exporter.cc index 8ba90d0d019..f9735ed2824 100644 --- a/mindspore/ccsrc/transform/express_ir/mindir_exporter.cc +++ b/mindspore/ccsrc/transform/express_ir/mindir_exporter.cc @@ -176,6 +176,7 @@ void IrExportBuilder::BuildModelInfo() { void IrExportBuilder::BuildModel(const FuncGraphPtr &func_graph) { mind_ir::GraphProto *graph_proto = model_.mutable_graph(); graph_proto->set_name(func_graph->ToString()); + graph_proto->set_bprop_hash(func_graph->bprop_hash()); ResetNodeIndex(); todo_.clear(); todo_.push_back(func_graph); @@ -247,6 +248,9 @@ void IrExportBuilder::SetValueInfoProto(const AnfNodePtr &node, mind_ir::ValueIn MS_LOG(DEBUG) << "SetValueInfoProto: " << node->DebugString(); const TypePtr &type = node->Type(); const BaseShapePtr &shape = node->Shape(); + if (type == nullptr || shape == nullptr) { + return; + } if (type->isa() && shape->isa()) { auto tensor = type->cast(); auto elem_type = tensor->element(); @@ -404,9 +408,10 @@ void IrExportBuilder::SetShapeToNodeProto(const CNodePtr &node, mind_ir::NodePro // 3. save tuple string in ref_attr_name MS_EXCEPTION_IF_NULL(node); auto type = node->Type(); - MS_EXCEPTION_IF_NULL(type); auto shape = node->Shape(); - MS_EXCEPTION_IF_NULL(shape); + if (type == nullptr || shape == nullptr) { + return; + } ResetTupleIndex(); std::string seq_string = "shape:"; mind_ir::AttributeProto *attr_proto = node_proto->add_attribute(); diff --git a/mindspore/ccsrc/utils/primitive_utils.cc b/mindspore/ccsrc/utils/primitive_utils.cc index f3bf43ba5db..b2608136e09 100644 --- a/mindspore/ccsrc/utils/primitive_utils.cc +++ b/mindspore/ccsrc/utils/primitive_utils.cc @@ -60,6 +60,29 @@ py::tuple ConvertDatatoPyTuple(const VectorRef &args) { return py_args; } +py::function GetComputeFunctionWithoutPyObj(const std::string &name) { + static const std::string module = "tests.vm_impl.vm_impl_function"; + py::module mod = py::module::import(common::SafeCStr(module)); + if (!py::hasattr(mod, common::SafeCStr(name))) { + return py::none(); + } + py::object fn = mod.attr(common::SafeCStr(name)); + return fn; +} + +BaseRef RunComputeFunctionWithoutPyObj(const PrimitivePtr &prim, const VectorRef &args) { + auto func = GetComputeFunctionWithoutPyObj(prim->name()); + if (py::isinstance(func)) { + return nullptr; + } + auto py_args = ConvertDatatoPyTuple(args); + py::object obj = func(*py_args); + if (py::isinstance(obj)) { + return nullptr; + } + return std::make_shared(obj); +} + BaseRef RunComputeFunction(const PrimitivePtr &prim, const VectorRef &args) { auto func = GetComputeFunction(prim->name()); if (py::isinstance(func)) { diff --git a/mindspore/ccsrc/utils/primitive_utils.h b/mindspore/ccsrc/utils/primitive_utils.h index cb23e535b02..cb3ad2c3b2d 100644 --- a/mindspore/ccsrc/utils/primitive_utils.h +++ b/mindspore/ccsrc/utils/primitive_utils.h @@ -33,6 +33,10 @@ py::function GetComputeFunction(std::string name); BaseRef RunComputeFunction(const PrimitivePtr &prim, const VectorRef &args); +py::function GetComputeFunctionWithoutPyObj(const std::string &name); + +BaseRef RunComputeFunctionWithoutPyObj(const PrimitivePtr &prim, const VectorRef &args); + py::tuple ConvertDatatoPyTuple(const VectorRef &args); } // namespace mindspore diff --git a/mindspore/ccsrc/utils/system/sha256.cc b/mindspore/ccsrc/utils/system/sha256.cc new file mode 100644 index 00000000000..dc169550a98 --- /dev/null +++ b/mindspore/ccsrc/utils/system/sha256.cc @@ -0,0 +1,201 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "utils/system/sha256.h" +#include +#include +#include +#include +#include +#include +#include +#include "securec/include/securec.h" +#include "utils/log_adapter.h" + +namespace mindspore { +namespace system { +namespace sha256 { +constexpr int kBitNumber = 8; +constexpr int kDigestSize = 8; +constexpr int kIterationNumber = 64; +constexpr int kMessageBlockLength = 64; +const uint32_t constant[64] = { + 0x428a2f98, 0x71374491, 0xb5c0fbcf, 0xe9b5dba5, 0x3956c25b, 0x59f111f1, 0x923f82a4, 0xab1c5ed5, + 0xd807aa98, 0x12835b01, 0x243185be, 0x550c7dc3, 0x72be5d74, 0x80deb1fe, 0x9bdc06a7, 0xc19bf174, + 0xe49b69c1, 0xefbe4786, 0x0fc19dc6, 0x240ca1cc, 0x2de92c6f, 0x4a7484aa, 0x5cb0a9dc, 0x76f988da, + 0x983e5152, 0xa831c66d, 0xb00327c8, 0xbf597fc7, 0xc6e00bf3, 0xd5a79147, 0x06ca6351, 0x14292967, + 0x27b70a85, 0x2e1b2138, 0x4d2c6dfc, 0x53380d13, 0x650a7354, 0x766a0abb, 0x81c2c92e, 0x92722c85, + 0xa2bfe8a1, 0xa81a664b, 0xc24b8b70, 0xc76c51a3, 0xd192e819, 0xd6990624, 0xf40e3585, 0x106aa070, + 0x19a4c116, 0x1e376c08, 0x2748774c, 0x34b0bcb5, 0x391c0cb3, 0x4ed8aa4a, 0x5b9cca4f, 0x682e6ff3, + 0x748f82ee, 0x78a5636f, 0x84c87814, 0x8cc70208, 0x90befffa, 0xa4506ceb, 0xbef9a3f7, 0xc67178f2}; + +std::string LoadFilePath(const std::string &path) { + char real_path[PATH_MAX] = {0}; +#if defined(_WIN32) || defined(_WIN64) + if (path.size() >= PATH_MAX || _fullpath(real_path, path.c_str(), PATH_MAX) == nullptr) { + return ""; + } +#else + if (path.size() >= PATH_MAX || realpath(path.c_str(), real_path) == nullptr) { + return ""; + } +#endif + std::ifstream bin_stream(real_path, std::ios::binary); + if (!bin_stream.is_open()) { + return ""; + } + std::string message((std::istreambuf_iterator(bin_stream)), std::istreambuf_iterator()); + return message; +} + +bool Padding(std::string *message) { + uint64_t bits_message = message->size() * kBitNumber; + const int remains = message->size() % kMessageBlockLength; + // The length of the message needs to be stored in 8 bytes, supplemented at the end of the message. + const int size_append = 8; + const int size_required = kMessageBlockLength - size_append; + const int size_pad = size_required - remains + (size_required > remains ? 0 : kMessageBlockLength); + if (size_pad < 1 || size_pad > kMessageBlockLength) { + return false; + } + message->push_back(0x80); + for (int i = 1; i < size_pad; ++i) { + message->push_back(0x00); + } + for (int i = size_append - 1; i >= 0; --i) { + message->push_back(static_cast((bits_message >> static_cast(i * kBitNumber)) & 0xff)); + } + return true; +} + +bool ProcessInner(const std::string &message, const int &bias, uint32_t *digest, const int &digest_size) { + if (digest_size != 8) { // The number of digests is fixed at 8 + return false; + } + uint32_t w[kIterationNumber] = {0}; + for (int i = 0; i < 16; ++i) { + w[i] = (static_cast(static_cast(message[bias + i * 4]) & 0xff) << 24) | + (static_cast(static_cast(message[bias + i * 4 + 1]) & 0xff) << 16) | + (static_cast(static_cast(message[bias + i * 4 + 2]) & 0xff) << 8) | + (static_cast(static_cast(message[bias + i * 4 + 3]) & 0xff)); + } + for (int i = 16; i < kIterationNumber; ++i) { + w[i] = sigma3(w[i - 2]) + w[i - 7] + sigma2(w[i - 15]) + w[i - 16]; + } + + std::vector hash(digest_size); + size_t mem_size = digest_size * sizeof(uint32_t); + auto ret = memcpy_s(hash.data(), mem_size, digest, mem_size); + if (ret != EOK) { + return false; + } + for (int i = 0; i < kIterationNumber; ++i) { + uint32_t t1 = w[i] + constant[i] + hash[7] + sigma1(hash[4]) + ch(hash[4], hash[5], hash[6]); + uint32_t t2 = sigma0(hash[0]) + ma(hash[0], hash[1], hash[2]); + for (int j = digest_size - 1; j >= 0; --j) { + if (j == 4) { + hash[j] = hash[j - 1] + t1; + } else if (j == 0) { + hash[j] = t1 + t2; + } else { + hash[j] = hash[j - 1]; + } + } + } + for (int i = 0; i < digest_size; ++i) { + digest[i] += hash[i]; + } + return true; +} + +std::string ConvertToString(uint32_t *input, const int &size) { + std::ostringstream oss; + oss << std::hex; + for (int i = 0; i < size; ++i) { + for (int j = static_cast(sizeof(uint32_t) / sizeof(uint8_t)) - 1; j >= 0; --j) { + auto val = static_cast((input[i] >> static_cast(j * kBitNumber)) & 0xff); + oss << std::setw(2) << std::setfill('0') << static_cast(val); + } + } + return oss.str(); +} + +std::string Encrypt(const std::string &message) { + uint32_t digest[kDigestSize] = {0x6a09e667, 0xbb67ae85, 0x3c6ef372, 0xa54ff53a, + 0x510e527f, 0x9b05688c, 0x1f83d9ab, 0x5be0cd19}; + for (int i = 0; i < static_cast(message.size()); i += kMessageBlockLength) { + if (!ProcessInner(message, i, digest, kDigestSize)) { + return ""; + } + } + return ConvertToString(digest, kDigestSize); +} + +std::string GetHashFromString(const std::string &data) { + std::string message = data; + if (message.empty() || !Padding(&message)) { + return ""; + } + return Encrypt(message); +} + +std::string GetHashFromFile(const std::string &path) { + std::string message = LoadFilePath(path); + if (message.empty() || !Padding(&message)) { + return ""; + } + return Encrypt(message); +} + +#ifndef _WIN32 +std::string GetHashFromDir(const std::string &dir) { + if (dir.empty()) { + MS_LOG(ERROR) << "The directory path is empty."; + return ""; + } + struct stat s {}; + int ret = stat(dir.c_str(), &s); + if (ret != 0) { + MS_LOG(ERROR) << "stat dir \"" << dir << "\" failed, ret is : " << ret; + return ""; + } + if (!S_ISDIR(s.st_mode)) { + MS_LOG(ERROR) << "The path \"" << dir << "\" is not a directory."; + return ""; + } + DIR *open_dir = opendir(dir.c_str()); + if (open_dir == nullptr) { + MS_LOG(ERROR) << "open dir " << dir.c_str() << " failed"; + return ""; + } + struct dirent *filename; + std::vector file_hashes; + while ((filename = readdir(open_dir)) != nullptr) { + std::string d_name = std::string(filename->d_name); + if (d_name == "." || d_name == ".." || filename->d_type != DT_REG) { + continue; + } + file_hashes.emplace_back(GetHashFromFile(std::string(dir) + "/" + filename->d_name)); + } + closedir(open_dir); + std::sort(file_hashes.begin(), file_hashes.end()); + auto dir_hash = std::accumulate(file_hashes.begin(), file_hashes.end(), std::string{}); + return dir_hash; +} +#endif +} // namespace sha256 +} // namespace system +} // namespace mindspore diff --git a/mindspore/ccsrc/utils/system/sha256.h b/mindspore/ccsrc/utils/system/sha256.h index b0a295ddf30..9735ab872a8 100644 --- a/mindspore/ccsrc/utils/system/sha256.h +++ b/mindspore/ccsrc/utils/system/sha256.h @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,28 +18,10 @@ #define MINDSPORE_CCSRC_UTILS_SYSTEM_SHA256_H_ #include -#include -#include -#include -#include namespace mindspore { namespace system { namespace sha256 { -constexpr int kBitNumber = 8; -constexpr int kDigestSize = 8; -constexpr int kIterationNumber = 64; -constexpr int kMessageBlockLength = 64; -const uint32_t constant[64] = { - 0x428a2f98, 0x71374491, 0xb5c0fbcf, 0xe9b5dba5, 0x3956c25b, 0x59f111f1, 0x923f82a4, 0xab1c5ed5, - 0xd807aa98, 0x12835b01, 0x243185be, 0x550c7dc3, 0x72be5d74, 0x80deb1fe, 0x9bdc06a7, 0xc19bf174, - 0xe49b69c1, 0xefbe4786, 0x0fc19dc6, 0x240ca1cc, 0x2de92c6f, 0x4a7484aa, 0x5cb0a9dc, 0x76f988da, - 0x983e5152, 0xa831c66d, 0xb00327c8, 0xbf597fc7, 0xc6e00bf3, 0xd5a79147, 0x06ca6351, 0x14292967, - 0x27b70a85, 0x2e1b2138, 0x4d2c6dfc, 0x53380d13, 0x650a7354, 0x766a0abb, 0x81c2c92e, 0x92722c85, - 0xa2bfe8a1, 0xa81a664b, 0xc24b8b70, 0xc76c51a3, 0xd192e819, 0xd6990624, 0xf40e3585, 0x106aa070, - 0x19a4c116, 0x1e376c08, 0x2748774c, 0x34b0bcb5, 0x391c0cb3, 0x4ed8aa4a, 0x5b9cca4f, 0x682e6ff3, - 0x748f82ee, 0x78a5636f, 0x84c87814, 0x8cc70208, 0x90befffa, 0xa4506ceb, 0xbef9a3f7, 0xc67178f2}; - inline uint32_t ch(uint32_t x, uint32_t y, uint32_t z) { return (x & y) ^ ((~x) & z); } inline uint32_t ma(uint32_t x, uint32_t y, uint32_t z) { return (x & y) ^ (x & z) ^ (y & z); } inline uint32_t sigma0(uint32_t x) { return (x >> 2 | x << 30) ^ (x >> 13 | x << 19) ^ (x >> 22 | x << 10); } @@ -47,123 +29,23 @@ inline uint32_t sigma1(uint32_t x) { return (x >> 6 | x << 26) ^ (x >> 11 | x << inline uint32_t sigma2(uint32_t x) { return (x >> 7 | x << 25) ^ (x >> 18 | x << 14) ^ (x >> 3); } inline uint32_t sigma3(uint32_t x) { return (x >> 17 | x << 15) ^ (x >> 19 | x << 13) ^ (x >> 10); } -std::string LoadFilePath(const std::string &path) { - char real_path[PATH_MAX] = {0}; -#if defined(_WIN32) || defined(_WIN64) - if (path.size() >= PATH_MAX || _fullpath(real_path, path.c_str(), PATH_MAX) == nullptr) { - return ""; - } -#else - if (path.size() >= PATH_MAX || realpath(path.c_str(), real_path) == nullptr) { - return ""; - } +std::string LoadFilePath(const std::string &path); + +bool Padding(std::string *message); + +bool ProcessInner(const std::string &message, const int &bias, uint32_t *digest, const int &digest_size); + +std::string ConvertToString(uint32_t *input, const int &size); + +std::string Encrypt(const std::string &message); + +std::string GetHashFromString(const std::string &data); + +std::string GetHashFromFile(const std::string &path); + +#ifndef _WIN32 +std::string GetHashFromDir(const std::string &dir); #endif - std::ifstream bin_stream(real_path, std::ios::binary); - if (!bin_stream.is_open()) { - return ""; - } - std::string message((std::istreambuf_iterator(bin_stream)), std::istreambuf_iterator()); - return message; -} - -bool Padding(std::string *message) { - uint64_t bits_message = message->size() * kBitNumber; - const int remains = message->size() % kMessageBlockLength; - // The length of the message needs to be stored in 8 bytes, supplemented at the end of the message. - const int size_append = 8; - const int size_required = kMessageBlockLength - size_append; - const int size_pad = size_required - remains + (size_required > remains ? 0 : kMessageBlockLength); - if (size_pad < 1 || size_pad > kMessageBlockLength) { - return false; - } - message->push_back(0x80); - for (int i = 1; i < size_pad; ++i) { - message->push_back(0x00); - } - for (int i = size_append - 1; i >= 0; --i) { - message->push_back(static_cast((bits_message >> static_cast(i * kBitNumber)) & 0xff)); - } - return true; -} - -bool ProcessInner(const std::string &message, const int &bias, uint32_t *digest, const int &digest_size) { - if (digest_size != 8) { // The number of digests is fixed at 8 - return false; - } - uint32_t w[kIterationNumber] = {0}; - for (int i = 0; i < 16; ++i) { - w[i] = (static_cast(static_cast(message[bias + i * 4]) & 0xff) << 24) | - (static_cast(static_cast(message[bias + i * 4 + 1]) & 0xff) << 16) | - (static_cast(static_cast(message[bias + i * 4 + 2]) & 0xff) << 8) | - (static_cast(static_cast(message[bias + i * 4 + 3]) & 0xff)); - } - for (int i = 16; i < kIterationNumber; ++i) { - w[i] = sigma3(w[i - 2]) + w[i - 7] + sigma2(w[i - 15]) + w[i - 16]; - } - - std::vector hash(digest_size); - size_t mem_size = digest_size * sizeof(uint32_t); - auto ret = memcpy_s(hash.data(), mem_size, digest, mem_size); - if (ret != EOK) { - return false; - } - for (int i = 0; i < kIterationNumber; ++i) { - uint32_t t1 = w[i] + constant[i] + hash[7] + sigma1(hash[4]) + ch(hash[4], hash[5], hash[6]); - uint32_t t2 = sigma0(hash[0]) + ma(hash[0], hash[1], hash[2]); - for (int j = digest_size - 1; j >= 0; --j) { - if (j == 4) { - hash[j] = hash[j - 1] + t1; - } else if (j == 0) { - hash[j] = t1 + t2; - } else { - hash[j] = hash[j - 1]; - } - } - } - for (int i = 0; i < digest_size; ++i) { - digest[i] += hash[i]; - } - return true; -} - -std::string ConvertToString(uint32_t *input, const int &size) { - std::ostringstream oss; - oss << std::hex; - for (int i = 0; i < size; ++i) { - for (int j = static_cast(sizeof(uint32_t) / sizeof(uint8_t)) - 1; j >= 0; --j) { - auto val = static_cast((input[i] >> static_cast(j * kBitNumber)) & 0xff); - oss << std::setw(2) << std::setfill('0') << static_cast(val); - } - } - return oss.str(); -} - -std::string Encrypt(const std::string &message) { - uint32_t digest[kDigestSize] = {0x6a09e667, 0xbb67ae85, 0x3c6ef372, 0xa54ff53a, - 0x510e527f, 0x9b05688c, 0x1f83d9ab, 0x5be0cd19}; - for (int i = 0; i < static_cast(message.size()); i += kMessageBlockLength) { - if (!ProcessInner(message, i, digest, kDigestSize)) { - return ""; - } - } - return ConvertToString(digest, kDigestSize); -} - -std::string GetHashFromString(const std::string &data) { - std::string message = data; - if (message.empty() || !Padding(&message)) { - return ""; - } - return Encrypt(message); -} - -std::string GetHashFromFile(const std::string &path) { - std::string message = LoadFilePath(path); - if (message.empty() || !Padding(&message)) { - return ""; - } - return Encrypt(message); -} } // namespace sha256 } // namespace system } // namespace mindspore diff --git a/mindspore/ccsrc/vm/vmimpl.cc b/mindspore/ccsrc/vm/vmimpl.cc index 7731a50794a..7b3c24edb04 100644 --- a/mindspore/ccsrc/vm/vmimpl.cc +++ b/mindspore/ccsrc/vm/vmimpl.cc @@ -437,6 +437,9 @@ BaseRef RunOperation(const PrimitivePtr &prim, const VectorRef &args) { MS_LOG(DEBUG) << "operation start " << prim->name(); MS_EXCEPTION_IF_NULL(prim); auto result = prim->RunComputeFunction(args); + if (result.is_null()) { + result = RunComputeFunctionWithoutPyObj(prim, args); + } if (result.is_null()) { return RunComputeFunction(prim, args); } diff --git a/mindspore/core/ir/func_graph.h b/mindspore/core/ir/func_graph.h index f200d225f8b..5c61a6e3c02 100644 --- a/mindspore/core/ir/func_graph.h +++ b/mindspore/core/ir/func_graph.h @@ -393,6 +393,9 @@ class FuncGraph : public FuncGraphBase, public EffectInfoHolder { bool dropped() const { return dropped_; } void set_dropped(bool dropped) { dropped_ = dropped; } + std::string bprop_hash() const { return bprop_hash_; } + void set_bprop_hash(const std::string &bprop_hash) { bprop_hash_ = bprop_hash; } + private: // Only used for func_graph manager to control resource free. int attached_mng_cnt() const { return attached_mng_cnt_; } @@ -477,6 +480,8 @@ class FuncGraph : public FuncGraphBase, public EffectInfoHolder { // If the graph was changed, it should be dropped in cache data_converter::object_map_ // which used by ConvertToFuncGraph. bool dropped_ = false; + // If the graph is a bprop graph, it should has a hash of the bprop directory. + std::string bprop_hash_; }; inline CNodePtr NewCNode(const std::vector &inputs, const FuncGraphPtr &fg) { diff --git a/mindspore/core/load_mindir/anf_model_parser.cc b/mindspore/core/load_mindir/anf_model_parser.cc index 00f88de4a1f..565a7bb7d4d 100644 --- a/mindspore/core/load_mindir/anf_model_parser.cc +++ b/mindspore/core/load_mindir/anf_model_parser.cc @@ -40,6 +40,8 @@ static constexpr char kConstantValueNode[] = "Constant"; static constexpr char kCNodeShapeAttr[] = "shape"; static constexpr char kCNodeShape1Attr[] = "shape1"; static constexpr char kCNodeShape2Attr[] = "shape2"; +static constexpr char kDoSignaturePrimitivePrefix[] = "S-Prim-"; + enum ParseForm : int { FORM_PARSE_TYPE = 0, FORM_PARSE_SCALAR = 1, @@ -305,12 +307,13 @@ bool MSANFModelParser::BuildInputForFuncGraph(const ParameterPtr &node, const mi node->set_debug_info(debug_info_ptr); node->set_name(debug_info_name); - const mind_ir::TensorProto &tensor_proto = value_proto.tensor(0); - - tensor::TensorPtr tensor_info = BuildTensorInfoForFuncGraph(tensor_proto); - MS_EXCEPTION_IF_NULL(tensor_info); - auto tensor_abstract = tensor_info->ToAbstract(); - node->set_abstract(tensor_abstract); + if (value_proto.tensor_size() > 0) { + const mind_ir::TensorProto &tensor_proto = value_proto.tensor(0); + tensor::TensorPtr tensor_info = BuildTensorInfoForFuncGraph(tensor_proto); + MS_EXCEPTION_IF_NULL(tensor_info); + auto tensor_abstract = tensor_info->ToAbstract(); + node->set_abstract(tensor_abstract); + } anfnode_build_map_[value_proto.name()] = node; return true; @@ -768,8 +771,14 @@ CNodePtr MSANFModelParser::BuildCNodeForFuncGraph(const FuncGraphPtr &outputFunc if (op_primc_fns.find(node_type) != op_primc_fns.end()) { prim = op_primc_fns[node_type](); } else { - prim = std::make_shared(node_type); - prim->set_instance_name(node_type); + if (node_type.compare(0, strlen(kDoSignaturePrimitivePrefix), kDoSignaturePrimitivePrefix) == 0) { + auto op_name = node_type.substr(strlen(kDoSignaturePrimitivePrefix)); + prim = std::make_shared(op_name, std::make_shared(op_name)); + prim->set_instance_name(op_name); + } else { + prim = std::make_shared(node_type); + prim->set_instance_name(node_type); + } } MS_EXCEPTION_IF_NULL(prim); @@ -812,9 +821,14 @@ CNodePtr MSANFModelParser::BuildCNodeForFuncGraph(const FuncGraphPtr &outputFunc } else { AbstractBasePtrList elem; for (size_t index = 1; index < cnode_ptr->inputs().size(); ++index) { - elem.push_back(cnode_ptr->input(index)->abstract()); + auto abs = cnode_ptr->input(index)->abstract(); + if (abs != nullptr) { + elem.push_back(abs); + } + } + if (!elem.empty()) { + cnode_ptr->set_abstract(std::make_shared(elem)); } - cnode_ptr->set_abstract(std::make_shared(elem)); } } else if (kv.size() == 1) { std::unordered_map::iterator iter = kv.begin(); @@ -919,6 +933,9 @@ bool MSANFModelParser::BuildFuncGraph(const FuncGraphPtr &outputFuncGraph, const } else { MS_LOG(ERROR) << "FuncGraph under converting has not name!"; } + if (importProto.has_bprop_hash()) { + outputFuncGraph->set_bprop_hash(importProto.bprop_hash()); + } if (!ImportParametersForGraph(outputFuncGraph, importProto)) { MS_LOG(ERROR) << "import parameters for graph fail!"; diff --git a/mindspore/core/proto/mind_ir.proto b/mindspore/core/proto/mind_ir.proto index 2c38198ab0b..cd6182b9e15 100644 --- a/mindspore/core/proto/mind_ir.proto +++ b/mindspore/core/proto/mind_ir.proto @@ -80,6 +80,7 @@ message GraphProto { optional string doc_string = 4; repeated ValueInfoProto input = 5; repeated ValueInfoProto output = 6; + optional string bprop_hash = 7; } diff --git a/mindspore/ops/bprop_mindir/Identity_bprop.mindir b/mindspore/ops/bprop_mindir/Identity_bprop.mindir new file mode 100644 index 00000000000..79e334abd3d --- /dev/null +++ b/mindspore/ops/bprop_mindir/Identity_bprop.mindir @@ -0,0 +1,9 @@ + +0.1.0 MindSpore*1.1.0:î +— + bprop.10:doutbprop.10:[CNode]12:2bprop.10:[CNode]11:1"S-Prim-MakeTuple:HGradients/Default/network-NetIdentity/gradIdentity/S-Prim-MakeTuple-op15bprop.10* + +bprop.10:x* + bprop.10:out* + bprop.10:dout2 +bprop.10:[CNode]12:2:€14cac93a068aa39edcd5220275a7f3df23c79f939b5f52bbe3321d22bc4706d91fc4e05e490c2d17243f1a3b10584567ec050735053d6e4da31b0c0120e747522366f7bd59ea5ec135e982de03b4f7cab6b61d833d046a6e13f78bdaf2fb2b22602dc0172bb790a967e7e5ba7e35d54f6df1ae3014fea781a726693f4c945c0d65c0e00bc893ef15ec6199798d6c8c46997153587d375b3240c1195ff2c7278c695827de181149f40d1a6397a5a742b3216423289a39974991e6c56c2a17a1dd6c22f5fa950412d9825b54837052ad6b1201185ed5a1a4de1d4fb4a60945bfb77e635a08323207b4cb3f73fd8437b4d7ee28a7676a68f005a7749bd19e5ed4eca40d19b1aba47164a5f706a591d6282290a76b4abfa99f3b64d0b95ce409fac3c414b8c313aac4f85c6217fbbb7009dd079b2d5548f8b695a470a11cb8cc83e6cdf493aa69bd2a1bfbbab074bc8076571a7a726e82612665807efc62736694a4f5e78f5b3c67f2e7bf339b250c3638aee952e1a073002e2834011401f3827260 \ No newline at end of file diff --git a/mindspore/ops/bprop_mindir/ReLU_bprop.mindir b/mindspore/ops/bprop_mindir/ReLU_bprop.mindir new file mode 100644 index 00000000000..e462e1cf443 --- /dev/null +++ b/mindspore/ops/bprop_mindir/ReLU_bprop.mindir @@ -0,0 +1,11 @@ + +0.1.0 MindSpore*1.1.0:å +ˆ + bprop.2:dout + bprop.2:out bprop.2:dx:1 bprop.2:dx:1"S-Prim-ReluGrad:>Gradients/Default/network-NetRelu/gradReLU/S-Prim-ReluGrad-op5 +‰ + bprop.2:dx:1bprop.2:[CNode]4:3bprop.2:[CNode]3:2"S-Prim-MakeTuple:?Gradients/Default/network-NetRelu/gradReLU/S-Prim-MakeTuple-op6bprop.2* + bprop.2:x* + bprop.2:out* + bprop.2:dout2 +bprop.2:[CNode]4:3:€14cac93a068aa39edcd5220275a7f3df23c79f939b5f52bbe3321d22bc4706d91fc4e05e490c2d17243f1a3b10584567ec050735053d6e4da31b0c0120e747522366f7bd59ea5ec135e982de03b4f7cab6b61d833d046a6e13f78bdaf2fb2b22602dc0172bb790a967e7e5ba7e35d54f6df1ae3014fea781a726693f4c945c0d65c0e00bc893ef15ec6199798d6c8c46997153587d375b3240c1195ff2c7278c695827de181149f40d1a6397a5a742b3216423289a39974991e6c56c2a17a1dd6c22f5fa950412d9825b54837052ad6b1201185ed5a1a4de1d4fb4a60945bfb77e635a08323207b4cb3f73fd8437b4d7ee28a7676a68f005a7749bd19e5ed4eca40d19b1aba47164a5f706a591d6282290a76b4abfa99f3b64d0b95ce409fac3c414b8c313aac4f85c6217fbbb7009dd079b2d5548f8b695a470a11cb8cc83e6cdf493aa69bd2a1bfbbab074bc8076571a7a726e82612665807efc62736694a4f5e78f5b3c67f2e7bf339b250c3638aee952e1a073002e2834011401f3827260 \ No newline at end of file diff --git a/mindspore/ops/bprop_mindir/__init__.py b/mindspore/ops/bprop_mindir/__init__.py new file mode 100644 index 00000000000..6706b970640 --- /dev/null +++ b/mindspore/ops/bprop_mindir/__init__.py @@ -0,0 +1,16 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""bprop mindir.""" diff --git a/mindspore/ops/bprop_mindir/generate_mindir.py b/mindspore/ops/bprop_mindir/generate_mindir.py new file mode 100644 index 00000000000..af8466b5f78 --- /dev/null +++ b/mindspore/ops/bprop_mindir/generate_mindir.py @@ -0,0 +1,81 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Generate the mindir for bprop""" +import os +import numpy as np + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops import operations as P +import mindspore.ops as ops +import mindspore.ops._grad as g + +context.set_context(mode=context.GRAPH_MODE) +os.environ['GENERATE_MINDIR'] = '1' + + +class NetRelu(nn.Cell): + def __init__(self): + super(NetRelu, self).__init__() + self.relu = P.ReLU() + + def construct(self, x): + return self.relu(x) + + +class NetIdentity(nn.Cell): + def __init__(self): + super(NetIdentity, self).__init__() + self.identity = P.Identity() + + def construct(self, x): + return self.identity(x) + + +class GradNet(nn.Cell): + def __init__(self, network): + super(GradNet, self).__init__() + self.grad = ops.GradOperation() + self.network = network + + def construct(self, x): + gout = self.grad(self.network)(x) + return gout + + +def test_relu(): + x = Tensor(np.array([[[[-1, 1, 10], + [1, -1, 1], + [10, 1, -1]]]]).astype(np.float32)) + relu = NetRelu() + grad = GradNet(relu) + grad(x) + + +def test_identity(): + x = Tensor(np.array([1, 2, 3, 4]).astype(np.int64)) + identity = NetIdentity() + grad = GradNet(identity) + grad(x) + + +test_relu() +test_identity() +# mindspore/ops/_grad/__init__.py +bprop_path = g.__file__ +bprop_mindir_path = bprop_path[: bprop_path.rindex('/')] + "/../bprop_mindir/" +print("The new bprop mindir files has been generated in the path \"" + bprop_mindir_path + + "\", copy the *.mindir to your PYTHONPATH if necessary.") diff --git a/setup.py b/setup.py index 9cd543c4545..24058b2f189 100644 --- a/setup.py +++ b/setup.py @@ -133,6 +133,7 @@ package_data = { 'lib/*.dylib*', '.commit_id', 'config/*', + 'ops/bprop_mindir/*', 'include/*', 'include/*/*', 'include/*/*/*', diff --git a/tests/vm_impl/vm_impl_function.py b/tests/vm_impl/vm_impl_function.py new file mode 100644 index 00000000000..3311ed2b3a5 --- /dev/null +++ b/tests/vm_impl/vm_impl_function.py @@ -0,0 +1,23 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Generate vm_impl function for nn ops without python object""" +from mindspore.common.tensor import Tensor +from .vm_interface import vm + +def ReluGrad(y_backprop, x): + x = x.asnumpy() + y_backprop = y_backprop.asnumpy() + y_backprop = vm.relu_grad(x.copy()) * y_backprop + return Tensor(y_backprop)