...
Run Format

Source file src/math/big/nat.go

     1	// Copyright 2009 The Go Authors. All rights reserved.
     2	// Use of this source code is governed by a BSD-style
     3	// license that can be found in the LICENSE file.
     4	
     5	// Package big implements multi-precision arithmetic (big numbers).
     6	// The following numeric types are supported:
     7	//
     8	//   Int    signed integers
     9	//   Rat    rational numbers
    10	//   Float  floating-point numbers
    11	//
    12	// Methods are typically of the form:
    13	//
    14	//   func (z *T) Unary(x *T) *T        // z = op x
    15	//   func (z *T) Binary(x, y *T) *T    // z = x op y
    16	//   func (x *T) M() T1                // v = x.M()
    17	//
    18	// with T one of Int, Rat, or Float. For unary and binary operations, the
    19	// result is the receiver (usually named z in that case); if it is one of
    20	// the operands x or y it may be overwritten (and its memory reused).
    21	// To enable chaining of operations, the result is also returned. Methods
    22	// returning a result other than *Int, *Rat, or *Float take an operand as
    23	// the receiver (usually named x in that case).
    24	//
    25	package big
    26	
    27	// This file contains operations on unsigned multi-precision integers.
    28	// These are the building blocks for the operations on signed integers
    29	// and rationals.
    30	
    31	import "math/rand"
    32	
    33	// An unsigned integer x of the form
    34	//
    35	//   x = x[n-1]*_B^(n-1) + x[n-2]*_B^(n-2) + ... + x[1]*_B + x[0]
    36	//
    37	// with 0 <= x[i] < _B and 0 <= i < n is stored in a slice of length n,
    38	// with the digits x[i] as the slice elements.
    39	//
    40	// A number is normalized if the slice contains no leading 0 digits.
    41	// During arithmetic operations, denormalized values may occur but are
    42	// always normalized before returning the final result. The normalized
    43	// representation of 0 is the empty or nil slice (length = 0).
    44	//
    45	type nat []Word
    46	
    47	var (
    48		natOne = nat{1}
    49		natTwo = nat{2}
    50		natTen = nat{10}
    51	)
    52	
    53	func (z nat) clear() {
    54		for i := range z {
    55			z[i] = 0
    56		}
    57	}
    58	
    59	func (z nat) norm() nat {
    60		i := len(z)
    61		for i > 0 && z[i-1] == 0 {
    62			i--
    63		}
    64		return z[0:i]
    65	}
    66	
    67	func (z nat) make(n int) nat {
    68		if n <= cap(z) {
    69			return z[:n] // reuse z
    70		}
    71		// Choosing a good value for e has significant performance impact
    72		// because it increases the chance that a value can be reused.
    73		const e = 4 // extra capacity
    74		return make(nat, n, n+e)
    75	}
    76	
    77	func (z nat) setWord(x Word) nat {
    78		if x == 0 {
    79			return z[:0]
    80		}
    81		z = z.make(1)
    82		z[0] = x
    83		return z
    84	}
    85	
    86	func (z nat) setUint64(x uint64) nat {
    87		// single-digit values
    88		if w := Word(x); uint64(w) == x {
    89			return z.setWord(w)
    90		}
    91	
    92		// compute number of words n required to represent x
    93		n := 0
    94		for t := x; t > 0; t >>= _W {
    95			n++
    96		}
    97	
    98		// split x into n words
    99		z = z.make(n)
   100		for i := range z {
   101			z[i] = Word(x & _M)
   102			x >>= _W
   103		}
   104	
   105		return z
   106	}
   107	
   108	func (z nat) set(x nat) nat {
   109		z = z.make(len(x))
   110		copy(z, x)
   111		return z
   112	}
   113	
   114	func (z nat) add(x, y nat) nat {
   115		m := len(x)
   116		n := len(y)
   117	
   118		switch {
   119		case m < n:
   120			return z.add(y, x)
   121		case m == 0:
   122			// n == 0 because m >= n; result is 0
   123			return z[:0]
   124		case n == 0:
   125			// result is x
   126			return z.set(x)
   127		}
   128		// m > 0
   129	
   130		z = z.make(m + 1)
   131		c := addVV(z[0:n], x, y)
   132		if m > n {
   133			c = addVW(z[n:m], x[n:], c)
   134		}
   135		z[m] = c
   136	
   137		return z.norm()
   138	}
   139	
   140	func (z nat) sub(x, y nat) nat {
   141		m := len(x)
   142		n := len(y)
   143	
   144		switch {
   145		case m < n:
   146			panic("underflow")
   147		case m == 0:
   148			// n == 0 because m >= n; result is 0
   149			return z[:0]
   150		case n == 0:
   151			// result is x
   152			return z.set(x)
   153		}
   154		// m > 0
   155	
   156		z = z.make(m)
   157		c := subVV(z[0:n], x, y)
   158		if m > n {
   159			c = subVW(z[n:], x[n:], c)
   160		}
   161		if c != 0 {
   162			panic("underflow")
   163		}
   164	
   165		return z.norm()
   166	}
   167	
   168	func (x nat) cmp(y nat) (r int) {
   169		m := len(x)
   170		n := len(y)
   171		if m != n || m == 0 {
   172			switch {
   173			case m < n:
   174				r = -1
   175			case m > n:
   176				r = 1
   177			}
   178			return
   179		}
   180	
   181		i := m - 1
   182		for i > 0 && x[i] == y[i] {
   183			i--
   184		}
   185	
   186		switch {
   187		case x[i] < y[i]:
   188			r = -1
   189		case x[i] > y[i]:
   190			r = 1
   191		}
   192		return
   193	}
   194	
   195	func (z nat) mulAddWW(x nat, y, r Word) nat {
   196		m := len(x)
   197		if m == 0 || y == 0 {
   198			return z.setWord(r) // result is r
   199		}
   200		// m > 0
   201	
   202		z = z.make(m + 1)
   203		z[m] = mulAddVWW(z[0:m], x, y, r)
   204	
   205		return z.norm()
   206	}
   207	
   208	// basicMul multiplies x and y and leaves the result in z.
   209	// The (non-normalized) result is placed in z[0 : len(x) + len(y)].
   210	func basicMul(z, x, y nat) {
   211		z[0 : len(x)+len(y)].clear() // initialize z
   212		for i, d := range y {
   213			if d != 0 {
   214				z[len(x)+i] = addMulVVW(z[i:i+len(x)], x, d)
   215			}
   216		}
   217	}
   218	
   219	// montgomery computes z mod m = x*y*2**(-n*_W) mod m,
   220	// assuming k = -1/m mod 2**_W.
   221	// z is used for storing the result which is returned;
   222	// z must not alias x, y or m.
   223	// See Gueron, "Efficient Software Implementations of Modular Exponentiation".
   224	// https://eprint.iacr.org/2011/239.pdf
   225	// In the terminology of that paper, this is an "Almost Montgomery Multiplication":
   226	// x and y are required to satisfy 0 <= z < 2**(n*_W) and then the result
   227	// z is guaranteed to satisfy 0 <= z < 2**(n*_W), but it may not be < m.
   228	func (z nat) montgomery(x, y, m nat, k Word, n int) nat {
   229		// This code assumes x, y, m are all the same length, n.
   230		// (required by addMulVVW and the for loop).
   231		// It also assumes that x, y are already reduced mod m,
   232		// or else the result will not be properly reduced.
   233		if len(x) != n || len(y) != n || len(m) != n {
   234			panic("math/big: mismatched montgomery number lengths")
   235		}
   236		var c1, c2, c3 Word
   237		z = z.make(n)
   238		z.clear()
   239		for i := 0; i < n; i++ {
   240			d := y[i]
   241			c2 = addMulVVW(z, x, d)
   242			t := z[0] * k
   243			c3 = addMulVVW(z, m, t)
   244			copy(z, z[1:])
   245			cx := c1 + c2
   246			cy := cx + c3
   247			z[n-1] = cy
   248			if cx < c2 || cy < c3 {
   249				c1 = 1
   250			} else {
   251				c1 = 0
   252			}
   253		}
   254		if c1 != 0 {
   255			subVV(z, z, m)
   256		}
   257		return z
   258	}
   259	
   260	// Fast version of z[0:n+n>>1].add(z[0:n+n>>1], x[0:n]) w/o bounds checks.
   261	// Factored out for readability - do not use outside karatsuba.
   262	func karatsubaAdd(z, x nat, n int) {
   263		if c := addVV(z[0:n], z, x); c != 0 {
   264			addVW(z[n:n+n>>1], z[n:], c)
   265		}
   266	}
   267	
   268	// Like karatsubaAdd, but does subtract.
   269	func karatsubaSub(z, x nat, n int) {
   270		if c := subVV(z[0:n], z, x); c != 0 {
   271			subVW(z[n:n+n>>1], z[n:], c)
   272		}
   273	}
   274	
   275	// Operands that are shorter than karatsubaThreshold are multiplied using
   276	// "grade school" multiplication; for longer operands the Karatsuba algorithm
   277	// is used.
   278	var karatsubaThreshold int = 40 // computed by calibrate.go
   279	
   280	// karatsuba multiplies x and y and leaves the result in z.
   281	// Both x and y must have the same length n and n must be a
   282	// power of 2. The result vector z must have len(z) >= 6*n.
   283	// The (non-normalized) result is placed in z[0 : 2*n].
   284	func karatsuba(z, x, y nat) {
   285		n := len(y)
   286	
   287		// Switch to basic multiplication if numbers are odd or small.
   288		// (n is always even if karatsubaThreshold is even, but be
   289		// conservative)
   290		if n&1 != 0 || n < karatsubaThreshold || n < 2 {
   291			basicMul(z, x, y)
   292			return
   293		}
   294		// n&1 == 0 && n >= karatsubaThreshold && n >= 2
   295	
   296		// Karatsuba multiplication is based on the observation that
   297		// for two numbers x and y with:
   298		//
   299		//   x = x1*b + x0
   300		//   y = y1*b + y0
   301		//
   302		// the product x*y can be obtained with 3 products z2, z1, z0
   303		// instead of 4:
   304		//
   305		//   x*y = x1*y1*b*b + (x1*y0 + x0*y1)*b + x0*y0
   306		//       =    z2*b*b +              z1*b +    z0
   307		//
   308		// with:
   309		//
   310		//   xd = x1 - x0
   311		//   yd = y0 - y1
   312		//
   313		//   z1 =      xd*yd                    + z2 + z0
   314		//      = (x1-x0)*(y0 - y1)             + z2 + z0
   315		//      = x1*y0 - x1*y1 - x0*y0 + x0*y1 + z2 + z0
   316		//      = x1*y0 -    z2 -    z0 + x0*y1 + z2 + z0
   317		//      = x1*y0                 + x0*y1
   318	
   319		// split x, y into "digits"
   320		n2 := n >> 1              // n2 >= 1
   321		x1, x0 := x[n2:], x[0:n2] // x = x1*b + y0
   322		y1, y0 := y[n2:], y[0:n2] // y = y1*b + y0
   323	
   324		// z is used for the result and temporary storage:
   325		//
   326		//   6*n     5*n     4*n     3*n     2*n     1*n     0*n
   327		// z = [z2 copy|z0 copy| xd*yd | yd:xd | x1*y1 | x0*y0 ]
   328		//
   329		// For each recursive call of karatsuba, an unused slice of
   330		// z is passed in that has (at least) half the length of the
   331		// caller's z.
   332	
   333		// compute z0 and z2 with the result "in place" in z
   334		karatsuba(z, x0, y0)     // z0 = x0*y0
   335		karatsuba(z[n:], x1, y1) // z2 = x1*y1
   336	
   337		// compute xd (or the negative value if underflow occurs)
   338		s := 1 // sign of product xd*yd
   339		xd := z[2*n : 2*n+n2]
   340		if subVV(xd, x1, x0) != 0 { // x1-x0
   341			s = -s
   342			subVV(xd, x0, x1) // x0-x1
   343		}
   344	
   345		// compute yd (or the negative value if underflow occurs)
   346		yd := z[2*n+n2 : 3*n]
   347		if subVV(yd, y0, y1) != 0 { // y0-y1
   348			s = -s
   349			subVV(yd, y1, y0) // y1-y0
   350		}
   351	
   352		// p = (x1-x0)*(y0-y1) == x1*y0 - x1*y1 - x0*y0 + x0*y1 for s > 0
   353		// p = (x0-x1)*(y0-y1) == x0*y0 - x0*y1 - x1*y0 + x1*y1 for s < 0
   354		p := z[n*3:]
   355		karatsuba(p, xd, yd)
   356	
   357		// save original z2:z0
   358		// (ok to use upper half of z since we're done recursing)
   359		r := z[n*4:]
   360		copy(r, z[:n*2])
   361	
   362		// add up all partial products
   363		//
   364		//   2*n     n     0
   365		// z = [ z2  | z0  ]
   366		//   +    [ z0  ]
   367		//   +    [ z2  ]
   368		//   +    [  p  ]
   369		//
   370		karatsubaAdd(z[n2:], r, n)
   371		karatsubaAdd(z[n2:], r[n:], n)
   372		if s > 0 {
   373			karatsubaAdd(z[n2:], p, n)
   374		} else {
   375			karatsubaSub(z[n2:], p, n)
   376		}
   377	}
   378	
   379	// alias reports whether x and y share the same base array.
   380	func alias(x, y nat) bool {
   381		return cap(x) > 0 && cap(y) > 0 && &x[0:cap(x)][cap(x)-1] == &y[0:cap(y)][cap(y)-1]
   382	}
   383	
   384	// addAt implements z += x<<(_W*i); z must be long enough.
   385	// (we don't use nat.add because we need z to stay the same
   386	// slice, and we don't need to normalize z after each addition)
   387	func addAt(z, x nat, i int) {
   388		if n := len(x); n > 0 {
   389			if c := addVV(z[i:i+n], z[i:], x); c != 0 {
   390				j := i + n
   391				if j < len(z) {
   392					addVW(z[j:], z[j:], c)
   393				}
   394			}
   395		}
   396	}
   397	
   398	func max(x, y int) int {
   399		if x > y {
   400			return x
   401		}
   402		return y
   403	}
   404	
   405	// karatsubaLen computes an approximation to the maximum k <= n such that
   406	// k = p<<i for a number p <= karatsubaThreshold and an i >= 0. Thus, the
   407	// result is the largest number that can be divided repeatedly by 2 before
   408	// becoming about the value of karatsubaThreshold.
   409	func karatsubaLen(n int) int {
   410		i := uint(0)
   411		for n > karatsubaThreshold {
   412			n >>= 1
   413			i++
   414		}
   415		return n << i
   416	}
   417	
   418	func (z nat) mul(x, y nat) nat {
   419		m := len(x)
   420		n := len(y)
   421	
   422		switch {
   423		case m < n:
   424			return z.mul(y, x)
   425		case m == 0 || n == 0:
   426			return z[:0]
   427		case n == 1:
   428			return z.mulAddWW(x, y[0], 0)
   429		}
   430		// m >= n > 1
   431	
   432		// determine if z can be reused
   433		if alias(z, x) || alias(z, y) {
   434			z = nil // z is an alias for x or y - cannot reuse
   435		}
   436	
   437		// use basic multiplication if the numbers are small
   438		if n < karatsubaThreshold {
   439			z = z.make(m + n)
   440			basicMul(z, x, y)
   441			return z.norm()
   442		}
   443		// m >= n && n >= karatsubaThreshold && n >= 2
   444	
   445		// determine Karatsuba length k such that
   446		//
   447		//   x = xh*b + x0  (0 <= x0 < b)
   448		//   y = yh*b + y0  (0 <= y0 < b)
   449		//   b = 1<<(_W*k)  ("base" of digits xi, yi)
   450		//
   451		k := karatsubaLen(n)
   452		// k <= n
   453	
   454		// multiply x0 and y0 via Karatsuba
   455		x0 := x[0:k]              // x0 is not normalized
   456		y0 := y[0:k]              // y0 is not normalized
   457		z = z.make(max(6*k, m+n)) // enough space for karatsuba of x0*y0 and full result of x*y
   458		karatsuba(z, x0, y0)
   459		z = z[0 : m+n]  // z has final length but may be incomplete
   460		z[2*k:].clear() // upper portion of z is garbage (and 2*k <= m+n since k <= n <= m)
   461	
   462		// If xh != 0 or yh != 0, add the missing terms to z. For
   463		//
   464		//   xh = xi*b^i + ... + x2*b^2 + x1*b (0 <= xi < b)
   465		//   yh =                         y1*b (0 <= y1 < b)
   466		//
   467		// the missing terms are
   468		//
   469		//   x0*y1*b and xi*y0*b^i, xi*y1*b^(i+1) for i > 0
   470		//
   471		// since all the yi for i > 1 are 0 by choice of k: If any of them
   472		// were > 0, then yh >= b^2 and thus y >= b^2. Then k' = k*2 would
   473		// be a larger valid threshold contradicting the assumption about k.
   474		//
   475		if k < n || m != n {
   476			var t nat
   477	
   478			// add x0*y1*b
   479			x0 := x0.norm()
   480			y1 := y[k:]       // y1 is normalized because y is
   481			t = t.mul(x0, y1) // update t so we don't lose t's underlying array
   482			addAt(z, t, k)
   483	
   484			// add xi*y0<<i, xi*y1*b<<(i+k)
   485			y0 := y0.norm()
   486			for i := k; i < len(x); i += k {
   487				xi := x[i:]
   488				if len(xi) > k {
   489					xi = xi[:k]
   490				}
   491				xi = xi.norm()
   492				t = t.mul(xi, y0)
   493				addAt(z, t, i)
   494				t = t.mul(xi, y1)
   495				addAt(z, t, i+k)
   496			}
   497		}
   498	
   499		return z.norm()
   500	}
   501	
   502	// mulRange computes the product of all the unsigned integers in the
   503	// range [a, b] inclusively. If a > b (empty range), the result is 1.
   504	func (z nat) mulRange(a, b uint64) nat {
   505		switch {
   506		case a == 0:
   507			// cut long ranges short (optimization)
   508			return z.setUint64(0)
   509		case a > b:
   510			return z.setUint64(1)
   511		case a == b:
   512			return z.setUint64(a)
   513		case a+1 == b:
   514			return z.mul(nat(nil).setUint64(a), nat(nil).setUint64(b))
   515		}
   516		m := (a + b) / 2
   517		return z.mul(nat(nil).mulRange(a, m), nat(nil).mulRange(m+1, b))
   518	}
   519	
   520	// q = (x-r)/y, with 0 <= r < y
   521	func (z nat) divW(x nat, y Word) (q nat, r Word) {
   522		m := len(x)
   523		switch {
   524		case y == 0:
   525			panic("division by zero")
   526		case y == 1:
   527			q = z.set(x) // result is x
   528			return
   529		case m == 0:
   530			q = z[:0] // result is 0
   531			return
   532		}
   533		// m > 0
   534		z = z.make(m)
   535		r = divWVW(z, 0, x, y)
   536		q = z.norm()
   537		return
   538	}
   539	
   540	func (z nat) div(z2, u, v nat) (q, r nat) {
   541		if len(v) == 0 {
   542			panic("division by zero")
   543		}
   544	
   545		if u.cmp(v) < 0 {
   546			q = z[:0]
   547			r = z2.set(u)
   548			return
   549		}
   550	
   551		if len(v) == 1 {
   552			var r2 Word
   553			q, r2 = z.divW(u, v[0])
   554			r = z2.setWord(r2)
   555			return
   556		}
   557	
   558		q, r = z.divLarge(z2, u, v)
   559		return
   560	}
   561	
   562	// q = (uIn-r)/v, with 0 <= r < y
   563	// Uses z as storage for q, and u as storage for r if possible.
   564	// See Knuth, Volume 2, section 4.3.1, Algorithm D.
   565	// Preconditions:
   566	//    len(v) >= 2
   567	//    len(uIn) >= len(v)
   568	func (z nat) divLarge(u, uIn, v nat) (q, r nat) {
   569		n := len(v)
   570		m := len(uIn) - n
   571	
   572		// determine if z can be reused
   573		// TODO(gri) should find a better solution - this if statement
   574		//           is very costly (see e.g. time pidigits -s -n 10000)
   575		if alias(z, uIn) || alias(z, v) {
   576			z = nil // z is an alias for uIn or v - cannot reuse
   577		}
   578		q = z.make(m + 1)
   579	
   580		qhatv := make(nat, n+1)
   581		if alias(u, uIn) || alias(u, v) {
   582			u = nil // u is an alias for uIn or v - cannot reuse
   583		}
   584		u = u.make(len(uIn) + 1)
   585		u.clear() // TODO(gri) no need to clear if we allocated a new u
   586	
   587		// D1.
   588		shift := nlz(v[n-1])
   589		if shift > 0 {
   590			// do not modify v, it may be used by another goroutine simultaneously
   591			v1 := make(nat, n)
   592			shlVU(v1, v, shift)
   593			v = v1
   594		}
   595		u[len(uIn)] = shlVU(u[0:len(uIn)], uIn, shift)
   596	
   597		// D2.
   598		for j := m; j >= 0; j-- {
   599			// D3.
   600			qhat := Word(_M)
   601			if u[j+n] != v[n-1] {
   602				var rhat Word
   603				qhat, rhat = divWW(u[j+n], u[j+n-1], v[n-1])
   604	
   605				// x1 | x2 = q̂v_{n-2}
   606				x1, x2 := mulWW(qhat, v[n-2])
   607				// test if q̂v_{n-2} > br̂ + u_{j+n-2}
   608				for greaterThan(x1, x2, rhat, u[j+n-2]) {
   609					qhat--
   610					prevRhat := rhat
   611					rhat += v[n-1]
   612					// v[n-1] >= 0, so this tests for overflow.
   613					if rhat < prevRhat {
   614						break
   615					}
   616					x1, x2 = mulWW(qhat, v[n-2])
   617				}
   618			}
   619	
   620			// D4.
   621			qhatv[n] = mulAddVWW(qhatv[0:n], v, qhat, 0)
   622	
   623			c := subVV(u[j:j+len(qhatv)], u[j:], qhatv)
   624			if c != 0 {
   625				c := addVV(u[j:j+n], u[j:], v)
   626				u[j+n] += c
   627				qhat--
   628			}
   629	
   630			q[j] = qhat
   631		}
   632	
   633		q = q.norm()
   634		shrVU(u, u, shift)
   635		r = u.norm()
   636	
   637		return q, r
   638	}
   639	
   640	// Length of x in bits. x must be normalized.
   641	func (x nat) bitLen() int {
   642		if i := len(x) - 1; i >= 0 {
   643			return i*_W + bitLen(x[i])
   644		}
   645		return 0
   646	}
   647	
   648	const deBruijn32 = 0x077CB531
   649	
   650	var deBruijn32Lookup = []byte{
   651		0, 1, 28, 2, 29, 14, 24, 3, 30, 22, 20, 15, 25, 17, 4, 8,
   652		31, 27, 13, 23, 21, 19, 16, 7, 26, 12, 18, 6, 11, 5, 10, 9,
   653	}
   654	
   655	const deBruijn64 = 0x03f79d71b4ca8b09
   656	
   657	var deBruijn64Lookup = []byte{
   658		0, 1, 56, 2, 57, 49, 28, 3, 61, 58, 42, 50, 38, 29, 17, 4,
   659		62, 47, 59, 36, 45, 43, 51, 22, 53, 39, 33, 30, 24, 18, 12, 5,
   660		63, 55, 48, 27, 60, 41, 37, 16, 46, 35, 44, 21, 52, 32, 23, 11,
   661		54, 26, 40, 15, 34, 20, 31, 10, 25, 14, 19, 9, 13, 8, 7, 6,
   662	}
   663	
   664	// trailingZeroBits returns the number of consecutive least significant zero
   665	// bits of x.
   666	func trailingZeroBits(x Word) uint {
   667		// x & -x leaves only the right-most bit set in the word. Let k be the
   668		// index of that bit. Since only a single bit is set, the value is two
   669		// to the power of k. Multiplying by a power of two is equivalent to
   670		// left shifting, in this case by k bits.  The de Bruijn constant is
   671		// such that all six bit, consecutive substrings are distinct.
   672		// Therefore, if we have a left shifted version of this constant we can
   673		// find by how many bits it was shifted by looking at which six bit
   674		// substring ended up at the top of the word.
   675		// (Knuth, volume 4, section 7.3.1)
   676		switch _W {
   677		case 32:
   678			return uint(deBruijn32Lookup[((x&-x)*deBruijn32)>>27])
   679		case 64:
   680			return uint(deBruijn64Lookup[((x&-x)*(deBruijn64&_M))>>58])
   681		default:
   682			panic("unknown word size")
   683		}
   684	}
   685	
   686	// trailingZeroBits returns the number of consecutive least significant zero
   687	// bits of x.
   688	func (x nat) trailingZeroBits() uint {
   689		if len(x) == 0 {
   690			return 0
   691		}
   692		var i uint
   693		for x[i] == 0 {
   694			i++
   695		}
   696		// x[i] != 0
   697		return i*_W + trailingZeroBits(x[i])
   698	}
   699	
   700	// z = x << s
   701	func (z nat) shl(x nat, s uint) nat {
   702		m := len(x)
   703		if m == 0 {
   704			return z[:0]
   705		}
   706		// m > 0
   707	
   708		n := m + int(s/_W)
   709		z = z.make(n + 1)
   710		z[n] = shlVU(z[n-m:n], x, s%_W)
   711		z[0 : n-m].clear()
   712	
   713		return z.norm()
   714	}
   715	
   716	// z = x >> s
   717	func (z nat) shr(x nat, s uint) nat {
   718		m := len(x)
   719		n := m - int(s/_W)
   720		if n <= 0 {
   721			return z[:0]
   722		}
   723		// n > 0
   724	
   725		z = z.make(n)
   726		shrVU(z, x[m-n:], s%_W)
   727	
   728		return z.norm()
   729	}
   730	
   731	func (z nat) setBit(x nat, i uint, b uint) nat {
   732		j := int(i / _W)
   733		m := Word(1) << (i % _W)
   734		n := len(x)
   735		switch b {
   736		case 0:
   737			z = z.make(n)
   738			copy(z, x)
   739			if j >= n {
   740				// no need to grow
   741				return z
   742			}
   743			z[j] &^= m
   744			return z.norm()
   745		case 1:
   746			if j >= n {
   747				z = z.make(j + 1)
   748				z[n:].clear()
   749			} else {
   750				z = z.make(n)
   751			}
   752			copy(z, x)
   753			z[j] |= m
   754			// no need to normalize
   755			return z
   756		}
   757		panic("set bit is not 0 or 1")
   758	}
   759	
   760	// bit returns the value of the i'th bit, with lsb == bit 0.
   761	func (x nat) bit(i uint) uint {
   762		j := i / _W
   763		if j >= uint(len(x)) {
   764			return 0
   765		}
   766		// 0 <= j < len(x)
   767		return uint(x[j] >> (i % _W) & 1)
   768	}
   769	
   770	// sticky returns 1 if there's a 1 bit within the
   771	// i least significant bits, otherwise it returns 0.
   772	func (x nat) sticky(i uint) uint {
   773		j := i / _W
   774		if j >= uint(len(x)) {
   775			if len(x) == 0 {
   776				return 0
   777			}
   778			return 1
   779		}
   780		// 0 <= j < len(x)
   781		for _, x := range x[:j] {
   782			if x != 0 {
   783				return 1
   784			}
   785		}
   786		if x[j]<<(_W-i%_W) != 0 {
   787			return 1
   788		}
   789		return 0
   790	}
   791	
   792	func (z nat) and(x, y nat) nat {
   793		m := len(x)
   794		n := len(y)
   795		if m > n {
   796			m = n
   797		}
   798		// m <= n
   799	
   800		z = z.make(m)
   801		for i := 0; i < m; i++ {
   802			z[i] = x[i] & y[i]
   803		}
   804	
   805		return z.norm()
   806	}
   807	
   808	func (z nat) andNot(x, y nat) nat {
   809		m := len(x)
   810		n := len(y)
   811		if n > m {
   812			n = m
   813		}
   814		// m >= n
   815	
   816		z = z.make(m)
   817		for i := 0; i < n; i++ {
   818			z[i] = x[i] &^ y[i]
   819		}
   820		copy(z[n:m], x[n:m])
   821	
   822		return z.norm()
   823	}
   824	
   825	func (z nat) or(x, y nat) nat {
   826		m := len(x)
   827		n := len(y)
   828		s := x
   829		if m < n {
   830			n, m = m, n
   831			s = y
   832		}
   833		// m >= n
   834	
   835		z = z.make(m)
   836		for i := 0; i < n; i++ {
   837			z[i] = x[i] | y[i]
   838		}
   839		copy(z[n:m], s[n:m])
   840	
   841		return z.norm()
   842	}
   843	
   844	func (z nat) xor(x, y nat) nat {
   845		m := len(x)
   846		n := len(y)
   847		s := x
   848		if m < n {
   849			n, m = m, n
   850			s = y
   851		}
   852		// m >= n
   853	
   854		z = z.make(m)
   855		for i := 0; i < n; i++ {
   856			z[i] = x[i] ^ y[i]
   857		}
   858		copy(z[n:m], s[n:m])
   859	
   860		return z.norm()
   861	}
   862	
   863	// greaterThan reports whether (x1<<_W + x2) > (y1<<_W + y2)
   864	func greaterThan(x1, x2, y1, y2 Word) bool {
   865		return x1 > y1 || x1 == y1 && x2 > y2
   866	}
   867	
   868	// modW returns x % d.
   869	func (x nat) modW(d Word) (r Word) {
   870		// TODO(agl): we don't actually need to store the q value.
   871		var q nat
   872		q = q.make(len(x))
   873		return divWVW(q, 0, x, d)
   874	}
   875	
   876	// random creates a random integer in [0..limit), using the space in z if
   877	// possible. n is the bit length of limit.
   878	func (z nat) random(rand *rand.Rand, limit nat, n int) nat {
   879		if alias(z, limit) {
   880			z = nil // z is an alias for limit - cannot reuse
   881		}
   882		z = z.make(len(limit))
   883	
   884		bitLengthOfMSW := uint(n % _W)
   885		if bitLengthOfMSW == 0 {
   886			bitLengthOfMSW = _W
   887		}
   888		mask := Word((1 << bitLengthOfMSW) - 1)
   889	
   890		for {
   891			switch _W {
   892			case 32:
   893				for i := range z {
   894					z[i] = Word(rand.Uint32())
   895				}
   896			case 64:
   897				for i := range z {
   898					z[i] = Word(rand.Uint32()) | Word(rand.Uint32())<<32
   899				}
   900			default:
   901				panic("unknown word size")
   902			}
   903			z[len(limit)-1] &= mask
   904			if z.cmp(limit) < 0 {
   905				break
   906			}
   907		}
   908	
   909		return z.norm()
   910	}
   911	
   912	// If m != 0 (i.e., len(m) != 0), expNN sets z to x**y mod m;
   913	// otherwise it sets z to x**y. The result is the value of z.
   914	func (z nat) expNN(x, y, m nat) nat {
   915		if alias(z, x) || alias(z, y) {
   916			// We cannot allow in-place modification of x or y.
   917			z = nil
   918		}
   919	
   920		// x**y mod 1 == 0
   921		if len(m) == 1 && m[0] == 1 {
   922			return z.setWord(0)
   923		}
   924		// m == 0 || m > 1
   925	
   926		// x**0 == 1
   927		if len(y) == 0 {
   928			return z.setWord(1)
   929		}
   930		// y > 0
   931	
   932		// x**1 mod m == x mod m
   933		if len(y) == 1 && y[0] == 1 && len(m) != 0 {
   934			_, z = z.div(z, x, m)
   935			return z
   936		}
   937		// y > 1
   938	
   939		if len(m) != 0 {
   940			// We likely end up being as long as the modulus.
   941			z = z.make(len(m))
   942		}
   943		z = z.set(x)
   944	
   945		// If the base is non-trivial and the exponent is large, we use
   946		// 4-bit, windowed exponentiation. This involves precomputing 14 values
   947		// (x^2...x^15) but then reduces the number of multiply-reduces by a
   948		// third. Even for a 32-bit exponent, this reduces the number of
   949		// operations. Uses Montgomery method for odd moduli.
   950		if len(x) > 1 && len(y) > 1 && len(m) > 0 {
   951			if m[0]&1 == 1 {
   952				return z.expNNMontgomery(x, y, m)
   953			}
   954			return z.expNNWindowed(x, y, m)
   955		}
   956	
   957		v := y[len(y)-1] // v > 0 because y is normalized and y > 0
   958		shift := nlz(v) + 1
   959		v <<= shift
   960		var q nat
   961	
   962		const mask = 1 << (_W - 1)
   963	
   964		// We walk through the bits of the exponent one by one. Each time we
   965		// see a bit, we square, thus doubling the power. If the bit is a one,
   966		// we also multiply by x, thus adding one to the power.
   967	
   968		w := _W - int(shift)
   969		// zz and r are used to avoid allocating in mul and div as
   970		// otherwise the arguments would alias.
   971		var zz, r nat
   972		for j := 0; j < w; j++ {
   973			zz = zz.mul(z, z)
   974			zz, z = z, zz
   975	
   976			if v&mask != 0 {
   977				zz = zz.mul(z, x)
   978				zz, z = z, zz
   979			}
   980	
   981			if len(m) != 0 {
   982				zz, r = zz.div(r, z, m)
   983				zz, r, q, z = q, z, zz, r
   984			}
   985	
   986			v <<= 1
   987		}
   988	
   989		for i := len(y) - 2; i >= 0; i-- {
   990			v = y[i]
   991	
   992			for j := 0; j < _W; j++ {
   993				zz = zz.mul(z, z)
   994				zz, z = z, zz
   995	
   996				if v&mask != 0 {
   997					zz = zz.mul(z, x)
   998					zz, z = z, zz
   999				}
  1000	
  1001				if len(m) != 0 {
  1002					zz, r = zz.div(r, z, m)
  1003					zz, r, q, z = q, z, zz, r
  1004				}
  1005	
  1006				v <<= 1
  1007			}
  1008		}
  1009	
  1010		return z.norm()
  1011	}
  1012	
  1013	// expNNWindowed calculates x**y mod m using a fixed, 4-bit window.
  1014	func (z nat) expNNWindowed(x, y, m nat) nat {
  1015		// zz and r are used to avoid allocating in mul and div as otherwise
  1016		// the arguments would alias.
  1017		var zz, r nat
  1018	
  1019		const n = 4
  1020		// powers[i] contains x^i.
  1021		var powers [1 << n]nat
  1022		powers[0] = natOne
  1023		powers[1] = x
  1024		for i := 2; i < 1<<n; i += 2 {
  1025			p2, p, p1 := &powers[i/2], &powers[i], &powers[i+1]
  1026			*p = p.mul(*p2, *p2)
  1027			zz, r = zz.div(r, *p, m)
  1028			*p, r = r, *p
  1029			*p1 = p1.mul(*p, x)
  1030			zz, r = zz.div(r, *p1, m)
  1031			*p1, r = r, *p1
  1032		}
  1033	
  1034		z = z.setWord(1)
  1035	
  1036		for i := len(y) - 1; i >= 0; i-- {
  1037			yi := y[i]
  1038			for j := 0; j < _W; j += n {
  1039				if i != len(y)-1 || j != 0 {
  1040					// Unrolled loop for significant performance
  1041					// gain.  Use go test -bench=".*" in crypto/rsa
  1042					// to check performance before making changes.
  1043					zz = zz.mul(z, z)
  1044					zz, z = z, zz
  1045					zz, r = zz.div(r, z, m)
  1046					z, r = r, z
  1047	
  1048					zz = zz.mul(z, z)
  1049					zz, z = z, zz
  1050					zz, r = zz.div(r, z, m)
  1051					z, r = r, z
  1052	
  1053					zz = zz.mul(z, z)
  1054					zz, z = z, zz
  1055					zz, r = zz.div(r, z, m)
  1056					z, r = r, z
  1057	
  1058					zz = zz.mul(z, z)
  1059					zz, z = z, zz
  1060					zz, r = zz.div(r, z, m)
  1061					z, r = r, z
  1062				}
  1063	
  1064				zz = zz.mul(z, powers[yi>>(_W-n)])
  1065				zz, z = z, zz
  1066				zz, r = zz.div(r, z, m)
  1067				z, r = r, z
  1068	
  1069				yi <<= n
  1070			}
  1071		}
  1072	
  1073		return z.norm()
  1074	}
  1075	
  1076	// expNNMontgomery calculates x**y mod m using a fixed, 4-bit window.
  1077	// Uses Montgomery representation.
  1078	func (z nat) expNNMontgomery(x, y, m nat) nat {
  1079		var zz, one, rr, RR nat
  1080	
  1081		numWords := len(m)
  1082	
  1083		// We want the lengths of x and m to be equal.
  1084		if len(x) > numWords {
  1085			_, rr = rr.div(rr, x, m)
  1086		} else if len(x) < numWords {
  1087			rr = rr.make(numWords)
  1088			rr.clear()
  1089			for i := range x {
  1090				rr[i] = x[i]
  1091			}
  1092		} else {
  1093			rr = x
  1094		}
  1095		x = rr
  1096	
  1097		// Ideally the precomputations would be performed outside, and reused
  1098		// k0 = -mˆ-1 mod 2ˆ_W. Algorithm from: Dumas, J.G. "On Newton–Raphson
  1099		// Iteration for Multiplicative Inverses Modulo Prime Powers".
  1100		k0 := 2 - m[0]
  1101		t := m[0] - 1
  1102		for i := 1; i < _W; i <<= 1 {
  1103			t *= t
  1104			k0 *= (t + 1)
  1105		}
  1106		k0 = -k0
  1107	
  1108		// RR = 2ˆ(2*_W*len(m)) mod m
  1109		RR = RR.setWord(1)
  1110		zz = zz.shl(RR, uint(2*numWords*_W))
  1111		_, RR = RR.div(RR, zz, m)
  1112		if len(RR) < numWords {
  1113			zz = zz.make(numWords)
  1114			copy(zz, RR)
  1115			RR = zz
  1116		}
  1117		// one = 1, with equal length to that of m
  1118		one = one.make(numWords)
  1119		one.clear()
  1120		one[0] = 1
  1121	
  1122		const n = 4
  1123		// powers[i] contains x^i
  1124		var powers [1 << n]nat
  1125		powers[0] = powers[0].montgomery(one, RR, m, k0, numWords)
  1126		powers[1] = powers[1].montgomery(x, RR, m, k0, numWords)
  1127		for i := 2; i < 1<<n; i++ {
  1128			powers[i] = powers[i].montgomery(powers[i-1], powers[1], m, k0, numWords)
  1129		}
  1130	
  1131		// initialize z = 1 (Montgomery 1)
  1132		z = z.make(numWords)
  1133		copy(z, powers[0])
  1134	
  1135		zz = zz.make(numWords)
  1136	
  1137		// same windowed exponent, but with Montgomery multiplications
  1138		for i := len(y) - 1; i >= 0; i-- {
  1139			yi := y[i]
  1140			for j := 0; j < _W; j += n {
  1141				if i != len(y)-1 || j != 0 {
  1142					zz = zz.montgomery(z, z, m, k0, numWords)
  1143					z = z.montgomery(zz, zz, m, k0, numWords)
  1144					zz = zz.montgomery(z, z, m, k0, numWords)
  1145					z = z.montgomery(zz, zz, m, k0, numWords)
  1146				}
  1147				zz = zz.montgomery(z, powers[yi>>(_W-n)], m, k0, numWords)
  1148				z, zz = zz, z
  1149				yi <<= n
  1150			}
  1151		}
  1152		// convert to regular number
  1153		zz = zz.montgomery(z, one, m, k0, numWords)
  1154	
  1155		// One last reduction, just in case.
  1156		// See golang.org/issue/13907.
  1157		if zz.cmp(m) >= 0 {
  1158			// Common case is m has high bit set; in that case,
  1159			// since zz is the same length as m, there can be just
  1160			// one multiple of m to remove. Just subtract.
  1161			// We think that the subtract should be sufficient in general,
  1162			// so do that unconditionally, but double-check,
  1163			// in case our beliefs are wrong.
  1164			// The div is not expected to be reached.
  1165			zz = zz.sub(zz, m)
  1166			if zz.cmp(m) >= 0 {
  1167				_, zz = nat(nil).div(nil, zz, m)
  1168			}
  1169		}
  1170	
  1171		return zz.norm()
  1172	}
  1173	
  1174	// probablyPrime performs reps Miller-Rabin tests to check whether n is prime.
  1175	// If it returns true, n is prime with probability 1 - 1/4^reps.
  1176	// If it returns false, n is not prime.
  1177	func (n nat) probablyPrime(reps int) bool {
  1178		if len(n) == 0 {
  1179			return false
  1180		}
  1181	
  1182		if len(n) == 1 {
  1183			if n[0] < 2 {
  1184				return false
  1185			}
  1186	
  1187			if n[0]%2 == 0 {
  1188				return n[0] == 2
  1189			}
  1190	
  1191			// We have to exclude these cases because we reject all
  1192			// multiples of these numbers below.
  1193			switch n[0] {
  1194			case 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53:
  1195				return true
  1196			}
  1197		}
  1198	
  1199		if n[0]&1 == 0 {
  1200			return false // n is even
  1201		}
  1202	
  1203		const primesProduct32 = 0xC0CFD797         // Π {p ∈ primes, 2 < p <= 29}
  1204		const primesProduct64 = 0xE221F97C30E94E1D // Π {p ∈ primes, 2 < p <= 53}
  1205	
  1206		var r Word
  1207		switch _W {
  1208		case 32:
  1209			r = n.modW(primesProduct32)
  1210		case 64:
  1211			r = n.modW(primesProduct64 & _M)
  1212		default:
  1213			panic("Unknown word size")
  1214		}
  1215	
  1216		if r%3 == 0 || r%5 == 0 || r%7 == 0 || r%11 == 0 ||
  1217			r%13 == 0 || r%17 == 0 || r%19 == 0 || r%23 == 0 || r%29 == 0 {
  1218			return false
  1219		}
  1220	
  1221		if _W == 64 && (r%31 == 0 || r%37 == 0 || r%41 == 0 ||
  1222			r%43 == 0 || r%47 == 0 || r%53 == 0) {
  1223			return false
  1224		}
  1225	
  1226		nm1 := nat(nil).sub(n, natOne)
  1227		// determine q, k such that nm1 = q << k
  1228		k := nm1.trailingZeroBits()
  1229		q := nat(nil).shr(nm1, k)
  1230	
  1231		nm3 := nat(nil).sub(nm1, natTwo)
  1232		rand := rand.New(rand.NewSource(int64(n[0])))
  1233	
  1234		var x, y, quotient nat
  1235		nm3Len := nm3.bitLen()
  1236	
  1237	NextRandom:
  1238		for i := 0; i < reps; i++ {
  1239			x = x.random(rand, nm3, nm3Len)
  1240			x = x.add(x, natTwo)
  1241			y = y.expNN(x, q, n)
  1242			if y.cmp(natOne) == 0 || y.cmp(nm1) == 0 {
  1243				continue
  1244			}
  1245			for j := uint(1); j < k; j++ {
  1246				y = y.mul(y, y)
  1247				quotient, y = quotient.div(y, y, n)
  1248				if y.cmp(nm1) == 0 {
  1249					continue NextRandom
  1250				}
  1251				if y.cmp(natOne) == 0 {
  1252					return false
  1253				}
  1254			}
  1255			return false
  1256		}
  1257	
  1258		return true
  1259	}
  1260	
  1261	// bytes writes the value of z into buf using big-endian encoding.
  1262	// len(buf) must be >= len(z)*_S. The value of z is encoded in the
  1263	// slice buf[i:]. The number i of unused bytes at the beginning of
  1264	// buf is returned as result.
  1265	func (z nat) bytes(buf []byte) (i int) {
  1266		i = len(buf)
  1267		for _, d := range z {
  1268			for j := 0; j < _S; j++ {
  1269				i--
  1270				buf[i] = byte(d)
  1271				d >>= 8
  1272			}
  1273		}
  1274	
  1275		for i < len(buf) && buf[i] == 0 {
  1276			i++
  1277		}
  1278	
  1279		return
  1280	}
  1281	
  1282	// setBytes interprets buf as the bytes of a big-endian unsigned
  1283	// integer, sets z to that value, and returns z.
  1284	func (z nat) setBytes(buf []byte) nat {
  1285		z = z.make((len(buf) + _S - 1) / _S)
  1286	
  1287		k := 0
  1288		s := uint(0)
  1289		var d Word
  1290		for i := len(buf); i > 0; i-- {
  1291			d |= Word(buf[i-1]) << s
  1292			if s += 8; s == _S*8 {
  1293				z[k] = d
  1294				k++
  1295				s = 0
  1296				d = 0
  1297			}
  1298		}
  1299		if k < len(z) {
  1300			z[k] = d
  1301		}
  1302	
  1303		return z.norm()
  1304	}
  1305	

View as plain text