forked from mindspore-Ecosystem/mindspore
!23206 Remove side_effect_mem mark for random operator
Merge pull request !23206 from Margaret_wangrui/random_op_2
This commit is contained in:
commit
41a5c0ae26
|
@ -134,20 +134,22 @@ static bool HasSideEffect(const AnfNodePtr &node) {
|
|||
|
||||
// If true do not merge the node.
|
||||
bool CSE::CheckRandomEffect(const AnfNodePtr &main, const AnfNodePtr &node) const {
|
||||
bool has_random_effect = false;
|
||||
auto prim_main = GetCNodePrimitive(main);
|
||||
auto prim_node = GetCNodePrimitive(node);
|
||||
// if has random effect, when generate by different op (not same object), do not merge.
|
||||
if (prim_main != nullptr) {
|
||||
if (prim_main == prim_node) {
|
||||
return false;
|
||||
}
|
||||
auto effect_val = prim_main->GetAttr(GRAPH_FLAG_RANDOM_EFFECT);
|
||||
if (effect_val != nullptr && effect_val->isa<BoolImm>()) {
|
||||
has_random_effect = GetValue<bool>(effect_val);
|
||||
bool has_random_effect = GetValue<bool>(effect_val);
|
||||
if (has_random_effect) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
if (prim_main != prim_node) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return has_random_effect;
|
||||
return false;
|
||||
}
|
||||
|
||||
bool CSE::CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node, bool check_side_effect) const {
|
||||
|
|
|
@ -35,7 +35,6 @@
|
|||
#include "pipeline/jit/parse/data_converter.h"
|
||||
#include "pipeline/jit/static_analysis/auto_monad.h"
|
||||
#include "pipeline/jit/static_analysis/order_enforce.h"
|
||||
#include "pipeline/jit/static_analysis/remove_monad.h"
|
||||
#include "pipeline/jit/static_analysis/static_analysis.h"
|
||||
#include "pipeline/jit/static_analysis/async_eval_result.h"
|
||||
#include "pipeline/jit/static_analysis/program_specialize.h"
|
||||
|
@ -486,19 +485,6 @@ bool OrderEnforceAction(const ResourcePtr &res) {
|
|||
return true;
|
||||
}
|
||||
|
||||
bool RemoveRandomOpMonadAction(const ResourcePtr &res) {
|
||||
MS_EXCEPTION_IF_NULL(res);
|
||||
if (res->manager() == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Remove-Random-Op-Monad error, manager is null";
|
||||
}
|
||||
auto func_graph = res->func_graph();
|
||||
if (func_graph == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Remove-Random-Op-Monad error, graph is null";
|
||||
}
|
||||
pipeline::RemoveRandomOpMonad(func_graph);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool InferenceOptPrepareAction(const ResourcePtr &res) {
|
||||
MS_EXCEPTION_IF_NULL(res);
|
||||
if (res->manager() == nullptr) {
|
||||
|
@ -1125,7 +1111,6 @@ std::vector<ActionItem> GePipeline() {
|
|||
(void)actions.emplace_back(std::make_pair("py_opt", OptActionGePyStub));
|
||||
(void)actions.emplace_back(std::make_pair("remove_value_node_duplications", RemoveValueNodeDuplicationsAction));
|
||||
(void)actions.emplace_back(std::make_pair("auto_monad_reorder", OrderEnforceAction));
|
||||
(void)actions.emplace_back(std::make_pair("remove_monad_from_random_op", RemoveRandomOpMonadAction));
|
||||
(void)actions.emplace_back(std::make_pair("validate", ValidateAction));
|
||||
return actions;
|
||||
}
|
||||
|
@ -1141,8 +1126,6 @@ std::vector<ActionItem> VmPipeline() {
|
|||
|
||||
(void)actions.emplace_back(std::make_pair("auto_monad_reorder", OrderEnforceAction));
|
||||
|
||||
(void)actions.emplace_back(std::make_pair("remove_monad_from_random_op", RemoveRandomOpMonadAction));
|
||||
|
||||
// eliminate forward cnode for grad graph
|
||||
(void)actions.emplace_back(std::make_pair("eliminate_forward_cnode", EliminateForwardCNode));
|
||||
|
||||
|
@ -1198,7 +1181,6 @@ std::vector<ActionItem> PServerPipeline() {
|
|||
auto actions = CommonPipeline();
|
||||
(void)actions.emplace_back(std::make_pair("optimize", VmOptimizeAction));
|
||||
(void)actions.emplace_back(std::make_pair("auto_monad_reorder", OrderEnforceAction));
|
||||
(void)actions.emplace_back(std::make_pair("remove_monad_from_random_op", RemoveRandomOpMonadAction));
|
||||
(void)actions.emplace_back(std::make_pair("validate", ValidateAction));
|
||||
(void)actions.emplace_back(std::make_pair("pserver", StartPSServerAction));
|
||||
return actions;
|
||||
|
|
|
@ -1,108 +0,0 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "pipeline/jit/static_analysis/remove_monad.h"
|
||||
#include <algorithm>
|
||||
#include <map>
|
||||
#include <queue>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include <utility>
|
||||
#include "base/core_ops.h"
|
||||
|
||||
namespace mindspore::pipeline {
|
||||
namespace {
|
||||
class RemoveMonad {
|
||||
public:
|
||||
explicit RemoveMonad(const FuncGraphPtr &func_graph) : func_graph_(func_graph), manager_(func_graph->manager()) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph_);
|
||||
MS_EXCEPTION_IF_NULL(manager_);
|
||||
}
|
||||
~RemoveMonad() = default;
|
||||
|
||||
void Run() {
|
||||
auto nodes = TopoSort(func_graph_->get_return());
|
||||
for (auto &node : nodes) {
|
||||
if (node->isa<CNode>()) {
|
||||
auto prim = GetCNodePrimitive(node);
|
||||
if (prim != nullptr && CheckPrimRandomEffect(prim)) {
|
||||
// Remove monad input
|
||||
RemoveMonadFromRandomNodes(node);
|
||||
}
|
||||
}
|
||||
// Remove random nodes from monad chain
|
||||
if (IsPrimitiveCNode(node, prim::kPrimUpdateState)) {
|
||||
RemoveRandomNodesFromMonadChain(node);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
bool CheckPrimRandomEffect(const PrimitivePtr &prim) {
|
||||
bool has_random_effect = false;
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
auto effect_val = prim->GetAttr(GRAPH_FLAG_RANDOM_EFFECT);
|
||||
if (effect_val != nullptr && effect_val->isa<BoolImm>()) {
|
||||
has_random_effect = GetValue<bool>(effect_val);
|
||||
}
|
||||
return has_random_effect;
|
||||
}
|
||||
|
||||
void RemoveMonadFromRandomNodes(const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto &inputs = cnode->inputs();
|
||||
std::vector<AnfNodePtr> new_random_node_inputs;
|
||||
// Remove monad input, in order to parallel execution of random number operators
|
||||
(void)std::copy_if(inputs.begin(), inputs.end(), std::back_inserter(new_random_node_inputs),
|
||||
[](const AnfNodePtr &input) { return !HasAbstractMonad(input); });
|
||||
auto new_random_node = func_graph_->NewCNode(new_random_node_inputs);
|
||||
MS_EXCEPTION_IF_NULL(node->abstract());
|
||||
new_random_node->set_abstract(node->abstract());
|
||||
new_random_node->set_scope(node->scope());
|
||||
(void)manager_->Replace(node, new_random_node);
|
||||
}
|
||||
|
||||
void RemoveRandomNodesFromMonadChain(const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
const size_t first_index = 1;
|
||||
const size_t attach_index = 2;
|
||||
auto monad_input = cnode->input(first_index);
|
||||
auto attach_input = cnode->input(attach_index);
|
||||
if (attach_input->isa<CNode>()) {
|
||||
auto prim = GetCNodePrimitive(attach_input);
|
||||
if (prim != nullptr && CheckPrimRandomEffect(prim)) {
|
||||
(void)manager_->Replace(cnode, monad_input);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const FuncGraphPtr &func_graph_;
|
||||
FuncGraphManagerPtr manager_;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
// Remove monad from random operator of the given graph.
|
||||
void RemoveRandomOpMonad(const FuncGraphPtr &func_graph) {
|
||||
RemoveMonad remover(func_graph);
|
||||
remover.Run();
|
||||
}
|
||||
} // namespace mindspore::pipeline
|
|
@ -1,27 +0,0 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PIPELINE_JIT_REMOVE_MONAD_H_
|
||||
#define MINDSPORE_CCSRC_PIPELINE_JIT_REMOVE_MONAD_H_
|
||||
|
||||
#include "ir/func_graph.h"
|
||||
|
||||
namespace mindspore::pipeline {
|
||||
// Remove monad from random operator of the given graph.
|
||||
void RemoveRandomOpMonad(const FuncGraphPtr &func_graph);
|
||||
} // namespace mindspore::pipeline
|
||||
|
||||
#endif // MINDSPORE_CCSRC_PIPELINE_JIT_REMOVE_MONAD_H_
|
|
@ -3422,7 +3422,6 @@ class DropoutGenMask(Primitive):
|
|||
validator.check_value_type("Seed0", Seed0, [int], self.name)
|
||||
validator.check_value_type("Seed1", Seed1, [int], self.name)
|
||||
self.add_prim_attr("_random_effect", True)
|
||||
self.add_prim_attr('side_effect_mem', True)
|
||||
|
||||
|
||||
class DropoutDoMask(Primitive):
|
||||
|
|
|
@ -62,7 +62,6 @@ class StandardNormal(PrimitiveWithInfer):
|
|||
def __init__(self, seed=0, seed2=0):
|
||||
"""Initialize StandardNormal"""
|
||||
self.init_prim_io_names(inputs=['shape'], outputs=['output'])
|
||||
self.add_prim_attr('side_effect_mem', True)
|
||||
self.add_prim_attr("_random_effect", True)
|
||||
Validator.check_non_negative_int(seed, "seed", self.name)
|
||||
Validator.check_non_negative_int(seed2, "seed2", self.name)
|
||||
|
@ -120,7 +119,6 @@ class StandardLaplace(PrimitiveWithInfer):
|
|||
def __init__(self, seed=0, seed2=0):
|
||||
"""Initialize StandardLaplace"""
|
||||
self.init_prim_io_names(inputs=['shape'], outputs=['output'])
|
||||
self.add_prim_attr('side_effect_mem', True)
|
||||
self.add_prim_attr("_random_effect", True)
|
||||
Validator.check_value_type('seed', seed, [int], self.name)
|
||||
Validator.check_value_type('seed2', seed2, [int], self.name)
|
||||
|
@ -184,7 +182,6 @@ class Gamma(PrimitiveWithInfer):
|
|||
def __init__(self, seed=0, seed2=0):
|
||||
"""Initialize Gamma"""
|
||||
self.init_prim_io_names(inputs=['shape', 'alpha', 'beta'], outputs=['output'])
|
||||
self.add_prim_attr('side_effect_mem', True)
|
||||
self.add_prim_attr("_random_effect", True)
|
||||
Validator.check_non_negative_int(seed, "seed", self.name)
|
||||
Validator.check_non_negative_int(seed2, "seed2", self.name)
|
||||
|
@ -249,7 +246,6 @@ class Poisson(PrimitiveWithInfer):
|
|||
def __init__(self, seed=0, seed2=0):
|
||||
"""Initialize Poisson"""
|
||||
self.init_prim_io_names(inputs=['shape', 'mean'], outputs=['output'])
|
||||
self.add_prim_attr('side_effect_mem', True)
|
||||
self.add_prim_attr("_random_effect", True)
|
||||
Validator.check_non_negative_int(seed, "seed", self.name)
|
||||
Validator.check_non_negative_int(seed2, "seed2", self.name)
|
||||
|
@ -322,7 +318,6 @@ class UniformInt(PrimitiveWithInfer):
|
|||
def __init__(self, seed=0, seed2=0):
|
||||
"""Initialize UniformInt"""
|
||||
self.init_prim_io_names(inputs=['shape', 'minval', 'maxval'], outputs=['output'])
|
||||
self.add_prim_attr('side_effect_mem', True)
|
||||
self.add_prim_attr("_random_effect", True)
|
||||
Validator.check_non_negative_int(seed, "seed", self.name)
|
||||
Validator.check_non_negative_int(seed2, "seed2", self.name)
|
||||
|
@ -430,7 +425,6 @@ class RandomChoiceWithMask(PrimitiveWithInfer):
|
|||
Validator.check_positive_int(count, "count", self.name)
|
||||
Validator.check_value_type('seed', seed, [int], self.name)
|
||||
Validator.check_value_type('seed2', seed2, [int], self.name)
|
||||
self.add_prim_attr('side_effect_mem', True)
|
||||
self.add_prim_attr("_random_effect", True)
|
||||
|
||||
def infer_shape(self, x_shape):
|
||||
|
@ -493,7 +487,6 @@ class RandomCategorical(PrimitiveWithInfer):
|
|||
Validator.check_type_name("dtype", dtype, valid_values, self.name)
|
||||
self.init_prim_io_names(inputs=['logits', 'num_samples', 'seed'],
|
||||
outputs=['output'])
|
||||
self.add_prim_attr('side_effect_mem', True)
|
||||
self.add_prim_attr("_random_effect", True)
|
||||
|
||||
def __infer__(self, logits, num_samples, seed):
|
||||
|
@ -561,7 +554,6 @@ class Multinomial(PrimitiveWithInfer):
|
|||
Validator.check_non_negative_int(seed, "seed", self.name)
|
||||
Validator.check_non_negative_int(seed2, "seed2", self.name)
|
||||
self.init_prim_io_names(inputs=['input', 'num_sample'], outputs=['output'])
|
||||
self.add_prim_attr('side_effect_mem', True)
|
||||
self.add_prim_attr("_random_effect", True)
|
||||
|
||||
def __infer__(self, inputs, num_samples):
|
||||
|
|
Loading…
Reference in New Issue