Source file
src/crypto/tls/handshake_messages.go
1
2
3
4
5 package tls
6
7 import (
8 "bytes"
9 "strings"
10 )
11
12 type clientHelloMsg struct {
13 raw []byte
14 vers uint16
15 random []byte
16 sessionId []byte
17 cipherSuites []uint16
18 compressionMethods []uint8
19 nextProtoNeg bool
20 serverName string
21 ocspStapling bool
22 scts bool
23 supportedCurves []CurveID
24 supportedPoints []uint8
25 ticketSupported bool
26 sessionTicket []uint8
27 supportedSignatureAlgorithms []SignatureScheme
28 secureRenegotiation []byte
29 secureRenegotiationSupported bool
30 alpnProtocols []string
31 }
32
33 func (m *clientHelloMsg) equal(i interface{}) bool {
34 m1, ok := i.(*clientHelloMsg)
35 if !ok {
36 return false
37 }
38
39 return bytes.Equal(m.raw, m1.raw) &&
40 m.vers == m1.vers &&
41 bytes.Equal(m.random, m1.random) &&
42 bytes.Equal(m.sessionId, m1.sessionId) &&
43 eqUint16s(m.cipherSuites, m1.cipherSuites) &&
44 bytes.Equal(m.compressionMethods, m1.compressionMethods) &&
45 m.nextProtoNeg == m1.nextProtoNeg &&
46 m.serverName == m1.serverName &&
47 m.ocspStapling == m1.ocspStapling &&
48 m.scts == m1.scts &&
49 eqCurveIDs(m.supportedCurves, m1.supportedCurves) &&
50 bytes.Equal(m.supportedPoints, m1.supportedPoints) &&
51 m.ticketSupported == m1.ticketSupported &&
52 bytes.Equal(m.sessionTicket, m1.sessionTicket) &&
53 eqSignatureAlgorithms(m.supportedSignatureAlgorithms, m1.supportedSignatureAlgorithms) &&
54 m.secureRenegotiationSupported == m1.secureRenegotiationSupported &&
55 bytes.Equal(m.secureRenegotiation, m1.secureRenegotiation) &&
56 eqStrings(m.alpnProtocols, m1.alpnProtocols)
57 }
58
59 func (m *clientHelloMsg) marshal() []byte {
60 if m.raw != nil {
61 return m.raw
62 }
63
64 length := 2 + 32 + 1 + len(m.sessionId) + 2 + len(m.cipherSuites)*2 + 1 + len(m.compressionMethods)
65 numExtensions := 0
66 extensionsLength := 0
67 if m.nextProtoNeg {
68 numExtensions++
69 }
70 if m.ocspStapling {
71 extensionsLength += 1 + 2 + 2
72 numExtensions++
73 }
74 if len(m.serverName) > 0 {
75 extensionsLength += 5 + len(m.serverName)
76 numExtensions++
77 }
78 if len(m.supportedCurves) > 0 {
79 extensionsLength += 2 + 2*len(m.supportedCurves)
80 numExtensions++
81 }
82 if len(m.supportedPoints) > 0 {
83 extensionsLength += 1 + len(m.supportedPoints)
84 numExtensions++
85 }
86 if m.ticketSupported {
87 extensionsLength += len(m.sessionTicket)
88 numExtensions++
89 }
90 if len(m.supportedSignatureAlgorithms) > 0 {
91 extensionsLength += 2 + 2*len(m.supportedSignatureAlgorithms)
92 numExtensions++
93 }
94 if m.secureRenegotiationSupported {
95 extensionsLength += 1 + len(m.secureRenegotiation)
96 numExtensions++
97 }
98 if len(m.alpnProtocols) > 0 {
99 extensionsLength += 2
100 for _, s := range m.alpnProtocols {
101 if l := len(s); l == 0 || l > 255 {
102 panic("invalid ALPN protocol")
103 }
104 extensionsLength++
105 extensionsLength += len(s)
106 }
107 numExtensions++
108 }
109 if m.scts {
110 numExtensions++
111 }
112 if numExtensions > 0 {
113 extensionsLength += 4 * numExtensions
114 length += 2 + extensionsLength
115 }
116
117 x := make([]byte, 4+length)
118 x[0] = typeClientHello
119 x[1] = uint8(length >> 16)
120 x[2] = uint8(length >> 8)
121 x[3] = uint8(length)
122 x[4] = uint8(m.vers >> 8)
123 x[5] = uint8(m.vers)
124 copy(x[6:38], m.random)
125 x[38] = uint8(len(m.sessionId))
126 copy(x[39:39+len(m.sessionId)], m.sessionId)
127 y := x[39+len(m.sessionId):]
128 y[0] = uint8(len(m.cipherSuites) >> 7)
129 y[1] = uint8(len(m.cipherSuites) << 1)
130 for i, suite := range m.cipherSuites {
131 y[2+i*2] = uint8(suite >> 8)
132 y[3+i*2] = uint8(suite)
133 }
134 z := y[2+len(m.cipherSuites)*2:]
135 z[0] = uint8(len(m.compressionMethods))
136 copy(z[1:], m.compressionMethods)
137
138 z = z[1+len(m.compressionMethods):]
139 if numExtensions > 0 {
140 z[0] = byte(extensionsLength >> 8)
141 z[1] = byte(extensionsLength)
142 z = z[2:]
143 }
144 if m.nextProtoNeg {
145 z[0] = byte(extensionNextProtoNeg >> 8)
146 z[1] = byte(extensionNextProtoNeg & 0xff)
147
148 z = z[4:]
149 }
150 if len(m.serverName) > 0 {
151 z[0] = byte(extensionServerName >> 8)
152 z[1] = byte(extensionServerName & 0xff)
153 l := len(m.serverName) + 5
154 z[2] = byte(l >> 8)
155 z[3] = byte(l)
156 z = z[4:]
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177 z[0] = byte((len(m.serverName) + 3) >> 8)
178 z[1] = byte(len(m.serverName) + 3)
179 z[3] = byte(len(m.serverName) >> 8)
180 z[4] = byte(len(m.serverName))
181 copy(z[5:], []byte(m.serverName))
182 z = z[l:]
183 }
184 if m.ocspStapling {
185
186 z[0] = byte(extensionStatusRequest >> 8)
187 z[1] = byte(extensionStatusRequest)
188 z[2] = 0
189 z[3] = 5
190 z[4] = 1
191
192 z = z[9:]
193 }
194 if len(m.supportedCurves) > 0 {
195
196 z[0] = byte(extensionSupportedCurves >> 8)
197 z[1] = byte(extensionSupportedCurves)
198 l := 2 + 2*len(m.supportedCurves)
199 z[2] = byte(l >> 8)
200 z[3] = byte(l)
201 l -= 2
202 z[4] = byte(l >> 8)
203 z[5] = byte(l)
204 z = z[6:]
205 for _, curve := range m.supportedCurves {
206 z[0] = byte(curve >> 8)
207 z[1] = byte(curve)
208 z = z[2:]
209 }
210 }
211 if len(m.supportedPoints) > 0 {
212
213 z[0] = byte(extensionSupportedPoints >> 8)
214 z[1] = byte(extensionSupportedPoints)
215 l := 1 + len(m.supportedPoints)
216 z[2] = byte(l >> 8)
217 z[3] = byte(l)
218 l--
219 z[4] = byte(l)
220 z = z[5:]
221 for _, pointFormat := range m.supportedPoints {
222 z[0] = pointFormat
223 z = z[1:]
224 }
225 }
226 if m.ticketSupported {
227
228 z[0] = byte(extensionSessionTicket >> 8)
229 z[1] = byte(extensionSessionTicket)
230 l := len(m.sessionTicket)
231 z[2] = byte(l >> 8)
232 z[3] = byte(l)
233 z = z[4:]
234 copy(z, m.sessionTicket)
235 z = z[len(m.sessionTicket):]
236 }
237 if len(m.supportedSignatureAlgorithms) > 0 {
238
239 z[0] = byte(extensionSignatureAlgorithms >> 8)
240 z[1] = byte(extensionSignatureAlgorithms)
241 l := 2 + 2*len(m.supportedSignatureAlgorithms)
242 z[2] = byte(l >> 8)
243 z[3] = byte(l)
244 z = z[4:]
245
246 l -= 2
247 z[0] = byte(l >> 8)
248 z[1] = byte(l)
249 z = z[2:]
250 for _, sigAlgo := range m.supportedSignatureAlgorithms {
251 z[0] = byte(sigAlgo >> 8)
252 z[1] = byte(sigAlgo)
253 z = z[2:]
254 }
255 }
256 if m.secureRenegotiationSupported {
257 z[0] = byte(extensionRenegotiationInfo >> 8)
258 z[1] = byte(extensionRenegotiationInfo & 0xff)
259 z[2] = 0
260 z[3] = byte(len(m.secureRenegotiation) + 1)
261 z[4] = byte(len(m.secureRenegotiation))
262 z = z[5:]
263 copy(z, m.secureRenegotiation)
264 z = z[len(m.secureRenegotiation):]
265 }
266 if len(m.alpnProtocols) > 0 {
267 z[0] = byte(extensionALPN >> 8)
268 z[1] = byte(extensionALPN & 0xff)
269 lengths := z[2:]
270 z = z[6:]
271
272 stringsLength := 0
273 for _, s := range m.alpnProtocols {
274 l := len(s)
275 z[0] = byte(l)
276 copy(z[1:], s)
277 z = z[1+l:]
278 stringsLength += 1 + l
279 }
280
281 lengths[2] = byte(stringsLength >> 8)
282 lengths[3] = byte(stringsLength)
283 stringsLength += 2
284 lengths[0] = byte(stringsLength >> 8)
285 lengths[1] = byte(stringsLength)
286 }
287 if m.scts {
288
289 z[0] = byte(extensionSCT >> 8)
290 z[1] = byte(extensionSCT)
291
292 z = z[4:]
293 }
294
295 m.raw = x
296
297 return x
298 }
299
300 func (m *clientHelloMsg) unmarshal(data []byte) bool {
301 if len(data) < 42 {
302 return false
303 }
304 m.raw = data
305 m.vers = uint16(data[4])<<8 | uint16(data[5])
306 m.random = data[6:38]
307 sessionIdLen := int(data[38])
308 if sessionIdLen > 32 || len(data) < 39+sessionIdLen {
309 return false
310 }
311 m.sessionId = data[39 : 39+sessionIdLen]
312 data = data[39+sessionIdLen:]
313 if len(data) < 2 {
314 return false
315 }
316
317
318 cipherSuiteLen := int(data[0])<<8 | int(data[1])
319 if cipherSuiteLen%2 == 1 || len(data) < 2+cipherSuiteLen {
320 return false
321 }
322 numCipherSuites := cipherSuiteLen / 2
323 m.cipherSuites = make([]uint16, numCipherSuites)
324 for i := 0; i < numCipherSuites; i++ {
325 m.cipherSuites[i] = uint16(data[2+2*i])<<8 | uint16(data[3+2*i])
326 if m.cipherSuites[i] == scsvRenegotiation {
327 m.secureRenegotiationSupported = true
328 }
329 }
330 data = data[2+cipherSuiteLen:]
331 if len(data) < 1 {
332 return false
333 }
334 compressionMethodsLen := int(data[0])
335 if len(data) < 1+compressionMethodsLen {
336 return false
337 }
338 m.compressionMethods = data[1 : 1+compressionMethodsLen]
339
340 data = data[1+compressionMethodsLen:]
341
342 m.nextProtoNeg = false
343 m.serverName = ""
344 m.ocspStapling = false
345 m.ticketSupported = false
346 m.sessionTicket = nil
347 m.supportedSignatureAlgorithms = nil
348 m.alpnProtocols = nil
349 m.scts = false
350
351 if len(data) == 0 {
352
353 return true
354 }
355 if len(data) < 2 {
356 return false
357 }
358
359 extensionsLength := int(data[0])<<8 | int(data[1])
360 data = data[2:]
361 if extensionsLength != len(data) {
362 return false
363 }
364
365 for len(data) != 0 {
366 if len(data) < 4 {
367 return false
368 }
369 extension := uint16(data[0])<<8 | uint16(data[1])
370 length := int(data[2])<<8 | int(data[3])
371 data = data[4:]
372 if len(data) < length {
373 return false
374 }
375
376 switch extension {
377 case extensionServerName:
378 d := data[:length]
379 if len(d) < 2 {
380 return false
381 }
382 namesLen := int(d[0])<<8 | int(d[1])
383 d = d[2:]
384 if len(d) != namesLen {
385 return false
386 }
387 for len(d) > 0 {
388 if len(d) < 3 {
389 return false
390 }
391 nameType := d[0]
392 nameLen := int(d[1])<<8 | int(d[2])
393 d = d[3:]
394 if len(d) < nameLen {
395 return false
396 }
397 if nameType == 0 {
398 m.serverName = string(d[:nameLen])
399
400
401
402 if strings.HasSuffix(m.serverName, ".") {
403 return false
404 }
405 break
406 }
407 d = d[nameLen:]
408 }
409 case extensionNextProtoNeg:
410 if length > 0 {
411 return false
412 }
413 m.nextProtoNeg = true
414 case extensionStatusRequest:
415 m.ocspStapling = length > 0 && data[0] == statusTypeOCSP
416 case extensionSupportedCurves:
417
418 if length < 2 {
419 return false
420 }
421 l := int(data[0])<<8 | int(data[1])
422 if l%2 == 1 || length != l+2 {
423 return false
424 }
425 numCurves := l / 2
426 m.supportedCurves = make([]CurveID, numCurves)
427 d := data[2:]
428 for i := 0; i < numCurves; i++ {
429 m.supportedCurves[i] = CurveID(d[0])<<8 | CurveID(d[1])
430 d = d[2:]
431 }
432 case extensionSupportedPoints:
433
434 if length < 1 {
435 return false
436 }
437 l := int(data[0])
438 if length != l+1 {
439 return false
440 }
441 m.supportedPoints = make([]uint8, l)
442 copy(m.supportedPoints, data[1:])
443 case extensionSessionTicket:
444
445 m.ticketSupported = true
446 m.sessionTicket = data[:length]
447 case extensionSignatureAlgorithms:
448
449 if length < 2 || length&1 != 0 {
450 return false
451 }
452 l := int(data[0])<<8 | int(data[1])
453 if l != length-2 {
454 return false
455 }
456 n := l / 2
457 d := data[2:]
458 m.supportedSignatureAlgorithms = make([]SignatureScheme, n)
459 for i := range m.supportedSignatureAlgorithms {
460 m.supportedSignatureAlgorithms[i] = SignatureScheme(d[0])<<8 | SignatureScheme(d[1])
461 d = d[2:]
462 }
463 case extensionRenegotiationInfo:
464 if length == 0 {
465 return false
466 }
467 d := data[:length]
468 l := int(d[0])
469 d = d[1:]
470 if l != len(d) {
471 return false
472 }
473
474 m.secureRenegotiation = d
475 m.secureRenegotiationSupported = true
476 case extensionALPN:
477 if length < 2 {
478 return false
479 }
480 l := int(data[0])<<8 | int(data[1])
481 if l != length-2 {
482 return false
483 }
484 d := data[2:length]
485 for len(d) != 0 {
486 stringLen := int(d[0])
487 d = d[1:]
488 if stringLen == 0 || stringLen > len(d) {
489 return false
490 }
491 m.alpnProtocols = append(m.alpnProtocols, string(d[:stringLen]))
492 d = d[stringLen:]
493 }
494 case extensionSCT:
495 m.scts = true
496 if length != 0 {
497 return false
498 }
499 }
500 data = data[length:]
501 }
502
503 return true
504 }
505
506 type serverHelloMsg struct {
507 raw []byte
508 vers uint16
509 random []byte
510 sessionId []byte
511 cipherSuite uint16
512 compressionMethod uint8
513 nextProtoNeg bool
514 nextProtos []string
515 ocspStapling bool
516 scts [][]byte
517 ticketSupported bool
518 secureRenegotiation []byte
519 secureRenegotiationSupported bool
520 alpnProtocol string
521 }
522
523 func (m *serverHelloMsg) equal(i interface{}) bool {
524 m1, ok := i.(*serverHelloMsg)
525 if !ok {
526 return false
527 }
528
529 if len(m.scts) != len(m1.scts) {
530 return false
531 }
532 for i, sct := range m.scts {
533 if !bytes.Equal(sct, m1.scts[i]) {
534 return false
535 }
536 }
537
538 return bytes.Equal(m.raw, m1.raw) &&
539 m.vers == m1.vers &&
540 bytes.Equal(m.random, m1.random) &&
541 bytes.Equal(m.sessionId, m1.sessionId) &&
542 m.cipherSuite == m1.cipherSuite &&
543 m.compressionMethod == m1.compressionMethod &&
544 m.nextProtoNeg == m1.nextProtoNeg &&
545 eqStrings(m.nextProtos, m1.nextProtos) &&
546 m.ocspStapling == m1.ocspStapling &&
547 m.ticketSupported == m1.ticketSupported &&
548 m.secureRenegotiationSupported == m1.secureRenegotiationSupported &&
549 bytes.Equal(m.secureRenegotiation, m1.secureRenegotiation) &&
550 m.alpnProtocol == m1.alpnProtocol
551 }
552
553 func (m *serverHelloMsg) marshal() []byte {
554 if m.raw != nil {
555 return m.raw
556 }
557
558 length := 38 + len(m.sessionId)
559 numExtensions := 0
560 extensionsLength := 0
561
562 nextProtoLen := 0
563 if m.nextProtoNeg {
564 numExtensions++
565 for _, v := range m.nextProtos {
566 nextProtoLen += len(v)
567 }
568 nextProtoLen += len(m.nextProtos)
569 extensionsLength += nextProtoLen
570 }
571 if m.ocspStapling {
572 numExtensions++
573 }
574 if m.ticketSupported {
575 numExtensions++
576 }
577 if m.secureRenegotiationSupported {
578 extensionsLength += 1 + len(m.secureRenegotiation)
579 numExtensions++
580 }
581 if alpnLen := len(m.alpnProtocol); alpnLen > 0 {
582 if alpnLen >= 256 {
583 panic("invalid ALPN protocol")
584 }
585 extensionsLength += 2 + 1 + alpnLen
586 numExtensions++
587 }
588 sctLen := 0
589 if len(m.scts) > 0 {
590 for _, sct := range m.scts {
591 sctLen += len(sct) + 2
592 }
593 extensionsLength += 2 + sctLen
594 numExtensions++
595 }
596
597 if numExtensions > 0 {
598 extensionsLength += 4 * numExtensions
599 length += 2 + extensionsLength
600 }
601
602 x := make([]byte, 4+length)
603 x[0] = typeServerHello
604 x[1] = uint8(length >> 16)
605 x[2] = uint8(length >> 8)
606 x[3] = uint8(length)
607 x[4] = uint8(m.vers >> 8)
608 x[5] = uint8(m.vers)
609 copy(x[6:38], m.random)
610 x[38] = uint8(len(m.sessionId))
611 copy(x[39:39+len(m.sessionId)], m.sessionId)
612 z := x[39+len(m.sessionId):]
613 z[0] = uint8(m.cipherSuite >> 8)
614 z[1] = uint8(m.cipherSuite)
615 z[2] = m.compressionMethod
616
617 z = z[3:]
618 if numExtensions > 0 {
619 z[0] = byte(extensionsLength >> 8)
620 z[1] = byte(extensionsLength)
621 z = z[2:]
622 }
623 if m.nextProtoNeg {
624 z[0] = byte(extensionNextProtoNeg >> 8)
625 z[1] = byte(extensionNextProtoNeg & 0xff)
626 z[2] = byte(nextProtoLen >> 8)
627 z[3] = byte(nextProtoLen)
628 z = z[4:]
629
630 for _, v := range m.nextProtos {
631 l := len(v)
632 if l > 255 {
633 l = 255
634 }
635 z[0] = byte(l)
636 copy(z[1:], []byte(v[0:l]))
637 z = z[1+l:]
638 }
639 }
640 if m.ocspStapling {
641 z[0] = byte(extensionStatusRequest >> 8)
642 z[1] = byte(extensionStatusRequest)
643 z = z[4:]
644 }
645 if m.ticketSupported {
646 z[0] = byte(extensionSessionTicket >> 8)
647 z[1] = byte(extensionSessionTicket)
648 z = z[4:]
649 }
650 if m.secureRenegotiationSupported {
651 z[0] = byte(extensionRenegotiationInfo >> 8)
652 z[1] = byte(extensionRenegotiationInfo & 0xff)
653 z[2] = 0
654 z[3] = byte(len(m.secureRenegotiation) + 1)
655 z[4] = byte(len(m.secureRenegotiation))
656 z = z[5:]
657 copy(z, m.secureRenegotiation)
658 z = z[len(m.secureRenegotiation):]
659 }
660 if alpnLen := len(m.alpnProtocol); alpnLen > 0 {
661 z[0] = byte(extensionALPN >> 8)
662 z[1] = byte(extensionALPN & 0xff)
663 l := 2 + 1 + alpnLen
664 z[2] = byte(l >> 8)
665 z[3] = byte(l)
666 l -= 2
667 z[4] = byte(l >> 8)
668 z[5] = byte(l)
669 l -= 1
670 z[6] = byte(l)
671 copy(z[7:], []byte(m.alpnProtocol))
672 z = z[7+alpnLen:]
673 }
674 if sctLen > 0 {
675 z[0] = byte(extensionSCT >> 8)
676 z[1] = byte(extensionSCT)
677 l := sctLen + 2
678 z[2] = byte(l >> 8)
679 z[3] = byte(l)
680 z[4] = byte(sctLen >> 8)
681 z[5] = byte(sctLen)
682
683 z = z[6:]
684 for _, sct := range m.scts {
685 z[0] = byte(len(sct) >> 8)
686 z[1] = byte(len(sct))
687 copy(z[2:], sct)
688 z = z[len(sct)+2:]
689 }
690 }
691
692 m.raw = x
693
694 return x
695 }
696
697 func (m *serverHelloMsg) unmarshal(data []byte) bool {
698 if len(data) < 42 {
699 return false
700 }
701 m.raw = data
702 m.vers = uint16(data[4])<<8 | uint16(data[5])
703 m.random = data[6:38]
704 sessionIdLen := int(data[38])
705 if sessionIdLen > 32 || len(data) < 39+sessionIdLen {
706 return false
707 }
708 m.sessionId = data[39 : 39+sessionIdLen]
709 data = data[39+sessionIdLen:]
710 if len(data) < 3 {
711 return false
712 }
713 m.cipherSuite = uint16(data[0])<<8 | uint16(data[1])
714 m.compressionMethod = data[2]
715 data = data[3:]
716
717 m.nextProtoNeg = false
718 m.nextProtos = nil
719 m.ocspStapling = false
720 m.scts = nil
721 m.ticketSupported = false
722 m.alpnProtocol = ""
723
724 if len(data) == 0 {
725
726 return true
727 }
728 if len(data) < 2 {
729 return false
730 }
731
732 extensionsLength := int(data[0])<<8 | int(data[1])
733 data = data[2:]
734 if len(data) != extensionsLength {
735 return false
736 }
737
738 for len(data) != 0 {
739 if len(data) < 4 {
740 return false
741 }
742 extension := uint16(data[0])<<8 | uint16(data[1])
743 length := int(data[2])<<8 | int(data[3])
744 data = data[4:]
745 if len(data) < length {
746 return false
747 }
748
749 switch extension {
750 case extensionNextProtoNeg:
751 m.nextProtoNeg = true
752 d := data[:length]
753 for len(d) > 0 {
754 l := int(d[0])
755 d = d[1:]
756 if l == 0 || l > len(d) {
757 return false
758 }
759 m.nextProtos = append(m.nextProtos, string(d[:l]))
760 d = d[l:]
761 }
762 case extensionStatusRequest:
763 if length > 0 {
764 return false
765 }
766 m.ocspStapling = true
767 case extensionSessionTicket:
768 if length > 0 {
769 return false
770 }
771 m.ticketSupported = true
772 case extensionRenegotiationInfo:
773 if length == 0 {
774 return false
775 }
776 d := data[:length]
777 l := int(d[0])
778 d = d[1:]
779 if l != len(d) {
780 return false
781 }
782
783 m.secureRenegotiation = d
784 m.secureRenegotiationSupported = true
785 case extensionALPN:
786 d := data[:length]
787 if len(d) < 3 {
788 return false
789 }
790 l := int(d[0])<<8 | int(d[1])
791 if l != len(d)-2 {
792 return false
793 }
794 d = d[2:]
795 l = int(d[0])
796 if l != len(d)-1 {
797 return false
798 }
799 d = d[1:]
800 if len(d) == 0 {
801
802 return false
803 }
804 m.alpnProtocol = string(d)
805 case extensionSCT:
806 d := data[:length]
807
808 if len(d) < 2 {
809 return false
810 }
811 l := int(d[0])<<8 | int(d[1])
812 d = d[2:]
813 if len(d) != l || l == 0 {
814 return false
815 }
816
817 m.scts = make([][]byte, 0, 3)
818 for len(d) != 0 {
819 if len(d) < 2 {
820 return false
821 }
822 sctLen := int(d[0])<<8 | int(d[1])
823 d = d[2:]
824 if sctLen == 0 || len(d) < sctLen {
825 return false
826 }
827 m.scts = append(m.scts, d[:sctLen])
828 d = d[sctLen:]
829 }
830 }
831 data = data[length:]
832 }
833
834 return true
835 }
836
837 type certificateMsg struct {
838 raw []byte
839 certificates [][]byte
840 }
841
842 func (m *certificateMsg) equal(i interface{}) bool {
843 m1, ok := i.(*certificateMsg)
844 if !ok {
845 return false
846 }
847
848 return bytes.Equal(m.raw, m1.raw) &&
849 eqByteSlices(m.certificates, m1.certificates)
850 }
851
852 func (m *certificateMsg) marshal() (x []byte) {
853 if m.raw != nil {
854 return m.raw
855 }
856
857 var i int
858 for _, slice := range m.certificates {
859 i += len(slice)
860 }
861
862 length := 3 + 3*len(m.certificates) + i
863 x = make([]byte, 4+length)
864 x[0] = typeCertificate
865 x[1] = uint8(length >> 16)
866 x[2] = uint8(length >> 8)
867 x[3] = uint8(length)
868
869 certificateOctets := length - 3
870 x[4] = uint8(certificateOctets >> 16)
871 x[5] = uint8(certificateOctets >> 8)
872 x[6] = uint8(certificateOctets)
873
874 y := x[7:]
875 for _, slice := range m.certificates {
876 y[0] = uint8(len(slice) >> 16)
877 y[1] = uint8(len(slice) >> 8)
878 y[2] = uint8(len(slice))
879 copy(y[3:], slice)
880 y = y[3+len(slice):]
881 }
882
883 m.raw = x
884 return
885 }
886
887 func (m *certificateMsg) unmarshal(data []byte) bool {
888 if len(data) < 7 {
889 return false
890 }
891
892 m.raw = data
893 certsLen := uint32(data[4])<<16 | uint32(data[5])<<8 | uint32(data[6])
894 if uint32(len(data)) != certsLen+7 {
895 return false
896 }
897
898 numCerts := 0
899 d := data[7:]
900 for certsLen > 0 {
901 if len(d) < 4 {
902 return false
903 }
904 certLen := uint32(d[0])<<16 | uint32(d[1])<<8 | uint32(d[2])
905 if uint32(len(d)) < 3+certLen {
906 return false
907 }
908 d = d[3+certLen:]
909 certsLen -= 3 + certLen
910 numCerts++
911 }
912
913 m.certificates = make([][]byte, numCerts)
914 d = data[7:]
915 for i := 0; i < numCerts; i++ {
916 certLen := uint32(d[0])<<16 | uint32(d[1])<<8 | uint32(d[2])
917 m.certificates[i] = d[3 : 3+certLen]
918 d = d[3+certLen:]
919 }
920
921 return true
922 }
923
924 type serverKeyExchangeMsg struct {
925 raw []byte
926 key []byte
927 }
928
929 func (m *serverKeyExchangeMsg) equal(i interface{}) bool {
930 m1, ok := i.(*serverKeyExchangeMsg)
931 if !ok {
932 return false
933 }
934
935 return bytes.Equal(m.raw, m1.raw) &&
936 bytes.Equal(m.key, m1.key)
937 }
938
939 func (m *serverKeyExchangeMsg) marshal() []byte {
940 if m.raw != nil {
941 return m.raw
942 }
943 length := len(m.key)
944 x := make([]byte, length+4)
945 x[0] = typeServerKeyExchange
946 x[1] = uint8(length >> 16)
947 x[2] = uint8(length >> 8)
948 x[3] = uint8(length)
949 copy(x[4:], m.key)
950
951 m.raw = x
952 return x
953 }
954
955 func (m *serverKeyExchangeMsg) unmarshal(data []byte) bool {
956 m.raw = data
957 if len(data) < 4 {
958 return false
959 }
960 m.key = data[4:]
961 return true
962 }
963
964 type certificateStatusMsg struct {
965 raw []byte
966 statusType uint8
967 response []byte
968 }
969
970 func (m *certificateStatusMsg) equal(i interface{}) bool {
971 m1, ok := i.(*certificateStatusMsg)
972 if !ok {
973 return false
974 }
975
976 return bytes.Equal(m.raw, m1.raw) &&
977 m.statusType == m1.statusType &&
978 bytes.Equal(m.response, m1.response)
979 }
980
981 func (m *certificateStatusMsg) marshal() []byte {
982 if m.raw != nil {
983 return m.raw
984 }
985
986 var x []byte
987 if m.statusType == statusTypeOCSP {
988 x = make([]byte, 4+4+len(m.response))
989 x[0] = typeCertificateStatus
990 l := len(m.response) + 4
991 x[1] = byte(l >> 16)
992 x[2] = byte(l >> 8)
993 x[3] = byte(l)
994 x[4] = statusTypeOCSP
995
996 l -= 4
997 x[5] = byte(l >> 16)
998 x[6] = byte(l >> 8)
999 x[7] = byte(l)
1000 copy(x[8:], m.response)
1001 } else {
1002 x = []byte{typeCertificateStatus, 0, 0, 1, m.statusType}
1003 }
1004
1005 m.raw = x
1006 return x
1007 }
1008
1009 func (m *certificateStatusMsg) unmarshal(data []byte) bool {
1010 m.raw = data
1011 if len(data) < 5 {
1012 return false
1013 }
1014 m.statusType = data[4]
1015
1016 m.response = nil
1017 if m.statusType == statusTypeOCSP {
1018 if len(data) < 8 {
1019 return false
1020 }
1021 respLen := uint32(data[5])<<16 | uint32(data[6])<<8 | uint32(data[7])
1022 if uint32(len(data)) != 4+4+respLen {
1023 return false
1024 }
1025 m.response = data[8:]
1026 }
1027 return true
1028 }
1029
1030 type serverHelloDoneMsg struct{}
1031
1032 func (m *serverHelloDoneMsg) equal(i interface{}) bool {
1033 _, ok := i.(*serverHelloDoneMsg)
1034 return ok
1035 }
1036
1037 func (m *serverHelloDoneMsg) marshal() []byte {
1038 x := make([]byte, 4)
1039 x[0] = typeServerHelloDone
1040 return x
1041 }
1042
1043 func (m *serverHelloDoneMsg) unmarshal(data []byte) bool {
1044 return len(data) == 4
1045 }
1046
1047 type clientKeyExchangeMsg struct {
1048 raw []byte
1049 ciphertext []byte
1050 }
1051
1052 func (m *clientKeyExchangeMsg) equal(i interface{}) bool {
1053 m1, ok := i.(*clientKeyExchangeMsg)
1054 if !ok {
1055 return false
1056 }
1057
1058 return bytes.Equal(m.raw, m1.raw) &&
1059 bytes.Equal(m.ciphertext, m1.ciphertext)
1060 }
1061
1062 func (m *clientKeyExchangeMsg) marshal() []byte {
1063 if m.raw != nil {
1064 return m.raw
1065 }
1066 length := len(m.ciphertext)
1067 x := make([]byte, length+4)
1068 x[0] = typeClientKeyExchange
1069 x[1] = uint8(length >> 16)
1070 x[2] = uint8(length >> 8)
1071 x[3] = uint8(length)
1072 copy(x[4:], m.ciphertext)
1073
1074 m.raw = x
1075 return x
1076 }
1077
1078 func (m *clientKeyExchangeMsg) unmarshal(data []byte) bool {
1079 m.raw = data
1080 if len(data) < 4 {
1081 return false
1082 }
1083 l := int(data[1])<<16 | int(data[2])<<8 | int(data[3])
1084 if l != len(data)-4 {
1085 return false
1086 }
1087 m.ciphertext = data[4:]
1088 return true
1089 }
1090
1091 type finishedMsg struct {
1092 raw []byte
1093 verifyData []byte
1094 }
1095
1096 func (m *finishedMsg) equal(i interface{}) bool {
1097 m1, ok := i.(*finishedMsg)
1098 if !ok {
1099 return false
1100 }
1101
1102 return bytes.Equal(m.raw, m1.raw) &&
1103 bytes.Equal(m.verifyData, m1.verifyData)
1104 }
1105
1106 func (m *finishedMsg) marshal() (x []byte) {
1107 if m.raw != nil {
1108 return m.raw
1109 }
1110
1111 x = make([]byte, 4+len(m.verifyData))
1112 x[0] = typeFinished
1113 x[3] = byte(len(m.verifyData))
1114 copy(x[4:], m.verifyData)
1115 m.raw = x
1116 return
1117 }
1118
1119 func (m *finishedMsg) unmarshal(data []byte) bool {
1120 m.raw = data
1121 if len(data) < 4 {
1122 return false
1123 }
1124 m.verifyData = data[4:]
1125 return true
1126 }
1127
1128 type nextProtoMsg struct {
1129 raw []byte
1130 proto string
1131 }
1132
1133 func (m *nextProtoMsg) equal(i interface{}) bool {
1134 m1, ok := i.(*nextProtoMsg)
1135 if !ok {
1136 return false
1137 }
1138
1139 return bytes.Equal(m.raw, m1.raw) &&
1140 m.proto == m1.proto
1141 }
1142
1143 func (m *nextProtoMsg) marshal() []byte {
1144 if m.raw != nil {
1145 return m.raw
1146 }
1147 l := len(m.proto)
1148 if l > 255 {
1149 l = 255
1150 }
1151
1152 padding := 32 - (l+2)%32
1153 length := l + padding + 2
1154 x := make([]byte, length+4)
1155 x[0] = typeNextProtocol
1156 x[1] = uint8(length >> 16)
1157 x[2] = uint8(length >> 8)
1158 x[3] = uint8(length)
1159
1160 y := x[4:]
1161 y[0] = byte(l)
1162 copy(y[1:], []byte(m.proto[0:l]))
1163 y = y[1+l:]
1164 y[0] = byte(padding)
1165
1166 m.raw = x
1167
1168 return x
1169 }
1170
1171 func (m *nextProtoMsg) unmarshal(data []byte) bool {
1172 m.raw = data
1173
1174 if len(data) < 5 {
1175 return false
1176 }
1177 data = data[4:]
1178 protoLen := int(data[0])
1179 data = data[1:]
1180 if len(data) < protoLen {
1181 return false
1182 }
1183 m.proto = string(data[0:protoLen])
1184 data = data[protoLen:]
1185
1186 if len(data) < 1 {
1187 return false
1188 }
1189 paddingLen := int(data[0])
1190 data = data[1:]
1191 if len(data) != paddingLen {
1192 return false
1193 }
1194
1195 return true
1196 }
1197
1198 type certificateRequestMsg struct {
1199 raw []byte
1200
1201
1202
1203 hasSignatureAndHash bool
1204
1205 certificateTypes []byte
1206 supportedSignatureAlgorithms []SignatureScheme
1207 certificateAuthorities [][]byte
1208 }
1209
1210 func (m *certificateRequestMsg) equal(i interface{}) bool {
1211 m1, ok := i.(*certificateRequestMsg)
1212 if !ok {
1213 return false
1214 }
1215
1216 return bytes.Equal(m.raw, m1.raw) &&
1217 bytes.Equal(m.certificateTypes, m1.certificateTypes) &&
1218 eqByteSlices(m.certificateAuthorities, m1.certificateAuthorities) &&
1219 eqSignatureAlgorithms(m.supportedSignatureAlgorithms, m1.supportedSignatureAlgorithms)
1220 }
1221
1222 func (m *certificateRequestMsg) marshal() (x []byte) {
1223 if m.raw != nil {
1224 return m.raw
1225 }
1226
1227
1228 length := 1 + len(m.certificateTypes) + 2
1229 casLength := 0
1230 for _, ca := range m.certificateAuthorities {
1231 casLength += 2 + len(ca)
1232 }
1233 length += casLength
1234
1235 if m.hasSignatureAndHash {
1236 length += 2 + 2*len(m.supportedSignatureAlgorithms)
1237 }
1238
1239 x = make([]byte, 4+length)
1240 x[0] = typeCertificateRequest
1241 x[1] = uint8(length >> 16)
1242 x[2] = uint8(length >> 8)
1243 x[3] = uint8(length)
1244
1245 x[4] = uint8(len(m.certificateTypes))
1246
1247 copy(x[5:], m.certificateTypes)
1248 y := x[5+len(m.certificateTypes):]
1249
1250 if m.hasSignatureAndHash {
1251 n := len(m.supportedSignatureAlgorithms) * 2
1252 y[0] = uint8(n >> 8)
1253 y[1] = uint8(n)
1254 y = y[2:]
1255 for _, sigAlgo := range m.supportedSignatureAlgorithms {
1256 y[0] = uint8(sigAlgo >> 8)
1257 y[1] = uint8(sigAlgo)
1258 y = y[2:]
1259 }
1260 }
1261
1262 y[0] = uint8(casLength >> 8)
1263 y[1] = uint8(casLength)
1264 y = y[2:]
1265 for _, ca := range m.certificateAuthorities {
1266 y[0] = uint8(len(ca) >> 8)
1267 y[1] = uint8(len(ca))
1268 y = y[2:]
1269 copy(y, ca)
1270 y = y[len(ca):]
1271 }
1272
1273 m.raw = x
1274 return
1275 }
1276
1277 func (m *certificateRequestMsg) unmarshal(data []byte) bool {
1278 m.raw = data
1279
1280 if len(data) < 5 {
1281 return false
1282 }
1283
1284 length := uint32(data[1])<<16 | uint32(data[2])<<8 | uint32(data[3])
1285 if uint32(len(data))-4 != length {
1286 return false
1287 }
1288
1289 numCertTypes := int(data[4])
1290 data = data[5:]
1291 if numCertTypes == 0 || len(data) <= numCertTypes {
1292 return false
1293 }
1294
1295 m.certificateTypes = make([]byte, numCertTypes)
1296 if copy(m.certificateTypes, data) != numCertTypes {
1297 return false
1298 }
1299
1300 data = data[numCertTypes:]
1301
1302 if m.hasSignatureAndHash {
1303 if len(data) < 2 {
1304 return false
1305 }
1306 sigAndHashLen := uint16(data[0])<<8 | uint16(data[1])
1307 data = data[2:]
1308 if sigAndHashLen&1 != 0 {
1309 return false
1310 }
1311 if len(data) < int(sigAndHashLen) {
1312 return false
1313 }
1314 numSigAlgos := sigAndHashLen / 2
1315 m.supportedSignatureAlgorithms = make([]SignatureScheme, numSigAlgos)
1316 for i := range m.supportedSignatureAlgorithms {
1317 m.supportedSignatureAlgorithms[i] = SignatureScheme(data[0])<<8 | SignatureScheme(data[1])
1318 data = data[2:]
1319 }
1320 }
1321
1322 if len(data) < 2 {
1323 return false
1324 }
1325 casLength := uint16(data[0])<<8 | uint16(data[1])
1326 data = data[2:]
1327 if len(data) < int(casLength) {
1328 return false
1329 }
1330 cas := make([]byte, casLength)
1331 copy(cas, data)
1332 data = data[casLength:]
1333
1334 m.certificateAuthorities = nil
1335 for len(cas) > 0 {
1336 if len(cas) < 2 {
1337 return false
1338 }
1339 caLen := uint16(cas[0])<<8 | uint16(cas[1])
1340 cas = cas[2:]
1341
1342 if len(cas) < int(caLen) {
1343 return false
1344 }
1345
1346 m.certificateAuthorities = append(m.certificateAuthorities, cas[:caLen])
1347 cas = cas[caLen:]
1348 }
1349
1350 return len(data) == 0
1351 }
1352
1353 type certificateVerifyMsg struct {
1354 raw []byte
1355 hasSignatureAndHash bool
1356 signatureAlgorithm SignatureScheme
1357 signature []byte
1358 }
1359
1360 func (m *certificateVerifyMsg) equal(i interface{}) bool {
1361 m1, ok := i.(*certificateVerifyMsg)
1362 if !ok {
1363 return false
1364 }
1365
1366 return bytes.Equal(m.raw, m1.raw) &&
1367 m.hasSignatureAndHash == m1.hasSignatureAndHash &&
1368 m.signatureAlgorithm == m1.signatureAlgorithm &&
1369 bytes.Equal(m.signature, m1.signature)
1370 }
1371
1372 func (m *certificateVerifyMsg) marshal() (x []byte) {
1373 if m.raw != nil {
1374 return m.raw
1375 }
1376
1377
1378 siglength := len(m.signature)
1379 length := 2 + siglength
1380 if m.hasSignatureAndHash {
1381 length += 2
1382 }
1383 x = make([]byte, 4+length)
1384 x[0] = typeCertificateVerify
1385 x[1] = uint8(length >> 16)
1386 x[2] = uint8(length >> 8)
1387 x[3] = uint8(length)
1388 y := x[4:]
1389 if m.hasSignatureAndHash {
1390 y[0] = uint8(m.signatureAlgorithm >> 8)
1391 y[1] = uint8(m.signatureAlgorithm)
1392 y = y[2:]
1393 }
1394 y[0] = uint8(siglength >> 8)
1395 y[1] = uint8(siglength)
1396 copy(y[2:], m.signature)
1397
1398 m.raw = x
1399
1400 return
1401 }
1402
1403 func (m *certificateVerifyMsg) unmarshal(data []byte) bool {
1404 m.raw = data
1405
1406 if len(data) < 6 {
1407 return false
1408 }
1409
1410 length := uint32(data[1])<<16 | uint32(data[2])<<8 | uint32(data[3])
1411 if uint32(len(data))-4 != length {
1412 return false
1413 }
1414
1415 data = data[4:]
1416 if m.hasSignatureAndHash {
1417 m.signatureAlgorithm = SignatureScheme(data[0])<<8 | SignatureScheme(data[1])
1418 data = data[2:]
1419 }
1420
1421 if len(data) < 2 {
1422 return false
1423 }
1424 siglength := int(data[0])<<8 + int(data[1])
1425 data = data[2:]
1426 if len(data) != siglength {
1427 return false
1428 }
1429
1430 m.signature = data
1431
1432 return true
1433 }
1434
1435 type newSessionTicketMsg struct {
1436 raw []byte
1437 ticket []byte
1438 }
1439
1440 func (m *newSessionTicketMsg) equal(i interface{}) bool {
1441 m1, ok := i.(*newSessionTicketMsg)
1442 if !ok {
1443 return false
1444 }
1445
1446 return bytes.Equal(m.raw, m1.raw) &&
1447 bytes.Equal(m.ticket, m1.ticket)
1448 }
1449
1450 func (m *newSessionTicketMsg) marshal() (x []byte) {
1451 if m.raw != nil {
1452 return m.raw
1453 }
1454
1455
1456 ticketLen := len(m.ticket)
1457 length := 2 + 4 + ticketLen
1458 x = make([]byte, 4+length)
1459 x[0] = typeNewSessionTicket
1460 x[1] = uint8(length >> 16)
1461 x[2] = uint8(length >> 8)
1462 x[3] = uint8(length)
1463 x[8] = uint8(ticketLen >> 8)
1464 x[9] = uint8(ticketLen)
1465 copy(x[10:], m.ticket)
1466
1467 m.raw = x
1468
1469 return
1470 }
1471
1472 func (m *newSessionTicketMsg) unmarshal(data []byte) bool {
1473 m.raw = data
1474
1475 if len(data) < 10 {
1476 return false
1477 }
1478
1479 length := uint32(data[1])<<16 | uint32(data[2])<<8 | uint32(data[3])
1480 if uint32(len(data))-4 != length {
1481 return false
1482 }
1483
1484 ticketLen := int(data[8])<<8 + int(data[9])
1485 if len(data)-10 != ticketLen {
1486 return false
1487 }
1488
1489 m.ticket = data[10:]
1490
1491 return true
1492 }
1493
1494 type helloRequestMsg struct {
1495 }
1496
1497 func (*helloRequestMsg) marshal() []byte {
1498 return []byte{typeHelloRequest, 0, 0, 0}
1499 }
1500
1501 func (*helloRequestMsg) unmarshal(data []byte) bool {
1502 return len(data) == 4
1503 }
1504
1505 func eqUint16s(x, y []uint16) bool {
1506 if len(x) != len(y) {
1507 return false
1508 }
1509 for i, v := range x {
1510 if y[i] != v {
1511 return false
1512 }
1513 }
1514 return true
1515 }
1516
1517 func eqCurveIDs(x, y []CurveID) bool {
1518 if len(x) != len(y) {
1519 return false
1520 }
1521 for i, v := range x {
1522 if y[i] != v {
1523 return false
1524 }
1525 }
1526 return true
1527 }
1528
1529 func eqStrings(x, y []string) bool {
1530 if len(x) != len(y) {
1531 return false
1532 }
1533 for i, v := range x {
1534 if y[i] != v {
1535 return false
1536 }
1537 }
1538 return true
1539 }
1540
1541 func eqByteSlices(x, y [][]byte) bool {
1542 if len(x) != len(y) {
1543 return false
1544 }
1545 for i, v := range x {
1546 if !bytes.Equal(v, y[i]) {
1547 return false
1548 }
1549 }
1550 return true
1551 }
1552
1553 func eqSignatureAlgorithms(x, y []SignatureScheme) bool {
1554 if len(x) != len(y) {
1555 return false
1556 }
1557 for i, v := range x {
1558 if v != y[i] {
1559 return false
1560 }
1561 }
1562 return true
1563 }
1564
View as plain text