Run Format

Source file src/pkg/crypto/tls/handshake_messages.go

     1	// Copyright 2009 The Go Authors. All rights reserved.
     2	// Use of this source code is governed by a BSD-style
     3	// license that can be found in the LICENSE file.
     4	
     5	package tls
     6	
     7	import "bytes"
     8	
     9	type clientHelloMsg struct {
    10		raw                []byte
    11		vers               uint16
    12		random             []byte
    13		sessionId          []byte
    14		cipherSuites       []uint16
    15		compressionMethods []uint8
    16		nextProtoNeg       bool
    17		serverName         string
    18		ocspStapling       bool
    19		supportedCurves    []uint16
    20		supportedPoints    []uint8
    21		ticketSupported    bool
    22		sessionTicket      []uint8
    23	}
    24	
    25	func (m *clientHelloMsg) equal(i interface{}) bool {
    26		m1, ok := i.(*clientHelloMsg)
    27		if !ok {
    28			return false
    29		}
    30	
    31		return bytes.Equal(m.raw, m1.raw) &&
    32			m.vers == m1.vers &&
    33			bytes.Equal(m.random, m1.random) &&
    34			bytes.Equal(m.sessionId, m1.sessionId) &&
    35			eqUint16s(m.cipherSuites, m1.cipherSuites) &&
    36			bytes.Equal(m.compressionMethods, m1.compressionMethods) &&
    37			m.nextProtoNeg == m1.nextProtoNeg &&
    38			m.serverName == m1.serverName &&
    39			m.ocspStapling == m1.ocspStapling &&
    40			eqUint16s(m.supportedCurves, m1.supportedCurves) &&
    41			bytes.Equal(m.supportedPoints, m1.supportedPoints) &&
    42			m.ticketSupported == m1.ticketSupported &&
    43			bytes.Equal(m.sessionTicket, m1.sessionTicket)
    44	}
    45	
    46	func (m *clientHelloMsg) marshal() []byte {
    47		if m.raw != nil {
    48			return m.raw
    49		}
    50	
    51		length := 2 + 32 + 1 + len(m.sessionId) + 2 + len(m.cipherSuites)*2 + 1 + len(m.compressionMethods)
    52		numExtensions := 0
    53		extensionsLength := 0
    54		if m.nextProtoNeg {
    55			numExtensions++
    56		}
    57		if m.ocspStapling {
    58			extensionsLength += 1 + 2 + 2
    59			numExtensions++
    60		}
    61		if len(m.serverName) > 0 {
    62			extensionsLength += 5 + len(m.serverName)
    63			numExtensions++
    64		}
    65		if len(m.supportedCurves) > 0 {
    66			extensionsLength += 2 + 2*len(m.supportedCurves)
    67			numExtensions++
    68		}
    69		if len(m.supportedPoints) > 0 {
    70			extensionsLength += 1 + len(m.supportedPoints)
    71			numExtensions++
    72		}
    73		if m.ticketSupported {
    74			extensionsLength += len(m.sessionTicket)
    75			numExtensions++
    76		}
    77		if numExtensions > 0 {
    78			extensionsLength += 4 * numExtensions
    79			length += 2 + extensionsLength
    80		}
    81	
    82		x := make([]byte, 4+length)
    83		x[0] = typeClientHello
    84		x[1] = uint8(length >> 16)
    85		x[2] = uint8(length >> 8)
    86		x[3] = uint8(length)
    87		x[4] = uint8(m.vers >> 8)
    88		x[5] = uint8(m.vers)
    89		copy(x[6:38], m.random)
    90		x[38] = uint8(len(m.sessionId))
    91		copy(x[39:39+len(m.sessionId)], m.sessionId)
    92		y := x[39+len(m.sessionId):]
    93		y[0] = uint8(len(m.cipherSuites) >> 7)
    94		y[1] = uint8(len(m.cipherSuites) << 1)
    95		for i, suite := range m.cipherSuites {
    96			y[2+i*2] = uint8(suite >> 8)
    97			y[3+i*2] = uint8(suite)
    98		}
    99		z := y[2+len(m.cipherSuites)*2:]
   100		z[0] = uint8(len(m.compressionMethods))
   101		copy(z[1:], m.compressionMethods)
   102	
   103		z = z[1+len(m.compressionMethods):]
   104		if numExtensions > 0 {
   105			z[0] = byte(extensionsLength >> 8)
   106			z[1] = byte(extensionsLength)
   107			z = z[2:]
   108		}
   109		if m.nextProtoNeg {
   110			z[0] = byte(extensionNextProtoNeg >> 8)
   111			z[1] = byte(extensionNextProtoNeg)
   112			// The length is always 0
   113			z = z[4:]
   114		}
   115		if len(m.serverName) > 0 {
   116			z[0] = byte(extensionServerName >> 8)
   117			z[1] = byte(extensionServerName)
   118			l := len(m.serverName) + 5
   119			z[2] = byte(l >> 8)
   120			z[3] = byte(l)
   121			z = z[4:]
   122	
   123			// RFC 3546, section 3.1
   124			//
   125			// struct {
   126			//     NameType name_type;
   127			//     select (name_type) {
   128			//         case host_name: HostName;
   129			//     } name;
   130			// } ServerName;
   131			//
   132			// enum {
   133			//     host_name(0), (255)
   134			// } NameType;
   135			//
   136			// opaque HostName<1..2^16-1>;
   137			//
   138			// struct {
   139			//     ServerName server_name_list<1..2^16-1>
   140			// } ServerNameList;
   141	
   142			z[0] = byte((len(m.serverName) + 3) >> 8)
   143			z[1] = byte(len(m.serverName) + 3)
   144			z[3] = byte(len(m.serverName) >> 8)
   145			z[4] = byte(len(m.serverName))
   146			copy(z[5:], []byte(m.serverName))
   147			z = z[l:]
   148		}
   149		if m.ocspStapling {
   150			// RFC 4366, section 3.6
   151			z[0] = byte(extensionStatusRequest >> 8)
   152			z[1] = byte(extensionStatusRequest)
   153			z[2] = 0
   154			z[3] = 5
   155			z[4] = 1 // OCSP type
   156			// Two zero valued uint16s for the two lengths.
   157			z = z[9:]
   158		}
   159		if len(m.supportedCurves) > 0 {
   160			// http://tools.ietf.org/html/rfc4492#section-5.5.1
   161			z[0] = byte(extensionSupportedCurves >> 8)
   162			z[1] = byte(extensionSupportedCurves)
   163			l := 2 + 2*len(m.supportedCurves)
   164			z[2] = byte(l >> 8)
   165			z[3] = byte(l)
   166			l -= 2
   167			z[4] = byte(l >> 8)
   168			z[5] = byte(l)
   169			z = z[6:]
   170			for _, curve := range m.supportedCurves {
   171				z[0] = byte(curve >> 8)
   172				z[1] = byte(curve)
   173				z = z[2:]
   174			}
   175		}
   176		if len(m.supportedPoints) > 0 {
   177			// http://tools.ietf.org/html/rfc4492#section-5.5.2
   178			z[0] = byte(extensionSupportedPoints >> 8)
   179			z[1] = byte(extensionSupportedPoints)
   180			l := 1 + len(m.supportedPoints)
   181			z[2] = byte(l >> 8)
   182			z[3] = byte(l)
   183			l--
   184			z[4] = byte(l)
   185			z = z[5:]
   186			for _, pointFormat := range m.supportedPoints {
   187				z[0] = byte(pointFormat)
   188				z = z[1:]
   189			}
   190		}
   191		if m.ticketSupported {
   192			// http://tools.ietf.org/html/rfc5077#section-3.2
   193			z[0] = byte(extensionSessionTicket >> 8)
   194			z[1] = byte(extensionSessionTicket)
   195			l := len(m.sessionTicket)
   196			z[2] = byte(l >> 8)
   197			z[3] = byte(l)
   198			z = z[4:]
   199			copy(z, m.sessionTicket)
   200			z = z[len(m.sessionTicket):]
   201		}
   202	
   203		m.raw = x
   204	
   205		return x
   206	}
   207	
   208	func (m *clientHelloMsg) unmarshal(data []byte) bool {
   209		if len(data) < 42 {
   210			return false
   211		}
   212		m.raw = data
   213		m.vers = uint16(data[4])<<8 | uint16(data[5])
   214		m.random = data[6:38]
   215		sessionIdLen := int(data[38])
   216		if sessionIdLen > 32 || len(data) < 39+sessionIdLen {
   217			return false
   218		}
   219		m.sessionId = data[39 : 39+sessionIdLen]
   220		data = data[39+sessionIdLen:]
   221		if len(data) < 2 {
   222			return false
   223		}
   224		// cipherSuiteLen is the number of bytes of cipher suite numbers. Since
   225		// they are uint16s, the number must be even.
   226		cipherSuiteLen := int(data[0])<<8 | int(data[1])
   227		if cipherSuiteLen%2 == 1 || len(data) < 2+cipherSuiteLen {
   228			return false
   229		}
   230		numCipherSuites := cipherSuiteLen / 2
   231		m.cipherSuites = make([]uint16, numCipherSuites)
   232		for i := 0; i < numCipherSuites; i++ {
   233			m.cipherSuites[i] = uint16(data[2+2*i])<<8 | uint16(data[3+2*i])
   234		}
   235		data = data[2+cipherSuiteLen:]
   236		if len(data) < 1 {
   237			return false
   238		}
   239		compressionMethodsLen := int(data[0])
   240		if len(data) < 1+compressionMethodsLen {
   241			return false
   242		}
   243		m.compressionMethods = data[1 : 1+compressionMethodsLen]
   244	
   245		data = data[1+compressionMethodsLen:]
   246	
   247		m.nextProtoNeg = false
   248		m.serverName = ""
   249		m.ocspStapling = false
   250		m.ticketSupported = false
   251		m.sessionTicket = nil
   252	
   253		if len(data) == 0 {
   254			// ClientHello is optionally followed by extension data
   255			return true
   256		}
   257		if len(data) < 2 {
   258			return false
   259		}
   260	
   261		extensionsLength := int(data[0])<<8 | int(data[1])
   262		data = data[2:]
   263		if extensionsLength != len(data) {
   264			return false
   265		}
   266	
   267		for len(data) != 0 {
   268			if len(data) < 4 {
   269				return false
   270			}
   271			extension := uint16(data[0])<<8 | uint16(data[1])
   272			length := int(data[2])<<8 | int(data[3])
   273			data = data[4:]
   274			if len(data) < length {
   275				return false
   276			}
   277	
   278			switch extension {
   279			case extensionServerName:
   280				if length < 2 {
   281					return false
   282				}
   283				numNames := int(data[0])<<8 | int(data[1])
   284				d := data[2:]
   285				for i := 0; i < numNames; i++ {
   286					if len(d) < 3 {
   287						return false
   288					}
   289					nameType := d[0]
   290					nameLen := int(d[1])<<8 | int(d[2])
   291					d = d[3:]
   292					if len(d) < nameLen {
   293						return false
   294					}
   295					if nameType == 0 {
   296						m.serverName = string(d[0:nameLen])
   297						break
   298					}
   299					d = d[nameLen:]
   300				}
   301			case extensionNextProtoNeg:
   302				if length > 0 {
   303					return false
   304				}
   305				m.nextProtoNeg = true
   306			case extensionStatusRequest:
   307				m.ocspStapling = length > 0 && data[0] == statusTypeOCSP
   308			case extensionSupportedCurves:
   309				// http://tools.ietf.org/html/rfc4492#section-5.5.1
   310				if length < 2 {
   311					return false
   312				}
   313				l := int(data[0])<<8 | int(data[1])
   314				if l%2 == 1 || length != l+2 {
   315					return false
   316				}
   317				numCurves := l / 2
   318				m.supportedCurves = make([]uint16, numCurves)
   319				d := data[2:]
   320				for i := 0; i < numCurves; i++ {
   321					m.supportedCurves[i] = uint16(d[0])<<8 | uint16(d[1])
   322					d = d[2:]
   323				}
   324			case extensionSupportedPoints:
   325				// http://tools.ietf.org/html/rfc4492#section-5.5.2
   326				if length < 1 {
   327					return false
   328				}
   329				l := int(data[0])
   330				if length != l+1 {
   331					return false
   332				}
   333				m.supportedPoints = make([]uint8, l)
   334				copy(m.supportedPoints, data[1:])
   335			case extensionSessionTicket:
   336				// http://tools.ietf.org/html/rfc5077#section-3.2
   337				m.ticketSupported = true
   338				m.sessionTicket = data[:length]
   339			}
   340			data = data[length:]
   341		}
   342	
   343		return true
   344	}
   345	
   346	type serverHelloMsg struct {
   347		raw               []byte
   348		vers              uint16
   349		random            []byte
   350		sessionId         []byte
   351		cipherSuite       uint16
   352		compressionMethod uint8
   353		nextProtoNeg      bool
   354		nextProtos        []string
   355		ocspStapling      bool
   356		ticketSupported   bool
   357	}
   358	
   359	func (m *serverHelloMsg) equal(i interface{}) bool {
   360		m1, ok := i.(*serverHelloMsg)
   361		if !ok {
   362			return false
   363		}
   364	
   365		return bytes.Equal(m.raw, m1.raw) &&
   366			m.vers == m1.vers &&
   367			bytes.Equal(m.random, m1.random) &&
   368			bytes.Equal(m.sessionId, m1.sessionId) &&
   369			m.cipherSuite == m1.cipherSuite &&
   370			m.compressionMethod == m1.compressionMethod &&
   371			m.nextProtoNeg == m1.nextProtoNeg &&
   372			eqStrings(m.nextProtos, m1.nextProtos) &&
   373			m.ocspStapling == m1.ocspStapling &&
   374			m.ticketSupported == m1.ticketSupported
   375	}
   376	
   377	func (m *serverHelloMsg) marshal() []byte {
   378		if m.raw != nil {
   379			return m.raw
   380		}
   381	
   382		length := 38 + len(m.sessionId)
   383		numExtensions := 0
   384		extensionsLength := 0
   385	
   386		nextProtoLen := 0
   387		if m.nextProtoNeg {
   388			numExtensions++
   389			for _, v := range m.nextProtos {
   390				nextProtoLen += len(v)
   391			}
   392			nextProtoLen += len(m.nextProtos)
   393			extensionsLength += nextProtoLen
   394		}
   395		if m.ocspStapling {
   396			numExtensions++
   397		}
   398		if m.ticketSupported {
   399			numExtensions++
   400		}
   401		if numExtensions > 0 {
   402			extensionsLength += 4 * numExtensions
   403			length += 2 + extensionsLength
   404		}
   405	
   406		x := make([]byte, 4+length)
   407		x[0] = typeServerHello
   408		x[1] = uint8(length >> 16)
   409		x[2] = uint8(length >> 8)
   410		x[3] = uint8(length)
   411		x[4] = uint8(m.vers >> 8)
   412		x[5] = uint8(m.vers)
   413		copy(x[6:38], m.random)
   414		x[38] = uint8(len(m.sessionId))
   415		copy(x[39:39+len(m.sessionId)], m.sessionId)
   416		z := x[39+len(m.sessionId):]
   417		z[0] = uint8(m.cipherSuite >> 8)
   418		z[1] = uint8(m.cipherSuite)
   419		z[2] = uint8(m.compressionMethod)
   420	
   421		z = z[3:]
   422		if numExtensions > 0 {
   423			z[0] = byte(extensionsLength >> 8)
   424			z[1] = byte(extensionsLength)
   425			z = z[2:]
   426		}
   427		if m.nextProtoNeg {
   428			z[0] = byte(extensionNextProtoNeg >> 8)
   429			z[1] = byte(extensionNextProtoNeg)
   430			z[2] = byte(nextProtoLen >> 8)
   431			z[3] = byte(nextProtoLen)
   432			z = z[4:]
   433	
   434			for _, v := range m.nextProtos {
   435				l := len(v)
   436				if l > 255 {
   437					l = 255
   438				}
   439				z[0] = byte(l)
   440				copy(z[1:], []byte(v[0:l]))
   441				z = z[1+l:]
   442			}
   443		}
   444		if m.ocspStapling {
   445			z[0] = byte(extensionStatusRequest >> 8)
   446			z[1] = byte(extensionStatusRequest)
   447			z = z[4:]
   448		}
   449		if m.ticketSupported {
   450			z[0] = byte(extensionSessionTicket >> 8)
   451			z[1] = byte(extensionSessionTicket)
   452			z = z[4:]
   453		}
   454	
   455		m.raw = x
   456	
   457		return x
   458	}
   459	
   460	func (m *serverHelloMsg) unmarshal(data []byte) bool {
   461		if len(data) < 42 {
   462			return false
   463		}
   464		m.raw = data
   465		m.vers = uint16(data[4])<<8 | uint16(data[5])
   466		m.random = data[6:38]
   467		sessionIdLen := int(data[38])
   468		if sessionIdLen > 32 || len(data) < 39+sessionIdLen {
   469			return false
   470		}
   471		m.sessionId = data[39 : 39+sessionIdLen]
   472		data = data[39+sessionIdLen:]
   473		if len(data) < 3 {
   474			return false
   475		}
   476		m.cipherSuite = uint16(data[0])<<8 | uint16(data[1])
   477		m.compressionMethod = data[2]
   478		data = data[3:]
   479	
   480		m.nextProtoNeg = false
   481		m.nextProtos = nil
   482		m.ocspStapling = false
   483		m.ticketSupported = false
   484	
   485		if len(data) == 0 {
   486			// ServerHello is optionally followed by extension data
   487			return true
   488		}
   489		if len(data) < 2 {
   490			return false
   491		}
   492	
   493		extensionsLength := int(data[0])<<8 | int(data[1])
   494		data = data[2:]
   495		if len(data) != extensionsLength {
   496			return false
   497		}
   498	
   499		for len(data) != 0 {
   500			if len(data) < 4 {
   501				return false
   502			}
   503			extension := uint16(data[0])<<8 | uint16(data[1])
   504			length := int(data[2])<<8 | int(data[3])
   505			data = data[4:]
   506			if len(data) < length {
   507				return false
   508			}
   509	
   510			switch extension {
   511			case extensionNextProtoNeg:
   512				m.nextProtoNeg = true
   513				d := data[:length]
   514				for len(d) > 0 {
   515					l := int(d[0])
   516					d = d[1:]
   517					if l == 0 || l > len(d) {
   518						return false
   519					}
   520					m.nextProtos = append(m.nextProtos, string(d[:l]))
   521					d = d[l:]
   522				}
   523			case extensionStatusRequest:
   524				if length > 0 {
   525					return false
   526				}
   527				m.ocspStapling = true
   528			case extensionSessionTicket:
   529				if length > 0 {
   530					return false
   531				}
   532				m.ticketSupported = true
   533			}
   534			data = data[length:]
   535		}
   536	
   537		return true
   538	}
   539	
   540	type certificateMsg struct {
   541		raw          []byte
   542		certificates [][]byte
   543	}
   544	
   545	func (m *certificateMsg) equal(i interface{}) bool {
   546		m1, ok := i.(*certificateMsg)
   547		if !ok {
   548			return false
   549		}
   550	
   551		return bytes.Equal(m.raw, m1.raw) &&
   552			eqByteSlices(m.certificates, m1.certificates)
   553	}
   554	
   555	func (m *certificateMsg) marshal() (x []byte) {
   556		if m.raw != nil {
   557			return m.raw
   558		}
   559	
   560		var i int
   561		for _, slice := range m.certificates {
   562			i += len(slice)
   563		}
   564	
   565		length := 3 + 3*len(m.certificates) + i
   566		x = make([]byte, 4+length)
   567		x[0] = typeCertificate
   568		x[1] = uint8(length >> 16)
   569		x[2] = uint8(length >> 8)
   570		x[3] = uint8(length)
   571	
   572		certificateOctets := length - 3
   573		x[4] = uint8(certificateOctets >> 16)
   574		x[5] = uint8(certificateOctets >> 8)
   575		x[6] = uint8(certificateOctets)
   576	
   577		y := x[7:]
   578		for _, slice := range m.certificates {
   579			y[0] = uint8(len(slice) >> 16)
   580			y[1] = uint8(len(slice) >> 8)
   581			y[2] = uint8(len(slice))
   582			copy(y[3:], slice)
   583			y = y[3+len(slice):]
   584		}
   585	
   586		m.raw = x
   587		return
   588	}
   589	
   590	func (m *certificateMsg) unmarshal(data []byte) bool {
   591		if len(data) < 7 {
   592			return false
   593		}
   594	
   595		m.raw = data
   596		certsLen := uint32(data[4])<<16 | uint32(data[5])<<8 | uint32(data[6])
   597		if uint32(len(data)) != certsLen+7 {
   598			return false
   599		}
   600	
   601		numCerts := 0
   602		d := data[7:]
   603		for certsLen > 0 {
   604			if len(d) < 4 {
   605				return false
   606			}
   607			certLen := uint32(d[0])<<16 | uint32(d[1])<<8 | uint32(d[2])
   608			if uint32(len(d)) < 3+certLen {
   609				return false
   610			}
   611			d = d[3+certLen:]
   612			certsLen -= 3 + certLen
   613			numCerts++
   614		}
   615	
   616		m.certificates = make([][]byte, numCerts)
   617		d = data[7:]
   618		for i := 0; i < numCerts; i++ {
   619			certLen := uint32(d[0])<<16 | uint32(d[1])<<8 | uint32(d[2])
   620			m.certificates[i] = d[3 : 3+certLen]
   621			d = d[3+certLen:]
   622		}
   623	
   624		return true
   625	}
   626	
   627	type serverKeyExchangeMsg struct {
   628		raw []byte
   629		key []byte
   630	}
   631	
   632	func (m *serverKeyExchangeMsg) equal(i interface{}) bool {
   633		m1, ok := i.(*serverKeyExchangeMsg)
   634		if !ok {
   635			return false
   636		}
   637	
   638		return bytes.Equal(m.raw, m1.raw) &&
   639			bytes.Equal(m.key, m1.key)
   640	}
   641	
   642	func (m *serverKeyExchangeMsg) marshal() []byte {
   643		if m.raw != nil {
   644			return m.raw
   645		}
   646		length := len(m.key)
   647		x := make([]byte, length+4)
   648		x[0] = typeServerKeyExchange
   649		x[1] = uint8(length >> 16)
   650		x[2] = uint8(length >> 8)
   651		x[3] = uint8(length)
   652		copy(x[4:], m.key)
   653	
   654		m.raw = x
   655		return x
   656	}
   657	
   658	func (m *serverKeyExchangeMsg) unmarshal(data []byte) bool {
   659		m.raw = data
   660		if len(data) < 4 {
   661			return false
   662		}
   663		m.key = data[4:]
   664		return true
   665	}
   666	
   667	type certificateStatusMsg struct {
   668		raw        []byte
   669		statusType uint8
   670		response   []byte
   671	}
   672	
   673	func (m *certificateStatusMsg) equal(i interface{}) bool {
   674		m1, ok := i.(*certificateStatusMsg)
   675		if !ok {
   676			return false
   677		}
   678	
   679		return bytes.Equal(m.raw, m1.raw) &&
   680			m.statusType == m1.statusType &&
   681			bytes.Equal(m.response, m1.response)
   682	}
   683	
   684	func (m *certificateStatusMsg) marshal() []byte {
   685		if m.raw != nil {
   686			return m.raw
   687		}
   688	
   689		var x []byte
   690		if m.statusType == statusTypeOCSP {
   691			x = make([]byte, 4+4+len(m.response))
   692			x[0] = typeCertificateStatus
   693			l := len(m.response) + 4
   694			x[1] = byte(l >> 16)
   695			x[2] = byte(l >> 8)
   696			x[3] = byte(l)
   697			x[4] = statusTypeOCSP
   698	
   699			l -= 4
   700			x[5] = byte(l >> 16)
   701			x[6] = byte(l >> 8)
   702			x[7] = byte(l)
   703			copy(x[8:], m.response)
   704		} else {
   705			x = []byte{typeCertificateStatus, 0, 0, 1, m.statusType}
   706		}
   707	
   708		m.raw = x
   709		return x
   710	}
   711	
   712	func (m *certificateStatusMsg) unmarshal(data []byte) bool {
   713		m.raw = data
   714		if len(data) < 5 {
   715			return false
   716		}
   717		m.statusType = data[4]
   718	
   719		m.response = nil
   720		if m.statusType == statusTypeOCSP {
   721			if len(data) < 8 {
   722				return false
   723			}
   724			respLen := uint32(data[5])<<16 | uint32(data[6])<<8 | uint32(data[7])
   725			if uint32(len(data)) != 4+4+respLen {
   726				return false
   727			}
   728			m.response = data[8:]
   729		}
   730		return true
   731	}
   732	
   733	type serverHelloDoneMsg struct{}
   734	
   735	func (m *serverHelloDoneMsg) equal(i interface{}) bool {
   736		_, ok := i.(*serverHelloDoneMsg)
   737		return ok
   738	}
   739	
   740	func (m *serverHelloDoneMsg) marshal() []byte {
   741		x := make([]byte, 4)
   742		x[0] = typeServerHelloDone
   743		return x
   744	}
   745	
   746	func (m *serverHelloDoneMsg) unmarshal(data []byte) bool {
   747		return len(data) == 4
   748	}
   749	
   750	type clientKeyExchangeMsg struct {
   751		raw        []byte
   752		ciphertext []byte
   753	}
   754	
   755	func (m *clientKeyExchangeMsg) equal(i interface{}) bool {
   756		m1, ok := i.(*clientKeyExchangeMsg)
   757		if !ok {
   758			return false
   759		}
   760	
   761		return bytes.Equal(m.raw, m1.raw) &&
   762			bytes.Equal(m.ciphertext, m1.ciphertext)
   763	}
   764	
   765	func (m *clientKeyExchangeMsg) marshal() []byte {
   766		if m.raw != nil {
   767			return m.raw
   768		}
   769		length := len(m.ciphertext)
   770		x := make([]byte, length+4)
   771		x[0] = typeClientKeyExchange
   772		x[1] = uint8(length >> 16)
   773		x[2] = uint8(length >> 8)
   774		x[3] = uint8(length)
   775		copy(x[4:], m.ciphertext)
   776	
   777		m.raw = x
   778		return x
   779	}
   780	
   781	func (m *clientKeyExchangeMsg) unmarshal(data []byte) bool {
   782		m.raw = data
   783		if len(data) < 4 {
   784			return false
   785		}
   786		l := int(data[1])<<16 | int(data[2])<<8 | int(data[3])
   787		if l != len(data)-4 {
   788			return false
   789		}
   790		m.ciphertext = data[4:]
   791		return true
   792	}
   793	
   794	type finishedMsg struct {
   795		raw        []byte
   796		verifyData []byte
   797	}
   798	
   799	func (m *finishedMsg) equal(i interface{}) bool {
   800		m1, ok := i.(*finishedMsg)
   801		if !ok {
   802			return false
   803		}
   804	
   805		return bytes.Equal(m.raw, m1.raw) &&
   806			bytes.Equal(m.verifyData, m1.verifyData)
   807	}
   808	
   809	func (m *finishedMsg) marshal() (x []byte) {
   810		if m.raw != nil {
   811			return m.raw
   812		}
   813	
   814		x = make([]byte, 4+len(m.verifyData))
   815		x[0] = typeFinished
   816		x[3] = byte(len(m.verifyData))
   817		copy(x[4:], m.verifyData)
   818		m.raw = x
   819		return
   820	}
   821	
   822	func (m *finishedMsg) unmarshal(data []byte) bool {
   823		m.raw = data
   824		if len(data) < 4 {
   825			return false
   826		}
   827		m.verifyData = data[4:]
   828		return true
   829	}
   830	
   831	type nextProtoMsg struct {
   832		raw   []byte
   833		proto string
   834	}
   835	
   836	func (m *nextProtoMsg) equal(i interface{}) bool {
   837		m1, ok := i.(*nextProtoMsg)
   838		if !ok {
   839			return false
   840		}
   841	
   842		return bytes.Equal(m.raw, m1.raw) &&
   843			m.proto == m1.proto
   844	}
   845	
   846	func (m *nextProtoMsg) marshal() []byte {
   847		if m.raw != nil {
   848			return m.raw
   849		}
   850		l := len(m.proto)
   851		if l > 255 {
   852			l = 255
   853		}
   854	
   855		padding := 32 - (l+2)%32
   856		length := l + padding + 2
   857		x := make([]byte, length+4)
   858		x[0] = typeNextProtocol
   859		x[1] = uint8(length >> 16)
   860		x[2] = uint8(length >> 8)
   861		x[3] = uint8(length)
   862	
   863		y := x[4:]
   864		y[0] = byte(l)
   865		copy(y[1:], []byte(m.proto[0:l]))
   866		y = y[1+l:]
   867		y[0] = byte(padding)
   868	
   869		m.raw = x
   870	
   871		return x
   872	}
   873	
   874	func (m *nextProtoMsg) unmarshal(data []byte) bool {
   875		m.raw = data
   876	
   877		if len(data) < 5 {
   878			return false
   879		}
   880		data = data[4:]
   881		protoLen := int(data[0])
   882		data = data[1:]
   883		if len(data) < protoLen {
   884			return false
   885		}
   886		m.proto = string(data[0:protoLen])
   887		data = data[protoLen:]
   888	
   889		if len(data) < 1 {
   890			return false
   891		}
   892		paddingLen := int(data[0])
   893		data = data[1:]
   894		if len(data) != paddingLen {
   895			return false
   896		}
   897	
   898		return true
   899	}
   900	
   901	type certificateRequestMsg struct {
   902		raw                    []byte
   903		certificateTypes       []byte
   904		certificateAuthorities [][]byte
   905	}
   906	
   907	func (m *certificateRequestMsg) equal(i interface{}) bool {
   908		m1, ok := i.(*certificateRequestMsg)
   909		if !ok {
   910			return false
   911		}
   912	
   913		return bytes.Equal(m.raw, m1.raw) &&
   914			bytes.Equal(m.certificateTypes, m1.certificateTypes) &&
   915			eqByteSlices(m.certificateAuthorities, m1.certificateAuthorities)
   916	}
   917	
   918	func (m *certificateRequestMsg) marshal() (x []byte) {
   919		if m.raw != nil {
   920			return m.raw
   921		}
   922	
   923		// See http://tools.ietf.org/html/rfc4346#section-7.4.4
   924		length := 1 + len(m.certificateTypes) + 2
   925		casLength := 0
   926		for _, ca := range m.certificateAuthorities {
   927			casLength += 2 + len(ca)
   928		}
   929		length += casLength
   930	
   931		x = make([]byte, 4+length)
   932		x[0] = typeCertificateRequest
   933		x[1] = uint8(length >> 16)
   934		x[2] = uint8(length >> 8)
   935		x[3] = uint8(length)
   936	
   937		x[4] = uint8(len(m.certificateTypes))
   938	
   939		copy(x[5:], m.certificateTypes)
   940		y := x[5+len(m.certificateTypes):]
   941		y[0] = uint8(casLength >> 8)
   942		y[1] = uint8(casLength)
   943		y = y[2:]
   944		for _, ca := range m.certificateAuthorities {
   945			y[0] = uint8(len(ca) >> 8)
   946			y[1] = uint8(len(ca))
   947			y = y[2:]
   948			copy(y, ca)
   949			y = y[len(ca):]
   950		}
   951	
   952		m.raw = x
   953		return
   954	}
   955	
   956	func (m *certificateRequestMsg) unmarshal(data []byte) bool {
   957		m.raw = data
   958	
   959		if len(data) < 5 {
   960			return false
   961		}
   962	
   963		length := uint32(data[1])<<16 | uint32(data[2])<<8 | uint32(data[3])
   964		if uint32(len(data))-4 != length {
   965			return false
   966		}
   967	
   968		numCertTypes := int(data[4])
   969		data = data[5:]
   970		if numCertTypes == 0 || len(data) <= numCertTypes {
   971			return false
   972		}
   973	
   974		m.certificateTypes = make([]byte, numCertTypes)
   975		if copy(m.certificateTypes, data) != numCertTypes {
   976			return false
   977		}
   978	
   979		data = data[numCertTypes:]
   980	
   981		if len(data) < 2 {
   982			return false
   983		}
   984		casLength := uint16(data[0])<<8 | uint16(data[1])
   985		data = data[2:]
   986		if len(data) < int(casLength) {
   987			return false
   988		}
   989		cas := make([]byte, casLength)
   990		copy(cas, data)
   991		data = data[casLength:]
   992	
   993		m.certificateAuthorities = nil
   994		for len(cas) > 0 {
   995			if len(cas) < 2 {
   996				return false
   997			}
   998			caLen := uint16(cas[0])<<8 | uint16(cas[1])
   999			cas = cas[2:]
  1000	
  1001			if len(cas) < int(caLen) {
  1002				return false
  1003			}
  1004	
  1005			m.certificateAuthorities = append(m.certificateAuthorities, cas[:caLen])
  1006			cas = cas[caLen:]
  1007		}
  1008		if len(data) > 0 {
  1009			return false
  1010		}
  1011	
  1012		return true
  1013	}
  1014	
  1015	type certificateVerifyMsg struct {
  1016		raw       []byte
  1017		signature []byte
  1018	}
  1019	
  1020	func (m *certificateVerifyMsg) equal(i interface{}) bool {
  1021		m1, ok := i.(*certificateVerifyMsg)
  1022		if !ok {
  1023			return false
  1024		}
  1025	
  1026		return bytes.Equal(m.raw, m1.raw) &&
  1027			bytes.Equal(m.signature, m1.signature)
  1028	}
  1029	
  1030	func (m *certificateVerifyMsg) marshal() (x []byte) {
  1031		if m.raw != nil {
  1032			return m.raw
  1033		}
  1034	
  1035		// See http://tools.ietf.org/html/rfc4346#section-7.4.8
  1036		siglength := len(m.signature)
  1037		length := 2 + siglength
  1038		x = make([]byte, 4+length)
  1039		x[0] = typeCertificateVerify
  1040		x[1] = uint8(length >> 16)
  1041		x[2] = uint8(length >> 8)
  1042		x[3] = uint8(length)
  1043		x[4] = uint8(siglength >> 8)
  1044		x[5] = uint8(siglength)
  1045		copy(x[6:], m.signature)
  1046	
  1047		m.raw = x
  1048	
  1049		return
  1050	}
  1051	
  1052	func (m *certificateVerifyMsg) unmarshal(data []byte) bool {
  1053		m.raw = data
  1054	
  1055		if len(data) < 6 {
  1056			return false
  1057		}
  1058	
  1059		length := uint32(data[1])<<16 | uint32(data[2])<<8 | uint32(data[3])
  1060		if uint32(len(data))-4 != length {
  1061			return false
  1062		}
  1063	
  1064		siglength := int(data[4])<<8 + int(data[5])
  1065		if len(data)-6 != siglength {
  1066			return false
  1067		}
  1068	
  1069		m.signature = data[6:]
  1070	
  1071		return true
  1072	}
  1073	
  1074	type newSessionTicketMsg struct {
  1075		raw    []byte
  1076		ticket []byte
  1077	}
  1078	
  1079	func (m *newSessionTicketMsg) equal(i interface{}) bool {
  1080		m1, ok := i.(*newSessionTicketMsg)
  1081		if !ok {
  1082			return false
  1083		}
  1084	
  1085		return bytes.Equal(m.raw, m1.raw) &&
  1086			bytes.Equal(m.ticket, m1.ticket)
  1087	}
  1088	
  1089	func (m *newSessionTicketMsg) marshal() (x []byte) {
  1090		if m.raw != nil {
  1091			return m.raw
  1092		}
  1093	
  1094		// See http://tools.ietf.org/html/rfc5077#section-3.3
  1095		ticketLen := len(m.ticket)
  1096		length := 2 + 4 + ticketLen
  1097		x = make([]byte, 4+length)
  1098		x[0] = typeNewSessionTicket
  1099		x[1] = uint8(length >> 16)
  1100		x[2] = uint8(length >> 8)
  1101		x[3] = uint8(length)
  1102		x[8] = uint8(ticketLen >> 8)
  1103		x[9] = uint8(ticketLen)
  1104		copy(x[10:], m.ticket)
  1105	
  1106		m.raw = x
  1107	
  1108		return
  1109	}
  1110	
  1111	func (m *newSessionTicketMsg) unmarshal(data []byte) bool {
  1112		m.raw = data
  1113	
  1114		if len(data) < 10 {
  1115			return false
  1116		}
  1117	
  1118		length := uint32(data[1])<<16 | uint32(data[2])<<8 | uint32(data[3])
  1119		if uint32(len(data))-4 != length {
  1120			return false
  1121		}
  1122	
  1123		ticketLen := int(data[8])<<8 + int(data[9])
  1124		if len(data)-10 != ticketLen {
  1125			return false
  1126		}
  1127	
  1128		m.ticket = data[10:]
  1129	
  1130		return true
  1131	}
  1132	
  1133	func eqUint16s(x, y []uint16) bool {
  1134		if len(x) != len(y) {
  1135			return false
  1136		}
  1137		for i, v := range x {
  1138			if y[i] != v {
  1139				return false
  1140			}
  1141		}
  1142		return true
  1143	}
  1144	
  1145	func eqStrings(x, y []string) bool {
  1146		if len(x) != len(y) {
  1147			return false
  1148		}
  1149		for i, v := range x {
  1150			if y[i] != v {
  1151				return false
  1152			}
  1153		}
  1154		return true
  1155	}
  1156	
  1157	func eqByteSlices(x, y [][]byte) bool {
  1158		if len(x) != len(y) {
  1159			return false
  1160		}
  1161		for i, v := range x {
  1162			if !bytes.Equal(v, y[i]) {
  1163				return false
  1164			}
  1165		}
  1166		return true
  1167	}

View as plain text