...
Run Format

Source file src/math/big/nat.go

     1	// Copyright 2009 The Go Authors. All rights reserved.
     2	// Use of this source code is governed by a BSD-style
     3	// license that can be found in the LICENSE file.
     4	
     5	// This file implements unsigned multi-precision integers (natural
     6	// numbers). They are the building blocks for the implementation
     7	// of signed integers, rationals, and floating-point numbers.
     8	
     9	package big
    10	
    11	import "math/rand"
    12	
    13	// An unsigned integer x of the form
    14	//
    15	//   x = x[n-1]*_B^(n-1) + x[n-2]*_B^(n-2) + ... + x[1]*_B + x[0]
    16	//
    17	// with 0 <= x[i] < _B and 0 <= i < n is stored in a slice of length n,
    18	// with the digits x[i] as the slice elements.
    19	//
    20	// A number is normalized if the slice contains no leading 0 digits.
    21	// During arithmetic operations, denormalized values may occur but are
    22	// always normalized before returning the final result. The normalized
    23	// representation of 0 is the empty or nil slice (length = 0).
    24	//
    25	type nat []Word
    26	
    27	var (
    28		natOne = nat{1}
    29		natTwo = nat{2}
    30		natTen = nat{10}
    31	)
    32	
    33	func (z nat) clear() {
    34		for i := range z {
    35			z[i] = 0
    36		}
    37	}
    38	
    39	func (z nat) norm() nat {
    40		i := len(z)
    41		for i > 0 && z[i-1] == 0 {
    42			i--
    43		}
    44		return z[0:i]
    45	}
    46	
    47	func (z nat) make(n int) nat {
    48		if n <= cap(z) {
    49			return z[:n] // reuse z
    50		}
    51		// Choosing a good value for e has significant performance impact
    52		// because it increases the chance that a value can be reused.
    53		const e = 4 // extra capacity
    54		return make(nat, n, n+e)
    55	}
    56	
    57	func (z nat) setWord(x Word) nat {
    58		if x == 0 {
    59			return z[:0]
    60		}
    61		z = z.make(1)
    62		z[0] = x
    63		return z
    64	}
    65	
    66	func (z nat) setUint64(x uint64) nat {
    67		// single-digit values
    68		if w := Word(x); uint64(w) == x {
    69			return z.setWord(w)
    70		}
    71	
    72		// compute number of words n required to represent x
    73		n := 0
    74		for t := x; t > 0; t >>= _W {
    75			n++
    76		}
    77	
    78		// split x into n words
    79		z = z.make(n)
    80		for i := range z {
    81			z[i] = Word(x & _M)
    82			x >>= _W
    83		}
    84	
    85		return z
    86	}
    87	
    88	func (z nat) set(x nat) nat {
    89		z = z.make(len(x))
    90		copy(z, x)
    91		return z
    92	}
    93	
    94	func (z nat) add(x, y nat) nat {
    95		m := len(x)
    96		n := len(y)
    97	
    98		switch {
    99		case m < n:
   100			return z.add(y, x)
   101		case m == 0:
   102			// n == 0 because m >= n; result is 0
   103			return z[:0]
   104		case n == 0:
   105			// result is x
   106			return z.set(x)
   107		}
   108		// m > 0
   109	
   110		z = z.make(m + 1)
   111		c := addVV(z[0:n], x, y)
   112		if m > n {
   113			c = addVW(z[n:m], x[n:], c)
   114		}
   115		z[m] = c
   116	
   117		return z.norm()
   118	}
   119	
   120	func (z nat) sub(x, y nat) nat {
   121		m := len(x)
   122		n := len(y)
   123	
   124		switch {
   125		case m < n:
   126			panic("underflow")
   127		case m == 0:
   128			// n == 0 because m >= n; result is 0
   129			return z[:0]
   130		case n == 0:
   131			// result is x
   132			return z.set(x)
   133		}
   134		// m > 0
   135	
   136		z = z.make(m)
   137		c := subVV(z[0:n], x, y)
   138		if m > n {
   139			c = subVW(z[n:], x[n:], c)
   140		}
   141		if c != 0 {
   142			panic("underflow")
   143		}
   144	
   145		return z.norm()
   146	}
   147	
   148	func (x nat) cmp(y nat) (r int) {
   149		m := len(x)
   150		n := len(y)
   151		if m != n || m == 0 {
   152			switch {
   153			case m < n:
   154				r = -1
   155			case m > n:
   156				r = 1
   157			}
   158			return
   159		}
   160	
   161		i := m - 1
   162		for i > 0 && x[i] == y[i] {
   163			i--
   164		}
   165	
   166		switch {
   167		case x[i] < y[i]:
   168			r = -1
   169		case x[i] > y[i]:
   170			r = 1
   171		}
   172		return
   173	}
   174	
   175	func (z nat) mulAddWW(x nat, y, r Word) nat {
   176		m := len(x)
   177		if m == 0 || y == 0 {
   178			return z.setWord(r) // result is r
   179		}
   180		// m > 0
   181	
   182		z = z.make(m + 1)
   183		z[m] = mulAddVWW(z[0:m], x, y, r)
   184	
   185		return z.norm()
   186	}
   187	
   188	// basicMul multiplies x and y and leaves the result in z.
   189	// The (non-normalized) result is placed in z[0 : len(x) + len(y)].
   190	func basicMul(z, x, y nat) {
   191		z[0 : len(x)+len(y)].clear() // initialize z
   192		for i, d := range y {
   193			if d != 0 {
   194				z[len(x)+i] = addMulVVW(z[i:i+len(x)], x, d)
   195			}
   196		}
   197	}
   198	
   199	// montgomery computes z mod m = x*y*2**(-n*_W) mod m,
   200	// assuming k = -1/m mod 2**_W.
   201	// z is used for storing the result which is returned;
   202	// z must not alias x, y or m.
   203	// See Gueron, "Efficient Software Implementations of Modular Exponentiation".
   204	// https://eprint.iacr.org/2011/239.pdf
   205	// In the terminology of that paper, this is an "Almost Montgomery Multiplication":
   206	// x and y are required to satisfy 0 <= z < 2**(n*_W) and then the result
   207	// z is guaranteed to satisfy 0 <= z < 2**(n*_W), but it may not be < m.
   208	func (z nat) montgomery(x, y, m nat, k Word, n int) nat {
   209		// This code assumes x, y, m are all the same length, n.
   210		// (required by addMulVVW and the for loop).
   211		// It also assumes that x, y are already reduced mod m,
   212		// or else the result will not be properly reduced.
   213		if len(x) != n || len(y) != n || len(m) != n {
   214			panic("math/big: mismatched montgomery number lengths")
   215		}
   216		z = z.make(n)
   217		z.clear()
   218		var c Word
   219		for i := 0; i < n; i++ {
   220			d := y[i]
   221			c2 := addMulVVW(z, x, d)
   222			t := z[0] * k
   223			c3 := addMulVVW(z, m, t)
   224			copy(z, z[1:])
   225			cx := c + c2
   226			cy := cx + c3
   227			z[n-1] = cy
   228			if cx < c2 || cy < c3 {
   229				c = 1
   230			} else {
   231				c = 0
   232			}
   233		}
   234		if c != 0 {
   235			subVV(z, z, m)
   236		}
   237		return z
   238	}
   239	
   240	// Fast version of z[0:n+n>>1].add(z[0:n+n>>1], x[0:n]) w/o bounds checks.
   241	// Factored out for readability - do not use outside karatsuba.
   242	func karatsubaAdd(z, x nat, n int) {
   243		if c := addVV(z[0:n], z, x); c != 0 {
   244			addVW(z[n:n+n>>1], z[n:], c)
   245		}
   246	}
   247	
   248	// Like karatsubaAdd, but does subtract.
   249	func karatsubaSub(z, x nat, n int) {
   250		if c := subVV(z[0:n], z, x); c != 0 {
   251			subVW(z[n:n+n>>1], z[n:], c)
   252		}
   253	}
   254	
   255	// Operands that are shorter than karatsubaThreshold are multiplied using
   256	// "grade school" multiplication; for longer operands the Karatsuba algorithm
   257	// is used.
   258	var karatsubaThreshold int = 40 // computed by calibrate.go
   259	
   260	// karatsuba multiplies x and y and leaves the result in z.
   261	// Both x and y must have the same length n and n must be a
   262	// power of 2. The result vector z must have len(z) >= 6*n.
   263	// The (non-normalized) result is placed in z[0 : 2*n].
   264	func karatsuba(z, x, y nat) {
   265		n := len(y)
   266	
   267		// Switch to basic multiplication if numbers are odd or small.
   268		// (n is always even if karatsubaThreshold is even, but be
   269		// conservative)
   270		if n&1 != 0 || n < karatsubaThreshold || n < 2 {
   271			basicMul(z, x, y)
   272			return
   273		}
   274		// n&1 == 0 && n >= karatsubaThreshold && n >= 2
   275	
   276		// Karatsuba multiplication is based on the observation that
   277		// for two numbers x and y with:
   278		//
   279		//   x = x1*b + x0
   280		//   y = y1*b + y0
   281		//
   282		// the product x*y can be obtained with 3 products z2, z1, z0
   283		// instead of 4:
   284		//
   285		//   x*y = x1*y1*b*b + (x1*y0 + x0*y1)*b + x0*y0
   286		//       =    z2*b*b +              z1*b +    z0
   287		//
   288		// with:
   289		//
   290		//   xd = x1 - x0
   291		//   yd = y0 - y1
   292		//
   293		//   z1 =      xd*yd                    + z2 + z0
   294		//      = (x1-x0)*(y0 - y1)             + z2 + z0
   295		//      = x1*y0 - x1*y1 - x0*y0 + x0*y1 + z2 + z0
   296		//      = x1*y0 -    z2 -    z0 + x0*y1 + z2 + z0
   297		//      = x1*y0                 + x0*y1
   298	
   299		// split x, y into "digits"
   300		n2 := n >> 1              // n2 >= 1
   301		x1, x0 := x[n2:], x[0:n2] // x = x1*b + y0
   302		y1, y0 := y[n2:], y[0:n2] // y = y1*b + y0
   303	
   304		// z is used for the result and temporary storage:
   305		//
   306		//   6*n     5*n     4*n     3*n     2*n     1*n     0*n
   307		// z = [z2 copy|z0 copy| xd*yd | yd:xd | x1*y1 | x0*y0 ]
   308		//
   309		// For each recursive call of karatsuba, an unused slice of
   310		// z is passed in that has (at least) half the length of the
   311		// caller's z.
   312	
   313		// compute z0 and z2 with the result "in place" in z
   314		karatsuba(z, x0, y0)     // z0 = x0*y0
   315		karatsuba(z[n:], x1, y1) // z2 = x1*y1
   316	
   317		// compute xd (or the negative value if underflow occurs)
   318		s := 1 // sign of product xd*yd
   319		xd := z[2*n : 2*n+n2]
   320		if subVV(xd, x1, x0) != 0 { // x1-x0
   321			s = -s
   322			subVV(xd, x0, x1) // x0-x1
   323		}
   324	
   325		// compute yd (or the negative value if underflow occurs)
   326		yd := z[2*n+n2 : 3*n]
   327		if subVV(yd, y0, y1) != 0 { // y0-y1
   328			s = -s
   329			subVV(yd, y1, y0) // y1-y0
   330		}
   331	
   332		// p = (x1-x0)*(y0-y1) == x1*y0 - x1*y1 - x0*y0 + x0*y1 for s > 0
   333		// p = (x0-x1)*(y0-y1) == x0*y0 - x0*y1 - x1*y0 + x1*y1 for s < 0
   334		p := z[n*3:]
   335		karatsuba(p, xd, yd)
   336	
   337		// save original z2:z0
   338		// (ok to use upper half of z since we're done recursing)
   339		r := z[n*4:]
   340		copy(r, z[:n*2])
   341	
   342		// add up all partial products
   343		//
   344		//   2*n     n     0
   345		// z = [ z2  | z0  ]
   346		//   +    [ z0  ]
   347		//   +    [ z2  ]
   348		//   +    [  p  ]
   349		//
   350		karatsubaAdd(z[n2:], r, n)
   351		karatsubaAdd(z[n2:], r[n:], n)
   352		if s > 0 {
   353			karatsubaAdd(z[n2:], p, n)
   354		} else {
   355			karatsubaSub(z[n2:], p, n)
   356		}
   357	}
   358	
   359	// alias reports whether x and y share the same base array.
   360	func alias(x, y nat) bool {
   361		return cap(x) > 0 && cap(y) > 0 && &x[0:cap(x)][cap(x)-1] == &y[0:cap(y)][cap(y)-1]
   362	}
   363	
   364	// addAt implements z += x<<(_W*i); z must be long enough.
   365	// (we don't use nat.add because we need z to stay the same
   366	// slice, and we don't need to normalize z after each addition)
   367	func addAt(z, x nat, i int) {
   368		if n := len(x); n > 0 {
   369			if c := addVV(z[i:i+n], z[i:], x); c != 0 {
   370				j := i + n
   371				if j < len(z) {
   372					addVW(z[j:], z[j:], c)
   373				}
   374			}
   375		}
   376	}
   377	
   378	func max(x, y int) int {
   379		if x > y {
   380			return x
   381		}
   382		return y
   383	}
   384	
   385	// karatsubaLen computes an approximation to the maximum k <= n such that
   386	// k = p<<i for a number p <= karatsubaThreshold and an i >= 0. Thus, the
   387	// result is the largest number that can be divided repeatedly by 2 before
   388	// becoming about the value of karatsubaThreshold.
   389	func karatsubaLen(n int) int {
   390		i := uint(0)
   391		for n > karatsubaThreshold {
   392			n >>= 1
   393			i++
   394		}
   395		return n << i
   396	}
   397	
   398	func (z nat) mul(x, y nat) nat {
   399		m := len(x)
   400		n := len(y)
   401	
   402		switch {
   403		case m < n:
   404			return z.mul(y, x)
   405		case m == 0 || n == 0:
   406			return z[:0]
   407		case n == 1:
   408			return z.mulAddWW(x, y[0], 0)
   409		}
   410		// m >= n > 1
   411	
   412		// determine if z can be reused
   413		if alias(z, x) || alias(z, y) {
   414			z = nil // z is an alias for x or y - cannot reuse
   415		}
   416	
   417		// use basic multiplication if the numbers are small
   418		if n < karatsubaThreshold {
   419			z = z.make(m + n)
   420			basicMul(z, x, y)
   421			return z.norm()
   422		}
   423		// m >= n && n >= karatsubaThreshold && n >= 2
   424	
   425		// determine Karatsuba length k such that
   426		//
   427		//   x = xh*b + x0  (0 <= x0 < b)
   428		//   y = yh*b + y0  (0 <= y0 < b)
   429		//   b = 1<<(_W*k)  ("base" of digits xi, yi)
   430		//
   431		k := karatsubaLen(n)
   432		// k <= n
   433	
   434		// multiply x0 and y0 via Karatsuba
   435		x0 := x[0:k]              // x0 is not normalized
   436		y0 := y[0:k]              // y0 is not normalized
   437		z = z.make(max(6*k, m+n)) // enough space for karatsuba of x0*y0 and full result of x*y
   438		karatsuba(z, x0, y0)
   439		z = z[0 : m+n]  // z has final length but may be incomplete
   440		z[2*k:].clear() // upper portion of z is garbage (and 2*k <= m+n since k <= n <= m)
   441	
   442		// If xh != 0 or yh != 0, add the missing terms to z. For
   443		//
   444		//   xh = xi*b^i + ... + x2*b^2 + x1*b (0 <= xi < b)
   445		//   yh =                         y1*b (0 <= y1 < b)
   446		//
   447		// the missing terms are
   448		//
   449		//   x0*y1*b and xi*y0*b^i, xi*y1*b^(i+1) for i > 0
   450		//
   451		// since all the yi for i > 1 are 0 by choice of k: If any of them
   452		// were > 0, then yh >= b^2 and thus y >= b^2. Then k' = k*2 would
   453		// be a larger valid threshold contradicting the assumption about k.
   454		//
   455		if k < n || m != n {
   456			var t nat
   457	
   458			// add x0*y1*b
   459			x0 := x0.norm()
   460			y1 := y[k:]       // y1 is normalized because y is
   461			t = t.mul(x0, y1) // update t so we don't lose t's underlying array
   462			addAt(z, t, k)
   463	
   464			// add xi*y0<<i, xi*y1*b<<(i+k)
   465			y0 := y0.norm()
   466			for i := k; i < len(x); i += k {
   467				xi := x[i:]
   468				if len(xi) > k {
   469					xi = xi[:k]
   470				}
   471				xi = xi.norm()
   472				t = t.mul(xi, y0)
   473				addAt(z, t, i)
   474				t = t.mul(xi, y1)
   475				addAt(z, t, i+k)
   476			}
   477		}
   478	
   479		return z.norm()
   480	}
   481	
   482	// mulRange computes the product of all the unsigned integers in the
   483	// range [a, b] inclusively. If a > b (empty range), the result is 1.
   484	func (z nat) mulRange(a, b uint64) nat {
   485		switch {
   486		case a == 0:
   487			// cut long ranges short (optimization)
   488			return z.setUint64(0)
   489		case a > b:
   490			return z.setUint64(1)
   491		case a == b:
   492			return z.setUint64(a)
   493		case a+1 == b:
   494			return z.mul(nat(nil).setUint64(a), nat(nil).setUint64(b))
   495		}
   496		m := (a + b) / 2
   497		return z.mul(nat(nil).mulRange(a, m), nat(nil).mulRange(m+1, b))
   498	}
   499	
   500	// q = (x-r)/y, with 0 <= r < y
   501	func (z nat) divW(x nat, y Word) (q nat, r Word) {
   502		m := len(x)
   503		switch {
   504		case y == 0:
   505			panic("division by zero")
   506		case y == 1:
   507			q = z.set(x) // result is x
   508			return
   509		case m == 0:
   510			q = z[:0] // result is 0
   511			return
   512		}
   513		// m > 0
   514		z = z.make(m)
   515		r = divWVW(z, 0, x, y)
   516		q = z.norm()
   517		return
   518	}
   519	
   520	func (z nat) div(z2, u, v nat) (q, r nat) {
   521		if len(v) == 0 {
   522			panic("division by zero")
   523		}
   524	
   525		if u.cmp(v) < 0 {
   526			q = z[:0]
   527			r = z2.set(u)
   528			return
   529		}
   530	
   531		if len(v) == 1 {
   532			var r2 Word
   533			q, r2 = z.divW(u, v[0])
   534			r = z2.setWord(r2)
   535			return
   536		}
   537	
   538		q, r = z.divLarge(z2, u, v)
   539		return
   540	}
   541	
   542	// q = (uIn-r)/v, with 0 <= r < y
   543	// Uses z as storage for q, and u as storage for r if possible.
   544	// See Knuth, Volume 2, section 4.3.1, Algorithm D.
   545	// Preconditions:
   546	//    len(v) >= 2
   547	//    len(uIn) >= len(v)
   548	func (z nat) divLarge(u, uIn, v nat) (q, r nat) {
   549		n := len(v)
   550		m := len(uIn) - n
   551	
   552		// determine if z can be reused
   553		// TODO(gri) should find a better solution - this if statement
   554		//           is very costly (see e.g. time pidigits -s -n 10000)
   555		if alias(z, uIn) || alias(z, v) {
   556			z = nil // z is an alias for uIn or v - cannot reuse
   557		}
   558		q = z.make(m + 1)
   559	
   560		qhatv := make(nat, n+1)
   561		if alias(u, uIn) || alias(u, v) {
   562			u = nil // u is an alias for uIn or v - cannot reuse
   563		}
   564		u = u.make(len(uIn) + 1)
   565		u.clear() // TODO(gri) no need to clear if we allocated a new u
   566	
   567		// D1.
   568		shift := nlz(v[n-1])
   569		if shift > 0 {
   570			// do not modify v, it may be used by another goroutine simultaneously
   571			v1 := make(nat, n)
   572			shlVU(v1, v, shift)
   573			v = v1
   574		}
   575		u[len(uIn)] = shlVU(u[0:len(uIn)], uIn, shift)
   576	
   577		// D2.
   578		for j := m; j >= 0; j-- {
   579			// D3.
   580			qhat := Word(_M)
   581			if u[j+n] != v[n-1] {
   582				var rhat Word
   583				qhat, rhat = divWW(u[j+n], u[j+n-1], v[n-1])
   584	
   585				// x1 | x2 = q̂v_{n-2}
   586				x1, x2 := mulWW(qhat, v[n-2])
   587				// test if q̂v_{n-2} > br̂ + u_{j+n-2}
   588				for greaterThan(x1, x2, rhat, u[j+n-2]) {
   589					qhat--
   590					prevRhat := rhat
   591					rhat += v[n-1]
   592					// v[n-1] >= 0, so this tests for overflow.
   593					if rhat < prevRhat {
   594						break
   595					}
   596					x1, x2 = mulWW(qhat, v[n-2])
   597				}
   598			}
   599	
   600			// D4.
   601			qhatv[n] = mulAddVWW(qhatv[0:n], v, qhat, 0)
   602	
   603			c := subVV(u[j:j+len(qhatv)], u[j:], qhatv)
   604			if c != 0 {
   605				c := addVV(u[j:j+n], u[j:], v)
   606				u[j+n] += c
   607				qhat--
   608			}
   609	
   610			q[j] = qhat
   611		}
   612	
   613		q = q.norm()
   614		shrVU(u, u, shift)
   615		r = u.norm()
   616	
   617		return q, r
   618	}
   619	
   620	// Length of x in bits. x must be normalized.
   621	func (x nat) bitLen() int {
   622		if i := len(x) - 1; i >= 0 {
   623			return i*_W + bitLen(x[i])
   624		}
   625		return 0
   626	}
   627	
   628	const deBruijn32 = 0x077CB531
   629	
   630	var deBruijn32Lookup = []byte{
   631		0, 1, 28, 2, 29, 14, 24, 3, 30, 22, 20, 15, 25, 17, 4, 8,
   632		31, 27, 13, 23, 21, 19, 16, 7, 26, 12, 18, 6, 11, 5, 10, 9,
   633	}
   634	
   635	const deBruijn64 = 0x03f79d71b4ca8b09
   636	
   637	var deBruijn64Lookup = []byte{
   638		0, 1, 56, 2, 57, 49, 28, 3, 61, 58, 42, 50, 38, 29, 17, 4,
   639		62, 47, 59, 36, 45, 43, 51, 22, 53, 39, 33, 30, 24, 18, 12, 5,
   640		63, 55, 48, 27, 60, 41, 37, 16, 46, 35, 44, 21, 52, 32, 23, 11,
   641		54, 26, 40, 15, 34, 20, 31, 10, 25, 14, 19, 9, 13, 8, 7, 6,
   642	}
   643	
   644	// trailingZeroBits returns the number of consecutive least significant zero
   645	// bits of x.
   646	func trailingZeroBits(x Word) uint {
   647		// x & -x leaves only the right-most bit set in the word. Let k be the
   648		// index of that bit. Since only a single bit is set, the value is two
   649		// to the power of k. Multiplying by a power of two is equivalent to
   650		// left shifting, in this case by k bits.  The de Bruijn constant is
   651		// such that all six bit, consecutive substrings are distinct.
   652		// Therefore, if we have a left shifted version of this constant we can
   653		// find by how many bits it was shifted by looking at which six bit
   654		// substring ended up at the top of the word.
   655		// (Knuth, volume 4, section 7.3.1)
   656		switch _W {
   657		case 32:
   658			return uint(deBruijn32Lookup[((x&-x)*deBruijn32)>>27])
   659		case 64:
   660			return uint(deBruijn64Lookup[((x&-x)*(deBruijn64&_M))>>58])
   661		default:
   662			panic("unknown word size")
   663		}
   664	}
   665	
   666	// trailingZeroBits returns the number of consecutive least significant zero
   667	// bits of x.
   668	func (x nat) trailingZeroBits() uint {
   669		if len(x) == 0 {
   670			return 0
   671		}
   672		var i uint
   673		for x[i] == 0 {
   674			i++
   675		}
   676		// x[i] != 0
   677		return i*_W + trailingZeroBits(x[i])
   678	}
   679	
   680	// z = x << s
   681	func (z nat) shl(x nat, s uint) nat {
   682		m := len(x)
   683		if m == 0 {
   684			return z[:0]
   685		}
   686		// m > 0
   687	
   688		n := m + int(s/_W)
   689		z = z.make(n + 1)
   690		z[n] = shlVU(z[n-m:n], x, s%_W)
   691		z[0 : n-m].clear()
   692	
   693		return z.norm()
   694	}
   695	
   696	// z = x >> s
   697	func (z nat) shr(x nat, s uint) nat {
   698		m := len(x)
   699		n := m - int(s/_W)
   700		if n <= 0 {
   701			return z[:0]
   702		}
   703		// n > 0
   704	
   705		z = z.make(n)
   706		shrVU(z, x[m-n:], s%_W)
   707	
   708		return z.norm()
   709	}
   710	
   711	func (z nat) setBit(x nat, i uint, b uint) nat {
   712		j := int(i / _W)
   713		m := Word(1) << (i % _W)
   714		n := len(x)
   715		switch b {
   716		case 0:
   717			z = z.make(n)
   718			copy(z, x)
   719			if j >= n {
   720				// no need to grow
   721				return z
   722			}
   723			z[j] &^= m
   724			return z.norm()
   725		case 1:
   726			if j >= n {
   727				z = z.make(j + 1)
   728				z[n:].clear()
   729			} else {
   730				z = z.make(n)
   731			}
   732			copy(z, x)
   733			z[j] |= m
   734			// no need to normalize
   735			return z
   736		}
   737		panic("set bit is not 0 or 1")
   738	}
   739	
   740	// bit returns the value of the i'th bit, with lsb == bit 0.
   741	func (x nat) bit(i uint) uint {
   742		j := i / _W
   743		if j >= uint(len(x)) {
   744			return 0
   745		}
   746		// 0 <= j < len(x)
   747		return uint(x[j] >> (i % _W) & 1)
   748	}
   749	
   750	// sticky returns 1 if there's a 1 bit within the
   751	// i least significant bits, otherwise it returns 0.
   752	func (x nat) sticky(i uint) uint {
   753		j := i / _W
   754		if j >= uint(len(x)) {
   755			if len(x) == 0 {
   756				return 0
   757			}
   758			return 1
   759		}
   760		// 0 <= j < len(x)
   761		for _, x := range x[:j] {
   762			if x != 0 {
   763				return 1
   764			}
   765		}
   766		if x[j]<<(_W-i%_W) != 0 {
   767			return 1
   768		}
   769		return 0
   770	}
   771	
   772	func (z nat) and(x, y nat) nat {
   773		m := len(x)
   774		n := len(y)
   775		if m > n {
   776			m = n
   777		}
   778		// m <= n
   779	
   780		z = z.make(m)
   781		for i := 0; i < m; i++ {
   782			z[i] = x[i] & y[i]
   783		}
   784	
   785		return z.norm()
   786	}
   787	
   788	func (z nat) andNot(x, y nat) nat {
   789		m := len(x)
   790		n := len(y)
   791		if n > m {
   792			n = m
   793		}
   794		// m >= n
   795	
   796		z = z.make(m)
   797		for i := 0; i < n; i++ {
   798			z[i] = x[i] &^ y[i]
   799		}
   800		copy(z[n:m], x[n:m])
   801	
   802		return z.norm()
   803	}
   804	
   805	func (z nat) or(x, y nat) nat {
   806		m := len(x)
   807		n := len(y)
   808		s := x
   809		if m < n {
   810			n, m = m, n
   811			s = y
   812		}
   813		// m >= n
   814	
   815		z = z.make(m)
   816		for i := 0; i < n; i++ {
   817			z[i] = x[i] | y[i]
   818		}
   819		copy(z[n:m], s[n:m])
   820	
   821		return z.norm()
   822	}
   823	
   824	func (z nat) xor(x, y nat) nat {
   825		m := len(x)
   826		n := len(y)
   827		s := x
   828		if m < n {
   829			n, m = m, n
   830			s = y
   831		}
   832		// m >= n
   833	
   834		z = z.make(m)
   835		for i := 0; i < n; i++ {
   836			z[i] = x[i] ^ y[i]
   837		}
   838		copy(z[n:m], s[n:m])
   839	
   840		return z.norm()
   841	}
   842	
   843	// greaterThan reports whether (x1<<_W + x2) > (y1<<_W + y2)
   844	func greaterThan(x1, x2, y1, y2 Word) bool {
   845		return x1 > y1 || x1 == y1 && x2 > y2
   846	}
   847	
   848	// modW returns x % d.
   849	func (x nat) modW(d Word) (r Word) {
   850		// TODO(agl): we don't actually need to store the q value.
   851		var q nat
   852		q = q.make(len(x))
   853		return divWVW(q, 0, x, d)
   854	}
   855	
   856	// random creates a random integer in [0..limit), using the space in z if
   857	// possible. n is the bit length of limit.
   858	func (z nat) random(rand *rand.Rand, limit nat, n int) nat {
   859		if alias(z, limit) {
   860			z = nil // z is an alias for limit - cannot reuse
   861		}
   862		z = z.make(len(limit))
   863	
   864		bitLengthOfMSW := uint(n % _W)
   865		if bitLengthOfMSW == 0 {
   866			bitLengthOfMSW = _W
   867		}
   868		mask := Word((1 << bitLengthOfMSW) - 1)
   869	
   870		for {
   871			switch _W {
   872			case 32:
   873				for i := range z {
   874					z[i] = Word(rand.Uint32())
   875				}
   876			case 64:
   877				for i := range z {
   878					z[i] = Word(rand.Uint32()) | Word(rand.Uint32())<<32
   879				}
   880			default:
   881				panic("unknown word size")
   882			}
   883			z[len(limit)-1] &= mask
   884			if z.cmp(limit) < 0 {
   885				break
   886			}
   887		}
   888	
   889		return z.norm()
   890	}
   891	
   892	// If m != 0 (i.e., len(m) != 0), expNN sets z to x**y mod m;
   893	// otherwise it sets z to x**y. The result is the value of z.
   894	func (z nat) expNN(x, y, m nat) nat {
   895		if alias(z, x) || alias(z, y) {
   896			// We cannot allow in-place modification of x or y.
   897			z = nil
   898		}
   899	
   900		// x**y mod 1 == 0
   901		if len(m) == 1 && m[0] == 1 {
   902			return z.setWord(0)
   903		}
   904		// m == 0 || m > 1
   905	
   906		// x**0 == 1
   907		if len(y) == 0 {
   908			return z.setWord(1)
   909		}
   910		// y > 0
   911	
   912		// x**1 mod m == x mod m
   913		if len(y) == 1 && y[0] == 1 && len(m) != 0 {
   914			_, z = z.div(z, x, m)
   915			return z
   916		}
   917		// y > 1
   918	
   919		if len(m) != 0 {
   920			// We likely end up being as long as the modulus.
   921			z = z.make(len(m))
   922		}
   923		z = z.set(x)
   924	
   925		// If the base is non-trivial and the exponent is large, we use
   926		// 4-bit, windowed exponentiation. This involves precomputing 14 values
   927		// (x^2...x^15) but then reduces the number of multiply-reduces by a
   928		// third. Even for a 32-bit exponent, this reduces the number of
   929		// operations. Uses Montgomery method for odd moduli.
   930		if len(x) > 1 && len(y) > 1 && len(m) > 0 {
   931			if m[0]&1 == 1 {
   932				return z.expNNMontgomery(x, y, m)
   933			}
   934			return z.expNNWindowed(x, y, m)
   935		}
   936	
   937		v := y[len(y)-1] // v > 0 because y is normalized and y > 0
   938		shift := nlz(v) + 1
   939		v <<= shift
   940		var q nat
   941	
   942		const mask = 1 << (_W - 1)
   943	
   944		// We walk through the bits of the exponent one by one. Each time we
   945		// see a bit, we square, thus doubling the power. If the bit is a one,
   946		// we also multiply by x, thus adding one to the power.
   947	
   948		w := _W - int(shift)
   949		// zz and r are used to avoid allocating in mul and div as
   950		// otherwise the arguments would alias.
   951		var zz, r nat
   952		for j := 0; j < w; j++ {
   953			zz = zz.mul(z, z)
   954			zz, z = z, zz
   955	
   956			if v&mask != 0 {
   957				zz = zz.mul(z, x)
   958				zz, z = z, zz
   959			}
   960	
   961			if len(m) != 0 {
   962				zz, r = zz.div(r, z, m)
   963				zz, r, q, z = q, z, zz, r
   964			}
   965	
   966			v <<= 1
   967		}
   968	
   969		for i := len(y) - 2; i >= 0; i-- {
   970			v = y[i]
   971	
   972			for j := 0; j < _W; j++ {
   973				zz = zz.mul(z, z)
   974				zz, z = z, zz
   975	
   976				if v&mask != 0 {
   977					zz = zz.mul(z, x)
   978					zz, z = z, zz
   979				}
   980	
   981				if len(m) != 0 {
   982					zz, r = zz.div(r, z, m)
   983					zz, r, q, z = q, z, zz, r
   984				}
   985	
   986				v <<= 1
   987			}
   988		}
   989	
   990		return z.norm()
   991	}
   992	
   993	// expNNWindowed calculates x**y mod m using a fixed, 4-bit window.
   994	func (z nat) expNNWindowed(x, y, m nat) nat {
   995		// zz and r are used to avoid allocating in mul and div as otherwise
   996		// the arguments would alias.
   997		var zz, r nat
   998	
   999		const n = 4
  1000		// powers[i] contains x^i.
  1001		var powers [1 << n]nat
  1002		powers[0] = natOne
  1003		powers[1] = x
  1004		for i := 2; i < 1<<n; i += 2 {
  1005			p2, p, p1 := &powers[i/2], &powers[i], &powers[i+1]
  1006			*p = p.mul(*p2, *p2)
  1007			zz, r = zz.div(r, *p, m)
  1008			*p, r = r, *p
  1009			*p1 = p1.mul(*p, x)
  1010			zz, r = zz.div(r, *p1, m)
  1011			*p1, r = r, *p1
  1012		}
  1013	
  1014		z = z.setWord(1)
  1015	
  1016		for i := len(y) - 1; i >= 0; i-- {
  1017			yi := y[i]
  1018			for j := 0; j < _W; j += n {
  1019				if i != len(y)-1 || j != 0 {
  1020					// Unrolled loop for significant performance
  1021					// gain.  Use go test -bench=".*" in crypto/rsa
  1022					// to check performance before making changes.
  1023					zz = zz.mul(z, z)
  1024					zz, z = z, zz
  1025					zz, r = zz.div(r, z, m)
  1026					z, r = r, z
  1027	
  1028					zz = zz.mul(z, z)
  1029					zz, z = z, zz
  1030					zz, r = zz.div(r, z, m)
  1031					z, r = r, z
  1032	
  1033					zz = zz.mul(z, z)
  1034					zz, z = z, zz
  1035					zz, r = zz.div(r, z, m)
  1036					z, r = r, z
  1037	
  1038					zz = zz.mul(z, z)
  1039					zz, z = z, zz
  1040					zz, r = zz.div(r, z, m)
  1041					z, r = r, z
  1042				}
  1043	
  1044				zz = zz.mul(z, powers[yi>>(_W-n)])
  1045				zz, z = z, zz
  1046				zz, r = zz.div(r, z, m)
  1047				z, r = r, z
  1048	
  1049				yi <<= n
  1050			}
  1051		}
  1052	
  1053		return z.norm()
  1054	}
  1055	
  1056	// expNNMontgomery calculates x**y mod m using a fixed, 4-bit window.
  1057	// Uses Montgomery representation.
  1058	func (z nat) expNNMontgomery(x, y, m nat) nat {
  1059		numWords := len(m)
  1060	
  1061		// We want the lengths of x and m to be equal.
  1062		// It is OK if x >= m as long as len(x) == len(m).
  1063		if len(x) > numWords {
  1064			_, x = nat(nil).div(nil, x, m)
  1065			// Note: now len(x) <= numWords, not guaranteed ==.
  1066		}
  1067		if len(x) < numWords {
  1068			rr := make(nat, numWords)
  1069			copy(rr, x)
  1070			x = rr
  1071		}
  1072	
  1073		// Ideally the precomputations would be performed outside, and reused
  1074		// k0 = -m**-1 mod 2**_W. Algorithm from: Dumas, J.G. "On Newton–Raphson
  1075		// Iteration for Multiplicative Inverses Modulo Prime Powers".
  1076		k0 := 2 - m[0]
  1077		t := m[0] - 1
  1078		for i := 1; i < _W; i <<= 1 {
  1079			t *= t
  1080			k0 *= (t + 1)
  1081		}
  1082		k0 = -k0
  1083	
  1084		// RR = 2**(2*_W*len(m)) mod m
  1085		RR := nat(nil).setWord(1)
  1086		zz := nat(nil).shl(RR, uint(2*numWords*_W))
  1087		_, RR = RR.div(RR, zz, m)
  1088		if len(RR) < numWords {
  1089			zz = zz.make(numWords)
  1090			copy(zz, RR)
  1091			RR = zz
  1092		}
  1093		// one = 1, with equal length to that of m
  1094		one := make(nat, numWords)
  1095		one[0] = 1
  1096	
  1097		const n = 4
  1098		// powers[i] contains x^i
  1099		var powers [1 << n]nat
  1100		powers[0] = powers[0].montgomery(one, RR, m, k0, numWords)
  1101		powers[1] = powers[1].montgomery(x, RR, m, k0, numWords)
  1102		for i := 2; i < 1<<n; i++ {
  1103			powers[i] = powers[i].montgomery(powers[i-1], powers[1], m, k0, numWords)
  1104		}
  1105	
  1106		// initialize z = 1 (Montgomery 1)
  1107		z = z.make(numWords)
  1108		copy(z, powers[0])
  1109	
  1110		zz = zz.make(numWords)
  1111	
  1112		// same windowed exponent, but with Montgomery multiplications
  1113		for i := len(y) - 1; i >= 0; i-- {
  1114			yi := y[i]
  1115			for j := 0; j < _W; j += n {
  1116				if i != len(y)-1 || j != 0 {
  1117					zz = zz.montgomery(z, z, m, k0, numWords)
  1118					z = z.montgomery(zz, zz, m, k0, numWords)
  1119					zz = zz.montgomery(z, z, m, k0, numWords)
  1120					z = z.montgomery(zz, zz, m, k0, numWords)
  1121				}
  1122				zz = zz.montgomery(z, powers[yi>>(_W-n)], m, k0, numWords)
  1123				z, zz = zz, z
  1124				yi <<= n
  1125			}
  1126		}
  1127		// convert to regular number
  1128		zz = zz.montgomery(z, one, m, k0, numWords)
  1129	
  1130		// One last reduction, just in case.
  1131		// See golang.org/issue/13907.
  1132		if zz.cmp(m) >= 0 {
  1133			// Common case is m has high bit set; in that case,
  1134			// since zz is the same length as m, there can be just
  1135			// one multiple of m to remove. Just subtract.
  1136			// We think that the subtract should be sufficient in general,
  1137			// so do that unconditionally, but double-check,
  1138			// in case our beliefs are wrong.
  1139			// The div is not expected to be reached.
  1140			zz = zz.sub(zz, m)
  1141			if zz.cmp(m) >= 0 {
  1142				_, zz = nat(nil).div(nil, zz, m)
  1143			}
  1144		}
  1145	
  1146		return zz.norm()
  1147	}
  1148	
  1149	// probablyPrime performs n Miller-Rabin tests to check whether x is prime.
  1150	// If x is prime, it returns true.
  1151	// If x is not prime, it returns false with probability at least 1 - ¼ⁿ.
  1152	//
  1153	// It is not suitable for judging primes that an adversary may have crafted
  1154	// to fool this test.
  1155	func (n nat) probablyPrime(reps int) bool {
  1156		if len(n) == 0 {
  1157			return false
  1158		}
  1159	
  1160		if len(n) == 1 {
  1161			if n[0] < 2 {
  1162				return false
  1163			}
  1164	
  1165			if n[0]%2 == 0 {
  1166				return n[0] == 2
  1167			}
  1168	
  1169			// We have to exclude these cases because we reject all
  1170			// multiples of these numbers below.
  1171			switch n[0] {
  1172			case 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53:
  1173				return true
  1174			}
  1175		}
  1176	
  1177		if n[0]&1 == 0 {
  1178			return false // n is even
  1179		}
  1180	
  1181		const primesProduct32 = 0xC0CFD797         // Π {p ∈ primes, 2 < p <= 29}
  1182		const primesProduct64 = 0xE221F97C30E94E1D // Π {p ∈ primes, 2 < p <= 53}
  1183	
  1184		var r Word
  1185		switch _W {
  1186		case 32:
  1187			r = n.modW(primesProduct32)
  1188		case 64:
  1189			r = n.modW(primesProduct64 & _M)
  1190		default:
  1191			panic("Unknown word size")
  1192		}
  1193	
  1194		if r%3 == 0 || r%5 == 0 || r%7 == 0 || r%11 == 0 ||
  1195			r%13 == 0 || r%17 == 0 || r%19 == 0 || r%23 == 0 || r%29 == 0 {
  1196			return false
  1197		}
  1198	
  1199		if _W == 64 && (r%31 == 0 || r%37 == 0 || r%41 == 0 ||
  1200			r%43 == 0 || r%47 == 0 || r%53 == 0) {
  1201			return false
  1202		}
  1203	
  1204		nm1 := nat(nil).sub(n, natOne)
  1205		// determine q, k such that nm1 = q << k
  1206		k := nm1.trailingZeroBits()
  1207		q := nat(nil).shr(nm1, k)
  1208	
  1209		nm3 := nat(nil).sub(nm1, natTwo)
  1210		rand := rand.New(rand.NewSource(int64(n[0])))
  1211	
  1212		var x, y, quotient nat
  1213		nm3Len := nm3.bitLen()
  1214	
  1215	NextRandom:
  1216		for i := 0; i < reps; i++ {
  1217			x = x.random(rand, nm3, nm3Len)
  1218			x = x.add(x, natTwo)
  1219			y = y.expNN(x, q, n)
  1220			if y.cmp(natOne) == 0 || y.cmp(nm1) == 0 {
  1221				continue
  1222			}
  1223			for j := uint(1); j < k; j++ {
  1224				y = y.mul(y, y)
  1225				quotient, y = quotient.div(y, y, n)
  1226				if y.cmp(nm1) == 0 {
  1227					continue NextRandom
  1228				}
  1229				if y.cmp(natOne) == 0 {
  1230					return false
  1231				}
  1232			}
  1233			return false
  1234		}
  1235	
  1236		return true
  1237	}
  1238	
  1239	// bytes writes the value of z into buf using big-endian encoding.
  1240	// len(buf) must be >= len(z)*_S. The value of z is encoded in the
  1241	// slice buf[i:]. The number i of unused bytes at the beginning of
  1242	// buf is returned as result.
  1243	func (z nat) bytes(buf []byte) (i int) {
  1244		i = len(buf)
  1245		for _, d := range z {
  1246			for j := 0; j < _S; j++ {
  1247				i--
  1248				buf[i] = byte(d)
  1249				d >>= 8
  1250			}
  1251		}
  1252	
  1253		for i < len(buf) && buf[i] == 0 {
  1254			i++
  1255		}
  1256	
  1257		return
  1258	}
  1259	
  1260	// setBytes interprets buf as the bytes of a big-endian unsigned
  1261	// integer, sets z to that value, and returns z.
  1262	func (z nat) setBytes(buf []byte) nat {
  1263		z = z.make((len(buf) + _S - 1) / _S)
  1264	
  1265		k := 0
  1266		s := uint(0)
  1267		var d Word
  1268		for i := len(buf); i > 0; i-- {
  1269			d |= Word(buf[i-1]) << s
  1270			if s += 8; s == _S*8 {
  1271				z[k] = d
  1272				k++
  1273				s = 0
  1274				d = 0
  1275			}
  1276		}
  1277		if k < len(z) {
  1278			z[k] = d
  1279		}
  1280	
  1281		return z.norm()
  1282	}
  1283	

View as plain text