From 77c82221795d27fcd345b865bb67a4777de33bb3 Mon Sep 17 00:00:00 2001 From: Christopher Hunter <8398225+crhntr@users.noreply.github.com> Date: Sat, 7 Feb 2026 11:28:57 -0800 Subject: [PATCH 1/3] feat: add check-template command --- cmd/check-templates/main.go | 48 ++ cmd/check-templates/main_test.go | 40 ++ .../err_additional_parsefs_missing_field.txt | 43 ++ .../testdata/err_aliased_import.txt | 38 ++ .../testdata/err_closure_missing_field.txt | 38 ++ .../testdata/err_funcs_chain.txt | 42 ++ .../testdata/err_imported_type.txt | 42 ++ .../testdata/err_inline_struct.txt | 37 ++ .../testdata/err_local_var_missing_field.txt | 35 ++ .../testdata/err_missing_field.txt | 38 ++ .../testdata/err_multiple_errors.txt | 49 ++ .../testdata/err_multiple_template_vars.txt | 49 ++ .../testdata/err_nested_template.txt | 41 ++ .../testdata/err_shadowed_var.txt | 48 ++ .../testdata/err_text_template.txt | 38 ++ .../testdata/pass_additional_parsefs.txt | 51 ++ .../pass_additional_parsefs_no_must.txt | 47 ++ .../testdata/pass_aliased_import.txt | 38 ++ cmd/check-templates/testdata/pass_closure.txt | 38 ++ .../testdata/pass_different_identifier.txt | 38 ++ .../testdata/pass_funcs_chain.txt | 40 ++ .../testdata/pass_imported_type.txt | 41 ++ .../testdata/pass_inline_struct.txt | 36 ++ .../testdata/pass_local_var.txt | 35 ++ .../testdata/pass_multiple_calls.txt | 48 ++ .../testdata/pass_multiple_template_vars.txt | 46 ++ .../testdata/pass_nested_template.txt | 40 ++ .../testdata/pass_no_execute_calls.txt | 13 + .../testdata/pass_non_template_execute.txt | 28 + .../testdata/pass_parse_call.txt | 29 + .../testdata/pass_shadowed_var.txt | 47 ++ cmd/check-templates/testdata/pass_simple.txt | 38 ++ .../testdata/pass_text_template.txt | 37 ++ func.go | 10 +- go.mod | 1 + go.sum | 2 + internal/asteval/errors.go | 13 + internal/asteval/forrest.go | 20 + internal/asteval/string.go | 30 + internal/asteval/template.go | 542 ++++++++++++++++++ .../testdata/template/assets_dir.txtar | 23 + .../testdata/template/bad_embed_pattern.txtar | 16 + .../asteval/testdata/template/delims.txtar | 22 + .../asteval/testdata/template/funcs.txtar | 41 ++ .../asteval/testdata/template/parse.txtar | 15 + .../testdata/template/template_ParseFS.txtar | 38 ++ .../asteval/testdata/template/templates.txtar | 84 +++ internal/astgen/format.go | 23 + internal/astgen/queries.go | 36 ++ package.go | 301 ++++++++++ 50 files changed, 2532 insertions(+), 1 deletion(-) create mode 100644 cmd/check-templates/main.go create mode 100644 cmd/check-templates/main_test.go create mode 100644 cmd/check-templates/testdata/err_additional_parsefs_missing_field.txt create mode 100644 cmd/check-templates/testdata/err_aliased_import.txt create mode 100644 cmd/check-templates/testdata/err_closure_missing_field.txt create mode 100644 cmd/check-templates/testdata/err_funcs_chain.txt create mode 100644 cmd/check-templates/testdata/err_imported_type.txt create mode 100644 cmd/check-templates/testdata/err_inline_struct.txt create mode 100644 cmd/check-templates/testdata/err_local_var_missing_field.txt create mode 100644 cmd/check-templates/testdata/err_missing_field.txt create mode 100644 cmd/check-templates/testdata/err_multiple_errors.txt create mode 100644 cmd/check-templates/testdata/err_multiple_template_vars.txt create mode 100644 cmd/check-templates/testdata/err_nested_template.txt create mode 100644 cmd/check-templates/testdata/err_shadowed_var.txt create mode 100644 cmd/check-templates/testdata/err_text_template.txt create mode 100644 cmd/check-templates/testdata/pass_additional_parsefs.txt create mode 100644 cmd/check-templates/testdata/pass_additional_parsefs_no_must.txt create mode 100644 cmd/check-templates/testdata/pass_aliased_import.txt create mode 100644 cmd/check-templates/testdata/pass_closure.txt create mode 100644 cmd/check-templates/testdata/pass_different_identifier.txt create mode 100644 cmd/check-templates/testdata/pass_funcs_chain.txt create mode 100644 cmd/check-templates/testdata/pass_imported_type.txt create mode 100644 cmd/check-templates/testdata/pass_inline_struct.txt create mode 100644 cmd/check-templates/testdata/pass_local_var.txt create mode 100644 cmd/check-templates/testdata/pass_multiple_calls.txt create mode 100644 cmd/check-templates/testdata/pass_multiple_template_vars.txt create mode 100644 cmd/check-templates/testdata/pass_nested_template.txt create mode 100644 cmd/check-templates/testdata/pass_no_execute_calls.txt create mode 100644 cmd/check-templates/testdata/pass_non_template_execute.txt create mode 100644 cmd/check-templates/testdata/pass_parse_call.txt create mode 100644 cmd/check-templates/testdata/pass_shadowed_var.txt create mode 100644 cmd/check-templates/testdata/pass_simple.txt create mode 100644 cmd/check-templates/testdata/pass_text_template.txt create mode 100644 internal/asteval/errors.go create mode 100644 internal/asteval/forrest.go create mode 100644 internal/asteval/string.go create mode 100644 internal/asteval/template.go create mode 100644 internal/asteval/testdata/template/assets_dir.txtar create mode 100644 internal/asteval/testdata/template/bad_embed_pattern.txtar create mode 100644 internal/asteval/testdata/template/delims.txtar create mode 100644 internal/asteval/testdata/template/funcs.txtar create mode 100644 internal/asteval/testdata/template/parse.txtar create mode 100644 internal/asteval/testdata/template/template_ParseFS.txtar create mode 100644 internal/asteval/testdata/template/templates.txtar create mode 100644 internal/astgen/format.go create mode 100644 internal/astgen/queries.go create mode 100644 package.go diff --git a/cmd/check-templates/main.go b/cmd/check-templates/main.go new file mode 100644 index 0000000..1a2bd82 --- /dev/null +++ b/cmd/check-templates/main.go @@ -0,0 +1,48 @@ +package main + +import ( + "fmt" + "go/token" + "io" + "os" + + "golang.org/x/tools/go/packages" + + "github.com/typelate/check" +) + +func main() { + os.Exit(run(os.Args[1:], os.Stdout, os.Stderr)) +} + +func run(args []string, stdout, stderr io.Writer) int { + dir := "." + if len(args) > 0 { + dir = args[0] + } + + fset := token.NewFileSet() + pkgs, err := packages.Load(&packages.Config{ + Fset: fset, + Mode: packages.NeedTypesInfo | packages.NeedName | packages.NeedFiles | + packages.NeedTypes | packages.NeedSyntax | packages.NeedEmbedPatterns | + packages.NeedEmbedFiles | packages.NeedImports | packages.NeedModule, + Dir: dir, + }, dir) + if err != nil { + fmt.Fprintf(stderr, "failed to load packages: %v\n", err) + return 1 + } + exitCode := 0 + for _, pkg := range pkgs { + for _, e := range pkg.Errors { + fmt.Fprintln(stderr, e) + exitCode = 1 + } + if err := check.Package(pkg); err != nil { + fmt.Fprintln(stderr, err) + exitCode = 1 + } + } + return exitCode +} diff --git a/cmd/check-templates/main_test.go b/cmd/check-templates/main_test.go new file mode 100644 index 0000000..acc3f24 --- /dev/null +++ b/cmd/check-templates/main_test.go @@ -0,0 +1,40 @@ +package main + +import ( + "bytes" + "path/filepath" + "testing" + + "rsc.io/script" + "rsc.io/script/scripttest" +) + +func Test(t *testing.T) { + e := script.NewEngine() + e.Quiet = true + e.Cmds = scripttest.DefaultCmds() + e.Cmds["check-templates"] = checkTemplatesCommand() + ctx := t.Context() + scripttest.Test(t, ctx, e, nil, filepath.FromSlash("testdata/*.txt")) +} + +func checkTemplatesCommand() script.Cmd { + return script.Command(script.CmdUsage{ + Summary: "check-templates [dir]", + Args: "[dir]", + }, func(state *script.State, args ...string) (script.WaitFunc, error) { + return func(state *script.State) (string, string, error) { + var stdout, stderr bytes.Buffer + cmdArgs := args + if len(cmdArgs) == 0 { + cmdArgs = []string{state.Getwd()} + } + code := run(cmdArgs, &stdout, &stderr) + var err error + if code != 0 { + err = script.ErrUsage + } + return stdout.String(), stderr.String(), err + }, nil + }) +} diff --git a/cmd/check-templates/testdata/err_additional_parsefs_missing_field.txt b/cmd/check-templates/testdata/err_additional_parsefs_missing_field.txt new file mode 100644 index 0000000..100c55e --- /dev/null +++ b/cmd/check-templates/testdata/err_additional_parsefs_missing_field.txt @@ -0,0 +1,43 @@ +# Template defined at top level, modified in a function. The additional template +# has a field that doesn't exist on the data type. + +! check-templates +stderr 'type check failed:.*about\.gohtml:1:5: executing "about\.gohtml" at <\.Missing>: Missing not found on example\.com/app\.Page' + +-- go.mod -- +module example.com/app + +go 1.25.0 +-- main.go -- +package main + +import ( + "embed" + "fmt" + "html/template" + "net/http" +) + +//go:embed *.gohtml +var source embed.FS + +type Page struct { + Title string +} + +var ts = template.Must(template.New("app").ParseFS(source, "index.gohtml")) + +func setup() { + template.Must(ts.ParseFS(source, "about.gohtml")) +} + +func handleAbout(w http.ResponseWriter, r *http.Request) { + _ = ts.ExecuteTemplate(w, "about.gohtml", Page{Title: "Home"}) +} + +var _ = fmt.Sprint + +-- index.gohtml -- +

{{.Title}}

+-- about.gohtml -- +

{{.Missing}}

diff --git a/cmd/check-templates/testdata/err_aliased_import.txt b/cmd/check-templates/testdata/err_aliased_import.txt new file mode 100644 index 0000000..cf4adc8 --- /dev/null +++ b/cmd/check-templates/testdata/err_aliased_import.txt @@ -0,0 +1,38 @@ +# Template import with a non-standard alias should still detect errors. + +! check-templates +stderr 'type check failed:.*index\.gohtml:1:6: executing "index\.gohtml" at <\.Missing>: Missing not found on example\.com/app\.Page' + +-- go.mod -- +module example.com/app + +go 1.25.0 +-- main.go -- +package main + +import ( + "embed" + "fmt" + htmltpl "html/template" + "net/http" +) + +var ( + //go:embed *.gohtml + source embed.FS + + templates = htmltpl.Must(htmltpl.ParseFS(source, "*")) +) + +type Page struct { + Title string +} + +func handleIndex(w http.ResponseWriter, r *http.Request) { + _ = templates.ExecuteTemplate(w, "index.gohtml", Page{Title: "Home"}) +} + +var _ = fmt.Sprint + +-- index.gohtml -- +

{{.Missing}}

diff --git a/cmd/check-templates/testdata/err_closure_missing_field.txt b/cmd/check-templates/testdata/err_closure_missing_field.txt new file mode 100644 index 0000000..3c01045 --- /dev/null +++ b/cmd/check-templates/testdata/err_closure_missing_field.txt @@ -0,0 +1,38 @@ +# Template parsed in outer function, closure calls ExecuteTemplate with missing field. + +! check-templates +stderr 'type check failed:.*index\.gohtml:1:6: executing "index\.gohtml" at <\.Missing>: Missing not found on example\.com/app\.Page' + +-- go.mod -- +module example.com/app + +go 1.25.0 +-- main.go -- +package main + +import ( + "embed" + "fmt" + "html/template" + "net/http" +) + +//go:embed *.gohtml +var source embed.FS + +type Page struct { + Title string +} + +func routes(mux *http.ServeMux) { + ts := template.Must(template.ParseFS(source, "*")) + + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + _ = ts.ExecuteTemplate(w, "index.gohtml", Page{Title: "Home"}) + }) +} + +var _ = fmt.Sprint + +-- index.gohtml -- +

{{.Missing}}

diff --git a/cmd/check-templates/testdata/err_funcs_chain.txt b/cmd/check-templates/testdata/err_funcs_chain.txt new file mode 100644 index 0000000..c14f893 --- /dev/null +++ b/cmd/check-templates/testdata/err_funcs_chain.txt @@ -0,0 +1,42 @@ +# Template constructed with Funcs() chained before ParseFS should still +# report errors for missing fields. + +! check-templates +stderr 'type check failed:.*index\.gohtml:1:2: executing "index\.gohtml" at <\.Missing>: Missing not found on example\.com/app\.Page' + +-- go.mod -- +module example.com/app + +go 1.25.0 +-- main.go -- +package main + +import ( + "embed" + "fmt" + "html/template" + "io" + "strings" +) + +var ( + //go:embed *.gohtml + source embed.FS + + templates = template.Must(template.New("").Funcs(template.FuncMap{ + "upper": strings.ToUpper, + }).ParseFS(source, "*")) +) + +type Page struct { + Title string +} + +func render() { + _ = templates.ExecuteTemplate(io.Discard, "index.gohtml", Page{}) +} + +var _ = fmt.Sprint + +-- index.gohtml -- +{{.Missing | upper}} diff --git a/cmd/check-templates/testdata/err_imported_type.txt b/cmd/check-templates/testdata/err_imported_type.txt new file mode 100644 index 0000000..a9a3a60 --- /dev/null +++ b/cmd/check-templates/testdata/err_imported_type.txt @@ -0,0 +1,42 @@ +# Types imported from other packages should report errors for missing fields. + +! check-templates +stderr 'type check failed:.*index\.gohtml:1:2: executing "index\.gohtml" at <\.Missing>: Missing not found on example\.com/app/internal/model\.Page' + +-- go.mod -- +module example.com/app + +go 1.25.0 +-- internal/model/types.go -- +package model + +type Page struct { + Title string +} +-- main.go -- +package main + +import ( + "embed" + "fmt" + "html/template" + "io" + + "example.com/app/internal/model" +) + +var ( + //go:embed *.gohtml + source embed.FS + + templates = template.Must(template.ParseFS(source, "*")) +) + +func render() { + _ = templates.ExecuteTemplate(io.Discard, "index.gohtml", model.Page{}) +} + +var _ = fmt.Sprint + +-- index.gohtml -- +{{.Missing}} diff --git a/cmd/check-templates/testdata/err_inline_struct.txt b/cmd/check-templates/testdata/err_inline_struct.txt new file mode 100644 index 0000000..bb5fe87 --- /dev/null +++ b/cmd/check-templates/testdata/err_inline_struct.txt @@ -0,0 +1,37 @@ +# Inline anonymous struct types should report errors for missing fields. + +! check-templates +stderr 'type check failed:.*index\.gohtml:1:2: executing "index\.gohtml" at <\.Missing>: Missing not found on struct\{Title string\}' + +-- go.mod -- +module example.com/app + +go 1.25.0 +-- main.go -- +package main + +import ( + "embed" + "fmt" + "html/template" + "io" +) + +var ( + //go:embed *.gohtml + source embed.FS + + templates = template.Must(template.ParseFS(source, "*")) +) + +func render() { + var data struct { + Title string + } + _ = templates.ExecuteTemplate(io.Discard, "index.gohtml", data) +} + +var _ = fmt.Sprint + +-- index.gohtml -- +{{.Missing}} diff --git a/cmd/check-templates/testdata/err_local_var_missing_field.txt b/cmd/check-templates/testdata/err_local_var_missing_field.txt new file mode 100644 index 0000000..bde48d4 --- /dev/null +++ b/cmd/check-templates/testdata/err_local_var_missing_field.txt @@ -0,0 +1,35 @@ +# Template defined as local variable reports errors for missing fields. + +! check-templates +stderr 'type check failed:.*index\.gohtml:1:6: executing "index\.gohtml" at <\.Missing>: Missing not found on example\.com/app\.Page' + +-- go.mod -- +module example.com/app + +go 1.25.0 +-- main.go -- +package main + +import ( + "embed" + "fmt" + "html/template" + "net/http" +) + +//go:embed *.gohtml +var source embed.FS + +type Page struct { + Title string +} + +func handleIndex(w http.ResponseWriter, r *http.Request) { + ts := template.Must(template.ParseFS(source, "*")) + _ = ts.ExecuteTemplate(w, "index.gohtml", Page{Title: "Home"}) +} + +var _ = fmt.Sprint + +-- index.gohtml -- +

