forked from OSSInnovation/mindspore
Merge branch 'master' into code_sync_incubator_f3c32baf_to_master_fcfc75a3_0811
This commit is contained in:
commit
4964f7703a
|
@ -5,9 +5,6 @@
|
|||
[submodule "third_party/googletest"]
|
||||
path = third_party/googletest
|
||||
url = https://github.com/google/googletest.git
|
||||
[submodule "third_party/incubator-tvm"]
|
||||
path = third_party/incubator-tvm
|
||||
url = https://github.com/apache/incubator-tvm.git
|
||||
[submodule "third_party/protobuf"]
|
||||
path = third_party/protobuf
|
||||
url = https://github.com/protocolbuffers/protobuf.git
|
||||
|
@ -17,7 +14,7 @@
|
|||
url = https://gitee.com/mindspore/akg.git
|
||||
[submodule "graphengine"]
|
||||
path = graphengine
|
||||
url = https://gitee.com/ms-incubator/graphengine.git
|
||||
url = https://gitee.com/mindspore/graphengine.git
|
||||
[submodule "third_party/OpenCL-CLHPP"]
|
||||
path = third_party/OpenCL-CLHPP
|
||||
url = https://github.com/KhronosGroup/OpenCL-CLHPP.git
|
||||
|
|
|
@ -98,6 +98,7 @@ endif()
|
|||
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fvisibility=hidden")
|
||||
add_subdirectory(mindspore/ccsrc)
|
||||
add_subdirectory(mindspore/core)
|
||||
if (ENABLE_TESTCASES)
|
||||
add_subdirectory(tests)
|
||||
endif()
|
||||
|
|
2
akg
2
akg
|
@ -1 +1 @@
|
|||
Subproject commit 5fe7e5c8377dccfd35c9f661e10ed3dc136208c5
|
||||
Subproject commit 8f9af74f59837579034610a741f5b8f33db12515
|
48
build.sh
48
build.sh
|
@ -109,7 +109,7 @@ checkopts()
|
|||
ENABLE_GPU="off"
|
||||
|
||||
# Process the options
|
||||
while getopts 'drvj:c:t:hsb:a:g:p:ie:m:l:I:LRP:Q:D:zM:V:K:swB:En' opt
|
||||
while getopts 'drvj:c:t:hsb:a:g:p:ie:m:l:I:LRP:Q:D:zM:V:K:swB:EnT:' opt
|
||||
do
|
||||
OPTARG=$(echo ${OPTARG} | tr '[A-Z]' '[a-z]')
|
||||
case "${opt}" in
|
||||
|
@ -282,6 +282,11 @@ checkopts()
|
|||
ENABLE_IBVERBS="on"
|
||||
echo "enable IBVERBS for parameter server"
|
||||
;;
|
||||
T)
|
||||
check_on_off $OPTARG T
|
||||
SUPPORT_TRAIN=$OPTARG
|
||||
echo "support train on device "
|
||||
;;
|
||||
*)
|
||||
echo "Unknown option ${opt}!"
|
||||
usage
|
||||
|
@ -397,7 +402,7 @@ checkndk() {
|
|||
if [ "${ANDROID_NDK}" ]; then
|
||||
echo -e "\e[31mANDROID_NDK_PATH=$ANDROID_NDK \e[0m"
|
||||
else
|
||||
echo -e "\e[31mplease set ANDROID_NDK_PATH in environment variable for example: export ANDROID_NDK=/root/usr/android-ndk-r20b/ \e[0m"
|
||||
echo -e "\e[31mplease set ANDROID_NDK in environment variable for example: export ANDROID_NDK=/root/usr/android-ndk-r20b/ \e[0m"
|
||||
exit 1
|
||||
fi
|
||||
}
|
||||
|
@ -569,6 +574,39 @@ build_minddata_lite_deps()
|
|||
build_jpeg_turbo
|
||||
}
|
||||
|
||||
prepare_md_lite() {
|
||||
if [ "${COMPILE_MINDDATA_LITE}" == "on" ]; then
|
||||
echo "packaging minddata"
|
||||
cp ${BASEPATH}/mindspore/ccsrc/minddata/dataset/include/*h ${OUTPUT_DIR}/include/
|
||||
cp ${BASEPATH}/mindspore/lite/build/minddata/libminddata-lite.so ${OUTPUT_DIR}/lib/
|
||||
if [[ "$LITE_PLATFORM" == "x86_64" ]]; then
|
||||
mkdir -p ${OUTPUT_DIR}/third_party/libjpeg-turbo/lib
|
||||
cp -r ${BASEPATH}/third_party/libjpeg-turbo/lib/libjpeg.so ${OUTPUT_DIR}/third_party/libjpeg-turbo/lib/
|
||||
cp -r ${BASEPATH}/third_party/libjpeg-turbo/lib/libturbojpeg.so ${OUTPUT_DIR}/third_party/libjpeg-turbo/lib/
|
||||
mkdir -p ${OUTPUT_DIR}/third_party/opencv/lib/
|
||||
cp -r ${BASEPATH}/third_party/opencv/build/lib/libopencv_core.so ${OUTPUT_DIR}/third_party/opencv/lib/
|
||||
cp -r ${BASEPATH}/third_party/opencv/build/lib/libopencv_imgcodecs.so ${OUTPUT_DIR}/third_party/opencv/lib/
|
||||
cp -r ${BASEPATH}/third_party/opencv/build/lib/libopencv_imgproc.so ${OUTPUT_DIR}/third_party/opencv/lib/
|
||||
elif [[ "$LITE_PLATFORM" == "arm64" ]]; then
|
||||
mkdir -p ${OUTPUT_DIR}/third_party/libjpeg-turbo/lib
|
||||
cp -r ${BASEPATH}/third_party/libjpeg-turbo/lib/libjpeg.so ${OUTPUT_DIR}/third_party/libjpeg-turbo/lib/
|
||||
cp -r ${BASEPATH}/third_party/libjpeg-turbo/lib/libturbojpeg.so ${OUTPUT_DIR}/third_party/libjpeg-turbo/lib/
|
||||
mkdir -p ${OUTPUT_DIR}/third_party/opencv/lib/arm64-v8a/
|
||||
cp -r ${BASEPATH}/third_party/opencv/build/lib/arm64-v8a/libopencv_core.so ${OUTPUT_DIR}/third_party/opencv/lib/arm64-v8a/
|
||||
cp -r ${BASEPATH}/third_party/opencv/build/lib/arm64-v8a/libopencv_imgcodecs.so ${OUTPUT_DIR}/third_party/opencv/lib/arm64-v8a/
|
||||
cp -r ${BASEPATH}/third_party/opencv/build/lib/arm64-v8a/libopencv_imgproc.so ${OUTPUT_DIR}/third_party/opencv/lib/arm64-v8a/
|
||||
elif [[ "$LITE_PLATFORM" == "arm32" ]]; then
|
||||
mkdir -p ${OUTPUT_DIR}/third_party/libjpeg-turbo/lib
|
||||
cp -r ${BASEPATH}/third_party/libjpeg-turbo/lib/libjpeg.so ${OUTPUT_DIR}/third_party/libjpeg-turbo/lib/
|
||||
cp -r ${BASEPATH}/third_party/libjpeg-turbo/lib/libturbojpeg.so ${OUTPUT_DIR}/third_party/libjpeg-turbo/lib/
|
||||
mkdir -p ${OUTPUT_DIR}/third_party/opencv/lib/armeabi-v7a/
|
||||
cp -r ${BASEPATH}/third_party/opencv/build/lib/armeabi-v7a/libopencv_core.so ${OUTPUT_DIR}/third_party/opencv/lib/armeabi-v7a/
|
||||
cp -r ${BASEPATH}/third_party/opencv/build/lib/armeabi-v7a/libopencv_imgcodecs.so ${OUTPUT_DIR}/third_party/opencv/lib/armeabi-v7a/
|
||||
cp -r ${BASEPATH}/third_party/opencv/build/lib/armeabi-v7a/libopencv_imgproc.so ${OUTPUT_DIR}/third_party/opencv/lib/armeabi-v7a/
|
||||
fi
|
||||
fi
|
||||
}
|
||||
|
||||
build_lite()
|
||||
{
|
||||
echo "start build mindspore lite project"
|
||||
|
@ -632,6 +670,7 @@ build_lite()
|
|||
mkdir -p ${OUTPUT_DIR}/converter && mkdir -p ${OUTPUT_DIR}/time_profile
|
||||
mkdir -p ${OUTPUT_DIR}/benchmark && mkdir -p ${OUTPUT_DIR}/include && mkdir -p ${OUTPUT_DIR}/lib
|
||||
mkdir -p ${OUTPUT_DIR}/third_party
|
||||
prepare_md_lite
|
||||
cp ${BASEPATH}/mindspore/lite/build/tools/converter/converter_lite ${OUTPUT_DIR}/converter/
|
||||
cp ${BASEPATH}/mindspore/lite/build/tools/benchmark/benchmark ${OUTPUT_DIR}/benchmark/
|
||||
cp ${BASEPATH}/mindspore/lite/build/tools/time_profile/timeprofile ${OUTPUT_DIR}/time_profile/
|
||||
|
@ -643,8 +682,7 @@ build_lite()
|
|||
cp ${BASEPATH}/mindspore/lite/build/src/libmindspore-lite.so ${OUTPUT_DIR}/lib/
|
||||
mkdir -p ${OUTPUT_DIR}/third_party/protobuf/lib
|
||||
cp -r ${BASEPATH}/third_party/protobuf/build/include/ ${OUTPUT_DIR}/third_party/protobuf/
|
||||
cp -r ${BASEPATH}/third_party/protobuf/build/lib/libprotobuf.so.19 ${OUTPUT_DIR}/third_party/protobuf/lib/
|
||||
cp -r ${BASEPATH}/third_party/protobuf/build/lib/libprotobuf.so.19.0.0 ${OUTPUT_DIR}/third_party/protobuf/lib/
|
||||
cp -r ${BASEPATH}/third_party/protobuf/build/lib/libprotobuf.so.19.0.0 ${OUTPUT_DIR}/third_party/protobuf/lib/libprotobuf.so.19
|
||||
mkdir -p ${OUTPUT_DIR}/third_party/flatbuffers
|
||||
cp -r ${BASEPATH}/third_party/flatbuffers/include/ ${OUTPUT_DIR}/third_party/flatbuffers/
|
||||
cd ..
|
||||
|
@ -657,6 +695,7 @@ build_lite()
|
|||
mkdir -p ${OUTPUT_DIR}/time_profile && mkdir -p ${OUTPUT_DIR}/benchmark
|
||||
mkdir -p ${OUTPUT_DIR}/include && mkdir -p ${OUTPUT_DIR}/lib
|
||||
mkdir -p ${OUTPUT_DIR}/third_party
|
||||
prepare_md_lite
|
||||
cp ${BASEPATH}/mindspore/lite/build/tools/benchmark/benchmark ${OUTPUT_DIR}/benchmark/
|
||||
cp ${BASEPATH}/mindspore/lite/build/tools/time_profile/timeprofile ${OUTPUT_DIR}/time_profile/
|
||||
cp ${BASEPATH}/mindspore/lite/include/*.h ${OUTPUT_DIR}/include/
|
||||
|
@ -677,6 +716,7 @@ build_lite()
|
|||
mkdir -p ${OUTPUT_DIR}/time_profile && mkdir -p ${OUTPUT_DIR}/benchmark
|
||||
mkdir -p ${OUTPUT_DIR}/include && mkdir -p ${OUTPUT_DIR}/lib
|
||||
mkdir -p ${OUTPUT_DIR}/third_party
|
||||
prepare_md_lite
|
||||
cp ${BASEPATH}/mindspore/lite/build/tools/benchmark/benchmark ${OUTPUT_DIR}/benchmark/
|
||||
cp ${BASEPATH}/mindspore/lite/build/tools/time_profile/timeprofile ${OUTPUT_DIR}/time_profile/
|
||||
cp ${BASEPATH}/mindspore/lite/include/*.h ${OUTPUT_DIR}/include/
|
||||
|
|
|
@ -8,11 +8,12 @@ endif()
|
|||
set(jpeg_turbo_LDFLAGS "-Wl,-z,relro,-z,now,-z,noexecstack")
|
||||
mindspore_add_pkg(jpeg_turbo
|
||||
VER 2.0.4
|
||||
LIBS jpeg
|
||||
LIBS jpeg turbojpeg
|
||||
URL https://github.com/libjpeg-turbo/libjpeg-turbo/archive/2.0.4.tar.gz
|
||||
MD5 44c43e4a9fb352f47090804529317c88
|
||||
CMAKE_OPTION -DCMAKE_BUILD_TYPE=Release -DCMAKE_SKIP_RPATH=TRUE
|
||||
CMAKE_OPTION -DCMAKE_BUILD_TYPE=Release -DCMAKE_SKIP_RPATH=TRUE -DWITH_SIMD=ON
|
||||
PATCHES ${CMAKE_SOURCE_DIR}/third_party/patch/jpeg_turbo/jpeg_turbo.patch001
|
||||
)
|
||||
include_directories(${jpeg_turbo_INC})
|
||||
add_library(mindspore::jpeg_turbo ALIAS jpeg_turbo::jpeg)
|
||||
add_library(mindspore::turbojpeg ALIAS jpeg_turbo::turbojpeg)
|
||||
|
|
|
@ -52,12 +52,6 @@ install(
|
|||
COMPONENT mindspore
|
||||
)
|
||||
|
||||
install(
|
||||
TARGETS mindspore_gvar
|
||||
DESTINATION ${INSTALL_LIB_DIR}
|
||||
COMPONENT mindspore
|
||||
)
|
||||
|
||||
if (USE_GLOG)
|
||||
file(GLOB_RECURSE GLOG_LIB_LIST ${glog_LIBPATH}/libglog*)
|
||||
install(
|
||||
|
@ -146,15 +140,6 @@ if (ENABLE_MPI)
|
|||
COMPONENT mindspore
|
||||
)
|
||||
endif ()
|
||||
file(GLOB_RECURSE MPI_LIB_LIST
|
||||
${ompi_LIBPATH}/libmpi${CMAKE_SHARED_LIBRARY_SUFFIX}*
|
||||
${ompi_LIBPATH}/libopen*${CMAKE_SHARED_LIBRARY_SUFFIX}*
|
||||
)
|
||||
install(
|
||||
FILES ${MPI_LIB_LIST}
|
||||
DESTINATION ${INSTALL_LIB_DIR}
|
||||
COMPONENT mindspore
|
||||
)
|
||||
endif ()
|
||||
|
||||
if (ENABLE_GPU)
|
||||
|
|
|
@ -1 +1 @@
|
|||
Subproject commit 377b2165184fbfbb32829266822438e439861f14
|
||||
Subproject commit 622af6c1c50034bea5a08bd409c5a410782bfe53
|
|
@ -22,7 +22,7 @@ from .parser import (Parser, create_obj_instance, generate_scope,
|
|||
get_dataclass_attributes, get_dataclass_methods, get_obj_id,
|
||||
get_module_namespace, get_obj_type, get_object_key,
|
||||
get_parse_method_of_class, get_scope_name,
|
||||
is_class_member, parse_cb, resolve_symbol)
|
||||
is_class_member, parse_cb, resolve_symbol, convert_to_ms_tensor)
|
||||
from .serialize import *
|
||||
|
||||
__all__ = ['parse_cb', 'get_parse_method_of_class', 'get_bprop_method_of_class', 'resolve_symbol',
|
||||
|
@ -30,4 +30,4 @@ __all__ = ['parse_cb', 'get_parse_method_of_class', 'get_bprop_method_of_class',
|
|||
'get_obj_type', 'get_obj_id', 'create_obj_instance', 'get_module_namespace',
|
||||
'get_class_member_namespace_symbol', 'get_obj_id', 'Parser', 'get_dataclass_attributes',
|
||||
'get_dataclass_methods', 'dump_obj', 'load_obj', 'get_dataclass_methods', 'get_scope_name',
|
||||
'create_slice_obj']
|
||||
'create_slice_obj', 'convert_to_ms_tensor']
|
||||
|
|
|
@ -25,6 +25,7 @@ from dataclasses import is_dataclass
|
|||
import asttokens
|
||||
import mindspore.nn as nn
|
||||
from mindspore import log as logger
|
||||
from mindspore import Tensor as MsTensor
|
||||
from mindspore import ops
|
||||
from mindspore.common.dtype import pytype_to_dtype
|
||||
from mindspore.common.api import _MindSporeFunction
|
||||
|
@ -316,6 +317,11 @@ def get_dataclass_methods(cls):
|
|||
return methods
|
||||
|
||||
|
||||
def convert_to_ms_tensor(data):
|
||||
"""Convert C++ tensor to mindspore tensor."""
|
||||
return MsTensor(data)
|
||||
|
||||
|
||||
class Parser:
|
||||
"""
|
||||
Parser python code to ast tree.
|
||||
|
|
|
@ -130,7 +130,7 @@ set(SUB_COMP
|
|||
frontend/operator
|
||||
pipeline/jit
|
||||
pipeline/pynative
|
||||
common debug gvar pybind_api utils vm
|
||||
common debug pybind_api utils vm
|
||||
)
|
||||
|
||||
foreach (_comp ${SUB_COMP})
|
||||
|
@ -141,32 +141,21 @@ foreach (_comp ${SUB_COMP})
|
|||
add_dependencies(_mindspore_${sub}_obj proto_input )
|
||||
endif ()
|
||||
endforeach ()
|
||||
add_subdirectory(${CMAKE_SOURCE_DIR}/mindspore/core/base base)
|
||||
list(APPEND SUB_OBJECTS_SRC $<TARGET_OBJECTS:_mindspore_base_obj>)
|
||||
add_subdirectory(${CMAKE_SOURCE_DIR}/mindspore/core/abstract abstract)
|
||||
list(APPEND SUB_OBJECTS_SRC $<TARGET_OBJECTS:_mindspore_abstract_obj>)
|
||||
add_subdirectory(${CMAKE_SOURCE_DIR}/mindspore/core/utils util)
|
||||
list(APPEND SUB_OBJECTS_SRC $<TARGET_OBJECTS:_mindspore_core_utils_obj>)
|
||||
add_subdirectory(${CMAKE_SOURCE_DIR}/mindspore/core/ir ir)
|
||||
list(APPEND SUB_OBJECTS_SRC $<TARGET_OBJECTS:_mindspore_ir_obj>)
|
||||
add_dependencies(_mindspore_core_utils_obj _mindspore_base_obj _mindspore_ir_obj _mindspore_abstract_obj proto_input )
|
||||
|
||||
set_property(SOURCE ${SUB_OBJECTS_SRC} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_ME)
|
||||
add_library(mindspore STATIC ${SUB_OBJECTS_SRC})
|
||||
|
||||
target_link_libraries(proto_input mindspore::protobuf)
|
||||
|
||||
target_link_libraries(mindspore mindspore_core)
|
||||
|
||||
if (ENABLE_DEBUGGER)
|
||||
# debugger: link grpc
|
||||
target_link_libraries(proto_input mindspore::grpc++)
|
||||
endif()
|
||||
|
||||
target_link_libraries(mindspore proto_input)
|
||||
if (ENABLE_MPI AND ENABLE_CPU)
|
||||
target_link_libraries(mindspore securec mindspore::flatbuffers mpi_adapter)
|
||||
else ()
|
||||
target_link_libraries(mindspore securec mindspore::flatbuffers)
|
||||
endif ()
|
||||
|
||||
if (NOT WIN32)
|
||||
target_link_libraries(mindspore dl)
|
||||
|
@ -242,7 +231,6 @@ set_target_properties(_c_expression PROPERTIES INSTALL_RPATH ${ORIGIN_PATH})
|
|||
|
||||
if (CMAKE_SYSTEM_NAME MATCHES "Windows")
|
||||
target_link_libraries(mindspore mindspore::pybind11_module)
|
||||
target_link_libraries(mindspore mindspore_gvar)
|
||||
target_link_libraries(_c_expression PRIVATE -Wl,--whole-archive mindspore -Wl,--no-whole-archive)
|
||||
else ()
|
||||
if (ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU))
|
||||
|
@ -253,7 +241,6 @@ else ()
|
|||
endif()
|
||||
target_link_libraries(_c_expression PRIVATE -Wl,--whole-archive mindspore -Wl,--no-whole-archive)
|
||||
target_link_libraries(_c_expression PRIVATE mindspore::pybind11_module)
|
||||
target_link_libraries(_c_expression PRIVATE mindspore_gvar)
|
||||
endif ()
|
||||
|
||||
if (USE_GLOG)
|
||||
|
@ -297,7 +284,7 @@ add_library(inference SHARED
|
|||
${LOAD_ONNX_SRC}
|
||||
)
|
||||
target_link_libraries(inference PRIVATE ${PYTHON_LIBRARIES} ${SECUREC_LIBRARY}
|
||||
-Wl,--whole-archive mindspore -Wl,--no-whole-archive mindspore_gvar mindspore::protobuf)
|
||||
-Wl,--whole-archive mindspore -Wl,--no-whole-archive mindspore::protobuf)
|
||||
|
||||
if (ENABLE_CPU)
|
||||
target_link_libraries(inference PRIVATE mindspore::dnnl mindspore::mkldnn)
|
||||
|
|
|
@ -25,6 +25,7 @@
|
|||
#include "backend/kernel_compiler/kernel.h"
|
||||
#include "ir/dtype.h"
|
||||
#include "ir/primitive.h"
|
||||
#include "pybind11/pybind11.h"
|
||||
#include <nlohmann/json.hpp>
|
||||
#include "backend/kernel_compiler/common_utils.h"
|
||||
#include "backend/kernel_compiler/oplib/oplib.h"
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
*/
|
||||
#include "backend/kernel_compiler/cpu/allgather_cpu_kernel.h"
|
||||
#include "runtime/device/cpu/cpu_device_address.h"
|
||||
#include "runtime/device/cpu/mpi/mpi_adapter.h"
|
||||
#include "runtime/device/cpu/mpi/mpi_interface.h"
|
||||
#include "utils/log_adapter.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -45,9 +45,7 @@ bool AllGatherCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
|||
auto input_addr = reinterpret_cast<float *>(inputs[0]->addr);
|
||||
auto output_addr = reinterpret_cast<float *>(outputs[0]->addr);
|
||||
auto input_data_num = inputs[0]->size / sizeof(float);
|
||||
auto mpi_instance = device::cpu::MPIAdapter::Instance();
|
||||
MS_EXCEPTION_IF_NULL(mpi_instance);
|
||||
return mpi_instance->AllGather(input_addr, output_addr, ranks_group_, input_data_num);
|
||||
return MPIAllGather(input_addr, output_addr, ranks_group_, input_data_num);
|
||||
}
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
#include <thread>
|
||||
#include "backend/kernel_compiler/cpu/embedding_look_up_comm_grad_cpu_kernel.h"
|
||||
#include "runtime/device/cpu/cpu_device_address.h"
|
||||
#include "runtime/device/cpu/mpi/mpi_adapter.h"
|
||||
#include "runtime/device/cpu/mpi/mpi_interface.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
|
@ -49,11 +49,8 @@ bool EmbeddingLookUpCommGradCPUKernel::Launch(const std::vector<kernel::AddressP
|
|||
const std::vector<int> &rank_group = {0, 1, 2, 3, 4, 5, 6, 7};
|
||||
size_t input_split_lens = input_size / split_num_ / sizeof(float_t);
|
||||
size_t output_split_lens = output_size / split_num_ / sizeof(float_t);
|
||||
auto mpi_instance = device::cpu::MPIAdapter::Instance();
|
||||
MS_EXCEPTION_IF_NULL(mpi_instance);
|
||||
for (int i = 0; i < split_num_; i++) {
|
||||
mpi_instance->AllGather(input_addr + i * input_split_lens, output_addr + i * output_split_lens, rank_group,
|
||||
input_split_lens);
|
||||
MPIAllGather(input_addr + i * input_split_lens, output_addr + i * output_split_lens, rank_group, input_split_lens);
|
||||
}
|
||||
#if defined(_WIN32) || defined(_WIN64)
|
||||
auto end_time = std::chrono::steady_clock::now();
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
*/
|
||||
#include "backend/kernel_compiler/cpu/reduce_scatter_cpu_kernel.h"
|
||||
#include "runtime/device/cpu/cpu_device_address.h"
|
||||
#include "runtime/device/cpu/mpi/mpi_adapter.h"
|
||||
#include "runtime/device/cpu/mpi/mpi_interface.h"
|
||||
#include "ir/primitive.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -24,7 +24,7 @@ namespace {
|
|||
constexpr auto kRanksGroup = "group";
|
||||
} // namespace
|
||||
|
||||
ReduceScatterCPUKernel::ReduceScatterCPUKernel() : op_type_(device::cpu::kOpTypeSum) {}
|
||||
ReduceScatterCPUKernel::ReduceScatterCPUKernel() : op_type_(kMPIOpTypeSum) {}
|
||||
|
||||
void ReduceScatterCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
||||
auto op = AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("op");
|
||||
|
@ -46,9 +46,7 @@ bool ReduceScatterCPUKernel::Launch(const std::vector<kernel::AddressPtr> &input
|
|||
auto input_addr = reinterpret_cast<float *>(inputs[0]->addr);
|
||||
auto output_addr = reinterpret_cast<float *>(outputs[0]->addr);
|
||||
auto output_data_num = outputs[0]->size / sizeof(float);
|
||||
auto mpi_instance = device::cpu::MPIAdapter::Instance();
|
||||
MS_EXCEPTION_IF_NULL(mpi_instance);
|
||||
return mpi_instance->ReduceScatter(input_addr, output_addr, ranks_group_, output_data_num, op_type_);
|
||||
return MPIReduceScatter(input_addr, output_addr, ranks_group_, output_data_num, op_type_);
|
||||
}
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -13,8 +13,9 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <thread>
|
||||
#include "backend/kernel_compiler/cpu/sub_cpu_kernel.h"
|
||||
#include <sys/time.h>
|
||||
#include <thread>
|
||||
#include "runtime/device/cpu/cpu_device_address.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
|
|
@ -182,30 +182,59 @@ class ArrayReduceGpuKernel : public GpuKernel {
|
|||
void InferInAndOutDesc(const std::vector<size_t> &input_shape, const std::vector<size_t> &output_shape) {
|
||||
std::vector<int> inputA;
|
||||
std::vector<size_t> outputC_shape = output_shape;
|
||||
const int split_dim = 4;
|
||||
|
||||
if (input_shape.size() <= split_dim) {
|
||||
ShapeNdTo4d(input_shape, &inputA);
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(inputA_descriptor_, CUDNN_TENSOR_NCHW, data_type_, inputA[0],
|
||||
inputA[1], inputA[2], inputA[3]),
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(inputA_descriptor_, CUDNN_TENSOR_NCHW, data_type_,
|
||||
inputA[0], inputA[1], inputA[2], inputA[3]),
|
||||
"cudnnSetTensor4dDescriptor failed");
|
||||
} else {
|
||||
CudnnSetTensorNdDescriptor(input_shape, inputA_descriptor_, data_type_);
|
||||
for (auto dim : input_shape) {
|
||||
inputA.emplace_back(SizeToInt(dim));
|
||||
}
|
||||
}
|
||||
|
||||
if (axis_[0] == -1) {
|
||||
outputC_shape.resize(input_shape.size(), 1);
|
||||
if (outputC_shape.size() <= split_dim) {
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(
|
||||
cudnnSetTensor4dDescriptor(outputC_descriptor_, CUDNN_TENSOR_NCHW, data_type_, 1, 1, 1, 1),
|
||||
"cudnnSetTensor4dDescriptor failed");
|
||||
if (inputA[0] == 1 && inputA[1] == 1 && inputA[2] == 1 && inputA[3] == 1) {
|
||||
all_match_ = true;
|
||||
} else {
|
||||
CudnnSetTensorNdDescriptor(outputC_shape, outputC_descriptor_, data_type_);
|
||||
}
|
||||
|
||||
for (auto dim : inputA) {
|
||||
if (dim != 1) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
all_match_ = true;
|
||||
return;
|
||||
}
|
||||
|
||||
std::vector<int> outputC;
|
||||
if (!keep_dims_) {
|
||||
for (auto i : axis_) {
|
||||
(void)(outputC_shape.insert(outputC_shape.begin() + i, 1));
|
||||
}
|
||||
}
|
||||
std::vector<int> outputC;
|
||||
|
||||
if (outputC_shape.size() <= split_dim) {
|
||||
ShapeNdTo4d(outputC_shape, &outputC);
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(outputC_descriptor_, CUDNN_TENSOR_NCHW, data_type_,
|
||||
outputC[0], outputC[1], outputC[2], outputC[3]),
|
||||
"cudnnSetTensor4dDescriptor failed");
|
||||
} else {
|
||||
CudnnSetTensorNdDescriptor(outputC_shape, outputC_descriptor_, data_type_);
|
||||
for (auto dim : outputC_shape) {
|
||||
outputC.emplace_back(SizeToInt(dim));
|
||||
}
|
||||
}
|
||||
|
||||
if (inputA == outputC) {
|
||||
all_match_ = true;
|
||||
}
|
||||
|
|
|
@ -69,6 +69,10 @@ class ScatterNdGpuFwdKernel : public GpuKernel {
|
|||
memcpy_flag_ = true;
|
||||
}
|
||||
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(
|
||||
cudaMemsetAsync(output, static_cast<T>(0.0), output_size_, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"cudaMemSet failed in ScatterNdGpuFwdKernel::Launch.");
|
||||
|
||||
const size_t input_size = input_size_ / sizeof(T);
|
||||
const size_t output_size = output_size_ / sizeof(T);
|
||||
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <vector>
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cuh"
|
||||
#include "runtime/device/gpu/cuda_common.h"
|
||||
|
||||
|
@ -54,6 +55,11 @@ struct RealDivFunc {
|
|||
__device__ __forceinline__ S operator()(const T &lhs, const T &rhs) { return (lhs / rhs); }
|
||||
};
|
||||
|
||||
template <typename T, typename S>
|
||||
struct DivFunc {
|
||||
__device__ __forceinline__ S operator()(const T &lhs, const T &rhs) { return (lhs / rhs); }
|
||||
};
|
||||
|
||||
template <typename T, typename S>
|
||||
struct MulFunc {
|
||||
__device__ __forceinline__ S operator()(const T &lhs, const T &rhs) { return (lhs * rhs); }
|
||||
|
@ -95,7 +101,6 @@ struct AbsGradFunc {
|
|||
}
|
||||
};
|
||||
|
||||
|
||||
template <>
|
||||
struct PowerFunc<half, bool> {
|
||||
// invalid branch
|
||||
|
@ -104,72 +109,100 @@ struct PowerFunc<half, bool> {
|
|||
|
||||
__device__ __forceinline__ int Index(const int &index, const int &dim) { return dim == 1 ? 0 : index; }
|
||||
|
||||
|
||||
template <typename T, typename S, typename Func>
|
||||
__device__ __forceinline__ void BroadcastOperator(const int &l0, const int &l1, const int &l2, const int &l3,
|
||||
const int &r0, const int &r1, const int &r2, const int &r3,
|
||||
const int &d0, const int &d1, const int &d2, const int &d3,
|
||||
const T *input0, const T *input1, S *output) {
|
||||
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < d0 * d1 * d2 * d3; pos += blockDim.x * gridDim.x) {
|
||||
int i = pos / (d1 * d2 * d3) % d0;
|
||||
int j = pos / (d2 * d3) % d1;
|
||||
int k = pos / d3 % d2;
|
||||
int l = pos % d3;
|
||||
const int &l4, const int &l5, const int &l6, const int &r0,
|
||||
const int &r1, const int &r2, const int &r3, const int &r4,
|
||||
const int &r5, const int &r6, const int &d0, const int &d1,
|
||||
const int &d2, const int &d3, const int &d4, const int &d5,
|
||||
const int &d6, const T *input0, const T *input1, S *output) {
|
||||
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < d0 * d1 * d2 * d3 * d4 * d5 * d6;
|
||||
pos += blockDim.x * gridDim.x) {
|
||||
int i = pos / (d1 * d2 * d3 * d4 * d5 * d6) % d0;
|
||||
int j = pos / (d2 * d3 * d4 * d5 * d6) % d1;
|
||||
int k = pos / (d3 * d4 * d5 * d6) % d2;
|
||||
int l = pos / (d4 * d5 * d6) % d3;
|
||||
int m = pos / (d5 * d6) % d4;
|
||||
int n = pos / d6 % d5;
|
||||
int o = pos % d6;
|
||||
|
||||
int l_index = Index(i, l0) * l1 * l2 * l3 + Index(j, l1) * l2 * l3 + Index(k, l2) * l3 + Index(l, l3);
|
||||
int r_index = Index(i, r0) * r1 * r2 * r3 + Index(j, r1) * r2 * r3 + Index(k, r2) * r3 + Index(l, r3);
|
||||
int l_index = Index(i, l0) * l1 * l2 * l3 * l4 * l5 * l6;
|
||||
l_index += Index(j, l1) * l2 * l3 * l4 * l5 * l6;
|
||||
l_index += Index(k, l2) * l3 * l4 * l5 * l6;
|
||||
l_index += Index(l, l3) * l4 * l5 * l6;
|
||||
l_index += Index(m, l4) * l5 * l6;
|
||||
l_index += Index(n, l5) * l6;
|
||||
l_index += Index(o, l6);
|
||||
int r_index = Index(i, r0) * r1 * r2 * r3 * r4 * r5 * r6;
|
||||
r_index += Index(j, r1) * r2 * r3 * r4 * r5 * r6;
|
||||
r_index += Index(k, r2) * r3 * r4 * r5 * r6;
|
||||
r_index += Index(l, r3) * r4 * r5 * r6;
|
||||
r_index += Index(m, r4) * r5 * r6;
|
||||
r_index += Index(n, r5) * r6;
|
||||
r_index += Index(o, r6);
|
||||
output[pos] = Func()(input0[l_index], input1[r_index]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename S>
|
||||
__global__ void BroadcastKernel(const int l0, const int l1, const int l2, const int l3, const int r0, const int r1,
|
||||
const int r2, const int r3, const int d0, const int d1, const int d2, const int d3,
|
||||
enum BroadcastOpType op, const T *input0, const T *input1, S *output) {
|
||||
__global__ void BroadcastKernel(const int l0, const int l1, const int l2, const int l3, const int l4, const int l5,
|
||||
const int l6, const int r0, const int r1, const int r2, const int r3, const int r4,
|
||||
const int r5, const int r6, const int d0, const int d1, const int d2, const int d3,
|
||||
const int d4, const int d5, const int d6, enum BroadcastOpType op, const T *input0,
|
||||
const T *input1, S *output) {
|
||||
switch (op) {
|
||||
case BROADCAST_TYPE_GREATER:
|
||||
return BroadcastOperator<T, S, GreaterFunc<T, S>>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1,
|
||||
output);
|
||||
return BroadcastOperator<T, S, GreaterFunc<T, S>>(l0, l1, l2, l3, l4, l5, l6, r0, r1, r2, r3, r4, r5, r6, d0, d1,
|
||||
d2, d3, d4, d5, d6, input0, input1, output);
|
||||
case BROADCAST_TYPE_LESS:
|
||||
return BroadcastOperator<T, S, LessFunc<T, S>>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1,
|
||||
output);
|
||||
return BroadcastOperator<T, S, LessFunc<T, S>>(l0, l1, l2, l3, l4, l5, l6, r0, r1, r2, r3, r4, r5, r6, d0, d1, d2,
|
||||
d3, d4, d5, d6, input0, input1, output);
|
||||
case BROADCAST_TYPE_MINIMUM:
|
||||
return BroadcastOperator<T, S, MinimumFunc<T, S>>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1,
|
||||
output);
|
||||
return BroadcastOperator<T, S, MinimumFunc<T, S>>(l0, l1, l2, l3, l4, l5, l6, r0, r1, r2, r3, r4, r5, r6, d0, d1,
|
||||
d2, d3, d4, d5, d6, input0, input1, output);
|
||||
case BROADCAST_TYPE_MAXIMUM:
|
||||
return BroadcastOperator<T, S, MaximumFunc<T, S>>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1,
|
||||
output);
|
||||
return BroadcastOperator<T, S, MaximumFunc<T, S>>(l0, l1, l2, l3, l4, l5, l6, r0, r1, r2, r3, r4, r5, r6, d0, d1,
|
||||
d2, d3, d4, d5, d6, input0, input1, output);
|
||||
case BROADCAST_TYPE_POWER:
|
||||
return BroadcastOperator<T, S, PowerFunc<T, S>>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1,
|
||||
output);
|
||||
return BroadcastOperator<T, S, PowerFunc<T, S>>(l0, l1, l2, l3, l4, l5, l6, r0, r1, r2, r3, r4, r5, r6, d0, d1,
|
||||
d2, d3, d4, d5, d6, input0, input1, output);
|
||||
case BROADCAST_TYPE_REALDIV:
|
||||
return BroadcastOperator<T, S, RealDivFunc<T, S>>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1,
|
||||
output);
|
||||
return BroadcastOperator<T, S, RealDivFunc<T, S>>(l0, l1, l2, l3, l4, l5, l6, r0, r1, r2, r3, r4, r5, r6, d0, d1,
|
||||
d2, d3, d4, d5, d6, input0, input1, output);
|
||||
case BROADCAST_TYPE_MUL:
|
||||
return BroadcastOperator<T, S, MulFunc<T, S>>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1,
|
||||
output);
|
||||
return BroadcastOperator<T, S, MulFunc<T, S>>(l0, l1, l2, l3, l4, l5, l6, r0, r1, r2, r3, r4, r5, r6, d0, d1, d2,
|
||||
d3, d4, d5, d6, input0, input1, output);
|
||||
case BROADCAST_TYPE_SUB:
|
||||
return BroadcastOperator<T, S, SubFunc<T, S>>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1,
|
||||
output);
|
||||
return BroadcastOperator<T, S, SubFunc<T, S>>(l0, l1, l2, l3, l4, l5, l6, r0, r1, r2, r3, r4, r5, r6, d0, d1, d2,
|
||||
d3, d4, d5, d6, input0, input1, output);
|
||||
case BROADCAST_TYPE_ADD:
|
||||
return BroadcastOperator<T, S, AddFunc<T, S>>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1,
|
||||
output);
|
||||
return BroadcastOperator<T, S, AddFunc<T, S>>(l0, l1, l2, l3, l4, l5, l6, r0, r1, r2, r3, r4, r5, r6, d0, d1, d2,
|
||||
d3, d4, d5, d6, input0, input1, output);
|
||||
case BROADCAST_TYPE_FLOORDIV:
|
||||
return BroadcastOperator<T, S, FloorDivFunc<T, S>>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1,
|
||||
output);
|
||||
return BroadcastOperator<T, S, FloorDivFunc<T, S>>(l0, l1, l2, l3, l4, l5, l6, r0, r1, r2, r3, r4, r5, r6, d0, d1,
|
||||
d2, d3, d4, d5, d6, input0, input1, output);
|
||||
case BROADCAST_TYPE_ABSGRAD:
|
||||
return BroadcastOperator<T, S, AbsGradFunc<T, S>>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1,
|
||||
output);
|
||||
return BroadcastOperator<T, S, AbsGradFunc<T, S>>(l0, l1, l2, l3, l4, l5, l6, r0, r1, r2, r3, r4, r5, r6, d0, d1,
|
||||
d2, d3, d4, d5, d6, input0, input1, output);
|
||||
case BROADCAST_TYPE_DIV:
|
||||
return BroadcastOperator<T, S, DivFunc<T, S>>(l0, l1, l2, l3, l4, l5, l6, r0, r1, r2, r3, r4, r5, r6, d0, d1, d2,
|
||||
d3, d4, d5, d6, input0, input1, output);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename S>
|
||||
void Broadcast(const int &l0, const int &l1, const int &l2, const int &l3, const int &r0, const int &r1, const int &r2,
|
||||
const int &r3, const int &d0, const int &d1, const int &d2, const int &d3, enum BroadcastOpType op,
|
||||
const T *input0, const T *input1, S *output, cudaStream_t stream) {
|
||||
int size = d0 * d1 * d2 * d3;
|
||||
BroadcastKernel<<<GET_BLOCKS(size), GET_THREADS, 0, stream>>>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, op,
|
||||
input0, input1, output);
|
||||
void Broadcast(const std::vector<int> &lhs_shape, const std::vector<int> &rhs_shape,
|
||||
const std::vector<int> &output_shape, enum BroadcastOpType op, const T *input0, const T *input1,
|
||||
S *output, cudaStream_t stream) {
|
||||
int size = 1;
|
||||
for (auto d : output_shape) {
|
||||
size *= d;
|
||||
}
|
||||
BroadcastKernel<<<GET_BLOCKS(size), GET_THREADS, 0, stream>>>(
|
||||
lhs_shape[0], lhs_shape[1], lhs_shape[2], lhs_shape[3], lhs_shape[4], lhs_shape[5], lhs_shape[6], rhs_shape[0],
|
||||
rhs_shape[1], rhs_shape[2], rhs_shape[3], rhs_shape[4], rhs_shape[5], rhs_shape[6], output_shape[0],
|
||||
output_shape[1], output_shape[2], output_shape[3], output_shape[4], output_shape[5], output_shape[6], op, input0,
|
||||
input1, output);
|
||||
}
|
||||
|
||||
template <typename T, typename S, typename Func>
|
||||
|
@ -205,6 +238,8 @@ __global__ void NoBroadcastKernel(const int nums, enum BroadcastOpType op, const
|
|||
return NoBroadcastOperator<T, S, FloorDivFunc<T, S>>(nums, input0, input1, output);
|
||||
case BROADCAST_TYPE_ABSGRAD:
|
||||
return NoBroadcastOperator<T, S, AbsGradFunc<T, S>>(nums, input0, input1, output);
|
||||
case BROADCAST_TYPE_DIV:
|
||||
return NoBroadcastOperator<T, S, DivFunc<T, S>>(nums, input0, input1, output);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -215,8 +250,8 @@ void NoBroadcast(const int &nums, enum BroadcastOpType op, const T *input0, cons
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void BroadcastToKernel(const int i0, const int i1, const int i2, const int i3, const int o0,
|
||||
const int o1, const int o2, const int o3, const T *input_addr, T *output_addr) {
|
||||
__global__ void BroadcastToKernel(const int i0, const int i1, const int i2, const int i3, const int o0, const int o1,
|
||||
const int o2, const int o3, const T *input_addr, T *output_addr) {
|
||||
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < o0 * o1 * o2 * o3; pos += blockDim.x * gridDim.x) {
|
||||
int i = pos / (o1 * o2 * o3) % o0;
|
||||
int j = pos / (o2 * o3) % o1;
|
||||
|
@ -236,30 +271,24 @@ void BroadcastTo(const int &i0, const int &i1, const int &i2, const int &i3, con
|
|||
output_addr);
|
||||
}
|
||||
|
||||
template void Broadcast(const int &l0, const int &l1, const int &l2, const int &l3, const int &r0, const int &r1,
|
||||
const int &r2, const int &r3, const int &d0, const int &d1, const int &d2, const int &d3,
|
||||
enum BroadcastOpType op, const float *input0, const float *input1, bool *output,
|
||||
cudaStream_t stream);
|
||||
template void Broadcast(const int &l0, const int &l1, const int &l2, const int &l3, const int &r0, const int &r1,
|
||||
const int &r2, const int &r3, const int &d0, const int &d1, const int &d2, const int &d3,
|
||||
enum BroadcastOpType op, const float *input0, const float *input1, float *output,
|
||||
cudaStream_t stream);
|
||||
template void Broadcast(const int &l0, const int &l1, const int &l2, const int &l3, const int &r0, const int &r1,
|
||||
const int &r2, const int &r3, const int &d0, const int &d1, const int &d2, const int &d3,
|
||||
enum BroadcastOpType op, const half *input0, const half *input1, bool *output,
|
||||
cudaStream_t stream);
|
||||
template void Broadcast(const int &l0, const int &l1, const int &l2, const int &l3, const int &r0, const int &r1,
|
||||
const int &r2, const int &r3, const int &d0, const int &d1, const int &d2, const int &d3,
|
||||
enum BroadcastOpType op, const half *input0, const half *input1, half *output,
|
||||
cudaStream_t stream);
|
||||
template void Broadcast(const int &l0, const int &l1, const int &l2, const int &l3, const int &r0, const int &r1,
|
||||
const int &r2, const int &r3, const int &d0, const int &d1, const int &d2, const int &d3,
|
||||
enum BroadcastOpType op, const int *input0, const int *input1, int *output,
|
||||
cudaStream_t stream);
|
||||
template void Broadcast(const int &l0, const int &l1, const int &l2, const int &l3, const int &r0, const int &r1,
|
||||
const int &r2, const int &r3, const int &d0, const int &d1, const int &d2, const int &d3,
|
||||
enum BroadcastOpType op, const int *input0, const int *input1, bool *output,
|
||||
cudaStream_t stream);
|
||||
template void Broadcast(const std::vector<int> &lhs_shape, const std::vector<int> &rhs_shape,
|
||||
const std::vector<int> &output_shape, enum BroadcastOpType op, const float *input0,
|
||||
const float *input1, bool *output, cudaStream_t stream);
|
||||
template void Broadcast(const std::vector<int> &lhs_shape, const std::vector<int> &rhs_shape,
|
||||
const std::vector<int> &output_shape, enum BroadcastOpType op, const float *input0,
|
||||
const float *input1, float *output, cudaStream_t stream);
|
||||
template void Broadcast(const std::vector<int> &lhs_shape, const std::vector<int> &rhs_shape,
|
||||
const std::vector<int> &output_shape, enum BroadcastOpType op, const half *input0,
|
||||
const half *input1, bool *output, cudaStream_t stream);
|
||||
template void Broadcast(const std::vector<int> &lhs_shape, const std::vector<int> &rhs_shape,
|
||||
const std::vector<int> &output_shape, enum BroadcastOpType op, const half *input0,
|
||||
const half *input1, half *output, cudaStream_t stream);
|
||||
template void Broadcast(const std::vector<int> &lhs_shape, const std::vector<int> &rhs_shape,
|
||||
const std::vector<int> &output_shape, enum BroadcastOpType op, const int *input0,
|
||||
const int *input1, int *output, cudaStream_t stream);
|
||||
template void Broadcast(const std::vector<int> &lhs_shape, const std::vector<int> &rhs_shape,
|
||||
const std::vector<int> &output_shape, enum BroadcastOpType op, const int *input0,
|
||||
const int *input1, bool *output, cudaStream_t stream);
|
||||
template void NoBroadcast(const int &nums, enum BroadcastOpType op, const float *input0, const float *input1,
|
||||
bool *output, cudaStream_t stream);
|
||||
template void NoBroadcast(const int &nums, enum BroadcastOpType op, const float *input0, const float *input1,
|
||||
|
@ -268,10 +297,10 @@ template void NoBroadcast(const int &nums, enum BroadcastOpType op, const half *
|
|||
bool *output, cudaStream_t stream);
|
||||
template void NoBroadcast(const int &nums, enum BroadcastOpType op, const half *input0, const half *input1,
|
||||
half *output, cudaStream_t stream);
|
||||
template void NoBroadcast(const int &nums, enum BroadcastOpType op, const int *input0, const int *input1,
|
||||
int *output, cudaStream_t stream);
|
||||
template void NoBroadcast(const int &nums, enum BroadcastOpType op, const int *input0, const int *input1,
|
||||
bool *output, cudaStream_t stream);
|
||||
template void NoBroadcast(const int &nums, enum BroadcastOpType op, const int *input0, const int *input1, int *output,
|
||||
cudaStream_t stream);
|
||||
template void NoBroadcast(const int &nums, enum BroadcastOpType op, const int *input0, const int *input1, bool *output,
|
||||
cudaStream_t stream);
|
||||
template void BroadcastTo(const int &i0, const int &i1, const int &i2, const int &i3, const int &o0, const int &o1,
|
||||
const int &o2, const int &o3, const float *input_addr, float *output_addr,
|
||||
cudaStream_t stream);
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_BROADCAST_H_
|
||||
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_BROADCAST_H_
|
||||
|
||||
#include <vector>
|
||||
#include "runtime/device/gpu/cuda_common.h"
|
||||
|
||||
enum BroadcastOpType {
|
||||
|
@ -31,13 +32,14 @@ enum BroadcastOpType {
|
|||
BROADCAST_TYPE_ADD = 8,
|
||||
BROADCAST_TYPE_FLOORDIV = 9,
|
||||
BROADCAST_TYPE_ABSGRAD = 10,
|
||||
BROADCAST_TYPE_DIV = 11,
|
||||
BROADCAST_TYPE_INVALID = 0xffffffff,
|
||||
};
|
||||
|
||||
template <typename T, typename S>
|
||||
void Broadcast(const int &l0, const int &l1, const int &l2, const int &l3, const int &r0, const int &r1, const int &r2,
|
||||
const int &r3, const int &d0, const int &d1, const int &d2, const int &d3, enum BroadcastOpType op,
|
||||
const T *input0, const T *input1, S *output, cudaStream_t stream);
|
||||
void Broadcast(const std::vector<int> &lhs_shape, const std::vector<int> &rhs_shape,
|
||||
const std::vector<int> &output_shape, enum BroadcastOpType op, const T *input0, const T *input1,
|
||||
S *output, cudaStream_t stream);
|
||||
|
||||
template <typename T, typename S>
|
||||
void NoBroadcast(const int &size, enum BroadcastOpType op, const T *input0, const T *input1, S *output,
|
||||
|
|
|
@ -25,10 +25,10 @@ __global__ void CheckValidKernel(const size_t size, const T *box, const T *img_m
|
|||
const size_t right_y = i * 4 + 3;
|
||||
|
||||
S valid_flag = false;
|
||||
valid_flag |= !(box[left_x] >= 0.f);
|
||||
valid_flag |= !(box[left_y] >= 0.f);
|
||||
valid_flag |= !(img_metas[0] * img_metas[2] - 1.f >= box[right_x]);
|
||||
valid_flag |= !(img_metas[1] * img_metas[2] - 1.f >= box[right_y]);
|
||||
valid_flag |= !(box[left_x] >= static_cast<T>(0.0));
|
||||
valid_flag |= !(box[left_y] >= static_cast<T>(0.0));
|
||||
valid_flag |= !(img_metas[1] * img_metas[2] - static_cast<T>(1.0) >= box[right_x]);
|
||||
valid_flag |= !(img_metas[0] * img_metas[2] - static_cast<T>(1.0) >= box[right_y]);
|
||||
|
||||
valid[i] = !valid_flag;
|
||||
}
|
||||
|
@ -43,3 +43,5 @@ void CheckValid(const size_t &size, const T *box, const T *img_metas, S *valid,
|
|||
|
||||
template void CheckValid(const size_t &size, const float *box, const float *img_metas, bool *valid,
|
||||
cudaStream_t cuda_stream);
|
||||
template void CheckValid(const size_t &size, const half *box, const half *img_metas, bool *valid,
|
||||
cudaStream_t cuda_stream);
|
||||
|
|
|
@ -18,12 +18,85 @@
|
|||
#include "runtime/device/gpu/cuda_common.h"
|
||||
|
||||
template <typename T>
|
||||
__global__ void CumSumKernel(T *input, T *output, size_t dim0, size_t dim1, size_t dim2, size_t stride,
|
||||
__global__ void Copy(T *input, T *output, size_t size) {
|
||||
size_t step = blockDim.x * gridDim.x;
|
||||
for (size_t write_index = blockIdx.x * blockDim.x + threadIdx.x; write_index < size; write_index += step) {
|
||||
input[write_index] = output[write_index];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void LeftMove(const T *input, T *output, size_t dim0, size_t dim1, size_t dim2, size_t stride,
|
||||
size_t stride2) {
|
||||
size_t num = dim0 * dim2;
|
||||
size_t i, k, offset;
|
||||
for (size_t write_index = blockIdx.x * blockDim.x + threadIdx.x; write_index < num;
|
||||
write_index += blockDim.x * gridDim.x) {
|
||||
size_t step = blockDim.x * gridDim.x;
|
||||
for (size_t write_index = blockIdx.x * blockDim.x + threadIdx.x; write_index < num; write_index += step) {
|
||||
i = write_index / dim2 % dim0;
|
||||
k = write_index % dim2;
|
||||
offset = i * stride + k;
|
||||
for (size_t j = 0; j < dim1; ++j) {
|
||||
size_t read_index = j * stride2 + offset;
|
||||
if (j == 0) {
|
||||
output[read_index] = 0;
|
||||
} else {
|
||||
size_t read_index2 = (j - 1) * stride2 + offset;
|
||||
output[read_index] = input[read_index2];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void RightMove(const T *input, T *output, size_t dim0, size_t dim1, size_t dim2, size_t stride,
|
||||
size_t stride2) {
|
||||
size_t num = dim0 * dim2;
|
||||
size_t i, k, offset;
|
||||
size_t step = blockDim.x * gridDim.x;
|
||||
for (size_t write_index = blockIdx.x * blockDim.x + threadIdx.x; write_index < num; write_index += step) {
|
||||
i = write_index / dim2 % dim0;
|
||||
k = write_index % dim2;
|
||||
offset = i * stride + k;
|
||||
for (int j = dim1 - 1; j >= 0; --j) {
|
||||
size_t read_index = j * stride2 + offset;
|
||||
if (j == dim1 - 1) {
|
||||
output[read_index] = 0;
|
||||
} else {
|
||||
size_t read_index2 = (j + 1) * stride2 + offset;
|
||||
output[read_index] = input[read_index2];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
template <typename T>
|
||||
__global__ void CumSumKernelReverse(const T *input, T *output, size_t dim0, size_t dim1, size_t dim2, size_t stride,
|
||||
size_t stride2) {
|
||||
size_t num = dim0 * dim2;
|
||||
size_t i, k, offset;
|
||||
size_t step = blockDim.x * gridDim.x;
|
||||
for (size_t write_index = blockIdx.x * blockDim.x + threadIdx.x; write_index < num; write_index += step) {
|
||||
i = write_index / dim2 % dim0;
|
||||
k = write_index % dim2;
|
||||
offset = i * stride + k;
|
||||
for (int j = dim1 - 1; j >= 0; --j) {
|
||||
size_t read_index = j * stride2 + offset;
|
||||
if (j == dim1 - 1) {
|
||||
output[read_index] = input[read_index];
|
||||
} else {
|
||||
size_t read_index2 = (j + 1) * stride2 + offset;
|
||||
output[read_index] = output[read_index2] + input[read_index];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void CumSumKernel(const T *input, T *output, size_t dim0, size_t dim1, size_t dim2, size_t stride,
|
||||
size_t stride2) {
|
||||
size_t num = dim0 * dim2;
|
||||
size_t i, k, offset;
|
||||
size_t step = blockDim.x * gridDim.x;
|
||||
for (size_t write_index = blockIdx.x * blockDim.x + threadIdx.x; write_index < num; write_index += step) {
|
||||
i = write_index / dim2 % dim0;
|
||||
k = write_index % dim2;
|
||||
offset = i * stride + k;
|
||||
|
@ -39,12 +112,32 @@ __global__ void CumSumKernel(T *input, T *output, size_t dim0, size_t dim1, size
|
|||
}
|
||||
}
|
||||
template <typename T>
|
||||
void CumSum(T *input, T *output, size_t dim0, size_t dim1, size_t dim2, size_t stride, size_t stride2,
|
||||
cudaStream_t stream) {
|
||||
void CumSum(const T *input, T *output, T *workspace, size_t dim0, size_t dim1, size_t dim2, size_t stride,
|
||||
size_t stride2, bool exclusive_, bool reverse_, cudaStream_t stream) {
|
||||
int size = dim0 * dim2;
|
||||
if (exclusive_) {
|
||||
if (reverse_) {
|
||||
RightMove<<<GET_BLOCKS(size), GET_THREADS, 0, stream>>>(input, output, dim0, dim1, dim2, stride, stride2);
|
||||
Copy<<<GET_BLOCKS(size * dim1), GET_THREADS, 0, stream>>>(workspace, output, size * dim1);
|
||||
CumSumKernelReverse<<<GET_BLOCKS(size), GET_THREADS, 0, stream>>>(workspace, output, dim0, dim1, dim2, stride,
|
||||
stride2);
|
||||
} else {
|
||||
LeftMove<<<GET_BLOCKS(size), GET_THREADS, 0, stream>>>(input, output, dim0, dim1, dim2, stride, stride2);
|
||||
Copy<<<GET_BLOCKS(size * dim1), GET_THREADS, 0, stream>>>(workspace, output, size * dim1);
|
||||
CumSumKernel<<<GET_BLOCKS(size), GET_THREADS, 0, stream>>>(workspace, output, dim0, dim1, dim2, stride, stride2);
|
||||
}
|
||||
} else {
|
||||
if (reverse_) {
|
||||
CumSumKernelReverse<<<GET_BLOCKS(size), GET_THREADS, 0, stream>>>(input, output, dim0, dim1, dim2, stride,
|
||||
stride2);
|
||||
} else {
|
||||
CumSumKernel<<<GET_BLOCKS(size), GET_THREADS, 0, stream>>>(input, output, dim0, dim1, dim2, stride, stride2);
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
template void CumSum<float>(float *input, float *output, size_t dim0, size_t dim1, size_t dim2, size_t stride,
|
||||
size_t stride2, cudaStream_t stream);
|
||||
template void CumSum<float>(const float *input, float *output, float *workspace, size_t dim0, size_t dim1, size_t dim2,
|
||||
size_t stride, size_t stride2, bool exclusive_, bool reverse_, cudaStream_t stream);
|
||||
template void CumSum<half>(const half *input, half *output, half *workspace, size_t dim0, size_t dim1, size_t dim2,
|
||||
size_t stride, size_t stride2, bool exclusive_, bool reverse_, cudaStream_t stream);
|
||||
|
|
|
@ -17,6 +17,6 @@
|
|||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUMSUM_IMPL_CUH_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUMSUM_IMPL_CUH_
|
||||
template <typename T>
|
||||
void CumSum(T *input, T *output, size_t dim0, size_t dim1, size_t dim2, size_t stride, size_t stride2,
|
||||
cudaStream_t stream);
|
||||
void CumSum(const T *input, T *output, T *workspace, size_t dim0, size_t dim1, size_t dim2, size_t stride,
|
||||
size_t stride2, bool exclusive_, bool reverse_, cudaStream_t stream);
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUMSUM_IMPL_CUH_
|
||||
|
|
|
@ -16,27 +16,26 @@
|
|||
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/iou_impl.cuh"
|
||||
|
||||
template <typename T>
|
||||
__device__ T CoordinateMax(const T a, const T b) {
|
||||
__device__ float CoordinateMax(const float a, const float b) {
|
||||
return (a > b ? a : b);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ T CoordinateMin(const T a, const T b) {
|
||||
__device__ float CoordinateMin(const float a, const float b) {
|
||||
return (a < b ? a : b);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void IOUKernel(const size_t size, const T *box1, const T *box2, T *iou_results, const size_t mode,
|
||||
const size_t input_len_0) {
|
||||
T location_coordinate[IOU_LOCATION_NUM][IOU_DIMENSION];
|
||||
T overlaps_coordinate[IOU_DIMENSION];
|
||||
const T epsilon = 1e-10;
|
||||
float location_coordinate[IOU_LOCATION_NUM][IOU_DIMENSION];
|
||||
float overlaps_coordinate[IOU_DIMENSION];
|
||||
const float epsilon = 1e-10;
|
||||
const float offset = 1.0;
|
||||
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) {
|
||||
for (size_t j = 0; j < IOU_DIMENSION; j++) {
|
||||
location_coordinate[0][j] = box1[(i % input_len_0) * IOU_DIMENSION + j];
|
||||
location_coordinate[1][j] = box2[(i / input_len_0) * IOU_DIMENSION + j];
|
||||
location_coordinate[0][j] = static_cast<float>(box1[(i % input_len_0) * IOU_DIMENSION + j]);
|
||||
location_coordinate[1][j] = static_cast<float>(box2[(i / input_len_0) * IOU_DIMENSION + j]);
|
||||
}
|
||||
|
||||
overlaps_coordinate[0] = CoordinateMax(location_coordinate[0][0], location_coordinate[1][0]);
|
||||
|
@ -44,18 +43,18 @@ __global__ void IOUKernel(const size_t size, const T *box1, const T *box2, T *io
|
|||
overlaps_coordinate[2] = CoordinateMin(location_coordinate[0][2], location_coordinate[1][2]);
|
||||
overlaps_coordinate[3] = CoordinateMin(location_coordinate[0][3], location_coordinate[1][3]);
|
||||
|
||||
T overlaps_w = CoordinateMax(0.f, overlaps_coordinate[2] - overlaps_coordinate[0] + 1);
|
||||
T overlaps_h = CoordinateMax(0.f, overlaps_coordinate[3] - overlaps_coordinate[1] + 1);
|
||||
T overlaps = overlaps_w * overlaps_h;
|
||||
float overlaps_w = CoordinateMax(0.0, overlaps_coordinate[2] - overlaps_coordinate[0] + offset);
|
||||
float overlaps_h = CoordinateMax(0.0, overlaps_coordinate[3] - overlaps_coordinate[1] + offset);
|
||||
float overlaps = overlaps_w * overlaps_h;
|
||||
|
||||
T area1 = (location_coordinate[0][2] - location_coordinate[0][0] + 1) * (location_coordinate[0][3] -
|
||||
location_coordinate[0][1] + 1);
|
||||
T area2 = (location_coordinate[1][2] - location_coordinate[1][0] + 1) * (location_coordinate[1][3] -
|
||||
location_coordinate[1][1] + 1);
|
||||
float area1 = (location_coordinate[0][2] - location_coordinate[0][0] + offset) * (location_coordinate[0][3] -
|
||||
location_coordinate[0][1] + offset);
|
||||
float area2 = (location_coordinate[1][2] - location_coordinate[1][0] + offset) * (location_coordinate[1][3] -
|
||||
location_coordinate[1][1] + offset);
|
||||
if (mode == 0) {
|
||||
iou_results[i] = overlaps / (area1 + area2 - overlaps + epsilon);
|
||||
iou_results[i] = static_cast<T>(overlaps / (area1 + area2 - overlaps + epsilon));
|
||||
} else {
|
||||
iou_results[i] = overlaps / (area2 + epsilon);
|
||||
iou_results[i] = static_cast<T>(overlaps / (area2 + epsilon));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -70,3 +69,5 @@ void IOU(const size_t &size, const T *box1, const T *box2, T *iou_results, const
|
|||
|
||||
template void IOU(const size_t &size, const float *box1, const float *box2, float *iou_results, const size_t &mode,
|
||||
const size_t &input_len_0, cudaStream_t cuda_stream);
|
||||
template void IOU(const size_t &size, const half *box1, const half *box2, half *iou_results, const size_t &mode,
|
||||
const size_t &input_len_0, cudaStream_t cuda_stream);
|
||||
|
|
|
@ -84,6 +84,40 @@ class GpuKernel : public KernelMod {
|
|||
}
|
||||
}
|
||||
|
||||
// set the tensor descriptor for cudnn/cublas
|
||||
void CudnnSetTensorNdDescriptor(const std::vector<size_t> &shape, cudnnTensorDescriptor_t descriptor,
|
||||
cudnnDataType_t data_type) {
|
||||
if (shape.size() < 3) {
|
||||
MS_EXCEPTION(ValueError) << "cudnnSetTensorNdDescriptor don't support" << shape.size() << "D.";
|
||||
}
|
||||
const int nbDims = shape.size();
|
||||
int *dim = new (std::nothrow) int[nbDims];
|
||||
if (dim == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "malloc dim failed.";
|
||||
}
|
||||
int *stride = new (std::nothrow) int[nbDims];
|
||||
if (stride == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "malloc stride failed.";
|
||||
}
|
||||
|
||||
for (int i = 0; i < nbDims; i++) {
|
||||
dim[i] = SizeToInt(shape[i]);
|
||||
stride[i] = 1;
|
||||
}
|
||||
|
||||
for (int i = nbDims - 2; i >= 0; i--) {
|
||||
stride[i] = stride[i + 1] * SizeToInt(shape[i + 1]);
|
||||
}
|
||||
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensorNdDescriptor(descriptor, data_type, nbDims, dim, stride),
|
||||
"cudnnSetTensorNdDescriptor failed");
|
||||
|
||||
delete[] dim;
|
||||
dim = nullptr;
|
||||
delete[] stride;
|
||||
stride = nullptr;
|
||||
}
|
||||
|
||||
// choose the suitable datatype for cudnn/cublas
|
||||
inline cudnnDataType_t GetCudnnDataType(const std::string &Type) {
|
||||
auto type = kCudnnDtypeMap.find(Type);
|
||||
|
|
|
@ -59,6 +59,9 @@ MS_REG_GPU_KERNEL_TWO(
|
|||
AbsGrad,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
BroadcastOpGpuKernel, float, float)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
Div, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
BroadcastOpGpuKernel, float, float)
|
||||
|
||||
// fp16
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
|
@ -101,6 +104,9 @@ MS_REG_GPU_KERNEL_TWO(
|
|||
AbsGrad,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
BroadcastOpGpuKernel, half, half)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
Div, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
BroadcastOpGpuKernel, half, half)
|
||||
|
||||
// int32
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
|
@ -118,14 +124,14 @@ MS_REG_GPU_KERNEL_TWO(
|
|||
MS_REG_GPU_KERNEL_TWO(
|
||||
Mul, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
BroadcastOpGpuKernel, int, int)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
RealDiv, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
BroadcastOpGpuKernel, int, int)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
FloorDiv, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
BroadcastOpGpuKernel, int, int)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
AbsGrad, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
BroadcastOpGpuKernel, int, int)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
Div, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
BroadcastOpGpuKernel, int, int)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -27,6 +27,7 @@
|
|||
#include "backend/kernel_compiler/gpu/kernel_constants.h"
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
constexpr int MAX_DIMS = 7;
|
||||
template <typename T, typename S>
|
||||
class BroadcastOpGpuKernel : public GpuKernel {
|
||||
public:
|
||||
|
@ -45,9 +46,8 @@ class BroadcastOpGpuKernel : public GpuKernel {
|
|||
S *output = GetDeviceAddress<S>(outputs, 0);
|
||||
|
||||
if (need_broadcast_) {
|
||||
Broadcast(lhs_shape_[0], lhs_shape_[1], lhs_shape_[2], lhs_shape_[3], rhs_shape_[0], rhs_shape_[1], rhs_shape_[2],
|
||||
rhs_shape_[3], output_shape_[0], output_shape_[1], output_shape_[2], output_shape_[3], op_type_, lhs,
|
||||
rhs, output, reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
Broadcast(lhs_shape_, rhs_shape_, output_shape_, op_type_, lhs, rhs, output,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
} else {
|
||||
NoBroadcast(output_num_, op_type_, lhs, rhs, output, reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
}
|
||||
|
@ -60,10 +60,13 @@ class BroadcastOpGpuKernel : public GpuKernel {
|
|||
auto shape2 = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
|
||||
auto shape3 = AnfAlgo::GetOutputInferShape(kernel_node, 0);
|
||||
need_broadcast_ = IsBroadcast(shape1, shape2);
|
||||
if (need_broadcast_ && shape1.size() > 4) {
|
||||
MS_LOG(EXCEPTION) << "Broadcast operation not support dim greater than 4";
|
||||
if (need_broadcast_ && shape1.size() > 7) {
|
||||
MS_LOG(EXCEPTION) << "Broadcast operation not support dim greater than 7";
|
||||
}
|
||||
|
||||
lhs_shape_.resize(MAX_DIMS, 1);
|
||||
rhs_shape_.resize(MAX_DIMS, 1);
|
||||
output_shape_.resize(MAX_DIMS, 1);
|
||||
for (size_t i = 0; i < shape3.size(); i++) {
|
||||
output_shape_[i] = shape3[i];
|
||||
output_num_ *= shape3[i];
|
||||
|
@ -99,7 +102,7 @@ class BroadcastOpGpuKernel : public GpuKernel {
|
|||
{"Greater", BROADCAST_TYPE_GREATER}, {"Less", BROADCAST_TYPE_LESS}, {"Maximum", BROADCAST_TYPE_MAXIMUM},
|
||||
{"Minimum", BROADCAST_TYPE_MINIMUM}, {"Pow", BROADCAST_TYPE_POWER}, {"RealDiv", BROADCAST_TYPE_REALDIV},
|
||||
{"Mul", BROADCAST_TYPE_MUL}, {"Sub", BROADCAST_TYPE_SUB}, {"TensorAdd", BROADCAST_TYPE_ADD},
|
||||
{"FloorDiv", BROADCAST_TYPE_FLOORDIV}, {"AbsGrad", BROADCAST_TYPE_ABSGRAD},
|
||||
{"FloorDiv", BROADCAST_TYPE_FLOORDIV}, {"AbsGrad", BROADCAST_TYPE_ABSGRAD}, {"Div", BROADCAST_TYPE_DIV},
|
||||
};
|
||||
|
||||
auto iter = kBroadcastTypeMap.find(kernel_name);
|
||||
|
@ -127,9 +130,9 @@ class BroadcastOpGpuKernel : public GpuKernel {
|
|||
int input1_num_;
|
||||
int input2_num_;
|
||||
int output_num_;
|
||||
int lhs_shape_[4] = {1, 1, 1, 1};
|
||||
int rhs_shape_[4] = {1, 1, 1, 1};
|
||||
int output_shape_[4] = {1, 1, 1, 1};
|
||||
std::vector<int> lhs_shape_;
|
||||
std::vector<int> rhs_shape_;
|
||||
std::vector<int> output_shape_;
|
||||
|
||||
std::vector<size_t> input_size_list_;
|
||||
std::vector<size_t> output_size_list_;
|
||||
|
|
|
@ -20,5 +20,7 @@ namespace mindspore {
|
|||
namespace kernel {
|
||||
MS_REG_GPU_KERNEL_ONE(CumSum, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
CumSumGpuKernel, float)
|
||||
MS_REG_GPU_KERNEL_ONE(CumSum, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
CumSumGpuKernel, half)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -27,7 +27,7 @@ namespace kernel {
|
|||
template <typename T>
|
||||
class CumSumGpuKernel : public GpuKernel {
|
||||
public:
|
||||
CumSumGpuKernel() : axis_(0), input_size_0_(0), stride_(0), stride2_(0) {}
|
||||
CumSumGpuKernel() : exclusive_(false), reverse_(false), axis_(0), input_size_0_(0), stride_(0), stride2_(0) {}
|
||||
~CumSumGpuKernel() = default;
|
||||
|
||||
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
|
||||
|
@ -38,7 +38,8 @@ class CumSumGpuKernel : public GpuKernel {
|
|||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
||||
T *input_addr = GetDeviceAddress<T>(inputs, 0);
|
||||
T *output_addr = GetDeviceAddress<T>(outputs, 0);
|
||||
CumSum(input_addr, output_addr, dims_[0], dims_[1], dims_[2], stride_, stride2_,
|
||||
T *ws_addr = GetDeviceAddress<T>(workspace, 0);
|
||||
CumSum(input_addr, output_addr, ws_addr, dims_[0], dims_[1], dims_[2], stride_, stride2_, exclusive_, reverse_,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
return true;
|
||||
}
|
||||
|
@ -51,6 +52,8 @@ class CumSumGpuKernel : public GpuKernel {
|
|||
input_size_0_ = sizeof(T);
|
||||
shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||
axis_ = GetAttr<int>(kernel_node, "axis");
|
||||
exclusive_ = GetAttr<bool>(kernel_node, "exclusive");
|
||||
reverse_ = GetAttr<bool>(kernel_node, "reverse");
|
||||
int input_dim_length = SizeToInt(shape_.size());
|
||||
if (axis_ >= input_dim_length) {
|
||||
MS_LOG(EXCEPTION) << "Axis out of bounds.";
|
||||
|
@ -70,6 +73,7 @@ class CumSumGpuKernel : public GpuKernel {
|
|||
void InitSizeLists() override {
|
||||
input_size_list_.push_back(input_size_0_);
|
||||
output_size_list_.push_back(input_size_0_);
|
||||
workspace_size_list_.push_back(input_size_0_);
|
||||
}
|
||||
|
||||
private:
|
||||
|
@ -87,6 +91,8 @@ class CumSumGpuKernel : public GpuKernel {
|
|||
stride2_ = dims_[2];
|
||||
return;
|
||||
}
|
||||
bool exclusive_;
|
||||
bool reverse_;
|
||||
int axis_;
|
||||
size_t input_size_0_;
|
||||
size_t stride_;
|
||||
|
|
|
@ -83,12 +83,19 @@ class ActivationGpuFwdKernel : public GpuKernel {
|
|||
return true;
|
||||
}
|
||||
std::vector<int> shape;
|
||||
ShapeNdTo4d(input_shape, &shape);
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetActivationDescriptor(activation_desc_, mode_, CUDNN_NOT_PROPAGATE_NAN, 0.0),
|
||||
"cudnnSetActivationDescriptor failed");
|
||||
|
||||
const int split_dim = 4;
|
||||
if (input_shape.size() <= split_dim) {
|
||||
ShapeNdTo4d(input_shape, &shape);
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(data_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_,
|
||||
shape[0], shape[1], shape[2], shape[3]),
|
||||
"cudnnSetTensor4dDescriptor failed");
|
||||
} else {
|
||||
CudnnSetTensorNdDescriptor(input_shape, data_descriptor_, cudnn_data_type_);
|
||||
}
|
||||
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
|
|
|
@ -90,12 +90,18 @@ class ActivationGradGpuKernel : public GpuKernel {
|
|||
return true;
|
||||
}
|
||||
std::vector<int> shape;
|
||||
ShapeNdTo4d(input_shape, &shape);
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetActivationDescriptor(activation_desc_, mode_, CUDNN_PROPAGATE_NAN, 0.0),
|
||||
"SetActivationDescriptor failed");
|
||||
|
||||
const int split_dim = 4;
|
||||
if (input_shape.size() <= split_dim) {
|
||||
ShapeNdTo4d(input_shape, &shape);
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(data_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_,
|
||||
shape[0], shape[1], shape[2], shape[3]),
|
||||
"SetTensor4dDescriptor failed");
|
||||
} else {
|
||||
CudnnSetTensorNdDescriptor(input_shape, data_descriptor_, cudnn_data_type_);
|
||||
}
|
||||
|
||||
InitSizeLists();
|
||||
return true;
|
||||
|
|
|
@ -54,12 +54,18 @@ class DropoutGpuFwdKernel : public GpuKernel {
|
|||
float *mask_f = GetDeviceAddress<float>(workspace, 0);
|
||||
|
||||
if (!states_init_) {
|
||||
curandCreateGenerator(&mask_generator_, CURAND_RNG_PSEUDO_DEFAULT);
|
||||
curandSetPseudoRandomGeneratorSeed(mask_generator_, time(NULL));
|
||||
CHECK_CURAND_RET_WITH_EXCEPT(curandCreateGenerator(&mask_generator_, CURAND_RNG_PSEUDO_DEFAULT),
|
||||
"Failed to create generator");
|
||||
CHECK_CURAND_RET_WITH_EXCEPT(curandSetPseudoRandomGeneratorSeed(mask_generator_, time(NULL)),
|
||||
"Failed to SetPseudoRandomGeneratorSeed");
|
||||
MS_EXCEPTION_IF_NULL(mask_generator_);
|
||||
states_init_ = true;
|
||||
}
|
||||
CHECK_CURAND_RET_WITH_EXCEPT(curandSetStream(mask_generator_, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"Failed to set stream for generator");
|
||||
// curandGen only support float or double for mask.
|
||||
curandGenerateUniform(mask_generator_, mask_f, num_count_);
|
||||
CHECK_CURAND_RET_WITH_EXCEPT(curandGenerateUniform(mask_generator_, mask_f, num_count_),
|
||||
"Failed to generate uniform");
|
||||
DropoutForward(input, mask, output, mask_f, num_count_, keep_prob_, reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
|
||||
return true;
|
||||
|
|
|
@ -22,5 +22,9 @@ MS_REG_GPU_KERNEL_TWO(
|
|||
CheckValid,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool),
|
||||
CheckValidGpuKernel, float, bool)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
CheckValid,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeBool),
|
||||
CheckValidGpuKernel, half, bool)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -21,5 +21,8 @@ namespace kernel {
|
|||
MS_REG_GPU_KERNEL_ONE(
|
||||
IOU, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
IOUGpuKernel, float)
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
IOU, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
IOUGpuKernel, half)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -80,6 +80,7 @@
|
|||
#include "backend/optimizer/ascend/buffer_fusion/conv_single_in_fusion_pass.h"
|
||||
#include "backend/optimizer/ascend/buffer_fusion/conv_double_in_fusion_pass.h"
|
||||
#include "backend/optimizer/ascend/buffer_fusion/matmul_eltwise_fusion_pass.h"
|
||||
#include "backend/optimizer/ascend/buffer_fusion/matmul_confusiontranspose_fusion_pass.h"
|
||||
#include "backend/optimizer/ascend/buffer_fusion/depthwiseconv_eltwise_fusion_pass.h"
|
||||
#include "backend/optimizer/ascend/buffer_fusion/bnupdate_eltwise_fusion_pass.h"
|
||||
#include "backend/optimizer/ascend/buffer_fusion/bnupdate_eltwise_eltwise_fusion_pass.h"
|
||||
|
@ -124,6 +125,10 @@ void AddAscendIRFusionRulesPass(PassManager *ir_fusion_pm) {
|
|||
ir_fusion_pm->AddPass(std::make_shared<LambNextMVRuleCond4>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<LambNextRightRule>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<LambUpdateWithLrV2>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<AdamApplyOneAssignCond1Fusion>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<AdamApplyOneAssignCond2Fusion>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<AdamApplyOneAssignCond3Fusion>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<AdamApplyOneAssignCond4Fusion>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<AdamApplyOneCond1Fusion>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<AdamApplyOneCond2Fusion>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<AdamApplyOneCond3Fusion>());
|
||||
|
@ -308,6 +313,7 @@ void RunOpAscendBackendIRFusionOptimization(const std::shared_ptr<session::Kerne
|
|||
}
|
||||
auto optimizer = std::make_shared<GraphOptimizer>();
|
||||
auto ir_fusion_pm = std::make_shared<PassManager>("ir_fusion_pm");
|
||||
ir_fusion_pm->AddPass(std::make_shared<SplitFission>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<BnSplit>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<LayerNormGradSplit>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<TopKSplit>());
|
||||
|
|
|
@ -27,15 +27,15 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
void BnupdateEltwiseFusionPass::MatchBnupdateRelu(const CNodePtr &cnode, const AnfNodePtr &relu_input,
|
||||
void BnupdateEltwiseFusionPass::MatchBnupdateDoubleOutputEltwise(const CNodePtr &cnode, const AnfNodePtr &eltwise_input,
|
||||
const session::KernelGraph &kernel_graph,
|
||||
FusedNodeRecord *candidate_fusion) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
MS_EXCEPTION_IF_NULL(candidate_fusion);
|
||||
auto manager = kernel_graph.manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
MS_EXCEPTION_IF_NULL(relu_input);
|
||||
auto getitem = relu_input->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(eltwise_input);
|
||||
auto getitem = eltwise_input->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(getitem);
|
||||
auto bnupdate = getitem->input(1);
|
||||
MS_EXCEPTION_IF_NULL(bnupdate);
|
||||
|
@ -68,10 +68,11 @@ void BnupdateEltwiseFusionPass::MatchSingleFusionPattern(const session::KernelGr
|
|||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (AnfAlgo::GetKernelType(cnode) == KernelType::TBE_KERNEL &&
|
||||
AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE) {
|
||||
AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE &&
|
||||
AnfAlgo::GetOutputTensorNum(cnode) == ELTWISE_DOUBLE_OUTPUT_SIZE) {
|
||||
auto eltwise_input = cnode->input(1);
|
||||
if (eltwise_input->isa<CNode>() && AnfAlgo::CheckPrimitiveType(eltwise_input, prim::kPrimTupleGetItem)) {
|
||||
MatchBnupdateRelu(cnode, eltwise_input, kernel_graph, candidate_fusion);
|
||||
MatchBnupdateDoubleOutputEltwise(cnode, eltwise_input, kernel_graph, candidate_fusion);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -39,8 +39,8 @@ class BnupdateEltwiseFusionPass : public FusionBasePass {
|
|||
void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override;
|
||||
|
||||
private:
|
||||
void MatchBnupdateRelu(const CNodePtr &cnode, const AnfNodePtr &relu_input, const session::KernelGraph &kernel_graph,
|
||||
FusedNodeRecord *candidate_fusion);
|
||||
void MatchBnupdateDoubleOutputEltwise(const CNodePtr &cnode, const AnfNodePtr &eltwise_input,
|
||||
const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion);
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -33,6 +33,7 @@ const int8_t MAX_ELTWISE_NUM = 3;
|
|||
const int8_t MIN_ELTWISE_SIZE = 2;
|
||||
const int8_t ELTWISE_INPUT_SIZE = 2;
|
||||
const int8_t ELTWISE_DOUBLE_IN_INPUT_SIZE = 3;
|
||||
const int8_t ELTWISE_DOUBLE_OUTPUT_SIZE = 2;
|
||||
const int8_t CONV_DOUBLE_IN_INPUT_SIZE = 3;
|
||||
const int8_t CONV_QUART_IN_INPUT_SIZE = 5;
|
||||
const int8_t ELTWISE_USE = 1;
|
||||
|
|
|
@ -0,0 +1,66 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "backend/optimizer/ascend/buffer_fusion/matmul_confusiontranspose_fusion_pass.h"
|
||||
#include <vector>
|
||||
#include <unordered_set>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include "backend/kernel_compiler/kernel_fusion.h"
|
||||
#include "debug/anf_ir_dump.h"
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
#include "frontend/operator/ops.h"
|
||||
#include "utils/ms_context.h"
|
||||
#include "backend/optimizer/common/fusion_id_allocator.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
void MatmulConfusionTranposeFusionPass::MatchMatmulConfusionTranpose(const CNodePtr &cnode,
|
||||
const session::KernelGraph &kernel_graph,
|
||||
FusedNodeRecord *candidate_fusion) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
MS_EXCEPTION_IF_NULL(candidate_fusion);
|
||||
auto manager = kernel_graph.manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
auto matmul = cnode->input(1);
|
||||
MS_EXCEPTION_IF_NULL(matmul);
|
||||
if (matmul->isa<CNode>() && AnfAlgo::CheckPrimitiveType(matmul, prim::kPrimMatMul)) {
|
||||
std::vector<int> output_used_num{SizeToInt(manager->node_users()[matmul].size())};
|
||||
AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(output_used_num), matmul);
|
||||
std::unordered_set<AnfNodePtr> record{cnode, matmul};
|
||||
candidate_fusion->push_back(record);
|
||||
SetRecordFusionId(record);
|
||||
}
|
||||
}
|
||||
|
||||
void MatmulConfusionTranposeFusionPass::MatchSingleFusionPattern(const session::KernelGraph &kernel_graph,
|
||||
FusedNodeRecord *candidate_fusion) {
|
||||
MS_EXCEPTION_IF_NULL(candidate_fusion);
|
||||
std::vector<AnfNodePtr> node_list = TopoSort(kernel_graph.get_return());
|
||||
for (auto &node : node_list) {
|
||||
if (!AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node) ||
|
||||
AnfAlgo::CheckPrimitiveType(node, prim::kPrimReturn)) {
|
||||
continue;
|
||||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
|
||||
if (AnfAlgo::GetCNodeName(cnode) == kConfusionTransposeDOpName) {
|
||||
MatchMatmulConfusionTranpose(cnode, kernel_graph, candidate_fusion);
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,48 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_BUFFER_FUSION_PASS_MATMUL_CONFUSIONTRANSPOSE_FUSION_PASS_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_BUFFER_FUSION_PASS_MATMUL_CONFUSIONTRANSPOSE_FUSION_PASS_H_
|
||||
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
||||
#include "backend/optimizer/ascend/buffer_fusion/fusion_base_pass.h"
|
||||
#include "ir/anf.h"
|
||||
#include "backend/optimizer/common/pass.h"
|
||||
#include "backend/optimizer/common/fusion_id_allocator.h"
|
||||
#include "runtime/device/kernel_info.h"
|
||||
#include "backend/kernel_compiler/kernel.h"
|
||||
#include "backend/session/kernel_graph.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
using FusedNodeRecord = std::vector<std::unordered_set<AnfNodePtr>>;
|
||||
|
||||
class MatmulConfusionTranposeFusionPass : public FusionBasePass {
|
||||
public:
|
||||
explicit MatmulConfusionTranposeFusionPass(FusionIdAllocatorPtr idAllocator)
|
||||
: FusionBasePass("MatmulConfusionTranposeFusionPass", idAllocator) {}
|
||||
~MatmulConfusionTranposeFusionPass() override = default;
|
||||
void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override;
|
||||
|
||||
private:
|
||||
void MatchMatmulConfusionTranpose(const CNodePtr &cnode, const session::KernelGraph &kernel_graph,
|
||||
FusedNodeRecord *candidate_fusion);
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_BUFFER_FUSION_PASS_MATMUL_CONFUSIONTRANSPOSE_FUSION_PASS_H_
|
|
@ -172,7 +172,6 @@ AnfNodePtr MergeCastToNextOp(const FuncGraphPtr &graph, const CNodePtr &node, co
|
|||
<< "ori kernel info" << ori_kernel_info->ToString() << "alternative kernel info"
|
||||
<< (*alternative_kernel_info)->ToString();
|
||||
AnfAlgo::SetSelectKernelBuildInfo(*alternative_kernel_info, next_cnode.get());
|
||||
ChangeNodeInferInfo(next_cnode, node, cast_index);
|
||||
if (node->inputs().size() < kCastInputNum) {
|
||||
MS_LOG(EXCEPTION) << "Op[" << node->DebugString() << "] has wrong input num:";
|
||||
}
|
||||
|
|
|
@ -15,30 +15,9 @@
|
|||
*/
|
||||
#include "backend/optimizer/ascend/ir_fusion/adam_apply_one_fusion.h"
|
||||
#include "backend/optimizer/common/helper.h"
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
AnfNodePtr AdamApplyOneFusion::CreateAdamApplyOneNode(const FuncGraphPtr &func_graph, const EquivPtr &equiv) const {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(equiv);
|
||||
auto prim = std::make_shared<Primitive>(kAdamApplyOneOpName);
|
||||
std::vector<AnfNodePtr> new_node_inputs = {NewValueNode(prim)};
|
||||
for (const auto &input_var : input_vars_) {
|
||||
auto input_node = utils::cast<AnfNodePtr>((*equiv)[input_var]);
|
||||
MS_EXCEPTION_IF_NULL(input_node);
|
||||
new_node_inputs.push_back(input_node);
|
||||
}
|
||||
for (const auto &mul_x_input_var : mul_x_input_vars_) {
|
||||
auto mul_x_input_node = utils::cast<AnfNodePtr>((*equiv)[mul_x_input_var]);
|
||||
MS_EXCEPTION_IF_NULL(mul_x_input_node);
|
||||
new_node_inputs.push_back(mul_x_input_node);
|
||||
}
|
||||
auto add2_y_node = utils::cast<AnfNodePtr>((*equiv)[add2_y_]);
|
||||
MS_EXCEPTION_IF_NULL(add2_y_node);
|
||||
new_node_inputs.push_back(add2_y_node);
|
||||
auto new_node = func_graph->NewCNode(new_node_inputs);
|
||||
return new_node;
|
||||
}
|
||||
|
||||
const BaseRef AdamApplyOneFusion::DefinePattern() const {
|
||||
const auto prim_sqrt = std::make_shared<Primitive>(kSqrtOpName);
|
||||
const auto prim_real_div = std::make_shared<Primitive>(kRealDivOpName);
|
||||
|
@ -104,16 +83,152 @@ const BaseRef AdamApplyOneCond4Fusion::DefinePattern() const {
|
|||
return VectorRef({prim::kPrimSub, input_vars_[3], VectorRef({prim::kPrimMul, true_div0, input_vars_[4]})});
|
||||
}
|
||||
|
||||
const BaseRef AdamApplyOneAssignFusion::DefinePattern() const {
|
||||
const auto prim_sqrt = std::make_shared<Primitive>(kSqrtOpName);
|
||||
const auto prim_real_div = std::make_shared<Primitive>(kRealDivOpName);
|
||||
VectorRef mul2 = VectorRef({prim::kPrimMul, mul_x_input_vars_[2], input_vars_[1]});
|
||||
VectorRef mul3 = VectorRef({prim::kPrimMul, mul_x_input_vars_[3], VectorRef({prim::kPrimSquare, input_vars_[0]})});
|
||||
VectorRef add1 = VectorRef({add1_var_, mul2, mul3});
|
||||
VectorRef sqrt0 = VectorRef({prim_sqrt, add1});
|
||||
VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[1], input_vars_[0]});
|
||||
VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[0], input_vars_[2]});
|
||||
VectorRef add0 = VectorRef({add0_var_, mul0, mul1});
|
||||
VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimTensorAdd, sqrt0, add2_y_})});
|
||||
VectorRef sub0 = VectorRef({sub0_var_, input_vars_[3], VectorRef({prim::kPrimMul, input_vars_[4], true_div0})});
|
||||
VectorRef assign0 = VectorRef({prim::kPrimAssign, input_vars_[3], sub0});
|
||||
VectorRef depend0 = VectorRef({prim::kPrimDepend, sub0, assign0});
|
||||
VectorRef assign1 = VectorRef({prim::kPrimAssign, input_vars_[2], add0});
|
||||
VectorRef depend1 = VectorRef({prim::kPrimDepend, depend0, assign1});
|
||||
VectorRef assign2 = VectorRef({prim::kPrimAssign, input_vars_[1], add1});
|
||||
return VectorRef({prim::kPrimDepend, depend1, assign2});
|
||||
}
|
||||
|
||||
const BaseRef AdamApplyOneAssignCond1Fusion::DefinePattern() const {
|
||||
const auto prim_sqrt = std::make_shared<Primitive>(kSqrtOpName);
|
||||
const auto prim_real_div = std::make_shared<Primitive>(kRealDivOpName);
|
||||
VectorRef mul2 = VectorRef({prim::kPrimMul, mul_x_input_vars_[2], input_vars_[1]});
|
||||
VectorRef mul3 = VectorRef({prim::kPrimMul, mul_x_input_vars_[3], VectorRef({prim::kPrimSquare, input_vars_[0]})});
|
||||
VectorRef add1 = VectorRef({add1_var_, mul2, mul3});
|
||||
VectorRef sqrt0 = VectorRef({prim_sqrt, add1});
|
||||
VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[1], input_vars_[0]});
|
||||
VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[0], input_vars_[2]});
|
||||
VectorRef add0 = VectorRef({add0_var_, mul0, mul1});
|
||||
VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimTensorAdd, add2_y_, sqrt0})});
|
||||
VectorRef sub0 = VectorRef({sub0_var_, input_vars_[3], VectorRef({prim::kPrimMul, input_vars_[4], true_div0})});
|
||||
VectorRef assign0 = VectorRef({prim::kPrimAssign, input_vars_[3], sub0});
|
||||
VectorRef depend0 = VectorRef({prim::kPrimDepend, sub0, assign0});
|
||||
VectorRef assign1 = VectorRef({prim::kPrimAssign, input_vars_[2], add0});
|
||||
VectorRef depend1 = VectorRef({prim::kPrimDepend, depend0, assign1});
|
||||
VectorRef assign2 = VectorRef({prim::kPrimAssign, input_vars_[1], add1});
|
||||
return VectorRef({prim::kPrimDepend, depend1, assign2});
|
||||
}
|
||||
|
||||
const BaseRef AdamApplyOneAssignCond2Fusion::DefinePattern() const {
|
||||
const auto prim_sqrt = std::make_shared<Primitive>(kSqrtOpName);
|
||||
const auto prim_real_div = std::make_shared<Primitive>(kRealDivOpName);
|
||||
VectorRef mul2 = VectorRef({prim::kPrimMul, mul_x_input_vars_[2], input_vars_[1]});
|
||||
VectorRef mul3 = VectorRef({prim::kPrimMul, VectorRef({prim::kPrimSquare, input_vars_[0]}), mul_x_input_vars_[3]});
|
||||
VectorRef add1 = VectorRef({add1_var_, mul2, mul3});
|
||||
VectorRef sqrt0 = VectorRef({prim_sqrt, add1});
|
||||
VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[1], input_vars_[0]});
|
||||
VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[0], input_vars_[2]});
|
||||
VectorRef add0 = VectorRef({add0_var_, mul0, mul1});
|
||||
VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimTensorAdd, sqrt0, add2_y_})});
|
||||
VectorRef sub0 = VectorRef({sub0_var_, input_vars_[3], VectorRef({prim::kPrimMul, true_div0, input_vars_[4]})});
|
||||
VectorRef assign0 = VectorRef({prim::kPrimAssign, input_vars_[3], sub0});
|
||||
VectorRef depend0 = VectorRef({prim::kPrimDepend, sub0, assign0});
|
||||
VectorRef assign1 = VectorRef({prim::kPrimAssign, input_vars_[2], add0});
|
||||
VectorRef depend1 = VectorRef({prim::kPrimDepend, depend0, assign1});
|
||||
VectorRef assign2 = VectorRef({prim::kPrimAssign, input_vars_[1], add1});
|
||||
return VectorRef({prim::kPrimDepend, depend1, assign2});
|
||||
}
|
||||
|
||||
const BaseRef AdamApplyOneAssignCond3Fusion::DefinePattern() const {
|
||||
const auto prim_sqrt = std::make_shared<Primitive>(kSqrtOpName);
|
||||
const auto prim_real_div = std::make_shared<Primitive>(kRealDivOpName);
|
||||
VectorRef mul2 = VectorRef({prim::kPrimMul, mul_x_input_vars_[2], input_vars_[1]});
|
||||
VectorRef mul3 = VectorRef({prim::kPrimMul, mul_x_input_vars_[3], VectorRef({prim::kPrimSquare, input_vars_[0]})});
|
||||
VectorRef add1 = VectorRef({add1_var_, mul2, mul3});
|
||||
VectorRef sqrt0 = VectorRef({prim_sqrt, add1});
|
||||
VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[1], input_vars_[0]});
|
||||
VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[0], input_vars_[2]});
|
||||
VectorRef add0 = VectorRef({add0_var_, mul0, mul1});
|
||||
VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimTensorAdd, sqrt0, add2_y_})});
|
||||
VectorRef sub0 = VectorRef({sub0_var_, input_vars_[3], VectorRef({prim::kPrimMul, true_div0, input_vars_[4]})});
|
||||
VectorRef assign0 = VectorRef({prim::kPrimAssign, input_vars_[3], sub0});
|
||||
VectorRef depend0 = VectorRef({prim::kPrimDepend, sub0, assign0});
|
||||
VectorRef assign1 = VectorRef({prim::kPrimAssign, input_vars_[2], add0});
|
||||
VectorRef depend1 = VectorRef({prim::kPrimDepend, depend0, assign1});
|
||||
VectorRef assign2 = VectorRef({prim::kPrimAssign, input_vars_[1], add1});
|
||||
return VectorRef({prim::kPrimDepend, depend1, assign2});
|
||||
}
|
||||
|
||||
const BaseRef AdamApplyOneAssignCond4Fusion::DefinePattern() const {
|
||||
const auto prim_sqrt = std::make_shared<Primitive>(kSqrtOpName);
|
||||
const auto prim_real_div = std::make_shared<Primitive>(kRealDivOpName);
|
||||
VectorRef mul2 = VectorRef({prim::kPrimMul, mul_x_input_vars_[2], input_vars_[1]});
|
||||
VectorRef mul3 = VectorRef({prim::kPrimMul, mul_x_input_vars_[3], VectorRef({prim::kPrimSquare, input_vars_[0]})});
|
||||
VectorRef add1 = VectorRef({add1_var_, mul2, mul3});
|
||||
VectorRef sqrt0 = VectorRef({prim_sqrt, add1});
|
||||
VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[1], input_vars_[0]});
|
||||
VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[0], input_vars_[2]});
|
||||
VectorRef add0 = VectorRef({add0_var_, mul0, mul1});
|
||||
VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimTensorAdd, add2_y_, sqrt0})});
|
||||
VectorRef sub0 = VectorRef({sub0_var_, input_vars_[3], VectorRef({prim::kPrimMul, true_div0, input_vars_[4]})});
|
||||
VectorRef assign0 = VectorRef({prim::kPrimAssign, input_vars_[3], sub0});
|
||||
VectorRef depend0 = VectorRef({prim::kPrimDepend, sub0, assign0});
|
||||
VectorRef assign1 = VectorRef({prim::kPrimAssign, input_vars_[2], add0});
|
||||
VectorRef depend1 = VectorRef({prim::kPrimDepend, depend0, assign1});
|
||||
VectorRef assign2 = VectorRef({prim::kPrimAssign, input_vars_[1], add1});
|
||||
return VectorRef({prim::kPrimDepend, depend1, assign2});
|
||||
}
|
||||
|
||||
AnfNodePtr AdamApplyOneFusion::CreateAdamApplyOneNode(const FuncGraphPtr &func_graph, const EquivPtr &equiv,
|
||||
const AnfNodePtr &final_node) const {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(equiv);
|
||||
PrimitivePtr prim = nullptr;
|
||||
if (AnfAlgo::CheckPrimitiveType(final_node, prim::kPrimDepend)) {
|
||||
prim = std::make_shared<Primitive>(kAdamApplyOneAssignOpName);
|
||||
} else {
|
||||
prim = std::make_shared<Primitive>(kAdamApplyOneOpName);
|
||||
}
|
||||
std::vector<AnfNodePtr> new_node_inputs = {NewValueNode(prim)};
|
||||
for (const auto &input_var : input_vars_) {
|
||||
auto input_node = utils::cast<AnfNodePtr>((*equiv)[input_var]);
|
||||
MS_EXCEPTION_IF_NULL(input_node);
|
||||
new_node_inputs.push_back(input_node);
|
||||
}
|
||||
for (const auto &mul_x_input_var : mul_x_input_vars_) {
|
||||
auto mul_x_input_node = utils::cast<AnfNodePtr>((*equiv)[mul_x_input_var]);
|
||||
MS_EXCEPTION_IF_NULL(mul_x_input_node);
|
||||
new_node_inputs.push_back(mul_x_input_node);
|
||||
}
|
||||
auto add2_y_node = utils::cast<AnfNodePtr>((*equiv)[add2_y_]);
|
||||
MS_EXCEPTION_IF_NULL(add2_y_node);
|
||||
new_node_inputs.push_back(add2_y_node);
|
||||
auto new_node = func_graph->NewCNode(new_node_inputs);
|
||||
return new_node;
|
||||
}
|
||||
|
||||
const AnfNodePtr AdamApplyOneFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||
const EquivPtr &equiv) const {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (!CheckSupportDataType(node, kFloatDataTypeSet)) {
|
||||
auto sub0 = node;
|
||||
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimDepend)) {
|
||||
auto iter_sub0 = (*equiv).find(sub0_var_);
|
||||
if (iter_sub0 == (*equiv).end()) {
|
||||
MS_LOG(EXCEPTION) << "The equiv map is expected to contains the sub0 var after matched.";
|
||||
}
|
||||
sub0 = utils::cast<AnfNodePtr>(iter_sub0->second);
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(sub0);
|
||||
if (!CheckSupportDataType(sub0, kFloatDataTypeSet)) {
|
||||
return nullptr;
|
||||
}
|
||||
auto new_node = CreateAdamApplyOneNode(func_graph, equiv);
|
||||
auto new_node = CreateAdamApplyOneNode(func_graph, equiv, node);
|
||||
MS_EXCEPTION_IF_NULL(new_node);
|
||||
new_node->set_scope(node->scope());
|
||||
new_node->set_scope(sub0->scope());
|
||||
// Set abstract of new node
|
||||
AbstractBasePtrList new_node_abstract_list;
|
||||
auto iter_add0 = (*equiv).find(add0_var_);
|
||||
|
@ -130,7 +245,7 @@ const AnfNodePtr AdamApplyOneFusion::Process(const FuncGraphPtr &func_graph, con
|
|||
MS_EXCEPTION_IF_NULL(add1);
|
||||
new_node_abstract_list.push_back(add1->abstract());
|
||||
new_node_abstract_list.push_back(add0->abstract());
|
||||
new_node_abstract_list.push_back(node->abstract());
|
||||
new_node_abstract_list.push_back(sub0->abstract());
|
||||
auto abstract_tuple = std::make_shared<abstract::AbstractTuple>(new_node_abstract_list);
|
||||
new_node->set_abstract(abstract_tuple);
|
||||
// Create tuple_getitem node for outputs
|
||||
|
|
|
@ -40,6 +40,7 @@ class AdamApplyOneFusion : public PatternProcessPass {
|
|||
add2_y_ = std::make_shared<Var>();
|
||||
add0_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimTensorAdd->name()));
|
||||
add1_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimTensorAdd->name()));
|
||||
sub0_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimSub->name()));
|
||||
}
|
||||
|
||||
~AdamApplyOneFusion() override = default;
|
||||
|
@ -47,12 +48,14 @@ class AdamApplyOneFusion : public PatternProcessPass {
|
|||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
|
||||
protected:
|
||||
AnfNodePtr CreateAdamApplyOneNode(const FuncGraphPtr &func_graph, const EquivPtr &equiv) const;
|
||||
AnfNodePtr CreateAdamApplyOneNode(const FuncGraphPtr &func_graph, const EquivPtr &equiv,
|
||||
const AnfNodePtr &final_node) const;
|
||||
std::vector<VarPtr> input_vars_;
|
||||
std::vector<VarPtr> mul_x_input_vars_;
|
||||
VarPtr add2_y_;
|
||||
VarPtr add0_var_;
|
||||
VarPtr add1_var_;
|
||||
VarPtr sub0_var_;
|
||||
};
|
||||
|
||||
class AdamApplyOneCond1Fusion : public AdamApplyOneFusion {
|
||||
|
@ -90,6 +93,51 @@ class AdamApplyOneCond4Fusion : public AdamApplyOneFusion {
|
|||
~AdamApplyOneCond4Fusion() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
};
|
||||
|
||||
class AdamApplyOneAssignFusion : public AdamApplyOneFusion {
|
||||
public:
|
||||
explicit AdamApplyOneAssignFusion(bool multigraph = true)
|
||||
: AdamApplyOneFusion("adam_apply_one_assign_fusion", multigraph) {}
|
||||
|
||||
~AdamApplyOneAssignFusion() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
};
|
||||
|
||||
class AdamApplyOneAssignCond1Fusion : public AdamApplyOneFusion {
|
||||
public:
|
||||
explicit AdamApplyOneAssignCond1Fusion(bool multigraph = true)
|
||||
: AdamApplyOneFusion("adam_apply_one_assign_cond1_fusion", multigraph) {}
|
||||
|
||||
~AdamApplyOneAssignCond1Fusion() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
};
|
||||
|
||||
class AdamApplyOneAssignCond2Fusion : public AdamApplyOneFusion {
|
||||
public:
|
||||
explicit AdamApplyOneAssignCond2Fusion(bool multigraph = true)
|
||||
: AdamApplyOneFusion("adam_apply_one_assign_cond2_fusion", multigraph) {}
|
||||
|
||||
~AdamApplyOneAssignCond2Fusion() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
};
|
||||
|
||||
class AdamApplyOneAssignCond3Fusion : public AdamApplyOneFusion {
|
||||
public:
|
||||
explicit AdamApplyOneAssignCond3Fusion(bool multigraph = true)
|
||||
: AdamApplyOneFusion("adam_apply_one_assign_cond3_fusion", multigraph) {}
|
||||
|
||||
~AdamApplyOneAssignCond3Fusion() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
};
|
||||
|
||||
class AdamApplyOneAssignCond4Fusion : public AdamApplyOneFusion {
|
||||
public:
|
||||
explicit AdamApplyOneAssignCond4Fusion(bool multigraph = true)
|
||||
: AdamApplyOneFusion("adam_apply_one_assign_cond4_fusion", multigraph) {}
|
||||
|
||||
~AdamApplyOneAssignCond4Fusion() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_ADAM_APPLY_ONE_FUSION_H_
|
||||
|
|
|
@ -62,7 +62,14 @@ const AnfNodePtr ConvertTupleOutputToMaketuple::Process(const FuncGraphPtr &func
|
|||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
std::unordered_map<AnfNodePtr, AnfNodePtr> transed_nodes;
|
||||
if (IsPrimitiveCNode(cnode, prim::kPrimTupleGetItem) || IsPrimitiveCNode(cnode, prim::kPrimControlDepend)) {
|
||||
if (IsPrimitiveCNode(cnode, prim::kPrimTupleGetItem)) {
|
||||
auto real_input = AnfAlgo::GetTupleGetItemRealInput(cnode);
|
||||
MS_EXCEPTION_IF_NULL(real_input);
|
||||
if (!real_input->isa<Parameter>() && !real_input->isa<ValueNode>()) {
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
if (IsPrimitiveCNode(cnode, prim::kPrimControlDepend)) {
|
||||
return nullptr;
|
||||
}
|
||||
bool cnode_input_changed = false;
|
||||
|
|
|
@ -863,7 +863,7 @@ bool TensorNeedSync(const AnfNodePtr ¶meter, const tensor::TensorPtr &tensor
|
|||
auto ms_context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(ms_context);
|
||||
auto device_address = AnfAlgo::GetMutableOutputAddr(parameter, 0);
|
||||
if (ms_context->enable_pynative_infer()) {
|
||||
if (ms_context->execution_mode() == kPynativeMode) {
|
||||
return tensor->device_address().get() == nullptr || tensor->device_address() != device_address;
|
||||
}
|
||||
if (tensor->is_dirty()) {
|
||||
|
|
|
@ -393,40 +393,5 @@ ValuePtr BoolEq(const ValuePtrList &list) {
|
|||
|
||||
MS_LOG(EXCEPTION) << "Unsported Value for BoolEq, x: " << x->ToString() << ".";
|
||||
}
|
||||
|
||||
std::vector<int> BroadcastShape_(std::vector<int> shpx, std::vector<int> shpy) {
|
||||
int dlen = SizeToInt(shpx.size()) - SizeToInt(shpy.size());
|
||||
if (dlen < 0) {
|
||||
for (int i = 0; i < -dlen; ++i) {
|
||||
(void)shpx.insert(shpx.begin(), 1);
|
||||
}
|
||||
} else if (dlen > 0) {
|
||||
for (int i = 0; i < dlen; i++) {
|
||||
(void)shpy.insert(shpy.begin(), 1);
|
||||
}
|
||||
}
|
||||
if (shpx.size() != shpy.size()) {
|
||||
MS_LOG(EXCEPTION) << "Failure: shpx.size() != shpy.size().";
|
||||
}
|
||||
std::vector<int> shp;
|
||||
for (size_t i = 0; i < shpx.size(); i++) {
|
||||
auto a = shpx[i];
|
||||
auto b = shpy[i];
|
||||
if (a == 1) {
|
||||
shp.push_back(b);
|
||||
} else if (b == 1) {
|
||||
shp.push_back(a);
|
||||
} else if (a == -1) {
|
||||
shp.push_back(b);
|
||||
} else if (b == -1) {
|
||||
shp.push_back(a);
|
||||
} else if (a == b) {
|
||||
shp.push_back(a);
|
||||
} else {
|
||||
return std::vector<int>();
|
||||
}
|
||||
}
|
||||
return shp;
|
||||
}
|
||||
} // namespace prim
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -52,7 +52,6 @@ ValuePtr BoolNot(const ValuePtrList &list);
|
|||
ValuePtr BoolAnd(const ValuePtrList &list);
|
||||
ValuePtr BoolOr(const ValuePtrList &list);
|
||||
ValuePtr BoolEq(const ValuePtrList &list);
|
||||
std::vector<int> BroadcastShape_(std::vector<int> s1, std::vector<int> s2);
|
||||
} // namespace prim
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -333,28 +333,28 @@ ArgsPairList HyperMap::Harmonize(const FuncGraphPtr &func_graph, const ArgsPairL
|
|||
}
|
||||
|
||||
FuncGraphPtr HyperMap::GenerateFromTypes(const TypePtrList &args_spec_list) {
|
||||
FuncGraphPtr ptrGraph = std::make_shared<FuncGraph>();
|
||||
ptrGraph->set_flag(FUNC_GRAPH_FLAG_CORE, true);
|
||||
ptrGraph->set_flag(FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER, true);
|
||||
ptrGraph->debug_info()->set_name("hyper_map");
|
||||
FuncGraphPtr ptr_graph = std::make_shared<FuncGraph>();
|
||||
ptr_graph->set_flag(FUNC_GRAPH_FLAG_CORE, true);
|
||||
ptr_graph->set_flag(FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER, true);
|
||||
ptr_graph->debug_info()->set_name("hyper_map");
|
||||
|
||||
AnfNodePtr ptrFnArg = nullptr;
|
||||
std::size_t i = 0;
|
||||
ArgsPairList argmap;
|
||||
ArgsPairList argmap2;
|
||||
if (fn_leaf_ == nullptr) {
|
||||
ptrFnArg = ptrGraph->add_parameter();
|
||||
ptrFnArg = ptr_graph->add_parameter();
|
||||
i = 1;
|
||||
}
|
||||
|
||||
std::size_t size = args_spec_list.size();
|
||||
for (; i < size; ++i) {
|
||||
argmap.push_back(std::make_pair(ptrGraph->add_parameter(), args_spec_list[i]));
|
||||
argmap.push_back(std::make_pair(ptr_graph->add_parameter(), args_spec_list[i]));
|
||||
}
|
||||
|
||||
argmap2 = Harmonize(ptrGraph, argmap);
|
||||
ptrGraph->set_output(Make(ptrGraph, ptrFnArg, argmap2));
|
||||
return ptrGraph;
|
||||
argmap2 = Harmonize(ptr_graph, argmap);
|
||||
ptr_graph->set_output(Make(ptr_graph, ptrFnArg, argmap2));
|
||||
return ptr_graph;
|
||||
}
|
||||
|
||||
abstract::AbstractBasePtrList HyperMap::NormalizeArgs(const AbstractBasePtrList &args_spec_list) const {
|
||||
|
@ -582,30 +582,30 @@ FuncGraphPtr GradOperation::GetGrad(AnfNodePtr node, const AnfNodePtr &weights,
|
|||
inputs.push_back(opsTupleItem);
|
||||
inputs.push_back(cnode);
|
||||
inputs.push_back(NewValueNode(1));
|
||||
AnfNodePtr ptrBprop = ret->NewCNode(inputs);
|
||||
AnfNodePtr ptr_bprop = ret->NewCNode(inputs);
|
||||
|
||||
doGetGrad(ret, out, ptrBprop, weights_node, opsTupleItem);
|
||||
doGetGrad(ret, out, ptr_bprop, weights_node, opsTupleItem);
|
||||
return ret;
|
||||
}
|
||||
|
||||
void GradOperation::doGetGrad(const FuncGraphPtr &func_graph, AnfNodePtr out, AnfNodePtr ptrBprop, AnfNodePtr weights,
|
||||
void GradOperation::doGetGrad(const FuncGraphPtr &func_graph, AnfNodePtr out, AnfNodePtr ptr_bprop, AnfNodePtr weights,
|
||||
ValueNodePtr opsTupleItem) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
|
||||
AnfNodePtr ptrBPropArg = nullptr;
|
||||
AnfNodePtr ptr_bprop_arg = nullptr;
|
||||
if (sens_param_) {
|
||||
ptrBPropArg = func_graph->add_parameter();
|
||||
ptr_bprop_arg = func_graph->add_parameter();
|
||||
} else {
|
||||
auto ones_like = prim::GetPythonOps("ones_like");
|
||||
ptrBPropArg = func_graph->NewCNode({NewValueNode(ones_like), out});
|
||||
ptr_bprop_arg = func_graph->NewCNode({NewValueNode(ones_like), out});
|
||||
}
|
||||
|
||||
AnfNodePtr ptrBApp = func_graph->NewCNode({ptrBprop, ptrBPropArg});
|
||||
AnfNodePtr ptr_bapp = func_graph->NewCNode({ptr_bprop, ptr_bprop_arg});
|
||||
|
||||
CNodePtr fv_bprop = nullptr;
|
||||
if (get_by_list_) {
|
||||
// python code: grads = hyper_map(F.partial(env_get, env), weights)
|
||||
AnfNodePtr env = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), ptrBApp, NewValueNode(0)});
|
||||
AnfNodePtr env = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), ptr_bapp, NewValueNode(0)});
|
||||
AnfNodePtr partial_env_get =
|
||||
func_graph->NewCNode({NewValueNode(prim::kPrimPartial), NewValueNode(prim::GetPythonOps("env_get")), env});
|
||||
MetaFuncGraphPtr hyper_map = std::make_shared<HyperMap>();
|
||||
|
@ -614,7 +614,7 @@ void GradOperation::doGetGrad(const FuncGraphPtr &func_graph, AnfNodePtr out, An
|
|||
|
||||
CNodePtr inputs_bprop = nullptr;
|
||||
if (get_all_) {
|
||||
inputs_bprop = func_graph->NewCNode({NewValueNode(kTail), ptrBApp});
|
||||
inputs_bprop = func_graph->NewCNode({NewValueNode(kTail), ptr_bapp});
|
||||
}
|
||||
|
||||
// Gradients wrt inputs and parameters
|
||||
|
@ -636,8 +636,8 @@ void GradOperation::doGetGrad(const FuncGraphPtr &func_graph, AnfNodePtr out, An
|
|||
}
|
||||
|
||||
// Gradients wrt first input.
|
||||
// ptrBApp returns (EnvInstance(grads wrt params), grads wrt input0, grads wrt input1, ...), so 1 is for first input
|
||||
func_graph->set_output(func_graph->NewCNode({opsTupleItem, ptrBApp, NewValueNode(1)}));
|
||||
// ptr_bapp returns (EnvInstance(grads wrt params), grads wrt input0, grads wrt input1, ...), so 1 is for first input
|
||||
func_graph->set_output(func_graph->NewCNode({opsTupleItem, ptr_bapp, NewValueNode(1)}));
|
||||
}
|
||||
|
||||
// Generate the graph.
|
||||
|
@ -657,35 +657,35 @@ FuncGraphPtr GradOperation::GenerateFuncGraph(const AbstractBasePtrList &args_sp
|
|||
auto real_fn = dyn_cast<FuncGraphAbstractClosure>(fn);
|
||||
MS_EXCEPTION_IF_NULL(real_fn);
|
||||
|
||||
FuncGraphPtr ptrGraph = real_fn->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(ptrGraph);
|
||||
TraceManager::DebugTrace(std::make_shared<TraceGradOperation>(ptrGraph->debug_info()));
|
||||
FuncGraphPtr dfBuilder = std::make_shared<FuncGraph>();
|
||||
FuncGraphPtr ptr_graph = real_fn->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(ptr_graph);
|
||||
TraceManager::DebugTrace(std::make_shared<TraceGradOperation>(ptr_graph->debug_info()));
|
||||
FuncGraphPtr df_builder = std::make_shared<FuncGraph>();
|
||||
TraceManager::EndTrace();
|
||||
auto nparam = ptrGraph->parameters().size();
|
||||
auto nparam = ptr_graph->parameters().size();
|
||||
|
||||
std::ostringstream ss;
|
||||
ss << "grad{" << nparam << "}";
|
||||
dfBuilder->set_flag(FUNC_GRAPH_FLAG_CORE, true);
|
||||
dfBuilder->debug_info()->set_name(ss.str());
|
||||
ParameterPtr param_graph = dfBuilder->add_parameter();
|
||||
df_builder->set_flag(FUNC_GRAPH_FLAG_CORE, true);
|
||||
df_builder->debug_info()->set_name(ss.str());
|
||||
ParameterPtr param_graph = df_builder->add_parameter();
|
||||
|
||||
AnfNodePtr weights = nullptr;
|
||||
if (get_by_list_) {
|
||||
weights = dfBuilder->add_parameter();
|
||||
weights = df_builder->add_parameter();
|
||||
}
|
||||
|
||||
std::vector<AnfNodePtr> inputs;
|
||||
inputs.push_back(NewValueNode(prim::kPrimJ));
|
||||
inputs.push_back(param_graph);
|
||||
auto jf = dfBuilder->NewCNode(inputs);
|
||||
auto jf = df_builder->NewCNode(inputs);
|
||||
// df is checked in GetGrad
|
||||
TraceManager::DebugTrace(std::make_shared<TraceGradOperation>(ptrGraph->debug_info()));
|
||||
auto df = GetGrad(jf, weights, ptrGraph->parameters());
|
||||
TraceManager::DebugTrace(std::make_shared<TraceGradOperation>(ptr_graph->debug_info()));
|
||||
auto df = GetGrad(jf, weights, ptr_graph->parameters());
|
||||
TraceManager::EndTrace();
|
||||
dfBuilder->set_output(NewValueNode(df));
|
||||
df_builder->set_output(NewValueNode(df));
|
||||
|
||||
return dfBuilder;
|
||||
return df_builder;
|
||||
}
|
||||
|
||||
REGISTER_PYBIND_DEFINE(GradOperation_, ([](const py::module *m) {
|
||||
|
@ -929,7 +929,7 @@ void GenerateTupleSliceParameter(const AbstractTuplePtr &tuple, const AbstractSl
|
|||
|
||||
*step_value = CheckSliceMember(slice->step(), step_default, step_name);
|
||||
if (*step_value == 0) {
|
||||
MS_LOG(EXCEPTION) << "TupleSlice require the step value could not be 0, but got 0.";
|
||||
MS_EXCEPTION(ValueError) << "TupleSlice require the step value could not be 0, but got 0.";
|
||||
}
|
||||
|
||||
if (*step_value < 0) {
|
||||
|
@ -941,7 +941,7 @@ void GenerateTupleSliceParameter(const AbstractTuplePtr &tuple, const AbstractSl
|
|||
*stop_index = CheckSliceMember(slice->stop(), stop_default, stop_name);
|
||||
if (!CheckIndexInRange(*start_index, -tuple_size, tuple_size - 1) ||
|
||||
!CheckIndexInRange(*stop_index, -tuple_size - 1, tuple_size)) {
|
||||
MS_LOG(EXCEPTION) << "TupleSlice the start index " << *start_index << " or end end index " << *stop_index
|
||||
MS_EXCEPTION(ValueError) << "TupleSlice the start index " << *start_index << " or end end index " << *stop_index
|
||||
<< " out of range, tuple size " << tuple_size << ".";
|
||||
}
|
||||
|
||||
|
|
|
@ -72,10 +72,15 @@ void SetMaxType(TypeId *max_type_id, size_t *max_type_number, const TypeId type_
|
|||
bool GetTensorOrScalarTypeInfo(AbstractBasePtr arg_value, bool is_write, TypeId *arg_type_id,
|
||||
TypeId *arg_type = nullptr) {
|
||||
if (arg_value->isa<abstract::AbstractRef>()) {
|
||||
if (is_write) {
|
||||
arg_value = arg_value->cast<abstract::AbstractRefPtr>()->ref_origin();
|
||||
} else {
|
||||
arg_value = arg_value->cast<abstract::AbstractRefPtr>()->ref();
|
||||
auto ref = arg_value->cast<abstract::AbstractRefPtr>();
|
||||
arg_value = ref->ref();
|
||||
if (!is_write && ref->need_cast()) {
|
||||
auto tensor_type = ref->target_type();
|
||||
*arg_type_id = tensor_type->type_id();
|
||||
if (arg_type != nullptr) {
|
||||
*arg_type = kObjectTypeTensorType;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
}
|
||||
if (arg_value->isa<abstract::AbstractTensor>()) {
|
||||
|
@ -248,6 +253,8 @@ void DoAutoCast(const std::string &func_name, const std::vector<Signature> &sign
|
|||
if (arg_value->isa<abstract::AbstractTensor>() && arg_type_id == it->second) {
|
||||
continue;
|
||||
}
|
||||
MS_LOG(DEBUG) << "do cast for inputs " << i << " " << (*op_inputs)[i + 1]->ToString() << " " << arg_type_id
|
||||
<< " to " << it->second;
|
||||
(*op_inputs)[i + 1] = DoCast((*op_inputs)[i + 1], it->second, graph);
|
||||
}
|
||||
}
|
||||
|
@ -289,16 +296,23 @@ AnfNodePtr BuildNewCNode(const FuncGraphPtr &func_graph, const std::string &func
|
|||
|
||||
TypePtr type = args_spec_list[i]->GetTypeTrack();
|
||||
if (type && type->type_id() == kObjectTypeRef) {
|
||||
auto ref_abs = args_spec_list[i]->cast<abstract::AbstractRefPtr>();
|
||||
if (sig == SignatureEnumRW::kRWRead) {
|
||||
param = func_graph->NewCNode({NewValueNode(prim::kPrimGetRefValue), param});
|
||||
param = NewCNode({NewValueNode(prim::kPrimGetRefValue), param}, func_graph);
|
||||
if (ref_abs && ref_abs->need_cast()) {
|
||||
auto cast = prim::GetPythonOps("cast", "mindspore.ops.functional");
|
||||
param = NewCNode({NewValueNode(cast), param, NewValueNode(ref_abs->target_type())}, func_graph);
|
||||
}
|
||||
} else if (sig == SignatureEnumRW::kRWWrite) {
|
||||
param = func_graph->NewCNode({NewValueNode(prim::kPrimGetRefOrigin), param});
|
||||
param = NewCNode({NewValueNode(prim::kPrimGetRefValue), param}, func_graph);
|
||||
write_indices.insert(i);
|
||||
}
|
||||
// If sig is SignatureEnumRW::kRWRef, not do anything.
|
||||
} else if (sig == SignatureEnumRW::kRWWrite && type->type_id() != kObjectTypeRefKey) {
|
||||
MS_EXCEPTION(TypeError) << "Function " << func_name << "'s input " << i << " should be a Parameter.";
|
||||
}
|
||||
MS_LOG(DEBUG) << "Function " << func_name << "'s input " << i << " " << param->DebugString(2) << " type "
|
||||
<< args_spec_list[i]->ToString();
|
||||
op_inputs.push_back(param);
|
||||
}
|
||||
// process default
|
||||
|
|
|
@ -49,13 +49,14 @@ FuncGraphPtr UnpackCall::GenerateFuncGraph(const AbstractBasePtrList &args_spec_
|
|||
MS_LOG(EXCEPTION) << op_name << " requires at least two args, but got " << arg_length << ".";
|
||||
}
|
||||
|
||||
(void)abstract::CheckArg<AbstractFunction>(op_name, args_spec_list, 0);
|
||||
// No need to check, check will be done in infer.
|
||||
auto ret_graph = std::make_shared<FuncGraph>();
|
||||
ret_graph->set_flag(FUNC_GRAPH_FLAG_CORE, true);
|
||||
ret_graph->debug_info()->set_name("UnpackCall");
|
||||
|
||||
AnfNodePtr fnNode = ret_graph->add_parameter();
|
||||
AnfNodePtr fn_node = ret_graph->add_parameter();
|
||||
std::vector<AnfNodePtr> elems;
|
||||
elems.push_back(fnNode);
|
||||
elems.push_back(fn_node);
|
||||
for (size_t index = 1; index < arg_length; index++) {
|
||||
MS_EXCEPTION_IF_NULL(args_spec_list[index]);
|
||||
if (args_spec_list[index]->isa<AbstractTuple>()) {
|
||||
|
|
|
@ -31,160 +31,43 @@ ValuePtr GetPythonOps(const std::string &op_name,
|
|||
const std::string &module_name = "mindspore._extends.parse.standard_method",
|
||||
bool use_signature = false);
|
||||
|
||||
// Arithmetic
|
||||
inline const PrimitivePtr kPrimScalarAdd = std::make_shared<Primitive>("scalar_add");
|
||||
inline const PrimitivePtr kPrimScalarSub = std::make_shared<Primitive>("scalar_sub");
|
||||
inline const PrimitivePtr kPrimScalarMul = std::make_shared<Primitive>("scalar_mul");
|
||||
inline const PrimitivePtr kPrimScalarDiv = std::make_shared<Primitive>("scalar_div");
|
||||
inline const PrimitivePtr kPrimScalarFloordiv = std::make_shared<Primitive>("scalar_floordiv");
|
||||
inline const PrimitivePtr kPrimScalarMod = std::make_shared<Primitive>("scalar_mod");
|
||||
inline const PrimitivePtr kPrimScalarPow = std::make_shared<Primitive>("scalar_pow");
|
||||
inline const PrimitivePtr kPrimScalarTrunc = std::make_shared<Primitive>("scalar_trunc");
|
||||
inline const PrimitivePtr kPrimScalarFloor = std::make_shared<Primitive>("scalar_floor");
|
||||
inline const PrimitivePtr kPrimScalarUadd = std::make_shared<Primitive>("scalar_uadd");
|
||||
inline const PrimitivePtr kPrimScalarUsub = std::make_shared<Primitive>("scalar_usub");
|
||||
inline const PrimitivePtr kPrimScalarExp = std::make_shared<Primitive>("scalar_exp");
|
||||
inline const PrimitivePtr kPrimScalarLog = std::make_shared<Primitive>("scalar_log");
|
||||
inline const PrimitivePtr kPrimScalarSin = std::make_shared<Primitive>("scalar_sin");
|
||||
inline const PrimitivePtr kPrimScalarCos = std::make_shared<Primitive>("scalar_cos");
|
||||
inline const PrimitivePtr kPrimScalarTan = std::make_shared<Primitive>("scalar_tan");
|
||||
|
||||
// Comparisons
|
||||
inline const PrimitivePtr kPrimScalarEq = std::make_shared<Primitive>("scalar_eq");
|
||||
inline const PrimitivePtr kPrimScalarLt = std::make_shared<Primitive>("scalar_lt");
|
||||
inline const PrimitivePtr kPrimScalarGt = std::make_shared<Primitive>("scalar_gt");
|
||||
inline const PrimitivePtr kPrimScalarNe = std::make_shared<Primitive>("scalar_ne");
|
||||
inline const PrimitivePtr kPrimScalarLe = std::make_shared<Primitive>("scalar_le");
|
||||
inline const PrimitivePtr kPrimScalarGe = std::make_shared<Primitive>("scalar_ge");
|
||||
inline const PrimitivePtr kPrimBoolNot = std::make_shared<Primitive>("bool_not");
|
||||
inline const PrimitivePtr kPrimBoolAnd = std::make_shared<Primitive>("bool_and");
|
||||
inline const PrimitivePtr kPrimBoolOr = std::make_shared<Primitive>("bool_or");
|
||||
inline const PrimitivePtr kPrimBoolEq = std::make_shared<Primitive>("bool_eq");
|
||||
inline const PrimitivePtr kPrimGreater = std::make_shared<Primitive>("Greater");
|
||||
inline const PrimitivePtr kPrimGreaterEqual = std::make_shared<Primitive>("GreaterEqual");
|
||||
inline const PrimitivePtr kPrimLess = std::make_shared<Primitive>("Less");
|
||||
inline const PrimitivePtr kPrimLessEqual = std::make_shared<Primitive>("LessEqual");
|
||||
inline const PrimitivePtr kPrimEqual = std::make_shared<Primitive>("Equal");
|
||||
inline const PrimitivePtr kPrimNotEqual = std::make_shared<Primitive>("NotEqual");
|
||||
|
||||
// Primitives only used by frontend;
|
||||
// Type introspection
|
||||
inline const PrimitivePtr kPrimTypeOf = std::make_shared<Primitive>("typeof");
|
||||
inline const PrimitivePtr kPrimHasType = std::make_shared<Primitive>("hastype");
|
||||
|
||||
inline const PrimitivePtr kPrimDistribute = std::make_shared<Primitive>("distribute");
|
||||
inline const PrimitivePtr kPrimDot = std::make_shared<Primitive>("dot");
|
||||
inline const PrimitivePtr kPrimIm2Col = std::make_shared<Primitive>("im2col");
|
||||
inline const PrimitivePtr kPrimCol2Im = std::make_shared<Primitive>("col2im");
|
||||
inline const PrimitivePtr kPrimIm2ColV1 = std::make_shared<Primitive>("im2col_v1");
|
||||
inline const PrimitivePtr kPrimCol2ImV1 = std::make_shared<Primitive>("col2im_v1");
|
||||
|
||||
inline const PrimitivePtr kPrimResolve = std::make_shared<Primitive>("resolve");
|
||||
inline const PrimitivePtr kPrimEmbed = std::make_shared<Primitive>("embed");
|
||||
inline const PrimitivePtr kPrimRefToEmbed = std::make_shared<Primitive>("RefToEmbed");
|
||||
inline const PrimitivePtr kPrimCreateInstance = std::make_shared<Primitive>("create_instance");
|
||||
|
||||
inline const PrimitivePtr kPrimLabelGoto = std::make_shared<Primitive>("LabelGoto");
|
||||
inline const PrimitivePtr kPrimLabelSwitch = std::make_shared<Primitive>("LabelSwitch");
|
||||
inline const PrimitivePtr kPrimLabelSet = std::make_shared<Primitive>("LabelSet");
|
||||
// Other miscellaneous
|
||||
inline const PrimitivePtr kPrimGetRefOrigin = std::make_shared<Primitive>("get_ref_origin");
|
||||
inline const PrimitivePtr kPrimInsertGradientOf = std::make_shared<Primitive>("InsertGradientOf");
|
||||
inline const PrimitivePtr kPrimCheckBprop = std::make_shared<Primitive>("CheckBprop");
|
||||
inline const PrimitivePtr kPrimMixedPrecisionCast = std::make_shared<Primitive>("mixed_precision_cast");
|
||||
inline const PrimitivePtr kPrimMakeRecord = std::make_shared<Primitive>("make_record");
|
||||
|
||||
// Arrays
|
||||
inline const PrimitivePtr kPrimScalarToArray = std::make_shared<Primitive>("scalar_to_array");
|
||||
inline const PrimitivePtr kPrimArrayToScalar = std::make_shared<Primitive>("array_to_scalar");
|
||||
inline const PrimitivePtr kPrimBroadcastShape = std::make_shared<Primitive>("broadcast_shape");
|
||||
inline const PrimitivePtr kPrimArrayMap = std::make_shared<Primitive>("array_map");
|
||||
inline const PrimitivePtr kPrimArrayReduce = std::make_shared<Primitive>("array_reduce");
|
||||
inline const PrimitivePtr kPrimCast = std::make_shared<Primitive>("Cast");
|
||||
inline const PrimitivePtr kPrimConcat = std::make_shared<Primitive>("Concat");
|
||||
inline const PrimitivePtr kPrimSqueeze = std::make_shared<Primitive>("Squeeze");
|
||||
inline const PrimitivePtr kPrimTranspose = std::make_shared<Primitive>("Transpose");
|
||||
inline const PrimitivePtr kPrimGatherV2 = std::make_shared<Primitive>("GatherV2");
|
||||
inline const PrimitivePtr kPrimEmbeddingLookup = std::make_shared<Primitive>("EmbeddingLookup");
|
||||
inline const PrimitivePtr kPrimEmbeddingLookupCommGrad = std::make_shared<Primitive>("EmbeddingLookupCommGrad");
|
||||
inline const PrimitivePtr kPrimSize = std::make_shared<Primitive>("Size");
|
||||
inline const PrimitivePtr kPrimArgMax = std::make_shared<Primitive>("Argmax");
|
||||
inline const PrimitivePtr kPrimPack = std::make_shared<Primitive>("Pack");
|
||||
inline const PrimitivePtr kPrimUnsortedSegmentSum = std::make_shared<Primitive>("UnsortedSegmentSum");
|
||||
inline const PrimitivePtr kPrimUnsortedSegmentMin = std::make_shared<Primitive>("UnsortedSegmentMin");
|
||||
inline const PrimitivePtr kPrimConcatOffset = std::make_shared<Primitive>("ConcatOffset");
|
||||
inline const PrimitivePtr kPrimReshape = std::make_shared<Primitive>("Reshape");
|
||||
inline const PrimitivePtr kPrimTile = std::make_shared<Primitive>("Tile");
|
||||
inline const PrimitivePtr kPrimAddN = std::make_shared<Primitive>("AddN");
|
||||
inline const PrimitivePtr KPrimTransData = std::make_shared<Primitive>("TransData");
|
||||
inline const PrimitivePtr kPrimNMSWithMask = std::make_shared<Primitive>("NMSWithMask");
|
||||
inline const PrimitivePtr kPrimPad = std::make_shared<Primitive>("Pad");
|
||||
inline const PrimitivePtr kPrimArgMaxWithValue = std::make_shared<Primitive>("ArgMaxWithValue");
|
||||
inline const PrimitivePtr kPrimUnique = std::make_shared<Primitive>("Unique");
|
||||
inline const PrimitivePtr kPrimUniqueGrad = std::make_shared<Primitive>("UniqueGrad");
|
||||
// Structures
|
||||
|
||||
inline const PrimitivePtr kPrimListMap = std::make_shared<Primitive>("list_map");
|
||||
inline const PrimitivePtr kPrimListReduce = std::make_shared<Primitive>("list_reduce");
|
||||
inline const PrimitivePtr kPrimTupleReversed = std::make_shared<Primitive>("tuple_reversed");
|
||||
inline const PrimitivePtr kPrimReducedShape = std::make_shared<Primitive>("reduced_shape");
|
||||
inline const PrimitivePtr kPrimTupleDiv = std::make_shared<Primitive>("tuple_div");
|
||||
inline const PrimitivePtr kPrimTupleToArray = std::make_shared<Primitive>("tuple_to_array");
|
||||
inline const PrimitivePtr kPrimShapeMul = std::make_shared<Primitive>("shape_mul");
|
||||
inline const PrimitivePtr kPrimTupleEqual = std::make_shared<Primitive>("tuple_equal");
|
||||
inline const PrimitivePtr kPrimListEqual = std::make_shared<Primitive>("list_equal");
|
||||
inline const PrimitivePtr kPrimMakeRange = std::make_shared<Primitive>("make_range");
|
||||
inline const PrimitivePtr kPrimStopGradient = std::make_shared<Primitive>("stop_gradient");
|
||||
inline const PrimitivePtr kPrimStringEqual = std::make_shared<Primitive>("string_equal");
|
||||
inline const PrimitivePtr kPrimStringConcat = std::make_shared<Primitive>("string_concat");
|
||||
inline const PrimitivePtr kPrimDictLen = std::make_shared<Primitive>("dict_len");
|
||||
|
||||
// NN
|
||||
inline const PrimitivePtr kPrimFlatten = std::make_shared<Primitive>("Flatten");
|
||||
inline const PrimitivePtr kPrimSoftmax = std::make_shared<Primitive>("Softmax");
|
||||
inline const PrimitivePtr kPrimLogSoftmax = std::make_shared<Primitive>("LogSoftmax");
|
||||
inline const PrimitivePtr kPrimLogSoftmaxGrad = std::make_shared<Primitive>("LogSoftmaxGrad");
|
||||
inline const PrimitivePtr kPrimTanh = std::make_shared<Primitive>("Tanh");
|
||||
inline const PrimitivePtr kPrimTanhGrad = std::make_shared<Primitive>("TanhGrad");
|
||||
inline const PrimitivePtr kPrimPooling = std::make_shared<Primitive>("Pooling");
|
||||
inline const PrimitivePtr kPrimPoolingGrad = std::make_shared<Primitive>("PoolingGrad");
|
||||
inline const PrimitivePtr kPrimMaxPool = std::make_shared<Primitive>("MaxPool");
|
||||
inline const PrimitivePtr kPrimMaxPoolGrad = std::make_shared<Primitive>("MaxPoolGrad");
|
||||
inline const PrimitivePtr kPrimApplyCenteredRMSProp = std::make_shared<Primitive>("ApplyCenteredRMSProp");
|
||||
inline const PrimitivePtr kPrimAvgPoolGrad = std::make_shared<Primitive>("AvgPoolGrad");
|
||||
inline const PrimitivePtr kPrimAvgPoolGradVm = std::make_shared<Primitive>("AvgPoolGradVm");
|
||||
inline const PrimitivePtr kPrimFusedBatchNorm = std::make_shared<Primitive>("FusedBatchNorm");
|
||||
inline const PrimitivePtr kPrimConv2D = std::make_shared<Primitive>("Conv2D");
|
||||
inline const PrimitivePtr kPrimFusedBatchNormGrad = std::make_shared<Primitive>("FusedBatchNormGrad");
|
||||
inline const PrimitivePtr kPrimBatchNorm = std::make_shared<Primitive>("BatchNorm");
|
||||
inline const PrimitivePtr kPrimBatchNormGrad = std::make_shared<Primitive>("BatchNormGrad");
|
||||
inline const PrimitivePtr kPrimReluGrad = std::make_shared<Primitive>("ReluGrad");
|
||||
inline const PrimitivePtr kPrimConv2DBackpropInput = std::make_shared<Primitive>("Conv2DBackpropInput");
|
||||
inline const PrimitivePtr kPrimConv2DBackpropFilter = std::make_shared<Primitive>("Conv2DBackpropFilter");
|
||||
inline const PrimitivePtr kPrimDepthwiseConv2dNative = std::make_shared<Primitive>("DepthwiseConv2dNative");
|
||||
inline const PrimitivePtr kPrimDepthwiseConv2dNativeBackpropFilter =
|
||||
std::make_shared<Primitive>("DepthwiseConv2dNativeBackpropFilter");
|
||||
inline const PrimitivePtr kPrimDepthwiseConv2dNativeBackpropInput =
|
||||
std::make_shared<Primitive>("DepthwiseConv2dNativeBackpropInput");
|
||||
inline const PrimitivePtr kPrimBiasAddGrad = std::make_shared<Primitive>("BiasAddGrad");
|
||||
inline const PrimitivePtr kPrimSoftmaxCrossEntropyWithLogits =
|
||||
std::make_shared<Primitive>("SoftmaxCrossEntropyWithLogits");
|
||||
inline const PrimitivePtr kPrimSparseSoftmaxCrossEntropyWithLogits =
|
||||
std::make_shared<Primitive>("SparseSoftmaxCrossEntropyWithLogits");
|
||||
inline const PrimitivePtr kPrimMomentum = std::make_shared<Primitive>("Momentum");
|
||||
inline const PrimitivePtr kPrimApplyMomentum = std::make_shared<Primitive>("ApplyMomentum");
|
||||
inline const PrimitivePtr kPrimLayerNorm = std::make_shared<Primitive>("LayerNorm");
|
||||
inline const PrimitivePtr kPrimLayerNormGrad = std::make_shared<Primitive>("LayerNormGrad");
|
||||
inline const PrimitivePtr kPrimLayerNormXBackprop = std::make_shared<Primitive>("LayerNormXBackprop");
|
||||
inline const PrimitivePtr kPrimLayerNormBetaGammaBackprop = std::make_shared<Primitive>("LayerNormBetaGammaBackprop");
|
||||
inline const PrimitivePtr kPrimDropoutGenMask = std::make_shared<Primitive>("DropoutGenMask");
|
||||
inline const PrimitivePtr kPrimDropoutDoMask = std::make_shared<Primitive>("DropoutDoMask");
|
||||
inline const PrimitivePtr kPrimOneHot = std::make_shared<Primitive>("OneHot");
|
||||
inline const PrimitivePtr kPrimGelu = std::make_shared<Primitive>("Gelu");
|
||||
inline const PrimitivePtr kPrimGeluGrad = std::make_shared<Primitive>("GeluGrad");
|
||||
inline const PrimitivePtr kPrimRelu = std::make_shared<Primitive>("ReLU");
|
||||
inline const PrimitivePtr kPrimReluV2 = std::make_shared<Primitive>("ReLUV2");
|
||||
inline const PrimitivePtr kPrimZerosLike = std::make_shared<Primitive>("ZerosLike");
|
||||
inline const PrimitivePtr kPrimFakeBprop = std::make_shared<Primitive>("fake_bprop");
|
||||
inline const PrimitivePtr kPrimBpropCut = std::make_shared<Primitive>("bprop_cut");
|
||||
inline const PrimitivePtr kPrimFakeQuantPerLayer = std::make_shared<Primitive>("FakeQuantPerLayer");
|
||||
inline const PrimitivePtr kPrimFakeQuantPerChannel = std::make_shared<Primitive>("FakeQuantPerChannel");
|
||||
inline const PrimitivePtr kPrimApplyRMSProp = std::make_shared<Primitive>("ApplyRMSProp");
|
||||
|
||||
// Comm ops
|
||||
inline const PrimitivePtr kPrimMirror = std::make_shared<Primitive>("_MirrorOperator");
|
||||
inline const PrimitivePtr kPrimVirtualDiv = std::make_shared<Primitive>("_VirtualDiv");
|
||||
inline const PrimitivePtr kPrimVirtualDataset = std::make_shared<Primitive>("_VirtualDataset");
|
||||
inline const PrimitivePtr kPrimAllReduce = std::make_shared<Primitive>("AllReduce");
|
||||
|
||||
// RowTensor
|
||||
inline const PrimitivePtr kPrimMakeRowTensor = std::make_shared<Primitive>("MakeRowTensor");
|
||||
inline const PrimitivePtr kPrimRowTensorGetValues = std::make_shared<Primitive>("RowTensorGetValues");
|
||||
inline const PrimitivePtr kPrimRowTensorGetIndices = std::make_shared<Primitive>("RowTensorGetIndices");
|
||||
inline const PrimitivePtr kPrimRowTensorGetDenseShape = std::make_shared<Primitive>("RowTensorGetDenseShape");
|
||||
|
||||
// SparseTensor
|
||||
inline const PrimitivePtr kPrimMakeSparseTensor = std::make_shared<Primitive>("MakeSparseTensor");
|
||||
inline const PrimitivePtr kPrimSparseTensorGetValues = std::make_shared<Primitive>("SparseTensorGetValues");
|
||||
inline const PrimitivePtr kPrimSparseTensorGetIndices = std::make_shared<Primitive>("SparseTensorGetIndices");
|
||||
inline const PrimitivePtr kPrimSparseTensorGetDenseShape = std::make_shared<Primitive>("SparseTensorGetDenseShape");
|
||||
inline const PrimitivePtr kPrimBroadcastGradientArgs = std::make_shared<Primitive>("BroadcastGradientArgs");
|
||||
|
||||
class UnpackGraphPrimitive : public Primitive {
|
||||
public:
|
||||
|
|
|
@ -1,6 +1,4 @@
|
|||
/**
|
||||
* This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
|
||||
*
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
@ -15,360 +13,266 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "frontend/operator/ops_front_infer_function.h"
|
||||
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <algorithm>
|
||||
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "pipeline/jit/static_analysis/prim.h"
|
||||
#include "abstract/utils.h"
|
||||
#include "abstract/param_validator.h"
|
||||
#include "frontend/operator/ops.h"
|
||||
#include "utils/convert_utils.h"
|
||||
#include "utils/tensor_py.h"
|
||||
|
||||
using mindspore::tensor::TensorPy;
|
||||
|
||||
#include "frontend/operator/ops.h"
|
||||
#include "abstract/infer_functions.h"
|
||||
namespace mindspore {
|
||||
namespace abstract {
|
||||
enum State {
|
||||
SAME,
|
||||
X_ONE,
|
||||
Y_ONE,
|
||||
};
|
||||
|
||||
AbstractBasePtr InferImplStringEqual(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: two scalars whose value is a string.
|
||||
const std::string op_name = primitive->name();
|
||||
struct SlideInfo {
|
||||
int start;
|
||||
int step;
|
||||
int stop;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
AbstractBasePtr InferImplTupleOrListEqual(const std::string &op_name, const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: two tuples or two lists.
|
||||
CheckArgsSize(op_name, args_spec_list, 2);
|
||||
AbstractScalarPtr scalar_x = CheckArg<AbstractScalar>(op_name, args_spec_list, 0);
|
||||
AbstractScalarPtr scalar_y = CheckArg<AbstractScalar>(op_name, args_spec_list, 1);
|
||||
auto input_x = CheckArg<T>(op_name, args_spec_list, 0);
|
||||
auto input_y = CheckArg<T>(op_name, args_spec_list, 1);
|
||||
|
||||
ValuePtr value_x = scalar_x->BuildValue();
|
||||
ValuePtr value_y = scalar_y->BuildValue();
|
||||
if (!value_x->isa<StringImm>() || !value_y->isa<StringImm>()) {
|
||||
MS_LOG(EXCEPTION) << op_name << " requires 2 parameters are string, but got param0: " << value_x->ToString()
|
||||
<< ", param1: " << value_y->ToString();
|
||||
ValuePtr x_value = input_x->BuildValue();
|
||||
ValuePtr y_value = input_y->BuildValue();
|
||||
return std::make_shared<AbstractScalar>(*x_value == *y_value);
|
||||
}
|
||||
|
||||
bool ret = (value_x->cast<StringImmPtr>()->value() == value_y->cast<StringImmPtr>()->value());
|
||||
return std::make_shared<AbstractScalar>(ret);
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplStringConcat(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: two scalars whose value is a string.
|
||||
const std::string op_name = primitive->name();
|
||||
CheckArgsSize(op_name, args_spec_list, 2);
|
||||
AbstractScalarPtr scalar_x = CheckArg<AbstractScalar>(op_name, args_spec_list, 0);
|
||||
AbstractScalarPtr scalar_y = CheckArg<AbstractScalar>(op_name, args_spec_list, 1);
|
||||
|
||||
ValuePtr value_x = scalar_x->BuildValue();
|
||||
ValuePtr value_y = scalar_y->BuildValue();
|
||||
if (!value_x->isa<StringImm>() || !value_y->isa<StringImm>()) {
|
||||
MS_LOG(EXCEPTION) << op_name << " requires 2 parameters are string, but got param0: " << value_x->ToString()
|
||||
<< ", param1: " << value_y->ToString();
|
||||
}
|
||||
|
||||
std::string ret = (value_x->cast<StringImmPtr>()->value() + value_y->cast<StringImmPtr>()->value());
|
||||
return std::make_shared<AbstractScalar>(ret);
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplMakeTuple(const AnalysisEnginePtr &, const PrimitivePtr &,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
return std::make_shared<AbstractTuple>(args_spec_list);
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplMakeList(const AnalysisEnginePtr &, const PrimitivePtr &,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
return std::make_shared<AbstractList>(args_spec_list);
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplMakeDict(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: two tuples.
|
||||
const std::string op_name = primitive->name();
|
||||
CheckArgsSize(op_name, args_spec_list, 2);
|
||||
AbstractTuplePtr keys = CheckArg<AbstractTuple>(op_name, args_spec_list, 0);
|
||||
AbstractTuplePtr values = CheckArg<AbstractTuple>(op_name, args_spec_list, 1);
|
||||
|
||||
size_t keys_size = keys->size();
|
||||
if (values->size() != keys_size) {
|
||||
MS_LOG(EXCEPTION) << op_name << " evaluator keys' size is not equal with values' size";
|
||||
}
|
||||
|
||||
std::vector<AbstractAttribute> key_value;
|
||||
AbstractScalarPtr key;
|
||||
AbstractBasePtrList key_list = keys->elements();
|
||||
AbstractBasePtrList value_list = values->elements();
|
||||
for (size_t index = 0; index < keys_size; index++) {
|
||||
key = CheckArg<AbstractScalar>(op_name + "key", key_list, index);
|
||||
ValuePtr keyPtr = key->BuildValue();
|
||||
MS_EXCEPTION_IF_NULL(keyPtr);
|
||||
if (!keyPtr->isa<StringImm>()) {
|
||||
MS_LOG(EXCEPTION) << op_name << " evaluator keys should be string, but got " << keyPtr->ToString();
|
||||
}
|
||||
std::string key_string = GetValue<std::string>(keyPtr);
|
||||
key_value.emplace_back(key_string, value_list[index]);
|
||||
}
|
||||
return std::make_shared<AbstractDictionary>(key_value);
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplMakeKwarg(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: a string and an object of a subclass of AbstractBase.
|
||||
const std::string op_name = primitive->name();
|
||||
CheckArgsSize(op_name, args_spec_list, 2);
|
||||
AbstractScalarPtr key = CheckArg<AbstractScalar>(op_name, args_spec_list, 0);
|
||||
|
||||
ValuePtr keyPtr = key->BuildValue();
|
||||
if (!keyPtr->isa<StringImm>()) {
|
||||
MS_LOG(EXCEPTION) << op_name << " evaluator key should be string, but got " << keyPtr->ToString();
|
||||
}
|
||||
std::string key_string = GetValue<std::string>(keyPtr);
|
||||
return std::make_shared<AbstractKeywordArg>(key_string, args_spec_list[1]);
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplExtractKwarg(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: a string and a keyword.
|
||||
const std::string op_name = primitive->name();
|
||||
CheckArgsSize(op_name, args_spec_list, 2);
|
||||
AbstractScalarPtr key = CheckArg<AbstractScalar>(op_name, args_spec_list, 0);
|
||||
AbstractKeywordArgPtr kwarg = CheckArg<AbstractKeywordArg>(op_name, args_spec_list, 1);
|
||||
|
||||
ValuePtr key_value = key->BuildValue();
|
||||
if (!key_value->isa<StringImm>()) {
|
||||
MS_LOG(EXCEPTION) << op_name << " evaluator key should be string, but got " << key_value->ToString();
|
||||
}
|
||||
std::string key_input = GetValue<std::string>(key_value);
|
||||
std::string key_actual = kwarg->get_key();
|
||||
if (key_actual != key_input) {
|
||||
MS_LOG(EXCEPTION) << op_name << " evaluator input key should be same as AbstractKeywordArg' key, but input is "
|
||||
<< key_input << ", AbstractKeywordArg' key is " << key_actual;
|
||||
}
|
||||
return kwarg->get_arg();
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplMakeSlice(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: three scalars whose value is an int32 number.
|
||||
CheckArgsSize(primitive->name(), args_spec_list, 3);
|
||||
size_t args_size = args_spec_list.size();
|
||||
for (size_t index = 0; index < args_size; index++) {
|
||||
MS_EXCEPTION_IF_NULL(args_spec_list[index]);
|
||||
if (!args_spec_list[index]->isa<AbstractScalar>() && !args_spec_list[index]->isa<AbstractNone>()) {
|
||||
MS_EXCEPTION(TypeError) << "MakeSlice eval " << index << " parameter is neither AbstractScalar nor AbstractNone.";
|
||||
}
|
||||
if (args_spec_list[index]->isa<AbstractScalar>() &&
|
||||
!dyn_cast<AbstractScalar>(args_spec_list[index])->BuildValue()->isa<Int32Imm>()) {
|
||||
MS_EXCEPTION(TypeError) << "MakeSlice eval " << index
|
||||
<< " parameter is an AbstractScalar, but is not an int32 number.";
|
||||
}
|
||||
}
|
||||
// Slice: start, end, step
|
||||
return std::make_shared<AbstractSlice>(args_spec_list[0], args_spec_list[1], args_spec_list[2]);
|
||||
}
|
||||
|
||||
// Eval the return type of make_record
|
||||
AbstractBasePtr InferImplMakeRecord(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: at lease two objects of a subclass of AbstractBase.
|
||||
if (args_spec_list.size() < 2) {
|
||||
MS_LOG(EXCEPTION) << "Typeof evaluator requires more than 1 parameter, while the input size is "
|
||||
<< args_spec_list.size() << ".";
|
||||
}
|
||||
|
||||
// args_spec_list[0] maybe AbstractScalarPtr or AbstractTypePtr
|
||||
void CalcSlidePara(const AbstractBasePtrList &args_spec_list, SlideInfo *slide) {
|
||||
int arg1 = 0;
|
||||
int arg2 = 0;
|
||||
if (!args_spec_list.empty()) {
|
||||
MS_EXCEPTION_IF_NULL(args_spec_list[0]);
|
||||
TypePtr type = args_spec_list[0]->GetTypeTrack();
|
||||
MS_EXCEPTION_IF_NULL(type);
|
||||
if (type->type_id() != kMetaTypeTypeType) {
|
||||
MS_LOG(EXCEPTION) << "Can not make type(" << type->ToString() << ")not TypeType";
|
||||
auto arg_value = args_spec_list[0]->BuildValue();
|
||||
if (!arg_value->isa<Int32Imm>()) {
|
||||
MS_LOG(EXCEPTION) << "Only supported input an int32 number.";
|
||||
}
|
||||
arg1 = GetValue<int>(arg_value);
|
||||
}
|
||||
|
||||
ValuePtr value_track = args_spec_list[0]->GetValueTrack();
|
||||
MS_EXCEPTION_IF_NULL(value_track);
|
||||
TypePtr type_ptr = value_track->cast<TypePtr>();
|
||||
if (type_ptr == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Value type error, not Me type:" << value_track->ToString();
|
||||
if (args_spec_list.size() >= 2) {
|
||||
MS_EXCEPTION_IF_NULL(args_spec_list[1]);
|
||||
auto arg_value = args_spec_list[1]->BuildValue();
|
||||
if (!arg_value->isa<Int32Imm>()) {
|
||||
MS_LOG(EXCEPTION) << "Only supported input an int32 number.";
|
||||
}
|
||||
arg2 = GetValue<int>(arg_value);
|
||||
}
|
||||
|
||||
auto cls = dyn_cast<Class>(type_ptr);
|
||||
MS_EXCEPTION_IF_NULL(cls);
|
||||
ClassAttrVector attributes = cls->GetAttributes();
|
||||
CheckArgsSize(primitive->name(), args_spec_list, attributes.size() + 1);
|
||||
|
||||
std::vector<AbstractAttribute> abs_attributes;
|
||||
for (size_t i = 0; i < attributes.size(); i++) {
|
||||
AbstractAttribute elem(attributes[i].first, args_spec_list[i + 1]);
|
||||
abs_attributes.push_back(elem);
|
||||
}
|
||||
|
||||
return std::make_shared<AbstractClass>(cls->tag(), abs_attributes, cls->methods());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
AbstractBasePtr InferTupleOrListGetItem(const std::string &op_name, const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: a tuple or list and a scalar whose value is an int32 number.
|
||||
CheckArgsSize(op_name, args_spec_list, 2);
|
||||
auto queue = CheckArg<T>(op_name, args_spec_list, 0);
|
||||
AbstractScalarPtr index = CheckArg<AbstractScalar>(op_name, args_spec_list, 1);
|
||||
|
||||
ValuePtr index_value = index->BuildValue();
|
||||
if (!index_value->isa<Int32Imm>()) {
|
||||
// when index_value is an AnyValue and args_spec_list[0] is a scalar, try to return the type of the first element
|
||||
// and continue
|
||||
if (dyn_cast<AbstractScalar>(queue->elements()[0]) != nullptr) {
|
||||
return std::make_shared<AbstractScalar>(queue->elements()[0]->BuildType());
|
||||
}
|
||||
MS_EXCEPTION(IndexError) << op_name << " evaluator index should be an int32 number, but got "
|
||||
<< index_value->ToString();
|
||||
}
|
||||
int idx_v = GetValue<int>(index_value);
|
||||
std::size_t nelems = queue->elements().size();
|
||||
if (idx_v >= SizeToInt(nelems) || idx_v < -SizeToInt(nelems)) {
|
||||
MS_EXCEPTION(IndexError) << op_name << " evaluator index should be in range[-" << SizeToInt(nelems) << ", "
|
||||
<< SizeToInt(nelems) << "), but got " << idx_v << ".";
|
||||
}
|
||||
|
||||
std::size_t uidx_v = 0;
|
||||
if (idx_v >= 0) {
|
||||
uidx_v = IntToSize(idx_v);
|
||||
} else {
|
||||
uidx_v = IntToSize(idx_v + SizeToInt(nelems));
|
||||
}
|
||||
return queue->elements()[uidx_v];
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
AbstractBasePtr InferTupleOrListSetItem(const std::string &op_name, const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: a tuple or list, a scalar whose value is an int32 number and an object of a subclass of AbstractBase.
|
||||
CheckArgsSize(op_name, args_spec_list, 3);
|
||||
auto queue = CheckArg<T>(op_name, args_spec_list, 0);
|
||||
AbstractScalarPtr index = CheckArg<AbstractScalar>(op_name, args_spec_list, 1);
|
||||
|
||||
ValuePtr index_value = index->BuildValue();
|
||||
if (!index_value->isa<Int32Imm>()) {
|
||||
MS_EXCEPTION(IndexError) << op_name << " evaluator index should be an int32 number, but got "
|
||||
<< index_value->ToString();
|
||||
}
|
||||
int idx_v = GetValue<int>(index_value);
|
||||
if (idx_v < 0) {
|
||||
MS_EXCEPTION(IndexError) << "The index of " << typeid(T).name() << " should be positive number, but got " << idx_v
|
||||
<< ".";
|
||||
}
|
||||
|
||||
size_t uidx_v = IntToSize(idx_v);
|
||||
AbstractBasePtrList elements = queue->elements();
|
||||
std::size_t nelems = elements.size();
|
||||
if (uidx_v >= nelems) {
|
||||
MS_EXCEPTION(IndexError) << op_name << " evaluator the index: " << uidx_v << " to set out of range: " << nelems - 1
|
||||
<< ".";
|
||||
}
|
||||
elements[uidx_v] = args_spec_list[2];
|
||||
return std::make_shared<T>(elements);
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplTupleGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
return InferTupleOrListGetItem<AbstractTuple>(primitive->name(), args_spec_list);
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplListGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
return InferTupleOrListGetItem<AbstractList>(primitive->name(), args_spec_list);
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplTupleSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
return InferTupleOrListSetItem<AbstractTuple>(primitive->name(), args_spec_list);
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplListSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
return InferTupleOrListSetItem<AbstractList>(primitive->name(), args_spec_list);
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplDictGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: a dict and a scalar whose value is a string.
|
||||
const std::string op_name = primitive->name();
|
||||
CheckArgsSize(op_name, args_spec_list, 2);
|
||||
AbstractDictionaryPtr dict = CheckArg<AbstractDictionary>(op_name, args_spec_list, 0);
|
||||
AbstractScalarPtr key = CheckArg<AbstractScalar>(op_name, args_spec_list, 1);
|
||||
|
||||
ValuePtr key_value = key->BuildValue();
|
||||
if (!key_value->isa<StringImm>()) {
|
||||
MS_LOG(EXCEPTION) << op_name << " evaluator key should be string, but got " << key_value->ToString();
|
||||
}
|
||||
auto key_str = GetValue<std::string>(key_value);
|
||||
std::vector<AbstractAttribute> dict_elems = dict->elements();
|
||||
auto it = std::find_if(dict_elems.begin(), dict_elems.end(),
|
||||
[key_str](const AbstractAttribute &item) { return item.first == key_str; });
|
||||
|
||||
if (it == dict_elems.end()) {
|
||||
MS_LOG(EXCEPTION) << "The key " << key_str << " does not exist in the dict:" << args_spec_list[0]->ToString();
|
||||
}
|
||||
return it->second;
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplDictSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: a dict and a scalar whose value is a string and an object of a subclass of AbstractBase.
|
||||
const std::string op_name = primitive->name();
|
||||
CheckArgsSize(op_name, args_spec_list, 3);
|
||||
AbstractDictionaryPtr dict = CheckArg<AbstractDictionary>(op_name, args_spec_list, 0);
|
||||
AbstractScalarPtr key = CheckArg<AbstractScalar>(op_name, args_spec_list, 1);
|
||||
|
||||
ValuePtr key_value = key->BuildValue();
|
||||
if (!key_value->isa<StringImm>()) {
|
||||
MS_LOG(EXCEPTION) << op_name << " evaluator key should be string, but got " << key_value->ToString();
|
||||
}
|
||||
std::string key_str = GetValue<std::string>(key_value);
|
||||
std::vector<AbstractAttribute> dict_elems = dict->elements();
|
||||
auto it = std::find_if(dict_elems.begin(), dict_elems.end(),
|
||||
[key_str](AbstractAttribute &item) { return item.first == key_str; });
|
||||
|
||||
if (args_spec_list.size() == 3) {
|
||||
MS_EXCEPTION_IF_NULL(args_spec_list[2]);
|
||||
auto new_ele = std::make_pair(key_str, args_spec_list[2]);
|
||||
if (it != dict_elems.end()) {
|
||||
int index = it - dict_elems.begin();
|
||||
dict_elems[IntToSize(index)] = new_ele;
|
||||
} else {
|
||||
dict_elems.push_back(new_ele);
|
||||
auto arg_value = args_spec_list[2]->BuildValue();
|
||||
if (!arg_value->isa<Int32Imm>()) {
|
||||
MS_LOG(EXCEPTION) << "Only supported input an int32 number.";
|
||||
}
|
||||
return std::make_shared<AbstractDictionary>(dict_elems);
|
||||
slide->step = GetValue<int>(arg_value);
|
||||
slide->start = arg1;
|
||||
slide->stop = arg2;
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplListAppend(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
if (args_spec_list.size() == 2) {
|
||||
slide->start = arg1;
|
||||
slide->stop = arg2;
|
||||
}
|
||||
|
||||
if (args_spec_list.size() == 1) {
|
||||
slide->stop = arg1;
|
||||
}
|
||||
}
|
||||
|
||||
void ComputeReduceIndex(const std::vector<int> &reverse_x, const std::vector<int> &reverse_y,
|
||||
std::vector<int> *grad_x_reduce_idx, std::vector<int> *grad_y_reduce_idy) {
|
||||
const size_t n = reverse_x.size();
|
||||
for (size_t i = 0; i < n; ++i) {
|
||||
State curr;
|
||||
const int32_t x_i = reverse_x[i];
|
||||
const int32_t y_i = reverse_y[i];
|
||||
const int reduce_idx = SizeToInt(n - 1 - i);
|
||||
if (x_i == y_i) {
|
||||
curr = SAME;
|
||||
} else if (x_i == 1) {
|
||||
grad_x_reduce_idx->push_back(reduce_idx);
|
||||
curr = X_ONE;
|
||||
} else if (y_i == 1) {
|
||||
grad_y_reduce_idy->push_back(reduce_idx);
|
||||
curr = Y_ONE;
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "not compatible shape input for BroadcastGradientArgs";
|
||||
}
|
||||
if (curr == SAME && x_i == 1) {
|
||||
grad_x_reduce_idx->push_back(reduce_idx);
|
||||
grad_y_reduce_idy->push_back(reduce_idx);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
std::reverse(grad_x_reduce_idx->begin(), grad_x_reduce_idx->end());
|
||||
std::reverse(grad_y_reduce_idy->begin(), grad_y_reduce_idy->end());
|
||||
}
|
||||
|
||||
AbstractBasePtr BroadcastGradientArgsDiff(const std::vector<ValuePtr> &x_shape, const std::vector<ValuePtr> &y_shape) {
|
||||
std::vector<int> reverse_x;
|
||||
std::vector<int> reverse_y;
|
||||
|
||||
(void)std::transform(x_shape.rbegin(), x_shape.rend(), std::back_inserter(reverse_x),
|
||||
[](const ValuePtr &v) { return v->cast<Int32ImmPtr>()->value(); });
|
||||
(void)std::transform(y_shape.rbegin(), y_shape.rend(), std::back_inserter(reverse_y),
|
||||
[](const ValuePtr &v) { return v->cast<Int32ImmPtr>()->value(); });
|
||||
|
||||
if (reverse_x.size() > reverse_y.size()) {
|
||||
reverse_y.resize(reverse_x.size(), 1);
|
||||
} else {
|
||||
reverse_x.resize(reverse_y.size(), 1);
|
||||
}
|
||||
|
||||
std::vector<int> grad_x_reduce_idx;
|
||||
std::vector<int> grad_y_reduce_idy;
|
||||
ComputeReduceIndex(reverse_x, reverse_y, &grad_x_reduce_idx, &grad_y_reduce_idy);
|
||||
|
||||
AbstractBasePtrList abs_list_x;
|
||||
AbstractBasePtrList abs_list_y;
|
||||
(void)std::transform(grad_x_reduce_idx.begin(), grad_x_reduce_idx.end(), std::back_inserter(abs_list_x),
|
||||
[](int v) { return abstract::FromValue(v); });
|
||||
(void)std::transform(grad_y_reduce_idy.begin(), grad_y_reduce_idy.end(), std::back_inserter(abs_list_y),
|
||||
[](int v) { return abstract::FromValue(v); });
|
||||
auto x_reduce_idx = std::make_shared<AbstractTuple>(abs_list_x);
|
||||
auto y_reduce_idx = std::make_shared<AbstractTuple>(abs_list_y);
|
||||
AbstractBasePtrList elem_list;
|
||||
elem_list.push_back(x_reduce_idx);
|
||||
elem_list.push_back(y_reduce_idx);
|
||||
|
||||
return std::make_shared<AbstractTuple>(elem_list);
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplTypeof(const AnalysisEnginePtr &, const PrimitivePtr &,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: a list and an object of a subclass of AbstractBase.
|
||||
// Inputs: a pointer to an AbstractBase object
|
||||
if (args_spec_list.size() != 1) {
|
||||
MS_LOG(EXCEPTION) << "Typeof evaluator requires 1 parameter, while the input size is " << args_spec_list.size()
|
||||
<< ".";
|
||||
}
|
||||
AbstractBasePtr abs_base = args_spec_list[0];
|
||||
MS_EXCEPTION_IF_NULL(abs_base);
|
||||
TypePtr type = abs_base->BuildType();
|
||||
return std::make_shared<AbstractType>(type);
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplHasType(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: a pointer to an AbstractBase object and a pointer to a Type
|
||||
const std::string op_name = primitive->name();
|
||||
CheckArgsSize(op_name, args_spec_list, 2);
|
||||
AbstractListPtr list = CheckArg<AbstractList>(op_name, args_spec_list, 0);
|
||||
(void)AbstractJoin(list->elements());
|
||||
return list;
|
||||
AbstractTypePtr abs_type = CheckArg<AbstractType>(op_name, args_spec_list, 1);
|
||||
|
||||
auto mode_v = abs_type->GetValueTrack();
|
||||
MS_EXCEPTION_IF_NULL(mode_v);
|
||||
if (!mode_v->isa<Type>()) {
|
||||
MS_LOG(EXCEPTION) << "Get the type from AbstractType value failed.";
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
AbstractBasePtr InferTupleOrListOrDictLen(const std::string &op_name, const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: a tuple or list or dict.
|
||||
CheckArgsSize(op_name, args_spec_list, 1);
|
||||
auto arg = CheckArg<T>(op_name, args_spec_list, 0);
|
||||
return std::make_shared<AbstractScalar>(SizeToInt(arg->size()));
|
||||
TypePtr mode_t = mode_v->cast<TypePtr>();
|
||||
MS_EXCEPTION_IF_NULL(args_spec_list[0]);
|
||||
bool v = IsSubtype(args_spec_list[0], mode_t);
|
||||
return std::make_shared<AbstractScalar>(std::make_shared<BoolImm>(v), kBool);
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplTupleLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
bool CompareShape(const std::vector<ValuePtr> &x_shape, const std::vector<ValuePtr> &y_shape) {
|
||||
if (x_shape.size() != y_shape.size()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < x_shape.size(); ++i) {
|
||||
if (GetValue<int>(x_shape[i]) != GetValue<int>(y_shape[i])) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
AbstractBasePtr DoInferReduceShape(const AbstractTuplePtr &x_shape, const ValuePtr &x_shp_value,
|
||||
const ValueTuplePtr &axis_value_ptr, const PrimitivePtr &primitive) {
|
||||
size_t x_rank = x_shape->size();
|
||||
std::set<int> axis_set;
|
||||
auto axis_data = axis_value_ptr->value();
|
||||
if (axis_data.empty()) {
|
||||
int size = 1;
|
||||
AbstractBasePtrList values(x_rank, std::make_shared<AbstractScalar>(size));
|
||||
return std::make_shared<AbstractTuple>(values);
|
||||
}
|
||||
|
||||
for (auto &elem : axis_data) {
|
||||
int e_value = CheckAxis(primitive->name(), elem, -SizeToInt(x_rank), SizeToInt(x_rank) - 1);
|
||||
(void)axis_set.insert(e_value);
|
||||
}
|
||||
|
||||
auto x_shp_data = x_shp_value->cast<ValueTuplePtr>()->value();
|
||||
if (x_shp_data.size() < x_rank) {
|
||||
MS_LOG(EXCEPTION) << "x_shape_data.size() " << x_shp_data.size() << " less than x_shape.size() " << x_rank;
|
||||
}
|
||||
AbstractBasePtrList values;
|
||||
for (size_t i = 0; i < x_rank; i++) {
|
||||
if (axis_set.count(SizeToInt(i)) || axis_set.count(SizeToInt(i) - SizeToInt(x_rank))) {
|
||||
auto axis_v = MakeValue(1);
|
||||
values.push_back(std::make_shared<AbstractScalar>(axis_v, axis_v->type()));
|
||||
} else {
|
||||
int dim_value = x_shp_data[i]->cast<Int32ImmPtr>()->value();
|
||||
auto dim = MakeValue(dim_value);
|
||||
values.push_back(std::make_shared<AbstractScalar>(dim, dim->type()));
|
||||
}
|
||||
}
|
||||
|
||||
return std::make_shared<AbstractTuple>(values);
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplBroadcastGradientArgs(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
return InferTupleOrListOrDictLen<AbstractTuple>(primitive->name(), args_spec_list);
|
||||
// this primitive get the index that need to reduce
|
||||
// input: x's shape and y's shape, inputs should be tuple
|
||||
// output: tuple of x and y 's reduce index, reduce index should be a tuple
|
||||
const std::string op_name = primitive->name();
|
||||
CheckArgsSize(op_name, args_spec_list, 2);
|
||||
auto arg_x = CheckArg<AbstractTuple>(op_name, args_spec_list, 0);
|
||||
auto arg_y = CheckArg<AbstractTuple>(op_name, args_spec_list, 1);
|
||||
|
||||
ValueTuplePtr arg_x_value = arg_x->BuildValue()->cast<ValueTuplePtr>();
|
||||
MS_EXCEPTION_IF_NULL(arg_x_value);
|
||||
|
||||
ValueTuplePtr arg_y_value = arg_y->BuildValue()->cast<ValueTuplePtr>();
|
||||
MS_EXCEPTION_IF_NULL(arg_y_value);
|
||||
|
||||
const std::vector<ValuePtr> x_shape = arg_x_value->value();
|
||||
const std::vector<ValuePtr> y_shape = arg_y_value->value();
|
||||
bool is_same_shape = CompareShape(x_shape, y_shape);
|
||||
// if it is the same shape , do not need reduce , return empty tuple
|
||||
if (is_same_shape) {
|
||||
AbstractBasePtrList empty_list;
|
||||
auto x_reduce_idx = std::make_shared<AbstractTuple>(empty_list);
|
||||
auto y_reduce_idx = std::make_shared<AbstractTuple>(empty_list);
|
||||
|
||||
AbstractBasePtrList elem_list;
|
||||
elem_list.push_back(x_reduce_idx);
|
||||
elem_list.push_back(y_reduce_idx);
|
||||
|
||||
return std::make_shared<AbstractTuple>(elem_list);
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplListLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
return InferTupleOrListOrDictLen<AbstractList>(primitive->name(), args_spec_list);
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplDictLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
return InferTupleOrListOrDictLen<AbstractDictionary>(primitive->name(), args_spec_list);
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplArrayLen(const AnalysisEnginePtr &, const PrimitivePtr &,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
return std::make_shared<AbstractScalar>(kAnyValue, kInt32);
|
||||
return BroadcastGradientArgsDiff(x_shape, y_shape);
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplListMap(const AnalysisEnginePtr &engine, const PrimitivePtr &primitive,
|
||||
|
@ -430,41 +334,6 @@ AbstractBasePtr InferImplTupleReversed(const AnalysisEnginePtr &, const Primitiv
|
|||
return std::make_shared<AbstractTuple>(elem_list);
|
||||
}
|
||||
|
||||
AbstractBasePtr DoInferReduceShape(const AbstractTuplePtr &x_shape, const ValuePtr &x_shp_value,
|
||||
const ValueTuplePtr &axis_value_ptr, const PrimitivePtr &primitive) {
|
||||
size_t x_rank = x_shape->size();
|
||||
std::set<int> axis_set;
|
||||
auto axis_data = axis_value_ptr->value();
|
||||
if (axis_data.empty()) {
|
||||
int size = 1;
|
||||
AbstractBasePtrList values(x_rank, std::make_shared<AbstractScalar>(size));
|
||||
return std::make_shared<AbstractTuple>(values);
|
||||
}
|
||||
|
||||
for (auto &elem : axis_data) {
|
||||
int e_value = CheckAxis(primitive->name(), elem, -SizeToInt(x_rank), SizeToInt(x_rank) - 1);
|
||||
(void)axis_set.insert(e_value);
|
||||
}
|
||||
|
||||
auto x_shp_data = x_shp_value->cast<ValueTuplePtr>()->value();
|
||||
if (x_shp_data.size() < x_rank) {
|
||||
MS_LOG(EXCEPTION) << "x_shape_data.size() " << x_shp_data.size() << " less than x_shape.size() " << x_rank;
|
||||
}
|
||||
AbstractBasePtrList values;
|
||||
for (size_t i = 0; i < x_rank; i++) {
|
||||
if (axis_set.count(SizeToInt(i)) || axis_set.count(SizeToInt(i) - SizeToInt(x_rank))) {
|
||||
auto axis_v = MakeValue(1);
|
||||
values.push_back(std::make_shared<AbstractScalar>(axis_v, axis_v->type()));
|
||||
} else {
|
||||
int dim_value = x_shp_data[i]->cast<Int32ImmPtr>()->value();
|
||||
auto dim = MakeValue(dim_value);
|
||||
values.push_back(std::make_shared<AbstractScalar>(dim, dim->type()));
|
||||
}
|
||||
}
|
||||
|
||||
return std::make_shared<AbstractTuple>(values);
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplReduceShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: x_shape, axis
|
||||
|
@ -563,7 +432,7 @@ AbstractBasePtr InferImplTuple2Array(const AnalysisEnginePtr &, const PrimitiveP
|
|||
|
||||
py::tuple data_tuple = ValuePtrToPyData(input->BuildValue());
|
||||
py::array data = py::array(data_tuple);
|
||||
auto tensor = TensorPy::MakeTensor(data);
|
||||
auto tensor = tensor::TensorPy::MakeTensor(data);
|
||||
auto ret = tensor->ToAbstract();
|
||||
ret->set_value(tensor);
|
||||
MS_LOG(DEBUG) << "Tuple2arry result AbstractTensor: " << ret->ToString();
|
||||
|
@ -596,76 +465,6 @@ AbstractBasePtr InferImplShapeMul(const AnalysisEnginePtr &, const PrimitivePtr
|
|||
return std::make_shared<AbstractScalar>(result_v, result_v->type());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
AbstractBasePtr InferImplTupleOrListEqual(const std::string &op_name, const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: two tuples or two lists.
|
||||
CheckArgsSize(op_name, args_spec_list, 2);
|
||||
auto input_x = CheckArg<T>(op_name, args_spec_list, 0);
|
||||
auto input_y = CheckArg<T>(op_name, args_spec_list, 1);
|
||||
|
||||
ValuePtr x_value = input_x->BuildValue();
|
||||
ValuePtr y_value = input_y->BuildValue();
|
||||
return std::make_shared<AbstractScalar>(*x_value == *y_value);
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplTupleEqual(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
return InferImplTupleOrListEqual<AbstractTuple>(primitive->name(), args_spec_list);
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplListEqual(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
return InferImplTupleOrListEqual<AbstractList>(primitive->name(), args_spec_list);
|
||||
}
|
||||
|
||||
struct SlideInfo {
|
||||
int start;
|
||||
int step;
|
||||
int stop;
|
||||
};
|
||||
|
||||
void CalcSlidePara(const AbstractBasePtrList &args_spec_list, SlideInfo *slide) {
|
||||
int arg1 = 0;
|
||||
int arg2 = 0;
|
||||
if (!args_spec_list.empty()) {
|
||||
MS_EXCEPTION_IF_NULL(args_spec_list[0]);
|
||||
auto arg_value = args_spec_list[0]->BuildValue();
|
||||
if (!arg_value->isa<Int32Imm>()) {
|
||||
MS_LOG(EXCEPTION) << "Only supported input an int32 number.";
|
||||
}
|
||||
arg1 = GetValue<int>(arg_value);
|
||||
}
|
||||
|
||||
if (args_spec_list.size() >= 2) {
|
||||
MS_EXCEPTION_IF_NULL(args_spec_list[1]);
|
||||
auto arg_value = args_spec_list[1]->BuildValue();
|
||||
if (!arg_value->isa<Int32Imm>()) {
|
||||
MS_LOG(EXCEPTION) << "Only supported input an int32 number.";
|
||||
}
|
||||
arg2 = GetValue<int>(arg_value);
|
||||
}
|
||||
|
||||
if (args_spec_list.size() == 3) {
|
||||
MS_EXCEPTION_IF_NULL(args_spec_list[2]);
|
||||
auto arg_value = args_spec_list[2]->BuildValue();
|
||||
if (!arg_value->isa<Int32Imm>()) {
|
||||
MS_LOG(EXCEPTION) << "Only supported input an int32 number.";
|
||||
}
|
||||
slide->step = GetValue<int>(arg_value);
|
||||
slide->start = arg1;
|
||||
slide->stop = arg2;
|
||||
}
|
||||
|
||||
if (args_spec_list.size() == 2) {
|
||||
slide->start = arg1;
|
||||
slide->stop = arg2;
|
||||
}
|
||||
|
||||
if (args_spec_list.size() == 1) {
|
||||
slide->stop = arg1;
|
||||
}
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplMakeRange(const AnalysisEnginePtr &, const PrimitivePtr &,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
if (args_spec_list.empty()) {
|
||||
|
@ -709,5 +508,145 @@ AbstractBasePtr InferImplStopGradient(const AnalysisEnginePtr &, const Primitive
|
|||
CheckArgsSize(primitive->name(), args_spec_list, 1);
|
||||
return args_spec_list[0]->Clone();
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplTupleEqual(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
return InferImplTupleOrListEqual<AbstractTuple>(primitive->name(), args_spec_list);
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplListEqual(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
return InferImplTupleOrListEqual<AbstractList>(primitive->name(), args_spec_list);
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplStringEqual(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: two scalars whose value is a string.
|
||||
const std::string op_name = primitive->name();
|
||||
CheckArgsSize(op_name, args_spec_list, 2);
|
||||
AbstractScalarPtr scalar_x = CheckArg<AbstractScalar>(op_name, args_spec_list, 0);
|
||||
AbstractScalarPtr scalar_y = CheckArg<AbstractScalar>(op_name, args_spec_list, 1);
|
||||
|
||||
ValuePtr value_x = scalar_x->BuildValue();
|
||||
ValuePtr value_y = scalar_y->BuildValue();
|
||||
if (!value_x->isa<StringImm>() || !value_y->isa<StringImm>()) {
|
||||
MS_LOG(EXCEPTION) << op_name << " requires 2 parameters are string, but got param0: " << value_x->ToString()
|
||||
<< ", param1: " << value_y->ToString();
|
||||
}
|
||||
|
||||
bool ret = (value_x->cast<StringImmPtr>()->value() == value_y->cast<StringImmPtr>()->value());
|
||||
return std::make_shared<AbstractScalar>(ret);
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplStringConcat(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: two scalars whose value is a string.
|
||||
const std::string op_name = primitive->name();
|
||||
CheckArgsSize(op_name, args_spec_list, 2);
|
||||
AbstractScalarPtr scalar_x = CheckArg<AbstractScalar>(op_name, args_spec_list, 0);
|
||||
AbstractScalarPtr scalar_y = CheckArg<AbstractScalar>(op_name, args_spec_list, 1);
|
||||
|
||||
ValuePtr value_x = scalar_x->BuildValue();
|
||||
ValuePtr value_y = scalar_y->BuildValue();
|
||||
if (!value_x->isa<StringImm>() || !value_y->isa<StringImm>()) {
|
||||
MS_LOG(EXCEPTION) << op_name << " requires 2 parameters are string, but got param0: " << value_x->ToString()
|
||||
<< ", param1: " << value_y->ToString();
|
||||
}
|
||||
|
||||
std::string ret = (value_x->cast<StringImmPtr>()->value() + value_y->cast<StringImmPtr>()->value());
|
||||
return std::make_shared<AbstractScalar>(ret);
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplDictLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
return InferTupleOrListOrDictLen<AbstractDictionary>(primitive->name(), args_spec_list);
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplJ(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// args: An object of AbstractFunction.
|
||||
CheckArgsSize(primitive->name(), args_spec_list, 1);
|
||||
MS_LOG(DEBUG) << "evaluate J: " << args_spec_list[0]->ToString();
|
||||
|
||||
AbstractFunctionPtr x = dyn_cast<AbstractFunction>(args_spec_list[0]);
|
||||
if (x == nullptr) {
|
||||
return std::make_shared<AbstractJTagged>(args_spec_list[0]);
|
||||
}
|
||||
|
||||
AbstractFuncAtomPtrList jv;
|
||||
auto build_jv = [&jv](const AbstractFuncAtomPtr &func) {
|
||||
auto j_closure = std::make_shared<JTransformedAbstractClosure>(func);
|
||||
jv.push_back(j_closure);
|
||||
};
|
||||
x->Visit(build_jv);
|
||||
|
||||
return AbstractFunction::MakeAbstractFunction(jv);
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplFakeBprop(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: a tensor.
|
||||
CheckArgsSize(primitive->name(), args_spec_list, 1);
|
||||
return args_spec_list[0]->Broaden();
|
||||
}
|
||||
|
||||
// Eval the return type of make_record
|
||||
AbstractBasePtr InferImplMakeRecord(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: at lease two objects of a subclass of AbstractBase.
|
||||
if (args_spec_list.size() < 2) {
|
||||
MS_LOG(EXCEPTION) << "Typeof evaluator requires more than 1 parameter, while the input size is "
|
||||
<< args_spec_list.size() << ".";
|
||||
}
|
||||
|
||||
// args_spec_list[0] maybe AbstractScalarPtr or AbstractTypePtr
|
||||
MS_EXCEPTION_IF_NULL(args_spec_list[0]);
|
||||
TypePtr type = args_spec_list[0]->GetTypeTrack();
|
||||
MS_EXCEPTION_IF_NULL(type);
|
||||
if (type->type_id() != kMetaTypeTypeType) {
|
||||
MS_LOG(EXCEPTION) << "Can not make type(" << type->ToString() << ")not TypeType";
|
||||
}
|
||||
|
||||
ValuePtr value_track = args_spec_list[0]->GetValueTrack();
|
||||
MS_EXCEPTION_IF_NULL(value_track);
|
||||
TypePtr type_ptr = value_track->cast<TypePtr>();
|
||||
if (type_ptr == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Value type error, not Me type:" << value_track->ToString();
|
||||
}
|
||||
|
||||
auto cls = dyn_cast<Class>(type_ptr);
|
||||
MS_EXCEPTION_IF_NULL(cls);
|
||||
ClassAttrVector attributes = cls->GetAttributes();
|
||||
CheckArgsSize(primitive->name(), args_spec_list, attributes.size() + 1);
|
||||
|
||||
std::vector<AbstractAttribute> abs_attributes;
|
||||
for (size_t i = 0; i < attributes.size(); i++) {
|
||||
AbstractAttribute elem(attributes[i].first, args_spec_list[i + 1]);
|
||||
abs_attributes.push_back(elem);
|
||||
}
|
||||
|
||||
return std::make_shared<AbstractClass>(cls->tag(), abs_attributes, cls->methods());
|
||||
}
|
||||
REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(TypeOf, prim::kPrimTypeOf, InferImplTypeof);
|
||||
REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(HasType, prim::kPrimHasType, InferImplHasType);
|
||||
REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(MakeRecord, prim::kPrimMakeRecord, InferImplMakeRecord);
|
||||
REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(ListMap, prim::kPrimListMap, InferImplListMap);
|
||||
REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(ListReduce, prim::kPrimListReduce, InferImplListReduce);
|
||||
REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(TupleReversed, prim::kPrimTupleReversed, InferImplTupleReversed);
|
||||
REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(ReducedShape, prim::kPrimReducedShape, InferImplReduceShape);
|
||||
REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(TupleDiv, prim::kPrimTupleDiv, InferImplTupleDiv);
|
||||
REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(TupleToArray, prim::kPrimTupleToArray, InferImplTuple2Array);
|
||||
REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(ShapeMul, prim::kPrimShapeMul, InferImplShapeMul);
|
||||
REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(TupleEqual, prim::kPrimTupleEqual, InferImplTupleEqual);
|
||||
REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(ListEqual, prim::kPrimListEqual, InferImplListEqual);
|
||||
REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(MakeRange, prim::kPrimMakeRange, InferImplMakeRange);
|
||||
REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(StopGradient, prim::kPrimStopGradient, InferImplStopGradient);
|
||||
REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(StringEqual, prim::kPrimStringEqual, InferImplStringEqual);
|
||||
REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(StringConcat, prim::kPrimStringConcat, InferImplStringConcat);
|
||||
REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(DictLen, prim::kPrimDictLen, InferImplDictLen);
|
||||
REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(FakeBprop, prim::kPrimFakeBprop, InferImplFakeBprop);
|
||||
REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(J, prim::kPrimJ, InferImplJ);
|
||||
REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(BroadcastGradientArgs, prim::kPrimBroadcastGradientArgs,
|
||||
InferImplBroadcastGradientArgs);
|
||||
} // namespace abstract
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,77 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_FRONTEND_OPERATE_OPS_FRONT_INFER_FUNCTION_H_
|
||||
#define MINDSPORE_CCSRC_FRONTEND_OPERATE_OPS_FRONT_INFER_FUNCTION_H_
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "abstract/primitive_infer_map.h"
|
||||
namespace mindspore {
|
||||
namespace abstract {
|
||||
AbstractBasePtr InferImplTypeof(const AnalysisEnginePtr &, const PrimitivePtr &,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplHasType(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplBroadcastGradientArgs(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplListMap(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplListReduce(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplTupleReversed(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplReduceShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplTupleDiv(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplTuple2Array(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplShapeMul(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplTupleEqual(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplListEqual(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplMakeRange(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplStopGradient(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplStringEqual(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplStringConcat(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplDictLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplJ(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplFakeBprop(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplMakeRecord(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
|
||||
class RegisterFrontendPrimitiveEvalHelper {
|
||||
public:
|
||||
RegisterFrontendPrimitiveEvalHelper(const PrimitivePtr &primitive, const StandardPrimitiveEvalImpl &impl) {
|
||||
const StandardPrimitiveImplReg impl_reg{impl, false};
|
||||
RegisterStandardPrimitiveImpl(primitive, impl_reg);
|
||||
}
|
||||
~RegisterFrontendPrimitiveEvalHelper() = default;
|
||||
};
|
||||
|
||||
#define REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(name, primitive, impl) \
|
||||
static auto helper_##name = RegisterFrontendPrimitiveEvalHelper(primitive, impl)
|
||||
} // namespace abstract
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_FRONTEND_OPERATE_OPS_FRONT_INFER_FUNCTION_H_
|
|
@ -95,10 +95,10 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
|
|||
// Ref eliminate
|
||||
make_ref_eliminate_ =
|
||||
MakeSubstitution(std::make_shared<MakeRefEliminater>(), "make_ref_eliminate", prim::kPrimMakeRef);
|
||||
get_ref_param_eliminate_ = MakeSubstitution(std::make_shared<GetRefParamEliminater>(), "get_ref_param_eliminate",
|
||||
{prim::kPrimGetRefValue, prim::kPrimGetRefOrigin});
|
||||
get_ref_param_eliminate_ =
|
||||
MakeSubstitution(std::make_shared<GetRefParamEliminater>(), "get_ref_param_eliminate", {prim::kPrimGetRefValue});
|
||||
get_make_ref_eliminate_ = MakeSubstitution(std::make_shared<GetMakeRefEliminater>(), "get_make_ref_eliminate",
|
||||
{prim::kPrimGetRefKey, prim::kPrimGetRefValue, prim::kPrimGetRefOrigin});
|
||||
{prim::kPrimGetRefKey, prim::kPrimGetRefValue});
|
||||
|
||||
replace_refkey_by_param_ = MakeSubstitution(std::make_shared<ReplaceRefkeyByParam>(), "replace_refkey_by_param",
|
||||
IsValueNode<RefKey>, opt::FORCE_RENORM);
|
||||
|
|
|
@ -20,9 +20,6 @@ namespace mindspore {
|
|||
namespace opt {
|
||||
namespace irpass {
|
||||
AnfNodePtr ArithmeticSimplify::operator()(const OptimizerPtr &, const AnfNodePtr &node) {
|
||||
if (MsContext::GetInstance()->execution_mode() == kPynativeMode) {
|
||||
return nullptr;
|
||||
}
|
||||
PatternNode x, y, z, xs;
|
||||
PConstant one_(node, false, 1);
|
||||
PConstant one_scalar_(node, false, 1, true);
|
||||
|
@ -32,6 +29,7 @@ AnfNodePtr ArithmeticSimplify::operator()(const OptimizerPtr &, const AnfNodePtr
|
|||
PConstant const_2(node);
|
||||
PConstant any_const(node);
|
||||
|
||||
if (MsContext::GetInstance()->execution_mode() != kPynativeMode) {
|
||||
MATCH_REPLACE(node, x + zero_, x); // Add by zero
|
||||
MATCH_REPLACE(node, x + zero_scalar_, x); // Add by zero
|
||||
MATCH_REPLACE(node, PBinOperation(prim::kPrimScalarAdd, x, zero_scalar_, true), x); // Scalar Add by zero
|
||||
|
@ -40,8 +38,12 @@ AnfNodePtr ArithmeticSimplify::operator()(const OptimizerPtr &, const AnfNodePtr
|
|||
|
||||
// Scalar Mul by zero
|
||||
MATCH_REPLACE(node, PBinOperation(prim::kPrimScalarMul, x, zero_scalar_, true), zero_scalar_.NewValue());
|
||||
}
|
||||
// Prim Eliminate (identity)
|
||||
MATCH_REPLACE(node, PPrimitive(prim::kPrimIdentity, x), x);
|
||||
if (MsContext::GetInstance()->execution_mode() == kPynativeMode) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// ConstantDuplicateMul
|
||||
auto const_dup_lambda = [&node, &x, &const_, &const_2]() -> AnfNodePtr {
|
||||
|
@ -95,37 +97,37 @@ AnfNodePtr ArithmeticSimplify2::operator()(const OptimizerPtr &, const AnfNodePt
|
|||
// {prim::kPrimAddN, {prim::kPrimMakeTuple, {prim::kPrimMul, {prim::kPrimAllReduce, X}, Y}, Z}} ->
|
||||
// {prim::kPrimMul, {prim::kPrimAllReduce, {prim::kPrimAddN,{prim::kPrimMakeTuple, Z, X}}}, Y}
|
||||
AnfNodePtr AdjustAllReduceMulAdd::operator()(const OptimizerPtr &, const AnfNodePtr &node) {
|
||||
Reset();
|
||||
// {prim::kPrimAddN, Zs}
|
||||
if (!IsPrimitiveCNode(node, prim::kPrimAddN)) {
|
||||
return nullptr;
|
||||
}
|
||||
auto addn = node->cast<CNodePtr>();
|
||||
if (addn->size() != 2) {
|
||||
return nullptr;
|
||||
}
|
||||
AnfVisitor::Match(prim::kPrimMakeTuple, {IsNode, IsNode})(addn->input(1));
|
||||
if (x_ == nullptr || y_ == nullptr || z_ == nullptr || all_reduce_fg_ == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
auto addn_maketuple = addn->input(1);
|
||||
|
||||
auto fg = all_reduce_fg_;
|
||||
// addn inputs cross the graph, make the inputs same as allreduce node.
|
||||
PatternNode x, y, z;
|
||||
auto all_reduce_pat = PPrimitive(prim::kPrimAllReduce, x);
|
||||
auto mul_pat = PBinOperation(prim::kPrimMul, all_reduce_pat, y, true);
|
||||
auto admktup_pat = PBinOperation(prim::kPrimMakeTuple, mul_pat, z, true);
|
||||
auto addn_pat = PPrimitive(prim::kPrimAddN, admktup_pat);
|
||||
auto adjust_lambda = [&node, &x, &y, &z, &addn_pat, &all_reduce_pat, &admktup_pat, &mul_pat, this]() -> AnfNodePtr {
|
||||
auto fg = all_reduce_pat.GetFuncGraph();
|
||||
auto z_ = z.GetNode(node);
|
||||
// If addn inputs cross the graph, make the inputs same as allreduce node.
|
||||
if (z_->isa<CNode>() && fg != z_->func_graph()) {
|
||||
auto cnode_z = z_->cast<CNodePtr>();
|
||||
z_ = NewCNode(cnode_z->inputs(), fg);
|
||||
}
|
||||
|
||||
auto addn_op_node = addn->input(0);
|
||||
auto make_tuple_op_node = addn->input(1)->cast<CNodePtr>()->input(0);
|
||||
auto addn_cnode = addn_pat.GetOriginalNode()->cast<CNodePtr>();
|
||||
auto addn_op_node = addn_cnode->input(0);
|
||||
auto make_tuple_op_node = addn_cnode->input(1)->cast<CNodePtr>()->input(0);
|
||||
auto all_reduce_prim = all_reduce_pat.GetOriginalNode()->cast<CNodePtr>()->input(0);
|
||||
mul_cnode_ = mul_pat.GetOriginalNode();
|
||||
auto mul_prim = mul_cnode_->cast<CNodePtr>()->input(0);
|
||||
auto addn_maketuple = admktup_pat.GetOriginalNode();
|
||||
|
||||
AnfNodePtr tuple = NewCNode({make_tuple_op_node, z_, x_}, fg);
|
||||
AnfNodePtr tuple = NewCNode({make_tuple_op_node, z_, x.GetNode(node)}, fg);
|
||||
AnfNodePtr add = NewCNode({addn_op_node, tuple}, fg);
|
||||
AnfNodePtr all_reduce = NewCNode({all_reduce_, add}, fg);
|
||||
AnfNodePtr mul = NewCNode({mul_, all_reduce, y_}, fg);
|
||||
AnfNodePtr all_reduce = NewCNode({all_reduce_prim, add}, fg);
|
||||
AnfNodePtr mul = NewCNode({mul_prim, all_reduce, y.GetNode(node)}, fg);
|
||||
ProcessDependEdge(fg, addn_maketuple, all_reduce);
|
||||
return mul;
|
||||
};
|
||||
MATCH_REPLACE_LAMBDA(node, addn_pat, adjust_lambda);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void AdjustAllReduceMulAdd::ProcessDependEdge(const FuncGraphPtr &fg, const AnfNodePtr &addn_maketuple,
|
||||
|
@ -146,48 +148,6 @@ void AdjustAllReduceMulAdd::ProcessDependEdge(const FuncGraphPtr &fg, const AnfN
|
|||
}
|
||||
}
|
||||
|
||||
void AdjustAllReduceMulAdd::Visit(const AnfNodePtr &node) {
|
||||
if (level_ == 0) {
|
||||
level_ = 1;
|
||||
is_reduce_match_ = false;
|
||||
// {prim::kPrimMul, {prim::kPrimAllReduce, X}, Y}
|
||||
AnfVisitor::Match(prim::kPrimMul)(node);
|
||||
level_ = 0;
|
||||
if (is_reduce_match_) {
|
||||
mul_ = node->cast<CNodePtr>()->input(0);
|
||||
mul_cnode_ = node->cast<CNodePtr>();
|
||||
y_ = tmp_;
|
||||
} else {
|
||||
z_ = node;
|
||||
}
|
||||
}
|
||||
|
||||
if (level_ == 1) {
|
||||
// {prim::kPrimAllReduce, X}
|
||||
if (IsPrimitiveCNode(node, prim::kPrimAllReduce)) {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
if (cnode->size() > 1) {
|
||||
all_reduce_ = cnode->input(0);
|
||||
x_ = cnode->input(1);
|
||||
is_reduce_match_ = true;
|
||||
all_reduce_fg_ = cnode->func_graph();
|
||||
}
|
||||
} else {
|
||||
tmp_ = node;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void AdjustAllReduceMulAdd::Reset() {
|
||||
level_ = 0;
|
||||
is_reduce_match_ = false;
|
||||
x_ = nullptr;
|
||||
y_ = nullptr;
|
||||
z_ = nullptr;
|
||||
tmp_ = nullptr;
|
||||
all_reduce_fg_ = nullptr;
|
||||
}
|
||||
|
||||
} // namespace irpass
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -38,20 +38,14 @@ namespace irpass {
|
|||
|
||||
// {prim::kPrimAddN, {prim::kPrimMakeTuple, {prim::kPrimMul, {prim::kPrimAllReduce, X}, Y}, Z}} ->
|
||||
// {prim::kPrimMul, {prim::kPrimAllReduce, {prim::kPrimAddN,{prim::kPrimMakeTuple, Z, X}}}, Y}
|
||||
class AdjustAllReduceMulAdd : public AnfVisitor {
|
||||
class AdjustAllReduceMulAdd : public OptimizerCaller {
|
||||
public:
|
||||
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override;
|
||||
|
||||
void ProcessDependEdge(const FuncGraphPtr &fg, const AnfNodePtr &addn_maketuple, const AnfNodePtr &new_node);
|
||||
void Visit(const AnfNodePtr &node) override;
|
||||
void Reset();
|
||||
|
||||
private:
|
||||
int level_{0};
|
||||
bool is_reduce_match_{false};
|
||||
AnfNodePtr x_{nullptr}, y_{nullptr}, z_{nullptr}, tmp_{nullptr};
|
||||
AnfNodePtr all_reduce_{nullptr}, mul_{nullptr}, mul_cnode_{nullptr};
|
||||
FuncGraphPtr all_reduce_fg_{nullptr};
|
||||
AnfNodePtr mul_cnode_{nullptr};
|
||||
};
|
||||
|
||||
class ArithmeticSimplify : public OptimizerCaller {
|
||||
|
|
|
@ -37,27 +37,23 @@ class MakeRefEliminater : public OptimizerCaller {
|
|||
};
|
||||
|
||||
// {prim::kPrimGetRefValue, Parameter} -> Parameter
|
||||
// {prim::kPrimGetRefOrigin, Parameter} -> Parameter
|
||||
class GetRefParamEliminater : public OptimizerCaller {
|
||||
public:
|
||||
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
|
||||
PatternNode<AnfNodePtr> x;
|
||||
MATCH_REPLACE(node, PPrimitive(prim::kPrimGetRefValue, x), x);
|
||||
MATCH_REPLACE(node, PPrimitive(prim::kPrimGetRefOrigin, x), x);
|
||||
return nullptr;
|
||||
}
|
||||
};
|
||||
|
||||
// {prim::kPrimGetRefKey, {prim::kPrimMakeRef, X, Y, Z}} -> X
|
||||
// {prim::kPrimGetRefValue, {prim::kPrimMakeRef, X, Y, Z}} -> Y
|
||||
// {prim::kPrimGetRefOrigin, {prim::kPrimMakeRef, X, Y, Z}} -> Z
|
||||
class GetMakeRefEliminater : public OptimizerCaller {
|
||||
public:
|
||||
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
|
||||
PatternNode<AnfNodePtr> x, y, z;
|
||||
MATCH_REPLACE(node, PPrimitive(prim::kPrimGetRefKey, PPrimitive(prim::kPrimMakeRef, x, y, z)), x);
|
||||
MATCH_REPLACE(node, PPrimitive(prim::kPrimGetRefValue, PPrimitive(prim::kPrimMakeRef, x, y, z)), y);
|
||||
MATCH_REPLACE(node, PPrimitive(prim::kPrimGetRefOrigin, PPrimitive(prim::kPrimMakeRef, x, y, z)), z);
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
|
|
@ -197,6 +197,9 @@ class CostGraph {
|
|||
inputs_tensor_name_list_.push_back(inputs_tensor_name);
|
||||
}
|
||||
const std::vector<std::vector<std::string>> get_inputs_tensor_name_list() const { return inputs_tensor_name_list_; }
|
||||
void set_inputs_tensor_name_list(const std::vector<std::vector<std::string>> &inputs_tensor_name_list) {
|
||||
inputs_tensor_name_list_ = inputs_tensor_name_list;
|
||||
}
|
||||
void add_tuple_getitem(const std::pair<std::string, std::string> &tuple_getitem) {
|
||||
auto ret = tuple_getitem_list_.insert(tuple_getitem);
|
||||
if (ret.second == false) {
|
||||
|
|
|
@ -199,6 +199,8 @@ class SoftmaxCost : public OperatorCost {
|
|||
using SoftmaxCostPtr = std::shared_ptr<SoftmaxCost>;
|
||||
using TileCost = SoftmaxCost;
|
||||
using TileCostPtr = std::shared_ptr<TileCost>;
|
||||
using ConcatCost = TileCost;
|
||||
using ConcatCostPtr = std::shared_ptr<ConcatCost>;
|
||||
|
||||
class TmpIdentityCost : public OperatorCost {
|
||||
public:
|
||||
|
|
|
@ -136,6 +136,7 @@ REGISTER(EmbeddingLookupInfo);
|
|||
REGISTER(TileInfo);
|
||||
REGISTER(StridedSliceInfo);
|
||||
REGISTER(DropoutInfo);
|
||||
REGISTER(ConcatInfo);
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -24,7 +24,6 @@
|
|||
namespace mindspore {
|
||||
namespace parallel {
|
||||
const std::set<std::string> BLACK_LIST = {TUPLE_GETITEM,
|
||||
MAKE_TUPLE,
|
||||
J,
|
||||
LIST_GETITEM,
|
||||
ARRAY_GETITEM,
|
||||
|
|
|
@ -0,0 +1,268 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "frontend/parallel/ops_info/concat_info.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "frontend/parallel/device_matrix.h"
|
||||
#include "frontend/parallel/strategy.h"
|
||||
#include "frontend/parallel/tensor_layout/tensor_redistribution.h"
|
||||
#include "pipeline/jit/resource.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace parallel {
|
||||
Status ConcatInfo::GetAttrs() {
|
||||
int axis = 0;
|
||||
auto axis_iter = attrs_.find(AXIS);
|
||||
if (axis_iter != attrs_.end()) {
|
||||
MS_EXCEPTION_IF_NULL(axis_iter->second);
|
||||
if (axis_iter->second->isa<Int32Imm>()) {
|
||||
axis = axis_iter->second->cast<Int32ImmPtr>()->value();
|
||||
} else {
|
||||
MS_LOG(ERROR) << name_ << ": The value of axis is not int";
|
||||
return FAILED;
|
||||
}
|
||||
} else {
|
||||
MS_LOG(ERROR) << name_ << ": Can not find the axis attr";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
if (inputs_shape_.empty()) {
|
||||
MS_LOG(ERROR) << name_ << ": The inputs shape is empty";
|
||||
return FAILED;
|
||||
}
|
||||
int dim = SizeToInt(inputs_shape_[0].size());
|
||||
|
||||
if (axis < 0) {
|
||||
axis = axis + dim;
|
||||
}
|
||||
|
||||
axis_ = SizeToInt(axis);
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status ConcatInfo::CheckStrategy(const StrategyPtr &strategy) {
|
||||
MS_EXCEPTION_IF_NULL(strategy);
|
||||
if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << ": Invalid strategy";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
std::vector<Dimensions> stra = strategy->GetInputDim();
|
||||
if (stra.empty()) {
|
||||
MS_LOG(ERROR) << name_ << ": The strategy is empty";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
if (stra.size() != inputs_shape_.size()) {
|
||||
MS_LOG(ERROR) << name_ << ": The size of strategy must be equal to the size of inputs shape";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < stra.size(); ++i) {
|
||||
auto strategy_ele = stra[i];
|
||||
auto input_shape_ele = inputs_shape_[i];
|
||||
if (strategy_ele.size() != input_shape_ele.size()) {
|
||||
MS_LOG(ERROR) << name_ << ": The size of strategy element must be equal to the size of input shape";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
if (axis_ >= strategy_ele.size()) {
|
||||
MS_LOG(ERROR) << name_ << ": The axis is out of range, the axis is " << axis_;
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
if (strategy_ele[axis_] != 1) {
|
||||
MS_LOG(ERROR) << name_ << ": The axis can not be split";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
for (size_t j = 0; j < strategy_ele.size(); ++j) {
|
||||
if (strategy_ele[j] != stra[0][j]) {
|
||||
MS_LOG(ERROR) << name_ << ": The strategy of each input tensor must be equal";
|
||||
return FAILED;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status ConcatInfo::InferDevMatrixShape() {
|
||||
MS_EXCEPTION_IF_NULL(strategy_);
|
||||
std::vector<Dimensions> stra = strategy_->GetInputDim();
|
||||
if (stra.empty()) {
|
||||
MS_LOG(ERROR) << name_ << "The strategy is empty";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
dev_matrix_shape_ = stra[0];
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status ConcatInfo::InferTensorMap() {
|
||||
TensorMap tensor_map;
|
||||
if (inputs_shape_.empty()) {
|
||||
MS_LOG(ERROR) << name_ << "The inputs shape is empty";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
// cannot use dev_matrix_shape_ replace inputs_shape_[0], because it may not be fully split in all devices.
|
||||
int32_t size = SizeToInt(inputs_shape_[0].size());
|
||||
for (int i = 0; i < size; ++i) {
|
||||
tensor_map.push_back(size - i - 1);
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < inputs_shape_.size(); ++i) {
|
||||
inputs_tensor_map_.push_back(tensor_map);
|
||||
}
|
||||
outputs_tensor_map_.push_back(tensor_map);
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status ConcatInfo::InferMirrorOps() {
|
||||
mirror_ops_.clear();
|
||||
if (inputs_tensor_map_.empty()) {
|
||||
MS_LOG(ERROR) << name_ << ": The inputs tensor map is empty";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
Shape input_tensor_map = inputs_tensor_map_[0];
|
||||
std::vector<Group> group;
|
||||
if (CreateGroupByTensorMap(input_tensor_map, &group) != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << ": Create group for input failed.";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
if (group.empty()) {
|
||||
MS_LOG(INFO) << name_ << ": The mirror group is empty.";
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
OperatorVector input_op;
|
||||
input_op = CreateMirrorOps(group[0].name(), group[0].GetDevNum());
|
||||
for (size_t i = 0; i < inputs_shape_.size(); ++i) {
|
||||
mirror_ops_.push_back(input_op);
|
||||
}
|
||||
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status ConcatInfo::InferTensorInfo() {
|
||||
if (inputs_shape_.empty() || outputs_shape_.empty() || inputs_tensor_map_.empty() || outputs_tensor_map_.empty()) {
|
||||
MS_LOG(ERROR) << name_ << ": Invalid args";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
TensorLayout input_layout, output_layout;
|
||||
for (size_t i = 0; i < inputs_shape_.size(); ++i) {
|
||||
// infer tensor layout
|
||||
if (input_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_[i], inputs_shape_[i]) != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << ": Infer input tensor layout failed.";
|
||||
return FAILED;
|
||||
}
|
||||
TensorInfo input_tensor_info(input_layout);
|
||||
inputs_tensor_info_.push_back(input_tensor_info);
|
||||
}
|
||||
|
||||
if (output_layout.InitFromVector(dev_matrix_shape_, outputs_tensor_map_[0], outputs_shape_[0]) != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << ": Infer output tensor layout failed.";
|
||||
return FAILED;
|
||||
}
|
||||
TensorInfo output_tensor_info(output_layout);
|
||||
outputs_tensor_info_.push_back(output_tensor_info);
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
void ConcatInfo::ReComputeBatchSplitFlagList() {
|
||||
for (size_t i = 0; i < inputs_shape_.size(); i++) {
|
||||
split_flag_list_[i] = true;
|
||||
}
|
||||
}
|
||||
|
||||
Status ConcatInfo::SetCostUnderStrategy(const StrategyPtr &strategy) {
|
||||
if (SetCostUnderStrategyBase(strategy) != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << ": Set cost under strategy failed.";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status ConcatInfo::GenerateStrategies(int32_t stage_id) {
|
||||
if (InferAttrs() != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << ": Infer attrs failed";
|
||||
return FAILED;
|
||||
}
|
||||
if (inputs_shape_.empty()) {
|
||||
MS_LOG(ERROR) << name_ << ": The inputs shape is empty";
|
||||
return FAILED;
|
||||
}
|
||||
Shape input_split;
|
||||
for (size_t i = 0; i < inputs_shape_[0].size(); ++i) {
|
||||
if (i == axis_) {
|
||||
input_split.push_back(0);
|
||||
} else {
|
||||
input_split.push_back(1);
|
||||
}
|
||||
}
|
||||
Shapes splittable_inputs;
|
||||
for (size_t i = 0; i < inputs_shape_.size(); ++i) {
|
||||
splittable_inputs.push_back(input_split);
|
||||
}
|
||||
|
||||
std::vector<StrategyPtr> sp_vector;
|
||||
is_auto_parallel_ = true;
|
||||
if (GenerateStrategiesWithBroadcast(stage_id, inputs_shape_, splittable_inputs, &sp_vector) != SUCCESS) {
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
size_t success = 0;
|
||||
for (auto &sp : sp_vector) {
|
||||
PrintStrategy(sp);
|
||||
if (SetCostUnderStrategy(sp) == SUCCESS) {
|
||||
success++;
|
||||
MS_LOG(INFO) << name_ << ": Successfully generated " << success << " strategy.";
|
||||
PrintStrategy(sp);
|
||||
}
|
||||
}
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status ConcatInfo::Init(const StrategyPtr &strategy) {
|
||||
if (InitWithAutoRepeatCalc(strategy) != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << ": Init failed.";
|
||||
return FAILED;
|
||||
}
|
||||
MS_LOG(INFO) << name_ << ": Init success.";
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status ConcatInfo::InitForCostModel(const StrategyPtr &strategy) {
|
||||
if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << ": Init for cost model failed.";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << name_ << ": Init for cost model success.";
|
||||
return SUCCESS;
|
||||
}
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,62 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_CONCAT_INFO_H_
|
||||
#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_CONCAT_INFO_H_
|
||||
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "ir/value.h"
|
||||
#include "frontend/parallel/auto_parallel/operator_costmodel.h"
|
||||
#include "frontend/parallel/ops_info/operator_info.h"
|
||||
#include "frontend/parallel/strategy.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace parallel {
|
||||
class ConcatInfo : public OperatorInfo {
|
||||
public:
|
||||
ConcatInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape,
|
||||
const PrimitiveAttrs &attrs)
|
||||
: OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<ConcatCost>(false)) {}
|
||||
~ConcatInfo() override = default;
|
||||
|
||||
Status Init(const StrategyPtr &strategy) override;
|
||||
Status InitForCostModel(const StrategyPtr &strategy) override;
|
||||
Status GenerateStrategies(int32_t) override;
|
||||
Status SetCostUnderStrategy(const StrategyPtr &) override;
|
||||
void ReComputeBatchSplitFlagList() override;
|
||||
|
||||
protected:
|
||||
Status GetAttrs() override;
|
||||
Status CheckStrategy(const StrategyPtr &strategy) override;
|
||||
Status InferMirrorOps() override;
|
||||
Status InferForwardCommunication() override { return SUCCESS; }
|
||||
Status InferTensorInfo() override;
|
||||
Status InferDevMatrixShape() override;
|
||||
Status InferTensorMap() override;
|
||||
|
||||
private:
|
||||
size_t axis_ = 0;
|
||||
};
|
||||
|
||||
using ConcatInfoPtr = std::shared_ptr<ConcatInfo>;
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_CONCAT_INFO_H_
|
|
@ -39,5 +39,6 @@
|
|||
#include "frontend/parallel/ops_info/gather_v2_p_info.h"
|
||||
#include "frontend/parallel/ops_info/tile_info.h"
|
||||
#include "frontend/parallel/ops_info/strided_slice_info.h"
|
||||
#include "frontend/parallel/ops_info/concat_info.h"
|
||||
|
||||
#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_HEAD_FILES_H_
|
||||
|
|
|
@ -56,9 +56,11 @@ constexpr char kMomentum[] = "momentum";
|
|||
|
||||
constexpr char kApplyMomentum[] = "ApplyMomentum";
|
||||
constexpr char kSparseAdam[] = "Adam";
|
||||
constexpr char kSparseLazyAdam[] = "LazyAdam";
|
||||
constexpr char kSparseFtrl[] = "Ftrl";
|
||||
constexpr char kApplyMomentumOp[] = "Momentum";
|
||||
constexpr char kSparseAdamOp[] = "Adam";
|
||||
constexpr char kSparseLazyAdamOp[] = "LazyAdam";
|
||||
constexpr char kSparseFtrlOp[] = "FTRL";
|
||||
|
||||
constexpr int kInitWeightsCmd = 10;
|
||||
|
|
|
@ -126,6 +126,15 @@ MomentumOptimInfo::MomentumOptimInfo(const AddressPtr &weight, const AddressPtr
|
|||
inputs_.push_back(momentum);
|
||||
}
|
||||
|
||||
void MomentumOptimInfo::Update(const Values &values, const Lengths &lens) {
|
||||
size_t lr_offset = 0;
|
||||
float *lr = values.data() + lr_offset;
|
||||
auto ret = memcpy_s(inputs_[2]->addr, sizeof(float), lr, sizeof(float));
|
||||
if (ret != 0) {
|
||||
MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")";
|
||||
}
|
||||
}
|
||||
|
||||
const AddressPtr &MomentumOptimInfo::gradient() { return inputs_[3]; }
|
||||
|
||||
const AddressPtr &MomentumOptimInfo::indices() { return inputs_[3]; }
|
||||
|
|
|
@ -82,6 +82,7 @@ class MomentumOptimInfo : public DenseOptimInfo {
|
|||
const AddressPtr &gradient, const AddressPtr &momentum);
|
||||
~MomentumOptimInfo() override = default;
|
||||
|
||||
void Update(const Values &values, const Lengths &lens) override;
|
||||
const AddressPtr &gradient();
|
||||
const AddressPtr &indices();
|
||||
size_t grad_index() override;
|
||||
|
|
|
@ -42,6 +42,7 @@
|
|||
#include "backend/kernel_compiler/kernel.h"
|
||||
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
|
||||
#include "backend/kernel_compiler/cpu/ps/pserver_kernel.h"
|
||||
#include "backend/kernel_compiler/cpu/ps/sparse_apply_adam_ps_kernel.h"
|
||||
#include "backend/kernel_compiler/cpu/ps/sparse_apply_lazy_adam_ps_kernel.h"
|
||||
#include "backend/kernel_compiler/cpu/ps/sparse_apply_ftrl_ps_kernel.h"
|
||||
#include "backend/kernel_compiler/cpu/ps/apply_momentum_ps_kernel.h"
|
||||
|
@ -374,6 +375,11 @@ void ParameterServer<T>::InitOptimInputsShape(const Keys &keys, const Values &va
|
|||
const CNodePtr cnode = GetCNode(optim_op_name);
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (optim_name == kSparseAdam) {
|
||||
std::shared_ptr<PServerKernel> optimizer =
|
||||
std::make_shared<kernel::ps::SparseApplyAdamPSKernel>(rank_id_, pserver_num_);
|
||||
optimizer->InitKernel(cnode, optim_inputs_shape_[key]);
|
||||
optimizers_[key] = optimizer;
|
||||
} else if (optim_name == kSparseLazyAdam) {
|
||||
std::shared_ptr<PServerKernel> optimizer =
|
||||
std::make_shared<kernel::ps::SparseApplyLazyAdamPSKernel>(rank_id_, pserver_num_);
|
||||
optimizer->InitKernel(cnode, optim_inputs_shape_[key]);
|
||||
|
|
|
@ -25,19 +25,22 @@ namespace ps {
|
|||
std::unordered_map<std::string, int> Util::optimizer_to_ids{
|
||||
{kApplyMomentum, 0},
|
||||
{kSparseAdam, 1},
|
||||
{kSparseFtrl, 2},
|
||||
{kSparseLazyAdam, 2},
|
||||
{kSparseFtrl, 3},
|
||||
};
|
||||
|
||||
std::unordered_map<int, std::string> Util::id_to_optimizers{
|
||||
{0, kApplyMomentum},
|
||||
{1, kSparseAdam},
|
||||
{2, kSparseFtrl},
|
||||
{2, kSparseLazyAdam},
|
||||
{3, kSparseFtrl},
|
||||
};
|
||||
|
||||
std::unordered_map<int, std::string> Util::id_to_optimizer_nodes{
|
||||
{0, kApplyMomentumOp},
|
||||
{1, kSparseAdamOp},
|
||||
{2, kSparseFtrlOp},
|
||||
{2, kSparseLazyAdamOp},
|
||||
{3, kSparseFtrlOp},
|
||||
};
|
||||
|
||||
bool Util::IsParamServerMode() { return IsRoleOfWorker() || IsRoleOfPServer() || IsRoleOfScheduler(); }
|
||||
|
|
|
@ -118,6 +118,9 @@ bool StepAutoParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &) {
|
|||
std::vector<bool> ExtractInputParameterByNode(const CNodePtr &node) {
|
||||
std::vector<bool> is_parameter;
|
||||
std::vector<AnfNodePtr> node_inputs{node->inputs()};
|
||||
if ((node_inputs.size() == 2) && AnfNodeIsPrimitive(node_inputs[1], MAKE_TUPLE)) {
|
||||
node_inputs = node_inputs[1]->cast<CNodePtr>()->inputs();
|
||||
}
|
||||
for (size_t i = 1; i < node_inputs.size(); ++i) {
|
||||
auto input = node_inputs[i];
|
||||
|
||||
|
@ -192,6 +195,10 @@ std::vector<size_t> ExtractInputTypeLengthByNode(const CNodePtr &node) {
|
|||
std::vector<size_t> inputs_type_len;
|
||||
std::vector<AnfNodePtr> node_inputs{node->inputs()};
|
||||
|
||||
if ((node_inputs.size() == 2) && AnfNodeIsPrimitive(node_inputs[1], MAKE_TUPLE)) {
|
||||
node_inputs = node_inputs[1]->cast<CNodePtr>()->inputs();
|
||||
}
|
||||
|
||||
// extract input element length
|
||||
for (auto &input : node_inputs) {
|
||||
if (IsValueNode<RefKey>(input)) {
|
||||
|
@ -255,7 +262,7 @@ bool IsSplittableOperator(const std::string &op_name) {
|
|||
FLOORDIV, L2_NORMALIZE, TENSOR_ADD, MAXPOOL, MAXPOOLV2, VIRTUAL_DATA_SET, RELU, ONEHOT, DROPOUT_DO_MASK,
|
||||
REDUCE_MAX, REDUCE_MIN, ARGMAXWITHVALUE, ARGMINWITHVALUE, REDUCE_SUM, CONV2D, FUSE_BATCH_NORM, POOLING,
|
||||
MAX_POOL_WITH_ARGMAX, SIMPLE_MEAN, FLATTEN, BATCH_NORM, LAYER_NORM, BIAS_ADD, ASSIGN_SUB, COS, ACOS, EXP,
|
||||
LOG, REDUCE_MEAN, REAL_DIV, SIGMOID, POW, MAXIMUM, MINIMUM, EQUAL, NOT_EQUAL, LOGICALNOT, GATHERV2, SQRT,
|
||||
LOG, REDUCE_MEAN, REAL_DIV, SIGMOID, POW, MAXIMUM, MINIMUM, EQUAL, NOT_EQUAL, LOGICALNOT, GATHERV2, SQRT, CONCAT,
|
||||
STRIDEDSLICE, GET_NEXT, CAST, NEG, SQUARE, BATCH_MATMUL, EXPAND_DIMS, SQUEEZE, SPARSE_GATHERV2, TILE, DROPOUT,
|
||||
SOFTMAX_CROSS_ENTROPY_WITH_LOGITS, SIGMOID_CROSS_ENTROPY_WITH_LOGITS, SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS};
|
||||
// clang-format on
|
||||
|
@ -275,7 +282,7 @@ bool IsAutoParallelCareNode(const CNodePtr &cnode) {
|
|||
return false;
|
||||
}
|
||||
bool bool_result = IsParallelCareNode(cnode) && !IsSplittableOperator(prim->name());
|
||||
if (bool_result) {
|
||||
if (bool_result && (prim->name() != MAKE_TUPLE)) {
|
||||
MS_LOG(EXCEPTION) << "Should implementing OperatorInfo for: " << prim->name();
|
||||
} else if (prim->name() == CAST) {
|
||||
if (cnode->fullname_with_scope().find(OPTIMIZER_SUB_STRING) != std::string::npos) {
|
||||
|
@ -520,6 +527,10 @@ Status ConstructCostGraphNodesByUniqueIdTC(const std::vector<AnfNodePtr> &all_no
|
|||
MS_LOG(EXCEPTION) << "The OperatorInfo: " << current_op_ptr->name()
|
||||
<< " does not match the Prim: " << prim->name();
|
||||
}
|
||||
|
||||
// Needed by rec_parser
|
||||
ModifyInputsTensorNameListIfOperatorInfoCreated(current_op_ptr->name(), cnode->UniqueId());
|
||||
|
||||
cnode->set_user_data<OperatorInfo>(current_op_ptr);
|
||||
MS_LOG(INFO) << "The CNode with UniqueId: " << cnode->UniqueId()
|
||||
<< " and UniqueIdThroughCopy: " << cnode->UniqueIdThroughCopy()
|
||||
|
@ -1117,6 +1128,27 @@ CNodePtr GetInternalOperatorInfo(const CNodePtr &cnode, const ValueNodePtr &prim
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
void ModifyInputsTensorNameListIfOperatorInfoCreated(const std::string &name, const std::string &uniqueid) {
|
||||
size_t iter_ops = 0;
|
||||
for (auto op : entire_costgraph->GetOperators()) {
|
||||
if (op->name() == name) {
|
||||
break;
|
||||
}
|
||||
iter_ops = iter_ops + 1;
|
||||
}
|
||||
|
||||
std::vector<std::vector<std::string>> input_tensor_names = entire_costgraph->get_inputs_tensor_name_list();
|
||||
for (size_t i = 0; i < input_tensor_names.size(); i++) {
|
||||
for (size_t j = 0; j < input_tensor_names[i].size(); j++) {
|
||||
if (input_tensor_names[i][j] == uniqueid) {
|
||||
input_tensor_names[i][j] = input_tensor_names[iter_ops][0];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
entire_costgraph->set_inputs_tensor_name_list(input_tensor_names);
|
||||
}
|
||||
|
||||
Status ParallelStrategyRecSearch(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &root) {
|
||||
if (CostModelContext::GetInstance()->is_multi_subgraphs()) {
|
||||
if (ConstructCostGraphNodesByUniqueIdTC(all_nodes, root) == SUCCESS) {
|
||||
|
|
|
@ -59,6 +59,8 @@ std::vector<std::vector<std::string>> RecInputTensorNames(const std::map<std::st
|
|||
std::vector<std::vector<std::string>> input_tensor_names);
|
||||
|
||||
CNodePtr GetInternalOperatorInfo(const CNodePtr &cnode, const ValueNodePtr &prim_anf_node);
|
||||
|
||||
void ModifyInputsTensorNameListIfOperatorInfoCreated(const std::string &name, const std::string &uniqueid);
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
#endif // PARALLEL_STEP_AUTO_PARALLEL_H_
|
||||
|
|
|
@ -267,6 +267,33 @@ TensorLayout GetTensorInLayout(const CNodePtr &middle_node, const PrimitivePtr &
|
|||
return tensorinfo_in.tensor_layout();
|
||||
}
|
||||
|
||||
bool AnfNodeIsPrimitive(const AnfNodePtr &anf_node, const std::string &prim_name) {
|
||||
MS_EXCEPTION_IF_NULL(anf_node);
|
||||
auto cnode = anf_node->cast<CNodePtr>();
|
||||
if ((cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0))) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto value_node = cnode->input(0)->cast<ValueNodePtr>();
|
||||
auto prim = GetValueNode<PrimitivePtr>(value_node);
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
if (prim->name() == prim_name) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
std::string GetPrimName(const CNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (!IsValueNode<Primitive>(node->input(0))) {
|
||||
MS_LOG(EXCEPTION) << "The node is not a primitive";
|
||||
}
|
||||
auto value_node = node->input(0)->cast<ValueNodePtr>();
|
||||
auto prim = GetValueNode<PrimitivePtr>(value_node);
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
return prim->name();
|
||||
}
|
||||
|
||||
OperatorInfoPtr GetDistributeOperator(const CNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (!IsParallelCareNode(node)) {
|
||||
|
@ -274,7 +301,7 @@ OperatorInfoPtr GetDistributeOperator(const CNodePtr &node) {
|
|||
}
|
||||
OperatorInfoPtr distribute_operator = node->user_data<OperatorInfo>();
|
||||
if (distribute_operator == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "GetDistributeOperator:distribute_operator is nullptr";
|
||||
MS_LOG(EXCEPTION) << "Distribute operator is nullptr, the prim is " << GetPrimName(node);
|
||||
}
|
||||
return distribute_operator;
|
||||
}
|
||||
|
@ -423,6 +450,11 @@ void StepRedistribution(const CNodePtr &node, const OperatorInfoPtr &distribute_
|
|||
MS_EXCEPTION_IF_NULL(manager);
|
||||
AnfNodeIndexSet node_set = manager->node_users()[node];
|
||||
CNodePtr insert_node_new;
|
||||
|
||||
if (AnfNodeIsPrimitive(node, MAKE_TUPLE)) {
|
||||
MS_LOG(INFO) << "No need to insert redistribution op betweend make_tuple node and the next node";
|
||||
return;
|
||||
}
|
||||
if (IsValueNode<Primitive>(node->input(0))) {
|
||||
auto current_value = node->input(0)->cast<ValueNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(current_value);
|
||||
|
@ -875,9 +907,15 @@ void InsertMirrorOps(const MirrorOps &mirror_ops, const CNodePtr &node) {
|
|||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
FuncGraphManagerPtr manager = func_graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
|
||||
if ((node->inputs().size() == 2) && AnfNodeIsPrimitive(node->input(1), MAKE_TUPLE)) {
|
||||
MS_LOG(INFO) << "The mirror for " << GetPrimName(node) << " has handle by make_tuple node";
|
||||
return;
|
||||
}
|
||||
|
||||
if (mirror_ops.size() != node_size - 1) {
|
||||
MS_LOG(EXCEPTION) << "Failure:Mirrorops's size is wrong! mirror_ops size is " << mirror_ops.size()
|
||||
<< ", node_size is " << node_size;
|
||||
MS_LOG(EXCEPTION) << "Mirrorops's size is wrong! mirror_ops size is " << mirror_ops.size() << ", node_size is "
|
||||
<< node_size - 1;
|
||||
}
|
||||
for (size_t index = 1; index < node_size; ++index) {
|
||||
OperatorVector backward_op = mirror_ops[index - 1];
|
||||
|
@ -993,7 +1031,7 @@ OperatorInfoPtr OperatorInstance(const PrimitivePtr &prim, const PrimitiveAttrs
|
|||
const std::vector<Shapes> &shape_list) {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
OperatorInfoPtr operator_ = OperatorInstanceByName(prim->name(), attrs, shape_list);
|
||||
if (operator_ == nullptr) {
|
||||
if ((operator_ == nullptr) && (prim->name() != MAKE_TUPLE)) {
|
||||
MS_LOG(INFO) << "Creat " << prim->name() << " failed, use batch parallel";
|
||||
operator_ = OperatorInstanceByName(BATCH_PARALLEL, attrs, shape_list);
|
||||
MS_EXCEPTION_IF_NULL(operator_);
|
||||
|
@ -1177,8 +1215,13 @@ std::vector<Shapes> ExtractShape(const CNodePtr &node) {
|
|||
continue;
|
||||
}
|
||||
if (input_shapes.size() != 1) {
|
||||
if (inputs_size == 2) { // like concat
|
||||
shape_inputs = input_shapes;
|
||||
break;
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "ExtractShape: Get input shape failed";
|
||||
}
|
||||
}
|
||||
shape_inputs.push_back(input_shapes[0]);
|
||||
}
|
||||
shape_all.push_back(shape_inputs);
|
||||
|
@ -1269,8 +1312,8 @@ void SetParallelShape(const AnfNodePtr ¶meter, const std::pair<AnfNodePtr, i
|
|||
}
|
||||
TensorInfo tensorinfo_in = distribute_operator->inputs_tensor_info()[IntToSize(res.second - 1)];
|
||||
Shape slice_shape = tensorinfo_in.slice_shape();
|
||||
MS_LOG(DEBUG) << "SetParallelShape slice_shape " << parameter->ToString() << " shape "
|
||||
<< MakeValue(slice_shape)->ToString();
|
||||
MS_LOG(INFO) << "SetParallelShape slice_shape " << parameter->ToString() << " shape "
|
||||
<< MakeValue(slice_shape)->ToString() << ", op name is " << distribute_operator->name();
|
||||
std::shared_ptr<abstract::BaseShape> parallel_shape = std::make_shared<abstract::Shape>(slice_shape);
|
||||
MS_EXCEPTION_IF_NULL(parallel_shape);
|
||||
// Don't modify it in-place as the pointer of this AbstractValue may used as cache key in StaticAnalysis.
|
||||
|
@ -1450,6 +1493,9 @@ void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes) {
|
|||
SetVirtualDatasetStrategy(cnode);
|
||||
ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>();
|
||||
PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node);
|
||||
if (prim->name() == MAKE_TUPLE) {
|
||||
continue;
|
||||
}
|
||||
auto attrs = prim->attrs();
|
||||
MS_LOG(INFO) << "extract information: node: " << node->ToString() << " prim " << prim->name();
|
||||
if (IsParallelCareNode(cnode)) {
|
||||
|
@ -2045,13 +2091,13 @@ void ParallelCommunication(const FuncGraphPtr &root, const std::vector<AnfNodePt
|
|||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (node->isa<CNode>()) {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
if (!IsValueNode<Primitive>(cnode->input(0))) {
|
||||
// the make_tuple is parallel care node, but it may have not operator info
|
||||
if (!IsParallelCareNode(cnode) || !cnode->has_user_data<OperatorInfo>()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
OperatorInfoPtr distribute_operator = GetDistributeOperator(cnode);
|
||||
if (distribute_operator == nullptr) {
|
||||
continue;
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(distribute_operator);
|
||||
|
||||
// insert forward ops
|
||||
InsertForwardOps(distribute_operator, cnode);
|
||||
|
@ -2074,13 +2120,12 @@ void ParallelCommunication(const FuncGraphPtr &root, const std::vector<AnfNodePt
|
|||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (node->isa<CNode>()) {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
if (!IsValueNode<Primitive>(cnode->input(0))) {
|
||||
if (!IsParallelCareNode(cnode) || !cnode->has_user_data<OperatorInfo>()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
OperatorInfoPtr distribute_operator = GetDistributeOperator(cnode);
|
||||
if (distribute_operator == nullptr) {
|
||||
continue;
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(distribute_operator);
|
||||
// StepReplace
|
||||
StepReplace(distribute_operator, cnode);
|
||||
}
|
||||
|
@ -2330,6 +2375,44 @@ Status ParallelInit() {
|
|||
return SUCCESS;
|
||||
}
|
||||
|
||||
void HandleForwardMakeTuple(const std::vector<AnfNodePtr> &all_nodes) {
|
||||
for (auto &node : all_nodes) {
|
||||
if (!AnfNodeIsPrimitive(node, MAKE_TUPLE)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (!cnode->in_forward_flag()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
FuncGraphManagerPtr manager = cnode->func_graph()->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
auto make_tuple_user = manager->node_users()[cnode];
|
||||
if (make_tuple_user.size() != 1) {
|
||||
MS_LOG(EXCEPTION) << "Now the make_tuple's user must be 1, but got " << make_tuple_user.size();
|
||||
}
|
||||
CNodePtr make_tuple_next_cnode = make_tuple_user.pop().first->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(make_tuple_next_cnode);
|
||||
|
||||
std::string make_tuple_user_prim_name = GetPrimName(make_tuple_next_cnode);
|
||||
if (!IsParallelCareNode(make_tuple_next_cnode)) {
|
||||
MS_LOG(INFO) << "The make_tuple's user is " << make_tuple_user_prim_name << ", no need to set operator info";
|
||||
continue;
|
||||
}
|
||||
if (make_tuple_next_cnode->inputs().size() != 2) {
|
||||
MS_LOG(EXCEPTION) << "Now the make_tuple's user only support 1 input, but got "
|
||||
<< make_tuple_next_cnode->inputs().size() - 1;
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "Set the make_tuple's operator info, and the op name is " << make_tuple_user_prim_name;
|
||||
OperatorInfoPtr op_info = GetDistributeOperator(make_tuple_next_cnode);
|
||||
MS_EXCEPTION_IF_NULL(op_info);
|
||||
cnode->set_user_data<OperatorInfo>(op_info);
|
||||
}
|
||||
}
|
||||
|
||||
bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer) {
|
||||
MS_EXCEPTION_IF_NULL(root);
|
||||
MS_EXCEPTION_IF_NULL(optimizer);
|
||||
|
@ -2383,6 +2466,9 @@ bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer)
|
|||
ExtractInformation(all_nodes);
|
||||
ReshapeInit(all_nodes);
|
||||
}
|
||||
|
||||
HandleForwardMakeTuple(all_nodes);
|
||||
|
||||
// save strategy as checkpoint for multi-train
|
||||
if (StrategyCheckpoint::GetInstance().SaveCheckPointOn()) {
|
||||
CheckpointStrategy(root);
|
||||
|
|
|
@ -149,6 +149,8 @@ Status ParallelInit();
|
|||
std::vector<std::string> ExtractInputsTensorName(const CNodePtr &node);
|
||||
|
||||
std::set<FuncGraphPtr> ForwardGraph(const FuncGraphPtr &root);
|
||||
|
||||
bool AnfNodeIsPrimitive(const AnfNodePtr &anf_node, const std::string &prim_name);
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -1,6 +0,0 @@
|
|||
file(GLOB_RECURSE MS_GVAR_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} *.cc)
|
||||
set_property(SOURCE ${MS_GVAR_SRC_LIST} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_COMMON)
|
||||
add_library(mindspore_gvar SHARED ${MS_GVAR_SRC_LIST})
|
||||
if (APPLE)
|
||||
set_target_properties(mindspore_gvar PROPERTIES MACOSX_RPATH ON)
|
||||
endif ()
|
|
@ -62,12 +62,15 @@ add_subdirectory(text)
|
|||
add_dependencies(utils core)
|
||||
add_dependencies(kernels-image core)
|
||||
add_dependencies(kernels-data core)
|
||||
add_dependencies(kernels-soft-dvpp-image core soft-dvpp-utils)
|
||||
add_dependencies(kernels core)
|
||||
add_dependencies(engine-datasetops-source core)
|
||||
add_dependencies(engine-datasetops-source-sampler core)
|
||||
add_dependencies(engine-datasetops core)
|
||||
add_dependencies(engine-datasetops-mapop core)
|
||||
add_dependencies(engine-opt core)
|
||||
add_dependencies(engine-cache-client core)
|
||||
add_dependencies(engine-cache-server core)
|
||||
add_dependencies(engine-perf core)
|
||||
add_dependencies(engine-gnn core)
|
||||
add_dependencies(engine core)
|
||||
|
@ -88,6 +91,8 @@ set(submodules
|
|||
$<TARGET_OBJECTS:kernels-image>
|
||||
$<TARGET_OBJECTS:kernels-data>
|
||||
$<TARGET_OBJECTS:cpp-API>
|
||||
$<TARGET_OBJECTS:kernels-soft-dvpp-image>
|
||||
$<TARGET_OBJECTS:soft-dvpp-utils>
|
||||
$<TARGET_OBJECTS:engine-datasetops-source>
|
||||
$<TARGET_OBJECTS:engine-datasetops-source-sampler>
|
||||
$<TARGET_OBJECTS:engine-datasetops-mapop>
|
||||
|
@ -126,7 +131,7 @@ endif()
|
|||
######################################################################
|
||||
|
||||
################# Link with external libraries ########################
|
||||
target_link_libraries(_c_dataengine PRIVATE mindspore mindspore_gvar)
|
||||
target_link_libraries(_c_dataengine PRIVATE mindspore)
|
||||
if (${CMAKE_SYSTEM_NAME} MATCHES "Windows")
|
||||
if (ENABLE_PYTHON)
|
||||
target_link_libraries(_c_dataengine PRIVATE mindspore::pybind11_module ${PYTHON_LIBRARIES} mindspore::protobuf ${SECUREC_LIBRARY})
|
||||
|
@ -141,7 +146,7 @@ else()
|
|||
target_link_libraries(_c_dataengine PRIVATE -ldl mindspore::protobuf ${SECUREC_LIBRARY})
|
||||
endif()
|
||||
endif()
|
||||
target_link_libraries(_c_dataengine PUBLIC mindspore::jpeg_turbo mindspore::opencv_core mindspore::opencv_imgcodecs
|
||||
target_link_libraries(_c_dataengine PUBLIC mindspore::jpeg_turbo mindspore::turbojpeg mindspore::opencv_core mindspore::opencv_imgcodecs
|
||||
mindspore::opencv_imgproc mindspore::tinyxml2 mindspore::sentencepiece mindspore::sentencepiece_train ${ICU_LIB})
|
||||
if (ENABLE_GPUQUE)
|
||||
target_link_libraries(_c_dataengine PRIVATE gpu_queue
|
||||
|
|
|
@ -61,11 +61,19 @@ namespace api {
|
|||
} while (false)
|
||||
|
||||
// Function to create the iterator, which will build and launch the execution tree.
|
||||
std::shared_ptr<Iterator> Dataset::CreateIterator() {
|
||||
std::shared_ptr<Iterator> Dataset::CreateIterator(std::vector<std::string> columns) {
|
||||
std::shared_ptr<Iterator> iter;
|
||||
try {
|
||||
auto ds = shared_from_this();
|
||||
|
||||
// The specified columns will be selected from the dataset and passed down the pipeline
|
||||
// in the order specified, other columns will be discarded.
|
||||
if (!columns.empty()) {
|
||||
ds = ds->Project(columns);
|
||||
}
|
||||
|
||||
iter = std::make_shared<Iterator>();
|
||||
Status rc = iter->BuildAndLaunchTree(shared_from_this());
|
||||
Status rc = iter->BuildAndLaunchTree(ds);
|
||||
if (rc.IsError()) {
|
||||
MS_LOG(ERROR) << "CreateIterator failed." << rc;
|
||||
return nullptr;
|
||||
|
@ -629,13 +637,13 @@ bool VOCDataset::ValidateParams() {
|
|||
}
|
||||
Path imagesets_file = dir / "ImageSets" / "Segmentation" / mode_ + ".txt";
|
||||
if (!imagesets_file.Exists()) {
|
||||
MS_LOG(ERROR) << "[Segmentation] imagesets_file is invalid or not exist";
|
||||
MS_LOG(ERROR) << "Invalid mode: " << mode_ << ", file \"" << imagesets_file << "\" is not exists!";
|
||||
return false;
|
||||
}
|
||||
} else if (task_ == "Detection") {
|
||||
Path imagesets_file = dir / "ImageSets" / "Main" / mode_ + ".txt";
|
||||
if (!imagesets_file.Exists()) {
|
||||
MS_LOG(ERROR) << "[Detection] imagesets_file is invalid or not exist.";
|
||||
MS_LOG(ERROR) << "Invalid mode: " << mode_ << ", file \"" << imagesets_file << "\" is not exists!";
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
|
@ -655,18 +663,33 @@ std::vector<std::shared_ptr<DatasetOp>> VOCDataset::Build() {
|
|||
sampler_ = CreateDefaultSampler();
|
||||
}
|
||||
|
||||
std::shared_ptr<VOCOp::Builder> builder = std::make_shared<VOCOp::Builder>();
|
||||
(void)builder->SetDir(dataset_dir_);
|
||||
(void)builder->SetTask(task_);
|
||||
(void)builder->SetMode(mode_);
|
||||
(void)builder->SetNumWorkers(num_workers_);
|
||||
(void)builder->SetSampler(std::move(sampler_->Build()));
|
||||
(void)builder->SetDecode(decode_);
|
||||
(void)builder->SetClassIndex(class_index_);
|
||||
auto schema = std::make_unique<DataSchema>();
|
||||
VOCOp::TaskType task_type_;
|
||||
|
||||
std::shared_ptr<VOCOp> op;
|
||||
RETURN_EMPTY_IF_ERROR(builder->Build(&op));
|
||||
node_ops.push_back(op);
|
||||
if (task_ == "Segmentation") {
|
||||
task_type_ = VOCOp::TaskType::Segmentation;
|
||||
RETURN_EMPTY_IF_ERROR(schema->AddColumn(
|
||||
ColDescriptor(std::string(kColumnImage), DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1)));
|
||||
RETURN_EMPTY_IF_ERROR(schema->AddColumn(
|
||||
ColDescriptor(std::string(kColumnTarget), DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1)));
|
||||
} else if (task_ == "Detection") {
|
||||
task_type_ = VOCOp::TaskType::Detection;
|
||||
RETURN_EMPTY_IF_ERROR(schema->AddColumn(
|
||||
ColDescriptor(std::string(kColumnImage), DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1)));
|
||||
RETURN_EMPTY_IF_ERROR(schema->AddColumn(
|
||||
ColDescriptor(std::string(kColumnBbox), DataType(DataType::DE_FLOAT32), TensorImpl::kFlexible, 1)));
|
||||
RETURN_EMPTY_IF_ERROR(schema->AddColumn(
|
||||
ColDescriptor(std::string(kColumnLabel), DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1)));
|
||||
RETURN_EMPTY_IF_ERROR(schema->AddColumn(
|
||||
ColDescriptor(std::string(kColumnDifficult), DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1)));
|
||||
RETURN_EMPTY_IF_ERROR(schema->AddColumn(
|
||||
ColDescriptor(std::string(kColumnTruncate), DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1)));
|
||||
}
|
||||
|
||||
std::shared_ptr<VOCOp> voc_op;
|
||||
voc_op = std::make_shared<VOCOp>(task_type_, mode_, dataset_dir_, class_index_, num_workers_, rows_per_buffer_,
|
||||
connector_que_size_, decode_, std::move(schema), std::move(sampler_->Build()));
|
||||
node_ops.push_back(voc_op);
|
||||
return node_ops;
|
||||
}
|
||||
|
||||
|
|
|
@ -30,6 +30,19 @@ void Iterator::GetNextRow(TensorMap *row) {
|
|||
}
|
||||
}
|
||||
|
||||
// Get the next row from the data pipeline.
|
||||
void Iterator::GetNextRow(TensorVec *row) {
|
||||
TensorRow tensor_row;
|
||||
Status rc = iterator_->FetchNextTensorRow(&tensor_row);
|
||||
if (rc.IsError()) {
|
||||
MS_LOG(ERROR) << "GetNextRow: Failed to get next row.";
|
||||
row->clear();
|
||||
}
|
||||
// Generate a vector as return
|
||||
row->clear();
|
||||
std::copy(tensor_row.begin(), tensor_row.end(), std::back_inserter(*row));
|
||||
}
|
||||
|
||||
// Shut down the data pipeline.
|
||||
void Iterator::Stop() {
|
||||
// Releasing the iterator_ unique_ptre. This should trigger the destructor of iterator_.
|
||||
|
@ -61,13 +74,20 @@ Status Iterator::BuildAndLaunchTree(std::shared_ptr<Dataset> ds) {
|
|||
RETURN_STATUS_UNEXPECTED("Node operation returned nothing");
|
||||
}
|
||||
|
||||
auto root_op = root_ops.front();
|
||||
|
||||
RETURN_UNEXPECTED_IF_NULL(root_op);
|
||||
|
||||
RETURN_IF_NOT_OK(tree_->AssociateNode(root_op));
|
||||
|
||||
q.push(std::make_pair(ds, root_op));
|
||||
// Iterate through all the DatasetOps returned by Dataset's Build(), associate them
|
||||
// with the execution tree and add the child and parent relationship between the nodes
|
||||
// Note that some Dataset objects might return more than one DatasetOps
|
||||
// e.g. MapDataset will return [ProjectOp, MapOp] if project_columns is set for MapDataset
|
||||
std::shared_ptr<DatasetOp> prev_op = nullptr;
|
||||
for (auto op : root_ops) {
|
||||
RETURN_IF_NOT_OK(tree_->AssociateNode(op));
|
||||
if (prev_op != nullptr) {
|
||||
RETURN_IF_NOT_OK(prev_op->AddChild(op));
|
||||
}
|
||||
prev_op = op;
|
||||
}
|
||||
// Add the last DatasetOp to the queue to be BFS.
|
||||
q.push(std::make_pair(ds, root_ops.back()));
|
||||
|
||||
// Traverse down to the children and convert them to the corresponding DatasetOps (i.e. execution tree nodes)
|
||||
while (!q.empty()) {
|
||||
|
@ -94,7 +114,7 @@ Status Iterator::BuildAndLaunchTree(std::shared_ptr<Dataset> ds) {
|
|||
q.push(std::make_pair(child, child_ops.back()));
|
||||
}
|
||||
}
|
||||
RETURN_IF_NOT_OK(tree_->AssignRoot(root_op));
|
||||
RETURN_IF_NOT_OK(tree_->AssignRoot(root_ops.front()));
|
||||
}
|
||||
|
||||
// Launch the execution tree.
|
||||
|
|
|
@ -28,8 +28,10 @@
|
|||
#include "minddata/dataset/kernels/image/hwc_to_chw_op.h"
|
||||
#include "minddata/dataset/kernels/image/image_utils.h"
|
||||
#include "minddata/dataset/kernels/image/invert_op.h"
|
||||
#include "minddata/dataset/kernels/image/mixup_batch_op.h"
|
||||
#include "minddata/dataset/kernels/image/normalize_op.h"
|
||||
#include "minddata/dataset/kernels/image/pad_op.h"
|
||||
#include "minddata/dataset/kernels/image/random_affine_op.h"
|
||||
#include "minddata/dataset/kernels/image/random_color_adjust_op.h"
|
||||
#include "minddata/dataset/kernels/image/random_crop_and_resize_op.h"
|
||||
#include "minddata/dataset/kernels/image/random_crop_and_resize_with_bbox_op.h"
|
||||
|
@ -48,6 +50,8 @@
|
|||
#include "minddata/dataset/kernels/image/resize_bilinear_op.h"
|
||||
#include "minddata/dataset/kernels/image/resize_op.h"
|
||||
#include "minddata/dataset/kernels/image/resize_with_bbox_op.h"
|
||||
#include "minddata/dataset/kernels/image/soft_dvpp/soft_dvpp_decode_random_crop_resize_jpeg_op.h"
|
||||
#include "minddata/dataset/kernels/image/soft_dvpp/soft_dvpp_decode_resize_jpeg_op.h"
|
||||
#include "minddata/dataset/kernels/image/uniform_aug_op.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -92,6 +96,12 @@ PYBIND_REGISTER(CenterCropOp, 1, ([](const py::module *m) {
|
|||
.def(py::init<int32_t, int32_t>(), py::arg("height"), py::arg("width") = CenterCropOp::kDefWidth);
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(MixUpBatchOp, 1, ([](const py::module *m) {
|
||||
(void)py::class_<MixUpBatchOp, TensorOp, std::shared_ptr<MixUpBatchOp>>(
|
||||
*m, "MixUpBatchOp", "Tensor operation to mixup a batch of images")
|
||||
.def(py::init<float>(), py::arg("alpha"));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(ResizeOp, 1, ([](const py::module *m) {
|
||||
(void)py::class_<ResizeOp, TensorOp, std::shared_ptr<ResizeOp>>(
|
||||
*m, "ResizeOp", "Tensor operation to resize an image. Takes height, width and mode")
|
||||
|
@ -108,6 +118,19 @@ PYBIND_REGISTER(ResizeWithBBoxOp, 1, ([](const py::module *m) {
|
|||
py::arg("interpolation") = ResizeWithBBoxOp::kDefInterpolation);
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(RandomAffineOp, 1, ([](const py::module *m) {
|
||||
(void)py::class_<RandomAffineOp, TensorOp, std::shared_ptr<RandomAffineOp>>(
|
||||
*m, "RandomAffineOp", "Tensor operation to apply random affine transformations on an image.")
|
||||
.def(py::init<std::vector<float_t>, std::vector<float_t>, std::vector<float_t>,
|
||||
std::vector<float_t>, InterpolationMode, std::vector<uint8_t>>(),
|
||||
py::arg("degrees") = RandomAffineOp::kDegreesRange,
|
||||
py::arg("translate_range") = RandomAffineOp::kTranslationPercentages,
|
||||
py::arg("scale_range") = RandomAffineOp::kScaleRange,
|
||||
py::arg("shear_ranges") = RandomAffineOp::kShearRanges,
|
||||
py::arg("interpolation") = RandomAffineOp::kDefInterpolation,
|
||||
py::arg("fill_value") = RandomAffineOp::kFillValue);
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(
|
||||
RandomResizeWithBBoxOp, 1, ([](const py::module *m) {
|
||||
(void)py::class_<RandomResizeWithBBoxOp, TensorOp, std::shared_ptr<RandomResizeWithBBoxOp>>(
|
||||
|
@ -341,6 +364,24 @@ PYBIND_REGISTER(RandomSelectSubpolicyOp, 1, ([](const py::module *m) {
|
|||
return std::make_shared<RandomSelectSubpolicyOp>(cpp_policy);
|
||||
}));
|
||||
}));
|
||||
PYBIND_REGISTER(SoftDvppDecodeResizeJpegOp, 1, ([](const py::module *m) {
|
||||
(void)py::class_<SoftDvppDecodeResizeJpegOp, TensorOp, std::shared_ptr<SoftDvppDecodeResizeJpegOp>>(
|
||||
*m, "SoftDvppDecodeResizeJpegOp", "TensorOp to use soft dvpp decode and resize jpeg image.")
|
||||
.def(py::init<int32_t, int32_t>(), py::arg("targetHeight"), py::arg("targetWidth"));
|
||||
}));
|
||||
PYBIND_REGISTER(
|
||||
SoftDvppDecodeRandomCropResizeJpegOp, 1, ([](const py::module *m) {
|
||||
(void)
|
||||
py::class_<SoftDvppDecodeRandomCropResizeJpegOp, TensorOp, std::shared_ptr<SoftDvppDecodeRandomCropResizeJpegOp>>(
|
||||
*m, "SoftDvppDecodeRandomCropResizeJpegOp",
|
||||
"TensorOp to use soft dvpp decode, random crop and resize jepg image.")
|
||||
.def(py::init<int32_t, int32_t, float, float, float, float, int32_t>(), py::arg("targetHeight"),
|
||||
py::arg("targetWidth"), py::arg("scaleLb") = RandomCropDecodeResizeOp::kDefScaleLb,
|
||||
py::arg("scaleUb") = RandomCropDecodeResizeOp::kDefScaleUb,
|
||||
py::arg("aspectLb") = RandomCropDecodeResizeOp::kDefAspectLb,
|
||||
py::arg("aspectUb") = RandomCropDecodeResizeOp::kDefAspectUb,
|
||||
py::arg("maxIter") = RandomCropDecodeResizeOp::kDefMaxIter);
|
||||
}));
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -48,12 +48,12 @@ PYBIND_REGISTER(
|
|||
ShardPkSample, 1, ([](const py::module *m) {
|
||||
(void)py::class_<mindrecord::ShardPkSample, mindrecord::ShardOperator, std::shared_ptr<mindrecord::ShardPkSample>>(
|
||||
*m, "MindrecordPkSampler")
|
||||
.def(py::init([](int64_t kVal, std::string kColumn, bool shuffle) {
|
||||
.def(py::init([](int64_t kVal, std::string kColumn, bool shuffle, int64_t num_samples) {
|
||||
if (shuffle == true) {
|
||||
return std::make_shared<mindrecord::ShardPkSample>(kColumn, kVal, std::numeric_limits<int64_t>::max(),
|
||||
GetSeed());
|
||||
GetSeed(), num_samples);
|
||||
} else {
|
||||
return std::make_shared<mindrecord::ShardPkSample>(kColumn, kVal);
|
||||
return std::make_shared<mindrecord::ShardPkSample>(kColumn, kVal, num_samples);
|
||||
}
|
||||
}));
|
||||
}));
|
||||
|
|
|
@ -21,8 +21,12 @@
|
|||
#include "minddata/dataset/kernels/image/crop_op.h"
|
||||
#include "minddata/dataset/kernels/image/cut_out_op.h"
|
||||
#include "minddata/dataset/kernels/image/decode_op.h"
|
||||
#include "minddata/dataset/kernels/image/hwc_to_chw_op.h"
|
||||
#include "minddata/dataset/kernels/image/mixup_batch_op.h"
|
||||
#include "minddata/dataset/kernels/image/normalize_op.h"
|
||||
#include "minddata/dataset/kernels/data/one_hot_op.h"
|
||||
#include "minddata/dataset/kernels/image/pad_op.h"
|
||||
#include "minddata/dataset/kernels/image/random_affine_op.h"
|
||||
#include "minddata/dataset/kernels/image/random_color_adjust_op.h"
|
||||
#include "minddata/dataset/kernels/image/random_crop_op.h"
|
||||
#include "minddata/dataset/kernels/image/random_horizontal_flip_op.h"
|
||||
|
@ -81,6 +85,26 @@ std::shared_ptr<DecodeOperation> Decode(bool rgb) {
|
|||
return op;
|
||||
}
|
||||
|
||||
// Function to create HwcToChwOperation.
|
||||
std::shared_ptr<HwcToChwOperation> HWC2CHW() {
|
||||
auto op = std::make_shared<HwcToChwOperation>();
|
||||
// Input validation
|
||||
if (!op->ValidateParams()) {
|
||||
return nullptr;
|
||||
}
|
||||
return op;
|
||||
}
|
||||
|
||||
// Function to create MixUpBatchOperation.
|
||||
std::shared_ptr<MixUpBatchOperation> MixUpBatch(float alpha) {
|
||||
auto op = std::make_shared<MixUpBatchOperation>(alpha);
|
||||
// Input validation
|
||||
if (!op->ValidateParams()) {
|
||||
return nullptr;
|
||||
}
|
||||
return op;
|
||||
}
|
||||
|
||||
// Function to create NormalizeOperation.
|
||||
std::shared_ptr<NormalizeOperation> Normalize(std::vector<float> mean, std::vector<float> std) {
|
||||
auto op = std::make_shared<NormalizeOperation>(mean, std);
|
||||
|
@ -91,6 +115,16 @@ std::shared_ptr<NormalizeOperation> Normalize(std::vector<float> mean, std::vect
|
|||
return op;
|
||||
}
|
||||
|
||||
// Function to create OneHotOperation.
|
||||
std::shared_ptr<OneHotOperation> OneHot(int32_t num_classes) {
|
||||
auto op = std::make_shared<OneHotOperation>(num_classes);
|
||||
// Input validation
|
||||
if (!op->ValidateParams()) {
|
||||
return nullptr;
|
||||
}
|
||||
return op;
|
||||
}
|
||||
|
||||
// Function to create PadOperation.
|
||||
std::shared_ptr<PadOperation> Pad(std::vector<int32_t> padding, std::vector<uint8_t> fill_value,
|
||||
BorderType padding_mode) {
|
||||
|
@ -114,10 +148,27 @@ std::shared_ptr<RandomColorAdjustOperation> RandomColorAdjust(std::vector<float>
|
|||
return op;
|
||||
}
|
||||
|
||||
// Function to create RandomAffineOperation.
|
||||
std::shared_ptr<RandomAffineOperation> RandomAffine(const std::vector<float_t> °rees,
|
||||
const std::vector<float_t> &translate_range,
|
||||
const std::vector<float_t> &scale_range,
|
||||
const std::vector<float_t> &shear_ranges,
|
||||
InterpolationMode interpolation,
|
||||
const std::vector<uint8_t> &fill_value) {
|
||||
auto op = std::make_shared<RandomAffineOperation>(degrees, translate_range, scale_range, shear_ranges, interpolation,
|
||||
fill_value);
|
||||
// Input validation
|
||||
if (!op->ValidateParams()) {
|
||||
return nullptr;
|
||||
}
|
||||
return op;
|
||||
}
|
||||
|
||||
// Function to create RandomCropOperation.
|
||||
std::shared_ptr<RandomCropOperation> RandomCrop(std::vector<int32_t> size, std::vector<int32_t> padding,
|
||||
bool pad_if_needed, std::vector<uint8_t> fill_value) {
|
||||
auto op = std::make_shared<RandomCropOperation>(size, padding, pad_if_needed, fill_value);
|
||||
bool pad_if_needed, std::vector<uint8_t> fill_value,
|
||||
BorderType padding_mode) {
|
||||
auto op = std::make_shared<RandomCropOperation>(size, padding, pad_if_needed, fill_value, padding_mode);
|
||||
// Input validation
|
||||
if (!op->ValidateParams()) {
|
||||
return nullptr;
|
||||
|
@ -271,6 +322,25 @@ bool DecodeOperation::ValidateParams() { return true; }
|
|||
|
||||
std::shared_ptr<TensorOp> DecodeOperation::Build() { return std::make_shared<DecodeOp>(rgb_); }
|
||||
|
||||
// HwcToChwOperation
|
||||
bool HwcToChwOperation::ValidateParams() { return true; }
|
||||
|
||||
std::shared_ptr<TensorOp> HwcToChwOperation::Build() { return std::make_shared<HwcToChwOp>(); }
|
||||
|
||||
// MixUpOperation
|
||||
MixUpBatchOperation::MixUpBatchOperation(float alpha) : alpha_(alpha) {}
|
||||
|
||||
bool MixUpBatchOperation::ValidateParams() {
|
||||
if (alpha_ < 0) {
|
||||
MS_LOG(ERROR) << "MixUpBatch: alpha must be a positive floating value however it is: " << alpha_;
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
std::shared_ptr<TensorOp> MixUpBatchOperation::Build() { return std::make_shared<MixUpBatchOp>(alpha_); }
|
||||
|
||||
// NormalizeOperation
|
||||
NormalizeOperation::NormalizeOperation(std::vector<float> mean, std::vector<float> std) : mean_(mean), std_(std) {}
|
||||
|
||||
|
@ -292,6 +362,20 @@ std::shared_ptr<TensorOp> NormalizeOperation::Build() {
|
|||
return std::make_shared<NormalizeOp>(mean_[0], mean_[1], mean_[2], std_[0], std_[1], std_[2]);
|
||||
}
|
||||
|
||||
// OneHotOperation
|
||||
OneHotOperation::OneHotOperation(int32_t num_classes) : num_classes_(num_classes) {}
|
||||
|
||||
bool OneHotOperation::ValidateParams() {
|
||||
if (num_classes_ < 0) {
|
||||
MS_LOG(ERROR) << "OneHot: Number of classes cannot be negative. Number of classes: " << num_classes_;
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
std::shared_ptr<TensorOp> OneHotOperation::Build() { return std::make_shared<OneHotOp>(num_classes_); }
|
||||
|
||||
// PadOperation
|
||||
PadOperation::PadOperation(std::vector<int32_t> padding, std::vector<uint8_t> fill_value, BorderType padding_mode)
|
||||
: padding_(padding), fill_value_(fill_value), padding_mode_(padding_mode) {}
|
||||
|
@ -401,10 +485,90 @@ std::shared_ptr<TensorOp> RandomColorAdjustOperation::Build() {
|
|||
return tensor_op;
|
||||
}
|
||||
|
||||
// RandomAffineOperation
|
||||
RandomAffineOperation::RandomAffineOperation(const std::vector<float_t> °rees,
|
||||
const std::vector<float_t> &translate_range,
|
||||
const std::vector<float_t> &scale_range,
|
||||
const std::vector<float_t> &shear_ranges, InterpolationMode interpolation,
|
||||
const std::vector<uint8_t> &fill_value)
|
||||
: degrees_(degrees),
|
||||
translate_range_(translate_range),
|
||||
scale_range_(scale_range),
|
||||
shear_ranges_(shear_ranges),
|
||||
interpolation_(interpolation),
|
||||
fill_value_(fill_value) {}
|
||||
|
||||
bool RandomAffineOperation::ValidateParams() {
|
||||
// Degrees
|
||||
if (degrees_.size() != 2) {
|
||||
MS_LOG(ERROR) << "RandomAffine: degrees vector has incorrect size: degrees.size() = " << degrees_.size();
|
||||
return false;
|
||||
}
|
||||
if (degrees_[0] > degrees_[1]) {
|
||||
MS_LOG(ERROR) << "RandomAffine: minimum of degrees range is greater than maximum: min = " << degrees_[0]
|
||||
<< ", max = " << degrees_[1];
|
||||
return false;
|
||||
}
|
||||
// Translate
|
||||
if (translate_range_.size() != 2) {
|
||||
MS_LOG(ERROR) << "RandomAffine: translate_range vector has incorrect size: translate_range.size() = "
|
||||
<< translate_range_.size();
|
||||
return false;
|
||||
}
|
||||
if (translate_range_[0] > translate_range_[1]) {
|
||||
MS_LOG(ERROR) << "RandomAffine: minimum of translate range is greater than maximum: min = " << translate_range_[0]
|
||||
<< ", max = " << translate_range_[1];
|
||||
return false;
|
||||
}
|
||||
// Scale
|
||||
if (scale_range_.size() != 2) {
|
||||
MS_LOG(ERROR) << "RandomAffine: scale_range vector has incorrect size: scale_range.size() = "
|
||||
<< scale_range_.size();
|
||||
return false;
|
||||
}
|
||||
if (scale_range_[0] > scale_range_[1]) {
|
||||
MS_LOG(ERROR) << "RandomAffine: minimum of scale range is greater than maximum: min = " << scale_range_[0]
|
||||
<< ", max = " << scale_range_[1];
|
||||
return false;
|
||||
}
|
||||
// Shear
|
||||
if (shear_ranges_.size() != 4) {
|
||||
MS_LOG(ERROR) << "RandomAffine: shear_ranges vector has incorrect size: shear_ranges.size() = "
|
||||
<< shear_ranges_.size();
|
||||
return false;
|
||||
}
|
||||
if (shear_ranges_[0] > shear_ranges_[1]) {
|
||||
MS_LOG(ERROR) << "RandomAffine: minimum of horizontal shear range is greater than maximum: min = "
|
||||
<< shear_ranges_[0] << ", max = " << shear_ranges_[1];
|
||||
return false;
|
||||
}
|
||||
if (shear_ranges_[2] > shear_ranges_[3]) {
|
||||
MS_LOG(ERROR) << "RandomAffine: minimum of vertical shear range is greater than maximum: min = " << shear_ranges_[2]
|
||||
<< ", max = " << scale_range_[3];
|
||||
return false;
|
||||
}
|
||||
// Fill Value
|
||||
if (fill_value_.size() != 3) {
|
||||
MS_LOG(ERROR) << "RandomAffine: fill_value vector has incorrect size: fill_value.size() = " << fill_value_.size();
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
std::shared_ptr<TensorOp> RandomAffineOperation::Build() {
|
||||
auto tensor_op = std::make_shared<RandomAffineOp>(degrees_, translate_range_, scale_range_, shear_ranges_,
|
||||
interpolation_, fill_value_);
|
||||
return tensor_op;
|
||||
}
|
||||
|
||||
// RandomCropOperation
|
||||
RandomCropOperation::RandomCropOperation(std::vector<int32_t> size, std::vector<int32_t> padding, bool pad_if_needed,
|
||||
std::vector<uint8_t> fill_value)
|
||||
: size_(size), padding_(padding), pad_if_needed_(pad_if_needed), fill_value_(fill_value) {}
|
||||
std::vector<uint8_t> fill_value, BorderType padding_mode)
|
||||
: size_(size),
|
||||
padding_(padding),
|
||||
pad_if_needed_(pad_if_needed),
|
||||
fill_value_(fill_value),
|
||||
padding_mode_(padding_mode) {}
|
||||
|
||||
bool RandomCropOperation::ValidateParams() {
|
||||
if (size_.empty() || size_.size() > 2) {
|
||||
|
@ -443,7 +607,7 @@ std::shared_ptr<TensorOp> RandomCropOperation::Build() {
|
|||
}
|
||||
|
||||
auto tensor_op = std::make_shared<RandomCropOp>(crop_height, crop_width, pad_top, pad_bottom, pad_left, pad_right,
|
||||
BorderType::kConstant, pad_if_needed_, fill_r, fill_g, fill_b);
|
||||
padding_mode_, pad_if_needed_, fill_r, fill_g, fill_b);
|
||||
return tensor_op;
|
||||
}
|
||||
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
#include "minddata/dataset/engine/datasetops/cache_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/cache_lookup_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/cache_merge_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/device_queue_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/epoch_ctrl_op.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -258,6 +259,13 @@ Status RepeatPass::RunOnNode(std::shared_ptr<CacheLookupOp> node, bool *modified
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status RepeatPass::RunOnNode(std::shared_ptr<DeviceQueueOp> node, bool *modified) {
|
||||
// Set total repeats and total epochs for the DeviceQueueOp
|
||||
node->set_total_repeats(num_epochs_);
|
||||
node->set_num_repeats_per_epoch(1);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Adds an operator to the eoe operator stack save area
|
||||
void RepeatPass::AddToEOEOpStack(std::shared_ptr<DatasetOp> dataset_op) {
|
||||
op_stack *current_stack = eoe_op_stacks_.top().get();
|
||||
|
|
|
@ -92,6 +92,12 @@ class RepeatPass : public NodePass {
|
|||
/// \return Status The error code return
|
||||
Status RunOnNode(std::shared_ptr<CacheLookupOp> node, bool *modified) override;
|
||||
|
||||
/// \brief Set the epoch count for DeviceQueue
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The error code return
|
||||
Status RunOnNode(std::shared_ptr<DeviceQueueOp> node, bool *modified) override;
|
||||
|
||||
/// \brief All operators have a flag that might be set related to the repeat and any leaf nodes need to be set up
|
||||
/// for use with a controlling repeat above it.
|
||||
/// \param[in] node The node being visited
|
||||
|
|
|
@ -196,8 +196,9 @@ class Dataset : public std::enable_shared_from_this<Dataset> {
|
|||
}
|
||||
|
||||
/// \brief Function to create an Iterator over the Dataset pipeline
|
||||
/// \param[in] columns List of columns to be used to specify the order of columns
|
||||
/// \return Shared pointer to the Iterator
|
||||
std::shared_ptr<Iterator> CreateIterator();
|
||||
std::shared_ptr<Iterator> CreateIterator(std::vector<std::string> columns = {});
|
||||
|
||||
/// \brief Function to create a BatchDataset
|
||||
/// \notes Combines batch_size number of consecutive rows into batches
|
||||
|
@ -452,6 +453,12 @@ class VOCDataset : public Dataset {
|
|||
bool ValidateParams() override;
|
||||
|
||||
private:
|
||||
const std::string kColumnImage = "image";
|
||||
const std::string kColumnTarget = "target";
|
||||
const std::string kColumnBbox = "bbox";
|
||||
const std::string kColumnLabel = "label";
|
||||
const std::string kColumnDifficult = "difficult";
|
||||
const std::string kColumnTruncate = "truncate";
|
||||
std::string dataset_dir_;
|
||||
std::string task_;
|
||||
std::string mode_;
|
||||
|
|
|
@ -37,6 +37,7 @@ namespace api {
|
|||
class Dataset;
|
||||
|
||||
using TensorMap = std::unordered_map<std::string, std::shared_ptr<Tensor>>;
|
||||
using TensorVec = std::vector<std::shared_ptr<Tensor>>;
|
||||
|
||||
// Abstract class for iterating over the dataset.
|
||||
class Iterator {
|
||||
|
@ -53,9 +54,15 @@ class Iterator {
|
|||
Status BuildAndLaunchTree(std::shared_ptr<Dataset> ds);
|
||||
|
||||
/// \brief Function to get the next row from the data pipeline.
|
||||
/// \note Type of return data is a map(with column name).
|
||||
/// \param[out] row - the output tensor row.
|
||||
void GetNextRow(TensorMap *row);
|
||||
|
||||
/// \brief Function to get the next row from the data pipeline.
|
||||
/// \note Type of return data is a vector(without column name).
|
||||
/// \param[out] row - the output tensor row.
|
||||
void GetNextRow(TensorVec *row);
|
||||
|
||||
/// \brief Function to shut down the data pipeline.
|
||||
void Stop();
|
||||
|
||||
|
|
|
@ -51,8 +51,12 @@ class CenterCropOperation;
|
|||
class CropOperation;
|
||||
class CutOutOperation;
|
||||
class DecodeOperation;
|
||||
class HwcToChwOperation;
|
||||
class MixUpBatchOperation;
|
||||
class NormalizeOperation;
|
||||
class OneHotOperation;
|
||||
class PadOperation;
|
||||
class RandomAffineOperation;
|
||||
class RandomColorAdjustOperation;
|
||||
class RandomCropOperation;
|
||||
class RandomHorizontalFlipOperation;
|
||||
|
@ -90,6 +94,18 @@ std::shared_ptr<CutOutOperation> CutOut(int32_t length, int32_t num_patches = 1)
|
|||
/// \return Shared pointer to the current TensorOperation.
|
||||
std::shared_ptr<DecodeOperation> Decode(bool rgb = true);
|
||||
|
||||
/// \brief Function to create a HwcToChw TensorOperation.
|
||||
/// \notes Transpose the input image; shape (H, W, C) to shape (C, H, W).
|
||||
/// \return Shared pointer to the current TensorOperation.
|
||||
std::shared_ptr<HwcToChwOperation> HWC2CHW();
|
||||
|
||||
/// \brief Function to create a MixUpBatch TensorOperation.
|
||||
/// \notes Apply MixUp transformation on an input batch of images and labels. The labels must be in one-hot format and
|
||||
/// Batch must be called before calling this function.
|
||||
/// \param[in] alpha hyperparameter of beta distribution (default = 1.0)
|
||||
/// \return Shared pointer to the current TensorOperation.
|
||||
std::shared_ptr<MixUpBatchOperation> MixUpBatch(float alpha = 1);
|
||||
|
||||
/// \brief Function to create a Normalize TensorOperation.
|
||||
/// \notes Normalize the input image with respect to mean and standard deviation.
|
||||
/// \param[in] mean - a vector of mean values for each channel, w.r.t channel order.
|
||||
|
@ -97,6 +113,12 @@ std::shared_ptr<DecodeOperation> Decode(bool rgb = true);
|
|||
/// \return Shared pointer to the current TensorOperation.
|
||||
std::shared_ptr<NormalizeOperation> Normalize(std::vector<float> mean, std::vector<float> std);
|
||||
|
||||
/// \brief Function to create a OneHot TensorOperation.
|
||||
/// \notes Convert the labels into OneHot format.
|
||||
/// \param[in] num_classes number of classes.
|
||||
/// \return Shared pointer to the current TensorOperation.
|
||||
std::shared_ptr<OneHotOperation> OneHot(int32_t num_classes);
|
||||
|
||||
/// \brief Function to create a Pad TensorOp
|
||||
/// \notes Pads the image according to padding parameters
|
||||
/// \param[in] padding A vector representing the number of pixels to pad the image
|
||||
|
@ -119,6 +141,23 @@ std::shared_ptr<NormalizeOperation> Normalize(std::vector<float> mean, std::vect
|
|||
std::shared_ptr<PadOperation> Pad(std::vector<int32_t> padding, std::vector<uint8_t> fill_value = {0},
|
||||
BorderType padding_mode = BorderType::kConstant);
|
||||
|
||||
/// \brief Function to create a RandomAffine TensorOperation.
|
||||
/// \notes Applies a Random Affine transformation on input image in RGB or Greyscale mode.
|
||||
/// \param[in] degrees A float vector size 2, representing the starting and ending degree
|
||||
/// \param[in] translate_range A float vector size 2, representing percentages of translation on x and y axes.
|
||||
/// \param[in] scale_range A float vector size 2, representing the starting and ending scales in the range.
|
||||
/// \param[in] shear_ranges A float vector size 4, representing the starting and ending shear degrees vertically and
|
||||
/// horizontally.
|
||||
/// \param[in] interpolation An enum for the mode of interpolation
|
||||
/// \param[in] fill_value A uint8_t vector size 3, representing the pixel intensity of the borders, it is used to
|
||||
/// fill R, G, B channels respectively.
|
||||
/// \return Shared pointer to the current TensorOperation.
|
||||
std::shared_ptr<RandomAffineOperation> RandomAffine(
|
||||
const std::vector<float_t> °rees, const std::vector<float_t> &translate_range = {0.0, 0.0},
|
||||
const std::vector<float_t> &scale_range = {1.0, 1.0}, const std::vector<float_t> &shear_ranges = {0.0, 0.0, 0.0, 0.0},
|
||||
InterpolationMode interpolation = InterpolationMode::kNearestNeighbour,
|
||||
const std::vector<uint8_t> &fill_value = {0, 0, 0});
|
||||
|
||||
/// \brief Randomly adjust the brightness, contrast, saturation, and hue of the input image
|
||||
/// \param[in] brightness Brightness adjustment factor. Must be a vector of one or two values
|
||||
/// if it's a vector of two values it needs to be in the form of [min, max]. Default value is {1, 1}
|
||||
|
@ -148,8 +187,8 @@ std::shared_ptr<RandomColorAdjustOperation> RandomColorAdjust(std::vector<float>
|
|||
/// fill R, G, B channels respectively.
|
||||
/// \return Shared pointer to the current TensorOperation.
|
||||
std::shared_ptr<RandomCropOperation> RandomCrop(std::vector<int32_t> size, std::vector<int32_t> padding = {0, 0, 0, 0},
|
||||
bool pad_if_needed = false,
|
||||
std::vector<uint8_t> fill_value = {0, 0, 0});
|
||||
bool pad_if_needed = false, std::vector<uint8_t> fill_value = {0, 0, 0},
|
||||
BorderType padding_mode = BorderType::kConstant);
|
||||
|
||||
/// \brief Function to create a RandomHorizontalFlip TensorOperation.
|
||||
/// \notes Tensor operation to perform random horizontal flip.
|
||||
|
@ -258,6 +297,29 @@ class DecodeOperation : public TensorOperation {
|
|||
bool rgb_;
|
||||
};
|
||||
|
||||
class HwcToChwOperation : public TensorOperation {
|
||||
public:
|
||||
~HwcToChwOperation() = default;
|
||||
|
||||
std::shared_ptr<TensorOp> Build() override;
|
||||
|
||||
bool ValidateParams() override;
|
||||
};
|
||||
|
||||
class MixUpBatchOperation : public TensorOperation {
|
||||
public:
|
||||
explicit MixUpBatchOperation(float alpha = 1);
|
||||
|
||||
~MixUpBatchOperation() = default;
|
||||
|
||||
std::shared_ptr<TensorOp> Build() override;
|
||||
|
||||
bool ValidateParams() override;
|
||||
|
||||
private:
|
||||
float alpha_;
|
||||
};
|
||||
|
||||
class NormalizeOperation : public TensorOperation {
|
||||
public:
|
||||
NormalizeOperation(std::vector<float> mean, std::vector<float> std);
|
||||
|
@ -273,6 +335,20 @@ class NormalizeOperation : public TensorOperation {
|
|||
std::vector<float> std_;
|
||||
};
|
||||
|
||||
class OneHotOperation : public TensorOperation {
|
||||
public:
|
||||
explicit OneHotOperation(int32_t num_classes_);
|
||||
|
||||
~OneHotOperation() = default;
|
||||
|
||||
std::shared_ptr<TensorOp> Build() override;
|
||||
|
||||
bool ValidateParams() override;
|
||||
|
||||
private:
|
||||
float num_classes_;
|
||||
};
|
||||
|
||||
class PadOperation : public TensorOperation {
|
||||
public:
|
||||
PadOperation(std::vector<int32_t> padding, std::vector<uint8_t> fill_value = {0},
|
||||
|
@ -290,6 +366,29 @@ class PadOperation : public TensorOperation {
|
|||
BorderType padding_mode_;
|
||||
};
|
||||
|
||||
class RandomAffineOperation : public TensorOperation {
|
||||
public:
|
||||
RandomAffineOperation(const std::vector<float_t> °rees, const std::vector<float_t> &translate_range = {0.0, 0.0},
|
||||
const std::vector<float_t> &scale_range = {1.0, 1.0},
|
||||
const std::vector<float_t> &shear_ranges = {0.0, 0.0, 0.0, 0.0},
|
||||
InterpolationMode interpolation = InterpolationMode::kNearestNeighbour,
|
||||
const std::vector<uint8_t> &fill_value = {0, 0, 0});
|
||||
|
||||
~RandomAffineOperation() = default;
|
||||
|
||||
std::shared_ptr<TensorOp> Build() override;
|
||||
|
||||
bool ValidateParams() override;
|
||||
|
||||
private:
|
||||
std::vector<float_t> degrees_; // min_degree, max_degree
|
||||
std::vector<float_t> translate_range_; // maximum x translation percentage, maximum y translation percentage
|
||||
std::vector<float_t> scale_range_; // min_scale, max_scale
|
||||
std::vector<float_t> shear_ranges_; // min_x_shear, max_x_shear, min_y_shear, max_y_shear
|
||||
InterpolationMode interpolation_;
|
||||
std::vector<uint8_t> fill_value_;
|
||||
};
|
||||
|
||||
class RandomColorAdjustOperation : public TensorOperation {
|
||||
public:
|
||||
RandomColorAdjustOperation(std::vector<float> brightness = {1.0, 1.0}, std::vector<float> contrast = {1.0, 1.0},
|
||||
|
@ -311,7 +410,8 @@ class RandomColorAdjustOperation : public TensorOperation {
|
|||
class RandomCropOperation : public TensorOperation {
|
||||
public:
|
||||
RandomCropOperation(std::vector<int32_t> size, std::vector<int32_t> padding = {0, 0, 0, 0},
|
||||
bool pad_if_needed = false, std::vector<uint8_t> fill_value = {0, 0, 0});
|
||||
bool pad_if_needed = false, std::vector<uint8_t> fill_value = {0, 0, 0},
|
||||
BorderType padding_mode = BorderType::kConstant);
|
||||
|
||||
~RandomCropOperation() = default;
|
||||
|
||||
|
@ -324,6 +424,7 @@ class RandomCropOperation : public TensorOperation {
|
|||
std::vector<int32_t> padding_;
|
||||
bool pad_if_needed_;
|
||||
std::vector<uint8_t> fill_value_;
|
||||
BorderType padding_mode_;
|
||||
};
|
||||
|
||||
class RandomHorizontalFlipOperation : public TensorOperation {
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
#include <limits>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <utility>
|
||||
|
||||
#include "minddata/dataset/core/constants.h"
|
||||
#include "minddata/dataset/core/data_type.h"
|
||||
|
@ -648,5 +649,30 @@ Status Concatenate(const TensorRow &input, TensorRow *output, int8_t axis, std::
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status BatchTensorToCVTensorVector(const std::shared_ptr<Tensor> &input,
|
||||
std::vector<std::shared_ptr<CVTensor>> *output) {
|
||||
std::vector<int64_t> tensor_shape = input->shape().AsVector();
|
||||
TensorShape remaining({-1});
|
||||
std::vector<int64_t> index(tensor_shape.size(), 0);
|
||||
if (tensor_shape.size() <= 1) {
|
||||
RETURN_STATUS_UNEXPECTED("Tensor must be at least 2-D in order to unpack");
|
||||
}
|
||||
TensorShape element_shape(std::vector<int64_t>(tensor_shape.begin() + 1, tensor_shape.end()));
|
||||
|
||||
for (; index[0] < tensor_shape[0]; index[0]++) {
|
||||
uchar *start_addr_of_index = nullptr;
|
||||
std::shared_ptr<Tensor> out;
|
||||
|
||||
RETURN_IF_NOT_OK(input->StartAddrOfIndex(index, &start_addr_of_index, &remaining));
|
||||
RETURN_IF_NOT_OK(input->CreateFromMemory(element_shape, input->type(), start_addr_of_index, &out));
|
||||
std::shared_ptr<CVTensor> cv_out = CVTensor::AsCVTensor(std::move(out));
|
||||
if (!cv_out->mat().data) {
|
||||
RETURN_STATUS_UNEXPECTED("Could not convert to CV Tensor");
|
||||
}
|
||||
output->push_back(cv_out);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -152,6 +152,17 @@ Status Mask(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *outpu
|
|||
|
||||
Status Concatenate(const TensorRow &input, TensorRow *output, int8_t axis, std::shared_ptr<Tensor> prepend,
|
||||
std::shared_ptr<Tensor> append);
|
||||
|
||||
// helper for concat, always append to the input, and pass that to the output
|
||||
Status ConcatenateHelper(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, int8_t axis,
|
||||
std::shared_ptr<Tensor> append);
|
||||
|
||||
/// Convert an n-dimensional Tensor to a vector of (n-1)-dimensional CVTensors
|
||||
/// @param input[in] input tensor
|
||||
/// @param output[out] output tensor
|
||||
/// @return Status ok/error
|
||||
Status BatchTensorToCVTensorVector(const std::shared_ptr<Tensor> &input,
|
||||
std::vector<std::shared_ptr<CVTensor>> *output);
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc")
|
||||
set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD)
|
||||
add_subdirectory(soft_dvpp)
|
||||
add_library(kernels-image OBJECT
|
||||
affine_op.cc
|
||||
auto_contrast_op.cc
|
||||
center_crop_op.cc
|
||||
crop_op.cc
|
||||
|
@ -10,8 +12,11 @@ add_library(kernels-image OBJECT
|
|||
hwc_to_chw_op.cc
|
||||
image_utils.cc
|
||||
invert_op.cc
|
||||
math_utils.cc
|
||||
mixup_batch_op.cc
|
||||
normalize_op.cc
|
||||
pad_op.cc
|
||||
random_affine_op.cc
|
||||
random_color_adjust_op.cc
|
||||
random_crop_decode_resize_op.cc
|
||||
random_crop_and_resize_with_bbox_op.cc
|
||||
|
@ -34,3 +39,4 @@ add_library(kernels-image OBJECT
|
|||
resize_with_bbox_op.cc
|
||||
random_resize_with_bbox_op.cc
|
||||
)
|
||||
add_dependencies(kernels-image kernels-soft-dvpp-image)
|
||||
|
|
|
@ -0,0 +1,99 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <algorithm>
|
||||
#include <random>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/kernels/image/affine_op.h"
|
||||
#include "minddata/dataset/kernels/image/image_utils.h"
|
||||
#include "minddata/dataset/kernels/image/math_utils.h"
|
||||
#include "minddata/dataset/util/random.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
const InterpolationMode AffineOp::kDefInterpolation = InterpolationMode::kNearestNeighbour;
|
||||
const float_t AffineOp::kDegrees = 0.0;
|
||||
const std::vector<float_t> AffineOp::kTranslation = {0.0, 0.0};
|
||||
const float_t AffineOp::kScale = 1.0;
|
||||
const std::vector<float_t> AffineOp::kShear = {0.0, 0.0};
|
||||
const std::vector<uint8_t> AffineOp::kFillValue = {0, 0, 0};
|
||||
|
||||
AffineOp::AffineOp(float_t degrees, const std::vector<float_t> &translation, float_t scale,
|
||||
const std::vector<float_t> &shear, InterpolationMode interpolation,
|
||||
const std::vector<uint8_t> &fill_value)
|
||||
: degrees_(degrees),
|
||||
translation_(translation),
|
||||
scale_(scale),
|
||||
shear_(shear),
|
||||
interpolation_(interpolation),
|
||||
fill_value_(fill_value) {}
|
||||
|
||||
Status AffineOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
|
||||
IO_CHECK(input, output);
|
||||
float_t translation_x = translation_[0];
|
||||
float_t translation_y = translation_[1];
|
||||
float_t degrees = 0.0;
|
||||
DegreesToRadians(degrees_, °rees);
|
||||
float_t shear_x = shear_[0];
|
||||
float_t shear_y = shear_[1];
|
||||
DegreesToRadians(shear_x, &shear_x);
|
||||
DegreesToRadians(-1 * shear_y, &shear_y);
|
||||
std::shared_ptr<CVTensor> input_cv = CVTensor::AsCVTensor(input);
|
||||
|
||||
// Apply Affine Transformation
|
||||
// T is translation matrix: [1, 0, tx | 0, 1, ty | 0, 0, 1]
|
||||
// C is translation matrix to keep center: [1, 0, cx | 0, 1, cy | 0, 0, 1]
|
||||
// RSS is rotation with scale and shear matrix
|
||||
// RSS(a, s, (sx, sy)) =
|
||||
// = R(a) * S(s) * SHy(sy) * SHx(sx)
|
||||
// = [ s*cos(a - sy)/cos(sy), s*(-cos(a - sy)*tan(x)/cos(y) - sin(a)), 0 ]
|
||||
// [ s*sin(a - sy)/cos(sy), s*(-sin(a - sy)*tan(x)/cos(y) + cos(a)), 0 ]
|
||||
// [ 0 , 0 , 1 ]
|
||||
//
|
||||
// where R is a rotation matrix, S is a scaling matrix, and SHx and SHy are the shears:
|
||||
// SHx(s) = [1, -tan(s)] and SHy(s) = [1 , 0]
|
||||
// [0, 1 ] [-tan(s), 1]
|
||||
//
|
||||
// Thus, the affine matrix is M = T * C * RSS * C^-1
|
||||
|
||||
float_t cx = ((input_cv->mat().cols - 1) / 2.0);
|
||||
float_t cy = ((input_cv->mat().rows - 1) / 2.0);
|
||||
// Calculate RSS
|
||||
std::vector<float_t> matrix{scale_ * cos(degrees + shear_y) / cos(shear_y),
|
||||
scale_ * (-1 * cos(degrees + shear_y) * tan(shear_x) / cos(shear_y) - sin(degrees)),
|
||||
0,
|
||||
scale_ * sin(degrees + shear_y) / cos(shear_y),
|
||||
scale_ * (-1 * sin(degrees + shear_y) * tan(shear_x) / cos(shear_y) + cos(degrees)),
|
||||
0};
|
||||
// Compute T * C * RSS * C^-1
|
||||
matrix[2] = (1 - matrix[0]) * cx - matrix[1] * cy + translation_x;
|
||||
matrix[5] = (1 - matrix[4]) * cy - matrix[3] * cx + translation_y;
|
||||
cv::Mat affine_mat(matrix);
|
||||
affine_mat = affine_mat.reshape(1, {2, 3});
|
||||
|
||||
std::shared_ptr<CVTensor> output_cv;
|
||||
RETURN_IF_NOT_OK(CVTensor::CreateEmpty(input_cv->shape(), input_cv->type(), &output_cv));
|
||||
RETURN_UNEXPECTED_IF_NULL(output_cv);
|
||||
cv::warpAffine(input_cv->mat(), output_cv->mat(), affine_mat, input_cv->mat().size(),
|
||||
GetCVInterpolationMode(interpolation_), cv::BORDER_CONSTANT,
|
||||
cv::Scalar(fill_value_[0], fill_value_[1], fill_value_[2]));
|
||||
(*output) = std::static_pointer_cast<Tensor>(output_cv);
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,68 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_AFFINE_OP_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_AFFINE_OP_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/core/cv_tensor.h"
|
||||
#include "minddata/dataset/core/tensor.h"
|
||||
#include "minddata/dataset/kernels/tensor_op.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
class AffineOp : public TensorOp {
|
||||
public:
|
||||
/// Default values
|
||||
static const float_t kDegrees;
|
||||
static const std::vector<float_t> kTranslation;
|
||||
static const float_t kScale;
|
||||
static const std::vector<float_t> kShear;
|
||||
static const InterpolationMode kDefInterpolation;
|
||||
static const std::vector<uint8_t> kFillValue;
|
||||
|
||||
/// Constructor
|
||||
public:
|
||||
explicit AffineOp(float_t degrees, const std::vector<float_t> &translation = kTranslation, float_t scale = kScale,
|
||||
const std::vector<float_t> &shear = kShear, InterpolationMode interpolation = kDefInterpolation,
|
||||
const std::vector<uint8_t> &fill_value = kFillValue);
|
||||
|
||||
~AffineOp() override = default;
|
||||
|
||||
std::string Name() const override { return kAffineOp; }
|
||||
|
||||
Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override;
|
||||
|
||||
/// Member variables
|
||||
private:
|
||||
std::string kAffineOp = "AffineOp";
|
||||
|
||||
protected:
|
||||
float_t degrees_;
|
||||
std::vector<float_t> translation_; // translation_x and translation_y
|
||||
float_t scale_;
|
||||
std::vector<float_t> shear_; // shear_x and shear_y
|
||||
InterpolationMode interpolation_;
|
||||
std::vector<uint8_t> fill_value_;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_AFFINE_OP_H_
|
|
@ -21,6 +21,7 @@
|
|||
#include <utility>
|
||||
#include <opencv2/imgcodecs.hpp>
|
||||
#include "utils/ms_utils.h"
|
||||
#include "minddata/dataset/kernels/image/math_utils.h"
|
||||
#include "minddata/dataset/core/constants.h"
|
||||
#include "minddata/dataset/core/cv_tensor.h"
|
||||
#include "minddata/dataset/core/tensor.h"
|
||||
|
@ -631,36 +632,9 @@ Status AutoContrast(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor
|
|||
hist.col(0).copyTo(hist_vec);
|
||||
// Ignore values in ignore
|
||||
for (const auto &item : ignore) hist_vec[item] = 0;
|
||||
int32_t n = std::accumulate(hist_vec.begin(), hist_vec.end(), 0);
|
||||
// Find pixel values that are in the low cutoff and high cutoff.
|
||||
int32_t cut = static_cast<int32_t>((cutoff / 100.0) * n);
|
||||
if (cut != 0) {
|
||||
for (int32_t lo = 0; lo < 256 && cut > 0; lo++) {
|
||||
if (cut > hist_vec[lo]) {
|
||||
cut -= hist_vec[lo];
|
||||
hist_vec[lo] = 0;
|
||||
} else {
|
||||
hist_vec[lo] -= cut;
|
||||
cut = 0;
|
||||
}
|
||||
}
|
||||
cut = static_cast<int32_t>((cutoff / 100.0) * n);
|
||||
for (int32_t hi = 255; hi >= 0 && cut > 0; hi--) {
|
||||
if (cut > hist_vec[hi]) {
|
||||
cut -= hist_vec[hi];
|
||||
hist_vec[hi] = 0;
|
||||
} else {
|
||||
hist_vec[hi] -= cut;
|
||||
cut = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
int32_t lo = 0;
|
||||
int32_t hi = 255;
|
||||
for (; lo < 256 && !hist_vec[lo]; lo++) {
|
||||
}
|
||||
for (; hi >= 0 && !hist_vec[hi]; hi--) {
|
||||
}
|
||||
int32_t lo = 0;
|
||||
RETURN_IF_NOT_OK(ComputeUpperAndLowerPercentiles(&hist_vec, cutoff, cutoff, &hi, &lo));
|
||||
if (hi <= lo) {
|
||||
for (int32_t i = 0; i < 256; i++) {
|
||||
table.push_back(i);
|
||||
|
@ -685,7 +659,6 @@ Status AutoContrast(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor
|
|||
std::shared_ptr<CVTensor> output_cv;
|
||||
RETURN_IF_NOT_OK(CVTensor::CreateFromMat(result, &output_cv));
|
||||
(*output) = std::static_pointer_cast<Tensor>(output_cv);
|
||||
(*output) = std::static_pointer_cast<Tensor>(output_cv);
|
||||
(*output)->Reshape(input->shape());
|
||||
} catch (const cv::Exception &e) {
|
||||
RETURN_STATUS_UNEXPECTED("Error in auto contrast");
|
||||
|
@ -983,5 +956,24 @@ Status UpdateBBoxesForResize(const std::shared_ptr<Tensor> &bboxList, const size
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GetJpegImageInfo(const std::shared_ptr<Tensor> &input, int *img_width, int *img_height) {
|
||||
struct jpeg_decompress_struct cinfo {};
|
||||
struct JpegErrorManagerCustom jerr {};
|
||||
cinfo.err = jpeg_std_error(&jerr.pub);
|
||||
jerr.pub.error_exit = JpegErrorExitCustom;
|
||||
try {
|
||||
jpeg_create_decompress(&cinfo);
|
||||
JpegSetSource(&cinfo, input->GetBuffer(), input->SizeInBytes());
|
||||
(void)jpeg_read_header(&cinfo, TRUE);
|
||||
jpeg_calc_output_dimensions(&cinfo);
|
||||
} catch (std::runtime_error &e) {
|
||||
jpeg_destroy_decompress(&cinfo);
|
||||
RETURN_STATUS_UNEXPECTED(e.what());
|
||||
}
|
||||
*img_height = cinfo.output_height;
|
||||
*img_width = cinfo.output_width;
|
||||
jpeg_destroy_decompress(&cinfo);
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -268,6 +268,12 @@ Status PadBBoxes(const std::shared_ptr<Tensor> *bboxList, const size_t &bboxCoun
|
|||
Status UpdateBBoxesForResize(const std::shared_ptr<Tensor> &bboxList, const size_t &bboxCount, int32_t target_width_,
|
||||
int32_t target_height_, int orig_width, int orig_height);
|
||||
|
||||
// Get jpeg image width and height
|
||||
// @param input: CVTensor containing the not decoded image 1D bytes
|
||||
// @param img_width: the jpeg image width
|
||||
// @param img_height: the jpeg image height
|
||||
Status GetJpegImageInfo(const std::shared_ptr<Tensor> &input, int *img_width, int *img_height);
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_IMAGE_UTILS_H_
|
||||
|
|
|
@ -0,0 +1,84 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "minddata/dataset/kernels/image/math_utils.h"
|
||||
|
||||
#include <opencv2/imgproc/types_c.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <string>
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
Status ComputeUpperAndLowerPercentiles(std::vector<int32_t> *hist, int32_t hi_p, int32_t low_p, int32_t *hi,
|
||||
int32_t *lo) {
|
||||
try {
|
||||
int32_t n = std::accumulate(hist->begin(), hist->end(), 0);
|
||||
int32_t cut = static_cast<int32_t>((low_p / 100.0) * n);
|
||||
for (int32_t lb = 0; lb < hist->size() + 1 && cut > 0; lb++) {
|
||||
if (cut > (*hist)[lb]) {
|
||||
cut -= (*hist)[lb];
|
||||
(*hist)[lb] = 0;
|
||||
} else {
|
||||
(*hist)[lb] -= cut;
|
||||
cut = 0;
|
||||
}
|
||||
}
|
||||
cut = static_cast<int32_t>((hi_p / 100.0) * n);
|
||||
for (int32_t ub = hist->size() - 1; ub >= 0 && cut > 0; ub--) {
|
||||
if (cut > (*hist)[ub]) {
|
||||
cut -= (*hist)[ub];
|
||||
(*hist)[ub] = 0;
|
||||
} else {
|
||||
(*hist)[ub] -= cut;
|
||||
cut = 0;
|
||||
}
|
||||
}
|
||||
*lo = 0;
|
||||
*hi = hist->size() - 1;
|
||||
for (; (*lo) < (*hi) && !(*hist)[*lo]; (*lo)++) {
|
||||
}
|
||||
for (; (*hi) >= 0 && !(*hist)[*hi]; (*hi)--) {
|
||||
}
|
||||
} catch (const std::exception &e) {
|
||||
const char *err_msg = e.what();
|
||||
std::string err_message = "Error in ComputeUpperAndLowerPercentiles: ";
|
||||
err_message += err_msg;
|
||||
RETURN_STATUS_UNEXPECTED(err_message);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DegreesToRadians(float_t degrees, float_t *radians_target) {
|
||||
*radians_target = CV_PI * degrees / 180.0;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GenerateRealNumber(float_t a, float_t b, std::mt19937 *rnd, float_t *result) {
|
||||
try {
|
||||
std::uniform_real_distribution<float_t> distribution{a, b};
|
||||
*result = distribution(*rnd);
|
||||
} catch (const std::exception &e) {
|
||||
const char *err_msg = e.what();
|
||||
std::string err_message = "Error in GenerateRealNumber: ";
|
||||
err_message += err_msg;
|
||||
RETURN_STATUS_UNEXPECTED(err_message);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,50 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_MATH_UTILS_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_MATH_UTILS_H_
|
||||
|
||||
#include <memory>
|
||||
#include <random>
|
||||
#include <vector>
|
||||
#include "minddata/dataset/util/status.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
/// \brief Returns lower and upper pth percentiles of the input histogram.
|
||||
/// \param[in] hist: Input histogram (mutates the histogram for computation purposes)
|
||||
/// \param[in] hi_p: Right side percentile
|
||||
/// \param[in] low_p: Left side percentile
|
||||
/// \param[out] hi: Value at high end percentile
|
||||
/// \param[out] lo: Value at low end percentile
|
||||
Status ComputeUpperAndLowerPercentiles(std::vector<int32_t> *hist, int32_t hi_p, int32_t low_p, int32_t *hi,
|
||||
int32_t *lo);
|
||||
|
||||
/// \brief Converts degrees input to radians.
|
||||
/// \param[in] degrees: Input degrees
|
||||
/// \param[out] radians_target: Radians output
|
||||
Status DegreesToRadians(float_t degrees, float_t *radians_target);
|
||||
|
||||
/// \brief Generates a random real number in [a,b).
|
||||
/// \param[in] a: Start of range
|
||||
/// \param[in] b: End of range
|
||||
/// \param[in] rnd: Random device
|
||||
/// \param[out] result: Random number in range [a,b)
|
||||
Status GenerateRealNumber(float_t a, float_t b, std::mt19937 *rnd, float_t *result);
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_MATH_UTILS_H_
|
|
@ -0,0 +1,108 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include "minddata/dataset/core/cv_tensor.h"
|
||||
#include "minddata/dataset/kernels/image/mixup_batch_op.h"
|
||||
#include "minddata/dataset/kernels/data/data_utils.h"
|
||||
#include "minddata/dataset/util/random.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
MixUpBatchOp::MixUpBatchOp(float alpha) : alpha_(alpha) { rnd_.seed(GetSeed()); }
|
||||
|
||||
Status MixUpBatchOp::Compute(const TensorRow &input, TensorRow *output) {
|
||||
if (input.size() < 2) {
|
||||
RETURN_STATUS_UNEXPECTED("Both images and labels columns are required for this operation");
|
||||
}
|
||||
|
||||
std::vector<std::shared_ptr<CVTensor>> images;
|
||||
std::vector<int64_t> image_shape = input.at(0)->shape().AsVector();
|
||||
std::vector<int64_t> label_shape = input.at(1)->shape().AsVector();
|
||||
|
||||
// Check inputs
|
||||
if (label_shape.size() != 2 || image_shape.size() != 4 || image_shape[0] != label_shape[0]) {
|
||||
RETURN_STATUS_UNEXPECTED("You must batch before calling MixUpBatch");
|
||||
}
|
||||
|
||||
if ((image_shape[1] != 1 && image_shape[1] != 3) && (image_shape[3] != 1 && image_shape[3] != 3)) {
|
||||
RETURN_STATUS_UNEXPECTED("MixUpBatch: Images must be in the shape of HWC or CHW");
|
||||
}
|
||||
|
||||
// Move images into a vector of CVTensors
|
||||
RETURN_IF_NOT_OK(BatchTensorToCVTensorVector(input.at(0), &images));
|
||||
|
||||
// Calculating lambda
|
||||
// If x1 is a random variable from Gamma(a1, 1) and x2 is a random variable from Gamma(a2, 1)
|
||||
// then x = x1 / (x1+x2) is a random variable from Beta(a1, a2)
|
||||
std::gamma_distribution<float> distribution(alpha_, 1);
|
||||
float x1 = distribution(rnd_);
|
||||
float x2 = distribution(rnd_);
|
||||
float lam = x1 / (x1 + x2);
|
||||
|
||||
// Calculate random labels
|
||||
std::vector<int64_t> rand_indx;
|
||||
for (int64_t i = 0; i < images.size(); i++) rand_indx.push_back(i);
|
||||
std::shuffle(rand_indx.begin(), rand_indx.end(), rnd_);
|
||||
|
||||
// Compute labels
|
||||
std::shared_ptr<Tensor> out_labels;
|
||||
RETURN_IF_NOT_OK(TypeCast(std::move(input.at(1)), &out_labels, DataType("float32")));
|
||||
for (int64_t i = 0; i < label_shape[0]; i++) {
|
||||
for (int64_t j = 0; j < label_shape[1]; j++) {
|
||||
uint64_t first_value, second_value;
|
||||
RETURN_IF_NOT_OK(input.at(1)->GetItemAt(&first_value, {i, j}));
|
||||
RETURN_IF_NOT_OK(input.at(1)->GetItemAt(&second_value, {rand_indx[i], j}));
|
||||
RETURN_IF_NOT_OK(out_labels->SetItemAt({i, j}, lam * first_value + (1 - lam) * second_value));
|
||||
}
|
||||
}
|
||||
|
||||
// Compute images
|
||||
for (int64_t i = 0; i < images.size(); i++) {
|
||||
TensorShape remaining({-1});
|
||||
uchar *start_addr_of_index = nullptr;
|
||||
std::shared_ptr<Tensor> out;
|
||||
RETURN_IF_NOT_OK(input.at(0)->StartAddrOfIndex({rand_indx[i], 0, 0, 0}, &start_addr_of_index, &remaining));
|
||||
RETURN_IF_NOT_OK(input.at(0)->CreateFromMemory(TensorShape({image_shape[1], image_shape[2], image_shape[3]}),
|
||||
input.at(0)->type(), start_addr_of_index, &out));
|
||||
std::shared_ptr<CVTensor> rand_image = CVTensor::AsCVTensor(std::move(out));
|
||||
if (!rand_image->mat().data) {
|
||||
RETURN_STATUS_UNEXPECTED("Could not convert to CV Tensor");
|
||||
}
|
||||
images[i]->mat() = lam * images[i]->mat() + (1 - lam) * rand_image->mat();
|
||||
}
|
||||
|
||||
// Move the output into a TensorRow
|
||||
std::shared_ptr<Tensor> output_image;
|
||||
RETURN_IF_NOT_OK(Tensor::CreateEmpty(input.at(0)->shape(), input.at(0)->type(), &output_image));
|
||||
for (int64_t i = 0; i < images.size(); i++) {
|
||||
RETURN_IF_NOT_OK(output_image->InsertTensor({i}, images[i]));
|
||||
}
|
||||
output->push_back(output_image);
|
||||
output->push_back(out_labels);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void MixUpBatchOp::Print(std::ostream &out) const {
|
||||
out << "MixUpBatchOp: "
|
||||
<< "alpha: " << alpha_ << "\n";
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,51 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_MIXUPBATCH_OP_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_MIXUPBATCH_OP_H_
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <random>
|
||||
#include <string>
|
||||
|
||||
#include "minddata/dataset/core/tensor.h"
|
||||
#include "minddata/dataset/kernels/tensor_op.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
class MixUpBatchOp : public TensorOp {
|
||||
public:
|
||||
// Default values, also used by python_bindings.cc
|
||||
|
||||
explicit MixUpBatchOp(float alpha);
|
||||
|
||||
~MixUpBatchOp() override = default;
|
||||
|
||||
void Print(std::ostream &out) const override;
|
||||
|
||||
Status Compute(const TensorRow &input, TensorRow *output) override;
|
||||
|
||||
std::string Name() const override { return kMixUpBatchOp; }
|
||||
|
||||
private:
|
||||
float alpha_;
|
||||
std::mt19937 rnd_;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_MIXUPBATCH_OP_H_
|
|
@ -0,0 +1,77 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <algorithm>
|
||||
#include <random>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/kernels/image/random_affine_op.h"
|
||||
#include "minddata/dataset/kernels/image/image_utils.h"
|
||||
#include "minddata/dataset/kernels/image/math_utils.h"
|
||||
#include "minddata/dataset/util/random.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
const std::vector<float_t> RandomAffineOp::kDegreesRange = {0.0, 0.0};
|
||||
const std::vector<float_t> RandomAffineOp::kTranslationPercentages = {0.0, 0.0};
|
||||
const std::vector<float_t> RandomAffineOp::kScaleRange = {1.0, 1.0};
|
||||
const std::vector<float_t> RandomAffineOp::kShearRanges = {0.0, 0.0, 0.0, 0.0};
|
||||
const InterpolationMode RandomAffineOp::kDefInterpolation = InterpolationMode::kNearestNeighbour;
|
||||
const std::vector<uint8_t> RandomAffineOp::kFillValue = {0, 0, 0};
|
||||
|
||||
RandomAffineOp::RandomAffineOp(std::vector<float_t> degrees, std::vector<float_t> translate_range,
|
||||
std::vector<float_t> scale_range, std::vector<float_t> shear_ranges,
|
||||
InterpolationMode interpolation, std::vector<uint8_t> fill_value)
|
||||
: AffineOp(0.0),
|
||||
degrees_range_(degrees),
|
||||
translate_range_(translate_range),
|
||||
scale_range_(scale_range),
|
||||
shear_ranges_(shear_ranges) {
|
||||
interpolation_ = interpolation;
|
||||
fill_value_ = fill_value;
|
||||
rnd_.seed(GetSeed());
|
||||
}
|
||||
|
||||
Status RandomAffineOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
|
||||
IO_CHECK(input, output);
|
||||
dsize_t height = input->shape()[0];
|
||||
dsize_t width = input->shape()[1];
|
||||
float_t max_dx = translate_range_[0] * height;
|
||||
float_t max_dy = translate_range_[1] * width;
|
||||
float_t degrees = 0.0;
|
||||
RETURN_IF_NOT_OK(GenerateRealNumber(degrees_range_[0], degrees_range_[1], &rnd_, °rees));
|
||||
float_t translation_x = 0.0;
|
||||
RETURN_IF_NOT_OK(GenerateRealNumber(-1 * max_dx, max_dx, &rnd_, &translation_x));
|
||||
float_t translation_y = 0.0;
|
||||
RETURN_IF_NOT_OK(GenerateRealNumber(-1 * max_dy, max_dy, &rnd_, &translation_y));
|
||||
float_t scale = 1.0;
|
||||
RETURN_IF_NOT_OK(GenerateRealNumber(scale_range_[0], scale_range_[1], &rnd_, &scale));
|
||||
float_t shear_x = 0.0;
|
||||
RETURN_IF_NOT_OK(GenerateRealNumber(shear_ranges_[0], shear_ranges_[1], &rnd_, &shear_x));
|
||||
float_t shear_y = 0.0;
|
||||
RETURN_IF_NOT_OK(GenerateRealNumber(shear_ranges_[2], shear_ranges_[3], &rnd_, &shear_y));
|
||||
// assign to base class variables
|
||||
degrees_ = degrees;
|
||||
scale_ = scale;
|
||||
translation_[0] = translation_x;
|
||||
translation_[1] = translation_y;
|
||||
shear_[0] = shear_x;
|
||||
shear_[1] = shear_y;
|
||||
return AffineOp::Compute(input, output);
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,64 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_RANDOM_AFFINE_OP_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_RANDOM_AFFINE_OP_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/core/cv_tensor.h"
|
||||
#include "minddata/dataset/core/tensor.h"
|
||||
#include "minddata/dataset/kernels/image/affine_op.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
class RandomAffineOp : public AffineOp {
|
||||
public:
|
||||
/// Default values, also used by python_bindings.cc
|
||||
static const std::vector<float_t> kDegreesRange;
|
||||
static const std::vector<float_t> kTranslationPercentages;
|
||||
static const std::vector<float_t> kScaleRange;
|
||||
static const std::vector<float_t> kShearRanges;
|
||||
static const InterpolationMode kDefInterpolation;
|
||||
static const std::vector<uint8_t> kFillValue;
|
||||
|
||||
explicit RandomAffineOp(std::vector<float_t> degrees, std::vector<float_t> translate_range = kTranslationPercentages,
|
||||
std::vector<float_t> scale_range = kScaleRange,
|
||||
std::vector<float_t> shear_ranges = kShearRanges,
|
||||
InterpolationMode interpolation = kDefInterpolation,
|
||||
std::vector<uint8_t> fill_value = kFillValue);
|
||||
|
||||
~RandomAffineOp() override = default;
|
||||
|
||||
std::string Name() const override { return kRandomAffineOp; }
|
||||
|
||||
Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override;
|
||||
|
||||
private:
|
||||
std::string kRandomAffineOp = "RandomAffineOp";
|
||||
std::vector<float_t> degrees_range_; // min_degree, max_degree
|
||||
std::vector<float_t> translate_range_; // maximum x translation percentage, maximum y translation percentage
|
||||
std::vector<float_t> scale_range_; // min_scale, max_scale
|
||||
std::vector<float_t> shear_ranges_; // min_x_shear, max_x_shear, min_y_shear, max_y_shear
|
||||
std::mt19937 rnd_; // random device
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_RANDOM_AFFINE_OP_H_
|
|
@ -37,22 +37,9 @@ Status RandomCropDecodeResizeOp::Compute(const std::shared_ptr<Tensor> &input, s
|
|||
RETURN_IF_NOT_OK(op.Compute(input, &decoded));
|
||||
return RandomCropAndResizeOp::Compute(decoded, output);
|
||||
} else {
|
||||
struct jpeg_decompress_struct cinfo {};
|
||||
struct JpegErrorManagerCustom jerr {};
|
||||
cinfo.err = jpeg_std_error(&jerr.pub);
|
||||
jerr.pub.error_exit = JpegErrorExitCustom;
|
||||
try {
|
||||
jpeg_create_decompress(&cinfo);
|
||||
JpegSetSource(&cinfo, input->GetBuffer(), input->SizeInBytes());
|
||||
(void)jpeg_read_header(&cinfo, TRUE);
|
||||
jpeg_calc_output_dimensions(&cinfo);
|
||||
} catch (std::runtime_error &e) {
|
||||
jpeg_destroy_decompress(&cinfo);
|
||||
RETURN_STATUS_UNEXPECTED(e.what());
|
||||
}
|
||||
int h_in = cinfo.output_height;
|
||||
int w_in = cinfo.output_width;
|
||||
jpeg_destroy_decompress(&cinfo);
|
||||
int h_in = 0;
|
||||
int w_in = 0;
|
||||
RETURN_IF_NOT_OK(GetJpegImageInfo(input, &w_in, &h_in));
|
||||
|
||||
int x = 0;
|
||||
int y = 0;
|
||||
|
|
|
@ -0,0 +1,6 @@
|
|||
file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc")
|
||||
set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD)
|
||||
add_subdirectory(utils)
|
||||
add_library(kernels-soft-dvpp-image OBJECT
|
||||
soft_dvpp_decode_resize_jpeg_op.cc
|
||||
soft_dvpp_decode_random_crop_resize_jpeg_op.cc)
|
|
@ -0,0 +1,84 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "minddata/dataset/kernels/image/soft_dvpp/soft_dvpp_decode_random_crop_resize_jpeg_op.h"
|
||||
#include <string>
|
||||
|
||||
#include "opencv2/opencv.hpp"
|
||||
|
||||
#include "minddata/dataset/core/cv_tensor.h"
|
||||
#include "minddata/dataset/kernels/image/image_utils.h"
|
||||
#include "minddata/dataset/util/random.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
SoftDvppDecodeRandomCropResizeJpegOp::SoftDvppDecodeRandomCropResizeJpegOp(int32_t target_height, int32_t target_width,
|
||||
float scale_lb, float scale_ub,
|
||||
float aspect_lb, float aspect_ub,
|
||||
int32_t max_iter)
|
||||
: RandomCropAndResizeOp(target_height, target_width, scale_lb, scale_ub, aspect_lb, aspect_ub,
|
||||
InterpolationMode::kLinear, max_iter) {}
|
||||
|
||||
Status SoftDvppDecodeRandomCropResizeJpegOp::GetCropInfo(const std::shared_ptr<Tensor> &input,
|
||||
SoftDpCropInfo *crop_info) {
|
||||
int img_width = 0;
|
||||
int img_height = 0;
|
||||
RETURN_IF_NOT_OK(GetJpegImageInfo(input, &img_width, &img_height));
|
||||
int x = 0;
|
||||
int y = 0;
|
||||
int crop_heigh = 0;
|
||||
int crop_widht = 0;
|
||||
RETURN_IF_NOT_OK(GetCropBox(img_height, img_width, &x, &y, &crop_heigh, &crop_widht));
|
||||
crop_info->left = x;
|
||||
crop_info->up = y;
|
||||
crop_info->right = crop_info->left + crop_widht;
|
||||
crop_info->down = crop_info->up + crop_heigh;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status SoftDvppDecodeRandomCropResizeJpegOp::Compute(const std::shared_ptr<Tensor> &input,
|
||||
std::shared_ptr<Tensor> *output) {
|
||||
IO_CHECK(input, output);
|
||||
if (!IsNonEmptyJPEG(input)) {
|
||||
RETURN_STATUS_UNEXPECTED("SoftDvppDecodeRandomCropResizeJpeg only support process jpeg image.");
|
||||
}
|
||||
SoftDpCropInfo crop_info;
|
||||
RETURN_IF_NOT_OK(GetCropInfo(input, &crop_info));
|
||||
try {
|
||||
unsigned char *buffer = const_cast<unsigned char *>(input->GetBuffer());
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(buffer != nullptr, "The input image buffer is empty.");
|
||||
SoftDpProcsessInfo info;
|
||||
info.input_buffer = static_cast<uint8_t *>(buffer);
|
||||
info.input_buffer_size = input->SizeInBytes();
|
||||
info.output_width = target_width_;
|
||||
info.output_height = target_height_;
|
||||
cv::Mat out_rgb_img(target_height_, target_width_, CV_8UC3);
|
||||
info.output_buffer = out_rgb_img.data;
|
||||
info.output_buffer_size = target_width_ * target_height_ * 3;
|
||||
info.is_v_before_u = true;
|
||||
int ret = DecodeAndCropAndResizeJpeg(&info, crop_info);
|
||||
std::string error_info("Soft dvpp DecodeAndResizeJpeg failed with return code: ");
|
||||
error_info += std::to_string(ret);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(ret == 0, error_info);
|
||||
std::shared_ptr<CVTensor> cv_tensor = nullptr;
|
||||
RETURN_IF_NOT_OK(CVTensor::CreateFromMat(out_rgb_img, &cv_tensor));
|
||||
*output = std::static_pointer_cast<Tensor>(cv_tensor);
|
||||
} catch (const cv::Exception &e) {
|
||||
RETURN_STATUS_UNEXPECTED("Error in soft dvpp image decode and resize.");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue