...
Run Format

Source file src/math/big/nat.go

     1	// Copyright 2009 The Go Authors. All rights reserved.
     2	// Use of this source code is governed by a BSD-style
     3	// license that can be found in the LICENSE file.
     4	
     5	// Package big implements multi-precision arithmetic (big numbers).
     6	// The following numeric types are supported:
     7	//
     8	//   Int    signed integers
     9	//   Rat    rational numbers
    10	//   Float  floating-point numbers
    11	//
    12	// Methods are typically of the form:
    13	//
    14	//   func (z *T) Unary(x *T) *T        // z = op x
    15	//   func (z *T) Binary(x, y *T) *T    // z = x op y
    16	//   func (x *T) M() T1                // v = x.M()
    17	//
    18	// with T one of Int, Rat, or Float. For unary and binary operations, the
    19	// result is the receiver (usually named z in that case); if it is one of
    20	// the operands x or y it may be overwritten (and its memory reused).
    21	// To enable chaining of operations, the result is also returned. Methods
    22	// returning a result other than *Int, *Rat, or *Float take an operand as
    23	// the receiver (usually named x in that case).
    24	//
    25	package big
    26	
    27	// This file contains operations on unsigned multi-precision integers.
    28	// These are the building blocks for the operations on signed integers
    29	// and rationals.
    30	
    31	import "math/rand"
    32	
    33	// An unsigned integer x of the form
    34	//
    35	//   x = x[n-1]*_B^(n-1) + x[n-2]*_B^(n-2) + ... + x[1]*_B + x[0]
    36	//
    37	// with 0 <= x[i] < _B and 0 <= i < n is stored in a slice of length n,
    38	// with the digits x[i] as the slice elements.
    39	//
    40	// A number is normalized if the slice contains no leading 0 digits.
    41	// During arithmetic operations, denormalized values may occur but are
    42	// always normalized before returning the final result. The normalized
    43	// representation of 0 is the empty or nil slice (length = 0).
    44	//
    45	type nat []Word
    46	
    47	var (
    48		natOne = nat{1}
    49		natTwo = nat{2}
    50		natTen = nat{10}
    51	)
    52	
    53	func (z nat) clear() {
    54		for i := range z {
    55			z[i] = 0
    56		}
    57	}
    58	
    59	func (z nat) norm() nat {
    60		i := len(z)
    61		for i > 0 && z[i-1] == 0 {
    62			i--
    63		}
    64		return z[0:i]
    65	}
    66	
    67	func (z nat) make(n int) nat {
    68		if n <= cap(z) {
    69			return z[:n] // reuse z
    70		}
    71		// Choosing a good value for e has significant performance impact
    72		// because it increases the chance that a value can be reused.
    73		const e = 4 // extra capacity
    74		return make(nat, n, n+e)
    75	}
    76	
    77	func (z nat) setWord(x Word) nat {
    78		if x == 0 {
    79			return z[:0]
    80		}
    81		z = z.make(1)
    82		z[0] = x
    83		return z
    84	}
    85	
    86	func (z nat) setUint64(x uint64) nat {
    87		// single-digit values
    88		if w := Word(x); uint64(w) == x {
    89			return z.setWord(w)
    90		}
    91	
    92		// compute number of words n required to represent x
    93		n := 0
    94		for t := x; t > 0; t >>= _W {
    95			n++
    96		}
    97	
    98		// split x into n words
    99		z = z.make(n)
   100		for i := range z {
   101			z[i] = Word(x & _M)
   102			x >>= _W
   103		}
   104	
   105		return z
   106	}
   107	
   108	func (z nat) set(x nat) nat {
   109		z = z.make(len(x))
   110		copy(z, x)
   111		return z
   112	}
   113	
   114	func (z nat) add(x, y nat) nat {
   115		m := len(x)
   116		n := len(y)
   117	
   118		switch {
   119		case m < n:
   120			return z.add(y, x)
   121		case m == 0:
   122			// n == 0 because m >= n; result is 0
   123			return z[:0]
   124		case n == 0:
   125			// result is x
   126			return z.set(x)
   127		}
   128		// m > 0
   129	
   130		z = z.make(m + 1)
   131		c := addVV(z[0:n], x, y)
   132		if m > n {
   133			c = addVW(z[n:m], x[n:], c)
   134		}
   135		z[m] = c
   136	
   137		return z.norm()
   138	}
   139	
   140	func (z nat) sub(x, y nat) nat {
   141		m := len(x)
   142		n := len(y)
   143	
   144		switch {
   145		case m < n:
   146			panic("underflow")
   147		case m == 0:
   148			// n == 0 because m >= n; result is 0
   149			return z[:0]
   150		case n == 0:
   151			// result is x
   152			return z.set(x)
   153		}
   154		// m > 0
   155	
   156		z = z.make(m)
   157		c := subVV(z[0:n], x, y)
   158		if m > n {
   159			c = subVW(z[n:], x[n:], c)
   160		}
   161		if c != 0 {
   162			panic("underflow")
   163		}
   164	
   165		return z.norm()
   166	}
   167	
   168	func (x nat) cmp(y nat) (r int) {
   169		m := len(x)
   170		n := len(y)
   171		if m != n || m == 0 {
   172			switch {
   173			case m < n:
   174				r = -1
   175			case m > n:
   176				r = 1
   177			}
   178			return
   179		}
   180	
   181		i := m - 1
   182		for i > 0 && x[i] == y[i] {
   183			i--
   184		}
   185	
   186		switch {
   187		case x[i] < y[i]:
   188			r = -1
   189		case x[i] > y[i]:
   190			r = 1
   191		}
   192		return
   193	}
   194	
   195	func (z nat) mulAddWW(x nat, y, r Word) nat {
   196		m := len(x)
   197		if m == 0 || y == 0 {
   198			return z.setWord(r) // result is r
   199		}
   200		// m > 0
   201	
   202		z = z.make(m + 1)
   203		z[m] = mulAddVWW(z[0:m], x, y, r)
   204	
   205		return z.norm()
   206	}
   207	
   208	// basicMul multiplies x and y and leaves the result in z.
   209	// The (non-normalized) result is placed in z[0 : len(x) + len(y)].
   210	func basicMul(z, x, y nat) {
   211		z[0 : len(x)+len(y)].clear() // initialize z
   212		for i, d := range y {
   213			if d != 0 {
   214				z[len(x)+i] = addMulVVW(z[i:i+len(x)], x, d)
   215			}
   216		}
   217	}
   218	
   219	// montgomery computes x*y*2^(-n*_W) mod m,
   220	// assuming k = -1/m mod 2^_W.
   221	// z is used for storing the result which is returned;
   222	// z must not alias x, y or m.
   223	func (z nat) montgomery(x, y, m nat, k Word, n int) nat {
   224		var c1, c2 Word
   225		z = z.make(n)
   226		z.clear()
   227		for i := 0; i < n; i++ {
   228			d := y[i]
   229			c1 += addMulVVW(z, x, d)
   230			t := z[0] * k
   231			c2 = addMulVVW(z, m, t)
   232	
   233			copy(z, z[1:])
   234			z[n-1] = c1 + c2
   235			if z[n-1] < c1 {
   236				c1 = 1
   237			} else {
   238				c1 = 0
   239			}
   240		}
   241		if c1 != 0 {
   242			subVV(z, z, m)
   243		}
   244		return z
   245	}
   246	
   247	// Fast version of z[0:n+n>>1].add(z[0:n+n>>1], x[0:n]) w/o bounds checks.
   248	// Factored out for readability - do not use outside karatsuba.
   249	func karatsubaAdd(z, x nat, n int) {
   250		if c := addVV(z[0:n], z, x); c != 0 {
   251			addVW(z[n:n+n>>1], z[n:], c)
   252		}
   253	}
   254	
   255	// Like karatsubaAdd, but does subtract.
   256	func karatsubaSub(z, x nat, n int) {
   257		if c := subVV(z[0:n], z, x); c != 0 {
   258			subVW(z[n:n+n>>1], z[n:], c)
   259		}
   260	}
   261	
   262	// Operands that are shorter than karatsubaThreshold are multiplied using
   263	// "grade school" multiplication; for longer operands the Karatsuba algorithm
   264	// is used.
   265	var karatsubaThreshold int = 40 // computed by calibrate.go
   266	
   267	// karatsuba multiplies x and y and leaves the result in z.
   268	// Both x and y must have the same length n and n must be a
   269	// power of 2. The result vector z must have len(z) >= 6*n.
   270	// The (non-normalized) result is placed in z[0 : 2*n].
   271	func karatsuba(z, x, y nat) {
   272		n := len(y)
   273	
   274		// Switch to basic multiplication if numbers are odd or small.
   275		// (n is always even if karatsubaThreshold is even, but be
   276		// conservative)
   277		if n&1 != 0 || n < karatsubaThreshold || n < 2 {
   278			basicMul(z, x, y)
   279			return
   280		}
   281		// n&1 == 0 && n >= karatsubaThreshold && n >= 2
   282	
   283		// Karatsuba multiplication is based on the observation that
   284		// for two numbers x and y with:
   285		//
   286		//   x = x1*b + x0
   287		//   y = y1*b + y0
   288		//
   289		// the product x*y can be obtained with 3 products z2, z1, z0
   290		// instead of 4:
   291		//
   292		//   x*y = x1*y1*b*b + (x1*y0 + x0*y1)*b + x0*y0
   293		//       =    z2*b*b +              z1*b +    z0
   294		//
   295		// with:
   296		//
   297		//   xd = x1 - x0
   298		//   yd = y0 - y1
   299		//
   300		//   z1 =      xd*yd                    + z2 + z0
   301		//      = (x1-x0)*(y0 - y1)             + z2 + z0
   302		//      = x1*y0 - x1*y1 - x0*y0 + x0*y1 + z2 + z0
   303		//      = x1*y0 -    z2 -    z0 + x0*y1 + z2 + z0
   304		//      = x1*y0                 + x0*y1
   305	
   306		// split x, y into "digits"
   307		n2 := n >> 1              // n2 >= 1
   308		x1, x0 := x[n2:], x[0:n2] // x = x1*b + y0
   309		y1, y0 := y[n2:], y[0:n2] // y = y1*b + y0
   310	
   311		// z is used for the result and temporary storage:
   312		//
   313		//   6*n     5*n     4*n     3*n     2*n     1*n     0*n
   314		// z = [z2 copy|z0 copy| xd*yd | yd:xd | x1*y1 | x0*y0 ]
   315		//
   316		// For each recursive call of karatsuba, an unused slice of
   317		// z is passed in that has (at least) half the length of the
   318		// caller's z.
   319	
   320		// compute z0 and z2 with the result "in place" in z
   321		karatsuba(z, x0, y0)     // z0 = x0*y0
   322		karatsuba(z[n:], x1, y1) // z2 = x1*y1
   323	
   324		// compute xd (or the negative value if underflow occurs)
   325		s := 1 // sign of product xd*yd
   326		xd := z[2*n : 2*n+n2]
   327		if subVV(xd, x1, x0) != 0 { // x1-x0
   328			s = -s
   329			subVV(xd, x0, x1) // x0-x1
   330		}
   331	
   332		// compute yd (or the negative value if underflow occurs)
   333		yd := z[2*n+n2 : 3*n]
   334		if subVV(yd, y0, y1) != 0 { // y0-y1
   335			s = -s
   336			subVV(yd, y1, y0) // y1-y0
   337		}
   338	
   339		// p = (x1-x0)*(y0-y1) == x1*y0 - x1*y1 - x0*y0 + x0*y1 for s > 0
   340		// p = (x0-x1)*(y0-y1) == x0*y0 - x0*y1 - x1*y0 + x1*y1 for s < 0
   341		p := z[n*3:]
   342		karatsuba(p, xd, yd)
   343	
   344		// save original z2:z0
   345		// (ok to use upper half of z since we're done recursing)
   346		r := z[n*4:]
   347		copy(r, z[:n*2])
   348	
   349		// add up all partial products
   350		//
   351		//   2*n     n     0
   352		// z = [ z2  | z0  ]
   353		//   +    [ z0  ]
   354		//   +    [ z2  ]
   355		//   +    [  p  ]
   356		//
   357		karatsubaAdd(z[n2:], r, n)
   358		karatsubaAdd(z[n2:], r[n:], n)
   359		if s > 0 {
   360			karatsubaAdd(z[n2:], p, n)
   361		} else {
   362			karatsubaSub(z[n2:], p, n)
   363		}
   364	}
   365	
   366	// alias reports whether x and y share the same base array.
   367	func alias(x, y nat) bool {
   368		return cap(x) > 0 && cap(y) > 0 && &x[0:cap(x)][cap(x)-1] == &y[0:cap(y)][cap(y)-1]
   369	}
   370	
   371	// addAt implements z += x<<(_W*i); z must be long enough.
   372	// (we don't use nat.add because we need z to stay the same
   373	// slice, and we don't need to normalize z after each addition)
   374	func addAt(z, x nat, i int) {
   375		if n := len(x); n > 0 {
   376			if c := addVV(z[i:i+n], z[i:], x); c != 0 {
   377				j := i + n
   378				if j < len(z) {
   379					addVW(z[j:], z[j:], c)
   380				}
   381			}
   382		}
   383	}
   384	
   385	func max(x, y int) int {
   386		if x > y {
   387			return x
   388		}
   389		return y
   390	}
   391	
   392	// karatsubaLen computes an approximation to the maximum k <= n such that
   393	// k = p<<i for a number p <= karatsubaThreshold and an i >= 0. Thus, the
   394	// result is the largest number that can be divided repeatedly by 2 before
   395	// becoming about the value of karatsubaThreshold.
   396	func karatsubaLen(n int) int {
   397		i := uint(0)
   398		for n > karatsubaThreshold {
   399			n >>= 1
   400			i++
   401		}
   402		return n << i
   403	}
   404	
   405	func (z nat) mul(x, y nat) nat {
   406		m := len(x)
   407		n := len(y)
   408	
   409		switch {
   410		case m < n:
   411			return z.mul(y, x)
   412		case m == 0 || n == 0:
   413			return z[:0]
   414		case n == 1:
   415			return z.mulAddWW(x, y[0], 0)
   416		}
   417		// m >= n > 1
   418	
   419		// determine if z can be reused
   420		if alias(z, x) || alias(z, y) {
   421			z = nil // z is an alias for x or y - cannot reuse
   422		}
   423	
   424		// use basic multiplication if the numbers are small
   425		if n < karatsubaThreshold {
   426			z = z.make(m + n)
   427			basicMul(z, x, y)
   428			return z.norm()
   429		}
   430		// m >= n && n >= karatsubaThreshold && n >= 2
   431	
   432		// determine Karatsuba length k such that
   433		//
   434		//   x = xh*b + x0  (0 <= x0 < b)
   435		//   y = yh*b + y0  (0 <= y0 < b)
   436		//   b = 1<<(_W*k)  ("base" of digits xi, yi)
   437		//
   438		k := karatsubaLen(n)
   439		// k <= n
   440	
   441		// multiply x0 and y0 via Karatsuba
   442		x0 := x[0:k]              // x0 is not normalized
   443		y0 := y[0:k]              // y0 is not normalized
   444		z = z.make(max(6*k, m+n)) // enough space for karatsuba of x0*y0 and full result of x*y
   445		karatsuba(z, x0, y0)
   446		z = z[0 : m+n]  // z has final length but may be incomplete
   447		z[2*k:].clear() // upper portion of z is garbage (and 2*k <= m+n since k <= n <= m)
   448	
   449		// If xh != 0 or yh != 0, add the missing terms to z. For
   450		//
   451		//   xh = xi*b^i + ... + x2*b^2 + x1*b (0 <= xi < b)
   452		//   yh =                         y1*b (0 <= y1 < b)
   453		//
   454		// the missing terms are
   455		//
   456		//   x0*y1*b and xi*y0*b^i, xi*y1*b^(i+1) for i > 0
   457		//
   458		// since all the yi for i > 1 are 0 by choice of k: If any of them
   459		// were > 0, then yh >= b^2 and thus y >= b^2. Then k' = k*2 would
   460		// be a larger valid threshold contradicting the assumption about k.
   461		//
   462		if k < n || m != n {
   463			var t nat
   464	
   465			// add x0*y1*b
   466			x0 := x0.norm()
   467			y1 := y[k:]       // y1 is normalized because y is
   468			t = t.mul(x0, y1) // update t so we don't lose t's underlying array
   469			addAt(z, t, k)
   470	
   471			// add xi*y0<<i, xi*y1*b<<(i+k)
   472			y0 := y0.norm()
   473			for i := k; i < len(x); i += k {
   474				xi := x[i:]
   475				if len(xi) > k {
   476					xi = xi[:k]
   477				}
   478				xi = xi.norm()
   479				t = t.mul(xi, y0)
   480				addAt(z, t, i)
   481				t = t.mul(xi, y1)
   482				addAt(z, t, i+k)
   483			}
   484		}
   485	
   486		return z.norm()
   487	}
   488	
   489	// mulRange computes the product of all the unsigned integers in the
   490	// range [a, b] inclusively. If a > b (empty range), the result is 1.
   491	func (z nat) mulRange(a, b uint64) nat {
   492		switch {
   493		case a == 0:
   494			// cut long ranges short (optimization)
   495			return z.setUint64(0)
   496		case a > b:
   497			return z.setUint64(1)
   498		case a == b:
   499			return z.setUint64(a)
   500		case a+1 == b:
   501			return z.mul(nat(nil).setUint64(a), nat(nil).setUint64(b))
   502		}
   503		m := (a + b) / 2
   504		return z.mul(nat(nil).mulRange(a, m), nat(nil).mulRange(m+1, b))
   505	}
   506	
   507	// q = (x-r)/y, with 0 <= r < y
   508	func (z nat) divW(x nat, y Word) (q nat, r Word) {
   509		m := len(x)
   510		switch {
   511		case y == 0:
   512			panic("division by zero")
   513		case y == 1:
   514			q = z.set(x) // result is x
   515			return
   516		case m == 0:
   517			q = z[:0] // result is 0
   518			return
   519		}
   520		// m > 0
   521		z = z.make(m)
   522		r = divWVW(z, 0, x, y)
   523		q = z.norm()
   524		return
   525	}
   526	
   527	func (z nat) div(z2, u, v nat) (q, r nat) {
   528		if len(v) == 0 {
   529			panic("division by zero")
   530		}
   531	
   532		if u.cmp(v) < 0 {
   533			q = z[:0]
   534			r = z2.set(u)
   535			return
   536		}
   537	
   538		if len(v) == 1 {
   539			var r2 Word
   540			q, r2 = z.divW(u, v[0])
   541			r = z2.setWord(r2)
   542			return
   543		}
   544	
   545		q, r = z.divLarge(z2, u, v)
   546		return
   547	}
   548	
   549	// q = (uIn-r)/v, with 0 <= r < y
   550	// Uses z as storage for q, and u as storage for r if possible.
   551	// See Knuth, Volume 2, section 4.3.1, Algorithm D.
   552	// Preconditions:
   553	//    len(v) >= 2
   554	//    len(uIn) >= len(v)
   555	func (z nat) divLarge(u, uIn, v nat) (q, r nat) {
   556		n := len(v)
   557		m := len(uIn) - n
   558	
   559		// determine if z can be reused
   560		// TODO(gri) should find a better solution - this if statement
   561		//           is very costly (see e.g. time pidigits -s -n 10000)
   562		if alias(z, uIn) || alias(z, v) {
   563			z = nil // z is an alias for uIn or v - cannot reuse
   564		}
   565		q = z.make(m + 1)
   566	
   567		qhatv := make(nat, n+1)
   568		if alias(u, uIn) || alias(u, v) {
   569			u = nil // u is an alias for uIn or v - cannot reuse
   570		}
   571		u = u.make(len(uIn) + 1)
   572		u.clear() // TODO(gri) no need to clear if we allocated a new u
   573	
   574		// D1.
   575		shift := nlz(v[n-1])
   576		if shift > 0 {
   577			// do not modify v, it may be used by another goroutine simultaneously
   578			v1 := make(nat, n)
   579			shlVU(v1, v, shift)
   580			v = v1
   581		}
   582		u[len(uIn)] = shlVU(u[0:len(uIn)], uIn, shift)
   583	
   584		// D2.
   585		for j := m; j >= 0; j-- {
   586			// D3.
   587			qhat := Word(_M)
   588			if u[j+n] != v[n-1] {
   589				var rhat Word
   590				qhat, rhat = divWW(u[j+n], u[j+n-1], v[n-1])
   591	
   592				// x1 | x2 = q̂v_{n-2}
   593				x1, x2 := mulWW(qhat, v[n-2])
   594				// test if q̂v_{n-2} > br̂ + u_{j+n-2}
   595				for greaterThan(x1, x2, rhat, u[j+n-2]) {
   596					qhat--
   597					prevRhat := rhat
   598					rhat += v[n-1]
   599					// v[n-1] >= 0, so this tests for overflow.
   600					if rhat < prevRhat {
   601						break
   602					}
   603					x1, x2 = mulWW(qhat, v[n-2])
   604				}
   605			}
   606	
   607			// D4.
   608			qhatv[n] = mulAddVWW(qhatv[0:n], v, qhat, 0)
   609	
   610			c := subVV(u[j:j+len(qhatv)], u[j:], qhatv)
   611			if c != 0 {
   612				c := addVV(u[j:j+n], u[j:], v)
   613				u[j+n] += c
   614				qhat--
   615			}
   616	
   617			q[j] = qhat
   618		}
   619	
   620		q = q.norm()
   621		shrVU(u, u, shift)
   622		r = u.norm()
   623	
   624		return q, r
   625	}
   626	
   627	// Length of x in bits. x must be normalized.
   628	func (x nat) bitLen() int {
   629		if i := len(x) - 1; i >= 0 {
   630			return i*_W + bitLen(x[i])
   631		}
   632		return 0
   633	}
   634	
   635	const deBruijn32 = 0x077CB531
   636	
   637	var deBruijn32Lookup = []byte{
   638		0, 1, 28, 2, 29, 14, 24, 3, 30, 22, 20, 15, 25, 17, 4, 8,
   639		31, 27, 13, 23, 21, 19, 16, 7, 26, 12, 18, 6, 11, 5, 10, 9,
   640	}
   641	
   642	const deBruijn64 = 0x03f79d71b4ca8b09
   643	
   644	var deBruijn64Lookup = []byte{
   645		0, 1, 56, 2, 57, 49, 28, 3, 61, 58, 42, 50, 38, 29, 17, 4,
   646		62, 47, 59, 36, 45, 43, 51, 22, 53, 39, 33, 30, 24, 18, 12, 5,
   647		63, 55, 48, 27, 60, 41, 37, 16, 46, 35, 44, 21, 52, 32, 23, 11,
   648		54, 26, 40, 15, 34, 20, 31, 10, 25, 14, 19, 9, 13, 8, 7, 6,
   649	}
   650	
   651	// trailingZeroBits returns the number of consecutive least significant zero
   652	// bits of x.
   653	func trailingZeroBits(x Word) uint {
   654		// x & -x leaves only the right-most bit set in the word. Let k be the
   655		// index of that bit. Since only a single bit is set, the value is two
   656		// to the power of k. Multiplying by a power of two is equivalent to
   657		// left shifting, in this case by k bits.  The de Bruijn constant is
   658		// such that all six bit, consecutive substrings are distinct.
   659		// Therefore, if we have a left shifted version of this constant we can
   660		// find by how many bits it was shifted by looking at which six bit
   661		// substring ended up at the top of the word.
   662		// (Knuth, volume 4, section 7.3.1)
   663		switch _W {
   664		case 32:
   665			return uint(deBruijn32Lookup[((x&-x)*deBruijn32)>>27])
   666		case 64:
   667			return uint(deBruijn64Lookup[((x&-x)*(deBruijn64&_M))>>58])
   668		default:
   669			panic("unknown word size")
   670		}
   671	}
   672	
   673	// trailingZeroBits returns the number of consecutive least significant zero
   674	// bits of x.
   675	func (x nat) trailingZeroBits() uint {
   676		if len(x) == 0 {
   677			return 0
   678		}
   679		var i uint
   680		for x[i] == 0 {
   681			i++
   682		}
   683		// x[i] != 0
   684		return i*_W + trailingZeroBits(x[i])
   685	}
   686	
   687	// z = x << s
   688	func (z nat) shl(x nat, s uint) nat {
   689		m := len(x)
   690		if m == 0 {
   691			return z[:0]
   692		}
   693		// m > 0
   694	
   695		n := m + int(s/_W)
   696		z = z.make(n + 1)
   697		z[n] = shlVU(z[n-m:n], x, s%_W)
   698		z[0 : n-m].clear()
   699	
   700		return z.norm()
   701	}
   702	
   703	// z = x >> s
   704	func (z nat) shr(x nat, s uint) nat {
   705		m := len(x)
   706		n := m - int(s/_W)
   707		if n <= 0 {
   708			return z[:0]
   709		}
   710		// n > 0
   711	
   712		z = z.make(n)
   713		shrVU(z, x[m-n:], s%_W)
   714	
   715		return z.norm()
   716	}
   717	
   718	func (z nat) setBit(x nat, i uint, b uint) nat {
   719		j := int(i / _W)
   720		m := Word(1) << (i % _W)
   721		n := len(x)
   722		switch b {
   723		case 0:
   724			z = z.make(n)
   725			copy(z, x)
   726			if j >= n {
   727				// no need to grow
   728				return z
   729			}
   730			z[j] &^= m
   731			return z.norm()
   732		case 1:
   733			if j >= n {
   734				z = z.make(j + 1)
   735				z[n:].clear()
   736			} else {
   737				z = z.make(n)
   738			}
   739			copy(z, x)
   740			z[j] |= m
   741			// no need to normalize
   742			return z
   743		}
   744		panic("set bit is not 0 or 1")
   745	}
   746	
   747	// bit returns the value of the i'th bit, with lsb == bit 0.
   748	func (x nat) bit(i uint) uint {
   749		j := i / _W
   750		if j >= uint(len(x)) {
   751			return 0
   752		}
   753		// 0 <= j < len(x)
   754		return uint(x[j] >> (i % _W) & 1)
   755	}
   756	
   757	// sticky returns 1 if there's a 1 bit within the
   758	// i least significant bits, otherwise it returns 0.
   759	func (x nat) sticky(i uint) uint {
   760		j := i / _W
   761		if j >= uint(len(x)) {
   762			if len(x) == 0 {
   763				return 0
   764			}
   765			return 1
   766		}
   767		// 0 <= j < len(x)
   768		for _, x := range x[:j] {
   769			if x != 0 {
   770				return 1
   771			}
   772		}
   773		if x[j]<<(_W-i%_W) != 0 {
   774			return 1
   775		}
   776		return 0
   777	}
   778	
   779	func (z nat) and(x, y nat) nat {
   780		m := len(x)
   781		n := len(y)
   782		if m > n {
   783			m = n
   784		}
   785		// m <= n
   786	
   787		z = z.make(m)
   788		for i := 0; i < m; i++ {
   789			z[i] = x[i] & y[i]
   790		}
   791	
   792		return z.norm()
   793	}
   794	
   795	func (z nat) andNot(x, y nat) nat {
   796		m := len(x)
   797		n := len(y)
   798		if n > m {
   799			n = m
   800		}
   801		// m >= n
   802	
   803		z = z.make(m)
   804		for i := 0; i < n; i++ {
   805			z[i] = x[i] &^ y[i]
   806		}
   807		copy(z[n:m], x[n:m])
   808	
   809		return z.norm()
   810	}
   811	
   812	func (z nat) or(x, y nat) nat {
   813		m := len(x)
   814		n := len(y)
   815		s := x
   816		if m < n {
   817			n, m = m, n
   818			s = y
   819		}
   820		// m >= n
   821	
   822		z = z.make(m)
   823		for i := 0; i < n; i++ {
   824			z[i] = x[i] | y[i]
   825		}
   826		copy(z[n:m], s[n:m])
   827	
   828		return z.norm()
   829	}
   830	
   831	func (z nat) xor(x, y nat) nat {
   832		m := len(x)
   833		n := len(y)
   834		s := x
   835		if m < n {
   836			n, m = m, n
   837			s = y
   838		}
   839		// m >= n
   840	
   841		z = z.make(m)
   842		for i := 0; i < n; i++ {
   843			z[i] = x[i] ^ y[i]
   844		}
   845		copy(z[n:m], s[n:m])
   846	
   847		return z.norm()
   848	}
   849	
   850	// greaterThan reports whether (x1<<_W + x2) > (y1<<_W + y2)
   851	func greaterThan(x1, x2, y1, y2 Word) bool {
   852		return x1 > y1 || x1 == y1 && x2 > y2
   853	}
   854	
   855	// modW returns x % d.
   856	func (x nat) modW(d Word) (r Word) {
   857		// TODO(agl): we don't actually need to store the q value.
   858		var q nat
   859		q = q.make(len(x))
   860		return divWVW(q, 0, x, d)
   861	}
   862	
   863	// random creates a random integer in [0..limit), using the space in z if
   864	// possible. n is the bit length of limit.
   865	func (z nat) random(rand *rand.Rand, limit nat, n int) nat {
   866		if alias(z, limit) {
   867			z = nil // z is an alias for limit - cannot reuse
   868		}
   869		z = z.make(len(limit))
   870	
   871		bitLengthOfMSW := uint(n % _W)
   872		if bitLengthOfMSW == 0 {
   873			bitLengthOfMSW = _W
   874		}
   875		mask := Word((1 << bitLengthOfMSW) - 1)
   876	
   877		for {
   878			switch _W {
   879			case 32:
   880				for i := range z {
   881					z[i] = Word(rand.Uint32())
   882				}
   883			case 64:
   884				for i := range z {
   885					z[i] = Word(rand.Uint32()) | Word(rand.Uint32())<<32
   886				}
   887			default:
   888				panic("unknown word size")
   889			}
   890			z[len(limit)-1] &= mask
   891			if z.cmp(limit) < 0 {
   892				break
   893			}
   894		}
   895	
   896		return z.norm()
   897	}
   898	
   899	// If m != 0 (i.e., len(m) != 0), expNN sets z to x**y mod m;
   900	// otherwise it sets z to x**y. The result is the value of z.
   901	func (z nat) expNN(x, y, m nat) nat {
   902		if alias(z, x) || alias(z, y) {
   903			// We cannot allow in-place modification of x or y.
   904			z = nil
   905		}
   906	
   907		// x**y mod 1 == 0
   908		if len(m) == 1 && m[0] == 1 {
   909			return z.setWord(0)
   910		}
   911		// m == 0 || m > 1
   912	
   913		// x**0 == 1
   914		if len(y) == 0 {
   915			return z.setWord(1)
   916		}
   917		// y > 0
   918	
   919		// x**1 mod m == x mod m
   920		if len(y) == 1 && y[0] == 1 && len(m) != 0 {
   921			_, z = z.div(z, x, m)
   922			return z
   923		}
   924		// y > 1
   925	
   926		if len(m) != 0 {
   927			// We likely end up being as long as the modulus.
   928			z = z.make(len(m))
   929		}
   930		z = z.set(x)
   931	
   932		// If the base is non-trivial and the exponent is large, we use
   933		// 4-bit, windowed exponentiation. This involves precomputing 14 values
   934		// (x^2...x^15) but then reduces the number of multiply-reduces by a
   935		// third. Even for a 32-bit exponent, this reduces the number of
   936		// operations. Uses Montgomery method for odd moduli.
   937		if len(x) > 1 && len(y) > 1 && len(m) > 0 {
   938			if m[0]&1 == 1 {
   939				return z.expNNMontgomery(x, y, m)
   940			}
   941			return z.expNNWindowed(x, y, m)
   942		}
   943	
   944		v := y[len(y)-1] // v > 0 because y is normalized and y > 0
   945		shift := nlz(v) + 1
   946		v <<= shift
   947		var q nat
   948	
   949		const mask = 1 << (_W - 1)
   950	
   951		// We walk through the bits of the exponent one by one. Each time we
   952		// see a bit, we square, thus doubling the power. If the bit is a one,
   953		// we also multiply by x, thus adding one to the power.
   954	
   955		w := _W - int(shift)
   956		// zz and r are used to avoid allocating in mul and div as
   957		// otherwise the arguments would alias.
   958		var zz, r nat
   959		for j := 0; j < w; j++ {
   960			zz = zz.mul(z, z)
   961			zz, z = z, zz
   962	
   963			if v&mask != 0 {
   964				zz = zz.mul(z, x)
   965				zz, z = z, zz
   966			}
   967	
   968			if len(m) != 0 {
   969				zz, r = zz.div(r, z, m)
   970				zz, r, q, z = q, z, zz, r
   971			}
   972	
   973			v <<= 1
   974		}
   975	
   976		for i := len(y) - 2; i >= 0; i-- {
   977			v = y[i]
   978	
   979			for j := 0; j < _W; j++ {
   980				zz = zz.mul(z, z)
   981				zz, z = z, zz
   982	
   983				if v&mask != 0 {
   984					zz = zz.mul(z, x)
   985					zz, z = z, zz
   986				}
   987	
   988				if len(m) != 0 {
   989					zz, r = zz.div(r, z, m)
   990					zz, r, q, z = q, z, zz, r
   991				}
   992	
   993				v <<= 1
   994			}
   995		}
   996	
   997		return z.norm()
   998	}
   999	
  1000	// expNNWindowed calculates x**y mod m using a fixed, 4-bit window.
  1001	func (z nat) expNNWindowed(x, y, m nat) nat {
  1002		// zz and r are used to avoid allocating in mul and div as otherwise
  1003		// the arguments would alias.
  1004		var zz, r nat
  1005	
  1006		const n = 4
  1007		// powers[i] contains x^i.
  1008		var powers [1 << n]nat
  1009		powers[0] = natOne
  1010		powers[1] = x
  1011		for i := 2; i < 1<<n; i += 2 {
  1012			p2, p, p1 := &powers[i/2], &powers[i], &powers[i+1]
  1013			*p = p.mul(*p2, *p2)
  1014			zz, r = zz.div(r, *p, m)
  1015			*p, r = r, *p
  1016			*p1 = p1.mul(*p, x)
  1017			zz, r = zz.div(r, *p1, m)
  1018			*p1, r = r, *p1
  1019		}
  1020	
  1021		z = z.setWord(1)
  1022	
  1023		for i := len(y) - 1; i >= 0; i-- {
  1024			yi := y[i]
  1025			for j := 0; j < _W; j += n {
  1026				if i != len(y)-1 || j != 0 {
  1027					// Unrolled loop for significant performance
  1028					// gain.  Use go test -bench=".*" in crypto/rsa
  1029					// to check performance before making changes.
  1030					zz = zz.mul(z, z)
  1031					zz, z = z, zz
  1032					zz, r = zz.div(r, z, m)
  1033					z, r = r, z
  1034	
  1035					zz = zz.mul(z, z)
  1036					zz, z = z, zz
  1037					zz, r = zz.div(r, z, m)
  1038					z, r = r, z
  1039	
  1040					zz = zz.mul(z, z)
  1041					zz, z = z, zz
  1042					zz, r = zz.div(r, z, m)
  1043					z, r = r, z
  1044	
  1045					zz = zz.mul(z, z)
  1046					zz, z = z, zz
  1047					zz, r = zz.div(r, z, m)
  1048					z, r = r, z
  1049				}
  1050	
  1051				zz = zz.mul(z, powers[yi>>(_W-n)])
  1052				zz, z = z, zz
  1053				zz, r = zz.div(r, z, m)
  1054				z, r = r, z
  1055	
  1056				yi <<= n
  1057			}
  1058		}
  1059	
  1060		return z.norm()
  1061	}
  1062	
  1063	// expNNMontgomery calculates x**y mod m using a fixed, 4-bit window.
  1064	// Uses Montgomery representation.
  1065	func (z nat) expNNMontgomery(x, y, m nat) nat {
  1066		var zz, one, rr, RR nat
  1067	
  1068		numWords := len(m)
  1069	
  1070		// We want the lengths of x and m to be equal.
  1071		if len(x) > numWords {
  1072			_, rr = rr.div(rr, x, m)
  1073		} else if len(x) < numWords {
  1074			rr = rr.make(numWords)
  1075			rr.clear()
  1076			for i := range x {
  1077				rr[i] = x[i]
  1078			}
  1079		} else {
  1080			rr = x
  1081		}
  1082		x = rr
  1083	
  1084		// Ideally the precomputations would be performed outside, and reused
  1085		// k0 = -mˆ-1 mod 2ˆ_W. Algorithm from: Dumas, J.G. "On Newton–Raphson
  1086		// Iteration for Multiplicative Inverses Modulo Prime Powers".
  1087		k0 := 2 - m[0]
  1088		t := m[0] - 1
  1089		for i := 1; i < _W; i <<= 1 {
  1090			t *= t
  1091			k0 *= (t + 1)
  1092		}
  1093		k0 = -k0
  1094	
  1095		// RR = 2ˆ(2*_W*len(m)) mod m
  1096		RR = RR.setWord(1)
  1097		zz = zz.shl(RR, uint(2*numWords*_W))
  1098		_, RR = RR.div(RR, zz, m)
  1099		if len(RR) < numWords {
  1100			zz = zz.make(numWords)
  1101			copy(zz, RR)
  1102			RR = zz
  1103		}
  1104		// one = 1, with equal length to that of m
  1105		one = one.make(numWords)
  1106		one.clear()
  1107		one[0] = 1
  1108	
  1109		const n = 4
  1110		// powers[i] contains x^i
  1111		var powers [1 << n]nat
  1112		powers[0] = powers[0].montgomery(one, RR, m, k0, numWords)
  1113		powers[1] = powers[1].montgomery(x, RR, m, k0, numWords)
  1114		for i := 2; i < 1<<n; i++ {
  1115			powers[i] = powers[i].montgomery(powers[i-1], powers[1], m, k0, numWords)
  1116		}
  1117	
  1118		// initialize z = 1 (Montgomery 1)
  1119		z = z.make(numWords)
  1120		copy(z, powers[0])
  1121	
  1122		zz = zz.make(numWords)
  1123	
  1124		// same windowed exponent, but with Montgomery multiplications
  1125		for i := len(y) - 1; i >= 0; i-- {
  1126			yi := y[i]
  1127			for j := 0; j < _W; j += n {
  1128				if i != len(y)-1 || j != 0 {
  1129					zz = zz.montgomery(z, z, m, k0, numWords)
  1130					z = z.montgomery(zz, zz, m, k0, numWords)
  1131					zz = zz.montgomery(z, z, m, k0, numWords)
  1132					z = z.montgomery(zz, zz, m, k0, numWords)
  1133				}
  1134				zz = zz.montgomery(z, powers[yi>>(_W-n)], m, k0, numWords)
  1135				z, zz = zz, z
  1136				yi <<= n
  1137			}
  1138		}
  1139		// convert to regular number
  1140		zz = zz.montgomery(z, one, m, k0, numWords)
  1141		return zz.norm()
  1142	}
  1143	
  1144	// probablyPrime performs reps Miller-Rabin tests to check whether n is prime.
  1145	// If it returns true, n is prime with probability 1 - 1/4^reps.
  1146	// If it returns false, n is not prime.
  1147	func (n nat) probablyPrime(reps int) bool {
  1148		if len(n) == 0 {
  1149			return false
  1150		}
  1151	
  1152		if len(n) == 1 {
  1153			if n[0] < 2 {
  1154				return false
  1155			}
  1156	
  1157			if n[0]%2 == 0 {
  1158				return n[0] == 2
  1159			}
  1160	
  1161			// We have to exclude these cases because we reject all
  1162			// multiples of these numbers below.
  1163			switch n[0] {
  1164			case 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53:
  1165				return true
  1166			}
  1167		}
  1168	
  1169		if n[0]&1 == 0 {
  1170			return false // n is even
  1171		}
  1172	
  1173		const primesProduct32 = 0xC0CFD797         // Π {p ∈ primes, 2 < p <= 29}
  1174		const primesProduct64 = 0xE221F97C30E94E1D // Π {p ∈ primes, 2 < p <= 53}
  1175	
  1176		var r Word
  1177		switch _W {
  1178		case 32:
  1179			r = n.modW(primesProduct32)
  1180		case 64:
  1181			r = n.modW(primesProduct64 & _M)
  1182		default:
  1183			panic("Unknown word size")
  1184		}
  1185	
  1186		if r%3 == 0 || r%5 == 0 || r%7 == 0 || r%11 == 0 ||
  1187			r%13 == 0 || r%17 == 0 || r%19 == 0 || r%23 == 0 || r%29 == 0 {
  1188			return false
  1189		}
  1190	
  1191		if _W == 64 && (r%31 == 0 || r%37 == 0 || r%41 == 0 ||
  1192			r%43 == 0 || r%47 == 0 || r%53 == 0) {
  1193			return false
  1194		}
  1195	
  1196		nm1 := nat(nil).sub(n, natOne)
  1197		// determine q, k such that nm1 = q << k
  1198		k := nm1.trailingZeroBits()
  1199		q := nat(nil).shr(nm1, k)
  1200	
  1201		nm3 := nat(nil).sub(nm1, natTwo)
  1202		rand := rand.New(rand.NewSource(int64(n[0])))
  1203	
  1204		var x, y, quotient nat
  1205		nm3Len := nm3.bitLen()
  1206	
  1207	NextRandom:
  1208		for i := 0; i < reps; i++ {
  1209			x = x.random(rand, nm3, nm3Len)
  1210			x = x.add(x, natTwo)
  1211			y = y.expNN(x, q, n)
  1212			if y.cmp(natOne) == 0 || y.cmp(nm1) == 0 {
  1213				continue
  1214			}
  1215			for j := uint(1); j < k; j++ {
  1216				y = y.mul(y, y)
  1217				quotient, y = quotient.div(y, y, n)
  1218				if y.cmp(nm1) == 0 {
  1219					continue NextRandom
  1220				}
  1221				if y.cmp(natOne) == 0 {
  1222					return false
  1223				}
  1224			}
  1225			return false
  1226		}
  1227	
  1228		return true
  1229	}
  1230	
  1231	// bytes writes the value of z into buf using big-endian encoding.
  1232	// len(buf) must be >= len(z)*_S. The value of z is encoded in the
  1233	// slice buf[i:]. The number i of unused bytes at the beginning of
  1234	// buf is returned as result.
  1235	func (z nat) bytes(buf []byte) (i int) {
  1236		i = len(buf)
  1237		for _, d := range z {
  1238			for j := 0; j < _S; j++ {
  1239				i--
  1240				buf[i] = byte(d)
  1241				d >>= 8
  1242			}
  1243		}
  1244	
  1245		for i < len(buf) && buf[i] == 0 {
  1246			i++
  1247		}
  1248	
  1249		return
  1250	}
  1251	
  1252	// setBytes interprets buf as the bytes of a big-endian unsigned
  1253	// integer, sets z to that value, and returns z.
  1254	func (z nat) setBytes(buf []byte) nat {
  1255		z = z.make((len(buf) + _S - 1) / _S)
  1256	
  1257		k := 0
  1258		s := uint(0)
  1259		var d Word
  1260		for i := len(buf); i > 0; i-- {
  1261			d |= Word(buf[i-1]) << s
  1262			if s += 8; s == _S*8 {
  1263				z[k] = d
  1264				k++
  1265				s = 0
  1266				d = 0
  1267			}
  1268		}
  1269		if k < len(z) {
  1270			z[k] = d
  1271		}
  1272	
  1273		return z.norm()
  1274	}
  1275	

View as plain text