Revert "[feat] [assistant] [I3T96T] add new Dataset operator CMUARCTICDataset"

This reverts commit b077aa1cab.

Revert "[feat] [assistant] [I3T96X] add new Dataset operator LibriSpeechDataset"

This reverts commit 4e6f7dc97d.

delete pass_registry_test.cc

comment  hiai_nlu_model_multi.pb related  line
This commit is contained in:
ms_yan 2021-08-22 22:16:37 +08:00
parent 580a97ba20
commit 36a8886ca2
1974 changed files with 24190 additions and 24985 deletions

0
1.txt
View File

View File

@ -18,7 +18,7 @@
SET BASE_PATH=%CD% SET BASE_PATH=%CD%
SET BUILD_PATH=%BASE_PATH%/build SET BUILD_PATH=%BASE_PATH%/build
SET threads=6 SET threads=8
SET ENABLE_GITEE=OFF SET ENABLE_GITEE=OFF
set VERSION_MAJOR='' set VERSION_MAJOR=''

View File

@ -61,7 +61,7 @@ usage()
echo " -l Compile with python dependency, default on" echo " -l Compile with python dependency, default on"
echo " -S Enable enable download cmake compile dependency from gitee , default off" echo " -S Enable enable download cmake compile dependency from gitee , default off"
echo " -k Enable make clean, clean up compilation generated cache " echo " -k Enable make clean, clean up compilation generated cache "
echo " -W Enable x86_64 SSE or AVX instruction set, use [sse|neon|avx|avx512|off], default off for lite and avx for CPU" echo " -W Enable SIMD instruction set, use [sse|neon|avx|avx512|off], default avx for cloud CPU backend"
echo " -H Enable hidden" echo " -H Enable hidden"
echo " -L Link and specify Tensor-RT library path, default disable Tensor-RT lib linking" echo " -L Link and specify Tensor-RT library path, default disable Tensor-RT lib linking"
echo " -y Compile the symbol table switch and save the symbol table to the directory output" echo " -y Compile the symbol table switch and save the symbol table to the directory output"

View File

@ -1,44 +0,0 @@
set(FFMPEG_FLAGS
--disable-programs
--disable-doc
--disable-debug
--disable-avdevice
--disable-postproc
--disable-avfilter
--disable-network
--disable-encoders
--disable-hwaccels
--disable-muxers
--disable-bsfs
--disable-protocols
--enable-protocol=file
--enable-protocol=pipe
--disable-indevs
--disable-outdevs
--disable-devices
--disable-filters
--disable-bzlib
--disable-iconv
--disable-libxcb
--disable-lzma
--disable-sdl2
--disable-xlib
--disable-zlib)
set(REQ_URL "https://github.com/FFmpeg/FFmpeg/archive/n4.3.1.tar.gz")
set(MD5 "426ca412ca61634a248c787e29507206")
mindspore_add_pkg(ffmpeg
VER 4.3.1
LIBS avcodec avformat avutil swresample swscale
URL ${REQ_URL}
MD5 ${MD5}
CONFIGURE_COMMAND ./configure --disable-static --enable-shared --disable-x86asm ${FFMPEG_FLAGS}
)
include_directories(${ffmpeg_INC})
add_library(mindspore::avcodec ALIAS ffmpeg::avcodec)
add_library(mindspore::avformat ALIAS ffmpeg::avformat)
add_library(mindspore::avutil ALIAS ffmpeg::avutil)
add_library(mindspore::swresample ALIAS ffmpeg::swresample)
add_library(mindspore::swscale ALIAS ffmpeg::swscale)

View File

@ -1,13 +1,15 @@
set(glog_CXXFLAGS "-D_FORTIFY_SOURCE=2 -O2 ${SECURE_CXX_FLAGS} -Dgoogle=mindspore_private")
set(glog_CFLAGS "-D_FORTIFY_SOURCE=2 -O2")
if(NOT ENABLE_GLIBCXX)
set(glog_CXXFLAGS "${glog_CXXFLAGS} -D_GLIBCXX_USE_CXX11_ABI=0")
endif()
if(BUILD_LITE) if(BUILD_LITE)
set(glog_CXXFLAGS "-D_FORTIFY_SOURCE=2 -O2 ${SECURE_CXX_FLAGS} -Dgoogle=mindspore_private")
set(glog_CFLAGS "-D_FORTIFY_SOURCE=2 -O2 ${SECURE_C_FLAGS}")
set(glog_LDFLAGS "${SECURE_SHARED_LINKER_FLAGS}")
set(glog_patch "") set(glog_patch "")
set(glog_lib glog) set(glog_lib glog)
else() else()
set(glog_CXXFLAGS "-D_FORTIFY_SOURCE=2 -O2 ${SECURE_CXX_FLAGS} -Dgoogle=mindspore_private")
set(glog_CFLAGS "-D_FORTIFY_SOURCE=2 -O2")
if(NOT ENABLE_GLIBCXX)
set(glog_CXXFLAGS "${glog_CXXFLAGS} -D_GLIBCXX_USE_CXX11_ABI=0")
endif()
set(glog_patch ${CMAKE_SOURCE_DIR}/third_party/patch/glog/glog.patch001) set(glog_patch ${CMAKE_SOURCE_DIR}/third_party/patch/glog/glog.patch001)
set(glog_lib mindspore_glog) set(glog_lib mindspore_glog)
endif() endif()

View File

@ -9,7 +9,7 @@ endif()
if(ENABLE_GITEE) if(ENABLE_GITEE)
set(REQ_URL "https://gitee.com/mirrors/JSON-for-Modern-CPP/repository/archive/v3.6.1.zip") set(REQ_URL "https://gitee.com/mirrors/JSON-for-Modern-CPP/repository/archive/v3.6.1.zip")
set(MD5 "5bda78ce308e6cfcf614dcf1d5ff27a7") set(MD5 "36ea0d9a709c6667b2798a62f6b197ae")
set(INCLUDE "./include") set(INCLUDE "./include")
else() else()
set(REQ_URL "https://github.com/nlohmann/json/releases/download/v3.6.1/include.zip") set(REQ_URL "https://github.com/nlohmann/json/releases/download/v3.6.1/include.zip")
@ -23,4 +23,4 @@ mindspore_add_pkg(nlohmann_json
URL ${REQ_URL} URL ${REQ_URL}
MD5 ${MD5}) MD5 ${MD5})
include_directories(${nlohmann_json_INC}) include_directories(${nlohmann_json_INC})
add_library(mindspore::json ALIAS nlohmann_json) add_library(mindspore::json ALIAS nlohmann_json)

View File

