forked from mindspore-Ecosystem/mindspore
!30211 [MS]Add Higher Order Differentiation
Merge pull request !30211 from chenzhuo/taylor_v1
This commit is contained in:
@ -944,6 +944,80 @@ REGISTER_PYBIND_DEFINE(VmapOperation_, ([](const py::module *m) {
.def(py::init<std::string &>(), py::arg("fn"));
TaylorOperation::TaylorOperation(const std::string &name) : MetaFuncGraph(name) {
// def Taylor(func:read):
signatures_ = std::vector<Signature>({{"func", SignatureEnumRW::kRWRead, SignatureEnumKind::kKindDefault}});
FuncGraphPtr TaylorOperation::GetTaylorGrad(const AnfNodePtr &k, const std::vector<AnfNodePtr> &forward_graph_params) {
FuncGraphPtr k_child = std::make_shared<FuncGraph>();
k_child->set_flag(FUNC_GRAPH_FLAG_CORE, true);
std::vector<AnfNodePtr> inputs;
MS_LOG(INFO) << "TaylorOperation forward input size " << forward_graph_params.size();
for (size_t i = 0; i < forward_graph_params.size(); ++i) {
// Taylor(fn)(input params)
auto k_app = k_child->NewCNodeInOrder(inputs);
return k_child;
// Generate the graph to calculate higher order derivatives.
FuncGraphPtr TaylorOperation::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
if (args_spec_list.empty()) {
<< "'TaylorOperation' requires a forward network or function as an input, while the input is empty.";
AbstractFunctionPtr fn = dyn_cast<AbstractFunction>(args_spec_list[0]);
if (fn == nullptr) {
MS_LOG(EXCEPTION) << "'TaylorOperation' arg0 must be a 'Function' or 'Cell', but got "
<< args_spec_list[0]->ToString();
auto real_fn = dyn_cast<FuncGraphAbstractClosure>(fn);
FuncGraphPtr forward_graph = real_fn->func_graph();
forward_graph->set_flag(FUNC_GRAPH_FLAG_DEFER_INLINE, true);
FuncGraphPtr grad_fg = nullptr;
MS_LOG(INFO) << "'TaylorOperation' forward_graph" << forward_graph->debug_info();
grad_fg = std::make_shared<FuncGraph>();
auto nparam = forward_graph->parameters().size();
std::ostringstream ss;
ss << "taylorgrad{" << nparam << "}";
grad_fg->set_flag(FUNC_GRAPH_FLAG_CORE, true);
ParameterPtr param_graph = grad_fg->add_parameter();
std::vector<AnfNodePtr> inputs;
// Taylor(fn)
auto mark_taylor = grad_fg->NewCNodeInOrder(inputs);
FuncGraphPtr k_child = nullptr;
TraceGuard guard(std::make_shared<TraceGradOperation>(forward_graph->debug_info()));
k_child = GetTaylorGrad(mark_taylor, forward_graph->parameters());
// return Taylor(fn)(inputs)
return grad_fg;
REGISTER_PYBIND_DEFINE(TaylorOperation_, ([](const py::module *m) {
(void)py::class_<TaylorOperation, MetaFuncGraph, std::shared_ptr<TaylorOperation>>(
*m, "TaylorOperation_")
.def(py::init<std::string &>(), py::arg("fn"));
// Generate the ListMap func graph.
FuncGraphPtr ListMap::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
size_t args_num = args_spec_list.size();
@ -166,6 +166,17 @@ class GradOperation : public MetaFuncGraph {
using GradOperationPtr = std::shared_ptr<GradOperation>;
class TaylorOperation : public MetaFuncGraph {
explicit TaylorOperation(const std::string &name);
~TaylorOperation() override = default;
MS_DECLARE_PARENT(TaylorOperation, MetaFuncGraph);
FuncGraphPtr GetTaylorGrad(const AnfNodePtr &k, const std::vector<AnfNodePtr> &forward_graph_params);
FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override;
using TaylorOperationPtr = std::shared_ptr<TaylorOperation>;
class ListMap {
explicit ListMap(const std::string &name) : name_(name) { cache_.clear(); }
@ -738,6 +738,25 @@ AbstractBasePtr InferImplJ(const AnalysisEnginePtr &, const PrimitivePtr &primit
return AbstractFunction::MakeAbstractFunction(jv);
AbstractBasePtr InferImplTaylor(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// args: An object of AbstractFunction.
CheckArgsSize(primitive->name(), args_spec_list, 1);
MS_LOG(DEBUG) << "evaluate Taylor: " << args_spec_list[0]->ToString();
AbstractFunctionPtr x = dyn_cast<AbstractFunction>(args_spec_list[0]);
AbstractFuncAtomPtrList taylor_v;
auto build_taylor_v = [&taylor_v](const AbstractFuncAtomPtr &func) {
auto taylor_closure = std::make_shared<TaylorTransformedAbstractClosure>(func);
return AbstractFunction::MakeAbstractFunction(taylor_v);
AbstractBasePtr InferImplShard(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// Inputs: func, in_axes, out_axes, device, level.
@ -846,6 +865,7 @@ REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(StringConcat, prim::kPrimStringConcat, InferI
REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(DictLen, prim::kPrimDictLen, InferImplDictLen, nullptr);
REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(FakeBprop, prim::kPrimFakeBprop, InferImplFakeBprop, nullptr);
REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(J, prim::kPrimJ, InferImplJ, nullptr);
REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(Taylor, prim::kPrimTaylor, InferImplTaylor, nullptr);
REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(Shard, prim::kPrimShard, InferImplShard, nullptr);
REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(Vmap, prim::kPrimVmap, InferImplVmap, nullptr);
REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(BroadcastGradientArgs, prim::kPrimBroadcastGradientArgs,
@ -55,6 +55,8 @@ AbstractBasePtr InferImplDictLen(const AnalysisEnginePtr &, const PrimitivePtr &
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplJ(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplTaylor(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplShard(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplVmap(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
@ -21,6 +21,7 @@
#include "frontend/optimizer/irpass/convert.h"
#include "frontend/optimizer/irpass/environ_eliminate.h"
#include "frontend/optimizer/irpass/grad_var_prepare.h"
#include "frontend/optimizer/irpass/taylor_eliminate.h"
#include "frontend/optimizer/irpass/inline.h"
#include "frontend/optimizer/irpass/updatestate_eliminate.h"
#include "frontend/optimizer/irpass/load_eliminate.h"
@ -22,6 +22,7 @@
#include "base/core_ops.h"
#include "frontend/optimizer/irpass/gradient_eliminate.h"
#include "frontend/optimizer/irpass/vmap_eliminate.h"
#include "frontend/optimizer/irpass/taylor_eliminate.h"
#include "frontend/optimizer/irpass/meta_fg_prim_eliminate.h"
namespace mindspore {
@ -35,8 +36,8 @@ class ExpandMetaFg {
// to the implementation of `kPrimVmap`.
virtual ~ExpandMetaFg() = default;
bool operator()(const FuncGraphPtr &func_graph, const OptimizerPtr &optimizer);
@ -0,0 +1,158 @@
* Copyright 2022 Huawei Technologies Co., Ltd
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* See the License for the specific language governing permissions and
* limitations under the License.
#include <string>
#include "ir/func_graph_cloner.h"
#include "frontend/optimizer/irpass/taylor_eliminate.h"
#include "pipeline/pynative/pynative_execute.h"
namespace mindspore {
namespace opt {
namespace irpass {
namespace internal {
// White list of ops with taylor rule.
mindspore::HashSet<std::string> taylor_ops{prim::kPrimAdd->name(), prim::kPrimSub->name(), prim::kPrimRealDiv->name(),
prim::kPrimMul->name(), prim::kPrimSin->name(), prim::kPrimCos->name(),
// The ops below are excluded when considering taylor rules.
mindspore::HashSet<std::string> taylor_exception_ops{prim::kPrimReturn->name(), prim::kPrimMakeTuple->name(),
prim::kPrimTupleGetItem->name(), prim::kPrimCast->name()};
// Cache list of primitive ops which have been replaced by taylor rule.
mindspore::HashMap<PrimitivePtr, FuncGraphPtr> taylor_ops_cache_;
FuncGraphPtr GetTaylorRule(const PrimitivePtr &prim, const pipeline::ResourceBasePtr &resources) {
// Set a child scope named "grad'PrimitiveName'" for the taylor rule function,
// and add "Gradients" to the front.
static const std::string gradients_scope = "Gradients/";
static const std::string grad_op_child_scope_prefix = "/grad";
auto scope = std::make_shared<Scope>(gradients_scope + ScopeManager::GetInstance().GetCurrentScope()->name() +
grad_op_child_scope_prefix + prim->name());
ScopeGuard scope_guard(scope);
// Firstly we get taylor rule from mindir. If failed, parse the python function registered.
FuncGraphPtr func_graph = nullptr;
py::function taylor_fn;
if (prim->is_base()) {
taylor_fn = GetTaylorRuleFunction(prim->name());
} else {
taylor_fn = prim->cast<PrimitivePyPtr>()->GetTaylorRuleFunction();
if (py::isinstance<py::none>(taylor_fn)) {
taylor_fn = GetTaylorRuleFunction(prim->name());
if (!taylor_fn || py::isinstance<py::none>(taylor_fn)) {
MS_LOG(INFO) << "Fail to find taylor rule function for " << prim->name() << ". taylor_fn: " << py::str(taylor_fn);
return nullptr;
func_graph = parse::ParsePythonCode(taylor_fn);
if (func_graph == nullptr) {
MS_LOG(ERROR) << "Fail to parse taylor rule function for " << prim->name() << ".";
return nullptr;
auto taylor_rule_flag = GetPrimitiveFlag(prim, GRAPH_FLAG_SIDE_EFFECT_BACKPROP);
if (taylor_rule_flag) {
func_graph->set_flag(mindspore::kFuncGraphFlagReAutoMonad, true);
pipeline::ResourceBasePtr res = (resources != nullptr) ? resources : std::make_shared<pipeline::Resource>();
(void)parse::ResolveFuncGraph(func_graph, res);
return func_graph;
FuncGraphPtr GetTaylorPyObj(const PrimitivePtr &prim, const pipeline::ResourceBasePtr &resources) {
auto fg = GetTaylorRule(prim, resources);
return fg;
FuncGraphPtr GetTaylorPrimitive(const AnfNodePtr &node, const pipeline::ResourceBasePtr &resources) {
auto prim_node = GetValueNode<PrimitivePtr>(node);
auto iter = taylor_ops_cache_.find(prim_node);
if (iter != taylor_ops_cache_.end()) {
return iter->second;
FuncGraphPtr primitive_taylor = GetTaylorPyObj(prim_node, resources);
taylor_ops_cache_[prim_node] = primitive_taylor;
return primitive_taylor;
FuncGraphPtr TaylorFunctor(const FuncGraphPtr &func_graph, const pipeline::ResourceBasePtr &resources) {
const auto &value_nodes = func_graph->value_nodes();
auto manager = resources->manager();
std::vector<AnfNodePtr> taylor_node_list;
for (const auto &value_pair : value_nodes) {
auto node = value_pair.first;
if (IsValueNode<Primitive>(node)) {
auto prim_node = GetValueNode<PrimitivePtr>(node);
if (taylor_ops.count(prim_node->name())) {
} else if (!taylor_exception_ops.count(prim_node->name())) {
MS_LOG(EXCEPTION) << "The operation " << prim_node->name()
<< " is not supported in taylor higher order differentiation currently.";
for (size_t i = 0; i < taylor_node_list.size(); i++) {
FuncGraphPtr taylor_node_graph = GetTaylorPrimitive(taylor_node_list[i], resources);
manager->Replace(taylor_node_list[i], NewValueNode(taylor_node_graph));
MS_LOG(INFO) << "return replaced taylor node: " << func_graph->ToString() << " replace end.";
return func_graph;
AnfNodePtr ExpandTaylor(const ValueNodePtr &vnode, const pipeline::ResourceBasePtr &resource) {
if (IsValueNode<FuncGraph>(vnode)) {
ScopeGuard scope_guard(vnode->scope());
auto func_graph = GetValueNode<FuncGraphPtr>(vnode);
MS_LOG(DEBUG) << "Funcgraph: " << func_graph->ToString() << " will expandTaylor now";
auto newfg = TaylorFunctor(func_graph, resource);
return NewValueNode(newfg);
return nullptr;
} // namespace internal
bool ExpandTaylorPrim::operator()(const FuncGraphPtr &func_graph, const OptimizerPtr &optimizer) {
// Search all taylor nodes.
bool change = false;
auto manager = optimizer->manager();
for (auto &taylor_node : prim_nodes_) {
auto taylor_fg_node = taylor_node->input(1);
auto taylor_fg = GetValueNode<FuncGraphPtr>(taylor_fg_node);
if (taylor_fg == nullptr) {
MS_LOG(EXCEPTION) << "Unexpected Taylor node, input func graph should not be null, node: "
<< taylor_fg->ToString();
// Copy original forward graph in case of the influence of usage in other place.
auto taylor_fg_copy = BasicClone(taylor_fg, true);
auto taylor_fg_copy_node = NewValueNode(taylor_fg_copy);
// Return expanded taylor graph.
auto expanded_taylor = internal::ExpandTaylor(taylor_fg_copy_node->cast<ValueNodePtr>(), optimizer->resource());
manager->Replace(taylor_node, expanded_taylor);
change = true;
return change;
} // namespace irpass
} // namespace opt
} // namespace mindspore
@ -0,0 +1,46 @@
* Copyright 2022 Huawei Technologies Co., Ltd
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* See the License for the specific language governing permissions and
* limitations under the License.
#include <vector>
#include <algorithm>
#include <memory>
#include "frontend/optimizer/optimizer.h"
#include "frontend/optimizer/irpass.h"
#include "frontend/optimizer/anf_visitor.h"
#include "utils/ms_utils.h"
#include "include/common/utils/primitive_utils.h"
#include "frontend/operator/ops.h"
#include "frontend/optimizer/ad/grad.h"
#include "frontend/optimizer/irpass/meta_fg_prim_eliminate.h"
namespace mindspore {
namespace opt {
namespace irpass {
// {prim::kPrimTaylor, C}
class ExpandTaylorPrim : public ExpandMetaFGPrim {
ExpandTaylorPrim() { prim_ = prim::kPrimTaylor; }
virtual ~ExpandTaylorPrim() = default;
bool operator()(const FuncGraphPtr &func_graph, const OptimizerPtr &optimizer);
} // namespace irpass
} // namespace opt
} // namespace mindspore
@ -353,6 +353,7 @@ constexpr char SIMPLE_MEAN[] = "SimpleMean";
constexpr char FLATTEN[] = "Flatten";
constexpr char J[] = "J";
constexpr char SHARD[] = "Shard";
constexpr char Taylor[] = "Taylor";
constexpr char TMPIDENTITY_INFO_NAME[] = "identity_info";
constexpr char COS[] = "Cos";
constexpr char ACOS[] = "ACos";
@ -30,6 +30,10 @@ COMMON_EXPORT py::function GetBpropFunctionByObj(const py::object &obj);
COMMON_EXPORT py::function GetBpropFunction(const std::string &name);
COMMON_EXPORT py::function GetTaylorRuleFunctionByObj(const py::object &obj);
COMMON_EXPORT py::function GetTaylorRuleFunction(const std::string &name);
COMMON_EXPORT py::function GetComputeFunction(const std::string &name);
COMMON_EXPORT BaseRef RunComputeFunction(const PrimitivePtr &prim, const VectorRef &args);
@ -52,6 +52,7 @@
#include "frontend/optimizer/irpass/ge_specialized_prepare.h"
#include "frontend/optimizer/irpass/gradient_eliminate.h"
#include "frontend/optimizer/irpass/shard_eliminate.h"
#include "frontend/optimizer/irpass/taylor_eliminate.h"
#include "frontend/optimizer/irpass/parameter_eliminate.h"
#include "frontend/optimizer/irpass/updatestate_eliminate.h"
#if ((defined ENABLE_CPU) && (!defined _WIN32))
@ -572,6 +572,27 @@ EvalResultPtr JEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &arg
return res;
EvalResultPtr TaylorEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
const AnfNodeConfigPtr &) {
AbstractBasePtrList args_spec_list;
(void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
[](const ConfigPtr &conf) -> AbstractBasePtr {
return conf->ObtainEvalResult()->abstract();
auto eval_result = evaluator_cache_mgr_->GetValue(args_spec_list);
if (eval_result != nullptr) {
return eval_result;
// Call the original evaluator, get the result: y = f(x)
EvalResultPtr result = evaluator_->Run(engine, args_conf_list, nullptr);
evaluator_cache_mgr_->SetValue(args_spec_list, result);
return result;
EvalResultPtr ShardEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
const AnfNodeConfigPtr &) {
AbstractBasePtrList args_spec_list;
@ -372,6 +372,38 @@ class JEvaluator : public Evaluator {
AbstractFunctionPtr orig_func_;
class TaylorEvaluator : public Evaluator {
TaylorEvaluator(const EvaluatorPtr &evaluator, const AbstractFunctionPtr &orig_func)
: Evaluator("TaylorEvaluator"), evaluator_(evaluator), orig_func_(orig_func) {}
~TaylorEvaluator() override = default;
MS_DECLARE_PARENT(TaylorEvaluator, Evaluator);
AnfNodePtr bound_node() const override {
if (evaluator_ != nullptr) {
return evaluator_->bound_node();
return bound_node_.lock();
void set_bound_node(const AnfNodePtr &node) override {
if (evaluator_ != nullptr) {
bound_node_ = AnfNodeWeakPtr(node);
EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &, const AnfNodeConfigPtr &) override {
MS_LOG(EXCEPTION) << "Should not be called, Run() method should be called";
EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
const AnfNodeConfigPtr &out_conf) override;
std::string ToString() const override { return identifier_ + "_" + evaluator_->ToString(); }
EvaluatorPtr evaluator_;
AbstractFunctionPtr orig_func_;
class ShardEvaluator : public Evaluator {
ShardEvaluator(const EvaluatorPtr &evaluator, const AbstractFunctionPtr &orig_func)
@ -501,6 +501,14 @@ EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const std::shared_ptr<VmapTransfor
return vmap_evaluator;
EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const std::shared_ptr<TaylorTransformedAbstractClosure> &func) {
AbstractFunctionPtr func_orig = func->fn();
EvaluatorPtr evaluator_orig = GetEvaluatorFor(func_orig);
auto taylorevaluator = std::make_shared<TaylorEvaluator>(evaluator_orig, func_orig);
return taylorevaluator;
EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const std::shared_ptr<ShardTransformedAbstractClosure> &func) {
AbstractFunctionPtr func_orig = func->fn();
@ -548,6 +556,8 @@ EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const AbstractFunctionPtr &func) {
return _GetEvaluatorFor(func->cast<std::shared_ptr<JTransformedAbstractClosure>>());
} else if (func->isa<VmapTransformedAbstractClosure>()) {
return _GetEvaluatorFor(func->cast<std::shared_ptr<VmapTransformedAbstractClosure>>());
} else if (func->isa<TaylorTransformedAbstractClosure>()) {
return _GetEvaluatorFor(func->cast<std::shared_ptr<TaylorTransformedAbstractClosure>>());
} else if (func->isa<ShardTransformedAbstractClosure>()) {
return _GetEvaluatorFor(func->cast<std::shared_ptr<ShardTransformedAbstractClosure>>());
} else if (func->isa<VirtualAbstractClosure>()) {
@ -324,6 +324,7 @@ class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> {
EvaluatorPtr _GetEvaluatorFor(const std::shared_ptr<VirtualAbstractClosure> &fn);
EvaluatorPtr _GetEvaluatorFor(const std::shared_ptr<TypedPrimitiveAbstractClosure> &);
EvaluatorPtr _GetEvaluatorFor(const std::shared_ptr<JTransformedAbstractClosure> &fn);
EvaluatorPtr _GetEvaluatorFor(const std::shared_ptr<TaylorTransformedAbstractClosure> &fn);
EvaluatorPtr _GetEvaluatorFor(const std::shared_ptr<ShardTransformedAbstractClosure> &fn);
EvaluatorPtr _GetEvaluatorFor(const std::shared_ptr<VmapTransformedAbstractClosure> &fn);
@ -151,6 +151,18 @@ py::function PrimitivePy::GetBpropFunction() {
auto fn = GetBpropFunctionByObj(python_obj_);
return fn;
auto fn = GetBpropFunctionByObj(python_obj_);
return fn;
py::function PrimitivePy::GetTaylorRuleFunction() {
static const char *const get_taylor_rule_func_name = "get_taylor_rule";
if (py::hasattr(python_obj_, get_taylor_rule_func_name)) {
py::function fn = python_obj_.attr(get_taylor_rule_func_name)().cast<py::function>();
return fn;
auto fn = GetTaylorRuleFunctionByObj(python_obj_);
return fn;
py::tuple check_bprop_out(const py::object &grads_obj, const py::tuple &py_args, const std::string &bprop_cls_name) {
@ -50,6 +50,7 @@ class PrimitivePy : public Primitive {
const bool parse_info_ = true;
py::function GetVmapRuleFunction(const bool is_side_effect = false, int axis_size = 0);
py::function GetBpropFunction();
py::function GetTaylorRuleFunction();
void set_signatures(const std::vector<Signature> &signatures);
const std::vector<Signature> &signatures() const { return signatures_; }
const std::map<int, py::function> &backward_hook_fn() const { return backward_hook_fn_; }
@ -41,6 +41,18 @@ py::function GetBpropFunction(const std::string &name) {
return fn;
py::function GetTaylorRuleFunctionByObj(const py::object &obj) {
static const std::string get_taylor_fprop_fn = "get_taylor_fprop_fn";
static const std::string ad_module = "mindspore.ops._grad";
py::function fn = python_adapter::GetPyFn(ad_module, get_taylor_fprop_fn)(obj);
return fn;
py::function GetTaylorRuleFunction(const std::string &name) {
auto fn = GetTaylorRuleFunctionByObj(py::str(name));
return fn;
py::function GetComputeFunction(const std::string &name) {
static const std::string module = "mindspore._extends.builtin_operations";
py::module mod = py::module::import(common::SafeCStr(module));
@ -297,6 +297,20 @@ std::size_t JTransformedAbstractClosure::hash() const {
return hash_value;
bool TaylorTransformedAbstractClosure::operator==(const AbstractFunction &other) const {
if (!other.isa<TaylorTransformedAbstractClosure>()) {
return false;
auto other_transformed = static_cast<const TaylorTransformedAbstractClosure *>(&other);
return fn_ == other_transformed->fn_;
std::size_t TaylorTransformedAbstractClosure::hash() const {
auto hash_value = hash_combine(tid(), fn_->hash());
return hash_value;
bool ShardTransformedAbstractClosure::operator==(const AbstractFunction &other) const {
if (!other.isa<ShardTransformedAbstractClosure>()) {
return false;
@ -339,6 +339,36 @@ class MS_CORE_API JTransformedAbstractClosure final : public AbstractFuncAtom {
AbstractFuncAtomPtr fn_;
/// \brief TaylorTransformedAbstractClosure defines interface for abstract of Function
/// transformed through the application of Taylor.
class MS_CORE_API TaylorTransformedAbstractClosure final : public AbstractFuncAtom {
/// \brief Constructor of TaylorTransformedAbstractClosure
/// \param[in] fn The AbstractFuncAtom transformed through the application of Taylor.
explicit TaylorTransformedAbstractClosure(const AbstractFuncAtomPtr &fn) : fn_(fn) {}
/// \brief Destructor of TaylorTransformedAbstractClosure
~TaylorTransformedAbstractClosure() override = default;
MS_DECLARE_PARENT(TaylorTransformedAbstractClosure, AbstractFuncAtom)
/// \brief Get the AbstractFuncAtom TaylorTransformedAbstractClosure corresponding to.
/// \return The AbstractFuncAtom TaylorTransformedAbstractClosure corresponding to.
AbstractFuncAtomPtr fn() { return fn_; }
AbstractFunctionPtr Copy() const override { return std::make_shared<TaylorTransformedAbstractClosure>(fn_); }
bool operator==(const AbstractFunction &other) const override;
std::size_t hash() const override;
std::string ToString() const override { return "Taylor(" + fn_->ToString() + ")"; }
AbstractFuncAtomPtr fn_;
/// \brief ShardTransformedAbstractClosure defines interface for abstract of Function
/// transformed through the application of Shard.
class MS_CORE_API ShardTransformedAbstractClosure final : public AbstractFuncAtom {
@ -171,6 +171,7 @@ constexpr auto kCSRDiv = "CSRDiv";
// Meta Function Graph
constexpr auto kJ = "J";
constexpr auto kVmap = "Vmap";
constexpr auto kTaylor = "Taylor";
// Others
constexpr auto kMakeTuple = "MakeTuple";
@ -828,6 +829,7 @@ GVAR_DEF(PrimitivePtr, kPrimStateSetItem, std::make_shared<Primitive>("state_set
GVAR_DEF(PrimitivePtr, kPrimJ, std::make_shared<Primitive>(kJ, kSideEffectPropagate));
GVAR_DEF(PrimitivePtr, kPrimVmap, std::make_shared<Primitive>(kVmap, kSideEffectPropagate));
GVAR_DEF(PrimitivePtr, kPrimShard, std::make_shared<Primitive>("Shard", kSideEffectPropagate));
GVAR_DEF(PrimitivePtr, kPrimTaylor, std::make_shared<Primitive>(kTaylor));
// Used to build graph which have keyword arguments
GVAR_DEF(PrimitivePtr, kPrimExtractKeywordArg, std::make_shared<Primitive>("extract_keyword_arg"));
@ -313,7 +313,7 @@ std::shared_ptr<std::list<FuncGraphPtr>> FuncGraphManager::recursive_graphs(cons
// Check if the function graph embed with `MetaFGPrim`, which currently covers kPrimJ and kPrimVmap.
// Check if the function graph embed with `MetaFGPrim`, which currently covers kPrimJ and kPrimVmap and kPrimTaylor.
bool FuncGraphManager::func_graph_meta_fg_prim_total(const FuncGraphPtr &fg) const {
@ -704,7 +704,8 @@ void FuncGraphManager::OnEdgeAdded(const AnfNodePtr &node, int index, const AnfN
if (IsPrimitiveCNode(node, prim::kPrimJ) || IsPrimitiveCNode(node, prim::kPrimVmap)) {
if (IsPrimitiveCNode(node, prim::kPrimJ) || IsPrimitiveCNode(node, prim::kPrimVmap) ||
IsPrimitiveCNode(node, prim::kPrimTaylor)) {
} else if (fg != nullptr && fg != input->func_graph()) {
@ -725,7 +726,8 @@ void FuncGraphManager::OnEdgeRemoved(const AnfNodePtr &node, int index, const An
if (IsPrimitiveCNode(node, prim::kPrimJ) || IsPrimitiveCNode(node, prim::kPrimVmap)) {
if (IsPrimitiveCNode(node, prim::kPrimJ) || IsPrimitiveCNode(node, prim::kPrimVmap) ||
IsPrimitiveCNode(node, prim::kPrimTaylor)) {
} else if (fg != nullptr && fg != input->func_graph()) {
@ -1096,7 +1098,7 @@ bool FuncGraphMetaFgPrimTotalComputer::SeekMetaFgPrim(const FuncGraphPtr &fg, Se
return false;
// Check MetaFgPrim (J/Vmap) FuncGraph input.
// Check MetaFgPrim (J/Vmap/Taylor) FuncGraph input.
const auto &meta_fg_prim_values = fg->meta_fg_prim_value_nodes();
if (!meta_fg_prim_values.empty()) {
auto contains_meta_fg_prim =
@ -1107,9 +1109,10 @@ bool FuncGraphMetaFgPrimTotalComputer::SeekMetaFgPrim(const FuncGraphPtr &fg, Se
return func_graph->seen_ != seen_num;
if (IsValueNode<Primitive>(iter.first)) {
// Exclude the primitive of MetaFgPrim (J/Vmap) itself.
// Exclude the primitive of MetaFgPrim (J/Vmap/Taylor) itself.
auto prim = GetValueNode<PrimitivePtr>(iter.first);
return (prim->name() != prim::kPrimJ->name() && prim->name() != prim::kPrimVmap->name());
return (prim->name() != prim::kPrimJ->name() && prim->name() != prim::kPrimVmap->name() &&
prim->name() != prim::kPrimTaylor->name());
return false;
@ -1119,35 +1122,38 @@ bool FuncGraphMetaFgPrimTotalComputer::SeekMetaFgPrim(const FuncGraphPtr &fg, Se
// Check MetaFgPrim (J/Vmap) CNode as FV.
// Check MetaFgPrim (J/Vmap/Taylor) CNode as FV.
const auto &fv_nodes = fg->free_variables();
if (!fv_nodes.empty()) {
auto contains_meta_fg_prim_cnode = std::find_if(fv_nodes.begin(), fv_nodes.end(), [seen_num](const auto &iter) {
// Check if the FV is a MetaFgPrim (J/Vmap) call CNode.
if (IsPrimitiveCNode(iter.first, prim::kPrimJ) || IsPrimitiveCNode(iter.first, prim::kPrimVmap)) {
// Check if the FV is a MetaFgPrim (J/Vmap/Taylor) call CNode.
if (IsPrimitiveCNode(iter.first, prim::kPrimJ) || IsPrimitiveCNode(iter.first, prim::kPrimVmap) ||
IsPrimitiveCNode(iter.first, prim::kPrimTaylor)) {
return true;
return false;
if (contains_meta_fg_prim_cnode != fv_nodes.end()) {
MS_LOG(DEBUG) << fg->ToString() << " contains FV MetaFgPrim (J/Vmap) ("
MS_LOG(DEBUG) << fg->ToString() << " contains FV MetaFgPrim (J/Vmap/Taylor) ("
<< contains_meta_fg_prim_cnode->first->DebugString() << ")";
return true;
// Check if func graphs used contains J(func_graph), J(Primitive), Vmap(func_graph) or Vmap(Primitive)
// Check if func graphs used contains J(func_graph), J(Primitive), Vmap(func_graph), Vmap(Primitive),
// Taylor(func_graph) or Taylor(Primitive).
fg->seen_ = seen_num;
for (auto &item : fg->func_graphs_used()) {
auto used_g = item.first;
if (SeekMetaFgPrim(used_g, seen_num)) {
MS_LOG(DEBUG) << fg->ToString() << " users func graph " << used_g->ToString()
<< " which contains J(func_graph), J(Primitive), Vmap(func_graph) or Vmap(Primitive)";
<< " which contains J(func_graph), J(Primitive), Vmap(func_graph), Vmap(Primitive), "
<< "Taylor(func_graph) or Taylor(Primitive)";
return true;
MS_LOG(DEBUG) << fg->ToString()
<< " doesn't contain J(func_graph), J(Primitive), Vmap(func_graph) or Vmap(Primitive)";
MS_LOG(DEBUG) << fg->ToString() << " doesn't contain J(func_graph), J(Primitive), Vmap(func_graph), Vmap(Primitive), "
<< "Taylor(func_graph) or Taylor(Primitive)";
return false;
@ -15,7 +15,7 @@
"""grad impl."""
from . import grad_array_ops, grad_comm_ops, grad_debug_ops, grad_implementations, \
grad_math_ops, grad_nn_ops, grad_other_ops, grad_quant_ops, grad_sparse, grad_inner_ops
from .grad_base import get_bprop_fn
grad_math_ops, grad_nn_ops, grad_other_ops, grad_quant_ops, grad_sparse, grad_inner_ops, taylor_rule
from .grad_base import get_bprop_fn, get_taylor_fprop_fn
__all__ = ['get_bprop_fn']
__all__ = ['get_bprop_fn', 'get_taylor_fprop_fn']
@ -37,8 +37,28 @@ class BpropRegistry(Registry):
return deco
class TaylorFpropRegistry(Registry):
"""Registry class for registry functions for taylor grad on Primitive or string."""
def register(self, prim):
"""register the function."""
def deco(fn):
"""Decorate the function."""
if isinstance(prim, str):
self[prim] = fn
elif issubclass(prim, Primitive):
self[id(prim)] = fn
self[prim.__name__] = fn
return fn
return deco
bprop_getters = BpropRegistry()
bprops = BpropRegistry()
taylor_fprop_getters = TaylorFpropRegistry()
taylor_fprops = TaylorFpropRegistry()
def get_bprop_fn(prim):
@ -47,3 +67,11 @@ def get_bprop_fn(prim):
if out:
return out(prim)
return bprops.get(prim, None)
def get_taylor_fprop_fn(prim):
"""get taylor function by primitive obj or prim name for c++"""
out = taylor_fprop_getters.get(prim, None)
if out:
return out(prim)
return taylor_fprops.get(prim, None)
@ -0,0 +1,166 @@
# Copyright 2022 Huawei Technologies Co., Ltd
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Define the taylor rules of operations."""
from mindspore import nn
import mindspore as ms
from ..primitive import Primitive
from .. import operations as P
from ..composite.multitype_ops.zeros_like_impl import zeros_like
from .grad_base import taylor_fprop_getters
def _factorial(order):
"""Return [0!, 1!, 2!,..., order!]."""
range_op = nn.Range(1, order + 1)
ones_op = P.Ones()
concat_op = P.Concat()
factorial_zero = ones_op(1, ms.float32)
factorial_positive = range_op().astype(ms.float32)
for i in range(1, order):
factorial_positive[i] *= factorial_positive[i - 1]
factorial = concat_op((factorial_zero, factorial_positive))
return factorial
def taylor_add_or_sub(self):
"""Higher order derivatives rule definition for `Add` or `Sub`operation."""
if isinstance(self, str):
prim = Primitive(self)
prim = self
def taylor_fprop_add_or_sub(input_x, input_y):
series = prim(input_x, input_y)
return series
return taylor_fprop_add_or_sub
def taylor_mul(self):
"""Higher order derivatives rule definition for `Mul` operation."""
mul_func = P.Mul()
def taylor_fprop_mul(input_x, input_y):
primals = mul_func(input_x[0], input_y[0])
series_num = len(input_x) - 1
factorial = _factorial(series_num)
series = zeros_like(input_x)
series[0] = primals
for k in range(1, series_num + 1):
for i in range(0, k + 1):
tmp = input_x[i] * input_y[k - i] / (factorial[k - i] * factorial[i])
series[k] += tmp
series[k] *= factorial[k]
return series
return taylor_fprop_mul
def taylor_realdiv(self):
"""Higher order derivatives rule definition for `RealDiv` operation."""
div_op = P.Div()
def taylor_fprop_realdiv(input_x, input_y):
primals = div_op(input_x[0], input_y[0])
series_num = len(input_x) - 1
factorial = _factorial(series_num)
series = zeros_like(input_x)
series[0] = primals
for k in range(1, series_num + 1):
for i in range(0, k):
tmp = series[i] * input_y[k - i] / (factorial[k - i] * factorial[i])
series[k] += tmp
series[k] = (input_x[k] - factorial[k] * series[k]) / input_y[0]
return series
return taylor_fprop_realdiv
def taylor_exp(self):
"""Higher order derivatives rule definition for `Exp` operation."""
exp_ = P.Exp()
def taylor_fprop_exp(inputs):
primals = exp_(inputs[0])
series_num = len(inputs) - 1
factorial = _factorial(series_num)
series = zeros_like(inputs)
series[0] = primals
for k in range(1, series_num + 1):
for i in range(1, k + 1):
tmp = i * inputs[i] * series[k - i] / (factorial[k - i] * factorial[i])
series[k] += tmp
series[k] *= factorial[k - 1]
return series
return taylor_fprop_exp
def taylor_sin(self):
"""Higher order derivatives rule definition for `Sin` operation."""
cos = P.Cos()
sin = P.Sin()
def taylor_fprop_sin(inputs):
primal_sin = sin(inputs[0])
primal_cos = cos(inputs[0])
series_sin = zeros_like(inputs)
series_cos = zeros_like(inputs)
series_sin[0] = primal_sin
series_cos[0] = primal_cos
series_num = len(inputs) - 1
factorial = _factorial(series_num)
for k in range(1, series_num + 1):
for i in range(1, k + 1):
series_sin[k] += i * inputs[i] * series_cos[k - i] / (factorial[i] * factorial[k - i])
series_cos[k] -= i * inputs[i] * series_sin[k - i] / (factorial[i] * factorial[k - i])
series_sin[k] *= factorial[k - 1]
series_cos[k] *= factorial[k - 1]
return series_sin
return taylor_fprop_sin
def taylor_cos(self):
"""Higher order derivatives rule definition for `Cos` operation."""
cos = P.Cos()
sin = P.Sin()
def taylor_fprop_cos(inputs):
primal_cos = cos(inputs[0])
primal_sin = sin(inputs[0])
series_cos = zeros_like(inputs)
series_sin = zeros_like(inputs)
series_cos[0] = primal_cos
series_sin[0] = primal_sin
series_num = len(inputs) - 1
factorial = _factorial(series_num)
for k in range(1, series_num + 1):
for i in range(1, k + 1):
series_cos[k] -= i * inputs[i] * series_sin[k - i] / (factorial[i] * factorial[k - i])
series_sin[k] += i * inputs[i] * series_cos[k - i] / (factorial[i] * factorial[k - i])
series_cos[k] *= factorial[k - 1]
series_sin[k] *= factorial[k - 1]
return series_cos
return taylor_fprop_cos
@ -21,7 +21,7 @@ Pre-defined combination of operators.
from .base import GradOperation, _Grad, HyperMap, Map, MultitypeFuncGraph, add_flags, \
core, env_get, tail, zip_operation, Shard, _Vmap
core, env_get, tail, zip_operation, Shard, _Vmap, _TaylorOperation
from .clip_ops import clip_by_value, clip_by_global_norm
from .multitype_ops.add_impl import hyper_add
from .multitype_ops.ones_like_impl import ones_like
@ -22,7 +22,7 @@ from types import FunctionType
from mindspore import context
from ..._c_expression import GradOperation_, HyperMap_, Map_, MultitypeFuncGraph_, Tail_, Shard_, \
TupleAdd_, TupleSlice_, UnpackCall_, ZipOperation_, ListAppend_, TupleGetItemTensor_, ListInsert_, \
ListSlice_, VmapOperation_
ListSlice_, VmapOperation_, TaylorOperation_
from ...common import dtype as mstype
from ...common.api import ms_function, _pynative_executor, _wrap_func
from ..primitive import Primitive
@ -401,6 +401,29 @@ class GradOperation(GradOperation_):
return self.grad_fn
class _TaylorOperation(TaylorOperation_):
Generate the higher order derivatives function for the input function.
def __init__(self):
"""Initialize TaylorOperation."""
TaylorOperation_.__init__(self, 'taylorgrad')
self.grad_fn = None
self.fn = None
def __call__(self, fn):
if self.grad_fn is not None and self.fn == fn:
return self.grad_fn
taylor_grad_ = _TaylorOperation()
# If calling Grad in GRAPH_MODE or calling Grad in ms_function, do grad in GRAPH_MODE
def after_taylor_grad(*args):
return taylor_grad_(fn)(*args)
self.grad_fn = after_taylor_grad
self.fn = fn
return self.grad_fn
class _Grad(GradOperation_):
A higher-order function which is used to generate the gradient function by position for the input function.
@ -32,7 +32,7 @@ from .primitive import Primitive
from . import operations as P
from .operations import _grad_ops
from .operations import _csr_ops
from .composite import _Grad, Shard, _Vmap
from .composite import _Grad, Shard, _Vmap, _TaylorOperation
from .._c_expression import security
typeof = Primitive('typeof')
@ -48,6 +48,7 @@ eye = P.Eye()
fill = P.Fill()
tile = P.Tile()
size = P.Size()
ones = P.Ones()
ones_like = P.OnesLike()
shape = P.Shape()
dyn_shape = P.TensorShape()
@ -285,6 +286,187 @@ def grad(fn, grad_position=0, sens_param=False):
return grad_by_position_with_sens(fn, None, grad_position)
return grad_by_position(fn, None, grad_position)
def _trans_jet_inputs(primals_item, series_item):
"""Trans inputs of jet"""
value_type = [mstype.int32, mstype.int64, mstype.float32, mstype.float64]
if not dtype(primals_item) in value_type or dtype(primals_item) != dtype(series_item):
raise TypeError(f"For `F.jet`, the elements' types of primals and series should be the same and belong to "
f"`mstype.int32, mstype.int64, mstype.float32, mstype.float64`, but got"
f" {dtype(primals_item).__name__} and {dtype(series_item).__name__}.")
if dtype(primals_item) in [mstype.int32, mstype.int64]:
return cast(primals_item, mstype.float64), cast(series_item, mstype.float64)
return primals_item, series_item
def _check_jet_inputs(primals, series):
"""Check inputs of jet"""
if not isinstance(primals, type(series)) or not isinstance(primals, (Tensor, tuple)):
raise TypeError(f"For 'F.jet', the 'primals' and `series` should be both Tensor or tuple, "
f"but got {type(primals).__name__} and {type(series).__name__}.")
if isinstance(primals, Tensor):
if primals.shape != series.shape[1:]:
raise ValueError("The shape of each element should be the same as the primals.")
return _trans_jet_inputs(primals, series)
if isinstance(primals, tuple):
if len(primals) != len(series):
raise ValueError("The lengths of primals and series should be the same.")
check_primals = []
check_series = []
for i, j in zip(primals, series):
trans_primals_item, trans_series_item = _trans_jet_inputs(i, j)
return check_primals, check_series
_taylor = _TaylorOperation()
def jet(fn, primals, series):
This function is designed to calculate the higher order differentiation of given composite function. To figure out
first to `n`-th order differentiations, original inputs and first to `n`-th order derivative of original inputs
must be provided together. Generally, it is recommended to set the values of given first order derivative to 1,
while the other to 0.
fn (Union(Cell, function)): Function to do TaylorOperation.
primals (Union(Tensor, Tuple of Tensors)): The inputs to `fn`.
series (Union(Tensor, Tuple of Tensors)): If tuple, the length and type of series should be the same as inputs.
For each Tensor, the length of first dimension `i` represents the `1` to `i+1`-th order of derivative of
output with respect to the inputs will be figured out.
Tuple, tuple of out_primals and out_series.
- **out_primals** (Tensors or List of Tensors) - The output of `fn(primals)`.
- **out_series** (Tensors or List of Tensors) - The `1` to `i+1`-th order of derivative of output with respect
to the inputs.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
>>> import numpy as np
>>> import mindspore.nn as nn
>>> import mindspore.context as context
>>> import mindspore.ops as P
>>> from mindspore import Tensor
>>> from mindspore.ops.functional import jet
>>> context.set_context(mode=context.GRAPH_MODE)
>>> class Net(nn.Cell):
... def __init__(self):
... super().__init__()
... self.sin = P.Sin()
... self.exp = P.Exp()
... def construct(self, x):
... out1 = self.sin(x)
... out2 = self.exp(out1)
... return out2
>>> primals = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
>>> series = Tensor(np.array([[[1, 1], [1, 1]], [[0, 0], [0, 0]], [[0, 0], [0, 0]]]).astype(np.float32))
>>> net = Net()
>>> out_primals, out_series = jet(net, primals, series)
>>> print(out_primals, out_series)
primals, series = _check_jet_inputs(primals, series)
derivative_fn = _taylor(fn)
concat_op = P.Concat()
if isinstance(primals, list) and list_len(primals) > 1:
inputs = list(map(lambda x, y: concat_op(((expand_dims(x, 0), y))), primals, series))
outputs = derivative_fn(*inputs)
inputs = concat_op((expand_dims(primals, 0), series))
outputs = derivative_fn(inputs)
if isinstance(outputs, list) and list_len(outputs) > 1:
out_primals = [element[0] for element in outputs]
out_series = [element[1:] for element in outputs]
out_primals = outputs[0]
out_series = outputs[1:]
return out_primals, out_series
def _trans_derivative_inputs(primals_item):
"""Trans inputs of derivative"""
value_type = [mstype.int32, mstype.int64, mstype.float32, mstype.float64]
if not dtype(primals_item) in value_type:
raise TypeError(f"For `F.derivative`, the elements of primals should belong to "
f"`mstype.int32, mstype.int64, mstype.float32, mstype.float64`, but got"
f" {dtype(primals_item).__name__}.")
if dtype(primals_item) in [mstype.int32, mstype.int64]:
return cast(primals_item, mstype.float64)
return primals_item
def derivative(fn, primals, order):
This function is designed to calculate the higher order differentiation of given composite function. To figure out
`order`-th order differentiations, original inputs and order must be provided together. In particular, the value of
input first order derivative is set to 1, while the other to 0.
fn (Union(Cell, function)): Function to do TaylorOperation.
primals (Union(Tensor, Tuple of Tensors)): The inputs to `fn`.
order (int): For each Tensor, the `order`-th order of derivative of output with respect to the inputs will be
figured out.
Tuple, tuple of out_primals and out_series.
- **out_primals** (Tensors or List of Tensors) - The output of `fn(primals)`.
- **out_series** (Tensors or List of Tensors) - The `order`-th order of derivative of output with respect
to the inputs.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
>>> import numpy as np
>>> import mindspore.nn as nn
>>> import mindspore.context as context
>>> import mindspore.ops as P
>>> from mindspore import Tensor
>>> from mindspore.ops.functional import derivative
>>> context.set_context(mode=context.GRAPH_MODE)
>>> class Net(nn.Cell):
... def __init__(self):
... super().__init__()
... self.sin = P.Sin()
... self.exp = P.Exp()
... def construct(self, x):
... out1 = self.sin(x)
... out2 = self.exp(out1)
... return out2
>>> primals = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
>>> order = 3
>>> net = Net()
>>> out_primals, out_series = derivative(net, primals, order)
>>> print(out_primals, out_series)
derivative_fn = _taylor(fn)
concat_op = P.Concat()
series_one = 1
if isinstance(primals, tuple):
trans_primals = [_trans_derivative_inputs(item) for item in primals]
inputs = list(map(lambda x: concat_op((expand_dims(x, 0), ones((1,) + x.shape, dtype(x)))), trans_primals))
if order > 1:
inputs = list(map(lambda x: concat_op((x, zeros((order - 1,) + x[0].shape, dtype(x)))), inputs))
outputs = derivative_fn(*inputs)
primals = _trans_derivative_inputs(primals)
series = zeros((order,) + primals.shape, dtype(primals))
series[0] = series_one
inputs = concat_op((expand_dims(primals, 0), series))
outputs = derivative_fn(inputs)
if isinstance(outputs, tuple) and tuple_len(outputs) > 1:
out_primals = [element[0] for element in outputs]
out_series = [element[-1] for element in outputs]
out_primals = outputs[0]
out_series = outputs[-1]
return out_primals, out_series
def jvp(fn, inputs, v):
@ -38,6 +38,11 @@ class MultipleInputsOutputNet(nn.Cell):
def test_vjp_single_input_graph():
Features: Function vjp
Description: Test vjp with single input, single output and default v in graph mode.
Expectation: No exception.
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
net = SingleInputNet()
@ -53,6 +58,11 @@ def test_vjp_single_input_graph():
def test_vjp_multiple_inputs_default_v_graph():
Features: Function vjp
Description: Test vjp with single input, single output and default v in graph mode.
Expectation: No exception.
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
@ -140,8 +150,8 @@ def test_vjp_construct_single_input_single_output_default_v_graph():
net_out, vjp_out = vjp(, inputs, vectors)
return net_out, vjp_out
test_net = Net(SingleInputNet())
primal, grad = test_net(x, v)
test_net_graph = Net(SingleInputNet())
primal, grad = test_net_graph(x, v)
expect_primal = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
expect_grad = Tensor(np.array([[3, 12], [27, 48]]).astype(np.float32))
assert np.allclose(primal.asnumpy(), expect_primal.asnumpy())
@ -18,7 +18,6 @@ import pytest
import mindspore.nn as nn
import mindspore.context as context
from mindspore import Tensor
from mindspore import ms_function
from mindspore.ops.functional import vjp
@ -37,12 +36,17 @@ class MultipleInputsOutputNet(nn.Cell):
def test_vjp_single_input_graph():
def test_vjp_single_input_pynative():
Features: Function vjp
Description: Test vjp with single input, single output and default v in pynative mode.
Expectation: No exception.
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
net = SingleInputNet()
expect_primal = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
expect_grad = Tensor(np.array([[3, 12], [27, 48]]).astype(np.float32))
expect_primal = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
primal, grad = vjp(net, x, v)
assert np.allclose(primal.asnumpy(), expect_primal.asnumpy())
assert np.allclose(grad.asnumpy(), expect_grad.asnumpy())
@ -51,58 +55,38 @@ def test_vjp_single_input_graph():
def test_vjp_multiple_inputs_default_v_graph():
def test_vjp_multiple_inputs_default_v_pynative():
Features: Function vjp
Description: Test vjp with multiple inputs, multiple outputs and default v in pynative mode.
Expectation: No exception.
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
net = MultipleInputsOutputNet()
expect_primal_0 = Tensor(np.array([[2, 4], [6, 8]]).astype(np.float32))
expect_primal_1 = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
expect_grad_0 = Tensor(np.array([[2, 2], [2, 2]]).astype(np.float32))
expect_grad_1 = Tensor(np.array([[3, 12], [27, 48]]).astype(np.float32))
expect_primal_0 = Tensor(np.array([[2, 4], [6, 8]]).astype(np.float32))
expect_primal_1 = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
primal, grad = vjp(net, (x, y), (v, v))
assert isinstance(primal, tuple)
assert len(primal) == 2
assert np.allclose(primal[0].asnumpy(), expect_primal_0.asnumpy())
assert np.allclose(primal[1].asnumpy(), expect_primal_1.asnumpy())
assert isinstance(grad, tuple)
assert len(grad) == 2
assert np.allclose(grad[0].asnumpy(), expect_grad_0.asnumpy())
assert np.allclose(grad[1].asnumpy(), expect_grad_1.asnumpy())
assert isinstance(primal, tuple)
assert len(primal) == 2
assert np.allclose(primal[0].asnumpy(), expect_primal_0.asnumpy())
assert np.allclose(primal[1].asnumpy(), expect_primal_1.asnumpy())
def test_vjp_ms_function_single_input_single_output_default_v_graph():
def test_vjp_input_function_single_input_single_output_default_v_pynative():
Features: Function vjp
Description: Test vjp with ms_function, single input, single output and default v in graph mode.
Expectation: No exception.
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
net = SingleInputNet()
def vjp_with_ms_function(inputs, vectors):
output, vjp_grad = vjp(net, inputs, vectors)
return output, vjp_grad
primal, grad = vjp_with_ms_function(x, v)
expect_primal = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
expect_grad = Tensor(np.array([[3, 12], [27, 48]]).astype(np.float32))
assert np.allclose(primal.asnumpy(), expect_primal.asnumpy())
assert np.allclose(grad.asnumpy(), expect_grad.asnumpy())
def test_vjp_input_function_single_input_single_output_default_v_graph():
Features: Function vjp
Description: Test vjp with function, single input, single output and default v in graph mode.
Description: Test vjp with function, single input, single output and default v in pynative mode.
Expectation: No exception.
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
@ -112,8 +96,8 @@ def test_vjp_input_function_single_input_single_output_default_v_graph():
return inputs**3
primal, grad = vjp(test_function, x, v)
expect_primal = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
expect_grad = Tensor(np.array([[3, 12], [27, 48]]).astype(np.float32))
expect_primal = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
assert np.allclose(primal.asnumpy(), expect_primal.asnumpy())
assert np.allclose(grad.asnumpy(), expect_grad.asnumpy())
@ -121,10 +105,10 @@ def test_vjp_input_function_single_input_single_output_default_v_graph():
def test_vjp_construct_single_input_single_output_default_v_graph():
def test_vjp_construct_single_input_single_output_default_v_pynative():
Features: Function vjp
Description: Test vjp with function, single input, single output and default v in graph mode.
Description: Test vjp with function, single input, single output and default v in pynative mode.
Expectation: No exception.
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
@ -139,8 +123,8 @@ def test_vjp_construct_single_input_single_output_default_v_graph():
net_out, vjp_out = vjp(, inputs, vectors)
return net_out, vjp_out
test_net = Net(SingleInputNet())
primal, grad = test_net(x, v)
test_net_pynative = Net(SingleInputNet())
primal, grad = test_net_pynative(x, v)
expect_primal = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
expect_grad = Tensor(np.array([[3, 12], [27, 48]]).astype(np.float32))
assert np.allclose(primal.asnumpy(), expect_primal.asnumpy())
@ -0,0 +1,142 @@
# Copyright 2022 Huawei Technologies Co., Ltd
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""test taylor differentiation in graph mode"""
import pytest
import numpy as np
import mindspore.nn as nn
import mindspore.context as context
from mindspore.ops import operations as P
from mindspore import Tensor
from mindspore.ops.functional import jet, derivative
class MultipleInputSingleOutputNet(nn.Cell):
def __init__(self):
super(MultipleInputSingleOutputNet, self).__init__()
self.sin = P.Sin()
self.cos = P.Cos()
self.exp = P.Exp()
def construct(self, x, y):
out1 = self.sin(x)
out2 = self.cos(y)
out3 = out1 * out2 + out1 / out2
out = self.exp(out3)
return out
class SingleInputSingleOutputNet(nn.Cell):
def __init__(self):
super(SingleInputSingleOutputNet, self).__init__()
self.sin = P.Sin()
self.cos = P.Cos()
self.exp = P.Exp()
def construct(self, x):
out1 = self.sin(x)
out2 = self.cos(out1)
out3 = self.exp(out2)
out = out1 + out2 - out3
return out
def test_jet_single_input_single_output_graph_mode():
Features: Function jet
Description: Test jet with single input in graph mode.
Expectation: No exception.
primals = Tensor([1., 1.])
series = Tensor([[1., 1.], [0., 0.], [0., 0.]])
net = SingleInputSingleOutputNet()
expected_primals = np.array([-0.43931, -0.43931]).astype(np.float32)
expected_series = np.array([[0.92187, 0.92187], [-1.56750, -1.56750], [-0.74808, -0.74808]]).astype(np.float32)
out_primals, out_series = jet(net, primals, series)
assert np.allclose(out_series.asnumpy(), expected_series, atol=1.e-4)
assert np.allclose(out_primals.asnumpy(), expected_primals, atol=1.e-4)
def test_derivative_single_input_single_output_graph_mode():
Features: Function derivative
Description: Test derivative with single input in graph mode.
Expectation: No exception.
primals = Tensor([1., 1.])
order = 3
net = SingleInputSingleOutputNet()
expected_primals = np.array([-0.43931, -0.43931]).astype(np.float32)
expected_series = np.array([-0.74808, -0.74808]).astype(np.float32)
out_primals, out_series = derivative(net, primals, order)
assert np.allclose(out_primals.asnumpy(), expected_primals, atol=1.e-4)
assert np.allclose(out_series.asnumpy(), expected_series, atol=1.e-4)
def test_jet_multiple_input_single_output_graph_mode():
Features: Function jet
Description: Test jet with multiple inputs in graph mode.
Expectation: No exception.
primals = (Tensor([1., 1.]), Tensor([1., 1.]))
series = (Tensor([[1., 1.], [0., 0.], [0., 0.]]), Tensor([[1., 1.], [0., 0.], [0., 0.]]))
net = MultipleInputSingleOutputNet()
expected_primals = np.array([7.47868, 7.47868]).astype(np.float32)
expected_series = np.array([[22.50614, 22.50614], [133.92517, 133.92517], [1237.959, 1237.959]]).astype(np.float32)
out_primals, out_series = jet(net, primals, series)
assert np.allclose(out_primals.asnumpy(), expected_primals, atol=1.e-4)
assert np.allclose(out_series.asnumpy(), expected_series, atol=1.e-4)
def test_derivative_multiple_input_single_output_graph_mode():
Features: Function derivative
Description: Test derivative with multiple inputs in graph mode.
Expectation: No exception.
primals = (Tensor([1., 1.]), Tensor([1., 1.]))
order = 3
net = MultipleInputSingleOutputNet()
expected_primals = np.array([7.47868, 7.47868]).astype(np.float32)
expected_series = np.array([1237.959, 1237.959]).astype(np.float32)
out_primals, out_series = derivative(net, primals, order)
assert np.allclose(out_primals.asnumpy(), expected_primals, atol=1.e-4)
assert np.allclose(out_series.asnumpy(), expected_series, atol=1.e-4)
@ -0,0 +1,142 @@
# Copyright 2022 Huawei Technologies Co., Ltd
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""test taylor differentiation in pynative mode"""
import pytest
import numpy as np
import mindspore.nn as nn
import mindspore.context as context
from mindspore.ops import operations as P
from mindspore import Tensor
from mindspore.ops.functional import jet, derivative
class SingleInputSingleOutputNet(nn.Cell):
def __init__(self):
super(SingleInputSingleOutputNet, self).__init__()
self.exp = P.Exp()
self.cos = P.Cos()
self.sin = P.Sin()
def construct(self, x):
out1 = self.sin(x)
out2 = self.cos(out1)
out3 = self.exp(out2)
out = out1 + out2 - out3
return out
class MultipleInputSingleOutputNet(nn.Cell):
def __init__(self):
super(MultipleInputSingleOutputNet, self).__init__()
self.exp = P.Exp()
self.cos = P.Cos()
self.sin = P.Sin()
def construct(self, x, y):
out1 = self.sin(x)
out2 = self.cos(y)
out3 = out1 * out2 + out1 / out2
out = self.exp(out3)
return out
def test_jet_multiple_input_single_output_pynative_mode():
Features: Function jet
Description: Test jet with multiple inputs in pynative mode.
Expectation: No exception.
series = (Tensor([[1., 1.], [0., 0.], [0., 0.]]), Tensor([[1., 1.], [0., 0.], [0., 0.]]))
primals = (Tensor([1., 1.]), Tensor([1., 1.]))
net = MultipleInputSingleOutputNet()
expected_primals = np.array([7.47868, 7.47868]).astype(np.float32)
expected_series = np.array([[22.50614, 22.50614], [133.92517, 133.92517], [1237.959, 1237.959]]).astype(np.float32)
out_primals, out_series = jet(net, primals, series)
assert np.allclose(out_series.asnumpy(), expected_series, atol=1.e-4)
assert np.allclose(out_primals.asnumpy(), expected_primals, atol=1.e-4)
def test_derivative_multiple_input_single_output_pynative_mode():
Features: Function derivative
Description: Test derivative with multiple inputs in pynative mode.
Expectation: No exception.
primals = (Tensor([1., 1.]), Tensor([1., 1.]))
order = 3
net = MultipleInputSingleOutputNet()
expected_primals = np.array([7.47868, 7.47868]).astype(np.float32)
expected_series = np.array([1237.959, 1237.959]).astype(np.float32)
out_primals, out_series = derivative(net, primals, order)
assert np.allclose(out_primals.asnumpy(), expected_primals, atol=1.e-4)
assert np.allclose(out_series.asnumpy(), expected_series, atol=1.e-4)
def test_jet_single_input_single_output_pynative_mode():
Features: Function jet
Description: Test jet with single input in pynative mode.
Expectation: No exception.
primals = Tensor([1., 1.])
series = Tensor([[1., 1.], [0., 0.], [0., 0.]])
net = SingleInputSingleOutputNet()
expected_primals = np.array([-0.43931, -0.43931]).astype(np.float32)
expected_series = np.array([[0.92187, 0.92187], [-1.56750, -1.56750], [-0.74808, -0.74808]]).astype(np.float32)
out_primals, out_series = jet(net, primals, series)
assert np.allclose(out_primals.asnumpy(), expected_primals, atol=1.e-4)
assert np.allclose(out_series.asnumpy(), expected_series, atol=1.e-4)
def test_derivative_single_input_single_output_pynative_mode():
Features: Function derivative
Description: Test derivative with single input in pynative mode.
Expectation: No exception.
primals = Tensor([1., 1.])
order = 3
net = SingleInputSingleOutputNet()
expected_primals = np.array([-0.43931, -0.43931]).astype(np.float32)
expected_series = np.array([-0.74808, -0.74808]).astype(np.float32)
out_primals, out_series = derivative(net, primals, order)
assert np.allclose(out_primals.asnumpy(), expected_primals, atol=1.e-4)
assert np.allclose(out_series.asnumpy(), expected_series, atol=1.e-4)
Reference in New Issue