+func (g *GenFile) Write(p []byte) (n int, err error) {
+ return g.buf.Write(p)
+}
+
+func (g *GenFile) Import(importPath GoImportPath) {
+ g.manualImports[importPath] = true
+}
+
+func (g *GenFile) GoIdent(ident GoIdent) string {
+ if ident.GoImportPath == g.goImportPath {
+ return ident.GoName
+ }
+ if packageName, ok := g.packageNames[ident.GoImportPath]; ok {
+ return string(packageName) + "." + ident.GoName
+ }
+ packageName := cleanPackageName(baseName(string(ident.GoImportPath)))
+ g.packageNames[ident.GoImportPath] = packageName
+ return string(packageName) + "." + ident.GoName
+}
+
+func (g *GenFile) P(v ...interface{}) {
+ for _, x := range v {
+ switch x := x.(type) {
+ case GoIdent:
+ fmt.Fprint(&g.buf, g.GoIdent(x))
+ default:
+ fmt.Fprint(&g.buf, x)
+ }
+ }
+ fmt.Fprintln(&g.buf)
+}
+
+func (g *GenFile) Content() ([]byte, error) {
+ if !strings.HasSuffix(g.filename, ".go") {
+ return g.buf.Bytes(), nil
+ }
+ return g.injectImports(g.buf.Bytes())
+}
+
+func getImportClass(importPath string) int {
+ if !strings.Contains(importPath, ".") {
+ return 0 /* std */
+ }
+ return 1 /* External */
+}
+
+// injectImports parses source, injects import block declaration with all imports and return formatted
+func (g *GenFile) injectImports(original []byte) ([]byte, error) {
+ // Parse source code
+ fset := token.NewFileSet()
+ file, err := parser.ParseFile(fset, "", original, parser.ParseComments)
+ if err != nil {
+ var src bytes.Buffer
+ s := bufio.NewScanner(bytes.NewReader(original))
+ for line := 1; s.Scan(); line++ {
+ fmt.Fprintf(&src, "%5d\t%s\n", line, s.Bytes())
+ }
+ return nil, fmt.Errorf("%v: unparsable Go source: %v\n%v", g.filename, err, src.String())
+ }
+ type Import struct {
+ Name string
+ Path string
+ }
+ // Prepare list of all imports
+ var importPaths []Import
+ for importPath := range g.packageNames {
+ importPaths = append(importPaths, Import{
+ Name: string(g.packageNames[importPath]),
+ Path: string(importPath),
+ })
+ }
+ for importPath := range g.manualImports {
+ if _, ok := g.packageNames[importPath]; ok {
+ continue
+ }
+ importPaths = append(importPaths, Import{
+ Name: "_",
+ Path: string(importPath),
+ })
+ }
+ // Sort imports by import path
+ sort.Slice(importPaths, func(i, j int) bool {
+ ci := getImportClass(importPaths[i].Path)
+ cj := getImportClass(importPaths[j].Path)
+ if ci == cj {
+ return importPaths[i].Path < importPaths[j].Path
+ }
+ return ci < cj
+ })
+ // Inject new import block into parsed AST
+ if len(importPaths) > 0 {
+ // Find import block position
+ pos := file.Package
+ tokFile := fset.File(file.Package)
+ pkgLine := tokFile.Line(file.Package)
+ for _, c := range file.Comments {
+ if tokFile.Line(c.Pos()) > pkgLine {
+ break
+ }
+ pos = c.End()
+ }
+ // Prepare the import block
+ impDecl := &ast.GenDecl{Tok: token.IMPORT, TokPos: pos, Lparen: pos, Rparen: pos}
+ for i, importPath := range importPaths {
+ var name *ast.Ident
+ if importPath.Name == "_" || strings.Contains(importPath.Path, ".") {
+ name = &ast.Ident{Name: importPath.Name, NamePos: pos}
+ }
+ value := strconv.Quote(importPath.Path)
+ if i < len(importPaths)-1 {
+ if getImportClass(importPath.Path) != getImportClass(importPaths[i+1].Path) {
+ value += "\n"
+ }
+ }
+ impDecl.Specs = append(impDecl.Specs, &ast.ImportSpec{
+ Name: name,
+ Path: &ast.BasicLit{Kind: token.STRING, Value: value, ValuePos: pos},
+ EndPos: pos,
+ })
+ }
+
+ file.Decls = append([]ast.Decl{impDecl}, file.Decls...)
+ }
+ // Reformat source code
+ var out bytes.Buffer
+ cfg := &printer.Config{
+ Mode: printer.TabIndent | printer.UseSpaces,
+ Tabwidth: 8,
+ }
+ if err = cfg.Fprint(&out, fset, file); err != nil {
+ return nil, fmt.Errorf("%v: can not reformat Go source: %v", g.filename, err)
+ }
+ return out.Bytes(), nil
+}
+