!11794 remove useless code of dot

From: @yuan_shen_zhou
Reviewed-by: @liangchenghui
Signed-off-by: @liangchenghui
This commit is contained in:
mindspore-ci-bot 2021-02-01 17:38:26 +08:00 committed by Gitee
commit 066ebe516e
12 changed files with 9 additions and 77 deletions

View File

@ -1,6 +1,6 @@
# This is the Python adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
#
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -86,7 +86,6 @@ convert_object_map = {
T.floordiv: multitype_ops.floordiv,
T.mod: multitype_ops.mod,
T.pow: multitype_ops.pow_,
T.matmul: F.dot,
T.lshift: NO_IMPLEMENT,
T.rshift: NO_IMPLEMENT,
T.and_: multitype_ops.logical_and,

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -32,7 +32,7 @@ bool CNodeHasTupleInput(const CNodePtr &cnode) {
}
if (IsValueNode<Primitive>(inputs[i])) {
// unexpected high order primitvie as cnode input when transform graph
MS_LOG(WARNING) << "CheckTupleInput, got unexpected primitve as input" << cnode->DebugString();
MS_LOG(WARNING) << "CheckTupleInput, got unexpected primitive as input" << cnode->DebugString();
return false;
}
auto abs = inputs[i]->abstract();

View File

@ -1,7 +1,7 @@
/**
* This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
*
* Copyright 2019 Huawei Technologies Co., Ltd
* Copyright 2019-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -170,7 +170,6 @@ BuiltInTypeMap &GetMethodMap() {
{"__ge__", std::string("ge")}, // C.ge
{"expand_as", std::string("expand_tensor_as")}, // C.expand_as
{"view", std::string("view")}, // C.view
{"__matmul__", prim::kPrimDot}, // P.dot,
{"__len__", prim::kPrimArrayLen}, // P.array_len,
{"__getitem__", prim::kPrimArrayGetItem}, // P.array_getitem,
{"__setitem__", prim::kPrimArraySetItem}, // P.array_setitem,
@ -352,7 +351,7 @@ void MemoryCleaner::RecordPynativeShortLifePrimitivePy(PrimitivePy *prim) {
if (pynative_short_life_primitives_.find(prim) != pynative_short_life_primitives_.end()) {
return;
}
MS_LOG(DEBUG) << "Record pynative tmp primitve:" << prim->ToString();
MS_LOG(DEBUG) << "Record pynative tmp primitive:" << prim->ToString();
pynative_short_life_primitives_.insert(prim);
pynative_new_primtives_squence_.push_back(prim->ToString());
}

View File

@ -27,8 +27,6 @@ namespace mindspore {
namespace abstract {
AbstractBasePtr InferImplReturn(const AnalysisEnginePtr &, const PrimitivePtr &,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplDot(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplSwitch(const AnalysisEnginePtr &, const PrimitivePtr &,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplSwitchLayer(const AnalysisEnginePtr &, const PrimitivePtr &,

View File

@ -1,5 +1,5 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
* Copyright 2019-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -34,35 +34,6 @@ AbstractBasePtr InferImplReturn(const AnalysisEnginePtr &, const PrimitivePtr &,
return abs_base;
}
AbstractBasePtr InferImplDot(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// Inputs: two tensors.
const std::string op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, 2);
AbstractTensorPtr input_x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
AbstractTensorPtr input_y = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
ShapePtr x_shp = input_x->shape();
auto x_shp_value = x_shp->shape();
ShapePtr y_shp = input_y->shape();
auto y_shp_value = y_shp->shape();
// Should be matrix which shape size is 2.
if (x_shp_value.size() != 2 || y_shp_value.size() != 2) {
MS_LOG(EXCEPTION) << op_name << " evaluator requires input two 2D tensors, while the dimensions of two tensors are "
<< x_shp_value.size() << ", " << y_shp_value.size() << " ";
}
if (x_shp_value[1] != y_shp_value[0] && x_shp_value[1] != Shape::SHP_ANY && y_shp_value[0] != Shape::SHP_ANY) {
MS_LOG(EXCEPTION) << "Incompatible shapes in dot: {" << x_shp->ToString() << "} and {" << y_shp->ToString() << "}";
}
auto x_element = input_x->element();
MS_EXCEPTION_IF_NULL(x_element);
(void)x_element->Join(input_y->element());
auto param = {x_shp_value[0], y_shp_value[1]};
return std::make_shared<AbstractTensor>(input_x->element(), std::make_shared<Shape>(param));
}
AbstractBasePtr InferImplSwitch(const AnalysisEnginePtr &, const PrimitivePtr &prim,
const AbstractBasePtrList &args_spec_list) {
// Inputs: condition, true branch, false branch

View File

@ -26,7 +26,6 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
static PrimitiveEvalImplMap prim_eval_implement_map = {
// Statements
{prim::kPrimReturn, {InferImplReturn, true}},
{prim::kPrimDot, {InferImplDot, true}},
{prim::kPrimSwitch, {InferImplSwitch, true}},
{prim::kPrimSwitchLayer, {InferImplSwitchLayer, true}},
{prim::kPrimIs_, {InferImplIs_, true}},

View File

@ -67,7 +67,6 @@ inline const PrimitivePtr kPrimLogicalOr = std::make_shared<Primitive>("LogicalO
inline const PrimitivePtr kPrimLogicalNot = std::make_shared<Primitive>("LogicalNot");
inline const PrimitivePtr kPrimDistribute = std::make_shared<Primitive>("distribute");
inline const PrimitivePtr kPrimDot = std::make_shared<Primitive>("dot");
inline const PrimitivePtr kPrimIm2Col = std::make_shared<Primitive>("im2col");
inline const PrimitivePtr kPrimCol2Im = std::make_shared<Primitive>("col2im");
inline const PrimitivePtr kPrimIm2ColV1 = std::make_shared<Primitive>("im2col_v1");

View File

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -188,12 +188,6 @@ def bprop_array_to_scalar(x, out, dout):
return (F.scalar_to_array(dout),)
@bprops.register("dot")
def bprop_dot(x, y, out, dout):
"""Backpropagator for primitive `dot`."""
return F.dot(dout, F.transpose(y, (1, 0))), F.dot(F.transpose(x, (1, 0)), dout)
@bprops.register("reshape")
def bprop_reshape(xs, shp, out, dout):
"""Backpropagator for primitive `reshape`."""

View File

@ -142,7 +142,6 @@ in_dict = Primitive("in_dict")
not_in_dict = Primitive("not_in_dict")
mixed_precision_cast = Primitive("mixed_precision_cast")
broadcast_gradient_args = Primitive('BroadcastGradientArgs')
dot = Primitive('dot')
array_reduce = Primitive('array_reduce')
zeros_like = P.ZerosLike()
distribute = Primitive('distribute')

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -287,11 +287,6 @@ TEST_F(TestOps, TransposeTest) {
ASSERT_EQ(prim->name(), kPrimTranspose->name());
}
TEST_F(TestOps, DotTest) {
auto prim = std::make_shared<Primitive>("dot");
ASSERT_EQ(prim->name(), kPrimDot->name());
}
TEST_F(TestOps, Im2ColTest) {
auto prim = std::make_shared<Primitive>("im2col");
ASSERT_EQ(prim->name(), kPrimIm2Col->name());

View File

@ -169,11 +169,6 @@ TEST_F(TestAD, test_prim_array_to_scalar) {
AssertExpect("test_prim_array_to_scalar", dg);
}
TEST_F(TestAD, test_prim_dot) {
FuncGraphPtr dg = Kprim(NewValueNode(prim::kPrimDot), resourcePtr);
AssertExpect("test_prim_dot", dg);
}
TEST_F(TestAD, test_prim_distribute) {
FuncGraphPtr dg = Kprim(NewValueNode(prim::kPrimDistribute), resourcePtr);
AssertExpect("test_prim_distribute", dg);

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -291,22 +291,6 @@ TEST_F(TestPrim, test_J_2) {
ASSERT_TRUE(res_J_1 != nullptr);
}
TEST_F(TestPrim, test_dot) {
auto dot = std::make_shared<Primitive>("dot");
FuncGraphPtr func_graph = MakeFuncGraph(dot, 2);
auto a1 = UTPrimUtils::ArrayFloat64Of({2, 3});
auto a2 = UTPrimUtils::ArrayFloat64Of({3, 4});
std::vector<int64_t> expectedA = {2, 4};
auto expected = UTPrimUtils::ArrayFloat64Of({2, 4});
AbstractBasePtrList args_spec_list = {a1, a2};
AbstractTensorPtr res = dyn_cast<AbstractTensor>(engine_->Run(func_graph, args_spec_list).inferred->abstract());
ASSERT_TRUE(*(dyn_cast<Shape>(res->GetShapeTrack())) == *(dyn_cast<Shape>(expected->GetShapeTrack())));
}
// tail half
TEST_F(TestPrim, test_switch1) {
PrimitivePtr switch_ = std::make_shared<Primitive>("switch");