1
2
3
4
5
6 package rsa
7
8
9
10 import (
11 "big"
12 "crypto/rand"
13 "crypto/subtle"
14 "hash"
15 "io"
16 "os"
17 )
18
19 var bigZero = big.NewInt(0)
20 var bigOne = big.NewInt(1)
21
22
23 type PublicKey struct {
24 N *big.Int
25 E int
26 }
27
28
29 type PrivateKey struct {
30 PublicKey
31 D *big.Int
32 Primes []*big.Int
33
34
35
36 Precomputed PrecomputedValues
37 }
38
39 type PrecomputedValues struct {
40 Dp, Dq *big.Int
41 Qinv *big.Int
42
43
44
45
46
47 CRTValues []CRTValue
48 }
49
50
51 type CRTValue struct {
52 Exp *big.Int
53 Coeff *big.Int
54 R *big.Int
55 }
56
57
58
59
60 func (priv *PrivateKey) Validate() os.Error {
61
62
63
64
65 for _, prime := range priv.Primes {
66 if !big.ProbablyPrime(prime, 20) {
67 return os.NewError("prime factor is composite")
68 }
69 }
70
71
72 modulus := new(big.Int).Set(bigOne)
73 for _, prime := range priv.Primes {
74 modulus.Mul(modulus, prime)
75 }
76 if modulus.Cmp(priv.N) != 0 {
77 return os.NewError("invalid modulus")
78 }
79
80 totient := new(big.Int).Set(bigOne)
81 for _, prime := range priv.Primes {
82 pminus1 := new(big.Int).Sub(prime, bigOne)
83 totient.Mul(totient, pminus1)
84 }
85 e := big.NewInt(int64(priv.E))
86 gcd := new(big.Int)
87 x := new(big.Int)
88 y := new(big.Int)
89 big.GcdInt(gcd, x, y, totient, e)
90 if gcd.Cmp(bigOne) != 0 {
91 return os.NewError("invalid public exponent E")
92 }
93
94 de := new(big.Int).Mul(priv.D, e)
95 de.Mod(de, totient)
96 if de.Cmp(bigOne) != 0 {
97 return os.NewError("invalid private exponent D")
98 }
99 return nil
100 }
101
102
103 func GenerateKey(random io.Reader, bits int) (priv *PrivateKey, err os.Error) {
104 return GenerateMultiPrimeKey(random, 2, bits)
105 }
106
107
108
109
110
111
112
113
114
115
116
117 func GenerateMultiPrimeKey(random io.Reader, nprimes int, bits int) (priv *PrivateKey, err os.Error) {
118 priv = new(PrivateKey)
119
120
121
122
123
124
125
126
127 priv.E = 3
128
129 if nprimes < 2 {
130 return nil, os.NewError("rsa.GenerateMultiPrimeKey: nprimes must be >= 2")
131 }
132
133 primes := make([]*big.Int, nprimes)
134
135 NextSetOfPrimes:
136 for {
137 todo := bits
138 for i := 0; i < nprimes; i++ {
139 primes[i], err = rand.Prime(random, todo/(nprimes-i))
140 if err != nil {
141 return nil, err
142 }
143 todo -= primes[i].BitLen()
144 }
145
146
147 for i, prime := range primes {
148 for j := 0; j < i; j++ {
149 if prime.Cmp(primes[j]) == 0 {
150 continue NextSetOfPrimes
151 }
152 }
153 }
154
155 n := new(big.Int).Set(bigOne)
156 totient := new(big.Int).Set(bigOne)
157 pminus1 := new(big.Int)
158 for _, prime := range primes {
159 n.Mul(n, prime)
160 pminus1.Sub(prime, bigOne)
161 totient.Mul(totient, pminus1)
162 }
163
164 g := new(big.Int)
165 priv.D = new(big.Int)
166 y := new(big.Int)
167 e := big.NewInt(int64(priv.E))
168 big.GcdInt(g, priv.D, y, e, totient)
169
170 if g.Cmp(bigOne) == 0 {
171 priv.D.Add(priv.D, totient)
172 priv.Primes = primes
173 priv.N = n
174
175 break
176 }
177 }
178
179 priv.Precompute()
180 return
181 }
182
183
184 func incCounter(c *[4]byte) {
185 if c[3]++; c[3] != 0 {
186 return
187 }
188 if c[2]++; c[2] != 0 {
189 return
190 }
191 if c[1]++; c[1] != 0 {
192 return
193 }
194 c[0]++
195 }
196
197
198
199 func mgf1XOR(out []byte, hash hash.Hash, seed []byte) {
200 var counter [4]byte
201
202 done := 0
203 for done < len(out) {
204 hash.Write(seed)
205 hash.Write(counter[0:4])
206 digest := hash.Sum()
207 hash.Reset()
208
209 for i := 0; i < len(digest) && done < len(out); i++ {
210 out[done] ^= digest[i]
211 done++
212 }
213 incCounter(&counter)
214 }
215 }
216
217
218
219 type MessageTooLongError struct{}
220
221 func (MessageTooLongError) String() string {
222 return "message too long for RSA public key size"
223 }
224
225 func encrypt(c *big.Int, pub *PublicKey, m *big.Int) *big.Int {
226 e := big.NewInt(int64(pub.E))
227 c.Exp(m, e, pub.N)
228 return c
229 }
230
231
232
233
234 func EncryptOAEP(hash hash.Hash, random io.Reader, pub *PublicKey, msg []byte, label []byte) (out []byte, err os.Error) {
235 hash.Reset()
236 k := (pub.N.BitLen() + 7) / 8
237 if len(msg) > k-2*hash.Size()-2 {
238 err = MessageTooLongError{}
239 return
240 }
241
242 hash.Write(label)
243 lHash := hash.Sum()
244 hash.Reset()
245
246 em := make([]byte, k)
247 seed := em[1 : 1+hash.Size()]
248 db := em[1+hash.Size():]
249
250 copy(db[0:hash.Size()], lHash)
251 db[len(db)-len(msg)-1] = 1
252 copy(db[len(db)-len(msg):], msg)
253
254 _, err = io.ReadFull(random, seed)
255 if err != nil {
256 return
257 }
258
259 mgf1XOR(db, hash, seed)
260 mgf1XOR(seed, hash, db)
261
262 m := new(big.Int)
263 m.SetBytes(em)
264 c := encrypt(new(big.Int), pub, m)
265 out = c.Bytes()
266
267 if len(out) < k {
268
269 t := make([]byte, k)
270 copy(t[k-len(out):], out)
271 out = t
272 }
273
274 return
275 }
276
277
278
279 type DecryptionError struct{}
280
281 func (DecryptionError) String() string { return "RSA decryption error" }
282
283
284
285 type VerificationError struct{}
286
287 func (VerificationError) String() string { return "RSA verification error" }
288
289
290
291 func modInverse(a, n *big.Int) (ia *big.Int, ok bool) {
292 g := new(big.Int)
293 x := new(big.Int)
294 y := new(big.Int)
295 big.GcdInt(g, x, y, a, n)
296 if g.Cmp(bigOne) != 0 {
297
298
299
300
301 return
302 }
303
304 if x.Cmp(bigOne) < 0 {
305
306
307 x.Add(x, n)
308 }
309
310 return x, true
311 }
312
313
314
315 func (priv *PrivateKey) Precompute() {
316 if priv.Precomputed.Dp != nil {
317 return
318 }
319
320 priv.Precomputed.Dp = new(big.Int).Sub(priv.Primes[0], bigOne)
321 priv.Precomputed.Dp.Mod(priv.D, priv.Precomputed.Dp)
322
323 priv.Precomputed.Dq = new(big.Int).Sub(priv.Primes[1], bigOne)
324 priv.Precomputed.Dq.Mod(priv.D, priv.Precomputed.Dq)
325
326 priv.Precomputed.Qinv = new(big.Int).ModInverse(priv.Primes[1], priv.Primes[0])
327
328 r := new(big.Int).Mul(priv.Primes[0], priv.Primes[1])
329 priv.Precomputed.CRTValues = make([]CRTValue, len(priv.Primes)-2)
330 for i := 2; i < len(priv.Primes); i++ {
331 prime := priv.Primes[i]
332 values := &priv.Precomputed.CRTValues[i-2]
333
334 values.Exp = new(big.Int).Sub(prime, bigOne)
335 values.Exp.Mod(priv.D, values.Exp)
336
337 values.R = new(big.Int).Set(r)
338 values.Coeff = new(big.Int).ModInverse(r, prime)
339
340 r.Mul(r, prime)
341 }
342 }
343
344
345
346 func decrypt(random io.Reader, priv *PrivateKey, c *big.Int) (m *big.Int, err os.Error) {
347
348 if c.Cmp(priv.N) > 0 {
349 err = DecryptionError{}
350 return
351 }
352
353 var ir *big.Int
354 if random != nil {
355
356
357
358
359
360 var r *big.Int
361
362 for {
363 r, err = rand.Int(random, priv.N)
364 if err != nil {
365 return
366 }
367 if r.Cmp(bigZero) == 0 {
368 r = bigOne
369 }
370 var ok bool
371 ir, ok = modInverse(r, priv.N)
372 if ok {
373 break
374 }
375 }
376 bigE := big.NewInt(int64(priv.E))
377 rpowe := new(big.Int).Exp(r, bigE, priv.N)
378 cCopy := new(big.Int).Set(c)
379 cCopy.Mul(cCopy, rpowe)
380 cCopy.Mod(cCopy, priv.N)
381 c = cCopy
382 }
383
384 if priv.Precomputed.Dp == nil {
385 m = new(big.Int).Exp(c, priv.D, priv.N)
386 } else {
387
388 m = new(big.Int).Exp(c, priv.Precomputed.Dp, priv.Primes[0])
389 m2 := new(big.Int).Exp(c, priv.Precomputed.Dq, priv.Primes[1])
390 m.Sub(m, m2)
391 if m.Sign() < 0 {
392 m.Add(m, priv.Primes[0])
393 }
394 m.Mul(m, priv.Precomputed.Qinv)
395 m.Mod(m, priv.Primes[0])
396 m.Mul(m, priv.Primes[1])
397 m.Add(m, m2)
398
399 for i, values := range priv.Precomputed.CRTValues {
400 prime := priv.Primes[2+i]
401 m2.Exp(c, values.Exp, prime)
402 m2.Sub(m2, m)
403 m2.Mul(m2, values.Coeff)
404 m2.Mod(m2, prime)
405 if m2.Sign() < 0 {
406 m2.Add(m2, prime)
407 }
408 m2.Mul(m2, values.R)
409 m.Add(m, m2)
410 }
411 }
412
413 if ir != nil {
414
415 m.Mul(m, ir)
416 m.Mod(m, priv.N)
417 }
418
419 return
420 }
421
422
423
424 func DecryptOAEP(hash hash.Hash, random io.Reader, priv *PrivateKey, ciphertext []byte, label []byte) (msg []byte, err os.Error) {
425 k := (priv.N.BitLen() + 7) / 8
426 if len(ciphertext) > k ||
427 k < hash.Size()*2+2 {
428 err = DecryptionError{}
429 return
430 }
431
432 c := new(big.Int).SetBytes(ciphertext)
433
434 m, err := decrypt(random, priv, c)
435 if err != nil {
436 return
437 }
438
439 hash.Write(label)
440 lHash := hash.Sum()
441 hash.Reset()
442
443
444
445
446
447
448 em := leftPad(m.Bytes(), k)
449
450 firstByteIsZero := subtle.ConstantTimeByteEq(em[0], 0)
451
452 seed := em[1 : hash.Size()+1]
453 db := em[hash.Size()+1:]
454
455 mgf1XOR(seed, hash, db)
456 mgf1XOR(db, hash, seed)
457
458 lHash2 := db[0:hash.Size()]
459
460
461
462
463
464 lHash2Good := subtle.ConstantTimeCompare(lHash, lHash2)
465
466
467
468
469
470
471 var lookingForIndex, index, invalid int
472 lookingForIndex = 1
473 rest := db[hash.Size():]
474
475 for i := 0; i < len(rest); i++ {
476 equals0 := subtle.ConstantTimeByteEq(rest[i], 0)
477 equals1 := subtle.ConstantTimeByteEq(rest[i], 1)
478 index = subtle.ConstantTimeSelect(lookingForIndex&equals1, i, index)
479 lookingForIndex = subtle.ConstantTimeSelect(equals1, 0, lookingForIndex)
480 invalid = subtle.ConstantTimeSelect(lookingForIndex&^equals0, 1, invalid)
481 }
482
483 if firstByteIsZero&lHash2Good&^invalid&^lookingForIndex != 1 {
484 err = DecryptionError{}
485 return
486 }
487
488 msg = rest[index+1:]
489 return
490 }
491
492
493
494 func leftPad(input []byte, size int) (out []byte) {
495 n := len(input)
496 if n > size {
497 n = size
498 }
499 out = make([]byte, size)
500 copy(out[len(out)-n:], input)
501 return
502 }