diff --git a/check.go b/check.go index cd33fb7..f79df09 100644 --- a/check.go +++ b/check.go @@ -51,6 +51,7 @@ type Global struct { typeNodeMapping TypeNodeMapping InspectTemplateNode TemplateNodeInspectorFunc + InspectCallNode ExecuteTemplateNodeInspectorFunc // Qualifier controls how types are printed in error messages. // If nil, types are printed with their full package path. @@ -58,7 +59,7 @@ type Global struct { Qualifier types.Qualifier } -type TemplateNodeInspectorFunc func(t *parse.Tree, node *parse.TemplateNode, tp types.Type) +type TemplateNodeInspectorFunc func(node *parse.TemplateNode, t *parse.Tree, tp types.Type) func NewGlobal(pkg *types.Package, fileSet *token.FileSet, trees TreeFinder, fnChecker CallChecker) *Global { return &Global{ @@ -295,7 +296,7 @@ func (s *scope) checkTemplateNode(tree *parse.Tree, dot types.Type, n *parse.Tem x = types.Typ[types.UntypedNil] } if fn := s.global.InspectTemplateNode; fn != nil { - fn(tree, n, x) + fn(n, tree, x) } childTree, ok := s.global.trees.FindTree(n.Name) if !ok { diff --git a/check_test.go b/check_test.go index 2ed5402..418bc7a 100644 --- a/check_test.go +++ b/check_test.go @@ -946,7 +946,7 @@ func TestTemplateNodeTypeHook(t *testing.T) { // Set the hook var hookCalls []HookCall - global.InspectTemplateNode = func(t *parse.Tree, node *parse.TemplateNode, tp types.Type) { + global.InspectTemplateNode = func(node *parse.TemplateNode, t *parse.Tree, tp types.Type) { hookCalls = append(hookCalls, HookCall{ treeName: t.Name, nodeName: node.Name, diff --git a/cmd/check-templates/main.go b/cmd/check-templates/main.go index 1a2bd82..c26b571 100644 --- a/cmd/check-templates/main.go +++ b/cmd/check-templates/main.go @@ -2,9 +2,13 @@ package main import ( "fmt" + "go/ast" "go/token" + "go/types" "io" + "log" "os" + "text/template/parse" "golang.org/x/tools/go/packages" @@ -12,13 +16,17 @@ import ( ) func main() { - os.Exit(run(os.Args[1:], os.Stdout, os.Stderr)) + wd, err := os.Getwd() + if err != nil { + log.Fatalln(err) + } + os.Exit(run(wd, os.Args[1:], os.Stdout, os.Stderr)) } -func run(args []string, stdout, stderr io.Writer) int { - dir := "." +func run(dir string, args []string, stdout, stderr io.Writer) int { + loadArgs := []string{"."} if len(args) > 0 { - dir = args[0] + loadArgs = args } fset := token.NewFileSet() @@ -28,19 +36,23 @@ func run(args []string, stdout, stderr io.Writer) int { packages.NeedTypes | packages.NeedSyntax | packages.NeedEmbedPatterns | packages.NeedEmbedFiles | packages.NeedImports | packages.NeedModule, Dir: dir, - }, dir) + }, loadArgs...) if err != nil { - fmt.Fprintf(stderr, "failed to load packages: %v\n", err) + _, _ = fmt.Fprintf(stderr, "failed to load packages: %v\n", err) return 1 } exitCode := 0 for _, pkg := range pkgs { for _, e := range pkg.Errors { - fmt.Fprintln(stderr, e) + _, _ = fmt.Fprintln(stderr, e) exitCode = 1 } - if err := check.Package(pkg); err != nil { - fmt.Fprintln(stderr, err) + if err := check.Package(pkg, func(node *ast.CallExpr, t *parse.Tree, tp types.Type) { + + }, func(node *parse.TemplateNode, t *parse.Tree, tp types.Type) { + + }); err != nil { + _, _ = fmt.Fprintln(stderr, err) exitCode = 1 } } diff --git a/cmd/check-templates/main_test.go b/cmd/check-templates/main_test.go index acc3f24..f504eab 100644 --- a/cmd/check-templates/main_test.go +++ b/cmd/check-templates/main_test.go @@ -25,11 +25,7 @@ func checkTemplatesCommand() script.Cmd { }, func(state *script.State, args ...string) (script.WaitFunc, error) { return func(state *script.State) (string, string, error) { var stdout, stderr bytes.Buffer - cmdArgs := args - if len(cmdArgs) == 0 { - cmdArgs = []string{state.Getwd()} - } - code := run(cmdArgs, &stdout, &stderr) + code := run(state.Getwd(), args, &stdout, &stderr) var err error if code != 0 { err = script.ErrUsage diff --git a/package.go b/package.go index e44de5f..401445c 100644 --- a/package.go +++ b/package.go @@ -7,6 +7,7 @@ import ( "go/token" "go/types" "path/filepath" + "text/template/parse" "golang.org/x/tools/go/packages" @@ -15,6 +16,7 @@ import ( ) type pendingCall struct { + call *ast.CallExpr receiverObj types.Object templateName string dataType types.Type @@ -26,15 +28,17 @@ type resolvedTemplate struct { metadata *asteval.TemplateMetadata } +type ExecuteTemplateNodeInspectorFunc func(node *ast.CallExpr, t *parse.Tree, tp types.Type) + // Package discovers all .ExecuteTemplate calls in the given package, // resolves receiver variables to their template construction chains, // and type-checks each call. // // ExecuteTemplate must be called with a string literal for the second parameter. -func Package(pkg *packages.Package) error { +func Package(pkg *packages.Package, inspectCall ExecuteTemplateNodeInspectorFunc, inspectTemplate TemplateNodeInspectorFunc) error { pending, receivers := findExecuteCalls(pkg) resolved, resolveErrs := resolveTemplates(pkg, receivers) - callErr := checkCalls(pkg, pending, resolved) + callErr := checkCalls(pkg, pending, resolved, inspectCall, inspectTemplate) return errors.Join(append(resolveErrs, callErr)...) } @@ -73,6 +77,7 @@ func findExecuteCalls(pkg *packages.Package) ([]pendingCall, map[types.Object]st } dataType := pkg.TypesInfo.TypeOf(call.Args[2]) pending = append(pending, pendingCall{ + call: call, receiverObj: obj, templateName: templateName, dataType: dataType, @@ -210,7 +215,7 @@ func resolveTemplates(pkg *packages.Package, receivers map[types.Object]struct{} // checkCalls type-checks each pending ExecuteTemplate call against its // resolved template. -func checkCalls(pkg *packages.Package, pending []pendingCall, resolved map[types.Object]*resolvedTemplate) error { +func checkCalls(pkg *packages.Package, pending []pendingCall, resolved map[types.Object]*resolvedTemplate, inspectCall ExecuteTemplateNodeInspectorFunc, inspectTemplate TemplateNodeInspectorFunc) error { mergedFunctions := make(Functions) if pkg.Types != nil { mergedFunctions = DefaultFunctions(pkg.Types) @@ -232,6 +237,10 @@ func checkCalls(pkg *packages.Package, pending []pendingCall, resolved map[types continue } global := NewGlobal(pkg.Types, pkg.Fset, rt.templates, mergedFunctions) + global.InspectTemplateNode = inspectTemplate + if inspectCall != nil { + inspectCall(p.call, looked.Tree(), p.dataType) + } if err := Execute(global, looked.Tree(), p.dataType); err != nil { errs = append(errs, err) }