...
Run Format

Source file src/encoding/asn1/marshal.go

     1	// Copyright 2009 The Go Authors. All rights reserved.
     2	// Use of this source code is governed by a BSD-style
     3	// license that can be found in the LICENSE file.
     4	
     5	package asn1
     6	
     7	import (
     8		"bytes"
     9		"errors"
    10		"fmt"
    11		"io"
    12		"math/big"
    13		"reflect"
    14		"time"
    15		"unicode/utf8"
    16	)
    17	
    18	// A forkableWriter is an in-memory buffer that can be
    19	// 'forked' to create new forkableWriters that bracket the
    20	// original.  After
    21	//    pre, post := w.fork()
    22	// the overall sequence of bytes represented is logically w+pre+post.
    23	type forkableWriter struct {
    24		*bytes.Buffer
    25		pre, post *forkableWriter
    26	}
    27	
    28	func newForkableWriter() *forkableWriter {
    29		return &forkableWriter{new(bytes.Buffer), nil, nil}
    30	}
    31	
    32	func (f *forkableWriter) fork() (pre, post *forkableWriter) {
    33		if f.pre != nil || f.post != nil {
    34			panic("have already forked")
    35		}
    36		f.pre = newForkableWriter()
    37		f.post = newForkableWriter()
    38		return f.pre, f.post
    39	}
    40	
    41	func (f *forkableWriter) Len() (l int) {
    42		l += f.Buffer.Len()
    43		if f.pre != nil {
    44			l += f.pre.Len()
    45		}
    46		if f.post != nil {
    47			l += f.post.Len()
    48		}
    49		return
    50	}
    51	
    52	func (f *forkableWriter) writeTo(out io.Writer) (n int, err error) {
    53		n, err = out.Write(f.Bytes())
    54		if err != nil {
    55			return
    56		}
    57	
    58		var nn int
    59	
    60		if f.pre != nil {
    61			nn, err = f.pre.writeTo(out)
    62			n += nn
    63			if err != nil {
    64				return
    65			}
    66		}
    67	
    68		if f.post != nil {
    69			nn, err = f.post.writeTo(out)
    70			n += nn
    71		}
    72		return
    73	}
    74	
    75	func marshalBase128Int(out *forkableWriter, n int64) (err error) {
    76		if n == 0 {
    77			err = out.WriteByte(0)
    78			return
    79		}
    80	
    81		l := 0
    82		for i := n; i > 0; i >>= 7 {
    83			l++
    84		}
    85	
    86		for i := l - 1; i >= 0; i-- {
    87			o := byte(n >> uint(i*7))
    88			o &= 0x7f
    89			if i != 0 {
    90				o |= 0x80
    91			}
    92			err = out.WriteByte(o)
    93			if err != nil {
    94				return
    95			}
    96		}
    97	
    98		return nil
    99	}
   100	
   101	func marshalInt64(out *forkableWriter, i int64) (err error) {
   102		n := int64Length(i)
   103	
   104		for ; n > 0; n-- {
   105			err = out.WriteByte(byte(i >> uint((n-1)*8)))
   106			if err != nil {
   107				return
   108			}
   109		}
   110	
   111		return nil
   112	}
   113	
   114	func int64Length(i int64) (numBytes int) {
   115		numBytes = 1
   116	
   117		for i > 127 {
   118			numBytes++
   119			i >>= 8
   120		}
   121	
   122		for i < -128 {
   123			numBytes++
   124			i >>= 8
   125		}
   126	
   127		return
   128	}
   129	
   130	func marshalBigInt(out *forkableWriter, n *big.Int) (err error) {
   131		if n.Sign() < 0 {
   132			// A negative number has to be converted to two's-complement
   133			// form. So we'll subtract 1 and invert. If the
   134			// most-significant-bit isn't set then we'll need to pad the
   135			// beginning with 0xff in order to keep the number negative.
   136			nMinus1 := new(big.Int).Neg(n)
   137			nMinus1.Sub(nMinus1, bigOne)
   138			bytes := nMinus1.Bytes()
   139			for i := range bytes {
   140				bytes[i] ^= 0xff
   141			}
   142			if len(bytes) == 0 || bytes[0]&0x80 == 0 {
   143				err = out.WriteByte(0xff)
   144				if err != nil {
   145					return
   146				}
   147			}
   148			_, err = out.Write(bytes)
   149		} else if n.Sign() == 0 {
   150			// Zero is written as a single 0 zero rather than no bytes.
   151			err = out.WriteByte(0x00)
   152		} else {
   153			bytes := n.Bytes()
   154			if len(bytes) > 0 && bytes[0]&0x80 != 0 {
   155				// We'll have to pad this with 0x00 in order to stop it
   156				// looking like a negative number.
   157				err = out.WriteByte(0)
   158				if err != nil {
   159					return
   160				}
   161			}
   162			_, err = out.Write(bytes)
   163		}
   164		return
   165	}
   166	
   167	func marshalLength(out *forkableWriter, i int) (err error) {
   168		n := lengthLength(i)
   169	
   170		for ; n > 0; n-- {
   171			err = out.WriteByte(byte(i >> uint((n-1)*8)))
   172			if err != nil {
   173				return
   174			}
   175		}
   176	
   177		return nil
   178	}
   179	
   180	func lengthLength(i int) (numBytes int) {
   181		numBytes = 1
   182		for i > 255 {
   183			numBytes++
   184			i >>= 8
   185		}
   186		return
   187	}
   188	
   189	func marshalTagAndLength(out *forkableWriter, t tagAndLength) (err error) {
   190		b := uint8(t.class) << 6
   191		if t.isCompound {
   192			b |= 0x20
   193		}
   194		if t.tag >= 31 {
   195			b |= 0x1f
   196			err = out.WriteByte(b)
   197			if err != nil {
   198				return
   199			}
   200			err = marshalBase128Int(out, int64(t.tag))
   201			if err != nil {
   202				return
   203			}
   204		} else {
   205			b |= uint8(t.tag)
   206			err = out.WriteByte(b)
   207			if err != nil {
   208				return
   209			}
   210		}
   211	
   212		if t.length >= 128 {
   213			l := lengthLength(t.length)
   214			err = out.WriteByte(0x80 | byte(l))
   215			if err != nil {
   216				return
   217			}
   218			err = marshalLength(out, t.length)
   219			if err != nil {
   220				return
   221			}
   222		} else {
   223			err = out.WriteByte(byte(t.length))
   224			if err != nil {
   225				return
   226			}
   227		}
   228	
   229		return nil
   230	}
   231	
   232	func marshalBitString(out *forkableWriter, b BitString) (err error) {
   233		paddingBits := byte((8 - b.BitLength%8) % 8)
   234		err = out.WriteByte(paddingBits)
   235		if err != nil {
   236			return
   237		}
   238		_, err = out.Write(b.Bytes)
   239		return
   240	}
   241	
   242	func marshalObjectIdentifier(out *forkableWriter, oid []int) (err error) {
   243		if len(oid) < 2 || oid[0] > 2 || (oid[0] < 2 && oid[1] >= 40) {
   244			return StructuralError{"invalid object identifier"}
   245		}
   246	
   247		err = marshalBase128Int(out, int64(oid[0]*40+oid[1]))
   248		if err != nil {
   249			return
   250		}
   251		for i := 2; i < len(oid); i++ {
   252			err = marshalBase128Int(out, int64(oid[i]))
   253			if err != nil {
   254				return
   255			}
   256		}
   257	
   258		return
   259	}
   260	
   261	func marshalPrintableString(out *forkableWriter, s string) (err error) {
   262		b := []byte(s)
   263		for _, c := range b {
   264			if !isPrintable(c) {
   265				return StructuralError{"PrintableString contains invalid character"}
   266			}
   267		}
   268	
   269		_, err = out.Write(b)
   270		return
   271	}
   272	
   273	func marshalIA5String(out *forkableWriter, s string) (err error) {
   274		b := []byte(s)
   275		for _, c := range b {
   276			if c > 127 {
   277				return StructuralError{"IA5String contains invalid character"}
   278			}
   279		}
   280	
   281		_, err = out.Write(b)
   282		return
   283	}
   284	
   285	func marshalUTF8String(out *forkableWriter, s string) (err error) {
   286		_, err = out.Write([]byte(s))
   287		return
   288	}
   289	
   290	func marshalTwoDigits(out *forkableWriter, v int) (err error) {
   291		err = out.WriteByte(byte('0' + (v/10)%10))
   292		if err != nil {
   293			return
   294		}
   295		return out.WriteByte(byte('0' + v%10))
   296	}
   297	
   298	func marshalFourDigits(out *forkableWriter, v int) (err error) {
   299		var bytes [4]byte
   300		for i := range bytes {
   301			bytes[3-i] = '0' + byte(v%10)
   302			v /= 10
   303		}
   304		_, err = out.Write(bytes[:])
   305		return
   306	}
   307	
   308	func outsideUTCRange(t time.Time) bool {
   309		year := t.Year()
   310		return year < 1950 || year >= 2050
   311	}
   312	
   313	func marshalUTCTime(out *forkableWriter, t time.Time) (err error) {
   314		year := t.Year()
   315	
   316		switch {
   317		case 1950 <= year && year < 2000:
   318			err = marshalTwoDigits(out, int(year-1900))
   319		case 2000 <= year && year < 2050:
   320			err = marshalTwoDigits(out, int(year-2000))
   321		default:
   322			return StructuralError{"cannot represent time as UTCTime"}
   323		}
   324		if err != nil {
   325			return
   326		}
   327	
   328		return marshalTimeCommon(out, t)
   329	}
   330	
   331	func marshalGeneralizedTime(out *forkableWriter, t time.Time) (err error) {
   332		year := t.Year()
   333		if year < 0 || year > 9999 {
   334			return StructuralError{"cannot represent time as GeneralizedTime"}
   335		}
   336		if err = marshalFourDigits(out, year); err != nil {
   337			return
   338		}
   339	
   340		return marshalTimeCommon(out, t)
   341	}
   342	
   343	func marshalTimeCommon(out *forkableWriter, t time.Time) (err error) {
   344		_, month, day := t.Date()
   345	
   346		err = marshalTwoDigits(out, int(month))
   347		if err != nil {
   348			return
   349		}
   350	
   351		err = marshalTwoDigits(out, day)
   352		if err != nil {
   353			return
   354		}
   355	
   356		hour, min, sec := t.Clock()
   357	
   358		err = marshalTwoDigits(out, hour)
   359		if err != nil {
   360			return
   361		}
   362	
   363		err = marshalTwoDigits(out, min)
   364		if err != nil {
   365			return
   366		}
   367	
   368		err = marshalTwoDigits(out, sec)
   369		if err != nil {
   370			return
   371		}
   372	
   373		_, offset := t.Zone()
   374	
   375		switch {
   376		case offset/60 == 0:
   377			err = out.WriteByte('Z')
   378			return
   379		case offset > 0:
   380			err = out.WriteByte('+')
   381		case offset < 0:
   382			err = out.WriteByte('-')
   383		}
   384	
   385		if err != nil {
   386			return
   387		}
   388	
   389		offsetMinutes := offset / 60
   390		if offsetMinutes < 0 {
   391			offsetMinutes = -offsetMinutes
   392		}
   393	
   394		err = marshalTwoDigits(out, offsetMinutes/60)
   395		if err != nil {
   396			return
   397		}
   398	
   399		err = marshalTwoDigits(out, offsetMinutes%60)
   400		return
   401	}
   402	
   403	func stripTagAndLength(in []byte) []byte {
   404		_, offset, err := parseTagAndLength(in, 0)
   405		if err != nil {
   406			return in
   407		}
   408		return in[offset:]
   409	}
   410	
   411	func marshalBody(out *forkableWriter, value reflect.Value, params fieldParameters) (err error) {
   412		switch value.Type() {
   413		case flagType:
   414			return nil
   415		case timeType:
   416			t := value.Interface().(time.Time)
   417			if params.timeType == tagGeneralizedTime || outsideUTCRange(t) {
   418				return marshalGeneralizedTime(out, t)
   419			} else {
   420				return marshalUTCTime(out, t)
   421			}
   422		case bitStringType:
   423			return marshalBitString(out, value.Interface().(BitString))
   424		case objectIdentifierType:
   425			return marshalObjectIdentifier(out, value.Interface().(ObjectIdentifier))
   426		case bigIntType:
   427			return marshalBigInt(out, value.Interface().(*big.Int))
   428		}
   429	
   430		switch v := value; v.Kind() {
   431		case reflect.Bool:
   432			if v.Bool() {
   433				return out.WriteByte(255)
   434			} else {
   435				return out.WriteByte(0)
   436			}
   437		case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
   438			return marshalInt64(out, int64(v.Int()))
   439		case reflect.Struct:
   440			t := v.Type()
   441	
   442			startingField := 0
   443	
   444			// If the first element of the structure is a non-empty
   445			// RawContents, then we don't bother serializing the rest.
   446			if t.NumField() > 0 && t.Field(0).Type == rawContentsType {
   447				s := v.Field(0)
   448				if s.Len() > 0 {
   449					bytes := make([]byte, s.Len())
   450					for i := 0; i < s.Len(); i++ {
   451						bytes[i] = uint8(s.Index(i).Uint())
   452					}
   453					/* The RawContents will contain the tag and
   454					 * length fields but we'll also be writing
   455					 * those ourselves, so we strip them out of
   456					 * bytes */
   457					_, err = out.Write(stripTagAndLength(bytes))
   458					return
   459				} else {
   460					startingField = 1
   461				}
   462			}
   463	
   464			for i := startingField; i < t.NumField(); i++ {
   465				var pre *forkableWriter
   466				pre, out = out.fork()
   467				err = marshalField(pre, v.Field(i), parseFieldParameters(t.Field(i).Tag.Get("asn1")))
   468				if err != nil {
   469					return
   470				}
   471			}
   472			return
   473		case reflect.Slice:
   474			sliceType := v.Type()
   475			if sliceType.Elem().Kind() == reflect.Uint8 {
   476				bytes := make([]byte, v.Len())
   477				for i := 0; i < v.Len(); i++ {
   478					bytes[i] = uint8(v.Index(i).Uint())
   479				}
   480				_, err = out.Write(bytes)
   481				return
   482			}
   483	
   484			var fp fieldParameters
   485			for i := 0; i < v.Len(); i++ {
   486				var pre *forkableWriter
   487				pre, out = out.fork()
   488				err = marshalField(pre, v.Index(i), fp)
   489				if err != nil {
   490					return
   491				}
   492			}
   493			return
   494		case reflect.String:
   495			switch params.stringType {
   496			case tagIA5String:
   497				return marshalIA5String(out, v.String())
   498			case tagPrintableString:
   499				return marshalPrintableString(out, v.String())
   500			default:
   501				return marshalUTF8String(out, v.String())
   502			}
   503		}
   504	
   505		return StructuralError{"unknown Go type"}
   506	}
   507	
   508	func marshalField(out *forkableWriter, v reflect.Value, params fieldParameters) (err error) {
   509		// If the field is an interface{} then recurse into it.
   510		if v.Kind() == reflect.Interface && v.Type().NumMethod() == 0 {
   511			return marshalField(out, v.Elem(), params)
   512		}
   513	
   514		if v.Kind() == reflect.Slice && v.Len() == 0 && params.omitEmpty {
   515			return
   516		}
   517	
   518		if params.optional && params.defaultValue != nil && canHaveDefaultValue(v.Kind()) {
   519			defaultValue := reflect.New(v.Type()).Elem()
   520			defaultValue.SetInt(*params.defaultValue)
   521	
   522			if reflect.DeepEqual(v.Interface(), defaultValue.Interface()) {
   523				return
   524			}
   525		}
   526	
   527		// If no default value is given then the zero value for the type is
   528		// assumed to be the default value. This isn't obviously the correct
   529		// behaviour, but it's what Go has traditionally done.
   530		if params.optional && params.defaultValue == nil {
   531			if reflect.DeepEqual(v.Interface(), reflect.Zero(v.Type()).Interface()) {
   532				return
   533			}
   534		}
   535	
   536		if v.Type() == rawValueType {
   537			rv := v.Interface().(RawValue)
   538			if len(rv.FullBytes) != 0 {
   539				_, err = out.Write(rv.FullBytes)
   540			} else {
   541				err = marshalTagAndLength(out, tagAndLength{rv.Class, rv.Tag, len(rv.Bytes), rv.IsCompound})
   542				if err != nil {
   543					return
   544				}
   545				_, err = out.Write(rv.Bytes)
   546			}
   547			return
   548		}
   549	
   550		tag, isCompound, ok := getUniversalType(v.Type())
   551		if !ok {
   552			err = StructuralError{fmt.Sprintf("unknown Go type: %v", v.Type())}
   553			return
   554		}
   555		class := classUniversal
   556	
   557		if params.timeType != 0 && tag != tagUTCTime {
   558			return StructuralError{"explicit time type given to non-time member"}
   559		}
   560	
   561		if params.stringType != 0 && tag != tagPrintableString {
   562			return StructuralError{"explicit string type given to non-string member"}
   563		}
   564	
   565		switch tag {
   566		case tagPrintableString:
   567			if params.stringType == 0 {
   568				// This is a string without an explicit string type. We'll use
   569				// a PrintableString if the character set in the string is
   570				// sufficiently limited, otherwise we'll use a UTF8String.
   571				for _, r := range v.String() {
   572					if r >= utf8.RuneSelf || !isPrintable(byte(r)) {
   573						if !utf8.ValidString(v.String()) {
   574							return errors.New("asn1: string not valid UTF-8")
   575						}
   576						tag = tagUTF8String
   577						break
   578					}
   579				}
   580			} else {
   581				tag = params.stringType
   582			}
   583		case tagUTCTime:
   584			if params.timeType == tagGeneralizedTime || outsideUTCRange(v.Interface().(time.Time)) {
   585				tag = tagGeneralizedTime
   586			}
   587		}
   588	
   589		if params.set {
   590			if tag != tagSequence {
   591				return StructuralError{"non sequence tagged as set"}
   592			}
   593			tag = tagSet
   594		}
   595	
   596		tags, body := out.fork()
   597	
   598		err = marshalBody(body, v, params)
   599		if err != nil {
   600			return
   601		}
   602	
   603		bodyLen := body.Len()
   604	
   605		var explicitTag *forkableWriter
   606		if params.explicit {
   607			explicitTag, tags = tags.fork()
   608		}
   609	
   610		if !params.explicit && params.tag != nil {
   611			// implicit tag.
   612			tag = *params.tag
   613			class = classContextSpecific
   614		}
   615	
   616		err = marshalTagAndLength(tags, tagAndLength{class, tag, bodyLen, isCompound})
   617		if err != nil {
   618			return
   619		}
   620	
   621		if params.explicit {
   622			err = marshalTagAndLength(explicitTag, tagAndLength{
   623				class:      classContextSpecific,
   624				tag:        *params.tag,
   625				length:     bodyLen + tags.Len(),
   626				isCompound: true,
   627			})
   628		}
   629	
   630		return nil
   631	}
   632	
   633	// Marshal returns the ASN.1 encoding of val.
   634	//
   635	// In addition to the struct tags recognised by Unmarshal, the following can be
   636	// used:
   637	//
   638	//	ia5:		causes strings to be marshaled as ASN.1, IA5 strings
   639	//	omitempty:	causes empty slices to be skipped
   640	//	printable:	causes strings to be marshaled as ASN.1, PrintableString strings.
   641	//	utf8:		causes strings to be marshaled as ASN.1, UTF8 strings
   642	func Marshal(val interface{}) ([]byte, error) {
   643		var out bytes.Buffer
   644		v := reflect.ValueOf(val)
   645		f := newForkableWriter()
   646		err := marshalField(f, v, fieldParameters{})
   647		if err != nil {
   648			return nil, err
   649		}
   650		_, err = f.writeTo(&out)
   651		return out.Bytes(), nil
   652	}
   653	

View as plain text