...
Run Format

Source file src/math/big/nat.go

     1	// Copyright 2009 The Go Authors. All rights reserved.
     2	// Use of this source code is governed by a BSD-style
     3	// license that can be found in the LICENSE file.
     4	
     5	// This file implements unsigned multi-precision integers (natural
     6	// numbers). They are the building blocks for the implementation
     7	// of signed integers, rationals, and floating-point numbers.
     8	
     9	package big
    10	
    11	import (
    12		"math/rand"
    13		"sync"
    14	)
    15	
    16	// An unsigned integer x of the form
    17	//
    18	//   x = x[n-1]*_B^(n-1) + x[n-2]*_B^(n-2) + ... + x[1]*_B + x[0]
    19	//
    20	// with 0 <= x[i] < _B and 0 <= i < n is stored in a slice of length n,
    21	// with the digits x[i] as the slice elements.
    22	//
    23	// A number is normalized if the slice contains no leading 0 digits.
    24	// During arithmetic operations, denormalized values may occur but are
    25	// always normalized before returning the final result. The normalized
    26	// representation of 0 is the empty or nil slice (length = 0).
    27	//
    28	type nat []Word
    29	
    30	var (
    31		natOne = nat{1}
    32		natTwo = nat{2}
    33		natTen = nat{10}
    34	)
    35	
    36	func (z nat) clear() {
    37		for i := range z {
    38			z[i] = 0
    39		}
    40	}
    41	
    42	func (z nat) norm() nat {
    43		i := len(z)
    44		for i > 0 && z[i-1] == 0 {
    45			i--
    46		}
    47		return z[0:i]
    48	}
    49	
    50	func (z nat) make(n int) nat {
    51		if n <= cap(z) {
    52			return z[:n] // reuse z
    53		}
    54		// Choosing a good value for e has significant performance impact
    55		// because it increases the chance that a value can be reused.
    56		const e = 4 // extra capacity
    57		return make(nat, n, n+e)
    58	}
    59	
    60	func (z nat) setWord(x Word) nat {
    61		if x == 0 {
    62			return z[:0]
    63		}
    64		z = z.make(1)
    65		z[0] = x
    66		return z
    67	}
    68	
    69	func (z nat) setUint64(x uint64) nat {
    70		// single-digit values
    71		if w := Word(x); uint64(w) == x {
    72			return z.setWord(w)
    73		}
    74	
    75		// compute number of words n required to represent x
    76		n := 0
    77		for t := x; t > 0; t >>= _W {
    78			n++
    79		}
    80	
    81		// split x into n words
    82		z = z.make(n)
    83		for i := range z {
    84			z[i] = Word(x & _M)
    85			x >>= _W
    86		}
    87	
    88		return z
    89	}
    90	
    91	func (z nat) set(x nat) nat {
    92		z = z.make(len(x))
    93		copy(z, x)
    94		return z
    95	}
    96	
    97	func (z nat) add(x, y nat) nat {
    98		m := len(x)
    99		n := len(y)
   100	
   101		switch {
   102		case m < n:
   103			return z.add(y, x)
   104		case m == 0:
   105			// n == 0 because m >= n; result is 0
   106			return z[:0]
   107		case n == 0:
   108			// result is x
   109			return z.set(x)
   110		}
   111		// m > 0
   112	
   113		z = z.make(m + 1)
   114		c := addVV(z[0:n], x, y)
   115		if m > n {
   116			c = addVW(z[n:m], x[n:], c)
   117		}
   118		z[m] = c
   119	
   120		return z.norm()
   121	}
   122	
   123	func (z nat) sub(x, y nat) nat {
   124		m := len(x)
   125		n := len(y)
   126	
   127		switch {
   128		case m < n:
   129			panic("underflow")
   130		case m == 0:
   131			// n == 0 because m >= n; result is 0
   132			return z[:0]
   133		case n == 0:
   134			// result is x
   135			return z.set(x)
   136		}
   137		// m > 0
   138	
   139		z = z.make(m)
   140		c := subVV(z[0:n], x, y)
   141		if m > n {
   142			c = subVW(z[n:], x[n:], c)
   143		}
   144		if c != 0 {
   145			panic("underflow")
   146		}
   147	
   148		return z.norm()
   149	}
   150	
   151	func (x nat) cmp(y nat) (r int) {
   152		m := len(x)
   153		n := len(y)
   154		if m != n || m == 0 {
   155			switch {
   156			case m < n:
   157				r = -1
   158			case m > n:
   159				r = 1
   160			}
   161			return
   162		}
   163	
   164		i := m - 1
   165		for i > 0 && x[i] == y[i] {
   166			i--
   167		}
   168	
   169		switch {
   170		case x[i] < y[i]:
   171			r = -1
   172		case x[i] > y[i]:
   173			r = 1
   174		}
   175		return
   176	}
   177	
   178	func (z nat) mulAddWW(x nat, y, r Word) nat {
   179		m := len(x)
   180		if m == 0 || y == 0 {
   181			return z.setWord(r) // result is r
   182		}
   183		// m > 0
   184	
   185		z = z.make(m + 1)
   186		z[m] = mulAddVWW(z[0:m], x, y, r)
   187	
   188		return z.norm()
   189	}
   190	
   191	// basicMul multiplies x and y and leaves the result in z.
   192	// The (non-normalized) result is placed in z[0 : len(x) + len(y)].
   193	func basicMul(z, x, y nat) {
   194		z[0 : len(x)+len(y)].clear() // initialize z
   195		for i, d := range y {
   196			if d != 0 {
   197				z[len(x)+i] = addMulVVW(z[i:i+len(x)], x, d)
   198			}
   199		}
   200	}
   201	
   202	// montgomery computes z mod m = x*y*2**(-n*_W) mod m,
   203	// assuming k = -1/m mod 2**_W.
   204	// z is used for storing the result which is returned;
   205	// z must not alias x, y or m.
   206	// See Gueron, "Efficient Software Implementations of Modular Exponentiation".
   207	// https://eprint.iacr.org/2011/239.pdf
   208	// In the terminology of that paper, this is an "Almost Montgomery Multiplication":
   209	// x and y are required to satisfy 0 <= z < 2**(n*_W) and then the result
   210	// z is guaranteed to satisfy 0 <= z < 2**(n*_W), but it may not be < m.
   211	func (z nat) montgomery(x, y, m nat, k Word, n int) nat {
   212		// This code assumes x, y, m are all the same length, n.
   213		// (required by addMulVVW and the for loop).
   214		// It also assumes that x, y are already reduced mod m,
   215		// or else the result will not be properly reduced.
   216		if len(x) != n || len(y) != n || len(m) != n {
   217			panic("math/big: mismatched montgomery number lengths")
   218		}
   219		z = z.make(n)
   220		z.clear()
   221		var c Word
   222		for i := 0; i < n; i++ {
   223			d := y[i]
   224			c2 := addMulVVW(z, x, d)
   225			t := z[0] * k
   226			c3 := addMulVVW(z, m, t)
   227			copy(z, z[1:])
   228			cx := c + c2
   229			cy := cx + c3
   230			z[n-1] = cy
   231			if cx < c2 || cy < c3 {
   232				c = 1
   233			} else {
   234				c = 0
   235			}
   236		}
   237		if c != 0 {
   238			subVV(z, z, m)
   239		}
   240		return z
   241	}
   242	
   243	// Fast version of z[0:n+n>>1].add(z[0:n+n>>1], x[0:n]) w/o bounds checks.
   244	// Factored out for readability - do not use outside karatsuba.
   245	func karatsubaAdd(z, x nat, n int) {
   246		if c := addVV(z[0:n], z, x); c != 0 {
   247			addVW(z[n:n+n>>1], z[n:], c)
   248		}
   249	}
   250	
   251	// Like karatsubaAdd, but does subtract.
   252	func karatsubaSub(z, x nat, n int) {
   253		if c := subVV(z[0:n], z, x); c != 0 {
   254			subVW(z[n:n+n>>1], z[n:], c)
   255		}
   256	}
   257	
   258	// Operands that are shorter than karatsubaThreshold are multiplied using
   259	// "grade school" multiplication; for longer operands the Karatsuba algorithm
   260	// is used.
   261	var karatsubaThreshold int = 40 // computed by calibrate.go
   262	
   263	// karatsuba multiplies x and y and leaves the result in z.
   264	// Both x and y must have the same length n and n must be a
   265	// power of 2. The result vector z must have len(z) >= 6*n.
   266	// The (non-normalized) result is placed in z[0 : 2*n].
   267	func karatsuba(z, x, y nat) {
   268		n := len(y)
   269	
   270		// Switch to basic multiplication if numbers are odd or small.
   271		// (n is always even if karatsubaThreshold is even, but be
   272		// conservative)
   273		if n&1 != 0 || n < karatsubaThreshold || n < 2 {
   274			basicMul(z, x, y)
   275			return
   276		}
   277		// n&1 == 0 && n >= karatsubaThreshold && n >= 2
   278	
   279		// Karatsuba multiplication is based on the observation that
   280		// for two numbers x and y with:
   281		//
   282		//   x = x1*b + x0
   283		//   y = y1*b + y0
   284		//
   285		// the product x*y can be obtained with 3 products z2, z1, z0
   286		// instead of 4:
   287		//
   288		//   x*y = x1*y1*b*b + (x1*y0 + x0*y1)*b + x0*y0
   289		//       =    z2*b*b +              z1*b +    z0
   290		//
   291		// with:
   292		//
   293		//   xd = x1 - x0
   294		//   yd = y0 - y1
   295		//
   296		//   z1 =      xd*yd                    + z2 + z0
   297		//      = (x1-x0)*(y0 - y1)             + z2 + z0
   298		//      = x1*y0 - x1*y1 - x0*y0 + x0*y1 + z2 + z0
   299		//      = x1*y0 -    z2 -    z0 + x0*y1 + z2 + z0
   300		//      = x1*y0                 + x0*y1
   301	
   302		// split x, y into "digits"
   303		n2 := n >> 1              // n2 >= 1
   304		x1, x0 := x[n2:], x[0:n2] // x = x1*b + y0
   305		y1, y0 := y[n2:], y[0:n2] // y = y1*b + y0
   306	
   307		// z is used for the result and temporary storage:
   308		//
   309		//   6*n     5*n     4*n     3*n     2*n     1*n     0*n
   310		// z = [z2 copy|z0 copy| xd*yd | yd:xd | x1*y1 | x0*y0 ]
   311		//
   312		// For each recursive call of karatsuba, an unused slice of
   313		// z is passed in that has (at least) half the length of the
   314		// caller's z.
   315	
   316		// compute z0 and z2 with the result "in place" in z
   317		karatsuba(z, x0, y0)     // z0 = x0*y0
   318		karatsuba(z[n:], x1, y1) // z2 = x1*y1
   319	
   320		// compute xd (or the negative value if underflow occurs)
   321		s := 1 // sign of product xd*yd
   322		xd := z[2*n : 2*n+n2]
   323		if subVV(xd, x1, x0) != 0 { // x1-x0
   324			s = -s
   325			subVV(xd, x0, x1) // x0-x1
   326		}
   327	
   328		// compute yd (or the negative value if underflow occurs)
   329		yd := z[2*n+n2 : 3*n]
   330		if subVV(yd, y0, y1) != 0 { // y0-y1
   331			s = -s
   332			subVV(yd, y1, y0) // y1-y0
   333		}
   334	
   335		// p = (x1-x0)*(y0-y1) == x1*y0 - x1*y1 - x0*y0 + x0*y1 for s > 0
   336		// p = (x0-x1)*(y0-y1) == x0*y0 - x0*y1 - x1*y0 + x1*y1 for s < 0
   337		p := z[n*3:]
   338		karatsuba(p, xd, yd)
   339	
   340		// save original z2:z0
   341		// (ok to use upper half of z since we're done recursing)
   342		r := z[n*4:]
   343		copy(r, z[:n*2])
   344	
   345		// add up all partial products
   346		//
   347		//   2*n     n     0
   348		// z = [ z2  | z0  ]
   349		//   +    [ z0  ]
   350		//   +    [ z2  ]
   351		//   +    [  p  ]
   352		//
   353		karatsubaAdd(z[n2:], r, n)
   354		karatsubaAdd(z[n2:], r[n:], n)
   355		if s > 0 {
   356			karatsubaAdd(z[n2:], p, n)
   357		} else {
   358			karatsubaSub(z[n2:], p, n)
   359		}
   360	}
   361	
   362	// alias reports whether x and y share the same base array.
   363	func alias(x, y nat) bool {
   364		return cap(x) > 0 && cap(y) > 0 && &x[0:cap(x)][cap(x)-1] == &y[0:cap(y)][cap(y)-1]
   365	}
   366	
   367	// addAt implements z += x<<(_W*i); z must be long enough.
   368	// (we don't use nat.add because we need z to stay the same
   369	// slice, and we don't need to normalize z after each addition)
   370	func addAt(z, x nat, i int) {
   371		if n := len(x); n > 0 {
   372			if c := addVV(z[i:i+n], z[i:], x); c != 0 {
   373				j := i + n
   374				if j < len(z) {
   375					addVW(z[j:], z[j:], c)
   376				}
   377			}
   378		}
   379	}
   380	
   381	func max(x, y int) int {
   382		if x > y {
   383			return x
   384		}
   385		return y
   386	}
   387	
   388	// karatsubaLen computes an approximation to the maximum k <= n such that
   389	// k = p<<i for a number p <= karatsubaThreshold and an i >= 0. Thus, the
   390	// result is the largest number that can be divided repeatedly by 2 before
   391	// becoming about the value of karatsubaThreshold.
   392	func karatsubaLen(n int) int {
   393		i := uint(0)
   394		for n > karatsubaThreshold {
   395			n >>= 1
   396			i++
   397		}
   398		return n << i
   399	}
   400	
   401	func (z nat) mul(x, y nat) nat {
   402		m := len(x)
   403		n := len(y)
   404	
   405		switch {
   406		case m < n:
   407			return z.mul(y, x)
   408		case m == 0 || n == 0:
   409			return z[:0]
   410		case n == 1:
   411			return z.mulAddWW(x, y[0], 0)
   412		}
   413		// m >= n > 1
   414	
   415		// determine if z can be reused
   416		if alias(z, x) || alias(z, y) {
   417			z = nil // z is an alias for x or y - cannot reuse
   418		}
   419	
   420		// use basic multiplication if the numbers are small
   421		if n < karatsubaThreshold {
   422			z = z.make(m + n)
   423			basicMul(z, x, y)
   424			return z.norm()
   425		}
   426		// m >= n && n >= karatsubaThreshold && n >= 2
   427	
   428		// determine Karatsuba length k such that
   429		//
   430		//   x = xh*b + x0  (0 <= x0 < b)
   431		//   y = yh*b + y0  (0 <= y0 < b)
   432		//   b = 1<<(_W*k)  ("base" of digits xi, yi)
   433		//
   434		k := karatsubaLen(n)
   435		// k <= n
   436	
   437		// multiply x0 and y0 via Karatsuba
   438		x0 := x[0:k]              // x0 is not normalized
   439		y0 := y[0:k]              // y0 is not normalized
   440		z = z.make(max(6*k, m+n)) // enough space for karatsuba of x0*y0 and full result of x*y
   441		karatsuba(z, x0, y0)
   442		z = z[0 : m+n]  // z has final length but may be incomplete
   443		z[2*k:].clear() // upper portion of z is garbage (and 2*k <= m+n since k <= n <= m)
   444	
   445		// If xh != 0 or yh != 0, add the missing terms to z. For
   446		//
   447		//   xh = xi*b^i + ... + x2*b^2 + x1*b (0 <= xi < b)
   448		//   yh =                         y1*b (0 <= y1 < b)
   449		//
   450		// the missing terms are
   451		//
   452		//   x0*y1*b and xi*y0*b^i, xi*y1*b^(i+1) for i > 0
   453		//
   454		// since all the yi for i > 1 are 0 by choice of k: If any of them
   455		// were > 0, then yh >= b^2 and thus y >= b^2. Then k' = k*2 would
   456		// be a larger valid threshold contradicting the assumption about k.
   457		//
   458		if k < n || m != n {
   459			var t nat
   460	
   461			// add x0*y1*b
   462			x0 := x0.norm()
   463			y1 := y[k:]       // y1 is normalized because y is
   464			t = t.mul(x0, y1) // update t so we don't lose t's underlying array
   465			addAt(z, t, k)
   466	
   467			// add xi*y0<<i, xi*y1*b<<(i+k)
   468			y0 := y0.norm()
   469			for i := k; i < len(x); i += k {
   470				xi := x[i:]
   471				if len(xi) > k {
   472					xi = xi[:k]
   473				}
   474				xi = xi.norm()
   475				t = t.mul(xi, y0)
   476				addAt(z, t, i)
   477				t = t.mul(xi, y1)
   478				addAt(z, t, i+k)
   479			}
   480		}
   481	
   482		return z.norm()
   483	}
   484	
   485	// mulRange computes the product of all the unsigned integers in the
   486	// range [a, b] inclusively. If a > b (empty range), the result is 1.
   487	func (z nat) mulRange(a, b uint64) nat {
   488		switch {
   489		case a == 0:
   490			// cut long ranges short (optimization)
   491			return z.setUint64(0)
   492		case a > b:
   493			return z.setUint64(1)
   494		case a == b:
   495			return z.setUint64(a)
   496		case a+1 == b:
   497			return z.mul(nat(nil).setUint64(a), nat(nil).setUint64(b))
   498		}
   499		m := (a + b) / 2
   500		return z.mul(nat(nil).mulRange(a, m), nat(nil).mulRange(m+1, b))
   501	}
   502	
   503	// q = (x-r)/y, with 0 <= r < y
   504	func (z nat) divW(x nat, y Word) (q nat, r Word) {
   505		m := len(x)
   506		switch {
   507		case y == 0:
   508			panic("division by zero")
   509		case y == 1:
   510			q = z.set(x) // result is x
   511			return
   512		case m == 0:
   513			q = z[:0] // result is 0
   514			return
   515		}
   516		// m > 0
   517		z = z.make(m)
   518		r = divWVW(z, 0, x, y)
   519		q = z.norm()
   520		return
   521	}
   522	
   523	func (z nat) div(z2, u, v nat) (q, r nat) {
   524		if len(v) == 0 {
   525			panic("division by zero")
   526		}
   527	
   528		if u.cmp(v) < 0 {
   529			q = z[:0]
   530			r = z2.set(u)
   531			return
   532		}
   533	
   534		if len(v) == 1 {
   535			var r2 Word
   536			q, r2 = z.divW(u, v[0])
   537			r = z2.setWord(r2)
   538			return
   539		}
   540	
   541		q, r = z.divLarge(z2, u, v)
   542		return
   543	}
   544	
   545	// getNat returns a *nat of len n. The contents may not be zero.
   546	// The pool holds *nat to avoid allocation when converting to interface{}.
   547	func getNat(n int) *nat {
   548		var z *nat
   549		if v := natPool.Get(); v != nil {
   550			z = v.(*nat)
   551		}
   552		if z == nil {
   553			z = new(nat)
   554		}
   555		*z = z.make(n)
   556		return z
   557	}
   558	
   559	func putNat(x *nat) {
   560		natPool.Put(x)
   561	}
   562	
   563	var natPool sync.Pool
   564	
   565	// q = (uIn-r)/v, with 0 <= r < y
   566	// Uses z as storage for q, and u as storage for r if possible.
   567	// See Knuth, Volume 2, section 4.3.1, Algorithm D.
   568	// Preconditions:
   569	//    len(v) >= 2
   570	//    len(uIn) >= len(v)
   571	func (z nat) divLarge(u, uIn, v nat) (q, r nat) {
   572		n := len(v)
   573		m := len(uIn) - n
   574	
   575		// determine if z can be reused
   576		// TODO(gri) should find a better solution - this if statement
   577		//           is very costly (see e.g. time pidigits -s -n 10000)
   578		if alias(z, uIn) || alias(z, v) {
   579			z = nil // z is an alias for uIn or v - cannot reuse
   580		}
   581		q = z.make(m + 1)
   582	
   583		qhatvp := getNat(n + 1)
   584		qhatv := *qhatvp
   585		if alias(u, uIn) || alias(u, v) {
   586			u = nil // u is an alias for uIn or v - cannot reuse
   587		}
   588		u = u.make(len(uIn) + 1)
   589		u.clear() // TODO(gri) no need to clear if we allocated a new u
   590	
   591		// D1.
   592		var v1p *nat
   593		shift := nlz(v[n-1])
   594		if shift > 0 {
   595			// do not modify v, it may be used by another goroutine simultaneously
   596			v1p = getNat(n)
   597			v1 := *v1p
   598			shlVU(v1, v, shift)
   599			v = v1
   600		}
   601		u[len(uIn)] = shlVU(u[0:len(uIn)], uIn, shift)
   602	
   603		// D2.
   604		vn1 := v[n-1]
   605		for j := m; j >= 0; j-- {
   606			// D3.
   607			qhat := Word(_M)
   608			if ujn := u[j+n]; ujn != vn1 {
   609				var rhat Word
   610				qhat, rhat = divWW(ujn, u[j+n-1], vn1)
   611	
   612				// x1 | x2 = q̂v_{n-2}
   613				vn2 := v[n-2]
   614				x1, x2 := mulWW(qhat, vn2)
   615				// test if q̂v_{n-2} > br̂ + u_{j+n-2}
   616				ujn2 := u[j+n-2]
   617				for greaterThan(x1, x2, rhat, ujn2) {
   618					qhat--
   619					prevRhat := rhat
   620					rhat += vn1
   621					// v[n-1] >= 0, so this tests for overflow.
   622					if rhat < prevRhat {
   623						break
   624					}
   625					x1, x2 = mulWW(qhat, vn2)
   626				}
   627			}
   628	
   629			// D4.
   630			qhatv[n] = mulAddVWW(qhatv[0:n], v, qhat, 0)
   631	
   632			c := subVV(u[j:j+len(qhatv)], u[j:], qhatv)
   633			if c != 0 {
   634				c := addVV(u[j:j+n], u[j:], v)
   635				u[j+n] += c
   636				qhat--
   637			}
   638	
   639			q[j] = qhat
   640		}
   641		if v1p != nil {
   642			putNat(v1p)
   643		}
   644		putNat(qhatvp)
   645	
   646		q = q.norm()
   647		shrVU(u, u, shift)
   648		r = u.norm()
   649	
   650		return q, r
   651	}
   652	
   653	// Length of x in bits. x must be normalized.
   654	func (x nat) bitLen() int {
   655		if i := len(x) - 1; i >= 0 {
   656			return i*_W + bitLen(x[i])
   657		}
   658		return 0
   659	}
   660	
   661	const deBruijn32 = 0x077CB531
   662	
   663	var deBruijn32Lookup = [...]byte{
   664		0, 1, 28, 2, 29, 14, 24, 3, 30, 22, 20, 15, 25, 17, 4, 8,
   665		31, 27, 13, 23, 21, 19, 16, 7, 26, 12, 18, 6, 11, 5, 10, 9,
   666	}
   667	
   668	const deBruijn64 = 0x03f79d71b4ca8b09
   669	
   670	var deBruijn64Lookup = [...]byte{
   671		0, 1, 56, 2, 57, 49, 28, 3, 61, 58, 42, 50, 38, 29, 17, 4,
   672		62, 47, 59, 36, 45, 43, 51, 22, 53, 39, 33, 30, 24, 18, 12, 5,
   673		63, 55, 48, 27, 60, 41, 37, 16, 46, 35, 44, 21, 52, 32, 23, 11,
   674		54, 26, 40, 15, 34, 20, 31, 10, 25, 14, 19, 9, 13, 8, 7, 6,
   675	}
   676	
   677	// trailingZeroBits returns the number of consecutive least significant zero
   678	// bits of x.
   679	func trailingZeroBits(x Word) uint {
   680		// x & -x leaves only the right-most bit set in the word. Let k be the
   681		// index of that bit. Since only a single bit is set, the value is two
   682		// to the power of k. Multiplying by a power of two is equivalent to
   683		// left shifting, in this case by k bits. The de Bruijn constant is
   684		// such that all six bit, consecutive substrings are distinct.
   685		// Therefore, if we have a left shifted version of this constant we can
   686		// find by how many bits it was shifted by looking at which six bit
   687		// substring ended up at the top of the word.
   688		// (Knuth, volume 4, section 7.3.1)
   689		switch _W {
   690		case 32:
   691			return uint(deBruijn32Lookup[((x&-x)*deBruijn32)>>27])
   692		case 64:
   693			return uint(deBruijn64Lookup[((x&-x)*(deBruijn64&_M))>>58])
   694		default:
   695			panic("unknown word size")
   696		}
   697	}
   698	
   699	// trailingZeroBits returns the number of consecutive least significant zero
   700	// bits of x.
   701	func (x nat) trailingZeroBits() uint {
   702		if len(x) == 0 {
   703			return 0
   704		}
   705		var i uint
   706		for x[i] == 0 {
   707			i++
   708		}
   709		// x[i] != 0
   710		return i*_W + trailingZeroBits(x[i])
   711	}
   712	
   713	// z = x << s
   714	func (z nat) shl(x nat, s uint) nat {
   715		m := len(x)
   716		if m == 0 {
   717			return z[:0]
   718		}
   719		// m > 0
   720	
   721		n := m + int(s/_W)
   722		z = z.make(n + 1)
   723		z[n] = shlVU(z[n-m:n], x, s%_W)
   724		z[0 : n-m].clear()
   725	
   726		return z.norm()
   727	}
   728	
   729	// z = x >> s
   730	func (z nat) shr(x nat, s uint) nat {
   731		m := len(x)
   732		n := m - int(s/_W)
   733		if n <= 0 {
   734			return z[:0]
   735		}
   736		// n > 0
   737	
   738		z = z.make(n)
   739		shrVU(z, x[m-n:], s%_W)
   740	
   741		return z.norm()
   742	}
   743	
   744	func (z nat) setBit(x nat, i uint, b uint) nat {
   745		j := int(i / _W)
   746		m := Word(1) << (i % _W)
   747		n := len(x)
   748		switch b {
   749		case 0:
   750			z = z.make(n)
   751			copy(z, x)
   752			if j >= n {
   753				// no need to grow
   754				return z
   755			}
   756			z[j] &^= m
   757			return z.norm()
   758		case 1:
   759			if j >= n {
   760				z = z.make(j + 1)
   761				z[n:].clear()
   762			} else {
   763				z = z.make(n)
   764			}
   765			copy(z, x)
   766			z[j] |= m
   767			// no need to normalize
   768			return z
   769		}
   770		panic("set bit is not 0 or 1")
   771	}
   772	
   773	// bit returns the value of the i'th bit, with lsb == bit 0.
   774	func (x nat) bit(i uint) uint {
   775		j := i / _W
   776		if j >= uint(len(x)) {
   777			return 0
   778		}
   779		// 0 <= j < len(x)
   780		return uint(x[j] >> (i % _W) & 1)
   781	}
   782	
   783	// sticky returns 1 if there's a 1 bit within the
   784	// i least significant bits, otherwise it returns 0.
   785	func (x nat) sticky(i uint) uint {
   786		j := i / _W
   787		if j >= uint(len(x)) {
   788			if len(x) == 0 {
   789				return 0
   790			}
   791			return 1
   792		}
   793		// 0 <= j < len(x)
   794		for _, x := range x[:j] {
   795			if x != 0 {
   796				return 1
   797			}
   798		}
   799		if x[j]<<(_W-i%_W) != 0 {
   800			return 1
   801		}
   802		return 0
   803	}
   804	
   805	func (z nat) and(x, y nat) nat {
   806		m := len(x)
   807		n := len(y)
   808		if m > n {
   809			m = n
   810		}
   811		// m <= n
   812	
   813		z = z.make(m)
   814		for i := 0; i < m; i++ {
   815			z[i] = x[i] & y[i]
   816		}
   817	
   818		return z.norm()
   819	}
   820	
   821	func (z nat) andNot(x, y nat) nat {
   822		m := len(x)
   823		n := len(y)
   824		if n > m {
   825			n = m
   826		}
   827		// m >= n
   828	
   829		z = z.make(m)
   830		for i := 0; i < n; i++ {
   831			z[i] = x[i] &^ y[i]
   832		}
   833		copy(z[n:m], x[n:m])
   834	
   835		return z.norm()
   836	}
   837	
   838	func (z nat) or(x, y nat) nat {
   839		m := len(x)
   840		n := len(y)
   841		s := x
   842		if m < n {
   843			n, m = m, n
   844			s = y
   845		}
   846		// m >= n
   847	
   848		z = z.make(m)
   849		for i := 0; i < n; i++ {
   850			z[i] = x[i] | y[i]
   851		}
   852		copy(z[n:m], s[n:m])
   853	
   854		return z.norm()
   855	}
   856	
   857	func (z nat) xor(x, y nat) nat {
   858		m := len(x)
   859		n := len(y)
   860		s := x
   861		if m < n {
   862			n, m = m, n
   863			s = y
   864		}
   865		// m >= n
   866	
   867		z = z.make(m)
   868		for i := 0; i < n; i++ {
   869			z[i] = x[i] ^ y[i]
   870		}
   871		copy(z[n:m], s[n:m])
   872	
   873		return z.norm()
   874	}
   875	
   876	// greaterThan reports whether (x1<<_W + x2) > (y1<<_W + y2)
   877	func greaterThan(x1, x2, y1, y2 Word) bool {
   878		return x1 > y1 || x1 == y1 && x2 > y2
   879	}
   880	
   881	// modW returns x % d.
   882	func (x nat) modW(d Word) (r Word) {
   883		// TODO(agl): we don't actually need to store the q value.
   884		var q nat
   885		q = q.make(len(x))
   886		return divWVW(q, 0, x, d)
   887	}
   888	
   889	// random creates a random integer in [0..limit), using the space in z if
   890	// possible. n is the bit length of limit.
   891	func (z nat) random(rand *rand.Rand, limit nat, n int) nat {
   892		if alias(z, limit) {
   893			z = nil // z is an alias for limit - cannot reuse
   894		}
   895		z = z.make(len(limit))
   896	
   897		bitLengthOfMSW := uint(n % _W)
   898		if bitLengthOfMSW == 0 {
   899			bitLengthOfMSW = _W
   900		}
   901		mask := Word((1 << bitLengthOfMSW) - 1)
   902	
   903		for {
   904			switch _W {
   905			case 32:
   906				for i := range z {
   907					z[i] = Word(rand.Uint32())
   908				}
   909			case 64:
   910				for i := range z {
   911					z[i] = Word(rand.Uint32()) | Word(rand.Uint32())<<32
   912				}
   913			default:
   914				panic("unknown word size")
   915			}
   916			z[len(limit)-1] &= mask
   917			if z.cmp(limit) < 0 {
   918				break
   919			}
   920		}
   921	
   922		return z.norm()
   923	}
   924	
   925	// If m != 0 (i.e., len(m) != 0), expNN sets z to x**y mod m;
   926	// otherwise it sets z to x**y. The result is the value of z.
   927	func (z nat) expNN(x, y, m nat) nat {
   928		if alias(z, x) || alias(z, y) {
   929			// We cannot allow in-place modification of x or y.
   930			z = nil
   931		}
   932	
   933		// x**y mod 1 == 0
   934		if len(m) == 1 && m[0] == 1 {
   935			return z.setWord(0)
   936		}
   937		// m == 0 || m > 1
   938	
   939		// x**0 == 1
   940		if len(y) == 0 {
   941			return z.setWord(1)
   942		}
   943		// y > 0
   944	
   945		// x**1 mod m == x mod m
   946		if len(y) == 1 && y[0] == 1 && len(m) != 0 {
   947			_, z = z.div(z, x, m)
   948			return z
   949		}
   950		// y > 1
   951	
   952		if len(m) != 0 {
   953			// We likely end up being as long as the modulus.
   954			z = z.make(len(m))
   955		}
   956		z = z.set(x)
   957	
   958		// If the base is non-trivial and the exponent is large, we use
   959		// 4-bit, windowed exponentiation. This involves precomputing 14 values
   960		// (x^2...x^15) but then reduces the number of multiply-reduces by a
   961		// third. Even for a 32-bit exponent, this reduces the number of
   962		// operations. Uses Montgomery method for odd moduli.
   963		if x.cmp(natOne) > 0 && len(y) > 1 && len(m) > 0 {
   964			if m[0]&1 == 1 {
   965				return z.expNNMontgomery(x, y, m)
   966			}
   967			return z.expNNWindowed(x, y, m)
   968		}
   969	
   970		v := y[len(y)-1] // v > 0 because y is normalized and y > 0
   971		shift := nlz(v) + 1
   972		v <<= shift
   973		var q nat
   974	
   975		const mask = 1 << (_W - 1)
   976	
   977		// We walk through the bits of the exponent one by one. Each time we
   978		// see a bit, we square, thus doubling the power. If the bit is a one,
   979		// we also multiply by x, thus adding one to the power.
   980	
   981		w := _W - int(shift)
   982		// zz and r are used to avoid allocating in mul and div as
   983		// otherwise the arguments would alias.
   984		var zz, r nat
   985		for j := 0; j < w; j++ {
   986			zz = zz.mul(z, z)
   987			zz, z = z, zz
   988	
   989			if v&mask != 0 {
   990				zz = zz.mul(z, x)
   991				zz, z = z, zz
   992			}
   993	
   994			if len(m) != 0 {
   995				zz, r = zz.div(r, z, m)
   996				zz, r, q, z = q, z, zz, r
   997			}
   998	
   999			v <<= 1
  1000		}
  1001	
  1002		for i := len(y) - 2; i >= 0; i-- {
  1003			v = y[i]
  1004	
  1005			for j := 0; j < _W; j++ {
  1006				zz = zz.mul(z, z)
  1007				zz, z = z, zz
  1008	
  1009				if v&mask != 0 {
  1010					zz = zz.mul(z, x)
  1011					zz, z = z, zz
  1012				}
  1013	
  1014				if len(m) != 0 {
  1015					zz, r = zz.div(r, z, m)
  1016					zz, r, q, z = q, z, zz, r
  1017				}
  1018	
  1019				v <<= 1
  1020			}
  1021		}
  1022	
  1023		return z.norm()
  1024	}
  1025	
  1026	// expNNWindowed calculates x**y mod m using a fixed, 4-bit window.
  1027	func (z nat) expNNWindowed(x, y, m nat) nat {
  1028		// zz and r are used to avoid allocating in mul and div as otherwise
  1029		// the arguments would alias.
  1030		var zz, r nat
  1031	
  1032		const n = 4
  1033		// powers[i] contains x^i.
  1034		var powers [1 << n]nat
  1035		powers[0] = natOne
  1036		powers[1] = x
  1037		for i := 2; i < 1<<n; i += 2 {
  1038			p2, p, p1 := &powers[i/2], &powers[i], &powers[i+1]
  1039			*p = p.mul(*p2, *p2)
  1040			zz, r = zz.div(r, *p, m)
  1041			*p, r = r, *p
  1042			*p1 = p1.mul(*p, x)
  1043			zz, r = zz.div(r, *p1, m)
  1044			*p1, r = r, *p1
  1045		}
  1046	
  1047		z = z.setWord(1)
  1048	
  1049		for i := len(y) - 1; i >= 0; i-- {
  1050			yi := y[i]
  1051			for j := 0; j < _W; j += n {
  1052				if i != len(y)-1 || j != 0 {
  1053					// Unrolled loop for significant performance
  1054					// gain. Use go test -bench=".*" in crypto/rsa
  1055					// to check performance before making changes.
  1056					zz = zz.mul(z, z)
  1057					zz, z = z, zz
  1058					zz, r = zz.div(r, z, m)
  1059					z, r = r, z
  1060	
  1061					zz = zz.mul(z, z)
  1062					zz, z = z, zz
  1063					zz, r = zz.div(r, z, m)
  1064					z, r = r, z
  1065	
  1066					zz = zz.mul(z, z)
  1067					zz, z = z, zz
  1068					zz, r = zz.div(r, z, m)
  1069					z, r = r, z
  1070	
  1071					zz = zz.mul(z, z)
  1072					zz, z = z, zz
  1073					zz, r = zz.div(r, z, m)
  1074					z, r = r, z
  1075				}
  1076	
  1077				zz = zz.mul(z, powers[yi>>(_W-n)])
  1078				zz, z = z, zz
  1079				zz, r = zz.div(r, z, m)
  1080				z, r = r, z
  1081	
  1082				yi <<= n
  1083			}
  1084		}
  1085	
  1086		return z.norm()
  1087	}
  1088	
  1089	// expNNMontgomery calculates x**y mod m using a fixed, 4-bit window.
  1090	// Uses Montgomery representation.
  1091	func (z nat) expNNMontgomery(x, y, m nat) nat {
  1092		numWords := len(m)
  1093	
  1094		// We want the lengths of x and m to be equal.
  1095		// It is OK if x >= m as long as len(x) == len(m).
  1096		if len(x) > numWords {
  1097			_, x = nat(nil).div(nil, x, m)
  1098			// Note: now len(x) <= numWords, not guaranteed ==.
  1099		}
  1100		if len(x) < numWords {
  1101			rr := make(nat, numWords)
  1102			copy(rr, x)
  1103			x = rr
  1104		}
  1105	
  1106		// Ideally the precomputations would be performed outside, and reused
  1107		// k0 = -m**-1 mod 2**_W. Algorithm from: Dumas, J.G. "On Newton–Raphson
  1108		// Iteration for Multiplicative Inverses Modulo Prime Powers".
  1109		k0 := 2 - m[0]
  1110		t := m[0] - 1
  1111		for i := 1; i < _W; i <<= 1 {
  1112			t *= t
  1113			k0 *= (t + 1)
  1114		}
  1115		k0 = -k0
  1116	
  1117		// RR = 2**(2*_W*len(m)) mod m
  1118		RR := nat(nil).setWord(1)
  1119		zz := nat(nil).shl(RR, uint(2*numWords*_W))
  1120		_, RR = RR.div(RR, zz, m)
  1121		if len(RR) < numWords {
  1122			zz = zz.make(numWords)
  1123			copy(zz, RR)
  1124			RR = zz
  1125		}
  1126		// one = 1, with equal length to that of m
  1127		one := make(nat, numWords)
  1128		one[0] = 1
  1129	
  1130		const n = 4
  1131		// powers[i] contains x^i
  1132		var powers [1 << n]nat
  1133		powers[0] = powers[0].montgomery(one, RR, m, k0, numWords)
  1134		powers[1] = powers[1].montgomery(x, RR, m, k0, numWords)
  1135		for i := 2; i < 1<<n; i++ {
  1136			powers[i] = powers[i].montgomery(powers[i-1], powers[1], m, k0, numWords)
  1137		}
  1138	
  1139		// initialize z = 1 (Montgomery 1)
  1140		z = z.make(numWords)
  1141		copy(z, powers[0])
  1142	
  1143		zz = zz.make(numWords)
  1144	
  1145		// same windowed exponent, but with Montgomery multiplications
  1146		for i := len(y) - 1; i >= 0; i-- {
  1147			yi := y[i]
  1148			for j := 0; j < _W; j += n {
  1149				if i != len(y)-1 || j != 0 {
  1150					zz = zz.montgomery(z, z, m, k0, numWords)
  1151					z = z.montgomery(zz, zz, m, k0, numWords)
  1152					zz = zz.montgomery(z, z, m, k0, numWords)
  1153					z = z.montgomery(zz, zz, m, k0, numWords)
  1154				}
  1155				zz = zz.montgomery(z, powers[yi>>(_W-n)], m, k0, numWords)
  1156				z, zz = zz, z
  1157				yi <<= n
  1158			}
  1159		}
  1160		// convert to regular number
  1161		zz = zz.montgomery(z, one, m, k0, numWords)
  1162	
  1163		// One last reduction, just in case.
  1164		// See golang.org/issue/13907.
  1165		if zz.cmp(m) >= 0 {
  1166			// Common case is m has high bit set; in that case,
  1167			// since zz is the same length as m, there can be just
  1168			// one multiple of m to remove. Just subtract.
  1169			// We think that the subtract should be sufficient in general,
  1170			// so do that unconditionally, but double-check,
  1171			// in case our beliefs are wrong.
  1172			// The div is not expected to be reached.
  1173			zz = zz.sub(zz, m)
  1174			if zz.cmp(m) >= 0 {
  1175				_, zz = nat(nil).div(nil, zz, m)
  1176			}
  1177		}
  1178	
  1179		return zz.norm()
  1180	}
  1181	
  1182	// bytes writes the value of z into buf using big-endian encoding.
  1183	// len(buf) must be >= len(z)*_S. The value of z is encoded in the
  1184	// slice buf[i:]. The number i of unused bytes at the beginning of
  1185	// buf is returned as result.
  1186	func (z nat) bytes(buf []byte) (i int) {
  1187		i = len(buf)
  1188		for _, d := range z {
  1189			for j := 0; j < _S; j++ {
  1190				i--
  1191				buf[i] = byte(d)
  1192				d >>= 8
  1193			}
  1194		}
  1195	
  1196		for i < len(buf) && buf[i] == 0 {
  1197			i++
  1198		}
  1199	
  1200		return
  1201	}
  1202	
  1203	// setBytes interprets buf as the bytes of a big-endian unsigned
  1204	// integer, sets z to that value, and returns z.
  1205	func (z nat) setBytes(buf []byte) nat {
  1206		z = z.make((len(buf) + _S - 1) / _S)
  1207	
  1208		k := 0
  1209		s := uint(0)
  1210		var d Word
  1211		for i := len(buf); i > 0; i-- {
  1212			d |= Word(buf[i-1]) << s
  1213			if s += 8; s == _S*8 {
  1214				z[k] = d
  1215				k++
  1216				s = 0
  1217				d = 0
  1218			}
  1219		}
  1220		if k < len(z) {
  1221			z[k] = d
  1222		}
  1223	
  1224		return z.norm()
  1225	}
  1226	
  1227	// sqrt sets z = ⌊√x⌋
  1228	func (z nat) sqrt(x nat) nat {
  1229		if x.cmp(natOne) <= 0 {
  1230			return z.set(x)
  1231		}
  1232		if alias(z, x) {
  1233			z = nil
  1234		}
  1235	
  1236		// Start with value known to be too large and repeat "z = ⌊(z + ⌊x/z⌋)/2⌋" until it stops getting smaller.
  1237		// See Brent and Zimmermann, Modern Computer Arithmetic, Algorithm 1.13 (SqrtInt).
  1238		// https://members.loria.fr/PZimmermann/mca/pub226.html
  1239		// If x is one less than a perfect square, the sequence oscillates between the correct z and z+1;
  1240		// otherwise it converges to the correct z and stays there.
  1241		var z1, z2 nat
  1242		z1 = z
  1243		z1 = z1.setUint64(1)
  1244		z1 = z1.shl(z1, uint(x.bitLen()/2+1)) // must be ≥ √x
  1245		for n := 0; ; n++ {
  1246			z2, _ = z2.div(nil, x, z1)
  1247			z2 = z2.add(z2, z1)
  1248			z2 = z2.shr(z2, 1)
  1249			if z2.cmp(z1) >= 0 {
  1250				// z1 is answer.
  1251				// Figure out whether z1 or z2 is currently aliased to z by looking at loop count.
  1252				if n&1 == 0 {
  1253					return z1
  1254				}
  1255				return z.set(z1)
  1256			}
  1257			z1, z2 = z2, z1
  1258		}
  1259	}
  1260	

View as plain text