Skip to content

Commit 6f221b8

Browse files
committed
BufferedReplayInputStream shallow copy fix
1 parent 88d6a14 commit 6f221b8

File tree

4 files changed

+167
-76
lines changed

4 files changed

+167
-76
lines changed

decoder/build.gradle.kts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ repositories {
2626

2727
dependencies {
2828
testImplementation("org.junit.jupiter:junit-jupiter:5.9.0")
29+
testImplementation("org.jetbrains.kotlinx:kotlinx-coroutines-test:1.6.4")
2930
}
3031

3132
sourceSets {

decoder/src/main/kotlin/app/redwarp/gif/decoder/streams/BufferedReplayInputStream.kt

Lines changed: 119 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -22,115 +22,158 @@ import java.io.InputStream
2222
* then we will stop reading the original stream, and only use the in memory data.
2323
* Load and keep the whole [InputStream] in memory, should be avoided for huge GIFs.
2424
*/
25-
internal class BufferedReplayInputStream(inputStream: InputStream) : ReplayInputStream() {
26-
private val inputStream = inputStream.buffered()
27-
private var outputStream: ByteArrayOutputStream? = ByteArrayOutputStream()
28-
private var position = 0
29-
private var totalCount = 0
30-
private var replay = false
31-
32-
private var _loadedData: ByteArray? = null
33-
private val loadedData: ByteArray
34-
get() {
35-
return _loadedData ?: requireNotNull(outputStream).toByteArray()
36-
.also(this::_loadedData::set)
37-
}
25+
internal class BufferedReplayInputStream private constructor(
26+
inputStream: InputStream?,
27+
private var loadedData: ByteArray?,
28+
private var state: State
29+
) : ReplayInputStream() {
30+
constructor(inputStream: InputStream) : this(inputStream, null, State())
31+
private constructor(loadedData: ByteArray, state: State) : this(null, loadedData, state)
32+
33+
private var reader: Reader? = inputStream?.let { Reader(it) }
3834

3935
override fun seek(position: Int) {
40-
replay = true
41-
if (_loadedData == null) {
42-
_loadedData = requireNotNull(outputStream).toByteArray()
43-
outputStream?.close()
44-
outputStream = null
45-
inputStream.close()
46-
}
47-
this.position = position
36+
setReplayIfNeeded()
37+
this.state.position = position
4838
}
4939

5040
override fun getPosition(): Int {
51-
return if (!replay) requireNotNull(outputStream).size()
52-
else {
53-
position
54-
}
41+
return reader?.size() ?: state.position
5542
}
5643

5744
override fun read(): Int {
58-
return if (!replay) {
59-
val read = inputStream.read()
60-
outputStream?.write(read)
61-
totalCount++
62-
read
63-
} else {
64-
(loadedData[position].toInt() and 0xFF).also {
65-
position += 1
45+
return reader?.read()
46+
?: (requireNotNull(loadedData)[state.position].toInt() and 0xFF).also {
47+
state.position += 1
6648
}
67-
}
6849
}
6950

7051
override fun read(byteArray: ByteArray, offset: Int, length: Int): Int {
71-
if (!replay) {
72-
val readCount = inputStream.read(byteArray, offset, length)
52+
reader.let { reader ->
53+
if (reader != null) {
54+
return reader.read(byteArray, offset, length)
55+
} else {
56+
val readCount =
57+
if (length > readableBytes()) readableBytes() else length
58+
59+
requireNotNull(loadedData).copyInto(
60+
destination = byteArray,
61+
destinationOffset = 0,
62+
startIndex = state.position + offset,
63+
endIndex = state.position + offset + readCount
64+
)
65+
state.position += readCount
66+
67+
return readCount
68+
}
69+
}
70+
}
7371

74-
if (readCount > 0) {
75-
outputStream?.write(byteArray, offset, readCount)
72+
override fun read(byteArray: ByteArray): Int {
73+
reader.let { reader ->
74+
if (reader != null) {
75+
return reader.read(byteArray)
76+
} else {
77+
val readCount =
78+
if (byteArray.size > readableBytes()) readableBytes() else byteArray.size
79+
80+
requireNotNull(loadedData).copyInto(
81+
destination = byteArray,
82+
destinationOffset = 0,
83+
startIndex = state.position,
84+
endIndex = state.position + readCount
85+
)
86+
87+
state.position += readCount
88+
89+
return readCount
7690
}
77-
totalCount += readCount
91+
}
92+
}
7893

79-
return readCount
80-
} else {
81-
val readCount =
82-
if (length > readableBytes()) readableBytes() else length
83-
84-
loadedData.copyInto(
85-
destination = byteArray,
86-
destinationOffset = 0,
87-
startIndex = position + offset,
88-
endIndex = position + offset + readCount
89-
)
90-
position += readCount
94+
override fun close() {
95+
reader?.close()
96+
}
9197

92-
return readCount
98+
override fun shallowClone(): ReplayInputStream {
99+
setReplayIfNeeded()
100+
return BufferedReplayInputStream(requireNotNull(loadedData), state.copy())
101+
}
102+
103+
private fun readableBytes(): Int {
104+
return loadedData?.let {
105+
it.size - state.position
106+
} ?: 0
107+
}
108+
109+
@Synchronized
110+
private fun setReplayIfNeeded() {
111+
if (loadedData != null) return
112+
113+
reader?.let { reader ->
114+
reader.readAll()
115+
loadedData = reader.toByteArray().also {
116+
reader.close()
117+
this.reader = null
118+
}
93119
}
94120
}
95121

96-
override fun read(byteArray: ByteArray): Int {
97-
if (!replay) {
122+
private data class State(var position: Int = 0)
123+
124+
private class Reader(inputStream: InputStream) : AutoCloseable {
125+
private val inputStream = inputStream.buffered()
126+
private var outputStream: ByteArrayOutputStream = ByteArrayOutputStream()
127+
private var totalCount = 0
128+
129+
fun size(): Int {
130+
return outputStream.size()
131+
}
132+
133+
override fun close() {
134+
inputStream.close()
135+
outputStream.close()
136+
}
137+
138+
fun read(): Int {
139+
return inputStream.read().also {
140+
outputStream.write(it)
141+
totalCount++
142+
}
143+
}
144+
145+
fun read(byteArray: ByteArray): Int {
98146
val readCount = inputStream.read(byteArray)
99147

100148
if (readCount > 0) {
101-
outputStream?.write(byteArray, 0, readCount)
149+
outputStream.write(byteArray, 0, readCount)
102150
}
103151
totalCount += readCount
104152

105153
return readCount
106-
} else {
107-
val readCount =
108-
if (byteArray.size > readableBytes()) readableBytes() else byteArray.size
154+
}
109155

110-
loadedData.copyInto(
111-
destination = byteArray,
112-
destinationOffset = 0,
113-
startIndex = position,
114-
endIndex = position + readCount
115-
)
156+
fun read(byteArray: ByteArray, offset: Int, length: Int): Int {
157+
val readCount = inputStream.read(byteArray, offset, length)
116158

117-
position += readCount
159+
if (readCount > 0) {
160+
outputStream.write(byteArray, offset, readCount)
161+
}
162+
totalCount += readCount
118163

119164
return readCount
120165
}
121-
}
122166

123-
override fun close() {
124-
inputStream.close()
125-
}
167+
fun readAll(): Int {
168+
val bytes = inputStream.readAllBytes()
126169

127-
override fun shallowClone(): ReplayInputStream {
128-
return this
129-
}
170+
outputStream.write(bytes)
130171

131-
private fun readableBytes(): Int {
132-
return _loadedData?.let {
133-
it.size - position
134-
} ?: 0
172+
return bytes.size
173+
}
174+
175+
fun toByteArray(): ByteArray {
176+
return outputStream.toByteArray()
177+
}
135178
}
136179
}

decoder/src/test/kotlin/app/redwarp/gif/decoder/GifTest.kt

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
package app.redwarp.gif.decoder
1616

1717
import app.redwarp.gif.decoder.descriptors.Dimension
18+
import kotlinx.coroutines.launch
19+
import kotlinx.coroutines.runBlocking
1820
import org.junit.jupiter.api.Assertions.assertArrayEquals
1921
import org.junit.jupiter.api.Assertions.assertEquals
2022
import org.junit.jupiter.api.Assertions.assertTrue
@@ -269,6 +271,24 @@ class GifTest {
269271
assertTrue(result.isFailure)
270272
}
271273

274+
@Test
275+
fun gif_shallowCloned_noIssuesWithConcurrency() = runBlocking {
276+
val gifDescriptor = Parser.parse(File("../assets/domo-no-dispose.gif")).getOrThrow()
277+
val originalGif = Gif(gifDescriptor)
278+
279+
repeat(1000) { id ->
280+
launch {
281+
val gif = Gif.from(originalGif)
282+
val index = id % gif.frameCount
283+
val pixels = IntArray(gif.dimension.size)
284+
gif.getFrame(index, pixels)
285+
val expectedPixels = loadExpectedPixels(File("../assets/frames/domo_$index.png"))
286+
287+
assertArrayEquals(expectedPixels, pixels)
288+
}
289+
}
290+
}
291+
272292
private fun loadExpectedPixels(file: File): IntArray {
273293
val input = ImageIO.read(file)
274294
val image = BufferedImage(input.width, input.height, BufferedImage.TYPE_INT_ARGB)

decoder/src/test/kotlin/app/redwarp/gif/decoder/streams/BufferedReplayInputStreamTest.kt

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,10 @@
1414
*/
1515
package app.redwarp.gif.decoder.streams
1616

17+
import kotlinx.coroutines.Dispatchers
18+
import kotlinx.coroutines.launch
19+
import kotlinx.coroutines.runBlocking
20+
import kotlinx.coroutines.withContext
1721
import org.junit.jupiter.api.Assertions.assertArrayEquals
1822
import org.junit.jupiter.api.Assertions.assertEquals
1923
import org.junit.jupiter.api.Test
@@ -89,4 +93,27 @@ class BufferedReplayInputStreamTest {
8993

9094
assertEquals(3, bufferedReplayInputStream.getPosition())
9195
}
96+
97+
@Test
98+
fun shallowCopy_noIssuesWithConcurrency() =
99+
runBlocking {
100+
val originalData = byteArrayOf(0x01, 0x02, 0x03, 0x04)
101+
102+
val bufferedReplayInputStream = BufferedReplayInputStream(originalData.inputStream())
103+
bufferedReplayInputStream.read(ByteArray(128))
104+
105+
repeat(10_000) { id ->
106+
launch {
107+
val index = id % 4
108+
val copied = bufferedReplayInputStream.shallowClone()
109+
copied.seek(index)
110+
assertEquals(
111+
originalData[index],
112+
withContext(Dispatchers.IO) {
113+
copied.read().toByte()
114+
}
115+
)
116+
}
117+
}
118+
}
92119
}

0 commit comments

Comments
 (0)