{{.Missing}}

diff --git a/cmd/check-templates/testdata/err_missing_field.txt b/cmd/check-templates/testdata/err_missing_field.txt new file mode 100644 index 0000000..c5b9e13 --- /dev/null +++ b/cmd/check-templates/testdata/err_missing_field.txt @@ -0,0 +1,38 @@ +# Template check fails when a field does not exist on the data type. + +! check-templates +stderr 'type check failed:.*index\.gohtml:1:6: executing "index\.gohtml" at <\.Missing>: Missing not found on example\.com/app\.Page' + +-- go.mod -- +module example.com/app + +go 1.25.0 +-- main.go -- +package main + +import ( + "embed" + "fmt" + "html/template" + "net/http" +) + +var ( + //go:embed *.gohtml + source embed.FS + + templates = template.Must(template.ParseFS(source, "*")) +) + +type Page struct { + Title string +} + +func handleIndex(w http.ResponseWriter, r *http.Request) { + _ = templates.ExecuteTemplate(w, "index.gohtml", Page{Title: "Home"}) +} + +var _ = fmt.Sprint + +-- index.gohtml -- +

{{.Missing}}

diff --git a/cmd/check-templates/testdata/err_multiple_errors.txt b/cmd/check-templates/testdata/err_multiple_errors.txt new file mode 100644 index 0000000..63e1a5f --- /dev/null +++ b/cmd/check-templates/testdata/err_multiple_errors.txt @@ -0,0 +1,49 @@ +# Multiple ExecuteTemplate calls can each report errors. + +! check-templates +stderr 'type check failed:.*index\.gohtml:1:6: executing "index\.gohtml" at <\.Missing>: Missing not found on example\.com/app\.IndexPage' +stderr 'type check failed:.*about\.gohtml:1:5: executing "about\.gohtml" at <\.Unknown>: Unknown not found on example\.com/app\.AboutPage' + +-- go.mod -- +module example.com/app + +go 1.25.0 +-- main.go -- +package main + +import ( + "embed" + "fmt" + "html/template" + "net/http" +) + +var ( + //go:embed *.gohtml + source embed.FS + + templates = template.Must(template.ParseFS(source, "*")) +) + +type IndexPage struct { + Title string +} + +type AboutPage struct { + Name string +} + +func handleIndex(w http.ResponseWriter, r *http.Request) { + _ = templates.ExecuteTemplate(w, "index.gohtml", IndexPage{Title: "Home"}) +} + +func handleAbout(w http.ResponseWriter, r *http.Request) { + _ = templates.ExecuteTemplate(w, "about.gohtml", AboutPage{Name: "World"}) +} + +var _ = fmt.Sprint + +-- index.gohtml -- +

{{.Missing}}

+-- about.gohtml -- +

{{.Unknown}}

diff --git a/cmd/check-templates/testdata/err_multiple_template_vars.txt b/cmd/check-templates/testdata/err_multiple_template_vars.txt new file mode 100644 index 0000000..7dd7062 --- /dev/null +++ b/cmd/check-templates/testdata/err_multiple_template_vars.txt @@ -0,0 +1,49 @@ +# Multiple template variables must not cross-contaminate: only the incorrect +# call should produce an error. + +! check-templates +stderr 'type check failed:.*about\.gohtml:1:2: executing "about\.gohtml" at <\.Name>: Name not found on example\.com/app\.IndexPage' +! stderr 'Title not found' + +-- go.mod -- +module example.com/app + +go 1.25.0 +-- main.go -- +package main + +import ( + "embed" + "fmt" + "html/template" + "io" +) + +//go:embed *.gohtml +var source embed.FS + +var templatesA = template.Must(template.ParseFS(source, "index.gohtml")) +var templatesB = template.Must(template.ParseFS(source, "about.gohtml")) + +type IndexPage struct { + Title string +} + +type AboutPage struct { + Name string +} + +func renderIndex() { + _ = templatesA.ExecuteTemplate(io.Discard, "index.gohtml", IndexPage{}) +} + +func renderAbout() { + _ = templatesB.ExecuteTemplate(io.Discard, "about.gohtml", IndexPage{}) +} + +var _ = fmt.Sprint + +-- index.gohtml -- +{{.Title}} +-- about.gohtml -- +{{.Name}} diff --git a/cmd/check-templates/testdata/err_nested_template.txt b/cmd/check-templates/testdata/err_nested_template.txt new file mode 100644 index 0000000..de654b4 --- /dev/null +++ b/cmd/check-templates/testdata/err_nested_template.txt @@ -0,0 +1,41 @@ +# A nested template call should check the invoked template against the data type. + +! check-templates +stderr 'type check failed:.*header\.gohtml:1:6: executing "header\.gohtml" at <\.Missing>: Missing not found on example\.com/app\.Page' + +-- go.mod -- +module example.com/app + +go 1.25.0 +-- main.go -- +package main + +import ( + "embed" + "fmt" + "html/template" + "io" +) + +var ( + //go:embed *.gohtml + source embed.FS + + templates = template.Must(template.ParseFS(source, "*")) +) + +type Page struct { + Title string +} + +func render() { + _ = templates.ExecuteTemplate(io.Discard, "index.gohtml", Page{}) +} + +var _ = fmt.Sprint + +-- index.gohtml -- +{{template "header.gohtml" .}} +

body

+-- header.gohtml -- +

{{.Missing}}

diff --git a/cmd/check-templates/testdata/err_shadowed_var.txt b/cmd/check-templates/testdata/err_shadowed_var.txt new file mode 100644 index 0000000..479f1e8 --- /dev/null +++ b/cmd/check-templates/testdata/err_shadowed_var.txt @@ -0,0 +1,48 @@ +# A shadowed template variable must not confuse resolution: the inner scope +# has its own template, and each ExecuteTemplate call is checked against the +# correct definition. The outer ts parses index.gohtml (which uses .Title), +# and the inner ts parses about.gohtml (which uses .Name). Passing the wrong +# data type to the inner call should produce an error only for that call. + +! check-templates +stderr 'type check failed:.*about\.gohtml:1:5: executing "about\.gohtml" at <\.Name>: Name not found on example\.com/app\.IndexPage' +! stderr 'Title not found' + +-- go.mod -- +module example.com/app + +go 1.25.0 +-- main.go -- +package main + +import ( + "embed" + "fmt" + "html/template" + "net/http" +) + +//go:embed *.gohtml +var source embed.FS + +type IndexPage struct { + Title string +} + +var ts = template.Must(template.ParseFS(source, "index.gohtml")) + +func handleIndex(w http.ResponseWriter, r *http.Request) { + _ = ts.ExecuteTemplate(w, "index.gohtml", IndexPage{Title: "Home"}) +} + +func handleAbout(w http.ResponseWriter, r *http.Request) { + ts := template.Must(template.ParseFS(source, "about.gohtml")) + _ = ts.ExecuteTemplate(w, "about.gohtml", IndexPage{Title: "oops"}) +} + +var _ = fmt.Sprint + +-- index.gohtml -- +

{{.Title}}

+-- about.gohtml -- +

{{.Name}}

diff --git a/cmd/check-templates/testdata/err_text_template.txt b/cmd/check-templates/testdata/err_text_template.txt new file mode 100644 index 0000000..7e6e187 --- /dev/null +++ b/cmd/check-templates/testdata/err_text_template.txt @@ -0,0 +1,38 @@ +# text/template should be checked the same as html/template. + +! check-templates +stderr 'type check failed:.*index\.gotmpl:1:2: executing "index\.gotmpl" at <\.Missing>: Missing not found on example\.com/app\.Page' + +-- go.mod -- +module example.com/app + +go 1.25.0 +-- main.go -- +package main + +import ( + "embed" + "fmt" + "io" + "text/template" +) + +var ( + //go:embed *.gotmpl + source embed.FS + + templates = template.Must(template.ParseFS(source, "*")) +) + +type Page struct { + Title string +} + +func render() { + _ = templates.ExecuteTemplate(io.Discard, "index.gotmpl", Page{Title: "Home"}) +} + +var _ = fmt.Sprint + +-- index.gotmpl -- +{{.Missing}} diff --git a/cmd/check-templates/testdata/pass_additional_parsefs.txt b/cmd/check-templates/testdata/pass_additional_parsefs.txt new file mode 100644 index 0000000..50059f6 --- /dev/null +++ b/cmd/check-templates/testdata/pass_additional_parsefs.txt @@ -0,0 +1,51 @@ +# Template defined at top level, then has more templates parsed inside a function. +# Both sets of templates should be available. + +check-templates +! stderr . + +-- go.mod -- +module example.com/app + +go 1.25.0 +-- main.go -- +package main + +import ( + "embed" + "fmt" + "html/template" + "net/http" +) + +//go:embed *.gohtml +var source embed.FS + +type IndexPage struct { + Title string +} + +type AboutPage struct { + Name string +} + +var ts = template.Must(template.New("app").ParseFS(source, "index.gohtml")) + +func setup() { + template.Must(ts.ParseFS(source, "about.gohtml")) +} + +func handleIndex(w http.ResponseWriter, r *http.Request) { + _ = ts.ExecuteTemplate(w, "index.gohtml", IndexPage{Title: "Home"}) +} + +func handleAbout(w http.ResponseWriter, r *http.Request) { + _ = ts.ExecuteTemplate(w, "about.gohtml", AboutPage{Name: "World"}) +} + +var _ = fmt.Sprint + +-- index.gohtml -- +

{{.Title}}

+-- about.gohtml -- +

{{.Name}}

diff --git a/cmd/check-templates/testdata/pass_additional_parsefs_no_must.txt b/cmd/check-templates/testdata/pass_additional_parsefs_no_must.txt new file mode 100644 index 0000000..09975df --- /dev/null +++ b/cmd/check-templates/testdata/pass_additional_parsefs_no_must.txt @@ -0,0 +1,47 @@ +# Template defined at top level, modified in a function without template.Must. +# The ParseFS call result is discarded but should still be detected. + +check-templates +! stderr . + +-- go.mod -- +module example.com/app + +go 1.25.0 +-- main.go -- +package main + +import ( + "embed" + "fmt" + "html/template" + "net/http" +) + +//go:embed *.gohtml +var source embed.FS + +type IndexPage struct { + Title string +} + +type AboutPage struct { + Name string +} + +var ts = template.Must(template.New("app").ParseFS(source, "index.gohtml")) + +func setup() { + _, _ = ts.ParseFS(source, "about.gohtml") +} + +func handleAbout(w http.ResponseWriter, r *http.Request) { + _ = ts.ExecuteTemplate(w, "about.gohtml", AboutPage{Name: "World"}) +} + +var _ = fmt.Sprint + +-- index.gohtml -- +

{{.Title}}

+-- about.gohtml -- +

{{.Name}}

diff --git a/cmd/check-templates/testdata/pass_aliased_import.txt b/cmd/check-templates/testdata/pass_aliased_import.txt new file mode 100644 index 0000000..96a9837 --- /dev/null +++ b/cmd/check-templates/testdata/pass_aliased_import.txt @@ -0,0 +1,38 @@ +# Template import with a non-standard alias should still be resolved. + +check-templates +! stderr . + +-- go.mod -- +module example.com/app + +go 1.25.0 +-- main.go -- +package main + +import ( + "embed" + "fmt" + htmltpl "html/template" + "net/http" +) + +var ( + //go:embed *.gohtml + source embed.FS + + templates = htmltpl.Must(htmltpl.ParseFS(source, "*")) +) + +type Page struct { + Title string +} + +func handleIndex(w http.ResponseWriter, r *http.Request) { + _ = templates.ExecuteTemplate(w, "index.gohtml", Page{Title: "Home"}) +} + +var _ = fmt.Sprint + +-- index.gohtml -- +

{{.Title}}

diff --git a/cmd/check-templates/testdata/pass_closure.txt b/cmd/check-templates/testdata/pass_closure.txt new file mode 100644 index 0000000..02f97d3 --- /dev/null +++ b/cmd/check-templates/testdata/pass_closure.txt @@ -0,0 +1,38 @@ +# Template parsed in outer function, closures call ExecuteTemplate. + +check-templates +! stderr . + +-- go.mod -- +module example.com/app + +go 1.25.0 +-- main.go -- +package main + +import ( + "embed" + "fmt" + "html/template" + "net/http" +) + +//go:embed *.gohtml +var source embed.FS + +type Page struct { + Title string +} + +func routes(mux *http.ServeMux) { + ts := template.Must(template.ParseFS(source, "*")) + + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + _ = ts.ExecuteTemplate(w, "index.gohtml", Page{Title: "Home"}) + }) +} + +var _ = fmt.Sprint + +-- index.gohtml -- +

{{.Title}}

diff --git a/cmd/check-templates/testdata/pass_different_identifier.txt b/cmd/check-templates/testdata/pass_different_identifier.txt new file mode 100644 index 0000000..be176ca --- /dev/null +++ b/cmd/check-templates/testdata/pass_different_identifier.txt @@ -0,0 +1,38 @@ +# Template variable with a non-standard identifier name. + +check-templates +! stderr . + +-- go.mod -- +module example.com/app + +go 1.25.0 +-- main.go -- +package main + +import ( + "embed" + "fmt" + "html/template" + "net/http" +) + +var ( + //go:embed *.gohtml + source embed.FS + + tmpl = template.Must(template.ParseFS(source, "*")) +) + +type Page struct { + Title string +} + +func handleIndex(w http.ResponseWriter, r *http.Request) { + _ = tmpl.ExecuteTemplate(w, "index.gohtml", Page{Title: "Home"}) +} + +var _ = fmt.Sprint + +-- index.gohtml -- +