@ -198,12 +198,6 @@ if(NOT ENABLE_GE)
set(ASCEND_DRIVER_PATH ${ASCEND_PATH}/driver/lib64/common) set(ASCEND_DRIVER_PATH ${ASCEND_PATH}/driver/lib64/common)
if(ENABLE_D) if(ENABLE_D)
install(
TARGETS ms_profile
DESTINATION ${INSTALL_LIB_DIR}
COMPONENT mindspore
)
install( install(
TARGETS hccl_plugin TARGETS hccl_plugin
DESTINATION ${INSTALL_LIB_DIR} DESTINATION ${INSTALL_LIB_DIR}

View File

@ -330,8 +330,6 @@ elseif(WIN32)
DESTINATION ${CONVERTER_ROOT_DIR}/include COMPONENT ${RUNTIME_COMPONENT_NAME}) DESTINATION ${CONVERTER_ROOT_DIR}/include COMPONENT ${RUNTIME_COMPONENT_NAME})
install(FILES ${TOP_DIR}/mindspore/lite/tools/converter/model_parser.h install(FILES ${TOP_DIR}/mindspore/lite/tools/converter/model_parser.h
DESTINATION ${CONVERTER_ROOT_DIR}/include COMPONENT ${RUNTIME_COMPONENT_NAME}) DESTINATION ${CONVERTER_ROOT_DIR}/include COMPONENT ${RUNTIME_COMPONENT_NAME})
install(FILES ${TOP_DIR}/mindspore/lite/tools/converter/dump_graph.h
DESTINATION ${CONVERTER_ROOT_DIR}/include COMPONENT ${RUNTIME_COMPONENT_NAME})
install(FILES ${TOP_DIR}/mindspore/lite/tools/converter/ops/ops_def.h install(FILES ${TOP_DIR}/mindspore/lite/tools/converter/ops/ops_def.h
DESTINATION ${CONVERTER_ROOT_DIR}/include COMPONENT ${RUNTIME_COMPONENT_NAME}) DESTINATION ${CONVERTER_ROOT_DIR}/include COMPONENT ${RUNTIME_COMPONENT_NAME})
install(DIRECTORY ${TOP_DIR}/build/mindspore/schema/ DESTINATION ${CONVERTER_ROOT_DIR}/include/schema install(DIRECTORY ${TOP_DIR}/build/mindspore/schema/ DESTINATION ${CONVERTER_ROOT_DIR}/include/schema
@ -462,8 +460,6 @@ else()
DESTINATION ${CONVERTER_ROOT_DIR}/include COMPONENT ${RUNTIME_COMPONENT_NAME}) DESTINATION ${CONVERTER_ROOT_DIR}/include COMPONENT ${RUNTIME_COMPONENT_NAME})
install(FILES ${TOP_DIR}/mindspore/lite/tools/converter/model_parser.h install(FILES ${TOP_DIR}/mindspore/lite/tools/converter/model_parser.h
DESTINATION ${CONVERTER_ROOT_DIR}/include COMPONENT ${RUNTIME_COMPONENT_NAME}) DESTINATION ${CONVERTER_ROOT_DIR}/include COMPONENT ${RUNTIME_COMPONENT_NAME})
install(FILES ${TOP_DIR}/mindspore/lite/tools/converter/dump_graph.h
DESTINATION ${CONVERTER_ROOT_DIR}/include COMPONENT ${RUNTIME_COMPONENT_NAME})
install(FILES ${TOP_DIR}/mindspore/lite/tools/converter/ops/ops_def.h install(FILES ${TOP_DIR}/mindspore/lite/tools/converter/ops/ops_def.h
DESTINATION ${CONVERTER_ROOT_DIR}/include COMPONENT ${RUNTIME_COMPONENT_NAME}) DESTINATION ${CONVERTER_ROOT_DIR}/include COMPONENT ${RUNTIME_COMPONENT_NAME})
install(DIRECTORY ${TOP_DIR}/mindspore/lite/build/schema/ DESTINATION ${CONVERTER_ROOT_DIR}/include/schema install(DIRECTORY ${TOP_DIR}/mindspore/lite/build/schema/ DESTINATION ${CONVERTER_ROOT_DIR}/include/schema

View File

@ -23,12 +23,6 @@
#include "include/api/data_type.h" #include "include/api/data_type.h"
#include "include/api/dual_abi_helper.h" #include "include/api/dual_abi_helper.h"
#ifdef _WIN32
#define MS_API __declspec(dllexport)
#else
#define MS_API __attribute__((visibility("default")))
#endif
namespace mindspore { namespace mindspore {
class Model; class Model;
class ModelImpl; class ModelImpl;

View File

@ -22,12 +22,6 @@
#include <memory> #include <memory>
#include "include/api/callback/callback.h" #include "include/api/callback/callback.h"
#ifdef _WIN32
#define MS_API __declspec(dllexport)
#else
#define MS_API __attribute__((visibility("default")))
#endif
namespace mindspore { namespace mindspore {
class CkptSaver: public TrainCallBack { class CkptSaver: public TrainCallBack {

View File

@ -21,12 +21,6 @@
#include <utility> #include <utility>
#include "include/api/callback/callback.h" #include "include/api/callback/callback.h"
#ifdef _WIN32
#define MS_API __declspec(dllexport)
#else
#define MS_API __attribute__((visibility("default")))
#endif
using GraphPoint = std::pair<int, float>; using GraphPoint = std::pair<int, float>;
namespace mindspore { namespace mindspore {

View File

@ -22,12 +22,6 @@
#include <memory> #include <memory>
#include "include/api/callback/callback.h" #include "include/api/callback/callback.h"
#ifdef _WIN32
#define MS_API __declspec(dllexport)
#else
#define MS_API __attribute__((visibility("default")))
#endif
namespace mindspore { namespace mindspore {
constexpr int DONT_UPDATE_LR = 0; constexpr int DONT_UPDATE_LR = 0;

View File

@ -22,12 +22,6 @@
#include <memory> #include <memory>
#include "include/api/callback/callback.h" #include "include/api/callback/callback.h"
#ifdef _WIN32
#define MS_API __declspec(dllexport)
#else
#define MS_API __attribute__((visibility("default")))
#endif
namespace mindspore { namespace mindspore {
class TimeMonitor: public TrainCallBack { class TimeMonitor: public TrainCallBack {

View File

@ -24,12 +24,6 @@
#include "include/api/callback/callback.h" #include "include/api/callback/callback.h"
#include "include/api/metrics/accuracy.h" #include "include/api/metrics/accuracy.h"
#ifdef _WIN32
#define MS_API __declspec(dllexport)
#else
#define MS_API __attribute__((visibility("default")))
#endif
using GraphPoint = std::pair<int, float>; using GraphPoint = std::pair<int, float>;
namespace mindspore { namespace mindspore {

View File

@ -23,12 +23,6 @@
#include "include/api/data_type.h" #include "include/api/data_type.h"
#include "include/api/dual_abi_helper.h" #include "include/api/dual_abi_helper.h"
#ifdef _WIN32
#define MS_API __declspec(dllexport)
#else
#define MS_API __attribute__((visibility("default")))
#endif
namespace mindspore { namespace mindspore {
class MixPrecisionCfg { class MixPrecisionCfg {

View File

@ -105,14 +105,29 @@ class MS_API DeviceInfoContext : public std::enable_shared_from_this<DeviceInfoC
return std::static_pointer_cast<T>(shared_from_this()); return std::static_pointer_cast<T>(shared_from_this());
} }
/// \brief obtain provider's name
///
/// \return provider's name.
std::string GetProvider() const; std::string GetProvider() const;
/// \brief set provider's name.
///
/// \param[in] provider define the provider's name.
void SetProvider(const std::string &provider); void SetProvider(const std::string &provider);
/// \brief obtain provider's device type.
///
/// \return provider's device type.
std::string GetProviderDevice() const; std::string GetProviderDevice() const;
/// \brief set provider's device type.
///
/// \param[in] device define the provider's device type.EG: CPU.
void SetProviderDevice(const std::string &device); void SetProviderDevice(const std::string &device);
/// \brief set memory allocator.
///
/// \param[in] allocator define the memory allocator which can be defined by user.
void SetAllocator(const std::shared_ptr<Allocator> &allocator); void SetAllocator(const std::shared_ptr<Allocator> &allocator);
/// \brief obtain memory allocator.
///
/// \return memory allocator.
std::shared_ptr<Allocator> GetAllocator() const; std::shared_ptr<Allocator> GetAllocator() const;
protected: protected:

View File

@ -24,9 +24,16 @@
#include "include/api/context.h" #include "include/api/context.h"
namespace mindspore::kernel { namespace mindspore::kernel {
/// \brief The Kernel class is used to define a MindSpore Kernel.
class Kernel { class Kernel {
public: public:
Kernel() = default; Kernel() = default;
/// \brief Constructor.
///
/// \param[in] inputs define the input tensors for kernel.
/// \param[in] outputs define the output tensors for kernel.
/// \param[in] primitive define the primitive of kernel generated by flatbuffers.
/// \param[in] ctx define the context for kernel.
Kernel(const std::vector<mindspore::MSTensor> &inputs, const std::vector<mindspore::MSTensor> &outputs, Kernel(const std::vector<mindspore::MSTensor> &inputs, const std::vector<mindspore::MSTensor> &outputs,
const schema::Primitive *primitive, const mindspore::Context *ctx) const schema::Primitive *primitive, const mindspore::Context *ctx)
: context_(ctx), inputs_(std::move(inputs)), outputs_(std::move(outputs)), primitive_(primitive) { : context_(ctx), inputs_(std::move(inputs)), outputs_(std::move(outputs)), primitive_(primitive) {
@ -34,32 +41,65 @@ class Kernel {
type_ = primitive->value_type(); type_ = primitive->value_type();
} }
} }
/// \brief Destructor.
virtual ~Kernel() = default; virtual ~Kernel() = default;
/// \brief prepare for executing kernel.
///
/// \return result code.
virtual int Prepare() = 0; virtual int Prepare() = 0;
/// \brief execute the kernel.
///
/// \return result code.
virtual int Execute() = 0; virtual int Execute() = 0;
/// \brief resize the kernel input shape, memory need to refresh.
///
/// \return result code.
virtual int ReSize() = 0; virtual int ReSize() = 0;
/// \brief set kernel's input tensors.
///
/// \param[in] in_tensors define the input tensors.
virtual void set_inputs(const std::vector<mindspore::MSTensor> &in_tensors) { this->inputs_ = in_tensors; } virtual void set_inputs(const std::vector<mindspore::MSTensor> &in_tensors) { this->inputs_ = in_tensors; }
/// \brief set kernel's input tensor.
///
/// \param[in] in_tensor define the input tensor.
/// \param[in] index define the index of the input tensor.
virtual void set_input(mindspore::MSTensor in_tensor, int index) { this->inputs_[index] = in_tensor; } virtual void set_input(mindspore::MSTensor in_tensor, int index) { this->inputs_[index] = in_tensor; }
/// \brief set kernel's output tensors.
///
/// \param[in] out_tensors define the output tensors.
virtual void set_outputs(const std::vector<mindspore::MSTensor> &out_tensors) { this->outputs_ = out_tensors; } virtual void set_outputs(const std::vector<mindspore::MSTensor> &out_tensors) { this->outputs_ = out_tensors; }
/// \brief set kernel's output tensor.
///
/// \param[in] out_tensor define the output tensor.
/// \param[in] index define the index of the output tensor.
virtual void set_output(mindspore::MSTensor out_tensor, int index) { this->outputs_[index] = out_tensor; } virtual void set_output(mindspore::MSTensor out_tensor, int index) { this->outputs_[index] = out_tensor; }
/// \brief obtain kernel's input tensors.
///
/// \return input tensors.
virtual const std::vector<mindspore::MSTensor> &inputs() { return this->inputs_; } virtual const std::vector<mindspore::MSTensor> &inputs() { return this->inputs_; }
/// \brief obtain kernel's output tensors.
///
/// \return output tensors.
virtual const std::vector<mindspore::MSTensor> &outputs() { return this->outputs_; } virtual const std::vector<mindspore::MSTensor> &outputs() { return this->outputs_; }
/// \brief obtain kernel's name.
///
/// \return kernel's name.
std::string name() const { return this->name_; } std::string name() const { return this->name_; }
/// \brief set kernel's name.
///
/// \param[in] name define the kernel's name.
void set_name(const std::string &name) { this->name_ = name; } void set_name(const std::string &name) { this->name_ = name; }
/// \brief obtain kernel's context.
///
/// \return kernel's context.
const mindspore::Context *context() const { return this->context_; } const mindspore::Context *context() const { return this->context_; }
/// \brief obtain kernel's type.
///
/// \return kernel's type.
virtual schema::PrimitiveType type() const { return type_; } virtual schema::PrimitiveType type() const { return type_; }
/// \brief obtain the primitive of kernel generated by flatbuffers.
///
/// \return the primitive of kernel generated by flatbuffers.
const schema::Primitive *primitive() const { return this->primitive_; } const schema::Primitive *primitive() const { return this->primitive_; }
protected: protected:

View File

@ -27,12 +27,16 @@
#ifndef MS_API #ifndef MS_API
#ifdef _WIN32 #ifdef _WIN32
#ifdef _MSC_VER
#ifdef BUILDING_DLL #ifdef BUILDING_DLL
#define MS_API __declspec(dllexport) #define MS_API __declspec(dllexport)
#else #else
#define MS_API __declspec(dllimport) #define MS_API __declspec(dllimport)
#endif #endif
#else #else
#define MS_API __declspec(dllexport)
#endif
#else
#define MS_API __attribute__((visibility("default"))) #define MS_API __attribute__((visibility("default")))
#endif #endif
#endif #endif

View File

@ -148,7 +148,7 @@ def check_number(arg_value, value, rel, arg_type=int, arg_name=None, prim_name=N
Check argument integer. Check argument integer.
Example: Example:
- number = check_int(number, 0, Rel.GE, "number", None) # number >= 0 - number = check_number(number, 0, Rel.GE, "number", None) # number >= 0
""" """
rel_fn = Rel.get_fns(rel) rel_fn = Rel.get_fns(rel)
prim_name = f'in `{prim_name}`' if prim_name else '' prim_name = f'in `{prim_name}`' if prim_name else ''

View File

@ -18,7 +18,6 @@ from .addn import AddN
from .assign_add import AssignAdd from .assign_add import AssignAdd
from .batchnorm import BatchNorm from .batchnorm import BatchNorm
from .batchnorm_grad import BatchNormGrad from .batchnorm_grad import BatchNormGrad
from .bias_add import BiasAdd
from .bias_add_grad import BiasAddGrad from .bias_add_grad import BiasAddGrad
from .clip_by_norm_no_div_sum import ClipByNormNoDivSum from .clip_by_norm_no_div_sum import ClipByNormNoDivSum
from .conv2d import Conv2D from .conv2d import Conv2D
@ -26,7 +25,6 @@ from .complex import CAbs, CAdd, CDiv, CMul, CSub
from .dropout_grad import DropoutGrad from .dropout_grad import DropoutGrad
from .equal_count import EqualCount from .equal_count import EqualCount
from .erfc import Erfc from .erfc import Erfc
from .expand_dims import ExpandDims
from .fused_adam import FusedAdam from .fused_adam import FusedAdam
from .fused_adam_weight_decay import FusedAdamWeightDecay from .fused_adam_weight_decay import FusedAdamWeightDecay
from .fused_mul_add import FusedMulAdd from .fused_mul_add import FusedMulAdd
@ -51,6 +49,7 @@ from .sigmoid import Sigmoid
from .sigmoid_cross_entropy_with_logits import SigmoidCrossEntropyWithLogits from .sigmoid_cross_entropy_with_logits import SigmoidCrossEntropyWithLogits
from .sigmoid_cross_entropy_with_logits_grad import SigmoidCrossEntropyWithLogitsGrad from .sigmoid_cross_entropy_with_logits_grad import SigmoidCrossEntropyWithLogitsGrad
from .sigmoid_grad import SigmoidGrad from .sigmoid_grad import SigmoidGrad
from .slice import Slice
from .softmax import Softmax from .softmax import Softmax
from .softmax_cross_entropy_with_logits import SoftmaxCrossEntropyWithLogits from .softmax_cross_entropy_with_logits import SoftmaxCrossEntropyWithLogits
from .softmax_grad_ext import SoftmaxGradExt from .softmax_grad_ext import SoftmaxGradExt

View File

@ -80,6 +80,9 @@ class Expander:
class ExpanderInfoValidator: class ExpanderInfoValidator:
"""ExpanderInfoValidator is the utility class which defines the validator decorator for expanders""" """ExpanderInfoValidator is the utility class which defines the validator decorator for expanders"""
def __init__(self):
"""Init"""
@staticmethod @staticmethod
def _add_check_function(kls, func): def _add_check_function(kls, func):
""" """
@ -198,8 +201,8 @@ def to_frac_z_axis(ori_shape, ori_axis):
return frac_z_axis return frac_z_axis
def infer_shape_from_fractalNz(fractal): def infer_shape_from_fractalnz(fractal):
"get original shape from fractalNz shape" "get original shape from fractalnz shape"
shape = [] shape = []
dims = len(fractal) dims = len(fractal)
batch = dims - 4 batch = dims - 4

View File

@ -24,6 +24,7 @@ from .expand_dims import ExpandDims
@VLD.check_attrs('is_training', 'momentum', 'epsilon') @VLD.check_attrs('is_training', 'momentum', 'epsilon')
class BatchNorm(Expander): class BatchNorm(Expander):
"""BatchNorm expander""" """BatchNorm expander"""
def _expand(self, graph_builder): def _expand(self, graph_builder):
# get op info # get op info
input_x = self.inputs[0] input_x = self.inputs[0]
@ -42,81 +43,8 @@ class BatchNorm(Expander):
input_x = graph_builder.emit('Cast', [input_x], attrs={'dst_type': input_x_new_type}) input_x = graph_builder.emit('Cast', [input_x], attrs={'dst_type': input_x_new_type})
if self.attrs['is_training']: if self.attrs['is_training']:
reduce_axis = () self.inputs[0] = input_x
shape_x = input_x.shape res_y, mean_res, variance_res, mean_muls, y_sqrt_rec = self._bn_train(graph_builder)
if input_x.data_format == DF.NHWC:
reduce_axis = (0, 1, 2)
num = shape_x[0] * shape_x[1] * shape_x[2]
else:
reduce_axis = (0, 2, 3)
num = shape_x[0] * shape_x[2] * shape_x[3]
num_rec = 1.0 / num
num_rec_v = graph_builder.value(input_scale.dtype, num_rec)
# compute mean value of input_x
mean_sum = graph_builder.emit(
'ReduceSum', [input_x], attrs={'reduce_axis': reduce_axis, 'keep_dims': False})
mean_muls = graph_builder.emit('Mul', [mean_sum, num_rec_v])
# compute variance of input_x
if input_x.data_format in (DF.DEFAULT, DF.NCHW):
mean_muls_expand = graph_builder.emit(
'Reshape', [mean_muls], attrs={'shape': ExpandDims.infer_shape(mean_muls.shape, [-1, -1])})
else:
mean_muls_expand = mean_muls
var_sub = graph_builder.emit('Sub', [input_x, mean_muls_expand])
var_mul = graph_builder.emit('Mul', [var_sub, var_sub])
var_sum = graph_builder.emit('ReduceSum', [var_mul], attrs={'reduce_axis': reduce_axis, 'keep_dims': False})
var_mul = graph_builder.emit('Mul', [var_sum, num_rec_v])
# y_sqrt_rec means 1 / sqrt(variance + epsilon), which is calculated in backward pass
scalar_one = 1.0
scalar_one_v = graph_builder.value(input_scale.dtype, scalar_one)
y_add = graph_builder.emit('Add', [var_mul, epsilon_v])
y_sqrt = graph_builder.emit('Sqrt', [y_add])
y_sqrt_rec = graph_builder.emit('RealDiv', [scalar_one_v, y_sqrt])
# compute res_y
tmp_sub = graph_builder.emit('Sub', [input_x, mean_muls_expand])
if input_x.data_format in (DF.DEFAULT, DF.NCHW):
y_sqrt_rec_expand = graph_builder.emit(
'Reshape', [y_sqrt_rec], attrs={'shape': ExpandDims.infer_shape(y_sqrt_rec.shape, [-1, -1])})
else:
y_sqrt_rec_expand = y_sqrt_rec
y_norm = graph_builder.emit('Mul', [tmp_sub, y_sqrt_rec_expand])
if input_x.data_format in (DF.DEFAULT, DF.NCHW):
input_scale_expand = graph_builder.emit(
'Reshape', [input_scale], attrs={'shape': ExpandDims.infer_shape(input_scale.shape, [-1, -1])})
else:
input_scale_expand = input_scale
res_y_mul = graph_builder.emit('Mul', [input_scale_expand, y_norm])
if input_x.data_format in (DF.DEFAULT, DF.NCHW):
input_offset_expand = graph_builder.emit(
'Reshape', [input_offset], attrs={'shape': ExpandDims.infer_shape(input_offset.shape, [-1, -1])})
else:
input_offset_expand = input_offset
res_y = graph_builder.emit('Add', [res_y_mul, input_offset_expand])
# compute mean_res
momentum_sub = scalar_one - self.attrs['momentum']
momentum_v_sub = graph_builder.value(input_scale.dtype, momentum_sub)
new_running_mean_tmp = graph_builder.emit('Mul', [momentum_v_sub, input_mean])
momentum_v = graph_builder.value(input_scale.dtype, self.attrs['momentum'])
current_mean_tmp = graph_builder.emit('Mul', [momentum_v, mean_muls])
updated_moving_mean = graph_builder.emit('Add', [new_running_mean_tmp, current_mean_tmp])
mean_res = graph_builder.emit(
'InplaceAssign', [input_mean, updated_moving_mean, updated_moving_mean], attrs={'fake_output': True})
# variance_res is calculated by sample variance, and need to multiply by num / (num - 1)
var_num = float(num) / (num - 1)
var_num_v = graph_builder.value(input_scale.dtype, var_num)
var_mul_update = graph_builder.emit('Mul', [var_num_v, var_mul])
new_running_var_tmp = graph_builder.emit('Mul', [momentum_v_sub, input_variance])
current_var_tmp = graph_builder.emit('Mul', [momentum_v, var_mul_update])
updated_moving_variance = graph_builder.emit('Add', [new_running_var_tmp, current_var_tmp])
variance_res = graph_builder.emit(
'InplaceAssign', [input_variance, updated_moving_variance, updated_moving_variance],
attrs={'fake_output': True})
if input_x_new_type != input_x_ori_type: if input_x_new_type != input_x_ori_type:
res_y = graph_builder.emit('Cast', [res_y], attrs={'dst_type': input_x_ori_type}) res_y = graph_builder.emit('Cast', [res_y], attrs={'dst_type': input_x_ori_type})
return res_y, mean_res, variance_res, mean_muls, y_sqrt_rec return res_y, mean_res, variance_res, mean_muls, y_sqrt_rec
@ -140,3 +68,88 @@ class BatchNorm(Expander):
if input_x_new_type != input_x_ori_type: if input_x_new_type != input_x_ori_type:
res_y = graph_builder.emit('Cast', [res_y], attrs={'dst_type': input_x_ori_type}) res_y = graph_builder.emit('Cast', [res_y], attrs={'dst_type': input_x_ori_type})
return res_y, var_add, var_add, var_add, var_add return res_y, var_add, var_add, var_add, var_add
def _bn_train(self, graph_builder):
"""expand BatchNorm for training mode"""
input_x = self.inputs[0]
input_scale = self.inputs[1]
input_offset = self.inputs[2]
input_mean = self.inputs[3]
input_variance = self.inputs[4]
epsilon_v = graph_builder.value(input_scale.dtype, self.attrs['epsilon'])
reduce_axis = ()
shape_x = input_x.shape
if input_x.data_format == DF.NHWC:
reduce_axis = (0, 1, 2)
num = shape_x[0] * shape_x[1] * shape_x[2]
else:
reduce_axis = (0, 2, 3)
num = shape_x[0] * shape_x[2] * shape_x[3]
num_rec = 1.0 / num
num_rec_v = graph_builder.value(input_scale.dtype, num_rec)
# compute mean value of input_x
mean_sum = graph_builder.emit(
'ReduceSum', [input_x], attrs={'reduce_axis': reduce_axis, 'keep_dims': False})
mean_muls = graph_builder.emit('Mul', [mean_sum, num_rec_v])
# compute variance of input_x
if input_x.data_format in (DF.DEFAULT, DF.NCHW):
mean_muls_expand = graph_builder.emit(
'Reshape', [mean_muls], attrs={'shape': ExpandDims.infer_shape(mean_muls.shape, [-1, -1])})
else:
mean_muls_expand = mean_muls
var_sub = graph_builder.emit('Sub', [input_x, mean_muls_expand])
var_mul = graph_builder.emit('Mul', [var_sub, var_sub])
var_sum = graph_builder.emit('ReduceSum', [var_mul], attrs={'reduce_axis': reduce_axis, 'keep_dims': False})
var_mul = graph_builder.emit('Mul', [var_sum, num_rec_v])
# y_sqrt_rec means 1 / sqrt(variance + epsilon), which is calculated in backward pass
scalar_one = 1.0
scalar_one_v = graph_builder.value(input_scale.dtype, scalar_one)
y_add = graph_builder.emit('Add', [var_mul, epsilon_v])
y_sqrt = graph_builder.emit('Sqrt', [y_add])
y_sqrt_rec = graph_builder.emit('RealDiv', [scalar_one_v, y_sqrt])
# compute res_y
tmp_sub = graph_builder.emit('Sub', [input_x, mean_muls_expand])
if input_x.data_format in (DF.DEFAULT, DF.NCHW):
y_sqrt_rec_expand = graph_builder.emit(
'Reshape', [y_sqrt_rec], attrs={'shape': ExpandDims.infer_shape(y_sqrt_rec.shape, [-1, -1])})
else:
y_sqrt_rec_expand = y_sqrt_rec
y_norm = graph_builder.emit('Mul', [tmp_sub, y_sqrt_rec_expand])
if input_x.data_format in (DF.DEFAULT, DF.NCHW):
input_scale_expand = graph_builder.emit(
'Reshape', [input_scale], attrs={'shape': ExpandDims.infer_shape(input_scale.shape, [-1, -1])})
else:
input_scale_expand = input_scale
res_y_mul = graph_builder.emit('Mul', [input_scale_expand, y_norm])
if input_x.data_format in (DF.DEFAULT, DF.NCHW):
input_offset_expand = graph_builder.emit(
'Reshape', [input_offset], attrs={'shape': ExpandDims.infer_shape(input_offset.shape, [-1, -1])})
else:
input_offset_expand = input_offset
res_y = graph_builder.emit('Add', [res_y_mul, input_offset_expand])
# compute mean_res
momentum_sub = scalar_one - self.attrs['momentum']
momentum_v_sub = graph_builder.value(input_scale.dtype, momentum_sub)
new_running_mean_tmp = graph_builder.emit('Mul', [momentum_v_sub, input_mean])
momentum_v = graph_builder.value(input_scale.dtype, self.attrs['momentum'])
current_mean_tmp = graph_builder.emit('Mul', [momentum_v, mean_muls])
updated_moving_mean = graph_builder.emit('Add', [new_running_mean_tmp, current_mean_tmp])
mean_res = graph_builder.emit(
'InplaceAssign', [input_mean, updated_moving_mean, updated_moving_mean], attrs={'fake_output': True})
# variance_res is calculated by sample variance, and need to multiply by num / (num - 1)
var_num = float(num) / (num - 1)
var_num_v = graph_builder.value(input_scale.dtype, var_num)
var_mul_update = graph_builder.emit('Mul', [var_num_v, var_mul])
new_running_var_tmp = graph_builder.emit('Mul', [momentum_v_sub, input_variance])
current_var_tmp = graph_builder.emit('Mul', [momentum_v, var_mul_update])
updated_moving_variance = graph_builder.emit('Add', [new_running_var_tmp, current_var_tmp])
variance_res = graph_builder.emit(
'InplaceAssign', [input_variance, updated_moving_variance, updated_moving_variance],
attrs={'fake_output': True})
return res_y, mean_res, variance_res, mean_muls, y_sqrt_rec

View File

@ -17,12 +17,14 @@ from mindspore._extends.graph_kernel.model.model import DataFormat as DF
from ._utils import Expander, ExpanderInfoValidator as VLD from ._utils import Expander, ExpanderInfoValidator as VLD
from .expand_dims import ExpandDims from .expand_dims import ExpandDims
@VLD.add_format(DF.NHWC, DF.NHWC, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT) @VLD.add_format(DF.NHWC, DF.NHWC, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT)
@VLD.add_format(DF.NCHW, DF.NCHW, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT) @VLD.add_format(DF.NCHW, DF.NCHW, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT)
@VLD.add_format(DF.DEFAULT, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT) @VLD.add_format(DF.DEFAULT, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT)
@VLD.check_attrs('is_training', 'epsilon') @VLD.check_attrs('is_training', 'epsilon')
class BatchNormGrad(Expander): class BatchNormGrad(Expander):
"""BatchNormGrad expander""" """BatchNormGrad expander"""
def _expand(self, graph_builder): def _expand(self, graph_builder):
# get op info # get op info
input_dy = self.inputs[0] input_dy = self.inputs[0]

View File

@ -1,48 +0,0 @@
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===========================================================================
"""generate json desc for bias_add"""
from mindspore._extends.graph_kernel.model.model import DataFormat as DF
from ._utils import Expander, ExpanderInfoValidator as VLD
from .expand_dims import ExpandDims
@VLD.add_format(DF.DEFAULT, DF.DEFAULT)
@VLD.add_format(DF.NCHW, DF.DEFAULT)
@VLD.add_format(DF.NHWC, DF.DEFAULT)
class BiasAdd(Expander):
"""BiasAdd expander"""
def _expand(self, graph_builder):
input_x, input_y = self.inputs
if input_x.data_format == DF.NCHW:
input_y_expand = graph_builder.emit(
'Reshape', [input_y], attrs={'shape': ExpandDims.infer_shape(input_y.shape, [1, 2])})
result = graph_builder.emit('Add', [input_x, input_y_expand])
elif input_x.data_format == DF.DEFAULT:
if len(input_x.shape) == 2:
result = graph_builder.emit('Add', [input_x, input_y])
elif len(input_x.shape) == 3:
input_y_expand = graph_builder.emit(
'Reshape', [input_y], attrs={'shape': ExpandDims.infer_shape(input_y.shape, 1)})
result = graph_builder.emit('Add', [input_x, input_y_expand])
else: # len == 4
input_y_expand = graph_builder.emit(
'Reshape', [input_y], attrs={'shape': ExpandDims.infer_shape(input_y.shape, [1, 2])})
result = graph_builder.emit('Add', [input_x, input_y_expand])
else: # NHWC
result = graph_builder.emit('Add', [input_x, input_y])
return result

View File

@ -15,6 +15,7 @@
"""generate json desc for FusedMulAdd""" """generate json desc for FusedMulAdd"""
from ._utils import Expander from ._utils import Expander
class FusedMulAdd(Expander): class FusedMulAdd(Expander):
"""FusedMulAdd expander""" """FusedMulAdd expander"""

View File

@ -15,13 +15,15 @@
"""generate json desc for LayerNorm""" """generate json desc for LayerNorm"""
from mindspore._extends.graph_kernel.model.model import DataFormat as DF from mindspore._extends.graph_kernel.model.model import DataFormat as DF
from ._utils import Expander, ExpanderInfoValidator as VLD from ._utils import Expander, ExpanderInfoValidator as VLD
from ._utils import infer_shape_from_fractalNz, get_reduced_ori_shape, to_frac_z_axis from ._utils import infer_shape_from_fractalnz, get_reduced_ori_shape, to_frac_z_axis
@VLD.add_format(DF.FRAC_NZ, DF.DEFAULT, DF.DEFAULT) @VLD.add_format(DF.FRAC_NZ, DF.DEFAULT, DF.DEFAULT)
@VLD.add_format(DF.DEFAULT, DF.DEFAULT, DF.DEFAULT) @VLD.add_format(DF.DEFAULT, DF.DEFAULT, DF.DEFAULT)
@VLD.check_attrs('begin_norm_axis', 'begin_params_axis', 'epsilon') @VLD.check_attrs('begin_norm_axis', 'begin_params_axis', 'epsilon')
class LayerNorm(Expander): class LayerNorm(Expander):
"""LayerNorm expander""" """LayerNorm expander"""
def _expand(self, graph_builder): def _expand(self, graph_builder):
input_x, input_gamma, input_beta = self.inputs input_x, input_gamma, input_beta = self.inputs
processor = self.processor processor = self.processor
@ -36,7 +38,7 @@ class LayerNorm(Expander):
ori_shape_x = input_x.shape ori_shape_x = input_x.shape
if input_x.data_format == DF.FRAC_NZ: if input_x.data_format == DF.FRAC_NZ:
ori_shape_x = infer_shape_from_fractalNz(ori_shape_x) ori_shape_x = infer_shape_from_fractalnz(ori_shape_x)
# Calculate the scaling ratio of the average # Calculate the scaling ratio of the average
if begin_norm_axis < 0: if begin_norm_axis < 0:

View File

@ -17,6 +17,7 @@ from mindspore._extends.graph_kernel.model.model import DataFormat as DF
from mindspore._extends.graph_kernel.model.model import GraphKernelUnsupportedException as GKException from mindspore._extends.graph_kernel.model.model import GraphKernelUnsupportedException as GKException
from ._utils import Expander, ExpanderInfoValidator as VLD from ._utils import Expander, ExpanderInfoValidator as VLD
@VLD.check_attrs('transpose_a', 'transpose_b', 'left_format', 'right_format') @VLD.check_attrs('transpose_a', 'transpose_b', 'left_format', 'right_format')
class MatMul(Expander): class MatMul(Expander):
""" """
@ -24,7 +25,7 @@ class MatMul(Expander):
""" """
def __init__(self, expand_info): def __init__(self, expand_info):
super().__init__(expand_info) super(MatMul, self).__init__(expand_info)
self.transpose_a = self.attrs['transpose_a'] self.transpose_a = self.attrs['transpose_a']
self.transpose_b = self.attrs['transpose_b'] self.transpose_b = self.attrs['transpose_b']
self.left_format = self.attrs['left_format'] self.left_format = self.attrs['left_format']
@ -47,28 +48,28 @@ class MatMul(Expander):
if input_num < 2: if input_num < 2:
raise GKException("matul inputs number should bigger than 1, but got {}.".format(input_num)) raise GKException("matul inputs number should bigger than 1, but got {}.".format(input_num))
def _trans_shape(self, shape):
trans_shape = list(shape)
trans_shape[-2] = shape[-1]
trans_shape[-1] = shape[-2]
return trans_shape
def _expand(self, graph_builder): def _expand(self, graph_builder):
def transpose(shape):
trans_shape = list(shape)
trans_shape[-2] = shape[-1]
trans_shape[-1] = shape[-2]
return trans_shape
if not self._optimize_to_mul(): if not self._optimize_to_mul():
raise GKException("MatMul/BatchMatMul do not need to be replaced by Mul") raise GKException("MatMul/BatchMatMul do not need to be replaced by Mul")
#Matmul is replaced by Mul([b m k], [b k n]) when k==1 # Matmul is replaced by Mul([b m k], [b k n]) when k==1
input_a = self.inputs[0] input_a = self.inputs[0]
input_b = self.inputs[1] input_b = self.inputs[1]
if self.transpose_a: if self.transpose_a:
shape_a_trans = self._trans_shape(self.shape_a) shape_a_trans = transpose(self.shape_a)
input_a = graph_builder.emit('Reshape', [input_a], attrs={'shape': shape_a_trans}) input_a = graph_builder.emit('Reshape', [input_a], attrs={'shape': shape_a_trans})
if self.transpose_b: if self.transpose_b:
shape_b_trans = self._trans_shape(self.shape_b) shape_b_trans = transpose(self.shape_b)
input_b = graph_builder.emit('Reshape', [input_b], attrs={'shape': shape_b_trans}) input_b = graph_builder.emit('Reshape', [input_b], attrs={'shape': shape_b_trans})
result = graph_builder.emit('Mul', [input_a, input_b]) result = graph_builder.emit('Mul', [input_a, input_b])
if 'dst_type' in self.attrs and self.inputs[0].dtype != self.attrs['dst_type']: if 'dst_type' in self.attrs and self.inputs[0].dtype != self.attrs['dst_type']:
result = graph_builder.emit('Cast', [result], attrs={'dst_type': self.attrs['dst_type']}) result = graph_builder.emit('Cast', [result], attrs={'dst_type': self.attrs['dst_type']})
return result return result
class BatchMatMul(MatMul): class BatchMatMul(MatMul):
"""BatchMatMul expander""" """BatchMatMul expander"""

View File

@ -24,7 +24,7 @@ class MinimumGrad(Expander):
def _check(self): def _check(self):
if not self.attrs.get('grad_x', True) and not self.attrs.get('grad_y', True): if not self.attrs.get('grad_x', True) and not self.attrs.get('grad_y', True):
raise GKException("both grad_x and grad_y are False.") raise GKException("both grad_x and grad_y are False.")
return super()._check() return super(MinimumGrad, self)._check()
def _expand(self, graph_builder): def _expand(self, graph_builder):
input_x, input_y, input_dout = self.inputs input_x, input_y, input_dout = self.inputs
@ -34,7 +34,8 @@ class MinimumGrad(Expander):
dx = graph_builder.emit('Mul', [le_result, input_dout]) dx = graph_builder.emit('Mul', [le_result, input_dout])
dy = graph_builder.emit('Sub', [input_dout, dx]) dy = graph_builder.emit('Sub', [input_dout, dx])
# for minimumgrad op, output_shape should be equal to input_shape, but some elementwise operating may broadcast input_shape # for minimumgrad op, output_shape should be equal to input_shape,
# but some elementwise operating may broadcast input_shape
# then output_shape not equal to original input_shape, so need to reduce output to let them equal # then output_shape not equal to original input_shape, so need to reduce output to let them equal
reduce_axis_x = self.get_reduce_axis(input_x.shape, dx.shape) reduce_axis_x = self.get_reduce_axis(input_x.shape, dx.shape)
reduce_axis_y = self.get_reduce_axis(input_y.shape, dy.shape) reduce_axis_y = self.get_reduce_axis(input_y.shape, dy.shape)

View File

@ -15,7 +15,8 @@
"""generate json desc for softmax""" """generate json desc for softmax"""
from mindspore._extends.graph_kernel.model.model import DataFormat as DF from mindspore._extends.graph_kernel.model.model import DataFormat as DF
from ._utils import Expander, ExpanderInfoValidator as VLD from ._utils import Expander, ExpanderInfoValidator as VLD
from ._utils import infer_shape_from_fractalNz, get_reduced_ori_shape, to_frac_z_axis from ._utils import infer_shape_from_fractalnz, get_reduced_ori_shape, to_frac_z_axis
@VLD.add_format(DF.FRAC_NZ) @VLD.add_format(DF.FRAC_NZ)
@VLD.add_format(DF.DEFAULT) @VLD.add_format(DF.DEFAULT)
@ -30,7 +31,7 @@ class Softmax(Expander):
ori_shape = input_x.shape ori_shape = input_x.shape
if input_x.data_format == DF.FRAC_NZ: if input_x.data_format == DF.FRAC_NZ:
ori_shape = infer_shape_from_fractalNz(input_x.shape) ori_shape = infer_shape_from_fractalnz(input_x.shape)
for i, _ in enumerate(list(axis)): for i, _ in enumerate(list(axis)):
if axis[i] < 0: if axis[i] < 0:

View File

@ -15,7 +15,8 @@
"""generate json desc for SoftmaxGradExt""" """generate json desc for SoftmaxGradExt"""
from mindspore._extends.graph_kernel.model.model import DataFormat as DF from mindspore._extends.graph_kernel.model.model import DataFormat as DF
from ._utils import Expander, ExpanderInfoValidator as VLD from ._utils import Expander, ExpanderInfoValidator as VLD
from ._utils import infer_shape_from_fractalNz, get_reduced_ori_shape, to_frac_z_axis from ._utils import infer_shape_from_fractalnz, get_reduced_ori_shape, to_frac_z_axis
@VLD.add_format(DF.FRAC_NZ, DF.FRAC_NZ, DF.DEFAULT) @VLD.add_format(DF.FRAC_NZ, DF.FRAC_NZ, DF.DEFAULT)
@VLD.add_format(DF.DEFAULT, DF.DEFAULT, DF.DEFAULT) @VLD.add_format(DF.DEFAULT, DF.DEFAULT, DF.DEFAULT)
@ -29,7 +30,7 @@ class SoftmaxGradExt(Expander):
ori_shape = x.shape ori_shape = x.shape
if x.data_format == DF.FRAC_NZ: if x.data_format == DF.FRAC_NZ:
ori_shape = infer_shape_from_fractalNz(ori_shape) ori_shape = infer_shape_from_fractalnz(ori_shape)
if not axis: if not axis:
axis = [] axis = []
for i, _ in enumerate(ori_shape): for i, _ in enumerate(ori_shape):

View File

@ -15,7 +15,7 @@
"""generate json desc for SquareSumV1""" """generate json desc for SquareSumV1"""
from mindspore._extends.graph_kernel.model.model import DataFormat as DF from mindspore._extends.graph_kernel.model.model import DataFormat as DF
from ._utils import Expander, ExpanderInfoValidator as VLD from ._utils import Expander, ExpanderInfoValidator as VLD
from ._utils import infer_shape_from_fractalNz, get_reduced_ori_shape, to_frac_z_axis from ._utils import infer_shape_from_fractalnz, get_reduced_ori_shape, to_frac_z_axis
@VLD.add_format(DF.FRAC_NZ) @VLD.add_format(DF.FRAC_NZ)
@ -30,7 +30,7 @@ class SquareSumV1(Expander):
ori_shape = x.shape ori_shape = x.shape
if x.data_format == DF.FRAC_NZ: if x.data_format == DF.FRAC_NZ:
ori_shape = infer_shape_from_fractalNz(ori_shape) ori_shape = infer_shape_from_fractalnz(ori_shape)
if not axis: if not axis:
axis = [] axis = []
for i, _ in enumerate(ori_shape): for i, _ in enumerate(ori_shape):

View File

@ -17,6 +17,8 @@ from .model import PrimLib
class ParalGain: class ParalGain:
"""Paral Gain"""
def __init__(self, fusion_type, bottleneck, gain, block_assign, type_info): def __init__(self, fusion_type, bottleneck, gain, block_assign, type_info):
self.fusion_type = fusion_type self.fusion_type = fusion_type
self.bottleneck = bottleneck self.bottleneck = bottleneck
@ -41,7 +43,9 @@ class ScheduleAnalyzer:
self.ops = graph.ops self.ops = graph.ops
self.dom_op = [out.op for out in outputs] self.dom_op = [out.op for out in outputs]
def prod(self, shape): @staticmethod
def prod(shape):
"""Compute shape product"""
res = shape[0] res = shape[0]
for i in range(1, len(shape)): for i in range(1, len(shape)):
res = res * shape[i] res = res * shape[i]
@ -254,7 +258,7 @@ class ScheduleAnalyzer:
fusion_type = "block_fusion" fusion_type = "block_fusion"
type_info = None type_info = None
activate_pipeline_optimization = False # Disable pipeline optimization for now. activate_pipeline_optimization = False # Disable pipeline optimization for now.
if activate_pipeline_optimization: if activate_pipeline_optimization:
pipeline_info = ScheduleAnalyzer.pipeline_fusion_analyze( pipeline_info = ScheduleAnalyzer.pipeline_fusion_analyze(
blocks, op_sizes, exclude_gid) blocks, op_sizes, exclude_gid)
@ -287,4 +291,5 @@ def block_parallel_estimate(graphs):
def parallel_estimate(graphs): def parallel_estimate(graphs):
"""Estimate parallel gain"""
return block_parallel_estimate(graphs) return block_parallel_estimate(graphs)

View File

@ -13,7 +13,6 @@
# limitations under the License. # limitations under the License.
# =========================================================================== # ===========================================================================
"""Cost model splitter""" """Cost model splitter"""
import os
from functools import reduce as prod_reduce from functools import reduce as prod_reduce
from mindspore import log as logger from mindspore import log as logger
from .model import PrimLib, Graph, Tensor, Operator from .model import PrimLib, Graph, Tensor, Operator
@ -39,20 +38,24 @@ class GraphSplitByPattern:
def sync(self, x, y): def sync(self, x, y):
"""sync from y to x""" """sync from y to x"""
for i in self.alive: for i in self.alive:
if self.map[y][i] and not self.map[x][i]: self._link(self.map[y][i], x, i)
self.map[x][i] = True
def _link(self, cond, f, t):
"""link from `f` to `t`"""
if cond:
self.map[f][t] = True
def fuse(self, x, y): def fuse(self, x, y):
"""fuse y to x""" """fuse y to x"""
for i in self.alive: for i in self.alive:
# i is the succeeding node of y, links the x's previous nodes to i
if self.map[y][i] and not self.map[x][i]: if self.map[y][i] and not self.map[x][i]:
for pre in self.alive: for pre in self.alive:
if self.map[pre][x] and not self.map[pre][i]: self._link(self.map[pre][x], pre, i)
self.map[pre][i] = True # i is the previous node of y, link i to x's succeeding nodes
if self.map[i][y] and not self.map[i][x]: if self.map[i][y] and not self.map[i][x]:
for suc in self.alive: for suc in self.alive:
if self.map[x][suc] and not self.map[i][suc]: self._link(self.map[x][suc], i, suc)
self.map[i][suc] = True
self.alive.remove(y) self.alive.remove(y)
class Area: class Area:
@ -67,6 +70,10 @@ class GraphSplitByPattern:
self.stitch_ops = set() self.stitch_ops = set()
self.stitch_atomic_ops = set() self.stitch_atomic_ops = set()
def has_stitch_op(self):
"""check stitch_op exists"""
return self.stitch_ops or self.stitch_atomic_ops
def __init__(self, init_op, is_output, unique_id, reach_tab, recompute_ops=None): def __init__(self, init_op, is_output, unique_id, reach_tab, recompute_ops=None):
self.pattern = PrimLib.iter_type(init_op) if init_op is not None else PrimLib.UNKNOWN self.pattern = PrimLib.iter_type(init_op) if init_op is not None else PrimLib.UNKNOWN
self.ops = [] if init_op is None else [init_op] self.ops = [] if init_op is None else [init_op]
@ -286,31 +293,35 @@ class GraphSplitByPattern:
def fuse(self, selector): def fuse(self, selector):
"""Fuse areas""" """Fuse areas"""
changed = False def _fuse_area():
while True:
for dominant in self.areas: for dominant in self.areas:
result = selector(dominant) result = selector(dominant)
if result is not None and result[0]: if result is None or not result[0]:
fuse_areas, is_forward = result continue
fuse_areas = self.limit_area_size(dominant, fuse_areas) fuse_areas, is_forward = result
if not fuse_areas: fuse_areas = self.limit_area_size(dominant, fuse_areas)
continue if not fuse_areas:
if is_forward: continue
for area in fuse_areas: if is_forward:
dominant.fuse(area) for area in fuse_areas:
self.set_area_map(area.ops, dominant) dominant.fuse(area)
self.areas.remove(area) self.set_area_map(area.ops, dominant)
else: self.areas.remove(area)
forward_area = dominant else:
for area in fuse_areas: forward_area = dominant
area.fuse(forward_area) for area in fuse_areas:
self.set_area_map(forward_area.ops, area) area.fuse(forward_area)
self.areas.remove(forward_area) self.set_area_map(forward_area.ops, area)
forward_area = area self.areas.remove(forward_area)
changed = True forward_area = area
break return True
else: return False
return changed
changed, do_again = False, True
while do_again:
do_again = _fuse_area()
changed = changed or do_again
return changed
def fuse_recom(self, selector): def fuse_recom(self, selector):
"""Fuse recompute area to its user""" """Fuse recompute area to its user"""
@ -348,21 +359,6 @@ class GraphSplitByPattern:
graphmodes.append("basic" if area.mode == self.Area.MODE_BASIC else "composite") graphmodes.append("basic" if area.mode == self.Area.MODE_BASIC else "composite")
return subgraphs, graphmodes return subgraphs, graphmodes
def dump_subgraphs(self, subgraphs):
"""Dump subgraphs"""
if os.environ.get("ENABLE_SUBGRAPHS", "off") == "on":
subgraphs_str = "subgraphs:\nlen: " + str(len(subgraphs)) + "\n"
for i, sub in enumerate(subgraphs):
subgraphs_str += str("============") + str(i) + "\n"
subgraphs_str += str(sub)
dirname = 'subgraphs'
if not os.path.exists(dirname):
os.makedirs(dirname)
graphname = self.graph.name
filename = dirname + '/' + graphname + '.log'
with os.fdopen(os.open(filename, os.O_RDWR | os.O_CREAT), 'w+') as f:
f.write(subgraphs_str)
def pattern_fuse(self, fuse_func=None): def pattern_fuse(self, fuse_func=None):
"""fuse Areas by pattern repeatedly""" """fuse Areas by pattern repeatedly"""
del fuse_func del fuse_func
@ -376,34 +372,38 @@ class GraphSplitByPattern:
# Note: after this function, the input output relation is not maintained. # Note: after this function, the input output relation is not maintained.
self.split_output_reshapes() self.split_output_reshapes()
subgraphs, graphmodes = self.to_subgraphs() subgraphs, graphmodes = self.to_subgraphs()
self.dump_subgraphs(subgraphs)
return subgraphs, graphmodes return subgraphs, graphmodes
def split_output_reshapes(self): def split_output_reshapes(self):
"""Force split the output reshapes into other new """ """Force split the output Reshapes into other new area"""
def _remove_output_reshape(reshape_ops, other_ops):
def _run():
for op in reshape_ops:
if any([to_op in other_ops for to_op in op.output.to_ops]):
reshape_ops.remove(op)
other_ops.append(op)
return True
return False
while _run():
pass
new_areas = [] new_areas = []
for area in self.areas: for area in self.areas:
out_reshape_ops = [op for op in area.ops if PrimLib.iter_type(op) == PrimLib.RESHAPE] reshape_ops = [op for op in area.ops if PrimLib.iter_type(op) == PrimLib.RESHAPE]
remain_ops = [op for op in area.ops if op not in out_reshape_ops] other_ops = [op for op in area.ops if op not in reshape_ops]
if not remain_ops or not out_reshape_ops: if not other_ops or not reshape_ops:
continue continue
changed = True # remove the output reshape from "reshape_ops" and add it into "other_ops"
while changed: _remove_output_reshape(reshape_ops, other_ops)
changed = False if not reshape_ops:
for op in out_reshape_ops: continue
if any([to_op in remain_ops for to_op in op.output.to_ops]): for op in reshape_ops:
out_reshape_ops.remove(op) a = self.Area(op, False, 0, self.reach_tab)
remain_ops.append(op) self.set_default_mode(a)
changed = True new_areas.append(a)
break area.ops = other_ops
if out_reshape_ops: if len(other_ops) == 1:
for op in out_reshape_ops: self.set_default_mode(area)
a = self.Area(op, False, 0, self.reach_tab)
self.set_default_mode(a)
new_areas.append(a)
area.ops = remain_ops
if len(remain_ops) == 1:
self.set_default_mode(area)
if new_areas: if new_areas:
self.areas += new_areas self.areas += new_areas
@ -472,8 +472,8 @@ class GraphSplitByPattern:
region_ops.append(op) region_ops.append(op)
return False, None, weight, True return False, None, weight, True
# region fails to grow # region fails to grow
MAX_WEIGHT = 20 max_weight = 20
if weight > MAX_WEIGHT or len(op.inputs) > 1 or PrimLib.iter_type(op) > PrimLib.BROADCAST: if weight > max_weight or len(op.inputs) > 1 or PrimLib.iter_type(op) > PrimLib.BROADCAST:
return False, None, weight, False return False, None, weight, False
# region grows successfully # region grows successfully
weight = weight + 1 weight = weight + 1
@ -486,7 +486,7 @@ class GraphSplitByPattern:
cheap_regions = [] cheap_regions = []
for output in outputs: for output in outputs:
# tensor should have user other than user_area to be fused # tensor should have user other than user_area to be fused
if output.para_type != Tensor.PARA_OUTPUT and len(output.to_ops) < 2: if len(output.to_ops) < 2:
continue continue
region_ops = [] region_ops = []
grow = True grow = True
@ -533,14 +533,7 @@ class GraphSplitByPattern:
"""find recompute regions and copy them out to new Areas""" """find recompute regions and copy them out to new Areas"""
def do_recompute_fuse(): def do_recompute_fuse():
"""split the unfusing pattern by add recompute area""" """split the unfusing pattern by add recompute area"""
recompute_suc = False def recompute_cheap_region(dom):
orig_areas = []
orig_areas.extend(self.areas)
for dom in orig_areas:
if dom not in self.areas or not dom.out_relations:
continue
cheap_regions = self.find_cheap_regions(dom)
dom_changed = False
for cheap_region in cheap_regions: for cheap_region in cheap_regions:
user_areas = self.select_user_area(cheap_region[-1].output) user_areas = self.select_user_area(cheap_region[-1].output)
if not user_areas: if not user_areas:
@ -550,12 +543,17 @@ class GraphSplitByPattern:
self.pattern_fuse(self.fuse_recom) self.pattern_fuse(self.fuse_recom)
self.clear_recompute() self.clear_recompute()
if self.recom_res: if self.recom_res:
recompute_suc = True return True
# Copy region at most once for this dom return False
dom_changed = True recompute_suc = False
break orig_areas = []
if dom_changed: orig_areas.extend(self.areas)
break for dom in orig_areas:
if dom not in self.areas or not dom.out_relations:
continue
cheap_regions = self.find_cheap_regions(dom)
if recompute_cheap_region(dom):
recompute_suc = True
return recompute_suc return recompute_suc
if self.enable_recompute: if self.enable_recompute:
@ -563,9 +561,6 @@ class GraphSplitByPattern:
self.pattern_fuse() self.pattern_fuse()
use_poly_reduce = True
class GraphSplitGpu(GraphSplitByPattern): class GraphSplitGpu(GraphSplitByPattern):
"""Graph splitter""" """Graph splitter"""
BORADCAST_FUSE_DEPTH = 20 BORADCAST_FUSE_DEPTH = 20
@ -616,7 +611,7 @@ class GraphSplitGpu(GraphSplitByPattern):
return fused, True return fused, True
def _broadcast_pat_exclude(dom, a, r): def _broadcast_pat_exclude(dom, a, r):
if use_poly_reduce and a.pattern == PrimLib.REDUCE: if a.pattern == PrimLib.REDUCE:
return dom.pattern > PrimLib.ELEMWISE or r > PrimLib.ELEMWISE return dom.pattern > PrimLib.ELEMWISE or r > PrimLib.ELEMWISE
return a.pattern > PrimLib.REDUCE or r > PrimLib.BROADCAST return a.pattern > PrimLib.REDUCE or r > PrimLib.BROADCAST
@ -641,34 +636,14 @@ class GraphSplitGpu(GraphSplitByPattern):
fused.append(a) fused.append(a)
return fused, False return fused, False
def _check_reduce_exclude(dom):
if use_poly_reduce:
return False
# exclude large all-reduce
if len(dom.ops[0].inputs[0].shape) == len(dom.ops[0].attrs["reduce_axis"]) and \
dom.ops[0].inputs[0].get_size() > 10000:
return True
# exclude multi output
for a in dom.in_relations.keys():
if len(a.out_relations) > 1:
return True
if any([op.output.para_type == Tensor.PARA_OUTPUT for op in a.ops]):
return True
return False
def _reduce_pat_exclude(_, a, r): def _reduce_pat_exclude(_, a, r):
if len(a.ops) > self.REDUCE_FUSE_DEPTH: if len(a.ops) > self.REDUCE_FUSE_DEPTH:
return True return True
if use_poly_reduce: return a.pattern > PrimLib.ELEMWISE or r > PrimLib.REDUCE or r == PrimLib.BROADCAST
return a.pattern > PrimLib.ELEMWISE or r > PrimLib.REDUCE or r == PrimLib.BROADCAST
return a.pattern > PrimLib.BROADCAST or r > PrimLib.REDUCE
def _reduce_depth(dom): def _reduce_depth(dom):
if dom.pattern != PrimLib.REDUCE or len(dom.in_relations) != 1: if dom.pattern != PrimLib.REDUCE or len(dom.in_relations) != 1:
return None return None
if _check_reduce_exclude(dom):
return None
a, r = list(dom.in_relations.items())[0] a, r = list(dom.in_relations.items())[0]
if dom.ops[0].inputs[0].dtype == "float16" and a.is_output and len(a.ops) >= 10 and \ if dom.ops[0].inputs[0].dtype == "float16" and a.is_output and len(a.ops) >= 10 and \
_is_atomic_add_available(dom): _is_atomic_add_available(dom):
@ -681,8 +656,6 @@ class GraphSplitGpu(GraphSplitByPattern):
def _reduce_width(dom): def _reduce_width(dom):
if dom.pattern != PrimLib.REDUCE: if dom.pattern != PrimLib.REDUCE:
return None return None
if _check_reduce_exclude(dom):
return None
fused = [] fused = []
for a, r in dom.in_relations.items(): for a, r in dom.in_relations.items():
if dom.ops[0].inputs[0].dtype == "float16" and a.is_output and len(a.ops) >= 10 and \ if dom.ops[0].inputs[0].dtype == "float16" and a.is_output and len(a.ops) >= 10 and \
@ -763,16 +736,16 @@ class GraphSplitGpu(GraphSplitByPattern):
def _may_stitch(dom, a, r): def _may_stitch(dom, a, r):
if a.pattern <= PrimLib.REDUCE and r <= PrimLib.BROADCAST and dom.check_acyclic(a): if a.pattern <= PrimLib.REDUCE and r <= PrimLib.BROADCAST and dom.check_acyclic(a):
if _reduce_nums(a.ops) < 2: if _reduce_nums(a.ops) >= 2:
dom_outs = [op.output for op in dom.ops] return False
a_ins = [op_input for op in a.ops for op_input in op.inputs] dom_outs = [op.output for op in dom.ops]
a_outs = [op.output for op in a.ops] a_ins = [op_input for op in a.ops for op_input in op.inputs]
a_final_outs = [tensor for tensor in a_outs if tensor not in a_ins] a_outs = [op.output for op in a.ops]
stitch_tensors = [tensor for tensor in dom_outs if tensor in a_ins] a_final_outs = [tensor for tensor in a_outs if tensor not in a_ins]
if _same_stitch_axis(stitch_tensors, a_final_outs): stitch_tensors = [tensor for tensor in dom_outs if tensor in a_ins]
for tensor in stitch_tensors: if not _same_stitch_axis(stitch_tensors, a_final_outs):
if _tensor_size(tensor) >= 1024 * 1024: return False
return True return any([_tensor_size(tensor) >= 1024 * 1024 for tensor in stitch_tensors])
return False return False
def _reduce_stitch(dom): def _reduce_stitch(dom):
@ -785,14 +758,15 @@ class GraphSplitGpu(GraphSplitByPattern):
fused = [] fused = []
for a, r in dom.out_relations.items(): for a, r in dom.out_relations.items():
if _may_stitch(dom, a, r): if not _may_stitch(dom, a, r):
if a.pattern == PrimLib.REDUCE: continue
if a.ops[0].attrs['reduce_axis'] == dom.ops[0].attrs['reduce_axis']: if a.pattern == PrimLib.REDUCE:
dom.stitch_info.stitch_ops.add(dom.ops[0].output.name) if a.ops[0].attrs['reduce_axis'] == dom.ops[0].attrs['reduce_axis']:
fused.append(a)
elif a.pattern == PrimLib.BROADCAST:
dom.stitch_info.stitch_ops.add(dom.ops[0].output.name) dom.stitch_info.stitch_ops.add(dom.ops[0].output.name)
fused.append(a) fused.append(a)
elif a.pattern == PrimLib.BROADCAST:
dom.stitch_info.stitch_ops.add(dom.ops[0].output.name)
fused.append(a)
return fused, False return fused, False
def _transpose(dom): def _transpose(dom):
@ -804,6 +778,16 @@ class GraphSplitGpu(GraphSplitByPattern):
fused.append(a) fused.append(a)
return fused, True return fused, True
def _strided_slice(dom):
if dom.dom_op().prim != "StridedSlice":
return None
fused = []
for a, _ in dom.in_relations.items():
if a.pattern <= PrimLib.BROADCAST and a.check_acyclic(dom) and \
len(a.out_relations) == 1 and not a.is_output:
fused.append(a)
return fused, True
def _fuse_loop(): def _fuse_loop():
changed = True changed = True
while changed: while changed:
@ -814,10 +798,10 @@ class GraphSplitGpu(GraphSplitByPattern):
changed = self.fuse(_reduce_width) or changed changed = self.fuse(_reduce_width) or changed
changed = self.fuse(_broadcast_depth) or changed changed = self.fuse(_broadcast_depth) or changed
changed = self.fuse(_broadcast_width) or changed changed = self.fuse(_broadcast_width) or changed
if use_poly_reduce: changed = self.fuse(_strided_slice) or changed
changed = self.fuse(_reduce_output) or changed changed = self.fuse(_reduce_output) or changed
if enable_stitch_fusion: if enable_stitch_fusion:
changed = self.fuse(_reduce_stitch) or changed changed = self.fuse(_reduce_stitch) or changed
self.fuse(_transpose) self.fuse(_transpose)
def _fuse_once(fuse_func): def _fuse_once(fuse_func):
@ -825,9 +809,8 @@ class GraphSplitGpu(GraphSplitByPattern):
fuse_func(_reduce_depth) or fuse_func(_reduce_width) or fuse_func(_broadcast_depth) or \ fuse_func(_reduce_depth) or fuse_func(_reduce_width) or fuse_func(_broadcast_depth) or \
fuse_func(_broadcast_width): fuse_func(_broadcast_width):
return return
if use_poly_reduce: if fuse_func(_reduce_output) or (enable_stitch_fusion and fuse_func(_reduce_stitch)):
if fuse_func(_reduce_output) or (enable_stitch_fusion and fuse_func(_reduce_stitch)): return
return
fuse_func(_transpose) fuse_func(_transpose)
return return

View File

@ -216,6 +216,7 @@ class PrimLib:
'Transpose': Prim(OPAQUE), 'Transpose': Prim(OPAQUE),
'Tile': Prim(BROADCAST), 'Tile': Prim(BROADCAST),
'BroadcastTo': Prim(BROADCAST), 'BroadcastTo': Prim(BROADCAST),
'StridedSlice': Prim(OPAQUE),
'MatMul': Prim(OPAQUE), 'MatMul': Prim(OPAQUE),
'TransData': Prim(OPAQUE), 'TransData': Prim(OPAQUE),
'BatchMatMul': Prim(OPAQUE), 'BatchMatMul': Prim(OPAQUE),
@ -421,14 +422,13 @@ class Graph:
for t in op.inputs: for t in op.inputs:
if t not in inputs and t.op not in self.ops: if t not in inputs and t.op not in self.ops:
inputs.append(t) inputs.append(t)
if op.output not in outputs: if op.output in outputs:
if op.output.para_type == Tensor.PARA_OUTPUT or not op.output.to_ops: continue
outputs.append(op.output) if op.output.para_type == Tensor.PARA_OUTPUT or not op.output.to_ops:
else: outputs.append(op.output)
for d in op.output.to_ops: continue
if d not in self.ops: if any([succ not in self.ops for succ in op.output.to_ops]):
outputs.append(op.output) outputs.append(op.output)
break
if self.inputs: if self.inputs:
inputs = self.inputs inputs = self.inputs

View File

@ -28,11 +28,13 @@ class GraphBuilder:
self.graph = Graph(name, []) self.graph = Graph(name, [])
def set_input(self, *para): def set_input(self, *para):
"""set input to graph inputs"""
for t in para: for t in para:
t.para_type = Tensor.PARA_INPUT t.para_type = Tensor.PARA_INPUT
self.graph.inputs.append(t) self.graph.inputs.append(t)
def set_output(self, *para): def set_output(self, *para):
"""set output to graph inputs"""
for t in para: for t in para:
t.para_type = Tensor.PARA_OUTPUT t.para_type = Tensor.PARA_OUTPUT
self.graph.outputs.append(t) self.graph.outputs.append(t)
@ -50,6 +52,8 @@ class GraphBuilder:
def graph_scope(self, name): def graph_scope(self, name):
"""The graph scope to be processed""" """The graph scope to be processed"""
class GraphScope: class GraphScope:
"""Graph Scope"""
def __init__(self, gb): def __init__(self, gb):
self.gb = gb self.gb = gb
@ -77,7 +81,6 @@ class GraphBuilder:
"""Create a new Value""" """Create a new Value"""
if name in (None, ''): if name in (None, ''):
name = self._alloc_tensor_name() name = self._alloc_tensor_name()
v = Value(name, dtype, value) v = Value(name, dtype, value)
return v return v
@ -105,6 +108,7 @@ class GraphBuilder:
return output return output
def get(self): def get(self):
"""Get graphs"""
return self.graphs return self.graphs
@ -123,34 +127,14 @@ class CompositeGraph:
def load(self, desc): def load(self, desc):
"""Load Graph from json""" """Load Graph from json"""
def _attr_of(op, inputs, output): def _attr_of(op):
def _get_axis_while_none(input_shape, output_shape): if not op['attr']:
red_axis = [] return dict()
if len(output_shape) == len(input_shape):
for i, s in enumerate(output_shape):
if s == 1 and input_shape[i] > 1:
red_axis.append(i)
else:
red_axis = list(range(len(output_shape)))
return red_axis
attr = {} attr = {}
if op['name'] in ('ReduceSum', 'ReduceMax', 'ReduceMin'): for a in op['attr']:
for a in op['attr']: if a['name'] == 'axis' and op['name'] in ('ReduceSum', 'ReduceMax', 'ReduceMin'):
if a['name'] == 'axis': attr['reduce_axis'] = a['value']
red_axis, dim_size = [], len(inputs[0].shape) else:
if not a['value']:
red_axis = _get_axis_while_none(inputs[0].shape, output.shape)
else:
if isinstance(a['value'], int):
a['value'] = [a['value']]
for i in a['value']:
red_axis.append(i if i >= 0 else dim_size + i)
attr['reduce_axis'] = red_axis
if a['name'] == "reduce_output_fuse":
attr['reduce_output_fuse'] = a['value']
elif op['attr']:
for a in op['attr']:
attr[a['name']] = a['value'] attr[a['name']] = a['value']
return attr return attr
@ -166,7 +150,6 @@ class CompositeGraph:
'shape'], out_desc['data_type'], out_desc['format'] 'shape'], out_desc['data_type'], out_desc['format']
self.tensors[name] = builder.tensor( self.tensors[name] = builder.tensor(
shape, dtype, data_format, name=name, para_type=Tensor.PARA_OUTPUT) shape, dtype, data_format, name=name, para_type=Tensor.PARA_OUTPUT)
cur_fusion = None
for op in desc['op_desc']: for op in desc['op_desc']:
inputs = [self.tensors[d['tensor_name']] for x in op['input_desc'] for d in x if 'value' not in d] inputs = [self.tensors[d['tensor_name']] for x in op['input_desc'] for d in x if 'value' not in d]
out_desc = op['output_desc'] out_desc = op['output_desc']
@ -177,25 +160,17 @@ class CompositeGraph:
inputs[1].para_type = Tensor.PARA_OUTPUT inputs[1].para_type = Tensor.PARA_OUTPUT
output = inputs[2] output = inputs[2]
self.tensors[name] = output self.tensors[name] = output
else: continue
output = self.tensors.get(name, None) output = self.tensors.get(name, None)
if not output: if not output:
output = builder.tensor( output = builder.tensor(shape, dtype, data_format, name=name)
shape, dtype, data_format, name=name) self.tensors[name] = output
self.tensors[name] = output builder.op(op['name'], output, inputs, attrs=_attr_of(op))
builder.op(op['name'], output, inputs,
attrs=_attr_of(op, inputs, output))
if 'fusion' in op:
if cur_fusion is None:
cur_fusion = output
else:
cur_fusion.add_buddy(output)
if op['fusion'].endswith('_end'):
cur_fusion = None
self.graph = builder.get()[0] self.graph = builder.get()[0]
self.desc = desc self.desc = desc
def add_stitch_info(self, subgraph, desc): def add_stitch_info(self, subgraph, desc):
"""add stitch info to desc"""
if subgraph.stitch_info and subgraph.stitch_info.stitch_ops: if subgraph.stitch_info and subgraph.stitch_info.stitch_ops:
buffer_stitch = {'stitch_op': list(subgraph.stitch_info.stitch_ops)} buffer_stitch = {'stitch_op': list(subgraph.stitch_info.stitch_ops)}
if subgraph.stitch_info.stitch_atomic_ops: if subgraph.stitch_info.stitch_atomic_ops:
@ -204,6 +179,7 @@ class CompositeGraph:
return desc return desc
def add_recompute_ops(self, subgraph, desc): def add_recompute_ops(self, subgraph, desc):
"""add recompute ops to desc"""
if subgraph.recompute_ops: if subgraph.recompute_ops:
desc['recompute_ops'] = [op.output.name for op in subgraph.recompute_ops] desc['recompute_ops'] = [op.output.name for op in subgraph.recompute_ops]
return desc return desc
@ -227,43 +203,40 @@ class CompositeGraph:
inputs, outputs = subgraph.deduce_parameters() inputs, outputs = subgraph.deduce_parameters()
graph_ops = set(subgraph.ops) graph_ops = set(subgraph.ops)
inplace_assign, inplace_assign_z = self._pre_dump(outputs) inplace_assign, inplace_assign_z = self._pre_dump(outputs)
for key in self.desc:
def dump_output(t):
if t.name in inplace_assign:
z = inplace_assign_z if inplace_assign_z is not None else self.tensors[t.name]
return {'data_type': z.dtype, 'shape': z.shape, 'tensor_name': inplace_assign[t.name]}
return {'data_type': t.dtype, 'shape': t.shape, 'tensor_name': t.name}
def dump_op_desc(d):
if d['name'] == 'InplaceAssign':
y = d['input_desc'][1][0]['tensor_name']
if self.tensors[y].op in graph_ops:
z, fake = (inplace_assign_z, False) if inplace_assign_z is not None else (self.tensors[y], True)
inplace_desc = copy.deepcopy(d)
inplace_desc['attr'] = {'name': 'fake_output', 'value': fake}
z_desc, out_desc = inplace_desc['input_desc'][2][0], inplace_desc['output_desc'][0]
z_desc['shape'] = z.shape
z_desc['data_type'] = z.dtype
z_desc['tensor_name'] = z.name
out_desc['shape'] = z.shape
out_desc['data_type'] = z.dtype
return inplace_desc
op = self.tensors[d['output_desc'][0]['tensor_name']].op
if op in graph_ops or op in subgraph.recompute_ops:
return d
return None
for key in self.desc.keys():
if key == 'input_desc': if key == 'input_desc':
desc[key] = [ desc[key] = [[{'data_type': t.dtype, 'shape': t.shape, 'tensor_name': t.name}] for t in inputs]
[{'data_type': t.dtype, 'shape': t.shape, 'tensor_name': t.name}] for t in inputs]
elif key == 'output_desc': elif key == 'output_desc':
out_desc = [] desc[key] = list(map(dump_output, outputs))
for t in outputs:
if t.name in inplace_assign:
z = inplace_assign_z if inplace_assign_z is not None else self.tensors[t.name]
out_desc.append(
{'data_type': z.dtype, 'shape': z.shape, 'tensor_name': inplace_assign[t.name]})
else:
out_desc.append(
{'data_type': t.dtype, 'shape': t.shape, 'tensor_name': t.name})
desc[key] = out_desc
elif key == 'op_desc': elif key == 'op_desc':
op_desc = [] op_desc = map(dump_op_desc, self.desc[key])
for d in self.desc[key]: desc[key] = [d for d in op_desc if d is not None]
if d['name'] == 'InplaceAssign':
y = d['input_desc'][1][0]['tensor_name']
if self.tensors[y].op in graph_ops:
z, fake = (inplace_assign_z, False) if inplace_assign_z is not None else (
self.tensors[y], True)
inplace_desc = copy.deepcopy(d)
inplace_desc['attr'] = {'name': 'fake_output', 'value': fake}
z_desc, out_desc = inplace_desc['input_desc'][2][0], inplace_desc['output_desc'][0]
z_desc['shape'] = z.shape
z_desc['data_type'] = z.dtype
z_desc['tensor_name'] = z.name
out_desc['shape'] = z.shape
out_desc['data_type'] = z.dtype
op_desc.append(inplace_desc)
else:
op = self.tensors[d['output_desc'][0]['tensor_name']].op
if op in graph_ops or op in subgraph.recompute_ops:
op_desc.append(d)
desc[key] = op_desc
elif key == 'op': elif key == 'op':
desc[key] = subgraph.name desc[key] = subgraph.name
else: else:

View File

@ -16,7 +16,7 @@
import copy import copy
import sys import sys
from functools import reduce from functools import reduce as prod_reduce
from .model import GraphKernelUnsupportedException as GKException from .model import GraphKernelUnsupportedException as GKException
from .model import PrimLib, DataFormat as DF from .model import PrimLib, DataFormat as DF
@ -101,22 +101,24 @@ class OpInfer:
class _Elemwise(OpInfer): class _Elemwise(OpInfer):
"""Common infer for elementwise operators""" """Common infer for elementwise operators"""
@staticmethod
def _broadcast_shape(self, shapes): def broadcast_shape(shapes):
"""deduce broadcast shape using same rules as numpy""" """deduce broadcast shape using same rules as numpy"""
dim_size = max([len(shape) for shape in shapes]) dim_size = max([len(shape) for shape in shapes])
align_shapes = [[1] * (dim_size - len(shape)) + shape for shape in shapes] align_shapes = [[1] * (dim_size - len(shape)) + shape for shape in shapes]
out_shape = [1] * dim_size out_shape = [1] * dim_size
for i in range(dim_size): for i in range(dim_size):
for align_shape in align_shapes: for align_shape in align_shapes:
if align_shape[i] > 1: if align_shape[i] == 1:
if out_shape[i] == 1: continue
out_shape[i] = align_shape[i] if out_shape[i] == 1:
if out_shape[i] != align_shape[i]: out_shape[i] = align_shape[i]
raise GKException("shape broadcast failed!") elif out_shape[i] != align_shape[i]:
raise GKException("shape broadcast failed!")
return out_shape return out_shape
def _to_nz(self, default_shape): @staticmethod
def defaultformat_to_nz(default_shape):
"""default format shape to fractal_Nz format shape""" """default format shape to fractal_Nz format shape"""
if len(default_shape) not in (1, 2): if len(default_shape) not in (1, 2):
raise GKException("shape is too long!") raise GKException("shape is too long!")
@ -142,17 +144,17 @@ class _Elemwise(OpInfer):
"""returns the output shape with broadcast""" """returns the output shape with broadcast"""
# in case all inputs are default format/NHWC/NCHW # in case all inputs are default format/NHWC/NCHW
is_default = [input.data_format in (DF.DEFAULT, DF.NHWC, DF.NCHW) for input in self.inputs] is_default = [op_input.data_format in (DF.DEFAULT, DF.NHWC, DF.NCHW) for op_input in self.inputs]
if all(is_default): if all(is_default):
return self._broadcast_shape([input.shape for input in self.inputs]) return self.broadcast_shape([op_input.shape for op_input in self.inputs])
# in case formats are fractal_nz, default_fromat/NHWC/HCHW(optional) # in case formats are fractal_nz, default_fromat/NHWC/HCHW(optional)
is_default_frac_nz = [input.data_format in (DF.DEFAULT, DF.NHWC, DF.NCHW, DF.FRAC_NZ) is_default_frac_nz = [op_input.data_format in (DF.DEFAULT, DF.NHWC, DF.NCHW, DF.FRAC_NZ)
for input in self.inputs] for op_input in self.inputs]
if all(is_default_frac_nz): if all(is_default_frac_nz):
nz_shapes = [self._to_nz(input.shape) if input.data_format != DF.FRAC_NZ else input.shape nz_shapes = [self.defaultformat_to_nz(op_input.shape) if op_input.data_format != DF.FRAC_NZ
for input in self.inputs] else op_input.shape for op_input in self.inputs]
return self._broadcast_shape(nz_shapes) return self.broadcast_shape(nz_shapes)
raise GKException("Only support default and fractal_nz") raise GKException("Only support default and fractal_nz")
@ -214,9 +216,11 @@ class _Reshape(OpInfer):
class Reshape(_Reshape): class Reshape(_Reshape):
"""Reshape op infer"""
def _check_shape(self): def _check_shape(self):
size_before_reshape = reduce(lambda x, y: x * y, self.inputs[0].shape) size_before_reshape = prod_reduce(lambda x, y: x * y, self.inputs[0].shape)
size_after_reshape = reduce(lambda x, y: x * y, self.attrs["shape"]) size_after_reshape = prod_reduce(lambda x, y: x * y, self.attrs["shape"])
if size_before_reshape != size_after_reshape: if size_before_reshape != size_after_reshape:
raise GKException("The shape product before and after reshaping should be equal") raise GKException("The shape product before and after reshaping should be equal")
@ -225,11 +229,15 @@ class Reshape(_Reshape):
class Cast(_Elemwise): class Cast(_Elemwise):
"""Cast op infer"""
def _infer_type(self): def _infer_type(self):
return self.attrs["dst_type"] return self.attrs["dst_type"]
class InplaceAssign(_Elemwise): class InplaceAssign(_Elemwise):
"""InplaceAssign op infer"""
def _infer_shape(self): def _infer_shape(self):
return self.inputs[2].shape return self.inputs[2].shape
@ -241,6 +249,8 @@ class InplaceAssign(_Elemwise):
class BroadcastTo(OpInfer): class BroadcastTo(OpInfer):
"""BroadcastTo op infer"""
def _infer_shape(self): def _infer_shape(self):
return self.attrs["shape"] return self.attrs["shape"]
@ -256,6 +266,8 @@ class _CompareOp(_Elemwise):
class CImag(OpInfer): class CImag(OpInfer):
"""CImag op infer"""
def _check_type(self): def _check_type(self):
if self.inputs[0].dtype != "complex64": if self.inputs[0].dtype != "complex64":
raise GKException( raise GKException(
@ -266,6 +278,8 @@ class CImag(OpInfer):
class CReal(OpInfer): class CReal(OpInfer):
"""CReal op infer"""
def _check_type(self): def _check_type(self):
if self.inputs[0].dtype != "complex64": if self.inputs[0].dtype != "complex64":
raise GKException( raise GKException(
@ -276,6 +290,8 @@ class CReal(OpInfer):
class Complex(OpInfer): class Complex(OpInfer):
"""Complex op infer"""
def _check_type(self): def _check_type(self):
if self.inputs[0].dtype != "float32": if self.inputs[0].dtype != "float32":
raise GKException( raise GKException(
@ -288,26 +304,28 @@ class Complex(OpInfer):
class Less(_CompareOp): class Less(_CompareOp):
pass """Less op infer"""
class LessEqual(_CompareOp): class LessEqual(_CompareOp):
pass """LessEqual op infer"""
class Equal(_CompareOp): class Equal(_CompareOp):
pass """Equal op infer"""
class Greater(_CompareOp): class Greater(_CompareOp):
pass """Greater op infer"""
class GreaterEqual(_CompareOp): class GreaterEqual(_CompareOp):
pass """GreaterEqual op infer"""
class Select(_Elemwise): class Select(_Elemwise):
"""Select op infer"""
def _check_type(self): def _check_type(self):
if self.inputs[0].dtype != "bool": if self.inputs[0].dtype != "bool":
raise GKException("Select's input[0] should be a bool condition but got {}".format(self.inputs[0].dtype)) raise GKException("Select's input[0] should be a bool condition but got {}".format(self.inputs[0].dtype))
@ -319,6 +337,7 @@ class Select(_Elemwise):
def check_format_any(formats, checked_format): def check_format_any(formats, checked_format):
"""Check whether input format in formats list"""
if not isinstance(formats, (list, tuple)): if not isinstance(formats, (list, tuple)):
raise GKException("formats {} should be list or tuple, but got {}.".format(formats, type(formats))) raise GKException("formats {} should be list or tuple, but got {}.".format(formats, type(formats)))
if checked_format not in formats: if checked_format not in formats:
@ -326,11 +345,13 @@ def check_format_any(formats, checked_format):
def check_nd(data, nd): def check_nd(data, nd):
"""Check whether data are nd format"""
if not isinstance(data, (list, tuple)) or len(data) != nd: if not isinstance(data, (list, tuple)) or len(data) != nd:
raise GKException("input should be {}D list or tuple, but got {}.".format(nd, data)) raise GKException("input should be {}D list or tuple, but got {}.".format(nd, data))
def conv_had_pad(pad_list, pad_mode): def conv_had_pad(pad_list, pad_mode):
"""Check whether conv need to add pad"""
if not isinstance(pad_list, (list, tuple)) or len(pad_list) != 4: if not isinstance(pad_list, (list, tuple)) or len(pad_list) != 4:
raise GKException("pad_list should be 4D list or tuple, but got {}".format(pad_list)) raise GKException("pad_list should be 4D list or tuple, but got {}".format(pad_list))
if pad_list[0] != pad_list[1] or pad_list[2] != pad_list[3]: if pad_list[0] != pad_list[1] or pad_list[2] != pad_list[3]:

View File

@ -57,11 +57,11 @@ def _dump_split_info(flags, graph_json, graph_desc, subgraphs, graph_mode):
return return
utils.create_dir(utils.GRAPH_KERNEL_DUMP_PATH) utils.create_dir(utils.GRAPH_KERNEL_DUMP_PATH)
filename = os.path.join(utils.GRAPH_KERNEL_DUMP_PATH, "graph_kernel_split_mode.txt") filename = os.path.join(utils.GRAPH_KERNEL_DUMP_PATH, "graph_kernel_split_mode.txt")
with open(filename, "a+") as f: with os.fdopen(os.open(filename, os.O_WRONLY | os.O_CREAT), "a+") as f:
f.write("********** main graph: {} **********\n".format(graph_desc.name)) f.write("********** main graph: {} **********\n".format(graph_desc.name))
f.write("input json:\n{}\n".format(graph_json)) f.write("input json:\n{}\n".format(graph_json))
f.write("graph desc:\n{}\n".format(str(graph_desc))) f.write("graph desc:\n{}\n".format(str(graph_desc)))
if len(subgraphs) > 1: if len(subgraphs) > 1 or subgraphs[0].stitch_info.has_stitch_op():
for i, g in enumerate(subgraphs): for i, g in enumerate(subgraphs):
f.write("-------- subgraph {}, mode: {} --------\n".format(i, graph_mode[i])) f.write("-------- subgraph {}, mode: {} --------\n".format(i, graph_mode[i]))
f.write("{}\n".format(str(g))) f.write("{}\n".format(str(g)))

View File

@ -26,3 +26,5 @@ def create_dir(pathname):
os.mkdir(pathname) os.mkdir(pathname)
except OSError: except OSError:
pass pass
finally:
pass

View File

@ -32,7 +32,7 @@ from te_fusion.parallel_compilation import init_multi_process_env, start_ga_mult
get_finished_compilation_task get_finished_compilation_task
from .tbe_helper import get_soc_info, assemble_op_args, get_compute_op_list, get_options_info, get_fuzz_build_info, \ from .tbe_helper import get_soc_info, assemble_op_args, get_compute_op_list, get_options_info, get_fuzz_build_info, \
BuildType, adjust_custom_op_info, pack_op_args BuildType, adjust_custom_op_info, pack_op_args, get_module_name
from .tbe_job import TbeJob, JobStatus from .tbe_job import TbeJob, JobStatus
PLATFORM_FLAG = ["Ascend310", "Ascend910", "Hi3796CV300ES", "Ascend710", "Ascend610", "Hi3796CV300CS", "SD3403"] PLATFORM_FLAG = ["Ascend310", "Ascend910", "Hi3796CV300ES", "Ascend710", "Ascend610", "Hi3796CV300CS", "SD3403"]
@ -242,7 +242,7 @@ def check_support(job: TbeJob):
op_func_name = compute_op_info["func_name"] op_func_name = compute_op_info["func_name"]
if op_func_name in ("resize_nearest_neighbor_v2_grad_d", "resize_bilinear_v2_grad"): if op_func_name in ("resize_nearest_neighbor_v2_grad_d", "resize_bilinear_v2_grad"):
attrs.pop(-2) attrs.pop(-2)
op_module_name = compute_op_info["module_name"] op_module_name = get_module_name(compute_op_info)
py_module_path = compute_op_info["py_module_path"] py_module_path = compute_op_info["py_module_path"]
_normalize_module_name(op_module_name, py_module_path) _normalize_module_name(op_module_name, py_module_path)
func_name = "check_supported" func_name = "check_supported"
@ -281,7 +281,7 @@ def select_op_format(job: TbeJob):
compute_op_info = compute_op_info_list[0] compute_op_info = compute_op_info_list[0]
adjust_custom_op_info(compute_op_info) adjust_custom_op_info(compute_op_info)
inputs, outputs, attrs = assemble_op_args(compute_op_info) inputs, outputs, attrs = assemble_op_args(compute_op_info)
op_module_name = compute_op_info["module_name"] op_module_name = get_module_name(compute_op_info)
py_module_path = compute_op_info["py_module_path"] py_module_path = compute_op_info["py_module_path"]
_normalize_module_name(op_module_name, py_module_path) _normalize_module_name(op_module_name, py_module_path)
op_func_name = "op_select_format" op_func_name = "op_select_format"
@ -317,7 +317,7 @@ def _pre_build_compute_op_info(compute_op, job):
if l1_size != -1: if l1_size != -1:
set_L1_info("op_L1_space", -1) set_L1_info("op_L1_space", -1)
inputs, outputs, attrs = assemble_op_args(compute_op) inputs, outputs, attrs = assemble_op_args(compute_op)
op_module_name = compute_op["module_name"] op_module_name = get_module_name(compute_op)
py_module_path = compute_op["py_module_path"] py_module_path = compute_op["py_module_path"]
op_func_name = compute_op["func_name"] op_func_name = compute_op["func_name"]
op_type = compute_op["type"] op_type = compute_op["type"]
@ -340,8 +340,8 @@ def _pre_build_compute_op_info(compute_op, job):
job.info("OpType {} support op_impl_mode, current op_impl_mode:{}".format(op_type, op_impl_mode)) job.info("OpType {} support op_impl_mode, current op_impl_mode:{}".format(op_type, op_impl_mode))
options = get_options_info(job.content) options = get_options_info(job.content)
dispatch_prebuild_task(job.source_id, job.id, l1_size, op_module_name, op_type, op_func_name, unknown_shape, dispatch_prebuild_task(job.source_id, job.id, l1_size, op_module_name, op_type, op_func_name, unknown_shape,
(inputs, outputs, attrs, options), int64_mode, dynamic_compile_static, job.rl_tune_switch, (inputs, outputs, attrs, options), int64_mode, dynamic_compile_static, unknown_shape,
job.rl_tune_list, job.pass_list, job.op_tune_switch, job.op_tune_list) job.rl_tune_switch, job.rl_tune_list, job.pass_list, job.op_tune_switch, job.op_tune_list)
def get_prebuild_output(op_name): def get_prebuild_output(op_name):
@ -391,7 +391,7 @@ def build_single_pre_op(job: TbeJob):
inputs, outputs, attrs = assemble_op_args(compute_op_info) inputs, outputs, attrs = assemble_op_args(compute_op_info)
op_type = compute_op_info["type"] op_type = compute_op_info["type"]
l1_size = job.content["l1_size"] l1_size = job.content["l1_size"]
op_module_name = compute_op_info["module_name"] op_module_name = get_module_name(compute_op_info)
op_kernel_name = compute_op_info["op_name"] op_kernel_name = compute_op_info["op_name"]
py_module_path = compute_op_info["py_module_path"] py_module_path = compute_op_info["py_module_path"]
op_func_name = compute_op_info["func_name"] op_func_name = compute_op_info["func_name"]
@ -404,9 +404,9 @@ def build_single_pre_op(job: TbeJob):
fuzz_build_info = get_fuzz_build_info(job.content) fuzz_build_info = get_fuzz_build_info(job.content)
dispatch_single_op_compile_task(job.source_id, job.id, l1_size, op_module_name, op_type, op_func_name, dispatch_single_op_compile_task(job.source_id, job.id, l1_size, op_module_name, op_type, op_func_name,
op_kernel_name, unknown_shape, (inputs, outputs, attrs, options), int64_mode, op_kernel_name, unknown_shape, (inputs, outputs, attrs, options), int64_mode,
None, None, dynamic_compile_static, op_pattern, json.dumps(fuzz_build_info), None, None, dynamic_compile_static, unknown_shape, op_pattern,
job.rl_tune_switch, job.rl_tune_list, job.pass_list, job.op_tune_switch, json.dumps(fuzz_build_info), job.rl_tune_switch, job.rl_tune_list, job.pass_list,
job.op_tune_list) job.op_tune_switch, job.op_tune_list)
return True return True
@ -487,7 +487,7 @@ def rl_tune_single_op(job: TbeJob):
inputs, outputs, attrs = assemble_op_args(compute_op_info) inputs, outputs, attrs = assemble_op_args(compute_op_info)
op_type = compute_op_info["type"] op_type = compute_op_info["type"]
l1_size = job.content["l1_size"] l1_size = job.content["l1_size"]
op_module_name = compute_op_info["module_name"] op_module_name = get_module_name(compute_op_info)
op_kernel_name = compute_op_info["op_name"] op_kernel_name = compute_op_info["op_name"]
full_name = compute_op_info["name"] full_name = compute_op_info["name"]
py_module_path = compute_op_info["py_module_path"] py_module_path = compute_op_info["py_module_path"]
@ -503,7 +503,7 @@ def rl_tune_single_op(job: TbeJob):
device_id = job.content["SocInfo"]["deviceId"] device_id = job.content["SocInfo"]["deviceId"]
try: try:
build_single_op_from_c(op_module_name, op_func_name, op_type, "build", unknown_shape, build_single_op_from_c(op_module_name, op_func_name, op_type, "build", unknown_shape,
(inputs, outputs, attrs), int64_mode, dynamic_compile_static, op_pattern, (inputs, outputs, attrs), int64_mode, dynamic_compile_static, unknown_shape, op_pattern,
auto_tiling_mode, device_id, json.dumps(fuzz_build_info)) auto_tiling_mode, device_id, json.dumps(fuzz_build_info))
# pylint: disable=broad-except # pylint: disable=broad-except
except Exception: except Exception:
@ -547,7 +547,7 @@ def rl_tune_fusion_op(job: TbeJob):
compute_op_list = get_compute_op_list(job.content) compute_op_list = get_compute_op_list(job.content)
op_module_names_str = "" op_module_names_str = ""
for op in compute_op_list: for op in compute_op_list:
op_module_names_str = op_module_names_str + "," + op["module_name"] op_module_names_str = op_module_names_str + "," + get_module_name(op)
op_module_names_str = op_module_names_str[1:] op_module_names_str = op_module_names_str[1:]
from schedule_search.rl_online_tune import dispatch_fusion_tune_task from schedule_search.rl_online_tune import dispatch_fusion_tune_task
res = dispatch_fusion_tune_task(job.source_id, job.id, l1_size, base_kernel, op_kernel_name, op_module_names_str, res = dispatch_fusion_tune_task(job.source_id, job.id, l1_size, base_kernel, op_kernel_name, op_module_names_str,

View File

@ -179,8 +179,6 @@ def get_options_info(job_content):
options["op_debug_level"] = job_content["SocInfo"]["op_debug_level"] options["op_debug_level"] = job_content["SocInfo"]["op_debug_level"]
options["op_impl_mode"] = job_content["SocInfo"]["op_impl_mode"] options["op_impl_mode"] = job_content["SocInfo"]["op_impl_mode"]
options["op_debug_dir"] = job_content["SocInfo"]["op_debug_dir"] options["op_debug_dir"] = job_content["SocInfo"]["op_debug_dir"]
options["op_compiler_cache_dir"] = job_content["SocInfo"]["op_compiler_cache_dir"]
options["op_compiler_cache_mode"] = job_content["SocInfo"]["op_compiler_cache_mode"]
options["mdl_bank_path"] = job_content["SocInfo"]["op_debug_level"] options["mdl_bank_path"] = job_content["SocInfo"]["op_debug_level"]
options["op_bank_path"] = job_content["SocInfo"]["op_bank_path"] options["op_bank_path"] = job_content["SocInfo"]["op_bank_path"]
options["deviceId"] = job_content["SocInfo"]["deviceId"] options["deviceId"] = job_content["SocInfo"]["deviceId"]
@ -220,6 +218,19 @@ def get_func_names(job_content):
return func_names return func_names
def get_module_name(compute_op_info):
"""
get compute_op_info
:param compute_op_info:
:return:
"""
unknown_shape = compute_op_info["unknown_shape"]
op_module_name = compute_op_info["module_name"]
if unknown_shape:
op_module_name = op_module_name.split(".")[0] + ".dynamic." + op_module_name.split(".")[-1]
return op_module_name
def adjust_custom_op_info(compute_op_info): def adjust_custom_op_info(compute_op_info):
""" """
adjust custom op info adjust custom op info

View File

@ -71,12 +71,13 @@ def _get_message(msg, args):
class TbeJob: class TbeJob:
""" Tbe compilation job """ """ Tbe compilation job """
def __init__(self, source_id, job_id, job_type, content, json_str, sys_info): def __init__(self, source_id, job_id, job_type, content, fusion_op_name, json_str, sys_info):
self.source_id = source_id self.source_id = source_id
self.id = job_id self.id = job_id
self.type = JobType(job_type) self.type = JobType(job_type)
self.status = JobStatus.JOB_INITIAL self.status = JobStatus.JOB_INITIAL
self.content = content self.content = content
self.fusion_op_name = fusion_op_name
self.result = "" self.result = ""
self.process_info = [] self.process_info = []
self.json_string = json_str self.json_string = json_str
@ -149,8 +150,8 @@ class TbeJob:
result["source_id"] = self.source_id result["source_id"] = self.source_id
result["job_id"] = self.id result["job_id"] = self.id
result["job_type"] = self.type.value result["job_type"] = self.type.value
result["fusion_op_name"] = self.fusion_op_name
result["result"] = self.result result["result"] = self.result
self.debug("Resp result:{}".format(json.dumps(result)))
process_info = [] process_info = []
for info in self.process_info: for info in self.process_info:
msg = {"index": info.index, "level": info.level.value, "message": info.info} msg = {"index": info.index, "level": info.level.value, "message": info.info}

View File

@ -102,8 +102,9 @@ class TbeJobManager:
source_id = job_json["source_id"] source_id = job_json["source_id"]
job_type = job_json["job_type"] job_type = job_json["job_type"]
sys_info = self._get_job_sys_info() sys_info = self._get_job_sys_info()
job = TbeJob(source_id, job_id, job_type, job_json["job_content"], job_str, sys_info) fusion_op_name = "NA" if "fusion_op_name" not in job_json["job_content"] else job_json["job_content"][
job.debug("Req job string: {}".format(job_str)) "fusion_op_name"]
job = TbeJob(source_id, job_id, job_type, job_json["job_content"], fusion_op_name, job_str, sys_info)
post_job(self._all_jobs, job) post_job(self._all_jobs, job)
if not self.tbe_initialize and job.type != JobType.INITIALIZE_JOB: if not self.tbe_initialize and job.type != JobType.INITIALIZE_JOB:
job.error( job.error(
@ -115,6 +116,7 @@ class TbeJobManager:
return res return res
# pylint: disable=broad-except # pylint: disable=broad-except
except Exception: except Exception:
# pylint: disable=no-value-for-parameter
sys_info = self._get_job_sys_info() sys_info = self._get_job_sys_info()
job = TbeJob(-1, -1, "", None, job_str, sys_info) if job is None else job job = TbeJob(-1, -1, "", None, job_str, sys_info) if job is None else job
job.status = JobStatus.JOB_FAILED job.status = JobStatus.JOB_FAILED
@ -261,9 +263,6 @@ class TbeJobManager:
return self.add_to_finished_jobs(query_job, JobStatus.JOB_SUCCESS) return self.add_to_finished_jobs(query_job, JobStatus.JOB_SUCCESS)
target_job = get_job(self._running_jobs, target_source_id, target_job_id) target_job = get_job(self._running_jobs, target_source_id, target_job_id)
if target_job: if target_job:
query_job.debug("Found job in Running jobs, source_id:{}, job_id:{}".format(target_source_id,
target_job_id))
target_job.debug("Be Queried")
query_job.result = target_job.get_result() query_job.result = target_job.get_result()
return self.add_to_finished_jobs(query_job, JobStatus.JOB_SUCCESS) return self.add_to_finished_jobs(query_job, JobStatus.JOB_SUCCESS)
target_job = get_job(self._all_jobs, target_source_id, target_job_id) target_job = get_job(self._all_jobs, target_source_id, target_job_id)

View File

@ -16,7 +16,6 @@
import os import os
from mindspore import log as logger from mindspore import log as logger
from mindspore._extends.parallel_compile.akg_compiler.akg_process import create_akg_parallel_process from mindspore._extends.parallel_compile.akg_compiler.akg_process import create_akg_parallel_process
from mindspore._extends.parallel_compile.akg_compiler.compiler import run_compiler as akg_compile_single
class Messager: class Messager:
@ -146,9 +145,7 @@ class AkgBuilder():
def handle(self, messager, arg): def handle(self, messager, arg):
"""Handle message about akg""" """Handle message about akg"""
if arg == 'AKG/PID': if arg == 'AKG/START':
messager.send_res(os.getpid())
elif arg == 'AKG/START':
messager.send_ack() messager.send_ack()
process_num_str = messager.get_message() process_num_str = messager.get_message()
messager.send_ack() messager.send_ack()
@ -173,17 +170,8 @@ class AkgBuilder():
else: else:
messager.send_ack(False) messager.send_ack(False)
break break
elif arg == 'AKG/COMPILE': else:
messager.send_ack() raise RuntimeError("Unknown message type: %s" % arg)
json = messager.get_message()
try:
akg_compile_single(json, self.attrs)
except ValueError:
messager.send_ack(False)
messager.exit()
finally:
pass
messager.send_ack()
def get_logger(): def get_logger():

View File

@ -297,20 +297,14 @@ if(MODE_ASCEND_ALL)
${ASCEND_DRIVER_BACK_PATH}) ${ASCEND_DRIVER_BACK_PATH})
find_library(DATATRANSFER datatransfer HINTS ${ASCEND_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH} find_library(DATATRANSFER datatransfer HINTS ${ASCEND_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH}
${ASCEND_DRIVER_BACK_PATH}) ${ASCEND_DRIVER_BACK_PATH})
find_library(PROFILING msprofiler_fwkacl ${ASCEND_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH}) find_library(PROFILING msprofiler ${ASCEND_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH})
find_library(ACL ascendcl ${ASCEND_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH}) find_library(ACL ascendcl ${ASCEND_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH})
find_library(PLATFORM platform ${ASCEND_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH}) find_library(PLATFORM platform ${ASCEND_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH})
find_library(OPTILING optiling ${ASCEND_OPP_PATH} ${ASCEND_TOOLKIT_OPP_PATH}) find_library(OPTILING optiling ${ASCEND_OPP_PATH} ${ASCEND_TOOLKIT_OPP_PATH})
find_library(OPT_FEATURE opt_feature ${ASCEND_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH}) find_library(OPT_FEATURE opt_feature ${ASCEND_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH})
add_library(ms_profile SHARED
${CMAKE_CURRENT_SOURCE_DIR}/runtime/device/ascend/profiling/profiling_callback_register.cc)
set_target_properties(ms_profile PROPERTIES LINKER_LANGUAGE CXX)
target_link_options(ms_profile PRIVATE -Wl,-init,common_log_init)
target_link_libraries(ms_profile -Wl,--start-group -Wl,--whole-archive ${PROFILING} -Wl,--no-whole-archive
mindspore::protobuf -Wl,--end-group)
target_link_libraries(mindspore ${RUNTIME_LIB} ${TSDCLIENT} ${DATATRANSFER} ${ERROR_MANAGER} -Wl,--no-as-needed target_link_libraries(mindspore ${RUNTIME_LIB} ${TSDCLIENT} ${DATATRANSFER} ${ERROR_MANAGER} -Wl,--no-as-needed
${OPTILING} ${PLATFORM} ${ACL} ${OPT_FEATURE}) ${OPTILING} ${PLATFORM} ${ACL} ${OPT_FEATURE} ${PROFILING})
target_link_libraries(mindspore -Wl,--start-group proto_input mindspore::protobuf -Wl,--end-group) target_link_libraries(mindspore -Wl,--start-group proto_input mindspore::protobuf -Wl,--end-group)
elseif(CMAKE_SYSTEM_NAME MATCHES "Windows") elseif(CMAKE_SYSTEM_NAME MATCHES "Windows")
target_link_libraries(mindspore -Wl,--start-group proto_input mindspore::protobuf mindspore::sentencepiece target_link_libraries(mindspore -Wl,--start-group proto_input mindspore::protobuf mindspore::sentencepiece
@ -325,7 +319,7 @@ endif()
set(CMAKE_BUILD_WITH_INSTALL_RPATH TRUE) set(CMAKE_BUILD_WITH_INSTALL_RPATH TRUE)
set_property(SOURCE "pipeline/jit/init.cc" PROPERTY set_property(SOURCE "pipeline/jit/init.cc" PROPERTY
COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_PIPELINE) COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_PIPELINE)
pybind11_add_module(_c_expression NO_EXTRAS "pipeline/jit/init.cc") pybind11_add_module(_c_expression NO_EXTRAS "pipeline/jit/init.cc" NO_EXTRAS)
MESSAGE(STATUS "operation system is ${CMAKE_SYSTEM}") MESSAGE(STATUS "operation system is ${CMAKE_SYSTEM}")
if(CMAKE_SYSTEM_NAME MATCHES "Linux") if(CMAKE_SYSTEM_NAME MATCHES "Linux")
@ -375,9 +369,6 @@ else()
proto_input -Wl,--no-whole-archive) proto_input -Wl,--no-whole-archive)
target_link_libraries(_c_expression PRIVATE mindspore::pybind11_module) target_link_libraries(_c_expression PRIVATE mindspore::pybind11_module)
target_link_libraries(_c_expression PRIVATE mindspore_gvar) target_link_libraries(_c_expression PRIVATE mindspore_gvar)
if(MODE_ASCEND_ALL)
target_link_libraries(_c_expression PRIVATE -Wl,--no-as-needed ms_profile)
endif()
endif() endif()
if(USE_GLOG) if(USE_GLOG)

View File

@ -36,6 +36,7 @@ if(ENABLE_CPU)
"cpu/ps/*.cc" "cpu/ps/*.cc"
"cpu/quantum/*.cc" "cpu/quantum/*.cc"
"cpu/pyfunc/*.cc" "cpu/pyfunc/*.cc"
"cpu/rl/*.cc"
) )
if(NOT ENABLE_MPI) if(NOT ENABLE_MPI)
@ -84,6 +85,7 @@ if(NOT ENABLE_CPU OR WIN32)
list(REMOVE_ITEM CPU_SRC_LIST "cpu/fl/get_model_kernel.cc") list(REMOVE_ITEM CPU_SRC_LIST "cpu/fl/get_model_kernel.cc")
list(REMOVE_ITEM CPU_SRC_LIST "cpu/fl/start_fl_job_kernel.cc") list(REMOVE_ITEM CPU_SRC_LIST "cpu/fl/start_fl_job_kernel.cc")
list(REMOVE_ITEM CPU_SRC_LIST "cpu/fl/update_model_kernel.cc") list(REMOVE_ITEM CPU_SRC_LIST "cpu/fl/update_model_kernel.cc")
list(REMOVE_ITEM CPU_SRC_LIST "cpu/fl/push_metrics_kernel.cc")
endif() endif()
if(ENABLE_GPU) if(ENABLE_GPU)

View File

@ -197,17 +197,37 @@ int32_t AkgKernelPool::Init(const std::vector<JsonNodePair> &build_args) {
} }
AkgKernelPool::~AkgKernelPool() { AkgKernelPool::~AkgKernelPool() {
// Detach shared memory {
auto ret = shmdt(reinterpret_cast<void *>(kernel_lists_[0])); LockMng lock(fd_);
if (ret < 0) { if (!lock.locked_) {
MS_LOG(EXCEPTION) << "Shared_mem detach failed, errno:" << strerror(errno); MS_LOG(EXCEPTION) << "Failed to acquire lock.";
} }
// Realse shared_memroy struct shmid_ds buf;
if (is_creator_) { auto ret = shmctl(shm_id_, IPC_STAT, &buf);
ret = shmctl(shm_id_, IPC_RMID, nullptr); if (ret == -1) {
MS_LOG(EXCEPTION) << "Failed to get the info of shared memory, errno:" << strerror(errno);
}
bool need_delete_by_last = false;
// if the creator exits unexpectedly and fails to delete the shm, the last process will try to delete the shm
if (((buf.shm_perm.mode & SHM_DEST) == 0) && (buf.shm_nattch == 1)) {
need_delete_by_last = true;
}
// Detach shared memory
ret = shmdt(reinterpret_cast<void *>(kernel_lists_[0]));
if (ret < 0) { if (ret < 0) {
MS_LOG(EXCEPTION) << "Realse shared_mem failed, errno:" << strerror(errno); MS_LOG(EXCEPTION) << "Shared_mem detach failed, errno:" << strerror(errno);
}
// Realse shared_memroy
if (is_creator_ || need_delete_by_last) {
ret = shmctl(shm_id_, IPC_RMID, nullptr);
if (ret < 0) {
MS_LOG(EXCEPTION) << "Realse shared_mem failed, errno:" << strerror(errno);
}
} }
} }
@ -354,35 +374,6 @@ int32_t AkgKernelPool::Wait() {
return -1; return -1;
} }
std::vector<std::string> AkgKernelBuilder::GetNotCachedKernelJsons(const std::vector<JsonNodePair> &build_args) {
// Remove cached nodes, gether unique nodes, and collect repeated nodes which need postprecess.
std::vector<std::string> jsons;
std::unordered_set<std::string> kernel_name_set;
for (const auto &[json_generator, anf_node] : build_args) {
MS_EXCEPTION_IF_NULL(anf_node);
auto kernel_name = json_generator.kernel_name();
MS_LOG(DEBUG) << "Akg start compile op: " << kernel_name;
auto cached_kernel_pack = AkgSearchCache(kernel_name);
if (cached_kernel_pack != nullptr) {
MS_LOG(DEBUG) << "Use cached kernel, kernel_name[" << kernel_name << "], fullname_with_scope["
<< anf_node->fullname_with_scope() << "].";
AkgSetKernelMod(cached_kernel_pack, json_generator, anf_node);
continue;
}
if (kernel_name_set.count(kernel_name) != 0) {
repeat_nodes_.push_back({json_generator, anf_node});
continue;
}
kernel_name_set.insert(kernel_name);
auto kernel_json = json_generator.kernel_json_str();
AkgSaveJsonInfo(kernel_name, kernel_json);
jsons.push_back(kernel_json);
}
return jsons;
}
std::vector<JsonNodePair> AkgKernelBuilder::GetNotCachedKernels(const std::vector<JsonNodePair> &build_args) { std::vector<JsonNodePair> AkgKernelBuilder::GetNotCachedKernels(const std::vector<JsonNodePair> &build_args) {
std::unordered_set<std::string> kernel_name_set; std::unordered_set<std::string> kernel_name_set;
std::vector<JsonNodePair> new_build_args; std::vector<JsonNodePair> new_build_args;
@ -432,8 +423,8 @@ bool AkgKernelBuilder::HandleRepeatNodes() {
<< anf_node->fullname_with_scope() << "]."; << anf_node->fullname_with_scope() << "].";
return false; return false;
} }
MS_LOG(INFO) << "Use just compiled kernel, kernel_name[" << kernel_name << "], fullname_with_scope[" MS_LOG(DEBUG) << "Use just compiled kernel, kernel_name[" << kernel_name << "], fullname_with_scope["
<< anf_node->fullname_with_scope() << "]."; << anf_node->fullname_with_scope() << "].";
AkgSetKernelMod(cached_kernel_pack, json_generator, anf_node); AkgSetKernelMod(cached_kernel_pack, json_generator, anf_node);
} }
return true; return true;
@ -555,7 +546,7 @@ bool AkgKernelBuilder::AkgKernelParallelBuild(const std::vector<AnfNodePtr> &anf
} }
if (json_and_node.empty()) { if (json_and_node.empty()) {
MS_LOG(DEBUG) << "There is no kernel needed to be compiled."; MS_LOG(INFO) << "There is no akg kernel to be compiled.";
return true; return true;
} }

View File

@ -47,7 +47,6 @@ class AkgKernelBuilder {
bool AkgKernelParallelBuild(const std::vector<AnfNodePtr> &anf_nodes); bool AkgKernelParallelBuild(const std::vector<AnfNodePtr> &anf_nodes);
private: private:
std::vector<std::string> GetNotCachedKernelJsons(const std::vector<JsonNodePair> &build_args);
std::vector<JsonNodePair> GetNotCachedKernels(const std::vector<JsonNodePair> &build_args); std::vector<JsonNodePair> GetNotCachedKernels(const std::vector<JsonNodePair> &build_args);
std::vector<std::string> GetKernelJsonsByHashId(const std::vector<JsonNodePair> &build_args, std::vector<std::string> GetKernelJsonsByHashId(const std::vector<JsonNodePair> &build_args,
std::set<size_t> fetched_ids); std::set<size_t> fetched_ids);
@ -91,7 +90,6 @@ class AkgKernelPool {
int32_t UpdateAndWait(const std::set<size_t> &ids); int32_t UpdateAndWait(const std::set<size_t> &ids);
constexpr inline static size_t kMaxKernelNum_{1000}; constexpr inline static size_t kMaxKernelNum_{1000};
constexpr inline static key_t kSharedMemKey_{0x57565845};
// allocate memory for todo_list, doing_list, done_list // allocate memory for todo_list, doing_list, done_list
constexpr inline static size_t kListNum_{3}; constexpr inline static size_t kListNum_{3};

View File

@ -15,12 +15,6 @@
*/ */
#include "backend/kernel_compiler/akg/akg_kernel_json_decoder.h" #include "backend/kernel_compiler/akg/akg_kernel_json_decoder.h"
#include <algorithm>
#include <memory>
#include <sstream>
#include <string>
#include <map>
#include <vector>
#include "backend/kernel_compiler/akg/akg_kernel_json_generator.h" #include "backend/kernel_compiler/akg/akg_kernel_json_generator.h"
#include "backend/kernel_compiler/common_utils.h" #include "backend/kernel_compiler/common_utils.h"
#include "backend/session/anf_runtime_algorithm.h" #include "backend/session/anf_runtime_algorithm.h"

View File

@ -16,12 +16,6 @@
#include "backend/kernel_compiler/akg/akg_kernel_json_generator.h" #include "backend/kernel_compiler/akg/akg_kernel_json_generator.h"
#include <algorithm>
#include <functional>
#include <map>
#include <set>
#include <sstream>
#include <tuple>
#if ENABLE_GPU #if ENABLE_GPU
#include <cuda.h> #include <cuda.h>
#endif #endif

View File

@ -15,7 +15,6 @@
*/ */
#include "backend/kernel_compiler/akg/akg_kernel_metadata.h" #include "backend/kernel_compiler/akg/akg_kernel_metadata.h"
#include <memory>
#include "backend/session/anf_runtime_algorithm.h" #include "backend/session/anf_runtime_algorithm.h"
#include "backend/kernel_compiler/oplib/oplib.h" #include "backend/kernel_compiler/oplib/oplib.h"
#include "backend/kernel_compiler/common_utils.h" #include "backend/kernel_compiler/common_utils.h"

View File

@ -16,13 +16,6 @@
#include "backend/kernel_compiler/akg/ascend/akg_ascend_kernel_build.h" #include "backend/kernel_compiler/akg/ascend/akg_ascend_kernel_build.h"
#include <algorithm>
#include <map>
#include <memory>
#include <string>
#include <unordered_set>
#include <utility>
#include <vector>
#include "ir/dtype.h" #include "ir/dtype.h"
#include "ir/func_graph.h" #include "ir/func_graph.h"
#include "backend/kernel_compiler/common_utils.h" #include "backend/kernel_compiler/common_utils.h"
@ -34,11 +27,11 @@
namespace mindspore { namespace mindspore {
namespace kernel { namespace kernel {
KernelPackPtr AkgAscendKernelBuilder::AkgSearchCache(const std::string &kernel_name) { KernelPackPtr AkgAscendKernelBuilder::AkgSearchCache(const std::string &kernel_name) {
return tbe::TbeUtils::SearchCache(kernel_name, kProcessorAiCore); return tbe::TbeUtils::SearchCache(kernel_name, true);
} }
KernelPackPtr AkgAscendKernelBuilder::AkgInsertCache(const std::string &kernel_name) { KernelPackPtr AkgAscendKernelBuilder::AkgInsertCache(const std::string &kernel_name) {
return tbe::TbeUtils::InsertCache(kernel_name, kProcessorAiCore); return tbe::TbeUtils::InsertCache(kernel_name, kProcessorAiCore, true);
} }
void AkgAscendKernelBuilder::AkgSetKernelMod(const KernelPackPtr &kernel_pack, void AkgAscendKernelBuilder::AkgSetKernelMod(const KernelPackPtr &kernel_pack,

View File

@ -49,6 +49,5 @@ void AkgGpuKernelBuilder::AkgSetKernelMod(const KernelPackPtr &kernel_pack,
void AkgGpuKernelBuilder::AkgSaveJsonInfo(const string &kernel_name, const string &kernel_json) { void AkgGpuKernelBuilder::AkgSaveJsonInfo(const string &kernel_name, const string &kernel_json) {
kernel::SaveJsonInfo(kernel_name, kernel_json, kernel::KernelMeta::GetInstance()->kernel_meta_path()); kernel::SaveJsonInfo(kernel_name, kernel_json, kernel::KernelMeta::GetInstance()->kernel_meta_path());
} }
} // namespace kernel } // namespace kernel
} // namespace mindspore } // namespace mindspore

View File

@ -15,8 +15,7 @@
*/ */
#include "backend/kernel_compiler/akg/gpu/akg_gpu_kernel_mod.h" #include "backend/kernel_compiler/akg/gpu/akg_gpu_kernel_mod.h"
#include <fstream>
#include <algorithm>
#include "nlohmann/json.hpp" #include "nlohmann/json.hpp"
#include "utils/ms_utils.h" #include "utils/ms_utils.h"
@ -126,7 +125,7 @@ bool GpuKernelMod::Launch(const std::vector<AddressPtr> &inputs, const std::vect
[](const AddressPtr &output) -> void * { return reinterpret_cast<void *>(&(output->addr)); }); [](const AddressPtr &output) -> void * { return reinterpret_cast<void *>(&(output->addr)); });
if (!workspace.empty()) { if (!workspace.empty()) {
(void)std::transform(std::begin(workspace), std::end(workspace), std::back_inserter(runtimeargs), (void)std::transform(std::begin(workspace), std::end(workspace), std::back_inserter(runtimeargs),
[](const AddressPtr &addr) -> void * { return addr->addr; }); [](const AddressPtr &addr) -> void * { return reinterpret_cast<void *>(&(addr->addr)); });
} }
result = cuLaunchKernel(kernel_addr, thread_info[0], thread_info[1], thread_info[2], thread_info[3], thread_info[4], result = cuLaunchKernel(kernel_addr, thread_info[0], thread_info[1], thread_info[2], thread_info[3], thread_info[4],
thread_info[5], 0, reinterpret_cast<CUstream>(stream_ptr), thread_info[5], 0, reinterpret_cast<CUstream>(stream_ptr),

View File

@ -970,5 +970,39 @@ size_t CalOffset(const std::vector<int64_t> &start, const std::vector<int64_t> &
} }
return offset; return offset;
} }
size_t UnitSizeInBytes(const mindspore::TypeId &t) {
size_t bytes = 0;
switch (t) {
case kNumberTypeBool:
case kNumberTypeInt8:
case kNumberTypeUInt8:
bytes = sizeof(int8_t);
break;
case kNumberTypeInt16:
case kNumberTypeUInt16:
case kNumberTypeFloat16:
bytes = sizeof(int16_t);
break;
case kNumberTypeInt:
case kNumberTypeUInt:
case kNumberTypeInt32:
case kNumberTypeUInt32:
case kNumberTypeFloat:
case kNumberTypeFloat32:
bytes = sizeof(int32_t);
break;
case kNumberTypeUInt64:
case kNumberTypeInt64:
case kNumberTypeFloat64:
bytes = sizeof(int64_t);
break;
default:
MS_LOG(EXCEPTION) << "Invalid types " << t;
break;
}
return bytes;
}
} // namespace kernel } // namespace kernel
} // namespace mindspore } // namespace mindspore

View File

@ -143,6 +143,7 @@ size_t CalOffset(const std::vector<int64_t> &start, const std::vector<int64_t> &
std::vector<int64_t> CalDimOffset(const std::vector<int64_t> &input_shape); std::vector<int64_t> CalDimOffset(const std::vector<int64_t> &input_shape);
size_t GetCopySize(const std::vector<int64_t> &dim_offset, const std::vector<int64_t> &start, size_t GetCopySize(const std::vector<int64_t> &dim_offset, const std::vector<int64_t> &start,
const std::vector<int64_t> &stop); const std::vector<int64_t> &stop);
size_t UnitSizeInBytes(const mindspore::TypeId &t);
} // namespace kernel } // namespace kernel
} // namespace mindspore } // namespace mindspore

View File

@ -83,7 +83,7 @@ void AdamCPUKernel::LaunchAdamNnacl(const std::vector<kernel::AddressPtr> &input
MS_LOG(EXCEPTION) << "AdamFp32 failed."; MS_LOG(EXCEPTION) << "AdamFp32 failed.";
} }
}; };
CPUKernelUtils::ParallelForAutoSearch(task, lens, &parallel_search_info_); ParallelLaunchAutoSearch(task, lens, this, &parallel_search_info_);
} }
void AdamCPUKernel::InitKernel(const CNodePtr &kernel_node) { void AdamCPUKernel::InitKernel(const CNodePtr &kernel_node) {

View File

@ -19,6 +19,7 @@
#include "runtime/device/cpu/cpu_device_address.h" #include "runtime/device/cpu/cpu_device_address.h"
#include "nnacl/fp32/power_fp32.h" #include "nnacl/fp32/power_fp32.h"
#include "nnacl/fp32/sub_fp32.h" #include "nnacl/fp32/sub_fp32.h"
#include "nnacl/fp32/mul_fp32.h"
namespace mindspore { namespace mindspore {
namespace kernel { namespace kernel {
@ -54,7 +55,7 @@ void ArithmeticCPUKernel<T>::Sub(const T *input1, const T *input2, T *out) {
auto task = [&](size_t start, size_t end) { auto task = [&](size_t start, size_t end) {
ElementSub(input1 + start, input2 + start, out + start, end - start); ElementSub(input1 + start, input2 + start, out + start, end - start);
}; };
CPUKernelUtils::ParallelFor(task, output_size_, MAX_SUB_SERIAL_SIZE); ParallelLaunchAutoSearch(task, output_size_, this, &parallel_search_info_);
return; return;
} }
if (op_para.in_elements_num0_ == 1 || op_para.in_elements_num1_ == 1) { if (op_para.in_elements_num0_ == 1 || op_para.in_elements_num1_ == 1) {
@ -65,7 +66,7 @@ void ArithmeticCPUKernel<T>::Sub(const T *input1, const T *input2, T *out) {
ElementOptSub(input1 + start, input2, out + start, end - start, &op_para); ElementOptSub(input1 + start, input2, out + start, end - start, &op_para);
} }
}; };
CPUKernelUtils::ParallelFor(task, output_size_, MAX_SUB_SERIAL_SIZE); ParallelLaunchAutoSearch(task, output_size_, this, &parallel_search_info_);
return; return;
} }
} }
@ -84,6 +85,26 @@ void ArithmeticCPUKernel<T>::Sub(const T *input1, const T *input2, T *out) {
template <typename T> template <typename T>
void ArithmeticCPUKernel<T>::Mul(const T *input1, const T *input2, T *out) { void ArithmeticCPUKernel<T>::Mul(const T *input1, const T *input2, T *out) {
if constexpr (std::is_same_v<T, float>) {
if (input_shape1_ == input_shape2_) {
auto task = [&](size_t start, size_t end) {
ElementMul(input1 + start, input2 + start, out + start, end - start);
};
ParallelLaunchAutoSearch(task, output_size_, this, &parallel_search_info_);
return;
}
if (op_para.in_elements_num0_ == 1 || op_para.in_elements_num1_ == 1) {
auto task = [&](size_t start, size_t end) {
if (op_para.in_elements_num0_ == 1) {
ElementOptMul(input1, input2 + start, out + start, end - start, &op_para);
} else {
ElementOptMul(input1 + start, input2, out + start, end - start, &op_para);
}
};
ParallelLaunchAutoSearch(task, output_size_, this, &parallel_search_info_);
return;
}
}
BroadcastIterator base_iter(input_shape1_, input_shape2_, output_shape_); BroadcastIterator base_iter(input_shape1_, input_shape2_, output_shape_);
auto task = [&input1, &input2, &out, &base_iter](size_t start, size_t end) { auto task = [&input1, &input2, &out, &base_iter](size_t start, size_t end) {
auto iter = base_iter; auto iter = base_iter;
@ -128,21 +149,21 @@ void ArithmeticCPUKernel<T>::RealDiv(const T *input1, const T *input2, T *out) {
auto task = [&](size_t start, size_t end) { auto task = [&](size_t start, size_t end) {
ElementRealDiv<T>(input1 + start, input2 + start, out + start, end - start, 1, 1); ElementRealDiv<T>(input1 + start, input2 + start, out + start, end - start, 1, 1);
}; };
CPUKernelUtils::ParallelFor(task, output_size_, MAX_DIV_SERIAL_SIZE); ParallelLaunchAutoSearch(task, output_size_, this, &parallel_search_info_);
return; return;
} }
if (op_para.in_elements_num0_ == 1) { if (op_para.in_elements_num0_ == 1) {
auto task = [&](size_t start, size_t end) { auto task = [&](size_t start, size_t end) {
ElementRealDiv<T>(input1, input2 + start, out + start, end - start, 0, 1); ElementRealDiv<T>(input1, input2 + start, out + start, end - start, 0, 1);
}; };
CPUKernelUtils::ParallelFor(task, output_size_, MAX_DIV_SERIAL_SIZE); ParallelLaunchAutoSearch(task, output_size_, this, &parallel_search_info_);
return; return;
} }
if (op_para.in_elements_num1_ == 1) { if (op_para.in_elements_num1_ == 1) {
auto task = [&](size_t start, size_t end) { auto task = [&](size_t start, size_t end) {
ElementRealDiv<T>(input1 + start, input2, out + start, end - start, 1, 0); ElementRealDiv<T>(input1 + start, input2, out + start, end - start, 1, 0);
}; };
CPUKernelUtils::ParallelFor(task, output_size_, MAX_DIV_SERIAL_SIZE); ParallelLaunchAutoSearch(task, output_size_, this, &parallel_search_info_);
return; return;
} }
@ -339,7 +360,7 @@ void ArithmeticCPUKernel<T>::SquaredDifference(const T *input1, const T *input2,
iter.GenNextPos(); iter.GenNextPos();
} }
}; };
CPUKernelUtils::ParallelFor(task, output_size_); ParallelLaunchAutoSearch(task, output_size_, this, &parallel_search_info_);
} }
template <typename T> template <typename T>

View File

@ -77,6 +77,8 @@ MS_REG_CPU_KERNEL_T(RealDiv, KernelAttr(), ArithmeticCPUKernel, int64_t);
MS_REG_CPU_KERNEL_T(Div, KernelAttr(), ArithmeticCPUKernel, int32_t); MS_REG_CPU_KERNEL_T(Div, KernelAttr(), ArithmeticCPUKernel, int32_t);
MS_REG_CPU_KERNEL_T(Div, KernelAttr(), ArithmeticCPUKernel, float); MS_REG_CPU_KERNEL_T(Div, KernelAttr(), ArithmeticCPUKernel, float);
MS_REG_CPU_KERNEL_T(Div, KernelAttr(), ArithmeticCPUKernel, int64_t); MS_REG_CPU_KERNEL_T(Div, KernelAttr(), ArithmeticCPUKernel, int64_t);
MS_REG_CPU_KERNEL_T(Mul, KernelAttr(), ArithmeticCPUKernel, float);
MS_REG_CPU_KERNEL_T(Mul, KernelAttr(), ArithmeticCPUKernel, int32_t);
MS_REG_CPU_KERNEL_T( MS_REG_CPU_KERNEL_T(
FloorDiv, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), FloorDiv, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
ArithmeticCPUKernel, int64_t); ArithmeticCPUKernel, int64_t);

View File

@ -20,6 +20,7 @@
#include <map> #include <map>
#include "backend/kernel_compiler/cpu/arithmetic_self_cpu_kernel.h" #include "backend/kernel_compiler/cpu/arithmetic_self_cpu_kernel.h"
#include "runtime/device/cpu/cpu_device_address.h" #include "runtime/device/cpu/cpu_device_address.h"
#include "nnacl/fp32/exp_fp32.h"
namespace mindspore { namespace mindspore {
namespace kernel { namespace kernel {
@ -31,7 +32,15 @@ void Square(const T *in, T *out, size_t size) {
out[i] = in[i] * in[i]; out[i] = in[i] * in[i];
} }
}; };
CPUKernelUtils::ParallelFor(task, size, MAX_SQUARE_SERIAL_SIZE); ParallelLaunch(task, size, MAX_SQUARE_SERIAL_SIZE);
}
template <typename T>
void Exp(const T *in, T *out, size_t size) {
if constexpr (std::is_same_v<T, float>) {
auto task = [&in, &out](size_t start, size_t end) { ExpFp32(in + start, out + start, end - start); };
ParallelLaunch(task, size, MAX_EXP_SERIAL_SIZE);
}
} }
template <typename T> template <typename T>
@ -57,7 +66,7 @@ void Neg(const T *in, T *out, size_t size) {
out[i] = -in[i]; out[i] = -in[i];
} }
}; };
CPUKernelUtils::ParallelFor(task, size, MAX_NEG_SERIAL_SIZE); ParallelLaunch(task, size, MAX_NEG_SERIAL_SIZE);
} }
template <typename T> template <typename T>
@ -262,6 +271,7 @@ void Identity(const T *in, T *out, size_t size) {
static const std::map<std::string, OperateType> kArithmeticOpTypeMap = {{prim::kPrimNeg->name(), NEG}, static const std::map<std::string, OperateType> kArithmeticOpTypeMap = {{prim::kPrimNeg->name(), NEG},
{prim::kPrimSquare->name(), SQUARE}, {prim::kPrimSquare->name(), SQUARE},
{prim::kPrimOnesLike->name(), ONESLIKE}, {prim::kPrimOnesLike->name(), ONESLIKE},
{prim::kPrimExp->name(), EXP},
{prim::kPrimZerosLike->name(), ZEROSLIKE}, {prim::kPrimZerosLike->name(), ZEROSLIKE},
{prim::kPrimLogicalNot->name(), LOGICALNOT}, {prim::kPrimLogicalNot->name(), LOGICALNOT},
{prim::kPrimSign->name(), SIGN}, {prim::kPrimSign->name(), SIGN},
@ -324,17 +334,29 @@ void ArithmeticSelfCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs
T *output = reinterpret_cast<T *>(outputs[0]->addr); T *output = reinterpret_cast<T *>(outputs[0]->addr);
size_t lens = outputs[0]->size > 0 ? static_cast<size_t>(outputs[0]->size / sizeof(T)) : 1; size_t lens = outputs[0]->size > 0 ? static_cast<size_t>(outputs[0]->size / sizeof(T)) : 1;
static const std::map<OperateType, std::function<void(const T *in, T *out, size_t size)>> kArithmeticOpFuncMap = { static const std::map<OperateType, std::function<void(const T *in, T *out, size_t size)>> kArithmeticOpFuncMap = {
{SQUARE, Square<T>}, {SIGN, Sign<T>}, {SQUARE, Square<T>},
{NEG, Neg<T>}, {LOGICALNOT, LogicalNot<T>}, {SIGN, Sign<T>},
{ONESLIKE, OnesLike<T>}, {ZEROSLIKE, ZerosLike<T>}, {NEG, Neg<T>},
{FLOOR, Floor<T>}, {RECIPROCAL, Reciprocal<T>}, {LOGICALNOT, LogicalNot<T>},
{GELU, Gelu<T>}, {SIN, Sin<T>}, {ONESLIKE, OnesLike<T>},
{COS, Cos<T>}, {TAN, Tan<T>}, {ZEROSLIKE, ZerosLike<T>},
{ASIN, Asin<T>}, {ACOS, ACos<T>}, {FLOOR, Floor<T>},
{ATAN, Atan<T>}, {SINH, Sinh<T>}, {RECIPROCAL, Reciprocal<T>},
{COSH, Cosh<T>}, {ASINH, Asinh<T>}, {GELU, Gelu<T>},
{ACOSH, Acosh<T>}, {ATANH, Atanh<T>}, {SIN, Sin<T>},
{RINT, Rint<T>}, {ROUND, Round<T>}}; {COS, Cos<T>},
{TAN, Tan<T>},
{ASIN, Asin<T>},
{ACOS, ACos<T>},
{ATAN, Atan<T>},
{SINH, Sinh<T>},
{COSH, Cosh<T>},
{ASINH, Asinh<T>},
{ACOSH, Acosh<T>},
{ATANH, Atanh<T>},
{RINT, Rint<T>},
{ROUND, Round<T>},
{EXP, Exp<T>}};
if (kArithmeticOpFuncMap.find(operate_type_) != kArithmeticOpFuncMap.end()) { if (kArithmeticOpFuncMap.find(operate_type_) != kArithmeticOpFuncMap.end()) {
kArithmeticOpFuncMap.at(operate_type_)(input, output, lens); kArithmeticOpFuncMap.at(operate_type_)(input, output, lens);
} else { } else {

View File

@ -20,8 +20,9 @@
#include "backend/kernel_compiler/cpu/cpu_kernel.h" #include "backend/kernel_compiler/cpu/cpu_kernel.h"
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" #include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
const float MAX_NEG_SERIAL_SIZE = 20000; const float MAX_NEG_SERIAL_SIZE = 5000;
const float MAX_SQUARE_SERIAL_SIZE = 20000; const float MAX_SQUARE_SERIAL_SIZE = 5000;
const float MAX_EXP_SERIAL_SIZE = 15000;
namespace mindspore { namespace mindspore {
namespace kernel { namespace kernel {
@ -58,6 +59,10 @@ class IdentityCPUKernel : public ArithmeticSelfCPUKernel {
MS_REG_CPU_KERNEL(Square, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), MS_REG_CPU_KERNEL(Square, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
ArithmeticSelfCPUKernel); ArithmeticSelfCPUKernel);
MS_REG_CPU_KERNEL(Square, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
ArithmeticSelfCPUKernel);
MS_REG_CPU_KERNEL(Exp, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
ArithmeticSelfCPUKernel);
MS_REG_CPU_KERNEL(Neg, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), MS_REG_CPU_KERNEL(Neg, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
ArithmeticSelfCPUKernel); ArithmeticSelfCPUKernel);
MS_REG_CPU_KERNEL(Neg, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), MS_REG_CPU_KERNEL(Neg, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),

View File

@ -90,7 +90,7 @@ bool BiasAddCPUKernel::Launch(const std::vector<AddressPtr> &inputs, const std::
ElementAdd(src_addr + n_offset, bias_addr, output_addr + n_offset, input_shape_[1]); ElementAdd(src_addr + n_offset, bias_addr, output_addr + n_offset, input_shape_[1]);
} }
}; };
CPUKernelUtils::ParallelForAutoSearch(task, input_shape_[0], &parallel_search_info_); ParallelLaunchAutoSearch(task, input_shape_[0], this, &parallel_search_info_);
} }
return true; return true;
} }

View File

@ -55,7 +55,7 @@ bool BiasAddGradCPUKernel::Launch(const std::vector<AddressPtr> &inputs, const s
auto task = [&](size_t start, size_t end) { auto task = [&](size_t start, size_t end) {
ReduceSumDim2Axis0(end - start, input_shape_[1], input_shape_[0], input_addr + start, output_addr + start); ReduceSumDim2Axis0(end - start, input_shape_[1], input_shape_[0], input_addr + start, output_addr + start);
}; };
CPUKernelUtils::ParallelForAutoSearch(task, input_shape_[1], &parallel_search_info_); ParallelLaunchAutoSearch(task, input_shape_[1], this, &parallel_search_info_);
} }
return true; return true;
} }

View File

@ -74,7 +74,7 @@ bool ConcatCPUKernel<T>::Launch(const std::vector<kernel::AddressPtr> &inputs, c
} }
} }
}; };
CPUKernelUtils::ParallelForAutoSearch(task, before_axis, &parallel_search_info_); ParallelLaunchAutoSearch(task, before_axis, this, &parallel_search_info_);
return true; return true;
} }

View File

@ -138,6 +138,77 @@ void CPUKernelUtils::ParallelForAutoSearch(const CTask &task, size_t count, Para
} }
} }
ActorThreadPool *GetActorMgrInnerThreadPool() {
auto actor_manager = ActorMgr::GetActorMgrRef();
auto thread_pool = actor_manager->GetActorThreadPool();
// Init thread_pool if env is windows or ascend, in case that it won't be init in graph_scheduler.
if (thread_pool == nullptr) {
const size_t kMaxThreadNum = 23;
size_t max_thread_num = std::thread::hardware_concurrency() - 1;
if (max_thread_num < 1) {
max_thread_num = 1;
}
max_thread_num = max_thread_num < kMaxThreadNum ? max_thread_num : kMaxThreadNum;
actor_manager->Initialize(true, 0, max_thread_num);
thread_pool = actor_manager->GetActorThreadPool();
MS_EXCEPTION_IF_NULL(thread_pool);
}
return thread_pool;
}
// Use threadpool of mindrt
void ParallelLaunch(const CTask &task, size_t count, float block_size, Content content) {
auto thread_pool = GetActorMgrInnerThreadPool();
size_t kernel_thread_num = thread_pool->GetKernelThreadNum();
if (kernel_thread_num == 0) {
MS_LOG(EXCEPTION) << "Actor inner pool has been init, but kernel thread is 0!";
}
size_t thread_num = count < block_size * kernel_thread_num ? std::ceil(count / block_size) : kernel_thread_num;
size_t once_compute_size = (count + thread_num - 1) / thread_num;
size_t task_num = count / once_compute_size;
if (count % once_compute_size != 0) {
task_num += 1;
}
auto func = [&](void *, int task_id, float, float) {
size_t start = task_id * once_compute_size;
size_t end = (start + once_compute_size) > count ? count : (start + once_compute_size);
task(start, end);
return common::SUCCESS;
};
thread_pool->ParallelLaunch(func, content, task_num);
}
void ParallelLaunchAutoSearch(const CTask &task, size_t count, Content content,
ParallelSearchInfo *parallel_search_info) {
const size_t MAX_POW = 6;
const size_t AVG_COUNT = 5;
size_t current_pow = parallel_search_info->search_count / AVG_COUNT;
if (current_pow < MAX_POW) {
if (parallel_search_info->search_count % AVG_COUNT == 0) {
parallel_search_info->tmp_sum_cost_time = 0;
}
float block_size = static_cast<float>(count) / std::pow(2.0f, current_pow);
double start_time = GetTime();
ParallelLaunch(task, count, block_size, content);
double cost_time = GetTime() - start_time;
parallel_search_info->tmp_sum_cost_time += cost_time;
parallel_search_info->search_count++;
if (parallel_search_info->search_count % AVG_COUNT == 0) {
double avg_time = parallel_search_info->tmp_sum_cost_time / AVG_COUNT;
if (parallel_search_info->min_cost_time > avg_time) {
parallel_search_info->min_cost_time = avg_time;
parallel_search_info->best_block_size = block_size;
parallel_search_info->best_pow = current_pow;
} else if (current_pow - parallel_search_info->best_pow >= 2) {
parallel_search_info->search_count = AVG_COUNT * MAX_POW;
}
}
} else {
ParallelLaunch(task, count, parallel_search_info->best_block_size, content);
}
}
std::vector<size_t> CPUKernelUtils::FlatShapeByAxis(const std::vector<size_t> &shape, int axis) { std::vector<size_t> CPUKernelUtils::FlatShapeByAxis(const std::vector<size_t> &shape, int axis) {
if (axis < 0) { if (axis < 0) {
axis = axis + SizeToInt(shape.size()); axis = axis + SizeToInt(shape.size());

View File

@ -25,6 +25,8 @@
#include "backend/session/anf_runtime_algorithm.h" #include "backend/session/anf_runtime_algorithm.h"
#include "backend/kernel_compiler/common_utils.h" #include "backend/kernel_compiler/common_utils.h"
#include "ir/anf.h" #include "ir/anf.h"
#include "runtime/framework/graph_scheduler.h"
#include "actor/actormgr.h"
using mindspore::kernel::Address; using mindspore::kernel::Address;
using mindspore::kernel::AddressPtr; using mindspore::kernel::AddressPtr;
@ -62,6 +64,7 @@ const char DELTA[] = "delta";
const char SORTED[] = "sorted"; const char SORTED[] = "sorted";
const char ADJ_ST[] = "adjoint_st"; const char ADJ_ST[] = "adjoint_st";
const char ADJ_dT[] = "adjoint_dt"; const char ADJ_dT[] = "adjoint_dt";
const char PERIODS[] = "periods";
enum OperateType { enum OperateType {
ADD = 0, ADD = 0,
@ -119,6 +122,7 @@ enum OperateType {
ATAN2, ATAN2,
RINT, RINT,
ROUND, ROUND,
EXP,
IDENTITY, IDENTITY,
}; };
@ -152,6 +156,19 @@ class CPUKernel : public kernel::KernelMod {
std::vector<size_t> output_size_list_; std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_; std::vector<size_t> workspace_size_list_;
ParallelSearchInfo parallel_search_info_; ParallelSearchInfo parallel_search_info_;
template <typename T>
inline T *GetDeviceAddress(const std::vector<AddressPtr> &addr_list, size_t index) {
if (index >= addr_list.size()) {
MS_LOG(EXCEPTION) << "Address index(" << index << ") out of range(" << addr_list.size() << ")";
}
if ((addr_list[index] == nullptr) || (addr_list[index]->addr == nullptr) || (addr_list[index]->size == 0)) {
MS_LOG(EXCEPTION) << "The device address is empty, address index: " << index;
}
return reinterpret_cast<T *>(addr_list[index]->addr);
}
}; };
class CPUKernelUtils { class CPUKernelUtils {
@ -209,6 +226,12 @@ class TransposeIterator {
std::vector<size_t> axes_; std::vector<size_t> axes_;
size_t pos_{0}; size_t pos_{0};
}; };
ActorThreadPool *GetActorMgrInnerThreadPool();
void ParallelLaunch(const CTask &task, size_t count, float block_size = 128.0, Content content = nullptr);
void ParallelLaunchAutoSearch(const CTask &task, size_t count, Content content,
ParallelSearchInfo *parallel_search_info);
} // namespace kernel } // namespace kernel
} // namespace mindspore } // namespace mindspore

View File

@ -144,8 +144,7 @@ bool CropAndResizeCPUKernel<T>::Launch(const std::vector<kernel::AddressPtr> &in
const int bottom_y_index = ceilf(target_y); const int bottom_y_index = ceilf(target_y);
const int left_x_index = floorf(target_x); const int left_x_index = floorf(target_x);
const int right_x_index = ceilf(target_x); const int right_x_index = ceilf(target_x);
const float y_lerp = target_y - top_y_index;
const float x_lerp = target_x - left_x_index;
const float top_left = static_cast<float>( const float top_left = static_cast<float>(
input_image[((box_index * input_height_ + top_y_index) * input_width_ + left_x_index) * channel_ + input_image[((box_index * input_height_ + top_y_index) * input_width_ + left_x_index) * channel_ +
pos_channel]); pos_channel]);
@ -158,9 +157,9 @@ bool CropAndResizeCPUKernel<T>::Launch(const std::vector<kernel::AddressPtr> &in
const float bottom_right = static_cast<float>( const float bottom_right = static_cast<float>(
input_image[((box_index * input_height_ + bottom_y_index) * input_width_ + right_x_index) * channel_ + input_image[((box_index * input_height_ + bottom_y_index) * input_width_ + right_x_index) * channel_ +
pos_channel]); pos_channel]);
const float top = top_left + (top_right - top_left) * x_lerp; const float top = top_left + (top_right - top_left) * (target_x - left_x_index);
const float bottom = bottom_left + (bottom_right - bottom_left) * x_lerp; const float bottom = bottom_left + (bottom_right - bottom_left) * (target_x - left_x_index);
output[pos] = top + (bottom - top) * y_lerp; output[pos] = top + (bottom - top) * (target_y - top_y_index);
} else if (method_ == 3) { } else if (method_ == 3) {
int y1h = static_cast<int>(y1 * input_height_); int y1h = static_cast<int>(y1 * input_height_);
int x1w = static_cast<int>(x1 * input_width_); int x1w = static_cast<int>(x1 * input_width_);
@ -170,36 +169,37 @@ bool CropAndResizeCPUKernel<T>::Launch(const std::vector<kernel::AddressPtr> &in
int h = ((y2h - y1h + 1) > 1) ? y2h - y1h + 1 : 1; int h = ((y2h - y1h + 1) > 1) ? y2h - y1h + 1 : 1;
float y_point = (pos_y + 0.5) * (h / static_cast<float>(final_height_)) - 0.5; float y_point = (pos_y + 0.5) * (h / static_cast<float>(final_height_)) - 0.5;
int top_y_index = floorf(y_point); int top_y_index = std::min(std::max(0, static_cast<int>(floorf(y_point))), h - 1);
top_y_index = std::min(std::max(0, top_y_index), h - 1); int bottom_y_index = std::min(std::max(0, static_cast<int>(ceilf(y_point))), h - 1);
int bottom_y_index = ceilf(y_point);
bottom_y_index = std::min(std::max(0, bottom_y_index), h - 1);
float x_point = (pos_x + 0.5) * (w / static_cast<float>(final_width_)) - 0.5; float x_point = (pos_x + 0.5) * (w / static_cast<float>(final_width_)) - 0.5;
int left_x_index = floorf(x_point); int left_x_index = std::min(std::max(0, static_cast<int>(floorf(x_point))), w - 1);
left_x_index = std::min(std::max(0, left_x_index), w - 1); int right_x_index = std::min(std::max(0, static_cast<int>(ceilf(x_point))), w - 1);
int right_x_index = ceilf(x_point);
right_x_index = std::min(std::max(0, right_x_index), w - 1);
const float y_lerp = y_point - top_y_index; const float y_lerp = y_point - top_y_index;
const float x_lerp = x_point - left_x_index; const float x_lerp = x_point - left_x_index;
const int y_top_index = box_index * input_height_ + y1h + top_y_index;
const int y_bottom_index = box_index * input_height_ + y1h + bottom_y_index;
const float top_left = const int y_top_index = std::max(0, y1h + top_y_index);
static_cast<float>(input_image[(y_top_index * input_width_ + x1w + left_x_index) * channel_ + pos_channel]); const int y_bottom_index = std::max(0, y1h + bottom_y_index);
const float top_right = const int x_left_index = std::max(0, x1w + left_x_index);
static_cast<float>(input_image[(y_top_index * input_width_ + x1w + right_x_index) * channel_ + pos_channel]); const int x_right_index = std::max(0, x1w + right_x_index);
const float top_left = static_cast<float>(
input_image[((box_index * input_height_ + y_top_index) * input_width_ + x_left_index) * channel_ +
pos_channel]);
const float top_right = static_cast<float>(
input_image[((box_index * input_height_ + y_top_index) * input_width_ + x_right_index) * channel_ +
pos_channel]);
const float bottom_left = static_cast<float>( const float bottom_left = static_cast<float>(
input_image[(y_bottom_index * input_width_ + x1w + left_x_index) * channel_ + pos_channel]); input_image[((box_index * input_height_ + y_bottom_index) * input_width_ + x_left_index) * channel_ +
pos_channel]);
const float bottom_right = static_cast<float>( const float bottom_right = static_cast<float>(
input_image[(y_bottom_index * input_width_ + x1w + right_x_index) * channel_ + pos_channel]); input_image[((box_index * input_height_ + y_bottom_index) * input_width_ + x_right_index) * channel_ +
pos_channel]);
output[pos] = top_left * (1 - y_lerp) * (1 - x_lerp) + bottom_right * y_lerp * x_lerp +
top_right * (1 - y_lerp) * x_lerp + bottom_left * y_lerp * (1 - x_lerp);
float ret = top_left * (1 - y_lerp) * (1 - x_lerp) + bottom_right * y_lerp * x_lerp +
top_right * (1 - y_lerp) * x_lerp + bottom_left * y_lerp * (1 - x_lerp);
output[pos] = ret;
} else { } else {
// Nearest Neighbour // Nearest Neighbour
const int closest_x_index = roundf(target_x); const int closest_x_index = roundf(target_x);

View File

@ -35,15 +35,14 @@ class CropAndResizeCPUKernel : public CPUKernel {
const std::vector<AddressPtr> &outputs) override; const std::vector<AddressPtr> &outputs) override;
private: private:
int method_; int method_{1};
float extrapolation_value_; float extrapolation_value_{0.0};
int input_crop_size_; int output_size_{0};
int output_size_; int input_height_{0};
int input_height_; int input_width_{0};
int input_width_; int final_height_{0};
int final_height_; int final_width_{0};
int final_width_; int channel_{0};
int channel_;
}; };
MS_REG_CPU_KERNEL_T(CropAndResize, MS_REG_CPU_KERNEL_T(CropAndResize,

View File

@ -259,9 +259,9 @@ bool EltWiseGradCPUKernel<T>::Launch(const std::vector<kernel::AddressPtr> &inpu
const auto input1 = reinterpret_cast<T *>(inputs[1]->addr); const auto input1 = reinterpret_cast<T *>(inputs[1]->addr);
auto output = reinterpret_cast<T *>(outputs[0]->addr); auto output = reinterpret_cast<T *>(outputs[0]->addr);
CPUKernelUtils::ParallelForAutoSearch( ParallelLaunchAutoSearch(
std::bind(elt_map.at(kernel_name_), this, input0, input1, output, std::placeholders::_1, std::placeholders::_2), std::bind(elt_map.at(kernel_name_), this, input0, input1, output, std::placeholders::_1, std::placeholders::_2),
outputs[0]->size / sizeof(T), &parallel_search_info_); outputs[0]->size / sizeof(T), this, &parallel_search_info_);
return true; return true;
} }
} // namespace kernel } // namespace kernel

View File

@ -30,7 +30,7 @@
namespace mindspore { namespace mindspore {
namespace kernel { namespace kernel {
// The duration between two downloading requests when return code is ResponseCode_SucNotReady. // The duration between two PullWeights requests when return code is ResponseCode_SucNotReady.
constexpr int kRetryDurationOfPullWeights = 200; constexpr int kRetryDurationOfPullWeights = 200;
template <typename T> template <typename T>
class FusedPullWeightKernel : public CPUKernel { class FusedPullWeightKernel : public CPUKernel {
@ -51,19 +51,17 @@ class FusedPullWeightKernel : public CPUKernel {
MS_EXCEPTION_IF_NULL(fbb); MS_EXCEPTION_IF_NULL(fbb);
total_iteration_++; total_iteration_++;
uint64_t step_num_per_iteration = fl::worker::FLWorker::GetInstance().worker_step_num_per_iteration();
// The worker has to train kWorkerTrainStepNum standalone iterations before it communicates with server. // The worker has to train kWorkerTrainStepNum standalone iterations before it communicates with server.
if (total_iteration_ % fl::worker::FLWorker::GetInstance().worker_step_num_per_iteration() != MS_LOG(INFO) << "Try to pull weights. Local step number: " << total_iteration_
fl::kTrainBeginStepNum) { << ", step number needs to run per iteration: " << step_num_per_iteration;
if (step_num_per_iteration != fl::kOneStepPerIteration &&
total_iteration_ % step_num_per_iteration != fl::kTrainBeginStepNum) {
return true; return true;
} }
fl_iteration_++; fl_iteration_++;
if (fl_iteration_ > ps::PSContext::instance()->fl_iteration_num()) { MS_LOG(INFO) << "Launching pulling weight for federated learning iteration " << fl_iteration_;
MS_LOG(INFO) << ps::PSContext::instance()->fl_iteration_num() << " iterations are completed.";
fl_iteration_ = 1;
}
MS_LOG(INFO) << "Start pulling weight for federated learning iteration " << fl_iteration_;
if (!BuildPullWeightReq(fbb)) { if (!BuildPullWeightReq(fbb)) {
MS_LOG(EXCEPTION) << "Building request for FusedPullWeight failed."; MS_LOG(EXCEPTION) << "Building request for FusedPullWeight failed.";
return false; return false;
@ -73,11 +71,16 @@ class FusedPullWeightKernel : public CPUKernel {
const schema::ResponsePullWeight *pull_weight_rsp = nullptr; const schema::ResponsePullWeight *pull_weight_rsp = nullptr;
int retcode = schema::ResponseCode_SucNotReady; int retcode = schema::ResponseCode_SucNotReady;
while (retcode == schema::ResponseCode_SucNotReady) { while (retcode == schema::ResponseCode_SucNotReady) {
if (!fl::worker::FLWorker::GetInstance().running()) {
MS_LOG(WARNING) << "Worker has finished.";
return true;
}
if (!fl::worker::FLWorker::GetInstance().SendToServer( if (!fl::worker::FLWorker::GetInstance().SendToServer(
0, fbb->GetBufferPointer(), fbb->GetSize(), ps::core::TcpUserCommand::kPullWeight, &pull_weight_rsp_msg)) { 0, fbb->GetBufferPointer(), fbb->GetSize(), ps::core::TcpUserCommand::kPullWeight, &pull_weight_rsp_msg)) {
MS_LOG(WARNING) << "Sending request for FusedPullWeight to server 0 failed. This iteration is dropped."; MS_LOG(WARNING) << "Sending request for FusedPullWeight to server 0 failed. Retry later.";
fl::worker::FLWorker::GetInstance().SetIterationRunning(); retcode = schema::ResponseCode_SucNotReady;
return true; std::this_thread::sleep_for(std::chrono::milliseconds(kRetryDurationOfPullWeights));
continue;
} }
MS_EXCEPTION_IF_NULL(pull_weight_rsp_msg); MS_EXCEPTION_IF_NULL(pull_weight_rsp_msg);
@ -88,6 +91,8 @@ class FusedPullWeightKernel : public CPUKernel {
fl_iteration_ = pull_weight_rsp->iteration(); fl_iteration_ = pull_weight_rsp->iteration();
MS_LOG(DEBUG) << "Server is not ready for downloading yet. Reason: " << pull_weight_rsp->reason()->str() MS_LOG(DEBUG) << "Server is not ready for downloading yet. Reason: " << pull_weight_rsp->reason()->str()
<< ". Retry later."; << ". Retry later.";
// Recreate fbb to avoid memory leak of FlatBuffers.
fbb = std::make_shared<fl::FBBuilder>();
if (!BuildPullWeightReq(fbb)) { if (!BuildPullWeightReq(fbb)) {
MS_LOG(EXCEPTION) << "Building request for FusedDownloadWeightsByKeys failed."; MS_LOG(EXCEPTION) << "Building request for FusedDownloadWeightsByKeys failed.";
return false; return false;
@ -116,7 +121,7 @@ class FusedPullWeightKernel : public CPUKernel {
return false; return false;
} }
} }
MS_LOG(INFO) << "Pull weights for " << weight_full_names_ << " succeed. Iteration: " << fl_iteration_; MS_LOG(INFO) << "Pull weights for " << weight_full_names_ << " success. Iteration: " << fl_iteration_;
fl::worker::FLWorker::GetInstance().SetIterationRunning(); fl::worker::FLWorker::GetInstance().SetIterationRunning();
return true; return true;
} }

View File

@ -28,7 +28,7 @@
namespace mindspore { namespace mindspore {
namespace kernel { namespace kernel {
// The duration between two uploading requests when return code is ResponseCode_SucNotReady. // The duration between two PushWeights requests when return code is ResponseCode_SucNotReady.
constexpr int kRetryDurationOfPushWeights = 200; constexpr int kRetryDurationOfPushWeights = 200;
template <typename T> template <typename T>
class FusedPushWeightKernel : public CPUKernel { class FusedPushWeightKernel : public CPUKernel {
@ -49,19 +49,17 @@ class FusedPushWeightKernel : public CPUKernel {
MS_EXCEPTION_IF_NULL(fbb); MS_EXCEPTION_IF_NULL(fbb);
total_iteration_++; total_iteration_++;
uint64_t step_num_per_iteration = fl::worker::FLWorker::GetInstance().worker_step_num_per_iteration();
// The worker has to train kWorkerTrainStepNum standalone iterations before it communicates with server. // The worker has to train kWorkerTrainStepNum standalone iterations before it communicates with server.
if (total_iteration_ % fl::worker::FLWorker::GetInstance().worker_step_num_per_iteration() != MS_LOG(INFO) << "Try to push weights. Local step number: " << total_iteration_
fl::kTrainBeginStepNum) { << ", step number needs to run per iteration: " << step_num_per_iteration;
if (step_num_per_iteration != fl::kOneStepPerIteration &&
total_iteration_ % step_num_per_iteration != fl::kTrainEndStepNum) {
return true; return true;
} }
fl_iteration_++; fl_iteration_++;
if (fl_iteration_ > ps::PSContext::instance()->fl_iteration_num()) { MS_LOG(INFO) << "Launching pushing weight for federated learning iteration " << fl_iteration_;
MS_LOG(INFO) << ps::PSContext::instance()->fl_iteration_num() << " iterations are completed.";
fl_iteration_ = 1;
}
MS_LOG(INFO) << "Start pushing weight for federated learning iteration " << fl_iteration_;
if (!BuildPushWeightReq(fbb, inputs)) { if (!BuildPushWeightReq(fbb, inputs)) {
MS_LOG(EXCEPTION) << "Building request for FusedPushWeight failed."; MS_LOG(EXCEPTION) << "Building request for FusedPushWeight failed.";
return false; return false;
@ -73,13 +71,17 @@ class FusedPushWeightKernel : public CPUKernel {
const schema::ResponsePushWeight *push_weight_rsp = nullptr; const schema::ResponsePushWeight *push_weight_rsp = nullptr;
int retcode = schema::ResponseCode_SucNotReady; int retcode = schema::ResponseCode_SucNotReady;
while (retcode == schema::ResponseCode_SucNotReady) { while (retcode == schema::ResponseCode_SucNotReady) {
if (!fl::worker::FLWorker::GetInstance().running()) {
MS_LOG(WARNING) << "Worker has finished.";
return true;
}
if (!fl::worker::FLWorker::GetInstance().SendToServer(i, fbb->GetBufferPointer(), fbb->GetSize(), if (!fl::worker::FLWorker::GetInstance().SendToServer(i, fbb->GetBufferPointer(), fbb->GetSize(),
ps::core::TcpUserCommand::kPushWeight, ps::core::TcpUserCommand::kPushWeight,
&push_weight_rsp_msg)) { &push_weight_rsp_msg)) {
MS_LOG(WARNING) << "Sending request for FusedPushWeight to server " << i MS_LOG(WARNING) << "Sending request for FusedPushWeight to server " << i << " failed.";
<< " failed. This iteration is dropped."; retcode = schema::ResponseCode_SucNotReady;
fl::worker::FLWorker::GetInstance().SetIterationCompleted(); std::this_thread::sleep_for(std::chrono::milliseconds(kRetryDurationOfPushWeights));
return true; continue;
} }
MS_EXCEPTION_IF_NULL(push_weight_rsp_msg); MS_EXCEPTION_IF_NULL(push_weight_rsp_msg);
@ -105,8 +107,7 @@ class FusedPushWeightKernel : public CPUKernel {
} }
} }
MS_LOG(INFO) << "Push weights for " << weight_full_names_ << " succeed. Iteration: " << fl_iteration_; MS_LOG(INFO) << "Push weights for " << weight_full_names_ << " success. Iteration: " << fl_iteration_;
fl::worker::FLWorker::GetInstance().SetIterationCompleted();
return true; return true;
} }

View File

@ -52,6 +52,26 @@ MS_REG_CPU_KERNEL_T(
MaskedSelect, MaskedSelect,
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeInt32), KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeInt32),
MaskedSelectCPUKernel, int); MaskedSelectCPUKernel, int);
MS_REG_CPU_KERNEL_T(
MaskedSelect,
KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeInt16),
MaskedSelectCPUKernel, int16_t);
MS_REG_CPU_KERNEL_T(
MaskedSelect,
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeInt64),
MaskedSelectCPUKernel, int64_t);
MS_REG_CPU_KERNEL_T(
MaskedSelect,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeFloat16),
MaskedSelectCPUKernel, float16);
MS_REG_CPU_KERNEL_T(
MaskedSelect,
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeFloat64),
MaskedSelectCPUKernel, double);
} // namespace kernel } // namespace kernel
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_MASKED_SELECTED_CPU_KERNEL_H_ #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_MASKED_SELECTED_CPU_KERNEL_H_

View File

@ -58,6 +58,38 @@ MS_REG_CPU_KERNEL_T(MaskedSelectGrad,
.AddInputAttr(kNumberTypeInt32) .AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt32), .AddOutputAttr(kNumberTypeInt32),
MaskedSelectGradCPUKernel, int); MaskedSelectGradCPUKernel, int);
MS_REG_CPU_KERNEL_T(MaskedSelectGrad,
KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeBool)
.AddInputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16),
MaskedSelectGradCPUKernel, float16);
MS_REG_CPU_KERNEL_T(MaskedSelectGrad,
KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeBool)
.AddInputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeFloat64),
MaskedSelectGradCPUKernel, double);
MS_REG_CPU_KERNEL_T(MaskedSelectGrad,
KernelAttr()
.AddInputAttr(kNumberTypeInt16)
.AddInputAttr(kNumberTypeBool)
.AddInputAttr(kNumberTypeInt16)
.AddOutputAttr(kNumberTypeInt16),
MaskedSelectGradCPUKernel, int16_t);
MS_REG_CPU_KERNEL_T(MaskedSelectGrad,
KernelAttr()
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeBool)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeInt64),
MaskedSelectGradCPUKernel, int64_t);
} // namespace kernel } // namespace kernel
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_MASKED_SELECTED_GRAD_CPU_KERNEL_H_ #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_MASKED_SELECTED_GRAD_CPU_KERNEL_H_

View File

@ -86,6 +86,8 @@ bool MirrorPadCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, c
LaunchKernel<float16>(inputs, outputs); LaunchKernel<float16>(inputs, outputs);
} else if (dtype_ == kNumberTypeFloat32) { } else if (dtype_ == kNumberTypeFloat32) {
LaunchKernel<float>(inputs, outputs); LaunchKernel<float>(inputs, outputs);
} else if (dtype_ == kNumberTypeFloat64) {
LaunchKernel<double>(inputs, outputs);
} else if (dtype_ == kNumberTypeInt32) { } else if (dtype_ == kNumberTypeInt32) {
LaunchKernel<int>(inputs, outputs); LaunchKernel<int>(inputs, outputs);
} else { } else {

View File

@ -74,6 +74,11 @@ MS_REG_CPU_KERNEL(
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32), KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32),
MirrorPadCPUKernel); MirrorPadCPUKernel);
MS_REG_CPU_KERNEL(
MirrorPad,
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat64),
MirrorPadCPUKernel);
MS_REG_CPU_KERNEL( MS_REG_CPU_KERNEL(
MirrorPad, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32), MirrorPad, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32),
MirrorPadCPUKernel); MirrorPadCPUKernel);
@ -88,6 +93,11 @@ MS_REG_CPU_KERNEL(
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32), KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
MirrorPadCPUKernel); MirrorPadCPUKernel);
MS_REG_CPU_KERNEL(
MirrorPad,
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat64),
MirrorPadCPUKernel);
MS_REG_CPU_KERNEL( MS_REG_CPU_KERNEL(
MirrorPad, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), MirrorPad, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
MirrorPadCPUKernel); MirrorPadCPUKernel);

