adapt code for msvc

This commit is contained in:
taipingchangan 2022-09-03 09:49:25 +08:00
parent 644b98c101
commit 31ac782bf7
31 changed files with 153 additions and 57 deletions

View File

@ -13,23 +13,25 @@ if(MSVC)
cmake_host_system_information(RESULT CPU_CORES QUERY NUMBER_OF_LOGICAL_CORES)
message("CPU_CORE number = ${CPU_CORES}")
math(EXPR MP_NUM "${CPU_CORES} * 2")
set(CMAKE_C_FLAGS "/MD /O2 /Ob2 /DNDEBUG /MP${MP_NUM} /EHsc")
set(CMAKE_C_FLAGS_DEBUG "/MDd /Zi /Ob0 /Od /RTC1 /MP${MP_NUM} /EHsc")
set(CMAKE_C_FLAGS_RELEASE "/MD /O2 /Ob2 /DNDEBUG /MP${MP_NUM} /EHsc")
set(CMAKE_C_FLAGS_RELWITHDEBINFO "/MD /Zi /O2 /Ob1 /DNDEBUG /MP${MP_NUM} /EHsc")
set(CMAKE_C_FLAGS_MINSIZEREL "/MD /O1 /Ob1 /DNDEBUG /MP${MP_NUM} /EHsc")
set(CMAKE_C_FLAGS "/MD /O2 /Ob2 /DNDEBUG /MP${MP_NUM} /EHsc /bigobj")
set(CMAKE_C_FLAGS_DEBUG "/MDd /Zi /Ob0 /Od /RTC1 /MP${MP_NUM} /EHsc /bigobj")
set(CMAKE_C_FLAGS_RELEASE "/MD /O2 /Ob2 /DNDEBUG /MP${MP_NUM} /EHsc /bigobj")
set(CMAKE_C_FLAGS_RELWITHDEBINFO "/MD /Zi /O2 /Ob1 /DNDEBUG /MP${MP_NUM} /EHsc /bigobj")
set(CMAKE_C_FLAGS_MINSIZEREL "/MD /O1 /Ob1 /DNDEBUG /MP${MP_NUM} /EHsc /bigobj")
set(CMAKE_CXX_FLAGS "/MD /O2 /Ob2 /DNDEBUG /MP${MP_NUM} /EHsc")
set(CMAKE_CXX_FLAGS_DEBUG "/MDd /Zi /Ob0 /Od /RTC1 /MP${MP_NUM} /EHsc")
set(CMAKE_CXX_FLAGS_RELEASE "/MD /O2 /Ob2 /DNDEBUG /MP${MP_NUM} /EHsc")
set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "/MD /Zi /O2 /Ob1 /DNDEBUG /MP${MP_NUM} /EHsc")
set(CMAKE_CXX_FLAGS_MINSIZEREL "/MD /O1 /Ob1 /DNDEBUG /MP${MP_NUM} /EHsc")
set(CMAKE_CXX_FLAGS "/MD /O2 /Ob2 /DNDEBUG /MP${MP_NUM} /EHsc /bigobj")
set(CMAKE_CXX_FLAGS_DEBUG "/MDd /Zi /Ob0 /Od /RTC1 /MP${MP_NUM} /EHsc /bigobj")
set(CMAKE_CXX_FLAGS_RELEASE "/MD /O2 /Ob2 /DNDEBUG /MP${MP_NUM} /EHsc /bigobj")
set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "/MD /Zi /O2 /Ob1 /DNDEBUG /MP${MP_NUM} /EHsc /bigobj")
set(CMAKE_CXX_FLAGS_MINSIZEREL "/MD /O1 /Ob1 /DNDEBUG /MP${MP_NUM} /EHsc /bigobj")
# resolve std::min/std::max and opencv::min opencv:max had defined in windows.h
add_definitions(-DNOMINMAX)
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} /wd4251 /wd4819 /wd4715 /wd4244 /wd4267 /wd4716")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /wd4251 /wd4819 /wd4715 /wd4244 /wd4267 /wd4716")
# resolve ERROR had defined in windows.h
add_definitions(-DNOGDI)
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} /wd4251 /wd4819 /wd4715 /wd4244 /wd4267 /wd4716 /wd4566 /wd4273")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /wd4251 /wd4819 /wd4715 /wd4244 /wd4267 /wd4716 /wd4566 /wd4273")
if(ENABLE_GPU)
message("init cxx_flags on windows_gpu")

