...
Run Format

Source file src/sort/genzfunc.go

Documentation: sort

  // Copyright 2016 The Go Authors. All rights reserved.
  // Use of this source code is governed by a BSD-style
  // license that can be found in the LICENSE file.
  
  // +build ignore
  
  // This program is run via "go generate" (via a directive in sort.go)
  // to generate zfuncversion.go.
  //
  // It copies sort.go to zfuncversion.go, only retaining funcs which
  // take a "data Interface" parameter, and renaming each to have a
  // "_func" suffix and taking a "data lessSwap" instead. It then rewrites
  // each internal function call to the appropriate _func variants.
  
  package main
  
  import (
  	"bytes"
  	"go/ast"
  	"go/format"
  	"go/parser"
  	"go/token"
  	"io/ioutil"
  	"log"
  	"regexp"
  )
  
  var fset = token.NewFileSet()
  
  func main() {
  	af, err := parser.ParseFile(fset, "sort.go", nil, 0)
  	if err != nil {
  		log.Fatal(err)
  	}
  	af.Doc = nil
  	af.Imports = nil
  	af.Comments = nil
  
  	var newDecl []ast.Decl
  	for _, d := range af.Decls {
  		fd, ok := d.(*ast.FuncDecl)
  		if !ok {
  			continue
  		}
  		if fd.Recv != nil || fd.Name.IsExported() {
  			continue
  		}
  		typ := fd.Type
  		if len(typ.Params.List) < 1 {
  			continue
  		}
  		arg0 := typ.Params.List[0]
  		arg0Name := arg0.Names[0].Name
  		arg0Type := arg0.Type.(*ast.Ident)
  		if arg0Name != "data" || arg0Type.Name != "Interface" {
  			continue
  		}
  		arg0Type.Name = "lessSwap"
  
  		newDecl = append(newDecl, fd)
  	}
  	af.Decls = newDecl
  	ast.Walk(visitFunc(rewriteCalls), af)
  
  	var out bytes.Buffer
  	if err := format.Node(&out, fset, af); err != nil {
  		log.Fatalf("format.Node: %v", err)
  	}
  
  	// Get rid of blank lines after removal of comments.
  	src := regexp.MustCompile(`\n{2,}`).ReplaceAll(out.Bytes(), []byte("\n"))
  
  	// Add comments to each func, for the lost reader.
  	// This is so much easier than adding comments via the AST
  	// and trying to get position info correct.
  	src = regexp.MustCompile(`(?m)^func (\w+)`).ReplaceAll(src, []byte("\n// Auto-generated variant of sort.go:$1\nfunc ${1}_func"))
  
  	// Final gofmt.
  	src, err = format.Source(src)
  	if err != nil {
  		log.Fatalf("format.Source: %v on\n%s", err, src)
  	}
  
  	out.Reset()
  	out.WriteString(`// DO NOT EDIT; AUTO-GENERATED from sort.go using genzfunc.go
  
  // Copyright 2016 The Go Authors. All rights reserved.
  // Use of this source code is governed by a BSD-style
  // license that can be found in the LICENSE file.
  
  `)
  	out.Write(src)
  
  	const target = "zfuncversion.go"
  	if err := ioutil.WriteFile(target, out.Bytes(), 0644); err != nil {
  		log.Fatal(err)
  	}
  }
  
  type visitFunc func(ast.Node) ast.Visitor
  
  func (f visitFunc) Visit(n ast.Node) ast.Visitor { return f(n) }
  
  func rewriteCalls(n ast.Node) ast.Visitor {
  	ce, ok := n.(*ast.CallExpr)
  	if ok {
  		rewriteCall(ce)
  	}
  	return visitFunc(rewriteCalls)
  }
  
  func rewriteCall(ce *ast.CallExpr) {
  	ident, ok := ce.Fun.(*ast.Ident)
  	if !ok {
  		// e.g. skip SelectorExpr (data.Less(..) calls)
  		return
  	}
  	// skip casts
  	if ident.Name == "int" || ident.Name == "uint" {
  		return
  	}
  	if len(ce.Args) < 1 {
  		return
  	}
  	ident.Name += "_func"
  }
  

View as plain text