...
Run Format

Source file src/cmd/fix/fix.go

Documentation: cmd/fix

     1  // Copyright 2011 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package main
     6  
     7  import (
     8  	"fmt"
     9  	"go/ast"
    10  	"go/parser"
    11  	"go/token"
    12  	"os"
    13  	"path"
    14  	"reflect"
    15  	"strconv"
    16  	"strings"
    17  )
    18  
    19  type fix struct {
    20  	name     string
    21  	date     string // date that fix was introduced, in YYYY-MM-DD format
    22  	f        func(*ast.File) bool
    23  	desc     string
    24  	disabled bool // whether this fix should be disabled by default
    25  }
    26  
    27  // main runs sort.Sort(byName(fixes)) before printing list of fixes.
    28  type byName []fix
    29  
    30  func (f byName) Len() int           { return len(f) }
    31  func (f byName) Swap(i, j int)      { f[i], f[j] = f[j], f[i] }
    32  func (f byName) Less(i, j int) bool { return f[i].name < f[j].name }
    33  
    34  // main runs sort.Sort(byDate(fixes)) before applying fixes.
    35  type byDate []fix
    36  
    37  func (f byDate) Len() int           { return len(f) }
    38  func (f byDate) Swap(i, j int)      { f[i], f[j] = f[j], f[i] }
    39  func (f byDate) Less(i, j int) bool { return f[i].date < f[j].date }
    40  
    41  var fixes []fix
    42  
    43  func register(f fix) {
    44  	fixes = append(fixes, f)
    45  }
    46  
    47  // walk traverses the AST x, calling visit(y) for each node y in the tree but
    48  // also with a pointer to each ast.Expr, ast.Stmt, and *ast.BlockStmt,
    49  // in a bottom-up traversal.
    50  func walk(x interface{}, visit func(interface{})) {
    51  	walkBeforeAfter(x, nop, visit)
    52  }
    53  
    54  func nop(interface{}) {}
    55  
    56  // walkBeforeAfter is like walk but calls before(x) before traversing
    57  // x's children and after(x) afterward.
    58  func walkBeforeAfter(x interface{}, before, after func(interface{})) {
    59  	before(x)
    60  
    61  	switch n := x.(type) {
    62  	default:
    63  		panic(fmt.Errorf("unexpected type %T in walkBeforeAfter", x))
    64  
    65  	case nil:
    66  
    67  	// pointers to interfaces
    68  	case *ast.Decl:
    69  		walkBeforeAfter(*n, before, after)
    70  	case *ast.Expr:
    71  		walkBeforeAfter(*n, before, after)
    72  	case *ast.Spec:
    73  		walkBeforeAfter(*n, before, after)
    74  	case *ast.Stmt:
    75  		walkBeforeAfter(*n, before, after)
    76  
    77  	// pointers to struct pointers
    78  	case **ast.BlockStmt:
    79  		walkBeforeAfter(*n, before, after)
    80  	case **ast.CallExpr:
    81  		walkBeforeAfter(*n, before, after)
    82  	case **ast.FieldList:
    83  		walkBeforeAfter(*n, before, after)
    84  	case **ast.FuncType:
    85  		walkBeforeAfter(*n, before, after)
    86  	case **ast.Ident:
    87  		walkBeforeAfter(*n, before, after)
    88  	case **ast.BasicLit:
    89  		walkBeforeAfter(*n, before, after)
    90  
    91  	// pointers to slices
    92  	case *[]ast.Decl:
    93  		walkBeforeAfter(*n, before, after)
    94  	case *[]ast.Expr:
    95  		walkBeforeAfter(*n, before, after)
    96  	case *[]*ast.File:
    97  		walkBeforeAfter(*n, before, after)
    98  	case *[]*ast.Ident:
    99  		walkBeforeAfter(*n, before, after)
   100  	case *[]ast.Spec:
   101  		walkBeforeAfter(*n, before, after)
   102  	case *[]ast.Stmt:
   103  		walkBeforeAfter(*n, before, after)
   104  
   105  	// These are ordered and grouped to match ../../go/ast/ast.go
   106  	case *ast.Field:
   107  		walkBeforeAfter(&n.Names, before, after)
   108  		walkBeforeAfter(&n.Type, before, after)
   109  		walkBeforeAfter(&n.Tag, before, after)
   110  	case *ast.FieldList:
   111  		for _, field := range n.List {
   112  			walkBeforeAfter(field, before, after)
   113  		}
   114  	case *ast.BadExpr:
   115  	case *ast.Ident:
   116  	case *ast.Ellipsis:
   117  		walkBeforeAfter(&n.Elt, before, after)
   118  	case *ast.BasicLit:
   119  	case *ast.FuncLit:
   120  		walkBeforeAfter(&n.Type, before, after)
   121  		walkBeforeAfter(&n.Body, before, after)
   122  	case *ast.CompositeLit:
   123  		walkBeforeAfter(&n.Type, before, after)
   124  		walkBeforeAfter(&n.Elts, before, after)
   125  	case *ast.ParenExpr:
   126  		walkBeforeAfter(&n.X, before, after)
   127  	case *ast.SelectorExpr:
   128  		walkBeforeAfter(&n.X, before, after)
   129  	case *ast.IndexExpr:
   130  		walkBeforeAfter(&n.X, before, after)
   131  		walkBeforeAfter(&n.Index, before, after)
   132  	case *ast.SliceExpr:
   133  		walkBeforeAfter(&n.X, before, after)
   134  		if n.Low != nil {
   135  			walkBeforeAfter(&n.Low, before, after)
   136  		}
   137  		if n.High != nil {
   138  			walkBeforeAfter(&n.High, before, after)
   139  		}
   140  	case *ast.TypeAssertExpr:
   141  		walkBeforeAfter(&n.X, before, after)
   142  		walkBeforeAfter(&n.Type, before, after)
   143  	case *ast.CallExpr:
   144  		walkBeforeAfter(&n.Fun, before, after)
   145  		walkBeforeAfter(&n.Args, before, after)
   146  	case *ast.StarExpr:
   147  		walkBeforeAfter(&n.X, before, after)
   148  	case *ast.UnaryExpr:
   149  		walkBeforeAfter(&n.X, before, after)
   150  	case *ast.BinaryExpr:
   151  		walkBeforeAfter(&n.X, before, after)
   152  		walkBeforeAfter(&n.Y, before, after)
   153  	case *ast.KeyValueExpr:
   154  		walkBeforeAfter(&n.Key, before, after)
   155  		walkBeforeAfter(&n.Value, before, after)
   156  
   157  	case *ast.ArrayType:
   158  		walkBeforeAfter(&n.Len, before, after)
   159  		walkBeforeAfter(&n.Elt, before, after)
   160  	case *ast.StructType:
   161  		walkBeforeAfter(&n.Fields, before, after)
   162  	case *ast.FuncType:
   163  		walkBeforeAfter(&n.Params, before, after)
   164  		if n.Results != nil {
   165  			walkBeforeAfter(&n.Results, before, after)
   166  		}
   167  	case *ast.InterfaceType:
   168  		walkBeforeAfter(&n.Methods, before, after)
   169  	case *ast.MapType:
   170  		walkBeforeAfter(&n.Key, before, after)
   171  		walkBeforeAfter(&n.Value, before, after)
   172  	case *ast.ChanType:
   173  		walkBeforeAfter(&n.Value, before, after)
   174  
   175  	case *ast.BadStmt:
   176  	case *ast.DeclStmt:
   177  		walkBeforeAfter(&n.Decl, before, after)
   178  	case *ast.EmptyStmt:
   179  	case *ast.LabeledStmt:
   180  		walkBeforeAfter(&n.Stmt, before, after)
   181  	case *ast.ExprStmt:
   182  		walkBeforeAfter(&n.X, before, after)
   183  	case *ast.SendStmt:
   184  		walkBeforeAfter(&n.Chan, before, after)
   185  		walkBeforeAfter(&n.Value, before, after)
   186  	case *ast.IncDecStmt:
   187  		walkBeforeAfter(&n.X, before, after)
   188  	case *ast.AssignStmt:
   189  		walkBeforeAfter(&n.Lhs, before, after)
   190  		walkBeforeAfter(&n.Rhs, before, after)
   191  	case *ast.GoStmt:
   192  		walkBeforeAfter(&n.Call, before, after)
   193  	case *ast.DeferStmt:
   194  		walkBeforeAfter(&n.Call, before, after)
   195  	case *ast.ReturnStmt:
   196  		walkBeforeAfter(&n.Results, before, after)
   197  	case *ast.BranchStmt:
   198  	case *ast.BlockStmt:
   199  		walkBeforeAfter(&n.List, before, after)
   200  	case *ast.IfStmt:
   201  		walkBeforeAfter(&n.Init, before, after)
   202  		walkBeforeAfter(&n.Cond, before, after)
   203  		walkBeforeAfter(&n.Body, before, after)
   204  		walkBeforeAfter(&n.Else, before, after)
   205  	case *ast.CaseClause:
   206  		walkBeforeAfter(&n.List, before, after)
   207  		walkBeforeAfter(&n.Body, before, after)
   208  	case *ast.SwitchStmt:
   209  		walkBeforeAfter(&n.Init, before, after)
   210  		walkBeforeAfter(&n.Tag, before, after)
   211  		walkBeforeAfter(&n.Body, before, after)
   212  	case *ast.TypeSwitchStmt:
   213  		walkBeforeAfter(&n.Init, before, after)
   214  		walkBeforeAfter(&n.Assign, before, after)
   215  		walkBeforeAfter(&n.Body, before, after)
   216  	case *ast.CommClause:
   217  		walkBeforeAfter(&n.Comm, before, after)
   218  		walkBeforeAfter(&n.Body, before, after)
   219  	case *ast.SelectStmt:
   220  		walkBeforeAfter(&n.Body, before, after)
   221  	case *ast.ForStmt:
   222  		walkBeforeAfter(&n.Init, before, after)
   223  		walkBeforeAfter(&n.Cond, before, after)
   224  		walkBeforeAfter(&n.Post, before, after)
   225  		walkBeforeAfter(&n.Body, before, after)
   226  	case *ast.RangeStmt:
   227  		walkBeforeAfter(&n.Key, before, after)
   228  		walkBeforeAfter(&n.Value, before, after)
   229  		walkBeforeAfter(&n.X, before, after)
   230  		walkBeforeAfter(&n.Body, before, after)
   231  
   232  	case *ast.ImportSpec:
   233  	case *ast.ValueSpec:
   234  		walkBeforeAfter(&n.Type, before, after)
   235  		walkBeforeAfter(&n.Values, before, after)
   236  		walkBeforeAfter(&n.Names, before, after)
   237  	case *ast.TypeSpec:
   238  		walkBeforeAfter(&n.Type, before, after)
   239  
   240  	case *ast.BadDecl:
   241  	case *ast.GenDecl:
   242  		walkBeforeAfter(&n.Specs, before, after)
   243  	case *ast.FuncDecl:
   244  		if n.Recv != nil {
   245  			walkBeforeAfter(&n.Recv, before, after)
   246  		}
   247  		walkBeforeAfter(&n.Type, before, after)
   248  		if n.Body != nil {
   249  			walkBeforeAfter(&n.Body, before, after)
   250  		}
   251  
   252  	case *ast.File:
   253  		walkBeforeAfter(&n.Decls, before, after)
   254  
   255  	case *ast.Package:
   256  		walkBeforeAfter(&n.Files, before, after)
   257  
   258  	case []*ast.File:
   259  		for i := range n {
   260  			walkBeforeAfter(&n[i], before, after)
   261  		}
   262  	case []ast.Decl:
   263  		for i := range n {
   264  			walkBeforeAfter(&n[i], before, after)
   265  		}
   266  	case []ast.Expr:
   267  		for i := range n {
   268  			walkBeforeAfter(&n[i], before, after)
   269  		}
   270  	case []*ast.Ident:
   271  		for i := range n {
   272  			walkBeforeAfter(&n[i], before, after)
   273  		}
   274  	case []ast.Stmt:
   275  		for i := range n {
   276  			walkBeforeAfter(&n[i], before, after)
   277  		}
   278  	case []ast.Spec:
   279  		for i := range n {
   280  			walkBeforeAfter(&n[i], before, after)
   281  		}
   282  	}
   283  	after(x)
   284  }
   285  
   286  // imports reports whether f imports path.
   287  func imports(f *ast.File, path string) bool {
   288  	return importSpec(f, path) != nil
   289  }
   290  
   291  // importSpec returns the import spec if f imports path,
   292  // or nil otherwise.
   293  func importSpec(f *ast.File, path string) *ast.ImportSpec {
   294  	for _, s := range f.Imports {
   295  		if importPath(s) == path {
   296  			return s
   297  		}
   298  	}
   299  	return nil
   300  }
   301  
   302  // importPath returns the unquoted import path of s,
   303  // or "" if the path is not properly quoted.
   304  func importPath(s *ast.ImportSpec) string {
   305  	t, err := strconv.Unquote(s.Path.Value)
   306  	if err == nil {
   307  		return t
   308  	}
   309  	return ""
   310  }
   311  
   312  // declImports reports whether gen contains an import of path.
   313  func declImports(gen *ast.GenDecl, path string) bool {
   314  	if gen.Tok != token.IMPORT {
   315  		return false
   316  	}
   317  	for _, spec := range gen.Specs {
   318  		impspec := spec.(*ast.ImportSpec)
   319  		if importPath(impspec) == path {
   320  			return true
   321  		}
   322  	}
   323  	return false
   324  }
   325  
   326  // isPkgDot reports whether t is the expression "pkg.name"
   327  // where pkg is an imported identifier.
   328  func isPkgDot(t ast.Expr, pkg, name string) bool {
   329  	sel, ok := t.(*ast.SelectorExpr)
   330  	return ok && isTopName(sel.X, pkg) && sel.Sel.String() == name
   331  }
   332  
   333  // isPtrPkgDot reports whether f is the expression "*pkg.name"
   334  // where pkg is an imported identifier.
   335  func isPtrPkgDot(t ast.Expr, pkg, name string) bool {
   336  	ptr, ok := t.(*ast.StarExpr)
   337  	return ok && isPkgDot(ptr.X, pkg, name)
   338  }
   339  
   340  // isTopName reports whether n is a top-level unresolved identifier with the given name.
   341  func isTopName(n ast.Expr, name string) bool {
   342  	id, ok := n.(*ast.Ident)
   343  	return ok && id.Name == name && id.Obj == nil
   344  }
   345  
   346  // isName reports whether n is an identifier with the given name.
   347  func isName(n ast.Expr, name string) bool {
   348  	id, ok := n.(*ast.Ident)
   349  	return ok && id.String() == name
   350  }
   351  
   352  // isCall reports whether t is a call to pkg.name.
   353  func isCall(t ast.Expr, pkg, name string) bool {
   354  	call, ok := t.(*ast.CallExpr)
   355  	return ok && isPkgDot(call.Fun, pkg, name)
   356  }
   357  
   358  // If n is an *ast.Ident, isIdent returns it; otherwise isIdent returns nil.
   359  func isIdent(n interface{}) *ast.Ident {
   360  	id, _ := n.(*ast.Ident)
   361  	return id
   362  }
   363  
   364  // refersTo reports whether n is a reference to the same object as x.
   365  func refersTo(n ast.Node, x *ast.Ident) bool {
   366  	id, ok := n.(*ast.Ident)
   367  	// The test of id.Name == x.Name handles top-level unresolved
   368  	// identifiers, which all have Obj == nil.
   369  	return ok && id.Obj == x.Obj && id.Name == x.Name
   370  }
   371  
   372  // isBlank reports whether n is the blank identifier.
   373  func isBlank(n ast.Expr) bool {
   374  	return isName(n, "_")
   375  }
   376  
   377  // isEmptyString reports whether n is an empty string literal.
   378  func isEmptyString(n ast.Expr) bool {
   379  	lit, ok := n.(*ast.BasicLit)
   380  	return ok && lit.Kind == token.STRING && len(lit.Value) == 2
   381  }
   382  
   383  func warn(pos token.Pos, msg string, args ...interface{}) {
   384  	if pos.IsValid() {
   385  		msg = "%s: " + msg
   386  		arg1 := []interface{}{fset.Position(pos).String()}
   387  		args = append(arg1, args...)
   388  	}
   389  	fmt.Fprintf(os.Stderr, msg+"\n", args...)
   390  }
   391  
   392  // countUses returns the number of uses of the identifier x in scope.
   393  func countUses(x *ast.Ident, scope []ast.Stmt) int {
   394  	count := 0
   395  	ff := func(n interface{}) {
   396  		if n, ok := n.(ast.Node); ok && refersTo(n, x) {
   397  			count++
   398  		}
   399  	}
   400  	for _, n := range scope {
   401  		walk(n, ff)
   402  	}
   403  	return count
   404  }
   405  
   406  // rewriteUses replaces all uses of the identifier x and !x in scope
   407  // with f(x.Pos()) and fnot(x.Pos()).
   408  func rewriteUses(x *ast.Ident, f, fnot func(token.Pos) ast.Expr, scope []ast.Stmt) {
   409  	var lastF ast.Expr
   410  	ff := func(n interface{}) {
   411  		ptr, ok := n.(*ast.Expr)
   412  		if !ok {
   413  			return
   414  		}
   415  		nn := *ptr
   416  
   417  		// The child node was just walked and possibly replaced.
   418  		// If it was replaced and this is a negation, replace with fnot(p).
   419  		not, ok := nn.(*ast.UnaryExpr)
   420  		if ok && not.Op == token.NOT && not.X == lastF {
   421  			*ptr = fnot(nn.Pos())
   422  			return
   423  		}
   424  		if refersTo(nn, x) {
   425  			lastF = f(nn.Pos())
   426  			*ptr = lastF
   427  		}
   428  	}
   429  	for _, n := range scope {
   430  		walk(n, ff)
   431  	}
   432  }
   433  
   434  // assignsTo reports whether any of the code in scope assigns to or takes the address of x.
   435  func assignsTo(x *ast.Ident, scope []ast.Stmt) bool {
   436  	assigned := false
   437  	ff := func(n interface{}) {
   438  		if assigned {
   439  			return
   440  		}
   441  		switch n := n.(type) {
   442  		case *ast.UnaryExpr:
   443  			// use of &x
   444  			if n.Op == token.AND && refersTo(n.X, x) {
   445  				assigned = true
   446  				return
   447  			}
   448  		case *ast.AssignStmt:
   449  			for _, l := range n.Lhs {
   450  				if refersTo(l, x) {
   451  					assigned = true
   452  					return
   453  				}
   454  			}
   455  		}
   456  	}
   457  	for _, n := range scope {
   458  		if assigned {
   459  			break
   460  		}
   461  		walk(n, ff)
   462  	}
   463  	return assigned
   464  }
   465  
   466  // newPkgDot returns an ast.Expr referring to "pkg.name" at position pos.
   467  func newPkgDot(pos token.Pos, pkg, name string) ast.Expr {
   468  	return &ast.SelectorExpr{
   469  		X: &ast.Ident{
   470  			NamePos: pos,
   471  			Name:    pkg,
   472  		},
   473  		Sel: &ast.Ident{
   474  			NamePos: pos,
   475  			Name:    name,
   476  		},
   477  	}
   478  }
   479  
   480  // renameTop renames all references to the top-level name old.
   481  // It returns true if it makes any changes.
   482  func renameTop(f *ast.File, old, new string) bool {
   483  	var fixed bool
   484  
   485  	// Rename any conflicting imports
   486  	// (assuming package name is last element of path).
   487  	for _, s := range f.Imports {
   488  		if s.Name != nil {
   489  			if s.Name.Name == old {
   490  				s.Name.Name = new
   491  				fixed = true
   492  			}
   493  		} else {
   494  			_, thisName := path.Split(importPath(s))
   495  			if thisName == old {
   496  				s.Name = ast.NewIdent(new)
   497  				fixed = true
   498  			}
   499  		}
   500  	}
   501  
   502  	// Rename any top-level declarations.
   503  	for _, d := range f.Decls {
   504  		switch d := d.(type) {
   505  		case *ast.FuncDecl:
   506  			if d.Recv == nil && d.Name.Name == old {
   507  				d.Name.Name = new
   508  				d.Name.Obj.Name = new
   509  				fixed = true
   510  			}
   511  		case *ast.GenDecl:
   512  			for _, s := range d.Specs {
   513  				switch s := s.(type) {
   514  				case *ast.TypeSpec:
   515  					if s.Name.Name == old {
   516  						s.Name.Name = new
   517  						s.Name.Obj.Name = new
   518  						fixed = true
   519  					}
   520  				case *ast.ValueSpec:
   521  					for _, n := range s.Names {
   522  						if n.Name == old {
   523  							n.Name = new
   524  							n.Obj.Name = new
   525  							fixed = true
   526  						}
   527  					}
   528  				}
   529  			}
   530  		}
   531  	}
   532  
   533  	// Rename top-level old to new, both unresolved names
   534  	// (probably defined in another file) and names that resolve
   535  	// to a declaration we renamed.
   536  	walk(f, func(n interface{}) {
   537  		id, ok := n.(*ast.Ident)
   538  		if ok && isTopName(id, old) {
   539  			id.Name = new
   540  			fixed = true
   541  		}
   542  		if ok && id.Obj != nil && id.Name == old && id.Obj.Name == new {
   543  			id.Name = id.Obj.Name
   544  			fixed = true
   545  		}
   546  	})
   547  
   548  	return fixed
   549  }
   550  
   551  // matchLen returns the length of the longest prefix shared by x and y.
   552  func matchLen(x, y string) int {
   553  	i := 0
   554  	for i < len(x) && i < len(y) && x[i] == y[i] {
   555  		i++
   556  	}
   557  	return i
   558  }
   559  
   560  // addImport adds the import path to the file f, if absent.
   561  func addImport(f *ast.File, ipath string) (added bool) {
   562  	if imports(f, ipath) {
   563  		return false
   564  	}
   565  
   566  	// Determine name of import.
   567  	// Assume added imports follow convention of using last element.
   568  	_, name := path.Split(ipath)
   569  
   570  	// Rename any conflicting top-level references from name to name_.
   571  	renameTop(f, name, name+"_")
   572  
   573  	newImport := &ast.ImportSpec{
   574  		Path: &ast.BasicLit{
   575  			Kind:  token.STRING,
   576  			Value: strconv.Quote(ipath),
   577  		},
   578  	}
   579  
   580  	// Find an import decl to add to.
   581  	var (
   582  		bestMatch  = -1
   583  		lastImport = -1
   584  		impDecl    *ast.GenDecl
   585  		impIndex   = -1
   586  	)
   587  	for i, decl := range f.Decls {
   588  		gen, ok := decl.(*ast.GenDecl)
   589  		if ok && gen.Tok == token.IMPORT {
   590  			lastImport = i
   591  			// Do not add to import "C", to avoid disrupting the
   592  			// association with its doc comment, breaking cgo.
   593  			if declImports(gen, "C") {
   594  				continue
   595  			}
   596  
   597  			// Compute longest shared prefix with imports in this block.
   598  			for j, spec := range gen.Specs {
   599  				impspec := spec.(*ast.ImportSpec)
   600  				n := matchLen(importPath(impspec), ipath)
   601  				if n > bestMatch {
   602  					bestMatch = n
   603  					impDecl = gen
   604  					impIndex = j
   605  				}
   606  			}
   607  		}
   608  	}
   609  
   610  	// If no import decl found, add one after the last import.
   611  	if impDecl == nil {
   612  		impDecl = &ast.GenDecl{
   613  			Tok: token.IMPORT,
   614  		}
   615  		f.Decls = append(f.Decls, nil)
   616  		copy(f.Decls[lastImport+2:], f.Decls[lastImport+1:])
   617  		f.Decls[lastImport+1] = impDecl
   618  	}
   619  
   620  	// Ensure the import decl has parentheses, if needed.
   621  	if len(impDecl.Specs) > 0 && !impDecl.Lparen.IsValid() {
   622  		impDecl.Lparen = impDecl.Pos()
   623  	}
   624  
   625  	insertAt := impIndex + 1
   626  	if insertAt == 0 {
   627  		insertAt = len(impDecl.Specs)
   628  	}
   629  	impDecl.Specs = append(impDecl.Specs, nil)
   630  	copy(impDecl.Specs[insertAt+1:], impDecl.Specs[insertAt:])
   631  	impDecl.Specs[insertAt] = newImport
   632  	if insertAt > 0 {
   633  		// Assign same position as the previous import,
   634  		// so that the sorter sees it as being in the same block.
   635  		prev := impDecl.Specs[insertAt-1]
   636  		newImport.Path.ValuePos = prev.Pos()
   637  		newImport.EndPos = prev.Pos()
   638  	}
   639  
   640  	f.Imports = append(f.Imports, newImport)
   641  	return true
   642  }
   643  
   644  // deleteImport deletes the import path from the file f, if present.
   645  func deleteImport(f *ast.File, path string) (deleted bool) {
   646  	oldImport := importSpec(f, path)
   647  
   648  	// Find the import node that imports path, if any.
   649  	for i, decl := range f.Decls {
   650  		gen, ok := decl.(*ast.GenDecl)
   651  		if !ok || gen.Tok != token.IMPORT {
   652  			continue
   653  		}
   654  		for j, spec := range gen.Specs {
   655  			impspec := spec.(*ast.ImportSpec)
   656  			if oldImport != impspec {
   657  				continue
   658  			}
   659  
   660  			// We found an import spec that imports path.
   661  			// Delete it.
   662  			deleted = true
   663  			copy(gen.Specs[j:], gen.Specs[j+1:])
   664  			gen.Specs = gen.Specs[:len(gen.Specs)-1]
   665  
   666  			// If this was the last import spec in this decl,
   667  			// delete the decl, too.
   668  			if len(gen.Specs) == 0 {
   669  				copy(f.Decls[i:], f.Decls[i+1:])
   670  				f.Decls = f.Decls[:len(f.Decls)-1]
   671  			} else if len(gen.Specs) == 1 {
   672  				gen.Lparen = token.NoPos // drop parens
   673  			}
   674  			if j > 0 {
   675  				// We deleted an entry but now there will be
   676  				// a blank line-sized hole where the import was.
   677  				// Close the hole by making the previous
   678  				// import appear to "end" where this one did.
   679  				gen.Specs[j-1].(*ast.ImportSpec).EndPos = impspec.End()
   680  			}
   681  			break
   682  		}
   683  	}
   684  
   685  	// Delete it from f.Imports.
   686  	for i, imp := range f.Imports {
   687  		if imp == oldImport {
   688  			copy(f.Imports[i:], f.Imports[i+1:])
   689  			f.Imports = f.Imports[:len(f.Imports)-1]
   690  			break
   691  		}
   692  	}
   693  
   694  	return
   695  }
   696  
   697  // rewriteImport rewrites any import of path oldPath to path newPath.
   698  func rewriteImport(f *ast.File, oldPath, newPath string) (rewrote bool) {
   699  	for _, imp := range f.Imports {
   700  		if importPath(imp) == oldPath {
   701  			rewrote = true
   702  			// record old End, because the default is to compute
   703  			// it using the length of imp.Path.Value.
   704  			imp.EndPos = imp.End()
   705  			imp.Path.Value = strconv.Quote(newPath)
   706  		}
   707  	}
   708  	return
   709  }
   710  
   711  func usesImport(f *ast.File, path string) (used bool) {
   712  	spec := importSpec(f, path)
   713  	if spec == nil {
   714  		return
   715  	}
   716  
   717  	name := spec.Name.String()
   718  	switch name {
   719  	case "<nil>":
   720  		// If the package name is not explicitly specified,
   721  		// make an educated guess. This is not guaranteed to be correct.
   722  		lastSlash := strings.LastIndex(path, "/")
   723  		if lastSlash == -1 {
   724  			name = path
   725  		} else {
   726  			name = path[lastSlash+1:]
   727  		}
   728  	case "_", ".":
   729  		// Not sure if this import is used - err on the side of caution.
   730  		return true
   731  	}
   732  
   733  	walk(f, func(n interface{}) {
   734  		sel, ok := n.(*ast.SelectorExpr)
   735  		if ok && isTopName(sel.X, name) {
   736  			used = true
   737  		}
   738  	})
   739  
   740  	return
   741  }
   742  
   743  func expr(s string) ast.Expr {
   744  	x, err := parser.ParseExpr(s)
   745  	if err != nil {
   746  		panic("parsing " + s + ": " + err.Error())
   747  	}
   748  	// Remove position information to avoid spurious newlines.
   749  	killPos(reflect.ValueOf(x))
   750  	return x
   751  }
   752  
   753  var posType = reflect.TypeOf(token.Pos(0))
   754  
   755  func killPos(v reflect.Value) {
   756  	switch v.Kind() {
   757  	case reflect.Ptr, reflect.Interface:
   758  		if !v.IsNil() {
   759  			killPos(v.Elem())
   760  		}
   761  	case reflect.Slice:
   762  		n := v.Len()
   763  		for i := 0; i < n; i++ {
   764  			killPos(v.Index(i))
   765  		}
   766  	case reflect.Struct:
   767  		n := v.NumField()
   768  		for i := 0; i < n; i++ {
   769  			f := v.Field(i)
   770  			if f.Type() == posType {
   771  				f.SetInt(0)
   772  				continue
   773  			}
   774  			killPos(f)
   775  		}
   776  	}
   777  }
   778  
   779  // A Rename describes a single renaming.
   780  type rename struct {
   781  	OldImport string // only apply rename if this import is present
   782  	NewImport string // add this import during rewrite
   783  	Old       string // old name: p.T or *p.T
   784  	New       string // new name: p.T or *p.T
   785  }
   786  
   787  func renameFix(tab []rename) func(*ast.File) bool {
   788  	return func(f *ast.File) bool {
   789  		return renameFixTab(f, tab)
   790  	}
   791  }
   792  
   793  func parseName(s string) (ptr bool, pkg, nam string) {
   794  	i := strings.Index(s, ".")
   795  	if i < 0 {
   796  		panic("parseName: invalid name " + s)
   797  	}
   798  	if strings.HasPrefix(s, "*") {
   799  		ptr = true
   800  		s = s[1:]
   801  		i--
   802  	}
   803  	pkg = s[:i]
   804  	nam = s[i+1:]
   805  	return
   806  }
   807  
   808  func renameFixTab(f *ast.File, tab []rename) bool {
   809  	fixed := false
   810  	added := map[string]bool{}
   811  	check := map[string]bool{}
   812  	for _, t := range tab {
   813  		if !imports(f, t.OldImport) {
   814  			continue
   815  		}
   816  		optr, opkg, onam := parseName(t.Old)
   817  		walk(f, func(n interface{}) {
   818  			np, ok := n.(*ast.Expr)
   819  			if !ok {
   820  				return
   821  			}
   822  			x := *np
   823  			if optr {
   824  				p, ok := x.(*ast.StarExpr)
   825  				if !ok {
   826  					return
   827  				}
   828  				x = p.X
   829  			}
   830  			if !isPkgDot(x, opkg, onam) {
   831  				return
   832  			}
   833  			if t.NewImport != "" && !added[t.NewImport] {
   834  				addImport(f, t.NewImport)
   835  				added[t.NewImport] = true
   836  			}
   837  			*np = expr(t.New)
   838  			check[t.OldImport] = true
   839  			fixed = true
   840  		})
   841  	}
   842  
   843  	for ipath := range check {
   844  		if !usesImport(f, ipath) {
   845  			deleteImport(f, ipath)
   846  		}
   847  	}
   848  	return fixed
   849  }
   850  

View as plain text