forked from mindspore-Ecosystem/mindspore
synchronize latest ascend software 04 Jun 2020
This commit is contained in:
parent
39338c8627
commit
8da4c1a763
|
@ -7,6 +7,9 @@ endif ()
|
||||||
|
|
||||||
include(${CMAKE_SOURCE_DIR}/cmake/options.cmake)
|
include(${CMAKE_SOURCE_DIR}/cmake/options.cmake)
|
||||||
set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${CMAKE_SOURCE_DIR}/cmake/modules/")
|
set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${CMAKE_SOURCE_DIR}/cmake/modules/")
|
||||||
|
if (NOT CMAKE_SYSTEM_NAME MATCHES "Windows")
|
||||||
|
add_compile_definitions(_GLIBCXX_USE_CXX11_ABI=0)
|
||||||
|
endif ()
|
||||||
|
|
||||||
if (${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
|
if (${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
|
||||||
set(CMAKE_CXX_FLAGS_RELEASE "$ENV{CXXFLAGS} -O2 -Werror -Wno-return-std-move -Wno-unused-private-field -Wno-unused-lambda-capture -Wno-sign-compare -Wno-overloaded-virtual -Wno-unneeded-internal-declaration -Wno-unused-variable -Wno-pessimizing-move -Wno-inconsistent-missing-override -DHALF_ENABLE_CPP11_USER_LITERALS=0 -D_FORTIFY_SOURCE=2")
|
set(CMAKE_CXX_FLAGS_RELEASE "$ENV{CXXFLAGS} -O2 -Werror -Wno-return-std-move -Wno-unused-private-field -Wno-unused-lambda-capture -Wno-sign-compare -Wno-overloaded-virtual -Wno-unneeded-internal-declaration -Wno-unused-variable -Wno-pessimizing-move -Wno-inconsistent-missing-override -DHALF_ENABLE_CPP11_USER_LITERALS=0 -D_FORTIFY_SOURCE=2")
|
||||||
|
|
|
@ -36,6 +36,7 @@ elseif (DEFINED ENV{D_LINK_PATH})
|
||||||
find_library(hccl libhccl.so ${GE_LIB_PATH})
|
find_library(hccl libhccl.so ${GE_LIB_PATH})
|
||||||
find_library(cce libcce.so ${GE_LIB_PATH})
|
find_library(cce libcce.so ${GE_LIB_PATH})
|
||||||
find_library(resource libresource.so ${GE_LIB_PATH})
|
find_library(resource libresource.so ${GE_LIB_PATH})
|
||||||
|
find_library(error_manager liberror_manager.so ${GE_LIB_PATH})
|
||||||
else()
|
else()
|
||||||
# Ascend mode
|
# Ascend mode
|
||||||
if(DEFINED ENV{ASCEND_CUSTOM_PATH})
|
if(DEFINED ENV{ASCEND_CUSTOM_PATH})
|
||||||
|
@ -54,6 +55,7 @@ else()
|
||||||
find_library(msprof libmsprof.so ${ASCEND_RUNTIME_PATH})
|
find_library(msprof libmsprof.so ${ASCEND_RUNTIME_PATH})
|
||||||
find_library(register libregister.so ${ASCEND_RUNTIME_PATH})
|
find_library(register libregister.so ${ASCEND_RUNTIME_PATH})
|
||||||
find_library(resource libresource.so ${ASCEND_RUNTIME_PATH})
|
find_library(resource libresource.so ${ASCEND_RUNTIME_PATH})
|
||||||
|
find_library(error_manager liberror_manager.so ${ASCEND_RUNTIME_PATH})
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
# compile libraries from following directories
|
# compile libraries from following directories
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
set(gtest_CXXFLAGS "-D_FORTIFY_SOURCE=2 -O2")
|
set(gtest_CXXFLAGS "-D_FORTIFY_SOURCE=2 -D_GLIBCXX_USE_CXX11_ABI=0 -O2")
|
||||||
set(gtest_CFLAGS "-D_FORTIFY_SOURCE=2 -O2")
|
set(gtest_CFLAGS "-D_FORTIFY_SOURCE=2 -O2")
|
||||||
mindspore_add_pkg(gtest
|
mindspore_add_pkg(gtest
|
||||||
VER 1.8.0
|
VER 1.8.0
|
||||||
|
|
|
@ -8,7 +8,7 @@ elseif (${CMAKE_SYSTEM_NAME} MATCHES "Windows")
|
||||||
set(opencv_CXXFLAGS "${opencv_CXXFLAGS} -Wno-attributes -Wno-unknown-pragmas")
|
set(opencv_CXXFLAGS "${opencv_CXXFLAGS} -Wno-attributes -Wno-unknown-pragmas")
|
||||||
set(opencv_CXXFLAGS "${opencv_CXXFLAGS} -Wno-unused-value -Wno-implicit-fallthrough")
|
set(opencv_CXXFLAGS "${opencv_CXXFLAGS} -Wno-unused-value -Wno-implicit-fallthrough")
|
||||||
else()
|
else()
|
||||||
set(opencv_CXXFLAGS "-fstack-protector-all -Wno-maybe-uninitialized -Wno-unused-parameter -D_FORTIFY_SOURCE=2 -O2")
|
set(opencv_CXXFLAGS "-fstack-protector-all -Wno-maybe-uninitialized -Wno-unused-parameter -D_FORTIFY_SOURCE=2 -D_GLIBCXX_USE_CXX11_ABI=0 -O2")
|
||||||
set(opencv_CFLAGS "-fstack-protector-all -Wno-maybe-uninitialized -Wno-unused-parameter -D_FORTIFY_SOURCE=2 -O2")
|
set(opencv_CFLAGS "-fstack-protector-all -Wno-maybe-uninitialized -Wno-unused-parameter -D_FORTIFY_SOURCE=2 -O2")
|
||||||
set(opencv_LDFLAGS "-Wl,-z,relro,-z,now,-z,noexecstack")
|
set(opencv_LDFLAGS "-Wl,-z,relro,-z,now,-z,noexecstack")
|
||||||
endif()
|
endif()
|
||||||
|
|
|
@ -1,9 +1,12 @@
|
||||||
set(protobuf_USE_STATIC_LIBS ON)
|
set(protobuf_USE_STATIC_LIBS ON)
|
||||||
if (${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
|
if (${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
|
||||||
set(protobuf_CXXFLAGS "-fstack-protector-all -Wno-uninitialized -Wno-unused-parameter -fPIC -fvisibility=hidden -D_FORTIFY_SOURCE=2 -O2")
|
set(protobuf_CXXFLAGS "-fstack-protector-all -Wno-uninitialized -Wno-unused-parameter -fPIC -fvisibility=hidden -D_FORTIFY_SOURCE=2 -O2")
|
||||||
else()
|
elseif (${CMAKE_SYSTEM_NAME} MATCHES "Windows")
|
||||||
set(protobuf_CXXFLAGS "-fstack-protector-all -Wno-maybe-uninitialized -Wno-unused-parameter -fPIC -fvisibility=hidden -D_FORTIFY_SOURCE=2 -O2")
|
set(protobuf_CXXFLAGS "-fstack-protector-all -Wno-maybe-uninitialized -Wno-unused-parameter -fPIC -fvisibility=hidden -D_FORTIFY_SOURCE=2 -O2")
|
||||||
|
else()
|
||||||
|
set(protobuf_CXXFLAGS "-fstack-protector-all -Wno-maybe-uninitialized -Wno-unused-parameter -fPIC -fvisibility=hidden -D_FORTIFY_SOURCE=2 -D_GLIBCXX_USE_CXX11_ABI=0 -O2")
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
set(protobuf_LDFLAGS "-Wl,-z,relro,-z,now,-z,noexecstack")
|
set(protobuf_LDFLAGS "-Wl,-z,relro,-z,now,-z,noexecstack")
|
||||||
set(_ms_tmp_CMAKE_CXX_FLAGS ${CMAKE_CXX_FLAGS})
|
set(_ms_tmp_CMAKE_CXX_FLAGS ${CMAKE_CXX_FLAGS})
|
||||||
set(CMAKE_CXX_FLAGS ${_ms_tmp_CMAKE_CXX_FLAGS})
|
set(CMAKE_CXX_FLAGS ${_ms_tmp_CMAKE_CXX_FLAGS})
|
||||||
|
|
|
@ -1 +1 @@
|
||||||
Subproject commit 579dcb75a990b533f9182733a6424f2bd66f0f23
|
Subproject commit 9248a2fd15ffc64d9d04b40c6b2836d1c94ca0b4
|
|
@ -32,6 +32,8 @@ namespace tbe {
|
||||||
static std::map<string, string> tbe_func_adapter_map = {
|
static std::map<string, string> tbe_func_adapter_map = {
|
||||||
{"softmax", "softmax_v2"},
|
{"softmax", "softmax_v2"},
|
||||||
{"log_softmax", "log_softmax_v2"},
|
{"log_softmax", "log_softmax_v2"},
|
||||||
|
{"apply_momentum", "apply_momentum_d"},
|
||||||
|
{"apply_ftrl", "apply_ftrl_d"},
|
||||||
{"re_lu6", "relu6"},
|
{"re_lu6", "relu6"},
|
||||||
{"re_lu6_grad", "relu6_grad"},
|
{"re_lu6_grad", "relu6_grad"},
|
||||||
{"re_lu", "relu"},
|
{"re_lu", "relu"},
|
||||||
|
@ -89,7 +91,7 @@ static std::map<string, string> tbe_func_adapter_map = {
|
||||||
{"batch_to_space_nd", "batch_to_space_nd_d"},
|
{"batch_to_space_nd", "batch_to_space_nd_d"},
|
||||||
{"resize_bilinear", "resize_bilinear_v2_d"},
|
{"resize_bilinear", "resize_bilinear_v2_d"},
|
||||||
{"resize_bilinear_grad", "resize_bilinear_v2_grad"},
|
{"resize_bilinear_grad", "resize_bilinear_v2_grad"},
|
||||||
{"adam", "apply_adam"},
|
{"adam", "apply_adam_d"},
|
||||||
{"r_oi_align", "roi_align"},
|
{"r_oi_align", "roi_align"},
|
||||||
{"r_oi_align_grad", "roi_align_grad"},
|
{"r_oi_align_grad", "roi_align_grad"},
|
||||||
{"i_ou", "iou"},
|
{"i_ou", "iou"},
|
||||||
|
|
|
@ -32,19 +32,6 @@ bool CheckValueNodeInputOfMul(const AnfNodePtr &node) {
|
||||||
std::vector<size_t> mul_input_shape = AnfAlgo::GetOutputInferShape(node, 0);
|
std::vector<size_t> mul_input_shape = AnfAlgo::GetOutputInferShape(node, 0);
|
||||||
return mul_input_shape.empty() || (mul_input_shape.size() == 1 && mul_input_shape[0] == 1);
|
return mul_input_shape.empty() || (mul_input_shape.size() == 1 && mul_input_shape[0] == 1);
|
||||||
}
|
}
|
||||||
void AddInputToOutput(const FuncGraphPtr &func_graph, const CNodePtr &old_cnode, const AnfNodePtr &new_node,
|
|
||||||
std::vector<AnfNodePtr> *new_outputs) {
|
|
||||||
MS_EXCEPTION_IF_NULL(old_cnode);
|
|
||||||
MS_EXCEPTION_IF_NULL(new_node);
|
|
||||||
MS_EXCEPTION_IF_NULL(new_outputs);
|
|
||||||
auto node_to_output = old_cnode->input(kAccumIndex + 1);
|
|
||||||
MS_EXCEPTION_IF_NULL(node_to_output);
|
|
||||||
AbstractBasePtrList abstract_list{old_cnode->abstract(), node_to_output->abstract()};
|
|
||||||
auto abstract_tuple = std::make_shared<abstract::AbstractTuple>(abstract_list);
|
|
||||||
new_node->set_abstract(abstract_tuple);
|
|
||||||
// Create Output
|
|
||||||
CreateMultipleOutputsOfAnfNode(func_graph, new_node, kFusedMulApplyMomentumOutputNum, new_outputs);
|
|
||||||
}
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
const BaseRef MomentumLossscaleFusion::DefinePattern() const {
|
const BaseRef MomentumLossscaleFusion::DefinePattern() const {
|
||||||
|
@ -94,14 +81,9 @@ const AnfNodePtr MomentumLossscaleFusion::Process(const FuncGraphPtr &func_graph
|
||||||
input_names_value[3] = "x1";
|
input_names_value[3] = "x1";
|
||||||
input_names_value.emplace_back("x2");
|
input_names_value.emplace_back("x2");
|
||||||
AnfAlgo::SetNodeAttr(kAttrInputNames, MakeValue(input_names_value), new_node);
|
AnfAlgo::SetNodeAttr(kAttrInputNames, MakeValue(input_names_value), new_node);
|
||||||
|
new_node->set_abstract(node->abstract());
|
||||||
new_node->set_scope(node->scope());
|
new_node->set_scope(node->scope());
|
||||||
// Create Outputs
|
return new_node;
|
||||||
std::vector<AnfNodePtr> new_outputs;
|
|
||||||
AddInputToOutput(func_graph, cnode, new_node, &new_outputs);
|
|
||||||
if (new_outputs.size() != kFusedMulApplyMomentumOutputNum) {
|
|
||||||
MS_LOG(EXCEPTION) << "Failed to create outputs of " << new_node->DebugString();
|
|
||||||
}
|
|
||||||
return new_outputs[0];
|
|
||||||
}
|
}
|
||||||
} // namespace opt
|
} // namespace opt
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -212,7 +212,7 @@ std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_ma
|
||||||
{string(kNameIOU), ADPT_DESC(Iou)},
|
{string(kNameIOU), ADPT_DESC(Iou)},
|
||||||
{string(kNameGreaterEqual), ADPT_DESC(GreaterEqual)},
|
{string(kNameGreaterEqual), ADPT_DESC(GreaterEqual)},
|
||||||
{string(kNameSlice), ADPT_DESC(SliceD)},
|
{string(kNameSlice), ADPT_DESC(SliceD)},
|
||||||
{string(kNameApplyMomentum), ADPT_DESC(ApplyMomentum)},
|
{string(kNameApplyMomentum), ADPT_DESC(ApplyMomentumD)},
|
||||||
{string(kNameMaxPool), ADPT_DESC(MaxPool)},
|
{string(kNameMaxPool), ADPT_DESC(MaxPool)},
|
||||||
{string(kNameAvgPool), ADPT_DESC(AvgPool)},
|
{string(kNameAvgPool), ADPT_DESC(AvgPool)},
|
||||||
{string(kNameMaxPoolWithArgmax), ADPT_DESC(MaxPoolWithArgmax)},
|
{string(kNameMaxPoolWithArgmax), ADPT_DESC(MaxPoolWithArgmax)},
|
||||||
|
@ -395,7 +395,7 @@ std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_ma
|
||||||
{string(kNameDepthToSpace), ADPT_DESC(DepthToSpace)},
|
{string(kNameDepthToSpace), ADPT_DESC(DepthToSpace)},
|
||||||
{string(kNameSign), ADPT_DESC(Sign)},
|
{string(kNameSign), ADPT_DESC(Sign)},
|
||||||
{string(kNameRound), ADPT_DESC(Round)},
|
{string(kNameRound), ADPT_DESC(Round)},
|
||||||
{string(kNameApplyFtrl), ADPT_DESC(ApplyFtrl)},
|
{string(kNameApplyFtrl), ADPT_DESC(ApplyFtrlD)},
|
||||||
{string(kNameDiag), ADPT_DESC(Diag)},
|
{string(kNameDiag), ADPT_DESC(Diag)},
|
||||||
{string(kNameDiagPart), ADPT_DESC(DiagPart)},
|
{string(kNameDiagPart), ADPT_DESC(DiagPart)},
|
||||||
{string(kNameSpaceToBatch), ADPT_DESC(SpaceToBatchD)},
|
{string(kNameSpaceToBatch), ADPT_DESC(SpaceToBatchD)},
|
||||||
|
@ -409,7 +409,7 @@ std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_ma
|
||||||
{string(kNameSquareSumAll), ADPT_DESC(SquareSumAll)}};
|
{string(kNameSquareSumAll), ADPT_DESC(SquareSumAll)}};
|
||||||
#ifdef ENABLE_GE
|
#ifdef ENABLE_GE
|
||||||
adpt_map[string(kNamePrint)] = ADPT_DESC(Print);
|
adpt_map[string(kNamePrint)] = ADPT_DESC(Print);
|
||||||
adpt_map[string(kNameApplyAdam)] = ADPT_DESC(ApplyAdam);
|
adpt_map[string(kNameApplyAdam)] = ADPT_DESC(ApplyAdamD);
|
||||||
#endif
|
#endif
|
||||||
return adpt_map;
|
return adpt_map;
|
||||||
}
|
}
|
||||||
|
|
|
@ -127,11 +127,12 @@ INPUT_MAP(Constant) = EMPTY_INPUT_MAP;
|
||||||
ATTR_MAP(Constant) = {{"value", ATTR_DESC(value, AnyTraits<AnyValue>())}};
|
ATTR_MAP(Constant) = {{"value", ATTR_DESC(value, AnyTraits<AnyValue>())}};
|
||||||
OUTPUT_MAP(Constant) = {{0, OUTPUT_DESC(y)}};
|
OUTPUT_MAP(Constant) = {{0, OUTPUT_DESC(y)}};
|
||||||
|
|
||||||
// ApplyMomentum
|
// ApplyMomentumD
|
||||||
INPUT_MAP(ApplyMomentum) = {
|
INPUT_MAP(ApplyMomentumD) = {
|
||||||
{1, INPUT_DESC(var)}, {2, INPUT_DESC(accum)}, {3, INPUT_DESC(lr)}, {4, INPUT_DESC(grad)}, {5, INPUT_DESC(momentum)}};
|
{1, INPUT_DESC(var)}, {2, INPUT_DESC(accum)}, {3, INPUT_DESC(lr)}, {4, INPUT_DESC(grad)}, {5, INPUT_DESC(momentum)}};
|
||||||
ATTR_MAP(ApplyMomentum) = {{"use_nesterov", ATTR_DESC(use_nesterov, AnyTraits<bool>())}};
|
ATTR_MAP(ApplyMomentumD) = {{"use_nesterov", ATTR_DESC(use_nesterov, AnyTraits<bool>())},
|
||||||
OUTPUT_MAP(ApplyMomentum) = {{0, OUTPUT_DESC(var)}};
|
{"use_locking", ATTR_DESC(use_locking, AnyTraits<bool>())}};
|
||||||
|
OUTPUT_MAP(ApplyMomentumD) = {{0, OUTPUT_DESC(var)}, {1, OUTPUT_DESC(accum)}};
|
||||||
|
|
||||||
// ScalarSummary
|
// ScalarSummary
|
||||||
INPUT_MAP(Summary) = {{2, INPUT_DESC(x)}};
|
INPUT_MAP(Summary) = {{2, INPUT_DESC(x)}};
|
||||||
|
@ -470,7 +471,16 @@ INPUT_MAP(ApplyAdam) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(m)},
|
||||||
{10, INPUT_DESC(grad)}};
|
{10, INPUT_DESC(grad)}};
|
||||||
ATTR_MAP(ApplyAdam) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits<bool>())},
|
ATTR_MAP(ApplyAdam) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits<bool>())},
|
||||||
{"use_nesterov", ATTR_DESC(use_nesterov, AnyTraits<bool>())}};
|
{"use_nesterov", ATTR_DESC(use_nesterov, AnyTraits<bool>())}};
|
||||||
OUTPUT_MAP(ApplyAdam) = {{0, OUTPUT_DESC(var)}, {1, OUTPUT_DESC(m)}, {2, OUTPUT_DESC(v)}};
|
OUTPUT_MAP(ApplyAdam) = {{0, OUTPUT_DESC(var)}};
|
||||||
|
|
||||||
|
// ApplyAdamD
|
||||||
|
INPUT_MAP(ApplyAdamD) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(m)}, {3, INPUT_DESC(v)},
|
||||||
|
{4, INPUT_DESC(beta1_power)}, {5, INPUT_DESC(beta2_power)}, {6, INPUT_DESC(lr)},
|
||||||
|
{7, INPUT_DESC(beta1)}, {8, INPUT_DESC(beta2)}, {9, INPUT_DESC(epsilon)},
|
||||||
|
{10, INPUT_DESC(grad)}};
|
||||||
|
ATTR_MAP(ApplyAdamD) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits<bool>())},
|
||||||
|
{"use_nesterov", ATTR_DESC(use_nesterov, AnyTraits<bool>())}};
|
||||||
|
OUTPUT_MAP(ApplyAdamD) = {{0, OUTPUT_DESC(var)}, {1, OUTPUT_DESC(m)}, {2, OUTPUT_DESC(v)}};
|
||||||
|
|
||||||
// Relu6
|
// Relu6
|
||||||
INPUT_MAP(Relu6) = {{1, INPUT_DESC(x)}};
|
INPUT_MAP(Relu6) = {{1, INPUT_DESC(x)}};
|
||||||
|
@ -823,7 +833,7 @@ OUTPUT_MAP(RealDiv) = {{0, OUTPUT_DESC(y)}};
|
||||||
// Cast
|
// Cast
|
||||||
INPUT_MAP(Cast) = {{1, INPUT_DESC(x)}};
|
INPUT_MAP(Cast) = {{1, INPUT_DESC(x)}};
|
||||||
INPUT_ATTR_MAP(Cast) = {{2, ATTR_DESC(dst_type, AnyTraits<GEType>())}};
|
INPUT_ATTR_MAP(Cast) = {{2, ATTR_DESC(dst_type, AnyTraits<GEType>())}};
|
||||||
ATTR_MAP(Cast) = {{"Truncate", ATTR_DESC(truncate, AnyTraits<bool>())}};
|
ATTR_MAP(Cast) = EMPTY_ATTR_MAP;
|
||||||
OUTPUT_MAP(Cast) = {{0, OUTPUT_DESC(y)}};
|
OUTPUT_MAP(Cast) = {{0, OUTPUT_DESC(y)}};
|
||||||
|
|
||||||
// Reciprocal
|
// Reciprocal
|
||||||
|
@ -1194,12 +1204,12 @@ INPUT_MAP(Round) = {{1, INPUT_DESC(x)}};
|
||||||
ATTR_MAP(Round) = EMPTY_ATTR_MAP;
|
ATTR_MAP(Round) = EMPTY_ATTR_MAP;
|
||||||
OUTPUT_MAP(Round) = {{0, OUTPUT_DESC(y)}};
|
OUTPUT_MAP(Round) = {{0, OUTPUT_DESC(y)}};
|
||||||
|
|
||||||
// ApplyFtrl
|
// ApplyFtrlD
|
||||||
INPUT_MAP(ApplyFtrl) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(accum)}, {3, INPUT_DESC(linear)},
|
INPUT_MAP(ApplyFtrlD) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(accum)}, {3, INPUT_DESC(linear)},
|
||||||
{4, INPUT_DESC(grad)}, {5, INPUT_DESC(lr)}, {6, INPUT_DESC(l1)},
|
{4, INPUT_DESC(grad)}, {5, INPUT_DESC(lr)}, {6, INPUT_DESC(l1)},
|
||||||
{7, INPUT_DESC(l2)}, {8, INPUT_DESC(lr_power)}};
|
{7, INPUT_DESC(l2)}, {8, INPUT_DESC(lr_power)}};
|
||||||
ATTR_MAP(ApplyFtrl) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits<bool>())}};
|
ATTR_MAP(ApplyFtrlD) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits<bool>())}};
|
||||||
OUTPUT_MAP(ApplyFtrl) = {{0, OUTPUT_DESC(var)}};
|
OUTPUT_MAP(ApplyFtrlD) = {{0, OUTPUT_DESC(var)}, {1, OUTPUT_DESC(accum)}, {2, OUTPUT_DESC(linear)}};
|
||||||
|
|
||||||
// Diag
|
// Diag
|
||||||
INPUT_MAP(Diag) = {{1, INPUT_DESC(x)}};
|
INPUT_MAP(Diag) = {{1, INPUT_DESC(x)}};
|
||||||
|
|
|
@ -120,6 +120,8 @@ DECLARE_OP_ADAPTER(ResizeNearestNeighborV2Grad)
|
||||||
DECLARE_OP_USE_OUTPUT(ResizeNearestNeighborV2Grad)
|
DECLARE_OP_USE_OUTPUT(ResizeNearestNeighborV2Grad)
|
||||||
DECLARE_OP_ADAPTER(ApplyAdam)
|
DECLARE_OP_ADAPTER(ApplyAdam)
|
||||||
DECLARE_OP_USE_OUTPUT(ApplyAdam)
|
DECLARE_OP_USE_OUTPUT(ApplyAdam)
|
||||||
|
DECLARE_OP_ADAPTER(ApplyAdamD)
|
||||||
|
DECLARE_OP_USE_OUTPUT(ApplyAdamD)
|
||||||
DECLARE_OP_ADAPTER(Relu6)
|
DECLARE_OP_ADAPTER(Relu6)
|
||||||
DECLARE_OP_USE_OUTPUT(Relu6)
|
DECLARE_OP_USE_OUTPUT(Relu6)
|
||||||
DECLARE_OP_ADAPTER(Relu6Grad)
|
DECLARE_OP_ADAPTER(Relu6Grad)
|
||||||
|
@ -319,8 +321,8 @@ DECLARE_OP_ADAPTER(Assign)
|
||||||
DECLARE_OP_USE_OUTPUT(Assign)
|
DECLARE_OP_USE_OUTPUT(Assign)
|
||||||
DECLARE_OP_ADAPTER(Constant)
|
DECLARE_OP_ADAPTER(Constant)
|
||||||
DECLARE_OP_USE_OUTPUT(Constant)
|
DECLARE_OP_USE_OUTPUT(Constant)
|
||||||
DECLARE_OP_ADAPTER(ApplyMomentum)
|
DECLARE_OP_ADAPTER(ApplyMomentumD)
|
||||||
DECLARE_OP_USE_OUTPUT(ApplyMomentum)
|
DECLARE_OP_USE_OUTPUT(ApplyMomentumD)
|
||||||
// ** Summary Operations **
|
// ** Summary Operations **
|
||||||
DECLARE_OP_ADAPTER(Summary)
|
DECLARE_OP_ADAPTER(Summary)
|
||||||
|
|
||||||
|
@ -454,8 +456,8 @@ DECLARE_OP_ADAPTER(LarsV2Update)
|
||||||
DECLARE_OP_USE_OUTPUT(LarsV2Update)
|
DECLARE_OP_USE_OUTPUT(LarsV2Update)
|
||||||
DECLARE_OP_ADAPTER(Round)
|
DECLARE_OP_ADAPTER(Round)
|
||||||
DECLARE_OP_USE_OUTPUT(Round)
|
DECLARE_OP_USE_OUTPUT(Round)
|
||||||
DECLARE_OP_ADAPTER(ApplyFtrl)
|
DECLARE_OP_ADAPTER(ApplyFtrlD)
|
||||||
DECLARE_OP_USE_OUTPUT(ApplyFtrl)
|
DECLARE_OP_USE_OUTPUT(ApplyFtrlD)
|
||||||
DECLARE_OP_ADAPTER(SparseApplyFtrlD)
|
DECLARE_OP_ADAPTER(SparseApplyFtrlD)
|
||||||
DECLARE_OP_USE_OUTPUT(SparseApplyFtrlD)
|
DECLARE_OP_USE_OUTPUT(SparseApplyFtrlD)
|
||||||
DECLARE_OP_ADAPTER(Diag)
|
DECLARE_OP_ADAPTER(Diag)
|
||||||
|
|
|
@ -32,30 +32,32 @@ apply_ftrl_op_info = TBERegOp("ApplyFtrl") \
|
||||||
.input(6, "l2", False, "required", "all") \
|
.input(6, "l2", False, "required", "all") \
|
||||||
.input(7, "lr_power", False, "required", "all") \
|
.input(7, "lr_power", False, "required", "all") \
|
||||||
.output(0, "var", False, "required", "all") \
|
.output(0, "var", False, "required", "all") \
|
||||||
|
.output(1, "accum", False, "required", "all") \
|
||||||
|
.output(2, "linear", False, "required", "all") \
|
||||||
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD,
|
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD,
|
||||||
DataType.F16_5HD, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
|
DataType.F16_5HD, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
|
||||||
DataType.F16_5HD) \
|
DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \
|
||||||
.dtype_format(DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_FracZ,
|
.dtype_format(DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_FracZ,
|
||||||
DataType.F16_FracZ, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
|
DataType.F16_FracZ, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
|
||||||
DataType.F16_FracZ) \
|
DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_FracZ) \
|
||||||
.dtype_format(DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0,
|
.dtype_format(DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0,
|
||||||
DataType.F16_C1HWNCoC0, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
|
DataType.F16_C1HWNCoC0, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
|
||||||
DataType.F16_C1HWNCoC0) \
|
DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0) \
|
||||||
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
|
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
|
||||||
DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
|
DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
|
||||||
DataType.F16_Default) \
|
DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
|
||||||
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD,
|
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD,
|
||||||
DataType.F32_5HD, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
|
DataType.F32_5HD, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
|
||||||
DataType.F32_5HD) \
|
DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \
|
||||||
.dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ,
|
.dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ,
|
||||||
DataType.F32_FracZ, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
|
DataType.F32_FracZ, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
|
||||||
DataType.F32_FracZ) \
|
DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ) \
|
||||||
.dtype_format(DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0,
|
.dtype_format(DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0,
|
||||||
DataType.F32_C1HWNCoC0, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
|
DataType.F32_C1HWNCoC0, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
|
||||||
DataType.F32_C1HWNCoC0) \
|
DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0) \
|
||||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
|
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
|
||||||
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
|
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
|
||||||
DataType.F32_Default) \
|
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||||
.get_op_info()
|
.get_op_info()
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -30,22 +30,23 @@ apply_momentum_op_info = TBERegOp("ApplyMomentum") \
|
||||||
.input(3, "grad", False, "required", "all") \
|
.input(3, "grad", False, "required", "all") \
|
||||||
.input(4, "momentum", False, "required", "all") \
|
.input(4, "momentum", False, "required", "all") \
|
||||||
.output(0, "var", False, "required", "all") \
|
.output(0, "var", False, "required", "all") \
|
||||||
|
.output(1, "accum", False, "required", "all") \
|
||||||
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
|
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
|
||||||
DataType.F16_Default, DataType.F16_Default) \
|
DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
|
||||||
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_Default, DataType.F16_5HD,
|
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_Default, DataType.F16_5HD,
|
||||||
DataType.F16_Default, DataType.F16_5HD) \
|
DataType.F16_Default, DataType.F16_5HD, DataType.F16_5HD) \
|
||||||
.dtype_format(DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_Default, DataType.F16_C1HWNCoC0,
|
.dtype_format(DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_Default, DataType.F16_C1HWNCoC0,
|
||||||
DataType.F16_Default, DataType.F16_C1HWNCoC0) \
|
DataType.F16_Default, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0) \
|
||||||
.dtype_format(DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_Default, DataType.F16_FracZ,
|
.dtype_format(DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_Default, DataType.F16_FracZ,
|
||||||
DataType.F16_Default, DataType.F16_FracZ) \
|
DataType.F16_Default, DataType.F16_FracZ, DataType.F16_FracZ) \
|
||||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
|
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
|
||||||
DataType.F32_Default, DataType.F32_Default) \
|
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||||
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_Default, DataType.F32_5HD,
|
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_Default, DataType.F32_5HD,
|
||||||
DataType.F32_Default, DataType.F32_5HD) \
|
DataType.F32_Default, DataType.F32_5HD, DataType.F32_5HD) \
|
||||||
.dtype_format(DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_Default, DataType.F32_C1HWNCoC0,
|
.dtype_format(DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_Default, DataType.F32_C1HWNCoC0,
|
||||||
DataType.F32_Default, DataType.F32_C1HWNCoC0) \
|
DataType.F32_Default, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0) \
|
||||||
.dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_Default, DataType.F32_FracZ,
|
.dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_Default, DataType.F32_FracZ,
|
||||||
DataType.F32_Default, DataType.F32_FracZ) \
|
DataType.F32_Default, DataType.F32_FracZ, DataType.F32_FracZ) \
|
||||||
.get_op_info()
|
.get_op_info()
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1507,8 +1507,11 @@ class ApplyMomentum(PrimitiveWithInfer):
|
||||||
def __init__(self, use_nesterov=False, use_locking=False, gradient_scale=1.0):
|
def __init__(self, use_nesterov=False, use_locking=False, gradient_scale=1.0):
|
||||||
self.init_prim_io_names(inputs=['variable', 'accumulation', 'learning_rate', 'gradient', 'momentum'],
|
self.init_prim_io_names(inputs=['variable', 'accumulation', 'learning_rate', 'gradient', 'momentum'],
|
||||||
outputs=['output'])
|
outputs=['output'])
|
||||||
|
self.is_tbe = context.get_context("device_target") == "Ascend"
|
||||||
|
|
||||||
def infer_shape(self, v_shape, a_shape, l_shape, g_shape, m_shape):
|
def infer_shape(self, v_shape, a_shape, l_shape, g_shape, m_shape):
|
||||||
|
if self.is_tbe:
|
||||||
|
return v_shape, v_shape
|
||||||
return v_shape
|
return v_shape
|
||||||
|
|
||||||
def infer_dtype(self, v_dtype, a_dtype, l_dtype, g_dtype, m_dtype):
|
def infer_dtype(self, v_dtype, a_dtype, l_dtype, g_dtype, m_dtype):
|
||||||
|
@ -1519,6 +1522,8 @@ class ApplyMomentum(PrimitiveWithInfer):
|
||||||
validator.check_scalar_or_tensor_type_same({"l_dtype": l_dtype}, valid_types, self.name)
|
validator.check_scalar_or_tensor_type_same({"l_dtype": l_dtype}, valid_types, self.name)
|
||||||
validator.check_scalar_or_tensor_type_same({"g_dtype": g_dtype}, valid_types, self.name)
|
validator.check_scalar_or_tensor_type_same({"g_dtype": g_dtype}, valid_types, self.name)
|
||||||
validator.check_scalar_or_tensor_type_same({"m_dtype": m_dtype}, valid_types, self.name)
|
validator.check_scalar_or_tensor_type_same({"m_dtype": m_dtype}, valid_types, self.name)
|
||||||
|
if self.is_tbe:
|
||||||
|
return g_dtype, g_dtype
|
||||||
return g_dtype
|
return g_dtype
|
||||||
|
|
||||||
|
|
||||||
|
@ -2810,13 +2815,13 @@ class SparseApplyAdagrad(PrimitiveWithInfer):
|
||||||
validator.check('var_shape[1:]', var_shape[1:], 'grad_shape[1:]', grad_shape[1:], Rel.EQ, self.name)
|
validator.check('var_shape[1:]', var_shape[1:], 'grad_shape[1:]', grad_shape[1:], Rel.EQ, self.name)
|
||||||
validator.check_integer("indices rank", len(indices_shape), 1, Rel.EQ, self.name)
|
validator.check_integer("indices rank", len(indices_shape), 1, Rel.EQ, self.name)
|
||||||
validator.check('grad_shape[0]', grad_shape[0], 'indices_shape[0]', indices_shape[0], Rel.EQ, self.name)
|
validator.check('grad_shape[0]', grad_shape[0], 'indices_shape[0]', indices_shape[0], Rel.EQ, self.name)
|
||||||
return var_shape
|
return var_shape, accum_shape
|
||||||
|
|
||||||
def infer_dtype(self, var_type, accum_type, grad_type, indices_type):
|
def infer_dtype(self, var_type, accum_type, grad_type, indices_type):
|
||||||
args = {'var': var_type, 'accum': accum_type, 'grad': grad_type}
|
args = {'var': var_type, 'accum': accum_type, 'grad': grad_type}
|
||||||
validator.check_tensor_type_same(args, (mstype.float32,), self.name)
|
validator.check_tensor_type_same(args, (mstype.float32,), self.name)
|
||||||
validator.check_tensor_type_same({'indices': indices_type}, [mstype.int32], self.name)
|
validator.check_tensor_type_same({'indices': indices_type}, [mstype.int32], self.name)
|
||||||
return var_type
|
return var_type, accum_type
|
||||||
|
|
||||||
|
|
||||||
class ApplyProximalAdagrad(PrimitiveWithInfer):
|
class ApplyProximalAdagrad(PrimitiveWithInfer):
|
||||||
|
@ -3074,11 +3079,14 @@ class ApplyFtrl(PrimitiveWithInfer):
|
||||||
self.init_prim_io_names(inputs=['var', 'accum', 'linear', 'grad', 'lr', 'l1', 'l2', 'lr_power'],
|
self.init_prim_io_names(inputs=['var', 'accum', 'linear', 'grad', 'lr', 'l1', 'l2', 'lr_power'],
|
||||||
outputs=['output'])
|
outputs=['output'])
|
||||||
self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name)
|
self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name)
|
||||||
|
self.is_tbe = context.get_context("device_target") == "Ascend"
|
||||||
|
|
||||||
def infer_shape(self, var_shape, accum_shape, linear_shape, grad_shape, lr_shape, l1_shape, l2_shape,
|
def infer_shape(self, var_shape, accum_shape, linear_shape, grad_shape, lr_shape, l1_shape, l2_shape,
|
||||||
lr_power_shape):
|
lr_power_shape):
|
||||||
validator.check('var shape', var_shape, 'accum shape', accum_shape, Rel.EQ, self.name)
|
validator.check('var shape', var_shape, 'accum shape', accum_shape, Rel.EQ, self.name)
|
||||||
validator.check('var shape', var_shape, 'linear shape', linear_shape, Rel.EQ, self.name)
|
validator.check('var shape', var_shape, 'linear shape', linear_shape, Rel.EQ, self.name)
|
||||||
|
if self.is_tbe:
|
||||||
|
return var_shape, var_shape, var_shape
|
||||||
return var_shape
|
return var_shape
|
||||||
|
|
||||||
def infer_dtype(self, var_type, accum_type, linear_type, grad_type, lr_type, l1_type, l2_type, lr_power_type):
|
def infer_dtype(self, var_type, accum_type, linear_type, grad_type, lr_type, l1_type, l2_type, lr_power_type):
|
||||||
|
@ -3090,6 +3098,8 @@ class ApplyFtrl(PrimitiveWithInfer):
|
||||||
validator.check_scalar_or_tensor_type_same({"l1": l1_type}, valid_types, self.name)
|
validator.check_scalar_or_tensor_type_same({"l1": l1_type}, valid_types, self.name)
|
||||||
validator.check_scalar_or_tensor_type_same({"l2": l2_type}, valid_types, self.name)
|
validator.check_scalar_or_tensor_type_same({"l2": l2_type}, valid_types, self.name)
|
||||||
validator.check_scalar_or_tensor_type_same({"lr_power": lr_power_type}, valid_types, self.name)
|
validator.check_scalar_or_tensor_type_same({"lr_power": lr_power_type}, valid_types, self.name)
|
||||||
|
if self.is_tbe:
|
||||||
|
return var_type, var_type, var_type
|
||||||
return var_type
|
return var_type
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -199,10 +199,10 @@ def test_bert_percision():
|
||||||
|
|
||||||
# assertion occurs while the loss value, overflow state or loss_scale value is wrong
|
# assertion occurs while the loss value, overflow state or loss_scale value is wrong
|
||||||
loss_value = np.array(callback.loss_list)
|
loss_value = np.array(callback.loss_list)
|
||||||
assert np.allclose(loss_value[0], 12.207198, 0, 0.000001)
|
assert np.allclose(loss_value[0], 12.206575, 0, 0.000001)
|
||||||
|
|
||||||
expect_loss_value = [12.207198, 11.980881, 11.984844, 11.879381, 11.832978, 12.411333, 12.009284,
|
expect_loss_value = [12.206575, 11.980493, 11.984225, 11.878742, 11.832555, 12.410444, 12.008799,
|
||||||
12.621277, 12.223178, 12.427385]
|
12.620619, 12.22254, 12.4261055]
|
||||||
print("loss value: {}".format(loss_value))
|
print("loss value: {}".format(loss_value))
|
||||||
assert np.allclose(loss_value, expect_loss_value, 0, 0.0005)
|
assert np.allclose(loss_value, expect_loss_value, 0, 0.0005)
|
||||||
|
|
||||||
|
|
|
@ -47,6 +47,6 @@ def test_momentum_lossscale_fusion(tag):
|
||||||
|
|
||||||
@fns
|
@fns
|
||||||
def after(input0, input1, input2, input3, input4):
|
def after(input0, input1, input2, input3, input4):
|
||||||
return make_tuple(tuple_getitem(FusedMulApplyMomentum(input0, input1, input2, input3, input4, constant), 0))
|
return make_tuple(FusedMulApplyMomentum(input0, input1, input2, input3, input4, constant))
|
||||||
|
|
||||||
return fns[tag]
|
return fns[tag]
|
||||||
|
|
|
@ -103,7 +103,7 @@ hcclResult_t hcom_receive(const char *tag, void *outputPtr, u64 count, hcclDataT
|
||||||
|
|
||||||
/* 获取梯度参数切分方案 */
|
/* 获取梯度参数切分方案 */
|
||||||
hcclResult_t hcom_get_split_strategy(const char *group, const struct model_feature *feature, u32 maxSegmentNum,
|
hcclResult_t hcom_get_split_strategy(const char *group, const struct model_feature *feature, u32 maxSegmentNum,
|
||||||
u32 *segmentNum, u32 *segmentIdx) {
|
u32 *segmentNum, u32 *segmentIdx, GradSplitForceMode force) {
|
||||||
return HCCL_SUCCESS;
|
return HCCL_SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -41,7 +41,7 @@ def tensor_run_opt(opt, iters, learning_rate, momentum,
|
||||||
gradient, variable, moment):
|
gradient, variable, moment):
|
||||||
""" tensor_run_opt """
|
""" tensor_run_opt """
|
||||||
success = True
|
success = True
|
||||||
new_weight = opt(variable, moment, learning_rate, gradient, momentum)
|
new_weight = opt(variable, moment, learning_rate, gradient, momentum)[0]
|
||||||
success = F.depend(success, F.assign(variable, new_weight))
|
success = F.depend(success, F.assign(variable, new_weight))
|
||||||
return success
|
return success
|
||||||
|
|
||||||
|
|
|
@ -1058,6 +1058,7 @@ test_case_nn_ops = [
|
||||||
('SparseApplyAdagrad', {
|
('SparseApplyAdagrad', {
|
||||||
'block': P.SparseApplyAdagrad(0.5),
|
'block': P.SparseApplyAdagrad(0.5),
|
||||||
'desc_inputs': [[3, 3], [3, 3], [3, 3], Tensor(np.ones((3,), np.int32))],
|
'desc_inputs': [[3, 3], [3, 3], [3, 3], Tensor(np.ones((3,), np.int32))],
|
||||||
|
'desc_bprop': [[3, 3], [3, 3]],
|
||||||
'skip': ['backward']}),
|
'skip': ['backward']}),
|
||||||
('SparseApplyFtrl', {
|
('SparseApplyFtrl', {
|
||||||
'block': SparseApplyFtrlNet(),
|
'block': SparseApplyFtrlNet(),
|
||||||
|
|
Loading…
Reference in New Issue