diff --git a/mindspore/ccsrc/frontend/operator/composite/composite.cc b/mindspore/ccsrc/frontend/operator/composite/composite.cc index 159ffcd67af..d470716f358 100644 --- a/mindspore/ccsrc/frontend/operator/composite/composite.cc +++ b/mindspore/ccsrc/frontend/operator/composite/composite.cc @@ -587,7 +587,7 @@ FuncGraphPtr MakeTupleGradient::GenerateFuncGraph(const AbstractBasePtrList &arg std::vector grads; grads.push_back(NewValueNode(prim::kPrimMakeTuple)); - grads.push_back(NewValueNode(newenv)); + grads.push_back(NewEnviron(b)); for (int64_t i = 0; i < tuple_size; ++i) { grads.push_back(b->NewCNodeInOrder({NewValueNode(prim::kPrimTupleGetItem), dout, NewValueNode(i)})); } @@ -628,7 +628,7 @@ FuncGraphPtr MakeListGradient::GenerateFuncGraph(const AbstractBasePtrList &args std::vector grads; grads.push_back(NewValueNode(prim::kPrimMakeTuple)); - grads.push_back(NewValueNode(newenv)); + grads.push_back(NewEnviron(b)); for (int64_t i = 0; i < list_size; ++i) { grads.push_back(b->NewCNodeInOrder({NewValueNode(prim::kPrimListGetItem), dout, NewValueNode(i)})); } diff --git a/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.cc b/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.cc index 75411b4353e..a2513049c1c 100644 --- a/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.cc +++ b/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.cc @@ -128,7 +128,7 @@ void DFunctor::BackPropagateFv(const AnfNodePtr &fv, const AnfNodePtr &din) { fv_adjoint->second->RegisterKUser(default_val_node, 1); anfnode_to_envitem_[fv_node] = std::make_pair(embed_node, default_val_node); } - auto dfv = tape_->NewCNode({NewValueNode(prim::kPrimEnvGetItem), din, embed_node, default_val_node}); + auto dfv = tape_->NewCNode({NewValueNode(prim::kPrimEnvironGet), din, embed_node, default_val_node}); MS_LOG(DEBUG) << "BackPropagateFv find adjoint in anfnode_to_adjoin_ or anfnode_to_adjoin_indirect_fv_ fv " << fv->func_graph()->ToString() << " " << fv->ToString() << "."; MS_LOG(DEBUG) << "BackPropagateFv get item from " << din->ToString() << " key " << embed_node->ToString() << "."; @@ -421,7 +421,7 @@ AnfNodePtr DFunctor::AttachFvDoutToTape(const AnfNodePtr &grad_fv) { fv_adjoint->second->RegisterKUser(node, 1); auto sens = fv_adjoint->second->dout(); new_grad_fv = tape_->NewCNode({ - NewValueNode(prim::kPrimEnvSetItem), + NewValueNode(prim::kPrimEnvironSet), new_grad_fv, node, sens, @@ -448,7 +448,7 @@ AnfNodePtr DFunctor::AttachIndirectFvDoutToTape(const AnfNodePtr &grad_fv) { fv_adjoint.second->RegisterKUser(node, 1); auto sens = fv_adjoint.second->dout(); new_grad_fv = tape_->NewCNode({ - NewValueNode(prim::kPrimEnvSetItem), + NewValueNode(prim::kPrimEnvironSet), new_grad_fv, node, sens, @@ -486,9 +486,9 @@ void DFunctor::MapMorphism() { // Set output for tape closure. AnfNodePtr grad_fv; if (lift_fv_before_grad) { - grad_fv = AttachFvDoutToTape(NewValueNode(newenv)); + grad_fv = AttachFvDoutToTape(NewEnviron(tape_)); } else { - grad_fv = AttachIndirectFvDoutToTape(AttachFvDoutToTape(NewValueNode(newenv))); + grad_fv = AttachIndirectFvDoutToTape(AttachFvDoutToTape(NewEnviron(tape_))); } std::vector inputs{NewValueNode(prim::kPrimMakeTuple), grad_fv}; diff --git a/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.h b/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.h index ee842f4c5d9..6a5ee188b5d 100644 --- a/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.h +++ b/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.h @@ -115,7 +115,7 @@ class DFunctor : public std::enable_shared_from_this { mindspore::HashMap anfnode_to_adjoin_; // Cache for indirect fv backpropagation, K o K can only do backprop layer by layer. mindspore::HashMap anfnode_to_adjoin_indirect_fv_; - // Cache for fv node -> pair, zeros_like>, so EnvGetItemTransform in optimizer + // Cache for fv node -> pair, zeros_like>, so EnvironGetTransform in optimizer // can hit its cache if fv_node is same. mindspore::HashMap> anfnode_to_envitem_; FuncGraphPtr primal_graph_; diff --git a/mindspore/ccsrc/frontend/optimizer/ad/kprim.cc b/mindspore/ccsrc/frontend/optimizer/ad/kprim.cc index b285d0f11e9..01204bf97a0 100644 --- a/mindspore/ccsrc/frontend/optimizer/ad/kprim.cc +++ b/mindspore/ccsrc/frontend/optimizer/ad/kprim.cc @@ -611,7 +611,7 @@ AnfNodePtr KPrim::BuildOutput(const FuncGraphPtr &bprop_fg, const FuncGraphPtr & std::vector args; args.push_back(NewValueNode(prim::kPrimMakeTuple)); - args.push_back(NewValueNode(newenv)); + args.push_back(NewEnviron(bprop_fg)); (void)args.insert(args.end(), inputs.begin() + 1, inputs.end()); if (!extra_args.empty()) { args.insert(args.end(), extra_args.cbegin(), extra_args.cend()); @@ -622,7 +622,7 @@ AnfNodePtr KPrim::BuildOutput(const FuncGraphPtr &bprop_fg, const FuncGraphPtr & // Set bprop output as (env, dx) std::string model_name("mindspore.ops.composite.multitype_ops.add_impl"); std::string python_ops("_tuple_add"); - auto tuple_env = NewCNode({NewValueNode(prim::kPrimMakeTuple), NewValueNode(newenv)}, bprop_fg); + auto tuple_env = NewCNode({NewValueNode(prim::kPrimMakeTuple), NewEnviron(bprop_fg)}, bprop_fg); auto tuple_add_ops = NewValueNode(prim::GetPythonOps(python_ops, model_name)); if (!extra_args.empty()) { extra_args.insert(extra_args.begin(), NewValueNode(prim::kPrimMakeTuple)); diff --git a/mindspore/ccsrc/frontend/optimizer/environ_conversion.cc b/mindspore/ccsrc/frontend/optimizer/environ_conversion.cc new file mode 100644 index 00000000000..4b8570bd586 --- /dev/null +++ b/mindspore/ccsrc/frontend/optimizer/environ_conversion.cc @@ -0,0 +1,186 @@ +/** + * Copyright 2019-2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/optimizer/environ_conversion.h" + +#include +#include +#include + +#include "utils/hash_map.h" +#include "abstract/abstract_function.h" +#include "utils/flags.h" +#include "utils/utils.h" +#include "utils/anf_utils.h" +#include "utils/symbolic.h" + +namespace mindspore { +/* namespace to support opt */ +namespace opt { +using SymbolicKeyConversionMap = + mindspore::HashMap; + +namespace { +bool IsAbstractEnvType(const abstract::AbstractBasePtr &abs) { + if (abs != nullptr && abs->isa() && abs->GetTypeTrack()->isa()) { + return true; + } + return false; +} + +abstract::AbstractBasePtr TransformAbstractRecursively(const abstract::AbstractBasePtr &orig_abs, + const abstract::AbstractBasePtr target_abs) { + if (orig_abs == nullptr) { + return nullptr; + } + if (IsAbstractEnvType(orig_abs)) { + return target_abs; + } + if (!orig_abs->isa()) { + return nullptr; + } + const auto &abs_seq = orig_abs->cast(); + MS_EXCEPTION_IF_NULL(abs_seq); + const auto &elements = abs_seq->elements(); + abstract::AbstractBasePtrList new_elements; + bool transformed = false; + for (auto elem : elements) { + if (elem->isa()) { + auto inner_abs_seq = elem->cast(); + auto transformed_abs = TransformAbstractRecursively(inner_abs_seq, target_abs); + if (transformed_abs != nullptr) { + new_elements.push_back(transformed_abs); + transformed = true; + } else { + new_elements.push_back(elem); + } + } else if (IsAbstractEnvType(elem)) { + new_elements.push_back(target_abs); + transformed = true; + } else { + new_elements.push_back(elem); + } + } + if (transformed) { + abstract::AbstractBasePtr new_abs; + if (abs_seq->isa()) { + new_abs = std::make_shared(new_elements); + } else if (abs_seq->isa()) { + new_abs = std::make_shared(new_elements); + } else { + MS_LOG(EXCEPTION) << "abs_seq is not AbstractTuple or AbstractList, but: " << abs_seq->ToString(); + } + return new_abs; + } + // No transformation. + return nullptr; +} + +void TransformNodeAbstractIfEnvType(const AnfNodePtr &node, const abstract::AbstractBasePtr &target_abs) { + auto transformed_abs = TransformAbstractRecursively(node->abstract(), target_abs); + if (transformed_abs != nullptr) { + node->set_abstract(transformed_abs); + } +} + +TypeId GetValueType(const CNodePtr &cnode) { + // (EnvironSet/EnvironGet, environ, key, value/default) + if (cnode->inputs().size() != 4) { + MS_LOG(EXCEPTION) << "EnvrionSet/EnvironGet cnode should have 4 inputs, but: " << cnode->DebugString(); + } + const auto &value_abstract = cnode->input(3)->abstract(); + if (value_abstract == nullptr) { + MS_LOG(EXCEPTION) << "4th input of EnvironSet/EnvironGet cnode should abstract, but not set, node: " + << cnode->DebugString(); + } + if (IsAbstractEnvType(value_abstract)) { + return kObjectTypeEnvType; + } else { + return kObjectTypeTensorType; + } +} + +AnfNodePtr GetTransformedKeyNode(const AnfNodePtr &old_key_node, SymbolicKeyConversionMap *para_symbol_map) { + const auto &symbolic_key_inst = GetValueNode(old_key_node); + uint64_t transformed_key = 0; + auto &symbolic_key_map = *para_symbol_map; + auto iter = symbolic_key_map.find(symbolic_key_inst); + if (iter != symbolic_key_map.end()) { + transformed_key = iter->second; + } else { + static uint64_t key_counter = 0; + transformed_key = ++key_counter; + symbolic_key_map.emplace(std::make_pair(symbolic_key_inst, transformed_key)); + } + auto tensor_key = std::make_shared(transformed_key); + auto transformed_key_node = NewValueNode(tensor_key); + transformed_key_node->set_abstract(tensor_key->ToAbstract()); + return transformed_key_node; +} +} // namespace + +bool EnvironConversion(const pipeline::ResourcePtr &resource) { + SymbolicKeyConversionMap symbolic_key_map; + static AbstractBasePtr scalar_abs = std::make_shared(kAnyValue, kInt64); + static AbstractBasePtr tensor_abs = std::make_shared(scalar_abs); + static std::string attr_name = "value_type"; + const int kPrimitiveOffset = 0; + const int kSymbolicKeyOffset = 2; + auto mng = resource->manager(); + const auto &all_nodes = mng->all_nodes(); + auto txn = mng->Transact(); + for (const auto &node : all_nodes) { + TransformNodeAbstractIfEnvType(node, tensor_abs); + if (!IsPrimitiveCNode(node, prim::kPrimEnvironSet) && !IsPrimitiveCNode(node, prim::kPrimEnvironGet)) { + continue; + } + const auto &cnode = node->cast(); + // Prim + AnfNodePtr transformed_prim_node; + const auto &type_id = GetValueType(cnode); + PrimitivePtr prim = GetValueNode(cnode->input(kPrimitiveOffset)); + MS_EXCEPTION_IF_NULL(prim); + const auto &old_attr = prim->GetAttr(attr_name); + if (old_attr != nullptr) { + const auto old_attr_value = GetValue(old_attr); + if (old_attr_value != type_id) { + if (IsPrimitiveCNode(node, prim::kPrimEnvironSet)) { + prim = std::make_shared(prim::kEnvironSet); + } else { + prim = std::make_shared(prim::kEnvironGet); + } + MS_EXCEPTION_IF_NULL(prim); + prim->set_attr(attr_name, MakeValue(static_cast(type_id))); + transformed_prim_node = NewValueNode(prim); + txn.SetEdge(node, kPrimitiveOffset, transformed_prim_node); + } + } else { + prim->set_attr(attr_name, MakeValue(static_cast(type_id))); + } + // Abstract of Environ & Value will be set by previous TransformNodeAbstract function. + // Key + if (!IsValueNode(cnode->input(kSymbolicKeyOffset))) { + MS_LOG(EXCEPTION) << "should be SymbolicKey, but: " << cnode->input(kSymbolicKeyOffset)->ToString(); + } + const auto &transformed_key_node = GetTransformedKeyNode(cnode->input(kSymbolicKeyOffset), &symbolic_key_map); + txn.SetEdge(node, kSymbolicKeyOffset, transformed_key_node); + } + + txn.Commit(); + return true; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/core/utils/symbolic.cc b/mindspore/ccsrc/frontend/optimizer/environ_conversion.h similarity index 51% rename from mindspore/core/utils/symbolic.cc rename to mindspore/ccsrc/frontend/optimizer/environ_conversion.h index fd9d01d808d..8db2c439ddc 100644 --- a/mindspore/core/utils/symbolic.cc +++ b/mindspore/ccsrc/frontend/optimizer/environ_conversion.h @@ -1,6 +1,4 @@ /** - * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). - * * Copyright 2019-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -16,24 +14,16 @@ * limitations under the License. */ -#include "utils/symbolic.h" +#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_ENVIRON_CONVERSION_H_ +#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_ENVIRON_CONVERSION_H_ -#include +#include "pipeline/jit/resource.h" namespace mindspore { -std::ostream &operator<<(std::ostream &out, const std::shared_ptr &objPtr) { - out << "()"; - return out; -} - -bool EnvInstance::operator==(const EnvInstance &other) const { return true; } - -bool EnvInstance::operator==(const Value &other) const { - if (other.isa()) { - auto other_env_inst = static_cast(&other); - return *this == *other_env_inst; - } - return false; -} -std::shared_ptr newenv = std::make_shared(); +/* namespace to support opt */ +namespace opt { +bool EnvironConversion(const pipeline::ResourcePtr &resource); +} // namespace opt } // namespace mindspore + +#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_ENVIRON_CONVERSION_H_ diff --git a/mindspore/ccsrc/frontend/optimizer/irpass.cc b/mindspore/ccsrc/frontend/optimizer/irpass.cc index 412431fd772..16f7be435cf 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass.cc +++ b/mindspore/ccsrc/frontend/optimizer/irpass.cc @@ -19,7 +19,7 @@ #include "frontend/optimizer/irpass/branch_culling.h" #include "frontend/optimizer/irpass/cast_eliminate.h" #include "frontend/optimizer/irpass/convert.h" -#include "frontend/optimizer/irpass/env_item_eliminate.h" +#include "frontend/optimizer/irpass/environ_eliminate.h" #include "frontend/optimizer/irpass/grad_var_prepare.h" #include "frontend/optimizer/irpass/gradient_eliminate.h" #include "frontend/optimizer/irpass/inline.h" @@ -119,26 +119,28 @@ OptimizeIRPassLib::OptimizeIRPassLib() { MakeSubstitution(std::make_shared(), "reduce_all_const_elim", prim::kPrimAllReduce); real_op_eliminate_ = MakeSubstitution(std::make_shared(), "real_op_eliminate", prim::kPrimRealInner); - // Env Item Eliminate - env_get_item_eliminate_ = - MakeSubstitution(std::make_shared(), "env_get_item_eliminate", prim::kPrimEnvGetItem); - env_get_item_add_eliminate_ = - MakeSubstitution(std::make_shared(), "env_get_item_add_eliminate_", prim::kPrimEnvGetItem); - env_get_set_item_eliminate_ = - MakeSubstitution(std::make_shared(), "env_get_set_item_eliminate", prim::kPrimEnvGetItem); - env_get_item_depend_swap_ = - MakeSubstitution(std::make_shared(), "env_get_item_depend_swap", prim::kPrimEnvGetItem); + // Environ Item Eliminate + environ_get_eliminate_ = + MakeSubstitution(std::make_shared(), "environ_get_eliminate", prim::kPrimEnvironGet); + environ_get_add_eliminate_ = + MakeSubstitution(std::make_shared(), "environ_get_add_eliminate_", prim::kPrimEnvironGet); + environ_get_set_eliminate_ = + MakeSubstitution(std::make_shared(), "environ_get_set_eliminate", prim::kPrimEnvironGet); + environ_get_depend_swap_ = + MakeSubstitution(std::make_shared(), "environ_get_depend_swap", prim::kPrimEnvironGet); + environ_add_const_eliminate_ = MakeSubstitution(std::make_shared(), + "environ_add_const_eliminate_", prim::kPrimEnvironAdd); - incorporate_env_getitem_bypass_recursive_ = - MakeSubstitution(std::make_shared(true), "incorporate_env_get_item", prim::kPrimEnvGetItem); - incorporate_env_getitem_switch_ = MakeSubstitution(std::make_shared(), - "incorporate_env_getitem_switch", prim::kPrimEnvGetItem); - incorporate_env_getitem_ = - MakeSubstitution(std::make_shared(), "incorporate_env_get_item", prim::kPrimEnvGetItem); + incorporate_environ_get_bypass_recursive_ = + MakeSubstitution(std::make_shared(true), "incorporate_environ_get", prim::kPrimEnvironGet); + incorporate_environ_get_switch_ = MakeSubstitution(std::make_shared(), + "incorporate_environ_get_switch", prim::kPrimEnvironGet); + incorporate_environ_get_ = + MakeSubstitution(std::make_shared(), "incorporate_environ_get", prim::kPrimEnvironGet); - incorporate_env_getitem_switch_layer_ = - MakeSubstitution(std::make_shared(), "incorporate_env_getitem_switch_layer", - prim::kPrimEnvGetItem); + incorporate_environ_get_switch_layer_ = + MakeSubstitution(std::make_shared(), "incorporate_environ_get_switch_layer", + prim::kPrimEnvironGet); // Ref eliminate make_ref_eliminate_ = @@ -157,8 +159,8 @@ OptimizeIRPassLib::OptimizeIRPassLib() { switch_simplify_ = MakeSubstitution(std::make_shared(), "switch_simplify", prim::kPrimSwitch); float_tuple_getitem_switch_ = MakeSubstitution(std::make_shared(), "float_tuple_getitem_switch", prim::kPrimTupleGetItem); - float_env_getitem_switch_ = - MakeSubstitution(std::make_shared(), "float_env_getitem_switch", prim::kPrimEnvGetItem); + float_environ_get_switch_ = + MakeSubstitution(std::make_shared(), "float_environ_get_switch", prim::kPrimEnvironGet); exchange_switch_depend_value_ = MakeSubstitution(std::make_shared(), "exchange_switch_depend_value", prim::kPrimSwitch); diff --git a/mindspore/ccsrc/frontend/optimizer/irpass.h b/mindspore/ccsrc/frontend/optimizer/irpass.h index cdacaffcb81..88c19a641a0 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass.h +++ b/mindspore/ccsrc/frontend/optimizer/irpass.h @@ -65,14 +65,15 @@ class OptimizeIRPassLib { SubstitutionPtr real_op_eliminate_; // Env Item Eliminate - SubstitutionPtr env_get_item_eliminate_; - SubstitutionPtr env_get_item_add_eliminate_; - SubstitutionPtr env_get_set_item_eliminate_; - SubstitutionPtr env_get_item_depend_swap_; - SubstitutionPtr incorporate_env_getitem_; - SubstitutionPtr incorporate_env_getitem_bypass_recursive_; - SubstitutionPtr incorporate_env_getitem_switch_; - SubstitutionPtr incorporate_env_getitem_switch_layer_; + SubstitutionPtr environ_get_eliminate_; + SubstitutionPtr environ_get_add_eliminate_; + SubstitutionPtr environ_get_set_eliminate_; + SubstitutionPtr environ_get_depend_swap_; + SubstitutionPtr environ_add_const_eliminate_; + SubstitutionPtr incorporate_environ_get_; + SubstitutionPtr incorporate_environ_get_bypass_recursive_; + SubstitutionPtr incorporate_environ_get_switch_; + SubstitutionPtr incorporate_environ_get_switch_layer_; // Ref eliminate SubstitutionPtr make_ref_eliminate_; @@ -84,7 +85,7 @@ class OptimizeIRPassLib { // Branch culling SubstitutionPtr switch_simplify_; SubstitutionPtr float_tuple_getitem_switch_; - SubstitutionPtr float_env_getitem_switch_; + SubstitutionPtr float_environ_get_switch_; SubstitutionPtr exchange_switch_depend_value_; SubstitutionPtr switch_partial_eliminater_; diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/branch_culling.cc b/mindspore/ccsrc/frontend/optimizer/irpass/branch_culling.cc index 4f8f8d28701..d3b17c1eb80 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass/branch_culling.cc +++ b/mindspore/ccsrc/frontend/optimizer/irpass/branch_culling.cc @@ -58,8 +58,8 @@ bool InConvertWhiteList(const AnfNodePtr &node, size_t index) { {prim::kPrimMomentum, {2, 3}}, {prim::kPrimStateSetItem, {1}}, {prim::kPrimTupleGetItem, {2}}, - {prim::kPrimEnvGetItem, {1}}, - {prim::kPrimEnvSetItem, {1}}, + {prim::kPrimEnvironGet, {1}}, + {prim::kPrimEnvironSet, {1}}, {prim::kPrimReduceSum, {2}}, {prim::kPrimReduceMean, {2}}, {prim::kPrimReduceAll, {2}}, @@ -83,7 +83,7 @@ bool InConvertWhiteList(const AnfNodePtr &node, size_t index) { std::vector>> white_list( {{prim::kPrimApplyMomentum, {1, 2}}, {prim::kPrimMomentum, {2, 3}}, {prim::kPrimStateSetItem, {1}}, {prim::kPrimTupleGetItem, {2}}, - {prim::kPrimEnvGetItem, {1}}, {prim::kPrimEnvSetItem, {1}}, + {prim::kPrimEnvironGet, {1}}, {prim::kPrimEnvironSet, {1}}, {prim::kPrimReduceSum, {2}}, {prim::kPrimReduceMean, {2}}, {prim::kPrimReduceAll, {2}}, {prim::kPrimCast, {2}}, {prim::kPrimTranspose, {2}}, {prim::kPrimOneHot, {2}}, diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/branch_culling.h b/mindspore/ccsrc/frontend/optimizer/irpass/branch_culling.h index 5864edd0e21..49b30dae7b6 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass/branch_culling.h +++ b/mindspore/ccsrc/frontend/optimizer/irpass/branch_culling.h @@ -66,16 +66,16 @@ class FloatTupleGetItemSwitch : public OptimizerCaller { } }; -// {prim::kPrimEnvGetItem, {prim::kPrimSwitch, X1, X2, X3}, X4, X5} => -// {prim::kPrimSwitch, X1, {prim::kPrimEnvGetItem, X2, X4, X5}, {prim::kPrimEnvGetItem, X3, X4, X5}} -class FloatEnvGetItemSwitch : public OptimizerCaller { +// {prim::kPrimEnvironGet, {prim::kPrimSwitch, X1, X2, X3}, X4, X5} => +// {prim::kPrimSwitch, X1, {prim::kPrimEnvironGet, X2, X4, X5}, {prim::kPrimEnvironGet, X3, X4, X5}} +class FloatEnvironGetSwitch : public OptimizerCaller { public: AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { PatternNode cond, true_br, false_br, x, x2; MATCH_REPLACE(node, - PPrimitive(prim::kPrimEnvGetItem, PPrimitive(prim::kPrimSwitch, cond, true_br, false_br), x, x2), - PPrimitive(prim::kPrimSwitch, cond, PPrimitive(prim::kPrimEnvGetItem, true_br, x, x2), - PPrimitive(prim::kPrimEnvGetItem, false_br, x, x2))); + PPrimitive(prim::kPrimEnvironGet, PPrimitive(prim::kPrimSwitch, cond, true_br, false_br), x, x2), + PPrimitive(prim::kPrimSwitch, cond, PPrimitive(prim::kPrimEnvironGet, true_br, x, x2), + PPrimitive(prim::kPrimEnvironGet, false_br, x, x2))); return nullptr; } diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/env_item_eliminate.h b/mindspore/ccsrc/frontend/optimizer/irpass/environ_eliminate.h similarity index 67% rename from mindspore/ccsrc/frontend/optimizer/irpass/env_item_eliminate.h rename to mindspore/ccsrc/frontend/optimizer/irpass/environ_eliminate.h index fc7f358e295..55c8793a48e 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass/env_item_eliminate.h +++ b/mindspore/ccsrc/frontend/optimizer/irpass/environ_eliminate.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_ENV_ITEM_ELIMINATE_H_ -#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_ENV_ITEM_ELIMINATE_H_ +#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_ENVIRON_ELIMINATE_H_ +#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_ENVIRON_ELIMINATE_H_ #include #include @@ -37,10 +37,15 @@ namespace mindspore { namespace opt { namespace irpass { namespace internal { -class EnvGetitemTransform { +constexpr int kEnvironGetSetInputSize = 4; +constexpr int kEnvironOffset = 1; +constexpr int kSymbolicKeyOffset = 2; +constexpr int kValueOffset = 3; + +class EnvironGetTransform { public: - EnvGetitemTransform() : cache_() {} - ~EnvGetitemTransform() = default; + EnvironGetTransform() : cache_() {} + ~EnvironGetTransform() = default; FuncGraphPtr operator()(const FuncGraphPtr &fg, const SymbolicKeyInstancePtr &key, const AnfNodePtr &default_node) { if (cache_.find(fg) == cache_.end()) { @@ -56,22 +61,22 @@ class EnvGetitemTransform { } auto new_fg = TransformableClone(fg, std::make_shared(ss.str())); - auto env = new_fg->output(); - while (IsPrimitiveCNode(env, prim::kPrimEnvSetItem)) { - // {prim::kPrimEnvSetItem, env, symbolickey, value} - auto &inputs = env->cast()->inputs(); - if (inputs.size() != 4) { + auto environ_node = new_fg->output(); + while (IsPrimitiveCNode(environ_node, prim::kPrimEnvironSet)) { + // {prim::kPrimEnvironSet, environ, symbolickey, value} + auto &inputs = environ_node->cast()->inputs(); + if (inputs.size() != kEnvironGetSetInputSize) { MS_LOG(WARNING) << "Input size should be 4"; return nullptr; } - if (!IsValueNode(inputs[2])) { + if (!IsValueNode(inputs[kSymbolicKeyOffset])) { MS_LOG(DEBUG) << "Input 2 is not a SymbolicKeyInstance?"; return nullptr; } - env = inputs[1]; - auto value = inputs[3]; - auto key2 = GetValueNode(inputs[2]); + environ_node = inputs[kEnvironOffset]; + auto value = inputs[kValueOffset]; + auto key2 = GetValueNode(inputs[kSymbolicKeyOffset]); if (*key2 == *key) { new_fg->set_output(value); cache[hash_key] = new_fg; @@ -79,7 +84,8 @@ class EnvGetitemTransform { return new_fg; } } - new_fg->set_output(new_fg->NewCNode({NewValueNode(prim::kPrimEnvGetItem), env, NewValueNode(key), default_node})); + new_fg->set_output( + new_fg->NewCNode({NewValueNode(prim::kPrimEnvironGet), environ_node, NewValueNode(key), default_node})); cache[hash_key] = new_fg; } @@ -92,10 +98,10 @@ class EnvGetitemTransform { cache_; }; -class EnvGetitemTransformACrossGraph { +class EnvironGetTransformACrossGraph { public: - EnvGetitemTransformACrossGraph() : cache_() {} - ~EnvGetitemTransformACrossGraph() = default; + EnvironGetTransformACrossGraph() : cache_() {} + ~EnvironGetTransformACrossGraph() = default; FuncGraphPtr operator()(const FuncGraphPtr &fg, const SymbolicKeyInstancePtr &key, const AnfNodePtr &default_node) { if (cache_.find(fg) == cache_.end()) { @@ -120,29 +126,30 @@ class EnvGetitemTransformACrossGraph { auto new_fg = TransformableClone(fg_inner, std::make_shared(ss.str())); new_fg_outer->set_output(NewValueNode(new_fg)); - auto env = new_fg->output(); - while (IsPrimitiveCNode(env, prim::kPrimEnvSetItem)) { - // {prim::kPrimEnvSetItem, env, symbolickey, value} - auto &inputs = env->cast()->inputs(); - if (inputs.size() != 4) { + auto environ_node = new_fg->output(); + while (IsPrimitiveCNode(environ_node, prim::kPrimEnvironSet)) { + // {prim::kPrimEnvironSet, environ, symbolickey, value} + auto &inputs = environ_node->cast()->inputs(); + if (inputs.size() != kEnvironGetSetInputSize) { MS_LOG(WARNING) << "Input size should be 4"; return nullptr; } - if (!IsValueNode(inputs[2])) { + if (!IsValueNode(inputs[kSymbolicKeyOffset])) { MS_LOG(DEBUG) << "Input 2 is not a SymbolicKeyInstance?"; return nullptr; } - env = inputs[1]; - auto value = inputs[3]; - auto key2 = GetValueNode(inputs[2]); + environ_node = inputs[kEnvironOffset]; + auto value = inputs[kValueOffset]; + auto key2 = GetValueNode(inputs[kSymbolicKeyOffset]); if (*key2 == *key) { new_fg->set_output(value); cache[hash_key] = new_fg_outer; return new_fg_outer; } } - new_fg->set_output(new_fg->NewCNode({NewValueNode(prim::kPrimEnvGetItem), env, NewValueNode(key), default_node})); + new_fg->set_output( + new_fg->NewCNode({NewValueNode(prim::kPrimEnvironGet), environ_node, NewValueNode(key), default_node})); cache[hash_key] = new_fg_outer; } @@ -156,48 +163,48 @@ class EnvGetitemTransformACrossGraph { }; } // namespace internal -// {prim::kPrimEnvGetItem, C1, C2, Y} -> Y -class EnvGetItemEliminater : public AnfVisitor { +// {prim::kPrimEnvironGet, C1, C2, Y} -> Y +class EnvironGetEliminater : public AnfVisitor { public: AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { PatternNode c1, c2, y; - MATCH_REPLACE_IF(node, PPrimitive(prim::kPrimEnvGetItem, c1, c2, y), y, - (IsValueNode(c1.GetNode(node)) && IsVNode(c2.GetNode(node)))); + MATCH_REPLACE_IF(node, PPrimitive(prim::kPrimEnvironGet, c1, c2, y), y, + (IsNewEnvironNode(c1.GetNode(node)) && IsVNode(c2.GetNode(node)))); return nullptr; } }; -// {prim::kPrimEnvGetItem, {prim::kPrimEnvAdd, X, Y}, C, Z} -> -// {prim::GetPythonOps("hyper_add"), {prim::kPrimEnvGetItem, X, C, Z}, {prim::kPrimEnvGetItem, Y, C, Z}} -class EnvGetItemAddEliminater : public AnfVisitor { +// {prim::kPrimEnvironGet, {prim::kPrimEnvironAdd, X, Y}, C, Z} -> +// {prim::GetPythonOps("hyper_add"), {prim::kPrimEnvironGet, X, C, Z}, {prim::kPrimEnvironGet, Y, C, Z}} +class EnvironGetAddEliminater : public AnfVisitor { public: - EnvGetItemAddEliminater() : PrimHyperAdd_(prim::GetPythonOps("hyper_add")) {} - ~EnvGetItemAddEliminater() override = default; + EnvironGetAddEliminater() : PrimHyperAdd_(prim::GetPythonOps("hyper_add")) {} + ~EnvironGetAddEliminater() override = default; AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { is_match_ = false; auto IsAddCNode = [](const AnfNodePtr &node) -> bool { - return IsPrimitiveCNode(node, prim::kPrimEnvAdd) && node->cast()->size() == 3; + return IsPrimitiveCNode(node, prim::kPrimEnvironAdd) && node->cast()->size() == 3; }; - AnfVisitor::Match(prim::kPrimEnvGetItem, {IsAddCNode, IsVNode, IsNode})(node); + AnfVisitor::Match(prim::kPrimEnvironGet, {IsAddCNode, IsVNode, IsNode})(node); if (!is_match_ || node->func_graph() == nullptr) { return nullptr; } - // {prim::kPrimEnvGetItem, {...}, C, Z} + // {prim::kPrimEnvironGet, {...}, C, Z} auto cnode = node->cast(); auto inp1 = cnode->input(1)->cast(); auto c = cnode->input(2); auto z = cnode->input(3); - // {prim::kPrimEnvAdd, X, Y} + // {prim::kPrimEnvironAdd, X, Y} auto x = inp1->input(1); auto y = inp1->input(2); auto fg = node->func_graph(); - auto xcz = fg->NewCNode({NewValueNode(prim::kPrimEnvGetItem), x, c, z}); - auto ycz = fg->NewCNode({NewValueNode(prim::kPrimEnvGetItem), y, c, z}); + auto xcz = fg->NewCNode({NewValueNode(prim::kPrimEnvironGet), x, c, z}); + auto ycz = fg->NewCNode({NewValueNode(prim::kPrimEnvironGet), y, c, z}); return fg->NewCNode({NewValueNode(PrimHyperAdd_), xcz, ycz}); } @@ -209,17 +216,17 @@ class EnvGetItemAddEliminater : public AnfVisitor { ValuePtr PrimHyperAdd_; }; -// {prim::kPrimEnvGetItem, {prim::kPrimEnvSetItem, X, C1, Y}, C2, Z} -class EnvGetSetItemEliminater : public AnfVisitor { +// {prim::kPrimEnvironGet, {prim::kPrimEnvironSet, X, C1, Y}, C2, Z} +class EnvironGetSetEliminater : public AnfVisitor { public: AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { is_match_ = false; auto IsSetCNode = [](const AnfNodePtr &node) -> bool { - if (!IsPrimitiveCNode(node, prim::kPrimEnvSetItem)) { + if (!IsPrimitiveCNode(node, prim::kPrimEnvironSet)) { return false; } - // {prim::kPrimEnvSetItem, X, C1, Y} + // {prim::kPrimEnvironSet, X, C1, Y} auto &inputs = node->cast()->inputs(); if (inputs.size() != 4) { return false; @@ -227,21 +234,21 @@ class EnvGetSetItemEliminater : public AnfVisitor { return IsValueNode(inputs[2]); }; - AnfVisitor::Match(prim::kPrimEnvGetItem, {IsSetCNode, IsValueNode, IsNode})(node); + AnfVisitor::Match(prim::kPrimEnvironGet, {IsSetCNode, IsValueNode, IsNode})(node); if (!is_match_ || node->func_graph() == nullptr) { return nullptr; } - // {prim::kPrimEnvGetItem, {...}, C2, Z} + // {prim::kPrimEnvironGet, {...}, C2, Z} auto cnode = node->cast(); auto inp1 = cnode->input(1)->cast(); auto key2 = cnode->input(2); auto c2 = GetValueNode(key2); auto default_v = cnode->input(3); - // {prim::kPrimEnvSetItem, X, C1, Y} - auto env = inp1->input(1); + // {prim::kPrimEnvironSet, X, C1, Y} + AnfNodePtr environ_node = inp1->input(1); auto c1 = GetValueNode(inp1->input(2)); auto last_set = inp1->input(3); @@ -249,27 +256,27 @@ class EnvGetSetItemEliminater : public AnfVisitor { return last_set; } - while (IsPrimitiveCNode(env, prim::kPrimEnvSetItem)) { - // {prim::kPrimEnvSetItem, env, symbolickey, value} - auto &inputs = env->cast()->inputs(); - if (inputs.size() != 4) { + while (IsPrimitiveCNode(environ_node, prim::kPrimEnvironSet)) { + // {prim::kPrimEnvironSet, environ, symbolickey, value} + auto &inputs = environ_node->cast()->inputs(); + if (inputs.size() != internal::kEnvironGetSetInputSize) { MS_LOG(WARNING) << "Input size should be 4"; return nullptr; } - if (!IsValueNode(inputs[2])) { + if (!IsValueNode(inputs[internal::kSymbolicKeyOffset])) { MS_LOG(DEBUG) << "Input 2 is not a SymbolicKeyInstance?"; return nullptr; } - env = inputs[1]; - last_set = inputs[3]; - auto symbolic_c1 = GetValueNode(inputs[2]); + environ_node = inputs[internal::kEnvironOffset]; + last_set = inputs[internal::kValueOffset]; + auto symbolic_c1 = GetValueNode(inputs[internal::kSymbolicKeyOffset]); if (*symbolic_c1 == *c2) { return last_set; } } - return node->func_graph()->NewCNode({NewValueNode(prim::kPrimEnvGetItem), env, key2, default_v}); + return node->func_graph()->NewCNode({NewValueNode(prim::kPrimEnvironGet), environ_node, key2, default_v}); } void Visit(const AnfNodePtr &) override { is_match_ = true; } @@ -278,9 +285,9 @@ class EnvGetSetItemEliminater : public AnfVisitor { bool is_match_{false}; }; -// {prim::kPrimEnvGetitem, {prim::kPrimDepend, X1, X2}, item, dflt} -> -// {prim::kPrimDepend, {prim::kPrimEnvGetitem, X1, item, dflt}, X2} -class EnvGetItemDependSwap : public OptimizerCaller { +// {prim::kPrimEnvironGet, {prim::kPrimDepend, X1, X2}, item, dflt} -> +// {prim::kPrimDepend, {prim::kPrimEnvironGet, X1, item, dflt}, X2} +class EnvironGetDependSwap : public OptimizerCaller { public: AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { if (!node->isa() || node->func_graph() == nullptr) { @@ -290,18 +297,30 @@ class EnvGetItemDependSwap : public OptimizerCaller { ScopeGuard scope_guard(scope); PatternNode x1, x2, item, dflt; - MATCH_REPLACE(node, PPrimitive(prim::kPrimEnvGetItem, PPrimitive(prim::kPrimDepend, x1, x2), item, dflt), - PPrimitive(prim::kPrimDepend, PPrimitive(prim::kPrimEnvGetItem, x1, item, dflt), x2)); + MATCH_REPLACE(node, PPrimitive(prim::kPrimEnvironGet, PPrimitive(prim::kPrimDepend, x1, x2), item, dflt), + PPrimitive(prim::kPrimDepend, PPrimitive(prim::kPrimEnvironGet, x1, item, dflt), x2)); return nullptr; } }; -// {prim::kPrimEnvGetItem, {G, Xs}, C, Y} -class IncorporateEnvGetitem : public AnfVisitor { +// {prim::kPrimEnvironAdd, C1, X} -> X +// {prim::kPrimEnvironAdd, X, C1} -> X +class EnvironAddConstEliminater : public AnfVisitor { public: - explicit IncorporateEnvGetitem(bool bypass_recursive = false) - : env_get_item_transform_(), bypass_recursive_(bypass_recursive) {} - ~IncorporateEnvGetitem() override = default; + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { + PatternNode c1, x; + MATCH_REPLACE_IF(node, PPrimitive(prim::kPrimEnvironAdd, c1, x), x, (IsNewEnvironNode(c1.GetNode(node)))); + MATCH_REPLACE_IF(node, PPrimitive(prim::kPrimEnvironAdd, x, c1), x, (IsNewEnvironNode(c1.GetNode(node)))); + return nullptr; + } +}; + +// {prim::kPrimEnvironGet, {G, Xs}, C, Y} +class IncorporateEnvironGet : public AnfVisitor { + public: + explicit IncorporateEnvironGet(bool bypass_recursive = false) + : environ_get_transform_(), bypass_recursive_(bypass_recursive) {} + ~IncorporateEnvironGet() override = default; AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { static bool enable_closure = common::GetEnv("MS_DEV_ENABLE_CLOSURE") == "1"; @@ -316,13 +335,13 @@ class IncorporateEnvGetitem : public AnfVisitor { } return IsValueNode(cnode->input(0)); }; - AnfVisitor::Match(prim::kPrimEnvGetItem, {IsGCNode, IsValueNode, IsNode})(node); + AnfVisitor::Match(prim::kPrimEnvironGet, {IsGCNode, IsValueNode, IsNode})(node); if (!is_match_) { return nullptr; } - // {prim::kPrimEnvGetItem, {...}, C, Y} + // {prim::kPrimEnvironGet, {...}, C, Y} auto cnode = node->cast(); auto inp1 = cnode->input(1)->cast(); auto key = GetValueNode(cnode->input(2)); @@ -331,9 +350,9 @@ class IncorporateEnvGetitem : public AnfVisitor { // {G, Xs} auto inputs = inp1->inputs(); auto fg = GetValueNode(inputs[0]); - auto new_fg = env_get_item_transform_(fg, key, default_v); + auto new_fg = environ_get_transform_(fg, key, default_v); if (fg->recursive() && bypass_recursive_) { - MS_LOG(DEBUG) << "Bypass env_get_item transform for recursive fg=" << fg->ToString(); + MS_LOG(DEBUG) << "Bypass EnvironGet transform for recursive fg=" << fg->ToString(); return nullptr; } if (new_fg == nullptr) { @@ -350,15 +369,15 @@ class IncorporateEnvGetitem : public AnfVisitor { private: bool is_match_{false}; - internal::EnvGetitemTransform env_get_item_transform_; + internal::EnvironGetTransform environ_get_transform_; bool bypass_recursive_; }; -// {prim::kPrimEnvGetItem, {{prim::kPrimSwitch, X, G1, G2}, Xs}, C, Y} -class IncorporateEnvGetitemSwitch : public AnfVisitor { +// {prim::kPrimEnvironGet, {{prim::kPrimSwitch, X, G1, G2}, Xs}, C, Y} +class IncorporateEnvironGetSwitch : public AnfVisitor { public: - IncorporateEnvGetitemSwitch() : env_get_item_transform_() {} - ~IncorporateEnvGetitemSwitch() override = default; + IncorporateEnvironGetSwitch() : environ_get_transform_() {} + ~IncorporateEnvironGetSwitch() override = default; AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { static bool enable_closure = common::GetEnv("MS_DEV_ENABLE_CLOSURE") == "1"; @@ -374,12 +393,12 @@ class IncorporateEnvGetitemSwitch : public AnfVisitor { return IsPrimitiveCNode(cnode->input(0), prim::kPrimSwitch); }; - AnfVisitor::Match(prim::kPrimEnvGetItem, {IsSwNode, IsValueNode, IsNode})(node); + AnfVisitor::Match(prim::kPrimEnvironGet, {IsSwNode, IsValueNode, IsNode})(node); if (!is_match_ || node->func_graph() == nullptr) { return nullptr; } - // {prim::kPrimEnvGetItem, {...}, C, Y} + // {prim::kPrimEnvironGet, {...}, C, Y} auto cnode = node->cast(); auto inp1 = cnode->input(1)->cast(); auto key = GetValueNode(cnode->input(2)); @@ -398,8 +417,8 @@ class IncorporateEnvGetitemSwitch : public AnfVisitor { auto x = sw->input(1); auto g1 = GetValueNode(sw->input(2)); auto g2 = GetValueNode(sw->input(3)); - auto new_g1 = env_get_item_transform_(g1, key, default_v); - auto new_g2 = env_get_item_transform_(g2, key, default_v); + auto new_g1 = environ_get_transform_(g1, key, default_v); + auto new_g2 = environ_get_transform_(g2, key, default_v); if (new_g1 == nullptr || new_g2 == nullptr) { return nullptr; } @@ -416,14 +435,14 @@ class IncorporateEnvGetitemSwitch : public AnfVisitor { private: bool is_match_{false}; - internal::EnvGetitemTransform env_get_item_transform_; + internal::EnvironGetTransform environ_get_transform_; }; -// {prim::kPrimEnvGetItem, {{{prim::kPrimSwitchLayer, X, {prim::kPrimMakeTuple, G1, G2...}}, Xs}, Ys}, C, Y} -class IncorporateEnvGetitemSwitchLayer : public AnfVisitor { +// {prim::kPrimEnvironGet, {{{prim::kPrimSwitchLayer, X, {prim::kPrimMakeTuple, G1, G2...}}, Xs}, Ys}, C, Y} +class IncorporateEnvironGetSwitchLayer : public AnfVisitor { public: - IncorporateEnvGetitemSwitchLayer() : env_get_item_transform_() {} - ~IncorporateEnvGetitemSwitchLayer() override = default; + IncorporateEnvironGetSwitchLayer() : environ_get_transform_() {} + ~IncorporateEnvironGetSwitchLayer() override = default; AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { static bool enable_closure = common::GetEnv("MS_DEV_ENABLE_CLOSURE") == "1"; @@ -431,11 +450,11 @@ class IncorporateEnvGetitemSwitchLayer : public AnfVisitor { return nullptr; } is_match_ = false; - AnfVisitor::Match(prim::kPrimEnvGetItem, {IsCNode, IsValueNode, IsNode})(node); + AnfVisitor::Match(prim::kPrimEnvironGet, {IsCNode, IsValueNode, IsNode})(node); if (!is_match_ || node->func_graph() == nullptr) { return nullptr; } - // {prim::kPrimEnvGetItem, {...}, C, Y} + // {prim::kPrimEnvironGet, {...}, C, Y} auto cnode = node->cast(); auto inp1 = cnode->input(1)->cast(); auto key = GetValueNode(cnode->input(2)); @@ -463,7 +482,8 @@ class IncorporateEnvGetitemSwitchLayer : public AnfVisitor { std::vector graphs{}; auto graphs_cnode = sw->input(2)->cast(); auto &graphs_inputs = graphs_cnode->inputs(); - if (IsPrimitiveCNode(graphs_cnode, prim::kPrimMakeTuple) && graphs_inputs.size() >= 2 && + const int kMinInputSize = 2; + if (IsPrimitiveCNode(graphs_cnode, prim::kPrimMakeTuple) && graphs_inputs.size() >= kMinInputSize && IsValueNode(graphs_inputs[1])) { (void)std::transform(graphs_inputs.begin() + 1, graphs_inputs.end(), std::back_inserter(graphs), [](const AnfNodePtr &vnode) { return GetValueNode(vnode); }); @@ -475,7 +495,7 @@ class IncorporateEnvGetitemSwitchLayer : public AnfVisitor { auto fg = node->func_graph(); std::vector layers; for (auto &graph : graphs) { - auto fg_transform = env_get_item_transform_(graph, key, default_v); + auto fg_transform = environ_get_transform_(graph, key, default_v); if (fg_transform == nullptr) { return nullptr; } @@ -493,9 +513,9 @@ class IncorporateEnvGetitemSwitchLayer : public AnfVisitor { private: bool is_match_{false}; - internal::EnvGetitemTransformACrossGraph env_get_item_transform_; + internal::EnvironGetTransformACrossGraph environ_get_transform_; }; } // namespace irpass } // namespace opt } // namespace mindspore -#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_ENV_ITEM_ELIMINATE_H_ +#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_ENVIRON_ELIMINATE_H_ diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/incorporate_getitem.h b/mindspore/ccsrc/frontend/optimizer/irpass/incorporate_getitem.h index 758d2d7bf51..0af88eafa50 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass/incorporate_getitem.h +++ b/mindspore/ccsrc/frontend/optimizer/irpass/incorporate_getitem.h @@ -648,15 +648,15 @@ class IncorporateGetitemSwitch : public AnfVisitor { MS_EXCEPTION_IF_NULL(switch_call); const auto &switch_call_cnode = switch_call->cast(); MS_EXCEPTION_IF_NULL(switch_call_cnode); - // If exist env_getitem/env_setitem in this funcgraph or - // if g1_/g2_ is fprop func_graph and the corresponding bprop funcgraph has any env_getitem or env_setitem; + // If exist EnvironGet/EnvironSet in this funcgraph or + // if g1_/g2_ is fprop func_graph and the corresponding bprop funcgraph has any EnvironGet or EnvironSet; std::vector tp_cnodes_and_index; auto switch_call_users_counter = MultipleUseOfSwitch(switch_call, fg, &tp_cnodes_and_index); bool multiple_use = (tp_cnodes_and_index.size() > 1); if (g1_output_is_shrinkable && g2_output_is_shrinkable && multiple_use && (tp_cnodes_and_index.size() == switch_call_users_counter)) { - if (!internal::HasMoreJ(optimizer) && !ExistEnvNode(fg) && !ExistEnvNodeInTupleItem(g1_) && - !ExistEnvNodeInTupleItem(g2_) && !internal::ShouldTransform(switch_call, tp_cnodes_and_index)) { + if (!internal::HasMoreJ(optimizer) && !ExistEnvironNode(fg) && !ExistEnvironNodeInTupleItem(g1_) && + !ExistEnvironNodeInTupleItem(g2_) && !internal::ShouldTransform(switch_call, tp_cnodes_and_index)) { MS_LOG(DEBUG) << "No more j, will shrink. Node: " << node->DebugString() << ", switch: " << switch_->DebugString(); const auto g1_output_size = internal::GetOutputSize(g1_->output()); @@ -834,15 +834,15 @@ class IncorporateGetitemSwitch : public AnfVisitor { return node_users.size(); } - static bool inline ExistEnvNode(const FuncGraphPtr &fg) { + static bool inline ExistEnvironNode(const FuncGraphPtr &fg) { MS_EXCEPTION_IF_NULL(fg); auto &nodes = fg->value_nodes(); return std::any_of(nodes.begin(), nodes.end(), [](const auto &node) { - return IsPrimitive(node.first, prim::kPrimEnvSetItem) || IsPrimitive(node.first, prim::kPrimEnvGetItem); + return IsPrimitive(node.first, prim::kPrimEnvironSet) || IsPrimitive(node.first, prim::kPrimEnvironGet); }); } - static bool inline ExistEnvNodeInTupleItem(const FuncGraphPtr &fg) { + static bool inline ExistEnvironNodeInTupleItem(const FuncGraphPtr &fg) { MS_EXCEPTION_IF_NULL(fg); const auto &output = fg->output(); if (!IsPrimitiveCNode(output, prim::kPrimMakeTuple)) { @@ -852,7 +852,7 @@ class IncorporateGetitemSwitch : public AnfVisitor { const auto &inputs = cnode->inputs(); return std::any_of(inputs.cbegin() + 1, inputs.cend(), [](const auto &input) { auto sub_fg = GetValueNode(input); - if (sub_fg != nullptr && ExistEnvNode(sub_fg)) { + if (sub_fg != nullptr && ExistEnvironNode(sub_fg)) { return true; } return false; diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/updatestate_eliminate.cc b/mindspore/ccsrc/frontend/optimizer/irpass/updatestate_eliminate.cc index de2fd3cde61..8fef00e8e8d 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass/updatestate_eliminate.cc +++ b/mindspore/ccsrc/frontend/optimizer/irpass/updatestate_eliminate.cc @@ -167,12 +167,12 @@ AnfNodePtr EliminateUpdateStateWithDepend(const CNodePtr &update_state) { return input_monad; } -bool ExistEnvGetItem(FuncGraphManagerPtr manager) { +bool ExistEnvironGet(FuncGraphManagerPtr manager) { const FuncGraphSet &fgs = manager->func_graphs(); for (auto &fg : fgs) { auto &nodes = fg->value_nodes(); bool exist = std::any_of(nodes.begin(), nodes.end(), - [](const auto &node) { return IsPrimitive(node.first, prim::kPrimEnvGetItem); }); + [](const auto &node) { return IsPrimitive(node.first, prim::kPrimEnvironGet); }); if (exist) { return true; } @@ -181,10 +181,10 @@ bool ExistEnvGetItem(FuncGraphManagerPtr manager) { } // Convert: -// cnode1 = env_setitem(EnvInstance, para1, attach1) -// cnode2 = env_setitem(cnode1, para2, attach2) +// cnode1 = EnvironSet(EnvCreate(), para1, attach1) +// cnode2 = EnvironSet(cnode1, para2, attach2) // ... -// cnode_n = env_setitem(cnode_n-1, para_n-1, attach_n-1) +// cnode_n = EnvironSet(cnode_n-1, para_n-1, attach_n-1) // maketuple = maketuple(cnode_n, ...) // updatestate = updatestate(umonad, maketuple) // To: @@ -194,26 +194,26 @@ AnfNodePtr EliminateUpdateStateMakeTupleWithUselessEnv(const CNodePtr &update_st std::vector env_nodes; std::vector new_maketuple_inputs{NewValueNode(prim::kPrimMakeTuple)}; size_t input_size = make_tuple->inputs().size(); - bool has_env_setitem = false; + bool has_environ_set = false; for (size_t i = 1; i < input_size; i++) { auto node = make_tuple->input(i); - if (IsPrimitiveCNode(node, prim::kPrimEnvSetItem) && OnlyUsedByOneNode(node, make_tuple)) { + if (IsPrimitiveCNode(node, prim::kPrimEnvironSet) && OnlyUsedByOneNode(node, make_tuple)) { env_nodes.emplace_back(node); - has_env_setitem = true; + has_environ_set = true; } else if (node->isa() && !IsPrimitiveCNode(node, prim::kPrimUpdateState)) { new_maketuple_inputs.emplace_back(node); } } - if (!has_env_setitem) { + if (!has_environ_set) { return nullptr; } - // Check env_setitem in MakeTuple + // Check EnvironSet in MakeTuple auto mgr = GetManager(update_state); if (mgr == nullptr) { return nullptr; } - // If exist env_getitem, don't eliminate env_setitem. - if (ExistEnvGetItem(mgr)) { + // If exist EnvironGet, don't eliminate EnvironSet. + if (ExistEnvironGet(mgr)) { return nullptr; } const size_t first_index = 1; @@ -228,7 +228,7 @@ AnfNodePtr EliminateUpdateStateMakeTupleWithUselessEnv(const CNodePtr &update_st auto env_cnode = env->cast(); auto env_input = env_cnode->input(first_index); auto attach = env_cnode->input(attach_index); - if (IsPrimitiveCNode(env_input, prim::kPrimEnvSetItem) && OnlyUsedByOneNode(env_input, env_cnode)) { + if (IsPrimitiveCNode(env_input, prim::kPrimEnvironSet) && OnlyUsedByOneNode(env_input, env_cnode)) { env_nodes.emplace_back(env_input); new_maketuple_inputs.insert(new_maketuple_inputs.begin() + no_env_node_size, attach); } diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h b/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h index 25002846545..7c8dbf08b4b 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h @@ -464,12 +464,11 @@ constexpr char TUPLE_DIV[] = "tuple_div"; constexpr char TUPLE_TO_ARRAY[] = "tuple_to_array"; constexpr char VIRTUALLOSS[] = "VirtualLoss"; constexpr char RETURN[] = "Return"; -constexpr char ENV_GETITEM[] = "env_getitem"; +constexpr char ENVIRONGET[] = "EnvironGet"; constexpr char IDENTITY[] = "identity"; constexpr char PARTIAL[] = "partial"; -constexpr char ENVSETITEM[] = "env_setitem"; -constexpr char ENVGETITEM[] = "env_getitem"; -constexpr char ENVADD[] = "env_add"; +constexpr char ENVIRONSET[] = "EnvironSet"; +constexpr char ENVIRONADD[] = "EnvironAdd"; constexpr char MAKEREFKEY[] = "MakeRefKey"; constexpr char MAKEREF[] = "make_ref"; constexpr char GETREFKEY[] = "get_ref_key"; diff --git a/mindspore/ccsrc/frontend/parallel/step_parallel.cc b/mindspore/ccsrc/frontend/parallel/step_parallel.cc index 2a4243d3a08..aafaaf67794 100644 --- a/mindspore/ccsrc/frontend/parallel/step_parallel.cc +++ b/mindspore/ccsrc/frontend/parallel/step_parallel.cc @@ -2862,7 +2862,7 @@ static AnfNodePtr FindGrad(const CNodePtr &cnode, size_t curr_depth) { if (!node->isa()) { continue; } - if (!IsPrimitiveCNode(node, prim::kPrimEnvGetItem)) { + if (!IsPrimitiveCNode(node, prim::kPrimEnvironGet)) { return FindGrad(node->cast(), ++curr_depth); } else { return node; diff --git a/mindspore/ccsrc/pipeline/jit/init.cc b/mindspore/ccsrc/pipeline/jit/init.cc index 8a074355188..7a65ac314c5 100644 --- a/mindspore/ccsrc/pipeline/jit/init.cc +++ b/mindspore/ccsrc/pipeline/jit/init.cc @@ -45,7 +45,6 @@ namespace py = pybind11; -using EnvInstance = mindspore::EnvInstance; using GraphExecutorPy = mindspore::pipeline::GraphExecutorPy; using Pipeline = mindspore::pipeline::Pipeline; using PrimitivePy = mindspore::PrimitivePy; @@ -118,8 +117,6 @@ PYBIND11_MODULE(_c_expression, m) { .def("set_jit_config", &GraphExecutorPy::SetJitConfig, py::arg("jit_config") = py::dict(), "Set the jit config.") .def("generate_arguments_key", &GraphExecutorPy::GenerateArgumentsKey, "Generate unique key of argument."); - (void)py::class_>(m, "EnvInstance_").def(py::init()); - (void)m.def("real_run_op", &mindspore::pynative::RealRunOp, "Run op pynatively."); (void)m.def("reset_op_id", &mindspore::pipeline::ResetOpId, "Reset Operator Id"); (void)m.def("init_hccl", &mindspore::pipeline::InitHccl, "Init Hccl"); diff --git a/mindspore/ccsrc/pipeline/jit/parse/data_converter.cc b/mindspore/ccsrc/pipeline/jit/parse/data_converter.cc index 0f54638102e..90719e7a958 100644 --- a/mindspore/ccsrc/pipeline/jit/parse/data_converter.cc +++ b/mindspore/ccsrc/pipeline/jit/parse/data_converter.cc @@ -502,7 +502,6 @@ std::vector GetDataConverters() { std::make_shared>(ObjCast), std::make_shared>(ObjCast), std::make_shared>(ObjCast), - std::make_shared>(ObjCast>), std::make_shared(PYTHON_CLASS_MEMBER_NAMESPACE, [](const py::object &obj) -> ValuePtr { auto res = diff --git a/mindspore/ccsrc/pipeline/jit/pass.cc b/mindspore/ccsrc/pipeline/jit/pass.cc index 40301196c79..5638d017110 100644 --- a/mindspore/ccsrc/pipeline/jit/pass.cc +++ b/mindspore/ccsrc/pipeline/jit/pass.cc @@ -42,6 +42,7 @@ #include "frontend/parallel/allreduce_fusion/step_allreduce_fusion.h" #include "frontend/optimizer/recompute.h" #include "frontend/optimizer/slice_activation_in_recompute.h" +#include "frontend/optimizer/environ_conversion.h" #include "utils/log_adapter.h" #include "pipeline/jit/pipeline_split.h" #include "pipeline/pynative/pynative_execute.h" @@ -198,12 +199,12 @@ FuncGraphPtr BpropGraphFinalOptPass(const ResourcePtr &res) { (void)map.emplace_back(std::make_pair("renormalize", opt::OptPassConfig::Renormalize())); opt::OptPassConfig real_op_eliminate = opt::OptPassConfig{irpass.real_op_eliminate_}; (void)map.emplace_back(std::make_pair("real_op_eliminate", real_op_eliminate)); - opt::OptPassConfig env_eliminate = opt::OptPassConfig({ + opt::OptPassConfig environ_eliminate = opt::OptPassConfig({ irpass.incorporate_call_, irpass.incorporate_call_switch_, irpass.incorporate_getitem_set_, }); - (void)map.emplace_back(std::make_pair("env_eliminate", env_eliminate)); + (void)map.emplace_back(std::make_pair("environ_eliminate", environ_eliminate)); } auto bprop_graph_final_opt = opt::Optimizer::MakeOptimizer("bprop_graph_final_opt", res, map); @@ -264,10 +265,11 @@ opt::OptPassConfig GetOptPassA1(const opt::irpass::OptimizeIRPassLib &irpass) { irpass.tuple_list_get_item_depend_reorder_, irpass.tuple_list_convert_item_index_to_positive_, - irpass.env_get_item_eliminate_, - irpass.env_get_item_add_eliminate_, - irpass.env_get_set_item_eliminate_, - irpass.env_get_item_depend_swap_, + irpass.environ_get_eliminate_, + irpass.environ_get_add_eliminate_, + irpass.environ_get_set_eliminate_, + irpass.environ_get_depend_swap_, + irpass.environ_add_const_eliminate_, irpass.cast_eliminate_, irpass.reshape_eliminate_, @@ -301,16 +303,16 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) { irpass.specialize_transform_, irpass.merge_addn_, irpass.float_tuple_getitem_switch_, - irpass.float_env_getitem_switch_, + irpass.float_environ_get_switch_, irpass.inline_, irpass.updatestate_useless_node_eliminater_, irpass.tuple_list_get_item_eliminator_, irpass.incorporate_getitem_set_, irpass.incorporate_call_, irpass.incorporate_call_switch_, - irpass.incorporate_env_getitem_bypass_recursive_, - irpass.incorporate_env_getitem_switch_, - irpass.env_get_item_eliminate_, + irpass.incorporate_environ_get_bypass_recursive_, + irpass.incorporate_environ_get_switch_, + irpass.environ_get_eliminate_, irpass.depend_value_elim_, irpass.all_reduce_const_elim_, }, @@ -434,13 +436,14 @@ OptPassGroupMap GetOptPassesB(const opt::irpass::OptimizeIRPassLib &irpass) { irpass.stopgrad_eliminater_, irpass.special_op_eliminate_, irpass.get_make_ref_eliminate_, - irpass.incorporate_env_getitem_, - irpass.incorporate_env_getitem_switch_, - irpass.env_get_item_eliminate_, - irpass.env_get_item_add_eliminate_, - irpass.env_get_set_item_eliminate_, - irpass.env_get_item_depend_swap_, - irpass.incorporate_env_getitem_switch_layer_, + irpass.incorporate_environ_get_, + irpass.incorporate_environ_get_switch_, + irpass.environ_get_eliminate_, + irpass.environ_get_add_eliminate_, + irpass.environ_get_set_eliminate_, + irpass.environ_get_depend_swap_, + irpass.environ_add_const_eliminate_, + irpass.incorporate_environ_get_switch_layer_, irpass.value_based_eliminate_, irpass.virtual_accu_grad_, irpass.virtual_assign_add_, @@ -732,6 +735,15 @@ bool AutoMonadElimOptPass(const FuncGraphPtr &func_graph) { return true; } +bool EnvironConversionPass(const ResourcePtr &res) { + MS_EXCEPTION_IF_NULL(res); + static bool enable_closure = common::GetEnv("MS_DEV_ENABLE_CLOSURE") == "1"; + if (enable_closure) { + opt::EnvironConversion(res); + } + return true; +} + std::vector kVmPasses = {{"simplify_data_structures", SimplifyDataStructuresPass}, {"opt_before_recompute", OptBeforeRecomputeGroup}, {"opt_a", OptPassAGroup}, @@ -744,6 +756,7 @@ std::vector kVmPasses = {{"simplify_data_structures", SimplifyDataStru {"add_cache_embedding", AddCacheEmbeddingPass}, {"add_recomputation", AddRecomputationPass}, {"cse_after_recomputation", OptAfterRecomputeGroup}, + {"environ_conv", EnvironConversionPass}, {"slice_recompute_activation", SliceRecomputeActivationPass}}; std::vector kGePasses = {{"simplify_data_structures", SimplifyDataStructuresPass}, diff --git a/mindspore/core/abstract/infer_functions.h b/mindspore/core/abstract/infer_functions.h index 6d4cb25dbf2..8360b8b295e 100644 --- a/mindspore/core/abstract/infer_functions.h +++ b/mindspore/core/abstract/infer_functions.h @@ -123,12 +123,14 @@ AbstractBasePtr InferImplArrayLen(const AnalysisEnginePtr &, const PrimitivePtr AbstractBasePtr InferImplIdentity(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplEnvGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, +AbstractBasePtr InferImplEnvironCreate(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplEnvironGet(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplEnvSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, +AbstractBasePtr InferImplEnvironSet(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplEnvironAdd(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplEnvAdd(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplMakeRefKey(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplMakeRef(const AnalysisEnginePtr &, const PrimitivePtr &primitive, diff --git a/mindspore/core/abstract/prim_others.cc b/mindspore/core/abstract/prim_others.cc index a76a35d56af..466d73f5fb2 100644 --- a/mindspore/core/abstract/prim_others.cc +++ b/mindspore/core/abstract/prim_others.cc @@ -43,7 +43,15 @@ AbstractBasePtr InferImplIdentity(const AnalysisEnginePtr &, const PrimitivePtr return args_spec_list[0]; } -AbstractBasePtr InferImplEnvGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, +AbstractBasePtr InferImplEnvironCreate(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // args: None. + CheckArgsSize(primitive->name(), args_spec_list, 0); + static AbstractBasePtr abs_env = std::make_shared(kAnyValue, std::make_shared()); + return abs_env; +} + +AbstractBasePtr InferImplEnvironGet(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list) { MS_EXCEPTION_IF_NULL(primitive); // args: Three objects of a subclass of AbstractBase, env, key, dflt(default). @@ -53,7 +61,7 @@ AbstractBasePtr InferImplEnvGetItem(const AnalysisEnginePtr &, const PrimitivePt TypePtr type = key->GetTypeTrack(); MS_EXCEPTION_IF_NULL(type); if (type->type_id() != kObjectTypeSymbolicKeyType) { - MS_LOG(EXCEPTION) << "EnvGetItem evaluator args[1] should be a SymbolicKeyInstance but: " << key->ToString(); + MS_LOG(EXCEPTION) << "EnvironGet evaluator args[1] should be a SymbolicKeyInstance but: " << key->ToString(); } auto context = MsContext::GetInstance(); @@ -76,7 +84,7 @@ AbstractBasePtr InferImplEnvGetItem(const AnalysisEnginePtr &, const PrimitivePt return expected; } -AbstractBasePtr InferImplEnvSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, +AbstractBasePtr InferImplEnvironSet(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list) { // args: Three objects of a subclass of AbstractBase, env, key, dflt(default). CheckArgsSize(primitive->name(), args_spec_list, 3); @@ -86,7 +94,7 @@ AbstractBasePtr InferImplEnvSetItem(const AnalysisEnginePtr &, const PrimitivePt MS_EXCEPTION_IF_NULL(key_value_ptr); auto key_value_track = key_value_ptr->cast(); if (key_value_track == nullptr) { - MS_LOG(EXCEPTION) << "EnvGetItem evaluator args[1] expected should be able to cast to SymbolicKeyInstancePtrbut: " + MS_LOG(EXCEPTION) << "EnvironSet evaluator args[1] expected should be able to cast to SymbolicKeyInstancePtrbut: " << key_value_ptr->ToString(); } auto expected = key_value_track->abstract(); @@ -94,8 +102,8 @@ AbstractBasePtr InferImplEnvSetItem(const AnalysisEnginePtr &, const PrimitivePt return std::make_shared(kAnyValue, std::make_shared()); } -AbstractBasePtr InferImplEnvAdd(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { +AbstractBasePtr InferImplEnvironAdd(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { // args: Three objects of a subclass of AbstractBase, env, key, dflt(default). CheckArgsSize(primitive->name(), args_spec_list, 2); return std::make_shared(kAnyValue, std::make_shared()); diff --git a/mindspore/core/abstract/primitive_infer_map.cc b/mindspore/core/abstract/primitive_infer_map.cc index 98d9f542cda..192505e199d 100644 --- a/mindspore/core/abstract/primitive_infer_map.cc +++ b/mindspore/core/abstract/primitive_infer_map.cc @@ -200,9 +200,10 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { {prim::kPrimLoad, R{InferImplLoad, nullptr, true}}, // Set impl to null as it will use PartialEvaluator; {prim::kPrimPartial, R{nullptr, nullptr, true}}, - {prim::kPrimEnvGetItem, R{InferImplEnvGetItem, nullptr, true}}, - {prim::kPrimEnvSetItem, R{InferImplEnvSetItem, nullptr, true}}, - {prim::kPrimEnvAdd, R{InferImplEnvAdd, nullptr, true}}, + {prim::kPrimEnvironCreate, R{InferImplEnvironCreate, nullptr, true}}, + {prim::kPrimEnvironGet, R{InferImplEnvironGet, nullptr, true}}, + {prim::kPrimEnvironSet, R{InferImplEnvironSet, nullptr, true}}, + {prim::kPrimEnvironAdd, R{InferImplEnvironAdd, nullptr, true}}, {prim::kPrimMakeRefKey, R{InferImplMakeRefKey, nullptr, true}}, {prim::kPrimMakeRef, R{InferImplMakeRef, nullptr, true}}, {prim::kPrimGetRefKey, R{InferImplGetRefKey, nullptr, true}}, diff --git a/mindspore/core/base/core_ops.h b/mindspore/core/base/core_ops.h index 06ee56c5516..fba78aeb1cc 100644 --- a/mindspore/core/base/core_ops.h +++ b/mindspore/core/base/core_ops.h @@ -120,6 +120,13 @@ constexpr auto kSparseApplyAdadelta = "SparseApplyAdadelta"; constexpr auto kRoll = "Roll"; constexpr auto kTanh = "Tanh"; +// Others +constexpr auto kEnvironCreate = "EnvironCreate"; +constexpr auto kEnvironSet = "EnvironSet"; +constexpr auto kEnvironGet = "EnvironGet"; +constexpr auto kEnvironAdd = "EnvironAdd"; + +// // Here list all primitives used in backend or some special primitives used by core. // GetNext inline const PrimitivePtr kPrimGetNext = std::make_shared(kGetNext); @@ -708,9 +715,10 @@ inline const PrimitivePtr kPrimListAppend = std::make_shared("list_ap inline const PrimitivePtr kPrimListLen = std::make_shared("list_len"); // Other miscellaneous -inline const PrimitivePtr kPrimEnvSetItem = std::make_shared("env_setitem"); -inline const PrimitivePtr kPrimEnvGetItem = std::make_shared("env_getitem"); -inline const PrimitivePtr kPrimEnvAdd = std::make_shared("env_add"); +inline const PrimitivePtr kPrimEnvironCreate = std::make_shared(kEnvironCreate); +inline const PrimitivePtr kPrimEnvironSet = std::make_shared(kEnvironSet); +inline const PrimitivePtr kPrimEnvironGet = std::make_shared(kEnvironGet); +inline const PrimitivePtr kPrimEnvironAdd = std::make_shared(kEnvironAdd); inline const PrimitivePtr kPrimMakeRefKey = std::make_shared("MakeRefKey"); inline const PrimitivePtr kPrimGetRefKey = std::make_shared("get_ref_key"); inline const PrimitivePtr kPrimMakeRef = std::make_shared("make_ref"); diff --git a/mindspore/core/utils/symbolic.h b/mindspore/core/utils/symbolic.h index 18d20d4f112..a076863b034 100644 --- a/mindspore/core/utils/symbolic.h +++ b/mindspore/core/utils/symbolic.h @@ -26,7 +26,9 @@ #include "utils/hash_map.h" #include "ir/anf.h" +#include "ir/func_graph.h" #include "abstract/abstract_value.h" +#include "base/core_ops.h" namespace mindspore { class SymbolicKeyInstance : public Value { @@ -97,37 +99,15 @@ struct SymbolicKeyInstanceEqual { } }; -using EnvInstanceContentsMap = - mindspore::HashMap; +static inline AnfNodePtr NewEnviron(const FuncGraphPtr &fg) { + return fg->NewCNode({NewValueNode(prim::kPrimEnvironCreate)}); +} -// Environment mapping keys to values. -// Keys are SymbolicKeyInstances, which represent nodes in the graph along -// with inferred properties. -class EnvInstance : public Value { - public: - friend std::ostream &operator<<(std::ostream &out, const std::shared_ptr &env); +static inline bool IsNewEnvironNode(const AnfNodePtr &node) { return IsPrimitiveCNode(node, prim::kPrimEnvironCreate); } - EnvInstance() = default; - ~EnvInstance() override = default; - MS_DECLARE_PARENT(EnvInstance, Value); - abstract::AbstractBasePtr ToAbstract() override { - return std::make_shared(shared_from_base(), std::make_shared()); - } - bool operator==(const EnvInstance &other) const; - bool operator==(const Value &other) const override; - EnvInstance(const EnvInstance &v) : Value(v) {} - EnvInstance(EnvInstance &&v) = default; - EnvInstance &operator=(EnvInstance &&src) noexcept { return *this; }; - - std::size_t hash() const override { - // deterministic characteristic of member variables. - return tid(); - } -}; - -using EnvInstancePtr = std::shared_ptr; - -extern std::shared_ptr newenv; +static inline abstract::AbstractBasePtr MakeEnvironAbstract() { + return std::make_shared(kAnyValue, std::make_shared()); +} } // namespace mindspore #endif // MINDSPORE_CORE_UTILS_SYMBOLIC_H_ diff --git a/mindspore/python/mindspore/common/dtype.py b/mindspore/python/mindspore/common/dtype.py index 37aa97a78ab..971db62a541 100644 --- a/mindspore/python/mindspore/common/dtype.py +++ b/mindspore/python/mindspore/common/dtype.py @@ -17,7 +17,7 @@ """Data type for MindSpore.""" import numpy as np -from .._c_expression import typing, EnvInstance_ +from .._c_expression import typing from .._c_expression.typing import Type __dtype__ = [ @@ -161,7 +161,6 @@ _simple_types = { np.float16: float16, np.float32: float32, np.float64: float64, - EnvInstance_: env_type, } diff --git a/mindspore/python/mindspore/ops/composite/base.py b/mindspore/python/mindspore/ops/composite/base.py index 01655209f58..bd8e631b8a6 100644 --- a/mindspore/python/mindspore/ops/composite/base.py +++ b/mindspore/python/mindspore/ops/composite/base.py @@ -20,7 +20,7 @@ from functools import partial from types import FunctionType from mindspore import context -from ..._c_expression import EnvInstance_, GradOperation_, HyperMap_, Map_, MultitypeFuncGraph_, Tail_, Shard_, \ +from ..._c_expression import GradOperation_, HyperMap_, Map_, MultitypeFuncGraph_, Tail_, Shard_, \ TupleAdd_, TupleSlice_, UnpackCall_, ZipOperation_, ListAppend_, TupleGetItemTensor_, ListInsert_ from ...common import dtype as mstype from ...common.api import ms_function, _pynative_executor, _wrap_func @@ -29,7 +29,7 @@ from ..operations import _grad_ops from .. import operations as P from .. import signature as sig -__all__ = [EnvInstance_, TupleAdd_, TupleSlice_, UnpackCall_, TupleGetItemTensor_] +__all__ = [TupleAdd_, TupleSlice_, UnpackCall_, TupleGetItemTensor_] def add_flags(fn=None, **flags): @@ -852,7 +852,7 @@ zip_operation = _ZipOperation('zip_operation') env_get = MultitypeFuncGraph("env_get") -env_getitem = Primitive('env_getitem') +environ_get = Primitive('EnvironGet') ref_to_embed = _grad_ops.RefToEmbed() zeros_like = P.ZerosLike() @@ -860,4 +860,4 @@ zeros_like = P.ZerosLike() @env_get.register("EnvType", "Tensor") def _tensor_env_get(env, parameter): """Used to get env.""" - return env_getitem(env, ref_to_embed(parameter), zeros_like(parameter)) + return environ_get(env, ref_to_embed(parameter), zeros_like(parameter)) diff --git a/mindspore/python/mindspore/ops/composite/multitype_ops/add_impl.py b/mindspore/python/mindspore/ops/composite/multitype_ops/add_impl.py index a5391e517fb..fb95bd44688 100644 --- a/mindspore/python/mindspore/ops/composite/multitype_ops/add_impl.py +++ b/mindspore/python/mindspore/ops/composite/multitype_ops/add_impl.py @@ -254,7 +254,7 @@ def _add_env(x, y): Returns: EnvType, equal to x + y. """ - return F.env_add(x, y) + return F.environ_add(x, y) @add.register("Tuple", "Tuple") diff --git a/mindspore/python/mindspore/ops/composite/multitype_ops/ones_like_impl.py b/mindspore/python/mindspore/ops/composite/multitype_ops/ones_like_impl.py index 54b2a31b51b..27bad0e3d89 100644 --- a/mindspore/python/mindspore/ops/composite/multitype_ops/ones_like_impl.py +++ b/mindspore/python/mindspore/ops/composite/multitype_ops/ones_like_impl.py @@ -58,9 +58,6 @@ def _ones_like_sparse_tensor(x): return F.make_sparse_tensor(F.sparse_tensor_get_indices(x), values, F.sparse_tensor_get_dense_shape(x)) -newenv = base.EnvInstance_() - - @ones_like_leaf.register("Function") def _ones_like_func(x): """ @@ -70,10 +67,10 @@ def _ones_like_func(x): x (Function): x Returns: - EnvInstance_, value is newenv. + A instance of EnvType. """ # Unused parameters are placeholders. - return newenv + return F.environ_create() ones_like = base.HyperMap(ones_like_leaf) diff --git a/mindspore/python/mindspore/ops/composite/multitype_ops/zeros_like_impl.py b/mindspore/python/mindspore/ops/composite/multitype_ops/zeros_like_impl.py index f1659139aca..84bc1d1f48a 100644 --- a/mindspore/python/mindspore/ops/composite/multitype_ops/zeros_like_impl.py +++ b/mindspore/python/mindspore/ops/composite/multitype_ops/zeros_like_impl.py @@ -37,9 +37,6 @@ def _zeros_like_bool(x): return False -newenv = base.EnvInstance_() - - @zeros_like_leaf.register("Function") def _zeros_like_func(x): """ @@ -49,10 +46,10 @@ def _zeros_like_func(x): x (Function): x Returns: - EnvInstance_, value is newenv. + A instance of EnvType. """ # Unused parameters are placeholders. - return newenv + return F.environ_create() @zeros_like_leaf.register("Tensor") diff --git a/mindspore/python/mindspore/ops/functional.py b/mindspore/python/mindspore/ops/functional.py index bcbcb495500..3c1f233de11 100644 --- a/mindspore/python/mindspore/ops/functional.py +++ b/mindspore/python/mindspore/ops/functional.py @@ -456,9 +456,10 @@ zeros_like = P.ZerosLike() distribute = Primitive('distribute') embed = Primitive('embed') ref_to_embed = _grad_ops.RefToEmbed() -env_setitem = Primitive('env_setitem') -env_getitem = Primitive('env_getitem') -env_add = Primitive('env_add') +environ_create = Primitive('EnvironCreate') +environ_set = Primitive('EnvironSet') +environ_get = Primitive('EnrironGet') +environ_add = Primitive('EnvironAdd') J = Primitive('J') SliceGetItem = Primitive("SliceGetItem") switch = Primitive('Switch') diff --git a/mindspore/python/mindspore/ops/operations/_grad_ops.py b/mindspore/python/mindspore/ops/operations/_grad_ops.py index 16a3450d4b6..f75960af512 100644 --- a/mindspore/python/mindspore/ops/operations/_grad_ops.py +++ b/mindspore/python/mindspore/ops/operations/_grad_ops.py @@ -2047,7 +2047,7 @@ class RefToEmbed(Primitive): Make a key from Ref. The Key is a symbolic_key, is a embedding on Parameter, which is used as a key of the variable in env_type, - and get items by operation `env_get_item` with the symbolic_key instance. The `Parameter` is a ref. + and get items by operation `EnvironGet` with the symbolic_key instance. The `Parameter` is a ref. Inputs: - **input** (Ref) - Target ref, ref is short for reference. The value of a Parameter is a ref. diff --git a/tests/ut/cpp/operator/ops_test.cc b/tests/ut/cpp/operator/ops_test.cc index 22d77698789..2a1d99f3805 100644 --- a/tests/ut/cpp/operator/ops_test.cc +++ b/tests/ut/cpp/operator/ops_test.cc @@ -346,19 +346,28 @@ TEST_F(TestOps, EmbedTest) { ASSERT_EQ(prim->name(), kPrimEmbed->name()); } -TEST_F(TestOps, EnvSetItemTest) { - auto prim = std::make_shared("env_setitem"); - ASSERT_EQ(prim->name(), kPrimEnvSetItem->name()); +/// Feature: Check primitive name equivalence +/// Description: EnvironSet primitive name equivalence +/// Expectation: Equal +TEST_F(TestOps, EnvironSetTest) { + auto prim = std::make_shared("EnvironSet"); + ASSERT_EQ(prim->name(), kPrimEnvironSet->name()); } -TEST_F(TestOps, EnvGetItemTest) { - auto prim = std::make_shared("env_getitem"); - ASSERT_EQ(prim->name(), kPrimEnvGetItem->name()); +/// Feature: Check primitive name equivalence +/// Description: EnvironGet primitive name equivalence +/// Expectation: Equal +TEST_F(TestOps, EnvironGetTest) { + auto prim = std::make_shared("EnvironGet"); + ASSERT_EQ(prim->name(), kPrimEnvironGet->name()); } -TEST_F(TestOps, EnvAddest) { - auto prim = std::make_shared("env_add"); - ASSERT_EQ(prim->name(), kPrimEnvAdd->name()); +/// Feature: Check primitive name equivalence +/// Description: EnvironAdd primitive name equivalence +/// Expectation: Equal +TEST_F(TestOps, EnvironAddTest) { + auto prim = std::make_shared("EnvironAdd"); + ASSERT_EQ(prim->name(), kPrimEnvironAdd->name()); } // Neural Network diff --git a/tests/ut/cpp/pipeline/static_analysis/prim_test.cc b/tests/ut/cpp/pipeline/static_analysis/prim_test.cc index 71891ea5d8c..d723d9cc06a 100644 --- a/tests/ut/cpp/pipeline/static_analysis/prim_test.cc +++ b/tests/ut/cpp/pipeline/static_analysis/prim_test.cc @@ -370,81 +370,81 @@ TEST_F(TestPrim, test_partial) { ASSERT_TRUE(res->ToString() == expected->ToString()); } -// def test_env(x, y): -// return env_setitem(newenv, embed(x), y) -TEST_F(TestPrim, test_env_setitem) { +/// Feature: Check inference result of primitive by build a FuncGraph with single primitive. +/// Description: Check inference result of primitive EnvironSet. +/// Expectation: Equal +TEST_F(TestPrim, test_environ_set) { FuncGraphPtr graph_embed = MakeFuncGraph(prim::kPrimEmbed, 1); AbstractBasePtr abstract_x = FromValue(static_cast(1), false); AbstractBasePtrList args_spec_list = {abstract_x}; AbstractBasePtr embed_x = engine_->Run(graph_embed, args_spec_list).inferred->abstract(); - FuncGraphPtr func_graph = MakeFuncGraph(prim::kPrimEnvSetItem, 3); + FuncGraphPtr func_graph = MakeFuncGraph(prim::kPrimEnvironSet, 3); - AbstractBasePtr abstract_env = ToAbstract(newenv); + AbstractBasePtr abstract_environ = MakeEnvironAbstract(); AbstractBasePtr abstract_y = FromValue(static_cast(2), false); - args_spec_list = {abstract_env, embed_x, abstract_y}; + args_spec_list = {abstract_environ, embed_x, abstract_y}; AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract(); AbstractBasePtr exp = std::make_shared(kAnyValue, std::make_shared()); ASSERT_TRUE(*res == *exp); } -// def test_env(x, y, z): -// e = env_setitem(newenv, embed(x), y) -// return env_getitem(e, embed(x), z) -TEST_F(TestPrim, test_env_getitem) { +/// Feature: Check inference result of primitive by build a FuncGraph with single primitive. +/// Description: Check inference result of primitive EnvironGet. +/// Expectation: Equal +TEST_F(TestPrim, test_environ_get) { FuncGraphPtr graph_embed = MakeFuncGraph(prim::kPrimEmbed, 1); AbstractBasePtr abstract_x = FromValue(static_cast(1), false); AbstractBasePtrList args_spec_list = {abstract_x}; AbstractBasePtr embed_x = engine_->Run(graph_embed, args_spec_list).inferred->abstract(); - FuncGraphPtr graph_setitem = MakeFuncGraph(prim::kPrimEnvSetItem, 3); + FuncGraphPtr graph_environ_set = MakeFuncGraph(prim::kPrimEnvironSet, 3); - AbstractBasePtr abstract_env = ToAbstract(newenv); + AbstractBasePtr abstract_environ = MakeEnvironAbstract(); AbstractBasePtr abstract_y = FromValue(static_cast(2), false); - args_spec_list = {abstract_env, embed_x, abstract_y}; + args_spec_list = {abstract_environ, embed_x, abstract_y}; - AbstractBasePtr res = engine_->Run(graph_setitem, args_spec_list).inferred->abstract(); + AbstractBasePtr res = engine_->Run(graph_environ_set, args_spec_list).inferred->abstract(); AbstractBasePtr exp = std::make_shared(kAnyValue, std::make_shared()); ASSERT_TRUE(*res == *exp); - FuncGraphPtr graph_getitem = MakeFuncGraph(prim::kPrimEnvGetItem, 3); + FuncGraphPtr graph_environ_get = MakeFuncGraph(prim::kPrimEnvironGet, 3); AbstractBasePtr abstract_z = FromValue(static_cast(3), false); args_spec_list = {res, embed_x, abstract_z}; - res = engine_->Run(graph_getitem, args_spec_list).inferred->abstract(); + res = engine_->Run(graph_environ_get, args_spec_list).inferred->abstract(); ASSERT_TRUE(*res == *abstract_x); } -// def test_env(x, y, z): -// e1 = env_setitem(newenv, embed(x), y) -// e2 = env_setitem(newenv, embed(x), z) -// return env_add(e1, e2) -TEST_F(TestPrim, test_env_add) { +/// Feature: Check inference result of primitive by build a FuncGraph with single primitive. +/// Description: Check inference result of primitive EnvironAdd. +/// Expectation: Equal +TEST_F(TestPrim, test_environ_add) { FuncGraphPtr graph_embed = MakeFuncGraph(prim::kPrimEmbed, 1); AbstractBasePtr abstract_x = FromValue(static_cast(1), false); AbstractBasePtrList args_spec_list = {abstract_x}; AbstractBasePtr embed_x = engine_->Run(graph_embed, args_spec_list).inferred->abstract(); - FuncGraphPtr graph_setitem = MakeFuncGraph(prim::kPrimEnvSetItem, 3); + FuncGraphPtr graph_environ_set = MakeFuncGraph(prim::kPrimEnvironSet, 3); - AbstractBasePtr abstract_env = ToAbstract(newenv); + AbstractBasePtr abstract_environ = MakeEnvironAbstract(); AbstractBasePtr abstract_y = FromValue(static_cast(2), false); - args_spec_list = {abstract_env, embed_x, abstract_y}; + args_spec_list = {abstract_environ, embed_x, abstract_y}; - AbstractBasePtr abstract_e1 = engine_->Run(graph_setitem, args_spec_list).inferred->abstract(); + AbstractBasePtr abstract_e1 = engine_->Run(graph_environ_set, args_spec_list).inferred->abstract(); AbstractBasePtr exp = std::make_shared(kAnyValue, std::make_shared()); ASSERT_TRUE(*abstract_e1 == *exp); AbstractBasePtr abstract_z = FromValue(static_cast(3), false); - args_spec_list = {abstract_env, embed_x, abstract_z}; + args_spec_list = {abstract_environ, embed_x, abstract_z}; - AbstractBasePtr abstract_e2 = engine_->Run(graph_setitem, args_spec_list).inferred->abstract(); + AbstractBasePtr abstract_e2 = engine_->Run(graph_environ_set, args_spec_list).inferred->abstract(); ASSERT_TRUE(*abstract_e2 == *exp); - FuncGraphPtr graph_add = MakeFuncGraph(prim::kPrimEnvAdd, 2); + FuncGraphPtr graph_add = MakeFuncGraph(prim::kPrimEnvironAdd, 2); args_spec_list = {abstract_e1, abstract_e2}; AbstractBasePtr res = engine_->Run(graph_add, args_spec_list).inferred->abstract(); diff --git a/tests/ut/python/pipeline/infer/test_auto_monad.py b/tests/ut/python/pipeline/infer/test_auto_monad.py index 2f48ed373fc..e616e19c262 100644 --- a/tests/ut/python/pipeline/infer/test_auto_monad.py +++ b/tests/ut/python/pipeline/infer/test_auto_monad.py @@ -315,7 +315,7 @@ def test_while_if(): # should compile success without zeros_like_tensor args mismatch, the generated graph files -# should not contain env_getitem or env_setitem. +# should not contain EnvironGet or EnvironSet. # InsertGradientOf primitive will make func_graph bprop_construct had BackPropAutoMonad flag set, # so all graph it used will be checked if any side effect it has, so the hyper_map_zeros_like # will have U as parameter, but the call site zeros_like(fv) don't have U argument. @@ -347,7 +347,7 @@ def test_grad_fv_and_insert_gradient_of(): net = FvAndInsertGradientNet() input_data = Tensor(np.array([1.0], np.float32)) - # if use grad_all_list, the generated graph will have env_setitem + # if use grad_all_list, the generated graph will have EnvironSet # as gradient for inputs is constant zero, so it will depend on result of grad. grad_net = grad_by_list(net, ParameterTuple(net.trainable_params())) print(grad_net(input_data))