forked from mindspore-Ecosystem/mindspore
multi_input
This commit is contained in:
parent
4ce1cf4529
commit
5cf29f2478
|
@ -109,6 +109,21 @@ bool IsInputFormatDtypeMatched(const KernelAttr &kernel_attr, const std::vector<
|
|||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void ExpandKernelAttr(const CNodePtr &kernel_node, KernelAttr *kernel_attr) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_attr);
|
||||
TypeId input_dtype = kernel_attr->GetInputAttr(0).first;
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
for (size_t i = 1; i < input_num; ++i) {
|
||||
kernel_attr->AddInputAttr(input_dtype);
|
||||
}
|
||||
|
||||
TypeId output_dtype = kernel_attr->GetOutputAttr(0).first;
|
||||
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
|
||||
for (size_t i = 1; i < output_num; ++i) {
|
||||
kernel_attr->AddOutputAttr(output_dtype);
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void SetKernelInfo(const CNodePtr &kernel_node) {
|
||||
|
@ -125,12 +140,16 @@ void SetKernelInfo(const CNodePtr &kernel_node) {
|
|||
kernel::CPUKernelFactory::GetInstance().GetSupportedKernelAttrList(AnfAlgo::GetCNodeName(kernel_node));
|
||||
|
||||
for (size_t index = 0; index < kernel_attrs.size(); ++index) {
|
||||
if (IsInputFormatDtypeMatched(kernel_attrs[index], input_formats, input_types, input_not_cnode_indexes)) {
|
||||
auto kernel_attr = kernel_attrs[index];
|
||||
if (kernel_attr.GetAllSame()) {
|
||||
ExpandKernelAttr(kernel_node, &kernel_attr);
|
||||
}
|
||||
if (IsInputFormatDtypeMatched(kernel_attr, input_formats, input_types, input_not_cnode_indexes)) {
|
||||
MS_LOG(INFO) << "Input format and dtype is matched, index: " << index;
|
||||
GetOutputFormatsAndDtypes(kernel_node, kernel_attrs[index], &output_formats, &output_types);
|
||||
UpdatePrevNotCNodeFormatDtype(kernel_attrs[index], input_not_cnode_indexes, kernel_node);
|
||||
GetOutputFormatsAndDtypes(kernel_node, kernel_attr, &output_formats, &output_types);
|
||||
UpdatePrevNotCNodeFormatDtype(kernel_attr, input_not_cnode_indexes, kernel_node);
|
||||
for (auto &input_index : input_not_cnode_indexes) {
|
||||
input_types[input_index] = kernel_attrs[index].GetInputAttr(input_index).first;
|
||||
input_types[input_index] = kernel_attr.GetInputAttr(input_index).first;
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
|
|
@ -46,8 +46,14 @@ class KernelAttr {
|
|||
return *this;
|
||||
}
|
||||
|
||||
KernelAttr &SetAllSameAttr(bool all_same) {
|
||||
all_same_ = all_same;
|
||||
return *this;
|
||||
}
|
||||
|
||||
const DataType &GetInputAttr(const size_t index) const { return input_type_[index]; }
|
||||
const DataType &GetOutputAttr(const size_t index) const { return output_type_[index]; }
|
||||
bool GetAllSame() const { return all_same_; }
|
||||
|
||||
size_t GetInputSize() const { return input_type_.size(); }
|
||||
size_t GetOutputSize() const { return output_type_.size(); }
|
||||
|
@ -55,6 +61,7 @@ class KernelAttr {
|
|||
private:
|
||||
std::vector<DataType> input_type_;
|
||||
std::vector<DataType> output_type_;
|
||||
bool all_same_;
|
||||
};
|
||||
} // namespace cpu
|
||||
} // namespace device
|
||||
|
|
|
@ -39,16 +39,8 @@ class AddNCPUKernel : public CPUKernel {
|
|||
std::vector<size_t> output_shape_;
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL(
|
||||
AddN,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
AddNCPUKernel);
|
||||
MS_REG_CPU_KERNEL(AddN,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
AddNCPUKernel);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -41,10 +41,9 @@ class ConcatCPUKernel : public CPUKernel {
|
|||
std::vector<size_t> output_shape_;
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL(
|
||||
Concat,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
ConcatCPUKernel);
|
||||
MS_REG_CPU_KERNEL(Concat,
|
||||
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
ConcatCPUKernel);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -59,28 +59,27 @@ std::pair<bool, size_t> CPUKernelFactory::CPUKernelAttrCheck(const std::string &
|
|||
auto creators = iter->second;
|
||||
for (size_t index = 0; index < creators.size(); ++index) {
|
||||
auto attr_creator = creators[index];
|
||||
if (CPUKernelSingleAttrCheck(attr_creator, kernel_info)) {
|
||||
if (CPUKernelSingleAttrCheck(attr_creator.first, kernel_info)) {
|
||||
return std::make_pair(true, index);
|
||||
}
|
||||
}
|
||||
return std::make_pair(false, 0);
|
||||
}
|
||||
|
||||
bool CPUKernelFactory::CPUKernelSingleAttrCheck(const std::pair<KernelAttr, CPUKernelCreator> &attr_creator,
|
||||
const KernelBuildInfo &kernel_info) {
|
||||
bool CPUKernelFactory::CPUKernelSingleAttrCheck(const KernelAttr &kernel_attr, const KernelBuildInfo &kernel_info) {
|
||||
for (size_t i = 0; i < kernel_info.GetInputNum(); ++i) {
|
||||
if (kernel_info.GetInputDeviceType(i) != attr_creator.first.GetInputAttr(i).first) {
|
||||
MS_LOG(DEBUG) << "cpu kernel attr check failed. input index: " << i << ".";
|
||||
MS_LOG(DEBUG) << "kernel info type:" << kernel_info.GetInputDeviceType(i) << ", "
|
||||
<< "register type:" << attr_creator.first.GetInputAttr(i).first;
|
||||
auto dtype = kernel_attr.GetAllSame() ? kernel_attr.GetInputAttr(0).first : kernel_attr.GetInputAttr(i).first;
|
||||
if (kernel_info.GetInputDeviceType(i) != dtype) {
|
||||
MS_LOG(DEBUG) << "input index:" << i << ", kernel info type:" << kernel_info.GetInputDeviceType(i)
|
||||
<< ", register type:" << dtype;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
for (size_t i = 0; i < kernel_info.GetOutputNum(); ++i) {
|
||||
if (kernel_info.GetOutputDeviceType(i) != attr_creator.first.GetOutputAttr(i).first) {
|
||||
MS_LOG(DEBUG) << "cpu kernel attr check failed. output index: " << i << ".";
|
||||
MS_LOG(DEBUG) << "kernel info type:" << kernel_info.GetOutputDeviceType(i) << ", "
|
||||
<< "register type:" << attr_creator.first.GetOutputAttr(i).first;
|
||||
auto dtype = kernel_attr.GetAllSame() ? kernel_attr.GetOutputAttr(0).first : kernel_attr.GetOutputAttr(i).first;
|
||||
if (kernel_info.GetOutputDeviceType(i) != dtype) {
|
||||
MS_LOG(DEBUG) << "output index:" << i << ", kernel info type:" << kernel_info.GetOutputDeviceType(i)
|
||||
<< ", register type:" << dtype;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -35,7 +35,6 @@ class CPUKernelFactory {
|
|||
public:
|
||||
static CPUKernelFactory &GetInstance();
|
||||
void Register(const std::string &kernel_name, const KernelAttr &kernel_attr, CPUKernelCreator &&kernel_creator);
|
||||
std::shared_ptr<CPUKernel> Create(const std::string &kernel_name);
|
||||
std::shared_ptr<CPUKernel> Create(const std::string &kernel_name, const CNodePtr &apply_kernel);
|
||||
std::vector<KernelAttr> GetSupportedKernelAttrList(const std::string &kernel_name);
|
||||
|
||||
|
@ -44,8 +43,7 @@ class CPUKernelFactory {
|
|||
~CPUKernelFactory() = default;
|
||||
DISABLE_COPY_AND_ASSIGN(CPUKernelFactory)
|
||||
std::pair<bool, size_t> CPUKernelAttrCheck(const std::string &kernel_name, const KernelBuildInfo &kernel_info);
|
||||
bool CPUKernelSingleAttrCheck(const std::pair<KernelAttr, CPUKernelCreator> &attr_creator,
|
||||
const KernelBuildInfo &kernel_info);
|
||||
bool CPUKernelSingleAttrCheck(const KernelAttr &kernel_attr, const KernelBuildInfo &kernel_info);
|
||||
std::map<std::string, std::vector<std::pair<KernelAttr, CPUKernelCreator>>> name_to_attr_creator_;
|
||||
};
|
||||
|
||||
|
|
|
@ -71,13 +71,13 @@ def test_in2_axis1():
|
|||
assert np.all(diff < error)
|
||||
assert np.all(-diff < error)
|
||||
|
||||
class Concat_Axis2(nn.Cell):
|
||||
class Concat_in3_Axis2(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Concat_Axis2, self).__init__()
|
||||
super(Concat_in3_Axis2, self).__init__()
|
||||
self.cat = P.Concat(axis=-1)
|
||||
|
||||
def construct(self, x1, x2):
|
||||
return self.cat((x1, x2))
|
||||
def construct(self, x1, x2, x3):
|
||||
return self.cat((x1, x2, x3))
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
|
@ -86,10 +86,10 @@ def test_in3_axis2():
|
|||
x1 = Tensor(np.arange(2 * 2 * 1).reshape(2, 2, 1), mstype.float32)
|
||||
x2 = Tensor(np.arange(2 * 2 * 2).reshape(2, 2, 2), mstype.float32)
|
||||
x3 = Tensor(np.arange(2 * 2 * 3).reshape(2, 2, 3), mstype.float32)
|
||||
cat = Concat_Axis2()
|
||||
output_ms = cat(x1, x2)
|
||||
cat = Concat_in3_Axis2()
|
||||
output_ms = cat(x1, x2, x3)
|
||||
print("output:\n", output_ms)
|
||||
output_np = np.concatenate((x1.asnumpy(), x2.asnumpy()), axis=-1)
|
||||
output_np = np.concatenate((x1.asnumpy(), x2.asnumpy(), x3.asnumpy()), axis=-1)
|
||||
|
||||
error = np.ones(shape=output_np.shape) * 10e-6
|
||||
diff = output_ms.asnumpy() - output_np
|
||||
|
|
Loading…
Reference in New Issue