multi_input

This commit is contained in:
sunsuodong 2020-05-27 14:57:57 +08:00
parent 4ce1cf4529
commit 5cf29f2478
7 changed files with 52 additions and 38 deletions

View File

@ -109,6 +109,21 @@ bool IsInputFormatDtypeMatched(const KernelAttr &kernel_attr, const std::vector<
}
return true;
}
void ExpandKernelAttr(const CNodePtr &kernel_node, KernelAttr *kernel_attr) {
MS_EXCEPTION_IF_NULL(kernel_attr);
TypeId input_dtype = kernel_attr->GetInputAttr(0).first;
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
for (size_t i = 1; i < input_num; ++i) {
kernel_attr->AddInputAttr(input_dtype);
}
TypeId output_dtype = kernel_attr->GetOutputAttr(0).first;
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
for (size_t i = 1; i < output_num; ++i) {
kernel_attr->AddOutputAttr(output_dtype);
}
}
} // namespace
void SetKernelInfo(const CNodePtr &kernel_node) {
@ -125,12 +140,16 @@ void SetKernelInfo(const CNodePtr &kernel_node) {
kernel::CPUKernelFactory::GetInstance().GetSupportedKernelAttrList(AnfAlgo::GetCNodeName(kernel_node));
for (size_t index = 0; index < kernel_attrs.size(); ++index) {
if (IsInputFormatDtypeMatched(kernel_attrs[index], input_formats, input_types, input_not_cnode_indexes)) {
auto kernel_attr = kernel_attrs[index];
if (kernel_attr.GetAllSame()) {
ExpandKernelAttr(kernel_node, &kernel_attr);
}
if (IsInputFormatDtypeMatched(kernel_attr, input_formats, input_types, input_not_cnode_indexes)) {
MS_LOG(INFO) << "Input format and dtype is matched, index: " << index;
GetOutputFormatsAndDtypes(kernel_node, kernel_attrs[index], &output_formats, &output_types);
UpdatePrevNotCNodeFormatDtype(kernel_attrs[index], input_not_cnode_indexes, kernel_node);
GetOutputFormatsAndDtypes(kernel_node, kernel_attr, &output_formats, &output_types);
UpdatePrevNotCNodeFormatDtype(kernel_attr, input_not_cnode_indexes, kernel_node);
for (auto &input_index : input_not_cnode_indexes) {
input_types[input_index] = kernel_attrs[index].GetInputAttr(input_index).first;
input_types[input_index] = kernel_attr.GetInputAttr(input_index).first;
}
break;
}

View File

@ -46,8 +46,14 @@ class KernelAttr {
return *this;
}
KernelAttr &SetAllSameAttr(bool all_same) {
all_same_ = all_same;
return *this;
}
const DataType &GetInputAttr(const size_t index) const { return input_type_[index]; }
const DataType &GetOutputAttr(const size_t index) const { return output_type_[index]; }
bool GetAllSame() const { return all_same_; }
size_t GetInputSize() const { return input_type_.size(); }
size_t GetOutputSize() const { return output_type_.size(); }
@ -55,6 +61,7 @@ class KernelAttr {
private:
std::vector<DataType> input_type_;
std::vector<DataType> output_type_;
bool all_same_;
};
} // namespace cpu
} // namespace device

View File

@ -39,16 +39,8 @@ class AddNCPUKernel : public CPUKernel {
std::vector<size_t> output_shape_;
};
MS_REG_CPU_KERNEL(
AddN,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
AddNCPUKernel);
MS_REG_CPU_KERNEL(AddN,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
AddNCPUKernel);
} // namespace kernel
} // namespace mindspore

View File

@ -41,10 +41,9 @@ class ConcatCPUKernel : public CPUKernel {
std::vector<size_t> output_shape_;
};
MS_REG_CPU_KERNEL(
Concat,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
ConcatCPUKernel);
MS_REG_CPU_KERNEL(Concat,
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
ConcatCPUKernel);
} // namespace kernel
} // namespace mindspore

View File

