aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorpacien2018-11-25 16:45:35 +0100
committerpacien2018-11-25 16:45:35 +0100
commit680c0a3c94f0bb84a2773bc9a95dc5399b6925fb (patch)
tree8b7efa786f14aa4d17ab22ab11b55eda1981519f
parent643d2d72fab23df30d29c10614bfa89648cd3655 (diff)
downloadgziplike-680c0a3c94f0bb84a2773bc9a95dc5399b6925fb.tar.gz
Fix bitreader look-ahead overflow
-rw-r--r--src/bitreader.nim52
-rw-r--r--src/integers.nim7
-rw-r--r--tests/tbitreader.nim23
-rw-r--r--tests/tintegers.nim7
4 files changed, 53 insertions, 36 deletions
diff --git a/src/bitreader.nim b/src/bitreader.nim
index 757c1b3..7afb13d 100644
--- a/src/bitreader.nim
+++ b/src/bitreader.nim
@@ -17,49 +17,33 @@
17import streams 17import streams
18import integers 18import integers
19 19
20# Stream functions
21
22proc newEIO(msg: string): ref IOError =
23 new(result)
24 result.msg = msg
25
26proc read[T](s: Stream, t: typedesc[T]): T =
27 if readData(s, addr(result), sizeof(T)) != sizeof(T):
28 raise newEIO("cannot read from stream")
29
30proc peek[T](s: Stream, t: typedesc[T]): T =
31 if peekData(s, addr(result), sizeof(T)) != sizeof(T):
32 raise newEIO("cannot read from stream")
33
34# BitReader
35
36type BitReader* = ref object 20type BitReader* = ref object
37 stream: Stream 21 stream: Stream
38 bitOffset: int 22 bitOffset: int
23 overflowBuffer: uint8
39 24
40proc bitReader*(stream: Stream): BitReader = 25proc bitReader*(stream: Stream): BitReader =
41 BitReader(stream: stream, bitOffset: 0) 26 BitReader(stream: stream, bitOffset: 0, overflowBuffer: 0)
42 27
43proc atEnd*(bitReader: BitReader): bool = 28proc atEnd*(bitReader: BitReader): bool =
44 bitReader.stream.atEnd() 29 bitReader.bitOffset == 0 and bitReader.stream.atEnd()
45 30
46proc readBits*[T: SomeUnsignedInt](bitReader: BitReader, bits: int, to: typedesc[T]): T = 31proc readBits*[T: SomeUnsignedInt](bitReader: BitReader, bits: int, to: typedesc[T]): T =
47 let targetBitLength = sizeof(T) * wordBitLength 32 if bits < 0 or bits > sizeof(T) * wordBitLength: raise newException(RangeError, "invalid bit length")
48 if bits < 0 or bits > targetBitLength: 33 if bits == 0: return 0
49 raise newException(RangeError, "invalid bit length") 34 var bitsRead = 0
50 elif bits == 0: 35 if bitReader.bitOffset > 0:
51 result = 0 36 let bitsFromBuffer = min(bits, wordBitLength - bitReader.bitOffset)
52 elif bits < targetBitLength - bitReader.bitOffset: 37 result = (bitReader.overflowBuffer shr bitReader.bitOffset).leastSignificantBits(bitsFromBuffer)
53 result = bitReader.stream.peek(T) shl (targetBitLength - bits - bitReader.bitOffset) shr (targetBitLength - bits) 38 bitReader.bitOffset = (bitReader.bitOffset + bitsFromBuffer) mod wordBitLength
54 elif bits == targetBitLength - bitReader.bitOffset: 39 bitsRead += bitsFromBuffer
55 result = bitReader.stream.read(T) shl (targetBitLength - bits - bitReader.bitOffset) shr (targetBitLength - bits) 40 while bits - bitsRead >= wordBitLength:
56 else: 41 result = result or (bitReader.stream.readUint8().T shl bitsRead)
57 let rightBits = targetBitLength - bitReader.bitOffset 42 bitsRead += wordBitLength
58 let leftBits = bits - rightBits 43 if bits - bitsRead > 0:
59 let right = bitReader.stream.read(T) shr bitReader.bitOffset 44 bitReader.overflowBuffer = bitReader.stream.readUint8()
60 let left = bitReader.stream.peek(T) shl (targetBitLength - leftBits) shr (targetBitLength - bits) 45 bitReader.bitOffset = bits - bitsRead
61 result = left or right 46 result = result or (bitReader.overflowBuffer.leastSignificantBits(bitReader.bitOffset).T shl bitsRead)
62 bitReader.bitOffset = (bitReader.bitOffset + bits) mod wordBitLength
63 47
64proc readBool*(bitReader: BitReader): bool = 48proc readBool*(bitReader: BitReader): bool =
65 bitReader.readBits(1, uint8) != 0 49 bitReader.readBits(1, uint8) != 0
diff --git a/src/integers.nim b/src/integers.nim
index fddbfdc..7b0f166 100644
--- a/src/integers.nim
+++ b/src/integers.nim
@@ -15,13 +15,16 @@
15# along with this program. If not, see <https://www.gnu.org/licenses/>. 15# along with this program. If not, see <https://www.gnu.org/licenses/>.
16 16
17const wordBitLength* = 8 17const wordBitLength* = 8
18const wordBitMask* = 0b1111_1111'u8
19 18
20proc `/^`*[T: Natural](x, y: T): T = 19proc `/^`*[T: Natural](x, y: T): T =
21 (x + y - 1) div y 20 (x + y - 1) div y
22 21
23proc truncateToUint8*(x: SomeUnsignedInt): uint8 = 22proc truncateToUint8*(x: SomeUnsignedInt): uint8 =
24 (x and wordBitMask).uint8 23 (x and uint8.high).uint8
24
25proc leastSignificantBits*[T: SomeUnsignedInt](x: T, bits: int): T =
26 let maskOffset = sizeof(T) * wordBitLength - bits
27 if maskOffset >= 0: (x shl maskOffset) shr maskOffset else: x
25 28
26iterator chunks*(totalBitLength: int, chunkType: typedesc[SomeInteger]): tuple[index: int, chunkBitLength: int] = 29iterator chunks*(totalBitLength: int, chunkType: typedesc[SomeInteger]): tuple[index: int, chunkBitLength: int] =
27 let chunkBitLength = sizeof(chunkType) * wordBitLength 30 let chunkBitLength = sizeof(chunkType) * wordBitLength
diff --git a/tests/tbitreader.nim b/tests/tbitreader.nim
index 8285f63..294f6c9 100644
--- a/tests/tbitreader.nim
+++ b/tests/tbitreader.nim
@@ -49,6 +49,29 @@ suite "bitreader":
49 expect IOError: discard bitReader.readBits(16, uint16) 49 expect IOError: discard bitReader.readBits(16, uint16)
50 check bitReader.atEnd() 50 check bitReader.atEnd()
51 51
52 test "readBits (look-ahead overflow)":
53 let stream = newStringStream()
54 defer: stream.close()
55 stream.write(0xAB'u8)
56 stream.setPosition(0)
57
58 let bitReader = stream.bitReader()
59 check bitReader.readBits(4, uint16) == 0x000B'u16
60 check bitReader.readBits(4, uint16) == 0x000A'u16
61 check bitReader.atEnd()
62
63 test "readBits (from buffer composition)":
64 let stream = newStringStream()
65 defer: stream.close()
66 stream.write(0xABCD'u16)
67 stream.setPosition(0)
68
69 let bitReader = stream.bitReader()
70 check bitReader.readBits(4, uint16) == 0x000D'u16
71 check bitReader.readBits(8, uint16) == 0x00BC'u16
72 check bitReader.readBits(4, uint16) == 0x000A'u16
73 check bitReader.atEnd()
74
52 test "readSeq": 75 test "readSeq":
53 let stream = newStringStream() 76 let stream = newStringStream()
54 defer: stream.close() 77 defer: stream.close()
diff --git a/tests/tintegers.nim b/tests/tintegers.nim
index c77abec..956e4aa 100644
--- a/tests/tintegers.nim
+++ b/tests/tintegers.nim
@@ -27,6 +27,13 @@ suite "integers":
27 check truncateToUint8(0x00FA'u16) == 0xFA'u8 27 check truncateToUint8(0x00FA'u16) == 0xFA'u8
28 check truncateToUint8(0xFFFA'u16) == 0xFA'u8 28 check truncateToUint8(0xFFFA'u16) == 0xFA'u8
29 29
30 test "leastSignificantBits":
31 check leastSignificantBits(0xFF'u8, 3) == 0b0000_0111'u8
32 check leastSignificantBits(0b0001_0101'u8, 3) == 0b0000_0101'u8
33 check leastSignificantBits(0xFF'u8, 10) == 0xFF'u8
34 check leastSignificantBits(0xFFFF'u16, 16) == 0xFFFF'u16
35 check leastSignificantBits(0xFFFF'u16, 8) == 0x00FF'u16
36
30 test "chunks iterator": 37 test "chunks iterator":
31 check toSeq(chunks(70, uint32)) == @[(0, 32), (1, 32), (2, 6)] 38 check toSeq(chunks(70, uint32)) == @[(0, 32), (1, 32), (2, 6)]
32 check toSeq(chunks(32, uint16)) == @[(0, 16), (1, 16)] 39 check toSeq(chunks(32, uint16)) == @[(0, 16), (1, 16)]