!11794 remove useless code of dot
From: @yuan_shen_zhou Reviewed-by: @liangchenghui Signed-off-by: @liangchenghui
This commit is contained in:
commit
066ebe516e
|
@ -1,6 +1,6 @@
|
|||
# This is the Python adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
|
||||
#
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
# Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -86,7 +86,6 @@ convert_object_map = {
|
|||
T.floordiv: multitype_ops.floordiv,
|
||||
T.mod: multitype_ops.mod,
|
||||
T.pow: multitype_ops.pow_,
|
||||
T.matmul: F.dot,
|
||||
T.lshift: NO_IMPLEMENT,
|
||||
T.rshift: NO_IMPLEMENT,
|
||||
T.and_: multitype_ops.logical_and,
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -32,7 +32,7 @@ bool CNodeHasTupleInput(const CNodePtr &cnode) {
|
|||
}
|
||||
if (IsValueNode<Primitive>(inputs[i])) {
|
||||
// unexpected high order primitvie as cnode input when transform graph
|
||||
MS_LOG(WARNING) << "CheckTupleInput, got unexpected primitve as input" << cnode->DebugString();
|
||||
MS_LOG(WARNING) << "CheckTupleInput, got unexpected primitive as input" << cnode->DebugString();
|
||||
return false;
|
||||
}
|
||||
auto abs = inputs[i]->abstract();
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
/**
|
||||
* This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
|
||||
*
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -170,7 +170,6 @@ BuiltInTypeMap &GetMethodMap() {
|
|||
{"__ge__", std::string("ge")}, // C.ge
|
||||
{"expand_as", std::string("expand_tensor_as")}, // C.expand_as
|
||||
{"view", std::string("view")}, // C.view
|
||||
{"__matmul__", prim::kPrimDot}, // P.dot,
|
||||
{"__len__", prim::kPrimArrayLen}, // P.array_len,
|
||||
{"__getitem__", prim::kPrimArrayGetItem}, // P.array_getitem,
|
||||
{"__setitem__", prim::kPrimArraySetItem}, // P.array_setitem,
|
||||
|
@ -352,7 +351,7 @@ void MemoryCleaner::RecordPynativeShortLifePrimitivePy(PrimitivePy *prim) {
|
|||
if (pynative_short_life_primitives_.find(prim) != pynative_short_life_primitives_.end()) {
|
||||
return;
|
||||
}
|
||||
MS_LOG(DEBUG) << "Record pynative tmp primitve:" << prim->ToString();
|
||||
MS_LOG(DEBUG) << "Record pynative tmp primitive:" << prim->ToString();
|
||||
pynative_short_life_primitives_.insert(prim);
|
||||
pynative_new_primtives_squence_.push_back(prim->ToString());
|
||||
}
|
||||
|
|
|
@ -27,8 +27,6 @@ namespace mindspore {
|
|||
namespace abstract {
|
||||
AbstractBasePtr InferImplReturn(const AnalysisEnginePtr &, const PrimitivePtr &,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplDot(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplSwitch(const AnalysisEnginePtr &, const PrimitivePtr &,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplSwitchLayer(const AnalysisEnginePtr &, const PrimitivePtr &,
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -34,35 +34,6 @@ AbstractBasePtr InferImplReturn(const AnalysisEnginePtr &, const PrimitivePtr &,
|
|||
return abs_base;
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplDot(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: two tensors.
|
||||
const std::string op_name = primitive->name();
|
||||
CheckArgsSize(op_name, args_spec_list, 2);
|
||||
AbstractTensorPtr input_x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
|
||||
AbstractTensorPtr input_y = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
|
||||
|
||||
ShapePtr x_shp = input_x->shape();
|
||||
auto x_shp_value = x_shp->shape();
|
||||
ShapePtr y_shp = input_y->shape();
|
||||
auto y_shp_value = y_shp->shape();
|
||||
// Should be matrix which shape size is 2.
|
||||
if (x_shp_value.size() != 2 || y_shp_value.size() != 2) {
|
||||
MS_LOG(EXCEPTION) << op_name << " evaluator requires input two 2D tensors, while the dimensions of two tensors are "
|
||||
<< x_shp_value.size() << ", " << y_shp_value.size() << " ";
|
||||
}
|
||||
if (x_shp_value[1] != y_shp_value[0] && x_shp_value[1] != Shape::SHP_ANY && y_shp_value[0] != Shape::SHP_ANY) {
|
||||
MS_LOG(EXCEPTION) << "Incompatible shapes in dot: {" << x_shp->ToString() << "} and {" << y_shp->ToString() << "}";
|
||||
}
|
||||
|
||||
auto x_element = input_x->element();
|
||||
MS_EXCEPTION_IF_NULL(x_element);
|
||||
(void)x_element->Join(input_y->element());
|
||||
auto param = {x_shp_value[0], y_shp_value[1]};
|
||||
|
||||
return std::make_shared<AbstractTensor>(input_x->element(), std::make_shared<Shape>(param));
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplSwitch(const AnalysisEnginePtr &, const PrimitivePtr &prim,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: condition, true branch, false branch
|
||||
|
|
|
@ -26,7 +26,6 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
|
|||
static PrimitiveEvalImplMap prim_eval_implement_map = {
|
||||
// Statements
|
||||
{prim::kPrimReturn, {InferImplReturn, true}},
|
||||
{prim::kPrimDot, {InferImplDot, true}},
|
||||
{prim::kPrimSwitch, {InferImplSwitch, true}},
|
||||
{prim::kPrimSwitchLayer, {InferImplSwitchLayer, true}},
|
||||
{prim::kPrimIs_, {InferImplIs_, true}},
|
||||
|
|
|
@ -67,7 +67,6 @@ inline const PrimitivePtr kPrimLogicalOr = std::make_shared<Primitive>("LogicalO
|
|||
inline const PrimitivePtr kPrimLogicalNot = std::make_shared<Primitive>("LogicalNot");
|
||||
|
||||
inline const PrimitivePtr kPrimDistribute = std::make_shared<Primitive>("distribute");
|
||||
inline const PrimitivePtr kPrimDot = std::make_shared<Primitive>("dot");
|
||||
inline const PrimitivePtr kPrimIm2Col = std::make_shared<Primitive>("im2col");
|
||||
inline const PrimitivePtr kPrimCol2Im = std::make_shared<Primitive>("col2im");
|
||||
inline const PrimitivePtr kPrimIm2ColV1 = std::make_shared<Primitive>("im2col_v1");
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
# Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -188,12 +188,6 @@ def bprop_array_to_scalar(x, out, dout):
|
|||
return (F.scalar_to_array(dout),)
|
||||
|
||||
|
||||
@bprops.register("dot")
|
||||
def bprop_dot(x, y, out, dout):
|
||||
"""Backpropagator for primitive `dot`."""
|
||||
return F.dot(dout, F.transpose(y, (1, 0))), F.dot(F.transpose(x, (1, 0)), dout)
|
||||
|
||||
|
||||
@bprops.register("reshape")
|
||||
def bprop_reshape(xs, shp, out, dout):
|
||||
"""Backpropagator for primitive `reshape`."""
|
||||
|
|
|
@ -142,7 +142,6 @@ in_dict = Primitive("in_dict")
|
|||
not_in_dict = Primitive("not_in_dict")
|
||||
mixed_precision_cast = Primitive("mixed_precision_cast")
|
||||
broadcast_gradient_args = Primitive('BroadcastGradientArgs')
|
||||
dot = Primitive('dot')
|
||||
array_reduce = Primitive('array_reduce')
|
||||
zeros_like = P.ZerosLike()
|
||||
distribute = Primitive('distribute')
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -287,11 +287,6 @@ TEST_F(TestOps, TransposeTest) {
|
|||
ASSERT_EQ(prim->name(), kPrimTranspose->name());
|
||||
}
|
||||
|
||||
TEST_F(TestOps, DotTest) {
|
||||
auto prim = std::make_shared<Primitive>("dot");
|
||||
ASSERT_EQ(prim->name(), kPrimDot->name());
|
||||
}
|
||||
|
||||
TEST_F(TestOps, Im2ColTest) {
|
||||
auto prim = std::make_shared<Primitive>("im2col");
|
||||
ASSERT_EQ(prim->name(), kPrimIm2Col->name());
|
||||
|
|
|
@ -169,11 +169,6 @@ TEST_F(TestAD, test_prim_array_to_scalar) {
|
|||
AssertExpect("test_prim_array_to_scalar", dg);
|
||||
}
|
||||
|
||||
TEST_F(TestAD, test_prim_dot) {
|
||||
FuncGraphPtr dg = Kprim(NewValueNode(prim::kPrimDot), resourcePtr);
|
||||
AssertExpect("test_prim_dot", dg);
|
||||
}
|
||||
|
||||
TEST_F(TestAD, test_prim_distribute) {
|
||||
FuncGraphPtr dg = Kprim(NewValueNode(prim::kPrimDistribute), resourcePtr);
|
||||
AssertExpect("test_prim_distribute", dg);
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -291,22 +291,6 @@ TEST_F(TestPrim, test_J_2) {
|
|||
ASSERT_TRUE(res_J_1 != nullptr);
|
||||
}
|
||||
|
||||
TEST_F(TestPrim, test_dot) {
|
||||
auto dot = std::make_shared<Primitive>("dot");
|
||||
FuncGraphPtr func_graph = MakeFuncGraph(dot, 2);
|
||||
|
||||
auto a1 = UTPrimUtils::ArrayFloat64Of({2, 3});
|
||||
auto a2 = UTPrimUtils::ArrayFloat64Of({3, 4});
|
||||
std::vector<int64_t> expectedA = {2, 4};
|
||||
auto expected = UTPrimUtils::ArrayFloat64Of({2, 4});
|
||||
|
||||
AbstractBasePtrList args_spec_list = {a1, a2};
|
||||
|
||||
AbstractTensorPtr res = dyn_cast<AbstractTensor>(engine_->Run(func_graph, args_spec_list).inferred->abstract());
|
||||
|
||||
ASSERT_TRUE(*(dyn_cast<Shape>(res->GetShapeTrack())) == *(dyn_cast<Shape>(expected->GetShapeTrack())));
|
||||
}
|
||||
|
||||
// tail half
|
||||
TEST_F(TestPrim, test_switch1) {
|
||||
PrimitivePtr switch_ = std::make_shared<Primitive>("switch");
|
||||
|
|
Loading…
Reference in New Issue