// Copyright 2023 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. package tls import ( "context" "errors" "reflect" "testing" ) type testQUICConn struct { t *testing.T conn *QUICConn readSecret map[QUICEncryptionLevel]suiteSecret writeSecret map[QUICEncryptionLevel]suiteSecret gotParams []byte complete bool } func newTestQUICClient(t *testing.T, config *Config) *testQUICConn { q := &testQUICConn{t: t} q.conn = QUICClient(&QUICConfig{ TLSConfig: config, }) t.Cleanup(func() { q.conn.Close() }) return q } func newTestQUICServer(t *testing.T, config *Config) *testQUICConn { q := &testQUICConn{t: t} q.conn = QUICServer(&QUICConfig{ TLSConfig: config, }) t.Cleanup(func() { q.conn.Close() }) return q } type suiteSecret struct { suite uint16 secret []byte } func (q *testQUICConn) setReadSecret(level QUICEncryptionLevel, suite uint16, secret []byte) { if _, ok := q.writeSecret[level]; !ok { q.t.Errorf("SetReadSecret for level %v called before SetWriteSecret", level) } if level == QUICEncryptionLevelApplication && !q.complete { q.t.Errorf("SetReadSecret for level %v called before HandshakeComplete", level) } if _, ok := q.readSecret[level]; ok { q.t.Errorf("SetReadSecret for level %v called twice", level) } if q.readSecret == nil { q.readSecret = map[QUICEncryptionLevel]suiteSecret{} } switch level { case QUICEncryptionLevelHandshake, QUICEncryptionLevelApplication: q.readSecret[level] = suiteSecret{suite, secret} default: q.t.Errorf("SetReadSecret for unexpected level %v", level) } } func (q *testQUICConn) setWriteSecret(level QUICEncryptionLevel, suite uint16, secret []byte) { if _, ok := q.writeSecret[level]; ok { q.t.Errorf("SetWriteSecret for level %v called twice", level) } if q.writeSecret == nil { q.writeSecret = map[QUICEncryptionLevel]suiteSecret{} } switch level { case QUICEncryptionLevelHandshake, QUICEncryptionLevelApplication: q.writeSecret[level] = suiteSecret{suite, secret} default: q.t.Errorf("SetWriteSecret for unexpected level %v", level) } } var errTransportParametersRequired = errors.New("transport parameters required") func runTestQUICConnection(ctx context.Context, cli, srv *testQUICConn, onEvent func(e QUICEvent, src, dst *testQUICConn) bool) error { a, b := cli, srv for _, c := range []*testQUICConn{a, b} { if !c.conn.conn.quic.started { if err := c.conn.Start(ctx); err != nil { return err } } } idleCount := 0 for { e := a.conn.NextEvent() if onEvent != nil && onEvent(e, a, b) { continue } switch e.Kind { case QUICNoEvent: idleCount++ if idleCount == 2 { if !a.complete || !b.complete { return errors.New("handshake incomplete") } return nil } a, b = b, a case QUICSetReadSecret: a.setReadSecret(e.Level, e.Suite, e.Data) case QUICSetWriteSecret: a.setWriteSecret(e.Level, e.Suite, e.Data) case QUICWriteData: if err := b.conn.HandleData(e.Level, e.Data); err != nil { return err } case QUICTransportParameters: a.gotParams = e.Data if a.gotParams == nil { a.gotParams = []byte{} } case QUICTransportParametersRequired: return errTransportParametersRequired case QUICHandshakeDone: a.complete = true if a == srv { opts := QUICSessionTicketOptions{} if err := srv.conn.SendSessionTicket(opts); err != nil { return err } } } if e.Kind != QUICNoEvent { idleCount = 0 } } } func TestQUICConnection(t *testing.T) { config := testConfig.Clone() config.MinVersion = VersionTLS13 cli := newTestQUICClient(t, config) cli.conn.SetTransportParameters(nil) srv := newTestQUICServer(t, config) srv.conn.SetTransportParameters(nil) if err := runTestQUICConnection(context.Background(), cli, srv, nil); err != nil { t.Fatalf("error during connection handshake: %v", err) } if _, ok := cli.readSecret[QUICEncryptionLevelHandshake]; !ok { t.Errorf("client has no Handshake secret") } if _, ok := cli.readSecret[QUICEncryptionLevelApplication]; !ok { t.Errorf("client has no Application secret") } if _, ok := srv.readSecret[QUICEncryptionLevelHandshake]; !ok { t.Errorf("server has no Handshake secret") } if _, ok := srv.readSecret[QUICEncryptionLevelApplication]; !ok { t.Errorf("server has no Application secret") } for _, level := range []QUICEncryptionLevel{QUICEncryptionLevelHandshake, QUICEncryptionLevelApplication} { if _, ok := cli.readSecret[level]; !ok { t.Errorf("client has no %v read secret", level) } if _, ok := srv.readSecret[level]; !ok { t.Errorf("server has no %v read secret", level) } if !reflect.DeepEqual(cli.readSecret[level], srv.writeSecret[level]) { t.Errorf("client read secret does not match server write secret for level %v", level) } if !reflect.DeepEqual(cli.writeSecret[level], srv.readSecret[level]) { t.Errorf("client write secret does not match server read secret for level %v", level) } } } func TestQUICSessionResumption(t *testing.T) { clientConfig := testConfig.Clone() clientConfig.MinVersion = VersionTLS13 clientConfig.ClientSessionCache = NewLRUClientSessionCache(1) clientConfig.ServerName = "example.go.dev" serverConfig := testConfig.Clone() serverConfig.MinVersion = VersionTLS13 cli := newTestQUICClient(t, clientConfig) cli.conn.SetTransportParameters(nil) srv := newTestQUICServer(t, serverConfig) srv.conn.SetTransportParameters(nil) if err := runTestQUICConnection(context.Background(), cli, srv, nil); err != nil { t.Fatalf("error during first connection handshake: %v", err) } if cli.conn.ConnectionState().DidResume { t.Errorf("first connection unexpectedly used session resumption") } cli2 := newTestQUICClient(t, clientConfig) cli2.conn.SetTransportParameters(nil) srv2 := newTestQUICServer(t, serverConfig) srv2.conn.SetTransportParameters(nil) if err := runTestQUICConnection(context.Background(), cli2, srv2, nil); err != nil { t.Fatalf("error during second connection handshake: %v", err) } if !cli2.conn.ConnectionState().DidResume { t.Errorf("second connection did not use session resumption") } } func TestQUICFragmentaryData(t *testing.T) { clientConfig := testConfig.Clone() clientConfig.MinVersion = VersionTLS13 clientConfig.ClientSessionCache = NewLRUClientSessionCache(1) clientConfig.ServerName = "example.go.dev" serverConfig := testConfig.Clone() serverConfig.MinVersion = VersionTLS13 cli := newTestQUICClient(t, clientConfig) cli.conn.SetTransportParameters(nil) srv := newTestQUICServer(t, serverConfig) srv.conn.SetTransportParameters(nil) onEvent := func(e QUICEvent, src, dst *testQUICConn) bool { if e.Kind == QUICWriteData { // Provide the data one byte at a time. for i := range e.Data { if err := dst.conn.HandleData(e.Level, e.Data[i:i+1]); err != nil { t.Errorf("HandleData: %v", err) break } } return true } return false } if err := runTestQUICConnection(context.Background(), cli, srv, onEvent); err != nil { t.Fatalf("error during first connection handshake: %v", err) } } func TestQUICPostHandshakeClientAuthentication(t *testing.T) { // RFC 9001, Section 4.4. config := testConfig.Clone() config.MinVersion = VersionTLS13 cli := newTestQUICClient(t, config) cli.conn.SetTransportParameters(nil) srv := newTestQUICServer(t, config) srv.conn.SetTransportParameters(nil) if err := runTestQUICConnection(context.Background(), cli, srv, nil); err != nil { t.Fatalf("error during connection handshake: %v", err) } certReq := new(certificateRequestMsgTLS13) certReq.ocspStapling = true certReq.scts = true certReq.supportedSignatureAlgorithms = supportedSignatureAlgorithms() certReqBytes, err := certReq.marshal() if err != nil { t.Fatal(err) } if err := cli.conn.HandleData(QUICEncryptionLevelApplication, append([]byte{ byte(typeCertificateRequest), byte(0), byte(0), byte(len(certReqBytes)), }, certReqBytes...)); err == nil { t.Fatalf("post-handshake authentication request: got no error, want one") } } func TestQUICPostHandshakeKeyUpdate(t *testing.T) { // RFC 9001, Section 6. config := testConfig.Clone() config.MinVersion = VersionTLS13 cli := newTestQUICClient(t, config) cli.conn.SetTransportParameters(nil) srv := newTestQUICServer(t, config) srv.conn.SetTransportParameters(nil) if err := runTestQUICConnection(context.Background(), cli, srv, nil); err != nil { t.Fatalf("error during connection handshake: %v", err) } keyUpdate := new(keyUpdateMsg) keyUpdateBytes, err := keyUpdate.marshal() if err != nil { t.Fatal(err) } if err := cli.conn.HandleData(QUICEncryptionLevelApplication, append([]byte{ byte(typeKeyUpdate), byte(0), byte(0), byte(len(keyUpdateBytes)), }, keyUpdateBytes...)); !errors.Is(err, alertUnexpectedMessage) { t.Fatalf("key update request: got error %v, want alertUnexpectedMessage", err) } } func TestQUICPostHandshakeMessageTooLarge(t *testing.T) { config := testConfig.Clone() config.MinVersion = VersionTLS13 cli := newTestQUICClient(t, config) cli.conn.SetTransportParameters(nil) srv := newTestQUICServer(t, config) srv.conn.SetTransportParameters(nil) if err := runTestQUICConnection(context.Background(), cli, srv, nil); err != nil { t.Fatalf("error during connection handshake: %v", err) } size := maxHandshake + 1 if err := cli.conn.HandleData(QUICEncryptionLevelApplication, []byte{ byte(typeNewSessionTicket), byte(size >> 16), byte(size >> 8), byte(size), }); err == nil { t.Fatalf("%v-byte post-handshake message: got no error, want one", size) } } func TestQUICHandshakeError(t *testing.T) { clientConfig := testConfig.Clone() clientConfig.MinVersion = VersionTLS13 clientConfig.InsecureSkipVerify = false clientConfig.ServerName = "name" serverConfig := testConfig.Clone() serverConfig.MinVersion = VersionTLS13 cli := newTestQUICClient(t, clientConfig) cli.conn.SetTransportParameters(nil) srv := newTestQUICServer(t, serverConfig) srv.conn.SetTransportParameters(nil) err := runTestQUICConnection(context.Background(), cli, srv, nil) if !errors.Is(err, AlertError(alertBadCertificate)) { t.Errorf("connection handshake terminated with error %q, want alertBadCertificate", err) } var e *CertificateVerificationError if !errors.As(err, &e) { t.Errorf("connection handshake terminated with error %q, want CertificateVerificationError", err) } } // Test that QUICConn.ConnectionState can be used during the handshake, // and that it reports the application protocol as soon as it has been // negotiated. func TestQUICConnectionState(t *testing.T) { config := testConfig.Clone() config.MinVersion = VersionTLS13 config.NextProtos = []string{"h3"} cli := newTestQUICClient(t, config) cli.conn.SetTransportParameters(nil) srv := newTestQUICServer(t, config) srv.conn.SetTransportParameters(nil) onEvent := func(e QUICEvent, src, dst *testQUICConn) bool { cliCS := cli.conn.ConnectionState() if _, ok := cli.readSecret[QUICEncryptionLevelApplication]; ok { if want, got := cliCS.NegotiatedProtocol, "h3"; want != got { t.Errorf("cli.ConnectionState().NegotiatedProtocol = %q, want %q", want, got) } } srvCS := srv.conn.ConnectionState() if _, ok := srv.readSecret[QUICEncryptionLevelHandshake]; ok { if want, got := srvCS.NegotiatedProtocol, "h3"; want != got { t.Errorf("srv.ConnectionState().NegotiatedProtocol = %q, want %q", want, got) } } return false } if err := runTestQUICConnection(context.Background(), cli, srv, onEvent); err != nil { t.Fatalf("error during connection handshake: %v", err) } } func TestQUICStartContextPropagation(t *testing.T) { const key = "key" const value = "value" ctx := context.WithValue(context.Background(), key, value) config := testConfig.Clone() config.MinVersion = VersionTLS13 calls := 0 config.GetConfigForClient = func(info *ClientHelloInfo) (*Config, error) { calls++ got, _ := info.Context().Value(key).(string) if got != value { t.Errorf("GetConfigForClient context key %q has value %q, want %q", key, got, value) } return nil, nil } cli := newTestQUICClient(t, config) cli.conn.SetTransportParameters(nil) srv := newTestQUICServer(t, config) srv.conn.SetTransportParameters(nil) if err := runTestQUICConnection(ctx, cli, srv, nil); err != nil { t.Fatalf("error during connection handshake: %v", err) } if calls != 1 { t.Errorf("GetConfigForClient called %v times, want 1", calls) } } func TestQUICDelayedTransportParameters(t *testing.T) { clientConfig := testConfig.Clone() clientConfig.MinVersion = VersionTLS13 clientConfig.ClientSessionCache = NewLRUClientSessionCache(1) clientConfig.ServerName = "example.go.dev" serverConfig := testConfig.Clone() serverConfig.MinVersion = VersionTLS13 cliParams := "client params" srvParams := "server params" cli := newTestQUICClient(t, clientConfig) srv := newTestQUICServer(t, serverConfig) if err := runTestQUICConnection(context.Background(), cli, srv, nil); err != errTransportParametersRequired { t.Fatalf("handshake with no client parameters: %v; want errTransportParametersRequired", err) } cli.conn.SetTransportParameters([]byte(cliParams)) if err := runTestQUICConnection(context.Background(), cli, srv, nil); err != errTransportParametersRequired { t.Fatalf("handshake with no server parameters: %v; want errTransportParametersRequired", err) } srv.conn.SetTransportParameters([]byte(srvParams)) if err := runTestQUICConnection(context.Background(), cli, srv, nil); err != nil { t.Fatalf("error during connection handshake: %v", err) } if got, want := string(cli.gotParams), srvParams; got != want { t.Errorf("client got transport params: %q, want %q", got, want) } if got, want := string(srv.gotParams), cliParams; got != want { t.Errorf("server got transport params: %q, want %q", got, want) } } func TestQUICEmptyTransportParameters(t *testing.T) { config := testConfig.Clone() config.MinVersion = VersionTLS13 cli := newTestQUICClient(t, config) cli.conn.SetTransportParameters(nil) srv := newTestQUICServer(t, config) srv.conn.SetTransportParameters(nil) if err := runTestQUICConnection(context.Background(), cli, srv, nil); err != nil { t.Fatalf("error during connection handshake: %v", err) } if cli.gotParams == nil { t.Errorf("client did not get transport params") } if srv.gotParams == nil { t.Errorf("server did not get transport params") } if len(cli.gotParams) != 0 { t.Errorf("client got transport params: %v, want empty", cli.gotParams) } if len(srv.gotParams) != 0 { t.Errorf("server got transport params: %v, want empty", srv.gotParams) } } func TestQUICCanceledWaitingForData(t *testing.T) { config := testConfig.Clone() config.MinVersion = VersionTLS13 cli := newTestQUICClient(t, config) cli.conn.SetTransportParameters(nil) cli.conn.Start(context.Background()) for cli.conn.NextEvent().Kind != QUICNoEvent { } err := cli.conn.Close() if !errors.Is(err, alertCloseNotify) { t.Errorf("conn.Close() = %v, want alertCloseNotify", err) } } func TestQUICCanceledWaitingForTransportParams(t *testing.T) { config := testConfig.Clone() config.MinVersion = VersionTLS13 cli := newTestQUICClient(t, config) cli.conn.Start(context.Background()) for cli.conn.NextEvent().Kind != QUICTransportParametersRequired { } err := cli.conn.Close() if !errors.Is(err, alertCloseNotify) { t.Errorf("conn.Close() = %v, want alertCloseNotify", err) } }