{{.Title}}

diff --git a/cmd/check-templates/testdata/pass_funcs_chain.txt b/cmd/check-templates/testdata/pass_funcs_chain.txt new file mode 100644 index 0000000..3a2dc12 --- /dev/null +++ b/cmd/check-templates/testdata/pass_funcs_chain.txt @@ -0,0 +1,40 @@ +# Template constructed with Funcs() chained before ParseFS should be resolved. + +check-templates + +-- go.mod -- +module example.com/app + +go 1.25.0 +-- main.go -- +package main + +import ( + "embed" + "fmt" + "html/template" + "io" + "strings" +) + +var ( + //go:embed *.gohtml + source embed.FS + + templates = template.Must(template.New("").Funcs(template.FuncMap{ + "upper": strings.ToUpper, + }).ParseFS(source, "*")) +) + +type Page struct { + Title string +} + +func render() { + _ = templates.ExecuteTemplate(io.Discard, "index.gohtml", Page{}) +} + +var _ = fmt.Sprint + +-- index.gohtml -- +{{.Title | upper}} diff --git a/cmd/check-templates/testdata/pass_imported_type.txt b/cmd/check-templates/testdata/pass_imported_type.txt new file mode 100644 index 0000000..b6770ea --- /dev/null +++ b/cmd/check-templates/testdata/pass_imported_type.txt @@ -0,0 +1,41 @@ +# Types imported from other packages should be resolved correctly. + +check-templates + +-- go.mod -- +module example.com/app + +go 1.25.0 +-- internal/model/types.go -- +package model + +type Page struct { + Title string +} +-- main.go -- +package main + +import ( + "embed" + "fmt" + "html/template" + "io" + + "example.com/app/internal/model" +) + +var ( + //go:embed *.gohtml + source embed.FS + + templates = template.Must(template.ParseFS(source, "*")) +) + +func render() { + _ = templates.ExecuteTemplate(io.Discard, "index.gohtml", model.Page{}) +} + +var _ = fmt.Sprint + +-- index.gohtml -- +{{.Title}} diff --git a/cmd/check-templates/testdata/pass_inline_struct.txt b/cmd/check-templates/testdata/pass_inline_struct.txt new file mode 100644 index 0000000..cfecd07 --- /dev/null +++ b/cmd/check-templates/testdata/pass_inline_struct.txt @@ -0,0 +1,36 @@ +# Inline anonymous struct types should be resolved correctly. + +check-templates + +-- go.mod -- +module example.com/app + +go 1.25.0 +-- main.go -- +package main + +import ( + "embed" + "fmt" + "html/template" + "io" +) + +var ( + //go:embed *.gohtml + source embed.FS + + templates = template.Must(template.ParseFS(source, "*")) +) + +func render() { + var data struct { + Title string + } + _ = templates.ExecuteTemplate(io.Discard, "index.gohtml", data) +} + +var _ = fmt.Sprint + +-- index.gohtml -- +{{.Title}} diff --git a/cmd/check-templates/testdata/pass_local_var.txt b/cmd/check-templates/testdata/pass_local_var.txt new file mode 100644 index 0000000..dfa6fd3 --- /dev/null +++ b/cmd/check-templates/testdata/pass_local_var.txt @@ -0,0 +1,35 @@ +# Template defined as a local variable in the same function as ExecuteTemplate. + +check-templates +! stderr . + +-- go.mod -- +module example.com/app + +go 1.25.0 +-- main.go -- +package main + +import ( + "embed" + "fmt" + "html/template" + "net/http" +) + +//go:embed *.gohtml +var source embed.FS + +func handleIndex(w http.ResponseWriter, r *http.Request) { + ts := template.Must(template.ParseFS(source, "*")) + _ = ts.ExecuteTemplate(w, "index.gohtml", Page{Title: "Home"}) +} + +type Page struct { + Title string +} + +var _ = fmt.Sprint + +-- index.gohtml -- +

{{.Title}}

diff --git a/cmd/check-templates/testdata/pass_multiple_calls.txt b/cmd/check-templates/testdata/pass_multiple_calls.txt new file mode 100644 index 0000000..36573ba --- /dev/null +++ b/cmd/check-templates/testdata/pass_multiple_calls.txt @@ -0,0 +1,48 @@ +# Multiple ExecuteTemplate calls with different types all pass. + +check-templates +! stderr . + +-- go.mod -- +module example.com/app + +go 1.25.0 +-- main.go -- +package main + +import ( + "embed" + "fmt" + "html/template" + "net/http" +) + +var ( + //go:embed *.gohtml + source embed.FS + + templates = template.Must(template.ParseFS(source, "*")) +) + +type IndexPage struct { + Title string +} + +type AboutPage struct { + Name string +} + +func handleIndex(w http.ResponseWriter, r *http.Request) { + _ = templates.ExecuteTemplate(w, "index.gohtml", IndexPage{Title: "Home"}) +} + +func handleAbout(w http.ResponseWriter, r *http.Request) { + _ = templates.ExecuteTemplate(w, "about.gohtml", AboutPage{Name: "World"}) +} + +var _ = fmt.Sprint + +-- index.gohtml -- +

{{.Title}}

+-- about.gohtml -- +

{{.Name}}

diff --git a/cmd/check-templates/testdata/pass_multiple_template_vars.txt b/cmd/check-templates/testdata/pass_multiple_template_vars.txt new file mode 100644 index 0000000..ad95b96 --- /dev/null +++ b/cmd/check-templates/testdata/pass_multiple_template_vars.txt @@ -0,0 +1,46 @@ +# Multiple template variables should each resolve independently. + +check-templates + +-- go.mod -- +module example.com/app + +go 1.25.0 +-- main.go -- +package main + +import ( + "embed" + "fmt" + "html/template" + "io" +) + +//go:embed *.gohtml +var source embed.FS + +var templatesA = template.Must(template.ParseFS(source, "index.gohtml")) +var templatesB = template.Must(template.ParseFS(source, "about.gohtml")) + +type IndexPage struct { + Title string +} + +type AboutPage struct { + Name string +} + +func renderIndex() { + _ = templatesA.ExecuteTemplate(io.Discard, "index.gohtml", IndexPage{}) +} + +func renderAbout() { + _ = templatesB.ExecuteTemplate(io.Discard, "about.gohtml", AboutPage{}) +} + +var _ = fmt.Sprint + +-- index.gohtml -- +{{.Title}} +-- about.gohtml -- +{{.Name}} diff --git a/cmd/check-templates/testdata/pass_nested_template.txt b/cmd/check-templates/testdata/pass_nested_template.txt new file mode 100644 index 0000000..e6c04e0 --- /dev/null +++ b/cmd/check-templates/testdata/pass_nested_template.txt @@ -0,0 +1,40 @@ +# Nested template calls should propagate the data type correctly. + +check-templates + +-- go.mod -- +module example.com/app + +go 1.25.0 +-- main.go -- +package main + +import ( + "embed" + "fmt" + "html/template" + "io" +) + +var ( + //go:embed *.gohtml + source embed.FS + + templates = template.Must(template.ParseFS(source, "*")) +) + +type Page struct { + Title string +} + +func render() { + _ = templates.ExecuteTemplate(io.Discard, "index.gohtml", Page{}) +} + +var _ = fmt.Sprint + +-- index.gohtml -- +{{template "header.gohtml" .}} +

body

+-- header.gohtml -- +

{{.Title}}

diff --git a/cmd/check-templates/testdata/pass_no_execute_calls.txt b/cmd/check-templates/testdata/pass_no_execute_calls.txt new file mode 100644 index 0000000..000f01f --- /dev/null +++ b/cmd/check-templates/testdata/pass_no_execute_calls.txt @@ -0,0 +1,13 @@ +# Package with no ExecuteTemplate calls passes without error. + +check-templates +! stderr . + +-- go.mod -- +module example.com/app + +go 1.25.0 +-- main.go -- +package main + +func main() {} diff --git a/cmd/check-templates/testdata/pass_non_template_execute.txt b/cmd/check-templates/testdata/pass_non_template_execute.txt new file mode 100644 index 0000000..b7a572e --- /dev/null +++ b/cmd/check-templates/testdata/pass_non_template_execute.txt @@ -0,0 +1,28 @@ +# A custom type with an ExecuteTemplate method should not be confused +# with html/template. The tool should not report errors for it. + +check-templates +! stderr . + +-- go.mod -- +module example.com/app + +go 1.25.0 +-- main.go -- +package main + +import ( + "io" + "net/http" +) + +type Renderer struct{} + +func (r *Renderer) ExecuteTemplate(w io.Writer, name string, data any) error { + return nil +} + +func handle(w http.ResponseWriter, r *http.Request) { + renderer := &Renderer{} + _ = renderer.ExecuteTemplate(w, "index.gohtml", struct{ Missing string }{}) +} diff --git a/cmd/check-templates/testdata/pass_parse_call.txt b/cmd/check-templates/testdata/pass_parse_call.txt new file mode 100644 index 0000000..d1ea925 --- /dev/null +++ b/cmd/check-templates/testdata/pass_parse_call.txt @@ -0,0 +1,29 @@ +# Template constructed with Parse (not ParseFS) passes check. + +check-templates +! stderr . + +-- go.mod -- +module example.com/app + +go 1.25.0 +-- main.go -- +package main + +import ( + "fmt" + "html/template" + "net/http" +) + +var templates = template.Must(template.New("greeting").Parse(`

Hello, {{.Name}}!

