summaryrefslogtreecommitdiffstats
path: root/vendor/github.com/stretchr/testify/_codegen/main.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/stretchr/testify/_codegen/main.go')
-rw-r--r--vendor/github.com/stretchr/testify/_codegen/main.go37
1 files changed, 33 insertions, 4 deletions
diff --git a/vendor/github.com/stretchr/testify/_codegen/main.go b/vendor/github.com/stretchr/testify/_codegen/main.go
index 328009f84..2e5e8124f 100644
--- a/vendor/github.com/stretchr/testify/_codegen/main.go
+++ b/vendor/github.com/stretchr/testify/_codegen/main.go
@@ -1,5 +1,5 @@
// This program reads all assertion functions from the assert package and
-// automatically generates the corersponding requires and forwarded assertions
+// automatically generates the corresponding requires and forwarded assertions
package main
@@ -10,6 +10,7 @@ import (
"go/ast"
"go/build"
"go/doc"
+ "go/format"
"go/importer"
"go/parser"
"go/token"
@@ -19,6 +20,7 @@ import (
"log"
"os"
"path"
+ "regexp"
"strings"
"text/template"
@@ -27,6 +29,7 @@ import (
var (
pkg = flag.String("assert-path", "github.com/stretchr/testify/assert", "Path to the assert package")
+ includeF = flag.Bool("include-format-funcs", false, "include format functions such as Errorf and Equalf")
outputPkg = flag.String("output-package", "", "package for the resulting code")
tmplFile = flag.String("template", "", "What file to load the function template from")
out = flag.String("out", "", "What file to write the source code to")
@@ -77,13 +80,18 @@ func generateCode(importer imports.Importer, funcs []testFunc) error {
}
}
+ code, err := format.Source(buff.Bytes())
+ if err != nil {
+ return err
+ }
+
// Write file
output, err := outputFile()
if err != nil {
return err
}
defer output.Close()
- _, err = io.Copy(output, buff)
+ _, err = io.Copy(output, bytes.NewReader(code))
return err
}
@@ -133,7 +141,7 @@ func analyzeCode(scope *types.Scope, docs *doc.Package) (imports.Importer, []tes
if !ok {
continue
}
- // Check function signatuer has at least two arguments
+ // Check function signature has at least two arguments
sig := fn.Type().(*types.Signature)
if sig.Params().Len() < 2 {
continue
@@ -151,13 +159,18 @@ func analyzeCode(scope *types.Scope, docs *doc.Package) (imports.Importer, []tes
continue
}
+ // Skip functions ending with f
+ if strings.HasSuffix(fdocs.Name, "f") && !*includeF {
+ continue
+ }
+
funcs = append(funcs, testFunc{*outputPkg, fdocs, fn})
importer.AddImportsFrom(sig.Params())
}
return importer, funcs, nil
}
-// parsePackageSource returns the types scope and the package documentation from the pa
+// parsePackageSource returns the types scope and the package documentation from the package
func parsePackageSource(pkg string) (*types.Scope, *doc.Package, error) {
pd, err := build.Import(pkg, ".", 0)
if err != nil {
@@ -258,10 +271,26 @@ func (f *testFunc) ForwardedParams() string {
return p
}
+func (f *testFunc) ParamsFormat() string {
+ return strings.Replace(f.Params(), "msgAndArgs", "msg string, args", 1)
+}
+
+func (f *testFunc) ForwardedParamsFormat() string {
+ return strings.Replace(f.ForwardedParams(), "msgAndArgs", "append([]interface{}{msg}, args...)", 1)
+}
+
func (f *testFunc) Comment() string {
return "// " + strings.Replace(strings.TrimSpace(f.DocInfo.Doc), "\n", "\n// ", -1)
}
+func (f *testFunc) CommentFormat() string {
+ search := fmt.Sprintf("%s", f.DocInfo.Name)
+ replace := fmt.Sprintf("%sf", f.DocInfo.Name)
+ comment := strings.Replace(f.Comment(), search, replace, -1)
+ exp := regexp.MustCompile(replace + `\(((\(\)|[^)])+)\)`)
+ return exp.ReplaceAllString(comment, replace+`($1, "error message %s", "formatted")`)
+}
+
func (f *testFunc) CommentWithoutT(receiver string) string {
search := fmt.Sprintf("assert.%s(t, ", f.DocInfo.Name)
replace := fmt.Sprintf("%s.%s(", receiver, f.DocInfo.Name)