Source file src/errors/wrap_test.go

     1  // Copyright 2018 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package errors_test
     6  
     7  import (
     8  	"errors"
     9  	"fmt"
    10  	"io/fs"
    11  	"os"
    12  	"reflect"
    13  	"testing"
    14  )
    15  
    16  func TestIs(t *testing.T) {
    17  	err1 := errors.New("1")
    18  	erra := wrapped{"wrap 2", err1}
    19  	errb := wrapped{"wrap 3", erra}
    20  
    21  	err3 := errors.New("3")
    22  
    23  	poser := &poser{"either 1 or 3", func(err error) bool {
    24  		return err == err1 || err == err3
    25  	}}
    26  
    27  	testCases := []struct {
    28  		err    error
    29  		target error
    30  		match  bool
    31  	}{
    32  		{nil, nil, true},
    33  		{err1, nil, false},
    34  		{err1, err1, true},
    35  		{erra, err1, true},
    36  		{errb, err1, true},
    37  		{err1, err3, false},
    38  		{erra, err3, false},
    39  		{errb, err3, false},
    40  		{poser, err1, true},
    41  		{poser, err3, true},
    42  		{poser, erra, false},
    43  		{poser, errb, false},
    44  		{errorUncomparable{}, errorUncomparable{}, true},
    45  		{errorUncomparable{}, &errorUncomparable{}, false},
    46  		{&errorUncomparable{}, errorUncomparable{}, true},
    47  		{&errorUncomparable{}, &errorUncomparable{}, false},
    48  		{errorUncomparable{}, err1, false},
    49  		{&errorUncomparable{}, err1, false},
    50  		{multiErr{}, err1, false},
    51  		{multiErr{err1, err3}, err1, true},
    52  		{multiErr{err3, err1}, err1, true},
    53  		{multiErr{err1, err3}, errors.New("x"), false},
    54  		{multiErr{err3, errb}, errb, true},
    55  		{multiErr{err3, errb}, erra, true},
    56  		{multiErr{err3, errb}, err1, true},
    57  		{multiErr{errb, err3}, err1, true},
    58  		{multiErr{poser}, err1, true},
    59  		{multiErr{poser}, err3, true},
    60  		{multiErr{nil}, nil, false},
    61  	}
    62  	for _, tc := range testCases {
    63  		t.Run("", func(t *testing.T) {
    64  			if got := errors.Is(tc.err, tc.target); got != tc.match {
    65  				t.Errorf("Is(%v, %v) = %v, want %v", tc.err, tc.target, got, tc.match)
    66  			}
    67  		})
    68  	}
    69  }
    70  
    71  type poser struct {
    72  	msg string
    73  	f   func(error) bool
    74  }
    75  
    76  var poserPathErr = &fs.PathError{Op: "poser"}
    77  
    78  func (p *poser) Error() string     { return p.msg }
    79  func (p *poser) Is(err error) bool { return p.f(err) }
    80  func (p *poser) As(err any) bool {
    81  	switch x := err.(type) {
    82  	case **poser:
    83  		*x = p
    84  	case *errorT:
    85  		*x = errorT{"poser"}
    86  	case **fs.PathError:
    87  		*x = poserPathErr
    88  	default:
    89  		return false
    90  	}
    91  	return true
    92  }
    93  
    94  func TestAs(t *testing.T) {
    95  	var errT errorT
    96  	var errP *fs.PathError
    97  	var timeout interface{ Timeout() bool }
    98  	var p *poser
    99  	_, errF := os.Open("non-existing")
   100  	poserErr := &poser{"oh no", nil}
   101  
   102  	testCases := []struct {
   103  		err    error
   104  		target any
   105  		match  bool
   106  		want   any // value of target on match
   107  	}{{
   108  		nil,
   109  		&errP,
   110  		false,
   111  		nil,
   112  	}, {
   113  		wrapped{"pitied the fool", errorT{"T"}},
   114  		&errT,
   115  		true,
   116  		errorT{"T"},
   117  	}, {
   118  		errF,
   119  		&errP,
   120  		true,
   121  		errF,
   122  	}, {
   123  		errorT{},
   124  		&errP,
   125  		false,
   126  		nil,
   127  	}, {
   128  		wrapped{"wrapped", nil},
   129  		&errT,
   130  		false,
   131  		nil,
   132  	}, {
   133  		&poser{"error", nil},
   134  		&errT,
   135  		true,
   136  		errorT{"poser"},
   137  	}, {
   138  		&poser{"path", nil},
   139  		&errP,
   140  		true,
   141  		poserPathErr,
   142  	}, {
   143  		poserErr,
   144  		&p,
   145  		true,
   146  		poserErr,
   147  	}, {
   148  		errors.New("err"),
   149  		&timeout,
   150  		false,
   151  		nil,
   152  	}, {
   153  		errF,
   154  		&timeout,
   155  		true,
   156  		errF,
   157  	}, {
   158  		wrapped{"path error", errF},
   159  		&timeout,
   160  		true,
   161  		errF,
   162  	}, {
   163  		multiErr{},
   164  		&errT,
   165  		false,
   166  		nil,
   167  	}, {
   168  		multiErr{errors.New("a"), errorT{"T"}},
   169  		&errT,
   170  		true,
   171  		errorT{"T"},
   172  	}, {
   173  		multiErr{errorT{"T"}, errors.New("a")},
   174  		&errT,
   175  		true,
   176  		errorT{"T"},
   177  	}, {
   178  		multiErr{errorT{"a"}, errorT{"b"}},
   179  		&errT,
   180  		true,
   181  		errorT{"a"},
   182  	}, {
   183  		multiErr{multiErr{errors.New("a"), errorT{"a"}}, errorT{"b"}},
   184  		&errT,
   185  		true,
   186  		errorT{"a"},
   187  	}, {
   188  		multiErr{wrapped{"path error", errF}},
   189  		&timeout,
   190  		true,
   191  		errF,
   192  	}, {
   193  		multiErr{nil},
   194  		&errT,
   195  		false,
   196  		nil,
   197  	}}
   198  	for i, tc := range testCases {
   199  		name := fmt.Sprintf("%d:As(Errorf(..., %v), %v)", i, tc.err, tc.target)
   200  		// Clear the target pointer, in case it was set in a previous test.
   201  		rtarget := reflect.ValueOf(tc.target)
   202  		rtarget.Elem().Set(reflect.Zero(reflect.TypeOf(tc.target).Elem()))
   203  		t.Run(name, func(t *testing.T) {
   204  			match := errors.As(tc.err, tc.target)
   205  			if match != tc.match {
   206  				t.Fatalf("match: got %v; want %v", match, tc.match)
   207  			}
   208  			if !match {
   209  				return
   210  			}
   211  			if got := rtarget.Elem().Interface(); got != tc.want {
   212  				t.Fatalf("got %#v, want %#v", got, tc.want)
   213  			}
   214  		})
   215  	}
   216  }
   217  
   218  func TestAsValidation(t *testing.T) {
   219  	var s string
   220  	testCases := []any{
   221  		nil,
   222  		(*int)(nil),
   223  		"error",
   224  		&s,
   225  	}
   226  	err := errors.New("error")
   227  	for _, tc := range testCases {
   228  		t.Run(fmt.Sprintf("%T(%v)", tc, tc), func(t *testing.T) {
   229  			defer func() {
   230  				recover()
   231  			}()
   232  			if errors.As(err, tc) {
   233  				t.Errorf("As(err, %T(%v)) = true, want false", tc, tc)
   234  				return
   235  			}
   236  			t.Errorf("As(err, %T(%v)) did not panic", tc, tc)
   237  		})
   238  	}
   239  }
   240  
   241  func BenchmarkIs(b *testing.B) {
   242  	err1 := errors.New("1")
   243  	err2 := multiErr{multiErr{multiErr{err1, errorT{"a"}}, errorT{"b"}}}
   244  
   245  	for i := 0; i < b.N; i++ {
   246  		if !errors.Is(err2, err1) {
   247  			b.Fatal("Is failed")
   248  		}
   249  	}
   250  }
   251  
   252  func BenchmarkAs(b *testing.B) {
   253  	err := multiErr{multiErr{multiErr{errors.New("a"), errorT{"a"}}, errorT{"b"}}}
   254  	for i := 0; i < b.N; i++ {
   255  		var target errorT
   256  		if !errors.As(err, &target) {
   257  			b.Fatal("As failed")
   258  		}
   259  	}
   260  }
   261  
   262  func TestUnwrap(t *testing.T) {
   263  	err1 := errors.New("1")
   264  	erra := wrapped{"wrap 2", err1}
   265  
   266  	testCases := []struct {
   267  		err  error
   268  		want error
   269  	}{
   270  		{nil, nil},
   271  		{wrapped{"wrapped", nil}, nil},
   272  		{err1, nil},
   273  		{erra, err1},
   274  		{wrapped{"wrap 3", erra}, erra},
   275  	}
   276  	for _, tc := range testCases {
   277  		if got := errors.Unwrap(tc.err); got != tc.want {
   278  			t.Errorf("Unwrap(%v) = %v, want %v", tc.err, got, tc.want)
   279  		}
   280  	}
   281  }
   282  
   283  type errorT struct{ s string }
   284  
   285  func (e errorT) Error() string { return fmt.Sprintf("errorT(%s)", e.s) }
   286  
   287  type wrapped struct {
   288  	msg string
   289  	err error
   290  }
   291  
   292  func (e wrapped) Error() string { return e.msg }
   293  func (e wrapped) Unwrap() error { return e.err }
   294  
   295  type multiErr []error
   296  
   297  func (m multiErr) Error() string   { return "multiError" }
   298  func (m multiErr) Unwrap() []error { return []error(m) }
   299  
   300  type errorUncomparable struct {
   301  	f []string
   302  }
   303  
   304  func (errorUncomparable) Error() string {
   305  	return "uncomparable error"
   306  }
   307  
   308  func (errorUncomparable) Is(target error) bool {
   309  	_, ok := target.(errorUncomparable)
   310  	return ok
   311  }
   312  

View as plain text