// Copyright 2011 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. package main import "go/ast" func init() { addTestCases(importTests, nil) } var importTests = []testCase{ { Name: "import.0", Fn: addImportFn("os"), In: `package main import ( "os" ) `, Out: `package main import ( "os" ) `, }, { Name: "import.1", Fn: addImportFn("os"), In: `package main `, Out: `package main import "os" `, }, { Name: "import.2", Fn: addImportFn("os"), In: `package main // Comment import "C" `, Out: `package main // Comment import "C" import "os" `, }, { Name: "import.3", Fn: addImportFn("os"), In: `package main // Comment import "C" import ( "io" "utf8" ) `, Out: `package main // Comment import "C" import ( "io" "os" "utf8" ) `, }, { Name: "import.4", Fn: deleteImportFn("os"), In: `package main import ( "os" ) `, Out: `package main `, }, { Name: "import.5", Fn: deleteImportFn("os"), In: `package main // Comment import "C" import "os" `, Out: `package main // Comment import "C" `, }, { Name: "import.6", Fn: deleteImportFn("os"), In: `package main // Comment import "C" import ( "io" "os" "utf8" ) `, Out: `package main // Comment import "C" import ( "io" "utf8" ) `, }, { Name: "import.7", Fn: deleteImportFn("io"), In: `package main import ( "io" // a "os" // b "utf8" // c ) `, Out: `package main import ( // a "os" // b "utf8" // c ) `, }, { Name: "import.8", Fn: deleteImportFn("os"), In: `package main import ( "io" // a "os" // b "utf8" // c ) `, Out: `package main import ( "io" // a // b "utf8" // c ) `, }, { Name: "import.9", Fn: deleteImportFn("utf8"), In: `package main import ( "io" // a "os" // b "utf8" // c ) `, Out: `package main import ( "io" // a "os" // b // c ) `, }, { Name: "import.10", Fn: deleteImportFn("io"), In: `package main import ( "io" "os" "utf8" ) `, Out: `package main import ( "os" "utf8" ) `, }, { Name: "import.11", Fn: deleteImportFn("os"), In: `package main import ( "io" "os" "utf8" ) `, Out: `package main import ( "io" "utf8" ) `, }, { Name: "import.12", Fn: deleteImportFn("utf8"), In: `package main import ( "io" "os" "utf8" ) `, Out: `package main import ( "io" "os" ) `, }, { Name: "import.13", Fn: rewriteImportFn("utf8", "encoding/utf8"), In: `package main import ( "io" "os" "utf8" // thanks ken ) `, Out: `package main import ( "encoding/utf8" // thanks ken "io" "os" ) `, }, { Name: "import.14", Fn: rewriteImportFn("asn1", "encoding/asn1"), In: `package main import ( "asn1" "crypto" "crypto/rsa" _ "crypto/sha1" "crypto/x509" "crypto/x509/pkix" "time" ) var x = 1 `, Out: `package main import ( "crypto" "crypto/rsa" _ "crypto/sha1" "crypto/x509" "crypto/x509/pkix" "encoding/asn1" "time" ) var x = 1 `, }, { Name: "import.15", Fn: rewriteImportFn("url", "net/url"), In: `package main import ( "bufio" "net" "path" "url" ) var x = 1 // comment on x, not on url `, Out: `package main import ( "bufio" "net" "net/url" "path" ) var x = 1 // comment on x, not on url `, }, { Name: "import.16", Fn: rewriteImportFn("http", "net/http", "template", "text/template"), In: `package main import ( "flag" "http" "log" "template" ) var addr = flag.String("addr", ":1718", "http service address") // Q=17, R=18 `, Out: `package main import ( "flag" "log" "net/http" "text/template" ) var addr = flag.String("addr", ":1718", "http service address") // Q=17, R=18 `, }, { Name: "import.17", Fn: addImportFn("x/y/z", "x/a/c"), In: `package main // Comment import "C" import ( "a" "b" "x/w" "d/f" ) `, Out: `package main // Comment import "C" import ( "a" "b" "x/a/c" "x/w" "x/y/z" "d/f" ) `, }, { Name: "import.18", Fn: addDelImportFn("e", "o"), In: `package main import ( "f" "o" "z" ) `, Out: `package main import ( "e" "f" "z" ) `, }, } func addImportFn(path ...string) func(*ast.File) bool { return func(f *ast.File) bool { fixed := false for _, p := range path { if !imports(f, p) { addImport(f, p) fixed = true } } return fixed } } func deleteImportFn(path string) func(*ast.File) bool { return func(f *ast.File) bool { if imports(f, path) { deleteImport(f, path) return true } return false } } func addDelImportFn(p1 string, p2 string) func(*ast.File) bool { return func(f *ast.File) bool { fixed := false if !imports(f, p1) { addImport(f, p1) fixed = true } if imports(f, p2) { deleteImport(f, p2) fixed = true } return fixed } } func rewriteImportFn(oldnew ...string) func(*ast.File) bool { return func(f *ast.File) bool { fixed := false for i := 0; i < len(oldnew); i += 2 { if imports(f, oldnew[i]) { rewriteImport(f, oldnew[i], oldnew[i+1]) fixed = true } } return fixed } }