...
Run Format

Source file src/crypto/tls/handshake_messages.go

     1	// Copyright 2009 The Go Authors. All rights reserved.
     2	// Use of this source code is governed by a BSD-style
     3	// license that can be found in the LICENSE file.
     4	
     5	package tls
     6	
     7	import "bytes"
     8	
     9	type clientHelloMsg struct {
    10		raw                 []byte
    11		vers                uint16
    12		random              []byte
    13		sessionId           []byte
    14		cipherSuites        []uint16
    15		compressionMethods  []uint8
    16		nextProtoNeg        bool
    17		serverName          string
    18		ocspStapling        bool
    19		supportedCurves     []CurveID
    20		supportedPoints     []uint8
    21		ticketSupported     bool
    22		sessionTicket       []uint8
    23		signatureAndHashes  []signatureAndHash
    24		secureRenegotiation bool
    25		alpnProtocols       []string
    26	}
    27	
    28	func (m *clientHelloMsg) equal(i interface{}) bool {
    29		m1, ok := i.(*clientHelloMsg)
    30		if !ok {
    31			return false
    32		}
    33	
    34		return bytes.Equal(m.raw, m1.raw) &&
    35			m.vers == m1.vers &&
    36			bytes.Equal(m.random, m1.random) &&
    37			bytes.Equal(m.sessionId, m1.sessionId) &&
    38			eqUint16s(m.cipherSuites, m1.cipherSuites) &&
    39			bytes.Equal(m.compressionMethods, m1.compressionMethods) &&
    40			m.nextProtoNeg == m1.nextProtoNeg &&
    41			m.serverName == m1.serverName &&
    42			m.ocspStapling == m1.ocspStapling &&
    43			eqCurveIDs(m.supportedCurves, m1.supportedCurves) &&
    44			bytes.Equal(m.supportedPoints, m1.supportedPoints) &&
    45			m.ticketSupported == m1.ticketSupported &&
    46			bytes.Equal(m.sessionTicket, m1.sessionTicket) &&
    47			eqSignatureAndHashes(m.signatureAndHashes, m1.signatureAndHashes) &&
    48			m.secureRenegotiation == m1.secureRenegotiation &&
    49			eqStrings(m.alpnProtocols, m1.alpnProtocols)
    50	}
    51	
    52	func (m *clientHelloMsg) marshal() []byte {
    53		if m.raw != nil {
    54			return m.raw
    55		}
    56	
    57		length := 2 + 32 + 1 + len(m.sessionId) + 2 + len(m.cipherSuites)*2 + 1 + len(m.compressionMethods)
    58		numExtensions := 0
    59		extensionsLength := 0
    60		if m.nextProtoNeg {
    61			numExtensions++
    62		}
    63		if m.ocspStapling {
    64			extensionsLength += 1 + 2 + 2
    65			numExtensions++
    66		}
    67		if len(m.serverName) > 0 {
    68			extensionsLength += 5 + len(m.serverName)
    69			numExtensions++
    70		}
    71		if len(m.supportedCurves) > 0 {
    72			extensionsLength += 2 + 2*len(m.supportedCurves)
    73			numExtensions++
    74		}
    75		if len(m.supportedPoints) > 0 {
    76			extensionsLength += 1 + len(m.supportedPoints)
    77			numExtensions++
    78		}
    79		if m.ticketSupported {
    80			extensionsLength += len(m.sessionTicket)
    81			numExtensions++
    82		}
    83		if len(m.signatureAndHashes) > 0 {
    84			extensionsLength += 2 + 2*len(m.signatureAndHashes)
    85			numExtensions++
    86		}
    87		if m.secureRenegotiation {
    88			extensionsLength += 1
    89			numExtensions++
    90		}
    91		if len(m.alpnProtocols) > 0 {
    92			extensionsLength += 2
    93			for _, s := range m.alpnProtocols {
    94				if l := len(s); l == 0 || l > 255 {
    95					panic("invalid ALPN protocol")
    96				}
    97				extensionsLength++
    98				extensionsLength += len(s)
    99			}
   100			numExtensions++
   101		}
   102		if numExtensions > 0 {
   103			extensionsLength += 4 * numExtensions
   104			length += 2 + extensionsLength
   105		}
   106	
   107		x := make([]byte, 4+length)
   108		x[0] = typeClientHello
   109		x[1] = uint8(length >> 16)
   110		x[2] = uint8(length >> 8)
   111		x[3] = uint8(length)
   112		x[4] = uint8(m.vers >> 8)
   113		x[5] = uint8(m.vers)
   114		copy(x[6:38], m.random)
   115		x[38] = uint8(len(m.sessionId))
   116		copy(x[39:39+len(m.sessionId)], m.sessionId)
   117		y := x[39+len(m.sessionId):]
   118		y[0] = uint8(len(m.cipherSuites) >> 7)
   119		y[1] = uint8(len(m.cipherSuites) << 1)
   120		for i, suite := range m.cipherSuites {
   121			y[2+i*2] = uint8(suite >> 8)
   122			y[3+i*2] = uint8(suite)
   123		}
   124		z := y[2+len(m.cipherSuites)*2:]
   125		z[0] = uint8(len(m.compressionMethods))
   126		copy(z[1:], m.compressionMethods)
   127	
   128		z = z[1+len(m.compressionMethods):]
   129		if numExtensions > 0 {
   130			z[0] = byte(extensionsLength >> 8)
   131			z[1] = byte(extensionsLength)
   132			z = z[2:]
   133		}
   134		if m.nextProtoNeg {
   135			z[0] = byte(extensionNextProtoNeg >> 8)
   136			z[1] = byte(extensionNextProtoNeg & 0xff)
   137			// The length is always 0
   138			z = z[4:]
   139		}
   140		if len(m.serverName) > 0 {
   141			z[0] = byte(extensionServerName >> 8)
   142			z[1] = byte(extensionServerName & 0xff)
   143			l := len(m.serverName) + 5
   144			z[2] = byte(l >> 8)
   145			z[3] = byte(l)
   146			z = z[4:]
   147	
   148			// RFC 3546, section 3.1
   149			//
   150			// struct {
   151			//     NameType name_type;
   152			//     select (name_type) {
   153			//         case host_name: HostName;
   154			//     } name;
   155			// } ServerName;
   156			//
   157			// enum {
   158			//     host_name(0), (255)
   159			// } NameType;
   160			//
   161			// opaque HostName<1..2^16-1>;
   162			//
   163			// struct {
   164			//     ServerName server_name_list<1..2^16-1>
   165			// } ServerNameList;
   166	
   167			z[0] = byte((len(m.serverName) + 3) >> 8)
   168			z[1] = byte(len(m.serverName) + 3)
   169			z[3] = byte(len(m.serverName) >> 8)
   170			z[4] = byte(len(m.serverName))
   171			copy(z[5:], []byte(m.serverName))
   172			z = z[l:]
   173		}
   174		if m.ocspStapling {
   175			// RFC 4366, section 3.6
   176			z[0] = byte(extensionStatusRequest >> 8)
   177			z[1] = byte(extensionStatusRequest)
   178			z[2] = 0
   179			z[3] = 5
   180			z[4] = 1 // OCSP type
   181			// Two zero valued uint16s for the two lengths.
   182			z = z[9:]
   183		}
   184		if len(m.supportedCurves) > 0 {
   185			// http://tools.ietf.org/html/rfc4492#section-5.5.1
   186			z[0] = byte(extensionSupportedCurves >> 8)
   187			z[1] = byte(extensionSupportedCurves)
   188			l := 2 + 2*len(m.supportedCurves)
   189			z[2] = byte(l >> 8)
   190			z[3] = byte(l)
   191			l -= 2
   192			z[4] = byte(l >> 8)
   193			z[5] = byte(l)
   194			z = z[6:]
   195			for _, curve := range m.supportedCurves {
   196				z[0] = byte(curve >> 8)
   197				z[1] = byte(curve)
   198				z = z[2:]
   199			}
   200		}
   201		if len(m.supportedPoints) > 0 {
   202			// http://tools.ietf.org/html/rfc4492#section-5.5.2
   203			z[0] = byte(extensionSupportedPoints >> 8)
   204			z[1] = byte(extensionSupportedPoints)
   205			l := 1 + len(m.supportedPoints)
   206			z[2] = byte(l >> 8)
   207			z[3] = byte(l)
   208			l--
   209			z[4] = byte(l)
   210			z = z[5:]
   211			for _, pointFormat := range m.supportedPoints {
   212				z[0] = byte(pointFormat)
   213				z = z[1:]
   214			}
   215		}
   216		if m.ticketSupported {
   217			// http://tools.ietf.org/html/rfc5077#section-3.2
   218			z[0] = byte(extensionSessionTicket >> 8)
   219			z[1] = byte(extensionSessionTicket)
   220			l := len(m.sessionTicket)
   221			z[2] = byte(l >> 8)
   222			z[3] = byte(l)
   223			z = z[4:]
   224			copy(z, m.sessionTicket)
   225			z = z[len(m.sessionTicket):]
   226		}
   227		if len(m.signatureAndHashes) > 0 {
   228			// https://tools.ietf.org/html/rfc5246#section-7.4.1.4.1
   229			z[0] = byte(extensionSignatureAlgorithms >> 8)
   230			z[1] = byte(extensionSignatureAlgorithms)
   231			l := 2 + 2*len(m.signatureAndHashes)
   232			z[2] = byte(l >> 8)
   233			z[3] = byte(l)
   234			z = z[4:]
   235	
   236			l -= 2
   237			z[0] = byte(l >> 8)
   238			z[1] = byte(l)
   239			z = z[2:]
   240			for _, sigAndHash := range m.signatureAndHashes {
   241				z[0] = sigAndHash.hash
   242				z[1] = sigAndHash.signature
   243				z = z[2:]
   244			}
   245		}
   246		if m.secureRenegotiation {
   247			z[0] = byte(extensionRenegotiationInfo >> 8)
   248			z[1] = byte(extensionRenegotiationInfo & 0xff)
   249			z[2] = 0
   250			z[3] = 1
   251			z = z[5:]
   252		}
   253		if len(m.alpnProtocols) > 0 {
   254			z[0] = byte(extensionALPN >> 8)
   255			z[1] = byte(extensionALPN & 0xff)
   256			lengths := z[2:]
   257			z = z[6:]
   258	
   259			stringsLength := 0
   260			for _, s := range m.alpnProtocols {
   261				l := len(s)
   262				z[0] = byte(l)
   263				copy(z[1:], s)
   264				z = z[1+l:]
   265				stringsLength += 1 + l
   266			}
   267	
   268			lengths[2] = byte(stringsLength >> 8)
   269			lengths[3] = byte(stringsLength)
   270			stringsLength += 2
   271			lengths[0] = byte(stringsLength >> 8)
   272			lengths[1] = byte(stringsLength)
   273		}
   274	
   275		m.raw = x
   276	
   277		return x
   278	}
   279	
   280	func (m *clientHelloMsg) unmarshal(data []byte) bool {
   281		if len(data) < 42 {
   282			return false
   283		}
   284		m.raw = data
   285		m.vers = uint16(data[4])<<8 | uint16(data[5])
   286		m.random = data[6:38]
   287		sessionIdLen := int(data[38])
   288		if sessionIdLen > 32 || len(data) < 39+sessionIdLen {
   289			return false
   290		}
   291		m.sessionId = data[39 : 39+sessionIdLen]
   292		data = data[39+sessionIdLen:]
   293		if len(data) < 2 {
   294			return false
   295		}
   296		// cipherSuiteLen is the number of bytes of cipher suite numbers. Since
   297		// they are uint16s, the number must be even.
   298		cipherSuiteLen := int(data[0])<<8 | int(data[1])
   299		if cipherSuiteLen%2 == 1 || len(data) < 2+cipherSuiteLen {
   300			return false
   301		}
   302		numCipherSuites := cipherSuiteLen / 2
   303		m.cipherSuites = make([]uint16, numCipherSuites)
   304		for i := 0; i < numCipherSuites; i++ {
   305			m.cipherSuites[i] = uint16(data[2+2*i])<<8 | uint16(data[3+2*i])
   306			if m.cipherSuites[i] == scsvRenegotiation {
   307				m.secureRenegotiation = true
   308			}
   309		}
   310		data = data[2+cipherSuiteLen:]
   311		if len(data) < 1 {
   312			return false
   313		}
   314		compressionMethodsLen := int(data[0])
   315		if len(data) < 1+compressionMethodsLen {
   316			return false
   317		}
   318		m.compressionMethods = data[1 : 1+compressionMethodsLen]
   319	
   320		data = data[1+compressionMethodsLen:]
   321	
   322		m.nextProtoNeg = false
   323		m.serverName = ""
   324		m.ocspStapling = false
   325		m.ticketSupported = false
   326		m.sessionTicket = nil
   327		m.signatureAndHashes = nil
   328		m.alpnProtocols = nil
   329	
   330		if len(data) == 0 {
   331			// ClientHello is optionally followed by extension data
   332			return true
   333		}
   334		if len(data) < 2 {
   335			return false
   336		}
   337	
   338		extensionsLength := int(data[0])<<8 | int(data[1])
   339		data = data[2:]
   340		if extensionsLength != len(data) {
   341			return false
   342		}
   343	
   344		for len(data) != 0 {
   345			if len(data) < 4 {
   346				return false
   347			}
   348			extension := uint16(data[0])<<8 | uint16(data[1])
   349			length := int(data[2])<<8 | int(data[3])
   350			data = data[4:]
   351			if len(data) < length {
   352				return false
   353			}
   354	
   355			switch extension {
   356			case extensionServerName:
   357				if length < 2 {
   358					return false
   359				}
   360				numNames := int(data[0])<<8 | int(data[1])
   361				d := data[2:]
   362				for i := 0; i < numNames; i++ {
   363					if len(d) < 3 {
   364						return false
   365					}
   366					nameType := d[0]
   367					nameLen := int(d[1])<<8 | int(d[2])
   368					d = d[3:]
   369					if len(d) < nameLen {
   370						return false
   371					}
   372					if nameType == 0 {
   373						m.serverName = string(d[0:nameLen])
   374						break
   375					}
   376					d = d[nameLen:]
   377				}
   378			case extensionNextProtoNeg:
   379				if length > 0 {
   380					return false
   381				}
   382				m.nextProtoNeg = true
   383			case extensionStatusRequest:
   384				m.ocspStapling = length > 0 && data[0] == statusTypeOCSP
   385			case extensionSupportedCurves:
   386				// http://tools.ietf.org/html/rfc4492#section-5.5.1
   387				if length < 2 {
   388					return false
   389				}
   390				l := int(data[0])<<8 | int(data[1])
   391				if l%2 == 1 || length != l+2 {
   392					return false
   393				}
   394				numCurves := l / 2
   395				m.supportedCurves = make([]CurveID, numCurves)
   396				d := data[2:]
   397				for i := 0; i < numCurves; i++ {
   398					m.supportedCurves[i] = CurveID(d[0])<<8 | CurveID(d[1])
   399					d = d[2:]
   400				}
   401			case extensionSupportedPoints:
   402				// http://tools.ietf.org/html/rfc4492#section-5.5.2
   403				if length < 1 {
   404					return false
   405				}
   406				l := int(data[0])
   407				if length != l+1 {
   408					return false
   409				}
   410				m.supportedPoints = make([]uint8, l)
   411				copy(m.supportedPoints, data[1:])
   412			case extensionSessionTicket:
   413				// http://tools.ietf.org/html/rfc5077#section-3.2
   414				m.ticketSupported = true
   415				m.sessionTicket = data[:length]
   416			case extensionSignatureAlgorithms:
   417				// https://tools.ietf.org/html/rfc5246#section-7.4.1.4.1
   418				if length < 2 || length&1 != 0 {
   419					return false
   420				}
   421				l := int(data[0])<<8 | int(data[1])
   422				if l != length-2 {
   423					return false
   424				}
   425				n := l / 2
   426				d := data[2:]
   427				m.signatureAndHashes = make([]signatureAndHash, n)
   428				for i := range m.signatureAndHashes {
   429					m.signatureAndHashes[i].hash = d[0]
   430					m.signatureAndHashes[i].signature = d[1]
   431					d = d[2:]
   432				}
   433			case extensionRenegotiationInfo + 1:
   434				if length != 1 || data[0] != 0 {
   435					return false
   436				}
   437				m.secureRenegotiation = true
   438			case extensionALPN:
   439				if length < 2 {
   440					return false
   441				}
   442				l := int(data[0])<<8 | int(data[1])
   443				if l != length-2 {
   444					return false
   445				}
   446				d := data[2:length]
   447				for len(d) != 0 {
   448					stringLen := int(d[0])
   449					d = d[1:]
   450					if stringLen == 0 || stringLen > len(d) {
   451						return false
   452					}
   453					m.alpnProtocols = append(m.alpnProtocols, string(d[:stringLen]))
   454					d = d[stringLen:]
   455				}
   456			}
   457			data = data[length:]
   458		}
   459	
   460		return true
   461	}
   462	
   463	type serverHelloMsg struct {
   464		raw                 []byte
   465		vers                uint16
   466		random              []byte
   467		sessionId           []byte
   468		cipherSuite         uint16
   469		compressionMethod   uint8
   470		nextProtoNeg        bool
   471		nextProtos          []string
   472		ocspStapling        bool
   473		ticketSupported     bool
   474		secureRenegotiation bool
   475		alpnProtocol        string
   476	}
   477	
   478	func (m *serverHelloMsg) equal(i interface{}) bool {
   479		m1, ok := i.(*serverHelloMsg)
   480		if !ok {
   481			return false
   482		}
   483	
   484		return bytes.Equal(m.raw, m1.raw) &&
   485			m.vers == m1.vers &&
   486			bytes.Equal(m.random, m1.random) &&
   487			bytes.Equal(m.sessionId, m1.sessionId) &&
   488			m.cipherSuite == m1.cipherSuite &&
   489			m.compressionMethod == m1.compressionMethod &&
   490			m.nextProtoNeg == m1.nextProtoNeg &&
   491			eqStrings(m.nextProtos, m1.nextProtos) &&
   492			m.ocspStapling == m1.ocspStapling &&
   493			m.ticketSupported == m1.ticketSupported &&
   494			m.secureRenegotiation == m1.secureRenegotiation &&
   495			m.alpnProtocol == m1.alpnProtocol
   496	}
   497	
   498	func (m *serverHelloMsg) marshal() []byte {
   499		if m.raw != nil {
   500			return m.raw
   501		}
   502	
   503		length := 38 + len(m.sessionId)
   504		numExtensions := 0
   505		extensionsLength := 0
   506	
   507		nextProtoLen := 0
   508		if m.nextProtoNeg {
   509			numExtensions++
   510			for _, v := range m.nextProtos {
   511				nextProtoLen += len(v)
   512			}
   513			nextProtoLen += len(m.nextProtos)
   514			extensionsLength += nextProtoLen
   515		}
   516		if m.ocspStapling {
   517			numExtensions++
   518		}
   519		if m.ticketSupported {
   520			numExtensions++
   521		}
   522		if m.secureRenegotiation {
   523			extensionsLength += 1
   524			numExtensions++
   525		}
   526		if alpnLen := len(m.alpnProtocol); alpnLen > 0 {
   527			if alpnLen >= 256 {
   528				panic("invalid ALPN protocol")
   529			}
   530			extensionsLength += 2 + 1 + alpnLen
   531			numExtensions++
   532		}
   533	
   534		if numExtensions > 0 {
   535			extensionsLength += 4 * numExtensions
   536			length += 2 + extensionsLength
   537		}
   538	
   539		x := make([]byte, 4+length)
   540		x[0] = typeServerHello
   541		x[1] = uint8(length >> 16)
   542		x[2] = uint8(length >> 8)
   543		x[3] = uint8(length)
   544		x[4] = uint8(m.vers >> 8)
   545		x[5] = uint8(m.vers)
   546		copy(x[6:38], m.random)
   547		x[38] = uint8(len(m.sessionId))
   548		copy(x[39:39+len(m.sessionId)], m.sessionId)
   549		z := x[39+len(m.sessionId):]
   550		z[0] = uint8(m.cipherSuite >> 8)
   551		z[1] = uint8(m.cipherSuite)
   552		z[2] = uint8(m.compressionMethod)
   553	
   554		z = z[3:]
   555		if numExtensions > 0 {
   556			z[0] = byte(extensionsLength >> 8)
   557			z[1] = byte(extensionsLength)
   558			z = z[2:]
   559		}
   560		if m.nextProtoNeg {
   561			z[0] = byte(extensionNextProtoNeg >> 8)
   562			z[1] = byte(extensionNextProtoNeg & 0xff)
   563			z[2] = byte(nextProtoLen >> 8)
   564			z[3] = byte(nextProtoLen)
   565			z = z[4:]
   566	
   567			for _, v := range m.nextProtos {
   568				l := len(v)
   569				if l > 255 {
   570					l = 255
   571				}
   572				z[0] = byte(l)
   573				copy(z[1:], []byte(v[0:l]))
   574				z = z[1+l:]
   575			}
   576		}
   577		if m.ocspStapling {
   578			z[0] = byte(extensionStatusRequest >> 8)
   579			z[1] = byte(extensionStatusRequest)
   580			z = z[4:]
   581		}
   582		if m.ticketSupported {
   583			z[0] = byte(extensionSessionTicket >> 8)
   584			z[1] = byte(extensionSessionTicket)
   585			z = z[4:]
   586		}
   587		if m.secureRenegotiation {
   588			z[0] = byte(extensionRenegotiationInfo >> 8)
   589			z[1] = byte(extensionRenegotiationInfo & 0xff)
   590			z[2] = 0
   591			z[3] = 1
   592			z = z[5:]
   593		}
   594		if alpnLen := len(m.alpnProtocol); alpnLen > 0 {
   595			z[0] = byte(extensionALPN >> 8)
   596			z[1] = byte(extensionALPN & 0xff)
   597			l := 2 + 1 + alpnLen
   598			z[2] = byte(l >> 8)
   599			z[3] = byte(l)
   600			l -= 2
   601			z[4] = byte(l >> 8)
   602			z[5] = byte(l)
   603			l -= 1
   604			z[6] = byte(l)
   605			copy(z[7:], []byte(m.alpnProtocol))
   606			z = z[7+alpnLen:]
   607		}
   608	
   609		m.raw = x
   610	
   611		return x
   612	}
   613	
   614	func (m *serverHelloMsg) unmarshal(data []byte) bool {
   615		if len(data) < 42 {
   616			return false
   617		}
   618		m.raw = data
   619		m.vers = uint16(data[4])<<8 | uint16(data[5])
   620		m.random = data[6:38]
   621		sessionIdLen := int(data[38])
   622		if sessionIdLen > 32 || len(data) < 39+sessionIdLen {
   623			return false
   624		}
   625		m.sessionId = data[39 : 39+sessionIdLen]
   626		data = data[39+sessionIdLen:]
   627		if len(data) < 3 {
   628			return false
   629		}
   630		m.cipherSuite = uint16(data[0])<<8 | uint16(data[1])
   631		m.compressionMethod = data[2]
   632		data = data[3:]
   633	
   634		m.nextProtoNeg = false
   635		m.nextProtos = nil
   636		m.ocspStapling = false
   637		m.ticketSupported = false
   638		m.alpnProtocol = ""
   639	
   640		if len(data) == 0 {
   641			// ServerHello is optionally followed by extension data
   642			return true
   643		}
   644		if len(data) < 2 {
   645			return false
   646		}
   647	
   648		extensionsLength := int(data[0])<<8 | int(data[1])
   649		data = data[2:]
   650		if len(data) != extensionsLength {
   651			return false
   652		}
   653	
   654		for len(data) != 0 {
   655			if len(data) < 4 {
   656				return false
   657			}
   658			extension := uint16(data[0])<<8 | uint16(data[1])
   659			length := int(data[2])<<8 | int(data[3])
   660			data = data[4:]
   661			if len(data) < length {
   662				return false
   663			}
   664	
   665			switch extension {
   666			case extensionNextProtoNeg:
   667				m.nextProtoNeg = true
   668				d := data[:length]
   669				for len(d) > 0 {
   670					l := int(d[0])
   671					d = d[1:]
   672					if l == 0 || l > len(d) {
   673						return false
   674					}
   675					m.nextProtos = append(m.nextProtos, string(d[:l]))
   676					d = d[l:]
   677				}
   678			case extensionStatusRequest:
   679				if length > 0 {
   680					return false
   681				}
   682				m.ocspStapling = true
   683			case extensionSessionTicket:
   684				if length > 0 {
   685					return false
   686				}
   687				m.ticketSupported = true
   688			case extensionRenegotiationInfo:
   689				if length != 1 || data[0] != 0 {
   690					return false
   691				}
   692				m.secureRenegotiation = true
   693			case extensionALPN:
   694				d := data[:length]
   695				if len(d) < 3 {
   696					return false
   697				}
   698				l := int(d[0])<<8 | int(d[1])
   699				if l != len(d)-2 {
   700					return false
   701				}
   702				d = d[2:]
   703				l = int(d[0])
   704				if l != len(d)-1 {
   705					return false
   706				}
   707				d = d[1:]
   708				m.alpnProtocol = string(d)
   709			}
   710			data = data[length:]
   711		}
   712	
   713		return true
   714	}
   715	
   716	type certificateMsg struct {
   717		raw          []byte
   718		certificates [][]byte
   719	}
   720	
   721	func (m *certificateMsg) equal(i interface{}) bool {
   722		m1, ok := i.(*certificateMsg)
   723		if !ok {
   724			return false
   725		}
   726	
   727		return bytes.Equal(m.raw, m1.raw) &&
   728			eqByteSlices(m.certificates, m1.certificates)
   729	}
   730	
   731	func (m *certificateMsg) marshal() (x []byte) {
   732		if m.raw != nil {
   733			return m.raw
   734		}
   735	
   736		var i int
   737		for _, slice := range m.certificates {
   738			i += len(slice)
   739		}
   740	
   741		length := 3 + 3*len(m.certificates) + i
   742		x = make([]byte, 4+length)
   743		x[0] = typeCertificate
   744		x[1] = uint8(length >> 16)
   745		x[2] = uint8(length >> 8)
   746		x[3] = uint8(length)
   747	
   748		certificateOctets := length - 3
   749		x[4] = uint8(certificateOctets >> 16)
   750		x[5] = uint8(certificateOctets >> 8)
   751		x[6] = uint8(certificateOctets)
   752	
   753		y := x[7:]
   754		for _, slice := range m.certificates {
   755			y[0] = uint8(len(slice) >> 16)
   756			y[1] = uint8(len(slice) >> 8)
   757			y[2] = uint8(len(slice))
   758			copy(y[3:], slice)
   759			y = y[3+len(slice):]
   760		}
   761	
   762		m.raw = x
   763		return
   764	}
   765	
   766	func (m *certificateMsg) unmarshal(data []byte) bool {
   767		if len(data) < 7 {
   768			return false
   769		}
   770	
   771		m.raw = data
   772		certsLen := uint32(data[4])<<16 | uint32(data[5])<<8 | uint32(data[6])
   773		if uint32(len(data)) != certsLen+7 {
   774			return false
   775		}
   776	
   777		numCerts := 0
   778		d := data[7:]
   779		for certsLen > 0 {
   780			if len(d) < 4 {
   781				return false
   782			}
   783			certLen := uint32(d[0])<<16 | uint32(d[1])<<8 | uint32(d[2])
   784			if uint32(len(d)) < 3+certLen {
   785				return false
   786			}
   787			d = d[3+certLen:]
   788			certsLen -= 3 + certLen
   789			numCerts++
   790		}
   791	
   792		m.certificates = make([][]byte, numCerts)
   793		d = data[7:]
   794		for i := 0; i < numCerts; i++ {
   795			certLen := uint32(d[0])<<16 | uint32(d[1])<<8 | uint32(d[2])
   796			m.certificates[i] = d[3 : 3+certLen]
   797			d = d[3+certLen:]
   798		}
   799	
   800		return true
   801	}
   802	
   803	type serverKeyExchangeMsg struct {
   804		raw []byte
   805		key []byte
   806	}
   807	
   808	func (m *serverKeyExchangeMsg) equal(i interface{}) bool {
   809		m1, ok := i.(*serverKeyExchangeMsg)
   810		if !ok {
   811			return false
   812		}
   813	
   814		return bytes.Equal(m.raw, m1.raw) &&
   815			bytes.Equal(m.key, m1.key)
   816	}
   817	
   818	func (m *serverKeyExchangeMsg) marshal() []byte {
   819		if m.raw != nil {
   820			return m.raw
   821		}
   822		length := len(m.key)
   823		x := make([]byte, length+4)
   824		x[0] = typeServerKeyExchange
   825		x[1] = uint8(length >> 16)
   826		x[2] = uint8(length >> 8)
   827		x[3] = uint8(length)
   828		copy(x[4:], m.key)
   829	
   830		m.raw = x
   831		return x
   832	}
   833	
   834	func (m *serverKeyExchangeMsg) unmarshal(data []byte) bool {
   835		m.raw = data
   836		if len(data) < 4 {
   837			return false
   838		}
   839		m.key = data[4:]
   840		return true
   841	}
   842	
   843	type certificateStatusMsg struct {
   844		raw        []byte
   845		statusType uint8
   846		response   []byte
   847	}
   848	
   849	func (m *certificateStatusMsg) equal(i interface{}) bool {
   850		m1, ok := i.(*certificateStatusMsg)
   851		if !ok {
   852			return false
   853		}
   854	
   855		return bytes.Equal(m.raw, m1.raw) &&
   856			m.statusType == m1.statusType &&
   857			bytes.Equal(m.response, m1.response)
   858	}
   859	
   860	func (m *certificateStatusMsg) marshal() []byte {
   861		if m.raw != nil {
   862			return m.raw
   863		}
   864	
   865		var x []byte
   866		if m.statusType == statusTypeOCSP {
   867			x = make([]byte, 4+4+len(m.response))
   868			x[0] = typeCertificateStatus
   869			l := len(m.response) + 4
   870			x[1] = byte(l >> 16)
   871			x[2] = byte(l >> 8)
   872			x[3] = byte(l)
   873			x[4] = statusTypeOCSP
   874	
   875			l -= 4
   876			x[5] = byte(l >> 16)
   877			x[6] = byte(l >> 8)
   878			x[7] = byte(l)
   879			copy(x[8:], m.response)
   880		} else {
   881			x = []byte{typeCertificateStatus, 0, 0, 1, m.statusType}
   882		}
   883	
   884		m.raw = x
   885		return x
   886	}
   887	
   888	func (m *certificateStatusMsg) unmarshal(data []byte) bool {
   889		m.raw = data
   890		if len(data) < 5 {
   891			return false
   892		}
   893		m.statusType = data[4]
   894	
   895		m.response = nil
   896		if m.statusType == statusTypeOCSP {
   897			if len(data) < 8 {
   898				return false
   899			}
   900			respLen := uint32(data[5])<<16 | uint32(data[6])<<8 | uint32(data[7])
   901			if uint32(len(data)) != 4+4+respLen {
   902				return false
   903			}
   904			m.response = data[8:]
   905		}
   906		return true
   907	}
   908	
   909	type serverHelloDoneMsg struct{}
   910	
   911	func (m *serverHelloDoneMsg) equal(i interface{}) bool {
   912		_, ok := i.(*serverHelloDoneMsg)
   913		return ok
   914	}
   915	
   916	func (m *serverHelloDoneMsg) marshal() []byte {
   917		x := make([]byte, 4)
   918		x[0] = typeServerHelloDone
   919		return x
   920	}
   921	
   922	func (m *serverHelloDoneMsg) unmarshal(data []byte) bool {
   923		return len(data) == 4
   924	}
   925	
   926	type clientKeyExchangeMsg struct {
   927		raw        []byte
   928		ciphertext []byte
   929	}
   930	
   931	func (m *clientKeyExchangeMsg) equal(i interface{}) bool {
   932		m1, ok := i.(*clientKeyExchangeMsg)
   933		if !ok {
   934			return false
   935		}
   936	
   937		return bytes.Equal(m.raw, m1.raw) &&
   938			bytes.Equal(m.ciphertext, m1.ciphertext)
   939	}
   940	
   941	func (m *clientKeyExchangeMsg) marshal() []byte {
   942		if m.raw != nil {
   943			return m.raw
   944		}
   945		length := len(m.ciphertext)
   946		x := make([]byte, length+4)
   947		x[0] = typeClientKeyExchange
   948		x[1] = uint8(length >> 16)
   949		x[2] = uint8(length >> 8)
   950		x[3] = uint8(length)
   951		copy(x[4:], m.ciphertext)
   952	
   953		m.raw = x
   954		return x
   955	}
   956	
   957	func (m *clientKeyExchangeMsg) unmarshal(data []byte) bool {
   958		m.raw = data
   959		if len(data) < 4 {
   960			return false
   961		}
   962		l := int(data[1])<<16 | int(data[2])<<8 | int(data[3])
   963		if l != len(data)-4 {
   964			return false
   965		}
   966		m.ciphertext = data[4:]
   967		return true
   968	}
   969	
   970	type finishedMsg struct {
   971		raw        []byte
   972		verifyData []byte
   973	}
   974	
   975	func (m *finishedMsg) equal(i interface{}) bool {
   976		m1, ok := i.(*finishedMsg)
   977		if !ok {
   978			return false
   979		}
   980	
   981		return bytes.Equal(m.raw, m1.raw) &&
   982			bytes.Equal(m.verifyData, m1.verifyData)
   983	}
   984	
   985	func (m *finishedMsg) marshal() (x []byte) {
   986		if m.raw != nil {
   987			return m.raw
   988		}
   989	
   990		x = make([]byte, 4+len(m.verifyData))
   991		x[0] = typeFinished
   992		x[3] = byte(len(m.verifyData))
   993		copy(x[4:], m.verifyData)
   994		m.raw = x
   995		return
   996	}
   997	
   998	func (m *finishedMsg) unmarshal(data []byte) bool {
   999		m.raw = data
  1000		if len(data) < 4 {
  1001			return false
  1002		}
  1003		m.verifyData = data[4:]
  1004		return true
  1005	}
  1006	
  1007	type nextProtoMsg struct {
  1008		raw   []byte
  1009		proto string
  1010	}
  1011	
  1012	func (m *nextProtoMsg) equal(i interface{}) bool {
  1013		m1, ok := i.(*nextProtoMsg)
  1014		if !ok {
  1015			return false
  1016		}
  1017	
  1018		return bytes.Equal(m.raw, m1.raw) &&
  1019			m.proto == m1.proto
  1020	}
  1021	
  1022	func (m *nextProtoMsg) marshal() []byte {
  1023		if m.raw != nil {
  1024			return m.raw
  1025		}
  1026		l := len(m.proto)
  1027		if l > 255 {
  1028			l = 255
  1029		}
  1030	
  1031		padding := 32 - (l+2)%32
  1032		length := l + padding + 2
  1033		x := make([]byte, length+4)
  1034		x[0] = typeNextProtocol
  1035		x[1] = uint8(length >> 16)
  1036		x[2] = uint8(length >> 8)
  1037		x[3] = uint8(length)
  1038	
  1039		y := x[4:]
  1040		y[0] = byte(l)
  1041		copy(y[1:], []byte(m.proto[0:l]))
  1042		y = y[1+l:]
  1043		y[0] = byte(padding)
  1044	
  1045		m.raw = x
  1046	
  1047		return x
  1048	}
  1049	
  1050	func (m *nextProtoMsg) unmarshal(data []byte) bool {
  1051		m.raw = data
  1052	
  1053		if len(data) < 5 {
  1054			return false
  1055		}
  1056		data = data[4:]
  1057		protoLen := int(data[0])
  1058		data = data[1:]
  1059		if len(data) < protoLen {
  1060			return false
  1061		}
  1062		m.proto = string(data[0:protoLen])
  1063		data = data[protoLen:]
  1064	
  1065		if len(data) < 1 {
  1066			return false
  1067		}
  1068		paddingLen := int(data[0])
  1069		data = data[1:]
  1070		if len(data) != paddingLen {
  1071			return false
  1072		}
  1073	
  1074		return true
  1075	}
  1076	
  1077	type certificateRequestMsg struct {
  1078		raw []byte
  1079		// hasSignatureAndHash indicates whether this message includes a list
  1080		// of signature and hash functions. This change was introduced with TLS
  1081		// 1.2.
  1082		hasSignatureAndHash bool
  1083	
  1084		certificateTypes       []byte
  1085		signatureAndHashes     []signatureAndHash
  1086		certificateAuthorities [][]byte
  1087	}
  1088	
  1089	func (m *certificateRequestMsg) equal(i interface{}) bool {
  1090		m1, ok := i.(*certificateRequestMsg)
  1091		if !ok {
  1092			return false
  1093		}
  1094	
  1095		return bytes.Equal(m.raw, m1.raw) &&
  1096			bytes.Equal(m.certificateTypes, m1.certificateTypes) &&
  1097			eqByteSlices(m.certificateAuthorities, m1.certificateAuthorities) &&
  1098			eqSignatureAndHashes(m.signatureAndHashes, m1.signatureAndHashes)
  1099	}
  1100	
  1101	func (m *certificateRequestMsg) marshal() (x []byte) {
  1102		if m.raw != nil {
  1103			return m.raw
  1104		}
  1105	
  1106		// See http://tools.ietf.org/html/rfc4346#section-7.4.4
  1107		length := 1 + len(m.certificateTypes) + 2
  1108		casLength := 0
  1109		for _, ca := range m.certificateAuthorities {
  1110			casLength += 2 + len(ca)
  1111		}
  1112		length += casLength
  1113	
  1114		if m.hasSignatureAndHash {
  1115			length += 2 + 2*len(m.signatureAndHashes)
  1116		}
  1117	
  1118		x = make([]byte, 4+length)
  1119		x[0] = typeCertificateRequest
  1120		x[1] = uint8(length >> 16)
  1121		x[2] = uint8(length >> 8)
  1122		x[3] = uint8(length)
  1123	
  1124		x[4] = uint8(len(m.certificateTypes))
  1125	
  1126		copy(x[5:], m.certificateTypes)
  1127		y := x[5+len(m.certificateTypes):]
  1128	
  1129		if m.hasSignatureAndHash {
  1130			n := len(m.signatureAndHashes) * 2
  1131			y[0] = uint8(n >> 8)
  1132			y[1] = uint8(n)
  1133			y = y[2:]
  1134			for _, sigAndHash := range m.signatureAndHashes {
  1135				y[0] = sigAndHash.hash
  1136				y[1] = sigAndHash.signature
  1137				y = y[2:]
  1138			}
  1139		}
  1140	
  1141		y[0] = uint8(casLength >> 8)
  1142		y[1] = uint8(casLength)
  1143		y = y[2:]
  1144		for _, ca := range m.certificateAuthorities {
  1145			y[0] = uint8(len(ca) >> 8)
  1146			y[1] = uint8(len(ca))
  1147			y = y[2:]
  1148			copy(y, ca)
  1149			y = y[len(ca):]
  1150		}
  1151	
  1152		m.raw = x
  1153		return
  1154	}
  1155	
  1156	func (m *certificateRequestMsg) unmarshal(data []byte) bool {
  1157		m.raw = data
  1158	
  1159		if len(data) < 5 {
  1160			return false
  1161		}
  1162	
  1163		length := uint32(data[1])<<16 | uint32(data[2])<<8 | uint32(data[3])
  1164		if uint32(len(data))-4 != length {
  1165			return false
  1166		}
  1167	
  1168		numCertTypes := int(data[4])
  1169		data = data[5:]
  1170		if numCertTypes == 0 || len(data) <= numCertTypes {
  1171			return false
  1172		}
  1173	
  1174		m.certificateTypes = make([]byte, numCertTypes)
  1175		if copy(m.certificateTypes, data) != numCertTypes {
  1176			return false
  1177		}
  1178	
  1179		data = data[numCertTypes:]
  1180	
  1181		if m.hasSignatureAndHash {
  1182			if len(data) < 2 {
  1183				return false
  1184			}
  1185			sigAndHashLen := uint16(data[0])<<8 | uint16(data[1])
  1186			data = data[2:]
  1187			if sigAndHashLen&1 != 0 {
  1188				return false
  1189			}
  1190			if len(data) < int(sigAndHashLen) {
  1191				return false
  1192			}
  1193			numSigAndHash := sigAndHashLen / 2
  1194			m.signatureAndHashes = make([]signatureAndHash, numSigAndHash)
  1195			for i := range m.signatureAndHashes {
  1196				m.signatureAndHashes[i].hash = data[0]
  1197				m.signatureAndHashes[i].signature = data[1]
  1198				data = data[2:]
  1199			}
  1200		}
  1201	
  1202		if len(data) < 2 {
  1203			return false
  1204		}
  1205		casLength := uint16(data[0])<<8 | uint16(data[1])
  1206		data = data[2:]
  1207		if len(data) < int(casLength) {
  1208			return false
  1209		}
  1210		cas := make([]byte, casLength)
  1211		copy(cas, data)
  1212		data = data[casLength:]
  1213	
  1214		m.certificateAuthorities = nil
  1215		for len(cas) > 0 {
  1216			if len(cas) < 2 {
  1217				return false
  1218			}
  1219			caLen := uint16(cas[0])<<8 | uint16(cas[1])
  1220			cas = cas[2:]
  1221	
  1222			if len(cas) < int(caLen) {
  1223				return false
  1224			}
  1225	
  1226			m.certificateAuthorities = append(m.certificateAuthorities, cas[:caLen])
  1227			cas = cas[caLen:]
  1228		}
  1229		if len(data) > 0 {
  1230			return false
  1231		}
  1232	
  1233		return true
  1234	}
  1235	
  1236	type certificateVerifyMsg struct {
  1237		raw                 []byte
  1238		hasSignatureAndHash bool
  1239		signatureAndHash    signatureAndHash
  1240		signature           []byte
  1241	}
  1242	
  1243	func (m *certificateVerifyMsg) equal(i interface{}) bool {
  1244		m1, ok := i.(*certificateVerifyMsg)
  1245		if !ok {
  1246			return false
  1247		}
  1248	
  1249		return bytes.Equal(m.raw, m1.raw) &&
  1250			m.hasSignatureAndHash == m1.hasSignatureAndHash &&
  1251			m.signatureAndHash.hash == m1.signatureAndHash.hash &&
  1252			m.signatureAndHash.signature == m1.signatureAndHash.signature &&
  1253			bytes.Equal(m.signature, m1.signature)
  1254	}
  1255	
  1256	func (m *certificateVerifyMsg) marshal() (x []byte) {
  1257		if m.raw != nil {
  1258			return m.raw
  1259		}
  1260	
  1261		// See http://tools.ietf.org/html/rfc4346#section-7.4.8
  1262		siglength := len(m.signature)
  1263		length := 2 + siglength
  1264		if m.hasSignatureAndHash {
  1265			length += 2
  1266		}
  1267		x = make([]byte, 4+length)
  1268		x[0] = typeCertificateVerify
  1269		x[1] = uint8(length >> 16)
  1270		x[2] = uint8(length >> 8)
  1271		x[3] = uint8(length)
  1272		y := x[4:]
  1273		if m.hasSignatureAndHash {
  1274			y[0] = m.signatureAndHash.hash
  1275			y[1] = m.signatureAndHash.signature
  1276			y = y[2:]
  1277		}
  1278		y[0] = uint8(siglength >> 8)
  1279		y[1] = uint8(siglength)
  1280		copy(y[2:], m.signature)
  1281	
  1282		m.raw = x
  1283	
  1284		return
  1285	}
  1286	
  1287	func (m *certificateVerifyMsg) unmarshal(data []byte) bool {
  1288		m.raw = data
  1289	
  1290		if len(data) < 6 {
  1291			return false
  1292		}
  1293	
  1294		length := uint32(data[1])<<16 | uint32(data[2])<<8 | uint32(data[3])
  1295		if uint32(len(data))-4 != length {
  1296			return false
  1297		}
  1298	
  1299		data = data[4:]
  1300		if m.hasSignatureAndHash {
  1301			m.signatureAndHash.hash = data[0]
  1302			m.signatureAndHash.signature = data[1]
  1303			data = data[2:]
  1304		}
  1305	
  1306		if len(data) < 2 {
  1307			return false
  1308		}
  1309		siglength := int(data[0])<<8 + int(data[1])
  1310		data = data[2:]
  1311		if len(data) != siglength {
  1312			return false
  1313		}
  1314	
  1315		m.signature = data
  1316	
  1317		return true
  1318	}
  1319	
  1320	type newSessionTicketMsg struct {
  1321		raw    []byte
  1322		ticket []byte
  1323	}
  1324	
  1325	func (m *newSessionTicketMsg) equal(i interface{}) bool {
  1326		m1, ok := i.(*newSessionTicketMsg)
  1327		if !ok {
  1328			return false
  1329		}
  1330	
  1331		return bytes.Equal(m.raw, m1.raw) &&
  1332			bytes.Equal(m.ticket, m1.ticket)
  1333	}
  1334	
  1335	func (m *newSessionTicketMsg) marshal() (x []byte) {
  1336		if m.raw != nil {
  1337			return m.raw
  1338		}
  1339	
  1340		// See http://tools.ietf.org/html/rfc5077#section-3.3
  1341		ticketLen := len(m.ticket)
  1342		length := 2 + 4 + ticketLen
  1343		x = make([]byte, 4+length)
  1344		x[0] = typeNewSessionTicket
  1345		x[1] = uint8(length >> 16)
  1346		x[2] = uint8(length >> 8)
  1347		x[3] = uint8(length)
  1348		x[8] = uint8(ticketLen >> 8)
  1349		x[9] = uint8(ticketLen)
  1350		copy(x[10:], m.ticket)
  1351	
  1352		m.raw = x
  1353	
  1354		return
  1355	}
  1356	
  1357	func (m *newSessionTicketMsg) unmarshal(data []byte) bool {
  1358		m.raw = data
  1359	
  1360		if len(data) < 10 {
  1361			return false
  1362		}
  1363	
  1364		length := uint32(data[1])<<16 | uint32(data[2])<<8 | uint32(data[3])
  1365		if uint32(len(data))-4 != length {
  1366			return false
  1367		}
  1368	
  1369		ticketLen := int(data[8])<<8 + int(data[9])
  1370		if len(data)-10 != ticketLen {
  1371			return false
  1372		}
  1373	
  1374		m.ticket = data[10:]
  1375	
  1376		return true
  1377	}
  1378	
  1379	func eqUint16s(x, y []uint16) bool {
  1380		if len(x) != len(y) {
  1381			return false
  1382		}
  1383		for i, v := range x {
  1384			if y[i] != v {
  1385				return false
  1386			}
  1387		}
  1388		return true
  1389	}
  1390	
  1391	func eqCurveIDs(x, y []CurveID) bool {
  1392		if len(x) != len(y) {
  1393			return false
  1394		}
  1395		for i, v := range x {
  1396			if y[i] != v {
  1397				return false
  1398			}
  1399		}
  1400		return true
  1401	}
  1402	
  1403	func eqStrings(x, y []string) bool {
  1404		if len(x) != len(y) {
  1405			return false
  1406		}
  1407		for i, v := range x {
  1408			if y[i] != v {
  1409				return false
  1410			}
  1411		}
  1412		return true
  1413	}
  1414	
  1415	func eqByteSlices(x, y [][]byte) bool {
  1416		if len(x) != len(y) {
  1417			return false
  1418		}
  1419		for i, v := range x {
  1420			if !bytes.Equal(v, y[i]) {
  1421				return false
  1422			}
  1423		}
  1424		return true
  1425	}
  1426	
  1427	func eqSignatureAndHashes(x, y []signatureAndHash) bool {
  1428		if len(x) != len(y) {
  1429			return false
  1430		}
  1431		for i, v := range x {
  1432			v2 := y[i]
  1433			if v.hash != v2.hash || v.signature != v2.signature {
  1434				return false
  1435			}
  1436		}
  1437		return true
  1438	}
  1439	

View as plain text