View File

@ -110,6 +110,8 @@ bool MirrorPadGradCPUKernel::Launch(const std::vector<kernel::AddressPtr> &input
LaunchKernel<float16>(inputs, workspace, outputs); LaunchKernel<float16>(inputs, workspace, outputs);
} else if (dtype_ == kNumberTypeFloat32) { } else if (dtype_ == kNumberTypeFloat32) {
LaunchKernel<float>(inputs, workspace, outputs); LaunchKernel<float>(inputs, workspace, outputs);
} else if (dtype_ == kNumberTypeFloat64) {
LaunchKernel<double>(inputs, workspace, outputs);
} else if (dtype_ == kNumberTypeInt32) { } else if (dtype_ == kNumberTypeInt32) {
LaunchKernel<int>(inputs, workspace, outputs); LaunchKernel<int>(inputs, workspace, outputs);
} else { } else {
@ -130,6 +132,8 @@ void MirrorPadGradCPUKernel::InitInputOutputSize(const CNodePtr &kernel_node) {
InitWorkspaceSize<float16>(); InitWorkspaceSize<float16>();
} else if (dtype_ == kNumberTypeFloat32) { } else if (dtype_ == kNumberTypeFloat32) {
InitWorkspaceSize<float>(); InitWorkspaceSize<float>();
} else if (dtype_ == kNumberTypeFloat64) {
InitWorkspaceSize<double>();
} else if (dtype_ == kNumberTypeInt32) { } else if (dtype_ == kNumberTypeInt32) {
InitWorkspaceSize<int>(); InitWorkspaceSize<int>();
} }

