diff --git a/shared/src/main/scala-2.13+/scala/xml/ScalaVersionSpecific.scala b/shared/src/main/scala-2.13+/scala/xml/ScalaVersionSpecific.scala index bd2eb820d..4258c100f 100644 --- a/shared/src/main/scala-2.13+/scala/xml/ScalaVersionSpecific.scala +++ b/shared/src/main/scala-2.13+/scala/xml/ScalaVersionSpecific.scala @@ -1,7 +1,7 @@ package scala.xml import scala.collection.immutable.StrictOptimizedSeqOps -import scala.collection.{SeqOps, IterableOnce, immutable, mutable} +import scala.collection.{View, SeqOps, IterableOnce, immutable, mutable} import scala.collection.BuildFrom import scala.collection.mutable.Builder @@ -20,6 +20,21 @@ private[xml] trait ScalaVersionSpecificNodeSeq override def fromSpecific(coll: IterableOnce[Node]): NodeSeq = (NodeSeq.newBuilder ++= coll).result() override def newSpecificBuilder: mutable.Builder[Node, NodeSeq] = NodeSeq.newBuilder override def empty: NodeSeq = NodeSeq.Empty + def concat(suffix: IterableOnce[Node]): NodeSeq = + fromSpecific(iterator ++ suffix.iterator) + @inline final def ++ (suffix: Seq[Node]): NodeSeq = concat(suffix) + def appended(base: Node): NodeSeq = + fromSpecific(new View.Appended(this, base)) + def appendedAll(suffix: IterableOnce[Node]): NodeSeq = + concat(suffix) + def prepended(base: Node): NodeSeq = + fromSpecific(new View.Prepended(base, this)) + def prependedAll(prefix: IterableOnce[Node]): NodeSeq = + fromSpecific(prefix.iterator ++ iterator) + def map(f: Node => Node): NodeSeq = + fromSpecific(new View.Map(this, f)) + def flatMap(f: Node => IterableOnce[Node]): NodeSeq = + fromSpecific(new View.FlatMap(this, f)) } private[xml] trait ScalaVersionSpecificNodeBuffer { self: NodeBuffer => diff --git a/shared/src/test/scala/scala/xml/NodeSeqTest.scala b/shared/src/test/scala/scala/xml/NodeSeqTest.scala new file mode 100644 index 000000000..08b7c0d79 --- /dev/null +++ b/shared/src/test/scala/scala/xml/NodeSeqTest.scala @@ -0,0 +1,23 @@ +package scala.xml + +import scala.xml.NodeSeq.seqToNodeSeq + +import org.junit.Test +import org.junit.Assert.assertEquals +import org.junit.Assert.fail + +class NodeSeqTest { + + @Test + def testAppend: Unit = { // Bug #392. + val a: NodeSeq = Hello + val b = Hi + a ++ Hi match { + case res: NodeSeq => assertEquals(2, res.size) + case res: Seq[Node] => fail("Should be NodeSeq") // Unreachable code? + } + val res: NodeSeq = a ++ b + val exp = NodeSeq.fromSeq(Seq(Hello, Hi)) + assertEquals(exp, res) + } +}