View File

@ -58,7 +58,9 @@ endif()
if(ENABLE_GPU)
include(${CMAKE_SOURCE_DIR}/cmake/external_libs/cub.cmake)
if(NOT MSVC)
include(${CMAKE_SOURCE_DIR}/cmake/external_libs/fast_transformers.cmake)
endif()
if(ENABLE_MPI)
include(${CMAKE_SOURCE_DIR}/cmake/external_libs/nccl.cmake)
endif()

View File

@ -115,10 +115,18 @@ if(ENABLE_GPU)
if(DEFINED ENV{CUDNN_HOME} AND NOT $ENV{CUDNN_HOME} STREQUAL "")
set(CUDNN_INCLUDE_DIR $ENV{CUDNN_HOME}/include)
if(WIN32)
set(CUDNN_LIBRARY_DIR $ENV{CUDNN_HOME}/lib)
else()
set(CUDNN_LIBRARY_DIR $ENV{CUDNN_HOME}/lib64)
endif()
find_path(CUDNN_INCLUDE_PATH cudnn.h HINTS ${CUDNN_INCLUDE_DIR} NO_DEFAULT_PATH)
find_library(CUDNN_LIBRARY_PATH "cudnn" HINTS ${CUDNN_LIBRARY_DIR} NO_DEFAULT_PATH)
if(WIN32)
find_library(CUBLAS_LIBRARY_PATH "cublas" HINTS ${CUDA_PATH}/lib/x64)
else()
find_library(CUBLAS_LIBRARY_PATH "cublas" HINTS ${CUDNN_LIBRARY_DIR})
endif()
if(CUDNN_INCLUDE_PATH STREQUAL CUDNN_INCLUDE_PATH-NOTFOUND)
message(FATAL_ERROR "Failed to find cudnn header file, please set environment variable CUDNN_HOME to \
cudnn installation position.")
@ -162,7 +170,9 @@ if(ENABLE_GPU)
## set NVCC ARCH FLAG
set(CUDA_NVCC_FLAGS)
set_nvcc_flag(CUDA_NVCC_FLAGS)
if(NOT MSVC)
add_definitions(-Wno-unknown-pragmas) # Avoid compilation warnings from cuda/thrust
endif()
if("${CMAKE_BUILD_TYPE}" STREQUAL "Debug")
list(APPEND CUDA_NVCC_FLAGS -G)
message("CUDA_NVCC_FLAGS" ${CUDA_NVCC_FLAGS})
@ -203,7 +213,9 @@ if(CMAKE_SYSTEM_NAME MATCHES "Windows" AND NOT MSVC)
endif()
# Set compile flags to ensure float compute consistency.
if(NOT MSVC)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fno-fast-math")
endif()
if(ENABLE_MPI)
add_compile_definitions(ENABLE_MPI)
@ -453,8 +465,12 @@ endif()
if(MODE_ASCEND_ALL)
target_link_libraries(mindspore PUBLIC -Wl,--start-group proto_input mindspore::protobuf -Wl,--end-group)
elseif(CMAKE_SYSTEM_NAME MATCHES "Windows")
target_link_libraries(mindspore PUBLIC -Wl,--start-group proto_input mindspore::protobuf mindspore::sentencepiece
-Wl,--end-group)
if(MSVC)
target_link_libraries(mindspore PUBLIC proto_input mindspore::protobuf mindspore::sentencepiece)
else()
target_link_libraries(mindspore PUBLIC -Wl,--start-group proto_input mindspore::protobuf
mindspore::sentencepiece -Wl,--end-group)
endif()
elseif(CMAKE_SYSTEM_NAME MATCHES "Darwin")
target_link_libraries(mindspore PUBLIC -Wl proto_input mindspore::protobuf mindspore::sentencepiece -Wl)
else()
@ -495,8 +511,12 @@ endif()
if(CMAKE_SYSTEM_NAME MATCHES "Windows")
target_link_libraries(mindspore PUBLIC mindspore::pybind11_module)
target_link_libraries(_c_expression PRIVATE -Wl,--whole-archive mindspore -Wl,--no-whole-archive mindspore_core
mindspore_common mindspore_backend)
if(NOT MSVC)
target_link_libraries(_c_expression PRIVATE -Wl,--whole-archive mindspore -Wl,--no-whole-archive mindspore_core
mindspore_common mindspore_backend)
else()
target_link_libraries(_c_expression PRIVATE mindspore_core mindspore_common mindspore_backend mindspore)
endif()
elseif(CMAKE_SYSTEM_NAME MATCHES "Darwin")
target_link_libraries(mindspore PUBLIC proto_input mindspore::protobuf mindspore::eigen mindspore::json)
target_link_libraries(_c_expression PRIVATE -Wl,-all_load mindspore proto_input -Wl,-noall_load mindspore_core

View File

@ -1414,18 +1414,19 @@ BackendOpRunInfoPtr SessionBasic::GetSingleOpRunInfo(const CNodePtr &cnode, cons
[cnode](const std::pair<KernelWithIndex, std::vector<std::vector<size_t>>> &output_index) {
return output_index.first.first == cnode;
});
pynative::BaseOpRunInfo base_op_run_info = {.has_dynamic_input = common::AnfAlgo::IsNodeInputDynamicShape(cnode),
.has_dynamic_output = shape->IsDynamic(),
.is_mixed_precision_cast = false,
.lazy_build = !shape->IsDynamic(),
.op_name = primitive->name(),
.next_op_name = std::string(),
.graph_info = graph_info,
.device_target = GetOpRunDeviceTarget(primitive),
.next_input_index = 0,
.input_tensor = tensor_info.input_tensors,
.input_mask = tensor_info.input_tensors_mask,
.abstract = abstract};
pynative::BaseOpRunInfo base_op_run_info;
base_op_run_info.has_dynamic_input = common::AnfAlgo::IsNodeInputDynamicShape(cnode);
base_op_run_info.has_dynamic_output = shape->IsDynamic();
base_op_run_info.is_mixed_precision_cast = false;
base_op_run_info.lazy_build = !shape->IsDynamic();
base_op_run_info.op_name = primitive->name();
base_op_run_info.next_op_name = std::string();
base_op_run_info.graph_info = graph_info;
base_op_run_info.device_target = GetOpRunDeviceTarget(primitive);
base_op_run_info.next_input_index = 0;
base_op_run_info.input_tensor = tensor_info.input_tensors;
base_op_run_info.input_mask = tensor_info.input_tensors_mask;
base_op_run_info.abstract = abstract;
return std::make_shared<BackendOpRunInfo>(base_op_run_info, primitive.get(), false, is_gradient_out);
}