@ -59,28 +59,27 @@ std::pair<bool, size_t> CPUKernelFactory::CPUKernelAttrCheck(const std::string &
auto creators = iter->second;
for (size_t index = 0; index < creators.size(); ++index) {
auto attr_creator = creators[index];
if (CPUKernelSingleAttrCheck(attr_creator, kernel_info)) {
if (CPUKernelSingleAttrCheck(attr_creator.first, kernel_info)) {
return std::make_pair(true, index);
}
}
return std::make_pair(false, 0);
}
bool CPUKernelFactory::CPUKernelSingleAttrCheck(const std::pair<KernelAttr, CPUKernelCreator> &attr_creator,
const KernelBuildInfo &kernel_info) {
bool CPUKernelFactory::CPUKernelSingleAttrCheck(const KernelAttr &kernel_attr, const KernelBuildInfo &kernel_info) {
for (size_t i = 0; i < kernel_info.GetInputNum(); ++i) {
if (kernel_info.GetInputDeviceType(i) != attr_creator.first.GetInputAttr(i).first) {
MS_LOG(DEBUG) << "cpu kernel attr check failed. input index: " << i << ".";
MS_LOG(DEBUG) << "kernel info type:" << kernel_info.GetInputDeviceType(i) << ", "
<< "register type:" << attr_creator.first.GetInputAttr(i).first;
auto dtype = kernel_attr.GetAllSame() ? kernel_attr.GetInputAttr(0).first : kernel_attr.GetInputAttr(i).first;
if (kernel_info.GetInputDeviceType(i) != dtype) {
MS_LOG(DEBUG) << "input index:" << i << ", kernel info type:" << kernel_info.GetInputDeviceType(i)
<< ", register type:" << dtype;
return false;
}
}
for (size_t i = 0; i < kernel_info.GetOutputNum(); ++i) {
if (kernel_info.GetOutputDeviceType(i) != attr_creator.first.GetOutputAttr(i).first) {
MS_LOG(DEBUG) << "cpu kernel attr check failed. output index: " << i << ".";
MS_LOG(DEBUG) << "kernel info type:" << kernel_info.GetOutputDeviceType(i) << ", "
<< "register type:" << attr_creator.first.GetOutputAttr(i).first;
auto dtype = kernel_attr.GetAllSame() ? kernel_attr.GetOutputAttr(0).first : kernel_attr.GetOutputAttr(i).first;
if (kernel_info.GetOutputDeviceType(i) != dtype) {
MS_LOG(DEBUG) << "output index:" << i << ", kernel info type:" << kernel_info.GetOutputDeviceType(i)
<< ", register type:" << dtype;
return false;
}
}

View File

@ -35,7 +35,6 @@ class CPUKernelFactory {
public:
static CPUKernelFactory &GetInstance();
void Register(const std::string &kernel_name, const KernelAttr &kernel_attr, CPUKernelCreator &&kernel_creator);
std::shared_ptr<CPUKernel> Create(const std::string &kernel_name);
std::shared_ptr<CPUKernel> Create(const std::string &kernel_name, const CNodePtr &apply_kernel);
std::vector<KernelAttr> GetSupportedKernelAttrList(const std::string &kernel_name);
@ -44,8 +43,7 @@ class CPUKernelFactory {
~CPUKernelFactory() = default;
DISABLE_COPY_AND_ASSIGN(CPUKernelFactory)
std::pair<bool, size_t> CPUKernelAttrCheck(const std::string &kernel_name, const KernelBuildInfo &kernel_info);
bool CPUKernelSingleAttrCheck(const std::pair<KernelAttr, CPUKernelCreator> &attr_creator,
const KernelBuildInfo &kernel_info);
bool CPUKernelSingleAttrCheck(const KernelAttr &kernel_attr, const KernelBuildInfo &kernel_info);
std::map<std::string, std::vector<std::pair<KernelAttr, CPUKernelCreator>>> name_to_attr_creator_;
};

View File

@ -71,13 +71,13 @@ def test_in2_axis1():
assert np.all(diff < error)
assert np.all(-diff < error)
class Concat_Axis2(nn.Cell):
class Concat_in3_Axis2(nn.Cell):
def __init__(self):
super(Concat_Axis2, self).__init__()
super(Concat_in3_Axis2, self).__init__()
self.cat = P.Concat(axis=-1)
def construct(self, x1, x2):
return self.cat((x1, x2))
def construct(self, x1, x2, x3):
return self.cat((x1, x2, x3))
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@ -86,10 +86,10 @@ def test_in3_axis2():
x1 = Tensor(np.arange(2 * 2 * 1).reshape(2, 2, 1), mstype.float32)
x2 = Tensor(np.arange(2 * 2 * 2).reshape(2, 2, 2), mstype.float32)
x3 = Tensor(np.arange(2 * 2 * 3).reshape(2, 2, 3), mstype.float32)
cat = Concat_Axis2()
output_ms = cat(x1, x2)
cat = Concat_in3_Axis2()
output_ms = cat(x1, x2, x3)
print("output:\n", output_ms)
output_np = np.concatenate((x1.asnumpy(), x2.asnumpy()), axis=-1)
output_np = np.concatenate((x1.asnumpy(), x2.asnumpy(), x3.asnumpy()), axis=-1)
error = np.ones(shape=output_np.shape) * 10e-6
diff = output_ms.asnumpy() - output_np