Run Format

Source file src/pkg/crypto/tls/handshake_messages.go

     1	// Copyright 2009 The Go Authors. All rights reserved.
     2	// Use of this source code is governed by a BSD-style
     3	// license that can be found in the LICENSE file.
     4	
     5	package tls
     6	
     7	import "bytes"
     8	
     9	type clientHelloMsg struct {
    10		raw                []byte
    11		vers               uint16
    12		random             []byte
    13		sessionId          []byte
    14		cipherSuites       []uint16
    15		compressionMethods []uint8
    16		nextProtoNeg       bool
    17		serverName         string
    18		ocspStapling       bool
    19		supportedCurves    []uint16
    20		supportedPoints    []uint8
    21		ticketSupported    bool
    22		sessionTicket      []uint8
    23		signatureAndHashes []signatureAndHash
    24	}
    25	
    26	func (m *clientHelloMsg) equal(i interface{}) bool {
    27		m1, ok := i.(*clientHelloMsg)
    28		if !ok {
    29			return false
    30		}
    31	
    32		return bytes.Equal(m.raw, m1.raw) &&
    33			m.vers == m1.vers &&
    34			bytes.Equal(m.random, m1.random) &&
    35			bytes.Equal(m.sessionId, m1.sessionId) &&
    36			eqUint16s(m.cipherSuites, m1.cipherSuites) &&
    37			bytes.Equal(m.compressionMethods, m1.compressionMethods) &&
    38			m.nextProtoNeg == m1.nextProtoNeg &&
    39			m.serverName == m1.serverName &&
    40			m.ocspStapling == m1.ocspStapling &&
    41			eqUint16s(m.supportedCurves, m1.supportedCurves) &&
    42			bytes.Equal(m.supportedPoints, m1.supportedPoints) &&
    43			m.ticketSupported == m1.ticketSupported &&
    44			bytes.Equal(m.sessionTicket, m1.sessionTicket) &&
    45			eqSignatureAndHashes(m.signatureAndHashes, m1.signatureAndHashes)
    46	}
    47	
    48	func (m *clientHelloMsg) marshal() []byte {
    49		if m.raw != nil {
    50			return m.raw
    51		}
    52	
    53		length := 2 + 32 + 1 + len(m.sessionId) + 2 + len(m.cipherSuites)*2 + 1 + len(m.compressionMethods)
    54		numExtensions := 0
    55		extensionsLength := 0
    56		if m.nextProtoNeg {
    57			numExtensions++
    58		}
    59		if m.ocspStapling {
    60			extensionsLength += 1 + 2 + 2
    61			numExtensions++
    62		}
    63		if len(m.serverName) > 0 {
    64			extensionsLength += 5 + len(m.serverName)
    65			numExtensions++
    66		}
    67		if len(m.supportedCurves) > 0 {
    68			extensionsLength += 2 + 2*len(m.supportedCurves)
    69			numExtensions++
    70		}
    71		if len(m.supportedPoints) > 0 {
    72			extensionsLength += 1 + len(m.supportedPoints)
    73			numExtensions++
    74		}
    75		if m.ticketSupported {
    76			extensionsLength += len(m.sessionTicket)
    77			numExtensions++
    78		}
    79		if len(m.signatureAndHashes) > 0 {
    80			extensionsLength += 2 + 2*len(m.signatureAndHashes)
    81			numExtensions++
    82		}
    83		if numExtensions > 0 {
    84			extensionsLength += 4 * numExtensions
    85			length += 2 + extensionsLength
    86		}
    87	
    88		x := make([]byte, 4+length)
    89		x[0] = typeClientHello
    90		x[1] = uint8(length >> 16)
    91		x[2] = uint8(length >> 8)
    92		x[3] = uint8(length)
    93		x[4] = uint8(m.vers >> 8)
    94		x[5] = uint8(m.vers)
    95		copy(x[6:38], m.random)
    96		x[38] = uint8(len(m.sessionId))
    97		copy(x[39:39+len(m.sessionId)], m.sessionId)
    98		y := x[39+len(m.sessionId):]
    99		y[0] = uint8(len(m.cipherSuites) >> 7)
   100		y[1] = uint8(len(m.cipherSuites) << 1)
   101		for i, suite := range m.cipherSuites {
   102			y[2+i*2] = uint8(suite >> 8)
   103			y[3+i*2] = uint8(suite)
   104		}
   105		z := y[2+len(m.cipherSuites)*2:]
   106		z[0] = uint8(len(m.compressionMethods))
   107		copy(z[1:], m.compressionMethods)
   108	
   109		z = z[1+len(m.compressionMethods):]
   110		if numExtensions > 0 {
   111			z[0] = byte(extensionsLength >> 8)
   112			z[1] = byte(extensionsLength)
   113			z = z[2:]
   114		}
   115		if m.nextProtoNeg {
   116			z[0] = byte(extensionNextProtoNeg >> 8)
   117			z[1] = byte(extensionNextProtoNeg)
   118			// The length is always 0
   119			z = z[4:]
   120		}
   121		if len(m.serverName) > 0 {
   122			z[0] = byte(extensionServerName >> 8)
   123			z[1] = byte(extensionServerName)
   124			l := len(m.serverName) + 5
   125			z[2] = byte(l >> 8)
   126			z[3] = byte(l)
   127			z = z[4:]
   128	
   129			// RFC 3546, section 3.1
   130			//
   131			// struct {
   132			//     NameType name_type;
   133			//     select (name_type) {
   134			//         case host_name: HostName;
   135			//     } name;
   136			// } ServerName;
   137			//
   138			// enum {
   139			//     host_name(0), (255)
   140			// } NameType;
   141			//
   142			// opaque HostName<1..2^16-1>;
   143			//
   144			// struct {
   145			//     ServerName server_name_list<1..2^16-1>
   146			// } ServerNameList;
   147	
   148			z[0] = byte((len(m.serverName) + 3) >> 8)
   149			z[1] = byte(len(m.serverName) + 3)
   150			z[3] = byte(len(m.serverName) >> 8)
   151			z[4] = byte(len(m.serverName))
   152			copy(z[5:], []byte(m.serverName))
   153			z = z[l:]
   154		}
   155		if m.ocspStapling {
   156			// RFC 4366, section 3.6
   157			z[0] = byte(extensionStatusRequest >> 8)
   158			z[1] = byte(extensionStatusRequest)
   159			z[2] = 0
   160			z[3] = 5
   161			z[4] = 1 // OCSP type
   162			// Two zero valued uint16s for the two lengths.
   163			z = z[9:]
   164		}
   165		if len(m.supportedCurves) > 0 {
   166			// http://tools.ietf.org/html/rfc4492#section-5.5.1
   167			z[0] = byte(extensionSupportedCurves >> 8)
   168			z[1] = byte(extensionSupportedCurves)
   169			l := 2 + 2*len(m.supportedCurves)
   170			z[2] = byte(l >> 8)
   171			z[3] = byte(l)
   172			l -= 2
   173			z[4] = byte(l >> 8)
   174			z[5] = byte(l)
   175			z = z[6:]
   176			for _, curve := range m.supportedCurves {
   177				z[0] = byte(curve >> 8)
   178				z[1] = byte(curve)
   179				z = z[2:]
   180			}
   181		}
   182		if len(m.supportedPoints) > 0 {
   183			// http://tools.ietf.org/html/rfc4492#section-5.5.2
   184			z[0] = byte(extensionSupportedPoints >> 8)
   185			z[1] = byte(extensionSupportedPoints)
   186			l := 1 + len(m.supportedPoints)
   187			z[2] = byte(l >> 8)
   188			z[3] = byte(l)
   189			l--
   190			z[4] = byte(l)
   191			z = z[5:]
   192			for _, pointFormat := range m.supportedPoints {
   193				z[0] = byte(pointFormat)
   194				z = z[1:]
   195			}
   196		}
   197		if m.ticketSupported {
   198			// http://tools.ietf.org/html/rfc5077#section-3.2
   199			z[0] = byte(extensionSessionTicket >> 8)
   200			z[1] = byte(extensionSessionTicket)
   201			l := len(m.sessionTicket)
   202			z[2] = byte(l >> 8)
   203			z[3] = byte(l)
   204			z = z[4:]
   205			copy(z, m.sessionTicket)
   206			z = z[len(m.sessionTicket):]
   207		}
   208		if len(m.signatureAndHashes) > 0 {
   209			// https://tools.ietf.org/html/rfc5246#section-7.4.1.4.1
   210			z[0] = byte(extensionSignatureAlgorithms >> 8)
   211			z[1] = byte(extensionSignatureAlgorithms)
   212			l := 2 + 2*len(m.signatureAndHashes)
   213			z[2] = byte(l >> 8)
   214			z[3] = byte(l)
   215			z = z[4:]
   216	
   217			l -= 2
   218			z[0] = byte(l >> 8)
   219			z[1] = byte(l)
   220			z = z[2:]
   221			for _, sigAndHash := range m.signatureAndHashes {
   222				z[0] = sigAndHash.hash
   223				z[1] = sigAndHash.signature
   224				z = z[2:]
   225			}
   226		}
   227	
   228		m.raw = x
   229	
   230		return x
   231	}
   232	
   233	func (m *clientHelloMsg) unmarshal(data []byte) bool {
   234		if len(data) < 42 {
   235			return false
   236		}
   237		m.raw = data
   238		m.vers = uint16(data[4])<<8 | uint16(data[5])
   239		m.random = data[6:38]
   240		sessionIdLen := int(data[38])
   241		if sessionIdLen > 32 || len(data) < 39+sessionIdLen {
   242			return false
   243		}
   244		m.sessionId = data[39 : 39+sessionIdLen]
   245		data = data[39+sessionIdLen:]
   246		if len(data) < 2 {
   247			return false
   248		}
   249		// cipherSuiteLen is the number of bytes of cipher suite numbers. Since
   250		// they are uint16s, the number must be even.
   251		cipherSuiteLen := int(data[0])<<8 | int(data[1])
   252		if cipherSuiteLen%2 == 1 || len(data) < 2+cipherSuiteLen {
   253			return false
   254		}
   255		numCipherSuites := cipherSuiteLen / 2
   256		m.cipherSuites = make([]uint16, numCipherSuites)
   257		for i := 0; i < numCipherSuites; i++ {
   258			m.cipherSuites[i] = uint16(data[2+2*i])<<8 | uint16(data[3+2*i])
   259		}
   260		data = data[2+cipherSuiteLen:]
   261		if len(data) < 1 {
   262			return false
   263		}
   264		compressionMethodsLen := int(data[0])
   265		if len(data) < 1+compressionMethodsLen {
   266			return false
   267		}
   268		m.compressionMethods = data[1 : 1+compressionMethodsLen]
   269	
   270		data = data[1+compressionMethodsLen:]
   271	
   272		m.nextProtoNeg = false
   273		m.serverName = ""
   274		m.ocspStapling = false
   275		m.ticketSupported = false
   276		m.sessionTicket = nil
   277		m.signatureAndHashes = nil
   278	
   279		if len(data) == 0 {
   280			// ClientHello is optionally followed by extension data
   281			return true
   282		}
   283		if len(data) < 2 {
   284			return false
   285		}
   286	
   287		extensionsLength := int(data[0])<<8 | int(data[1])
   288		data = data[2:]
   289		if extensionsLength != len(data) {
   290			return false
   291		}
   292	
   293		for len(data) != 0 {
   294			if len(data) < 4 {
   295				return false
   296			}
   297			extension := uint16(data[0])<<8 | uint16(data[1])
   298			length := int(data[2])<<8 | int(data[3])
   299			data = data[4:]
   300			if len(data) < length {
   301				return false
   302			}
   303	
   304			switch extension {
   305			case extensionServerName:
   306				if length < 2 {
   307					return false
   308				}
   309				numNames := int(data[0])<<8 | int(data[1])
   310				d := data[2:]
   311				for i := 0; i < numNames; i++ {
   312					if len(d) < 3 {
   313						return false
   314					}
   315					nameType := d[0]
   316					nameLen := int(d[1])<<8 | int(d[2])
   317					d = d[3:]
   318					if len(d) < nameLen {
   319						return false
   320					}
   321					if nameType == 0 {
   322						m.serverName = string(d[0:nameLen])
   323						break
   324					}
   325					d = d[nameLen:]
   326				}
   327			case extensionNextProtoNeg:
   328				if length > 0 {
   329					return false
   330				}
   331				m.nextProtoNeg = true
   332			case extensionStatusRequest:
   333				m.ocspStapling = length > 0 && data[0] == statusTypeOCSP
   334			case extensionSupportedCurves:
   335				// http://tools.ietf.org/html/rfc4492#section-5.5.1
   336				if length < 2 {
   337					return false
   338				}
   339				l := int(data[0])<<8 | int(data[1])
   340				if l%2 == 1 || length != l+2 {
   341					return false
   342				}
   343				numCurves := l / 2
   344				m.supportedCurves = make([]uint16, numCurves)
   345				d := data[2:]
   346				for i := 0; i < numCurves; i++ {
   347					m.supportedCurves[i] = uint16(d[0])<<8 | uint16(d[1])
   348					d = d[2:]
   349				}
   350			case extensionSupportedPoints:
   351				// http://tools.ietf.org/html/rfc4492#section-5.5.2
   352				if length < 1 {
   353					return false
   354				}
   355				l := int(data[0])
   356				if length != l+1 {
   357					return false
   358				}
   359				m.supportedPoints = make([]uint8, l)
   360				copy(m.supportedPoints, data[1:])
   361			case extensionSessionTicket:
   362				// http://tools.ietf.org/html/rfc5077#section-3.2
   363				m.ticketSupported = true
   364				m.sessionTicket = data[:length]
   365			case extensionSignatureAlgorithms:
   366				// https://tools.ietf.org/html/rfc5246#section-7.4.1.4.1
   367				if length < 2 || length&1 != 0 {
   368					return false
   369				}
   370				l := int(data[0])<<8 | int(data[1])
   371				if l != length-2 {
   372					return false
   373				}
   374				n := l / 2
   375				d := data[2:]
   376				m.signatureAndHashes = make([]signatureAndHash, n)
   377				for i := range m.signatureAndHashes {
   378					m.signatureAndHashes[i].hash = d[0]
   379					m.signatureAndHashes[i].signature = d[1]
   380					d = d[2:]
   381				}
   382			}
   383			data = data[length:]
   384		}
   385	
   386		return true
   387	}
   388	
   389	type serverHelloMsg struct {
   390		raw               []byte
   391		vers              uint16
   392		random            []byte
   393		sessionId         []byte
   394		cipherSuite       uint16
   395		compressionMethod uint8
   396		nextProtoNeg      bool
   397		nextProtos        []string
   398		ocspStapling      bool
   399		ticketSupported   bool
   400	}
   401	
   402	func (m *serverHelloMsg) equal(i interface{}) bool {
   403		m1, ok := i.(*serverHelloMsg)
   404		if !ok {
   405			return false
   406		}
   407	
   408		return bytes.Equal(m.raw, m1.raw) &&
   409			m.vers == m1.vers &&
   410			bytes.Equal(m.random, m1.random) &&
   411			bytes.Equal(m.sessionId, m1.sessionId) &&
   412			m.cipherSuite == m1.cipherSuite &&
   413			m.compressionMethod == m1.compressionMethod &&
   414			m.nextProtoNeg == m1.nextProtoNeg &&
   415			eqStrings(m.nextProtos, m1.nextProtos) &&
   416			m.ocspStapling == m1.ocspStapling &&
   417			m.ticketSupported == m1.ticketSupported
   418	}
   419	
   420	func (m *serverHelloMsg) marshal() []byte {
   421		if m.raw != nil {
   422			return m.raw
   423		}
   424	
   425		length := 38 + len(m.sessionId)
   426		numExtensions := 0
   427		extensionsLength := 0
   428	
   429		nextProtoLen := 0
   430		if m.nextProtoNeg {
   431			numExtensions++
   432			for _, v := range m.nextProtos {
   433				nextProtoLen += len(v)
   434			}
   435			nextProtoLen += len(m.nextProtos)
   436			extensionsLength += nextProtoLen
   437		}
   438		if m.ocspStapling {
   439			numExtensions++
   440		}
   441		if m.ticketSupported {
   442			numExtensions++
   443		}
   444		if numExtensions > 0 {
   445			extensionsLength += 4 * numExtensions
   446			length += 2 + extensionsLength
   447		}
   448	
   449		x := make([]byte, 4+length)
   450		x[0] = typeServerHello
   451		x[1] = uint8(length >> 16)
   452		x[2] = uint8(length >> 8)
   453		x[3] = uint8(length)
   454		x[4] = uint8(m.vers >> 8)
   455		x[5] = uint8(m.vers)
   456		copy(x[6:38], m.random)
   457		x[38] = uint8(len(m.sessionId))
   458		copy(x[39:39+len(m.sessionId)], m.sessionId)
   459		z := x[39+len(m.sessionId):]
   460		z[0] = uint8(m.cipherSuite >> 8)
   461		z[1] = uint8(m.cipherSuite)
   462		z[2] = uint8(m.compressionMethod)
   463	
   464		z = z[3:]
   465		if numExtensions > 0 {
   466			z[0] = byte(extensionsLength >> 8)
   467			z[1] = byte(extensionsLength)
   468			z = z[2:]
   469		}
   470		if m.nextProtoNeg {
   471			z[0] = byte(extensionNextProtoNeg >> 8)
   472			z[1] = byte(extensionNextProtoNeg)
   473			z[2] = byte(nextProtoLen >> 8)
   474			z[3] = byte(nextProtoLen)
   475			z = z[4:]
   476	
   477			for _, v := range m.nextProtos {
   478				l := len(v)
   479				if l > 255 {
   480					l = 255
   481				}
   482				z[0] = byte(l)
   483				copy(z[1:], []byte(v[0:l]))
   484				z = z[1+l:]
   485			}
   486		}
   487		if m.ocspStapling {
   488			z[0] = byte(extensionStatusRequest >> 8)
   489			z[1] = byte(extensionStatusRequest)
   490			z = z[4:]
   491		}
   492		if m.ticketSupported {
   493			z[0] = byte(extensionSessionTicket >> 8)
   494			z[1] = byte(extensionSessionTicket)
   495			z = z[4:]
   496		}
   497	
   498		m.raw = x
   499	
   500		return x
   501	}
   502	
   503	func (m *serverHelloMsg) unmarshal(data []byte) bool {
   504		if len(data) < 42 {
   505			return false
   506		}
   507		m.raw = data
   508		m.vers = uint16(data[4])<<8 | uint16(data[5])
   509		m.random = data[6:38]
   510		sessionIdLen := int(data[38])
   511		if sessionIdLen > 32 || len(data) < 39+sessionIdLen {
   512			return false
   513		}
   514		m.sessionId = data[39 : 39+sessionIdLen]
   515		data = data[39+sessionIdLen:]
   516		if len(data) < 3 {
   517			return false
   518		}
   519		m.cipherSuite = uint16(data[0])<<8 | uint16(data[1])
   520		m.compressionMethod = data[2]
   521		data = data[3:]
   522	
   523		m.nextProtoNeg = false
   524		m.nextProtos = nil
   525		m.ocspStapling = false
   526		m.ticketSupported = false
   527	
   528		if len(data) == 0 {
   529			// ServerHello is optionally followed by extension data
   530			return true
   531		}
   532		if len(data) < 2 {
   533			return false
   534		}
   535	
   536		extensionsLength := int(data[0])<<8 | int(data[1])
   537		data = data[2:]
   538		if len(data) != extensionsLength {
   539			return false
   540		}
   541	
   542		for len(data) != 0 {
   543			if len(data) < 4 {
   544				return false
   545			}
   546			extension := uint16(data[0])<<8 | uint16(data[1])
   547			length := int(data[2])<<8 | int(data[3])
   548			data = data[4:]
   549			if len(data) < length {
   550				return false
   551			}
   552	
   553			switch extension {
   554			case extensionNextProtoNeg:
   555				m.nextProtoNeg = true
   556				d := data[:length]
   557				for len(d) > 0 {
   558					l := int(d[0])
   559					d = d[1:]
   560					if l == 0 || l > len(d) {
   561						return false
   562					}
   563					m.nextProtos = append(m.nextProtos, string(d[:l]))
   564					d = d[l:]
   565				}
   566			case extensionStatusRequest:
   567				if length > 0 {
   568					return false
   569				}
   570				m.ocspStapling = true
   571			case extensionSessionTicket:
   572				if length > 0 {
   573					return false
   574				}
   575				m.ticketSupported = true
   576			}
   577			data = data[length:]
   578		}
   579	
   580		return true
   581	}
   582	
   583	type certificateMsg struct {
   584		raw          []byte
   585		certificates [][]byte
   586	}
   587	
   588	func (m *certificateMsg) equal(i interface{}) bool {
   589		m1, ok := i.(*certificateMsg)
   590		if !ok {
   591			return false
   592		}
   593	
   594		return bytes.Equal(m.raw, m1.raw) &&
   595			eqByteSlices(m.certificates, m1.certificates)
   596	}
   597	
   598	func (m *certificateMsg) marshal() (x []byte) {
   599		if m.raw != nil {
   600			return m.raw
   601		}
   602	
   603		var i int
   604		for _, slice := range m.certificates {
   605			i += len(slice)
   606		}
   607	
   608		length := 3 + 3*len(m.certificates) + i
   609		x = make([]byte, 4+length)
   610		x[0] = typeCertificate
   611		x[1] = uint8(length >> 16)
   612		x[2] = uint8(length >> 8)
   613		x[3] = uint8(length)
   614	
   615		certificateOctets := length - 3
   616		x[4] = uint8(certificateOctets >> 16)
   617		x[5] = uint8(certificateOctets >> 8)
   618		x[6] = uint8(certificateOctets)
   619	
   620		y := x[7:]
   621		for _, slice := range m.certificates {
   622			y[0] = uint8(len(slice) >> 16)
   623			y[1] = uint8(len(slice) >> 8)
   624			y[2] = uint8(len(slice))
   625			copy(y[3:], slice)
   626			y = y[3+len(slice):]
   627		}
   628	
   629		m.raw = x
   630		return
   631	}
   632	
   633	func (m *certificateMsg) unmarshal(data []byte) bool {
   634		if len(data) < 7 {
   635			return false
   636		}
   637	
   638		m.raw = data
   639		certsLen := uint32(data[4])<<16 | uint32(data[5])<<8 | uint32(data[6])
   640		if uint32(len(data)) != certsLen+7 {
   641			return false
   642		}
   643	
   644		numCerts := 0
   645		d := data[7:]
   646		for certsLen > 0 {
   647			if len(d) < 4 {
   648				return false
   649			}
   650			certLen := uint32(d[0])<<16 | uint32(d[1])<<8 | uint32(d[2])
   651			if uint32(len(d)) < 3+certLen {
   652				return false
   653			}
   654			d = d[3+certLen:]
   655			certsLen -= 3 + certLen
   656			numCerts++
   657		}
   658	
   659		m.certificates = make([][]byte, numCerts)
   660		d = data[7:]
   661		for i := 0; i < numCerts; i++ {
   662			certLen := uint32(d[0])<<16 | uint32(d[1])<<8 | uint32(d[2])
   663			m.certificates[i] = d[3 : 3+certLen]
   664			d = d[3+certLen:]
   665		}
   666	
   667		return true
   668	}
   669	
   670	type serverKeyExchangeMsg struct {
   671		raw []byte
   672		key []byte
   673	}
   674	
   675	func (m *serverKeyExchangeMsg) equal(i interface{}) bool {
   676		m1, ok := i.(*serverKeyExchangeMsg)
   677		if !ok {
   678			return false
   679		}
   680	
   681		return bytes.Equal(m.raw, m1.raw) &&
   682			bytes.Equal(m.key, m1.key)
   683	}
   684	
   685	func (m *serverKeyExchangeMsg) marshal() []byte {
   686		if m.raw != nil {
   687			return m.raw
   688		}
   689		length := len(m.key)
   690		x := make([]byte, length+4)
   691		x[0] = typeServerKeyExchange
   692		x[1] = uint8(length >> 16)
   693		x[2] = uint8(length >> 8)
   694		x[3] = uint8(length)
   695		copy(x[4:], m.key)
   696	
   697		m.raw = x
   698		return x
   699	}
   700	
   701	func (m *serverKeyExchangeMsg) unmarshal(data []byte) bool {
   702		m.raw = data
   703		if len(data) < 4 {
   704			return false
   705		}
   706		m.key = data[4:]
   707		return true
   708	}
   709	
   710	type certificateStatusMsg struct {
   711		raw        []byte
   712		statusType uint8
   713		response   []byte
   714	}
   715	
   716	func (m *certificateStatusMsg) equal(i interface{}) bool {
   717		m1, ok := i.(*certificateStatusMsg)
   718		if !ok {
   719			return false
   720		}
   721	
   722		return bytes.Equal(m.raw, m1.raw) &&
   723			m.statusType == m1.statusType &&
   724			bytes.Equal(m.response, m1.response)
   725	}
   726	
   727	func (m *certificateStatusMsg) marshal() []byte {
   728		if m.raw != nil {
   729			return m.raw
   730		}
   731	
   732		var x []byte
   733		if m.statusType == statusTypeOCSP {
   734			x = make([]byte, 4+4+len(m.response))
   735			x[0] = typeCertificateStatus
   736			l := len(m.response) + 4
   737			x[1] = byte(l >> 16)
   738			x[2] = byte(l >> 8)
   739			x[3] = byte(l)
   740			x[4] = statusTypeOCSP
   741	
   742			l -= 4
   743			x[5] = byte(l >> 16)
   744			x[6] = byte(l >> 8)
   745			x[7] = byte(l)
   746			copy(x[8:], m.response)
   747		} else {
   748			x = []byte{typeCertificateStatus, 0, 0, 1, m.statusType}
   749		}
   750	
   751		m.raw = x
   752		return x
   753	}
   754	
   755	func (m *certificateStatusMsg) unmarshal(data []byte) bool {
   756		m.raw = data
   757		if len(data) < 5 {
   758			return false
   759		}
   760		m.statusType = data[4]
   761	
   762		m.response = nil
   763		if m.statusType == statusTypeOCSP {
   764			if len(data) < 8 {
   765				return false
   766			}
   767			respLen := uint32(data[5])<<16 | uint32(data[6])<<8 | uint32(data[7])
   768			if uint32(len(data)) != 4+4+respLen {
   769				return false
   770			}
   771			m.response = data[8:]
   772		}
   773		return true
   774	}
   775	
   776	type serverHelloDoneMsg struct{}
   777	
   778	func (m *serverHelloDoneMsg) equal(i interface{}) bool {
   779		_, ok := i.(*serverHelloDoneMsg)
   780		return ok
   781	}
   782	
   783	func (m *serverHelloDoneMsg) marshal() []byte {
   784		x := make([]byte, 4)
   785		x[0] = typeServerHelloDone
   786		return x
   787	}
   788	
   789	func (m *serverHelloDoneMsg) unmarshal(data []byte) bool {
   790		return len(data) == 4
   791	}
   792	
   793	type clientKeyExchangeMsg struct {
   794		raw        []byte
   795		ciphertext []byte
   796	}
   797	
   798	func (m *clientKeyExchangeMsg) equal(i interface{}) bool {
   799		m1, ok := i.(*clientKeyExchangeMsg)
   800		if !ok {
   801			return false
   802		}
   803	
   804		return bytes.Equal(m.raw, m1.raw) &&
   805			bytes.Equal(m.ciphertext, m1.ciphertext)
   806	}
   807	
   808	func (m *clientKeyExchangeMsg) marshal() []byte {
   809		if m.raw != nil {
   810			return m.raw
   811		}
   812		length := len(m.ciphertext)
   813		x := make([]byte, length+4)
   814		x[0] = typeClientKeyExchange
   815		x[1] = uint8(length >> 16)
   816		x[2] = uint8(length >> 8)
   817		x[3] = uint8(length)
   818		copy(x[4:], m.ciphertext)
   819	
   820		m.raw = x
   821		return x
   822	}
   823	
   824	func (m *clientKeyExchangeMsg) unmarshal(data []byte) bool {
   825		m.raw = data
   826		if len(data) < 4 {
   827			return false
   828		}
   829		l := int(data[1])<<16 | int(data[2])<<8 | int(data[3])
   830		if l != len(data)-4 {
   831			return false
   832		}
   833		m.ciphertext = data[4:]
   834		return true
   835	}
   836	
   837	type finishedMsg struct {
   838		raw        []byte
   839		verifyData []byte
   840	}
   841	
   842	func (m *finishedMsg) equal(i interface{}) bool {
   843		m1, ok := i.(*finishedMsg)
   844		if !ok {
   845			return false
   846		}
   847	
   848		return bytes.Equal(m.raw, m1.raw) &&
   849			bytes.Equal(m.verifyData, m1.verifyData)
   850	}
   851	
   852	func (m *finishedMsg) marshal() (x []byte) {
   853		if m.raw != nil {
   854			return m.raw
   855		}
   856	
   857		x = make([]byte, 4+len(m.verifyData))
   858		x[0] = typeFinished
   859		x[3] = byte(len(m.verifyData))
   860		copy(x[4:], m.verifyData)
   861		m.raw = x
   862		return
   863	}
   864	
   865	func (m *finishedMsg) unmarshal(data []byte) bool {
   866		m.raw = data
   867		if len(data) < 4 {
   868			return false
   869		}
   870		m.verifyData = data[4:]
   871		return true
   872	}
   873	
   874	type nextProtoMsg struct {
   875		raw   []byte
   876		proto string
   877	}
   878	
   879	func (m *nextProtoMsg) equal(i interface{}) bool {
   880		m1, ok := i.(*nextProtoMsg)
   881		if !ok {
   882			return false
   883		}
   884	
   885		return bytes.Equal(m.raw, m1.raw) &&
   886			m.proto == m1.proto
   887	}
   888	
   889	func (m *nextProtoMsg) marshal() []byte {
   890		if m.raw != nil {
   891			return m.raw
   892		}
   893		l := len(m.proto)
   894		if l > 255 {
   895			l = 255
   896		}
   897	
   898		padding := 32 - (l+2)%32
   899		length := l + padding + 2
   900		x := make([]byte, length+4)
   901		x[0] = typeNextProtocol
   902		x[1] = uint8(length >> 16)
   903		x[2] = uint8(length >> 8)
   904		x[3] = uint8(length)
   905	
   906		y := x[4:]
   907		y[0] = byte(l)
   908		copy(y[1:], []byte(m.proto[0:l]))
   909		y = y[1+l:]
   910		y[0] = byte(padding)
   911	
   912		m.raw = x
   913	
   914		return x
   915	}
   916	
   917	func (m *nextProtoMsg) unmarshal(data []byte) bool {
   918		m.raw = data
   919	
   920		if len(data) < 5 {
   921			return false
   922		}
   923		data = data[4:]
   924		protoLen := int(data[0])
   925		data = data[1:]
   926		if len(data) < protoLen {
   927			return false
   928		}
   929		m.proto = string(data[0:protoLen])
   930		data = data[protoLen:]
   931	
   932		if len(data) < 1 {
   933			return false
   934		}
   935		paddingLen := int(data[0])
   936		data = data[1:]
   937		if len(data) != paddingLen {
   938			return false
   939		}
   940	
   941		return true
   942	}
   943	
   944	type certificateRequestMsg struct {
   945		raw []byte
   946		// hasSignatureAndHash indicates whether this message includes a list
   947		// of signature and hash functions. This change was introduced with TLS
   948		// 1.2.
   949		hasSignatureAndHash bool
   950	
   951		certificateTypes       []byte
   952		signatureAndHashes     []signatureAndHash
   953		certificateAuthorities [][]byte
   954	}
   955	
   956	func (m *certificateRequestMsg) equal(i interface{}) bool {
   957		m1, ok := i.(*certificateRequestMsg)
   958		if !ok {
   959			return false
   960		}
   961	
   962		return bytes.Equal(m.raw, m1.raw) &&
   963			bytes.Equal(m.certificateTypes, m1.certificateTypes) &&
   964			eqByteSlices(m.certificateAuthorities, m1.certificateAuthorities) &&
   965			eqSignatureAndHashes(m.signatureAndHashes, m1.signatureAndHashes)
   966	}
   967	
   968	func (m *certificateRequestMsg) marshal() (x []byte) {
   969		if m.raw != nil {
   970			return m.raw
   971		}
   972	
   973		// See http://tools.ietf.org/html/rfc4346#section-7.4.4
   974		length := 1 + len(m.certificateTypes) + 2
   975		casLength := 0
   976		for _, ca := range m.certificateAuthorities {
   977			casLength += 2 + len(ca)
   978		}
   979		length += casLength
   980	
   981		if m.hasSignatureAndHash {
   982			length += 2 + 2*len(m.signatureAndHashes)
   983		}
   984	
   985		x = make([]byte, 4+length)
   986		x[0] = typeCertificateRequest
   987		x[1] = uint8(length >> 16)
   988		x[2] = uint8(length >> 8)
   989		x[3] = uint8(length)
   990	
   991		x[4] = uint8(len(m.certificateTypes))
   992	
   993		copy(x[5:], m.certificateTypes)
   994		y := x[5+len(m.certificateTypes):]
   995	
   996		if m.hasSignatureAndHash {
   997			n := len(m.signatureAndHashes) * 2
   998			y[0] = uint8(n >> 8)
   999			y[1] = uint8(n)
  1000			y = y[2:]
  1001			for _, sigAndHash := range m.signatureAndHashes {
  1002				y[0] = sigAndHash.hash
  1003				y[1] = sigAndHash.signature
  1004				y = y[2:]
  1005			}
  1006		}
  1007	
  1008		y[0] = uint8(casLength >> 8)
  1009		y[1] = uint8(casLength)
  1010		y = y[2:]
  1011		for _, ca := range m.certificateAuthorities {
  1012			y[0] = uint8(len(ca) >> 8)
  1013			y[1] = uint8(len(ca))
  1014			y = y[2:]
  1015			copy(y, ca)
  1016			y = y[len(ca):]
  1017		}
  1018	
  1019		m.raw = x
  1020		return
  1021	}
  1022	
  1023	func (m *certificateRequestMsg) unmarshal(data []byte) bool {
  1024		m.raw = data
  1025	
  1026		if len(data) < 5 {
  1027			return false
  1028		}
  1029	
  1030		length := uint32(data[1])<<16 | uint32(data[2])<<8 | uint32(data[3])
  1031		if uint32(len(data))-4 != length {
  1032			return false
  1033		}
  1034	
  1035		numCertTypes := int(data[4])
  1036		data = data[5:]
  1037		if numCertTypes == 0 || len(data) <= numCertTypes {
  1038			return false
  1039		}
  1040	
  1041		m.certificateTypes = make([]byte, numCertTypes)
  1042		if copy(m.certificateTypes, data) != numCertTypes {
  1043			return false
  1044		}
  1045	
  1046		data = data[numCertTypes:]
  1047	
  1048		if m.hasSignatureAndHash {
  1049			if len(data) < 2 {
  1050				return false
  1051			}
  1052			sigAndHashLen := uint16(data[0])<<8 | uint16(data[1])
  1053			data = data[2:]
  1054			if sigAndHashLen&1 != 0 {
  1055				return false
  1056			}
  1057			if len(data) < int(sigAndHashLen) {
  1058				return false
  1059			}
  1060			numSigAndHash := sigAndHashLen / 2
  1061			m.signatureAndHashes = make([]signatureAndHash, numSigAndHash)
  1062			for i := range m.signatureAndHashes {
  1063				m.signatureAndHashes[i].hash = data[0]
  1064				m.signatureAndHashes[i].signature = data[1]
  1065				data = data[2:]
  1066			}
  1067		}
  1068	
  1069		if len(data) < 2 {
  1070			return false
  1071		}
  1072		casLength := uint16(data[0])<<8 | uint16(data[1])
  1073		data = data[2:]
  1074		if len(data) < int(casLength) {
  1075			return false
  1076		}
  1077		cas := make([]byte, casLength)
  1078		copy(cas, data)
  1079		data = data[casLength:]
  1080	
  1081		m.certificateAuthorities = nil
  1082		for len(cas) > 0 {
  1083			if len(cas) < 2 {
  1084				return false
  1085			}
  1086			caLen := uint16(cas[0])<<8 | uint16(cas[1])
  1087			cas = cas[2:]
  1088	
  1089			if len(cas) < int(caLen) {
  1090				return false
  1091			}
  1092	
  1093			m.certificateAuthorities = append(m.certificateAuthorities, cas[:caLen])
  1094			cas = cas[caLen:]
  1095		}
  1096		if len(data) > 0 {
  1097			return false
  1098		}
  1099	
  1100		return true
  1101	}
  1102	
  1103	type certificateVerifyMsg struct {
  1104		raw                 []byte
  1105		hasSignatureAndHash bool
  1106		signatureAndHash    signatureAndHash
  1107		signature           []byte
  1108	}
  1109	
  1110	func (m *certificateVerifyMsg) equal(i interface{}) bool {
  1111		m1, ok := i.(*certificateVerifyMsg)
  1112		if !ok {
  1113			return false
  1114		}
  1115	
  1116		return bytes.Equal(m.raw, m1.raw) &&
  1117			m.hasSignatureAndHash == m1.hasSignatureAndHash &&
  1118			m.signatureAndHash.hash == m1.signatureAndHash.hash &&
  1119			m.signatureAndHash.signature == m1.signatureAndHash.signature &&
  1120			bytes.Equal(m.signature, m1.signature)
  1121	}
  1122	
  1123	func (m *certificateVerifyMsg) marshal() (x []byte) {
  1124		if m.raw != nil {
  1125			return m.raw
  1126		}
  1127	
  1128		// See http://tools.ietf.org/html/rfc4346#section-7.4.8
  1129		siglength := len(m.signature)
  1130		length := 2 + siglength
  1131		if m.hasSignatureAndHash {
  1132			length += 2
  1133		}
  1134		x = make([]byte, 4+length)
  1135		x[0] = typeCertificateVerify
  1136		x[1] = uint8(length >> 16)
  1137		x[2] = uint8(length >> 8)
  1138		x[3] = uint8(length)
  1139		y := x[4:]
  1140		if m.hasSignatureAndHash {
  1141			y[0] = m.signatureAndHash.hash
  1142			y[1] = m.signatureAndHash.signature
  1143			y = y[2:]
  1144		}
  1145		y[0] = uint8(siglength >> 8)
  1146		y[1] = uint8(siglength)
  1147		copy(y[2:], m.signature)
  1148	
  1149		m.raw = x
  1150	
  1151		return
  1152	}
  1153	
  1154	func (m *certificateVerifyMsg) unmarshal(data []byte) bool {
  1155		m.raw = data
  1156	
  1157		if len(data) < 6 {
  1158			return false
  1159		}
  1160	
  1161		length := uint32(data[1])<<16 | uint32(data[2])<<8 | uint32(data[3])
  1162		if uint32(len(data))-4 != length {
  1163			return false
  1164		}
  1165	
  1166		data = data[4:]
  1167		if m.hasSignatureAndHash {
  1168			m.signatureAndHash.hash = data[0]
  1169			m.signatureAndHash.signature = data[1]
  1170			data = data[2:]
  1171		}
  1172	
  1173		if len(data) < 2 {
  1174			return false
  1175		}
  1176		siglength := int(data[0])<<8 + int(data[1])
  1177		data = data[2:]
  1178		if len(data) != siglength {
  1179			return false
  1180		}
  1181	
  1182		m.signature = data
  1183	
  1184		return true
  1185	}
  1186	
  1187	type newSessionTicketMsg struct {
  1188		raw    []byte
  1189		ticket []byte
  1190	}
  1191	
  1192	func (m *newSessionTicketMsg) equal(i interface{}) bool {
  1193		m1, ok := i.(*newSessionTicketMsg)
  1194		if !ok {
  1195			return false
  1196		}
  1197	
  1198		return bytes.Equal(m.raw, m1.raw) &&
  1199			bytes.Equal(m.ticket, m1.ticket)
  1200	}
  1201	
  1202	func (m *newSessionTicketMsg) marshal() (x []byte) {
  1203		if m.raw != nil {
  1204			return m.raw
  1205		}
  1206	
  1207		// See http://tools.ietf.org/html/rfc5077#section-3.3
  1208		ticketLen := len(m.ticket)
  1209		length := 2 + 4 + ticketLen
  1210		x = make([]byte, 4+length)
  1211		x[0] = typeNewSessionTicket
  1212		x[1] = uint8(length >> 16)
  1213		x[2] = uint8(length >> 8)
  1214		x[3] = uint8(length)
  1215		x[8] = uint8(ticketLen >> 8)
  1216		x[9] = uint8(ticketLen)
  1217		copy(x[10:], m.ticket)
  1218	
  1219		m.raw = x
  1220	
  1221		return
  1222	}
  1223	
  1224	func (m *newSessionTicketMsg) unmarshal(data []byte) bool {
  1225		m.raw = data
  1226	
  1227		if len(data) < 10 {
  1228			return false
  1229		}
  1230	
  1231		length := uint32(data[1])<<16 | uint32(data[2])<<8 | uint32(data[3])
  1232		if uint32(len(data))-4 != length {
  1233			return false
  1234		}
  1235	
  1236		ticketLen := int(data[8])<<8 + int(data[9])
  1237		if len(data)-10 != ticketLen {
  1238			return false
  1239		}
  1240	
  1241		m.ticket = data[10:]
  1242	
  1243		return true
  1244	}
  1245	
  1246	func eqUint16s(x, y []uint16) bool {
  1247		if len(x) != len(y) {
  1248			return false
  1249		}
  1250		for i, v := range x {
  1251			if y[i] != v {
  1252				return false
  1253			}
  1254		}
  1255		return true
  1256	}
  1257	
  1258	func eqStrings(x, y []string) bool {
  1259		if len(x) != len(y) {
  1260			return false
  1261		}
  1262		for i, v := range x {
  1263			if y[i] != v {
  1264				return false
  1265			}
  1266		}
  1267		return true
  1268	}
  1269	
  1270	func eqByteSlices(x, y [][]byte) bool {
  1271		if len(x) != len(y) {
  1272			return false
  1273		}
  1274		for i, v := range x {
  1275			if !bytes.Equal(v, y[i]) {
  1276				return false
  1277			}
  1278		}
  1279		return true
  1280	}
  1281	
  1282	func eqSignatureAndHashes(x, y []signatureAndHash) bool {
  1283		if len(x) != len(y) {
  1284			return false
  1285		}
  1286		for i, v := range x {
  1287			v2 := y[i]
  1288			if v.hash != v2.hash || v.signature != v2.signature {
  1289				return false
  1290			}
  1291		}
  1292		return true
  1293	}

View as plain text