forked from mindspore-Ecosystem/mindspore
fix ReluV2's mask shape in derelu fusion pass
This commit is contained in:
parent
b48d663c69
commit
cd6e8d6542
|
@ -46,6 +46,8 @@
|
|||
#include "pre_activate/ascend/ir_fusion/mul_addn_fusion.h"
|
||||
#include "pre_activate/ascend/ir_fusion/matmul_biasadd_fusion.h"
|
||||
#include "pre_activate/ascend/ir_fusion/remove_reshape_pair.h"
|
||||
#include "pre_activate/ascend/ir_fusion/confusion_mul_grad_fusion.h"
|
||||
#include "pre_activate/ascend/ir_fusion/derelu_fusion.h"
|
||||
#include "pre_activate/ascend/format_type/insert_trans_op.h"
|
||||
#include "pre_activate/pass/getitem_tuple.h"
|
||||
#include "pre_activate/pass/optimize_dependence.h"
|
||||
|
@ -94,8 +96,10 @@ void AddAscendBackendOptionalIRFusion(PassManager *ir_fusion_pm) {
|
|||
ir_fusion_pm->AddPass(std::make_shared<MulAddNFusion>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<MatmulBiasaddFusion>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<AddnFission>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<GetitemTuple>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<DereluFusion>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<ConfusionMulGradFusion>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<TransposeTransDataFusion>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<GetitemTuple>());
|
||||
}
|
||||
} // namespace
|
||||
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
#include <memory>
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include <string>
|
||||
#include "session/anf_runtime_algorithm.h"
|
||||
#include "ir/primitive.h"
|
||||
#include "utils/utils.h"
|
||||
|
@ -89,6 +90,9 @@ const AnfNodePtr ConfusionMulGradFusion::Process(const FuncGraphPtr &graph, cons
|
|||
auto reduce_sum = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(reduce_sum);
|
||||
auto mul1 = reduce_sum->input(1);
|
||||
if (mul1->fullname_with_scope().find("bert/encoder") == std::string::npos) {
|
||||
return nullptr;
|
||||
}
|
||||
if (IsUsedByOthers(graph, mul1)) {
|
||||
MS_LOG(INFO) << "Mul1 is used by others, quit fusion!";
|
||||
return nullptr;
|
||||
|
|
|
@ -50,9 +50,22 @@ CNodePtr CreateReluV2(const FuncGraphPtr &graph, const CNodePtr &relu) {
|
|||
MS_EXCEPTION_IF_NULL(new_node);
|
||||
new_node->set_scope(relu->scope());
|
||||
|
||||
// ReluV2's 2rd output is mask whose data type is uint8 and value is 0 or 1, so shape is an empty vector
|
||||
// ReluV2's 2rd output is mask whose data type is uint8
|
||||
TypeId mask_dtype = kNumberTypeUInt8;
|
||||
std::vector<size_t> mask_shape;
|
||||
std::vector<size_t> mask_shape = AnfAlgo::GetOutputInferShape(relu, 0);
|
||||
if (mask_shape.size() != 4) {
|
||||
MS_LOG(WARNING) << "relu's infer shape size not equal 4";
|
||||
return nullptr;
|
||||
}
|
||||
auto input_dtype = AnfAlgo::GetPrevNodeOutputInferDataType(relu, 0);
|
||||
if (input_dtype == kNumberTypeUInt8 || input_dtype == kNumberTypeInt8) {
|
||||
mask_shape[1] = (mask_shape[1] + 31) / 32;
|
||||
mask_shape.push_back(4);
|
||||
} else {
|
||||
mask_shape[1] = (mask_shape[1] + 15) / 16;
|
||||
mask_shape.push_back(2);
|
||||
}
|
||||
|
||||
auto types = {AnfAlgo::GetOutputInferDataType(relu, 0), mask_dtype};
|
||||
auto shapes = {AnfAlgo::GetOutputInferShape(relu, 0), mask_shape};
|
||||
AnfAlgo::SetOutputInferTypeAndShape(types, shapes, new_node.get());
|
||||
|
@ -91,6 +104,9 @@ const AnfNodePtr DereluFusion::Process(const FuncGraphPtr &graph, const AnfNodeP
|
|||
MS_EXCEPTION_IF_NULL(relu);
|
||||
|
||||
auto relu_v2 = CreateReluV2(graph, relu);
|
||||
if (relu_v2 == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
std::vector<AnfNodePtr> relu_v2_node_outputs;
|
||||
CreateMultipleOutputsOfAnfNode(graph, relu_v2, kReluV2OutputNum, &relu_v2_node_outputs);
|
||||
|
||||
|
|
|
@ -120,7 +120,7 @@ constexpr auto kStreamActiveOpName = "StreamActive";
|
|||
constexpr auto kAssignAddOpName = "AssignAdd";
|
||||
constexpr auto kSendOpName = "Send";
|
||||
constexpr auto kRecvOpName = "Recv";
|
||||
constexpr auto kReluV2OpName = "ReluV2";
|
||||
constexpr auto kReluV2OpName = "ReLUV2";
|
||||
constexpr auto kReluGradV2OpName = "ReluGradV2";
|
||||
|
||||
// attr key name
|
||||
|
|
|
@ -32,6 +32,11 @@ class TestHWOptimizeConfusionMulGradFusion : public BackendCommon {
|
|||
TEST_F(TestHWOptimizeConfusionMulGradFusion, test_fusion) {
|
||||
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_confusion_mul_grad_fusion", "before");
|
||||
EXPECT_NE(g, nullptr);
|
||||
auto bert_scope = std::make_shared<Scope>("bert/encoder");
|
||||
for (auto node : TopoSort(g->get_return())) {
|
||||
node->set_scope(bert_scope);
|
||||
}
|
||||
|
||||
std::vector<int> shp{1, 1, 1, 1};
|
||||
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
|
||||
AbstractBasePtrList args_spec_list;
|
||||
|
|
|
@ -17,7 +17,7 @@ from mindspore.ops import Primitive
|
|||
|
||||
relu = P.ReLU()
|
||||
relu_grad = Primitive('ReluGrad')
|
||||
relu_v2 = Primitive('ReluV2')
|
||||
relu_v2 = Primitive('ReLUV2')
|
||||
relu_grad_v2 = Primitive('ReluGradV2')
|
||||
make_tuple = Primitive('make_tuple')
|
||||
tuple_getitem = Primitive('tuple_getitem')
|
||||
|
|
Loading…
Reference in New Issue