Source file src/cmd/gofmt/rewrite.go

     1  // Copyright 2009 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package main
     6  
     7  import (
     8  	"fmt"
     9  	"go/ast"
    10  	"go/parser"
    11  	"go/token"
    12  	"os"
    13  	"reflect"
    14  	"strings"
    15  	"unicode"
    16  	"unicode/utf8"
    17  )
    18  
    19  func initRewrite() {
    20  	if *rewriteRule == "" {
    21  		rewrite = nil // disable any previous rewrite
    22  		return
    23  	}
    24  	f := strings.Split(*rewriteRule, "->")
    25  	if len(f) != 2 {
    26  		fmt.Fprintf(os.Stderr, "rewrite rule must be of the form 'pattern -> replacement'\n")
    27  		os.Exit(2)
    28  	}
    29  	pattern := parseExpr(f[0], "pattern")
    30  	replace := parseExpr(f[1], "replacement")
    31  	rewrite = func(fset *token.FileSet, p *ast.File) *ast.File {
    32  		return rewriteFile(fset, pattern, replace, p)
    33  	}
    34  }
    35  
    36  // parseExpr parses s as an expression.
    37  // It might make sense to expand this to allow statement patterns,
    38  // but there are problems with preserving formatting and also
    39  // with what a wildcard for a statement looks like.
    40  func parseExpr(s, what string) ast.Expr {
    41  	x, err := parser.ParseExpr(s)
    42  	if err != nil {
    43  		fmt.Fprintf(os.Stderr, "parsing %s %s at %s\n", what, s, err)
    44  		os.Exit(2)
    45  	}
    46  	return x
    47  }
    48  
    49  // Keep this function for debugging.
    50  /*
    51  func dump(msg string, val reflect.Value) {
    52  	fmt.Printf("%s:\n", msg)
    53  	ast.Print(fileSet, val.Interface())
    54  	fmt.Println()
    55  }
    56  */
    57  
    58  // rewriteFile applies the rewrite rule 'pattern -> replace' to an entire file.
    59  func rewriteFile(fileSet *token.FileSet, pattern, replace ast.Expr, p *ast.File) *ast.File {
    60  	cmap := ast.NewCommentMap(fileSet, p, p.Comments)
    61  	m := make(map[string]reflect.Value)
    62  	pat := reflect.ValueOf(pattern)
    63  	repl := reflect.ValueOf(replace)
    64  
    65  	var rewriteVal func(val reflect.Value) reflect.Value
    66  	rewriteVal = func(val reflect.Value) reflect.Value {
    67  		// don't bother if val is invalid to start with
    68  		if !val.IsValid() {
    69  			return reflect.Value{}
    70  		}
    71  		val = apply(rewriteVal, val)
    72  		clear(m)
    73  		if match(m, pat, val) {
    74  			val = subst(m, repl, reflect.ValueOf(val.Interface().(ast.Node).Pos()))
    75  		}
    76  		return val
    77  	}
    78  
    79  	r := apply(rewriteVal, reflect.ValueOf(p)).Interface().(*ast.File)
    80  	r.Comments = cmap.Filter(r).Comments() // recreate comments list
    81  	return r
    82  }
    83  
    84  // set is a wrapper for x.Set(y); it protects the caller from panics if x cannot be changed to y.
    85  func set(x, y reflect.Value) {
    86  	// don't bother if x cannot be set or y is invalid
    87  	if !x.CanSet() || !y.IsValid() {
    88  		return
    89  	}
    90  	defer func() {
    91  		if x := recover(); x != nil {
    92  			if s, ok := x.(string); ok &&
    93  				(strings.Contains(s, "type mismatch") || strings.Contains(s, "not assignable")) {
    94  				// x cannot be set to y - ignore this rewrite
    95  				return
    96  			}
    97  			panic(x)
    98  		}
    99  	}()
   100  	x.Set(y)
   101  }
   102  
   103  // Values/types for special cases.
   104  var (
   105  	objectPtrNil = reflect.ValueOf((*ast.Object)(nil))
   106  	scopePtrNil  = reflect.ValueOf((*ast.Scope)(nil))
   107  
   108  	identType     = reflect.TypeOf((*ast.Ident)(nil))
   109  	objectPtrType = reflect.TypeOf((*ast.Object)(nil))
   110  	positionType  = reflect.TypeOf(token.NoPos)
   111  	callExprType  = reflect.TypeOf((*ast.CallExpr)(nil))
   112  	scopePtrType  = reflect.TypeOf((*ast.Scope)(nil))
   113  )
   114  
   115  // apply replaces each AST field x in val with f(x), returning val.
   116  // To avoid extra conversions, f operates on the reflect.Value form.
   117  func apply(f func(reflect.Value) reflect.Value, val reflect.Value) reflect.Value {
   118  	if !val.IsValid() {
   119  		return reflect.Value{}
   120  	}
   121  
   122  	// *ast.Objects introduce cycles and are likely incorrect after
   123  	// rewrite; don't follow them but replace with nil instead
   124  	if val.Type() == objectPtrType {
   125  		return objectPtrNil
   126  	}
   127  
   128  	// similarly for scopes: they are likely incorrect after a rewrite;
   129  	// replace them with nil
   130  	if val.Type() == scopePtrType {
   131  		return scopePtrNil
   132  	}
   133  
   134  	switch v := reflect.Indirect(val); v.Kind() {
   135  	case reflect.Slice:
   136  		for i := 0; i < v.Len(); i++ {
   137  			e := v.Index(i)
   138  			set(e, f(e))
   139  		}
   140  	case reflect.Struct:
   141  		for i := 0; i < v.NumField(); i++ {
   142  			e := v.Field(i)
   143  			set(e, f(e))
   144  		}
   145  	case reflect.Interface:
   146  		e := v.Elem()
   147  		set(v, f(e))
   148  	}
   149  	return val
   150  }
   151  
   152  func isWildcard(s string) bool {
   153  	rune, size := utf8.DecodeRuneInString(s)
   154  	return size == len(s) && unicode.IsLower(rune)
   155  }
   156  
   157  // match reports whether pattern matches val,
   158  // recording wildcard submatches in m.
   159  // If m == nil, match checks whether pattern == val.
   160  func match(m map[string]reflect.Value, pattern, val reflect.Value) bool {
   161  	// Wildcard matches any expression. If it appears multiple
   162  	// times in the pattern, it must match the same expression
   163  	// each time.
   164  	if m != nil && pattern.IsValid() && pattern.Type() == identType {
   165  		name := pattern.Interface().(*ast.Ident).Name
   166  		if isWildcard(name) && val.IsValid() {
   167  			// wildcards only match valid (non-nil) expressions.
   168  			if _, ok := val.Interface().(ast.Expr); ok && !val.IsNil() {
   169  				if old, ok := m[name]; ok {
   170  					return match(nil, old, val)
   171  				}
   172  				m[name] = val
   173  				return true
   174  			}
   175  		}
   176  	}
   177  
   178  	// Otherwise, pattern and val must match recursively.
   179  	if !pattern.IsValid() || !val.IsValid() {
   180  		return !pattern.IsValid() && !val.IsValid()
   181  	}
   182  	if pattern.Type() != val.Type() {
   183  		return false
   184  	}
   185  
   186  	// Special cases.
   187  	switch pattern.Type() {
   188  	case identType:
   189  		// For identifiers, only the names need to match
   190  		// (and none of the other *ast.Object information).
   191  		// This is a common case, handle it all here instead
   192  		// of recursing down any further via reflection.
   193  		p := pattern.Interface().(*ast.Ident)
   194  		v := val.Interface().(*ast.Ident)
   195  		return p == nil && v == nil || p != nil && v != nil && p.Name == v.Name
   196  	case objectPtrType, positionType:
   197  		// object pointers and token positions always match
   198  		return true
   199  	case callExprType:
   200  		// For calls, the Ellipsis fields (token.Pos) must
   201  		// match since that is how f(x) and f(x...) are different.
   202  		// Check them here but fall through for the remaining fields.
   203  		p := pattern.Interface().(*ast.CallExpr)
   204  		v := val.Interface().(*ast.CallExpr)
   205  		if p.Ellipsis.IsValid() != v.Ellipsis.IsValid() {
   206  			return false
   207  		}
   208  	}
   209  
   210  	p := reflect.Indirect(pattern)
   211  	v := reflect.Indirect(val)
   212  	if !p.IsValid() || !v.IsValid() {
   213  		return !p.IsValid() && !v.IsValid()
   214  	}
   215  
   216  	switch p.Kind() {
   217  	case reflect.Slice:
   218  		if p.Len() != v.Len() {
   219  			return false
   220  		}
   221  		for i := 0; i < p.Len(); i++ {
   222  			if !match(m, p.Index(i), v.Index(i)) {
   223  				return false
   224  			}
   225  		}
   226  		return true
   227  
   228  	case reflect.Struct:
   229  		for i := 0; i < p.NumField(); i++ {
   230  			if !match(m, p.Field(i), v.Field(i)) {
   231  				return false
   232  			}
   233  		}
   234  		return true
   235  
   236  	case reflect.Interface:
   237  		return match(m, p.Elem(), v.Elem())
   238  	}
   239  
   240  	// Handle token integers, etc.
   241  	return p.Interface() == v.Interface()
   242  }
   243  
   244  // subst returns a copy of pattern with values from m substituted in place
   245  // of wildcards and pos used as the position of tokens from the pattern.
   246  // if m == nil, subst returns a copy of pattern and doesn't change the line
   247  // number information.
   248  func subst(m map[string]reflect.Value, pattern reflect.Value, pos reflect.Value) reflect.Value {
   249  	if !pattern.IsValid() {
   250  		return reflect.Value{}
   251  	}
   252  
   253  	// Wildcard gets replaced with map value.
   254  	if m != nil && pattern.Type() == identType {
   255  		name := pattern.Interface().(*ast.Ident).Name
   256  		if isWildcard(name) {
   257  			if old, ok := m[name]; ok {
   258  				return subst(nil, old, reflect.Value{})
   259  			}
   260  		}
   261  	}
   262  
   263  	if pos.IsValid() && pattern.Type() == positionType {
   264  		// use new position only if old position was valid in the first place
   265  		if old := pattern.Interface().(token.Pos); !old.IsValid() {
   266  			return pattern
   267  		}
   268  		return pos
   269  	}
   270  
   271  	// Otherwise copy.
   272  	switch p := pattern; p.Kind() {
   273  	case reflect.Slice:
   274  		if p.IsNil() {
   275  			// Do not turn nil slices into empty slices. go/ast
   276  			// guarantees that certain lists will be nil if not
   277  			// populated.
   278  			return reflect.Zero(p.Type())
   279  		}
   280  		v := reflect.MakeSlice(p.Type(), p.Len(), p.Len())
   281  		for i := 0; i < p.Len(); i++ {
   282  			v.Index(i).Set(subst(m, p.Index(i), pos))
   283  		}
   284  		return v
   285  
   286  	case reflect.Struct:
   287  		v := reflect.New(p.Type()).Elem()
   288  		for i := 0; i < p.NumField(); i++ {
   289  			v.Field(i).Set(subst(m, p.Field(i), pos))
   290  		}
   291  		return v
   292  
   293  	case reflect.Pointer:
   294  		v := reflect.New(p.Type()).Elem()
   295  		if elem := p.Elem(); elem.IsValid() {
   296  			v.Set(subst(m, elem, pos).Addr())
   297  		}
   298  		return v
   299  
   300  	case reflect.Interface:
   301  		v := reflect.New(p.Type()).Elem()
   302  		if elem := p.Elem(); elem.IsValid() {
   303  			v.Set(subst(m, elem, pos))
   304  		}
   305  		return v
   306  	}
   307  
   308  	return pattern
   309  }
   310  

View as plain text