Source file src/internal/zstd/block.go

     1  // Copyright 2023 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package zstd
     6  
     7  import (
     8  	"io"
     9  )
    10  
    11  // debug can be set in the source to print debug info using println.
    12  const debug = false
    13  
    14  // compressedBlock decompresses a compressed block, storing the decompressed
    15  // data in r.buffer. The blockSize argument is the compressed size.
    16  // RFC 3.1.1.3.
    17  func (r *Reader) compressedBlock(blockSize int) error {
    18  	if len(r.compressedBuf) >= blockSize {
    19  		r.compressedBuf = r.compressedBuf[:blockSize]
    20  	} else {
    21  		// We know that blockSize <= 128K,
    22  		// so this won't allocate an enormous amount.
    23  		need := blockSize - len(r.compressedBuf)
    24  		r.compressedBuf = append(r.compressedBuf, make([]byte, need)...)
    25  	}
    26  
    27  	if _, err := io.ReadFull(r.r, r.compressedBuf); err != nil {
    28  		return r.wrapNonEOFError(0, err)
    29  	}
    30  
    31  	data := block(r.compressedBuf)
    32  	off := 0
    33  	r.buffer = r.buffer[:0]
    34  
    35  	litoff, litbuf, err := r.readLiterals(data, off, r.literals[:0])
    36  	if err != nil {
    37  		return err
    38  	}
    39  	r.literals = litbuf
    40  
    41  	off = litoff
    42  
    43  	seqCount, off, err := r.initSeqs(data, off)
    44  	if err != nil {
    45  		return err
    46  	}
    47  
    48  	if seqCount == 0 {
    49  		// No sequences, just literals.
    50  		if off < len(data) {
    51  			return r.makeError(off, "extraneous data after no sequences")
    52  		}
    53  
    54  		r.buffer = append(r.buffer, litbuf...)
    55  
    56  		return nil
    57  	}
    58  
    59  	return r.execSeqs(data, off, litbuf, seqCount)
    60  }
    61  
    62  // seqCode is the kind of sequence codes we have to handle.
    63  type seqCode int
    64  
    65  const (
    66  	seqLiteral seqCode = iota
    67  	seqOffset
    68  	seqMatch
    69  )
    70  
    71  // seqCodeInfoData is the information needed to set up seqTables and
    72  // seqTableBits for a particular kind of sequence code.
    73  type seqCodeInfoData struct {
    74  	predefTable     []fseBaselineEntry // predefined FSE
    75  	predefTableBits int                // number of bits in predefTable
    76  	maxSym          int                // max symbol value in FSE
    77  	maxBits         int                // max bits for FSE
    78  
    79  	// toBaseline converts from an FSE table to an FSE baseline table.
    80  	toBaseline func(*Reader, int, []fseEntry, []fseBaselineEntry) error
    81  }
    82  
    83  // seqCodeInfo is the seqCodeInfoData for each kind of sequence code.
    84  var seqCodeInfo = [3]seqCodeInfoData{
    85  	seqLiteral: {
    86  		predefTable:     predefinedLiteralTable[:],
    87  		predefTableBits: 6,
    88  		maxSym:          35,
    89  		maxBits:         9,
    90  		toBaseline:      (*Reader).makeLiteralBaselineFSE,
    91  	},
    92  	seqOffset: {
    93  		predefTable:     predefinedOffsetTable[:],
    94  		predefTableBits: 5,
    95  		maxSym:          31,
    96  		maxBits:         8,
    97  		toBaseline:      (*Reader).makeOffsetBaselineFSE,
    98  	},
    99  	seqMatch: {
   100  		predefTable:     predefinedMatchTable[:],
   101  		predefTableBits: 6,
   102  		maxSym:          52,
   103  		maxBits:         9,
   104  		toBaseline:      (*Reader).makeMatchBaselineFSE,
   105  	},
   106  }
   107  
   108  // initSeqs reads the Sequences_Section_Header and sets up the FSE
   109  // tables used to read the sequence codes. It returns the number of
   110  // sequences and the new offset. RFC 3.1.1.3.2.1.
   111  func (r *Reader) initSeqs(data block, off int) (int, int, error) {
   112  	if off >= len(data) {
   113  		return 0, 0, r.makeEOFError(off)
   114  	}
   115  
   116  	seqHdr := data[off]
   117  	off++
   118  	if seqHdr == 0 {
   119  		return 0, off, nil
   120  	}
   121  
   122  	var seqCount int
   123  	if seqHdr < 128 {
   124  		seqCount = int(seqHdr)
   125  	} else if seqHdr < 255 {
   126  		if off >= len(data) {
   127  			return 0, 0, r.makeEOFError(off)
   128  		}
   129  		seqCount = ((int(seqHdr) - 128) << 8) + int(data[off])
   130  		off++
   131  	} else {
   132  		if off+1 >= len(data) {
   133  			return 0, 0, r.makeEOFError(off)
   134  		}
   135  		seqCount = int(data[off]) + (int(data[off+1]) << 8) + 0x7f00
   136  		off += 2
   137  	}
   138  
   139  	// Read the Symbol_Compression_Modes byte.
   140  
   141  	if off >= len(data) {
   142  		return 0, 0, r.makeEOFError(off)
   143  	}
   144  	symMode := data[off]
   145  	if symMode&3 != 0 {
   146  		return 0, 0, r.makeError(off, "invalid symbol compression mode")
   147  	}
   148  	off++
   149  
   150  	// Set up the FSE tables used to decode the sequence codes.
   151  
   152  	var err error
   153  	off, err = r.setSeqTable(data, off, seqLiteral, (symMode>>6)&3)
   154  	if err != nil {
   155  		return 0, 0, err
   156  	}
   157  
   158  	off, err = r.setSeqTable(data, off, seqOffset, (symMode>>4)&3)
   159  	if err != nil {
   160  		return 0, 0, err
   161  	}
   162  
   163  	off, err = r.setSeqTable(data, off, seqMatch, (symMode>>2)&3)
   164  	if err != nil {
   165  		return 0, 0, err
   166  	}
   167  
   168  	return seqCount, off, nil
   169  }
   170  
   171  // setSeqTable uses the Compression_Mode in mode to set up r.seqTables and
   172  // r.seqTableBits for kind. We store these in the Reader because one of
   173  // the modes simply reuses the value from the last block in the frame.
   174  func (r *Reader) setSeqTable(data block, off int, kind seqCode, mode byte) (int, error) {
   175  	info := &seqCodeInfo[kind]
   176  	switch mode {
   177  	case 0:
   178  		// Predefined_Mode
   179  		r.seqTables[kind] = info.predefTable
   180  		r.seqTableBits[kind] = uint8(info.predefTableBits)
   181  		return off, nil
   182  
   183  	case 1:
   184  		// RLE_Mode
   185  		if off >= len(data) {
   186  			return 0, r.makeEOFError(off)
   187  		}
   188  		rle := data[off]
   189  		off++
   190  
   191  		// Build a simple baseline table that always returns rle.
   192  
   193  		entry := []fseEntry{
   194  			{
   195  				sym:  rle,
   196  				bits: 0,
   197  				base: 0,
   198  			},
   199  		}
   200  		if cap(r.seqTableBuffers[kind]) == 0 {
   201  			r.seqTableBuffers[kind] = make([]fseBaselineEntry, 1<<info.maxBits)
   202  		}
   203  		r.seqTableBuffers[kind] = r.seqTableBuffers[kind][:1]
   204  		if err := info.toBaseline(r, off, entry, r.seqTableBuffers[kind]); err != nil {
   205  			return 0, err
   206  		}
   207  
   208  		r.seqTables[kind] = r.seqTableBuffers[kind]
   209  		r.seqTableBits[kind] = 0
   210  		return off, nil
   211  
   212  	case 2:
   213  		// FSE_Compressed_Mode
   214  		if cap(r.fseScratch) < 1<<info.maxBits {
   215  			r.fseScratch = make([]fseEntry, 1<<info.maxBits)
   216  		}
   217  		r.fseScratch = r.fseScratch[:1<<info.maxBits]
   218  
   219  		tableBits, roff, err := r.readFSE(data, off, info.maxSym, info.maxBits, r.fseScratch)
   220  		if err != nil {
   221  			return 0, err
   222  		}
   223  		r.fseScratch = r.fseScratch[:1<<tableBits]
   224  
   225  		if cap(r.seqTableBuffers[kind]) == 0 {
   226  			r.seqTableBuffers[kind] = make([]fseBaselineEntry, 1<<info.maxBits)
   227  		}
   228  		r.seqTableBuffers[kind] = r.seqTableBuffers[kind][:1<<tableBits]
   229  
   230  		if err := info.toBaseline(r, roff, r.fseScratch, r.seqTableBuffers[kind]); err != nil {
   231  			return 0, err
   232  		}
   233  
   234  		r.seqTables[kind] = r.seqTableBuffers[kind]
   235  		r.seqTableBits[kind] = uint8(tableBits)
   236  		return roff, nil
   237  
   238  	case 3:
   239  		// Repeat_Mode
   240  		if len(r.seqTables[kind]) == 0 {
   241  			return 0, r.makeError(off, "missing repeat sequence FSE table")
   242  		}
   243  		return off, nil
   244  	}
   245  	panic("unreachable")
   246  }
   247  
   248  // execSeqs reads and executes the sequences. RFC 3.1.1.3.2.1.2.
   249  func (r *Reader) execSeqs(data block, off int, litbuf []byte, seqCount int) error {
   250  	// Set up the initial states for the sequence code readers.
   251  
   252  	rbr, err := r.makeReverseBitReader(data, len(data)-1, off)
   253  	if err != nil {
   254  		return err
   255  	}
   256  
   257  	literalState, err := rbr.val(r.seqTableBits[seqLiteral])
   258  	if err != nil {
   259  		return err
   260  	}
   261  
   262  	offsetState, err := rbr.val(r.seqTableBits[seqOffset])
   263  	if err != nil {
   264  		return err
   265  	}
   266  
   267  	matchState, err := rbr.val(r.seqTableBits[seqMatch])
   268  	if err != nil {
   269  		return err
   270  	}
   271  
   272  	// Read and perform all the sequences. RFC 3.1.1.4.
   273  
   274  	seq := 0
   275  	for seq < seqCount {
   276  		if len(r.buffer)+len(litbuf) > 128<<10 {
   277  			return rbr.makeError("uncompressed size too big")
   278  		}
   279  
   280  		ptoffset := &r.seqTables[seqOffset][offsetState]
   281  		ptmatch := &r.seqTables[seqMatch][matchState]
   282  		ptliteral := &r.seqTables[seqLiteral][literalState]
   283  
   284  		add, err := rbr.val(ptoffset.basebits)
   285  		if err != nil {
   286  			return err
   287  		}
   288  		offset := ptoffset.baseline + add
   289  
   290  		add, err = rbr.val(ptmatch.basebits)
   291  		if err != nil {
   292  			return err
   293  		}
   294  		match := ptmatch.baseline + add
   295  
   296  		add, err = rbr.val(ptliteral.basebits)
   297  		if err != nil {
   298  			return err
   299  		}
   300  		literal := ptliteral.baseline + add
   301  
   302  		// Handle repeat offsets. RFC 3.1.1.5.
   303  		// See the comment in makeOffsetBaselineFSE.
   304  		if ptoffset.basebits > 1 {
   305  			r.repeatedOffset3 = r.repeatedOffset2
   306  			r.repeatedOffset2 = r.repeatedOffset1
   307  			r.repeatedOffset1 = offset
   308  		} else {
   309  			if literal == 0 {
   310  				offset++
   311  			}
   312  			switch offset {
   313  			case 1:
   314  				offset = r.repeatedOffset1
   315  			case 2:
   316  				offset = r.repeatedOffset2
   317  				r.repeatedOffset2 = r.repeatedOffset1
   318  				r.repeatedOffset1 = offset
   319  			case 3:
   320  				offset = r.repeatedOffset3
   321  				r.repeatedOffset3 = r.repeatedOffset2
   322  				r.repeatedOffset2 = r.repeatedOffset1
   323  				r.repeatedOffset1 = offset
   324  			case 4:
   325  				offset = r.repeatedOffset1 - 1
   326  				r.repeatedOffset3 = r.repeatedOffset2
   327  				r.repeatedOffset2 = r.repeatedOffset1
   328  				r.repeatedOffset1 = offset
   329  			}
   330  		}
   331  
   332  		seq++
   333  		if seq < seqCount {
   334  			// Update the states.
   335  			add, err = rbr.val(ptliteral.bits)
   336  			if err != nil {
   337  				return err
   338  			}
   339  			literalState = uint32(ptliteral.base) + add
   340  
   341  			add, err = rbr.val(ptmatch.bits)
   342  			if err != nil {
   343  				return err
   344  			}
   345  			matchState = uint32(ptmatch.base) + add
   346  
   347  			add, err = rbr.val(ptoffset.bits)
   348  			if err != nil {
   349  				return err
   350  			}
   351  			offsetState = uint32(ptoffset.base) + add
   352  		}
   353  
   354  		// The next sequence is now in literal, offset, match.
   355  
   356  		if debug {
   357  			println("literal", literal, "offset", offset, "match", match)
   358  		}
   359  
   360  		// Copy literal bytes from litbuf.
   361  		if literal > uint32(len(litbuf)) {
   362  			return rbr.makeError("literal byte overflow")
   363  		}
   364  		if literal > 0 {
   365  			r.buffer = append(r.buffer, litbuf[:literal]...)
   366  			litbuf = litbuf[literal:]
   367  		}
   368  
   369  		if match > 0 {
   370  			if err := r.copyFromWindow(&rbr, offset, match); err != nil {
   371  				return err
   372  			}
   373  		}
   374  	}
   375  
   376  	r.buffer = append(r.buffer, litbuf...)
   377  
   378  	if rbr.cnt != 0 {
   379  		return r.makeError(off, "extraneous data after sequences")
   380  	}
   381  
   382  	return nil
   383  }
   384  
   385  // Copy match bytes from the decoded output, or the window, at offset.
   386  func (r *Reader) copyFromWindow(rbr *reverseBitReader, offset, match uint32) error {
   387  	if offset == 0 {
   388  		return rbr.makeError("invalid zero offset")
   389  	}
   390  
   391  	// Offset may point into the buffer or the window and
   392  	// match may extend past the end of the initial buffer.
   393  	// |--r.window--|--r.buffer--|
   394  	//        |<-----offset------|
   395  	//        |------match----------->|
   396  	bufferOffset := uint32(0)
   397  	lenBlock := uint32(len(r.buffer))
   398  	if lenBlock < offset {
   399  		lenWindow := r.window.len()
   400  		copy := offset - lenBlock
   401  		if copy > lenWindow {
   402  			return rbr.makeError("offset past window")
   403  		}
   404  		windowOffset := lenWindow - copy
   405  		if copy > match {
   406  			copy = match
   407  		}
   408  		r.buffer = r.window.appendTo(r.buffer, windowOffset, windowOffset+copy)
   409  		match -= copy
   410  	} else {
   411  		bufferOffset = lenBlock - offset
   412  	}
   413  
   414  	// We are being asked to copy data that we are adding to the
   415  	// buffer in the same copy.
   416  	for match > 0 {
   417  		copy := uint32(len(r.buffer)) - bufferOffset
   418  		if copy > match {
   419  			copy = match
   420  		}
   421  		r.buffer = append(r.buffer, r.buffer[bufferOffset:bufferOffset+copy]...)
   422  		match -= copy
   423  	}
   424  	return nil
   425  }
   426  

View as plain text