View File

@ -90,6 +90,11 @@ MS_REG_CPU_KERNEL(
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32), KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32),
MirrorPadGradCPUKernel); MirrorPadGradCPUKernel);
MS_REG_CPU_KERNEL(
MirrorPadGrad,
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat64),
MirrorPadGradCPUKernel);
MS_REG_CPU_KERNEL( MS_REG_CPU_KERNEL(
MirrorPadGrad, MirrorPadGrad,
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32), KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32),
@ -105,6 +110,11 @@ MS_REG_CPU_KERNEL(
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32), KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
MirrorPadGradCPUKernel); MirrorPadGradCPUKernel);
MS_REG_CPU_KERNEL(
MirrorPadGrad,
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat64),
MirrorPadGradCPUKernel);
MS_REG_CPU_KERNEL( MS_REG_CPU_KERNEL(
MirrorPadGrad, MirrorPadGrad,
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),

View File

@ -52,8 +52,6 @@ MS_REG_CPU_KERNEL(Sigmoid, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutp
EltWiseCPUKernel); EltWiseCPUKernel);
MS_REG_CPU_KERNEL(Sqrt, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), MS_REG_CPU_KERNEL(Sqrt, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
EltWiseCPUKernel); EltWiseCPUKernel);
MS_REG_CPU_KERNEL(Square, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
EltWiseCPUKernel);
MS_REG_CPU_KERNEL(Tanh, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), MS_REG_CPU_KERNEL(Tanh, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
EltWiseCPUKernel); EltWiseCPUKernel);
MS_REG_CPU_KERNEL(Softplus, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), MS_REG_CPU_KERNEL(Softplus, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),

