Source file src/crypto/tls/quic_test.go

     1  // Copyright 2023 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package tls
     6  
     7  import (
     8  	"context"
     9  	"errors"
    10  	"reflect"
    11  	"testing"
    12  )
    13  
    14  type testQUICConn struct {
    15  	t           *testing.T
    16  	conn        *QUICConn
    17  	readSecret  map[QUICEncryptionLevel]suiteSecret
    18  	writeSecret map[QUICEncryptionLevel]suiteSecret
    19  	gotParams   []byte
    20  	complete    bool
    21  }
    22  
    23  func newTestQUICClient(t *testing.T, config *Config) *testQUICConn {
    24  	q := &testQUICConn{t: t}
    25  	q.conn = QUICClient(&QUICConfig{
    26  		TLSConfig: config,
    27  	})
    28  	t.Cleanup(func() {
    29  		q.conn.Close()
    30  	})
    31  	return q
    32  }
    33  
    34  func newTestQUICServer(t *testing.T, config *Config) *testQUICConn {
    35  	q := &testQUICConn{t: t}
    36  	q.conn = QUICServer(&QUICConfig{
    37  		TLSConfig: config,
    38  	})
    39  	t.Cleanup(func() {
    40  		q.conn.Close()
    41  	})
    42  	return q
    43  }
    44  
    45  type suiteSecret struct {
    46  	suite  uint16
    47  	secret []byte
    48  }
    49  
    50  func (q *testQUICConn) setReadSecret(level QUICEncryptionLevel, suite uint16, secret []byte) {
    51  	if _, ok := q.writeSecret[level]; !ok {
    52  		q.t.Errorf("SetReadSecret for level %v called before SetWriteSecret", level)
    53  	}
    54  	if level == QUICEncryptionLevelApplication && !q.complete {
    55  		q.t.Errorf("SetReadSecret for level %v called before HandshakeComplete", level)
    56  	}
    57  	if _, ok := q.readSecret[level]; ok {
    58  		q.t.Errorf("SetReadSecret for level %v called twice", level)
    59  	}
    60  	if q.readSecret == nil {
    61  		q.readSecret = map[QUICEncryptionLevel]suiteSecret{}
    62  	}
    63  	switch level {
    64  	case QUICEncryptionLevelHandshake, QUICEncryptionLevelApplication:
    65  		q.readSecret[level] = suiteSecret{suite, secret}
    66  	default:
    67  		q.t.Errorf("SetReadSecret for unexpected level %v", level)
    68  	}
    69  }
    70  
    71  func (q *testQUICConn) setWriteSecret(level QUICEncryptionLevel, suite uint16, secret []byte) {
    72  	if _, ok := q.writeSecret[level]; ok {
    73  		q.t.Errorf("SetWriteSecret for level %v called twice", level)
    74  	}
    75  	if q.writeSecret == nil {
    76  		q.writeSecret = map[QUICEncryptionLevel]suiteSecret{}
    77  	}
    78  	switch level {
    79  	case QUICEncryptionLevelHandshake, QUICEncryptionLevelApplication:
    80  		q.writeSecret[level] = suiteSecret{suite, secret}
    81  	default:
    82  		q.t.Errorf("SetWriteSecret for unexpected level %v", level)
    83  	}
    84  }
    85  
    86  var errTransportParametersRequired = errors.New("transport parameters required")
    87  
    88  func runTestQUICConnection(ctx context.Context, cli, srv *testQUICConn, onEvent func(e QUICEvent, src, dst *testQUICConn) bool) error {
    89  	a, b := cli, srv
    90  	for _, c := range []*testQUICConn{a, b} {
    91  		if !c.conn.conn.quic.started {
    92  			if err := c.conn.Start(ctx); err != nil {
    93  				return err
    94  			}
    95  		}
    96  	}
    97  	idleCount := 0
    98  	for {
    99  		e := a.conn.NextEvent()
   100  		if onEvent != nil && onEvent(e, a, b) {
   101  			continue
   102  		}
   103  		switch e.Kind {
   104  		case QUICNoEvent:
   105  			idleCount++
   106  			if idleCount == 2 {
   107  				if !a.complete || !b.complete {
   108  					return errors.New("handshake incomplete")
   109  				}
   110  				return nil
   111  			}
   112  			a, b = b, a
   113  		case QUICSetReadSecret:
   114  			a.setReadSecret(e.Level, e.Suite, e.Data)
   115  		case QUICSetWriteSecret:
   116  			a.setWriteSecret(e.Level, e.Suite, e.Data)
   117  		case QUICWriteData:
   118  			if err := b.conn.HandleData(e.Level, e.Data); err != nil {
   119  				return err
   120  			}
   121  		case QUICTransportParameters:
   122  			a.gotParams = e.Data
   123  			if a.gotParams == nil {
   124  				a.gotParams = []byte{}
   125  			}
   126  		case QUICTransportParametersRequired:
   127  			return errTransportParametersRequired
   128  		case QUICHandshakeDone:
   129  			a.complete = true
   130  			if a == srv {
   131  				opts := QUICSessionTicketOptions{}
   132  				if err := srv.conn.SendSessionTicket(opts); err != nil {
   133  					return err
   134  				}
   135  			}
   136  		}
   137  		if e.Kind != QUICNoEvent {
   138  			idleCount = 0
   139  		}
   140  	}
   141  }
   142  
   143  func TestQUICConnection(t *testing.T) {
   144  	config := testConfig.Clone()
   145  	config.MinVersion = VersionTLS13
   146  
   147  	cli := newTestQUICClient(t, config)
   148  	cli.conn.SetTransportParameters(nil)
   149  
   150  	srv := newTestQUICServer(t, config)
   151  	srv.conn.SetTransportParameters(nil)
   152  
   153  	if err := runTestQUICConnection(context.Background(), cli, srv, nil); err != nil {
   154  		t.Fatalf("error during connection handshake: %v", err)
   155  	}
   156  
   157  	if _, ok := cli.readSecret[QUICEncryptionLevelHandshake]; !ok {
   158  		t.Errorf("client has no Handshake secret")
   159  	}
   160  	if _, ok := cli.readSecret[QUICEncryptionLevelApplication]; !ok {
   161  		t.Errorf("client has no Application secret")
   162  	}
   163  	if _, ok := srv.readSecret[QUICEncryptionLevelHandshake]; !ok {
   164  		t.Errorf("server has no Handshake secret")
   165  	}
   166  	if _, ok := srv.readSecret[QUICEncryptionLevelApplication]; !ok {
   167  		t.Errorf("server has no Application secret")
   168  	}
   169  	for _, level := range []QUICEncryptionLevel{QUICEncryptionLevelHandshake, QUICEncryptionLevelApplication} {
   170  		if _, ok := cli.readSecret[level]; !ok {
   171  			t.Errorf("client has no %v read secret", level)
   172  		}
   173  		if _, ok := srv.readSecret[level]; !ok {
   174  			t.Errorf("server has no %v read secret", level)
   175  		}
   176  		if !reflect.DeepEqual(cli.readSecret[level], srv.writeSecret[level]) {
   177  			t.Errorf("client read secret does not match server write secret for level %v", level)
   178  		}
   179  		if !reflect.DeepEqual(cli.writeSecret[level], srv.readSecret[level]) {
   180  			t.Errorf("client write secret does not match server read secret for level %v", level)
   181  		}
   182  	}
   183  }
   184  
   185  func TestQUICSessionResumption(t *testing.T) {
   186  	clientConfig := testConfig.Clone()
   187  	clientConfig.MinVersion = VersionTLS13
   188  	clientConfig.ClientSessionCache = NewLRUClientSessionCache(1)
   189  	clientConfig.ServerName = "example.go.dev"
   190  
   191  	serverConfig := testConfig.Clone()
   192  	serverConfig.MinVersion = VersionTLS13
   193  
   194  	cli := newTestQUICClient(t, clientConfig)
   195  	cli.conn.SetTransportParameters(nil)
   196  	srv := newTestQUICServer(t, serverConfig)
   197  	srv.conn.SetTransportParameters(nil)
   198  	if err := runTestQUICConnection(context.Background(), cli, srv, nil); err != nil {
   199  		t.Fatalf("error during first connection handshake: %v", err)
   200  	}
   201  	if cli.conn.ConnectionState().DidResume {
   202  		t.Errorf("first connection unexpectedly used session resumption")
   203  	}
   204  
   205  	cli2 := newTestQUICClient(t, clientConfig)
   206  	cli2.conn.SetTransportParameters(nil)
   207  	srv2 := newTestQUICServer(t, serverConfig)
   208  	srv2.conn.SetTransportParameters(nil)
   209  	if err := runTestQUICConnection(context.Background(), cli2, srv2, nil); err != nil {
   210  		t.Fatalf("error during second connection handshake: %v", err)
   211  	}
   212  	if !cli2.conn.ConnectionState().DidResume {
   213  		t.Errorf("second connection did not use session resumption")
   214  	}
   215  }
   216  
   217  func TestQUICFragmentaryData(t *testing.T) {
   218  	clientConfig := testConfig.Clone()
   219  	clientConfig.MinVersion = VersionTLS13
   220  	clientConfig.ClientSessionCache = NewLRUClientSessionCache(1)
   221  	clientConfig.ServerName = "example.go.dev"
   222  
   223  	serverConfig := testConfig.Clone()
   224  	serverConfig.MinVersion = VersionTLS13
   225  
   226  	cli := newTestQUICClient(t, clientConfig)
   227  	cli.conn.SetTransportParameters(nil)
   228  	srv := newTestQUICServer(t, serverConfig)
   229  	srv.conn.SetTransportParameters(nil)
   230  	onEvent := func(e QUICEvent, src, dst *testQUICConn) bool {
   231  		if e.Kind == QUICWriteData {
   232  			// Provide the data one byte at a time.
   233  			for i := range e.Data {
   234  				if err := dst.conn.HandleData(e.Level, e.Data[i:i+1]); err != nil {
   235  					t.Errorf("HandleData: %v", err)
   236  					break
   237  				}
   238  			}
   239  			return true
   240  		}
   241  		return false
   242  	}
   243  	if err := runTestQUICConnection(context.Background(), cli, srv, onEvent); err != nil {
   244  		t.Fatalf("error during first connection handshake: %v", err)
   245  	}
   246  }
   247  
   248  func TestQUICPostHandshakeClientAuthentication(t *testing.T) {
   249  	// RFC 9001, Section 4.4.
   250  	config := testConfig.Clone()
   251  	config.MinVersion = VersionTLS13
   252  	cli := newTestQUICClient(t, config)
   253  	cli.conn.SetTransportParameters(nil)
   254  	srv := newTestQUICServer(t, config)
   255  	srv.conn.SetTransportParameters(nil)
   256  	if err := runTestQUICConnection(context.Background(), cli, srv, nil); err != nil {
   257  		t.Fatalf("error during connection handshake: %v", err)
   258  	}
   259  
   260  	certReq := new(certificateRequestMsgTLS13)
   261  	certReq.ocspStapling = true
   262  	certReq.scts = true
   263  	certReq.supportedSignatureAlgorithms = supportedSignatureAlgorithms()
   264  	certReqBytes, err := certReq.marshal()
   265  	if err != nil {
   266  		t.Fatal(err)
   267  	}
   268  	if err := cli.conn.HandleData(QUICEncryptionLevelApplication, append([]byte{
   269  		byte(typeCertificateRequest),
   270  		byte(0), byte(0), byte(len(certReqBytes)),
   271  	}, certReqBytes...)); err == nil {
   272  		t.Fatalf("post-handshake authentication request: got no error, want one")
   273  	}
   274  }
   275  
   276  func TestQUICPostHandshakeKeyUpdate(t *testing.T) {
   277  	// RFC 9001, Section 6.
   278  	config := testConfig.Clone()
   279  	config.MinVersion = VersionTLS13
   280  	cli := newTestQUICClient(t, config)
   281  	cli.conn.SetTransportParameters(nil)
   282  	srv := newTestQUICServer(t, config)
   283  	srv.conn.SetTransportParameters(nil)
   284  	if err := runTestQUICConnection(context.Background(), cli, srv, nil); err != nil {
   285  		t.Fatalf("error during connection handshake: %v", err)
   286  	}
   287  
   288  	keyUpdate := new(keyUpdateMsg)
   289  	keyUpdateBytes, err := keyUpdate.marshal()
   290  	if err != nil {
   291  		t.Fatal(err)
   292  	}
   293  	if err := cli.conn.HandleData(QUICEncryptionLevelApplication, append([]byte{
   294  		byte(typeKeyUpdate),
   295  		byte(0), byte(0), byte(len(keyUpdateBytes)),
   296  	}, keyUpdateBytes...)); !errors.Is(err, alertUnexpectedMessage) {
   297  		t.Fatalf("key update request: got error %v, want alertUnexpectedMessage", err)
   298  	}
   299  }
   300  
   301  func TestQUICPostHandshakeMessageTooLarge(t *testing.T) {
   302  	config := testConfig.Clone()
   303  	config.MinVersion = VersionTLS13
   304  	cli := newTestQUICClient(t, config)
   305  	cli.conn.SetTransportParameters(nil)
   306  	srv := newTestQUICServer(t, config)
   307  	srv.conn.SetTransportParameters(nil)
   308  	if err := runTestQUICConnection(context.Background(), cli, srv, nil); err != nil {
   309  		t.Fatalf("error during connection handshake: %v", err)
   310  	}
   311  
   312  	size := maxHandshake + 1
   313  	if err := cli.conn.HandleData(QUICEncryptionLevelApplication, []byte{
   314  		byte(typeNewSessionTicket),
   315  		byte(size >> 16),
   316  		byte(size >> 8),
   317  		byte(size),
   318  	}); err == nil {
   319  		t.Fatalf("%v-byte post-handshake message: got no error, want one", size)
   320  	}
   321  }
   322  
   323  func TestQUICHandshakeError(t *testing.T) {
   324  	clientConfig := testConfig.Clone()
   325  	clientConfig.MinVersion = VersionTLS13
   326  	clientConfig.InsecureSkipVerify = false
   327  	clientConfig.ServerName = "name"
   328  
   329  	serverConfig := testConfig.Clone()
   330  	serverConfig.MinVersion = VersionTLS13
   331  
   332  	cli := newTestQUICClient(t, clientConfig)
   333  	cli.conn.SetTransportParameters(nil)
   334  	srv := newTestQUICServer(t, serverConfig)
   335  	srv.conn.SetTransportParameters(nil)
   336  	err := runTestQUICConnection(context.Background(), cli, srv, nil)
   337  	if !errors.Is(err, AlertError(alertBadCertificate)) {
   338  		t.Errorf("connection handshake terminated with error %q, want alertBadCertificate", err)
   339  	}
   340  	var e *CertificateVerificationError
   341  	if !errors.As(err, &e) {
   342  		t.Errorf("connection handshake terminated with error %q, want CertificateVerificationError", err)
   343  	}
   344  }
   345  
   346  // Test that QUICConn.ConnectionState can be used during the handshake,
   347  // and that it reports the application protocol as soon as it has been
   348  // negotiated.
   349  func TestQUICConnectionState(t *testing.T) {
   350  	config := testConfig.Clone()
   351  	config.MinVersion = VersionTLS13
   352  	config.NextProtos = []string{"h3"}
   353  	cli := newTestQUICClient(t, config)
   354  	cli.conn.SetTransportParameters(nil)
   355  	srv := newTestQUICServer(t, config)
   356  	srv.conn.SetTransportParameters(nil)
   357  	onEvent := func(e QUICEvent, src, dst *testQUICConn) bool {
   358  		cliCS := cli.conn.ConnectionState()
   359  		if _, ok := cli.readSecret[QUICEncryptionLevelApplication]; ok {
   360  			if want, got := cliCS.NegotiatedProtocol, "h3"; want != got {
   361  				t.Errorf("cli.ConnectionState().NegotiatedProtocol = %q, want %q", want, got)
   362  			}
   363  		}
   364  		srvCS := srv.conn.ConnectionState()
   365  		if _, ok := srv.readSecret[QUICEncryptionLevelHandshake]; ok {
   366  			if want, got := srvCS.NegotiatedProtocol, "h3"; want != got {
   367  				t.Errorf("srv.ConnectionState().NegotiatedProtocol = %q, want %q", want, got)
   368  			}
   369  		}
   370  		return false
   371  	}
   372  	if err := runTestQUICConnection(context.Background(), cli, srv, onEvent); err != nil {
   373  		t.Fatalf("error during connection handshake: %v", err)
   374  	}
   375  }
   376  
   377  func TestQUICStartContextPropagation(t *testing.T) {
   378  	const key = "key"
   379  	const value = "value"
   380  	ctx := context.WithValue(context.Background(), key, value)
   381  	config := testConfig.Clone()
   382  	config.MinVersion = VersionTLS13
   383  	calls := 0
   384  	config.GetConfigForClient = func(info *ClientHelloInfo) (*Config, error) {
   385  		calls++
   386  		got, _ := info.Context().Value(key).(string)
   387  		if got != value {
   388  			t.Errorf("GetConfigForClient context key %q has value %q, want %q", key, got, value)
   389  		}
   390  		return nil, nil
   391  	}
   392  	cli := newTestQUICClient(t, config)
   393  	cli.conn.SetTransportParameters(nil)
   394  	srv := newTestQUICServer(t, config)
   395  	srv.conn.SetTransportParameters(nil)
   396  	if err := runTestQUICConnection(ctx, cli, srv, nil); err != nil {
   397  		t.Fatalf("error during connection handshake: %v", err)
   398  	}
   399  	if calls != 1 {
   400  		t.Errorf("GetConfigForClient called %v times, want 1", calls)
   401  	}
   402  }
   403  
   404  func TestQUICDelayedTransportParameters(t *testing.T) {
   405  	clientConfig := testConfig.Clone()
   406  	clientConfig.MinVersion = VersionTLS13
   407  	clientConfig.ClientSessionCache = NewLRUClientSessionCache(1)
   408  	clientConfig.ServerName = "example.go.dev"
   409  
   410  	serverConfig := testConfig.Clone()
   411  	serverConfig.MinVersion = VersionTLS13
   412  
   413  	cliParams := "client params"
   414  	srvParams := "server params"
   415  
   416  	cli := newTestQUICClient(t, clientConfig)
   417  	srv := newTestQUICServer(t, serverConfig)
   418  	if err := runTestQUICConnection(context.Background(), cli, srv, nil); err != errTransportParametersRequired {
   419  		t.Fatalf("handshake with no client parameters: %v; want errTransportParametersRequired", err)
   420  	}
   421  	cli.conn.SetTransportParameters([]byte(cliParams))
   422  	if err := runTestQUICConnection(context.Background(), cli, srv, nil); err != errTransportParametersRequired {
   423  		t.Fatalf("handshake with no server parameters: %v; want errTransportParametersRequired", err)
   424  	}
   425  	srv.conn.SetTransportParameters([]byte(srvParams))
   426  	if err := runTestQUICConnection(context.Background(), cli, srv, nil); err != nil {
   427  		t.Fatalf("error during connection handshake: %v", err)
   428  	}
   429  
   430  	if got, want := string(cli.gotParams), srvParams; got != want {
   431  		t.Errorf("client got transport params: %q, want %q", got, want)
   432  	}
   433  	if got, want := string(srv.gotParams), cliParams; got != want {
   434  		t.Errorf("server got transport params: %q, want %q", got, want)
   435  	}
   436  }
   437  
   438  func TestQUICEmptyTransportParameters(t *testing.T) {
   439  	config := testConfig.Clone()
   440  	config.MinVersion = VersionTLS13
   441  
   442  	cli := newTestQUICClient(t, config)
   443  	cli.conn.SetTransportParameters(nil)
   444  	srv := newTestQUICServer(t, config)
   445  	srv.conn.SetTransportParameters(nil)
   446  	if err := runTestQUICConnection(context.Background(), cli, srv, nil); err != nil {
   447  		t.Fatalf("error during connection handshake: %v", err)
   448  	}
   449  
   450  	if cli.gotParams == nil {
   451  		t.Errorf("client did not get transport params")
   452  	}
   453  	if srv.gotParams == nil {
   454  		t.Errorf("server did not get transport params")
   455  	}
   456  	if len(cli.gotParams) != 0 {
   457  		t.Errorf("client got transport params: %v, want empty", cli.gotParams)
   458  	}
   459  	if len(srv.gotParams) != 0 {
   460  		t.Errorf("server got transport params: %v, want empty", srv.gotParams)
   461  	}
   462  }
   463  
   464  func TestQUICCanceledWaitingForData(t *testing.T) {
   465  	config := testConfig.Clone()
   466  	config.MinVersion = VersionTLS13
   467  	cli := newTestQUICClient(t, config)
   468  	cli.conn.SetTransportParameters(nil)
   469  	cli.conn.Start(context.Background())
   470  	for cli.conn.NextEvent().Kind != QUICNoEvent {
   471  	}
   472  	err := cli.conn.Close()
   473  	if !errors.Is(err, alertCloseNotify) {
   474  		t.Errorf("conn.Close() = %v, want alertCloseNotify", err)
   475  	}
   476  }
   477  
   478  func TestQUICCanceledWaitingForTransportParams(t *testing.T) {
   479  	config := testConfig.Clone()
   480  	config.MinVersion = VersionTLS13
   481  	cli := newTestQUICClient(t, config)
   482  	cli.conn.Start(context.Background())
   483  	for cli.conn.NextEvent().Kind != QUICTransportParametersRequired {
   484  	}
   485  	err := cli.conn.Close()
   486  	if !errors.Is(err, alertCloseNotify) {
   487  		t.Errorf("conn.Close() = %v, want alertCloseNotify", err)
   488  	}
   489  }
   490  

View as plain text