1
2
3
4
5 package tls
6
7 type clientHelloMsg struct {
8 raw []byte
9 vers uint16
10 random []byte
11 sessionId []byte
12 cipherSuites []uint16
13 compressionMethods []uint8
14 nextProtoNeg bool
15 serverName string
16 ocspStapling bool
17 supportedCurves []uint16
18 supportedPoints []uint8
19 }
20
21 func (m *clientHelloMsg) marshal() []byte {
22 if m.raw != nil {
23 return m.raw
24 }
25
26 length := 2 + 32 + 1 + len(m.sessionId) + 2 + len(m.cipherSuites)*2 + 1 + len(m.compressionMethods)
27 numExtensions := 0
28 extensionsLength := 0
29 if m.nextProtoNeg {
30 numExtensions++
31 }
32 if m.ocspStapling {
33 extensionsLength += 1 + 2 + 2
34 numExtensions++
35 }
36 if len(m.serverName) > 0 {
37 extensionsLength += 5 + len(m.serverName)
38 numExtensions++
39 }
40 if len(m.supportedCurves) > 0 {
41 extensionsLength += 2 + 2*len(m.supportedCurves)
42 numExtensions++
43 }
44 if len(m.supportedPoints) > 0 {
45 extensionsLength += 1 + len(m.supportedPoints)
46 numExtensions++
47 }
48 if numExtensions > 0 {
49 extensionsLength += 4 * numExtensions
50 length += 2 + extensionsLength
51 }
52
53 x := make([]byte, 4+length)
54 x[0] = typeClientHello
55 x[1] = uint8(length >> 16)
56 x[2] = uint8(length >> 8)
57 x[3] = uint8(length)
58 x[4] = uint8(m.vers >> 8)
59 x[5] = uint8(m.vers)
60 copy(x[6:38], m.random)
61 x[38] = uint8(len(m.sessionId))
62 copy(x[39:39+len(m.sessionId)], m.sessionId)
63 y := x[39+len(m.sessionId):]
64 y[0] = uint8(len(m.cipherSuites) >> 7)
65 y[1] = uint8(len(m.cipherSuites) << 1)
66 for i, suite := range m.cipherSuites {
67 y[2+i*2] = uint8(suite >> 8)
68 y[3+i*2] = uint8(suite)
69 }
70 z := y[2+len(m.cipherSuites)*2:]
71 z[0] = uint8(len(m.compressionMethods))
72 copy(z[1:], m.compressionMethods)
73
74 z = z[1+len(m.compressionMethods):]
75 if numExtensions > 0 {
76 z[0] = byte(extensionsLength >> 8)
77 z[1] = byte(extensionsLength)
78 z = z[2:]
79 }
80 if m.nextProtoNeg {
81 z[0] = byte(extensionNextProtoNeg >> 8)
82 z[1] = byte(extensionNextProtoNeg)
83
84 z = z[4:]
85 }
86 if len(m.serverName) > 0 {
87 z[0] = byte(extensionServerName >> 8)
88 z[1] = byte(extensionServerName)
89 l := len(m.serverName) + 5
90 z[2] = byte(l >> 8)
91 z[3] = byte(l)
92 z = z[4:]
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113 z[0] = byte((len(m.serverName) + 3) >> 8)
114 z[1] = byte(len(m.serverName) + 3)
115 z[3] = byte(len(m.serverName) >> 8)
116 z[4] = byte(len(m.serverName))
117 copy(z[5:], []byte(m.serverName))
118 z = z[l:]
119 }
120 if m.ocspStapling {
121
122 z[0] = byte(extensionStatusRequest >> 8)
123 z[1] = byte(extensionStatusRequest)
124 z[2] = 0
125 z[3] = 5
126 z[4] = 1
127
128 z = z[9:]
129 }
130 if len(m.supportedCurves) > 0 {
131
132 z[0] = byte(extensionSupportedCurves >> 8)
133 z[1] = byte(extensionSupportedCurves)
134 l := 2 + 2*len(m.supportedCurves)
135 z[2] = byte(l >> 8)
136 z[3] = byte(l)
137 l -= 2
138 z[4] = byte(l >> 8)
139 z[5] = byte(l)
140 z = z[6:]
141 for _, curve := range m.supportedCurves {
142 z[0] = byte(curve >> 8)
143 z[1] = byte(curve)
144 z = z[2:]
145 }
146 }
147 if len(m.supportedPoints) > 0 {
148
149 z[0] = byte(extensionSupportedPoints >> 8)
150 z[1] = byte(extensionSupportedPoints)
151 l := 1 + len(m.supportedPoints)
152 z[2] = byte(l >> 8)
153 z[3] = byte(l)
154 l--
155 z[4] = byte(l)
156 z = z[5:]
157 for _, pointFormat := range m.supportedPoints {
158 z[0] = byte(pointFormat)
159 z = z[1:]
160 }
161 }
162
163 m.raw = x
164
165 return x
166 }
167
168 func (m *clientHelloMsg) unmarshal(data []byte) bool {
169 if len(data) < 42 {
170 return false
171 }
172 m.raw = data
173 m.vers = uint16(data[4])<<8 | uint16(data[5])
174 m.random = data[6:38]
175 sessionIdLen := int(data[38])
176 if sessionIdLen > 32 || len(data) < 39+sessionIdLen {
177 return false
178 }
179 m.sessionId = data[39 : 39+sessionIdLen]
180 data = data[39+sessionIdLen:]
181 if len(data) < 2 {
182 return false
183 }
184
185
186 cipherSuiteLen := int(data[0])<<8 | int(data[1])
187 if cipherSuiteLen%2 == 1 || len(data) < 2+cipherSuiteLen {
188 return false
189 }
190 numCipherSuites := cipherSuiteLen / 2
191 m.cipherSuites = make([]uint16, numCipherSuites)
192 for i := 0; i < numCipherSuites; i++ {
193 m.cipherSuites[i] = uint16(data[2+2*i])<<8 | uint16(data[3+2*i])
194 }
195 data = data[2+cipherSuiteLen:]
196 if len(data) < 1 {
197 return false
198 }
199 compressionMethodsLen := int(data[0])
200 if len(data) < 1+compressionMethodsLen {
201 return false
202 }
203 m.compressionMethods = data[1 : 1+compressionMethodsLen]
204
205 data = data[1+compressionMethodsLen:]
206
207 m.nextProtoNeg = false
208 m.serverName = ""
209 m.ocspStapling = false
210
211 if len(data) == 0 {
212
213 return true
214 }
215 if len(data) < 2 {
216 return false
217 }
218
219 extensionsLength := int(data[0])<<8 | int(data[1])
220 data = data[2:]
221 if extensionsLength != len(data) {
222 return false
223 }
224
225 for len(data) != 0 {
226 if len(data) < 4 {
227 return false
228 }
229 extension := uint16(data[0])<<8 | uint16(data[1])
230 length := int(data[2])<<8 | int(data[3])
231 data = data[4:]
232 if len(data) < length {
233 return false
234 }
235
236 switch extension {
237 case extensionServerName:
238 if length < 2 {
239 return false
240 }
241 numNames := int(data[0])<<8 | int(data[1])
242 d := data[2:]
243 for i := 0; i < numNames; i++ {
244 if len(d) < 3 {
245 return false
246 }
247 nameType := d[0]
248 nameLen := int(d[1])<<8 | int(d[2])
249 d = d[3:]
250 if len(d) < nameLen {
251 return false
252 }
253 if nameType == 0 {
254 m.serverName = string(d[0:nameLen])
255 break
256 }
257 d = d[nameLen:]
258 }
259 case extensionNextProtoNeg:
260 if length > 0 {
261 return false
262 }
263 m.nextProtoNeg = true
264 case extensionStatusRequest:
265 m.ocspStapling = length > 0 && data[0] == statusTypeOCSP
266 case extensionSupportedCurves:
267
268 if length < 2 {
269 return false
270 }
271 l := int(data[0])<<8 | int(data[1])
272 if l%2 == 1 || length != l+2 {
273 return false
274 }
275 numCurves := l / 2
276 m.supportedCurves = make([]uint16, numCurves)
277 d := data[2:]
278 for i := 0; i < numCurves; i++ {
279 m.supportedCurves[i] = uint16(d[0])<<8 | uint16(d[1])
280 d = d[2:]
281 }
282 case extensionSupportedPoints:
283
284 if length < 1 {
285 return false
286 }
287 l := int(data[0])
288 if length != l+1 {
289 return false
290 }
291 m.supportedPoints = make([]uint8, l)
292 copy(m.supportedPoints, data[1:])
293 }
294 data = data[length:]
295 }
296
297 return true
298 }
299
300 type serverHelloMsg struct {
301 raw []byte
302 vers uint16
303 random []byte
304 sessionId []byte
305 cipherSuite uint16
306 compressionMethod uint8
307 nextProtoNeg bool
308 nextProtos []string
309 ocspStapling bool
310 }
311
312 func (m *serverHelloMsg) marshal() []byte {
313 if m.raw != nil {
314 return m.raw
315 }
316
317 length := 38 + len(m.sessionId)
318 numExtensions := 0
319 extensionsLength := 0
320
321 nextProtoLen := 0
322 if m.nextProtoNeg {
323 numExtensions++
324 for _, v := range m.nextProtos {
325 nextProtoLen += len(v)
326 }
327 nextProtoLen += len(m.nextProtos)
328 extensionsLength += nextProtoLen
329 }
330 if m.ocspStapling {
331 numExtensions++
332 }
333 if numExtensions > 0 {
334 extensionsLength += 4 * numExtensions
335 length += 2 + extensionsLength
336 }
337
338 x := make([]byte, 4+length)
339 x[0] = typeServerHello
340 x[1] = uint8(length >> 16)
341 x[2] = uint8(length >> 8)
342 x[3] = uint8(length)
343 x[4] = uint8(m.vers >> 8)
344 x[5] = uint8(m.vers)
345 copy(x[6:38], m.random)
346 x[38] = uint8(len(m.sessionId))
347 copy(x[39:39+len(m.sessionId)], m.sessionId)
348 z := x[39+len(m.sessionId):]
349 z[0] = uint8(m.cipherSuite >> 8)
350 z[1] = uint8(m.cipherSuite)
351 z[2] = uint8(m.compressionMethod)
352
353 z = z[3:]
354 if numExtensions > 0 {
355 z[0] = byte(extensionsLength >> 8)
356 z[1] = byte(extensionsLength)
357 z = z[2:]
358 }
359 if m.nextProtoNeg {
360 z[0] = byte(extensionNextProtoNeg >> 8)
361 z[1] = byte(extensionNextProtoNeg)
362 z[2] = byte(nextProtoLen >> 8)
363 z[3] = byte(nextProtoLen)
364 z = z[4:]
365
366 for _, v := range m.nextProtos {
367 l := len(v)
368 if l > 255 {
369 l = 255
370 }
371 z[0] = byte(l)
372 copy(z[1:], []byte(v[0:l]))
373 z = z[1+l:]
374 }
375 }
376 if m.ocspStapling {
377 z[0] = byte(extensionStatusRequest >> 8)
378 z[1] = byte(extensionStatusRequest)
379 z = z[4:]
380 }
381
382 m.raw = x
383
384 return x
385 }
386
387 func (m *serverHelloMsg) unmarshal(data []byte) bool {
388 if len(data) < 42 {
389 return false
390 }
391 m.raw = data
392 m.vers = uint16(data[4])<<8 | uint16(data[5])
393 m.random = data[6:38]
394 sessionIdLen := int(data[38])
395 if sessionIdLen > 32 || len(data) < 39+sessionIdLen {
396 return false
397 }
398 m.sessionId = data[39 : 39+sessionIdLen]
399 data = data[39+sessionIdLen:]
400 if len(data) < 3 {
401 return false
402 }
403 m.cipherSuite = uint16(data[0])<<8 | uint16(data[1])
404 m.compressionMethod = data[2]
405 data = data[3:]
406
407 m.nextProtoNeg = false
408 m.nextProtos = nil
409 m.ocspStapling = false
410
411 if len(data) == 0 {
412
413 return true
414 }
415 if len(data) < 2 {
416 return false
417 }
418
419 extensionsLength := int(data[0])<<8 | int(data[1])
420 data = data[2:]
421 if len(data) != extensionsLength {
422 return false
423 }
424
425 for len(data) != 0 {
426 if len(data) < 4 {
427 return false
428 }
429 extension := uint16(data[0])<<8 | uint16(data[1])
430 length := int(data[2])<<8 | int(data[3])
431 data = data[4:]
432 if len(data) < length {
433 return false
434 }
435
436 switch extension {
437 case extensionNextProtoNeg:
438 m.nextProtoNeg = true
439 d := data
440 for len(d) > 0 {
441 l := int(d[0])
442 d = d[1:]
443 if l == 0 || l > len(d) {
444 return false
445 }
446 m.nextProtos = append(m.nextProtos, string(d[0:l]))
447 d = d[l:]
448 }
449 case extensionStatusRequest:
450 if length > 0 {
451 return false
452 }
453 m.ocspStapling = true
454 }
455 data = data[length:]
456 }
457
458 return true
459 }
460
461 type certificateMsg struct {
462 raw []byte
463 certificates [][]byte
464 }
465
466 func (m *certificateMsg) marshal() (x []byte) {
467 if m.raw != nil {
468 return m.raw
469 }
470
471 var i int
472 for _, slice := range m.certificates {
473 i += len(slice)
474 }
475
476 length := 3 + 3*len(m.certificates) + i
477 x = make([]byte, 4+length)
478 x[0] = typeCertificate
479 x[1] = uint8(length >> 16)
480 x[2] = uint8(length >> 8)
481 x[3] = uint8(length)
482
483 certificateOctets := length - 3
484 x[4] = uint8(certificateOctets >> 16)
485 x[5] = uint8(certificateOctets >> 8)
486 x[6] = uint8(certificateOctets)
487
488 y := x[7:]
489 for _, slice := range m.certificates {
490 y[0] = uint8(len(slice) >> 16)
491 y[1] = uint8(len(slice) >> 8)
492 y[2] = uint8(len(slice))
493 copy(y[3:], slice)
494 y = y[3+len(slice):]
495 }
496
497 m.raw = x
498 return
499 }
500
501 func (m *certificateMsg) unmarshal(data []byte) bool {
502 if len(data) < 7 {
503 return false
504 }
505
506 m.raw = data
507 certsLen := uint32(data[4])<<16 | uint32(data[5])<<8 | uint32(data[6])
508 if uint32(len(data)) != certsLen+7 {
509 return false
510 }
511
512 numCerts := 0
513 d := data[7:]
514 for certsLen > 0 {
515 if len(d) < 4 {
516 return false
517 }
518 certLen := uint32(d[0])<<24 | uint32(d[1])<<8 | uint32(d[2])
519 if uint32(len(d)) < 3+certLen {
520 return false
521 }
522 d = d[3+certLen:]
523 certsLen -= 3 + certLen
524 numCerts++
525 }
526
527 m.certificates = make([][]byte, numCerts)
528 d = data[7:]
529 for i := 0; i < numCerts; i++ {
530 certLen := uint32(d[0])<<24 | uint32(d[1])<<8 | uint32(d[2])
531 m.certificates[i] = d[3 : 3+certLen]
532 d = d[3+certLen:]
533 }
534
535 return true
536 }
537
538 type serverKeyExchangeMsg struct {
539 raw []byte
540 key []byte
541 }
542
543 func (m *serverKeyExchangeMsg) marshal() []byte {
544 if m.raw != nil {
545 return m.raw
546 }
547 length := len(m.key)
548 x := make([]byte, length+4)
549 x[0] = typeServerKeyExchange
550 x[1] = uint8(length >> 16)
551 x[2] = uint8(length >> 8)
552 x[3] = uint8(length)
553 copy(x[4:], m.key)
554
555 m.raw = x
556 return x
557 }
558
559 func (m *serverKeyExchangeMsg) unmarshal(data []byte) bool {
560 m.raw = data
561 if len(data) < 4 {
562 return false
563 }
564 m.key = data[4:]
565 return true
566 }
567
568 type certificateStatusMsg struct {
569 raw []byte
570 statusType uint8
571 response []byte
572 }
573
574 func (m *certificateStatusMsg) marshal() []byte {
575 if m.raw != nil {
576 return m.raw
577 }
578
579 var x []byte
580 if m.statusType == statusTypeOCSP {
581 x = make([]byte, 4+4+len(m.response))
582 x[0] = typeCertificateStatus
583 l := len(m.response) + 4
584 x[1] = byte(l >> 16)
585 x[2] = byte(l >> 8)
586 x[3] = byte(l)
587 x[4] = statusTypeOCSP
588
589 l -= 4
590 x[5] = byte(l >> 16)
591 x[6] = byte(l >> 8)
592 x[7] = byte(l)
593 copy(x[8:], m.response)
594 } else {
595 x = []byte{typeCertificateStatus, 0, 0, 1, m.statusType}
596 }
597
598 m.raw = x
599 return x
600 }
601
602 func (m *certificateStatusMsg) unmarshal(data []byte) bool {
603 m.raw = data
604 if len(data) < 5 {
605 return false
606 }
607 m.statusType = data[4]
608
609 m.response = nil
610 if m.statusType == statusTypeOCSP {
611 if len(data) < 8 {
612 return false
613 }
614 respLen := uint32(data[5])<<16 | uint32(data[6])<<8 | uint32(data[7])
615 if uint32(len(data)) != 4+4+respLen {
616 return false
617 }
618 m.response = data[8:]
619 }
620 return true
621 }
622
623 type serverHelloDoneMsg struct{}
624
625 func (m *serverHelloDoneMsg) marshal() []byte {
626 x := make([]byte, 4)
627 x[0] = typeServerHelloDone
628 return x
629 }
630
631 func (m *serverHelloDoneMsg) unmarshal(data []byte) bool {
632 return len(data) == 4
633 }
634
635 type clientKeyExchangeMsg struct {
636 raw []byte
637 ciphertext []byte
638 }
639
640 func (m *clientKeyExchangeMsg) marshal() []byte {
641 if m.raw != nil {
642 return m.raw
643 }
644 length := len(m.ciphertext)
645 x := make([]byte, length+4)
646 x[0] = typeClientKeyExchange
647 x[1] = uint8(length >> 16)
648 x[2] = uint8(length >> 8)
649 x[3] = uint8(length)
650 copy(x[4:], m.ciphertext)
651
652 m.raw = x
653 return x
654 }
655
656 func (m *clientKeyExchangeMsg) unmarshal(data []byte) bool {
657 m.raw = data
658 if len(data) < 4 {
659 return false
660 }
661 l := int(data[1])<<16 | int(data[2])<<8 | int(data[3])
662 if l != len(data)-4 {
663 return false
664 }
665 m.ciphertext = data[4:]
666 return true
667 }
668
669 type finishedMsg struct {
670 raw []byte
671 verifyData []byte
672 }
673
674 func (m *finishedMsg) marshal() (x []byte) {
675 if m.raw != nil {
676 return m.raw
677 }
678
679 x = make([]byte, 16)
680 x[0] = typeFinished
681 x[3] = 12
682 copy(x[4:], m.verifyData)
683 m.raw = x
684 return
685 }
686
687 func (m *finishedMsg) unmarshal(data []byte) bool {
688 m.raw = data
689 if len(data) != 4+12 {
690 return false
691 }
692 m.verifyData = data[4:]
693 return true
694 }
695
696 type nextProtoMsg struct {
697 raw []byte
698 proto string
699 }
700
701 func (m *nextProtoMsg) marshal() []byte {
702 if m.raw != nil {
703 return m.raw
704 }
705 l := len(m.proto)
706 if l > 255 {
707 l = 255
708 }
709
710 padding := 32 - (l+2)%32
711 length := l + padding + 2
712 x := make([]byte, length+4)
713 x[0] = typeNextProtocol
714 x[1] = uint8(length >> 16)
715 x[2] = uint8(length >> 8)
716 x[3] = uint8(length)
717
718 y := x[4:]
719 y[0] = byte(l)
720 copy(y[1:], []byte(m.proto[0:l]))
721 y = y[1+l:]
722 y[0] = byte(padding)
723
724 m.raw = x
725
726 return x
727 }
728
729 func (m *nextProtoMsg) unmarshal(data []byte) bool {
730 m.raw = data
731
732 if len(data) < 5 {
733 return false
734 }
735 data = data[4:]
736 protoLen := int(data[0])
737 data = data[1:]
738 if len(data) < protoLen {
739 return false
740 }
741 m.proto = string(data[0:protoLen])
742 data = data[protoLen:]
743
744 if len(data) < 1 {
745 return false
746 }
747 paddingLen := int(data[0])
748 data = data[1:]
749 if len(data) != paddingLen {
750 return false
751 }
752
753 return true
754 }
755
756 type certificateRequestMsg struct {
757 raw []byte
758 certificateTypes []byte
759 certificateAuthorities [][]byte
760 }
761
762 func (m *certificateRequestMsg) marshal() (x []byte) {
763 if m.raw != nil {
764 return m.raw
765 }
766
767
768 length := 1 + len(m.certificateTypes) + 2
769 for _, ca := range m.certificateAuthorities {
770 length += 2 + len(ca)
771 }
772
773 x = make([]byte, 4+length)
774 x[0] = typeCertificateRequest
775 x[1] = uint8(length >> 16)
776 x[2] = uint8(length >> 8)
777 x[3] = uint8(length)
778
779 x[4] = uint8(len(m.certificateTypes))
780
781 copy(x[5:], m.certificateTypes)
782 y := x[5+len(m.certificateTypes):]
783
784 numCA := len(m.certificateAuthorities)
785 y[0] = uint8(numCA >> 8)
786 y[1] = uint8(numCA)
787 y = y[2:]
788 for _, ca := range m.certificateAuthorities {
789 y[0] = uint8(len(ca) >> 8)
790 y[1] = uint8(len(ca))
791 y = y[2:]
792 copy(y, ca)
793 y = y[len(ca):]
794 }
795
796 m.raw = x
797
798 return
799 }
800
801 func (m *certificateRequestMsg) unmarshal(data []byte) bool {
802 m.raw = data
803
804 if len(data) < 5 {
805 return false
806 }
807
808 length := uint32(data[1])<<16 | uint32(data[2])<<8 | uint32(data[3])
809 if uint32(len(data))-4 != length {
810 return false
811 }
812
813 numCertTypes := int(data[4])
814 data = data[5:]
815 if numCertTypes == 0 || len(data) <= numCertTypes {
816 return false
817 }
818
819 m.certificateTypes = make([]byte, numCertTypes)
820 if copy(m.certificateTypes, data) != numCertTypes {
821 return false
822 }
823
824 data = data[numCertTypes:]
825 if len(data) < 2 {
826 return false
827 }
828
829 numCAs := uint16(data[0])<<16 | uint16(data[1])
830 data = data[2:]
831
832 m.certificateAuthorities = make([][]byte, numCAs)
833 for i := uint16(0); i < numCAs; i++ {
834 if len(data) < 2 {
835 return false
836 }
837 caLen := uint16(data[0])<<16 | uint16(data[1])
838
839 data = data[2:]
840 if len(data) < int(caLen) {
841 return false
842 }
843
844 ca := make([]byte, caLen)
845 copy(ca, data)
846 m.certificateAuthorities[i] = ca
847 data = data[caLen:]
848 }
849
850 if len(data) > 0 {
851 return false
852 }
853
854 return true
855 }
856
857 type certificateVerifyMsg struct {
858 raw []byte
859 signature []byte
860 }
861
862 func (m *certificateVerifyMsg) marshal() (x []byte) {
863 if m.raw != nil {
864 return m.raw
865 }
866
867
868 siglength := len(m.signature)
869 length := 2 + siglength
870 x = make([]byte, 4+length)
871 x[0] = typeCertificateVerify
872 x[1] = uint8(length >> 16)
873 x[2] = uint8(length >> 8)
874 x[3] = uint8(length)
875 x[4] = uint8(siglength >> 8)
876 x[5] = uint8(siglength)
877 copy(x[6:], m.signature)
878
879 m.raw = x
880
881 return
882 }
883
884 func (m *certificateVerifyMsg) unmarshal(data []byte) bool {
885 m.raw = data
886
887 if len(data) < 6 {
888 return false
889 }
890
891 length := uint32(data[1])<<16 | uint32(data[2])<<8 | uint32(data[3])
892 if uint32(len(data))-4 != length {
893 return false
894 }
895
896 siglength := int(data[4])<<8 + int(data[5])
897 if len(data)-6 != siglength {
898 return false
899 }
900
901 m.signature = data[6:]
902
903 return true
904 }