From 77c82221795d27fcd345b865bb67a4777de33bb3 Mon Sep 17 00:00:00 2001
From: Christopher Hunter <8398225+crhntr@users.noreply.github.com>
Date: Sat, 7 Feb 2026 11:28:57 -0800
Subject: [PATCH 1/3] feat: add check-template command
---
cmd/check-templates/main.go | 48 ++
cmd/check-templates/main_test.go | 40 ++
.../err_additional_parsefs_missing_field.txt | 43 ++
.../testdata/err_aliased_import.txt | 38 ++
.../testdata/err_closure_missing_field.txt | 38 ++
.../testdata/err_funcs_chain.txt | 42 ++
.../testdata/err_imported_type.txt | 42 ++
.../testdata/err_inline_struct.txt | 37 ++
.../testdata/err_local_var_missing_field.txt | 35 ++
.../testdata/err_missing_field.txt | 38 ++
.../testdata/err_multiple_errors.txt | 49 ++
.../testdata/err_multiple_template_vars.txt | 49 ++
.../testdata/err_nested_template.txt | 41 ++
.../testdata/err_shadowed_var.txt | 48 ++
.../testdata/err_text_template.txt | 38 ++
.../testdata/pass_additional_parsefs.txt | 51 ++
.../pass_additional_parsefs_no_must.txt | 47 ++
.../testdata/pass_aliased_import.txt | 38 ++
cmd/check-templates/testdata/pass_closure.txt | 38 ++
.../testdata/pass_different_identifier.txt | 38 ++
.../testdata/pass_funcs_chain.txt | 40 ++
.../testdata/pass_imported_type.txt | 41 ++
.../testdata/pass_inline_struct.txt | 36 ++
.../testdata/pass_local_var.txt | 35 ++
.../testdata/pass_multiple_calls.txt | 48 ++
.../testdata/pass_multiple_template_vars.txt | 46 ++
.../testdata/pass_nested_template.txt | 40 ++
.../testdata/pass_no_execute_calls.txt | 13 +
.../testdata/pass_non_template_execute.txt | 28 +
.../testdata/pass_parse_call.txt | 29 +
.../testdata/pass_shadowed_var.txt | 47 ++
cmd/check-templates/testdata/pass_simple.txt | 38 ++
.../testdata/pass_text_template.txt | 37 ++
func.go | 10 +-
go.mod | 1 +
go.sum | 2 +
internal/asteval/errors.go | 13 +
internal/asteval/forrest.go | 20 +
internal/asteval/string.go | 30 +
internal/asteval/template.go | 542 ++++++++++++++++++
.../testdata/template/assets_dir.txtar | 23 +
.../testdata/template/bad_embed_pattern.txtar | 16 +
.../asteval/testdata/template/delims.txtar | 22 +
.../asteval/testdata/template/funcs.txtar | 41 ++
.../asteval/testdata/template/parse.txtar | 15 +
.../testdata/template/template_ParseFS.txtar | 38 ++
.../asteval/testdata/template/templates.txtar | 84 +++
internal/astgen/format.go | 23 +
internal/astgen/queries.go | 36 ++
package.go | 301 ++++++++++
50 files changed, 2532 insertions(+), 1 deletion(-)
create mode 100644 cmd/check-templates/main.go
create mode 100644 cmd/check-templates/main_test.go
create mode 100644 cmd/check-templates/testdata/err_additional_parsefs_missing_field.txt
create mode 100644 cmd/check-templates/testdata/err_aliased_import.txt
create mode 100644 cmd/check-templates/testdata/err_closure_missing_field.txt
create mode 100644 cmd/check-templates/testdata/err_funcs_chain.txt
create mode 100644 cmd/check-templates/testdata/err_imported_type.txt
create mode 100644 cmd/check-templates/testdata/err_inline_struct.txt
create mode 100644 cmd/check-templates/testdata/err_local_var_missing_field.txt
create mode 100644 cmd/check-templates/testdata/err_missing_field.txt
create mode 100644 cmd/check-templates/testdata/err_multiple_errors.txt
create mode 100644 cmd/check-templates/testdata/err_multiple_template_vars.txt
create mode 100644 cmd/check-templates/testdata/err_nested_template.txt
create mode 100644 cmd/check-templates/testdata/err_shadowed_var.txt
create mode 100644 cmd/check-templates/testdata/err_text_template.txt
create mode 100644 cmd/check-templates/testdata/pass_additional_parsefs.txt
create mode 100644 cmd/check-templates/testdata/pass_additional_parsefs_no_must.txt
create mode 100644 cmd/check-templates/testdata/pass_aliased_import.txt
create mode 100644 cmd/check-templates/testdata/pass_closure.txt
create mode 100644 cmd/check-templates/testdata/pass_different_identifier.txt
create mode 100644 cmd/check-templates/testdata/pass_funcs_chain.txt
create mode 100644 cmd/check-templates/testdata/pass_imported_type.txt
create mode 100644 cmd/check-templates/testdata/pass_inline_struct.txt
create mode 100644 cmd/check-templates/testdata/pass_local_var.txt
create mode 100644 cmd/check-templates/testdata/pass_multiple_calls.txt
create mode 100644 cmd/check-templates/testdata/pass_multiple_template_vars.txt
create mode 100644 cmd/check-templates/testdata/pass_nested_template.txt
create mode 100644 cmd/check-templates/testdata/pass_no_execute_calls.txt
create mode 100644 cmd/check-templates/testdata/pass_non_template_execute.txt
create mode 100644 cmd/check-templates/testdata/pass_parse_call.txt
create mode 100644 cmd/check-templates/testdata/pass_shadowed_var.txt
create mode 100644 cmd/check-templates/testdata/pass_simple.txt
create mode 100644 cmd/check-templates/testdata/pass_text_template.txt
create mode 100644 internal/asteval/errors.go
create mode 100644 internal/asteval/forrest.go
create mode 100644 internal/asteval/string.go
create mode 100644 internal/asteval/template.go
create mode 100644 internal/asteval/testdata/template/assets_dir.txtar
create mode 100644 internal/asteval/testdata/template/bad_embed_pattern.txtar
create mode 100644 internal/asteval/testdata/template/delims.txtar
create mode 100644 internal/asteval/testdata/template/funcs.txtar
create mode 100644 internal/asteval/testdata/template/parse.txtar
create mode 100644 internal/asteval/testdata/template/template_ParseFS.txtar
create mode 100644 internal/asteval/testdata/template/templates.txtar
create mode 100644 internal/astgen/format.go
create mode 100644 internal/astgen/queries.go
create mode 100644 package.go
diff --git a/cmd/check-templates/main.go b/cmd/check-templates/main.go
new file mode 100644
index 0000000..1a2bd82
--- /dev/null
+++ b/cmd/check-templates/main.go
@@ -0,0 +1,48 @@
+package main
+
+import (
+ "fmt"
+ "go/token"
+ "io"
+ "os"
+
+ "golang.org/x/tools/go/packages"
+
+ "github.com/typelate/check"
+)
+
+func main() {
+ os.Exit(run(os.Args[1:], os.Stdout, os.Stderr))
+}
+
+func run(args []string, stdout, stderr io.Writer) int {
+ dir := "."
+ if len(args) > 0 {
+ dir = args[0]
+ }
+
+ fset := token.NewFileSet()
+ pkgs, err := packages.Load(&packages.Config{
+ Fset: fset,
+ Mode: packages.NeedTypesInfo | packages.NeedName | packages.NeedFiles |
+ packages.NeedTypes | packages.NeedSyntax | packages.NeedEmbedPatterns |
+ packages.NeedEmbedFiles | packages.NeedImports | packages.NeedModule,
+ Dir: dir,
+ }, dir)
+ if err != nil {
+ fmt.Fprintf(stderr, "failed to load packages: %v\n", err)
+ return 1
+ }
+ exitCode := 0
+ for _, pkg := range pkgs {
+ for _, e := range pkg.Errors {
+ fmt.Fprintln(stderr, e)
+ exitCode = 1
+ }
+ if err := check.Package(pkg); err != nil {
+ fmt.Fprintln(stderr, err)
+ exitCode = 1
+ }
+ }
+ return exitCode
+}
diff --git a/cmd/check-templates/main_test.go b/cmd/check-templates/main_test.go
new file mode 100644
index 0000000..acc3f24
--- /dev/null
+++ b/cmd/check-templates/main_test.go
@@ -0,0 +1,40 @@
+package main
+
+import (
+ "bytes"
+ "path/filepath"
+ "testing"
+
+ "rsc.io/script"
+ "rsc.io/script/scripttest"
+)
+
+func Test(t *testing.T) {
+ e := script.NewEngine()
+ e.Quiet = true
+ e.Cmds = scripttest.DefaultCmds()
+ e.Cmds["check-templates"] = checkTemplatesCommand()
+ ctx := t.Context()
+ scripttest.Test(t, ctx, e, nil, filepath.FromSlash("testdata/*.txt"))
+}
+
+func checkTemplatesCommand() script.Cmd {
+ return script.Command(script.CmdUsage{
+ Summary: "check-templates [dir]",
+ Args: "[dir]",
+ }, func(state *script.State, args ...string) (script.WaitFunc, error) {
+ return func(state *script.State) (string, string, error) {
+ var stdout, stderr bytes.Buffer
+ cmdArgs := args
+ if len(cmdArgs) == 0 {
+ cmdArgs = []string{state.Getwd()}
+ }
+ code := run(cmdArgs, &stdout, &stderr)
+ var err error
+ if code != 0 {
+ err = script.ErrUsage
+ }
+ return stdout.String(), stderr.String(), err
+ }, nil
+ })
+}
diff --git a/cmd/check-templates/testdata/err_additional_parsefs_missing_field.txt b/cmd/check-templates/testdata/err_additional_parsefs_missing_field.txt
new file mode 100644
index 0000000..100c55e
--- /dev/null
+++ b/cmd/check-templates/testdata/err_additional_parsefs_missing_field.txt
@@ -0,0 +1,43 @@
+# Template defined at top level, modified in a function. The additional template
+# has a field that doesn't exist on the data type.
+
+! check-templates
+stderr 'type check failed:.*about\.gohtml:1:5: executing "about\.gohtml" at <\.Missing>: Missing not found on example\.com/app\.Page'
+
+-- go.mod --
+module example.com/app
+
+go 1.25.0
+-- main.go --
+package main
+
+import (
+ "embed"
+ "fmt"
+ "html/template"
+ "net/http"
+)
+
+//go:embed *.gohtml
+var source embed.FS
+
+type Page struct {
+ Title string
+}
+
+var ts = template.Must(template.New("app").ParseFS(source, "index.gohtml"))
+
+func setup() {
+ template.Must(ts.ParseFS(source, "about.gohtml"))
+}
+
+func handleAbout(w http.ResponseWriter, r *http.Request) {
+ _ = ts.ExecuteTemplate(w, "about.gohtml", Page{Title: "Home"})
+}
+
+var _ = fmt.Sprint
+
+-- index.gohtml --
+
{{.Title}}
+-- about.gohtml --
+{{.Missing}}
diff --git a/cmd/check-templates/testdata/err_aliased_import.txt b/cmd/check-templates/testdata/err_aliased_import.txt
new file mode 100644
index 0000000..cf4adc8
--- /dev/null
+++ b/cmd/check-templates/testdata/err_aliased_import.txt
@@ -0,0 +1,38 @@
+# Template import with a non-standard alias should still detect errors.
+
+! check-templates
+stderr 'type check failed:.*index\.gohtml:1:6: executing "index\.gohtml" at <\.Missing>: Missing not found on example\.com/app\.Page'
+
+-- go.mod --
+module example.com/app
+
+go 1.25.0
+-- main.go --
+package main
+
+import (
+ "embed"
+ "fmt"
+ htmltpl "html/template"
+ "net/http"
+)
+
+var (
+ //go:embed *.gohtml
+ source embed.FS
+
+ templates = htmltpl.Must(htmltpl.ParseFS(source, "*"))
+)
+
+type Page struct {
+ Title string
+}
+
+func handleIndex(w http.ResponseWriter, r *http.Request) {
+ _ = templates.ExecuteTemplate(w, "index.gohtml", Page{Title: "Home"})
+}
+
+var _ = fmt.Sprint
+
+-- index.gohtml --
+{{.Missing}}
diff --git a/cmd/check-templates/testdata/err_closure_missing_field.txt b/cmd/check-templates/testdata/err_closure_missing_field.txt
new file mode 100644
index 0000000..3c01045
--- /dev/null
+++ b/cmd/check-templates/testdata/err_closure_missing_field.txt
@@ -0,0 +1,38 @@
+# Template parsed in outer function, closure calls ExecuteTemplate with missing field.
+
+! check-templates
+stderr 'type check failed:.*index\.gohtml:1:6: executing "index\.gohtml" at <\.Missing>: Missing not found on example\.com/app\.Page'
+
+-- go.mod --
+module example.com/app
+
+go 1.25.0
+-- main.go --
+package main
+
+import (
+ "embed"
+ "fmt"
+ "html/template"
+ "net/http"
+)
+
+//go:embed *.gohtml
+var source embed.FS
+
+type Page struct {
+ Title string
+}
+
+func routes(mux *http.ServeMux) {
+ ts := template.Must(template.ParseFS(source, "*"))
+
+ mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
+ _ = ts.ExecuteTemplate(w, "index.gohtml", Page{Title: "Home"})
+ })
+}
+
+var _ = fmt.Sprint
+
+-- index.gohtml --
+{{.Missing}}
diff --git a/cmd/check-templates/testdata/err_funcs_chain.txt b/cmd/check-templates/testdata/err_funcs_chain.txt
new file mode 100644
index 0000000..c14f893
--- /dev/null
+++ b/cmd/check-templates/testdata/err_funcs_chain.txt
@@ -0,0 +1,42 @@
+# Template constructed with Funcs() chained before ParseFS should still
+# report errors for missing fields.
+
+! check-templates
+stderr 'type check failed:.*index\.gohtml:1:2: executing "index\.gohtml" at <\.Missing>: Missing not found on example\.com/app\.Page'
+
+-- go.mod --
+module example.com/app
+
+go 1.25.0
+-- main.go --
+package main
+
+import (
+ "embed"
+ "fmt"
+ "html/template"
+ "io"
+ "strings"
+)
+
+var (
+ //go:embed *.gohtml
+ source embed.FS
+
+ templates = template.Must(template.New("").Funcs(template.FuncMap{
+ "upper": strings.ToUpper,
+ }).ParseFS(source, "*"))
+)
+
+type Page struct {
+ Title string
+}
+
+func render() {
+ _ = templates.ExecuteTemplate(io.Discard, "index.gohtml", Page{})
+}
+
+var _ = fmt.Sprint
+
+-- index.gohtml --
+{{.Missing | upper}}
diff --git a/cmd/check-templates/testdata/err_imported_type.txt b/cmd/check-templates/testdata/err_imported_type.txt
new file mode 100644
index 0000000..a9a3a60
--- /dev/null
+++ b/cmd/check-templates/testdata/err_imported_type.txt
@@ -0,0 +1,42 @@
+# Types imported from other packages should report errors for missing fields.
+
+! check-templates
+stderr 'type check failed:.*index\.gohtml:1:2: executing "index\.gohtml" at <\.Missing>: Missing not found on example\.com/app/internal/model\.Page'
+
+-- go.mod --
+module example.com/app
+
+go 1.25.0
+-- internal/model/types.go --
+package model
+
+type Page struct {
+ Title string
+}
+-- main.go --
+package main
+
+import (
+ "embed"
+ "fmt"
+ "html/template"
+ "io"
+
+ "example.com/app/internal/model"
+)
+
+var (
+ //go:embed *.gohtml
+ source embed.FS
+
+ templates = template.Must(template.ParseFS(source, "*"))
+)
+
+func render() {
+ _ = templates.ExecuteTemplate(io.Discard, "index.gohtml", model.Page{})
+}
+
+var _ = fmt.Sprint
+
+-- index.gohtml --
+{{.Missing}}
diff --git a/cmd/check-templates/testdata/err_inline_struct.txt b/cmd/check-templates/testdata/err_inline_struct.txt
new file mode 100644
index 0000000..bb5fe87
--- /dev/null
+++ b/cmd/check-templates/testdata/err_inline_struct.txt
@@ -0,0 +1,37 @@
+# Inline anonymous struct types should report errors for missing fields.
+
+! check-templates
+stderr 'type check failed:.*index\.gohtml:1:2: executing "index\.gohtml" at <\.Missing>: Missing not found on struct\{Title string\}'
+
+-- go.mod --
+module example.com/app
+
+go 1.25.0
+-- main.go --
+package main
+
+import (
+ "embed"
+ "fmt"
+ "html/template"
+ "io"
+)
+
+var (
+ //go:embed *.gohtml
+ source embed.FS
+
+ templates = template.Must(template.ParseFS(source, "*"))
+)
+
+func render() {
+ var data struct {
+ Title string
+ }
+ _ = templates.ExecuteTemplate(io.Discard, "index.gohtml", data)
+}
+
+var _ = fmt.Sprint
+
+-- index.gohtml --
+{{.Missing}}
diff --git a/cmd/check-templates/testdata/err_local_var_missing_field.txt b/cmd/check-templates/testdata/err_local_var_missing_field.txt
new file mode 100644
index 0000000..bde48d4
--- /dev/null
+++ b/cmd/check-templates/testdata/err_local_var_missing_field.txt
@@ -0,0 +1,35 @@
+# Template defined as local variable reports errors for missing fields.
+
+! check-templates
+stderr 'type check failed:.*index\.gohtml:1:6: executing "index\.gohtml" at <\.Missing>: Missing not found on example\.com/app\.Page'
+
+-- go.mod --
+module example.com/app
+
+go 1.25.0
+-- main.go --
+package main
+
+import (
+ "embed"
+ "fmt"
+ "html/template"
+ "net/http"
+)
+
+//go:embed *.gohtml
+var source embed.FS
+
+type Page struct {
+ Title string
+}
+
+func handleIndex(w http.ResponseWriter, r *http.Request) {
+ ts := template.Must(template.ParseFS(source, "*"))
+ _ = ts.ExecuteTemplate(w, "index.gohtml", Page{Title: "Home"})
+}
+
+var _ = fmt.Sprint
+
+-- index.gohtml --
+{{.Missing}}
diff --git a/cmd/check-templates/testdata/err_missing_field.txt b/cmd/check-templates/testdata/err_missing_field.txt
new file mode 100644
index 0000000..c5b9e13
--- /dev/null
+++ b/cmd/check-templates/testdata/err_missing_field.txt
@@ -0,0 +1,38 @@
+# Template check fails when a field does not exist on the data type.
+
+! check-templates
+stderr 'type check failed:.*index\.gohtml:1:6: executing "index\.gohtml" at <\.Missing>: Missing not found on example\.com/app\.Page'
+
+-- go.mod --
+module example.com/app
+
+go 1.25.0
+-- main.go --
+package main
+
+import (
+ "embed"
+ "fmt"
+ "html/template"
+ "net/http"
+)
+
+var (
+ //go:embed *.gohtml
+ source embed.FS
+
+ templates = template.Must(template.ParseFS(source, "*"))
+)
+
+type Page struct {
+ Title string
+}
+
+func handleIndex(w http.ResponseWriter, r *http.Request) {
+ _ = templates.ExecuteTemplate(w, "index.gohtml", Page{Title: "Home"})
+}
+
+var _ = fmt.Sprint
+
+-- index.gohtml --
+{{.Missing}}
diff --git a/cmd/check-templates/testdata/err_multiple_errors.txt b/cmd/check-templates/testdata/err_multiple_errors.txt
new file mode 100644
index 0000000..63e1a5f
--- /dev/null
+++ b/cmd/check-templates/testdata/err_multiple_errors.txt
@@ -0,0 +1,49 @@
+# Multiple ExecuteTemplate calls can each report errors.
+
+! check-templates
+stderr 'type check failed:.*index\.gohtml:1:6: executing "index\.gohtml" at <\.Missing>: Missing not found on example\.com/app\.IndexPage'
+stderr 'type check failed:.*about\.gohtml:1:5: executing "about\.gohtml" at <\.Unknown>: Unknown not found on example\.com/app\.AboutPage'
+
+-- go.mod --
+module example.com/app
+
+go 1.25.0
+-- main.go --
+package main
+
+import (
+ "embed"
+ "fmt"
+ "html/template"
+ "net/http"
+)
+
+var (
+ //go:embed *.gohtml
+ source embed.FS
+
+ templates = template.Must(template.ParseFS(source, "*"))
+)
+
+type IndexPage struct {
+ Title string
+}
+
+type AboutPage struct {
+ Name string
+}
+
+func handleIndex(w http.ResponseWriter, r *http.Request) {
+ _ = templates.ExecuteTemplate(w, "index.gohtml", IndexPage{Title: "Home"})
+}
+
+func handleAbout(w http.ResponseWriter, r *http.Request) {
+ _ = templates.ExecuteTemplate(w, "about.gohtml", AboutPage{Name: "World"})
+}
+
+var _ = fmt.Sprint
+
+-- index.gohtml --
+{{.Missing}}
+-- about.gohtml --
+{{.Unknown}}
diff --git a/cmd/check-templates/testdata/err_multiple_template_vars.txt b/cmd/check-templates/testdata/err_multiple_template_vars.txt
new file mode 100644
index 0000000..7dd7062
--- /dev/null
+++ b/cmd/check-templates/testdata/err_multiple_template_vars.txt
@@ -0,0 +1,49 @@
+# Multiple template variables must not cross-contaminate: only the incorrect
+# call should produce an error.
+
+! check-templates
+stderr 'type check failed:.*about\.gohtml:1:2: executing "about\.gohtml" at <\.Name>: Name not found on example\.com/app\.IndexPage'
+! stderr 'Title not found'
+
+-- go.mod --
+module example.com/app
+
+go 1.25.0
+-- main.go --
+package main
+
+import (
+ "embed"
+ "fmt"
+ "html/template"
+ "io"
+)
+
+//go:embed *.gohtml
+var source embed.FS
+
+var templatesA = template.Must(template.ParseFS(source, "index.gohtml"))
+var templatesB = template.Must(template.ParseFS(source, "about.gohtml"))
+
+type IndexPage struct {
+ Title string
+}
+
+type AboutPage struct {
+ Name string
+}
+
+func renderIndex() {
+ _ = templatesA.ExecuteTemplate(io.Discard, "index.gohtml", IndexPage{})
+}
+
+func renderAbout() {
+ _ = templatesB.ExecuteTemplate(io.Discard, "about.gohtml", IndexPage{})
+}
+
+var _ = fmt.Sprint
+
+-- index.gohtml --
+{{.Title}}
+-- about.gohtml --
+{{.Name}}
diff --git a/cmd/check-templates/testdata/err_nested_template.txt b/cmd/check-templates/testdata/err_nested_template.txt
new file mode 100644
index 0000000..de654b4
--- /dev/null
+++ b/cmd/check-templates/testdata/err_nested_template.txt
@@ -0,0 +1,41 @@
+# A nested template call should check the invoked template against the data type.
+
+! check-templates
+stderr 'type check failed:.*header\.gohtml:1:6: executing "header\.gohtml" at <\.Missing>: Missing not found on example\.com/app\.Page'
+
+-- go.mod --
+module example.com/app
+
+go 1.25.0
+-- main.go --
+package main
+
+import (
+ "embed"
+ "fmt"
+ "html/template"
+ "io"
+)
+
+var (
+ //go:embed *.gohtml
+ source embed.FS
+
+ templates = template.Must(template.ParseFS(source, "*"))
+)
+
+type Page struct {
+ Title string
+}
+
+func render() {
+ _ = templates.ExecuteTemplate(io.Discard, "index.gohtml", Page{})
+}
+
+var _ = fmt.Sprint
+
+-- index.gohtml --
+{{template "header.gohtml" .}}
+body
+-- header.gohtml --
+{{.Missing}}
diff --git a/cmd/check-templates/testdata/err_shadowed_var.txt b/cmd/check-templates/testdata/err_shadowed_var.txt
new file mode 100644
index 0000000..479f1e8
--- /dev/null
+++ b/cmd/check-templates/testdata/err_shadowed_var.txt
@@ -0,0 +1,48 @@
+# A shadowed template variable must not confuse resolution: the inner scope
+# has its own template, and each ExecuteTemplate call is checked against the
+# correct definition. The outer ts parses index.gohtml (which uses .Title),
+# and the inner ts parses about.gohtml (which uses .Name). Passing the wrong
+# data type to the inner call should produce an error only for that call.
+
+! check-templates
+stderr 'type check failed:.*about\.gohtml:1:5: executing "about\.gohtml" at <\.Name>: Name not found on example\.com/app\.IndexPage'
+! stderr 'Title not found'
+
+-- go.mod --
+module example.com/app
+
+go 1.25.0
+-- main.go --
+package main
+
+import (
+ "embed"
+ "fmt"
+ "html/template"
+ "net/http"
+)
+
+//go:embed *.gohtml
+var source embed.FS
+
+type IndexPage struct {
+ Title string
+}
+
+var ts = template.Must(template.ParseFS(source, "index.gohtml"))
+
+func handleIndex(w http.ResponseWriter, r *http.Request) {
+ _ = ts.ExecuteTemplate(w, "index.gohtml", IndexPage{Title: "Home"})
+}
+
+func handleAbout(w http.ResponseWriter, r *http.Request) {
+ ts := template.Must(template.ParseFS(source, "about.gohtml"))
+ _ = ts.ExecuteTemplate(w, "about.gohtml", IndexPage{Title: "oops"})
+}
+
+var _ = fmt.Sprint
+
+-- index.gohtml --
+{{.Title}}
+-- about.gohtml --
+{{.Name}}
diff --git a/cmd/check-templates/testdata/err_text_template.txt b/cmd/check-templates/testdata/err_text_template.txt
new file mode 100644
index 0000000..7e6e187
--- /dev/null
+++ b/cmd/check-templates/testdata/err_text_template.txt
@@ -0,0 +1,38 @@
+# text/template should be checked the same as html/template.
+
+! check-templates
+stderr 'type check failed:.*index\.gotmpl:1:2: executing "index\.gotmpl" at <\.Missing>: Missing not found on example\.com/app\.Page'
+
+-- go.mod --
+module example.com/app
+
+go 1.25.0
+-- main.go --
+package main
+
+import (
+ "embed"
+ "fmt"
+ "io"
+ "text/template"
+)
+
+var (
+ //go:embed *.gotmpl
+ source embed.FS
+
+ templates = template.Must(template.ParseFS(source, "*"))
+)
+
+type Page struct {
+ Title string
+}
+
+func render() {
+ _ = templates.ExecuteTemplate(io.Discard, "index.gotmpl", Page{Title: "Home"})
+}
+
+var _ = fmt.Sprint
+
+-- index.gotmpl --
+{{.Missing}}
diff --git a/cmd/check-templates/testdata/pass_additional_parsefs.txt b/cmd/check-templates/testdata/pass_additional_parsefs.txt
new file mode 100644
index 0000000..50059f6
--- /dev/null
+++ b/cmd/check-templates/testdata/pass_additional_parsefs.txt
@@ -0,0 +1,51 @@
+# Template defined at top level, then has more templates parsed inside a function.
+# Both sets of templates should be available.
+
+check-templates
+! stderr .
+
+-- go.mod --
+module example.com/app
+
+go 1.25.0
+-- main.go --
+package main
+
+import (
+ "embed"
+ "fmt"
+ "html/template"
+ "net/http"
+)
+
+//go:embed *.gohtml
+var source embed.FS
+
+type IndexPage struct {
+ Title string
+}
+
+type AboutPage struct {
+ Name string
+}
+
+var ts = template.Must(template.New("app").ParseFS(source, "index.gohtml"))
+
+func setup() {
+ template.Must(ts.ParseFS(source, "about.gohtml"))
+}
+
+func handleIndex(w http.ResponseWriter, r *http.Request) {
+ _ = ts.ExecuteTemplate(w, "index.gohtml", IndexPage{Title: "Home"})
+}
+
+func handleAbout(w http.ResponseWriter, r *http.Request) {
+ _ = ts.ExecuteTemplate(w, "about.gohtml", AboutPage{Name: "World"})
+}
+
+var _ = fmt.Sprint
+
+-- index.gohtml --
+{{.Title}}
+-- about.gohtml --
+{{.Name}}
diff --git a/cmd/check-templates/testdata/pass_additional_parsefs_no_must.txt b/cmd/check-templates/testdata/pass_additional_parsefs_no_must.txt
new file mode 100644
index 0000000..09975df
--- /dev/null
+++ b/cmd/check-templates/testdata/pass_additional_parsefs_no_must.txt
@@ -0,0 +1,47 @@
+# Template defined at top level, modified in a function without template.Must.
+# The ParseFS call result is discarded but should still be detected.
+
+check-templates
+! stderr .
+
+-- go.mod --
+module example.com/app
+
+go 1.25.0
+-- main.go --
+package main
+
+import (
+ "embed"
+ "fmt"
+ "html/template"
+ "net/http"
+)
+
+//go:embed *.gohtml
+var source embed.FS
+
+type IndexPage struct {
+ Title string
+}
+
+type AboutPage struct {
+ Name string
+}
+
+var ts = template.Must(template.New("app").ParseFS(source, "index.gohtml"))
+
+func setup() {
+ _, _ = ts.ParseFS(source, "about.gohtml")
+}
+
+func handleAbout(w http.ResponseWriter, r *http.Request) {
+ _ = ts.ExecuteTemplate(w, "about.gohtml", AboutPage{Name: "World"})
+}
+
+var _ = fmt.Sprint
+
+-- index.gohtml --
+{{.Title}}
+-- about.gohtml --
+{{.Name}}
diff --git a/cmd/check-templates/testdata/pass_aliased_import.txt b/cmd/check-templates/testdata/pass_aliased_import.txt
new file mode 100644
index 0000000..96a9837
--- /dev/null
+++ b/cmd/check-templates/testdata/pass_aliased_import.txt
@@ -0,0 +1,38 @@
+# Template import with a non-standard alias should still be resolved.
+
+check-templates
+! stderr .
+
+-- go.mod --
+module example.com/app
+
+go 1.25.0
+-- main.go --
+package main
+
+import (
+ "embed"
+ "fmt"
+ htmltpl "html/template"
+ "net/http"
+)
+
+var (
+ //go:embed *.gohtml
+ source embed.FS
+
+ templates = htmltpl.Must(htmltpl.ParseFS(source, "*"))
+)
+
+type Page struct {
+ Title string
+}
+
+func handleIndex(w http.ResponseWriter, r *http.Request) {
+ _ = templates.ExecuteTemplate(w, "index.gohtml", Page{Title: "Home"})
+}
+
+var _ = fmt.Sprint
+
+-- index.gohtml --
+{{.Title}}
diff --git a/cmd/check-templates/testdata/pass_closure.txt b/cmd/check-templates/testdata/pass_closure.txt
new file mode 100644
index 0000000..02f97d3
--- /dev/null
+++ b/cmd/check-templates/testdata/pass_closure.txt
@@ -0,0 +1,38 @@
+# Template parsed in outer function, closures call ExecuteTemplate.
+
+check-templates
+! stderr .
+
+-- go.mod --
+module example.com/app
+
+go 1.25.0
+-- main.go --
+package main
+
+import (
+ "embed"
+ "fmt"
+ "html/template"
+ "net/http"
+)
+
+//go:embed *.gohtml
+var source embed.FS
+
+type Page struct {
+ Title string
+}
+
+func routes(mux *http.ServeMux) {
+ ts := template.Must(template.ParseFS(source, "*"))
+
+ mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
+ _ = ts.ExecuteTemplate(w, "index.gohtml", Page{Title: "Home"})
+ })
+}
+
+var _ = fmt.Sprint
+
+-- index.gohtml --
+{{.Title}}
diff --git a/cmd/check-templates/testdata/pass_different_identifier.txt b/cmd/check-templates/testdata/pass_different_identifier.txt
new file mode 100644
index 0000000..be176ca
--- /dev/null
+++ b/cmd/check-templates/testdata/pass_different_identifier.txt
@@ -0,0 +1,38 @@
+# Template variable with a non-standard identifier name.
+
+check-templates
+! stderr .
+
+-- go.mod --
+module example.com/app
+
+go 1.25.0
+-- main.go --
+package main
+
+import (
+ "embed"
+ "fmt"
+ "html/template"
+ "net/http"
+)
+
+var (
+ //go:embed *.gohtml
+ source embed.FS
+
+ tmpl = template.Must(template.ParseFS(source, "*"))
+)
+
+type Page struct {
+ Title string
+}
+
+func handleIndex(w http.ResponseWriter, r *http.Request) {
+ _ = tmpl.ExecuteTemplate(w, "index.gohtml", Page{Title: "Home"})
+}
+
+var _ = fmt.Sprint
+
+-- index.gohtml --
+{{.Title}}
diff --git a/cmd/check-templates/testdata/pass_funcs_chain.txt b/cmd/check-templates/testdata/pass_funcs_chain.txt
new file mode 100644
index 0000000..3a2dc12
--- /dev/null
+++ b/cmd/check-templates/testdata/pass_funcs_chain.txt
@@ -0,0 +1,40 @@
+# Template constructed with Funcs() chained before ParseFS should be resolved.
+
+check-templates
+
+-- go.mod --
+module example.com/app
+
+go 1.25.0
+-- main.go --
+package main
+
+import (
+ "embed"
+ "fmt"
+ "html/template"
+ "io"
+ "strings"
+)
+
+var (
+ //go:embed *.gohtml
+ source embed.FS
+
+ templates = template.Must(template.New("").Funcs(template.FuncMap{
+ "upper": strings.ToUpper,
+ }).ParseFS(source, "*"))
+)
+
+type Page struct {
+ Title string
+}
+
+func render() {
+ _ = templates.ExecuteTemplate(io.Discard, "index.gohtml", Page{})
+}
+
+var _ = fmt.Sprint
+
+-- index.gohtml --
+{{.Title | upper}}
diff --git a/cmd/check-templates/testdata/pass_imported_type.txt b/cmd/check-templates/testdata/pass_imported_type.txt
new file mode 100644
index 0000000..b6770ea
--- /dev/null
+++ b/cmd/check-templates/testdata/pass_imported_type.txt
@@ -0,0 +1,41 @@
+# Types imported from other packages should be resolved correctly.
+
+check-templates
+
+-- go.mod --
+module example.com/app
+
+go 1.25.0
+-- internal/model/types.go --
+package model
+
+type Page struct {
+ Title string
+}
+-- main.go --
+package main
+
+import (
+ "embed"
+ "fmt"
+ "html/template"
+ "io"
+
+ "example.com/app/internal/model"
+)
+
+var (
+ //go:embed *.gohtml
+ source embed.FS
+
+ templates = template.Must(template.ParseFS(source, "*"))
+)
+
+func render() {
+ _ = templates.ExecuteTemplate(io.Discard, "index.gohtml", model.Page{})
+}
+
+var _ = fmt.Sprint
+
+-- index.gohtml --
+{{.Title}}
diff --git a/cmd/check-templates/testdata/pass_inline_struct.txt b/cmd/check-templates/testdata/pass_inline_struct.txt
new file mode 100644
index 0000000..cfecd07
--- /dev/null
+++ b/cmd/check-templates/testdata/pass_inline_struct.txt
@@ -0,0 +1,36 @@
+# Inline anonymous struct types should be resolved correctly.
+
+check-templates
+
+-- go.mod --
+module example.com/app
+
+go 1.25.0
+-- main.go --
+package main
+
+import (
+ "embed"
+ "fmt"
+ "html/template"
+ "io"
+)
+
+var (
+ //go:embed *.gohtml
+ source embed.FS
+
+ templates = template.Must(template.ParseFS(source, "*"))
+)
+
+func render() {
+ var data struct {
+ Title string
+ }
+ _ = templates.ExecuteTemplate(io.Discard, "index.gohtml", data)
+}
+
+var _ = fmt.Sprint
+
+-- index.gohtml --
+{{.Title}}
diff --git a/cmd/check-templates/testdata/pass_local_var.txt b/cmd/check-templates/testdata/pass_local_var.txt
new file mode 100644
index 0000000..dfa6fd3
--- /dev/null
+++ b/cmd/check-templates/testdata/pass_local_var.txt
@@ -0,0 +1,35 @@
+# Template defined as a local variable in the same function as ExecuteTemplate.
+
+check-templates
+! stderr .
+
+-- go.mod --
+module example.com/app
+
+go 1.25.0
+-- main.go --
+package main
+
+import (
+ "embed"
+ "fmt"
+ "html/template"
+ "net/http"
+)
+
+//go:embed *.gohtml
+var source embed.FS
+
+func handleIndex(w http.ResponseWriter, r *http.Request) {
+ ts := template.Must(template.ParseFS(source, "*"))
+ _ = ts.ExecuteTemplate(w, "index.gohtml", Page{Title: "Home"})
+}
+
+type Page struct {
+ Title string
+}
+
+var _ = fmt.Sprint
+
+-- index.gohtml --
+{{.Title}}
diff --git a/cmd/check-templates/testdata/pass_multiple_calls.txt b/cmd/check-templates/testdata/pass_multiple_calls.txt
new file mode 100644
index 0000000..36573ba
--- /dev/null
+++ b/cmd/check-templates/testdata/pass_multiple_calls.txt
@@ -0,0 +1,48 @@
+# Multiple ExecuteTemplate calls with different types all pass.
+
+check-templates
+! stderr .
+
+-- go.mod --
+module example.com/app
+
+go 1.25.0
+-- main.go --
+package main
+
+import (
+ "embed"
+ "fmt"
+ "html/template"
+ "net/http"
+)
+
+var (
+ //go:embed *.gohtml
+ source embed.FS
+
+ templates = template.Must(template.ParseFS(source, "*"))
+)
+
+type IndexPage struct {
+ Title string
+}
+
+type AboutPage struct {
+ Name string
+}
+
+func handleIndex(w http.ResponseWriter, r *http.Request) {
+ _ = templates.ExecuteTemplate(w, "index.gohtml", IndexPage{Title: "Home"})
+}
+
+func handleAbout(w http.ResponseWriter, r *http.Request) {
+ _ = templates.ExecuteTemplate(w, "about.gohtml", AboutPage{Name: "World"})
+}
+
+var _ = fmt.Sprint
+
+-- index.gohtml --
+{{.Title}}
+-- about.gohtml --
+{{.Name}}
diff --git a/cmd/check-templates/testdata/pass_multiple_template_vars.txt b/cmd/check-templates/testdata/pass_multiple_template_vars.txt
new file mode 100644
index 0000000..ad95b96
--- /dev/null
+++ b/cmd/check-templates/testdata/pass_multiple_template_vars.txt
@@ -0,0 +1,46 @@
+# Multiple template variables should each resolve independently.
+
+check-templates
+
+-- go.mod --
+module example.com/app
+
+go 1.25.0
+-- main.go --
+package main
+
+import (
+ "embed"
+ "fmt"
+ "html/template"
+ "io"
+)
+
+//go:embed *.gohtml
+var source embed.FS
+
+var templatesA = template.Must(template.ParseFS(source, "index.gohtml"))
+var templatesB = template.Must(template.ParseFS(source, "about.gohtml"))
+
+type IndexPage struct {
+ Title string
+}
+
+type AboutPage struct {
+ Name string
+}
+
+func renderIndex() {
+ _ = templatesA.ExecuteTemplate(io.Discard, "index.gohtml", IndexPage{})
+}
+
+func renderAbout() {
+ _ = templatesB.ExecuteTemplate(io.Discard, "about.gohtml", AboutPage{})
+}
+
+var _ = fmt.Sprint
+
+-- index.gohtml --
+{{.Title}}
+-- about.gohtml --
+{{.Name}}
diff --git a/cmd/check-templates/testdata/pass_nested_template.txt b/cmd/check-templates/testdata/pass_nested_template.txt
new file mode 100644
index 0000000..e6c04e0
--- /dev/null
+++ b/cmd/check-templates/testdata/pass_nested_template.txt
@@ -0,0 +1,40 @@
+# Nested template calls should propagate the data type correctly.
+
+check-templates
+
+-- go.mod --
+module example.com/app
+
+go 1.25.0
+-- main.go --
+package main
+
+import (
+ "embed"
+ "fmt"
+ "html/template"
+ "io"
+)
+
+var (
+ //go:embed *.gohtml
+ source embed.FS
+
+ templates = template.Must(template.ParseFS(source, "*"))
+)
+
+type Page struct {
+ Title string
+}
+
+func render() {
+ _ = templates.ExecuteTemplate(io.Discard, "index.gohtml", Page{})
+}
+
+var _ = fmt.Sprint
+
+-- index.gohtml --
+{{template "header.gohtml" .}}
+body
+-- header.gohtml --
+{{.Title}}
diff --git a/cmd/check-templates/testdata/pass_no_execute_calls.txt b/cmd/check-templates/testdata/pass_no_execute_calls.txt
new file mode 100644
index 0000000..000f01f
--- /dev/null
+++ b/cmd/check-templates/testdata/pass_no_execute_calls.txt
@@ -0,0 +1,13 @@
+# Package with no ExecuteTemplate calls passes without error.
+
+check-templates
+! stderr .
+
+-- go.mod --
+module example.com/app
+
+go 1.25.0
+-- main.go --
+package main
+
+func main() {}
diff --git a/cmd/check-templates/testdata/pass_non_template_execute.txt b/cmd/check-templates/testdata/pass_non_template_execute.txt
new file mode 100644
index 0000000..b7a572e
--- /dev/null
+++ b/cmd/check-templates/testdata/pass_non_template_execute.txt
@@ -0,0 +1,28 @@
+# A custom type with an ExecuteTemplate method should not be confused
+# with html/template. The tool should not report errors for it.
+
+check-templates
+! stderr .
+
+-- go.mod --
+module example.com/app
+
+go 1.25.0
+-- main.go --
+package main
+
+import (
+ "io"
+ "net/http"
+)
+
+type Renderer struct{}
+
+func (r *Renderer) ExecuteTemplate(w io.Writer, name string, data any) error {
+ return nil
+}
+
+func handle(w http.ResponseWriter, r *http.Request) {
+ renderer := &Renderer{}
+ _ = renderer.ExecuteTemplate(w, "index.gohtml", struct{ Missing string }{})
+}
diff --git a/cmd/check-templates/testdata/pass_parse_call.txt b/cmd/check-templates/testdata/pass_parse_call.txt
new file mode 100644
index 0000000..d1ea925
--- /dev/null
+++ b/cmd/check-templates/testdata/pass_parse_call.txt
@@ -0,0 +1,29 @@
+# Template constructed with Parse (not ParseFS) passes check.
+
+check-templates
+! stderr .
+
+-- go.mod --
+module example.com/app
+
+go 1.25.0
+-- main.go --
+package main
+
+import (
+ "fmt"
+ "html/template"
+ "net/http"
+)
+
+var templates = template.Must(template.New("greeting").Parse(`Hello, {{.Name}}!
`))
+
+type Greeting struct {
+ Name string
+}
+
+func handle(w http.ResponseWriter, r *http.Request) {
+ _ = templates.ExecuteTemplate(w, "greeting", Greeting{Name: "World"})
+}
+
+var _ = fmt.Sprint
diff --git a/cmd/check-templates/testdata/pass_shadowed_var.txt b/cmd/check-templates/testdata/pass_shadowed_var.txt
new file mode 100644
index 0000000..213415c
--- /dev/null
+++ b/cmd/check-templates/testdata/pass_shadowed_var.txt
@@ -0,0 +1,47 @@
+# A shadowed template variable should resolve to the correct definition at each call site.
+
+check-templates
+! stderr .
+
+-- go.mod --
+module example.com/app
+
+go 1.25.0
+-- main.go --
+package main
+
+import (
+ "embed"
+ "fmt"
+ "html/template"
+ "net/http"
+)
+
+//go:embed *.gohtml
+var source embed.FS
+
+type IndexPage struct {
+ Title string
+}
+
+type AboutPage struct {
+ Name string
+}
+
+var ts = template.Must(template.ParseFS(source, "index.gohtml"))
+
+func handleIndex(w http.ResponseWriter, r *http.Request) {
+ _ = ts.ExecuteTemplate(w, "index.gohtml", IndexPage{Title: "Home"})
+}
+
+func handleAbout(w http.ResponseWriter, r *http.Request) {
+ ts := template.Must(template.ParseFS(source, "about.gohtml"))
+ _ = ts.ExecuteTemplate(w, "about.gohtml", AboutPage{Name: "World"})
+}
+
+var _ = fmt.Sprint
+
+-- index.gohtml --
+{{.Title}}
+-- about.gohtml --
+{{.Name}}
diff --git a/cmd/check-templates/testdata/pass_simple.txt b/cmd/check-templates/testdata/pass_simple.txt
new file mode 100644
index 0000000..de22050
--- /dev/null
+++ b/cmd/check-templates/testdata/pass_simple.txt
@@ -0,0 +1,38 @@
+# Simple template check passes when fields match.
+
+check-templates
+! stderr .
+
+-- go.mod --
+module example.com/app
+
+go 1.25.0
+-- main.go --
+package main
+
+import (
+ "embed"
+ "fmt"
+ "html/template"
+ "net/http"
+)
+
+var (
+ //go:embed *.gohtml
+ source embed.FS
+
+ templates = template.Must(template.ParseFS(source, "*"))
+)
+
+type Page struct {
+ Title string
+}
+
+func handleIndex(w http.ResponseWriter, r *http.Request) {
+ _ = templates.ExecuteTemplate(w, "index.gohtml", Page{Title: "Home"})
+}
+
+var _ = fmt.Sprint
+
+-- index.gohtml --
+{{.Title}}
diff --git a/cmd/check-templates/testdata/pass_text_template.txt b/cmd/check-templates/testdata/pass_text_template.txt
new file mode 100644
index 0000000..cbf8999
--- /dev/null
+++ b/cmd/check-templates/testdata/pass_text_template.txt
@@ -0,0 +1,37 @@
+# text/template should be checked the same as html/template.
+
+check-templates
+
+-- go.mod --
+module example.com/app
+
+go 1.25.0
+-- main.go --
+package main
+
+import (
+ "embed"
+ "fmt"
+ "io"
+ "text/template"
+)
+
+var (
+ //go:embed *.gotmpl
+ source embed.FS
+
+ templates = template.Must(template.ParseFS(source, "*"))
+)
+
+type Page struct {
+ Title string
+}
+
+func render() {
+ _ = templates.ExecuteTemplate(io.Discard, "index.gotmpl", Page{Title: "Home"})
+}
+
+var _ = fmt.Sprint
+
+-- index.gotmpl --
+{{.Title}}
diff --git a/func.go b/func.go
index 72e2b6f..80c51c8 100644
--- a/func.go
+++ b/func.go
@@ -31,7 +31,15 @@ func DefaultFunctions(pkg *types.Package) Functions {
} {
if p, ok := findPackage(pkg, pn); ok && p != nil {
for funcIdent, templateFunc := range idents {
- fns[templateFunc] = p.Scope().Lookup(funcIdent).Type().(*types.Signature)
+ obj := p.Scope().Lookup(funcIdent)
+ if obj == nil {
+ continue
+ }
+ sig, ok := obj.Type().(*types.Signature)
+ if !ok {
+ continue
+ }
+ fns[templateFunc] = sig
}
}
}
diff --git a/go.mod b/go.mod
index 7d0808a..bad63f5 100644
--- a/go.mod
+++ b/go.mod
@@ -5,6 +5,7 @@ go 1.25.0
require (
github.com/stretchr/testify v1.11.1
golang.org/x/tools v0.41.0
+ rsc.io/script v0.0.2
)
require (
diff --git a/go.sum b/go.sum
index ddb0899..7e5240f 100644
--- a/go.sum
+++ b/go.sum
@@ -16,3 +16,5 @@ gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
+rsc.io/script v0.0.2 h1:eYoG7A3GFC3z1pRx3A2+s/vZ9LA8cxojHyCvslnj4RI=
+rsc.io/script v0.0.2/go.mod h1:cKBjCtFBBeZ0cbYFRXkRoxP+xGqhArPa9t3VWhtXfzU=
diff --git a/internal/asteval/errors.go b/internal/asteval/errors.go
new file mode 100644
index 0000000..fb2c83e
--- /dev/null
+++ b/internal/asteval/errors.go
@@ -0,0 +1,13 @@
+package asteval
+
+import (
+ "fmt"
+ "go/token"
+ "path/filepath"
+)
+
+func wrapWithFilename(workingDirectory string, set *token.FileSet, pos token.Pos, err error) error {
+ p := set.Position(pos)
+ p.Filename, _ = filepath.Rel(workingDirectory, p.Filename)
+ return fmt.Errorf("%s: %w", p, err)
+}
diff --git a/internal/asteval/forrest.go b/internal/asteval/forrest.go
new file mode 100644
index 0000000..a3083bb
--- /dev/null
+++ b/internal/asteval/forrest.go
@@ -0,0 +1,20 @@
+package asteval
+
+import (
+ "html/template"
+ "text/template/parse"
+)
+
+type Forrest template.Template
+
+func NewForrest(templates *template.Template) *Forrest {
+ return (*Forrest)(templates)
+}
+
+func (f *Forrest) FindTree(name string) (*parse.Tree, bool) {
+ ts := (*template.Template)(f).Lookup(name)
+ if ts == nil {
+ return nil, false
+ }
+ return ts.Tree, true
+}
diff --git a/internal/asteval/string.go b/internal/asteval/string.go
new file mode 100644
index 0000000..ba8b813
--- /dev/null
+++ b/internal/asteval/string.go
@@ -0,0 +1,30 @@
+package asteval
+
+import (
+ "fmt"
+ "go/ast"
+ "go/token"
+ "strconv"
+
+ "github.com/typelate/check/internal/astgen"
+)
+
+func StringLiteralExpression(wd string, set *token.FileSet, exp ast.Expr) (string, error) {
+ arg, ok := exp.(*ast.BasicLit)
+ if !ok || arg.Kind != token.STRING {
+ return "", wrapWithFilename(wd, set, exp.Pos(), fmt.Errorf("expected string literal got %s", astgen.Format(exp)))
+ }
+ return strconv.Unquote(arg.Value)
+}
+
+func StringLiteralExpressionList(wd string, set *token.FileSet, list []ast.Expr) ([]string, error) {
+ result := make([]string, 0, len(list))
+ for _, a := range list {
+ s, err := StringLiteralExpression(wd, set, a)
+ if err != nil {
+ return result, err
+ }
+ result = append(result, s)
+ }
+ return result, nil
+}
diff --git a/internal/asteval/template.go b/internal/asteval/template.go
new file mode 100644
index 0000000..bdfe916
--- /dev/null
+++ b/internal/asteval/template.go
@@ -0,0 +1,542 @@
+package asteval
+
+import (
+ "bytes"
+ "fmt"
+ "go/ast"
+ "go/format"
+ "go/token"
+ "go/types"
+ "html/template"
+ "os"
+ "path/filepath"
+ "slices"
+ "strconv"
+ "strings"
+ "text/template/parse"
+ "unicode"
+
+ "github.com/typelate/check/internal/astgen"
+)
+
+// TemplateMetadata accumulates metadata during template evaluation.
+type TemplateMetadata struct {
+ EmbedFilePaths []string
+ ParseCalls []*ast.BasicLit
+}
+
+func EvaluateTemplateSelector(ts *template.Template, pkg *types.Package, typesInfo *types.Info, expression ast.Expr, workingDirectory, templatesVariable, rDelim, lDelim string, fileSet *token.FileSet, files []*ast.File, embeddedPaths []string, funcTypeMaps TemplateFunctions, fm template.FuncMap, meta *TemplateMetadata) (*template.Template, string, string, error) {
+ call, ok := expression.(*ast.CallExpr)
+ if !ok {
+ return nil, lDelim, rDelim, wrapWithFilename(workingDirectory, fileSet, expression.Pos(), fmt.Errorf("expected call expression"))
+ }
+ sel, ok := call.Fun.(*ast.SelectorExpr)
+ if !ok {
+ return nil, lDelim, rDelim, wrapWithFilename(workingDirectory, fileSet, call.Fun.Pos(), fmt.Errorf("unexpected expression %T: %s", call.Fun, astgen.Format(call.Fun)))
+ }
+ switch x := sel.X.(type) {
+ default:
+ return nil, lDelim, rDelim, wrapWithFilename(workingDirectory, fileSet, sel.X.Pos(), fmt.Errorf("expected exactly one argument %s got %d", astgen.Format(sel.X), len(call.Args)))
+ case *ast.Ident:
+ if !isTemplatePkgIdent(typesInfo, x) {
+ if ts == nil {
+ return nil, lDelim, rDelim, wrapWithFilename(workingDirectory, fileSet, sel.X.Pos(), fmt.Errorf("expected template package got %s", astgen.Format(sel.X)))
+ }
+ // Variable receiver — apply method to existing template.
+ switch sel.Sel.Name {
+ case "ParseFS":
+ filePaths, err := evaluateCallParseFilesArgs(workingDirectory, fileSet, call, files, embeddedPaths)
+ if err != nil {
+ return nil, lDelim, rDelim, err
+ }
+ if meta != nil {
+ meta.EmbedFilePaths = append(meta.EmbedFilePaths, filePaths...)
+ }
+ t, err := parseFiles(ts, fm, lDelim, rDelim, filePaths...)
+ return t, lDelim, rDelim, err
+ case "Parse":
+ if len(call.Args) != 1 {
+ return nil, lDelim, rDelim, wrapWithFilename(workingDirectory, fileSet, call.Lparen, fmt.Errorf("expected exactly one string literal argument"))
+ }
+ if meta != nil {
+ if bl, ok := call.Args[0].(*ast.BasicLit); ok {
+ meta.ParseCalls = append(meta.ParseCalls, bl)
+ }
+ }
+ sl, err := StringLiteralExpression(workingDirectory, fileSet, call.Args[0])
+ if err != nil {
+ return nil, lDelim, rDelim, err
+ }
+ t, err := ts.Parse(sl)
+ return t, lDelim, rDelim, err
+ case "Funcs":
+ if err := evaluateFuncMap(workingDirectory, typesInfo, pkg, fileSet, call, fm, funcTypeMaps); err != nil {
+ return nil, lDelim, rDelim, err
+ }
+ return ts.Funcs(fm), lDelim, rDelim, nil
+ case "Option":
+ list, err := StringLiteralExpressionList(workingDirectory, fileSet, call.Args)
+ if err != nil {
+ return nil, lDelim, rDelim, err
+ }
+ return ts.Option(list...), lDelim, rDelim, nil
+ case "Delims":
+ if len(call.Args) != 2 {
+ return nil, lDelim, rDelim, wrapWithFilename(workingDirectory, fileSet, call.Lparen, fmt.Errorf("expected exactly two string literal arguments"))
+ }
+ list, err := StringLiteralExpressionList(workingDirectory, fileSet, call.Args)
+ if err != nil {
+ return nil, lDelim, rDelim, err
+ }
+ return ts.Delims(list[0], list[1]), list[0], list[1], nil
+ default:
+ return nil, lDelim, rDelim, wrapWithFilename(workingDirectory, fileSet, call.Fun.Pos(), fmt.Errorf("unsupported method %s on variable receiver", sel.Sel.Name))
+ }
+ }
+ switch sel.Sel.Name {
+ case "Must":
+ if len(call.Args) != 1 {
+ return nil, lDelim, rDelim, wrapWithFilename(workingDirectory, fileSet, call.Lparen, fmt.Errorf("expected exactly one argument %s got %d", astgen.Format(sel.X), len(call.Args)))
+ }
+ return EvaluateTemplateSelector(ts, pkg, typesInfo, call.Args[0], workingDirectory, templatesVariable, rDelim, lDelim, fileSet, files, embeddedPaths, funcTypeMaps, fm, meta)
+ case "New":
+ if len(call.Args) != 1 {
+ return nil, lDelim, rDelim, wrapWithFilename(workingDirectory, fileSet, call.Lparen, fmt.Errorf("expected exactly one string literal argument"))
+ }
+ templateNames, err := StringLiteralExpressionList(workingDirectory, fileSet, call.Args)
+ if err != nil {
+ return nil, lDelim, rDelim, err
+ }
+ return template.New(templateNames[0]), lDelim, rDelim, nil
+ case "ParseFS":
+ filePaths, err := evaluateCallParseFilesArgs(workingDirectory, fileSet, call, files, embeddedPaths)
+ if err != nil {
+ return nil, lDelim, rDelim, err
+ }
+ if meta != nil {
+ meta.EmbedFilePaths = append(meta.EmbedFilePaths, filePaths...)
+ }
+ t, err := parseFiles(nil, fm, lDelim, rDelim, filePaths...)
+ return t, lDelim, rDelim, err
+ default:
+ return nil, lDelim, rDelim, wrapWithFilename(workingDirectory, fileSet, call.Fun.Pos(), fmt.Errorf("unsupported function %s", sel.Sel.Name))
+ }
+ case *ast.CallExpr:
+ up, upLDelim, upRDelim, err := EvaluateTemplateSelector(ts, pkg, typesInfo, sel.X, workingDirectory, templatesVariable, rDelim, lDelim, fileSet, files, embeddedPaths, funcTypeMaps, fm, meta)
+ if err != nil {
+ return nil, lDelim, rDelim, err
+ }
+ switch sel.Sel.Name {
+ case "Delims":
+ if len(call.Args) != 2 {
+ return nil, upLDelim, upRDelim, wrapWithFilename(workingDirectory, fileSet, call.Lparen, fmt.Errorf("expected exactly two string literal arguments"))
+ }
+ list, err := StringLiteralExpressionList(workingDirectory, fileSet, call.Args)
+ if err != nil {
+ return nil, upLDelim, upRDelim, err
+ }
+ return up.Delims(list[0], list[1]), list[0], list[1], nil
+ case "Parse":
+ if len(call.Args) != 1 {
+ return nil, upLDelim, upRDelim, wrapWithFilename(workingDirectory, fileSet, call.Lparen, fmt.Errorf("expected exactly one string literal argument"))
+ }
+ if meta != nil {
+ if bl, ok := call.Args[0].(*ast.BasicLit); ok {
+ meta.ParseCalls = append(meta.ParseCalls, bl)
+ }
+ }
+ sl, err := StringLiteralExpression(workingDirectory, fileSet, call.Args[0])
+ if err != nil {
+ return nil, upLDelim, upRDelim, err
+ }
+ t, err := up.Parse(sl)
+ return t, upLDelim, upRDelim, err
+ case "New":
+ if len(call.Args) != 1 {
+ return nil, upLDelim, upRDelim, wrapWithFilename(workingDirectory, fileSet, call.Lparen, fmt.Errorf("expected exactly one string literal argument"))
+ }
+ templateNames, err := StringLiteralExpressionList(workingDirectory, fileSet, call.Args)
+ if err != nil {
+ return nil, upLDelim, upRDelim, err
+ }
+ return up.New(templateNames[0]), upLDelim, upRDelim, nil
+ case "ParseFS":
+ filePaths, err := evaluateCallParseFilesArgs(workingDirectory, fileSet, call, files, embeddedPaths)
+ if err != nil {
+ return nil, upLDelim, upRDelim, err
+ }
+ if meta != nil {
+ meta.EmbedFilePaths = append(meta.EmbedFilePaths, filePaths...)
+ }
+ t, err := parseFiles(up, fm, upLDelim, upRDelim, filePaths...)
+ return t, upLDelim, upRDelim, err
+ case "Option":
+ list, err := StringLiteralExpressionList(workingDirectory, fileSet, call.Args)
+ if err != nil {
+ return nil, upLDelim, upRDelim, err
+ }
+ return up.Option(list...), upLDelim, upRDelim, nil
+ case "Funcs":
+ if err := evaluateFuncMap(workingDirectory, typesInfo, pkg, fileSet, call, fm, funcTypeMaps); err != nil {
+ return nil, upLDelim, upRDelim, err
+ }
+ return up.Funcs(fm), upLDelim, upRDelim, nil
+ default:
+ return nil, upLDelim, upRDelim, wrapWithFilename(workingDirectory, fileSet, call.Fun.Pos(), fmt.Errorf("unsupported method %s", sel.Sel.Name))
+ }
+ }
+}
+
+// isTemplatePkgIdent reports whether ident refers to the "html/template"
+// or "text/template" package via the type checker.
+func isTemplatePkgIdent(info *types.Info, ident *ast.Ident) bool {
+ if info == nil {
+ return false
+ }
+ obj := info.Uses[ident]
+ pkgName, ok := obj.(*types.PkgName)
+ if !ok {
+ return false
+ }
+ path := pkgName.Imported().Path()
+ return path == "html/template" || path == "text/template"
+}
+
+func builtins() template.FuncMap {
+ type nothing struct{}
+ return template.FuncMap{
+ "and": func() (_ nothing) { return },
+ "call": func() (_ nothing) { return },
+ "html": func() (_ nothing) { return },
+ "index": func() (_ nothing) { return },
+ "slice": func() (_ nothing) { return },
+ "js": func() (_ nothing) { return },
+ "len": func() (_ nothing) { return },
+ "not": func() (_ nothing) { return },
+ "or": func() (_ nothing) { return },
+ "print": func() (_ nothing) { return },
+ "printf": func() (_ nothing) { return },
+ "println": func() (_ nothing) { return },
+ "urlquery": func() (_ nothing) { return },
+
+ // Comparisons
+ "eq": func() (_ nothing) { return },
+ "ge": func() (_ nothing) { return },
+ "gt": func() (_ nothing) { return },
+ "le": func() (_ nothing) { return },
+ "lt": func() (_ nothing) { return },
+ "ne": func() (_ nothing) { return },
+ }
+}
+
+func parseFiles(t *template.Template, fm template.FuncMap, leftDelim, rightDelim string, filenames ...string) (*template.Template, error) {
+ if len(filenames) == 0 {
+ return nil, fmt.Errorf("html/template: no files named in call to ParseFiles")
+ }
+ for _, filename := range filenames {
+ templateName := filepath.Base(filename)
+ b, err := os.ReadFile(filename)
+ if err != nil {
+ return nil, err
+ }
+ s := string(b)
+ var tmpl *template.Template
+ if t == nil {
+ t = template.New(templateName)
+ }
+ if templateName == t.Name() {
+ tmpl = t
+ } else {
+ tmpl = t.New(templateName)
+ }
+ trees, err := parse.Parse(templateName, s, leftDelim, rightDelim, fm, builtins())
+ if err != nil {
+ return nil, err
+ }
+ absoluteFilename, err := filepath.Abs(filename)
+ if err != nil {
+ return nil, err
+ }
+ for _, tree := range trees {
+ tree.ParseName = absoluteFilename
+ if _, err = tmpl.AddParseTree(tree.Name, tree); err != nil {
+ return nil, err
+ }
+ }
+ }
+ return t, nil
+}
+
+func evaluateFuncMap(workingDirectory string, typesInfo *types.Info, pkg *types.Package, fileSet *token.FileSet, call *ast.CallExpr, fm template.FuncMap, funcTypesMap TemplateFunctions) error {
+ if len(call.Args) != 1 {
+ return wrapWithFilename(workingDirectory, fileSet, call.Lparen, fmt.Errorf("expected exactly 1 template.FuncMap composite literal argument"))
+ }
+ arg := call.Args[0]
+ lit, ok := arg.(*ast.CompositeLit)
+ if !ok {
+ return wrapWithFilename(workingDirectory, fileSet, arg.Pos(), fmt.Errorf("expected a template.FuncMap composite literal got %s", astgen.Format(arg)))
+ }
+ if typesInfo != nil {
+ litType := typesInfo.TypeOf(lit)
+ if named, ok := litType.(*types.Named); ok {
+ obj := named.Obj()
+ if obj.Pkg() == nil || obj.Name() != "FuncMap" {
+ return wrapWithFilename(workingDirectory, fileSet, arg.Pos(), fmt.Errorf("expected template.FuncMap got %s", litType))
+ }
+ path := obj.Pkg().Path()
+ if path != "html/template" && path != "text/template" {
+ return wrapWithFilename(workingDirectory, fileSet, arg.Pos(), fmt.Errorf("expected template.FuncMap got %s", litType))
+ }
+ }
+ }
+ var buf bytes.Buffer
+ for i, exp := range lit.Elts {
+ el, ok := exp.(*ast.KeyValueExpr)
+ if !ok {
+ return wrapWithFilename(workingDirectory, fileSet, exp.Pos(), fmt.Errorf("expected element at index %d to be a key value pair got %s", i, astgen.Format(exp)))
+ }
+ funcName, err := StringLiteralExpression(workingDirectory, fileSet, el.Key)
+ if err != nil {
+ return err
+ }
+
+ // template.Parse does not evaluate the function signature parameters;
+ // it ensures the function name is in scope and there is one or two results.
+ // we could use something like func() string { return "" } for this signature
+ // but this function from fmt works just fine.
+ //
+ // to explore the known requirements run:
+ // fm[funcName] = nil // will fail because nil does not have `reflect.Kind` Func
+ // or
+ // fm[funcName] = func() {} // will fail because there are no results
+ // or
+ // fm[funcName] = func() (int, int) {return 0, 0} // will fail because the second result is not an error
+ fm[funcName] = fmt.Sprintln
+
+ if pkg == nil {
+ continue
+ }
+ buf.Reset()
+ if err := format.Node(&buf, fileSet, el.Value); err != nil {
+ return err
+ }
+ tv, err := types.Eval(fileSet, pkg, lit.Pos(), buf.String())
+ if err != nil {
+ return err
+ }
+ funcTypesMap[funcName] = tv.Type.(*types.Signature)
+ }
+ return nil
+}
+
+func evaluateCallParseFilesArgs(workingDirectory string, fileSet *token.FileSet, call *ast.CallExpr, files []*ast.File, embeddedPaths []string) ([]string, error) {
+ if len(call.Args) < 1 {
+ return nil, wrapWithFilename(workingDirectory, fileSet, call.Lparen, fmt.Errorf("missing required arguments"))
+ }
+ matches, err := embedFSFilePaths(workingDirectory, fileSet, files, call.Args[0], embeddedPaths)
+ if err != nil {
+ return nil, err
+ }
+ templateNames, err := StringLiteralExpressionList(workingDirectory, fileSet, call.Args[1:])
+ if err != nil {
+ return nil, err
+ }
+ filtered := matches[:0]
+ for _, ef := range matches {
+ for j, pattern := range templateNames {
+ match, err := filepath.Match(pattern, ef)
+ if err != nil {
+ return nil, wrapWithFilename(workingDirectory, fileSet, call.Args[j+1].Pos(), fmt.Errorf("bad pattern %q: %w", pattern, err))
+ }
+ if !match {
+ continue
+ }
+ filtered = append(filtered, ef)
+ break
+ }
+ }
+ return joinFilePaths(workingDirectory, filtered...), nil
+}
+
+func embedFSFilePaths(dir string, fileSet *token.FileSet, files []*ast.File, exp ast.Expr, embeddedFiles []string) ([]string, error) {
+ varIdent, ok := exp.(*ast.Ident)
+ if !ok {
+ return nil, wrapWithFilename(dir, fileSet, exp.Pos(), fmt.Errorf("first argument to ParseFS must be an identifier"))
+ }
+ for _, decl := range astgen.IterateGenDecl(files, token.VAR) {
+ for _, s := range decl.Specs {
+ spec, ok := s.(*ast.ValueSpec)
+ if !ok || !slices.ContainsFunc(spec.Names, func(e *ast.Ident) bool { return e.Name == varIdent.Name }) {
+ continue
+ }
+ var comment strings.Builder
+ commentNode := readComments(&comment, decl.Doc, spec.Doc)
+ templateNames := parseTemplateNames(comment.String())
+ absMat, err := embeddedFilesMatchingTemplateNameList(dir, fileSet, commentNode, templateNames, embeddedFiles)
+ if err != nil {
+ return nil, err
+ }
+ return absMat, nil
+ }
+ }
+ return nil, wrapWithFilename(dir, fileSet, exp.Pos(), fmt.Errorf("variable %s not found", varIdent))
+}
+
+func embeddedFilesMatchingTemplateNameList(dir string, set *token.FileSet, comment ast.Node, templateNames, embeddedFiles []string) ([]string, error) {
+ var matches []string
+ for _, fp := range embeddedFiles {
+ for _, pattern := range templateNames {
+ pat := filepath.FromSlash(pattern)
+ if !strings.ContainsAny(pat, "*[]") {
+ prefix := filepath.FromSlash(pat + "/")
+ if strings.HasPrefix(fp, prefix) {
+ matches = append(matches, fp)
+ continue
+ }
+ }
+ if matched, err := filepath.Match(pat, fp); err != nil {
+ return nil, wrapWithFilename(dir, set, comment.Pos(), fmt.Errorf("embed comment malformed: %w", err))
+ } else if matched {
+ matches = append(matches, fp)
+ }
+ }
+ }
+ return slices.Clip(matches), nil
+}
+
+const goEmbedCommentPrefix = "//go:embed"
+
+func readComments(s *strings.Builder, groups ...*ast.CommentGroup) ast.Node {
+ var n ast.Node
+ for _, c := range groups {
+ if c == nil {
+ continue
+ }
+ for _, line := range c.List {
+ if !strings.HasPrefix(line.Text, goEmbedCommentPrefix) {
+ continue
+ }
+ s.WriteString(strings.TrimSpace(strings.TrimPrefix(line.Text, goEmbedCommentPrefix)))
+ s.WriteByte(' ')
+ }
+ n = c
+ break
+ }
+ return n
+}
+
+func parseTemplateNames(input string) []string {
+ // todo: refactor to use strconv.QuotedPrefix
+ var (
+ templateNames []string
+ currentTemplateName strings.Builder
+ inQuote = false
+ quoteChar rune
+ )
+
+ for _, r := range input {
+ switch {
+ case r == '"' || r == '`':
+ if !inQuote {
+ inQuote = true
+ quoteChar = r
+ continue
+ }
+ if r != quoteChar {
+ currentTemplateName.WriteRune(r)
+ continue
+ }
+ templateNames = append(templateNames, currentTemplateName.String())
+ currentTemplateName.Reset()
+ inQuote = false
+ case unicode.IsSpace(r):
+ if inQuote {
+ currentTemplateName.WriteRune(r)
+ continue
+ }
+ if currentTemplateName.Len() > 0 {
+ templateNames = append(templateNames, currentTemplateName.String())
+ currentTemplateName.Reset()
+ }
+ default:
+ currentTemplateName.WriteRune(r)
+ }
+ }
+
+ // Import any remaining pattern
+ if currentTemplateName.Len() > 0 {
+ templateNames = append(templateNames, currentTemplateName.String())
+ }
+
+ return templateNames
+}
+
+func joinFilePaths(wd string, rel ...string) []string {
+ result := slices.Clone(rel)
+ for i := range result {
+ result[i] = filepath.Join(wd, result[i])
+ }
+ return result
+}
+
+func RelativeFilePaths(wd string, abs ...string) ([]string, error) {
+ result := slices.Clone(abs)
+ for i, p := range result {
+ r, err := filepath.Rel(wd, p)
+ if err != nil {
+ return nil, err
+ }
+ result[i] = r
+ }
+ return result, nil
+}
+
+type TemplateFunctions map[string]*types.Signature
+
+func DefaultFunctions(pkg *types.Package) TemplateFunctions {
+ funcTypeMap := make(TemplateFunctions)
+ fmtPkg, ok := findPackage(pkg, "fmt")
+ if !ok || fmtPkg == nil {
+ return funcTypeMap
+ }
+ funcTypeMap["printf"] = fmtPkg.Scope().Lookup("Sprintf").Type().(*types.Signature)
+ funcTypeMap["print"] = fmtPkg.Scope().Lookup("Sprint").Type().(*types.Signature)
+ funcTypeMap["println"] = fmtPkg.Scope().Lookup("Sprintln").Type().(*types.Signature)
+ return funcTypeMap
+}
+
+func findPackage(pkg *types.Package, path string) (*types.Package, bool) {
+ if pkg == nil || pkg.Path() == path {
+ return pkg, true
+ }
+ for _, im := range pkg.Imports() {
+ if p, ok := findPackage(im, path); ok {
+ return p, true
+ }
+ }
+ return nil, false
+}
+
+func (functions TemplateFunctions) FindFunction(name string) (*types.Signature, bool) {
+ m := (map[string]*types.Signature)(functions)
+ fn, ok := m[name]
+ if !ok {
+ return nil, false
+ }
+ return fn, true
+}
+
+func BasicLiteralString(node ast.Node) (string, bool) {
+ name, ok := node.(*ast.BasicLit)
+ if !ok {
+ return "", false
+ }
+ if name.Kind != token.STRING {
+ return "", false
+ }
+ templateName, err := strconv.Unquote(name.Value)
+ if err != nil {
+ return "", false
+ }
+ return templateName, true
+}
diff --git a/internal/asteval/testdata/template/assets_dir.txtar b/internal/asteval/testdata/template/assets_dir.txtar
new file mode 100644
index 0000000..29b7738
--- /dev/null
+++ b/internal/asteval/testdata/template/assets_dir.txtar
@@ -0,0 +1,23 @@
+-- template.go --
+package main
+
+import (
+ "embed"
+ "html/template"
+)
+
+var (
+ //go:embed assets
+ assetsFS embed.FS
+
+ templates = template.Must(template.ParseFS(assetsFS, "assets/*"))
+)
+-- assets/index.gohtml --
+
+{{define "home"}}{{end}}
+
+-- assets/form.gohtml --
+
+{{define "create"}}{{end}}
+
+{{define "update"}}{{end}}
diff --git a/internal/asteval/testdata/template/bad_embed_pattern.txtar b/internal/asteval/testdata/template/bad_embed_pattern.txtar
new file mode 100644
index 0000000..0d4137a
--- /dev/null
+++ b/internal/asteval/testdata/template/bad_embed_pattern.txtar
@@ -0,0 +1,16 @@
+-- template.go --
+package main
+
+import (
+ "embed"
+ "html/template"
+)
+
+var (
+ //go:embed "[fail"
+ assetsFS embed.FS
+
+ templates = template.Must(template.ParseFS(assetsFS, "*"))
+)
+-- greeting.gohtml --
+Hello, friend!
diff --git a/internal/asteval/testdata/template/delims.txtar b/internal/asteval/testdata/template/delims.txtar
new file mode 100644
index 0000000..852baa2
--- /dev/null
+++ b/internal/asteval/testdata/template/delims.txtar
@@ -0,0 +1,22 @@
+-- templates.go --
+package main
+
+var (
+ //go:embed *.gohtml
+ files embed.FS
+)
+
+var (
+ templates = template.Must(
+ template.Must(
+ template.Must(
+ template.New("").ParseFS(files, "default.gohtml")).
+ Delims("(((", ")))").ParseFS(files, "triple_parens.gohtml")).
+ Delims("[[", "]]").ParseFS(files, "double_square.gohtml"))
+)
+-- default.gohtml --
+{{- define "default" -}}{{- end -}}
+-- triple_parens.gohtml --
+(((- define "parens" -)))(((- end -)))
+-- double_square.gohtml --
+[[- define "square" -]][[end]]
diff --git a/internal/asteval/testdata/template/funcs.txtar b/internal/asteval/testdata/template/funcs.txtar
new file mode 100644
index 0000000..619aa40
--- /dev/null
+++ b/internal/asteval/testdata/template/funcs.txtar
@@ -0,0 +1,41 @@
+-- template.go --
+package main
+
+import (
+ "embed"
+ "html/template"
+)
+
+var (
+ //go:embed *.gohtml
+ src embed.FS
+
+ templates = template.New("x").Funcs(template.FuncMap{
+ "greet": func() string { return "Hello" },
+ }).ParseFS(src, "greet.gohtml")
+
+ templatesFuncNotDefined = template.New("x").Funcs(template.FuncMap{
+ "greet": func() string { return "Hello" },
+ }).ParseFS(src, "missing_func.gohtml")
+
+ templatesWrongArg = template.New("x").Funcs(wrong)
+
+ templatesTwoArgs = template.New("x").Funcs(wrong, fail)
+
+ templatesNoArgs = template.New("x").Funcs()
+
+ templatesWrongTypePackageName = template.New("x").Funcs(wrong.FuncMap{})
+
+ templatesWrongTypeName = template.New("x").Funcs(template.Wrong{})
+
+ templatesWrongTypeExpression = template.New("x").Funcs(wrong{})
+
+ templatesWrongTypeElem = template.New("x").Funcs(template.FuncMap{wrong})
+
+ templatesWrongElemKey = template.New("x").Funcs(template.FuncMap{wrong: func() string { return "" }})
+)
+-- greet.gohtml --
+{{greet}}, world!
+
+-- missing_func.gohtml --
+{{greet}}, {{enemy}}!
diff --git a/internal/asteval/testdata/template/parse.txtar b/internal/asteval/testdata/template/parse.txtar
new file mode 100644
index 0000000..7157867
--- /dev/null
+++ b/internal/asteval/testdata/template/parse.txtar
@@ -0,0 +1,15 @@
+-- parse.go --
+package main
+
+import "html/template"
+
+var templates = template.New("GET /").Parse(`Hello, world!
`)
+
+var multiple = template.New("").Parse(`
+{{define "GET /"}}Hello, world!
{{end}}
+{{define "GET /{name}"}}Hello, {{.PathValue "name"}}!
{{end}}
+`)
+
+var noArg = template.New("").Parse()
+
+var wrongArg = template.New("").Parse(500)
\ No newline at end of file
diff --git a/internal/asteval/testdata/template/template_ParseFS.txtar b/internal/asteval/testdata/template/template_ParseFS.txtar
new file mode 100644
index 0000000..b07673b
--- /dev/null
+++ b/internal/asteval/testdata/template/template_ParseFS.txtar
@@ -0,0 +1,38 @@
+-- template.go --
+package main
+
+import (
+ "embed"
+ "html/template"
+)
+
+var (
+ //go:embed *.gohtml
+ templateSource embed.FS
+
+ templates = template.Must(template.ParseFS(templateSource, "*"))
+
+ // allHTML used by templatesHTML and templatesGoHTML to test pattern filtering
+ // this comment ensures the comment parser skips lines not preceded by go:embed
+ //go:embed *html
+ allHTML embed.FS
+
+ templatesHTML = template.Must(template.ParseFS(allHTML, "*.*html"))
+ templatesGoHTML = template.Must(template.ParseFS(allHTML, "*.gohtml"))
+
+ templateEmbedVariableNotFound = template.Must(template.ParseFS(hiding, "*"))
+)
+-- index.gohtml --
+
+{{define "home"}}{{end}}
+
+-- form.gohtml --
+
+{{define "create"}}{{end}}
+
+{{define "update"}}{{end}}
+
+-- script.html --
+{{define "console_log"}}
+
+{{end}}
diff --git a/internal/asteval/testdata/template/templates.txtar b/internal/asteval/testdata/template/templates.txtar
new file mode 100644
index 0000000..66a8671
--- /dev/null
+++ b/internal/asteval/testdata/template/templates.txtar
@@ -0,0 +1,84 @@
+-- template.go --
+package main
+
+import (
+ "embed"
+ "html/template"
+)
+
+var (
+ //go:embed *.gohtml
+ templateSource embed.FS
+
+ templateNew = template.New("some-name")
+
+ templateParseFSNew = template.Must(template.ParseFS(templateSource, "*")).New("greetings")
+
+ templateNewParseFS = template.Must(template.New("greetings").ParseFS(templateSource, "*"))
+
+ templateNewMissingArg = template.New()
+
+ templateWrongX = UNKNOWN.New()
+
+ templateWrongArgCount = template.New("one", "two")
+
+ templateNewOnIndexed = ts[0].New("one", "two")
+
+ templateNewArg42 = template.New(42)
+
+ templateNewArgIdent = template.New(TemplateName)
+
+ templateNewErrUpstream = template.New(fail).New("n")
+
+ templatesIdent = someIdent
+
+ unsupportedMethod = template.Unknown()
+
+ unexpectedFunExpression = x[3]()
+
+ templateMustNonIdentReceiver = f().Must(template.ParseFS(templateSource, "*"))
+
+ templateMustCalledWithTwoArgs = template.Must(nil, nil)
+
+ templateMustCalledWithNoArg s = template.Must()
+
+ templateMustWrongPackageIdent = wrong.Must()
+
+ templateParseFSWrongPackageIdent = wrong.ParseFS(templateSource, "*")
+
+ templateParseFSReceiverErr = template.New().ParseFS(templateSource, "*")
+
+ templateParseFSUnexpectedReceiver = x[0].ParseFS(templateSource, "*")
+
+ templateParseFSNoArgs = template.ParseFS()
+
+ templateParseFSFirstArgNonIdent = template.ParseFS(os.DirFS("."), "*")
+
+ templateParseFSNonStringLiteralGlob = template.ParseFS(templateSource, "w", 42, "x")
+
+ templateParseFSWithBadGlob = template.ParseFS(templateSource, "[fail")
+
+ templateNewHasWrongNumberOfArgs = template.Must(template.New("x").ParseFS(templateSource, "*")).New()
+
+ templateNewHasWrongTypeOfArgs = template.New("x").New(9000)
+
+ templateNewHasTooManyArgs = template.New("x").New("x", "y")
+
+ templateDelimsGetsNoArgs = template.New("x").Delims()
+
+ templateDelimsGetsTooMany = template.New("x").Delims("x", "y", "")
+
+ templateDelimsWrongExpressionArg = template.New("x").Delims("x", y)
+
+ templateParseFSMethodFails = template.New("x").ParseFS(templateSource, fail)
+
+ templateOptionsRequiresStringLiterals = template.New("x").Option(fail)
+
+ templateUnknownMethod = template.New("x").Unknown()
+
+ templateOptionCall = template.New("x").Option("missingkey=default").ParseFS(templateSource, "*")
+
+ templateOptionCallUnknownArg = template.New("x").Option("unknown").ParseFS(templateSource, "*")
+)
+-- index.gohtml --
+Hello, friend!
diff --git a/internal/astgen/format.go b/internal/astgen/format.go
new file mode 100644
index 0000000..748cebe
--- /dev/null
+++ b/internal/astgen/format.go
@@ -0,0 +1,23 @@
+package astgen
+
+import (
+ "bytes"
+ "fmt"
+ "go/ast"
+ "go/format"
+ "go/printer"
+ "go/token"
+)
+
+// Format converts an AST node to formatted Go source code
+func Format(node ast.Node) string {
+ var buf bytes.Buffer
+ if err := printer.Fprint(&buf, token.NewFileSet(), node); err != nil {
+ return fmt.Sprintf("formatting error: %v", err)
+ }
+ out, err := format.Source(buf.Bytes())
+ if err != nil {
+ return fmt.Sprintf("formatting error: %v", err)
+ }
+ return string(bytes.ReplaceAll(out, []byte("\n}\nfunc "), []byte("\n}\n\nfunc ")))
+}
diff --git a/internal/astgen/queries.go b/internal/astgen/queries.go
new file mode 100644
index 0000000..e262076
--- /dev/null
+++ b/internal/astgen/queries.go
@@ -0,0 +1,36 @@
+package astgen
+
+import (
+ "go/ast"
+ "go/token"
+)
+
+// IterateGenDecl returns an iterator over GenDecl nodes with the specified token type
+func IterateGenDecl(files []*ast.File, tok token.Token) func(func(*ast.File, *ast.GenDecl) bool) {
+ return func(yield func(*ast.File, *ast.GenDecl) bool) {
+ for _, file := range files {
+ for _, decl := range file.Decls {
+ d, ok := decl.(*ast.GenDecl)
+ if !ok || d.Tok != tok {
+ continue
+ }
+ if !yield(file, d) {
+ return
+ }
+ }
+ }
+ }
+}
+
+// IterateValueSpecs returns an iterator over ValueSpec nodes in var declarations
+func IterateValueSpecs(files []*ast.File) func(func(*ast.File, *ast.ValueSpec) bool) {
+ return func(yield func(*ast.File, *ast.ValueSpec) bool) {
+ for file, decl := range IterateGenDecl(files, token.VAR) {
+ for _, s := range decl.Specs {
+ if !yield(file, s.(*ast.ValueSpec)) {
+ return
+ }
+ }
+ }
+ }
+}
diff --git a/package.go b/package.go
new file mode 100644
index 0000000..0e102dd
--- /dev/null
+++ b/package.go
@@ -0,0 +1,301 @@
+package check
+
+import (
+ "errors"
+ "fmt"
+ "go/ast"
+ "go/token"
+ "go/types"
+ "html/template"
+
+ "golang.org/x/tools/go/packages"
+
+ "github.com/typelate/check/internal/asteval"
+ "github.com/typelate/check/internal/astgen"
+)
+
+// Package discovers all .ExecuteTemplate calls in the given package,
+// resolves receiver variables to their template construction chains,
+// and type-checks each call.
+//
+// ExecuteTemplate must be called with a string literal for the second parameter.
+func Package(pkg *packages.Package) error {
+ // Phase 1: Find all ExecuteTemplate calls and collect receiver objects.
+ type pendingCall struct {
+ receiverObj types.Object
+ templateName string
+ dataType types.Type
+ }
+ var pending []pendingCall
+ receiverSet := make(map[types.Object]struct{})
+
+ for _, file := range pkg.Syntax {
+ ast.Inspect(file, func(node ast.Node) bool {
+ call, ok := node.(*ast.CallExpr)
+ if !ok || len(call.Args) != 3 {
+ return true
+ }
+ sel, ok := call.Fun.(*ast.SelectorExpr)
+ if !ok || sel.Sel.Name != "ExecuteTemplate" {
+ return true
+ }
+ // Verify the method belongs to html/template or text/template.
+ if !isTemplateMethod(pkg.TypesInfo, sel) {
+ return true
+ }
+ receiverIdent, ok := sel.X.(*ast.Ident)
+ if !ok {
+ return true
+ }
+ obj := pkg.TypesInfo.Uses[receiverIdent]
+ if obj == nil {
+ return true
+ }
+ templateName, ok := asteval.BasicLiteralString(call.Args[1])
+ if !ok {
+ return true
+ }
+ dataType := pkg.TypesInfo.TypeOf(call.Args[2])
+ pending = append(pending, pendingCall{
+ receiverObj: obj,
+ templateName: templateName,
+ dataType: dataType,
+ })
+ receiverSet[obj] = struct{}{}
+ return true
+ })
+ }
+
+ // Phase 2: Resolve each unique receiver object to its template construction chain.
+ type resolvedTemplate struct {
+ ts *template.Template
+ funcs asteval.TemplateFunctions
+ meta *asteval.TemplateMetadata
+ }
+ resolved := make(map[types.Object]*resolvedTemplate)
+
+ workingDirectory := packageDirectory(pkg)
+ embeddedPaths, err := asteval.RelativeFilePaths(workingDirectory, pkg.EmbedFiles...)
+ if err != nil {
+ return fmt.Errorf("failed to calculate relative path for embedded files: %w", err)
+ }
+
+ resolveValueSpec := func(tv *ast.ValueSpec) {
+ for i, name := range tv.Names {
+ if i >= len(tv.Values) {
+ continue
+ }
+ obj := pkg.TypesInfo.Defs[name]
+ if obj == nil {
+ continue
+ }
+ if _, needed := receiverSet[obj]; !needed {
+ continue
+ }
+
+ funcTypeMap := asteval.DefaultFunctions(pkg.Types)
+ meta := &asteval.TemplateMetadata{}
+ ts, _, _, err := asteval.EvaluateTemplateSelector(nil, pkg.Types, pkg.TypesInfo, tv.Values[i], workingDirectory, name.Name, "", "", pkg.Fset, pkg.Syntax, embeddedPaths, funcTypeMap, make(template.FuncMap), meta)
+ if err != nil {
+ return
+ }
+ resolved[obj] = &resolvedTemplate{
+ ts: ts,
+ funcs: funcTypeMap,
+ meta: meta,
+ }
+ }
+ }
+
+ resolveAssignStmt := func(stmt *ast.AssignStmt) {
+ if stmt.Tok != token.DEFINE {
+ return
+ }
+ for i, lhs := range stmt.Lhs {
+ if i >= len(stmt.Rhs) {
+ continue
+ }
+ ident, ok := lhs.(*ast.Ident)
+ if !ok {
+ continue
+ }
+ obj := pkg.TypesInfo.Defs[ident]
+ if obj == nil {
+ continue
+ }
+ if _, needed := receiverSet[obj]; !needed {
+ continue
+ }
+
+ funcTypeMap := asteval.DefaultFunctions(pkg.Types)
+ meta := &asteval.TemplateMetadata{}
+ ts, _, _, err := asteval.EvaluateTemplateSelector(nil, pkg.Types, pkg.TypesInfo, stmt.Rhs[i], workingDirectory, ident.Name, "", "", pkg.Fset, pkg.Syntax, embeddedPaths, funcTypeMap, make(template.FuncMap), meta)
+ if err != nil {
+ return
+ }
+ resolved[obj] = &resolvedTemplate{
+ ts: ts,
+ funcs: funcTypeMap,
+ meta: meta,
+ }
+ }
+ }
+
+ // Resolve top-level var declarations.
+ for _, tv := range astgen.IterateValueSpecs(pkg.Syntax) {
+ resolveValueSpec(tv)
+ }
+
+ // Resolve function-local var declarations and short variable declarations.
+ for _, file := range pkg.Syntax {
+ ast.Inspect(file, func(node ast.Node) bool {
+ switch n := node.(type) {
+ case *ast.DeclStmt:
+ gd, ok := n.Decl.(*ast.GenDecl)
+ if !ok || gd.Tok != token.VAR {
+ return true
+ }
+ for _, spec := range gd.Specs {
+ vs, ok := spec.(*ast.ValueSpec)
+ if !ok {
+ continue
+ }
+ resolveValueSpec(vs)
+ }
+ case *ast.AssignStmt:
+ resolveAssignStmt(n)
+ }
+ return true
+ })
+ }
+
+ // Phase 2b: Find additional ParseFS/Parse calls on resolved template variables.
+ for _, file := range pkg.Syntax {
+ ast.Inspect(file, func(node ast.Node) bool {
+ call, ok := node.(*ast.CallExpr)
+ if !ok {
+ return true
+ }
+ obj := findModificationReceiver(call, pkg.TypesInfo)
+ if obj == nil {
+ return true
+ }
+ rt, ok := resolved[obj]
+ if !ok {
+ return true
+ }
+ meta := &asteval.TemplateMetadata{}
+ ts, _, _, err := asteval.EvaluateTemplateSelector(rt.ts, pkg.Types, pkg.TypesInfo, call, workingDirectory, "", "", "", pkg.Fset, pkg.Syntax, embeddedPaths, rt.funcs, make(template.FuncMap), meta)
+ if err != nil {
+ return true
+ }
+ rt.ts = ts
+ rt.meta.EmbedFilePaths = append(rt.meta.EmbedFilePaths, meta.EmbedFilePaths...)
+ rt.meta.ParseCalls = append(rt.meta.ParseCalls, meta.ParseCalls...)
+ return true
+ })
+ }
+
+ // Phase 3: Type-check each ExecuteTemplate call.
+ mergedFunctions := make(Functions)
+ if pkg.Types != nil {
+ mergedFunctions = DefaultFunctions(pkg.Types)
+ }
+ for _, rt := range resolved {
+ for name, sig := range rt.funcs {
+ mergedFunctions[name] = sig
+ }
+ }
+
+ var errs []error
+ for _, p := range pending {
+ rt, ok := resolved[p.receiverObj]
+ if !ok {
+ continue
+ }
+ looked := rt.ts.Lookup(p.templateName)
+ if looked == nil {
+ continue
+ }
+ treeFinder := (*asteval.Forrest)(rt.ts)
+ global := NewGlobal(pkg.Types, pkg.Fset, treeFinder, mergedFunctions)
+ if err := Execute(global, looked.Tree, p.dataType); err != nil {
+ errs = append(errs, err)
+ }
+ }
+
+ return errors.Join(errs...)
+}
+
+// isTemplateMethod reports whether sel refers to a method on
+// *html/template.Template or *text/template.Template.
+func isTemplateMethod(typesInfo *types.Info, sel *ast.SelectorExpr) bool {
+ if typesInfo == nil {
+ return false
+ }
+ selection, ok := typesInfo.Selections[sel]
+ if !ok {
+ return false
+ }
+ fn, ok := selection.Obj().(*types.Func)
+ if !ok {
+ return false
+ }
+ fnPkg := fn.Pkg()
+ if fnPkg == nil {
+ return false
+ }
+ return fnPkg.Path() == "html/template" || fnPkg.Path() == "text/template"
+}
+
+// findModificationReceiver unwraps template.Must and returns the types.Object
+// of the variable receiver for a method call like ts.ParseFS(...) or
+// template.Must(ts.ParseFS(...)). Returns nil if no variable receiver is found.
+func findModificationReceiver(expr *ast.CallExpr, typesInfo *types.Info) types.Object {
+ sel, ok := expr.Fun.(*ast.SelectorExpr)
+ if !ok {
+ return nil
+ }
+ switch x := sel.X.(type) {
+ case *ast.Ident:
+ if isTemplatePkgIdent(typesInfo, x) && sel.Sel.Name == "Must" && len(expr.Args) == 1 {
+ inner, ok := expr.Args[0].(*ast.CallExpr)
+ if !ok {
+ return nil
+ }
+ return findModificationReceiver(inner, typesInfo)
+ }
+ if isTemplatePkgIdent(typesInfo, x) {
+ return nil
+ }
+ return typesInfo.Uses[x]
+ }
+ return nil
+}
+
+// isTemplatePkgIdent reports whether ident refers to the "html/template"
+// or "text/template" package via the type checker.
+func isTemplatePkgIdent(typesInfo *types.Info, ident *ast.Ident) bool {
+ if typesInfo == nil {
+ return false
+ }
+ obj := typesInfo.Uses[ident]
+ pkgName, ok := obj.(*types.PkgName)
+ if !ok {
+ return false
+ }
+ path := pkgName.Imported().Path()
+ return path == "html/template" || path == "text/template"
+}
+
+func packageDirectory(pkg *packages.Package) string {
+ if len(pkg.GoFiles) > 0 {
+ p := pkg.GoFiles[0]
+ for i := len(p) - 1; i >= 0; i-- {
+ if p[i] == '/' || p[i] == '\\' {
+ return p[:i]
+ }
+ }
+ }
+ return "."
+}
From ec8a2451b728ca93868e1b38551d5b2583f3d57a Mon Sep 17 00:00:00 2001
From: Christopher Hunter <8398225+crhntr@users.noreply.github.com>
Date: Sat, 7 Feb 2026 12:35:15 -0800
Subject: [PATCH 2/3] chore: refactor Package function
break it up and add more to asteval
---
internal/asteval/template.go | 52 +++++++-
package.go | 240 ++++++++++++++---------------------
2 files changed, 143 insertions(+), 149 deletions(-)
diff --git a/internal/asteval/template.go b/internal/asteval/template.go
index bdfe916..2ee2af2 100644
--- a/internal/asteval/template.go
+++ b/internal/asteval/template.go
@@ -38,7 +38,7 @@ func EvaluateTemplateSelector(ts *template.Template, pkg *types.Package, typesIn
default:
return nil, lDelim, rDelim, wrapWithFilename(workingDirectory, fileSet, sel.X.Pos(), fmt.Errorf("expected exactly one argument %s got %d", astgen.Format(sel.X), len(call.Args)))
case *ast.Ident:
- if !isTemplatePkgIdent(typesInfo, x) {
+ if !IsTemplatePkgIdent(typesInfo, x) {
if ts == nil {
return nil, lDelim, rDelim, wrapWithFilename(workingDirectory, fileSet, sel.X.Pos(), fmt.Errorf("expected template package got %s", astgen.Format(sel.X)))
}
@@ -187,9 +187,9 @@ func EvaluateTemplateSelector(ts *template.Template, pkg *types.Package, typesIn
}
}
-// isTemplatePkgIdent reports whether ident refers to the "html/template"
+// IsTemplatePkgIdent reports whether ident refers to the "html/template"
// or "text/template" package via the type checker.
-func isTemplatePkgIdent(info *types.Info, ident *ast.Ident) bool {
+func IsTemplatePkgIdent(info *types.Info, ident *ast.Ident) bool {
if info == nil {
return false
}
@@ -202,6 +202,52 @@ func isTemplatePkgIdent(info *types.Info, ident *ast.Ident) bool {
return path == "html/template" || path == "text/template"
}
+// IsTemplateMethod reports whether sel refers to a method on
+// *html/template.Template or *text/template.Template.
+func IsTemplateMethod(typesInfo *types.Info, sel *ast.SelectorExpr) bool {
+ if typesInfo == nil {
+ return false
+ }
+ selection, ok := typesInfo.Selections[sel]
+ if !ok {
+ return false
+ }
+ fn, ok := selection.Obj().(*types.Func)
+ if !ok {
+ return false
+ }
+ fnPkg := fn.Pkg()
+ if fnPkg == nil {
+ return false
+ }
+ return fnPkg.Path() == "html/template" || fnPkg.Path() == "text/template"
+}
+
+// FindModificationReceiver unwraps template.Must and returns the types.Object
+// of the variable receiver for a method call like ts.ParseFS(...) or
+// template.Must(ts.ParseFS(...)). Returns nil if no variable receiver is found.
+func FindModificationReceiver(expr *ast.CallExpr, typesInfo *types.Info) types.Object {
+ sel, ok := expr.Fun.(*ast.SelectorExpr)
+ if !ok {
+ return nil
+ }
+ switch x := sel.X.(type) {
+ case *ast.Ident:
+ if IsTemplatePkgIdent(typesInfo, x) && sel.Sel.Name == "Must" && len(expr.Args) == 1 {
+ inner, ok := expr.Args[0].(*ast.CallExpr)
+ if !ok {
+ return nil
+ }
+ return FindModificationReceiver(inner, typesInfo)
+ }
+ if IsTemplatePkgIdent(typesInfo, x) {
+ return nil
+ }
+ return typesInfo.Uses[x]
+ }
+ return nil
+}
+
func builtins() template.FuncMap {
type nothing struct{}
return template.FuncMap{
diff --git a/package.go b/package.go
index 0e102dd..a2ed138 100644
--- a/package.go
+++ b/package.go
@@ -7,6 +7,7 @@ import (
"go/token"
"go/types"
"html/template"
+ "path/filepath"
"golang.org/x/tools/go/packages"
@@ -14,18 +15,36 @@ import (
"github.com/typelate/check/internal/astgen"
)
+type pendingCall struct {
+ receiverObj types.Object
+ templateName string
+ dataType types.Type
+}
+
+type resolvedTemplate struct {
+ templates *template.Template
+ functions asteval.TemplateFunctions
+ metadata *asteval.TemplateMetadata
+}
+
// Package discovers all .ExecuteTemplate calls in the given package,
// resolves receiver variables to their template construction chains,
// and type-checks each call.
//
// ExecuteTemplate must be called with a string literal for the second parameter.
func Package(pkg *packages.Package) error {
- // Phase 1: Find all ExecuteTemplate calls and collect receiver objects.
- type pendingCall struct {
- receiverObj types.Object
- templateName string
- dataType types.Type
+ pending, receivers := findExecuteCalls(pkg)
+ resolved, err := resolveTemplates(pkg, receivers)
+ if err != nil {
+ return err
}
+ return checkCalls(pkg, pending, resolved)
+}
+
+// findExecuteCalls walks the package syntax looking for ExecuteTemplate calls
+// and returns the pending calls along with the set of receiver objects that
+// need template resolution.
+func findExecuteCalls(pkg *packages.Package) ([]pendingCall, map[types.Object]struct{}) {
var pending []pendingCall
receiverSet := make(map[types.Object]struct{})
@@ -40,7 +59,7 @@ func Package(pkg *packages.Package) error {
return true
}
// Verify the method belongs to html/template or text/template.
- if !isTemplateMethod(pkg.TypesInfo, sel) {
+ if !asteval.IsTemplateMethod(pkg.TypesInfo, sel) {
return true
}
receiverIdent, ok := sel.X.(*ast.Ident)
@@ -66,86 +85,51 @@ func Package(pkg *packages.Package) error {
})
}
- // Phase 2: Resolve each unique receiver object to its template construction chain.
- type resolvedTemplate struct {
- ts *template.Template
- funcs asteval.TemplateFunctions
- meta *asteval.TemplateMetadata
- }
+ return pending, receiverSet
+}
+
+// resolveTemplates resolves each unique receiver object to its template
+// construction chain, including additional ParseFS/Parse modifications.
+func resolveTemplates(pkg *packages.Package, receivers map[types.Object]struct{}) (map[types.Object]*resolvedTemplate, error) {
resolved := make(map[types.Object]*resolvedTemplate)
workingDirectory := packageDirectory(pkg)
embeddedPaths, err := asteval.RelativeFilePaths(workingDirectory, pkg.EmbedFiles...)
if err != nil {
- return fmt.Errorf("failed to calculate relative path for embedded files: %w", err)
+ return nil, fmt.Errorf("failed to calculate relative path for embedded files: %w", err)
}
- resolveValueSpec := func(tv *ast.ValueSpec) {
- for i, name := range tv.Names {
- if i >= len(tv.Values) {
- continue
- }
- obj := pkg.TypesInfo.Defs[name]
- if obj == nil {
- continue
- }
- if _, needed := receiverSet[obj]; !needed {
- continue
- }
-
- funcTypeMap := asteval.DefaultFunctions(pkg.Types)
- meta := &asteval.TemplateMetadata{}
- ts, _, _, err := asteval.EvaluateTemplateSelector(nil, pkg.Types, pkg.TypesInfo, tv.Values[i], workingDirectory, name.Name, "", "", pkg.Fset, pkg.Syntax, embeddedPaths, funcTypeMap, make(template.FuncMap), meta)
- if err != nil {
- return
- }
- resolved[obj] = &resolvedTemplate{
- ts: ts,
- funcs: funcTypeMap,
- meta: meta,
- }
+ resolveExpr := func(obj types.Object, name string, expr ast.Expr) {
+ if _, needed := receivers[obj]; !needed {
+ return
}
- }
-
- resolveAssignStmt := func(stmt *ast.AssignStmt) {
- if stmt.Tok != token.DEFINE {
+ funcTypeMap := asteval.DefaultFunctions(pkg.Types)
+ meta := &asteval.TemplateMetadata{}
+ ts, _, _, err := asteval.EvaluateTemplateSelector(nil, pkg.Types, pkg.TypesInfo, expr, workingDirectory, name, "", "", pkg.Fset, pkg.Syntax, embeddedPaths, funcTypeMap, make(template.FuncMap), meta)
+ if err != nil {
return
}
- for i, lhs := range stmt.Lhs {
- if i >= len(stmt.Rhs) {
- continue
- }
- ident, ok := lhs.(*ast.Ident)
- if !ok {
+ resolved[obj] = &resolvedTemplate{
+ templates: ts,
+ functions: funcTypeMap,
+ metadata: meta,
+ }
+ }
+
+ // Resolve top-level var declarations.
+ for _, tv := range astgen.IterateValueSpecs(pkg.Syntax) {
+ for i, ident := range tv.Names {
+ if i >= len(tv.Values) {
continue
}
obj := pkg.TypesInfo.Defs[ident]
if obj == nil {
continue
}
- if _, needed := receiverSet[obj]; !needed {
- continue
- }
-
- funcTypeMap := asteval.DefaultFunctions(pkg.Types)
- meta := &asteval.TemplateMetadata{}
- ts, _, _, err := asteval.EvaluateTemplateSelector(nil, pkg.Types, pkg.TypesInfo, stmt.Rhs[i], workingDirectory, ident.Name, "", "", pkg.Fset, pkg.Syntax, embeddedPaths, funcTypeMap, make(template.FuncMap), meta)
- if err != nil {
- return
- }
- resolved[obj] = &resolvedTemplate{
- ts: ts,
- funcs: funcTypeMap,
- meta: meta,
- }
+ resolveExpr(obj, ident.Name, tv.Values[i])
}
}
- // Resolve top-level var declarations.
- for _, tv := range astgen.IterateValueSpecs(pkg.Syntax) {
- resolveValueSpec(tv)
- }
-
// Resolve function-local var declarations and short variable declarations.
for _, file := range pkg.Syntax {
ast.Inspect(file, func(node ast.Node) bool {
@@ -160,23 +144,48 @@ func Package(pkg *packages.Package) error {
if !ok {
continue
}
- resolveValueSpec(vs)
+ for i, ident := range vs.Names {
+ if i >= len(vs.Values) {
+ continue
+ }
+ obj := pkg.TypesInfo.Defs[ident]
+ if obj == nil {
+ continue
+ }
+ resolveExpr(obj, ident.Name, vs.Values[i])
+ }
}
case *ast.AssignStmt:
- resolveAssignStmt(n)
+ if n.Tok != token.DEFINE {
+ return true
+ }
+ for i, lhs := range n.Lhs {
+ if i >= len(n.Rhs) {
+ continue
+ }
+ ident, ok := lhs.(*ast.Ident)
+ if !ok {
+ continue
+ }
+ obj := pkg.TypesInfo.Defs[ident]
+ if obj == nil {
+ continue
+ }
+ resolveExpr(obj, ident.Name, n.Rhs[i])
+ }
}
return true
})
}
- // Phase 2b: Find additional ParseFS/Parse calls on resolved template variables.
+ // Find additional ParseFS/Parse calls on resolved template variables.
for _, file := range pkg.Syntax {
ast.Inspect(file, func(node ast.Node) bool {
call, ok := node.(*ast.CallExpr)
if !ok {
return true
}
- obj := findModificationReceiver(call, pkg.TypesInfo)
+ obj := asteval.FindModificationReceiver(call, pkg.TypesInfo)
if obj == nil {
return true
}
@@ -185,24 +194,29 @@ func Package(pkg *packages.Package) error {
return true
}
meta := &asteval.TemplateMetadata{}
- ts, _, _, err := asteval.EvaluateTemplateSelector(rt.ts, pkg.Types, pkg.TypesInfo, call, workingDirectory, "", "", "", pkg.Fset, pkg.Syntax, embeddedPaths, rt.funcs, make(template.FuncMap), meta)
+ ts, _, _, err := asteval.EvaluateTemplateSelector(rt.templates, pkg.Types, pkg.TypesInfo, call, workingDirectory, "", "", "", pkg.Fset, pkg.Syntax, embeddedPaths, rt.functions, make(template.FuncMap), meta)
if err != nil {
return true
}
- rt.ts = ts
- rt.meta.EmbedFilePaths = append(rt.meta.EmbedFilePaths, meta.EmbedFilePaths...)
- rt.meta.ParseCalls = append(rt.meta.ParseCalls, meta.ParseCalls...)
+ rt.templates = ts
+ rt.metadata.EmbedFilePaths = append(rt.metadata.EmbedFilePaths, meta.EmbedFilePaths...)
+ rt.metadata.ParseCalls = append(rt.metadata.ParseCalls, meta.ParseCalls...)
return true
})
}
- // Phase 3: Type-check each ExecuteTemplate call.
+ return resolved, nil
+}
+
+// checkCalls type-checks each pending ExecuteTemplate call against its
+// resolved template.
+func checkCalls(pkg *packages.Package, pending []pendingCall, resolved map[types.Object]*resolvedTemplate) error {
mergedFunctions := make(Functions)
if pkg.Types != nil {
mergedFunctions = DefaultFunctions(pkg.Types)
}
for _, rt := range resolved {
- for name, sig := range rt.funcs {
+ for name, sig := range rt.functions {
mergedFunctions[name] = sig
}
}
@@ -213,11 +227,11 @@ func Package(pkg *packages.Package) error {
if !ok {
continue
}
- looked := rt.ts.Lookup(p.templateName)
+ looked := rt.templates.Lookup(p.templateName)
if looked == nil {
continue
}
- treeFinder := (*asteval.Forrest)(rt.ts)
+ treeFinder := (*asteval.Forrest)(rt.templates)
global := NewGlobal(pkg.Types, pkg.Fset, treeFinder, mergedFunctions)
if err := Execute(global, looked.Tree, p.dataType); err != nil {
errs = append(errs, err)
@@ -227,75 +241,9 @@ func Package(pkg *packages.Package) error {
return errors.Join(errs...)
}
-// isTemplateMethod reports whether sel refers to a method on
-// *html/template.Template or *text/template.Template.
-func isTemplateMethod(typesInfo *types.Info, sel *ast.SelectorExpr) bool {
- if typesInfo == nil {
- return false
- }
- selection, ok := typesInfo.Selections[sel]
- if !ok {
- return false
- }
- fn, ok := selection.Obj().(*types.Func)
- if !ok {
- return false
- }
- fnPkg := fn.Pkg()
- if fnPkg == nil {
- return false
- }
- return fnPkg.Path() == "html/template" || fnPkg.Path() == "text/template"
-}
-
-// findModificationReceiver unwraps template.Must and returns the types.Object
-// of the variable receiver for a method call like ts.ParseFS(...) or
-// template.Must(ts.ParseFS(...)). Returns nil if no variable receiver is found.
-func findModificationReceiver(expr *ast.CallExpr, typesInfo *types.Info) types.Object {
- sel, ok := expr.Fun.(*ast.SelectorExpr)
- if !ok {
- return nil
- }
- switch x := sel.X.(type) {
- case *ast.Ident:
- if isTemplatePkgIdent(typesInfo, x) && sel.Sel.Name == "Must" && len(expr.Args) == 1 {
- inner, ok := expr.Args[0].(*ast.CallExpr)
- if !ok {
- return nil
- }
- return findModificationReceiver(inner, typesInfo)
- }
- if isTemplatePkgIdent(typesInfo, x) {
- return nil
- }
- return typesInfo.Uses[x]
- }
- return nil
-}
-
-// isTemplatePkgIdent reports whether ident refers to the "html/template"
-// or "text/template" package via the type checker.
-func isTemplatePkgIdent(typesInfo *types.Info, ident *ast.Ident) bool {
- if typesInfo == nil {
- return false
- }
- obj := typesInfo.Uses[ident]
- pkgName, ok := obj.(*types.PkgName)
- if !ok {
- return false
- }
- path := pkgName.Imported().Path()
- return path == "html/template" || path == "text/template"
-}
-
func packageDirectory(pkg *packages.Package) string {
if len(pkg.GoFiles) > 0 {
- p := pkg.GoFiles[0]
- for i := len(p) - 1; i >= 0; i-- {
- if p[i] == '/' || p[i] == '\\' {
- return p[:i]
- }
- }
+ return filepath.Dir(pkg.GoFiles[0])
}
return "."
}
From 45e0ccd122515888e2a4d709cc8d2eb74d353491 Mon Sep 17 00:00:00 2001
From: Christopher Hunter <8398225+crhntr@users.noreply.github.com>
Date: Sat, 7 Feb 2026 16:53:34 -0800
Subject: [PATCH 3/3] chore: add test coverage for text vs html template
packages
---
.../err_html_template_err_branch_end.txt | 36 ++++++++++
.../pass_text_template_no_err_branch_end.txt | 35 ++++++++++
internal/asteval/template.go | 56 ++++++++++++----
internal/asteval/template_html.go | 67 +++++++++++++++++++
internal/asteval/template_packages.go | 19 ++++++
internal/asteval/template_text.go | 67 +++++++++++++++++++
package.go | 29 ++++----
7 files changed, 281 insertions(+), 28 deletions(-)
create mode 100644 cmd/check-templates/testdata/err_html_template_err_branch_end.txt
create mode 100644 cmd/check-templates/testdata/pass_text_template_no_err_branch_end.txt
create mode 100644 internal/asteval/template_html.go
create mode 100644 internal/asteval/template_packages.go
create mode 100644 internal/asteval/template_text.go
diff --git a/cmd/check-templates/testdata/err_html_template_err_branch_end.txt b/cmd/check-templates/testdata/err_html_template_err_branch_end.txt
new file mode 100644
index 0000000..a5d2fbd
--- /dev/null
+++ b/cmd/check-templates/testdata/err_html_template_err_branch_end.txt
@@ -0,0 +1,36 @@
+# html/template errors with ErrBranchEnd at Execute time.
+# The go test exercises the html/template escape analysis which rejects
+# templates whose {{if}} branches end in different HTML contexts.
+# Companion: pass_text_template_no_err_branch_end.txt
+
+check-templates
+exec go test
+
+-- go.mod --
+module example.com/app
+
+go 1.25.0
+-- main.go --
+package main
+
+import "html/template"
+
+var templates = template.Must(template.New("page").Parse(`{{end}}
`))
+
+func main() {}
+
+-- main_test.go --
+package main
+
+import (
+ "io"
+ "testing"
+)
+
+func TestExecuteBranchEnd(t *testing.T) {
+ err := templates.Execute(io.Discard, false)
+ if err == nil {
+ t.Fatal("expected html/template ErrBranchEnd error, got nil")
+ }
+ t.Logf("got expected error: %s", err.Error())
+}
diff --git a/cmd/check-templates/testdata/pass_text_template_no_err_branch_end.txt b/cmd/check-templates/testdata/pass_text_template_no_err_branch_end.txt
new file mode 100644
index 0000000..1ae730d
--- /dev/null
+++ b/cmd/check-templates/testdata/pass_text_template_no_err_branch_end.txt
@@ -0,0 +1,35 @@
+# text/template does not error with ErrBranchEnd.
+# The go test confirms text/template.Execute succeeds on the same template
+# that html/template rejects due to branches ending in different contexts.
+# Companion: err_html_template_err_branch_end.txt
+
+check-templates
+exec go test
+
+-- go.mod --
+module example.com/app
+
+go 1.25.0
+-- main.go --
+package main
+
+import "text/template"
+
+var templates = template.Must(template.New("page").Parse(`{{end}}
`))
+
+func main() {}
+
+-- main_test.go --
+package main
+
+import (
+ "io"
+ "testing"
+)
+
+func TestExecuteNoBranchEnd(t *testing.T) {
+ err := templates.Execute(io.Discard, false)
+ if err != nil {
+ t.Fatalf("text/template should not produce ErrBranchEnd, got: %v", err)
+ }
+}
diff --git a/internal/asteval/template.go b/internal/asteval/template.go
index 2ee2af2..b0293e9 100644
--- a/internal/asteval/template.go
+++ b/internal/asteval/template.go
@@ -7,7 +7,6 @@ import (
"go/format"
"go/token"
"go/types"
- "html/template"
"os"
"path/filepath"
"slices"
@@ -19,13 +18,28 @@ import (
"github.com/typelate/check/internal/astgen"
)
+// Template abstracts over html/template.Template and text/template.Template
+// so that the correct template package is used based on the user's import.
+type Template interface {
+ New(name string) Template
+ Parse(text string) (Template, error)
+ Funcs(funcMap map[string]any) Template
+ Option(opt ...string) Template
+ Delims(left, right string) Template
+ Lookup(name string) Template
+ Name() string
+ AddParseTree(name string, tree *parse.Tree) (Template, error)
+ Tree() *parse.Tree
+ FindTree(name string) (*parse.Tree, bool)
+}
+
// TemplateMetadata accumulates metadata during template evaluation.
type TemplateMetadata struct {
EmbedFilePaths []string
ParseCalls []*ast.BasicLit
}
-func EvaluateTemplateSelector(ts *template.Template, pkg *types.Package, typesInfo *types.Info, expression ast.Expr, workingDirectory, templatesVariable, rDelim, lDelim string, fileSet *token.FileSet, files []*ast.File, embeddedPaths []string, funcTypeMaps TemplateFunctions, fm template.FuncMap, meta *TemplateMetadata) (*template.Template, string, string, error) {
+func EvaluateTemplateSelector(ts Template, pkg *types.Package, typesInfo *types.Info, expression ast.Expr, workingDirectory, templatesVariable, rDelim, lDelim string, fileSet *token.FileSet, files []*ast.File, embeddedPaths []string, funcTypeMaps TemplateFunctions, fm map[string]any, meta *TemplateMetadata) (Template, string, string, error) {
call, ok := expression.(*ast.CallExpr)
if !ok {
return nil, lDelim, rDelim, wrapWithFilename(workingDirectory, fileSet, expression.Pos(), fmt.Errorf("expected call expression"))
@@ -52,7 +66,8 @@ func EvaluateTemplateSelector(ts *template.Template, pkg *types.Package, typesIn
if meta != nil {
meta.EmbedFilePaths = append(meta.EmbedFilePaths, filePaths...)
}
- t, err := parseFiles(ts, fm, lDelim, rDelim, filePaths...)
+ pkgPath := templatePkgPath(typesInfo, x)
+ t, err := parseFiles(ts, pkgPath, fm, lDelim, rDelim, filePaths...)
return t, lDelim, rDelim, err
case "Parse":
if len(call.Args) != 1 {
@@ -93,6 +108,7 @@ func EvaluateTemplateSelector(ts *template.Template, pkg *types.Package, typesIn
return nil, lDelim, rDelim, wrapWithFilename(workingDirectory, fileSet, call.Fun.Pos(), fmt.Errorf("unsupported method %s on variable receiver", sel.Sel.Name))
}
}
+ pkgPath := templatePkgPath(typesInfo, x)
switch sel.Sel.Name {
case "Must":
if len(call.Args) != 1 {
@@ -107,7 +123,7 @@ func EvaluateTemplateSelector(ts *template.Template, pkg *types.Package, typesIn
if err != nil {
return nil, lDelim, rDelim, err
}
- return template.New(templateNames[0]), lDelim, rDelim, nil
+ return NewTemplate(pkgPath, templateNames[0]), lDelim, rDelim, nil
case "ParseFS":
filePaths, err := evaluateCallParseFilesArgs(workingDirectory, fileSet, call, files, embeddedPaths)
if err != nil {
@@ -116,7 +132,7 @@ func EvaluateTemplateSelector(ts *template.Template, pkg *types.Package, typesIn
if meta != nil {
meta.EmbedFilePaths = append(meta.EmbedFilePaths, filePaths...)
}
- t, err := parseFiles(nil, fm, lDelim, rDelim, filePaths...)
+ t, err := parseFiles(nil, pkgPath, fm, lDelim, rDelim, filePaths...)
return t, lDelim, rDelim, err
default:
return nil, lDelim, rDelim, wrapWithFilename(workingDirectory, fileSet, call.Fun.Pos(), fmt.Errorf("unsupported function %s", sel.Sel.Name))
@@ -168,7 +184,7 @@ func EvaluateTemplateSelector(ts *template.Template, pkg *types.Package, typesIn
if meta != nil {
meta.EmbedFilePaths = append(meta.EmbedFilePaths, filePaths...)
}
- t, err := parseFiles(up, fm, upLDelim, upRDelim, filePaths...)
+ t, err := parseFiles(up, "", fm, upLDelim, upRDelim, filePaths...)
return t, upLDelim, upRDelim, err
case "Option":
list, err := StringLiteralExpressionList(workingDirectory, fileSet, call.Args)
@@ -187,6 +203,20 @@ func EvaluateTemplateSelector(ts *template.Template, pkg *types.Package, typesIn
}
}
+// templatePkgPath extracts the import path ("html/template" or "text/template")
+// from an AST identifier that refers to a template package.
+func templatePkgPath(info *types.Info, ident *ast.Ident) string {
+ if info == nil {
+ return "html/template"
+ }
+ obj := info.Uses[ident]
+ pkgName, ok := obj.(*types.PkgName)
+ if !ok {
+ return "html/template"
+ }
+ return pkgName.Imported().Path()
+}
+
// IsTemplatePkgIdent reports whether ident refers to the "html/template"
// or "text/template" package via the type checker.
func IsTemplatePkgIdent(info *types.Info, ident *ast.Ident) bool {
@@ -248,9 +278,9 @@ func FindModificationReceiver(expr *ast.CallExpr, typesInfo *types.Info) types.O
return nil
}
-func builtins() template.FuncMap {
+func builtins() map[string]any {
type nothing struct{}
- return template.FuncMap{
+ return map[string]any{
"and": func() (_ nothing) { return },
"call": func() (_ nothing) { return },
"html": func() (_ nothing) { return },
@@ -275,9 +305,9 @@ func builtins() template.FuncMap {
}
}
-func parseFiles(t *template.Template, fm template.FuncMap, leftDelim, rightDelim string, filenames ...string) (*template.Template, error) {
+func parseFiles(t Template, pkgPath string, fm map[string]any, leftDelim, rightDelim string, filenames ...string) (Template, error) {
if len(filenames) == 0 {
- return nil, fmt.Errorf("html/template: no files named in call to ParseFiles")
+ return nil, fmt.Errorf("template: no files named in call to ParseFiles")
}
for _, filename := range filenames {
templateName := filepath.Base(filename)
@@ -286,9 +316,9 @@ func parseFiles(t *template.Template, fm template.FuncMap, leftDelim, rightDelim
return nil, err
}
s := string(b)
- var tmpl *template.Template
+ var tmpl Template
if t == nil {
- t = template.New(templateName)
+ t = NewTemplate(pkgPath, templateName)
}
if templateName == t.Name() {
tmpl = t
@@ -313,7 +343,7 @@ func parseFiles(t *template.Template, fm template.FuncMap, leftDelim, rightDelim
return t, nil
}
-func evaluateFuncMap(workingDirectory string, typesInfo *types.Info, pkg *types.Package, fileSet *token.FileSet, call *ast.CallExpr, fm template.FuncMap, funcTypesMap TemplateFunctions) error {
+func evaluateFuncMap(workingDirectory string, typesInfo *types.Info, pkg *types.Package, fileSet *token.FileSet, call *ast.CallExpr, fm map[string]any, funcTypesMap TemplateFunctions) error {
if len(call.Args) != 1 {
return wrapWithFilename(workingDirectory, fileSet, call.Lparen, fmt.Errorf("expected exactly 1 template.FuncMap composite literal argument"))
}
diff --git a/internal/asteval/template_html.go b/internal/asteval/template_html.go
new file mode 100644
index 0000000..31c75fc
--- /dev/null
+++ b/internal/asteval/template_html.go
@@ -0,0 +1,67 @@
+package asteval
+
+import (
+ "html/template"
+ "text/template/parse"
+)
+
+// htmlTemplate wraps *html/template.Template.
+type htmlTemplate struct {
+ t *template.Template
+}
+
+func (h *htmlTemplate) New(name string) Template {
+ return &htmlTemplate{t: h.t.New(name)}
+}
+
+func (h *htmlTemplate) Parse(text string) (Template, error) {
+ t, err := h.t.Parse(text)
+ if err != nil {
+ return nil, err
+ }
+ return &htmlTemplate{t: t}, nil
+}
+
+func (h *htmlTemplate) Funcs(funcMap map[string]any) Template {
+ return &htmlTemplate{t: h.t.Funcs(funcMap)}
+}
+
+func (h *htmlTemplate) Option(opt ...string) Template {
+ return &htmlTemplate{t: h.t.Option(opt...)}
+}
+
+func (h *htmlTemplate) Delims(left, right string) Template {
+ return &htmlTemplate{t: h.t.Delims(left, right)}
+}
+
+func (h *htmlTemplate) Lookup(name string) Template {
+ t := h.t.Lookup(name)
+ if t == nil {
+ return nil
+ }
+ return &htmlTemplate{t: t}
+}
+
+func (h *htmlTemplate) Name() string {
+ return h.t.Name()
+}
+
+func (h *htmlTemplate) AddParseTree(name string, tree *parse.Tree) (Template, error) {
+ t, err := h.t.AddParseTree(name, tree)
+ if err != nil {
+ return nil, err
+ }
+ return &htmlTemplate{t: t}, nil
+}
+
+func (h *htmlTemplate) Tree() *parse.Tree {
+ return h.t.Tree
+}
+
+func (h *htmlTemplate) FindTree(name string) (*parse.Tree, bool) {
+ t := h.t.Lookup(name)
+ if t == nil {
+ return nil, false
+ }
+ return t.Tree, true
+}
diff --git a/internal/asteval/template_packages.go b/internal/asteval/template_packages.go
new file mode 100644
index 0000000..2b4428a
--- /dev/null
+++ b/internal/asteval/template_packages.go
@@ -0,0 +1,19 @@
+package asteval
+
+import (
+ html "html/template"
+ text "text/template"
+)
+
+// NewTemplate creates a Template backed by the appropriate template
+// package either: "text/template" or "html/template".
+func NewTemplate(pkgPath, name string) Template {
+ switch pkgPath {
+ case "text/template":
+ return &textTemplate{t: text.New(name)}
+ case "html/template":
+ return &htmlTemplate{t: html.New(name)}
+ default:
+ return nil
+ }
+}
diff --git a/internal/asteval/template_text.go b/internal/asteval/template_text.go
new file mode 100644
index 0000000..d0e427c
--- /dev/null
+++ b/internal/asteval/template_text.go
@@ -0,0 +1,67 @@
+package asteval
+
+import (
+ "text/template"
+ "text/template/parse"
+)
+
+// textTemplate wraps *text/template.Template.
+type textTemplate struct {
+ t *template.Template
+}
+
+func (s *textTemplate) New(name string) Template {
+ return &textTemplate{t: s.t.New(name)}
+}
+
+func (s *textTemplate) Parse(text string) (Template, error) {
+ t, err := s.t.Parse(text)
+ if err != nil {
+ return nil, err
+ }
+ return &textTemplate{t: t}, nil
+}
+
+func (s *textTemplate) Funcs(funcMap map[string]any) Template {
+ return &textTemplate{t: s.t.Funcs(funcMap)}
+}
+
+func (s *textTemplate) Option(opt ...string) Template {
+ return &textTemplate{t: s.t.Option(opt...)}
+}
+
+func (s *textTemplate) Delims(left, right string) Template {
+ return &textTemplate{t: s.t.Delims(left, right)}
+}
+
+func (s *textTemplate) Lookup(name string) Template {
+ t := s.t.Lookup(name)
+ if t == nil {
+ return nil
+ }
+ return &textTemplate{t: t}
+}
+
+func (s *textTemplate) Name() string {
+ return s.t.Name()
+}
+
+func (s *textTemplate) AddParseTree(name string, tree *parse.Tree) (Template, error) {
+ t, err := s.t.AddParseTree(name, tree)
+ if err != nil {
+ return nil, err
+ }
+ return &textTemplate{t: t}, nil
+}
+
+func (s *textTemplate) Tree() *parse.Tree {
+ return s.t.Tree
+}
+
+func (s *textTemplate) FindTree(name string) (*parse.Tree, bool) {
+ t := s.t.Lookup(name)
+ if t == nil {
+ return nil, false
+ }
+ return t.Tree, true
+}
diff --git a/package.go b/package.go
index a2ed138..e44de5f 100644
--- a/package.go
+++ b/package.go
@@ -6,7 +6,6 @@ import (
"go/ast"
"go/token"
"go/types"
- "html/template"
"path/filepath"
"golang.org/x/tools/go/packages"
@@ -22,7 +21,7 @@ type pendingCall struct {
}
type resolvedTemplate struct {
- templates *template.Template
+ templates asteval.Template
functions asteval.TemplateFunctions
metadata *asteval.TemplateMetadata
}
@@ -34,11 +33,9 @@ type resolvedTemplate struct {
// ExecuteTemplate must be called with a string literal for the second parameter.
func Package(pkg *packages.Package) error {
pending, receivers := findExecuteCalls(pkg)
- resolved, err := resolveTemplates(pkg, receivers)
- if err != nil {
- return err
- }
- return checkCalls(pkg, pending, resolved)
+ resolved, resolveErrs := resolveTemplates(pkg, receivers)
+ callErr := checkCalls(pkg, pending, resolved)
+ return errors.Join(append(resolveErrs, callErr)...)
}
// findExecuteCalls walks the package syntax looking for ExecuteTemplate calls
@@ -90,23 +87,26 @@ func findExecuteCalls(pkg *packages.Package) ([]pendingCall, map[types.Object]st
// resolveTemplates resolves each unique receiver object to its template
// construction chain, including additional ParseFS/Parse modifications.
-func resolveTemplates(pkg *packages.Package, receivers map[types.Object]struct{}) (map[types.Object]*resolvedTemplate, error) {
+func resolveTemplates(pkg *packages.Package, receivers map[types.Object]struct{}) (map[types.Object]*resolvedTemplate, []error) {
resolved := make(map[types.Object]*resolvedTemplate)
workingDirectory := packageDirectory(pkg)
embeddedPaths, err := asteval.RelativeFilePaths(workingDirectory, pkg.EmbedFiles...)
if err != nil {
- return nil, fmt.Errorf("failed to calculate relative path for embedded files: %w", err)
+ return nil, []error{fmt.Errorf("failed to calculate relative path for embedded files: %w", err)}
}
+ var resolveErrs []error
+
resolveExpr := func(obj types.Object, name string, expr ast.Expr) {
if _, needed := receivers[obj]; !needed {
return
}
funcTypeMap := asteval.DefaultFunctions(pkg.Types)
meta := &asteval.TemplateMetadata{}
- ts, _, _, err := asteval.EvaluateTemplateSelector(nil, pkg.Types, pkg.TypesInfo, expr, workingDirectory, name, "", "", pkg.Fset, pkg.Syntax, embeddedPaths, funcTypeMap, make(template.FuncMap), meta)
+ ts, _, _, err := asteval.EvaluateTemplateSelector(nil, pkg.Types, pkg.TypesInfo, expr, workingDirectory, name, "", "", pkg.Fset, pkg.Syntax, embeddedPaths, funcTypeMap, make(map[string]any), meta)
if err != nil {
+ resolveErrs = append(resolveErrs, err)
return
}
resolved[obj] = &resolvedTemplate{
@@ -194,7 +194,7 @@ func resolveTemplates(pkg *packages.Package, receivers map[types.Object]struct{}
return true
}
meta := &asteval.TemplateMetadata{}
- ts, _, _, err := asteval.EvaluateTemplateSelector(rt.templates, pkg.Types, pkg.TypesInfo, call, workingDirectory, "", "", "", pkg.Fset, pkg.Syntax, embeddedPaths, rt.functions, make(template.FuncMap), meta)
+ ts, _, _, err := asteval.EvaluateTemplateSelector(rt.templates, pkg.Types, pkg.TypesInfo, call, workingDirectory, "", "", "", pkg.Fset, pkg.Syntax, embeddedPaths, rt.functions, make(map[string]any), meta)
if err != nil {
return true
}
@@ -205,7 +205,7 @@ func resolveTemplates(pkg *packages.Package, receivers map[types.Object]struct{}
})
}
- return resolved, nil
+ return resolved, resolveErrs
}
// checkCalls type-checks each pending ExecuteTemplate call against its
@@ -231,9 +231,8 @@ func checkCalls(pkg *packages.Package, pending []pendingCall, resolved map[types
if looked == nil {
continue
}
- treeFinder := (*asteval.Forrest)(rt.templates)
- global := NewGlobal(pkg.Types, pkg.Fset, treeFinder, mergedFunctions)
- if err := Execute(global, looked.Tree, p.dataType); err != nil {
+ global := NewGlobal(pkg.Types, pkg.Fset, rt.templates, mergedFunctions)
+ if err := Execute(global, looked.Tree(), p.dataType); err != nil {
errs = append(errs, err)
}
}