1
2
3
4
5
6
7 package tls
8
9 import (
10 "bytes"
11 "crypto/cipher"
12 "crypto/subtle"
13 "crypto/x509"
14 "errors"
15 "fmt"
16 "io"
17 "net"
18 "sync"
19 "sync/atomic"
20 "time"
21 )
22
23
24
25 type Conn struct {
26
27 conn net.Conn
28 isClient bool
29
30
31
32
33 handshakeStatus uint32
34
35 handshakeMutex sync.Mutex
36 handshakeErr error
37 vers uint16
38 haveVers bool
39 config *Config
40
41
42
43 handshakes int
44 didResume bool
45 cipherSuite uint16
46 ocspResponse []byte
47 scts [][]byte
48 peerCertificates []*x509.Certificate
49
50
51 verifiedChains [][]*x509.Certificate
52
53 serverName string
54
55
56
57 secureRenegotiation bool
58
59 ekm func(label string, context []byte, length int) ([]byte, error)
60
61
62
63
64
65 clientFinishedIsFirst bool
66
67
68 closeNotifyErr error
69
70
71 closeNotifySent bool
72
73
74
75
76
77 clientFinished [12]byte
78 serverFinished [12]byte
79
80 clientProtocol string
81 clientProtocolFallback bool
82
83
84 in, out halfConn
85 rawInput *block
86 input *block
87 hand bytes.Buffer
88 buffering bool
89 sendBuf []byte
90
91
92
93 bytesSent int64
94 packetsSent int64
95
96
97
98 warnCount int
99
100
101
102
103 activeCall int32
104
105 tmp [16]byte
106 }
107
108
109
110
111
112
113 func (c *Conn) LocalAddr() net.Addr {
114 return c.conn.LocalAddr()
115 }
116
117
118 func (c *Conn) RemoteAddr() net.Addr {
119 return c.conn.RemoteAddr()
120 }
121
122
123
124
125 func (c *Conn) SetDeadline(t time.Time) error {
126 return c.conn.SetDeadline(t)
127 }
128
129
130
131 func (c *Conn) SetReadDeadline(t time.Time) error {
132 return c.conn.SetReadDeadline(t)
133 }
134
135
136
137
138 func (c *Conn) SetWriteDeadline(t time.Time) error {
139 return c.conn.SetWriteDeadline(t)
140 }
141
142
143
144 type halfConn struct {
145 sync.Mutex
146
147 err error
148 version uint16
149 cipher interface{}
150 mac macFunction
151 seq [8]byte
152 bfree *block
153 additionalData [13]byte
154
155 nextCipher interface{}
156 nextMac macFunction
157
158
159 inDigestBuf, outDigestBuf []byte
160 }
161
162 func (hc *halfConn) setErrorLocked(err error) error {
163 hc.err = err
164 return err
165 }
166
167
168
169 func (hc *halfConn) prepareCipherSpec(version uint16, cipher interface{}, mac macFunction) {
170 hc.version = version
171 hc.nextCipher = cipher
172 hc.nextMac = mac
173 }
174
175
176
177 func (hc *halfConn) changeCipherSpec() error {
178 if hc.nextCipher == nil {
179 return alertInternalError
180 }
181 hc.cipher = hc.nextCipher
182 hc.mac = hc.nextMac
183 hc.nextCipher = nil
184 hc.nextMac = nil
185 for i := range hc.seq {
186 hc.seq[i] = 0
187 }
188 return nil
189 }
190
191
192 func (hc *halfConn) incSeq() {
193 for i := 7; i >= 0; i-- {
194 hc.seq[i]++
195 if hc.seq[i] != 0 {
196 return
197 }
198 }
199
200
201
202
203 panic("TLS: sequence number wraparound")
204 }
205
206
207
208
209 func extractPadding(payload []byte) (toRemove int, good byte) {
210 if len(payload) < 1 {
211 return 0, 0
212 }
213
214 paddingLen := payload[len(payload)-1]
215 t := uint(len(payload)-1) - uint(paddingLen)
216
217 good = byte(int32(^t) >> 31)
218
219
220 toCheck := 256
221
222 if toCheck > len(payload) {
223 toCheck = len(payload)
224 }
225
226 for i := 0; i < toCheck; i++ {
227 t := uint(paddingLen) - uint(i)
228
229 mask := byte(int32(^t) >> 31)
230 b := payload[len(payload)-1-i]
231 good &^= mask&paddingLen ^ mask&b
232 }
233
234
235
236 good &= good << 4
237 good &= good << 2
238 good &= good << 1
239 good = uint8(int8(good) >> 7)
240
241 toRemove = int(paddingLen) + 1
242 return
243 }
244
245
246
247
248 func extractPaddingSSL30(payload []byte) (toRemove int, good byte) {
249 if len(payload) < 1 {
250 return 0, 0
251 }
252
253 paddingLen := int(payload[len(payload)-1]) + 1
254 if paddingLen > len(payload) {
255 return 0, 0
256 }
257
258 return paddingLen, 255
259 }
260
261 func roundUp(a, b int) int {
262 return a + (b-a%b)%b
263 }
264
265
266 type cbcMode interface {
267 cipher.BlockMode
268 SetIV([]byte)
269 }
270
271
272
273
274 func (hc *halfConn) decrypt(b *block) (ok bool, prefixLen int, alertValue alert) {
275
276 payload := b.data[recordHeaderLen:]
277
278 macSize := 0
279 if hc.mac != nil {
280 macSize = hc.mac.Size()
281 }
282
283 paddingGood := byte(255)
284 paddingLen := 0
285 explicitIVLen := 0
286
287
288 if hc.cipher != nil {
289 switch c := hc.cipher.(type) {
290 case cipher.Stream:
291 c.XORKeyStream(payload, payload)
292 case aead:
293 explicitIVLen = c.explicitNonceLen()
294 if len(payload) < explicitIVLen {
295 return false, 0, alertBadRecordMAC
296 }
297 nonce := payload[:explicitIVLen]
298 payload = payload[explicitIVLen:]
299
300 if len(nonce) == 0 {
301 nonce = hc.seq[:]
302 }
303
304 copy(hc.additionalData[:], hc.seq[:])
305 copy(hc.additionalData[8:], b.data[:3])
306 n := len(payload) - c.Overhead()
307 hc.additionalData[11] = byte(n >> 8)
308 hc.additionalData[12] = byte(n)
309 var err error
310 payload, err = c.Open(payload[:0], nonce, payload, hc.additionalData[:])
311 if err != nil {
312 return false, 0, alertBadRecordMAC
313 }
314 b.resize(recordHeaderLen + explicitIVLen + len(payload))
315 case cbcMode:
316 blockSize := c.BlockSize()
317 if hc.version >= VersionTLS11 {
318 explicitIVLen = blockSize
319 }
320
321 if len(payload)%blockSize != 0 || len(payload) < roundUp(explicitIVLen+macSize+1, blockSize) {
322 return false, 0, alertBadRecordMAC
323 }
324
325 if explicitIVLen > 0 {
326 c.SetIV(payload[:explicitIVLen])
327 payload = payload[explicitIVLen:]
328 }
329 c.CryptBlocks(payload, payload)
330 if hc.version == VersionSSL30 {
331 paddingLen, paddingGood = extractPaddingSSL30(payload)
332 } else {
333 paddingLen, paddingGood = extractPadding(payload)
334
335
336
337
338
339
340
341 }
342 default:
343 panic("unknown cipher type")
344 }
345 }
346
347
348 if hc.mac != nil {
349 if len(payload) < macSize {
350 return false, 0, alertBadRecordMAC
351 }
352
353
354 n := len(payload) - macSize - paddingLen
355 n = subtle.ConstantTimeSelect(int(uint32(n)>>31), 0, n)
356 b.data[3] = byte(n >> 8)
357 b.data[4] = byte(n)
358 remoteMAC := payload[n : n+macSize]
359 localMAC := hc.mac.MAC(hc.inDigestBuf, hc.seq[0:], b.data[:recordHeaderLen], payload[:n], payload[n+macSize:])
360
361 if subtle.ConstantTimeCompare(localMAC, remoteMAC) != 1 || paddingGood != 255 {
362 return false, 0, alertBadRecordMAC
363 }
364 hc.inDigestBuf = localMAC
365
366 b.resize(recordHeaderLen + explicitIVLen + n)
367 }
368 hc.incSeq()
369
370 return true, recordHeaderLen + explicitIVLen, 0
371 }
372
373
374
375
376
377
378 func padToBlockSize(payload []byte, blockSize int) (prefix, finalBlock []byte) {
379 overrun := len(payload) % blockSize
380 paddingLen := blockSize - overrun
381 prefix = payload[:len(payload)-overrun]
382 finalBlock = make([]byte, blockSize)
383 copy(finalBlock, payload[len(payload)-overrun:])
384 for i := overrun; i < blockSize; i++ {
385 finalBlock[i] = byte(paddingLen - 1)
386 }
387 return
388 }
389
390
391 func (hc *halfConn) encrypt(b *block, explicitIVLen int) (bool, alert) {
392
393 if hc.mac != nil {
394 mac := hc.mac.MAC(hc.outDigestBuf, hc.seq[0:], b.data[:recordHeaderLen], b.data[recordHeaderLen+explicitIVLen:], nil)
395
396 n := len(b.data)
397 b.resize(n + len(mac))
398 copy(b.data[n:], mac)
399 hc.outDigestBuf = mac
400 }
401
402 payload := b.data[recordHeaderLen:]
403
404
405 if hc.cipher != nil {
406 switch c := hc.cipher.(type) {
407 case cipher.Stream:
408 c.XORKeyStream(payload, payload)
409 case aead:
410 payloadLen := len(b.data) - recordHeaderLen - explicitIVLen
411 b.resize(len(b.data) + c.Overhead())
412 nonce := b.data[recordHeaderLen : recordHeaderLen+explicitIVLen]
413 if len(nonce) == 0 {
414 nonce = hc.seq[:]
415 }
416 payload := b.data[recordHeaderLen+explicitIVLen:]
417 payload = payload[:payloadLen]
418
419 copy(hc.additionalData[:], hc.seq[:])
420 copy(hc.additionalData[8:], b.data[:3])
421 hc.additionalData[11] = byte(payloadLen >> 8)
422 hc.additionalData[12] = byte(payloadLen)
423
424 c.Seal(payload[:0], nonce, payload, hc.additionalData[:])
425 case cbcMode:
426 blockSize := c.BlockSize()
427 if explicitIVLen > 0 {
428 c.SetIV(payload[:explicitIVLen])
429 payload = payload[explicitIVLen:]
430 }
431 prefix, finalBlock := padToBlockSize(payload, blockSize)
432 b.resize(recordHeaderLen + explicitIVLen + len(prefix) + len(finalBlock))
433 c.CryptBlocks(b.data[recordHeaderLen+explicitIVLen:], prefix)
434 c.CryptBlocks(b.data[recordHeaderLen+explicitIVLen+len(prefix):], finalBlock)
435 default:
436 panic("unknown cipher type")
437 }
438 }
439
440
441 n := len(b.data) - recordHeaderLen
442 b.data[3] = byte(n >> 8)
443 b.data[4] = byte(n)
444 hc.incSeq()
445
446 return true, 0
447 }
448
449
450 type block struct {
451 data []byte
452 off int
453 link *block
454 }
455
456
457 func (b *block) resize(n int) {
458 if n > cap(b.data) {
459 b.reserve(n)
460 }
461 b.data = b.data[0:n]
462 }
463
464
465 func (b *block) reserve(n int) {
466 if cap(b.data) >= n {
467 return
468 }
469 m := cap(b.data)
470 if m == 0 {
471 m = 1024
472 }
473 for m < n {
474 m *= 2
475 }
476 data := make([]byte, len(b.data), m)
477 copy(data, b.data)
478 b.data = data
479 }
480
481
482
483 func (b *block) readFromUntil(r io.Reader, n int) error {
484
485 if len(b.data) >= n {
486 return nil
487 }
488
489
490 b.reserve(n)
491 for {
492 m, err := r.Read(b.data[len(b.data):cap(b.data)])
493 b.data = b.data[0 : len(b.data)+m]
494 if len(b.data) >= n {
495
496
497 break
498 }
499 if err != nil {
500 return err
501 }
502 }
503 return nil
504 }
505
506 func (b *block) Read(p []byte) (n int, err error) {
507 n = copy(p, b.data[b.off:])
508 b.off += n
509 return
510 }
511
512
513 func (hc *halfConn) newBlock() *block {
514 b := hc.bfree
515 if b == nil {
516 return new(block)
517 }
518 hc.bfree = b.link
519 b.link = nil
520 b.resize(0)
521 return b
522 }
523
524
525
526
527
528 func (hc *halfConn) freeBlock(b *block) {
529 b.link = hc.bfree
530 hc.bfree = b
531 }
532
533
534
535
536 func (hc *halfConn) splitBlock(b *block, n int) (*block, *block) {
537 if len(b.data) <= n {
538 return b, nil
539 }
540 bb := hc.newBlock()
541 bb.resize(len(b.data) - n)
542 copy(bb.data, b.data[n:])
543 b.data = b.data[0:n]
544 return b, bb
545 }
546
547
548 type RecordHeaderError struct {
549
550 Msg string
551
552
553 RecordHeader [5]byte
554 }
555
556 func (e RecordHeaderError) Error() string { return "tls: " + e.Msg }
557
558 func (c *Conn) newRecordHeaderError(msg string) (err RecordHeaderError) {
559 err.Msg = msg
560 copy(err.RecordHeader[:], c.rawInput.data)
561 return err
562 }
563
564
565
566 func (c *Conn) readRecord(want recordType) error {
567
568
569
570 switch want {
571 default:
572 c.sendAlert(alertInternalError)
573 return c.in.setErrorLocked(errors.New("tls: unknown record type requested"))
574 case recordTypeHandshake, recordTypeChangeCipherSpec:
575 if c.handshakeComplete() {
576 c.sendAlert(alertInternalError)
577 return c.in.setErrorLocked(errors.New("tls: handshake or ChangeCipherSpec requested while not in handshake"))
578 }
579 case recordTypeApplicationData:
580 if !c.handshakeComplete() {
581 c.sendAlert(alertInternalError)
582 return c.in.setErrorLocked(errors.New("tls: application data record requested while in handshake"))
583 }
584 }
585
586 Again:
587 if c.rawInput == nil {
588 c.rawInput = c.in.newBlock()
589 }
590 b := c.rawInput
591
592
593 if err := b.readFromUntil(c.conn, recordHeaderLen); err != nil {
594
595
596
597
598
599
600 if e, ok := err.(net.Error); !ok || !e.Temporary() {
601 c.in.setErrorLocked(err)
602 }
603 return err
604 }
605 typ := recordType(b.data[0])
606
607
608
609
610
611 if want == recordTypeHandshake && typ == 0x80 {
612 c.sendAlert(alertProtocolVersion)
613 return c.in.setErrorLocked(c.newRecordHeaderError("unsupported SSLv2 handshake received"))
614 }
615
616 vers := uint16(b.data[1])<<8 | uint16(b.data[2])
617 n := int(b.data[3])<<8 | int(b.data[4])
618 if c.haveVers && vers != c.vers {
619 c.sendAlert(alertProtocolVersion)
620 msg := fmt.Sprintf("received record with version %x when expecting version %x", vers, c.vers)
621 return c.in.setErrorLocked(c.newRecordHeaderError(msg))
622 }
623 if n > maxCiphertext {
624 c.sendAlert(alertRecordOverflow)
625 msg := fmt.Sprintf("oversized record received with length %d", n)
626 return c.in.setErrorLocked(c.newRecordHeaderError(msg))
627 }
628 if !c.haveVers {
629
630
631
632
633 if (typ != recordTypeAlert && typ != want) || vers >= 0x1000 {
634 c.sendAlert(alertUnexpectedMessage)
635 return c.in.setErrorLocked(c.newRecordHeaderError("first record does not look like a TLS handshake"))
636 }
637 }
638 if err := b.readFromUntil(c.conn, recordHeaderLen+n); err != nil {
639 if err == io.EOF {
640 err = io.ErrUnexpectedEOF
641 }
642 if e, ok := err.(net.Error); !ok || !e.Temporary() {
643 c.in.setErrorLocked(err)
644 }
645 return err
646 }
647
648
649 b, c.rawInput = c.in.splitBlock(b, recordHeaderLen+n)
650 ok, off, alertValue := c.in.decrypt(b)
651 if !ok {
652 c.in.freeBlock(b)
653 return c.in.setErrorLocked(c.sendAlert(alertValue))
654 }
655 b.off = off
656 data := b.data[b.off:]
657 if len(data) > maxPlaintext {
658 err := c.sendAlert(alertRecordOverflow)
659 c.in.freeBlock(b)
660 return c.in.setErrorLocked(err)
661 }
662
663 if typ != recordTypeAlert && len(data) > 0 {
664
665 c.warnCount = 0
666 }
667
668 switch typ {
669 default:
670 c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
671
672 case recordTypeAlert:
673 if len(data) != 2 {
674 c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
675 break
676 }
677 if alert(data[1]) == alertCloseNotify {
678 c.in.setErrorLocked(io.EOF)
679 break
680 }
681 switch data[0] {
682 case alertLevelWarning:
683
684 c.in.freeBlock(b)
685
686 c.warnCount++
687 if c.warnCount > maxWarnAlertCount {
688 c.sendAlert(alertUnexpectedMessage)
689 return c.in.setErrorLocked(errors.New("tls: too many warn alerts"))
690 }
691
692 goto Again
693 case alertLevelError:
694 c.in.setErrorLocked(&net.OpError{Op: "remote error", Err: alert(data[1])})
695 default:
696 c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
697 }
698
699 case recordTypeChangeCipherSpec:
700 if typ != want || len(data) != 1 || data[0] != 1 {
701 c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
702 break
703 }
704
705 if c.hand.Len() > 0 {
706 c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
707 break
708 }
709 err := c.in.changeCipherSpec()
710 if err != nil {
711 c.in.setErrorLocked(c.sendAlert(err.(alert)))
712 }
713
714 case recordTypeApplicationData:
715 if typ != want {
716 c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
717 break
718 }
719 c.input = b
720 b = nil
721
722 case recordTypeHandshake:
723
724 if typ != want && !(c.isClient && c.config.Renegotiation != RenegotiateNever) {
725 return c.in.setErrorLocked(c.sendAlert(alertNoRenegotiation))
726 }
727 c.hand.Write(data)
728 }
729
730 if b != nil {
731 c.in.freeBlock(b)
732 }
733 return c.in.err
734 }
735
736
737 func (c *Conn) sendAlertLocked(err alert) error {
738 switch err {
739 case alertNoRenegotiation, alertCloseNotify:
740 c.tmp[0] = alertLevelWarning
741 default:
742 c.tmp[0] = alertLevelError
743 }
744 c.tmp[1] = byte(err)
745
746 _, writeErr := c.writeRecordLocked(recordTypeAlert, c.tmp[0:2])
747 if err == alertCloseNotify {
748
749 return writeErr
750 }
751
752 return c.out.setErrorLocked(&net.OpError{Op: "local error", Err: err})
753 }
754
755
756 func (c *Conn) sendAlert(err alert) error {
757 c.out.Lock()
758 defer c.out.Unlock()
759 return c.sendAlertLocked(err)
760 }
761
762 const (
763
764
765
766
767
768 tcpMSSEstimate = 1208
769
770
771
772
773 recordSizeBoostThreshold = 128 * 1024
774 )
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792 func (c *Conn) maxPayloadSizeForWrite(typ recordType, explicitIVLen int) int {
793 if c.config.DynamicRecordSizingDisabled || typ != recordTypeApplicationData {
794 return maxPlaintext
795 }
796
797 if c.bytesSent >= recordSizeBoostThreshold {
798 return maxPlaintext
799 }
800
801
802 macSize := 0
803 if c.out.mac != nil {
804 macSize = c.out.mac.Size()
805 }
806
807 payloadBytes := tcpMSSEstimate - recordHeaderLen - explicitIVLen
808 if c.out.cipher != nil {
809 switch ciph := c.out.cipher.(type) {
810 case cipher.Stream:
811 payloadBytes -= macSize
812 case cipher.AEAD:
813 payloadBytes -= ciph.Overhead()
814 case cbcMode:
815 blockSize := ciph.BlockSize()
816
817
818 payloadBytes = (payloadBytes & ^(blockSize - 1)) - 1
819
820
821 payloadBytes -= macSize
822 default:
823 panic("unknown cipher type")
824 }
825 }
826
827
828 pkt := c.packetsSent
829 c.packetsSent++
830 if pkt > 1000 {
831 return maxPlaintext
832 }
833
834 n := payloadBytes * int(pkt+1)
835 if n > maxPlaintext {
836 n = maxPlaintext
837 }
838 return n
839 }
840
841 func (c *Conn) write(data []byte) (int, error) {
842 if c.buffering {
843 c.sendBuf = append(c.sendBuf, data...)
844 return len(data), nil
845 }
846
847 n, err := c.conn.Write(data)
848 c.bytesSent += int64(n)
849 return n, err
850 }
851
852 func (c *Conn) flush() (int, error) {
853 if len(c.sendBuf) == 0 {
854 return 0, nil
855 }
856
857 n, err := c.conn.Write(c.sendBuf)
858 c.bytesSent += int64(n)
859 c.sendBuf = nil
860 c.buffering = false
861 return n, err
862 }
863
864
865
866 func (c *Conn) writeRecordLocked(typ recordType, data []byte) (int, error) {
867 b := c.out.newBlock()
868 defer c.out.freeBlock(b)
869
870 var n int
871 for len(data) > 0 {
872 explicitIVLen := 0
873 explicitIVIsSeq := false
874
875 var cbc cbcMode
876 if c.out.version >= VersionTLS11 {
877 var ok bool
878 if cbc, ok = c.out.cipher.(cbcMode); ok {
879 explicitIVLen = cbc.BlockSize()
880 }
881 }
882 if explicitIVLen == 0 {
883 if c, ok := c.out.cipher.(aead); ok {
884 explicitIVLen = c.explicitNonceLen()
885
886
887
888
889
890
891
892 explicitIVIsSeq = explicitIVLen > 0
893 }
894 }
895 m := len(data)
896 if maxPayload := c.maxPayloadSizeForWrite(typ, explicitIVLen); m > maxPayload {
897 m = maxPayload
898 }
899 b.resize(recordHeaderLen + explicitIVLen + m)
900 b.data[0] = byte(typ)
901 vers := c.vers
902 if vers == 0 {
903
904
905 vers = VersionTLS10
906 }
907 b.data[1] = byte(vers >> 8)
908 b.data[2] = byte(vers)
909 b.data[3] = byte(m >> 8)
910 b.data[4] = byte(m)
911 if explicitIVLen > 0 {
912 explicitIV := b.data[recordHeaderLen : recordHeaderLen+explicitIVLen]
913 if explicitIVIsSeq {
914 copy(explicitIV, c.out.seq[:])
915 } else {
916 if _, err := io.ReadFull(c.config.rand(), explicitIV); err != nil {
917 return n, err
918 }
919 }
920 }
921 copy(b.data[recordHeaderLen+explicitIVLen:], data)
922 c.out.encrypt(b, explicitIVLen)
923 if _, err := c.write(b.data); err != nil {
924 return n, err
925 }
926 n += m
927 data = data[m:]
928 }
929
930 if typ == recordTypeChangeCipherSpec {
931 if err := c.out.changeCipherSpec(); err != nil {
932 return n, c.sendAlertLocked(err.(alert))
933 }
934 }
935
936 return n, nil
937 }
938
939
940
941 func (c *Conn) writeRecord(typ recordType, data []byte) (int, error) {
942 c.out.Lock()
943 defer c.out.Unlock()
944
945 return c.writeRecordLocked(typ, data)
946 }
947
948
949
950 func (c *Conn) readHandshake() (interface{}, error) {
951 for c.hand.Len() < 4 {
952 if err := c.in.err; err != nil {
953 return nil, err
954 }
955 if err := c.readRecord(recordTypeHandshake); err != nil {
956 return nil, err
957 }
958 }
959
960 data := c.hand.Bytes()
961 n := int(data[1])<<16 | int(data[2])<<8 | int(data[3])
962 if n > maxHandshake {
963 c.sendAlertLocked(alertInternalError)
964 return nil, c.in.setErrorLocked(fmt.Errorf("tls: handshake message of length %d bytes exceeds maximum of %d bytes", n, maxHandshake))
965 }
966 for c.hand.Len() < 4+n {
967 if err := c.in.err; err != nil {
968 return nil, err
969 }
970 if err := c.readRecord(recordTypeHandshake); err != nil {
971 return nil, err
972 }
973 }
974 data = c.hand.Next(4 + n)
975 var m handshakeMessage
976 switch data[0] {
977 case typeHelloRequest:
978 m = new(helloRequestMsg)
979 case typeClientHello:
980 m = new(clientHelloMsg)
981 case typeServerHello:
982 m = new(serverHelloMsg)
983 case typeNewSessionTicket:
984 m = new(newSessionTicketMsg)
985 case typeCertificate:
986 m = new(certificateMsg)
987 case typeCertificateRequest:
988 m = &certificateRequestMsg{
989 hasSignatureAndHash: c.vers >= VersionTLS12,
990 }
991 case typeCertificateStatus:
992 m = new(certificateStatusMsg)
993 case typeServerKeyExchange:
994 m = new(serverKeyExchangeMsg)
995 case typeServerHelloDone:
996 m = new(serverHelloDoneMsg)
997 case typeClientKeyExchange:
998 m = new(clientKeyExchangeMsg)
999 case typeCertificateVerify:
1000 m = &certificateVerifyMsg{
1001 hasSignatureAndHash: c.vers >= VersionTLS12,
1002 }
1003 case typeNextProtocol:
1004 m = new(nextProtoMsg)
1005 case typeFinished:
1006 m = new(finishedMsg)
1007 default:
1008 return nil, c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
1009 }
1010
1011
1012
1013
1014 data = append([]byte(nil), data...)
1015
1016 if !m.unmarshal(data) {
1017 return nil, c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
1018 }
1019 return m, nil
1020 }
1021
1022 var (
1023 errClosed = errors.New("tls: use of closed connection")
1024 errShutdown = errors.New("tls: protocol is shutdown")
1025 )
1026
1027
1028 func (c *Conn) Write(b []byte) (int, error) {
1029
1030 for {
1031 x := atomic.LoadInt32(&c.activeCall)
1032 if x&1 != 0 {
1033 return 0, errClosed
1034 }
1035 if atomic.CompareAndSwapInt32(&c.activeCall, x, x+2) {
1036 defer atomic.AddInt32(&c.activeCall, -2)
1037 break
1038 }
1039 }
1040
1041 if err := c.Handshake(); err != nil {
1042 return 0, err
1043 }
1044
1045 c.out.Lock()
1046 defer c.out.Unlock()
1047
1048 if err := c.out.err; err != nil {
1049 return 0, err
1050 }
1051
1052 if !c.handshakeComplete() {
1053 return 0, alertInternalError
1054 }
1055
1056 if c.closeNotifySent {
1057 return 0, errShutdown
1058 }
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069 var m int
1070 if len(b) > 1 && c.vers <= VersionTLS10 {
1071 if _, ok := c.out.cipher.(cipher.BlockMode); ok {
1072 n, err := c.writeRecordLocked(recordTypeApplicationData, b[:1])
1073 if err != nil {
1074 return n, c.out.setErrorLocked(err)
1075 }
1076 m, b = 1, b[1:]
1077 }
1078 }
1079
1080 n, err := c.writeRecordLocked(recordTypeApplicationData, b)
1081 return n + m, c.out.setErrorLocked(err)
1082 }
1083
1084
1085 func (c *Conn) handleRenegotiation() error {
1086 msg, err := c.readHandshake()
1087 if err != nil {
1088 return err
1089 }
1090
1091 _, ok := msg.(*helloRequestMsg)
1092 if !ok {
1093 c.sendAlert(alertUnexpectedMessage)
1094 return alertUnexpectedMessage
1095 }
1096
1097 if !c.isClient {
1098 return c.sendAlert(alertNoRenegotiation)
1099 }
1100
1101 switch c.config.Renegotiation {
1102 case RenegotiateNever:
1103 return c.sendAlert(alertNoRenegotiation)
1104 case RenegotiateOnceAsClient:
1105 if c.handshakes > 1 {
1106 return c.sendAlert(alertNoRenegotiation)
1107 }
1108 case RenegotiateFreelyAsClient:
1109
1110 default:
1111 c.sendAlert(alertInternalError)
1112 return errors.New("tls: unknown Renegotiation value")
1113 }
1114
1115 c.handshakeMutex.Lock()
1116 defer c.handshakeMutex.Unlock()
1117
1118 atomic.StoreUint32(&c.handshakeStatus, 0)
1119 if c.handshakeErr = c.clientHandshake(); c.handshakeErr == nil {
1120 c.handshakes++
1121 }
1122 return c.handshakeErr
1123 }
1124
1125
1126
1127 func (c *Conn) Read(b []byte) (n int, err error) {
1128 if err = c.Handshake(); err != nil {
1129 return
1130 }
1131 if len(b) == 0 {
1132
1133
1134 return
1135 }
1136
1137 c.in.Lock()
1138 defer c.in.Unlock()
1139
1140
1141
1142 const maxConsecutiveEmptyRecords = 100
1143 for emptyRecordCount := 0; emptyRecordCount <= maxConsecutiveEmptyRecords; emptyRecordCount++ {
1144 for c.input == nil && c.in.err == nil {
1145 if err := c.readRecord(recordTypeApplicationData); err != nil {
1146
1147 return 0, err
1148 }
1149 if c.hand.Len() > 0 {
1150
1151
1152 if err := c.handleRenegotiation(); err != nil {
1153 return 0, err
1154 }
1155 }
1156 }
1157 if err := c.in.err; err != nil {
1158 return 0, err
1159 }
1160
1161 n, err = c.input.Read(b)
1162 if c.input.off >= len(c.input.data) {
1163 c.in.freeBlock(c.input)
1164 c.input = nil
1165 }
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178 if ri := c.rawInput; ri != nil &&
1179 n != 0 && err == nil &&
1180 c.input == nil && len(ri.data) > 0 && recordType(ri.data[0]) == recordTypeAlert {
1181 if recErr := c.readRecord(recordTypeApplicationData); recErr != nil {
1182 err = recErr
1183 }
1184 }
1185
1186 if n != 0 || err != nil {
1187 return n, err
1188 }
1189 }
1190
1191 return 0, io.ErrNoProgress
1192 }
1193
1194
1195 func (c *Conn) Close() error {
1196
1197 var x int32
1198 for {
1199 x = atomic.LoadInt32(&c.activeCall)
1200 if x&1 != 0 {
1201 return errClosed
1202 }
1203 if atomic.CompareAndSwapInt32(&c.activeCall, x, x|1) {
1204 break
1205 }
1206 }
1207 if x != 0 {
1208
1209
1210
1211
1212
1213
1214 return c.conn.Close()
1215 }
1216
1217 var alertErr error
1218
1219 if c.handshakeComplete() {
1220 alertErr = c.closeNotify()
1221 }
1222
1223 if err := c.conn.Close(); err != nil {
1224 return err
1225 }
1226 return alertErr
1227 }
1228
1229 var errEarlyCloseWrite = errors.New("tls: CloseWrite called before handshake complete")
1230
1231
1232
1233
1234 func (c *Conn) CloseWrite() error {
1235 if !c.handshakeComplete() {
1236 return errEarlyCloseWrite
1237 }
1238
1239 return c.closeNotify()
1240 }
1241
1242 func (c *Conn) closeNotify() error {
1243 c.out.Lock()
1244 defer c.out.Unlock()
1245
1246 if !c.closeNotifySent {
1247 c.closeNotifyErr = c.sendAlertLocked(alertCloseNotify)
1248 c.closeNotifySent = true
1249 }
1250 return c.closeNotifyErr
1251 }
1252
1253
1254
1255
1256
1257 func (c *Conn) Handshake() error {
1258 c.handshakeMutex.Lock()
1259 defer c.handshakeMutex.Unlock()
1260
1261 if err := c.handshakeErr; err != nil {
1262 return err
1263 }
1264 if c.handshakeComplete() {
1265 return nil
1266 }
1267
1268 c.in.Lock()
1269 defer c.in.Unlock()
1270
1271 if c.isClient {
1272 c.handshakeErr = c.clientHandshake()
1273 } else {
1274 c.handshakeErr = c.serverHandshake()
1275 }
1276 if c.handshakeErr == nil {
1277 c.handshakes++
1278 } else {
1279
1280
1281 c.flush()
1282 }
1283
1284 if c.handshakeErr == nil && !c.handshakeComplete() {
1285 panic("handshake should have had a result.")
1286 }
1287
1288 return c.handshakeErr
1289 }
1290
1291
1292 func (c *Conn) ConnectionState() ConnectionState {
1293 c.handshakeMutex.Lock()
1294 defer c.handshakeMutex.Unlock()
1295
1296 var state ConnectionState
1297 state.HandshakeComplete = c.handshakeComplete()
1298 state.ServerName = c.serverName
1299
1300 if state.HandshakeComplete {
1301 state.Version = c.vers
1302 state.NegotiatedProtocol = c.clientProtocol
1303 state.DidResume = c.didResume
1304 state.NegotiatedProtocolIsMutual = !c.clientProtocolFallback
1305 state.CipherSuite = c.cipherSuite
1306 state.PeerCertificates = c.peerCertificates
1307 state.VerifiedChains = c.verifiedChains
1308 state.SignedCertificateTimestamps = c.scts
1309 state.OCSPResponse = c.ocspResponse
1310 if !c.didResume {
1311 if c.clientFinishedIsFirst {
1312 state.TLSUnique = c.clientFinished[:]
1313 } else {
1314 state.TLSUnique = c.serverFinished[:]
1315 }
1316 }
1317 if c.config.Renegotiation != RenegotiateNever {
1318 state.ekm = noExportedKeyingMaterial
1319 } else {
1320 state.ekm = c.ekm
1321 }
1322 }
1323
1324 return state
1325 }
1326
1327
1328
1329 func (c *Conn) OCSPResponse() []byte {
1330 c.handshakeMutex.Lock()
1331 defer c.handshakeMutex.Unlock()
1332
1333 return c.ocspResponse
1334 }
1335
1336
1337
1338
1339 func (c *Conn) VerifyHostname(host string) error {
1340 c.handshakeMutex.Lock()
1341 defer c.handshakeMutex.Unlock()
1342 if !c.isClient {
1343 return errors.New("tls: VerifyHostname called on TLS server connection")
1344 }
1345 if !c.handshakeComplete() {
1346 return errors.New("tls: handshake has not yet been performed")
1347 }
1348 if len(c.verifiedChains) == 0 {
1349 return errors.New("tls: handshake did not verify certificate chain")
1350 }
1351 return c.peerCertificates[0].VerifyHostname(host)
1352 }
1353
1354 func (c *Conn) handshakeComplete() bool {
1355 return atomic.LoadUint32(&c.handshakeStatus) == 1
1356 }
1357
View as plain text