...
Run Format

Source file src/cmd/fix/import_test.go

Documentation: cmd/fix

  // Copyright 2011 The Go Authors. All rights reserved.
  // Use of this source code is governed by a BSD-style
  // license that can be found in the LICENSE file.
  
  package main
  
  import "go/ast"
  
  func init() {
  	addTestCases(importTests, nil)
  }
  
  var importTests = []testCase{
  	{
  		Name: "import.0",
  		Fn:   addImportFn("os"),
  		In: `package main
  
  import (
  	"os"
  )
  `,
  		Out: `package main
  
  import (
  	"os"
  )
  `,
  	},
  	{
  		Name: "import.1",
  		Fn:   addImportFn("os"),
  		In: `package main
  `,
  		Out: `package main
  
  import "os"
  `,
  	},
  	{
  		Name: "import.2",
  		Fn:   addImportFn("os"),
  		In: `package main
  
  // Comment
  import "C"
  `,
  		Out: `package main
  
  // Comment
  import "C"
  import "os"
  `,
  	},
  	{
  		Name: "import.3",
  		Fn:   addImportFn("os"),
  		In: `package main
  
  // Comment
  import "C"
  
  import (
  	"io"
  	"utf8"
  )
  `,
  		Out: `package main
  
  // Comment
  import "C"
  
  import (
  	"io"
  	"os"
  	"utf8"
  )
  `,
  	},
  	{
  		Name: "import.4",
  		Fn:   deleteImportFn("os"),
  		In: `package main
  
  import (
  	"os"
  )
  `,
  		Out: `package main
  `,
  	},
  	{
  		Name: "import.5",
  		Fn:   deleteImportFn("os"),
  		In: `package main
  
  // Comment
  import "C"
  import "os"
  `,
  		Out: `package main
  
  // Comment
  import "C"
  `,
  	},
  	{
  		Name: "import.6",
  		Fn:   deleteImportFn("os"),
  		In: `package main
  
  // Comment
  import "C"
  
  import (
  	"io"
  	"os"
  	"utf8"
  )
  `,
  		Out: `package main
  
  // Comment
  import "C"
  
  import (
  	"io"
  	"utf8"
  )
  `,
  	},
  	{
  		Name: "import.7",
  		Fn:   deleteImportFn("io"),
  		In: `package main
  
  import (
  	"io"   // a
  	"os"   // b
  	"utf8" // c
  )
  `,
  		Out: `package main
  
  import (
  	// a
  	"os"   // b
  	"utf8" // c
  )
  `,
  	},
  	{
  		Name: "import.8",
  		Fn:   deleteImportFn("os"),
  		In: `package main
  
  import (
  	"io"   // a
  	"os"   // b
  	"utf8" // c
  )
  `,
  		Out: `package main
  
  import (
  	"io" // a
  	// b
  	"utf8" // c
  )
  `,
  	},
  	{
  		Name: "import.9",
  		Fn:   deleteImportFn("utf8"),
  		In: `package main
  
  import (
  	"io"   // a
  	"os"   // b
  	"utf8" // c
  )
  `,
  		Out: `package main
  
  import (
  	"io" // a
  	"os" // b
  	// c
  )
  `,
  	},
  	{
  		Name: "import.10",
  		Fn:   deleteImportFn("io"),
  		In: `package main
  
  import (
  	"io"
  	"os"
  	"utf8"
  )
  `,
  		Out: `package main
  
  import (
  	"os"
  	"utf8"
  )
  `,
  	},
  	{
  		Name: "import.11",
  		Fn:   deleteImportFn("os"),
  		In: `package main
  
  import (
  	"io"
  	"os"
  	"utf8"
  )
  `,
  		Out: `package main
  
  import (
  	"io"
  	"utf8"
  )
  `,
  	},
  	{
  		Name: "import.12",
  		Fn:   deleteImportFn("utf8"),
  		In: `package main
  
  import (
  	"io"
  	"os"
  	"utf8"
  )
  `,
  		Out: `package main
  
  import (
  	"io"
  	"os"
  )
  `,
  	},
  	{
  		Name: "import.13",
  		Fn:   rewriteImportFn("utf8", "encoding/utf8"),
  		In: `package main
  
  import (
  	"io"
  	"os"
  	"utf8" // thanks ken
  )
  `,
  		Out: `package main
  
  import (
  	"encoding/utf8" // thanks ken
  	"io"
  	"os"
  )
  `,
  	},
  	{
  		Name: "import.14",
  		Fn:   rewriteImportFn("asn1", "encoding/asn1"),
  		In: `package main
  
  import (
  	"asn1"
  	"crypto"
  	"crypto/rsa"
  	_ "crypto/sha1"
  	"crypto/x509"
  	"crypto/x509/pkix"
  	"time"
  )
  
  var x = 1
  `,
  		Out: `package main
  
  import (
  	"crypto"
  	"crypto/rsa"
  	_ "crypto/sha1"
  	"crypto/x509"
  	"crypto/x509/pkix"
  	"encoding/asn1"
  	"time"
  )
  
  var x = 1
  `,
  	},
  	{
  		Name: "import.15",
  		Fn:   rewriteImportFn("url", "net/url"),
  		In: `package main
  
  import (
  	"bufio"
  	"net"
  	"path"
  	"url"
  )
  
  var x = 1 // comment on x, not on url
  `,
  		Out: `package main
  
  import (
  	"bufio"
  	"net"
  	"net/url"
  	"path"
  )
  
  var x = 1 // comment on x, not on url
  `,
  	},
  	{
  		Name: "import.16",
  		Fn:   rewriteImportFn("http", "net/http", "template", "text/template"),
  		In: `package main
  
  import (
  	"flag"
  	"http"
  	"log"
  	"template"
  )
  
  var addr = flag.String("addr", ":1718", "http service address") // Q=17, R=18
  `,
  		Out: `package main
  
  import (
  	"flag"
  	"log"
  	"net/http"
  	"text/template"
  )
  
  var addr = flag.String("addr", ":1718", "http service address") // Q=17, R=18
  `,
  	},
  	{
  		Name: "import.17",
  		Fn:   addImportFn("x/y/z", "x/a/c"),
  		In: `package main
  
  // Comment
  import "C"
  
  import (
  	"a"
  	"b"
  
  	"x/w"
  
  	"d/f"
  )
  `,
  		Out: `package main
  
  // Comment
  import "C"
  
  import (
  	"a"
  	"b"
  
  	"x/a/c"
  	"x/w"
  	"x/y/z"
  
  	"d/f"
  )
  `,
  	},
  	{
  		Name: "import.18",
  		Fn:   addDelImportFn("e", "o"),
  		In: `package main
  
  import (
  	"f"
  	"o"
  	"z"
  )
  `,
  		Out: `package main
  
  import (
  	"e"
  	"f"
  	"z"
  )
  `,
  	},
  }
  
  func addImportFn(path ...string) func(*ast.File) bool {
  	return func(f *ast.File) bool {
  		fixed := false
  		for _, p := range path {
  			if !imports(f, p) {
  				addImport(f, p)
  				fixed = true
  			}
  		}
  		return fixed
  	}
  }
  
  func deleteImportFn(path string) func(*ast.File) bool {
  	return func(f *ast.File) bool {
  		if imports(f, path) {
  			deleteImport(f, path)
  			return true
  		}
  		return false
  	}
  }
  
  func addDelImportFn(p1 string, p2 string) func(*ast.File) bool {
  	return func(f *ast.File) bool {
  		fixed := false
  		if !imports(f, p1) {
  			addImport(f, p1)
  			fixed = true
  		}
  		if imports(f, p2) {
  			deleteImport(f, p2)
  			fixed = true
  		}
  		return fixed
  	}
  }
  
  func rewriteImportFn(oldnew ...string) func(*ast.File) bool {
  	return func(f *ast.File) bool {
  		fixed := false
  		for i := 0; i < len(oldnew); i += 2 {
  			if imports(f, oldnew[i]) {
  				rewriteImport(f, oldnew[i], oldnew[i+1])
  				fixed = true
  			}
  		}
  		return fixed
  	}
  }
  

View as plain text