diff --git a/cmd/check-templates/main.go b/cmd/check-templates/main.go new file mode 100644 index 0000000..1a2bd82 --- /dev/null +++ b/cmd/check-templates/main.go @@ -0,0 +1,48 @@ +package main + +import ( + "fmt" + "go/token" + "io" + "os" + + "golang.org/x/tools/go/packages" + + "github.com/typelate/check" +) + +func main() { + os.Exit(run(os.Args[1:], os.Stdout, os.Stderr)) +} + +func run(args []string, stdout, stderr io.Writer) int { + dir := "." + if len(args) > 0 { + dir = args[0] + } + + fset := token.NewFileSet() + pkgs, err := packages.Load(&packages.Config{ + Fset: fset, + Mode: packages.NeedTypesInfo | packages.NeedName | packages.NeedFiles | + packages.NeedTypes | packages.NeedSyntax | packages.NeedEmbedPatterns | + packages.NeedEmbedFiles | packages.NeedImports | packages.NeedModule, + Dir: dir, + }, dir) + if err != nil { + fmt.Fprintf(stderr, "failed to load packages: %v\n", err) + return 1 + } + exitCode := 0 + for _, pkg := range pkgs { + for _, e := range pkg.Errors { + fmt.Fprintln(stderr, e) + exitCode = 1 + } + if err := check.Package(pkg); err != nil { + fmt.Fprintln(stderr, err) + exitCode = 1 + } + } + return exitCode +} diff --git a/cmd/check-templates/main_test.go b/cmd/check-templates/main_test.go new file mode 100644 index 0000000..acc3f24 --- /dev/null +++ b/cmd/check-templates/main_test.go @@ -0,0 +1,40 @@ +package main + +import ( + "bytes" + "path/filepath" + "testing" + + "rsc.io/script" + "rsc.io/script/scripttest" +) + +func Test(t *testing.T) { + e := script.NewEngine() + e.Quiet = true + e.Cmds = scripttest.DefaultCmds() + e.Cmds["check-templates"] = checkTemplatesCommand() + ctx := t.Context() + scripttest.Test(t, ctx, e, nil, filepath.FromSlash("testdata/*.txt")) +} + +func checkTemplatesCommand() script.Cmd { + return script.Command(script.CmdUsage{ + Summary: "check-templates [dir]", + Args: "[dir]", + }, func(state *script.State, args ...string) (script.WaitFunc, error) { + return func(state *script.State) (string, string, error) { + var stdout, stderr bytes.Buffer + cmdArgs := args + if len(cmdArgs) == 0 { + cmdArgs = []string{state.Getwd()} + } + code := run(cmdArgs, &stdout, &stderr) + var err error + if code != 0 { + err = script.ErrUsage + } + return stdout.String(), stderr.String(), err + }, nil + }) +} diff --git a/cmd/check-templates/testdata/err_additional_parsefs_missing_field.txt b/cmd/check-templates/testdata/err_additional_parsefs_missing_field.txt new file mode 100644 index 0000000..100c55e --- /dev/null +++ b/cmd/check-templates/testdata/err_additional_parsefs_missing_field.txt @@ -0,0 +1,43 @@ +# Template defined at top level, modified in a function. The additional template +# has a field that doesn't exist on the data type. + +! check-templates +stderr 'type check failed:.*about\.gohtml:1:5: executing "about\.gohtml" at <\.Missing>: Missing not found on example\.com/app\.Page' + +-- go.mod -- +module example.com/app + +go 1.25.0 +-- main.go -- +package main + +import ( + "embed" + "fmt" + "html/template" + "net/http" +) + +//go:embed *.gohtml +var source embed.FS + +type Page struct { + Title string +} + +var ts = template.Must(template.New("app").ParseFS(source, "index.gohtml")) + +func setup() { + template.Must(ts.ParseFS(source, "about.gohtml")) +} + +func handleAbout(w http.ResponseWriter, r *http.Request) { + _ = ts.ExecuteTemplate(w, "about.gohtml", Page{Title: "Home"}) +} + +var _ = fmt.Sprint + +-- index.gohtml -- +

{{.Title}}

+-- about.gohtml -- +

{{.Missing}}

diff --git a/cmd/check-templates/testdata/err_aliased_import.txt b/cmd/check-templates/testdata/err_aliased_import.txt new file mode 100644 index 0000000..cf4adc8 --- /dev/null +++ b/cmd/check-templates/testdata/err_aliased_import.txt @@ -0,0 +1,38 @@ +# Template import with a non-standard alias should still detect errors. + +! check-templates +stderr 'type check failed:.*index\.gohtml:1:6: executing "index\.gohtml" at <\.Missing>: Missing not found on example\.com/app\.Page' + +-- go.mod -- +module example.com/app + +go 1.25.0 +-- main.go -- +package main + +import ( + "embed" + "fmt" + htmltpl "html/template" + "net/http" +) + +var ( + //go:embed *.gohtml + source embed.FS + + templates = htmltpl.Must(htmltpl.ParseFS(source, "*")) +) + +type Page struct { + Title string +} + +func handleIndex(w http.ResponseWriter, r *http.Request) { + _ = templates.ExecuteTemplate(w, "index.gohtml", Page{Title: "Home"}) +} + +var _ = fmt.Sprint + +-- index.gohtml -- +

{{.Missing}}

diff --git a/cmd/check-templates/testdata/err_closure_missing_field.txt b/cmd/check-templates/testdata/err_closure_missing_field.txt new file mode 100644 index 0000000..3c01045 --- /dev/null +++ b/cmd/check-templates/testdata/err_closure_missing_field.txt @@ -0,0 +1,38 @@ +# Template parsed in outer function, closure calls ExecuteTemplate with missing field. + +! check-templates +stderr 'type check failed:.*index\.gohtml:1:6: executing "index\.gohtml" at <\.Missing>: Missing not found on example\.com/app\.Page' + +-- go.mod -- +module example.com/app + +go 1.25.0 +-- main.go -- +package main + +import ( + "embed" + "fmt" + "html/template" + "net/http" +) + +//go:embed *.gohtml +var source embed.FS + +type Page struct { + Title string +} + +func routes(mux *http.ServeMux) { + ts := template.Must(template.ParseFS(source, "*")) + + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + _ = ts.ExecuteTemplate(w, "index.gohtml", Page{Title: "Home"}) + }) +} + +var _ = fmt.Sprint + +-- index.gohtml -- +

{{.Missing}}

