mirror of
https://github.com/fleetdm/fleet
synced 2026-05-06 06:48:54 +00:00
435 lines
11 KiB
Go
435 lines
11 KiB
Go
// impl generates method stubs for implementing an interface.
|
|
package main
|
|
|
|
import (
|
|
"bytes"
|
|
"flag"
|
|
"fmt"
|
|
"go/ast"
|
|
"go/build"
|
|
"go/format"
|
|
"go/parser"
|
|
"go/printer"
|
|
"go/token"
|
|
"log"
|
|
"os"
|
|
"path/filepath"
|
|
"regexp"
|
|
"strconv"
|
|
"strings"
|
|
"text/template"
|
|
|
|
"golang.org/x/tools/imports"
|
|
)
|
|
|
|
const usage = `impl [o output.go] <recv> <iface>
|
|
|
|
impl generates method stubs for recv to implement iface.
|
|
|
|
Examples:
|
|
|
|
impl 'f *File' io.Reader
|
|
impl Murmur hash.Hash
|
|
|
|
Don't forget the single quotes around the receiver type
|
|
to prevent shell globbing.
|
|
`
|
|
|
|
// findInterface returns the import path and identifier of an interface.
|
|
// For example, given "http.ResponseWriter", findInterface returns
|
|
// "net/http", "ResponseWriter".
|
|
// If a fully qualified interface is given, such as "net/http.ResponseWriter",
|
|
// it simply parses the input.
|
|
func findInterface(iface string) (path string, id string, err error) {
|
|
if len(strings.Fields(iface)) != 1 {
|
|
return "", "", fmt.Errorf("couldn't parse interface: %s", iface)
|
|
}
|
|
|
|
if slash := strings.LastIndex(iface, "/"); slash > -1 {
|
|
// package path provided
|
|
dot := strings.LastIndex(iface, ".")
|
|
// make sure iface does not end with "/" (e.g. reject net/http/)
|
|
if slash+1 == len(iface) {
|
|
return "", "", fmt.Errorf("interface name cannot end with a '/' character: %s", iface)
|
|
}
|
|
// make sure iface does not end with "." (e.g. reject net/http.)
|
|
if dot+1 == len(iface) {
|
|
return "", "", fmt.Errorf("interface name cannot end with a '.' character: %s", iface)
|
|
}
|
|
// make sure iface has exactly one "." after "/" (e.g. reject net/http/httputil)
|
|
if strings.Count(iface[slash:], ".") != 1 {
|
|
return "", "", fmt.Errorf("invalid interface name: %s", iface)
|
|
}
|
|
return iface[:dot], iface[dot+1:], nil
|
|
}
|
|
|
|
src := []byte("package hack\n" + "var i " + iface)
|
|
// If we couldn't determine the import path, goimports will
|
|
// auto fix the import path.
|
|
imp, err := imports.Process(".", src, nil)
|
|
if err != nil {
|
|
return "", "", fmt.Errorf("couldn't parse interface: %s", iface)
|
|
}
|
|
|
|
// imp should now contain an appropriate import.
|
|
// Parse out the import and the identifier.
|
|
fset := token.NewFileSet()
|
|
f, err := parser.ParseFile(fset, "", imp, 0)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
if len(f.Imports) == 0 {
|
|
return "", "", fmt.Errorf("unrecognized interface: %s", iface)
|
|
}
|
|
raw := f.Imports[0].Path.Value // "io"
|
|
path, err = strconv.Unquote(raw) // io
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
decl := f.Decls[1].(*ast.GenDecl) // var i io.Reader
|
|
spec := decl.Specs[0].(*ast.ValueSpec) // i io.Reader
|
|
sel := spec.Type.(*ast.SelectorExpr) // io.Reader
|
|
id = sel.Sel.Name // Reader
|
|
return path, id, nil
|
|
}
|
|
|
|
// Pkg is a parsed build.Package.
|
|
type Pkg struct {
|
|
*build.Package
|
|
*token.FileSet
|
|
}
|
|
|
|
// typeSpec locates the *ast.TypeSpec for type id in the import path.
|
|
func typeSpec(path string, id string) (Pkg, *ast.TypeSpec, error) {
|
|
pkg, err := build.Import(path, "", 0)
|
|
if err != nil {
|
|
return Pkg{}, nil, fmt.Errorf("couldn't find package %s: %v", path, err)
|
|
}
|
|
|
|
fset := token.NewFileSet() // share one fset across the whole package
|
|
for _, file := range pkg.GoFiles {
|
|
f, err := parser.ParseFile(fset, filepath.Join(pkg.Dir, file), nil, 0)
|
|
if err != nil {
|
|
continue
|
|
}
|
|
|
|
for _, decl := range f.Decls {
|
|
decl, ok := decl.(*ast.GenDecl)
|
|
if !ok || decl.Tok != token.TYPE {
|
|
continue
|
|
}
|
|
for _, spec := range decl.Specs {
|
|
spec := spec.(*ast.TypeSpec)
|
|
if spec.Name.Name != id {
|
|
continue
|
|
}
|
|
return Pkg{Package: pkg, FileSet: fset}, spec, nil
|
|
}
|
|
}
|
|
}
|
|
return Pkg{}, nil, fmt.Errorf("type %s not found in %s", id, path)
|
|
}
|
|
|
|
// gofmt pretty-prints e.
|
|
func (p Pkg) gofmt(e ast.Expr) string {
|
|
var buf bytes.Buffer
|
|
printer.Fprint(&buf, p.FileSet, e)
|
|
return buf.String()
|
|
}
|
|
|
|
// fullType returns the fully qualified type of e.
|
|
// Examples, assuming package net/http:
|
|
//
|
|
// fullType(int) => "int"
|
|
// fullType(Handler) => "http.Handler"
|
|
// fullType(io.Reader) => "io.Reader"
|
|
// fullType(*Request) => "*http.Request"
|
|
func (p Pkg) fullType(e ast.Expr) string {
|
|
ast.Inspect(e, func(n ast.Node) bool {
|
|
switch n := n.(type) {
|
|
case *ast.Ident:
|
|
// Using typeSpec instead of IsExported here would be
|
|
// more accurate, but it'd be crazy expensive, and if
|
|
// the type isn't exported, there's no point trying
|
|
// to implement it anyway.
|
|
if n.IsExported() {
|
|
n.Name = p.Package.Name + "." + n.Name
|
|
}
|
|
case *ast.SelectorExpr:
|
|
return false
|
|
}
|
|
return true
|
|
})
|
|
return p.gofmt(e)
|
|
}
|
|
|
|
func (p Pkg) params(field *ast.Field, defaultName string) []Param {
|
|
var params []Param
|
|
typ := p.fullType(field.Type)
|
|
for _, name := range field.Names {
|
|
params = append(params, Param{Name: name.Name, Type: typ})
|
|
}
|
|
// Handle anonymous params
|
|
if len(params) == 0 {
|
|
params = []Param{{Type: typ, Name: defaultName}}
|
|
}
|
|
return params
|
|
}
|
|
|
|
// Method represents a method signature.
|
|
type Method struct {
|
|
RecvShort string
|
|
Recv string
|
|
Func
|
|
}
|
|
|
|
type Struc struct {
|
|
IName string
|
|
Func
|
|
}
|
|
|
|
// Func represents a function signature.
|
|
type Func struct {
|
|
Name string
|
|
Params []Param
|
|
Res []Param
|
|
}
|
|
|
|
// Param represents a parameter in a function or method signature.
|
|
type Param struct {
|
|
Name string
|
|
Type string
|
|
}
|
|
|
|
// CalledArgument will correctly generate call to a function with a
|
|
// variadic parameter
|
|
func (p *Param) CalledArgument() string {
|
|
variadic, _ := regexp.MatchString("^[.]{3}", p.Type)
|
|
if variadic {
|
|
return p.Name + "..."
|
|
}
|
|
return p.Name
|
|
}
|
|
|
|
func (p Pkg) funcsig(f *ast.Field) Func {
|
|
fn := Func{Name: f.Names[0].Name}
|
|
typ := f.Type.(*ast.FuncType)
|
|
if typ.Params != nil {
|
|
for pos, field := range typ.Params.List {
|
|
defaultName := fmt.Sprintf("p%d", pos)
|
|
fn.Params = append(fn.Params, p.params(field, defaultName)...)
|
|
}
|
|
}
|
|
if typ.Results != nil {
|
|
for _, field := range typ.Results.List {
|
|
fn.Res = append(fn.Res, p.params(field, "")...)
|
|
}
|
|
}
|
|
return fn
|
|
}
|
|
|
|
// funcs returns the set of methods required to implement iface.
|
|
// It is called funcs rather than methods because the
|
|
// function descriptions are functions; there is no receiver.
|
|
func funcs(iface string) ([]Func, error) {
|
|
// Locate the interface.
|
|
path, id, err := findInterface(iface)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Parse the package and find the interface declaration.
|
|
p, spec, err := typeSpec(path, id)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("interface %s not found: %s", iface, err)
|
|
}
|
|
idecl, ok := spec.Type.(*ast.InterfaceType)
|
|
if !ok {
|
|
return nil, fmt.Errorf("not an interface: %s", iface)
|
|
}
|
|
|
|
if idecl.Methods == nil {
|
|
return nil, fmt.Errorf("empty interface: %s", iface)
|
|
}
|
|
|
|
var fns []Func
|
|
for _, fndecl := range idecl.Methods.List {
|
|
if len(fndecl.Names) == 0 {
|
|
// Embedded interface: recurse
|
|
embedded, err := funcs(p.fullType(fndecl.Type))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
fns = append(fns, embedded...)
|
|
continue
|
|
}
|
|
|
|
fn := p.funcsig(fndecl)
|
|
fns = append(fns, fn)
|
|
}
|
|
return fns, nil
|
|
}
|
|
|
|
const stub = "func ({{.Recv}}) {{.Name}}" +
|
|
"({{range .Params}}{{.Name}} {{.Type}}, {{end}})" +
|
|
"({{range .Res}}{{.Name}} {{.Type}}, {{end}})" +
|
|
"{\n" + "{{.RecvShort}}.mu.Lock()" + "\n" +
|
|
"{{.RecvShort}}.{{.Name}}FuncInvoked = true" + "\n" +
|
|
"{{.RecvShort}}.mu.Unlock()" + "\n" +
|
|
"{{if .Res}}return {{end}}{{.RecvShort}}.{{.Name}}Func({{range .Params}}{{.CalledArgument}}, {{end}})" +
|
|
"\n" + "}\n\n"
|
|
|
|
var tmpl = template.Must(template.New("test").Parse(stub))
|
|
|
|
// genStubs prints nicely formatted method stubs
|
|
// for fns using receiver expression recv.
|
|
// If recv is not a valid receiver expression,
|
|
// genStubs will panic.
|
|
func genStubs(recv string, fns []Func) []byte {
|
|
var buf bytes.Buffer
|
|
for _, fn := range fns {
|
|
meth := Method{Recv: recv, RecvShort: shortRecv(recv), Func: fn}
|
|
tmpl.Execute(&buf, meth) //nolint:errcheck
|
|
}
|
|
|
|
pretty, err := format.Source(buf.Bytes())
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
return pretty
|
|
}
|
|
|
|
func shortRecv(recv string) string {
|
|
s := strings.SplitN(recv, "*", 2)[0]
|
|
return s
|
|
}
|
|
|
|
const packageStr = "// Automatically generated by mockimpl. DO NOT EDIT!" +
|
|
"\n\n" + "package mock" + "\n\n"
|
|
|
|
const str = "{{.Name}}Func {{.Name}}Func" +
|
|
"\n" + "{{.Name}}FuncInvoked bool" +
|
|
"\n\n"
|
|
|
|
const funcTypeStr = "type {{.Name}}Func func" +
|
|
"({{range .Params}}{{.Name}} {{.Type}}, {{end}})" +
|
|
"({{range .Res}}{{.Name}} {{.Type}}, {{end}})" +
|
|
"\n\n"
|
|
|
|
var (
|
|
tmplStr = template.Must(template.New("testtwo").Parse(str))
|
|
funcTypetmplStr = template.Must(template.New("funcTypetmpl").Parse(funcTypeStr))
|
|
)
|
|
|
|
func genStr(name string, fns []Func) []byte {
|
|
var buf bytes.Buffer
|
|
for _, fn := range fns {
|
|
meth := Struc{IName: name, Func: fn}
|
|
funcTypetmplStr.Execute(&buf, meth) //nolint:errcheck
|
|
}
|
|
buf.WriteString("type ")
|
|
buf.WriteString(name)
|
|
buf.WriteString(" struct {\n")
|
|
for _, fn := range fns {
|
|
meth := Struc{IName: name, Func: fn}
|
|
tmplStr.Execute(&buf, meth) //nolint:errcheck
|
|
}
|
|
buf.WriteString("\n")
|
|
buf.WriteString("mu sync.Mutex")
|
|
buf.WriteString("\n")
|
|
buf.WriteString("}")
|
|
|
|
pretty, err := format.Source(buf.Bytes())
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
return pretty
|
|
// return buf.Bytes()
|
|
}
|
|
|
|
// validReceiver reports whether recv is a valid receiver expression.
|
|
func validReceiver(recv string) bool {
|
|
if recv == "" {
|
|
// The parse will parse empty receivers, but we don't want to accept them,
|
|
// since it won't generate a usable code snippet.
|
|
return false
|
|
}
|
|
fset := token.NewFileSet()
|
|
_, err := parser.ParseFile(fset, "", "package hack\nfunc ("+recv+") Foo()", 0)
|
|
return err == nil
|
|
}
|
|
|
|
func main() {
|
|
flOut := flag.String("o", "", "output file")
|
|
flag.Parse()
|
|
args := flag.Args()
|
|
if len(args) != 2 {
|
|
fmt.Fprint(os.Stderr, usage)
|
|
os.Exit(2)
|
|
}
|
|
recv, iface := args[0], args[1]
|
|
if !validReceiver(recv) {
|
|
fatal(fmt.Sprintf("invalid receiver: %q", recv))
|
|
}
|
|
|
|
fns, err := funcs(iface)
|
|
if err != nil {
|
|
fatal(err)
|
|
}
|
|
|
|
// Remove duplicates from fns
|
|
uniqueFns := make(map[string]Func, len(fns))
|
|
dedupedFns := make([]Func, 0, len(fns))
|
|
for _, fn := range fns {
|
|
if _, exists := uniqueFns[fn.Name]; !exists {
|
|
uniqueFns[fn.Name] = fn
|
|
dedupedFns = append(dedupedFns, fn)
|
|
}
|
|
}
|
|
fns = dedupedFns
|
|
|
|
src := genStubs(recv, fns)
|
|
recName := strings.SplitN(recv, " ", 2)
|
|
name := strings.TrimPrefix(recName[1], "*")
|
|
src2 := genStr(name, fns)
|
|
|
|
path, ifaceID, err := findInterface(iface)
|
|
if err != nil {
|
|
fatal(err)
|
|
}
|
|
|
|
var buf bytes.Buffer
|
|
fmt.Fprint(&buf, packageStr)
|
|
fmt.Fprintf(&buf, "import \"%s\"\n\n", path)
|
|
fmt.Fprintf(&buf, "var _ %s.%s = (*%s)(nil)\n\n", filepath.Base(path), ifaceID, name)
|
|
fmt.Fprint(&buf, string(src2))
|
|
buf.WriteString("\n")
|
|
fmt.Fprint(&buf, string(src))
|
|
pretty, err := format.Source(buf.Bytes())
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
imp, err := imports.Process("", pretty, nil)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
switch *flOut {
|
|
case "":
|
|
fmt.Println(string(imp))
|
|
default:
|
|
f, err := os.Create(*flOut)
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
defer f.Close()
|
|
_, err = f.Write(imp)
|
|
if err != nil {
|
|
log.Fatal(err) //nolint:gocritic // ignore exitAfterDefer
|
|
}
|
|
}
|
|
}
|
|
|
|
func fatal(msg interface{}) {
|
|
fmt.Fprintln(os.Stderr, msg)
|
|
os.Exit(1)
|
|
}
|