1
2
3
4
5
6
7 package tls
8
9
10
11
12
13
14 import (
15 "bytes"
16 "context"
17 "crypto"
18 "crypto/ecdsa"
19 "crypto/ed25519"
20 "crypto/rsa"
21 "crypto/x509"
22 "encoding/pem"
23 "errors"
24 "fmt"
25 "net"
26 "os"
27 "strings"
28 "time"
29 )
30
31
32
33
34
35 func Server(conn net.Conn, config *Config) *Conn {
36 c := &Conn{
37 conn: conn,
38 config: config,
39 }
40 c.handshakeFn = c.serverHandshake
41 return c
42 }
43
44
45
46
47
48 func Client(conn net.Conn, config *Config) *Conn {
49 c := &Conn{
50 conn: conn,
51 config: config,
52 isClient: true,
53 }
54 c.handshakeFn = c.clientHandshake
55 return c
56 }
57
58
59 type listener struct {
60 net.Listener
61 config *Config
62 }
63
64
65
66 func (l *listener) Accept() (net.Conn, error) {
67 c, err := l.Listener.Accept()
68 if err != nil {
69 return nil, err
70 }
71 return Server(c, l.config), nil
72 }
73
74
75
76
77
78 func NewListener(inner net.Listener, config *Config) net.Listener {
79 l := new(listener)
80 l.Listener = inner
81 l.config = config
82 return l
83 }
84
85
86
87
88
89 func Listen(network, laddr string, config *Config) (net.Listener, error) {
90 if config == nil || len(config.Certificates) == 0 &&
91 config.GetCertificate == nil && config.GetConfigForClient == nil {
92 return nil, errors.New("tls: neither Certificates, GetCertificate, nor GetConfigForClient set in Config")
93 }
94 l, err := net.Listen(network, laddr)
95 if err != nil {
96 return nil, err
97 }
98 return NewListener(l, config), nil
99 }
100
101 type timeoutError struct{}
102
103 func (timeoutError) Error() string { return "tls: DialWithDialer timed out" }
104 func (timeoutError) Timeout() bool { return true }
105 func (timeoutError) Temporary() bool { return true }
106
107
108
109
110
111
112
113
114 func DialWithDialer(dialer *net.Dialer, network, addr string, config *Config) (*Conn, error) {
115 return dial(context.Background(), dialer, network, addr, config)
116 }
117
118 func dial(ctx context.Context, netDialer *net.Dialer, network, addr string, config *Config) (*Conn, error) {
119
120
121
122 timeout := netDialer.Timeout
123
124 if !netDialer.Deadline.IsZero() {
125 deadlineTimeout := time.Until(netDialer.Deadline)
126 if timeout == 0 || deadlineTimeout < timeout {
127 timeout = deadlineTimeout
128 }
129 }
130
131
132 var hsErrCh chan error
133 if timeout != 0 || ctx.Done() != nil {
134 hsErrCh = make(chan error, 2)
135 }
136 if timeout != 0 {
137 timer := time.AfterFunc(timeout, func() {
138 hsErrCh <- timeoutError{}
139 })
140 defer timer.Stop()
141 }
142
143 rawConn, err := netDialer.DialContext(ctx, network, addr)
144 if err != nil {
145 return nil, err
146 }
147
148 colonPos := strings.LastIndex(addr, ":")
149 if colonPos == -1 {
150 colonPos = len(addr)
151 }
152 hostname := addr[:colonPos]
153
154 if config == nil {
155 config = defaultConfig()
156 }
157
158
159 if config.ServerName == "" {
160
161 c := config.Clone()
162 c.ServerName = hostname
163 config = c
164 }
165
166 conn := Client(rawConn, config)
167
168 if hsErrCh == nil {
169 err = conn.Handshake()
170 } else {
171 go func() {
172 hsErrCh <- conn.Handshake()
173 }()
174
175 select {
176 case <-ctx.Done():
177 err = ctx.Err()
178 case err = <-hsErrCh:
179 if err != nil {
180
181
182
183 if e := ctx.Err(); e != nil {
184 err = e
185 }
186 }
187 }
188 }
189
190 if err != nil {
191 rawConn.Close()
192 return nil, err
193 }
194
195 return conn, nil
196 }
197
198
199
200
201
202
203
204 func Dial(network, addr string, config *Config) (*Conn, error) {
205 return DialWithDialer(new(net.Dialer), network, addr, config)
206 }
207
208
209
210 type Dialer struct {
211
212
213
214 NetDialer *net.Dialer
215
216
217
218
219
220 Config *Config
221 }
222
223
224
225
226
227 func (d *Dialer) Dial(network, addr string) (net.Conn, error) {
228 return d.DialContext(context.Background(), network, addr)
229 }
230
231 func (d *Dialer) netDialer() *net.Dialer {
232 if d.NetDialer != nil {
233 return d.NetDialer
234 }
235 return new(net.Dialer)
236 }
237
238
239
240
241
242
243
244
245
246
247 func (d *Dialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) {
248 c, err := dial(ctx, d.netDialer(), network, addr, d.Config)
249 if err != nil {
250
251 return nil, err
252 }
253 return c, nil
254 }
255
256
257
258
259
260
261 func LoadX509KeyPair(certFile, keyFile string) (Certificate, error) {
262 certPEMBlock, err := os.ReadFile(certFile)
263 if err != nil {
264 return Certificate{}, err
265 }
266 keyPEMBlock, err := os.ReadFile(keyFile)
267 if err != nil {
268 return Certificate{}, err
269 }
270 return X509KeyPair(certPEMBlock, keyPEMBlock)
271 }
272
273
274
275
276 func X509KeyPair(certPEMBlock, keyPEMBlock []byte) (Certificate, error) {
277 fail := func(err error) (Certificate, error) { return Certificate{}, err }
278
279 var cert Certificate
280 var skippedBlockTypes []string
281 for {
282 var certDERBlock *pem.Block
283 certDERBlock, certPEMBlock = pem.Decode(certPEMBlock)
284 if certDERBlock == nil {
285 break
286 }
287 if certDERBlock.Type == "CERTIFICATE" {
288 cert.Certificate = append(cert.Certificate, certDERBlock.Bytes)
289 } else {
290 skippedBlockTypes = append(skippedBlockTypes, certDERBlock.Type)
291 }
292 }
293
294 if len(cert.Certificate) == 0 {
295 if len(skippedBlockTypes) == 0 {
296 return fail(errors.New("tls: failed to find any PEM data in certificate input"))
297 }
298 if len(skippedBlockTypes) == 1 && strings.HasSuffix(skippedBlockTypes[0], "PRIVATE KEY") {
299 return fail(errors.New("tls: failed to find certificate PEM data in certificate input, but did find a private key; PEM inputs may have been switched"))
300 }
301 return fail(fmt.Errorf("tls: failed to find \"CERTIFICATE\" PEM block in certificate input after skipping PEM blocks of the following types: %v", skippedBlockTypes))
302 }
303
304 skippedBlockTypes = skippedBlockTypes[:0]
305 var keyDERBlock *pem.Block
306 for {
307 keyDERBlock, keyPEMBlock = pem.Decode(keyPEMBlock)
308 if keyDERBlock == nil {
309 if len(skippedBlockTypes) == 0 {
310 return fail(errors.New("tls: failed to find any PEM data in key input"))
311 }
312 if len(skippedBlockTypes) == 1 && skippedBlockTypes[0] == "CERTIFICATE" {
313 return fail(errors.New("tls: found a certificate rather than a key in the PEM for the private key"))
314 }
315 return fail(fmt.Errorf("tls: failed to find PEM block with type ending in \"PRIVATE KEY\" in key input after skipping PEM blocks of the following types: %v", skippedBlockTypes))
316 }
317 if keyDERBlock.Type == "PRIVATE KEY" || strings.HasSuffix(keyDERBlock.Type, " PRIVATE KEY") {
318 break
319 }
320 skippedBlockTypes = append(skippedBlockTypes, keyDERBlock.Type)
321 }
322
323
324
325 x509Cert, err := x509.ParseCertificate(cert.Certificate[0])
326 if err != nil {
327 return fail(err)
328 }
329
330 cert.PrivateKey, err = parsePrivateKey(keyDERBlock.Bytes)
331 if err != nil {
332 return fail(err)
333 }
334
335 switch pub := x509Cert.PublicKey.(type) {
336 case *rsa.PublicKey:
337 priv, ok := cert.PrivateKey.(*rsa.PrivateKey)
338 if !ok {
339 return fail(errors.New("tls: private key type does not match public key type"))
340 }
341 if pub.N.Cmp(priv.N) != 0 {
342 return fail(errors.New("tls: private key does not match public key"))
343 }
344 case *ecdsa.PublicKey:
345 priv, ok := cert.PrivateKey.(*ecdsa.PrivateKey)
346 if !ok {
347 return fail(errors.New("tls: private key type does not match public key type"))
348 }
349 if pub.X.Cmp(priv.X) != 0 || pub.Y.Cmp(priv.Y) != 0 {
350 return fail(errors.New("tls: private key does not match public key"))
351 }
352 case ed25519.PublicKey:
353 priv, ok := cert.PrivateKey.(ed25519.PrivateKey)
354 if !ok {
355 return fail(errors.New("tls: private key type does not match public key type"))
356 }
357 if !bytes.Equal(priv.Public().(ed25519.PublicKey), pub) {
358 return fail(errors.New("tls: private key does not match public key"))
359 }
360 default:
361 return fail(errors.New("tls: unknown public key algorithm"))
362 }
363
364 return cert, nil
365 }
366
367
368
369
370 func parsePrivateKey(der []byte) (crypto.PrivateKey, error) {
371 if key, err := x509.ParsePKCS1PrivateKey(der); err == nil {
372 return key, nil
373 }
374 if key, err := x509.ParsePKCS8PrivateKey(der); err == nil {
375 switch key := key.(type) {
376 case *rsa.PrivateKey, *ecdsa.PrivateKey, ed25519.PrivateKey:
377 return key, nil
378 default:
379 return nil, errors.New("tls: found unknown private key type in PKCS#8 wrapping")
380 }
381 }
382 if key, err := x509.ParseECPrivateKey(der); err == nil {
383 return key, nil
384 }
385
386 return nil, errors.New("tls: failed to parse private key")
387 }
388
View as plain text