diff --git a/cmd/check-templates/testdata/err_funcs_chain.txt b/cmd/check-templates/testdata/err_funcs_chain.txt new file mode 100644 index 0000000..c14f893 --- /dev/null +++ b/cmd/check-templates/testdata/err_funcs_chain.txt @@ -0,0 +1,42 @@ +# Template constructed with Funcs() chained before ParseFS should still +# report errors for missing fields. + +! check-templates +stderr 'type check failed:.*index\.gohtml:1:2: executing "index\.gohtml" at <\.Missing>: Missing not found on example\.com/app\.Page' + +-- go.mod -- +module example.com/app + +go 1.25.0 +-- main.go -- +package main + +import ( + "embed" + "fmt" + "html/template" + "io" + "strings" +) + +var ( + //go:embed *.gohtml + source embed.FS + + templates = template.Must(template.New("").Funcs(template.FuncMap{ + "upper": strings.ToUpper, + }).ParseFS(source, "*")) +) + +type Page struct { + Title string +} + +func render() { + _ = templates.ExecuteTemplate(io.Discard, "index.gohtml", Page{}) +} + +var _ = fmt.Sprint + +-- index.gohtml -- +{{.Missing | upper}} diff --git a/cmd/check-templates/testdata/err_html_template_err_branch_end.txt b/cmd/check-templates/testdata/err_html_template_err_branch_end.txt new file mode 100644 index 0000000..a5d2fbd --- /dev/null +++ b/cmd/check-templates/testdata/err_html_template_err_branch_end.txt @@ -0,0 +1,36 @@ +# html/template errors with ErrBranchEnd at Execute time. +# The go test exercises the html/template escape analysis which rejects +# templates whose {{if}} branches end in different HTML contexts. +# Companion: pass_text_template_no_err_branch_end.txt + +check-templates +exec go test + +-- go.mod -- +module example.com/app + +go 1.25.0 +-- main.go -- +package main + +import "html/template" + +var templates = template.Must(template.New("page").Parse(`
{{end}}
`)) + +func main() {} + +-- main_test.go -- +package main + +import ( + "io" + "testing" +) + +func TestExecuteBranchEnd(t *testing.T) { + err := templates.Execute(io.Discard, false) + if err == nil { + t.Fatal("expected html/template ErrBranchEnd error, got nil") + } + t.Logf("got expected error: %s", err.Error()) +} diff --git a/cmd/check-templates/testdata/err_imported_type.txt b/cmd/check-templates/testdata/err_imported_type.txt new file mode 100644 index 0000000..a9a3a60 --- /dev/null +++ b/cmd/check-templates/testdata/err_imported_type.txt @@ -0,0 +1,42 @@ +# Types imported from other packages should report errors for missing fields. + +! check-templates +stderr 'type check failed:.*index\.gohtml:1:2: executing "index\.gohtml" at <\.Missing>: Missing not found on example\.com/app/internal/model\.Page' + +-- go.mod -- +module example.com/app + +go 1.25.0 +-- internal/model/types.go -- +package model + +type Page struct { + Title string +} +-- main.go -- +package main + +import ( + "embed" + "fmt" + "html/template" + "io" + + "example.com/app/internal/model" +) + +var ( + //go:embed *.gohtml + source embed.FS + + templates = template.Must(template.ParseFS(source, "*")) +) + +func render() { + _ = templates.ExecuteTemplate(io.Discard, "index.gohtml", model.Page{}) +} + +var _ = fmt.Sprint + +-- index.gohtml -- +{{.Missing}} diff --git a/cmd/check-templates/testdata/err_inline_struct.txt b/cmd/check-templates/testdata/err_inline_struct.txt new file mode 100644 index 0000000..bb5fe87 --- /dev/null +++ b/cmd/check-templates/testdata/err_inline_struct.txt @@ -0,0 +1,37 @@ +# Inline anonymous struct types should report errors for missing fields. + +! check-templates +stderr 'type check failed:.*index\.gohtml:1:2: executing "index\.gohtml" at <\.Missing>: Missing not found on struct\{Title string\}' + +-- go.mod -- +module example.com/app + +go 1.25.0 +-- main.go -- +package main + +import ( + "embed" + "fmt" + "html/template" + "io" +) + +var ( + //go:embed *.gohtml + source embed.FS + + templates = template.Must(template.ParseFS(source, "*")) +) + +func render() { + var data struct { + Title string + } + _ = templates.ExecuteTemplate(io.Discard, "index.gohtml", data) +} + +var _ = fmt.Sprint + +-- index.gohtml -- +{{.Missing}} diff --git a/cmd/check-templates/testdata/err_local_var_missing_field.txt b/cmd/check-templates/testdata/err_local_var_missing_field.txt new file mode 100644 index 0000000..bde48d4 --- /dev/null +++ b/cmd/check-templates/testdata/err_local_var_missing_field.txt @@ -0,0 +1,35 @@ +# Template defined as local variable reports errors for missing fields. + +! check-templates +stderr 'type check failed:.*index\.gohtml:1:6: executing "index\.gohtml" at <\.Missing>: Missing not found on example\.com/app\.Page' + +-- go.mod -- +module example.com/app + +go 1.25.0 +-- main.go -- +package main + +import ( + "embed" + "fmt" + "html/template" + "net/http" +) + +//go:embed *.gohtml +var source embed.FS + +type Page struct { + Title string +} + +func handleIndex(w http.ResponseWriter, r *http.Request) { + ts := template.Must(template.ParseFS(source, "*")) + _ = ts.ExecuteTemplate(w, "index.gohtml", Page{Title: "Home"}) +} + +var _ = fmt.Sprint + +-- index.gohtml -- +

{{.Missing}}

diff --git a/cmd/check-templates/testdata/err_missing_field.txt b/cmd/check-templates/testdata/err_missing_field.txt new file mode 100644 index 0000000..c5b9e13 --- /dev/null +++ b/cmd/check-templates/testdata/err_missing_field.txt @@ -0,0 +1,38 @@ +# Template check fails when a field does not exist on the data type. + +! check-templates +stderr 'type check failed:.*index\.gohtml:1:6: executing "index\.gohtml" at <\.Missing>: Missing not found on example\.com/app\.Page' + +-- go.mod -- +module example.com/app + +go 1.25.0 +-- main.go -- +package main + +import ( + "embed" + "fmt" + "html/template" + "net/http" +) + +var ( + //go:embed *.gohtml + source embed.FS + + templates = template.Must(template.ParseFS(source, "*")) +) + +type Page struct { + Title string +} + +func handleIndex(w http.ResponseWriter, r *http.Request) { + _ = templates.ExecuteTemplate(w, "index.gohtml", Page{Title: "Home"}) +} + +var _ = fmt.Sprint + +-- index.gohtml -- +

{{.Missing}}

diff --git a/cmd/check-templates/testdata/err_multiple_errors.txt b/cmd/check-templates/testdata/err_multiple_errors.txt new file mode 100644 index 0000000..63e1a5f --- /dev/null +++ b/cmd/check-templates/testdata/err_multiple_errors.txt @@ -0,0 +1,49 @@ +# Multiple ExecuteTemplate calls can each report errors. + +! check-templates +stderr 'type check failed:.*index\.gohtml:1:6: executing "index\.gohtml" at <\.Missing>: Missing not found on example\.com/app\.IndexPage' +stderr 'type check failed:.*about\.gohtml:1:5: executing "about\.gohtml" at <\.Unknown>: Unknown not found on example\.com/app\.AboutPage' + +-- go.mod -- +module example.com/app + +go 1.25.0 +-- main.go -- +package main + +import ( + "embed" + "fmt" + "html/template" + "net/http" +) + +var ( + //go:embed *.gohtml + source embed.FS + + templates = template.Must(template.ParseFS(source, "*")) +) + +type IndexPage struct { + Title string +} + +type AboutPage struct { + Name string +} + +func handleIndex(w http.ResponseWriter, r *http.Request) { + _ = templates.ExecuteTemplate(w, "index.gohtml", IndexPage{Title: "Home"}) +} + +func handleAbout(w http.ResponseWriter, r *http.Request) { + _ = templates.ExecuteTemplate(w, "about.gohtml", AboutPage{Name: "World"}) +} + +var _ = fmt.Sprint + +-- index.gohtml -- +

{{.Missing}}

+-- about.gohtml -- +

{{.Unknown}}

diff --git a/cmd/check-templates/testdata/err_multiple_template_vars.txt b/cmd/check-templates/testdata/err_multiple_template_vars.txt new file mode 100644 index 0000000..7dd7062 --- /dev/null +++ b/cmd/check-templates/testdata/err_multiple_template_vars.txt @@ -0,0 +1,49 @@ +# Multiple template variables must not cross-contaminate: only the incorrect +# call should produce an error. + +! check-templates +stderr 'type check failed:.*about\.gohtml:1:2: executing "about\.gohtml" at <\.Name>: Name not found on example\.com/app\.IndexPage' +! stderr 'Title not found' + +-- go.mod -- +module example.com/app + +go 1.25.0 +-- main.go -- +package main + +import ( + "embed" + "fmt" + "html/template" + "io" +) + +//go:embed *.gohtml +var source embed.FS + +var templatesA = template.Must(template.ParseFS(source, "index.gohtml")) +var templatesB = template.Must(template.ParseFS(source, "about.gohtml")) + +type IndexPage struct { + Title string +} + +type AboutPage struct { + Name string +} + +func renderIndex() { + _ = templatesA.ExecuteTemplate(io.Discard, "index.gohtml", IndexPage{}) +} + +func renderAbout() { + _ = templatesB.ExecuteTemplate(io.Discard, "about.gohtml", IndexPage{}) +} + +var _ = fmt.Sprint + +-- index.gohtml -- +{{.Title}} +-- about.gohtml -- +{{.Name}} diff --git a/cmd/check-templates/testdata/err_nested_template.txt b/cmd/check-templates/testdata/err_nested_template.txt new file mode 100644 index 0000000..de654b4 --- /dev/null +++ b/cmd/check-templates/testdata/err_nested_template.txt @@ -0,0 +1,41 @@ +# A nested template call should check the invoked template against the data type. + +! check-templates +stderr 'type check failed:.*header\.gohtml:1:6: executing "header\.gohtml" at <\.Missing>: Missing not found on example\.com/app\.Page' + +-- go.mod -- +module example.com/app + +go 1.25.0 +-- main.go -- +package main + +import ( + "embed" + "fmt" + "html/template" + "io" +) + +var ( + //go:embed *.gohtml + source embed.FS + + templates = template.Must(template.ParseFS(source, "*")) +) + +type Page struct { + Title string +} + +func render() { + _ = templates.ExecuteTemplate(io.Discard, "index.gohtml", Page{}) +} + +var _ = fmt.Sprint + +-- index.gohtml -- +{{template "header.gohtml" .}} +

body

+-- header.gohtml -- +

{{.Missing}}

diff --git a/cmd/check-templates/testdata/err_shadowed_var.txt b/cmd/check-templates/testdata/err_shadowed_var.txt new file mode 100644 index 0000000..479f1e8 --- /dev/null +++ b/cmd/check-templates/testdata/err_shadowed_var.txt @@ -0,0 +1,48 @@ +# A shadowed template variable must not confuse resolution: the inner scope +# has its own template, and each ExecuteTemplate call is checked against the +# correct definition. The outer ts parses index.gohtml (which uses .Title), +# and the inner ts parses about.gohtml (which uses .Name). Passing the wrong +# data type to the inner call should produce an error only for that call. + +! check-templates +stderr 'type check failed:.*about\.gohtml:1:5: executing "about\.gohtml" at <\.Name>: Name not found on example\.com/app\.IndexPage' +! stderr 'Title not found' + +-- go.mod -- +module example.com/app + +go 1.25.0 +-- main.go -- +package main + +import ( + "embed" + "fmt" + "html/template" + "net/http" +) + +//go:embed *.gohtml +var source embed.FS + +type IndexPage struct { + Title string +} + +var ts = template.Must(template.ParseFS(source, "index.gohtml")) + +func handleIndex(w http.ResponseWriter, r *http.Request) { + _ = ts.ExecuteTemplate(w, "index.gohtml", IndexPage{Title: "Home"}) +} + +func handleAbout(w http.ResponseWriter, r *http.Request) { + ts := template.Must(template.ParseFS(source, "about.gohtml")) + _ = ts.ExecuteTemplate(w, "about.gohtml", IndexPage{Title: "oops"}) +} + +var _ = fmt.Sprint + +-- index.gohtml -- +

{{.Title}}

+-- about.gohtml -- +

{{.Name}}

diff --git a/cmd/check-templates/testdata/err_text_template.txt b/cmd/check-templates/testdata/err_text_template.txt new file mode 100644 index 0000000..7e6e187 --- /dev/null +++ b/cmd/check-templates/testdata/err_text_template.txt @@ -0,0 +1,38 @@ +# text/template should be checked the same as html/template. + +! check-templates +stderr 'type check failed:.*index\.gotmpl:1:2: executing "index\.gotmpl" at <\.Missing>: Missing not found on example\.com/app\.Page' + +-- go.mod -- +module example.com/app + +go 1.25.0 +-- main.go -- +package main + +import ( + "embed" + "fmt" + "io" + "text/template" +) + +var ( + //go:embed *.gotmpl + source embed.FS + + templates = template.Must(template.ParseFS(source, "*")) +) + +type Page struct { + Title string +} + +func render() { + _ = templates.ExecuteTemplate(io.Discard, "index.gotmpl", Page{Title: "Home"}) +} + +var _ = fmt.Sprint + +-- index.gotmpl -- +{{.Missing}} diff --git a/cmd/check-templates/testdata/pass_additional_parsefs.txt b/cmd/check-templates/testdata/pass_additional_parsefs.txt new file mode 100644 index 0000000..50059f6 --- /dev/null +++ b/cmd/check-templates/testdata/pass_additional_parsefs.txt @@ -0,0 +1,51 @@ +# Template defined at top level, then has more templates parsed inside a function. +# Both sets of templates should be available. + +check-templates +! stderr . + +-- go.mod -- +module example.com/app + +go 1.25.0 +-- main.go -- +package main + +import ( + "embed" + "fmt" + "html/template" + "net/http" +) + +//go:embed *.gohtml +var source embed.FS + +type IndexPage struct { + Title string +} + +type AboutPage struct { + Name string +} + +var ts = template.Must(template.New("app").ParseFS(source, "index.gohtml")) + +func setup() { + template.Must(ts.ParseFS(source, "about.gohtml")) +} + +func handleIndex(w http.ResponseWriter, r *http.Request) { + _ = ts.ExecuteTemplate(w, "index.gohtml", IndexPage{Title: "Home"}) +} + +func handleAbout(w http.ResponseWriter, r *http.Request) { + _ = ts.ExecuteTemplate(w, "about.gohtml", AboutPage{Name: "World"}) +} + +var _ = fmt.Sprint + +-- index.gohtml -- +

{{.Title}}

+-- about.gohtml -- +

{{.Name}}

diff --git a/cmd/check-templates/testdata/pass_additional_parsefs_no_must.txt b/cmd/check-templates/testdata/pass_additional_parsefs_no_must.txt new file mode 100644 index 0000000..09975df --- /dev/null +++ b/cmd/check-templates/testdata/pass_additional_parsefs_no_must.txt @@ -0,0 +1,47 @@ +# Template defined at top level, modified in a function without template.Must. +# The ParseFS call result is discarded but should still be detected. + +check-templates +! stderr . + +-- go.mod -- +module example.com/app + +go 1.25.0 +-- main.go -- +package main + +import ( + "embed" + "fmt" + "html/template" + "net/http" +) + +//go:embed *.gohtml +var source embed.FS + +type IndexPage struct { + Title string +} + +type AboutPage struct { + Name string +} + +var ts = template.Must(template.New("app").ParseFS(source, "index.gohtml")) + +func setup() { + _, _ = ts.ParseFS(source, "about.gohtml") +} + +func handleAbout(w http.ResponseWriter, r *http.Request) { + _ = ts.ExecuteTemplate(w, "about.gohtml", AboutPage{Name: "World"}) +} + +var _ = fmt.Sprint + +-- index.gohtml -- +

{{.Title}}

+-- about.gohtml -- +

{{.Name}}

diff --git a/cmd/check-templates/testdata/pass_aliased_import.txt b/cmd/check-templates/testdata/pass_aliased_import.txt new file mode 100644 index 0000000..96a9837 --- /dev/null +++ b/cmd/check-templates/testdata/pass_aliased_import.txt @@ -0,0 +1,38 @@ +# Template import with a non-standard alias should still be resolved. + +check-templates +! stderr . + +-- go.mod -- +module example.com/app + +go 1.25.0 +-- main.go -- +package main + +import ( + "embed" + "fmt" + htmltpl "html/template" + "net/http" +) + +var ( + //go:embed *.gohtml + source embed.FS + + templates = htmltpl.Must(htmltpl.ParseFS(source, "*")) +) + +type Page struct { + Title string +} + +func handleIndex(w http.ResponseWriter, r *http.Request) { + _ = templates.ExecuteTemplate(w, "index.gohtml", Page{Title: "Home"}) +} + +var _ = fmt.Sprint + +-- index.gohtml -- +

{{.Title}}

diff --git a/cmd/check-templates/testdata/pass_closure.txt b/cmd/check-templates/testdata/pass_closure.txt new file mode 100644 index 0000000..02f97d3 --- /dev/null +++ b/cmd/check-templates/testdata/pass_closure.txt @@ -0,0 +1,38 @@ +# Template parsed in outer function, closures call ExecuteTemplate. + +check-templates +! stderr . + +-- go.mod -- +module example.com/app + +go 1.25.0 +-- main.go -- +package main + +import ( + "embed" + "fmt" + "html/template" + "net/http" +) + +//go:embed *.gohtml +var source embed.FS + +type Page struct { + Title string +} + +func routes(mux *http.ServeMux) { + ts := template.Must(template.ParseFS(source, "*")) + + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + _ = ts.ExecuteTemplate(w, "index.gohtml", Page{Title: "Home"}) + }) +} + +var _ = fmt.Sprint + +-- index.gohtml -- +

{{.Title}}

diff --git a/cmd/check-templates/testdata/pass_different_identifier.txt b/cmd/check-templates/testdata/pass_different_identifier.txt new file mode 100644 index 0000000..be176ca --- /dev/null +++ b/cmd/check-templates/testdata/pass_different_identifier.txt @@ -0,0 +1,38 @@ +# Template variable with a non-standard identifier name. + +check-templates +! stderr . + +-- go.mod -- +module example.com/app + +go 1.25.0 +-- main.go -- +package main + +import ( + "embed" + "fmt" + "html/template" + "net/http" +) + +var ( + //go:embed *.gohtml + source embed.FS + + tmpl = template.Must(template.ParseFS(source, "*")) +) + +type Page struct { + Title string +} + +func handleIndex(w http.ResponseWriter, r *http.Request) { + _ = tmpl.ExecuteTemplate(w, "index.gohtml", Page{Title: "Home"}) +} + +var _ = fmt.Sprint + +-- index.gohtml -- +

{{.Title}}

diff --git a/cmd/check-templates/testdata/pass_funcs_chain.txt b/cmd/check-templates/testdata/pass_funcs_chain.txt new file mode 100644 index 0000000..3a2dc12 --- /dev/null +++ b/cmd/check-templates/testdata/pass_funcs_chain.txt @@ -0,0 +1,40 @@ +# Template constructed with Funcs() chained before ParseFS should be resolved. + +check-templates + +-- go.mod -- +module example.com/app + +go 1.25.0 +-- main.go -- +package main + +import ( + "embed" + "fmt" + "html/template" + "io" + "strings" +) + +var ( + //go:embed *.gohtml + source embed.FS + + templates = template.Must(template.New("").Funcs(template.FuncMap{ + "upper": strings.ToUpper, + }).ParseFS(source, "*")) +) + +type Page struct { + Title string +} + +func render() { + _ = templates.ExecuteTemplate(io.Discard, "index.gohtml", Page{}) +} + +var _ = fmt.Sprint + +-- index.gohtml -- +{{.Title | upper}} diff --git a/cmd/check-templates/testdata/pass_imported_type.txt b/cmd/check-templates/testdata/pass_imported_type.txt new file mode 100644 index 0000000..b6770ea --- /dev/null +++ b/cmd/check-templates/testdata/pass_imported_type.txt @@ -0,0 +1,41 @@ +# Types imported from other packages should be resolved correctly. + +check-templates + +-- go.mod -- +module example.com/app + +go 1.25.0 +-- internal/model/types.go -- +package model + +type Page struct { + Title string +} +-- main.go -- +package main + +import ( + "embed" + "fmt" + "html/template" + "io" + + "example.com/app/internal/model" +) + +var ( + //go:embed *.gohtml + source embed.FS + + templates = template.Must(template.ParseFS(source, "*")) +) + +func render() { + _ = templates.ExecuteTemplate(io.Discard, "index.gohtml", model.Page{}) +} + +var _ = fmt.Sprint + +-- index.gohtml -- +{{.Title}} diff --git a/cmd/check-templates/testdata/pass_inline_struct.txt b/cmd/check-templates/testdata/pass_inline_struct.txt new file mode 100644 index 0000000..cfecd07 --- /dev/null +++ b/cmd/check-templates/testdata/pass_inline_struct.txt @@ -0,0 +1,36 @@ +# Inline anonymous struct types should be resolved correctly. + +check-templates + +-- go.mod -- +module example.com/app + +go 1.25.0 +-- main.go -- +package main + +import ( + "embed" + "fmt" + "html/template" + "io" +) + +var ( + //go:embed *.gohtml + source embed.FS + + templates = template.Must(template.ParseFS(source, "*")) +) + +func render() { + var data struct { + Title string + } + _ = templates.ExecuteTemplate(io.Discard, "index.gohtml", data) +} + +var _ = fmt.Sprint + +-- index.gohtml -- +{{.Title}} diff --git a/cmd/check-templates/testdata/pass_local_var.txt b/cmd/check-templates/testdata/pass_local_var.txt new file mode 100644 index 0000000..dfa6fd3 --- /dev/null +++ b/cmd/check-templates/testdata/pass_local_var.txt @@ -0,0 +1,35 @@ +# Template defined as a local variable in the same function as ExecuteTemplate. + +check-templates +! stderr . + +-- go.mod -- +module example.com/app + +go 1.25.0 +-- main.go -- +package main + +import ( + "embed" + "fmt" + "html/template" + "net/http" +) + +//go:embed *.gohtml +var source embed.FS + +func handleIndex(w http.ResponseWriter, r *http.Request) { + ts := template.Must(template.ParseFS(source, "*")) + _ = ts.ExecuteTemplate(w, "index.gohtml", Page{Title: "Home"}) +} + +type Page struct { + Title string +} + +var _ = fmt.Sprint + +-- index.gohtml -- +

{{.Title}}

diff --git a/cmd/check-templates/testdata/pass_multiple_calls.txt b/cmd/check-templates/testdata/pass_multiple_calls.txt new file mode 100644 index 0000000..36573ba --- /dev/null +++ b/cmd/check-templates/testdata/pass_multiple_calls.txt @@ -0,0 +1,48 @@ +# Multiple ExecuteTemplate calls with different types all pass. + +check-templates +! stderr . + +-- go.mod -- +module example.com/app + +go 1.25.0 +-- main.go -- +package main + +import ( + "embed" + "fmt" + "html/template" + "net/http" +) + +var ( + //go:embed *.gohtml + source embed.FS + + templates = template.Must(template.ParseFS(source, "*")) +) + +type IndexPage struct { + Title string +} + +type AboutPage struct { + Name string +} + +func handleIndex(w http.ResponseWriter, r *http.Request) { + _ = templates.ExecuteTemplate(w, "index.gohtml", IndexPage{Title: "Home"}) +} + +func handleAbout(w http.ResponseWriter, r *http.Request) { + _ = templates.ExecuteTemplate(w, "about.gohtml", AboutPage{Name: "World"}) +} + +var _ = fmt.Sprint + +-- index.gohtml -- +

{{.Title}}

+-- about.gohtml -- +

{{.Name}}

diff --git a/cmd/check-templates/testdata/pass_multiple_template_vars.txt b/cmd/check-templates/testdata/pass_multiple_template_vars.txt new file mode 100644 index 0000000..ad95b96 --- /dev/null +++ b/cmd/check-templates/testdata/pass_multiple_template_vars.txt @@ -0,0 +1,46 @@ +# Multiple template variables should each resolve independently. + +check-templates + +-- go.mod -- +module example.com/app + +go 1.25.0 +-- main.go -- +package main + +import ( + "embed" + "fmt" + "html/template" + "io" +) + +//go:embed *.gohtml +var source embed.FS + +var templatesA = template.Must(template.ParseFS(source, "index.gohtml")) +var templatesB = template.Must(template.ParseFS(source, "about.gohtml")) + +type IndexPage struct { + Title string +} + +type AboutPage struct { + Name string +} + +func renderIndex() { + _ = templatesA.ExecuteTemplate(io.Discard, "index.gohtml", IndexPage{}) +} + +func renderAbout() { + _ = templatesB.ExecuteTemplate(io.Discard, "about.gohtml", AboutPage{}) +} + +var _ = fmt.Sprint + +-- index.gohtml -- +{{.Title}} +-- about.gohtml -- +{{.Name}} diff --git a/cmd/check-templates/testdata/pass_nested_template.txt b/cmd/check-templates/testdata/pass_nested_template.txt new file mode 100644 index 0000000..e6c04e0 --- /dev/null +++ b/cmd/check-templates/testdata/pass_nested_template.txt @@ -0,0 +1,40 @@ +# Nested template calls should propagate the data type correctly. + +check-templates + +-- go.mod -- +module example.com/app + +go 1.25.0 +-- main.go -- +package main + +import ( + "embed" + "fmt" + "html/template" + "io" +) + +var ( + //go:embed *.gohtml + source embed.FS + + templates = template.Must(template.ParseFS(source, "*")) +) + +type Page struct { + Title string +} + +func render() { + _ = templates.ExecuteTemplate(io.Discard, "index.gohtml", Page{}) +} + +var _ = fmt.Sprint + +-- index.gohtml -- +{{template "header.gohtml" .}} +

body

+-- header.gohtml -- +

{{.Title}}

diff --git a/cmd/check-templates/testdata/pass_no_execute_calls.txt b/cmd/check-templates/testdata/pass_no_execute_calls.txt new file mode 100644 index 0000000..000f01f --- /dev/null +++ b/cmd/check-templates/testdata/pass_no_execute_calls.txt @@ -0,0 +1,13 @@ +# Package with no ExecuteTemplate calls passes without error. + +check-templates +! stderr . + +-- go.mod -- +module example.com/app + +go 1.25.0 +-- main.go -- +package main + +func main() {} diff --git a/cmd/check-templates/testdata/pass_non_template_execute.txt b/cmd/check-templates/testdata/pass_non_template_execute.txt new file mode 100644 index 0000000..b7a572e --- /dev/null +++ b/cmd/check-templates/testdata/pass_non_template_execute.txt @@ -0,0 +1,28 @@ +# A custom type with an ExecuteTemplate method should not be confused +# with html/template. The tool should not report errors for it. + +check-templates +! stderr . + +-- go.mod -- +module example.com/app + +go 1.25.0 +-- main.go -- +package main + +import ( + "io" + "net/http" +) + +type Renderer struct{} + +func (r *Renderer) ExecuteTemplate(w io.Writer, name string, data any) error { + return nil +} + +func handle(w http.ResponseWriter, r *http.Request) { + renderer := &Renderer{} + _ = renderer.ExecuteTemplate(w, "index.gohtml", struct{ Missing string }{}) +} diff --git a/cmd/check-templates/testdata/pass_parse_call.txt b/cmd/check-templates/testdata/pass_parse_call.txt new file mode 100644 index 0000000..d1ea925 --- /dev/null +++ b/cmd/check-templates/testdata/pass_parse_call.txt @@ -0,0 +1,29 @@ +# Template constructed with Parse (not ParseFS) passes check. + +check-templates +! stderr . + +-- go.mod -- +module example.com/app + +go 1.25.0 +-- main.go -- +package main + +import ( + "fmt" + "html/template" + "net/http" +) + +var templates = template.Must(template.New("greeting").Parse(`

