Source file src/cmd/fix/main.go

     1  // Copyright 2011 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package main
     6  
     7  import (
     8  	"bytes"
     9  	"flag"
    10  	"fmt"
    11  	"go/ast"
    12  	"go/format"
    13  	"go/parser"
    14  	"go/scanner"
    15  	"go/token"
    16  	"internal/diff"
    17  	"io"
    18  	"io/fs"
    19  	"os"
    20  	"path/filepath"
    21  	"sort"
    22  	"strconv"
    23  	"strings"
    24  )
    25  
    26  var (
    27  	fset     = token.NewFileSet()
    28  	exitCode = 0
    29  )
    30  
    31  var allowedRewrites = flag.String("r", "",
    32  	"restrict the rewrites to this comma-separated list")
    33  
    34  var forceRewrites = flag.String("force", "",
    35  	"force these fixes to run even if the code looks updated")
    36  
    37  var allowed, force map[string]bool
    38  
    39  var (
    40  	doDiff       = flag.Bool("diff", false, "display diffs instead of rewriting files")
    41  	goVersionStr = flag.String("go", "", "go language version for files")
    42  
    43  	goVersion int // 115 for go1.15
    44  )
    45  
    46  // enable for debugging fix failures
    47  const debug = false // display incorrectly reformatted source and exit
    48  
    49  func usage() {
    50  	fmt.Fprintf(os.Stderr, "usage: go tool fix [-diff] [-r fixname,...] [-force fixname,...] [path ...]\n")
    51  	flag.PrintDefaults()
    52  	fmt.Fprintf(os.Stderr, "\nAvailable rewrites are:\n")
    53  	sort.Sort(byName(fixes))
    54  	for _, f := range fixes {
    55  		if f.disabled {
    56  			fmt.Fprintf(os.Stderr, "\n%s (disabled)\n", f.name)
    57  		} else {
    58  			fmt.Fprintf(os.Stderr, "\n%s\n", f.name)
    59  		}
    60  		desc := strings.TrimSpace(f.desc)
    61  		desc = strings.ReplaceAll(desc, "\n", "\n\t")
    62  		fmt.Fprintf(os.Stderr, "\t%s\n", desc)
    63  	}
    64  	os.Exit(2)
    65  }
    66  
    67  func main() {
    68  	flag.Usage = usage
    69  	flag.Parse()
    70  
    71  	if *goVersionStr != "" {
    72  		if !strings.HasPrefix(*goVersionStr, "go") {
    73  			report(fmt.Errorf("invalid -go=%s", *goVersionStr))
    74  			os.Exit(exitCode)
    75  		}
    76  		majorStr := (*goVersionStr)[len("go"):]
    77  		minorStr := "0"
    78  		if before, after, found := strings.Cut(majorStr, "."); found {
    79  			majorStr, minorStr = before, after
    80  		}
    81  		major, err1 := strconv.Atoi(majorStr)
    82  		minor, err2 := strconv.Atoi(minorStr)
    83  		if err1 != nil || err2 != nil || major < 0 || major >= 100 || minor < 0 || minor >= 100 {
    84  			report(fmt.Errorf("invalid -go=%s", *goVersionStr))
    85  			os.Exit(exitCode)
    86  		}
    87  
    88  		goVersion = major*100 + minor
    89  	}
    90  
    91  	sort.Sort(byDate(fixes))
    92  
    93  	if *allowedRewrites != "" {
    94  		allowed = make(map[string]bool)
    95  		for _, f := range strings.Split(*allowedRewrites, ",") {
    96  			allowed[f] = true
    97  		}
    98  	}
    99  
   100  	if *forceRewrites != "" {
   101  		force = make(map[string]bool)
   102  		for _, f := range strings.Split(*forceRewrites, ",") {
   103  			force[f] = true
   104  		}
   105  	}
   106  
   107  	if flag.NArg() == 0 {
   108  		if err := processFile("standard input", true); err != nil {
   109  			report(err)
   110  		}
   111  		os.Exit(exitCode)
   112  	}
   113  
   114  	for i := 0; i < flag.NArg(); i++ {
   115  		path := flag.Arg(i)
   116  		switch dir, err := os.Stat(path); {
   117  		case err != nil:
   118  			report(err)
   119  		case dir.IsDir():
   120  			walkDir(path)
   121  		default:
   122  			if err := processFile(path, false); err != nil {
   123  				report(err)
   124  			}
   125  		}
   126  	}
   127  
   128  	os.Exit(exitCode)
   129  }
   130  
   131  const parserMode = parser.ParseComments
   132  
   133  func gofmtFile(f *ast.File) ([]byte, error) {
   134  	var buf bytes.Buffer
   135  	if err := format.Node(&buf, fset, f); err != nil {
   136  		return nil, err
   137  	}
   138  	return buf.Bytes(), nil
   139  }
   140  
   141  func processFile(filename string, useStdin bool) error {
   142  	var f *os.File
   143  	var err error
   144  	var fixlog strings.Builder
   145  
   146  	if useStdin {
   147  		f = os.Stdin
   148  	} else {
   149  		f, err = os.Open(filename)
   150  		if err != nil {
   151  			return err
   152  		}
   153  		defer f.Close()
   154  	}
   155  
   156  	src, err := io.ReadAll(f)
   157  	if err != nil {
   158  		return err
   159  	}
   160  
   161  	file, err := parser.ParseFile(fset, filename, src, parserMode)
   162  	if err != nil {
   163  		return err
   164  	}
   165  
   166  	// Make sure file is in canonical format.
   167  	// This "fmt" pseudo-fix cannot be disabled.
   168  	newSrc, err := gofmtFile(file)
   169  	if err != nil {
   170  		return err
   171  	}
   172  	if !bytes.Equal(newSrc, src) {
   173  		newFile, err := parser.ParseFile(fset, filename, newSrc, parserMode)
   174  		if err != nil {
   175  			return err
   176  		}
   177  		file = newFile
   178  		fmt.Fprintf(&fixlog, " fmt")
   179  	}
   180  
   181  	// Apply all fixes to file.
   182  	newFile := file
   183  	fixed := false
   184  	for _, fix := range fixes {
   185  		if allowed != nil && !allowed[fix.name] {
   186  			continue
   187  		}
   188  		if fix.disabled && !force[fix.name] {
   189  			continue
   190  		}
   191  		if fix.f(newFile) {
   192  			fixed = true
   193  			fmt.Fprintf(&fixlog, " %s", fix.name)
   194  
   195  			// AST changed.
   196  			// Print and parse, to update any missing scoping
   197  			// or position information for subsequent fixers.
   198  			newSrc, err := gofmtFile(newFile)
   199  			if err != nil {
   200  				return err
   201  			}
   202  			newFile, err = parser.ParseFile(fset, filename, newSrc, parserMode)
   203  			if err != nil {
   204  				if debug {
   205  					fmt.Printf("%s", newSrc)
   206  					report(err)
   207  					os.Exit(exitCode)
   208  				}
   209  				return err
   210  			}
   211  		}
   212  	}
   213  	if !fixed {
   214  		return nil
   215  	}
   216  	fmt.Fprintf(os.Stderr, "%s: fixed %s\n", filename, fixlog.String()[1:])
   217  
   218  	// Print AST.  We did that after each fix, so this appears
   219  	// redundant, but it is necessary to generate gofmt-compatible
   220  	// source code in a few cases. The official gofmt style is the
   221  	// output of the printer run on a standard AST generated by the parser,
   222  	// but the source we generated inside the loop above is the
   223  	// output of the printer run on a mangled AST generated by a fixer.
   224  	newSrc, err = gofmtFile(newFile)
   225  	if err != nil {
   226  		return err
   227  	}
   228  
   229  	if *doDiff {
   230  		os.Stdout.Write(diff.Diff(filename, src, "fixed/"+filename, newSrc))
   231  		return nil
   232  	}
   233  
   234  	if useStdin {
   235  		os.Stdout.Write(newSrc)
   236  		return nil
   237  	}
   238  
   239  	return os.WriteFile(f.Name(), newSrc, 0)
   240  }
   241  
   242  func gofmt(n any) string {
   243  	var gofmtBuf strings.Builder
   244  	if err := format.Node(&gofmtBuf, fset, n); err != nil {
   245  		return "<" + err.Error() + ">"
   246  	}
   247  	return gofmtBuf.String()
   248  }
   249  
   250  func report(err error) {
   251  	scanner.PrintError(os.Stderr, err)
   252  	exitCode = 2
   253  }
   254  
   255  func walkDir(path string) {
   256  	filepath.WalkDir(path, visitFile)
   257  }
   258  
   259  func visitFile(path string, f fs.DirEntry, err error) error {
   260  	if err == nil && isGoFile(f) {
   261  		err = processFile(path, false)
   262  	}
   263  	if err != nil {
   264  		report(err)
   265  	}
   266  	return nil
   267  }
   268  
   269  func isGoFile(f fs.DirEntry) bool {
   270  	// ignore non-Go files
   271  	name := f.Name()
   272  	return !f.IsDir() && !strings.HasPrefix(name, ".") && strings.HasSuffix(name, ".go")
   273  }
   274  

View as plain text