1
2
3
4
5
6
7 package tls
8
9 import (
10 "bytes"
11 "crypto/cipher"
12 "crypto/subtle"
13 "crypto/x509"
14 "errors"
15 "fmt"
16 "hash"
17 "io"
18 "net"
19 "sync"
20 "sync/atomic"
21 "time"
22 )
23
24
25
26 type Conn struct {
27
28 conn net.Conn
29 isClient bool
30 handshakeFn func() error
31
32
33
34
35 handshakeStatus uint32
36
37 handshakeMutex sync.Mutex
38 handshakeErr error
39 vers uint16
40 haveVers bool
41 config *Config
42
43
44
45 handshakes int
46 didResume bool
47 cipherSuite uint16
48 ocspResponse []byte
49 scts [][]byte
50 peerCertificates []*x509.Certificate
51
52
53 verifiedChains [][]*x509.Certificate
54
55 serverName string
56
57
58
59 secureRenegotiation bool
60
61 ekm func(label string, context []byte, length int) ([]byte, error)
62
63
64 resumptionSecret []byte
65
66
67
68
69 ticketKeys []ticketKey
70
71
72
73
74
75 clientFinishedIsFirst bool
76
77
78 closeNotifyErr error
79
80
81 closeNotifySent bool
82
83
84
85
86
87 clientFinished [12]byte
88 serverFinished [12]byte
89
90
91 clientProtocol string
92
93
94 in, out halfConn
95 rawInput bytes.Buffer
96 input bytes.Reader
97 hand bytes.Buffer
98 buffering bool
99 sendBuf []byte
100
101
102
103 bytesSent int64
104 packetsSent int64
105
106
107
108
109 retryCount int
110
111
112
113
114 activeCall int32
115
116 tmp [16]byte
117 }
118
119
120
121
122
123
124 func (c *Conn) LocalAddr() net.Addr {
125 return c.conn.LocalAddr()
126 }
127
128
129 func (c *Conn) RemoteAddr() net.Addr {
130 return c.conn.RemoteAddr()
131 }
132
133
134
135
136 func (c *Conn) SetDeadline(t time.Time) error {
137 return c.conn.SetDeadline(t)
138 }
139
140
141
142 func (c *Conn) SetReadDeadline(t time.Time) error {
143 return c.conn.SetReadDeadline(t)
144 }
145
146
147
148
149 func (c *Conn) SetWriteDeadline(t time.Time) error {
150 return c.conn.SetWriteDeadline(t)
151 }
152
153
154
155 type halfConn struct {
156 sync.Mutex
157
158 err error
159 version uint16
160 cipher interface{}
161 mac hash.Hash
162 seq [8]byte
163
164 scratchBuf [13]byte
165
166 nextCipher interface{}
167 nextMac hash.Hash
168
169 trafficSecret []byte
170 }
171
172 type permanentError struct {
173 err net.Error
174 }
175
176 func (e *permanentError) Error() string { return e.err.Error() }
177 func (e *permanentError) Unwrap() error { return e.err }
178 func (e *permanentError) Timeout() bool { return e.err.Timeout() }
179 func (e *permanentError) Temporary() bool { return false }
180
181 func (hc *halfConn) setErrorLocked(err error) error {
182 if e, ok := err.(net.Error); ok {
183 hc.err = &permanentError{err: e}
184 } else {
185 hc.err = err
186 }
187 return hc.err
188 }
189
190
191
192 func (hc *halfConn) prepareCipherSpec(version uint16, cipher interface{}, mac hash.Hash) {
193 hc.version = version
194 hc.nextCipher = cipher
195 hc.nextMac = mac
196 }
197
198
199
200 func (hc *halfConn) changeCipherSpec() error {
201 if hc.nextCipher == nil || hc.version == VersionTLS13 {
202 return alertInternalError
203 }
204 hc.cipher = hc.nextCipher
205 hc.mac = hc.nextMac
206 hc.nextCipher = nil
207 hc.nextMac = nil
208 for i := range hc.seq {
209 hc.seq[i] = 0
210 }
211 return nil
212 }
213
214 func (hc *halfConn) setTrafficSecret(suite *cipherSuiteTLS13, secret []byte) {
215 hc.trafficSecret = secret
216 key, iv := suite.trafficKey(secret)
217 hc.cipher = suite.aead(key, iv)
218 for i := range hc.seq {
219 hc.seq[i] = 0
220 }
221 }
222
223
224 func (hc *halfConn) incSeq() {
225 for i := 7; i >= 0; i-- {
226 hc.seq[i]++
227 if hc.seq[i] != 0 {
228 return
229 }
230 }
231
232
233
234
235 panic("TLS: sequence number wraparound")
236 }
237
238
239
240
241 func (hc *halfConn) explicitNonceLen() int {
242 if hc.cipher == nil {
243 return 0
244 }
245
246 switch c := hc.cipher.(type) {
247 case cipher.Stream:
248 return 0
249 case aead:
250 return c.explicitNonceLen()
251 case cbcMode:
252
253 if hc.version >= VersionTLS11 {
254 return c.BlockSize()
255 }
256 return 0
257 default:
258 panic("unknown cipher type")
259 }
260 }
261
262
263
264
265 func extractPadding(payload []byte) (toRemove int, good byte) {
266 if len(payload) < 1 {
267 return 0, 0
268 }
269
270 paddingLen := payload[len(payload)-1]
271 t := uint(len(payload)-1) - uint(paddingLen)
272
273 good = byte(int32(^t) >> 31)
274
275
276 toCheck := 256
277
278 if toCheck > len(payload) {
279 toCheck = len(payload)
280 }
281
282 for i := 0; i < toCheck; i++ {
283 t := uint(paddingLen) - uint(i)
284
285 mask := byte(int32(^t) >> 31)
286 b := payload[len(payload)-1-i]
287 good &^= mask&paddingLen ^ mask&b
288 }
289
290
291
292 good &= good << 4
293 good &= good << 2
294 good &= good << 1
295 good = uint8(int8(good) >> 7)
296
297
298
299
300
301
302
303
304
305
306 paddingLen &= good
307
308 toRemove = int(paddingLen) + 1
309 return
310 }
311
312 func roundUp(a, b int) int {
313 return a + (b-a%b)%b
314 }
315
316
317 type cbcMode interface {
318 cipher.BlockMode
319 SetIV([]byte)
320 }
321
322
323
324 func (hc *halfConn) decrypt(record []byte) ([]byte, recordType, error) {
325 var plaintext []byte
326 typ := recordType(record[0])
327 payload := record[recordHeaderLen:]
328
329
330
331 if hc.version == VersionTLS13 && typ == recordTypeChangeCipherSpec {
332 return payload, typ, nil
333 }
334
335 paddingGood := byte(255)
336 paddingLen := 0
337
338 explicitNonceLen := hc.explicitNonceLen()
339
340 if hc.cipher != nil {
341 switch c := hc.cipher.(type) {
342 case cipher.Stream:
343 c.XORKeyStream(payload, payload)
344 case aead:
345 if len(payload) < explicitNonceLen {
346 return nil, 0, alertBadRecordMAC
347 }
348 nonce := payload[:explicitNonceLen]
349 if len(nonce) == 0 {
350 nonce = hc.seq[:]
351 }
352 payload = payload[explicitNonceLen:]
353
354 var additionalData []byte
355 if hc.version == VersionTLS13 {
356 additionalData = record[:recordHeaderLen]
357 } else {
358 additionalData = append(hc.scratchBuf[:0], hc.seq[:]...)
359 additionalData = append(additionalData, record[:3]...)
360 n := len(payload) - c.Overhead()
361 additionalData = append(additionalData, byte(n>>8), byte(n))
362 }
363
364 var err error
365 plaintext, err = c.Open(payload[:0], nonce, payload, additionalData)
366 if err != nil {
367 return nil, 0, alertBadRecordMAC
368 }
369 case cbcMode:
370 blockSize := c.BlockSize()
371 minPayload := explicitNonceLen + roundUp(hc.mac.Size()+1, blockSize)
372 if len(payload)%blockSize != 0 || len(payload) < minPayload {
373 return nil, 0, alertBadRecordMAC
374 }
375
376 if explicitNonceLen > 0 {
377 c.SetIV(payload[:explicitNonceLen])
378 payload = payload[explicitNonceLen:]
379 }
380 c.CryptBlocks(payload, payload)
381
382
383
384
385
386
387
388 paddingLen, paddingGood = extractPadding(payload)
389 default:
390 panic("unknown cipher type")
391 }
392
393 if hc.version == VersionTLS13 {
394 if typ != recordTypeApplicationData {
395 return nil, 0, alertUnexpectedMessage
396 }
397 if len(plaintext) > maxPlaintext+1 {
398 return nil, 0, alertRecordOverflow
399 }
400
401 for i := len(plaintext) - 1; i >= 0; i-- {
402 if plaintext[i] != 0 {
403 typ = recordType(plaintext[i])
404 plaintext = plaintext[:i]
405 break
406 }
407 if i == 0 {
408 return nil, 0, alertUnexpectedMessage
409 }
410 }
411 }
412 } else {
413 plaintext = payload
414 }
415
416 if hc.mac != nil {
417 macSize := hc.mac.Size()
418 if len(payload) < macSize {
419 return nil, 0, alertBadRecordMAC
420 }
421
422 n := len(payload) - macSize - paddingLen
423 n = subtle.ConstantTimeSelect(int(uint32(n)>>31), 0, n)
424 record[3] = byte(n >> 8)
425 record[4] = byte(n)
426 remoteMAC := payload[n : n+macSize]
427 localMAC := tls10MAC(hc.mac, hc.scratchBuf[:0], hc.seq[:], record[:recordHeaderLen], payload[:n], payload[n+macSize:])
428
429
430
431
432
433
434
435
436 macAndPaddingGood := subtle.ConstantTimeCompare(localMAC, remoteMAC) & int(paddingGood)
437 if macAndPaddingGood != 1 {
438 return nil, 0, alertBadRecordMAC
439 }
440
441 plaintext = payload[:n]
442 }
443
444 hc.incSeq()
445 return plaintext, typ, nil
446 }
447
448
449
450
451 func sliceForAppend(in []byte, n int) (head, tail []byte) {
452 if total := len(in) + n; cap(in) >= total {
453 head = in[:total]
454 } else {
455 head = make([]byte, total)
456 copy(head, in)
457 }
458 tail = head[len(in):]
459 return
460 }
461
462
463
464 func (hc *halfConn) encrypt(record, payload []byte, rand io.Reader) ([]byte, error) {
465 if hc.cipher == nil {
466 return append(record, payload...), nil
467 }
468
469 var explicitNonce []byte
470 if explicitNonceLen := hc.explicitNonceLen(); explicitNonceLen > 0 {
471 record, explicitNonce = sliceForAppend(record, explicitNonceLen)
472 if _, isCBC := hc.cipher.(cbcMode); !isCBC && explicitNonceLen < 16 {
473
474
475
476
477
478
479
480
481
482 copy(explicitNonce, hc.seq[:])
483 } else {
484 if _, err := io.ReadFull(rand, explicitNonce); err != nil {
485 return nil, err
486 }
487 }
488 }
489
490 var dst []byte
491 switch c := hc.cipher.(type) {
492 case cipher.Stream:
493 mac := tls10MAC(hc.mac, hc.scratchBuf[:0], hc.seq[:], record[:recordHeaderLen], payload, nil)
494 record, dst = sliceForAppend(record, len(payload)+len(mac))
495 c.XORKeyStream(dst[:len(payload)], payload)
496 c.XORKeyStream(dst[len(payload):], mac)
497 case aead:
498 nonce := explicitNonce
499 if len(nonce) == 0 {
500 nonce = hc.seq[:]
501 }
502
503 if hc.version == VersionTLS13 {
504 record = append(record, payload...)
505
506
507 record = append(record, record[0])
508 record[0] = byte(recordTypeApplicationData)
509
510 n := len(payload) + 1 + c.Overhead()
511 record[3] = byte(n >> 8)
512 record[4] = byte(n)
513
514 record = c.Seal(record[:recordHeaderLen],
515 nonce, record[recordHeaderLen:], record[:recordHeaderLen])
516 } else {
517 additionalData := append(hc.scratchBuf[:0], hc.seq[:]...)
518 additionalData = append(additionalData, record[:recordHeaderLen]...)
519 record = c.Seal(record, nonce, payload, additionalData)
520 }
521 case cbcMode:
522 mac := tls10MAC(hc.mac, hc.scratchBuf[:0], hc.seq[:], record[:recordHeaderLen], payload, nil)
523 blockSize := c.BlockSize()
524 plaintextLen := len(payload) + len(mac)
525 paddingLen := blockSize - plaintextLen%blockSize
526 record, dst = sliceForAppend(record, plaintextLen+paddingLen)
527 copy(dst, payload)
528 copy(dst[len(payload):], mac)
529 for i := plaintextLen; i < len(dst); i++ {
530 dst[i] = byte(paddingLen - 1)
531 }
532 if len(explicitNonce) > 0 {
533 c.SetIV(explicitNonce)
534 }
535 c.CryptBlocks(dst, dst)
536 default:
537 panic("unknown cipher type")
538 }
539
540
541 n := len(record) - recordHeaderLen
542 record[3] = byte(n >> 8)
543 record[4] = byte(n)
544 hc.incSeq()
545
546 return record, nil
547 }
548
549
550 type RecordHeaderError struct {
551
552 Msg string
553
554
555 RecordHeader [5]byte
556
557
558
559
560 Conn net.Conn
561 }
562
563 func (e RecordHeaderError) Error() string { return "tls: " + e.Msg }
564
565 func (c *Conn) newRecordHeaderError(conn net.Conn, msg string) (err RecordHeaderError) {
566 err.Msg = msg
567 err.Conn = conn
568 copy(err.RecordHeader[:], c.rawInput.Bytes())
569 return err
570 }
571
572 func (c *Conn) readRecord() error {
573 return c.readRecordOrCCS(false)
574 }
575
576 func (c *Conn) readChangeCipherSpec() error {
577 return c.readRecordOrCCS(true)
578 }
579
580
581
582
583
584
585
586
587
588
589
590
591
592 func (c *Conn) readRecordOrCCS(expectChangeCipherSpec bool) error {
593 if c.in.err != nil {
594 return c.in.err
595 }
596 handshakeComplete := c.handshakeComplete()
597
598
599 if c.input.Len() != 0 {
600 return c.in.setErrorLocked(errors.New("tls: internal error: attempted to read record with pending application data"))
601 }
602 c.input.Reset(nil)
603
604
605 if err := c.readFromUntil(c.conn, recordHeaderLen); err != nil {
606
607
608
609 if err == io.ErrUnexpectedEOF && c.rawInput.Len() == 0 {
610 err = io.EOF
611 }
612 if e, ok := err.(net.Error); !ok || !e.Temporary() {
613 c.in.setErrorLocked(err)
614 }
615 return err
616 }
617 hdr := c.rawInput.Bytes()[:recordHeaderLen]
618 typ := recordType(hdr[0])
619
620
621
622
623
624 if !handshakeComplete && typ == 0x80 {
625 c.sendAlert(alertProtocolVersion)
626 return c.in.setErrorLocked(c.newRecordHeaderError(nil, "unsupported SSLv2 handshake received"))
627 }
628
629 vers := uint16(hdr[1])<<8 | uint16(hdr[2])
630 n := int(hdr[3])<<8 | int(hdr[4])
631 if c.haveVers && c.vers != VersionTLS13 && vers != c.vers {
632 c.sendAlert(alertProtocolVersion)
633 msg := fmt.Sprintf("received record with version %x when expecting version %x", vers, c.vers)
634 return c.in.setErrorLocked(c.newRecordHeaderError(nil, msg))
635 }
636 if !c.haveVers {
637
638
639
640
641 if (typ != recordTypeAlert && typ != recordTypeHandshake) || vers >= 0x1000 {
642 return c.in.setErrorLocked(c.newRecordHeaderError(c.conn, "first record does not look like a TLS handshake"))
643 }
644 }
645 if c.vers == VersionTLS13 && n > maxCiphertextTLS13 || n > maxCiphertext {
646 c.sendAlert(alertRecordOverflow)
647 msg := fmt.Sprintf("oversized record received with length %d", n)
648 return c.in.setErrorLocked(c.newRecordHeaderError(nil, msg))
649 }
650 if err := c.readFromUntil(c.conn, recordHeaderLen+n); err != nil {
651 if e, ok := err.(net.Error); !ok || !e.Temporary() {
652 c.in.setErrorLocked(err)
653 }
654 return err
655 }
656
657
658 record := c.rawInput.Next(recordHeaderLen + n)
659 data, typ, err := c.in.decrypt(record)
660 if err != nil {
661 return c.in.setErrorLocked(c.sendAlert(err.(alert)))
662 }
663 if len(data) > maxPlaintext {
664 return c.in.setErrorLocked(c.sendAlert(alertRecordOverflow))
665 }
666
667
668 if c.in.cipher == nil && typ == recordTypeApplicationData {
669 return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
670 }
671
672 if typ != recordTypeAlert && typ != recordTypeChangeCipherSpec && len(data) > 0 {
673
674 c.retryCount = 0
675 }
676
677
678 if c.vers == VersionTLS13 && typ != recordTypeHandshake && c.hand.Len() > 0 {
679 return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
680 }
681
682 switch typ {
683 default:
684 return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
685
686 case recordTypeAlert:
687 if len(data) != 2 {
688 return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
689 }
690 if alert(data[1]) == alertCloseNotify {
691 return c.in.setErrorLocked(io.EOF)
692 }
693 if c.vers == VersionTLS13 {
694 return c.in.setErrorLocked(&net.OpError{Op: "remote error", Err: alert(data[1])})
695 }
696 switch data[0] {
697 case alertLevelWarning:
698
699 return c.retryReadRecord(expectChangeCipherSpec)
700 case alertLevelError:
701 return c.in.setErrorLocked(&net.OpError{Op: "remote error", Err: alert(data[1])})
702 default:
703 return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
704 }
705
706 case recordTypeChangeCipherSpec:
707 if len(data) != 1 || data[0] != 1 {
708 return c.in.setErrorLocked(c.sendAlert(alertDecodeError))
709 }
710
711 if c.hand.Len() > 0 {
712 return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
713 }
714
715
716
717
718
719 if c.vers == VersionTLS13 {
720 return c.retryReadRecord(expectChangeCipherSpec)
721 }
722 if !expectChangeCipherSpec {
723 return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
724 }
725 if err := c.in.changeCipherSpec(); err != nil {
726 return c.in.setErrorLocked(c.sendAlert(err.(alert)))
727 }
728
729 case recordTypeApplicationData:
730 if !handshakeComplete || expectChangeCipherSpec {
731 return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
732 }
733
734
735 if len(data) == 0 {
736 return c.retryReadRecord(expectChangeCipherSpec)
737 }
738
739
740
741 c.input.Reset(data)
742
743 case recordTypeHandshake:
744 if len(data) == 0 || expectChangeCipherSpec {
745 return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
746 }
747 c.hand.Write(data)
748 }
749
750 return nil
751 }
752
753
754
755 func (c *Conn) retryReadRecord(expectChangeCipherSpec bool) error {
756 c.retryCount++
757 if c.retryCount > maxUselessRecords {
758 c.sendAlert(alertUnexpectedMessage)
759 return c.in.setErrorLocked(errors.New("tls: too many ignored records"))
760 }
761 return c.readRecordOrCCS(expectChangeCipherSpec)
762 }
763
764
765
766
767 type atLeastReader struct {
768 R io.Reader
769 N int64
770 }
771
772 func (r *atLeastReader) Read(p []byte) (int, error) {
773 if r.N <= 0 {
774 return 0, io.EOF
775 }
776 n, err := r.R.Read(p)
777 r.N -= int64(n)
778 if r.N > 0 && err == io.EOF {
779 return n, io.ErrUnexpectedEOF
780 }
781 if r.N <= 0 && err == nil {
782 return n, io.EOF
783 }
784 return n, err
785 }
786
787
788
789 func (c *Conn) readFromUntil(r io.Reader, n int) error {
790 if c.rawInput.Len() >= n {
791 return nil
792 }
793 needs := n - c.rawInput.Len()
794
795
796
797 c.rawInput.Grow(needs + bytes.MinRead)
798 _, err := c.rawInput.ReadFrom(&atLeastReader{r, int64(needs)})
799 return err
800 }
801
802
803 func (c *Conn) sendAlertLocked(err alert) error {
804 switch err {
805 case alertNoRenegotiation, alertCloseNotify:
806 c.tmp[0] = alertLevelWarning
807 default:
808 c.tmp[0] = alertLevelError
809 }
810 c.tmp[1] = byte(err)
811
812 _, writeErr := c.writeRecordLocked(recordTypeAlert, c.tmp[0:2])
813 if err == alertCloseNotify {
814
815 return writeErr
816 }
817
818 return c.out.setErrorLocked(&net.OpError{Op: "local error", Err: err})
819 }
820
821
822 func (c *Conn) sendAlert(err alert) error {
823 c.out.Lock()
824 defer c.out.Unlock()
825 return c.sendAlertLocked(err)
826 }
827
828 const (
829
830
831
832
833
834 tcpMSSEstimate = 1208
835
836
837
838
839 recordSizeBoostThreshold = 128 * 1024
840 )
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858 func (c *Conn) maxPayloadSizeForWrite(typ recordType) int {
859 if c.config.DynamicRecordSizingDisabled || typ != recordTypeApplicationData {
860 return maxPlaintext
861 }
862
863 if c.bytesSent >= recordSizeBoostThreshold {
864 return maxPlaintext
865 }
866
867
868 payloadBytes := tcpMSSEstimate - recordHeaderLen - c.out.explicitNonceLen()
869 if c.out.cipher != nil {
870 switch ciph := c.out.cipher.(type) {
871 case cipher.Stream:
872 payloadBytes -= c.out.mac.Size()
873 case cipher.AEAD:
874 payloadBytes -= ciph.Overhead()
875 case cbcMode:
876 blockSize := ciph.BlockSize()
877
878
879 payloadBytes = (payloadBytes & ^(blockSize - 1)) - 1
880
881
882 payloadBytes -= c.out.mac.Size()
883 default:
884 panic("unknown cipher type")
885 }
886 }
887 if c.vers == VersionTLS13 {
888 payloadBytes--
889 }
890
891
892 pkt := c.packetsSent
893 c.packetsSent++
894 if pkt > 1000 {
895 return maxPlaintext
896 }
897
898 n := payloadBytes * int(pkt+1)
899 if n > maxPlaintext {
900 n = maxPlaintext
901 }
902 return n
903 }
904
905 func (c *Conn) write(data []byte) (int, error) {
906 if c.buffering {
907 c.sendBuf = append(c.sendBuf, data...)
908 return len(data), nil
909 }
910
911 n, err := c.conn.Write(data)
912 c.bytesSent += int64(n)
913 return n, err
914 }
915
916 func (c *Conn) flush() (int, error) {
917 if len(c.sendBuf) == 0 {
918 return 0, nil
919 }
920
921 n, err := c.conn.Write(c.sendBuf)
922 c.bytesSent += int64(n)
923 c.sendBuf = nil
924 c.buffering = false
925 return n, err
926 }
927
928
929 var outBufPool = sync.Pool{
930 New: func() interface{} {
931 return new([]byte)
932 },
933 }
934
935
936
937 func (c *Conn) writeRecordLocked(typ recordType, data []byte) (int, error) {
938 outBufPtr := outBufPool.Get().(*[]byte)
939 outBuf := *outBufPtr
940 defer func() {
941
942
943
944
945
946 *outBufPtr = outBuf
947 outBufPool.Put(outBufPtr)
948 }()
949
950 var n int
951 for len(data) > 0 {
952 m := len(data)
953 if maxPayload := c.maxPayloadSizeForWrite(typ); m > maxPayload {
954 m = maxPayload
955 }
956
957 _, outBuf = sliceForAppend(outBuf[:0], recordHeaderLen)
958 outBuf[0] = byte(typ)
959 vers := c.vers
960 if vers == 0 {
961
962
963 vers = VersionTLS10
964 } else if vers == VersionTLS13 {
965
966
967 vers = VersionTLS12
968 }
969 outBuf[1] = byte(vers >> 8)
970 outBuf[2] = byte(vers)
971 outBuf[3] = byte(m >> 8)
972 outBuf[4] = byte(m)
973
974 var err error
975 outBuf, err = c.out.encrypt(outBuf, data[:m], c.config.rand())
976 if err != nil {
977 return n, err
978 }
979 if _, err := c.write(outBuf); err != nil {
980 return n, err
981 }
982 n += m
983 data = data[m:]
984 }
985
986 if typ == recordTypeChangeCipherSpec && c.vers != VersionTLS13 {
987 if err := c.out.changeCipherSpec(); err != nil {
988 return n, c.sendAlertLocked(err.(alert))
989 }
990 }
991
992 return n, nil
993 }
994
995
996
997 func (c *Conn) writeRecord(typ recordType, data []byte) (int, error) {
998 c.out.Lock()
999 defer c.out.Unlock()
1000
1001 return c.writeRecordLocked(typ, data)
1002 }
1003
1004
1005
1006 func (c *Conn) readHandshake() (interface{}, error) {
1007 for c.hand.Len() < 4 {
1008 if err := c.readRecord(); err != nil {
1009 return nil, err
1010 }
1011 }
1012
1013 data := c.hand.Bytes()
1014 n := int(data[1])<<16 | int(data[2])<<8 | int(data[3])
1015 if n > maxHandshake {
1016 c.sendAlertLocked(alertInternalError)
1017 return nil, c.in.setErrorLocked(fmt.Errorf("tls: handshake message of length %d bytes exceeds maximum of %d bytes", n, maxHandshake))
1018 }
1019 for c.hand.Len() < 4+n {
1020 if err := c.readRecord(); err != nil {
1021 return nil, err
1022 }
1023 }
1024 data = c.hand.Next(4 + n)
1025 var m handshakeMessage
1026 switch data[0] {
1027 case typeHelloRequest:
1028 m = new(helloRequestMsg)
1029 case typeClientHello:
1030 m = new(clientHelloMsg)
1031 case typeServerHello:
1032 m = new(serverHelloMsg)
1033 case typeNewSessionTicket:
1034 if c.vers == VersionTLS13 {
1035 m = new(newSessionTicketMsgTLS13)
1036 } else {
1037 m = new(newSessionTicketMsg)
1038 }
1039 case typeCertificate:
1040 if c.vers == VersionTLS13 {
1041 m = new(certificateMsgTLS13)
1042 } else {
1043 m = new(certificateMsg)
1044 }
1045 case typeCertificateRequest:
1046 if c.vers == VersionTLS13 {
1047 m = new(certificateRequestMsgTLS13)
1048 } else {
1049 m = &certificateRequestMsg{
1050 hasSignatureAlgorithm: c.vers >= VersionTLS12,
1051 }
1052 }
1053 case typeCertificateStatus:
1054 m = new(certificateStatusMsg)
1055 case typeServerKeyExchange:
1056 m = new(serverKeyExchangeMsg)
1057 case typeServerHelloDone:
1058 m = new(serverHelloDoneMsg)
1059 case typeClientKeyExchange:
1060 m = new(clientKeyExchangeMsg)
1061 case typeCertificateVerify:
1062 m = &certificateVerifyMsg{
1063 hasSignatureAlgorithm: c.vers >= VersionTLS12,
1064 }
1065 case typeFinished:
1066 m = new(finishedMsg)
1067 case typeEncryptedExtensions:
1068 m = new(encryptedExtensionsMsg)
1069 case typeEndOfEarlyData:
1070 m = new(endOfEarlyDataMsg)
1071 case typeKeyUpdate:
1072 m = new(keyUpdateMsg)
1073 default:
1074 return nil, c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
1075 }
1076
1077
1078
1079
1080 data = append([]byte(nil), data...)
1081
1082 if !m.unmarshal(data) {
1083 return nil, c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
1084 }
1085 return m, nil
1086 }
1087
1088 var (
1089 errShutdown = errors.New("tls: protocol is shutdown")
1090 )
1091
1092
1093
1094
1095
1096
1097
1098 func (c *Conn) Write(b []byte) (int, error) {
1099
1100 for {
1101 x := atomic.LoadInt32(&c.activeCall)
1102 if x&1 != 0 {
1103 return 0, net.ErrClosed
1104 }
1105 if atomic.CompareAndSwapInt32(&c.activeCall, x, x+2) {
1106 break
1107 }
1108 }
1109 defer atomic.AddInt32(&c.activeCall, -2)
1110
1111 if err := c.Handshake(); err != nil {
1112 return 0, err
1113 }
1114
1115 c.out.Lock()
1116 defer c.out.Unlock()
1117
1118 if err := c.out.err; err != nil {
1119 return 0, err
1120 }
1121
1122 if !c.handshakeComplete() {
1123 return 0, alertInternalError
1124 }
1125
1126 if c.closeNotifySent {
1127 return 0, errShutdown
1128 }
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139 var m int
1140 if len(b) > 1 && c.vers == VersionTLS10 {
1141 if _, ok := c.out.cipher.(cipher.BlockMode); ok {
1142 n, err := c.writeRecordLocked(recordTypeApplicationData, b[:1])
1143 if err != nil {
1144 return n, c.out.setErrorLocked(err)
1145 }
1146 m, b = 1, b[1:]
1147 }
1148 }
1149
1150 n, err := c.writeRecordLocked(recordTypeApplicationData, b)
1151 return n + m, c.out.setErrorLocked(err)
1152 }
1153
1154
1155 func (c *Conn) handleRenegotiation() error {
1156 if c.vers == VersionTLS13 {
1157 return errors.New("tls: internal error: unexpected renegotiation")
1158 }
1159
1160 msg, err := c.readHandshake()
1161 if err != nil {
1162 return err
1163 }
1164
1165 helloReq, ok := msg.(*helloRequestMsg)
1166 if !ok {
1167 c.sendAlert(alertUnexpectedMessage)
1168 return unexpectedMessageError(helloReq, msg)
1169 }
1170
1171 if !c.isClient {
1172 return c.sendAlert(alertNoRenegotiation)
1173 }
1174
1175 switch c.config.Renegotiation {
1176 case RenegotiateNever:
1177 return c.sendAlert(alertNoRenegotiation)
1178 case RenegotiateOnceAsClient:
1179 if c.handshakes > 1 {
1180 return c.sendAlert(alertNoRenegotiation)
1181 }
1182 case RenegotiateFreelyAsClient:
1183
1184 default:
1185 c.sendAlert(alertInternalError)
1186 return errors.New("tls: unknown Renegotiation value")
1187 }
1188
1189 c.handshakeMutex.Lock()
1190 defer c.handshakeMutex.Unlock()
1191
1192 atomic.StoreUint32(&c.handshakeStatus, 0)
1193 if c.handshakeErr = c.clientHandshake(); c.handshakeErr == nil {
1194 c.handshakes++
1195 }
1196 return c.handshakeErr
1197 }
1198
1199
1200
1201 func (c *Conn) handlePostHandshakeMessage() error {
1202 if c.vers != VersionTLS13 {
1203 return c.handleRenegotiation()
1204 }
1205
1206 msg, err := c.readHandshake()
1207 if err != nil {
1208 return err
1209 }
1210
1211 c.retryCount++
1212 if c.retryCount > maxUselessRecords {
1213 c.sendAlert(alertUnexpectedMessage)
1214 return c.in.setErrorLocked(errors.New("tls: too many non-advancing records"))
1215 }
1216
1217 switch msg := msg.(type) {
1218 case *newSessionTicketMsgTLS13:
1219 return c.handleNewSessionTicket(msg)
1220 case *keyUpdateMsg:
1221 return c.handleKeyUpdate(msg)
1222 default:
1223 c.sendAlert(alertUnexpectedMessage)
1224 return fmt.Errorf("tls: received unexpected handshake message of type %T", msg)
1225 }
1226 }
1227
1228 func (c *Conn) handleKeyUpdate(keyUpdate *keyUpdateMsg) error {
1229 cipherSuite := cipherSuiteTLS13ByID(c.cipherSuite)
1230 if cipherSuite == nil {
1231 return c.in.setErrorLocked(c.sendAlert(alertInternalError))
1232 }
1233
1234 newSecret := cipherSuite.nextTrafficSecret(c.in.trafficSecret)
1235 c.in.setTrafficSecret(cipherSuite, newSecret)
1236
1237 if keyUpdate.updateRequested {
1238 c.out.Lock()
1239 defer c.out.Unlock()
1240
1241 msg := &keyUpdateMsg{}
1242 _, err := c.writeRecordLocked(recordTypeHandshake, msg.marshal())
1243 if err != nil {
1244
1245 c.out.setErrorLocked(err)
1246 return nil
1247 }
1248
1249 newSecret := cipherSuite.nextTrafficSecret(c.out.trafficSecret)
1250 c.out.setTrafficSecret(cipherSuite, newSecret)
1251 }
1252
1253 return nil
1254 }
1255
1256
1257
1258
1259
1260
1261
1262 func (c *Conn) Read(b []byte) (int, error) {
1263 if err := c.Handshake(); err != nil {
1264 return 0, err
1265 }
1266 if len(b) == 0 {
1267
1268
1269 return 0, nil
1270 }
1271
1272 c.in.Lock()
1273 defer c.in.Unlock()
1274
1275 for c.input.Len() == 0 {
1276 if err := c.readRecord(); err != nil {
1277 return 0, err
1278 }
1279 for c.hand.Len() > 0 {
1280 if err := c.handlePostHandshakeMessage(); err != nil {
1281 return 0, err
1282 }
1283 }
1284 }
1285
1286 n, _ := c.input.Read(b)
1287
1288
1289
1290
1291
1292
1293
1294
1295 if n != 0 && c.input.Len() == 0 && c.rawInput.Len() > 0 &&
1296 recordType(c.rawInput.Bytes()[0]) == recordTypeAlert {
1297 if err := c.readRecord(); err != nil {
1298 return n, err
1299 }
1300 }
1301
1302 return n, nil
1303 }
1304
1305
1306 func (c *Conn) Close() error {
1307
1308 var x int32
1309 for {
1310 x = atomic.LoadInt32(&c.activeCall)
1311 if x&1 != 0 {
1312 return net.ErrClosed
1313 }
1314 if atomic.CompareAndSwapInt32(&c.activeCall, x, x|1) {
1315 break
1316 }
1317 }
1318 if x != 0 {
1319
1320
1321
1322
1323
1324
1325 return c.conn.Close()
1326 }
1327
1328 var alertErr error
1329 if c.handshakeComplete() {
1330 if err := c.closeNotify(); err != nil {
1331 alertErr = fmt.Errorf("tls: failed to send closeNotify alert (but connection was closed anyway): %w", err)
1332 }
1333 }
1334
1335 if err := c.conn.Close(); err != nil {
1336 return err
1337 }
1338 return alertErr
1339 }
1340
1341 var errEarlyCloseWrite = errors.New("tls: CloseWrite called before handshake complete")
1342
1343
1344
1345
1346 func (c *Conn) CloseWrite() error {
1347 if !c.handshakeComplete() {
1348 return errEarlyCloseWrite
1349 }
1350
1351 return c.closeNotify()
1352 }
1353
1354 func (c *Conn) closeNotify() error {
1355 c.out.Lock()
1356 defer c.out.Unlock()
1357
1358 if !c.closeNotifySent {
1359
1360 c.SetWriteDeadline(time.Now().Add(time.Second * 5))
1361 c.closeNotifyErr = c.sendAlertLocked(alertCloseNotify)
1362 c.closeNotifySent = true
1363
1364 c.SetWriteDeadline(time.Now())
1365 }
1366 return c.closeNotifyErr
1367 }
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377 func (c *Conn) Handshake() error {
1378 c.handshakeMutex.Lock()
1379 defer c.handshakeMutex.Unlock()
1380
1381 if err := c.handshakeErr; err != nil {
1382 return err
1383 }
1384 if c.handshakeComplete() {
1385 return nil
1386 }
1387
1388 c.in.Lock()
1389 defer c.in.Unlock()
1390
1391 c.handshakeErr = c.handshakeFn()
1392 if c.handshakeErr == nil {
1393 c.handshakes++
1394 } else {
1395
1396
1397 c.flush()
1398 }
1399
1400 if c.handshakeErr == nil && !c.handshakeComplete() {
1401 c.handshakeErr = errors.New("tls: internal error: handshake should have had a result")
1402 }
1403
1404 return c.handshakeErr
1405 }
1406
1407
1408 func (c *Conn) ConnectionState() ConnectionState {
1409 c.handshakeMutex.Lock()
1410 defer c.handshakeMutex.Unlock()
1411 return c.connectionStateLocked()
1412 }
1413
1414 func (c *Conn) connectionStateLocked() ConnectionState {
1415 var state ConnectionState
1416 state.HandshakeComplete = c.handshakeComplete()
1417 state.Version = c.vers
1418 state.NegotiatedProtocol = c.clientProtocol
1419 state.DidResume = c.didResume
1420 state.NegotiatedProtocolIsMutual = true
1421 state.ServerName = c.serverName
1422 state.CipherSuite = c.cipherSuite
1423 state.PeerCertificates = c.peerCertificates
1424 state.VerifiedChains = c.verifiedChains
1425 state.SignedCertificateTimestamps = c.scts
1426 state.OCSPResponse = c.ocspResponse
1427 if !c.didResume && c.vers != VersionTLS13 {
1428 if c.clientFinishedIsFirst {
1429 state.TLSUnique = c.clientFinished[:]
1430 } else {
1431 state.TLSUnique = c.serverFinished[:]
1432 }
1433 }
1434 if c.config.Renegotiation != RenegotiateNever {
1435 state.ekm = noExportedKeyingMaterial
1436 } else {
1437 state.ekm = c.ekm
1438 }
1439 return state
1440 }
1441
1442
1443
1444 func (c *Conn) OCSPResponse() []byte {
1445 c.handshakeMutex.Lock()
1446 defer c.handshakeMutex.Unlock()
1447
1448 return c.ocspResponse
1449 }
1450
1451
1452
1453
1454 func (c *Conn) VerifyHostname(host string) error {
1455 c.handshakeMutex.Lock()
1456 defer c.handshakeMutex.Unlock()
1457 if !c.isClient {
1458 return errors.New("tls: VerifyHostname called on TLS server connection")
1459 }
1460 if !c.handshakeComplete() {
1461 return errors.New("tls: handshake has not yet been performed")
1462 }
1463 if len(c.verifiedChains) == 0 {
1464 return errors.New("tls: handshake did not verify certificate chain")
1465 }
1466 return c.peerCertificates[0].VerifyHostname(host)
1467 }
1468
1469 func (c *Conn) handshakeComplete() bool {
1470 return atomic.LoadUint32(&c.handshakeStatus) == 1
1471 }
1472
View as plain text