Pattern matching on Array types, working for reals.

def f[T](a: Array[T]) = a match {
    case x: Array[Int]      => x(0)
    case x: Array[Double]   => 2
    // etc.
  }

I'd also like to thank "instantiateTypeVar" for displacing the
mechanical spiders and giant squid beings which used to fill my
nightmares.  Now that I know true horror, I welcome the squid.

Closes #2755, review by odersky.

git-svn-id: http://lampsvn.epfl.ch/svn-repos/scala/scala/trunk@23180 5e8d7ff9-d8ef-0310-90f0-a4852d11357a
This commit is contained in:
extempore 2010-10-04 05:28:27 +00:00
parent 608824b938
commit faa8737d9a
7 changed files with 139 additions and 44 deletions

View File

@ -1341,9 +1341,9 @@ trait Infer {
}
check(tp, List())
}
/** Type intersection of simple type <code>tp1</code> with general
* type <code>tp2</code>. The result eliminates some redundancies.
/** Type intersection of simple type tp1 with general type tp2.
* The result eliminates some redundancies.
*/
def intersect(tp1: Type, tp2: Type): Type = {
if (tp1 <:< tp2) tp1
@ -1351,7 +1351,7 @@ trait Infer {
else {
val reduced2 = tp2 match {
case rtp @ RefinedType(parents2, decls2) =>
copyRefinedType(rtp, parents2 filter (p2 => !(tp1 <:< p2)), decls2)
copyRefinedType(rtp, parents2 filterNot (tp1 <:< _), decls2)
case _ =>
tp2
}
@ -1360,36 +1360,57 @@ trait Infer {
}
def inferTypedPattern(pos: Position, pattp: Type, pt0: Type): Type = {
val pt = widen(pt0)
val pt = widen(pt0)
val ptparams = freeTypeParamsOfTerms.collect(pt)
val tpparams = freeTypeParamsOfTerms.collect(pattp)
def ptMatchesPattp = pt matchesPattern pattp
def pattpMatchesPt = pattp matchesPattern pt
/** If we can absolutely rule out a match we can fail fast. */
if (pt.isFinalType && !(pt matchesPattern pattp))
error(pos, "scrutinee is incompatible with pattern type"+foundReqMsg(pattp, pt))
/** If we can absolutely rule out a match we can fail early.
* This is the case if the scrutinee has no unresolved type arguments
* and is a "final type", meaning final + invariant in all type parameters.
*/
if (pt.isFinalType && ptparams.isEmpty && !ptMatchesPattp)
error(pos, "scrutinee is incompatible with pattern type" + foundReqMsg(pattp, pt))
checkCheckable(pos, pattp, "pattern ")
if (!(pattp <:< pt)) {
val tpparams = freeTypeParamsOfTerms.collect(pattp)
if (settings.debug.value) log("free type params (1) = " + tpparams)
if (pattp <:< pt) ()
else {
if (settings.debug.value)
log("free type params (1) = " + tpparams)
var tvars = tpparams map freshVar
var tp = pattp.instantiateTypeParams(tpparams, tvars)
if (!((tp <:< pt) && isInstantiatable(tvars))) {
var tp = pattp.instantiateTypeParams(tpparams, tvars)
if ((tp <:< pt) && isInstantiatable(tvars)) ()
else {
tvars = tpparams map freshVar
tp = pattp.instantiateTypeParams(tpparams, tvars)
val ptparams = freeTypeParamsOfTerms.collect(pt)
if (settings.debug.value) log("free type params (2) = " + ptparams)
tp = pattp.instantiateTypeParams(tpparams, tvars)
if (settings.debug.value)
log("free type params (2) = " + ptparams)
val ptvars = ptparams map freshVar
val pt1 = pt.instantiateTypeParams(ptparams, ptvars)
// See ticket #2486 we have this example of code which would incorrectly
// fail without verifying that !(pattp matchesPattern pt)
if (!(isPopulated(tp, pt1) && isInstantiatable(tvars ::: ptvars)) && !(pattp matchesPattern pt)) {
error(pos, "pattern type is incompatible with expected type"+foundReqMsg(pattp, pt))
return pattp
val pt1 = pt.instantiateTypeParams(ptparams, ptvars)
// See ticket #2486 for an example of code which would incorrectly
// fail if we didn't allow for pattpMatchesPt.
if (isPopulated(tp, pt1) && isInstantiatable(tvars ++ ptvars) || pattpMatchesPt)
ptvars foreach instantiateTypeVar
else {
error(pos, "pattern type is incompatible with expected type" + foundReqMsg(pattp, pt))
return pattp
}
ptvars foreach instantiateTypeVar
}
tvars foreach instantiateTypeVar
}
intersect(pt, pattp)
/** If the scrutinee has free type parameters but the pattern does not,
* we have to flip the arguments so the expected type is treated as more
* general when calculating the intersection. See run/bug2755.scala.
*/
if (tpparams.isEmpty && ptparams.nonEmpty) intersect(pattp, pt)
else intersect(pt, pattp)
}
def inferModulePattern(pat: Tree, pt: Type) =

View File

@ -1,21 +1,21 @@
patmat-type-check.scala:18: error: scrutinee is incompatible with pattern type;
patmat-type-check.scala:22: error: scrutinee is incompatible with pattern type;
found : Seq[A]
required: java.lang.String
def f1 = "bob".reverse match { case Seq('b', 'o', 'b') => true } // fail
^
patmat-type-check.scala:19: error: scrutinee is incompatible with pattern type;
patmat-type-check.scala:23: error: scrutinee is incompatible with pattern type;
found : Seq[A]
required: Array[Char]
def f2 = "bob".toArray match { case Seq('b', 'o', 'b') => true } // fail
^
patmat-type-check.scala:23: error: scrutinee is incompatible with pattern type;
patmat-type-check.scala:27: error: scrutinee is incompatible with pattern type;
found : Seq[A]
required: Test.Bop2
def f3(x: Bop2) = x match { case Seq('b', 'o', 'b') => true } // fail
^
patmat-type-check.scala:27: error: scrutinee is incompatible with pattern type;
patmat-type-check.scala:30: error: scrutinee is incompatible with pattern type;
found : Seq[A]
required: Test.Bop3[T]
def f4[T](x: Bop3[T]) = x match { case Seq('b', 'o', 'b') => true } // fail
^
required: Test.Bop3[Char]
def f4[T](x: Bop3[Char]) = x match { case Seq('b', 'o', 'b') => true } // fail
^
four errors found

View File

@ -14,6 +14,10 @@ object Test
final class Bop5[T, U, -V]
def s4[T1, T2](x: Bop5[_, T1, T2]) = x match { case Seq('b', 'o', 'b') => true }
// free type parameter, allowed
final class Bop3[T]
def f4[T](x: Bop3[T]) = x match { case Seq('b', 'o', 'b') => true }
// String and Array are final/invariant, disallowed
def f1 = "bob".reverse match { case Seq('b', 'o', 'b') => true } // fail
def f2 = "bob".toArray match { case Seq('b', 'o', 'b') => true } // fail
@ -23,6 +27,5 @@ object Test
def f3(x: Bop2) = x match { case Seq('b', 'o', 'b') => true } // fail
// final, invariant type parameter, should be disallowed
final class Bop3[T]
def f4[T](x: Bop3[T]) = x match { case Seq('b', 'o', 'b') => true } // fail
def f4[T](x: Bop3[Char]) = x match { case Seq('b', 'o', 'b') => true } // fail
}

View File

@ -1,14 +1,4 @@
t3692.scala:11: warning: type Integer in package scala is deprecated: use <code>java.lang.Integer</code> instead
case m0: Map[Int, Int] => new java.util.HashMap[Integer, Integer]
^
t3692.scala:12: warning: type Integer in package scala is deprecated: use <code>java.lang.Integer</code> instead
case m1: Map[Int, V] => new java.util.HashMap[Integer, V]
^
t3692.scala:13: warning: type Integer in package scala is deprecated: use <code>java.lang.Integer</code> instead
case m2: Map[T, Int] => new java.util.HashMap[T, Integer]
^
t3692.scala:13: error: unreachable code
t3692.scala:15: error: unreachable code
case m2: Map[T, Int] => new java.util.HashMap[T, Integer]
^
three warnings found
one error found

View File

@ -1,3 +1,5 @@
import java.lang.Integer
object ManifestTester {
def main(args: Array[String]) = {
val map = Map("John" -> 1, "Josh" -> 2)

View File

@ -0,0 +1,21 @@
1
2
3
4
5
6
7
1
2
3
4
5
6
7
1
2
3
4
5
6
7

View File

@ -0,0 +1,58 @@
// Test cases: the only place we can cut and paste without crying
// ourself to sleep.
object Test {
def f1(a: Any) = a match {
case x: Array[Int] => x(0)
case x: Array[Double] => 2
case x: Array[Float] => x.sum.toInt
case x: Array[String] => x.size
case x: Array[AnyRef] => 5
case x: Array[_] => 6
case _ => 7
}
def f2(a: Array[_]) = a match {
case x: Array[Int] => x(0)
case x: Array[Double] => 2
case x: Array[Float] => x.sum.toInt
case x: Array[String] => x.size
case x: Array[AnyRef] => 5
case x: Array[_] => 6
case _ => 7
}
def f3[T](a: Array[T]) = a match {
case x: Array[Int] => x(0)
case x: Array[Double] => 2
case x: Array[Float] => x.sum.toInt
case x: Array[String] => x.size
case x: Array[AnyRef] => 5
case x: Array[_] => 6
case _ => 7
}
def main(args: Array[String]): Unit = {
println(f1(Array(1, 2, 3)))
println(f1(Array(1.0, -2.0, 3.0, 1.0)))
println(f1(Array(1.0f, 2.0f, 3.0f, -3.0f)))
println(f1((1 to 4).toArray map (_.toString)))
println(f1(new Array[Any](10))) // should match as Array[AnyRef]
println(f1(Array(1L)))
println(f1(null))
println(f2(Array(1, 2, 3)))
println(f2(Array(1.0, -2.0, 3.0, 1.0)))
println(f2(Array(1.0f, 2.0f, 3.0f, -3.0f)))
println(f2((1 to 4).toArray map (_.toString)))
println(f2(new Array[Any](10))) // should match as Array[AnyRef]
println(f2(Array(1L)))
println(f2(null))
println(f3(Array(1, 2, 3)))
println(f3(Array(1.0, -2.0, 3.0, 1.0)))
println(f3(Array(1.0f, 2.0f, 3.0f, -3.0f)))
println(f3((1 to 4).toArray map (_.toString)))
println(f3(new Array[Any](10))) // should match as Array[AnyRef]
println(f3(Array(1L)))
println(f3(null))
}
}