forked from mindspore-Ecosystem/mindspore
new_construct_bprop
move expander files move expander component to mindspore/core move the bprop expanders to mindspore/frontend
This commit is contained in:
parent
1809ae2820
commit
605c1a8479
|
@ -872,7 +872,7 @@ void SessionBasic::GetOpInputTensors(const CNodePtr &cnode,
|
|||
}
|
||||
}
|
||||
input_tensor_info->input_tensors_mask.emplace_back(
|
||||
(is_value_node && !is_forward_output) ? kValueNodeTensorMask : kParameterDataTensorMask);
|
||||
(is_value_node || !is_forward_output) ? kValueNodeTensorMask : kParameterDataTensorMask);
|
||||
} else if (real_input->isa<Parameter>()) {
|
||||
tensor = GetParameterOutputTensor(real_input, parameter_index, graph_inputs);
|
||||
input_tensor_info->input_tensors_mask.emplace_back(tensor->is_parameter() ? kParameterWeightTensorMask
|
||||
|
|
|
@ -901,6 +901,9 @@ bool TopoCompareFuncGraphNode(const AnfNodePtr &node, const bool &is_first_sort,
|
|||
if (is_first_sort) {
|
||||
(void)new_graph_info->topo_node_list.emplace_back(std::move(node_info));
|
||||
} else {
|
||||
if (common::AnfAlgo::IsControlOpExecInBackend(cnode)) {
|
||||
return true;
|
||||
}
|
||||
if (topo_node_idx >= old_graph_info->topo_node_list.size()) {
|
||||
return true;
|
||||
}
|
||||
|
|
|
@ -149,7 +149,7 @@ class BACKEND_EXPORT MindRTBackendBase : public Backend {
|
|||
|
||||
// Save the mapping between cell id and actor info.
|
||||
mindspore::HashMap<std::string, ActorInfo> graph_actor_infos_;
|
||||
bool enable_backend_dynamic_detect_{false};
|
||||
bool enable_backend_dynamic_detect_{true};
|
||||
FuncGraphPtr root_graph_;
|
||||
GraphPartitionPtr graph_partition_;
|
||||
std::shared_ptr<GraphCompiler> graph_compiler_;
|
||||
|
|
|
@ -0,0 +1,6 @@
|
|||
approvers:
|
||||
- gaoxiong1
|
||||
- ckey_dou
|
||||
- dayschan
|
||||
- anyrenwei
|
||||
- zichun_ye
|
|
@ -13,14 +13,14 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "common/graph_kernel/bprop/bprop.h"
|
||||
#include "frontend/operator/bprop/bprop.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <queue>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include "common/graph_kernel/bprop/expander/infer.h"
|
||||
#include "expander/infer.h"
|
||||
#include "utils/anf_utils.h"
|
||||
#include "include/common/debug/anf_ir_dump.h"
|
||||
|
|
@ -13,22 +13,22 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_COMMON_GRAPH_KERNEL_BPROP_BPROP_H_
|
||||
#define MINDSPORE_CCSRC_COMMON_GRAPH_KERNEL_BPROP_BPROP_H_
|
||||
#ifndef MINDSPORE_CCSRC_FRONTEND_OPERATOR_BPROP_BPROP_H_
|
||||
#define MINDSPORE_CCSRC_FRONTEND_OPERATOR_BPROP_BPROP_H_
|
||||
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include <utility>
|
||||
#include "ir/anf.h"
|
||||
#include "common/graph_kernel/bprop/bprop_irbuilder.h"
|
||||
#include "frontend/operator/bprop/bprop_irbuilder.h"
|
||||
#include "include/common/visible.h"
|
||||
|
||||
namespace mindspore {
|
||||
using DoutUserType = std::vector<std::pair<CNodePtr, int>>;
|
||||
// deprecated
|
||||
COMMON_EXPORT void BuildBprop(const CNodePtr &cnode, CNodePtrList *outputs, DoutUserType *dout_user);
|
||||
void BuildBprop(const CNodePtr &cnode, CNodePtrList *outputs, DoutUserType *dout_user);
|
||||
|
||||
using UserType = std::map<AnfNodePtr, std::vector<std::pair<CNodePtr, int>>>;
|
||||
COMMON_EXPORT bool BuildBprop(const CNodePtr &cnode, CNodePtrList *outputs, UserType *users);
|
||||
bool BuildBprop(const CNodePtr &cnode, CNodePtrList *outputs, UserType *users);
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_COMMON_GRAPH_KERNEL_BPROP_BPROP_H_
|
||||
#endif // MINDSPORE_CCSRC_FRONTEND_OPERATOR_BPROP_BPROP_H_
|
|
@ -14,14 +14,14 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "common/graph_kernel/bprop/bprop_irbuilder.h"
|
||||
#include "frontend/operator/bprop/bprop_irbuilder.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <vector>
|
||||
#include <limits>
|
||||
#include "include/common/utils/utils.h"
|
||||
#include "utils/ms_context.h"
|
||||
#include "common/graph_kernel/bprop/expander/common_utils.h"
|
||||
#include "frontend/operator/bprop/grad/common_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace expander {
|
|
@ -13,8 +13,8 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_COMMON_GRAPH_KERNEL_BPROP_BPROP_IRBUILDER_H_
|
||||
#define MINDSPORE_CCSRC_COMMON_GRAPH_KERNEL_BPROP_BPROP_IRBUILDER_H_
|
||||
#ifndef MINDSPORE_CCSRC_FRONTEND_OPERATOR_BPROP_BPROP_IRBUILDER_H_
|
||||
#define MINDSPORE_CCSRC_FRONTEND_OPERATOR_BPROP_BPROP_IRBUILDER_H_
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
@ -22,8 +22,8 @@
|
|||
#include <map>
|
||||
#include <functional>
|
||||
|
||||
#include "common/graph_kernel/bprop/expander/node.h"
|
||||
#include "common/graph_kernel/bprop/expander/emitter.h"
|
||||
#include "expander/node.h"
|
||||
#include "expander/emitter.h"
|
||||
#include "utils/hash_map.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -112,4 +112,4 @@ class BpropIRBuilderRegHelper {
|
|||
} // namespace bprop
|
||||
} // namespace expander
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_COMMON_GRAPH_KERNEL_BPROP_BPROP_IRBUILDER_H_
|
||||
#endif // MINDSPORE_CCSRC_FRONTEND_OPERATOR_BPROP_BPROP_IRBUILDER_H_
|
|
@ -13,7 +13,7 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "common/graph_kernel/bprop/expander/common_utils.h"
|
||||
#include "frontend/operator/bprop/grad/common_utils.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <memory>
|
|
@ -13,19 +13,20 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_COMMON_GRAPH_KERNEL_BPROP_EXPANDER_COMMON_UTILS_H_
|
||||
#define MINDSPORE_CCSRC_COMMON_GRAPH_KERNEL_BPROP_EXPANDER_COMMON_UTILS_H_
|
||||
#ifndef MINDSPORE_CCSRC_FRONTEND_OPERATOR_BPROP_GRAD_COMMON_UTILS_H_
|
||||
#define MINDSPORE_CCSRC_FRONTEND_OPERATOR_BPROP_GRAD_COMMON_UTILS_H_
|
||||
|
||||
#include <cmath>
|
||||
#include <vector>
|
||||
#include <utility>
|
||||
#include <set>
|
||||
#include "common/graph_kernel/bprop/expander/node.h"
|
||||
#include "common/graph_kernel/bprop/bprop_irbuilder.h"
|
||||
#include "expander/node.h"
|
||||
#include "frontend/operator/bprop/bprop_irbuilder.h"
|
||||
|
||||
namespace mindspore::expander::bprop {
|
||||
constexpr auto pi = acos(-1.0);
|
||||
constexpr auto log_2 = log(2.0);
|
||||
constexpr auto log_pi = log(pi);
|
||||
inline const auto pi = std::acos(-1.0);
|
||||
inline const auto log_2 = std::log(2.0);
|
||||
inline const auto log_pi = std::log(pi);
|
||||
|
||||
std::vector<std::vector<int64_t>> BroadcastGradientArgs(const std::vector<int64_t> &x_shape,
|
||||
const std::vector<int64_t> &y_shape);
|
||||
|
@ -83,4 +84,4 @@ bool CheckType(const TypePtr &check_type, const std::set<TypePtr> &template_type
|
|||
ShapeVector PoolToNHWC(const ShapeVector &v);
|
||||
ShapeVector ConvToNHWC(const ShapeVector &v);
|
||||
} // namespace mindspore::expander::bprop
|
||||
#endif // MINDSPORE_CCSRC_COMMON_GRAPH_KERNEL_BPROP_EXPANDER_COMMON_UTILS_H_
|
||||
#endif // MINDSPORE_CCSRC_FRONTEND_OPERATOR_BPROP_GRAD_COMMON_UTILS_H_
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "common/graph_kernel/bprop/bprop_irbuilder.h"
|
||||
#include "common/graph_kernel/bprop/expander/common_utils.h"
|
||||
#include "frontend/operator/bprop/bprop_irbuilder.h"
|
||||
#include "frontend/operator/bprop/grad/common_utils.h"
|
||||
#include "include/common/utils/utils.h"
|
||||
|
||||
namespace mindspore::expander::bprop {
|
|
@ -13,8 +13,8 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "common/graph_kernel/bprop/bprop_irbuilder.h"
|
||||
#include "common/graph_kernel/bprop/expander/common_utils.h"
|
||||
#include "frontend/operator/bprop/bprop_irbuilder.h"
|
||||
#include "frontend/operator/bprop/grad/common_utils.h"
|
||||
#include "include/common/utils/utils.h"
|
||||
|
||||
namespace mindspore::expander::bprop {
|
|
@ -13,8 +13,8 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "common/graph_kernel/bprop/bprop_irbuilder.h"
|
||||
#include "common/graph_kernel/bprop/expander/common_utils.h"
|
||||
#include "frontend/operator/bprop/bprop_irbuilder.h"
|
||||
#include "frontend/operator/bprop/grad/common_utils.h"
|
||||
#include "include/common/utils/utils.h"
|
||||
#include "ir/anf.h"
|
||||
|
|
@ -13,7 +13,7 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "common/graph_kernel/bprop/bprop_irbuilder.h"
|
||||
#include "frontend/operator/bprop/bprop_irbuilder.h"
|
||||
#include "include/common/utils/utils.h"
|
||||
|
||||
namespace mindspore::expander::bprop {
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
#include <set>
|
||||
#include "common/graph_kernel/bprop/bprop_irbuilder.h"
|
||||
#include "common/graph_kernel/bprop/expander/common_utils.h"
|
||||
#include "frontend/operator/bprop/bprop_irbuilder.h"
|
||||
#include "frontend/operator/bprop/grad/common_utils.h"
|
||||
#include "include/common/utils/utils.h"
|
||||
#include "utils/ms_context.h"
|
||||
|
|
@ -14,7 +14,7 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "common/graph_kernel/bprop/bprop_irbuilder.h"
|
||||
#include "frontend/operator/bprop/bprop_irbuilder.h"
|
||||
#include "include/common/utils/utils.h"
|
||||
|
||||
namespace mindspore::expander::bprop {
|
|
@ -15,9 +15,9 @@
|
|||
*/
|
||||
#include <unordered_set>
|
||||
|
||||
#include "common/graph_kernel/bprop/bprop_irbuilder.h"
|
||||
#include "frontend/operator/bprop/bprop_irbuilder.h"
|
||||
#include "include/common/utils/utils.h"
|
||||
#include "common/graph_kernel/bprop/expander/common_utils.h"
|
||||
#include "frontend/operator/bprop/grad/common_utils.h"
|
||||
|
||||
namespace mindspore::expander::bprop {
|
||||
static NodePtr GetMatrixDiagAssist(const BpropIRBuilder *ib, const ShapeVector &x_shape, TypePtr x_dtype) {
|
|
@ -15,8 +15,8 @@
|
|||
*/
|
||||
|
||||
#include <map>
|
||||
#include "common/graph_kernel/bprop/bprop_irbuilder.h"
|
||||
#include "common/graph_kernel/bprop/expander/common_utils.h"
|
||||
#include "frontend/operator/bprop/bprop_irbuilder.h"
|
||||
#include "frontend/operator/bprop/grad/common_utils.h"
|
||||
#include "include/common/utils/utils.h"
|
||||
|
||||
namespace mindspore::expander::bprop {
|
|
@ -15,9 +15,9 @@
|
|||
*/
|
||||
#include <unordered_set>
|
||||
|
||||
#include "common/graph_kernel/bprop/bprop_irbuilder.h"
|
||||
#include "frontend/operator/bprop/bprop_irbuilder.h"
|
||||
#include "include/common/utils/utils.h"
|
||||
#include "common/graph_kernel/bprop/expander/common_utils.h"
|
||||
#include "frontend/operator/bprop/grad/common_utils.h"
|
||||
|
||||
namespace mindspore::expander::bprop {
|
||||
NodePtrList CheckBpropExpander(const BpropIRBuilder *ib) {
|
|
@ -13,8 +13,8 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "common/graph_kernel/bprop/bprop_irbuilder.h"
|
||||
#include "common/graph_kernel/bprop/expander/common_utils.h"
|
||||
#include "frontend/operator/bprop/bprop_irbuilder.h"
|
||||
#include "frontend/operator/bprop/grad/common_utils.h"
|
||||
#include "include/common/utils/utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
|
||||
|
@ -1030,7 +1030,6 @@ REG_BPROP_BUILDER("SparseSoftmaxCrossEntropyWithLogits").SetBody([](const BpropI
|
|||
auto logits = ib->GetInput(kIndex0);
|
||||
auto dout = ib->GetInput(kIndex3);
|
||||
auto grad = ib->Emit(kSparseSoftmaxCrossEntropyWithLogitsOpName, {logits, labels}, {{kAttrIsGrad, MakeValue(true)}});
|
||||
grad = ib->Emit("Depend", {grad, out});
|
||||
grad = ib->Mul(grad, dout);
|
||||
return {grad, ib->ZerosLike(labels)};
|
||||
});
|
|
@ -14,7 +14,7 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "common/graph_kernel/bprop/bprop_irbuilder.h"
|
||||
#include "frontend/operator/bprop/bprop_irbuilder.h"
|
||||
#include "include/common/utils/utils.h"
|
||||
|
||||
namespace mindspore::expander::bprop {
|
|
@ -13,7 +13,7 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "common/graph_kernel/bprop/bprop_irbuilder.h"
|
||||
#include "frontend/operator/bprop/bprop_irbuilder.h"
|
||||
#include "include/common/utils/utils.h"
|
||||
#include "utils/ms_context.h"
|
||||
|
|
@ -14,7 +14,7 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "common/graph_kernel/bprop/bprop_irbuilder.h"
|
||||
#include "frontend/operator/bprop/bprop_irbuilder.h"
|
||||
#include "include/common/utils/utils.h"
|
||||
|
||||
namespace mindspore::expander::bprop {
|
||||
|
@ -52,7 +52,7 @@ REG_BPROP_BUILDER("SolveTriangular").SetBody([](const BpropIRBuilder *ib) -> Nod
|
|||
if (GetValue<bool>(unit_diagonal)) {
|
||||
auto fill = ib->Emit("Fill", {ib->EmitValue(ib->GetDtype(grad_a)), ib->Value<ShapeVector>(ShapeVector(1, row_size)),
|
||||
ib->Tensor(0, ib->GetDtype(grad_a))});
|
||||
grad_a = ib->Emit("MatrixSetDiagV3", {grad_a, fill, ib->Fill(0L, {2}, TypeId::kNumberTypeInt32)},
|
||||
grad_a = ib->Emit("MatrixSetDiagV3", {grad_a, fill, ib->Fill(int64_t(0), {2}, TypeId::kNumberTypeInt32)},
|
||||
{{"align", MakeValue("RIGHT_LEFT")}, {"max_length", MakeValue<int64_t>(200000000)}});
|
||||
}
|
||||
return {grad_a, grad_b};
|
||||
|
@ -96,7 +96,7 @@ REG_BPROP_BUILDER("Eigh").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
|||
// constants in the computation
|
||||
std::vector<int32_t> zero_tensor_value = {0, 0};
|
||||
ShapeVector zero_shape{2};
|
||||
auto zero_tensor = ib->Fill(0L, {2}, TypeId::kNumberTypeInt32);
|
||||
auto zero_tensor = ib->Fill(int64_t(0), {2}, TypeId::kNumberTypeInt32);
|
||||
auto kValueNeg1 = ib->Value<int64_t>(-1);
|
||||
auto kValueNeg2 = ib->Value<int64_t>(-2);
|
||||
|
|
@ -15,8 +15,8 @@
|
|||
*/
|
||||
|
||||
#include <tuple>
|
||||
#include "common/graph_kernel/bprop/bprop_irbuilder.h"
|
||||
#include "common/graph_kernel/bprop/expander/common_utils.h"
|
||||
#include "frontend/operator/bprop/bprop_irbuilder.h"
|
||||
#include "frontend/operator/bprop/grad/common_utils.h"
|
||||
#include "include/common/utils/utils.h"
|
||||
#include "utils/ms_context.h"
|
||||
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,228 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_AD_AUTO_GRAD_H_
|
||||
#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_AD_AUTO_GRAD_H_
|
||||
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include "ir/anf.h"
|
||||
#include "ir/func_graph.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ad {
|
||||
struct GradAttr {
|
||||
GradAttr(bool get_all, bool get_by_list, bool sens_param, bool get_by_position, bool weight_param_is_tuple)
|
||||
: grad_all_inputs(get_all),
|
||||
grad_weights(get_by_list),
|
||||
has_sens(sens_param),
|
||||
get_by_position(get_by_position),
|
||||
weight_param_is_tuple(weight_param_is_tuple) {}
|
||||
|
||||
bool grad_all_inputs;
|
||||
bool grad_weights;
|
||||
bool has_sens;
|
||||
bool get_by_position;
|
||||
bool weight_param_is_tuple;
|
||||
};
|
||||
|
||||
struct GradParam {
|
||||
GradParam(const CNodePtr &cnode, const ValuePtrList &op_args, const ValuePtr &out, FuncGraphPtr fprop_fg = nullptr)
|
||||
: cnode(cnode), op_args(op_args), out(out), fprop_fg(std::move(fprop_fg)) {}
|
||||
|
||||
// Primal CNode create by op forward process
|
||||
const CNodePtr &cnode;
|
||||
// Input value for cnode
|
||||
const ValuePtrList &op_args;
|
||||
// Output of op
|
||||
const ValuePtr &out;
|
||||
// Bprop func graph
|
||||
const FuncGraphPtr fprop_fg;
|
||||
// High order used this, which
|
||||
bool grad_by_value = true;
|
||||
};
|
||||
using GradParamPtr = std::shared_ptr<GradParam>;
|
||||
|
||||
class FunctionNode {
|
||||
public:
|
||||
FunctionNode(const FuncGraphPtr &tape, const AnfNodePtr &dout)
|
||||
: tape_(tape), accumulate_dout_(dout), fake_dout_(dout) {}
|
||||
void AddEdge(const AnfNodePtr &next_node, const AnfNodePtr &din);
|
||||
void UpdateAccumulativeDout(const AnfNodePtr &new_dout);
|
||||
const std::vector<std::pair<AnfNodePtr, AnfNodePtr>> &next_edges() const { return next_edges_; }
|
||||
AnfNodePtr RealDout() const { return accumulate_dout_; }
|
||||
void ReplaceEdges();
|
||||
const AnfNodePtr fake_dout() const { return fake_dout_; }
|
||||
|
||||
private:
|
||||
AnfNodePtr HyperAdd(const AnfNodePtr &left_node, const AnfNodePtr &right_node);
|
||||
// Bprop func graph
|
||||
const FuncGraphPtr tape_;
|
||||
// Input of dout for this bprop function
|
||||
AnfNodePtr accumulate_dout_;
|
||||
// First we generate a fake dout
|
||||
const AnfNodePtr fake_dout_;
|
||||
// Represent where thd dins backpropagate to other bprop function or variable
|
||||
std::vector<std::pair<AnfNodePtr, AnfNodePtr>> next_edges_;
|
||||
// Replace next_edges where din == dout in brprop function
|
||||
std::vector<int> need_replace_edges_;
|
||||
};
|
||||
using FunctionNodePtr = std::shared_ptr<FunctionNode>;
|
||||
|
||||
class VariableNode {
|
||||
public:
|
||||
VariableNode(const FunctionNodePtr &fn, const ValuePtr &out_value) : fn_(fn), out_value_(out_value) {}
|
||||
|
||||
ValuePtr out_value() const { return out_value_; }
|
||||
FunctionNodePtr fn() const { return fn_; }
|
||||
bool is_need_grad() const { return is_need_grad_; }
|
||||
void set_is_need_grad(bool is_need_grad) { is_need_grad_ = is_need_grad; }
|
||||
AnfNodePtr k_node() const { return k_node_; }
|
||||
void set_k_node(const AnfNodePtr &k_node) { k_node_ = k_node; }
|
||||
|
||||
private:
|
||||
// Abstract bprop function
|
||||
FunctionNodePtr fn_;
|
||||
ValuePtr out_value_;
|
||||
bool is_need_grad_{false};
|
||||
// k mapped cnode for primal CNode; primal CNode is owned by primal funcgraph, this is owned by tape_;
|
||||
AnfNodePtr k_node_{nullptr};
|
||||
};
|
||||
using VariableNodePtr = std::shared_ptr<VariableNode>;
|
||||
|
||||
class AutoGradCellImpl {
|
||||
public:
|
||||
using UserType = std::map<AnfNodePtr, std::vector<std::pair<CNodePtr, int>>>;
|
||||
AutoGradCellImpl(const AnfNodePtrList &cell_inputs, const std::vector<ValuePtr> &input_param_values);
|
||||
~AutoGradCellImpl() = default;
|
||||
// Reverse connect bprop of op
|
||||
bool KPynativeOp(const CNodePtr &cnode, const ValuePtrList &op_args, const ValuePtr &out);
|
||||
// Reverse connect ms_function or higher order sub bprop funcgraph
|
||||
bool KPynativeWithFProp(const GradParamPtr &grad_param);
|
||||
CNodePtr GetBPropFromFProp(const FuncGraphPtr &fprop_fg, const AnfNodePtrList &args, const ValuePtr &out,
|
||||
AnfNodePtr *const tape_dout);
|
||||
// Update top cell output, record last_node
|
||||
void UpdateOutputNodeOfTopCell(const AnfNodePtr &output_node, const ValuePtr &sens_out);
|
||||
// Build a back propagate funcgraph, each cnode in primal funcgraph is replaced by value node or formal cnode, so it
|
||||
// can be grad again.
|
||||
FuncGraphPtr Finish(const AnfNodePtrList &weights, const std::vector<size_t> &grad_position,
|
||||
const GradAttr &grad_attr, bool build_formal_param);
|
||||
|
||||
private:
|
||||
// Last cnode of this Cell, may be a primitive op or cell with user defined bprop.
|
||||
AnfNodePtr last_node_{nullptr};
|
||||
ValuePtr sens_value_{nullptr};
|
||||
// Bprop funcgraph
|
||||
FuncGraphPtr tape_;
|
||||
// Top cell inputs
|
||||
AnfNodePtrList cell_inputs_;
|
||||
// These weights need to calculate gradient.
|
||||
mindspore::HashSet<AnfNodePtr> need_grad_weights_;
|
||||
// Bprop dins of each variable or middle out
|
||||
OrderedMap<AnfNodePtr, VariableNodePtr> anfnode_to_variable_adjoint_;
|
||||
AnfNodePtrList weights_;
|
||||
// Record cnode's input map for tape_
|
||||
UserType users_;
|
||||
// Flag for ms_funtcion and high order
|
||||
bool has_fbprop_{false};
|
||||
|
||||
bool IsCNodeNeedGrad(const AnfNodePtr &node_ptr) const;
|
||||
std::vector<bool> GetNeedGradFlags(const CNodePtr &cnode);
|
||||
|
||||
// construct input as cnode for expander
|
||||
CNodePtr ConstructBpropGraphInput(const CNodePtr &cnode, const ValuePtrList &op_args, const ValuePtr &out,
|
||||
const AnfNodePtr &dout);
|
||||
// Back propagate for one node;
|
||||
void UpdateNextEdges(const FunctionNodePtr &fn, const CNodePtr &cnode, const std::vector<CNodePtr> &dins,
|
||||
const ValuePtrList &op_args);
|
||||
void UpdateNextEdges(const FunctionNodePtr &fn, const AnfNodePtr &node, const AnfNodePtr &din,
|
||||
const ValuePtr &op_arg);
|
||||
|
||||
void BuildForwardLastNode();
|
||||
// Add parameter(weights) to anfnode_to_variable_adjoint_
|
||||
void AddParameterNode(const AnfNodePtr ¶meter, const ValuePtr &tensor);
|
||||
AnfNodePtr GetRealDin(const FunctionNodePtr &fn, const ValuePtr &out_value, const ValuePtr &sub_value,
|
||||
const AnfNodePtr &din);
|
||||
void BuildBPropCutCNode(const CNodePtr &cnode, std::vector<CNodePtr> *outputs);
|
||||
void BuildCustomBpropCNode(const CNodePtr &cnode, std::vector<CNodePtr> *outputs);
|
||||
void BuildFakeBpropCNode(const CNodePtr &cnode, std::vector<CNodePtr> *outputs);
|
||||
// Replace input or weights parameter from primal funcgraph to parameters of tape_;
|
||||
void ReplacePrimalParameter(const AnfNodePtrList &weights, bool has_sens_arg);
|
||||
// Set sens and weights parameter nodes by user input info
|
||||
void SetSensAndWeights(const AnfNodePtrList &weights, bool has_sens_arg);
|
||||
// get last reverse iterator
|
||||
OrderedMap<AnfNodePtr, VariableNodePtr>::reverse_iterator GetLastNodeReverseIter();
|
||||
|
||||
void BackPropagate();
|
||||
// Set return node according to grad flag
|
||||
void SetOutput(const AnfNodePtrList &weights, const std::vector<size_t> &grad_position, const GradAttr &grad_attr);
|
||||
AnfNodePtr GetGradNodeByIndex(const AnfNodePtrList &node_list, size_t index);
|
||||
AnfNodePtr GetInputGrad(bool grad_all_inputs, bool get_by_position, const std::vector<size_t> &grad_position);
|
||||
AnfNodePtr GetWeightGrad(bool grad_weights, const AnfNodePtrList &weights, bool weight_param_is_tuple);
|
||||
bool IsOutputBothEmpty(const AnfNodePtr &inputs_grad, const AnfNodePtr &weights_grad) const;
|
||||
AnfNodePtr GenerateEmptyTupleValue();
|
||||
void AddUser(const AnfNodePtr &node, const CNodePtr &user, size_t index);
|
||||
void Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node);
|
||||
void ElimateTupleGetItem();
|
||||
void ClearDeviceAddress(const ValuePtr &out);
|
||||
|
||||
// Fbprop
|
||||
AnfNodePtr BuildKNode(const GradParamPtr &grad_param);
|
||||
AnfNodePtrList BuildKNodeListFromPrimalCNode(const CNodePtr &cnode, const VariableNodePtr &adjoint);
|
||||
AnfNodePtr BuildKNodeForCNodeInput(const ValuePtrList &op_args, const AnfNodePtr &input_node, size_t input_index);
|
||||
};
|
||||
using AutoGradCellImplPtr = std::shared_ptr<AutoGradCellImpl>;
|
||||
|
||||
// Start building back propagate funcgraph for this cell.
|
||||
// cell_inputs: the input parameter list of this cell except the weights;
|
||||
AutoGradCellImplPtr GradPynativeCellBegin(const AnfNodePtrList &cell_inputs,
|
||||
const std::vector<ValuePtr> &input_param_values);
|
||||
|
||||
// Return the back propagate funcgraph for this cell.
|
||||
// weights: weights parameters used in this cell.
|
||||
// grad_inputs: return sensitivity for input parameters;
|
||||
// grad_weights: return sensitivity for weights;
|
||||
// has_sens_arg: caller will pass sens args;
|
||||
// return: the returned funcgraph will have prototype:
|
||||
// if has_sens_arg is true
|
||||
// (sens_input1, sens_input2, ..., sens_weight0, sens_weight1, ) bprop_fg(input1, input2, ..., weight0, weight1, ...,
|
||||
// sens_out)
|
||||
// else:
|
||||
// (sens_input1, sens_input2, ..., sens_weight0, sens_weight1, ) bprop_fg(input1, input2, ..., weight0, weight1, ...)
|
||||
// if build_formal_param is true
|
||||
// each cnode in primal funcgraph is replaced by formal cnode
|
||||
// else:
|
||||
// each cnode in primal funcgraph is replaced by value node
|
||||
FuncGraphPtr GradPynativeCellEnd(const AutoGradCellImplPtr &k_cell, const AnfNodePtrList &weights,
|
||||
const std::vector<size_t> &grad_position, const GradAttr &grad_attr,
|
||||
bool build_formal_param = false);
|
||||
|
||||
// Grad for each operation.
|
||||
// c_node: CNode with contains the prim (index 0) and the formal input parameters of that prim.
|
||||
// op_args: the arguments list of each input parameters.
|
||||
// out: the op result.
|
||||
bool GradPynativeOp(const AutoGradCellImplPtr &k_cell, const CNodePtr &cnode, const ValuePtrList &op_args,
|
||||
const ValuePtr &out);
|
||||
|
||||
// adjoint bprop form ms_function and high grad
|
||||
void GradPynativeFBprop(const CNodePtr &cnode, const ValuePtrList &op_args, const ValuePtr &out,
|
||||
const FuncGraphPtr &fprop_fg);
|
||||
} // namespace ad
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_AD_AUTO_GRAD_H_
|
File diff suppressed because it is too large
Load Diff
|
@ -1,113 +0,0 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_AD_KPYNATIVE_H_
|
||||
#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_AD_KPYNATIVE_H_
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include "ir/anf.h"
|
||||
#include "ir/func_graph.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ad {
|
||||
class KPynativeCell {
|
||||
public:
|
||||
virtual ~KPynativeCell() = default;
|
||||
virtual void UpdateOutputNodeOfTopCell(const AnfNodePtr &output_node, const ValuePtr &sen_out) = 0;
|
||||
// Grad for cell which may have user passed front propagate FuncGraph.
|
||||
// c_node: CNode with contains the construct function graph of cell (index 0) and the formal input parameters of that
|
||||
// cell. op_args: the arguments list of each input parameters.
|
||||
// out: the op result.
|
||||
// fprop_fg: user defined back propagate cnode which output is the bprop_fg.
|
||||
// Should have prototype: (sens_input1, sens_input2, ...) bprop_fg(input1, input2, ..., out, dout)
|
||||
virtual bool KPynativeWithFProp(const CNodePtr &c_node, const ValuePtrList &op_args, const ValuePtr &out,
|
||||
const FuncGraphPtr &fprop_fg) = 0;
|
||||
};
|
||||
|
||||
using KPynativeCellPtr = std::shared_ptr<KPynativeCell>;
|
||||
|
||||
struct GradAttr {
|
||||
bool grad_all_inputs;
|
||||
bool grad_weights;
|
||||
bool has_sens;
|
||||
bool get_by_position;
|
||||
bool weight_param_is_tuple;
|
||||
|
||||
GradAttr(bool get_all, bool get_by_list, bool sens_param, bool get_by_position, bool weight_param_is_tuple)
|
||||
: grad_all_inputs(get_all),
|
||||
grad_weights(get_by_list),
|
||||
has_sens(sens_param),
|
||||
get_by_position(get_by_position),
|
||||
weight_param_is_tuple(weight_param_is_tuple) {}
|
||||
};
|
||||
|
||||
// bprop_fg: user defined back propagate funcgraph or back propagate funcgraph of primitive, it will be passed after
|
||||
// just parsed. will have prototype:
|
||||
// (sens_input1, sens_input2, ...) bprop_fg(input1, input2, ..., out, dout)
|
||||
// c_node: CNode with contains the prim (index 0) and the formal input parameters of that prim.
|
||||
// op_args: the arguments list of each input parameters.
|
||||
// out: the op result.
|
||||
// return: the returned funcgraph should have the same prototype.
|
||||
FuncGraphPtr OptimizeBPropFuncGraph(const FuncGraphPtr &bprop_fg, const CNodePtr &cnode, const ValuePtrList &op_args,
|
||||
const ValuePtr &out);
|
||||
|
||||
// Start building back propagate funcgraph for this cell.
|
||||
// cell_inputs: the input parameter list of this cell except the weights;
|
||||
KPynativeCellPtr GradPynativeCellBegin(const AnfNodePtrList &cell_inputs,
|
||||
const std::vector<ValuePtr> &input_param_values);
|
||||
|
||||
// Return the back propagate funcgraph for this cell.
|
||||
// weights: weights parameters used in this cell.
|
||||
// grad_inputs: return sensitivity for input parameters;
|
||||
// grad_weights: return sensitivity for weights;
|
||||
// has_sens_arg: caller will pass sens args;
|
||||
// return: the returned funcgraph will have prototype:
|
||||
// if has_sens_arg is true
|
||||
// (sens_input1, sens_input2, ..., sens_weight0, sens_weight1, ) bprop_fg(input1, input2, ..., weight0, weight1, ...,
|
||||
// sens_out)
|
||||
// else:
|
||||
// (sens_input1, sens_input2, ..., sens_weight0, sens_weight1, ) bprop_fg(input1, input2, ..., weight0, weight1, ...)
|
||||
// if build_formal_param is true
|
||||
// each cnode in primal funcgraph is replaced by formal cnode
|
||||
// else:
|
||||
// each cnode in primal funcgraph is replaced by value node
|
||||
FuncGraphPtr GradPynativeCellEnd(const KPynativeCellPtr &k_cell, const AnfNodePtrList &weights,
|
||||
const std::vector<size_t> &grad_position, const GradAttr &grad_attr,
|
||||
bool build_formal_param = false);
|
||||
|
||||
// Grad for each operation.
|
||||
// c_node: CNode with contains the prim (index 0) and the formal input parameters of that prim.
|
||||
// op_args: the arguments list of each input parameters.
|
||||
// out: the op result.
|
||||
bool GradPynativeOp(const KPynativeCellPtr &k_cell, const CNodePtr &cnode, const ValuePtrList &op_args,
|
||||
const ValuePtr &out);
|
||||
|
||||
// Grad for cell which may have user defined back propagate function.
|
||||
// c_node: CNode with contains the construct function graph of cell (index 0) and the formal input parameters of that
|
||||
// cell. op_args: the arguments list of each input parameters.
|
||||
// out: the op result.
|
||||
// bprop_fg: user defined back propagate funcgraph, it should be passed after just parsed.
|
||||
// Should have prototype: (sens_input1, sens_input2, ...) bprop_fg(input1, input2, ..., out, dout)
|
||||
bool GradPynativeWithBProp(const KPynativeCellPtr &k_cell, const CNodePtr &c_node, const ValuePtrList &op_args,
|
||||
const ValuePtr &out, const FuncGraphPtr &bprop_fg);
|
||||
|
||||
// Clear all static resources that used in grad process
|
||||
void ClearKPynativeCellStaticRes();
|
||||
} // namespace ad
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_AD_GRAD_H_
|
|
@ -191,18 +191,6 @@ FuncGraphPtr PrimBpOptPassStep2(const opt::irpass::OptimizeIRPassLib &irpass, co
|
|||
}
|
||||
|
||||
FuncGraphPtr BpropGraphFinalOptPass(const ResourcePtr &resource) {
|
||||
MS_EXCEPTION_IF_NULL(resource);
|
||||
MS_EXCEPTION_IF_NULL(resource->func_graph());
|
||||
(void)TransformTopGraphPass(resource);
|
||||
|
||||
auto func_graph = resource->func_graph();
|
||||
// PyNative dynamic shape need add those pass, like convert make_list to make_tuple.
|
||||
// Cannot execute those pass due to performance reasons if the graph is a dynamic structure graph.
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
if (func_graph->has_flag(FUNC_GRAPH_FLAG_DYNAMIC_SHAPE) || !func_graph->has_flag(kFlagIsDynamicStructure)) {
|
||||
(void)OptPassAGroup(resource);
|
||||
(void)CleanAfterOptAPass(resource);
|
||||
}
|
||||
opt::irpass::OptimizeIRPassLib irpass;
|
||||
opt::OptPassConfig bg_final_opt = opt::OptPassConfig({
|
||||
irpass.inline_,
|
||||
|
@ -210,6 +198,27 @@ FuncGraphPtr BpropGraphFinalOptPass(const ResourcePtr &resource) {
|
|||
irpass.tuple_list_get_item_eliminator_,
|
||||
irpass.tuple_list_set_item_eliminator_,
|
||||
irpass.depend_value_elim_,
|
||||
});
|
||||
OptPassGroupMap map({{"ad_final_opt", bg_final_opt}});
|
||||
auto bprop_graph_final_opt = opt::Optimizer::MakeOptimizer("bprop_graph_final_opt", resource, map);
|
||||
MS_EXCEPTION_IF_NULL(resource);
|
||||
auto func_graph = resource->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
WITH(MsProfile::GetProfile()->Step("bprop_graph_final_opt"))[&bprop_graph_final_opt, &func_graph]() {
|
||||
func_graph = bprop_graph_final_opt->step(func_graph, true);
|
||||
};
|
||||
// Validate(func_graph);
|
||||
return func_graph;
|
||||
}
|
||||
|
||||
FuncGraphPtr OptGradGraphPass(const ResourcePtr &resource) {
|
||||
opt::irpass::OptimizeIRPassLib irpass;
|
||||
opt::OptPassConfig grad_graph_opt = opt::OptPassConfig({
|
||||
irpass.inline_,
|
||||
irpass.tuple_list_get_set_item_eliminator_,
|
||||
irpass.tuple_list_get_item_eliminator_,
|
||||
irpass.tuple_list_set_item_eliminator_,
|
||||
irpass.depend_value_elim_,
|
||||
irpass.reshape_eliminate_,
|
||||
irpass.switch_simplify_,
|
||||
irpass.addn_zero_filter_,
|
||||
|
@ -217,28 +226,21 @@ FuncGraphPtr BpropGraphFinalOptPass(const ResourcePtr &resource) {
|
|||
});
|
||||
opt::OptPassConfig fill_zeros_like = opt::OptPassConfig{irpass.zero_like_fill_zero_};
|
||||
OptPassGroupMap map({
|
||||
{"ad_final_opt", bg_final_opt},
|
||||
{"ad_final_opt", grad_graph_opt},
|
||||
{"zeros_like", fill_zeros_like},
|
||||
});
|
||||
|
||||
if (pynative::PyNativeExecutor::GetInstance()->grad_executor()->need_renormalize()) {
|
||||
(void)map.emplace_back(std::make_pair("renormalize", opt::OptPassConfig::Renormalize()));
|
||||
opt::OptPassConfig real_op_eliminate = opt::OptPassConfig{irpass.real_op_eliminate_};
|
||||
(void)map.emplace_back(std::make_pair("real_op_eliminate", real_op_eliminate));
|
||||
opt::OptPassConfig environ_eliminate = opt::OptPassConfig({
|
||||
irpass.incorporate_call_,
|
||||
irpass.incorporate_call_switch_,
|
||||
});
|
||||
(void)map.emplace_back(std::make_pair("environ_eliminate", environ_eliminate));
|
||||
}
|
||||
|
||||
auto bprop_graph_final_opt = opt::Optimizer::MakeOptimizer("bprop_graph_final_opt", resource, map);
|
||||
func_graph = resource->func_graph();
|
||||
(void)map.emplace_back(std::make_pair("renormalize", opt::OptPassConfig::Renormalize()));
|
||||
opt::OptPassConfig real_op_eliminate = opt::OptPassConfig{irpass.real_op_eliminate_};
|
||||
(void)map.emplace_back(std::make_pair("real_op_eliminate", real_op_eliminate));
|
||||
MS_EXCEPTION_IF_NULL(resource);
|
||||
auto bprop_graph_final_opt = opt::Optimizer::MakeOptimizer("grad_graph_opt", resource, map);
|
||||
auto func_graph = resource->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
WITH(MsProfile::GetProfile()->Step("bprop_graph_final_opt"))[&bprop_graph_final_opt, &func_graph]() {
|
||||
func_graph = bprop_graph_final_opt->step(func_graph, true);
|
||||
};
|
||||
func_graph = LiftingClone(func_graph);
|
||||
Validate(func_graph);
|
||||
// Validate(func_graph);
|
||||
return func_graph;
|
||||
}
|
||||
|
||||
|
|
|
@ -55,6 +55,7 @@ FuncGraphPtr PrimBpOptPassStep1(const opt::irpass::OptimizeIRPassLib &irpass, co
|
|||
FuncGraphPtr PrimBpOptPassStep2(const opt::irpass::OptimizeIRPassLib &irpass, const ResourcePtr &resource,
|
||||
const std::vector<bool> &need_grad_flags);
|
||||
FuncGraphPtr BpropGraphFinalOptPass(const ResourcePtr &resource);
|
||||
FuncGraphPtr OptGradGraphPass(const ResourcePtr &resource);
|
||||
} // namespace pipeline
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -1838,7 +1838,6 @@ void MemoryRecycle() {
|
|||
ReclaimOptimizer();
|
||||
session::ExecutorManager::Instance().ClearDoneTasks();
|
||||
ad::g_k_prims.clear();
|
||||
ad::ClearKPynativeCellStaticRes();
|
||||
ad::PrimBpropOptimizer::GetPrimBpropOptimizerInst().Clear();
|
||||
abstract::AnalysisResultCacheMgr::GetInstance().Clear();
|
||||
abstract::AnalysisContext::ClearContext();
|
||||
|
@ -1850,6 +1849,7 @@ void MemoryRecycle() {
|
|||
parse::Parser::CleanParserResource();
|
||||
trace::ClearTraceStack();
|
||||
pynative::PyNativeExecutor::GetInstance()->ClearRes();
|
||||
pynative::PyNativeExecutor::GetInstance()->WorkerJoin();
|
||||
ConfigManager::GetInstance().ResetConfig();
|
||||
ScopeManager::GetInstance().ClearScope();
|
||||
FuncGraphLoopBreaker::Inst().CleanMetaFuncGraphCache();
|
||||
|
@ -1881,7 +1881,6 @@ void ClearResPart1() {
|
|||
(void)distributed::collective::CollectiveManager::instance()->Finalize();
|
||||
PrimitivePy::ClearHookRes();
|
||||
ad::g_k_prims.clear();
|
||||
ad::ClearKPynativeCellStaticRes();
|
||||
ad::PrimBpropOptimizer::GetPrimBpropOptimizerInst().Clear();
|
||||
|
||||
abstract::ClearPrimEvaluatorMap();
|
||||
|
|
|
@ -60,6 +60,7 @@ struct FrontendOpRunInfo {
|
|||
bool output_get_by_infer_value = false;
|
||||
int mix_type{0};
|
||||
size_t input_size = 0;
|
||||
size_t custom_bprop_cell_count = 0;
|
||||
PrimitivePyPtr op_prim{nullptr};
|
||||
ValuePtr out_value{nullptr};
|
||||
std::string op_info;
|
||||
|
@ -86,12 +87,13 @@ struct InputArgsInfo {
|
|||
bool has_custom_bprop;
|
||||
size_t input_size;
|
||||
std::string obj_id;
|
||||
|
||||
bool has_sens{false};
|
||||
PrimitivePyPtr custom_bprp_prim{nullptr};
|
||||
ValuePtr out_value{nullptr};
|
||||
std::string cell_id;
|
||||
std::string input_args_id;
|
||||
size_t custom_bprop_cell_count = 0;
|
||||
size_t grad_order = 0;
|
||||
std::vector<std::string> input_arg_id_vec;
|
||||
std::vector<ValuePtr> input_arg_value_vec;
|
||||
};
|
||||
|
|
|
@ -266,6 +266,7 @@ ValuePtr CastOperation::DoAutoCast(const FrontendOpRunInfoPtr &op_run_info, cons
|
|||
constexpr auto input_size = 2;
|
||||
const auto &cast_run_info = std::make_shared<FrontendOpRunInfo>();
|
||||
cast_run_info->grad_flag = op_run_info->grad_flag;
|
||||
cast_run_info->custom_bprop_cell_count = op_run_info->custom_bprop_cell_count;
|
||||
MS_EXCEPTION_IF_NULL(cast_prim_);
|
||||
cast_run_info->op_prim = cast_prim_;
|
||||
cast_run_info->base_op_run_info.op_name = prim::kPrimCast->name();
|
||||
|
|
|
@ -184,7 +184,11 @@ void ForwardExecutor::RunOpForward(const FrontendOpRunInfoPtr &op_run_info) {
|
|||
GetOutput(op_run_info);
|
||||
}
|
||||
// 4. Do op grad and record op info
|
||||
grad()->ProcessOpGradInfo(op_run_info);
|
||||
if (enable_async_) {
|
||||
grad()->AsyncProcessOpGradInfo(op_run_info);
|
||||
} else {
|
||||
grad()->ProcessOpGradInfo(op_run_info);
|
||||
}
|
||||
}
|
||||
|
||||
FrontendOpRunInfoPtr ForwardExecutor::GenerateOpRunInfo(const py::args &args) const {
|
||||
|
@ -194,6 +198,7 @@ FrontendOpRunInfoPtr ForwardExecutor::GenerateOpRunInfo(const py::args &args) co
|
|||
const auto &op_run_info = std::make_shared<FrontendOpRunInfo>();
|
||||
// Used for async run
|
||||
op_run_info->grad_flag = grad()->grad_flag();
|
||||
op_run_info->custom_bprop_cell_count = grad()->custom_bprop_cell_count();
|
||||
op_run_info->base_op_run_info.op_name = args[static_cast<size_t>(RunOpArgsEnum::PY_NAME)].cast<std::string>();
|
||||
op_run_info->base_op_run_info.lazy_build = lazy_build_;
|
||||
PyNativeAlgo::PyParser::SetPrim(op_run_info, args[static_cast<size_t>(RunOpArgsEnum::PY_PRIM)]);
|
||||
|
@ -236,6 +241,7 @@ void ForwardExecutor::GetOutput(const FrontendOpRunInfoPtr &op_run_info) {
|
|||
op_run_info->out_value = result_v_list->value().front();
|
||||
}
|
||||
}
|
||||
|
||||
// Not use GetNext abs
|
||||
if (op_run_info->base_op_run_info.op_name != kGetNextOpName) {
|
||||
op_run_info->out_value_id = PyNativeAlgo::Common::GetIdByValue(op_run_info->out_value);
|
||||
|
|
|
@ -38,7 +38,9 @@ using MindrtBackendMap = std::map<std::string, std::shared_ptr<compile::MindRTBa
|
|||
class ForwardExecutor {
|
||||
public:
|
||||
ForwardExecutor()
|
||||
: cast_operation_(std::make_shared<CastOperation>()), infer_operation_(std::make_shared<InferOperation>()) {}
|
||||
: cast_operation_(std::make_shared<CastOperation>()),
|
||||
infer_operation_(std::make_shared<InferOperation>()),
|
||||
enable_async_(std::getenv("ENABLE_ASYNC")) {}
|
||||
~ForwardExecutor() = default;
|
||||
|
||||
void Init();
|
||||
|
@ -100,6 +102,7 @@ class ForwardExecutor {
|
|||
CastOperationPtr cast_operation_;
|
||||
InferOperationPtr infer_operation_;
|
||||
MindrtBackendMap mindrt_backends_;
|
||||
bool enable_async_ = false;
|
||||
};
|
||||
} // namespace pynative
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -0,0 +1,27 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "pipeline/pynative/grad/bprop_task.h"
|
||||
#include "utils/log_adapter.h"
|
||||
namespace mindspore {
|
||||
namespace pynative {
|
||||
void BpropTask::Run() {
|
||||
MS_LOG(DEBUG) << "run construct bprop task";
|
||||
run_task_();
|
||||
MS_LOG(DEBUG) << "finish construct bprop task";
|
||||
}
|
||||
} // namespace pynative
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,36 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PIPELINE_PYNATIVE_BPROP_TASK_H_
|
||||
#define MINDSPORE_CCSRC_PIPELINE_PYNATIVE_BPROP_TASK_H_
|
||||
|
||||
#include <functional>
|
||||
#include "runtime/pynative/async/task.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace pynative {
|
||||
class BpropTask : public AsyncTask {
|
||||
public:
|
||||
explicit BpropTask(const std::function<void(void)> &task) : AsyncTask(kBpropTask), run_task_(task) {}
|
||||
~BpropTask() = default;
|
||||
void Run() override;
|
||||
|
||||
private:
|
||||
std::function<void(void)> run_task_;
|
||||
};
|
||||
} // namespace pynative
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PIPELINE_PYNATIVE_BPROP_TASK_H_
|
|
@ -139,12 +139,29 @@ ValuePtr ConvertOutputValueToTensor(const ValuePtr &v) {
|
|||
}
|
||||
}
|
||||
|
||||
FuncGraphPtr BpropGraphFinalOpt(const FuncGraphPtr &bprop_graph) {
|
||||
auto resource = std::make_shared<pipeline::Resource>();
|
||||
resource->set_func_graph(bprop_graph);
|
||||
auto manager = resource->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
manager->AddFuncGraph(bprop_graph);
|
||||
auto after_opt_bg = pipeline::BpropGraphFinalOptPass(resource);
|
||||
PyNativeAlgo::Common::DumpGraphIR("after_final_opt.ir", after_opt_bg);
|
||||
return after_opt_bg;
|
||||
}
|
||||
|
||||
void SetGraphInputArgs(const std::vector<ValuePtr> &input_vec, const pipeline::ResourcePtr &res,
|
||||
VectorRef *const arg_list) {
|
||||
VectorRef *const arg_list, bool has_sens) {
|
||||
MS_EXCEPTION_IF_NULL(arg_list);
|
||||
// Set inputs values
|
||||
for (auto v : input_vec) {
|
||||
(void)arg_list->emplace_back(v);
|
||||
size_t size = has_sens ? input_vec.size() - 1 : input_vec.size();
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
if (PyNativeAlgo::Common::IsTensor(input_vec[i])) {
|
||||
(void)arg_list->emplace_back(input_vec[i]);
|
||||
}
|
||||
}
|
||||
if (has_sens) {
|
||||
(void)arg_list->emplace_back(input_vec.back());
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(res);
|
||||
auto graph = res->func_graph();
|
||||
|
@ -203,7 +220,7 @@ TopCellInfoPtr GradExecutor::PopHighOrderGraphStack() {
|
|||
|
||||
void GradExecutor::PushInputArgsInfoStack(const InputArgsInfoPtr &input_args_info) {
|
||||
input_args_info_stack_.push(input_args_info);
|
||||
++cell_order_;
|
||||
// ++cell_order_;
|
||||
}
|
||||
|
||||
void GradExecutor::PopInputArgsInfoStack() {
|
||||
|
@ -251,12 +268,12 @@ void GradExecutor::HandleInputArgsForTopCell(const InputArgsInfoPtr &input_args_
|
|||
new_param->set_abstract(param_i_abs);
|
||||
top_cell()->SetParamNodeMapInGraphInfoMap(input_args_info->input_arg_id_vec[i], new_param);
|
||||
}
|
||||
top_cell()->set_k_pynative_cell_ptr(ad::GradPynativeCellBegin(curr_g()->parameters(), input_param_values));
|
||||
top_cell()->set_auto_grad_cell_ptr(ad::GradPynativeCellBegin(curr_g()->parameters(), input_param_values));
|
||||
}
|
||||
|
||||
void GradExecutor::InitResourceAndDfBuilder(const InputArgsInfoPtr &input_args_info) {
|
||||
MS_EXCEPTION_IF_NULL(input_args_info);
|
||||
if (input_args_info->is_grad_topest_cell || IsNestedGrad()) {
|
||||
if (input_args_info->is_grad_topest_cell || input_args_info->grad_order > 1) {
|
||||
if (input_args_info->is_grad_topest_cell && !grad_is_running_) {
|
||||
MS_LOG(DEBUG) << "Make new topest graph";
|
||||
MakeNewTopGraph(input_args_info);
|
||||
|
@ -268,7 +285,7 @@ void GradExecutor::InitResourceAndDfBuilder(const InputArgsInfoPtr &input_args_i
|
|||
top_cell()->SetGraphInfoMap(fg, graph_info_cg);
|
||||
HandleInputArgsForTopCell(input_args_info, true);
|
||||
bprop_grad_stack_.push(std::make_pair(input_args_info->cell_id, false));
|
||||
} else if (grad_is_running_ && top_cell()->grad_order() != grad_order_) {
|
||||
} else if (grad_is_running_ && top_cell()->grad_order() != input_args_info->grad_order) {
|
||||
MS_LOG(DEBUG) << "Nested grad graph existed in custom bprop";
|
||||
MakeNewTopGraph(input_args_info);
|
||||
bprop_grad_stack_.push(std::make_pair(input_args_info->cell_id, true));
|
||||
|
@ -290,42 +307,59 @@ void GradExecutor::InitResourceAndDfBuilder(const InputArgsInfoPtr &input_args_i
|
|||
void GradExecutor::NewGraphInner(const py::object &obj, const py::args &args) {
|
||||
const auto &input_args_info = GetInputArgsInfo(obj, args, input_args_info_stack_.empty(), is_high_order_top_cell());
|
||||
PushInputArgsInfoStack(input_args_info);
|
||||
|
||||
if (input_args_info->has_custom_bprop) {
|
||||
custom_bprop_cell_count_ += 1;
|
||||
input_args_info->custom_bprop_cell_count = custom_bprop_cell_count_;
|
||||
}
|
||||
if (grad_order_ == 0) {
|
||||
IncreaseGradOrder();
|
||||
}
|
||||
input_args_info->grad_order = grad_order_;
|
||||
// May be can async here
|
||||
NewGraphImpl(input_args_info);
|
||||
if (enable_async_) {
|
||||
AsyncNewGraphImpl(input_args_info);
|
||||
} else {
|
||||
NewGraphImpl(input_args_info);
|
||||
}
|
||||
}
|
||||
|
||||
void GradExecutor::NewGraphImpl(const InputArgsInfoPtr &input_args_info) {
|
||||
MS_EXCEPTION_IF_NULL(input_args_info);
|
||||
++cell_order_;
|
||||
const auto &cell_id = input_args_info->cell_id;
|
||||
MS_LOG(DEBUG) << "NewGraphInner start " << input_args_info->input_size << ", cell_id " << cell_id
|
||||
<< ", input args info ptr " << input_args_info.get();
|
||||
// When the cell has custom bprop, in_custom_bprop_cell is lager than 0
|
||||
if (input_args_info->has_custom_bprop) {
|
||||
custom_bprop_cell_count_ += 1;
|
||||
}
|
||||
// Make top graph and init resource
|
||||
InitResourceAndDfBuilder(input_args_info);
|
||||
}
|
||||
|
||||
void GradExecutor::AsyncNewGraphImpl(const InputArgsInfoPtr &input_args_info) {
|
||||
const auto fn = [this, input_args_info]() { this->NewGraphImpl(input_args_info); };
|
||||
auto task = std::make_shared<BpropTask>(fn);
|
||||
async_executor_->Push(task);
|
||||
}
|
||||
|
||||
void GradExecutor::MakeNewTopGraph(const InputArgsInfoPtr &input_args_info) {
|
||||
MS_EXCEPTION_IF_NULL(input_args_info);
|
||||
// CheckAlready run first, grad_order_ will increase 1(highorder scenario)
|
||||
// If NetA.set_grad(), so come here first, CheckAlready run later, so grad_order_ need increase 1
|
||||
if (grad_order_ == 0) {
|
||||
IncreaseGradOrder();
|
||||
if (input_args_info->grad_order == 0) {
|
||||
input_args_info->grad_order++;
|
||||
}
|
||||
// Both set grad: NetA.set_grad(); NetB.set_grad();
|
||||
// Run forward: NetA(); NetB();
|
||||
// Grad(NetA()); Grad(NetB()). grad_order_ is disordered, so need reset.
|
||||
if (input_args_info->is_grad_topest_cell && IsNestedGrad()) {
|
||||
DecreaseGradOrder();
|
||||
if (input_args_info->is_grad_topest_cell && input_args_info->grad_order > 1) {
|
||||
input_args_info->grad_order--;
|
||||
}
|
||||
|
||||
auto fg = std::make_shared<FuncGraph>();
|
||||
fg->debug_info()->set_name("pynative_forward_graph");
|
||||
auto resource = std::make_shared<pipeline::Resource>();
|
||||
const auto &already_run_cell_id = input_args_info->cell_id + std::to_string(grad_order_ == 0 ? 1 : grad_order_);
|
||||
top_cell_ = std::make_shared<TopCellInfo>(grad_order_, input_args_info->cell_id, already_run_cell_id, resource, fg);
|
||||
const auto &already_run_cell_id = input_args_info->cell_id + std::to_string(input_args_info->grad_order);
|
||||
top_cell_ = std::make_shared<TopCellInfo>(input_args_info->grad_order, input_args_info->cell_id, already_run_cell_id,
|
||||
resource, fg);
|
||||
top_cell_->set_forward_already_run(true);
|
||||
top_cell_->set_input_args_id(input_args_info->input_args_id);
|
||||
PushHighOrderGraphStack(top_cell_);
|
||||
|
@ -349,10 +383,11 @@ void GradExecutor::SetForwardLastNodeInfo(const ValuePtr &v, const std::string &
|
|||
output_node->set_abstract(v->ToAbstract()->Broaden());
|
||||
}
|
||||
// Set last output abstract and will be used for sens
|
||||
auto k_pynative_cell_ptr = top_cell()->k_pynative_cell_ptr();
|
||||
MS_EXCEPTION_IF_NULL(k_pynative_cell_ptr);
|
||||
auto auto_grad_cell_ptr = top_cell()->auto_grad_cell_ptr();
|
||||
MS_EXCEPTION_IF_NULL(auto_grad_cell_ptr);
|
||||
auto sens_v = ConvertOutputValueToTensor(v);
|
||||
k_pynative_cell_ptr->UpdateOutputNodeOfTopCell(output_node, sens_v);
|
||||
auto cloned_value = ShallowCopyTensorValue(sens_v);
|
||||
auto_grad_cell_ptr->UpdateOutputNodeOfTopCell(output_node, cloned_value);
|
||||
}
|
||||
|
||||
void GradExecutor::EndGraphInner(const py::object &obj, const py::object &out, const py::args &args) {
|
||||
|
@ -366,8 +401,15 @@ void GradExecutor::EndGraphInner(const py::object &obj, const py::object &out, c
|
|||
}
|
||||
input_args_info->out_value = PyNativeAlgo::DataConvert::PyObjToValue(out);
|
||||
PopInputArgsInfoStack();
|
||||
if (input_args_info->is_grad_topest_cell) {
|
||||
set_grad_flag(false);
|
||||
}
|
||||
// May be can async here
|
||||
EndGraphImpl(input_args_info);
|
||||
if (enable_async_) {
|
||||
AsyncEndGraphImpl(input_args_info);
|
||||
} else {
|
||||
EndGraphImpl(input_args_info);
|
||||
}
|
||||
}
|
||||
|
||||
void GradExecutor::EndGraphImpl(const InputArgsInfoPtr &input_args_info) {
|
||||
|
@ -402,7 +444,7 @@ void GradExecutor::EndGraphImpl(const InputArgsInfoPtr &input_args_info) {
|
|||
SetForwardLastNodeInfo(out_value, out_id);
|
||||
top_cell()->ClearCellHookOp();
|
||||
cell_order_ = 0;
|
||||
set_grad_flag(false);
|
||||
// set_grad_flag(false);
|
||||
}
|
||||
// Checkout whether need to compile graph when each top cell has ran finished
|
||||
if (is_top_cell_end) {
|
||||
|
@ -415,9 +457,15 @@ void GradExecutor::EndGraphImpl(const InputArgsInfoPtr &input_args_info) {
|
|||
}
|
||||
}
|
||||
|
||||
void GradExecutor::AsyncEndGraphImpl(const InputArgsInfoPtr input_args_info) {
|
||||
const auto fn = [this, input_args_info]() { this->EndGraphImpl(input_args_info); };
|
||||
auto task = std::make_shared<BpropTask>(fn);
|
||||
async_executor_->Push(task);
|
||||
}
|
||||
|
||||
void GradExecutor::DoGradForCustomBprop(const InputArgsInfoPtr &input_args_info, const std::string &out_id) {
|
||||
MS_EXCEPTION_IF_NULL(input_args_info);
|
||||
if (!input_args_info->has_custom_bprop || custom_bprop_cell_count_ != 0) {
|
||||
if (!input_args_info->has_custom_bprop || input_args_info->custom_bprop_cell_count != 0) {
|
||||
return;
|
||||
}
|
||||
MS_LOG(DEBUG) << "Do grad for custom bprop";
|
||||
|
@ -437,6 +485,7 @@ void GradExecutor::DoGradForCustomBprop(const InputArgsInfoPtr &input_args_info,
|
|||
void GradExecutor::GetCustomBpropPrim(const py::object &obj, const py::args &args, const py::object &out,
|
||||
const InputArgsInfoPtr &input_args_info) {
|
||||
custom_bprop_cell_count_ -= 1;
|
||||
input_args_info->custom_bprop_cell_count = custom_bprop_cell_count_;
|
||||
if (custom_bprop_cell_count_ != 0) {
|
||||
return;
|
||||
}
|
||||
|
@ -488,12 +537,17 @@ void GradExecutor::GetCustomBpropPrim(const py::object &obj, const py::args &arg
|
|||
|
||||
void GradExecutor::GradNetInner(const prim::GradOperationPtr &grad, const py::object &obj, const py::object &weights,
|
||||
const py::object &grad_position, const py::args &args) {
|
||||
{
|
||||
py::gil_scoped_release gil_release;
|
||||
async_executor_->Wait();
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(grad);
|
||||
MS_EXCEPTION_IF_NULL(top_input_args_info_);
|
||||
MS_LOG(DEBUG) << "GradNetInner start " << args.size() << ", cell_id " << top_input_args_info_->cell_id
|
||||
<< ", input args info ptr " << top_input_args_info_.get();
|
||||
if (grad->sens_param()) {
|
||||
MS_LOG(DEBUG) << "Get sens param";
|
||||
top_input_args_info_->has_sens = true;
|
||||
size_t forward_args_size = args.size() - 1;
|
||||
auto sens_v = PyNativeAlgo::DataConvert::PyObjToValue(args[forward_args_size]);
|
||||
const auto &sens_tensor = ConvertOutputValueToTensor(sens_v);
|
||||
|
@ -529,7 +583,10 @@ void GradExecutor::GetGradGraph(const ad::GradAttr &grad_attr, const std::vector
|
|||
MS_EXCEPTION_IF_NULL(manager);
|
||||
manager->AddFuncGraph(bprop_graph, true);
|
||||
PyNativeAlgo::Common::DumpGraphIR("launch_bprop_graph.ir", bprop_graph);
|
||||
resource->SetBackendAsync([]() { return compile::CreateBackend(); });
|
||||
if (backends_.find(top_input_args_info_->obj_id) == backends_.end()) {
|
||||
backends_[top_input_args_info_->obj_id] = compile::CreateBackend();
|
||||
}
|
||||
resource->SetBackendAsync([&]() { return backends_[top_input_args_info_->obj_id]; });
|
||||
MS_LOG(DEBUG) << "Start task emit action";
|
||||
(void)TaskEmitAction(resource);
|
||||
MS_LOG(DEBUG) << "Start execute action";
|
||||
|
@ -653,9 +710,6 @@ void GradExecutor::UpdateParamAbsByArgs(const std::vector<ValuePtr> &input_args,
|
|||
if (param_node->abstract() != nullptr) {
|
||||
auto input_shape = input_abs->BuildShape()->ToString();
|
||||
auto param_tensor_abs = param_node->abstract();
|
||||
if (param_tensor_abs->isa<abstract::AbstractRefTensor>()) {
|
||||
param_tensor_abs = param_tensor_abs->cast<abstract::AbstractRefPtr>()->CloneAsTensor();
|
||||
}
|
||||
CheckParamShapeAndType(param, param_node, input_abs, param_tensor_abs, input_shape);
|
||||
}
|
||||
param_node->set_abstract(input_abs->Broaden());
|
||||
|
@ -672,14 +726,10 @@ FuncGraphPtr GradExecutor::GetBpropGraph(const ad::GradAttr &grad_attr, const ve
|
|||
build_formal_param = true;
|
||||
need_renormalize_ = true;
|
||||
}
|
||||
if (top_cell()->ms_function_flag()) {
|
||||
need_renormalize_ = true;
|
||||
}
|
||||
|
||||
auto k_pynative_cell_ptr = top_cell()->k_pynative_cell_ptr();
|
||||
MS_EXCEPTION_IF_NULL(k_pynative_cell_ptr);
|
||||
FuncGraphPtr bprop_graph =
|
||||
ad::GradPynativeCellEnd(k_pynative_cell_ptr, w_args, p_args, grad_attr, build_formal_param);
|
||||
auto auto_grad_cell_ptr = top_cell()->auto_grad_cell_ptr();
|
||||
MS_EXCEPTION_IF_NULL(auto_grad_cell_ptr);
|
||||
FuncGraphPtr bprop_graph = ad::GradPynativeCellEnd(auto_grad_cell_ptr, w_args, p_args, grad_attr, build_formal_param);
|
||||
MS_EXCEPTION_IF_NULL(bprop_graph);
|
||||
|
||||
MS_LOG(DEBUG) << "Top graph input params size " << top_input_args_info_->input_arg_value_vec.size();
|
||||
|
@ -688,18 +738,13 @@ FuncGraphPtr GradExecutor::GetBpropGraph(const ad::GradAttr &grad_attr, const ve
|
|||
bprop_graph->set_flag(FUNC_GRAPH_FLAG_CORE, true);
|
||||
bprop_graph->debug_info()->set_name(ss.str());
|
||||
UpdateParamAbsByArgs(top_input_args_info_->input_arg_value_vec, bprop_graph, grad_attr.has_sens);
|
||||
// Do opt for final bprop graph
|
||||
pipeline::ResourcePtr resource = std::make_shared<pipeline::Resource>();
|
||||
resource->set_func_graph(bprop_graph);
|
||||
auto manager = resource->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
manager->AddFuncGraph(bprop_graph);
|
||||
auto optimized_bg = ad::PrimBpropOptimizer::GetPrimBpropOptimizerInst().BpropGraphFinalOpt(resource);
|
||||
if (top_cell()->ms_function_flag()) {
|
||||
bprop_graph = BpropGraphFinalOpt(bprop_graph);
|
||||
}
|
||||
if (top_input_args_info_->is_grad_topest_cell) {
|
||||
need_renormalize_ = false;
|
||||
}
|
||||
PyNativeAlgo::Common::DumpGraphIR("after_final_opt.ir", optimized_bg);
|
||||
return optimized_bg;
|
||||
return bprop_graph;
|
||||
}
|
||||
|
||||
void GradExecutor::SetGradOrder(const std::string &cell_id) {
|
||||
|
@ -749,25 +794,25 @@ py::object GradExecutor::RunGradGraph() {
|
|||
const auto &resource = top_cell()->resource();
|
||||
MS_EXCEPTION_IF_NULL(resource);
|
||||
MS_LOG(DEBUG) << "Run cell id " << top_input_args_info_->cell_id << ", resource ptr " << resource.get();
|
||||
std::vector<ValuePtr> flatten_v;
|
||||
PyNativeAlgo::DataConvert::FlattenArgs(top_input_args_info_->input_arg_value_vec, &flatten_v,
|
||||
top_input_args_info_->has_sens);
|
||||
bool has_sens = top_input_args_info_->has_sens;
|
||||
VectorRef arg_list;
|
||||
SetGraphInputArgs(flatten_v, resource, &arg_list);
|
||||
MS_LOG(DEBUG) << "Convert args size " << flatten_v.size() << ", graph param size " << arg_list.size();
|
||||
SetGraphInputArgs(top_input_args_info_->input_arg_value_vec, resource, &arg_list, has_sens);
|
||||
MS_LOG(DEBUG) << "Convert args size " << top_input_args_info_->input_arg_value_vec.size() << ", graph param size "
|
||||
<< arg_list.size();
|
||||
compile::VmEvalFuncPtr run = resource->GetResult(pipeline::kOutput).cast<compile::VmEvalFuncPtr>();
|
||||
MS_EXCEPTION_IF_NULL(run);
|
||||
|
||||
const auto &backend = MsContext::GetInstance()->backend_policy();
|
||||
MS_LOG(DEBUG) << "Eval run " << backend;
|
||||
grad_is_running_ = true;
|
||||
top_cell()->set_k_pynative_cell_ptr(nullptr);
|
||||
top_cell()->set_auto_grad_cell_ptr(nullptr);
|
||||
BaseRef out_value = (*run)(arg_list);
|
||||
grad_is_running_ = false;
|
||||
MS_LOG(DEBUG) << "Eval run end " << out_value.ToString();
|
||||
const auto &cur_run_bprop_graph = resource->func_graph();
|
||||
const auto &out_abs = GetGradGraphOutputAbstract(cur_run_bprop_graph);
|
||||
MakeNestedCnode(top_input_args_info_->has_custom_bprop, flatten_v, cur_run_bprop_graph, out_value);
|
||||
MakeNestedCnode(top_input_args_info_->has_custom_bprop, top_input_args_info_->input_arg_value_vec,
|
||||
cur_run_bprop_graph, out_value);
|
||||
return BaseRefToPyData(out_value, out_abs);
|
||||
}
|
||||
|
||||
|
@ -816,7 +861,8 @@ void GradExecutor::MakeNestedCnode(bool has_custom_bprop, const std::vector<Valu
|
|||
std::vector<ValuePtr> out_v{out_value};
|
||||
out_value = std::make_shared<ValueTuple>(out_v);
|
||||
}
|
||||
if (!top_cell()->k_pynative_cell_ptr()->KPynativeWithFProp(cnode, input_args, out_value, second_grad_fg)) {
|
||||
auto grad_param = std::make_shared<ad::GradParam>(cnode, input_args, out_value, second_grad_fg);
|
||||
if (!top_cell()->auto_grad_cell_ptr()->KPynativeWithFProp(grad_param)) {
|
||||
MS_LOG(EXCEPTION) << "Failed to run ad grad for second grad graph " << cnode->ToString();
|
||||
}
|
||||
need_renormalize_ = true;
|
||||
|
@ -918,6 +964,8 @@ void GradExecutor::ClearRes() {
|
|||
top_cell_ = nullptr;
|
||||
top_input_args_info_ = nullptr;
|
||||
bprop_cell_list_.clear();
|
||||
backends_.clear();
|
||||
async_executor_->Reset();
|
||||
std::stack<InputArgsInfoPtr>().swap(input_args_info_stack_);
|
||||
std::stack<std::pair<std::string, bool>>().swap(bprop_grad_stack_);
|
||||
std::stack<TopCellInfoPtr>().swap(high_order_stack_);
|
||||
|
@ -1000,18 +1048,6 @@ AnfNodePtr GradExecutor::GetOutputNodeAsInput(const std::string &obj_id) const {
|
|||
return CreateTupleGetItemNode(obj_id, it->second);
|
||||
}
|
||||
|
||||
void GradExecutor::RecordGradNodeToGraphInfoMap(const FuncGraphPtr &fg, const CNodePtr &cnode,
|
||||
const std::string &obj_id, const ValuePtrList &input_args) const {
|
||||
top_cell()->SetNodeMapInGraphInfoMap(obj_id, cnode);
|
||||
// run ad for make tuple node
|
||||
if (grad_is_running_ && !bprop_grad_stack_.empty() && !bprop_grad_stack_.top().second) {
|
||||
MS_LOG(DEBUG) << "Running custom bprop, no need to do GradPynativeOp.";
|
||||
} else {
|
||||
(void)ad::GradPynativeOp(top_cell()->k_pynative_cell_ptr(), cnode, input_args,
|
||||
std::make_shared<ValueTuple>(input_args));
|
||||
}
|
||||
}
|
||||
|
||||
AnfNodePtr GradExecutor::GetValueSequenceInput(const ValuePtr &v, const std::string &obj_id) const {
|
||||
MS_EXCEPTION_IF_NULL(v);
|
||||
if (!v->isa<ValueSequence>()) {
|
||||
|
@ -1034,8 +1070,8 @@ AnfNodePtr GradExecutor::GetValueSequenceInput(const ValuePtr &v, const std::str
|
|||
}
|
||||
// Create make tuple node and record to graph info map.
|
||||
auto cnode = curr_g()->NewCNode(inputs);
|
||||
MS_LOG(DEBUG) << "Create make tuple node " << cnode->DebugString();
|
||||
RecordGradNodeToGraphInfoMap(curr_g(), cnode, obj_id, input_args);
|
||||
MS_LOG(DEBUG) << "Create make tuple node: " << cnode->DebugString();
|
||||
top_cell()->SetNodeMapInGraphInfoMap(obj_id, cnode, -1, false);
|
||||
return cnode;
|
||||
}
|
||||
|
||||
|
@ -1094,7 +1130,7 @@ void GradExecutor::ProcessOpGradInfo(const FrontendOpRunInfoPtr &op_run_info) co
|
|||
return;
|
||||
}
|
||||
// Do op grad and save node info. If cell have custom bprop, no need do op grad. Otherwise, need do.
|
||||
if (custom_bprop_cell_count_ <= 0) {
|
||||
if (op_run_info->custom_bprop_cell_count <= 0) {
|
||||
const auto &cnode = ConstructForwardGraph(op_run_info);
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
cnode->set_abstract(op_run_info->base_op_run_info.abstract);
|
||||
|
@ -1103,6 +1139,12 @@ void GradExecutor::ProcessOpGradInfo(const FrontendOpRunInfoPtr &op_run_info) co
|
|||
}
|
||||
}
|
||||
|
||||
void GradExecutor::AsyncProcessOpGradInfo(const FrontendOpRunInfoPtr &op_run_info) const {
|
||||
const auto fn = [this, op_run_info]() { this->ProcessOpGradInfo(op_run_info); };
|
||||
auto task = std::make_shared<BpropTask>(fn);
|
||||
async_executor_->Push(task);
|
||||
}
|
||||
|
||||
void GradExecutor::SaveOutputNodeMap(const std::string &obj_id, const FrontendOpRunInfoPtr &op_run_info,
|
||||
const CNodePtr &cnode) const {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
|
@ -1126,7 +1168,19 @@ void GradExecutor::DoOpGrad(const FrontendOpRunInfoPtr &op_run_info, const CNode
|
|||
return;
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(op_run_info);
|
||||
if (!ad::GradPynativeOp(top_cell()->k_pynative_cell_ptr(), cnode, op_run_info->input_value, op_out)) {
|
||||
|
||||
// to avoid out exist in tape bprop, avoid out be modified.
|
||||
ValuePtrList cloned_op_args;
|
||||
(void)std::transform(op_run_info->input_value.begin(), op_run_info->input_value.end(),
|
||||
std::back_inserter(cloned_op_args),
|
||||
[](const ValuePtr &value) { return ShallowCopyTensorValue(value); });
|
||||
ValuePtr cloned_out = ShallowCopyTensorValue(op_out);
|
||||
std::vector<tensor::TensorPtr> tensors;
|
||||
TensorValueToTensor(cloned_out, &tensors);
|
||||
for (auto tensor : tensors) {
|
||||
tensor->set_is_forward_output(true);
|
||||
}
|
||||
if (!ad::GradPynativeOp(top_cell()->auto_grad_cell_ptr(), cnode, cloned_op_args, cloned_out)) {
|
||||
MS_LOG(EXCEPTION) << "Failed to run ad grad for op " << op_run_info->base_op_run_info.op_name;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -22,10 +22,13 @@
|
|||
#include <utility>
|
||||
#include <stack>
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include "pipeline/pynative/base.h"
|
||||
#include "pipeline/pynative/grad/top_cell.h"
|
||||
#include "pipeline/pynative/grad/ms_function_grad.h"
|
||||
|
||||
#include "runtime/pynative/async/async_queue.h"
|
||||
#include "pipeline/pynative/grad/bprop_task.h"
|
||||
#include "pipeline/jit/resource.h"
|
||||
namespace mindspore {
|
||||
namespace pynative {
|
||||
namespace py = pybind11;
|
||||
|
@ -38,7 +41,10 @@ class GradExecutor {
|
|||
GradExecutor() = default;
|
||||
~GradExecutor() = default;
|
||||
explicit GradExecutor(const ForwardExecutorPtr &forward_executor = nullptr)
|
||||
: forward_executor_(ForwardExecutorWeakPtr(forward_executor)), ms_function_(std::make_shared<MsFunction>()) {}
|
||||
: forward_executor_(ForwardExecutorWeakPtr(forward_executor)),
|
||||
ms_function_(std::make_shared<MsFunction>()),
|
||||
async_executor_(std::make_unique<AsyncQueue>()),
|
||||
enable_async_(std::getenv("ENABLE_ASYNC")) {}
|
||||
|
||||
std::function<void(const py::object &, const py::args &)> InitGraph = [this](auto &&PH1, auto &&PH2) {
|
||||
NewGraphInner(std::forward<decltype(PH1)>(PH1), std::forward<decltype(PH2)>(PH2));
|
||||
|
@ -70,6 +76,7 @@ class GradExecutor {
|
|||
// Construct grad graph for ms_function
|
||||
inline bool eliminate_forward() const { return eliminate_forward_; }
|
||||
inline void set_eliminate_forward(bool eliminate_forward) { eliminate_forward_ = eliminate_forward; }
|
||||
inline size_t custom_bprop_cell_count() const { return custom_bprop_cell_count_; }
|
||||
void SetHookChanged(const py::object &cell) const;
|
||||
void GradNetInner(const prim::GradOperationPtr &grad, const py::object &obj, const py::object &weights,
|
||||
const py::object &grad_position, const py::args &args);
|
||||
|
@ -77,11 +84,14 @@ class GradExecutor {
|
|||
CNodePtr ConstructForwardGraph(const FrontendOpRunInfoPtr &op_run_info) const;
|
||||
py::object CheckAlreadyRun(const prim::GradOperationPtr &grad, const py::object &obj, const py::args &args);
|
||||
void ProcessOpGradInfo(const FrontendOpRunInfoPtr &op_run_info) const;
|
||||
void AsyncProcessOpGradInfo(const FrontendOpRunInfoPtr &op_run_info) const;
|
||||
void EndGraphInner(const py::object &obj, const py::object &out, const py::args &args);
|
||||
void EndGraphImpl(const InputArgsInfoPtr &input_args_info);
|
||||
AnfNodePtr GetInput(const ValuePtr &v, const string &obj_id) const;
|
||||
void AsyncEndGraphImpl(const InputArgsInfoPtr input_args_info);
|
||||
AnfNodePtr GetParamInput(const ValuePtr &v, const std::string &id) const;
|
||||
void ClearRes();
|
||||
void WorkerJoin() { async_executor_->WorkerJoin(); }
|
||||
|
||||
private:
|
||||
ForwardExecutorPtr forward() const;
|
||||
|
@ -126,6 +136,7 @@ class GradExecutor {
|
|||
bool IsBpropGraph(const std::string &cell_id) const;
|
||||
void NewGraphInner(const py::object &obj, const py::args &args);
|
||||
void NewGraphImpl(const InputArgsInfoPtr &input_args_info);
|
||||
void AsyncNewGraphImpl(const InputArgsInfoPtr &input_args_info);
|
||||
void SetForwardLastNodeInfo(const ValuePtr &v, const std::string &obj_id) const;
|
||||
void GetCustomBpropPrim(const py::object &obj, const py::args &args, const py::object &out,
|
||||
const InputArgsInfoPtr &input_args_info);
|
||||
|
@ -145,8 +156,6 @@ class GradExecutor {
|
|||
AnfNodePtr GetValueSequenceInput(const ValuePtr &v, const std::string &obj_id) const;
|
||||
AnfNodePtr CreateTupleGetItemNode(const std::string &obj_id,
|
||||
const std::pair<AnfNodePtr, std::vector<int64_t>> &out) const;
|
||||
void RecordGradNodeToGraphInfoMap(const FuncGraphPtr &fg, const CNodePtr &cnode, const std::string &obj_id,
|
||||
const ValuePtrList &input_args) const;
|
||||
|
||||
bool grad_flag_{false};
|
||||
bool grad_is_running_{false};
|
||||
|
@ -168,6 +177,9 @@ class GradExecutor {
|
|||
std::stack<TopCellInfoPtr> high_order_stack_;
|
||||
ForwardExecutorWeakPtr forward_executor_;
|
||||
MsFunctionPtr ms_function_;
|
||||
std::unique_ptr<AsyncQueue> async_executor_;
|
||||
std::map<std::string, compile::BackendPtr> backends_;
|
||||
bool enable_async_ = false;
|
||||
};
|
||||
} // namespace pynative
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -238,10 +238,11 @@ CNodePtr MsFunction::MakeAdjointForMsFunction(const FrontendOpRunInfoPtr &op_run
|
|||
top_cell->SetNodeMapInGraphInfoMap(out_id, ms_function_cnode);
|
||||
|
||||
// Connect grad graph of ms_function to context.
|
||||
auto k_pynative_cell_ptr = top_cell->k_pynative_cell_ptr();
|
||||
MS_EXCEPTION_IF_NULL(k_pynative_cell_ptr);
|
||||
if (!k_pynative_cell_ptr->KPynativeWithFProp(ms_function_cnode, op_run_info->input_value, op_run_info->out_value,
|
||||
grad_graph)) {
|
||||
auto auto_grad_cell_ptr = top_cell->auto_grad_cell_ptr();
|
||||
MS_EXCEPTION_IF_NULL(auto_grad_cell_ptr);
|
||||
auto grad_param =
|
||||
std::make_shared<ad::GradParam>(ms_function_cnode, op_run_info->input_value, op_run_info->out_value, grad_graph);
|
||||
if (!auto_grad_cell_ptr->KPynativeWithFProp(grad_param)) {
|
||||
MS_LOG(EXCEPTION) << "Failed to make adjoint for ms_function cnode, ms_function cnode info: "
|
||||
<< ms_function_cnode->DebugString();
|
||||
}
|
||||
|
@ -315,7 +316,6 @@ py::object MsFunction::GradMsFunction(const py::object &out, const py::args &arg
|
|||
const auto &op_run_info = GetOpRunInfo(out, args, graph_phase_, &added_out_v);
|
||||
FuncGraphPtr grad_graph = executor->GetGradGraph(graph_phase_);
|
||||
PyNativeAlgo::Common::DumpGraphIR("ms_func_forward_graph.ir", ms_func_graph);
|
||||
PyNativeAlgo::Common::DumpGraphIR("ms_func_grad_graph.ir", grad_graph);
|
||||
GradMsFunctionInner(op_run_info, grad_executor.get(), added_out_v, ms_func_graph, grad_graph);
|
||||
SetMsFuncGraphParameters(ms_func_graph);
|
||||
graph_phase_.clear();
|
||||
|
|
|
@ -110,12 +110,15 @@ void TopCellInfo::SetParamNodeMapInGraphInfoMap(const std::string &id, const Par
|
|||
}
|
||||
}
|
||||
|
||||
void TopCellInfo::SetNodeMapInGraphInfoMap(const std::string &id, const AnfNodePtr &node, int64_t index) const {
|
||||
void TopCellInfo::SetNodeMapInGraphInfoMap(const std::string &id, const AnfNodePtr &node, int64_t index,
|
||||
bool save_flag) const {
|
||||
auto &graph_info = graph_info_map().at(fg());
|
||||
MS_EXCEPTION_IF_NULL(graph_info);
|
||||
graph_info->node_map[id] = std::make_pair(node, std::vector<int64_t>{index});
|
||||
// For example, set id of ((A,B),C) = {CNode, -1}
|
||||
SetMultipleOutputToGraphInfoMap(id, node);
|
||||
if (save_flag) {
|
||||
SetMultipleOutputToGraphInfoMap(id, node);
|
||||
}
|
||||
}
|
||||
|
||||
void TopCellInfo::SetMultipleOutputToGraphInfoMap(const string &id, const AnfNodePtr &node) const {
|
||||
|
|
|
@ -32,7 +32,7 @@
|
|||
#include "pybind11/pytypes.h"
|
||||
#include "pybind_api/ir/base_ref_py.h"
|
||||
#include "ir/anf.h"
|
||||
#include "frontend/optimizer/ad/kpynative.h"
|
||||
#include "frontend/optimizer/ad/auto_grad.h"
|
||||
#include "frontend/operator/composite/composite.h"
|
||||
#include "pipeline/jit/resource.h"
|
||||
#include "pipeline/pynative/base.h"
|
||||
|
@ -90,16 +90,14 @@ class TopCellInfo {
|
|||
graph_info_map_[fg] = graph_info;
|
||||
}
|
||||
inline const OrderedMap<FuncGraphPtr, GraphInfoPtr> &graph_info_map() const { return graph_info_map_; }
|
||||
inline ad::KPynativeCellPtr k_pynative_cell_ptr() const {
|
||||
MS_EXCEPTION_IF_NULL(k_pynative_cell_ptr_);
|
||||
return k_pynative_cell_ptr_;
|
||||
}
|
||||
inline void set_k_pynative_cell_ptr(const ad::KPynativeCellPtr &k_pynative_cell_ptr) {
|
||||
k_pynative_cell_ptr_ = k_pynative_cell_ptr;
|
||||
inline ad::AutoGradCellImplPtr auto_grad_cell_ptr() const { return auto_grad_cell_ptr_; }
|
||||
void set_auto_grad_cell_ptr(const ad::AutoGradCellImplPtr &auto_grad_cell_ptr) {
|
||||
auto_grad_cell_ptr_ = auto_grad_cell_ptr;
|
||||
}
|
||||
void DeleteParamNodeInfo(const FuncGraphPtr &g, const std::string &id);
|
||||
void SetParamNodeMapInGraphInfoMap(const std::string &id, const ParameterPtr ¶m, bool is_weight = false) const;
|
||||
void SetNodeMapInGraphInfoMap(const std::string &id, const AnfNodePtr &node, int64_t index = -1) const;
|
||||
void SetNodeMapInGraphInfoMap(const std::string &id, const AnfNodePtr &node, int64_t index = -1,
|
||||
bool save_flag = true) const;
|
||||
void ClearDeviceMemory() const;
|
||||
|
||||
private:
|
||||
|
@ -120,7 +118,7 @@ class TopCellInfo {
|
|||
std::string grad_operation_;
|
||||
pipeline::ResourcePtr resource_{nullptr};
|
||||
FuncGraphPtr fg_{nullptr};
|
||||
ad::KPynativeCellPtr k_pynative_cell_ptr_{nullptr};
|
||||
ad::AutoGradCellImplPtr auto_grad_cell_ptr_{nullptr};
|
||||
OrderedMap<FuncGraphPtr, GraphInfoPtr> graph_info_map_;
|
||||
// Record `register hook` or `remove hook` function has been called by sub cell
|
||||
// The record range between the begin and end of top cell.
|
||||
|
|
|
@ -71,6 +71,7 @@ class PyNativeExecutor : public std::enable_shared_from_this<PyNativeExecutor> {
|
|||
void Sync() const;
|
||||
void SetLazyBuild(bool enable) const;
|
||||
bool IsFirstCell() const;
|
||||
void WorkerJoin() { grad_executor_->WorkerJoin(); }
|
||||
|
||||
private:
|
||||
PyNativeExecutor() = default;
|
||||
|
|
|
@ -36,6 +36,7 @@ namespace {
|
|||
constexpr auto kBpropAttrName = "bprop";
|
||||
constexpr auto kCellHookAttrName = "cell_hook";
|
||||
constexpr auto kCellIDAttrName = "cell_id";
|
||||
constexpr auto kCustomOpBpropAttrName = "custom_op_bprop";
|
||||
std::map<std::string, std::string> kOpAttrNameReplaceMap = {
|
||||
{"data_format", "format"},
|
||||
};
|
||||
|
@ -359,6 +360,26 @@ BaseRef PrimitivePy::RunCellBpropFunction(const py::tuple &py_args) const {
|
|||
}
|
||||
}
|
||||
|
||||
BaseRef PrimitivePy::RunOpBpropFunction(const py::tuple &py_args) const {
|
||||
if (backward_hook_fn_.size() > 1) {
|
||||
MS_LOG(EXCEPTION) << "Multiple registration of bprop function is not supported.";
|
||||
}
|
||||
py::tuple grads;
|
||||
SyncData(py_args);
|
||||
py::tuple converted_args(py_args.size());
|
||||
ConvertCTensorToPyTensor(py_args, &converted_args);
|
||||
try {
|
||||
MS_LOG(DEBUG) << "start execute custom op bprop";
|
||||
for (const auto &elem : backward_hook_fn_) {
|
||||
py::object grads_obj = elem.second(*converted_args);
|
||||
grads = check_bprop_out(grads_obj, py_args, bprop_cls_name_);
|
||||
}
|
||||
return std::make_shared<PyObjectRef>(grads);
|
||||
} catch (std::exception &bt) {
|
||||
std::rethrow_exception(std::current_exception());
|
||||
}
|
||||
}
|
||||
|
||||
BaseRef PrimitivePy::RunCellHookFunction(const py::tuple &py_args) const {
|
||||
// Get the gradient passed to current bprop cut op.
|
||||
const auto args_size = py_args.size();
|
||||
|
@ -426,6 +447,10 @@ BaseRef PrimitivePy::RunHookFunction(const VectorRef &args) const {
|
|||
if (is_cell) {
|
||||
return RunCellHookFunction(py_args);
|
||||
}
|
||||
bool is_custom_op_bprop = this->HasAttr(kCustomOpBpropAttrName);
|
||||
if (is_custom_op_bprop) {
|
||||
return RunOpBpropFunction(py_args);
|
||||
}
|
||||
return RunVariableHookFunction(py_args);
|
||||
}
|
||||
|
||||
|
|
|
@ -59,6 +59,7 @@ class PrimitivePy : public Primitive {
|
|||
void RemoveBackwardHookFn(const int &key);
|
||||
BaseRef RunHookFunction(const VectorRef &args) const;
|
||||
BaseRef RunCellBpropFunction(const py::tuple &py_args) const;
|
||||
BaseRef RunOpBpropFunction(const py::tuple &py_args) const;
|
||||
BaseRef RunCellHookFunction(const py::tuple &py_args) const;
|
||||
BaseRef RunVariableHookFunction(const py::tuple &py_args) const;
|
||||
BaseRef RunComputeFunction(const VectorRef &args) const override;
|
||||
|
|
|
@ -23,6 +23,7 @@ enum TaskType {
|
|||
kUnknownTask = 0,
|
||||
kOpRunTask,
|
||||
kOpBuildTask,
|
||||
kBpropTask,
|
||||
kExitTask,
|
||||
};
|
||||
class AsyncTask {
|
||||
|
|
|
@ -38,6 +38,7 @@ file(GLOB_RECURSE CORE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
|||
"utils/*.cc"
|
||||
"load_mindir/*.cc"
|
||||
"mindapi/src/*.cc"
|
||||
"expander/*.cc"
|
||||
)
|
||||
|
||||
set(CORE_SRC_LIST ${CORE_SRC_LIST} ${CORE_OPS_LIST})
|
||||
|
|
|
@ -495,11 +495,11 @@ PrimitiveEvalImplMap &GetPrimitiveToBackendEvalImplMap() {
|
|||
|
||||
StandardPrimitiveImplReg GetPrimitiveInferImpl(const PrimitivePtr &primitive) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto iter = GetPrimitiveToEvalImplMap().find(primitive);
|
||||
if (iter == GetPrimitiveToEvalImplMap().end()) {
|
||||
return {nullptr, nullptr, false};
|
||||
const auto iter = GetPrimitiveToEvalImplMap().find(primitive);
|
||||
if (iter != GetPrimitiveToEvalImplMap().end()) {
|
||||
return iter->second;
|
||||
}
|
||||
return iter->second;
|
||||
return {nullptr, nullptr, false};
|
||||
}
|
||||
|
||||
void RegisterStandardPrimitiveImpl(const PrimitivePtr &primitive, const StandardPrimitiveImplReg &impl_reg) {
|
||||
|
|
|
@ -0,0 +1,6 @@
|
|||
approvers:
|
||||
- gaoxiong1
|
||||
- ckey_dou
|
||||
- dayschan
|
||||
- anyrenwei
|
||||
- zichun_ye
|
|
@ -14,7 +14,7 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "common/graph_kernel/bprop/expander/emitter.h"
|
||||
#include "expander/emitter.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <functional>
|
||||
|
@ -146,13 +146,13 @@ NodePtr Emitter::ZerosLike(const NodePtr &node) const {
|
|||
return Emit(prim::kZerosLike, {node});
|
||||
}
|
||||
|
||||
NodePtr Emitter::Fill(const double &value, const ShapeVector &shape, TypeId data_type) const {
|
||||
NodePtr Emitter::Fill(double value, const ShapeVector &shape, TypeId data_type) const {
|
||||
size_t data_num = LongToSize(std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int64_t>()));
|
||||
std::vector<double> data(data_num, value);
|
||||
return Tensor(data_type, shape, &data[0], TypeId::kNumberTypeFloat64);
|
||||
}
|
||||
|
||||
NodePtr Emitter::Fill(const int64_t &value, const ShapeVector &shape, TypeId data_type) const {
|
||||
NodePtr Emitter::Fill(int64_t value, const ShapeVector &shape, TypeId data_type) const {
|
||||
size_t data_num = LongToSize(std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int64_t>()));
|
||||
std::vector<int64_t> data(data_num, value);
|
||||
return Tensor(data_type, shape, &data[0], TypeId::kNumberTypeInt64);
|
|
@ -25,12 +25,12 @@
|
|||
#include "ir/func_graph.h"
|
||||
#include "ops/core_ops.h"
|
||||
#include "include/common/utils/utils.h"
|
||||
#include "common/graph_kernel/bprop/expander/node.h"
|
||||
#include "common/graph_kernel/bprop/expander/infer.h"
|
||||
#include "expander/node.h"
|
||||
#include "expander/infer.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace expander {
|
||||
class Emitter {
|
||||
class MS_CORE_API Emitter {
|
||||
public:
|
||||
Emitter(const FuncGraphPtr &func_graph, const ExpanderInferPtr &infer, const ScopePtr &scope = nullptr)
|
||||
: func_graph_(func_graph), infer_(infer), scope_(scope) {
|
||||
|
@ -119,8 +119,8 @@ class Emitter {
|
|||
NodePtr ReduceSum(const NodePtr &x, const ShapeVector &axis = {}, bool keep_dims = false) const;
|
||||
|
||||
NodePtr ZerosLike(const NodePtr &node) const;
|
||||
NodePtr Fill(const double &value, const ShapeVector &shape, TypeId data_type) const;
|
||||
NodePtr Fill(const int64_t &value, const ShapeVector &shape, TypeId data_type) const;
|
||||
NodePtr Fill(double value, const ShapeVector &shape, TypeId data_type) const;
|
||||
NodePtr Fill(int64_t value, const ShapeVector &shape, TypeId data_type) const;
|
||||
|
||||
/// \brief Emit a value node
|
||||
template <typename T>
|
||||
|
@ -165,11 +165,11 @@ class Emitter {
|
|||
};
|
||||
using EmitterPtr = std::shared_ptr<Emitter>;
|
||||
|
||||
NodePtr operator+(const NodePtr &lhs, const NodePtr &rhs);
|
||||
NodePtr operator-(const NodePtr &lhs, const NodePtr &rhs);
|
||||
NodePtr operator*(const NodePtr &lhs, const NodePtr &rhs);
|
||||
NodePtr operator/(const NodePtr &lhs, const NodePtr &rhs);
|
||||
NodePtr operator-(const NodePtr &node);
|
||||
MS_CORE_API NodePtr operator+(const NodePtr &lhs, const NodePtr &rhs);
|
||||
MS_CORE_API NodePtr operator-(const NodePtr &lhs, const NodePtr &rhs);
|
||||
MS_CORE_API NodePtr operator*(const NodePtr &lhs, const NodePtr &rhs);
|
||||
MS_CORE_API NodePtr operator/(const NodePtr &lhs, const NodePtr &rhs);
|
||||
MS_CORE_API NodePtr operator-(const NodePtr &node);
|
||||
} // namespace expander
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CORE_EXPANDER_EMITTER_H_
|
|
@ -14,7 +14,7 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "common/graph_kernel/bprop/expander/infer.h"
|
||||
#include "expander/infer.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include "abstract/ops/primitive_infer_map.h"
|
|
@ -17,12 +17,12 @@
|
|||
#ifndef MINDSPORE_CORE_EXPANDER_INFER_H_
|
||||
#define MINDSPORE_CORE_EXPANDER_INFER_H_
|
||||
#include <memory>
|
||||
#include "common/graph_kernel/bprop/expander/node.h"
|
||||
#include "expander/node.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace expander {
|
||||
/// \brief ExpanderInfer is the adapter for inferring functions that is called in emitter.
|
||||
class ExpanderInfer {
|
||||
class MS_CORE_API ExpanderInfer {
|
||||
public:
|
||||
/// \brief Infer shape and dtype for node
|
||||
virtual void Infer(const NodePtr &node) = 0;
|
||||
|
@ -33,7 +33,7 @@ class ExpanderInfer {
|
|||
using ExpanderInferPtr = std::shared_ptr<ExpanderInfer>;
|
||||
|
||||
/// \brief CppInfer calls the InferShapeAndType interface of frontend or backend map.
|
||||
class CppInfer : public ExpanderInfer {
|
||||
class MS_CORE_API CppInfer : public ExpanderInfer {
|
||||
public:
|
||||
void Infer(const NodePtr &node) override;
|
||||
BaseShapePtr GetShape(const NodePtr &node) override;
|
|
@ -14,10 +14,10 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "common/graph_kernel/bprop/expander/node.h"
|
||||
#include "expander/node.h"
|
||||
#include <algorithm>
|
||||
#include "common/graph_kernel/bprop/expander/emitter.h"
|
||||
#include "common/graph_kernel/bprop/expander/infer.h"
|
||||
#include "expander/emitter.h"
|
||||
#include "expander/infer.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace expander {
|
|
@ -27,7 +27,7 @@ namespace expander {
|
|||
class Emitter;
|
||||
using DAttr = mindspore::HashMap<std::string, ValuePtr>;
|
||||
|
||||
class Node : public std::enable_shared_from_this<Node> {
|
||||
class MS_CORE_API Node : public std::enable_shared_from_this<Node> {
|
||||
public:
|
||||
Node(const AnfNodePtr &node, const Emitter *emitter);
|
||||
~Node() = default;
|
|
@ -20,6 +20,7 @@ from mindspore.ops import composite as C
|
|||
from mindspore.common.tensor import Tensor
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.common.dtype import dtype_to_nptype, get_py_obj_dtype
|
||||
from mindspore._c_expression import Tensor as CTensor
|
||||
|
||||
|
||||
def ScalarAdd(x, y):
|
||||
|
@ -103,8 +104,11 @@ def zeros_like_tensor(x):
|
|||
|
||||
def OnesLike(x):
|
||||
"""Implement `oneslike`."""
|
||||
x = x.asnumpy()
|
||||
value = Tensor(np.ones(x.shape).astype(x.dtype))
|
||||
if isinstance(x, (Tensor, CTensor)):
|
||||
x = x.asnumpy()
|
||||
value = Tensor(np.ones(x.shape).astype(x.dtype))
|
||||
else:
|
||||
value = Tensor(1.0)
|
||||
return value
|
||||
|
||||
|
||||
|
|
|
@ -492,9 +492,6 @@ class DepthwiseConv2dNativeBackpropFilter(PrimitiveWithInfer):
|
|||
self.group = group
|
||||
self.add_prim_attr('data_format', "NCHW")
|
||||
|
||||
def __call__(self, x, w_size, dout):
|
||||
raise NotImplementedError
|
||||
|
||||
def __infer__(self, x, w_size, dout):
|
||||
w_size_v = w_size['value']
|
||||
args = {'x': x['dtype'], 'dout': dout['dtype']}
|
||||
|
@ -559,9 +556,6 @@ class DepthwiseConv2dNativeBackpropInput(PrimitiveWithInfer):
|
|||
self.group = group
|
||||
self.add_prim_attr('data_format', "NCHW")
|
||||
|
||||
def __call__(self, x_size, w, dout):
|
||||
raise NotImplementedError
|
||||
|
||||
def __infer__(self, x_size, w, dout):
|
||||
args = {'w': w['dtype'], 'dout': dout['dtype']}
|
||||
validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
|
||||
|
@ -664,9 +658,6 @@ class UniqueGrad(Primitive):
|
|||
def __init__(self):
|
||||
self.init_prim_io_names(inputs=['dy', 'y'], outputs=['dx'])
|
||||
|
||||
def __call__(self, dy, x, scale, save_mean, save_inv_variance):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class BNTrainingReduceGrad(Primitive):
|
||||
"""Gradients of FusedBatchNorm operation."""
|
||||
|
@ -721,9 +712,6 @@ class NeighborExchangeV2Grad(PrimitiveWithInfer):
|
|||
'dtype': dy['dtype'],
|
||||
'value': None}
|
||||
|
||||
def __call__(self, tensor):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class GeLUGrad(Primitive):
|
||||
"""Gradients of GeLU operation."""
|
||||
|
@ -1389,9 +1377,6 @@ class LayerNormGrad(Primitive):
|
|||
self.begin_norm_axis = validator.check_value_type('begin_norm_axis', begin_norm_axis, [int], self.name)
|
||||
self.begin_params_axis = validator.check_value_type('begin_params_axis', begin_params_axis, [int], self.name)
|
||||
|
||||
def __call__(self, x, dy, variance, mean, gamma):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class LayerNormGradGrad(Primitive):
|
||||
"""
|
||||
|
@ -1793,9 +1778,6 @@ class ReluGrad(Primitive):
|
|||
"""Initialize ReluGrad"""
|
||||
self.init_prim_io_names(inputs=['y_backprop', 'x'], outputs=['output'])
|
||||
|
||||
def __call__(self, y_backprop, x):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class ReLU6Grad(Primitive):
|
||||
"""Performs grad of ReLU6 operation."""
|
||||
|
@ -1804,9 +1786,6 @@ class ReLU6Grad(Primitive):
|
|||
def __init__(self):
|
||||
self.init_prim_io_names(inputs=['y_grad', 'x'], outputs=['output'])
|
||||
|
||||
def __call__(self, y_grad, x):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class ReluGradV2(Primitive):
|
||||
"""Performs grad of ReLUV2 operation."""
|
||||
|
@ -1815,9 +1794,6 @@ class ReluGradV2(Primitive):
|
|||
def __init__(self):
|
||||
self.init_prim_io_names(inputs=['gradients', 'mask'], outputs=['output'])
|
||||
|
||||
def __call__(self, gradients, mask):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class EluGrad(Primitive):
|
||||
"""Performs grad of Elu operation."""
|
||||
|
|
|
@ -19,9 +19,11 @@ import numpy as np
|
|||
from mindspore.nn import Cell
|
||||
from mindspore.common import Tensor, dtype, Parameter
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore import jit
|
||||
from mindspore import jit, context
|
||||
import mindspore.ops.functional as F
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
|
||||
@case_register.level0
|
||||
@case_register.target_gpu
|
||||
|
|
|
@ -18,8 +18,11 @@ from mindspore.nn import Cell
|
|||
from mindspore.common import Tensor, dtype
|
||||
import mindspore.ops.functional as F
|
||||
import mindspore.ops.operations as P
|
||||
from mindspore import context
|
||||
import numpy as np
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
|
||||
@case_register.level0
|
||||
@case_register.target_gpu
|
||||
|
|
|
@ -148,7 +148,7 @@ def test_grad_multiple_inputs_multiple_outputs_cell_pynative():
|
|||
assert np.allclose(real_grad[1].asnumpy(), expect_grad2.asnumpy())
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.level2
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_grad_iteration_function_pynative():
|
||||
|
|
|
@ -16,12 +16,9 @@
|
|||
import platform
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore as ms
|
||||
from mindspore import nn
|
||||
from mindspore import ops
|
||||
from mindspore import context, Tensor
|
||||
from mindspore import jit
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
|
@ -32,8 +29,6 @@ class Net(nn.Cell):
|
|||
self.addn = ops.AddN()
|
||||
self.relu = nn.ReLU()
|
||||
|
||||
@jit(input_signature=(Tensor(shape=[2, 3, 6, None], dtype=ms.float32),
|
||||
Tensor(shape=[2, 3, None, None], dtype=ms.float32)))
|
||||
def construct(self, x, y):
|
||||
x = self.addn((x, y))
|
||||
x = self.log(x)
|
||||
|
@ -58,8 +53,6 @@ class CmpNet(nn.Cell):
|
|||
return x
|
||||
|
||||
|
||||
@jit(input_signature=(Tensor(shape=[2, 3, 6, None], dtype=ms.float32),
|
||||
Tensor(shape=[2, 3, None, None], dtype=ms.float32)))
|
||||
def func(x, y):
|
||||
x = ops.AddN()((x, y))
|
||||
x = ops.Log()(x)
|
||||
|
|
|
@ -1,123 +0,0 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <iostream>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "frontend/optimizer/ad/kpynative.h"
|
||||
#include "common/common_test.h"
|
||||
#include "common/py_func_graph_fetcher.h"
|
||||
#include "ir/manager.h"
|
||||
#include "ir/value.h"
|
||||
#include "ir/func_graph_cloner.h"
|
||||
#include "utils/log_adapter.h"
|
||||
#include "ir/graph_utils.h"
|
||||
#include "pipeline/jit/resource.h"
|
||||
#include "pipeline/jit/parse/parse.h"
|
||||
#include "pipeline/jit/debug/anf_ir_utils.h"
|
||||
#include "frontend/operator/ops.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ad {
|
||||
class TestKPynative : public UT::Common {
|
||||
public:
|
||||
pipeline::ResourcePtr resource = std::make_shared<pipeline::Resource>();
|
||||
|
||||
protected:
|
||||
AbstractBasePtr BuildArg() {
|
||||
std::vector<int64_t> shp = {2, 2};
|
||||
tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(kFloat32->type_id(), shp);
|
||||
auto abstract = tensor->ToAbstract();
|
||||
return abstract;
|
||||
}
|
||||
|
||||
FuncGraphPtr BuildPrimalFuncGraph(const std::string &testCase) {
|
||||
auto g = std::make_shared<FuncGraph>();
|
||||
auto x = g->add_parameter();
|
||||
auto y = g->add_parameter();
|
||||
x->set_abstract(BuildArg());
|
||||
y->set_abstract(BuildArg());
|
||||
auto c_node = g->NewCNode({NewValueNode(prim::GetPythonOps("tensor_mul", "mindspore.ops.functional")), x, y});
|
||||
c_node->set_abstract(BuildArg());
|
||||
g->set_output(c_node);
|
||||
return g;
|
||||
}
|
||||
|
||||
// a = x * y
|
||||
// b = stop_gradient(a)
|
||||
// c = b * y
|
||||
// return c
|
||||
FuncGraphPtr BuildStopGradient(const std::string &testCase) {
|
||||
auto g = std::make_shared<FuncGraph>();
|
||||
auto x = g->add_parameter();
|
||||
x->debug_info()->set_name("x");
|
||||
auto y = g->add_parameter();
|
||||
y->debug_info()->set_name("y");
|
||||
x->set_abstract(BuildArg());
|
||||
y->set_abstract(BuildArg());
|
||||
auto a_node = g->NewCNode({NewValueNode(prim::GetPythonOps("tensor_mul", "mindspore.ops.functional")), x, y});
|
||||
a_node->set_abstract(BuildArg());
|
||||
auto b_node = g->NewCNode({NewValueNode(prim::kPrimStopGradient), a_node});
|
||||
b_node->set_abstract(BuildArg());
|
||||
auto c_node = g->NewCNode({NewValueNode(prim::GetPythonOps("tensor_mul", "mindspore.ops.functional")), b_node, y});
|
||||
c_node->set_abstract(BuildArg());
|
||||
auto d_node =
|
||||
g->NewCNode({NewValueNode(prim::GetPythonOps("tensor_mul", "mindspore.ops.functional")), a_node, c_node});
|
||||
d_node->set_abstract(BuildArg());
|
||||
g->set_output(d_node);
|
||||
return g;
|
||||
}
|
||||
|
||||
FuncGraphPtr BuildBpropFuncGraph(const FuncGraphPtr &primal_fg) {
|
||||
auto input_params = primal_fg->parameters();
|
||||
std::vector<ValuePtr> input_param_values;
|
||||
std::for_each(input_params.begin(), input_params.end(),
|
||||
[&](const AnfNodePtr ¶m) { input_param_values.emplace_back(param->abstract()->BuildValue()); });
|
||||
auto k_pynative_cell = GradPynativeCellBegin(input_params, input_param_values);
|
||||
auto node_list = TopoSort(primal_fg->output());
|
||||
for (auto node : node_list) {
|
||||
if (node->isa<CNode>()) {
|
||||
auto c_node = node->cast<CNodePtr>();
|
||||
auto out = c_node->abstract()->GetValueTrack();
|
||||
ValuePtrList args;
|
||||
for (size_t i = 1; i < c_node->inputs().size(); ++i) {
|
||||
args.push_back(c_node->input(i)->abstract()->GetValueTrack());
|
||||
}
|
||||
GradPynativeOp(k_pynative_cell, c_node, args, out);
|
||||
}
|
||||
}
|
||||
GradAttr grad_attr(true, false, false, false, true);
|
||||
auto bprop_fg = GradPynativeCellEnd(k_pynative_cell, AnfNodePtrList{}, std::vector<size_t>{0}, grad_attr, true);
|
||||
return bprop_fg;
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(TestKPynative, test_simple_add) {
|
||||
auto primal_fg = BuildPrimalFuncGraph("test_simple_add");
|
||||
resource->manager()->KeepRoots({primal_fg});
|
||||
|
||||
auto bprop_fg = BuildBpropFuncGraph(primal_fg);
|
||||
resource->manager()->KeepRoots({bprop_fg});
|
||||
}
|
||||
|
||||
TEST_F(TestKPynative, test_stop_gradient) {
|
||||
auto primal_fg = BuildStopGradient("test_stop_gradient");
|
||||
resource->manager()->KeepRoots({primal_fg});
|
||||
|
||||
auto bprop_fg = BuildBpropFuncGraph(primal_fg);
|
||||
resource->manager()->KeepRoots({bprop_fg});
|
||||
}
|
||||
} // namespace ad
|
||||
} // namespace mindspore
|
|
@ -26,6 +26,26 @@ from .vm_interface import vm
|
|||
# pylint: disable=unused-argument
|
||||
|
||||
|
||||
@vm_impl_getters.register(P.ZerosLike)
|
||||
def vm_impl_zeroslike(self):
|
||||
def vm_impl(x):
|
||||
x = x.asnumpy()
|
||||
out = np.zeros_like(x)
|
||||
return Tensor(out)
|
||||
|
||||
return vm_impl
|
||||
|
||||
|
||||
@vm_impl_getters.register(P.Log)
|
||||
def vm_impl_log(self):
|
||||
def vm_impl(x):
|
||||
x = x.asnumpy()
|
||||
out = np.log(x)
|
||||
return Tensor(out)
|
||||
|
||||
return vm_impl
|
||||
|
||||
|
||||
@vm_impl_getters.register(P.Add)
|
||||
def vm_impl_tensor_add(self):
|
||||
"""Generate vm_impl function for TensorAdd."""
|
||||
|
|
Loading…
Reference in New Issue