package block
import (
"fmt"
"hash"
"io"
"os"
)
type EAXTagError struct {
Read []byte
Computed []byte
}
func (e *EAXTagError) String() string {
return fmt.Sprintf("crypto/block: EAX tag mismatch: read %x but computed %x", e.Read, e.Computed)
}
func setupEAX(c Cipher, iv, hdr []byte, tagBytes int) (ctrIV, tag []byte, cmac hash.Hash) {
n := len(iv)
if n != c.BlockSize() {
panic(fmt.Sprintln("crypto/block: EAX: iv length", n, "!=", c.BlockSize()))
}
buf := make([]byte, n)
cmac = NewCMAC(c)
cmac.Write(buf)
cmac.Write(iv)
sum := cmac.Sum()
ctrIV = copy(sum)
tag = copy(sum[0:tagBytes])
cmac.Reset()
buf[n-1] = 1
cmac.Write(buf)
cmac.Write(hdr)
sum = cmac.Sum()
for i := 0; i < tagBytes; i++ {
tag[i] ^= sum[i]
}
cmac.Reset()
buf[n-1] = 2
cmac.Write(buf)
return
}
func finishEAX(tag []byte, cmac hash.Hash) {
sum := cmac.Sum()
for i := range tag {
tag[i] ^= sum[i]
}
}
type cmacWriter struct {
w io.Writer
cmac hash.Hash
}
func (cw *cmacWriter) Write(p []byte) (n int, err os.Error) {
n, err = cw.w.Write(p)
cw.cmac.Write(p[0:n])
return
}
type eaxEncrypter struct {
ctr io.Writer
cw cmacWriter
tag []byte
}
func NewEAXEncrypter(c Cipher, iv []byte, hdr []byte, tagBytes int, w io.Writer) io.WriteCloser {
x := new(eaxEncrypter)
x.cw.w = w
var ctrIV []byte
ctrIV, x.tag, x.cw.cmac = setupEAX(c, iv, hdr, tagBytes)
x.ctr = NewCTRWriter(c, ctrIV, &x.cw)
return x
}
func (x *eaxEncrypter) Write(p []byte) (n int, err os.Error) {
return x.ctr.Write(p)
}
func (x *eaxEncrypter) Close() os.Error {
x.ctr = nil
finishEAX(x.tag, x.cw.cmac)
n, err := x.cw.w.Write(x.tag)
if n != len(x.tag) && err == nil {
err = io.ErrShortWrite
}
return err
}
type cmacReader struct {
r io.Reader
cmac hash.Hash
tag []byte
tmp []byte
}
func (cr *cmacReader) Read(p []byte) (n int, err os.Error) {
tag := cr.tag
if len(tag) < cap(tag) {
nt := len(tag)
nn, err1 := io.ReadFull(cr.r, tag[nt:cap(tag)])
tag = tag[0 : nt+nn]
cr.tag = tag
if err1 != nil {
return 0, err1
}
}
tagBytes := len(tag)
if len(p) > 4*tagBytes {
n, err = cr.r.Read(p[tagBytes:])
if n == 0 {
goto out
}
for i := 0; i < tagBytes; i++ {
p[i] = tag[i]
}
for i := 0; i < tagBytes; i++ {
tag[i] = p[n+i]
}
goto out
}
n, err = cr.r.Read(p)
if n == 0 {
goto out
}
tmp := cr.tmp
for i := n + tagBytes - 1; i >= 0; i-- {
var c byte
if i < tagBytes {
c = tag[i]
} else {
c = p[i-tagBytes]
}
if i < n {
p[i] = c
} else {
tmp[i] = c
}
}
cr.tmp, cr.tag = tag, tmp
out:
cr.cmac.Write(p[0:n])
return
}
type eaxDecrypter struct {
ctr io.Reader
cr cmacReader
tag []byte
}
func NewEAXDecrypter(c Cipher, iv []byte, hdr []byte, tagBytes int, r io.Reader) io.Reader {
x := new(eaxDecrypter)
x.cr.r = r
x.cr.tag = make([]byte, 0, tagBytes)
x.cr.tmp = make([]byte, 0, tagBytes)
var ctrIV []byte
ctrIV, x.tag, x.cr.cmac = setupEAX(c, iv, hdr, tagBytes)
x.ctr = NewCTRReader(c, ctrIV, &x.cr)
return x
}
func (x *eaxDecrypter) checkTag() os.Error {
x.ctr = nil
finishEAX(x.tag, x.cr.cmac)
if !same(x.tag, x.cr.tag) {
e := new(EAXTagError)
e.Computed = copy(x.tag)
e.Read = copy(x.cr.tag)
return e
}
return nil
}
func (x *eaxDecrypter) Read(p []byte) (n int, err os.Error) {
n, err = x.ctr.Read(p)
if n == 0 && err == nil {
err = x.checkTag()
}
return n, err
}