...
Run Format

Source file src/encoding/asn1/marshal.go

     1	// Copyright 2009 The Go Authors. All rights reserved.
     2	// Use of this source code is governed by a BSD-style
     3	// license that can be found in the LICENSE file.
     4	
     5	package asn1
     6	
     7	import (
     8		"bytes"
     9		"errors"
    10		"fmt"
    11		"io"
    12		"math/big"
    13		"reflect"
    14		"time"
    15		"unicode/utf8"
    16	)
    17	
    18	// A forkableWriter is an in-memory buffer that can be
    19	// 'forked' to create new forkableWriters that bracket the
    20	// original.  After
    21	//    pre, post := w.fork();
    22	// the overall sequence of bytes represented is logically w+pre+post.
    23	type forkableWriter struct {
    24		*bytes.Buffer
    25		pre, post *forkableWriter
    26	}
    27	
    28	func newForkableWriter() *forkableWriter {
    29		return &forkableWriter{new(bytes.Buffer), nil, nil}
    30	}
    31	
    32	func (f *forkableWriter) fork() (pre, post *forkableWriter) {
    33		if f.pre != nil || f.post != nil {
    34			panic("have already forked")
    35		}
    36		f.pre = newForkableWriter()
    37		f.post = newForkableWriter()
    38		return f.pre, f.post
    39	}
    40	
    41	func (f *forkableWriter) Len() (l int) {
    42		l += f.Buffer.Len()
    43		if f.pre != nil {
    44			l += f.pre.Len()
    45		}
    46		if f.post != nil {
    47			l += f.post.Len()
    48		}
    49		return
    50	}
    51	
    52	func (f *forkableWriter) writeTo(out io.Writer) (n int, err error) {
    53		n, err = out.Write(f.Bytes())
    54		if err != nil {
    55			return
    56		}
    57	
    58		var nn int
    59	
    60		if f.pre != nil {
    61			nn, err = f.pre.writeTo(out)
    62			n += nn
    63			if err != nil {
    64				return
    65			}
    66		}
    67	
    68		if f.post != nil {
    69			nn, err = f.post.writeTo(out)
    70			n += nn
    71		}
    72		return
    73	}
    74	
    75	func marshalBase128Int(out *forkableWriter, n int64) (err error) {
    76		if n == 0 {
    77			err = out.WriteByte(0)
    78			return
    79		}
    80	
    81		l := 0
    82		for i := n; i > 0; i >>= 7 {
    83			l++
    84		}
    85	
    86		for i := l - 1; i >= 0; i-- {
    87			o := byte(n >> uint(i*7))
    88			o &= 0x7f
    89			if i != 0 {
    90				o |= 0x80
    91			}
    92			err = out.WriteByte(o)
    93			if err != nil {
    94				return
    95			}
    96		}
    97	
    98		return nil
    99	}
   100	
   101	func marshalInt64(out *forkableWriter, i int64) (err error) {
   102		n := int64Length(i)
   103	
   104		for ; n > 0; n-- {
   105			err = out.WriteByte(byte(i >> uint((n-1)*8)))
   106			if err != nil {
   107				return
   108			}
   109		}
   110	
   111		return nil
   112	}
   113	
   114	func int64Length(i int64) (numBytes int) {
   115		numBytes = 1
   116	
   117		for i > 127 {
   118			numBytes++
   119			i >>= 8
   120		}
   121	
   122		for i < -128 {
   123			numBytes++
   124			i >>= 8
   125		}
   126	
   127		return
   128	}
   129	
   130	func marshalBigInt(out *forkableWriter, n *big.Int) (err error) {
   131		if n.Sign() < 0 {
   132			// A negative number has to be converted to two's-complement
   133			// form. So we'll subtract 1 and invert. If the
   134			// most-significant-bit isn't set then we'll need to pad the
   135			// beginning with 0xff in order to keep the number negative.
   136			nMinus1 := new(big.Int).Neg(n)
   137			nMinus1.Sub(nMinus1, bigOne)
   138			bytes := nMinus1.Bytes()
   139			for i := range bytes {
   140				bytes[i] ^= 0xff
   141			}
   142			if len(bytes) == 0 || bytes[0]&0x80 == 0 {
   143				err = out.WriteByte(0xff)
   144				if err != nil {
   145					return
   146				}
   147			}
   148			_, err = out.Write(bytes)
   149		} else if n.Sign() == 0 {
   150			// Zero is written as a single 0 zero rather than no bytes.
   151			err = out.WriteByte(0x00)
   152		} else {
   153			bytes := n.Bytes()
   154			if len(bytes) > 0 && bytes[0]&0x80 != 0 {
   155				// We'll have to pad this with 0x00 in order to stop it
   156				// looking like a negative number.
   157				err = out.WriteByte(0)
   158				if err != nil {
   159					return
   160				}
   161			}
   162			_, err = out.Write(bytes)
   163		}
   164		return
   165	}
   166	
   167	func marshalLength(out *forkableWriter, i int) (err error) {
   168		n := lengthLength(i)
   169	
   170		for ; n > 0; n-- {
   171			err = out.WriteByte(byte(i >> uint((n-1)*8)))
   172			if err != nil {
   173				return
   174			}
   175		}
   176	
   177		return nil
   178	}
   179	
   180	func lengthLength(i int) (numBytes int) {
   181		numBytes = 1
   182		for i > 255 {
   183			numBytes++
   184			i >>= 8
   185		}
   186		return
   187	}
   188	
   189	func marshalTagAndLength(out *forkableWriter, t tagAndLength) (err error) {
   190		b := uint8(t.class) << 6
   191		if t.isCompound {
   192			b |= 0x20
   193		}
   194		if t.tag >= 31 {
   195			b |= 0x1f
   196			err = out.WriteByte(b)
   197			if err != nil {
   198				return
   199			}
   200			err = marshalBase128Int(out, int64(t.tag))
   201			if err != nil {
   202				return
   203			}
   204		} else {
   205			b |= uint8(t.tag)
   206			err = out.WriteByte(b)
   207			if err != nil {
   208				return
   209			}
   210		}
   211	
   212		if t.length >= 128 {
   213			l := lengthLength(t.length)
   214			err = out.WriteByte(0x80 | byte(l))
   215			if err != nil {
   216				return
   217			}
   218			err = marshalLength(out, t.length)
   219			if err != nil {
   220				return
   221			}
   222		} else {
   223			err = out.WriteByte(byte(t.length))
   224			if err != nil {
   225				return
   226			}
   227		}
   228	
   229		return nil
   230	}
   231	
   232	func marshalBitString(out *forkableWriter, b BitString) (err error) {
   233		paddingBits := byte((8 - b.BitLength%8) % 8)
   234		err = out.WriteByte(paddingBits)
   235		if err != nil {
   236			return
   237		}
   238		_, err = out.Write(b.Bytes)
   239		return
   240	}
   241	
   242	func marshalObjectIdentifier(out *forkableWriter, oid []int) (err error) {
   243		if len(oid) < 2 || oid[0] > 2 || (oid[0] < 2 && oid[1] >= 40) {
   244			return StructuralError{"invalid object identifier"}
   245		}
   246	
   247		err = marshalBase128Int(out, int64(oid[0]*40+oid[1]))
   248		if err != nil {
   249			return
   250		}
   251		for i := 2; i < len(oid); i++ {
   252			err = marshalBase128Int(out, int64(oid[i]))
   253			if err != nil {
   254				return
   255			}
   256		}
   257	
   258		return
   259	}
   260	
   261	func marshalPrintableString(out *forkableWriter, s string) (err error) {
   262		b := []byte(s)
   263		for _, c := range b {
   264			if !isPrintable(c) {
   265				return StructuralError{"PrintableString contains invalid character"}
   266			}
   267		}
   268	
   269		_, err = out.Write(b)
   270		return
   271	}
   272	
   273	func marshalIA5String(out *forkableWriter, s string) (err error) {
   274		b := []byte(s)
   275		for _, c := range b {
   276			if c > 127 {
   277				return StructuralError{"IA5String contains invalid character"}
   278			}
   279		}
   280	
   281		_, err = out.Write(b)
   282		return
   283	}
   284	
   285	func marshalUTF8String(out *forkableWriter, s string) (err error) {
   286		_, err = out.Write([]byte(s))
   287		return
   288	}
   289	
   290	func marshalTwoDigits(out *forkableWriter, v int) (err error) {
   291		err = out.WriteByte(byte('0' + (v/10)%10))
   292		if err != nil {
   293			return
   294		}
   295		return out.WriteByte(byte('0' + v%10))
   296	}
   297	
   298	func marshalFourDigits(out *forkableWriter, v int) (err error) {
   299		var bytes [4]byte
   300		for i := range bytes {
   301			bytes[3-i] = '0' + byte(v%10)
   302			v /= 10
   303		}
   304		_, err = out.Write(bytes[:])
   305		return
   306	}
   307	
   308	func outsideUTCRange(t time.Time) bool {
   309		year := t.Year()
   310		return year < 1950 || year >= 2050
   311	}
   312	
   313	func marshalUTCTime(out *forkableWriter, t time.Time) (err error) {
   314		year := t.Year()
   315	
   316		switch {
   317		case 1950 <= year && year < 2000:
   318			err = marshalTwoDigits(out, int(year-1900))
   319		case 2000 <= year && year < 2050:
   320			err = marshalTwoDigits(out, int(year-2000))
   321		default:
   322			return StructuralError{"cannot represent time as UTCTime"}
   323		}
   324		if err != nil {
   325			return
   326		}
   327	
   328		return marshalTimeCommon(out, t)
   329	}
   330	
   331	func marshalGeneralizedTime(out *forkableWriter, t time.Time) (err error) {
   332		year := t.Year()
   333		if year < 0 || year > 9999 {
   334			return StructuralError{"cannot represent time as GeneralizedTime"}
   335		}
   336		if err = marshalFourDigits(out, year); err != nil {
   337			return
   338		}
   339	
   340		return marshalTimeCommon(out, t)
   341	}
   342	
   343	func marshalTimeCommon(out *forkableWriter, t time.Time) (err error) {
   344		_, month, day := t.Date()
   345	
   346		err = marshalTwoDigits(out, int(month))
   347		if err != nil {
   348			return
   349		}
   350	
   351		err = marshalTwoDigits(out, day)
   352		if err != nil {
   353			return
   354		}
   355	
   356		hour, min, sec := t.Clock()
   357	
   358		err = marshalTwoDigits(out, hour)
   359		if err != nil {
   360			return
   361		}
   362	
   363		err = marshalTwoDigits(out, min)
   364		if err != nil {
   365			return
   366		}
   367	
   368		err = marshalTwoDigits(out, sec)
   369		if err != nil {
   370			return
   371		}
   372	
   373		_, offset := t.Zone()
   374	
   375		switch {
   376		case offset/60 == 0:
   377			err = out.WriteByte('Z')
   378			return
   379		case offset > 0:
   380			err = out.WriteByte('+')
   381		case offset < 0:
   382			err = out.WriteByte('-')
   383		}
   384	
   385		if err != nil {
   386			return
   387		}
   388	
   389		offsetMinutes := offset / 60
   390		if offsetMinutes < 0 {
   391			offsetMinutes = -offsetMinutes
   392		}
   393	
   394		err = marshalTwoDigits(out, offsetMinutes/60)
   395		if err != nil {
   396			return
   397		}
   398	
   399		err = marshalTwoDigits(out, offsetMinutes%60)
   400		return
   401	}
   402	
   403	func stripTagAndLength(in []byte) []byte {
   404		_, offset, err := parseTagAndLength(in, 0)
   405		if err != nil {
   406			return in
   407		}
   408		return in[offset:]
   409	}
   410	
   411	func marshalBody(out *forkableWriter, value reflect.Value, params fieldParameters) (err error) {
   412		switch value.Type() {
   413		case timeType:
   414			t := value.Interface().(time.Time)
   415			if outsideUTCRange(t) {
   416				return marshalGeneralizedTime(out, t)
   417			} else {
   418				return marshalUTCTime(out, t)
   419			}
   420		case bitStringType:
   421			return marshalBitString(out, value.Interface().(BitString))
   422		case objectIdentifierType:
   423			return marshalObjectIdentifier(out, value.Interface().(ObjectIdentifier))
   424		case bigIntType:
   425			return marshalBigInt(out, value.Interface().(*big.Int))
   426		}
   427	
   428		switch v := value; v.Kind() {
   429		case reflect.Bool:
   430			if v.Bool() {
   431				return out.WriteByte(255)
   432			} else {
   433				return out.WriteByte(0)
   434			}
   435		case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
   436			return marshalInt64(out, int64(v.Int()))
   437		case reflect.Struct:
   438			t := v.Type()
   439	
   440			startingField := 0
   441	
   442			// If the first element of the structure is a non-empty
   443			// RawContents, then we don't bother serializing the rest.
   444			if t.NumField() > 0 && t.Field(0).Type == rawContentsType {
   445				s := v.Field(0)
   446				if s.Len() > 0 {
   447					bytes := make([]byte, s.Len())
   448					for i := 0; i < s.Len(); i++ {
   449						bytes[i] = uint8(s.Index(i).Uint())
   450					}
   451					/* The RawContents will contain the tag and
   452					 * length fields but we'll also be writing
   453					 * those ourselves, so we strip them out of
   454					 * bytes */
   455					_, err = out.Write(stripTagAndLength(bytes))
   456					return
   457				} else {
   458					startingField = 1
   459				}
   460			}
   461	
   462			for i := startingField; i < t.NumField(); i++ {
   463				var pre *forkableWriter
   464				pre, out = out.fork()
   465				err = marshalField(pre, v.Field(i), parseFieldParameters(t.Field(i).Tag.Get("asn1")))
   466				if err != nil {
   467					return
   468				}
   469			}
   470			return
   471		case reflect.Slice:
   472			sliceType := v.Type()
   473			if sliceType.Elem().Kind() == reflect.Uint8 {
   474				bytes := make([]byte, v.Len())
   475				for i := 0; i < v.Len(); i++ {
   476					bytes[i] = uint8(v.Index(i).Uint())
   477				}
   478				_, err = out.Write(bytes)
   479				return
   480			}
   481	
   482			var fp fieldParameters
   483			for i := 0; i < v.Len(); i++ {
   484				var pre *forkableWriter
   485				pre, out = out.fork()
   486				err = marshalField(pre, v.Index(i), fp)
   487				if err != nil {
   488					return
   489				}
   490			}
   491			return
   492		case reflect.String:
   493			switch params.stringType {
   494			case tagIA5String:
   495				return marshalIA5String(out, v.String())
   496			case tagPrintableString:
   497				return marshalPrintableString(out, v.String())
   498			default:
   499				return marshalUTF8String(out, v.String())
   500			}
   501		}
   502	
   503		return StructuralError{"unknown Go type"}
   504	}
   505	
   506	func marshalField(out *forkableWriter, v reflect.Value, params fieldParameters) (err error) {
   507		// If the field is an interface{} then recurse into it.
   508		if v.Kind() == reflect.Interface && v.Type().NumMethod() == 0 {
   509			return marshalField(out, v.Elem(), params)
   510		}
   511	
   512		if v.Kind() == reflect.Slice && v.Len() == 0 && params.omitEmpty {
   513			return
   514		}
   515	
   516		if params.optional && params.defaultValue != nil && canHaveDefaultValue(v.Kind()) {
   517			defaultValue := reflect.New(v.Type()).Elem()
   518			defaultValue.SetInt(*params.defaultValue)
   519	
   520			if reflect.DeepEqual(v.Interface(), defaultValue.Interface()) {
   521				return
   522			}
   523		}
   524	
   525		// If no default value is given then the zero value for the type is
   526		// assumed to be the default value. This isn't obviously the correct
   527		// behaviour, but it's what Go has traditionally done.
   528		if params.optional && params.defaultValue == nil {
   529			if reflect.DeepEqual(v.Interface(), reflect.Zero(v.Type()).Interface()) {
   530				return
   531			}
   532		}
   533	
   534		if v.Type() == rawValueType {
   535			rv := v.Interface().(RawValue)
   536			if len(rv.FullBytes) != 0 {
   537				_, err = out.Write(rv.FullBytes)
   538			} else {
   539				err = marshalTagAndLength(out, tagAndLength{rv.Class, rv.Tag, len(rv.Bytes), rv.IsCompound})
   540				if err != nil {
   541					return
   542				}
   543				_, err = out.Write(rv.Bytes)
   544			}
   545			return
   546		}
   547	
   548		tag, isCompound, ok := getUniversalType(v.Type())
   549		if !ok {
   550			err = StructuralError{fmt.Sprintf("unknown Go type: %v", v.Type())}
   551			return
   552		}
   553		class := classUniversal
   554	
   555		if params.stringType != 0 && tag != tagPrintableString {
   556			return StructuralError{"explicit string type given to non-string member"}
   557		}
   558	
   559		switch tag {
   560		case tagPrintableString:
   561			if params.stringType == 0 {
   562				// This is a string without an explicit string type. We'll use
   563				// a PrintableString if the character set in the string is
   564				// sufficiently limited, otherwise we'll use a UTF8String.
   565				for _, r := range v.String() {
   566					if r >= utf8.RuneSelf || !isPrintable(byte(r)) {
   567						if !utf8.ValidString(v.String()) {
   568							return errors.New("asn1: string not valid UTF-8")
   569						}
   570						tag = tagUTF8String
   571						break
   572					}
   573				}
   574			} else {
   575				tag = params.stringType
   576			}
   577		case tagUTCTime:
   578			if outsideUTCRange(v.Interface().(time.Time)) {
   579				tag = tagGeneralizedTime
   580			}
   581		}
   582	
   583		if params.set {
   584			if tag != tagSequence {
   585				return StructuralError{"non sequence tagged as set"}
   586			}
   587			tag = tagSet
   588		}
   589	
   590		tags, body := out.fork()
   591	
   592		err = marshalBody(body, v, params)
   593		if err != nil {
   594			return
   595		}
   596	
   597		bodyLen := body.Len()
   598	
   599		var explicitTag *forkableWriter
   600		if params.explicit {
   601			explicitTag, tags = tags.fork()
   602		}
   603	
   604		if !params.explicit && params.tag != nil {
   605			// implicit tag.
   606			tag = *params.tag
   607			class = classContextSpecific
   608		}
   609	
   610		err = marshalTagAndLength(tags, tagAndLength{class, tag, bodyLen, isCompound})
   611		if err != nil {
   612			return
   613		}
   614	
   615		if params.explicit {
   616			err = marshalTagAndLength(explicitTag, tagAndLength{
   617				class:      classContextSpecific,
   618				tag:        *params.tag,
   619				length:     bodyLen + tags.Len(),
   620				isCompound: true,
   621			})
   622		}
   623	
   624		return nil
   625	}
   626	
   627	// Marshal returns the ASN.1 encoding of val.
   628	//
   629	// In addition to the struct tags recognised by Unmarshal, the following can be
   630	// used:
   631	//
   632	//	ia5:		causes strings to be marshaled as ASN.1, IA5 strings
   633	//	omitempty:	causes empty slices to be skipped
   634	//	printable:	causes strings to be marshaled as ASN.1, PrintableString strings.
   635	//	utf8:		causes strings to be marshaled as ASN.1, UTF8 strings
   636	func Marshal(val interface{}) ([]byte, error) {
   637		var out bytes.Buffer
   638		v := reflect.ValueOf(val)
   639		f := newForkableWriter()
   640		err := marshalField(f, v, fieldParameters{})
   641		if err != nil {
   642			return nil, err
   643		}
   644		_, err = f.writeTo(&out)
   645		return out.Bytes(), nil
   646	}
   647	

View as plain text