forked from OSchip/llvm-project
[flang] Shape analysis for result of MATMUL
Implement shape analysis for the result of the MATMUL generic transformational intrinsic function, based on the shapes of its arguments. Correct the names of the arguments to match the standard, too. Reviewed By: PeteSteinfeld Differential Revision: https://reviews.llvm.org/D82250
This commit is contained in:
parent
1728dec255
commit
16d24e4543
|
@ -496,28 +496,28 @@ static const IntrinsicInterface genericIntrinsicFunction[]{
|
|||
{"logical", {{"l", AnyLogical}, DefaultingKIND}, KINDLogical},
|
||||
{"log_gamma", {{"x", SameReal}}, SameReal},
|
||||
{"matmul",
|
||||
{{"array_a", AnyLogical, Rank::vector},
|
||||
{"array_b", AnyLogical, Rank::matrix}},
|
||||
{{"matrix_a", AnyLogical, Rank::vector},
|
||||
{"matrix_b", AnyLogical, Rank::matrix}},
|
||||
ResultLogical, Rank::vector, IntrinsicClass::transformationalFunction},
|
||||
{"matmul",
|
||||
{{"array_a", AnyLogical, Rank::matrix},
|
||||
{"array_b", AnyLogical, Rank::vector}},
|
||||
{{"matrix_a", AnyLogical, Rank::matrix},
|
||||
{"matrix_b", AnyLogical, Rank::vector}},
|
||||
ResultLogical, Rank::vector, IntrinsicClass::transformationalFunction},
|
||||
{"matmul",
|
||||
{{"array_a", AnyLogical, Rank::matrix},
|
||||
{"array_b", AnyLogical, Rank::matrix}},
|
||||
{{"matrix_a", AnyLogical, Rank::matrix},
|
||||
{"matrix_b", AnyLogical, Rank::matrix}},
|
||||
ResultLogical, Rank::matrix, IntrinsicClass::transformationalFunction},
|
||||
{"matmul",
|
||||
{{"array_a", AnyNumeric, Rank::vector},
|
||||
{"array_b", AnyNumeric, Rank::matrix}},
|
||||
{{"matrix_a", AnyNumeric, Rank::vector},
|
||||
{"matrix_b", AnyNumeric, Rank::matrix}},
|
||||
ResultNumeric, Rank::vector, IntrinsicClass::transformationalFunction},
|
||||
{"matmul",
|
||||
{{"array_a", AnyNumeric, Rank::matrix},
|
||||
{"array_b", AnyNumeric, Rank::vector}},
|
||||
{{"matrix_a", AnyNumeric, Rank::matrix},
|
||||
{"matrix_b", AnyNumeric, Rank::vector}},
|
||||
ResultNumeric, Rank::vector, IntrinsicClass::transformationalFunction},
|
||||
{"matmul",
|
||||
{{"array_a", AnyNumeric, Rank::matrix},
|
||||
{"array_b", AnyNumeric, Rank::matrix}},
|
||||
{{"matrix_a", AnyNumeric, Rank::matrix},
|
||||
{"matrix_b", AnyNumeric, Rank::matrix}},
|
||||
ResultNumeric, Rank::matrix, IntrinsicClass::transformationalFunction},
|
||||
{"maskl", {{"i", AnyInt}, DefaultingKIND}, KINDInt},
|
||||
{"maskr", {{"i", AnyInt}, DefaultingKIND}, KINDInt},
|
||||
|
@ -1904,7 +1904,6 @@ std::optional<SpecificCall> IntrinsicProcTable::Implementation::Probe(
|
|||
}
|
||||
|
||||
if (call.isSubroutineCall) {
|
||||
parser::Messages buffer;
|
||||
auto subrRange{subroutines_.equal_range(call.name)};
|
||||
for (auto iter{subrRange.first}; iter != subrRange.second; ++iter) {
|
||||
if (auto specificCall{
|
||||
|
|
|
@ -545,6 +545,23 @@ auto GetShapeHelper::operator()(const ProcedureRef &call) const -> Result {
|
|||
if (!call.arguments().empty()) {
|
||||
return (*this)(call.arguments()[0]);
|
||||
}
|
||||
} else if (intrinsic->name == "matmul") {
|
||||
if (call.arguments().size() == 2) {
|
||||
if (auto ashape{(*this)(call.arguments()[0])}) {
|
||||
if (auto bshape{(*this)(call.arguments()[1])}) {
|
||||
if (ashape->size() == 1 && bshape->size() == 2) {
|
||||
bshape->erase(bshape->begin());
|
||||
return std::move(*bshape); // matmul(vector, matrix)
|
||||
} else if (ashape->size() == 2 && bshape->size() == 1) {
|
||||
ashape->pop_back();
|
||||
return std::move(*ashape); // matmul(matrix, vector)
|
||||
} else if (ashape->size() == 2 && bshape->size() == 2) {
|
||||
(*ashape)[1] = std::move((*bshape)[1]);
|
||||
return std::move(*ashape); // matmul(matrix, matrix)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if (intrinsic->name == "reshape") {
|
||||
if (call.arguments().size() >= 2 && call.arguments().at(1)) {
|
||||
// SHAPE(RESHAPE(array,shape)) -> shape
|
||||
|
|
|
@ -607,7 +607,8 @@ static void CheckExplicitInterfaceArg(evaluate::ActualArgument &arg,
|
|||
// ok
|
||||
} else {
|
||||
messages.Say(
|
||||
"Actual argument is not a variable or typed expression"_err_en_US);
|
||||
"Actual argument '%s' associated with %s is not a variable or typed expression"_err_en_US,
|
||||
expr->AsFortran(), dummyName);
|
||||
}
|
||||
} else {
|
||||
const Symbol &assumed{DEREF(arg.GetAssumedTypeDummy())};
|
||||
|
|
Loading…
Reference in New Issue