new_construct_bprop

move expander files
move expander component to mindspore/core
move the bprop expanders to mindspore/frontend
This commit is contained in:
luochao 2022-11-10 16:32:47 +08:00
parent 1809ae2820
commit 605c1a8479
63 changed files with 1748 additions and 1795 deletions

View File

@ -872,7 +872,7 @@ void SessionBasic::GetOpInputTensors(const CNodePtr &cnode,
}
}
input_tensor_info->input_tensors_mask.emplace_back(
(is_value_node && !is_forward_output) ? kValueNodeTensorMask : kParameterDataTensorMask);
(is_value_node || !is_forward_output) ? kValueNodeTensorMask : kParameterDataTensorMask);
} else if (real_input->isa<Parameter>()) {
tensor = GetParameterOutputTensor(real_input, parameter_index, graph_inputs);
input_tensor_info->input_tensors_mask.emplace_back(tensor->is_parameter() ? kParameterWeightTensorMask

View File

@ -901,6 +901,9 @@ bool TopoCompareFuncGraphNode(const AnfNodePtr &node, const bool &is_first_sort,
if (is_first_sort) {
(void)new_graph_info->topo_node_list.emplace_back(std::move(node_info));
} else {
if (common::AnfAlgo::IsControlOpExecInBackend(cnode)) {
return true;
}
if (topo_node_idx >= old_graph_info->topo_node_list.size()) {
return true;
}

View File

@ -149,7 +149,7 @@ class BACKEND_EXPORT MindRTBackendBase : public Backend {
// Save the mapping between cell id and actor info.
mindspore::HashMap<std::string, ActorInfo> graph_actor_infos_;
bool enable_backend_dynamic_detect_{false};
bool enable_backend_dynamic_detect_{true};
FuncGraphPtr root_graph_;
GraphPartitionPtr graph_partition_;
std::shared_ptr<GraphCompiler> graph_compiler_;

View File

@ -0,0 +1,6 @@
approvers:
- gaoxiong1
- ckey_dou
- dayschan
- anyrenwei
- zichun_ye

View File

@ -13,14 +13,14 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "common/graph_kernel/bprop/bprop.h"
#include "frontend/operator/bprop/bprop.h"
#include <algorithm>
#include <memory>
#include <queue>
#include <set>
#include <string>
#include "common/graph_kernel/bprop/expander/infer.h"
#include "expander/infer.h"
#include "utils/anf_utils.h"
#include "include/common/debug/anf_ir_dump.h"

View File

@ -13,22 +13,22 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_COMMON_GRAPH_KERNEL_BPROP_BPROP_H_
#define MINDSPORE_CCSRC_COMMON_GRAPH_KERNEL_BPROP_BPROP_H_
#ifndef MINDSPORE_CCSRC_FRONTEND_OPERATOR_BPROP_BPROP_H_
#define MINDSPORE_CCSRC_FRONTEND_OPERATOR_BPROP_BPROP_H_
#include <map>
#include <vector>
#include <utility>
#include "ir/anf.h"
#include "common/graph_kernel/bprop/bprop_irbuilder.h"
#include "frontend/operator/bprop/bprop_irbuilder.h"
#include "include/common/visible.h"
namespace mindspore {
using DoutUserType = std::vector<std::pair<CNodePtr, int>>;
// deprecated
COMMON_EXPORT void BuildBprop(const CNodePtr &cnode, CNodePtrList *outputs, DoutUserType *dout_user);
void BuildBprop(const CNodePtr &cnode, CNodePtrList *outputs, DoutUserType *dout_user);
using UserType = std::map<AnfNodePtr, std::vector<std::pair<CNodePtr, int>>>;
COMMON_EXPORT bool BuildBprop(const CNodePtr &cnode, CNodePtrList *outputs, UserType *users);
bool BuildBprop(const CNodePtr &cnode, CNodePtrList *outputs, UserType *users);
} // namespace mindspore
#endif // MINDSPORE_CCSRC_COMMON_GRAPH_KERNEL_BPROP_BPROP_H_
#endif // MINDSPORE_CCSRC_FRONTEND_OPERATOR_BPROP_BPROP_H_

View File

@ -14,14 +14,14 @@
* limitations under the License.
*/
#include "common/graph_kernel/bprop/bprop_irbuilder.h"
#include "frontend/operator/bprop/bprop_irbuilder.h"
#include <algorithm>
#include <vector>
#include <limits>
#include "include/common/utils/utils.h"
#include "utils/ms_context.h"
#include "common/graph_kernel/bprop/expander/common_utils.h"
#include "frontend/operator/bprop/grad/common_utils.h"
namespace mindspore {
namespace expander {

View File

@ -13,8 +13,8 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_COMMON_GRAPH_KERNEL_BPROP_BPROP_IRBUILDER_H_
#define MINDSPORE_CCSRC_COMMON_GRAPH_KERNEL_BPROP_BPROP_IRBUILDER_H_
#ifndef MINDSPORE_CCSRC_FRONTEND_OPERATOR_BPROP_BPROP_IRBUILDER_H_
#define MINDSPORE_CCSRC_FRONTEND_OPERATOR_BPROP_BPROP_IRBUILDER_H_
#include <memory>
#include <vector>
@ -22,8 +22,8 @@
#include <map>
#include <functional>
#include "common/graph_kernel/bprop/expander/node.h"
#include "common/graph_kernel/bprop/expander/emitter.h"
#include "expander/node.h"
#include "expander/emitter.h"
#include "utils/hash_map.h"
namespace mindspore {
@ -112,4 +112,4 @@ class BpropIRBuilderRegHelper {
} // namespace bprop
} // namespace expander
} // namespace mindspore
#endif // MINDSPORE_CCSRC_COMMON_GRAPH_KERNEL_BPROP_BPROP_IRBUILDER_H_
#endif // MINDSPORE_CCSRC_FRONTEND_OPERATOR_BPROP_BPROP_IRBUILDER_H_

View File

@ -13,7 +13,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "common/graph_kernel/bprop/expander/common_utils.h"
#include "frontend/operator/bprop/grad/common_utils.h"
#include <algorithm>
#include <memory>

View File

@ -13,19 +13,20 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_COMMON_GRAPH_KERNEL_BPROP_EXPANDER_COMMON_UTILS_H_
#define MINDSPORE_CCSRC_COMMON_GRAPH_KERNEL_BPROP_EXPANDER_COMMON_UTILS_H_
#ifndef MINDSPORE_CCSRC_FRONTEND_OPERATOR_BPROP_GRAD_COMMON_UTILS_H_
#define MINDSPORE_CCSRC_FRONTEND_OPERATOR_BPROP_GRAD_COMMON_UTILS_H_
#include <cmath>
#include <vector>
#include <utility>
#include <set>
#include "common/graph_kernel/bprop/expander/node.h"
#include "common/graph_kernel/bprop/bprop_irbuilder.h"
#include "expander/node.h"
#include "frontend/operator/bprop/bprop_irbuilder.h"
namespace mindspore::expander::bprop {
constexpr auto pi = acos(-1.0);
constexpr auto log_2 = log(2.0);
constexpr auto log_pi = log(pi);
inline const auto pi = std::acos(-1.0);
inline const auto log_2 = std::log(2.0);
inline const auto log_pi = std::log(pi);
std::vector<std::vector<int64_t>> BroadcastGradientArgs(const std::vector<int64_t> &x_shape,
const std::vector<int64_t> &y_shape);
@ -83,4 +84,4 @@ bool CheckType(const TypePtr &check_type, const std::set<TypePtr> &template_type
ShapeVector PoolToNHWC(const ShapeVector &v);
ShapeVector ConvToNHWC(const ShapeVector &v);
} // namespace mindspore::expander::bprop
#endif // MINDSPORE_CCSRC_COMMON_GRAPH_KERNEL_BPROP_EXPANDER_COMMON_UTILS_H_
#endif // MINDSPORE_CCSRC_FRONTEND_OPERATOR_BPROP_GRAD_COMMON_UTILS_H_

View File

@ -14,8 +14,8 @@
* limitations under the License.
*/
#include "common/graph_kernel/bprop/bprop_irbuilder.h"
#include "common/graph_kernel/bprop/expander/common_utils.h"
#include "frontend/operator/bprop/bprop_irbuilder.h"
#include "frontend/operator/bprop/grad/common_utils.h"
#include "include/common/utils/utils.h"
namespace mindspore::expander::bprop {

View File

@ -13,8 +13,8 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "common/graph_kernel/bprop/bprop_irbuilder.h"
#include "common/graph_kernel/bprop/expander/common_utils.h"
#include "frontend/operator/bprop/bprop_irbuilder.h"
#include "frontend/operator/bprop/grad/common_utils.h"
#include "include/common/utils/utils.h"
namespace mindspore::expander::bprop {

View File

@ -13,8 +13,8 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "common/graph_kernel/bprop/bprop_irbuilder.h"
#include "common/graph_kernel/bprop/expander/common_utils.h"
#include "frontend/operator/bprop/bprop_irbuilder.h"
#include "frontend/operator/bprop/grad/common_utils.h"
#include "include/common/utils/utils.h"
#include "ir/anf.h"

View File

@ -13,7 +13,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "common/graph_kernel/bprop/bprop_irbuilder.h"
#include "frontend/operator/bprop/bprop_irbuilder.h"
#include "include/common/utils/utils.h"
namespace mindspore::expander::bprop {

View File

@ -14,8 +14,8 @@
* limitations under the License.
*/
#include <set>
#include "common/graph_kernel/bprop/bprop_irbuilder.h"
#include "common/graph_kernel/bprop/expander/common_utils.h"
#include "frontend/operator/bprop/bprop_irbuilder.h"
#include "frontend/operator/bprop/grad/common_utils.h"
#include "include/common/utils/utils.h"
#include "utils/ms_context.h"

View File

@ -14,7 +14,7 @@
* limitations under the License.
*/
#include "common/graph_kernel/bprop/bprop_irbuilder.h"
#include "frontend/operator/bprop/bprop_irbuilder.h"
#include "include/common/utils/utils.h"
namespace mindspore::expander::bprop {

View File

@ -15,9 +15,9 @@
*/
#include <unordered_set>
#include "common/graph_kernel/bprop/bprop_irbuilder.h"
#include "frontend/operator/bprop/bprop_irbuilder.h"
#include "include/common/utils/utils.h"
#include "common/graph_kernel/bprop/expander/common_utils.h"
#include "frontend/operator/bprop/grad/common_utils.h"
namespace mindspore::expander::bprop {
static NodePtr GetMatrixDiagAssist(const BpropIRBuilder *ib, const ShapeVector &x_shape, TypePtr x_dtype) {

View File

@ -15,8 +15,8 @@
*/
#include <map>
#include "common/graph_kernel/bprop/bprop_irbuilder.h"
#include "common/graph_kernel/bprop/expander/common_utils.h"
#include "frontend/operator/bprop/bprop_irbuilder.h"
#include "frontend/operator/bprop/grad/common_utils.h"
#include "include/common/utils/utils.h"
namespace mindspore::expander::bprop {

View File

@ -15,9 +15,9 @@
*/
#include <unordered_set>
#include "common/graph_kernel/bprop/bprop_irbuilder.h"
#include "frontend/operator/bprop/bprop_irbuilder.h"
#include "include/common/utils/utils.h"
#include "common/graph_kernel/bprop/expander/common_utils.h"
#include "frontend/operator/bprop/grad/common_utils.h"
namespace mindspore::expander::bprop {
NodePtrList CheckBpropExpander(const BpropIRBuilder *ib) {

View File

@ -13,8 +13,8 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "common/graph_kernel/bprop/bprop_irbuilder.h"
#include "common/graph_kernel/bprop/expander/common_utils.h"
#include "frontend/operator/bprop/bprop_irbuilder.h"
#include "frontend/operator/bprop/grad/common_utils.h"
#include "include/common/utils/utils.h"
#include "utils/check_convert_utils.h"
@ -1030,7 +1030,6 @@ REG_BPROP_BUILDER("SparseSoftmaxCrossEntropyWithLogits").SetBody([](const BpropI
auto logits = ib->GetInput(kIndex0);
auto dout = ib->GetInput(kIndex3);
auto grad = ib->Emit(kSparseSoftmaxCrossEntropyWithLogitsOpName, {logits, labels}, {{kAttrIsGrad, MakeValue(true)}});
grad = ib->Emit("Depend", {grad, out});
grad = ib->Mul(grad, dout);
return {grad, ib->ZerosLike(labels)};
});

View File

@ -14,7 +14,7 @@
* limitations under the License.
*/
#include "common/graph_kernel/bprop/bprop_irbuilder.h"
#include "frontend/operator/bprop/bprop_irbuilder.h"
#include "include/common/utils/utils.h"
namespace mindspore::expander::bprop {

View File

@ -13,7 +13,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "common/graph_kernel/bprop/bprop_irbuilder.h"
#include "frontend/operator/bprop/bprop_irbuilder.h"
#include "include/common/utils/utils.h"
#include "utils/ms_context.h"

View File

@ -14,7 +14,7 @@
* limitations under the License.
*/
#include "common/graph_kernel/bprop/bprop_irbuilder.h"
#include "frontend/operator/bprop/bprop_irbuilder.h"
#include "include/common/utils/utils.h"
namespace mindspore::expander::bprop {
@ -52,7 +52,7 @@ REG_BPROP_BUILDER("SolveTriangular").SetBody([](const BpropIRBuilder *ib) -> Nod
if (GetValue<bool>(unit_diagonal)) {
auto fill = ib->Emit("Fill", {ib->EmitValue(ib->GetDtype(grad_a)), ib->Value<ShapeVector>(ShapeVector(1, row_size)),
ib->Tensor(0, ib->GetDtype(grad_a))});
grad_a = ib->Emit("MatrixSetDiagV3", {grad_a, fill, ib->Fill(0L, {2}, TypeId::kNumberTypeInt32)},
grad_a = ib->Emit("MatrixSetDiagV3", {grad_a, fill, ib->Fill(int64_t(0), {2}, TypeId::kNumberTypeInt32)},
{{"align", MakeValue("RIGHT_LEFT")}, {"max_length", MakeValue<int64_t>(200000000)}});
}
return {grad_a, grad_b};
@ -96,7 +96,7 @@ REG_BPROP_BUILDER("Eigh").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
// constants in the computation
std::vector<int32_t> zero_tensor_value = {0, 0};
ShapeVector zero_shape{2};
auto zero_tensor = ib->Fill(0L, {2}, TypeId::kNumberTypeInt32);
auto zero_tensor = ib->Fill(int64_t(0), {2}, TypeId::kNumberTypeInt32);
auto kValueNeg1 = ib->Value<int64_t>(-1);
auto kValueNeg2 = ib->Value<int64_t>(-2);

View File

@ -15,8 +15,8 @@
*/
#include <tuple>
#include "common/graph_kernel/bprop/bprop_irbuilder.h"
#include "common/graph_kernel/bprop/expander/common_utils.h"
#include "frontend/operator/bprop/bprop_irbuilder.h"
#include "frontend/operator/bprop/grad/common_utils.h"
#include "include/common/utils/utils.h"
#include "utils/ms_context.h"

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,228 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_AD_AUTO_GRAD_H_
#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_AD_AUTO_GRAD_H_
#include <memory>
#include <utility>
#include <map>
#include <vector>
#include "ir/anf.h"
#include "ir/func_graph.h"
namespace mindspore {
namespace ad {
struct GradAttr {
GradAttr(bool get_all, bool get_by_list, bool sens_param, bool get_by_position, bool weight_param_is_tuple)
: grad_all_inputs(get_all),
grad_weights(get_by_list),
has_sens(sens_param),
get_by_position(get_by_position),
weight_param_is_tuple(weight_param_is_tuple) {}
bool grad_all_inputs;
bool grad_weights;
bool has_sens;
bool get_by_position;
bool weight_param_is_tuple;
};
struct GradParam {
GradParam(const CNodePtr &cnode, const ValuePtrList &op_args, const ValuePtr &out, FuncGraphPtr fprop_fg = nullptr)
: cnode(cnode), op_args(op_args), out(out), fprop_fg(std::move(fprop_fg)) {}
// Primal CNode create by op forward process
const CNodePtr &cnode;
// Input value for cnode
const ValuePtrList &op_args;
// Output of op
const ValuePtr &out;
// Bprop func graph
const FuncGraphPtr fprop_fg;
// High order used this, which
bool grad_by_value = true;
};
using GradParamPtr = std::shared_ptr<GradParam>;
class FunctionNode {
public:
FunctionNode(const FuncGraphPtr &tape, const AnfNodePtr &dout)
: tape_(tape), accumulate_dout_(dout), fake_dout_(dout) {}
void AddEdge(const AnfNodePtr &next_node, const AnfNodePtr &din);
void UpdateAccumulativeDout(const AnfNodePtr &new_dout);
const std::vector<std::pair<AnfNodePtr, AnfNodePtr>> &next_edges() const { return next_edges_; }
AnfNodePtr RealDout() const { return accumulate_dout_; }
void ReplaceEdges();
const AnfNodePtr fake_dout() const { return fake_dout_; }
private:
AnfNodePtr HyperAdd(const AnfNodePtr &left_node, const AnfNodePtr &right_node);
// Bprop func graph
const FuncGraphPtr tape_;
// Input of dout for this bprop function
AnfNodePtr accumulate_dout_;
// First we generate a fake dout
const AnfNodePtr fake_dout_;
// Represent where thd dins backpropagate to other bprop function or variable
std::vector<std::pair<AnfNodePtr, AnfNodePtr>> next_edges_;
// Replace next_edges where din == dout in brprop function
std::vector<int> need_replace_edges_;
};
using FunctionNodePtr = std::shared_ptr<FunctionNode>;
class VariableNode {
public:
VariableNode(const FunctionNodePtr &fn, const ValuePtr &out_value) : fn_(fn), out_value_(out_value) {}
ValuePtr out_value() const { return out_value_; }
FunctionNodePtr fn() const { return fn_; }
bool is_need_grad() const { return is_need_grad_; }
void set_is_need_grad(bool is_need_grad) { is_need_grad_ = is_need_grad; }
AnfNodePtr k_node() const { return k_node_; }
void set_k_node(const AnfNodePtr &k_node) { k_node_ = k_node; }
private:
// Abstract bprop function
FunctionNodePtr fn_;
ValuePtr out_value_;
bool is_need_grad_{false};
// k mapped cnode for primal CNode; primal CNode is owned by primal funcgraph, this is owned by tape_;
AnfNodePtr k_node_{nullptr};
};
using VariableNodePtr = std::shared_ptr<VariableNode>;
class AutoGradCellImpl {
public:
using UserType = std::map<AnfNodePtr, std::vector<std::pair<CNodePtr, int>>>;
AutoGradCellImpl(const AnfNodePtrList &cell_inputs, const std::vector<ValuePtr> &input_param_values);
~AutoGradCellImpl() = default;
// Reverse connect bprop of op
bool KPynativeOp(const CNodePtr &cnode, const ValuePtrList &op_args, const ValuePtr &out);
// Reverse connect ms_function or higher order sub bprop funcgraph
bool KPynativeWithFProp(const GradParamPtr &grad_param);
CNodePtr GetBPropFromFProp(const FuncGraphPtr &fprop_fg, const AnfNodePtrList &args, const ValuePtr &out,
AnfNodePtr *const tape_dout);
// Update top cell output, record last_node
void UpdateOutputNodeOfTopCell(const AnfNodePtr &output_node, const ValuePtr &sens_out);
// Build a back propagate funcgraph, each cnode in primal funcgraph is replaced by value node or formal cnode, so it
// can be grad again.
FuncGraphPtr Finish(const AnfNodePtrList &weights, const std::vector<size_t> &grad_position,
const GradAttr &grad_attr, bool build_formal_param);
private:
// Last cnode of this Cell, may be a primitive op or cell with user defined bprop.
AnfNodePtr last_node_{nullptr};
ValuePtr sens_value_{nullptr};
// Bprop funcgraph
FuncGraphPtr tape_;
// Top cell inputs
AnfNodePtrList cell_inputs_;
// These weights need to calculate gradient.
mindspore::HashSet<AnfNodePtr> need_grad_weights_;
// Bprop dins of each variable or middle out
OrderedMap<AnfNodePtr, VariableNodePtr> anfnode_to_variable_adjoint_;
AnfNodePtrList weights_;
// Record cnode's input map for tape_
UserType users_;
// Flag for ms_funtcion and high order
bool has_fbprop_{false};
bool IsCNodeNeedGrad(const AnfNodePtr &node_ptr) const;
std::vector<bool> GetNeedGradFlags(const CNodePtr &cnode);
// construct input as cnode for expander
CNodePtr ConstructBpropGraphInput(const CNodePtr &cnode, const ValuePtrList &op_args, const ValuePtr &out,
const AnfNodePtr &dout);
// Back propagate for one node;
void UpdateNextEdges(const FunctionNodePtr &fn, const CNodePtr &cnode, const std::vector<CNodePtr> &dins,
const ValuePtrList &op_args);
void UpdateNextEdges(const FunctionNodePtr &fn, const AnfNodePtr &node, const AnfNodePtr &din,
const ValuePtr &op_arg);
void BuildForwardLastNode();
// Add parameter(weights) to anfnode_to_variable_adjoint_
void AddParameterNode(const AnfNodePtr &parameter, const ValuePtr &tensor);
AnfNodePtr GetRealDin(const FunctionNodePtr &fn, const ValuePtr &out_value, const ValuePtr &sub_value,
const AnfNodePtr &din);
void BuildBPropCutCNode(const CNodePtr &cnode, std::vector<CNodePtr> *outputs);
void BuildCustomBpropCNode(const CNodePtr &cnode, std::vector<CNodePtr> *outputs);
void BuildFakeBpropCNode(const CNodePtr &cnode, std::vector<CNodePtr> *outputs);
// Replace input or weights parameter from primal funcgraph to parameters of tape_;
void ReplacePrimalParameter(const AnfNodePtrList &weights, bool has_sens_arg);
// Set sens and weights parameter nodes by user input info
void SetSensAndWeights(const AnfNodePtrList &weights, bool has_sens_arg);
// get last reverse iterator
OrderedMap<AnfNodePtr, VariableNodePtr>::reverse_iterator GetLastNodeReverseIter();
void BackPropagate();
// Set return node according to grad flag
void SetOutput(const AnfNodePtrList &weights, const std::vector<size_t> &grad_position, const GradAttr &grad_attr);
AnfNodePtr GetGradNodeByIndex(const AnfNodePtrList &node_list, size_t index);
AnfNodePtr GetInputGrad(bool grad_all_inputs, bool get_by_position, const std::vector<size_t> &grad_position);
AnfNodePtr GetWeightGrad(bool grad_weights, const AnfNodePtrList &weights, bool weight_param_is_tuple);
bool IsOutputBothEmpty(const AnfNodePtr &inputs_grad, const AnfNodePtr &weights_grad) const;
AnfNodePtr GenerateEmptyTupleValue();
void AddUser(const AnfNodePtr &node, const CNodePtr &user, size_t index);
void Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node);
void ElimateTupleGetItem();
void ClearDeviceAddress(const ValuePtr &out);
// Fbprop
AnfNodePtr BuildKNode(const GradParamPtr &grad_param);
AnfNodePtrList BuildKNodeListFromPrimalCNode(const CNodePtr &cnode, const VariableNodePtr &adjoint);
AnfNodePtr BuildKNodeForCNodeInput(const ValuePtrList &op_args, const AnfNodePtr &input_node, size_t input_index);
};
using AutoGradCellImplPtr = std::shared_ptr<AutoGradCellImpl>;
// Start building back propagate funcgraph for this cell.
// cell_inputs: the input parameter list of this cell except the weights;
AutoGradCellImplPtr GradPynativeCellBegin(const AnfNodePtrList &cell_inputs,
const std::vector<ValuePtr> &input_param_values);
// Return the back propagate funcgraph for this cell.
// weights: weights parameters used in this cell.
// grad_inputs: return sensitivity for input parameters;
// grad_weights: return sensitivity for weights;
// has_sens_arg: caller will pass sens args;
// return: the returned funcgraph will have prototype:
// if has_sens_arg is true
// (sens_input1, sens_input2, ..., sens_weight0, sens_weight1, ) bprop_fg(input1, input2, ..., weight0, weight1, ...,
// sens_out)
// else:
// (sens_input1, sens_input2, ..., sens_weight0, sens_weight1, ) bprop_fg(input1, input2, ..., weight0, weight1, ...)
// if build_formal_param is true
// each cnode in primal funcgraph is replaced by formal cnode
// else:
// each cnode in primal funcgraph is replaced by value node
FuncGraphPtr GradPynativeCellEnd(const AutoGradCellImplPtr &k_cell, const AnfNodePtrList &weights,
const std::vector<size_t> &grad_position, const GradAttr &grad_attr,
bool build_formal_param = false);
// Grad for each operation.
// c_node: CNode with contains the prim (index 0) and the formal input parameters of that prim.
// op_args: the arguments list of each input parameters.
// out: the op result.
bool GradPynativeOp(const AutoGradCellImplPtr &k_cell, const CNodePtr &cnode, const ValuePtrList &op_args,
const ValuePtr &out);
// adjoint bprop form ms_function and high grad
void GradPynativeFBprop(const CNodePtr &cnode, const ValuePtrList &op_args, const ValuePtr &out,
const FuncGraphPtr &fprop_fg);
} // namespace ad
} // namespace mindspore
#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_AD_AUTO_GRAD_H_

File diff suppressed because it is too large Load Diff

View File

@ -1,113 +0,0 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_AD_KPYNATIVE_H_
#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_AD_KPYNATIVE_H_
#include <memory>
#include <vector>
#include "ir/anf.h"
#include "ir/func_graph.h"
namespace mindspore {
namespace ad {
class KPynativeCell {
public:
virtual ~KPynativeCell() = default;
virtual void UpdateOutputNodeOfTopCell(const AnfNodePtr &output_node, const ValuePtr &sen_out) = 0;
// Grad for cell which may have user passed front propagate FuncGraph.
// c_node: CNode with contains the construct function graph of cell (index 0) and the formal input parameters of that
// cell. op_args: the arguments list of each input parameters.
// out: the op result.
// fprop_fg: user defined back propagate cnode which output is the bprop_fg.
// Should have prototype: (sens_input1, sens_input2, ...) bprop_fg(input1, input2, ..., out, dout)
virtual bool KPynativeWithFProp(const CNodePtr &c_node, const ValuePtrList &op_args, const ValuePtr &out,
const FuncGraphPtr &fprop_fg) = 0;
};
using KPynativeCellPtr = std::shared_ptr<KPynativeCell>;
struct GradAttr {
bool grad_all_inputs;
bool grad_weights;
bool has_sens;
bool get_by_position;
bool weight_param_is_tuple;
GradAttr(bool get_all, bool get_by_list, bool sens_param, bool get_by_position, bool weight_param_is_tuple)
: grad_all_inputs(get_all),
grad_weights(get_by_list),
has_sens(sens_param),
get_by_position(get_by_position),
weight_param_is_tuple(weight_param_is_tuple) {}
};
// bprop_fg: user defined back propagate funcgraph or back propagate funcgraph of primitive, it will be passed after
// just parsed. will have prototype:
// (sens_input1, sens_input2, ...) bprop_fg(input1, input2, ..., out, dout)
// c_node: CNode with contains the prim (index 0) and the formal input parameters of that prim.
// op_args: the arguments list of each input parameters.
// out: the op result.
// return: the returned funcgraph should have the same prototype.
FuncGraphPtr OptimizeBPropFuncGraph(const FuncGraphPtr &bprop_fg, const CNodePtr &cnode, const ValuePtrList &op_args,
const ValuePtr &out);
// Start building back propagate funcgraph for this cell.
// cell_inputs: the input parameter list of this cell except the weights;
KPynativeCellPtr GradPynativeCellBegin(const AnfNodePtrList &cell_inputs,
const std::vector<ValuePtr> &input_param_values);
// Return the back propagate funcgraph for this cell.
// weights: weights parameters used in this cell.
// grad_inputs: return sensitivity for input parameters;
// grad_weights: return sensitivity for weights;
// has_sens_arg: caller will pass sens args;
// return: the returned funcgraph will have prototype:
// if has_sens_arg is true
// (sens_input1, sens_input2, ..., sens_weight0, sens_weight1, ) bprop_fg(input1, input2, ..., weight0, weight1, ...,
// sens_out)
// else:
// (sens_input1, sens_input2, ..., sens_weight0, sens_weight1, ) bprop_fg(input1, input2, ..., weight0, weight1, ...)
// if build_formal_param is true
// each cnode in primal funcgraph is replaced by formal cnode
// else:
// each cnode in primal funcgraph is replaced by value node
FuncGraphPtr GradPynativeCellEnd(const KPynativeCellPtr &k_cell, const AnfNodePtrList &weights,
const std::vector<size_t> &grad_position, const GradAttr &grad_attr,
bool build_formal_param = false);
// Grad for each operation.
// c_node: CNode with contains the prim (index 0) and the formal input parameters of that prim.
// op_args: the arguments list of each input parameters.
// out: the op result.
bool GradPynativeOp(const KPynativeCellPtr &k_cell, const CNodePtr &cnode, const ValuePtrList &op_args,
const ValuePtr &out);
// Grad for cell which may have user defined back propagate function.
// c_node: CNode with contains the construct function graph of cell (index 0) and the formal input parameters of that
// cell. op_args: the arguments list of each input parameters.
// out: the op result.
// bprop_fg: user defined back propagate funcgraph, it should be passed after just parsed.
// Should have prototype: (sens_input1, sens_input2, ...) bprop_fg(input1, input2, ..., out, dout)
bool GradPynativeWithBProp(const KPynativeCellPtr &k_cell, const CNodePtr &c_node, const ValuePtrList &op_args,
const ValuePtr &out, const FuncGraphPtr &bprop_fg);
// Clear all static resources that used in grad process
void ClearKPynativeCellStaticRes();
} // namespace ad
} // namespace mindspore
#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_AD_GRAD_H_

View File

@ -191,18 +191,6 @@ FuncGraphPtr PrimBpOptPassStep2(const opt::irpass::OptimizeIRPassLib &irpass, co
}
FuncGraphPtr BpropGraphFinalOptPass(const ResourcePtr &resource) {
MS_EXCEPTION_IF_NULL(resource);
MS_EXCEPTION_IF_NULL(resource->func_graph());
(void)TransformTopGraphPass(resource);
auto func_graph = resource->func_graph();
// PyNative dynamic shape need add those pass, like convert make_list to make_tuple.
// Cannot execute those pass due to performance reasons if the graph is a dynamic structure graph.
MS_EXCEPTION_IF_NULL(func_graph);
if (func_graph->has_flag(FUNC_GRAPH_FLAG_DYNAMIC_SHAPE) || !func_graph->has_flag(kFlagIsDynamicStructure)) {
(void)OptPassAGroup(resource);
(void)CleanAfterOptAPass(resource);
}
opt::irpass::OptimizeIRPassLib irpass;
opt::OptPassConfig bg_final_opt = opt::OptPassConfig({
irpass.inline_,
@ -210,6 +198,27 @@ FuncGraphPtr BpropGraphFinalOptPass(const ResourcePtr &resource) {
irpass.tuple_list_get_item_eliminator_,
irpass.tuple_list_set_item_eliminator_,
irpass.depend_value_elim_,
});
OptPassGroupMap map({{"ad_final_opt", bg_final_opt}});
auto bprop_graph_final_opt = opt::Optimizer::MakeOptimizer("bprop_graph_final_opt", resource, map);
MS_EXCEPTION_IF_NULL(resource);
auto func_graph = resource->func_graph();
MS_EXCEPTION_IF_NULL(func_graph);
WITH(MsProfile::GetProfile()->Step("bprop_graph_final_opt"))[&bprop_graph_final_opt, &func_graph]() {
func_graph = bprop_graph_final_opt->step(func_graph, true);
};
// Validate(func_graph);
return func_graph;
}
FuncGraphPtr OptGradGraphPass(const ResourcePtr &resource) {
opt::irpass::OptimizeIRPassLib irpass;
opt::OptPassConfig grad_graph_opt = opt::OptPassConfig({
irpass.inline_,
irpass.tuple_list_get_set_item_eliminator_,
irpass.tuple_list_get_item_eliminator_,
irpass.tuple_list_set_item_eliminator_,
irpass.depend_value_elim_,
irpass.reshape_eliminate_,
irpass.switch_simplify_,
irpass.addn_zero_filter_,
@ -217,28 +226,21 @@ FuncGraphPtr BpropGraphFinalOptPass(const ResourcePtr &resource) {
});
opt::OptPassConfig fill_zeros_like = opt::OptPassConfig{irpass.zero_like_fill_zero_};
OptPassGroupMap map({
{"ad_final_opt", bg_final_opt},
{"ad_final_opt", grad_graph_opt},
{"zeros_like", fill_zeros_like},
});
if (pynative::PyNativeExecutor::GetInstance()->grad_executor()->need_renormalize()) {
(void)map.emplace_back(std::make_pair("renormalize", opt::OptPassConfig::Renormalize()));
opt::OptPassConfig real_op_eliminate = opt::OptPassConfig{irpass.real_op_eliminate_};
(void)map.emplace_back(std::make_pair("real_op_eliminate", real_op_eliminate));
opt::OptPassConfig environ_eliminate = opt::OptPassConfig({
irpass.incorporate_call_,
irpass.incorporate_call_switch_,
});
(void)map.emplace_back(std::make_pair("environ_eliminate", environ_eliminate));
}
auto bprop_graph_final_opt = opt::Optimizer::MakeOptimizer("bprop_graph_final_opt", resource, map);
func_graph = resource->func_graph();
(void)map.emplace_back(std::make_pair("renormalize", opt::OptPassConfig::Renormalize()));
opt::OptPassConfig real_op_eliminate = opt::OptPassConfig{irpass.real_op_eliminate_};
(void)map.emplace_back(std::make_pair("real_op_eliminate", real_op_eliminate));
MS_EXCEPTION_IF_NULL(resource);
auto bprop_graph_final_opt = opt::Optimizer::MakeOptimizer("grad_graph_opt", resource, map);
auto func_graph = resource->func_graph();
MS_EXCEPTION_IF_NULL(func_graph);
WITH(MsProfile::GetProfile()->Step("bprop_graph_final_opt"))[&bprop_graph_final_opt, &func_graph]() {
func_graph = bprop_graph_final_opt->step(func_graph, true);
};
func_graph = LiftingClone(func_graph);
Validate(func_graph);
// Validate(func_graph);
return func_graph;
}

View File

@ -55,6 +55,7 @@ FuncGraphPtr PrimBpOptPassStep1(const opt::irpass::OptimizeIRPassLib &irpass, co
FuncGraphPtr PrimBpOptPassStep2(const opt::irpass::OptimizeIRPassLib &irpass, const ResourcePtr &resource,
const std::vector<bool> &need_grad_flags);
FuncGraphPtr BpropGraphFinalOptPass(const ResourcePtr &resource);
FuncGraphPtr OptGradGraphPass(const ResourcePtr &resource);
} // namespace pipeline
} // namespace mindspore

View File

@ -1838,7 +1838,6 @@ void MemoryRecycle() {
ReclaimOptimizer();
session::ExecutorManager::Instance().ClearDoneTasks();
ad::g_k_prims.clear();
ad::ClearKPynativeCellStaticRes();
ad::PrimBpropOptimizer::GetPrimBpropOptimizerInst().Clear();
abstract::AnalysisResultCacheMgr::GetInstance().Clear();
abstract::AnalysisContext::ClearContext();
@ -1850,6 +1849,7 @@ void MemoryRecycle() {
parse::Parser::CleanParserResource();
trace::ClearTraceStack();
pynative::PyNativeExecutor::GetInstance()->ClearRes();
pynative::PyNativeExecutor::GetInstance()->WorkerJoin();
ConfigManager::GetInstance().ResetConfig();
ScopeManager::GetInstance().ClearScope();
FuncGraphLoopBreaker::Inst().CleanMetaFuncGraphCache();
@ -1881,7 +1881,6 @@ void ClearResPart1() {
(void)distributed::collective::CollectiveManager::instance()->Finalize();
PrimitivePy::ClearHookRes();
ad::g_k_prims.clear();
ad::ClearKPynativeCellStaticRes();
ad::PrimBpropOptimizer::GetPrimBpropOptimizerInst().Clear();
abstract::ClearPrimEvaluatorMap();

View File

@ -60,6 +60,7 @@ struct FrontendOpRunInfo {
bool output_get_by_infer_value = false;
int mix_type{0};
size_t input_size = 0;
size_t custom_bprop_cell_count = 0;
PrimitivePyPtr op_prim{nullptr};
ValuePtr out_value{nullptr};
std::string op_info;
@ -86,12 +87,13 @@ struct InputArgsInfo {
bool has_custom_bprop;
size_t input_size;
std::string obj_id;
bool has_sens{false};
PrimitivePyPtr custom_bprp_prim{nullptr};
ValuePtr out_value{nullptr};
std::string cell_id;
std::string input_args_id;
size_t custom_bprop_cell_count = 0;
size_t grad_order = 0;
std::vector<std::string> input_arg_id_vec;
std::vector<ValuePtr> input_arg_value_vec;
};

View File

@ -266,6 +266,7 @@ ValuePtr CastOperation::DoAutoCast(const FrontendOpRunInfoPtr &op_run_info, cons
constexpr auto input_size = 2;
const auto &cast_run_info = std::make_shared<FrontendOpRunInfo>();
cast_run_info->grad_flag = op_run_info->grad_flag;
cast_run_info->custom_bprop_cell_count = op_run_info->custom_bprop_cell_count;
MS_EXCEPTION_IF_NULL(cast_prim_);
cast_run_info->op_prim = cast_prim_;
cast_run_info->base_op_run_info.op_name = prim::kPrimCast->name();

View File

@ -184,7 +184,11 @@ void ForwardExecutor::RunOpForward(const FrontendOpRunInfoPtr &op_run_info) {
GetOutput(op_run_info);
}
// 4. Do op grad and record op info
grad()->ProcessOpGradInfo(op_run_info);
if (enable_async_) {
grad()->AsyncProcessOpGradInfo(op_run_info);
} else {
grad()->ProcessOpGradInfo(op_run_info);
}
}
FrontendOpRunInfoPtr ForwardExecutor::GenerateOpRunInfo(const py::args &args) const {
@ -194,6 +198,7 @@ FrontendOpRunInfoPtr ForwardExecutor::GenerateOpRunInfo(const py::args &args) co
const auto &op_run_info = std::make_shared<FrontendOpRunInfo>();
// Used for async run
op_run_info->grad_flag = grad()->grad_flag();
op_run_info->custom_bprop_cell_count = grad()->custom_bprop_cell_count();
op_run_info->base_op_run_info.op_name = args[static_cast<size_t>(RunOpArgsEnum::PY_NAME)].cast<std::string>();
op_run_info->base_op_run_info.lazy_build = lazy_build_;
PyNativeAlgo::PyParser::SetPrim(op_run_info, args[static_cast<size_t>(RunOpArgsEnum::PY_PRIM)]);
@ -236,6 +241,7 @@ void ForwardExecutor::GetOutput(const FrontendOpRunInfoPtr &op_run_info) {
op_run_info->out_value = result_v_list->value().front();
}
}
// Not use GetNext abs
if (op_run_info->base_op_run_info.op_name != kGetNextOpName) {
op_run_info->out_value_id = PyNativeAlgo::Common::GetIdByValue(op_run_info->out_value);

View File

@ -38,7 +38,9 @@ using MindrtBackendMap = std::map<std::string, std::shared_ptr<compile::MindRTBa
class ForwardExecutor {
public:
ForwardExecutor()
: cast_operation_(std::make_shared<CastOperation>()), infer_operation_(std::make_shared<InferOperation>()) {}
: cast_operation_(std::make_shared<CastOperation>()),
infer_operation_(std::make_shared<InferOperation>()),
enable_async_(std::getenv("ENABLE_ASYNC")) {}
~ForwardExecutor() = default;
void Init();
@ -100,6 +102,7 @@ class ForwardExecutor {
CastOperationPtr cast_operation_;
InferOperationPtr infer_operation_;
MindrtBackendMap mindrt_backends_;
bool enable_async_ = false;
};
} // namespace pynative
} // namespace mindspore

View File

@ -0,0 +1,27 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "pipeline/pynative/grad/bprop_task.h"
#include "utils/log_adapter.h"
namespace mindspore {
namespace pynative {
void BpropTask::Run() {
MS_LOG(DEBUG) << "run construct bprop task";
run_task_();
MS_LOG(DEBUG) << "finish construct bprop task";
}
} // namespace pynative
} // namespace mindspore

View File

@ -0,0 +1,36 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PIPELINE_PYNATIVE_BPROP_TASK_H_
#define MINDSPORE_CCSRC_PIPELINE_PYNATIVE_BPROP_TASK_H_
#include <functional>
#include "runtime/pynative/async/task.h"
namespace mindspore {
namespace pynative {
class BpropTask : public AsyncTask {
public:
explicit BpropTask(const std::function<void(void)> &task) : AsyncTask(kBpropTask), run_task_(task) {}
~BpropTask() = default;
void Run() override;
private:
std::function<void(void)> run_task_;
};
} // namespace pynative
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PIPELINE_PYNATIVE_BPROP_TASK_H_

View File

@ -139,12 +139,29 @@ ValuePtr ConvertOutputValueToTensor(const ValuePtr &v) {
}
}
FuncGraphPtr BpropGraphFinalOpt(const FuncGraphPtr &bprop_graph) {
auto resource = std::make_shared<pipeline::Resource>();
resource->set_func_graph(bprop_graph);
auto manager = resource->manager();
MS_EXCEPTION_IF_NULL(manager);
manager->AddFuncGraph(bprop_graph);
auto after_opt_bg = pipeline::BpropGraphFinalOptPass(resource);
PyNativeAlgo::Common::DumpGraphIR("after_final_opt.ir", after_opt_bg);
return after_opt_bg;
}
void SetGraphInputArgs(const std::vector<ValuePtr> &input_vec, const pipeline::ResourcePtr &res,
VectorRef *const arg_list) {
VectorRef *const arg_list, bool has_sens) {
MS_EXCEPTION_IF_NULL(arg_list);
// Set inputs values
for (auto v : input_vec) {
(void)arg_list->emplace_back(v);
size_t size = has_sens ? input_vec.size() - 1 : input_vec.size();
for (size_t i = 0; i < size; ++i) {
if (PyNativeAlgo::Common::IsTensor(input_vec[i])) {
(void)arg_list->emplace_back(input_vec[i]);
}
}
if (has_sens) {
(void)arg_list->emplace_back(input_vec.back());
}
MS_EXCEPTION_IF_NULL(res);
auto graph = res->func_graph();
@ -203,7 +220,7 @@ TopCellInfoPtr GradExecutor::PopHighOrderGraphStack() {
void GradExecutor::PushInputArgsInfoStack(const InputArgsInfoPtr &input_args_info) {
input_args_info_stack_.push(input_args_info);
++cell_order_;
// ++cell_order_;
}
void GradExecutor::PopInputArgsInfoStack() {
@ -251,12 +268,12 @@ void GradExecutor::HandleInputArgsForTopCell(const InputArgsInfoPtr &input_args_
new_param->set_abstract(param_i_abs);
top_cell()->SetParamNodeMapInGraphInfoMap(input_args_info->input_arg_id_vec[i], new_param);
}
top_cell()->set_k_pynative_cell_ptr(ad::GradPynativeCellBegin(curr_g()->parameters(), input_param_values));
top_cell()->set_auto_grad_cell_ptr(ad::GradPynativeCellBegin(curr_g()->parameters(), input_param_values));
}
void GradExecutor::InitResourceAndDfBuilder(const InputArgsInfoPtr &input_args_info) {
MS_EXCEPTION_IF_NULL(input_args_info);
if (input_args_info->is_grad_topest_cell || IsNestedGrad()) {
if (input_args_info->is_grad_topest_cell || input_args_info->grad_order > 1) {
if (input_args_info->is_grad_topest_cell && !grad_is_running_) {
MS_LOG(DEBUG) << "Make new topest graph";
MakeNewTopGraph(input_args_info);
@ -268,7 +285,7 @@ void GradExecutor::InitResourceAndDfBuilder(const InputArgsInfoPtr &input_args_i
top_cell()->SetGraphInfoMap(fg, graph_info_cg);
HandleInputArgsForTopCell(input_args_info, true);
bprop_grad_stack_.push(std::make_pair(input_args_info->cell_id, false));
} else if (grad_is_running_ && top_cell()->grad_order() != grad_order_) {
} else if (grad_is_running_ && top_cell()->grad_order() != input_args_info->grad_order) {
MS_LOG(DEBUG) << "Nested grad graph existed in custom bprop";
MakeNewTopGraph(input_args_info);
bprop_grad_stack_.push(std::make_pair(input_args_info->cell_id, true));
@ -290,42 +307,59 @@ void GradExecutor::InitResourceAndDfBuilder(const InputArgsInfoPtr &input_args_i
void GradExecutor::NewGraphInner(const py::object &obj, const py::args &args) {
const auto &input_args_info = GetInputArgsInfo(obj, args, input_args_info_stack_.empty(), is_high_order_top_cell());
PushInputArgsInfoStack(input_args_info);
if (input_args_info->has_custom_bprop) {
custom_bprop_cell_count_ += 1;
input_args_info->custom_bprop_cell_count = custom_bprop_cell_count_;
}
if (grad_order_ == 0) {
IncreaseGradOrder();
}
input_args_info->grad_order = grad_order_;
// May be can async here
NewGraphImpl(input_args_info);
if (enable_async_) {
AsyncNewGraphImpl(input_args_info);
} else {
NewGraphImpl(input_args_info);
}
}
void GradExecutor::NewGraphImpl(const InputArgsInfoPtr &input_args_info) {
MS_EXCEPTION_IF_NULL(input_args_info);
++cell_order_;
const auto &cell_id = input_args_info->cell_id;
MS_LOG(DEBUG) << "NewGraphInner start " << input_args_info->input_size << ", cell_id " << cell_id
<< ", input args info ptr " << input_args_info.get();
// When the cell has custom bprop, in_custom_bprop_cell is lager than 0
if (input_args_info->has_custom_bprop) {
custom_bprop_cell_count_ += 1;
}
// Make top graph and init resource
InitResourceAndDfBuilder(input_args_info);
}
void GradExecutor::AsyncNewGraphImpl(const InputArgsInfoPtr &input_args_info) {
const auto fn = [this, input_args_info]() { this->NewGraphImpl(input_args_info); };
auto task = std::make_shared<BpropTask>(fn);
async_executor_->Push(task);
}
void GradExecutor::MakeNewTopGraph(const InputArgsInfoPtr &input_args_info) {
MS_EXCEPTION_IF_NULL(input_args_info);
// CheckAlready run first, grad_order_ will increase 1(highorder scenario)
// If NetA.set_grad(), so come here first, CheckAlready run later, so grad_order_ need increase 1
if (grad_order_ == 0) {
IncreaseGradOrder();
if (input_args_info->grad_order == 0) {
input_args_info->grad_order++;
}
// Both set grad: NetA.set_grad(); NetB.set_grad();
// Run forward: NetA(); NetB();
// Grad(NetA()); Grad(NetB()). grad_order_ is disordered, so need reset.
if (input_args_info->is_grad_topest_cell && IsNestedGrad()) {
DecreaseGradOrder();
if (input_args_info->is_grad_topest_cell && input_args_info->grad_order > 1) {
input_args_info->grad_order--;
}
auto fg = std::make_shared<FuncGraph>();
fg->debug_info()->set_name("pynative_forward_graph");
auto resource = std::make_shared<pipeline::Resource>();
const auto &already_run_cell_id = input_args_info->cell_id + std::to_string(grad_order_ == 0 ? 1 : grad_order_);
top_cell_ = std::make_shared<TopCellInfo>(grad_order_, input_args_info->cell_id, already_run_cell_id, resource, fg);
const auto &already_run_cell_id = input_args_info->cell_id + std::to_string(input_args_info->grad_order);
top_cell_ = std::make_shared<TopCellInfo>(input_args_info->grad_order, input_args_info->cell_id, already_run_cell_id,
resource, fg);
top_cell_->set_forward_already_run(true);
top_cell_->set_input_args_id(input_args_info->input_args_id);
PushHighOrderGraphStack(top_cell_);
@ -349,10 +383,11 @@ void GradExecutor::SetForwardLastNodeInfo(const ValuePtr &v, const std::string &
output_node->set_abstract(v->ToAbstract()->Broaden());
}
// Set last output abstract and will be used for sens
auto k_pynative_cell_ptr = top_cell()->k_pynative_cell_ptr();
MS_EXCEPTION_IF_NULL(k_pynative_cell_ptr);
auto auto_grad_cell_ptr = top_cell()->auto_grad_cell_ptr();
MS_EXCEPTION_IF_NULL(auto_grad_cell_ptr);
auto sens_v = ConvertOutputValueToTensor(v);
k_pynative_cell_ptr->UpdateOutputNodeOfTopCell(output_node, sens_v);
auto cloned_value = ShallowCopyTensorValue(sens_v);
auto_grad_cell_ptr->UpdateOutputNodeOfTopCell(output_node, cloned_value);
}
void GradExecutor::EndGraphInner(const py::object &obj, const py::object &out, const py::args &args) {
@ -366,8 +401,15 @@ void GradExecutor::EndGraphInner(const py::object &obj, const py::object &out, c
}
input_args_info->out_value = PyNativeAlgo::DataConvert::PyObjToValue(out);
PopInputArgsInfoStack();
if (input_args_info->is_grad_topest_cell) {
set_grad_flag(false);
}
// May be can async here
EndGraphImpl(input_args_info);
if (enable_async_) {
AsyncEndGraphImpl(input_args_info);
} else {
EndGraphImpl(input_args_info);
}
}
void GradExecutor::EndGraphImpl(const InputArgsInfoPtr &input_args_info) {
@ -402,7 +444,7 @@ void GradExecutor::EndGraphImpl(const InputArgsInfoPtr &input_args_info) {
SetForwardLastNodeInfo(out_value, out_id);
top_cell()->ClearCellHookOp();
cell_order_ = 0;
set_grad_flag(false);
// set_grad_flag(false);
}
// Checkout whether need to compile graph when each top cell has ran finished
if (is_top_cell_end) {
@ -415,9 +457,15 @@ void GradExecutor::EndGraphImpl(const InputArgsInfoPtr &input_args_info) {
}
}
void GradExecutor::AsyncEndGraphImpl(const InputArgsInfoPtr input_args_info) {
const auto fn = [this, input_args_info]() { this->EndGraphImpl(input_args_info); };
auto task = std::make_shared<BpropTask>(fn);
async_executor_->Push(task);
}
void GradExecutor::DoGradForCustomBprop(const InputArgsInfoPtr &input_args_info, const std::string &out_id) {
MS_EXCEPTION_IF_NULL(input_args_info);
if (!input_args_info->has_custom_bprop || custom_bprop_cell_count_ != 0) {
if (!input_args_info->has_custom_bprop || input_args_info->custom_bprop_cell_count != 0) {
return;
}
MS_LOG(DEBUG) << "Do grad for custom bprop";
@ -437,6 +485,7 @@ void GradExecutor::DoGradForCustomBprop(const InputArgsInfoPtr &input_args_info,
void GradExecutor::GetCustomBpropPrim(const py::object &obj, const py::args &args, const py::object &out,
const InputArgsInfoPtr &input_args_info) {
custom_bprop_cell_count_ -= 1;
input_args_info->custom_bprop_cell_count = custom_bprop_cell_count_;
if (custom_bprop_cell_count_ != 0) {
return;
}
@ -488,12 +537,17 @@ void GradExecutor::GetCustomBpropPrim(const py::object &obj, const py::args &arg
void GradExecutor::GradNetInner(const prim::GradOperationPtr &grad, const py::object &obj, const py::object &weights,
const py::object &grad_position, const py::args &args) {
{
py::gil_scoped_release gil_release;
async_executor_->Wait();
}
MS_EXCEPTION_IF_NULL(grad);
MS_EXCEPTION_IF_NULL(top_input_args_info_);
MS_LOG(DEBUG) << "GradNetInner start " << args.size() << ", cell_id " << top_input_args_info_->cell_id
<< ", input args info ptr " << top_input_args_info_.get();
if (grad->sens_param()) {
MS_LOG(DEBUG) << "Get sens param";
top_input_args_info_->has_sens = true;
size_t forward_args_size = args.size() - 1;
auto sens_v = PyNativeAlgo::DataConvert::PyObjToValue(args[forward_args_size]);
const auto &sens_tensor = ConvertOutputValueToTensor(sens_v);
@ -529,7 +583,10 @@ void GradExecutor::GetGradGraph(const ad::GradAttr &grad_attr, const std::vector
MS_EXCEPTION_IF_NULL(manager);
manager->AddFuncGraph(bprop_graph, true);
PyNativeAlgo::Common::DumpGraphIR("launch_bprop_graph.ir", bprop_graph);
resource->SetBackendAsync([]() { return compile::CreateBackend(); });
if (backends_.find(top_input_args_info_->obj_id) == backends_.end()) {
backends_[top_input_args_info_->obj_id] = compile::CreateBackend();
}
resource->SetBackendAsync([&]() { return backends_[top_input_args_info_->obj_id]; });
MS_LOG(DEBUG) << "Start task emit action";
(void)TaskEmitAction(resource);
MS_LOG(DEBUG) << "Start execute action";
@ -653,9 +710,6 @@ void GradExecutor::UpdateParamAbsByArgs(const std::vector<ValuePtr> &input_args,
if (param_node->abstract() != nullptr) {
auto input_shape = input_abs->BuildShape()->ToString();
auto param_tensor_abs = param_node->abstract();
if (param_tensor_abs->isa<abstract::AbstractRefTensor>()) {
param_tensor_abs = param_tensor_abs->cast<abstract::AbstractRefPtr>()->CloneAsTensor();
}
CheckParamShapeAndType(param, param_node, input_abs, param_tensor_abs, input_shape);
}
param_node->set_abstract(input_abs->Broaden());
@ -672,14 +726,10 @@ FuncGraphPtr GradExecutor::GetBpropGraph(const ad::GradAttr &grad_attr, const ve
build_formal_param = true;
need_renormalize_ = true;
}
if (top_cell()->ms_function_flag()) {
need_renormalize_ = true;
}
auto k_pynative_cell_ptr = top_cell()->k_pynative_cell_ptr();
MS_EXCEPTION_IF_NULL(k_pynative_cell_ptr);
FuncGraphPtr bprop_graph =
ad::GradPynativeCellEnd(k_pynative_cell_ptr, w_args, p_args, grad_attr, build_formal_param);
auto auto_grad_cell_ptr = top_cell()->auto_grad_cell_ptr();
MS_EXCEPTION_IF_NULL(auto_grad_cell_ptr);
FuncGraphPtr bprop_graph = ad::GradPynativeCellEnd(auto_grad_cell_ptr, w_args, p_args, grad_attr, build_formal_param);
MS_EXCEPTION_IF_NULL(bprop_graph);
MS_LOG(DEBUG) << "Top graph input params size " << top_input_args_info_->input_arg_value_vec.size();
@ -688,18 +738,13 @@ FuncGraphPtr GradExecutor::GetBpropGraph(const ad::GradAttr &grad_attr, const ve
bprop_graph->set_flag(FUNC_GRAPH_FLAG_CORE, true);
bprop_graph->debug_info()->set_name(ss.str());
UpdateParamAbsByArgs(top_input_args_info_->input_arg_value_vec, bprop_graph, grad_attr.has_sens);
// Do opt for final bprop graph
pipeline::ResourcePtr resource = std::make_shared<pipeline::Resource>();
resource->set_func_graph(bprop_graph);
auto manager = resource->manager();
MS_EXCEPTION_IF_NULL(manager);
manager->AddFuncGraph(bprop_graph);
auto optimized_bg = ad::PrimBpropOptimizer::GetPrimBpropOptimizerInst().BpropGraphFinalOpt(resource);
if (top_cell()->ms_function_flag()) {
bprop_graph = BpropGraphFinalOpt(bprop_graph);
}
if (top_input_args_info_->is_grad_topest_cell) {
need_renormalize_ = false;
}
PyNativeAlgo::Common::DumpGraphIR("after_final_opt.ir", optimized_bg);
return optimized_bg;
return bprop_graph;
}
void GradExecutor::SetGradOrder(const std::string &cell_id) {
@ -749,25 +794,25 @@ py::object GradExecutor::RunGradGraph() {
const auto &resource = top_cell()->resource();
MS_EXCEPTION_IF_NULL(resource);
MS_LOG(DEBUG) << "Run cell id " << top_input_args_info_->cell_id << ", resource ptr " << resource.get();
std::vector<ValuePtr> flatten_v;
PyNativeAlgo::DataConvert::FlattenArgs(top_input_args_info_->input_arg_value_vec, &flatten_v,
top_input_args_info_->has_sens);
bool has_sens = top_input_args_info_->has_sens;
VectorRef arg_list;
SetGraphInputArgs(flatten_v, resource, &arg_list);
MS_LOG(DEBUG) << "Convert args size " << flatten_v.size() << ", graph param size " << arg_list.size();
SetGraphInputArgs(top_input_args_info_->input_arg_value_vec, resource, &arg_list, has_sens);
MS_LOG(DEBUG) << "Convert args size " << top_input_args_info_->input_arg_value_vec.size() << ", graph param size "
<< arg_list.size();
compile::VmEvalFuncPtr run = resource->GetResult(pipeline::kOutput).cast<compile::VmEvalFuncPtr>();
MS_EXCEPTION_IF_NULL(run);
const auto &backend = MsContext::GetInstance()->backend_policy();
MS_LOG(DEBUG) << "Eval run " << backend;
grad_is_running_ = true;
top_cell()->set_k_pynative_cell_ptr(nullptr);
top_cell()->set_auto_grad_cell_ptr(nullptr);
BaseRef out_value = (*run)(arg_list);
grad_is_running_ = false;
MS_LOG(DEBUG) << "Eval run end " << out_value.ToString();
const auto &cur_run_bprop_graph = resource->func_graph();
const auto &out_abs = GetGradGraphOutputAbstract(cur_run_bprop_graph);
MakeNestedCnode(top_input_args_info_->has_custom_bprop, flatten_v, cur_run_bprop_graph, out_value);
MakeNestedCnode(top_input_args_info_->has_custom_bprop, top_input_args_info_->input_arg_value_vec,
cur_run_bprop_graph, out_value);
return BaseRefToPyData(out_value, out_abs);
}
@ -816,7 +861,8 @@ void GradExecutor::MakeNestedCnode(bool has_custom_bprop, const std::vector<Valu
std::vector<ValuePtr> out_v{out_value};
out_value = std::make_shared<ValueTuple>(out_v);
}
if (!top_cell()->k_pynative_cell_ptr()->KPynativeWithFProp(cnode, input_args, out_value, second_grad_fg)) {
auto grad_param = std::make_shared<ad::GradParam>(cnode, input_args, out_value, second_grad_fg);
if (!top_cell()->auto_grad_cell_ptr()->KPynativeWithFProp(grad_param)) {
MS_LOG(EXCEPTION) << "Failed to run ad grad for second grad graph " << cnode->ToString();
}
need_renormalize_ = true;
@ -918,6 +964,8 @@ void GradExecutor::ClearRes() {
top_cell_ = nullptr;
top_input_args_info_ = nullptr;
bprop_cell_list_.clear();
backends_.clear();
async_executor_->Reset();
std::stack<InputArgsInfoPtr>().swap(input_args_info_stack_);
std::stack<std::pair<std::string, bool>>().swap(bprop_grad_stack_);
std::stack<TopCellInfoPtr>().swap(high_order_stack_);
@ -1000,18 +1048,6 @@ AnfNodePtr GradExecutor::GetOutputNodeAsInput(const std::string &obj_id) const {
return CreateTupleGetItemNode(obj_id, it->second);
}
void GradExecutor::RecordGradNodeToGraphInfoMap(const FuncGraphPtr &fg, const CNodePtr &cnode,
const std::string &obj_id, const ValuePtrList &input_args) const {
top_cell()->SetNodeMapInGraphInfoMap(obj_id, cnode);
// run ad for make tuple node
if (grad_is_running_ && !bprop_grad_stack_.empty() && !bprop_grad_stack_.top().second) {
MS_LOG(DEBUG) << "Running custom bprop, no need to do GradPynativeOp.";
} else {
(void)ad::GradPynativeOp(top_cell()->k_pynative_cell_ptr(), cnode, input_args,
std::make_shared<ValueTuple>(input_args));
}
}
AnfNodePtr GradExecutor::GetValueSequenceInput(const ValuePtr &v, const std::string &obj_id) const {
MS_EXCEPTION_IF_NULL(v);
if (!v->isa<ValueSequence>()) {
@ -1034,8 +1070,8 @@ AnfNodePtr GradExecutor::GetValueSequenceInput(const ValuePtr &v, const std::str
}
// Create make tuple node and record to graph info map.
auto cnode = curr_g()->NewCNode(inputs);
MS_LOG(DEBUG) << "Create make tuple node " << cnode->DebugString();
RecordGradNodeToGraphInfoMap(curr_g(), cnode, obj_id, input_args);
MS_LOG(DEBUG) << "Create make tuple node: " << cnode->DebugString();
top_cell()->SetNodeMapInGraphInfoMap(obj_id, cnode, -1, false);
return cnode;
}
@ -1094,7 +1130,7 @@ void GradExecutor::ProcessOpGradInfo(const FrontendOpRunInfoPtr &op_run_info) co
return;
}
// Do op grad and save node info. If cell have custom bprop, no need do op grad. Otherwise, need do.
if (custom_bprop_cell_count_ <= 0) {
if (op_run_info->custom_bprop_cell_count <= 0) {
const auto &cnode = ConstructForwardGraph(op_run_info);
MS_EXCEPTION_IF_NULL(cnode);
cnode->set_abstract(op_run_info->base_op_run_info.abstract);
@ -1103,6 +1139,12 @@ void GradExecutor::ProcessOpGradInfo(const FrontendOpRunInfoPtr &op_run_info) co
}
}
void GradExecutor::AsyncProcessOpGradInfo(const FrontendOpRunInfoPtr &op_run_info) const {
const auto fn = [this, op_run_info]() { this->ProcessOpGradInfo(op_run_info); };
auto task = std::make_shared<BpropTask>(fn);
async_executor_->Push(task);
}
void GradExecutor::SaveOutputNodeMap(const std::string &obj_id, const FrontendOpRunInfoPtr &op_run_info,
const CNodePtr &cnode) const {
MS_EXCEPTION_IF_NULL(cnode);
@ -1126,7 +1168,19 @@ void GradExecutor::DoOpGrad(const FrontendOpRunInfoPtr &op_run_info, const CNode
return;
}
MS_EXCEPTION_IF_NULL(op_run_info);
if (!ad::GradPynativeOp(top_cell()->k_pynative_cell_ptr(), cnode, op_run_info->input_value, op_out)) {
// to avoid out exist in tape bprop, avoid out be modified.
ValuePtrList cloned_op_args;
(void)std::transform(op_run_info->input_value.begin(), op_run_info->input_value.end(),
std::back_inserter(cloned_op_args),
[](const ValuePtr &value) { return ShallowCopyTensorValue(value); });
ValuePtr cloned_out = ShallowCopyTensorValue(op_out);
std::vector<tensor::TensorPtr> tensors;
TensorValueToTensor(cloned_out, &tensors);
for (auto tensor : tensors) {
tensor->set_is_forward_output(true);
}
if (!ad::GradPynativeOp(top_cell()->auto_grad_cell_ptr(), cnode, cloned_op_args, cloned_out)) {
MS_LOG(EXCEPTION) << "Failed to run ad grad for op " << op_run_info->base_op_run_info.op_name;
}
}

View File

@ -22,10 +22,13 @@
#include <utility>
#include <stack>
#include <vector>
#include <map>
#include "pipeline/pynative/base.h"
#include "pipeline/pynative/grad/top_cell.h"
#include "pipeline/pynative/grad/ms_function_grad.h"
#include "runtime/pynative/async/async_queue.h"
#include "pipeline/pynative/grad/bprop_task.h"
#include "pipeline/jit/resource.h"
namespace mindspore {
namespace pynative {
namespace py = pybind11;
@ -38,7 +41,10 @@ class GradExecutor {
GradExecutor() = default;
~GradExecutor() = default;
explicit GradExecutor(const ForwardExecutorPtr &forward_executor = nullptr)
: forward_executor_(ForwardExecutorWeakPtr(forward_executor)), ms_function_(std::make_shared<MsFunction>()) {}
: forward_executor_(ForwardExecutorWeakPtr(forward_executor)),
ms_function_(std::make_shared<MsFunction>()),
async_executor_(std::make_unique<AsyncQueue>()),
enable_async_(std::getenv("ENABLE_ASYNC")) {}
std::function<void(const py::object &, const py::args &)> InitGraph = [this](auto &&PH1, auto &&PH2) {
NewGraphInner(std::forward<decltype(PH1)>(PH1), std::forward<decltype(PH2)>(PH2));
@ -70,6 +76,7 @@ class GradExecutor {
// Construct grad graph for ms_function
inline bool eliminate_forward() const { return eliminate_forward_; }
inline void set_eliminate_forward(bool eliminate_forward) { eliminate_forward_ = eliminate_forward; }
inline size_t custom_bprop_cell_count() const { return custom_bprop_cell_count_; }
void SetHookChanged(const py::object &cell) const;
void GradNetInner(const prim::GradOperationPtr &grad, const py::object &obj, const py::object &weights,
const py::object &grad_position, const py::args &args);
@ -77,11 +84,14 @@ class GradExecutor {
CNodePtr ConstructForwardGraph(const FrontendOpRunInfoPtr &op_run_info) const;
py::object CheckAlreadyRun(const prim::GradOperationPtr &grad, const py::object &obj, const py::args &args);
void ProcessOpGradInfo(const FrontendOpRunInfoPtr &op_run_info) const;
void AsyncProcessOpGradInfo(const FrontendOpRunInfoPtr &op_run_info) const;
void EndGraphInner(const py::object &obj, const py::object &out, const py::args &args);
void EndGraphImpl(const InputArgsInfoPtr &input_args_info);
AnfNodePtr GetInput(const ValuePtr &v, const string &obj_id) const;
void AsyncEndGraphImpl(const InputArgsInfoPtr input_args_info);
AnfNodePtr GetParamInput(const ValuePtr &v, const std::string &id) const;
void ClearRes();
void WorkerJoin() { async_executor_->WorkerJoin(); }
private:
ForwardExecutorPtr forward() const;
@ -126,6 +136,7 @@ class GradExecutor {
bool IsBpropGraph(const std::string &cell_id) const;
void NewGraphInner(const py::object &obj, const py::args &args);
void NewGraphImpl(const InputArgsInfoPtr &input_args_info);
void AsyncNewGraphImpl(const InputArgsInfoPtr &input_args_info);
void SetForwardLastNodeInfo(const ValuePtr &v, const std::string &obj_id) const;
void GetCustomBpropPrim(const py::object &obj, const py::args &args, const py::object &out,
const InputArgsInfoPtr &input_args_info);
@ -145,8 +156,6 @@ class GradExecutor {
AnfNodePtr GetValueSequenceInput(const ValuePtr &v, const std::string &obj_id) const;
AnfNodePtr CreateTupleGetItemNode(const std::string &obj_id,
const std::pair<AnfNodePtr, std::vector<int64_t>> &out) const;
void RecordGradNodeToGraphInfoMap(const FuncGraphPtr &fg, const CNodePtr &cnode, const std::string &obj_id,
const ValuePtrList &input_args) const;
bool grad_flag_{false};
bool grad_is_running_{false};
@ -168,6 +177,9 @@ class GradExecutor {
std::stack<TopCellInfoPtr> high_order_stack_;
ForwardExecutorWeakPtr forward_executor_;
MsFunctionPtr ms_function_;
std::unique_ptr<AsyncQueue> async_executor_;
std::map<std::string, compile::BackendPtr> backends_;
bool enable_async_ = false;
};
} // namespace pynative
} // namespace mindspore

View File

@ -238,10 +238,11 @@ CNodePtr MsFunction::MakeAdjointForMsFunction(const FrontendOpRunInfoPtr &op_run
top_cell->SetNodeMapInGraphInfoMap(out_id, ms_function_cnode);
// Connect grad graph of ms_function to context.
auto k_pynative_cell_ptr = top_cell->k_pynative_cell_ptr();
MS_EXCEPTION_IF_NULL(k_pynative_cell_ptr);
if (!k_pynative_cell_ptr->KPynativeWithFProp(ms_function_cnode, op_run_info->input_value, op_run_info->out_value,
grad_graph)) {
auto auto_grad_cell_ptr = top_cell->auto_grad_cell_ptr();
MS_EXCEPTION_IF_NULL(auto_grad_cell_ptr);
auto grad_param =
std::make_shared<ad::GradParam>(ms_function_cnode, op_run_info->input_value, op_run_info->out_value, grad_graph);
if (!auto_grad_cell_ptr->KPynativeWithFProp(grad_param)) {
MS_LOG(EXCEPTION) << "Failed to make adjoint for ms_function cnode, ms_function cnode info: "
<< ms_function_cnode->DebugString();
}
@ -315,7 +316,6 @@ py::object MsFunction::GradMsFunction(const py::object &out, const py::args &arg
const auto &op_run_info = GetOpRunInfo(out, args, graph_phase_, &added_out_v);
FuncGraphPtr grad_graph = executor->GetGradGraph(graph_phase_);
PyNativeAlgo::Common::DumpGraphIR("ms_func_forward_graph.ir", ms_func_graph);
PyNativeAlgo::Common::DumpGraphIR("ms_func_grad_graph.ir", grad_graph);
GradMsFunctionInner(op_run_info, grad_executor.get(), added_out_v, ms_func_graph, grad_graph);
SetMsFuncGraphParameters(ms_func_graph);
graph_phase_.clear();

View File

@ -110,12 +110,15 @@ void TopCellInfo::SetParamNodeMapInGraphInfoMap(const std::string &id, const Par
}
}
void TopCellInfo::SetNodeMapInGraphInfoMap(const std::string &id, const AnfNodePtr &node, int64_t index) const {
void TopCellInfo::SetNodeMapInGraphInfoMap(const std::string &id, const AnfNodePtr &node, int64_t index,
bool save_flag) const {
auto &graph_info = graph_info_map().at(fg());
MS_EXCEPTION_IF_NULL(graph_info);
graph_info->node_map[id] = std::make_pair(node, std::vector<int64_t>{index});
// For example, set id of ((A,B),C) = {CNode, -1}
SetMultipleOutputToGraphInfoMap(id, node);
if (save_flag) {
SetMultipleOutputToGraphInfoMap(id, node);
}
}
void TopCellInfo::SetMultipleOutputToGraphInfoMap(const string &id, const AnfNodePtr &node) const {

View File

@ -32,7 +32,7 @@
#include "pybind11/pytypes.h"
#include "pybind_api/ir/base_ref_py.h"
#include "ir/anf.h"
#include "frontend/optimizer/ad/kpynative.h"
#include "frontend/optimizer/ad/auto_grad.h"
#include "frontend/operator/composite/composite.h"
#include "pipeline/jit/resource.h"
#include "pipeline/pynative/base.h"
@ -90,16 +90,14 @@ class TopCellInfo {
graph_info_map_[fg] = graph_info;
}
inline const OrderedMap<FuncGraphPtr, GraphInfoPtr> &graph_info_map() const { return graph_info_map_; }
inline ad::KPynativeCellPtr k_pynative_cell_ptr() const {
MS_EXCEPTION_IF_NULL(k_pynative_cell_ptr_);
return k_pynative_cell_ptr_;
}
inline void set_k_pynative_cell_ptr(const ad::KPynativeCellPtr &k_pynative_cell_ptr) {
k_pynative_cell_ptr_ = k_pynative_cell_ptr;
inline ad::AutoGradCellImplPtr auto_grad_cell_ptr() const { return auto_grad_cell_ptr_; }
void set_auto_grad_cell_ptr(const ad::AutoGradCellImplPtr &auto_grad_cell_ptr) {
auto_grad_cell_ptr_ = auto_grad_cell_ptr;
}
void DeleteParamNodeInfo(const FuncGraphPtr &g, const std::string &id);
void SetParamNodeMapInGraphInfoMap(const std::string &id, const ParameterPtr &param, bool is_weight = false) const;
void SetNodeMapInGraphInfoMap(const std::string &id, const AnfNodePtr &node, int64_t index = -1) const;
void SetNodeMapInGraphInfoMap(const std::string &id, const AnfNodePtr &node, int64_t index = -1,
bool save_flag = true) const;
void ClearDeviceMemory() const;
private:
@ -120,7 +118,7 @@ class TopCellInfo {
std::string grad_operation_;
pipeline::ResourcePtr resource_{nullptr};
FuncGraphPtr fg_{nullptr};
ad::KPynativeCellPtr k_pynative_cell_ptr_{nullptr};
ad::AutoGradCellImplPtr auto_grad_cell_ptr_{nullptr};
OrderedMap<FuncGraphPtr, GraphInfoPtr> graph_info_map_;
// Record `register hook` or `remove hook` function has been called by sub cell
// The record range between the begin and end of top cell.

View File

@ -71,6 +71,7 @@ class PyNativeExecutor : public std::enable_shared_from_this<PyNativeExecutor> {
void Sync() const;
void SetLazyBuild(bool enable) const;
bool IsFirstCell() const;
void WorkerJoin() { grad_executor_->WorkerJoin(); }
private:
PyNativeExecutor() = default;

View File

@ -36,6 +36,7 @@ namespace {
constexpr auto kBpropAttrName = "bprop";
constexpr auto kCellHookAttrName = "cell_hook";
constexpr auto kCellIDAttrName = "cell_id";
constexpr auto kCustomOpBpropAttrName = "custom_op_bprop";
std::map<std::string, std::string> kOpAttrNameReplaceMap = {
{"data_format", "format"},
};
@ -359,6 +360,26 @@ BaseRef PrimitivePy::RunCellBpropFunction(const py::tuple &py_args) const {
}
}
BaseRef PrimitivePy::RunOpBpropFunction(const py::tuple &py_args) const {
if (backward_hook_fn_.size() > 1) {
MS_LOG(EXCEPTION) << "Multiple registration of bprop function is not supported.";
}
py::tuple grads;
SyncData(py_args);
py::tuple converted_args(py_args.size());
ConvertCTensorToPyTensor(py_args, &converted_args);
try {
MS_LOG(DEBUG) << "start execute custom op bprop";
for (const auto &elem : backward_hook_fn_) {
py::object grads_obj = elem.second(*converted_args);
grads = check_bprop_out(grads_obj, py_args, bprop_cls_name_);
}
return std::make_shared<PyObjectRef>(grads);
} catch (std::exception &bt) {
std::rethrow_exception(std::current_exception());
}
}
BaseRef PrimitivePy::RunCellHookFunction(const py::tuple &py_args) const {
// Get the gradient passed to current bprop cut op.
const auto args_size = py_args.size();
@ -426,6 +447,10 @@ BaseRef PrimitivePy::RunHookFunction(const VectorRef &args) const {
if (is_cell) {
return RunCellHookFunction(py_args);
}
bool is_custom_op_bprop = this->HasAttr(kCustomOpBpropAttrName);
if (is_custom_op_bprop) {
return RunOpBpropFunction(py_args);
}
return RunVariableHookFunction(py_args);
}

View File

@ -59,6 +59,7 @@ class PrimitivePy : public Primitive {
void RemoveBackwardHookFn(const int &key);
BaseRef RunHookFunction(const VectorRef &args) const;
BaseRef RunCellBpropFunction(const py::tuple &py_args) const;
BaseRef RunOpBpropFunction(const py::tuple &py_args) const;
BaseRef RunCellHookFunction(const py::tuple &py_args) const;
BaseRef RunVariableHookFunction(const py::tuple &py_args) const;
BaseRef RunComputeFunction(const VectorRef &args) const override;

View File

@ -23,6 +23,7 @@ enum TaskType {
kUnknownTask = 0,
kOpRunTask,
kOpBuildTask,
kBpropTask,
kExitTask,
};
class AsyncTask {

View File

@ -38,6 +38,7 @@ file(GLOB_RECURSE CORE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
"utils/*.cc"
"load_mindir/*.cc"
"mindapi/src/*.cc"
"expander/*.cc"
)
set(CORE_SRC_LIST ${CORE_SRC_LIST} ${CORE_OPS_LIST})

View File

@ -495,11 +495,11 @@ PrimitiveEvalImplMap &GetPrimitiveToBackendEvalImplMap() {
StandardPrimitiveImplReg GetPrimitiveInferImpl(const PrimitivePtr &primitive) {
MS_EXCEPTION_IF_NULL(primitive);
auto iter = GetPrimitiveToEvalImplMap().find(primitive);
if (iter == GetPrimitiveToEvalImplMap().end()) {
return {nullptr, nullptr, false};
const auto iter = GetPrimitiveToEvalImplMap().find(primitive);
if (iter != GetPrimitiveToEvalImplMap().end()) {
return iter->second;
}
return iter->second;
return {nullptr, nullptr, false};
}
void RegisterStandardPrimitiveImpl(const PrimitivePtr &primitive, const StandardPrimitiveImplReg &impl_reg) {

View File

@ -0,0 +1,6 @@
approvers:
- gaoxiong1
- ckey_dou
- dayschan
- anyrenwei
- zichun_ye

View File

@ -14,7 +14,7 @@
* limitations under the License.
*/
#include "common/graph_kernel/bprop/expander/emitter.h"
#include "expander/emitter.h"
#include <algorithm>
#include <functional>
@ -146,13 +146,13 @@ NodePtr Emitter::ZerosLike(const NodePtr &node) const {
return Emit(prim::kZerosLike, {node});
}
NodePtr Emitter::Fill(const double &value, const ShapeVector &shape, TypeId data_type) const {
NodePtr Emitter::Fill(double value, const ShapeVector &shape, TypeId data_type) const {
size_t data_num = LongToSize(std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int64_t>()));
std::vector<double> data(data_num, value);
return Tensor(data_type, shape, &data[0], TypeId::kNumberTypeFloat64);
}
NodePtr Emitter::Fill(const int64_t &value, const ShapeVector &shape, TypeId data_type) const {
NodePtr Emitter::Fill(int64_t value, const ShapeVector &shape, TypeId data_type) const {
size_t data_num = LongToSize(std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int64_t>()));
std::vector<int64_t> data(data_num, value);
return Tensor(data_type, shape, &data[0], TypeId::kNumberTypeInt64);

View File

@ -25,12 +25,12 @@
#include "ir/func_graph.h"
#include "ops/core_ops.h"
#include "include/common/utils/utils.h"
#include "common/graph_kernel/bprop/expander/node.h"
#include "common/graph_kernel/bprop/expander/infer.h"
#include "expander/node.h"
#include "expander/infer.h"
namespace mindspore {
namespace expander {
class Emitter {
class MS_CORE_API Emitter {
public:
Emitter(const FuncGraphPtr &func_graph, const ExpanderInferPtr &infer, const ScopePtr &scope = nullptr)
: func_graph_(func_graph), infer_(infer), scope_(scope) {
@ -119,8 +119,8 @@ class Emitter {
NodePtr ReduceSum(const NodePtr &x, const ShapeVector &axis = {}, bool keep_dims = false) const;
NodePtr ZerosLike(const NodePtr &node) const;
NodePtr Fill(const double &value, const ShapeVector &shape, TypeId data_type) const;
NodePtr Fill(const int64_t &value, const ShapeVector &shape, TypeId data_type) const;
NodePtr Fill(double value, const ShapeVector &shape, TypeId data_type) const;
NodePtr Fill(int64_t value, const ShapeVector &shape, TypeId data_type) const;
/// \brief Emit a value node
template <typename T>
@ -165,11 +165,11 @@ class Emitter {
};
using EmitterPtr = std::shared_ptr<Emitter>;
NodePtr operator+(const NodePtr &lhs, const NodePtr &rhs);
NodePtr operator-(const NodePtr &lhs, const NodePtr &rhs);
NodePtr operator*(const NodePtr &lhs, const NodePtr &rhs);
NodePtr operator/(const NodePtr &lhs, const NodePtr &rhs);
NodePtr operator-(const NodePtr &node);
MS_CORE_API NodePtr operator+(const NodePtr &lhs, const NodePtr &rhs);
MS_CORE_API NodePtr operator-(const NodePtr &lhs, const NodePtr &rhs);
MS_CORE_API NodePtr operator*(const NodePtr &lhs, const NodePtr &rhs);
MS_CORE_API NodePtr operator/(const NodePtr &lhs, const NodePtr &rhs);
MS_CORE_API NodePtr operator-(const NodePtr &node);
} // namespace expander
} // namespace mindspore
#endif // MINDSPORE_CORE_EXPANDER_EMITTER_H_

View File

@ -14,7 +14,7 @@
* limitations under the License.
*/
#include "common/graph_kernel/bprop/expander/infer.h"
#include "expander/infer.h"
#include <algorithm>
#include "abstract/ops/primitive_infer_map.h"

View File

@ -17,12 +17,12 @@
#ifndef MINDSPORE_CORE_EXPANDER_INFER_H_
#define MINDSPORE_CORE_EXPANDER_INFER_H_
#include <memory>
#include "common/graph_kernel/bprop/expander/node.h"
#include "expander/node.h"
namespace mindspore {
namespace expander {
/// \brief ExpanderInfer is the adapter for inferring functions that is called in emitter.
class ExpanderInfer {
class MS_CORE_API ExpanderInfer {
public:
/// \brief Infer shape and dtype for node
virtual void Infer(const NodePtr &node) = 0;
@ -33,7 +33,7 @@ class ExpanderInfer {
using ExpanderInferPtr = std::shared_ptr<ExpanderInfer>;
/// \brief CppInfer calls the InferShapeAndType interface of frontend or backend map.
class CppInfer : public ExpanderInfer {
class MS_CORE_API CppInfer : public ExpanderInfer {
public:
void Infer(const NodePtr &node) override;
BaseShapePtr GetShape(const NodePtr &node) override;

View File

@ -14,10 +14,10 @@
* limitations under the License.
*/
#include "common/graph_kernel/bprop/expander/node.h"
#include "expander/node.h"
#include <algorithm>
#include "common/graph_kernel/bprop/expander/emitter.h"
#include "common/graph_kernel/bprop/expander/infer.h"
#include "expander/emitter.h"
#include "expander/infer.h"
namespace mindspore {
namespace expander {

View File

@ -27,7 +27,7 @@ namespace expander {
class Emitter;
using DAttr = mindspore::HashMap<std::string, ValuePtr>;
class Node : public std::enable_shared_from_this<Node> {
class MS_CORE_API Node : public std::enable_shared_from_this<Node> {
public:
Node(const AnfNodePtr &node, const Emitter *emitter);
~Node() = default;

View File

@ -20,6 +20,7 @@ from mindspore.ops import composite as C
from mindspore.common.tensor import Tensor
import mindspore.common.dtype as mstype
from mindspore.common.dtype import dtype_to_nptype, get_py_obj_dtype
from mindspore._c_expression import Tensor as CTensor
def ScalarAdd(x, y):
@ -103,8 +104,11 @@ def zeros_like_tensor(x):
def OnesLike(x):
"""Implement `oneslike`."""
x = x.asnumpy()
value = Tensor(np.ones(x.shape).astype(x.dtype))
if isinstance(x, (Tensor, CTensor)):
x = x.asnumpy()
value = Tensor(np.ones(x.shape).astype(x.dtype))
else:
value = Tensor(1.0)
return value

View File

@ -492,9 +492,6 @@ class DepthwiseConv2dNativeBackpropFilter(PrimitiveWithInfer):
self.group = group
self.add_prim_attr('data_format', "NCHW")
def __call__(self, x, w_size, dout):
raise NotImplementedError
def __infer__(self, x, w_size, dout):
w_size_v = w_size['value']
args = {'x': x['dtype'], 'dout': dout['dtype']}
@ -559,9 +556,6 @@ class DepthwiseConv2dNativeBackpropInput(PrimitiveWithInfer):
self.group = group
self.add_prim_attr('data_format', "NCHW")
def __call__(self, x_size, w, dout):
raise NotImplementedError
def __infer__(self, x_size, w, dout):
args = {'w': w['dtype'], 'dout': dout['dtype']}
validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
@ -664,9 +658,6 @@ class UniqueGrad(Primitive):
def __init__(self):
self.init_prim_io_names(inputs=['dy', 'y'], outputs=['dx'])
def __call__(self, dy, x, scale, save_mean, save_inv_variance):
raise NotImplementedError
class BNTrainingReduceGrad(Primitive):
"""Gradients of FusedBatchNorm operation."""
@ -721,9 +712,6 @@ class NeighborExchangeV2Grad(PrimitiveWithInfer):
'dtype': dy['dtype'],
'value': None}
def __call__(self, tensor):
raise NotImplementedError
class GeLUGrad(Primitive):
"""Gradients of GeLU operation."""
@ -1389,9 +1377,6 @@ class LayerNormGrad(Primitive):
self.begin_norm_axis = validator.check_value_type('begin_norm_axis', begin_norm_axis, [int], self.name)
self.begin_params_axis = validator.check_value_type('begin_params_axis', begin_params_axis, [int], self.name)
def __call__(self, x, dy, variance, mean, gamma):
raise NotImplementedError
class LayerNormGradGrad(Primitive):
"""
@ -1793,9 +1778,6 @@ class ReluGrad(Primitive):
"""Initialize ReluGrad"""
self.init_prim_io_names(inputs=['y_backprop', 'x'], outputs=['output'])
def __call__(self, y_backprop, x):
raise NotImplementedError
class ReLU6Grad(Primitive):
"""Performs grad of ReLU6 operation."""
@ -1804,9 +1786,6 @@ class ReLU6Grad(Primitive):
def __init__(self):
self.init_prim_io_names(inputs=['y_grad', 'x'], outputs=['output'])
def __call__(self, y_grad, x):
raise NotImplementedError
class ReluGradV2(Primitive):
"""Performs grad of ReLUV2 operation."""
@ -1815,9 +1794,6 @@ class ReluGradV2(Primitive):
def __init__(self):
self.init_prim_io_names(inputs=['gradients', 'mask'], outputs=['output'])
def __call__(self, gradients, mask):
raise NotImplementedError
class EluGrad(Primitive):
"""Performs grad of Elu operation."""

View File

@ -19,9 +19,11 @@ import numpy as np
from mindspore.nn import Cell
from mindspore.common import Tensor, dtype, Parameter
from mindspore.ops import operations as P
from mindspore import jit
from mindspore import jit, context
import mindspore.ops.functional as F
context.set_context(mode=context.GRAPH_MODE)
@case_register.level0
@case_register.target_gpu

View File

@ -18,8 +18,11 @@ from mindspore.nn import Cell
from mindspore.common import Tensor, dtype
import mindspore.ops.functional as F
import mindspore.ops.operations as P
from mindspore import context
import numpy as np
context.set_context(mode=context.GRAPH_MODE)
@case_register.level0
@case_register.target_gpu

View File

@ -148,7 +148,7 @@ def test_grad_multiple_inputs_multiple_outputs_cell_pynative():
assert np.allclose(real_grad[1].asnumpy(), expect_grad2.asnumpy())
@pytest.mark.level0
@pytest.mark.level2
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_grad_iteration_function_pynative():

View File

@ -16,12 +16,9 @@
import platform
import numpy as np
import pytest
import mindspore as ms
from mindspore import nn
from mindspore import ops
from mindspore import context, Tensor
from mindspore import jit
class Net(nn.Cell):
@ -32,8 +29,6 @@ class Net(nn.Cell):
self.addn = ops.AddN()
self.relu = nn.ReLU()
@jit(input_signature=(Tensor(shape=[2, 3, 6, None], dtype=ms.float32),
Tensor(shape=[2, 3, None, None], dtype=ms.float32)))
def construct(self, x, y):
x = self.addn((x, y))
x = self.log(x)
@ -58,8 +53,6 @@ class CmpNet(nn.Cell):
return x
@jit(input_signature=(Tensor(shape=[2, 3, 6, None], dtype=ms.float32),
Tensor(shape=[2, 3, None, None], dtype=ms.float32)))
def func(x, y):
x = ops.AddN()((x, y))
x = ops.Log()(x)

View File

@ -1,123 +0,0 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <iostream>
#include <unordered_map>
#include "frontend/optimizer/ad/kpynative.h"
#include "common/common_test.h"
#include "common/py_func_graph_fetcher.h"
#include "ir/manager.h"
#include "ir/value.h"
#include "ir/func_graph_cloner.h"
#include "utils/log_adapter.h"
#include "ir/graph_utils.h"
#include "pipeline/jit/resource.h"
#include "pipeline/jit/parse/parse.h"
#include "pipeline/jit/debug/anf_ir_utils.h"
#include "frontend/operator/ops.h"
namespace mindspore {
namespace ad {
class TestKPynative : public UT::Common {
public:
pipeline::ResourcePtr resource = std::make_shared<pipeline::Resource>();
protected:
AbstractBasePtr BuildArg() {
std::vector<int64_t> shp = {2, 2};
tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(kFloat32->type_id(), shp);
auto abstract = tensor->ToAbstract();
return abstract;
}
FuncGraphPtr BuildPrimalFuncGraph(const std::string &testCase) {
auto g = std::make_shared<FuncGraph>();
auto x = g->add_parameter();
auto y = g->add_parameter();
x->set_abstract(BuildArg());
y->set_abstract(BuildArg());
auto c_node = g->NewCNode({NewValueNode(prim::GetPythonOps("tensor_mul", "mindspore.ops.functional")), x, y});
c_node->set_abstract(BuildArg());
g->set_output(c_node);
return g;
}
// a = x * y
// b = stop_gradient(a)
// c = b * y
// return c
FuncGraphPtr BuildStopGradient(const std::string &testCase) {
auto g = std::make_shared<FuncGraph>();
auto x = g->add_parameter();
x->debug_info()->set_name("x");
auto y = g->add_parameter();
y->debug_info()->set_name("y");
x->set_abstract(BuildArg());
y->set_abstract(BuildArg());
auto a_node = g->NewCNode({NewValueNode(prim::GetPythonOps("tensor_mul", "mindspore.ops.functional")), x, y});
a_node->set_abstract(BuildArg());
auto b_node = g->NewCNode({NewValueNode(prim::kPrimStopGradient), a_node});
b_node->set_abstract(BuildArg());
auto c_node = g->NewCNode({NewValueNode(prim::GetPythonOps("tensor_mul", "mindspore.ops.functional")), b_node, y});
c_node->set_abstract(BuildArg());
auto d_node =
g->NewCNode({NewValueNode(prim::GetPythonOps("tensor_mul", "mindspore.ops.functional")), a_node, c_node});
d_node->set_abstract(BuildArg());
g->set_output(d_node);
return g;
}
FuncGraphPtr BuildBpropFuncGraph(const FuncGraphPtr &primal_fg) {
auto input_params = primal_fg->parameters();
std::vector<ValuePtr> input_param_values;
std::for_each(input_params.begin(), input_params.end(),
[&](const AnfNodePtr &param) { input_param_values.emplace_back(param->abstract()->BuildValue()); });
auto k_pynative_cell = GradPynativeCellBegin(input_params, input_param_values);
auto node_list = TopoSort(primal_fg->output());
for (auto node : node_list) {
if (node->isa<CNode>()) {
auto c_node = node->cast<CNodePtr>();
auto out = c_node->abstract()->GetValueTrack();
ValuePtrList args;
for (size_t i = 1; i < c_node->inputs().size(); ++i) {
args.push_back(c_node->input(i)->abstract()->GetValueTrack());
}
GradPynativeOp(k_pynative_cell, c_node, args, out);
}
}
GradAttr grad_attr(true, false, false, false, true);
auto bprop_fg = GradPynativeCellEnd(k_pynative_cell, AnfNodePtrList{}, std::vector<size_t>{0}, grad_attr, true);
return bprop_fg;
}
};
TEST_F(TestKPynative, test_simple_add) {
auto primal_fg = BuildPrimalFuncGraph("test_simple_add");
resource->manager()->KeepRoots({primal_fg});
auto bprop_fg = BuildBpropFuncGraph(primal_fg);
resource->manager()->KeepRoots({bprop_fg});
}
TEST_F(TestKPynative, test_stop_gradient) {
auto primal_fg = BuildStopGradient("test_stop_gradient");
resource->manager()->KeepRoots({primal_fg});
auto bprop_fg = BuildBpropFuncGraph(primal_fg);
resource->manager()->KeepRoots({bprop_fg});
}
} // namespace ad
} // namespace mindspore

View File

@ -26,6 +26,26 @@ from .vm_interface import vm
# pylint: disable=unused-argument
@vm_impl_getters.register(P.ZerosLike)
def vm_impl_zeroslike(self):
def vm_impl(x):
x = x.asnumpy()
out = np.zeros_like(x)
return Tensor(out)
return vm_impl
@vm_impl_getters.register(P.Log)
def vm_impl_log(self):
def vm_impl(x):
x = x.asnumpy()
out = np.log(x)
return Tensor(out)
return vm_impl
@vm_impl_getters.register(P.Add)
def vm_impl_tensor_add(self):
"""Generate vm_impl function for TensorAdd."""