View File

@ -174,6 +174,16 @@ if(ENABLE_D)
endif()
if(ENABLE_GPU)
if(WIN32)
target_link_libraries(mindspore_shared_lib PRIVATE cuda_ops
${CUBLAS_LIBRARY_PATH}
${CUDA_PATH}/lib/x64/curand.lib
${CUDNN_LIBRARY_PATH}
${CUDA_PATH}/lib/x64/cudart.lib
${CUDA_PATH}/lib/x64/cuda.lib
${CUDA_PATH}/lib/x64/cusolver.lib
${CUDA_PATH}/lib/x64/cufft.lib)
else()
target_link_libraries(mindspore_shared_lib PRIVATE cuda_ops
${CUBLAS_LIBRARY_PATH}
${CUDA_PATH}/lib64/libcurand.so
@ -182,6 +192,7 @@ if(ENABLE_GPU)
${CUDA_PATH}/lib64/stubs/libcuda.so
${CUDA_PATH}/lib64/libcusolver.so
${CUDA_PATH}/lib64/libcufft.so)
endif()
endif()
if(CMAKE_SYSTEM_NAME MATCHES "Linux")

View File

@ -15,14 +15,12 @@
*/
#include "distributed/persistent/storage/file_io_utils.h"
#include <dirent.h>
#include <unistd.h>
#include <fstream>
#include "mindspore/core/utils/file_utils.h"
#include "utils/convert_utils_base.h"
#include "utils/log_adapter.h"
#include "utils/os.h"
namespace mindspore {
namespace distributed {
@ -127,13 +125,16 @@ bool FileIOUtils::IsFileOrDirExist(const std::string &path) {
}
void FileIOUtils::CreateFile(const std::string &file_path, mode_t mode) {
(void)mode;
if (IsFileOrDirExist(file_path)) {
return;
}
std::ofstream output_file(file_path);
output_file.close();
#ifndef _MSC_VER
ChangeFileMode(file_path, mode);
#endif
}
void FileIOUtils::CreateDir(const std::string &dir_path, mode_t mode) {
@ -142,7 +143,11 @@ void FileIOUtils::CreateDir(const std::string &dir_path, mode_t mode) {
}
#if defined(_WIN32) || defined(_WIN64)
#ifndef _MSC_VER
int ret = mkdir(dir_path.c_str());
#else
int ret = _mkdir(dir_path.c_str());
#endif
#else
int ret = mkdir(dir_path.c_str(), mode);
if (ret == 0) {
@ -173,7 +178,11 @@ void FileIOUtils::CreateDirRecursive(const std::string &dir_path, mode_t mode) {
}
#if defined(_WIN32) || defined(_WIN64)
#ifndef _MSC_VER
int32_t ret = mkdir(tmp_dir_path);
#else
int32_t ret = _mkdir(tmp_dir_path);
#endif
if (ret != 0) {
MS_LOG(EXCEPTION) << "Failed to create directory recursion: " << dir_path << ". Errno = " << errno;
}

View File

@ -22,6 +22,10 @@
#include <vector>
#include <string>
#include <utility>
#include "utils/os.h"
#ifdef CreateFile
#undef CreateFile
#endif
namespace mindspore {
namespace distributed {

View File

@ -216,7 +216,7 @@ void InitHashMapData(void *data, const int64_t host_size, const int64_t cache_si
for (int64_t i = 0; i < host_size; ++i) {
host_range.emplace_back(static_cast<T>(i));
}
#if defined(__APPLE__)
#if defined(__APPLE__) || defined(_MSC_VER)
std::random_device rd;
std::mt19937 rng(rd());
std::shuffle(host_range.begin(), host_range.end(), rng);

View File

@ -366,7 +366,7 @@ static std::set<CNodePtr> SetParameterLayout(const FuncGraphPtr &root, const Fun
PrimitivePtr prim = GetCNodePrimitive(cnode);
MS_EXCEPTION_IF_NULL(prim);
auto attrs = prim->attrs();
if (!attrs.contains(parallel::IN_STRATEGY)) {
if (attrs.count(parallel::IN_STRATEGY) == 0) {
auto empty_strategies = GenerateEmptyStrategies(cnode);
attrs[parallel::IN_STRATEGY] = ShapesToValueTuplePtr(empty_strategies);
}

View File

@ -17,7 +17,6 @@
#ifndef MINDSPORE_CCSRC_INCLUDE_BACKEND_DATA_QUEUE_BLOCKING_QUEUE_H
#define MINDSPORE_CCSRC_INCLUDE_BACKEND_DATA_QUEUE_BLOCKING_QUEUE_H
#include <unistd.h>
#include <iostream>
#include <memory>
#include <mutex>

View File

@ -17,7 +17,6 @@
#ifndef MINDSPORE_CCSRC_INCLUDE_BACKEND_DATA_QUEUE_DATA_QUEUE_H
#define MINDSPORE_CCSRC_INCLUDE_BACKEND_DATA_QUEUE_DATA_QUEUE_H
#include <unistd.h>
#include <string>
#include <memory>
#include <vector>

View File

@ -17,7 +17,6 @@
#ifndef MINDSPORE_CCSRC_INCLUDE_BACKEND_DATA_QUEUE_DATA_QUEUE_MGR_H
#define MINDSPORE_CCSRC_INCLUDE_BACKEND_DATA_QUEUE_DATA_QUEUE_MGR_H
#include <unistd.h>
#include <iostream>
#include <functional>
#include <map>

View File

@ -24,6 +24,7 @@
#include <functional>
#include "utils/log_adapter.h"
#include "utils/os.h"
#include "include/common/visible.h"
#define DP_DEBUG MS_LOG(DEBUG) << "[DuplexPipe] "

View File

@ -20,6 +20,7 @@
#include <memory>
#include <map>
#include <set>
#include <optional>
#include "nlohmann/json.hpp"
#include "ir/anf.h"
#include "ir/dtype.h"

View File

@ -47,7 +47,8 @@ class FlangerOp : public TensorOp {
void Print(std::ostream &out) const override {
out << Name() << ": sample_rate: " << sample_rate_ << ", delay:" << delay_ << ", depth: " << depth_
<< ", regen: " << regen_ << ", width: " << width_ << ", speed: " << speed_ << ", phase: " << phase_
<< ", Modulation: " << static_cast<int>(Modulation_) << ", Interpolation: " << Interpolation_ << std::endl;
<< ", Modulation: " << static_cast<int>(Modulation_) << ", Interpolation: " << static_cast<int>(Interpolation_)
<< std::endl;
}
Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override;

View File

@ -45,8 +45,9 @@ else()
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wl,-rpath,$ORIGIN:$ORIGIN/..:$ORIGIN/../lib")
endif()
endif()
if(NOT MSVC)
set(CMAKE_CXX_FLAGS "-fPIE ${CMAKE_CXX_FLAGS}")
endif()
if(ENABLE_CACHE)
ms_grpc_generate(CACHE_GRPC_SRCS CACHE_GRPC_HDRS cache_grpc.proto)

View File

@ -25,6 +25,7 @@
#include "minddata/dataset/include/dataset/constants.h"
#include "minddata/dataset/engine/cache/cache_client.h"
#include "minddata/dataset/engine/cache/cache_fbb.h"
#undef BitTest
namespace mindspore {
namespace dataset {
Status BaseRequest::Wait() {

View File

@ -28,6 +28,7 @@
#ifdef CACHE_LOCAL_CLIENT
#include "minddata/dataset/util/sig_handler.h"
#endif
#undef BitTest
namespace mindspore {
namespace dataset {

View File

@ -1020,8 +1020,12 @@ bool GraphExecutorPy::Compile(const py::object &source_obj, const py::tuple &arg
throw(std::runtime_error(ex.what()));
} catch (...) {
ReleaseResource(phase);
#ifndef _MSC_VER
std::string exName(abi::__cxa_current_exception_type()->name());
MS_LOG(EXCEPTION) << "Error occurred when compile graph. Exception name: " << exName;
#else
MS_LOG(EXCEPTION) << "Error occurred when compile graph. Exception name: ";
#endif
}
return ret_value;
}

View File

@ -14,7 +14,6 @@
* limitations under the License.
*/
#include "plugin/device/cpu/hal/device/cpu_kernel_runtime.h"
#include <unistd.h>
#include <string>
#include <vector>
#include <memory>

View File

@ -34,10 +34,19 @@ using KernelRunFunc = RandomPoissonCpuKernelMod::KernelRunFunc;
}
static unsigned int s_seed = static_cast<unsigned int>(time(nullptr));
#ifndef _MSC_VER
EIGEN_DEVICE_FUNC uint64_t get_random_seed() {
auto rnd = rand_r(&s_seed);
return IntToSize(rnd);
}
#else
EIGEN_DEVICE_FUNC uint64_t get_random_seed() {
std::random_device rd;
std::mt19937 gen(rd());
std::uniform_int_distribution<uint64_t> distribution(0, std::numeric_limits<uint64_t>::max());
return distribution(gen);
}
#endif
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE uint64_t PCG_XSH_RS_state(uint64_t seed) {
seed = (seed == 0) ? get_random_seed() : seed;

View File

@ -61,7 +61,9 @@ class EltwiseCpuKernelFunc : public CpuKernelFunc {
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
auto iter = eltwise_func_map.find(kernel_name_);
if (iter == eltwise_func_map.end()) {
MS_LOG(EXCEPTION) << "For 'EltWise Op', the kernel name must be in " << kernel::Map2Str(eltwise_func_map)
MS_LOG(EXCEPTION) << "For 'EltWise Op', the kernel name must be in "
<< kernel::Map2Str<std::map, std::vector<std::pair<KernelAttr, TypeComputeFunc>>>(
eltwise_func_map)
<< ", but got " << kernel_name_;
}
std::vector<KernelAttr> support_list;
@ -165,8 +167,11 @@ bool EltWiseCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::
}
auto iter = additional_kernel_attr_map_.find(kernel_name_);
if (iter == additional_kernel_attr_map_.end()) {
MS_LOG(ERROR) << "For 'EltWise Op', the kernel name must be in " << kernel::Map2Str(additional_kernel_attr_map_)
<< ", but got " << kernel_name_;
MS_LOG(ERROR)
<< "For 'EltWise Op', the kernel name must be in "
<< kernel::Map2Str<std::map, std::vector<std::pair<KernelAttr, EltWiseCpuKernelMod::EltwiseCpuFuncCreator>>>(
additional_kernel_attr_map_)
<< ", but got " << kernel_name_;
return false;
}
additional_func_ = iter->second[index].second();
@ -180,8 +185,9 @@ bool EltWiseCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::
}
auto iter = mkl_kernel_attr_map_.find(kernel_name_);
if (iter == mkl_kernel_attr_map_.end()) {
MS_LOG(ERROR) << "For 'EltWise Op', the kernel name must be in " << kernel::Map2Str(mkl_kernel_attr_map_)
<< ", but got " << kernel_name_;
MS_LOG(ERROR) << "For 'EltWise Op', the kernel name must be in "
<< kernel::Map2Str<std::map, std::vector<KernelAttr>>(mkl_kernel_attr_map_) << ", but got "
<< kernel_name_;
return false;
}
}
@ -247,8 +253,9 @@ std::vector<KernelAttr> EltWiseCpuKernelMod::GetOpSupport() {
// only mkl_kernel_attr_map_ need to be checked since it contains all kind of ops
auto iter = mkl_kernel_attr_map_.find(kernel_name_);
if (iter == mkl_kernel_attr_map_.end()) {
MS_LOG(ERROR) << "For 'EltWise Op', the kernel name must be in " << kernel::Map2Str(mkl_kernel_attr_map_)
<< ", but got " << kernel_name_;
MS_LOG(ERROR) << "For 'EltWise Op', the kernel name must be in "
<< kernel::Map2Str<std::map, std::vector<KernelAttr>>(mkl_kernel_attr_map_) << ", but got "
<< kernel_name_;
return std::vector<KernelAttr>{};
}
std::vector<KernelAttr> support_list;

View File

@ -51,7 +51,7 @@ generate_simd_code(AVX 8 "\"avx\", \"avx2\"")
generate_simd_code(AVX512 16 \"avx512f\")
generate_simd_header_code()
if(ENABLE_CPU)
if(ENABLE_CPU AND NOT MSVC)
set(CMAKE_C_FLAGS "-Wno-attributes ${CMAKE_C_FLAGS}")
endif()

View File

@ -80,10 +80,10 @@ class BufferCPUSampleKernelMod : public DeprecatedNativeCpuKernelMod {
for (size_t i = 0; i < IntToSize(count_addr[0]); ++i) {
(void)indexes.emplace_back(i);
}
#if !defined(__APPLE__)
random_shuffle(indexes.begin(), indexes.end(), [&](int i) { return std::rand() % i; });
#else
#if defined(__APPLE__) || defined(_MSC_VER)
std::shuffle(indexes.begin(), indexes.end(), generator_);
#else
random_shuffle(indexes.begin(), indexes.end(), [&](int i) { return std::rand() % i; });
#endif
} else {
std::uniform_int_distribution<> distrib(0, count_addr[0] - 1); // random integers in a range [a,b]

View File

@ -35,6 +35,17 @@ target_link_libraries(mindspore_gpu PRIVATE mindspore::event mindspore::event_pt
if(ENABLE_GPU)
message("add gpu lib to mindspore_gpu")
if(WIN32)
target_link_libraries(mindspore_gpu PRIVATE cuda_ops
${CUBLAS_LIBRARY_PATH}
${CUDA_PATH}/lib/x64/curand.lib
${CUDNN_LIBRARY_PATH}
${CUDA_PATH}/lib/x64/cudart.lib
${CUDA_PATH}/lib/x64/cuda.lib
${CUDA_PATH}/lib/x64/cusolver.lib
${CUDA_PATH}/lib/x64/cufft.lib
${CUDA_PATH}/lib/x64/cusparse.lib)
else()
target_link_libraries(mindspore_gpu PRIVATE cuda_ops
${CUBLAS_LIBRARY_PATH}
${CUDA_PATH}/lib64/libcurand.so
@ -44,6 +55,7 @@ if(ENABLE_GPU)
${CUDA_PATH}/lib64/libcusolver.so
${CUDA_PATH}/lib64/libcufft.so
${CUDA_PATH}/lib64/libcusparse.so)
endif()
endif()
if(ENABLE_DEBUGGER)

View File

@ -28,6 +28,16 @@ if(ENABLE_GPU)
endif()
cuda_add_library(cuda_ops SHARED ${CUDA_OPS_SRC_LIST} $<TARGET_OBJECTS:cuda_common_obj>)
message("add gpu lib to cuda_ops")
if(WIN32)
target_link_libraries(cuda_ops mindspore_core
${CUBLAS_LIBRARY_PATH}
${CUDA_PATH}/lib/x64/curand.lib
${CUDNN_LIBRARY_PATH}
${CUDA_PATH}/lib/x64/cudart.lib
${CUDA_PATH}/lib/x64/cuda.lib
${CUDA_PATH}/lib/x64/cusolver.lib
${CUDA_PATH}/lib/x64/cufft.lib)
else()
target_link_libraries(cuda_ops mindspore_core
${CUBLAS_LIBRARY_PATH}
${CUDA_PATH}/lib64/libcurand.so
@ -36,4 +46,5 @@ if(ENABLE_GPU)
${CUDA_PATH}/lib64/stubs/libcuda.so
${CUDA_PATH}/lib64/libcusolver.so
${CUDA_PATH}/lib64/libcufft.so)
endif()
endif()

View File

@ -83,11 +83,11 @@ class RequestProcessResult {
operator bool() const = delete;
RequestProcessResult &operator<(const LogStream &stream) noexcept __attribute__((visibility("default"))) {
RequestProcessResult &operator<(const LogStream &stream) noexcept {
msg_ = stream.stream()->str();
return *this;
}
RequestProcessResult &operator=(const std::string &message) noexcept __attribute__((visibility("default"))) {
RequestProcessResult &operator=(const std::string &message) noexcept {
msg_ = message;
return *this;
}

View File

@ -75,8 +75,9 @@ TypePtr ConcatInferType(const PrimitivePtr &primitive, const std::vector<Abstrac
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
if (!input_args[0]->isa<abstract::AbstractTuple>() && !input_args[0]->isa<abstract::AbstractList>()) {
MS_EXCEPTION(TypeError) << "For '" << prim_name << "', the input must be a list or tuple of tensors. But got"
<< input_args[0]->ToString() << ".";
MS_EXCEPTION(TypeError) << "For '" << prim_name
<< "', the input must be a list or tuple of tensors. But got: " << input_args[0]->ToString()
<< ".";
}
auto elements = input_args[0]->isa<abstract::AbstractTuple>()
? input_args[0]->cast<abstract::AbstractTuplePtr>()->elements()

View File

@ -27,6 +27,7 @@
#include <thread>
#include <functional>
#include "utils/macros.h"
#include "utils/os.h"
#include "utils/overload.h"
#include "./securec.h"
#ifdef USE_GLOG

View File

@ -38,13 +38,13 @@
#endif
#ifdef _MSC_VER
#define NO_RETURN __declspec(noreturn)
#define NO_RETURN
#else
#define NO_RETURN __attribute__((noreturn))
#endif
#ifdef _MSC_VER
#define ALWAYS_INLINE __declspec(__forceinline)
#define ALWAYS_INLINE
#else
#define ALWAYS_INLINE __attribute__((__always_inline__))
#endif

View File

@ -17,7 +17,6 @@
#ifndef MINDSPORE_CORE_UTILS_SYSTEM_FILE_SYSTEM_H_
#define MINDSPORE_CORE_UTILS_SYSTEM_FILE_SYSTEM_H_
#include <sys/param.h>
#include <cerrno>
#include <cstdint>
#include <cstdlib>
@ -28,6 +27,7 @@
#include <vector>
#include "utils/system/base.h"
#include "utils/log_adapter.h"
#include "utils/os.h"
#include "include/common/debug/common.h"
namespace mindspore {