`)) + +type Greeting struct { + Name string +} + +func handle(w http.ResponseWriter, r *http.Request) { + _ = templates.ExecuteTemplate(w, "greeting", Greeting{Name: "World"}) +} + +var _ = fmt.Sprint diff --git a/cmd/check-templates/testdata/pass_shadowed_var.txt b/cmd/check-templates/testdata/pass_shadowed_var.txt new file mode 100644 index 0000000..213415c --- /dev/null +++ b/cmd/check-templates/testdata/pass_shadowed_var.txt @@ -0,0 +1,47 @@ +# A shadowed template variable should resolve to the correct definition at each call site. + +check-templates +! stderr . + +-- go.mod -- +module example.com/app + +go 1.25.0 +-- main.go -- +package main + +import ( + "embed" + "fmt" + "html/template" + "net/http" +) + +//go:embed *.gohtml +var source embed.FS + +type IndexPage struct { + Title string +} + +type AboutPage struct { + Name string +} + +var ts = template.Must(template.ParseFS(source, "index.gohtml")) + +func handleIndex(w http.ResponseWriter, r *http.Request) { + _ = ts.ExecuteTemplate(w, "index.gohtml", IndexPage{Title: "Home"}) +} + +func handleAbout(w http.ResponseWriter, r *http.Request) { + ts := template.Must(template.ParseFS(source, "about.gohtml")) + _ = ts.ExecuteTemplate(w, "about.gohtml", AboutPage{Name: "World"}) +} + +var _ = fmt.Sprint + +-- index.gohtml -- +

{{.Title}}

+-- about.gohtml -- +

{{.Name}}

diff --git a/cmd/check-templates/testdata/pass_simple.txt b/cmd/check-templates/testdata/pass_simple.txt new file mode 100644 index 0000000..de22050 --- /dev/null +++ b/cmd/check-templates/testdata/pass_simple.txt @@ -0,0 +1,38 @@ +# Simple template check passes when fields match. + +check-templates +! stderr . + +-- go.mod -- +module example.com/app + +go 1.25.0 +-- main.go -- +package main + +import ( + "embed" + "fmt" + "html/template" + "net/http" +) + +var ( + //go:embed *.gohtml + source embed.FS + + templates = template.Must(template.ParseFS(source, "*")) +) + +type Page struct { + Title string +} + +func handleIndex(w http.ResponseWriter, r *http.Request) { + _ = templates.ExecuteTemplate(w, "index.gohtml", Page{Title: "Home"}) +} + +var _ = fmt.Sprint + +-- index.gohtml -- +

{{.Title}}

diff --git a/cmd/check-templates/testdata/pass_text_template.txt b/cmd/check-templates/testdata/pass_text_template.txt new file mode 100644 index 0000000..cbf8999 --- /dev/null +++ b/cmd/check-templates/testdata/pass_text_template.txt @@ -0,0 +1,37 @@ +# text/template should be checked the same as html/template. + +check-templates + +-- go.mod -- +module example.com/app + +go 1.25.0 +-- main.go -- +package main + +import ( + "embed" + "fmt" + "io" + "text/template" +) + +var ( + //go:embed *.gotmpl + source embed.FS + + templates = template.Must(template.ParseFS(source, "*")) +) + +type Page struct { + Title string +} + +func render() { + _ = templates.ExecuteTemplate(io.Discard, "index.gotmpl", Page{Title: "Home"}) +} + +var _ = fmt.Sprint + +-- index.gotmpl -- +{{.Title}} diff --git a/func.go b/func.go index 72e2b6f..80c51c8 100644 --- a/func.go +++ b/func.go @@ -31,7 +31,15 @@ func DefaultFunctions(pkg *types.Package) Functions { } { if p, ok := findPackage(pkg, pn); ok && p != nil { for funcIdent, templateFunc := range idents { - fns[templateFunc] = p.Scope().Lookup(funcIdent).Type().(*types.Signature) + obj := p.Scope().Lookup(funcIdent) + if obj == nil { + continue + } + sig, ok := obj.Type().(*types.Signature) + if !ok { + continue + } + fns[templateFunc] = sig } } } diff --git a/go.mod b/go.mod index 7d0808a..bad63f5 100644 --- a/go.mod +++ b/go.mod @@ -5,6 +5,7 @@ go 1.25.0 require ( github.com/stretchr/testify v1.11.1 golang.org/x/tools v0.41.0 + rsc.io/script v0.0.2 ) require ( diff --git a/go.sum b/go.sum index ddb0899..7e5240f 100644 --- a/go.sum +++ b/go.sum @@ -16,3 +16,5 @@ gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+ gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +rsc.io/script v0.0.2 h1:eYoG7A3GFC3z1pRx3A2+s/vZ9LA8cxojHyCvslnj4RI= +rsc.io/script v0.0.2/go.mod h1:cKBjCtFBBeZ0cbYFRXkRoxP+xGqhArPa9t3VWhtXfzU= diff --git a/internal/asteval/errors.go b/internal/asteval/errors.go new file mode 100644 index 0000000..fb2c83e --- /dev/null +++ b/internal/asteval/errors.go @@ -0,0 +1,13 @@ +package asteval + +import ( + "fmt" + "go/token" + "path/filepath" +) + +func wrapWithFilename(workingDirectory string, set *token.FileSet, pos token.Pos, err error) error { + p := set.Position(pos) + p.Filename, _ = filepath.Rel(workingDirectory, p.Filename) + return fmt.Errorf("%s: %w", p, err) +} diff --git a/internal/asteval/forrest.go b/internal/asteval/forrest.go new file mode 100644 index 0000000..a3083bb --- /dev/null +++ b/internal/asteval/forrest.go @@ -0,0 +1,20 @@ +package asteval + +import ( + "html/template" + "text/template/parse" +) + +type Forrest template.Template + +func NewForrest(templates *template.Template) *Forrest { + return (*Forrest)(templates) +} + +func (f *Forrest) FindTree(name string) (*parse.Tree, bool) { + ts := (*template.Template)(f).Lookup(name) + if ts == nil { + return nil, false + } + return ts.Tree, true +} diff --git a/internal/asteval/string.go b/internal/asteval/string.go new file mode 100644 index 0000000..ba8b813 --- /dev/null +++ b/internal/asteval/string.go @@ -0,0 +1,30 @@ +package asteval + +import ( + "fmt" + "go/ast" + "go/token" + "strconv" + + "github.com/typelate/check/internal/astgen" +) + +func StringLiteralExpression(wd string, set *token.FileSet, exp ast.Expr) (string, error) { + arg, ok := exp.(*ast.BasicLit) + if !ok || arg.Kind != token.STRING { + return "", wrapWithFilename(wd, set, exp.Pos(), fmt.Errorf("expected string literal got %s", astgen.Format(exp))) + } + return strconv.Unquote(arg.Value) +} + +func StringLiteralExpressionList(wd string, set *token.FileSet, list []ast.Expr) ([]string, error) { + result := make([]string, 0, len(list)) + for _, a := range list { + s, err := StringLiteralExpression(wd, set, a) + if err != nil { + return result, err + } + result = append(result, s) + } + return result, nil +} diff --git a/internal/asteval/template.go b/internal/asteval/template.go new file mode 100644 index 0000000..bdfe916 --- /dev/null +++ b/internal/asteval/template.go @@ -0,0 +1,542 @@ +package asteval + +import ( + "bytes" + "fmt" + "go/ast" + "go/format" + "go/token" + "go/types" + "html/template" + "os" + "path/filepath" + "slices" + "strconv" + "strings" + "text/template/parse" + "unicode" + + "github.com/typelate/check/internal/astgen" +) + +// TemplateMetadata accumulates metadata during template evaluation. +type TemplateMetadata struct { + EmbedFilePaths []string + ParseCalls []*ast.BasicLit +} + +func EvaluateTemplateSelector(ts *template.Template, pkg *types.Package, typesInfo *types.Info, expression ast.Expr, workingDirectory, templatesVariable, rDelim, lDelim string, fileSet *token.FileSet, files []*ast.File, embeddedPaths []string, funcTypeMaps TemplateFunctions, fm template.FuncMap, meta *TemplateMetadata) (*template.Template, string, string, error) { + call, ok := expression.(*ast.CallExpr) + if !ok { + return nil, lDelim, rDelim, wrapWithFilename(workingDirectory, fileSet, expression.Pos(), fmt.Errorf("expected call expression")) + } + sel, ok := call.Fun.(*ast.SelectorExpr) + if !ok { + return nil, lDelim, rDelim, wrapWithFilename(workingDirectory, fileSet, call.Fun.Pos(), fmt.Errorf("unexpected expression %T: %s", call.Fun, astgen.Format(call.Fun))) + } + switch x := sel.X.(type) { + default: + return nil, lDelim, rDelim, wrapWithFilename(workingDirectory, fileSet, sel.X.Pos(), fmt.Errorf("expected exactly one argument %s got %d", astgen.Format(sel.X), len(call.Args))) + case *ast.Ident: + if !isTemplatePkgIdent(typesInfo, x) { + if ts == nil { + return nil, lDelim, rDelim, wrapWithFilename(workingDirectory, fileSet, sel.X.Pos(), fmt.Errorf("expected template package got %s", astgen.Format(sel.X))) + } + // Variable receiver — apply method to existing template. + switch sel.Sel.Name { + case "ParseFS": + filePaths, err := evaluateCallParseFilesArgs(workingDirectory, fileSet, call, files, embeddedPaths) + if err != nil { + return nil, lDelim, rDelim, err + } + if meta != nil { + meta.EmbedFilePaths = append(meta.EmbedFilePaths, filePaths...) + } + t, err := parseFiles(ts, fm, lDelim, rDelim, filePaths...) + return t, lDelim, rDelim, err + case "Parse": + if len(call.Args) != 1 { + return nil, lDelim, rDelim, wrapWithFilename(workingDirectory, fileSet, call.Lparen, fmt.Errorf("expected exactly one string literal argument")) + } + if meta != nil { + if bl, ok := call.Args[0].(*ast.BasicLit); ok { + meta.ParseCalls = append(meta.ParseCalls, bl) + } + } + sl, err := StringLiteralExpression(workingDirectory, fileSet, call.Args[0]) + if err != nil { + return nil, lDelim, rDelim, err + } + t, err := ts.Parse(sl) + return t, lDelim, rDelim, err + case "Funcs": + if err := evaluateFuncMap(workingDirectory, typesInfo, pkg, fileSet, call, fm, funcTypeMaps); err != nil { + return nil, lDelim, rDelim, err + } + return ts.Funcs(fm), lDelim, rDelim, nil + case "Option": + list, err := StringLiteralExpressionList(workingDirectory, fileSet, call.Args) + if err != nil { + return nil, lDelim, rDelim, err + } + return ts.Option(list...), lDelim, rDelim, nil + case "Delims": + if len(call.Args) != 2 { + return nil, lDelim, rDelim, wrapWithFilename(workingDirectory, fileSet, call.Lparen, fmt.Errorf("expected exactly two string literal arguments")) + } + list, err := StringLiteralExpressionList(workingDirectory, fileSet, call.Args) + if err != nil { + return nil, lDelim, rDelim, err + } + return ts.Delims(list[0], list[1]), list[0], list[1], nil + default: + return nil, lDelim, rDelim, wrapWithFilename(workingDirectory, fileSet, call.Fun.Pos(), fmt.Errorf("unsupported method %s on variable receiver", sel.Sel.Name)) + } + } + switch sel.Sel.Name { + case "Must": + if len(call.Args) != 1 { + return nil, lDelim, rDelim, wrapWithFilename(workingDirectory, fileSet, call.Lparen, fmt.Errorf("expected exactly one argument %s got %d", astgen.Format(sel.X), len(call.Args))) + } + return EvaluateTemplateSelector(ts, pkg, typesInfo, call.Args[0], workingDirectory, templatesVariable, rDelim, lDelim, fileSet, files, embeddedPaths, funcTypeMaps, fm, meta) + case "New": + if len(call.Args) != 1 { + return nil, lDelim, rDelim, wrapWithFilename(workingDirectory, fileSet, call.Lparen, fmt.Errorf("expected exactly one string literal argument")) + } + templateNames, err := StringLiteralExpressionList(workingDirectory, fileSet, call.Args) + if err != nil { + return nil, lDelim, rDelim, err + } + return template.New(templateNames[0]), lDelim, rDelim, nil + case "ParseFS": + filePaths, err := evaluateCallParseFilesArgs(workingDirectory, fileSet, call, files, embeddedPaths) + if err != nil { + return nil, lDelim, rDelim, err + } + if meta != nil { + meta.EmbedFilePaths = append(meta.EmbedFilePaths, filePaths...) + } + t, err := parseFiles(nil, fm, lDelim, rDelim, filePaths...) + return t, lDelim, rDelim, err + default: + return nil, lDelim, rDelim, wrapWithFilename(workingDirectory, fileSet, call.Fun.Pos(), fmt.Errorf("unsupported function %s", sel.Sel.Name)) + } + case *ast.CallExpr: + up, upLDelim, upRDelim, err := EvaluateTemplateSelector(ts, pkg, typesInfo, sel.X, workingDirectory, templatesVariable, rDelim, lDelim, fileSet, files, embeddedPaths, funcTypeMaps, fm, meta) + if err != nil { + return nil, lDelim, rDelim, err + } + switch sel.Sel.Name { + case "Delims": + if len(call.Args) != 2 { + return nil, upLDelim, upRDelim, wrapWithFilename(workingDirectory, fileSet, call.Lparen, fmt.Errorf("expected exactly two string literal arguments")) + } + list, err := StringLiteralExpressionList(workingDirectory, fileSet, call.Args) + if err != nil { + return nil, upLDelim, upRDelim, err + } + return up.Delims(list[0], list[1]), list[0], list[1], nil + case "Parse": + if len(call.Args) != 1 { + return nil, upLDelim, upRDelim, wrapWithFilename(workingDirectory, fileSet, call.Lparen, fmt.Errorf("expected exactly one string literal argument")) + } + if meta != nil { + if bl, ok := call.Args[0].(*ast.BasicLit); ok { + meta.ParseCalls = append(meta.ParseCalls, bl) + } + } + sl, err := StringLiteralExpression(workingDirectory, fileSet, call.Args[0]) + if err != nil { + return nil, upLDelim, upRDelim, err + } + t, err := up.Parse(sl) + return t, upLDelim, upRDelim, err + case "New": + if len(call.Args) != 1 { + return nil, upLDelim, upRDelim, wrapWithFilename(workingDirectory, fileSet, call.Lparen, fmt.Errorf("expected exactly one string literal argument")) + } + templateNames, err := StringLiteralExpressionList(workingDirectory, fileSet, call.Args) + if err != nil { + return nil, upLDelim, upRDelim, err + } + return up.New(templateNames[0]), upLDelim, upRDelim, nil + case "ParseFS": + filePaths, err := evaluateCallParseFilesArgs(workingDirectory, fileSet, call, files, embeddedPaths) + if err != nil { + return nil, upLDelim, upRDelim, err + } + if meta != nil { + meta.EmbedFilePaths = append(meta.EmbedFilePaths, filePaths...) + } + t, err := parseFiles(up, fm, upLDelim, upRDelim, filePaths...) + return t, upLDelim, upRDelim, err + case "Option": + list, err := StringLiteralExpressionList(workingDirectory, fileSet, call.Args) + if err != nil { + return nil, upLDelim, upRDelim, err + } + return up.Option(list...), upLDelim, upRDelim, nil + case "Funcs": + if err := evaluateFuncMap(workingDirectory, typesInfo, pkg, fileSet, call, fm, funcTypeMaps); err != nil { + return nil, upLDelim, upRDelim, err + } + return up.Funcs(fm), upLDelim, upRDelim, nil + default: + return nil, upLDelim, upRDelim, wrapWithFilename(workingDirectory, fileSet, call.Fun.Pos(), fmt.Errorf("unsupported method %s", sel.Sel.Name)) + } + } +} + +// isTemplatePkgIdent reports whether ident refers to the "html/template" +// or "text/template" package via the type checker. +func isTemplatePkgIdent(info *types.Info, ident *ast.Ident) bool { + if info == nil { + return false + } + obj := info.Uses[ident] + pkgName, ok := obj.(*types.PkgName) + if !ok { + return false + } + path := pkgName.Imported().Path() + return path == "html/template" || path == "text/template" +} + +func builtins() template.FuncMap { + type nothing struct{} + return template.FuncMap{ + "and": func() (_ nothing) { return }, + "call": func() (_ nothing) { return }, + "html": func() (_ nothing) { return }, + "index": func() (_ nothing) { return }, + "slice": func() (_ nothing) { return }, + "js": func() (_ nothing) { return }, + "len": func() (_ nothing) { return }, + "not": func() (_ nothing) { return }, + "or": func() (_ nothing) { return }, + "print": func() (_ nothing) { return }, + "printf": func() (_ nothing) { return }, + "println": func() (_ nothing) { return }, + "urlquery": func() (_ nothing) { return }, + + // Comparisons + "eq": func() (_ nothing) { return }, + "ge": func() (_ nothing) { return }, + "gt": func() (_ nothing) { return }, + "le": func() (_ nothing) { return }, + "lt": func() (_ nothing) { return }, + "ne": func() (_ nothing) { return }, + } +} + +func parseFiles(t *template.Template, fm template.FuncMap, leftDelim, rightDelim string, filenames ...string) (*template.Template, error) { + if len(filenames) == 0 { + return nil, fmt.Errorf("html/template: no files named in call to ParseFiles") + } + for _, filename := range filenames { + templateName := filepath.Base(filename) + b, err := os.ReadFile(filename) + if err != nil { + return nil, err + } + s := string(b) + var tmpl *template.Template + if t == nil { + t = template.New(templateName) + } + if templateName == t.Name() { + tmpl = t + } else { + tmpl = t.New(templateName) + } + trees, err := parse.Parse(templateName, s, leftDelim, rightDelim, fm, builtins()) + if err != nil { + return nil, err + } + absoluteFilename, err := filepath.Abs(filename) + if err != nil { + return nil, err + } + for _, tree := range trees { + tree.ParseName = absoluteFilename + if _, err = tmpl.AddParseTree(tree.Name, tree); err != nil { + return nil, err + } + } + } + return t, nil +} + +func evaluateFuncMap(workingDirectory string, typesInfo *types.Info, pkg *types.Package, fileSet *token.FileSet, call *ast.CallExpr, fm template.FuncMap, funcTypesMap TemplateFunctions) error { + if len(call.Args) != 1 { + return wrapWithFilename(workingDirectory, fileSet, call.Lparen, fmt.Errorf("expected exactly 1 template.FuncMap composite literal argument")) + } + arg := call.Args[0] + lit, ok := arg.(*ast.CompositeLit) + if !ok { + return wrapWithFilename(workingDirectory, fileSet, arg.Pos(), fmt.Errorf("expected a template.FuncMap composite literal got %s", astgen.Format(arg))) + } + if typesInfo != nil { + litType := typesInfo.TypeOf(lit) + if named, ok := litType.(*types.Named); ok { + obj := named.Obj() + if obj.Pkg() == nil || obj.Name() != "FuncMap" { + return wrapWithFilename(workingDirectory, fileSet, arg.Pos(), fmt.Errorf("expected template.FuncMap got %s", litType)) + } + path := obj.Pkg().Path() + if path != "html/template" && path != "text/template" { + return wrapWithFilename(workingDirectory, fileSet, arg.Pos(), fmt.Errorf("expected template.FuncMap got %s", litType)) + } + } + } + var buf bytes.Buffer + for i, exp := range lit.Elts { + el, ok := exp.(*ast.KeyValueExpr) + if !ok { + return wrapWithFilename(workingDirectory, fileSet, exp.Pos(), fmt.Errorf("expected element at index %d to be a key value pair got %s", i, astgen.Format(exp))) + } + funcName, err := StringLiteralExpression(workingDirectory, fileSet, el.Key) + if err != nil { + return err + } + + // template.Parse does not evaluate the function signature parameters; + // it ensures the function name is in scope and there is one or two results. + // we could use something like func() string { return "" } for this signature + // but this function from fmt works just fine. + // + // to explore the known requirements run: + // fm[funcName] = nil // will fail because nil does not have `reflect.Kind` Func + // or + // fm[funcName] = func() {} // will fail because there are no results + // or + // fm[funcName] = func() (int, int) {return 0, 0} // will fail because the second result is not an error + fm[funcName] = fmt.Sprintln + + if pkg == nil { + continue + } + buf.Reset() + if err := format.Node(&buf, fileSet, el.Value); err != nil { + return err + } + tv, err := types.Eval(fileSet, pkg, lit.Pos(), buf.String()) + if err != nil { + return err + } + funcTypesMap[funcName] = tv.Type.(*types.Signature) + } + return nil +} + +func evaluateCallParseFilesArgs(workingDirectory string, fileSet *token.FileSet, call *ast.CallExpr, files []*ast.File, embeddedPaths []string) ([]string, error) { + if len(call.Args) < 1 { + return nil, wrapWithFilename(workingDirectory, fileSet, call.Lparen, fmt.Errorf("missing required arguments")) + } + matches, err := embedFSFilePaths(workingDirectory, fileSet, files, call.Args[0], embeddedPaths) + if err != nil { + return nil, err + } + templateNames, err := StringLiteralExpressionList(workingDirectory, fileSet, call.Args[1:]) + if err != nil { + return nil, err + } + filtered := matches[:0] + for _, ef := range matches { + for j, pattern := range templateNames { + match, err := filepath.Match(pattern, ef) + if err != nil { + return nil, wrapWithFilename(workingDirectory, fileSet, call.Args[j+1].Pos(), fmt.Errorf("bad pattern %q: %w", pattern, err)) + } + if !match { + continue + } + filtered = append(filtered, ef) + break + } + } + return joinFilePaths(workingDirectory, filtered...), nil +} + +func embedFSFilePaths(dir string, fileSet *token.FileSet, files []*ast.File, exp ast.Expr, embeddedFiles []string) ([]string, error) { + varIdent, ok := exp.(*ast.Ident) + if !ok { + return nil, wrapWithFilename(dir, fileSet, exp.Pos(), fmt.Errorf("first argument to ParseFS must be an identifier")) + } + for _, decl := range astgen.IterateGenDecl(files, token.VAR) { + for _, s := range decl.Specs { + spec, ok := s.(*ast.ValueSpec) + if !ok || !slices.ContainsFunc(spec.Names, func(e *ast.Ident) bool { return e.Name == varIdent.Name }) { + continue + } + var comment strings.Builder + commentNode := readComments(&comment, decl.Doc, spec.Doc) + templateNames := parseTemplateNames(comment.String()) + absMat, err := embeddedFilesMatchingTemplateNameList(dir, fileSet, commentNode, templateNames, embeddedFiles) + if err != nil { + return nil, err + } + return absMat, nil + } + } + return nil, wrapWithFilename(dir, fileSet, exp.Pos(), fmt.Errorf("variable %s not found", varIdent)) +} + +func embeddedFilesMatchingTemplateNameList(dir string, set *token.FileSet, comment ast.Node, templateNames, embeddedFiles []string) ([]string, error) { + var matches []string + for _, fp := range embeddedFiles { + for _, pattern := range templateNames { + pat := filepath.FromSlash(pattern) + if !strings.ContainsAny(pat, "*[]") { + prefix := filepath.FromSlash(pat + "/") + if strings.HasPrefix(fp, prefix) { + matches = append(matches, fp) + continue + } + } + if matched, err := filepath.Match(pat, fp); err != nil { + return nil, wrapWithFilename(dir, set, comment.Pos(), fmt.Errorf("embed comment malformed: %w", err)) + } else if matched { + matches = append(matches, fp) + } + } + } + return slices.Clip(matches), nil +} + +const goEmbedCommentPrefix = "//go:embed" + +func readComments(s *strings.Builder, groups ...*ast.CommentGroup) ast.Node { + var n ast.Node + for _, c := range groups { + if c == nil { + continue + } + for _, line := range c.List { + if !strings.HasPrefix(line.Text, goEmbedCommentPrefix) { + continue + } + s.WriteString(strings.TrimSpace(strings.TrimPrefix(line.Text, goEmbedCommentPrefix))) + s.WriteByte(' ') + } + n = c + break + } + return n +} + +func parseTemplateNames(input string) []string { + // todo: refactor to use strconv.QuotedPrefix + var ( + templateNames []string + currentTemplateName strings.Builder + inQuote = false + quoteChar rune + ) + + for _, r := range input { + switch { + case r == '"' || r == '`': + if !inQuote { + inQuote = true + quoteChar = r + continue + } + if r != quoteChar { + currentTemplateName.WriteRune(r) + continue + } + templateNames = append(templateNames, currentTemplateName.String()) + currentTemplateName.Reset() + inQuote = false + case unicode.IsSpace(r): + if inQuote { + currentTemplateName.WriteRune(r) + continue + } + if currentTemplateName.Len() > 0 { + templateNames = append(templateNames, currentTemplateName.String()) + currentTemplateName.Reset() + } + default: + currentTemplateName.WriteRune(r) + } + } + + // Import any remaining pattern + if currentTemplateName.Len() > 0 { + templateNames = append(templateNames, currentTemplateName.String()) + } + + return templateNames +} + +func joinFilePaths(wd string, rel ...string) []string { + result := slices.Clone(rel) + for i := range result { + result[i] = filepath.Join(wd, result[i]) + } + return result +} + +func RelativeFilePaths(wd string, abs ...string) ([]string, error) { + result := slices.Clone(abs) + for i, p := range result { + r, err := filepath.Rel(wd, p) + if err != nil { + return nil, err + } + result[i] = r + } + return result, nil +} + +type TemplateFunctions map[string]*types.Signature + +func DefaultFunctions(pkg *types.Package) TemplateFunctions { + funcTypeMap := make(TemplateFunctions) + fmtPkg, ok := findPackage(pkg, "fmt") + if !ok || fmtPkg == nil { + return funcTypeMap + } + funcTypeMap["printf"] = fmtPkg.Scope().Lookup("Sprintf").Type().(*types.Signature) + funcTypeMap["print"] = fmtPkg.Scope().Lookup("Sprint").Type().(*types.Signature) + funcTypeMap["println"] = fmtPkg.Scope().Lookup("Sprintln").Type().(*types.Signature) + return funcTypeMap +} + +func findPackage(pkg *types.Package, path string) (*types.Package, bool) { + if pkg == nil || pkg.Path() == path { + return pkg, true + } + for _, im := range pkg.Imports() { + if p, ok := findPackage(im, path); ok { + return p, true + } + } + return nil, false +} + +func (functions TemplateFunctions) FindFunction(name string) (*types.Signature, bool) { + m := (map[string]*types.Signature)(functions) + fn, ok := m[name] + if !ok { + return nil, false + } + return fn, true +} + +func BasicLiteralString(node ast.Node) (string, bool) { + name, ok := node.(*ast.BasicLit) + if !ok { + return "", false + } + if name.Kind != token.STRING { + return "", false + } + templateName, err := strconv.Unquote(name.Value) + if err != nil { + return "", false + } + return templateName, true +} diff --git a/internal/asteval/testdata/template/assets_dir.txtar b/internal/asteval/testdata/template/assets_dir.txtar new file mode 100644 index 0000000..29b7738 --- /dev/null +++ b/internal/asteval/testdata/template/assets_dir.txtar @@ -0,0 +1,23 @@ +-- template.go -- +package main + +import ( + "embed" + "html/template" +) + +var ( + //go:embed assets + assetsFS embed.FS + + templates = template.Must(template.ParseFS(assetsFS, "assets/*")) +) +-- assets/index.gohtml -- + +{{define "home"}}{{end}} + +-- assets/form.gohtml -- + +{{define "create"}}{{end}} + +{{define "update"}}{{end}} diff --git a/internal/asteval/testdata/template/bad_embed_pattern.txtar b/internal/asteval/testdata/template/bad_embed_pattern.txtar new file mode 100644 index 0000000..0d4137a --- /dev/null +++ b/internal/asteval/testdata/template/bad_embed_pattern.txtar @@ -0,0 +1,16 @@ +-- template.go -- +package main + +import ( + "embed" + "html/template" +) + +var ( + //go:embed "[fail" + assetsFS embed.FS + + templates = template.Must(template.ParseFS(assetsFS, "*")) +) +-- greeting.gohtml -- +Hello, friend! diff --git a/internal/asteval/testdata/template/delims.txtar b/internal/asteval/testdata/template/delims.txtar new file mode 100644 index 0000000..852baa2 --- /dev/null +++ b/internal/asteval/testdata/template/delims.txtar @@ -0,0 +1,22 @@ +-- templates.go -- +package main + +var ( + //go:embed *.gohtml + files embed.FS +) + +var ( + templates = template.Must( + template.Must( + template.Must( + template.New("").ParseFS(files, "default.gohtml")). + Delims("(((", ")))").ParseFS(files, "triple_parens.gohtml")). + Delims("[[", "]]").ParseFS(files, "double_square.gohtml")) +) +-- default.gohtml -- +{{- define "default" -}}{{- end -}} +-- triple_parens.gohtml -- +(((- define "parens" -)))(((- end -))) +-- double_square.gohtml -- +[[- define "square" -]][[end]] diff --git a/internal/asteval/testdata/template/funcs.txtar b/internal/asteval/testdata/template/funcs.txtar new file mode 100644 index 0000000..619aa40 --- /dev/null +++ b/internal/asteval/testdata/template/funcs.txtar @@ -0,0 +1,41 @@ +-- template.go -- +package main + +import ( + "embed" + "html/template" +) + +var ( + //go:embed *.gohtml + src embed.FS + + templates = template.New("x").Funcs(template.FuncMap{ + "greet": func() string { return "Hello" }, + }).ParseFS(src, "greet.gohtml") + + templatesFuncNotDefined = template.New("x").Funcs(template.FuncMap{ + "greet": func() string { return "Hello" }, + }).ParseFS(src, "missing_func.gohtml") + + templatesWrongArg = template.New("x").Funcs(wrong) + + templatesTwoArgs = template.New("x").Funcs(wrong, fail) + + templatesNoArgs = template.New("x").Funcs() + + templatesWrongTypePackageName = template.New("x").Funcs(wrong.FuncMap{}) + + templatesWrongTypeName = template.New("x").Funcs(template.Wrong{}) + + templatesWrongTypeExpression = template.New("x").Funcs(wrong{}) + + templatesWrongTypeElem = template.New("x").Funcs(template.FuncMap{wrong}) + + templatesWrongElemKey = template.New("x").Funcs(template.FuncMap{wrong: func() string { return "" }}) +) +-- greet.gohtml -- +{{greet}}, world! + +-- missing_func.gohtml -- +{{greet}}, {{enemy}}! diff --git a/internal/asteval/testdata/template/parse.txtar b/internal/asteval/testdata/template/parse.txtar new file mode 100644 index 0000000..7157867 --- /dev/null +++ b/internal/asteval/testdata/template/parse.txtar @@ -0,0 +1,15 @@ +-- parse.go -- +package main + +import "html/template" + +var templates = template.New("GET /").Parse(`

