Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions check.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,15 @@ type Global struct {
typeNodeMapping TypeNodeMapping

InspectTemplateNode TemplateNodeInspectorFunc
InspectCallNode ExecuteTemplateNodeInspectorFunc

// Qualifier controls how types are printed in error messages.
// If nil, types are printed with their full package path.
// See types.WriteType for details.
Qualifier types.Qualifier
}

type TemplateNodeInspectorFunc func(t *parse.Tree, node *parse.TemplateNode, tp types.Type)
type TemplateNodeInspectorFunc func(node *parse.TemplateNode, t *parse.Tree, tp types.Type)

func NewGlobal(pkg *types.Package, fileSet *token.FileSet, trees TreeFinder, fnChecker CallChecker) *Global {
return &Global{
Expand Down Expand Up @@ -295,7 +296,7 @@ func (s *scope) checkTemplateNode(tree *parse.Tree, dot types.Type, n *parse.Tem
x = types.Typ[types.UntypedNil]
}
if fn := s.global.InspectTemplateNode; fn != nil {
fn(tree, n, x)
fn(n, tree, x)
}
childTree, ok := s.global.trees.FindTree(n.Name)
if !ok {
Expand Down
2 changes: 1 addition & 1 deletion check_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -946,7 +946,7 @@ func TestTemplateNodeTypeHook(t *testing.T) {

// Set the hook
var hookCalls []HookCall
global.InspectTemplateNode = func(t *parse.Tree, node *parse.TemplateNode, tp types.Type) {
global.InspectTemplateNode = func(node *parse.TemplateNode, t *parse.Tree, tp types.Type) {
hookCalls = append(hookCalls, HookCall{
treeName: t.Name,
nodeName: node.Name,
Expand Down
30 changes: 21 additions & 9 deletions cmd/check-templates/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,31 @@ package main

import (
"fmt"
"go/ast"
"go/token"
"go/types"
"io"
"log"
"os"
"text/template/parse"

"golang.org/x/tools/go/packages"

"github.com/typelate/check"
)

func main() {
os.Exit(run(os.Args[1:], os.Stdout, os.Stderr))
wd, err := os.Getwd()
if err != nil {
log.Fatalln(err)
}
os.Exit(run(wd, os.Args[1:], os.Stdout, os.Stderr))
}

func run(args []string, stdout, stderr io.Writer) int {
dir := "."
func run(dir string, args []string, stdout, stderr io.Writer) int {
loadArgs := []string{"."}
if len(args) > 0 {
dir = args[0]
loadArgs = args
}

fset := token.NewFileSet()
Expand All @@ -28,19 +36,23 @@ func run(args []string, stdout, stderr io.Writer) int {
packages.NeedTypes | packages.NeedSyntax | packages.NeedEmbedPatterns |
packages.NeedEmbedFiles | packages.NeedImports | packages.NeedModule,
Dir: dir,
}, dir)
}, loadArgs...)
if err != nil {
fmt.Fprintf(stderr, "failed to load packages: %v\n", err)
_, _ = fmt.Fprintf(stderr, "failed to load packages: %v\n", err)
return 1
}
exitCode := 0
for _, pkg := range pkgs {
for _, e := range pkg.Errors {
fmt.Fprintln(stderr, e)
_, _ = fmt.Fprintln(stderr, e)
exitCode = 1
}
if err := check.Package(pkg); err != nil {
fmt.Fprintln(stderr, err)
if err := check.Package(pkg, func(node *ast.CallExpr, t *parse.Tree, tp types.Type) {

}, func(node *parse.TemplateNode, t *parse.Tree, tp types.Type) {

}); err != nil {
_, _ = fmt.Fprintln(stderr, err)
exitCode = 1
}
}
Expand Down
6 changes: 1 addition & 5 deletions cmd/check-templates/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,7 @@ func checkTemplatesCommand() script.Cmd {
}, func(state *script.State, args ...string) (script.WaitFunc, error) {
return func(state *script.State) (string, string, error) {
var stdout, stderr bytes.Buffer
cmdArgs := args
if len(cmdArgs) == 0 {
cmdArgs = []string{state.Getwd()}
}
code := run(cmdArgs, &stdout, &stderr)
code := run(state.Getwd(), args, &stdout, &stderr)
var err error
if code != 0 {
err = script.ErrUsage
Expand Down
15 changes: 12 additions & 3 deletions package.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"go/token"
"go/types"
"path/filepath"
"text/template/parse"

"golang.org/x/tools/go/packages"

Expand All @@ -15,6 +16,7 @@ import (
)

type pendingCall struct {
call *ast.CallExpr
receiverObj types.Object
templateName string
dataType types.Type
Expand All @@ -26,15 +28,17 @@ type resolvedTemplate struct {
metadata *asteval.TemplateMetadata
}

type ExecuteTemplateNodeInspectorFunc func(node *ast.CallExpr, t *parse.Tree, tp types.Type)

// Package discovers all .ExecuteTemplate calls in the given package,
// resolves receiver variables to their template construction chains,
// and type-checks each call.
//
// ExecuteTemplate must be called with a string literal for the second parameter.
func Package(pkg *packages.Package) error {
func Package(pkg *packages.Package, inspectCall ExecuteTemplateNodeInspectorFunc, inspectTemplate TemplateNodeInspectorFunc) error {
pending, receivers := findExecuteCalls(pkg)
resolved, resolveErrs := resolveTemplates(pkg, receivers)
callErr := checkCalls(pkg, pending, resolved)
callErr := checkCalls(pkg, pending, resolved, inspectCall, inspectTemplate)
return errors.Join(append(resolveErrs, callErr)...)
}

Expand Down Expand Up @@ -73,6 +77,7 @@ func findExecuteCalls(pkg *packages.Package) ([]pendingCall, map[types.Object]st
}
dataType := pkg.TypesInfo.TypeOf(call.Args[2])
pending = append(pending, pendingCall{
call: call,
receiverObj: obj,
templateName: templateName,
dataType: dataType,
Expand Down Expand Up @@ -210,7 +215,7 @@ func resolveTemplates(pkg *packages.Package, receivers map[types.Object]struct{}

// checkCalls type-checks each pending ExecuteTemplate call against its
// resolved template.
func checkCalls(pkg *packages.Package, pending []pendingCall, resolved map[types.Object]*resolvedTemplate) error {
func checkCalls(pkg *packages.Package, pending []pendingCall, resolved map[types.Object]*resolvedTemplate, inspectCall ExecuteTemplateNodeInspectorFunc, inspectTemplate TemplateNodeInspectorFunc) error {
mergedFunctions := make(Functions)
if pkg.Types != nil {
mergedFunctions = DefaultFunctions(pkg.Types)
Expand All @@ -232,6 +237,10 @@ func checkCalls(pkg *packages.Package, pending []pendingCall, resolved map[types
continue
}
global := NewGlobal(pkg.Types, pkg.Fset, rt.templates, mergedFunctions)
global.InspectTemplateNode = inspectTemplate
if inspectCall != nil {
inspectCall(p.call, looked.Tree(), p.dataType)
}
if err := Execute(global, looked.Tree(), p.dataType); err != nil {
errs = append(errs, err)
}
Expand Down