!30211 [MS]Add Higher Order Differentiation

Merge pull request !30211 from chenzhuo/taylor_v1
This commit is contained in:
i-robot 2022-03-11 06:09:07 +00:00 committed by Gitee
commit 7f9056ad69
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
32 changed files with 1202 additions and 65 deletions

View File

@ -944,6 +944,80 @@ REGISTER_PYBIND_DEFINE(VmapOperation_, ([](const py::module *m) {
.def(py::init<std::string &>(), py::arg("fn"));
}));
TaylorOperation::TaylorOperation(const std::string &name) : MetaFuncGraph(name) {
// def Taylor(func:read):
signatures_ = std::vector<Signature>({{"func", SignatureEnumRW::kRWRead, SignatureEnumKind::kKindDefault}});
}
FuncGraphPtr TaylorOperation::GetTaylorGrad(const AnfNodePtr &k, const std::vector<AnfNodePtr> &forward_graph_params) {
FuncGraphPtr k_child = std::make_shared<FuncGraph>();
k_child->set_flag(FUNC_GRAPH_FLAG_CORE, true);
std::vector<AnfNodePtr> inputs;
inputs.push_back(k);
MS_LOG(INFO) << "TaylorOperation forward input size " << forward_graph_params.size();
for (size_t i = 0; i < forward_graph_params.size(); ++i) {
inputs.push_back(k_child->add_parameter());
}
// Taylor(fn)(input params)
auto k_app = k_child->NewCNodeInOrder(inputs);
k_child->set_output(k_app);
return k_child;
}
// Generate the graph to calculate higher order derivatives.
FuncGraphPtr TaylorOperation::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
if (args_spec_list.empty()) {
MS_LOG(EXCEPTION)
<< "'TaylorOperation' requires a forward network or function as an input, while the input is empty.";
}
MS_EXCEPTION_IF_NULL(args_spec_list[0]);
AbstractFunctionPtr fn = dyn_cast<AbstractFunction>(args_spec_list[0]);
if (fn == nullptr) {
MS_LOG(EXCEPTION) << "'TaylorOperation' arg0 must be a 'Function' or 'Cell', but got "
<< args_spec_list[0]->ToString();
}
auto real_fn = dyn_cast<FuncGraphAbstractClosure>(fn);
MS_EXCEPTION_IF_NULL(real_fn);
FuncGraphPtr forward_graph = real_fn->func_graph();
MS_EXCEPTION_IF_NULL(forward_graph);
forward_graph->set_flag(FUNC_GRAPH_FLAG_DEFER_INLINE, true);
FuncGraphPtr grad_fg = nullptr;
MS_LOG(INFO) << "'TaylorOperation' forward_graph" << forward_graph->debug_info();
grad_fg = std::make_shared<FuncGraph>();
auto nparam = forward_graph->parameters().size();
std::ostringstream ss;
ss << "taylorgrad{" << nparam << "}";
grad_fg->set_flag(FUNC_GRAPH_FLAG_CORE, true);
grad_fg->debug_info()->set_name(ss.str());
ParameterPtr param_graph = grad_fg->add_parameter();
std::vector<AnfNodePtr> inputs;
inputs.push_back(NewValueNode(prim::kPrimTaylor));
inputs.push_back(param_graph);
// Taylor(fn)
auto mark_taylor = grad_fg->NewCNodeInOrder(inputs);
FuncGraphPtr k_child = nullptr;
{
TraceGuard guard(std::make_shared<TraceGradOperation>(forward_graph->debug_info()));
k_child = GetTaylorGrad(mark_taylor, forward_graph->parameters());
}
grad_fg->set_output(NewValueNode(k_child));
// return Taylor(fn)(inputs)
return grad_fg;
}
REGISTER_PYBIND_DEFINE(TaylorOperation_, ([](const py::module *m) {
(void)py::class_<TaylorOperation, MetaFuncGraph, std::shared_ptr<TaylorOperation>>(
*m, "TaylorOperation_")
.def(py::init<std::string &>(), py::arg("fn"));
}));
// Generate the ListMap func graph.
FuncGraphPtr ListMap::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
size_t args_num = args_spec_list.size();

View File

