forked from mindspore-Ecosystem/mindspore
add emblookup dynamic
This commit is contained in:
parent
c31749e3d0
commit
963b063374
|
@ -48,32 +48,54 @@ void LookUpTableTask(const float *input_addr, const T *indices_addr, float *outp
|
|||
|
||||
void EmbeddingLookUpCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
||||
CheckParam(kernel_node);
|
||||
node_ = kernel_node;
|
||||
std::vector<size_t> input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||
if (input_shape.empty()) {
|
||||
MS_LOG(EXCEPTION) << "param must be at least 1D";
|
||||
}
|
||||
first_dim_size_ = input_shape[0];
|
||||
outer_dim_size_ = 1;
|
||||
for (size_t i = 1; i < input_shape.size(); ++i) {
|
||||
outer_dim_size_ *= input_shape[i];
|
||||
}
|
||||
indices_lens_ = 1;
|
||||
std::vector<size_t> indices_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
|
||||
for (const auto &shape : indices_shape) {
|
||||
indices_lens_ *= shape;
|
||||
}
|
||||
indices_data_type_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 1);
|
||||
if (AnfAlgo::HasNodeAttr(kAttrOffset, kernel_node)) {
|
||||
offset_ = AnfAlgo::GetNodeAttr<int64_t>(kernel_node, kAttrOffset);
|
||||
}
|
||||
indices_data_type_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 1);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void EmbeddingLookUpCPUKernel::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &outputs) const {
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
if (node_ != nullptr) {
|
||||
std::vector<size_t> input_shape = AnfAlgo::GetPrevNodeOutputInferShape(node_, 0);
|
||||
if (input_shape.empty()) {
|
||||
MS_LOG(EXCEPTION) << "param must be at least 1D";
|
||||
}
|
||||
first_dim_size_ = input_shape[0];
|
||||
outer_dim_size_ = 1;
|
||||
for (size_t i = 1; i < input_shape.size(); ++i) {
|
||||
outer_dim_size_ *= input_shape[i];
|
||||
}
|
||||
|
||||
indices_lens_ = 1;
|
||||
std::vector<size_t> indices_shape = AnfAlgo::GetPrevNodeOutputInferShape(node_, 1);
|
||||
for (const auto &shape : indices_shape) {
|
||||
indices_lens_ *= shape;
|
||||
}
|
||||
}
|
||||
auto input_addr = reinterpret_cast<float *>(inputs[0]->addr);
|
||||
auto indices_addr = reinterpret_cast<T *>(inputs[1]->addr);
|
||||
auto output_addr = reinterpret_cast<float *>(outputs[0]->addr);
|
||||
const size_t thread_num = 16;
|
||||
std::thread threads[16];
|
||||
const size_t kMaxThreadNum = 16;
|
||||
size_t thread_num = indices_lens_ / 10000 + 1;
|
||||
thread_num = thread_num > kMaxThreadNum ? kMaxThreadNum : thread_num;
|
||||
std::thread threads[kMaxThreadNum];
|
||||
size_t task_proc_lens = (indices_lens_ + thread_num - 1) / thread_num;
|
||||
size_t i;
|
||||
size_t task_offset = 0;
|
||||
|
|
|
@ -32,8 +32,7 @@ class EmbeddingLookUpCPUKernel : public CPUKernel {
|
|||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
template <typename T>
|
||||
void LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &outputs) const;
|
||||
void LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &outputs);
|
||||
|
||||
protected:
|
||||
void CheckParam(const CNodePtr &kernel_node);
|
||||
|
@ -42,6 +41,7 @@ class EmbeddingLookUpCPUKernel : public CPUKernel {
|
|||
size_t first_dim_size_{1};
|
||||
size_t outer_dim_size_{1};
|
||||
TypeId indices_data_type_{kNumberTypeInt32};
|
||||
CNodePtr node_ = nullptr;
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL(
|
||||
|
|
|
@ -228,6 +228,8 @@ AbstractBasePtr InferImplReshape(const AnalysisEnginePtr &, const PrimitivePtr &
|
|||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplMemCpyAsync(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplEmbeddingLookup(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplSub(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplEqual(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
|
|
|
@ -358,6 +358,41 @@ AbstractBasePtr InferImplGatherV2(const AnalysisEnginePtr &, const PrimitivePtr
|
|||
return std::make_shared<AbstractTensor>(params->element(), std::make_shared<Shape>(out_shape));
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplEmbeddingLookup(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
const std::string op_name = primitive->name();
|
||||
CheckArgsSize(op_name, args_spec_list, 2);
|
||||
auto params = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
|
||||
auto params_shp = params->shape();
|
||||
MS_EXCEPTION_IF_NULL(params);
|
||||
MS_EXCEPTION_IF_NULL(params_shp);
|
||||
auto params_shape = params_shp->shape();
|
||||
auto indices = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
|
||||
auto indices_shp = indices->shape();
|
||||
MS_EXCEPTION_IF_NULL(indices);
|
||||
MS_EXCEPTION_IF_NULL(indices_shp);
|
||||
auto indices_shape = indices_shp->shape();
|
||||
auto indices_max_shape = indices_shp->max_shape();
|
||||
ShapeVector shape;
|
||||
ShapeVector max_shape;
|
||||
shape.insert(shape.end(), indices_shape.begin(), indices_shape.end());
|
||||
shape.insert(shape.end(), params_shape.begin() + 1, params_shape.end());
|
||||
if (!indices_max_shape.empty()) {
|
||||
max_shape.insert(max_shape.end(), indices_max_shape.begin(), indices_max_shape.end());
|
||||
max_shape.insert(max_shape.end(), params_shape.begin() + 1, params_shape.end());
|
||||
} else {
|
||||
max_shape = shape;
|
||||
}
|
||||
ShapeVector min_shape;
|
||||
for (size_t i = 0; i < max_shape.size(); ++i) {
|
||||
min_shape.emplace_back(1);
|
||||
}
|
||||
|
||||
AbstractTensorPtr ret =
|
||||
std::make_shared<AbstractTensor>(params->element(), std::make_shared<Shape>(shape, min_shape, max_shape));
|
||||
return ret;
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
const std::string &op_name = primitive->name();
|
||||
|
|
|
@ -54,6 +54,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
|
|||
{prim::kPrimUniqueGrad, {InferImplUniqueGrad, true}},
|
||||
{prim::kPrimGatherV2, {InferImplGatherV2, true}},
|
||||
{prim::kPrimSparseGatherV2, {InferImplGatherV2, true}},
|
||||
{prim::kPrimEmbeddingLookup, {InferImplEmbeddingLookup, true}},
|
||||
{prim::kPrimUnsortedSegmentSum, {InferImplUnsortedSegmentSum, true}},
|
||||
{prim::kPrimScatterAdd, {InferImplScatterAdd, true}},
|
||||
{prim::kPrimScatterUpdate, {InferImplScatterUpdate, true}},
|
||||
|
|
|
@ -4141,12 +4141,20 @@ class EmbeddingLookup(PrimitiveWithInfer):
|
|||
validator.check_tensor_dtype_valid("indices", indices['dtype'], mstype.int_type, self.name)
|
||||
validator.check_subclass("offset", offset['dtype'], mstype.int_, self.name)
|
||||
params_shp = params['shape']
|
||||
if len(params_shp) != 2:
|
||||
raise ValueError("The dimension of 'params' in EmbeddingLookup must be 2, but got %d." % len(params_shp))
|
||||
out_shape = indices['shape'] + params_shp[1:]
|
||||
if 'max_shape' in indices:
|
||||
out_max_shape = indices['max_shape'] + params_shp[1:]
|
||||
else:
|
||||
out_max_shape = out_shape
|
||||
if 'min_shape' in indices:
|
||||
out_min_shape = indices['min_shape'] + params_shp[1:]
|
||||
else:
|
||||
out_min_shape = out_shape
|
||||
out = {'shape': out_shape,
|
||||
'dtype': params['dtype'],
|
||||
'value': None}
|
||||
'value': None,
|
||||
'max_shape': out_max_shape,
|
||||
'min_shape': out_min_shape}
|
||||
return out
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue