forked from mindspore-Ecosystem/mindspore
Set unused inputs for bprop expanders and clear the device address of forward ops
This commit is contained in:
parent
3b513daac7
commit
bd645cd660
|
@ -40,6 +40,14 @@
|
|||
"mindspore/mindspore/core/ops/max_pool.cc" "zerodivcond"
|
||||
"mindspore/core/utils/log_adapter.cc" "stlIfStrFind"
|
||||
"mindspore/mindspore/ccsrc/transform/graph_ir/convert.cc" "knownConditionTrueFalse"
|
||||
"mindspore/mindspore/ccsrc/pipeline/pynative/grad/bprop_expander/grad_ops/grad_array_ops.cc" "internalAstError"
|
||||
"mindspore/mindspore/ccsrc/pipeline/pynative/grad/bprop_expander/grad_ops/grad_image_ops.cc" "internalAstError"
|
||||
"mindspore/mindspore/ccsrc/pipeline/pynative/grad/bprop_expander/grad_ops/grad_inner_ops.cc" "internalAstError"
|
||||
"mindspore/mindspore/ccsrc/pipeline/pynative/grad/bprop_expander/grad_ops/grad_math_ops.cc" "internalAstError"
|
||||
"mindspore/mindspore/ccsrc/pipeline/pynative/grad/bprop_expander/grad_ops/grad_nn_ops.cc" "internalAstError"
|
||||
"mindspore/mindspore/ccsrc/pipeline/pynative/grad/bprop_expander/grad_ops/grad_quant_ops.cc" "internalAstError"
|
||||
"mindspore/mindspore/ccsrc/pipeline/pynative/grad/bprop_expander/grad_ops/grad_scipy_ops.cc" "internalAstError"
|
||||
"mindspore/mindspore/ccsrc/pipeline/pynative/grad/bprop_expander/grad_ops/grad_sparse_ops.cc" "internalAstError"
|
||||
|
||||
# MindData
|
||||
"mindspore/mindspore/ccsrc/minddata/dataset/engine/dataset_iterator.cc" "useStlAlgorithm"
|
||||
|
|
|
@ -298,19 +298,18 @@ AnfNodePtr VariableAdjoint::RealDout() {
|
|||
return accumulate_dout;
|
||||
}
|
||||
|
||||
AutoGradCellImpl::AutoGradCellImpl(const AnfNodePtrList &cell_inputs, const std::vector<ValuePtr> &input_param_values)
|
||||
AutoGradCellImpl::AutoGradCellImpl(const AnfNodePtrList &cell_inputs, const std::vector<ValuePtr> &input_param_values,
|
||||
const AbstractBasePtrList &abs_list)
|
||||
: tape_(std::make_shared<FuncGraph>()), cell_inputs_(cell_inputs) {
|
||||
tape_->debug_info()->set_name("grad_top");
|
||||
MS_LOG(DEBUG) << "Start AutoGradCellImpl, cell_inputs size: " << cell_inputs.size();
|
||||
for (size_t i = 0; i < cell_inputs.size(); ++i) {
|
||||
TraceGuard trace_guard(std::make_shared<TraceCopy>(cell_inputs[i]->debug_info()));
|
||||
auto parameter = tape_->add_parameter();
|
||||
parameter->set_abstract(input_param_values[i]->ToAbstract()->Broaden());
|
||||
parameter->set_abstract(abs_list[i]);
|
||||
auto zeros_like_dout = BuildZerosLikeNode(tape_, input_param_values[i]);
|
||||
auto func_node = std::make_shared<FunctionNode>(tape_, zeros_like_dout);
|
||||
const auto &clone_value = ShallowCopyTensorValue(input_param_values[i]);
|
||||
ClearDeviceAddress(clone_value);
|
||||
auto input_adjoint = std::make_shared<VariableAdjoint>(func_node, clone_value);
|
||||
auto input_adjoint = std::make_shared<VariableAdjoint>(func_node, input_param_values[i]);
|
||||
(void)anfnode_to_variable_adjoint_.insert(std::make_pair(cell_inputs[i], input_adjoint));
|
||||
}
|
||||
}
|
||||
|
@ -349,8 +348,8 @@ bool AutoGradCellImpl::KPynativeOp(const GradParamPtr &grad_param) {
|
|||
BuildBPropCutCNode(input_node, prim, &outputs);
|
||||
} else {
|
||||
#ifndef ENABLE_TEST
|
||||
mindspore::BuildBprop(input_node, &outputs, &users_);
|
||||
if (outputs.empty()) {
|
||||
auto ret = BpropExpander(&outputs, &users_).Run(input_node);
|
||||
if (!ret || outputs.empty()) {
|
||||
MS_LOG(DEBUG) << "Expander has no bprop of this prim: " << grad_param->cnode->DebugString();
|
||||
BuildCustomBpropCNode(input_node, prim, &outputs);
|
||||
}
|
||||
|
@ -460,7 +459,6 @@ void AutoGradCellImpl::UpdateOutputNodeOfTopCell(const AnfNodePtr &output_node,
|
|||
MS_EXCEPTION_IF_NULL(sens_out);
|
||||
MS_LOG(DEBUG) << "Real output node of top cell is " << output_node->DebugString();
|
||||
last_node_ = output_node;
|
||||
ClearDeviceAddress(sens_out);
|
||||
sens_value_ = sens_out;
|
||||
}
|
||||
|
||||
|
@ -1077,7 +1075,7 @@ void AutoGradCellImpl::AddUser(const AnfNodePtr &node, const CNodePtr &user, siz
|
|||
if (users_.find(node) == users_.end()) {
|
||||
users_[node] = {};
|
||||
}
|
||||
(void)users_[node].emplace_back(make_pair(user, index));
|
||||
(void)users_[node].emplace_back(user, index);
|
||||
}
|
||||
|
||||
void AutoGradCellImpl::Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node) {
|
||||
|
@ -1086,7 +1084,10 @@ void AutoGradCellImpl::Replace(const AnfNodePtr &old_node, const AnfNodePtr &new
|
|||
}
|
||||
auto &old_node_users = users_[old_node];
|
||||
for (const auto &pair_node : old_node_users) {
|
||||
auto cnode = pair_node.first;
|
||||
auto cnode = pair_node.first.lock();
|
||||
if (cnode == nullptr) {
|
||||
continue;
|
||||
}
|
||||
size_t index = pair_node.second;
|
||||
if (index >= cnode->size()) {
|
||||
MS_LOG(EXCEPTION) << "exception for index:" << index << "greater than cnode size:" << cnode->size();
|
||||
|
@ -1181,7 +1182,8 @@ void AutoGradCellImpl::ReplacePrimalParameter(const AnfNodePtrList &weights, boo
|
|||
}
|
||||
|
||||
AutoGradCellImplPtr GradPynativeCellBegin(const AnfNodePtrList &cell_inputs,
|
||||
const std::vector<ValuePtr> &input_param_values) {
|
||||
const std::vector<ValuePtr> &input_param_values,
|
||||
const AbstractBasePtrList &abs_list) {
|
||||
auto abstract_are_set = std::all_of(cell_inputs.cbegin(), cell_inputs.cend(),
|
||||
[](const AnfNodePtr &node) { return node->abstract() != nullptr; });
|
||||
if (!abstract_are_set) {
|
||||
|
@ -1191,7 +1193,7 @@ AutoGradCellImplPtr GradPynativeCellBegin(const AnfNodePtrList &cell_inputs,
|
|||
MS_LOG(EXCEPTION) << "The size of cell inputs " << cell_inputs.size()
|
||||
<< " is not equal to the size of input parameter values " << input_param_values.size();
|
||||
}
|
||||
return std::make_shared<AutoGradCellImpl>(cell_inputs, input_param_values);
|
||||
return std::make_shared<AutoGradCellImpl>(cell_inputs, input_param_values, abs_list);
|
||||
}
|
||||
|
||||
FuncGraphPtr GradPynativeCellEnd(const AutoGradCellImplPtr &auto_grad_cell, const AnfNodePtrList &weights,
|
||||
|
|
|
@ -127,8 +127,9 @@ using VariableAdjointPtr = std::shared_ptr<VariableAdjoint>;
|
|||
|
||||
class AutoGradCellImpl {
|
||||
public:
|
||||
using UserType = std::map<AnfNodePtr, std::vector<std::pair<CNodePtr, int>>>;
|
||||
AutoGradCellImpl(const AnfNodePtrList &cell_inputs, const std::vector<ValuePtr> &input_param_values);
|
||||
using UserType = std::map<AnfNodePtr, std::vector<std::pair<std::weak_ptr<CNode>, int>>>;
|
||||
AutoGradCellImpl(const AnfNodePtrList &cell_inputs, const std::vector<ValuePtr> &input_param_values,
|
||||
const AbstractBasePtrList &abs_list);
|
||||
~AutoGradCellImpl() = default;
|
||||
// Reverse connect bprop of op
|
||||
bool KPynativeOp(const GradParamPtr &grad_param);
|
||||
|
@ -214,7 +215,8 @@ using AutoGradCellImplPtr = std::shared_ptr<AutoGradCellImpl>;
|
|||
// Start building back propagate funcgraph for this cell.
|
||||
// cell_inputs: the input parameter list of this cell except the weights;
|
||||
AutoGradCellImplPtr GradPynativeCellBegin(const AnfNodePtrList &cell_inputs,
|
||||
const std::vector<ValuePtr> &input_param_values);
|
||||
const std::vector<ValuePtr> &input_param_values,
|
||||
const AbstractBasePtrList &abs_list);
|
||||
|
||||
// Return the back propagate funcgraph for this cell.
|
||||
// weights: weights parameters used in this cell.
|
||||
|
|
|
@ -16,10 +16,8 @@
|
|||
#include "pipeline/pynative/grad/bprop_expander/bprop.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <queue>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include "expander/infer.h"
|
||||
#include "utils/anf_utils.h"
|
||||
#include "include/common/debug/anf_ir_dump.h"
|
||||
|
@ -27,191 +25,191 @@
|
|||
namespace mindspore {
|
||||
namespace expander {
|
||||
namespace bprop {
|
||||
class BpropExpander {
|
||||
public:
|
||||
BpropExpander(CNodePtrList *outputs, DoutUserType *dout_user, UserType *users)
|
||||
: outputs_(outputs), dout_user_(dout_user), users_(users) {}
|
||||
~BpropExpander() = default;
|
||||
|
||||
NodePtrList ExtractInputs(const CNodePtr &cnode, const BpropIRBuilder *ir_builder) {
|
||||
NodePtrList nodes;
|
||||
nodes.reserve(cnode->size());
|
||||
(void)std::transform(cnode->inputs().cbegin() + 1, cnode->inputs().cend(), std::back_inserter(nodes),
|
||||
[ir_builder](const AnfNodePtr &no) { return std::make_shared<Node>(no, ir_builder); });
|
||||
return nodes;
|
||||
}
|
||||
|
||||
bool Run(const CNodePtr &cnode) {
|
||||
auto infer = std::make_shared<CppInfer>();
|
||||
auto name = AnfUtils::GetCNodeName(cnode);
|
||||
auto ir_builder = std::make_unique<BpropIRBuilder>(name, cnode->func_graph(), infer);
|
||||
auto inputs = ExtractInputs(cnode, ir_builder.get());
|
||||
auto &attrs = GetCNodePrimitive(cnode)->attrs();
|
||||
auto ret = ir_builder->Run(inputs, attrs, outputs_);
|
||||
if (!ret) {
|
||||
return false;
|
||||
}
|
||||
PostProcess(inputs);
|
||||
static bool dump_result = (common::GetEnv("MS_DEV_DUMP_BPROP") == "on");
|
||||
if (dump_result) {
|
||||
DumpResult(name, inputs);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void PostProcess(const NodePtrList &inputs) const {
|
||||
std::set<AnfNodePtr> visited;
|
||||
// do not visit the inputs again.
|
||||
std::for_each(inputs.cbegin(), inputs.cend(), [&visited](const NodePtr &node) { visited.insert(node->get()); });
|
||||
|
||||
std::queue<CNodePtr> que;
|
||||
std::for_each(outputs_->cbegin(), outputs_->cend(), [&que](const CNodePtr &cnode) { que.push(cnode); });
|
||||
|
||||
AnfNodePtr dout = inputs.back()->get();
|
||||
while (!que.empty()) {
|
||||
auto node = que.front();
|
||||
que.pop();
|
||||
for (size_t i = 1; i < node->size(); ++i) {
|
||||
const auto &inp = node->input(i);
|
||||
// record parameter's and dout's user
|
||||
if (dout_user_ != nullptr) {
|
||||
if (inp == dout) {
|
||||
(void)dout_user_->emplace_back(node, i);
|
||||
}
|
||||
} else { // users_ != nullptr
|
||||
if (inp == dout || inp->isa<Parameter>()) {
|
||||
(*users_)[inp].emplace_back(node, i);
|
||||
}
|
||||
}
|
||||
if (IsPrimitiveCNode(inp, prim::kPrimTupleGetItem)) {
|
||||
auto getitem = inp->cast<CNodePtr>();
|
||||
auto real_input = getitem->input(kIndex1);
|
||||
// record the dout's successor getitem's users
|
||||
if (users_ != nullptr && real_input == dout) {
|
||||
(*users_)[inp].emplace_back(node, i);
|
||||
} else if (real_input->isa<ValueNode>()) {
|
||||
// eliminate redundant getitem
|
||||
auto real_input_value = real_input->cast<ValueNodePtr>()->value();
|
||||
if (real_input_value->isa<ValueSequence>()) {
|
||||
auto item_idx = GetValue<int64_t>(getitem->input(kIndex2)->cast<ValueNodePtr>()->value());
|
||||
auto newnode = NewValueNode((*(real_input_value->cast<ValueSequencePtr>()))[item_idx]);
|
||||
newnode->set_abstract(newnode->value()->ToAbstract());
|
||||
node->set_input(i, newnode);
|
||||
continue; // do not visit the getitem again from this node
|
||||
}
|
||||
}
|
||||
}
|
||||
if (inp->isa<CNode>() && visited.count(inp) == 0) {
|
||||
(void)visited.insert(inp);
|
||||
que.push(inp->cast<CNodePtr>());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void DumpResult(const std::string &name, const NodePtrList &inputs) const {
|
||||
auto fg = std::make_shared<FuncGraph>();
|
||||
std::map<AnfNodePtr, AnfNodePtr> node_map;
|
||||
CNodePtrList newcnodes;
|
||||
for (auto &inp : inputs) {
|
||||
auto p = fg->add_parameter();
|
||||
p->set_abstract(inp->get()->abstract());
|
||||
node_map[inp->get()] = p;
|
||||
}
|
||||
std::queue<CNodePtr> que;
|
||||
std::for_each(outputs_->cbegin(), outputs_->cend(), [&que](const CNodePtr &cnode) { que.push(cnode); });
|
||||
|
||||
while (!que.empty()) {
|
||||
auto node = que.front();
|
||||
que.pop();
|
||||
if (node_map.count(node)) {
|
||||
continue;
|
||||
}
|
||||
auto new_node = fg->NewCNode(node->inputs());
|
||||
new_node->CloneCNodeInfo(node);
|
||||
new_node->set_fullname_with_scope(node->fullname_with_scope());
|
||||
node_map[node] = new_node;
|
||||
newcnodes.push_back(new_node);
|
||||
for (size_t i = 1; i < node->size(); ++i) {
|
||||
const auto &inp = node->input(i);
|
||||
if (inp->isa<CNode>() && node_map.count(inp) == 0) {
|
||||
que.push(inp->cast<CNodePtr>());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (auto &cnode : newcnodes) {
|
||||
for (size_t i = 1; i < cnode->size(); i++) {
|
||||
if (node_map.count(cnode->input(i)) != 0) {
|
||||
cnode->set_input(i, node_map[cnode->input(i)]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (outputs_->size() == 1) {
|
||||
fg->set_output(node_map[(*outputs_)[0]]);
|
||||
} else {
|
||||
AnfNodePtrList new_outputs{NewValueNode(prim::kPrimMakeTuple)};
|
||||
AbstractBasePtrList abs;
|
||||
(void)std::transform(outputs_->cbegin(), outputs_->cend(), std::back_inserter(new_outputs),
|
||||
[&node_map, &abs](const CNodePtr &node) {
|
||||
abs.push_back(node->abstract());
|
||||
return node_map[node];
|
||||
});
|
||||
auto mt = fg->NewCNode(new_outputs);
|
||||
mt->set_abstract(std::make_shared<abstract::AbstractTuple>(abs));
|
||||
fg->set_output(mt);
|
||||
}
|
||||
DumpIR("bprop/bprop_expander_" + name + ".ir", fg, true);
|
||||
|
||||
if (dout_user_ != nullptr) {
|
||||
for (auto &iter : *dout_user_) {
|
||||
MS_LOG(INFO) << "Dout User: " << iter.first->fullname_with_scope() << " index: " << iter.second;
|
||||
}
|
||||
} else { // users_ != nullptr
|
||||
for (auto &uiter : *users_) {
|
||||
for (auto &iter : uiter.second) {
|
||||
MS_LOG(INFO) << "Node " << uiter.first->ToString() << " user: " << iter.first->fullname_with_scope()
|
||||
<< " index: " << iter.second;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
CNodePtrList *outputs_;
|
||||
DoutUserType *dout_user_;
|
||||
UserType *users_;
|
||||
};
|
||||
} // namespace bprop
|
||||
} // namespace expander
|
||||
|
||||
// deprecated
|
||||
void BuildBprop(const CNodePtr &cnode, CNodePtrList *outputs, DoutUserType *dout_user) {
|
||||
bool BpropExpander::Run(const CNodePtr &cnode) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
MS_EXCEPTION_IF_NULL(outputs);
|
||||
MS_EXCEPTION_IF_NULL(dout_user);
|
||||
expander::bprop::BpropExpander e(outputs, dout_user, nullptr);
|
||||
(void)e.Run(cnode);
|
||||
}
|
||||
|
||||
bool BuildBprop(const CNodePtr &cnode, CNodePtrList *outputs, UserType *users) {
|
||||
MS_EXCEPTION_IF_NULL(outputs_);
|
||||
MS_LOG(DEBUG) << "Begin building bprop for " << cnode->fullname_with_scope();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
MS_EXCEPTION_IF_NULL(outputs);
|
||||
MS_EXCEPTION_IF_NULL(users);
|
||||
bool ret = true;
|
||||
outputs_->clear();
|
||||
try {
|
||||
expander::bprop::BpropExpander e(outputs, nullptr, users);
|
||||
ret = e.Run(cnode);
|
||||
ret = RunBprop(cnode);
|
||||
} catch (const std::exception &e) {
|
||||
auto node_name = AnfUtils::GetCNodeName(cnode);
|
||||
MS_LOG(DEBUG) << "Bprop \"" << node_name << "\" encounter a problem: [" << e.what() << "]";
|
||||
MS_LOG(INFO) << "Python bprop will be used for \"" << node_name << "\"";
|
||||
outputs->clear();
|
||||
outputs_->clear();
|
||||
ret = false;
|
||||
}
|
||||
MS_LOG(DEBUG) << "Finish building bprop for " << cnode->fullname_with_scope();
|
||||
return ret;
|
||||
}
|
||||
|
||||
const std::vector<size_t> &BpropExpander::GetUnusedInputs(const CNodePtr &cnode) const {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto name = AnfUtils::GetCNodeName(cnode);
|
||||
auto handle = GetBpropHandle(name);
|
||||
if (handle == nullptr) {
|
||||
MS_LOG(DEBUG) << "Bprop IRBuilder [" << name << "] is not registered in bprop expander.";
|
||||
static std::vector<size_t> empty{};
|
||||
return empty;
|
||||
}
|
||||
return handle->unused_inputs;
|
||||
}
|
||||
|
||||
NodePtrList BpropExpander::ExtractInputs(const CNodePtr &cnode, const BpropIRBuilder *ir_builder) {
|
||||
NodePtrList nodes;
|
||||
nodes.reserve(cnode->size());
|
||||
(void)std::transform(cnode->inputs().cbegin() + 1, cnode->inputs().cend(), std::back_inserter(nodes),
|
||||
[ir_builder](const AnfNodePtr &no) { return std::make_shared<Node>(no, ir_builder); });
|
||||
return nodes;
|
||||
}
|
||||
|
||||
bool BpropExpander::RunBprop(const CNodePtr &cnode) {
|
||||
auto infer = std::make_shared<CppInfer>();
|
||||
auto name = AnfUtils::GetCNodeName(cnode);
|
||||
auto ir_builder = std::make_unique<BpropIRBuilder>(name, cnode->func_graph(), infer);
|
||||
auto inputs = ExtractInputs(cnode, ir_builder.get());
|
||||
auto &attrs = GetCNodePrimitive(cnode)->attrs();
|
||||
auto handle = GetBpropHandle(name);
|
||||
if (handle == nullptr) {
|
||||
MS_LOG(DEBUG) << "Bprop IRBuilder [" << name << "] is not registered in bprop expander.";
|
||||
return false;
|
||||
}
|
||||
auto output_nodes = ir_builder->Run(inputs, attrs, *handle);
|
||||
if (output_nodes.empty()) {
|
||||
MS_LOG(DEBUG) << "The output nodes of bprop function [" << name << "] is empty.";
|
||||
return false;
|
||||
}
|
||||
outputs_->reserve(output_nodes.size());
|
||||
(void)std::transform(output_nodes.cbegin(), output_nodes.cend(), std::back_inserter(*outputs_),
|
||||
[](const NodePtr &node) {
|
||||
auto cnode = node->get<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
return cnode;
|
||||
});
|
||||
PostProcess(inputs);
|
||||
DumpResult(name, inputs);
|
||||
return true;
|
||||
}
|
||||
|
||||
void BpropExpander::PostProcess(const NodePtrList &inputs) const {
|
||||
std::set<AnfNodePtr> visited;
|
||||
// do not visit the inputs again.
|
||||
std::for_each(inputs.cbegin(), inputs.cend(), [&visited](const NodePtr &node) { visited.insert(node->get()); });
|
||||
|
||||
std::queue<CNodePtr> que;
|
||||
std::for_each(outputs_->cbegin(), outputs_->cend(), [&que](const CNodePtr &cnode) { que.push(cnode); });
|
||||
|
||||
AnfNodePtr dout = inputs.back()->get();
|
||||
while (!que.empty()) {
|
||||
auto node = que.front();
|
||||
que.pop();
|
||||
for (size_t i = 1; i < node->size(); ++i) {
|
||||
const auto &inp = node->input(i);
|
||||
// record parameter's and dout's user
|
||||
if (users_ != nullptr) {
|
||||
if (inp == dout || inp->isa<Parameter>()) {
|
||||
(*users_)[inp].emplace_back(node, i);
|
||||
}
|
||||
}
|
||||
if (IsPrimitiveCNode(inp, prim::kPrimTupleGetItem)) {
|
||||
auto getitem = inp->cast<CNodePtr>();
|
||||
auto real_input = getitem->input(kIndex1);
|
||||
// record the dout's successor getitem's users
|
||||
if (users_ != nullptr && real_input == dout) {
|
||||
(*users_)[inp].emplace_back(node, i);
|
||||
} else if (real_input->isa<ValueNode>()) {
|
||||
// eliminate redundant getitem
|
||||
auto real_input_value = real_input->cast<ValueNodePtr>()->value();
|
||||
if (real_input_value->isa<ValueSequence>()) {
|
||||
auto item_idx = GetValue<int64_t>(getitem->input(kIndex2)->cast<ValueNodePtr>()->value());
|
||||
auto newnode = NewValueNode((*(real_input_value->cast<ValueSequencePtr>()))[item_idx]);
|
||||
newnode->set_abstract(newnode->value()->ToAbstract());
|
||||
node->set_input(i, newnode);
|
||||
continue; // do not visit the getitem again from this node
|
||||
}
|
||||
}
|
||||
}
|
||||
if (inp->isa<CNode>() && visited.count(inp) == 0) {
|
||||
(void)visited.insert(inp);
|
||||
que.push(inp->cast<CNodePtr>());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void BpropExpander::DumpResult(const std::string &name, const NodePtrList &inputs) const {
|
||||
static bool dump_result = (common::GetEnv("MS_DEV_DUMP_BPROP") == "on");
|
||||
if (!dump_result) {
|
||||
return;
|
||||
}
|
||||
auto fg = std::make_shared<FuncGraph>();
|
||||
std::map<AnfNodePtr, AnfNodePtr> node_map;
|
||||
CNodePtrList newcnodes;
|
||||
for (auto &inp : inputs) {
|
||||
auto p = fg->add_parameter();
|
||||
p->set_abstract(inp->get()->abstract());
|
||||
node_map[inp->get()] = p;
|
||||
}
|
||||
std::queue<CNodePtr> que;
|
||||
std::for_each(outputs_->cbegin(), outputs_->cend(), [&que](const CNodePtr &cnode) { que.push(cnode); });
|
||||
|
||||
while (!que.empty()) {
|
||||
auto node = que.front();
|
||||
que.pop();
|
||||
if (node_map.count(node)) {
|
||||
continue;
|
||||
}
|
||||
auto new_node = fg->NewCNode(node->inputs());
|
||||
new_node->CloneCNodeInfo(node);
|
||||
new_node->set_fullname_with_scope(node->fullname_with_scope());
|
||||
node_map[node] = new_node;
|
||||
newcnodes.push_back(new_node);
|
||||
for (size_t i = 1; i < node->size(); ++i) {
|
||||
const auto &inp = node->input(i);
|
||||
if (inp->isa<CNode>() && node_map.count(inp) == 0) {
|
||||
que.push(inp->cast<CNodePtr>());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (auto &cnode : newcnodes) {
|
||||
for (size_t i = 1; i < cnode->size(); i++) {
|
||||
if (node_map.count(cnode->input(i)) != 0) {
|
||||
cnode->set_input(i, node_map[cnode->input(i)]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (outputs_->size() == 1) {
|
||||
fg->set_output(node_map[(*outputs_)[0]]);
|
||||
} else {
|
||||
AnfNodePtrList new_outputs{NewValueNode(prim::kPrimMakeTuple)};
|
||||
AbstractBasePtrList abs;
|
||||
(void)std::transform(outputs_->cbegin(), outputs_->cend(), std::back_inserter(new_outputs),
|
||||
[&node_map, &abs](const CNodePtr &node) {
|
||||
abs.push_back(node->abstract());
|
||||
return node_map[node];
|
||||
});
|
||||
auto mt = fg->NewCNode(new_outputs);
|
||||
mt->set_abstract(std::make_shared<abstract::AbstractTuple>(abs));
|
||||
fg->set_output(mt);
|
||||
}
|
||||
DumpIR("bprop/bprop_expander_" + name + ".ir", fg, true);
|
||||
|
||||
if (users_ != nullptr) {
|
||||
for (auto &uiter : *users_) {
|
||||
for (auto &iter : uiter.second) {
|
||||
auto user = iter.first.lock();
|
||||
if (user == nullptr) {
|
||||
continue;
|
||||
}
|
||||
MS_LOG(INFO) << "Node " << uiter.first->ToString() << " user: " << user->fullname_with_scope()
|
||||
<< " index: " << iter.second;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace bprop
|
||||
} // namespace expander
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -19,16 +19,39 @@
|
|||
#include <map>
|
||||
#include <vector>
|
||||
#include <utility>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include "ir/anf.h"
|
||||
#include "pipeline/pynative/grad/bprop_expander/bprop_irbuilder.h"
|
||||
#include "include/common/visible.h"
|
||||
|
||||
namespace mindspore {
|
||||
using DoutUserType = std::vector<std::pair<CNodePtr, int>>;
|
||||
// deprecated
|
||||
void BuildBprop(const CNodePtr &cnode, CNodePtrList *outputs, DoutUserType *dout_user);
|
||||
namespace expander {
|
||||
namespace bprop {
|
||||
using UserType = std::map<AnfNodePtr, std::vector<std::pair<std::weak_ptr<CNode>, int>>>;
|
||||
class BpropExpander {
|
||||
public:
|
||||
BpropExpander() {}
|
||||
BpropExpander(CNodePtrList *outputs, UserType *users) : outputs_(outputs), users_(users) {}
|
||||
~BpropExpander() = default;
|
||||
bool Run(const CNodePtr &cnode);
|
||||
const std::vector<size_t> &GetUnusedInputs(const CNodePtr &cnode) const;
|
||||
|
||||
using UserType = std::map<AnfNodePtr, std::vector<std::pair<CNodePtr, int>>>;
|
||||
bool BuildBprop(const CNodePtr &cnode, CNodePtrList *outputs, UserType *users);
|
||||
private:
|
||||
bool RunBprop(const CNodePtr &cnode);
|
||||
NodePtrList ExtractInputs(const CNodePtr &cnode, const BpropIRBuilder *ir_builder);
|
||||
const BpropHandle *GetBpropHandle(const std::string &name) const {
|
||||
return BpropIRBuilderFactory::Instance().GetBuilder(name);
|
||||
}
|
||||
void PostProcess(const NodePtrList &inputs) const;
|
||||
void DumpResult(const std::string &name, const NodePtrList &inputs) const;
|
||||
|
||||
private:
|
||||
CNodePtrList *outputs_{nullptr};
|
||||
UserType *users_{nullptr};
|
||||
};
|
||||
} // namespace bprop
|
||||
} // namespace expander
|
||||
|
||||
using expander::bprop::BpropExpander;
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PIPELINE_PYNATIVE_GRAD_BPROP_EXPANDER_BPROP_H_
|
||||
|
|
|
@ -26,27 +26,10 @@
|
|||
namespace mindspore {
|
||||
namespace expander {
|
||||
namespace bprop {
|
||||
namespace {
|
||||
constexpr size_t kMaxDims = 8;
|
||||
} // namespace
|
||||
|
||||
bool BpropIRBuilder::Run(const NodePtrList &inputs, const DAttr &attrs, CNodePtrList *outputs) {
|
||||
MS_EXCEPTION_IF_NULL(outputs);
|
||||
if (!BpropIRBuilderFactory::Instance().HasOp(name())) {
|
||||
return false;
|
||||
}
|
||||
NodePtrList BpropIRBuilder::Run(const NodePtrList &inputs, const DAttr &attrs, const BpropHandle &handle) {
|
||||
inputs_ptr_ = &inputs;
|
||||
attrs_ptr_ = &attrs;
|
||||
auto func = BpropIRBuilderFactory::Instance().GetBuilder(name());
|
||||
auto output_nodes = func(this);
|
||||
outputs->reserve(output_nodes.size());
|
||||
(void)std::transform(output_nodes.cbegin(), output_nodes.cend(), std::back_inserter(*outputs),
|
||||
[](const NodePtr &node) {
|
||||
auto cnode = node->get<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
return cnode;
|
||||
});
|
||||
return true;
|
||||
return handle.func(this);
|
||||
}
|
||||
|
||||
ValuePtr BpropIRBuilder::GetAttr(const std::string &attr) const {
|
||||
|
@ -86,6 +69,7 @@ std::string BpropIRBuilder::GetTargetFromContext() const {
|
|||
NodePtr BpropIRBuilder::TensorGetItem(const NodePtr &node, int64_t idx) const {
|
||||
auto data_shape = GetShape(node);
|
||||
auto n = data_shape.size();
|
||||
constexpr const size_t kMaxDims = 8;
|
||||
if (n < 1 || n > kMaxDims) {
|
||||
MS_EXCEPTION(ValueError) << "Expect Tensor to have dimension between 1 and " << kMaxDims << ", but got: " << n;
|
||||
}
|
||||
|
|
|
@ -29,13 +29,21 @@
|
|||
namespace mindspore {
|
||||
namespace expander {
|
||||
namespace bprop {
|
||||
class BpropIRBuilder;
|
||||
|
||||
using BpropIRBuilderFunc = std::function<NodePtrList(const BpropIRBuilder *)>;
|
||||
struct BpropHandle {
|
||||
BpropIRBuilderFunc func;
|
||||
std::vector<size_t> unused_inputs;
|
||||
};
|
||||
|
||||
class BpropIRBuilder : public Emitter {
|
||||
public:
|
||||
BpropIRBuilder(const std::string &name, const FuncGraphPtr &func_graph, const ExpanderInferPtr &infer)
|
||||
: Emitter(func_graph, infer, std::make_shared<Scope>(std::string("Bprop/grad") + name)), name_(name) {}
|
||||
|
||||
/// \brief Run irbuilder to generate a graph
|
||||
bool Run(const NodePtrList &inputs, const DAttr &attrs, CNodePtrList *outputs);
|
||||
NodePtrList Run(const NodePtrList &inputs, const DAttr &attrs, const BpropHandle &handle);
|
||||
|
||||
ValuePtr GetAttr(const std::string &attr) const;
|
||||
template <typename S>
|
||||
|
@ -72,23 +80,28 @@ class BpropIRBuilder : public Emitter {
|
|||
const NodePtrList *inputs_ptr_{nullptr};
|
||||
const DAttr *attrs_ptr_{nullptr};
|
||||
};
|
||||
using BpropIRBuilderPtr = std::shared_ptr<BpropIRBuilder>;
|
||||
|
||||
using BpropIRBuilderFunc = std::function<NodePtrList(const BpropIRBuilder *)>;
|
||||
class BpropIRBuilderFactory {
|
||||
public:
|
||||
static BpropIRBuilderFactory &Instance() {
|
||||
static BpropIRBuilderFactory instance{};
|
||||
return instance;
|
||||
}
|
||||
const BpropIRBuilderFunc &GetBuilder(const std::string &name) { return builders()[name]; }
|
||||
void RegBuilder(const std::string &name, const BpropIRBuilderFunc &func) { builders()[name] = func; }
|
||||
bool HasOp(const std::string &name) const { return builders().count(name) != 0; }
|
||||
|
||||
const BpropHandle *GetBuilder(const std::string &name) {
|
||||
auto iter = registry().find(name);
|
||||
return (iter == registry().end()) ? nullptr : &(iter->second);
|
||||
}
|
||||
|
||||
void RegBuilder(const std::string &name, const BpropIRBuilderFunc &func) { registry()[name].func = func; }
|
||||
void RegUnusedInputs(const std::string &name, const std::vector<size_t> &unused) {
|
||||
registry()[name].unused_inputs = unused;
|
||||
}
|
||||
|
||||
private:
|
||||
HashMap<std::string, BpropIRBuilderFunc> &builders() const {
|
||||
static HashMap<std::string, BpropIRBuilderFunc> builder_map;
|
||||
return builder_map;
|
||||
HashMap<std::string, BpropHandle> ®istry() const {
|
||||
static HashMap<std::string, BpropHandle> reg;
|
||||
return reg;
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -100,6 +113,10 @@ class BpropIRBuilderRegHelper {
|
|||
BpropIRBuilderFactory::Instance().RegBuilder(name_, func);
|
||||
return *this;
|
||||
}
|
||||
const BpropIRBuilderRegHelper &SetUnusedInputs(const std::initializer_list<size_t> &unused_inputs) const {
|
||||
BpropIRBuilderFactory::Instance().RegUnusedInputs(name_, unused_inputs);
|
||||
return *this;
|
||||
}
|
||||
|
||||
private:
|
||||
std::string name_;
|
||||
|
@ -109,6 +126,7 @@ class BpropIRBuilderRegHelper {
|
|||
#define BPROP_EXPANDER_UNIQUE_NAME(prefix, cnt) BPROP_EXPANDER_JOIN(prefix, cnt)
|
||||
#define REG_BPROP_BUILDER(name) \
|
||||
const BpropIRBuilderRegHelper BPROP_EXPANDER_UNIQUE_NAME(g_bprop, __COUNTER__) = BpropIRBuilderRegHelper(name)
|
||||
#define BODYFUNC(v) [](const BpropIRBuilder *v) -> NodePtrList
|
||||
} // namespace bprop
|
||||
} // namespace expander
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -332,34 +332,6 @@ NodePtr GetEps(const BpropIRBuilder *ib, const TypePtr &type) {
|
|||
}
|
||||
}
|
||||
|
||||
NodePtrList BinopGatherCommon(const BpropIRBuilder *ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto indices = ib->GetInput(kIndex1);
|
||||
auto axis = ib->GetInput(kIndex2);
|
||||
auto dout = ib->GetInput(kIndex4);
|
||||
auto orig_indices = indices;
|
||||
auto x_shp = ib->GetShape(x);
|
||||
auto out_shp = ib->GetShape(dout);
|
||||
auto ind_shp = ib->GetShape(indices);
|
||||
auto axis_v = CheckRange(GetIntValue(axis), SizeToLong(x_shp.size()));
|
||||
if (out_shp.empty()) {
|
||||
dout = ib->Emit("ExpandDims", {dout, ib->Tensor(-1)});
|
||||
}
|
||||
if (ind_shp.empty()) {
|
||||
indices = ib->Emit("ExpandDims", {indices, ib->Tensor(-1)});
|
||||
ind_shp = ib->GetShape(indices);
|
||||
auto out_shp1 = RegenerateOutputShape(x_shp, ind_shp, axis_v);
|
||||
dout = ib->Reshape(dout, out_shp1);
|
||||
}
|
||||
out_shp = ib->GetShape(dout);
|
||||
auto perm_1 = GenerateShapeIndex(out_shp, ind_shp, axis_v);
|
||||
auto values_transpose = ib->Transpose(dout, perm_1);
|
||||
auto tmp = ib->Emit("UnsortedSegmentSum", {values_transpose, indices, ib->Value<int64_t>(x_shp[axis_v])});
|
||||
auto perm_2 = GenerateInverseIndex(x_shp, axis_v);
|
||||
auto params_grad = ib->Transpose(tmp, perm_2);
|
||||
return {params_grad, ib->ZerosLike(orig_indices), ib->ZerosLike(axis)};
|
||||
}
|
||||
|
||||
std::vector<int64_t> GenerateInverseIndex(const std::vector<int64_t> &x_shp, int64_t axis_v) {
|
||||
int64_t x_rank = static_cast<int64_t>(x_shp.size());
|
||||
auto index = Range(x_rank);
|
||||
|
|
|
@ -24,6 +24,17 @@
|
|||
#include "pipeline/pynative/grad/bprop_expander/bprop_irbuilder.h"
|
||||
|
||||
namespace mindspore::expander::bprop {
|
||||
constexpr size_t i0 = 0;
|
||||
constexpr size_t i1 = 1;
|
||||
constexpr size_t i2 = 2;
|
||||
constexpr size_t i3 = 3;
|
||||
constexpr size_t i4 = 4;
|
||||
constexpr size_t i5 = 5;
|
||||
constexpr size_t i6 = 6;
|
||||
constexpr size_t i7 = 7;
|
||||
constexpr size_t i8 = 8;
|
||||
constexpr size_t i9 = 9;
|
||||
constexpr size_t i10 = 10;
|
||||
inline const auto pi = std::acos(-1.0);
|
||||
inline const auto log_2 = std::log(2.0);
|
||||
inline const auto log_pi = std::log(pi);
|
||||
|
@ -62,7 +73,6 @@ std::vector<int64_t> GetIntList(const ValuePtr &value);
|
|||
std::vector<int64_t> GetIntList(const NodePtr &node);
|
||||
|
||||
NodePtr GetEps(const BpropIRBuilder *ib, const TypePtr &type);
|
||||
NodePtrList BinopGatherCommon(const BpropIRBuilder *ib);
|
||||
std::vector<int64_t> GenerateInverseIndex(const std::vector<int64_t> &x_shp, int64_t axis_v);
|
||||
std::vector<int64_t> GenerateShapeIndex(const std::vector<int64_t> &out_shp, const std::vector<int64_t> &ind_shp,
|
||||
int64_t axis_v);
|
||||
|
|
|
@ -69,7 +69,7 @@ NodePtrList UnsortedSegmentMinOrMaxGrad(const BpropIRBuilder *ib, const NodePtr
|
|||
}
|
||||
} // namespace
|
||||
|
||||
REG_BPROP_BUILDER("GatherD").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("GatherD").SetUnusedInputs({i3}).SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto dim = ib->GetInput(kIndex1);
|
||||
auto index = ib->GetInput(kIndex2);
|
||||
|
@ -78,7 +78,7 @@ REG_BPROP_BUILDER("GatherD").SetBody([](const BpropIRBuilder *ib) -> NodePtrList
|
|||
return {dx, ib->ZerosLike(dim), ib->ZerosLike(index)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("GatherDGrad").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("GatherDGrad").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) {
|
||||
auto dim = GetValue<int64_t>(ib->GetAttr("dim"));
|
||||
auto x_shp = GetValue<ShapeVector>(ib->GetAttr("shape"));
|
||||
auto index = ib->GetInput(kIndex0);
|
||||
|
@ -112,7 +112,7 @@ REG_BPROP_BUILDER("GatherDGrad").SetBody([](const BpropIRBuilder *ib) -> NodePtr
|
|||
return {ib->ZerosLike(index), dx};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("GatherDGradV2").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("GatherDGradV2").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) {
|
||||
auto dim = GetValue<int64_t>(ib->GetAttr("dim"));
|
||||
auto index = ib->GetInput(kIndex0);
|
||||
auto x = ib->GetInput(kIndex1);
|
||||
|
@ -150,7 +150,7 @@ REG_BPROP_BUILDER("GatherDGradV2").SetBody([](const BpropIRBuilder *ib) -> NodeP
|
|||
return {ib->ZerosLike(index), dx};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("SparseGatherV2").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("SparseGatherV2").SetUnusedInputs({i3}).SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto indices = ib->GetInput(kIndex1);
|
||||
auto axis = ib->GetInput(kIndex2);
|
||||
|
@ -185,7 +185,7 @@ REG_BPROP_BUILDER("SparseGatherV2").SetBody([](const BpropIRBuilder *ib) -> Node
|
|||
return {params_grad, ib->ZerosLike(indices), ib->ZerosLike(axis)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("Sort").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("Sort").SetUnusedInputs({i1}).SetBody(BODYFUNC(ib) {
|
||||
auto axis = GetValue<int64_t>(ib->GetAttr("axis"));
|
||||
auto descending = GetValue<bool>(ib->GetAttr("descending"));
|
||||
auto input_x = ib->GetInput(kIndex0);
|
||||
|
@ -243,19 +243,19 @@ REG_BPROP_BUILDER("Sort").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
|||
return {dx};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("Identity").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("Identity").SetUnusedInputs({i0, i1}).SetBody(BODYFUNC(ib) {
|
||||
auto dout = ib->GetInput(kIndex2);
|
||||
return {dout};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("Range").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("Range").SetUnusedInputs({i3, i4}).SetBody(BODYFUNC(ib) {
|
||||
auto start = ib->GetInput(kIndex0);
|
||||
auto limit = ib->GetInput(kIndex1);
|
||||
auto delta = ib->GetInput(kIndex2);
|
||||
return {ib->ZerosLike(start), ib->ZerosLike(limit), ib->ZerosLike(delta)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("Pack").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("Pack").SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto out = ib->GetInput(kIndex1);
|
||||
auto dout = ib->GetInput(kIndex2);
|
||||
|
@ -263,7 +263,7 @@ REG_BPROP_BUILDER("Pack").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
|||
return {ret};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("Stack").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("Stack").SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto out = ib->GetInput(kIndex1);
|
||||
auto dout = ib->GetInput(kIndex2);
|
||||
|
@ -271,20 +271,20 @@ REG_BPROP_BUILDER("Stack").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
|||
return {ret};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("ReverseV2").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("ReverseV2").SetUnusedInputs({i0, i1}).SetBody(BODYFUNC(ib) {
|
||||
auto dout = ib->GetInput(kIndex2);
|
||||
auto dx = ib->Emit("ReverseV2", {dout}, {{"axis", ib->GetAttr("axis")}});
|
||||
return {dx};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("Unstack").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("Unstack").SetUnusedInputs({i0}).SetBody(BODYFUNC(ib) {
|
||||
auto out = ib->GetInput(kIndex1);
|
||||
auto dout = ib->GetInput(kIndex2);
|
||||
out = ib->Emit("Stack", {dout}, {{"axis", ib->GetAttr("axis")}});
|
||||
return {out};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("StridedSlice").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("StridedSlice").SetUnusedInputs({i0, i4}).SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto begin = ib->GetInput(kIndex1);
|
||||
auto end = ib->GetInput(kIndex2);
|
||||
|
@ -303,7 +303,7 @@ REG_BPROP_BUILDER("StridedSlice").SetBody([](const BpropIRBuilder *ib) -> NodePt
|
|||
return {dx, dbegin, dend, dstrides};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("StridedSliceGrad").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("StridedSliceGrad").SetUnusedInputs({i0, i5}).SetBody(BODYFUNC(ib) {
|
||||
auto shapex = ib->GetInput(kIndex1);
|
||||
auto begin = ib->GetInput(kIndex2);
|
||||
auto end = ib->GetInput(kIndex3);
|
||||
|
@ -318,14 +318,14 @@ REG_BPROP_BUILDER("StridedSliceGrad").SetBody([](const BpropIRBuilder *ib) -> No
|
|||
ib->ZerosLike(shapex), ib->ZerosLike(begin), ib->ZerosLike(end), ib->ZerosLike(strides)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("Eye").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("Eye").SetUnusedInputs({i3, i4}).SetBody(BODYFUNC(ib) {
|
||||
auto n = ib->GetInput(kIndex0);
|
||||
auto m = ib->GetInput(kIndex1);
|
||||
auto t = ib->GetInput(kIndex2);
|
||||
return {ib->ZerosLike(n), ib->ZerosLike(m), t};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("Select").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("Select").SetUnusedInputs({i3}).SetBody(BODYFUNC(ib) {
|
||||
auto cond = ib->GetInput(kIndex0);
|
||||
auto x = ib->GetInput(kIndex1);
|
||||
auto y = ib->GetInput(kIndex2);
|
||||
|
@ -333,17 +333,17 @@ REG_BPROP_BUILDER("Select").SetBody([](const BpropIRBuilder *ib) -> NodePtrList
|
|||
return {ib->ZerosLike(cond), ib->Select(cond, dout, ib->ZerosLike(x)), ib->Select(cond, ib->ZerosLike(y), dout)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("OnesLike").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("OnesLike").SetUnusedInputs({i1, i2}).SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
return {ib->ZerosLike(x)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("ZerosLike").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("ZerosLike").SetUnusedInputs({i1, i2}).SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
return {ib->ZerosLike(x)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("ResizeNearestNeighbor").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("ResizeNearestNeighbor").SetUnusedInputs({i1}).SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto dout = ib->GetInput(kIndex2);
|
||||
auto x_shape = ib->GetShape(x);
|
||||
|
@ -356,7 +356,7 @@ REG_BPROP_BUILDER("ResizeNearestNeighbor").SetBody([](const BpropIRBuilder *ib)
|
|||
return {out};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("GatherNd").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("GatherNd").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto indices = ib->GetInput(kIndex1);
|
||||
auto dout = ib->GetInput(kIndex3);
|
||||
|
@ -364,26 +364,26 @@ REG_BPROP_BUILDER("GatherNd").SetBody([](const BpropIRBuilder *ib) -> NodePtrLis
|
|||
return {ib->Emit("ScatterNd", {indices, dout, shp}), ib->ZerosLike(indices)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("ScatterNd").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("ScatterNd").SetUnusedInputs({i1, i3}).SetBody(BODYFUNC(ib) {
|
||||
auto indices = ib->GetInput(kIndex0);
|
||||
auto shape = ib->GetInput(kIndex2);
|
||||
auto dout = ib->GetInput(kIndex4);
|
||||
return {ib->ZerosLike(indices), ib->Emit("GatherNd", {dout, indices}), ib->ZerosLike(shape)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("ScatterNdUpdate").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("ScatterNdUpdate").SetUnusedInputs({i0, i2, i3}).SetBody(BODYFUNC(ib) {
|
||||
auto indices = ib->GetInput(kIndex1);
|
||||
auto dout = ib->GetInput(kIndex4);
|
||||
return {dout, ib->ZerosLike(indices), ib->Emit("GatherNd", {dout, indices})};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("ScatterNonAliasingAdd").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("ScatterNonAliasingAdd").SetUnusedInputs({i0, i2, i3}).SetBody(BODYFUNC(ib) {
|
||||
auto indices = ib->GetInput(kIndex1);
|
||||
auto dout = ib->GetInput(kIndex4);
|
||||
return {dout, ib->ZerosLike(indices), ib->Emit("GatherNd", {dout, indices})};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("TensorScatterUpdate").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("TensorScatterUpdate").SetUnusedInputs({i0, i3}).SetBody(BODYFUNC(ib) {
|
||||
auto indices = ib->GetInput(kIndex1);
|
||||
auto update = ib->GetInput(kIndex2);
|
||||
auto dout = ib->GetInput(kIndex4);
|
||||
|
@ -392,14 +392,14 @@ REG_BPROP_BUILDER("TensorScatterUpdate").SetBody([](const BpropIRBuilder *ib) ->
|
|||
return {x_grad, ib->ZerosLike(indices), update_grad};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("Flatten").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("Flatten").SetUnusedInputs({i1}).SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto dout = ib->GetInput(kIndex2);
|
||||
auto dx = ib->Reshape(dout, ib->GetShape(x));
|
||||
return {dx};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("Reshape").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("Reshape").SetUnusedInputs({i0, i1, i2}).SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto shp = ib->GetInput(kIndex1);
|
||||
auto dout = ib->GetInput(kIndex3);
|
||||
|
@ -407,12 +407,12 @@ REG_BPROP_BUILDER("Reshape").SetBody([](const BpropIRBuilder *ib) -> NodePtrList
|
|||
return {ib->Reshape(dout, shapex), ib->ZerosLike(shp)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("NonZero").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("NonZero").SetUnusedInputs({i1, i2}).SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
return {ib->ZerosLike(x)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("BatchMatMul").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("BatchMatMul").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) {
|
||||
auto ta = GetValue<bool>(ib->GetAttr("transpose_a"));
|
||||
auto tb = GetValue<bool>(ib->GetAttr("transpose_b"));
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
|
@ -436,41 +436,41 @@ REG_BPROP_BUILDER("BatchMatMul").SetBody([](const BpropIRBuilder *ib) -> NodePtr
|
|||
return BinopGradCommonWithShift(ib, x, w, dx, dw, 2);
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("Argmax").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("Argmax").SetUnusedInputs({i1, i2}).SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
return {ib->ZerosLike(x)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("Argmin").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("Argmin").SetUnusedInputs({i1, i2}).SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
return {ib->ZerosLike(x)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("Diag").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("Diag").SetUnusedInputs({i0, i1}).SetBody(BODYFUNC(ib) {
|
||||
auto dout = ib->GetInput(kIndex2);
|
||||
return {ib->Emit("DiagPart", {dout})};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("DiagPart").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("DiagPart").SetUnusedInputs({i0, i1}).SetBody(BODYFUNC(ib) {
|
||||
auto dout = ib->GetInput(kIndex2);
|
||||
return {ib->Emit("Diag", {dout})};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("SpaceToBatch").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("SpaceToBatch").SetUnusedInputs({i0, i1}).SetBody(BODYFUNC(ib) {
|
||||
auto dout = ib->GetInput(kIndex2);
|
||||
auto dx =
|
||||
ib->Emit("BatchToSpace", {dout}, {{"block_size", ib->GetAttr("block_size")}, {"crops", ib->GetAttr("paddings")}});
|
||||
return {dx};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("BatchToSpace").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("BatchToSpace").SetUnusedInputs({i0, i1}).SetBody(BODYFUNC(ib) {
|
||||
auto dout = ib->GetInput(kIndex2);
|
||||
auto dx =
|
||||
ib->Emit("SpaceToBatch", {dout}, {{"block_size", ib->GetAttr("block_size")}, {"paddings", ib->GetAttr("crops")}});
|
||||
return {dx};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("ReverseSequence").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("ReverseSequence").SetUnusedInputs({i0, i2}).SetBody(BODYFUNC(ib) {
|
||||
auto seq_lengths = ib->GetInput(kIndex1);
|
||||
auto dout = ib->GetInput(kIndex3);
|
||||
auto dx = ib->Emit("ReverseSequence", {dout, seq_lengths},
|
||||
|
@ -478,14 +478,14 @@ REG_BPROP_BUILDER("ReverseSequence").SetBody([](const BpropIRBuilder *ib) -> Nod
|
|||
return {dx, ib->ZerosLike(seq_lengths)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("TensorScatterAdd").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("TensorScatterAdd").SetUnusedInputs({i0, i2, i3}).SetBody(BODYFUNC(ib) {
|
||||
auto indices = ib->GetInput(kIndex1);
|
||||
auto dout = ib->GetInput(kIndex4);
|
||||
auto update_grad = ib->Emit("GatherNd", {dout, indices});
|
||||
return {dout, ib->ZerosLike(indices), update_grad};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("Concat").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("Concat").SetUnusedInputs({i0, i1}).SetBody(BODYFUNC(ib) {
|
||||
auto axis = ib->GetAttr<int64_t>(kAttrAxis);
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto dout = ib->GetInput(kIndex2);
|
||||
|
@ -537,14 +537,14 @@ REG_BPROP_BUILDER("Concat").SetBody([](const BpropIRBuilder *ib) -> NodePtrList
|
|||
return {ib->MakeTuple(res)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("Mvlgamma").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("Mvlgamma").SetUnusedInputs({i1}).SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto dout = ib->GetInput(kIndex2);
|
||||
auto dx = ib->Emit("MvlgammaGrad", {dout, x}, {{"p", ib->GetAttr("p")}});
|
||||
return {dx};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("TensorScatterDiv").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("TensorScatterDiv").SetUnusedInputs({i3}).SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto indices = ib->GetInput(kIndex1);
|
||||
auto update = ib->GetInput(kIndex2);
|
||||
|
@ -558,14 +558,14 @@ REG_BPROP_BUILDER("TensorScatterDiv").SetBody([](const BpropIRBuilder *ib) -> No
|
|||
return {in_grad, ib->ZerosLike(indices), update_grad};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("TensorScatterSub").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("TensorScatterSub").SetUnusedInputs({i0, i2, i3}).SetBody(BODYFUNC(ib) {
|
||||
auto indices = ib->GetInput(kIndex1);
|
||||
auto dout = ib->GetInput(kIndex4);
|
||||
auto update_grad = ib->Emit("Neg", {ib->Emit("GatherNd", {dout, indices})});
|
||||
return {dout, ib->ZerosLike(indices), update_grad};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("TensorScatterMul").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("TensorScatterMul").SetUnusedInputs({i3}).SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto indices = ib->GetInput(kIndex1);
|
||||
auto update = ib->GetInput(kIndex2);
|
||||
|
@ -600,7 +600,7 @@ NodePtrList TensorScatterPossibleReplacement(const BpropIRBuilder *ib) {
|
|||
REG_BPROP_BUILDER("TensorScatterMax").SetBody(TensorScatterPossibleReplacement);
|
||||
REG_BPROP_BUILDER("TensorScatterMin").SetBody(TensorScatterPossibleReplacement);
|
||||
|
||||
REG_BPROP_BUILDER("IndexFill").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("IndexFill").SetUnusedInputs({i4}).SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto dim = ib->GetInput(kIndex1);
|
||||
auto indices = ib->GetInput(kIndex2);
|
||||
|
@ -618,7 +618,7 @@ REG_BPROP_BUILDER("IndexFill").SetBody([](const BpropIRBuilder *ib) -> NodePtrLi
|
|||
return {x_grad, ib->ZerosLike(dim), ib->ZerosLike(indices), value_grad};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("UnsortedSegmentSum").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("UnsortedSegmentSum").SetUnusedInputs({i0, i3}).SetBody(BODYFUNC(ib) {
|
||||
auto segment_ids = ib->GetInput(kIndex1);
|
||||
auto num_segments = ib->GetInput(kIndex2);
|
||||
auto dout = ib->GetInput(kIndex4);
|
||||
|
@ -626,7 +626,7 @@ REG_BPROP_BUILDER("UnsortedSegmentSum").SetBody([](const BpropIRBuilder *ib) ->
|
|||
ib->ZerosLike(num_segments)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("UnsortedSegmentMin").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("UnsortedSegmentMin").SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto segment_ids = ib->GetInput(kIndex1);
|
||||
auto num_segments = ib->GetInput(kIndex2);
|
||||
|
@ -635,7 +635,7 @@ REG_BPROP_BUILDER("UnsortedSegmentMin").SetBody([](const BpropIRBuilder *ib) ->
|
|||
return UnsortedSegmentMinOrMaxGrad(ib, x, segment_ids, num_segments, out, dout);
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("UnsortedSegmentMax").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("UnsortedSegmentMax").SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto segment_ids = ib->GetInput(kIndex1);
|
||||
auto num_segments = ib->GetInput(kIndex2);
|
||||
|
@ -644,7 +644,7 @@ REG_BPROP_BUILDER("UnsortedSegmentMax").SetBody([](const BpropIRBuilder *ib) ->
|
|||
return UnsortedSegmentMinOrMaxGrad(ib, x, segment_ids, num_segments, out, dout);
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("UnsortedSegmentProd").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("UnsortedSegmentProd").SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto segment_ids = ib->GetInput(kIndex1);
|
||||
auto num_segments = ib->GetInput(kIndex2);
|
||||
|
@ -698,21 +698,21 @@ REG_BPROP_BUILDER("UnsortedSegmentProd").SetBody([](const BpropIRBuilder *ib) ->
|
|||
return {dx, ib->ZerosLike(segment_ids), ib->ZerosLike(num_segments)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("SpaceToBatchND").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("SpaceToBatchND").SetUnusedInputs({i0, i1}).SetBody(BODYFUNC(ib) {
|
||||
auto dout = ib->GetInput(kIndex2);
|
||||
auto dx = ib->Emit("BatchToSpaceND", {dout},
|
||||
{{"block_shape", ib->GetAttr("block_shape")}, {"crops", ib->GetAttr("paddings")}});
|
||||
return {dx};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("BatchToSpaceND").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("BatchToSpaceND").SetUnusedInputs({i0, i1}).SetBody(BODYFUNC(ib) {
|
||||
auto dout = ib->GetInput(kIndex2);
|
||||
auto dx = ib->Emit("SpaceToBatchND", {dout},
|
||||
{{"block_shape", ib->GetAttr("block_shape")}, {"paddings", ib->GetAttr("crops")}});
|
||||
return {dx};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("BroadcastTo").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("BroadcastTo").SetUnusedInputs({i1}).SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto dout = ib->GetInput(kIndex2);
|
||||
auto broadcast_shape = ib->GetAttr<ShapeVector>("shape");
|
||||
|
@ -731,7 +731,7 @@ REG_BPROP_BUILDER("BroadcastTo").SetBody([](const BpropIRBuilder *ib) -> NodePtr
|
|||
return {dx};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("SpaceToDepth").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("SpaceToDepth").SetUnusedInputs({i0, i1}).SetBody(BODYFUNC(ib) {
|
||||
auto dout = ib->GetInput(kIndex2);
|
||||
return {ib->Emit("DepthToSpace", {dout},
|
||||
{{"block_size", ib->GetAttr("block_size")},
|
||||
|
@ -739,7 +739,7 @@ REG_BPROP_BUILDER("SpaceToDepth").SetBody([](const BpropIRBuilder *ib) -> NodePt
|
|||
{"format", ib->GetAttr("format")}})};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("DepthToSpace").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("DepthToSpace").SetUnusedInputs({i0, i1}).SetBody(BODYFUNC(ib) {
|
||||
auto dout = ib->GetInput(kIndex2);
|
||||
return {ib->Emit("SpaceToDepth", {dout},
|
||||
{{"block_size", ib->GetAttr("block_size")},
|
||||
|
@ -747,31 +747,31 @@ REG_BPROP_BUILDER("DepthToSpace").SetBody([](const BpropIRBuilder *ib) -> NodePt
|
|||
{"format", ib->GetAttr("format")}})};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("ScatterMax").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("ScatterMax").SetUnusedInputs({i0, i2, i3}).SetBody(BODYFUNC(ib) {
|
||||
auto indices = ib->GetInput(kIndex1);
|
||||
auto dout = ib->GetInput(kIndex4);
|
||||
return {dout, ib->ZerosLike(indices), ib->Emit("Gather", {dout, indices, ib->Tensor(0, kInt64)})};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("ScatterMin").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("ScatterMin").SetUnusedInputs({i0, i2, i3}).SetBody(BODYFUNC(ib) {
|
||||
auto indices = ib->GetInput(kIndex1);
|
||||
auto dout = ib->GetInput(kIndex4);
|
||||
return {dout, ib->ZerosLike(indices), ib->Emit("Gather", {dout, indices, ib->Tensor(0, kInt64)})};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("ScatterUpdate").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("ScatterUpdate").SetUnusedInputs({i0, i2, i3}).SetBody(BODYFUNC(ib) {
|
||||
auto indices = ib->GetInput(kIndex1);
|
||||
auto dout = ib->GetInput(kIndex4);
|
||||
return {dout, ib->ZerosLike(indices), ib->Emit("Gather", {dout, indices, ib->Tensor(0, kInt64)})};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("Fills").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("Fills").SetUnusedInputs({i2, i3}).SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto value = ib->GetInput(kIndex1);
|
||||
return {ib->ZerosLike(x), ib->ZerosLike(value)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("Cast").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("Cast").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto t = ib->GetInput(kIndex1);
|
||||
auto dout = ib->GetInput(kIndex3);
|
||||
|
@ -780,7 +780,7 @@ REG_BPROP_BUILDER("Cast").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
|||
return {dx, ib->ZerosLike(t)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("ExpandDims").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("ExpandDims").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto axis = ib->GetInput(kIndex1);
|
||||
auto dout = ib->GetInput(kIndex3);
|
||||
|
@ -788,14 +788,14 @@ REG_BPROP_BUILDER("ExpandDims").SetBody([](const BpropIRBuilder *ib) -> NodePtrL
|
|||
return {ib->Reshape(dout, shapex), ib->ZerosLike(axis)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("Squeeze").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("Squeeze").SetUnusedInputs({i1}).SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto dout = ib->GetInput(kIndex2);
|
||||
auto shapex = ib->GetShape(x);
|
||||
return {ib->Reshape(dout, shapex)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("Padding").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("Padding").SetUnusedInputs({i1}).SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto dout = ib->GetInput(kIndex2);
|
||||
auto shp = ib->GetShape(x);
|
||||
|
@ -805,7 +805,7 @@ REG_BPROP_BUILDER("Padding").SetBody([](const BpropIRBuilder *ib) -> NodePtrList
|
|||
return {dx};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("Transpose").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("Transpose").SetUnusedInputs({i0, i2}).SetBody(BODYFUNC(ib) {
|
||||
auto perm = ib->GetInput(kIndex1);
|
||||
auto dout = ib->GetInput(kIndex3);
|
||||
auto tmp_perm = GetIntList(perm);
|
||||
|
@ -816,7 +816,7 @@ REG_BPROP_BUILDER("Transpose").SetBody([](const BpropIRBuilder *ib) -> NodePtrLi
|
|||
return {ib->Transpose(dout, res_perm), ib->ZerosLike(perm)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("Slice").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("Slice").SetUnusedInputs({i3}).SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto begin = ib->GetInput(kIndex1);
|
||||
auto size = ib->GetInput(kIndex2);
|
||||
|
@ -825,13 +825,13 @@ REG_BPROP_BUILDER("Slice").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
|||
return {dx, ib->ZerosLike(begin), ib->ZerosLike(size)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("Split").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("Split").SetUnusedInputs({i0, i1}).SetBody(BODYFUNC(ib) {
|
||||
auto dout = ib->GetInput(kIndex2);
|
||||
auto dx = ib->Emit("Concat", {dout}, {{"axis", ib->GetAttr("axis")}});
|
||||
return {dx};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("Tile").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("Tile").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto input_multiples = ib->GetInput(kIndex1);
|
||||
auto dout = ib->GetInput(kIndex3);
|
||||
|
@ -858,18 +858,44 @@ REG_BPROP_BUILDER("Tile").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
|||
return {dx, ib->ZerosLike(input_multiples)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("Gather").SetBody([](const BpropIRBuilder *ib) -> NodePtrList { return BinopGatherCommon(ib); });
|
||||
NodePtrList BinopGatherCommon(const BpropIRBuilder *ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto indices = ib->GetInput(kIndex1);
|
||||
auto axis = ib->GetInput(kIndex2);
|
||||
auto dout = ib->GetInput(kIndex4);
|
||||
auto orig_indices = indices;
|
||||
auto x_shp = ib->GetShape(x);
|
||||
auto out_shp = ib->GetShape(dout);
|
||||
auto ind_shp = ib->GetShape(indices);
|
||||
auto axis_v = CheckRange(GetIntValue(axis), SizeToLong(x_shp.size()));
|
||||
if (out_shp.empty()) {
|
||||
dout = ib->Emit("ExpandDims", {dout, ib->Tensor(-1)});
|
||||
}
|
||||
if (ind_shp.empty()) {
|
||||
indices = ib->Emit("ExpandDims", {indices, ib->Tensor(-1)});
|
||||
ind_shp = ib->GetShape(indices);
|
||||
auto out_shp1 = RegenerateOutputShape(x_shp, ind_shp, axis_v);
|
||||
dout = ib->Reshape(dout, out_shp1);
|
||||
}
|
||||
out_shp = ib->GetShape(dout);
|
||||
auto perm_1 = GenerateShapeIndex(out_shp, ind_shp, axis_v);
|
||||
auto values_transpose = ib->Transpose(dout, perm_1);
|
||||
auto tmp = ib->Emit("UnsortedSegmentSum", {values_transpose, indices, ib->Value<int64_t>(x_shp[axis_v])});
|
||||
auto perm_2 = GenerateInverseIndex(x_shp, axis_v);
|
||||
auto params_grad = ib->Transpose(tmp, perm_2);
|
||||
return {params_grad, ib->ZerosLike(orig_indices), ib->ZerosLike(axis)};
|
||||
}
|
||||
REG_BPROP_BUILDER("Gather").SetUnusedInputs({i3}).SetBody(BinopGatherCommon);
|
||||
REG_BPROP_BUILDER("GatherV2").SetUnusedInputs({i3}).SetBody(BinopGatherCommon);
|
||||
|
||||
REG_BPROP_BUILDER("GatherV2").SetBody([](const BpropIRBuilder *ib) -> NodePtrList { return BinopGatherCommon(ib); });
|
||||
|
||||
REG_BPROP_BUILDER("Fill").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("Fill").SetUnusedInputs({i3, i4}).SetBody(BODYFUNC(ib) {
|
||||
auto dtype = ib->GetInput(kIndex0);
|
||||
auto dims = ib->GetInput(kIndex1);
|
||||
auto x = ib->GetInput(kIndex2);
|
||||
return {ib->ZerosLike(dtype), ib->ZerosLike(dims), ib->ZerosLike(x)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("MatrixDiagV3").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("MatrixDiagV3").SetUnusedInputs({i0, i5}).SetBody(BODYFUNC(ib) {
|
||||
auto k = ib->GetInput(kIndex1);
|
||||
auto num_rows = ib->GetInput(kIndex2);
|
||||
auto num_cols = ib->GetInput(kIndex3);
|
||||
|
@ -880,7 +906,7 @@ REG_BPROP_BUILDER("MatrixDiagV3").SetBody([](const BpropIRBuilder *ib) -> NodePt
|
|||
return {part, ib->ZerosLike(k), ib->ZerosLike(num_rows), ib->ZerosLike(num_cols), ib->ZerosLike(padding_value)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("MatrixDiagPartV3").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("MatrixDiagPartV3").SetUnusedInputs({i3}).SetBody(BODYFUNC(ib) {
|
||||
auto align = ib->GetAttr("align");
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto k = ib->GetInput(kIndex1);
|
||||
|
@ -895,7 +921,7 @@ REG_BPROP_BUILDER("MatrixDiagPartV3").SetBody([](const BpropIRBuilder *ib) -> No
|
|||
return {diag, ib->ZerosLike(k), ib->ZerosLike(padding_value)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("MatrixSetDiagV3").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("MatrixSetDiagV3").SetUnusedInputs({i0, i3}).SetBody(BODYFUNC(ib) {
|
||||
auto align = ib->GetAttr("align");
|
||||
auto diagonal = ib->GetInput(kIndex1);
|
||||
auto k = ib->GetInput(kIndex2);
|
||||
|
@ -910,37 +936,37 @@ REG_BPROP_BUILDER("MatrixSetDiagV3").SetBody([](const BpropIRBuilder *ib) -> Nod
|
|||
return {x_cal, diagonal_cal, ib->ZerosLike(k)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("LogNormalReverse").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("LogNormalReverse").SetUnusedInputs({i1, i2}).SetBody(BODYFUNC(ib) {
|
||||
auto input_data = ib->GetInput(kIndex0);
|
||||
return {ib->ZerosLike(input_data)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("Shape").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("Shape").SetUnusedInputs({i1, i2}).SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
return {ib->ZerosLike(x)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("Rank").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("Rank").SetUnusedInputs({i1, i2}).SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
return {ib->ZerosLike(x)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("DynamicShape").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("DynamicShape").SetUnusedInputs({i1, i2}).SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
return {ib->ZerosLike(x)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("TensorShape").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("TensorShape").SetUnusedInputs({i1, i2}).SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
return {ib->ZerosLike(x)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("DType").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("DType").SetUnusedInputs({i1, i2}).SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
return {ib->ZerosLike(x)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("StridedSliceV2").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("StridedSliceV2").SetUnusedInputs({i4}).SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto begin = ib->GetInput(kIndex1);
|
||||
auto end = ib->GetInput(kIndex2);
|
||||
|
@ -956,7 +982,7 @@ REG_BPROP_BUILDER("StridedSliceV2").SetBody([](const BpropIRBuilder *ib) -> Node
|
|||
return {dx, ib->ZerosLike(begin), ib->ZerosLike(end), ib->ZerosLike(strides)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("MaskedFill").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("MaskedFill").SetUnusedInputs({i3}).SetBody(BODYFUNC(ib) {
|
||||
auto input_data = ib->GetInput(kIndex0);
|
||||
auto mask = ib->GetInput(kIndex1);
|
||||
auto value = ib->GetInput(kIndex2);
|
||||
|
@ -976,7 +1002,7 @@ REG_BPROP_BUILDER("MaskedFill").SetBody([](const BpropIRBuilder *ib) -> NodePtrL
|
|||
return {dinput, ib->ZerosLike(mask), dvalue};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("Coalesce").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("Coalesce").SetUnusedInputs({i0, i1, i2, i3}).SetBody(BODYFUNC(ib) {
|
||||
auto dout = ib->GetInput(kIndex4);
|
||||
auto d1 = ib->TupleGetItem(dout, 0);
|
||||
auto d2 = ib->TupleGetItem(dout, 1);
|
||||
|
@ -984,7 +1010,7 @@ REG_BPROP_BUILDER("Coalesce").SetBody([](const BpropIRBuilder *ib) -> NodePtrLis
|
|||
return {d1, d2, d3};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("ConjugateTranspose").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("ConjugateTranspose").SetUnusedInputs({i0, i2}).SetBody(BODYFUNC(ib) {
|
||||
auto perm = ib->GetInput(kIndex1);
|
||||
auto dout = ib->GetInput(kIndex3);
|
||||
auto tmp_perm = GetIntList(perm);
|
||||
|
@ -995,24 +1021,24 @@ REG_BPROP_BUILDER("ConjugateTranspose").SetBody([](const BpropIRBuilder *ib) ->
|
|||
return {ib->Emit("ConjugateTranspose", {dout, ib->Value<ShapeVector>(res_perm)}), ib->ZerosLike(perm)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("Triu").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("Triu").SetUnusedInputs({i0, i1}).SetBody(BODYFUNC(ib) {
|
||||
auto diagonal = GetValue<int64_t>(ib->GetAttr("diagonal"));
|
||||
auto dout = ib->GetInput(kIndex2);
|
||||
auto dx = ib->Emit("Triu", {dout}, {{"diagonal", MakeValue(diagonal)}});
|
||||
return {dx};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("CheckNumerics").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("CheckNumerics").SetUnusedInputs({i0, i1}).SetBody(BODYFUNC(ib) {
|
||||
auto dout = ib->GetInput(kIndex2);
|
||||
return {ib->Emit("CheckNumerics", {dout})};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("IdentityN").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("IdentityN").SetUnusedInputs({i0, i1}).SetBody(BODYFUNC(ib) {
|
||||
auto dout = ib->GetInput(kIndex2);
|
||||
return {dout};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("ResizeNearestNeighborV2").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("ResizeNearestNeighborV2").SetUnusedInputs({i1, i2}).SetBody(BODYFUNC(ib) {
|
||||
auto align_corners = GetValue<bool>(ib->GetAttr("align_corners"));
|
||||
auto half_pixel_centers = GetValue<bool>(ib->GetAttr("half_pixel_centers"));
|
||||
auto data_format = GetValue<std::string>(ib->GetAttr("format"));
|
||||
|
@ -1031,14 +1057,14 @@ REG_BPROP_BUILDER("ResizeNearestNeighborV2").SetBody([](const BpropIRBuilder *ib
|
|||
return {dx, ib->ZerosLike(ib->Value<ShapeVector>(grad_in_size))};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("Tril").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("Tril").SetUnusedInputs({i0, i1}).SetBody(BODYFUNC(ib) {
|
||||
auto diagonal = GetValue<int64_t>(ib->GetAttr("diagonal"));
|
||||
auto dout = ib->GetInput(kIndex2);
|
||||
auto dx = ib->Emit("Tril", {dout}, {{"diagonal", MakeValue(diagonal)}});
|
||||
return {dx};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("SegmentSum").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("SegmentSum").SetUnusedInputs({i0, i2}).SetBody(BODYFUNC(ib) {
|
||||
auto segment_ids = ib->GetInput(kIndex1);
|
||||
auto dout = ib->GetInput(kIndex3);
|
||||
auto dout_type = ib->GetDtype(dout);
|
||||
|
@ -1052,7 +1078,7 @@ REG_BPROP_BUILDER("SegmentSum").SetBody([](const BpropIRBuilder *ib) -> NodePtrL
|
|||
return {ib->Cast(ib->Emit("Gather", {dout, segment_ids, ib->Tensor(0)}), dout_type), ib->ZerosLike(segment_ids)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("EmbeddingLookup").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("EmbeddingLookup").SetUnusedInputs({i3}).SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto indices = ib->GetInput(kIndex1);
|
||||
auto offset = ib->GetInput(kIndex2);
|
||||
|
@ -1078,7 +1104,7 @@ REG_BPROP_BUILDER("EmbeddingLookup").SetBody([](const BpropIRBuilder *ib) -> Nod
|
|||
ib->ZerosLike(offset)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("MaskedSelect").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("MaskedSelect").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto mask = ib->GetInput(kIndex1);
|
||||
auto dout = ib->GetInput(kIndex3);
|
||||
|
@ -1086,14 +1112,14 @@ REG_BPROP_BUILDER("MaskedSelect").SetBody([](const BpropIRBuilder *ib) -> NodePt
|
|||
return {dx, ib->ZerosLike(mask)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("SplitV").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("SplitV").SetUnusedInputs({i0, i1}).SetBody(BODYFUNC(ib) {
|
||||
auto split_dim = GetValue<int64_t>(ib->GetAttr("split_dim"));
|
||||
auto dout = ib->GetInput(kIndex2);
|
||||
auto dx = ib->Emit("Concat", {dout}, {{"axis", MakeValue(split_dim)}});
|
||||
return {dx};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("Col2Im").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("Col2Im").SetUnusedInputs({i0, i2}).SetBody(BODYFUNC(ib) {
|
||||
auto ksizes = GetValue<std::vector<int64_t>>(ib->GetAttr("kernel_size"));
|
||||
auto dilations = GetValue<std::vector<int64_t>>(ib->GetAttr("dilation"));
|
||||
auto strides = GetValue<std::vector<int64_t>>(ib->GetAttr("stride"));
|
||||
|
@ -1109,7 +1135,7 @@ REG_BPROP_BUILDER("Col2Im").SetBody([](const BpropIRBuilder *ib) -> NodePtrList
|
|||
return {dx, ib->ZerosLike(output_size)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("ExtractVolumePatches").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("ExtractVolumePatches").SetBody(BODYFUNC(ib) {
|
||||
auto ksize = GetValue<std::vector<int64_t>>(ib->GetAttr("kernel_size"));
|
||||
auto ksize_d = ksize.at(2);
|
||||
auto ksize_h = ksize.at(3);
|
||||
|
@ -1159,7 +1185,7 @@ REG_BPROP_BUILDER("ExtractVolumePatches").SetBody([](const BpropIRBuilder *ib) -
|
|||
return {dx};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("AffineGrid").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("AffineGrid").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) {
|
||||
auto align_corners = GetValue<bool>(ib->GetAttr("align_corners"));
|
||||
auto theta = ib->GetInput(kIndex0);
|
||||
auto output_size = GetIntList(ib->GetInput(kIndex1));
|
||||
|
@ -1268,7 +1294,7 @@ NodePtrList SegmentMinOrMaxGrad(const BpropIRBuilder *ib) {
|
|||
REG_BPROP_BUILDER("SegmentMax").SetBody(SegmentMinOrMaxGrad);
|
||||
REG_BPROP_BUILDER("SegmentMin").SetBody(SegmentMinOrMaxGrad);
|
||||
|
||||
REG_BPROP_BUILDER("TensorScatterElements").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("TensorScatterElements").SetUnusedInputs({i0, i3}).SetBody(BODYFUNC(ib) {
|
||||
auto indices = ib->GetInput(kIndex1);
|
||||
auto update = ib->GetInput(kIndex2);
|
||||
auto dout = ib->GetInput(kIndex4);
|
||||
|
@ -1279,7 +1305,7 @@ REG_BPROP_BUILDER("TensorScatterElements").SetBody([](const BpropIRBuilder *ib)
|
|||
return {x_grad, ib->ZerosLike(indices), update_grad};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("ScatterAddWithAxis").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("ScatterAddWithAxis").SetUnusedInputs({i0, i2, i3}).SetBody(BODYFUNC(ib) {
|
||||
auto axis = ib->GetAttr("axis");
|
||||
auto indices = ib->GetInput(kIndex1);
|
||||
auto dout = ib->GetInput(kIndex4);
|
||||
|
@ -1302,7 +1328,7 @@ REG_BPROP_BUILDER("ScatterAddWithAxis").SetBody([](const BpropIRBuilder *ib) ->
|
|||
return {dout, ib->ZerosLike(indices), update_grad};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("Expand").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("Expand").SetUnusedInputs({i1, i2}).SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto dout = ib->GetInput(kIndex3);
|
||||
auto dout_shape = ib->GetShape(dout);
|
||||
|
@ -1324,7 +1350,7 @@ REG_BPROP_BUILDER("Expand").SetBody([](const BpropIRBuilder *ib) -> NodePtrList
|
|||
return {dx, ib->ZerosLike(dout)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("SegmentMean").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("SegmentMean").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) {
|
||||
auto input_x = ib->GetInput(kIndex0);
|
||||
auto segment_ids = ib->GetInput(kIndex1);
|
||||
auto dout = ib->GetInput(kIndex3);
|
||||
|
|
|
@ -18,7 +18,7 @@
|
|||
#include "include/common/utils/utils.h"
|
||||
|
||||
namespace mindspore::expander::bprop {
|
||||
REG_BPROP_BUILDER("ClipByNorm").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("ClipByNorm").SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto clip_norm = ib->GetInput(kIndex1);
|
||||
auto out = ib->GetInput(kIndex2);
|
||||
|
|
|
@ -15,27 +15,28 @@
|
|||
*/
|
||||
#include "pipeline/pynative/grad/bprop_expander/bprop_irbuilder.h"
|
||||
#include "include/common/utils/utils.h"
|
||||
#include "pipeline/pynative/grad/bprop_expander/grad_ops/common_utils.h"
|
||||
|
||||
namespace mindspore::expander::bprop {
|
||||
REG_BPROP_BUILDER("ScalarSummary").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("ScalarSummary").SetUnusedInputs({i2, i3}).SetBody(BODYFUNC(ib) {
|
||||
auto tag = ib->GetInput(kIndex0);
|
||||
auto x = ib->GetInput(kIndex1);
|
||||
return {tag, ib->ZerosLike(x)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("TensorSummary").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("TensorSummary").SetUnusedInputs({i2, i3}).SetBody(BODYFUNC(ib) {
|
||||
auto tag = ib->GetInput(kIndex0);
|
||||
auto x = ib->GetInput(kIndex1);
|
||||
return {tag, ib->ZerosLike(x)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("ImageSummary").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("ImageSummary").SetUnusedInputs({i2, i3}).SetBody(BODYFUNC(ib) {
|
||||
auto tag = ib->GetInput(kIndex0);
|
||||
auto x = ib->GetInput(kIndex1);
|
||||
return {tag, ib->ZerosLike(x)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("HistogramSummary").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("HistogramSummary").SetUnusedInputs({i2, i3}).SetBody(BODYFUNC(ib) {
|
||||
auto tag = ib->GetInput(kIndex0);
|
||||
auto x = ib->GetInput(kIndex1);
|
||||
return {tag, ib->ZerosLike(x)};
|
||||
|
|
|
@ -20,7 +20,7 @@
|
|||
#include "utils/ms_context.h"
|
||||
|
||||
namespace mindspore::expander::bprop {
|
||||
REG_BPROP_BUILDER("ResizeBicubic").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("ResizeBicubic").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) {
|
||||
auto images = ib->GetInput(kIndex0);
|
||||
auto size = ib->GetInput(kIndex1);
|
||||
auto dout = ib->GetInput(kIndex3);
|
||||
|
@ -36,7 +36,7 @@ REG_BPROP_BUILDER("ResizeBicubic").SetBody([](const BpropIRBuilder *ib) -> NodeP
|
|||
return {dx, ib->Emit("ZerosLike", {size})};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("CropAndResize").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("CropAndResize").SetUnusedInputs({i4}).SetBody(BODYFUNC(ib) {
|
||||
std::set<TypeId> allowed_types = {kNumberTypeFloat16, kNumberTypeFloat32, kNumberTypeFloat64};
|
||||
auto method = GetValue<std::string>(ib->GetAttr("method"));
|
||||
auto target = ib->GetTargetFromContext();
|
||||
|
@ -63,7 +63,7 @@ REG_BPROP_BUILDER("CropAndResize").SetBody([](const BpropIRBuilder *ib) -> NodeP
|
|||
return {dimage, dbox, ib->ZerosLike(box_index), ib->ZerosLike(crop_size)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("ScaleAndTranslate").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("ScaleAndTranslate").SetUnusedInputs({i4}).SetBody(BODYFUNC(ib) {
|
||||
auto images = ib->GetInput(kIndex0);
|
||||
auto size = ib->GetInput(kIndex1);
|
||||
auto scale = ib->GetInput(kIndex2);
|
||||
|
@ -77,7 +77,7 @@ REG_BPROP_BUILDER("ScaleAndTranslate").SetBody([](const BpropIRBuilder *ib) -> N
|
|||
return {grad0, ib->ZerosLike(size), ib->ZerosLike(scale), ib->ZerosLike(translation)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("RGBToHSV").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("RGBToHSV").SetBody(BODYFUNC(ib) {
|
||||
auto images = ib->GetInput(kIndex0);
|
||||
auto out = ib->GetInput(kIndex1);
|
||||
auto dout = ib->GetInput(kIndex2);
|
||||
|
|
|
@ -16,9 +16,10 @@
|
|||
|
||||
#include "pipeline/pynative/grad/bprop_expander/bprop_irbuilder.h"
|
||||
#include "include/common/utils/utils.h"
|
||||
#include "pipeline/pynative/grad/bprop_expander/grad_ops/common_utils.h"
|
||||
|
||||
namespace mindspore::expander::bprop {
|
||||
REG_BPROP_BUILDER("Load").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("Load").SetUnusedInputs({i0, i1, i2}).SetBody(BODYFUNC(ib) {
|
||||
auto u_monad = ib->GetInput(kIndex1);
|
||||
auto dout = ib->GetInput(kIndex3);
|
||||
return {dout, ib->ZerosLike(u_monad)};
|
||||
|
|
|
@ -40,7 +40,7 @@ static NodePtr GetMatrixDiagPartAssist(const BpropIRBuilder *ib, const ShapeVect
|
|||
return ib->Reshape(tile, x_shape);
|
||||
}
|
||||
|
||||
REG_BPROP_BUILDER("MatrixDiag").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("MatrixDiag").SetUnusedInputs({i0, i2}).SetBody(BODYFUNC(ib) {
|
||||
auto y = ib->GetInput(kIndex1);
|
||||
auto dout = ib->GetInput(kIndex3);
|
||||
auto shape = ib->GetShape(dout);
|
||||
|
@ -50,7 +50,7 @@ REG_BPROP_BUILDER("MatrixDiag").SetBody([](const BpropIRBuilder *ib) -> NodePtrL
|
|||
return {dx, ib->ZerosLike(y)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("MatrixDiagPart").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("MatrixDiagPart").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto y = ib->GetInput(kIndex1);
|
||||
auto dout = ib->GetInput(kIndex3);
|
||||
|
@ -65,7 +65,7 @@ REG_BPROP_BUILDER("MatrixDiagPart").SetBody([](const BpropIRBuilder *ib) -> Node
|
|||
return {ib->Emit("MatrixSetDiag", {ib->ZerosLike(x), dout, assist1}), ib->ZerosLike(y)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("MatrixSetDiag").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("MatrixSetDiag").SetUnusedInputs({i1, i3}).SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto z = ib->GetInput(kIndex2);
|
||||
auto dout = ib->GetInput(kIndex4);
|
||||
|
@ -83,7 +83,7 @@ REG_BPROP_BUILDER("MatrixSetDiag").SetBody([](const BpropIRBuilder *ib) -> NodeP
|
|||
return {dx, dy, dz};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("DSDMatmul").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("DSDMatmul").SetBody(BODYFUNC(ib) {
|
||||
auto w1_gm = ib->GetInput(kIndex0);
|
||||
auto w2_gm = ib->GetInput(kIndex1);
|
||||
auto v_gm = ib->GetInput(kIndex2);
|
||||
|
@ -96,13 +96,12 @@ REG_BPROP_BUILDER("DSDMatmul").SetBody([](const BpropIRBuilder *ib) -> NodePtrLi
|
|||
return {d_w1_gm, d_w2_gm, d_v_gm};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("MatmulDDS").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("MatmulDDS").SetUnusedInputs({i2, i3, i5}).SetBody(BODYFUNC(ib) {
|
||||
auto q = ib->GetInput(kIndex0);
|
||||
auto k = ib->GetInput(kIndex1);
|
||||
auto local_mask = ib->GetInput(kIndex2);
|
||||
auto global_mask = ib->GetInput(kIndex3);
|
||||
auto out = ib->GetInput(kIndex4);
|
||||
auto d_out = ib->GetInput(kIndex5);
|
||||
auto lc = ib->TupleGetItem(out, kIndex0);
|
||||
auto gc = ib->TupleGetItem(out, kIndex1);
|
||||
auto d_lc = ib->TupleGetItem(out, kIndex0);
|
||||
|
@ -115,7 +114,7 @@ REG_BPROP_BUILDER("MatmulDDS").SetBody([](const BpropIRBuilder *ib) -> NodePtrLi
|
|||
return {dq, dk, ib->ZerosLike(local_mask), ib->ZerosLike(global_mask)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("PsROIPooling").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("PsROIPooling").SetBody(BODYFUNC(ib) {
|
||||
auto pooled_height = GetValue<int64_t>(ib->GetAttr("pooled_height"));
|
||||
auto pooled_width = GetValue<int64_t>(ib->GetAttr("pooled_width"));
|
||||
auto spatial_scale = GetValue<float>(ib->GetAttr("spatial_scale"));
|
||||
|
@ -144,7 +143,7 @@ REG_BPROP_BUILDER("PsROIPooling").SetBody([](const BpropIRBuilder *ib) -> NodePt
|
|||
return {dx, ib->ZerosLike(rois)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("ResizeBilinearV2").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("ResizeBilinearV2").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto size = ib->GetInput(kIndex1);
|
||||
auto dout = ib->GetInput(kIndex3);
|
||||
|
@ -154,18 +153,18 @@ REG_BPROP_BUILDER("ResizeBilinearV2").SetBody([](const BpropIRBuilder *ib) -> No
|
|||
return {dx, ib->ZerosLike(size)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("ConvertToDynamic").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("ConvertToDynamic").SetUnusedInputs({i0, i1}).SetBody(BODYFUNC(ib) {
|
||||
auto dout = ib->GetInput(kIndex2);
|
||||
return {dout};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("_VirtualPipelineEnd").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("_VirtualPipelineEnd").SetUnusedInputs({i0, i1}).SetBody(BODYFUNC(ib) {
|
||||
auto dout = ib->GetInput(kIndex2);
|
||||
auto dx = ib->Emit("_VirtualPipelineEnd", {dout});
|
||||
return {dx};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("FillV2").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("FillV2").SetUnusedInputs({i1, i2}).SetBody(BODYFUNC(ib) {
|
||||
auto shape = ib->GetInput(kIndex0);
|
||||
auto dout = ib->GetInput(kIndex3);
|
||||
auto dout_typeptr = ib->GetDtype(dout);
|
||||
|
@ -182,7 +181,7 @@ REG_BPROP_BUILDER("FillV2").SetBody([](const BpropIRBuilder *ib) -> NodePtrList
|
|||
return {ib->ZerosLike(shape), ib->Cast(dvalue, dout_typeptr)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("TensorCopySlices").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("TensorCopySlices").SetUnusedInputs({i0, i5}).SetBody(BODYFUNC(ib) {
|
||||
auto update = ib->GetInput(kIndex1);
|
||||
auto begin = ib->GetInput(kIndex2);
|
||||
auto end = ib->GetInput(kIndex3);
|
||||
|
@ -203,14 +202,14 @@ REG_BPROP_BUILDER("TensorCopySlices").SetBody([](const BpropIRBuilder *ib) -> No
|
|||
return {x_grad, update_grad, ib->ZerosLike(begin), ib->ZerosLike(end), ib->ZerosLike(stride)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("Roll").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("Roll").SetUnusedInputs({i0, i1}).SetBody(BODYFUNC(ib) {
|
||||
auto dout = ib->GetInput(kIndex2);
|
||||
std::vector<int64_t> shift = GetIntList(ib->GetAttr("shift"));
|
||||
std::transform(shift.begin(), shift.end(), shift.begin(), [](const int64_t &e) { return -e; });
|
||||
return {ib->Emit("Roll", {dout}, {{"axis", ib->GetAttr("axis")}, {"shift", MakeValue(shift)}})};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("DynamicResizeNearestNeighbor").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("DynamicResizeNearestNeighbor").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) {
|
||||
auto inputs = ib->GetInput(kIndex0);
|
||||
auto size = ib->GetInput(kIndex1);
|
||||
auto dout = ib->GetInput(kIndex3);
|
||||
|
@ -224,7 +223,7 @@ REG_BPROP_BUILDER("DynamicResizeNearestNeighbor").SetBody([](const BpropIRBuilde
|
|||
ib->ZerosLike(size)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("ParallelResizeBilinear").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("ParallelResizeBilinear").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto size = ib->GetInput(kIndex1);
|
||||
auto dout = ib->GetInput(kIndex3);
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -19,10 +19,9 @@
|
|||
#include "utils/check_convert_utils.h"
|
||||
|
||||
namespace mindspore::expander::bprop {
|
||||
REG_BPROP_BUILDER("Conv2D").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("Conv2D").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto w = ib->GetInput(kIndex1);
|
||||
auto out = ib->GetInput(kIndex2);
|
||||
auto dout = ib->GetInput(kIndex3);
|
||||
auto x_shape = ib->GetShape(x);
|
||||
auto w_shape = ib->GetShape(w);
|
||||
|
@ -57,7 +56,7 @@ REG_BPROP_BUILDER("Conv2D").SetBody([](const BpropIRBuilder *ib) -> NodePtrList
|
|||
{"pad_list", ib->GetAttr("pad_list")}});
|
||||
return {dx, dw};
|
||||
});
|
||||
REG_BPROP_BUILDER("MaxPool").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("MaxPool").SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto out = ib->GetInput(kIndex1);
|
||||
auto dout = ib->GetInput(kIndex2);
|
||||
|
@ -77,10 +76,7 @@ REG_BPROP_BUILDER("MaxPool").SetBody([](const BpropIRBuilder *ib) -> NodePtrList
|
|||
return {dx};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("BiasAdd").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto w = ib->GetInput(kIndex1);
|
||||
auto out = ib->GetInput(kIndex2);
|
||||
REG_BPROP_BUILDER("BiasAdd").SetUnusedInputs({i0, i1, i2}).SetBody(BODYFUNC(ib) {
|
||||
auto dout = ib->GetInput(kIndex3);
|
||||
auto format = GetValue<std::string>(ib->GetAttr("data_format"));
|
||||
if (format == "NCDHW") {
|
||||
|
@ -90,47 +86,47 @@ REG_BPROP_BUILDER("BiasAdd").SetBody([](const BpropIRBuilder *ib) -> NodePtrList
|
|||
ib->Emit(kBiasAddGradOpName, {dout}, {{"format", MakeValue(format)}, {"data_format", MakeValue(format)}})};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("ReLU").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("ReLU").SetUnusedInputs({i0}).SetBody(BODYFUNC(ib) {
|
||||
auto out = ib->GetInput(kIndex1);
|
||||
auto dout = ib->GetInput(kIndex2);
|
||||
auto dx = ib->Emit(kReluGradOpName, {dout, out});
|
||||
return {dx};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("TopK").SetBody([](const BpropIRBuilder *builder) -> NodePtrList {
|
||||
auto input_x = builder->GetInput(kIndex0);
|
||||
auto out = builder->GetInput(kIndex2);
|
||||
auto dout = builder->GetInput(kIndex3);
|
||||
REG_BPROP_BUILDER("TopK").SetBody(BODYFUNC(ib) {
|
||||
auto input_x = ib->GetInput(kIndex0);
|
||||
auto out = ib->GetInput(kIndex2);
|
||||
auto dout = ib->GetInput(kIndex3);
|
||||
|
||||
auto indices = builder->TupleGetItem(out, kIndex1);
|
||||
auto dout0 = builder->TupleGetItem(dout, kIndex0);
|
||||
auto indices = ib->TupleGetItem(out, kIndex1);
|
||||
auto dout0 = ib->TupleGetItem(dout, kIndex0);
|
||||
|
||||
auto in_shape = builder->GetShape(input_x);
|
||||
auto in_shape = ib->GetShape(input_x);
|
||||
auto in_lastdim = in_shape.back();
|
||||
|
||||
auto ind_shape = builder->GetShape(indices);
|
||||
auto ind_shape = ib->GetShape(indices);
|
||||
auto ind_lastdim = ind_shape.back();
|
||||
|
||||
auto ind_2d = builder->Reshape(indices, {-1, ind_lastdim});
|
||||
auto outerdim = builder->GetShape(ind_2d)[0]; // k
|
||||
auto ind_2d = ib->Reshape(indices, {-1, ind_lastdim});
|
||||
auto outerdim = ib->GetShape(ind_2d)[0]; // k
|
||||
|
||||
// [0, outerdim, 2*outerdim, ..., (k-1)*outerdim]
|
||||
auto indices_dtype = builder->GetDtype(indices);
|
||||
auto indices_dtype = ib->GetDtype(indices);
|
||||
std::vector<int64_t> range_flatten_index_vec(LongToSize(outerdim));
|
||||
for (int64_t i = 0; i < outerdim; i++) {
|
||||
range_flatten_index_vec[i] = i * in_lastdim;
|
||||
}
|
||||
auto range_flatten_index = builder->Tensor(range_flatten_index_vec, indices_dtype);
|
||||
auto ind = builder->Reshape(ind_2d + builder->Reshape(range_flatten_index, {-1, 1}), {-1, 1});
|
||||
auto range_flatten_index = ib->Tensor(range_flatten_index_vec, indices_dtype);
|
||||
auto ind = ib->Reshape(ind_2d + ib->Reshape(range_flatten_index, {-1, 1}), {-1, 1});
|
||||
auto in_shape_1d = ShapeVector(1, std::accumulate(in_shape.begin(), in_shape.end(), 1, std::multiplies<int64_t>()));
|
||||
auto out_grad = builder->Emit("ScatterNd", {ind, builder->Reshape(dout0, {-1}), builder->Value(in_shape_1d)});
|
||||
out_grad = builder->Reshape(out_grad, in_shape);
|
||||
auto out_grad = ib->Emit("ScatterNd", {ind, ib->Reshape(dout0, {-1}), ib->Value(in_shape_1d)});
|
||||
out_grad = ib->Reshape(out_grad, in_shape);
|
||||
|
||||
auto grad_k = builder->ZerosLike(builder->GetInput(kIndex1));
|
||||
auto grad_k = ib->ZerosLike(ib->GetInput(kIndex1));
|
||||
return {out_grad, grad_k};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("PReLU").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("PReLU").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto w = ib->GetInput(kIndex1);
|
||||
auto dout = ib->GetInput(kIndex3);
|
||||
|
@ -140,7 +136,7 @@ REG_BPROP_BUILDER("PReLU").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
|||
return {dx, dw};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("SigmoidCrossEntropyWithLogits").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("SigmoidCrossEntropyWithLogits").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto y = ib->GetInput(kIndex1);
|
||||
auto dout = ib->GetInput(kIndex3);
|
||||
|
@ -148,7 +144,7 @@ REG_BPROP_BUILDER("SigmoidCrossEntropyWithLogits").SetBody([](const BpropIRBuild
|
|||
return {dx, ib->ZerosLike(y)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("Pad").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("Pad").SetUnusedInputs({i1}).SetBody(BODYFUNC(ib) {
|
||||
auto paddings = ib->GetAttr<std::vector<std::vector<int64_t>>>("paddings");
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto dout = ib->GetInput(kIndex2);
|
||||
|
@ -161,7 +157,7 @@ REG_BPROP_BUILDER("Pad").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
|||
return {dx};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("ROIAlign").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("ROIAlign").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) {
|
||||
auto inputs = ib->GetInput(kIndex0);
|
||||
auto rois = ib->GetInput(kIndex1);
|
||||
auto dout = ib->GetInput(kIndex3);
|
||||
|
@ -175,7 +171,7 @@ REG_BPROP_BUILDER("ROIAlign").SetBody([](const BpropIRBuilder *ib) -> NodePtrLis
|
|||
return {dx, ib->ZerosLike(rois)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("LRN").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("LRN").SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto out = ib->GetInput(kIndex1);
|
||||
auto dout = ib->GetInput(kIndex2);
|
||||
|
@ -187,7 +183,7 @@ REG_BPROP_BUILDER("LRN").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
|||
return {dx};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("Dropout").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("Dropout").SetUnusedInputs({i0}).SetBody(BODYFUNC(ib) {
|
||||
auto out = ib->GetInput(kIndex1);
|
||||
auto dout = ib->GetInput(kIndex2);
|
||||
auto mask = ib->TupleGetItem(out, 1);
|
||||
|
@ -196,7 +192,7 @@ REG_BPROP_BUILDER("Dropout").SetBody([](const BpropIRBuilder *ib) -> NodePtrList
|
|||
return {dx};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("BinaryCrossEntropy").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("BinaryCrossEntropy").SetUnusedInputs({i3}).SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto y = ib->GetInput(kIndex1);
|
||||
auto weight = ib->GetInput(kIndex2);
|
||||
|
@ -205,7 +201,7 @@ REG_BPROP_BUILDER("BinaryCrossEntropy").SetBody([](const BpropIRBuilder *ib) ->
|
|||
return {dx, ib->ZerosLike(y), ib->ZerosLike(weight)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("DropoutGrad").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("DropoutGrad").SetUnusedInputs({i0, i2}).SetBody(BODYFUNC(ib) {
|
||||
auto mask = ib->GetInput(kIndex1);
|
||||
auto dout = ib->GetInput(kIndex3);
|
||||
auto dy = dout;
|
||||
|
@ -213,7 +209,7 @@ REG_BPROP_BUILDER("DropoutGrad").SetBody([](const BpropIRBuilder *ib) -> NodePtr
|
|||
return {dx, ib->ZerosLike(mask)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("DeformableOffsets").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("DeformableOffsets").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto offsets = ib->GetInput(kIndex1);
|
||||
auto dout = ib->GetInput(kIndex3);
|
||||
|
@ -229,7 +225,7 @@ REG_BPROP_BUILDER("DeformableOffsets").SetBody([](const BpropIRBuilder *ib) -> N
|
|||
return {ib->TupleGetItem(out_grad, 0), ib->TupleGetItem(out_grad, 1)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("LSTM").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("LSTM").SetBody(BODYFUNC(ib) {
|
||||
auto input_size = ib->GetAttr("input_size");
|
||||
auto hidden_size = ib->GetAttr("hidden_size");
|
||||
auto num_layers = ib->GetAttr("num_layers");
|
||||
|
@ -290,7 +286,7 @@ REG_BPROP_BUILDER("LSTM").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
|||
return {dx, dhx, dcx, dw};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("CudnnGRU").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("CudnnGRU").SetBody(BODYFUNC(ib) {
|
||||
auto input_size = ib->GetAttr("input_size");
|
||||
auto hidden_size = ib->GetAttr("hidden_size");
|
||||
auto num_layers = ib->GetAttr("num_layers");
|
||||
|
@ -326,14 +322,14 @@ REG_BPROP_BUILDER("CudnnGRU").SetBody([](const BpropIRBuilder *ib) -> NodePtrLis
|
|||
return {dx, dhx, dw};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("MirrorPad").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("MirrorPad").SetUnusedInputs({i0, i2}).SetBody(BODYFUNC(ib) {
|
||||
auto paddings = ib->GetInput(kIndex1);
|
||||
auto dout = ib->GetInput(kIndex3);
|
||||
auto dx = ib->Emit("MirrorPadGrad", {dout, paddings}, {{kAttrMode, ib->GetAttr(kAttrMode)}});
|
||||
return {dx, ib->ZerosLike(paddings)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("LayerNorm").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("LayerNorm").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto gamma = ib->GetInput(kIndex1);
|
||||
auto out = ib->GetInput(kIndex3);
|
||||
|
@ -347,7 +343,7 @@ REG_BPROP_BUILDER("LayerNorm").SetBody([](const BpropIRBuilder *ib) -> NodePtrLi
|
|||
return {d_x, d_gamma, d_beta};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("LayerNormGrad").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("LayerNormGrad").SetUnusedInputs({i5}).SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto dy = ib->GetInput(kIndex1);
|
||||
auto variance = ib->GetInput(kIndex2);
|
||||
|
@ -364,7 +360,7 @@ REG_BPROP_BUILDER("LayerNormGrad").SetBody([](const BpropIRBuilder *ib) -> NodeP
|
|||
return {d_x, d_dy, ib->ZerosLike(variance), ib->ZerosLike(mean), d_gamma};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("L2Normalize").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("L2Normalize").SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto out = ib->GetInput(kIndex1);
|
||||
auto dout = ib->GetInput(kIndex2);
|
||||
|
@ -373,7 +369,7 @@ REG_BPROP_BUILDER("L2Normalize").SetBody([](const BpropIRBuilder *ib) -> NodePtr
|
|||
return {dx};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("SoftmaxCrossEntropyWithLogits").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("SoftmaxCrossEntropyWithLogits").SetUnusedInputs({i0}).SetBody(BODYFUNC(ib) {
|
||||
auto labels = ib->GetInput(kIndex1);
|
||||
auto out = ib->GetInput(kIndex2);
|
||||
auto dout = ib->GetInput(kIndex3);
|
||||
|
@ -382,7 +378,7 @@ REG_BPROP_BUILDER("SoftmaxCrossEntropyWithLogits").SetBody([](const BpropIRBuild
|
|||
return {grad, ib->ZerosLike(labels)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("NLLLoss").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("NLLLoss").SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto target = ib->GetInput(kIndex1);
|
||||
auto weight = ib->GetInput(kIndex2);
|
||||
|
@ -395,7 +391,7 @@ REG_BPROP_BUILDER("NLLLoss").SetBody([](const BpropIRBuilder *ib) -> NodePtrList
|
|||
return {dx, ib->ZerosLike(target), ib->ZerosLike(weight)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("ResizeBilinear").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("ResizeBilinear").SetUnusedInputs({i1}).SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto dout = ib->GetInput(kIndex2);
|
||||
auto dx = ib->Emit(
|
||||
|
@ -404,7 +400,7 @@ REG_BPROP_BUILDER("ResizeBilinear").SetBody([](const BpropIRBuilder *ib) -> Node
|
|||
return {dx};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("OneHot").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("OneHot").SetUnusedInputs({i4, i5}).SetBody(BODYFUNC(ib) {
|
||||
auto indices = ib->GetInput(kIndex0);
|
||||
auto depth = ib->GetInput(kIndex1);
|
||||
auto on_value = ib->GetInput(kIndex2);
|
||||
|
@ -413,7 +409,7 @@ REG_BPROP_BUILDER("OneHot").SetBody([](const BpropIRBuilder *ib) -> NodePtrList
|
|||
ib->ZerosLike(off_value)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("SmoothL1Loss").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("SmoothL1Loss").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) {
|
||||
auto prediction = ib->GetInput(kIndex0);
|
||||
auto target = ib->GetInput(kIndex1);
|
||||
auto dout = ib->GetInput(kIndex3);
|
||||
|
@ -424,14 +420,14 @@ REG_BPROP_BUILDER("SmoothL1Loss").SetBody([](const BpropIRBuilder *ib) -> NodePt
|
|||
return {dx, dy};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("L2Loss").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("L2Loss").SetUnusedInputs({i1}).SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto dout = ib->GetInput(kIndex2);
|
||||
auto dx = ib->Mul(x, dout);
|
||||
return {dx};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("RNNTLoss").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("RNNTLoss").SetUnusedInputs({i0, i5}).SetBody(BODYFUNC(ib) {
|
||||
auto labels = ib->GetInput(kIndex1);
|
||||
auto act_lens = ib->GetInput(kIndex2);
|
||||
auto label_lens = ib->GetInput(kIndex3);
|
||||
|
@ -440,7 +436,7 @@ REG_BPROP_BUILDER("RNNTLoss").SetBody([](const BpropIRBuilder *ib) -> NodePtrLis
|
|||
return {grad, ib->ZerosLike(labels), ib->ZerosLike(act_lens), ib->ZerosLike(label_lens)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("Conv3D").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("Conv3D").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto w = ib->GetInput(kIndex1);
|
||||
auto dout = ib->GetInput(kIndex3);
|
||||
|
@ -477,7 +473,7 @@ REG_BPROP_BUILDER("Conv3D").SetBody([](const BpropIRBuilder *ib) -> NodePtrList
|
|||
return {dx, dw};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("Conv3DTranspose").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("Conv3DTranspose").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) {
|
||||
auto strides = GetValue<std::vector<int64_t>>(ib->GetAttr("strides"));
|
||||
auto dilations = GetValue<std::vector<int64_t>>(ib->GetAttr("dilations"));
|
||||
std::vector<int64_t> stride = {strides.at(kIndex2), strides.at(kIndex3), strides.at(kIndex4)};
|
||||
|
@ -519,7 +515,7 @@ REG_BPROP_BUILDER("Conv3DTranspose").SetBody([](const BpropIRBuilder *ib) -> Nod
|
|||
return {dx, dw};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("MaxPoolWithArgmax").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("MaxPoolWithArgmax").SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto out = ib->GetInput(kIndex1);
|
||||
auto dout = ib->GetInput(kIndex2);
|
||||
|
@ -530,7 +526,7 @@ REG_BPROP_BUILDER("MaxPoolWithArgmax").SetBody([](const BpropIRBuilder *ib) -> N
|
|||
return {dx};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("MaxPoolGradGrad").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("MaxPoolGradGrad").SetUnusedInputs({i2, i3}).SetBody(BODYFUNC(ib) {
|
||||
auto x1 = ib->GetInput(kIndex0);
|
||||
auto x2 = ib->GetInput(kIndex1);
|
||||
auto dout = ib->GetInput(kIndex4);
|
||||
|
@ -545,7 +541,7 @@ REG_BPROP_BUILDER("MaxPoolGradGrad").SetBody([](const BpropIRBuilder *ib) -> Nod
|
|||
return {dx1, dx2, dgrad};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("MaxPoolGrad").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("MaxPoolGrad").SetUnusedInputs({i2, i3}).SetBody(BODYFUNC(ib) {
|
||||
auto device_target = ib->GetTargetFromContext();
|
||||
auto is_ascend = device_target == "Ascend";
|
||||
std::vector<int64_t> kernel_size;
|
||||
|
@ -601,7 +597,7 @@ REG_BPROP_BUILDER("MaxPoolGrad").SetBody([](const BpropIRBuilder *ib) -> NodePtr
|
|||
return {dx1, dx2, dgrad};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("UpsampleNearest3D").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("UpsampleNearest3D").SetUnusedInputs({i1}).SetBody(BODYFUNC(ib) {
|
||||
auto input_x = ib->GetInput(kIndex0);
|
||||
auto dout = ib->GetInput(kIndex2);
|
||||
auto dx = ib->Emit("UpsampleNearest3DGrad", {dout},
|
||||
|
@ -611,7 +607,7 @@ REG_BPROP_BUILDER("UpsampleNearest3D").SetBody([](const BpropIRBuilder *ib) -> N
|
|||
return {dx};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("UpsampleTrilinear3D").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("UpsampleTrilinear3D").SetUnusedInputs({i1}).SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto dout = ib->GetInput(kIndex2);
|
||||
auto dx = ib->Emit("UpsampleTrilinear3DGrad", {dout},
|
||||
|
@ -640,7 +636,7 @@ NodePtrList Dropout2DBpropExpander(const BpropIRBuilder *ib) {
|
|||
REG_BPROP_BUILDER("Dropout2D").SetBody(Dropout2DBpropExpander);
|
||||
REG_BPROP_BUILDER("Dropout3D").SetBody(Dropout2DBpropExpander);
|
||||
|
||||
REG_BPROP_BUILDER("CTCLoss").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("CTCLoss").SetUnusedInputs({i0}).SetBody(BODYFUNC(ib) {
|
||||
auto labels_indices = ib->GetInput(kIndex1);
|
||||
auto labels_values = ib->GetInput(kIndex2);
|
||||
auto sequence_length = ib->GetInput(kIndex3);
|
||||
|
@ -651,7 +647,7 @@ REG_BPROP_BUILDER("CTCLoss").SetBody([](const BpropIRBuilder *ib) -> NodePtrList
|
|||
return {grad, ib->ZerosLike(labels_indices), ib->ZerosLike(labels_values), ib->ZerosLike(sequence_length)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("MaxPool3D").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("MaxPool3D").SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto out = ib->GetInput(kIndex1);
|
||||
auto dout = ib->GetInput(kIndex2);
|
||||
|
@ -664,7 +660,7 @@ REG_BPROP_BUILDER("MaxPool3D").SetBody([](const BpropIRBuilder *ib) -> NodePtrLi
|
|||
return {dx};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("MaxPool3DGrad").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("MaxPool3DGrad").SetUnusedInputs({i2, i3}).SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto y = ib->GetInput(kIndex1);
|
||||
auto dout = ib->GetInput(kIndex4);
|
||||
|
@ -676,7 +672,7 @@ REG_BPROP_BUILDER("MaxPool3DGrad").SetBody([](const BpropIRBuilder *ib) -> NodeP
|
|||
return {ib->ZerosLike(x), ib->ZerosLike(y), dgrad};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("MaxPool3DGradGrad").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("MaxPool3DGradGrad").SetUnusedInputs({i2, i3}).SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto y = ib->GetInput(kIndex1);
|
||||
auto dout = ib->GetInput(kIndex4);
|
||||
|
@ -690,7 +686,7 @@ REG_BPROP_BUILDER("MaxPool3DGradGrad").SetBody([](const BpropIRBuilder *ib) -> N
|
|||
return {ib->ZerosLike(x), ib->ZerosLike(y), dgrad};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("AvgPool").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("AvgPool").SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto out = ib->GetInput(kIndex1);
|
||||
auto dout = ib->GetInput(kIndex2);
|
||||
|
@ -710,7 +706,7 @@ REG_BPROP_BUILDER("AvgPool").SetBody([](const BpropIRBuilder *ib) -> NodePtrList
|
|||
return {dx};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("AvgPool3D").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("AvgPool3D").SetUnusedInputs({i1}).SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto dout = ib->GetInput(kIndex2);
|
||||
auto x_shape = ib->GetShape(x);
|
||||
|
@ -727,7 +723,7 @@ REG_BPROP_BUILDER("AvgPool3D").SetBody([](const BpropIRBuilder *ib) -> NodePtrLi
|
|||
return {dx};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("Mish").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("Mish").SetUnusedInputs({i1}).SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto dout = ib->GetInput(kIndex2);
|
||||
auto dx1 = ib->Emit("Tanh", {ib->Emit("Softplus", {x})});
|
||||
|
@ -736,7 +732,7 @@ REG_BPROP_BUILDER("Mish").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
|||
return {dx};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("SeLU").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("SeLU").SetUnusedInputs({i0}).SetBody(BODYFUNC(ib) {
|
||||
auto scale = 1.0507009873554805;
|
||||
auto out = ib->GetInput(kIndex1);
|
||||
auto dout = ib->GetInput(kIndex2);
|
||||
|
@ -745,14 +741,14 @@ REG_BPROP_BUILDER("SeLU").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
|||
return {dx};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("ReLU6").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("ReLU6").SetUnusedInputs({i1}).SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto dout = ib->GetInput(kIndex2);
|
||||
auto dx = ib->Emit("ReLU6Grad", {dout, x});
|
||||
return {dx};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("ReLUV2").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("ReLUV2").SetUnusedInputs({i0}).SetBody(BODYFUNC(ib) {
|
||||
auto out = ib->GetInput(kIndex1);
|
||||
auto dout = ib->GetInput(kIndex2);
|
||||
auto mask = ib->TupleGetItem(out, 1);
|
||||
|
@ -760,7 +756,7 @@ REG_BPROP_BUILDER("ReLUV2").SetBody([](const BpropIRBuilder *ib) -> NodePtrList
|
|||
return {dx};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("BiasAddGrad").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("BiasAddGrad").SetUnusedInputs({i1}).SetBody(BODYFUNC(ib) {
|
||||
auto data_format = GetValue<std::string>(ib->GetAttr("format"));
|
||||
auto dy = ib->GetInput(kIndex0);
|
||||
auto dout = ib->GetInput(kIndex2);
|
||||
|
@ -788,7 +784,7 @@ REG_BPROP_BUILDER("BiasAddGrad").SetBody([](const BpropIRBuilder *ib) -> NodePtr
|
|||
return {tiled_grad};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("ExtractImagePatches").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("ExtractImagePatches").SetBody(BODYFUNC(ib) {
|
||||
auto ksizes_row = GetValue<std::vector<int64_t>>(ib->GetAttr("ksizes"))[2];
|
||||
auto ksizes_col = GetValue<std::vector<int64_t>>(ib->GetAttr("ksizes"))[3];
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
|
@ -834,35 +830,35 @@ REG_BPROP_BUILDER("ExtractImagePatches").SetBody([](const BpropIRBuilder *ib) ->
|
|||
return {dx};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("HSwish").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("HSwish").SetUnusedInputs({i1}).SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto dout = ib->GetInput(kIndex2);
|
||||
auto dx = ib->Emit("HSwishGrad", {dout, x});
|
||||
return {dx};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("HSigmoid").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("HSigmoid").SetUnusedInputs({i1}).SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto dout = ib->GetInput(kIndex2);
|
||||
auto dx = ib->Emit("HSigmoidGrad", {dout, x});
|
||||
return {dx};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("Elu").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("Elu").SetUnusedInputs({i0}).SetBody(BODYFUNC(ib) {
|
||||
auto out = ib->GetInput(kIndex1);
|
||||
auto dout = ib->GetInput(kIndex2);
|
||||
auto dx = ib->Emit("EluGrad", {dout, out});
|
||||
return {dx};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("Sigmoid").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("Sigmoid").SetUnusedInputs({i0}).SetBody(BODYFUNC(ib) {
|
||||
auto out = ib->GetInput(kIndex1);
|
||||
auto dout = ib->GetInput(kIndex2);
|
||||
auto dx = ib->Emit("SigmoidGrad", {out, dout});
|
||||
return {dx};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("SigmoidGrad").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("SigmoidGrad").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) {
|
||||
auto y = ib->GetInput(kIndex0);
|
||||
auto grad = ib->GetInput(kIndex1);
|
||||
auto dout = ib->GetInput(kIndex3);
|
||||
|
@ -872,21 +868,21 @@ REG_BPROP_BUILDER("SigmoidGrad").SetBody([](const BpropIRBuilder *ib) -> NodePtr
|
|||
return {dy, dgrad};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("LogSoftmax").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("LogSoftmax").SetUnusedInputs({i0}).SetBody(BODYFUNC(ib) {
|
||||
auto out = ib->GetInput(kIndex1);
|
||||
auto dout = ib->GetInput(kIndex2);
|
||||
auto dx = ib->Emit("LogSoftmaxGrad", {out, dout}, {{"axis", ib->GetAttr("axis")}});
|
||||
return {dx};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("Softplus").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("Softplus").SetUnusedInputs({i1}).SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto dout = ib->GetInput(kIndex2);
|
||||
auto dx = ib->Emit("SoftplusGrad", {dout, x});
|
||||
return {dx};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("Softsign").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("Softsign").SetUnusedInputs({i1}).SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto dout = ib->GetInput(kIndex2);
|
||||
auto dx =
|
||||
|
@ -895,7 +891,7 @@ REG_BPROP_BUILDER("Softsign").SetBody([](const BpropIRBuilder *ib) -> NodePtrLis
|
|||
return {dx};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("Tanh").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("Tanh").SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto out = ib->GetInput(kIndex1);
|
||||
auto dout = ib->GetInput(kIndex2);
|
||||
|
@ -911,7 +907,7 @@ REG_BPROP_BUILDER("Tanh").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
|||
return {dx};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("TanhGrad").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("TanhGrad").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) {
|
||||
auto y = ib->GetInput(kIndex0);
|
||||
auto grad = ib->GetInput(kIndex1);
|
||||
auto dout = ib->GetInput(kIndex3);
|
||||
|
@ -936,10 +932,10 @@ NodePtrList FastGeLUBpropExpander(const BpropIRBuilder *ib) {
|
|||
auto dx = ib->Emit("FastGeLUGrad", {dout, x});
|
||||
return {dx};
|
||||
}
|
||||
REG_BPROP_BUILDER("FastGeLU").SetBody(FastGeLUBpropExpander);
|
||||
REG_BPROP_BUILDER("FastGelu").SetBody(FastGeLUBpropExpander);
|
||||
REG_BPROP_BUILDER("FastGeLU").SetUnusedInputs({i1}).SetBody(FastGeLUBpropExpander);
|
||||
REG_BPROP_BUILDER("FastGelu").SetUnusedInputs({i1}).SetBody(FastGeLUBpropExpander);
|
||||
|
||||
REG_BPROP_BUILDER("InstanceNorm").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("InstanceNorm").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto gamma = ib->GetInput(kIndex1);
|
||||
auto mean = ib->GetInput(kIndex3);
|
||||
|
@ -956,7 +952,7 @@ REG_BPROP_BUILDER("InstanceNorm").SetBody([](const BpropIRBuilder *ib) -> NodePt
|
|||
return {dx, dgamma, dbeta, ib->ZerosLike(mean), ib->ZerosLike(variance)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("BatchNorm").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("BatchNorm").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) {
|
||||
auto is_training = GetValue<bool>(ib->GetAttr("is_training"));
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto scale = ib->GetInput(kIndex1);
|
||||
|
@ -983,7 +979,7 @@ REG_BPROP_BUILDER("BatchNorm").SetBody([](const BpropIRBuilder *ib) -> NodePtrLi
|
|||
return {dx, dscale, dbias, ib->ZerosLike(mean), ib->ZerosLike(variance)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("BatchNormGrad").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("BatchNormGrad").SetUnusedInputs({i6}).SetBody(BODYFUNC(ib) {
|
||||
auto dy = ib->GetInput(kIndex0);
|
||||
auto x = ib->GetInput(kIndex1);
|
||||
auto scale = ib->GetInput(kIndex2);
|
||||
|
@ -1003,7 +999,7 @@ REG_BPROP_BUILDER("BatchNormGrad").SetBody([](const BpropIRBuilder *ib) -> NodeP
|
|||
return {ddy, dx, dscale, ib->ZerosLike(mean), ib->ZerosLike(variance), ib->ZerosLike(reserve)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("Softmax").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("Softmax").SetBody(BODYFUNC(ib) {
|
||||
auto axis = GetValue<std::vector<int64_t>>(ib->GetAttr("axis"));
|
||||
auto one_axis = axis[0];
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
|
@ -1019,7 +1015,7 @@ REG_BPROP_BUILDER("Softmax").SetBody([](const BpropIRBuilder *ib) -> NodePtrList
|
|||
return {dx};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("SparseSoftmaxCrossEntropyWithLogits").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("SparseSoftmaxCrossEntropyWithLogits").SetBody(BODYFUNC(ib) {
|
||||
auto is_grad = ib->GetAttr<bool>(kAttrIsGrad);
|
||||
auto labels = ib->GetInput(kIndex1);
|
||||
auto out = ib->GetInput(kIndex2);
|
||||
|
@ -1034,7 +1030,7 @@ REG_BPROP_BUILDER("SparseSoftmaxCrossEntropyWithLogits").SetBody([](const BpropI
|
|||
return {grad, ib->ZerosLike(labels)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("DynamicRNN").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("DynamicRNN").SetUnusedInputs({i3}).SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto w = ib->GetInput(kIndex1);
|
||||
auto b = ib->GetInput(kIndex2);
|
||||
|
@ -1078,7 +1074,7 @@ REG_BPROP_BUILDER("DynamicRNN").SetBody([](const BpropIRBuilder *ib) -> NodePtrL
|
|||
return {dx, dw, db, ib->ZerosLike(ib->Tensor(zero)), dh_prev, dc_prev};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("DynamicGRUV2").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("DynamicGRUV2").SetUnusedInputs({i3, i4, i5}).SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto winput = ib->GetInput(kIndex1);
|
||||
auto whidden = ib->GetInput(kIndex2);
|
||||
|
@ -1114,7 +1110,7 @@ REG_BPROP_BUILDER("DynamicGRUV2").SetBody([](const BpropIRBuilder *ib) -> NodePt
|
|||
return {dx, dw_input, dw_hidden, db_input, db_hidden, ib->ZerosLike(ib->Tensor(zero)), dh_prev};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("AdaptiveMaxPool2D").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("AdaptiveMaxPool2D").SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto out = ib->GetInput(kIndex1);
|
||||
auto dout = ib->GetInput(kIndex2);
|
||||
|
@ -1157,10 +1153,10 @@ NodePtrList Conv2DTransposeBpropExpander(const BpropIRBuilder *ib) {
|
|||
{"pad_list", ib->GetAttr("pad_list")}});
|
||||
return {dx, dw, ib->ZerosLike(f_sizes)};
|
||||
}
|
||||
REG_BPROP_BUILDER("Conv2DTranspose").SetBody(Conv2DTransposeBpropExpander);
|
||||
REG_BPROP_BUILDER("Conv2DBackpropInput").SetBody(Conv2DTransposeBpropExpander);
|
||||
REG_BPROP_BUILDER("Conv2DTranspose").SetUnusedInputs({i2, i3}).SetBody(Conv2DTransposeBpropExpander);
|
||||
REG_BPROP_BUILDER("Conv2DBackpropInput").SetUnusedInputs({i2, i3}).SetBody(Conv2DTransposeBpropExpander);
|
||||
|
||||
REG_BPROP_BUILDER("Conv2DBackpropFilter").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("Conv2DBackpropFilter").SetUnusedInputs({i2, i3}).SetBody(BODYFUNC(ib) {
|
||||
auto dy = ib->GetInput(kIndex0);
|
||||
auto x = ib->GetInput(kIndex1);
|
||||
auto filter_size = ib->GetInput(kIndex2);
|
||||
|
@ -1194,7 +1190,7 @@ REG_BPROP_BUILDER("Conv2DBackpropFilter").SetBody([](const BpropIRBuilder *ib) -
|
|||
return {dw_dy, dw_dx, ib->ZerosLike(filter_size)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("BCEWithLogitsLoss").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("BCEWithLogitsLoss").SetUnusedInputs({i4}).SetBody(BODYFUNC(ib) {
|
||||
auto reduction = GetValue<std::string>(ib->GetAttr("reduction"));
|
||||
auto predict = ib->GetInput(kIndex0);
|
||||
auto target = ib->GetInput(kIndex1);
|
||||
|
@ -1220,7 +1216,7 @@ REG_BPROP_BUILDER("BCEWithLogitsLoss").SetBody([](const BpropIRBuilder *ib) -> N
|
|||
return {dx, grad_target, ib->ZerosLike(weight), ib->ZerosLike(pos_weight)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("KLDivLoss").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("KLDivLoss").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) {
|
||||
auto reduction = GetValue<std::string>(ib->GetAttr("reduction"));
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto y = ib->GetInput(kIndex1);
|
||||
|
@ -1235,21 +1231,21 @@ REG_BPROP_BUILDER("KLDivLoss").SetBody([](const BpropIRBuilder *ib) -> NodePtrLi
|
|||
return {dx, ib->ZerosLike(y)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("HShrink").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("HShrink").SetUnusedInputs({i1}).SetBody(BODYFUNC(ib) {
|
||||
auto features = ib->GetInput(kIndex0);
|
||||
auto gradients = ib->GetInput(kIndex2);
|
||||
auto dx = ib->Emit("HShrinkGrad", {gradients, features}, {{"lambd", ib->GetAttr("lambd")}});
|
||||
return {dx};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("SoftShrink").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("SoftShrink").SetUnusedInputs({i1}).SetBody(BODYFUNC(ib) {
|
||||
auto input_x = ib->GetInput(kIndex0);
|
||||
auto dout = ib->GetInput(kIndex2);
|
||||
auto dx = ib->Emit("SoftShrinkGrad", {dout, input_x}, {{"lambd", ib->GetAttr("lambd")}});
|
||||
return {dx};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("SoftMarginLoss").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("SoftMarginLoss").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) {
|
||||
auto predict = ib->GetInput(kIndex0);
|
||||
auto label = ib->GetInput(kIndex1);
|
||||
auto dout = ib->GetInput(kIndex3);
|
||||
|
@ -1258,7 +1254,7 @@ REG_BPROP_BUILDER("SoftMarginLoss").SetBody([](const BpropIRBuilder *ib) -> Node
|
|||
return {dx, dy};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("MultilabelMarginLoss").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("MultilabelMarginLoss").SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto target = ib->GetInput(kIndex1);
|
||||
auto out = ib->GetInput(kIndex2);
|
||||
|
@ -1268,7 +1264,7 @@ REG_BPROP_BUILDER("MultilabelMarginLoss").SetBody([](const BpropIRBuilder *ib) -
|
|||
return {dx, ib->ZerosLike(target)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("Dilation2D").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("Dilation2D").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto _filter = ib->GetInput(kIndex1);
|
||||
auto dout = ib->GetInput(kIndex3);
|
||||
|
@ -1285,7 +1281,7 @@ REG_BPROP_BUILDER("Dilation2D").SetBody([](const BpropIRBuilder *ib) -> NodePtrL
|
|||
return {dx, dfilter};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("CeLU").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("CeLU").SetBody(BODYFUNC(ib) {
|
||||
auto alpha = GetValue<float>(ib->GetAttr("alpha"));
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto x_dtype = ib->GetDtype(x);
|
||||
|
@ -1302,7 +1298,7 @@ REG_BPROP_BUILDER("CeLU").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
|||
return {dx};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("Pdist").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("Pdist").SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto out = ib->GetInput(kIndex1);
|
||||
auto dout = ib->GetInput(kIndex2);
|
||||
|
@ -1310,7 +1306,7 @@ REG_BPROP_BUILDER("Pdist").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
|||
return {dx};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("MultiMarginLoss").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("MultiMarginLoss").SetUnusedInputs({i3}).SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto target = ib->GetInput(kIndex1);
|
||||
auto weight = ib->GetInput(kIndex2);
|
||||
|
@ -1321,27 +1317,27 @@ REG_BPROP_BUILDER("MultiMarginLoss").SetBody([](const BpropIRBuilder *ib) -> Nod
|
|||
return {dx, ib->ZerosLike(target), ib->ZerosLike(weight)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("DropoutGenMask").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("DropoutGenMask").SetUnusedInputs({i2, i3}).SetBody(BODYFUNC(ib) {
|
||||
auto shape = ib->GetInput(kIndex0);
|
||||
auto keep_prob = ib->GetInput(kIndex1);
|
||||
return {ib->ZerosLike(shape), ib->ZerosLike(keep_prob)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("DropoutDoMask").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("DropoutDoMask").SetUnusedInputs({i0, i3}).SetBody(BODYFUNC(ib) {
|
||||
auto y = ib->GetInput(kIndex1);
|
||||
auto keep_prob = ib->GetInput(kIndex2);
|
||||
auto dout = ib->GetInput(kIndex4);
|
||||
return {ib->Emit("DropoutDoMask", {dout, y, keep_prob}), ib->ZerosLike(y), ib->ZerosLike(keep_prob)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("ReluGrad").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("ReluGrad").SetUnusedInputs({i0, i2}).SetBody(BODYFUNC(ib) {
|
||||
auto y = ib->GetInput(kIndex1);
|
||||
auto dout = ib->GetInput(kIndex3);
|
||||
auto dgrad = ib->Emit("ReluGrad", {dout, y});
|
||||
return {dgrad, ib->ZerosLike(y)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("GridSampler3D").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("GridSampler3D").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) {
|
||||
auto input_x = ib->GetInput(kIndex0);
|
||||
auto grid = ib->GetInput(kIndex1);
|
||||
auto dout = ib->GetInput(kIndex3);
|
||||
|
@ -1354,14 +1350,14 @@ REG_BPROP_BUILDER("GridSampler3D").SetBody([](const BpropIRBuilder *ib) -> NodeP
|
|||
return {dx, dgrid};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("ReLUV3").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("ReLUV3").SetUnusedInputs({i0}).SetBody(BODYFUNC(ib) {
|
||||
auto out = ib->GetInput(kIndex1);
|
||||
auto dout = ib->GetInput(kIndex2);
|
||||
auto dgrad = ib->Emit("ReluGrad", {dout, out});
|
||||
return {dgrad};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("GridSampler2D").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("GridSampler2D").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) {
|
||||
auto input_x = ib->GetInput(kIndex0);
|
||||
auto grid = ib->GetInput(kIndex1);
|
||||
auto dout = ib->GetInput(kIndex3);
|
||||
|
@ -1374,7 +1370,7 @@ REG_BPROP_BUILDER("GridSampler2D").SetBody([](const BpropIRBuilder *ib) -> NodeP
|
|||
return {dx, dgrid};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("ResizeLinear1D").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("ResizeLinear1D").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) {
|
||||
auto input_x = ib->GetInput(kIndex0);
|
||||
auto size = ib->GetInput(kIndex1);
|
||||
auto dout = ib->GetInput(kIndex3);
|
||||
|
@ -1383,7 +1379,7 @@ REG_BPROP_BUILDER("ResizeLinear1D").SetBody([](const BpropIRBuilder *ib) -> Node
|
|||
return {dx, ib->ZerosLike(size)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("MaxPool3DWithArgmax").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("MaxPool3DWithArgmax").SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto out = ib->GetInput(kIndex1);
|
||||
auto dout = ib->GetInput(kIndex2);
|
||||
|
@ -1397,7 +1393,7 @@ REG_BPROP_BUILDER("MaxPool3DWithArgmax").SetBody([](const BpropIRBuilder *ib) ->
|
|||
return {dx};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("MaxUnpool2D").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("MaxUnpool2D").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto argmax = ib->GetInput(kIndex1);
|
||||
auto dout = ib->GetInput(kIndex3);
|
||||
|
@ -1411,7 +1407,7 @@ REG_BPROP_BUILDER("MaxUnpool2D").SetBody([](const BpropIRBuilder *ib) -> NodePtr
|
|||
return {dx, dargmax};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("MaxUnpool3D").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("MaxUnpool3D").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto argmax = ib->GetInput(kIndex1);
|
||||
auto dout = ib->GetInput(kIndex3);
|
||||
|
@ -1425,7 +1421,7 @@ REG_BPROP_BUILDER("MaxUnpool3D").SetBody([](const BpropIRBuilder *ib) -> NodePtr
|
|||
return {dx, dargmax};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("NthElement").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("NthElement").SetBody(BODYFUNC(ib) {
|
||||
auto input_x = ib->GetInput(kIndex0);
|
||||
auto n = ib->GetInput(kIndex1);
|
||||
auto out = ib->GetInput(kIndex2);
|
||||
|
@ -1436,7 +1432,7 @@ REG_BPROP_BUILDER("NthElement").SetBody([](const BpropIRBuilder *ib) -> NodePtrL
|
|||
return {ib->Mul(ib->Div(indicators, num_select), dout), ib->ZerosLike(n)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("AdaptiveAvgPool3D").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("AdaptiveAvgPool3D").SetUnusedInputs({i1}).SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto dout = ib->GetInput(kIndex2);
|
||||
auto x_shape = ib->Tensor(ib->GetShape(x));
|
||||
|
@ -1444,14 +1440,14 @@ REG_BPROP_BUILDER("AdaptiveAvgPool3D").SetBody([](const BpropIRBuilder *ib) -> N
|
|||
return {dx};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("AdaptiveAvgPool2DV1").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("AdaptiveAvgPool2DV1").SetUnusedInputs({i1}).SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto dout = ib->GetInput(kIndex2);
|
||||
auto dx = ib->Emit("AdaptiveAvgPool2DGradV1", {dout}, {{"orig_input_shape", MakeValue(ib->GetShape(x))}});
|
||||
return {dx};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("FractionalMaxPool").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("FractionalMaxPool").SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto out = ib->GetInput(kIndex1);
|
||||
auto dout = ib->GetInput(kIndex2);
|
||||
|
@ -1463,7 +1459,7 @@ REG_BPROP_BUILDER("FractionalMaxPool").SetBody([](const BpropIRBuilder *ib) -> N
|
|||
return {dx};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("FractionalMaxPool3DWithFixedKsize").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("FractionalMaxPool3DWithFixedKsize").SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto random_samples = ib->GetInput(kIndex1);
|
||||
auto out = ib->GetInput(kIndex2);
|
||||
|
@ -1473,7 +1469,7 @@ REG_BPROP_BUILDER("FractionalMaxPool3DWithFixedKsize").SetBody([](const BpropIRB
|
|||
return {dx, ib->ZerosLike(random_samples)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("FractionalAvgPool").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("FractionalAvgPool").SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto out = ib->GetInput(kIndex1);
|
||||
auto dout = ib->GetInput(kIndex2);
|
||||
|
@ -1487,7 +1483,7 @@ REG_BPROP_BUILDER("FractionalAvgPool").SetBody([](const BpropIRBuilder *ib) -> N
|
|||
return {dx};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("PSROIPooling").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("PSROIPooling").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) {
|
||||
auto spatial_scale = ib->GetAttr("spatial_scale");
|
||||
auto group_size = ib->GetAttr("group_size");
|
||||
auto output_dim = ib->GetAttr("output_dim");
|
||||
|
@ -1509,7 +1505,7 @@ REG_BPROP_BUILDER("PSROIPooling").SetBody([](const BpropIRBuilder *ib) -> NodePt
|
|||
return {dx, ib->ZerosLike(rois)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("AvgPoolV1").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("AvgPoolV1").SetUnusedInputs({i1}).SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto dout = ib->GetInput(kIndex2);
|
||||
auto orig_input_shape = ib->Value<ShapeVector>(ib->GetShape(x));
|
||||
|
@ -1523,7 +1519,7 @@ REG_BPROP_BUILDER("AvgPoolV1").SetBody([](const BpropIRBuilder *ib) -> NodePtrLi
|
|||
return {dx};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("MaxPoolV1").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("MaxPoolV1").SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto out = ib->GetInput(kIndex1);
|
||||
auto dout = ib->GetInput(kIndex2);
|
||||
|
@ -1537,7 +1533,7 @@ REG_BPROP_BUILDER("MaxPoolV1").SetBody([](const BpropIRBuilder *ib) -> NodePtrLi
|
|||
return {dx};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("CTCLossV2").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("CTCLossV2").SetBody(BODYFUNC(ib) {
|
||||
auto log_probs = ib->GetInput(kIndex0);
|
||||
auto targets = ib->GetInput(kIndex1);
|
||||
auto input_lengths = ib->GetInput(kIndex2);
|
||||
|
@ -1553,7 +1549,7 @@ REG_BPROP_BUILDER("CTCLossV2").SetBody([](const BpropIRBuilder *ib) -> NodePtrLi
|
|||
return {grad, ib->ZerosLike(targets), ib->ZerosLike(input_lengths), ib->ZerosLike(target_lengths)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("InstanceNormV2").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("InstanceNormV2").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto gamma = ib->GetInput(kIndex1);
|
||||
auto mean = ib->GetInput(kIndex3);
|
||||
|
@ -1573,7 +1569,7 @@ REG_BPROP_BUILDER("InstanceNormV2").SetBody([](const BpropIRBuilder *ib) -> Node
|
|||
return {dx, dgamma, dbeta, ib->ZerosLike(mean), ib->ZerosLike(variance)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("FractionalMaxPoolWithFixedKsize").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("FractionalMaxPoolWithFixedKsize").SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto random_samples = ib->GetInput(kIndex1);
|
||||
auto out = ib->GetInput(kIndex2);
|
||||
|
@ -1585,7 +1581,7 @@ REG_BPROP_BUILDER("FractionalMaxPoolWithFixedKsize").SetBody([](const BpropIRBui
|
|||
return {dx, ib->ZerosLike(random_samples)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("AdaptiveAvgPool2D").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("AdaptiveAvgPool2D").SetUnusedInputs({i1}).SetBody(BODYFUNC(ib) {
|
||||
MS_LOG(WARNING) << "Bprop Expander under testing: " << ib->name();
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto dout = ib->GetInput(kIndex2);
|
||||
|
@ -1593,7 +1589,7 @@ REG_BPROP_BUILDER("AdaptiveAvgPool2D").SetBody([](const BpropIRBuilder *ib) -> N
|
|||
return {dx};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("SparseSoftmaxCrossEntropyWithLogitsV2").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("SparseSoftmaxCrossEntropyWithLogitsV2").SetBody(BODYFUNC(ib) {
|
||||
auto logits = ib->GetInput(kIndex0);
|
||||
auto labels = ib->GetInput(kIndex1);
|
||||
auto out = ib->GetInput(kIndex2);
|
||||
|
@ -1615,7 +1611,7 @@ REG_BPROP_BUILDER("SparseSoftmaxCrossEntropyWithLogitsV2").SetBody([](const Bpro
|
|||
return {grad, ib->ZerosLike(labels)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("DepthwiseConv2dNative").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("DepthwiseConv2dNative").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto w = ib->GetInput(kIndex1);
|
||||
auto dout = ib->GetInput(kIndex3);
|
||||
|
@ -1642,7 +1638,7 @@ REG_BPROP_BUILDER("DepthwiseConv2dNative").SetBody([](const BpropIRBuilder *ib)
|
|||
return {dx, dw};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("PadV3").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("PadV3").SetUnusedInputs({i3}).SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto paddings = ib->GetInput(kIndex1);
|
||||
auto constant_values = ib->GetInput(kIndex2);
|
||||
|
@ -1682,6 +1678,6 @@ NodePtrList CommonMaxMinGradBprop(const BpropIRBuilder *ib) {
|
|||
auto dz = ib->Add(ib->Mul(out0, ib->TupleGetItem(dout, 0)), ib->Mul(out1, ib->TupleGetItem(dout, 1)));
|
||||
return {ib->ZerosLike(x), ib->ZerosLike(y), dz};
|
||||
}
|
||||
REG_BPROP_BUILDER("MaximumGrad").SetBody(CommonMaxMinGradBprop);
|
||||
REG_BPROP_BUILDER("MinimumGrad").SetBody(CommonMaxMinGradBprop);
|
||||
REG_BPROP_BUILDER("MaximumGrad").SetUnusedInputs({i2}).SetBody(CommonMaxMinGradBprop);
|
||||
REG_BPROP_BUILDER("MinimumGrad").SetUnusedInputs({i2}).SetBody(CommonMaxMinGradBprop);
|
||||
} // namespace mindspore::expander::bprop
|
||||
|
|
|
@ -16,26 +16,27 @@
|
|||
|
||||
#include "pipeline/pynative/grad/bprop_expander/bprop_irbuilder.h"
|
||||
#include "include/common/utils/utils.h"
|
||||
#include "pipeline/pynative/grad/bprop_expander/grad_ops/common_utils.h"
|
||||
|
||||
namespace mindspore::expander::bprop {
|
||||
REG_BPROP_BUILDER("Assign").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("Assign").SetUnusedInputs({i0, i2}).SetBody(BODYFUNC(ib) {
|
||||
auto y = ib->GetInput(kIndex1);
|
||||
auto dout = ib->GetInput(kIndex3);
|
||||
return {dout, ib->ZerosLike(y)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("InvertPermutation").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("InvertPermutation").SetUnusedInputs({i1, i2}).SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
return {ib->ZerosLike(x)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("IOU").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("IOU").SetUnusedInputs({i2, i3}).SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto y = ib->GetInput(kIndex1);
|
||||
return {ib->ZerosLike(x), ib->ZerosLike(y)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("SyncBatchNorm").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("SyncBatchNorm").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto scale = ib->GetInput(kIndex1);
|
||||
auto mean = ib->GetInput(kIndex3);
|
||||
|
@ -53,12 +54,12 @@ REG_BPROP_BUILDER("SyncBatchNorm").SetBody([](const BpropIRBuilder *ib) -> NodeP
|
|||
return {dx, dscale, dbias, ib->ZerosLike(mean), ib->ZerosLike(variance)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("GpuConvertToDynamicShape").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("GpuConvertToDynamicShape").SetUnusedInputs({i0, i1}).SetBody(BODYFUNC(ib) {
|
||||
auto dout = ib->GetInput(kIndex2);
|
||||
return {dout};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("_DynamicLossScale").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("_DynamicLossScale").SetUnusedInputs({i0, i2}).SetBody(BODYFUNC(ib) {
|
||||
auto loss_scale = ib->GetInput(kIndex1);
|
||||
auto dout = ib->GetInput(kIndex3);
|
||||
auto res = ib->Emit("Mul", {dout, loss_scale},
|
||||
|
|
|
@ -16,35 +16,36 @@
|
|||
#include "pipeline/pynative/grad/bprop_expander/bprop_irbuilder.h"
|
||||
#include "include/common/utils/utils.h"
|
||||
#include "utils/ms_context.h"
|
||||
#include "pipeline/pynative/grad/bprop_expander/grad_ops/common_utils.h"
|
||||
|
||||
namespace mindspore::expander::bprop {
|
||||
REG_BPROP_BUILDER("BNTrainingReduce").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("BNTrainingReduce").SetUnusedInputs({i1, i2}).SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
return {ib->ZerosLike(x)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("MinMaxUpdatePerLayer").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("MinMaxUpdatePerLayer").SetUnusedInputs({i3, i4}).SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto x_min = ib->GetInput(kIndex1);
|
||||
auto x_max = ib->GetInput(kIndex2);
|
||||
return {ib->ZerosLike(x), ib->ZerosLike(x_min), ib->ZerosLike(x_max)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("MinMaxUpdatePerChannel").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("MinMaxUpdatePerChannel").SetUnusedInputs({i3, i4}).SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto x_min = ib->GetInput(kIndex1);
|
||||
auto x_max = ib->GetInput(kIndex2);
|
||||
return {ib->ZerosLike(x), ib->ZerosLike(x_min), ib->ZerosLike(x_max)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("WtsARQ").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("WtsARQ").SetUnusedInputs({i0, i3}).SetBody(BODYFUNC(ib) {
|
||||
auto w_min = ib->GetInput(kIndex1);
|
||||
auto w_max = ib->GetInput(kIndex2);
|
||||
auto dout = ib->GetInput(kIndex4);
|
||||
return {dout, ib->ZerosLike(w_min), ib->ZerosLike(w_max)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("FakeQuantPerLayer").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("FakeQuantPerLayer").SetUnusedInputs({i3}).SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto x_min = ib->GetInput(kIndex1);
|
||||
auto x_max = ib->GetInput(kIndex2);
|
||||
|
@ -60,7 +61,7 @@ REG_BPROP_BUILDER("FakeQuantPerLayer").SetBody([](const BpropIRBuilder *ib) -> N
|
|||
return {dx, ib->ZerosLike(x_min), ib->ZerosLike(x_max)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("FakeQuantWithMinMaxVars").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("FakeQuantWithMinMaxVars").SetUnusedInputs({i3}).SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto x_min = ib->GetInput(kIndex1);
|
||||
auto x_max = ib->GetInput(kIndex2);
|
||||
|
@ -70,7 +71,7 @@ REG_BPROP_BUILDER("FakeQuantWithMinMaxVars").SetBody([](const BpropIRBuilder *ib
|
|||
return {dx, ib->ZerosLike(x_min), ib->ZerosLike(x_max)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("FakeQuantWithMinMaxVarsPerChannel").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("FakeQuantWithMinMaxVarsPerChannel").SetUnusedInputs({i3}).SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto x_min = ib->GetInput(kIndex1);
|
||||
auto x_max = ib->GetInput(kIndex2);
|
||||
|
@ -80,7 +81,7 @@ REG_BPROP_BUILDER("FakeQuantWithMinMaxVarsPerChannel").SetBody([](const BpropIRB
|
|||
return {dx, ib->ZerosLike(x_min), ib->ZerosLike(x_max)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("FakeQuantPerChannel").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("FakeQuantPerChannel").SetUnusedInputs({i3}).SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto x_min = ib->GetInput(kIndex1);
|
||||
auto x_max = ib->GetInput(kIndex2);
|
||||
|
@ -94,7 +95,7 @@ REG_BPROP_BUILDER("FakeQuantPerChannel").SetBody([](const BpropIRBuilder *ib) ->
|
|||
return {dx, ib->ZerosLike(x_min), ib->ZerosLike(x_max)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("BatchNormFold").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("BatchNormFold").SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto mean = ib->GetInput(kIndex1);
|
||||
auto variance = ib->GetInput(kIndex2);
|
||||
|
@ -110,7 +111,7 @@ REG_BPROP_BUILDER("BatchNormFold").SetBody([](const BpropIRBuilder *ib) -> NodeP
|
|||
return {dx, ib->ZerosLike(mean), ib->ZerosLike(variance), ib->ZerosLike(global_step)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("CorrectionMul").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("CorrectionMul").SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto batch_std = ib->GetInput(kIndex1);
|
||||
auto running_std = ib->GetInput(kIndex2);
|
||||
|
@ -126,7 +127,7 @@ REG_BPROP_BUILDER("CorrectionMul").SetBody([](const BpropIRBuilder *ib) -> NodeP
|
|||
return {dx, d_batch_std, ib->ZerosLike(running_std)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("BatchNormFold2").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("BatchNormFold2").SetUnusedInputs({i1, i8}).SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto gamma = ib->GetInput(kIndex2);
|
||||
auto batch_std = ib->GetInput(kIndex3);
|
||||
|
@ -153,7 +154,7 @@ REG_BPROP_BUILDER("BatchNormFold2").SetBody([](const BpropIRBuilder *ib) -> Node
|
|||
ib->ZerosLike(global_step)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("BatchNormFoldD").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("BatchNormFoldD").SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto x_sum = ib->GetInput(kIndex1);
|
||||
auto x_square_sum = ib->GetInput(kIndex2);
|
||||
|
@ -170,7 +171,7 @@ REG_BPROP_BUILDER("BatchNormFoldD").SetBody([](const BpropIRBuilder *ib) -> Node
|
|||
return {dx, ib->ZerosLike(x_sum), ib->ZerosLike(x_square_sum), ib->ZerosLike(mean), ib->ZerosLike(variance)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("BatchNormFold2D").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("BatchNormFold2D").SetUnusedInputs({i1, i6}).SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto gamma = ib->GetInput(kIndex2);
|
||||
auto batch_std = ib->GetInput(kIndex3);
|
||||
|
@ -189,7 +190,7 @@ REG_BPROP_BUILDER("BatchNormFold2D").SetBody([](const BpropIRBuilder *ib) -> Nod
|
|||
return {d_x, dout_reduce, d_gamma, d_batch_std, d_batch_mean, ib->ZerosLike(running_std)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("ActsULQ").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("ActsULQ").SetUnusedInputs({i0, i1, i2}).SetBody(BODYFUNC(ib) {
|
||||
auto out = ib->GetInput(kIndex3);
|
||||
auto dout = ib->GetInput(kIndex4);
|
||||
auto dout0 = ib->TupleGetItem(dout, kIndex0);
|
||||
|
@ -202,7 +203,7 @@ REG_BPROP_BUILDER("ActsULQ").SetBody([](const BpropIRBuilder *ib) -> NodePtrList
|
|||
return {dx, dx1, dx2};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("FakeLearnedScaleQuantPerLayer").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("FakeLearnedScaleQuantPerLayer").SetUnusedInputs({i3}).SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto x_alpha = ib->GetInput(kIndex1);
|
||||
auto x_quant_max = ib->GetInput(kIndex2);
|
||||
|
@ -214,7 +215,7 @@ REG_BPROP_BUILDER("FakeLearnedScaleQuantPerLayer").SetBody([](const BpropIRBuild
|
|||
return {dx, dalpha, ib->ZerosLike(x_quant_max)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("FakeLearnedScaleQuantPerChannel").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("FakeLearnedScaleQuantPerChannel").SetUnusedInputs({i3}).SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto x_alpha = ib->GetInput(kIndex1);
|
||||
auto x_quant_max = ib->GetInput(kIndex2);
|
||||
|
|
|
@ -16,9 +16,10 @@
|
|||
|
||||
#include "pipeline/pynative/grad/bprop_expander/bprop_irbuilder.h"
|
||||
#include "include/common/utils/utils.h"
|
||||
#include "pipeline/pynative/grad/bprop_expander/grad_ops/common_utils.h"
|
||||
|
||||
namespace mindspore::expander::bprop {
|
||||
REG_BPROP_BUILDER("SolveTriangular").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("SolveTriangular").SetUnusedInputs({i1}).SetBody(BODYFUNC(ib) {
|
||||
auto reverse_perm = [](const ShapeVector &shape) -> ShapeVector {
|
||||
ShapeVector perm;
|
||||
for (int64_t i = shape.size() - 1; i >= 0; --i) {
|
||||
|
@ -58,7 +59,7 @@ REG_BPROP_BUILDER("SolveTriangular").SetBody([](const BpropIRBuilder *ib) -> Nod
|
|||
return {grad_a, grad_b};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("Eigh").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("Eigh").SetBody(BODYFUNC(ib) {
|
||||
auto is_compute_v = GetValue<bool>(ib->GetAttr("compute_eigenvectors"));
|
||||
auto is_lower = GetValue<bool>(ib->GetAttr("lower"));
|
||||
auto lower = static_cast<int64_t>(is_lower);
|
||||
|
|
|
@ -112,14 +112,14 @@ ShapeVector InferOutShape(const ShapeVector &sh1, const ShapeVector &sh2) {
|
|||
}
|
||||
}; // namespace
|
||||
|
||||
REG_BPROP_BUILDER("SparseToDense").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("SparseToDense").SetUnusedInputs({i1, i3}).SetBody(BODYFUNC(ib) {
|
||||
auto indices = ib->GetInput(kIndex0);
|
||||
auto dense_shape = ib->GetInput(kIndex2);
|
||||
auto dout = ib->GetInput(kIndex4);
|
||||
return {ib->ZerosLike(indices), ib->Emit("GatherNd", {dout, indices}), ib->ZerosLike(dense_shape)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("SparseToDenseV2").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("SparseToDenseV2").SetUnusedInputs({i2, i3, i4}).SetBody(BODYFUNC(ib) {
|
||||
auto indices = ib->GetInput(kIndex0);
|
||||
auto output_shape = ib->GetInput(kIndex1);
|
||||
auto dout = ib->GetInput(kIndex5);
|
||||
|
@ -128,7 +128,7 @@ REG_BPROP_BUILDER("SparseToDenseV2").SetBody([](const BpropIRBuilder *ib) -> Nod
|
|||
return {ib->ZerosLike(indices), ib->ZerosLike(output_shape), sparse_values_grad, default_value_grad};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("SparseTensorDenseMatmul").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("SparseTensorDenseMatmul").SetUnusedInputs({i4}).SetBody(BODYFUNC(ib) {
|
||||
auto adj_s = ib->GetAttr<bool>("adjoint_st");
|
||||
auto adj_d = ib->GetAttr<bool>("adjoint_dt");
|
||||
auto indices = ib->GetInput(kIndex0);
|
||||
|
@ -175,7 +175,7 @@ REG_BPROP_BUILDER("SparseTensorDenseMatmul").SetBody([](const BpropIRBuilder *ib
|
|||
return {ib->ZerosLike(indices), values_grad, ib->ZerosLike(dense_shape), dense_grad};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("SparseAdd").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("SparseAdd").SetBody(BODYFUNC(ib) {
|
||||
auto x1_indices = ib->GetInput(kIndex0);
|
||||
auto x1_values = ib->GetInput(kIndex1);
|
||||
auto x1_shape = ib->GetInput(kIndex2);
|
||||
|
@ -198,7 +198,7 @@ REG_BPROP_BUILDER("SparseAdd").SetBody([](const BpropIRBuilder *ib) -> NodePtrLi
|
|||
return {ret0, ret1, ret2, ret3, ret4, ret5, ret6};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("CSRReduceSum").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("CSRReduceSum").SetUnusedInputs({i2, i5}).SetBody(BODYFUNC(ib) {
|
||||
auto indptr = ib->GetInput(kIndex0);
|
||||
auto indices = ib->GetInput(kIndex1);
|
||||
auto shape = ib->GetInput(kIndex3);
|
||||
|
@ -213,7 +213,7 @@ REG_BPROP_BUILDER("CSRReduceSum").SetBody([](const BpropIRBuilder *ib) -> NodePt
|
|||
ib->ZerosLike(axis)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("CSRMV").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("CSRMV").SetUnusedInputs({i5}).SetBody(BODYFUNC(ib) {
|
||||
auto indptr = ib->GetInput(kIndex0);
|
||||
auto indices = ib->GetInput(kIndex1);
|
||||
auto values = ib->GetInput(kIndex2);
|
||||
|
@ -244,7 +244,7 @@ REG_BPROP_BUILDER("CSRMV").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
|||
return {ib->ZerosLike(indptr), ib->ZerosLike(indices), values_grad, ib->ZerosLike(zero), dense_grad};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("CSRMul").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("CSRMul").SetUnusedInputs({i5}).SetBody(BODYFUNC(ib) {
|
||||
auto indptr = ib->GetInput(kIndex0);
|
||||
auto indices = ib->GetInput(kIndex1);
|
||||
auto values = ib->GetInput(kIndex2);
|
||||
|
@ -271,7 +271,7 @@ REG_BPROP_BUILDER("CSRMul").SetBody([](const BpropIRBuilder *ib) -> NodePtrList
|
|||
dense_grad};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("CSRDiv").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("CSRDiv").SetUnusedInputs({i2}).SetBody(BODYFUNC(ib) {
|
||||
auto indptr = ib->GetInput(kIndex0);
|
||||
auto indices = ib->GetInput(kIndex1);
|
||||
auto shape_node = ib->GetInput(kIndex3);
|
||||
|
@ -320,26 +320,26 @@ REG_BPROP_BUILDER("CSRDiv").SetBody([](const BpropIRBuilder *ib) -> NodePtrList
|
|||
dense_grad};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("CSR2COO").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("CSR2COO").SetUnusedInputs({i2, i3}).SetBody(BODYFUNC(ib) {
|
||||
auto indptr = ib->GetInput(kIndex0);
|
||||
auto nnz = ib->GetInput(kIndex1);
|
||||
return {ib->ZerosLike(indptr), ib->ZerosLike(nnz)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("COO2CSR").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("COO2CSR").SetUnusedInputs({i2, i3}).SetBody(BODYFUNC(ib) {
|
||||
auto row_indices = ib->GetInput(kIndex0);
|
||||
auto height = ib->GetInput(kIndex1);
|
||||
return {ib->ZerosLike(row_indices), ib->ZerosLike(height)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("MakeCOOTensor").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("MakeCOOTensor").SetUnusedInputs({i1, i2, i3}).SetBody(BODYFUNC(ib) {
|
||||
auto indices = ib->GetInput(kIndex0);
|
||||
auto dout = ib->GetInput(kIndex4);
|
||||
auto dout_values = ib->TupleGetItem(dout, kIndex1);
|
||||
return {ib->ZerosLike(indices), dout_values};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("COOTensorGetIndices").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("COOTensorGetIndices").SetUnusedInputs({i1}).SetBody(BODYFUNC(ib) {
|
||||
auto coo_tensor = ib->GetInput(kIndex0);
|
||||
auto dout = ib->GetInput(kIndex2);
|
||||
auto coo_tensor_values = ib->TupleGetItem(coo_tensor, kIndex1);
|
||||
|
@ -347,7 +347,7 @@ REG_BPROP_BUILDER("COOTensorGetIndices").SetBody([](const BpropIRBuilder *ib) ->
|
|||
return {ib->MakeTuple({dout, ib->ZerosLike(coo_tensor_values), coo_tensor_shape})};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("COOTensorGetValues").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("COOTensorGetValues").SetUnusedInputs({i1}).SetBody(BODYFUNC(ib) {
|
||||
auto coo_tensor = ib->GetInput(kIndex0);
|
||||
auto dout = ib->GetInput(kIndex2);
|
||||
auto coo_tensor_indices = ib->TupleGetItem(coo_tensor, kIndex0);
|
||||
|
@ -355,12 +355,12 @@ REG_BPROP_BUILDER("COOTensorGetValues").SetBody([](const BpropIRBuilder *ib) ->
|
|||
return {ib->MakeTuple({ib->ZerosLike(coo_tensor_indices), dout, coo_tensor_shape})};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("COOTensorGetDenseShape").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("COOTensorGetDenseShape").SetUnusedInputs({i1, i2}).SetBody(BODYFUNC(ib) {
|
||||
auto coo_tensor = ib->GetInput(kIndex0);
|
||||
return {ib->ZerosLike(coo_tensor)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("MakeCSRTensor").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("MakeCSRTensor").SetUnusedInputs({i2, i3, i4}).SetBody(BODYFUNC(ib) {
|
||||
auto indptr = ib->GetInput(kIndex0);
|
||||
auto indices = ib->GetInput(kIndex1);
|
||||
auto dout = ib->GetInput(kIndex5);
|
||||
|
@ -369,7 +369,7 @@ REG_BPROP_BUILDER("MakeCSRTensor").SetBody([](const BpropIRBuilder *ib) -> NodeP
|
|||
return {ib->ZerosLike(indptr), ib->ZerosLike(indices), dout_values, dout_shape};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("CSRTensorGetIndptr").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("CSRTensorGetIndptr").SetUnusedInputs({i1}).SetBody(BODYFUNC(ib) {
|
||||
auto csr_tensor = ib->GetInput(kIndex0);
|
||||
auto dout = ib->GetInput(kIndex2);
|
||||
auto csr_tensor_indices = ib->TupleGetItem(csr_tensor, kIndex1);
|
||||
|
@ -378,7 +378,7 @@ REG_BPROP_BUILDER("CSRTensorGetIndptr").SetBody([](const BpropIRBuilder *ib) ->
|
|||
return {ib->MakeTuple({dout, ib->ZerosLike(csr_tensor_indices), ib->ZerosLike(csr_tensor_values), csr_tensor_shape})};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("CSRTensorGetIndices").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("CSRTensorGetIndices").SetUnusedInputs({i1}).SetBody(BODYFUNC(ib) {
|
||||
auto csr_tensor = ib->GetInput(kIndex0);
|
||||
auto dout = ib->GetInput(kIndex2);
|
||||
auto csr_tensor_indptr = ib->TupleGetItem(csr_tensor, kIndex0);
|
||||
|
@ -387,7 +387,7 @@ REG_BPROP_BUILDER("CSRTensorGetIndices").SetBody([](const BpropIRBuilder *ib) ->
|
|||
return {ib->MakeTuple({ib->ZerosLike(csr_tensor_indptr), dout, ib->ZerosLike(csr_tensor_values), csr_tensor_shape})};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("CSRTensorGetValues").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("CSRTensorGetValues").SetUnusedInputs({i1}).SetBody(BODYFUNC(ib) {
|
||||
auto csr_tensor = ib->GetInput(kIndex0);
|
||||
auto dout = ib->GetInput(kIndex2);
|
||||
auto csr_tensor_indptr = ib->TupleGetItem(csr_tensor, kIndex0);
|
||||
|
@ -396,12 +396,12 @@ REG_BPROP_BUILDER("CSRTensorGetValues").SetBody([](const BpropIRBuilder *ib) ->
|
|||
return {ib->MakeTuple({ib->ZerosLike(csr_tensor_indptr), ib->ZerosLike(csr_tensor_indices), dout, csr_tensor_shape})};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("CSRTensorGetDenseShape").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("CSRTensorGetDenseShape").SetUnusedInputs({i1, i2}).SetBody(BODYFUNC(ib) {
|
||||
auto csr_tensor = ib->GetInput(kIndex0);
|
||||
return {ib->ZerosLike(csr_tensor)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("CSRSparseMatrixToDense").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("CSRSparseMatrixToDense").SetUnusedInputs({i5}).SetBody(BODYFUNC(ib) {
|
||||
auto shape = ib->GetInput(kIndex0);
|
||||
auto batch = ib->GetInput(kIndex1);
|
||||
auto indptr = ib->GetInput(kIndex2);
|
||||
|
@ -414,7 +414,7 @@ REG_BPROP_BUILDER("CSRSparseMatrixToDense").SetBody([](const BpropIRBuilder *ib)
|
|||
ib->TupleGetItem(res, kIndex3), ib->TupleGetItem(res, kIndex4)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("DenseToCSRSparseMatrix").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("DenseToCSRSparseMatrix").SetBody(BODYFUNC(ib) {
|
||||
auto indices = ib->GetInput(kIndex1);
|
||||
auto out = ib->GetInput(kIndex2);
|
||||
auto dout = ib->GetInput(kIndex3);
|
||||
|
@ -446,7 +446,7 @@ REG_BPROP_BUILDER("DenseToCSRSparseMatrix").SetBody([](const BpropIRBuilder *ib)
|
|||
return {ib->Emit("CSRSparseMatrixToDense", {shape, batch_ptr, row_ptr, col_ind, dvalue}), ib->ZerosLike(indices)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("SparseSoftmax").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("SparseSoftmax").SetBody(BODYFUNC(ib) {
|
||||
auto indices = ib->GetInput(kIndex0);
|
||||
auto values = ib->GetInput(kIndex1);
|
||||
auto shape = ib->GetInput(kIndex2);
|
||||
|
@ -463,7 +463,7 @@ REG_BPROP_BUILDER("SparseSoftmax").SetBody([](const BpropIRBuilder *ib) -> NodeP
|
|||
return {ib->ZerosLike(indices), grad_x, ib->ZerosLike(shape)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("SparseTensorToCSRSparseMatrix").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("SparseTensorToCSRSparseMatrix").SetUnusedInputs({i0, i1, i2, i3}).SetBody(BODYFUNC(ib) {
|
||||
auto dout = ib->GetInput(kIndex4);
|
||||
auto dx = ib->Emit("CSRSparseMatrixToSparseTensor",
|
||||
{ib->TupleGetItem(dout, kIndex0), ib->TupleGetItem(dout, kIndex1), ib->TupleGetItem(dout, kIndex2),
|
||||
|
@ -471,7 +471,7 @@ REG_BPROP_BUILDER("SparseTensorToCSRSparseMatrix").SetBody([](const BpropIRBuild
|
|||
return {ib->TupleGetItem(dx, kIndex0), ib->TupleGetItem(dx, kIndex1), ib->TupleGetItem(dx, kIndex2)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("CSRSparseMatrixToSparseTensor").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("CSRSparseMatrixToSparseTensor").SetUnusedInputs({i0, i1, i2, i3, i4, i5}).SetBody(BODYFUNC(ib) {
|
||||
auto dout = ib->GetInput(kIndex6);
|
||||
auto dx = ib->Emit("SparseTensorToCSRSparseMatrix", {ib->TupleGetItem(dout, kIndex0), ib->TupleGetItem(dout, kIndex1),
|
||||
ib->TupleGetItem(dout, kIndex2)});
|
||||
|
@ -518,40 +518,40 @@ NodePtrList CommonSparseSegmentBpropForCpu(const BpropIRBuilder *ib, bool with_s
|
|||
return result;
|
||||
}
|
||||
|
||||
REG_BPROP_BUILDER("SparseSegmentSqrtN").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("SparseSegmentSqrtN").SetUnusedInputs({i0, i1, i2, i3, i4}).SetBody(BODYFUNC(ib) {
|
||||
return CommonSparseSegmentBprop(ib, "SparseSegmentSqrtNGrad", false);
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("SparseSegmentSqrtNWithNumSegments").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("SparseSegmentSqrtNWithNumSegments").SetUnusedInputs({i0, i1, i2, i3, i4, i5}).SetBody(BODYFUNC(ib) {
|
||||
return CommonSparseSegmentBprop(ib, "SparseSegmentSqrtNGrad", true);
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("SparseSegmentSum").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("SparseSegmentSum").SetUnusedInputs({i0, i1, i2, i3, i4}).SetBody(BODYFUNC(ib) {
|
||||
if (ib->GetTargetFromContext() == kGPUDevice) {
|
||||
return CommonSparseSegmentBprop(ib, "SparseSegmentSumGrad", false);
|
||||
}
|
||||
return CommonSparseSegmentBpropForCpu(ib, false);
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("SparseSegmentSumWithNumSegments").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("SparseSegmentSumWithNumSegments").SetUnusedInputs({i0, i1, i2, i3, i4, i5}).SetBody(BODYFUNC(ib) {
|
||||
if (ib->GetTargetFromContext() == kGPUDevice) {
|
||||
return CommonSparseSegmentBprop(ib, "SparseSegmentSumGrad", true);
|
||||
}
|
||||
return CommonSparseSegmentBpropForCpu(ib, true);
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("SparseTensorDenseAdd").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("SparseTensorDenseAdd").SetUnusedInputs({i1, i3, i4}).SetBody(BODYFUNC(ib) {
|
||||
auto x1_indices = ib->GetInput(kIndex0);
|
||||
auto x1_shape = ib->GetInput(kIndex2);
|
||||
auto dout = ib->GetInput(kIndex5);
|
||||
return {ib->ZerosLike(x1_indices), ib->Emit("GatherNd", {dout, x1_indices}), ib->ZerosLike(x1_shape), dout};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("SparseSegmentMeanWithNumSegments").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("SparseSegmentMeanWithNumSegments").SetUnusedInputs({i0, i1, i2, i3, i4, i5}).SetBody(BODYFUNC(ib) {
|
||||
return CommonSparseSegmentBprop(ib, "SparseSegmentMeanGrad", true);
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("SparseReorder").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("SparseReorder").SetUnusedInputs({i1, i3}).SetBody(BODYFUNC(ib) {
|
||||
auto indices = ib->GetInput(kIndex0);
|
||||
auto shape = ib->GetInput(kIndex2);
|
||||
auto dout = ib->GetInput(kIndex4);
|
||||
|
@ -573,7 +573,7 @@ REG_BPROP_BUILDER("SparseReorder").SetBody([](const BpropIRBuilder *ib) -> NodeP
|
|||
return res;
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("SparseDenseCwiseMul").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("SparseDenseCwiseMul").SetUnusedInputs({i4}).SetBody(BODYFUNC(ib) {
|
||||
auto x1_indices = ib->GetInput(kIndex0);
|
||||
auto x1_values = ib->GetInput(kIndex1);
|
||||
auto x1_shape = ib->GetInput(kIndex2);
|
||||
|
@ -593,7 +593,7 @@ REG_BPROP_BUILDER("SparseDenseCwiseMul").SetBody([](const BpropIRBuilder *ib) ->
|
|||
return d_all;
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("SparseDenseCwiseDiv").SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
REG_BPROP_BUILDER("SparseDenseCwiseDiv").SetUnusedInputs({i4}).SetBody(BODYFUNC(ib) {
|
||||
auto x1_indices = ib->GetInput(kIndex0);
|
||||
auto x1_values = ib->GetInput(kIndex1);
|
||||
auto x1_shape = ib->GetInput(kIndex2);
|
||||
|
|
|
@ -28,6 +28,7 @@
|
|||
#include "frontend/optimizer/ad/grad.h"
|
||||
#include "frontend/optimizer/expander.h"
|
||||
#include "pipeline/jit/pass.h"
|
||||
#include "pipeline/pynative/grad/bprop_expander/bprop.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace pynative {
|
||||
|
@ -35,6 +36,15 @@ namespace {
|
|||
const mindspore::HashSet<std::string> kHookOp = {"HookBackward", "CellBackwardHook"};
|
||||
const char kGrad[] = "grad";
|
||||
|
||||
void ClearDeviceAddress(const ValuePtr &value) {
|
||||
std::vector<tensor::TensorPtr> tensors;
|
||||
TensorValueToTensor(value, &tensors);
|
||||
for (auto tensor : tensors) {
|
||||
tensor->set_device_address(nullptr);
|
||||
tensor->set_is_forward_output(false);
|
||||
}
|
||||
}
|
||||
|
||||
std::string GetCellId(const py::object &obj, const py::args &args, const InputArgsInfoPtr &input_args_info) {
|
||||
auto cell_id = PyNativeAlgo::PyParser::GetIdByPyObj(obj);
|
||||
auto fn = [&cell_id](const abstract::AbstractBasePtr &abs) {
|
||||
|
@ -443,20 +453,24 @@ void GradExecutor::HandleInputArgsForTopCell(const InputArgsInfoPtr &input_args_
|
|||
if (input_args_info->input_size != 0 && input_value.empty()) {
|
||||
MS_LOG(EXCEPTION) << "Input value is empty";
|
||||
}
|
||||
AbstractBasePtrList abs_list;
|
||||
for (size_t i = 0; i < input_args_info->input_size; ++i) {
|
||||
const auto &v = input_value[i];
|
||||
if (!PyNativeAlgo::Common::IsTensor(v)) {
|
||||
continue;
|
||||
}
|
||||
auto new_param = curr_g()->add_parameter();
|
||||
(void)input_param_values.emplace_back(v);
|
||||
auto cloned_v = ShallowCopyTensorValue(v);
|
||||
ClearDeviceAddress(cloned_v);
|
||||
(void)input_param_values.emplace_back(cloned_v);
|
||||
auto param_i_abs = v->ToAbstract();
|
||||
MS_EXCEPTION_IF_NULL(param_i_abs);
|
||||
param_i_abs = param_i_abs->Broaden();
|
||||
new_param->set_abstract(param_i_abs);
|
||||
(void)abs_list.emplace_back(param_i_abs);
|
||||
top_cell()->SetParamNodeMapInGraphInfoMap(input_args_info->input_arg_id_vec[i], new_param);
|
||||
}
|
||||
top_cell()->set_auto_grad_cell_ptr(ad::GradPynativeCellBegin(curr_g()->parameters(), input_param_values));
|
||||
top_cell()->set_auto_grad_cell_ptr(ad::GradPynativeCellBegin(curr_g()->parameters(), input_param_values, abs_list));
|
||||
}
|
||||
|
||||
void GradExecutor::InitResourceAndDfBuilder(const InputArgsInfoPtr &input_args_info) {
|
||||
|
@ -577,6 +591,7 @@ void GradExecutor::SetForwardLastNodeInfo(const ValuePtr &v, const std::string &
|
|||
auto auto_grad_cell_ptr = top_cell()->auto_grad_cell_ptr();
|
||||
MS_EXCEPTION_IF_NULL(auto_grad_cell_ptr);
|
||||
auto cloned_value = ShallowCopyTensorValue(v);
|
||||
ClearDeviceAddress(cloned_value);
|
||||
if (!MsContext::GetInstance()->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_SYNCHRONIZE)) {
|
||||
AsyncUpdateOutputNodeOfTopCell(output_node, cloned_value);
|
||||
} else {
|
||||
|
@ -1552,6 +1567,19 @@ void GradExecutor::DoOpGrad(const FrontendOpRunInfoPtr &op_run_info, const CNode
|
|||
std::back_inserter(cloned_op_args),
|
||||
[](const ValuePtr &value) { return ShallowCopyTensorValue(value); });
|
||||
ValuePtr cloned_out = ShallowCopyTensorValue(op_out);
|
||||
const auto &unused_inputs = BpropExpander().GetUnusedInputs(cnode);
|
||||
auto is_unused_index = [&unused_inputs](size_t i) {
|
||||
return std::find(unused_inputs.begin(), unused_inputs.end(), i) != unused_inputs.end();
|
||||
};
|
||||
for (size_t i = 0; i < cloned_op_args.size(); i++) {
|
||||
if (is_unused_index(i)) {
|
||||
ClearDeviceAddress(cloned_op_args[i]);
|
||||
}
|
||||
}
|
||||
if (is_unused_index(cloned_op_args.size())) {
|
||||
ClearDeviceAddress(cloned_out);
|
||||
}
|
||||
|
||||
auto grad_param =
|
||||
std::make_shared<ad::GradParam>(cnode, cloned_op_args, cloned_out, nullptr, !top_cell()->is_high_order_top_cell());
|
||||
auto auto_grad_cell_ptr = top_cell()->auto_grad_cell_ptr();
|
||||
|
|
Loading…
Reference in New Issue