Hello, {{.Name}}!

`)) + +type Greeting struct { + Name string +} + +func handle(w http.ResponseWriter, r *http.Request) { + _ = templates.ExecuteTemplate(w, "greeting", Greeting{Name: "World"}) +} + +var _ = fmt.Sprint diff --git a/cmd/check-templates/testdata/pass_shadowed_var.txt b/cmd/check-templates/testdata/pass_shadowed_var.txt new file mode 100644 index 0000000..213415c --- /dev/null +++ b/cmd/check-templates/testdata/pass_shadowed_var.txt @@ -0,0 +1,47 @@ +# A shadowed template variable should resolve to the correct definition at each call site. + +check-templates +! stderr . + +-- go.mod -- +module example.com/app + +go 1.25.0 +-- main.go -- +package main + +import ( + "embed" + "fmt" + "html/template" + "net/http" +) + +//go:embed *.gohtml +var source embed.FS + +type IndexPage struct { + Title string +} + +type AboutPage struct { + Name string +} + +var ts = template.Must(template.ParseFS(source, "index.gohtml")) + +func handleIndex(w http.ResponseWriter, r *http.Request) { + _ = ts.ExecuteTemplate(w, "index.gohtml", IndexPage{Title: "Home"}) +} + +func handleAbout(w http.ResponseWriter, r *http.Request) { + ts := template.Must(template.ParseFS(source, "about.gohtml")) + _ = ts.ExecuteTemplate(w, "about.gohtml", AboutPage{Name: "World"}) +} + +var _ = fmt.Sprint + +-- index.gohtml -- +

{{.Title}}

+-- about.gohtml -- +

{{.Name}}

diff --git a/cmd/check-templates/testdata/pass_simple.txt b/cmd/check-templates/testdata/pass_simple.txt new file mode 100644 index 0000000..de22050 --- /dev/null +++ b/cmd/check-templates/testdata/pass_simple.txt @@ -0,0 +1,38 @@ +# Simple template check passes when fields match. + +check-templates +! stderr . + +-- go.mod -- +module example.com/app + +go 1.25.0 +-- main.go -- +package main + +import ( + "embed" + "fmt" + "html/template" + "net/http" +) + +var ( + //go:embed *.gohtml + source embed.FS + + templates = template.Must(template.ParseFS(source, "*")) +) + +type Page struct { + Title string +} + +func handleIndex(w http.ResponseWriter, r *http.Request) { + _ = templates.ExecuteTemplate(w, "index.gohtml", Page{Title: "Home"}) +} + +var _ = fmt.Sprint + +-- index.gohtml -- +

{{.Title}}

diff --git a/cmd/check-templates/testdata/pass_text_template.txt b/cmd/check-templates/testdata/pass_text_template.txt new file mode 100644 index 0000000..cbf8999 --- /dev/null +++ b/cmd/check-templates/testdata/pass_text_template.txt @@ -0,0 +1,37 @@ +# text/template should be checked the same as html/template. + +check-templates + +-- go.mod -- +module example.com/app + +go 1.25.0 +-- main.go -- +package main + +import ( + "embed" + "fmt" + "io" + "text/template" +) + +var ( + //go:embed *.gotmpl + source embed.FS + + templates = template.Must(template.ParseFS(source, "*")) +) + +type Page struct { + Title string +} + +func render() { + _ = templates.ExecuteTemplate(io.Discard, "index.gotmpl", Page{Title: "Home"}) +} + +var _ = fmt.Sprint + +-- index.gotmpl -- +{{.Title}} diff --git a/cmd/check-templates/testdata/pass_text_template_no_err_branch_end.txt b/cmd/check-templates/testdata/pass_text_template_no_err_branch_end.txt new file mode 100644 index 0000000..1ae730d --- /dev/null +++ b/cmd/check-templates/testdata/pass_text_template_no_err_branch_end.txt @@ -0,0 +1,35 @@ +# text/template does not error with ErrBranchEnd. +# The go test confirms text/template.Execute succeeds on the same template +# that html/template rejects due to branches ending in different contexts. +# Companion: err_html_template_err_branch_end.txt + +check-templates +exec go test + +-- go.mod -- +module example.com/app + +go 1.25.0 +-- main.go -- +package main + +import "text/template" + +var templates = template.Must(template.New("page").Parse(`
{{end}}
`)) + +func main() {} + +-- main_test.go -- +package main + +import ( + "io" + "testing" +) + +func TestExecuteNoBranchEnd(t *testing.T) { + err := templates.Execute(io.Discard, false) + if err != nil { + t.Fatalf("text/template should not produce ErrBranchEnd, got: %v", err) + } +} diff --git a/func.go b/func.go index 72e2b6f..80c51c8 100644 --- a/func.go +++ b/func.go @@ -31,7 +31,15 @@ func DefaultFunctions(pkg *types.Package) Functions { } { if p, ok := findPackage(pkg, pn); ok && p != nil { for funcIdent, templateFunc := range idents { - fns[templateFunc] = p.Scope().Lookup(funcIdent).Type().(*types.Signature) + obj := p.Scope().Lookup(funcIdent) + if obj == nil { + continue + } + sig, ok := obj.Type().(*types.Signature) + if !ok { + continue + } + fns[templateFunc] = sig } } } diff --git a/go.mod b/go.mod index 7d0808a..bad63f5 100644 --- a/go.mod +++ b/go.mod @@ -5,6 +5,7 @@ go 1.25.0 require ( github.com/stretchr/testify v1.11.1 golang.org/x/tools v0.41.0 + rsc.io/script v0.0.2 ) require ( diff --git a/go.sum b/go.sum index ddb0899..7e5240f 100644 --- a/go.sum +++ b/go.sum @@ -16,3 +16,5 @@ gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+ gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +rsc.io/script v0.0.2 h1:eYoG7A3GFC3z1pRx3A2+s/vZ9LA8cxojHyCvslnj4RI= +rsc.io/script v0.0.2/go.mod h1:cKBjCtFBBeZ0cbYFRXkRoxP+xGqhArPa9t3VWhtXfzU= diff --git a/internal/asteval/errors.go b/internal/asteval/errors.go new file mode 100644 index 0000000..fb2c83e --- /dev/null +++ b/internal/asteval/errors.go @@ -0,0 +1,13 @@ +package asteval + +import ( + "fmt" + "go/token" + "path/filepath" +) + +func wrapWithFilename(workingDirectory string, set *token.FileSet, pos token.Pos, err error) error { + p := set.Position(pos) + p.Filename, _ = filepath.Rel(workingDirectory, p.Filename) + return fmt.Errorf("%s: %w", p, err) +} diff --git a/internal/asteval/forrest.go b/internal/asteval/forrest.go new file mode 100644 index 0000000..a3083bb --- /dev/null +++ b/internal/asteval/forrest.go @@ -0,0 +1,20 @@ +package asteval + +import ( + "html/template" + "text/template/parse" +) + +type Forrest template.Template + +func NewForrest(templates *template.Template) *Forrest { + return (*Forrest)(templates) +} + +func (f *Forrest) FindTree(name string) (*parse.Tree, bool) { + ts := (*template.Template)(f).Lookup(name) + if ts == nil { + return nil, false + } + return ts.Tree, true +} diff --git a/internal/asteval/string.go b/internal/asteval/string.go new file mode 100644 index 0000000..ba8b813 --- /dev/null +++ b/internal/asteval/string.go @@ -0,0 +1,30 @@ +package asteval + +import ( + "fmt" + "go/ast" + "go/token" + "strconv" + + "github.com/typelate/check/internal/astgen" +) + +func StringLiteralExpression(wd string, set *token.FileSet, exp ast.Expr) (string, error) { + arg, ok := exp.(*ast.BasicLit) + if !ok || arg.Kind != token.STRING { + return "", wrapWithFilename(wd, set, exp.Pos(), fmt.Errorf("expected string literal got %s", astgen.Format(exp))) + } + return strconv.Unquote(arg.Value) +} + +func StringLiteralExpressionList(wd string, set *token.FileSet, list []ast.Expr) ([]string, error) { + result := make([]string, 0, len(list)) + for _, a := range list { + s, err := StringLiteralExpression(wd, set, a) + if err != nil { + return result, err + } + result = append(result, s) + } + return result, nil +} diff --git a/internal/asteval/template.go b/internal/asteval/template.go new file mode 100644 index 0000000..b0293e9 --- /dev/null +++ b/internal/asteval/template.go @@ -0,0 +1,618 @@ +package asteval + +import ( + "bytes" + "fmt" + "go/ast" + "go/format" + "go/token" + "go/types" + "os" + "path/filepath" + "slices" + "strconv" + "strings" + "text/template/parse" + "unicode" + + "github.com/typelate/check/internal/astgen" +) + +// Template abstracts over html/template.Template and text/template.Template +// so that the correct template package is used based on the user's import. +type Template interface { + New(name string) Template + Parse(text string) (Template, error) + Funcs(funcMap map[string]any) Template + Option(opt ...string) Template + Delims(left, right string) Template + Lookup(name string) Template + Name() string + AddParseTree(name string, tree *parse.Tree) (Template, error) + Tree() *parse.Tree + FindTree(name string) (*parse.Tree, bool) +} + +// TemplateMetadata accumulates metadata during template evaluation. +type TemplateMetadata struct { + EmbedFilePaths []string + ParseCalls []*ast.BasicLit +} + +func EvaluateTemplateSelector(ts Template, pkg *types.Package, typesInfo *types.Info, expression ast.Expr, workingDirectory, templatesVariable, rDelim, lDelim string, fileSet *token.FileSet, files []*ast.File, embeddedPaths []string, funcTypeMaps TemplateFunctions, fm map[string]any, meta *TemplateMetadata) (Template, string, string, error) { + call, ok := expression.(*ast.CallExpr) + if !ok { + return nil, lDelim, rDelim, wrapWithFilename(workingDirectory, fileSet, expression.Pos(), fmt.Errorf("expected call expression")) + } + sel, ok := call.Fun.(*ast.SelectorExpr) + if !ok { + return nil, lDelim, rDelim, wrapWithFilename(workingDirectory, fileSet, call.Fun.Pos(), fmt.Errorf("unexpected expression %T: %s", call.Fun, astgen.Format(call.Fun))) + } + switch x := sel.X.(type) { + default: + return nil, lDelim, rDelim, wrapWithFilename(workingDirectory, fileSet, sel.X.Pos(), fmt.Errorf("expected exactly one argument %s got %d", astgen.Format(sel.X), len(call.Args))) + case *ast.Ident: + if !IsTemplatePkgIdent(typesInfo, x) { + if ts == nil { + return nil, lDelim, rDelim, wrapWithFilename(workingDirectory, fileSet, sel.X.Pos(), fmt.Errorf("expected template package got %s", astgen.Format(sel.X))) + } + // Variable receiver — apply method to existing template. + switch sel.Sel.Name { + case "ParseFS": + filePaths, err := evaluateCallParseFilesArgs(workingDirectory, fileSet, call, files, embeddedPaths) + if err != nil { + return nil, lDelim, rDelim, err + } + if meta != nil { + meta.EmbedFilePaths = append(meta.EmbedFilePaths, filePaths...) + } + pkgPath := templatePkgPath(typesInfo, x) + t, err := parseFiles(ts, pkgPath, fm, lDelim, rDelim, filePaths...) + return t, lDelim, rDelim, err + case "Parse": + if len(call.Args) != 1 { + return nil, lDelim, rDelim, wrapWithFilename(workingDirectory, fileSet, call.Lparen, fmt.Errorf("expected exactly one string literal argument")) + } + if meta != nil { + if bl, ok := call.Args[0].(*ast.BasicLit); ok { + meta.ParseCalls = append(meta.ParseCalls, bl) + } + } + sl, err := StringLiteralExpression(workingDirectory, fileSet, call.Args[0]) + if err != nil { + return nil, lDelim, rDelim, err + } + t, err := ts.Parse(sl) + return t, lDelim, rDelim, err + case "Funcs": + if err := evaluateFuncMap(workingDirectory, typesInfo, pkg, fileSet, call, fm, funcTypeMaps); err != nil { + return nil, lDelim, rDelim, err + } + return ts.Funcs(fm), lDelim, rDelim, nil + case "Option": + list, err := StringLiteralExpressionList(workingDirectory, fileSet, call.Args) + if err != nil { + return nil, lDelim, rDelim, err + } + return ts.Option(list...), lDelim, rDelim, nil + case "Delims": + if len(call.Args) != 2 { + return nil, lDelim, rDelim, wrapWithFilename(workingDirectory, fileSet, call.Lparen, fmt.Errorf("expected exactly two string literal arguments")) + } + list, err := StringLiteralExpressionList(workingDirectory, fileSet, call.Args) + if err != nil { + return nil, lDelim, rDelim, err + } + return ts.Delims(list[0], list[1]), list[0], list[1], nil + default: + return nil, lDelim, rDelim, wrapWithFilename(workingDirectory, fileSet, call.Fun.Pos(), fmt.Errorf("unsupported method %s on variable receiver", sel.Sel.Name)) + } + } + pkgPath := templatePkgPath(typesInfo, x) + switch sel.Sel.Name { + case "Must": + if len(call.Args) != 1 { + return nil, lDelim, rDelim, wrapWithFilename(workingDirectory, fileSet, call.Lparen, fmt.Errorf("expected exactly one argument %s got %d", astgen.Format(sel.X), len(call.Args))) + } + return EvaluateTemplateSelector(ts, pkg, typesInfo, call.Args[0], workingDirectory, templatesVariable, rDelim, lDelim, fileSet, files, embeddedPaths, funcTypeMaps, fm, meta) + case "New": + if len(call.Args) != 1 { + return nil, lDelim, rDelim, wrapWithFilename(workingDirectory, fileSet, call.Lparen, fmt.Errorf("expected exactly one string literal argument")) + } + templateNames, err := StringLiteralExpressionList(workingDirectory, fileSet, call.Args) + if err != nil { + return nil, lDelim, rDelim, err + } + return NewTemplate(pkgPath, templateNames[0]), lDelim, rDelim, nil + case "ParseFS": + filePaths, err := evaluateCallParseFilesArgs(workingDirectory, fileSet, call, files, embeddedPaths) + if err != nil { + return nil, lDelim, rDelim, err + } + if meta != nil { + meta.EmbedFilePaths = append(meta.EmbedFilePaths, filePaths...) + } + t, err := parseFiles(nil, pkgPath, fm, lDelim, rDelim, filePaths...) + return t, lDelim, rDelim, err + default: + return nil, lDelim, rDelim, wrapWithFilename(workingDirectory, fileSet, call.Fun.Pos(), fmt.Errorf("unsupported function %s", sel.Sel.Name)) + } + case *ast.CallExpr: + up, upLDelim, upRDelim, err := EvaluateTemplateSelector(ts, pkg, typesInfo, sel.X, workingDirectory, templatesVariable, rDelim, lDelim, fileSet, files, embeddedPaths, funcTypeMaps, fm, meta) + if err != nil { + return nil, lDelim, rDelim, err + } + switch sel.Sel.Name { + case "Delims": + if len(call.Args) != 2 { + return nil, upLDelim, upRDelim, wrapWithFilename(workingDirectory, fileSet, call.Lparen, fmt.Errorf("expected exactly two string literal arguments")) + } + list, err := StringLiteralExpressionList(workingDirectory, fileSet, call.Args) + if err != nil { + return nil, upLDelim, upRDelim, err + } + return up.Delims(list[0], list[1]), list[0], list[1], nil + case "Parse": + if len(call.Args) != 1 { + return nil, upLDelim, upRDelim, wrapWithFilename(workingDirectory, fileSet, call.Lparen, fmt.Errorf("expected exactly one string literal argument")) + } + if meta != nil { + if bl, ok := call.Args[0].(*ast.BasicLit); ok { + meta.ParseCalls = append(meta.ParseCalls, bl) + } + } + sl, err := StringLiteralExpression(workingDirectory, fileSet, call.Args[0]) + if err != nil { + return nil, upLDelim, upRDelim, err + } + t, err := up.Parse(sl) + return t, upLDelim, upRDelim, err + case "New": + if len(call.Args) != 1 { + return nil, upLDelim, upRDelim, wrapWithFilename(workingDirectory, fileSet, call.Lparen, fmt.Errorf("expected exactly one string literal argument")) + } + templateNames, err := StringLiteralExpressionList(workingDirectory, fileSet, call.Args) + if err != nil { + return nil, upLDelim, upRDelim, err + } + return up.New(templateNames[0]), upLDelim, upRDelim, nil + case "ParseFS": + filePaths, err := evaluateCallParseFilesArgs(workingDirectory, fileSet, call, files, embeddedPaths) + if err != nil { + return nil, upLDelim, upRDelim, err + } + if meta != nil { + meta.EmbedFilePaths = append(meta.EmbedFilePaths, filePaths...) + } + t, err := parseFiles(up, "", fm, upLDelim, upRDelim, filePaths...) + return t, upLDelim, upRDelim, err + case "Option": + list, err := StringLiteralExpressionList(workingDirectory, fileSet, call.Args) + if err != nil { + return nil, upLDelim, upRDelim, err + } + return up.Option(list...), upLDelim, upRDelim, nil + case "Funcs": + if err := evaluateFuncMap(workingDirectory, typesInfo, pkg, fileSet, call, fm, funcTypeMaps); err != nil { + return nil, upLDelim, upRDelim, err + } + return up.Funcs(fm), upLDelim, upRDelim, nil + default: + return nil, upLDelim, upRDelim, wrapWithFilename(workingDirectory, fileSet, call.Fun.Pos(), fmt.Errorf("unsupported method %s", sel.Sel.Name)) + } + } +} + +// templatePkgPath extracts the import path ("html/template" or "text/template") +// from an AST identifier that refers to a template package. +func templatePkgPath(info *types.Info, ident *ast.Ident) string { + if info == nil { + return "html/template" + } + obj := info.Uses[ident] + pkgName, ok := obj.(*types.PkgName) + if !ok { + return "html/template" + } + return pkgName.Imported().Path() +} + +// IsTemplatePkgIdent reports whether ident refers to the "html/template" +// or "text/template" package via the type checker. +func IsTemplatePkgIdent(info *types.Info, ident *ast.Ident) bool { + if info == nil { + return false + } + obj := info.Uses[ident] + pkgName, ok := obj.(*types.PkgName) + if !ok { + return false + } + path := pkgName.Imported().Path() + return path == "html/template" || path == "text/template" +} + +// IsTemplateMethod reports whether sel refers to a method on +// *html/template.Template or *text/template.Template. +func IsTemplateMethod(typesInfo *types.Info, sel *ast.SelectorExpr) bool { + if typesInfo == nil { + return false + } + selection, ok := typesInfo.Selections[sel] + if !ok { + return false + } + fn, ok := selection.Obj().(*types.Func) + if !ok { + return false + } + fnPkg := fn.Pkg() + if fnPkg == nil { + return false + } + return fnPkg.Path() == "html/template" || fnPkg.Path() == "text/template" +} + +// FindModificationReceiver unwraps template.Must and returns the types.Object +// of the variable receiver for a method call like ts.ParseFS(...) or +// template.Must(ts.ParseFS(...)). Returns nil if no variable receiver is found. +func FindModificationReceiver(expr *ast.CallExpr, typesInfo *types.Info) types.Object { + sel, ok := expr.Fun.(*ast.SelectorExpr) + if !ok { + return nil + } + switch x := sel.X.(type) { + case *ast.Ident: + if IsTemplatePkgIdent(typesInfo, x) && sel.Sel.Name == "Must" && len(expr.Args) == 1 { + inner, ok := expr.Args[0].(*ast.CallExpr) + if !ok { + return nil + } + return FindModificationReceiver(inner, typesInfo) + } + if IsTemplatePkgIdent(typesInfo, x) { + return nil + } + return typesInfo.Uses[x] + } + return nil +} + +func builtins() map[string]any { + type nothing struct{} + return map[string]any{ + "and": func() (_ nothing) { return }, + "call": func() (_ nothing) { return }, + "html": func() (_ nothing) { return }, + "index": func() (_ nothing) { return }, + "slice": func() (_ nothing) { return }, + "js": func() (_ nothing) { return }, + "len": func() (_ nothing) { return }, + "not": func() (_ nothing) { return }, + "or": func() (_ nothing) { return }, + "print": func() (_ nothing) { return }, + "printf": func() (_ nothing) { return }, + "println": func() (_ nothing) { return }, + "urlquery": func() (_ nothing) { return }, + + // Comparisons + "eq": func() (_ nothing) { return }, + "ge": func() (_ nothing) { return }, + "gt": func() (_ nothing) { return }, + "le": func() (_ nothing) { return }, + "lt": func() (_ nothing) { return }, + "ne": func() (_ nothing) { return }, + } +} + +func parseFiles(t Template, pkgPath string, fm map[string]any, leftDelim, rightDelim string, filenames ...string) (Template, error) { + if len(filenames) == 0 { + return nil, fmt.Errorf("template: no files named in call to ParseFiles") + } + for _, filename := range filenames { + templateName := filepath.Base(filename) + b, err := os.ReadFile(filename) + if err != nil { + return nil, err + } + s := string(b) + var tmpl Template + if t == nil { + t = NewTemplate(pkgPath, templateName) + } + if templateName == t.Name() { + tmpl = t + } else { + tmpl = t.New(templateName) + } + trees, err := parse.Parse(templateName, s, leftDelim, rightDelim, fm, builtins()) + if err != nil { + return nil, err + } + absoluteFilename, err := filepath.Abs(filename) + if err != nil { + return nil, err + } + for _, tree := range trees { + tree.ParseName = absoluteFilename + if _, err = tmpl.AddParseTree(tree.Name, tree); err != nil { + return nil, err + } + } + } + return t, nil +} + +func evaluateFuncMap(workingDirectory string, typesInfo *types.Info, pkg *types.Package, fileSet *token.FileSet, call *ast.CallExpr, fm map[string]any, funcTypesMap TemplateFunctions) error { + if len(call.Args) != 1 { + return wrapWithFilename(workingDirectory, fileSet, call.Lparen, fmt.Errorf("expected exactly 1 template.FuncMap composite literal argument")) + } + arg := call.Args[0] + lit, ok := arg.(*ast.CompositeLit) + if !ok { + return wrapWithFilename(workingDirectory, fileSet, arg.Pos(), fmt.Errorf("expected a template.FuncMap composite literal got %s", astgen.Format(arg))) + } + if typesInfo != nil { + litType := typesInfo.TypeOf(lit) + if named, ok := litType.(*types.Named); ok { + obj := named.Obj() + if obj.Pkg() == nil || obj.Name() != "FuncMap" { + return wrapWithFilename(workingDirectory, fileSet, arg.Pos(), fmt.Errorf("expected template.FuncMap got %s", litType)) + } + path := obj.Pkg().Path() + if path != "html/template" && path != "text/template" { + return wrapWithFilename(workingDirectory, fileSet, arg.Pos(), fmt.Errorf("expected template.FuncMap got %s", litType)) + } + } + } + var buf bytes.Buffer + for i, exp := range lit.Elts { + el, ok := exp.(*ast.KeyValueExpr) + if !ok { + return wrapWithFilename(workingDirectory, fileSet, exp.Pos(), fmt.Errorf("expected element at index %d to be a key value pair got %s", i, astgen.Format(exp))) + } + funcName, err := StringLiteralExpression(workingDirectory, fileSet, el.Key) + if err != nil { + return err + } + + // template.Parse does not evaluate the function signature parameters; + // it ensures the function name is in scope and there is one or two results. + // we could use something like func() string { return "" } for this signature + // but this function from fmt works just fine. + // + // to explore the known requirements run: + // fm[funcName] = nil // will fail because nil does not have `reflect.Kind` Func + // or + // fm[funcName] = func() {} // will fail because there are no results + // or + // fm[funcName] = func() (int, int) {return 0, 0} // will fail because the second result is not an error + fm[funcName] = fmt.Sprintln + + if pkg == nil { + continue + } + buf.Reset() + if err := format.Node(&buf, fileSet, el.Value); err != nil { + return err + } + tv, err := types.Eval(fileSet, pkg, lit.Pos(), buf.String()) + if err != nil { + return err + } + funcTypesMap[funcName] = tv.Type.(*types.Signature) + } + return nil +} + +func evaluateCallParseFilesArgs(workingDirectory string, fileSet *token.FileSet, call *ast.CallExpr, files []*ast.File, embeddedPaths []string) ([]string, error) { + if len(call.Args) < 1 { + return nil, wrapWithFilename(workingDirectory, fileSet, call.Lparen, fmt.Errorf("missing required arguments")) + } + matches, err := embedFSFilePaths(workingDirectory, fileSet, files, call.Args[0], embeddedPaths) + if err != nil { + return nil, err + } + templateNames, err := StringLiteralExpressionList(workingDirectory, fileSet, call.Args[1:]) + if err != nil { + return nil, err + } + filtered := matches[:0] + for _, ef := range matches { + for j, pattern := range templateNames { + match, err := filepath.Match(pattern, ef) + if err != nil { + return nil, wrapWithFilename(workingDirectory, fileSet, call.Args[j+1].Pos(), fmt.Errorf("bad pattern %q: %w", pattern, err)) + } + if !match { + continue + } + filtered = append(filtered, ef) + break + } + } + return joinFilePaths(workingDirectory, filtered...), nil +} + +func embedFSFilePaths(dir string, fileSet *token.FileSet, files []*ast.File, exp ast.Expr, embeddedFiles []string) ([]string, error) { + varIdent, ok := exp.(*ast.Ident) + if !ok { + return nil, wrapWithFilename(dir, fileSet, exp.Pos(), fmt.Errorf("first argument to ParseFS must be an identifier")) + } + for _, decl := range astgen.IterateGenDecl(files, token.VAR) { + for _, s := range decl.Specs { + spec, ok := s.(*ast.ValueSpec) + if !ok || !slices.ContainsFunc(spec.Names, func(e *ast.Ident) bool { return e.Name == varIdent.Name }) { + continue + } + var comment strings.Builder + commentNode := readComments(&comment, decl.Doc, spec.Doc) + templateNames := parseTemplateNames(comment.String()) + absMat, err := embeddedFilesMatchingTemplateNameList(dir, fileSet, commentNode, templateNames, embeddedFiles) + if err != nil { + return nil, err + } + return absMat, nil + } + } + return nil, wrapWithFilename(dir, fileSet, exp.Pos(), fmt.Errorf("variable %s not found", varIdent)) +} + +func embeddedFilesMatchingTemplateNameList(dir string, set *token.FileSet, comment ast.Node, templateNames, embeddedFiles []string) ([]string, error) { + var matches []string + for _, fp := range embeddedFiles { + for _, pattern := range templateNames { + pat := filepath.FromSlash(pattern) + if !strings.ContainsAny(pat, "*[]") { + prefix := filepath.FromSlash(pat + "/") + if strings.HasPrefix(fp, prefix) { + matches = append(matches, fp) + continue + } + } + if matched, err := filepath.Match(pat, fp); err != nil { + return nil, wrapWithFilename(dir, set, comment.Pos(), fmt.Errorf("embed comment malformed: %w", err)) + } else if matched { + matches = append(matches, fp) + } + } + } + return slices.Clip(matches), nil +} + +const goEmbedCommentPrefix = "//go:embed" + +func readComments(s *strings.Builder, groups ...*ast.CommentGroup) ast.Node { + var n ast.Node + for _, c := range groups { + if c == nil { + continue + } + for _, line := range c.List { + if !strings.HasPrefix(line.Text, goEmbedCommentPrefix) { + continue + } + s.WriteString(strings.TrimSpace(strings.TrimPrefix(line.Text, goEmbedCommentPrefix))) + s.WriteByte(' ') + } + n = c + break + } + return n +} + +func parseTemplateNames(input string) []string { + // todo: refactor to use strconv.QuotedPrefix + var ( + templateNames []string + currentTemplateName strings.Builder + inQuote = false + quoteChar rune + ) + + for _, r := range input { + switch { + case r == '"' || r == '`': + if !inQuote { + inQuote = true + quoteChar = r + continue + } + if r != quoteChar { + currentTemplateName.WriteRune(r) + continue + } + templateNames = append(templateNames, currentTemplateName.String()) + currentTemplateName.Reset() + inQuote = false + case unicode.IsSpace(r): + if inQuote { + currentTemplateName.WriteRune(r) + continue + } + if currentTemplateName.Len() > 0 { + templateNames = append(templateNames, currentTemplateName.String()) + currentTemplateName.Reset() + } + default: + currentTemplateName.WriteRune(r) + } + } + + // Import any remaining pattern + if currentTemplateName.Len() > 0 { + templateNames = append(templateNames, currentTemplateName.String()) + } + + return templateNames +} + +func joinFilePaths(wd string, rel ...string) []string { + result := slices.Clone(rel) + for i := range result { + result[i] = filepath.Join(wd, result[i]) + } + return result +} + +func RelativeFilePaths(wd string, abs ...string) ([]string, error) { + result := slices.Clone(abs) + for i, p := range result { + r, err := filepath.Rel(wd, p) + if err != nil { + return nil, err + } + result[i] = r + } + return result, nil +} + +type TemplateFunctions map[string]*types.Signature + +func DefaultFunctions(pkg *types.Package) TemplateFunctions { + funcTypeMap := make(TemplateFunctions) + fmtPkg, ok := findPackage(pkg, "fmt") + if !ok || fmtPkg == nil { + return funcTypeMap + } + funcTypeMap["printf"] = fmtPkg.Scope().Lookup("Sprintf").Type().(*types.Signature) + funcTypeMap["print"] = fmtPkg.Scope().Lookup("Sprint").Type().(*types.Signature) + funcTypeMap["println"] = fmtPkg.Scope().Lookup("Sprintln").Type().(*types.Signature) + return funcTypeMap +} + +func findPackage(pkg *types.Package, path string) (*types.Package, bool) { + if pkg == nil || pkg.Path() == path { + return pkg, true + } + for _, im := range pkg.Imports() { + if p, ok := findPackage(im, path); ok { + return p, true + } + } + return nil, false +} + +func (functions TemplateFunctions) FindFunction(name string) (*types.Signature, bool) { + m := (map[string]*types.Signature)(functions) + fn, ok := m[name] + if !ok { + return nil, false + } + return fn, true +} + +func BasicLiteralString(node ast.Node) (string, bool) { + name, ok := node.(*ast.BasicLit) + if !ok { + return "", false + } + if name.Kind != token.STRING { + return "", false + } + templateName, err := strconv.Unquote(name.Value) + if err != nil { + return "", false + } + return templateName, true +} diff --git a/internal/asteval/template_html.go b/internal/asteval/template_html.go new file mode 100644 index 0000000..31c75fc --- /dev/null +++ b/internal/asteval/template_html.go @@ -0,0 +1,67 @@ +package asteval + +import ( + "html/template" + "text/template/parse" +) + +// htmlTemplate wraps *html/template.Template. +type htmlTemplate struct { + t *template.Template +} + +func (h *htmlTemplate) New(name string) Template { + return &htmlTemplate{t: h.t.New(name)} +} + +func (h *htmlTemplate) Parse(text string) (Template, error) { + t, err := h.t.Parse(text) + if err != nil { + return nil, err + } + return &htmlTemplate{t: t}, nil +} + +func (h *htmlTemplate) Funcs(funcMap map[string]any) Template { + return &htmlTemplate{t: h.t.Funcs(funcMap)} +} + +func (h *htmlTemplate) Option(opt ...string) Template { + return &htmlTemplate{t: h.t.Option(opt...)} +} + +func (h *htmlTemplate) Delims(left, right string) Template { + return &htmlTemplate{t: h.t.Delims(left, right)} +} + +func (h *htmlTemplate) Lookup(name string) Template { + t := h.t.Lookup(name) + if t == nil { + return nil + } + return &htmlTemplate{t: t} +} + +func (h *htmlTemplate) Name() string { + return h.t.Name() +} + +func (h *htmlTemplate) AddParseTree(name string, tree *parse.Tree) (Template, error) { + t, err := h.t.AddParseTree(name, tree) + if err != nil { + return nil, err + } + return &htmlTemplate{t: t}, nil +} + +func (h *htmlTemplate) Tree() *parse.Tree { + return h.t.Tree +} + +func (h *htmlTemplate) FindTree(name string) (*parse.Tree, bool) { + t := h.t.Lookup(name) + if t == nil { + return nil, false + } + return t.Tree, true +} diff --git a/internal/asteval/template_packages.go b/internal/asteval/template_packages.go new file mode 100644 index 0000000..2b4428a --- /dev/null +++ b/internal/asteval/template_packages.go @@ -0,0 +1,19 @@ +package asteval + +import ( + html "html/template" + text "text/template" +) + +// NewTemplate creates a Template backed by the appropriate template +// package either: "text/template" or "html/template". +func NewTemplate(pkgPath, name string) Template { + switch pkgPath { + case "text/template": + return &textTemplate{t: text.New(name)} + case "html/template": + return &htmlTemplate{t: html.New(name)} + default: + return nil + } +} diff --git a/internal/asteval/template_text.go b/internal/asteval/template_text.go new file mode 100644 index 0000000..d0e427c --- /dev/null +++ b/internal/asteval/template_text.go @@ -0,0 +1,67 @@ +package asteval + +import ( + "text/template" + "text/template/parse" +) + +// textTemplate wraps *text/template.Template. +type textTemplate struct { + t *template.Template +} + +func (s *textTemplate) New(name string) Template { + return &textTemplate{t: s.t.New(name)} +} + +func (s *textTemplate) Parse(text string) (Template, error) { + t, err := s.t.Parse(text) + if err != nil { + return nil, err + } + return &textTemplate{t: t}, nil +} + +func (s *textTemplate) Funcs(funcMap map[string]any) Template { + return &textTemplate{t: s.t.Funcs(funcMap)} +} + +func (s *textTemplate) Option(opt ...string) Template { + return &textTemplate{t: s.t.Option(opt...)} +} + +func (s *textTemplate) Delims(left, right string) Template { + return &textTemplate{t: s.t.Delims(left, right)} +} + +func (s *textTemplate) Lookup(name string) Template { + t := s.t.Lookup(name) + if t == nil { + return nil + } + return &textTemplate{t: t} +} + +func (s *textTemplate) Name() string { + return s.t.Name() +} + +func (s *textTemplate) AddParseTree(name string, tree *parse.Tree) (Template, error) { + t, err := s.t.AddParseTree(name, tree) + if err != nil { + return nil, err + } + return &textTemplate{t: t}, nil +} + +func (s *textTemplate) Tree() *parse.Tree { + return s.t.Tree +} + +func (s *textTemplate) FindTree(name string) (*parse.Tree, bool) { + t := s.t.Lookup(name) + if t == nil { + return nil, false + } + return t.Tree, true +} diff --git a/internal/asteval/testdata/template/assets_dir.txtar b/internal/asteval/testdata/template/assets_dir.txtar new file mode 100644 index 0000000..29b7738 --- /dev/null +++ b/internal/asteval/testdata/template/assets_dir.txtar @@ -0,0 +1,23 @@ +-- template.go -- +package main + +import ( + "embed" + "html/template" +) + +var ( + //go:embed assets + assetsFS embed.FS + + templates = template.Must(template.ParseFS(assetsFS, "assets/*")) +) +-- assets/index.gohtml -- + +{{define "home"}}{{end}} + +-- assets/form.gohtml -- + +{{define "create"}}{{end}} + +{{define "update"}}{{end}} diff --git a/internal/asteval/testdata/template/bad_embed_pattern.txtar b/internal/asteval/testdata/template/bad_embed_pattern.txtar new file mode 100644 index 0000000..0d4137a --- /dev/null +++ b/internal/asteval/testdata/template/bad_embed_pattern.txtar @@ -0,0 +1,16 @@ +-- template.go -- +package main + +import ( + "embed" + "html/template" +) + +var ( + //go:embed "[fail" + assetsFS embed.FS + + templates = template.Must(template.ParseFS(assetsFS, "*")) +) +-- greeting.gohtml -- +Hello, friend! diff --git a/internal/asteval/testdata/template/delims.txtar b/internal/asteval/testdata/template/delims.txtar new file mode 100644 index 0000000..852baa2 --- /dev/null +++ b/internal/asteval/testdata/template/delims.txtar @@ -0,0 +1,22 @@ +-- templates.go -- +package main + +var ( + //go:embed *.gohtml + files embed.FS +) + +var ( + templates = template.Must( + template.Must( + template.Must( + template.New("").ParseFS(files, "default.gohtml")). + Delims("(((", ")))").ParseFS(files, "triple_parens.gohtml")). + Delims("[[", "]]").ParseFS(files, "double_square.gohtml")) +) +-- default.gohtml -- +{{- define "default" -}}{{- end -}} +-- triple_parens.gohtml -- +(((- define "parens" -)))(((- end -))) +-- double_square.gohtml -- +[[- define "square" -]][[end]] diff --git a/internal/asteval/testdata/template/funcs.txtar b/internal/asteval/testdata/template/funcs.txtar new file mode 100644 index 0000000..619aa40 --- /dev/null +++ b/internal/asteval/testdata/template/funcs.txtar @@ -0,0 +1,41 @@ +-- template.go -- +package main + +import ( + "embed" + "html/template" +) + +var ( + //go:embed *.gohtml + src embed.FS + + templates = template.New("x").Funcs(template.FuncMap{ + "greet": func() string { return "Hello" }, + }).ParseFS(src, "greet.gohtml") + + templatesFuncNotDefined = template.New("x").Funcs(template.FuncMap{ + "greet": func() string { return "Hello" }, + }).ParseFS(src, "missing_func.gohtml") + + templatesWrongArg = template.New("x").Funcs(wrong) + + templatesTwoArgs = template.New("x").Funcs(wrong, fail) + + templatesNoArgs = template.New("x").Funcs() + + templatesWrongTypePackageName = template.New("x").Funcs(wrong.FuncMap{}) + + templatesWrongTypeName = template.New("x").Funcs(template.Wrong{}) + + templatesWrongTypeExpression = template.New("x").Funcs(wrong{}) + + templatesWrongTypeElem = template.New("x").Funcs(template.FuncMap{wrong}) + + templatesWrongElemKey = template.New("x").Funcs(template.FuncMap{wrong: func() string { return "" }}) +) +-- greet.gohtml -- +{{greet}}, world! + +-- missing_func.gohtml -- +{{greet}}, {{enemy}}! diff --git a/internal/asteval/testdata/template/parse.txtar b/internal/asteval/testdata/template/parse.txtar new file mode 100644 index 0000000..7157867 --- /dev/null +++ b/internal/asteval/testdata/template/parse.txtar @@ -0,0 +1,15 @@ +-- parse.go -- +package main + +import "html/template" + +var templates = template.New("GET /").Parse(`

