forked from mindspore-Ecosystem/mindspore
!30211 [MS]Add Higher Order Differentiation
Merge pull request !30211 from chenzhuo/taylor_v1
This commit is contained in:
commit
7f9056ad69
|
@ -944,6 +944,80 @@ REGISTER_PYBIND_DEFINE(VmapOperation_, ([](const py::module *m) {
|
|||
.def(py::init<std::string &>(), py::arg("fn"));
|
||||
}));
|
||||
|
||||
TaylorOperation::TaylorOperation(const std::string &name) : MetaFuncGraph(name) {
|
||||
// def Taylor(func:read):
|
||||
signatures_ = std::vector<Signature>({{"func", SignatureEnumRW::kRWRead, SignatureEnumKind::kKindDefault}});
|
||||
}
|
||||
|
||||
FuncGraphPtr TaylorOperation::GetTaylorGrad(const AnfNodePtr &k, const std::vector<AnfNodePtr> &forward_graph_params) {
|
||||
FuncGraphPtr k_child = std::make_shared<FuncGraph>();
|
||||
k_child->set_flag(FUNC_GRAPH_FLAG_CORE, true);
|
||||
|
||||
std::vector<AnfNodePtr> inputs;
|
||||
inputs.push_back(k);
|
||||
MS_LOG(INFO) << "TaylorOperation forward input size " << forward_graph_params.size();
|
||||
for (size_t i = 0; i < forward_graph_params.size(); ++i) {
|
||||
inputs.push_back(k_child->add_parameter());
|
||||
}
|
||||
// Taylor(fn)(input params)
|
||||
auto k_app = k_child->NewCNodeInOrder(inputs);
|
||||
|
||||
k_child->set_output(k_app);
|
||||
return k_child;
|
||||
}
|
||||
|
||||
// Generate the graph to calculate higher order derivatives.
|
||||
FuncGraphPtr TaylorOperation::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
|
||||
if (args_spec_list.empty()) {
|
||||
MS_LOG(EXCEPTION)
|
||||
<< "'TaylorOperation' requires a forward network or function as an input, while the input is empty.";
|
||||
}
|
||||
|
||||
MS_EXCEPTION_IF_NULL(args_spec_list[0]);
|
||||
AbstractFunctionPtr fn = dyn_cast<AbstractFunction>(args_spec_list[0]);
|
||||
if (fn == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "'TaylorOperation' arg0 must be a 'Function' or 'Cell', but got "
|
||||
<< args_spec_list[0]->ToString();
|
||||
}
|
||||
|
||||
auto real_fn = dyn_cast<FuncGraphAbstractClosure>(fn);
|
||||
MS_EXCEPTION_IF_NULL(real_fn);
|
||||
|
||||
FuncGraphPtr forward_graph = real_fn->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(forward_graph);
|
||||
forward_graph->set_flag(FUNC_GRAPH_FLAG_DEFER_INLINE, true);
|
||||
FuncGraphPtr grad_fg = nullptr;
|
||||
MS_LOG(INFO) << "'TaylorOperation' forward_graph" << forward_graph->debug_info();
|
||||
grad_fg = std::make_shared<FuncGraph>();
|
||||
auto nparam = forward_graph->parameters().size();
|
||||
|
||||
std::ostringstream ss;
|
||||
ss << "taylorgrad{" << nparam << "}";
|
||||
grad_fg->set_flag(FUNC_GRAPH_FLAG_CORE, true);
|
||||
grad_fg->debug_info()->set_name(ss.str());
|
||||
ParameterPtr param_graph = grad_fg->add_parameter();
|
||||
|
||||
std::vector<AnfNodePtr> inputs;
|
||||
inputs.push_back(NewValueNode(prim::kPrimTaylor));
|
||||
inputs.push_back(param_graph);
|
||||
// Taylor(fn)
|
||||
auto mark_taylor = grad_fg->NewCNodeInOrder(inputs);
|
||||
FuncGraphPtr k_child = nullptr;
|
||||
{
|
||||
TraceGuard guard(std::make_shared<TraceGradOperation>(forward_graph->debug_info()));
|
||||
k_child = GetTaylorGrad(mark_taylor, forward_graph->parameters());
|
||||
}
|
||||
grad_fg->set_output(NewValueNode(k_child));
|
||||
// return Taylor(fn)(inputs)
|
||||
return grad_fg;
|
||||
}
|
||||
|
||||
REGISTER_PYBIND_DEFINE(TaylorOperation_, ([](const py::module *m) {
|
||||
(void)py::class_<TaylorOperation, MetaFuncGraph, std::shared_ptr<TaylorOperation>>(
|
||||
*m, "TaylorOperation_")
|
||||
.def(py::init<std::string &>(), py::arg("fn"));
|
||||
}));
|
||||
|
||||
// Generate the ListMap func graph.
|
||||
FuncGraphPtr ListMap::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
|
||||
size_t args_num = args_spec_list.size();
|
||||
|
|
|
@ -166,6 +166,17 @@ class GradOperation : public MetaFuncGraph {
|
|||
};
|
||||
using GradOperationPtr = std::shared_ptr<GradOperation>;
|
||||
|
||||
class TaylorOperation : public MetaFuncGraph {
|
||||
public:
|
||||
explicit TaylorOperation(const std::string &name);
|
||||
~TaylorOperation() override = default;
|
||||
MS_DECLARE_PARENT(TaylorOperation, MetaFuncGraph);
|
||||
FuncGraphPtr GetTaylorGrad(const AnfNodePtr &k, const std::vector<AnfNodePtr> &forward_graph_params);
|
||||
|
||||
FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override;
|
||||
};
|
||||
using TaylorOperationPtr = std::shared_ptr<TaylorOperation>;
|
||||
|
||||
class ListMap {
|
||||
public:
|
||||
explicit ListMap(const std::string &name) : name_(name) { cache_.clear(); }
|
||||
|
|
|
@ -738,6 +738,25 @@ AbstractBasePtr InferImplJ(const AnalysisEnginePtr &, const PrimitivePtr &primit
|
|||
return AbstractFunction::MakeAbstractFunction(jv);
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplTaylor(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// args: An object of AbstractFunction.
|
||||
CheckArgsSize(primitive->name(), args_spec_list, 1);
|
||||
MS_LOG(DEBUG) << "evaluate Taylor: " << args_spec_list[0]->ToString();
|
||||
|
||||
AbstractFunctionPtr x = dyn_cast<AbstractFunction>(args_spec_list[0]);
|
||||
MS_EXCEPTION_IF_NULL(x);
|
||||
|
||||
AbstractFuncAtomPtrList taylor_v;
|
||||
auto build_taylor_v = [&taylor_v](const AbstractFuncAtomPtr &func) {
|
||||
auto taylor_closure = std::make_shared<TaylorTransformedAbstractClosure>(func);
|
||||
taylor_v.push_back(taylor_closure);
|
||||
};
|
||||
x->Visit(build_taylor_v);
|
||||
|
||||
return AbstractFunction::MakeAbstractFunction(taylor_v);
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplShard(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: func, in_axes, out_axes, device, level.
|
||||
|
@ -846,6 +865,7 @@ REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(StringConcat, prim::kPrimStringConcat, InferI
|
|||
REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(DictLen, prim::kPrimDictLen, InferImplDictLen, nullptr);
|
||||
REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(FakeBprop, prim::kPrimFakeBprop, InferImplFakeBprop, nullptr);
|
||||
REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(J, prim::kPrimJ, InferImplJ, nullptr);
|
||||
REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(Taylor, prim::kPrimTaylor, InferImplTaylor, nullptr);
|
||||
REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(Shard, prim::kPrimShard, InferImplShard, nullptr);
|
||||
REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(Vmap, prim::kPrimVmap, InferImplVmap, nullptr);
|
||||
REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(BroadcastGradientArgs, prim::kPrimBroadcastGradientArgs,
|
||||
|
|
|
@ -55,6 +55,8 @@ AbstractBasePtr InferImplDictLen(const AnalysisEnginePtr &, const PrimitivePtr &
|
|||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplJ(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplTaylor(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplShard(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplVmap(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
|
|
|
@ -21,6 +21,7 @@
|
|||
#include "frontend/optimizer/irpass/convert.h"
|
||||
#include "frontend/optimizer/irpass/environ_eliminate.h"
|
||||
#include "frontend/optimizer/irpass/grad_var_prepare.h"
|
||||
#include "frontend/optimizer/irpass/taylor_eliminate.h"
|
||||
#include "frontend/optimizer/irpass/inline.h"
|
||||
#include "frontend/optimizer/irpass/updatestate_eliminate.h"
|
||||
#include "frontend/optimizer/irpass/load_eliminate.h"
|
||||
|
|
|
@ -22,6 +22,7 @@
|
|||
#include "base/core_ops.h"
|
||||
#include "frontend/optimizer/irpass/gradient_eliminate.h"
|
||||
#include "frontend/optimizer/irpass/vmap_eliminate.h"
|
||||
#include "frontend/optimizer/irpass/taylor_eliminate.h"
|
||||
#include "frontend/optimizer/irpass/meta_fg_prim_eliminate.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -35,8 +36,8 @@ class ExpandMetaFg {
|
|||
// to the implementation of `kPrimVmap`.
|
||||
(void)expand_meta_fg_list_.emplace_back(std::make_shared<ExpandJPrim>());
|
||||
(void)expand_meta_fg_list_.emplace_back(std::make_shared<ExpandVmapPrim>());
|
||||
(void)expand_meta_fg_list_.emplace_back(std::make_shared<ExpandTaylorPrim>());
|
||||
}
|
||||
|
||||
virtual ~ExpandMetaFg() = default;
|
||||
bool operator()(const FuncGraphPtr &func_graph, const OptimizerPtr &optimizer);
|
||||
|
||||
|
|
|
@ -0,0 +1,158 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <string>
|
||||
#include "ir/func_graph_cloner.h"
|
||||
#include "frontend/optimizer/irpass/taylor_eliminate.h"
|
||||
#include "pipeline/pynative/pynative_execute.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace irpass {
|
||||
namespace internal {
|
||||
// White list of ops with taylor rule.
|
||||
mindspore::HashSet<std::string> taylor_ops{prim::kPrimAdd->name(), prim::kPrimSub->name(), prim::kPrimRealDiv->name(),
|
||||
prim::kPrimMul->name(), prim::kPrimSin->name(), prim::kPrimCos->name(),
|
||||
prim::kPrimExp->name()};
|
||||
// The ops below are excluded when considering taylor rules.
|
||||
mindspore::HashSet<std::string> taylor_exception_ops{prim::kPrimReturn->name(), prim::kPrimMakeTuple->name(),
|
||||
prim::kPrimTupleGetItem->name(), prim::kPrimCast->name()};
|
||||
|
||||
// Cache list of primitive ops which have been replaced by taylor rule.
|
||||
mindspore::HashMap<PrimitivePtr, FuncGraphPtr> taylor_ops_cache_;
|
||||
|
||||
FuncGraphPtr GetTaylorRule(const PrimitivePtr &prim, const pipeline::ResourceBasePtr &resources) {
|
||||
// Set a child scope named "grad'PrimitiveName'" for the taylor rule function,
|
||||
// and add "Gradients" to the front.
|
||||
static const std::string gradients_scope = "Gradients/";
|
||||
static const std::string grad_op_child_scope_prefix = "/grad";
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
auto scope = std::make_shared<Scope>(gradients_scope + ScopeManager::GetInstance().GetCurrentScope()->name() +
|
||||
grad_op_child_scope_prefix + prim->name());
|
||||
ScopeGuard scope_guard(scope);
|
||||
|
||||
// Firstly we get taylor rule from mindir. If failed, parse the python function registered.
|
||||
FuncGraphPtr func_graph = nullptr;
|
||||
py::function taylor_fn;
|
||||
if (prim->is_base()) {
|
||||
taylor_fn = GetTaylorRuleFunction(prim->name());
|
||||
} else {
|
||||
taylor_fn = prim->cast<PrimitivePyPtr>()->GetTaylorRuleFunction();
|
||||
if (py::isinstance<py::none>(taylor_fn)) {
|
||||
taylor_fn = GetTaylorRuleFunction(prim->name());
|
||||
}
|
||||
}
|
||||
if (!taylor_fn || py::isinstance<py::none>(taylor_fn)) {
|
||||
MS_LOG(INFO) << "Fail to find taylor rule function for " << prim->name() << ". taylor_fn: " << py::str(taylor_fn);
|
||||
return nullptr;
|
||||
}
|
||||
func_graph = parse::ParsePythonCode(taylor_fn);
|
||||
if (func_graph == nullptr) {
|
||||
MS_LOG(ERROR) << "Fail to parse taylor rule function for " << prim->name() << ".";
|
||||
return nullptr;
|
||||
}
|
||||
auto taylor_rule_flag = GetPrimitiveFlag(prim, GRAPH_FLAG_SIDE_EFFECT_BACKPROP);
|
||||
if (taylor_rule_flag) {
|
||||
func_graph->set_flag(mindspore::kFuncGraphFlagReAutoMonad, true);
|
||||
}
|
||||
pipeline::ResourceBasePtr res = (resources != nullptr) ? resources : std::make_shared<pipeline::Resource>();
|
||||
(void)parse::ResolveFuncGraph(func_graph, res);
|
||||
return func_graph;
|
||||
}
|
||||
|
||||
FuncGraphPtr GetTaylorPyObj(const PrimitivePtr &prim, const pipeline::ResourceBasePtr &resources) {
|
||||
auto fg = GetTaylorRule(prim, resources);
|
||||
return fg;
|
||||
}
|
||||
|
||||
FuncGraphPtr GetTaylorPrimitive(const AnfNodePtr &node, const pipeline::ResourceBasePtr &resources) {
|
||||
auto prim_node = GetValueNode<PrimitivePtr>(node);
|
||||
MS_EXCEPTION_IF_NULL(prim_node);
|
||||
auto iter = taylor_ops_cache_.find(prim_node);
|
||||
if (iter != taylor_ops_cache_.end()) {
|
||||
return iter->second;
|
||||
}
|
||||
FuncGraphPtr primitive_taylor = GetTaylorPyObj(prim_node, resources);
|
||||
MS_EXCEPTION_IF_NULL(primitive_taylor);
|
||||
taylor_ops_cache_[prim_node] = primitive_taylor;
|
||||
return primitive_taylor;
|
||||
}
|
||||
|
||||
FuncGraphPtr TaylorFunctor(const FuncGraphPtr &func_graph, const pipeline::ResourceBasePtr &resources) {
|
||||
const auto &value_nodes = func_graph->value_nodes();
|
||||
auto manager = resources->manager();
|
||||
manager->AddFuncGraph(func_graph);
|
||||
std::vector<AnfNodePtr> taylor_node_list;
|
||||
for (const auto &value_pair : value_nodes) {
|
||||
auto node = value_pair.first;
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (IsValueNode<Primitive>(node)) {
|
||||
auto prim_node = GetValueNode<PrimitivePtr>(node);
|
||||
if (taylor_ops.count(prim_node->name())) {
|
||||
taylor_node_list.push_back(node);
|
||||
} else if (!taylor_exception_ops.count(prim_node->name())) {
|
||||
MS_LOG(EXCEPTION) << "The operation " << prim_node->name()
|
||||
<< " is not supported in taylor higher order differentiation currently.";
|
||||
}
|
||||
}
|
||||
}
|
||||
for (size_t i = 0; i < taylor_node_list.size(); i++) {
|
||||
FuncGraphPtr taylor_node_graph = GetTaylorPrimitive(taylor_node_list[i], resources);
|
||||
manager->Replace(taylor_node_list[i], NewValueNode(taylor_node_graph));
|
||||
}
|
||||
taylor_ops_cache_.clear();
|
||||
MS_LOG(INFO) << "return replaced taylor node: " << func_graph->ToString() << " replace end.";
|
||||
return func_graph;
|
||||
}
|
||||
|
||||
AnfNodePtr ExpandTaylor(const ValueNodePtr &vnode, const pipeline::ResourceBasePtr &resource) {
|
||||
if (IsValueNode<FuncGraph>(vnode)) {
|
||||
ScopeGuard scope_guard(vnode->scope());
|
||||
auto func_graph = GetValueNode<FuncGraphPtr>(vnode);
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_LOG(DEBUG) << "Funcgraph: " << func_graph->ToString() << " will expandTaylor now";
|
||||
auto newfg = TaylorFunctor(func_graph, resource);
|
||||
return NewValueNode(newfg);
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
} // namespace internal
|
||||
|
||||
bool ExpandTaylorPrim::operator()(const FuncGraphPtr &func_graph, const OptimizerPtr &optimizer) {
|
||||
// Search all taylor nodes.
|
||||
bool change = false;
|
||||
auto manager = optimizer->manager();
|
||||
for (auto &taylor_node : prim_nodes_) {
|
||||
auto taylor_fg_node = taylor_node->input(1);
|
||||
auto taylor_fg = GetValueNode<FuncGraphPtr>(taylor_fg_node);
|
||||
if (taylor_fg == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Unexpected Taylor node, input func graph should not be null, node: "
|
||||
<< taylor_fg->ToString();
|
||||
}
|
||||
// Copy original forward graph in case of the influence of usage in other place.
|
||||
auto taylor_fg_copy = BasicClone(taylor_fg, true);
|
||||
manager->AddFuncGraph(taylor_fg_copy);
|
||||
auto taylor_fg_copy_node = NewValueNode(taylor_fg_copy);
|
||||
// Return expanded taylor graph.
|
||||
auto expanded_taylor = internal::ExpandTaylor(taylor_fg_copy_node->cast<ValueNodePtr>(), optimizer->resource());
|
||||
manager->Replace(taylor_node, expanded_taylor);
|
||||
change = true;
|
||||
}
|
||||
return change;
|
||||
}
|
||||
} // namespace irpass
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,46 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_TAYLOR_ELIMINATE_H_
|
||||
#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_TAYLOR_ELIMINATE_H_
|
||||
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
|
||||
#include "frontend/optimizer/optimizer.h"
|
||||
#include "frontend/optimizer/irpass.h"
|
||||
#include "frontend/optimizer/anf_visitor.h"
|
||||
#include "utils/ms_utils.h"
|
||||
#include "include/common/utils/primitive_utils.h"
|
||||
#include "frontend/operator/ops.h"
|
||||
#include "frontend/optimizer/ad/grad.h"
|
||||
#include "frontend/optimizer/irpass/meta_fg_prim_eliminate.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace irpass {
|
||||
// {prim::kPrimTaylor, C}
|
||||
class ExpandTaylorPrim : public ExpandMetaFGPrim {
|
||||
public:
|
||||
ExpandTaylorPrim() { prim_ = prim::kPrimTaylor; }
|
||||
virtual ~ExpandTaylorPrim() = default;
|
||||
bool operator()(const FuncGraphPtr &func_graph, const OptimizerPtr &optimizer);
|
||||
};
|
||||
} // namespace irpass
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_TAYLOR_ELIMINATE_H_
|
|
@ -353,6 +353,7 @@ constexpr char SIMPLE_MEAN[] = "SimpleMean";
|
|||
constexpr char FLATTEN[] = "Flatten";
|
||||
constexpr char J[] = "J";
|
||||
constexpr char SHARD[] = "Shard";
|
||||
constexpr char Taylor[] = "Taylor";
|
||||
constexpr char TMPIDENTITY_INFO_NAME[] = "identity_info";
|
||||
constexpr char COS[] = "Cos";
|
||||
constexpr char ACOS[] = "ACos";
|
||||
|
|
|
@ -30,6 +30,10 @@ COMMON_EXPORT py::function GetBpropFunctionByObj(const py::object &obj);
|
|||
|
||||
COMMON_EXPORT py::function GetBpropFunction(const std::string &name);
|
||||
|
||||
COMMON_EXPORT py::function GetTaylorRuleFunctionByObj(const py::object &obj);
|
||||
|
||||
COMMON_EXPORT py::function GetTaylorRuleFunction(const std::string &name);
|
||||
|
||||
COMMON_EXPORT py::function GetComputeFunction(const std::string &name);
|
||||
|
||||
COMMON_EXPORT BaseRef RunComputeFunction(const PrimitivePtr &prim, const VectorRef &args);
|
||||
|
|
|
@ -52,6 +52,7 @@
|
|||
#include "frontend/optimizer/irpass/ge_specialized_prepare.h"
|
||||
#include "frontend/optimizer/irpass/gradient_eliminate.h"
|
||||
#include "frontend/optimizer/irpass/shard_eliminate.h"
|
||||
#include "frontend/optimizer/irpass/taylor_eliminate.h"
|
||||
#include "frontend/optimizer/irpass/parameter_eliminate.h"
|
||||
#include "frontend/optimizer/irpass/updatestate_eliminate.h"
|
||||
#if ((defined ENABLE_CPU) && (!defined _WIN32))
|
||||
|
|
|
@ -572,6 +572,27 @@ EvalResultPtr JEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &arg
|
|||
return res;
|
||||
}
|
||||
|
||||
EvalResultPtr TaylorEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
|
||||
const AnfNodeConfigPtr &) {
|
||||
AbstractBasePtrList args_spec_list;
|
||||
(void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
|
||||
[](const ConfigPtr &conf) -> AbstractBasePtr {
|
||||
MS_EXCEPTION_IF_NULL(conf);
|
||||
return conf->ObtainEvalResult()->abstract();
|
||||
});
|
||||
MS_EXCEPTION_IF_NULL(evaluator_cache_mgr_);
|
||||
auto eval_result = evaluator_cache_mgr_->GetValue(args_spec_list);
|
||||
if (eval_result != nullptr) {
|
||||
return eval_result;
|
||||
}
|
||||
|
||||
// Call the original evaluator, get the result: y = f(x)
|
||||
EvalResultPtr result = evaluator_->Run(engine, args_conf_list, nullptr);
|
||||
MS_EXCEPTION_IF_NULL(result);
|
||||
evaluator_cache_mgr_->SetValue(args_spec_list, result);
|
||||
return result;
|
||||
}
|
||||
|
||||
EvalResultPtr ShardEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
|
||||
const AnfNodeConfigPtr &) {
|
||||
AbstractBasePtrList args_spec_list;
|
||||
|
|
|
@ -372,6 +372,38 @@ class JEvaluator : public Evaluator {
|
|||
AbstractFunctionPtr orig_func_;
|
||||
};
|
||||
|
||||
class TaylorEvaluator : public Evaluator {
|
||||
public:
|
||||
TaylorEvaluator(const EvaluatorPtr &evaluator, const AbstractFunctionPtr &orig_func)
|
||||
: Evaluator("TaylorEvaluator"), evaluator_(evaluator), orig_func_(orig_func) {}
|
||||
~TaylorEvaluator() override = default;
|
||||
MS_DECLARE_PARENT(TaylorEvaluator, Evaluator);
|
||||
AnfNodePtr bound_node() const override {
|
||||
if (evaluator_ != nullptr) {
|
||||
return evaluator_->bound_node();
|
||||
}
|
||||
return bound_node_.lock();
|
||||
}
|
||||
|
||||
void set_bound_node(const AnfNodePtr &node) override {
|
||||
if (evaluator_ != nullptr) {
|
||||
evaluator_->set_bound_node(node);
|
||||
}
|
||||
bound_node_ = AnfNodeWeakPtr(node);
|
||||
}
|
||||
|
||||
EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &, const AnfNodeConfigPtr &) override {
|
||||
MS_LOG(EXCEPTION) << "Should not be called, Run() method should be called";
|
||||
}
|
||||
EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
|
||||
const AnfNodeConfigPtr &out_conf) override;
|
||||
std::string ToString() const override { return identifier_ + "_" + evaluator_->ToString(); }
|
||||
|
||||
private:
|
||||
EvaluatorPtr evaluator_;
|
||||
AbstractFunctionPtr orig_func_;
|
||||
};
|
||||
|
||||
class ShardEvaluator : public Evaluator {
|
||||
public:
|
||||
ShardEvaluator(const EvaluatorPtr &evaluator, const AbstractFunctionPtr &orig_func)
|
||||
|
|
|
@ -501,6 +501,14 @@ EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const std::shared_ptr<VmapTransfor
|
|||
return vmap_evaluator;
|
||||
}
|
||||
|
||||
EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const std::shared_ptr<TaylorTransformedAbstractClosure> &func) {
|
||||
MS_EXCEPTION_IF_NULL(func);
|
||||
AbstractFunctionPtr func_orig = func->fn();
|
||||
EvaluatorPtr evaluator_orig = GetEvaluatorFor(func_orig);
|
||||
auto taylorevaluator = std::make_shared<TaylorEvaluator>(evaluator_orig, func_orig);
|
||||
return taylorevaluator;
|
||||
}
|
||||
|
||||
EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const std::shared_ptr<ShardTransformedAbstractClosure> &func) {
|
||||
MS_EXCEPTION_IF_NULL(func);
|
||||
AbstractFunctionPtr func_orig = func->fn();
|
||||
|
@ -548,6 +556,8 @@ EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const AbstractFunctionPtr &func) {
|
|||
return _GetEvaluatorFor(func->cast<std::shared_ptr<JTransformedAbstractClosure>>());
|
||||
} else if (func->isa<VmapTransformedAbstractClosure>()) {
|
||||
return _GetEvaluatorFor(func->cast<std::shared_ptr<VmapTransformedAbstractClosure>>());
|
||||
} else if (func->isa<TaylorTransformedAbstractClosure>()) {
|
||||
return _GetEvaluatorFor(func->cast<std::shared_ptr<TaylorTransformedAbstractClosure>>());
|
||||
} else if (func->isa<ShardTransformedAbstractClosure>()) {
|
||||
return _GetEvaluatorFor(func->cast<std::shared_ptr<ShardTransformedAbstractClosure>>());
|
||||
} else if (func->isa<VirtualAbstractClosure>()) {
|
||||
|
|
|
@ -324,6 +324,7 @@ class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> {
|
|||
EvaluatorPtr _GetEvaluatorFor(const std::shared_ptr<VirtualAbstractClosure> &fn);
|
||||
EvaluatorPtr _GetEvaluatorFor(const std::shared_ptr<TypedPrimitiveAbstractClosure> &);
|
||||
EvaluatorPtr _GetEvaluatorFor(const std::shared_ptr<JTransformedAbstractClosure> &fn);
|
||||
EvaluatorPtr _GetEvaluatorFor(const std::shared_ptr<TaylorTransformedAbstractClosure> &fn);
|
||||
EvaluatorPtr _GetEvaluatorFor(const std::shared_ptr<ShardTransformedAbstractClosure> &fn);
|
||||
EvaluatorPtr _GetEvaluatorFor(const std::shared_ptr<VmapTransformedAbstractClosure> &fn);
|
||||
|
||||
|
|
|
@ -151,6 +151,18 @@ py::function PrimitivePy::GetBpropFunction() {
|
|||
auto fn = GetBpropFunctionByObj(python_obj_);
|
||||
return fn;
|
||||
}
|
||||
auto fn = GetBpropFunctionByObj(python_obj_);
|
||||
return fn;
|
||||
}
|
||||
|
||||
py::function PrimitivePy::GetTaylorRuleFunction() {
|
||||
static const char *const get_taylor_rule_func_name = "get_taylor_rule";
|
||||
if (py::hasattr(python_obj_, get_taylor_rule_func_name)) {
|
||||
py::function fn = python_obj_.attr(get_taylor_rule_func_name)().cast<py::function>();
|
||||
return fn;
|
||||
}
|
||||
auto fn = GetTaylorRuleFunctionByObj(python_obj_);
|
||||
return fn;
|
||||
}
|
||||
|
||||
py::tuple check_bprop_out(const py::object &grads_obj, const py::tuple &py_args, const std::string &bprop_cls_name) {
|
||||
|
|
|
@ -50,6 +50,7 @@ class PrimitivePy : public Primitive {
|
|||
const bool parse_info_ = true;
|
||||
py::function GetVmapRuleFunction(const bool is_side_effect = false, int axis_size = 0);
|
||||
py::function GetBpropFunction();
|
||||
py::function GetTaylorRuleFunction();
|
||||
void set_signatures(const std::vector<Signature> &signatures);
|
||||
const std::vector<Signature> &signatures() const { return signatures_; }
|
||||
const std::map<int, py::function> &backward_hook_fn() const { return backward_hook_fn_; }
|
||||
|
|
|
@ -41,6 +41,18 @@ py::function GetBpropFunction(const std::string &name) {
|
|||
return fn;
|
||||
}
|
||||
|
||||
py::function GetTaylorRuleFunctionByObj(const py::object &obj) {
|
||||
static const std::string get_taylor_fprop_fn = "get_taylor_fprop_fn";
|
||||
static const std::string ad_module = "mindspore.ops._grad";
|
||||
py::function fn = python_adapter::GetPyFn(ad_module, get_taylor_fprop_fn)(obj);
|
||||
return fn;
|
||||
}
|
||||
|
||||
py::function GetTaylorRuleFunction(const std::string &name) {
|
||||
auto fn = GetTaylorRuleFunctionByObj(py::str(name));
|
||||
return fn;
|
||||
}
|
||||
|
||||
py::function GetComputeFunction(const std::string &name) {
|
||||
static const std::string module = "mindspore._extends.builtin_operations";
|
||||
py::module mod = py::module::import(common::SafeCStr(module));
|
||||
|
|
|
@ -297,6 +297,20 @@ std::size_t JTransformedAbstractClosure::hash() const {
|
|||
return hash_value;
|
||||
}
|
||||
|
||||
bool TaylorTransformedAbstractClosure::operator==(const AbstractFunction &other) const {
|
||||
if (!other.isa<TaylorTransformedAbstractClosure>()) {
|
||||
return false;
|
||||
}
|
||||
auto other_transformed = static_cast<const TaylorTransformedAbstractClosure *>(&other);
|
||||
return fn_ == other_transformed->fn_;
|
||||
}
|
||||
|
||||
std::size_t TaylorTransformedAbstractClosure::hash() const {
|
||||
MS_EXCEPTION_IF_NULL(fn_);
|
||||
auto hash_value = hash_combine(tid(), fn_->hash());
|
||||
return hash_value;
|
||||
}
|
||||
|
||||
bool ShardTransformedAbstractClosure::operator==(const AbstractFunction &other) const {
|
||||
if (!other.isa<ShardTransformedAbstractClosure>()) {
|
||||
return false;
|
||||
|
|
|
@ -339,6 +339,36 @@ class MS_CORE_API JTransformedAbstractClosure final : public AbstractFuncAtom {
|
|||
AbstractFuncAtomPtr fn_;
|
||||
};
|
||||
|
||||
/// \brief TaylorTransformedAbstractClosure defines interface for abstract of Function
|
||||
/// transformed through the application of Taylor.
|
||||
class MS_CORE_API TaylorTransformedAbstractClosure final : public AbstractFuncAtom {
|
||||
public:
|
||||
/// \brief Constructor of TaylorTransformedAbstractClosure
|
||||
///
|
||||
/// \param[in] fn The AbstractFuncAtom transformed through the application of Taylor.
|
||||
explicit TaylorTransformedAbstractClosure(const AbstractFuncAtomPtr &fn) : fn_(fn) {}
|
||||
|
||||
/// \brief Destructor of TaylorTransformedAbstractClosure
|
||||
~TaylorTransformedAbstractClosure() override = default;
|
||||
MS_DECLARE_PARENT(TaylorTransformedAbstractClosure, AbstractFuncAtom)
|
||||
|
||||
/// \brief Get the AbstractFuncAtom TaylorTransformedAbstractClosure corresponding to.
|
||||
///
|
||||
/// \return The AbstractFuncAtom TaylorTransformedAbstractClosure corresponding to.
|
||||
AbstractFuncAtomPtr fn() { return fn_; }
|
||||
|
||||
AbstractFunctionPtr Copy() const override { return std::make_shared<TaylorTransformedAbstractClosure>(fn_); }
|
||||
|
||||
bool operator==(const AbstractFunction &other) const override;
|
||||
|
||||
std::size_t hash() const override;
|
||||
|
||||
std::string ToString() const override { return "Taylor(" + fn_->ToString() + ")"; }
|
||||
|
||||
private:
|
||||
AbstractFuncAtomPtr fn_;
|
||||
};
|
||||
|
||||
/// \brief ShardTransformedAbstractClosure defines interface for abstract of Function
|
||||
/// transformed through the application of Shard.
|
||||
class MS_CORE_API ShardTransformedAbstractClosure final : public AbstractFuncAtom {
|
||||
|
|
|
@ -171,6 +171,7 @@ constexpr auto kCSRDiv = "CSRDiv";
|
|||
// Meta Function Graph
|
||||
constexpr auto kJ = "J";
|
||||
constexpr auto kVmap = "Vmap";
|
||||
constexpr auto kTaylor = "Taylor";
|
||||
|
||||
// Others
|
||||
constexpr auto kMakeTuple = "MakeTuple";
|
||||
|
@ -828,6 +829,7 @@ GVAR_DEF(PrimitivePtr, kPrimStateSetItem, std::make_shared<Primitive>("state_set
|
|||
GVAR_DEF(PrimitivePtr, kPrimJ, std::make_shared<Primitive>(kJ, kSideEffectPropagate));
|
||||
GVAR_DEF(PrimitivePtr, kPrimVmap, std::make_shared<Primitive>(kVmap, kSideEffectPropagate));
|
||||
GVAR_DEF(PrimitivePtr, kPrimShard, std::make_shared<Primitive>("Shard", kSideEffectPropagate));
|
||||
GVAR_DEF(PrimitivePtr, kPrimTaylor, std::make_shared<Primitive>(kTaylor));
|
||||
|
||||
// Used to build graph which have keyword arguments
|
||||
GVAR_DEF(PrimitivePtr, kPrimExtractKeywordArg, std::make_shared<Primitive>("extract_keyword_arg"));
|
||||
|
|
|
@ -313,7 +313,7 @@ std::shared_ptr<std::list<FuncGraphPtr>> FuncGraphManager::recursive_graphs(cons
|
|||
}
|
||||
}
|
||||
|
||||
// Check if the function graph embed with `MetaFGPrim`, which currently covers kPrimJ and kPrimVmap.
|
||||
// Check if the function graph embed with `MetaFGPrim`, which currently covers kPrimJ and kPrimVmap and kPrimTaylor.
|
||||
bool FuncGraphManager::func_graph_meta_fg_prim_total(const FuncGraphPtr &fg) const {
|
||||
MS_EXCEPTION_IF_NULL(meta_fg_prim_total_);
|
||||
MS_EXCEPTION_IF_NULL(fg);
|
||||
|
@ -704,7 +704,8 @@ void FuncGraphManager::OnEdgeAdded(const AnfNodePtr &node, int index, const AnfN
|
|||
signals_->InvalidateComputer();
|
||||
}
|
||||
}
|
||||
if (IsPrimitiveCNode(node, prim::kPrimJ) || IsPrimitiveCNode(node, prim::kPrimVmap)) {
|
||||
if (IsPrimitiveCNode(node, prim::kPrimJ) || IsPrimitiveCNode(node, prim::kPrimVmap) ||
|
||||
IsPrimitiveCNode(node, prim::kPrimTaylor)) {
|
||||
fg->AddMetaFgPrimValueNode(input);
|
||||
}
|
||||
} else if (fg != nullptr && fg != input->func_graph()) {
|
||||
|
@ -725,7 +726,8 @@ void FuncGraphManager::OnEdgeRemoved(const AnfNodePtr &node, int index, const An
|
|||
signals_->InvalidateComputer();
|
||||
}
|
||||
}
|
||||
if (IsPrimitiveCNode(node, prim::kPrimJ) || IsPrimitiveCNode(node, prim::kPrimVmap)) {
|
||||
if (IsPrimitiveCNode(node, prim::kPrimJ) || IsPrimitiveCNode(node, prim::kPrimVmap) ||
|
||||
IsPrimitiveCNode(node, prim::kPrimTaylor)) {
|
||||
fg->DropMetaFgPrimValueNode(input);
|
||||
}
|
||||
} else if (fg != nullptr && fg != input->func_graph()) {
|
||||
|
@ -1096,7 +1098,7 @@ bool FuncGraphMetaFgPrimTotalComputer::SeekMetaFgPrim(const FuncGraphPtr &fg, Se
|
|||
return false;
|
||||
}
|
||||
|
||||
// Check MetaFgPrim (J/Vmap) FuncGraph input.
|
||||
// Check MetaFgPrim (J/Vmap/Taylor) FuncGraph input.
|
||||
const auto &meta_fg_prim_values = fg->meta_fg_prim_value_nodes();
|
||||
if (!meta_fg_prim_values.empty()) {
|
||||
auto contains_meta_fg_prim =
|
||||
|
@ -1107,9 +1109,10 @@ bool FuncGraphMetaFgPrimTotalComputer::SeekMetaFgPrim(const FuncGraphPtr &fg, Se
|
|||
return func_graph->seen_ != seen_num;
|
||||
}
|
||||
if (IsValueNode<Primitive>(iter.first)) {
|
||||
// Exclude the primitive of MetaFgPrim (J/Vmap) itself.
|
||||
// Exclude the primitive of MetaFgPrim (J/Vmap/Taylor) itself.
|
||||
auto prim = GetValueNode<PrimitivePtr>(iter.first);
|
||||
return (prim->name() != prim::kPrimJ->name() && prim->name() != prim::kPrimVmap->name());
|
||||
return (prim->name() != prim::kPrimJ->name() && prim->name() != prim::kPrimVmap->name() &&
|
||||
prim->name() != prim::kPrimTaylor->name());
|
||||
}
|
||||
return false;
|
||||
});
|
||||
|
@ -1119,35 +1122,38 @@ bool FuncGraphMetaFgPrimTotalComputer::SeekMetaFgPrim(const FuncGraphPtr &fg, Se
|
|||
}
|
||||
}
|
||||
|
||||
// Check MetaFgPrim (J/Vmap) CNode as FV.
|
||||
// Check MetaFgPrim (J/Vmap/Taylor) CNode as FV.
|
||||
const auto &fv_nodes = fg->free_variables();
|
||||
if (!fv_nodes.empty()) {
|
||||
auto contains_meta_fg_prim_cnode = std::find_if(fv_nodes.begin(), fv_nodes.end(), [seen_num](const auto &iter) {
|
||||
// Check if the FV is a MetaFgPrim (J/Vmap) call CNode.
|
||||
if (IsPrimitiveCNode(iter.first, prim::kPrimJ) || IsPrimitiveCNode(iter.first, prim::kPrimVmap)) {
|
||||
// Check if the FV is a MetaFgPrim (J/Vmap/Taylor) call CNode.
|
||||
if (IsPrimitiveCNode(iter.first, prim::kPrimJ) || IsPrimitiveCNode(iter.first, prim::kPrimVmap) ||
|
||||
IsPrimitiveCNode(iter.first, prim::kPrimTaylor)) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
});
|
||||
if (contains_meta_fg_prim_cnode != fv_nodes.end()) {
|
||||
MS_LOG(DEBUG) << fg->ToString() << " contains FV MetaFgPrim (J/Vmap) ("
|
||||
MS_LOG(DEBUG) << fg->ToString() << " contains FV MetaFgPrim (J/Vmap/Taylor) ("
|
||||
<< contains_meta_fg_prim_cnode->first->DebugString() << ")";
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
// Check if func graphs used contains J(func_graph), J(Primitive), Vmap(func_graph) or Vmap(Primitive)
|
||||
// Check if func graphs used contains J(func_graph), J(Primitive), Vmap(func_graph), Vmap(Primitive),
|
||||
// Taylor(func_graph) or Taylor(Primitive).
|
||||
fg->seen_ = seen_num;
|
||||
for (auto &item : fg->func_graphs_used()) {
|
||||
auto used_g = item.first;
|
||||
if (SeekMetaFgPrim(used_g, seen_num)) {
|
||||
MS_LOG(DEBUG) << fg->ToString() << " users func graph " << used_g->ToString()
|
||||
<< " which contains J(func_graph), J(Primitive), Vmap(func_graph) or Vmap(Primitive)";
|
||||
<< " which contains J(func_graph), J(Primitive), Vmap(func_graph), Vmap(Primitive), "
|
||||
<< "Taylor(func_graph) or Taylor(Primitive)";
|
||||
return true;
|
||||
}
|
||||
}
|
||||
MS_LOG(DEBUG) << fg->ToString()
|
||||
<< " doesn't contain J(func_graph), J(Primitive), Vmap(func_graph) or Vmap(Primitive)";
|
||||
MS_LOG(DEBUG) << fg->ToString() << " doesn't contain J(func_graph), J(Primitive), Vmap(func_graph), Vmap(Primitive), "
|
||||
<< "Taylor(func_graph) or Taylor(Primitive)";
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
|
||||
"""grad impl."""
|
||||
from . import grad_array_ops, grad_comm_ops, grad_debug_ops, grad_implementations, \
|
||||
grad_math_ops, grad_nn_ops, grad_other_ops, grad_quant_ops, grad_sparse, grad_inner_ops
|
||||
from .grad_base import get_bprop_fn
|
||||
grad_math_ops, grad_nn_ops, grad_other_ops, grad_quant_ops, grad_sparse, grad_inner_ops, taylor_rule
|
||||
from .grad_base import get_bprop_fn, get_taylor_fprop_fn
|
||||
|
||||
__all__ = ['get_bprop_fn']
|
||||
__all__ = ['get_bprop_fn', 'get_taylor_fprop_fn']
|
||||
|
|
|
@ -37,8 +37,28 @@ class BpropRegistry(Registry):
|
|||
return deco
|
||||
|
||||
|
||||
class TaylorFpropRegistry(Registry):
|
||||
"""Registry class for registry functions for taylor grad on Primitive or string."""
|
||||
|
||||
def register(self, prim):
|
||||
"""register the function."""
|
||||
|
||||
def deco(fn):
|
||||
"""Decorate the function."""
|
||||
if isinstance(prim, str):
|
||||
self[prim] = fn
|
||||
elif issubclass(prim, Primitive):
|
||||
self[id(prim)] = fn
|
||||
self[prim.__name__] = fn
|
||||
return fn
|
||||
|
||||
return deco
|
||||
|
||||
|
||||
bprop_getters = BpropRegistry()
|
||||
bprops = BpropRegistry()
|
||||
taylor_fprop_getters = TaylorFpropRegistry()
|
||||
taylor_fprops = TaylorFpropRegistry()
|
||||
|
||||
|
||||
def get_bprop_fn(prim):
|
||||
|
@ -47,3 +67,11 @@ def get_bprop_fn(prim):
|
|||
if out:
|
||||
return out(prim)
|
||||
return bprops.get(prim, None)
|
||||
|
||||
|
||||
def get_taylor_fprop_fn(prim):
|
||||
"""get taylor function by primitive obj or prim name for c++"""
|
||||
out = taylor_fprop_getters.get(prim, None)
|
||||
if out:
|
||||
return out(prim)
|
||||
return taylor_fprops.get(prim, None)
|
||||
|
|
|
@ -0,0 +1,166 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Define the taylor rules of operations."""
|
||||
|
||||
from mindspore import nn
|
||||
import mindspore as ms
|
||||
from ..primitive import Primitive
|
||||
from .. import operations as P
|
||||
from ..composite.multitype_ops.zeros_like_impl import zeros_like
|
||||
from .grad_base import taylor_fprop_getters
|
||||
|
||||
|
||||
def _factorial(order):
|
||||
"""Return [0!, 1!, 2!,..., order!]."""
|
||||
range_op = nn.Range(1, order + 1)
|
||||
ones_op = P.Ones()
|
||||
concat_op = P.Concat()
|
||||
factorial_zero = ones_op(1, ms.float32)
|
||||
factorial_positive = range_op().astype(ms.float32)
|
||||
for i in range(1, order):
|
||||
factorial_positive[i] *= factorial_positive[i - 1]
|
||||
factorial = concat_op((factorial_zero, factorial_positive))
|
||||
return factorial
|
||||
|
||||
|
||||
@taylor_fprop_getters.register(P.Add)
|
||||
@taylor_fprop_getters.register(P.Sub)
|
||||
def taylor_add_or_sub(self):
|
||||
"""Higher order derivatives rule definition for `Add` or `Sub`operation."""
|
||||
if isinstance(self, str):
|
||||
prim = Primitive(self)
|
||||
else:
|
||||
prim = self
|
||||
def taylor_fprop_add_or_sub(input_x, input_y):
|
||||
series = prim(input_x, input_y)
|
||||
return series
|
||||
|
||||
return taylor_fprop_add_or_sub
|
||||
|
||||
|
||||
@taylor_fprop_getters.register(P.Mul)
|
||||
def taylor_mul(self):
|
||||
"""Higher order derivatives rule definition for `Mul` operation."""
|
||||
mul_func = P.Mul()
|
||||
|
||||
def taylor_fprop_mul(input_x, input_y):
|
||||
primals = mul_func(input_x[0], input_y[0])
|
||||
series_num = len(input_x) - 1
|
||||
factorial = _factorial(series_num)
|
||||
series = zeros_like(input_x)
|
||||
series[0] = primals
|
||||
for k in range(1, series_num + 1):
|
||||
for i in range(0, k + 1):
|
||||
tmp = input_x[i] * input_y[k - i] / (factorial[k - i] * factorial[i])
|
||||
series[k] += tmp
|
||||
series[k] *= factorial[k]
|
||||
return series
|
||||
|
||||
return taylor_fprop_mul
|
||||
|
||||
|
||||
@taylor_fprop_getters.register(P.RealDiv)
|
||||
def taylor_realdiv(self):
|
||||
"""Higher order derivatives rule definition for `RealDiv` operation."""
|
||||
div_op = P.Div()
|
||||
|
||||
def taylor_fprop_realdiv(input_x, input_y):
|
||||
primals = div_op(input_x[0], input_y[0])
|
||||
series_num = len(input_x) - 1
|
||||
factorial = _factorial(series_num)
|
||||
series = zeros_like(input_x)
|
||||
series[0] = primals
|
||||
for k in range(1, series_num + 1):
|
||||
for i in range(0, k):
|
||||
tmp = series[i] * input_y[k - i] / (factorial[k - i] * factorial[i])
|
||||
series[k] += tmp
|
||||
series[k] = (input_x[k] - factorial[k] * series[k]) / input_y[0]
|
||||
return series
|
||||
|
||||
return taylor_fprop_realdiv
|
||||
|
||||
|
||||
@taylor_fprop_getters.register(P.Exp)
|
||||
def taylor_exp(self):
|
||||
"""Higher order derivatives rule definition for `Exp` operation."""
|
||||
exp_ = P.Exp()
|
||||
|
||||
def taylor_fprop_exp(inputs):
|
||||
primals = exp_(inputs[0])
|
||||
series_num = len(inputs) - 1
|
||||
factorial = _factorial(series_num)
|
||||
series = zeros_like(inputs)
|
||||
series[0] = primals
|
||||
for k in range(1, series_num + 1):
|
||||
for i in range(1, k + 1):
|
||||
tmp = i * inputs[i] * series[k - i] / (factorial[k - i] * factorial[i])
|
||||
series[k] += tmp
|
||||
series[k] *= factorial[k - 1]
|
||||
return series
|
||||
|
||||
return taylor_fprop_exp
|
||||
|
||||
|
||||
@taylor_fprop_getters.register(P.Sin)
|
||||
def taylor_sin(self):
|
||||
"""Higher order derivatives rule definition for `Sin` operation."""
|
||||
cos = P.Cos()
|
||||
sin = P.Sin()
|
||||
|
||||
def taylor_fprop_sin(inputs):
|
||||
primal_sin = sin(inputs[0])
|
||||
primal_cos = cos(inputs[0])
|
||||
series_sin = zeros_like(inputs)
|
||||
series_cos = zeros_like(inputs)
|
||||
series_sin[0] = primal_sin
|
||||
series_cos[0] = primal_cos
|
||||
series_num = len(inputs) - 1
|
||||
factorial = _factorial(series_num)
|
||||
for k in range(1, series_num + 1):
|
||||
for i in range(1, k + 1):
|
||||
series_sin[k] += i * inputs[i] * series_cos[k - i] / (factorial[i] * factorial[k - i])
|
||||
series_cos[k] -= i * inputs[i] * series_sin[k - i] / (factorial[i] * factorial[k - i])
|
||||
series_sin[k] *= factorial[k - 1]
|
||||
series_cos[k] *= factorial[k - 1]
|
||||
return series_sin
|
||||
|
||||
return taylor_fprop_sin
|
||||
|
||||
|
||||
@taylor_fprop_getters.register(P.Cos)
|
||||
def taylor_cos(self):
|
||||
"""Higher order derivatives rule definition for `Cos` operation."""
|
||||
cos = P.Cos()
|
||||
sin = P.Sin()
|
||||
|
||||
def taylor_fprop_cos(inputs):
|
||||
primal_cos = cos(inputs[0])
|
||||
primal_sin = sin(inputs[0])
|
||||
series_cos = zeros_like(inputs)
|
||||
series_sin = zeros_like(inputs)
|
||||
series_cos[0] = primal_cos
|
||||
series_sin[0] = primal_sin
|
||||
series_num = len(inputs) - 1
|
||||
factorial = _factorial(series_num)
|
||||
for k in range(1, series_num + 1):
|
||||
for i in range(1, k + 1):
|
||||
series_cos[k] -= i * inputs[i] * series_sin[k - i] / (factorial[i] * factorial[k - i])
|
||||
series_sin[k] += i * inputs[i] * series_cos[k - i] / (factorial[i] * factorial[k - i])
|
||||
series_cos[k] *= factorial[k - 1]
|
||||
series_sin[k] *= factorial[k - 1]
|
||||
return series_cos
|
||||
|
||||
return taylor_fprop_cos
|
|
@ -21,7 +21,7 @@ Pre-defined combination of operators.
|
|||
|
||||
|
||||
from .base import GradOperation, _Grad, HyperMap, Map, MultitypeFuncGraph, add_flags, \
|
||||
core, env_get, tail, zip_operation, Shard, _Vmap
|
||||
core, env_get, tail, zip_operation, Shard, _Vmap, _TaylorOperation
|
||||
from .clip_ops import clip_by_value, clip_by_global_norm
|
||||
from .multitype_ops.add_impl import hyper_add
|
||||
from .multitype_ops.ones_like_impl import ones_like
|
||||
|
|
|
@ -22,7 +22,7 @@ from types import FunctionType
|
|||
from mindspore import context
|
||||
from ..._c_expression import GradOperation_, HyperMap_, Map_, MultitypeFuncGraph_, Tail_, Shard_, \
|
||||
TupleAdd_, TupleSlice_, UnpackCall_, ZipOperation_, ListAppend_, TupleGetItemTensor_, ListInsert_, \
|
||||
ListSlice_, VmapOperation_
|
||||
ListSlice_, VmapOperation_, TaylorOperation_
|
||||
from ...common import dtype as mstype
|
||||
from ...common.api import ms_function, _pynative_executor, _wrap_func
|
||||
from ..primitive import Primitive
|
||||
|
@ -401,6 +401,29 @@ class GradOperation(GradOperation_):
|
|||
return self.grad_fn
|
||||
|
||||
|
||||
class _TaylorOperation(TaylorOperation_):
|
||||
"""
|
||||
Generate the higher order derivatives function for the input function.
|
||||
"""
|
||||
def __init__(self):
|
||||
"""Initialize TaylorOperation."""
|
||||
TaylorOperation_.__init__(self, 'taylorgrad')
|
||||
self.grad_fn = None
|
||||
self.fn = None
|
||||
|
||||
def __call__(self, fn):
|
||||
if self.grad_fn is not None and self.fn == fn:
|
||||
return self.grad_fn
|
||||
taylor_grad_ = _TaylorOperation()
|
||||
# If calling Grad in GRAPH_MODE or calling Grad in ms_function, do grad in GRAPH_MODE
|
||||
@ms_function
|
||||
def after_taylor_grad(*args):
|
||||
return taylor_grad_(fn)(*args)
|
||||
self.grad_fn = after_taylor_grad
|
||||
self.fn = fn
|
||||
return self.grad_fn
|
||||
|
||||
|
||||
class _Grad(GradOperation_):
|
||||
"""
|
||||
A higher-order function which is used to generate the gradient function by position for the input function.
|
||||
|
|
|
@ -32,7 +32,7 @@ from .primitive import Primitive
|
|||
from . import operations as P
|
||||
from .operations import _grad_ops
|
||||
from .operations import _csr_ops
|
||||
from .composite import _Grad, Shard, _Vmap
|
||||
from .composite import _Grad, Shard, _Vmap, _TaylorOperation
|
||||
from .._c_expression import security
|
||||
|
||||
typeof = Primitive('typeof')
|
||||
|
@ -48,6 +48,7 @@ eye = P.Eye()
|
|||
fill = P.Fill()
|
||||
tile = P.Tile()
|
||||
size = P.Size()
|
||||
ones = P.Ones()
|
||||
ones_like = P.OnesLike()
|
||||
shape = P.Shape()
|
||||
dyn_shape = P.TensorShape()
|
||||
|
@ -285,6 +286,187 @@ def grad(fn, grad_position=0, sens_param=False):
|
|||
return grad_by_position_with_sens(fn, None, grad_position)
|
||||
return grad_by_position(fn, None, grad_position)
|
||||
|
||||
@constexpr
|
||||
def _trans_jet_inputs(primals_item, series_item):
|
||||
"""Trans inputs of jet"""
|
||||
value_type = [mstype.int32, mstype.int64, mstype.float32, mstype.float64]
|
||||
if not dtype(primals_item) in value_type or dtype(primals_item) != dtype(series_item):
|
||||
raise TypeError(f"For `F.jet`, the elements' types of primals and series should be the same and belong to "
|
||||
f"`mstype.int32, mstype.int64, mstype.float32, mstype.float64`, but got"
|
||||
f" {dtype(primals_item).__name__} and {dtype(series_item).__name__}.")
|
||||
if dtype(primals_item) in [mstype.int32, mstype.int64]:
|
||||
return cast(primals_item, mstype.float64), cast(series_item, mstype.float64)
|
||||
return primals_item, series_item
|
||||
|
||||
@constexpr
|
||||
def _check_jet_inputs(primals, series):
|
||||
"""Check inputs of jet"""
|
||||
if not isinstance(primals, type(series)) or not isinstance(primals, (Tensor, tuple)):
|
||||
raise TypeError(f"For 'F.jet', the 'primals' and `series` should be both Tensor or tuple, "
|
||||
f"but got {type(primals).__name__} and {type(series).__name__}.")
|
||||
if isinstance(primals, Tensor):
|
||||
if primals.shape != series.shape[1:]:
|
||||
raise ValueError("The shape of each element should be the same as the primals.")
|
||||
return _trans_jet_inputs(primals, series)
|
||||
if isinstance(primals, tuple):
|
||||
if len(primals) != len(series):
|
||||
raise ValueError("The lengths of primals and series should be the same.")
|
||||
check_primals = []
|
||||
check_series = []
|
||||
for i, j in zip(primals, series):
|
||||
trans_primals_item, trans_series_item = _trans_jet_inputs(i, j)
|
||||
check_primals.append(trans_primals_item)
|
||||
check_series.append(trans_series_item)
|
||||
return check_primals, check_series
|
||||
|
||||
_taylor = _TaylorOperation()
|
||||
|
||||
def jet(fn, primals, series):
|
||||
"""
|
||||
This function is designed to calculate the higher order differentiation of given composite function. To figure out
|
||||
first to `n`-th order differentiations, original inputs and first to `n`-th order derivative of original inputs
|
||||
must be provided together. Generally, it is recommended to set the values of given first order derivative to 1,
|
||||
while the other to 0.
|
||||
|
||||
Args:
|
||||
fn (Union(Cell, function)): Function to do TaylorOperation.
|
||||
primals (Union(Tensor, Tuple of Tensors)): The inputs to `fn`.
|
||||
series (Union(Tensor, Tuple of Tensors)): If tuple, the length and type of series should be the same as inputs.
|
||||
For each Tensor, the length of first dimension `i` represents the `1` to `i+1`-th order of derivative of
|
||||
output with respect to the inputs will be figured out.
|
||||
|
||||
Returns:
|
||||
Tuple, tuple of out_primals and out_series.
|
||||
|
||||
- **out_primals** (Tensors or List of Tensors) - The output of `fn(primals)`.
|
||||
- **out_series** (Tensors or List of Tensors) - The `1` to `i+1`-th order of derivative of output with respect
|
||||
to the inputs.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> import numpy as np
|
||||
>>> import mindspore.nn as nn
|
||||
>>> import mindspore.context as context
|
||||
>>> import mindspore.ops as P
|
||||
>>> from mindspore import Tensor
|
||||
>>> from mindspore.ops.functional import jet
|
||||
>>> context.set_context(mode=context.GRAPH_MODE)
|
||||
>>> class Net(nn.Cell):
|
||||
... def __init__(self):
|
||||
... super().__init__()
|
||||
... self.sin = P.Sin()
|
||||
... self.exp = P.Exp()
|
||||
... def construct(self, x):
|
||||
... out1 = self.sin(x)
|
||||
... out2 = self.exp(out1)
|
||||
... return out2
|
||||
>>> primals = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
>>> series = Tensor(np.array([[[1, 1], [1, 1]], [[0, 0], [0, 0]], [[0, 0], [0, 0]]]).astype(np.float32))
|
||||
>>> net = Net()
|
||||
>>> out_primals, out_series = jet(net, primals, series)
|
||||
>>> print(out_primals, out_series)
|
||||
"""
|
||||
primals, series = _check_jet_inputs(primals, series)
|
||||
derivative_fn = _taylor(fn)
|
||||
concat_op = P.Concat()
|
||||
if isinstance(primals, list) and list_len(primals) > 1:
|
||||
inputs = list(map(lambda x, y: concat_op(((expand_dims(x, 0), y))), primals, series))
|
||||
outputs = derivative_fn(*inputs)
|
||||
else:
|
||||
inputs = concat_op((expand_dims(primals, 0), series))
|
||||
outputs = derivative_fn(inputs)
|
||||
if isinstance(outputs, list) and list_len(outputs) > 1:
|
||||
out_primals = [element[0] for element in outputs]
|
||||
out_series = [element[1:] for element in outputs]
|
||||
else:
|
||||
out_primals = outputs[0]
|
||||
out_series = outputs[1:]
|
||||
return out_primals, out_series
|
||||
|
||||
|
||||
@constexpr
|
||||
def _trans_derivative_inputs(primals_item):
|
||||
"""Trans inputs of derivative"""
|
||||
value_type = [mstype.int32, mstype.int64, mstype.float32, mstype.float64]
|
||||
if not dtype(primals_item) in value_type:
|
||||
raise TypeError(f"For `F.derivative`, the elements of primals should belong to "
|
||||
f"`mstype.int32, mstype.int64, mstype.float32, mstype.float64`, but got"
|
||||
f" {dtype(primals_item).__name__}.")
|
||||
if dtype(primals_item) in [mstype.int32, mstype.int64]:
|
||||
return cast(primals_item, mstype.float64)
|
||||
return primals_item
|
||||
|
||||
|
||||
def derivative(fn, primals, order):
|
||||
"""
|
||||
This function is designed to calculate the higher order differentiation of given composite function. To figure out
|
||||
`order`-th order differentiations, original inputs and order must be provided together. In particular, the value of
|
||||
input first order derivative is set to 1, while the other to 0.
|
||||
|
||||
Args:
|
||||
fn (Union(Cell, function)): Function to do TaylorOperation.
|
||||
primals (Union(Tensor, Tuple of Tensors)): The inputs to `fn`.
|
||||
order (int): For each Tensor, the `order`-th order of derivative of output with respect to the inputs will be
|
||||
figured out.
|
||||
|
||||
Returns:
|
||||
Tuple, tuple of out_primals and out_series.
|
||||
|
||||
- **out_primals** (Tensors or List of Tensors) - The output of `fn(primals)`.
|
||||
- **out_series** (Tensors or List of Tensors) - The `order`-th order of derivative of output with respect
|
||||
to the inputs.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> import numpy as np
|
||||
>>> import mindspore.nn as nn
|
||||
>>> import mindspore.context as context
|
||||
>>> import mindspore.ops as P
|
||||
>>> from mindspore import Tensor
|
||||
>>> from mindspore.ops.functional import derivative
|
||||
>>> context.set_context(mode=context.GRAPH_MODE)
|
||||
>>> class Net(nn.Cell):
|
||||
... def __init__(self):
|
||||
... super().__init__()
|
||||
... self.sin = P.Sin()
|
||||
... self.exp = P.Exp()
|
||||
... def construct(self, x):
|
||||
... out1 = self.sin(x)
|
||||
... out2 = self.exp(out1)
|
||||
... return out2
|
||||
>>> primals = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
>>> order = 3
|
||||
>>> net = Net()
|
||||
>>> out_primals, out_series = derivative(net, primals, order)
|
||||
>>> print(out_primals, out_series)
|
||||
"""
|
||||
derivative_fn = _taylor(fn)
|
||||
concat_op = P.Concat()
|
||||
series_one = 1
|
||||
if isinstance(primals, tuple):
|
||||
trans_primals = [_trans_derivative_inputs(item) for item in primals]
|
||||
inputs = list(map(lambda x: concat_op((expand_dims(x, 0), ones((1,) + x.shape, dtype(x)))), trans_primals))
|
||||
if order > 1:
|
||||
inputs = list(map(lambda x: concat_op((x, zeros((order - 1,) + x[0].shape, dtype(x)))), inputs))
|
||||
outputs = derivative_fn(*inputs)
|
||||
else:
|
||||
primals = _trans_derivative_inputs(primals)
|
||||
series = zeros((order,) + primals.shape, dtype(primals))
|
||||
series[0] = series_one
|
||||
inputs = concat_op((expand_dims(primals, 0), series))
|
||||
outputs = derivative_fn(inputs)
|
||||
if isinstance(outputs, tuple) and tuple_len(outputs) > 1:
|
||||
out_primals = [element[0] for element in outputs]
|
||||
out_series = [element[-1] for element in outputs]
|
||||
else:
|
||||
out_primals = outputs[0]
|
||||
out_series = outputs[-1]
|
||||
return out_primals, out_series
|
||||
|
||||
|
||||
def jvp(fn, inputs, v):
|
||||
"""
|
||||
|
|
|
@ -38,6 +38,11 @@ class MultipleInputsOutputNet(nn.Cell):
|
|||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_vjp_single_input_graph():
|
||||
"""
|
||||
Features: Function vjp
|
||||
Description: Test vjp with single input, single output and default v in graph mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
|
||||
net = SingleInputNet()
|
||||
|
@ -53,6 +58,11 @@ def test_vjp_single_input_graph():
|
|||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_vjp_multiple_inputs_default_v_graph():
|
||||
"""
|
||||
Features: Function vjp
|
||||
Description: Test vjp with single input, single output and default v in graph mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
|
||||
|
@ -140,8 +150,8 @@ def test_vjp_construct_single_input_single_output_default_v_graph():
|
|||
net_out, vjp_out = vjp(self.net, inputs, vectors)
|
||||
return net_out, vjp_out
|
||||
|
||||
test_net = Net(SingleInputNet())
|
||||
primal, grad = test_net(x, v)
|
||||
test_net_graph = Net(SingleInputNet())
|
||||
primal, grad = test_net_graph(x, v)
|
||||
expect_primal = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
|
||||
expect_grad = Tensor(np.array([[3, 12], [27, 48]]).astype(np.float32))
|
||||
assert np.allclose(primal.asnumpy(), expect_primal.asnumpy())
|
||||
|
|
|
@ -18,7 +18,6 @@ import pytest
|
|||
import mindspore.nn as nn
|
||||
import mindspore.context as context
|
||||
from mindspore import Tensor
|
||||
from mindspore import ms_function
|
||||
from mindspore.ops.functional import vjp
|
||||
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
|
@ -37,12 +36,17 @@ class MultipleInputsOutputNet(nn.Cell):
|
|||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_vjp_single_input_graph():
|
||||
def test_vjp_single_input_pynative():
|
||||
"""
|
||||
Features: Function vjp
|
||||
Description: Test vjp with single input, single output and default v in pynative mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
|
||||
net = SingleInputNet()
|
||||
expect_primal = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
|
||||
expect_grad = Tensor(np.array([[3, 12], [27, 48]]).astype(np.float32))
|
||||
expect_primal = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
|
||||
primal, grad = vjp(net, x, v)
|
||||
assert np.allclose(primal.asnumpy(), expect_primal.asnumpy())
|
||||
assert np.allclose(grad.asnumpy(), expect_grad.asnumpy())
|
||||
|
@ -51,58 +55,38 @@ def test_vjp_single_input_graph():
|
|||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_vjp_multiple_inputs_default_v_graph():
|
||||
def test_vjp_multiple_inputs_default_v_pynative():
|
||||
"""
|
||||
Features: Function vjp
|
||||
Description: Test vjp with multiple inputs, multiple outputs and default v in pynative mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
|
||||
net = MultipleInputsOutputNet()
|
||||
expect_primal_0 = Tensor(np.array([[2, 4], [6, 8]]).astype(np.float32))
|
||||
expect_primal_1 = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
|
||||
expect_grad_0 = Tensor(np.array([[2, 2], [2, 2]]).astype(np.float32))
|
||||
expect_grad_1 = Tensor(np.array([[3, 12], [27, 48]]).astype(np.float32))
|
||||
expect_primal_0 = Tensor(np.array([[2, 4], [6, 8]]).astype(np.float32))
|
||||
expect_primal_1 = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
|
||||
primal, grad = vjp(net, (x, y), (v, v))
|
||||
assert isinstance(primal, tuple)
|
||||
assert len(primal) == 2
|
||||
assert np.allclose(primal[0].asnumpy(), expect_primal_0.asnumpy())
|
||||
assert np.allclose(primal[1].asnumpy(), expect_primal_1.asnumpy())
|
||||
assert isinstance(grad, tuple)
|
||||
assert len(grad) == 2
|
||||
assert np.allclose(grad[0].asnumpy(), expect_grad_0.asnumpy())
|
||||
assert np.allclose(grad[1].asnumpy(), expect_grad_1.asnumpy())
|
||||
assert isinstance(primal, tuple)
|
||||
assert len(primal) == 2
|
||||
assert np.allclose(primal[0].asnumpy(), expect_primal_0.asnumpy())
|
||||
assert np.allclose(primal[1].asnumpy(), expect_primal_1.asnumpy())
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_vjp_ms_function_single_input_single_output_default_v_graph():
|
||||
def test_vjp_input_function_single_input_single_output_default_v_pynative():
|
||||
"""
|
||||
Features: Function vjp
|
||||
Description: Test vjp with ms_function, single input, single output and default v in graph mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
|
||||
net = SingleInputNet()
|
||||
|
||||
@ms_function
|
||||
def vjp_with_ms_function(inputs, vectors):
|
||||
output, vjp_grad = vjp(net, inputs, vectors)
|
||||
return output, vjp_grad
|
||||
|
||||
primal, grad = vjp_with_ms_function(x, v)
|
||||
expect_primal = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
|
||||
expect_grad = Tensor(np.array([[3, 12], [27, 48]]).astype(np.float32))
|
||||
assert np.allclose(primal.asnumpy(), expect_primal.asnumpy())
|
||||
assert np.allclose(grad.asnumpy(), expect_grad.asnumpy())
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_vjp_input_function_single_input_single_output_default_v_graph():
|
||||
"""
|
||||
Features: Function vjp
|
||||
Description: Test vjp with function, single input, single output and default v in graph mode.
|
||||
Description: Test vjp with function, single input, single output and default v in pynative mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
|
@ -112,8 +96,8 @@ def test_vjp_input_function_single_input_single_output_default_v_graph():
|
|||
return inputs**3
|
||||
|
||||
primal, grad = vjp(test_function, x, v)
|
||||
expect_primal = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
|
||||
expect_grad = Tensor(np.array([[3, 12], [27, 48]]).astype(np.float32))
|
||||
expect_primal = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
|
||||
assert np.allclose(primal.asnumpy(), expect_primal.asnumpy())
|
||||
assert np.allclose(grad.asnumpy(), expect_grad.asnumpy())
|
||||
|
||||
|
@ -121,10 +105,10 @@ def test_vjp_input_function_single_input_single_output_default_v_graph():
|
|||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_vjp_construct_single_input_single_output_default_v_graph():
|
||||
def test_vjp_construct_single_input_single_output_default_v_pynative():
|
||||
"""
|
||||
Features: Function vjp
|
||||
Description: Test vjp with function, single input, single output and default v in graph mode.
|
||||
Description: Test vjp with function, single input, single output and default v in pynative mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||
|
@ -139,8 +123,8 @@ def test_vjp_construct_single_input_single_output_default_v_graph():
|
|||
net_out, vjp_out = vjp(self.net, inputs, vectors)
|
||||
return net_out, vjp_out
|
||||
|
||||
test_net = Net(SingleInputNet())
|
||||
primal, grad = test_net(x, v)
|
||||
test_net_pynative = Net(SingleInputNet())
|
||||
primal, grad = test_net_pynative(x, v)
|
||||
expect_primal = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
|
||||
expect_grad = Tensor(np.array([[3, 12], [27, 48]]).astype(np.float32))
|
||||
assert np.allclose(primal.asnumpy(), expect_primal.asnumpy())
|
||||
|
|
|
@ -0,0 +1,142 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""test taylor differentiation in graph mode"""
|
||||
import pytest
|
||||
import numpy as np
|
||||
import mindspore.nn as nn
|
||||
import mindspore.context as context
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops.functional import jet, derivative
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
|
||||
class MultipleInputSingleOutputNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super(MultipleInputSingleOutputNet, self).__init__()
|
||||
self.sin = P.Sin()
|
||||
self.cos = P.Cos()
|
||||
self.exp = P.Exp()
|
||||
|
||||
def construct(self, x, y):
|
||||
out1 = self.sin(x)
|
||||
out2 = self.cos(y)
|
||||
out3 = out1 * out2 + out1 / out2
|
||||
out = self.exp(out3)
|
||||
return out
|
||||
|
||||
|
||||
class SingleInputSingleOutputNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super(SingleInputSingleOutputNet, self).__init__()
|
||||
self.sin = P.Sin()
|
||||
self.cos = P.Cos()
|
||||
self.exp = P.Exp()
|
||||
|
||||
def construct(self, x):
|
||||
out1 = self.sin(x)
|
||||
out2 = self.cos(out1)
|
||||
out3 = self.exp(out2)
|
||||
out = out1 + out2 - out3
|
||||
return out
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_jet_single_input_single_output_graph_mode():
|
||||
"""
|
||||
Features: Function jet
|
||||
Description: Test jet with single input in graph mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
primals = Tensor([1., 1.])
|
||||
series = Tensor([[1., 1.], [0., 0.], [0., 0.]])
|
||||
net = SingleInputSingleOutputNet()
|
||||
expected_primals = np.array([-0.43931, -0.43931]).astype(np.float32)
|
||||
expected_series = np.array([[0.92187, 0.92187], [-1.56750, -1.56750], [-0.74808, -0.74808]]).astype(np.float32)
|
||||
out_primals, out_series = jet(net, primals, series)
|
||||
assert np.allclose(out_series.asnumpy(), expected_series, atol=1.e-4)
|
||||
assert np.allclose(out_primals.asnumpy(), expected_primals, atol=1.e-4)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_derivative_single_input_single_output_graph_mode():
|
||||
"""
|
||||
Features: Function derivative
|
||||
Description: Test derivative with single input in graph mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
primals = Tensor([1., 1.])
|
||||
order = 3
|
||||
net = SingleInputSingleOutputNet()
|
||||
expected_primals = np.array([-0.43931, -0.43931]).astype(np.float32)
|
||||
expected_series = np.array([-0.74808, -0.74808]).astype(np.float32)
|
||||
out_primals, out_series = derivative(net, primals, order)
|
||||
assert np.allclose(out_primals.asnumpy(), expected_primals, atol=1.e-4)
|
||||
assert np.allclose(out_series.asnumpy(), expected_series, atol=1.e-4)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_jet_multiple_input_single_output_graph_mode():
|
||||
"""
|
||||
Features: Function jet
|
||||
Description: Test jet with multiple inputs in graph mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
primals = (Tensor([1., 1.]), Tensor([1., 1.]))
|
||||
series = (Tensor([[1., 1.], [0., 0.], [0., 0.]]), Tensor([[1., 1.], [0., 0.], [0., 0.]]))
|
||||
net = MultipleInputSingleOutputNet()
|
||||
expected_primals = np.array([7.47868, 7.47868]).astype(np.float32)
|
||||
expected_series = np.array([[22.50614, 22.50614], [133.92517, 133.92517], [1237.959, 1237.959]]).astype(np.float32)
|
||||
out_primals, out_series = jet(net, primals, series)
|
||||
assert np.allclose(out_primals.asnumpy(), expected_primals, atol=1.e-4)
|
||||
assert np.allclose(out_series.asnumpy(), expected_series, atol=1.e-4)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_derivative_multiple_input_single_output_graph_mode():
|
||||
"""
|
||||
Features: Function derivative
|
||||
Description: Test derivative with multiple inputs in graph mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
primals = (Tensor([1., 1.]), Tensor([1., 1.]))
|
||||
order = 3
|
||||
net = MultipleInputSingleOutputNet()
|
||||
expected_primals = np.array([7.47868, 7.47868]).astype(np.float32)
|
||||
expected_series = np.array([1237.959, 1237.959]).astype(np.float32)
|
||||
out_primals, out_series = derivative(net, primals, order)
|
||||
assert np.allclose(out_primals.asnumpy(), expected_primals, atol=1.e-4)
|
||||
assert np.allclose(out_series.asnumpy(), expected_series, atol=1.e-4)
|
|
@ -0,0 +1,142 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""test taylor differentiation in pynative mode"""
|
||||
import pytest
|
||||
import numpy as np
|
||||
import mindspore.nn as nn
|
||||
import mindspore.context as context
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops.functional import jet, derivative
|
||||
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
|
||||
|
||||
class SingleInputSingleOutputNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super(SingleInputSingleOutputNet, self).__init__()
|
||||
self.exp = P.Exp()
|
||||
self.cos = P.Cos()
|
||||
self.sin = P.Sin()
|
||||
|
||||
def construct(self, x):
|
||||
out1 = self.sin(x)
|
||||
out2 = self.cos(out1)
|
||||
out3 = self.exp(out2)
|
||||
out = out1 + out2 - out3
|
||||
return out
|
||||
|
||||
|
||||
class MultipleInputSingleOutputNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super(MultipleInputSingleOutputNet, self).__init__()
|
||||
self.exp = P.Exp()
|
||||
self.cos = P.Cos()
|
||||
self.sin = P.Sin()
|
||||
|
||||
def construct(self, x, y):
|
||||
out1 = self.sin(x)
|
||||
out2 = self.cos(y)
|
||||
out3 = out1 * out2 + out1 / out2
|
||||
out = self.exp(out3)
|
||||
return out
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_jet_multiple_input_single_output_pynative_mode():
|
||||
"""
|
||||
Features: Function jet
|
||||
Description: Test jet with multiple inputs in pynative mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
series = (Tensor([[1., 1.], [0., 0.], [0., 0.]]), Tensor([[1., 1.], [0., 0.], [0., 0.]]))
|
||||
primals = (Tensor([1., 1.]), Tensor([1., 1.]))
|
||||
net = MultipleInputSingleOutputNet()
|
||||
expected_primals = np.array([7.47868, 7.47868]).astype(np.float32)
|
||||
expected_series = np.array([[22.50614, 22.50614], [133.92517, 133.92517], [1237.959, 1237.959]]).astype(np.float32)
|
||||
out_primals, out_series = jet(net, primals, series)
|
||||
assert np.allclose(out_series.asnumpy(), expected_series, atol=1.e-4)
|
||||
assert np.allclose(out_primals.asnumpy(), expected_primals, atol=1.e-4)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_derivative_multiple_input_single_output_pynative_mode():
|
||||
"""
|
||||
Features: Function derivative
|
||||
Description: Test derivative with multiple inputs in pynative mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
primals = (Tensor([1., 1.]), Tensor([1., 1.]))
|
||||
order = 3
|
||||
net = MultipleInputSingleOutputNet()
|
||||
expected_primals = np.array([7.47868, 7.47868]).astype(np.float32)
|
||||
expected_series = np.array([1237.959, 1237.959]).astype(np.float32)
|
||||
out_primals, out_series = derivative(net, primals, order)
|
||||
assert np.allclose(out_primals.asnumpy(), expected_primals, atol=1.e-4)
|
||||
assert np.allclose(out_series.asnumpy(), expected_series, atol=1.e-4)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_jet_single_input_single_output_pynative_mode():
|
||||
"""
|
||||
Features: Function jet
|
||||
Description: Test jet with single input in pynative mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
primals = Tensor([1., 1.])
|
||||
series = Tensor([[1., 1.], [0., 0.], [0., 0.]])
|
||||
net = SingleInputSingleOutputNet()
|
||||
expected_primals = np.array([-0.43931, -0.43931]).astype(np.float32)
|
||||
expected_series = np.array([[0.92187, 0.92187], [-1.56750, -1.56750], [-0.74808, -0.74808]]).astype(np.float32)
|
||||
out_primals, out_series = jet(net, primals, series)
|
||||
assert np.allclose(out_primals.asnumpy(), expected_primals, atol=1.e-4)
|
||||
assert np.allclose(out_series.asnumpy(), expected_series, atol=1.e-4)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_derivative_single_input_single_output_pynative_mode():
|
||||
"""
|
||||
Features: Function derivative
|
||||
Description: Test derivative with single input in pynative mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
primals = Tensor([1., 1.])
|
||||
order = 3
|
||||
net = SingleInputSingleOutputNet()
|
||||
expected_primals = np.array([-0.43931, -0.43931]).astype(np.float32)
|
||||
expected_series = np.array([-0.74808, -0.74808]).astype(np.float32)
|
||||
out_primals, out_series = derivative(net, primals, order)
|
||||
assert np.allclose(out_primals.asnumpy(), expected_primals, atol=1.e-4)
|
||||
assert np.allclose(out_series.asnumpy(), expected_series, atol=1.e-4)
|
Loading…
Reference in New Issue