1
2
3
4
5
6
7 package tls
8
9 import (
10 "bytes"
11 "crypto/cipher"
12 "crypto/subtle"
13 "crypto/x509"
14 "hash"
15 "io"
16 "net"
17 "os"
18 "sync"
19 )
20
21
22
23 type Conn struct {
24
25 conn net.Conn
26 isClient bool
27
28
29 handshakeMutex sync.Mutex
30 vers uint16
31 haveVers bool
32 config *Config
33 handshakeComplete bool
34 cipherSuite uint16
35 ocspResponse []byte
36 peerCertificates []*x509.Certificate
37
38
39 verifiedChains [][]*x509.Certificate
40
41 clientProtocol string
42 clientProtocolFallback bool
43
44
45 errMutex sync.Mutex
46 err os.Error
47
48
49 in, out halfConn
50 rawInput *block
51 input *block
52 hand bytes.Buffer
53
54 tmp [16]byte
55 }
56
57 func (c *Conn) setError(err os.Error) os.Error {
58 c.errMutex.Lock()
59 defer c.errMutex.Unlock()
60
61 if c.err == nil {
62 c.err = err
63 }
64 return err
65 }
66
67 func (c *Conn) error() os.Error {
68 c.errMutex.Lock()
69 defer c.errMutex.Unlock()
70
71 return c.err
72 }
73
74
75
76
77
78
79 func (c *Conn) LocalAddr() net.Addr {
80 return c.conn.LocalAddr()
81 }
82
83
84 func (c *Conn) RemoteAddr() net.Addr {
85 return c.conn.RemoteAddr()
86 }
87
88
89
90 func (c *Conn) SetTimeout(nsec int64) os.Error {
91 return c.conn.SetTimeout(nsec)
92 }
93
94
95
96
97 func (c *Conn) SetReadTimeout(nsec int64) os.Error {
98 return c.conn.SetReadTimeout(nsec)
99 }
100
101
102
103 func (c *Conn) SetWriteTimeout(nsec int64) os.Error {
104 return os.NewError("TLS does not support SetWriteTimeout")
105 }
106
107
108
109 type halfConn struct {
110 sync.Mutex
111 cipher interface{}
112 mac hash.Hash
113 seq [8]byte
114 bfree *block
115
116 nextCipher interface{}
117 nextMac hash.Hash
118 }
119
120
121
122 func (hc *halfConn) prepareCipherSpec(cipher interface{}, mac hash.Hash) {
123 hc.nextCipher = cipher
124 hc.nextMac = mac
125 }
126
127
128
129 func (hc *halfConn) changeCipherSpec() os.Error {
130 if hc.nextCipher == nil {
131 return alertInternalError
132 }
133 hc.cipher = hc.nextCipher
134 hc.mac = hc.nextMac
135 hc.nextCipher = nil
136 hc.nextMac = nil
137 return nil
138 }
139
140
141 func (hc *halfConn) incSeq() {
142 for i := 7; i >= 0; i-- {
143 hc.seq[i]++
144 if hc.seq[i] != 0 {
145 return
146 }
147 }
148
149
150
151
152 panic("TLS: sequence number wraparound")
153 }
154
155
156 func (hc *halfConn) resetSeq() {
157 for i := range hc.seq {
158 hc.seq[i] = 0
159 }
160 }
161
162
163
164
165 func removePadding(payload []byte) ([]byte, byte) {
166 if len(payload) < 1 {
167 return payload, 0
168 }
169
170 paddingLen := payload[len(payload)-1]
171 t := uint(len(payload)-1) - uint(paddingLen)
172
173 good := byte(int32(^t) >> 31)
174
175 toCheck := 255
176
177 if toCheck+1 > len(payload) {
178 toCheck = len(payload) - 1
179 }
180
181 for i := 0; i < toCheck; i++ {
182 t := uint(paddingLen) - uint(i)
183
184 mask := byte(int32(^t) >> 31)
185 b := payload[len(payload)-1-i]
186 good &^= mask&paddingLen ^ mask&b
187 }
188
189
190
191 good &= good << 4
192 good &= good << 2
193 good &= good << 1
194 good = uint8(int8(good) >> 7)
195
196 toRemove := good&paddingLen + 1
197 return payload[:len(payload)-int(toRemove)], good
198 }
199
200 func roundUp(a, b int) int {
201 return a + (b-a%b)%b
202 }
203
204
205 func (hc *halfConn) decrypt(b *block) (bool, alert) {
206
207 payload := b.data[recordHeaderLen:]
208
209 macSize := 0
210 if hc.mac != nil {
211 macSize = hc.mac.Size()
212 }
213
214 paddingGood := byte(255)
215
216
217 if hc.cipher != nil {
218 switch c := hc.cipher.(type) {
219 case cipher.Stream:
220 c.XORKeyStream(payload, payload)
221 case cipher.BlockMode:
222 blockSize := c.BlockSize()
223
224 if len(payload)%blockSize != 0 || len(payload) < roundUp(macSize+1, blockSize) {
225 return false, alertBadRecordMAC
226 }
227
228 c.CryptBlocks(payload, payload)
229 payload, paddingGood = removePadding(payload)
230 b.resize(recordHeaderLen + len(payload))
231
232
233
234
235
236
237
238
239
240
241
242 default:
243 panic("unknown cipher type")
244 }
245 }
246
247
248 if hc.mac != nil {
249 if len(payload) < macSize {
250 return false, alertBadRecordMAC
251 }
252
253
254 n := len(payload) - macSize
255 b.data[3] = byte(n >> 8)
256 b.data[4] = byte(n)
257 b.resize(recordHeaderLen + n)
258 remoteMAC := payload[n:]
259
260 hc.mac.Reset()
261 hc.mac.Write(hc.seq[0:])
262 hc.incSeq()
263 hc.mac.Write(b.data)
264
265 if subtle.ConstantTimeCompare(hc.mac.Sum(), remoteMAC) != 1 || paddingGood != 255 {
266 return false, alertBadRecordMAC
267 }
268 }
269
270 return true, 0
271 }
272
273
274
275
276
277
278 func padToBlockSize(payload []byte, blockSize int) (prefix, finalBlock []byte) {
279 overrun := len(payload) % blockSize
280 paddingLen := blockSize - overrun
281 prefix = payload[:len(payload)-overrun]
282 finalBlock = make([]byte, blockSize)
283 copy(finalBlock, payload[len(payload)-overrun:])
284 for i := overrun; i < blockSize; i++ {
285 finalBlock[i] = byte(paddingLen - 1)
286 }
287 return
288 }
289
290
291 func (hc *halfConn) encrypt(b *block) (bool, alert) {
292
293 if hc.mac != nil {
294 hc.mac.Reset()
295 hc.mac.Write(hc.seq[0:])
296 hc.incSeq()
297 hc.mac.Write(b.data)
298 mac := hc.mac.Sum()
299 n := len(b.data)
300 b.resize(n + len(mac))
301 copy(b.data[n:], mac)
302 }
303
304 payload := b.data[recordHeaderLen:]
305
306
307 if hc.cipher != nil {
308 switch c := hc.cipher.(type) {
309 case cipher.Stream:
310 c.XORKeyStream(payload, payload)
311 case cipher.BlockMode:
312 prefix, finalBlock := padToBlockSize(payload, c.BlockSize())
313 b.resize(recordHeaderLen + len(prefix) + len(finalBlock))
314 c.CryptBlocks(b.data[recordHeaderLen:], prefix)
315 c.CryptBlocks(b.data[recordHeaderLen+len(prefix):], finalBlock)
316 default:
317 panic("unknown cipher type")
318 }
319 }
320
321
322 n := len(b.data) - recordHeaderLen
323 b.data[3] = byte(n >> 8)
324 b.data[4] = byte(n)
325
326 return true, 0
327 }
328
329
330 type block struct {
331 data []byte
332 off int
333 link *block
334 }
335
336
337 func (b *block) resize(n int) {
338 if n > cap(b.data) {
339 b.reserve(n)
340 }
341 b.data = b.data[0:n]
342 }
343
344
345 func (b *block) reserve(n int) {
346 if cap(b.data) >= n {
347 return
348 }
349 m := cap(b.data)
350 if m == 0 {
351 m = 1024
352 }
353 for m < n {
354 m *= 2
355 }
356 data := make([]byte, len(b.data), m)
357 copy(data, b.data)
358 b.data = data
359 }
360
361
362
363 func (b *block) readFromUntil(r io.Reader, n int) os.Error {
364
365 if len(b.data) >= n {
366 return nil
367 }
368
369
370 b.reserve(n)
371 for {
372 m, err := r.Read(b.data[len(b.data):cap(b.data)])
373 b.data = b.data[0 : len(b.data)+m]
374 if len(b.data) >= n {
375 break
376 }
377 if err != nil {
378 return err
379 }
380 }
381 return nil
382 }
383
384 func (b *block) Read(p []byte) (n int, err os.Error) {
385 n = copy(p, b.data[b.off:])
386 b.off += n
387 return
388 }
389
390
391 func (hc *halfConn) newBlock() *block {
392 b := hc.bfree
393 if b == nil {
394 return new(block)
395 }
396 hc.bfree = b.link
397 b.link = nil
398 b.resize(0)
399 return b
400 }
401
402
403
404
405
406 func (hc *halfConn) freeBlock(b *block) {
407 b.link = hc.bfree
408 hc.bfree = b
409 }
410
411
412
413
414 func (hc *halfConn) splitBlock(b *block, n int) (*block, *block) {
415 if len(b.data) <= n {
416 return b, nil
417 }
418 bb := hc.newBlock()
419 bb.resize(len(b.data) - n)
420 copy(bb.data, b.data[n:])
421 b.data = b.data[0:n]
422 return b, bb
423 }
424
425
426
427
428 func (c *Conn) readRecord(want recordType) os.Error {
429
430
431
432 switch want {
433 default:
434 return c.sendAlert(alertInternalError)
435 case recordTypeHandshake, recordTypeChangeCipherSpec:
436 if c.handshakeComplete {
437 return c.sendAlert(alertInternalError)
438 }
439 case recordTypeApplicationData:
440 if !c.handshakeComplete {
441 return c.sendAlert(alertInternalError)
442 }
443 }
444
445 Again:
446 if c.rawInput == nil {
447 c.rawInput = c.in.newBlock()
448 }
449 b := c.rawInput
450
451
452 if err := b.readFromUntil(c.conn, recordHeaderLen); err != nil {
453
454
455
456
457
458
459 if e, ok := err.(net.Error); !ok || !e.Temporary() {
460 c.setError(err)
461 }
462 return err
463 }
464 typ := recordType(b.data[0])
465 vers := uint16(b.data[1])<<8 | uint16(b.data[2])
466 n := int(b.data[3])<<8 | int(b.data[4])
467 if c.haveVers && vers != c.vers {
468 return c.sendAlert(alertProtocolVersion)
469 }
470 if n > maxCiphertext {
471 return c.sendAlert(alertRecordOverflow)
472 }
473 if err := b.readFromUntil(c.conn, recordHeaderLen+n); err != nil {
474 if err == os.EOF {
475 err = io.ErrUnexpectedEOF
476 }
477 if e, ok := err.(net.Error); !ok || !e.Temporary() {
478 c.setError(err)
479 }
480 return err
481 }
482
483
484 b, c.rawInput = c.in.splitBlock(b, recordHeaderLen+n)
485 b.off = recordHeaderLen
486 if ok, err := c.in.decrypt(b); !ok {
487 return c.sendAlert(err)
488 }
489 data := b.data[b.off:]
490 if len(data) > maxPlaintext {
491 c.sendAlert(alertRecordOverflow)
492 c.in.freeBlock(b)
493 return c.error()
494 }
495
496 switch typ {
497 default:
498 c.sendAlert(alertUnexpectedMessage)
499
500 case recordTypeAlert:
501 if len(data) != 2 {
502 c.sendAlert(alertUnexpectedMessage)
503 break
504 }
505 if alert(data[1]) == alertCloseNotify {
506 c.setError(os.EOF)
507 break
508 }
509 switch data[0] {
510 case alertLevelWarning:
511
512 c.in.freeBlock(b)
513 goto Again
514 case alertLevelError:
515 c.setError(&net.OpError{Op: "remote error", Error: alert(data[1])})
516 default:
517 c.sendAlert(alertUnexpectedMessage)
518 }
519
520 case recordTypeChangeCipherSpec:
521 if typ != want || len(data) != 1 || data[0] != 1 {
522 c.sendAlert(alertUnexpectedMessage)
523 break
524 }
525 err := c.in.changeCipherSpec()
526 if err != nil {
527 c.sendAlert(err.(alert))
528 }
529
530 case recordTypeApplicationData:
531 if typ != want {
532 c.sendAlert(alertUnexpectedMessage)
533 break
534 }
535 c.input = b
536 b = nil
537
538 case recordTypeHandshake:
539
540 if typ != want {
541 return c.sendAlert(alertNoRenegotiation)
542 }
543 c.hand.Write(data)
544 }
545
546 if b != nil {
547 c.in.freeBlock(b)
548 }
549 return c.error()
550 }
551
552
553
554 func (c *Conn) sendAlertLocked(err alert) os.Error {
555 c.tmp[0] = alertLevelError
556 if err == alertNoRenegotiation {
557 c.tmp[0] = alertLevelWarning
558 }
559 c.tmp[1] = byte(err)
560 c.writeRecord(recordTypeAlert, c.tmp[0:2])
561
562 if err != alertCloseNotify {
563 return c.setError(&net.OpError{Op: "local error", Error: err})
564 }
565 return nil
566 }
567
568
569
570 func (c *Conn) sendAlert(err alert) os.Error {
571 c.out.Lock()
572 defer c.out.Unlock()
573 return c.sendAlertLocked(err)
574 }
575
576
577
578
579 func (c *Conn) writeRecord(typ recordType, data []byte) (n int, err os.Error) {
580 b := c.out.newBlock()
581 for len(data) > 0 {
582 m := len(data)
583 if m > maxPlaintext {
584 m = maxPlaintext
585 }
586 b.resize(recordHeaderLen + m)
587 b.data[0] = byte(typ)
588 vers := c.vers
589 if vers == 0 {
590 vers = maxVersion
591 }
592 b.data[1] = byte(vers >> 8)
593 b.data[2] = byte(vers)
594 b.data[3] = byte(m >> 8)
595 b.data[4] = byte(m)
596 copy(b.data[recordHeaderLen:], data)
597 c.out.encrypt(b)
598 _, err = c.conn.Write(b.data)
599 if err != nil {
600 break
601 }
602 n += m
603 data = data[m:]
604 }
605 c.out.freeBlock(b)
606
607 if typ == recordTypeChangeCipherSpec {
608 err = c.out.changeCipherSpec()
609 if err != nil {
610
611
612 c.tmp[0] = alertLevelError
613 c.tmp[1] = byte(err.(alert))
614 c.writeRecord(recordTypeAlert, c.tmp[0:2])
615 c.err = &net.OpError{Op: "local error", Error: err}
616 return n, c.err
617 }
618 }
619 return
620 }
621
622
623
624
625 func (c *Conn) readHandshake() (interface{}, os.Error) {
626 for c.hand.Len() < 4 {
627 if c.err != nil {
628 return nil, c.err
629 }
630 c.readRecord(recordTypeHandshake)
631 }
632
633 data := c.hand.Bytes()
634 n := int(data[1])<<16 | int(data[2])<<8 | int(data[3])
635 if n > maxHandshake {
636 c.sendAlert(alertInternalError)
637 return nil, c.err
638 }
639 for c.hand.Len() < 4+n {
640 if c.err != nil {
641 return nil, c.err
642 }
643 c.readRecord(recordTypeHandshake)
644 }
645 data = c.hand.Next(4 + n)
646 var m handshakeMessage
647 switch data[0] {
648 case typeClientHello:
649 m = new(clientHelloMsg)
650 case typeServerHello:
651 m = new(serverHelloMsg)
652 case typeCertificate:
653 m = new(certificateMsg)
654 case typeCertificateRequest:
655 m = new(certificateRequestMsg)
656 case typeCertificateStatus:
657 m = new(certificateStatusMsg)
658 case typeServerKeyExchange:
659 m = new(serverKeyExchangeMsg)
660 case typeServerHelloDone:
661 m = new(serverHelloDoneMsg)
662 case typeClientKeyExchange:
663 m = new(clientKeyExchangeMsg)
664 case typeCertificateVerify:
665 m = new(certificateVerifyMsg)
666 case typeNextProtocol:
667 m = new(nextProtoMsg)
668 case typeFinished:
669 m = new(finishedMsg)
670 default:
671 c.sendAlert(alertUnexpectedMessage)
672 return nil, alertUnexpectedMessage
673 }
674
675
676
677
678 data = append([]byte(nil), data...)
679
680 if !m.unmarshal(data) {
681 c.sendAlert(alertUnexpectedMessage)
682 return nil, alertUnexpectedMessage
683 }
684 return m, nil
685 }
686
687
688 func (c *Conn) Write(b []byte) (n int, err os.Error) {
689 if err = c.Handshake(); err != nil {
690 return
691 }
692
693 c.out.Lock()
694 defer c.out.Unlock()
695
696 if !c.handshakeComplete {
697 return 0, alertInternalError
698 }
699 if c.err != nil {
700 return 0, c.err
701 }
702 return c.writeRecord(recordTypeApplicationData, b)
703 }
704
705
706
707 func (c *Conn) Read(b []byte) (n int, err os.Error) {
708 if err = c.Handshake(); err != nil {
709 return
710 }
711
712 c.in.Lock()
713 defer c.in.Unlock()
714
715 for c.input == nil && c.err == nil {
716 if err := c.readRecord(recordTypeApplicationData); err != nil {
717
718 return 0, err
719 }
720 }
721 if c.err != nil {
722 return 0, c.err
723 }
724 n, err = c.input.Read(b)
725 if c.input.off >= len(c.input.data) {
726 c.in.freeBlock(c.input)
727 c.input = nil
728 }
729 return n, nil
730 }
731
732
733 func (c *Conn) Close() os.Error {
734 if err := c.Handshake(); err != nil {
735 return err
736 }
737 return c.sendAlert(alertCloseNotify)
738 }
739
740
741
742
743
744 func (c *Conn) Handshake() os.Error {
745 c.handshakeMutex.Lock()
746 defer c.handshakeMutex.Unlock()
747 if err := c.error(); err != nil {
748 return err
749 }
750 if c.handshakeComplete {
751 return nil
752 }
753 if c.isClient {
754 return c.clientHandshake()
755 }
756 return c.serverHandshake()
757 }
758
759
760 func (c *Conn) ConnectionState() ConnectionState {
761 c.handshakeMutex.Lock()
762 defer c.handshakeMutex.Unlock()
763
764 var state ConnectionState
765 state.HandshakeComplete = c.handshakeComplete
766 if c.handshakeComplete {
767 state.NegotiatedProtocol = c.clientProtocol
768 state.NegotiatedProtocolIsMutual = !c.clientProtocolFallback
769 state.CipherSuite = c.cipherSuite
770 state.PeerCertificates = c.peerCertificates
771 state.VerifiedChains = c.verifiedChains
772 }
773
774 return state
775 }
776
777
778
779 func (c *Conn) OCSPResponse() []byte {
780 c.handshakeMutex.Lock()
781 defer c.handshakeMutex.Unlock()
782
783 return c.ocspResponse
784 }
785
786
787
788
789 func (c *Conn) VerifyHostname(host string) os.Error {
790 c.handshakeMutex.Lock()
791 defer c.handshakeMutex.Unlock()
792 if !c.isClient {
793 return os.NewError("VerifyHostname called on TLS server connection")
794 }
795 if !c.handshakeComplete {
796 return os.NewError("TLS handshake has not yet been performed")
797 }
798 return c.peerCertificates[0].VerifyHostname(host)
799 }