Hello, world!

`) + +var multiple = template.New("").Parse(` +{{define "GET /"}}

Hello, world!

{{end}} +{{define "GET /{name}"}}

Hello, {{.PathValue "name"}}!

{{end}} +`) + +var noArg = template.New("").Parse() + +var wrongArg = template.New("").Parse(500) \ No newline at end of file diff --git a/internal/asteval/testdata/template/template_ParseFS.txtar b/internal/asteval/testdata/template/template_ParseFS.txtar new file mode 100644 index 0000000..b07673b --- /dev/null +++ b/internal/asteval/testdata/template/template_ParseFS.txtar @@ -0,0 +1,38 @@ +-- template.go -- +package main + +import ( + "embed" + "html/template" +) + +var ( + //go:embed *.gohtml + templateSource embed.FS + + templates = template.Must(template.ParseFS(templateSource, "*")) + + // allHTML used by templatesHTML and templatesGoHTML to test pattern filtering + // this comment ensures the comment parser skips lines not preceded by go:embed + //go:embed *html + allHTML embed.FS + + templatesHTML = template.Must(template.ParseFS(allHTML, "*.*html")) + templatesGoHTML = template.Must(template.ParseFS(allHTML, "*.gohtml")) + + templateEmbedVariableNotFound = template.Must(template.ParseFS(hiding, "*")) +) +-- index.gohtml -- + +{{define "home"}}{{end}} + +-- form.gohtml -- + +{{define "create"}}{{end}} + +{{define "update"}}{{end}} + +-- script.html -- +{{define "console_log"}} + +{{end}} diff --git a/internal/asteval/testdata/template/templates.txtar b/internal/asteval/testdata/template/templates.txtar new file mode 100644 index 0000000..66a8671 --- /dev/null +++ b/internal/asteval/testdata/template/templates.txtar @@ -0,0 +1,84 @@ +-- template.go -- +package main + +import ( + "embed" + "html/template" +) + +var ( + //go:embed *.gohtml + templateSource embed.FS + + templateNew = template.New("some-name") + + templateParseFSNew = template.Must(template.ParseFS(templateSource, "*")).New("greetings") + + templateNewParseFS = template.Must(template.New("greetings").ParseFS(templateSource, "*")) + + templateNewMissingArg = template.New() + + templateWrongX = UNKNOWN.New() + + templateWrongArgCount = template.New("one", "two") + + templateNewOnIndexed = ts[0].New("one", "two") + + templateNewArg42 = template.New(42) + + templateNewArgIdent = template.New(TemplateName) + + templateNewErrUpstream = template.New(fail).New("n") + + templatesIdent = someIdent + + unsupportedMethod = template.Unknown() + + unexpectedFunExpression = x[3]() + + templateMustNonIdentReceiver = f().Must(template.ParseFS(templateSource, "*")) + + templateMustCalledWithTwoArgs = template.Must(nil, nil) + + templateMustCalledWithNoArg s = template.Must() + + templateMustWrongPackageIdent = wrong.Must() + + templateParseFSWrongPackageIdent = wrong.ParseFS(templateSource, "*") + + templateParseFSReceiverErr = template.New().ParseFS(templateSource, "*") + + templateParseFSUnexpectedReceiver = x[0].ParseFS(templateSource, "*") + + templateParseFSNoArgs = template.ParseFS() + + templateParseFSFirstArgNonIdent = template.ParseFS(os.DirFS("."), "*") + + templateParseFSNonStringLiteralGlob = template.ParseFS(templateSource, "w", 42, "x") + + templateParseFSWithBadGlob = template.ParseFS(templateSource, "[fail") + + templateNewHasWrongNumberOfArgs = template.Must(template.New("x").ParseFS(templateSource, "*")).New() + + templateNewHasWrongTypeOfArgs = template.New("x").New(9000) + + templateNewHasTooManyArgs = template.New("x").New("x", "y") + + templateDelimsGetsNoArgs = template.New("x").Delims() + + templateDelimsGetsTooMany = template.New("x").Delims("x", "y", "") + + templateDelimsWrongExpressionArg = template.New("x").Delims("x", y) + + templateParseFSMethodFails = template.New("x").ParseFS(templateSource, fail) + + templateOptionsRequiresStringLiterals = template.New("x").Option(fail) + + templateUnknownMethod = template.New("x").Unknown() + + templateOptionCall = template.New("x").Option("missingkey=default").ParseFS(templateSource, "*") + + templateOptionCallUnknownArg = template.New("x").Option("unknown").ParseFS(templateSource, "*") +) +-- index.gohtml -- +Hello, friend! diff --git a/internal/astgen/format.go b/internal/astgen/format.go new file mode 100644 index 0000000..748cebe --- /dev/null +++ b/internal/astgen/format.go @@ -0,0 +1,23 @@ +package astgen + +import ( + "bytes" + "fmt" + "go/ast" + "go/format" + "go/printer" + "go/token" +) + +// Format converts an AST node to formatted Go source code +func Format(node ast.Node) string { + var buf bytes.Buffer + if err := printer.Fprint(&buf, token.NewFileSet(), node); err != nil { + return fmt.Sprintf("formatting error: %v", err) + } + out, err := format.Source(buf.Bytes()) + if err != nil { + return fmt.Sprintf("formatting error: %v", err) + } + return string(bytes.ReplaceAll(out, []byte("\n}\nfunc "), []byte("\n}\n\nfunc "))) +} diff --git a/internal/astgen/queries.go b/internal/astgen/queries.go new file mode 100644 index 0000000..e262076 --- /dev/null +++ b/internal/astgen/queries.go @@ -0,0 +1,36 @@ +package astgen + +import ( + "go/ast" + "go/token" +) + +// IterateGenDecl returns an iterator over GenDecl nodes with the specified token type +func IterateGenDecl(files []*ast.File, tok token.Token) func(func(*ast.File, *ast.GenDecl) bool) { + return func(yield func(*ast.File, *ast.GenDecl) bool) { + for _, file := range files { + for _, decl := range file.Decls { + d, ok := decl.(*ast.GenDecl) + if !ok || d.Tok != tok { + continue + } + if !yield(file, d) { + return + } + } + } + } +} + +// IterateValueSpecs returns an iterator over ValueSpec nodes in var declarations +func IterateValueSpecs(files []*ast.File) func(func(*ast.File, *ast.ValueSpec) bool) { + return func(yield func(*ast.File, *ast.ValueSpec) bool) { + for file, decl := range IterateGenDecl(files, token.VAR) { + for _, s := range decl.Specs { + if !yield(file, s.(*ast.ValueSpec)) { + return + } + } + } + } +} diff --git a/package.go b/package.go new file mode 100644 index 0000000..e44de5f --- /dev/null +++ b/package.go @@ -0,0 +1,248 @@ +package check + +import ( + "errors" + "fmt" + "go/ast" + "go/token" + "go/types" + "path/filepath" + + "golang.org/x/tools/go/packages" + + "github.com/typelate/check/internal/asteval" + "github.com/typelate/check/internal/astgen" +) + +type pendingCall struct { + receiverObj types.Object + templateName string + dataType types.Type +} + +type resolvedTemplate struct { + templates asteval.Template + functions asteval.TemplateFunctions + metadata *asteval.TemplateMetadata +} + +// Package discovers all .ExecuteTemplate calls in the given package, +// resolves receiver variables to their template construction chains, +// and type-checks each call. +// +// ExecuteTemplate must be called with a string literal for the second parameter. +func Package(pkg *packages.Package) error { + pending, receivers := findExecuteCalls(pkg) + resolved, resolveErrs := resolveTemplates(pkg, receivers) + callErr := checkCalls(pkg, pending, resolved) + return errors.Join(append(resolveErrs, callErr)...) +} + +// findExecuteCalls walks the package syntax looking for ExecuteTemplate calls +// and returns the pending calls along with the set of receiver objects that +// need template resolution. +func findExecuteCalls(pkg *packages.Package) ([]pendingCall, map[types.Object]struct{}) { + var pending []pendingCall + receiverSet := make(map[types.Object]struct{}) + + for _, file := range pkg.Syntax { + ast.Inspect(file, func(node ast.Node) bool { + call, ok := node.(*ast.CallExpr) + if !ok || len(call.Args) != 3 { + return true + } + sel, ok := call.Fun.(*ast.SelectorExpr) + if !ok || sel.Sel.Name != "ExecuteTemplate" { + return true + } + // Verify the method belongs to html/template or text/template. + if !asteval.IsTemplateMethod(pkg.TypesInfo, sel) { + return true + } + receiverIdent, ok := sel.X.(*ast.Ident) + if !ok { + return true + } + obj := pkg.TypesInfo.Uses[receiverIdent] + if obj == nil { + return true + } + templateName, ok := asteval.BasicLiteralString(call.Args[1]) + if !ok { + return true + } + dataType := pkg.TypesInfo.TypeOf(call.Args[2]) + pending = append(pending, pendingCall{ + receiverObj: obj, + templateName: templateName, + dataType: dataType, + }) + receiverSet[obj] = struct{}{} + return true + }) + } + + return pending, receiverSet +} + +// resolveTemplates resolves each unique receiver object to its template +// construction chain, including additional ParseFS/Parse modifications. +func resolveTemplates(pkg *packages.Package, receivers map[types.Object]struct{}) (map[types.Object]*resolvedTemplate, []error) { + resolved := make(map[types.Object]*resolvedTemplate) + + workingDirectory := packageDirectory(pkg) + embeddedPaths, err := asteval.RelativeFilePaths(workingDirectory, pkg.EmbedFiles...) + if err != nil { + return nil, []error{fmt.Errorf("failed to calculate relative path for embedded files: %w", err)} + } + + var resolveErrs []error + + resolveExpr := func(obj types.Object, name string, expr ast.Expr) { + if _, needed := receivers[obj]; !needed { + return + } + funcTypeMap := asteval.DefaultFunctions(pkg.Types) + meta := &asteval.TemplateMetadata{} + ts, _, _, err := asteval.EvaluateTemplateSelector(nil, pkg.Types, pkg.TypesInfo, expr, workingDirectory, name, "", "", pkg.Fset, pkg.Syntax, embeddedPaths, funcTypeMap, make(map[string]any), meta) + if err != nil { + resolveErrs = append(resolveErrs, err) + return + } + resolved[obj] = &resolvedTemplate{ + templates: ts, + functions: funcTypeMap, + metadata: meta, + } + } + + // Resolve top-level var declarations. + for _, tv := range astgen.IterateValueSpecs(pkg.Syntax) { + for i, ident := range tv.Names { + if i >= len(tv.Values) { + continue + } + obj := pkg.TypesInfo.Defs[ident] + if obj == nil { + continue + } + resolveExpr(obj, ident.Name, tv.Values[i]) + } + } + + // Resolve function-local var declarations and short variable declarations. + for _, file := range pkg.Syntax { + ast.Inspect(file, func(node ast.Node) bool { + switch n := node.(type) { + case *ast.DeclStmt: + gd, ok := n.Decl.(*ast.GenDecl) + if !ok || gd.Tok != token.VAR { + return true + } + for _, spec := range gd.Specs { + vs, ok := spec.(*ast.ValueSpec) + if !ok { + continue + } + for i, ident := range vs.Names { + if i >= len(vs.Values) { + continue + } + obj := pkg.TypesInfo.Defs[ident] + if obj == nil { + continue + } + resolveExpr(obj, ident.Name, vs.Values[i]) + } + } + case *ast.AssignStmt: + if n.Tok != token.DEFINE { + return true + } + for i, lhs := range n.Lhs { + if i >= len(n.Rhs) { + continue + } + ident, ok := lhs.(*ast.Ident) + if !ok { + continue + } + obj := pkg.TypesInfo.Defs[ident] + if obj == nil { + continue + } + resolveExpr(obj, ident.Name, n.Rhs[i]) + } + } + return true + }) + } + + // Find additional ParseFS/Parse calls on resolved template variables. + for _, file := range pkg.Syntax { + ast.Inspect(file, func(node ast.Node) bool { + call, ok := node.(*ast.CallExpr) + if !ok { + return true + } + obj := asteval.FindModificationReceiver(call, pkg.TypesInfo) + if obj == nil { + return true + } + rt, ok := resolved[obj] + if !ok { + return true + } + meta := &asteval.TemplateMetadata{} + ts, _, _, err := asteval.EvaluateTemplateSelector(rt.templates, pkg.Types, pkg.TypesInfo, call, workingDirectory, "", "", "", pkg.Fset, pkg.Syntax, embeddedPaths, rt.functions, make(map[string]any), meta) + if err != nil { + return true + } + rt.templates = ts + rt.metadata.EmbedFilePaths = append(rt.metadata.EmbedFilePaths, meta.EmbedFilePaths...) + rt.metadata.ParseCalls = append(rt.metadata.ParseCalls, meta.ParseCalls...) + return true + }) + } + + return resolved, resolveErrs +} + +// checkCalls type-checks each pending ExecuteTemplate call against its +// resolved template. +func checkCalls(pkg *packages.Package, pending []pendingCall, resolved map[types.Object]*resolvedTemplate) error { + mergedFunctions := make(Functions) + if pkg.Types != nil { + mergedFunctions = DefaultFunctions(pkg.Types) + } + for _, rt := range resolved { + for name, sig := range rt.functions { + mergedFunctions[name] = sig + } + } + + var errs []error + for _, p := range pending { + rt, ok := resolved[p.receiverObj] + if !ok { + continue + } + looked := rt.templates.Lookup(p.templateName) + if looked == nil { + continue + } + global := NewGlobal(pkg.Types, pkg.Fset, rt.templates, mergedFunctions) + if err := Execute(global, looked.Tree(), p.dataType); err != nil { + errs = append(errs, err) + } + } + + return errors.Join(errs...) +} + +func packageDirectory(pkg *packages.Package) string { + if len(pkg.GoFiles) > 0 { + return filepath.Dir(pkg.GoFiles[0]) + } + return "." +}