add emblookup dynamic

This commit is contained in:
fangzehua 2020-11-13 15:29:02 +08:00
parent c31749e3d0
commit 963b063374
6 changed files with 77 additions and 9 deletions

View File

@ -48,32 +48,54 @@ void LookUpTableTask(const float *input_addr, const T *indices_addr, float *outp
void EmbeddingLookUpCPUKernel::InitKernel(const CNodePtr &kernel_node) {
CheckParam(kernel_node);
node_ = kernel_node;
std::vector<size_t> input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
if (input_shape.empty()) {
MS_LOG(EXCEPTION) << "param must be at least 1D";
}
first_dim_size_ = input_shape[0];
outer_dim_size_ = 1;
for (size_t i = 1; i < input_shape.size(); ++i) {
outer_dim_size_ *= input_shape[i];
}
indices_lens_ = 1;
std::vector<size_t> indices_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
for (const auto &shape : indices_shape) {
indices_lens_ *= shape;
}
indices_data_type_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 1);
if (AnfAlgo::HasNodeAttr(kAttrOffset, kernel_node)) {
offset_ = AnfAlgo::GetNodeAttr<int64_t>(kernel_node, kAttrOffset);
}
indices_data_type_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 1);
}
template <typename T>
void EmbeddingLookUpCPUKernel::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &outputs) const {
const std::vector<kernel::AddressPtr> &outputs) {
if (node_ != nullptr) {
std::vector<size_t> input_shape = AnfAlgo::GetPrevNodeOutputInferShape(node_, 0);
if (input_shape.empty()) {
MS_LOG(EXCEPTION) << "param must be at least 1D";
}
first_dim_size_ = input_shape[0];
outer_dim_size_ = 1;
for (size_t i = 1; i < input_shape.size(); ++i) {
outer_dim_size_ *= input_shape[i];
}
indices_lens_ = 1;
std::vector<size_t> indices_shape = AnfAlgo::GetPrevNodeOutputInferShape(node_, 1);
for (const auto &shape : indices_shape) {
indices_lens_ *= shape;
}
}
auto input_addr = reinterpret_cast<float *>(inputs[0]->addr);
auto indices_addr = reinterpret_cast<T *>(inputs[1]->addr);
auto output_addr = reinterpret_cast<float *>(outputs[0]->addr);
const size_t thread_num = 16;
std::thread threads[16];
const size_t kMaxThreadNum = 16;
size_t thread_num = indices_lens_ / 10000 + 1;
thread_num = thread_num > kMaxThreadNum ? kMaxThreadNum : thread_num;
std::thread threads[kMaxThreadNum];
size_t task_proc_lens = (indices_lens_ + thread_num - 1) / thread_num;
size_t i;
size_t task_offset = 0;

View File

@ -32,8 +32,7 @@ class EmbeddingLookUpCPUKernel : public CPUKernel {
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
template <typename T>
void LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &outputs) const;
void LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &outputs);
protected:
void CheckParam(const CNodePtr &kernel_node);
@ -42,6 +41,7 @@ class EmbeddingLookUpCPUKernel : public CPUKernel {
size_t first_dim_size_{1};
size_t outer_dim_size_{1};
TypeId indices_data_type_{kNumberTypeInt32};
CNodePtr node_ = nullptr;
};
MS_REG_CPU_KERNEL(

View File

@ -228,6 +228,8 @@ AbstractBasePtr InferImplReshape(const AnalysisEnginePtr &, const PrimitivePtr &
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplMemCpyAsync(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplEmbeddingLookup(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplSub(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplEqual(const AnalysisEnginePtr &, const PrimitivePtr &primitive,

View File

@ -358,6 +358,41 @@ AbstractBasePtr InferImplGatherV2(const AnalysisEnginePtr &, const PrimitivePtr
return std::make_shared<AbstractTensor>(params->element(), std::make_shared<Shape>(out_shape));
}
AbstractBasePtr InferImplEmbeddingLookup(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
const std::string op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, 2);
auto params = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
auto params_shp = params->shape();
MS_EXCEPTION_IF_NULL(params);
MS_EXCEPTION_IF_NULL(params_shp);
auto params_shape = params_shp->shape();
auto indices = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
auto indices_shp = indices->shape();
MS_EXCEPTION_IF_NULL(indices);
MS_EXCEPTION_IF_NULL(indices_shp);
auto indices_shape = indices_shp->shape();
auto indices_max_shape = indices_shp->max_shape();
ShapeVector shape;
ShapeVector max_shape;
shape.insert(shape.end(), indices_shape.begin(), indices_shape.end());
shape.insert(shape.end(), params_shape.begin() + 1, params_shape.end());
if (!indices_max_shape.empty()) {
max_shape.insert(max_shape.end(), indices_max_shape.begin(), indices_max_shape.end());
max_shape.insert(max_shape.end(), params_shape.begin() + 1, params_shape.end());
} else {
max_shape = shape;
}
ShapeVector min_shape;
for (size_t i = 0; i < max_shape.size(); ++i) {
min_shape.emplace_back(1);
}
AbstractTensorPtr ret =
std::make_shared<AbstractTensor>(params->element(), std::make_shared<Shape>(shape, min_shape, max_shape));
return ret;
}
AbstractBasePtr InferImplShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
const std::string &op_name = primitive->name();

View File

@ -54,6 +54,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
{prim::kPrimUniqueGrad, {InferImplUniqueGrad, true}},
{prim::kPrimGatherV2, {InferImplGatherV2, true}},
{prim::kPrimSparseGatherV2, {InferImplGatherV2, true}},
{prim::kPrimEmbeddingLookup, {InferImplEmbeddingLookup, true}},
{prim::kPrimUnsortedSegmentSum, {InferImplUnsortedSegmentSum, true}},
{prim::kPrimScatterAdd, {InferImplScatterAdd, true}},
{prim::kPrimScatterUpdate, {InferImplScatterUpdate, true}},

View File

@ -4141,12 +4141,20 @@ class EmbeddingLookup(PrimitiveWithInfer):
validator.check_tensor_dtype_valid("indices", indices['dtype'], mstype.int_type, self.name)
validator.check_subclass("offset", offset['dtype'], mstype.int_, self.name)
params_shp = params['shape']
if len(params_shp) != 2:
raise ValueError("The dimension of 'params' in EmbeddingLookup must be 2, but got %d." % len(params_shp))
out_shape = indices['shape'] + params_shp[1:]
if 'max_shape' in indices:
out_max_shape = indices['max_shape'] + params_shp[1:]
else:
out_max_shape = out_shape
if 'min_shape' in indices:
out_min_shape = indices['min_shape'] + params_shp[1:]
else:
out_min_shape = out_shape
out = {'shape': out_shape,
'dtype': params['dtype'],
'value': None}
'value': None,
'max_shape': out_max_shape,
'min_shape': out_min_shape}
return out