Hello, world!

`) + +var multiple = template.New("").Parse(` +{{define "GET /"}}

Hello, world!

{{end}} +{{define "GET /{name}"}}

Hello, {{.PathValue "name"}}!

{{end}} +`) + +var noArg = template.New("").Parse() + +var wrongArg = template.New("").Parse(500) \ No newline at end of file diff --git a/internal/asteval/testdata/template/template_ParseFS.txtar b/internal/asteval/testdata/template/template_ParseFS.txtar new file mode 100644 index 0000000..b07673b --- /dev/null +++ b/internal/asteval/testdata/template/template_ParseFS.txtar @@ -0,0 +1,38 @@ +-- template.go -- +package main + +import ( + "embed" + "html/template" +) + +var ( + //go:embed *.gohtml + templateSource embed.FS + + templates = template.Must(template.ParseFS(templateSource, "*")) + + // allHTML used by templatesHTML and templatesGoHTML to test pattern filtering + // this comment ensures the comment parser skips lines not preceded by go:embed + //go:embed *html + allHTML embed.FS + + templatesHTML = template.Must(template.ParseFS(allHTML, "*.*html")) + templatesGoHTML = template.Must(template.ParseFS(allHTML, "*.gohtml")) + + templateEmbedVariableNotFound = template.Must(template.ParseFS(hiding, "*")) +) +-- index.gohtml -- + +{{define "home"}}{{end}} + +-- form.gohtml -- + +{{define "create"}}{{end}} + +{{define "update"}}{{end}} + +-- script.html -- +{{define "console_log"}} + +{{end}} diff --git a/internal/asteval/testdata/template/templates.txtar b/internal/asteval/testdata/template/templates.txtar new file mode 100644 index 0000000..66a8671 --- /dev/null +++ b/internal/asteval/testdata/template/templates.txtar @@ -0,0 +1,84 @@ +-- template.go -- +package main + +import ( + "embed" + "html/template" +) + +var ( + //go:embed *.gohtml + templateSource embed.FS + + templateNew = template.New("some-name") + + templateParseFSNew = template.Must(template.ParseFS(templateSource, "*")).New("greetings") + + templateNewParseFS = template.Must(template.New("greetings").ParseFS(templateSource, "*")) + + templateNewMissingArg = template.New() + + templateWrongX = UNKNOWN.New() + + templateWrongArgCount = template.New("one", "two") + + templateNewOnIndexed = ts[0].New("one", "two") + + templateNewArg42 = template.New(42) + + templateNewArgIdent = template.New(TemplateName) + + templateNewErrUpstream = template.New(fail).New("n") + + templatesIdent = someIdent + + unsupportedMethod = template.Unknown() + + unexpectedFunExpression = x[3]() + + templateMustNonIdentReceiver = f().Must(template.ParseFS(templateSource, "*")) + + templateMustCalledWithTwoArgs = template.Must(nil, nil) + + templateMustCalledWithNoArg s = template.Must() + + templateMustWrongPackageIdent = wrong.Must() + + templateParseFSWrongPackageIdent = wrong.ParseFS(templateSource, "*") + + templateParseFSReceiverErr = template.New().ParseFS(templateSource, "*") + + templateParseFSUnexpectedReceiver = x[0].ParseFS(templateSource, "*") + + templateParseFSNoArgs = template.ParseFS() + + templateParseFSFirstArgNonIdent = template.ParseFS(os.DirFS("."), "*") + + templateParseFSNonStringLiteralGlob = template.ParseFS(templateSource, "w", 42, "x") + + templateParseFSWithBadGlob = template.ParseFS(templateSource, "[fail") + + templateNewHasWrongNumberOfArgs = template.Must(template.New("x").ParseFS(templateSource, "*")).New() + + templateNewHasWrongTypeOfArgs = template.New("x").New(9000) + + templateNewHasTooManyArgs = template.New("x").New("x", "y") + + templateDelimsGetsNoArgs = template.New("x").Delims() + + templateDelimsGetsTooMany = template.New("x").Delims("x", "y", "") + + templateDelimsWrongExpressionArg = template.New("x").Delims("x", y) + + templateParseFSMethodFails = template.New("x").ParseFS(templateSource, fail) + + templateOptionsRequiresStringLiterals = template.New("x").Option(fail) + + templateUnknownMethod = template.New("x").Unknown() + + templateOptionCall = template.New("x").Option("missingkey=default").ParseFS(templateSource, "*") + + templateOptionCallUnknownArg = template.New("x").Option("unknown").ParseFS(templateSource, "*") +) +-- index.gohtml -- +Hello, friend! diff --git a/internal/astgen/format.go b/internal/astgen/format.go new file mode 100644 index 0000000..748cebe --- /dev/null +++ b/internal/astgen/format.go @@ -0,0 +1,23 @@ +package astgen + +import ( + "bytes" + "fmt" + "go/ast" + "go/format" + "go/printer" + "go/token" +) + +// Format converts an AST node to formatted Go source code +func Format(node ast.Node) string { + var buf bytes.Buffer + if err := printer.Fprint(&buf, token.NewFileSet(), node); err != nil { + return fmt.Sprintf("formatting error: %v", err) + } + out, err := format.Source(buf.Bytes()) + if err != nil { + return fmt.Sprintf("formatting error: %v", err) + } + return string(bytes.ReplaceAll(out, []byte("\n}\nfunc "), []byte("\n}\n\nfunc "))) +} diff --git a/internal/astgen/queries.go b/internal/astgen/queries.go new file mode 100644 index 0000000..e262076 --- /dev/null +++ b/internal/astgen/queries.go @@ -0,0 +1,36 @@ +package astgen + +import ( + "go/ast" + "go/token" +) + +// IterateGenDecl returns an iterator over GenDecl nodes with the specified token type +func IterateGenDecl(files []*ast.File, tok token.Token) func(func(*ast.File, *ast.GenDecl) bool) { + return func(yield func(*ast.File, *ast.GenDecl) bool) { + for _, file := range files { + for _, decl := range file.Decls { + d, ok := decl.(*ast.GenDecl) + if !ok || d.Tok != tok { + continue + } + if !yield(file, d) { + return + } + } + } + } +} + +// IterateValueSpecs returns an iterator over ValueSpec nodes in var declarations +func IterateValueSpecs(files []*ast.File) func(func(*ast.File, *ast.ValueSpec) bool) { + return func(yield func(*ast.File, *ast.ValueSpec) bool) { + for file, decl := range IterateGenDecl(files, token.VAR) { + for _, s := range decl.Specs { + if !yield(file, s.(*ast.ValueSpec)) { + return + } + } + } + } +} diff --git a/package.go b/package.go new file mode 100644 index 0000000..0e102dd --- /dev/null +++ b/package.go @@ -0,0 +1,301 @@ +package check + +import ( + "errors" + "fmt" + "go/ast" + "go/token" + "go/types" + "html/template" + + "golang.org/x/tools/go/packages" + + "github.com/typelate/check/internal/asteval" + "github.com/typelate/check/internal/astgen" +) + +// Package discovers all .ExecuteTemplate calls in the given package, +// resolves receiver variables to their template construction chains, +// and type-checks each call. +// +// ExecuteTemplate must be called with a string literal for the second parameter. +func Package(pkg *packages.Package) error { + // Phase 1: Find all ExecuteTemplate calls and collect receiver objects. + type pendingCall struct { + receiverObj types.Object + templateName string + dataType types.Type + } + var pending []pendingCall + receiverSet := make(map[types.Object]struct{}) + + for _, file := range pkg.Syntax { + ast.Inspect(file, func(node ast.Node) bool { + call, ok := node.(*ast.CallExpr) + if !ok || len(call.Args) != 3 { + return true + } + sel, ok := call.Fun.(*ast.SelectorExpr) + if !ok || sel.Sel.Name != "ExecuteTemplate" { + return true + } + // Verify the method belongs to html/template or text/template. + if !isTemplateMethod(pkg.TypesInfo, sel) { + return true + } + receiverIdent, ok := sel.X.(*ast.Ident) + if !ok { + return true + } + obj := pkg.TypesInfo.Uses[receiverIdent] + if obj == nil { + return true + } + templateName, ok := asteval.BasicLiteralString(call.Args[1]) + if !ok { + return true + } + dataType := pkg.TypesInfo.TypeOf(call.Args[2]) + pending = append(pending, pendingCall{ + receiverObj: obj, + templateName: templateName, + dataType: dataType, + }) + receiverSet[obj] = struct{}{} + return true + }) + } + + // Phase 2: Resolve each unique receiver object to its template construction chain. + type resolvedTemplate struct { + ts *template.Template + funcs asteval.TemplateFunctions + meta *asteval.TemplateMetadata + } + resolved := make(map[types.Object]*resolvedTemplate) + + workingDirectory := packageDirectory(pkg) + embeddedPaths, err := asteval.RelativeFilePaths(workingDirectory, pkg.EmbedFiles...) + if err != nil { + return fmt.Errorf("failed to calculate relative path for embedded files: %w", err) + } + + resolveValueSpec := func(tv *ast.ValueSpec) { + for i, name := range tv.Names { + if i >= len(tv.Values) { + continue + } + obj := pkg.TypesInfo.Defs[name] + if obj == nil { + continue + } + if _, needed := receiverSet[obj]; !needed { + continue + } + + funcTypeMap := asteval.DefaultFunctions(pkg.Types) + meta := &asteval.TemplateMetadata{} + ts, _, _, err := asteval.EvaluateTemplateSelector(nil, pkg.Types, pkg.TypesInfo, tv.Values[i], workingDirectory, name.Name, "", "", pkg.Fset, pkg.Syntax, embeddedPaths, funcTypeMap, make(template.FuncMap), meta) + if err != nil { + return + } + resolved[obj] = &resolvedTemplate{ + ts: ts, + funcs: funcTypeMap, + meta: meta, + } + } + } + + resolveAssignStmt := func(stmt *ast.AssignStmt) { + if stmt.Tok != token.DEFINE { + return + } + for i, lhs := range stmt.Lhs { + if i >= len(stmt.Rhs) { + continue + } + ident, ok := lhs.(*ast.Ident) + if !ok { + continue + } + obj := pkg.TypesInfo.Defs[ident] + if obj == nil { + continue + } + if _, needed := receiverSet[obj]; !needed { + continue + } + + funcTypeMap := asteval.DefaultFunctions(pkg.Types) + meta := &asteval.TemplateMetadata{} + ts, _, _, err := asteval.EvaluateTemplateSelector(nil, pkg.Types, pkg.TypesInfo, stmt.Rhs[i], workingDirectory, ident.Name, "", "", pkg.Fset, pkg.Syntax, embeddedPaths, funcTypeMap, make(template.FuncMap), meta) + if err != nil { + return + } + resolved[obj] = &resolvedTemplate{ + ts: ts, + funcs: funcTypeMap, + meta: meta, + } + } + } + + // Resolve top-level var declarations. + for _, tv := range astgen.IterateValueSpecs(pkg.Syntax) { + resolveValueSpec(tv) + } + + // Resolve function-local var declarations and short variable declarations. + for _, file := range pkg.Syntax { + ast.Inspect(file, func(node ast.Node) bool { + switch n := node.(type) { + case *ast.DeclStmt: + gd, ok := n.Decl.(*ast.GenDecl) + if !ok || gd.Tok != token.VAR { + return true + } + for _, spec := range gd.Specs { + vs, ok := spec.(*ast.ValueSpec) + if !ok { + continue + } + resolveValueSpec(vs) + } + case *ast.AssignStmt: + resolveAssignStmt(n) + } + return true + }) + } + + // Phase 2b: Find additional ParseFS/Parse calls on resolved template variables. + for _, file := range pkg.Syntax { + ast.Inspect(file, func(node ast.Node) bool { + call, ok := node.(*ast.CallExpr) + if !ok { + return true + } + obj := findModificationReceiver(call, pkg.TypesInfo) + if obj == nil { + return true + } + rt, ok := resolved[obj] + if !ok { + return true + } + meta := &asteval.TemplateMetadata{} + ts, _, _, err := asteval.EvaluateTemplateSelector(rt.ts, pkg.Types, pkg.TypesInfo, call, workingDirectory, "", "", "", pkg.Fset, pkg.Syntax, embeddedPaths, rt.funcs, make(template.FuncMap), meta) + if err != nil { + return true + } + rt.ts = ts + rt.meta.EmbedFilePaths = append(rt.meta.EmbedFilePaths, meta.EmbedFilePaths...) + rt.meta.ParseCalls = append(rt.meta.ParseCalls, meta.ParseCalls...) + return true + }) + } + + // Phase 3: Type-check each ExecuteTemplate call. + mergedFunctions := make(Functions) + if pkg.Types != nil { + mergedFunctions = DefaultFunctions(pkg.Types) + } + for _, rt := range resolved { + for name, sig := range rt.funcs { + mergedFunctions[name] = sig + } + } + + var errs []error + for _, p := range pending { + rt, ok := resolved[p.receiverObj] + if !ok { + continue + } + looked := rt.ts.Lookup(p.templateName) + if looked == nil { + continue + } + treeFinder := (*asteval.Forrest)(rt.ts) + global := NewGlobal(pkg.Types, pkg.Fset, treeFinder, mergedFunctions) + if err := Execute(global, looked.Tree, p.dataType); err != nil { + errs = append(errs, err) + } + } + + return errors.Join(errs...) +} + +// isTemplateMethod reports whether sel refers to a method on +// *html/template.Template or *text/template.Template. +func isTemplateMethod(typesInfo *types.Info, sel *ast.SelectorExpr) bool { + if typesInfo == nil { + return false + } + selection, ok := typesInfo.Selections[sel] + if !ok { + return false + } + fn, ok := selection.Obj().(*types.Func) + if !ok { + return false + } + fnPkg := fn.Pkg() + if fnPkg == nil { + return false + } + return fnPkg.Path() == "html/template" || fnPkg.Path() == "text/template" +} + +// findModificationReceiver unwraps template.Must and returns the types.Object +// of the variable receiver for a method call like ts.ParseFS(...) or +// template.Must(ts.ParseFS(...)). Returns nil if no variable receiver is found. +func findModificationReceiver(expr *ast.CallExpr, typesInfo *types.Info) types.Object { + sel, ok := expr.Fun.(*ast.SelectorExpr) + if !ok { + return nil + } + switch x := sel.X.(type) { + case *ast.Ident: + if isTemplatePkgIdent(typesInfo, x) && sel.Sel.Name == "Must" && len(expr.Args) == 1 { + inner, ok := expr.Args[0].(*ast.CallExpr) + if !ok { + return nil + } + return findModificationReceiver(inner, typesInfo) + } + if isTemplatePkgIdent(typesInfo, x) { + return nil + } + return typesInfo.Uses[x] + } + return nil +} + +// isTemplatePkgIdent reports whether ident refers to the "html/template" +// or "text/template" package via the type checker. +func isTemplatePkgIdent(typesInfo *types.Info, ident *ast.Ident) bool { + if typesInfo == nil { + return false + } + obj := typesInfo.Uses[ident] + pkgName, ok := obj.(*types.PkgName) + if !ok { + return false + } + path := pkgName.Imported().Path() + return path == "html/template" || path == "text/template" +} + +func packageDirectory(pkg *packages.Package) string { + if len(pkg.GoFiles) > 0 { + p := pkg.GoFiles[0] + for i := len(p) - 1; i >= 0; i-- { + if p[i] == '/' || p[i] == '\\' { + return p[:i] + } + } + } + return "." +} From ec8a2451b728ca93868e1b38551d5b2583f3d57a Mon Sep 17 00:00:00 2001 From: Christopher Hunter <8398225+crhntr@users.noreply.github.com> Date: Sat, 7 Feb 2026 12:35:15 -0800 Subject: [PATCH 2/3] chore: refactor Package function break it up and add more to asteval --- internal/asteval/template.go | 52 +++++++- package.go | 240 ++++++++++++++--------------------- 2 files changed, 143 insertions(+), 149 deletions(-) diff --git a/internal/asteval/template.go b/internal/asteval/template.go index bdfe916..2ee2af2 100644 --- a/internal/asteval/template.go +++ b/internal/asteval/template.go @@ -38,7 +38,7 @@ func EvaluateTemplateSelector(ts *template.Template, pkg *types.Package, typesIn default: return nil, lDelim, rDelim, wrapWithFilename(workingDirectory, fileSet, sel.X.Pos(), fmt.Errorf("expected exactly one argument %s got %d", astgen.Format(sel.X), len(call.Args))) case *ast.Ident: - if !isTemplatePkgIdent(typesInfo, x) { + if !IsTemplatePkgIdent(typesInfo, x) { if ts == nil { return nil, lDelim, rDelim, wrapWithFilename(workingDirectory, fileSet, sel.X.Pos(), fmt.Errorf("expected template package got %s", astgen.Format(sel.X))) } @@ -187,9 +187,9 @@ func EvaluateTemplateSelector(ts *template.Template, pkg *types.Package, typesIn } } -// isTemplatePkgIdent reports whether ident refers to the "html/template" +// IsTemplatePkgIdent reports whether ident refers to the "html/template" // or "text/template" package via the type checker. -func isTemplatePkgIdent(info *types.Info, ident *ast.Ident) bool { +func IsTemplatePkgIdent(info *types.Info, ident *ast.Ident) bool { if info == nil { return false } @@ -202,6 +202,52 @@ func isTemplatePkgIdent(info *types.Info, ident *ast.Ident) bool { return path == "html/template" || path == "text/template" } +// IsTemplateMethod reports whether sel refers to a method on +// *html/template.Template or *text/template.Template. +func IsTemplateMethod(typesInfo *types.Info, sel *ast.SelectorExpr) bool { + if typesInfo == nil { + return false + } + selection, ok := typesInfo.Selections[sel] + if !ok { + return false + } + fn, ok := selection.Obj().(*types.Func) + if !ok { + return false + } + fnPkg := fn.Pkg() + if fnPkg == nil { + return false + } + return fnPkg.Path() == "html/template" || fnPkg.Path() == "text/template" +} + +// FindModificationReceiver unwraps template.Must and returns the types.Object +// of the variable receiver for a method call like ts.ParseFS(...) or +// template.Must(ts.ParseFS(...)). Returns nil if no variable receiver is found. +func FindModificationReceiver(expr *ast.CallExpr, typesInfo *types.Info) types.Object { + sel, ok := expr.Fun.(*ast.SelectorExpr) + if !ok { + return nil + } + switch x := sel.X.(type) { + case *ast.Ident: + if IsTemplatePkgIdent(typesInfo, x) && sel.Sel.Name == "Must" && len(expr.Args) == 1 { + inner, ok := expr.Args[0].(*ast.CallExpr) + if !ok { + return nil + } + return FindModificationReceiver(inner, typesInfo) + } + if IsTemplatePkgIdent(typesInfo, x) { + return nil + } + return typesInfo.Uses[x] + } + return nil +} + func builtins() template.FuncMap { type nothing struct{} return template.FuncMap{ diff --git a/package.go b/package.go index 0e102dd..a2ed138 100644 --- a/package.go +++ b/package.go @@ -7,6 +7,7 @@ import ( "go/token" "go/types" "html/template" + "path/filepath" "golang.org/x/tools/go/packages" @@ -14,18 +15,36 @@ import ( "github.com/typelate/check/internal/astgen" ) +type pendingCall struct { + receiverObj types.Object + templateName string + dataType types.Type +} + +type resolvedTemplate struct { + templates *template.Template + functions asteval.TemplateFunctions + metadata *asteval.TemplateMetadata +} + // Package discovers all .ExecuteTemplate calls in the given package, // resolves receiver variables to their template construction chains, // and type-checks each call. // // ExecuteTemplate must be called with a string literal for the second parameter. func Package(pkg *packages.Package) error { - // Phase 1: Find all ExecuteTemplate calls and collect receiver objects. - type pendingCall struct { - receiverObj types.Object - templateName string - dataType types.Type + pending, receivers := findExecuteCalls(pkg) + resolved, err := resolveTemplates(pkg, receivers) + if err != nil { + return err } + return checkCalls(pkg, pending, resolved) +} + +// findExecuteCalls walks the package syntax looking for ExecuteTemplate calls +// and returns the pending calls along with the set of receiver objects that +// need template resolution. +func findExecuteCalls(pkg *packages.Package) ([]pendingCall, map[types.Object]struct{}) { var pending []pendingCall receiverSet := make(map[types.Object]struct{}) @@ -40,7 +59,7 @@ func Package(pkg *packages.Package) error { return true } // Verify the method belongs to html/template or text/template. - if !isTemplateMethod(pkg.TypesInfo, sel) { + if !asteval.IsTemplateMethod(pkg.TypesInfo, sel) { return true } receiverIdent, ok := sel.X.(*ast.Ident) @@ -66,86 +85,51 @@ func Package(pkg *packages.Package) error { }) } - // Phase 2: Resolve each unique receiver object to its template construction chain. - type resolvedTemplate struct { - ts *template.Template - funcs asteval.TemplateFunctions - meta *asteval.TemplateMetadata - } + return pending, receiverSet +} + +// resolveTemplates resolves each unique receiver object to its template +// construction chain, including additional ParseFS/Parse modifications. +func resolveTemplates(pkg *packages.Package, receivers map[types.Object]struct{}) (map[types.Object]*resolvedTemplate, error) { resolved := make(map[types.Object]*resolvedTemplate) workingDirectory := packageDirectory(pkg) embeddedPaths, err := asteval.RelativeFilePaths(workingDirectory, pkg.EmbedFiles...) if err != nil { - return fmt.Errorf("failed to calculate relative path for embedded files: %w", err) + return nil, fmt.Errorf("failed to calculate relative path for embedded files: %w", err) } - resolveValueSpec := func(tv *ast.ValueSpec) { - for i, name := range tv.Names { - if i >= len(tv.Values) { - continue - } - obj := pkg.TypesInfo.Defs[name] - if obj == nil { - continue - } - if _, needed := receiverSet[obj]; !needed { - continue - } - - funcTypeMap := asteval.DefaultFunctions(pkg.Types) - meta := &asteval.TemplateMetadata{} - ts, _, _, err := asteval.EvaluateTemplateSelector(nil, pkg.Types, pkg.TypesInfo, tv.Values[i], workingDirectory, name.Name, "", "", pkg.Fset, pkg.Syntax, embeddedPaths, funcTypeMap, make(template.FuncMap), meta) - if err != nil { - return - } - resolved[obj] = &resolvedTemplate{ - ts: ts, - funcs: funcTypeMap, - meta: meta, - } + resolveExpr := func(obj types.Object, name string, expr ast.Expr) { + if _, needed := receivers[obj]; !needed { + return } - } - - resolveAssignStmt := func(stmt *ast.AssignStmt) { - if stmt.Tok != token.DEFINE { + funcTypeMap := asteval.DefaultFunctions(pkg.Types) + meta := &asteval.TemplateMetadata{} + ts, _, _, err := asteval.EvaluateTemplateSelector(nil, pkg.Types, pkg.TypesInfo, expr, workingDirectory, name, "", "", pkg.Fset, pkg.Syntax, embeddedPaths, funcTypeMap, make(template.FuncMap), meta) + if err != nil { return } - for i, lhs := range stmt.Lhs { - if i >= len(stmt.Rhs) { - continue - } - ident, ok := lhs.(*ast.Ident) - if !ok { + resolved[obj] = &resolvedTemplate{ + templates: ts, + functions: funcTypeMap, + metadata: meta, + } + } + + // Resolve top-level var declarations. + for _, tv := range astgen.IterateValueSpecs(pkg.Syntax) { + for i, ident := range tv.Names { + if i >= len(tv.Values) { continue } obj := pkg.TypesInfo.Defs[ident] if obj == nil { continue } - if _, needed := receiverSet[obj]; !needed { - continue - } - - funcTypeMap := asteval.DefaultFunctions(pkg.Types) - meta := &asteval.TemplateMetadata{} - ts, _, _, err := asteval.EvaluateTemplateSelector(nil, pkg.Types, pkg.TypesInfo, stmt.Rhs[i], workingDirectory, ident.Name, "", "", pkg.Fset, pkg.Syntax, embeddedPaths, funcTypeMap, make(template.FuncMap), meta) - if err != nil { - return - } - resolved[obj] = &resolvedTemplate{ - ts: ts, - funcs: funcTypeMap, - meta: meta, - } + resolveExpr(obj, ident.Name, tv.Values[i]) } } - // Resolve top-level var declarations. - for _, tv := range astgen.IterateValueSpecs(pkg.Syntax) { - resolveValueSpec(tv) - } - // Resolve function-local var declarations and short variable declarations. for _, file := range pkg.Syntax { ast.Inspect(file, func(node ast.Node) bool { @@ -160,23 +144,48 @@ func Package(pkg *packages.Package) error { if !ok { continue } - resolveValueSpec(vs) + for i, ident := range vs.Names { + if i >= len(vs.Values) { + continue + } + obj := pkg.TypesInfo.Defs[ident] + if obj == nil { + continue + } + resolveExpr(obj, ident.Name, vs.Values[i]) + } } case *ast.AssignStmt: - resolveAssignStmt(n) + if n.Tok != token.DEFINE { + return true + } + for i, lhs := range n.Lhs { + if i >= len(n.Rhs) { + continue + } + ident, ok := lhs.(*ast.Ident) + if !ok { + continue + } + obj := pkg.TypesInfo.Defs[ident] + if obj == nil { + continue + } + resolveExpr(obj, ident.Name, n.Rhs[i]) + } } return true }) } - // Phase 2b: Find additional ParseFS/Parse calls on resolved template variables. + // Find additional ParseFS/Parse calls on resolved template variables. for _, file := range pkg.Syntax { ast.Inspect(file, func(node ast.Node) bool { call, ok := node.(*ast.CallExpr) if !ok { return true } - obj := findModificationReceiver(call, pkg.TypesInfo) + obj := asteval.FindModificationReceiver(call, pkg.TypesInfo) if obj == nil { return true } @@ -185,24 +194,29 @@ func Package(pkg *packages.Package) error { return true } meta := &asteval.TemplateMetadata{} - ts, _, _, err := asteval.EvaluateTemplateSelector(rt.ts, pkg.Types, pkg.TypesInfo, call, workingDirectory, "", "", "", pkg.Fset, pkg.Syntax, embeddedPaths, rt.funcs, make(template.FuncMap), meta) + ts, _, _, err := asteval.EvaluateTemplateSelector(rt.templates, pkg.Types, pkg.TypesInfo, call, workingDirectory, "", "", "", pkg.Fset, pkg.Syntax, embeddedPaths, rt.functions, make(template.FuncMap), meta) if err != nil { return true } - rt.ts = ts - rt.meta.EmbedFilePaths = append(rt.meta.EmbedFilePaths, meta.EmbedFilePaths...) - rt.meta.ParseCalls = append(rt.meta.ParseCalls, meta.ParseCalls...) + rt.templates = ts + rt.metadata.EmbedFilePaths = append(rt.metadata.EmbedFilePaths, meta.EmbedFilePaths...) + rt.metadata.ParseCalls = append(rt.metadata.ParseCalls, meta.ParseCalls...) return true }) } - // Phase 3: Type-check each ExecuteTemplate call. + return resolved, nil +} + +// checkCalls type-checks each pending ExecuteTemplate call against its +// resolved template. +func checkCalls(pkg *packages.Package, pending []pendingCall, resolved map[types.Object]*resolvedTemplate) error { mergedFunctions := make(Functions) if pkg.Types != nil { mergedFunctions = DefaultFunctions(pkg.Types) } for _, rt := range resolved { - for name, sig := range rt.funcs { + for name, sig := range rt.functions { mergedFunctions[name] = sig } } @@ -213,11 +227,11 @@ func Package(pkg *packages.Package) error { if !ok { continue } - looked := rt.ts.Lookup(p.templateName) + looked := rt.templates.Lookup(p.templateName) if looked == nil { continue } - treeFinder := (*asteval.Forrest)(rt.ts) + treeFinder := (*asteval.Forrest)(rt.templates) global := NewGlobal(pkg.Types, pkg.Fset, treeFinder, mergedFunctions) if err := Execute(global, looked.Tree, p.dataType); err != nil { errs = append(errs, err) @@ -227,75 +241,9 @@ func Package(pkg *packages.Package) error { return errors.Join(errs...) } -// isTemplateMethod reports whether sel refers to a method on -// *html/template.Template or *text/template.Template. -func isTemplateMethod(typesInfo *types.Info, sel *ast.SelectorExpr) bool { - if typesInfo == nil { - return false - } - selection, ok := typesInfo.Selections[sel] - if !ok { - return false - } - fn, ok := selection.Obj().(*types.Func) - if !ok { - return false - } - fnPkg := fn.Pkg() - if fnPkg == nil { - return false - } - return fnPkg.Path() == "html/template" || fnPkg.Path() == "text/template" -} - -// findModificationReceiver unwraps template.Must and returns the types.Object -// of the variable receiver for a method call like ts.ParseFS(...) or -// template.Must(ts.ParseFS(...)). Returns nil if no variable receiver is found. -func findModificationReceiver(expr *ast.CallExpr, typesInfo *types.Info) types.Object { - sel, ok := expr.Fun.(*ast.SelectorExpr) - if !ok { - return nil - } - switch x := sel.X.(type) { - case *ast.Ident: - if isTemplatePkgIdent(typesInfo, x) && sel.Sel.Name == "Must" && len(expr.Args) == 1 { - inner, ok := expr.Args[0].(*ast.CallExpr) - if !ok { - return nil - } - return findModificationReceiver(inner, typesInfo) - } - if isTemplatePkgIdent(typesInfo, x) { - return nil - } - return typesInfo.Uses[x] - } - return nil -} - -// isTemplatePkgIdent reports whether ident refers to the "html/template" -// or "text/template" package via the type checker. -func isTemplatePkgIdent(typesInfo *types.Info, ident *ast.Ident) bool { - if typesInfo == nil { - return false - } - obj := typesInfo.Uses[ident] - pkgName, ok := obj.(*types.PkgName) - if !ok { - return false - } - path := pkgName.Imported().Path() - return path == "html/template" || path == "text/template" -} - func packageDirectory(pkg *packages.Package) string { if len(pkg.GoFiles) > 0 { - p := pkg.GoFiles[0] - for i := len(p) - 1; i >= 0; i-- { - if p[i] == '/' || p[i] == '\\' { - return p[:i] - } - } + return filepath.Dir(pkg.GoFiles[0]) } return "." } From 45e0ccd122515888e2a4d709cc8d2eb74d353491 Mon Sep 17 00:00:00 2001 From: Christopher Hunter <8398225+crhntr@users.noreply.github.com> Date: Sat, 7 Feb 2026 16:53:34 -0800 Subject: [PATCH 3/3] chore: add test coverage for text vs html template packages --- .../err_html_template_err_branch_end.txt | 36 ++++++++++ .../pass_text_template_no_err_branch_end.txt | 35 ++++++++++ internal/asteval/template.go | 56 ++++++++++++---- internal/asteval/template_html.go | 67 +++++++++++++++++++ internal/asteval/template_packages.go | 19 ++++++ internal/asteval/template_text.go | 67 +++++++++++++++++++ package.go | 29 ++++---- 7 files changed, 281 insertions(+), 28 deletions(-) create mode 100644 cmd/check-templates/testdata/err_html_template_err_branch_end.txt create mode 100644 cmd/check-templates/testdata/pass_text_template_no_err_branch_end.txt create mode 100644 internal/asteval/template_html.go create mode 100644 internal/asteval/template_packages.go create mode 100644 internal/asteval/template_text.go diff --git a/cmd/check-templates/testdata/err_html_template_err_branch_end.txt b/cmd/check-templates/testdata/err_html_template_err_branch_end.txt new file mode 100644 index 0000000..a5d2fbd --- /dev/null +++ b/cmd/check-templates/testdata/err_html_template_err_branch_end.txt @@ -0,0 +1,36 @@ +# html/template errors with ErrBranchEnd at Execute time. +# The go test exercises the html/template escape analysis which rejects +# templates whose {{if}} branches end in different HTML contexts. +# Companion: pass_text_template_no_err_branch_end.txt + +check-templates +exec go test + +-- go.mod -- +module example.com/app + +go 1.25.0 +-- main.go -- +package main + +import "html/template" + +var templates = template.Must(template.New("page").Parse(`
{{end}}
`)) + +func main() {} + +-- main_test.go -- +package main + +import ( + "io" + "testing" +) + +func TestExecuteBranchEnd(t *testing.T) { + err := templates.Execute(io.Discard, false) + if err == nil { + t.Fatal("expected html/template ErrBranchEnd error, got nil") + } + t.Logf("got expected error: %s", err.Error()) +} diff --git a/cmd/check-templates/testdata/pass_text_template_no_err_branch_end.txt b/cmd/check-templates/testdata/pass_text_template_no_err_branch_end.txt new file mode 100644 index 0000000..1ae730d --- /dev/null +++ b/cmd/check-templates/testdata/pass_text_template_no_err_branch_end.txt @@ -0,0 +1,35 @@ +# text/template does not error with ErrBranchEnd. +# The go test confirms text/template.Execute succeeds on the same template +# that html/template rejects due to branches ending in different contexts. +# Companion: err_html_template_err_branch_end.txt + +check-templates +exec go test + +-- go.mod -- +module example.com/app + +go 1.25.0 +-- main.go -- +package main + +import "text/template" + +var templates = template.Must(template.New("page").Parse(`
{{end}}
`)) + +func main() {} + +-- main_test.go -- +package main + +import ( + "io" + "testing" +) + +func TestExecuteNoBranchEnd(t *testing.T) { + err := templates.Execute(io.Discard, false) + if err != nil { + t.Fatalf("text/template should not produce ErrBranchEnd, got: %v", err) + } +} diff --git a/internal/asteval/template.go b/internal/asteval/template.go index 2ee2af2..b0293e9 100644 --- a/internal/asteval/template.go +++ b/internal/asteval/template.go @@ -7,7 +7,6 @@ import ( "go/format" "go/token" "go/types" - "html/template" "os" "path/filepath" "slices" @@ -19,13 +18,28 @@ import ( "github.com/typelate/check/internal/astgen" ) +// Template abstracts over html/template.Template and text/template.Template +// so that the correct template package is used based on the user's import. +type Template interface { + New(name string) Template + Parse(text string) (Template, error) + Funcs(funcMap map[string]any) Template + Option(opt ...string) Template + Delims(left, right string) Template + Lookup(name string) Template + Name() string + AddParseTree(name string, tree *parse.Tree) (Template, error) + Tree() *parse.Tree + FindTree(name string) (*parse.Tree, bool) +} + // TemplateMetadata accumulates metadata during template evaluation. type TemplateMetadata struct { EmbedFilePaths []string ParseCalls []*ast.BasicLit } -func EvaluateTemplateSelector(ts *template.Template, pkg *types.Package, typesInfo *types.Info, expression ast.Expr, workingDirectory, templatesVariable, rDelim, lDelim string, fileSet *token.FileSet, files []*ast.File, embeddedPaths []string, funcTypeMaps TemplateFunctions, fm template.FuncMap, meta *TemplateMetadata) (*template.Template, string, string, error) { +func EvaluateTemplateSelector(ts Template, pkg *types.Package, typesInfo *types.Info, expression ast.Expr, workingDirectory, templatesVariable, rDelim, lDelim string, fileSet *token.FileSet, files []*ast.File, embeddedPaths []string, funcTypeMaps TemplateFunctions, fm map[string]any, meta *TemplateMetadata) (Template, string, string, error) { call, ok := expression.(*ast.CallExpr) if !ok { return nil, lDelim, rDelim, wrapWithFilename(workingDirectory, fileSet, expression.Pos(), fmt.Errorf("expected call expression")) @@ -52,7 +66,8 @@ func EvaluateTemplateSelector(ts *template.Template, pkg *types.Package, typesIn if meta != nil { meta.EmbedFilePaths = append(meta.EmbedFilePaths, filePaths...) } - t, err := parseFiles(ts, fm, lDelim, rDelim, filePaths...) + pkgPath := templatePkgPath(typesInfo, x) + t, err := parseFiles(ts, pkgPath, fm, lDelim, rDelim, filePaths...) return t, lDelim, rDelim, err case "Parse": if len(call.Args) != 1 { @@ -93,6 +108,7 @@ func EvaluateTemplateSelector(ts *template.Template, pkg *types.Package, typesIn return nil, lDelim, rDelim, wrapWithFilename(workingDirectory, fileSet, call.Fun.Pos(), fmt.Errorf("unsupported method %s on variable receiver", sel.Sel.Name)) } } + pkgPath := templatePkgPath(typesInfo, x) switch sel.Sel.Name { case "Must": if len(call.Args) != 1 { @@ -107,7 +123,7 @@ func EvaluateTemplateSelector(ts *template.Template, pkg *types.Package, typesIn if err != nil { return nil, lDelim, rDelim, err } - return template.New(templateNames[0]), lDelim, rDelim, nil + return NewTemplate(pkgPath, templateNames[0]), lDelim, rDelim, nil case "ParseFS": filePaths, err := evaluateCallParseFilesArgs(workingDirectory, fileSet, call, files, embeddedPaths) if err != nil { @@ -116,7 +132,7 @@ func EvaluateTemplateSelector(ts *template.Template, pkg *types.Package, typesIn if meta != nil { meta.EmbedFilePaths = append(meta.EmbedFilePaths, filePaths...) } - t, err := parseFiles(nil, fm, lDelim, rDelim, filePaths...) + t, err := parseFiles(nil, pkgPath, fm, lDelim, rDelim, filePaths...) return t, lDelim, rDelim, err default: return nil, lDelim, rDelim, wrapWithFilename(workingDirectory, fileSet, call.Fun.Pos(), fmt.Errorf("unsupported function %s", sel.Sel.Name)) @@ -168,7 +184,7 @@ func EvaluateTemplateSelector(ts *template.Template, pkg *types.Package, typesIn if meta != nil { meta.EmbedFilePaths = append(meta.EmbedFilePaths, filePaths...) } - t, err := parseFiles(up, fm, upLDelim, upRDelim, filePaths...) + t, err := parseFiles(up, "", fm, upLDelim, upRDelim, filePaths...) return t, upLDelim, upRDelim, err case "Option": list, err := StringLiteralExpressionList(workingDirectory, fileSet, call.Args) @@ -187,6 +203,20 @@ func EvaluateTemplateSelector(ts *template.Template, pkg *types.Package, typesIn } } +// templatePkgPath extracts the import path ("html/template" or "text/template") +// from an AST identifier that refers to a template package. +func templatePkgPath(info *types.Info, ident *ast.Ident) string { + if info == nil { + return "html/template" + } + obj := info.Uses[ident] + pkgName, ok := obj.(*types.PkgName) + if !ok { + return "html/template" + } + return pkgName.Imported().Path() +} + // IsTemplatePkgIdent reports whether ident refers to the "html/template" // or "text/template" package via the type checker. func IsTemplatePkgIdent(info *types.Info, ident *ast.Ident) bool { @@ -248,9 +278,9 @@ func FindModificationReceiver(expr *ast.CallExpr, typesInfo *types.Info) types.O return nil } -func builtins() template.FuncMap { +func builtins() map[string]any { type nothing struct{} - return template.FuncMap{ + return map[string]any{ "and": func() (_ nothing) { return }, "call": func() (_ nothing) { return }, "html": func() (_ nothing) { return }, @@ -275,9 +305,9 @@ func builtins() template.FuncMap { } } -func parseFiles(t *template.Template, fm template.FuncMap, leftDelim, rightDelim string, filenames ...string) (*template.Template, error) { +func parseFiles(t Template, pkgPath string, fm map[string]any, leftDelim, rightDelim string, filenames ...string) (Template, error) { if len(filenames) == 0 { - return nil, fmt.Errorf("html/template: no files named in call to ParseFiles") + return nil, fmt.Errorf("template: no files named in call to ParseFiles") } for _, filename := range filenames { templateName := filepath.Base(filename) @@ -286,9 +316,9 @@ func parseFiles(t *template.Template, fm template.FuncMap, leftDelim, rightDelim return nil, err } s := string(b) - var tmpl *template.Template + var tmpl Template if t == nil { - t = template.New(templateName) + t = NewTemplate(pkgPath, templateName) } if templateName == t.Name() { tmpl = t @@ -313,7 +343,7 @@ func parseFiles(t *template.Template, fm template.FuncMap, leftDelim, rightDelim return t, nil } -func evaluateFuncMap(workingDirectory string, typesInfo *types.Info, pkg *types.Package, fileSet *token.FileSet, call *ast.CallExpr, fm template.FuncMap, funcTypesMap TemplateFunctions) error { +func evaluateFuncMap(workingDirectory string, typesInfo *types.Info, pkg *types.Package, fileSet *token.FileSet, call *ast.CallExpr, fm map[string]any, funcTypesMap TemplateFunctions) error { if len(call.Args) != 1 { return wrapWithFilename(workingDirectory, fileSet, call.Lparen, fmt.Errorf("expected exactly 1 template.FuncMap composite literal argument")) } diff --git a/internal/asteval/template_html.go b/internal/asteval/template_html.go new file mode 100644 index 0000000..31c75fc --- /dev/null +++ b/internal/asteval/template_html.go @@ -0,0 +1,67 @@ +package asteval + +import ( + "html/template" + "text/template/parse" +) + +// htmlTemplate wraps *html/template.Template. +type htmlTemplate struct { + t *template.Template +} + +func (h *htmlTemplate) New(name string) Template { + return &htmlTemplate{t: h.t.New(name)} +} + +func (h *htmlTemplate) Parse(text string) (Template, error) { + t, err := h.t.Parse(text) + if err != nil { + return nil, err + } + return &htmlTemplate{t: t}, nil +} + +func (h *htmlTemplate) Funcs(funcMap map[string]any) Template { + return &htmlTemplate{t: h.t.Funcs(funcMap)} +} + +func (h *htmlTemplate) Option(opt ...string) Template { + return &htmlTemplate{t: h.t.Option(opt...)} +} + +func (h *htmlTemplate) Delims(left, right string) Template { + return &htmlTemplate{t: h.t.Delims(left, right)} +} + +func (h *htmlTemplate) Lookup(name string) Template { + t := h.t.Lookup(name) + if t == nil { + return nil + } + return &htmlTemplate{t: t} +} + +func (h *htmlTemplate) Name() string { + return h.t.Name() +} + +func (h *htmlTemplate) AddParseTree(name string, tree *parse.Tree) (Template, error) { + t, err := h.t.AddParseTree(name, tree) + if err != nil { + return nil, err + } + return &htmlTemplate{t: t}, nil +} + +func (h *htmlTemplate) Tree() *parse.Tree { + return h.t.Tree +} + +func (h *htmlTemplate) FindTree(name string) (*parse.Tree, bool) { + t := h.t.Lookup(name) + if t == nil { + return nil, false + } + return t.Tree, true +} diff --git a/internal/asteval/template_packages.go b/internal/asteval/template_packages.go new file mode 100644 index 0000000..2b4428a --- /dev/null +++ b/internal/asteval/template_packages.go @@ -0,0 +1,19 @@ +package asteval + +import ( + html "html/template" + text "text/template" +) + +// NewTemplate creates a Template backed by the appropriate template +// package either: "text/template" or "html/template". +func NewTemplate(pkgPath, name string) Template { + switch pkgPath { + case "text/template": + return &textTemplate{t: text.New(name)} + case "html/template": + return &htmlTemplate{t: html.New(name)} + default: + return nil + } +} diff --git a/internal/asteval/template_text.go b/internal/asteval/template_text.go new file mode 100644 index 0000000..d0e427c --- /dev/null +++ b/internal/asteval/template_text.go @@ -0,0 +1,67 @@ +package asteval + +import ( + "text/template" + "text/template/parse" +) + +// textTemplate wraps *text/template.Template. +type textTemplate struct { + t *template.Template +} + +func (s *textTemplate) New(name string) Template { + return &textTemplate{t: s.t.New(name)} +} + +func (s *textTemplate) Parse(text string) (Template, error) { + t, err := s.t.Parse(text) + if err != nil { + return nil, err + } + return &textTemplate{t: t}, nil +} + +func (s *textTemplate) Funcs(funcMap map[string]any) Template { + return &textTemplate{t: s.t.Funcs(funcMap)} +} + +func (s *textTemplate) Option(opt ...string) Template { + return &textTemplate{t: s.t.Option(opt...)} +} + +func (s *textTemplate) Delims(left, right string) Template { + return &textTemplate{t: s.t.Delims(left, right)} +} + +func (s *textTemplate) Lookup(name string) Template { + t := s.t.Lookup(name) + if t == nil { + return nil + } + return &textTemplate{t: t} +} + +func (s *textTemplate) Name() string { + return s.t.Name() +} + +func (s *textTemplate) AddParseTree(name string, tree *parse.Tree) (Template, error) { + t, err := s.t.AddParseTree(name, tree) + if err != nil { + return nil, err + } + return &textTemplate{t: t}, nil +} + +func (s *textTemplate) Tree() *parse.Tree { + return s.t.Tree +} + +func (s *textTemplate) FindTree(name string) (*parse.Tree, bool) { + t := s.t.Lookup(name) + if t == nil { + return nil, false + } + return t.Tree, true +} diff --git a/package.go b/package.go index a2ed138..e44de5f 100644 --- a/package.go +++ b/package.go @@ -6,7 +6,6 @@ import ( "go/ast" "go/token" "go/types" - "html/template" "path/filepath" "golang.org/x/tools/go/packages" @@ -22,7 +21,7 @@ type pendingCall struct { } type resolvedTemplate struct { - templates *template.Template + templates asteval.Template functions asteval.TemplateFunctions metadata *asteval.TemplateMetadata } @@ -34,11 +33,9 @@ type resolvedTemplate struct { // ExecuteTemplate must be called with a string literal for the second parameter. func Package(pkg *packages.Package) error { pending, receivers := findExecuteCalls(pkg) - resolved, err := resolveTemplates(pkg, receivers) - if err != nil { - return err - } - return checkCalls(pkg, pending, resolved) + resolved, resolveErrs := resolveTemplates(pkg, receivers) + callErr := checkCalls(pkg, pending, resolved) + return errors.Join(append(resolveErrs, callErr)...) } // findExecuteCalls walks the package syntax looking for ExecuteTemplate calls @@ -90,23 +87,26 @@ func findExecuteCalls(pkg *packages.Package) ([]pendingCall, map[types.Object]st // resolveTemplates resolves each unique receiver object to its template // construction chain, including additional ParseFS/Parse modifications. -func resolveTemplates(pkg *packages.Package, receivers map[types.Object]struct{}) (map[types.Object]*resolvedTemplate, error) { +func resolveTemplates(pkg *packages.Package, receivers map[types.Object]struct{}) (map[types.Object]*resolvedTemplate, []error) { resolved := make(map[types.Object]*resolvedTemplate) workingDirectory := packageDirectory(pkg) embeddedPaths, err := asteval.RelativeFilePaths(workingDirectory, pkg.EmbedFiles...) if err != nil { - return nil, fmt.Errorf("failed to calculate relative path for embedded files: %w", err) + return nil, []error{fmt.Errorf("failed to calculate relative path for embedded files: %w", err)} } + var resolveErrs []error + resolveExpr := func(obj types.Object, name string, expr ast.Expr) { if _, needed := receivers[obj]; !needed { return } funcTypeMap := asteval.DefaultFunctions(pkg.Types) meta := &asteval.TemplateMetadata{} - ts, _, _, err := asteval.EvaluateTemplateSelector(nil, pkg.Types, pkg.TypesInfo, expr, workingDirectory, name, "", "", pkg.Fset, pkg.Syntax, embeddedPaths, funcTypeMap, make(template.FuncMap), meta) + ts, _, _, err := asteval.EvaluateTemplateSelector(nil, pkg.Types, pkg.TypesInfo, expr, workingDirectory, name, "", "", pkg.Fset, pkg.Syntax, embeddedPaths, funcTypeMap, make(map[string]any), meta) if err != nil { + resolveErrs = append(resolveErrs, err) return } resolved[obj] = &resolvedTemplate{ @@ -194,7 +194,7 @@ func resolveTemplates(pkg *packages.Package, receivers map[types.Object]struct{} return true } meta := &asteval.TemplateMetadata{} - ts, _, _, err := asteval.EvaluateTemplateSelector(rt.templates, pkg.Types, pkg.TypesInfo, call, workingDirectory, "", "", "", pkg.Fset, pkg.Syntax, embeddedPaths, rt.functions, make(template.FuncMap), meta) + ts, _, _, err := asteval.EvaluateTemplateSelector(rt.templates, pkg.Types, pkg.TypesInfo, call, workingDirectory, "", "", "", pkg.Fset, pkg.Syntax, embeddedPaths, rt.functions, make(map[string]any), meta) if err != nil { return true } @@ -205,7 +205,7 @@ func resolveTemplates(pkg *packages.Package, receivers map[types.Object]struct{} }) } - return resolved, nil + return resolved, resolveErrs } // checkCalls type-checks each pending ExecuteTemplate call against its @@ -231,9 +231,8 @@ func checkCalls(pkg *packages.Package, pending []pendingCall, resolved map[types if looked == nil { continue } - treeFinder := (*asteval.Forrest)(rt.templates) - global := NewGlobal(pkg.Types, pkg.Fset, treeFinder, mergedFunctions) - if err := Execute(global, looked.Tree, p.dataType); err != nil { + global := NewGlobal(pkg.Types, pkg.Fset, rt.templates, mergedFunctions) + if err := Execute(global, looked.Tree(), p.dataType); err != nil { errs = append(errs, err) } }