Source file src/crypto/elliptic/elliptic_test.go

     1  // Copyright 2010 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package elliptic
     6  
     7  import (
     8  	"bytes"
     9  	"crypto/rand"
    10  	"encoding/hex"
    11  	"math/big"
    12  	"testing"
    13  )
    14  
    15  // genericParamsForCurve returns the dereferenced CurveParams for
    16  // the specified curve. This is used to avoid the logic for
    17  // upgrading a curve to its specific implementation, forcing
    18  // usage of the generic implementation.
    19  func genericParamsForCurve(c Curve) *CurveParams {
    20  	d := *(c.Params())
    21  	return &d
    22  }
    23  
    24  func testAllCurves(t *testing.T, f func(*testing.T, Curve)) {
    25  	tests := []struct {
    26  		name  string
    27  		curve Curve
    28  	}{
    29  		{"P256", P256()},
    30  		{"P256/Params", genericParamsForCurve(P256())},
    31  		{"P224", P224()},
    32  		{"P224/Params", genericParamsForCurve(P224())},
    33  		{"P384", P384()},
    34  		{"P384/Params", genericParamsForCurve(P384())},
    35  		{"P521", P521()},
    36  		{"P521/Params", genericParamsForCurve(P521())},
    37  	}
    38  	if testing.Short() {
    39  		tests = tests[:1]
    40  	}
    41  	for _, test := range tests {
    42  		curve := test.curve
    43  		t.Run(test.name, func(t *testing.T) {
    44  			t.Parallel()
    45  			f(t, curve)
    46  		})
    47  	}
    48  }
    49  
    50  func TestOnCurve(t *testing.T) {
    51  	t.Parallel()
    52  	testAllCurves(t, func(t *testing.T, curve Curve) {
    53  		if !curve.IsOnCurve(curve.Params().Gx, curve.Params().Gy) {
    54  			t.Error("basepoint is not on the curve")
    55  		}
    56  	})
    57  }
    58  
    59  func TestOffCurve(t *testing.T) {
    60  	t.Parallel()
    61  	testAllCurves(t, func(t *testing.T, curve Curve) {
    62  		x, y := new(big.Int).SetInt64(1), new(big.Int).SetInt64(1)
    63  		if curve.IsOnCurve(x, y) {
    64  			t.Errorf("point off curve is claimed to be on the curve")
    65  		}
    66  
    67  		byteLen := (curve.Params().BitSize + 7) / 8
    68  		b := make([]byte, 1+2*byteLen)
    69  		b[0] = 4 // uncompressed point
    70  		x.FillBytes(b[1 : 1+byteLen])
    71  		y.FillBytes(b[1+byteLen : 1+2*byteLen])
    72  
    73  		x1, y1 := Unmarshal(curve, b)
    74  		if x1 != nil || y1 != nil {
    75  			t.Errorf("unmarshaling a point not on the curve succeeded")
    76  		}
    77  	})
    78  }
    79  
    80  func TestInfinity(t *testing.T) {
    81  	t.Parallel()
    82  	testAllCurves(t, testInfinity)
    83  }
    84  
    85  func isInfinity(x, y *big.Int) bool {
    86  	return x.Sign() == 0 && y.Sign() == 0
    87  }
    88  
    89  func testInfinity(t *testing.T, curve Curve) {
    90  	x0, y0 := new(big.Int), new(big.Int)
    91  	xG, yG := curve.Params().Gx, curve.Params().Gy
    92  
    93  	if !isInfinity(curve.ScalarMult(xG, yG, curve.Params().N.Bytes())) {
    94  		t.Errorf("x^q != ∞")
    95  	}
    96  	if !isInfinity(curve.ScalarMult(xG, yG, []byte{0})) {
    97  		t.Errorf("x^0 != ∞")
    98  	}
    99  
   100  	if !isInfinity(curve.ScalarMult(x0, y0, []byte{1, 2, 3})) {
   101  		t.Errorf("∞^k != ∞")
   102  	}
   103  	if !isInfinity(curve.ScalarMult(x0, y0, []byte{0})) {
   104  		t.Errorf("∞^0 != ∞")
   105  	}
   106  
   107  	if !isInfinity(curve.ScalarBaseMult(curve.Params().N.Bytes())) {
   108  		t.Errorf("b^q != ∞")
   109  	}
   110  	if !isInfinity(curve.ScalarBaseMult([]byte{0})) {
   111  		t.Errorf("b^0 != ∞")
   112  	}
   113  
   114  	if !isInfinity(curve.Double(x0, y0)) {
   115  		t.Errorf("2∞ != ∞")
   116  	}
   117  	// There is no other point of order two on the NIST curves (as they have
   118  	// cofactor one), so Double can't otherwise return the point at infinity.
   119  
   120  	nMinusOne := new(big.Int).Sub(curve.Params().N, big.NewInt(1))
   121  	x, y := curve.ScalarMult(xG, yG, nMinusOne.Bytes())
   122  	x, y = curve.Add(x, y, xG, yG)
   123  	if !isInfinity(x, y) {
   124  		t.Errorf("x^(q-1) + x != ∞")
   125  	}
   126  	x, y = curve.Add(xG, yG, x0, y0)
   127  	if x.Cmp(xG) != 0 || y.Cmp(yG) != 0 {
   128  		t.Errorf("x+∞ != x")
   129  	}
   130  	x, y = curve.Add(x0, y0, xG, yG)
   131  	if x.Cmp(xG) != 0 || y.Cmp(yG) != 0 {
   132  		t.Errorf("∞+x != x")
   133  	}
   134  
   135  	if curve.IsOnCurve(x0, y0) {
   136  		t.Errorf("IsOnCurve(∞) == true")
   137  	}
   138  
   139  	if xx, yy := Unmarshal(curve, Marshal(curve, x0, y0)); xx != nil || yy != nil {
   140  		t.Errorf("Unmarshal(Marshal(∞)) did not return an error")
   141  	}
   142  	// We don't test UnmarshalCompressed(MarshalCompressed(∞)) because there are
   143  	// two valid points with x = 0.
   144  	if xx, yy := Unmarshal(curve, []byte{0x00}); xx != nil || yy != nil {
   145  		t.Errorf("Unmarshal(∞) did not return an error")
   146  	}
   147  	byteLen := (curve.Params().BitSize + 7) / 8
   148  	buf := make([]byte, byteLen*2+1)
   149  	buf[0] = 4 // Uncompressed format.
   150  	if xx, yy := Unmarshal(curve, buf); xx != nil || yy != nil {
   151  		t.Errorf("Unmarshal((0,0)) did not return an error")
   152  	}
   153  }
   154  
   155  func TestMarshal(t *testing.T) {
   156  	t.Parallel()
   157  	testAllCurves(t, func(t *testing.T, curve Curve) {
   158  		_, x, y, err := GenerateKey(curve, rand.Reader)
   159  		if err != nil {
   160  			t.Fatal(err)
   161  		}
   162  		serialized := Marshal(curve, x, y)
   163  		xx, yy := Unmarshal(curve, serialized)
   164  		if xx == nil {
   165  			t.Fatal("failed to unmarshal")
   166  		}
   167  		if xx.Cmp(x) != 0 || yy.Cmp(y) != 0 {
   168  			t.Fatal("unmarshal returned different values")
   169  		}
   170  	})
   171  }
   172  
   173  func TestUnmarshalToLargeCoordinates(t *testing.T) {
   174  	t.Parallel()
   175  	// See https://golang.org/issues/20482.
   176  	testAllCurves(t, testUnmarshalToLargeCoordinates)
   177  }
   178  
   179  func testUnmarshalToLargeCoordinates(t *testing.T, curve Curve) {
   180  	p := curve.Params().P
   181  	byteLen := (p.BitLen() + 7) / 8
   182  
   183  	// Set x to be greater than curve's parameter P – specifically, to P+5.
   184  	// Set y to mod_sqrt(x^3 - 3x + B)) so that (x mod P = 5 , y) is on the
   185  	// curve.
   186  	x := new(big.Int).Add(p, big.NewInt(5))
   187  	y := curve.Params().polynomial(x)
   188  	y.ModSqrt(y, p)
   189  
   190  	invalid := make([]byte, byteLen*2+1)
   191  	invalid[0] = 4 // uncompressed encoding
   192  	x.FillBytes(invalid[1 : 1+byteLen])
   193  	y.FillBytes(invalid[1+byteLen:])
   194  
   195  	if X, Y := Unmarshal(curve, invalid); X != nil || Y != nil {
   196  		t.Errorf("Unmarshal accepts invalid X coordinate")
   197  	}
   198  
   199  	if curve == p256 {
   200  		// This is a point on the curve with a small y value, small enough that
   201  		// we can add p and still be within 32 bytes.
   202  		x, _ = new(big.Int).SetString("31931927535157963707678568152204072984517581467226068221761862915403492091210", 10)
   203  		y, _ = new(big.Int).SetString("5208467867388784005506817585327037698770365050895731383201516607147", 10)
   204  		y.Add(y, p)
   205  
   206  		if p.Cmp(y) > 0 || y.BitLen() != 256 {
   207  			t.Fatal("y not within expected range")
   208  		}
   209  
   210  		// marshal
   211  		x.FillBytes(invalid[1 : 1+byteLen])
   212  		y.FillBytes(invalid[1+byteLen:])
   213  
   214  		if X, Y := Unmarshal(curve, invalid); X != nil || Y != nil {
   215  			t.Errorf("Unmarshal accepts invalid Y coordinate")
   216  		}
   217  	}
   218  }
   219  
   220  // TestInvalidCoordinates tests big.Int values that are not valid field elements
   221  // (negative or bigger than P). They are expected to return false from
   222  // IsOnCurve, all other behavior is undefined.
   223  func TestInvalidCoordinates(t *testing.T) {
   224  	t.Parallel()
   225  	testAllCurves(t, testInvalidCoordinates)
   226  }
   227  
   228  func testInvalidCoordinates(t *testing.T, curve Curve) {
   229  	checkIsOnCurveFalse := func(name string, x, y *big.Int) {
   230  		if curve.IsOnCurve(x, y) {
   231  			t.Errorf("IsOnCurve(%s) unexpectedly returned true", name)
   232  		}
   233  	}
   234  
   235  	p := curve.Params().P
   236  	_, x, y, _ := GenerateKey(curve, rand.Reader)
   237  	xx, yy := new(big.Int), new(big.Int)
   238  
   239  	// Check if the sign is getting dropped.
   240  	xx.Neg(x)
   241  	checkIsOnCurveFalse("-x, y", xx, y)
   242  	yy.Neg(y)
   243  	checkIsOnCurveFalse("x, -y", x, yy)
   244  
   245  	// Check if negative values are reduced modulo P.
   246  	xx.Sub(x, p)
   247  	checkIsOnCurveFalse("x-P, y", xx, y)
   248  	yy.Sub(y, p)
   249  	checkIsOnCurveFalse("x, y-P", x, yy)
   250  
   251  	// Check if positive values are reduced modulo P.
   252  	xx.Add(x, p)
   253  	checkIsOnCurveFalse("x+P, y", xx, y)
   254  	yy.Add(y, p)
   255  	checkIsOnCurveFalse("x, y+P", x, yy)
   256  
   257  	// Check if the overflow is dropped.
   258  	xx.Add(x, new(big.Int).Lsh(big.NewInt(1), 535))
   259  	checkIsOnCurveFalse("x+2⁵³⁵, y", xx, y)
   260  	yy.Add(y, new(big.Int).Lsh(big.NewInt(1), 535))
   261  	checkIsOnCurveFalse("x, y+2⁵³⁵", x, yy)
   262  
   263  	// Check if P is treated like zero (if possible).
   264  	// y^2 = x^3 - 3x + B
   265  	// y = mod_sqrt(x^3 - 3x + B)
   266  	// y = mod_sqrt(B) if x = 0
   267  	// If there is no modsqrt, there is no point with x = 0, can't test x = P.
   268  	if yy := new(big.Int).ModSqrt(curve.Params().B, p); yy != nil {
   269  		if !curve.IsOnCurve(big.NewInt(0), yy) {
   270  			t.Fatal("(0, mod_sqrt(B)) is not on the curve?")
   271  		}
   272  		checkIsOnCurveFalse("P, y", p, yy)
   273  	}
   274  }
   275  
   276  func TestMarshalCompressed(t *testing.T) {
   277  	t.Parallel()
   278  	t.Run("P-256/03", func(t *testing.T) {
   279  		data, _ := hex.DecodeString("031e3987d9f9ea9d7dd7155a56a86b2009e1e0ab332f962d10d8beb6406ab1ad79")
   280  		x, _ := new(big.Int).SetString("13671033352574878777044637384712060483119675368076128232297328793087057702265", 10)
   281  		y, _ := new(big.Int).SetString("66200849279091436748794323380043701364391950689352563629885086590854940586447", 10)
   282  		testMarshalCompressed(t, P256(), x, y, data)
   283  	})
   284  	t.Run("P-256/02", func(t *testing.T) {
   285  		data, _ := hex.DecodeString("021e3987d9f9ea9d7dd7155a56a86b2009e1e0ab332f962d10d8beb6406ab1ad79")
   286  		x, _ := new(big.Int).SetString("13671033352574878777044637384712060483119675368076128232297328793087057702265", 10)
   287  		y, _ := new(big.Int).SetString("49591239931264812013903123569363872165694192725937750565648544718012157267504", 10)
   288  		testMarshalCompressed(t, P256(), x, y, data)
   289  	})
   290  
   291  	t.Run("Invalid", func(t *testing.T) {
   292  		data, _ := hex.DecodeString("02fd4bf61763b46581fd9174d623516cf3c81edd40e29ffa2777fb6cb0ae3ce535")
   293  		X, Y := UnmarshalCompressed(P256(), data)
   294  		if X != nil || Y != nil {
   295  			t.Error("expected an error for invalid encoding")
   296  		}
   297  	})
   298  
   299  	if testing.Short() {
   300  		t.Skip("skipping other curves on short test")
   301  	}
   302  
   303  	testAllCurves(t, func(t *testing.T, curve Curve) {
   304  		_, x, y, err := GenerateKey(curve, rand.Reader)
   305  		if err != nil {
   306  			t.Fatal(err)
   307  		}
   308  		testMarshalCompressed(t, curve, x, y, nil)
   309  	})
   310  
   311  }
   312  
   313  func testMarshalCompressed(t *testing.T, curve Curve, x, y *big.Int, want []byte) {
   314  	if !curve.IsOnCurve(x, y) {
   315  		t.Fatal("invalid test point")
   316  	}
   317  	got := MarshalCompressed(curve, x, y)
   318  	if want != nil && !bytes.Equal(got, want) {
   319  		t.Errorf("got unexpected MarshalCompressed result: got %x, want %x", got, want)
   320  	}
   321  
   322  	X, Y := UnmarshalCompressed(curve, got)
   323  	if X == nil || Y == nil {
   324  		t.Fatalf("UnmarshalCompressed failed unexpectedly")
   325  	}
   326  
   327  	if !curve.IsOnCurve(X, Y) {
   328  		t.Error("UnmarshalCompressed returned a point not on the curve")
   329  	}
   330  	if X.Cmp(x) != 0 || Y.Cmp(y) != 0 {
   331  		t.Errorf("point did not round-trip correctly: got (%v, %v), want (%v, %v)", X, Y, x, y)
   332  	}
   333  }
   334  
   335  func TestLargeIsOnCurve(t *testing.T) {
   336  	t.Parallel()
   337  	testAllCurves(t, func(t *testing.T, curve Curve) {
   338  		large := big.NewInt(1)
   339  		large.Lsh(large, 1000)
   340  		if curve.IsOnCurve(large, large) {
   341  			t.Errorf("(2^1000, 2^1000) is reported on the curve")
   342  		}
   343  	})
   344  }
   345  
   346  func benchmarkAllCurves(b *testing.B, f func(*testing.B, Curve)) {
   347  	tests := []struct {
   348  		name  string
   349  		curve Curve
   350  	}{
   351  		{"P256", P256()},
   352  		{"P224", P224()},
   353  		{"P384", P384()},
   354  		{"P521", P521()},
   355  	}
   356  	for _, test := range tests {
   357  		curve := test.curve
   358  		b.Run(test.name, func(b *testing.B) {
   359  			f(b, curve)
   360  		})
   361  	}
   362  }
   363  
   364  func BenchmarkScalarBaseMult(b *testing.B) {
   365  	benchmarkAllCurves(b, func(b *testing.B, curve Curve) {
   366  		priv, _, _, _ := GenerateKey(curve, rand.Reader)
   367  		b.ReportAllocs()
   368  		b.ResetTimer()
   369  		for i := 0; i < b.N; i++ {
   370  			x, _ := curve.ScalarBaseMult(priv)
   371  			// Prevent the compiler from optimizing out the operation.
   372  			priv[0] ^= byte(x.Bits()[0])
   373  		}
   374  	})
   375  }
   376  
   377  func BenchmarkScalarMult(b *testing.B) {
   378  	benchmarkAllCurves(b, func(b *testing.B, curve Curve) {
   379  		_, x, y, _ := GenerateKey(curve, rand.Reader)
   380  		priv, _, _, _ := GenerateKey(curve, rand.Reader)
   381  		b.ReportAllocs()
   382  		b.ResetTimer()
   383  		for i := 0; i < b.N; i++ {
   384  			x, y = curve.ScalarMult(x, y, priv)
   385  		}
   386  	})
   387  }
   388  
   389  func BenchmarkMarshalUnmarshal(b *testing.B) {
   390  	benchmarkAllCurves(b, func(b *testing.B, curve Curve) {
   391  		_, x, y, _ := GenerateKey(curve, rand.Reader)
   392  		b.Run("Uncompressed", func(b *testing.B) {
   393  			b.ReportAllocs()
   394  			for i := 0; i < b.N; i++ {
   395  				buf := Marshal(curve, x, y)
   396  				xx, yy := Unmarshal(curve, buf)
   397  				if xx.Cmp(x) != 0 || yy.Cmp(y) != 0 {
   398  					b.Error("Unmarshal output different from Marshal input")
   399  				}
   400  			}
   401  		})
   402  		b.Run("Compressed", func(b *testing.B) {
   403  			b.ReportAllocs()
   404  			for i := 0; i < b.N; i++ {
   405  				buf := MarshalCompressed(curve, x, y)
   406  				xx, yy := UnmarshalCompressed(curve, buf)
   407  				if xx.Cmp(x) != 0 || yy.Cmp(y) != 0 {
   408  					b.Error("Unmarshal output different from Marshal input")
   409  				}
   410  			}
   411  		})
   412  	})
   413  }
   414  

View as plain text