forked from mindspore-Ecosystem/mindspore
multi_input
This commit is contained in:
parent
4ce1cf4529
commit
5cf29f2478
|
@ -109,6 +109,21 @@ bool IsInputFormatDtypeMatched(const KernelAttr &kernel_attr, const std::vector<
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ExpandKernelAttr(const CNodePtr &kernel_node, KernelAttr *kernel_attr) {
|
||||||
|
MS_EXCEPTION_IF_NULL(kernel_attr);
|
||||||
|
TypeId input_dtype = kernel_attr->GetInputAttr(0).first;
|
||||||
|
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||||
|
for (size_t i = 1; i < input_num; ++i) {
|
||||||
|
kernel_attr->AddInputAttr(input_dtype);
|
||||||
|
}
|
||||||
|
|
||||||
|
TypeId output_dtype = kernel_attr->GetOutputAttr(0).first;
|
||||||
|
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
|
||||||
|
for (size_t i = 1; i < output_num; ++i) {
|
||||||
|
kernel_attr->AddOutputAttr(output_dtype);
|
||||||
|
}
|
||||||
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
void SetKernelInfo(const CNodePtr &kernel_node) {
|
void SetKernelInfo(const CNodePtr &kernel_node) {
|
||||||
|
@ -125,12 +140,16 @@ void SetKernelInfo(const CNodePtr &kernel_node) {
|
||||||
kernel::CPUKernelFactory::GetInstance().GetSupportedKernelAttrList(AnfAlgo::GetCNodeName(kernel_node));
|
kernel::CPUKernelFactory::GetInstance().GetSupportedKernelAttrList(AnfAlgo::GetCNodeName(kernel_node));
|
||||||
|
|
||||||
for (size_t index = 0; index < kernel_attrs.size(); ++index) {
|
for (size_t index = 0; index < kernel_attrs.size(); ++index) {
|
||||||
if (IsInputFormatDtypeMatched(kernel_attrs[index], input_formats, input_types, input_not_cnode_indexes)) {
|
auto kernel_attr = kernel_attrs[index];
|
||||||
|
if (kernel_attr.GetAllSame()) {
|
||||||
|
ExpandKernelAttr(kernel_node, &kernel_attr);
|
||||||
|
}
|
||||||
|
if (IsInputFormatDtypeMatched(kernel_attr, input_formats, input_types, input_not_cnode_indexes)) {
|
||||||
MS_LOG(INFO) << "Input format and dtype is matched, index: " << index;
|
MS_LOG(INFO) << "Input format and dtype is matched, index: " << index;
|
||||||
GetOutputFormatsAndDtypes(kernel_node, kernel_attrs[index], &output_formats, &output_types);
|
GetOutputFormatsAndDtypes(kernel_node, kernel_attr, &output_formats, &output_types);
|
||||||
UpdatePrevNotCNodeFormatDtype(kernel_attrs[index], input_not_cnode_indexes, kernel_node);
|
UpdatePrevNotCNodeFormatDtype(kernel_attr, input_not_cnode_indexes, kernel_node);
|
||||||
for (auto &input_index : input_not_cnode_indexes) {
|
for (auto &input_index : input_not_cnode_indexes) {
|
||||||
input_types[input_index] = kernel_attrs[index].GetInputAttr(input_index).first;
|
input_types[input_index] = kernel_attr.GetInputAttr(input_index).first;
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
|
@ -46,8 +46,14 @@ class KernelAttr {
|
||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
KernelAttr &SetAllSameAttr(bool all_same) {
|
||||||
|
all_same_ = all_same;
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
const DataType &GetInputAttr(const size_t index) const { return input_type_[index]; }
|
const DataType &GetInputAttr(const size_t index) const { return input_type_[index]; }
|
||||||
const DataType &GetOutputAttr(const size_t index) const { return output_type_[index]; }
|
const DataType &GetOutputAttr(const size_t index) const { return output_type_[index]; }
|
||||||
|
bool GetAllSame() const { return all_same_; }
|
||||||
|
|
||||||
size_t GetInputSize() const { return input_type_.size(); }
|
size_t GetInputSize() const { return input_type_.size(); }
|
||||||
size_t GetOutputSize() const { return output_type_.size(); }
|
size_t GetOutputSize() const { return output_type_.size(); }
|
||||||
|
@ -55,6 +61,7 @@ class KernelAttr {
|
||||||
private:
|
private:
|
||||||
std::vector<DataType> input_type_;
|
std::vector<DataType> input_type_;
|
||||||
std::vector<DataType> output_type_;
|
std::vector<DataType> output_type_;
|
||||||
|
bool all_same_;
|
||||||
};
|
};
|
||||||
} // namespace cpu
|
} // namespace cpu
|
||||||
} // namespace device
|
} // namespace device
|
||||||
|
|
|
@ -39,16 +39,8 @@ class AddNCPUKernel : public CPUKernel {
|
||||||
std::vector<size_t> output_shape_;
|
std::vector<size_t> output_shape_;
|
||||||
};
|
};
|
||||||
|
|
||||||
MS_REG_CPU_KERNEL(
|
|
||||||
AddN,
|
|
||||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
|
||||||
AddNCPUKernel);
|
|
||||||
MS_REG_CPU_KERNEL(AddN,
|
MS_REG_CPU_KERNEL(AddN,
|
||||||
KernelAttr()
|
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||||
.AddInputAttr(kNumberTypeFloat32)
|
|
||||||
.AddInputAttr(kNumberTypeFloat32)
|
|
||||||
.AddInputAttr(kNumberTypeFloat32)
|
|
||||||
.AddOutputAttr(kNumberTypeFloat32),
|
|
||||||
AddNCPUKernel);
|
AddNCPUKernel);
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -41,9 +41,8 @@ class ConcatCPUKernel : public CPUKernel {
|
||||||
std::vector<size_t> output_shape_;
|
std::vector<size_t> output_shape_;
|
||||||
};
|
};
|
||||||
|
|
||||||
MS_REG_CPU_KERNEL(
|
MS_REG_CPU_KERNEL(Concat,
|
||||||
Concat,
|
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
|
||||||
ConcatCPUKernel);
|
ConcatCPUKernel);
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -59,28 +59,27 @@ std::pair<bool, size_t> CPUKernelFactory::CPUKernelAttrCheck(const std::string &
|
||||||
auto creators = iter->second;
|
auto creators = iter->second;
|
||||||
for (size_t index = 0; index < creators.size(); ++index) {
|
for (size_t index = 0; index < creators.size(); ++index) {
|
||||||
auto attr_creator = creators[index];
|
auto attr_creator = creators[index];
|
||||||
if (CPUKernelSingleAttrCheck(attr_creator, kernel_info)) {
|
if (CPUKernelSingleAttrCheck(attr_creator.first, kernel_info)) {
|
||||||
return std::make_pair(true, index);
|
return std::make_pair(true, index);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return std::make_pair(false, 0);
|
return std::make_pair(false, 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool CPUKernelFactory::CPUKernelSingleAttrCheck(const std::pair<KernelAttr, CPUKernelCreator> &attr_creator,
|
bool CPUKernelFactory::CPUKernelSingleAttrCheck(const KernelAttr &kernel_attr, const KernelBuildInfo &kernel_info) {
|
||||||
const KernelBuildInfo &kernel_info) {
|
|
||||||
for (size_t i = 0; i < kernel_info.GetInputNum(); ++i) {
|
for (size_t i = 0; i < kernel_info.GetInputNum(); ++i) {
|
||||||
if (kernel_info.GetInputDeviceType(i) != attr_creator.first.GetInputAttr(i).first) {
|
auto dtype = kernel_attr.GetAllSame() ? kernel_attr.GetInputAttr(0).first : kernel_attr.GetInputAttr(i).first;
|
||||||
MS_LOG(DEBUG) << "cpu kernel attr check failed. input index: " << i << ".";
|
if (kernel_info.GetInputDeviceType(i) != dtype) {
|
||||||
MS_LOG(DEBUG) << "kernel info type:" << kernel_info.GetInputDeviceType(i) << ", "
|
MS_LOG(DEBUG) << "input index:" << i << ", kernel info type:" << kernel_info.GetInputDeviceType(i)
|
||||||
<< "register type:" << attr_creator.first.GetInputAttr(i).first;
|
<< ", register type:" << dtype;
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for (size_t i = 0; i < kernel_info.GetOutputNum(); ++i) {
|
for (size_t i = 0; i < kernel_info.GetOutputNum(); ++i) {
|
||||||
if (kernel_info.GetOutputDeviceType(i) != attr_creator.first.GetOutputAttr(i).first) {
|
auto dtype = kernel_attr.GetAllSame() ? kernel_attr.GetOutputAttr(0).first : kernel_attr.GetOutputAttr(i).first;
|
||||||
MS_LOG(DEBUG) << "cpu kernel attr check failed. output index: " << i << ".";
|
if (kernel_info.GetOutputDeviceType(i) != dtype) {
|
||||||
MS_LOG(DEBUG) << "kernel info type:" << kernel_info.GetOutputDeviceType(i) << ", "
|
MS_LOG(DEBUG) << "output index:" << i << ", kernel info type:" << kernel_info.GetOutputDeviceType(i)
|
||||||
<< "register type:" << attr_creator.first.GetOutputAttr(i).first;
|
<< ", register type:" << dtype;
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -35,7 +35,6 @@ class CPUKernelFactory {
|
||||||
public:
|
public:
|
||||||
static CPUKernelFactory &GetInstance();
|
static CPUKernelFactory &GetInstance();
|
||||||
void Register(const std::string &kernel_name, const KernelAttr &kernel_attr, CPUKernelCreator &&kernel_creator);
|
void Register(const std::string &kernel_name, const KernelAttr &kernel_attr, CPUKernelCreator &&kernel_creator);
|
||||||
std::shared_ptr<CPUKernel> Create(const std::string &kernel_name);
|
|
||||||
std::shared_ptr<CPUKernel> Create(const std::string &kernel_name, const CNodePtr &apply_kernel);
|
std::shared_ptr<CPUKernel> Create(const std::string &kernel_name, const CNodePtr &apply_kernel);
|
||||||
std::vector<KernelAttr> GetSupportedKernelAttrList(const std::string &kernel_name);
|
std::vector<KernelAttr> GetSupportedKernelAttrList(const std::string &kernel_name);
|
||||||
|
|
||||||
|
@ -44,8 +43,7 @@ class CPUKernelFactory {
|
||||||
~CPUKernelFactory() = default;
|
~CPUKernelFactory() = default;
|
||||||
DISABLE_COPY_AND_ASSIGN(CPUKernelFactory)
|
DISABLE_COPY_AND_ASSIGN(CPUKernelFactory)
|
||||||
std::pair<bool, size_t> CPUKernelAttrCheck(const std::string &kernel_name, const KernelBuildInfo &kernel_info);
|
std::pair<bool, size_t> CPUKernelAttrCheck(const std::string &kernel_name, const KernelBuildInfo &kernel_info);
|
||||||
bool CPUKernelSingleAttrCheck(const std::pair<KernelAttr, CPUKernelCreator> &attr_creator,
|
bool CPUKernelSingleAttrCheck(const KernelAttr &kernel_attr, const KernelBuildInfo &kernel_info);
|
||||||
const KernelBuildInfo &kernel_info);
|
|
||||||
std::map<std::string, std::vector<std::pair<KernelAttr, CPUKernelCreator>>> name_to_attr_creator_;
|
std::map<std::string, std::vector<std::pair<KernelAttr, CPUKernelCreator>>> name_to_attr_creator_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -71,13 +71,13 @@ def test_in2_axis1():
|
||||||
assert np.all(diff < error)
|
assert np.all(diff < error)
|
||||||
assert np.all(-diff < error)
|
assert np.all(-diff < error)
|
||||||
|
|
||||||
class Concat_Axis2(nn.Cell):
|
class Concat_in3_Axis2(nn.Cell):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(Concat_Axis2, self).__init__()
|
super(Concat_in3_Axis2, self).__init__()
|
||||||
self.cat = P.Concat(axis=-1)
|
self.cat = P.Concat(axis=-1)
|
||||||
|
|
||||||
def construct(self, x1, x2):
|
def construct(self, x1, x2, x3):
|
||||||
return self.cat((x1, x2))
|
return self.cat((x1, x2, x3))
|
||||||
|
|
||||||
@pytest.mark.level0
|
@pytest.mark.level0
|
||||||
@pytest.mark.platform_x86_cpu
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@ -86,10 +86,10 @@ def test_in3_axis2():
|
||||||
x1 = Tensor(np.arange(2 * 2 * 1).reshape(2, 2, 1), mstype.float32)
|
x1 = Tensor(np.arange(2 * 2 * 1).reshape(2, 2, 1), mstype.float32)
|
||||||
x2 = Tensor(np.arange(2 * 2 * 2).reshape(2, 2, 2), mstype.float32)
|
x2 = Tensor(np.arange(2 * 2 * 2).reshape(2, 2, 2), mstype.float32)
|
||||||
x3 = Tensor(np.arange(2 * 2 * 3).reshape(2, 2, 3), mstype.float32)
|
x3 = Tensor(np.arange(2 * 2 * 3).reshape(2, 2, 3), mstype.float32)
|
||||||
cat = Concat_Axis2()
|
cat = Concat_in3_Axis2()
|
||||||
output_ms = cat(x1, x2)
|
output_ms = cat(x1, x2, x3)
|
||||||
print("output:\n", output_ms)
|
print("output:\n", output_ms)
|
||||||
output_np = np.concatenate((x1.asnumpy(), x2.asnumpy()), axis=-1)
|
output_np = np.concatenate((x1.asnumpy(), x2.asnumpy(), x3.asnumpy()), axis=-1)
|
||||||
|
|
||||||
error = np.ones(shape=output_np.shape) * 10e-6
|
error = np.ones(shape=output_np.shape) * 10e-6
|
||||||
diff = output_ms.asnumpy() - output_np
|
diff = output_ms.asnumpy() - output_np
|
||||||
|
|
Loading…
Reference in New Issue