@ -166,6 +166,17 @@ class GradOperation : public MetaFuncGraph {
};
using GradOperationPtr = std::shared_ptr<GradOperation>;
class TaylorOperation : public MetaFuncGraph {
public:
explicit TaylorOperation(const std::string &name);
~TaylorOperation() override = default;
MS_DECLARE_PARENT(TaylorOperation, MetaFuncGraph);
FuncGraphPtr GetTaylorGrad(const AnfNodePtr &k, const std::vector<AnfNodePtr> &forward_graph_params);
FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override;
};
using TaylorOperationPtr = std::shared_ptr<TaylorOperation>;
class ListMap {
public:
explicit ListMap(const std::string &name) : name_(name) { cache_.clear(); }

View File

@ -738,6 +738,25 @@ AbstractBasePtr InferImplJ(const AnalysisEnginePtr &, const PrimitivePtr &primit
return AbstractFunction::MakeAbstractFunction(jv);
}
AbstractBasePtr InferImplTaylor(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// args: An object of AbstractFunction.
CheckArgsSize(primitive->name(), args_spec_list, 1);
MS_LOG(DEBUG) << "evaluate Taylor: " << args_spec_list[0]->ToString();
AbstractFunctionPtr x = dyn_cast<AbstractFunction>(args_spec_list[0]);
MS_EXCEPTION_IF_NULL(x);
AbstractFuncAtomPtrList taylor_v;
auto build_taylor_v = [&taylor_v](const AbstractFuncAtomPtr &func) {
auto taylor_closure = std::make_shared<TaylorTransformedAbstractClosure>(func);
taylor_v.push_back(taylor_closure);
};
x->Visit(build_taylor_v);
return AbstractFunction::MakeAbstractFunction(taylor_v);
}
AbstractBasePtr InferImplShard(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// Inputs: func, in_axes, out_axes, device, level.
@ -846,6 +865,7 @@ REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(StringConcat, prim::kPrimStringConcat, InferI
REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(DictLen, prim::kPrimDictLen, InferImplDictLen, nullptr);
REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(FakeBprop, prim::kPrimFakeBprop, InferImplFakeBprop, nullptr);
REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(J, prim::kPrimJ, InferImplJ, nullptr);
REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(Taylor, prim::kPrimTaylor, InferImplTaylor, nullptr);
REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(Shard, prim::kPrimShard, InferImplShard, nullptr);
REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(Vmap, prim::kPrimVmap, InferImplVmap, nullptr);
REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(BroadcastGradientArgs, prim::kPrimBroadcastGradientArgs,

View File

@ -55,6 +55,8 @@ AbstractBasePtr InferImplDictLen(const AnalysisEnginePtr &, const PrimitivePtr &
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplJ(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplTaylor(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplShard(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplVmap(const AnalysisEnginePtr &, const PrimitivePtr &primitive,

View File

@ -21,6 +21,7 @@
#include "frontend/optimizer/irpass/convert.h"
#include "frontend/optimizer/irpass/environ_eliminate.h"
#include "frontend/optimizer/irpass/grad_var_prepare.h"
#include "frontend/optimizer/irpass/taylor_eliminate.h"
#include "frontend/optimizer/irpass/inline.h"
#include "frontend/optimizer/irpass/updatestate_eliminate.h"
#include "frontend/optimizer/irpass/load_eliminate.h"

View File

@ -22,6 +22,7 @@
#include "base/core_ops.h"
#include "frontend/optimizer/irpass/gradient_eliminate.h"
#include "frontend/optimizer/irpass/vmap_eliminate.h"
#include "frontend/optimizer/irpass/taylor_eliminate.h"
#include "frontend/optimizer/irpass/meta_fg_prim_eliminate.h"
namespace mindspore {
@ -35,8 +36,8 @@ class ExpandMetaFg {
// to the implementation of `kPrimVmap`.
(void)expand_meta_fg_list_.emplace_back(std::make_shared<ExpandJPrim>());
(void)expand_meta_fg_list_.emplace_back(std::make_shared<ExpandVmapPrim>());
(void)expand_meta_fg_list_.emplace_back(std::make_shared<ExpandTaylorPrim>());
}
virtual ~ExpandMetaFg() = default;
bool operator()(const FuncGraphPtr &func_graph, const OptimizerPtr &optimizer);

View File

@ -0,0 +1,158 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <string>
#include "ir/func_graph_cloner.h"
#include "frontend/optimizer/irpass/taylor_eliminate.h"
#include "pipeline/pynative/pynative_execute.h"
namespace mindspore {
namespace opt {
namespace irpass {
namespace internal {
// White list of ops with taylor rule.
mindspore::HashSet<std::string> taylor_ops{prim::kPrimAdd->name(), prim::kPrimSub->name(), prim::kPrimRealDiv->name(),
prim::kPrimMul->name(), prim::kPrimSin->name(), prim::kPrimCos->name(),
prim::kPrimExp->name()};
// The ops below are excluded when considering taylor rules.
mindspore::HashSet<std::string> taylor_exception_ops{prim::kPrimReturn->name(), prim::kPrimMakeTuple->name(),
prim::kPrimTupleGetItem->name(), prim::kPrimCast->name()};
// Cache list of primitive ops which have been replaced by taylor rule.
mindspore::HashMap<PrimitivePtr, FuncGraphPtr> taylor_ops_cache_;
FuncGraphPtr GetTaylorRule(const PrimitivePtr &prim, const pipeline::ResourceBasePtr &resources) {
// Set a child scope named "grad'PrimitiveName'" for the taylor rule function,
// and add "Gradients" to the front.
static const std::string gradients_scope = "Gradients/";
static const std::string grad_op_child_scope_prefix = "/grad";
MS_EXCEPTION_IF_NULL(prim);
auto scope = std::make_shared<Scope>(gradients_scope + ScopeManager::GetInstance().GetCurrentScope()->name() +
grad_op_child_scope_prefix + prim->name());
ScopeGuard scope_guard(scope);
// Firstly we get taylor rule from mindir. If failed, parse the python function registered.
FuncGraphPtr func_graph = nullptr;
py::function taylor_fn;
if (prim->is_base()) {
taylor_fn = GetTaylorRuleFunction(prim->name());
} else {
taylor_fn = prim->cast<PrimitivePyPtr>()->GetTaylorRuleFunction();
if (py::isinstance<py::none>(taylor_fn)) {
taylor_fn = GetTaylorRuleFunction(prim->name());
}
}
if (!taylor_fn || py::isinstance<py::none>(taylor_fn)) {
MS_LOG(INFO) << "Fail to find taylor rule function for " << prim->name() << ". taylor_fn: " << py::str(taylor_fn);
return nullptr;
}
func_graph = parse::ParsePythonCode(taylor_fn);
if (func_graph == nullptr) {
MS_LOG(ERROR) << "Fail to parse taylor rule function for " << prim->name() << ".";
return nullptr;
}
auto taylor_rule_flag = GetPrimitiveFlag(prim, GRAPH_FLAG_SIDE_EFFECT_BACKPROP);
if (taylor_rule_flag) {
func_graph->set_flag(mindspore::kFuncGraphFlagReAutoMonad, true);
}
pipeline::ResourceBasePtr res = (resources != nullptr) ? resources : std::make_shared<pipeline::Resource>();
(void)parse::ResolveFuncGraph(func_graph, res);
return func_graph;
}
FuncGraphPtr GetTaylorPyObj(const PrimitivePtr &prim, const pipeline::ResourceBasePtr &resources) {
auto fg = GetTaylorRule(prim, resources);
return fg;
}
FuncGraphPtr GetTaylorPrimitive(const AnfNodePtr &node, const pipeline::ResourceBasePtr &resources) {
auto prim_node = GetValueNode<PrimitivePtr>(node);
MS_EXCEPTION_IF_NULL(prim_node);
auto iter = taylor_ops_cache_.find(prim_node);
if (iter != taylor_ops_cache_.end()) {
return iter->second;
}
FuncGraphPtr primitive_taylor = GetTaylorPyObj(prim_node, resources);
MS_EXCEPTION_IF_NULL(primitive_taylor);
taylor_ops_cache_[prim_node] = primitive_taylor;
return primitive_taylor;
}
FuncGraphPtr TaylorFunctor(const FuncGraphPtr &func_graph, const pipeline::ResourceBasePtr &resources) {
const auto &value_nodes = func_graph->value_nodes();
auto manager = resources->manager();
manager->AddFuncGraph(func_graph);
std::vector<AnfNodePtr> taylor_node_list;
for (const auto &value_pair : value_nodes) {
auto node = value_pair.first;
MS_EXCEPTION_IF_NULL(node);
if (IsValueNode<Primitive>(node)) {
auto prim_node = GetValueNode<PrimitivePtr>(node);
if (taylor_ops.count(prim_node->name())) {
taylor_node_list.push_back(node);
} else if (!taylor_exception_ops.count(prim_node->name())) {
MS_LOG(EXCEPTION) << "The operation " << prim_node->name()
<< " is not supported in taylor higher order differentiation currently.";
}
}
}
for (size_t i = 0; i < taylor_node_list.size(); i++) {
FuncGraphPtr taylor_node_graph = GetTaylorPrimitive(taylor_node_list[i], resources);
manager->Replace(taylor_node_list[i], NewValueNode(taylor_node_graph));
}
taylor_ops_cache_.clear();
MS_LOG(INFO) << "return replaced taylor node: " << func_graph->ToString() << " replace end.";
return func_graph;
}
AnfNodePtr ExpandTaylor(const ValueNodePtr &vnode, const pipeline::ResourceBasePtr &resource) {
if (IsValueNode<FuncGraph>(vnode)) {
ScopeGuard scope_guard(vnode->scope());
auto func_graph = GetValueNode<FuncGraphPtr>(vnode);
MS_EXCEPTION_IF_NULL(func_graph);
MS_LOG(DEBUG) << "Funcgraph: " << func_graph->ToString() << " will expandTaylor now";
auto newfg = TaylorFunctor(func_graph, resource);
return NewValueNode(newfg);
}
return nullptr;
}
} // namespace internal
bool ExpandTaylorPrim::operator()(const FuncGraphPtr &func_graph, const OptimizerPtr &optimizer) {
// Search all taylor nodes.
bool change = false;
auto manager = optimizer->manager();
for (auto &taylor_node : prim_nodes_) {
auto taylor_fg_node = taylor_node->input(1);
auto taylor_fg = GetValueNode<FuncGraphPtr>(taylor_fg_node);
if (taylor_fg == nullptr) {
MS_LOG(EXCEPTION) << "Unexpected Taylor node, input func graph should not be null, node: "
<< taylor_fg->ToString();
}
// Copy original forward graph in case of the influence of usage in other place.
auto taylor_fg_copy = BasicClone(taylor_fg, true);
manager->AddFuncGraph(taylor_fg_copy);
auto taylor_fg_copy_node = NewValueNode(taylor_fg_copy);
// Return expanded taylor graph.
auto expanded_taylor = internal::ExpandTaylor(taylor_fg_copy_node->cast<ValueNodePtr>(), optimizer->resource());
manager->Replace(taylor_node, expanded_taylor);
change = true;
}
return change;
}
} // namespace irpass
} // namespace opt
} // namespace mindspore

View File

@ -0,0 +1,46 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_TAYLOR_ELIMINATE_H_
#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_TAYLOR_ELIMINATE_H_
#include <vector>
#include <algorithm>
#include <memory>
#include "frontend/optimizer/optimizer.h"
#include "frontend/optimizer/irpass.h"
#include "frontend/optimizer/anf_visitor.h"
#include "utils/ms_utils.h"
#include "include/common/utils/primitive_utils.h"
#include "frontend/operator/ops.h"
#include "frontend/optimizer/ad/grad.h"
#include "frontend/optimizer/irpass/meta_fg_prim_eliminate.h"
namespace mindspore {
namespace opt {
namespace irpass {
// {prim::kPrimTaylor, C}
class ExpandTaylorPrim : public ExpandMetaFGPrim {
public:
ExpandTaylorPrim() { prim_ = prim::kPrimTaylor; }
virtual ~ExpandTaylorPrim() = default;
bool operator()(const FuncGraphPtr &func_graph, const OptimizerPtr &optimizer);
};
} // namespace irpass
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_TAYLOR_ELIMINATE_H_

View File

@ -353,6 +353,7 @@ constexpr char SIMPLE_MEAN[] = "SimpleMean";
constexpr char FLATTEN[] = "Flatten";
constexpr char J[] = "J";
constexpr char SHARD[] = "Shard";
constexpr char Taylor[] = "Taylor";
constexpr char TMPIDENTITY_INFO_NAME[] = "identity_info";
constexpr char COS[] = "Cos";
constexpr char ACOS[] = "ACos";

View File

@ -30,6 +30,10 @@ COMMON_EXPORT py::function GetBpropFunctionByObj(const py::object &obj);
COMMON_EXPORT py::function GetBpropFunction(const std::string &name);
COMMON_EXPORT py::function GetTaylorRuleFunctionByObj(const py::object &obj);
COMMON_EXPORT py::function GetTaylorRuleFunction(const std::string &name);
COMMON_EXPORT py::function GetComputeFunction(const std::string &name);
COMMON_EXPORT BaseRef RunComputeFunction(const PrimitivePtr &prim, const VectorRef &args);

View File

@ -52,6 +52,7 @@
#include "frontend/optimizer/irpass/ge_specialized_prepare.h"
#include "frontend/optimizer/irpass/gradient_eliminate.h"
#include "frontend/optimizer/irpass/shard_eliminate.h"
#include "frontend/optimizer/irpass/taylor_eliminate.h"
#include "frontend/optimizer/irpass/parameter_eliminate.h"
#include "frontend/optimizer/irpass/updatestate_eliminate.h"
#if ((defined ENABLE_CPU) && (!defined _WIN32))

View File

@ -572,6 +572,27 @@ EvalResultPtr JEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &arg
return res;
}
EvalResultPtr TaylorEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
const AnfNodeConfigPtr &) {
AbstractBasePtrList args_spec_list;
(void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
[](const ConfigPtr &conf) -> AbstractBasePtr {
MS_EXCEPTION_IF_NULL(conf);
return conf->ObtainEvalResult()->abstract();
});
MS_EXCEPTION_IF_NULL(evaluator_cache_mgr_);
auto eval_result = evaluator_cache_mgr_->GetValue(args_spec_list);
if (eval_result != nullptr) {
return eval_result;
}
// Call the original evaluator, get the result: y = f(x)
EvalResultPtr result = evaluator_->Run(engine, args_conf_list, nullptr);
MS_EXCEPTION_IF_NULL(result);
evaluator_cache_mgr_->SetValue(args_spec_list, result);
return result;
}
EvalResultPtr ShardEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
const AnfNodeConfigPtr &) {
AbstractBasePtrList args_spec_list;

View File

@ -372,6 +372,38 @@ class JEvaluator : public Evaluator {
AbstractFunctionPtr orig_func_;
};
class TaylorEvaluator : public Evaluator {
public:
TaylorEvaluator(const EvaluatorPtr &evaluator, const AbstractFunctionPtr &orig_func)
: Evaluator("TaylorEvaluator"), evaluator_(evaluator), orig_func_(orig_func) {}
~TaylorEvaluator() override = default;
MS_DECLARE_PARENT(TaylorEvaluator, Evaluator);
AnfNodePtr bound_node() const override {
if (evaluator_ != nullptr) {
return evaluator_->bound_node();
}
return bound_node_.lock();
}
void set_bound_node(const AnfNodePtr &node) override {
if (evaluator_ != nullptr) {
evaluator_->set_bound_node(node);
}
bound_node_ = AnfNodeWeakPtr(node);
}
EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &, const AnfNodeConfigPtr &) override {
MS_LOG(EXCEPTION) << "Should not be called, Run() method should be called";
}
EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
const AnfNodeConfigPtr &out_conf) override;
std::string ToString() const override { return identifier_ + "_" + evaluator_->ToString(); }
private:
EvaluatorPtr evaluator_;
AbstractFunctionPtr orig_func_;
};
class ShardEvaluator : public Evaluator {
public:
ShardEvaluator(const EvaluatorPtr &evaluator, const AbstractFunctionPtr &orig_func)

View File

@ -501,6 +501,14 @@ EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const std::shared_ptr<VmapTransfor
return vmap_evaluator;
}
EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const std::shared_ptr<TaylorTransformedAbstractClosure> &func) {
MS_EXCEPTION_IF_NULL(func);
AbstractFunctionPtr func_orig = func->fn();
EvaluatorPtr evaluator_orig = GetEvaluatorFor(func_orig);
auto taylorevaluator = std::make_shared<TaylorEvaluator>(evaluator_orig, func_orig);
return taylorevaluator;
}
EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const std::shared_ptr<ShardTransformedAbstractClosure> &func) {
MS_EXCEPTION_IF_NULL(func);
AbstractFunctionPtr func_orig = func->fn();
@ -548,6 +556,8 @@ EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const AbstractFunctionPtr &func) {
return _GetEvaluatorFor(func->cast<std::shared_ptr<JTransformedAbstractClosure>>());
} else if (func->isa<VmapTransformedAbstractClosure>()) {
return _GetEvaluatorFor(func->cast<std::shared_ptr<VmapTransformedAbstractClosure>>());
} else if (func->isa<TaylorTransformedAbstractClosure>()) {
return _GetEvaluatorFor(func->cast<std::shared_ptr<TaylorTransformedAbstractClosure>>());
} else if (func->isa<ShardTransformedAbstractClosure>()) {
return _GetEvaluatorFor(func->cast<std::shared_ptr<ShardTransformedAbstractClosure>>());
} else if (func->isa<VirtualAbstractClosure>()) {

View File

@ -324,6 +324,7 @@ class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> {
EvaluatorPtr _GetEvaluatorFor(const std::shared_ptr<VirtualAbstractClosure> &fn);
EvaluatorPtr _GetEvaluatorFor(const std::shared_ptr<TypedPrimitiveAbstractClosure> &);
EvaluatorPtr _GetEvaluatorFor(const std::shared_ptr<JTransformedAbstractClosure> &fn);
EvaluatorPtr _GetEvaluatorFor(const std::shared_ptr<TaylorTransformedAbstractClosure> &fn);
EvaluatorPtr _GetEvaluatorFor(const std::shared_ptr<ShardTransformedAbstractClosure> &fn);
EvaluatorPtr _GetEvaluatorFor(const std::shared_ptr<VmapTransformedAbstractClosure> &fn);

View File

@ -151,6 +151,18 @@ py::function PrimitivePy::GetBpropFunction() {
auto fn = GetBpropFunctionByObj(python_obj_);
return fn;
}
auto fn = GetBpropFunctionByObj(python_obj_);
return fn;
}
py::function PrimitivePy::GetTaylorRuleFunction() {
static const char *const get_taylor_rule_func_name = "get_taylor_rule";
if (py::hasattr(python_obj_, get_taylor_rule_func_name)) {
py::function fn = python_obj_.attr(get_taylor_rule_func_name)().cast<py::function>();
return fn;
}
auto fn = GetTaylorRuleFunctionByObj(python_obj_);
return fn;
}
py::tuple check_bprop_out(const py::object &grads_obj, const py::tuple &py_args, const std::string &bprop_cls_name) {

View File

@ -50,6 +50,7 @@ class PrimitivePy : public Primitive {
const bool parse_info_ = true;
py::function GetVmapRuleFunction(const bool is_side_effect = false, int axis_size = 0);
py::function GetBpropFunction();
py::function GetTaylorRuleFunction();
void set_signatures(const std::vector<Signature> &signatures);
const std::vector<Signature> &signatures() const { return signatures_; }
const std::map<int, py::function> &backward_hook_fn() const { return backward_hook_fn_; }

View File

@ -41,6 +41,18 @@ py::function GetBpropFunction(const std::string &name) {
return fn;
}
py::function GetTaylorRuleFunctionByObj(const py::object &obj) {
static const std::string get_taylor_fprop_fn = "get_taylor_fprop_fn";
static const std::string ad_module = "mindspore.ops._grad";
py::function fn = python_adapter::GetPyFn(ad_module, get_taylor_fprop_fn)(obj);
return fn;
}
py::function GetTaylorRuleFunction(const std::string &name) {
auto fn = GetTaylorRuleFunctionByObj(py::str(name));
return fn;
}
py::function GetComputeFunction(const std::string &name) {
static const std::string module = "mindspore._extends.builtin_operations";
py::module mod = py::module::import(common::SafeCStr(module));

View File

@ -297,6 +297,20 @@ std::size_t JTransformedAbstractClosure::hash() const {
return hash_value;
}
bool TaylorTransformedAbstractClosure::operator==(const AbstractFunction &other) const {
if (!other.isa<TaylorTransformedAbstractClosure>()) {
return false;
}
auto other_transformed = static_cast<const TaylorTransformedAbstractClosure *>(&other);
return fn_ == other_transformed->fn_;
}
std::size_t TaylorTransformedAbstractClosure::hash() const {
MS_EXCEPTION_IF_NULL(fn_);
auto hash_value = hash_combine(tid(), fn_->hash());
return hash_value;
}
bool ShardTransformedAbstractClosure::operator==(const AbstractFunction &other) const {
if (!other.isa<ShardTransformedAbstractClosure>()) {
return false;

View File

@ -339,6 +339,36 @@ class MS_CORE_API JTransformedAbstractClosure final : public AbstractFuncAtom {
AbstractFuncAtomPtr fn_;
};
/// \brief TaylorTransformedAbstractClosure defines interface for abstract of Function
/// transformed through the application of Taylor.
class MS_CORE_API TaylorTransformedAbstractClosure final : public AbstractFuncAtom {
public:
/// \brief Constructor of TaylorTransformedAbstractClosure
///
/// \param[in] fn The AbstractFuncAtom transformed through the application of Taylor.
explicit TaylorTransformedAbstractClosure(const AbstractFuncAtomPtr &fn) : fn_(fn) {}
/// \brief Destructor of TaylorTransformedAbstractClosure
~TaylorTransformedAbstractClosure() override = default;
MS_DECLARE_PARENT(TaylorTransformedAbstractClosure, AbstractFuncAtom)
/// \brief Get the AbstractFuncAtom TaylorTransformedAbstractClosure corresponding to.
///
/// \return The AbstractFuncAtom TaylorTransformedAbstractClosure corresponding to.
AbstractFuncAtomPtr fn() { return fn_; }
AbstractFunctionPtr Copy() const override { return std::make_shared<TaylorTransformedAbstractClosure>(fn_); }
bool operator==(const AbstractFunction &other) const override;
std::size_t hash() const override;
std::string ToString() const override { return "Taylor(" + fn_->ToString() + ")"; }
private:
AbstractFuncAtomPtr fn_;
};
/// \brief ShardTransformedAbstractClosure defines interface for abstract of Function
/// transformed through the application of Shard.
class MS_CORE_API ShardTransformedAbstractClosure final : public AbstractFuncAtom {

View File

@ -171,6 +171,7 @@ constexpr auto kCSRDiv = "CSRDiv";
// Meta Function Graph
constexpr auto kJ = "J";
constexpr auto kVmap = "Vmap";
constexpr auto kTaylor = "Taylor";
// Others
constexpr auto kMakeTuple = "MakeTuple";
@ -828,6 +829,7 @@ GVAR_DEF(PrimitivePtr, kPrimStateSetItem, std::make_shared<Primitive>("state_set
GVAR_DEF(PrimitivePtr, kPrimJ, std::make_shared<Primitive>(kJ, kSideEffectPropagate));
GVAR_DEF(PrimitivePtr, kPrimVmap, std::make_shared<Primitive>(kVmap, kSideEffectPropagate));
GVAR_DEF(PrimitivePtr, kPrimShard, std::make_shared<Primitive>("Shard", kSideEffectPropagate));
GVAR_DEF(PrimitivePtr, kPrimTaylor, std::make_shared<Primitive>(kTaylor));
// Used to build graph which have keyword arguments
GVAR_DEF(PrimitivePtr, kPrimExtractKeywordArg, std::make_shared<Primitive>("extract_keyword_arg"));

View File

@ -313,7 +313,7 @@ std::shared_ptr<std::list<FuncGraphPtr>> FuncGraphManager::recursive_graphs(cons
}
}
// Check if the function graph embed with `MetaFGPrim`, which currently covers kPrimJ and kPrimVmap.
// Check if the function graph embed with `MetaFGPrim`, which currently covers kPrimJ and kPrimVmap and kPrimTaylor.
bool FuncGraphManager::func_graph_meta_fg_prim_total(const FuncGraphPtr &fg) const {
MS_EXCEPTION_IF_NULL(meta_fg_prim_total_);
MS_EXCEPTION_IF_NULL(fg);
@ -704,7 +704,8 @@ void FuncGraphManager::OnEdgeAdded(const AnfNodePtr &node, int index, const AnfN
signals_->InvalidateComputer();
}
}
if (IsPrimitiveCNode(node, prim::kPrimJ) || IsPrimitiveCNode(node, prim::kPrimVmap)) {
if (IsPrimitiveCNode(node, prim::kPrimJ) || IsPrimitiveCNode(node, prim::kPrimVmap) ||
IsPrimitiveCNode(node, prim::kPrimTaylor)) {
fg->AddMetaFgPrimValueNode(input);
}
} else if (fg != nullptr && fg != input->func_graph()) {
@ -725,7 +726,8 @@ void FuncGraphManager::OnEdgeRemoved(const AnfNodePtr &node, int index, const An
signals_->InvalidateComputer();
}
}
if (IsPrimitiveCNode(node, prim::kPrimJ) || IsPrimitiveCNode(node, prim::kPrimVmap)) {
if (IsPrimitiveCNode(node, prim::kPrimJ) || IsPrimitiveCNode(node, prim::kPrimVmap) ||
IsPrimitiveCNode(node, prim::kPrimTaylor)) {
fg->DropMetaFgPrimValueNode(input);
}
} else if (fg != nullptr && fg != input->func_graph()) {
@ -1096,7 +1098,7 @@ bool FuncGraphMetaFgPrimTotalComputer::SeekMetaFgPrim(const FuncGraphPtr &fg, Se
return false;
}
// Check MetaFgPrim (J/Vmap) FuncGraph input.
// Check MetaFgPrim (J/Vmap/Taylor) FuncGraph input.
const auto &meta_fg_prim_values = fg->meta_fg_prim_value_nodes();
if (!meta_fg_prim_values.empty()) {
auto contains_meta_fg_prim =
@ -1107,9 +1109,10 @@ bool FuncGraphMetaFgPrimTotalComputer::SeekMetaFgPrim(const FuncGraphPtr &fg, Se
return func_graph->seen_ != seen_num;
}
if (IsValueNode<Primitive>(iter.first)) {
// Exclude the primitive of MetaFgPrim (J/Vmap) itself.
// Exclude the primitive of MetaFgPrim (J/Vmap/Taylor) itself.
auto prim = GetValueNode<PrimitivePtr>(iter.first);
return (prim->name() != prim::kPrimJ->name() && prim->name() != prim::kPrimVmap->name());
return (prim->name() != prim::kPrimJ->name() && prim->name() != prim::kPrimVmap->name() &&
prim->name() != prim::kPrimTaylor->name());
}
return false;
});
@ -1119,35 +1122,38 @@ bool FuncGraphMetaFgPrimTotalComputer::SeekMetaFgPrim(const FuncGraphPtr &fg, Se
}
}
// Check MetaFgPrim (J/Vmap) CNode as FV.
// Check MetaFgPrim (J/Vmap/Taylor) CNode as FV.
const auto &fv_nodes = fg->free_variables();
if (!fv_nodes.empty()) {
auto contains_meta_fg_prim_cnode = std::find_if(fv_nodes.begin(), fv_nodes.end(), [seen_num](const auto &iter) {
// Check if the FV is a MetaFgPrim (J/Vmap) call CNode.
if (IsPrimitiveCNode(iter.first, prim::kPrimJ) || IsPrimitiveCNode(iter.first, prim::kPrimVmap)) {
// Check if the FV is a MetaFgPrim (J/Vmap/Taylor) call CNode.
if (IsPrimitiveCNode(iter.first, prim::kPrimJ) || IsPrimitiveCNode(iter.first, prim::kPrimVmap) ||
IsPrimitiveCNode(iter.first, prim::kPrimTaylor)) {
return true;
}
return false;
});
if (contains_meta_fg_prim_cnode != fv_nodes.end()) {
MS_LOG(DEBUG) << fg->ToString() << " contains FV MetaFgPrim (J/Vmap) ("
MS_LOG(DEBUG) << fg->ToString() << " contains FV MetaFgPrim (J/Vmap/Taylor) ("
<< contains_meta_fg_prim_cnode->first->DebugString() << ")";
return true;
}
}
// Check if func graphs used contains J(func_graph), J(Primitive), Vmap(func_graph) or Vmap(Primitive)
// Check if func graphs used contains J(func_graph), J(Primitive), Vmap(func_graph), Vmap(Primitive),
// Taylor(func_graph) or Taylor(Primitive).
fg->seen_ = seen_num;
for (auto &item : fg->func_graphs_used()) {
auto used_g = item.first;
if (SeekMetaFgPrim(used_g, seen_num)) {
MS_LOG(DEBUG) << fg->ToString() << " users func graph " << used_g->ToString()
<< " which contains J(func_graph), J(Primitive), Vmap(func_graph) or Vmap(Primitive)";
<< " which contains J(func_graph), J(Primitive), Vmap(func_graph), Vmap(Primitive), "
<< "Taylor(func_graph) or Taylor(Primitive)";
return true;
}
}
MS_LOG(DEBUG) << fg->ToString()
<< " doesn't contain J(func_graph), J(Primitive), Vmap(func_graph) or Vmap(Primitive)";
MS_LOG(DEBUG) << fg->ToString() << " doesn't contain J(func_graph), J(Primitive), Vmap(func_graph), Vmap(Primitive), "
<< "Taylor(func_graph) or Taylor(Primitive)";
return false;
}

View File

@ -15,7 +15,7 @@
"""grad impl."""
from . import grad_array_ops, grad_comm_ops, grad_debug_ops, grad_implementations, \
grad_math_ops, grad_nn_ops, grad_other_ops, grad_quant_ops, grad_sparse, grad_inner_ops
from .grad_base import get_bprop_fn
grad_math_ops, grad_nn_ops, grad_other_ops, grad_quant_ops, grad_sparse, grad_inner_ops, taylor_rule
from .grad_base import get_bprop_fn, get_taylor_fprop_fn
__all__ = ['get_bprop_fn']
__all__ = ['get_bprop_fn', 'get_taylor_fprop_fn']

View File

@ -37,8 +37,28 @@ class BpropRegistry(Registry):
return deco
class TaylorFpropRegistry(Registry):
"""Registry class for registry functions for taylor grad on Primitive or string."""
def register(self, prim):
"""register the function."""
def deco(fn):
"""Decorate the function."""
if isinstance(prim, str):
self[prim] = fn
elif issubclass(prim, Primitive):
self[id(prim)] = fn
self[prim.__name__] = fn
return fn
return deco
bprop_getters = BpropRegistry()
bprops = BpropRegistry()
taylor_fprop_getters = TaylorFpropRegistry()
taylor_fprops = TaylorFpropRegistry()
def get_bprop_fn(prim):
@ -47,3 +67,11 @@ def get_bprop_fn(prim):
if out:
return out(prim)
return bprops.get(prim, None)
def get_taylor_fprop_fn(prim):
"""get taylor function by primitive obj or prim name for c++"""
out = taylor_fprop_getters.get(prim, None)
if out:
return out(prim)
return taylor_fprops.get(prim, None)

View File

@ -0,0 +1,166 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Define the taylor rules of operations."""
from mindspore import nn
import mindspore as ms
from ..primitive import Primitive
from .. import operations as P
from ..composite.multitype_ops.zeros_like_impl import zeros_like
from .grad_base import taylor_fprop_getters
def _factorial(order):
"""Return [0!, 1!, 2!,..., order!]."""
range_op = nn.Range(1, order + 1)
ones_op = P.Ones()
concat_op = P.Concat()
factorial_zero = ones_op(1, ms.float32)
factorial_positive = range_op().astype(ms.float32)
for i in range(1, order):
factorial_positive[i] *= factorial_positive[i - 1]
factorial = concat_op((factorial_zero, factorial_positive))
return factorial
@taylor_fprop_getters.register(P.Add)
@taylor_fprop_getters.register(P.Sub)
def taylor_add_or_sub(self):
"""Higher order derivatives rule definition for `Add` or `Sub`operation."""
if isinstance(self, str):
prim = Primitive(self)
else:
prim = self
def taylor_fprop_add_or_sub(input_x, input_y):
series = prim(input_x, input_y)
return series
return taylor_fprop_add_or_sub
@taylor_fprop_getters.register(P.Mul)
def taylor_mul(self):
"""Higher order derivatives rule definition for `Mul` operation."""
mul_func = P.Mul()
def taylor_fprop_mul(input_x, input_y):
primals = mul_func(input_x[0], input_y[0])
series_num = len(input_x) - 1
factorial = _factorial(series_num)
series = zeros_like(input_x)
series[0] = primals
for k in range(1, series_num + 1):
for i in range(0, k + 1):
tmp = input_x[i] * input_y[k - i] / (factorial[k - i] * factorial[i])
series[k] += tmp
series[k] *= factorial[k]
return series
return taylor_fprop_mul
@taylor_fprop_getters.register(P.RealDiv)
def taylor_realdiv(self):
"""Higher order derivatives rule definition for `RealDiv` operation."""
div_op = P.Div()
def taylor_fprop_realdiv(input_x, input_y):
primals = div_op(input_x[0], input_y[0])
series_num = len(input_x) - 1
factorial = _factorial(series_num)
series = zeros_like(input_x)
series[0] = primals
for k in range(1, series_num + 1):
for i in range(0, k):
tmp = series[i] * input_y[k - i] / (factorial[k - i] * factorial[i])
series[k] += tmp
series[k] = (input_x[k] - factorial[k] * series[k]) / input_y[0]
return series
return taylor_fprop_realdiv
@taylor_fprop_getters.register(P.Exp)
def taylor_exp(self):
"""Higher order derivatives rule definition for `Exp` operation."""
exp_ = P.Exp()
def taylor_fprop_exp(inputs):
primals = exp_(inputs[0])
series_num = len(inputs) - 1
factorial = _factorial(series_num)
series = zeros_like(inputs)
series[0] = primals
for k in range(1, series_num + 1):
for i in range(1, k + 1):
tmp = i * inputs[i] * series[k - i] / (factorial[k - i] * factorial[i])
series[k] += tmp
series[k] *= factorial[k - 1]
return series
return taylor_fprop_exp
@taylor_fprop_getters.register(P.Sin)
def taylor_sin(self):
"""Higher order derivatives rule definition for `Sin` operation."""
cos = P.Cos()
sin = P.Sin()
def taylor_fprop_sin(inputs):
primal_sin = sin(inputs[0])
primal_cos = cos(inputs[0])
series_sin = zeros_like(inputs)
series_cos = zeros_like(inputs)
series_sin[0] = primal_sin
series_cos[0] = primal_cos
series_num = len(inputs) - 1
factorial = _factorial(series_num)
for k in range(1, series_num + 1):
for i in range(1, k + 1):
series_sin[k] += i * inputs[i] * series_cos[k - i] / (factorial[i] * factorial[k - i])
series_cos[k] -= i * inputs[i] * series_sin[k - i] / (factorial[i] * factorial[k - i])
series_sin[k] *= factorial[k - 1]
series_cos[k] *= factorial[k - 1]
return series_sin
return taylor_fprop_sin
@taylor_fprop_getters.register(P.Cos)
def taylor_cos(self):
"""Higher order derivatives rule definition for `Cos` operation."""
cos = P.Cos()
sin = P.Sin()
def taylor_fprop_cos(inputs):
primal_cos = cos(inputs[0])
primal_sin = sin(inputs[0])
series_cos = zeros_like(inputs)
series_sin = zeros_like(inputs)
series_cos[0] = primal_cos
series_sin[0] = primal_sin
series_num = len(inputs) - 1
factorial = _factorial(series_num)
for k in range(1, series_num + 1):
for i in range(1, k + 1):
series_cos[k] -= i * inputs[i] * series_sin[k - i] / (factorial[i] * factorial[k - i])
series_sin[k] += i * inputs[i] * series_cos[k - i] / (factorial[i] * factorial[k - i])
series_cos[k] *= factorial[k - 1]
series_sin[k] *= factorial[k - 1]
return series_cos
return taylor_fprop_cos

View File

@ -21,7 +21,7 @@ Pre-defined combination of operators.
from .base import GradOperation, _Grad, HyperMap, Map, MultitypeFuncGraph, add_flags, \
core, env_get, tail, zip_operation, Shard, _Vmap
core, env_get, tail, zip_operation, Shard, _Vmap, _TaylorOperation
from .clip_ops import clip_by_value, clip_by_global_norm
from .multitype_ops.add_impl import hyper_add
from .multitype_ops.ones_like_impl import ones_like

View File

@ -22,7 +22,7 @@ from types import FunctionType
from mindspore import context
from ..._c_expression import GradOperation_, HyperMap_, Map_, MultitypeFuncGraph_, Tail_, Shard_, \
TupleAdd_, TupleSlice_, UnpackCall_, ZipOperation_, ListAppend_, TupleGetItemTensor_, ListInsert_, \
ListSlice_, VmapOperation_
ListSlice_, VmapOperation_, TaylorOperation_
from ...common import dtype as mstype
from ...common.api import ms_function, _pynative_executor, _wrap_func
from ..primitive import Primitive
@ -401,6 +401,29 @@ class GradOperation(GradOperation_):
return self.grad_fn
class _TaylorOperation(TaylorOperation_):
"""
Generate the higher order derivatives function for the input function.
"""
def __init__(self):
"""Initialize TaylorOperation."""
TaylorOperation_.__init__(self, 'taylorgrad')
self.grad_fn = None
self.fn = None
def __call__(self, fn):
if self.grad_fn is not None and self.fn == fn:
return self.grad_fn
taylor_grad_ = _TaylorOperation()
# If calling Grad in GRAPH_MODE or calling Grad in ms_function, do grad in GRAPH_MODE
@ms_function
def after_taylor_grad(*args):
return taylor_grad_(fn)(*args)
self.grad_fn = after_taylor_grad
self.fn = fn
return self.grad_fn
class _Grad(GradOperation_):
"""
A higher-order function which is used to generate the gradient function by position for the input function.

View File

@ -32,7 +32,7 @@ from .primitive import Primitive
from . import operations as P
from .operations import _grad_ops
from .operations import _csr_ops
from .composite import _Grad, Shard, _Vmap
from .composite import _Grad, Shard, _Vmap, _TaylorOperation
from .._c_expression import security
typeof = Primitive('typeof')
@ -48,6 +48,7 @@ eye = P.Eye()
fill = P.Fill()
tile = P.Tile()
size = P.Size()
ones = P.Ones()
ones_like = P.OnesLike()
shape = P.Shape()
dyn_shape = P.TensorShape()
@ -285,6 +286,187 @@ def grad(fn, grad_position=0, sens_param=False):
return grad_by_position_with_sens(fn, None, grad_position)
return grad_by_position(fn, None, grad_position)
@constexpr
def _trans_jet_inputs(primals_item, series_item):
"""Trans inputs of jet"""
value_type = [mstype.int32, mstype.int64, mstype.float32, mstype.float64]
if not dtype(primals_item) in value_type or dtype(primals_item) != dtype(series_item):
raise TypeError(f"For `F.jet`, the elements' types of primals and series should be the same and belong to "
f"`mstype.int32, mstype.int64, mstype.float32, mstype.float64`, but got"
f" {dtype(primals_item).__name__} and {dtype(series_item).__name__}.")
if dtype(primals_item) in [mstype.int32, mstype.int64]:
return cast(primals_item, mstype.float64), cast(series_item, mstype.float64)
return primals_item, series_item
@constexpr
def _check_jet_inputs(primals, series):
"""Check inputs of jet"""
if not isinstance(primals, type(series)) or not isinstance(primals, (Tensor, tuple)):
raise TypeError(f"For 'F.jet', the 'primals' and `series` should be both Tensor or tuple, "
f"but got {type(primals).__name__} and {type(series).__name__}.")
if isinstance(primals, Tensor):
if primals.shape != series.shape[1:]:
raise ValueError("The shape of each element should be the same as the primals.")
return _trans_jet_inputs(primals, series)
if isinstance(primals, tuple):
if len(primals) != len(series):
raise ValueError("The lengths of primals and series should be the same.")
check_primals = []
check_series = []
for i, j in zip(primals, series):
trans_primals_item, trans_series_item = _trans_jet_inputs(i, j)
check_primals.append(trans_primals_item)
check_series.append(trans_series_item)
return check_primals, check_series
_taylor = _TaylorOperation()
def jet(fn, primals, series):
"""
This function is designed to calculate the higher order differentiation of given composite function. To figure out
first to `n`-th order differentiations, original inputs and first to `n`-th order derivative of original inputs
must be provided together. Generally, it is recommended to set the values of given first order derivative to 1,
while the other to 0.
Args:
fn (Union(Cell, function)): Function to do TaylorOperation.
primals (Union(Tensor, Tuple of Tensors)): The inputs to `fn`.
series (Union(Tensor, Tuple of Tensors)): If tuple, the length and type of series should be the same as inputs.
For each Tensor, the length of first dimension `i` represents the `1` to `i+1`-th order of derivative of
output with respect to the inputs will be figured out.
Returns:
Tuple, tuple of out_primals and out_series.
- **out_primals** (Tensors or List of Tensors) - The output of `fn(primals)`.
- **out_series** (Tensors or List of Tensors) - The `1` to `i+1`-th order of derivative of output with respect
to the inputs.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> import numpy as np
>>> import mindspore.nn as nn
>>> import mindspore.context as context
>>> import mindspore.ops as P
>>> from mindspore import Tensor
>>> from mindspore.ops.functional import jet
>>> context.set_context(mode=context.GRAPH_MODE)
>>> class Net(nn.Cell):
... def __init__(self):
... super().__init__()
... self.sin = P.Sin()
... self.exp = P.Exp()
... def construct(self, x):
... out1 = self.sin(x)
... out2 = self.exp(out1)
... return out2
>>> primals = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
>>> series = Tensor(np.array([[[1, 1], [1, 1]], [[0, 0], [0, 0]], [[0, 0], [0, 0]]]).astype(np.float32))
>>> net = Net()
>>> out_primals, out_series = jet(net, primals, series)
>>> print(out_primals, out_series)
"""
primals, series = _check_jet_inputs(primals, series)
derivative_fn = _taylor(fn)
concat_op = P.Concat()
if isinstance(primals, list) and list_len(primals) > 1:
inputs = list(map(lambda x, y: concat_op(((expand_dims(x, 0), y))), primals, series))
outputs = derivative_fn(*inputs)
else:
inputs = concat_op((expand_dims(primals, 0), series))
outputs = derivative_fn(inputs)
if isinstance(outputs, list) and list_len(outputs) > 1:
out_primals = [element[0] for element in outputs]
out_series = [element[1:] for element in outputs]
else:
out_primals = outputs[0]
out_series = outputs[1:]
return out_primals, out_series
@constexpr
def _trans_derivative_inputs(primals_item):
"""Trans inputs of derivative"""
value_type = [mstype.int32, mstype.int64, mstype.float32, mstype.float64]
if not dtype(primals_item) in value_type:
raise TypeError(f"For `F.derivative`, the elements of primals should belong to "
f"`mstype.int32, mstype.int64, mstype.float32, mstype.float64`, but got"
f" {dtype(primals_item).__name__}.")
if dtype(primals_item) in [mstype.int32, mstype.int64]:
return cast(primals_item, mstype.float64)
return primals_item
def derivative(fn, primals, order):
"""
This function is designed to calculate the higher order differentiation of given composite function. To figure out
`order`-th order differentiations, original inputs and order must be provided together. In particular, the value of
input first order derivative is set to 1, while the other to 0.
Args:
fn (Union(Cell, function)): Function to do TaylorOperation.
primals (Union(Tensor, Tuple of Tensors)): The inputs to `fn`.
order (int): For each Tensor, the `order`-th order of derivative of output with respect to the inputs will be
figured out.
Returns:
Tuple, tuple of out_primals and out_series.
- **out_primals** (Tensors or List of Tensors) - The output of `fn(primals)`.
- **out_series** (Tensors or List of Tensors) - The `order`-th order of derivative of output with respect
to the inputs.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> import numpy as np
>>> import mindspore.nn as nn
>>> import mindspore.context as context
>>> import mindspore.ops as P
>>> from mindspore import Tensor
>>> from mindspore.ops.functional import derivative
>>> context.set_context(mode=context.GRAPH_MODE)
>>> class Net(nn.Cell):
... def __init__(self):
... super().__init__()
... self.sin = P.Sin()
... self.exp = P.Exp()
... def construct(self, x):
... out1 = self.sin(x)
... out2 = self.exp(out1)
... return out2
>>> primals = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
>>> order = 3
>>> net = Net()
>>> out_primals, out_series = derivative(net, primals, order)
>>> print(out_primals, out_series)
"""
derivative_fn = _taylor(fn)
concat_op = P.Concat()
series_one = 1
if isinstance(primals, tuple):
trans_primals = [_trans_derivative_inputs(item) for item in primals]
inputs = list(map(lambda x: concat_op((expand_dims(x, 0), ones((1,) + x.shape, dtype(x)))), trans_primals))
if order > 1:
inputs = list(map(lambda x: concat_op((x, zeros((order - 1,) + x[0].shape, dtype(x)))), inputs))
outputs = derivative_fn(*inputs)
else:
primals = _trans_derivative_inputs(primals)
series = zeros((order,) + primals.shape, dtype(primals))
series[0] = series_one
inputs = concat_op((expand_dims(primals, 0), series))
outputs = derivative_fn(inputs)
if isinstance(outputs, tuple) and tuple_len(outputs) > 1:
out_primals = [element[0] for element in outputs]
out_series = [element[-1] for element in outputs]
else:
out_primals = outputs[0]
out_series = outputs[-1]
return out_primals, out_series
def jvp(fn, inputs, v):
"""

View File

@ -38,6 +38,11 @@ class MultipleInputsOutputNet(nn.Cell):
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_vjp_single_input_graph():
"""
Features: Function vjp
Description: Test vjp with single input, single output and default v in graph mode.
Expectation: No exception.
"""
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
net = SingleInputNet()
@ -53,6 +58,11 @@ def test_vjp_single_input_graph():
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_vjp_multiple_inputs_default_v_graph():
"""
Features: Function vjp
Description: Test vjp with single input, single output and default v in graph mode.
Expectation: No exception.
"""
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
@ -140,8 +150,8 @@ def test_vjp_construct_single_input_single_output_default_v_graph():
net_out, vjp_out = vjp(self.net, inputs, vectors)
return net_out, vjp_out
test_net = Net(SingleInputNet())
primal, grad = test_net(x, v)
test_net_graph = Net(SingleInputNet())
primal, grad = test_net_graph(x, v)
expect_primal = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
expect_grad = Tensor(np.array([[3, 12], [27, 48]]).astype(np.float32))
assert np.allclose(primal.asnumpy(), expect_primal.asnumpy())

View File

@ -18,7 +18,6 @@ import pytest
import mindspore.nn as nn
import mindspore.context as context
from mindspore import Tensor
from mindspore import ms_function
from mindspore.ops.functional import vjp
context.set_context(mode=context.PYNATIVE_MODE)
@ -37,12 +36,17 @@ class MultipleInputsOutputNet(nn.Cell):
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_vjp_single_input_graph():
def test_vjp_single_input_pynative():
"""
Features: Function vjp
Description: Test vjp with single input, single output and default v in pynative mode.
Expectation: No exception.
"""
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
net = SingleInputNet()
expect_primal = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
expect_grad = Tensor(np.array([[3, 12], [27, 48]]).astype(np.float32))
expect_primal = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
primal, grad = vjp(net, x, v)
assert np.allclose(primal.asnumpy(), expect_primal.asnumpy())
assert np.allclose(grad.asnumpy(), expect_grad.asnumpy())
@ -51,58 +55,38 @@ def test_vjp_single_input_graph():
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_vjp_multiple_inputs_default_v_graph():
def test_vjp_multiple_inputs_default_v_pynative():
"""
Features: Function vjp
Description: Test vjp with multiple inputs, multiple outputs and default v in pynative mode.
Expectation: No exception.
"""
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
net = MultipleInputsOutputNet()
expect_primal_0 = Tensor(np.array([[2, 4], [6, 8]]).astype(np.float32))
expect_primal_1 = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
expect_grad_0 = Tensor(np.array([[2, 2], [2, 2]]).astype(np.float32))
expect_grad_1 = Tensor(np.array([[3, 12], [27, 48]]).astype(np.float32))
expect_primal_0 = Tensor(np.array([[2, 4], [6, 8]]).astype(np.float32))
expect_primal_1 = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
primal, grad = vjp(net, (x, y), (v, v))
assert isinstance(primal, tuple)
assert len(primal) == 2
assert np.allclose(primal[0].asnumpy(), expect_primal_0.asnumpy())
assert np.allclose(primal[1].asnumpy(), expect_primal_1.asnumpy())
assert isinstance(grad, tuple)
assert len(grad) == 2
assert np.allclose(grad[0].asnumpy(), expect_grad_0.asnumpy())
assert np.allclose(grad[1].asnumpy(), expect_grad_1.asnumpy())
assert isinstance(primal, tuple)
assert len(primal) == 2
assert np.allclose(primal[0].asnumpy(), expect_primal_0.asnumpy())
assert np.allclose(primal[1].asnumpy(), expect_primal_1.asnumpy())
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_vjp_ms_function_single_input_single_output_default_v_graph():
def test_vjp_input_function_single_input_single_output_default_v_pynative():
"""
Features: Function vjp
Description: Test vjp with ms_function, single input, single output and default v in graph mode.
Expectation: No exception.
"""
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
net = SingleInputNet()
@ms_function
def vjp_with_ms_function(inputs, vectors):
output, vjp_grad = vjp(net, inputs, vectors)
return output, vjp_grad
primal, grad = vjp_with_ms_function(x, v)
expect_primal = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
expect_grad = Tensor(np.array([[3, 12], [27, 48]]).astype(np.float32))
assert np.allclose(primal.asnumpy(), expect_primal.asnumpy())
assert np.allclose(grad.asnumpy(), expect_grad.asnumpy())
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_vjp_input_function_single_input_single_output_default_v_graph():
"""
Features: Function vjp
Description: Test vjp with function, single input, single output and default v in graph mode.
Description: Test vjp with function, single input, single output and default v in pynative mode.
Expectation: No exception.
"""
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
@ -112,8 +96,8 @@ def test_vjp_input_function_single_input_single_output_default_v_graph():
return inputs**3
primal, grad = vjp(test_function, x, v)
expect_primal = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
expect_grad = Tensor(np.array([[3, 12], [27, 48]]).astype(np.float32))
expect_primal = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
assert np.allclose(primal.asnumpy(), expect_primal.asnumpy())
assert np.allclose(grad.asnumpy(), expect_grad.asnumpy())
@ -121,10 +105,10 @@ def test_vjp_input_function_single_input_single_output_default_v_graph():
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_vjp_construct_single_input_single_output_default_v_graph():
def test_vjp_construct_single_input_single_output_default_v_pynative():
"""
Features: Function vjp
Description: Test vjp with function, single input, single output and default v in graph mode.
Description: Test vjp with function, single input, single output and default v in pynative mode.
Expectation: No exception.
"""
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
@ -139,8 +123,8 @@ def test_vjp_construct_single_input_single_output_default_v_graph():
net_out, vjp_out = vjp(self.net, inputs, vectors)
return net_out, vjp_out
test_net = Net(SingleInputNet())
primal, grad = test_net(x, v)
test_net_pynative = Net(SingleInputNet())
primal, grad = test_net_pynative(x, v)
expect_primal = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
expect_grad = Tensor(np.array([[3, 12], [27, 48]]).astype(np.float32))
assert np.allclose(primal.asnumpy(), expect_primal.asnumpy())

View File

@ -0,0 +1,142 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""test taylor differentiation in graph mode"""
import pytest
import numpy as np
import mindspore.nn as nn
import mindspore.context as context
from mindspore.ops import operations as P
from mindspore import Tensor
from mindspore.ops.functional import jet, derivative
context.set_context(mode=context.GRAPH_MODE)
class MultipleInputSingleOutputNet(nn.Cell):
def __init__(self):
super(MultipleInputSingleOutputNet, self).__init__()
self.sin = P.Sin()
self.cos = P.Cos()
self.exp = P.Exp()
def construct(self, x, y):
out1 = self.sin(x)
out2 = self.cos(y)
out3 = out1 * out2 + out1 / out2
out = self.exp(out3)
return out
class SingleInputSingleOutputNet(nn.Cell):
def __init__(self):
super(SingleInputSingleOutputNet, self).__init__()
self.sin = P.Sin()
self.cos = P.Cos()
self.exp = P.Exp()
def construct(self, x):
out1 = self.sin(x)
out2 = self.cos(out1)
out3 = self.exp(out2)
out = out1 + out2 - out3
return out
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_jet_single_input_single_output_graph_mode():
"""
Features: Function jet
Description: Test jet with single input in graph mode.
Expectation: No exception.
"""
primals = Tensor([1., 1.])
series = Tensor([[1., 1.], [0., 0.], [0., 0.]])
net = SingleInputSingleOutputNet()
expected_primals = np.array([-0.43931, -0.43931]).astype(np.float32)
expected_series = np.array([[0.92187, 0.92187], [-1.56750, -1.56750], [-0.74808, -0.74808]]).astype(np.float32)
out_primals, out_series = jet(net, primals, series)
assert np.allclose(out_series.asnumpy(), expected_series, atol=1.e-4)
assert np.allclose(out_primals.asnumpy(), expected_primals, atol=1.e-4)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_derivative_single_input_single_output_graph_mode():
"""
Features: Function derivative
Description: Test derivative with single input in graph mode.
Expectation: No exception.
"""
primals = Tensor([1., 1.])
order = 3
net = SingleInputSingleOutputNet()
expected_primals = np.array([-0.43931, -0.43931]).astype(np.float32)
expected_series = np.array([-0.74808, -0.74808]).astype(np.float32)
out_primals, out_series = derivative(net, primals, order)
assert np.allclose(out_primals.asnumpy(), expected_primals, atol=1.e-4)
assert np.allclose(out_series.asnumpy(), expected_series, atol=1.e-4)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_jet_multiple_input_single_output_graph_mode():
"""
Features: Function jet
Description: Test jet with multiple inputs in graph mode.
Expectation: No exception.
"""
primals = (Tensor([1., 1.]), Tensor([1., 1.]))
series = (Tensor([[1., 1.], [0., 0.], [0., 0.]]), Tensor([[1., 1.], [0., 0.], [0., 0.]]))
net = MultipleInputSingleOutputNet()
expected_primals = np.array([7.47868, 7.47868]).astype(np.float32)
expected_series = np.array([[22.50614, 22.50614], [133.92517, 133.92517], [1237.959, 1237.959]]).astype(np.float32)
out_primals, out_series = jet(net, primals, series)
assert np.allclose(out_primals.asnumpy(), expected_primals, atol=1.e-4)
assert np.allclose(out_series.asnumpy(), expected_series, atol=1.e-4)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_derivative_multiple_input_single_output_graph_mode():
"""
Features: Function derivative
Description: Test derivative with multiple inputs in graph mode.
Expectation: No exception.
"""
primals = (Tensor([1., 1.]), Tensor([1., 1.]))
order = 3
net = MultipleInputSingleOutputNet()
expected_primals = np.array([7.47868, 7.47868]).astype(np.float32)
expected_series = np.array([1237.959, 1237.959]).astype(np.float32)
out_primals, out_series = derivative(net, primals, order)
assert np.allclose(out_primals.asnumpy(), expected_primals, atol=1.e-4)
assert np.allclose(out_series.asnumpy(), expected_series, atol=1.e-4)

View File

@ -0,0 +1,142 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""test taylor differentiation in pynative mode"""
import pytest
import numpy as np
import mindspore.nn as nn
import mindspore.context as context
from mindspore.ops import operations as P
from mindspore import Tensor
from mindspore.ops.functional import jet, derivative
context.set_context(mode=context.PYNATIVE_MODE)
class SingleInputSingleOutputNet(nn.Cell):
def __init__(self):
super(SingleInputSingleOutputNet, self).__init__()
self.exp = P.Exp()
self.cos = P.Cos()
self.sin = P.Sin()
def construct(self, x):
out1 = self.sin(x)
out2 = self.cos(out1)
out3 = self.exp(out2)
out = out1 + out2 - out3
return out
class MultipleInputSingleOutputNet(nn.Cell):
def __init__(self):
super(MultipleInputSingleOutputNet, self).__init__()
self.exp = P.Exp()
self.cos = P.Cos()
self.sin = P.Sin()
def construct(self, x, y):
out1 = self.sin(x)
out2 = self.cos(y)
out3 = out1 * out2 + out1 / out2
out = self.exp(out3)
return out
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_jet_multiple_input_single_output_pynative_mode():
"""
Features: Function jet
Description: Test jet with multiple inputs in pynative mode.
Expectation: No exception.
"""
series = (Tensor([[1., 1.], [0., 0.], [0., 0.]]), Tensor([[1., 1.], [0., 0.], [0., 0.]]))
primals = (Tensor([1., 1.]), Tensor([1., 1.]))
net = MultipleInputSingleOutputNet()
expected_primals = np.array([7.47868, 7.47868]).astype(np.float32)
expected_series = np.array([[22.50614, 22.50614], [133.92517, 133.92517], [1237.959, 1237.959]]).astype(np.float32)
out_primals, out_series = jet(net, primals, series)
assert np.allclose(out_series.asnumpy(), expected_series, atol=1.e-4)
assert np.allclose(out_primals.asnumpy(), expected_primals, atol=1.e-4)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_derivative_multiple_input_single_output_pynative_mode():
"""
Features: Function derivative
Description: Test derivative with multiple inputs in pynative mode.
Expectation: No exception.
"""
primals = (Tensor([1., 1.]), Tensor([1., 1.]))
order = 3
net = MultipleInputSingleOutputNet()
expected_primals = np.array([7.47868, 7.47868]).astype(np.float32)
expected_series = np.array([1237.959, 1237.959]).astype(np.float32)
out_primals, out_series = derivative(net, primals, order)
assert np.allclose(out_primals.asnumpy(), expected_primals, atol=1.e-4)
assert np.allclose(out_series.asnumpy(), expected_series, atol=1.e-4)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_jet_single_input_single_output_pynative_mode():
"""
Features: Function jet
Description: Test jet with single input in pynative mode.
Expectation: No exception.
"""
primals = Tensor([1., 1.])
series = Tensor([[1., 1.], [0., 0.], [0., 0.]])
net = SingleInputSingleOutputNet()
expected_primals = np.array([-0.43931, -0.43931]).astype(np.float32)
expected_series = np.array([[0.92187, 0.92187], [-1.56750, -1.56750], [-0.74808, -0.74808]]).astype(np.float32)
out_primals, out_series = jet(net, primals, series)
assert np.allclose(out_primals.asnumpy(), expected_primals, atol=1.e-4)
assert np.allclose(out_series.asnumpy(), expected_series, atol=1.e-4)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_derivative_single_input_single_output_pynative_mode():
"""
Features: Function derivative
Description: Test derivative with single input in pynative mode.
Expectation: No exception.
"""
primals = Tensor([1., 1.])
order = 3
net = SingleInputSingleOutputNet()
expected_primals = np.array([-0.43931, -0.43931]).astype(np.float32)
expected_series = np.array([-0.74808, -0.74808]).astype(np.float32)
out_primals, out_series = derivative(net, primals, order)
assert np.allclose(out_primals.asnumpy(), expected_primals, atol=1.e-4)
assert np.allclose(out_series.asnumpy(), expected_series, atol=1.e-4)