View File

@ -111,22 +111,16 @@ bool MKLCPUKernel::BinaryBroadCast(std::vector<size_t> *src0_shape, std::vector<
} }
dnnl::memory::format_tag MKLCPUKernel::GetDefaultFormatTag(const dnnl::memory::dims &dims) const { dnnl::memory::format_tag MKLCPUKernel::GetDefaultFormatTag(const dnnl::memory::dims &dims) const {
dnnl::memory::format_tag mem_tag; static const std::vector<dnnl::memory::format_tag> tag_vec = {
auto dim_size = dims.size(); dnnl::memory::format_tag::a, dnnl::memory::format_tag::ab, dnnl::memory::format_tag::abc,
if (dim_size == 5) { dnnl::memory::format_tag::abcd, dnnl::memory::format_tag::abcde, dnnl::memory::format_tag::abcdef,
mem_tag = dnnl::memory::format_tag::abcde; dnnl::memory::format_tag::abcdefg};
} else if (dim_size == 4) {
mem_tag = dnnl::memory::format_tag::abcd; auto rank = dims.size();
} else if (dim_size == 3) { if (rank > tag_vec.size()) {
mem_tag = dnnl::memory::format_tag::abc; MS_LOG(EXCEPTION) << "The kernel does not support construct " << rank << "-D tensor dnnl memory format_tag.";
} else if (dim_size == 2) {
mem_tag = dnnl::memory::format_tag::ab;
} else if (dim_size == 1) {
mem_tag = dnnl::memory::format_tag::a;
} else {
MS_LOG(EXCEPTION) << "Kernel dims invalid " << dim_size;
} }
return mem_tag; return tag_vec[rank - 1];
} }
dnnl::memory::desc MKLCPUKernel::GetDefaultMemDesc(const std::vector<size_t> &shape) { dnnl::memory::desc MKLCPUKernel::GetDefaultMemDesc(const std::vector<size_t> &shape) {

View File

@ -36,9 +36,6 @@ class MulCPUKernel : public MKLCPUKernel {
private: private:
bool need_swap_{false}; bool need_swap_{false};
}; };
MS_REG_CPU_KERNEL(Mul, KernelAttr(), MulCPUKernel);
MS_REG_CPU_KERNEL_T(Mul, KernelAttr(), ArithmeticCPUKernel, int32_t);
} // namespace kernel } // namespace kernel
} // namespace mindspore } // namespace mindspore

