1
2
3
4
5 package flate
6
7 import (
8 "math"
9 "sort"
10 )
11
12 type huffmanEncoder struct {
13 codeBits []uint8
14 code []uint16
15 }
16
17 type literalNode struct {
18 literal uint16
19 freq int32
20 }
21
22 type chain struct {
23
24 freq int32
25
26
27 leafCount int32
28
29
30 up *chain
31 }
32
33 type levelInfo struct {
34
35 level int32
36
37
38 lastChain *chain
39
40
41 nextCharFreq int32
42
43
44
45 nextPairFreq int32
46
47
48
49 needed int32
50
51
52 up *levelInfo
53
54
55 down *levelInfo
56 }
57
58 func maxNode() literalNode { return literalNode{math.MaxUint16, math.MaxInt32} }
59
60 func newHuffmanEncoder(size int) *huffmanEncoder {
61 return &huffmanEncoder{make([]uint8, size), make([]uint16, size)}
62 }
63
64
65 func generateFixedLiteralEncoding() *huffmanEncoder {
66 h := newHuffmanEncoder(maxLit)
67 codeBits := h.codeBits
68 code := h.code
69 var ch uint16
70 for ch = 0; ch < maxLit; ch++ {
71 var bits uint16
72 var size uint8
73 switch {
74 case ch < 144:
75
76 bits = ch + 48
77 size = 8
78 break
79 case ch < 256:
80
81 bits = ch + 400 - 144
82 size = 9
83 break
84 case ch < 280:
85
86 bits = ch - 256
87 size = 7
88 break
89 default:
90
91 bits = ch + 192 - 280
92 size = 8
93 }
94 codeBits[ch] = size
95 code[ch] = reverseBits(bits, size)
96 }
97 return h
98 }
99
100 func generateFixedOffsetEncoding() *huffmanEncoder {
101 h := newHuffmanEncoder(30)
102 codeBits := h.codeBits
103 code := h.code
104 for ch := uint16(0); ch < 30; ch++ {
105 codeBits[ch] = 5
106 code[ch] = reverseBits(ch, 5)
107 }
108 return h
109 }
110
111 var fixedLiteralEncoding *huffmanEncoder = generateFixedLiteralEncoding()
112 var fixedOffsetEncoding *huffmanEncoder = generateFixedOffsetEncoding()
113
114 func (h *huffmanEncoder) bitLength(freq []int32) int64 {
115 var total int64
116 for i, f := range freq {
117 if f != 0 {
118 total += int64(f) * int64(h.codeBits[i])
119 }
120 }
121 return total
122 }
123
124
125 func (h *huffmanEncoder) generateChains(top *levelInfo, list []literalNode) {
126 n := len(list)
127 list = list[0 : n+1]
128 list[n] = maxNode()
129
130 l := top
131 for {
132 if l.nextPairFreq == math.MaxInt32 && l.nextCharFreq == math.MaxInt32 {
133
134
135
136
137 l.lastChain = nil
138 l.needed = 0
139 l = l.up
140 l.nextPairFreq = math.MaxInt32
141 continue
142 }
143
144 prevFreq := l.lastChain.freq
145 if l.nextCharFreq < l.nextPairFreq {
146
147 n := l.lastChain.leafCount + 1
148 l.lastChain = &chain{l.nextCharFreq, n, l.lastChain.up}
149 l.nextCharFreq = list[n].freq
150 } else {
151
152
153
154 l.lastChain = &chain{l.nextPairFreq, l.lastChain.leafCount, l.down.lastChain}
155 l.down.needed = 2
156 }
157
158 if l.needed--; l.needed == 0 {
159
160
161
162
163 up := l.up
164 if up == nil {
165
166 return
167 }
168 up.nextPairFreq = prevFreq + l.lastChain.freq
169 l = up
170 } else {
171
172 for l.down.needed > 0 {
173 l = l.down
174 }
175 }
176 }
177 }
178
179
180
181
182
183
184
185
186
187
188
189
190
191 func (h *huffmanEncoder) bitCounts(list []literalNode, maxBits int32) []int32 {
192 n := int32(len(list))
193 list = list[0 : n+1]
194 list[n] = maxNode()
195
196
197
198 maxBits = minInt32(maxBits, n-1)
199
200
201
202
203
204 top := &levelInfo{needed: 0}
205 chain2 := &chain{list[1].freq, 2, new(chain)}
206 for level := int32(1); level <= maxBits; level++ {
207
208
209 top = &levelInfo{
210 level: level,
211 lastChain: chain2,
212 nextCharFreq: list[2].freq,
213 nextPairFreq: list[0].freq + list[1].freq,
214 down: top,
215 }
216 top.down.up = top
217 if level == 1 {
218 top.nextPairFreq = math.MaxInt32
219 }
220 }
221
222
223 top.needed = 2*n - 4
224
225 l := top
226 for {
227 if l.nextPairFreq == math.MaxInt32 && l.nextCharFreq == math.MaxInt32 {
228
229
230
231
232 l.lastChain = nil
233 l.needed = 0
234 l = l.up
235 l.nextPairFreq = math.MaxInt32
236 continue
237 }
238
239 prevFreq := l.lastChain.freq
240 if l.nextCharFreq < l.nextPairFreq {
241
242 n := l.lastChain.leafCount + 1
243 l.lastChain = &chain{l.nextCharFreq, n, l.lastChain.up}
244 l.nextCharFreq = list[n].freq
245 } else {
246
247
248
249 l.lastChain = &chain{l.nextPairFreq, l.lastChain.leafCount, l.down.lastChain}
250 l.down.needed = 2
251 }
252
253 if l.needed--; l.needed == 0 {
254
255
256
257
258 up := l.up
259 if up == nil {
260
261 break
262 }
263 up.nextPairFreq = prevFreq + l.lastChain.freq
264 l = up
265 } else {
266
267 for l.down.needed > 0 {
268 l = l.down
269 }
270 }
271 }
272
273
274
275 if top.lastChain.leafCount != n {
276 panic("top.lastChain.leafCount != n")
277 }
278
279 bitCount := make([]int32, maxBits+1)
280 bits := 1
281 for chain := top.lastChain; chain.up != nil; chain = chain.up {
282
283
284 bitCount[bits] = chain.leafCount - chain.up.leafCount
285 bits++
286 }
287 return bitCount
288 }
289
290
291
292 func (h *huffmanEncoder) assignEncodingAndSize(bitCount []int32, list []literalNode) {
293 code := uint16(0)
294 for n, bits := range bitCount {
295 code <<= 1
296 if n == 0 || bits == 0 {
297 continue
298 }
299
300
301
302
303 chunk := list[len(list)-int(bits):]
304 sortByLiteral(chunk)
305 for _, node := range chunk {
306 h.codeBits[node.literal] = uint8(n)
307 h.code[node.literal] = reverseBits(code, uint8(n))
308 code++
309 }
310 list = list[0 : len(list)-int(bits)]
311 }
312 }
313
314
315
316
317
318 func (h *huffmanEncoder) generate(freq []int32, maxBits int32) {
319 list := make([]literalNode, len(freq)+1)
320
321 count := 0
322
323 for i, f := range freq {
324 if f != 0 {
325 list[count] = literalNode{uint16(i), f}
326 count++
327 } else {
328 h.codeBits[i] = 0
329 }
330 }
331
332 h.codeBits = h.codeBits[0:len(freq)]
333 list = list[0:count]
334 if count <= 2 {
335
336
337 for i, node := range list {
338
339 h.codeBits[node.literal] = 1
340 h.code[node.literal] = uint16(i)
341 }
342 return
343 }
344 sortByFreq(list)
345
346
347 bitCount := h.bitCounts(list, maxBits)
348
349 h.assignEncodingAndSize(bitCount, list)
350 }
351
352 type literalNodeSorter struct {
353 a []literalNode
354 less func(i, j int) bool
355 }
356
357 func (s literalNodeSorter) Len() int { return len(s.a) }
358
359 func (s literalNodeSorter) Less(i, j int) bool {
360 return s.less(i, j)
361 }
362
363 func (s literalNodeSorter) Swap(i, j int) { s.a[i], s.a[j] = s.a[j], s.a[i] }
364
365 func sortByFreq(a []literalNode) {
366 s := &literalNodeSorter{a, func(i, j int) bool {
367 if a[i].freq == a[j].freq {
368 return a[i].literal < a[j].literal
369 }
370 return a[i].freq < a[j].freq
371 }}
372 sort.Sort(s)
373 }
374
375 func sortByLiteral(a []literalNode) {
376 s := &literalNodeSorter{a, func(i, j int) bool { return a[i].literal < a[j].literal }}
377 sort.Sort(s)
378 }