...
Run Format

Source file src/math/big/nat.go

     1	// Copyright 2009 The Go Authors. All rights reserved.
     2	// Use of this source code is governed by a BSD-style
     3	// license that can be found in the LICENSE file.
     4	
     5	// This file implements unsigned multi-precision integers (natural
     6	// numbers). They are the building blocks for the implementation
     7	// of signed integers, rationals, and floating-point numbers.
     8	
     9	package big
    10	
    11	import (
    12		"math/rand"
    13		"sync"
    14	)
    15	
    16	// An unsigned integer x of the form
    17	//
    18	//   x = x[n-1]*_B^(n-1) + x[n-2]*_B^(n-2) + ... + x[1]*_B + x[0]
    19	//
    20	// with 0 <= x[i] < _B and 0 <= i < n is stored in a slice of length n,
    21	// with the digits x[i] as the slice elements.
    22	//
    23	// A number is normalized if the slice contains no leading 0 digits.
    24	// During arithmetic operations, denormalized values may occur but are
    25	// always normalized before returning the final result. The normalized
    26	// representation of 0 is the empty or nil slice (length = 0).
    27	//
    28	type nat []Word
    29	
    30	var (
    31		natOne = nat{1}
    32		natTwo = nat{2}
    33		natTen = nat{10}
    34	)
    35	
    36	func (z nat) clear() {
    37		for i := range z {
    38			z[i] = 0
    39		}
    40	}
    41	
    42	func (z nat) norm() nat {
    43		i := len(z)
    44		for i > 0 && z[i-1] == 0 {
    45			i--
    46		}
    47		return z[0:i]
    48	}
    49	
    50	func (z nat) make(n int) nat {
    51		if n <= cap(z) {
    52			return z[:n] // reuse z
    53		}
    54		// Choosing a good value for e has significant performance impact
    55		// because it increases the chance that a value can be reused.
    56		const e = 4 // extra capacity
    57		return make(nat, n, n+e)
    58	}
    59	
    60	func (z nat) setWord(x Word) nat {
    61		if x == 0 {
    62			return z[:0]
    63		}
    64		z = z.make(1)
    65		z[0] = x
    66		return z
    67	}
    68	
    69	func (z nat) setUint64(x uint64) nat {
    70		// single-digit values
    71		if w := Word(x); uint64(w) == x {
    72			return z.setWord(w)
    73		}
    74	
    75		// compute number of words n required to represent x
    76		n := 0
    77		for t := x; t > 0; t >>= _W {
    78			n++
    79		}
    80	
    81		// split x into n words
    82		z = z.make(n)
    83		for i := range z {
    84			z[i] = Word(x & _M)
    85			x >>= _W
    86		}
    87	
    88		return z
    89	}
    90	
    91	func (z nat) set(x nat) nat {
    92		z = z.make(len(x))
    93		copy(z, x)
    94		return z
    95	}
    96	
    97	func (z nat) add(x, y nat) nat {
    98		m := len(x)
    99		n := len(y)
   100	
   101		switch {
   102		case m < n:
   103			return z.add(y, x)
   104		case m == 0:
   105			// n == 0 because m >= n; result is 0
   106			return z[:0]
   107		case n == 0:
   108			// result is x
   109			return z.set(x)
   110		}
   111		// m > 0
   112	
   113		z = z.make(m + 1)
   114		c := addVV(z[0:n], x, y)
   115		if m > n {
   116			c = addVW(z[n:m], x[n:], c)
   117		}
   118		z[m] = c
   119	
   120		return z.norm()
   121	}
   122	
   123	func (z nat) sub(x, y nat) nat {
   124		m := len(x)
   125		n := len(y)
   126	
   127		switch {
   128		case m < n:
   129			panic("underflow")
   130		case m == 0:
   131			// n == 0 because m >= n; result is 0
   132			return z[:0]
   133		case n == 0:
   134			// result is x
   135			return z.set(x)
   136		}
   137		// m > 0
   138	
   139		z = z.make(m)
   140		c := subVV(z[0:n], x, y)
   141		if m > n {
   142			c = subVW(z[n:], x[n:], c)
   143		}
   144		if c != 0 {
   145			panic("underflow")
   146		}
   147	
   148		return z.norm()
   149	}
   150	
   151	func (x nat) cmp(y nat) (r int) {
   152		m := len(x)
   153		n := len(y)
   154		if m != n || m == 0 {
   155			switch {
   156			case m < n:
   157				r = -1
   158			case m > n:
   159				r = 1
   160			}
   161			return
   162		}
   163	
   164		i := m - 1
   165		for i > 0 && x[i] == y[i] {
   166			i--
   167		}
   168	
   169		switch {
   170		case x[i] < y[i]:
   171			r = -1
   172		case x[i] > y[i]:
   173			r = 1
   174		}
   175		return
   176	}
   177	
   178	func (z nat) mulAddWW(x nat, y, r Word) nat {
   179		m := len(x)
   180		if m == 0 || y == 0 {
   181			return z.setWord(r) // result is r
   182		}
   183		// m > 0
   184	
   185		z = z.make(m + 1)
   186		z[m] = mulAddVWW(z[0:m], x, y, r)
   187	
   188		return z.norm()
   189	}
   190	
   191	// basicMul multiplies x and y and leaves the result in z.
   192	// The (non-normalized) result is placed in z[0 : len(x) + len(y)].
   193	func basicMul(z, x, y nat) {
   194		z[0 : len(x)+len(y)].clear() // initialize z
   195		for i, d := range y {
   196			if d != 0 {
   197				z[len(x)+i] = addMulVVW(z[i:i+len(x)], x, d)
   198			}
   199		}
   200	}
   201	
   202	// montgomery computes z mod m = x*y*2**(-n*_W) mod m,
   203	// assuming k = -1/m mod 2**_W.
   204	// z is used for storing the result which is returned;
   205	// z must not alias x, y or m.
   206	// See Gueron, "Efficient Software Implementations of Modular Exponentiation".
   207	// https://eprint.iacr.org/2011/239.pdf
   208	// In the terminology of that paper, this is an "Almost Montgomery Multiplication":
   209	// x and y are required to satisfy 0 <= z < 2**(n*_W) and then the result
   210	// z is guaranteed to satisfy 0 <= z < 2**(n*_W), but it may not be < m.
   211	func (z nat) montgomery(x, y, m nat, k Word, n int) nat {
   212		// This code assumes x, y, m are all the same length, n.
   213		// (required by addMulVVW and the for loop).
   214		// It also assumes that x, y are already reduced mod m,
   215		// or else the result will not be properly reduced.
   216		if len(x) != n || len(y) != n || len(m) != n {
   217			panic("math/big: mismatched montgomery number lengths")
   218		}
   219		z = z.make(n)
   220		z.clear()
   221		var c Word
   222		for i := 0; i < n; i++ {
   223			d := y[i]
   224			c2 := addMulVVW(z, x, d)
   225			t := z[0] * k
   226			c3 := addMulVVW(z, m, t)
   227			copy(z, z[1:])
   228			cx := c + c2
   229			cy := cx + c3
   230			z[n-1] = cy
   231			if cx < c2 || cy < c3 {
   232				c = 1
   233			} else {
   234				c = 0
   235			}
   236		}
   237		if c != 0 {
   238			subVV(z, z, m)
   239		}
   240		return z
   241	}
   242	
   243	// Fast version of z[0:n+n>>1].add(z[0:n+n>>1], x[0:n]) w/o bounds checks.
   244	// Factored out for readability - do not use outside karatsuba.
   245	func karatsubaAdd(z, x nat, n int) {
   246		if c := addVV(z[0:n], z, x); c != 0 {
   247			addVW(z[n:n+n>>1], z[n:], c)
   248		}
   249	}
   250	
   251	// Like karatsubaAdd, but does subtract.
   252	func karatsubaSub(z, x nat, n int) {
   253		if c := subVV(z[0:n], z, x); c != 0 {
   254			subVW(z[n:n+n>>1], z[n:], c)
   255		}
   256	}
   257	
   258	// Operands that are shorter than karatsubaThreshold are multiplied using
   259	// "grade school" multiplication; for longer operands the Karatsuba algorithm
   260	// is used.
   261	var karatsubaThreshold int = 40 // computed by calibrate.go
   262	
   263	// karatsuba multiplies x and y and leaves the result in z.
   264	// Both x and y must have the same length n and n must be a
   265	// power of 2. The result vector z must have len(z) >= 6*n.
   266	// The (non-normalized) result is placed in z[0 : 2*n].
   267	func karatsuba(z, x, y nat) {
   268		n := len(y)
   269	
   270		// Switch to basic multiplication if numbers are odd or small.
   271		// (n is always even if karatsubaThreshold is even, but be
   272		// conservative)
   273		if n&1 != 0 || n < karatsubaThreshold || n < 2 {
   274			basicMul(z, x, y)
   275			return
   276		}
   277		// n&1 == 0 && n >= karatsubaThreshold && n >= 2
   278	
   279		// Karatsuba multiplication is based on the observation that
   280		// for two numbers x and y with:
   281		//
   282		//   x = x1*b + x0
   283		//   y = y1*b + y0
   284		//
   285		// the product x*y can be obtained with 3 products z2, z1, z0
   286		// instead of 4:
   287		//
   288		//   x*y = x1*y1*b*b + (x1*y0 + x0*y1)*b + x0*y0
   289		//       =    z2*b*b +              z1*b +    z0
   290		//
   291		// with:
   292		//
   293		//   xd = x1 - x0
   294		//   yd = y0 - y1
   295		//
   296		//   z1 =      xd*yd                    + z2 + z0
   297		//      = (x1-x0)*(y0 - y1)             + z2 + z0
   298		//      = x1*y0 - x1*y1 - x0*y0 + x0*y1 + z2 + z0
   299		//      = x1*y0 -    z2 -    z0 + x0*y1 + z2 + z0
   300		//      = x1*y0                 + x0*y1
   301	
   302		// split x, y into "digits"
   303		n2 := n >> 1              // n2 >= 1
   304		x1, x0 := x[n2:], x[0:n2] // x = x1*b + y0
   305		y1, y0 := y[n2:], y[0:n2] // y = y1*b + y0
   306	
   307		// z is used for the result and temporary storage:
   308		//
   309		//   6*n     5*n     4*n     3*n     2*n     1*n     0*n
   310		// z = [z2 copy|z0 copy| xd*yd | yd:xd | x1*y1 | x0*y0 ]
   311		//
   312		// For each recursive call of karatsuba, an unused slice of
   313		// z is passed in that has (at least) half the length of the
   314		// caller's z.
   315	
   316		// compute z0 and z2 with the result "in place" in z
   317		karatsuba(z, x0, y0)     // z0 = x0*y0
   318		karatsuba(z[n:], x1, y1) // z2 = x1*y1
   319	
   320		// compute xd (or the negative value if underflow occurs)
   321		s := 1 // sign of product xd*yd
   322		xd := z[2*n : 2*n+n2]
   323		if subVV(xd, x1, x0) != 0 { // x1-x0
   324			s = -s
   325			subVV(xd, x0, x1) // x0-x1
   326		}
   327	
   328		// compute yd (or the negative value if underflow occurs)
   329		yd := z[2*n+n2 : 3*n]
   330		if subVV(yd, y0, y1) != 0 { // y0-y1
   331			s = -s
   332			subVV(yd, y1, y0) // y1-y0
   333		}
   334	
   335		// p = (x1-x0)*(y0-y1) == x1*y0 - x1*y1 - x0*y0 + x0*y1 for s > 0
   336		// p = (x0-x1)*(y0-y1) == x0*y0 - x0*y1 - x1*y0 + x1*y1 for s < 0
   337		p := z[n*3:]
   338		karatsuba(p, xd, yd)
   339	
   340		// save original z2:z0
   341		// (ok to use upper half of z since we're done recursing)
   342		r := z[n*4:]
   343		copy(r, z[:n*2])
   344	
   345		// add up all partial products
   346		//
   347		//   2*n     n     0
   348		// z = [ z2  | z0  ]
   349		//   +    [ z0  ]
   350		//   +    [ z2  ]
   351		//   +    [  p  ]
   352		//
   353		karatsubaAdd(z[n2:], r, n)
   354		karatsubaAdd(z[n2:], r[n:], n)
   355		if s > 0 {
   356			karatsubaAdd(z[n2:], p, n)
   357		} else {
   358			karatsubaSub(z[n2:], p, n)
   359		}
   360	}
   361	
   362	// alias reports whether x and y share the same base array.
   363	func alias(x, y nat) bool {
   364		return cap(x) > 0 && cap(y) > 0 && &x[0:cap(x)][cap(x)-1] == &y[0:cap(y)][cap(y)-1]
   365	}
   366	
   367	// addAt implements z += x<<(_W*i); z must be long enough.
   368	// (we don't use nat.add because we need z to stay the same
   369	// slice, and we don't need to normalize z after each addition)
   370	func addAt(z, x nat, i int) {
   371		if n := len(x); n > 0 {
   372			if c := addVV(z[i:i+n], z[i:], x); c != 0 {
   373				j := i + n
   374				if j < len(z) {
   375					addVW(z[j:], z[j:], c)
   376				}
   377			}
   378		}
   379	}
   380	
   381	func max(x, y int) int {
   382		if x > y {
   383			return x
   384		}
   385		return y
   386	}
   387	
   388	// karatsubaLen computes an approximation to the maximum k <= n such that
   389	// k = p<<i for a number p <= karatsubaThreshold and an i >= 0. Thus, the
   390	// result is the largest number that can be divided repeatedly by 2 before
   391	// becoming about the value of karatsubaThreshold.
   392	func karatsubaLen(n int) int {
   393		i := uint(0)
   394		for n > karatsubaThreshold {
   395			n >>= 1
   396			i++
   397		}
   398		return n << i
   399	}
   400	
   401	func (z nat) mul(x, y nat) nat {
   402		m := len(x)
   403		n := len(y)
   404	
   405		switch {
   406		case m < n:
   407			return z.mul(y, x)
   408		case m == 0 || n == 0:
   409			return z[:0]
   410		case n == 1:
   411			return z.mulAddWW(x, y[0], 0)
   412		}
   413		// m >= n > 1
   414	
   415		// determine if z can be reused
   416		if alias(z, x) || alias(z, y) {
   417			z = nil // z is an alias for x or y - cannot reuse
   418		}
   419	
   420		// use basic multiplication if the numbers are small
   421		if n < karatsubaThreshold {
   422			z = z.make(m + n)
   423			basicMul(z, x, y)
   424			return z.norm()
   425		}
   426		// m >= n && n >= karatsubaThreshold && n >= 2
   427	
   428		// determine Karatsuba length k such that
   429		//
   430		//   x = xh*b + x0  (0 <= x0 < b)
   431		//   y = yh*b + y0  (0 <= y0 < b)
   432		//   b = 1<<(_W*k)  ("base" of digits xi, yi)
   433		//
   434		k := karatsubaLen(n)
   435		// k <= n
   436	
   437		// multiply x0 and y0 via Karatsuba
   438		x0 := x[0:k]              // x0 is not normalized
   439		y0 := y[0:k]              // y0 is not normalized
   440		z = z.make(max(6*k, m+n)) // enough space for karatsuba of x0*y0 and full result of x*y
   441		karatsuba(z, x0, y0)
   442		z = z[0 : m+n]  // z has final length but may be incomplete
   443		z[2*k:].clear() // upper portion of z is garbage (and 2*k <= m+n since k <= n <= m)
   444	
   445		// If xh != 0 or yh != 0, add the missing terms to z. For
   446		//
   447		//   xh = xi*b^i + ... + x2*b^2 + x1*b (0 <= xi < b)
   448		//   yh =                         y1*b (0 <= y1 < b)
   449		//
   450		// the missing terms are
   451		//
   452		//   x0*y1*b and xi*y0*b^i, xi*y1*b^(i+1) for i > 0
   453		//
   454		// since all the yi for i > 1 are 0 by choice of k: If any of them
   455		// were > 0, then yh >= b^2 and thus y >= b^2. Then k' = k*2 would
   456		// be a larger valid threshold contradicting the assumption about k.
   457		//
   458		if k < n || m != n {
   459			var t nat
   460	
   461			// add x0*y1*b
   462			x0 := x0.norm()
   463			y1 := y[k:]       // y1 is normalized because y is
   464			t = t.mul(x0, y1) // update t so we don't lose t's underlying array
   465			addAt(z, t, k)
   466	
   467			// add xi*y0<<i, xi*y1*b<<(i+k)
   468			y0 := y0.norm()
   469			for i := k; i < len(x); i += k {
   470				xi := x[i:]
   471				if len(xi) > k {
   472					xi = xi[:k]
   473				}
   474				xi = xi.norm()
   475				t = t.mul(xi, y0)
   476				addAt(z, t, i)
   477				t = t.mul(xi, y1)
   478				addAt(z, t, i+k)
   479			}
   480		}
   481	
   482		return z.norm()
   483	}
   484	
   485	// mulRange computes the product of all the unsigned integers in the
   486	// range [a, b] inclusively. If a > b (empty range), the result is 1.
   487	func (z nat) mulRange(a, b uint64) nat {
   488		switch {
   489		case a == 0:
   490			// cut long ranges short (optimization)
   491			return z.setUint64(0)
   492		case a > b:
   493			return z.setUint64(1)
   494		case a == b:
   495			return z.setUint64(a)
   496		case a+1 == b:
   497			return z.mul(nat(nil).setUint64(a), nat(nil).setUint64(b))
   498		}
   499		m := (a + b) / 2
   500		return z.mul(nat(nil).mulRange(a, m), nat(nil).mulRange(m+1, b))
   501	}
   502	
   503	// q = (x-r)/y, with 0 <= r < y
   504	func (z nat) divW(x nat, y Word) (q nat, r Word) {
   505		m := len(x)
   506		switch {
   507		case y == 0:
   508			panic("division by zero")
   509		case y == 1:
   510			q = z.set(x) // result is x
   511			return
   512		case m == 0:
   513			q = z[:0] // result is 0
   514			return
   515		}
   516		// m > 0
   517		z = z.make(m)
   518		r = divWVW(z, 0, x, y)
   519		q = z.norm()
   520		return
   521	}
   522	
   523	func (z nat) div(z2, u, v nat) (q, r nat) {
   524		if len(v) == 0 {
   525			panic("division by zero")
   526		}
   527	
   528		if u.cmp(v) < 0 {
   529			q = z[:0]
   530			r = z2.set(u)
   531			return
   532		}
   533	
   534		if len(v) == 1 {
   535			var r2 Word
   536			q, r2 = z.divW(u, v[0])
   537			r = z2.setWord(r2)
   538			return
   539		}
   540	
   541		q, r = z.divLarge(z2, u, v)
   542		return
   543	}
   544	
   545	// getNat returns a nat of len n. The contents may not be zero.
   546	func getNat(n int) nat {
   547		var z nat
   548		if v := natPool.Get(); v != nil {
   549			z = v.(nat)
   550		}
   551		return z.make(n)
   552	}
   553	
   554	func putNat(x nat) {
   555		natPool.Put(x)
   556	}
   557	
   558	var natPool sync.Pool
   559	
   560	// q = (uIn-r)/v, with 0 <= r < y
   561	// Uses z as storage for q, and u as storage for r if possible.
   562	// See Knuth, Volume 2, section 4.3.1, Algorithm D.
   563	// Preconditions:
   564	//    len(v) >= 2
   565	//    len(uIn) >= len(v)
   566	func (z nat) divLarge(u, uIn, v nat) (q, r nat) {
   567		n := len(v)
   568		m := len(uIn) - n
   569	
   570		// determine if z can be reused
   571		// TODO(gri) should find a better solution - this if statement
   572		//           is very costly (see e.g. time pidigits -s -n 10000)
   573		if alias(z, uIn) || alias(z, v) {
   574			z = nil // z is an alias for uIn or v - cannot reuse
   575		}
   576		q = z.make(m + 1)
   577	
   578		qhatv := getNat(n + 1)
   579		if alias(u, uIn) || alias(u, v) {
   580			u = nil // u is an alias for uIn or v - cannot reuse
   581		}
   582		u = u.make(len(uIn) + 1)
   583		u.clear() // TODO(gri) no need to clear if we allocated a new u
   584	
   585		// D1.
   586		var v1 nat
   587		shift := nlz(v[n-1])
   588		if shift > 0 {
   589			// do not modify v, it may be used by another goroutine simultaneously
   590			v1 = getNat(n)
   591			shlVU(v1, v, shift)
   592			v = v1
   593		}
   594		u[len(uIn)] = shlVU(u[0:len(uIn)], uIn, shift)
   595	
   596		// D2.
   597		for j := m; j >= 0; j-- {
   598			// D3.
   599			qhat := Word(_M)
   600			if u[j+n] != v[n-1] {
   601				var rhat Word
   602				qhat, rhat = divWW(u[j+n], u[j+n-1], v[n-1])
   603	
   604				// x1 | x2 = q̂v_{n-2}
   605				x1, x2 := mulWW(qhat, v[n-2])
   606				// test if q̂v_{n-2} > br̂ + u_{j+n-2}
   607				for greaterThan(x1, x2, rhat, u[j+n-2]) {
   608					qhat--
   609					prevRhat := rhat
   610					rhat += v[n-1]
   611					// v[n-1] >= 0, so this tests for overflow.
   612					if rhat < prevRhat {
   613						break
   614					}
   615					x1, x2 = mulWW(qhat, v[n-2])
   616				}
   617			}
   618	
   619			// D4.
   620			qhatv[n] = mulAddVWW(qhatv[0:n], v, qhat, 0)
   621	
   622			c := subVV(u[j:j+len(qhatv)], u[j:], qhatv)
   623			if c != 0 {
   624				c := addVV(u[j:j+n], u[j:], v)
   625				u[j+n] += c
   626				qhat--
   627			}
   628	
   629			q[j] = qhat
   630		}
   631		if v1 != nil {
   632			putNat(v1)
   633		}
   634		putNat(qhatv)
   635	
   636		q = q.norm()
   637		shrVU(u, u, shift)
   638		r = u.norm()
   639	
   640		return q, r
   641	}
   642	
   643	// Length of x in bits. x must be normalized.
   644	func (x nat) bitLen() int {
   645		if i := len(x) - 1; i >= 0 {
   646			return i*_W + bitLen(x[i])
   647		}
   648		return 0
   649	}
   650	
   651	const deBruijn32 = 0x077CB531
   652	
   653	var deBruijn32Lookup = []byte{
   654		0, 1, 28, 2, 29, 14, 24, 3, 30, 22, 20, 15, 25, 17, 4, 8,
   655		31, 27, 13, 23, 21, 19, 16, 7, 26, 12, 18, 6, 11, 5, 10, 9,
   656	}
   657	
   658	const deBruijn64 = 0x03f79d71b4ca8b09
   659	
   660	var deBruijn64Lookup = []byte{
   661		0, 1, 56, 2, 57, 49, 28, 3, 61, 58, 42, 50, 38, 29, 17, 4,
   662		62, 47, 59, 36, 45, 43, 51, 22, 53, 39, 33, 30, 24, 18, 12, 5,
   663		63, 55, 48, 27, 60, 41, 37, 16, 46, 35, 44, 21, 52, 32, 23, 11,
   664		54, 26, 40, 15, 34, 20, 31, 10, 25, 14, 19, 9, 13, 8, 7, 6,
   665	}
   666	
   667	// trailingZeroBits returns the number of consecutive least significant zero
   668	// bits of x.
   669	func trailingZeroBits(x Word) uint {
   670		// x & -x leaves only the right-most bit set in the word. Let k be the
   671		// index of that bit. Since only a single bit is set, the value is two
   672		// to the power of k. Multiplying by a power of two is equivalent to
   673		// left shifting, in this case by k bits. The de Bruijn constant is
   674		// such that all six bit, consecutive substrings are distinct.
   675		// Therefore, if we have a left shifted version of this constant we can
   676		// find by how many bits it was shifted by looking at which six bit
   677		// substring ended up at the top of the word.
   678		// (Knuth, volume 4, section 7.3.1)
   679		switch _W {
   680		case 32:
   681			return uint(deBruijn32Lookup[((x&-x)*deBruijn32)>>27])
   682		case 64:
   683			return uint(deBruijn64Lookup[((x&-x)*(deBruijn64&_M))>>58])
   684		default:
   685			panic("unknown word size")
   686		}
   687	}
   688	
   689	// trailingZeroBits returns the number of consecutive least significant zero
   690	// bits of x.
   691	func (x nat) trailingZeroBits() uint {
   692		if len(x) == 0 {
   693			return 0
   694		}
   695		var i uint
   696		for x[i] == 0 {
   697			i++
   698		}
   699		// x[i] != 0
   700		return i*_W + trailingZeroBits(x[i])
   701	}
   702	
   703	// z = x << s
   704	func (z nat) shl(x nat, s uint) nat {
   705		m := len(x)
   706		if m == 0 {
   707			return z[:0]
   708		}
   709		// m > 0
   710	
   711		n := m + int(s/_W)
   712		z = z.make(n + 1)
   713		z[n] = shlVU(z[n-m:n], x, s%_W)
   714		z[0 : n-m].clear()
   715	
   716		return z.norm()
   717	}
   718	
   719	// z = x >> s
   720	func (z nat) shr(x nat, s uint) nat {
   721		m := len(x)
   722		n := m - int(s/_W)
   723		if n <= 0 {
   724			return z[:0]
   725		}
   726		// n > 0
   727	
   728		z = z.make(n)
   729		shrVU(z, x[m-n:], s%_W)
   730	
   731		return z.norm()
   732	}
   733	
   734	func (z nat) setBit(x nat, i uint, b uint) nat {
   735		j := int(i / _W)
   736		m := Word(1) << (i % _W)
   737		n := len(x)
   738		switch b {
   739		case 0:
   740			z = z.make(n)
   741			copy(z, x)
   742			if j >= n {
   743				// no need to grow
   744				return z
   745			}
   746			z[j] &^= m
   747			return z.norm()
   748		case 1:
   749			if j >= n {
   750				z = z.make(j + 1)
   751				z[n:].clear()
   752			} else {
   753				z = z.make(n)
   754			}
   755			copy(z, x)
   756			z[j] |= m
   757			// no need to normalize
   758			return z
   759		}
   760		panic("set bit is not 0 or 1")
   761	}
   762	
   763	// bit returns the value of the i'th bit, with lsb == bit 0.
   764	func (x nat) bit(i uint) uint {
   765		j := i / _W
   766		if j >= uint(len(x)) {
   767			return 0
   768		}
   769		// 0 <= j < len(x)
   770		return uint(x[j] >> (i % _W) & 1)
   771	}
   772	
   773	// sticky returns 1 if there's a 1 bit within the
   774	// i least significant bits, otherwise it returns 0.
   775	func (x nat) sticky(i uint) uint {
   776		j := i / _W
   777		if j >= uint(len(x)) {
   778			if len(x) == 0 {
   779				return 0
   780			}
   781			return 1
   782		}
   783		// 0 <= j < len(x)
   784		for _, x := range x[:j] {
   785			if x != 0 {
   786				return 1
   787			}
   788		}
   789		if x[j]<<(_W-i%_W) != 0 {
   790			return 1
   791		}
   792		return 0
   793	}
   794	
   795	func (z nat) and(x, y nat) nat {
   796		m := len(x)
   797		n := len(y)
   798		if m > n {
   799			m = n
   800		}
   801		// m <= n
   802	
   803		z = z.make(m)
   804		for i := 0; i < m; i++ {
   805			z[i] = x[i] & y[i]
   806		}
   807	
   808		return z.norm()
   809	}
   810	
   811	func (z nat) andNot(x, y nat) nat {
   812		m := len(x)
   813		n := len(y)
   814		if n > m {
   815			n = m
   816		}
   817		// m >= n
   818	
   819		z = z.make(m)
   820		for i := 0; i < n; i++ {
   821			z[i] = x[i] &^ y[i]
   822		}
   823		copy(z[n:m], x[n:m])
   824	
   825		return z.norm()
   826	}
   827	
   828	func (z nat) or(x, y nat) nat {
   829		m := len(x)
   830		n := len(y)
   831		s := x
   832		if m < n {
   833			n, m = m, n
   834			s = y
   835		}
   836		// m >= n
   837	
   838		z = z.make(m)
   839		for i := 0; i < n; i++ {
   840			z[i] = x[i] | y[i]
   841		}
   842		copy(z[n:m], s[n:m])
   843	
   844		return z.norm()
   845	}
   846	
   847	func (z nat) xor(x, y nat) nat {
   848		m := len(x)
   849		n := len(y)
   850		s := x
   851		if m < n {
   852			n, m = m, n
   853			s = y
   854		}
   855		// m >= n
   856	
   857		z = z.make(m)
   858		for i := 0; i < n; i++ {
   859			z[i] = x[i] ^ y[i]
   860		}
   861		copy(z[n:m], s[n:m])
   862	
   863		return z.norm()
   864	}
   865	
   866	// greaterThan reports whether (x1<<_W + x2) > (y1<<_W + y2)
   867	func greaterThan(x1, x2, y1, y2 Word) bool {
   868		return x1 > y1 || x1 == y1 && x2 > y2
   869	}
   870	
   871	// modW returns x % d.
   872	func (x nat) modW(d Word) (r Word) {
   873		// TODO(agl): we don't actually need to store the q value.
   874		var q nat
   875		q = q.make(len(x))
   876		return divWVW(q, 0, x, d)
   877	}
   878	
   879	// random creates a random integer in [0..limit), using the space in z if
   880	// possible. n is the bit length of limit.
   881	func (z nat) random(rand *rand.Rand, limit nat, n int) nat {
   882		if alias(z, limit) {
   883			z = nil // z is an alias for limit - cannot reuse
   884		}
   885		z = z.make(len(limit))
   886	
   887		bitLengthOfMSW := uint(n % _W)
   888		if bitLengthOfMSW == 0 {
   889			bitLengthOfMSW = _W
   890		}
   891		mask := Word((1 << bitLengthOfMSW) - 1)
   892	
   893		for {
   894			switch _W {
   895			case 32:
   896				for i := range z {
   897					z[i] = Word(rand.Uint32())
   898				}
   899			case 64:
   900				for i := range z {
   901					z[i] = Word(rand.Uint32()) | Word(rand.Uint32())<<32
   902				}
   903			default:
   904				panic("unknown word size")
   905			}
   906			z[len(limit)-1] &= mask
   907			if z.cmp(limit) < 0 {
   908				break
   909			}
   910		}
   911	
   912		return z.norm()
   913	}
   914	
   915	// If m != 0 (i.e., len(m) != 0), expNN sets z to x**y mod m;
   916	// otherwise it sets z to x**y. The result is the value of z.
   917	func (z nat) expNN(x, y, m nat) nat {
   918		if alias(z, x) || alias(z, y) {
   919			// We cannot allow in-place modification of x or y.
   920			z = nil
   921		}
   922	
   923		// x**y mod 1 == 0
   924		if len(m) == 1 && m[0] == 1 {
   925			return z.setWord(0)
   926		}
   927		// m == 0 || m > 1
   928	
   929		// x**0 == 1
   930		if len(y) == 0 {
   931			return z.setWord(1)
   932		}
   933		// y > 0
   934	
   935		// x**1 mod m == x mod m
   936		if len(y) == 1 && y[0] == 1 && len(m) != 0 {
   937			_, z = z.div(z, x, m)
   938			return z
   939		}
   940		// y > 1
   941	
   942		if len(m) != 0 {
   943			// We likely end up being as long as the modulus.
   944			z = z.make(len(m))
   945		}
   946		z = z.set(x)
   947	
   948		// If the base is non-trivial and the exponent is large, we use
   949		// 4-bit, windowed exponentiation. This involves precomputing 14 values
   950		// (x^2...x^15) but then reduces the number of multiply-reduces by a
   951		// third. Even for a 32-bit exponent, this reduces the number of
   952		// operations. Uses Montgomery method for odd moduli.
   953		if len(x) > 1 && len(y) > 1 && len(m) > 0 {
   954			if m[0]&1 == 1 {
   955				return z.expNNMontgomery(x, y, m)
   956			}
   957			return z.expNNWindowed(x, y, m)
   958		}
   959	
   960		v := y[len(y)-1] // v > 0 because y is normalized and y > 0
   961		shift := nlz(v) + 1
   962		v <<= shift
   963		var q nat
   964	
   965		const mask = 1 << (_W - 1)
   966	
   967		// We walk through the bits of the exponent one by one. Each time we
   968		// see a bit, we square, thus doubling the power. If the bit is a one,
   969		// we also multiply by x, thus adding one to the power.
   970	
   971		w := _W - int(shift)
   972		// zz and r are used to avoid allocating in mul and div as
   973		// otherwise the arguments would alias.
   974		var zz, r nat
   975		for j := 0; j < w; j++ {
   976			zz = zz.mul(z, z)
   977			zz, z = z, zz
   978	
   979			if v&mask != 0 {
   980				zz = zz.mul(z, x)
   981				zz, z = z, zz
   982			}
   983	
   984			if len(m) != 0 {
   985				zz, r = zz.div(r, z, m)
   986				zz, r, q, z = q, z, zz, r
   987			}
   988	
   989			v <<= 1
   990		}
   991	
   992		for i := len(y) - 2; i >= 0; i-- {
   993			v = y[i]
   994	
   995			for j := 0; j < _W; j++ {
   996				zz = zz.mul(z, z)
   997				zz, z = z, zz
   998	
   999				if v&mask != 0 {
  1000					zz = zz.mul(z, x)
  1001					zz, z = z, zz
  1002				}
  1003	
  1004				if len(m) != 0 {
  1005					zz, r = zz.div(r, z, m)
  1006					zz, r, q, z = q, z, zz, r
  1007				}
  1008	
  1009				v <<= 1
  1010			}
  1011		}
  1012	
  1013		return z.norm()
  1014	}
  1015	
  1016	// expNNWindowed calculates x**y mod m using a fixed, 4-bit window.
  1017	func (z nat) expNNWindowed(x, y, m nat) nat {
  1018		// zz and r are used to avoid allocating in mul and div as otherwise
  1019		// the arguments would alias.
  1020		var zz, r nat
  1021	
  1022		const n = 4
  1023		// powers[i] contains x^i.
  1024		var powers [1 << n]nat
  1025		powers[0] = natOne
  1026		powers[1] = x
  1027		for i := 2; i < 1<<n; i += 2 {
  1028			p2, p, p1 := &powers[i/2], &powers[i], &powers[i+1]
  1029			*p = p.mul(*p2, *p2)
  1030			zz, r = zz.div(r, *p, m)
  1031			*p, r = r, *p
  1032			*p1 = p1.mul(*p, x)
  1033			zz, r = zz.div(r, *p1, m)
  1034			*p1, r = r, *p1
  1035		}
  1036	
  1037		z = z.setWord(1)
  1038	
  1039		for i := len(y) - 1; i >= 0; i-- {
  1040			yi := y[i]
  1041			for j := 0; j < _W; j += n {
  1042				if i != len(y)-1 || j != 0 {
  1043					// Unrolled loop for significant performance
  1044					// gain. Use go test -bench=".*" in crypto/rsa
  1045					// to check performance before making changes.
  1046					zz = zz.mul(z, z)
  1047					zz, z = z, zz
  1048					zz, r = zz.div(r, z, m)
  1049					z, r = r, z
  1050	
  1051					zz = zz.mul(z, z)
  1052					zz, z = z, zz
  1053					zz, r = zz.div(r, z, m)
  1054					z, r = r, z
  1055	
  1056					zz = zz.mul(z, z)
  1057					zz, z = z, zz
  1058					zz, r = zz.div(r, z, m)
  1059					z, r = r, z
  1060	
  1061					zz = zz.mul(z, z)
  1062					zz, z = z, zz
  1063					zz, r = zz.div(r, z, m)
  1064					z, r = r, z
  1065				}
  1066	
  1067				zz = zz.mul(z, powers[yi>>(_W-n)])
  1068				zz, z = z, zz
  1069				zz, r = zz.div(r, z, m)
  1070				z, r = r, z
  1071	
  1072				yi <<= n
  1073			}
  1074		}
  1075	
  1076		return z.norm()
  1077	}
  1078	
  1079	// expNNMontgomery calculates x**y mod m using a fixed, 4-bit window.
  1080	// Uses Montgomery representation.
  1081	func (z nat) expNNMontgomery(x, y, m nat) nat {
  1082		numWords := len(m)
  1083	
  1084		// We want the lengths of x and m to be equal.
  1085		// It is OK if x >= m as long as len(x) == len(m).
  1086		if len(x) > numWords {
  1087			_, x = nat(nil).div(nil, x, m)
  1088			// Note: now len(x) <= numWords, not guaranteed ==.
  1089		}
  1090		if len(x) < numWords {
  1091			rr := make(nat, numWords)
  1092			copy(rr, x)
  1093			x = rr
  1094		}
  1095	
  1096		// Ideally the precomputations would be performed outside, and reused
  1097		// k0 = -m**-1 mod 2**_W. Algorithm from: Dumas, J.G. "On Newton–Raphson
  1098		// Iteration for Multiplicative Inverses Modulo Prime Powers".
  1099		k0 := 2 - m[0]
  1100		t := m[0] - 1
  1101		for i := 1; i < _W; i <<= 1 {
  1102			t *= t
  1103			k0 *= (t + 1)
  1104		}
  1105		k0 = -k0
  1106	
  1107		// RR = 2**(2*_W*len(m)) mod m
  1108		RR := nat(nil).setWord(1)
  1109		zz := nat(nil).shl(RR, uint(2*numWords*_W))
  1110		_, RR = RR.div(RR, zz, m)
  1111		if len(RR) < numWords {
  1112			zz = zz.make(numWords)
  1113			copy(zz, RR)
  1114			RR = zz
  1115		}
  1116		// one = 1, with equal length to that of m
  1117		one := make(nat, numWords)
  1118		one[0] = 1
  1119	
  1120		const n = 4
  1121		// powers[i] contains x^i
  1122		var powers [1 << n]nat
  1123		powers[0] = powers[0].montgomery(one, RR, m, k0, numWords)
  1124		powers[1] = powers[1].montgomery(x, RR, m, k0, numWords)
  1125		for i := 2; i < 1<<n; i++ {
  1126			powers[i] = powers[i].montgomery(powers[i-1], powers[1], m, k0, numWords)
  1127		}
  1128	
  1129		// initialize z = 1 (Montgomery 1)
  1130		z = z.make(numWords)
  1131		copy(z, powers[0])
  1132	
  1133		zz = zz.make(numWords)
  1134	
  1135		// same windowed exponent, but with Montgomery multiplications
  1136		for i := len(y) - 1; i >= 0; i-- {
  1137			yi := y[i]
  1138			for j := 0; j < _W; j += n {
  1139				if i != len(y)-1 || j != 0 {
  1140					zz = zz.montgomery(z, z, m, k0, numWords)
  1141					z = z.montgomery(zz, zz, m, k0, numWords)
  1142					zz = zz.montgomery(z, z, m, k0, numWords)
  1143					z = z.montgomery(zz, zz, m, k0, numWords)
  1144				}
  1145				zz = zz.montgomery(z, powers[yi>>(_W-n)], m, k0, numWords)
  1146				z, zz = zz, z
  1147				yi <<= n
  1148			}
  1149		}
  1150		// convert to regular number
  1151		zz = zz.montgomery(z, one, m, k0, numWords)
  1152	
  1153		// One last reduction, just in case.
  1154		// See golang.org/issue/13907.
  1155		if zz.cmp(m) >= 0 {
  1156			// Common case is m has high bit set; in that case,
  1157			// since zz is the same length as m, there can be just
  1158			// one multiple of m to remove. Just subtract.
  1159			// We think that the subtract should be sufficient in general,
  1160			// so do that unconditionally, but double-check,
  1161			// in case our beliefs are wrong.
  1162			// The div is not expected to be reached.
  1163			zz = zz.sub(zz, m)
  1164			if zz.cmp(m) >= 0 {
  1165				_, zz = nat(nil).div(nil, zz, m)
  1166			}
  1167		}
  1168	
  1169		return zz.norm()
  1170	}
  1171	
  1172	// probablyPrime performs n Miller-Rabin tests to check whether x is prime.
  1173	// If x is prime, it returns true.
  1174	// If x is not prime, it returns false with probability at least 1 - ¼ⁿ.
  1175	//
  1176	// It is not suitable for judging primes that an adversary may have crafted
  1177	// to fool this test.
  1178	func (n nat) probablyPrime(reps int) bool {
  1179		if len(n) == 0 {
  1180			return false
  1181		}
  1182	
  1183		if len(n) == 1 {
  1184			if n[0] < 2 {
  1185				return false
  1186			}
  1187	
  1188			if n[0]%2 == 0 {
  1189				return n[0] == 2
  1190			}
  1191	
  1192			// We have to exclude these cases because we reject all
  1193			// multiples of these numbers below.
  1194			switch n[0] {
  1195			case 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53:
  1196				return true
  1197			}
  1198		}
  1199	
  1200		if n[0]&1 == 0 {
  1201			return false // n is even
  1202		}
  1203	
  1204		const primesProduct32 = 0xC0CFD797         // Π {p ∈ primes, 2 < p <= 29}
  1205		const primesProduct64 = 0xE221F97C30E94E1D // Π {p ∈ primes, 2 < p <= 53}
  1206	
  1207		var r Word
  1208		switch _W {
  1209		case 32:
  1210			r = n.modW(primesProduct32)
  1211		case 64:
  1212			r = n.modW(primesProduct64 & _M)
  1213		default:
  1214			panic("Unknown word size")
  1215		}
  1216	
  1217		if r%3 == 0 || r%5 == 0 || r%7 == 0 || r%11 == 0 ||
  1218			r%13 == 0 || r%17 == 0 || r%19 == 0 || r%23 == 0 || r%29 == 0 {
  1219			return false
  1220		}
  1221	
  1222		if _W == 64 && (r%31 == 0 || r%37 == 0 || r%41 == 0 ||
  1223			r%43 == 0 || r%47 == 0 || r%53 == 0) {
  1224			return false
  1225		}
  1226	
  1227		nm1 := nat(nil).sub(n, natOne)
  1228		// determine q, k such that nm1 = q << k
  1229		k := nm1.trailingZeroBits()
  1230		q := nat(nil).shr(nm1, k)
  1231	
  1232		nm3 := nat(nil).sub(nm1, natTwo)
  1233		rand := rand.New(rand.NewSource(int64(n[0])))
  1234	
  1235		var x, y, quotient nat
  1236		nm3Len := nm3.bitLen()
  1237	
  1238	NextRandom:
  1239		for i := 0; i < reps; i++ {
  1240			x = x.random(rand, nm3, nm3Len)
  1241			x = x.add(x, natTwo)
  1242			y = y.expNN(x, q, n)
  1243			if y.cmp(natOne) == 0 || y.cmp(nm1) == 0 {
  1244				continue
  1245			}
  1246			for j := uint(1); j < k; j++ {
  1247				y = y.mul(y, y)
  1248				quotient, y = quotient.div(y, y, n)
  1249				if y.cmp(nm1) == 0 {
  1250					continue NextRandom
  1251				}
  1252				if y.cmp(natOne) == 0 {
  1253					return false
  1254				}
  1255			}
  1256			return false
  1257		}
  1258	
  1259		return true
  1260	}
  1261	
  1262	// bytes writes the value of z into buf using big-endian encoding.
  1263	// len(buf) must be >= len(z)*_S. The value of z is encoded in the
  1264	// slice buf[i:]. The number i of unused bytes at the beginning of
  1265	// buf is returned as result.
  1266	func (z nat) bytes(buf []byte) (i int) {
  1267		i = len(buf)
  1268		for _, d := range z {
  1269			for j := 0; j < _S; j++ {
  1270				i--
  1271				buf[i] = byte(d)
  1272				d >>= 8
  1273			}
  1274		}
  1275	
  1276		for i < len(buf) && buf[i] == 0 {
  1277			i++
  1278		}
  1279	
  1280		return
  1281	}
  1282	
  1283	// setBytes interprets buf as the bytes of a big-endian unsigned
  1284	// integer, sets z to that value, and returns z.
  1285	func (z nat) setBytes(buf []byte) nat {
  1286		z = z.make((len(buf) + _S - 1) / _S)
  1287	
  1288		k := 0
  1289		s := uint(0)
  1290		var d Word
  1291		for i := len(buf); i > 0; i-- {
  1292			d |= Word(buf[i-1]) << s
  1293			if s += 8; s == _S*8 {
  1294				z[k] = d
  1295				k++
  1296				s = 0
  1297				d = 0
  1298			}
  1299		}
  1300		if k < len(z) {
  1301			z[k] = d
  1302		}
  1303	
  1304		return z.norm()
  1305	}
  1306	

View as plain text