Source file src/pkg/crypto/tls/conn.go
1
2
3
4
5
6
7 package tls
8
9 import (
10 "bytes"
11 "crypto/cipher"
12 "crypto/subtle"
13 "crypto/x509"
14 "errors"
15 "io"
16 "net"
17 "sync"
18 "time"
19 )
20
21
22
23 type Conn struct {
24
25 conn net.Conn
26 isClient bool
27
28
29 handshakeMutex sync.Mutex
30 vers uint16
31 haveVers bool
32 config *Config
33 handshakeComplete bool
34 didResume bool
35 cipherSuite uint16
36 ocspResponse []byte
37 peerCertificates []*x509.Certificate
38
39
40 verifiedChains [][]*x509.Certificate
41
42 serverName string
43
44 clientProtocol string
45 clientProtocolFallback bool
46
47
48 connErr
49
50
51 in, out halfConn
52 rawInput *block
53 input *block
54 hand bytes.Buffer
55
56 tmp [16]byte
57 }
58
59 type connErr struct {
60 mu sync.Mutex
61 value error
62 }
63
64 func (e *connErr) setError(err error) error {
65 e.mu.Lock()
66 defer e.mu.Unlock()
67
68 if e.value == nil {
69 e.value = err
70 }
71 return err
72 }
73
74 func (e *connErr) error() error {
75 e.mu.Lock()
76 defer e.mu.Unlock()
77 return e.value
78 }
79
80
81
82
83
84
85 func (c *Conn) LocalAddr() net.Addr {
86 return c.conn.LocalAddr()
87 }
88
89
90 func (c *Conn) RemoteAddr() net.Addr {
91 return c.conn.RemoteAddr()
92 }
93
94
95
96
97 func (c *Conn) SetDeadline(t time.Time) error {
98 return c.conn.SetDeadline(t)
99 }
100
101
102
103 func (c *Conn) SetReadDeadline(t time.Time) error {
104 return c.conn.SetReadDeadline(t)
105 }
106
107
108
109
110 func (c *Conn) SetWriteDeadline(t time.Time) error {
111 return c.conn.SetWriteDeadline(t)
112 }
113
114
115
116 type halfConn struct {
117 sync.Mutex
118 version uint16
119 cipher interface{}
120 mac macFunction
121 seq [8]byte
122 bfree *block
123
124 nextCipher interface{}
125 nextMac macFunction
126
127
128 inDigestBuf, outDigestBuf []byte
129 }
130
131
132
133 func (hc *halfConn) prepareCipherSpec(version uint16, cipher interface{}, mac macFunction) {
134 hc.version = version
135 hc.nextCipher = cipher
136 hc.nextMac = mac
137 }
138
139
140
141 func (hc *halfConn) changeCipherSpec() error {
142 if hc.nextCipher == nil {
143 return alertInternalError
144 }
145 hc.cipher = hc.nextCipher
146 hc.mac = hc.nextMac
147 hc.nextCipher = nil
148 hc.nextMac = nil
149 return nil
150 }
151
152
153 func (hc *halfConn) incSeq() {
154 for i := 7; i >= 0; i-- {
155 hc.seq[i]++
156 if hc.seq[i] != 0 {
157 return
158 }
159 }
160
161
162
163
164 panic("TLS: sequence number wraparound")
165 }
166
167
168 func (hc *halfConn) resetSeq() {
169 for i := range hc.seq {
170 hc.seq[i] = 0
171 }
172 }
173
174
175
176
177 func removePadding(payload []byte) ([]byte, byte) {
178 if len(payload) < 1 {
179 return payload, 0
180 }
181
182 paddingLen := payload[len(payload)-1]
183 t := uint(len(payload)-1) - uint(paddingLen)
184
185 good := byte(int32(^t) >> 31)
186
187 toCheck := 255
188
189 if toCheck+1 > len(payload) {
190 toCheck = len(payload) - 1
191 }
192
193 for i := 0; i < toCheck; i++ {
194 t := uint(paddingLen) - uint(i)
195
196 mask := byte(int32(^t) >> 31)
197 b := payload[len(payload)-1-i]
198 good &^= mask&paddingLen ^ mask&b
199 }
200
201
202
203 good &= good << 4
204 good &= good << 2
205 good &= good << 1
206 good = uint8(int8(good) >> 7)
207
208 toRemove := good&paddingLen + 1
209 return payload[:len(payload)-int(toRemove)], good
210 }
211
212
213
214
215 func removePaddingSSL30(payload []byte) ([]byte, byte) {
216 if len(payload) < 1 {
217 return payload, 0
218 }
219
220 paddingLen := int(payload[len(payload)-1]) + 1
221 if paddingLen > len(payload) {
222 return payload, 0
223 }
224
225 return payload[:len(payload)-paddingLen], 255
226 }
227
228 func roundUp(a, b int) int {
229 return a + (b-a%b)%b
230 }
231
232
233 func (hc *halfConn) decrypt(b *block) (bool, alert) {
234
235 payload := b.data[recordHeaderLen:]
236
237 macSize := 0
238 if hc.mac != nil {
239 macSize = hc.mac.Size()
240 }
241
242 paddingGood := byte(255)
243
244
245 if hc.cipher != nil {
246 switch c := hc.cipher.(type) {
247 case cipher.Stream:
248 c.XORKeyStream(payload, payload)
249 case cipher.BlockMode:
250 blockSize := c.BlockSize()
251
252 if len(payload)%blockSize != 0 || len(payload) < roundUp(macSize+1, blockSize) {
253 return false, alertBadRecordMAC
254 }
255
256 c.CryptBlocks(payload, payload)
257 if hc.version == versionSSL30 {
258 payload, paddingGood = removePaddingSSL30(payload)
259 } else {
260 payload, paddingGood = removePadding(payload)
261 }
262 b.resize(recordHeaderLen + len(payload))
263
264
265
266
267
268
269
270
271
272
273
274 default:
275 panic("unknown cipher type")
276 }
277 }
278
279
280 if hc.mac != nil {
281 if len(payload) < macSize {
282 return false, alertBadRecordMAC
283 }
284
285
286 n := len(payload) - macSize
287 b.data[3] = byte(n >> 8)
288 b.data[4] = byte(n)
289 b.resize(recordHeaderLen + n)
290 remoteMAC := payload[n:]
291 localMAC := hc.mac.MAC(hc.inDigestBuf, hc.seq[0:], b.data)
292 hc.incSeq()
293
294 if subtle.ConstantTimeCompare(localMAC, remoteMAC) != 1 || paddingGood != 255 {
295 return false, alertBadRecordMAC
296 }
297 hc.inDigestBuf = localMAC
298 }
299
300 return true, 0
301 }
302
303
304
305
306
307
308 func padToBlockSize(payload []byte, blockSize int) (prefix, finalBlock []byte) {
309 overrun := len(payload) % blockSize
310 paddingLen := blockSize - overrun
311 prefix = payload[:len(payload)-overrun]
312 finalBlock = make([]byte, blockSize)
313 copy(finalBlock, payload[len(payload)-overrun:])
314 for i := overrun; i < blockSize; i++ {
315 finalBlock[i] = byte(paddingLen - 1)
316 }
317 return
318 }
319
320
321 func (hc *halfConn) encrypt(b *block) (bool, alert) {
322
323 if hc.mac != nil {
324 mac := hc.mac.MAC(hc.outDigestBuf, hc.seq[0:], b.data)
325 hc.incSeq()
326
327 n := len(b.data)
328 b.resize(n + len(mac))
329 copy(b.data[n:], mac)
330 hc.outDigestBuf = mac
331 }
332
333 payload := b.data[recordHeaderLen:]
334
335
336 if hc.cipher != nil {
337 switch c := hc.cipher.(type) {
338 case cipher.Stream:
339 c.XORKeyStream(payload, payload)
340 case cipher.BlockMode:
341 prefix, finalBlock := padToBlockSize(payload, c.BlockSize())
342 b.resize(recordHeaderLen + len(prefix) + len(finalBlock))
343 c.CryptBlocks(b.data[recordHeaderLen:], prefix)
344 c.CryptBlocks(b.data[recordHeaderLen+len(prefix):], finalBlock)
345 default:
346 panic("unknown cipher type")
347 }
348 }
349
350
351 n := len(b.data) - recordHeaderLen
352 b.data[3] = byte(n >> 8)
353 b.data[4] = byte(n)
354
355 return true, 0
356 }
357
358
359 type block struct {
360 data []byte
361 off int
362 link *block
363 }
364
365
366 func (b *block) resize(n int) {
367 if n > cap(b.data) {
368 b.reserve(n)
369 }
370 b.data = b.data[0:n]
371 }
372
373
374 func (b *block) reserve(n int) {
375 if cap(b.data) >= n {
376 return
377 }
378 m := cap(b.data)
379 if m == 0 {
380 m = 1024
381 }
382 for m < n {
383 m *= 2
384 }
385 data := make([]byte, len(b.data), m)
386 copy(data, b.data)
387 b.data = data
388 }
389
390
391
392 func (b *block) readFromUntil(r io.Reader, n int) error {
393
394 if len(b.data) >= n {
395 return nil
396 }
397
398
399 b.reserve(n)
400 for {
401 m, err := r.Read(b.data[len(b.data):cap(b.data)])
402 b.data = b.data[0 : len(b.data)+m]
403 if len(b.data) >= n {
404 break
405 }
406 if err != nil {
407 return err
408 }
409 }
410 return nil
411 }
412
413 func (b *block) Read(p []byte) (n int, err error) {
414 n = copy(p, b.data[b.off:])
415 b.off += n
416 return
417 }
418
419
420 func (hc *halfConn) newBlock() *block {
421 b := hc.bfree
422 if b == nil {
423 return new(block)
424 }
425 hc.bfree = b.link
426 b.link = nil
427 b.resize(0)
428 return b
429 }
430
431
432
433
434
435 func (hc *halfConn) freeBlock(b *block) {
436 b.link = hc.bfree
437 hc.bfree = b
438 }
439
440
441
442
443 func (hc *halfConn) splitBlock(b *block, n int) (*block, *block) {
444 if len(b.data) <= n {
445 return b, nil
446 }
447 bb := hc.newBlock()
448 bb.resize(len(b.data) - n)
449 copy(bb.data, b.data[n:])
450 b.data = b.data[0:n]
451 return b, bb
452 }
453
454
455
456
457 func (c *Conn) readRecord(want recordType) error {
458
459
460
461 switch want {
462 default:
463 return c.sendAlert(alertInternalError)
464 case recordTypeHandshake, recordTypeChangeCipherSpec:
465 if c.handshakeComplete {
466 return c.sendAlert(alertInternalError)
467 }
468 case recordTypeApplicationData:
469 if !c.handshakeComplete {
470 return c.sendAlert(alertInternalError)
471 }
472 }
473
474 Again:
475 if c.rawInput == nil {
476 c.rawInput = c.in.newBlock()
477 }
478 b := c.rawInput
479
480
481 if err := b.readFromUntil(c.conn, recordHeaderLen); err != nil {
482
483
484
485
486
487
488 if e, ok := err.(net.Error); !ok || !e.Temporary() {
489 c.setError(err)
490 }
491 return err
492 }
493 typ := recordType(b.data[0])
494
495
496
497
498
499 if want == recordTypeHandshake && typ == 0x80 {
500 c.sendAlert(alertProtocolVersion)
501 return errors.New("tls: unsupported SSLv2 handshake received")
502 }
503
504 vers := uint16(b.data[1])<<8 | uint16(b.data[2])
505 n := int(b.data[3])<<8 | int(b.data[4])
506 if c.haveVers && vers != c.vers {
507 return c.sendAlert(alertProtocolVersion)
508 }
509 if n > maxCiphertext {
510 return c.sendAlert(alertRecordOverflow)
511 }
512 if !c.haveVers {
513
514
515
516
517
518
519
520
521 if (typ != recordTypeAlert && typ != want) || vers >= 0x1000 || n >= 0x3000 {
522 return c.sendAlert(alertUnexpectedMessage)
523 }
524 }
525 if err := b.readFromUntil(c.conn, recordHeaderLen+n); err != nil {
526 if err == io.EOF {
527 err = io.ErrUnexpectedEOF
528 }
529 if e, ok := err.(net.Error); !ok || !e.Temporary() {
530 c.setError(err)
531 }
532 return err
533 }
534
535
536 b, c.rawInput = c.in.splitBlock(b, recordHeaderLen+n)
537 b.off = recordHeaderLen
538 if ok, err := c.in.decrypt(b); !ok {
539 return c.sendAlert(err)
540 }
541 data := b.data[b.off:]
542 if len(data) > maxPlaintext {
543 c.sendAlert(alertRecordOverflow)
544 c.in.freeBlock(b)
545 return c.error()
546 }
547
548 switch typ {
549 default:
550 c.sendAlert(alertUnexpectedMessage)
551
552 case recordTypeAlert:
553 if len(data) != 2 {
554 c.sendAlert(alertUnexpectedMessage)
555 break
556 }
557 if alert(data[1]) == alertCloseNotify {
558 c.setError(io.EOF)
559 break
560 }
561 switch data[0] {
562 case alertLevelWarning:
563
564 c.in.freeBlock(b)
565 goto Again
566 case alertLevelError:
567 c.setError(&net.OpError{Op: "remote error", Err: alert(data[1])})
568 default:
569 c.sendAlert(alertUnexpectedMessage)
570 }
571
572 case recordTypeChangeCipherSpec:
573 if typ != want || len(data) != 1 || data[0] != 1 {
574 c.sendAlert(alertUnexpectedMessage)
575 break
576 }
577 err := c.in.changeCipherSpec()
578 if err != nil {
579 c.sendAlert(err.(alert))
580 }
581
582 case recordTypeApplicationData:
583 if typ != want {
584 c.sendAlert(alertUnexpectedMessage)
585 break
586 }
587 c.input = b
588 b = nil
589
590 case recordTypeHandshake:
591
592 if typ != want {
593 return c.sendAlert(alertNoRenegotiation)
594 }
595 c.hand.Write(data)
596 }
597
598 if b != nil {
599 c.in.freeBlock(b)
600 }
601 return c.error()
602 }
603
604
605
606 func (c *Conn) sendAlertLocked(err alert) error {
607 switch err {
608 case alertNoRenegotiation, alertCloseNotify:
609 c.tmp[0] = alertLevelWarning
610 default:
611 c.tmp[0] = alertLevelError
612 }
613 c.tmp[1] = byte(err)
614 c.writeRecord(recordTypeAlert, c.tmp[0:2])
615
616 if err != alertCloseNotify {
617 return c.setError(&net.OpError{Op: "local error", Err: err})
618 }
619 return nil
620 }
621
622
623
624 func (c *Conn) sendAlert(err alert) error {
625 c.out.Lock()
626 defer c.out.Unlock()
627 return c.sendAlertLocked(err)
628 }
629
630
631
632
633 func (c *Conn) writeRecord(typ recordType, data []byte) (n int, err error) {
634 b := c.out.newBlock()
635 for len(data) > 0 {
636 m := len(data)
637 if m > maxPlaintext {
638 m = maxPlaintext
639 }
640 b.resize(recordHeaderLen + m)
641 b.data[0] = byte(typ)
642 vers := c.vers
643 if vers == 0 {
644 vers = maxVersion
645 }
646 b.data[1] = byte(vers >> 8)
647 b.data[2] = byte(vers)
648 b.data[3] = byte(m >> 8)
649 b.data[4] = byte(m)
650 copy(b.data[recordHeaderLen:], data)
651 c.out.encrypt(b)
652 _, err = c.conn.Write(b.data)
653 if err != nil {
654 break
655 }
656 n += m
657 data = data[m:]
658 }
659 c.out.freeBlock(b)
660
661 if typ == recordTypeChangeCipherSpec {
662 err = c.out.changeCipherSpec()
663 if err != nil {
664
665
666 c.tmp[0] = alertLevelError
667 c.tmp[1] = byte(err.(alert))
668 c.writeRecord(recordTypeAlert, c.tmp[0:2])
669 return n, c.setError(&net.OpError{Op: "local error", Err: err})
670 }
671 }
672 return
673 }
674
675
676
677
678 func (c *Conn) readHandshake() (interface{}, error) {
679 for c.hand.Len() < 4 {
680 if err := c.error(); err != nil {
681 return nil, err
682 }
683 if err := c.readRecord(recordTypeHandshake); err != nil {
684 return nil, err
685 }
686 }
687
688 data := c.hand.Bytes()
689 n := int(data[1])<<16 | int(data[2])<<8 | int(data[3])
690 if n > maxHandshake {
691 c.sendAlert(alertInternalError)
692 return nil, c.error()
693 }
694 for c.hand.Len() < 4+n {
695 if err := c.error(); err != nil {
696 return nil, err
697 }
698 if err := c.readRecord(recordTypeHandshake); err != nil {
699 return nil, err
700 }
701 }
702 data = c.hand.Next(4 + n)
703 var m handshakeMessage
704 switch data[0] {
705 case typeClientHello:
706 m = new(clientHelloMsg)
707 case typeServerHello:
708 m = new(serverHelloMsg)
709 case typeCertificate:
710 m = new(certificateMsg)
711 case typeCertificateRequest:
712 m = new(certificateRequestMsg)
713 case typeCertificateStatus:
714 m = new(certificateStatusMsg)
715 case typeServerKeyExchange:
716 m = new(serverKeyExchangeMsg)
717 case typeServerHelloDone:
718 m = new(serverHelloDoneMsg)
719 case typeClientKeyExchange:
720 m = new(clientKeyExchangeMsg)
721 case typeCertificateVerify:
722 m = new(certificateVerifyMsg)
723 case typeNextProtocol:
724 m = new(nextProtoMsg)
725 case typeFinished:
726 m = new(finishedMsg)
727 default:
728 c.sendAlert(alertUnexpectedMessage)
729 return nil, alertUnexpectedMessage
730 }
731
732
733
734
735 data = append([]byte(nil), data...)
736
737 if !m.unmarshal(data) {
738 c.sendAlert(alertUnexpectedMessage)
739 return nil, alertUnexpectedMessage
740 }
741 return m, nil
742 }
743
744
745 func (c *Conn) Write(b []byte) (int, error) {
746 if err := c.error(); err != nil {
747 return 0, err
748 }
749
750 if err := c.Handshake(); err != nil {
751 return 0, c.setError(err)
752 }
753
754 c.out.Lock()
755 defer c.out.Unlock()
756
757 if !c.handshakeComplete {
758 return 0, alertInternalError
759 }
760
761
762
763
764
765
766
767
768
769
770 var m int
771 if len(b) > 1 && c.vers <= versionTLS10 {
772 if _, ok := c.out.cipher.(cipher.BlockMode); ok {
773 n, err := c.writeRecord(recordTypeApplicationData, b[:1])
774 if err != nil {
775 return n, c.setError(err)
776 }
777 m, b = 1, b[1:]
778 }
779 }
780
781 n, err := c.writeRecord(recordTypeApplicationData, b)
782 return n + m, c.setError(err)
783 }
784
785
786
787 func (c *Conn) Read(b []byte) (n int, err error) {
788 if err = c.Handshake(); err != nil {
789 return
790 }
791
792 c.in.Lock()
793 defer c.in.Unlock()
794
795 for c.input == nil && c.error() == nil {
796 if err := c.readRecord(recordTypeApplicationData); err != nil {
797
798 return 0, err
799 }
800 }
801 if err := c.error(); err != nil {
802 return 0, err
803 }
804 n, err = c.input.Read(b)
805 if c.input.off >= len(c.input.data) {
806 c.in.freeBlock(c.input)
807 c.input = nil
808 }
809 return n, nil
810 }
811
812
813 func (c *Conn) Close() error {
814 var alertErr error
815
816 c.handshakeMutex.Lock()
817 defer c.handshakeMutex.Unlock()
818 if c.handshakeComplete {
819 alertErr = c.sendAlert(alertCloseNotify)
820 }
821
822 if err := c.conn.Close(); err != nil {
823 return err
824 }
825 return alertErr
826 }
827
828
829
830
831
832 func (c *Conn) Handshake() error {
833 c.handshakeMutex.Lock()
834 defer c.handshakeMutex.Unlock()
835 if err := c.error(); err != nil {
836 return err
837 }
838 if c.handshakeComplete {
839 return nil
840 }
841 if c.isClient {
842 return c.clientHandshake()
843 }
844 return c.serverHandshake()
845 }
846
847
848 func (c *Conn) ConnectionState() ConnectionState {
849 c.handshakeMutex.Lock()
850 defer c.handshakeMutex.Unlock()
851
852 var state ConnectionState
853 state.HandshakeComplete = c.handshakeComplete
854 if c.handshakeComplete {
855 state.NegotiatedProtocol = c.clientProtocol
856 state.DidResume = c.didResume
857 state.NegotiatedProtocolIsMutual = !c.clientProtocolFallback
858 state.CipherSuite = c.cipherSuite
859 state.PeerCertificates = c.peerCertificates
860 state.VerifiedChains = c.verifiedChains
861 state.ServerName = c.serverName
862 }
863
864 return state
865 }
866
867
868
869 func (c *Conn) OCSPResponse() []byte {
870 c.handshakeMutex.Lock()
871 defer c.handshakeMutex.Unlock()
872
873 return c.ocspResponse
874 }
875
876
877
878
879 func (c *Conn) VerifyHostname(host string) error {
880 c.handshakeMutex.Lock()
881 defer c.handshakeMutex.Unlock()
882 if !c.isClient {
883 return errors.New("VerifyHostname called on TLS server connection")
884 }
885 if !c.handshakeComplete {
886 return errors.New("TLS handshake has not yet been performed")
887 }
888 return c.peerCertificates[0].VerifyHostname(host)
889 }
View as plain text