View File

@ -45,7 +45,7 @@ if(MSLITE_STRING_KERNEL)
${KERNEL_SRC_INFER_STRING} ${KERNEL_SRC_INFER_STRING}
) )
endif() endif()
if(MSLITE_CONTROL_TENSORLIST) if(MSLITE_CONTROLFLOW_TENSORLIST)
file(GLOB KERNEL_SRC_INFER_CONTROL_TENSORLIST file(GLOB KERNEL_SRC_INFER_CONTROL_TENSORLIST
${NNACL_DIR}/infer/control/*.c ${NNACL_DIR}/infer/control/*.c
) )

View File

@ -29,10 +29,28 @@ asm_function MatmulFloatNeon64Opt
mov x21, #48 // sizeof(float) * 12 mov x21, #48 // sizeof(float) * 12
mul x17, x5, x21 // block stride of lhs/rhs: sizeof(float) * 12 * depth mul x17, x5, x21 // block stride of lhs/rhs: sizeof(float) * 12 * depth
cmp x9, #3 // c4
beq C4Stride
cbnz x9, NoC8Steps cbnz x9, NoC8Steps
mov x11, x2 mov x11, x2
mov x21, #32 mov x21, #32
mul x16, x6, x21 // row * 8 * sizeof(float) mul x16, x6, x21 // row * 8 * sizeof(float)
b NoC8Steps
C4Stride:
mov x18, #48 // 12 * sizeof(float)
mov x22, #4
mul x8, x8, x22 // stride * sizeof(float), in c4 stride == row
mul x8, x8, x22 // col stride
// col >= 4 , block stride 192, otherwise 12 * 4 * col
cmp x7, #4
bge C4StrideCommon
mul x18, x18, x7 // block stride
b LoopRowStart
C4StrideCommon:
mov x18, #192 // block stride
b LoopRowStart
NoC8Steps: NoC8Steps:
cmp x9, #2 cmp x9, #2
bne NoWinoSteps bne NoWinoSteps
@ -46,10 +64,14 @@ NoWinoSteps:
mul x8, x8, x21 mul x8, x8, x21
LoopRowStart: LoopRowStart:
cmp x9, #3
bne RowStart
mov x20, x2
RowStart:
cmp x6, #4 cmp x6, #4
ble LoopRow4 ble LoopRow4
cmp x6, #8 cmp x6, #8
blt LoopRow8 ble LoopRow8
LoopRow: LoopRow:
mov x14, x1 // reload rhs ptr mov x14, x1 // reload rhs ptr
@ -58,7 +80,12 @@ LoopRow:
LoopCol: LoopCol:
cbz x9, NoReloadDst cbz x9, NoReloadDst
cmp x9, #3
beq C4ReloadDst
mov x11, x2 mov x11, x2
b NoReloadDst
C4ReloadDst:
mov x11, x20
NoReloadDst: NoReloadDst:
mov x10, x0 // reload lhs ptr mov x10, x0 // reload lhs ptr
mov x19, x5 // reload depth mov x19, x5 // reload depth
@ -192,7 +219,7 @@ LoopRow:
fmin v29.4s, v29.4s, v2.4s fmin v29.4s, v29.4s, v2.4s
fmin v30.4s, v30.4s, v2.4s fmin v30.4s, v30.4s, v2.4s
fmin v31.4s, v31.4s, v2.4s fmin v31.4s, v31.4s, v2.4s
Relu: Relu:
dup v3.4s, wzr dup v3.4s, wzr
fmax v8.4s, v8.4s, v3.4s fmax v8.4s, v8.4s, v3.4s
@ -324,7 +351,12 @@ LoopRow8:
LoopCol8: LoopCol8:
cbz x9, NoReloadDst8 cbz x9, NoReloadDst8
cmp x9, #3
beq C4ReloadDst8
mov x11, x2 mov x11, x2
b NoReloadDst8
C4ReloadDst8:
mov x11, x20
NoReloadDst8: NoReloadDst8:
mov x10, x0 // reload lhs ptr mov x10, x0 // reload lhs ptr
mov x19, x5 // reload depth mov x19, x5 // reload depth
@ -426,7 +458,7 @@ LoopRow8:
fmin v21.4s, v21.4s, v2.4s fmin v21.4s, v21.4s, v2.4s
fmin v22.4s, v22.4s, v2.4s fmin v22.4s, v22.4s, v2.4s
fmin v23.4s, v23.4s, v2.4s fmin v23.4s, v23.4s, v2.4s
Relu8: Relu8:
dup v3.4s, wzr dup v3.4s, wzr
fmax v8.4s, v8.4s, v3.4s fmax v8.4s, v8.4s, v3.4s
@ -529,7 +561,12 @@ LoopRow4:
LoopCol4: LoopCol4:
cbz x9, NoReloadDst4 cbz x9, NoReloadDst4
cmp x9, #3
beq C4ReloadDst4
mov x11, x2 mov x11, x2
b NoReloadDst4
C4ReloadDst4:
mov x11, x20
NoReloadDst4: NoReloadDst4:
mov x10, x0 // reload lhs ptr mov x10, x0 // reload lhs ptr
mov x19, x5 // reload depth mov x19, x5 // reload depth
@ -599,7 +636,7 @@ LoopRow4:
fmin v13.4s, v13.4s, v2.4s fmin v13.4s, v13.4s, v2.4s
fmin v14.4s, v14.4s, v2.4s fmin v14.4s, v14.4s, v2.4s
fmin v15.4s, v15.4s, v2.4s fmin v15.4s, v15.4s, v2.4s
Relu4: Relu4:
dup v3.4s, wzr dup v3.4s, wzr
fmax v8.4s, v8.4s, v3.4s fmax v8.4s, v8.4s, v3.4s
@ -669,6 +706,8 @@ LoopRow4:
Write: Write:
cmp x9, #2 cmp x9, #2
beq WriteWino beq WriteWino
cmp x9, #3
beq WriteC4
cbz x9, WriteC8 cbz x9, WriteC8
cmp x13, #1 cmp x13, #1
beq Write1 beq Write1
@ -1102,6 +1141,508 @@ LoopRow4:
beq WriteEnd beq WriteEnd
st1 {v30.4s, v31.4s}, [x11], x8 st1 {v30.4s, v31.4s}, [x11], x8
add x11, x11, #32 add x11, x11, #32
b WriteEnd
WriteC4:
cmp x13, #1
beq C4Write1
cmp x13, #2
beq C4Write2
cmp x13, #3
beq C4Write3
cmp x13, #4
beq C4Write4
cmp x13, #5
beq C4Write5
cmp x13, #6
beq C4Write6
cmp x13, #7
beq C4Write7
b C4Write8
C4Write1:
// add x20, x11, x8
str s8, [x11], #4
cmp x6, #1
beq WriteEnd
str s10, [x11], #4
cmp x6, #2
beq WriteEnd
str s12, [x11], #4
cmp x6, #3
beq WriteEnd
str s14, [x11], #4
cmp x6, #4
beq WriteEnd
str s16, [x11], #4
cmp x6, #5
beq WriteEnd
str s18, [x11], #4
cmp x6, #6
beq WriteEnd
str s20, [x11], #4
cmp x6, #7
beq WriteEnd
str s22, [x11], #4
cmp x6, #8
beq WriteEnd
str s24, [x11], #4
cmp x6, #9
beq WriteEnd
str s26, [x11], #4
cmp x6, #10
beq WriteEnd
str s28, [x11], #4
cmp x6, #11
beq WriteEnd
str s30, [x11], #4
b WriteEnd
C4Write2:
// add x20, x11, x8
st1 {v8.2s}, [x11], #8
cmp x6, #1
beq WriteEnd
st1 {v10.2s}, [x11], #8
cmp x6, #2
beq WriteEnd
st1 {v12.2s}, [x11], #8
cmp x6, #3
beq WriteEnd
st1 {v14.2s}, [x11], #8
cmp x6, #4
beq WriteEnd
st1 {v16.2s}, [x11], #8
cmp x6, #5
beq WriteEnd
st1 {v18.2s}, [x11], #8
cmp x6, #6
beq WriteEnd
st1 {v20.2s}, [x11], #8
cmp x6, #7
beq WriteEnd
st1 {v22.2s}, [x11], #8
cmp x6, #8
beq WriteEnd
st1 {v24.2s}, [x11], #8
cmp x6, #9
beq WriteEnd
st1 {v26.2s}, [x11], #8
cmp x6, #10
beq WriteEnd
st1 {v28.2s}, [x11], #8
cmp x6, #11
beq WriteEnd
st1 {v30.2s}, [x11], #8
b WriteEnd
C4Write3:
// add x20, x11, x8
add x19, x11, #8
st1 {v8.2s}, [x11]
add x11, x11, #12
st1 {v8.s}[2], [x19]
add x19, x19, #12
cmp x6, #1
beq WriteEnd
st1 {v10.2s}, [x11]
add x11, x11, #12
st1 {v10.s}[2], [x19]
add x19, x19, #12
cmp x6, #2
beq WriteEnd
st1 {v12.2s}, [x11]
add x11, x11, #12
st1 {v12.s}[2], [x19]
add x19, x19, #12
cmp x6, #3
beq WriteEnd
st1 {v14.2s}, [x11]
add x11, x11, #12
st1 {v14.s}[2], [x19]
add x19, x19, #12
cmp x6, #4
beq WriteEnd
st1 {v16.2s}, [x11]
add x11, x11, #12
st1 {v16.s}[2], [x19]
add x19, x19, #12
cmp x6, #5
beq WriteEnd
st1 {v18.2s}, [x11]
add x11, x11, #12
st1 {v18.s}[2], [x19]
add x19, x19, #12
cmp x6, #6
beq WriteEnd
st1 {v20.2s}, [x11]
add x11, x11, #12
st1 {v20.s}[2], [x19]
add x19, x19, #12
cmp x6, #7
beq WriteEnd
st1 {v22.2s}, [x11]
add x11, x11, #12
st1 {v22.s}[2], [x19]
add x19, x19, #12
cmp x6, #8
beq WriteEnd
st1 {v24.2s}, [x11]
add x11, x11, #12
st1 {v24.s}[2], [x19]
add x19, x19, #12
cmp x6, #9
beq WriteEnd
st1 {v26.2s}, [x11]
add x11, x11, #12
st1 {v26.s}[2], [x19]
add x19, x19, #12
cmp x6, #10
beq WriteEnd
st1 {v28.2s}, [x11]
add x11, x11, #12
st1 {v28.s}[2], [x19]
add x19, x19, #12
cmp x6, #11
beq WriteEnd
st1 {v30.2s}, [x11]
add x11, x11, #12
st1 {v30.s}[2], [x19]
add x19, x19, #12
b WriteEnd
C4Write4:
add x20, x11, x8
st1 {v8.4s}, [x11], #16
cmp x6, #1
beq WriteEnd
st1 {v10.4s}, [x11], #16
cmp x6, #2
beq WriteEnd
st1 {v12.4s}, [x11], #16
cmp x6, #3
beq WriteEnd
st1 {v14.4s}, [x11], #16
cmp x6, #4
beq WriteEnd
st1 {v16.4s}, [x11], #16
cmp x6, #5
beq WriteEnd
st1 {v18.4s}, [x11], #16
cmp x6, #6
beq WriteEnd
st1 {v20.4s}, [x11], #16
cmp x6, #7
beq WriteEnd
st1 {v22.4s}, [x11], #16
cmp x6, #8
beq WriteEnd
st1 {v24.4s}, [x11], #16
cmp x6, #9
beq WriteEnd
st1 {v26.4s}, [x11], #16
cmp x6, #10
beq WriteEnd
st1 {v28.4s}, [x11], #16
cmp x6, #11
beq WriteEnd
st1 {v30.4s}, [x11], #16
b WriteEnd
C4Write5:
add x19, x11, #16
st1 {v8.4s}, [x11]
add x11, x11, #20
str s9, [x19]
add x19, x19, #20
cmp x6, #1
beq WriteEnd
st1 {v10.4s}, [x11]
add x11, x11, #20
str s11, [x19]
add x19, x19, #20
cmp x6, #2
beq WriteEnd
st1 {v12.4s}, [x11]
add x11, x11, #20
str s13, [x19]
add x19, x19, #20
cmp x6, #3
beq WriteEnd
st1 {v14.4s}, [x11]
add x11, x11, #20
str s15, [x19]
add x19, x19, #20
cmp x6, #4
beq WriteEnd
st1 {v16.4s}, [x11]
add x11, x11, #20
str s17, [x19]
add x19, x19, #20
cmp x6, #5
beq WriteEnd
st1 {v18.4s}, [x11]
add x11, x11, #20
str s19, [x19]
add x19, x19, #20
cmp x6, #6
beq WriteEnd
st1 {v20.4s}, [x11]
add x11, x11, #20
str s21, [x19]
add x19, x19, #20
cmp x6, #7
beq WriteEnd
st1 {v22.4s}, [x11]
add x11, x11, #20
str s23, [x19]
add x19, x19, #20
cmp x6, #8
beq WriteEnd
st1 {v24.4s}, [x11]
add x11, x11, #20
str s25, [x19]
add x19, x19, #20
cmp x6, #9
beq WriteEnd
st1 {v26.4s}, [x11]
add x11, x11, #20
str s27, [x19]
add x19, x19, #20
cmp x6, #10
beq WriteEnd
st1 {v28.4s}, [x11]
add x11, x11, #20
str s29, [x19]
add x19, x19, #20
cmp x6, #11
beq WriteEnd
st1 {v30.4s}, [x11]
str s31, [x19]
b WriteEnd
C4Write6:
add x19, x11, #16
st1 {v8.4s}, [x11]
add x11, x11, #24
st1 {v9.2s}, [x19]
add x19, x19, #24
cmp x6, #1
beq WriteEnd
st1 {v10.4s}, [x11]
add x11, x11, #24
st1 {v11.2s}, [x19]
add x19, x19, #24
cmp x6, #2
beq WriteEnd
st1 {v12.4s}, [x11]
add x11, x11, #24
st1 {v13.2s}, [x19]
add x19, x19, #24
cmp x6, #3
beq WriteEnd
st1 {v14.4s}, [x11]
add x11, x11, #24
st1 {v15.2s}, [x19]
add x19, x19, #24
cmp x6, #4
beq WriteEnd
st1 {v16.4s}, [x11]
add x11, x11, #24
st1 {v17.2s}, [x19]
add x19, x19, #24
cmp x6, #5
beq WriteEnd
st1 {v18.4s}, [x11]
add x11, x11, #24
st1 {v19.2s}, [x19]
add x19, x19, #24
cmp x6, #6
beq WriteEnd
st1 {v20.4s}, [x11]
add x11, x11, #24
st1 {v21.2s}, [x19]
add x19, x19, #24
cmp x6, #7
beq WriteEnd
st1 {v22.4s}, [x11]
add x11, x11, #24
st1 {v23.2s}, [x19]
add x19, x19, #24
cmp x6, #8
beq WriteEnd
st1 {v24.4s}, [x11]
add x11, x11, #24
st1 {v25.2s}, [x19]
add x19, x19, #24
cmp x6, #9
beq WriteEnd
st1 {v26.4s}, [x11]
add x11, x11, #24
st1 {v27.2s}, [x19]
add x19, x19, #24
cmp x6, #10
beq WriteEnd
st1 {v28.4s}, [x11]
add x11, x11, #24
st1 {v29.2s}, [x19]
add x19, x19, #24
cmp x6, #11
beq WriteEnd
st1 {v30.4s}, [x11]
st1 {v31.2s}, [x19]
b WriteEnd
C4Write7:
add x19, x11, #16
add x16, x11, #24
mov x10, #28
st1 {v8.4s}, [x11], x10
st1 {v9.2s}, [x19], x10
st1 {v9.s}[2], [x16], x10
cmp x6, #1
beq WriteEnd
st1 {v10.4s}, [x11], x10
st1 {v11.2s}, [x19], x10
st1 {v11.s}[2], [x16], x10
cmp x6, #2
beq WriteEnd
st1 {v12.4s}, [x11], x10
st1 {v13.2s}, [x19], x10
st1 {v13.s}[2], [x16], x10
cmp x6, #3
beq WriteEnd
st1 {v14.4s}, [x11], x10
st1 {v15.2s}, [x19], x10
st1 {v15.s}[2], [x16], x10
cmp x6, #4
beq WriteEnd
st1 {v16.4s}, [x11], x10
st1 {v17.2s}, [x19], x10
st1 {v17.s}[2], [x16], x10
cmp x6, #5
beq WriteEnd
st1 {v18.4s}, [x11], x10
st1 {v19.2s}, [x19], x10
st1 {v19.s}[2], [x16], x10
cmp x6, #6
beq WriteEnd
st1 {v20.4s}, [x11], x10
st1 {v21.2s}, [x19], x10
st1 {v21.s}[2], [x16], x10
cmp x6, #7
beq WriteEnd
st1 {v22.4s}, [x11], x10
st1 {v23.2s}, [x19], x10
st1 {v23.s}[2], [x16], x10
cmp x6, #8
beq WriteEnd
st1 {v24.4s}, [x11], x10
st1 {v25.2s}, [x19], x10
st1 {v25.s}[2], [x16], x10
cmp x6, #9
beq WriteEnd
st1 {v26.4s}, [x11], x10
st1 {v27.2s}, [x19], x10
st1 {v27.s}[2], [x16], x10
cmp x6, #10
beq WriteEnd
st1 {v28.4s}, [x11], x10
st1 {v29.2s}, [x19], x10
st1 {v29.s}[2], [x16], x10
cmp x6, #11
beq WriteEnd
st1 {v30.4s}, [x11]
st1 {v31.2s}, [x19]
st1 {v31.s}[2], [x16]
b WriteEnd
C4Write8:
add x19, x11, x8
add x20, x19, x8
st1 {v8.4s}, [x11], #16
st1 {v9.4s}, [x19], #16
cmp x6, #1
beq WriteEnd
st1 {v10.4s}, [x11], #16
st1 {v11.4s}, [x19], #16
cmp x6, #2
beq WriteEnd
st1 {v12.4s}, [x11], #16
st1 {v13.4s}, [x19], #16
cmp x6, #3
beq WriteEnd
st1 {v14.4s}, [x11], #16
st1 {v15.4s}, [x19], #16
cmp x6, #4
beq WriteEnd
st1 {v16.4s}, [x11], #16
st1 {v17.4s}, [x19], #16
cmp x6, #5
beq WriteEnd
st1 {v18.4s}, [x11], #16
st1 {v19.4s}, [x19], #16
cmp x6, #6
beq WriteEnd
st1 {v20.4s}, [x11], #16
st1 {v21.4s}, [x19], #16
cmp x6, #7
beq WriteEnd
st1 {v22.4s}, [x11], #16
st1 {v23.4s}, [x19], #16
cmp x6, #8
beq WriteEnd
st1 {v24.4s}, [x11], #16
st1 {v25.4s}, [x19], #16
cmp x6, #9
beq WriteEnd
st1 {v26.4s}, [x11], #16
st1 {v27.4s}, [x19], #16
cmp x6, #10
beq WriteEnd
st1 {v28.4s}, [x11], #16
st1 {v29.4s}, [x19], #16
cmp x6, #11
beq WriteEnd
st1 {v30.4s}, [x11]
st1 {v31.4s}, [x19]
b WriteEnd
WriteEnd: WriteEnd:
subs x13, x13, #8 // rhs col - 8 subs x13, x13, #8 // rhs col - 8
@ -1115,11 +1656,16 @@ LoopRow4:
LoopColEnd: LoopColEnd:
add x0, x0, x17 add x0, x0, x17
cbz x9, C8DstStep cbz x9, C8DstStep
cmp x9, #3
beq C4DstStep
mov x21, #4 mov x21, #4
mul x21, x21, x7 mul x21, x21, x7
sub x11, x11, x21 sub x11, x11, x21
mov x2, x11 mov x2, x11
b NoDstStep b NoDstStep
C4DstStep:
add x2, x2, x18
b NoDstStep
C8DstStep: C8DstStep:
add x2, x2, #384 add x2, x2, #384
mov x11, x2 mov x11, x2

View File

@ -29,10 +29,27 @@ asm_function MatmulFloatNeon64OptRow12
mov x21, #48 // sizeof(float) * 12 mov x21, #48 // sizeof(float) * 12
mul x17, x5, x21 // block stride of lhs/rhs: sizeof(float) * 12 * depth mul x17, x5, x21 // block stride of lhs/rhs: sizeof(float) * 12 * depth
cmp x9, #3 // c4
beq C4Stride
cbnz x9, NoC8Steps cbnz x9, NoC8Steps
mov x11, x2 mov x11, x2
mov x21, #32 mov x21, #32
mul x16, x6, x21 // row * 8 * sizeof(float) mul x16, x6, x21 // row * 8 * sizeof(float)
b NoC8Steps
C4Stride:
mov x18, #48 // 12 * sizeof(float)
mov x22, #4
mul x8, x8, x22 // stride * sizeof(float), in c4 stride == row
mul x8, x8, x22 // col stride
// col >= 4 , block stride 192, otherwise 12 * 4 * col
cmp x7, #4
bge C4StrideCommon
mul x18, x18, x7 // block stride
b LoopRowStart
C4StrideCommon:
mov x18, #192 // block stride
b LoopRowStart
NoC8Steps: NoC8Steps:
cmp x9, #2 cmp x9, #2
bne NoWinoSteps bne NoWinoSteps
@ -45,6 +62,10 @@ NoWinoSteps:
mov x21, #4 mov x21, #4
mul x8, x8, x21 mul x8, x8, x21
LoopRowStart:
cmp x9, #3
bne LoopRow
mov x20, x2
LoopRow: LoopRow:
mov x14, x1 // reload rhs ptr mov x14, x1 // reload rhs ptr
mov x13, x7 // reload rhs col mov x13, x7 // reload rhs col
@ -52,7 +73,12 @@ LoopRow:
LoopCol: LoopCol:
cbz x9, NoReloadDst cbz x9, NoReloadDst
cmp x9, #3
beq C4ReloadDst
mov x11, x2 mov x11, x2
b NoReloadDst
C4ReloadDst:
mov x11, x20
NoReloadDst: NoReloadDst:
mov x10, x0 // reload lhs ptr mov x10, x0 // reload lhs ptr
mov x19, x5 // reload depth mov x19, x5 // reload depth
@ -186,7 +212,7 @@ LoopRow:
fmin v29.4s, v29.4s, v2.4s fmin v29.4s, v29.4s, v2.4s
fmin v30.4s, v30.4s, v2.4s fmin v30.4s, v30.4s, v2.4s
fmin v31.4s, v31.4s, v2.4s fmin v31.4s, v31.4s, v2.4s
Relu: Relu:
dup v3.4s, wzr dup v3.4s, wzr
fmax v8.4s, v8.4s, v3.4s fmax v8.4s, v8.4s, v3.4s
@ -312,6 +338,8 @@ LoopRow:
Write: Write:
cmp x9, #2 cmp x9, #2
beq WriteWino beq WriteWino
cmp x9, #3
beq WriteC4
cbz x9, WriteC8 cbz x9, WriteC8
cmp x13, #1 cmp x13, #1
beq Write1 beq Write1
@ -370,7 +398,7 @@ LoopRow:
str s26, [x11] str s26, [x11]
cmp x6, #10 cmp x6, #10
beq WriteEnd beq WriteEnd
add x11, x11, x8 add x11, x11, x8
str s28, [x11] str s28, [x11]
cmp x6, #11 cmp x6, #11
beq WriteEnd beq WriteEnd
@ -745,7 +773,458 @@ LoopRow:
beq WriteEnd beq WriteEnd
st1 {v30.4s, v31.4s}, [x11], x8 st1 {v30.4s, v31.4s}, [x11], x8
add x11, x11, #32 add x11, x11, #32
b WriteEnd
WriteC4:
cmp x13, #1
beq C4Write1
cmp x13, #2
beq C4Write2
cmp x13, #3
beq C4Write3
cmp x13, #4
beq C4Write4
cmp x13, #5
beq C4Write5
cmp x13, #6
beq C4Write6
cmp x13, #7
beq C4Write7
b C4Write8
C4Write1:
str s8, [x11], #4
cmp x6, #1
beq WriteEnd
str s10, [x11], #4
cmp x6, #2
beq WriteEnd
str s12, [x11], #4
cmp x6, #3
beq WriteEnd
str s14, [x11], #4
cmp x6, #4
beq WriteEnd
str s16, [x11], #4
cmp x6, #5
beq WriteEnd
str s18, [x11], #4
cmp x6, #6
beq WriteEnd
str s20, [x11], #4
cmp x6, #7
beq WriteEnd
str s22, [x11], #4
cmp x6, #8
beq WriteEnd
str s24, [x11], #4
cmp x6, #9
beq WriteEnd
str s26, [x11], #4
cmp x6, #10
beq WriteEnd
str s28, [x11], #4
cmp x6, #11
beq WriteEnd
str s30, [x11], #4
b WriteEnd
C4Write2:
st1 {v8.2s}, [x11], #8
cmp x6, #1
beq WriteEnd
st1 {v10.2s}, [x11], #8
cmp x6, #2
beq WriteEnd
st1 {v12.2s}, [x11], #8
cmp x6, #3
beq WriteEnd
st1 {v14.2s}, [x11], #8
cmp x6, #4
beq WriteEnd
st1 {v16.2s}, [x11], #8
cmp x6, #5
beq WriteEnd
st1 {v18.2s}, [x11], #8
cmp x6, #6
beq WriteEnd
st1 {v20.2s}, [x11], #8
cmp x6, #7
beq WriteEnd
st1 {v22.2s}, [x11], #8
cmp x6, #8
beq WriteEnd
st1 {v24.2s}, [x11], #8
cmp x6, #9
beq WriteEnd
st1 {v26.2s}, [x11], #8
cmp x6, #10
beq WriteEnd
st1 {v28.2s}, [x11], #8
cmp x6, #11
beq WriteEnd
st1 {v30.2s}, [x11], #8
b WriteEnd
C4Write3:
add x19, x11, #8
st1 {v8.2s}, [x11]
add x11, x11, #12
st1 {v8.s}[2], [x19]
add x19, x19, #12
cmp x6, #1
beq WriteEnd
st1 {v10.2s}, [x11]
add x11, x11, #12
st1 {v10.s}[2], [x19]
add x19, x19, #12
cmp x6, #2
beq WriteEnd
st1 {v12.2s}, [x11]
add x11, x11, #12
st1 {v12.s}[2], [x19]
add x19, x19, #12
cmp x6, #3
beq WriteEnd
st1 {v14.2s}, [x11]
add x11, x11, #12
st1 {v14.s}[2], [x19]
add x19, x19, #12
cmp x6, #4
beq WriteEnd
st1 {v16.2s}, [x11]
add x11, x11, #12
st1 {v16.s}[2], [x19]
add x19, x19, #12
cmp x6, #5
beq WriteEnd
st1 {v18.2s}, [x11]
add x11, x11, #12
st1 {v18.s}[2], [x19]
add x19, x19, #12
cmp x6, #6
beq WriteEnd
st1 {v20.2s}, [x11]
add x11, x11, #12
st1 {v20.s}[2], [x19]
add x19, x19, #12
cmp x6, #7
beq WriteEnd
st1 {v22.2s}, [x11]
add x11, x11, #12
st1 {v22.s}[2], [x19]
add x19, x19, #12
cmp x6, #8
beq WriteEnd
st1 {v24.2s}, [x11]
add x11, x11, #12
st1 {v24.s}[2], [x19]
add x19, x19, #12
cmp x6, #9
beq WriteEnd
st1 {v26.2s}, [x11]
add x11, x11, #12
st1 {v26.s}[2], [x19]
add x19, x19, #12
cmp x6, #10
beq WriteEnd
st1 {v28.2s}, [x11]
add x11, x11, #12
st1 {v28.s}[2], [x19]
add x19, x19, #12
cmp x6, #11
beq WriteEnd
st1 {v30.2s}, [x11]
add x11, x11, #12
st1 {v30.s}[2], [x19]
add x19, x19, #12
b WriteEnd
C4Write4:
st1 {v8.4s}, [x11], #16
cmp x6, #1
beq WriteEnd
st1 {v10.4s}, [x11], #16
cmp x6, #2
beq WriteEnd
st1 {v12.4s}, [x11], #16
cmp x6, #3
beq WriteEnd
st1 {v14.4s}, [x11], #16
cmp x6, #4
beq WriteEnd
st1 {v16.4s}, [x11], #16
cmp x6, #5
beq WriteEnd
st1 {v18.4s}, [x11], #16
cmp x6, #6
beq WriteEnd
st1 {v20.4s}, [x11], #16
cmp x6, #7
beq WriteEnd
st1 {v22.4s}, [x11], #16
cmp x6, #8
beq WriteEnd
st1 {v24.4s}, [x11], #16
cmp x6, #9
beq WriteEnd
st1 {v26.4s}, [x11], #16
cmp x6, #10
beq WriteEnd
st1 {v28.4s}, [x11], #16
cmp x6, #11
beq WriteEnd
st1 {v30.4s}, [x11], #16
b WriteEnd
C4Write5:
add x19, x11, #16
st1 {v8.4s}, [x11]
add x11, x11, #20
str s9, [x19]
add x19, x19, #20
cmp x6, #1
beq WriteEnd
st1 {v10.4s}, [x11]
add x11, x11, #20
str s11, [x19]
add x19, x19, #20
cmp x6, #2
beq WriteEnd
st1 {v12.4s}, [x11]
add x11, x11, #20
str s13, [x19]
add x19, x19, #20
cmp x6, #3
beq WriteEnd
st1 {v14.4s}, [x11]
add x11, x11, #20
str s15, [x19]
add x19, x19, #20
cmp x6, #4
beq WriteEnd
st1 {v16.4s}, [x11]
add x11, x11, #20
str s17, [x19]
add x19, x19, #20
cmp x6, #5
beq WriteEnd
st1 {v18.4s}, [x11]
add x11, x11, #20
str s19, [x19]
add x19, x19, #20
cmp x6, #6
beq WriteEnd
st1 {v20.4s}, [x11]
add x11, x11, #20
str s21, [x19]
add x19, x19, #20
cmp x6, #7
beq WriteEnd
st1 {v22.4s}, [x11]
add x11, x11, #20
str s23, [x19]
add x19, x19, #20
cmp x6, #8
beq WriteEnd
st1 {v24.4s}, [x11]
add x11, x11, #20
str s25, [x19]
add x19, x19, #20
cmp x6, #9
beq WriteEnd
st1 {v26.4s}, [x11]
add x11, x11, #20
str s27, [x19]
add x19, x19, #20
cmp x6, #10
beq WriteEnd
st1 {v28.4s}, [x11]
add x11, x11, #20
str s29, [x19]
add x19, x19, #20
cmp x6, #11
beq WriteEnd
st1 {v30.4s}, [x11]
str s31, [x19]
b WriteEnd
C4Write6:
add x19, x11, #16
st1 {v8.4s}, [x11]
add x11, x11, #24
st1 {v9.2s}, [x19]
add x19, x19, #24
cmp x6, #1
beq WriteEnd
st1 {v10.4s}, [x11]
add x11, x11, #24
st1 {v11.2s}, [x19]
add x19, x19, #24
cmp x6, #2
beq WriteEnd
st1 {v12.4s}, [x11]
add x11, x11, #24
st1 {v13.2s}, [x19]
add x19, x19, #24
cmp x6, #3
beq WriteEnd
st1 {v14.4s}, [x11]
add x11, x11, #24
st1 {v15.2s}, [x19]
add x19, x19, #24
cmp x6, #4
beq WriteEnd
st1 {v16.4s}, [x11]
add x11, x11, #24
st1 {v17.2s}, [x19]
add x19, x19, #24
cmp x6, #5
beq WriteEnd
st1 {v18.4s}, [x11]
add x11, x11, #24
st1 {v19.2s}, [x19]
add x19, x19, #24
cmp x6, #6
beq WriteEnd
st1 {v20.4s}, [x11]
add x11, x11, #24
st1 {v21.2s}, [x19]
add x19, x19, #24
cmp x6, #7
beq WriteEnd
st1 {v22.4s}, [x11]
add x11, x11, #24
st1 {v23.2s}, [x19]
add x19, x19, #24
cmp x6, #8
beq WriteEnd
st1 {v24.4s}, [x11]
add x11, x11, #24
st1 {v25.2s}, [x19]
add x19, x19, #24
cmp x6, #9
beq WriteEnd
st1 {v26.4s}, [x11]
add x11, x11, #24
st1 {v27.2s}, [x19]
add x19, x19, #24
cmp x6, #10
beq WriteEnd
st1 {v28.4s}, [x11]
add x11, x11, #24
st1 {v29.2s}, [x19]
add x19, x19, #24
cmp x6, #11
beq WriteEnd
st1 {v30.4s}, [x11]
st1 {v31.2s}, [x19]
b WriteEnd
C4Write7:
add x19, x11, #16
add x16, x11, #24
mov x10, #28
st1 {v8.4s}, [x11], x10
st1 {v9.2s}, [x19], x10
st1 {v9.s}[2], [x16], x10
cmp x6, #1
beq WriteEnd
st1 {v10.4s}, [x11], x10
st1 {v11.2s}, [x19], x10
st1 {v11.s}[2], [x16], x10
cmp x6, #2
beq WriteEnd
st1 {v12.4s}, [x11], x10
st1 {v13.2s}, [x19], x10
st1 {v13.s}[2], [x16], x10
cmp x6, #3
beq WriteEnd
st1 {v14.4s}, [x11], x10
st1 {v15.2s}, [x19], x10
st1 {v15.s}[2], [x16], x10
cmp x6, #4
beq WriteEnd
st1 {v16.4s}, [x11], x10
st1 {v17.2s}, [x19], x10
st1 {v17.s}[2], [x16], x10
cmp x6, #5
beq WriteEnd
st1 {v18.4s}, [x11], x10
st1 {v19.2s}, [x19], x10
st1 {v19.s}[2], [x16], x10
cmp x6, #6
beq WriteEnd
st1 {v20.4s}, [x11], x10
st1 {v21.2s}, [x19], x10
st1 {v21.s}[2], [x16], x10
cmp x6, #7
beq WriteEnd
st1 {v22.4s}, [x11], x10
st1 {v23.2s}, [x19], x10
st1 {v23.s}[2], [x16], x10
cmp x6, #8
beq WriteEnd
st1 {v24.4s}, [x11], x10
st1 {v25.2s}, [x19], x10
st1 {v25.s}[2], [x16], x10
cmp x6, #9
beq WriteEnd
st1 {v26.4s}, [x11], x10
st1 {v27.2s}, [x19], x10
st1 {v27.s}[2], [x16], x10
cmp x6, #10
beq WriteEnd
st1 {v28.4s}, [x11], x10
st1 {v29.2s}, [x19], x10
st1 {v29.s}[2], [x16], x10
cmp x6, #11
beq WriteEnd
st1 {v30.4s}, [x11]
st1 {v31.2s}, [x19]
st1 {v31.s}[2], [x16]
b WriteEnd
C4Write8:
add x19, x11, x8
add x20, x19, x8
st1 {v8.4s}, [x11], #16
st1 {v9.4s}, [x19], #16
cmp x6, #1
beq WriteEnd
st1 {v10.4s}, [x11], #16
st1 {v11.4s}, [x19], #16
cmp x6, #2
beq WriteEnd
st1 {v12.4s}, [x11], #16
st1 {v13.4s}, [x19], #16
cmp x6, #3
beq WriteEnd
st1 {v14.4s}, [x11], #16
st1 {v15.4s}, [x19], #16
cmp x6, #4
beq WriteEnd
st1 {v16.4s}, [x11], #16
st1 {v17.4s}, [x19], #16
cmp x6, #5
beq WriteEnd
st1 {v18.4s}, [x11], #16
st1 {v19.4s}, [x19], #16
cmp x6, #6
beq WriteEnd
st1 {v20.4s}, [x11], #16
st1 {v21.4s}, [x19], #16
cmp x6, #7
beq WriteEnd
st1 {v22.4s}, [x11], #16
st1 {v23.4s}, [x19], #16
cmp x6, #8
beq WriteEnd
st1 {v24.4s}, [x11], #16
st1 {v25.4s}, [x19], #16
cmp x6, #9
beq WriteEnd
st1 {v26.4s}, [x11], #16
st1 {v27.4s}, [x19], #16
cmp x6, #10
beq WriteEnd
st1 {v28.4s}, [x11], #16
st1 {v29.4s}, [x19], #16
cmp x6, #11
beq WriteEnd
st1 {v30.4s}, [x11]
st1 {v31.4s}, [x19]
WriteEnd: WriteEnd:
subs x13, x13, #8 // rhs col - 8 subs x13, x13, #8 // rhs col - 8
bgt LoopCol bgt LoopCol
@ -753,11 +1232,16 @@ LoopRow:
LoopColEnd: LoopColEnd:
add x0, x0, x17 add x0, x0, x17
cbz x9, C8DstStep cbz x9, C8DstStep
cmp x9, #3
beq C4DstStep
mov x21, #4 mov x21, #4
mul x21, x21, x7 mul x21, x21, x7
sub x11, x11, x21 sub x11, x11, x21
mov x2, x11 mov x2, x11
b NoDstStep b NoDstStep
C4DstStep:
add x2, x2, x18
b NoDstStep
C8DstStep: C8DstStep:
add x2, x2, #384 add x2, x2, #384
mov x11, x2 mov x11, x2

View File

@ -28,11 +28,29 @@ asm_function MatmulFloatNeon64OptRow4
ldr x9, [sp, #8] ldr x9, [sp, #8]
mov x21, #48 // sizeof(float) * 12 mov x21, #48 // sizeof(float) * 12
mul x17, x5, x21 // block stride of lhs/rhs: sizeof(float) * 12 * depth mul x17, x5, x21 // block stride of lhs/rhs: sizeof(float) * 12 * depth
cmp x9, #3 // c4
beq C4Stride
cbnz x9, NoC8Steps cbnz x9, NoC8Steps
mov x11, x2 mov x11, x2
mov x21, #32 mov x21, #32
mul x16, x6, x21 // row * 8 * sizeof(float) mul x16, x6, x21 // row * 8 * sizeof(float)
b NoC8Steps
C4Stride:
mov x18, #16 // 4 * sizeof(float)
mov x22, #4
mul x8, x8, x22 // stride * sizeof(float), in c4 stride == row
mul x8, x8, x22 // col stride
// col >= 4 , block stride 64, otherwise 4 * 4 * col
cmp x7, #4
bge C4StrideCommon
mul x18, x18, x7 // block stride
b LoopRowStart
C4StrideCommon:
mov x18, #64 // block stride
b LoopRowStart
NoC8Steps: NoC8Steps:
cmp x9, #2 cmp x9, #2
bne NoWinoSteps bne NoWinoSteps
@ -45,6 +63,10 @@ NoWinoSteps:
mov x21, #4 mov x21, #4
mul x8, x8, x21 mul x8, x8, x21
LoopRowStart:
cmp x9, #3
bne LoopRow4
mov x20, x2
LoopRow4: LoopRow4:
mov x14, x1 // reload rhs ptr mov x14, x1 // reload rhs ptr
mov x13, x7 // reload rhs col mov x13, x7 // reload rhs col
@ -52,7 +74,12 @@ LoopRow4:
LoopCol4: LoopCol4:
cbz x9, NoReloadDst4 cbz x9, NoReloadDst4
cmp x9, #3
beq C4ReloadDst4
mov x11, x2 mov x11, x2
b NoReloadDst4
C4ReloadDst4:
mov x11, x20
NoReloadDst4: NoReloadDst4:
mov x10, x0 // reload lhs ptr mov x10, x0 // reload lhs ptr
mov x19, x5 // reload depth mov x19, x5 // reload depth
@ -194,6 +221,8 @@ LoopRow4:
Write: Write:
cmp x9, #2 cmp x9, #2
beq WriteWino beq WriteWino
cmp x9, #3
beq WriteC4
cbz x9, WriteC8 cbz x9, WriteC8
cmp x13, #1 cmp x13, #1
beq Write1 beq Write1
@ -369,7 +398,168 @@ LoopRow4:
beq WriteEnd beq WriteEnd
st1 {v14.4s, v15.4s}, [x11], x8 st1 {v14.4s, v15.4s}, [x11], x8
add x11, x11, #32 add x11, x11, #32
b WriteEnd
WriteC4:
cmp x13, #1
beq C4Write1
cmp x13, #2
beq C4Write2
cmp x13, #3
beq C4Write3
cmp x13, #4
beq C4Write4
cmp x13, #5
beq C4Write5
cmp x13, #6
beq C4Write6
cmp x13, #7
beq C4Write7
b C4Write8
C4Write1:
str s8, [x11], #4
cmp x6, #1
beq WriteEnd
str s10, [x11], #4
cmp x6, #2
beq WriteEnd
str s12, [x11], #4
cmp x6, #3
beq WriteEnd
str s14, [x11], #4
b WriteEnd
C4Write2:
st1 {v8.2s}, [x11], #8
cmp x6, #1
beq WriteEnd
st1 {v10.2s}, [x11], #8
cmp x6, #2
beq WriteEnd
st1 {v12.2s}, [x11], #8
cmp x6, #3
beq WriteEnd
st1 {v14.2s}, [x11], #8
b WriteEnd
C4Write3:
add x19, x11, #8
st1 {v8.2s}, [x11]
add x11, x11, #12
st1 {v8.s}[2], [x19]
add x19, x19, #12
cmp x6, #1
beq WriteEnd
st1 {v10.2s}, [x11]
add x11, x11, #12
st1 {v10.s}[2], [x19]
add x19, x19, #12
cmp x6, #2
beq WriteEnd
st1 {v12.2s}, [x11]
add x11, x11, #12
st1 {v12.s}[2], [x19]
add x19, x19, #12
cmp x6, #3
beq WriteEnd
st1 {v14.2s}, [x11]
st1 {v14.s}[2], [x19]
b WriteEnd
C4Write4:
st1 {v8.4s}, [x11], #16
cmp x6, #1
beq WriteEnd
st1 {v10.4s}, [x11], #16
cmp x6, #2
beq WriteEnd
st1 {v12.4s}, [x11], #16
cmp x6, #3
beq WriteEnd
st1 {v14.4s}, [x11], #16
b WriteEnd
C4Write5:
add x19, x11, #16
st1 {v8.4s}, [x11]
add x11, x11, #20
str s9, [x19]
add x19, x19, #20
cmp x6, #1
beq WriteEnd
st1 {v10.4s}, [x11]
add x11, x11, #20
str s11, [x19]
add x19, x19, #20
cmp x6, #2
beq WriteEnd
st1 {v12.4s}, [x11]
add x11, x11, #20
str s13, [x19]
add x19, x19, #20
cmp x6, #3
beq WriteEnd
st1 {v14.4s}, [x11]
str s15, [x19]
b WriteEnd
C4Write6:
add x19, x11, #16
st1 {v8.4s}, [x11]
add x11, x11, #24
st1 {v9.2s}, [x19]
add x19, x19, #24
cmp x6, #1
beq WriteEnd
st1 {v10.4s}, [x11]
add x11, x11, #24
st1 {v11.2s}, [x19]
add x19, x19, #24
cmp x6, #2
beq WriteEnd
st1 {v12.4s}, [x11]
add x11, x11, #24
st1 {v13.2s}, [x19]
add x19, x19, #24
cmp x6, #3
beq WriteEnd
st1 {v14.4s}, [x11]
st1 {v15.2s}, [x19]
b WriteEnd
C4Write7:
add x19, x11, #16
add x16, x11, #24
mov x10, #28
st1 {v8.4s}, [x11], x10
st1 {v9.2s}, [x19], x10
st1 {v9.s}[2], [x16], x10
cmp x6, #1
beq WriteEnd
st1 {v10.4s}, [x11], x10
st1 {v11.2s}, [x19], x10
st1 {v11.s}[2], [x16], x10
cmp x6, #2
beq WriteEnd
st1 {v12.4s}, [x11], x10
st1 {v13.2s}, [x19], x10
st1 {v13.s}[2], [x16], x10
cmp x6, #3
beq WriteEnd
st1 {v14.4s}, [x11], x10
st1 {v15.2s}, [x19], x10
st1 {v15.s}[2], [x16], x10
b WriteEnd
C4Write8:
add x19, x11, x8
add x20, x19, x8
st1 {v8.4s}, [x11], #16
st1 {v9.4s}, [x19], #16
cmp x6, #1
beq WriteEnd
st1 {v10.4s}, [x11], #16
st1 {v11.4s}, [x19], #16
cmp x6, #2
beq WriteEnd
st1 {v12.4s}, [x11], #16
st1 {v13.4s}, [x19], #16
cmp x6, #3
beq WriteEnd
st1 {v14.4s}, [x11], #16
st1 {v15.4s}, [x19], #16
WriteEnd: WriteEnd:
subs x13, x13, #8 // rhs col - 8 subs x13, x13, #8 // rhs col - 8
bgt LoopCol4 bgt LoopCol4
@ -378,11 +568,16 @@ LoopRow4:
LoopColEnd: LoopColEnd:
add x0, x0, x17 add x0, x0, x17
cbz x9, C8DstStep cbz x9, C8DstStep
cmp x9, #3
beq C4DstStep
mov x21, #4 mov x21, #4
mul x21, x21, x7 mul x21, x21, x7
sub x11, x11, x21 sub x11, x11, x21
mov x2, x11 mov x2, x11
b NoDstStep b NoDstStep
C4DstStep:
add x2, x2, x18
b NoDstStep
C8DstStep: C8DstStep:
add x2, x2, #384 add x2, x2, #384
mov x11, x2 mov x11, x2

View File

@ -29,10 +29,27 @@ asm_function MatmulFloatNeon64OptRow8
mov x21, #48 // sizeof(float) * 12 mov x21, #48 // sizeof(float) * 12
mul x17, x5, x21 // block stride of lhs/rhs: sizeof(float) * 12 * depth mul x17, x5, x21 // block stride of lhs/rhs: sizeof(float) * 12 * depth
cmp x9, #3 // c4
beq C4Stride
cbnz x9, NoC8Steps cbnz x9, NoC8Steps
mov x11, x2 mov x11, x2
mov x21, #32 mov x21, #32
mul x16, x6, x21 // row * 8 * sizeof(float) mul x16, x6, x21 // row * 8 * sizeof(float)
b NoC8Steps
C4Stride:
mov x18, #32 // 8 * sizeof(float)
mov x22, #4
mul x8, x8, x22 // stride * sizeof(float), in c4 stride == row
mul x8, x8, x22 // col stride
// col >= 4 , block stride 128, otherwise 8 * 4 * col
cmp x7, #4
bge C4StrideCommon
mul x18, x18, x7 // block stride
b LoopRowStart
C4StrideCommon:
mov x18, #128 // block stride
b LoopRowStart
NoC8Steps: NoC8Steps:
cmp x9, #2 cmp x9, #2
bne NoWinoSteps bne NoWinoSteps
@ -45,6 +62,10 @@ NoWinoSteps:
mov x21, #4 mov x21, #4
mul x8, x8, x21 mul x8, x8, x21
LoopRowStart:
cmp x9, #3
bne LoopRow8
mov x20, x2
LoopRow8: LoopRow8:
mov x14, x1 // reload rhs ptr mov x14, x1 // reload rhs ptr
mov x13, x7 // reload rhs col mov x13, x7 // reload rhs col
@ -52,7 +73,12 @@ LoopRow8:
LoopCol8: LoopCol8:
cbz x9, NoReloadDst8 cbz x9, NoReloadDst8
cmp x9, #3
beq C4ReloadDst8
mov x11, x2 mov x11, x2
b NoReloadDst8
C4ReloadDst8:
mov x11, x20
NoReloadDst8: NoReloadDst8:
mov x10, x0 // reload lhs ptr mov x10, x0 // reload lhs ptr
mov x19, x5 // reload depth mov x19, x5 // reload depth
@ -254,6 +280,8 @@ LoopRow8:
Write: Write:
cmp x9, #2 cmp x9, #2
beq WriteWino beq WriteWino
cmp x9, #3
beq WriteC4
cbz x9, WriteC8 cbz x9, WriteC8
cmp x13, #1 cmp x13, #1
beq Write1 beq Write1
@ -557,7 +585,312 @@ LoopRow8:
beq WriteEnd beq WriteEnd
st1 {v22.4s, v23.4s}, [x11], x8 st1 {v22.4s, v23.4s}, [x11], x8
add x11, x11, #32 add x11, x11, #32
b WriteEnd
WriteC4:
cmp x13, #1
beq C4Write1
cmp x13, #2
beq C4Write2
cmp x13, #3
beq C4Write3
cmp x13, #4
beq C4Write4
cmp x13, #5
beq C4Write5
cmp x13, #6
beq C4Write6
cmp x13, #7
beq C4Write7
b C4Write8
C4Write1:
str s8, [x11], #4
cmp x6, #1
beq WriteEnd
str s10, [x11], #4
cmp x6, #2
beq WriteEnd
str s12, [x11], #4
cmp x6, #3
beq WriteEnd
str s14, [x11], #4
cmp x6, #4
beq WriteEnd
str s16, [x11], #4
cmp x6, #5
beq WriteEnd
str s18, [x11], #4
cmp x6, #6
beq WriteEnd
str s20, [x11], #4
cmp x6, #7
beq WriteEnd
str s22, [x11], #4
b WriteEnd
C4Write2:
st1 {v8.2s}, [x11], #8
cmp x6, #1
beq WriteEnd
st1 {v10.2s}, [x11], #8
cmp x6, #2
beq WriteEnd
st1 {v12.2s}, [x11], #8
cmp x6, #3
beq WriteEnd
st1 {v14.2s}, [x11], #8
cmp x6, #4
beq WriteEnd
st1 {v16.2s}, [x11], #8
cmp x6, #5
beq WriteEnd
st1 {v18.2s}, [x11], #8
cmp x6, #6
beq WriteEnd
st1 {v20.2s}, [x11], #8
cmp x6, #7
beq WriteEnd
st1 {v22.2s}, [x11], #8
b WriteEnd
C4Write3:
add x19, x11, #8
st1 {v8.2s}, [x11]
add x11, x11, #12
st1 {v8.s}[2], [x19]
add x19, x19, #12
cmp x6, #1
beq WriteEnd
st1 {v10.2s}, [x11]
add x11, x11, #12
st1 {v10.s}[2], [x19]
add x19, x19, #12
cmp x6, #2
beq WriteEnd
st1 {v12.2s}, [x11]
add x11, x11, #12
st1 {v12.s}[2], [x19]
add x19, x19, #12
cmp x6, #3
beq WriteEnd
st1 {v14.2s}, [x11]
add x11, x11, #12
st1 {v14.s}[2], [x19]
add x19, x19, #12
cmp x6, #4
beq WriteEnd
st1 {v16.2s}, [x11]
add x11, x11, #12
st1 {v16.s}[2], [x19]
add x19, x19, #12
cmp x6, #5
beq WriteEnd
st1 {v18.2s}, [x11]
add x11, x11, #12
st1 {v18.s}[2], [x19]
add x19, x19, #12
cmp x6, #6
beq WriteEnd
st1 {v20.2s}, [x11]
add x11, x11, #12
st1 {v20.s}[2], [x19]
add x19, x19, #12
cmp x6, #7
beq WriteEnd
st1 {v22.2s}, [x11]
st1 {v22.s}[2], [x19]
b WriteEnd
C4Write4:
st1 {v8.4s}, [x11], #16
cmp x6, #1
beq WriteEnd
st1 {v10.4s}, [x11], #16
cmp x6, #2
beq WriteEnd
st1 {v12.4s}, [x11], #16
cmp x6, #3
beq WriteEnd
st1 {v14.4s}, [x11], #16
cmp x6, #4
beq WriteEnd
st1 {v16.4s}, [x11], #16
cmp x6, #5
beq WriteEnd
st1 {v18.4s}, [x11], #16
cmp x6, #6
beq WriteEnd
st1 {v20.4s}, [x11], #16
cmp x6, #7
beq WriteEnd
st1 {v22.4s}, [x11], #16
b WriteEnd
C4Write5:
add x19, x11, #16
st1 {v8.4s}, [x11]
add x11, x11, #20
str s9, [x19]
add x19, x19, #20
cmp x6, #1
beq WriteEnd
st1 {v10.4s}, [x11]
add x11, x11, #20
str s11, [x19]
add x19, x19, #20
cmp x6, #2
beq WriteEnd
st1 {v12.4s}, [x11]
add x11, x11, #20
str s13, [x19]
add x19, x19, #20
cmp x6, #3
beq WriteEnd
st1 {v14.4s}, [x11]
add x11, x11, #20
str s15, [x19]
add x19, x19, #20
cmp x6, #4
beq WriteEnd
st1 {v16.4s}, [x11]
add x11, x11, #20
str s17, [x19]
add x19, x19, #20
cmp x6, #5
beq WriteEnd
st1 {v18.4s}, [x11]
add x11, x11, #20
str s19, [x19]
add x19, x19, #20
cmp x6, #6
beq WriteEnd
st1 {v20.4s}, [x11]
add x11, x11, #20
str s21, [x19]
add x19, x19, #20
cmp x6, #7
beq WriteEnd
st1 {v22.4s}, [x11]
str s23, [x19]
b WriteEnd
C4Write6:
add x19, x11, #16
st1 {v8.4s}, [x11]
add x11, x11, #24
st1 {v9.2s}, [x19]
add x19, x19, #24
cmp x6, #1
beq WriteEnd
st1 {v10.4s}, [x11]
add x11, x11, #24
st1 {v11.2s}, [x19]
add x19, x19, #24
cmp x6, #2
beq WriteEnd
st1 {v12.4s}, [x11]
add x11, x11, #24
st1 {v13.2s}, [x19]
add x19, x19, #24
cmp x6, #3
beq WriteEnd
st1 {v14.4s}, [x11]
add x11, x11, #24
st1 {v15.2s}, [x19]
add x19, x19, #24
cmp x6, #4
beq WriteEnd
st1 {v16.4s}, [x11]
add x11, x11, #24
st1 {v17.2s}, [x19]
add x19, x19, #24
cmp x6, #5
beq WriteEnd
st1 {v18.4s}, [x11]
add x11, x11, #24
st1 {v19.2s}, [x19]
add x19, x19, #24
cmp x6, #6
beq WriteEnd
st1 {v20.4s}, [x11]
add x11, x11, #24
st1 {v21.2s}, [x19]
add x19, x19, #24
cmp x6, #7
beq WriteEnd
st1 {v22.4s}, [x11]
st1 {v23.2s}, [x19]
b WriteEnd
C4Write7:
add x19, x11, #16
add x16, x11, #24
mov x10, #28
st1 {v8.4s}, [x11], x10
st1 {v9.2s}, [x19], x10
st1 {v9.s}[2], [x16], x10
cmp x6, #1
beq WriteEnd
st1 {v10.4s}, [x11], x10
st1 {v11.2s}, [x19], x10
st1 {v11.s}[2], [x16], x10
cmp x6, #2
beq WriteEnd
st1 {v12.4s}, [x11], x10
st1 {v13.2s}, [x19], x10
st1 {v13.s}[2], [x16], x10
cmp x6, #3
beq WriteEnd
st1 {v14.4s}, [x11], x10
st1 {v15.2s}, [x19], x10
st1 {v15.s}[2], [x16], x10
cmp x6, #4
beq WriteEnd
st1 {v16.4s}, [x11], x10
st1 {v17.2s}, [x19], x10
st1 {v17.s}[2], [x16], x10
cmp x6, #5
beq WriteEnd
st1 {v18.4s}, [x11], x10
st1 {v19.2s}, [x19], x10
st1 {v19.s}[2], [x16], x10
cmp x6, #6
beq WriteEnd
st1 {v20.4s}, [x11], x10
st1 {v21.2s}, [x19], x10
st1 {v21.s}[2], [x16], x10
cmp x6, #7
beq WriteEnd
st1 {v22.4s}, [x11], x10
st1 {v23.2s}, [x19], x10
st1 {v23.s}[2], [x16], x10
b WriteEnd
C4Write8:
add x19, x11, x8
add x20, x19, x8
st1 {v8.4s}, [x11], #16
st1 {v9.4s}, [x19], #16
cmp x6, #1
beq WriteEnd
st1 {v10.4s}, [x11], #16
st1 {v11.4s}, [x19], #16
cmp x6, #2
beq WriteEnd
st1 {v12.4s}, [x11], #16
st1 {v13.4s}, [x19], #16
cmp x6, #3
beq WriteEnd
st1 {v14.4s}, [x11], #16
st1 {v15.4s}, [x19], #16
cmp x6, #4
beq WriteEnd
st1 {v16.4s}, [x11], #16
st1 {v17.4s}, [x19], #16
cmp x6, #5
beq WriteEnd
st1 {v18.4s}, [x11], #16
st1 {v19.4s}, [x19], #16
cmp x6, #6
beq WriteEnd
st1 {v20.4s}, [x11], #16
st1 {v21.4s}, [x19], #16
cmp x6, #7
beq WriteEnd
st1 {v22.4s}, [x11], #16
st1 {v23.4s}, [x19], #16
WriteEnd: WriteEnd:
subs x13, x13, #8 // rhs col - 8 subs x13, x13, #8 // rhs col - 8
bgt LoopCol8 bgt LoopCol8
@ -565,11 +898,16 @@ LoopRow8:
LoopColEnd: LoopColEnd:
add x0, x0, x17 add x0, x0, x17
cbz x9, C8DstStep cbz x9, C8DstStep
cmp x9, #3
beq C4DstStep
mov x21, #4 mov x21, #4
mul x21, x21, x7 mul x21, x21, x7
sub x11, x11, x21 sub x11, x11, x21
mov x2, x11 mov x2, x11
b NoDstStep b NoDstStep
C4DstStep:
add x2, x2, x18
b NoDstStep
C8DstStep: C8DstStep:
add x2, x2, #384 add x2, x2, #384
mov x11, x2 mov x11, x2

View File

@ -29,12 +29,14 @@ int Gather(const void *input, int outer_size, int inner_size, int limit, const i
int8_t *int8_out_m = int8_out + inner_size * m * indices_element_size * data_size; int8_t *int8_out_m = int8_out + inner_size * m * indices_element_size * data_size;
for (int i = 0; i < indices_element_size; ++i) { for (int i = 0; i < indices_element_size; ++i) {
if (indices[i] < 0 || indices[i] >= limit) { int index = indices[i];
printf("[ERROR] [%s:%d] %s] indices[%d]:%d is out of range [%d, %d)\n", __FILE__, __LINE__, __func__, i, if (index < -limit || indices[i] >= limit) {
indices[i], 0, limit);
return NNACL_ERR; return NNACL_ERR;
} }
memcpy(int8_out_m + i * inner_size * data_size, int8_in_m + indices[i] * inner_size * data_size, if (indices[i] < 0) {
index = limit + indices[i];
}
memcpy(int8_out_m + i * inner_size * data_size, int8_in_m + index * inner_size * data_size,
data_size * inner_size); data_size * inner_size);
} }
} }

View File

@ -43,7 +43,7 @@ void PadSliceParameterTo8D(SliceParameter *param) {
param->param_length_ = DIMENSION_8D; param->param_length_ = DIMENSION_8D;
} }
void DoSlice(const void *input, void *output, SliceParameter *param, int thread_id, int data_size) { void DoSlice(const void *input, void *output, const SliceParameter *param, int thread_id, int data_size) {
int8_t *int8_in = (int8_t *)input; int8_t *int8_in = (int8_t *)input;
int8_t *int8_out = (int8_t *)output; int8_t *int8_out = (int8_t *)output;
@ -94,14 +94,14 @@ void DoSlice(const void *input, void *output, SliceParameter *param, int thread_
} }
} }
static bool WhetherCopyByAxis(int begin[], int end[], const int shape[], int dim) { static bool WhetherCopyByAxis(const int begin[], const int end[], const int shape[], int dim) {
for (int i = dim + 1; i < DIMENSION_8D; ++i) { for (int i = dim + 1; i < DIMENSION_8D; ++i) {
if (begin[i] != 0 || end[i] != shape[i]) return false; if (begin[i] != 0 || end[i] != shape[i]) return false;
} }
return true; return true;
} }
void DoSliceNoParallel(const void *input, void *output, SliceParameter *param, int data_size) { void DoSliceNoParallel(const void *input, void *output, const SliceParameter *param, int data_size) {
int8_t *int8_in = (int8_t *)input; int8_t *int8_in = (int8_t *)input;
int8_t *int8_out = (int8_t *)output; int8_t *int8_out = (int8_t *)output;

View File

@ -25,8 +25,8 @@ extern "C" {
#endif #endif
void PadSliceParameterTo8D(SliceParameter *param); void PadSliceParameterTo8D(SliceParameter *param);
void DoSlice(const void *input, void *output, SliceParameter *param, int thread_id, int data_size); void DoSlice(const void *input, void *output, const SliceParameter *param, int thread_id, int data_size);
void DoSliceNoParallel(const void *input, void *output, SliceParameter *param, int data_size); void DoSliceNoParallel(const void *input, void *output, const SliceParameter *param, int data_size);
#ifdef __cplusplus #ifdef __cplusplus
} }
#endif #endif

View File

@ -20,12 +20,12 @@
#include "nnacl/errorcode.h" #include "nnacl/errorcode.h"
int DoSplit(void *in_data, void **out_data, const int *input_shape, int offset, int num_unit, int DoSplit(void *in_data, void **out_data, const int *input_shape, int offset, int num_unit,
SplitParameter *split_param, int data_size) { const SplitParameter *split_param, int data_size) {
int8_t *int8_in = (int8_t *)in_data; int8_t *int8_in = (int8_t *)in_data;
int num_split = split_param->num_split_; int num_split = split_param->num_split_;
int *split_sizes = split_param->split_sizes_; int *split_sizes = split_param->split_sizes_;
int *strides = split_param->strides_; const int *strides = split_param->strides_;
int split_dim = split_param->split_dim_; int split_dim = split_param->split_dim_;
int in_stride = strides[split_dim]; int in_stride = strides[split_dim];

View File

@ -24,7 +24,7 @@
extern "C" { extern "C" {
#endif #endif
int DoSplit(void *in_data, void **out_data, const int *input_shape, int offset, int num_unit, int DoSplit(void *in_data, void **out_data, const int *input_shape, int offset, int num_unit,
SplitParameter *split_param, int data_size); const SplitParameter *split_param, int data_size);
#ifdef __cplusplus #ifdef __cplusplus
} }
#endif #endif

View File

@ -18,7 +18,7 @@
#include <string.h> #include <string.h>
#include "nnacl/errorcode.h" #include "nnacl/errorcode.h"
int DoSplitWithOverlapParallel(char *in_data, char **out_data, int slice_idx, SplitWithOverlapParameter *param, int DoSplitWithOverlapParallel(char *in_data, char **out_data, int slice_idx, const SplitWithOverlapParameter *param,
const int *start_indices, const int *end_indices) { const int *start_indices, const int *end_indices) {
if (in_data == NULL || out_data == NULL) { if (in_data == NULL || out_data == NULL) {
return NNACL_NULL_PTR; return NNACL_NULL_PTR;

View File

@ -23,7 +23,7 @@
#ifdef __cplusplus #ifdef __cplusplus
extern "C" { extern "C" {
#endif #endif
int DoSplitWithOverlapParallel(char *in_data, char **out_data, int slice_idx, SplitWithOverlapParameter *param, int DoSplitWithOverlapParallel(char *in_data, char **out_data, int slice_idx, const SplitWithOverlapParameter *param,
const int *start_indices, const int *end_indices); const int *start_indices, const int *end_indices);
#ifdef __cplusplus #ifdef __cplusplus
} }

View File

@ -16,7 +16,7 @@
#include "nnacl/base/unstack_base.h" #include "nnacl/base/unstack_base.h"
void Unstack(const void *input, void **output, UnstackParameter *para, int data_size) { void Unstack(const void *input, void **output, const UnstackParameter *para, int data_size) {
const int8_t *in_addr = (int8_t *)input; const int8_t *in_addr = (int8_t *)input;
for (int j = 0; j < para->num_; j++) { for (int j = 0; j < para->num_; j++) {
int8_t *out_addr = (int8_t *)output[j]; int8_t *out_addr = (int8_t *)output[j];

View File

@ -24,7 +24,7 @@
#ifdef __cplusplus #ifdef __cplusplus
extern "C" { extern "C" {
#endif #endif
void Unstack(const void *input, void **output, UnstackParameter *para, int data_size); void Unstack(const void *input, void **output, const UnstackParameter *para, int data_size);
#ifdef __cplusplus #ifdef __cplusplus
} }
#endif #endif

View File

@ -54,6 +54,7 @@ typedef struct ConvParameter {
int channel_multiplie_; int channel_multiplie_;
int output_padding_w_; int output_padding_w_;
int output_padding_h_; int output_padding_h_;
int out_format_;
} ConvParameter; } ConvParameter;
typedef struct SlidingWindowParam { typedef struct SlidingWindowParam {

View File

@ -69,7 +69,7 @@ int ElementMulFp16(const float16_t *input0, const float16_t *input1, float16_t *
} }
int ElementOptMulFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, int ElementOptMulFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size,
ArithmeticParameter *param) { const ArithmeticParameter *param) {
#ifdef ENABLE_NEON #ifdef ENABLE_NEON
float16x8_t vin0_opt = vdupq_n_f16(input0[0]); float16x8_t vin0_opt = vdupq_n_f16(input0[0]);
float16x8_t vin1_opt = vdupq_n_f16(input1[0]); float16x8_t vin1_opt = vdupq_n_f16(input1[0]);
@ -123,7 +123,7 @@ int ElementMulReluFp16(const float16_t *input0, const float16_t *input1, float16
} }
int ElementOptMulReluFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, int ElementOptMulReluFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size,
ArithmeticParameter *param) { const ArithmeticParameter *param) {
#ifdef ENABLE_NEON #ifdef ENABLE_NEON
float16x8_t vin0_opt = vdupq_n_f16(input0[0]); float16x8_t vin0_opt = vdupq_n_f16(input0[0]);
float16x8_t vin1_opt = vdupq_n_f16(input1[0]); float16x8_t vin1_opt = vdupq_n_f16(input1[0]);
@ -180,7 +180,7 @@ int ElementMulRelu6Fp16(const float16_t *input0, const float16_t *input1, float1
} }
int ElementOptMulRelu6Fp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, int ElementOptMulRelu6Fp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size,
ArithmeticParameter *param) { const ArithmeticParameter *param) {
#ifdef ENABLE_NEON #ifdef ENABLE_NEON
float16x8_t vin0_opt = vdupq_n_f16(input0[0]); float16x8_t vin0_opt = vdupq_n_f16(input0[0]);
float16x8_t vin1_opt = vdupq_n_f16(input1[0]); float16x8_t vin1_opt = vdupq_n_f16(input1[0]);
@ -239,7 +239,7 @@ int ElementAddFp16(const float16_t *input0, const float16_t *input1, float16_t *
} }
int ElementOptAddFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, int ElementOptAddFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size,
ArithmeticParameter *param) { const ArithmeticParameter *param) {
#ifdef ENABLE_NEON #ifdef ENABLE_NEON
float16x8_t vin0_opt = vdupq_n_f16(input0[0]); float16x8_t vin0_opt = vdupq_n_f16(input0[0]);
float16x8_t vin1_opt = vdupq_n_f16(input1[0]); float16x8_t vin1_opt = vdupq_n_f16(input1[0]);
@ -299,7 +299,7 @@ int ElementAddReluFp16(const float16_t *input0, const float16_t *input1, float16
} }
int ElementOptAddReluFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, int ElementOptAddReluFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size,
ArithmeticParameter *param) { const ArithmeticParameter *param) {
#ifdef ENABLE_NEON #ifdef ENABLE_NEON
float16x8_t vin0_opt = vdupq_n_f16(input0[0]); float16x8_t vin0_opt = vdupq_n_f16(input0[0]);
float16x8_t vin1_opt = vdupq_n_f16(input1[0]); float16x8_t vin1_opt = vdupq_n_f16(input1[0]);
@ -365,7 +365,7 @@ int ElementAddRelu6Fp16(const float16_t *input0, const float16_t *input1, float1
} }
int ElementOptAddRelu6Fp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, int ElementOptAddRelu6Fp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size,
ArithmeticParameter *param) { const ArithmeticParameter *param) {
#ifdef ENABLE_NEON #ifdef ENABLE_NEON
float16x8_t vin0_opt = vdupq_n_f16(input0[0]); float16x8_t vin0_opt = vdupq_n_f16(input0[0]);
float16x8_t vin1_opt = vdupq_n_f16(input1[0]); float16x8_t vin1_opt = vdupq_n_f16(input1[0]);
@ -418,7 +418,7 @@ int ElementSubFp16(const float16_t *input0, const float16_t *input1, float16_t *
} }
int ElementOptSubFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, int ElementOptSubFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size,
ArithmeticParameter *param) { const ArithmeticParameter *param) {
#ifdef ENABLE_NEON #ifdef ENABLE_NEON
float16x8_t vin0_opt = vdupq_n_f16(input0[0]); float16x8_t vin0_opt = vdupq_n_f16(input0[0]);
float16x8_t vin1_opt = vdupq_n_f16(input1[0]); float16x8_t vin1_opt = vdupq_n_f16(input1[0]);
@ -470,7 +470,7 @@ int ElementSubReluFp16(const float16_t *input0, const float16_t *input1, float16
} }
int ElementOptSubReluFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, int ElementOptSubReluFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size,
ArithmeticParameter *param) { const ArithmeticParameter *param) {
#ifdef ENABLE_NEON #ifdef ENABLE_NEON
float16x8_t vin0_opt = vdupq_n_f16(input0[0]); float16x8_t vin0_opt = vdupq_n_f16(input0[0]);
float16x8_t vin1_opt = vdupq_n_f16(input1[0]); float16x8_t vin1_opt = vdupq_n_f16(input1[0]);
@ -527,7 +527,7 @@ int ElementSubRelu6Fp16(const float16_t *input0, const float16_t *input1, float1
} }
int ElementOptSubRelu6Fp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, int ElementOptSubRelu6Fp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size,
ArithmeticParameter *param) { const ArithmeticParameter *param) {
#ifdef ENABLE_NEON #ifdef ENABLE_NEON
float16x8_t vin0_opt = vdupq_n_f16(input0[0]); float16x8_t vin0_opt = vdupq_n_f16(input0[0]);
float16x8_t vin1_opt = vdupq_n_f16(input1[0]); float16x8_t vin1_opt = vdupq_n_f16(input1[0]);
@ -581,7 +581,7 @@ int ElementDivFp16(const float16_t *input0, const float16_t *input1, float16_t *
} }
int ElementOptDivFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, int ElementOptDivFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size,
ArithmeticParameter *param) { const ArithmeticParameter *param) {
#ifdef ENABLE_NEON #ifdef ENABLE_NEON
float16x8_t vin0_opt = vdupq_n_f16(input0[0]); float16x8_t vin0_opt = vdupq_n_f16(input0[0]);
float16x8_t vin1_opt = vdupq_n_f16(input1[0]); float16x8_t vin1_opt = vdupq_n_f16(input1[0]);
@ -641,7 +641,7 @@ int ElementDivReluFp16(const float16_t *input0, const float16_t *input1, float16
} }
int ElementOptDivReluFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, int ElementOptDivReluFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size,
ArithmeticParameter *param) { const ArithmeticParameter *param) {
#ifdef ENABLE_NEON #ifdef ENABLE_NEON
float16x8_t vin0_opt = vdupq_n_f16(input0[0]); float16x8_t vin0_opt = vdupq_n_f16(input0[0]);
float16x8_t vin1_opt = vdupq_n_f16(input1[0]); float16x8_t vin1_opt = vdupq_n_f16(input1[0]);
@ -704,7 +704,7 @@ int ElementDivRelu6Fp16(const float16_t *input0, const float16_t *input1, float1
} }
int ElementOptDivRelu6Fp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, int ElementOptDivRelu6Fp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size,
ArithmeticParameter *param) { const ArithmeticParameter *param) {
#ifdef ENABLE_NEON #ifdef ENABLE_NEON
float16x8_t vin0_opt = vdupq_n_f16(input0[0]); float16x8_t vin0_opt = vdupq_n_f16(input0[0]);
float16x8_t vin1_opt = vdupq_n_f16(input1[0]); float16x8_t vin1_opt = vdupq_n_f16(input1[0]);
@ -755,7 +755,7 @@ int ElementFloorModFp16(const float16_t *input0, const float16_t *input1, float1
} }
int ElementOptFloorModFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, int ElementOptFloorModFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size,
ArithmeticParameter *param) { const ArithmeticParameter *param) {
if (param->in_elements_num1_ == 1) { if (param->in_elements_num1_ == 1) {
for (int i = 0; i < element_size; ++i) { for (int i = 0; i < element_size; ++i) {
NNACL_ASSERT(input1[0] != 0); NNACL_ASSERT(input1[0] != 0);
@ -778,7 +778,7 @@ int ElementFloorDivFp16(const float16_t *input0, const float16_t *input1, float1
return NNACL_OK; return NNACL_OK;
} }
int ElementOptFloorDivFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, int ElementOptFloorDivFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size,
ArithmeticParameter *param) { const ArithmeticParameter *param) {
if (param->in_elements_num1_ == 1) { if (param->in_elements_num1_ == 1) {
for (int i = 0; i < element_size; ++i) { for (int i = 0; i < element_size; ++i) {
NNACL_ASSERT(input1[0] != 0); NNACL_ASSERT(input1[0] != 0);
@ -814,7 +814,7 @@ int ElementLogicalAndFp16(const float16_t *input0, const float16_t *input1, floa
} }
int ElementOptLogicalAndFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, int ElementOptLogicalAndFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size,
ArithmeticParameter *param) { const ArithmeticParameter *param) {
#ifdef ENABLE_NEON #ifdef ENABLE_NEON
float16x8_t vin0_opt = vdupq_n_f16(input0[0]); float16x8_t vin0_opt = vdupq_n_f16(input0[0]);
float16x8_t vin1_opt = vdupq_n_f16(input1[0]); float16x8_t vin1_opt = vdupq_n_f16(input1[0]);
@ -875,7 +875,7 @@ int ElementLogicalOrFp16(const float16_t *input0, const float16_t *input1, float
} }
int ElementOptLogicalOrFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, int ElementOptLogicalOrFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size,
ArithmeticParameter *param) { const ArithmeticParameter *param) {
#ifdef ENABLE_NEON #ifdef ENABLE_NEON
float16x8_t vin0_opt = vdupq_n_f16(input0[0]); float16x8_t vin0_opt = vdupq_n_f16(input0[0]);
float16x8_t vin1_opt = vdupq_n_f16(input1[0]); float16x8_t vin1_opt = vdupq_n_f16(input1[0]);
@ -922,7 +922,7 @@ int ElementSquaredDifferenceFp16(const float16_t *input0, const float16_t *input
} }
int ElementOptSquaredDifferenceFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int ElementOptSquaredDifferenceFp16(const float16_t *input0, const float16_t *input1, float16_t *output,
int element_size, ArithmeticParameter *param) { int element_size, const ArithmeticParameter *param) {
ElementOptSubFp16(input0, input1, output, element_size, param); ElementOptSubFp16(input0, input1, output, element_size, param);
return ElementMulFp16(output, output, output, element_size); return ElementMulFp16(output, output, output, element_size);
} }
@ -944,7 +944,7 @@ int ElementMaximumFp16(const float16_t *input0, const float16_t *input1, float16
} }
int ElementOptMaximumFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, int ElementOptMaximumFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size,
ArithmeticParameter *param) { const ArithmeticParameter *param) {
#ifdef ENABLE_NEON #ifdef ENABLE_NEON
float16x8_t vin0_opt = vdupq_n_f16(input0[0]); float16x8_t vin0_opt = vdupq_n_f16(input0[0]);
float16x8_t vin1_opt = vdupq_n_f16(input1[0]); float16x8_t vin1_opt = vdupq_n_f16(input1[0]);
@ -993,7 +993,7 @@ int ElementMinimumFp16(const float16_t *input0, const float16_t *input1, float16
} }
int ElementOptMinimumFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, int ElementOptMinimumFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size,
ArithmeticParameter *param) { const ArithmeticParameter *param) {
#ifdef ENABLE_NEON #ifdef ENABLE_NEON
float16x8_t vin0_opt = vdupq_n_f16(input0[0]); float16x8_t vin0_opt = vdupq_n_f16(input0[0]);
float16x8_t vin1_opt = vdupq_n_f16(input1[0]); float16x8_t vin1_opt = vdupq_n_f16(input1[0]);
@ -1042,7 +1042,7 @@ int ElementNotEqualFp16(const float16_t *input0, const float16_t *input1, uint8_
} }
int ElementOptNotEqualFp16(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size, int ElementOptNotEqualFp16(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size,
ArithmeticParameter *param) { const ArithmeticParameter *param) {
#ifdef ENABLE_NEON #ifdef ENABLE_NEON
float16x8_t vin0_opt = vdupq_n_f16(input0[0]); float16x8_t vin0_opt = vdupq_n_f16(input0[0]);
float16x8_t vin1_opt = vdupq_n_f16(input1[0]); float16x8_t vin1_opt = vdupq_n_f16(input1[0]);
@ -1091,7 +1091,7 @@ int ElementEqualFp16(const float16_t *input0, const float16_t *input1, uint8_t *
} }
int ElementOptEqualFp16(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size, int ElementOptEqualFp16(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size,
ArithmeticParameter *param) { const ArithmeticParameter *param) {
#ifdef ENABLE_NEON #ifdef ENABLE_NEON
float16x8_t vin0_opt = vdupq_n_f16(input0[0]); float16x8_t vin0_opt = vdupq_n_f16(input0[0]);
float16x8_t vin1_opt = vdupq_n_f16(input1[0]); float16x8_t vin1_opt = vdupq_n_f16(input1[0]);
@ -1140,7 +1140,7 @@ int ElementLessFp16(const float16_t *input0, const float16_t *input1, uint8_t *o
} }
int ElementOptLessFp16(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size, int ElementOptLessFp16(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size,
ArithmeticParameter *param) { const ArithmeticParameter *param) {
#ifdef ENABLE_NEON #ifdef ENABLE_NEON
float16x8_t vin0_opt = vdupq_n_f16(input0[0]); float16x8_t vin0_opt = vdupq_n_f16(input0[0]);
float16x8_t vin1_opt = vdupq_n_f16(input1[0]); float16x8_t vin1_opt = vdupq_n_f16(input1[0]);
@ -1189,7 +1189,7 @@ int ElementLessEqualFp16(const float16_t *input0, const float16_t *input1, uint8
} }
int ElementOptLessEqualFp16(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size, int ElementOptLessEqualFp16(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size,
ArithmeticParameter *param) { const ArithmeticParameter *param) {
#ifdef ENABLE_NEON #ifdef ENABLE_NEON
float16x8_t vin0_opt = vdupq_n_f16(input0[0]); float16x8_t vin0_opt = vdupq_n_f16(input0[0]);
float16x8_t vin1_opt = vdupq_n_f16(input1[0]); float16x8_t vin1_opt = vdupq_n_f16(input1[0]);
@ -1238,7 +1238,7 @@ int ElementGreaterFp16(const float16_t *input0, const float16_t *input1, uint8_t
} }
int ElementOptGreaterFp16(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size, int ElementOptGreaterFp16(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size,
ArithmeticParameter *param) { const ArithmeticParameter *param) {
#ifdef ENABLE_NEON #ifdef ENABLE_NEON
float16x8_t vin0_opt = vdupq_n_f16(input0[0]); float16x8_t vin0_opt = vdupq_n_f16(input0[0]);
float16x8_t vin1_opt = vdupq_n_f16(input1[0]); float16x8_t vin1_opt = vdupq_n_f16(input1[0]);
@ -1287,7 +1287,7 @@ int ElementGreaterEqualFp16(const float16_t *input0, const float16_t *input1, ui
} }
int ElementOptGreaterEqualFp16(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size, int ElementOptGreaterEqualFp16(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size,
ArithmeticParameter *param) { const ArithmeticParameter *param) {
#ifdef ENABLE_NEON #ifdef ENABLE_NEON
float16x8_t vin0_opt = vdupq_n_f16(input0[0]); float16x8_t vin0_opt = vdupq_n_f16(input0[0]);
float16x8_t vin1_opt = vdupq_n_f16(input1[0]); float16x8_t vin1_opt = vdupq_n_f16(input1[0]);

View File

@ -31,55 +31,55 @@ void TileDimensionsFp16(const float16_t *data0, const float16_t *data1, float16_
ArithmeticParameter *param); ArithmeticParameter *param);
int ElementOptMulFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, int ElementOptMulFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size,
ArithmeticParameter *param); const ArithmeticParameter *param);
int ElementOptMulReluFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, int ElementOptMulReluFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size,
ArithmeticParameter *param); const ArithmeticParameter *param);
int ElementOptMulRelu6Fp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, int ElementOptMulRelu6Fp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size,
ArithmeticParameter *param); const ArithmeticParameter *param);
int ElementOptAddFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, int ElementOptAddFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size,
ArithmeticParameter *param); const ArithmeticParameter *param);
int ElementOptAddReluFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, int ElementOptAddReluFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size,
ArithmeticParameter *param); const ArithmeticParameter *param);
int ElementOptAddRelu6Fp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, int ElementOptAddRelu6Fp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size,
ArithmeticParameter *param); const ArithmeticParameter *param);
int ElementOptSubFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, int ElementOptSubFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size,
ArithmeticParameter *param); const ArithmeticParameter *param);
int ElementOptSubReluFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, int ElementOptSubReluFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size,
ArithmeticParameter *param); const ArithmeticParameter *param);
int ElementOptSubRelu6Fp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, int ElementOptSubRelu6Fp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size,
ArithmeticParameter *param); const ArithmeticParameter *param);
int ElementOptDivFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, int ElementOptDivFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size,
ArithmeticParameter *param); const ArithmeticParameter *param);
int ElementOptDivReluFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, int ElementOptDivReluFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size,
ArithmeticParameter *param); const ArithmeticParameter *param);
int ElementOptDivRelu6Fp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, int ElementOptDivRelu6Fp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size,
ArithmeticParameter *param); const ArithmeticParameter *param);
int ElementOptFloorModFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, int ElementOptFloorModFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size,
ArithmeticParameter *param); const ArithmeticParameter *param);
int ElementOptFloorDivFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, int ElementOptFloorDivFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size,
ArithmeticParameter *param); const ArithmeticParameter *param);
int ElementOptLogicalAndFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, int ElementOptLogicalAndFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size,
ArithmeticParameter *param); const ArithmeticParameter *param);
int ElementOptLogicalOrFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, int ElementOptLogicalOrFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size,
ArithmeticParameter *param); const ArithmeticParameter *param);
int ElementOptSquaredDifferenceFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int ElementOptSquaredDifferenceFp16(const float16_t *input0, const float16_t *input1, float16_t *output,
int element_size, ArithmeticParameter *param); int element_size, const ArithmeticParameter *param);
int ElementOptMaximumFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, int ElementOptMaximumFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size,
ArithmeticParameter *param); const ArithmeticParameter *param);
int ElementOptMinimumFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, int ElementOptMinimumFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size,
ArithmeticParameter *param); const ArithmeticParameter *param);
int ElementOptNotEqualFp16(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size, int ElementOptNotEqualFp16(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size,
ArithmeticParameter *param); const ArithmeticParameter *param);
int ElementOptEqualFp16(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size, int ElementOptEqualFp16(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size,
ArithmeticParameter *param); const ArithmeticParameter *param);
int ElementOptLessFp16(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size, int ElementOptLessFp16(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size,
ArithmeticParameter *param); const ArithmeticParameter *param);
int ElementOptLessEqualFp16(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size, int ElementOptLessEqualFp16(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size,
ArithmeticParameter *param); const ArithmeticParameter *param);
int ElementOptGreaterFp16(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size, int ElementOptGreaterFp16(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size,
ArithmeticParameter *param); const ArithmeticParameter *param);
int ElementOptGreaterEqualFp16(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size, int ElementOptGreaterEqualFp16(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size,
ArithmeticParameter *param); const ArithmeticParameter *param);
int ElementMulFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size); int ElementMulFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size);
int ElementMulReluFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size); int ElementMulReluFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size);

View File

@ -16,21 +16,21 @@
#include <math.h> #include <math.h>
#include "nnacl/fp16/arithmetic_self_fp16.h" #include "nnacl/fp16/arithmetic_self_fp16.h"
int ElementAbsFp16(float16_t *input, float16_t *output, int element_size) { int ElementAbsFp16(const float16_t *input, float16_t *output, int element_size) {
for (int i = 0; i < element_size; i++) { for (int i = 0; i < element_size; i++) {
output[i] = fabsf(input[i]); output[i] = fabsf(input[i]);
} }
return NNACL_OK; return NNACL_OK;
} }
int ElementCosFp16(float16_t *input, float16_t *output, int element_size) { int ElementCosFp16(const float16_t *input, float16_t *output, int element_size) {
for (int i = 0; i < element_size; i++) { for (int i = 0; i < element_size; i++) {
output[i] = cosf(input[i]); output[i] = cosf(input[i]);
} }
return NNACL_OK; return NNACL_OK;
} }
int ElementLogFp16(float16_t *input, float16_t *output, int element_size) { int ElementLogFp16(const float16_t *input, float16_t *output, int element_size) {
for (int i = 0; i < element_size; i++) { for (int i = 0; i < element_size; i++) {
if (input[i] <= 0) { if (input[i] <= 0) {
return NNACL_ERRCODE_LOG_NEGATIVE_OR_ZERO; return NNACL_ERRCODE_LOG_NEGATIVE_OR_ZERO;
@ -40,14 +40,14 @@ int ElementLogFp16(float16_t *input, float16_t *output, int element_size) {
return NNACL_OK; return NNACL_OK;
} }
int ElementSquareFp16(float16_t *input, float16_t *output, int element_size) { int ElementSquareFp16(const float16_t *input, float16_t *output, int element_size) {
for (int i = 0; i < element_size; i++) { for (int i = 0; i < element_size; i++) {
output[i] = input[i] * input[i]; output[i] = input[i] * input[i];
} }
return NNACL_OK; return NNACL_OK;
} }
int ElementSqrtFp16(float16_t *input, float16_t *output, int element_size) { int ElementSqrtFp16(const float16_t *input, float16_t *output, int element_size) {
for (int i = 0; i < element_size; i++) { for (int i = 0; i < element_size; i++) {
if (input[i] < 0) { if (input[i] < 0) {
return NNACL_ERRCODE_SQRT_NEGATIVE; return NNACL_ERRCODE_SQRT_NEGATIVE;
@ -57,56 +57,56 @@ int ElementSqrtFp16(float16_t *input, float16_t *output, int element_size) {
return NNACL_OK; return NNACL_OK;
} }
int ElementRsqrtFp16(float16_t *input, float16_t *output, int element_size) { int ElementRsqrtFp16(const float16_t *input, float16_t *output, int element_size) {
for (int i = 0; i < element_size; i++) { for (int i = 0; i < element_size; i++) {
output[i] = 1.f / sqrtf(input[i]); output[i] = 1.f / sqrtf(input[i]);
} }
return NNACL_OK; return NNACL_OK;
} }
int ElementSinFp16(float16_t *input, float16_t *output, int element_size) { int ElementSinFp16(const float16_t *input, float16_t *output, int element_size) {
for (int i = 0; i < element_size; i++) { for (int i = 0; i < element_size; i++) {
output[i] = sinf(input[i]); output[i] = sinf(input[i]);
} }
return NNACL_OK; return NNACL_OK;
} }
int ElementLogicalNotFp16(float16_t *input, float16_t *output, int element_size) { int ElementLogicalNotFp16(const float16_t *input, float16_t *output, int element_size) {
for (int i = 0; i < element_size; i++) { for (int i = 0; i < element_size; i++) {
output[i] = (float)(!((bool)(input[i]))); output[i] = (float)(!((bool)(input[i])));
} }
return NNACL_OK; return NNACL_OK;
} }
int ElementRoundFp16(float16_t *input, float16_t *output, int element_size) { int ElementRoundFp16(const float16_t *input, float16_t *output, int element_size) {
for (int i = 0; i < element_size; i++) { for (int i = 0; i < element_size; i++) {
output[i] = roundf(input[i]); output[i] = roundf(input[i]);
} }
return NNACL_OK; return NNACL_OK;
} }
int ElementFloorFp16(float16_t *input, float16_t *output, int element_size) { int ElementFloorFp16(const float16_t *input, float16_t *output, int element_size) {
for (int i = 0; i < element_size; i++) { for (int i = 0; i < element_size; i++) {
output[i] = floorf(input[i]); output[i] = floorf(input[i]);
} }
return NNACL_OK; return NNACL_OK;
} }
int ElementCeilFp16(float16_t *input, float16_t *output, int number) { int ElementCeilFp16(const float16_t *input, float16_t *output, int number) {
for (int i = 0; i < number; ++i) { for (int i = 0; i < number; ++i) {
output[i] = ceilf(input[i]); output[i] = ceilf(input[i]);
} }
return NNACL_OK; return NNACL_OK;
} }
int ElementNegativeFp16(float16_t *input, float16_t *output, int element_size) { int ElementNegativeFp16(const float16_t *input, float16_t *output, int element_size) {
for (int i = 0; i < element_size; ++i) { for (int i = 0; i < element_size; ++i) {
output[i] = -input[i]; output[i] = -input[i];
} }
return NNACL_OK; return NNACL_OK;
} }
int ElementReciprocalFp16(float16_t *input, float16_t *output, int element_size) { int ElementReciprocalFp16(const float16_t *input, float16_t *output, int element_size) {
for (int i = 0; i < element_size; ++i) { for (int i = 0; i < element_size; ++i) {
if (input[i] == 0.0f) { if (input[i] == 0.0f) {
return NNACL_ERR; return NNACL_ERR;
@ -116,7 +116,7 @@ int ElementReciprocalFp16(float16_t *input, float16_t *output, int element_size)
return NNACL_OK; return NNACL_OK;
} }
int ElementErfFp16(float16_t *input, float16_t *output, int element_size) { int ElementErfFp16(const float16_t *input, float16_t *output, int element_size) {
for (int i = 0; i < element_size; i++) { for (int i = 0; i < element_size; i++) {
output[i] = erff(input[i]); output[i] = erff(input[i]);
} }

View File

@ -23,33 +23,33 @@
#ifdef __cplusplus #ifdef __cplusplus
extern "C" { extern "C" {
#endif #endif
int ElementAbsFp16(float16_t *input, float16_t *output, int element_size); int ElementAbsFp16(const float16_t *input, float16_t *output, int element_size);
int ElementCosFp16(float16_t *input, float16_t *output, int element_size); int ElementCosFp16(const float16_t *input, float16_t *output, int element_size);
int ElementLogFp16(float16_t *input, float16_t *output, int element_size); int ElementLogFp16(const float16_t *input, float16_t *output, int element_size);
int ElementSquareFp16(float16_t *input, float16_t *output, int element_size); int ElementSquareFp16(const float16_t *input, float16_t *output, int element_size);
int ElementSqrtFp16(float16_t *input, float16_t *output, int element_size); int ElementSqrtFp16(const float16_t *input, float16_t *output, int element_size);
int ElementRsqrtFp16(float16_t *input, float16_t *output, int element_size); int ElementRsqrtFp16(const float16_t *input, float16_t *output, int element_size);
int ElementSinFp16(float16_t *input, float16_t *output, int element_size); int ElementSinFp16(const float16_t *input, float16_t *output, int element_size);
int ElementLogicalNotFp16(float16_t *input, float16_t *output, int element_size); int ElementLogicalNotFp16(const float16_t *input, float16_t *output, int element_size);
int ElementRoundFp16(float16_t *input, float16_t *output, int element_size); int ElementRoundFp16(const float16_t *input, float16_t *output, int element_size);
int ElementFloorFp16(float16_t *input, float16_t *output, int element_size); int ElementFloorFp16(const float16_t *input, float16_t *output, int element_size);
int ElementCeilFp16(float16_t *input, float16_t *output, int number); int ElementCeilFp16(const float16_t *input, float16_t *output, int number);
int ElementNegativeFp16(float16_t *input, float16_t *output, int element_size); int ElementNegativeFp16(const float16_t *input, float16_t *output, int element_size);
int ElementReciprocalFp16(float16_t *input, float16_t *output, int element_size); int ElementReciprocalFp16(const float16_t *input, float16_t *output, int element_size);
int ElementErfFp16(float16_t *input, float16_t *output, int element_size); int ElementErfFp16(const float16_t *input, float16_t *output, int element_size);
#ifdef __cplusplus #ifdef __cplusplus
} }
#endif #endif

View File

@ -17,7 +17,7 @@
#include "nnacl/fp16/batchnorm_fp16.h" #include "nnacl/fp16/batchnorm_fp16.h"
#include <math.h> #include <math.h>
void BatchNormFp16(const float16_t *input, const void *mean, const void *variance, BatchNormParameter *param, void BatchNormFp16(const float16_t *input, const void *mean, const void *variance, const BatchNormParameter *param,
int task_id, float16_t *output) { int task_id, float16_t *output) {
int units_per_thread = UP_DIV(param->unit_, param->op_parameter_.thread_num_); int units_per_thread = UP_DIV(param->unit_, param->op_parameter_.thread_num_);
int completed_units = task_id * units_per_thread; int completed_units = task_id * units_per_thread;
@ -36,7 +36,7 @@ void BatchNormFp16(const float16_t *input, const void *mean, const void *varianc
} }
void FusedBatchNormFp16(const void *input, const void *scale, const void *offset, const void *mean, void FusedBatchNormFp16(const void *input, const void *scale, const void *offset, const void *mean,
const void *variance, BatchNormParameter *param, int task_id, void *output) { const void *variance, const BatchNormParameter *param, int task_id, void *output) {
int units_per_thread = UP_DIV(param->unit_, param->op_parameter_.thread_num_); int units_per_thread = UP_DIV(param->unit_, param->op_parameter_.thread_num_);
int completed_units = task_id * units_per_thread; int completed_units = task_id * units_per_thread;
int cur_unit = MSMIN(units_per_thread, param->unit_ - completed_units); int cur_unit = MSMIN(units_per_thread, param->unit_ - completed_units);

View File

@ -22,10 +22,10 @@
extern "C" { extern "C" {
#endif #endif
void BatchNormFp16(const float16_t *input, const void *mean, const void *variance, BatchNormParameter *param, void BatchNormFp16(const float16_t *input, const void *mean, const void *variance, const BatchNormParameter *param,
int task_id, float16_t *output); int task_id, float16_t *output);
void FusedBatchNormFp16(const void *input, const void *scale, const void *offset, const void *mean, void FusedBatchNormFp16(const void *input, const void *scale, const void *offset, const void *mean,
const void *variance, BatchNormParameter *param, int task_id, void *output); const void *variance, const BatchNormParameter *param, int task_id, void *output);
void FusedBatchNormFp16MeanVar(const float16_t *input, float16_t *run_mean, float16_t *run_var, void FusedBatchNormFp16MeanVar(const float16_t *input, float16_t *run_mean, float16_t *run_var,
const BatchNormParameter *param, float16_t *save_mean, float16_t *save_var); const BatchNormParameter *param, float16_t *save_mean, float16_t *save_var);
#ifdef __cplusplus #ifdef __cplusplus

Some files were not shown because too many files have changed in this diff Show More