Integrated contributed non-recursive implementation of permutations,

combinations, subsets, by EastSun.  Also gave mutable Seqs an
in-place transform method like the one Map has.  And couldn't
resist slightly reformulating a few set methods, because how can we
settle for "forall(that.contains)" when we could have "this forall that".
(Which is also what normal people hear when we talk about sets.)

Closes #4060, #3644, review by moors.

git-svn-id: http://lampsvn.epfl.ch/svn-repos/scala/scala/trunk@24042 5e8d7ff9-d8ef-0310-90f0-a4852d11357a
This commit is contained in:
extempore 2011-01-20 20:49:03 +00:00
parent 1e2ef90b5f
commit 8cf68bbadf
6 changed files with 225 additions and 27 deletions

View File

@ -372,18 +372,127 @@ trait SeqLike[+A, +Repr] extends IterableLike[A, Repr] { self =>
* @return An Iterator which traverses the distinct permutations of this $coll. * @return An Iterator which traverses the distinct permutations of this $coll.
* @example `"abb".permutations = Iterator(abb, bab, bba)` * @example `"abb".permutations = Iterator(abb, bab, bba)`
*/ */
def permutations: Iterator[Repr] = { def permutations: Iterator[Repr] =
val seen = mutable.HashSet[A]() if (isEmpty) Iterator(repr)
val xs = thisCollection.toIndexedSeq else new PermutationsItr
/** Iterates over combinations.
*
* @return An Iterator which traverses the possible n-element combinations of this $coll.
* @example `"abbbc".combinations(2) = Iterator(ab, ac, bb, bc)`
*/
def combinations(n: Int): Iterator[Repr] =
if (n < 0 || n > size) Iterator.empty
else new CombinationsItr(n)
private class PermutationsItr extends Iterator[Repr] {
private[this] val (elms, idxs) = init()
private var _hasNext = true
if (xs.isEmpty) Iterator.empty def hasNext = _hasNext
else if (xs.tail.isEmpty) Iterator(repr) def next: Repr = {
else xs.indices collect { if (!hasNext)
case idx if !seen(xs(idx)) => Iterator.empty.next
seen += xs(idx)
val rest = (xs take idx) ++ (xs drop (idx + 1)) val result = (self.newBuilder ++= elms).result
rest.permutations map (newBuilder += xs(idx) ++= _ result) var i = idxs.length - 2
} reduceLeft (_ ++ _) while(i >= 0 && idxs(i) >= idxs(i+1))
i -= 1
if (i < 0)
_hasNext = false
else {
var j = idxs.length - 1
while(idxs(j) <= idxs(i)) j -= 1
swap(i,j)
val len = (idxs.length - i) / 2
var k = 1
while (k <= len) {
swap(i+k, idxs.length - k)
k += 1
}
}
result
}
private def swap(i: Int, j: Int) {
var tmpI = idxs(i)
idxs(i) = idxs(j)
idxs(j) = tmpI
var tmpE = elms(i)
elms(i) = elms(j)
elms(j) = tmpE
}
private[this] def init() = {
val m = mutable.HashMap[A, Int]()
val (es, is) = thisCollection map (e => (e, m.getOrElseUpdate(e, m.size))) sortBy (_._2) unzip
(es.toBuffer, is.toArray)
}
}
private class CombinationsItr(n: Int) extends Iterator[Repr] {
// generating all nums such that:
// (1) nums(0) + .. + nums(length-1) = n
// (2) 0 <= nums(i) <= cnts(i), where 0 <= i <= cnts.length-1
private val (elms, cnts, nums) = init()
private val offs = cnts.scanLeft(0)(_ + _)
private var _hasNext = true
def hasNext = _hasNext
def next: Repr = {
if (!hasNext)
Iterator.empty.next
/** Calculate this result. */
val buf = self.newBuilder
for(k <- 0 until nums.length; j <- 0 until nums(k))
buf += elms(offs(k)+j)
val res = buf.result
/** Prepare for the next call to next. */
var idx = nums.length - 1
while (idx >= 0 && nums(idx) == cnts(idx))
idx -= 1
idx = nums.lastIndexWhere(_ > 0, idx - 1)
if (idx < 0)
_hasNext = false
else {
var sum = nums.slice(idx + 1, nums.length).sum + 1
nums(idx) -= 1
for (k <- (idx+1) until nums.length) {
nums(k) = sum min cnts(k)
sum -= nums(k)
}
}
res
}
/** Rearrange seq to newSeq a0a0..a0a1..a1...ak..ak such that
* seq.count(_ == aj) == cnts(j)
*
* @return (newSeq,cnts,nums)
*/
private def init(): (IndexedSeq[A], Array[Int], Array[Int]) = {
val m = mutable.HashMap[A, Int]()
// e => (e, weight(e))
val (es, is) = thisCollection map (e => (e, m.getOrElseUpdate(e, m.size))) sortBy (_._2) unzip
val cs = new Array[Int](m.size)
is foreach (i => cs(i) += 1)
val ns = new Array[Int](cs.length)
var r = n
0 until ns.length foreach { k =>
ns(k) = r min cs(k)
r -= ns(k)
}
(es.toIndexedSeq, cs, ns)
}
} }
/** Returns new $coll wih elements in reversed order. /** Returns new $coll wih elements in reversed order.

View File

@ -141,7 +141,7 @@ self =>
* @param elem the element to test for membership. * @param elem the element to test for membership.
* @return `true` if `elem` is contained in this set, `false` otherwise. * @return `true` if `elem` is contained in this set, `false` otherwise.
*/ */
def apply(elem: A): Boolean = contains(elem) def apply(elem: A): Boolean = this contains elem
/** Computes the intersection between this set and another set. /** Computes the intersection between this set and another set.
* *
@ -149,7 +149,7 @@ self =>
* @return a new set consisting of all elements that are both in this * @return a new set consisting of all elements that are both in this
* set and in the given set `that`. * set and in the given set `that`.
*/ */
def intersect(that: Set[A]): This = filter(that.contains) def intersect(that: Set[A]): This = this filter that
/** Computes the intersection between this set and another set. /** Computes the intersection between this set and another set.
* *
@ -158,7 +158,7 @@ self =>
* @return a new set consisting of all elements that are both in this * @return a new set consisting of all elements that are both in this
* set and in the given set `that`. * set and in the given set `that`.
*/ */
def &(that: Set[A]): This = intersect(that) def &(that: Set[A]): This = this intersect that
/** This method is an alias for `intersect`. /** This method is an alias for `intersect`.
* It computes an intersection with set `that`. * It computes an intersection with set `that`.
@ -166,7 +166,8 @@ self =>
* *
* @param that the set to intersect with * @param that the set to intersect with
*/ */
@deprecated("use & instead") def ** (that: Set[A]): This = intersect(that) @deprecated("use & instead")
def ** (that: Set[A]): This = &(that)
/** Computes the union between of set and another set. /** Computes the union between of set and another set.
* *
@ -174,7 +175,7 @@ self =>
* @return a new set consisting of all elements that are in this * @return a new set consisting of all elements that are in this
* set or in the given set `that`. * set or in the given set `that`.
*/ */
def union(that: Set[A]): This = this.++(that) def union(that: Set[A]): This = this ++ that
/** Computes the union between this set and another set. /** Computes the union between this set and another set.
* *
@ -183,7 +184,7 @@ self =>
* @return a new set consisting of all elements that are in this * @return a new set consisting of all elements that are in this
* set or in the given set `that`. * set or in the given set `that`.
*/ */
def | (that: Set[A]): This = union(that) def | (that: Set[A]): This = this union that
/** Computes the difference of this set and another set. /** Computes the difference of this set and another set.
* *
@ -191,7 +192,7 @@ self =>
* @return a set containing those elements of this * @return a set containing those elements of this
* set that are not also contained in the given set `that`. * set that are not also contained in the given set `that`.
*/ */
def diff(that: Set[A]): This = --(that) def diff(that: Set[A]): This = this -- that
/** The difference of this set and another set. /** The difference of this set and another set.
* *
@ -200,7 +201,7 @@ self =>
* @return a set containing those elements of this * @return a set containing those elements of this
* set that are not also contained in the given set `that`. * set that are not also contained in the given set `that`.
*/ */
def &~(that: Set[A]): This = diff(that) def &~(that: Set[A]): This = this diff that
/** Tests whether this set is a subset of another set. /** Tests whether this set is a subset of another set.
* *
@ -208,7 +209,74 @@ self =>
* @return `true` if this set is a subset of `that`, i.e. if * @return `true` if this set is a subset of `that`, i.e. if
* every element of this set is also an element of `that`. * every element of this set is also an element of `that`.
*/ */
def subsetOf(that: Set[A]): Boolean = forall(that.contains) def subsetOf(that: Set[A]) = this forall that
/** An iterator over all subsets of this set of the given size.
*
* @param len the size of the subsets.
* @return the iterator.
*/
def subsets(len: Int): Iterator[This] = {
if (len < 0 || len > size) throw new IllegalArgumentException(len.toString)
else new SubsetsItr(self.toIndexedSeq, len)
}
/** An iterator over all subsets of this set.
*
* @return the iterator.
*/
def subsets: Iterator[This] = new Iterator[This] {
private val elms = self.toIndexedSeq
private var len = 0
private var itr: Iterator[This] = Iterator.empty
def hasNext = len <= elms.size || itr.hasNext
def next = {
if (!itr.hasNext) {
if (len > elms.size) Iterator.empty.next
else {
itr = new SubsetsItr(elms, len)
len += 1
}
}
itr.next
}
}
/** An Iterator include all subsets containing exactly len elements.
* If the elements in 'This' type is ordered, then the subsets will also be in the same order.
* ListSet(1,2,3).subsets => {1},{2},{3},{1,2},{1,3},{2,3},{1,2,3}}
*
* @author Eastsun
* @date 2010.12.6
*/
private class SubsetsItr(elms: IndexedSeq[A], len: Int) extends Iterator[This] {
private val idxs = Array.range(0, len+1)
private var _hasNext = true
idxs(len) = elms.size
def hasNext = _hasNext
def next: This = {
if (!hasNext) Iterator.empty.next
val buf = self.newBuilder
idxs.slice(0, len) foreach (idx => buf += elms(idx))
val result = buf.result
var i = len - 1
while (i >= 0 && idxs(i) == idxs(i+1)-1) i -= 1
if (i < 0) _hasNext = false
else {
idxs(i) += 1
for (j <- (i+1) until len)
idxs(j) = idxs(j-1) + 1
}
result
}
}
/** Defines the prefix of this object's `toString` representation. /** Defines the prefix of this object's `toString` representation.
* @return a string representation which starts the result of `toString` applied to this set. * @return a string representation which starts the result of `toString` applied to this set.

View File

@ -79,8 +79,6 @@ object BitSet extends BitSetFactory[BitSet] {
else new BitSetN(elems) else new BitSetN(elems)
} }
private val hashSeed = "BitSet".hashCode
class BitSet1(val elems: Long) extends BitSet { class BitSet1(val elems: Long) extends BitSet {
protected def nwords = 1 protected def nwords = 1
protected def word(idx: Int) = if (idx == 0) elems else 0L protected def word(idx: Int) = if (idx == 0) elems else 0L

View File

@ -55,14 +55,14 @@ object IntMap {
// develops. Case objects and custom equality don't mix without // develops. Case objects and custom equality don't mix without
// careful handling. // careful handling.
override def equals(that : Any) = that match { override def equals(that : Any) = that match {
case (that : AnyRef) if (this eq that) => true; case _: this.type => true
case (that : IntMap[_]) => false; // The only empty IntMaps are eq Nil case _: IntMap[_] => false // The only empty IntMaps are eq Nil
case that => super.equals(that); case _ => super.equals(that)
} }
}; }
private[immutable] case class Tip[+T](key : Int, value : T) extends IntMap[T]{ private[immutable] case class Tip[+T](key : Int, value : T) extends IntMap[T]{
def withValue[S](s : S) = def withValue[S](s: S) =
if (s.asInstanceOf[AnyRef] eq value.asInstanceOf[AnyRef]) this.asInstanceOf[IntMap.Tip[S]]; if (s.asInstanceOf[AnyRef] eq value.asInstanceOf[AnyRef]) this.asInstanceOf[IntMap.Tip[S]];
else IntMap.Tip(key, s); else IntMap.Tip(key, s);
} }

View File

@ -28,4 +28,19 @@ trait SeqLike[A, +This <: SeqLike[A, This] with Seq[A]]
* @throws IndexOutofBoundsException if the index is not valid. * @throws IndexOutofBoundsException if the index is not valid.
*/ */
def update(idx: Int, elem: A) def update(idx: Int, elem: A)
/** Applies a transformation function to all values contained in this sequence.
* The transformation function produces new values from existing elements.
*
* @param f the transformation to apply
* @return the sequence itself.
*/
def transform(f: A => A): this.type = {
var i = 0
iterator foreach { el =>
update(i, f(el))
i += 1
}
this
}
} }

View File

@ -0,0 +1,8 @@
object Test {
val x = 1 to 10 toBuffer
def main(args: Array[String]): Unit = {
x transform (_ * 2)
assert(x.sum == (1 to 10).sum * 2)
}
}